summaryrefslogtreecommitdiff
path: root/net/unix/unix_bpf.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/unix/unix_bpf.c')
-rw-r--r--net/unix/unix_bpf.c31
1 files changed, 29 insertions, 2 deletions
diff --git a/net/unix/unix_bpf.c b/net/unix/unix_bpf.c
index e9bf15513961..e0d30d6d22ac 100644
--- a/net/unix/unix_bpf.c
+++ b/net/unix/unix_bpf.c
@@ -1,11 +1,12 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2021 Cong Wang <cong.wang@bytedance.com> */
-#include <linux/skmsg.h>
#include <linux/bpf.h>
-#include <net/sock.h>
+#include <linux/skmsg.h>
#include <net/af_unix.h>
+#include "af_unix.h"
+
#define unix_sk_has_data(__sk, __psock) \
({ !skb_queue_empty(&__sk->sk_receive_queue) || \
!skb_queue_empty(&__psock->ingress_skb) || \
@@ -54,6 +55,12 @@ static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
struct sk_psock *psock;
int copied;
+ if (flags & MSG_OOB)
+ return -EOPNOTSUPP;
+
+ if (!len)
+ return 0;
+
psock = sk_psock_get(sk);
if (unlikely(!psock))
return __unix_recvmsg(sk, msg, len, flags);
@@ -156,12 +163,32 @@ int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool re
int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
{
+ struct sock *sk_pair;
+
+ /* Restore does not decrement the sk_pair reference yet because we must
+ * keep the a reference to the socket until after an RCU grace period
+ * and any pending sends have completed.
+ */
if (restore) {
sk->sk_write_space = psock->saved_write_space;
sock_replace_proto(sk, psock->sk_proto);
return 0;
}
+ /* psock_update_sk_prot can be called multiple times if psock is
+ * added to multiple maps and/or slots in the same map. There is
+ * also an edge case where replacing a psock with itself can trigger
+ * an extra psock_update_sk_prot during the insert process. So it
+ * must be safe to do multiple calls. Here we need to ensure we don't
+ * increment the refcnt through sock_hold many times. There will only
+ * be a single matching destroy operation.
+ */
+ if (!psock->sk_pair) {
+ sk_pair = unix_peer(sk);
+ sock_hold(sk_pair);
+ psock->sk_pair = sk_pair;
+ }
+
unix_stream_bpf_check_needs_rebuild(psock->sk_proto);
sock_replace_proto(sk, &unix_stream_bpf_prot);
return 0;