diff options
Diffstat (limited to 'net/ipv4/inet_hashtables.c')
| -rw-r--r-- | net/ipv4/inet_hashtables.c | 458 |
1 files changed, 280 insertions, 178 deletions
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index 7876b7d703cb..f5826ec4bcaa 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -23,21 +23,21 @@ #if IS_ENABLED(CONFIG_IPV6) #include <net/inet6_hashtables.h> #endif -#include <net/secure_seq.h> +#include <net/hotdata.h> #include <net/ip.h> -#include <net/tcp.h> +#include <net/rps.h> +#include <net/secure_seq.h> #include <net/sock_reuseport.h> +#include <net/tcp.h> u32 inet_ehashfn(const struct net *net, const __be32 laddr, const __u16 lport, const __be32 faddr, const __be16 fport) { - static u32 inet_ehash_secret __read_mostly; - net_get_random_once(&inet_ehash_secret, sizeof(inet_ehash_secret)); - return __inet_ehashfn(laddr, lport, faddr, fport, - inet_ehash_secret + net_hash_mix(net)); + return lport + __inet_ehashfn(laddr, 0, faddr, fport, + inet_ehash_secret + net_hash_mix(net)); } EXPORT_SYMBOL_GPL(inet_ehashfn); @@ -58,6 +58,14 @@ static u32 sk_ehashfn(const struct sock *sk) sk->sk_daddr, sk->sk_dport); } +static bool sk_is_connect_bind(const struct sock *sk) +{ + if (sk->sk_state == TCP_TIME_WAIT) + return inet_twsk(sk)->tw_connect_bind; + else + return sk->sk_userlocks & SOCK_CONNECT_BIND; +} + /* * Allocate and initialize a new local port bind bucket. * The bindhash mutex for snum's hash chain must be held here. @@ -76,8 +84,8 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, tb->port = snum; tb->fastreuse = 0; tb->fastreuseport = 0; - INIT_HLIST_HEAD(&tb->owners); - hlist_add_head(&tb->node, &head->chain); + INIT_HLIST_HEAD(&tb->bhash2); + hlist_add_head_rcu(&tb->node, &head->chain); } return tb; } @@ -85,12 +93,24 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, /* * Caller must hold hashbucket lock for this tb with local BH disabled */ -void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb) +void inet_bind_bucket_destroy(struct inet_bind_bucket *tb) { - if (hlist_empty(&tb->owners)) { - __hlist_del(&tb->node); - kmem_cache_free(cachep, tb); + const struct inet_bind2_bucket *tb2; + + if (hlist_empty(&tb->bhash2)) { + hlist_del_rcu(&tb->node); + kfree_rcu(tb, rcu); + return; } + + if (tb->fastreuse == -1 && tb->fastreuseport == -1) + return; + hlist_for_each_entry(tb2, &tb->bhash2, bhash_node) { + if (tb2->fastreuse != -1 || tb2->fastreuseport != -1) + return; + } + tb->fastreuse = -1; + tb->fastreuseport = -1; } bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net, @@ -100,61 +120,79 @@ bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net tb->l3mdev == l3mdev; } -static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb, +static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb2, struct net *net, struct inet_bind_hashbucket *head, - unsigned short port, int l3mdev, + struct inet_bind_bucket *tb, const struct sock *sk) { - write_pnet(&tb->ib_net, net); - tb->l3mdev = l3mdev; - tb->port = port; + write_pnet(&tb2->ib_net, net); + tb2->l3mdev = tb->l3mdev; + tb2->port = tb->port; #if IS_ENABLED(CONFIG_IPV6) - tb->family = sk->sk_family; - if (sk->sk_family == AF_INET6) - tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr; - else + BUILD_BUG_ON(USHRT_MAX < (IPV6_ADDR_ANY | IPV6_ADDR_MAPPED)); + if (sk->sk_family == AF_INET6) { + tb2->addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr); + tb2->v6_rcv_saddr = sk->sk_v6_rcv_saddr; + } else { + tb2->addr_type = IPV6_ADDR_MAPPED; + ipv6_addr_set_v4mapped(sk->sk_rcv_saddr, &tb2->v6_rcv_saddr); + } +#else + tb2->rcv_saddr = sk->sk_rcv_saddr; #endif - tb->rcv_saddr = sk->sk_rcv_saddr; - INIT_HLIST_HEAD(&tb->owners); - INIT_HLIST_HEAD(&tb->deathrow); - hlist_add_head(&tb->node, &head->chain); + tb2->fastreuse = 0; + tb2->fastreuseport = 0; + INIT_HLIST_HEAD(&tb2->owners); + hlist_add_head(&tb2->node, &head->chain); + hlist_add_head(&tb2->bhash_node, &tb->bhash2); } struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep, struct net *net, struct inet_bind_hashbucket *head, - unsigned short port, - int l3mdev, + struct inet_bind_bucket *tb, const struct sock *sk) { - struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC); + struct inet_bind2_bucket *tb2 = kmem_cache_alloc(cachep, GFP_ATOMIC); - if (tb) - inet_bind2_bucket_init(tb, net, head, port, l3mdev, sk); + if (tb2) + inet_bind2_bucket_init(tb2, net, head, tb, sk); - return tb; + return tb2; } /* Caller must hold hashbucket lock for this tb with local BH disabled */ void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb) { - if (hlist_empty(&tb->owners) && hlist_empty(&tb->deathrow)) { + const struct sock *sk; + + if (hlist_empty(&tb->owners)) { __hlist_del(&tb->node); + __hlist_del(&tb->bhash_node); kmem_cache_free(cachep, tb); + return; + } + + if (tb->fastreuse == -1 && tb->fastreuseport == -1) + return; + sk_for_each_bound(sk, &tb->owners) { + if (!sk_is_connect_bind(sk)) + return; } + tb->fastreuse = -1; + tb->fastreuseport = -1; } static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2, const struct sock *sk) { #if IS_ENABLED(CONFIG_IPV6) - if (sk->sk_family != tb2->family) - return false; - if (sk->sk_family == AF_INET6) - return ipv6_addr_equal(&tb2->v6_rcv_saddr, - &sk->sk_v6_rcv_saddr); + return ipv6_addr_equal(&tb2->v6_rcv_saddr, &sk->sk_v6_rcv_saddr); + + if (tb2->addr_type != IPV6_ADDR_MAPPED) + return false; #endif return tb2->rcv_saddr == sk->sk_rcv_saddr; } @@ -163,10 +201,9 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, struct inet_bind2_bucket *tb2, unsigned short port) { inet_sk(sk)->inet_num = port; - sk_add_bind_node(sk, &tb->owners); inet_csk(sk)->icsk_bind_hash = tb; - sk_add_bind2_node(sk, &tb2->owners); inet_csk(sk)->icsk_bind2_hash = tb2; + sk_add_bind_node(sk, &tb2->owners); } /* @@ -174,7 +211,7 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, */ static void __inet_put_port(struct sock *sk) { - struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_hashinfo *hashinfo = tcp_get_hashinfo(sk); struct inet_bind_hashbucket *head, *head2; struct net *net = sock_net(sk); struct inet_bind_bucket *tb; @@ -186,21 +223,21 @@ static void __inet_put_port(struct sock *sk) spin_lock(&head->lock); tb = inet_csk(sk)->icsk_bind_hash; - __sk_del_bind_node(sk); inet_csk(sk)->icsk_bind_hash = NULL; inet_sk(sk)->inet_num = 0; - inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); + sk->sk_userlocks &= ~SOCK_CONNECT_BIND; spin_lock(&head2->lock); if (inet_csk(sk)->icsk_bind2_hash) { struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash; - __sk_del_bind2_node(sk); + __sk_del_bind_node(sk); inet_csk(sk)->icsk_bind2_hash = NULL; inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2); } spin_unlock(&head2->lock); + inet_bind_bucket_destroy(tb); spin_unlock(&head->lock); } @@ -214,7 +251,7 @@ EXPORT_SYMBOL(inet_put_port); int __inet_inherit_port(const struct sock *sk, struct sock *child) { - struct inet_hashinfo *table = tcp_or_dccp_get_hashinfo(sk); + struct inet_hashinfo *table = tcp_get_hashinfo(sk); unsigned short port = inet_sk(child)->inet_num; struct inet_bind_hashbucket *head, *head2; bool created_inet_bind_bucket = false; @@ -269,14 +306,13 @@ bhash2_find: tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child); if (!tb2) { tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep, - net, head2, port, - l3mdev, child); + net, head2, tb, child); if (!tb2) goto error; } } if (update_fastreuse) - inet_csk_update_fastreuse(tb, child); + inet_csk_update_fastreuse(child, tb, tb2); inet_bind_hash(child, tb, tb2, port); spin_unlock(&head2->lock); spin_unlock(&head->lock); @@ -285,7 +321,7 @@ bhash2_find: error: if (created_inet_bind_bucket) - inet_bind_bucket_destroy(table->bind_bucket_cachep, tb); + inet_bind_bucket_destroy(tb); spin_unlock(&head2->lock); spin_unlock(&head->lock); return -ENOMEM; @@ -310,7 +346,7 @@ inet_lhash2_bucket_sk(struct inet_hashinfo *h, struct sock *sk) return inet_lhash2_bucket(h, hash); } -static inline int compute_score(struct sock *sk, struct net *net, +static inline int compute_score(struct sock *sk, const struct net *net, const unsigned short hnum, const __be32 daddr, const int dif, const int sdif) { @@ -348,7 +384,7 @@ static inline int compute_score(struct sock *sk, struct net *net, * Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to * the selected sock or an error. */ -struct sock *inet_lookup_reuseport(struct net *net, struct sock *sk, +struct sock *inet_lookup_reuseport(const struct net *net, struct sock *sk, struct sk_buff *skb, int doff, __be32 saddr, __be16 sport, __be32 daddr, unsigned short hnum, @@ -374,7 +410,7 @@ EXPORT_SYMBOL_GPL(inet_lookup_reuseport); */ /* called with rcu_read_lock() : No refcount taken on the socket */ -static struct sock *inet_lhash2_lookup(struct net *net, +static struct sock *inet_lhash2_lookup(const struct net *net, struct inet_listen_hashbucket *ilb2, struct sk_buff *skb, int doff, const __be32 saddr, __be16 sport, @@ -401,7 +437,7 @@ static struct sock *inet_lhash2_lookup(struct net *net, return result; } -struct sock *inet_lookup_run_sk_lookup(struct net *net, +struct sock *inet_lookup_run_sk_lookup(const struct net *net, int protocol, struct sk_buff *skb, int doff, __be32 saddr, __be16 sport, @@ -423,20 +459,19 @@ struct sock *inet_lookup_run_sk_lookup(struct net *net, return sk; } -struct sock *__inet_lookup_listener(struct net *net, - struct inet_hashinfo *hashinfo, +struct sock *__inet_lookup_listener(const struct net *net, struct sk_buff *skb, int doff, const __be32 saddr, __be16 sport, const __be32 daddr, const unsigned short hnum, const int dif, const int sdif) { struct inet_listen_hashbucket *ilb2; + struct inet_hashinfo *hashinfo; struct sock *result = NULL; unsigned int hash2; /* Lookup redirect from BPF */ - if (static_branch_unlikely(&bpf_sk_lookup_enabled) && - hashinfo == net->ipv4.tcp_death_row.hashinfo) { + if (static_branch_unlikely(&bpf_sk_lookup_enabled)) { result = inet_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff, saddr, sport, daddr, hnum, dif, inet_ehashfn); @@ -444,6 +479,7 @@ struct sock *__inet_lookup_listener(struct net *net, goto done; } + hashinfo = net->ipv4.tcp_death_row.hashinfo; hash2 = ipv4_portaddr_hash(net, daddr, hnum); ilb2 = inet_lhash2_bucket(hashinfo, hash2); @@ -488,22 +524,23 @@ void sock_edemux(struct sk_buff *skb) } EXPORT_SYMBOL(sock_edemux); -struct sock *__inet_lookup_established(struct net *net, - struct inet_hashinfo *hashinfo, - const __be32 saddr, const __be16 sport, - const __be32 daddr, const u16 hnum, - const int dif, const int sdif) +struct sock *__inet_lookup_established(const struct net *net, + const __be32 saddr, const __be16 sport, + const __be32 daddr, const u16 hnum, + const int dif, const int sdif) { - INET_ADDR_COOKIE(acookie, saddr, daddr); const __portpair ports = INET_COMBINED_PORTS(sport, hnum); - struct sock *sk; + INET_ADDR_COOKIE(acookie, saddr, daddr); const struct hlist_nulls_node *node; - /* Optimize here for direct hit, only listening connections can - * have wildcards anyways. - */ - unsigned int hash = inet_ehashfn(net, daddr, hnum, saddr, sport); - unsigned int slot = hash & hashinfo->ehash_mask; - struct inet_ehash_bucket *head = &hashinfo->ehash[slot]; + struct inet_ehash_bucket *head; + struct inet_hashinfo *hashinfo; + unsigned int hash, slot; + struct sock *sk; + + hashinfo = net->ipv4.tcp_death_row.hashinfo; + hash = inet_ehashfn(net, daddr, hnum, saddr, sport); + slot = hash & hashinfo->ehash_mask; + head = &hashinfo->ehash[slot]; begin: sk_nulls_for_each_rcu(sk, node, &head->chain) { @@ -537,7 +574,9 @@ EXPORT_SYMBOL_GPL(__inet_lookup_established); /* called with local bh disabled */ static int __inet_check_established(struct inet_timewait_death_row *death_row, struct sock *sk, __u16 lport, - struct inet_timewait_sock **twp) + struct inet_timewait_sock **twp, + bool rcu_lookup, + u32 hash) { struct inet_hashinfo *hinfo = death_row->hashinfo; struct inet_sock *inet = inet_sk(sk); @@ -548,14 +587,25 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row, int sdif = l3mdev_master_ifindex_by_index(net, dif); INET_ADDR_COOKIE(acookie, saddr, daddr); const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport); - unsigned int hash = inet_ehashfn(net, daddr, lport, - saddr, inet->inet_dport); struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); - spinlock_t *lock = inet_ehash_lockp(hinfo, hash); - struct sock *sk2; - const struct hlist_nulls_node *node; struct inet_timewait_sock *tw = NULL; + const struct hlist_nulls_node *node; + struct sock *sk2; + spinlock_t *lock; + + if (rcu_lookup) { + sk_nulls_for_each(sk2, node, &head->chain) { + if (sk2->sk_hash != hash || + !inet_match(net, sk2, acookie, ports, dif, sdif)) + continue; + if (sk2->sk_state == TCP_TIME_WAIT) + break; + return -EADDRNOTAVAIL; + } + return 0; + } + lock = inet_ehash_lockp(hinfo, hash); spin_lock(lock); sk_nulls_for_each(sk2, node, &head->chain) { @@ -565,7 +615,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row, if (likely(inet_match(net, sk2, acookie, ports, dif, sdif))) { if (sk2->sk_state == TCP_TIME_WAIT) { tw = inet_twsk(sk2); - if (twsk_unique(sk, sk2, twp)) + if (tcp_twsk_unique(sk, sk2, twp)) break; } goto not_unique; @@ -654,7 +704,7 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk, */ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) { - struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_hashinfo *hashinfo = tcp_get_hashinfo(sk); struct inet_ehash_bucket *head; struct hlist_nulls_head *list; spinlock_t *lock; @@ -670,8 +720,11 @@ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) spin_lock(lock); if (osk) { WARN_ON_ONCE(sk->sk_hash != osk->sk_hash); - ret = sk_nulls_del_node_init_rcu(osk); - } else if (found_dup_sk) { + ret = sk_nulls_replace_node_init_rcu(osk, sk); + goto unlock; + } + + if (found_dup_sk) { *found_dup_sk = inet_ehash_lookup_by_sk(sk, list); if (*found_dup_sk) ret = false; @@ -680,6 +733,7 @@ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) if (ret) __sk_nulls_add_node_rcu(sk, list); +unlock: spin_unlock(lock); return ret; @@ -692,22 +746,22 @@ bool inet_ehash_nolisten(struct sock *sk, struct sock *osk, bool *found_dup_sk) if (ok) { sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); } else { - this_cpu_inc(*sk->sk_prot->orphan_count); + tcp_orphan_count_inc(); inet_sk_set_state(sk, TCP_CLOSE); sock_set_flag(sk, SOCK_DEAD); inet_csk_destroy_sock(sk); } return ok; } -EXPORT_SYMBOL_GPL(inet_ehash_nolisten); +EXPORT_IPV6_MOD(inet_ehash_nolisten); static int inet_reuseport_add_sock(struct sock *sk, struct inet_listen_hashbucket *ilb) { struct inet_bind_bucket *tb = inet_csk(sk)->icsk_bind_hash; const struct hlist_nulls_node *node; + kuid_t uid = sk_uid(sk); struct sock *sk2; - kuid_t uid = sock_i_uid(sk); sk_nulls_for_each_rcu(sk2, node, &ilb->nulls_head) { if (sk2 != sk && @@ -715,7 +769,7 @@ static int inet_reuseport_add_sock(struct sock *sk, ipv6_only_sock(sk2) == ipv6_only_sock(sk) && sk2->sk_bound_dev_if == sk->sk_bound_dev_if && inet_csk(sk2)->icsk_bind_hash == tb && - sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && + sk2->sk_reuseport && uid_eq(uid, sk_uid(sk2)) && inet_rcv_saddr_equal(sk, sk2, false)) return reuseport_add_sock(sk, sk2, inet_rcv_saddr_any(sk)); @@ -724,15 +778,18 @@ static int inet_reuseport_add_sock(struct sock *sk, return reuseport_alloc(sk, inet_rcv_saddr_any(sk)); } -int __inet_hash(struct sock *sk, struct sock *osk) +int inet_hash(struct sock *sk) { - struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_hashinfo *hashinfo = tcp_get_hashinfo(sk); struct inet_listen_hashbucket *ilb2; int err = 0; + if (sk->sk_state == TCP_CLOSE) + return 0; + if (sk->sk_state != TCP_LISTEN) { local_bh_disable(); - inet_ehash_nolisten(sk, osk, NULL); + inet_ehash_nolisten(sk, NULL, NULL); local_bh_enable(); return 0; } @@ -745,38 +802,28 @@ int __inet_hash(struct sock *sk, struct sock *osk) if (err) goto unlock; } + sock_set_flag(sk, SOCK_RCU_FREE); if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && sk->sk_family == AF_INET6) __sk_nulls_add_node_tail_rcu(sk, &ilb2->nulls_head); else __sk_nulls_add_node_rcu(sk, &ilb2->nulls_head); - sock_set_flag(sk, SOCK_RCU_FREE); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); unlock: spin_unlock(&ilb2->lock); return err; } -EXPORT_SYMBOL(__inet_hash); - -int inet_hash(struct sock *sk) -{ - int err = 0; - - if (sk->sk_state != TCP_CLOSE) - err = __inet_hash(sk, NULL); - - return err; -} -EXPORT_SYMBOL_GPL(inet_hash); +EXPORT_IPV6_MOD(inet_hash); void inet_unhash(struct sock *sk) { - struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_hashinfo *hashinfo = tcp_get_hashinfo(sk); if (sk_unhashed(sk)) return; + sock_rps_delete_flow(sk); if (sk->sk_state == TCP_LISTEN) { struct inet_listen_hashbucket *ilb2; @@ -785,11 +832,6 @@ void inet_unhash(struct sock *sk) * avoid circular locking dependency on PREEMPT_RT. */ spin_lock(&ilb2->lock); - if (sk_unhashed(sk)) { - spin_unlock(&ilb2->lock); - return; - } - if (rcu_access_pointer(sk->sk_reuseport_cb)) reuseport_stop_listen_sock(sk); @@ -800,56 +842,43 @@ void inet_unhash(struct sock *sk) spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash); spin_lock_bh(lock); - if (sk_unhashed(sk)) { - spin_unlock_bh(lock); - return; - } __sk_nulls_del_node_init_rcu(sk); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); spin_unlock_bh(lock); } } -EXPORT_SYMBOL_GPL(inet_unhash); +EXPORT_IPV6_MOD(inet_unhash); static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb, const struct net *net, unsigned short port, int l3mdev, const struct sock *sk) { -#if IS_ENABLED(CONFIG_IPV6) - if (sk->sk_family != tb->family) + if (!net_eq(ib2_net(tb), net) || tb->port != port || + tb->l3mdev != l3mdev) return false; - if (sk->sk_family == AF_INET6) - return net_eq(ib2_net(tb), net) && tb->port == port && - tb->l3mdev == l3mdev && - ipv6_addr_equal(&tb->v6_rcv_saddr, &sk->sk_v6_rcv_saddr); - else -#endif - return net_eq(ib2_net(tb), net) && tb->port == port && - tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr; + return inet_bind2_bucket_addr_match(tb, sk); } bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net, unsigned short port, int l3mdev, const struct sock *sk) { + if (!net_eq(ib2_net(tb), net) || tb->port != port || + tb->l3mdev != l3mdev) + return false; + #if IS_ENABLED(CONFIG_IPV6) - if (sk->sk_family != tb->family) { - if (sk->sk_family == AF_INET) - return net_eq(ib2_net(tb), net) && tb->port == port && - tb->l3mdev == l3mdev && - ipv6_addr_any(&tb->v6_rcv_saddr); + if (tb->addr_type == IPV6_ADDR_ANY) + return true; + if (tb->addr_type != IPV6_ADDR_MAPPED) return false; - } - if (sk->sk_family == AF_INET6) - return net_eq(ib2_net(tb), net) && tb->port == port && - tb->l3mdev == l3mdev && - ipv6_addr_any(&tb->v6_rcv_saddr); - else + if (sk->sk_family == AF_INET6 && + !ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr)) + return false; #endif - return net_eq(ib2_net(tb), net) && tb->port == port && - tb->l3mdev == l3mdev && tb->rcv_saddr == 0; + return tb->rcv_saddr == 0; } /* The socket's bhash2 hashbucket spinlock must be held when this is called */ @@ -869,7 +898,7 @@ inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net struct inet_bind_hashbucket * inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port) { - struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_hashinfo *hinfo = tcp_get_hashinfo(sk); u32 hash; #if IS_ENABLED(CONFIG_IPV6) @@ -897,7 +926,7 @@ static void inet_update_saddr(struct sock *sk, void *saddr, int family) static int __inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family, bool reset) { - struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_hashinfo *hinfo = tcp_get_hashinfo(sk); struct inet_bind_hashbucket *head, *head2; struct inet_bind2_bucket *tb2, *new_tb2; int l3mdev = inet_sk_bound_l3mdev(sk); @@ -944,7 +973,7 @@ static int __inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family, spin_lock_bh(&head->lock); spin_lock(&head2->lock); - __sk_del_bind2_node(sk); + __sk_del_bind_node(sk); inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, inet_csk(sk)->icsk_bind2_hash); spin_unlock(&head2->lock); @@ -959,10 +988,14 @@ static int __inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family, tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); if (!tb2) { tb2 = new_tb2; - inet_bind2_bucket_init(tb2, net, head2, port, l3mdev, sk); + inet_bind2_bucket_init(tb2, net, head2, inet_csk(sk)->icsk_bind_hash, sk); + if (sk_is_connect_bind(sk)) { + tb2->fastreuse = -1; + tb2->fastreuseport = -1; + } } - sk_add_bind2_node(sk, &tb2->owners); inet_csk(sk)->icsk_bind2_hash = tb2; + sk_add_bind_node(sk, &tb2->owners); spin_unlock(&head2->lock); spin_unlock_bh(&head->lock); @@ -977,14 +1010,14 @@ int inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family) { return __inet_bhash2_update_saddr(sk, saddr, family, false); } -EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr); +EXPORT_IPV6_MOD(inet_bhash2_update_saddr); void inet_bhash2_reset_saddr(struct sock *sk) { if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) __inet_bhash2_update_saddr(sk, NULL, 0, true); } -EXPORT_SYMBOL_GPL(inet_bhash2_reset_saddr); +EXPORT_IPV6_MOD(inet_bhash2_reset_saddr); /* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm * Note that we use 32bit integers (vs RFC 'short integers') @@ -1001,8 +1034,10 @@ static u32 *table_perturb; int __inet_hash_connect(struct inet_timewait_death_row *death_row, struct sock *sk, u64 port_offset, + u32 hash_port0, int (*check_established)(struct inet_timewait_death_row *, - struct sock *, __u16, struct inet_timewait_sock **)) + struct sock *, __u16, struct inet_timewait_sock **, + bool rcu_lookup, u32 hash)) { struct inet_hashinfo *hinfo = death_row->hashinfo; struct inet_bind_hashbucket *head, *head2; @@ -1014,22 +1049,26 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, bool tb_created = false; u32 remaining, offset; int ret, i, low, high; - int l3mdev; + bool local_ports; + int step, l3mdev; u32 index; if (port) { local_bh_disable(); - ret = check_established(death_row, sk, port, NULL); + ret = check_established(death_row, sk, port, NULL, false, + hash_port0 + port); local_bh_enable(); return ret; } l3mdev = inet_sk_bound_l3mdev(sk); - inet_sk_get_local_port_range(sk, &low, &high); + local_ports = inet_sk_get_local_port_range(sk, &low, &high); + step = local_ports ? 1 : 2; + high++; /* [32768, 60999] -> [32768, 61000[ */ remaining = high - low; - if (likely(remaining > 1)) + if (!local_ports && remaining > 1) remaining &= ~1U; get_random_sleepable_once(table_perturb, @@ -1042,16 +1081,33 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row, /* In first pass we try ports of @low parity. * inet_csk_get_port() does the opposite choice. */ - offset &= ~1U; + if (!local_ports) + offset &= ~1U; other_parity_scan: port = low + offset; - for (i = 0; i < remaining; i += 2, port += 2) { + for (i = 0; i < remaining; i += step, port += step) { if (unlikely(port >= high)) port -= remaining; if (inet_is_local_reserved_port(net, port)) continue; head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)]; + rcu_read_lock(); + hlist_for_each_entry_rcu(tb, &head->chain, node) { + if (!inet_bind_bucket_match(tb, net, port, l3mdev)) + continue; + if (tb->fastreuse >= 0 || tb->fastreuseport >= 0) { + rcu_read_unlock(); + goto next_port; + } + if (!check_established(death_row, sk, port, &tw, true, + hash_port0 + port)) + break; + rcu_read_unlock(); + goto next_port; + } + rcu_read_unlock(); + spin_lock_bh(&head->lock); /* Does not bother with rcv_saddr checks, because @@ -1061,12 +1117,13 @@ other_parity_scan: if (inet_bind_bucket_match(tb, net, port, l3mdev)) { if (tb->fastreuse >= 0 || tb->fastreuseport >= 0) - goto next_port; - WARN_ON(hlist_empty(&tb->owners)); + goto next_port_unlock; + WARN_ON(hlist_empty(&tb->bhash2)); if (!check_established(death_row, sk, - port, &tw)) + port, &tw, false, + hash_port0 + port)) goto ok; - goto next_port; + goto next_port_unlock; } } @@ -1080,15 +1137,17 @@ other_parity_scan: tb->fastreuse = -1; tb->fastreuseport = -1; goto ok; -next_port: +next_port_unlock: spin_unlock_bh(&head->lock); +next_port: cond_resched(); } - offset++; - if ((offset & 1) && remaining > 1) - goto other_parity_scan; - + if (!local_ports) { + offset++; + if ((offset & 1) && remaining > 1) + goto other_parity_scan; + } return -EADDRNOTAVAIL; ok: @@ -1101,9 +1160,11 @@ ok: tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); if (!tb2) { tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net, - head2, port, l3mdev, sk); + head2, tb, sk); if (!tb2) goto error; + tb2->fastreuse = -1; + tb2->fastreuseport = -1; } /* Here we want to add a little bit of randomness to the next source @@ -1111,11 +1172,12 @@ ok: * on low contention the randomness is maximal and on high contention * it may be inexistent. */ - i = max_t(int, i, get_random_u32_below(8) * 2); - WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2); + i = max_t(int, i, get_random_u32_below(8) * step); + WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + step); /* Head lock still held and bh's disabled */ inet_bind_hash(sk, tb, tb2, port); + sk->sk_userlocks |= SOCK_CONNECT_BIND; if (sk_unhashed(sk)) { inet_sk(sk)->inet_sport = htons(port); @@ -1133,10 +1195,33 @@ ok: return 0; error: + if (sk_hashed(sk)) { + spinlock_t *lock = inet_ehash_lockp(hinfo, sk->sk_hash); + + sock_prot_inuse_add(net, sk->sk_prot, -1); + + spin_lock(lock); + __sk_nulls_del_node_init_rcu(sk); + spin_unlock(lock); + + sk->sk_hash = 0; + inet_sk(sk)->inet_sport = 0; + inet_sk(sk)->inet_num = 0; + + if (tw) + inet_twsk_bind_unhash(tw, hinfo); + } + spin_unlock(&head2->lock); if (tb_created) - inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb); - spin_unlock_bh(&head->lock); + inet_bind_bucket_destroy(tb); + spin_unlock(&head->lock); + + if (tw) + inet_twsk_deschedule_put(tw); + + local_bh_enable(); + return -ENOMEM; } @@ -1146,14 +1231,20 @@ error: int inet_hash_connect(struct inet_timewait_death_row *death_row, struct sock *sk) { + const struct inet_sock *inet = inet_sk(sk); + const struct net *net = sock_net(sk); u64 port_offset = 0; + u32 hash_port0; if (!inet_sk(sk)->inet_num) port_offset = inet_sk_port_offset(sk); - return __inet_hash_connect(death_row, sk, port_offset, + + hash_port0 = inet_ehashfn(net, inet->inet_rcv_saddr, 0, + inet->inet_daddr, inet->inet_dport); + + return __inet_hash_connect(death_row, sk, port_offset, hash_port0, __inet_check_established); } -EXPORT_SYMBOL_GPL(inet_hash_connect); static void init_hashinfo_lhash2(struct inet_hashinfo *h) { @@ -1204,32 +1295,45 @@ int inet_hashinfo2_init_mod(struct inet_hashinfo *h) init_hashinfo_lhash2(h); return 0; } -EXPORT_SYMBOL_GPL(inet_hashinfo2_init_mod); int inet_ehash_locks_alloc(struct inet_hashinfo *hashinfo) { unsigned int locksz = sizeof(spinlock_t); unsigned int i, nblocks = 1; + spinlock_t *ptr = NULL; - if (locksz != 0) { - /* allocate 2 cache lines or at least one spinlock per cpu */ - nblocks = max(2U * L1_CACHE_BYTES / locksz, 1U); - nblocks = roundup_pow_of_two(nblocks * num_possible_cpus()); + if (locksz == 0) + goto set_mask; - /* no more locks than number of hash buckets */ - nblocks = min(nblocks, hashinfo->ehash_mask + 1); + /* Allocate 2 cache lines or at least one spinlock per cpu. */ + nblocks = max(2U * L1_CACHE_BYTES / locksz, 1U) * num_possible_cpus(); - hashinfo->ehash_locks = kvmalloc_array(nblocks, locksz, GFP_KERNEL); - if (!hashinfo->ehash_locks) - return -ENOMEM; + /* At least one page per NUMA node. */ + nblocks = max(nblocks, num_online_nodes() * PAGE_SIZE / locksz); + + nblocks = roundup_pow_of_two(nblocks); + + /* No more locks than number of hash buckets. */ + nblocks = min(nblocks, hashinfo->ehash_mask + 1); - for (i = 0; i < nblocks; i++) - spin_lock_init(&hashinfo->ehash_locks[i]); + if (num_online_nodes() > 1) { + /* Use vmalloc() to allow NUMA policy to spread pages + * on all available nodes if desired. + */ + ptr = vmalloc_array(nblocks, locksz); + } + if (!ptr) { + ptr = kvmalloc_array(nblocks, locksz, GFP_KERNEL); + if (!ptr) + return -ENOMEM; } + for (i = 0; i < nblocks; i++) + spin_lock_init(&ptr[i]); + hashinfo->ehash_locks = ptr; +set_mask: hashinfo->ehash_locks_mask = nblocks - 1; return 0; } -EXPORT_SYMBOL_GPL(inet_ehash_locks_alloc); struct inet_hashinfo *inet_pernet_hashinfo_alloc(struct inet_hashinfo *hashinfo, unsigned int ehash_entries) @@ -1265,7 +1369,6 @@ free_hashinfo: err: return NULL; } -EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_alloc); void inet_pernet_hashinfo_free(struct inet_hashinfo *hashinfo) { @@ -1276,4 +1379,3 @@ void inet_pernet_hashinfo_free(struct inet_hashinfo *hashinfo) vfree(hashinfo->ehash); kfree(hashinfo); } -EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_free); |
