summaryrefslogtreecommitdiff
path: root/kernel/bpf/lpm_trie.c
diff options
context:
space:
mode:
Diffstat (limited to 'kernel/bpf/lpm_trie.c')
-rw-r--r--kernel/bpf/lpm_trie.c202
1 files changed, 128 insertions, 74 deletions
diff --git a/kernel/bpf/lpm_trie.c b/kernel/bpf/lpm_trie.c
index cec792a17e5f..e8a772e64324 100644
--- a/kernel/bpf/lpm_trie.c
+++ b/kernel/bpf/lpm_trie.c
@@ -14,6 +14,8 @@
#include <linux/vmalloc.h>
#include <net/ipv6.h>
#include <uapi/linux/btf.h>
+#include <linux/btf_ids.h>
+#include <linux/bpf_mem_alloc.h>
/* Intermediate node */
#define LPM_TREE_NODE_FLAG_IM BIT(0)
@@ -21,7 +23,6 @@
struct lpm_trie_node;
struct lpm_trie_node {
- struct rcu_head rcu;
struct lpm_trie_node __rcu *child[2];
u32 prefixlen;
u32 flags;
@@ -31,10 +32,11 @@ struct lpm_trie_node {
struct lpm_trie {
struct bpf_map map;
struct lpm_trie_node __rcu *root;
+ struct bpf_mem_alloc ma;
size_t n_entries;
size_t max_prefixlen;
size_t data_size;
- spinlock_t lock;
+ raw_spinlock_t lock;
};
/* This trie implements a longest prefix match algorithm that can be used to
@@ -154,22 +156,23 @@ static inline int extract_bit(const u8 *data, size_t index)
}
/**
- * longest_prefix_match() - determine the longest prefix
+ * __longest_prefix_match() - determine the longest prefix
* @trie: The trie to get internal sizes from
* @node: The node to operate on
* @key: The key to compare to @node
*
* Determine the longest prefix of @node that matches the bits in @key.
*/
-static size_t longest_prefix_match(const struct lpm_trie *trie,
- const struct lpm_trie_node *node,
- const struct bpf_lpm_trie_key *key)
+static __always_inline
+size_t __longest_prefix_match(const struct lpm_trie *trie,
+ const struct lpm_trie_node *node,
+ const struct bpf_lpm_trie_key_u8 *key)
{
u32 limit = min(node->prefixlen, key->prefixlen);
u32 prefixlen = 0, i = 0;
BUILD_BUG_ON(offsetof(struct lpm_trie_node, data) % sizeof(u32));
- BUILD_BUG_ON(offsetof(struct bpf_lpm_trie_key, data) % sizeof(u32));
+ BUILD_BUG_ON(offsetof(struct bpf_lpm_trie_key_u8, data) % sizeof(u32));
#if defined(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) && defined(CONFIG_64BIT)
@@ -223,16 +226,27 @@ static size_t longest_prefix_match(const struct lpm_trie *trie,
return prefixlen;
}
+static size_t longest_prefix_match(const struct lpm_trie *trie,
+ const struct lpm_trie_node *node,
+ const struct bpf_lpm_trie_key_u8 *key)
+{
+ return __longest_prefix_match(trie, node, key);
+}
+
/* Called from syscall or from eBPF program */
static void *trie_lookup_elem(struct bpf_map *map, void *_key)
{
struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
struct lpm_trie_node *node, *found = NULL;
- struct bpf_lpm_trie_key *key = _key;
+ struct bpf_lpm_trie_key_u8 *key = _key;
+
+ if (key->prefixlen > trie->max_prefixlen)
+ return NULL;
/* Start walking the trie from the root node ... */
- for (node = rcu_dereference(trie->root); node;) {
+ for (node = rcu_dereference_check(trie->root, rcu_read_lock_bh_held());
+ node;) {
unsigned int next_bit;
size_t matchlen;
@@ -240,7 +254,7 @@ static void *trie_lookup_elem(struct bpf_map *map, void *_key)
* If it's the maximum possible prefix for this trie, we have
* an exact match and can return it directly.
*/
- matchlen = longest_prefix_match(trie, node, key);
+ matchlen = __longest_prefix_match(trie, node, key);
if (matchlen == trie->max_prefixlen) {
found = node;
break;
@@ -264,7 +278,8 @@ static void *trie_lookup_elem(struct bpf_map *map, void *_key)
* traverse down.
*/
next_bit = extract_bit(key->data, node->prefixlen);
- node = rcu_dereference(node->child[next_bit]);
+ node = rcu_dereference_check(node->child[next_bit],
+ rcu_read_lock_bh_held());
}
if (!found)
@@ -273,17 +288,13 @@ static void *trie_lookup_elem(struct bpf_map *map, void *_key)
return found->data + trie->data_size;
}
-static struct lpm_trie_node *lpm_trie_node_alloc(const struct lpm_trie *trie,
+static struct lpm_trie_node *lpm_trie_node_alloc(struct lpm_trie *trie,
const void *value)
{
struct lpm_trie_node *node;
- size_t size = sizeof(struct lpm_trie_node) + trie->data_size;
- if (value)
- size += trie->map.value_size;
+ node = bpf_mem_cache_alloc(&trie->ma);
- node = bpf_map_kmalloc_node(&trie->map, size, GFP_ATOMIC | __GFP_NOWARN,
- trie->map.numa_node);
if (!node)
return NULL;
@@ -296,14 +307,25 @@ static struct lpm_trie_node *lpm_trie_node_alloc(const struct lpm_trie *trie,
return node;
}
+static int trie_check_add_elem(struct lpm_trie *trie, u64 flags)
+{
+ if (flags == BPF_EXIST)
+ return -ENOENT;
+ if (trie->n_entries == trie->map.max_entries)
+ return -ENOSPC;
+ trie->n_entries++;
+ return 0;
+}
+
/* Called from syscall or from eBPF program */
-static int trie_update_elem(struct bpf_map *map,
- void *_key, void *value, u64 flags)
+static long trie_update_elem(struct bpf_map *map,
+ void *_key, void *value, u64 flags)
{
struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
- struct lpm_trie_node *node, *im_node = NULL, *new_node = NULL;
+ struct lpm_trie_node *node, *im_node, *new_node;
+ struct lpm_trie_node *free_node = NULL;
struct lpm_trie_node __rcu **slot;
- struct bpf_lpm_trie_key *key = _key;
+ struct bpf_lpm_trie_key_u8 *key = _key;
unsigned long irq_flags;
unsigned int next_bit;
size_t matchlen = 0;
@@ -315,22 +337,12 @@ static int trie_update_elem(struct bpf_map *map,
if (key->prefixlen > trie->max_prefixlen)
return -EINVAL;
- spin_lock_irqsave(&trie->lock, irq_flags);
-
/* Allocate and fill a new node */
-
- if (trie->n_entries == trie->map.max_entries) {
- ret = -ENOSPC;
- goto out;
- }
-
new_node = lpm_trie_node_alloc(trie, value);
- if (!new_node) {
- ret = -ENOMEM;
- goto out;
- }
+ if (!new_node)
+ return -ENOMEM;
- trie->n_entries++;
+ raw_spin_lock_irqsave(&trie->lock, irq_flags);
new_node->prefixlen = key->prefixlen;
RCU_INIT_POINTER(new_node->child[0], NULL);
@@ -349,8 +361,7 @@ static int trie_update_elem(struct bpf_map *map,
matchlen = longest_prefix_match(trie, node, key);
if (node->prefixlen != matchlen ||
- node->prefixlen == key->prefixlen ||
- node->prefixlen == trie->max_prefixlen)
+ node->prefixlen == key->prefixlen)
break;
next_bit = extract_bit(key->data, node->prefixlen);
@@ -361,6 +372,10 @@ static int trie_update_elem(struct bpf_map *map,
* simply assign the @new_node to that slot and be done.
*/
if (!node) {
+ ret = trie_check_add_elem(trie, flags);
+ if (ret)
+ goto out;
+
rcu_assign_pointer(*slot, new_node);
goto out;
}
@@ -369,18 +384,30 @@ static int trie_update_elem(struct bpf_map *map,
* which already has the correct data array set.
*/
if (node->prefixlen == matchlen) {
+ if (!(node->flags & LPM_TREE_NODE_FLAG_IM)) {
+ if (flags == BPF_NOEXIST) {
+ ret = -EEXIST;
+ goto out;
+ }
+ } else {
+ ret = trie_check_add_elem(trie, flags);
+ if (ret)
+ goto out;
+ }
+
new_node->child[0] = node->child[0];
new_node->child[1] = node->child[1];
- if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
- trie->n_entries--;
-
rcu_assign_pointer(*slot, new_node);
- kfree_rcu(node, rcu);
+ free_node = node;
goto out;
}
+ ret = trie_check_add_elem(trie, flags);
+ if (ret)
+ goto out;
+
/* If the new node matches the prefix completely, it must be inserted
* as an ancestor. Simply insert it between @node and *@slot.
*/
@@ -393,6 +420,7 @@ static int trie_update_elem(struct bpf_map *map,
im_node = lpm_trie_node_alloc(trie, NULL);
if (!im_node) {
+ trie->n_entries--;
ret = -ENOMEM;
goto out;
}
@@ -410,28 +438,25 @@ static int trie_update_elem(struct bpf_map *map,
rcu_assign_pointer(im_node->child[1], node);
}
- /* Finally, assign the intermediate node to the determined spot */
+ /* Finally, assign the intermediate node to the determined slot */
rcu_assign_pointer(*slot, im_node);
out:
- if (ret) {
- if (new_node)
- trie->n_entries--;
-
- kfree(new_node);
- kfree(im_node);
- }
+ raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
- spin_unlock_irqrestore(&trie->lock, irq_flags);
+ if (ret)
+ bpf_mem_cache_free(&trie->ma, new_node);
+ bpf_mem_cache_free_rcu(&trie->ma, free_node);
return ret;
}
/* Called from syscall or from eBPF program */
-static int trie_delete_elem(struct bpf_map *map, void *_key)
+static long trie_delete_elem(struct bpf_map *map, void *_key)
{
struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
- struct bpf_lpm_trie_key *key = _key;
+ struct lpm_trie_node *free_node = NULL, *free_parent = NULL;
+ struct bpf_lpm_trie_key_u8 *key = _key;
struct lpm_trie_node __rcu **trim, **trim2;
struct lpm_trie_node *node, *parent;
unsigned long irq_flags;
@@ -442,7 +467,7 @@ static int trie_delete_elem(struct bpf_map *map, void *_key)
if (key->prefixlen > trie->max_prefixlen)
return -EINVAL;
- spin_lock_irqsave(&trie->lock, irq_flags);
+ raw_spin_lock_irqsave(&trie->lock, irq_flags);
/* Walk the tree looking for an exact key/length match and keeping
* track of the path we traverse. We will need to know the node
@@ -500,8 +525,8 @@ static int trie_delete_elem(struct bpf_map *map, void *_key)
else
rcu_assign_pointer(
*trim2, rcu_access_pointer(parent->child[0]));
- kfree_rcu(parent, rcu);
- kfree_rcu(node, rcu);
+ free_parent = parent;
+ free_node = node;
goto out;
}
@@ -515,10 +540,13 @@ static int trie_delete_elem(struct bpf_map *map, void *_key)
rcu_assign_pointer(*trim, rcu_access_pointer(node->child[1]));
else
RCU_INIT_POINTER(*trim, NULL);
- kfree_rcu(node, rcu);
+ free_node = node;
out:
- spin_unlock_irqrestore(&trie->lock, irq_flags);
+ raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
+
+ bpf_mem_cache_free_rcu(&trie->ma, free_parent);
+ bpf_mem_cache_free_rcu(&trie->ma, free_node);
return ret;
}
@@ -530,7 +558,7 @@ out:
sizeof(struct lpm_trie_node))
#define LPM_VAL_SIZE_MIN 1
-#define LPM_KEY_SIZE(X) (sizeof(struct bpf_lpm_trie_key) + (X))
+#define LPM_KEY_SIZE(X) (sizeof(struct bpf_lpm_trie_key_u8) + (X))
#define LPM_KEY_SIZE_MAX LPM_KEY_SIZE(LPM_DATA_SIZE_MAX)
#define LPM_KEY_SIZE_MIN LPM_KEY_SIZE(LPM_DATA_SIZE_MIN)
@@ -540,9 +568,8 @@ out:
static struct bpf_map *trie_alloc(union bpf_attr *attr)
{
struct lpm_trie *trie;
-
- if (!bpf_capable())
- return ERR_PTR(-EPERM);
+ size_t leaf_size;
+ int err;
/* check sanity of attributes */
if (attr->max_entries == 0 ||
@@ -555,19 +582,29 @@ static struct bpf_map *trie_alloc(union bpf_attr *attr)
attr->value_size > LPM_VAL_SIZE_MAX)
return ERR_PTR(-EINVAL);
- trie = kzalloc(sizeof(*trie), GFP_USER | __GFP_NOWARN | __GFP_ACCOUNT);
+ trie = bpf_map_area_alloc(sizeof(*trie), NUMA_NO_NODE);
if (!trie)
return ERR_PTR(-ENOMEM);
/* copy mandatory map attributes */
bpf_map_init_from_attr(&trie->map, attr);
trie->data_size = attr->key_size -
- offsetof(struct bpf_lpm_trie_key, data);
+ offsetof(struct bpf_lpm_trie_key_u8, data);
trie->max_prefixlen = trie->data_size * 8;
- spin_lock_init(&trie->lock);
+ raw_spin_lock_init(&trie->lock);
+ /* Allocate intermediate and leaf nodes from the same allocator */
+ leaf_size = sizeof(struct lpm_trie_node) + trie->data_size +
+ trie->map.value_size;
+ err = bpf_mem_alloc_init(&trie->ma, leaf_size, false);
+ if (err)
+ goto free_out;
return &trie->map;
+
+free_out:
+ bpf_map_area_free(trie);
+ return ERR_PTR(err);
}
static void trie_free(struct bpf_map *map)
@@ -599,25 +636,29 @@ static void trie_free(struct bpf_map *map)
continue;
}
- kfree(node);
+ /* No bpf program may access the map, so freeing the
+ * node without waiting for the extra RCU GP.
+ */
+ bpf_mem_cache_raw_free(node);
RCU_INIT_POINTER(*slot, NULL);
break;
}
}
out:
- kfree(trie);
+ bpf_mem_alloc_destroy(&trie->ma);
+ bpf_map_area_free(trie);
}
static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
{
struct lpm_trie_node *node, *next_node = NULL, *parent, *search_root;
struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
- struct bpf_lpm_trie_key *key = _key, *next_key = _next_key;
+ struct bpf_lpm_trie_key_u8 *key = _key, *next_key = _next_key;
struct lpm_trie_node **node_stack = NULL;
int err = 0, stack_ptr = -1;
unsigned int next_bit;
- size_t matchlen;
+ size_t matchlen = 0;
/* The get_next_key follows postorder. For the 4 node example in
* the top of this file, the trie_get_next_key() returns the following
@@ -639,7 +680,7 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
if (!key || key->prefixlen > trie->max_prefixlen)
goto find_leftmost;
- node_stack = kmalloc_array(trie->max_prefixlen,
+ node_stack = kmalloc_array(trie->max_prefixlen + 1,
sizeof(struct lpm_trie_node *),
GFP_ATOMIC | __GFP_NOWARN);
if (!node_stack)
@@ -656,7 +697,7 @@ static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
next_bit = extract_bit(key->data, node->prefixlen);
node = rcu_dereference(node->child[next_bit]);
}
- if (!node || node->prefixlen != key->prefixlen ||
+ if (!node || node->prefixlen != matchlen ||
(node->flags & LPM_TREE_NODE_FLAG_IM))
goto find_leftmost;
@@ -700,7 +741,7 @@ find_leftmost:
}
do_copy:
next_key->prefixlen = next_node->prefixlen;
- memcpy((void *)next_key + offsetof(struct bpf_lpm_trie_key, data),
+ memcpy((void *)next_key + offsetof(struct bpf_lpm_trie_key_u8, data),
next_node->data, trie->data_size);
free_stack:
kfree(node_stack);
@@ -712,12 +753,22 @@ static int trie_check_btf(const struct bpf_map *map,
const struct btf_type *key_type,
const struct btf_type *value_type)
{
- /* Keys must have struct bpf_lpm_trie_key embedded. */
+ /* Keys must have struct bpf_lpm_trie_key_u8 embedded. */
return BTF_INFO_KIND(key_type->info) != BTF_KIND_STRUCT ?
-EINVAL : 0;
}
-static int trie_map_btf_id;
+static u64 trie_mem_usage(const struct bpf_map *map)
+{
+ struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
+ u64 elem_size;
+
+ elem_size = sizeof(struct lpm_trie_node) + trie->data_size +
+ trie->map.value_size;
+ return elem_size * READ_ONCE(trie->n_entries);
+}
+
+BTF_ID_LIST_SINGLE(trie_map_btf_ids, struct, lpm_trie)
const struct bpf_map_ops trie_map_ops = {
.map_meta_equal = bpf_map_meta_equal,
.map_alloc = trie_alloc,
@@ -726,7 +777,10 @@ const struct bpf_map_ops trie_map_ops = {
.map_lookup_elem = trie_lookup_elem,
.map_update_elem = trie_update_elem,
.map_delete_elem = trie_delete_elem,
+ .map_lookup_batch = generic_map_lookup_batch,
+ .map_update_batch = generic_map_update_batch,
+ .map_delete_batch = generic_map_delete_batch,
.map_check_btf = trie_check_btf,
- .map_btf_name = "lpm_trie",
- .map_btf_id = &trie_map_btf_id,
+ .map_mem_usage = trie_mem_usage,
+ .map_btf_id = &trie_map_btf_ids[0],
};