summaryrefslogtreecommitdiff
path: root/include/linux/skmsg.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/linux/skmsg.h')
-rw-r--r--include/linux/skmsg.h364
1 files changed, 248 insertions, 116 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 178a3933a71b..49847888c287 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -14,6 +14,7 @@
#include <net/strparser.h>
#define MAX_MSG_FRAGS MAX_SKB_FRAGS
+#define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
enum __sk_action {
__SK_DROP = 0,
@@ -28,12 +29,14 @@ struct sk_msg_sg {
u32 end;
u32 size;
u32 copybreak;
- bool copy[MAX_MSG_FRAGS];
- /* The extra element is used for chaining the front and sections when
- * the list becomes partitioned (e.g. end < start). The crypto APIs
- * require the chaining.
+ DECLARE_BITMAP(copy, MAX_MSG_FRAGS + 2);
+ /* The extra two elements:
+ * 1) used for chaining the front and sections when the list becomes
+ * partitioned (e.g. end < start). The crypto APIs require the
+ * chaining;
+ * 2) to chain tailer SG entries after the message.
*/
- struct scatterlist data[MAX_MSG_FRAGS + 1];
+ struct scatterlist data[MAX_MSG_FRAGS + 2];
};
/* UAPI in filter.c depends on struct sk_msg_sg being first element. */
@@ -52,12 +55,18 @@ struct sk_msg {
struct sk_psock_progs {
struct bpf_prog *msg_parser;
- struct bpf_prog *skb_parser;
+ struct bpf_prog *stream_parser;
+ struct bpf_prog *stream_verdict;
struct bpf_prog *skb_verdict;
+ struct bpf_link *msg_parser_link;
+ struct bpf_link *stream_parser_link;
+ struct bpf_link *stream_verdict_link;
+ struct bpf_link *skb_verdict_link;
};
enum sk_psock_state_bits {
SK_PSOCK_TX_ENABLED,
+ SK_PSOCK_RX_STRP_ENABLED,
};
struct sk_psock_link {
@@ -66,14 +75,7 @@ struct sk_psock_link {
void *link_raw;
};
-struct sk_psock_parser {
- struct strparser strp;
- bool enabled;
- void (*saved_data_ready)(struct sock *sk);
-};
-
struct sk_psock_work_state {
- struct sk_buff *skb;
u32 len;
u32 off;
};
@@ -84,25 +86,39 @@ struct sk_psock {
u32 apply_bytes;
u32 cork_bytes;
u32 eval;
+ bool redir_ingress; /* undefined if sk_redir is null */
struct sk_msg *cork;
struct sk_psock_progs progs;
- struct sk_psock_parser parser;
+#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
+ struct strparser strp;
+ u32 copied_seq;
+ u32 ingress_bytes;
+#endif
struct sk_buff_head ingress_skb;
struct list_head ingress_msg;
+ spinlock_t ingress_lock;
unsigned long state;
struct list_head link;
spinlock_t link_lock;
refcount_t refcnt;
void (*saved_unhash)(struct sock *sk);
+ void (*saved_destroy)(struct sock *sk);
void (*saved_close)(struct sock *sk, long timeout);
void (*saved_write_space)(struct sock *sk);
+ void (*saved_data_ready)(struct sock *sk);
+ /* psock_update_sk_prot may be called with restore=false many times
+ * so the handler must be safe for this case. It will be called
+ * exactly once with restore=true when the psock is being destroyed
+ * and psock refcnt is zero, but before an RCU grace period.
+ */
+ int (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
+ bool restore);
struct proto *sk_proto;
+ struct mutex work_mutex;
struct sk_psock_work_state work_state;
- struct work_struct work;
- union {
- struct rcu_head rcu;
- struct work_struct gc;
- };
+ struct delayed_work work;
+ struct sock *sk_pair;
+ struct rcu_work rwork;
};
int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
@@ -123,6 +139,9 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes);
int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes);
+int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
+ int len, int flags);
+bool sk_msg_is_readable(struct sock *sk);
static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
{
@@ -139,10 +158,15 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
}
}
+static inline u32 sk_msg_iter_dist(u32 start, u32 end)
+{
+ return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
+}
+
#define sk_msg_iter_var_prev(var) \
do { \
if (var == 0) \
- var = MAX_MSG_FRAGS - 1; \
+ var = NR_MSG_FRAG_IDS - 1; \
else \
var--; \
} while (0)
@@ -150,7 +174,7 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
#define sk_msg_iter_var_next(var) \
do { \
var++; \
- if (var == MAX_MSG_FRAGS) \
+ if (var == NR_MSG_FRAG_IDS) \
var = 0; \
} while (0)
@@ -160,16 +184,11 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
#define sk_msg_iter_next(msg, which) \
sk_msg_iter_var_next(msg->sg.which)
-static inline void sk_msg_clear_meta(struct sk_msg *msg)
-{
- memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
-}
-
static inline void sk_msg_init(struct sk_msg *msg)
{
- BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS);
+ BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
memset(msg, 0, sizeof(*msg));
- sg_init_marker(msg->sg.data, MAX_MSG_FRAGS);
+ sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
}
static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
@@ -178,6 +197,7 @@ static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
dst->sg.data[which] = src->sg.data[which];
dst->sg.data[which].length = size;
dst->sg.size += size;
+ src->sg.size -= size;
src->sg.data[which].length -= size;
src->sg.data[which].offset += size;
}
@@ -190,17 +210,12 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
static inline bool sk_msg_full(const struct sk_msg *msg)
{
- return (msg->sg.end == msg->sg.start) && msg->sg.size;
+ return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
}
static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
{
- if (sk_msg_full(msg))
- return MAX_MSG_FRAGS;
-
- return msg->sg.end >= msg->sg.start ?
- msg->sg.end - msg->sg.start :
- msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start);
+ return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
}
static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
@@ -227,7 +242,7 @@ static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
{
struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
- if (msg->sg.copy[msg->sg.start]) {
+ if (test_bit(msg->sg.start, msg->sg.copy)) {
msg->data = NULL;
msg->data_end = NULL;
} else {
@@ -246,7 +261,7 @@ static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
sg_set_page(sge, page, len, offset);
sg_unmark_end(sge);
- msg->sg.copy[msg->sg.end] = true;
+ __set_bit(msg->sg.end, msg->sg.copy);
msg->sg.size += len;
sk_msg_iter_next(msg, end);
}
@@ -254,7 +269,10 @@ static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
{
do {
- msg->sg.copy[i] = copy_state;
+ if (copy_state)
+ __set_bit(i, msg->sg.copy);
+ else
+ __clear_bit(i, msg->sg.copy);
sk_msg_iter_var_next(i);
if (i == msg->sg.end)
break;
@@ -273,13 +291,86 @@ static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
static inline struct sk_psock *sk_psock(const struct sock *sk)
{
- return rcu_dereference_sk_user_data(sk);
+ return __rcu_dereference_sk_user_data_with_flags(sk,
+ SK_USER_DATA_PSOCK);
+}
+
+static inline void sk_psock_set_state(struct sk_psock *psock,
+ enum sk_psock_state_bits bit)
+{
+ set_bit(bit, &psock->state);
+}
+
+static inline void sk_psock_clear_state(struct sk_psock *psock,
+ enum sk_psock_state_bits bit)
+{
+ clear_bit(bit, &psock->state);
+}
+
+static inline bool sk_psock_test_state(const struct sk_psock *psock,
+ enum sk_psock_state_bits bit)
+{
+ return test_bit(bit, &psock->state);
+}
+
+static inline void sock_drop(struct sock *sk, struct sk_buff *skb)
+{
+ sk_drops_skbadd(sk, skb);
+ kfree_skb(skb);
}
-static inline void sk_psock_queue_msg(struct sk_psock *psock,
+static inline bool sk_psock_queue_msg(struct sk_psock *psock,
struct sk_msg *msg)
{
- list_add_tail(&msg->list, &psock->ingress_msg);
+ bool ret;
+
+ spin_lock_bh(&psock->ingress_lock);
+ if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
+ list_add_tail(&msg->list, &psock->ingress_msg);
+ ret = true;
+ } else {
+ sk_msg_free(psock->sk, msg);
+ kfree(msg);
+ ret = false;
+ }
+ spin_unlock_bh(&psock->ingress_lock);
+ return ret;
+}
+
+static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)
+{
+ struct sk_msg *msg;
+
+ spin_lock_bh(&psock->ingress_lock);
+ msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
+ if (msg)
+ list_del(&msg->list);
+ spin_unlock_bh(&psock->ingress_lock);
+ return msg;
+}
+
+static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
+{
+ struct sk_msg *msg;
+
+ spin_lock_bh(&psock->ingress_lock);
+ msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
+ spin_unlock_bh(&psock->ingress_lock);
+ return msg;
+}
+
+static inline struct sk_msg *sk_psock_next_msg(struct sk_psock *psock,
+ struct sk_msg *msg)
+{
+ struct sk_msg *ret;
+
+ spin_lock_bh(&psock->ingress_lock);
+ if (list_is_last(&msg->list, &psock->ingress_msg))
+ ret = NULL;
+ else
+ ret = list_next_entry(msg, list);
+ spin_unlock_bh(&psock->ingress_lock);
+ return ret;
}
static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
@@ -287,28 +378,57 @@ static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
return psock ? list_empty(&psock->ingress_msg) : true;
}
+static inline void kfree_sk_msg(struct sk_msg *msg)
+{
+ if (msg->skb)
+ consume_skb(msg->skb);
+ kfree(msg);
+}
+
static inline void sk_psock_report_error(struct sk_psock *psock, int err)
{
struct sock *sk = psock->sk;
sk->sk_err = err;
- sk->sk_error_report(sk);
+ sk_error_report(sk);
}
struct sk_psock *sk_psock_init(struct sock *sk, int node);
+void sk_psock_stop(struct sk_psock *psock);
+#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
+#else
+static inline int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
+{
+ return -EOPNOTSUPP;
+}
-int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
- struct sk_msg *msg);
+static inline void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
+{
+}
-static inline struct sk_psock_link *sk_psock_init_link(void)
+static inline void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
{
- return kzalloc(sizeof(struct sk_psock_link),
- GFP_ATOMIC | __GFP_NOWARN);
}
+#endif
+
+void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock);
+void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock);
+
+int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
+ struct sk_msg *msg);
+
+/*
+ * This specialized allocator has to be a macro for its allocations to be
+ * accounted separately (to have a separate alloc_tag). The typecast is
+ * intentional to enforce typesafety.
+ */
+#define sk_psock_init_link() \
+ ((struct sk_psock_link *)kzalloc(sizeof(struct sk_psock_link), \
+ GFP_ATOMIC | __GFP_NOWARN))
static inline void sk_psock_free_link(struct sk_psock_link *link)
{
@@ -316,16 +436,6 @@ static inline void sk_psock_free_link(struct sk_psock_link *link)
}
struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
-#if defined(CONFIG_BPF_STREAM_PARSER)
-void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
-#else
-static inline void sk_psock_unlink(struct sock *sk,
- struct sk_psock_link *link)
-{
-}
-#endif
-
-void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
static inline void sk_psock_cork_free(struct sk_psock *psock)
{
@@ -336,63 +446,11 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
}
}
-static inline void sk_psock_update_proto(struct sock *sk,
- struct sk_psock *psock,
- struct proto *ops)
-{
- psock->saved_unhash = sk->sk_prot->unhash;
- psock->saved_close = sk->sk_prot->close;
- psock->saved_write_space = sk->sk_write_space;
-
- psock->sk_proto = sk->sk_prot;
- sk->sk_prot = ops;
-}
-
static inline void sk_psock_restore_proto(struct sock *sk,
struct sk_psock *psock)
{
- if (psock->sk_proto) {
- sk->sk_prot = psock->sk_proto;
- psock->sk_proto = NULL;
- }
-}
-
-static inline void sk_psock_set_state(struct sk_psock *psock,
- enum sk_psock_state_bits bit)
-{
- set_bit(bit, &psock->state);
-}
-
-static inline void sk_psock_clear_state(struct sk_psock *psock,
- enum sk_psock_state_bits bit)
-{
- clear_bit(bit, &psock->state);
-}
-
-static inline bool sk_psock_test_state(const struct sk_psock *psock,
- enum sk_psock_state_bits bit)
-{
- return test_bit(bit, &psock->state);
-}
-
-static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
-{
- struct sk_psock *psock;
-
- rcu_read_lock();
- psock = sk_psock(sk);
- if (psock) {
- if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
- psock = ERR_PTR(-EBUSY);
- goto out;
- }
-
- if (!refcount_inc_not_zero(&psock->refcnt))
- psock = ERR_PTR(-EBUSY);
- }
-out:
- rcu_read_unlock();
- return psock;
+ if (psock->psock_update_sk_prot)
+ psock->psock_update_sk_prot(sk, psock, true);
}
static inline struct sk_psock *sk_psock_get(struct sock *sk)
@@ -407,8 +465,6 @@ static inline struct sk_psock *sk_psock_get(struct sock *sk)
return psock;
}
-void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
-void sk_psock_destroy(struct rcu_head *rcu);
void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
@@ -419,10 +475,12 @@ static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
{
- if (psock->parser.enabled)
- psock->parser.saved_data_ready(sk);
+ read_lock_bh(&sk->sk_callback_lock);
+ if (psock->saved_data_ready)
+ psock->saved_data_ready(sk);
else
sk->sk_data_ready(sk);
+ read_unlock_bh(&sk->sk_callback_lock);
}
static inline void psock_set_prog(struct bpf_prog **pprog,
@@ -433,11 +491,85 @@ static inline void psock_set_prog(struct bpf_prog **pprog,
bpf_prog_put(prog);
}
+static inline int psock_replace_prog(struct bpf_prog **pprog,
+ struct bpf_prog *prog,
+ struct bpf_prog *old)
+{
+ if (cmpxchg(pprog, old, prog) != old)
+ return -ENOENT;
+
+ if (old)
+ bpf_prog_put(old);
+
+ return 0;
+}
+
static inline void psock_progs_drop(struct sk_psock_progs *progs)
{
psock_set_prog(&progs->msg_parser, NULL);
- psock_set_prog(&progs->skb_parser, NULL);
+ psock_set_prog(&progs->stream_parser, NULL);
+ psock_set_prog(&progs->stream_verdict, NULL);
psock_set_prog(&progs->skb_verdict, NULL);
}
+int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);
+
+static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
+{
+ if (!psock)
+ return false;
+ return !!psock->saved_data_ready;
+}
+
+#if IS_ENABLED(CONFIG_NET_SOCK_MSG)
+
+#define BPF_F_STRPARSER (1UL << 1)
+
+/* We only have two bits so far. */
+#define BPF_F_PTR_MASK ~(BPF_F_INGRESS | BPF_F_STRPARSER)
+
+static inline bool skb_bpf_strparser(const struct sk_buff *skb)
+{
+ unsigned long sk_redir = skb->_sk_redir;
+
+ return sk_redir & BPF_F_STRPARSER;
+}
+
+static inline void skb_bpf_set_strparser(struct sk_buff *skb)
+{
+ skb->_sk_redir |= BPF_F_STRPARSER;
+}
+
+static inline bool skb_bpf_ingress(const struct sk_buff *skb)
+{
+ unsigned long sk_redir = skb->_sk_redir;
+
+ return sk_redir & BPF_F_INGRESS;
+}
+
+static inline void skb_bpf_set_ingress(struct sk_buff *skb)
+{
+ skb->_sk_redir |= BPF_F_INGRESS;
+}
+
+static inline void skb_bpf_set_redir(struct sk_buff *skb, struct sock *sk_redir,
+ bool ingress)
+{
+ skb->_sk_redir = (unsigned long)sk_redir;
+ if (ingress)
+ skb->_sk_redir |= BPF_F_INGRESS;
+}
+
+static inline struct sock *skb_bpf_redirect_fetch(const struct sk_buff *skb)
+{
+ unsigned long sk_redir = skb->_sk_redir;
+
+ return (struct sock *)(sk_redir & BPF_F_PTR_MASK);
+}
+
+static inline void skb_bpf_redirect_clear(struct sk_buff *skb)
+{
+ skb->_sk_redir = 0;
+}
+#endif /* CONFIG_NET_SOCK_MSG */
#endif /* _LINUX_SKMSG_H */