summaryrefslogtreecommitdiff
path: root/drivers/net/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r--drivers/net/wireguard/Makefile2
-rw-r--r--drivers/net/wireguard/allowedips.c106
-rw-r--r--drivers/net/wireguard/allowedips.h4
-rw-r--r--drivers/net/wireguard/cookie.c24
-rw-r--r--drivers/net/wireguard/device.c36
-rw-r--r--drivers/net/wireguard/generated/netlink.c73
-rw-r--r--drivers/net/wireguard/generated/netlink.h30
-rw-r--r--drivers/net/wireguard/main.c2
-rw-r--r--drivers/net/wireguard/netlink.c119
-rw-r--r--drivers/net/wireguard/noise.c38
-rw-r--r--drivers/net/wireguard/peer.h2
-rw-r--r--drivers/net/wireguard/queueing.h17
-rw-r--r--drivers/net/wireguard/receive.c20
-rw-r--r--drivers/net/wireguard/selftest/allowedips.c49
-rw-r--r--drivers/net/wireguard/send.c5
-rw-r--r--drivers/net/wireguard/socket.c4
-rw-r--r--drivers/net/wireguard/timers.c25
17 files changed, 355 insertions, 201 deletions
diff --git a/drivers/net/wireguard/Makefile b/drivers/net/wireguard/Makefile
index dbe1f8514efc..00cbcc9ab69d 100644
--- a/drivers/net/wireguard/Makefile
+++ b/drivers/net/wireguard/Makefile
@@ -13,5 +13,5 @@ wireguard-y += peerlookup.o
wireguard-y += allowedips.o
wireguard-y += ratelimiter.o
wireguard-y += cookie.o
-wireguard-y += netlink.o
+wireguard-y += netlink.o generated/netlink.o
obj-$(CONFIG_WIREGUARD) := wireguard.o
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c
index 0ba714ca5185..09f7fcd7da78 100644
--- a/drivers/net/wireguard/allowedips.c
+++ b/drivers/net/wireguard/allowedips.c
@@ -15,8 +15,8 @@ static void swap_endian(u8 *dst, const u8 *src, u8 bits)
if (bits == 32) {
*(u32 *)dst = be32_to_cpu(*(const __be32 *)src);
} else if (bits == 128) {
- ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]);
- ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]);
+ ((u64 *)dst)[0] = get_unaligned_be64(src);
+ ((u64 *)dst)[1] = get_unaligned_be64(src + 8);
}
}
@@ -249,6 +249,52 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return 0;
}
+static void remove_node(struct allowedips_node *node, struct mutex *lock)
+{
+ struct allowedips_node *child, **parent_bit, *parent;
+ bool free_parent;
+
+ list_del_init(&node->peer_list);
+ RCU_INIT_POINTER(node->peer, NULL);
+ if (node->bit[0] && node->bit[1])
+ return;
+ child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
+ lockdep_is_held(lock));
+ if (child)
+ child->parent_bit_packed = node->parent_bit_packed;
+ parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
+ *parent_bit = child;
+ parent = (void *)parent_bit -
+ offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
+ free_parent = !rcu_access_pointer(node->bit[0]) && !rcu_access_pointer(node->bit[1]) &&
+ (node->parent_bit_packed & 3) <= 1 && !rcu_access_pointer(parent->peer);
+ if (free_parent)
+ child = rcu_dereference_protected(parent->bit[!(node->parent_bit_packed & 1)],
+ lockdep_is_held(lock));
+ call_rcu(&node->rcu, node_free_rcu);
+ if (!free_parent)
+ return;
+ if (child)
+ child->parent_bit_packed = parent->parent_bit_packed;
+ *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
+ call_rcu(&parent->rcu, node_free_rcu);
+}
+
+static int remove(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock)
+{
+ struct allowedips_node *node;
+
+ if (unlikely(cidr > bits))
+ return -EINVAL;
+ if (!rcu_access_pointer(*trie) || !node_placement(*trie, key, cidr, bits, &node, lock) ||
+ peer != rcu_access_pointer(node->peer))
+ return 0;
+
+ remove_node(node, lock);
+ return 0;
+}
+
void wg_allowedips_init(struct allowedips *table)
{
table->root4 = table->root6 = NULL;
@@ -300,44 +346,38 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
return add(&table->root6, 128, key, cidr, peer, lock);
}
+int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock)
+{
+ /* Aligned so it can be passed to fls */
+ u8 key[4] __aligned(__alignof(u32));
+
+ ++table->seq;
+ swap_endian(key, (const u8 *)ip, 32);
+ return remove(&table->root4, 32, key, cidr, peer, lock);
+}
+
+int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock)
+{
+ /* Aligned so it can be passed to fls64 */
+ u8 key[16] __aligned(__alignof(u64));
+
+ ++table->seq;
+ swap_endian(key, (const u8 *)ip, 128);
+ return remove(&table->root6, 128, key, cidr, peer, lock);
+}
+
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock)
{
- struct allowedips_node *node, *child, **parent_bit, *parent, *tmp;
- bool free_parent;
+ struct allowedips_node *node, *tmp;
if (list_empty(&peer->allowedips_list))
return;
++table->seq;
- list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
- list_del_init(&node->peer_list);
- RCU_INIT_POINTER(node->peer, NULL);
- if (node->bit[0] && node->bit[1])
- continue;
- child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
- lockdep_is_held(lock));
- if (child)
- child->parent_bit_packed = node->parent_bit_packed;
- parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
- *parent_bit = child;
- parent = (void *)parent_bit -
- offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
- free_parent = !rcu_access_pointer(node->bit[0]) &&
- !rcu_access_pointer(node->bit[1]) &&
- (node->parent_bit_packed & 3) <= 1 &&
- !rcu_access_pointer(parent->peer);
- if (free_parent)
- child = rcu_dereference_protected(
- parent->bit[!(node->parent_bit_packed & 1)],
- lockdep_is_held(lock));
- call_rcu(&node->rcu, node_free_rcu);
- if (!free_parent)
- continue;
- if (child)
- child->parent_bit_packed = parent->parent_bit_packed;
- *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
- call_rcu(&parent->rcu, node_free_rcu);
- }
+ list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list)
+ remove_node(node, lock);
}
int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h
index 2346c797eb4d..931958cb6e10 100644
--- a/drivers/net/wireguard/allowedips.h
+++ b/drivers/net/wireguard/allowedips.h
@@ -38,6 +38,10 @@ int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
u8 cidr, struct wg_peer *peer, struct mutex *lock);
int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
u8 cidr, struct wg_peer *peer, struct mutex *lock);
+int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock);
+int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock);
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock);
/* The ip input pointer should be __aligned(__alignof(u64))) */
diff --git a/drivers/net/wireguard/cookie.c b/drivers/net/wireguard/cookie.c
index 4956f0499c19..08731b3fa32b 100644
--- a/drivers/net/wireguard/cookie.c
+++ b/drivers/net/wireguard/cookie.c
@@ -12,9 +12,9 @@
#include <crypto/blake2s.h>
#include <crypto/chacha20poly1305.h>
+#include <crypto/utils.h>
#include <net/ipv6.h>
-#include <crypto/algapi.h>
void wg_cookie_checker_init(struct cookie_checker *checker,
struct wg_device *wg)
@@ -26,14 +26,14 @@ void wg_cookie_checker_init(struct cookie_checker *checker,
}
enum { COOKIE_KEY_LABEL_LEN = 8 };
-static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----";
-static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--";
+static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "mac1----";
+static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "cookie--";
static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN],
const u8 pubkey[NOISE_PUBLIC_KEY_LEN],
const u8 label[COOKIE_KEY_LABEL_LEN])
{
- struct blake2s_state blake;
+ struct blake2s_ctx blake;
blake2s_init(&blake, NOISE_SYMMETRIC_KEY_LEN);
blake2s_update(&blake, label, COOKIE_KEY_LABEL_LEN);
@@ -77,7 +77,7 @@ static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len,
{
len = len - sizeof(struct message_macs) +
offsetof(struct message_macs, mac1);
- blake2s(mac1, message, key, COOKIE_LEN, len, NOISE_SYMMETRIC_KEY_LEN);
+ blake2s(key, NOISE_SYMMETRIC_KEY_LEN, message, len, mac1, COOKIE_LEN);
}
static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len,
@@ -85,13 +85,13 @@ static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len,
{
len = len - sizeof(struct message_macs) +
offsetof(struct message_macs, mac2);
- blake2s(mac2, message, cookie, COOKIE_LEN, len, COOKIE_LEN);
+ blake2s(cookie, COOKIE_LEN, message, len, mac2, COOKIE_LEN);
}
static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb,
struct cookie_checker *checker)
{
- struct blake2s_state state;
+ struct blake2s_ctx blake;
if (wg_birthdate_has_expired(checker->secret_birthdate,
COOKIE_SECRET_MAX_AGE)) {
@@ -103,15 +103,15 @@ static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb,
down_read(&checker->secret_lock);
- blake2s_init_key(&state, COOKIE_LEN, checker->secret, NOISE_HASH_LEN);
+ blake2s_init_key(&blake, COOKIE_LEN, checker->secret, NOISE_HASH_LEN);
if (skb->protocol == htons(ETH_P_IP))
- blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr,
+ blake2s_update(&blake, (u8 *)&ip_hdr(skb)->saddr,
sizeof(struct in_addr));
else if (skb->protocol == htons(ETH_P_IPV6))
- blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr,
+ blake2s_update(&blake, (u8 *)&ipv6_hdr(skb)->saddr,
sizeof(struct in6_addr));
- blake2s_update(&state, (u8 *)&udp_hdr(skb)->source, sizeof(__be16));
- blake2s_final(&state, cookie);
+ blake2s_update(&blake, (u8 *)&udp_hdr(skb)->source, sizeof(__be16));
+ blake2s_final(&blake, cookie);
up_read(&checker->secret_lock);
}
diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index 258dcc103921..46a71ec36af8 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -81,7 +81,7 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action, v
list_for_each_entry(wg, &device_list, device_list) {
mutex_lock(&wg->device_update_lock);
list_for_each_entry(peer, &wg->peer_list, peer_list) {
- del_timer(&peer->timer_zero_key_material);
+ timer_delete(&peer->timer_zero_key_material);
wg_noise_handshake_clear(&peer->handshake);
wg_noise_keypairs_clear(&peer->keypairs);
}
@@ -210,7 +210,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev)
*/
while (skb_queue_len(&peer->staged_packet_queue) > MAX_STAGED_PACKETS) {
dev_kfree_skb(__skb_dequeue(&peer->staged_packet_queue));
- ++dev->stats.tx_dropped;
+ DEV_STATS_INC(dev, tx_dropped);
}
skb_queue_splice_tail(&packets, &peer->staged_packet_queue);
spin_unlock_bh(&peer->staged_packet_queue.lock);
@@ -228,7 +228,7 @@ err_icmp:
else if (skb->protocol == htons(ETH_P_IPV6))
icmpv6_ndo_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0);
err:
- ++dev->stats.tx_errors;
+ DEV_STATS_INC(dev, tx_errors);
kfree_skb(skb);
return ret;
}
@@ -237,7 +237,6 @@ static const struct net_device_ops netdev_ops = {
.ndo_open = wg_open,
.ndo_stop = wg_stop,
.ndo_start_xmit = wg_xmit,
- .ndo_get_stats64 = dev_get_tstats64
};
static void wg_destruct(struct net_device *dev)
@@ -262,7 +261,6 @@ static void wg_destruct(struct net_device *dev)
rcu_barrier(); /* Wait for all the peers to be actually freed. */
wg_ratelimiter_uninit();
memzero_explicit(&wg->static_identity, sizeof(wg->static_identity));
- free_percpu(dev->tstats);
kvfree(wg->index_hashtable);
kvfree(wg->peer_hashtable);
mutex_unlock(&wg->device_update_lock);
@@ -291,30 +289,33 @@ static void wg_setup(struct net_device *dev)
dev->type = ARPHRD_NONE;
dev->flags = IFF_POINTOPOINT | IFF_NOARP;
dev->priv_flags |= IFF_NO_QUEUE;
- dev->features |= NETIF_F_LLTX;
+ dev->lltx = true;
dev->features |= WG_NETDEV_FEATURES;
dev->hw_features |= WG_NETDEV_FEATURES;
dev->hw_enc_features |= WG_NETDEV_FEATURES;
dev->mtu = ETH_DATA_LEN - overhead;
dev->max_mtu = round_down(INT_MAX, MESSAGE_PADDING_MULTIPLE) - overhead;
+ dev->pcpu_stat_type = NETDEV_PCPU_STAT_TSTATS;
SET_NETDEV_DEVTYPE(dev, &device_type);
/* We need to keep the dst around in case of icmp replies. */
netif_keep_dst(dev);
- memset(wg, 0, sizeof(*wg));
+ netif_set_tso_max_size(dev, GSO_MAX_SIZE);
+
wg->dev = dev;
}
-static int wg_newlink(struct net *src_net, struct net_device *dev,
- struct nlattr *tb[], struct nlattr *data[],
+static int wg_newlink(struct net_device *dev,
+ struct rtnl_newlink_params *params,
struct netlink_ext_ack *extack)
{
+ struct net *link_net = rtnl_newlink_link_net(params);
struct wg_device *wg = netdev_priv(dev);
int ret = -ENOMEM;
- rcu_assign_pointer(wg->creating_net, src_net);
+ rcu_assign_pointer(wg->creating_net, link_net);
init_rwsem(&wg->static_identity.lock);
mutex_init(&wg->socket_update_lock);
mutex_init(&wg->device_update_lock);
@@ -331,14 +332,11 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
if (!wg->index_hashtable)
goto err_free_peer_hashtable;
- dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
- if (!dev->tstats)
- goto err_free_index_hashtable;
-
wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s",
- WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name);
+ WQ_CPU_INTENSIVE | WQ_FREEZABLE | WQ_PERCPU, 0,
+ dev->name);
if (!wg->handshake_receive_wq)
- goto err_free_tstats;
+ goto err_free_index_hashtable;
wg->handshake_send_wq = alloc_workqueue("wg-kex-%s",
WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name);
@@ -346,7 +344,8 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
goto err_destroy_handshake_receive;
wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s",
- WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name);
+ WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM | WQ_PERCPU, 0,
+ dev->name);
if (!wg->packet_crypt_wq)
goto err_destroy_handshake_send;
@@ -369,6 +368,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
if (ret < 0)
goto err_free_handshake_queue;
+ netif_threaded_enable(dev);
ret = register_netdevice(dev);
if (ret < 0)
goto err_uninit_ratelimiter;
@@ -397,8 +397,6 @@ err_destroy_handshake_send:
destroy_workqueue(wg->handshake_send_wq);
err_destroy_handshake_receive:
destroy_workqueue(wg->handshake_receive_wq);
-err_free_tstats:
- free_percpu(dev->tstats);
err_free_index_hashtable:
kvfree(wg->index_hashtable);
err_free_peer_hashtable:
diff --git a/drivers/net/wireguard/generated/netlink.c b/drivers/net/wireguard/generated/netlink.c
new file mode 100644
index 000000000000..3ef8c29908c2
--- /dev/null
+++ b/drivers/net/wireguard/generated/netlink.c
@@ -0,0 +1,73 @@
+// SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
+/* Do not edit directly, auto-generated from: */
+/* Documentation/netlink/specs/wireguard.yaml */
+/* YNL-GEN kernel source */
+/* YNL-ARG --function-prefix wg */
+/* To regenerate run: tools/net/ynl/ynl-regen.sh */
+
+#include <net/netlink.h>
+#include <net/genetlink.h>
+
+#include "netlink.h"
+
+#include <uapi/linux/wireguard.h>
+#include <linux/time_types.h>
+
+/* Common nested types */
+const struct nla_policy wireguard_wgallowedip_nl_policy[WGALLOWEDIP_A_FLAGS + 1] = {
+ [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16, },
+ [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(4),
+ [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8, },
+ [WGALLOWEDIP_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, 0x1),
+};
+
+const struct nla_policy wireguard_wgpeer_nl_policy[WGPEER_A_PROTOCOL_VERSION + 1] = {
+ [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN),
+ [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN),
+ [WGPEER_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, 0x7),
+ [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(16),
+ [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16, },
+ [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(16),
+ [WGPEER_A_RX_BYTES] = { .type = NLA_U64, },
+ [WGPEER_A_TX_BYTES] = { .type = NLA_U64, },
+ [WGPEER_A_ALLOWEDIPS] = NLA_POLICY_NESTED_ARRAY(wireguard_wgallowedip_nl_policy),
+ [WGPEER_A_PROTOCOL_VERSION] = { .type = NLA_U32, },
+};
+
+/* WG_CMD_GET_DEVICE - dump */
+static const struct nla_policy wireguard_get_device_nl_policy[WGDEVICE_A_IFNAME + 1] = {
+ [WGDEVICE_A_IFINDEX] = { .type = NLA_U32, },
+ [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = 15, },
+};
+
+/* WG_CMD_SET_DEVICE - do */
+static const struct nla_policy wireguard_set_device_nl_policy[WGDEVICE_A_PEERS + 1] = {
+ [WGDEVICE_A_IFINDEX] = { .type = NLA_U32, },
+ [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = 15, },
+ [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN),
+ [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN),
+ [WGDEVICE_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, 0x1),
+ [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16, },
+ [WGDEVICE_A_FWMARK] = { .type = NLA_U32, },
+ [WGDEVICE_A_PEERS] = NLA_POLICY_NESTED_ARRAY(wireguard_wgpeer_nl_policy),
+};
+
+/* Ops table for wireguard */
+const struct genl_split_ops wireguard_nl_ops[2] = {
+ {
+ .cmd = WG_CMD_GET_DEVICE,
+ .start = wg_get_device_start,
+ .dumpit = wg_get_device_dumpit,
+ .done = wg_get_device_done,
+ .policy = wireguard_get_device_nl_policy,
+ .maxattr = WGDEVICE_A_IFNAME,
+ .flags = GENL_UNS_ADMIN_PERM | GENL_CMD_CAP_DUMP,
+ },
+ {
+ .cmd = WG_CMD_SET_DEVICE,
+ .doit = wg_set_device_doit,
+ .policy = wireguard_set_device_nl_policy,
+ .maxattr = WGDEVICE_A_PEERS,
+ .flags = GENL_UNS_ADMIN_PERM | GENL_CMD_CAP_DO,
+ },
+};
diff --git a/drivers/net/wireguard/generated/netlink.h b/drivers/net/wireguard/generated/netlink.h
new file mode 100644
index 000000000000..5dc977ee9e7c
--- /dev/null
+++ b/drivers/net/wireguard/generated/netlink.h
@@ -0,0 +1,30 @@
+/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */
+/* Do not edit directly, auto-generated from: */
+/* Documentation/netlink/specs/wireguard.yaml */
+/* YNL-GEN kernel header */
+/* YNL-ARG --function-prefix wg */
+/* To regenerate run: tools/net/ynl/ynl-regen.sh */
+
+#ifndef _LINUX_WIREGUARD_GEN_H
+#define _LINUX_WIREGUARD_GEN_H
+
+#include <net/netlink.h>
+#include <net/genetlink.h>
+
+#include <uapi/linux/wireguard.h>
+#include <linux/time_types.h>
+
+/* Common nested types */
+extern const struct nla_policy wireguard_wgallowedip_nl_policy[WGALLOWEDIP_A_FLAGS + 1];
+extern const struct nla_policy wireguard_wgpeer_nl_policy[WGPEER_A_PROTOCOL_VERSION + 1];
+
+/* Ops table for wireguard */
+extern const struct genl_split_ops wireguard_nl_ops[2];
+
+int wg_get_device_start(struct netlink_callback *cb);
+int wg_get_device_done(struct netlink_callback *cb);
+
+int wg_get_device_dumpit(struct sk_buff *skb, struct netlink_callback *cb);
+int wg_set_device_doit(struct sk_buff *skb, struct genl_info *info);
+
+#endif /* _LINUX_WIREGUARD_GEN_H */
diff --git a/drivers/net/wireguard/main.c b/drivers/net/wireguard/main.c
index ee4da9ab8013..a00671b58701 100644
--- a/drivers/net/wireguard/main.c
+++ b/drivers/net/wireguard/main.c
@@ -14,7 +14,7 @@
#include <linux/init.h>
#include <linux/module.h>
-#include <linux/genetlink.h>
+#include <net/genetlink.h>
#include <net/rtnetlink.h>
static int __init wg_mod_init(void)
diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c
index dc09b75a3248..1da7e98d0d50 100644
--- a/drivers/net/wireguard/netlink.c
+++ b/drivers/net/wireguard/netlink.c
@@ -9,46 +9,17 @@
#include "socket.h"
#include "queueing.h"
#include "messages.h"
+#include "generated/netlink.h"
#include <uapi/linux/wireguard.h>
#include <linux/if.h>
#include <net/genetlink.h>
#include <net/sock.h>
-#include <crypto/algapi.h>
+#include <crypto/utils.h>
static struct genl_family genl_family;
-static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
- [WGDEVICE_A_IFINDEX] = { .type = NLA_U32 },
- [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 },
- [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
- [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
- [WGDEVICE_A_FLAGS] = { .type = NLA_U32 },
- [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16 },
- [WGDEVICE_A_FWMARK] = { .type = NLA_U32 },
- [WGDEVICE_A_PEERS] = { .type = NLA_NESTED }
-};
-
-static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
- [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
- [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN),
- [WGPEER_A_FLAGS] = { .type = NLA_U32 },
- [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)),
- [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16 },
- [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)),
- [WGPEER_A_RX_BYTES] = { .type = NLA_U64 },
- [WGPEER_A_TX_BYTES] = { .type = NLA_U64 },
- [WGPEER_A_ALLOWEDIPS] = { .type = NLA_NESTED },
- [WGPEER_A_PROTOCOL_VERSION] = { .type = NLA_U32 }
-};
-
-static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
- [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16 },
- [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(sizeof(struct in_addr)),
- [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 }
-};
-
static struct wg_device *lookup_interface(struct nlattr **attrs,
struct sk_buff *skb)
{
@@ -164,8 +135,8 @@ get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx)
if (!allowedips_node)
goto no_allowedips;
if (!ctx->allowedips_seq)
- ctx->allowedips_seq = peer->device->peer_allowedips.seq;
- else if (ctx->allowedips_seq != peer->device->peer_allowedips.seq)
+ ctx->allowedips_seq = ctx->wg->peer_allowedips.seq;
+ else if (ctx->allowedips_seq != ctx->wg->peer_allowedips.seq)
goto no_allowedips;
allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
@@ -196,7 +167,7 @@ err:
return -EMSGSIZE;
}
-static int wg_get_device_start(struct netlink_callback *cb)
+int wg_get_device_start(struct netlink_callback *cb)
{
struct wg_device *wg;
@@ -207,7 +178,7 @@ static int wg_get_device_start(struct netlink_callback *cb)
return 0;
}
-static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
+int wg_get_device_dumpit(struct sk_buff *skb, struct netlink_callback *cb)
{
struct wg_peer *peer, *next_peer_cursor;
struct dump_ctx *ctx = DUMP_CTX(cb);
@@ -255,17 +226,17 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
if (!peers_nest)
goto out;
ret = 0;
- /* If the last cursor was removed via list_del_init in peer_remove, then
+ lockdep_assert_held(&wg->device_update_lock);
+ /* If the last cursor was removed in peer_remove or peer_remove_all, then
* we just treat this the same as there being no more peers left. The
* reason is that seq_nr should indicate to userspace that this isn't a
* coherent dump anyway, so they'll try again.
*/
if (list_empty(&wg->peer_list) ||
- (ctx->next_peer && list_empty(&ctx->next_peer->peer_list))) {
+ (ctx->next_peer && ctx->next_peer->is_dead)) {
nla_nest_cancel(skb, peers_nest);
goto out;
}
- lockdep_assert_held(&wg->device_update_lock);
peer = list_prepare_entry(ctx->next_peer, &wg->peer_list, peer_list);
list_for_each_entry_continue(peer, &wg->peer_list, peer_list) {
if (get_peer(peer, skb, ctx)) {
@@ -301,7 +272,7 @@ out:
*/
}
-static int wg_get_device_done(struct netlink_callback *cb)
+int wg_get_device_done(struct netlink_callback *cb)
{
struct dump_ctx *ctx = DUMP_CTX(cb);
@@ -329,6 +300,7 @@ static int set_port(struct wg_device *wg, u16 port)
static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
{
int ret = -EINVAL;
+ u32 flags = 0;
u16 family;
u8 cidr;
@@ -337,19 +309,30 @@ static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
return ret;
family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]);
cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]);
+ if (attrs[WGALLOWEDIP_A_FLAGS])
+ flags = nla_get_u32(attrs[WGALLOWEDIP_A_FLAGS]);
if (family == AF_INET && cidr <= 32 &&
- nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
- ret = wg_allowedips_insert_v4(
- &peer->device->peer_allowedips,
- nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
- &peer->device->device_update_lock);
- else if (family == AF_INET6 && cidr <= 128 &&
- nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr))
- ret = wg_allowedips_insert_v6(
- &peer->device->peer_allowedips,
- nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
- &peer->device->device_update_lock);
+ nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) {
+ if (flags & WGALLOWEDIP_F_REMOVE_ME)
+ ret = wg_allowedips_remove_v4(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
+ peer, &peer->device->device_update_lock);
+ else
+ ret = wg_allowedips_insert_v4(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
+ peer, &peer->device->device_update_lock);
+ } else if (family == AF_INET6 && cidr <= 128 &&
+ nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) {
+ if (flags & WGALLOWEDIP_F_REMOVE_ME)
+ ret = wg_allowedips_remove_v6(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
+ peer, &peer->device->device_update_lock);
+ else
+ ret = wg_allowedips_insert_v6(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr,
+ peer, &peer->device->device_update_lock);
+ }
return ret;
}
@@ -373,9 +356,6 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs)
if (attrs[WGPEER_A_FLAGS])
flags = nla_get_u32(attrs[WGPEER_A_FLAGS]);
- ret = -EOPNOTSUPP;
- if (flags & ~__WGPEER_F_ALL)
- goto out;
ret = -EPFNOSUPPORT;
if (attrs[WGPEER_A_PROTOCOL_VERSION]) {
@@ -457,7 +437,7 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs)
nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) {
ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX,
- attr, allowedip_policy, NULL);
+ attr, NULL, NULL);
if (ret < 0)
goto out;
ret = set_allowedip(peer, allowedip);
@@ -490,7 +470,7 @@ out:
return ret;
}
-static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
+int wg_set_device_doit(struct sk_buff *skb, struct genl_info *info)
{
struct wg_device *wg = lookup_interface(info->attrs, skb);
u32 flags = 0;
@@ -506,9 +486,6 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
if (info->attrs[WGDEVICE_A_FLAGS])
flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]);
- ret = -EOPNOTSUPP;
- if (flags & ~__WGDEVICE_F_ALL)
- goto out;
if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
struct net *net;
@@ -586,7 +563,7 @@ skip_set_private_key:
nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) {
ret = nla_parse_nested(peer, WGPEER_A_MAX, attr,
- peer_policy, NULL);
+ NULL, NULL);
if (ret < 0)
goto out;
ret = set_peer(wg, peer);
@@ -607,34 +584,20 @@ out_nodev:
return ret;
}
-static const struct genl_ops genl_ops[] = {
- {
- .cmd = WG_CMD_GET_DEVICE,
- .start = wg_get_device_start,
- .dumpit = wg_get_device_dump,
- .done = wg_get_device_done,
- .flags = GENL_UNS_ADMIN_PERM
- }, {
- .cmd = WG_CMD_SET_DEVICE,
- .doit = wg_set_device,
- .flags = GENL_UNS_ADMIN_PERM
- }
-};
-
static struct genl_family genl_family __ro_after_init = {
- .ops = genl_ops,
- .n_ops = ARRAY_SIZE(genl_ops),
- .resv_start_op = WG_CMD_SET_DEVICE + 1,
+ .split_ops = wireguard_nl_ops,
+ .n_split_ops = ARRAY_SIZE(wireguard_nl_ops),
.name = WG_GENL_NAME,
.version = WG_GENL_VERSION,
- .maxattr = WGDEVICE_A_MAX,
.module = THIS_MODULE,
- .policy = device_policy,
.netnsok = true
};
int __init wg_genetlink_init(void)
{
+ BUILD_BUG_ON(WG_KEY_LEN != NOISE_PUBLIC_KEY_LEN);
+ BUILD_BUG_ON(WG_KEY_LEN != NOISE_SYMMETRIC_KEY_LEN);
+
return genl_register_family(&genl_family);
}
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index 720952b92e78..1fe8468f0bef 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -15,7 +15,7 @@
#include <linux/bitmap.h>
#include <linux/scatterlist.h>
#include <linux/highmem.h>
-#include <crypto/algapi.h>
+#include <crypto/utils.h>
/* This implements Noise_IKpsk2:
*
@@ -25,18 +25,18 @@
* <- e, ee, se, psk, {}
*/
-static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
-static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
+static const u8 handshake_name[37] __nonstring = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
+static const u8 identifier_name[34] __nonstring = "WireGuard v1 zx2c4 Jason@zx2c4.com";
static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
static atomic64_t keypair_counter = ATOMIC64_INIT(0);
void __init wg_noise_init(void)
{
- struct blake2s_state blake;
+ struct blake2s_ctx blake;
- blake2s(handshake_init_chaining_key, handshake_name, NULL,
- NOISE_HASH_LEN, sizeof(handshake_name), 0);
+ blake2s(NULL, 0, handshake_name, sizeof(handshake_name),
+ handshake_init_chaining_key, NOISE_HASH_LEN);
blake2s_init(&blake, NOISE_HASH_LEN);
blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
blake2s_update(&blake, identifier_name, sizeof(identifier_name));
@@ -304,33 +304,33 @@ void wg_noise_set_static_identity_private_key(
static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen)
{
- struct blake2s_state state;
+ struct blake2s_ctx blake;
u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 };
u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32));
int i;
if (keylen > BLAKE2S_BLOCK_SIZE) {
- blake2s_init(&state, BLAKE2S_HASH_SIZE);
- blake2s_update(&state, key, keylen);
- blake2s_final(&state, x_key);
+ blake2s_init(&blake, BLAKE2S_HASH_SIZE);
+ blake2s_update(&blake, key, keylen);
+ blake2s_final(&blake, x_key);
} else
memcpy(x_key, key, keylen);
for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
x_key[i] ^= 0x36;
- blake2s_init(&state, BLAKE2S_HASH_SIZE);
- blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
- blake2s_update(&state, in, inlen);
- blake2s_final(&state, i_hash);
+ blake2s_init(&blake, BLAKE2S_HASH_SIZE);
+ blake2s_update(&blake, x_key, BLAKE2S_BLOCK_SIZE);
+ blake2s_update(&blake, in, inlen);
+ blake2s_final(&blake, i_hash);
for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
x_key[i] ^= 0x5c ^ 0x36;
- blake2s_init(&state, BLAKE2S_HASH_SIZE);
- blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
- blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
- blake2s_final(&state, i_hash);
+ blake2s_init(&blake, BLAKE2S_HASH_SIZE);
+ blake2s_update(&blake, x_key, BLAKE2S_BLOCK_SIZE);
+ blake2s_update(&blake, i_hash, BLAKE2S_HASH_SIZE);
+ blake2s_final(&blake, i_hash);
memcpy(out, i_hash, BLAKE2S_HASH_SIZE);
memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE);
@@ -431,7 +431,7 @@ static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
{
- struct blake2s_state blake;
+ struct blake2s_ctx blake;
blake2s_init(&blake, NOISE_HASH_LEN);
blake2s_update(&blake, hash, NOISE_HASH_LEN);
diff --git a/drivers/net/wireguard/peer.h b/drivers/net/wireguard/peer.h
index 76e4d3128ad4..718fb42bdac7 100644
--- a/drivers/net/wireguard/peer.h
+++ b/drivers/net/wireguard/peer.h
@@ -20,7 +20,7 @@ struct wg_device;
struct endpoint {
union {
- struct sockaddr addr;
+ struct sockaddr_inet addr; /* Large enough for both address families */
struct sockaddr_in addr4;
struct sockaddr_in6 addr6;
};
diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h
index 1ea4f874e367..79b6d70de236 100644
--- a/drivers/net/wireguard/queueing.h
+++ b/drivers/net/wireguard/queueing.h
@@ -104,16 +104,11 @@ static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating)
static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id)
{
- unsigned int cpu = *stored_cpu, cpu_index, i;
+ unsigned int cpu = *stored_cpu;
+
+ while (unlikely(cpu >= nr_cpu_ids || !cpu_online(cpu)))
+ cpu = *stored_cpu = cpumask_nth(id % num_online_cpus(), cpu_online_mask);
- if (unlikely(cpu >= nr_cpu_ids ||
- !cpumask_test_cpu(cpu, cpu_online_mask))) {
- cpu_index = id % cpumask_weight(cpu_online_mask);
- cpu = cpumask_first(cpu_online_mask);
- for (i = 0; i < cpu_index; ++i)
- cpu = cpumask_next(cpu, cpu_online_mask);
- *stored_cpu = cpu;
- }
return cpu;
}
@@ -124,10 +119,10 @@ static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id)
*/
static inline int wg_cpumask_next_online(int *last_cpu)
{
- int cpu = cpumask_next(*last_cpu, cpu_online_mask);
+ int cpu = cpumask_next(READ_ONCE(*last_cpu), cpu_online_mask);
if (cpu >= nr_cpu_ids)
cpu = cpumask_first(cpu_online_mask);
- *last_cpu = cpu;
+ WRITE_ONCE(*last_cpu, cpu);
return cpu;
}
diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c
index 0b3f0c843550..eb8851113654 100644
--- a/drivers/net/wireguard/receive.c
+++ b/drivers/net/wireguard/receive.c
@@ -251,7 +251,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
- keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
+ READ_ONCE(keypair->receiving_counter.counter) >= REJECT_AFTER_MESSAGES)) {
WRITE_ONCE(keypair->receiving.is_valid, false);
return false;
}
@@ -263,7 +263,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
* call skb_cow_data, so that there's no chance that data is removed
* from the skb, so that later we can extract the original endpoint.
*/
- offset = skb->data - skb_network_header(skb);
+ offset = -skb_network_offset(skb);
skb_push(skb, offset);
num_frags = skb_cow_data(skb, 0, &trailer);
offset += sizeof(struct message_data);
@@ -318,7 +318,7 @@ static bool counter_validate(struct noise_replay_counter *counter, u64 their_cou
for (i = 1; i <= top; ++i)
counter->backtrack[(i + index_current) &
((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
- counter->counter = their_counter;
+ WRITE_ONCE(counter->counter, their_counter);
}
index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
@@ -416,20 +416,20 @@ dishonest_packet_peer:
net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n",
dev->name, skb, peer->internal_id,
&peer->endpoint.addr);
- ++dev->stats.rx_errors;
- ++dev->stats.rx_frame_errors;
+ DEV_STATS_INC(dev, rx_errors);
+ DEV_STATS_INC(dev, rx_frame_errors);
goto packet_processed;
dishonest_packet_type:
net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n",
dev->name, peer->internal_id, &peer->endpoint.addr);
- ++dev->stats.rx_errors;
- ++dev->stats.rx_frame_errors;
+ DEV_STATS_INC(dev, rx_errors);
+ DEV_STATS_INC(dev, rx_frame_errors);
goto packet_processed;
dishonest_packet_size:
net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n",
dev->name, peer->internal_id, &peer->endpoint.addr);
- ++dev->stats.rx_errors;
- ++dev->stats.rx_length_errors;
+ DEV_STATS_INC(dev, rx_errors);
+ DEV_STATS_INC(dev, rx_length_errors);
goto packet_processed;
packet_processed:
dev_kfree_skb(skb);
@@ -463,7 +463,7 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget)
net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
peer->device->dev->name,
PACKET_CB(skb)->nonce,
- keypair->receiving_counter.counter);
+ READ_ONCE(keypair->receiving_counter.counter));
goto next;
}
diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c
index 3d1f64ff2e12..41837efa70cb 100644
--- a/drivers/net/wireguard/selftest/allowedips.c
+++ b/drivers/net/wireguard/selftest/allowedips.c
@@ -383,7 +383,6 @@ static __init bool randomized_test(void)
for (i = 0; i < NUM_QUERIES; ++i) {
get_random_bytes(ip, 4);
if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
- horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip);
pr_err("allowedips random v4 self-test: FAIL\n");
goto free;
}
@@ -461,6 +460,10 @@ static __init struct wg_peer *init_peer(void)
wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
cidr, mem, &mutex)
+#define remove(version, mem, ipa, ipb, ipc, ipd, cidr) \
+ wg_allowedips_remove_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
+ cidr, mem, &mutex)
+
#define maybe_fail() do { \
++i; \
if (!_s) { \
@@ -586,6 +589,50 @@ bool __init wg_allowedips_selftest(void)
test_negative(4, a, 192, 0, 0, 0);
test_negative(4, a, 255, 0, 0, 0);
+ insert(4, a, 1, 0, 0, 0, 32);
+ insert(4, a, 192, 0, 0, 0, 24);
+ insert(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
+ insert(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ test(4, a, 1, 0, 0, 0);
+ test(4, a, 192, 0, 0, 1);
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010);
+ /* Must be an exact match to remove */
+ remove(4, a, 192, 0, 0, 0, 32);
+ test(4, a, 192, 0, 0, 1);
+ /* NULL peer should have no effect and return 0 */
+ test_boolean(!remove(4, NULL, 192, 0, 0, 0, 24));
+ test(4, a, 192, 0, 0, 1);
+ /* different peer should have no effect and return 0 */
+ test_boolean(!remove(4, b, 192, 0, 0, 0, 24));
+ test(4, a, 192, 0, 0, 1);
+ /* invalid CIDR should have no effect and return -EINVAL */
+ test_boolean(remove(4, b, 192, 0, 0, 0, 33) == -EINVAL);
+ test(4, a, 192, 0, 0, 1);
+ remove(4, a, 192, 0, 0, 0, 24);
+ test_negative(4, a, 192, 0, 0, 1);
+ remove(4, a, 1, 0, 0, 0, 32);
+ test_negative(4, a, 1, 0, 0, 0);
+ /* Must be an exact match to remove */
+ remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96);
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ /* NULL peer should have no effect and return 0 */
+ test_boolean(!remove(6, NULL, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128));
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ /* different peer should have no effect and return 0 */
+ test_boolean(!remove(6, b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128));
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ /* invalid CIDR should have no effect and return -EINVAL */
+ test_boolean(remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 129) == -EINVAL);
+ test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
+ test_negative(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef);
+ /* Must match the peer to remove */
+ remove(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010);
+ remove(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ test_negative(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010);
+
wg_allowedips_free(&t, &mutex);
wg_allowedips_init(&t);
insert(4, a, 192, 168, 0, 0, 16);
diff --git a/drivers/net/wireguard/send.c b/drivers/net/wireguard/send.c
index 95c853b59e1d..26e09c30d596 100644
--- a/drivers/net/wireguard/send.c
+++ b/drivers/net/wireguard/send.c
@@ -222,7 +222,7 @@ void wg_packet_send_keepalive(struct wg_peer *peer)
{
struct sk_buff *skb;
- if (skb_queue_empty(&peer->staged_packet_queue)) {
+ if (skb_queue_empty_lockless(&peer->staged_packet_queue)) {
skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH,
GFP_ATOMIC);
if (unlikely(!skb))
@@ -333,7 +333,8 @@ err:
void wg_packet_purge_staged_packets(struct wg_peer *peer)
{
spin_lock_bh(&peer->staged_packet_queue.lock);
- peer->device->dev->stats.tx_dropped += peer->staged_packet_queue.qlen;
+ DEV_STATS_ADD(peer->device->dev, tx_dropped,
+ peer->staged_packet_queue.qlen);
__skb_queue_purge(&peer->staged_packet_queue);
spin_unlock_bh(&peer->staged_packet_queue.lock);
}
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index 0414d7a6ce74..253488f8c00f 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -84,7 +84,7 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
skb->ignore_df = 1;
udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds,
ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport,
- fl.fl4_dport, false, false);
+ fl.fl4_dport, false, false, 0);
goto out;
err:
@@ -151,7 +151,7 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
skb->ignore_df = 1;
udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds,
ip6_dst_hoplimit(dst), 0, fl.fl6_sport,
- fl.fl6_dport, false);
+ fl.fl6_dport, false, 0);
goto out;
err:
diff --git a/drivers/net/wireguard/timers.c b/drivers/net/wireguard/timers.c
index 968bdb4df0b3..4016a3065602 100644
--- a/drivers/net/wireguard/timers.c
+++ b/drivers/net/wireguard/timers.c
@@ -40,15 +40,15 @@ static inline void mod_peer_timer(struct wg_peer *peer,
static void wg_expired_retransmit_handshake(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer,
- timer_retransmit_handshake);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_retransmit_handshake);
if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) {
pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n",
peer->device->dev->name, peer->internal_id,
&peer->endpoint.addr, (int)MAX_TIMER_HANDSHAKES + 2);
- del_timer(&peer->timer_send_keepalive);
+ timer_delete(&peer->timer_send_keepalive);
/* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake.
*/
@@ -78,7 +78,8 @@ static void wg_expired_retransmit_handshake(struct timer_list *timer)
static void wg_expired_send_keepalive(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer, timer_send_keepalive);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_send_keepalive);
wg_packet_send_keepalive(peer);
if (peer->timer_need_another_keepalive) {
@@ -90,7 +91,8 @@ static void wg_expired_send_keepalive(struct timer_list *timer)
static void wg_expired_new_handshake(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer, timer_new_handshake);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_new_handshake);
pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n",
peer->device->dev->name, peer->internal_id,
@@ -104,7 +106,8 @@ static void wg_expired_new_handshake(struct timer_list *timer)
static void wg_expired_zero_key_material(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer, timer_zero_key_material);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_zero_key_material);
rcu_read_lock_bh();
if (!READ_ONCE(peer->is_dead)) {
@@ -134,8 +137,8 @@ static void wg_queued_expired_zero_key_material(struct work_struct *work)
static void wg_expired_send_persistent_keepalive(struct timer_list *timer)
{
- struct wg_peer *peer = from_timer(peer, timer,
- timer_persistent_keepalive);
+ struct wg_peer *peer = timer_container_of(peer, timer,
+ timer_persistent_keepalive);
if (likely(peer->persistent_keepalive_interval))
wg_packet_send_keepalive(peer);
@@ -167,7 +170,7 @@ void wg_timers_data_received(struct wg_peer *peer)
*/
void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer)
{
- del_timer(&peer->timer_send_keepalive);
+ timer_delete(&peer->timer_send_keepalive);
}
/* Should be called after any type of authenticated packet is received, whether
@@ -175,7 +178,7 @@ void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer)
*/
void wg_timers_any_authenticated_packet_received(struct wg_peer *peer)
{
- del_timer(&peer->timer_new_handshake);
+ timer_delete(&peer->timer_new_handshake);
}
/* Should be called after a handshake initiation message is sent. */
@@ -191,7 +194,7 @@ void wg_timers_handshake_initiated(struct wg_peer *peer)
*/
void wg_timers_handshake_complete(struct wg_peer *peer)
{
- del_timer(&peer->timer_retransmit_handshake);
+ timer_delete(&peer->timer_retransmit_handshake);
peer->timer_handshake_attempts = 0;
peer->sent_lastminute_handshake = false;
ktime_get_real_ts64(&peer->walltime_last_handshake);