summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/linux/skmsg.h12
-rw-r--r--net/core/filter.c6
-rw-r--r--net/core/sock_map.c6
3 files changed, 16 insertions, 8 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index b4256847c707..584d94be9c8b 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -507,6 +507,18 @@ static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
return !!psock->saved_data_ready;
}
+static inline bool sk_is_tcp(const struct sock *sk)
+{
+ return sk->sk_type == SOCK_STREAM &&
+ sk->sk_protocol == IPPROTO_TCP;
+}
+
+static inline bool sk_is_udp(const struct sock *sk)
+{
+ return sk->sk_type == SOCK_DGRAM &&
+ sk->sk_protocol == IPPROTO_UDP;
+}
+
#if IS_ENABLED(CONFIG_NET_SOCK_MSG)
#define BPF_F_STRPARSER (1UL << 1)
diff --git a/net/core/filter.c b/net/core/filter.c
index 8e8d3b49c297..a68418268e92 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -10423,8 +10423,10 @@ BPF_CALL_3(bpf_sk_lookup_assign, struct bpf_sk_lookup_kern *, ctx,
return -EINVAL;
if (unlikely(sk && sk_is_refcounted(sk)))
return -ESOCKTNOSUPPORT; /* reject non-RCU freed sockets */
- if (unlikely(sk && sk->sk_state == TCP_ESTABLISHED))
- return -ESOCKTNOSUPPORT; /* reject connected sockets */
+ if (unlikely(sk && sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN))
+ return -ESOCKTNOSUPPORT; /* only accept TCP socket in LISTEN */
+ if (unlikely(sk && sk_is_udp(sk) && sk->sk_state != TCP_CLOSE))
+ return -ESOCKTNOSUPPORT; /* only accept UDP socket in CLOSE */
/* Check if socket is suitable for packet L3/L4 protocol */
if (sk && sk->sk_protocol != ctx->protocol)
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index e252b8ec2b85..f39ef79ced67 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -511,12 +511,6 @@ static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
}
-static bool sk_is_tcp(const struct sock *sk)
-{
- return sk->sk_type == SOCK_STREAM &&
- sk->sk_protocol == IPPROTO_TCP;
-}
-
static bool sock_map_redirect_allowed(const struct sock *sk)
{
if (sk_is_tcp(sk))