diff options
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r-- | drivers/net/wireguard/allowedips.c | 106 | ||||
-rw-r--r-- | drivers/net/wireguard/allowedips.h | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/cookie.c | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/device.c | 26 | ||||
-rw-r--r-- | drivers/net/wireguard/main.c | 2 | ||||
-rw-r--r-- | drivers/net/wireguard/netlink.c | 57 | ||||
-rw-r--r-- | drivers/net/wireguard/noise.c | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/queueing.h | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/receive.c | 8 | ||||
-rw-r--r-- | drivers/net/wireguard/selftest/allowedips.c | 49 | ||||
-rw-r--r-- | drivers/net/wireguard/send.c | 2 | ||||
-rw-r--r-- | drivers/net/wireguard/timers.c | 25 |
12 files changed, 194 insertions, 97 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c index 0ba714ca5185..09f7fcd7da78 100644 --- a/drivers/net/wireguard/allowedips.c +++ b/drivers/net/wireguard/allowedips.c @@ -15,8 +15,8 @@ static void swap_endian(u8 *dst, const u8 *src, u8 bits) if (bits == 32) { *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); } else if (bits == 128) { - ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]); - ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]); + ((u64 *)dst)[0] = get_unaligned_be64(src); + ((u64 *)dst)[1] = get_unaligned_be64(src + 8); } } @@ -249,6 +249,52 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, return 0; } +static void remove_node(struct allowedips_node *node, struct mutex *lock) +{ + struct allowedips_node *child, **parent_bit, *parent; + bool free_parent; + + list_del_init(&node->peer_list); + RCU_INIT_POINTER(node->peer, NULL); + if (node->bit[0] && node->bit[1]) + return; + child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], + lockdep_is_held(lock)); + if (child) + child->parent_bit_packed = node->parent_bit_packed; + parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); + *parent_bit = child; + parent = (void *)parent_bit - + offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); + free_parent = !rcu_access_pointer(node->bit[0]) && !rcu_access_pointer(node->bit[1]) && + (node->parent_bit_packed & 3) <= 1 && !rcu_access_pointer(parent->peer); + if (free_parent) + child = rcu_dereference_protected(parent->bit[!(node->parent_bit_packed & 1)], + lockdep_is_held(lock)); + call_rcu(&node->rcu, node_free_rcu); + if (!free_parent) + return; + if (child) + child->parent_bit_packed = parent->parent_bit_packed; + *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; + call_rcu(&parent->rcu, node_free_rcu); +} + +static int remove(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + struct allowedips_node *node; + + if (unlikely(cidr > bits)) + return -EINVAL; + if (!rcu_access_pointer(*trie) || !node_placement(*trie, key, cidr, bits, &node, lock) || + peer != rcu_access_pointer(node->peer)) + return 0; + + remove_node(node, lock); + return 0; +} + void wg_allowedips_init(struct allowedips *table) { table->root4 = table->root6 = NULL; @@ -300,44 +346,38 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, return add(&table->root6, 128, key, cidr, peer, lock); } +int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + /* Aligned so it can be passed to fls */ + u8 key[4] __aligned(__alignof(u32)); + + ++table->seq; + swap_endian(key, (const u8 *)ip, 32); + return remove(&table->root4, 32, key, cidr, peer, lock); +} + +int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + /* Aligned so it can be passed to fls64 */ + u8 key[16] __aligned(__alignof(u64)); + + ++table->seq; + swap_endian(key, (const u8 *)ip, 128); + return remove(&table->root6, 128, key, cidr, peer, lock); +} + void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock) { - struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; - bool free_parent; + struct allowedips_node *node, *tmp; if (list_empty(&peer->allowedips_list)) return; ++table->seq; - list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { - list_del_init(&node->peer_list); - RCU_INIT_POINTER(node->peer, NULL); - if (node->bit[0] && node->bit[1]) - continue; - child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], - lockdep_is_held(lock)); - if (child) - child->parent_bit_packed = node->parent_bit_packed; - parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); - *parent_bit = child; - parent = (void *)parent_bit - - offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); - free_parent = !rcu_access_pointer(node->bit[0]) && - !rcu_access_pointer(node->bit[1]) && - (node->parent_bit_packed & 3) <= 1 && - !rcu_access_pointer(parent->peer); - if (free_parent) - child = rcu_dereference_protected( - parent->bit[!(node->parent_bit_packed & 1)], - lockdep_is_held(lock)); - call_rcu(&node->rcu, node_free_rcu); - if (!free_parent) - continue; - if (child) - child->parent_bit_packed = parent->parent_bit_packed; - *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; - call_rcu(&parent->rcu, node_free_rcu); - } + list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) + remove_node(node, lock); } int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h index 2346c797eb4d..931958cb6e10 100644 --- a/drivers/net/wireguard/allowedips.h +++ b/drivers/net/wireguard/allowedips.h @@ -38,6 +38,10 @@ int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock); int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock); +int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock); +int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock); void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock); /* The ip input pointer should be __aligned(__alignof(u64))) */ diff --git a/drivers/net/wireguard/cookie.c b/drivers/net/wireguard/cookie.c index f89581b5e8cb..94d0a7206084 100644 --- a/drivers/net/wireguard/cookie.c +++ b/drivers/net/wireguard/cookie.c @@ -26,8 +26,8 @@ void wg_cookie_checker_init(struct cookie_checker *checker, } enum { COOKIE_KEY_LABEL_LEN = 8 }; -static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----"; -static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--"; +static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "mac1----"; +static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "cookie--"; static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOISE_PUBLIC_KEY_LEN], diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c index deb9636b0ecf..4a529f1f9bea 100644 --- a/drivers/net/wireguard/device.c +++ b/drivers/net/wireguard/device.c @@ -81,7 +81,7 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action, v list_for_each_entry(wg, &device_list, device_list) { mutex_lock(&wg->device_update_lock); list_for_each_entry(peer, &wg->peer_list, peer_list) { - del_timer(&peer->timer_zero_key_material); + timer_delete(&peer->timer_zero_key_material); wg_noise_handshake_clear(&peer->handshake); wg_noise_keypairs_clear(&peer->keypairs); } @@ -237,7 +237,6 @@ static const struct net_device_ops netdev_ops = { .ndo_open = wg_open, .ndo_stop = wg_stop, .ndo_start_xmit = wg_xmit, - .ndo_get_stats64 = dev_get_tstats64 }; static void wg_destruct(struct net_device *dev) @@ -262,7 +261,6 @@ static void wg_destruct(struct net_device *dev) rcu_barrier(); /* Wait for all the peers to be actually freed. */ wg_ratelimiter_uninit(); memzero_explicit(&wg->static_identity, sizeof(wg->static_identity)); - free_percpu(dev->tstats); kvfree(wg->index_hashtable); kvfree(wg->peer_hashtable); mutex_unlock(&wg->device_update_lock); @@ -291,30 +289,33 @@ static void wg_setup(struct net_device *dev) dev->type = ARPHRD_NONE; dev->flags = IFF_POINTOPOINT | IFF_NOARP; dev->priv_flags |= IFF_NO_QUEUE; - dev->features |= NETIF_F_LLTX; + dev->lltx = true; dev->features |= WG_NETDEV_FEATURES; dev->hw_features |= WG_NETDEV_FEATURES; dev->hw_enc_features |= WG_NETDEV_FEATURES; dev->mtu = ETH_DATA_LEN - overhead; dev->max_mtu = round_down(INT_MAX, MESSAGE_PADDING_MULTIPLE) - overhead; + dev->pcpu_stat_type = NETDEV_PCPU_STAT_TSTATS; SET_NETDEV_DEVTYPE(dev, &device_type); /* We need to keep the dst around in case of icmp replies. */ netif_keep_dst(dev); - memset(wg, 0, sizeof(*wg)); + netif_set_tso_max_size(dev, GSO_MAX_SIZE); + wg->dev = dev; } -static int wg_newlink(struct net *src_net, struct net_device *dev, - struct nlattr *tb[], struct nlattr *data[], +static int wg_newlink(struct net_device *dev, + struct rtnl_newlink_params *params, struct netlink_ext_ack *extack) { + struct net *link_net = rtnl_newlink_link_net(params); struct wg_device *wg = netdev_priv(dev); int ret = -ENOMEM; - rcu_assign_pointer(wg->creating_net, src_net); + rcu_assign_pointer(wg->creating_net, link_net); init_rwsem(&wg->static_identity.lock); mutex_init(&wg->socket_update_lock); mutex_init(&wg->device_update_lock); @@ -331,14 +332,10 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, if (!wg->index_hashtable) goto err_free_peer_hashtable; - dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); - if (!dev->tstats) - goto err_free_index_hashtable; - wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s", WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name); if (!wg->handshake_receive_wq) - goto err_free_tstats; + goto err_free_index_hashtable; wg->handshake_send_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); @@ -369,6 +366,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, if (ret < 0) goto err_free_handshake_queue; + dev_set_threaded(dev, true); ret = register_netdevice(dev); if (ret < 0) goto err_uninit_ratelimiter; @@ -397,8 +395,6 @@ err_destroy_handshake_send: destroy_workqueue(wg->handshake_send_wq); err_destroy_handshake_receive: destroy_workqueue(wg->handshake_receive_wq); -err_free_tstats: - free_percpu(dev->tstats); err_free_index_hashtable: kvfree(wg->index_hashtable); err_free_peer_hashtable: diff --git a/drivers/net/wireguard/main.c b/drivers/net/wireguard/main.c index ee4da9ab8013..a00671b58701 100644 --- a/drivers/net/wireguard/main.c +++ b/drivers/net/wireguard/main.c @@ -14,7 +14,7 @@ #include <linux/init.h> #include <linux/module.h> -#include <linux/genetlink.h> +#include <net/genetlink.h> #include <net/rtnetlink.h> static int __init wg_mod_init(void) diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c index e220d761b1f2..67f962eb8b46 100644 --- a/drivers/net/wireguard/netlink.c +++ b/drivers/net/wireguard/netlink.c @@ -24,7 +24,7 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = { [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 }, [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), - [WGDEVICE_A_FLAGS] = { .type = NLA_U32 }, + [WGDEVICE_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGDEVICE_F_ALL), [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16 }, [WGDEVICE_A_FWMARK] = { .type = NLA_U32 }, [WGDEVICE_A_PEERS] = { .type = NLA_NESTED } @@ -33,7 +33,7 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = { static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = { [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN), - [WGPEER_A_FLAGS] = { .type = NLA_U32 }, + [WGPEER_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGPEER_F_ALL), [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)), [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16 }, [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)), @@ -46,7 +46,8 @@ static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = { static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = { [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16 }, [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(sizeof(struct in_addr)), - [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 } + [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 }, + [WGALLOWEDIP_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGALLOWEDIP_F_ALL), }; static struct wg_device *lookup_interface(struct nlattr **attrs, @@ -164,8 +165,8 @@ get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx) if (!allowedips_node) goto no_allowedips; if (!ctx->allowedips_seq) - ctx->allowedips_seq = peer->device->peer_allowedips.seq; - else if (ctx->allowedips_seq != peer->device->peer_allowedips.seq) + ctx->allowedips_seq = ctx->wg->peer_allowedips.seq; + else if (ctx->allowedips_seq != ctx->wg->peer_allowedips.seq) goto no_allowedips; allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); @@ -255,17 +256,17 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) if (!peers_nest) goto out; ret = 0; - /* If the last cursor was removed via list_del_init in peer_remove, then + lockdep_assert_held(&wg->device_update_lock); + /* If the last cursor was removed in peer_remove or peer_remove_all, then * we just treat this the same as there being no more peers left. The * reason is that seq_nr should indicate to userspace that this isn't a * coherent dump anyway, so they'll try again. */ if (list_empty(&wg->peer_list) || - (ctx->next_peer && list_empty(&ctx->next_peer->peer_list))) { + (ctx->next_peer && ctx->next_peer->is_dead)) { nla_nest_cancel(skb, peers_nest); goto out; } - lockdep_assert_held(&wg->device_update_lock); peer = list_prepare_entry(ctx->next_peer, &wg->peer_list, peer_list); list_for_each_entry_continue(peer, &wg->peer_list, peer_list) { if (get_peer(peer, skb, ctx)) { @@ -329,6 +330,7 @@ static int set_port(struct wg_device *wg, u16 port) static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) { int ret = -EINVAL; + u32 flags = 0; u16 family; u8 cidr; @@ -337,19 +339,30 @@ static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) return ret; family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]); cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]); + if (attrs[WGALLOWEDIP_A_FLAGS]) + flags = nla_get_u32(attrs[WGALLOWEDIP_A_FLAGS]); if (family == AF_INET && cidr <= 32 && - nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) - ret = wg_allowedips_insert_v4( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); - else if (family == AF_INET6 && cidr <= 128 && - nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) - ret = wg_allowedips_insert_v6( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) { + if (flags & WGALLOWEDIP_F_REMOVE_ME) + ret = wg_allowedips_remove_v4(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + else + ret = wg_allowedips_insert_v4(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + } else if (family == AF_INET6 && cidr <= 128 && + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) { + if (flags & WGALLOWEDIP_F_REMOVE_ME) + ret = wg_allowedips_remove_v6(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + else + ret = wg_allowedips_insert_v6(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + } return ret; } @@ -373,9 +386,6 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs) if (attrs[WGPEER_A_FLAGS]) flags = nla_get_u32(attrs[WGPEER_A_FLAGS]); - ret = -EOPNOTSUPP; - if (flags & ~__WGPEER_F_ALL) - goto out; ret = -EPFNOSUPPORT; if (attrs[WGPEER_A_PROTOCOL_VERSION]) { @@ -506,9 +516,6 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) if (info->attrs[WGDEVICE_A_FLAGS]) flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]); - ret = -EOPNOTSUPP; - if (flags & ~__WGDEVICE_F_ALL) - goto out; if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) { struct net *net; diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c index 202a33af5a72..7eb9a23a3d4d 100644 --- a/drivers/net/wireguard/noise.c +++ b/drivers/net/wireguard/noise.c @@ -25,8 +25,8 @@ * <- e, ee, se, psk, {} */ -static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; -static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com"; +static const u8 handshake_name[37] __nonstring = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; +static const u8 identifier_name[34] __nonstring = "WireGuard v1 zx2c4 Jason@zx2c4.com"; static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init; static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init; static atomic64_t keypair_counter = ATOMIC64_INIT(0); diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h index 1ea4f874e367..7eb76724b3ed 100644 --- a/drivers/net/wireguard/queueing.h +++ b/drivers/net/wireguard/queueing.h @@ -124,10 +124,10 @@ static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id) */ static inline int wg_cpumask_next_online(int *last_cpu) { - int cpu = cpumask_next(*last_cpu, cpu_online_mask); + int cpu = cpumask_next(READ_ONCE(*last_cpu), cpu_online_mask); if (cpu >= nr_cpu_ids) cpu = cpumask_first(cpu_online_mask); - *last_cpu = cpu; + WRITE_ONCE(*last_cpu, cpu); return cpu; } diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c index a176653c8861..eb8851113654 100644 --- a/drivers/net/wireguard/receive.c +++ b/drivers/net/wireguard/receive.c @@ -251,7 +251,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair) if (unlikely(!READ_ONCE(keypair->receiving.is_valid) || wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) || - keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) { + READ_ONCE(keypair->receiving_counter.counter) >= REJECT_AFTER_MESSAGES)) { WRITE_ONCE(keypair->receiving.is_valid, false); return false; } @@ -263,7 +263,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair) * call skb_cow_data, so that there's no chance that data is removed * from the skb, so that later we can extract the original endpoint. */ - offset = skb->data - skb_network_header(skb); + offset = -skb_network_offset(skb); skb_push(skb, offset); num_frags = skb_cow_data(skb, 0, &trailer); offset += sizeof(struct message_data); @@ -318,7 +318,7 @@ static bool counter_validate(struct noise_replay_counter *counter, u64 their_cou for (i = 1; i <= top; ++i) counter->backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; - counter->counter = their_counter; + WRITE_ONCE(counter->counter, their_counter); } index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1; @@ -463,7 +463,7 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget) net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", peer->device->dev->name, PACKET_CB(skb)->nonce, - keypair->receiving_counter.counter); + READ_ONCE(keypair->receiving_counter.counter)); goto next; } diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c index 3d1f64ff2e12..41837efa70cb 100644 --- a/drivers/net/wireguard/selftest/allowedips.c +++ b/drivers/net/wireguard/selftest/allowedips.c @@ -383,7 +383,6 @@ static __init bool randomized_test(void) for (i = 0; i < NUM_QUERIES; ++i) { get_random_bytes(ip, 4); if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { - horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip); pr_err("allowedips random v4 self-test: FAIL\n"); goto free; } @@ -461,6 +460,10 @@ static __init struct wg_peer *init_peer(void) wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \ cidr, mem, &mutex) +#define remove(version, mem, ipa, ipb, ipc, ipd, cidr) \ + wg_allowedips_remove_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \ + cidr, mem, &mutex) + #define maybe_fail() do { \ ++i; \ if (!_s) { \ @@ -586,6 +589,50 @@ bool __init wg_allowedips_selftest(void) test_negative(4, a, 192, 0, 0, 0); test_negative(4, a, 255, 0, 0, 0); + insert(4, a, 1, 0, 0, 0, 32); + insert(4, a, 192, 0, 0, 0, 24); + insert(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128); + insert(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + test(4, a, 1, 0, 0, 0); + test(4, a, 192, 0, 0, 1); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010); + /* Must be an exact match to remove */ + remove(4, a, 192, 0, 0, 0, 32); + test(4, a, 192, 0, 0, 1); + /* NULL peer should have no effect and return 0 */ + test_boolean(!remove(4, NULL, 192, 0, 0, 0, 24)); + test(4, a, 192, 0, 0, 1); + /* different peer should have no effect and return 0 */ + test_boolean(!remove(4, b, 192, 0, 0, 0, 24)); + test(4, a, 192, 0, 0, 1); + /* invalid CIDR should have no effect and return -EINVAL */ + test_boolean(remove(4, b, 192, 0, 0, 0, 33) == -EINVAL); + test(4, a, 192, 0, 0, 1); + remove(4, a, 192, 0, 0, 0, 24); + test_negative(4, a, 192, 0, 0, 1); + remove(4, a, 1, 0, 0, 0, 32); + test_negative(4, a, 1, 0, 0, 0); + /* Must be an exact match to remove */ + remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* NULL peer should have no effect and return 0 */ + test_boolean(!remove(6, NULL, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* different peer should have no effect and return 0 */ + test_boolean(!remove(6, b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* invalid CIDR should have no effect and return -EINVAL */ + test_boolean(remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 129) == -EINVAL); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128); + test_negative(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* Must match the peer to remove */ + remove(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010); + remove(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + test_negative(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010); + wg_allowedips_free(&t, &mutex); wg_allowedips_init(&t); insert(4, a, 192, 168, 0, 0, 16); diff --git a/drivers/net/wireguard/send.c b/drivers/net/wireguard/send.c index 0d48e0f4a1ba..26e09c30d596 100644 --- a/drivers/net/wireguard/send.c +++ b/drivers/net/wireguard/send.c @@ -222,7 +222,7 @@ void wg_packet_send_keepalive(struct wg_peer *peer) { struct sk_buff *skb; - if (skb_queue_empty(&peer->staged_packet_queue)) { + if (skb_queue_empty_lockless(&peer->staged_packet_queue)) { skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH, GFP_ATOMIC); if (unlikely(!skb)) diff --git a/drivers/net/wireguard/timers.c b/drivers/net/wireguard/timers.c index 968bdb4df0b3..4016a3065602 100644 --- a/drivers/net/wireguard/timers.c +++ b/drivers/net/wireguard/timers.c @@ -40,15 +40,15 @@ static inline void mod_peer_timer(struct wg_peer *peer, static void wg_expired_retransmit_handshake(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, - timer_retransmit_handshake); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_retransmit_handshake); if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) { pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, (int)MAX_TIMER_HANDSHAKES + 2); - del_timer(&peer->timer_send_keepalive); + timer_delete(&peer->timer_send_keepalive); /* We drop all packets without a keypair and don't try again, * if we try unsuccessfully for too long to make a handshake. */ @@ -78,7 +78,8 @@ static void wg_expired_retransmit_handshake(struct timer_list *timer) static void wg_expired_send_keepalive(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, timer_send_keepalive); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_send_keepalive); wg_packet_send_keepalive(peer); if (peer->timer_need_another_keepalive) { @@ -90,7 +91,8 @@ static void wg_expired_send_keepalive(struct timer_list *timer) static void wg_expired_new_handshake(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, timer_new_handshake); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_new_handshake); pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n", peer->device->dev->name, peer->internal_id, @@ -104,7 +106,8 @@ static void wg_expired_new_handshake(struct timer_list *timer) static void wg_expired_zero_key_material(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, timer_zero_key_material); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_zero_key_material); rcu_read_lock_bh(); if (!READ_ONCE(peer->is_dead)) { @@ -134,8 +137,8 @@ static void wg_queued_expired_zero_key_material(struct work_struct *work) static void wg_expired_send_persistent_keepalive(struct timer_list *timer) { - struct wg_peer *peer = from_timer(peer, timer, - timer_persistent_keepalive); + struct wg_peer *peer = timer_container_of(peer, timer, + timer_persistent_keepalive); if (likely(peer->persistent_keepalive_interval)) wg_packet_send_keepalive(peer); @@ -167,7 +170,7 @@ void wg_timers_data_received(struct wg_peer *peer) */ void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer) { - del_timer(&peer->timer_send_keepalive); + timer_delete(&peer->timer_send_keepalive); } /* Should be called after any type of authenticated packet is received, whether @@ -175,7 +178,7 @@ void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer) */ void wg_timers_any_authenticated_packet_received(struct wg_peer *peer) { - del_timer(&peer->timer_new_handshake); + timer_delete(&peer->timer_new_handshake); } /* Should be called after a handshake initiation message is sent. */ @@ -191,7 +194,7 @@ void wg_timers_handshake_initiated(struct wg_peer *peer) */ void wg_timers_handshake_complete(struct wg_peer *peer) { - del_timer(&peer->timer_retransmit_handshake); + timer_delete(&peer->timer_retransmit_handshake); peer->timer_handshake_attempts = 0; peer->sent_lastminute_handshake = false; ktime_get_real_ts64(&peer->walltime_last_handshake); |