]> Git Repo - linux.git/blob - tools/testing/selftests/net/tls.c
Linux 6.14-rc3
[linux.git] / tools / testing / selftests / net / tls.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17
18 #include <sys/epoll.h>
19 #include <sys/types.h>
20 #include <sys/sendfile.h>
21 #include <sys/socket.h>
22 #include <sys/stat.h>
23
24 #include "../kselftest_harness.h"
25
26 #define TLS_PAYLOAD_MAX_LEN 16384
27 #define SOL_TLS 282
28
29 static int fips_enabled;
30
31 struct tls_crypto_info_keys {
32         union {
33                 struct tls_crypto_info crypto_info;
34                 struct tls12_crypto_info_aes_gcm_128 aes128;
35                 struct tls12_crypto_info_chacha20_poly1305 chacha20;
36                 struct tls12_crypto_info_sm4_gcm sm4gcm;
37                 struct tls12_crypto_info_sm4_ccm sm4ccm;
38                 struct tls12_crypto_info_aes_ccm_128 aesccm128;
39                 struct tls12_crypto_info_aes_gcm_256 aesgcm256;
40                 struct tls12_crypto_info_aria_gcm_128 ariagcm128;
41                 struct tls12_crypto_info_aria_gcm_256 ariagcm256;
42         };
43         size_t len;
44 };
45
46 static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
47                                  struct tls_crypto_info_keys *tls12,
48                                  char key_generation)
49 {
50         memset(tls12, key_generation, sizeof(*tls12));
51         memset(tls12, 0, sizeof(struct tls_crypto_info));
52
53         switch (cipher_type) {
54         case TLS_CIPHER_CHACHA20_POLY1305:
55                 tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
56                 tls12->chacha20.info.version = tls_version;
57                 tls12->chacha20.info.cipher_type = cipher_type;
58                 break;
59         case TLS_CIPHER_AES_GCM_128:
60                 tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
61                 tls12->aes128.info.version = tls_version;
62                 tls12->aes128.info.cipher_type = cipher_type;
63                 break;
64         case TLS_CIPHER_SM4_GCM:
65                 tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
66                 tls12->sm4gcm.info.version = tls_version;
67                 tls12->sm4gcm.info.cipher_type = cipher_type;
68                 break;
69         case TLS_CIPHER_SM4_CCM:
70                 tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
71                 tls12->sm4ccm.info.version = tls_version;
72                 tls12->sm4ccm.info.cipher_type = cipher_type;
73                 break;
74         case TLS_CIPHER_AES_CCM_128:
75                 tls12->len = sizeof(struct tls12_crypto_info_aes_ccm_128);
76                 tls12->aesccm128.info.version = tls_version;
77                 tls12->aesccm128.info.cipher_type = cipher_type;
78                 break;
79         case TLS_CIPHER_AES_GCM_256:
80                 tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_256);
81                 tls12->aesgcm256.info.version = tls_version;
82                 tls12->aesgcm256.info.cipher_type = cipher_type;
83                 break;
84         case TLS_CIPHER_ARIA_GCM_128:
85                 tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_128);
86                 tls12->ariagcm128.info.version = tls_version;
87                 tls12->ariagcm128.info.cipher_type = cipher_type;
88                 break;
89         case TLS_CIPHER_ARIA_GCM_256:
90                 tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_256);
91                 tls12->ariagcm256.info.version = tls_version;
92                 tls12->ariagcm256.info.cipher_type = cipher_type;
93                 break;
94         default:
95                 break;
96         }
97 }
98
99 static void memrnd(void *s, size_t n)
100 {
101         int *dword = s;
102         char *byte;
103
104         for (; n >= 4; n -= 4)
105                 *dword++ = rand();
106         byte = (void *)dword;
107         while (n--)
108                 *byte++ = rand();
109 }
110
111 static void ulp_sock_pair(struct __test_metadata *_metadata,
112                           int *fd, int *cfd, bool *notls)
113 {
114         struct sockaddr_in addr;
115         socklen_t len;
116         int sfd, ret;
117
118         *notls = false;
119         len = sizeof(addr);
120
121         addr.sin_family = AF_INET;
122         addr.sin_addr.s_addr = htonl(INADDR_ANY);
123         addr.sin_port = 0;
124
125         *fd = socket(AF_INET, SOCK_STREAM, 0);
126         sfd = socket(AF_INET, SOCK_STREAM, 0);
127
128         ret = bind(sfd, &addr, sizeof(addr));
129         ASSERT_EQ(ret, 0);
130         ret = listen(sfd, 10);
131         ASSERT_EQ(ret, 0);
132
133         ret = getsockname(sfd, &addr, &len);
134         ASSERT_EQ(ret, 0);
135
136         ret = connect(*fd, &addr, sizeof(addr));
137         ASSERT_EQ(ret, 0);
138
139         *cfd = accept(sfd, &addr, &len);
140         ASSERT_GE(*cfd, 0);
141
142         close(sfd);
143
144         ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
145         if (ret != 0) {
146                 ASSERT_EQ(errno, ENOENT);
147                 *notls = true;
148                 printf("Failure setting TCP_ULP, testing without tls\n");
149                 return;
150         }
151
152         ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
153         ASSERT_EQ(ret, 0);
154 }
155
156 /* Produce a basic cmsg */
157 static int tls_send_cmsg(int fd, unsigned char record_type,
158                          void *data, size_t len, int flags)
159 {
160         char cbuf[CMSG_SPACE(sizeof(char))];
161         int cmsg_len = sizeof(char);
162         struct cmsghdr *cmsg;
163         struct msghdr msg;
164         struct iovec vec;
165
166         vec.iov_base = data;
167         vec.iov_len = len;
168         memset(&msg, 0, sizeof(struct msghdr));
169         msg.msg_iov = &vec;
170         msg.msg_iovlen = 1;
171         msg.msg_control = cbuf;
172         msg.msg_controllen = sizeof(cbuf);
173         cmsg = CMSG_FIRSTHDR(&msg);
174         cmsg->cmsg_level = SOL_TLS;
175         /* test sending non-record types. */
176         cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
177         cmsg->cmsg_len = CMSG_LEN(cmsg_len);
178         *CMSG_DATA(cmsg) = record_type;
179         msg.msg_controllen = cmsg->cmsg_len;
180
181         return sendmsg(fd, &msg, flags);
182 }
183
184 static int tls_recv_cmsg(struct __test_metadata *_metadata,
185                          int fd, unsigned char record_type,
186                          void *data, size_t len, int flags)
187 {
188         char cbuf[CMSG_SPACE(sizeof(char))];
189         struct cmsghdr *cmsg;
190         unsigned char ctype;
191         struct msghdr msg;
192         struct iovec vec;
193         int n;
194
195         vec.iov_base = data;
196         vec.iov_len = len;
197         memset(&msg, 0, sizeof(struct msghdr));
198         msg.msg_iov = &vec;
199         msg.msg_iovlen = 1;
200         msg.msg_control = cbuf;
201         msg.msg_controllen = sizeof(cbuf);
202
203         n = recvmsg(fd, &msg, flags);
204
205         cmsg = CMSG_FIRSTHDR(&msg);
206         EXPECT_NE(cmsg, NULL);
207         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
208         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
209         ctype = *((unsigned char *)CMSG_DATA(cmsg));
210         EXPECT_EQ(ctype, record_type);
211
212         return n;
213 }
214
215 FIXTURE(tls_basic)
216 {
217         int fd, cfd;
218         bool notls;
219 };
220
221 FIXTURE_SETUP(tls_basic)
222 {
223         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
224 }
225
226 FIXTURE_TEARDOWN(tls_basic)
227 {
228         close(self->fd);
229         close(self->cfd);
230 }
231
232 /* Send some data through with ULP but no keys */
233 TEST_F(tls_basic, base_base)
234 {
235         char const *test_str = "test_read";
236         int send_len = 10;
237         char buf[10];
238
239         ASSERT_EQ(strlen(test_str) + 1, send_len);
240
241         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
242         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
243         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
244 };
245
246 TEST_F(tls_basic, bad_cipher)
247 {
248         struct tls_crypto_info_keys tls12;
249
250         tls12.crypto_info.version = 200;
251         tls12.crypto_info.cipher_type = TLS_CIPHER_AES_GCM_128;
252         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
253
254         tls12.crypto_info.version = TLS_1_2_VERSION;
255         tls12.crypto_info.cipher_type = 50;
256         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
257
258         tls12.crypto_info.version = TLS_1_2_VERSION;
259         tls12.crypto_info.cipher_type = 59;
260         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
261
262         tls12.crypto_info.version = TLS_1_2_VERSION;
263         tls12.crypto_info.cipher_type = 10;
264         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
265
266         tls12.crypto_info.version = TLS_1_2_VERSION;
267         tls12.crypto_info.cipher_type = 70;
268         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
269 }
270
271 TEST_F(tls_basic, recseq_wrap)
272 {
273         struct tls_crypto_info_keys tls12;
274         char const *test_str = "test_read";
275         int send_len = 10;
276
277         if (self->notls)
278                 SKIP(return, "no TLS support");
279
280         tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12, 0);
281         memset(&tls12.aes128.rec_seq, 0xff, sizeof(tls12.aes128.rec_seq));
282
283         ASSERT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
284         ASSERT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
285
286         EXPECT_EQ(send(self->fd, test_str, send_len, 0), -1);
287         EXPECT_EQ(errno, EBADMSG);
288 }
289
290 FIXTURE(tls)
291 {
292         int fd, cfd;
293         bool notls;
294 };
295
296 FIXTURE_VARIANT(tls)
297 {
298         uint16_t tls_version;
299         uint16_t cipher_type;
300         bool nopad, fips_non_compliant;
301 };
302
303 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
304 {
305         .tls_version = TLS_1_2_VERSION,
306         .cipher_type = TLS_CIPHER_AES_GCM_128,
307 };
308
309 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
310 {
311         .tls_version = TLS_1_3_VERSION,
312         .cipher_type = TLS_CIPHER_AES_GCM_128,
313 };
314
315 FIXTURE_VARIANT_ADD(tls, 12_chacha)
316 {
317         .tls_version = TLS_1_2_VERSION,
318         .cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
319         .fips_non_compliant = true,
320 };
321
322 FIXTURE_VARIANT_ADD(tls, 13_chacha)
323 {
324         .tls_version = TLS_1_3_VERSION,
325         .cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
326         .fips_non_compliant = true,
327 };
328
329 FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
330 {
331         .tls_version = TLS_1_3_VERSION,
332         .cipher_type = TLS_CIPHER_SM4_GCM,
333         .fips_non_compliant = true,
334 };
335
336 FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
337 {
338         .tls_version = TLS_1_3_VERSION,
339         .cipher_type = TLS_CIPHER_SM4_CCM,
340         .fips_non_compliant = true,
341 };
342
343 FIXTURE_VARIANT_ADD(tls, 12_aes_ccm)
344 {
345         .tls_version = TLS_1_2_VERSION,
346         .cipher_type = TLS_CIPHER_AES_CCM_128,
347 };
348
349 FIXTURE_VARIANT_ADD(tls, 13_aes_ccm)
350 {
351         .tls_version = TLS_1_3_VERSION,
352         .cipher_type = TLS_CIPHER_AES_CCM_128,
353 };
354
355 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm_256)
356 {
357         .tls_version = TLS_1_2_VERSION,
358         .cipher_type = TLS_CIPHER_AES_GCM_256,
359 };
360
361 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm_256)
362 {
363         .tls_version = TLS_1_3_VERSION,
364         .cipher_type = TLS_CIPHER_AES_GCM_256,
365 };
366
367 FIXTURE_VARIANT_ADD(tls, 13_nopad)
368 {
369         .tls_version = TLS_1_3_VERSION,
370         .cipher_type = TLS_CIPHER_AES_GCM_128,
371         .nopad = true,
372 };
373
374 FIXTURE_VARIANT_ADD(tls, 12_aria_gcm)
375 {
376         .tls_version = TLS_1_2_VERSION,
377         .cipher_type = TLS_CIPHER_ARIA_GCM_128,
378 };
379
380 FIXTURE_VARIANT_ADD(tls, 12_aria_gcm_256)
381 {
382         .tls_version = TLS_1_2_VERSION,
383         .cipher_type = TLS_CIPHER_ARIA_GCM_256,
384 };
385
386 FIXTURE_SETUP(tls)
387 {
388         struct tls_crypto_info_keys tls12;
389         int one = 1;
390         int ret;
391
392         if (fips_enabled && variant->fips_non_compliant)
393                 SKIP(return, "Unsupported cipher in FIPS mode");
394
395         tls_crypto_info_init(variant->tls_version, variant->cipher_type,
396                              &tls12, 0);
397
398         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
399
400         if (self->notls)
401                 return;
402
403         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
404         ASSERT_EQ(ret, 0);
405
406         ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
407         ASSERT_EQ(ret, 0);
408
409         if (variant->nopad) {
410                 ret = setsockopt(self->cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
411                                  (void *)&one, sizeof(one));
412                 ASSERT_EQ(ret, 0);
413         }
414 }
415
416 FIXTURE_TEARDOWN(tls)
417 {
418         close(self->fd);
419         close(self->cfd);
420 }
421
422 TEST_F(tls, sendfile)
423 {
424         int filefd = open("/proc/self/exe", O_RDONLY);
425         struct stat st;
426
427         EXPECT_GE(filefd, 0);
428         fstat(filefd, &st);
429         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
430 }
431
432 TEST_F(tls, send_then_sendfile)
433 {
434         int filefd = open("/proc/self/exe", O_RDONLY);
435         char const *test_str = "test_send";
436         int to_send = strlen(test_str) + 1;
437         char recv_buf[10];
438         struct stat st;
439         char *buf;
440
441         EXPECT_GE(filefd, 0);
442         fstat(filefd, &st);
443         buf = (char *)malloc(st.st_size);
444
445         EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
446         EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
447         EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
448
449         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
450         EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
451 }
452
453 static void chunked_sendfile(struct __test_metadata *_metadata,
454                              struct _test_data_tls *self,
455                              uint16_t chunk_size,
456                              uint16_t extra_payload_size)
457 {
458         char buf[TLS_PAYLOAD_MAX_LEN];
459         uint16_t test_payload_size;
460         int size = 0;
461         int ret;
462         char filename[] = "/tmp/mytemp.XXXXXX";
463         int fd = mkstemp(filename);
464         off_t offset = 0;
465
466         unlink(filename);
467         ASSERT_GE(fd, 0);
468         EXPECT_GE(chunk_size, 1);
469         test_payload_size = chunk_size + extra_payload_size;
470         ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
471         memset(buf, 1, test_payload_size);
472         size = write(fd, buf, test_payload_size);
473         EXPECT_EQ(size, test_payload_size);
474         fsync(fd);
475
476         while (size > 0) {
477                 ret = sendfile(self->fd, fd, &offset, chunk_size);
478                 EXPECT_GE(ret, 0);
479                 size -= ret;
480         }
481
482         EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
483                   test_payload_size);
484
485         close(fd);
486 }
487
488 TEST_F(tls, multi_chunk_sendfile)
489 {
490         chunked_sendfile(_metadata, self, 4096, 4096);
491         chunked_sendfile(_metadata, self, 4096, 0);
492         chunked_sendfile(_metadata, self, 4096, 1);
493         chunked_sendfile(_metadata, self, 4096, 2048);
494         chunked_sendfile(_metadata, self, 8192, 2048);
495         chunked_sendfile(_metadata, self, 4096, 8192);
496         chunked_sendfile(_metadata, self, 8192, 4096);
497         chunked_sendfile(_metadata, self, 12288, 1024);
498         chunked_sendfile(_metadata, self, 12288, 2000);
499         chunked_sendfile(_metadata, self, 15360, 100);
500         chunked_sendfile(_metadata, self, 15360, 300);
501         chunked_sendfile(_metadata, self, 1, 4096);
502         chunked_sendfile(_metadata, self, 2048, 4096);
503         chunked_sendfile(_metadata, self, 2048, 8192);
504         chunked_sendfile(_metadata, self, 4096, 8192);
505         chunked_sendfile(_metadata, self, 1024, 12288);
506         chunked_sendfile(_metadata, self, 2000, 12288);
507         chunked_sendfile(_metadata, self, 100, 15360);
508         chunked_sendfile(_metadata, self, 300, 15360);
509 }
510
511 TEST_F(tls, recv_max)
512 {
513         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
514         char recv_mem[TLS_PAYLOAD_MAX_LEN];
515         char buf[TLS_PAYLOAD_MAX_LEN];
516
517         memrnd(buf, sizeof(buf));
518
519         EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
520         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
521         EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
522 }
523
524 TEST_F(tls, recv_small)
525 {
526         char const *test_str = "test_read";
527         int send_len = 10;
528         char buf[10];
529
530         send_len = strlen(test_str) + 1;
531         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
532         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
533         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
534 }
535
536 TEST_F(tls, msg_more)
537 {
538         char const *test_str = "test_read";
539         int send_len = 10;
540         char buf[10 * 2];
541
542         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
543         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
544         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
545         EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
546                   send_len * 2);
547         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
548 }
549
550 TEST_F(tls, msg_more_unsent)
551 {
552         char const *test_str = "test_read";
553         int send_len = 10;
554         char buf[10];
555
556         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
557         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
558 }
559
560 TEST_F(tls, msg_eor)
561 {
562         char const *test_str = "test_read";
563         int send_len = 10;
564         char buf[10];
565
566         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_EOR), send_len);
567         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
568         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
569 }
570
571 TEST_F(tls, sendmsg_single)
572 {
573         struct msghdr msg;
574
575         char const *test_str = "test_sendmsg";
576         size_t send_len = 13;
577         struct iovec vec;
578         char buf[13];
579
580         vec.iov_base = (char *)test_str;
581         vec.iov_len = send_len;
582         memset(&msg, 0, sizeof(struct msghdr));
583         msg.msg_iov = &vec;
584         msg.msg_iovlen = 1;
585         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
586         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
587         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
588 }
589
590 #define MAX_FRAGS       64
591 #define SEND_LEN        13
592 TEST_F(tls, sendmsg_fragmented)
593 {
594         char const *test_str = "test_sendmsg";
595         char buf[SEND_LEN * MAX_FRAGS];
596         struct iovec vec[MAX_FRAGS];
597         struct msghdr msg;
598         int i, frags;
599
600         for (frags = 1; frags <= MAX_FRAGS; frags++) {
601                 for (i = 0; i < frags; i++) {
602                         vec[i].iov_base = (char *)test_str;
603                         vec[i].iov_len = SEND_LEN;
604                 }
605
606                 memset(&msg, 0, sizeof(struct msghdr));
607                 msg.msg_iov = vec;
608                 msg.msg_iovlen = frags;
609
610                 EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
611                 EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
612                           SEND_LEN * frags);
613
614                 for (i = 0; i < frags; i++)
615                         EXPECT_EQ(memcmp(buf + SEND_LEN * i,
616                                          test_str, SEND_LEN), 0);
617         }
618 }
619 #undef MAX_FRAGS
620 #undef SEND_LEN
621
622 TEST_F(tls, sendmsg_large)
623 {
624         void *mem = malloc(16384);
625         size_t send_len = 16384;
626         size_t sends = 128;
627         struct msghdr msg;
628         size_t recvs = 0;
629         size_t sent = 0;
630
631         memset(&msg, 0, sizeof(struct msghdr));
632         while (sent++ < sends) {
633                 struct iovec vec = { (void *)mem, send_len };
634
635                 msg.msg_iov = &vec;
636                 msg.msg_iovlen = 1;
637                 EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
638         }
639
640         while (recvs++ < sends) {
641                 EXPECT_NE(recv(self->cfd, mem, send_len, 0), -1);
642         }
643
644         free(mem);
645 }
646
647 TEST_F(tls, sendmsg_multiple)
648 {
649         char const *test_str = "test_sendmsg_multiple";
650         struct iovec vec[5];
651         char *test_strs[5];
652         struct msghdr msg;
653         int total_len = 0;
654         int len_cmp = 0;
655         int iov_len = 5;
656         char *buf;
657         int i;
658
659         memset(&msg, 0, sizeof(struct msghdr));
660         for (i = 0; i < iov_len; i++) {
661                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
662                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
663                 vec[i].iov_base = (void *)test_strs[i];
664                 vec[i].iov_len = strlen(test_strs[i]) + 1;
665                 total_len += vec[i].iov_len;
666         }
667         msg.msg_iov = vec;
668         msg.msg_iovlen = iov_len;
669
670         EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
671         buf = malloc(total_len);
672         EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
673         for (i = 0; i < iov_len; i++) {
674                 EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
675                                  strlen(test_strs[i])),
676                           0);
677                 len_cmp += strlen(buf + len_cmp) + 1;
678         }
679         for (i = 0; i < iov_len; i++)
680                 free(test_strs[i]);
681         free(buf);
682 }
683
684 TEST_F(tls, sendmsg_multiple_stress)
685 {
686         char const *test_str = "abcdefghijklmno";
687         struct iovec vec[1024];
688         char *test_strs[1024];
689         int iov_len = 1024;
690         int total_len = 0;
691         char buf[1 << 14];
692         struct msghdr msg;
693         int len_cmp = 0;
694         int i;
695
696         memset(&msg, 0, sizeof(struct msghdr));
697         for (i = 0; i < iov_len; i++) {
698                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
699                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
700                 vec[i].iov_base = (void *)test_strs[i];
701                 vec[i].iov_len = strlen(test_strs[i]) + 1;
702                 total_len += vec[i].iov_len;
703         }
704         msg.msg_iov = vec;
705         msg.msg_iovlen = iov_len;
706
707         EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
708         EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
709
710         for (i = 0; i < iov_len; i++)
711                 len_cmp += strlen(buf + len_cmp) + 1;
712
713         for (i = 0; i < iov_len; i++)
714                 free(test_strs[i]);
715 }
716
717 TEST_F(tls, splice_from_pipe)
718 {
719         int send_len = TLS_PAYLOAD_MAX_LEN;
720         char mem_send[TLS_PAYLOAD_MAX_LEN];
721         char mem_recv[TLS_PAYLOAD_MAX_LEN];
722         int p[2];
723
724         ASSERT_GE(pipe(p), 0);
725         EXPECT_GE(write(p[1], mem_send, send_len), 0);
726         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
727         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
728         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
729 }
730
731 TEST_F(tls, splice_more)
732 {
733         unsigned int f = SPLICE_F_NONBLOCK | SPLICE_F_MORE | SPLICE_F_GIFT;
734         int send_len = TLS_PAYLOAD_MAX_LEN;
735         char mem_send[TLS_PAYLOAD_MAX_LEN];
736         int i, send_pipe = 1;
737         int p[2];
738
739         ASSERT_GE(pipe(p), 0);
740         EXPECT_GE(write(p[1], mem_send, send_len), 0);
741         for (i = 0; i < 32; i++)
742                 EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, send_pipe, f), 1);
743 }
744
745 TEST_F(tls, splice_from_pipe2)
746 {
747         int send_len = 16000;
748         char mem_send[16000];
749         char mem_recv[16000];
750         int p2[2];
751         int p[2];
752
753         memrnd(mem_send, sizeof(mem_send));
754
755         ASSERT_GE(pipe(p), 0);
756         ASSERT_GE(pipe(p2), 0);
757         EXPECT_EQ(write(p[1], mem_send, 8000), 8000);
758         EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, 8000, 0), 8000);
759         EXPECT_EQ(write(p2[1], mem_send + 8000, 8000), 8000);
760         EXPECT_EQ(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 8000);
761         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
762         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
763 }
764
765 TEST_F(tls, send_and_splice)
766 {
767         int send_len = TLS_PAYLOAD_MAX_LEN;
768         char mem_send[TLS_PAYLOAD_MAX_LEN];
769         char mem_recv[TLS_PAYLOAD_MAX_LEN];
770         char const *test_str = "test_read";
771         int send_len2 = 10;
772         char buf[10];
773         int p[2];
774
775         ASSERT_GE(pipe(p), 0);
776         EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
777         EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
778         EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
779
780         EXPECT_GE(write(p[1], mem_send, send_len), send_len);
781         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
782
783         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
784         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
785 }
786
787 TEST_F(tls, splice_to_pipe)
788 {
789         int send_len = TLS_PAYLOAD_MAX_LEN;
790         char mem_send[TLS_PAYLOAD_MAX_LEN];
791         char mem_recv[TLS_PAYLOAD_MAX_LEN];
792         int p[2];
793
794         memrnd(mem_send, sizeof(mem_send));
795
796         ASSERT_GE(pipe(p), 0);
797         EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
798         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), send_len);
799         EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
800         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
801 }
802
803 TEST_F(tls, splice_cmsg_to_pipe)
804 {
805         char *test_str = "test_read";
806         char record_type = 100;
807         int send_len = 10;
808         char buf[10];
809         int p[2];
810
811         if (self->notls)
812                 SKIP(return, "no TLS support");
813
814         ASSERT_GE(pipe(p), 0);
815         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
816         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
817         EXPECT_EQ(errno, EINVAL);
818         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
819         EXPECT_EQ(errno, EIO);
820         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
821                                 buf, sizeof(buf), MSG_WAITALL),
822                   send_len);
823         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
824 }
825
826 TEST_F(tls, splice_dec_cmsg_to_pipe)
827 {
828         char *test_str = "test_read";
829         char record_type = 100;
830         int send_len = 10;
831         char buf[10];
832         int p[2];
833
834         if (self->notls)
835                 SKIP(return, "no TLS support");
836
837         ASSERT_GE(pipe(p), 0);
838         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
839         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
840         EXPECT_EQ(errno, EIO);
841         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
842         EXPECT_EQ(errno, EINVAL);
843         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
844                                 buf, sizeof(buf), MSG_WAITALL),
845                   send_len);
846         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
847 }
848
849 TEST_F(tls, recv_and_splice)
850 {
851         int send_len = TLS_PAYLOAD_MAX_LEN;
852         char mem_send[TLS_PAYLOAD_MAX_LEN];
853         char mem_recv[TLS_PAYLOAD_MAX_LEN];
854         int half = send_len / 2;
855         int p[2];
856
857         ASSERT_GE(pipe(p), 0);
858         EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
859         /* Recv hald of the record, splice the other half */
860         EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
861         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
862                   half);
863         EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
864         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
865 }
866
867 TEST_F(tls, peek_and_splice)
868 {
869         int send_len = TLS_PAYLOAD_MAX_LEN;
870         char mem_send[TLS_PAYLOAD_MAX_LEN];
871         char mem_recv[TLS_PAYLOAD_MAX_LEN];
872         int chunk = TLS_PAYLOAD_MAX_LEN / 4;
873         int n, i, p[2];
874
875         memrnd(mem_send, sizeof(mem_send));
876
877         ASSERT_GE(pipe(p), 0);
878         for (i = 0; i < 4; i++)
879                 EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
880                           chunk);
881
882         EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
883                        MSG_WAITALL | MSG_PEEK),
884                   chunk * 5 / 2);
885         EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
886
887         n = 0;
888         while (n < send_len) {
889                 i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
890                 EXPECT_GT(i, 0);
891                 n += i;
892         }
893         EXPECT_EQ(n, send_len);
894         EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
895         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
896 }
897
898 TEST_F(tls, recvmsg_single)
899 {
900         char const *test_str = "test_recvmsg_single";
901         int send_len = strlen(test_str) + 1;
902         char buf[20];
903         struct msghdr hdr;
904         struct iovec vec;
905
906         memset(&hdr, 0, sizeof(hdr));
907         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
908         vec.iov_base = (char *)buf;
909         vec.iov_len = send_len;
910         hdr.msg_iovlen = 1;
911         hdr.msg_iov = &vec;
912         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
913         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
914 }
915
916 TEST_F(tls, recvmsg_single_max)
917 {
918         int send_len = TLS_PAYLOAD_MAX_LEN;
919         char send_mem[TLS_PAYLOAD_MAX_LEN];
920         char recv_mem[TLS_PAYLOAD_MAX_LEN];
921         struct iovec vec;
922         struct msghdr hdr;
923
924         memrnd(send_mem, sizeof(send_mem));
925
926         EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
927         vec.iov_base = (char *)recv_mem;
928         vec.iov_len = TLS_PAYLOAD_MAX_LEN;
929
930         hdr.msg_iovlen = 1;
931         hdr.msg_iov = &vec;
932         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
933         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
934 }
935
936 TEST_F(tls, recvmsg_multiple)
937 {
938         unsigned int msg_iovlen = 1024;
939         struct iovec vec[1024];
940         char *iov_base[1024];
941         unsigned int iov_len = 16;
942         int send_len = 1 << 14;
943         char buf[1 << 14];
944         struct msghdr hdr;
945         int i;
946
947         memrnd(buf, sizeof(buf));
948
949         EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
950         for (i = 0; i < msg_iovlen; i++) {
951                 iov_base[i] = (char *)malloc(iov_len);
952                 vec[i].iov_base = iov_base[i];
953                 vec[i].iov_len = iov_len;
954         }
955
956         hdr.msg_iovlen = msg_iovlen;
957         hdr.msg_iov = vec;
958         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
959
960         for (i = 0; i < msg_iovlen; i++)
961                 free(iov_base[i]);
962 }
963
964 TEST_F(tls, single_send_multiple_recv)
965 {
966         unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
967         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
968         char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
969         char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
970
971         memrnd(send_mem, sizeof(send_mem));
972
973         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
974         memset(recv_mem, 0, total_len);
975
976         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
977         EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
978         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
979 }
980
981 TEST_F(tls, multiple_send_single_recv)
982 {
983         unsigned int total_len = 2 * 10;
984         unsigned int send_len = 10;
985         char recv_mem[2 * 10];
986         char send_mem[10];
987
988         memrnd(send_mem, sizeof(send_mem));
989
990         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
991         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
992         memset(recv_mem, 0, total_len);
993         EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
994
995         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
996         EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
997 }
998
999 TEST_F(tls, single_send_multiple_recv_non_align)
1000 {
1001         const unsigned int total_len = 15;
1002         const unsigned int recv_len = 10;
1003         char recv_mem[recv_len * 2];
1004         char send_mem[total_len];
1005
1006         memrnd(send_mem, sizeof(send_mem));
1007
1008         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
1009         memset(recv_mem, 0, total_len);
1010
1011         EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
1012         EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
1013         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
1014 }
1015
1016 TEST_F(tls, recv_partial)
1017 {
1018         char const *test_str = "test_read_partial";
1019         char const *test_str_first = "test_read";
1020         char const *test_str_second = "_partial";
1021         int send_len = strlen(test_str) + 1;
1022         char recv_mem[18];
1023
1024         memset(recv_mem, 0, sizeof(recv_mem));
1025         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1026         EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_first),
1027                        MSG_WAITALL), strlen(test_str_first));
1028         EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
1029         memset(recv_mem, 0, sizeof(recv_mem));
1030         EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_second),
1031                        MSG_WAITALL), strlen(test_str_second));
1032         EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
1033                   0);
1034 }
1035
1036 TEST_F(tls, recv_nonblock)
1037 {
1038         char buf[4096];
1039         bool err;
1040
1041         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1042         err = (errno == EAGAIN || errno == EWOULDBLOCK);
1043         EXPECT_EQ(err, true);
1044 }
1045
1046 TEST_F(tls, recv_peek)
1047 {
1048         char const *test_str = "test_read_peek";
1049         int send_len = strlen(test_str) + 1;
1050         char buf[15];
1051
1052         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1053         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), send_len);
1054         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1055         memset(buf, 0, sizeof(buf));
1056         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1057         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1058 }
1059
1060 TEST_F(tls, recv_peek_multiple)
1061 {
1062         char const *test_str = "test_read_peek";
1063         int send_len = strlen(test_str) + 1;
1064         unsigned int num_peeks = 100;
1065         char buf[15];
1066         int i;
1067
1068         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1069         for (i = 0; i < num_peeks; i++) {
1070                 EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1071                 EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1072                 memset(buf, 0, sizeof(buf));
1073         }
1074         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1075         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1076 }
1077
1078 TEST_F(tls, recv_peek_multiple_records)
1079 {
1080         char const *test_str = "test_read_peek_mult_recs";
1081         char const *test_str_first = "test_read_peek";
1082         char const *test_str_second = "_mult_recs";
1083         int len;
1084         char buf[64];
1085
1086         len = strlen(test_str_first);
1087         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1088
1089         len = strlen(test_str_second) + 1;
1090         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1091
1092         len = strlen(test_str_first);
1093         memset(buf, 0, len);
1094         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1095
1096         /* MSG_PEEK can only peek into the current record. */
1097         len = strlen(test_str_first);
1098         EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
1099
1100         len = strlen(test_str) + 1;
1101         memset(buf, 0, len);
1102         EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
1103
1104         /* Non-MSG_PEEK will advance strparser (and therefore record)
1105          * however.
1106          */
1107         len = strlen(test_str) + 1;
1108         EXPECT_EQ(memcmp(test_str, buf, len), 0);
1109
1110         /* MSG_MORE will hold current record open, so later MSG_PEEK
1111          * will see everything.
1112          */
1113         len = strlen(test_str_first);
1114         EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
1115
1116         len = strlen(test_str_second) + 1;
1117         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1118
1119         len = strlen(test_str) + 1;
1120         memset(buf, 0, len);
1121         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1122
1123         len = strlen(test_str) + 1;
1124         EXPECT_EQ(memcmp(test_str, buf, len), 0);
1125 }
1126
1127 TEST_F(tls, recv_peek_large_buf_mult_recs)
1128 {
1129         char const *test_str = "test_read_peek_mult_recs";
1130         char const *test_str_first = "test_read_peek";
1131         char const *test_str_second = "_mult_recs";
1132         int len;
1133         char buf[64];
1134
1135         len = strlen(test_str_first);
1136         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1137
1138         len = strlen(test_str_second) + 1;
1139         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1140
1141         len = strlen(test_str) + 1;
1142         memset(buf, 0, len);
1143         EXPECT_NE((len = recv(self->cfd, buf, len,
1144                               MSG_PEEK | MSG_WAITALL)), -1);
1145         len = strlen(test_str) + 1;
1146         EXPECT_EQ(memcmp(test_str, buf, len), 0);
1147 }
1148
1149 TEST_F(tls, recv_lowat)
1150 {
1151         char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
1152         char recv_mem[20];
1153         int lowat = 8;
1154
1155         EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
1156         EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
1157
1158         memset(recv_mem, 0, 20);
1159         EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
1160                              &lowat, sizeof(lowat)), 0);
1161         EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
1162         EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
1163         EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
1164
1165         EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
1166         EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
1167 }
1168
1169 TEST_F(tls, bidir)
1170 {
1171         char const *test_str = "test_read";
1172         int send_len = 10;
1173         char buf[10];
1174         int ret;
1175
1176         if (!self->notls) {
1177                 struct tls_crypto_info_keys tls12;
1178
1179                 tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1180                                      &tls12, 0);
1181
1182                 ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1183                                  tls12.len);
1184                 ASSERT_EQ(ret, 0);
1185
1186                 ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1187                                  tls12.len);
1188                 ASSERT_EQ(ret, 0);
1189         }
1190
1191         ASSERT_EQ(strlen(test_str) + 1, send_len);
1192
1193         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1194         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1195         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1196
1197         memset(buf, 0, sizeof(buf));
1198
1199         EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1200         EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1201         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1202 };
1203
1204 TEST_F(tls, pollin)
1205 {
1206         char const *test_str = "test_poll";
1207         struct pollfd fd = { 0, 0, 0 };
1208         char buf[10];
1209         int send_len = 10;
1210
1211         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1212         fd.fd = self->cfd;
1213         fd.events = POLLIN;
1214
1215         EXPECT_EQ(poll(&fd, 1, 20), 1);
1216         EXPECT_EQ(fd.revents & POLLIN, 1);
1217         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1218         /* Test timing out */
1219         EXPECT_EQ(poll(&fd, 1, 20), 0);
1220 }
1221
1222 TEST_F(tls, poll_wait)
1223 {
1224         char const *test_str = "test_poll_wait";
1225         int send_len = strlen(test_str) + 1;
1226         struct pollfd fd = { 0, 0, 0 };
1227         char recv_mem[15];
1228
1229         fd.fd = self->cfd;
1230         fd.events = POLLIN;
1231         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1232         /* Set timeout to inf. secs */
1233         EXPECT_EQ(poll(&fd, 1, -1), 1);
1234         EXPECT_EQ(fd.revents & POLLIN, 1);
1235         EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1236 }
1237
1238 TEST_F(tls, poll_wait_split)
1239 {
1240         struct pollfd fd = { 0, 0, 0 };
1241         char send_mem[20] = {};
1242         char recv_mem[15];
1243
1244         fd.fd = self->cfd;
1245         fd.events = POLLIN;
1246         /* Send 20 bytes */
1247         EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1248                   sizeof(send_mem));
1249         /* Poll with inf. timeout */
1250         EXPECT_EQ(poll(&fd, 1, -1), 1);
1251         EXPECT_EQ(fd.revents & POLLIN, 1);
1252         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1253                   sizeof(recv_mem));
1254
1255         /* Now the remaining 5 bytes of record data are in TLS ULP */
1256         fd.fd = self->cfd;
1257         fd.events = POLLIN;
1258         EXPECT_EQ(poll(&fd, 1, -1), 1);
1259         EXPECT_EQ(fd.revents & POLLIN, 1);
1260         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1261                   sizeof(send_mem) - sizeof(recv_mem));
1262 }
1263
1264 TEST_F(tls, blocking)
1265 {
1266         size_t data = 100000;
1267         int res = fork();
1268
1269         EXPECT_NE(res, -1);
1270
1271         if (res) {
1272                 /* parent */
1273                 size_t left = data;
1274                 char buf[16384];
1275                 int status;
1276                 int pid2;
1277
1278                 while (left) {
1279                         int res = send(self->fd, buf,
1280                                        left > 16384 ? 16384 : left, 0);
1281
1282                         EXPECT_GE(res, 0);
1283                         left -= res;
1284                 }
1285
1286                 pid2 = wait(&status);
1287                 EXPECT_EQ(status, 0);
1288                 EXPECT_EQ(res, pid2);
1289         } else {
1290                 /* child */
1291                 size_t left = data;
1292                 char buf[16384];
1293
1294                 while (left) {
1295                         int res = recv(self->cfd, buf,
1296                                        left > 16384 ? 16384 : left, 0);
1297
1298                         EXPECT_GE(res, 0);
1299                         left -= res;
1300                 }
1301         }
1302 }
1303
1304 TEST_F(tls, nonblocking)
1305 {
1306         size_t data = 100000;
1307         int sendbuf = 100;
1308         int flags;
1309         int res;
1310
1311         flags = fcntl(self->fd, F_GETFL, 0);
1312         fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1313         fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1314
1315         /* Ensure nonblocking behavior by imposing a small send
1316          * buffer.
1317          */
1318         EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1319                              &sendbuf, sizeof(sendbuf)), 0);
1320
1321         res = fork();
1322         EXPECT_NE(res, -1);
1323
1324         if (res) {
1325                 /* parent */
1326                 bool eagain = false;
1327                 size_t left = data;
1328                 char buf[16384];
1329                 int status;
1330                 int pid2;
1331
1332                 while (left) {
1333                         int res = send(self->fd, buf,
1334                                        left > 16384 ? 16384 : left, 0);
1335
1336                         if (res == -1 && errno == EAGAIN) {
1337                                 eagain = true;
1338                                 usleep(10000);
1339                                 continue;
1340                         }
1341                         EXPECT_GE(res, 0);
1342                         left -= res;
1343                 }
1344
1345                 EXPECT_TRUE(eagain);
1346                 pid2 = wait(&status);
1347
1348                 EXPECT_EQ(status, 0);
1349                 EXPECT_EQ(res, pid2);
1350         } else {
1351                 /* child */
1352                 bool eagain = false;
1353                 size_t left = data;
1354                 char buf[16384];
1355
1356                 while (left) {
1357                         int res = recv(self->cfd, buf,
1358                                        left > 16384 ? 16384 : left, 0);
1359
1360                         if (res == -1 && errno == EAGAIN) {
1361                                 eagain = true;
1362                                 usleep(10000);
1363                                 continue;
1364                         }
1365                         EXPECT_GE(res, 0);
1366                         left -= res;
1367                 }
1368                 EXPECT_TRUE(eagain);
1369         }
1370 }
1371
1372 static void
1373 test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1374                bool sendpg, unsigned int n_readers, unsigned int n_writers)
1375 {
1376         const unsigned int n_children = n_readers + n_writers;
1377         const size_t data = 6 * 1000 * 1000;
1378         const size_t file_sz = data / 100;
1379         size_t read_bias, write_bias;
1380         int i, fd, child_id;
1381         char buf[file_sz];
1382         pid_t pid;
1383
1384         /* Only allow multiples for simplicity */
1385         ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1386         read_bias = n_writers / n_readers ?: 1;
1387         write_bias = n_readers / n_writers ?: 1;
1388
1389         /* prep a file to send */
1390         fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1391         ASSERT_GE(fd, 0);
1392
1393         memset(buf, 0xac, file_sz);
1394         ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1395
1396         /* spawn children */
1397         for (child_id = 0; child_id < n_children; child_id++) {
1398                 pid = fork();
1399                 ASSERT_NE(pid, -1);
1400                 if (!pid)
1401                         break;
1402         }
1403
1404         /* parent waits for all children */
1405         if (pid) {
1406                 for (i = 0; i < n_children; i++) {
1407                         int status;
1408
1409                         wait(&status);
1410                         EXPECT_EQ(status, 0);
1411                 }
1412
1413                 return;
1414         }
1415
1416         /* Split threads for reading and writing */
1417         if (child_id < n_readers) {
1418                 size_t left = data * read_bias;
1419                 char rb[8001];
1420
1421                 while (left) {
1422                         int res;
1423
1424                         res = recv(self->cfd, rb,
1425                                    left > sizeof(rb) ? sizeof(rb) : left, 0);
1426
1427                         EXPECT_GE(res, 0);
1428                         left -= res;
1429                 }
1430         } else {
1431                 size_t left = data * write_bias;
1432
1433                 while (left) {
1434                         int res;
1435
1436                         ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1437                         if (sendpg)
1438                                 res = sendfile(self->fd, fd, NULL,
1439                                                left > file_sz ? file_sz : left);
1440                         else
1441                                 res = send(self->fd, buf,
1442                                            left > file_sz ? file_sz : left, 0);
1443
1444                         EXPECT_GE(res, 0);
1445                         left -= res;
1446                 }
1447         }
1448 }
1449
1450 TEST_F(tls, mutliproc_even)
1451 {
1452         test_mutliproc(_metadata, self, false, 6, 6);
1453 }
1454
1455 TEST_F(tls, mutliproc_readers)
1456 {
1457         test_mutliproc(_metadata, self, false, 4, 12);
1458 }
1459
1460 TEST_F(tls, mutliproc_writers)
1461 {
1462         test_mutliproc(_metadata, self, false, 10, 2);
1463 }
1464
1465 TEST_F(tls, mutliproc_sendpage_even)
1466 {
1467         test_mutliproc(_metadata, self, true, 6, 6);
1468 }
1469
1470 TEST_F(tls, mutliproc_sendpage_readers)
1471 {
1472         test_mutliproc(_metadata, self, true, 4, 12);
1473 }
1474
1475 TEST_F(tls, mutliproc_sendpage_writers)
1476 {
1477         test_mutliproc(_metadata, self, true, 10, 2);
1478 }
1479
1480 TEST_F(tls, control_msg)
1481 {
1482         char *test_str = "test_read";
1483         char record_type = 100;
1484         int send_len = 10;
1485         char buf[10];
1486
1487         if (self->notls)
1488                 SKIP(return, "no TLS support");
1489
1490         EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1491                   send_len);
1492         /* Should fail because we didn't provide a control message */
1493         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1494
1495         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1496                                 buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1497                   send_len);
1498         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1499
1500         /* Recv the message again without MSG_PEEK */
1501         memset(buf, 0, sizeof(buf));
1502
1503         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1504                                 buf, sizeof(buf), MSG_WAITALL),
1505                   send_len);
1506         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1507 }
1508
1509 TEST_F(tls, control_msg_nomerge)
1510 {
1511         char *rec1 = "1111";
1512         char *rec2 = "2222";
1513         int send_len = 5;
1514         char buf[15];
1515
1516         if (self->notls)
1517                 SKIP(return, "no TLS support");
1518
1519         EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec1, send_len, 0), send_len);
1520         EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1521
1522         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1523         EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1524
1525         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1526         EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1527
1528         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1529         EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1530
1531         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1532         EXPECT_EQ(memcmp(buf, rec2, send_len), 0);
1533 }
1534
1535 TEST_F(tls, data_control_data)
1536 {
1537         char *rec1 = "1111";
1538         char *rec2 = "2222";
1539         char *rec3 = "3333";
1540         int send_len = 5;
1541         char buf[15];
1542
1543         if (self->notls)
1544                 SKIP(return, "no TLS support");
1545
1546         EXPECT_EQ(send(self->fd, rec1, send_len, 0), send_len);
1547         EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1548         EXPECT_EQ(send(self->fd, rec3, send_len, 0), send_len);
1549
1550         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1551         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1552 }
1553
1554 TEST_F(tls, shutdown)
1555 {
1556         char const *test_str = "test_read";
1557         int send_len = 10;
1558         char buf[10];
1559
1560         ASSERT_EQ(strlen(test_str) + 1, send_len);
1561
1562         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1563         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1564         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1565
1566         shutdown(self->fd, SHUT_RDWR);
1567         shutdown(self->cfd, SHUT_RDWR);
1568 }
1569
1570 TEST_F(tls, shutdown_unsent)
1571 {
1572         char const *test_str = "test_read";
1573         int send_len = 10;
1574
1575         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1576
1577         shutdown(self->fd, SHUT_RDWR);
1578         shutdown(self->cfd, SHUT_RDWR);
1579 }
1580
1581 TEST_F(tls, shutdown_reuse)
1582 {
1583         struct sockaddr_in addr;
1584         int ret;
1585
1586         shutdown(self->fd, SHUT_RDWR);
1587         shutdown(self->cfd, SHUT_RDWR);
1588         close(self->cfd);
1589
1590         addr.sin_family = AF_INET;
1591         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1592         addr.sin_port = 0;
1593
1594         ret = bind(self->fd, &addr, sizeof(addr));
1595         EXPECT_EQ(ret, 0);
1596         ret = listen(self->fd, 10);
1597         EXPECT_EQ(ret, -1);
1598         EXPECT_EQ(errno, EINVAL);
1599
1600         ret = connect(self->fd, &addr, sizeof(addr));
1601         EXPECT_EQ(ret, -1);
1602         EXPECT_EQ(errno, EISCONN);
1603 }
1604
1605 TEST_F(tls, getsockopt)
1606 {
1607         struct tls_crypto_info_keys expect, get;
1608         socklen_t len;
1609
1610         /* get only the version/cipher */
1611         len = sizeof(struct tls_crypto_info);
1612         memrnd(&get, sizeof(get));
1613         EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1614         EXPECT_EQ(len, sizeof(struct tls_crypto_info));
1615         EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1616         EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1617
1618         /* get the full crypto_info */
1619         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &expect, 0);
1620         len = expect.len;
1621         memrnd(&get, sizeof(get));
1622         EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1623         EXPECT_EQ(len, expect.len);
1624         EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1625         EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1626         EXPECT_EQ(memcmp(&get, &expect, expect.len), 0);
1627
1628         /* short get should fail */
1629         len = sizeof(struct tls_crypto_info) - 1;
1630         EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1631         EXPECT_EQ(errno, EINVAL);
1632
1633         /* partial get of the cipher data should fail */
1634         len = expect.len - 1;
1635         EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1636         EXPECT_EQ(errno, EINVAL);
1637 }
1638
1639 TEST_F(tls, recv_efault)
1640 {
1641         char *rec1 = "1111111111";
1642         char *rec2 = "2222222222";
1643         struct msghdr hdr = {};
1644         struct iovec iov[2];
1645         char recv_mem[12];
1646         int ret;
1647
1648         if (self->notls)
1649                 SKIP(return, "no TLS support");
1650
1651         EXPECT_EQ(send(self->fd, rec1, 10, 0), 10);
1652         EXPECT_EQ(send(self->fd, rec2, 10, 0), 10);
1653
1654         iov[0].iov_base = recv_mem;
1655         iov[0].iov_len = sizeof(recv_mem);
1656         iov[1].iov_base = NULL; /* broken iov to make process_rx_list fail */
1657         iov[1].iov_len = 1;
1658
1659         hdr.msg_iovlen = 2;
1660         hdr.msg_iov = iov;
1661
1662         EXPECT_EQ(recv(self->cfd, recv_mem, 1, 0), 1);
1663         EXPECT_EQ(recv_mem[0], rec1[0]);
1664
1665         ret = recvmsg(self->cfd, &hdr, 0);
1666         EXPECT_LE(ret, sizeof(recv_mem));
1667         EXPECT_GE(ret, 9);
1668         EXPECT_EQ(memcmp(rec1, recv_mem, 9), 0);
1669         if (ret > 9)
1670                 EXPECT_EQ(memcmp(rec2, recv_mem + 9, ret - 9), 0);
1671 }
1672
1673 #define TLS_RECORD_TYPE_HANDSHAKE      0x16
1674 /* key_update, length 1, update_not_requested */
1675 static const char key_update_msg[] = "\x18\x00\x00\x01\x00";
1676 static void tls_send_keyupdate(struct __test_metadata *_metadata, int fd)
1677 {
1678         size_t len = sizeof(key_update_msg);
1679
1680         EXPECT_EQ(tls_send_cmsg(fd, TLS_RECORD_TYPE_HANDSHAKE,
1681                                 (char *)key_update_msg, len, 0),
1682                   len);
1683 }
1684
1685 static void tls_recv_keyupdate(struct __test_metadata *_metadata, int fd, int flags)
1686 {
1687         char buf[100];
1688
1689         EXPECT_EQ(tls_recv_cmsg(_metadata, fd, TLS_RECORD_TYPE_HANDSHAKE, buf, sizeof(buf), flags),
1690                   sizeof(key_update_msg));
1691         EXPECT_EQ(memcmp(buf, key_update_msg, sizeof(key_update_msg)), 0);
1692 }
1693
1694 /* set the key to 0 then 1 for RX, immediately to 1 for TX */
1695 TEST_F(tls_basic, rekey_rx)
1696 {
1697         struct tls_crypto_info_keys tls12_0, tls12_1;
1698         char const *test_str = "test_message";
1699         int send_len = strlen(test_str) + 1;
1700         char buf[20];
1701         int ret;
1702
1703         if (self->notls)
1704                 return;
1705
1706         tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1707                              &tls12_0, 0);
1708         tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1709                              &tls12_1, 1);
1710
1711         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_1, tls12_1.len);
1712         ASSERT_EQ(ret, 0);
1713
1714         ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_0, tls12_0.len);
1715         ASSERT_EQ(ret, 0);
1716
1717         ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_1, tls12_1.len);
1718         EXPECT_EQ(ret, 0);
1719
1720         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1721         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1722         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1723 }
1724
1725 /* set the key to 0 then 1 for TX, immediately to 1 for RX */
1726 TEST_F(tls_basic, rekey_tx)
1727 {
1728         struct tls_crypto_info_keys tls12_0, tls12_1;
1729         char const *test_str = "test_message";
1730         int send_len = strlen(test_str) + 1;
1731         char buf[20];
1732         int ret;
1733
1734         if (self->notls)
1735                 return;
1736
1737         tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1738                              &tls12_0, 0);
1739         tls_crypto_info_init(TLS_1_3_VERSION, TLS_CIPHER_AES_GCM_128,
1740                              &tls12_1, 1);
1741
1742         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_0, tls12_0.len);
1743         ASSERT_EQ(ret, 0);
1744
1745         ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_1, tls12_1.len);
1746         ASSERT_EQ(ret, 0);
1747
1748         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_1, tls12_1.len);
1749         EXPECT_EQ(ret, 0);
1750
1751         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1752         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1753         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1754 }
1755
1756 TEST_F(tls, rekey)
1757 {
1758         char const *test_str_1 = "test_message_before_rekey";
1759         char const *test_str_2 = "test_message_after_rekey";
1760         struct tls_crypto_info_keys tls12;
1761         int send_len;
1762         char buf[100];
1763
1764         if (variant->tls_version != TLS_1_3_VERSION)
1765                 return;
1766
1767         /* initial send/recv */
1768         send_len = strlen(test_str_1) + 1;
1769         EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1770         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1771         EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1772
1773         /* update TX key */
1774         tls_send_keyupdate(_metadata, self->fd);
1775         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1776         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1777
1778         /* send after rekey */
1779         send_len = strlen(test_str_2) + 1;
1780         EXPECT_EQ(send(self->fd, test_str_2, send_len, 0), send_len);
1781
1782         /* can't receive the KeyUpdate without a control message */
1783         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1784
1785         /* get KeyUpdate */
1786         tls_recv_keyupdate(_metadata, self->cfd, 0);
1787
1788         /* recv blocking -> -EKEYEXPIRED */
1789         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), 0), -1);
1790         EXPECT_EQ(errno, EKEYEXPIRED);
1791
1792         /* recv non-blocking -> -EKEYEXPIRED */
1793         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1794         EXPECT_EQ(errno, EKEYEXPIRED);
1795
1796         /* update RX key */
1797         EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1798
1799         /* recv after rekey */
1800         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1801         EXPECT_EQ(memcmp(buf, test_str_2, send_len), 0);
1802 }
1803
1804 TEST_F(tls, rekey_fail)
1805 {
1806         char const *test_str_1 = "test_message_before_rekey";
1807         char const *test_str_2 = "test_message_after_rekey";
1808         struct tls_crypto_info_keys tls12;
1809         int send_len;
1810         char buf[100];
1811
1812         /* initial send/recv */
1813         send_len = strlen(test_str_1) + 1;
1814         EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1815         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1816         EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1817
1818         /* update TX key */
1819         tls_send_keyupdate(_metadata, self->fd);
1820
1821         if (variant->tls_version != TLS_1_3_VERSION) {
1822                 /* just check that rekey is not supported and return */
1823                 tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1824                 EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1825                 EXPECT_EQ(errno, EBUSY);
1826                 return;
1827         }
1828
1829         /* successful update */
1830         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1831         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1832
1833         /* invalid update: change of version */
1834         tls_crypto_info_init(TLS_1_2_VERSION, variant->cipher_type, &tls12, 1);
1835         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1836         EXPECT_EQ(errno, EINVAL);
1837
1838         /* invalid update (RX socket): change of version */
1839         tls_crypto_info_init(TLS_1_2_VERSION, variant->cipher_type, &tls12, 1);
1840         EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), -1);
1841         EXPECT_EQ(errno, EINVAL);
1842
1843         /* invalid update: change of cipher */
1844         if (variant->cipher_type == TLS_CIPHER_AES_GCM_256)
1845                 tls_crypto_info_init(variant->tls_version, TLS_CIPHER_CHACHA20_POLY1305, &tls12, 1);
1846         else
1847                 tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_256, &tls12, 1);
1848         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), -1);
1849         EXPECT_EQ(errno, EINVAL);
1850
1851         /* send after rekey, the invalid updates shouldn't have an effect */
1852         send_len = strlen(test_str_2) + 1;
1853         EXPECT_EQ(send(self->fd, test_str_2, send_len, 0), send_len);
1854
1855         /* can't receive the KeyUpdate without a control message */
1856         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1857
1858         /* get KeyUpdate */
1859         tls_recv_keyupdate(_metadata, self->cfd, 0);
1860
1861         /* recv blocking -> -EKEYEXPIRED */
1862         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), 0), -1);
1863         EXPECT_EQ(errno, EKEYEXPIRED);
1864
1865         /* recv non-blocking -> -EKEYEXPIRED */
1866         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1867         EXPECT_EQ(errno, EKEYEXPIRED);
1868
1869         /* update RX key */
1870         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1871         EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1872
1873         /* recv after rekey */
1874         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1875         EXPECT_EQ(memcmp(buf, test_str_2, send_len), 0);
1876 }
1877
1878 TEST_F(tls, rekey_peek)
1879 {
1880         char const *test_str_1 = "test_message_before_rekey";
1881         struct tls_crypto_info_keys tls12;
1882         int send_len;
1883         char buf[100];
1884
1885         if (variant->tls_version != TLS_1_3_VERSION)
1886                 return;
1887
1888         send_len = strlen(test_str_1) + 1;
1889         EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1890
1891         /* update TX key */
1892         tls_send_keyupdate(_metadata, self->fd);
1893         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1894         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1895
1896         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1897         EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1898
1899         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1900         EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1901
1902         /* can't receive the KeyUpdate without a control message */
1903         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1904
1905         /* peek KeyUpdate */
1906         tls_recv_keyupdate(_metadata, self->cfd, MSG_PEEK);
1907
1908         /* get KeyUpdate */
1909         tls_recv_keyupdate(_metadata, self->cfd, 0);
1910
1911         /* update RX key */
1912         EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1913 }
1914
1915 TEST_F(tls, splice_rekey)
1916 {
1917         int send_len = TLS_PAYLOAD_MAX_LEN / 2;
1918         char mem_send[TLS_PAYLOAD_MAX_LEN];
1919         char mem_recv[TLS_PAYLOAD_MAX_LEN];
1920         struct tls_crypto_info_keys tls12;
1921         int p[2];
1922
1923         if (variant->tls_version != TLS_1_3_VERSION)
1924                 return;
1925
1926         memrnd(mem_send, sizeof(mem_send));
1927
1928         ASSERT_GE(pipe(p), 0);
1929         EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
1930
1931         /* update TX key */
1932         tls_send_keyupdate(_metadata, self->fd);
1933         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1934         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1935
1936         EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
1937
1938         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
1939         EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
1940         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
1941
1942         /* can't splice the KeyUpdate */
1943         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), -1);
1944         EXPECT_EQ(errno, EINVAL);
1945
1946         /* peek KeyUpdate */
1947         tls_recv_keyupdate(_metadata, self->cfd, MSG_PEEK);
1948
1949         /* get KeyUpdate */
1950         tls_recv_keyupdate(_metadata, self->cfd, 0);
1951
1952         /* can't splice before updating the key */
1953         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), -1);
1954         EXPECT_EQ(errno, EKEYEXPIRED);
1955
1956         /* update RX key */
1957         EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
1958
1959         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
1960         EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
1961         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
1962 }
1963
1964 TEST_F(tls, rekey_peek_splice)
1965 {
1966         char const *test_str_1 = "test_message_before_rekey";
1967         struct tls_crypto_info_keys tls12;
1968         int send_len;
1969         char buf[100];
1970         char mem_recv[TLS_PAYLOAD_MAX_LEN];
1971         int p[2];
1972
1973         if (variant->tls_version != TLS_1_3_VERSION)
1974                 return;
1975
1976         ASSERT_GE(pipe(p), 0);
1977
1978         send_len = strlen(test_str_1) + 1;
1979         EXPECT_EQ(send(self->fd, test_str_1, send_len, 0), send_len);
1980
1981         /* update TX key */
1982         tls_send_keyupdate(_metadata, self->fd);
1983         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
1984         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
1985
1986         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1987         EXPECT_EQ(memcmp(buf, test_str_1, send_len), 0);
1988
1989         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, TLS_PAYLOAD_MAX_LEN, 0), send_len);
1990         EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
1991         EXPECT_EQ(memcmp(mem_recv, test_str_1, send_len), 0);
1992 }
1993
1994 TEST_F(tls, rekey_getsockopt)
1995 {
1996         struct tls_crypto_info_keys tls12;
1997         struct tls_crypto_info_keys tls12_get;
1998         socklen_t len;
1999
2000         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 0);
2001
2002         len = tls12.len;
2003         EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_get, &len), 0);
2004         EXPECT_EQ(len, tls12.len);
2005         EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2006
2007         len = tls12.len;
2008         EXPECT_EQ(getsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_get, &len), 0);
2009         EXPECT_EQ(len, tls12.len);
2010         EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2011
2012         if (variant->tls_version != TLS_1_3_VERSION)
2013                 return;
2014
2015         tls_send_keyupdate(_metadata, self->fd);
2016         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2017         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2018
2019         tls_recv_keyupdate(_metadata, self->cfd, 0);
2020         EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2021
2022         len = tls12.len;
2023         EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &tls12_get, &len), 0);
2024         EXPECT_EQ(len, tls12.len);
2025         EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2026
2027         len = tls12.len;
2028         EXPECT_EQ(getsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12_get, &len), 0);
2029         EXPECT_EQ(len, tls12.len);
2030         EXPECT_EQ(memcmp(&tls12_get, &tls12, tls12.len), 0);
2031 }
2032
2033 TEST_F(tls, rekey_poll_pending)
2034 {
2035         char const *test_str = "test_message_after_rekey";
2036         struct tls_crypto_info_keys tls12;
2037         struct pollfd pfd = { };
2038         int send_len;
2039         int ret;
2040
2041         if (variant->tls_version != TLS_1_3_VERSION)
2042                 return;
2043
2044         /* update TX key */
2045         tls_send_keyupdate(_metadata, self->fd);
2046         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2047         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2048
2049         /* get KeyUpdate */
2050         tls_recv_keyupdate(_metadata, self->cfd, 0);
2051
2052         /* send immediately after rekey */
2053         send_len = strlen(test_str) + 1;
2054         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
2055
2056         /* key hasn't been updated, expect cfd to be non-readable */
2057         pfd.fd = self->cfd;
2058         pfd.events = POLLIN;
2059         EXPECT_EQ(poll(&pfd, 1, 0), 0);
2060
2061         ret = fork();
2062         ASSERT_GE(ret, 0);
2063
2064         if (ret) {
2065                 int pid2, status;
2066
2067                 /* wait before installing the new key */
2068                 sleep(1);
2069
2070                 /* update RX key while poll() is sleeping */
2071                 EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2072
2073                 pid2 = wait(&status);
2074                 EXPECT_EQ(pid2, ret);
2075                 EXPECT_EQ(status, 0);
2076         } else {
2077                 pfd.fd = self->cfd;
2078                 pfd.events = POLLIN;
2079                 EXPECT_EQ(poll(&pfd, 1, 5000), 1);
2080
2081                 exit(!__test_passed(_metadata));
2082         }
2083 }
2084
2085 TEST_F(tls, rekey_poll_delay)
2086 {
2087         char const *test_str = "test_message_after_rekey";
2088         struct tls_crypto_info_keys tls12;
2089         struct pollfd pfd = { };
2090         int send_len;
2091         int ret;
2092
2093         if (variant->tls_version != TLS_1_3_VERSION)
2094                 return;
2095
2096         /* update TX key */
2097         tls_send_keyupdate(_metadata, self->fd);
2098         tls_crypto_info_init(variant->tls_version, variant->cipher_type, &tls12, 1);
2099         EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2100
2101         /* get KeyUpdate */
2102         tls_recv_keyupdate(_metadata, self->cfd, 0);
2103
2104         ret = fork();
2105         ASSERT_GE(ret, 0);
2106
2107         if (ret) {
2108                 int pid2, status;
2109
2110                 /* wait before installing the new key */
2111                 sleep(1);
2112
2113                 /* update RX key while poll() is sleeping */
2114                 EXPECT_EQ(setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2115
2116                 sleep(1);
2117                 send_len = strlen(test_str) + 1;
2118                 EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
2119
2120                 pid2 = wait(&status);
2121                 EXPECT_EQ(pid2, ret);
2122                 EXPECT_EQ(status, 0);
2123         } else {
2124                 pfd.fd = self->cfd;
2125                 pfd.events = POLLIN;
2126                 EXPECT_EQ(poll(&pfd, 1, 5000), 1);
2127                 exit(!__test_passed(_metadata));
2128         }
2129 }
2130
2131 FIXTURE(tls_err)
2132 {
2133         int fd, cfd;
2134         int fd2, cfd2;
2135         bool notls;
2136 };
2137
2138 FIXTURE_VARIANT(tls_err)
2139 {
2140         uint16_t tls_version;
2141 };
2142
2143 FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
2144 {
2145         .tls_version = TLS_1_2_VERSION,
2146 };
2147
2148 FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
2149 {
2150         .tls_version = TLS_1_3_VERSION,
2151 };
2152
2153 FIXTURE_SETUP(tls_err)
2154 {
2155         struct tls_crypto_info_keys tls12;
2156         int ret;
2157
2158         tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
2159                              &tls12, 0);
2160
2161         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
2162         ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
2163         if (self->notls)
2164                 return;
2165
2166         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2167         ASSERT_EQ(ret, 0);
2168
2169         ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
2170         ASSERT_EQ(ret, 0);
2171 }
2172
2173 FIXTURE_TEARDOWN(tls_err)
2174 {
2175         close(self->fd);
2176         close(self->cfd);
2177         close(self->fd2);
2178         close(self->cfd2);
2179 }
2180
2181 TEST_F(tls_err, bad_rec)
2182 {
2183         char buf[64];
2184
2185         if (self->notls)
2186                 SKIP(return, "no TLS support");
2187
2188         memset(buf, 0x55, sizeof(buf));
2189         EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
2190         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2191         EXPECT_EQ(errno, EMSGSIZE);
2192         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
2193         EXPECT_EQ(errno, EAGAIN);
2194 }
2195
2196 TEST_F(tls_err, bad_auth)
2197 {
2198         char buf[128];
2199         int n;
2200
2201         if (self->notls)
2202                 SKIP(return, "no TLS support");
2203
2204         memrnd(buf, sizeof(buf) / 2);
2205         EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
2206         n = recv(self->cfd, buf, sizeof(buf), 0);
2207         EXPECT_GT(n, sizeof(buf) / 2);
2208
2209         buf[n - 1]++;
2210
2211         EXPECT_EQ(send(self->fd2, buf, n, 0), n);
2212         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2213         EXPECT_EQ(errno, EBADMSG);
2214         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2215         EXPECT_EQ(errno, EBADMSG);
2216 }
2217
2218 TEST_F(tls_err, bad_in_large_read)
2219 {
2220         char txt[3][64];
2221         char cip[3][128];
2222         char buf[3 * 128];
2223         int i, n;
2224
2225         if (self->notls)
2226                 SKIP(return, "no TLS support");
2227
2228         /* Put 3 records in the sockets */
2229         for (i = 0; i < 3; i++) {
2230                 memrnd(txt[i], sizeof(txt[i]));
2231                 EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
2232                           sizeof(txt[i]));
2233                 n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
2234                 EXPECT_GT(n, sizeof(txt[i]));
2235                 /* Break the third message */
2236                 if (i == 2)
2237                         cip[2][n - 1]++;
2238                 EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
2239         }
2240
2241         /* We should be able to receive the first two messages */
2242         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
2243         EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
2244         EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
2245         /* Third mesasge is bad */
2246         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2247         EXPECT_EQ(errno, EBADMSG);
2248         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2249         EXPECT_EQ(errno, EBADMSG);
2250 }
2251
2252 TEST_F(tls_err, bad_cmsg)
2253 {
2254         char *test_str = "test_read";
2255         int send_len = 10;
2256         char cip[128];
2257         char buf[128];
2258         char txt[64];
2259         int n;
2260
2261         if (self->notls)
2262                 SKIP(return, "no TLS support");
2263
2264         /* Queue up one data record */
2265         memrnd(txt, sizeof(txt));
2266         EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
2267         n = recv(self->cfd, cip, sizeof(cip), 0);
2268         EXPECT_GT(n, sizeof(txt));
2269         EXPECT_EQ(send(self->fd2, cip, n, 0), n);
2270
2271         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
2272         n = recv(self->cfd, cip, sizeof(cip), 0);
2273         cip[n - 1]++; /* Break it */
2274         EXPECT_GT(n, send_len);
2275         EXPECT_EQ(send(self->fd2, cip, n, 0), n);
2276
2277         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
2278         EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
2279         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2280         EXPECT_EQ(errno, EBADMSG);
2281         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2282         EXPECT_EQ(errno, EBADMSG);
2283 }
2284
2285 TEST_F(tls_err, timeo)
2286 {
2287         struct timeval tv = { .tv_usec = 10000, };
2288         char buf[128];
2289         int ret;
2290
2291         if (self->notls)
2292                 SKIP(return, "no TLS support");
2293
2294         ret = setsockopt(self->cfd2, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
2295         ASSERT_EQ(ret, 0);
2296
2297         ret = fork();
2298         ASSERT_GE(ret, 0);
2299
2300         if (ret) {
2301                 usleep(1000); /* Give child a head start */
2302
2303                 EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2304                 EXPECT_EQ(errno, EAGAIN);
2305
2306                 EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2307                 EXPECT_EQ(errno, EAGAIN);
2308
2309                 wait(&ret);
2310         } else {
2311                 EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
2312                 EXPECT_EQ(errno, EAGAIN);
2313                 exit(0);
2314         }
2315 }
2316
2317 TEST_F(tls_err, poll_partial_rec)
2318 {
2319         struct pollfd pfd = { };
2320         ssize_t rec_len;
2321         char rec[256];
2322         char buf[128];
2323
2324         if (self->notls)
2325                 SKIP(return, "no TLS support");
2326
2327         pfd.fd = self->cfd2;
2328         pfd.events = POLLIN;
2329         EXPECT_EQ(poll(&pfd, 1, 1), 0);
2330
2331         memrnd(buf, sizeof(buf));
2332         EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2333         rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2334         EXPECT_GT(rec_len, sizeof(buf));
2335
2336         /* Write 100B, not the full record ... */
2337         EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2338         /* ... no full record should mean no POLLIN */
2339         pfd.fd = self->cfd2;
2340         pfd.events = POLLIN;
2341         EXPECT_EQ(poll(&pfd, 1, 1), 0);
2342         /* Now write the rest, and it should all pop out of the other end. */
2343         EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
2344         pfd.fd = self->cfd2;
2345         pfd.events = POLLIN;
2346         EXPECT_EQ(poll(&pfd, 1, 1), 1);
2347         EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
2348         EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
2349 }
2350
2351 TEST_F(tls_err, epoll_partial_rec)
2352 {
2353         struct epoll_event ev, events[10];
2354         ssize_t rec_len;
2355         char rec[256];
2356         char buf[128];
2357         int epollfd;
2358
2359         if (self->notls)
2360                 SKIP(return, "no TLS support");
2361
2362         epollfd = epoll_create1(0);
2363         ASSERT_GE(epollfd, 0);
2364
2365         memset(&ev, 0, sizeof(ev));
2366         ev.events = EPOLLIN;
2367         ev.data.fd = self->cfd2;
2368         ASSERT_GE(epoll_ctl(epollfd, EPOLL_CTL_ADD, self->cfd2, &ev), 0);
2369
2370         EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
2371
2372         memrnd(buf, sizeof(buf));
2373         EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2374         rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2375         EXPECT_GT(rec_len, sizeof(buf));
2376
2377         /* Write 100B, not the full record ... */
2378         EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2379         /* ... no full record should mean no POLLIN */
2380         EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
2381         /* Now write the rest, and it should all pop out of the other end. */
2382         EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
2383         EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 1);
2384         EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
2385         EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
2386
2387         close(epollfd);
2388 }
2389
2390 TEST_F(tls_err, poll_partial_rec_async)
2391 {
2392         struct pollfd pfd = { };
2393         ssize_t rec_len;
2394         char rec[256];
2395         char buf[128];
2396         char token;
2397         int p[2];
2398         int ret;
2399
2400         if (self->notls)
2401                 SKIP(return, "no TLS support");
2402
2403         ASSERT_GE(pipe(p), 0);
2404
2405         memrnd(buf, sizeof(buf));
2406         EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
2407         rec_len = recv(self->cfd, rec, sizeof(rec), 0);
2408         EXPECT_GT(rec_len, sizeof(buf));
2409
2410         ret = fork();
2411         ASSERT_GE(ret, 0);
2412
2413         if (ret) {
2414                 int status, pid2;
2415
2416                 close(p[1]);
2417                 usleep(1000); /* Give child a head start */
2418
2419                 EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
2420
2421                 EXPECT_EQ(read(p[0], &token, 1), 1); /* Barrier #1 */
2422
2423                 EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0),
2424                           rec_len - 100);
2425
2426                 pid2 = wait(&status);
2427                 EXPECT_EQ(pid2, ret);
2428                 EXPECT_EQ(status, 0);
2429         } else {
2430                 close(p[0]);
2431
2432                 /* Child should sleep in poll(), never get a wake */
2433                 pfd.fd = self->cfd2;
2434                 pfd.events = POLLIN;
2435                 EXPECT_EQ(poll(&pfd, 1, 20), 0);
2436
2437                 EXPECT_EQ(write(p[1], &token, 1), 1); /* Barrier #1 */
2438
2439                 pfd.fd = self->cfd2;
2440                 pfd.events = POLLIN;
2441                 EXPECT_EQ(poll(&pfd, 1, 20), 1);
2442
2443                 exit(!__test_passed(_metadata));
2444         }
2445 }
2446
2447 TEST(non_established) {
2448         struct tls12_crypto_info_aes_gcm_256 tls12;
2449         struct sockaddr_in addr;
2450         int sfd, ret, fd;
2451         socklen_t len;
2452
2453         len = sizeof(addr);
2454
2455         memset(&tls12, 0, sizeof(tls12));
2456         tls12.info.version = TLS_1_2_VERSION;
2457         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2458
2459         addr.sin_family = AF_INET;
2460         addr.sin_addr.s_addr = htonl(INADDR_ANY);
2461         addr.sin_port = 0;
2462
2463         fd = socket(AF_INET, SOCK_STREAM, 0);
2464         sfd = socket(AF_INET, SOCK_STREAM, 0);
2465
2466         ret = bind(sfd, &addr, sizeof(addr));
2467         ASSERT_EQ(ret, 0);
2468         ret = listen(sfd, 10);
2469         ASSERT_EQ(ret, 0);
2470
2471         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2472         EXPECT_EQ(ret, -1);
2473         /* TLS ULP not supported */
2474         if (errno == ENOENT)
2475                 return;
2476         EXPECT_EQ(errno, ENOTCONN);
2477
2478         ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2479         EXPECT_EQ(ret, -1);
2480         EXPECT_EQ(errno, ENOTCONN);
2481
2482         ret = getsockname(sfd, &addr, &len);
2483         ASSERT_EQ(ret, 0);
2484
2485         ret = connect(fd, &addr, sizeof(addr));
2486         ASSERT_EQ(ret, 0);
2487
2488         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2489         ASSERT_EQ(ret, 0);
2490
2491         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2492         EXPECT_EQ(ret, -1);
2493         EXPECT_EQ(errno, EEXIST);
2494
2495         close(fd);
2496         close(sfd);
2497 }
2498
2499 TEST(keysizes) {
2500         struct tls12_crypto_info_aes_gcm_256 tls12;
2501         int ret, fd, cfd;
2502         bool notls;
2503
2504         memset(&tls12, 0, sizeof(tls12));
2505         tls12.info.version = TLS_1_2_VERSION;
2506         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2507
2508         ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2509
2510         if (!notls) {
2511                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
2512                                  sizeof(tls12));
2513                 EXPECT_EQ(ret, 0);
2514
2515                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
2516                                  sizeof(tls12));
2517                 EXPECT_EQ(ret, 0);
2518         }
2519
2520         close(fd);
2521         close(cfd);
2522 }
2523
2524 TEST(no_pad) {
2525         struct tls12_crypto_info_aes_gcm_256 tls12;
2526         int ret, fd, cfd, val;
2527         socklen_t len;
2528         bool notls;
2529
2530         memset(&tls12, 0, sizeof(tls12));
2531         tls12.info.version = TLS_1_3_VERSION;
2532         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2533
2534         ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2535
2536         if (notls)
2537                 exit(KSFT_SKIP);
2538
2539         ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
2540         EXPECT_EQ(ret, 0);
2541
2542         ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
2543         EXPECT_EQ(ret, 0);
2544
2545         val = 1;
2546         ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2547                          (void *)&val, sizeof(val));
2548         EXPECT_EQ(ret, 0);
2549
2550         len = sizeof(val);
2551         val = 2;
2552         ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2553                          (void *)&val, &len);
2554         EXPECT_EQ(ret, 0);
2555         EXPECT_EQ(val, 1);
2556         EXPECT_EQ(len, 4);
2557
2558         val = 0;
2559         ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2560                          (void *)&val, sizeof(val));
2561         EXPECT_EQ(ret, 0);
2562
2563         len = sizeof(val);
2564         val = 2;
2565         ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2566                          (void *)&val, &len);
2567         EXPECT_EQ(ret, 0);
2568         EXPECT_EQ(val, 0);
2569         EXPECT_EQ(len, 4);
2570
2571         close(fd);
2572         close(cfd);
2573 }
2574
2575 TEST(tls_v6ops) {
2576         struct tls_crypto_info_keys tls12;
2577         struct sockaddr_in6 addr, addr2;
2578         int sfd, ret, fd;
2579         socklen_t len, len2;
2580
2581         tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12, 0);
2582
2583         addr.sin6_family = AF_INET6;
2584         addr.sin6_addr = in6addr_any;
2585         addr.sin6_port = 0;
2586
2587         fd = socket(AF_INET6, SOCK_STREAM, 0);
2588         sfd = socket(AF_INET6, SOCK_STREAM, 0);
2589
2590         ret = bind(sfd, &addr, sizeof(addr));
2591         ASSERT_EQ(ret, 0);
2592         ret = listen(sfd, 10);
2593         ASSERT_EQ(ret, 0);
2594
2595         len = sizeof(addr);
2596         ret = getsockname(sfd, &addr, &len);
2597         ASSERT_EQ(ret, 0);
2598
2599         ret = connect(fd, &addr, sizeof(addr));
2600         ASSERT_EQ(ret, 0);
2601
2602         len = sizeof(addr);
2603         ret = getsockname(fd, &addr, &len);
2604         ASSERT_EQ(ret, 0);
2605
2606         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2607         if (ret) {
2608                 ASSERT_EQ(errno, ENOENT);
2609                 SKIP(return, "no TLS support");
2610         }
2611         ASSERT_EQ(ret, 0);
2612
2613         ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2614         ASSERT_EQ(ret, 0);
2615
2616         ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
2617         ASSERT_EQ(ret, 0);
2618
2619         len2 = sizeof(addr2);
2620         ret = getsockname(fd, &addr2, &len2);
2621         ASSERT_EQ(ret, 0);
2622
2623         EXPECT_EQ(len2, len);
2624         EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
2625
2626         close(fd);
2627         close(sfd);
2628 }
2629
2630 TEST(prequeue) {
2631         struct tls_crypto_info_keys tls12;
2632         char buf[20000], buf2[20000];
2633         struct sockaddr_in addr;
2634         int sfd, cfd, ret, fd;
2635         socklen_t len;
2636
2637         len = sizeof(addr);
2638         memrnd(buf, sizeof(buf));
2639
2640         tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_256, &tls12, 0);
2641
2642         addr.sin_family = AF_INET;
2643         addr.sin_addr.s_addr = htonl(INADDR_ANY);
2644         addr.sin_port = 0;
2645
2646         fd = socket(AF_INET, SOCK_STREAM, 0);
2647         sfd = socket(AF_INET, SOCK_STREAM, 0);
2648
2649         ASSERT_EQ(bind(sfd, &addr, sizeof(addr)), 0);
2650         ASSERT_EQ(listen(sfd, 10), 0);
2651         ASSERT_EQ(getsockname(sfd, &addr, &len), 0);
2652         ASSERT_EQ(connect(fd, &addr, sizeof(addr)), 0);
2653         ASSERT_GE(cfd = accept(sfd, &addr, &len), 0);
2654         close(sfd);
2655
2656         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2657         if (ret) {
2658                 ASSERT_EQ(errno, ENOENT);
2659                 SKIP(return, "no TLS support");
2660         }
2661
2662         ASSERT_EQ(setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2663         EXPECT_EQ(send(fd, buf, sizeof(buf), MSG_DONTWAIT), sizeof(buf));
2664
2665         ASSERT_EQ(setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")), 0);
2666         ASSERT_EQ(setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2667         EXPECT_EQ(recv(cfd, buf2, sizeof(buf2), MSG_WAITALL), sizeof(buf2));
2668
2669         EXPECT_EQ(memcmp(buf, buf2, sizeof(buf)), 0);
2670
2671         close(fd);
2672         close(cfd);
2673 }
2674
2675 static void __attribute__((constructor)) fips_check(void) {
2676         int res;
2677         FILE *f;
2678
2679         f = fopen("/proc/sys/crypto/fips_enabled", "r");
2680         if (f) {
2681                 res = fscanf(f, "%d", &fips_enabled);
2682                 if (res != 1)
2683                         ksft_print_msg("ERROR: Couldn't read /proc/sys/crypto/fips_enabled\n");
2684                 fclose(f);
2685         }
2686 }
2687
2688 TEST_HARNESS_MAIN
This page took 0.185313 seconds and 4 git commands to generate.