summaryrefslogtreecommitdiff
path: root/drivers/net/wireguard/noise.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard/noise.c')
-rw-r--r--drivers/net/wireguard/noise.c55
1 files changed, 29 insertions, 26 deletions
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index 919d9d866446..708dc61c974f 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -44,32 +44,23 @@ void __init wg_noise_init(void)
}
/* Must hold peer->handshake.static_identity->lock */
-bool wg_noise_precompute_static_static(struct wg_peer *peer)
+void wg_noise_precompute_static_static(struct wg_peer *peer)
{
- bool ret;
-
down_write(&peer->handshake.lock);
- if (peer->handshake.static_identity->has_identity) {
- ret = curve25519(
- peer->handshake.precomputed_static_static,
+ if (!peer->handshake.static_identity->has_identity ||
+ !curve25519(peer->handshake.precomputed_static_static,
peer->handshake.static_identity->static_private,
- peer->handshake.remote_static);
- } else {
- u8 empty[NOISE_PUBLIC_KEY_LEN] = { 0 };
-
- ret = curve25519(empty, empty, peer->handshake.remote_static);
+ peer->handshake.remote_static))
memset(peer->handshake.precomputed_static_static, 0,
NOISE_PUBLIC_KEY_LEN);
- }
up_write(&peer->handshake.lock);
- return ret;
}
-bool wg_noise_handshake_init(struct noise_handshake *handshake,
- struct noise_static_identity *static_identity,
- const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
- const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
- struct wg_peer *peer)
+void wg_noise_handshake_init(struct noise_handshake *handshake,
+ struct noise_static_identity *static_identity,
+ const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
+ const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
+ struct wg_peer *peer)
{
memset(handshake, 0, sizeof(*handshake));
init_rwsem(&handshake->lock);
@@ -81,7 +72,7 @@ bool wg_noise_handshake_init(struct noise_handshake *handshake,
NOISE_SYMMETRIC_KEY_LEN);
handshake->static_identity = static_identity;
handshake->state = HANDSHAKE_ZEROED;
- return wg_noise_precompute_static_static(peer);
+ wg_noise_precompute_static_static(peer);
}
static void handshake_zero(struct noise_handshake *handshake)
@@ -403,6 +394,19 @@ static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
return true;
}
+static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
+ u8 key[NOISE_SYMMETRIC_KEY_LEN],
+ const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
+{
+ static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
+ if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
+ return false;
+ kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
+ NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
+ chaining_key);
+ return true;
+}
+
static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
{
struct blake2s_state blake;
@@ -531,10 +535,9 @@ wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
/* ss */
- kdf(handshake->chaining_key, key, NULL,
- handshake->precomputed_static_static, NOISE_HASH_LEN,
- NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
- handshake->chaining_key);
+ if (!mix_precomputed_dh(handshake->chaining_key, key,
+ handshake->precomputed_static_static))
+ goto out;
/* {t} */
tai64n_now(timestamp);
@@ -595,9 +598,9 @@ wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
handshake = &peer->handshake;
/* ss */
- kdf(chaining_key, key, NULL, handshake->precomputed_static_static,
- NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
- chaining_key);
+ if (!mix_precomputed_dh(chaining_key, key,
+ handshake->precomputed_static_static))
+ goto out;
/* {t} */
if (!message_decrypt(t, src->encrypted_timestamp,