summaryrefslogtreecommitdiff
path: root/drivers/net/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r--drivers/net/wireguard/Makefile2
-rw-r--r--drivers/net/wireguard/allowedips.c102
-rw-r--r--drivers/net/wireguard/allowedips.h4
-rw-r--r--drivers/net/wireguard/cookie.c22
-rw-r--r--drivers/net/wireguard/device.c16
-rw-r--r--drivers/net/wireguard/generated/netlink.c73
-rw-r--r--drivers/net/wireguard/generated/netlink.h30
-rw-r--r--drivers/net/wireguard/netlink.c107
-rw-r--r--drivers/net/wireguard/noise.c36
-rw-r--r--drivers/net/wireguard/peer.h2
-rw-r--r--drivers/net/wireguard/queueing.h13
-rw-r--r--drivers/net/wireguard/selftest/allowedips.c48
-rw-r--r--drivers/net/wireguard/socket.c4
-rw-r--r--drivers/net/wireguard/timers.c25
14 files changed, 322 insertions, 162 deletions
diff --git a/drivers/net/wireguard/Makefile b/drivers/net/wireguard/Makefile
index dbe1f8514efc..00cbcc9ab69d 100644
--- a/drivers/net/wireguard/Makefile
+++ b/drivers/net/wireguard/Makefile
@@ -13,5 +13,5 @@ wireguard-y += peerlookup.o
wireguard-y += allowedips.o
wireguard-y += ratelimiter.o
wireguard-y += cookie.o
-wireguard-y += netlink.o
+wireguard-y += netlink.o generated/netlink.o
obj-$(CONFIG_WIREGUARD) := wireguard.o
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c
index 4b8528206cc8..09f7fcd7da78 100644
--- a/drivers/net/wireguard/allowedips.c
+++ b/drivers/net/wireguard/allowedips.c
@@ -249,6 +249,52 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return 0;
}
+static void remove_node(struct allowedips_node *node, struct mutex *lock)
+{
+ struct allowedips_node *child, **parent_bit, *parent;
+ bool free_parent;
+
+ list_del_init(&node->peer_list);
+ RCU_INIT_POINTER(node->peer, NULL);
+ if (node->bit[0] && node->bit[1])
+ return;
+ child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
+ lockdep_is_held(lock));
+ if (child)
+ child->parent_bit_packed = node->parent_bit_packed;
+ parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
+ *parent_bit = child;
+ parent = (void *)parent_bit -
+ offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
+ free_parent = !rcu_access_pointer(node->bit[0]) && !rcu_access_pointer(node->bit[1]) &&
+ (node->parent_bit_packed & 3) <= 1 && !rcu_access_pointer(parent->peer);
+ if (free_parent)
+ child = rcu_dereference_protected(parent->bit[!(node->parent_bit_packed & 1)],
+ lockdep_is_held(lock));
+ call_rcu(&node->rcu, node_free_rcu);
+ if (!free_parent)
+ return;
+ if (child)
+ child->parent_bit_packed = parent->parent_bit_packed;
+ *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
+ call_rcu(&parent->rcu, node_free_rcu);
+}
+
+static int remove(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock)
+{
+ struct allowedips_node *node;
+
+ if (unlikely(cidr > bits))
+ return -EINVAL;
+ if (!rcu_access_pointer(*trie) || !node_placement(*trie, key, cidr, bits, &node, lock) ||
+ peer != rcu_access_pointer(node->peer))
+ return 0;
+
+ remove_node(node, lock);
+ return 0;
+}
+
void wg_allowedips_init(struct allowedips *table)
{
table->root4 = table->root6 = NULL;
@@ -300,44 +346,38 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
return add(&table->root6, 128, key, cidr, peer, lock);
}
+int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock)
+{
+ /* Aligned so it can be passed to fls */
+ u8 key[4] __aligned(__alignof(u32));
+
+ ++table->seq;
+ swap_endian(key, (const u8 *)ip, 32);
+ return remove(&table->root4, 32, key, cidr, peer, lock);
+}
+
+int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock)
+{
+ /* Aligned so it can be passed to fls64 */
+ u8 key[16] __aligned(__alignof(u64));
+
+ ++table->seq;
+ swap_endian(key, (const u8 *)ip, 128);
+ return remove(&table->root6, 128, key, cidr, peer, lock);
+}
+
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock)
{
- struct allowedips_node *node, *child, **parent_bit, *parent, *tmp;
- bool free_parent;
+ struct allowedips_node *node, *tmp;
if (list_empty(&peer->allowedips_list))
return;
++table->seq;
- list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
- list_del_init(&node->peer_list);
- RCU_INIT_POINTER(node->peer, NULL);
- if (node->bit[0] && node->bit[1])
- continue;
- child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
- lockdep_is_held(lock));
- if (child)
- child->parent_bit_packed = node->parent_bit_packed;
- parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
- *parent_bit = child;
- parent = (void *)parent_bit -
- offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
- free_parent = !rcu_access_pointer(node->bit[0]) &&
- !rcu_access_pointer(node->bit[1]) &&
- (node->parent_bit_packed & 3) <= 1 &&
- !rcu_access_pointer(parent->peer);
- if (free_parent)
- child = rcu_dereference_protected(
- parent->bit[!(node->parent_bit_packed & 1)],
- lockdep_is_held(lock));
- call_rcu(&node->rcu, node_free_rcu);
- if (!free_parent)
- continue;
- if (child)
- child->parent_bit_packed = parent->parent_bit_packed;
- *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
- call_rcu(&parent->rcu, node_free_rcu);
- }
+ list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list)
+ remove_node(node, lock);
}
int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h
index 2346c797eb4d..931958cb6e10 100644
--- a/drivers/net/wireguard/allowedips.h
+++ b/drivers/net/wireguard/allowedips.h
@@ -38,6 +38,10 @@ int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
u8 cidr, struct wg_peer *peer, struct mutex *lock);
int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
u8 cidr, struct wg_peer *peer, struct mutex *lock);
+int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock);
+int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock);
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock);
/* The ip input pointer should be __aligned(__alignof(u64))) */
diff --git a/drivers/net/wireguard/cookie.c b/drivers/net/wireguard/cookie.c
index f89581b5e8cb..08731b3fa32b 100644
--- a/drivers/net/wireguard/cookie.c
+++ b/drivers/net/wireguard/cookie.c
@@ -26,14 +26,14 @@ void wg_cookie_checker_init(struct cookie_checker *checker,
}
enum { COOKIE_KEY_LABEL_LEN = 8 };
-static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----";
-static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--";
+static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "mac1----";
+static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "cookie--";
static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN],
const u8 pubkey[NOISE_PUBLIC_KEY_LEN],
const u8 label[COOKIE_KEY_LABEL_LEN])
{
- struct blake2s_state blake;
+ struct blake2s_ctx blake;
blake2s_init(&blake, NOISE_SYMMETRIC_KEY_LEN);
blake2s_update(&blake, label, COOKIE_KEY_LABEL_LEN);
@@ -77,7 +77,7 @@ static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len,
{
len = len - sizeof(struct message_macs) +
offsetof(struct message_macs, mac1);
- blake2s(mac1, message, key, COOKIE_LEN, len, NOISE_SYMMETRIC_KEY_LEN);
+ blake2s(key, NOISE_SYMMETRIC_KEY_LEN, message, len, mac1, COOKIE_LEN);
}
static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len,
@@ -85,13 +85,13 @@ static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len,
{
len = len - sizeof(struct message_macs) +
offsetof(struct message_macs, mac2);
- blake2s(mac2, message, cookie, COOKIE_LEN, len, COOKIE_LEN);
+ blake2s(cookie, COOKIE_LEN, message, len, mac2, COOKIE_LEN);
}
static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb,
struct cookie_checker *checker)
{
- struct blake2s_state state;
+ struct blake2s_ctx blake;
if (wg_birthdate_has_expired(checker->secret_birthdate,
COOKIE_SECRET_MAX_AGE)) {
@@ -103,15 +103,15 @@ static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb,
down_read(&checker->secret_lock);
- blake2s_init_key(&state, COOKIE_LEN, checker->secret, NOISE_HASH_LEN);
+ blake2s_init_key(&blake, COOKIE_LEN, checker->secret, NOISE_HASH_LEN);
if (skb->protocol == htons(ETH_P_IP))
- blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr,
+ blake2s_update(&blake, (u8 *)&ip_hdr(skb)->saddr,
sizeof(struct in_addr));
else if (skb->protocol == htons(ETH_P_IPV6))
- blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr,
+ blake2s_update(&blake, (u8 *)&ipv6_hdr(skb)->saddr,
sizeof(struct in6_addr));
- blake2s_update(&state, (u8 *)&udp_hdr(skb)->source, sizeof(__be16));
- blake2s_final(&state, cookie);
+ blake2s_update(&blake, (u8 *)&udp_hdr(skb)->source, sizeof(__be16));
+ blake2s_final(&blake, cookie);
up_read(&checker->secret_lock);
}
diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index 6cf173a008e7..46a71ec36af8 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -81,7 +81,7 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action, v
list_for_each_entry(wg, &device_list, device_list) {
mutex_lock(&wg->device_update_lock);
list_for_each_entry(peer, &wg->peer_list, peer_list) {
- del_timer(&peer->timer_zero_key_material);
+ timer_delete(&peer->timer_zero_key_material);
wg_noise_handshake_clear(&peer->handshake);
wg_noise_keypairs_clear(&peer->keypairs);
}
@@ -307,14 +307,15 @@ static void wg_setup(struct net_device *dev)
wg->dev = dev;
}
-static int wg_newlink(struct net *src_net, struct net_device *dev,
- struct nlattr *tb[], struct nlattr *data[],
+static int wg_newlink(struct net_device *dev,
+ struct rtnl_newlink_params *params,
struct netlink_ext_ack *extack)
{
+ struct net *link_net = rtnl_newlink_link_net(params);
struct wg_device *wg = netdev_priv(dev);
int ret = -ENOMEM;
- rcu_assign_pointer(wg->creating_net, src_net);
+ rcu_assign_pointer(wg->creating_net, link_net);
init_rwsem(&wg->static_identity.lock);
mutex_init(&wg->socket_update_lock);
mutex_init(&wg->device_update_lock);
@@ -332,7 +333,8 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
goto err_free_peer_hashtable;
wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s",
- WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name);
+ WQ_CPU_INTENSIVE | WQ_FREEZABLE | WQ_PERCPU, 0,
+ dev->name);
if (!wg->handshake_receive_wq)
goto err_free_index_hashtable;
@@ -342,7 +344,8 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
goto err_destroy_handshake_receive;
wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s",
- WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name);
+ WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM | WQ_PERCPU, 0,
+ dev->name);
if (!wg->packet_crypt_wq)
goto err_destroy_handshake_send;
@@ -365,6 +368,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
if (ret < 0)
goto err_free_handshake_queue;
+ netif_threaded_enable(dev);
ret = register_netdevice(dev);
if (ret < 0)
goto err_uninit_ratelimiter;
diff --git a/drivers/net/wireguard/generated/netlink.c b/drivers/net/wireguard/generated/netlink.c
new file mode 100644
index 000000000000..3ef8c29908c2
--- /dev/null
+++ b/drivers/net/wireguard/generated/netlink.c
@@ -0,0 +1,73 @@
+// SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
+/* Do not edit directly, auto-generated from: */
+/* Documentation/netlink/specs/wireguard.yaml */
+/* YNL-GEN kernel source */
+/* YNL-ARG --function-prefix wg */
+/* To regenerate run: tools/net/ynl/ynl-regen.sh */
+
+#include <net/netlink.h>
+#include <net/genetlink.h>
+
+#include "netlink.h"
+
+#include <uapi/linux/wireguard.h>
+#include <linux/time_types.h>
+
+/* Common nested types */
+const struct nla_policy wireguard_wgallowedip_nl_policy[WGALLOWEDIP_A_FLAGS + 1] = {
+ [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16, },
+ [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(4),
+ [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8, },
+ [WGALLOWEDIP_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, 0x1),
+};
+
+const struct nla_policy wireguard_wgpeer_nl_policy[WGPEER_A_PROTOCOL_VERSION + 1] = {
+ [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN),
+ [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN),
+ [WGPEER_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, 0x7),
+ [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(16),
+ [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16, },
+ [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(16),
+ [WGPEER_A_RX_BYTES] = { .type = NLA_U64, },
+ [WGPEER_A_TX_BYTES] = { .type = NLA_U64, },
+ [WGPEER_A_ALLOWEDIPS] = NLA_POLICY_NESTED_ARRAY(wireguard_wgallowedip_nl_policy),
+ [WGPEER_A_PROTOCOL_VERSION] = { .type = NLA_U32, },
+};
+
+/* WG_CMD_GET_DEVICE - dump */
+static const struct nla_policy wireguard_get_device_nl_policy[WGDEVICE_A_IFNAME + 1] = {
+ [WGDEVICE_A_IFINDEX] = { .type = NLA_U32, },
+ [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = 15, },
+};
+
+/* WG_CMD_SET_DEVICE - do */
+static const struct nla_policy wireguard_set_device_nl_policy[WGDEVICE_A_PEERS + 1] = {
+ [WGDEVICE_A_IFINDEX] = { .type = NLA_U32, },
+ [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = 15, },
+ [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN),
+ [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN),
+ [WGDEVICE_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, 0x1),
+ [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16, },
+ [WGDEVICE_A_FWMARK] = { .type = NLA_U32, },
+ [WGDEVICE_A_PEERS] = NLA_POLICY_NESTED_ARRAY(wireguard_wgpeer_nl_policy),
+};
+
+/* Ops table for wireguard */
+const struct genl_split_ops wireguard_nl_ops[2] = {
+ {
+ .cmd = WG_CMD_GET_DEVICE,
+ .start = wg_get_device_start,
+ .dumpit = wg_get_device_dumpit,
+ .done = wg_get_device_done,
+ .policy = wireguard_get_device_nl_policy,
+ .maxattr = WGDEVICE_A_IFNAME,
+ .flags = GENL_UNS_ADMIN_PERM | GENL_CMD_CAP_DUMP,
+ },
+ {
+ .cmd = WG_CMD_SET_DEVICE,
+ .doit = wg_set_device_doit,
+ .policy = wireguard_set_device_nl_policy,
+ .maxattr = WGDEVICE_A_PEERS,
+ .flags = GENL_UNS_ADMIN_PERM | GENL_CMD_CAP_DO,
+ },
+};
diff --git a/drivers/net/wireguard/generated/netlink.h b/drivers/net/wireguard/generated/netlink.h
new file mode 100644
index 000000000000..5dc977ee9e7c
--- /dev/null
+++ b/drivers/net/wireguard/generated/netlink.h
@@ -0,0 +1,30 @@
+/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */
+/* Do not edit directly, auto-generated from: */
+/* Documentation/netlink/specs/wireguard.yaml */
+/* YNL-GEN kernel header */
+/* YNL-ARG --function-prefix wg */
+/* To regenerate run: tools/net/ynl/ynl-regen.sh */
+
+#ifndef _LINUX_WIREGUARD_GEN_H
+#define _LINUX_WIREGUARD_GEN_H
+
+#include <net/netlink.h>
+#include <net/genetlink.h>
+
+#include <uapi/linux/wireguard.h>
+#include <linux/time_types.h>
+
+/* Common nested types */
+extern const struct nla_policy wireguard_wgallowedip_nl_policy[WGALLOWEDIP_A_FLAGS + 1];
+extern const struct nla_policy wireguard_wgpeer_nl_policy[WGPEER_A_PROTOCOL_VERSION + 1];
+
+/* Ops table for wireguard */
+extern const struct genl_split_ops wireguard_nl_ops[2];
+
+int wg_get_device_start(struct netlink_callback *cb);
+int wg_get_device_done(struct netlink_callback *cb);
+
+int wg_get_device_dumpit(struct sk_buff *skb, struct netlink_callback *cb);
+int wg_set_device_doit(struct sk_buff *skb, struct genl_info *info);
+
+#endif /* _LINUX_WIREGUARD_GEN_H */
diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c
index f7055180ba4a..1da7e98d0d50 100644
--- a/drivers/net/wireguard/netlink.c
+++ b/drivers/net/wireguard/netlink.c
@@ -9,6 +9,7 @@
#include "socket.h"
#include "queueing.h"
#include "messages.h"
+#include "generated/netlink.h"
#include <uapi/linux/wireguard.h>
@@ -19,36 +20,6 @@
static struct genl_family genl_family;
-static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
- [WGDEVICE_A_IFINDEX] = { .type = NLA_U32 },
- [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 },
- [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
- [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
- [WGDEVICE_A_FLAGS] = { .type = NLA_U32 },
- [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16 },
- [WGDEVICE_A_FWMARK] = { .type = NLA_U32 },
- [WGDEVICE_A_PEERS] = { .type = NLA_NESTED }
-};
-
-static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
- [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
- [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN),
- [WGPEER_A_FLAGS] = { .type = NLA_U32 },
- [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)),
- [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16 },
- [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)),
- [WGPEER_A_RX_BYTES] = { .type = NLA_U64 },
- [WGPEER_A_TX_BYTES] = { .type = NLA_U64 },
- [WGPEER_A_ALLOWEDIPS] = { .type = NLA_NESTED },
- [WGPEER_A_PROTOCOL_VERSION] = { .type = NLA_U32 }
-};
-
-static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
- [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16 },
- [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(sizeof(struct in_addr)),
- [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 }
-};
-
static struct wg_device *lookup_interface(struct nlattr **attrs,
struct sk_buff *skb)
{
@@ -196,7 +167,7 @@ err:
return -EMSGSIZE;
}
-static int wg_get_device_start(struct netlink_callback *cb)
+int wg_get_device_start(struct netlink_callback *cb)
{
struct wg_device *wg;
@@ -207,7 +178,7 @@ static int wg_get_device_start(struct netlink_callback *cb)
return 0;
}
-static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
+int wg_get_device_dumpit(struct sk_buff *skb, struct netlink_callback *cb)
{
struct wg_peer *peer, *next_peer_cursor;
struct dump_ctx *ctx = DUMP_CTX(cb);
@@ -301,7 +272,7 @@ out:
*/
}
-static int wg_get_device_done(struct netlink_callback *cb)
+int wg_get_device_done(struct netlink_callback *cb)
{
struct dump_ctx *ctx = DUMP_CTX(cb);
@@ -329,6 +300,7 @@ static int set_port(struct wg_device *wg, u16 port)
static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
{
int ret = -EINVAL;
+ u32 flags = 0;
u16 family;
u8 cidr;
@@ -337,19 +309,30 @@ static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
return ret;
family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]);
cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]);
+ if (attrs[WGALLOWEDIP_A_FLAGS])
+ flags = nla_get_u32(attrs[WGALLOWEDIP_A_FLAGS]);
if (family == AF_INET && cidr <= 32 &&
- nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
- ret = wg_allowedips_insert_v4(
- &peer->device->peer_allowedips,
- nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
- &peer->device->device_update_lock);
- else if (family == AF_INET6 && cidr <= 128 &&
- nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr))
- ret = wg_allowedips_insert_v6(
- &peer->device->peer_allowedips,
- nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
- &peer->device->device_update_lock);
+ nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) {
+ if (flags & WGALLOWEDIP_F_REMOVE_ME)
+ ret = wg_allowedips_remove_v4(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
+ peer, &peer->device->device_update_lock);
+ else
+ ret = wg_allowedips_insert_v4(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
+ peer, &peer->device->device_update_lock);
+ } else if (family == AF_INET6 && cidr <= 128 &&
+ nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) {
+ if (flags & WGALLOWEDIP_F_REMOVE_ME)
+ ret = wg_allowedips_remove_v6(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
+ peer, &peer->device->device_update_lock);
+ else
+ ret = wg_allowedips_insert_v6(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
+ peer, &peer->device->device_update_lock);
+ }
return ret;
}
@@ -373,9 +356,6 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs)
if (attrs[WGPEER_A_FLAGS])
flags = nla_get_u32(attrs[WGPEER_A_FLAGS]);
- ret = -EOPNOTSUPP;
- if (flags & ~__WGPEER_F_ALL)
- goto out;
ret = -EPFNOSUPPORT;
if (attrs[WGPEER_A_PROTOCOL_VERSION]) {
@@ -457,7 +437,7 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs)
nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) {
ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX,
- attr, allowedip_policy, NULL);
+ attr, NULL, NULL);
if (ret < 0)
goto out;
ret = set_allowedip(peer, allowedip);
@@ -490,7 +470,7 @@ out:
return ret;
}
-static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
+int wg_set_device_doit(struct sk_buff *skb, struct genl_info *info)
{
struct wg_device *wg = lookup_interface(info->attrs, skb);
u32 flags = 0;
@@ -506,9 +486,6 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
if (info->attrs[WGDEVICE_A_FLAGS])
flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]);
- ret = -EOPNOTSUPP;
- if (flags & ~__WGDEVICE_F_ALL)
- goto out;
if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
struct net *net;
@@ -586,7 +563,7 @@ skip_set_private_key:
nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) {
ret = nla_parse_nested(peer, WGPEER_A_MAX, attr,
- peer_policy, NULL);
+ NULL, NULL);
if (ret < 0)
goto out;
ret = set_peer(wg, peer);
@@ -607,34 +584,20 @@ out_nodev:
return ret;
}
-static const struct genl_ops genl_ops[] = {
- {
- .cmd = WG_CMD_GET_DEVICE,
- .start = wg_get_device_start,
- .dumpit = wg_get_device_dump,
- .done = wg_get_device_done,
- .flags = GENL_UNS_ADMIN_PERM
- }, {
- .cmd = WG_CMD_SET_DEVICE,
- .doit = wg_set_device,
- .flags = GENL_UNS_ADMIN_PERM
- }
-};
-
static struct genl_family genl_family __ro_after_init = {
- .ops = genl_ops,
- .n_ops = ARRAY_SIZE(genl_ops),
- .resv_start_op = WG_CMD_SET_DEVICE + 1,
+ .split_ops = wireguard_nl_ops,
+ .n_split_ops = ARRAY_SIZE(wireguard_nl_ops),
.name = WG_GENL_NAME,
.version = WG_GENL_VERSION,
- .maxattr = WGDEVICE_A_MAX,
.module = THIS_MODULE,
- .policy = device_policy,
.netnsok = true
};
int __init wg_genetlink_init(void)
{
+ BUILD_BUG_ON(WG_KEY_LEN != NOISE_PUBLIC_KEY_LEN);
+ BUILD_BUG_ON(WG_KEY_LEN != NOISE_SYMMETRIC_KEY_LEN);
+
return genl_register_family(&genl_family);
}
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index 202a33af5a72..1fe8468f0bef 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -25,18 +25,18 @@
* <- e, ee, se, psk, {}
*/
-static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
-static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
+static const u8 handshake_name[37] __nonstring = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
+static const u8 identifier_name[34] __nonstring = "WireGuard v1 zx2c4 Jason@zx2c4.com";
static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
static atomic64_t keypair_counter = ATOMIC64_INIT(0);
void __init wg_noise_init(void)
{
- struct blake2s_state blake;
+ struct blake2s_ctx blake;
- blake2s(handshake_init_chaining_key, handshake_name, NULL,
- NOISE_HASH_LEN, sizeof(handshake_name), 0);
+ blake2s(NULL, 0, handshake_name, sizeof(handshake_name),
+ handshake_init_chaining_key, NOISE_HASH_LEN);
blake2s_init(&blake, NOISE_HASH_LEN);
blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
blake2s_update(&blake, identifier_name, sizeof(identifier_name));
@@ -304,33 +304,33 @@ void wg_noise_set_static_identity_private_key(
static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen)
{
- struct blake2s_state state;
+ struct blake2s_ctx blake;
u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 };
u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32));
int i;
if (keylen > BLAKE2S_BLOCK_SIZE) {
- blake2s_init(&state, BLAKE2S_HASH_SIZE);
- blake2s_update(&state, key, keylen);
- blake2s_final(&state, x_key);
+ blake2s_init(&blake, BLAKE2S_HASH_SIZE);
+ blake2s_update(&blake, key, keylen);
+ blake2s_final(&blake, x_key);
} else
memcpy(x_key, key, keylen);
for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
x_key[i] ^= 0x36;
- blake2s_init(&state, BLAKE2S_HASH_SIZE);
- blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
- blake2s_update(&state, in, inlen);
- blake2s_final(&state, i_hash);
+ blake2s_init(&blake, BLAKE2S_HASH_SIZE);
+ blake2s_update(&blake, x_key, BLAKE2S_BLOCK_SIZE);
+ blake2s_update(&blake, in, inlen);
+ blake2s_final(&blake, i_hash);
for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
x_key[i] ^= 0x5c ^ 0x36;
- blake2s_init(&state, BLAKE2S_HASH_SIZE);
- blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
- blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
- blake2s_final(&state, i_hash);
+ blake2s_init(&blake, BLAKE2S_HASH_SIZE);
+ blake2s_update(&blake, x_key, BLAKE2S_BLOCK_SIZE);
+ blake2s_update(&blake, i_hash, BLAKE2S_HASH_SIZE);
+ blake2s_final(&blake, i_hash);
memcpy(out, i_hash, BLAKE2S_HASH_SIZE);
memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE);
@@ -431,7 +431,7 @@ static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
{
- struct blake2s_state blake;
+ struct blake2s_ctx blake;
blake2s_init(&blake, NOISE_HASH_LEN);
blake2s_update(&blake, hash, NOISE_HASH_LEN);
diff --git a/drivers/net/wireguard/peer.h b/drivers/net/wireguard/peer.h
index 76e4d3128ad4..718fb42bdac7 100644
--- a/drivers/net/wireguard/peer.h
+++ b/drivers/net/wireguard/peer.h
@@ -20,7 +20,7 @@ struct wg_device;
struct endpoint {
union {
- struct sockaddr addr;
+ struct sockaddr_inet addr; /* Large enough for both address families */
struct sockaddr_in addr4;
struct sockaddr_in6 addr6;
};
diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h
index 7eb76724b3ed..79b6d70de236 100644
--- a/drivers/net/wireguard/queueing.h
+++ b/drivers/net/wireguard/queueing.h
@@ -104,16 +104,11 @@ static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating)
static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id)
{
- unsigned int cpu = *stored_cpu, cpu_index, i;
+ unsigned int cpu = *stored_cpu;
+
+ while (unlikely(cpu >= nr_cpu_ids || !cpu_online(cpu)))
+ cpu = *stored_cpu = cpumask_nth(id % num_online_cpus(), cpu_online_mask);
- if (unlikely(cpu >= nr_cpu_ids ||
- !cpumask_test_cpu(cpu, cpu_online_mask))) {
- cpu_index = id % cpumask_weight(cpu_online_mask);
- cpu = cpumask_first(cpu_online_mask);
- for (i = 0; i < cpu_index; ++i)
- cpu = cpumask_next(cpu, cpu_online_mask);
- *stored_cpu = cpu;
- }
return cpu;
}
diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c
index 25de7058701a..41837efa70cb 100644
--- a/drivers/net/wireguard/selftest/allowedips.c
+++ b/drivers/net/wireguard/selftest/allowedips.c
@@ -460,6 +460,10 @@ static __init struct wg_peer *init_peer(void)
wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
cidr, mem, &mutex)
+#define remove(version, mem, ipa, ipb, ipc, ipd, cidr) \
+ wg_allowedips_remove_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
+ cidr, mem, &mutex)
+
#define maybe_fail() do { \
++i; \
if (!_s) { \
@@ -585,6 +589,50 @@ bool __init wg_allowedips_selftest(void)
test_negative(4, a, 192, 0, 0, 0);
test_negative(4, a, 255, 0, 0, 0);
+ insert(4, a, 1, 0, 0, 0, 32);
+ insert(4, a, 192, 0, 0, 0, 24);
+ insert(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
+ insert(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ test(4, a, 1, 0, 0, 0);
+ test(4, a, 192, 0, 0, 1);
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010);
+ /* Must be an exact match to remove */
+ remove(4, a, 192, 0, 0, 0, 32);
+ test(4, a, 192, 0, 0, 1);
+ /* NULL peer should have no effect and return 0 */
+ test_boolean(!remove(4, NULL, 192, 0, 0, 0, 24));
+ test(4, a, 192, 0, 0, 1);
+ /* different peer should have no effect and return 0 */
+ test_boolean(!remove(4, b, 192, 0, 0, 0, 24));
+ test(4, a, 192, 0, 0, 1);
+ /* invalid CIDR should have no effect and return -EINVAL */
+ test_boolean(remove(4, b, 192, 0, 0, 0, 33) == -EINVAL);
+ test(4, a, 192, 0, 0, 1);
+ remove(4, a, 192, 0, 0, 0, 24);
+ test_negative(4, a, 192, 0, 0, 1);
+ remove(4, a, 1, 0, 0, 0, 32);
+ test_negative(4, a, 1, 0, 0, 0);
+ /* Must be an exact match to remove */
+ remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96);
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ /* NULL peer should have no effect and return 0 */
+ test_boolean(!remove(6, NULL, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128));
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ /* different peer should have no effect and return 0 */
+ test_boolean(!remove(6, b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128));
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ /* invalid CIDR should have no effect and return -EINVAL */
+ test_boolean(remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 129) == -EINVAL);
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
+ test_negative(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ /* Must match the peer to remove */
+ remove(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010);
+ remove(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ test_negative(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010);
+
wg_allowedips_free(&t, &mutex);
wg_allowedips_init(&t);
insert(4, a, 192, 168, 0, 0, 16);
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index 0414d7a6ce74..253488f8c00f 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -84,7 +84,7 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
skb->ignore_df = 1;
udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds,
ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport,
- fl.fl4_dport, false, false);
+ fl.fl4_dport, false, false, 0);
goto out;
err:
@@ -151,7 +151,7 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
skb->ignore_df = 1;
udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds,
ip6_dst_hoplimit(dst), 0, fl.fl6_sport,
- fl.fl6_dport, false);
+ fl.fl6_dport, false, 0);
goto out;
err:
diff --git a/drivers/net/wireguard/timers.c b/drivers/net/wireguard/timers.c
index 968bdb4df0b3..4016a3065602 100644
--- a/drivers/net/wireguard/timers.c
+++ b/drivers/net/wireguard/timers.c
@@ -40,15 +40,15 @@ static inline void mod_peer_timer(struct wg_peer *peer,
static void wg_expired_retransmit_handshake(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer,
- timer_retransmit_handshake);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_retransmit_handshake);
if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) {
pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n",
peer->device->dev->name, peer->internal_id,
&peer->endpoint.addr, (int)MAX_TIMER_HANDSHAKES + 2);
- del_timer(&peer->timer_send_keepalive);
+ timer_delete(&peer->timer_send_keepalive);
/* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake.
*/
@@ -78,7 +78,8 @@ static void wg_expired_retransmit_handshake(struct timer_list *timer)
static void wg_expired_send_keepalive(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer, timer_send_keepalive);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_send_keepalive);
wg_packet_send_keepalive(peer);
if (peer->timer_need_another_keepalive) {
@@ -90,7 +91,8 @@ static void wg_expired_send_keepalive(struct timer_list *timer)
static void wg_expired_new_handshake(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer, timer_new_handshake);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_new_handshake);
pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n",
peer->device->dev->name, peer->internal_id,
@@ -104,7 +106,8 @@ static void wg_expired_new_handshake(struct timer_list *timer)
static void wg_expired_zero_key_material(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer, timer_zero_key_material);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_zero_key_material);
rcu_read_lock_bh();
if (!READ_ONCE(peer->is_dead)) {
@@ -134,8 +137,8 @@ static void wg_queued_expired_zero_key_material(struct work_struct *work)
static void wg_expired_send_persistent_keepalive(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer,
- timer_persistent_keepalive);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_persistent_keepalive);
if (likely(peer->persistent_keepalive_interval))
wg_packet_send_keepalive(peer);
@@ -167,7 +170,7 @@ void wg_timers_data_received(struct wg_peer *peer)
*/
void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer)
{
- del_timer(&peer->timer_send_keepalive);
+ timer_delete(&peer->timer_send_keepalive);
}
/* Should be called after any type of authenticated packet is received, whether
@@ -175,7 +178,7 @@ void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer)
*/
void wg_timers_any_authenticated_packet_received(struct wg_peer *peer)
{
- del_timer(&peer->timer_new_handshake);
+ timer_delete(&peer->timer_new_handshake);
}
/* Should be called after a handshake initiation message is sent. */
@@ -191,7 +194,7 @@ void wg_timers_handshake_initiated(struct wg_peer *peer)
*/
void wg_timers_handshake_complete(struct wg_peer *peer)
{
- del_timer(&peer->timer_retransmit_handshake);
+ timer_delete(&peer->timer_retransmit_handshake);
peer->timer_handshake_attempts = 0;
peer->sent_lastminute_handshake = false;
ktime_get_real_ts64(&peer->walltime_last_handshake);