]> Git Repo - J-linux.git/blob - tools/testing/selftests/net/tcp_ao/key-management.c
Merge tag 'vfs-6.13-rc7.fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/vfs/vfs
[J-linux.git] / tools / testing / selftests / net / tcp_ao / key-management.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <[email protected]> */
3 #include <inttypes.h>
4 #include "../../../../include/linux/kernel.h"
5 #include "aolib.h"
6
7 const size_t nr_packets = 20;
8 const size_t msg_len = 100;
9 const size_t quota = nr_packets * msg_len;
10 union tcp_addr wrong_addr;
11 #define SECOND_PASSWORD "at all times sincere friends of freedom have been rare"
12 #define fault(type)     (inj == FAULT_ ## type)
13
14 static const int test_vrf_ifindex = 200;
15 static const uint8_t test_vrf_tabid = 42;
16 static void setup_vrfs(void)
17 {
18         int err;
19
20         if (!kernel_config_has(KCONFIG_NET_VRF))
21                 return;
22
23         err = add_vrf("ksft-vrf", test_vrf_tabid, test_vrf_ifindex, -1);
24         if (err)
25                 test_error("Failed to add a VRF: %d", err);
26
27         err = link_set_up("ksft-vrf");
28         if (err)
29                 test_error("Failed to bring up a VRF");
30
31         err = ip_route_add_vrf(veth_name, TEST_FAMILY,
32                                this_ip_addr, this_ip_dest, test_vrf_tabid);
33         if (err)
34                 test_error("Failed to add a route to VRF");
35 }
36
37
38 static int prepare_sk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid)
39 {
40         int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
41
42         if (sk < 0)
43                 test_error("socket()");
44
45         if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest,
46                          DEFAULT_TEST_PREFIX, 100, 100))
47                 test_error("test_add_key()");
48
49         if (addr && test_add_key(sk, SECOND_PASSWORD, *addr,
50                                  DEFAULT_TEST_PREFIX, sndid, rcvid))
51                 test_error("test_add_key()");
52
53         return sk;
54 }
55
56 static int prepare_lsk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid)
57 {
58         int sk = prepare_sk(addr, sndid, rcvid);
59
60         if (listen(sk, 10))
61                 test_error("listen()");
62
63         return sk;
64 }
65
66 static int test_del_key(int sk, uint8_t sndid, uint8_t rcvid, bool async,
67                         int current_key, int rnext_key)
68 {
69         struct tcp_ao_info_opt ao_info = {};
70         struct tcp_ao_getsockopt key = {};
71         struct tcp_ao_del del = {};
72         sockaddr_af sockaddr;
73         int err;
74
75         tcp_addr_to_sockaddr_in(&del.addr, &this_ip_dest, 0);
76         del.prefix = DEFAULT_TEST_PREFIX;
77         del.sndid = sndid;
78         del.rcvid = rcvid;
79
80         if (current_key >= 0) {
81                 del.set_current = 1;
82                 del.current_key = (uint8_t)current_key;
83         }
84         if (rnext_key >= 0) {
85                 del.set_rnext = 1;
86                 del.rnext = (uint8_t)rnext_key;
87         }
88
89         err = setsockopt(sk, IPPROTO_TCP, TCP_AO_DEL_KEY, &del, sizeof(del));
90         if (err < 0)
91                 return -errno;
92
93         if (async)
94                 return 0;
95
96         tcp_addr_to_sockaddr_in(&sockaddr, &this_ip_dest, 0);
97         err = test_get_one_ao(sk, &key, &sockaddr, sizeof(sockaddr),
98                               DEFAULT_TEST_PREFIX, sndid, rcvid);
99         if (!err)
100                 return -EEXIST;
101         if (err != -E2BIG)
102                 test_error("getsockopt()");
103         if (current_key < 0 && rnext_key < 0)
104                 return 0;
105         if (test_get_ao_info(sk, &ao_info))
106                 test_error("getsockopt(TCP_AO_INFO) failed");
107         if (current_key >= 0 && ao_info.current_key != (uint8_t)current_key)
108                 return -ENOTRECOVERABLE;
109         if (rnext_key >= 0 && ao_info.rnext != (uint8_t)rnext_key)
110                 return -ENOTRECOVERABLE;
111         return 0;
112 }
113
114 static void try_delete_key(char *tst_name, int sk, uint8_t sndid, uint8_t rcvid,
115                            bool async, int current_key, int rnext_key,
116                            fault_t inj)
117 {
118         int err;
119
120         err = test_del_key(sk, sndid, rcvid, async, current_key, rnext_key);
121         if ((err == -EBUSY && fault(BUSY)) || (err == -EINVAL && fault(CURRNEXT))) {
122                 test_ok("%s: key deletion was prevented", tst_name);
123                 return;
124         }
125         if (err && fault(FIXME)) {
126                 test_xfail("%s: failed to delete the key %u:%u %d",
127                            tst_name, sndid, rcvid, err);
128                 return;
129         }
130         if (!err) {
131                 if (fault(BUSY) || fault(CURRNEXT)) {
132                         test_fail("%s: the key was deleted %u:%u %d", tst_name,
133                                   sndid, rcvid, err);
134                 } else {
135                         test_ok("%s: the key was deleted", tst_name);
136                 }
137                 return;
138         }
139         test_fail("%s: can't delete the key %u:%u %d", tst_name, sndid, rcvid, err);
140 }
141
142 static int test_set_key(int sk, int current_keyid, int rnext_keyid)
143 {
144         struct tcp_ao_info_opt ao_info = {};
145         int err;
146
147         if (current_keyid >= 0) {
148                 ao_info.set_current = 1;
149                 ao_info.current_key = (uint8_t)current_keyid;
150         }
151         if (rnext_keyid >= 0) {
152                 ao_info.set_rnext = 1;
153                 ao_info.rnext = (uint8_t)rnext_keyid;
154         }
155
156         err = test_set_ao_info(sk, &ao_info);
157         if (err)
158                 return err;
159         if (test_get_ao_info(sk, &ao_info))
160                 test_error("getsockopt(TCP_AO_INFO) failed");
161         if (current_keyid >= 0 && ao_info.current_key != (uint8_t)current_keyid)
162                 return -ENOTRECOVERABLE;
163         if (rnext_keyid >= 0 && ao_info.rnext != (uint8_t)rnext_keyid)
164                 return -ENOTRECOVERABLE;
165         return 0;
166 }
167
168 static int test_add_current_rnext_key(int sk, const char *key, uint8_t keyflags,
169                                       union tcp_addr in_addr, uint8_t prefix,
170                                       bool set_current, bool set_rnext,
171                                       uint8_t sndid, uint8_t rcvid)
172 {
173         struct tcp_ao_add tmp = {};
174         int err;
175
176         err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr,
177                                set_current, set_rnext,
178                                prefix, 0, sndid, rcvid, 0, keyflags,
179                                strlen(key), key);
180         if (err)
181                 return err;
182
183
184         err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
185         if (err < 0)
186                 return -errno;
187
188         return test_verify_socket_key(sk, &tmp);
189 }
190
191 static int __try_add_current_rnext_key(int sk, const char *key, uint8_t keyflags,
192                                        union tcp_addr in_addr, uint8_t prefix,
193                                        bool set_current, bool set_rnext,
194                                        uint8_t sndid, uint8_t rcvid)
195 {
196         struct tcp_ao_info_opt ao_info = {};
197         int err;
198
199         err = test_add_current_rnext_key(sk, key, keyflags, in_addr, prefix,
200                                          set_current, set_rnext, sndid, rcvid);
201         if (err)
202                 return err;
203
204         if (test_get_ao_info(sk, &ao_info))
205                 test_error("getsockopt(TCP_AO_INFO) failed");
206         if (set_current && ao_info.current_key != sndid)
207                 return -ENOTRECOVERABLE;
208         if (set_rnext && ao_info.rnext != rcvid)
209                 return -ENOTRECOVERABLE;
210         return 0;
211 }
212
213 static void try_add_current_rnext_key(char *tst_name, int sk, const char *key,
214                                      uint8_t keyflags,
215                                      union tcp_addr in_addr, uint8_t prefix,
216                                      bool set_current, bool set_rnext,
217                                      uint8_t sndid, uint8_t rcvid, fault_t inj)
218 {
219         int err;
220
221         err = __try_add_current_rnext_key(sk, key, keyflags, in_addr, prefix,
222                                           set_current, set_rnext, sndid, rcvid);
223         if (!err && !fault(CURRNEXT)) {
224                 test_ok("%s", tst_name);
225                 return;
226         }
227         if (err == -EINVAL && fault(CURRNEXT)) {
228                 test_ok("%s", tst_name);
229                 return;
230         }
231         test_fail("%s", tst_name);
232 }
233
234 static void check_closed_socket(void)
235 {
236         int sk;
237
238         sk = prepare_sk(&this_ip_dest, 200, 200);
239         try_delete_key("closed socket, delete a key", sk, 200, 200, 0, -1, -1, 0);
240         try_delete_key("closed socket, delete all keys", sk, 100, 100, 0, -1, -1, 0);
241         close(sk);
242
243         sk = prepare_sk(&this_ip_dest, 200, 200);
244         if (test_set_key(sk, 100, 200))
245                 test_error("failed to set current/rnext keys");
246         try_delete_key("closed socket, delete current key", sk, 100, 100, 0, -1, -1, FAULT_BUSY);
247         try_delete_key("closed socket, delete rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY);
248         close(sk);
249
250         sk = prepare_sk(&this_ip_dest, 200, 200);
251         if (test_add_key(sk, "Glory to heros!", this_ip_dest,
252                          DEFAULT_TEST_PREFIX, 10, 11))
253                 test_error("test_add_key()");
254         if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest,
255                          DEFAULT_TEST_PREFIX, 12, 13))
256                 test_error("test_add_key()");
257         try_delete_key("closed socket, delete a key + set current/rnext", sk, 100, 100, 0, 10, 13, 0);
258         try_delete_key("closed socket, force-delete current key", sk, 10, 11, 0, 200, -1, 0);
259         try_delete_key("closed socket, force-delete rnext key", sk, 12, 13, 0, -1, 200, 0);
260         try_delete_key("closed socket, delete current+rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY);
261         close(sk);
262
263         sk = prepare_sk(&this_ip_dest, 200, 200);
264         if (test_set_key(sk, 100, 200))
265                 test_error("failed to set current/rnext keys");
266         try_add_current_rnext_key("closed socket, add + change current key",
267                                   sk, "Laaaa! Lalala-la-la-lalala...", 0,
268                                   this_ip_dest, DEFAULT_TEST_PREFIX,
269                                   true, false, 10, 20, 0);
270         try_add_current_rnext_key("closed socket, add + change rnext key",
271                                   sk, "Laaaa! Lalala-la-la-lalala...", 0,
272                                   this_ip_dest, DEFAULT_TEST_PREFIX,
273                                   false, true, 20, 10, 0);
274         close(sk);
275 }
276
277 static void assert_no_current_rnext(const char *tst_msg, int sk)
278 {
279         struct tcp_ao_info_opt ao_info = {};
280
281         if (test_get_ao_info(sk, &ao_info))
282                 test_error("getsockopt(TCP_AO_INFO) failed");
283
284         errno = 0;
285         if (ao_info.set_current || ao_info.set_rnext) {
286                 test_xfail("%s: the socket has current/rnext keys: %d:%d",
287                            tst_msg,
288                            (ao_info.set_current) ? ao_info.current_key : -1,
289                            (ao_info.set_rnext) ? ao_info.rnext : -1);
290         } else {
291                 test_ok("%s: the socket has no current/rnext keys", tst_msg);
292         }
293 }
294
295 static void assert_no_tcp_repair(void)
296 {
297         struct tcp_ao_repair ao_img = {};
298         socklen_t len = sizeof(ao_img);
299         int sk, err;
300
301         sk = prepare_sk(&this_ip_dest, 200, 200);
302         test_enable_repair(sk);
303         if (listen(sk, 10))
304                 test_error("listen()");
305         errno = 0;
306         err = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, &len);
307         if (err && errno == EPERM)
308                 test_ok("listen socket, getsockopt(TCP_AO_REPAIR) is restricted");
309         else
310                 test_fail("listen socket, getsockopt(TCP_AO_REPAIR) works");
311         errno = 0;
312         err = setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, sizeof(ao_img));
313         if (err && errno == EPERM)
314                 test_ok("listen socket, setsockopt(TCP_AO_REPAIR) is restricted");
315         else
316                 test_fail("listen socket, setsockopt(TCP_AO_REPAIR) works");
317         close(sk);
318 }
319
320 static void check_listen_socket(void)
321 {
322         int sk, err;
323
324         sk = prepare_lsk(&this_ip_dest, 200, 200);
325         try_delete_key("listen socket, delete a key", sk, 200, 200, 0, -1, -1, 0);
326         try_delete_key("listen socket, delete all keys", sk, 100, 100, 0, -1, -1, 0);
327         close(sk);
328
329         sk = prepare_lsk(&this_ip_dest, 200, 200);
330         err = test_set_key(sk, 100, -1);
331         if (err == -EINVAL)
332                 test_ok("listen socket, setting current key not allowed");
333         else
334                 test_fail("listen socket, set current key");
335         err = test_set_key(sk, -1, 200);
336         if (err == -EINVAL)
337                 test_ok("listen socket, setting rnext key not allowed");
338         else
339                 test_fail("listen socket, set rnext key");
340         close(sk);
341
342         sk = prepare_sk(&this_ip_dest, 200, 200);
343         if (test_set_key(sk, 100, 200))
344                 test_error("failed to set current/rnext keys");
345         if (listen(sk, 10))
346                 test_error("listen()");
347         assert_no_current_rnext("listen() after current/rnext keys set", sk);
348         try_delete_key("listen socket, delete current key from before listen()", sk, 100, 100, 0, -1, -1, FAULT_FIXME);
349         try_delete_key("listen socket, delete rnext key from before listen()", sk, 200, 200, 0, -1, -1, FAULT_FIXME);
350         close(sk);
351
352         assert_no_tcp_repair();
353
354         sk = prepare_lsk(&this_ip_dest, 200, 200);
355         if (test_add_key(sk, "Glory to heros!", this_ip_dest,
356                          DEFAULT_TEST_PREFIX, 10, 11))
357                 test_error("test_add_key()");
358         if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest,
359                          DEFAULT_TEST_PREFIX, 12, 13))
360                 test_error("test_add_key()");
361         try_delete_key("listen socket, delete a key + set current/rnext", sk,
362                        100, 100, 0, 10, 13, FAULT_CURRNEXT);
363         try_delete_key("listen socket, force-delete current key", sk,
364                        10, 11, 0, 200, -1, FAULT_CURRNEXT);
365         try_delete_key("listen socket, force-delete rnext key", sk,
366                        12, 13, 0, -1, 200, FAULT_CURRNEXT);
367         try_delete_key("listen socket, delete a key", sk,
368                        200, 200, 0, -1, -1, 0);
369         close(sk);
370
371         sk = prepare_lsk(&this_ip_dest, 200, 200);
372         try_add_current_rnext_key("listen socket, add + change current key",
373                                   sk, "Laaaa! Lalala-la-la-lalala...", 0,
374                                   this_ip_dest, DEFAULT_TEST_PREFIX,
375                                   true, false, 10, 20, FAULT_CURRNEXT);
376         try_add_current_rnext_key("listen socket, add + change rnext key",
377                                   sk, "Laaaa! Lalala-la-la-lalala...", 0,
378                                   this_ip_dest, DEFAULT_TEST_PREFIX,
379                                   false, true, 20, 10, FAULT_CURRNEXT);
380         close(sk);
381 }
382
383 static const char *fips_fpath = "/proc/sys/crypto/fips_enabled";
384 static bool is_fips_enabled(void)
385 {
386         static int fips_checked = -1;
387         FILE *fenabled;
388         int enabled;
389
390         if (fips_checked >= 0)
391                 return !!fips_checked;
392         if (access(fips_fpath, R_OK)) {
393                 if (errno != ENOENT)
394                         test_error("Can't open %s", fips_fpath);
395                 fips_checked = 0;
396                 return false;
397         }
398         fenabled = fopen(fips_fpath, "r");
399         if (!fenabled)
400                 test_error("Can't open %s", fips_fpath);
401         if (fscanf(fenabled, "%d", &enabled) != 1)
402                 test_error("Can't read from %s", fips_fpath);
403         fclose(fenabled);
404         fips_checked = !!enabled;
405         return !!fips_checked;
406 }
407
408 struct test_key {
409         char password[TCP_AO_MAXKEYLEN];
410         const char *alg;
411         unsigned int len;
412         uint8_t client_keyid;
413         uint8_t server_keyid;
414         uint8_t maclen;
415         uint8_t matches_client          : 1,
416                 matches_server          : 1,
417                 matches_vrf             : 1,
418                 is_current              : 1,
419                 is_rnext                : 1,
420                 used_on_server_tx       : 1,
421                 used_on_client_tx       : 1,
422                 skip_counters_checks    : 1;
423 };
424
425 struct key_collection {
426         unsigned int nr_keys;
427         struct test_key *keys;
428 };
429
430 static struct key_collection collection;
431
432 #define TEST_MAX_MACLEN         16
433 const char *test_algos[] = {
434         "cmac(aes128)",
435         "hmac(sha1)", "hmac(sha512)", "hmac(sha384)", "hmac(sha256)",
436         "hmac(sha224)", "hmac(sha3-512)",
437         /* only if !CONFIG_FIPS */
438 #define TEST_NON_FIPS_ALGOS     2
439         "hmac(rmd160)", "hmac(md5)"
440 };
441 const unsigned int test_maclens[] = { 1, 4, 12, 16 };
442 #define MACLEN_SHIFT            2
443 #define ALGOS_SHIFT             4
444
445 static unsigned int make_mask(unsigned int shift, unsigned int prev_shift)
446 {
447         unsigned int ret = BIT(shift) - 1;
448
449         return ret << prev_shift;
450 }
451
452 static void init_key_in_collection(unsigned int index, bool randomized)
453 {
454         struct test_key *key = &collection.keys[index];
455         unsigned int algos_nr, algos_index;
456
457         /* Same for randomized and non-randomized test flows */
458         key->client_keyid = index;
459         key->server_keyid = 127 + index;
460         key->matches_client = 1;
461         key->matches_server = 1;
462         key->matches_vrf = 1;
463         /* not really even random, but good enough for a test */
464         key->len = rand() % (TCP_AO_MAXKEYLEN - TEST_TCP_AO_MINKEYLEN);
465         key->len += TEST_TCP_AO_MINKEYLEN;
466         randomize_buffer(key->password, key->len);
467
468         if (randomized) {
469                 key->maclen = (rand() % TEST_MAX_MACLEN) + 1;
470                 algos_index = rand();
471         } else {
472                 unsigned int shift = MACLEN_SHIFT;
473
474                 key->maclen = test_maclens[index & make_mask(shift, 0)];
475                 algos_index = index & make_mask(ALGOS_SHIFT, shift);
476         }
477         algos_nr = ARRAY_SIZE(test_algos);
478         if (is_fips_enabled())
479                 algos_nr -= TEST_NON_FIPS_ALGOS;
480         key->alg = test_algos[algos_index % algos_nr];
481 }
482
483 static int init_default_key_collection(unsigned int nr_keys, bool randomized)
484 {
485         size_t key_sz = sizeof(collection.keys[0]);
486
487         if (!nr_keys) {
488                 free(collection.keys);
489                 collection.keys = NULL;
490                 return 0;
491         }
492
493         /*
494          * All keys have uniq sndid/rcvid and sndid != rcvid in order to
495          * check for any bugs/issues for different keyids, visible to both
496          * peers. Keyid == 254 is unused.
497          */
498         if (nr_keys > 127)
499                 test_error("Test requires too many keys, correct the source");
500
501         collection.keys = reallocarray(collection.keys, nr_keys, key_sz);
502         if (!collection.keys)
503                 return -ENOMEM;
504
505         memset(collection.keys, 0, nr_keys * key_sz);
506         collection.nr_keys = nr_keys;
507         while (nr_keys--)
508                 init_key_in_collection(nr_keys, randomized);
509
510         return 0;
511 }
512
513 static void test_key_error(const char *msg, struct test_key *key)
514 {
515         test_error("%s: key: { %s, %u:%u, %u, %u:%u:%u:%u:%u (%u)}",
516                    msg, key->alg, key->client_keyid, key->server_keyid,
517                    key->maclen, key->matches_client, key->matches_server,
518                    key->matches_vrf, key->is_current, key->is_rnext, key->len);
519 }
520
521 static int test_add_key_cr(int sk, const char *pwd, unsigned int pwd_len,
522                            union tcp_addr addr, uint8_t vrf,
523                            uint8_t sndid, uint8_t rcvid,
524                            uint8_t maclen, const char *alg,
525                            bool set_current, bool set_rnext)
526 {
527         struct tcp_ao_add tmp = {};
528         uint8_t keyflags = 0;
529         int err;
530
531         if (!alg)
532                 alg = DEFAULT_TEST_ALGO;
533
534         if (vrf)
535                 keyflags |= TCP_AO_KEYF_IFINDEX;
536         err = test_prepare_key(&tmp, alg, addr, set_current, set_rnext,
537                                DEFAULT_TEST_PREFIX, vrf, sndid, rcvid, maclen,
538                                keyflags, pwd_len, pwd);
539         if (err)
540                 return err;
541
542         err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
543         if (err < 0)
544                 return -errno;
545
546         return test_verify_socket_key(sk, &tmp);
547 }
548
549 static void verify_current_rnext(const char *tst, int sk,
550                                  int current_keyid, int rnext_keyid)
551 {
552         struct tcp_ao_info_opt ao_info = {};
553
554         if (test_get_ao_info(sk, &ao_info))
555                 test_error("getsockopt(TCP_AO_INFO) failed");
556
557         errno = 0;
558         if (current_keyid >= 0) {
559                 if (!ao_info.set_current)
560                         test_fail("%s: the socket doesn't have current key", tst);
561                 else if (ao_info.current_key != current_keyid)
562                         test_fail("%s: current key is not the expected one %d != %u",
563                                   tst, current_keyid, ao_info.current_key);
564                 else
565                         test_ok("%s: current key %u as expected",
566                                 tst, ao_info.current_key);
567         }
568         if (rnext_keyid >= 0) {
569                 if (!ao_info.set_rnext)
570                         test_fail("%s: the socket doesn't have rnext key", tst);
571                 else if (ao_info.rnext != rnext_keyid)
572                         test_fail("%s: rnext key is not the expected one %d != %u",
573                                   tst, rnext_keyid, ao_info.rnext);
574                 else
575                         test_ok("%s: rnext key %u as expected", tst, ao_info.rnext);
576         }
577 }
578
579
580 static int key_collection_socket(bool server, unsigned int port)
581 {
582         unsigned int i;
583         int sk;
584
585         if (server)
586                 sk = test_listen_socket(this_ip_addr, port, 1);
587         else
588                 sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
589         if (sk < 0)
590                 test_error("socket()");
591
592         for (i = 0; i < collection.nr_keys; i++) {
593                 struct test_key *key = &collection.keys[i];
594                 union tcp_addr *addr = &wrong_addr;
595                 uint8_t sndid, rcvid, vrf;
596                 bool set_current = false, set_rnext = false;
597
598                 if (key->matches_vrf)
599                         vrf = 0;
600                 else
601                         vrf = test_vrf_ifindex;
602                 if (server) {
603                         if (key->matches_client)
604                                 addr = &this_ip_dest;
605                         sndid = key->server_keyid;
606                         rcvid = key->client_keyid;
607                 } else {
608                         if (key->matches_server)
609                                 addr = &this_ip_dest;
610                         sndid = key->client_keyid;
611                         rcvid = key->server_keyid;
612                         key->used_on_client_tx = set_current = key->is_current;
613                         key->used_on_server_tx = set_rnext = key->is_rnext;
614                 }
615
616                 if (test_add_key_cr(sk, key->password, key->len,
617                                     *addr, vrf, sndid, rcvid, key->maclen,
618                                     key->alg, set_current, set_rnext))
619                         test_key_error("setsockopt(TCP_AO_ADD_KEY)", key);
620 #ifdef DEBUG
621                 test_print("%s [%u/%u] key: { %s, %u:%u, %u, %u:%u:%u:%u (%u)}",
622                            server ? "server" : "client", i, collection.nr_keys,
623                            key->alg, rcvid, sndid, key->maclen,
624                            key->matches_client, key->matches_server,
625                            key->is_current, key->is_rnext, key->len);
626 #endif
627         }
628         return sk;
629 }
630
631 static void verify_counters(const char *tst_name, bool is_listen_sk, bool server,
632                             struct tcp_ao_counters *a, struct tcp_ao_counters *b)
633 {
634         unsigned int i;
635
636         __test_tcp_ao_counters_cmp(tst_name, a, b, TEST_CNT_GOOD);
637
638         for (i = 0; i < collection.nr_keys; i++) {
639                 struct test_key *key = &collection.keys[i];
640                 uint8_t sndid, rcvid;
641                 bool rx_cnt_expected;
642
643                 if (key->skip_counters_checks)
644                         continue;
645                 if (server) {
646                         sndid = key->server_keyid;
647                         rcvid = key->client_keyid;
648                         rx_cnt_expected = key->used_on_client_tx;
649                 } else {
650                         sndid = key->client_keyid;
651                         rcvid = key->server_keyid;
652                         rx_cnt_expected = key->used_on_server_tx;
653                 }
654
655                 test_tcp_ao_key_counters_cmp(tst_name, a, b,
656                                              rx_cnt_expected ? TEST_CNT_KEY_GOOD : 0,
657                                              sndid, rcvid);
658         }
659         test_tcp_ao_counters_free(a);
660         test_tcp_ao_counters_free(b);
661         test_ok("%s: passed counters checks", tst_name);
662 }
663
664 static struct tcp_ao_getsockopt *lookup_key(struct tcp_ao_getsockopt *buf,
665                                             size_t len, int sndid, int rcvid)
666 {
667         size_t i;
668
669         for (i = 0; i < len; i++) {
670                 if (sndid >= 0 && buf[i].sndid != sndid)
671                         continue;
672                 if (rcvid >= 0 && buf[i].rcvid != rcvid)
673                         continue;
674                 return &buf[i];
675         }
676         return NULL;
677 }
678
679 static void verify_keys(const char *tst_name, int sk,
680                         bool is_listen_sk, bool server)
681 {
682         socklen_t len = sizeof(struct tcp_ao_getsockopt);
683         struct tcp_ao_getsockopt *keys;
684         bool passed_test = true;
685         unsigned int i;
686
687         keys = calloc(collection.nr_keys, len);
688         if (!keys)
689                 test_error("calloc()");
690
691         keys->nkeys = collection.nr_keys;
692         keys->get_all = 1;
693
694         if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, keys, &len)) {
695                 free(keys);
696                 test_error("getsockopt(TCP_AO_GET_KEYS)");
697         }
698
699         for (i = 0; i < collection.nr_keys; i++) {
700                 struct test_key *key = &collection.keys[i];
701                 struct tcp_ao_getsockopt *dump_key;
702                 bool is_kdf_aes_128_cmac = false;
703                 bool is_cmac_aes = false;
704                 uint8_t sndid, rcvid;
705                 bool matches = false;
706
707                 if (server) {
708                         if (key->matches_client)
709                                 matches = true;
710                         sndid = key->server_keyid;
711                         rcvid = key->client_keyid;
712                 } else {
713                         if (key->matches_server)
714                                 matches = true;
715                         sndid = key->client_keyid;
716                         rcvid = key->server_keyid;
717                 }
718                 if (!key->matches_vrf)
719                         matches = false;
720                 /* no keys get removed on the original listener socket */
721                 if (is_listen_sk)
722                         matches = true;
723
724                 dump_key = lookup_key(keys, keys->nkeys, sndid, rcvid);
725                 if (matches != !!dump_key) {
726                         test_fail("%s: key %u:%u %s%s on the socket",
727                                   tst_name, sndid, rcvid,
728                                   key->matches_vrf ? "" : "[vrf] ",
729                                   matches ? "disappeared" : "yet present");
730                         passed_test = false;
731                         goto out;
732                 }
733                 if (!dump_key)
734                         continue;
735
736                 if (!strcmp("cmac(aes128)", key->alg)) {
737                         is_kdf_aes_128_cmac = (key->len != 16);
738                         is_cmac_aes = true;
739                 }
740
741                 if (is_cmac_aes) {
742                         if (strcmp(dump_key->alg_name, "cmac(aes)")) {
743                                 test_fail("%s: key %u:%u cmac(aes) has unexpected alg %s",
744                                           tst_name, sndid, rcvid,
745                                           dump_key->alg_name);
746                                 passed_test = false;
747                                 continue;
748                         }
749                 } else if (strcmp(dump_key->alg_name, key->alg)) {
750                         test_fail("%s: key %u:%u has unexpected alg %s != %s",
751                                   tst_name, sndid, rcvid,
752                                   dump_key->alg_name, key->alg);
753                         passed_test = false;
754                         continue;
755                 }
756                 if (is_kdf_aes_128_cmac) {
757                         if (dump_key->keylen != 16) {
758                                 test_fail("%s: key %u:%u cmac(aes128) has unexpected len %u",
759                                           tst_name, sndid, rcvid,
760                                           dump_key->keylen);
761                                 continue;
762                         }
763                 } else if (dump_key->keylen != key->len) {
764                         test_fail("%s: key %u:%u changed password len %u != %u",
765                                   tst_name, sndid, rcvid,
766                                   dump_key->keylen, key->len);
767                         passed_test = false;
768                         continue;
769                 }
770                 if (!is_kdf_aes_128_cmac &&
771                     memcmp(dump_key->key, key->password, key->len)) {
772                         test_fail("%s: key %u:%u has different password",
773                                   tst_name, sndid, rcvid);
774                         passed_test = false;
775                         continue;
776                 }
777                 if (dump_key->maclen != key->maclen) {
778                         test_fail("%s: key %u:%u changed maclen %u != %u",
779                                   tst_name, sndid, rcvid,
780                                   dump_key->maclen, key->maclen);
781                         passed_test = false;
782                         continue;
783                 }
784         }
785
786         if (passed_test)
787                 test_ok("%s: The socket keys are consistent with the expectations",
788                         tst_name);
789 out:
790         free(keys);
791 }
792
793 static int start_server(const char *tst_name, unsigned int port, size_t quota,
794                         struct tcp_ao_counters *begin,
795                         unsigned int current_index, unsigned int rnext_index)
796 {
797         struct tcp_ao_counters lsk_c1, lsk_c2;
798         ssize_t bytes;
799         int sk, lsk;
800
801         synchronize_threads(); /* 1: key collection initialized */
802         lsk = key_collection_socket(true, port);
803         if (test_get_tcp_ao_counters(lsk, &lsk_c1))
804                 test_error("test_get_tcp_ao_counters()");
805         synchronize_threads(); /* 2: MKTs added => connect() */
806         if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
807                 test_error("test_wait_fd()");
808
809         sk = accept(lsk, NULL, NULL);
810         if (sk < 0)
811                 test_error("accept()");
812         if (test_get_tcp_ao_counters(sk, begin))
813                 test_error("test_get_tcp_ao_counters()");
814
815         synchronize_threads(); /* 3: accepted => send data */
816         if (test_get_tcp_ao_counters(lsk, &lsk_c2))
817                 test_error("test_get_tcp_ao_counters()");
818         verify_keys(tst_name, lsk, true, true);
819         close(lsk);
820
821         bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
822         if (bytes != quota)
823                 test_fail("%s: server served: %zd", tst_name, bytes);
824         else
825                 test_ok("%s: server alive", tst_name);
826
827         verify_counters(tst_name, true, true, &lsk_c1, &lsk_c2);
828
829         return sk;
830 }
831
832 static void end_server(const char *tst_name, int sk,
833                        struct tcp_ao_counters *begin)
834 {
835         struct tcp_ao_counters end;
836
837         if (test_get_tcp_ao_counters(sk, &end))
838                 test_error("test_get_tcp_ao_counters()");
839         verify_keys(tst_name, sk, false, true);
840
841         synchronize_threads(); /* 4: verified => closed */
842         close(sk);
843
844         verify_counters(tst_name, false, true, begin, &end);
845         synchronize_threads(); /* 5: counters */
846 }
847
848 static void try_server_run(const char *tst_name, unsigned int port, size_t quota,
849                            unsigned int current_index, unsigned int rnext_index)
850 {
851         struct tcp_ao_counters tmp;
852         int sk;
853
854         sk = start_server(tst_name, port, quota, &tmp,
855                           current_index, rnext_index);
856         end_server(tst_name, sk, &tmp);
857 }
858
859 static void server_rotations(const char *tst_name, unsigned int port,
860                              size_t quota, unsigned int rotations,
861                              unsigned int current_index, unsigned int rnext_index)
862 {
863         struct tcp_ao_counters tmp;
864         unsigned int i;
865         int sk;
866
867         sk = start_server(tst_name, port, quota, &tmp,
868                           current_index, rnext_index);
869
870         for (i = current_index + 1; rotations > 0; i++, rotations--) {
871                 ssize_t bytes;
872
873                 if (i >= collection.nr_keys)
874                         i = 0;
875                 bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
876                 if (bytes != quota) {
877                         test_fail("%s: server served: %zd", tst_name, bytes);
878                         return;
879                 }
880                 verify_current_rnext(tst_name, sk,
881                                      collection.keys[i].server_keyid, -1);
882                 synchronize_threads(); /* verify current/rnext */
883         }
884         end_server(tst_name, sk, &tmp);
885 }
886
887 static int run_client(const char *tst_name, unsigned int port,
888                       unsigned int nr_keys, int current_index, int rnext_index,
889                       struct tcp_ao_counters *before,
890                       const size_t msg_sz, const size_t msg_nr)
891 {
892         int sk;
893
894         synchronize_threads(); /* 1: key collection initialized */
895         sk = key_collection_socket(false, port);
896
897         if (current_index >= 0 || rnext_index >= 0) {
898                 int sndid = -1, rcvid = -1;
899
900                 if (current_index >= 0)
901                         sndid = collection.keys[current_index].client_keyid;
902                 if (rnext_index >= 0)
903                         rcvid = collection.keys[rnext_index].server_keyid;
904                 if (test_set_key(sk, sndid, rcvid))
905                         test_error("failed to set current/rnext keys");
906         }
907         if (before && test_get_tcp_ao_counters(sk, before))
908                 test_error("test_get_tcp_ao_counters()");
909
910         synchronize_threads(); /* 2: MKTs added => connect() */
911         if (test_connect_socket(sk, this_ip_dest, port++) <= 0)
912                 test_error("failed to connect()");
913         if (current_index < 0)
914                 current_index = nr_keys - 1;
915         if (rnext_index < 0)
916                 rnext_index = nr_keys - 1;
917         collection.keys[current_index].used_on_client_tx = 1;
918         collection.keys[rnext_index].used_on_server_tx = 1;
919
920         synchronize_threads(); /* 3: accepted => send data */
921         if (test_client_verify(sk, msg_sz, msg_nr, TEST_TIMEOUT_SEC)) {
922                 test_fail("verify failed");
923                 close(sk);
924                 if (before)
925                         test_tcp_ao_counters_free(before);
926                 return -1;
927         }
928
929         return sk;
930 }
931
932 static int start_client(const char *tst_name, unsigned int port,
933                         unsigned int nr_keys, int current_index, int rnext_index,
934                         struct tcp_ao_counters *before,
935                         const size_t msg_sz, const size_t msg_nr)
936 {
937         if (init_default_key_collection(nr_keys, true))
938                 test_error("Failed to init the key collection");
939
940         return run_client(tst_name, port, nr_keys, current_index,
941                           rnext_index, before, msg_sz, msg_nr);
942 }
943
944 static void end_client(const char *tst_name, int sk, unsigned int nr_keys,
945                        int current_index, int rnext_index,
946                        struct tcp_ao_counters *start)
947 {
948         struct tcp_ao_counters end;
949
950         /* Some application may become dependent on this kernel choice */
951         if (current_index < 0)
952                 current_index = nr_keys - 1;
953         if (rnext_index < 0)
954                 rnext_index = nr_keys - 1;
955         verify_current_rnext(tst_name, sk,
956                              collection.keys[current_index].client_keyid,
957                              collection.keys[rnext_index].server_keyid);
958         if (start && test_get_tcp_ao_counters(sk, &end))
959                 test_error("test_get_tcp_ao_counters()");
960         verify_keys(tst_name, sk, false, false);
961         synchronize_threads(); /* 4: verify => closed */
962         close(sk);
963         if (start)
964                 verify_counters(tst_name, false, false, start, &end);
965         synchronize_threads(); /* 5: counters */
966 }
967
968 static void try_unmatched_keys(int sk, int *rnext_index, unsigned int port)
969 {
970         struct test_key *key;
971         unsigned int i = 0;
972         int err;
973
974         do {
975                 key = &collection.keys[i];
976                 if (!key->matches_server)
977                         break;
978         } while (++i < collection.nr_keys);
979         if (key->matches_server)
980                 test_error("all keys on client match the server");
981
982         err = test_add_key_cr(sk, key->password, key->len, wrong_addr,
983                               0, key->client_keyid, key->server_keyid,
984                               key->maclen, key->alg, 0, 0);
985         if (!err) {
986                 test_fail("Added a key with non-matching ip-address for established sk");
987                 return;
988         }
989         if (err == -EINVAL)
990                 test_ok("Can't add a key with non-matching ip-address for established sk");
991         else
992                 test_error("Failed to add a key");
993
994         err = test_add_key_cr(sk, key->password, key->len, this_ip_dest,
995                               test_vrf_ifindex,
996                               key->client_keyid, key->server_keyid,
997                               key->maclen, key->alg, 0, 0);
998         if (!err) {
999                 test_fail("Added a key with non-matching VRF for established sk");
1000                 return;
1001         }
1002         if (err == -EINVAL)
1003                 test_ok("Can't add a key with non-matching VRF for established sk");
1004         else
1005                 test_error("Failed to add a key");
1006
1007         for (i = 0; i < collection.nr_keys; i++) {
1008                 key = &collection.keys[i];
1009                 if (!key->matches_client)
1010                         break;
1011         }
1012         if (key->matches_client)
1013                 test_error("all keys on server match the client");
1014         if (test_set_key(sk, -1, key->server_keyid))
1015                 test_error("Can't change the current key");
1016         trace_ao_event_expect(TCP_AO_RNEXT_REQUEST, this_ip_addr, this_ip_dest,
1017                               -1, port, 0, -1, -1, -1, -1, -1,
1018                               -1, key->server_keyid, -1);
1019         if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
1020                 test_fail("verify failed");
1021         *rnext_index = i;
1022 }
1023
1024 static int client_non_matching(const char *tst_name, unsigned int port,
1025                                unsigned int nr_keys,
1026                                int current_index, int rnext_index,
1027                                const size_t msg_sz, const size_t msg_nr)
1028 {
1029         unsigned int i;
1030
1031         if (init_default_key_collection(nr_keys, true))
1032                 test_error("Failed to init the key collection");
1033
1034         for (i = 0; i < nr_keys; i++) {
1035                 /* key (0, 0) matches */
1036                 collection.keys[i].matches_client = !!((i + 3) % 4);
1037                 collection.keys[i].matches_server = !!((i + 2) % 4);
1038                 if (kernel_config_has(KCONFIG_NET_VRF))
1039                         collection.keys[i].matches_vrf = !!((i + 1) % 4);
1040         }
1041
1042         return run_client(tst_name, port, nr_keys, current_index,
1043                           rnext_index, NULL, msg_sz, msg_nr);
1044 }
1045
1046 static void check_current_back(const char *tst_name, unsigned int port,
1047                                unsigned int nr_keys,
1048                                unsigned int current_index, unsigned int rnext_index,
1049                                unsigned int rotate_to_index)
1050 {
1051         struct tcp_ao_counters tmp;
1052         int sk;
1053
1054         sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
1055                           &tmp, msg_len, nr_packets);
1056         if (sk < 0)
1057                 return;
1058         if (test_set_key(sk, collection.keys[rotate_to_index].client_keyid, -1))
1059                 test_error("Can't change the current key");
1060         trace_ao_event_expect(TCP_AO_RNEXT_REQUEST, this_ip_dest, this_ip_addr,
1061                               port, -1, 0, -1, -1, -1, -1, -1,
1062                               collection.keys[rotate_to_index].client_keyid,
1063                               collection.keys[current_index].client_keyid, -1);
1064         if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
1065                 test_fail("verify failed");
1066         /* There is a race here: between setting the current_key with
1067          * setsockopt(TCP_AO_INFO) and starting to send some data - there
1068          * might have been a segment received with the desired
1069          * RNext_key set. In turn that would mean that the first outgoing
1070          * segment will have the desired current_key (flipped back).
1071          * Which is what the user/test wants. As it's racy, skip checking
1072          * the counters, yet check what are the resulting current/rnext
1073          * keys on both sides.
1074          */
1075         collection.keys[rotate_to_index].skip_counters_checks = 1;
1076
1077         end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
1078 }
1079
1080 static void roll_over_keys(const char *tst_name, unsigned int port,
1081                            unsigned int nr_keys, unsigned int rotations,
1082                            unsigned int current_index, unsigned int rnext_index)
1083 {
1084         struct tcp_ao_counters tmp;
1085         unsigned int i;
1086         int sk;
1087
1088         sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
1089                           &tmp, msg_len, nr_packets);
1090         if (sk < 0)
1091                 return;
1092         for (i = rnext_index + 1; rotations > 0; i++, rotations--) {
1093                 if (i >= collection.nr_keys)
1094                         i = 0;
1095                 trace_ao_event_expect(TCP_AO_RNEXT_REQUEST,
1096                                 this_ip_addr, this_ip_dest,
1097                                 -1, port, 0, -1, -1, -1, -1, -1,
1098                                 i == 0 ? -1 : collection.keys[i - 1].server_keyid,
1099                                 collection.keys[i].server_keyid, -1);
1100                 if (test_set_key(sk, -1, collection.keys[i].server_keyid))
1101                         test_error("Can't change the Rnext key");
1102                 if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC)) {
1103                         test_fail("verify failed");
1104                         close(sk);
1105                         test_tcp_ao_counters_free(&tmp);
1106                         return;
1107                 }
1108                 verify_current_rnext(tst_name, sk, -1,
1109                                      collection.keys[i].server_keyid);
1110                 collection.keys[i].used_on_server_tx = 1;
1111                 synchronize_threads(); /* verify current/rnext */
1112         }
1113         end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
1114 }
1115
1116 static void try_client_run(const char *tst_name, unsigned int port,
1117                            unsigned int nr_keys, int current_index, int rnext_index)
1118 {
1119         struct tcp_ao_counters tmp;
1120         int sk;
1121
1122         sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
1123                           &tmp, msg_len, nr_packets);
1124         if (sk < 0)
1125                 return;
1126         end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
1127 }
1128
1129 static void try_client_match(const char *tst_name, unsigned int port,
1130                              unsigned int nr_keys,
1131                              int current_index, int rnext_index)
1132 {
1133         int sk;
1134
1135         sk = client_non_matching(tst_name, port, nr_keys, current_index,
1136                                  rnext_index, msg_len, nr_packets);
1137         if (sk < 0)
1138                 return;
1139         try_unmatched_keys(sk, &rnext_index, port);
1140         end_client(tst_name, sk, nr_keys, current_index, rnext_index, NULL);
1141 }
1142
1143 static void *server_fn(void *arg)
1144 {
1145         unsigned int port = test_server_port;
1146
1147         setup_vrfs();
1148         try_server_run("server: Check current/rnext keys unset before connect()",
1149                        port++, quota, 19, 19);
1150         try_server_run("server: Check current/rnext keys set before connect()",
1151                        port++, quota, 10, 10);
1152         try_server_run("server: Check current != rnext keys set before connect()",
1153                        port++, quota, 5, 10);
1154         try_server_run("server: Check current flapping back on peer's RnextKey request",
1155                        port++, quota * 2, 5, 10);
1156         server_rotations("server: Rotate over all different keys", port++,
1157                          quota, 20, 0, 0);
1158         try_server_run("server: Check accept() => established key matching",
1159                        port++, quota * 2, 0, 0);
1160
1161         synchronize_threads(); /* don't race to exit: client exits */
1162         return NULL;
1163 }
1164
1165 static void check_established_socket(void)
1166 {
1167         unsigned int port = test_server_port;
1168
1169         setup_vrfs();
1170         try_client_run("client: Check current/rnext keys unset before connect()",
1171                        port++, 20, -1, -1);
1172         try_client_run("client: Check current/rnext keys set before connect()",
1173                        port++, 20, 10, 10);
1174         try_client_run("client: Check current != rnext keys set before connect()",
1175                        port++, 20, 10, 5);
1176         check_current_back("client: Check current flapping back on peer's RnextKey request",
1177                            port++, 20, 10, 5, 2);
1178         roll_over_keys("client: Rotate over all different keys", port++,
1179                        20, 20, 0, 0);
1180         try_client_match("client: Check connect() => established key matching",
1181                          port++, 20, 0, 0);
1182 }
1183
1184 static void *client_fn(void *arg)
1185 {
1186         if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
1187                 test_error("Can't convert ip address %s", TEST_WRONG_IP);
1188         check_closed_socket();
1189         check_listen_socket();
1190         check_established_socket();
1191         return NULL;
1192 }
1193
1194 int main(int argc, char *argv[])
1195 {
1196         test_init(121, server_fn, client_fn);
1197         return 0;
1198 }
This page took 0.101569 seconds and 4 git commands to generate.