diff options
Diffstat (limited to 'net/ipv4/raw_diag.c')
| -rw-r--r-- | net/ipv4/raw_diag.c | 99 |
1 files changed, 46 insertions, 53 deletions
diff --git a/net/ipv4/raw_diag.c b/net/ipv4/raw_diag.c index c200065ef9a5..943e5998e0ad 100644 --- a/net/ipv4/raw_diag.c +++ b/net/ipv4/raw_diag.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only #include <linux/module.h> #include <linux/inet_diag.h> @@ -23,9 +24,6 @@ raw_get_hashinfo(const struct inet_diag_req_v2 *r) return &raw_v6_hashinfo; #endif } else { - pr_warn_once("Unexpected inet family %d\n", - r->sdiag_family); - WARN_ON_ONCE(1); return ERR_PTR(-EINVAL); } } @@ -36,84 +34,82 @@ raw_get_hashinfo(const struct inet_diag_req_v2 *r) * use helper to figure it out. */ -static struct sock *raw_lookup(struct net *net, struct sock *from, - const struct inet_diag_req_v2 *req) +static bool raw_lookup(struct net *net, const struct sock *sk, + const struct inet_diag_req_v2 *req) { struct inet_diag_req_raw *r = (void *)req; - struct sock *sk = NULL; if (r->sdiag_family == AF_INET) - sk = __raw_v4_lookup(net, from, r->sdiag_raw_protocol, - r->id.idiag_dst[0], - r->id.idiag_src[0], - r->id.idiag_if, 0); + return raw_v4_match(net, sk, r->sdiag_raw_protocol, + r->id.idiag_dst[0], + r->id.idiag_src[0], + r->id.idiag_if, 0); #if IS_ENABLED(CONFIG_IPV6) else - sk = __raw_v6_lookup(net, from, r->sdiag_raw_protocol, - (const struct in6_addr *)r->id.idiag_src, - (const struct in6_addr *)r->id.idiag_dst, - r->id.idiag_if, 0); + return raw_v6_match(net, sk, r->sdiag_raw_protocol, + (const struct in6_addr *)r->id.idiag_src, + (const struct in6_addr *)r->id.idiag_dst, + r->id.idiag_if, 0); #endif - return sk; + return false; } static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r) { struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); - struct sock *sk = NULL, *s; + struct hlist_head *hlist; + struct sock *sk; int slot; if (IS_ERR(hashinfo)) return ERR_CAST(hashinfo); - read_lock(&hashinfo->lock); + rcu_read_lock(); for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) { - sk_for_each(s, &hashinfo->ht[slot]) { - sk = raw_lookup(net, s, r); - if (sk) { + hlist = &hashinfo->ht[slot]; + sk_for_each_rcu(sk, hlist) { + if (raw_lookup(net, sk, r)) { /* * Grab it and keep until we fill - * diag meaage to be reported, so + * diag message to be reported, so * caller should call sock_put then. - * We can do that because we're keeping - * hashinfo->lock here. */ - sock_hold(sk); - goto out_unlock; + if (refcount_inc_not_zero(&sk->sk_refcnt)) + goto out_unlock; } } } + sk = ERR_PTR(-ENOENT); out_unlock: - read_unlock(&hashinfo->lock); + rcu_read_unlock(); - return sk ? sk : ERR_PTR(-ENOENT); + return sk; } -static int raw_diag_dump_one(struct sk_buff *in_skb, - const struct nlmsghdr *nlh, +static int raw_diag_dump_one(struct netlink_callback *cb, const struct inet_diag_req_v2 *r) { - struct net *net = sock_net(in_skb->sk); + struct sk_buff *in_skb = cb->skb; struct sk_buff *rep; struct sock *sk; + struct net *net; int err; + net = sock_net(in_skb->sk); sk = raw_sock_get(net, r); if (IS_ERR(sk)) return PTR_ERR(sk); - rep = nlmsg_new(sizeof(struct inet_diag_msg) + - sizeof(struct inet_diag_meminfo) + 64, + rep = nlmsg_new(nla_total_size(sizeof(struct inet_diag_msg)) + + inet_diag_msg_attrs_size() + + nla_total_size(sizeof(struct inet_diag_meminfo)) + 64, GFP_KERNEL); if (!rep) { sock_put(sk); return -ENOMEM; } - err = inet_sk_diag_fill(sk, NULL, rep, r, - sk_user_ns(NETLINK_CB(in_skb).sk), - NETLINK_CB(in_skb).portid, - nlh->nlmsg_seq, 0, nlh, + err = inet_sk_diag_fill(sk, NULL, rep, cb, r, 0, netlink_net_capable(in_skb, CAP_NET_ADMIN)); sock_put(sk); @@ -122,36 +118,30 @@ static int raw_diag_dump_one(struct sk_buff *in_skb, return err; } - err = netlink_unicast(net->diag_nlsk, rep, - NETLINK_CB(in_skb).portid, - MSG_DONTWAIT); - if (err > 0) - err = 0; + err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid); + return err; } static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *r, - struct nlattr *bc, bool net_admin) + bool net_admin) { - if (!inet_diag_bc_sk(bc, sk)) + if (!inet_diag_bc_sk(cb->data, sk)) return 0; - return inet_sk_diag_fill(sk, NULL, skb, r, - sk_user_ns(NETLINK_CB(cb->skb).sk), - NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, NLM_F_MULTI, - cb->nlh, net_admin); + return inet_sk_diag_fill(sk, NULL, skb, cb, r, NLM_F_MULTI, net_admin); } static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, - const struct inet_diag_req_v2 *r, struct nlattr *bc) + const struct inet_diag_req_v2 *r) { bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); struct net *net = sock_net(skb->sk); int num, s_num, slot, s_slot; + struct hlist_head *hlist; struct sock *sk = NULL; if (IS_ERR(hashinfo)) @@ -160,11 +150,12 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, s_slot = cb->args[0]; num = s_num = cb->args[1]; - read_lock(&hashinfo->lock); + rcu_read_lock(); for (slot = s_slot; slot < RAW_HTABLE_SIZE; s_num = 0, slot++) { num = 0; - sk_for_each(sk, &hashinfo->ht[slot]) { + hlist = &hashinfo->ht[slot]; + sk_for_each_rcu(sk, hlist) { struct inet_sock *inet = inet_sk(sk); if (!net_eq(sock_net(sk), net)) @@ -179,7 +170,7 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, if (r->id.idiag_dport != inet->inet_dport && r->id.idiag_dport) goto next; - if (sk_diag_dump(sk, skb, cb, r, bc, net_admin) < 0) + if (sk_diag_dump(sk, skb, cb, r, net_admin) < 0) goto out_unlock; next: num++; @@ -187,7 +178,7 @@ next: } out_unlock: - read_unlock(&hashinfo->lock); + rcu_read_unlock(); cb->args[0] = slot; cb->args[1] = num; @@ -218,6 +209,7 @@ static int raw_diag_destroy(struct sk_buff *in_skb, #endif static const struct inet_diag_handler raw_diag_handler = { + .owner = THIS_MODULE, .dump = raw_diag_dump, .dump_one = raw_diag_dump_one, .idiag_get_info = raw_diag_get_info, @@ -262,5 +254,6 @@ static void __exit raw_diag_exit(void) module_init(raw_diag_init); module_exit(raw_diag_exit); MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("RAW socket monitoring via SOCK_DIAG"); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-255 /* AF_INET - IPPROTO_RAW */); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10-255 /* AF_INET6 - IPPROTO_RAW */); |
