1 /* SPDX-License-Identifier: GPL-2.0 */
3 #ifndef __SOCKET_HELPERS__
4 #define __SOCKET_HELPERS__
6 #include <linux/vm_sockets.h>
8 /* include/linux/net.h */
9 #define SOCK_TYPE_MASK 0xf
11 #define IO_TIMEOUT_SEC 30
12 #define MAX_STRERR_LEN 256
14 /* workaround for older vm_sockets.h */
15 #ifndef VMADDR_CID_LOCAL
16 #define VMADDR_CID_LOCAL 1
19 /* include/linux/cleanup.h */
20 #define __get_and_null(p, nullvalue) \
22 __auto_type __ptr = &(p); \
23 __auto_type __val = *__ptr; \
28 #define take_fd(fd) __get_and_null(fd, -EBADF)
30 /* Wrappers that fail the test on error and report it. */
32 #define _FAIL(errnum, fmt...) \
34 error_at_line(0, (errnum), __func__, __LINE__, fmt); \
37 #define FAIL(fmt...) _FAIL(0, fmt)
38 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
39 #define FAIL_LIBBPF(err, msg) \
41 char __buf[MAX_STRERR_LEN]; \
42 libbpf_strerror((err), __buf, sizeof(__buf)); \
43 FAIL("%s: %s", (msg), __buf); \
47 #define xaccept_nonblock(fd, addr, len) \
50 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \
52 FAIL_ERRNO("accept"); \
56 #define xbind(fd, addr, len) \
58 int __ret = bind((fd), (addr), (len)); \
66 int __ret = close((fd)); \
68 FAIL_ERRNO("close"); \
72 #define xconnect(fd, addr, len) \
74 int __ret = connect((fd), (addr), (len)); \
76 FAIL_ERRNO("connect"); \
80 #define xgetsockname(fd, addr, len) \
82 int __ret = getsockname((fd), (addr), (len)); \
84 FAIL_ERRNO("getsockname"); \
88 #define xgetsockopt(fd, level, name, val, len) \
90 int __ret = getsockopt((fd), (level), (name), (val), (len)); \
92 FAIL_ERRNO("getsockopt(" #name ")"); \
96 #define xlisten(fd, backlog) \
98 int __ret = listen((fd), (backlog)); \
100 FAIL_ERRNO("listen"); \
104 #define xsetsockopt(fd, level, name, val, len) \
106 int __ret = setsockopt((fd), (level), (name), (val), (len)); \
108 FAIL_ERRNO("setsockopt(" #name ")"); \
112 #define xsend(fd, buf, len, flags) \
114 ssize_t __ret = send((fd), (buf), (len), (flags)); \
116 FAIL_ERRNO("send"); \
120 #define xrecv_nonblock(fd, buf, len, flags) \
122 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \
125 FAIL_ERRNO("recv"); \
129 #define xsocket(family, sotype, flags) \
131 int __ret = socket(family, sotype, flags); \
133 FAIL_ERRNO("socket"); \
137 static inline void close_fd(int *fd)
143 #define __close_fd __attribute__((cleanup(close_fd)))
145 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
147 return (struct sockaddr *)ss;
150 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
153 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
155 addr4->sin_family = AF_INET;
157 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
158 *len = sizeof(*addr4);
161 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
164 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
166 addr6->sin6_family = AF_INET6;
167 addr6->sin6_port = 0;
168 addr6->sin6_addr = in6addr_loopback;
169 *len = sizeof(*addr6);
172 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
175 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
177 addr->svm_family = AF_VSOCK;
178 addr->svm_port = VMADDR_PORT_ANY;
179 addr->svm_cid = VMADDR_CID_LOCAL;
180 *len = sizeof(*addr);
183 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
188 init_addr_loopback4(ss, len);
191 init_addr_loopback6(ss, len);
194 init_addr_loopback_vsock(ss, len);
197 FAIL("unsupported address family %d", family);
201 static inline int enable_reuseport(int s, int progfd)
205 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
208 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
216 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
218 struct sockaddr_storage addr;
222 init_addr_loopback(family, &addr, &len);
224 s = xsocket(family, sotype, 0);
229 enable_reuseport(s, progfd);
231 err = xbind(s, sockaddr(&addr), len);
235 if (sotype & SOCK_DGRAM)
238 err = xlisten(s, SOMAXCONN);
248 static inline int socket_loopback(int family, int sotype)
250 return socket_loopback_reuseport(family, sotype, -1);
253 static inline int poll_connect(int fd, unsigned int timeout_sec)
255 struct timeval timeout = { .tv_sec = timeout_sec };
258 socklen_t esize = sizeof(eval);
263 r = select(fd + 1, NULL, &wfds, NULL, &timeout);
269 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
279 static inline int poll_read(int fd, unsigned int timeout_sec)
281 struct timeval timeout = { .tv_sec = timeout_sec };
288 r = select(fd + 1, &rfds, NULL, NULL, &timeout);
292 return r == 1 ? 0 : -1;
295 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
296 unsigned int timeout_sec)
298 if (poll_read(fd, timeout_sec))
301 return accept(fd, addr, len);
304 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
305 unsigned int timeout_sec)
307 if (poll_read(fd, timeout_sec))
310 return recv(fd, buf, len, flags);
314 static inline int create_pair(int family, int sotype, int *p0, int *p1)
316 __close_fd int s, c = -1, p = -1;
317 struct sockaddr_storage addr;
318 socklen_t len = sizeof(addr);
321 s = socket_loopback(family, sotype);
325 err = xgetsockname(s, sockaddr(&addr), &len);
329 c = xsocket(family, sotype, 0);
333 err = connect(c, sockaddr(&addr), len);
335 if (errno != EINPROGRESS) {
336 FAIL_ERRNO("connect");
340 err = poll_connect(c, IO_TIMEOUT_SEC);
342 FAIL_ERRNO("poll_connect");
347 switch (sotype & SOCK_TYPE_MASK) {
349 err = xgetsockname(c, sockaddr(&addr), &len);
353 err = xconnect(s, sockaddr(&addr), len);
361 p = xaccept_nonblock(s, NULL, NULL);
368 FAIL("Unsupported socket type %#x", sotype);
376 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
381 err = create_pair(family, sotype, c0, p0);
385 err = create_pair(family, sotype, c1, p1);
394 #endif // __SOCKET_HELPERS__