diff options
Diffstat (limited to 'drivers/net/wireguard')
| -rw-r--r-- | drivers/net/wireguard/Makefile | 2 | ||||
| -rw-r--r-- | drivers/net/wireguard/allowedips.c | 106 | ||||
| -rw-r--r-- | drivers/net/wireguard/allowedips.h | 4 | ||||
| -rw-r--r-- | drivers/net/wireguard/cookie.c | 24 | ||||
| -rw-r--r-- | drivers/net/wireguard/device.c | 36 | ||||
| -rw-r--r-- | drivers/net/wireguard/generated/netlink.c | 73 | ||||
| -rw-r--r-- | drivers/net/wireguard/generated/netlink.h | 30 | ||||
| -rw-r--r-- | drivers/net/wireguard/main.c | 2 | ||||
| -rw-r--r-- | drivers/net/wireguard/netlink.c | 119 | ||||
| -rw-r--r-- | drivers/net/wireguard/noise.c | 38 | ||||
| -rw-r--r-- | drivers/net/wireguard/peer.h | 2 | ||||
| -rw-r--r-- | drivers/net/wireguard/queueing.h | 17 | ||||
| -rw-r--r-- | drivers/net/wireguard/receive.c | 20 | ||||
| -rw-r--r-- | drivers/net/wireguard/selftest/allowedips.c | 49 | ||||
| -rw-r--r-- | drivers/net/wireguard/send.c | 5 | ||||
| -rw-r--r-- | drivers/net/wireguard/socket.c | 4 | ||||
| -rw-r--r-- | drivers/net/wireguard/timers.c | 25 |
17 files changed, 355 insertions, 201 deletions
diff --git a/drivers/net/wireguard/Makefile b/drivers/net/wireguard/Makefile index dbe1f8514efc..00cbcc9ab69d 100644 --- a/drivers/net/wireguard/Makefile +++ b/drivers/net/wireguard/Makefile @@ -13,5 +13,5 @@ wireguard-y += peerlookup.o wireguard-y += allowedips.o wireguard-y += ratelimiter.o wireguard-y += cookie.o -wireguard-y += netlink.o +wireguard-y += netlink.o generated/netlink.o obj-$(CONFIG_WIREGUARD) := wireguard.o diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c index 0ba714ca5185..09f7fcd7da78 100644 --- a/drivers/net/wireguard/allowedips.c +++ b/drivers/net/wireguard/allowedips.c @@ -15,8 +15,8 @@ static void swap_endian(u8 *dst, const u8 *src, u8 bits) if (bits == 32) { *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); } else if (bits == 128) { - ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]); - ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]); + ((u64 *)dst)[0] = get_unaligned_be64(src); + ((u64 *)dst)[1] = get_unaligned_be64(src + 8); } } @@ -249,6 +249,52 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, return 0; } +static void remove_node(struct allowedips_node *node, struct mutex *lock) +{ + struct allowedips_node *child, **parent_bit, *parent; + bool free_parent; + + list_del_init(&node->peer_list); + RCU_INIT_POINTER(node->peer, NULL); + if (node->bit[0] && node->bit[1]) + return; + child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], + lockdep_is_held(lock)); + if (child) + child->parent_bit_packed = node->parent_bit_packed; + parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); + *parent_bit = child; + parent = (void *)parent_bit - + offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); + free_parent = !rcu_access_pointer(node->bit[0]) && !rcu_access_pointer(node->bit[1]) && + (node->parent_bit_packed & 3) <= 1 && !rcu_access_pointer(parent->peer); + if (free_parent) + child = rcu_dereference_protected(parent->bit[!(node->parent_bit_packed & 1)], + lockdep_is_held(lock)); + call_rcu(&node->rcu, node_free_rcu); + if (!free_parent) + return; + if (child) + child->parent_bit_packed = parent->parent_bit_packed; + *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; + call_rcu(&parent->rcu, node_free_rcu); +} + +static int remove(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + struct allowedips_node *node; + + if (unlikely(cidr > bits)) + return -EINVAL; + if (!rcu_access_pointer(*trie) || !node_placement(*trie, key, cidr, bits, &node, lock) || + peer != rcu_access_pointer(node->peer)) + return 0; + + remove_node(node, lock); + return 0; +} + void wg_allowedips_init(struct allowedips *table) { table->root4 = table->root6 = NULL; @@ -300,44 +346,38 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, return add(&table->root6, 128, key, cidr, peer, lock); } +int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + /* Aligned so it can be passed to fls */ + u8 key[4] __aligned(__alignof(u32)); + + ++table->seq; + swap_endian(key, (const u8 *)ip, 32); + return remove(&table->root4, 32, key, cidr, peer, lock); +} + +int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + /* Aligned so it can be passed to fls64 */ + u8 key[16] __aligned(__alignof(u64)); + + ++table->seq; + swap_endian(key, (const u8 *)ip, 128); + return remove(&table->root6, 128, key, cidr, peer, lock); +} + void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock) { - struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; - bool free_parent; + struct allowedips_node *node, *tmp; if (list_empty(&peer->allowedips_list)) return; ++table->seq; - list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { - list_del_init(&node->peer_list); - RCU_INIT_POINTER(node->peer, NULL); - if (node->bit[0] && node->bit[1]) - continue; - child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], - lockdep_is_held(lock)); - if (child) - child->parent_bit_packed = node->parent_bit_packed; - parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); - *parent_bit = child; - parent = (void *)parent_bit - - offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); - free_parent = !rcu_access_pointer(node->bit[0]) && - !rcu_access_pointer(node->bit[1]) && - (node->parent_bit_packed & 3) <= 1 && - !rcu_access_pointer(parent->peer); - if (free_parent) - child = rcu_dereference_protected( - parent->bit[!(node->parent_bit_packed & 1)], - lockdep_is_held(lock)); - call_rcu(&node->rcu, node_free_rcu); - if (!free_parent) - continue; - if (child) - child->parent_bit_packed = parent->parent_bit_packed; - *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; - call_rcu(&parent->rcu, node_free_rcu); - } + list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) + remove_node(node, lock); } int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h index 2346c797eb4d..931958cb6e10 100644 --- a/drivers/net/wireguard/allowedips.h +++ b/drivers/net/wireguard/allowedips.h @@ -38,6 +38,10 @@ int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock); int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock); +int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock); +int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock); void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock); /* The ip input pointer should be __aligned(__alignof(u64))) */ diff --git a/drivers/net/wireguard/cookie.c b/drivers/net/wireguard/cookie.c index 4956f0499c19..08731b3fa32b 100644 --- a/drivers/net/wireguard/cookie.c +++ b/drivers/net/wireguard/cookie.c @@ -12,9 +12,9 @@ #include <crypto/blake2s.h> #include <crypto/chacha20poly1305.h> +#include <crypto/utils.h> #include <net/ipv6.h> -#include <crypto/algapi.h> void wg_cookie_checker_init(struct cookie_checker *checker, struct wg_device *wg) @@ -26,14 +26,14 @@ void wg_cookie_checker_init(struct cookie_checker *checker, } enum { COOKIE_KEY_LABEL_LEN = 8 }; -static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----"; -static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--"; +static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "mac1----"; +static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "cookie--"; static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOISE_PUBLIC_KEY_LEN], const u8 label[COOKIE_KEY_LABEL_LEN]) { - struct blake2s_state blake; + struct blake2s_ctx blake; blake2s_init(&blake, NOISE_SYMMETRIC_KEY_LEN); blake2s_update(&blake, label, COOKIE_KEY_LABEL_LEN); @@ -77,7 +77,7 @@ static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len, { len = len - sizeof(struct message_macs) + offsetof(struct message_macs, mac1); - blake2s(mac1, message, key, COOKIE_LEN, len, NOISE_SYMMETRIC_KEY_LEN); + blake2s(key, NOISE_SYMMETRIC_KEY_LEN, message, len, mac1, COOKIE_LEN); } static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len, @@ -85,13 +85,13 @@ static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len, { len = len - sizeof(struct message_macs) + offsetof(struct message_macs, mac2); - blake2s(mac2, message, cookie, COOKIE_LEN, len, COOKIE_LEN); + blake2s(cookie, COOKIE_LEN, message, len, mac2, COOKIE_LEN); } static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, struct cookie_checker *checker) { - struct blake2s_state state; + struct blake2s_ctx blake; if (wg_birthdate_has_expired(checker->secret_birthdate, COOKIE_SECRET_MAX_AGE)) { @@ -103,15 +103,15 @@ static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, down_read(&checker->secret_lock); - blake2s_init_key(&state, COOKIE_LEN, checker->secret, NOISE_HASH_LEN); + blake2s_init_key(&blake, COOKIE_LEN, checker->secret, NOISE_HASH_LEN); if (skb->protocol == htons(ETH_P_IP)) - blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr, + blake2s_update(&blake, (u8 *)&ip_hdr(skb)->saddr, sizeof(struct in_addr)); else if (skb->protocol == htons(ETH_P_IPV6)) - blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr, + blake2s_update(&blake, (u8 *)&ipv6_hdr(skb)->saddr, sizeof(struct in6_addr)); - blake2s_update(&state, (u8 *)&udp_hdr(skb)->source, sizeof(__be16)); - blake2s_final(&state, cookie); + blake2s_update(&blake, (u8 *)&udp_hdr(skb)->source, sizeof(__be16)); + blake2s_final(&blake, cookie); up_read(&checker->secret_lock); } diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c index 258dcc103921..46a71ec36af8 100644 --- a/drivers/net/wireguard/device.c +++ b/drivers/net/wireguard/device.c @@ -81,7 +81,7 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action, v list_for_each_entry(wg, &device_list, device_list) { mutex_lock(&wg->device_update_lock); list_for_each_entry(peer, &wg->peer_list, peer_list) { - del_timer(&peer->timer_zero_key_material); + timer_delete(&peer->timer_zero_key_material); wg_noise_handshake_clear(&peer->handshake); wg_noise_keypairs_clear(&peer->keypairs); } @@ -210,7 +210,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev) */ while (skb_queue_len(&peer->staged_packet_queue) > MAX_STAGED_PACKETS) { dev_kfree_skb(__skb_dequeue(&peer->staged_packet_queue)); - ++dev->stats.tx_dropped; + DEV_STATS_INC(dev, tx_dropped); } skb_queue_splice_tail(&packets, &peer->staged_packet_queue); spin_unlock_bh(&peer->staged_packet_queue.lock); @@ -228,7 +228,7 @@ err_icmp: else if (skb->protocol == htons(ETH_P_IPV6)) icmpv6_ndo_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0); err: - ++dev->stats.tx_errors; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); return ret; } @@ -237,7 +237,6 @@ static const struct net_device_ops netdev_ops = { .ndo_open = wg_open, .ndo_stop = wg_stop, .ndo_start_xmit = wg_xmit, - .ndo_get_stats64 = dev_get_tstats64 }; static void wg_destruct(struct net_device *dev) @@ -262,7 +261,6 @@ static void wg_destruct(struct net_device *dev) rcu_barrier(); /* Wait for all the peers to be actually freed. */ wg_ratelimiter_uninit(); memzero_explicit(&wg->static_identity, sizeof(wg->static_identity)); - free_percpu(dev->tstats); kvfree(wg->index_hashtable); kvfree(wg->peer_hashtable); mutex_unlock(&wg->device_update_lock); @@ -291,30 +289,33 @@ static void wg_setup(struct net_device *dev) dev->type = ARPHRD_NONE; dev->flags = IFF_POINTOPOINT | IFF_NOARP; dev->priv_flags |= IFF_NO_QUEUE; - dev->features |= NETIF_F_LLTX; + dev->lltx = true; dev->features |= WG_NETDEV_FEATURES; dev->hw_features |= WG_NETDEV_FEATURES; dev->hw_enc_features |= WG_NETDEV_FEATURES; dev->mtu = ETH_DATA_LEN - overhead; dev->max_mtu = round_down(INT_MAX, MESSAGE_PADDING_MULTIPLE) - overhead; + dev->pcpu_stat_type = NETDEV_PCPU_STAT_TSTATS; SET_NETDEV_DEVTYPE(dev, &device_type); /* We need to keep the dst around in case of icmp replies. */ netif_keep_dst(dev); - memset(wg, 0, sizeof(*wg)); + netif_set_tso_max_size(dev, GSO_MAX_SIZE); + wg->dev = dev; } -static int wg_newlink(struct net *src_net, struct net_device *dev, - struct nlattr *tb[], struct nlattr *data[], +static int wg_newlink(struct net_device *dev, + struct rtnl_newlink_params *params, struct netlink_ext_ack *extack) { + struct net *link_net = rtnl_newlink_link_net(params); struct wg_device *wg = netdev_priv(dev); int ret = -ENOMEM; - rcu_assign_pointer(wg->creating_net, src_net); + rcu_assign_pointer(wg->creating_net, link_net); init_rwsem(&wg->static_identity.lock); mutex_init(&wg->socket_update_lock); mutex_init(&wg->device_update_lock); @@ -331,14 +332,11 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, if (!wg->index_hashtable) goto err_free_peer_hashtable; - dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); - if (!dev->tstats) - goto err_free_index_hashtable; - wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s", - WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name); + WQ_CPU_INTENSIVE | WQ_FREEZABLE | WQ_PERCPU, 0, + dev->name); if (!wg->handshake_receive_wq) - goto err_free_tstats; + goto err_free_index_hashtable; wg->handshake_send_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); @@ -346,7 +344,8 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, goto err_destroy_handshake_receive; wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s", - WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name); + WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM | WQ_PERCPU, 0, + dev->name); if (!wg->packet_crypt_wq) goto err_destroy_handshake_send; @@ -369,6 +368,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, if (ret < 0) goto err_free_handshake_queue; + netif_threaded_enable(dev); ret = register_netdevice(dev); if (ret < 0) goto err_uninit_ratelimiter; @@ -397,8 +397,6 @@ err_destroy_handshake_send: destroy_workqueue(wg->handshake_send_wq); err_destroy_handshake_receive: destroy_workqueue(wg->handshake_receive_wq); -err_free_tstats: - free_percpu(dev->tstats); err_free_index_hashtable: kvfree(wg->index_hashtable); err_free_peer_hashtable: diff --git a/drivers/net/wireguard/generated/netlink.c b/drivers/net/wireguard/generated/netlink.c new file mode 100644 index 000000000000..3ef8c29908c2 --- /dev/null +++ b/drivers/net/wireguard/generated/netlink.c @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) +/* Do not edit directly, auto-generated from: */ +/* Documentation/netlink/specs/wireguard.yaml */ +/* YNL-GEN kernel source */ +/* YNL-ARG --function-prefix wg */ +/* To regenerate run: tools/net/ynl/ynl-regen.sh */ + +#include <net/netlink.h> +#include <net/genetlink.h> + +#include "netlink.h" + +#include <uapi/linux/wireguard.h> +#include <linux/time_types.h> + +/* Common nested types */ +const struct nla_policy wireguard_wgallowedip_nl_policy[WGALLOWEDIP_A_FLAGS + 1] = { + [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16, }, + [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(4), + [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8, }, + [WGALLOWEDIP_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, 0x1), +}; + +const struct nla_policy wireguard_wgpeer_nl_policy[WGPEER_A_PROTOCOL_VERSION + 1] = { + [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN), + [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN), + [WGPEER_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, 0x7), + [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(16), + [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16, }, + [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(16), + [WGPEER_A_RX_BYTES] = { .type = NLA_U64, }, + [WGPEER_A_TX_BYTES] = { .type = NLA_U64, }, + [WGPEER_A_ALLOWEDIPS] = NLA_POLICY_NESTED_ARRAY(wireguard_wgallowedip_nl_policy), + [WGPEER_A_PROTOCOL_VERSION] = { .type = NLA_U32, }, +}; + +/* WG_CMD_GET_DEVICE - dump */ +static const struct nla_policy wireguard_get_device_nl_policy[WGDEVICE_A_IFNAME + 1] = { + [WGDEVICE_A_IFINDEX] = { .type = NLA_U32, }, + [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = 15, }, +}; + +/* WG_CMD_SET_DEVICE - do */ +static const struct nla_policy wireguard_set_device_nl_policy[WGDEVICE_A_PEERS + 1] = { + [WGDEVICE_A_IFINDEX] = { .type = NLA_U32, }, + [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = 15, }, + [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN), + [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(WG_KEY_LEN), + [WGDEVICE_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, 0x1), + [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16, }, + [WGDEVICE_A_FWMARK] = { .type = NLA_U32, }, + [WGDEVICE_A_PEERS] = NLA_POLICY_NESTED_ARRAY(wireguard_wgpeer_nl_policy), +}; + +/* Ops table for wireguard */ +const struct genl_split_ops wireguard_nl_ops[2] = { + { + .cmd = WG_CMD_GET_DEVICE, + .start = wg_get_device_start, + .dumpit = wg_get_device_dumpit, + .done = wg_get_device_done, + .policy = wireguard_get_device_nl_policy, + .maxattr = WGDEVICE_A_IFNAME, + .flags = GENL_UNS_ADMIN_PERM | GENL_CMD_CAP_DUMP, + }, + { + .cmd = WG_CMD_SET_DEVICE, + .doit = wg_set_device_doit, + .policy = wireguard_set_device_nl_policy, + .maxattr = WGDEVICE_A_PEERS, + .flags = GENL_UNS_ADMIN_PERM | GENL_CMD_CAP_DO, + }, +}; diff --git a/drivers/net/wireguard/generated/netlink.h b/drivers/net/wireguard/generated/netlink.h new file mode 100644 index 000000000000..5dc977ee9e7c --- /dev/null +++ b/drivers/net/wireguard/generated/netlink.h @@ -0,0 +1,30 @@ +/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */ +/* Do not edit directly, auto-generated from: */ +/* Documentation/netlink/specs/wireguard.yaml */ +/* YNL-GEN kernel header */ +/* YNL-ARG --function-prefix wg */ +/* To regenerate run: tools/net/ynl/ynl-regen.sh */ + +#ifndef _LINUX_WIREGUARD_GEN_H +#define _LINUX_WIREGUARD_GEN_H + +#include <net/netlink.h> +#include <net/genetlink.h> + +#include <uapi/linux/wireguard.h> +#include <linux/time_types.h> + +/* Common nested types */ +extern const struct nla_policy wireguard_wgallowedip_nl_policy[WGALLOWEDIP_A_FLAGS + 1]; +extern const struct nla_policy wireguard_wgpeer_nl_policy[WGPEER_A_PROTOCOL_VERSION + 1]; + +/* Ops table for wireguard */ +extern const struct genl_split_ops wireguard_nl_ops[2]; + +int wg_get_device_start(struct netlink_callback *cb); +int wg_get_device_done(struct netlink_callback *cb); + +int wg_get_device_dumpit(struct sk_buff *skb, struct netlink_callback *cb); +int wg_set_device_doit(struct sk_buff *skb, struct genl_info *info); + +#endif /* _LINUX_WIREGUARD_GEN_H */ diff --git a/drivers/net/wireguard/main.c b/drivers/net/wireguard/main.c index ee4da9ab8013..a00671b58701 100644 --- a/drivers/net/wireguard/main.c +++ b/drivers/net/wireguard/main.c @@ -14,7 +14,7 @@ #include <linux/init.h> #include <linux/module.h> -#include <linux/genetlink.h> +#include <net/genetlink.h> #include <net/rtnetlink.h> static int __init wg_mod_init(void) diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c index dc09b75a3248..1da7e98d0d50 100644 --- a/drivers/net/wireguard/netlink.c +++ b/drivers/net/wireguard/netlink.c @@ -9,46 +9,17 @@ #include "socket.h" #include "queueing.h" #include "messages.h" +#include "generated/netlink.h" #include <uapi/linux/wireguard.h> #include <linux/if.h> #include <net/genetlink.h> #include <net/sock.h> -#include <crypto/algapi.h> +#include <crypto/utils.h> static struct genl_family genl_family; -static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = { - [WGDEVICE_A_IFINDEX] = { .type = NLA_U32 }, - [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 }, - [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), - [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), - [WGDEVICE_A_FLAGS] = { .type = NLA_U32 }, - [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16 }, - [WGDEVICE_A_FWMARK] = { .type = NLA_U32 }, - [WGDEVICE_A_PEERS] = { .type = NLA_NESTED } -}; - -static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = { - [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), - [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN), - [WGPEER_A_FLAGS] = { .type = NLA_U32 }, - [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)), - [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16 }, - [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)), - [WGPEER_A_RX_BYTES] = { .type = NLA_U64 }, - [WGPEER_A_TX_BYTES] = { .type = NLA_U64 }, - [WGPEER_A_ALLOWEDIPS] = { .type = NLA_NESTED }, - [WGPEER_A_PROTOCOL_VERSION] = { .type = NLA_U32 } -}; - -static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = { - [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16 }, - [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(sizeof(struct in_addr)), - [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 } -}; - static struct wg_device *lookup_interface(struct nlattr **attrs, struct sk_buff *skb) { @@ -164,8 +135,8 @@ get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx) if (!allowedips_node) goto no_allowedips; if (!ctx->allowedips_seq) - ctx->allowedips_seq = peer->device->peer_allowedips.seq; - else if (ctx->allowedips_seq != peer->device->peer_allowedips.seq) + ctx->allowedips_seq = ctx->wg->peer_allowedips.seq; + else if (ctx->allowedips_seq != ctx->wg->peer_allowedips.seq) goto no_allowedips; allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); @@ -196,7 +167,7 @@ err: return -EMSGSIZE; } -static int wg_get_device_start(struct netlink_callback *cb) +int wg_get_device_start(struct netlink_callback *cb) { struct wg_device *wg; @@ -207,7 +178,7 @@ static int wg_get_device_start(struct netlink_callback *cb) return 0; } -static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) +int wg_get_device_dumpit(struct sk_buff *skb, struct netlink_callback *cb) { struct wg_peer *peer, *next_peer_cursor; struct dump_ctx *ctx = DUMP_CTX(cb); @@ -255,17 +226,17 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) if (!peers_nest) goto out; ret = 0; - /* If the last cursor was removed via list_del_init in peer_remove, then + lockdep_assert_held(&wg->device_update_lock); + /* If the last cursor was removed in peer_remove or peer_remove_all, then * we just treat this the same as there being no more peers left. The * reason is that seq_nr should indicate to userspace that this isn't a * coherent dump anyway, so they'll try again. */ if (list_empty(&wg->peer_list) || - (ctx->next_peer && list_empty(&ctx->next_peer->peer_list))) { + (ctx->next_peer && ctx->next_peer->is_dead)) { nla_nest_cancel(skb, peers_nest); goto out; } - lockdep_assert_held(&wg->device_update_lock); peer = list_prepare_entry(ctx->next_peer, &wg->peer_list, peer_list); list_for_each_entry_continue(peer, &wg->peer_list, peer_list) { if (get_peer(peer, skb, ctx)) { @@ -301,7 +272,7 @@ out: */ } -static int wg_get_device_done(struct netlink_callback *cb) +int wg_get_device_done(struct netlink_callback *cb) { struct dump_ctx *ctx = DUMP_CTX(cb); @@ -329,6 +300,7 @@ static int set_port(struct wg_device *wg, u16 port) static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) { int ret = -EINVAL; + u32 flags = 0; u16 family; u8 cidr; @@ -337,19 +309,30 @@ static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) return ret; family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]); cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]); + if (attrs[WGALLOWEDIP_A_FLAGS]) + flags = nla_get_u32(attrs[WGALLOWEDIP_A_FLAGS]); if (family == AF_INET && cidr <= 32 && - nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) - ret = wg_allowedips_insert_v4( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); - else if (family == AF_INET6 && cidr <= 128 && - nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) - ret = wg_allowedips_insert_v6( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) { + if (flags & WGALLOWEDIP_F_REMOVE_ME) + ret = wg_allowedips_remove_v4(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + else + ret = wg_allowedips_insert_v4(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + } else if (family == AF_INET6 && cidr <= 128 && + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) { + if (flags & WGALLOWEDIP_F_REMOVE_ME) + ret = wg_allowedips_remove_v6(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + else + ret = wg_allowedips_insert_v6(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + } return ret; } @@ -373,9 +356,6 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs) if (attrs[WGPEER_A_FLAGS]) flags = nla_get_u32(attrs[WGPEER_A_FLAGS]); - ret = -EOPNOTSUPP; - if (flags & ~__WGPEER_F_ALL) - goto out; ret = -EPFNOSUPPORT; if (attrs[WGPEER_A_PROTOCOL_VERSION]) { @@ -457,7 +437,7 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs) nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) { ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX, - attr, allowedip_policy, NULL); + attr, NULL, NULL); if (ret < 0) goto out; ret = set_allowedip(peer, allowedip); @@ -490,7 +470,7 @@ out: return ret; } -static int wg_set_device(struct sk_buff *skb, struct genl_info *info) +int wg_set_device_doit(struct sk_buff *skb, struct genl_info *info) { struct wg_device *wg = lookup_interface(info->attrs, skb); u32 flags = 0; @@ -506,9 +486,6 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) if (info->attrs[WGDEVICE_A_FLAGS]) flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]); - ret = -EOPNOTSUPP; - if (flags & ~__WGDEVICE_F_ALL) - goto out; if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) { struct net *net; @@ -586,7 +563,7 @@ skip_set_private_key: nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) { ret = nla_parse_nested(peer, WGPEER_A_MAX, attr, - peer_policy, NULL); + NULL, NULL); if (ret < 0) goto out; ret = set_peer(wg, peer); @@ -607,34 +584,20 @@ out_nodev: return ret; } -static const struct genl_ops genl_ops[] = { - { - .cmd = WG_CMD_GET_DEVICE, - .start = wg_get_device_start, - .dumpit = wg_get_device_dump, - .done = wg_get_device_done, - .flags = GENL_UNS_ADMIN_PERM - }, { - .cmd = WG_CMD_SET_DEVICE, - .doit = wg_set_device, - .flags = GENL_UNS_ADMIN_PERM - } -}; - static struct genl_family genl_family __ro_after_init = { - .ops = genl_ops, - .n_ops = ARRAY_SIZE(genl_ops), - .resv_start_op = WG_CMD_SET_DEVICE + 1, + .split_ops = wireguard_nl_ops, + .n_split_ops = ARRAY_SIZE(wireguard_nl_ops), .name = WG_GENL_NAME, .version = WG_GENL_VERSION, - .maxattr = WGDEVICE_A_MAX, .module = THIS_MODULE, - .policy = device_policy, .netnsok = true }; int __init wg_genetlink_init(void) { + BUILD_BUG_ON(WG_KEY_LEN != NOISE_PUBLIC_KEY_LEN); + BUILD_BUG_ON(WG_KEY_LEN != NOISE_SYMMETRIC_KEY_LEN); + return genl_register_family(&genl_family); } diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c index 720952b92e78..1fe8468f0bef 100644 --- a/drivers/net/wireguard/noise.c +++ b/drivers/net/wireguard/noise.c @@ -15,7 +15,7 @@ #include <linux/bitmap.h> #include <linux/scatterlist.h> #include <linux/highmem.h> -#include <crypto/algapi.h> +#include <crypto/utils.h> /* This implements Noise_IKpsk2: * @@ -25,18 +25,18 @@ * <- e, ee, se, psk, {} */ -static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; -static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com"; +static const u8 handshake_name[37] __nonstring = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; +static const u8 identifier_name[34] __nonstring = "WireGuard v1 zx2c4 Jason@zx2c4.com"; static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init; static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init; static atomic64_t keypair_counter = ATOMIC64_INIT(0); void __init wg_noise_init(void) { - struct blake2s_state blake; + struct blake2s_ctx blake; - blake2s(handshake_init_chaining_key, handshake_name, NULL, - NOISE_HASH_LEN, sizeof(handshake_name), 0); + blake2s(NULL, 0, handshake_name, sizeof(handshake_name), + handshake_init_chaining_key, NOISE_HASH_LEN); blake2s_init(&blake, NOISE_HASH_LEN); blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN); blake2s_update(&blake, identifier_name, sizeof(identifier_name)); @@ -304,33 +304,33 @@ void wg_noise_set_static_identity_private_key( static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen) { - struct blake2s_state state; + struct blake2s_ctx blake; u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 }; u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32)); int i; if (keylen > BLAKE2S_BLOCK_SIZE) { - blake2s_init(&state, BLAKE2S_HASH_SIZE); - blake2s_update(&state, key, keylen); - blake2s_final(&state, x_key); + blake2s_init(&blake, BLAKE2S_HASH_SIZE); + blake2s_update(&blake, key, keylen); + blake2s_final(&blake, x_key); } else memcpy(x_key, key, keylen); for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) x_key[i] ^= 0x36; - blake2s_init(&state, BLAKE2S_HASH_SIZE); - blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); - blake2s_update(&state, in, inlen); - blake2s_final(&state, i_hash); + blake2s_init(&blake, BLAKE2S_HASH_SIZE); + blake2s_update(&blake, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&blake, in, inlen); + blake2s_final(&blake, i_hash); for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) x_key[i] ^= 0x5c ^ 0x36; - blake2s_init(&state, BLAKE2S_HASH_SIZE); - blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); - blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE); - blake2s_final(&state, i_hash); + blake2s_init(&blake, BLAKE2S_HASH_SIZE); + blake2s_update(&blake, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&blake, i_hash, BLAKE2S_HASH_SIZE); + blake2s_final(&blake, i_hash); memcpy(out, i_hash, BLAKE2S_HASH_SIZE); memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE); @@ -431,7 +431,7 @@ static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN], static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len) { - struct blake2s_state blake; + struct blake2s_ctx blake; blake2s_init(&blake, NOISE_HASH_LEN); blake2s_update(&blake, hash, NOISE_HASH_LEN); diff --git a/drivers/net/wireguard/peer.h b/drivers/net/wireguard/peer.h index 76e4d3128ad4..718fb42bdac7 100644 --- a/drivers/net/wireguard/peer.h +++ b/drivers/net/wireguard/peer.h @@ -20,7 +20,7 @@ struct wg_device; struct endpoint { union { - struct sockaddr addr; + struct sockaddr_inet addr; /* Large enough for both address families */ struct sockaddr_in addr4; struct sockaddr_in6 addr6; }; diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h index 1ea4f874e367..79b6d70de236 100644 --- a/drivers/net/wireguard/queueing.h +++ b/drivers/net/wireguard/queueing.h @@ -104,16 +104,11 @@ static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating) static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id) { - unsigned int cpu = *stored_cpu, cpu_index, i; + unsigned int cpu = *stored_cpu; + + while (unlikely(cpu >= nr_cpu_ids || !cpu_online(cpu))) + cpu = *stored_cpu = cpumask_nth(id % num_online_cpus(), cpu_online_mask); - if (unlikely(cpu >= nr_cpu_ids || - !cpumask_test_cpu(cpu, cpu_online_mask))) { - cpu_index = id % cpumask_weight(cpu_online_mask); - cpu = cpumask_first(cpu_online_mask); - for (i = 0; i < cpu_index; ++i) - cpu = cpumask_next(cpu, cpu_online_mask); - *stored_cpu = cpu; - } return cpu; } @@ -124,10 +119,10 @@ static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id) */ static inline int wg_cpumask_next_online(int *last_cpu) { - int cpu = cpumask_next(*last_cpu, cpu_online_mask); + int cpu = cpumask_next(READ_ONCE(*last_cpu), cpu_online_mask); if (cpu >= nr_cpu_ids) cpu = cpumask_first(cpu_online_mask); - *last_cpu = cpu; + WRITE_ONCE(*last_cpu, cpu); return cpu; } diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c index 0b3f0c843550..eb8851113654 100644 --- a/drivers/net/wireguard/receive.c +++ b/drivers/net/wireguard/receive.c @@ -251,7 +251,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair) if (unlikely(!READ_ONCE(keypair->receiving.is_valid) || wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) || - keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) { + READ_ONCE(keypair->receiving_counter.counter) >= REJECT_AFTER_MESSAGES)) { WRITE_ONCE(keypair->receiving.is_valid, false); return false; } @@ -263,7 +263,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair) * call skb_cow_data, so that there's no chance that data is removed * from the skb, so that later we can extract the original endpoint. */ - offset = skb->data - skb_network_header(skb); + offset = -skb_network_offset(skb); skb_push(skb, offset); num_frags = skb_cow_data(skb, 0, &trailer); offset += sizeof(struct message_data); @@ -318,7 +318,7 @@ static bool counter_validate(struct noise_replay_counter *counter, u64 their_cou for (i = 1; i <= top; ++i) counter->backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; - counter->counter = their_counter; + WRITE_ONCE(counter->counter, their_counter); } index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1; @@ -416,20 +416,20 @@ dishonest_packet_peer: net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n", dev->name, skb, peer->internal_id, &peer->endpoint.addr); - ++dev->stats.rx_errors; - ++dev->stats.rx_frame_errors; + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_frame_errors); goto packet_processed; dishonest_packet_type: net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); - ++dev->stats.rx_errors; - ++dev->stats.rx_frame_errors; + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_frame_errors); goto packet_processed; dishonest_packet_size: net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); - ++dev->stats.rx_errors; - ++dev->stats.rx_length_errors; + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_length_errors); goto packet_processed; packet_processed: dev_kfree_skb(skb); @@ -463,7 +463,7 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget) net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", peer->device->dev->name, PACKET_CB(skb)->nonce, - keypair->receiving_counter.counter); + READ_ONCE(keypair->receiving_counter.counter)); goto next; } diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c index 3d1f64ff2e12..41837efa70cb 100644 --- a/drivers/net/wireguard/selftest/allowedips.c +++ b/drivers/net/wireguard/selftest/allowedips.c @@ -383,7 +383,6 @@ static __init bool randomized_test(void) for (i = 0; i < NUM_QUERIES; ++i) { get_random_bytes(ip, 4); if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { - horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip); pr_err("allowedips random v4 self-test: FAIL\n"); goto free; } @@ -461,6 +460,10 @@ static __init struct wg_peer *init_peer(void) wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \ cidr, mem, &mutex) +#define remove(version, mem, ipa, ipb, ipc, ipd, cidr) \ + wg_allowedips_remove_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \ + cidr, mem, &mutex) + #define maybe_fail() do { \ ++i; \ if (!_s) { \ @@ -586,6 +589,50 @@ bool __init wg_allowedips_selftest(void) test_negative(4, a, 192, 0, 0, 0); test_negative(4, a, 255, 0, 0, 0); + insert(4, a, 1, 0, 0, 0, 32); + insert(4, a, 192, 0, 0, 0, 24); + insert(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128); + insert(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + test(4, a, 1, 0, 0, 0); + test(4, a, 192, 0, 0, 1); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010); + /* Must be an exact match to remove */ + remove(4, a, 192, 0, 0, 0, 32); + test(4, a, 192, 0, 0, 1); + /* NULL peer should have no effect and return 0 */ + test_boolean(!remove(4, NULL, 192, 0, 0, 0, 24)); + test(4, a, 192, 0, 0, 1); + /* different peer should have no effect and return 0 */ + test_boolean(!remove(4, b, 192, 0, 0, 0, 24)); + test(4, a, 192, 0, 0, 1); + /* invalid CIDR should have no effect and return -EINVAL */ + test_boolean(remove(4, b, 192, 0, 0, 0, 33) == -EINVAL); + test(4, a, 192, 0, 0, 1); + remove(4, a, 192, 0, 0, 0, 24); + test_negative(4, a, 192, 0, 0, 1); + remove(4, a, 1, 0, 0, 0, 32); + test_negative(4, a, 1, 0, 0, 0); + /* Must be an exact match to remove */ + remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* NULL peer should have no effect and return 0 */ + test_boolean(!remove(6, NULL, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* different peer should have no effect and return 0 */ + test_boolean(!remove(6, b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* invalid CIDR should have no effect and return -EINVAL */ + test_boolean(remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 129) == -EINVAL); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128); + test_negative(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* Must match the peer to remove */ + remove(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010); + remove(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + test_negative(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010); + wg_allowedips_free(&t, &mutex); wg_allowedips_init(&t); insert(4, a, 192, 168, 0, 0, 16); diff --git a/drivers/net/wireguard/send.c b/drivers/net/wireguard/send.c index 95c853b59e1d..26e09c30d596 100644 --- a/drivers/net/wireguard/send.c +++ b/drivers/net/wireguard/send.c @@ -222,7 +222,7 @@ void wg_packet_send_keepalive(struct wg_peer *peer) { struct sk_buff *skb; - if (skb_queue_empty(&peer->staged_packet_queue)) { + if (skb_queue_empty_lockless(&peer->staged_packet_queue)) { skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH, GFP_ATOMIC); if (unlikely(!skb)) @@ -333,7 +333,8 @@ err: void wg_packet_purge_staged_packets(struct wg_peer *peer) { spin_lock_bh(&peer->staged_packet_queue.lock); - peer->device->dev->stats.tx_dropped += peer->staged_packet_queue.qlen; + DEV_STATS_ADD(peer->device->dev, tx_dropped, + peer->staged_packet_queue.qlen); __skb_queue_purge(&peer->staged_packet_queue); spin_unlock_bh(&peer->staged_packet_queue.lock); } diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c index 0414d7a6ce74..253488f8c00f 100644 --- a/drivers/net/wireguard/socket.c +++ b/drivers/net/wireguard/socket.c @@ -84,7 +84,7 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, skb->ignore_df = 1; udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds, ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport, - fl.fl4_dport, false, false); + fl.fl4_dport, false, false, 0); goto out; err: @@ -151,7 +151,7 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, skb->ignore_df = 1; udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds, ip6_dst_hoplimit(dst), 0, fl.fl6_sport, - fl.fl6_dport, false); + fl.fl6_dport, false, 0); goto out; err: diff --git a/drivers/net/wireguard/timers.c b/drivers/net/wireguard/timers.c index 968bdb4df0b3..4016a3065602 100644 --- a/drivers/net/wireguard/timers.c +++ b/drivers/net/wireguard/timers.c @@ -40,15 +40,15 @@ static inline void mod_peer_timer(struct wg_peer *peer, static void wg_expired_retransmit_handshake(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, - timer_retransmit_handshake); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_retransmit_handshake); if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) { pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, (int)MAX_TIMER_HANDSHAKES + 2); - del_timer(&peer->timer_send_keepalive); + timer_delete(&peer->timer_send_keepalive); /* We drop all packets without a keypair and don't try again, * if we try unsuccessfully for too long to make a handshake. */ @@ -78,7 +78,8 @@ static void wg_expired_retransmit_handshake(struct timer_list *timer) static void wg_expired_send_keepalive(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, timer_send_keepalive); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_send_keepalive); wg_packet_send_keepalive(peer); if (peer->timer_need_another_keepalive) { @@ -90,7 +91,8 @@ static void wg_expired_send_keepalive(struct timer_list *timer) static void wg_expired_new_handshake(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, timer_new_handshake); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_new_handshake); pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n", peer->device->dev->name, peer->internal_id, @@ -104,7 +106,8 @@ static void wg_expired_new_handshake(struct timer_list *timer) static void wg_expired_zero_key_material(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, timer_zero_key_material); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_zero_key_material); rcu_read_lock_bh(); if (!READ_ONCE(peer->is_dead)) { @@ -134,8 +137,8 @@ static void wg_queued_expired_zero_key_material(struct work_struct *work) static void wg_expired_send_persistent_keepalive(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, - timer_persistent_keepalive); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_persistent_keepalive); if (likely(peer->persistent_keepalive_interval)) wg_packet_send_keepalive(peer); @@ -167,7 +170,7 @@ void wg_timers_data_received(struct wg_peer *peer) */ void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer) { - del_timer(&peer->timer_send_keepalive); + timer_delete(&peer->timer_send_keepalive); } /* Should be called after any type of authenticated packet is received, whether @@ -175,7 +178,7 @@ void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer) */ void wg_timers_any_authenticated_packet_received(struct wg_peer *peer) { - del_timer(&peer->timer_new_handshake); + timer_delete(&peer->timer_new_handshake); } /* Should be called after a handshake initiation message is sent. */ @@ -191,7 +194,7 @@ void wg_timers_handshake_initiated(struct wg_peer *peer) */ void wg_timers_handshake_complete(struct wg_peer *peer) { - del_timer(&peer->timer_retransmit_handshake); + timer_delete(&peer->timer_retransmit_handshake); peer->timer_handshake_attempts = 0; peer->sent_lastminute_handshake = false; ktime_get_real_ts64(&peer->walltime_last_handshake); |
