diff options
-rw-r--r-- | include/linux/skmsg.h | 5 | ||||
-rw-r--r-- | include/uapi/linux/bpf.h | 20 | ||||
-rw-r--r-- | net/core/filter.c | 134 | ||||
-rw-r--r-- | tools/include/uapi/linux/bpf.h | 20 | ||||
-rw-r--r-- | tools/testing/selftests/bpf/bpf_helpers.h | 2 | ||||
-rw-r--r-- | tools/testing/selftests/bpf/test_sockmap.c | 58 | ||||
-rw-r--r-- | tools/testing/selftests/bpf/test_sockmap_kern.h | 97 |
7 files changed, 308 insertions, 28 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index 84e18863f6a4..2a11e9d91dfa 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -207,6 +207,11 @@ static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which) return &msg->sg.data[which]; } +static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which) +{ + return msg->sg.data[which]; +} + static inline struct page *sk_msg_page(struct sk_msg *msg, int which) { return sg_page(sk_msg_elem(msg, which)); diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h index a2fb333290dc..852dc17ab47a 100644 --- a/include/uapi/linux/bpf.h +++ b/include/uapi/linux/bpf.h @@ -2240,6 +2240,23 @@ union bpf_attr { * pointer that was returned from bpf_sk_lookup_xxx\ (). * Return * 0 on success, or a negative error in case of failure. + * + * int bpf_msg_push_data(struct sk_buff *skb, u32 start, u32 len, u64 flags) + * Description + * For socket policies, insert *len* bytes into msg at offset + * *start*. + * + * If a program of type **BPF_PROG_TYPE_SK_MSG** is run on a + * *msg* it may want to insert metadata or options into the msg. + * This can later be read and used by any of the lower layer BPF + * hooks. + * + * This helper may fail if under memory pressure (a malloc + * fails) in these cases BPF programs will get an appropriate + * error and BPF programs will need to handle them. + * + * Return + * 0 on success, or a negative error in case of failure. */ #define __BPF_FUNC_MAPPER(FN) \ FN(unspec), \ @@ -2331,7 +2348,8 @@ union bpf_attr { FN(sk_release), \ FN(map_push_elem), \ FN(map_pop_elem), \ - FN(map_peek_elem), + FN(map_peek_elem), \ + FN(msg_push_data), /* integer value in 'imm' field of BPF_CALL instruction selects which helper * function eBPF program intends to call diff --git a/net/core/filter.c b/net/core/filter.c index 5fd5139e8638..35c6933c2622 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -2297,6 +2297,137 @@ static const struct bpf_func_proto bpf_msg_pull_data_proto = { .arg4_type = ARG_ANYTHING, }; +BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start, + u32, len, u64, flags) +{ + struct scatterlist sge, nsge, nnsge, rsge = {0}, *psge; + u32 new, i = 0, l, space, copy = 0, offset = 0; + u8 *raw, *to, *from; + struct page *page; + + if (unlikely(flags)) + return -EINVAL; + + /* First find the starting scatterlist element */ + i = msg->sg.start; + do { + l = sk_msg_elem(msg, i)->length; + + if (start < offset + l) + break; + offset += l; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); + + if (start >= offset + l) + return -EINVAL; + + space = MAX_MSG_FRAGS - sk_msg_elem_used(msg); + + /* If no space available will fallback to copy, we need at + * least one scatterlist elem available to push data into + * when start aligns to the beginning of an element or two + * when it falls inside an element. We handle the start equals + * offset case because its the common case for inserting a + * header. + */ + if (!space || (space == 1 && start != offset)) + copy = msg->sg.data[i].length; + + page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC | __GFP_COMP, + get_order(copy + len)); + if (unlikely(!page)) + return -ENOMEM; + + if (copy) { + int front, back; + + raw = page_address(page); + + psge = sk_msg_elem(msg, i); + front = start - offset; + back = psge->length - front; + from = sg_virt(psge); + + if (front) + memcpy(raw, from, front); + + if (back) { + from += front; + to = raw + front + len; + + memcpy(to, from, back); + } + + put_page(sg_page(psge)); + } else if (start - offset) { + psge = sk_msg_elem(msg, i); + rsge = sk_msg_elem_cpy(msg, i); + + psge->length = start - offset; + rsge.length -= psge->length; + rsge.offset += start; + + sk_msg_iter_var_next(i); + sg_unmark_end(psge); + sk_msg_iter_next(msg, end); + } + + /* Slot(s) to place newly allocated data */ + new = i; + + /* Shift one or two slots as needed */ + if (!copy) { + sge = sk_msg_elem_cpy(msg, i); + + sk_msg_iter_var_next(i); + sg_unmark_end(&sge); + sk_msg_iter_next(msg, end); + + nsge = sk_msg_elem_cpy(msg, i); + if (rsge.length) { + sk_msg_iter_var_next(i); + nnsge = sk_msg_elem_cpy(msg, i); + } + + while (i != msg->sg.end) { + msg->sg.data[i] = sge; + sge = nsge; + sk_msg_iter_var_next(i); + if (rsge.length) { + nsge = nnsge; + nnsge = sk_msg_elem_cpy(msg, i); + } else { + nsge = sk_msg_elem_cpy(msg, i); + } + } + } + + /* Place newly allocated data buffer */ + sk_mem_charge(msg->sk, len); + msg->sg.size += len; + msg->sg.copy[new] = false; + sg_set_page(&msg->sg.data[new], page, len + copy, 0); + if (rsge.length) { + get_page(sg_page(&rsge)); + sk_msg_iter_var_next(new); + msg->sg.data[new] = rsge; + } + + sk_msg_compute_data_pointers(msg); + return 0; +} + +static const struct bpf_func_proto bpf_msg_push_data_proto = { + .func = bpf_msg_push_data, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb) { return task_get_classid(skb); @@ -4854,6 +4985,7 @@ bool bpf_helper_changes_pkt_data(void *func) func == bpf_xdp_adjust_head || func == bpf_xdp_adjust_meta || func == bpf_msg_pull_data || + func == bpf_msg_push_data || func == bpf_xdp_adjust_tail || #if IS_ENABLED(CONFIG_IPV6_SEG6_BPF) func == bpf_lwt_seg6_store_bytes || @@ -5130,6 +5262,8 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_msg_cork_bytes_proto; case BPF_FUNC_msg_pull_data: return &bpf_msg_pull_data_proto; + case BPF_FUNC_msg_push_data: + return &bpf_msg_push_data_proto; case BPF_FUNC_get_local_storage: return &bpf_get_local_storage_proto; default: diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h index a2fb333290dc..852dc17ab47a 100644 --- a/tools/include/uapi/linux/bpf.h +++ b/tools/include/uapi/linux/bpf.h @@ -2240,6 +2240,23 @@ union bpf_attr { * pointer that was returned from bpf_sk_lookup_xxx\ (). * Return * 0 on success, or a negative error in case of failure. + * + * int bpf_msg_push_data(struct sk_buff *skb, u32 start, u32 len, u64 flags) + * Description + * For socket policies, insert *len* bytes into msg at offset + * *start*. + * + * If a program of type **BPF_PROG_TYPE_SK_MSG** is run on a + * *msg* it may want to insert metadata or options into the msg. + * This can later be read and used by any of the lower layer BPF + * hooks. + * + * This helper may fail if under memory pressure (a malloc + * fails) in these cases BPF programs will get an appropriate + * error and BPF programs will need to handle them. + * + * Return + * 0 on success, or a negative error in case of failure. */ #define __BPF_FUNC_MAPPER(FN) \ FN(unspec), \ @@ -2331,7 +2348,8 @@ union bpf_attr { FN(sk_release), \ FN(map_push_elem), \ FN(map_pop_elem), \ - FN(map_peek_elem), + FN(map_peek_elem), \ + FN(msg_push_data), /* integer value in 'imm' field of BPF_CALL instruction selects which helper * function eBPF program intends to call diff --git a/tools/testing/selftests/bpf/bpf_helpers.h b/tools/testing/selftests/bpf/bpf_helpers.h index 6407a3df0f3b..686e57ce40f4 100644 --- a/tools/testing/selftests/bpf/bpf_helpers.h +++ b/tools/testing/selftests/bpf/bpf_helpers.h @@ -111,6 +111,8 @@ static int (*bpf_msg_cork_bytes)(void *ctx, int len) = (void *) BPF_FUNC_msg_cork_bytes; static int (*bpf_msg_pull_data)(void *ctx, int start, int end, int flags) = (void *) BPF_FUNC_msg_pull_data; +static int (*bpf_msg_push_data)(void *ctx, int start, int end, int flags) = + (void *) BPF_FUNC_msg_push_data; static int (*bpf_bind)(void *ctx, void *addr, int addr_len) = (void *) BPF_FUNC_bind; static int (*bpf_xdp_adjust_tail)(void *ctx, int offset) = diff --git a/tools/testing/selftests/bpf/test_sockmap.c b/tools/testing/selftests/bpf/test_sockmap.c index cbd1c0be8680..622ade0a0957 100644 --- a/tools/testing/selftests/bpf/test_sockmap.c +++ b/tools/testing/selftests/bpf/test_sockmap.c @@ -77,6 +77,8 @@ int txmsg_apply; int txmsg_cork; int txmsg_start; int txmsg_end; +int txmsg_start_push; +int txmsg_end_push; int txmsg_ingress; int txmsg_skb; int ktls; @@ -100,6 +102,8 @@ static const struct option long_options[] = { {"txmsg_cork", required_argument, NULL, 'k'}, {"txmsg_start", required_argument, NULL, 's'}, {"txmsg_end", required_argument, NULL, 'e'}, + {"txmsg_start_push", required_argument, NULL, 'p'}, + {"txmsg_end_push", required_argument, NULL, 'q'}, {"txmsg_ingress", no_argument, &txmsg_ingress, 1 }, {"txmsg_skb", no_argument, &txmsg_skb, 1 }, {"ktls", no_argument, &ktls, 1 }, @@ -903,6 +907,30 @@ run: } } + if (txmsg_start_push) { + i = 2; + err = bpf_map_update_elem(map_fd[5], + &i, &txmsg_start_push, BPF_ANY); + if (err) { + fprintf(stderr, + "ERROR: bpf_map_update_elem (txmsg_start_push): %d (%s)\n", + err, strerror(errno)); + goto out; + } + } + + if (txmsg_end_push) { + i = 3; + err = bpf_map_update_elem(map_fd[5], + &i, &txmsg_end_push, BPF_ANY); + if (err) { + fprintf(stderr, + "ERROR: bpf_map_update_elem %i@%i (txmsg_end_push): %d (%s)\n", + txmsg_end_push, i, err, strerror(errno)); + goto out; + } + } + if (txmsg_ingress) { int in = BPF_F_INGRESS; @@ -1235,6 +1263,8 @@ static int test_mixed(int cgrp) txmsg_pass = txmsg_noisy = txmsg_redir_noisy = txmsg_drop = 0; txmsg_apply = txmsg_cork = 0; txmsg_start = txmsg_end = 0; + txmsg_start_push = txmsg_end_push = 0; + /* Test small and large iov_count values with pass/redir/apply/cork */ txmsg_pass = 1; txmsg_redir = 0; @@ -1351,6 +1381,8 @@ static int test_start_end(int cgrp) /* Test basic start/end with lots of iov_count and iov_lengths */ txmsg_start = 1; txmsg_end = 2; + txmsg_start_push = 1; + txmsg_end_push = 2; err = test_txmsg(cgrp); if (err) goto out; @@ -1364,6 +1396,8 @@ static int test_start_end(int cgrp) for (i = 99; i <= 1600; i += 500) { txmsg_start = 0; txmsg_end = i; + txmsg_start_push = 0; + txmsg_end_push = i; err = test_exec(cgrp, &opt); if (err) goto out; @@ -1373,6 +1407,8 @@ static int test_start_end(int cgrp) for (i = 199; i <= 1600; i += 500) { txmsg_start = 100; txmsg_end = i; + txmsg_start_push = 100; + txmsg_end_push = i; err = test_exec(cgrp, &opt); if (err) goto out; @@ -1381,6 +1417,8 @@ static int test_start_end(int cgrp) /* Test start/end with cork pulling last sg entry */ txmsg_start = 1500; txmsg_end = 1600; + txmsg_start_push = 1500; + txmsg_end_push = 1600; err = test_exec(cgrp, &opt); if (err) goto out; @@ -1388,6 +1426,8 @@ static int test_start_end(int cgrp) /* Test start/end pull of single byte in last page */ txmsg_start = 1111; txmsg_end = 1112; + txmsg_start_push = 1111; + txmsg_end_push = 1112; err = test_exec(cgrp, &opt); if (err) goto out; @@ -1395,6 +1435,8 @@ static int test_start_end(int cgrp) /* Test start/end with end < start */ txmsg_start = 1111; txmsg_end = 0; + txmsg_start_push = 1111; + txmsg_end_push = 0; err = test_exec(cgrp, &opt); if (err) goto out; @@ -1402,6 +1444,8 @@ static int test_start_end(int cgrp) /* Test start/end with end > data */ txmsg_start = 0; txmsg_end = 1601; + txmsg_start_push = 0; + txmsg_end_push = 1601; err = test_exec(cgrp, &opt); if (err) goto out; @@ -1409,6 +1453,8 @@ static int test_start_end(int cgrp) /* Test start/end with start > data */ txmsg_start = 1601; txmsg_end = 1600; + txmsg_start_push = 1601; + txmsg_end_push = 1600; err = test_exec(cgrp, &opt); out: @@ -1424,7 +1470,7 @@ char *map_names[] = { "sock_map_redir", "sock_apply_bytes", "sock_cork_bytes", - "sock_pull_bytes", + "sock_bytes", "sock_redir_flags", "sock_skb_opts", }; @@ -1531,7 +1577,7 @@ static int __test_suite(int cg_fd, char *bpf_file) } /* Tests basic commands and APIs with range of iov values */ - txmsg_start = txmsg_end = 0; + txmsg_start = txmsg_end = txmsg_start_push = txmsg_end_push = 0; err = test_txmsg(cg_fd); if (err) goto out; @@ -1580,7 +1626,7 @@ int main(int argc, char **argv) if (argc < 2) return test_suite(-1); - while ((opt = getopt_long(argc, argv, ":dhvc:r:i:l:t:", + while ((opt = getopt_long(argc, argv, ":dhvc:r:i:l:t:p:q:", long_options, &longindex)) != -1) { switch (opt) { case 's': @@ -1589,6 +1635,12 @@ int main(int argc, char **argv) case 'e': txmsg_end = atoi(optarg); break; + case 'p': + txmsg_start_push = atoi(optarg); + break; + case 'q': + txmsg_end_push = atoi(optarg); + break; case 'a': txmsg_apply = atoi(optarg); break; diff --git a/tools/testing/selftests/bpf/test_sockmap_kern.h b/tools/testing/selftests/bpf/test_sockmap_kern.h index 8e8e41780bb9..14b8bbac004f 100644 --- a/tools/testing/selftests/bpf/test_sockmap_kern.h +++ b/tools/testing/selftests/bpf/test_sockmap_kern.h @@ -70,11 +70,11 @@ struct bpf_map_def SEC("maps") sock_cork_bytes = { .max_entries = 1 }; -struct bpf_map_def SEC("maps") sock_pull_bytes = { +struct bpf_map_def SEC("maps") sock_bytes = { .type = BPF_MAP_TYPE_ARRAY, .key_size = sizeof(int), .value_size = sizeof(int), - .max_entries = 2 + .max_entries = 4 }; struct bpf_map_def SEC("maps") sock_redir_flags = { @@ -181,8 +181,8 @@ int bpf_sockmap(struct bpf_sock_ops *skops) SEC("sk_msg1") int bpf_prog4(struct sk_msg_md *msg) { - int *bytes, zero = 0, one = 1; - int *start, *end; + int *bytes, zero = 0, one = 1, two = 2, three = 3; + int *start, *end, *start_push, *end_push; bytes = bpf_map_lookup_elem(&sock_apply_bytes, &zero); if (bytes) @@ -190,18 +190,24 @@ int bpf_prog4(struct sk_msg_md *msg) bytes = bpf_map_lookup_elem(&sock_cork_bytes, &zero); if (bytes) bpf_msg_cork_bytes(msg, *bytes); - start = bpf_map_lookup_elem(&sock_pull_bytes, &zero); - end = bpf_map_lookup_elem(&sock_pull_bytes, &one); + start = bpf_map_lookup_elem(&sock_bytes, &zero); + end = bpf_map_lookup_elem(&sock_bytes, &one); if (start && end) bpf_msg_pull_data(msg, *start, *end, 0); + start_push = bpf_map_lookup_elem(&sock_bytes, &two); + end_push = bpf_map_lookup_elem(&sock_bytes, &three); + if (start_push && end_push) + bpf_msg_push_data(msg, *start_push, *end_push, 0); return SK_PASS; } SEC("sk_msg2") int bpf_prog5(struct sk_msg_md *msg) { - int err1 = -1, err2 = -1, zero = 0, one = 1; - int *bytes, *start, *end, len1, len2; + int zero = 0, one = 1, two = 2, three = 3; + int *start, *end, *start_push, *end_push; + int *bytes, len1, len2 = 0, len3; + int err1 = -1, err2 = -1; bytes = bpf_map_lookup_elem(&sock_apply_bytes, &zero); if (bytes) @@ -210,8 +216,8 @@ int bpf_prog5(struct sk_msg_md *msg) if (bytes) err2 = bpf_msg_cork_bytes(msg, *bytes); len1 = (__u64)msg->data_end - (__u64)msg->data; - start = bpf_map_lookup_elem(&sock_pull_bytes, &zero); - end = bpf_map_lookup_elem(&sock_pull_bytes, &one); + start = bpf_map_lookup_elem(&sock_bytes, &zero); + end = bpf_map_lookup_elem(&sock_bytes, &one); if (start && end) { int err; @@ -225,6 +231,23 @@ int bpf_prog5(struct sk_msg_md *msg) bpf_printk("sk_msg2: length update %i->%i\n", len1, len2); } + + start_push = bpf_map_lookup_elem(&sock_bytes, &two); + end_push = bpf_map_lookup_elem(&sock_bytes, &three); + if (start_push && end_push) { + int err; + + bpf_printk("sk_msg2: push(%i:%i)\n", + start_push ? *start_push : 0, + end_push ? *end_push : 0); + err = bpf_msg_push_data(msg, *start_push, *end_push, 0); + if (err) + bpf_printk("sk_msg2: push_data err %i\n", err); + len3 = (__u64)msg->data_end - (__u64)msg->data; + bpf_printk("sk_msg2: length push_update %i->%i\n", + len2 ? len2 : len1, len3); + } + bpf_printk("sk_msg2: data length %i err1 %i err2 %i\n", len1, err1, err2); return SK_PASS; @@ -233,8 +256,8 @@ int bpf_prog5(struct sk_msg_md *msg) SEC("sk_msg3") int bpf_prog6(struct sk_msg_md *msg) { - int *bytes, zero = 0, one = 1, key = 0; - int *start, *end, *f; + int *bytes, *start, *end, *start_push, *end_push, *f; + int zero = 0, one = 1, two = 2, three = 3, key = 0; __u64 flags = 0; bytes = bpf_map_lookup_elem(&sock_apply_bytes, &zero); @@ -243,10 +266,17 @@ int bpf_prog6(struct sk_msg_md *msg) bytes = bpf_map_lookup_elem(&sock_cork_bytes, &zero); if (bytes) bpf_msg_cork_bytes(msg, *bytes); - start = bpf_map_lookup_elem(&sock_pull_bytes, &zero); - end = bpf_map_lookup_elem(&sock_pull_bytes, &one); + + start = bpf_map_lookup_elem(&sock_bytes, &zero); + end = bpf_map_lookup_elem(&sock_bytes, &one); if (start && end) bpf_msg_pull_data(msg, *start, *end, 0); + + start_push = bpf_map_lookup_elem(&sock_bytes, &two); + end_push = bpf_map_lookup_elem(&sock_bytes, &three); + if (start_push && end_push) + bpf_msg_push_data(msg, *start_push, *end_push, 0); + f = bpf_map_lookup_elem(&sock_redir_flags, &zero); if (f && *f) { key = 2; @@ -262,8 +292,9 @@ int bpf_prog6(struct sk_msg_md *msg) SEC("sk_msg4") int bpf_prog7(struct sk_msg_md *msg) { - int err1 = 0, err2 = 0, zero = 0, one = 1, key = 0; - int *f, *bytes, *start, *end, len1, len2; + int zero = 0, one = 1, two = 2, three = 3, len1, len2 = 0, len3; + int *bytes, *start, *end, *start_push, *end_push, *f; + int err1 = 0, err2 = 0, key = 0; __u64 flags = 0; int err; @@ -274,10 +305,10 @@ int bpf_prog7(struct sk_msg_md *msg) if (bytes) err2 = bpf_msg_cork_bytes(msg, *bytes); len1 = (__u64)msg->data_end - (__u64)msg->data; - start = bpf_map_lookup_elem(&sock_pull_bytes, &zero); - end = bpf_map_lookup_elem(&sock_pull_bytes, &one); - if (start && end) { + start = bpf_map_lookup_elem(&sock_bytes, &zero); + end = bpf_map_lookup_elem(&sock_bytes, &one); + if (start && end) { bpf_printk("sk_msg2: pull(%i:%i)\n", start ? *start : 0, end ? *end : 0); err = bpf_msg_pull_data(msg, *start, *end, 0); @@ -288,6 +319,22 @@ int bpf_prog7(struct sk_msg_md *msg) bpf_printk("sk_msg2: length update %i->%i\n", len1, len2); } + + start_push = bpf_map_lookup_elem(&sock_bytes, &two); + end_push = bpf_map_lookup_elem(&sock_bytes, &three); + if (start_push && end_push) { + bpf_printk("sk_msg4: push(%i:%i)\n", + start_push ? *start_push : 0, + end_push ? *end_push : 0); + err = bpf_msg_push_data(msg, *start_push, *end_push, 0); + if (err) + bpf_printk("sk_msg4: push_data err %i\n", + err); + len3 = (__u64)msg->data_end - (__u64)msg->data; + bpf_printk("sk_msg4: length push_update %i->%i\n", + len2 ? len2 : len1, len3); + } + f = bpf_map_lookup_elem(&sock_redir_flags, &zero); if (f && *f) { key = 2; @@ -342,8 +389,8 @@ int bpf_prog9(struct sk_msg_md *msg) SEC("sk_msg7") int bpf_prog10(struct sk_msg_md *msg) { - int *bytes, zero = 0, one = 1; - int *start, *end; + int *bytes, *start, *end, *start_push, *end_push; + int zero = 0, one = 1, two = 2, three = 3; bytes = bpf_map_lookup_elem(&sock_apply_bytes, &zero); if (bytes) @@ -351,10 +398,14 @@ int bpf_prog10(struct sk_msg_md *msg) bytes = bpf_map_lookup_elem(&sock_cork_bytes, &zero); if (bytes) bpf_msg_cork_bytes(msg, *bytes); - start = bpf_map_lookup_elem(&sock_pull_bytes, &zero); - end = bpf_map_lookup_elem(&sock_pull_bytes, &one); + start = bpf_map_lookup_elem(&sock_bytes, &zero); + end = bpf_map_lookup_elem(&sock_bytes, &one); if (start && end) bpf_msg_pull_data(msg, *start, *end, 0); + start_push = bpf_map_lookup_elem(&sock_bytes, &two); + end_push = bpf_map_lookup_elem(&sock_bytes, &three); + if (start_push && end_push) + bpf_msg_push_data(msg, *start_push, *end_push, 0); return SK_DROP; } |