summaryrefslogtreecommitdiff
path: root/drivers/net/wireguard/allowedips.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard/allowedips.c')
-rw-r--r--drivers/net/wireguard/allowedips.c102
1 files changed, 58 insertions, 44 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c
index c540dce8d224..b7197e80f226 100644
--- a/drivers/net/wireguard/allowedips.c
+++ b/drivers/net/wireguard/allowedips.c
@@ -30,8 +30,11 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
node->bitlen = bits;
memcpy(node->bits, src, bits / 8U);
}
-#define CHOOSE_NODE(parent, key) \
- parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
+
+static inline u8 choose(struct allowedips_node *node, const u8 *key)
+{
+ return (key[node->bit_at_a] >> node->bit_at_b) & 1;
+}
static void push_rcu(struct allowedips_node **stack,
struct allowedips_node __rcu *p, unsigned int *len)
@@ -112,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits,
found = node;
if (node->cidr == bits)
break;
- node = rcu_dereference_bh(CHOOSE_NODE(node, key));
+ node = rcu_dereference_bh(node->bit[choose(node, key)]);
}
return found;
}
@@ -144,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
u8 cidr, u8 bits, struct allowedips_node **rnode,
struct mutex *lock)
{
- struct allowedips_node *node = rcu_dereference_protected(trie,
- lockdep_is_held(lock));
+ struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
struct allowedips_node *parent = NULL;
bool exact = false;
@@ -155,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
exact = true;
break;
}
- node = rcu_dereference_protected(CHOOSE_NODE(parent, key),
- lockdep_is_held(lock));
+ node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock));
}
*rnode = parent;
return exact;
}
+static inline void connect_node(struct allowedips_node **parent, u8 bit, struct allowedips_node *node)
+{
+ node->parent_bit_packed = (unsigned long)parent | bit;
+ rcu_assign_pointer(*parent, node);
+}
+
+static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node)
+{
+ u8 bit = choose(parent, node->bits);
+ connect_node(&parent->bit[bit], bit, node);
+}
+
static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
u8 cidr, struct wg_peer *peer, struct mutex *lock)
{
@@ -177,8 +190,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
RCU_INIT_POINTER(node->peer, peer);
list_add_tail(&node->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(node, key, cidr, bits);
- rcu_assign_pointer(node->parent_bit, trie);
- rcu_assign_pointer(*trie, node);
+ connect_node(trie, 2, node);
return 0;
}
if (node_placement(*trie, key, cidr, bits, &node, lock)) {
@@ -197,10 +209,10 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (!node) {
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
} else {
- down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock));
+ const u8 bit = choose(node, key);
+ down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock));
if (!down) {
- rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, key));
- rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
+ connect_node(&node->bit[bit], bit, newnode);
return 0;
}
}
@@ -208,15 +220,11 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
parent = node;
if (newnode->cidr == cidr) {
- rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(newnode, down->bits));
- rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
- if (!parent) {
- rcu_assign_pointer(newnode->parent_bit, trie);
- rcu_assign_pointer(*trie, newnode);
- } else {
- rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(parent, newnode->bits));
- rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode);
- }
+ choose_and_connect_node(newnode, down);
+ if (!parent)
+ connect_node(trie, 2, newnode);
+ else
+ choose_and_connect_node(parent, newnode);
return 0;
}
@@ -229,17 +237,12 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
INIT_LIST_HEAD(&node->peer_list);
copy_and_assign_cidr(node, newnode->bits, cidr, bits);
- rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(node, down->bits));
- rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
- rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, newnode->bits));
- rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
- if (!parent) {
- rcu_assign_pointer(node->parent_bit, trie);
- rcu_assign_pointer(*trie, node);
- } else {
- rcu_assign_pointer(node->parent_bit, &CHOOSE_NODE(parent, node->bits));
- rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node);
- }
+ choose_and_connect_node(node, down);
+ choose_and_connect_node(node, newnode);
+ if (!parent)
+ connect_node(trie, 2, node);
+ else
+ choose_and_connect_node(parent, node);
return 0;
}
@@ -297,7 +300,8 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock)
{
- struct allowedips_node *node, *child, *tmp;
+ struct allowedips_node *node, *child, **parent_bit, *parent, *tmp;
+ bool free_parent;
if (list_empty(&peer->allowedips_list))
return;
@@ -307,19 +311,29 @@ void wg_allowedips_remove_by_peer(struct allowedips *table,
RCU_INIT_POINTER(node->peer, NULL);
if (node->bit[0] && node->bit[1])
continue;
- child = rcu_dereference_protected(
- node->bit[!rcu_access_pointer(node->bit[0])],
- lockdep_is_held(lock));
+ child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
+ lockdep_is_held(lock));
if (child)
- child->parent_bit = node->parent_bit;
- *rcu_dereference_protected(node->parent_bit, lockdep_is_held(lock)) = child;
+ child->parent_bit_packed = node->parent_bit_packed;
+ parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
+ *parent_bit = child;
+ parent = (void *)parent_bit -
+ offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
+ free_parent = !rcu_access_pointer(node->bit[0]) &&
+ !rcu_access_pointer(node->bit[1]) &&
+ (node->parent_bit_packed & 3) <= 1 &&
+ !rcu_access_pointer(parent->peer);
+ if (free_parent)
+ child = rcu_dereference_protected(
+ parent->bit[!(node->parent_bit_packed & 1)],
+ lockdep_is_held(lock));
call_rcu(&node->rcu, node_free_rcu);
-
- /* TODO: Note that we currently don't walk up and down in order to
- * free any potential filler nodes. This means that this function
- * doesn't free up as much as it could, which could be revisited
- * at some point.
- */
+ if (!free_parent)
+ continue;
+ if (child)
+ child->parent_bit_packed = parent->parent_bit_packed;
+ *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
+ call_rcu(&parent->rcu, node_free_rcu);
}
}