]> Git Repo - J-linux.git/blob - tools/testing/selftests/bpf/prog_tests/socket_helpers.h
Merge tag 'vfs-6.13-rc7.fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/vfs/vfs
[J-linux.git] / tools / testing / selftests / bpf / prog_tests / socket_helpers.h
1 /* SPDX-License-Identifier: GPL-2.0 */
2
3 #ifndef __SOCKET_HELPERS__
4 #define __SOCKET_HELPERS__
5
6 #include <linux/vm_sockets.h>
7
8 /* include/linux/net.h */
9 #define SOCK_TYPE_MASK 0xf
10
11 #define IO_TIMEOUT_SEC 30
12 #define MAX_STRERR_LEN 256
13
14 /* workaround for older vm_sockets.h */
15 #ifndef VMADDR_CID_LOCAL
16 #define VMADDR_CID_LOCAL 1
17 #endif
18
19 /* include/linux/cleanup.h */
20 #define __get_and_null(p, nullvalue)                                           \
21         ({                                                                     \
22                 __auto_type __ptr = &(p);                                      \
23                 __auto_type __val = *__ptr;                                    \
24                 *__ptr = nullvalue;                                            \
25                 __val;                                                         \
26         })
27
28 #define take_fd(fd) __get_and_null(fd, -EBADF)
29
30 /* Wrappers that fail the test on error and report it. */
31
32 #define _FAIL(errnum, fmt...)                                                  \
33         ({                                                                     \
34                 error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
35                 CHECK_FAIL(true);                                              \
36         })
37 #define FAIL(fmt...) _FAIL(0, fmt)
38 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
39 #define FAIL_LIBBPF(err, msg)                                                  \
40         ({                                                                     \
41                 char __buf[MAX_STRERR_LEN];                                    \
42                 libbpf_strerror((err), __buf, sizeof(__buf));                  \
43                 FAIL("%s: %s", (msg), __buf);                                  \
44         })
45
46
47 #define xaccept_nonblock(fd, addr, len)                                        \
48         ({                                                                     \
49                 int __ret =                                                    \
50                         accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC);   \
51                 if (__ret == -1)                                               \
52                         FAIL_ERRNO("accept");                                  \
53                 __ret;                                                         \
54         })
55
56 #define xbind(fd, addr, len)                                                   \
57         ({                                                                     \
58                 int __ret = bind((fd), (addr), (len));                         \
59                 if (__ret == -1)                                               \
60                         FAIL_ERRNO("bind");                                    \
61                 __ret;                                                         \
62         })
63
64 #define xclose(fd)                                                             \
65         ({                                                                     \
66                 int __ret = close((fd));                                       \
67                 if (__ret == -1)                                               \
68                         FAIL_ERRNO("close");                                   \
69                 __ret;                                                         \
70         })
71
72 #define xconnect(fd, addr, len)                                                \
73         ({                                                                     \
74                 int __ret = connect((fd), (addr), (len));                      \
75                 if (__ret == -1)                                               \
76                         FAIL_ERRNO("connect");                                 \
77                 __ret;                                                         \
78         })
79
80 #define xgetsockname(fd, addr, len)                                            \
81         ({                                                                     \
82                 int __ret = getsockname((fd), (addr), (len));                  \
83                 if (__ret == -1)                                               \
84                         FAIL_ERRNO("getsockname");                             \
85                 __ret;                                                         \
86         })
87
88 #define xgetsockopt(fd, level, name, val, len)                                 \
89         ({                                                                     \
90                 int __ret = getsockopt((fd), (level), (name), (val), (len));   \
91                 if (__ret == -1)                                               \
92                         FAIL_ERRNO("getsockopt(" #name ")");                   \
93                 __ret;                                                         \
94         })
95
96 #define xlisten(fd, backlog)                                                   \
97         ({                                                                     \
98                 int __ret = listen((fd), (backlog));                           \
99                 if (__ret == -1)                                               \
100                         FAIL_ERRNO("listen");                                  \
101                 __ret;                                                         \
102         })
103
104 #define xsetsockopt(fd, level, name, val, len)                                 \
105         ({                                                                     \
106                 int __ret = setsockopt((fd), (level), (name), (val), (len));   \
107                 if (__ret == -1)                                               \
108                         FAIL_ERRNO("setsockopt(" #name ")");                   \
109                 __ret;                                                         \
110         })
111
112 #define xsend(fd, buf, len, flags)                                             \
113         ({                                                                     \
114                 ssize_t __ret = send((fd), (buf), (len), (flags));             \
115                 if (__ret == -1)                                               \
116                         FAIL_ERRNO("send");                                    \
117                 __ret;                                                         \
118         })
119
120 #define xrecv_nonblock(fd, buf, len, flags)                                    \
121         ({                                                                     \
122                 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags),      \
123                                              IO_TIMEOUT_SEC);                  \
124                 if (__ret == -1)                                               \
125                         FAIL_ERRNO("recv");                                    \
126                 __ret;                                                         \
127         })
128
129 #define xsocket(family, sotype, flags)                                         \
130         ({                                                                     \
131                 int __ret = socket(family, sotype, flags);                     \
132                 if (__ret == -1)                                               \
133                         FAIL_ERRNO("socket");                                  \
134                 __ret;                                                         \
135         })
136
137 static inline void close_fd(int *fd)
138 {
139         if (*fd >= 0)
140                 xclose(*fd);
141 }
142
143 #define __close_fd __attribute__((cleanup(close_fd)))
144
145 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
146 {
147         return (struct sockaddr *)ss;
148 }
149
150 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
151                                        socklen_t *len)
152 {
153         struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
154
155         addr4->sin_family = AF_INET;
156         addr4->sin_port = 0;
157         addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
158         *len = sizeof(*addr4);
159 }
160
161 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
162                                        socklen_t *len)
163 {
164         struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
165
166         addr6->sin6_family = AF_INET6;
167         addr6->sin6_port = 0;
168         addr6->sin6_addr = in6addr_loopback;
169         *len = sizeof(*addr6);
170 }
171
172 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
173                                             socklen_t *len)
174 {
175         struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
176
177         addr->svm_family = AF_VSOCK;
178         addr->svm_port = VMADDR_PORT_ANY;
179         addr->svm_cid = VMADDR_CID_LOCAL;
180         *len = sizeof(*addr);
181 }
182
183 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
184                                       socklen_t *len)
185 {
186         switch (family) {
187         case AF_INET:
188                 init_addr_loopback4(ss, len);
189                 return;
190         case AF_INET6:
191                 init_addr_loopback6(ss, len);
192                 return;
193         case AF_VSOCK:
194                 init_addr_loopback_vsock(ss, len);
195                 return;
196         default:
197                 FAIL("unsupported address family %d", family);
198         }
199 }
200
201 static inline int enable_reuseport(int s, int progfd)
202 {
203         int err, one = 1;
204
205         err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
206         if (err)
207                 return -1;
208         err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
209                           sizeof(progfd));
210         if (err)
211                 return -1;
212
213         return 0;
214 }
215
216 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
217 {
218         struct sockaddr_storage addr;
219         socklen_t len = 0;
220         int err, s;
221
222         init_addr_loopback(family, &addr, &len);
223
224         s = xsocket(family, sotype, 0);
225         if (s == -1)
226                 return -1;
227
228         if (progfd >= 0)
229                 enable_reuseport(s, progfd);
230
231         err = xbind(s, sockaddr(&addr), len);
232         if (err)
233                 goto close;
234
235         if (sotype & SOCK_DGRAM)
236                 return s;
237
238         err = xlisten(s, SOMAXCONN);
239         if (err)
240                 goto close;
241
242         return s;
243 close:
244         xclose(s);
245         return -1;
246 }
247
248 static inline int socket_loopback(int family, int sotype)
249 {
250         return socket_loopback_reuseport(family, sotype, -1);
251 }
252
253 static inline int poll_connect(int fd, unsigned int timeout_sec)
254 {
255         struct timeval timeout = { .tv_sec = timeout_sec };
256         fd_set wfds;
257         int r, eval;
258         socklen_t esize = sizeof(eval);
259
260         FD_ZERO(&wfds);
261         FD_SET(fd, &wfds);
262
263         r = select(fd + 1, NULL, &wfds, NULL, &timeout);
264         if (r == 0)
265                 errno = ETIME;
266         if (r != 1)
267                 return -1;
268
269         if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
270                 return -1;
271         if (eval != 0) {
272                 errno = eval;
273                 return -1;
274         }
275
276         return 0;
277 }
278
279 static inline int poll_read(int fd, unsigned int timeout_sec)
280 {
281         struct timeval timeout = { .tv_sec = timeout_sec };
282         fd_set rfds;
283         int r;
284
285         FD_ZERO(&rfds);
286         FD_SET(fd, &rfds);
287
288         r = select(fd + 1, &rfds, NULL, NULL, &timeout);
289         if (r == 0)
290                 errno = ETIME;
291
292         return r == 1 ? 0 : -1;
293 }
294
295 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
296                                  unsigned int timeout_sec)
297 {
298         if (poll_read(fd, timeout_sec))
299                 return -1;
300
301         return accept(fd, addr, len);
302 }
303
304 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
305                                unsigned int timeout_sec)
306 {
307         if (poll_read(fd, timeout_sec))
308                 return -1;
309
310         return recv(fd, buf, len, flags);
311 }
312
313
314 static inline int create_pair(int family, int sotype, int *p0, int *p1)
315 {
316         __close_fd int s, c = -1, p = -1;
317         struct sockaddr_storage addr;
318         socklen_t len = sizeof(addr);
319         int err;
320
321         s = socket_loopback(family, sotype);
322         if (s < 0)
323                 return s;
324
325         err = xgetsockname(s, sockaddr(&addr), &len);
326         if (err)
327                 return err;
328
329         c = xsocket(family, sotype, 0);
330         if (c < 0)
331                 return c;
332
333         err = connect(c, sockaddr(&addr), len);
334         if (err) {
335                 if (errno != EINPROGRESS) {
336                         FAIL_ERRNO("connect");
337                         return err;
338                 }
339
340                 err = poll_connect(c, IO_TIMEOUT_SEC);
341                 if (err) {
342                         FAIL_ERRNO("poll_connect");
343                         return err;
344                 }
345         }
346
347         switch (sotype & SOCK_TYPE_MASK) {
348         case SOCK_DGRAM:
349                 err = xgetsockname(c, sockaddr(&addr), &len);
350                 if (err)
351                         return err;
352
353                 err = xconnect(s, sockaddr(&addr), len);
354                 if (err)
355                         return err;
356
357                 *p0 = take_fd(s);
358                 break;
359         case SOCK_STREAM:
360         case SOCK_SEQPACKET:
361                 p = xaccept_nonblock(s, NULL, NULL);
362                 if (p < 0)
363                         return p;
364
365                 *p0 = take_fd(p);
366                 break;
367         default:
368                 FAIL("Unsupported socket type %#x", sotype);
369                 return -EOPNOTSUPP;
370         }
371
372         *p1 = take_fd(c);
373         return 0;
374 }
375
376 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
377                                       int *p0, int *p1)
378 {
379         int err;
380
381         err = create_pair(family, sotype, c0, p0);
382         if (err)
383                 return err;
384
385         err = create_pair(family, sotype, c1, p1);
386         if (err) {
387                 close(*c0);
388                 close(*p0);
389         }
390
391         return err;
392 }
393
394 #endif // __SOCKET_HELPERS__
This page took 0.049273 seconds and 4 git commands to generate.