]> Git Repo - linux.git/blob - net/mptcp/pm_userspace.c
Merge tag 'amd-drm-next-6.5-2023-06-09' of https://gitlab.freedesktop.org/agd5f/linux...
[linux.git] / net / mptcp / pm_userspace.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2022, Intel Corporation.
5  */
6
7 #include "protocol.h"
8 #include "mib.h"
9
10 void mptcp_free_local_addr_list(struct mptcp_sock *msk)
11 {
12         struct mptcp_pm_addr_entry *entry, *tmp;
13         struct sock *sk = (struct sock *)msk;
14         LIST_HEAD(free_list);
15
16         if (!mptcp_pm_is_userspace(msk))
17                 return;
18
19         spin_lock_bh(&msk->pm.lock);
20         list_splice_init(&msk->pm.userspace_pm_local_addr_list, &free_list);
21         spin_unlock_bh(&msk->pm.lock);
22
23         list_for_each_entry_safe(entry, tmp, &free_list, list) {
24                 sock_kfree_s(sk, entry, sizeof(*entry));
25         }
26 }
27
28 static int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk,
29                                                     struct mptcp_pm_addr_entry *entry)
30 {
31         DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
32         struct mptcp_pm_addr_entry *match = NULL;
33         struct sock *sk = (struct sock *)msk;
34         struct mptcp_pm_addr_entry *e;
35         bool addr_match = false;
36         bool id_match = false;
37         int ret = -EINVAL;
38
39         bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
40
41         spin_lock_bh(&msk->pm.lock);
42         list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) {
43                 addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true);
44                 if (addr_match && entry->addr.id == 0)
45                         entry->addr.id = e->addr.id;
46                 id_match = (e->addr.id == entry->addr.id);
47                 if (addr_match && id_match) {
48                         match = e;
49                         break;
50                 } else if (addr_match || id_match) {
51                         break;
52                 }
53                 __set_bit(e->addr.id, id_bitmap);
54         }
55
56         if (!match && !addr_match && !id_match) {
57                 /* Memory for the entry is allocated from the
58                  * sock option buffer.
59                  */
60                 e = sock_kmalloc(sk, sizeof(*e), GFP_ATOMIC);
61                 if (!e) {
62                         ret = -ENOMEM;
63                         goto append_err;
64                 }
65
66                 *e = *entry;
67                 if (!e->addr.id)
68                         e->addr.id = find_next_zero_bit(id_bitmap,
69                                                         MPTCP_PM_MAX_ADDR_ID + 1,
70                                                         1);
71                 list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list);
72                 ret = e->addr.id;
73         } else if (match) {
74                 ret = entry->addr.id;
75         }
76
77 append_err:
78         spin_unlock_bh(&msk->pm.lock);
79         return ret;
80 }
81
82 int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk,
83                                                    unsigned int id,
84                                                    u8 *flags, int *ifindex)
85 {
86         struct mptcp_pm_addr_entry *entry, *match = NULL;
87
88         *flags = 0;
89         *ifindex = 0;
90
91         spin_lock_bh(&msk->pm.lock);
92         list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) {
93                 if (id == entry->addr.id) {
94                         match = entry;
95                         break;
96                 }
97         }
98         spin_unlock_bh(&msk->pm.lock);
99         if (match) {
100                 *flags = match->flags;
101                 *ifindex = match->ifindex;
102         }
103
104         return 0;
105 }
106
107 int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk,
108                                     struct mptcp_addr_info *skc)
109 {
110         struct mptcp_pm_addr_entry new_entry;
111         __be16 msk_sport =  ((struct inet_sock *)
112                              inet_sk((struct sock *)msk))->inet_sport;
113
114         memset(&new_entry, 0, sizeof(struct mptcp_pm_addr_entry));
115         new_entry.addr = *skc;
116         new_entry.addr.id = 0;
117         new_entry.flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;
118
119         if (new_entry.addr.port == msk_sport)
120                 new_entry.addr.port = 0;
121
122         return mptcp_userspace_pm_append_new_local_addr(msk, &new_entry);
123 }
124
125 int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info)
126 {
127         struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
128         struct nlattr *addr = info->attrs[MPTCP_PM_ATTR_ADDR];
129         struct mptcp_pm_addr_entry addr_val;
130         struct mptcp_sock *msk;
131         int err = -EINVAL;
132         u32 token_val;
133
134         if (!addr || !token) {
135                 GENL_SET_ERR_MSG(info, "missing required inputs");
136                 return err;
137         }
138
139         token_val = nla_get_u32(token);
140
141         msk = mptcp_token_get_sock(sock_net(skb->sk), token_val);
142         if (!msk) {
143                 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
144                 return err;
145         }
146
147         if (!mptcp_pm_is_userspace(msk)) {
148                 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
149                 goto announce_err;
150         }
151
152         err = mptcp_pm_parse_entry(addr, info, true, &addr_val);
153         if (err < 0) {
154                 GENL_SET_ERR_MSG(info, "error parsing local address");
155                 goto announce_err;
156         }
157
158         if (addr_val.addr.id == 0 || !(addr_val.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
159                 GENL_SET_ERR_MSG(info, "invalid addr id or flags");
160                 err = -EINVAL;
161                 goto announce_err;
162         }
163
164         err = mptcp_userspace_pm_append_new_local_addr(msk, &addr_val);
165         if (err < 0) {
166                 GENL_SET_ERR_MSG(info, "did not match address and id");
167                 goto announce_err;
168         }
169
170         lock_sock((struct sock *)msk);
171         spin_lock_bh(&msk->pm.lock);
172
173         if (mptcp_pm_alloc_anno_list(msk, &addr_val)) {
174                 mptcp_pm_announce_addr(msk, &addr_val.addr, false);
175                 mptcp_pm_nl_addr_send_ack(msk);
176         }
177
178         spin_unlock_bh(&msk->pm.lock);
179         release_sock((struct sock *)msk);
180
181         err = 0;
182  announce_err:
183         sock_put((struct sock *)msk);
184         return err;
185 }
186
187 int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info)
188 {
189         struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
190         struct nlattr *id = info->attrs[MPTCP_PM_ATTR_LOC_ID];
191         struct mptcp_pm_addr_entry *match = NULL;
192         struct mptcp_pm_addr_entry *entry;
193         struct mptcp_sock *msk;
194         LIST_HEAD(free_list);
195         int err = -EINVAL;
196         u32 token_val;
197         u8 id_val;
198
199         if (!id || !token) {
200                 GENL_SET_ERR_MSG(info, "missing required inputs");
201                 return err;
202         }
203
204         id_val = nla_get_u8(id);
205         token_val = nla_get_u32(token);
206
207         msk = mptcp_token_get_sock(sock_net(skb->sk), token_val);
208         if (!msk) {
209                 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
210                 return err;
211         }
212
213         if (!mptcp_pm_is_userspace(msk)) {
214                 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
215                 goto remove_err;
216         }
217
218         lock_sock((struct sock *)msk);
219
220         list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) {
221                 if (entry->addr.id == id_val) {
222                         match = entry;
223                         break;
224                 }
225         }
226
227         if (!match) {
228                 GENL_SET_ERR_MSG(info, "address with specified id not found");
229                 release_sock((struct sock *)msk);
230                 goto remove_err;
231         }
232
233         list_move(&match->list, &free_list);
234
235         mptcp_pm_remove_addrs_and_subflows(msk, &free_list);
236
237         release_sock((struct sock *)msk);
238
239         list_for_each_entry_safe(match, entry, &free_list, list) {
240                 sock_kfree_s((struct sock *)msk, match, sizeof(*match));
241         }
242
243         err = 0;
244  remove_err:
245         sock_put((struct sock *)msk);
246         return err;
247 }
248
249 int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info)
250 {
251         struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
252         struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
253         struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR];
254         struct mptcp_addr_info addr_r;
255         struct mptcp_addr_info addr_l;
256         struct mptcp_sock *msk;
257         int err = -EINVAL;
258         struct sock *sk;
259         u32 token_val;
260
261         if (!laddr || !raddr || !token) {
262                 GENL_SET_ERR_MSG(info, "missing required inputs");
263                 return err;
264         }
265
266         token_val = nla_get_u32(token);
267
268         msk = mptcp_token_get_sock(genl_info_net(info), token_val);
269         if (!msk) {
270                 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
271                 return err;
272         }
273
274         if (!mptcp_pm_is_userspace(msk)) {
275                 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
276                 goto create_err;
277         }
278
279         err = mptcp_pm_parse_addr(laddr, info, &addr_l);
280         if (err < 0) {
281                 NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr");
282                 goto create_err;
283         }
284
285         if (addr_l.id == 0) {
286                 NL_SET_ERR_MSG_ATTR(info->extack, laddr, "missing local addr id");
287                 err = -EINVAL;
288                 goto create_err;
289         }
290
291         err = mptcp_pm_parse_addr(raddr, info, &addr_r);
292         if (err < 0) {
293                 NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr");
294                 goto create_err;
295         }
296
297         sk = (struct sock *)msk;
298
299         if (!mptcp_pm_addr_families_match(sk, &addr_l, &addr_r)) {
300                 GENL_SET_ERR_MSG(info, "families mismatch");
301                 err = -EINVAL;
302                 goto create_err;
303         }
304
305         lock_sock(sk);
306
307         err = __mptcp_subflow_connect(sk, &addr_l, &addr_r);
308
309         release_sock(sk);
310
311  create_err:
312         sock_put((struct sock *)msk);
313         return err;
314 }
315
316 static struct sock *mptcp_nl_find_ssk(struct mptcp_sock *msk,
317                                       const struct mptcp_addr_info *local,
318                                       const struct mptcp_addr_info *remote)
319 {
320         struct mptcp_subflow_context *subflow;
321
322         if (local->family != remote->family)
323                 return NULL;
324
325         mptcp_for_each_subflow(msk, subflow) {
326                 const struct inet_sock *issk;
327                 struct sock *ssk;
328
329                 ssk = mptcp_subflow_tcp_sock(subflow);
330
331                 if (local->family != ssk->sk_family)
332                         continue;
333
334                 issk = inet_sk(ssk);
335
336                 switch (ssk->sk_family) {
337                 case AF_INET:
338                         if (issk->inet_saddr != local->addr.s_addr ||
339                             issk->inet_daddr != remote->addr.s_addr)
340                                 continue;
341                         break;
342 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
343                 case AF_INET6: {
344                         const struct ipv6_pinfo *pinfo = inet6_sk(ssk);
345
346                         if (!ipv6_addr_equal(&local->addr6, &pinfo->saddr) ||
347                             !ipv6_addr_equal(&remote->addr6, &ssk->sk_v6_daddr))
348                                 continue;
349                         break;
350                 }
351 #endif
352                 default:
353                         continue;
354                 }
355
356                 if (issk->inet_sport == local->port &&
357                     issk->inet_dport == remote->port)
358                         return ssk;
359         }
360
361         return NULL;
362 }
363
364 int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info)
365 {
366         struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
367         struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
368         struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR];
369         struct mptcp_addr_info addr_l;
370         struct mptcp_addr_info addr_r;
371         struct mptcp_sock *msk;
372         struct sock *sk, *ssk;
373         int err = -EINVAL;
374         u32 token_val;
375
376         if (!laddr || !raddr || !token) {
377                 GENL_SET_ERR_MSG(info, "missing required inputs");
378                 return err;
379         }
380
381         token_val = nla_get_u32(token);
382
383         msk = mptcp_token_get_sock(genl_info_net(info), token_val);
384         if (!msk) {
385                 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
386                 return err;
387         }
388
389         if (!mptcp_pm_is_userspace(msk)) {
390                 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
391                 goto destroy_err;
392         }
393
394         err = mptcp_pm_parse_addr(laddr, info, &addr_l);
395         if (err < 0) {
396                 NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr");
397                 goto destroy_err;
398         }
399
400         err = mptcp_pm_parse_addr(raddr, info, &addr_r);
401         if (err < 0) {
402                 NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr");
403                 goto destroy_err;
404         }
405
406         if (addr_l.family != addr_r.family) {
407                 GENL_SET_ERR_MSG(info, "address families do not match");
408                 err = -EINVAL;
409                 goto destroy_err;
410         }
411
412         if (!addr_l.port || !addr_r.port) {
413                 GENL_SET_ERR_MSG(info, "missing local or remote port");
414                 err = -EINVAL;
415                 goto destroy_err;
416         }
417
418         sk = (struct sock *)msk;
419         lock_sock(sk);
420         ssk = mptcp_nl_find_ssk(msk, &addr_l, &addr_r);
421         if (ssk) {
422                 struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
423
424                 mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN);
425                 mptcp_close_ssk(sk, ssk, subflow);
426                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW);
427                 err = 0;
428         } else {
429                 err = -ESRCH;
430         }
431         release_sock(sk);
432
433 destroy_err:
434         sock_put((struct sock *)msk);
435         return err;
436 }
437
438 int mptcp_userspace_pm_set_flags(struct net *net, struct nlattr *token,
439                                  struct mptcp_pm_addr_entry *loc,
440                                  struct mptcp_pm_addr_entry *rem, u8 bkup)
441 {
442         struct mptcp_sock *msk;
443         int ret = -EINVAL;
444         u32 token_val;
445
446         token_val = nla_get_u32(token);
447
448         msk = mptcp_token_get_sock(net, token_val);
449         if (!msk)
450                 return ret;
451
452         if (!mptcp_pm_is_userspace(msk))
453                 goto set_flags_err;
454
455         if (loc->addr.family == AF_UNSPEC ||
456             rem->addr.family == AF_UNSPEC)
457                 goto set_flags_err;
458
459         lock_sock((struct sock *)msk);
460         ret = mptcp_pm_nl_mp_prio_send_ack(msk, &loc->addr, &rem->addr, bkup);
461         release_sock((struct sock *)msk);
462
463 set_flags_err:
464         sock_put((struct sock *)msk);
465         return ret;
466 }
This page took 0.057645 seconds and 4 git commands to generate.