]> Git Repo - linux.git/blob - tools/testing/selftests/net/af_unix/scm_pidfd.c
Linux 6.14-rc3
[linux.git] / tools / testing / selftests / net / af_unix / scm_pidfd.c
1 // SPDX-License-Identifier: GPL-2.0 OR MIT
2 #define _GNU_SOURCE
3 #include <error.h>
4 #include <limits.h>
5 #include <stddef.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <sys/socket.h>
9 #include <linux/socket.h>
10 #include <unistd.h>
11 #include <string.h>
12 #include <errno.h>
13 #include <sys/un.h>
14 #include <sys/signal.h>
15 #include <sys/types.h>
16 #include <sys/wait.h>
17
18 #include "../../kselftest_harness.h"
19
20 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
21 #define log_err(MSG, ...)                                                   \
22         fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", __FILE__, __LINE__, \
23                 clean_errno(), ##__VA_ARGS__)
24
25 #ifndef SCM_PIDFD
26 #define SCM_PIDFD 0x04
27 #endif
28
29 static void child_die()
30 {
31         exit(1);
32 }
33
34 static int safe_int(const char *numstr, int *converted)
35 {
36         char *err = NULL;
37         long sli;
38
39         errno = 0;
40         sli = strtol(numstr, &err, 0);
41         if (errno == ERANGE && (sli == LONG_MAX || sli == LONG_MIN))
42                 return -ERANGE;
43
44         if (errno != 0 && sli == 0)
45                 return -EINVAL;
46
47         if (err == numstr || *err != '\0')
48                 return -EINVAL;
49
50         if (sli > INT_MAX || sli < INT_MIN)
51                 return -ERANGE;
52
53         *converted = (int)sli;
54         return 0;
55 }
56
57 static int char_left_gc(const char *buffer, size_t len)
58 {
59         size_t i;
60
61         for (i = 0; i < len; i++) {
62                 if (buffer[i] == ' ' || buffer[i] == '\t')
63                         continue;
64
65                 return i;
66         }
67
68         return 0;
69 }
70
71 static int char_right_gc(const char *buffer, size_t len)
72 {
73         int i;
74
75         for (i = len - 1; i >= 0; i--) {
76                 if (buffer[i] == ' ' || buffer[i] == '\t' ||
77                     buffer[i] == '\n' || buffer[i] == '\0')
78                         continue;
79
80                 return i + 1;
81         }
82
83         return 0;
84 }
85
86 static char *trim_whitespace_in_place(char *buffer)
87 {
88         buffer += char_left_gc(buffer, strlen(buffer));
89         buffer[char_right_gc(buffer, strlen(buffer))] = '\0';
90         return buffer;
91 }
92
93 /* borrowed (with all helpers) from pidfd/pidfd_open_test.c */
94 static pid_t get_pid_from_fdinfo_file(int pidfd, const char *key, size_t keylen)
95 {
96         int ret;
97         char path[512];
98         FILE *f;
99         size_t n = 0;
100         pid_t result = -1;
101         char *line = NULL;
102
103         snprintf(path, sizeof(path), "/proc/self/fdinfo/%d", pidfd);
104
105         f = fopen(path, "re");
106         if (!f)
107                 return -1;
108
109         while (getline(&line, &n, f) != -1) {
110                 char *numstr;
111
112                 if (strncmp(line, key, keylen))
113                         continue;
114
115                 numstr = trim_whitespace_in_place(line + 4);
116                 ret = safe_int(numstr, &result);
117                 if (ret < 0)
118                         goto out;
119
120                 break;
121         }
122
123 out:
124         free(line);
125         fclose(f);
126         return result;
127 }
128
129 static int cmsg_check(int fd)
130 {
131         struct msghdr msg = { 0 };
132         struct cmsghdr *cmsg;
133         struct iovec iov;
134         struct ucred *ucred = NULL;
135         int data = 0;
136         char control[CMSG_SPACE(sizeof(struct ucred)) +
137                      CMSG_SPACE(sizeof(int))] = { 0 };
138         int *pidfd = NULL;
139         pid_t parent_pid;
140         int err;
141
142         iov.iov_base = &data;
143         iov.iov_len = sizeof(data);
144
145         msg.msg_iov = &iov;
146         msg.msg_iovlen = 1;
147         msg.msg_control = control;
148         msg.msg_controllen = sizeof(control);
149
150         err = recvmsg(fd, &msg, 0);
151         if (err < 0) {
152                 log_err("recvmsg");
153                 return 1;
154         }
155
156         if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
157                 log_err("recvmsg: truncated");
158                 return 1;
159         }
160
161         for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
162              cmsg = CMSG_NXTHDR(&msg, cmsg)) {
163                 if (cmsg->cmsg_level == SOL_SOCKET &&
164                     cmsg->cmsg_type == SCM_PIDFD) {
165                         if (cmsg->cmsg_len < sizeof(*pidfd)) {
166                                 log_err("CMSG parse: SCM_PIDFD wrong len");
167                                 return 1;
168                         }
169
170                         pidfd = (void *)CMSG_DATA(cmsg);
171                 }
172
173                 if (cmsg->cmsg_level == SOL_SOCKET &&
174                     cmsg->cmsg_type == SCM_CREDENTIALS) {
175                         if (cmsg->cmsg_len < sizeof(*ucred)) {
176                                 log_err("CMSG parse: SCM_CREDENTIALS wrong len");
177                                 return 1;
178                         }
179
180                         ucred = (void *)CMSG_DATA(cmsg);
181                 }
182         }
183
184         /* send(pfd, "x", sizeof(char), 0) */
185         if (data != 'x') {
186                 log_err("recvmsg: data corruption");
187                 return 1;
188         }
189
190         if (!pidfd) {
191                 log_err("CMSG parse: SCM_PIDFD not found");
192                 return 1;
193         }
194
195         if (!ucred) {
196                 log_err("CMSG parse: SCM_CREDENTIALS not found");
197                 return 1;
198         }
199
200         /* pidfd from SCM_PIDFD should point to the parent process PID */
201         parent_pid =
202                 get_pid_from_fdinfo_file(*pidfd, "Pid:", sizeof("Pid:") - 1);
203         if (parent_pid != getppid()) {
204                 log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid());
205                 return 1;
206         }
207
208         return 0;
209 }
210
211 struct sock_addr {
212         char sock_name[32];
213         struct sockaddr_un listen_addr;
214         socklen_t addrlen;
215 };
216
217 FIXTURE(scm_pidfd)
218 {
219         int server;
220         pid_t client_pid;
221         int startup_pipe[2];
222         struct sock_addr server_addr;
223         struct sock_addr *client_addr;
224 };
225
226 FIXTURE_VARIANT(scm_pidfd)
227 {
228         int type;
229         bool abstract;
230 };
231
232 FIXTURE_VARIANT_ADD(scm_pidfd, stream_pathname)
233 {
234         .type = SOCK_STREAM,
235         .abstract = 0,
236 };
237
238 FIXTURE_VARIANT_ADD(scm_pidfd, stream_abstract)
239 {
240         .type = SOCK_STREAM,
241         .abstract = 1,
242 };
243
244 FIXTURE_VARIANT_ADD(scm_pidfd, dgram_pathname)
245 {
246         .type = SOCK_DGRAM,
247         .abstract = 0,
248 };
249
250 FIXTURE_VARIANT_ADD(scm_pidfd, dgram_abstract)
251 {
252         .type = SOCK_DGRAM,
253         .abstract = 1,
254 };
255
256 FIXTURE_SETUP(scm_pidfd)
257 {
258         self->client_addr = mmap(NULL, sizeof(*self->client_addr), PROT_READ | PROT_WRITE,
259                                  MAP_SHARED | MAP_ANONYMOUS, -1, 0);
260         ASSERT_NE(MAP_FAILED, self->client_addr);
261 }
262
263 FIXTURE_TEARDOWN(scm_pidfd)
264 {
265         close(self->server);
266
267         kill(self->client_pid, SIGKILL);
268         waitpid(self->client_pid, NULL, 0);
269
270         if (!variant->abstract) {
271                 unlink(self->server_addr.sock_name);
272                 unlink(self->client_addr->sock_name);
273         }
274 }
275
276 static void fill_sockaddr(struct sock_addr *addr, bool abstract)
277 {
278         char *sun_path_buf = (char *)&addr->listen_addr.sun_path;
279
280         addr->listen_addr.sun_family = AF_UNIX;
281         addr->addrlen = offsetof(struct sockaddr_un, sun_path);
282         snprintf(addr->sock_name, sizeof(addr->sock_name), "scm_pidfd_%d", getpid());
283         addr->addrlen += strlen(addr->sock_name);
284         if (abstract) {
285                 *sun_path_buf = '\0';
286                 addr->addrlen++;
287                 sun_path_buf++;
288         } else {
289                 unlink(addr->sock_name);
290         }
291         memcpy(sun_path_buf, addr->sock_name, strlen(addr->sock_name));
292 }
293
294 static void client(FIXTURE_DATA(scm_pidfd) *self,
295                    const FIXTURE_VARIANT(scm_pidfd) *variant)
296 {
297         int cfd;
298         socklen_t len;
299         struct ucred peer_cred;
300         int peer_pidfd;
301         pid_t peer_pid;
302         int on = 0;
303
304         cfd = socket(AF_UNIX, variant->type, 0);
305         if (cfd < 0) {
306                 log_err("socket");
307                 child_die();
308         }
309
310         if (variant->type == SOCK_DGRAM) {
311                 fill_sockaddr(self->client_addr, variant->abstract);
312
313                 if (bind(cfd, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen)) {
314                         log_err("bind");
315                         child_die();
316                 }
317         }
318
319         if (connect(cfd, (struct sockaddr *)&self->server_addr.listen_addr,
320                     self->server_addr.addrlen) != 0) {
321                 log_err("connect");
322                 child_die();
323         }
324
325         on = 1;
326         if (setsockopt(cfd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) {
327                 log_err("Failed to set SO_PASSCRED");
328                 child_die();
329         }
330
331         if (setsockopt(cfd, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) {
332                 log_err("Failed to set SO_PASSPIDFD");
333                 child_die();
334         }
335
336         close(self->startup_pipe[1]);
337
338         if (cmsg_check(cfd)) {
339                 log_err("cmsg_check failed");
340                 child_die();
341         }
342
343         /* skip further for SOCK_DGRAM as it's not applicable */
344         if (variant->type == SOCK_DGRAM)
345                 return;
346
347         len = sizeof(peer_cred);
348         if (getsockopt(cfd, SOL_SOCKET, SO_PEERCRED, &peer_cred, &len)) {
349                 log_err("Failed to get SO_PEERCRED");
350                 child_die();
351         }
352
353         len = sizeof(peer_pidfd);
354         if (getsockopt(cfd, SOL_SOCKET, SO_PEERPIDFD, &peer_pidfd, &len)) {
355                 log_err("Failed to get SO_PEERPIDFD");
356                 child_die();
357         }
358
359         /* pid from SO_PEERCRED should point to the parent process PID */
360         if (peer_cred.pid != getppid()) {
361                 log_err("peer_cred.pid != getppid(): %d != %d", peer_cred.pid, getppid());
362                 child_die();
363         }
364
365         peer_pid = get_pid_from_fdinfo_file(peer_pidfd,
366                                             "Pid:", sizeof("Pid:") - 1);
367         if (peer_pid != peer_cred.pid) {
368                 log_err("peer_pid != peer_cred.pid: %d != %d", peer_pid, peer_cred.pid);
369                 child_die();
370         }
371 }
372
373 TEST_F(scm_pidfd, test)
374 {
375         int err;
376         int pfd;
377         int child_status = 0;
378
379         self->server = socket(AF_UNIX, variant->type, 0);
380         ASSERT_NE(-1, self->server);
381
382         fill_sockaddr(&self->server_addr, variant->abstract);
383
384         err = bind(self->server, (struct sockaddr *)&self->server_addr.listen_addr, self->server_addr.addrlen);
385         ASSERT_EQ(0, err);
386
387         if (variant->type == SOCK_STREAM) {
388                 err = listen(self->server, 1);
389                 ASSERT_EQ(0, err);
390         }
391
392         err = pipe(self->startup_pipe);
393         ASSERT_NE(-1, err);
394
395         self->client_pid = fork();
396         ASSERT_NE(-1, self->client_pid);
397         if (self->client_pid == 0) {
398                 close(self->server);
399                 close(self->startup_pipe[0]);
400                 client(self, variant);
401                 exit(0);
402         }
403         close(self->startup_pipe[1]);
404
405         if (variant->type == SOCK_STREAM) {
406                 pfd = accept(self->server, NULL, NULL);
407                 ASSERT_NE(-1, pfd);
408         } else {
409                 pfd = self->server;
410         }
411
412         /* wait until the child arrives at checkpoint */
413         read(self->startup_pipe[0], &err, sizeof(int));
414         close(self->startup_pipe[0]);
415
416         if (variant->type == SOCK_DGRAM) {
417                 err = sendto(pfd, "x", sizeof(char), 0, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen);
418                 ASSERT_NE(-1, err);
419         } else {
420                 err = send(pfd, "x", sizeof(char), 0);
421                 ASSERT_NE(-1, err);
422         }
423
424         close(pfd);
425         waitpid(self->client_pid, &child_status, 0);
426         ASSERT_EQ(0, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1);
427 }
428
429 TEST_HARNESS_MAIN
This page took 0.054418 seconds and 4 git commands to generate.