summaryrefslogtreecommitdiff
path: root/net/ipv4/bpf_tcp_ca.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/bpf_tcp_ca.c')
-rw-r--r--net/ipv4/bpf_tcp_ca.c40
1 files changed, 35 insertions, 5 deletions
diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c
index 574972bc7299..e3939f76b024 100644
--- a/net/ipv4/bpf_tcp_ca.c
+++ b/net/ipv4/bpf_tcp_ca.c
@@ -7,6 +7,7 @@
#include <linux/btf.h>
#include <linux/filter.h>
#include <net/tcp.h>
+#include <net/bpf_sk_storage.h>
static u32 optional_ops[] = {
offsetof(struct tcp_congestion_ops, init),
@@ -27,6 +28,27 @@ static u32 unsupported_ops[] = {
static const struct btf_type *tcp_sock_type;
static u32 tcp_sock_id, sock_id;
+static int btf_sk_storage_get_ids[5];
+static struct bpf_func_proto btf_sk_storage_get_proto __read_mostly;
+
+static int btf_sk_storage_delete_ids[5];
+static struct bpf_func_proto btf_sk_storage_delete_proto __read_mostly;
+
+static void convert_sk_func_proto(struct bpf_func_proto *to, int *to_btf_ids,
+ const struct bpf_func_proto *from)
+{
+ int i;
+
+ *to = *from;
+ to->btf_id = to_btf_ids;
+ for (i = 0; i < ARRAY_SIZE(to->arg_type); i++) {
+ if (to->arg_type[i] == ARG_PTR_TO_SOCKET) {
+ to->arg_type[i] = ARG_PTR_TO_BTF_ID;
+ to->btf_id[i] = tcp_sock_id;
+ }
+ }
+}
+
static int bpf_tcp_ca_init(struct btf *btf)
{
s32 type_id;
@@ -42,6 +64,13 @@ static int bpf_tcp_ca_init(struct btf *btf)
tcp_sock_id = type_id;
tcp_sock_type = btf_type_by_id(btf, tcp_sock_id);
+ convert_sk_func_proto(&btf_sk_storage_get_proto,
+ btf_sk_storage_get_ids,
+ &bpf_sk_storage_get_proto);
+ convert_sk_func_proto(&btf_sk_storage_delete_proto,
+ btf_sk_storage_delete_ids,
+ &bpf_sk_storage_delete_proto);
+
return 0;
}
@@ -167,6 +196,10 @@ bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id,
switch (func_id) {
case BPF_FUNC_tcp_send_ack:
return &bpf_tcp_send_ack_proto;
+ case BPF_FUNC_sk_storage_get:
+ return &btf_sk_storage_get_proto;
+ case BPF_FUNC_sk_storage_delete:
+ return &btf_sk_storage_delete_proto;
default:
return bpf_base_func_proto(func_id);
}
@@ -184,7 +217,6 @@ static int bpf_tcp_ca_init_member(const struct btf_type *t,
{
const struct tcp_congestion_ops *utcp_ca;
struct tcp_congestion_ops *tcp_ca;
- size_t tcp_ca_name_len;
int prog_fd;
u32 moff;
@@ -199,13 +231,11 @@ static int bpf_tcp_ca_init_member(const struct btf_type *t,
tcp_ca->flags = utcp_ca->flags;
return 1;
case offsetof(struct tcp_congestion_ops, name):
- tcp_ca_name_len = strnlen(utcp_ca->name, sizeof(utcp_ca->name));
- if (!tcp_ca_name_len ||
- tcp_ca_name_len == sizeof(utcp_ca->name))
+ if (bpf_obj_name_cpy(tcp_ca->name, utcp_ca->name,
+ sizeof(tcp_ca->name)) <= 0)
return -EINVAL;
if (tcp_ca_find(utcp_ca->name))
return -EEXIST;
- memcpy(tcp_ca->name, utcp_ca->name, sizeof(tcp_ca->name));
return 1;
}