summaryrefslogtreecommitdiff
path: root/net/core/sock_map.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/core/sock_map.c')
-rw-r--r--net/core/sock_map.c306
1 files changed, 266 insertions, 40 deletions
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index b70c844a88ec..b08dfae10f88 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -10,6 +10,8 @@
#include <linux/skmsg.h>
#include <linux/list.h>
#include <linux/jhash.h>
+#include <linux/sock_diag.h>
+#include <net/udp.h>
struct bpf_stab {
struct bpf_map map;
@@ -31,7 +33,8 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
return ERR_PTR(-EPERM);
if (attr->max_entries == 0 ||
attr->key_size != 4 ||
- attr->value_size != 4 ||
+ (attr->value_size != sizeof(u32) &&
+ attr->value_size != sizeof(u64)) ||
attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
return ERR_PTR(-EINVAL);
@@ -139,12 +142,58 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
}
}
+static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
+{
+ struct proto *prot;
+
+ sock_owned_by_me(sk);
+
+ switch (sk->sk_type) {
+ case SOCK_STREAM:
+ prot = tcp_bpf_get_proto(sk, psock);
+ break;
+
+ case SOCK_DGRAM:
+ prot = udp_bpf_get_proto(sk, psock);
+ break;
+
+ default:
+ return -EINVAL;
+ }
+
+ if (IS_ERR(prot))
+ return PTR_ERR(prot);
+
+ sk_psock_update_proto(sk, psock, prot);
+ return 0;
+}
+
+static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
+{
+ struct sk_psock *psock;
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (psock) {
+ if (sk->sk_prot->close != sock_map_close) {
+ psock = ERR_PTR(-EBUSY);
+ goto out;
+ }
+
+ if (!refcount_inc_not_zero(&psock->refcnt))
+ psock = ERR_PTR(-EBUSY);
+ }
+out:
+ rcu_read_unlock();
+ return psock;
+}
+
static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
struct sock *sk)
{
struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
- bool skb_progs, sk_psock_is_new = false;
struct sk_psock *psock;
+ bool skb_progs;
int ret;
skb_verdict = READ_ONCE(progs->skb_verdict);
@@ -170,7 +219,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
}
}
- psock = sk_psock_get_checked(sk);
+ psock = sock_map_psock_get_checked(sk);
if (IS_ERR(psock)) {
ret = PTR_ERR(psock);
goto out_progs;
@@ -189,18 +238,14 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
ret = -ENOMEM;
goto out_progs;
}
- sk_psock_is_new = true;
}
if (msg_parser)
psock_set_prog(&psock->progs.msg_parser, msg_parser);
- if (sk_psock_is_new) {
- ret = tcp_bpf_init(sk);
- if (ret < 0)
- goto out_drop;
- } else {
- tcp_bpf_reinit(sk);
- }
+
+ ret = sock_map_init_proto(sk, psock);
+ if (ret < 0)
+ goto out_drop;
write_lock_bh(&sk->sk_callback_lock);
if (skb_progs && !psock->parser.enabled) {
@@ -228,6 +273,27 @@ out:
return ret;
}
+static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
+{
+ struct sk_psock *psock;
+ int ret;
+
+ psock = sock_map_psock_get_checked(sk);
+ if (IS_ERR(psock))
+ return PTR_ERR(psock);
+
+ if (!psock) {
+ psock = sk_psock_init(sk, map->numa_node);
+ if (!psock)
+ return -ENOMEM;
+ }
+
+ ret = sock_map_init_proto(sk, psock);
+ if (ret < 0)
+ sk_psock_put(sk, psock);
+ return ret;
+}
+
static void sock_map_free(struct bpf_map *map)
{
struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
@@ -277,7 +343,22 @@ static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
static void *sock_map_lookup(struct bpf_map *map, void *key)
{
- return ERR_PTR(-EOPNOTSUPP);
+ return __sock_map_lookup_elem(map, *(u32 *)key);
+}
+
+static void *sock_map_lookup_sys(struct bpf_map *map, void *key)
+{
+ struct sock *sk;
+
+ if (map->value_size != sizeof(u64))
+ return ERR_PTR(-ENOSPC);
+
+ sk = __sock_map_lookup_elem(map, *(u32 *)key);
+ if (!sk)
+ return ERR_PTR(-ENOENT);
+
+ sock_gen_cookie(sk);
+ return &sk->sk_cookie;
}
static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
@@ -336,11 +417,15 @@ static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
return 0;
}
+static bool sock_map_redirect_allowed(const struct sock *sk)
+{
+ return sk->sk_state != TCP_LISTEN;
+}
+
static int sock_map_update_common(struct bpf_map *map, u32 idx,
struct sock *sk, u64 flags)
{
struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
- struct inet_connection_sock *icsk = inet_csk(sk);
struct sk_psock_link *link;
struct sk_psock *psock;
struct sock *osk;
@@ -351,14 +436,21 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
return -EINVAL;
if (unlikely(idx >= map->max_entries))
return -E2BIG;
- if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data)))
+ if (inet_csk_has_ulp(sk))
return -EINVAL;
link = sk_psock_init_link();
if (!link)
return -ENOMEM;
- ret = sock_map_link(map, &stab->progs, sk);
+ /* Only sockets we can redirect into/from in BPF need to hold
+ * refs to parser/verdict progs and have their sk_data_ready
+ * and sk_write_space callbacks overridden.
+ */
+ if (sock_map_redirect_allowed(sk))
+ ret = sock_map_link(map, &stab->progs, sk);
+ else
+ ret = sock_map_link_no_progs(map, sk);
if (ret < 0)
goto out_free;
@@ -393,23 +485,52 @@ out_free:
static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
{
return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
- ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
+ ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB ||
+ ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
}
-static bool sock_map_sk_is_suitable(const struct sock *sk)
+static bool sk_is_tcp(const struct sock *sk)
{
return sk->sk_type == SOCK_STREAM &&
sk->sk_protocol == IPPROTO_TCP;
}
+static bool sk_is_udp(const struct sock *sk)
+{
+ return sk->sk_type == SOCK_DGRAM &&
+ sk->sk_protocol == IPPROTO_UDP;
+}
+
+static bool sock_map_sk_is_suitable(const struct sock *sk)
+{
+ return sk_is_tcp(sk) || sk_is_udp(sk);
+}
+
+static bool sock_map_sk_state_allowed(const struct sock *sk)
+{
+ if (sk_is_tcp(sk))
+ return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
+ else if (sk_is_udp(sk))
+ return sk_hashed(sk);
+
+ return false;
+}
+
static int sock_map_update_elem(struct bpf_map *map, void *key,
void *value, u64 flags)
{
- u32 ufd = *(u32 *)value;
u32 idx = *(u32 *)key;
struct socket *sock;
struct sock *sk;
int ret;
+ u64 ufd;
+
+ if (map->value_size == sizeof(u64))
+ ufd = *(u64 *)value;
+ else
+ ufd = *(u32 *)value;
+ if (ufd > S32_MAX)
+ return -EINVAL;
sock = sockfd_lookup(ufd, &ret);
if (!sock)
@@ -425,7 +546,7 @@ static int sock_map_update_elem(struct bpf_map *map, void *key,
}
sock_map_sk_acquire(sk);
- if (sk->sk_state != TCP_ESTABLISHED)
+ if (!sock_map_sk_state_allowed(sk))
ret = -EOPNOTSUPP;
else
ret = sock_map_update_common(map, idx, sk, flags);
@@ -462,13 +583,17 @@ BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
struct bpf_map *, map, u32, key, u64, flags)
{
struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
+ struct sock *sk;
if (unlikely(flags & ~(BPF_F_INGRESS)))
return SK_DROP;
- tcb->bpf.flags = flags;
- tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
- if (!tcb->bpf.sk_redir)
+
+ sk = __sock_map_lookup_elem(map, key);
+ if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
return SK_DROP;
+
+ tcb->bpf.flags = flags;
+ tcb->bpf.sk_redir = sk;
return SK_PASS;
}
@@ -485,12 +610,17 @@ const struct bpf_func_proto bpf_sk_redirect_map_proto = {
BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
struct bpf_map *, map, u32, key, u64, flags)
{
+ struct sock *sk;
+
if (unlikely(flags & ~(BPF_F_INGRESS)))
return SK_DROP;
- msg->flags = flags;
- msg->sk_redir = __sock_map_lookup_elem(map, key);
- if (!msg->sk_redir)
+
+ sk = __sock_map_lookup_elem(map, key);
+ if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
return SK_DROP;
+
+ msg->flags = flags;
+ msg->sk_redir = sk;
return SK_PASS;
}
@@ -508,6 +638,7 @@ const struct bpf_map_ops sock_map_ops = {
.map_alloc = sock_map_alloc,
.map_free = sock_map_free,
.map_get_next_key = sock_map_get_next_key,
+ .map_lookup_elem_sys_only = sock_map_lookup_sys,
.map_update_elem = sock_map_update_elem,
.map_delete_elem = sock_map_delete_elem,
.map_lookup_elem = sock_map_lookup,
@@ -520,7 +651,7 @@ struct bpf_htab_elem {
u32 hash;
struct sock *sk;
struct hlist_node node;
- u8 key[0];
+ u8 key[];
};
struct bpf_htab_bucket {
@@ -664,7 +795,6 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
struct sock *sk, u64 flags)
{
struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
- struct inet_connection_sock *icsk = inet_csk(sk);
u32 key_size = map->key_size, hash;
struct bpf_htab_elem *elem, *elem_new;
struct bpf_htab_bucket *bucket;
@@ -675,14 +805,21 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
WARN_ON_ONCE(!rcu_read_lock_held());
if (unlikely(flags > BPF_EXIST))
return -EINVAL;
- if (unlikely(icsk->icsk_ulp_data))
+ if (inet_csk_has_ulp(sk))
return -EINVAL;
link = sk_psock_init_link();
if (!link)
return -ENOMEM;
- ret = sock_map_link(map, &htab->progs, sk);
+ /* Only sockets we can redirect into/from in BPF need to hold
+ * refs to parser/verdict progs and have their sk_data_ready
+ * and sk_write_space callbacks overridden.
+ */
+ if (sock_map_redirect_allowed(sk))
+ ret = sock_map_link(map, &htab->progs, sk);
+ else
+ ret = sock_map_link_no_progs(map, sk);
if (ret < 0)
goto out_free;
@@ -731,10 +868,17 @@ out_free:
static int sock_hash_update_elem(struct bpf_map *map, void *key,
void *value, u64 flags)
{
- u32 ufd = *(u32 *)value;
struct socket *sock;
struct sock *sk;
int ret;
+ u64 ufd;
+
+ if (map->value_size == sizeof(u64))
+ ufd = *(u64 *)value;
+ else
+ ufd = *(u32 *)value;
+ if (ufd > S32_MAX)
+ return -EINVAL;
sock = sockfd_lookup(ufd, &ret);
if (!sock)
@@ -750,7 +894,7 @@ static int sock_hash_update_elem(struct bpf_map *map, void *key,
}
sock_map_sk_acquire(sk);
- if (sk->sk_state != TCP_ESTABLISHED)
+ if (!sock_map_sk_state_allowed(sk))
ret = -EOPNOTSUPP;
else
ret = sock_hash_update_common(map, key, sk, flags);
@@ -810,7 +954,8 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
return ERR_PTR(-EPERM);
if (attr->max_entries == 0 ||
attr->key_size == 0 ||
- attr->value_size != 4 ||
+ (attr->value_size != sizeof(u32) &&
+ attr->value_size != sizeof(u64)) ||
attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
return ERR_PTR(-EINVAL);
if (attr->key_size > MAX_BPF_STACK)
@@ -889,6 +1034,26 @@ static void sock_hash_free(struct bpf_map *map)
kfree(htab);
}
+static void *sock_hash_lookup_sys(struct bpf_map *map, void *key)
+{
+ struct sock *sk;
+
+ if (map->value_size != sizeof(u64))
+ return ERR_PTR(-ENOSPC);
+
+ sk = __sock_hash_lookup_elem(map, key);
+ if (!sk)
+ return ERR_PTR(-ENOENT);
+
+ sock_gen_cookie(sk);
+ return &sk->sk_cookie;
+}
+
+static void *sock_hash_lookup(struct bpf_map *map, void *key)
+{
+ return __sock_hash_lookup_elem(map, key);
+}
+
static void sock_hash_release_progs(struct bpf_map *map)
{
psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
@@ -920,13 +1085,17 @@ BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
struct bpf_map *, map, void *, key, u64, flags)
{
struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
+ struct sock *sk;
if (unlikely(flags & ~(BPF_F_INGRESS)))
return SK_DROP;
- tcb->bpf.flags = flags;
- tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
- if (!tcb->bpf.sk_redir)
+
+ sk = __sock_hash_lookup_elem(map, key);
+ if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
return SK_DROP;
+
+ tcb->bpf.flags = flags;
+ tcb->bpf.sk_redir = sk;
return SK_PASS;
}
@@ -943,12 +1112,17 @@ const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
struct bpf_map *, map, void *, key, u64, flags)
{
+ struct sock *sk;
+
if (unlikely(flags & ~(BPF_F_INGRESS)))
return SK_DROP;
- msg->flags = flags;
- msg->sk_redir = __sock_hash_lookup_elem(map, key);
- if (!msg->sk_redir)
+
+ sk = __sock_hash_lookup_elem(map, key);
+ if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
return SK_DROP;
+
+ msg->flags = flags;
+ msg->sk_redir = sk;
return SK_PASS;
}
@@ -968,7 +1142,8 @@ const struct bpf_map_ops sock_hash_ops = {
.map_get_next_key = sock_hash_get_next_key,
.map_update_elem = sock_hash_update_elem,
.map_delete_elem = sock_hash_delete_elem,
- .map_lookup_elem = sock_map_lookup,
+ .map_lookup_elem = sock_hash_lookup,
+ .map_lookup_elem_sys_only = sock_hash_lookup_sys,
.map_release_uref = sock_hash_release_progs,
.map_check_btf = map_check_no_btf,
};
@@ -1012,7 +1187,7 @@ int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
return 0;
}
-void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
+static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
{
switch (link->map->map_type) {
case BPF_MAP_TYPE_SOCKMAP:
@@ -1025,3 +1200,54 @@ void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
break;
}
}
+
+static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
+{
+ struct sk_psock_link *link;
+
+ while ((link = sk_psock_link_pop(psock))) {
+ sock_map_unlink(sk, link);
+ sk_psock_free_link(link);
+ }
+}
+
+void sock_map_unhash(struct sock *sk)
+{
+ void (*saved_unhash)(struct sock *sk);
+ struct sk_psock *psock;
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (unlikely(!psock)) {
+ rcu_read_unlock();
+ if (sk->sk_prot->unhash)
+ sk->sk_prot->unhash(sk);
+ return;
+ }
+
+ saved_unhash = psock->saved_unhash;
+ sock_map_remove_links(sk, psock);
+ rcu_read_unlock();
+ saved_unhash(sk);
+}
+
+void sock_map_close(struct sock *sk, long timeout)
+{
+ void (*saved_close)(struct sock *sk, long timeout);
+ struct sk_psock *psock;
+
+ lock_sock(sk);
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (unlikely(!psock)) {
+ rcu_read_unlock();
+ release_sock(sk);
+ return sk->sk_prot->close(sk, timeout);
+ }
+
+ saved_close = psock->saved_close;
+ sock_map_remove_links(sk, psock);
+ rcu_read_unlock();
+ release_sock(sk);
+ saved_close(sk, timeout);
+}