]> Git Repo - linux.git/blob - drivers/net/wireguard/noise.c
drm/amdgpu:/navi10: use the ODCAP enum to index the caps array
[linux.git] / drivers / net / wireguard / noise.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <[email protected]>. All Rights Reserved.
4  */
5
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
12
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
19
20 /* This implements Noise_IKpsk2:
21  *
22  * <- s
23  * ******
24  * -> e, es, s, ss, {t}
25  * <- e, ee, se, psk, {}
26  */
27
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 [email protected]";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33
34 void __init wg_noise_init(void)
35 {
36         struct blake2s_state blake;
37
38         blake2s(handshake_init_chaining_key, handshake_name, NULL,
39                 NOISE_HASH_LEN, sizeof(handshake_name), 0);
40         blake2s_init(&blake, NOISE_HASH_LEN);
41         blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42         blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43         blake2s_final(&blake, handshake_init_hash);
44 }
45
46 /* Must hold peer->handshake.static_identity->lock */
47 bool wg_noise_precompute_static_static(struct wg_peer *peer)
48 {
49         bool ret;
50
51         down_write(&peer->handshake.lock);
52         if (peer->handshake.static_identity->has_identity) {
53                 ret = curve25519(
54                         peer->handshake.precomputed_static_static,
55                         peer->handshake.static_identity->static_private,
56                         peer->handshake.remote_static);
57         } else {
58                 u8 empty[NOISE_PUBLIC_KEY_LEN] = { 0 };
59
60                 ret = curve25519(empty, empty, peer->handshake.remote_static);
61                 memset(peer->handshake.precomputed_static_static, 0,
62                        NOISE_PUBLIC_KEY_LEN);
63         }
64         up_write(&peer->handshake.lock);
65         return ret;
66 }
67
68 bool wg_noise_handshake_init(struct noise_handshake *handshake,
69                            struct noise_static_identity *static_identity,
70                            const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
71                            const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
72                            struct wg_peer *peer)
73 {
74         memset(handshake, 0, sizeof(*handshake));
75         init_rwsem(&handshake->lock);
76         handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
77         handshake->entry.peer = peer;
78         memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
79         if (peer_preshared_key)
80                 memcpy(handshake->preshared_key, peer_preshared_key,
81                        NOISE_SYMMETRIC_KEY_LEN);
82         handshake->static_identity = static_identity;
83         handshake->state = HANDSHAKE_ZEROED;
84         return wg_noise_precompute_static_static(peer);
85 }
86
87 static void handshake_zero(struct noise_handshake *handshake)
88 {
89         memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
90         memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
91         memset(&handshake->hash, 0, NOISE_HASH_LEN);
92         memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
93         handshake->remote_index = 0;
94         handshake->state = HANDSHAKE_ZEROED;
95 }
96
97 void wg_noise_handshake_clear(struct noise_handshake *handshake)
98 {
99         wg_index_hashtable_remove(
100                         handshake->entry.peer->device->index_hashtable,
101                         &handshake->entry);
102         down_write(&handshake->lock);
103         handshake_zero(handshake);
104         up_write(&handshake->lock);
105         wg_index_hashtable_remove(
106                         handshake->entry.peer->device->index_hashtable,
107                         &handshake->entry);
108 }
109
110 static struct noise_keypair *keypair_create(struct wg_peer *peer)
111 {
112         struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
113
114         if (unlikely(!keypair))
115                 return NULL;
116         keypair->internal_id = atomic64_inc_return(&keypair_counter);
117         keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
118         keypair->entry.peer = peer;
119         kref_init(&keypair->refcount);
120         return keypair;
121 }
122
123 static void keypair_free_rcu(struct rcu_head *rcu)
124 {
125         kzfree(container_of(rcu, struct noise_keypair, rcu));
126 }
127
128 static void keypair_free_kref(struct kref *kref)
129 {
130         struct noise_keypair *keypair =
131                 container_of(kref, struct noise_keypair, refcount);
132
133         net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
134                             keypair->entry.peer->device->dev->name,
135                             keypair->internal_id,
136                             keypair->entry.peer->internal_id);
137         wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
138                                   &keypair->entry);
139         call_rcu(&keypair->rcu, keypair_free_rcu);
140 }
141
142 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
143 {
144         if (unlikely(!keypair))
145                 return;
146         if (unlikely(unreference_now))
147                 wg_index_hashtable_remove(
148                         keypair->entry.peer->device->index_hashtable,
149                         &keypair->entry);
150         kref_put(&keypair->refcount, keypair_free_kref);
151 }
152
153 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
154 {
155         RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
156                 "Taking noise keypair reference without holding the RCU BH read lock");
157         if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
158                 return NULL;
159         return keypair;
160 }
161
162 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
163 {
164         struct noise_keypair *old;
165
166         spin_lock_bh(&keypairs->keypair_update_lock);
167
168         /* We zero the next_keypair before zeroing the others, so that
169          * wg_noise_received_with_keypair returns early before subsequent ones
170          * are zeroed.
171          */
172         old = rcu_dereference_protected(keypairs->next_keypair,
173                 lockdep_is_held(&keypairs->keypair_update_lock));
174         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
175         wg_noise_keypair_put(old, true);
176
177         old = rcu_dereference_protected(keypairs->previous_keypair,
178                 lockdep_is_held(&keypairs->keypair_update_lock));
179         RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
180         wg_noise_keypair_put(old, true);
181
182         old = rcu_dereference_protected(keypairs->current_keypair,
183                 lockdep_is_held(&keypairs->keypair_update_lock));
184         RCU_INIT_POINTER(keypairs->current_keypair, NULL);
185         wg_noise_keypair_put(old, true);
186
187         spin_unlock_bh(&keypairs->keypair_update_lock);
188 }
189
190 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
191 {
192         struct noise_keypair *keypair;
193
194         wg_noise_handshake_clear(&peer->handshake);
195         wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
196
197         spin_lock_bh(&peer->keypairs.keypair_update_lock);
198         keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
199                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
200         if (keypair)
201                 keypair->sending.is_valid = false;
202         keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
203                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
204         if (keypair)
205                 keypair->sending.is_valid = false;
206         spin_unlock_bh(&peer->keypairs.keypair_update_lock);
207 }
208
209 static void add_new_keypair(struct noise_keypairs *keypairs,
210                             struct noise_keypair *new_keypair)
211 {
212         struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
213
214         spin_lock_bh(&keypairs->keypair_update_lock);
215         previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
216                 lockdep_is_held(&keypairs->keypair_update_lock));
217         next_keypair = rcu_dereference_protected(keypairs->next_keypair,
218                 lockdep_is_held(&keypairs->keypair_update_lock));
219         current_keypair = rcu_dereference_protected(keypairs->current_keypair,
220                 lockdep_is_held(&keypairs->keypair_update_lock));
221         if (new_keypair->i_am_the_initiator) {
222                 /* If we're the initiator, it means we've sent a handshake, and
223                  * received a confirmation response, which means this new
224                  * keypair can now be used.
225                  */
226                 if (next_keypair) {
227                         /* If there already was a next keypair pending, we
228                          * demote it to be the previous keypair, and free the
229                          * existing current. Note that this means KCI can result
230                          * in this transition. It would perhaps be more sound to
231                          * always just get rid of the unused next keypair
232                          * instead of putting it in the previous slot, but this
233                          * might be a bit less robust. Something to think about
234                          * for the future.
235                          */
236                         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
237                         rcu_assign_pointer(keypairs->previous_keypair,
238                                            next_keypair);
239                         wg_noise_keypair_put(current_keypair, true);
240                 } else /* If there wasn't an existing next keypair, we replace
241                         * the previous with the current one.
242                         */
243                         rcu_assign_pointer(keypairs->previous_keypair,
244                                            current_keypair);
245                 /* At this point we can get rid of the old previous keypair, and
246                  * set up the new keypair.
247                  */
248                 wg_noise_keypair_put(previous_keypair, true);
249                 rcu_assign_pointer(keypairs->current_keypair, new_keypair);
250         } else {
251                 /* If we're the responder, it means we can't use the new keypair
252                  * until we receive confirmation via the first data packet, so
253                  * we get rid of the existing previous one, the possibly
254                  * existing next one, and slide in the new next one.
255                  */
256                 rcu_assign_pointer(keypairs->next_keypair, new_keypair);
257                 wg_noise_keypair_put(next_keypair, true);
258                 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
259                 wg_noise_keypair_put(previous_keypair, true);
260         }
261         spin_unlock_bh(&keypairs->keypair_update_lock);
262 }
263
264 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
265                                     struct noise_keypair *received_keypair)
266 {
267         struct noise_keypair *old_keypair;
268         bool key_is_new;
269
270         /* We first check without taking the spinlock. */
271         key_is_new = received_keypair ==
272                      rcu_access_pointer(keypairs->next_keypair);
273         if (likely(!key_is_new))
274                 return false;
275
276         spin_lock_bh(&keypairs->keypair_update_lock);
277         /* After locking, we double check that things didn't change from
278          * beneath us.
279          */
280         if (unlikely(received_keypair !=
281                     rcu_dereference_protected(keypairs->next_keypair,
282                             lockdep_is_held(&keypairs->keypair_update_lock)))) {
283                 spin_unlock_bh(&keypairs->keypair_update_lock);
284                 return false;
285         }
286
287         /* When we've finally received the confirmation, we slide the next
288          * into the current, the current into the previous, and get rid of
289          * the old previous.
290          */
291         old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
292                 lockdep_is_held(&keypairs->keypair_update_lock));
293         rcu_assign_pointer(keypairs->previous_keypair,
294                 rcu_dereference_protected(keypairs->current_keypair,
295                         lockdep_is_held(&keypairs->keypair_update_lock)));
296         wg_noise_keypair_put(old_keypair, true);
297         rcu_assign_pointer(keypairs->current_keypair, received_keypair);
298         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
299
300         spin_unlock_bh(&keypairs->keypair_update_lock);
301         return true;
302 }
303
304 /* Must hold static_identity->lock */
305 void wg_noise_set_static_identity_private_key(
306         struct noise_static_identity *static_identity,
307         const u8 private_key[NOISE_PUBLIC_KEY_LEN])
308 {
309         memcpy(static_identity->static_private, private_key,
310                NOISE_PUBLIC_KEY_LEN);
311         curve25519_clamp_secret(static_identity->static_private);
312         static_identity->has_identity = curve25519_generate_public(
313                 static_identity->static_public, private_key);
314 }
315
316 /* This is Hugo Krawczyk's HKDF:
317  *  - https://eprint.iacr.org/2010/264.pdf
318  *  - https://tools.ietf.org/html/rfc5869
319  */
320 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
321                 size_t first_len, size_t second_len, size_t third_len,
322                 size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
323 {
324         u8 output[BLAKE2S_HASH_SIZE + 1];
325         u8 secret[BLAKE2S_HASH_SIZE];
326
327         WARN_ON(IS_ENABLED(DEBUG) &&
328                 (first_len > BLAKE2S_HASH_SIZE ||
329                  second_len > BLAKE2S_HASH_SIZE ||
330                  third_len > BLAKE2S_HASH_SIZE ||
331                  ((second_len || second_dst || third_len || third_dst) &&
332                   (!first_len || !first_dst)) ||
333                  ((third_len || third_dst) && (!second_len || !second_dst))));
334
335         /* Extract entropy from data into secret */
336         blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
337
338         if (!first_dst || !first_len)
339                 goto out;
340
341         /* Expand first key: key = secret, data = 0x1 */
342         output[0] = 1;
343         blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
344         memcpy(first_dst, output, first_len);
345
346         if (!second_dst || !second_len)
347                 goto out;
348
349         /* Expand second key: key = secret, data = first-key || 0x2 */
350         output[BLAKE2S_HASH_SIZE] = 2;
351         blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
352                         BLAKE2S_HASH_SIZE);
353         memcpy(second_dst, output, second_len);
354
355         if (!third_dst || !third_len)
356                 goto out;
357
358         /* Expand third key: key = secret, data = second-key || 0x3 */
359         output[BLAKE2S_HASH_SIZE] = 3;
360         blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
361                         BLAKE2S_HASH_SIZE);
362         memcpy(third_dst, output, third_len);
363
364 out:
365         /* Clear sensitive data from stack */
366         memzero_explicit(secret, BLAKE2S_HASH_SIZE);
367         memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
368 }
369
370 static void symmetric_key_init(struct noise_symmetric_key *key)
371 {
372         spin_lock_init(&key->counter.receive.lock);
373         atomic64_set(&key->counter.counter, 0);
374         memset(key->counter.receive.backtrack, 0,
375                sizeof(key->counter.receive.backtrack));
376         key->birthdate = ktime_get_coarse_boottime_ns();
377         key->is_valid = true;
378 }
379
380 static void derive_keys(struct noise_symmetric_key *first_dst,
381                         struct noise_symmetric_key *second_dst,
382                         const u8 chaining_key[NOISE_HASH_LEN])
383 {
384         kdf(first_dst->key, second_dst->key, NULL, NULL,
385             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
386             chaining_key);
387         symmetric_key_init(first_dst);
388         symmetric_key_init(second_dst);
389 }
390
391 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
392                                 u8 key[NOISE_SYMMETRIC_KEY_LEN],
393                                 const u8 private[NOISE_PUBLIC_KEY_LEN],
394                                 const u8 public[NOISE_PUBLIC_KEY_LEN])
395 {
396         u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
397
398         if (unlikely(!curve25519(dh_calculation, private, public)))
399                 return false;
400         kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
401             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
402         memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
403         return true;
404 }
405
406 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
407 {
408         struct blake2s_state blake;
409
410         blake2s_init(&blake, NOISE_HASH_LEN);
411         blake2s_update(&blake, hash, NOISE_HASH_LEN);
412         blake2s_update(&blake, src, src_len);
413         blake2s_final(&blake, hash);
414 }
415
416 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
417                     u8 key[NOISE_SYMMETRIC_KEY_LEN],
418                     const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
419 {
420         u8 temp_hash[NOISE_HASH_LEN];
421
422         kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
423             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
424         mix_hash(hash, temp_hash, NOISE_HASH_LEN);
425         memzero_explicit(temp_hash, NOISE_HASH_LEN);
426 }
427
428 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
429                            u8 hash[NOISE_HASH_LEN],
430                            const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
431 {
432         memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
433         memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
434         mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
435 }
436
437 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
438                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
439                             u8 hash[NOISE_HASH_LEN])
440 {
441         chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
442                                  NOISE_HASH_LEN,
443                                  0 /* Always zero for Noise_IK */, key);
444         mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
445 }
446
447 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
448                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
449                             u8 hash[NOISE_HASH_LEN])
450 {
451         if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
452                                       hash, NOISE_HASH_LEN,
453                                       0 /* Always zero for Noise_IK */, key))
454                 return false;
455         mix_hash(hash, src_ciphertext, src_len);
456         return true;
457 }
458
459 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
460                               const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
461                               u8 chaining_key[NOISE_HASH_LEN],
462                               u8 hash[NOISE_HASH_LEN])
463 {
464         if (ephemeral_dst != ephemeral_src)
465                 memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
466         mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
467         kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
468             NOISE_PUBLIC_KEY_LEN, chaining_key);
469 }
470
471 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
472 {
473         struct timespec64 now;
474
475         ktime_get_real_ts64(&now);
476
477         /* In order to prevent some sort of infoleak from precise timers, we
478          * round down the nanoseconds part to the closest rounded-down power of
479          * two to the maximum initiations per second allowed anyway by the
480          * implementation.
481          */
482         now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
483                 rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
484
485         /* https://cr.yp.to/libtai/tai64.html */
486         *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
487         *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
488 }
489
490 bool
491 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
492                                      struct noise_handshake *handshake)
493 {
494         u8 timestamp[NOISE_TIMESTAMP_LEN];
495         u8 key[NOISE_SYMMETRIC_KEY_LEN];
496         bool ret = false;
497
498         /* We need to wait for crng _before_ taking any locks, since
499          * curve25519_generate_secret uses get_random_bytes_wait.
500          */
501         wait_for_random_bytes();
502
503         down_read(&handshake->static_identity->lock);
504         down_write(&handshake->lock);
505
506         if (unlikely(!handshake->static_identity->has_identity))
507                 goto out;
508
509         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
510
511         handshake_init(handshake->chaining_key, handshake->hash,
512                        handshake->remote_static);
513
514         /* e */
515         curve25519_generate_secret(handshake->ephemeral_private);
516         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
517                                         handshake->ephemeral_private))
518                 goto out;
519         message_ephemeral(dst->unencrypted_ephemeral,
520                           dst->unencrypted_ephemeral, handshake->chaining_key,
521                           handshake->hash);
522
523         /* es */
524         if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
525                     handshake->remote_static))
526                 goto out;
527
528         /* s */
529         message_encrypt(dst->encrypted_static,
530                         handshake->static_identity->static_public,
531                         NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
532
533         /* ss */
534         kdf(handshake->chaining_key, key, NULL,
535             handshake->precomputed_static_static, NOISE_HASH_LEN,
536             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
537             handshake->chaining_key);
538
539         /* {t} */
540         tai64n_now(timestamp);
541         message_encrypt(dst->encrypted_timestamp, timestamp,
542                         NOISE_TIMESTAMP_LEN, key, handshake->hash);
543
544         dst->sender_index = wg_index_hashtable_insert(
545                 handshake->entry.peer->device->index_hashtable,
546                 &handshake->entry);
547
548         handshake->state = HANDSHAKE_CREATED_INITIATION;
549         ret = true;
550
551 out:
552         up_write(&handshake->lock);
553         up_read(&handshake->static_identity->lock);
554         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
555         return ret;
556 }
557
558 struct wg_peer *
559 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
560                                       struct wg_device *wg)
561 {
562         struct wg_peer *peer = NULL, *ret_peer = NULL;
563         struct noise_handshake *handshake;
564         bool replay_attack, flood_attack;
565         u8 key[NOISE_SYMMETRIC_KEY_LEN];
566         u8 chaining_key[NOISE_HASH_LEN];
567         u8 hash[NOISE_HASH_LEN];
568         u8 s[NOISE_PUBLIC_KEY_LEN];
569         u8 e[NOISE_PUBLIC_KEY_LEN];
570         u8 t[NOISE_TIMESTAMP_LEN];
571         u64 initiation_consumption;
572
573         down_read(&wg->static_identity.lock);
574         if (unlikely(!wg->static_identity.has_identity))
575                 goto out;
576
577         handshake_init(chaining_key, hash, wg->static_identity.static_public);
578
579         /* e */
580         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
581
582         /* es */
583         if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
584                 goto out;
585
586         /* s */
587         if (!message_decrypt(s, src->encrypted_static,
588                              sizeof(src->encrypted_static), key, hash))
589                 goto out;
590
591         /* Lookup which peer we're actually talking to */
592         peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
593         if (!peer)
594                 goto out;
595         handshake = &peer->handshake;
596
597         /* ss */
598         kdf(chaining_key, key, NULL, handshake->precomputed_static_static,
599             NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
600             chaining_key);
601
602         /* {t} */
603         if (!message_decrypt(t, src->encrypted_timestamp,
604                              sizeof(src->encrypted_timestamp), key, hash))
605                 goto out;
606
607         down_read(&handshake->lock);
608         replay_attack = memcmp(t, handshake->latest_timestamp,
609                                NOISE_TIMESTAMP_LEN) <= 0;
610         flood_attack = (s64)handshake->last_initiation_consumption +
611                                NSEC_PER_SEC / INITIATIONS_PER_SECOND >
612                        (s64)ktime_get_coarse_boottime_ns();
613         up_read(&handshake->lock);
614         if (replay_attack || flood_attack)
615                 goto out;
616
617         /* Success! Copy everything to peer */
618         down_write(&handshake->lock);
619         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
620         if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
621                 memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
622         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
623         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
624         handshake->remote_index = src->sender_index;
625         if ((s64)(handshake->last_initiation_consumption -
626             (initiation_consumption = ktime_get_coarse_boottime_ns())) < 0)
627                 handshake->last_initiation_consumption = initiation_consumption;
628         handshake->state = HANDSHAKE_CONSUMED_INITIATION;
629         up_write(&handshake->lock);
630         ret_peer = peer;
631
632 out:
633         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
634         memzero_explicit(hash, NOISE_HASH_LEN);
635         memzero_explicit(chaining_key, NOISE_HASH_LEN);
636         up_read(&wg->static_identity.lock);
637         if (!ret_peer)
638                 wg_peer_put(peer);
639         return ret_peer;
640 }
641
642 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
643                                         struct noise_handshake *handshake)
644 {
645         u8 key[NOISE_SYMMETRIC_KEY_LEN];
646         bool ret = false;
647
648         /* We need to wait for crng _before_ taking any locks, since
649          * curve25519_generate_secret uses get_random_bytes_wait.
650          */
651         wait_for_random_bytes();
652
653         down_read(&handshake->static_identity->lock);
654         down_write(&handshake->lock);
655
656         if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
657                 goto out;
658
659         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
660         dst->receiver_index = handshake->remote_index;
661
662         /* e */
663         curve25519_generate_secret(handshake->ephemeral_private);
664         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
665                                         handshake->ephemeral_private))
666                 goto out;
667         message_ephemeral(dst->unencrypted_ephemeral,
668                           dst->unencrypted_ephemeral, handshake->chaining_key,
669                           handshake->hash);
670
671         /* ee */
672         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
673                     handshake->remote_ephemeral))
674                 goto out;
675
676         /* se */
677         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
678                     handshake->remote_static))
679                 goto out;
680
681         /* psk */
682         mix_psk(handshake->chaining_key, handshake->hash, key,
683                 handshake->preshared_key);
684
685         /* {} */
686         message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
687
688         dst->sender_index = wg_index_hashtable_insert(
689                 handshake->entry.peer->device->index_hashtable,
690                 &handshake->entry);
691
692         handshake->state = HANDSHAKE_CREATED_RESPONSE;
693         ret = true;
694
695 out:
696         up_write(&handshake->lock);
697         up_read(&handshake->static_identity->lock);
698         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
699         return ret;
700 }
701
702 struct wg_peer *
703 wg_noise_handshake_consume_response(struct message_handshake_response *src,
704                                     struct wg_device *wg)
705 {
706         enum noise_handshake_state state = HANDSHAKE_ZEROED;
707         struct wg_peer *peer = NULL, *ret_peer = NULL;
708         struct noise_handshake *handshake;
709         u8 key[NOISE_SYMMETRIC_KEY_LEN];
710         u8 hash[NOISE_HASH_LEN];
711         u8 chaining_key[NOISE_HASH_LEN];
712         u8 e[NOISE_PUBLIC_KEY_LEN];
713         u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
714         u8 static_private[NOISE_PUBLIC_KEY_LEN];
715
716         down_read(&wg->static_identity.lock);
717
718         if (unlikely(!wg->static_identity.has_identity))
719                 goto out;
720
721         handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
722                 wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
723                 src->receiver_index, &peer);
724         if (unlikely(!handshake))
725                 goto out;
726
727         down_read(&handshake->lock);
728         state = handshake->state;
729         memcpy(hash, handshake->hash, NOISE_HASH_LEN);
730         memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
731         memcpy(ephemeral_private, handshake->ephemeral_private,
732                NOISE_PUBLIC_KEY_LEN);
733         up_read(&handshake->lock);
734
735         if (state != HANDSHAKE_CREATED_INITIATION)
736                 goto fail;
737
738         /* e */
739         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
740
741         /* ee */
742         if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
743                 goto fail;
744
745         /* se */
746         if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
747                 goto fail;
748
749         /* psk */
750         mix_psk(chaining_key, hash, key, handshake->preshared_key);
751
752         /* {} */
753         if (!message_decrypt(NULL, src->encrypted_nothing,
754                              sizeof(src->encrypted_nothing), key, hash))
755                 goto fail;
756
757         /* Success! Copy everything to peer */
758         down_write(&handshake->lock);
759         /* It's important to check that the state is still the same, while we
760          * have an exclusive lock.
761          */
762         if (handshake->state != state) {
763                 up_write(&handshake->lock);
764                 goto fail;
765         }
766         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
767         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
768         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
769         handshake->remote_index = src->sender_index;
770         handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
771         up_write(&handshake->lock);
772         ret_peer = peer;
773         goto out;
774
775 fail:
776         wg_peer_put(peer);
777 out:
778         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
779         memzero_explicit(hash, NOISE_HASH_LEN);
780         memzero_explicit(chaining_key, NOISE_HASH_LEN);
781         memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
782         memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
783         up_read(&wg->static_identity.lock);
784         return ret_peer;
785 }
786
787 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
788                                       struct noise_keypairs *keypairs)
789 {
790         struct noise_keypair *new_keypair;
791         bool ret = false;
792
793         down_write(&handshake->lock);
794         if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
795             handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
796                 goto out;
797
798         new_keypair = keypair_create(handshake->entry.peer);
799         if (!new_keypair)
800                 goto out;
801         new_keypair->i_am_the_initiator = handshake->state ==
802                                           HANDSHAKE_CONSUMED_RESPONSE;
803         new_keypair->remote_index = handshake->remote_index;
804
805         if (new_keypair->i_am_the_initiator)
806                 derive_keys(&new_keypair->sending, &new_keypair->receiving,
807                             handshake->chaining_key);
808         else
809                 derive_keys(&new_keypair->receiving, &new_keypair->sending,
810                             handshake->chaining_key);
811
812         handshake_zero(handshake);
813         rcu_read_lock_bh();
814         if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
815                                            handshake)->is_dead))) {
816                 add_new_keypair(keypairs, new_keypair);
817                 net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
818                                     handshake->entry.peer->device->dev->name,
819                                     new_keypair->internal_id,
820                                     handshake->entry.peer->internal_id);
821                 ret = wg_index_hashtable_replace(
822                         handshake->entry.peer->device->index_hashtable,
823                         &handshake->entry, &new_keypair->entry);
824         } else {
825                 kzfree(new_keypair);
826         }
827         rcu_read_unlock_bh();
828
829 out:
830         up_write(&handshake->lock);
831         return ret;
832 }
This page took 0.078385 seconds and 4 git commands to generate.