]> Git Repo - linux.git/blob - tools/testing/selftests/bpf/prog_tests/mptcp.c
Merge patch series "riscv: Extension parsing fixes"
[linux.git] / tools / testing / selftests / bpf / prog_tests / mptcp.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2020, Tessares SA. */
3 /* Copyright (c) 2022, SUSE. */
4
5 #include <linux/const.h>
6 #include <netinet/in.h>
7 #include <test_progs.h>
8 #include "cgroup_helpers.h"
9 #include "network_helpers.h"
10 #include "mptcp_sock.skel.h"
11 #include "mptcpify.skel.h"
12
13 #define NS_TEST "mptcp_ns"
14
15 #ifndef IPPROTO_MPTCP
16 #define IPPROTO_MPTCP 262
17 #endif
18
19 #ifndef SOL_MPTCP
20 #define SOL_MPTCP 284
21 #endif
22 #ifndef MPTCP_INFO
23 #define MPTCP_INFO              1
24 #endif
25 #ifndef MPTCP_INFO_FLAG_FALLBACK
26 #define MPTCP_INFO_FLAG_FALLBACK                _BITUL(0)
27 #endif
28 #ifndef MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED
29 #define MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED     _BITUL(1)
30 #endif
31
32 #ifndef TCP_CA_NAME_MAX
33 #define TCP_CA_NAME_MAX 16
34 #endif
35
36 struct __mptcp_info {
37         __u8    mptcpi_subflows;
38         __u8    mptcpi_add_addr_signal;
39         __u8    mptcpi_add_addr_accepted;
40         __u8    mptcpi_subflows_max;
41         __u8    mptcpi_add_addr_signal_max;
42         __u8    mptcpi_add_addr_accepted_max;
43         __u32   mptcpi_flags;
44         __u32   mptcpi_token;
45         __u64   mptcpi_write_seq;
46         __u64   mptcpi_snd_una;
47         __u64   mptcpi_rcv_nxt;
48         __u8    mptcpi_local_addr_used;
49         __u8    mptcpi_local_addr_max;
50         __u8    mptcpi_csum_enabled;
51         __u32   mptcpi_retransmits;
52         __u64   mptcpi_bytes_retrans;
53         __u64   mptcpi_bytes_sent;
54         __u64   mptcpi_bytes_received;
55         __u64   mptcpi_bytes_acked;
56 };
57
58 struct mptcp_storage {
59         __u32 invoked;
60         __u32 is_mptcp;
61         struct sock *sk;
62         __u32 token;
63         struct sock *first;
64         char ca_name[TCP_CA_NAME_MAX];
65 };
66
67 static struct nstoken *create_netns(void)
68 {
69         SYS(fail, "ip netns add %s", NS_TEST);
70         SYS(fail, "ip -net %s link set dev lo up", NS_TEST);
71
72         return open_netns(NS_TEST);
73 fail:
74         return NULL;
75 }
76
77 static void cleanup_netns(struct nstoken *nstoken)
78 {
79         if (nstoken)
80                 close_netns(nstoken);
81
82         SYS_NOFAIL("ip netns del %s", NS_TEST);
83 }
84
85 static int start_mptcp_server(int family, const char *addr_str, __u16 port,
86                               int timeout_ms)
87 {
88         struct network_helper_opts opts = {
89                 .timeout_ms     = timeout_ms,
90                 .proto          = IPPROTO_MPTCP,
91         };
92         struct sockaddr_storage addr;
93         socklen_t addrlen;
94
95         if (make_sockaddr(family, addr_str, port, &addr, &addrlen))
96                 return -1;
97
98         return start_server_addr(SOCK_STREAM, &addr, addrlen, &opts);
99 }
100
101 static int verify_tsk(int map_fd, int client_fd)
102 {
103         int err, cfd = client_fd;
104         struct mptcp_storage val;
105
106         err = bpf_map_lookup_elem(map_fd, &cfd, &val);
107         if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
108                 return err;
109
110         if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
111                 err++;
112
113         if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
114                 err++;
115
116         return err;
117 }
118
119 static void get_msk_ca_name(char ca_name[])
120 {
121         size_t len;
122         int fd;
123
124         fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
125         if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
126                 return;
127
128         len = read(fd, ca_name, TCP_CA_NAME_MAX);
129         if (!ASSERT_GT(len, 0, "failed to read ca_name"))
130                 goto err;
131
132         if (len > 0 && ca_name[len - 1] == '\n')
133                 ca_name[len - 1] = '\0';
134
135 err:
136         close(fd);
137 }
138
139 static int verify_msk(int map_fd, int client_fd, __u32 token)
140 {
141         char ca_name[TCP_CA_NAME_MAX];
142         int err, cfd = client_fd;
143         struct mptcp_storage val;
144
145         if (!ASSERT_GT(token, 0, "invalid token"))
146                 return -1;
147
148         get_msk_ca_name(ca_name);
149
150         err = bpf_map_lookup_elem(map_fd, &cfd, &val);
151         if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
152                 return err;
153
154         if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
155                 err++;
156
157         if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
158                 err++;
159
160         if (!ASSERT_EQ(val.token, token, "unexpected token"))
161                 err++;
162
163         if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
164                 err++;
165
166         if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
167                 err++;
168
169         return err;
170 }
171
172 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
173 {
174         int client_fd, prog_fd, map_fd, err;
175         struct mptcp_sock *sock_skel;
176
177         sock_skel = mptcp_sock__open_and_load();
178         if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
179                 return libbpf_get_error(sock_skel);
180
181         err = mptcp_sock__attach(sock_skel);
182         if (!ASSERT_OK(err, "skel_attach"))
183                 goto out;
184
185         prog_fd = bpf_program__fd(sock_skel->progs._sockops);
186         map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
187         err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
188         if (!ASSERT_OK(err, "bpf_prog_attach"))
189                 goto out;
190
191         client_fd = connect_to_fd(server_fd, 0);
192         if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
193                 err = -EIO;
194                 goto out;
195         }
196
197         err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
198                           verify_tsk(map_fd, client_fd);
199
200         close(client_fd);
201
202 out:
203         mptcp_sock__destroy(sock_skel);
204         return err;
205 }
206
207 static void test_base(void)
208 {
209         struct nstoken *nstoken = NULL;
210         int server_fd, cgroup_fd;
211
212         cgroup_fd = test__join_cgroup("/mptcp");
213         if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
214                 return;
215
216         nstoken = create_netns();
217         if (!ASSERT_OK_PTR(nstoken, "create_netns"))
218                 goto fail;
219
220         /* without MPTCP */
221         server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
222         if (!ASSERT_GE(server_fd, 0, "start_server"))
223                 goto with_mptcp;
224
225         ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
226
227         close(server_fd);
228
229 with_mptcp:
230         /* with MPTCP */
231         server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
232         if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
233                 goto fail;
234
235         ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
236
237         close(server_fd);
238
239 fail:
240         cleanup_netns(nstoken);
241         close(cgroup_fd);
242 }
243
244 static void send_byte(int fd)
245 {
246         char b = 0x55;
247
248         ASSERT_EQ(write(fd, &b, sizeof(b)), 1, "send single byte");
249 }
250
251 static int verify_mptcpify(int server_fd, int client_fd)
252 {
253         struct __mptcp_info info;
254         socklen_t optlen;
255         int protocol;
256         int err = 0;
257
258         optlen = sizeof(protocol);
259         if (!ASSERT_OK(getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen),
260                        "getsockopt(SOL_PROTOCOL)"))
261                 return -1;
262
263         if (!ASSERT_EQ(protocol, IPPROTO_MPTCP, "protocol isn't MPTCP"))
264                 err++;
265
266         optlen = sizeof(info);
267         if (!ASSERT_OK(getsockopt(client_fd, SOL_MPTCP, MPTCP_INFO, &info, &optlen),
268                        "getsockopt(MPTCP_INFO)"))
269                 return -1;
270
271         if (!ASSERT_GE(info.mptcpi_flags, 0, "unexpected mptcpi_flags"))
272                 err++;
273         if (!ASSERT_FALSE(info.mptcpi_flags & MPTCP_INFO_FLAG_FALLBACK,
274                           "MPTCP fallback"))
275                 err++;
276         if (!ASSERT_TRUE(info.mptcpi_flags & MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED,
277                          "no remote key received"))
278                 err++;
279
280         return err;
281 }
282
283 static int run_mptcpify(int cgroup_fd)
284 {
285         int server_fd, client_fd, err = 0;
286         struct mptcpify *mptcpify_skel;
287
288         mptcpify_skel = mptcpify__open_and_load();
289         if (!ASSERT_OK_PTR(mptcpify_skel, "skel_open_load"))
290                 return libbpf_get_error(mptcpify_skel);
291
292         mptcpify_skel->bss->pid = getpid();
293
294         err = mptcpify__attach(mptcpify_skel);
295         if (!ASSERT_OK(err, "skel_attach"))
296                 goto out;
297
298         /* without MPTCP */
299         server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
300         if (!ASSERT_GE(server_fd, 0, "start_server")) {
301                 err = -EIO;
302                 goto out;
303         }
304
305         client_fd = connect_to_fd(server_fd, 0);
306         if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
307                 err = -EIO;
308                 goto close_server;
309         }
310
311         send_byte(client_fd);
312
313         err = verify_mptcpify(server_fd, client_fd);
314
315         close(client_fd);
316 close_server:
317         close(server_fd);
318 out:
319         mptcpify__destroy(mptcpify_skel);
320         return err;
321 }
322
323 static void test_mptcpify(void)
324 {
325         struct nstoken *nstoken = NULL;
326         int cgroup_fd;
327
328         cgroup_fd = test__join_cgroup("/mptcpify");
329         if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
330                 return;
331
332         nstoken = create_netns();
333         if (!ASSERT_OK_PTR(nstoken, "create_netns"))
334                 goto fail;
335
336         ASSERT_OK(run_mptcpify(cgroup_fd), "run_mptcpify");
337
338 fail:
339         cleanup_netns(nstoken);
340         close(cgroup_fd);
341 }
342
343 void test_mptcp(void)
344 {
345         if (test__start_subtest("base"))
346                 test_base();
347         if (test__start_subtest("mptcpify"))
348                 test_mptcpify();
349 }
This page took 0.05706 seconds and 4 git commands to generate.