diff options
Diffstat (limited to 'net/ipv4')
130 files changed, 19834 insertions, 10966 deletions
diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 87983e70f03f..b71c22475c51 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -321,7 +321,6 @@ config NET_UDP_TUNNEL config NET_FOU tristate "IP: Foo (IP protocols) over UDP" - select XFRM select NET_UDP_TUNNEL help Foo over UDP allows any IP protocol to be directly encapsulated @@ -403,6 +402,16 @@ config INET_IPCOMP If unsure, say Y. +config INET_TABLE_PERTURB_ORDER + int "INET: Source port perturbation table size (as power of 2)" if EXPERT + default 16 + help + Source port perturbation table size (as power of 2) for + RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm. + + The default is almost always what you want. + Only change this if you know what you are doing. + config INET_XFRM_TUNNEL tristate select INET_TUNNEL @@ -416,7 +425,7 @@ config INET_DIAG tristate "INET: socket monitoring interface" default y help - Support for INET (TCP, DCCP, etc) socket monitoring interface used by + Support for INET (TCP, UDP, etc) socket monitoring interface used by native Linux tools such as ss. ss is included in iproute2, currently downloadable at: @@ -652,7 +661,8 @@ config TCP_CONG_CDG For further details see: D.A. Hayes and G. Armitage. "Revisiting TCP congestion control using - delay gradients." In Networking 2011. Preprint: http://goo.gl/No3vdg + delay gradients." In Networking 2011. Preprint: + http://caia.swin.edu.au/cv/dahayes/content/networking2011-cdg-preprint.pdf config TCP_CONG_BBR tristate "BBR TCP" @@ -732,10 +742,25 @@ config DEFAULT_TCP_CONG default "bbr" if DEFAULT_BBR default "cubic" +config TCP_SIGPOOL + tristate + +config TCP_AO + bool "TCP: Authentication Option (RFC5925)" + select CRYPTO + select TCP_SIGPOOL + depends on 64BIT && IPV6 != m # seq-number extension needs WRITE_ONCE(u64) + help + TCP-AO specifies the use of stronger Message Authentication Codes (MACs), + protects against replays for long-lived TCP connections, and + provides more details on the association of security with TCP + connections than TCP MD5 (See RFC5925) + + If unsure, say N. + config TCP_MD5SIG bool "TCP: MD5 Signature Option support (RFC2385)" - select CRYPTO - select CRYPTO_MD5 + select CRYPTO_LIB_MD5 help RFC2385 specifies a method of giving MD5 protection to TCP sessions. Its main (only?) use is to protect BGP sessions between core routers diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile index bbdd9c44f14e..ec36d2ec059e 100644 --- a/net/ipv4/Makefile +++ b/net/ipv4/Makefile @@ -10,14 +10,12 @@ obj-y := route.o inetpeer.o protocol.o \ tcp.o tcp_input.o tcp_output.o tcp_timer.o tcp_ipv4.o \ tcp_minisocks.o tcp_cong.o tcp_metrics.o tcp_fastopen.o \ tcp_rate.o tcp_recovery.o tcp_ulp.o \ - tcp_offload.o datagram.o raw.o udp.o udplite.o \ + tcp_offload.o tcp_plb.o datagram.o raw.o udp.o udplite.o \ udp_offload.o arp.o icmp.o devinet.o af_inet.o igmp.o \ fib_frontend.o fib_semantics.o fib_trie.o fib_notifier.o \ inet_fragment.o ping.o ip_tunnel_core.o gre_offload.o \ metrics.o netlink.o nexthop.o udp_tunnel_stub.o -obj-$(CONFIG_BPFILTER) += bpfilter/ - obj-$(CONFIG_NET_IP_TUNNEL) += ip_tunnel.o obj-$(CONFIG_SYSCTL) += sysctl_net_ipv4.o obj-$(CONFIG_PROC_FS) += proc.o @@ -26,6 +24,7 @@ obj-$(CONFIG_IP_MROUTE) += ipmr.o obj-$(CONFIG_IP_MROUTE_COMMON) += ipmr_base.o obj-$(CONFIG_NET_IPIP) += ipip.o gre-y := gre_demux.o +fou-y := fou_core.o fou_nl.o fou_bpf.o obj-$(CONFIG_NET_FOU) += fou.o obj-$(CONFIG_NET_IPGRE_DEMUX) += gre.o obj-$(CONFIG_NET_IPGRE) += ip_gre.o @@ -61,12 +60,14 @@ obj-$(CONFIG_TCP_CONG_SCALABLE) += tcp_scalable.o obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o +obj-$(CONFIG_TCP_SIGPOOL) += tcp_sigpool.o obj-$(CONFIG_NET_SOCK_MSG) += tcp_bpf.o obj-$(CONFIG_BPF_SYSCALL) += udp_bpf.o obj-$(CONFIG_NETLABEL) += cipso_ipv4.o obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \ xfrm4_output.o xfrm4_protocol.o +obj-$(CONFIG_TCP_AO) += tcp_ao.o ifeq ($(CONFIG_BPF_JIT),y) obj-$(CONFIG_BPF_SYSCALL) += bpf_tcp_ca.o diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 9c465bac1eb0..08d811f11896 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -100,7 +100,9 @@ #include <net/ip_fib.h> #include <net/inet_connection_sock.h> #include <net/gro.h> +#include <net/gso.h> #include <net/tcp.h> +#include <net/psp.h> #include <net/udp.h> #include <net/udplite.h> #include <net/ping.h> @@ -118,6 +120,7 @@ #endif #include <net/l3mdev.h> #include <net/compat.h> +#include <net/rps.h> #include <trace/events/sock.h> @@ -148,15 +151,15 @@ void inet_sock_destruct(struct sock *sk) return; } - WARN_ON(atomic_read(&sk->sk_rmem_alloc)); - WARN_ON(refcount_read(&sk->sk_wmem_alloc)); - WARN_ON(sk->sk_wmem_queued); - WARN_ON(sk_forward_alloc_get(sk)); + WARN_ON_ONCE(atomic_read(&sk->sk_rmem_alloc)); + WARN_ON_ONCE(refcount_read(&sk->sk_wmem_alloc)); + WARN_ON_ONCE(sk->sk_wmem_queued); + WARN_ON_ONCE(sk->sk_forward_alloc); kfree(rcu_dereference_protected(inet->inet_opt, 1)); dst_release(rcu_dereference_protected(sk->sk_dst_cache, 1)); dst_release(rcu_dereference_protected(sk->sk_rx_dst, 1)); - sk_refcnt_debug_dec(sk); + psp_sk_assoc_free(sk); } EXPORT_SYMBOL(inet_sock_destruct); @@ -187,24 +190,13 @@ static int inet_autobind(struct sock *sk) return 0; } -/* - * Move a socket into listening state. - */ -int inet_listen(struct socket *sock, int backlog) +int __inet_listen_sk(struct sock *sk, int backlog) { - struct sock *sk = sock->sk; - unsigned char old_state; + unsigned char old_state = sk->sk_state; int err, tcp_fastopen; - lock_sock(sk); - - err = -EINVAL; - if (sock->state != SS_UNCONNECTED || sock->type != SOCK_STREAM) - goto out; - - old_state = sk->sk_state; if (!((1 << old_state) & (TCPF_CLOSE | TCPF_LISTEN))) - goto out; + return -EINVAL; WRITE_ONCE(sk->sk_max_ack_backlog, backlog); /* Really, if the socket is already in listen state @@ -217,7 +209,7 @@ int inet_listen(struct socket *sock, int backlog) * because the socket was in TCP_LISTEN state previously but * was shutdown() rather than close(). */ - tcp_fastopen = sock_net(sk)->ipv4.sysctl_tcp_fastopen; + tcp_fastopen = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen); if ((tcp_fastopen & TFO_SERVER_WO_SOCKOPT1) && (tcp_fastopen & TFO_SERVER_ENABLE) && !inet_csk(sk)->icsk_accept_queue.fastopenq.max_qlen) { @@ -227,10 +219,27 @@ int inet_listen(struct socket *sock, int backlog) err = inet_csk_listen_start(sk); if (err) - goto out; + return err; + tcp_call_bpf(sk, BPF_SOCK_OPS_TCP_LISTEN_CB, 0, NULL); } - err = 0; + return 0; +} + +/* + * Move a socket into listening state. + */ +int inet_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int err = -EINVAL; + + lock_sock(sk); + + if (sock->state != SS_UNCONNECTED || sock->type != SOCK_STREAM) + goto out; + + err = __inet_listen_sk(sk, backlog); out: release_sock(sk); @@ -324,40 +333,42 @@ lookup_protocol: if (INET_PROTOSW_REUSE & answer_flags) sk->sk_reuse = SK_CAN_REUSE; + if (INET_PROTOSW_ICSK & answer_flags) + inet_init_csk_locks(sk); + inet = inet_sk(sk); - inet->is_icsk = (INET_PROTOSW_ICSK & answer_flags) != 0; + inet_assign_bit(IS_ICSK, sk, INET_PROTOSW_ICSK & answer_flags); - inet->nodefrag = 0; + inet_clear_bit(NODEFRAG, sk); if (SOCK_RAW == sock->type) { inet->inet_num = protocol; if (IPPROTO_RAW == protocol) - inet->hdrincl = 1; + inet_set_bit(HDRINCL, sk); } - if (net->ipv4.sysctl_ip_no_pmtu_disc) + if (READ_ONCE(net->ipv4.sysctl_ip_no_pmtu_disc)) inet->pmtudisc = IP_PMTUDISC_DONT; else inet->pmtudisc = IP_PMTUDISC_WANT; - inet->inet_id = 0; + atomic_set(&inet->inet_id, 0); sock_init_data(sock, sk); sk->sk_destruct = inet_sock_destruct; sk->sk_protocol = protocol; sk->sk_backlog_rcv = sk->sk_prot->backlog_rcv; + sk->sk_txrehash = READ_ONCE(net->core.sysctl_txrehash); inet->uc_ttl = -1; - inet->mc_loop = 1; + inet_set_bit(MC_LOOP, sk); inet->mc_ttl = 1; - inet->mc_all = 1; + inet_set_bit(MC_ALL, sk); inet->mc_index = 0; inet->mc_list = NULL; inet->rcv_tos = 0; - sk_refcnt_debug_inc(sk); - if (inet->inet_num) { /* It assumes that any protocol which allows * the user to assign a number at socket @@ -367,32 +378,30 @@ lookup_protocol: inet->inet_sport = htons(inet->inet_num); /* Add to protocol hash chains. */ err = sk->sk_prot->hash(sk); - if (err) { - sk_common_release(sk); - goto out; - } + if (err) + goto out_sk_release; } if (sk->sk_prot->init) { err = sk->sk_prot->init(sk); - if (err) { - sk_common_release(sk); - goto out; - } + if (err) + goto out_sk_release; } if (!kern) { err = BPF_CGROUP_RUN_PROG_INET_SOCK(sk); - if (err) { - sk_common_release(sk); - goto out; - } + if (err) + goto out_sk_release; } out: return err; out_rcu_unlock: rcu_read_unlock(); goto out; +out_sk_release: + sk_common_release(sk); + sock->sk = NULL; + goto out; } @@ -432,9 +441,8 @@ int inet_release(struct socket *sock) } EXPORT_SYMBOL(inet_release); -int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +int inet_bind_sk(struct sock *sk, struct sockaddr_unsized *uaddr, int addr_len) { - struct sock *sk = sock->sk; u32 flags = BIND_WITH_LOCK; int err; @@ -448,16 +456,21 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) /* BPF prog is run before any checks are done so that if the prog * changes context in a wrong way it will be caught. */ - err = BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr, + err = BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr, &addr_len, CGROUP_INET4_BIND, &flags); if (err) return err; return __inet_bind(sk, uaddr, addr_len, flags); } + +int inet_bind(struct socket *sock, struct sockaddr_unsized *uaddr, int addr_len) +{ + return inet_bind_sk(sock->sk, uaddr, addr_len); +} EXPORT_SYMBOL(inet_bind); -int __inet_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len, +int __inet_bind(struct sock *sk, struct sockaddr_unsized *uaddr, int addr_len, u32 flags) { struct sockaddr_in *addr = (struct sockaddr_in *)uaddr; @@ -520,11 +533,11 @@ int __inet_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len, inet->inet_saddr = 0; /* Use device */ /* Make sure we are allowed to bind here. */ - if (snum || !(inet->bind_address_no_port || + if (snum || !(inet_test_bit(BIND_ADDRESS_NO_PORT, sk) || (flags & BIND_FORCE_ADDRESS_NO_PORT))) { - if (sk->sk_prot->get_port(sk, snum)) { + err = sk->sk_prot->get_port(sk, snum); + if (err) { inet->inet_saddr = inet->inet_rcv_saddr = 0; - err = -EADDRINUSE; goto out_release_sock; } if (!(flags & BIND_FROM_BPF)) { @@ -554,26 +567,31 @@ out: return err; } -int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr, +int inet_dgram_connect(struct socket *sock, struct sockaddr_unsized *uaddr, int addr_len, int flags) { struct sock *sk = sock->sk; + const struct proto *prot; int err; if (addr_len < sizeof(uaddr->sa_family)) return -EINVAL; + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + prot = READ_ONCE(sk->sk_prot); + if (uaddr->sa_family == AF_UNSPEC) - return sk->sk_prot->disconnect(sk, flags); + return prot->disconnect(sk, flags); if (BPF_CGROUP_PRE_CONNECT_ENABLED(sk)) { - err = sk->sk_prot->pre_connect(sk, uaddr, addr_len); + err = prot->pre_connect(sk, uaddr, addr_len); if (err) return err; } if (data_race(!inet_sk(sk)->inet_num) && inet_autobind(sk)) return -EAGAIN; - return sk->sk_prot->connect(sk, uaddr, addr_len); + return prot->connect(sk, uaddr, addr_len); } EXPORT_SYMBOL(inet_dgram_connect); @@ -605,7 +623,7 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias) * Connect to a remote host. There is regrettably still a little * TCP 'magic' in here. */ -int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, +int __inet_stream_connect(struct socket *sock, struct sockaddr_unsized *uaddr, int addr_len, int flags, int is_sendmsg) { struct sock *sk = sock->sk; @@ -626,6 +644,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, return -EINVAL; if (uaddr->sa_family == AF_UNSPEC) { + sk->sk_disconnects++; err = sk->sk_prot->disconnect(sk, flags); sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED; goto out; @@ -640,7 +659,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, err = -EISCONN; goto out; case SS_CONNECTING: - if (inet_sk(sk)->defer_connect) + if (inet_test_bit(DEFER_CONNECT, sk)) err = is_sendmsg ? -EINPROGRESS : -EISCONN; else err = -EALREADY; @@ -663,7 +682,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, sock->state = SS_CONNECTING; - if (!err && inet_sk(sk)->defer_connect) + if (!err && inet_test_bit(DEFER_CONNECT, sk)) goto out; /* Just entered SS_CONNECTING state; the only @@ -680,6 +699,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, int writebias = (sk->sk_protocol == IPPROTO_TCP) && tcp_sk(sk)->fastopen_req && tcp_sk(sk)->fastopen_req->data ? 1 : 0; + int dis = sk->sk_disconnects; /* Error code is set above */ if (!timeo || !inet_wait_for_connect(sk, timeo, writebias)) @@ -688,6 +708,11 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, err = sock_intr_errno(timeo); if (signal_pending(current)) goto out; + + if (dis != sk->sk_disconnects) { + err = -EPIPE; + goto out; + } } /* Connection was closed by RST, timeout, ICMP error @@ -709,13 +734,14 @@ out: sock_error: err = sock_error(sk) ? : -ECONNABORTED; sock->state = SS_UNCONNECTED; + sk->sk_disconnects++; if (sk->sk_prot->disconnect(sk, flags)) sock->state = SS_DISCONNECTING; goto out; } EXPORT_SYMBOL(__inet_stream_connect); -int inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, +int inet_stream_connect(struct socket *sock, struct sockaddr_unsized *uaddr, int addr_len, int flags) { int err; @@ -727,34 +753,47 @@ int inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, } EXPORT_SYMBOL(inet_stream_connect); +void __inet_accept(struct socket *sock, struct socket *newsock, struct sock *newsk) +{ + if (mem_cgroup_sockets_enabled) { + mem_cgroup_sk_alloc(newsk); + __sk_charge(newsk, GFP_KERNEL); + } + + sock_rps_record_flow(newsk); + WARN_ON(!((1 << newsk->sk_state) & + (TCPF_ESTABLISHED | TCPF_SYN_RECV | + TCPF_FIN_WAIT1 | TCPF_FIN_WAIT2 | + TCPF_CLOSING | TCPF_CLOSE_WAIT | + TCPF_CLOSE))); + + if (test_bit(SOCK_SUPPORT_ZC, &sock->flags)) + set_bit(SOCK_SUPPORT_ZC, &newsock->flags); + sock_graft(newsk, newsock); + + newsock->state = SS_CONNECTED; +} +EXPORT_SYMBOL_GPL(__inet_accept); + /* * Accept a pending connection. The TCP layer now gives BSD semantics. */ -int inet_accept(struct socket *sock, struct socket *newsock, int flags, - bool kern) +int inet_accept(struct socket *sock, struct socket *newsock, + struct proto_accept_arg *arg) { - struct sock *sk1 = sock->sk; - int err = -EINVAL; - struct sock *sk2 = sk1->sk_prot->accept(sk1, flags, &err, kern); + struct sock *sk1 = sock->sk, *sk2; + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + arg->err = -EINVAL; + sk2 = READ_ONCE(sk1->sk_prot)->accept(sk1, arg); if (!sk2) - goto do_err; + return arg->err; lock_sock(sk2); - - sock_rps_record_flow(sk2); - WARN_ON(!((1 << sk2->sk_state) & - (TCPF_ESTABLISHED | TCPF_SYN_RECV | - TCPF_CLOSE_WAIT | TCPF_CLOSE))); - - sock_graft(sk2, newsock); - - newsock->state = SS_CONNECTED; - err = 0; + __inet_accept(sock, newsock, sk2); release_sock(sk2); -do_err: - return err; + return 0; } EXPORT_SYMBOL(inet_accept); @@ -767,6 +806,7 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr, struct sock *sk = sock->sk; struct inet_sock *inet = inet_sk(sk); DECLARE_SOCKADDR(struct sockaddr_in *, sin, uaddr); + int sin_addr_len = sizeof(*sin); sin->sin_family = AF_INET; lock_sock(sk); @@ -779,7 +819,7 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr, } sin->sin_port = inet->inet_dport; sin->sin_addr.s_addr = inet->inet_daddr; - BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, + BPF_CGROUP_RUN_SA_PROG(sk, sin, &sin_addr_len, CGROUP_INET4_GETPEERNAME); } else { __be32 addr = inet->inet_rcv_saddr; @@ -787,12 +827,12 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr, addr = inet->inet_saddr; sin->sin_port = inet->inet_sport; sin->sin_addr.s_addr = addr; - BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, + BPF_CGROUP_RUN_SA_PROG(sk, sin, &sin_addr_len, CGROUP_INET4_GETSOCKNAME); } release_sock(sk); memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); - return sizeof(*sin); + return sin_addr_len; } EXPORT_SYMBOL(inet_getname); @@ -821,22 +861,23 @@ int inet_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) } EXPORT_SYMBOL(inet_sendmsg); -ssize_t inet_sendpage(struct socket *sock, struct page *page, int offset, - size_t size, int flags) +void inet_splice_eof(struct socket *sock) { + const struct proto *prot; struct sock *sk = sock->sk; if (unlikely(inet_send_prepare(sk))) - return -EAGAIN; + return; - if (sk->sk_prot->sendpage) - return sk->sk_prot->sendpage(sk, page, offset, size, flags); - return sock_no_sendpage(sock, page, offset, size, flags); + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + prot = READ_ONCE(sk->sk_prot); + if (prot->splice_eof) + prot->splice_eof(sock); } -EXPORT_SYMBOL(inet_sendpage); +EXPORT_SYMBOL_GPL(inet_splice_eof); INDIRECT_CALLABLE_DECLARE(int udp_recvmsg(struct sock *, struct msghdr *, - size_t, int, int, int *)); + size_t, int, int *)); int inet_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, int flags) { @@ -848,8 +889,7 @@ int inet_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, sock_rps_record_flow(sk); err = INDIRECT_CALL_2(sk->sk_prot->recvmsg, tcp_recvmsg, udp_recvmsg, - sk, msg, size, flags & MSG_DONTWAIT, - flags & ~MSG_DONTWAIT, &addr_len); + sk, msg, size, flags, &addr_len); if (err >= 0) msg->msg_namelen = addr_len; return err; @@ -886,7 +926,7 @@ int inet_shutdown(struct socket *sock, int how) EPOLLHUP, even on eg. unconnected UDP sockets -- RR */ fallthrough; default: - sk->sk_shutdown |= how; + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | how); if (sk->sk_prot->shutdown) sk->sk_prot->shutdown(sk, how); break; @@ -970,7 +1010,7 @@ int inet_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) break; default: if (sk->sk_prot->ioctl) - err = sk->sk_prot->ioctl(sk, cmd, arg); + err = sk_ioctl(sk, cmd, (void __user *)arg); else err = -ENOIOCTLCMD; break; @@ -1038,11 +1078,12 @@ const struct proto_ops inet_stream_ops = { #ifdef CONFIG_MMU .mmap = tcp_mmap, #endif - .sendpage = inet_sendpage, + .splice_eof = inet_splice_eof, .splice_read = tcp_splice_read, + .set_peek_off = sk_set_peek_off, .read_sock = tcp_read_sock, + .read_skb = tcp_read_skb, .sendmsg_locked = tcp_sendmsg_locked, - .sendpage_locked = tcp_sendpage_locked, .peek_len = tcp_peek_len, #ifdef CONFIG_COMPAT .compat_ioctl = inet_compat_ioctl, @@ -1068,11 +1109,11 @@ const struct proto_ops inet_dgram_ops = { .setsockopt = sock_common_setsockopt, .getsockopt = sock_common_getsockopt, .sendmsg = inet_sendmsg, - .read_sock = udp_read_sock, + .read_skb = udp_read_skb, .recvmsg = inet_recvmsg, .mmap = sock_no_mmap, - .sendpage = inet_sendpage, - .set_peek_off = sk_set_peek_off, + .splice_eof = inet_splice_eof, + .set_peek_off = udp_set_peek_off, #ifdef CONFIG_COMPAT .compat_ioctl = inet_compat_ioctl, #endif @@ -1102,7 +1143,7 @@ static const struct proto_ops inet_sockraw_ops = { .sendmsg = inet_sendmsg, .recvmsg = inet_recvmsg, .mmap = sock_no_mmap, - .sendpage = inet_sendpage, + .splice_eof = inet_splice_eof, #ifdef CONFIG_COMPAT .compat_ioctl = inet_compat_ioctl, #endif @@ -1226,6 +1267,7 @@ static int inet_sk_reselect_saddr(struct sock *sk) struct rtable *rt; __be32 new_saddr; struct ip_options_rcu *inet_opt; + int err; inet_opt = rcu_dereference_protected(inet->inet_opt, lockdep_sock_is_held(sk)); @@ -1234,26 +1276,32 @@ static int inet_sk_reselect_saddr(struct sock *sk) /* Query new route. */ fl4 = &inet->cork.fl.u.ip4; - rt = ip_route_connect(fl4, daddr, 0, RT_CONN_FLAGS(sk), - sk->sk_bound_dev_if, sk->sk_protocol, - inet->inet_sport, inet->inet_dport, sk); + rt = ip_route_connect(fl4, daddr, 0, sk->sk_bound_dev_if, + sk->sk_protocol, inet->inet_sport, + inet->inet_dport, sk); if (IS_ERR(rt)) return PTR_ERR(rt); - sk_setup_caps(sk, &rt->dst); - new_saddr = fl4->saddr; - if (new_saddr == old_saddr) + if (new_saddr == old_saddr) { + sk_setup_caps(sk, &rt->dst); return 0; + } - if (sock_net(sk)->ipv4.sysctl_ip_dynaddr > 1) { + err = inet_bhash2_update_saddr(sk, &new_saddr, AF_INET); + if (err) { + ip_rt_put(rt); + return err; + } + + sk_setup_caps(sk, &rt->dst); + + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_ip_dynaddr) > 1) { pr_info("%s(): shifting inet->saddr from %pI4 to %pI4\n", __func__, &old_saddr, &new_saddr); } - inet->inet_saddr = inet->inet_rcv_saddr = new_saddr; - /* * XXX The only one ugly spot where we need to * XXX really change the sockets identity after @@ -1267,10 +1315,8 @@ static int inet_sk_reselect_saddr(struct sock *sk) int inet_sk_rebuild_header(struct sock *sk) { + struct rtable *rt = dst_rtable(__sk_dst_check(sk, 0)); struct inet_sock *inet = inet_sk(sk); - struct rtable *rt = (struct rtable *)__sk_dst_check(sk, 0); - __be32 daddr; - struct ip_options_rcu *inet_opt; struct flowi4 *fl4; int err; @@ -1279,17 +1325,9 @@ int inet_sk_rebuild_header(struct sock *sk) return 0; /* Reroute. */ - rcu_read_lock(); - inet_opt = rcu_dereference(inet->inet_opt); - daddr = inet->inet_daddr; - if (inet_opt && inet_opt->opt.srr) - daddr = inet_opt->opt.faddr; - rcu_read_unlock(); fl4 = &inet->cork.fl.u.ip4; - rt = ip_route_output_ports(sock_net(sk), fl4, sk, daddr, inet->inet_saddr, - inet->inet_dport, inet->inet_sport, - sk->sk_protocol, RT_CONN_FLAGS(sk), - sk->sk_bound_dev_if); + inet_sk_init_flowi4(inet, fl4); + rt = ip_route_output_flow(sock_net(sk), fl4, sk); if (!IS_ERR(rt)) { err = 0; sk_setup_caps(sk, &rt->dst); @@ -1298,15 +1336,12 @@ int inet_sk_rebuild_header(struct sock *sk) /* Routing failed... */ sk->sk_route_caps = 0; - /* - * Other protocols have to map its equivalent state to TCP_SYN_SENT. - * DCCP maps its DCCP_REQUESTING state to TCP_SYN_SENT. -acme - */ - if (!sock_net(sk)->ipv4.sysctl_ip_dynaddr || + + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_ip_dynaddr) || sk->sk_state != TCP_SYN_SENT || (sk->sk_userlocks & SOCK_BINDADDR_LOCK) || (err = inet_sk_reselect_saddr(sk)) != 0) - sk->sk_err_soft = -err; + WRITE_ONCE(sk->sk_err_soft, -err); } return err; @@ -1366,18 +1401,17 @@ struct sk_buff *inet_gso_segment(struct sk_buff *skb, segs = ERR_PTR(-EPROTONOSUPPORT); - if (!skb->encapsulation || encap) { - udpfrag = !!(skb_shinfo(skb)->gso_type & SKB_GSO_UDP); - fixedid = !!(skb_shinfo(skb)->gso_type & SKB_GSO_TCP_FIXEDID); + fixedid = !!(skb_shinfo(skb)->gso_type & (SKB_GSO_TCP_FIXEDID << encap)); - /* fixed ID is invalid if DF bit is not set */ - if (fixedid && !(ip_hdr(skb)->frag_off & htons(IP_DF))) - goto out; - } + if (!skb->encapsulation || encap) + udpfrag = !!(skb_shinfo(skb)->gso_type & SKB_GSO_UDP); ops = rcu_dereference(inet_offloads[proto]); - if (likely(ops && ops->callbacks.gso_segment)) + if (likely(ops && ops->callbacks.gso_segment)) { segs = ops->callbacks.gso_segment(skb, features); + if (!segs) + skb->network_header = skb_mac_header(skb) + nhoff - skb->head; + } if (IS_ERR_OR_NULL(segs)) goto out; @@ -1439,18 +1473,14 @@ struct sk_buff *inet_gro_receive(struct list_head *head, struct sk_buff *skb) struct sk_buff *p; unsigned int hlen; unsigned int off; - unsigned int id; int flush = 1; int proto; off = skb_gro_offset(skb); hlen = off + sizeof(*iph); - iph = skb_gro_header_fast(skb, off); - if (skb_gro_header_hard(skb, hlen)) { - iph = skb_gro_header_slow(skb, hlen, off); - if (unlikely(!iph)) - goto out; - } + iph = skb_gro_header(skb, hlen, off); + if (unlikely(!iph)) + goto out; proto = iph->protocol; @@ -1467,13 +1497,11 @@ struct sk_buff *inet_gro_receive(struct list_head *head, struct sk_buff *skb) if (unlikely(ip_fast_csum((u8 *)iph, 5))) goto out; - id = ntohl(*(__be32 *)&iph->id); - flush = (u16)((ntohl(*(__be32 *)iph) ^ skb_gro_len(skb)) | (id & ~IP_DF)); - id >>= 16; + NAPI_GRO_CB(skb)->proto = proto; + flush = (u16)((ntohl(*(__be32 *)iph) ^ skb_gro_len(skb)) | (ntohl(*(__be32 *)&iph->id) & ~IP_DF)); list_for_each_entry(p, head, list) { struct iphdr *iph2; - u16 flush_id; if (!NAPI_GRO_CB(p)->same_flow) continue; @@ -1490,48 +1518,10 @@ struct sk_buff *inet_gro_receive(struct list_head *head, struct sk_buff *skb) NAPI_GRO_CB(p)->same_flow = 0; continue; } - - /* All fields must match except length and checksum. */ - NAPI_GRO_CB(p)->flush |= - (iph->ttl ^ iph2->ttl) | - (iph->tos ^ iph2->tos) | - ((iph->frag_off ^ iph2->frag_off) & htons(IP_DF)); - - NAPI_GRO_CB(p)->flush |= flush; - - /* We need to store of the IP ID check to be included later - * when we can verify that this packet does in fact belong - * to a given flow. - */ - flush_id = (u16)(id - ntohs(iph2->id)); - - /* This bit of code makes it much easier for us to identify - * the cases where we are doing atomic vs non-atomic IP ID - * checks. Specifically an atomic check can return IP ID - * values 0 - 0xFFFF, while a non-atomic check can only - * return 0 or 0xFFFF. - */ - if (!NAPI_GRO_CB(p)->is_atomic || - !(iph->frag_off & htons(IP_DF))) { - flush_id ^= NAPI_GRO_CB(p)->count; - flush_id = flush_id ? 0xFFFF : 0; - } - - /* If the previous IP ID value was based on an atomic - * datagram we can overwrite the value and ignore it. - */ - if (NAPI_GRO_CB(skb)->is_atomic) - NAPI_GRO_CB(p)->flush_id = flush_id; - else - NAPI_GRO_CB(p)->flush_id |= flush_id; } - NAPI_GRO_CB(skb)->is_atomic = !!(iph->frag_off & htons(IP_DF)); NAPI_GRO_CB(skb)->flush |= flush; - skb_set_network_header(skb, off); - /* The above will be needed by the transport layer if there is one - * immediately following this IP hdr. - */ + NAPI_GRO_CB(skb)->network_offsets[NAPI_GRO_CB(skb)->encap_mark] = off; /* Note : No need to call skb_gro_postpull_rcsum() here, * as we already checked checksum over ipv4 header was 0 @@ -1589,20 +1579,23 @@ EXPORT_SYMBOL(inet_current_timestamp); int inet_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) { - if (sk->sk_family == AF_INET) + unsigned int family = READ_ONCE(sk->sk_family); + + if (family == AF_INET) return ip_recv_error(sk, msg, len, addr_len); #if IS_ENABLED(CONFIG_IPV6) - if (sk->sk_family == AF_INET6) + if (family == AF_INET6) return pingv6_ops.ipv6_recv_error(sk, msg, len, addr_len); #endif return -EINVAL; } +EXPORT_SYMBOL(inet_recv_error); int inet_gro_complete(struct sk_buff *skb, int nhoff) { - __be16 newlen = htons(skb->len - nhoff); struct iphdr *iph = (struct iphdr *)(skb->data + nhoff); const struct net_offload *ops; + __be16 totlen = iph->tot_len; int proto = iph->protocol; int err = -ENOSYS; @@ -1611,8 +1604,8 @@ int inet_gro_complete(struct sk_buff *skb, int nhoff) skb_set_inner_network_header(skb, nhoff); } - csum_replace2(&iph->check, iph->tot_len, newlen); - iph->tot_len = newlen; + iph_set_totlen(iph, skb->len - nhoff); + csum_replace2(&iph->check, totlen, iph->tot_len); ops = rcu_dereference(inet_offloads[proto]); if (WARN_ON(!ops || !ops->callbacks.gro_complete)) @@ -1647,6 +1640,7 @@ int inet_ctl_sock_create(struct sock **sk, unsigned short family, if (rc == 0) { *sk = sock->sk; (*sk)->sk_allocation = GFP_ATOMIC; + (*sk)->sk_use_task_frag = false; /* * Unhash it so that IP input processing does not even see it, * we do not wish this socket to see incoming packets. @@ -1681,9 +1675,9 @@ u64 snmp_get_cpu_field64(void __percpu *mib, int cpu, int offt, bhptr = per_cpu_ptr(mib, cpu); syncp = (struct u64_stats_sync *)(bhptr + syncp_offset); do { - start = u64_stats_fetch_begin_irq(syncp); + start = u64_stats_fetch_begin(syncp); v = *(((u64 *)bhptr) + offt); - } while (u64_stats_fetch_retry_irq(syncp, start)); + } while (u64_stats_fetch_retry(syncp, start)); return v; } @@ -1708,29 +1702,6 @@ static const struct net_protocol igmp_protocol = { }; #endif -/* thinking of making this const? Don't. - * early_demux can change based on sysctl. - */ -static struct net_protocol tcp_protocol = { - .early_demux = tcp_v4_early_demux, - .early_demux_handler = tcp_v4_early_demux, - .handler = tcp_v4_rcv, - .err_handler = tcp_v4_err, - .no_policy = 1, - .icmp_strict_tag_validation = 1, -}; - -/* thinking of making this const? Don't. - * early_demux can change based on sysctl. - */ -static struct net_protocol udp_protocol = { - .early_demux = udp_v4_early_demux, - .early_demux_handler = udp_v4_early_demux, - .handler = udp_rcv, - .err_handler = udp_err, - .no_policy = 1, -}; - static const struct net_protocol icmp_protocol = { .handler = icmp_rcv, .err_handler = icmp_err, @@ -1820,9 +1791,7 @@ static __net_init int inet_init_net(struct net *net) /* * Set defaults for local port range */ - seqlock_init(&net->ipv4.ip_local_ports.lock); - net->ipv4.ip_local_ports.range[0] = 32768; - net->ipv4.ip_local_ports.range[1] = 60999; + net->ipv4.ip_local_ports.range = 60999u << 16 | 32768u; seqlock_init(&net->ipv4.ping_group_range.lock); /* @@ -1873,14 +1842,6 @@ static int ipv4_proc_init(void); * IP protocol layer initialiser */ -static struct packet_offload ip_packet_offload __read_mostly = { - .type = cpu_to_be16(ETH_P_IP), - .callbacks = { - .gso_segment = inet_gso_segment, - .gro_receive = inet_gro_receive, - .gro_complete = inet_gro_complete, - }, -}; static const struct net_offload ipip_offload = { .callbacks = { @@ -1907,7 +1868,15 @@ static int __init ipv4_offload_init(void) if (ipip_offload_init() < 0) pr_crit("%s: Cannot add IPIP protocol offload\n", __func__); - dev_add_offload(&ip_packet_offload); + net_hotdata.ip_packet_offload = (struct packet_offload) { + .type = cpu_to_be16(ETH_P_IP), + .callbacks = { + .gso_segment = inet_gso_segment, + .gro_receive = inet_gro_receive, + .gro_complete = inet_gro_complete, + }, + }; + dev_add_offload(&net_hotdata.ip_packet_offload); return 0; } @@ -1927,6 +1896,8 @@ static int __init inet_init(void) sock_skb_cb_check_size(sizeof(struct inet_skb_parm)); + raw_hashinfo_init(&raw_v4_hashinfo); + rc = proto_register(&tcp_prot, 1); if (rc) goto out; @@ -1959,9 +1930,22 @@ static int __init inet_init(void) if (inet_add_protocol(&icmp_protocol, IPPROTO_ICMP) < 0) pr_crit("%s: Cannot add ICMP protocol\n", __func__); - if (inet_add_protocol(&udp_protocol, IPPROTO_UDP) < 0) + + net_hotdata.udp_protocol = (struct net_protocol) { + .handler = udp_rcv, + .err_handler = udp_err, + .no_policy = 1, + }; + if (inet_add_protocol(&net_hotdata.udp_protocol, IPPROTO_UDP) < 0) pr_crit("%s: Cannot add UDP protocol\n", __func__); - if (inet_add_protocol(&tcp_protocol, IPPROTO_TCP) < 0) + + net_hotdata.tcp_protocol = (struct net_protocol) { + .handler = tcp_v4_rcv, + .err_handler = tcp_v4_err, + .no_policy = 1, + .icmp_strict_tag_validation = 1, + }; + if (inet_add_protocol(&net_hotdata.tcp_protocol, IPPROTO_TCP) < 0) pr_crit("%s: Cannot add TCP protocol\n", __func__); #ifdef CONFIG_IP_MULTICAST if (inet_add_protocol(&igmp_protocol, IPPROTO_IGMP) < 0) diff --git a/net/ipv4/ah4.c b/net/ipv4/ah4.c index 6eea1e9e998d..64aec3dff8ec 100644 --- a/net/ipv4/ah4.c +++ b/net/ipv4/ah4.c @@ -1,8 +1,8 @@ // SPDX-License-Identifier: GPL-2.0-only #define pr_fmt(fmt) "IPsec: " fmt -#include <crypto/algapi.h> #include <crypto/hash.h> +#include <crypto/utils.h> #include <linux/err.h> #include <linux/module.h> #include <linux/slab.h> @@ -27,9 +27,7 @@ static void *ah_alloc_tmp(struct crypto_ahash *ahash, int nfrags, { unsigned int len; - len = size + crypto_ahash_digestsize(ahash) + - (crypto_ahash_alignmask(ahash) & - ~(crypto_tfm_ctx_alignment() - 1)); + len = size + crypto_ahash_digestsize(ahash); len = ALIGN(len, crypto_tfm_ctx_alignment()); @@ -46,10 +44,9 @@ static inline u8 *ah_tmp_auth(void *tmp, unsigned int offset) return tmp + offset; } -static inline u8 *ah_tmp_icv(struct crypto_ahash *ahash, void *tmp, - unsigned int offset) +static inline u8 *ah_tmp_icv(void *tmp, unsigned int offset) { - return PTR_ALIGN((u8 *)tmp + offset, crypto_ahash_alignmask(ahash) + 1); + return tmp + offset; } static inline struct ahash_request *ah_tmp_req(struct crypto_ahash *ahash, @@ -117,11 +114,11 @@ static int ip_clear_mutable_options(const struct iphdr *iph, __be32 *daddr) return 0; } -static void ah_output_done(struct crypto_async_request *base, int err) +static void ah_output_done(void *data, int err) { u8 *icv; struct iphdr *iph; - struct sk_buff *skb = base->data; + struct sk_buff *skb = data; struct xfrm_state *x = skb_dst(skb)->xfrm; struct ah_data *ahp = x->data; struct iphdr *top_iph = ip_hdr(skb); @@ -129,7 +126,7 @@ static void ah_output_done(struct crypto_async_request *base, int err) int ihl = ip_hdrlen(skb); iph = AH_SKB_CB(skb)->tmp; - icv = ah_tmp_icv(ahp->ahash, iph, ihl); + icv = ah_tmp_icv(iph, ihl); memcpy(ah->auth_data, icv, ahp->icv_trunc_len); top_iph->tos = iph->tos; @@ -182,7 +179,7 @@ static int ah_output(struct xfrm_state *x, struct sk_buff *skb) if (!iph) goto out; seqhi = (__be32 *)((char *)iph + ihl); - icv = ah_tmp_icv(ahash, seqhi, seqhi_len); + icv = ah_tmp_icv(seqhi, seqhi_len); req = ah_tmp_req(ahash, icv); sg = ah_req_sg(ahash, req); seqhisg = sg + nfrags; @@ -262,12 +259,12 @@ out: return err; } -static void ah_input_done(struct crypto_async_request *base, int err) +static void ah_input_done(void *data, int err) { u8 *auth_data; u8 *icv; struct iphdr *work_iph; - struct sk_buff *skb = base->data; + struct sk_buff *skb = data; struct xfrm_state *x = xfrm_input_state(skb); struct ah_data *ahp = x->data; struct ip_auth_hdr *ah = ip_auth_hdr(skb); @@ -279,7 +276,7 @@ static void ah_input_done(struct crypto_async_request *base, int err) work_iph = AH_SKB_CB(skb)->tmp; auth_data = ah_tmp_auth(work_iph, ihl); - icv = ah_tmp_icv(ahp->ahash, auth_data, ahp->icv_trunc_len); + icv = ah_tmp_icv(auth_data, ahp->icv_trunc_len); err = crypto_memneq(icv, auth_data, ahp->icv_trunc_len) ? -EBADMSG : 0; if (err) @@ -374,7 +371,7 @@ static int ah_input(struct xfrm_state *x, struct sk_buff *skb) seqhi = (__be32 *)((char *)work_iph + ihl); auth_data = ah_tmp_auth(seqhi, seqhi_len); - icv = ah_tmp_icv(ahash, auth_data, ahp->icv_trunc_len); + icv = ah_tmp_icv(auth_data, ahp->icv_trunc_len); req = ah_tmp_req(ahash, icv); sg = ah_req_sg(ahash, req); seqhisg = sg + nfrags; @@ -471,30 +468,38 @@ static int ah4_err(struct sk_buff *skb, u32 info) return 0; } -static int ah_init_state(struct xfrm_state *x) +static int ah_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) { struct ah_data *ahp = NULL; struct xfrm_algo_desc *aalg_desc; struct crypto_ahash *ahash; - if (!x->aalg) + if (!x->aalg) { + NL_SET_ERR_MSG(extack, "AH requires a state with an AUTH algorithm"); goto error; + } - if (x->encap) + if (x->encap) { + NL_SET_ERR_MSG(extack, "AH is not compatible with encapsulation"); goto error; + } ahp = kzalloc(sizeof(*ahp), GFP_KERNEL); if (!ahp) return -ENOMEM; ahash = crypto_alloc_ahash(x->aalg->alg_name, 0, 0); - if (IS_ERR(ahash)) + if (IS_ERR(ahash)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto error; + } ahp->ahash = ahash; if (crypto_ahash_setkey(ahash, x->aalg->alg_key, - (x->aalg->alg_key_len + 7) / 8)) + (x->aalg->alg_key_len + 7) / 8)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto error; + } /* * Lookup the algorithm description maintained by xfrm_algo, @@ -507,10 +512,7 @@ static int ah_init_state(struct xfrm_state *x) if (aalg_desc->uinfo.auth.icv_fullbits/8 != crypto_ahash_digestsize(ahash)) { - pr_info("%s: %s digestsize %u != %hu\n", - __func__, x->aalg->alg_name, - crypto_ahash_digestsize(ahash), - aalg_desc->uinfo.auth.icv_fullbits / 8); + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto error; } @@ -595,5 +597,6 @@ static void __exit ah4_fini(void) module_init(ah4_init); module_exit(ah4_fini); +MODULE_DESCRIPTION("IPv4 AH transformation library"); MODULE_LICENSE("GPL"); MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_AH); diff --git a/net/ipv4/arp.c b/net/ipv4/arp.c index 4db0325f6e1a..7f3863daaa40 100644 --- a/net/ipv4/arp.c +++ b/net/ipv4/arp.c @@ -168,8 +168,9 @@ struct neigh_table arp_tbl = { [NEIGH_VAR_RETRANS_TIME] = 1 * HZ, [NEIGH_VAR_BASE_REACHABLE_TIME] = 30 * HZ, [NEIGH_VAR_DELAY_PROBE_TIME] = 5 * HZ, + [NEIGH_VAR_INTERVAL_PROBE_TIME_MS] = 5 * HZ, [NEIGH_VAR_GC_STALETIME] = 60 * HZ, - [NEIGH_VAR_QUEUE_LEN_BYTES] = SK_WMEM_MAX, + [NEIGH_VAR_QUEUE_LEN_BYTES] = SK_WMEM_DEFAULT, [NEIGH_VAR_PROXY_QLEN] = 64, [NEIGH_VAR_ANYCAST_DELAY] = 1 * HZ, [NEIGH_VAR_PROXY_DELAY] = (8 * HZ) / 10, @@ -293,7 +294,7 @@ static int arp_constructor(struct neighbour *neigh) static void arp_error_report(struct neighbour *neigh, struct sk_buff *skb) { dst_link_failure(skb); - kfree_skb(skb); + kfree_skb_reason(skb, SKB_DROP_REASON_NEIGH_FAILED); } /* Create and send an arp packet. */ @@ -374,7 +375,7 @@ static void arp_solicit(struct neighbour *neigh, struct sk_buff *skb) probes -= NEIGH_VAR(neigh->parms, UCAST_PROBES); if (probes < 0) { - if (!(neigh->nud_state & NUD_VALID)) + if (!(READ_ONCE(neigh->nud_state) & NUD_VALID)) pr_debug("trying to ucast probe in NUD_INVALID\n"); neigh_ha_snapshot(dst_ha, neigh, dev); dst_hw = dst_ha; @@ -428,6 +429,26 @@ static int arp_ignore(struct in_device *in_dev, __be32 sip, __be32 tip) return !inet_confirm_addr(net, in_dev, sip, tip, scope); } +static int arp_accept(struct in_device *in_dev, __be32 sip) +{ + struct net *net = dev_net(in_dev->dev); + int scope = RT_SCOPE_LINK; + + switch (IN_DEV_ARP_ACCEPT(in_dev)) { + case 0: /* Don't create new entries from garp */ + return 0; + case 1: /* Create new entries from garp */ + return 1; + case 2: /* Create a neighbor in the arp table only if sip + * is in the same subnet as an address configured + * on the interface that received the garp message + */ + return !!inet_confirm_addr(net, in_dev, sip, 0, scope); + default: + return 0; + } +} + static int arp_filter(__be32 sip, __be32 tip, struct net_device *dev) { struct rtable *rt; @@ -435,7 +456,8 @@ static int arp_filter(__be32 sip, __be32 tip, struct net_device *dev) /*unsigned long now; */ struct net *net = dev_net(dev); - rt = ip_route_output(net, sip, tip, 0, l3mdev_master_ifindex_rcu(dev)); + rt = ip_route_output(net, sip, tip, 0, l3mdev_master_ifindex_rcu(dev), + RT_SCOPE_UNIVERSE); if (IS_ERR(rt)) return 1; if (rt->dst.dev != dev) { @@ -637,10 +659,12 @@ static int arp_xmit_finish(struct net *net, struct sock *sk, struct sk_buff *skb */ void arp_xmit(struct sk_buff *skb) { + rcu_read_lock(); /* Send it off, maybe filter it using firewalling first. */ NF_HOOK(NFPROTO_ARP, NF_ARP_OUT, - dev_net(skb->dev), NULL, skb, NULL, skb->dev, + dev_net_rcu(skb->dev), NULL, skb, NULL, skb->dev, arp_xmit_finish); + rcu_read_unlock(); } EXPORT_SYMBOL(arp_xmit); @@ -840,7 +864,7 @@ static int arp_process(struct net *net, struct sock *sk, struct sk_buff *skb) (arp_fwd_proxy(in_dev, dev, rt) || arp_fwd_pvlan(in_dev, dev, rt, sip, tip) || (rt->dst.dev != dev && - pneigh_lookup(&arp_tbl, net, &tip, dev, 0)))) { + pneigh_lookup(&arp_tbl, net, &tip, dev)))) { n = neigh_event_ns(&arp_tbl, sha, &sip, dev); if (n) neigh_release(n); @@ -867,12 +891,12 @@ static int arp_process(struct net *net, struct sock *sk, struct sk_buff *skb) n = __neigh_lookup(&arp_tbl, &sip, dev, 0); addr_type = -1; - if (n || IN_DEV_ARP_ACCEPT(in_dev)) { + if (n || arp_accept(in_dev, sip)) { is_garp = arp_is_garp(net, dev, &addr_type, arp->ar_op, sip, tip, sha, tha); } - if (IN_DEV_ARP_ACCEPT(in_dev)) { + if (arp_accept(in_dev, sip)) { /* Unsolicited ARP is not accepted by default. It is possible, that this option should be enabled for some devices (strip is candidate) @@ -942,6 +966,7 @@ static int arp_is_multicast(const void *pkey) static int arp_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev) { + enum skb_drop_reason drop_reason; const struct arphdr *arp; /* do not tweak dropwatch on an ARP we will ignore */ @@ -955,12 +980,15 @@ static int arp_rcv(struct sk_buff *skb, struct net_device *dev, goto out_of_mem; /* ARP header, plus 2 device addresses, plus 2 IP addresses. */ - if (!pskb_may_pull(skb, arp_hdr_len(dev))) + drop_reason = pskb_may_pull_reason(skb, arp_hdr_len(dev)); + if (drop_reason != SKB_NOT_DROPPED_YET) goto freeskb; arp = arp_hdr(skb); - if (arp->ar_hln != dev->addr_len || arp->ar_pln != 4) + if (arp->ar_hln != dev->addr_len || arp->ar_pln != 4) { + drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; goto freeskb; + } memset(NEIGH_CB(skb), 0, sizeof(struct neighbour_cb)); @@ -972,7 +1000,7 @@ consumeskb: consume_skb(skb); return NET_RX_SUCCESS; freeskb: - kfree_skb(skb); + kfree_skb_reason(skb, drop_reason); out_of_mem: return NET_RX_DROP; } @@ -981,6 +1009,55 @@ out_of_mem: * User level interface (ioctl) */ +static struct net_device *arp_req_dev_by_name(struct net *net, struct arpreq *r, + bool getarp) +{ + struct net_device *dev; + + if (getarp) + dev = dev_get_by_name_rcu(net, r->arp_dev); + else + dev = __dev_get_by_name(net, r->arp_dev); + if (!dev) + return ERR_PTR(-ENODEV); + + /* Mmmm... It is wrong... ARPHRD_NETROM == 0 */ + if (!r->arp_ha.sa_family) + r->arp_ha.sa_family = dev->type; + + if ((r->arp_flags & ATF_COM) && r->arp_ha.sa_family != dev->type) + return ERR_PTR(-EINVAL); + + return dev; +} + +static struct net_device *arp_req_dev(struct net *net, struct arpreq *r) +{ + struct net_device *dev; + struct rtable *rt; + __be32 ip; + + if (r->arp_dev[0]) + return arp_req_dev_by_name(net, r, false); + + if (r->arp_flags & ATF_PUBL) + return NULL; + + ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; + + rt = ip_route_output(net, ip, 0, 0, 0, RT_SCOPE_LINK); + if (IS_ERR(rt)) + return ERR_CAST(rt); + + dev = rt->dst.dev; + ip_rt_put(rt); + + if (!dev) + return ERR_PTR(-EINVAL); + + return dev; +} + /* * Set (create) an ARP cache entry. */ @@ -991,8 +1068,8 @@ static int arp_req_set_proxy(struct net *net, struct net_device *dev, int on) IPV4_DEVCONF_ALL(net, PROXY_ARP) = on; return 0; } - if (__in_dev_get_rtnl(dev)) { - IN_DEV_CONF_SET(__in_dev_get_rtnl(dev), PROXY_ARP, on); + if (__in_dev_get_rtnl_net(dev)) { + IN_DEV_CONF_SET(__in_dev_get_rtnl_net(dev), PROXY_ARP, on); return 0; } return -ENXIO; @@ -1001,49 +1078,37 @@ static int arp_req_set_proxy(struct net *net, struct net_device *dev, int on) static int arp_req_set_public(struct net *net, struct arpreq *r, struct net_device *dev) { - __be32 ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; __be32 mask = ((struct sockaddr_in *)&r->arp_netmask)->sin_addr.s_addr; - if (mask && mask != htonl(0xFFFFFFFF)) - return -EINVAL; if (!dev && (r->arp_flags & ATF_COM)) { - dev = dev_getbyhwaddr_rcu(net, r->arp_ha.sa_family, + dev = dev_getbyhwaddr(net, r->arp_ha.sa_family, r->arp_ha.sa_data); if (!dev) return -ENODEV; } if (mask) { - if (!pneigh_lookup(&arp_tbl, net, &ip, dev, 1)) - return -ENOBUFS; - return 0; + __be32 ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; + + return pneigh_create(&arp_tbl, net, &ip, dev, 0, 0, false); } return arp_req_set_proxy(net, dev, 1); } -static int arp_req_set(struct net *net, struct arpreq *r, - struct net_device *dev) +static int arp_req_set(struct net *net, struct arpreq *r) { - __be32 ip; struct neighbour *neigh; + struct net_device *dev; + __be32 ip; int err; + dev = arp_req_dev(net, r); + if (IS_ERR(dev)) + return PTR_ERR(dev); + if (r->arp_flags & ATF_PUBL) return arp_req_set_public(net, r, dev); - ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; - if (r->arp_flags & ATF_PERM) - r->arp_flags |= ATF_COM; - if (!dev) { - struct rtable *rt = ip_route_output(net, ip, 0, RTO_ONLINK, 0); - - if (IS_ERR(rt)) - return PTR_ERR(rt); - dev = rt->dst.dev; - ip_rt_put(rt); - if (!dev) - return -EINVAL; - } switch (dev->type) { #if IS_ENABLED(CONFIG_FDDI) case ARPHRD_FDDI: @@ -1065,12 +1130,18 @@ static int arp_req_set(struct net *net, struct arpreq *r, break; } + ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; + neigh = __neigh_lookup_errno(&arp_tbl, &ip, dev); err = PTR_ERR(neigh); if (!IS_ERR(neigh)) { unsigned int state = NUD_STALE; - if (r->arp_flags & ATF_PERM) + + if (r->arp_flags & ATF_PERM) { + r->arp_flags |= ATF_COM; state = NUD_PERMANENT; + } + err = neigh_update(neigh, (r->arp_flags & ATF_COM) ? r->arp_ha.sa_data : NULL, state, NEIGH_UPDATE_F_OVERRIDE | @@ -1094,43 +1165,62 @@ static unsigned int arp_state_to_flags(struct neighbour *neigh) * Get an ARP cache entry. */ -static int arp_req_get(struct arpreq *r, struct net_device *dev) +static int arp_req_get(struct net *net, struct arpreq *r) { __be32 ip = ((struct sockaddr_in *) &r->arp_pa)->sin_addr.s_addr; struct neighbour *neigh; - int err = -ENXIO; + struct net_device *dev; + + if (!r->arp_dev[0]) + return -ENODEV; + + dev = arp_req_dev_by_name(net, r, true); + if (IS_ERR(dev)) + return PTR_ERR(dev); neigh = neigh_lookup(&arp_tbl, &ip, dev); - if (neigh) { - if (!(neigh->nud_state & NUD_NOARP)) { - read_lock_bh(&neigh->lock); - memcpy(r->arp_ha.sa_data, neigh->ha, dev->addr_len); - r->arp_flags = arp_state_to_flags(neigh); - read_unlock_bh(&neigh->lock); - r->arp_ha.sa_family = dev->type; - strlcpy(r->arp_dev, dev->name, sizeof(r->arp_dev)); - err = 0; - } + if (!neigh) + return -ENXIO; + + if (READ_ONCE(neigh->nud_state) & NUD_NOARP) { neigh_release(neigh); + return -ENXIO; } - return err; + + read_lock_bh(&neigh->lock); + memcpy(r->arp_ha.sa_data, neigh->ha, + min(dev->addr_len, sizeof(r->arp_ha.sa_data))); + r->arp_flags = arp_state_to_flags(neigh); + read_unlock_bh(&neigh->lock); + + neigh_release(neigh); + + r->arp_ha.sa_family = dev->type; + netdev_copy_name(dev, r->arp_dev); + + return 0; } -static int arp_invalidate(struct net_device *dev, __be32 ip) +int arp_invalidate(struct net_device *dev, __be32 ip, bool force) { struct neighbour *neigh = neigh_lookup(&arp_tbl, &ip, dev); int err = -ENXIO; struct neigh_table *tbl = &arp_tbl; if (neigh) { - if (neigh->nud_state & ~NUD_NOARP) + if ((READ_ONCE(neigh->nud_state) & NUD_VALID) && !force) { + neigh_release(neigh); + return 0; + } + + if (READ_ONCE(neigh->nud_state) & ~NUD_NOARP) err = neigh_update(neigh, NULL, NUD_FAILED, NEIGH_UPDATE_F_OVERRIDE| NEIGH_UPDATE_F_ADMIN, 0); - write_lock_bh(&tbl->lock); + spin_lock_bh(&tbl->lock); neigh_release(neigh); - neigh_remove_one(neigh, tbl); - write_unlock_bh(&tbl->lock); + neigh_remove_one(neigh); + spin_unlock_bh(&tbl->lock); } return err; @@ -1139,37 +1229,32 @@ static int arp_invalidate(struct net_device *dev, __be32 ip) static int arp_req_delete_public(struct net *net, struct arpreq *r, struct net_device *dev) { - __be32 ip = ((struct sockaddr_in *) &r->arp_pa)->sin_addr.s_addr; __be32 mask = ((struct sockaddr_in *)&r->arp_netmask)->sin_addr.s_addr; - if (mask == htonl(0xFFFFFFFF)) - return pneigh_delete(&arp_tbl, net, &ip, dev); + if (mask) { + __be32 ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; - if (mask) - return -EINVAL; + return pneigh_delete(&arp_tbl, net, &ip, dev); + } return arp_req_set_proxy(net, dev, 0); } -static int arp_req_delete(struct net *net, struct arpreq *r, - struct net_device *dev) +static int arp_req_delete(struct net *net, struct arpreq *r) { + struct net_device *dev; __be32 ip; + dev = arp_req_dev(net, r); + if (IS_ERR(dev)) + return PTR_ERR(dev); + if (r->arp_flags & ATF_PUBL) return arp_req_delete_public(net, r, dev); ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; - if (!dev) { - struct rtable *rt = ip_route_output(net, ip, 0, RTO_ONLINK, 0); - if (IS_ERR(rt)) - return PTR_ERR(rt); - dev = rt->dst.dev; - ip_rt_put(rt); - if (!dev) - return -EINVAL; - } - return arp_invalidate(dev, ip); + + return arp_invalidate(dev, ip, true); } /* @@ -1178,9 +1263,9 @@ static int arp_req_delete(struct net *net, struct arpreq *r, int arp_ioctl(struct net *net, unsigned int cmd, void __user *arg) { - int err; struct arpreq r; - struct net_device *dev = NULL; + __be32 *netmask; + int err; switch (cmd) { case SIOCDARP: @@ -1203,42 +1288,34 @@ int arp_ioctl(struct net *net, unsigned int cmd, void __user *arg) if (!(r.arp_flags & ATF_PUBL) && (r.arp_flags & (ATF_NETMASK | ATF_DONTPUB))) return -EINVAL; + + netmask = &((struct sockaddr_in *)&r.arp_netmask)->sin_addr.s_addr; if (!(r.arp_flags & ATF_NETMASK)) - ((struct sockaddr_in *)&r.arp_netmask)->sin_addr.s_addr = - htonl(0xFFFFFFFFUL); - rtnl_lock(); - if (r.arp_dev[0]) { - err = -ENODEV; - dev = __dev_get_by_name(net, r.arp_dev); - if (!dev) - goto out; - - /* Mmmm... It is wrong... ARPHRD_NETROM==0 */ - if (!r.arp_ha.sa_family) - r.arp_ha.sa_family = dev->type; - err = -EINVAL; - if ((r.arp_flags & ATF_COM) && r.arp_ha.sa_family != dev->type) - goto out; - } else if (cmd == SIOCGARP) { - err = -ENODEV; - goto out; - } + *netmask = htonl(0xFFFFFFFFUL); + else if (*netmask && *netmask != htonl(0xFFFFFFFFUL)) + return -EINVAL; switch (cmd) { case SIOCDARP: - err = arp_req_delete(net, &r, dev); + rtnl_net_lock(net); + err = arp_req_delete(net, &r); + rtnl_net_unlock(net); break; case SIOCSARP: - err = arp_req_set(net, &r, dev); + rtnl_net_lock(net); + err = arp_req_set(net, &r); + rtnl_net_unlock(net); break; case SIOCGARP: - err = arp_req_get(&r, dev); + rcu_read_lock(); + err = arp_req_get(net, &r); + rcu_read_unlock(); + + if (!err && copy_to_user(arg, &r, sizeof(r))) + err = -EFAULT; break; } -out: - rtnl_unlock(); - if (cmd == SIOCGARP && !err && copy_to_user(arg, &r, sizeof(r))) - err = -EFAULT; + return err; } @@ -1299,9 +1376,9 @@ static struct packet_type arp_packet_type __read_mostly = { .func = arp_rcv, }; +#ifdef CONFIG_PROC_FS #if IS_ENABLED(CONFIG_AX25) -/* ------------------------------------------------------------------------ */ /* * ax25 -> ASCII conversion */ @@ -1407,16 +1484,13 @@ static void *arp_seq_start(struct seq_file *seq, loff_t *pos) return neigh_seq_start(seq, pos, &arp_tbl, NEIGH_SEQ_SKIP_NOARP); } -/* ------------------------------------------------------------------------ */ - static const struct seq_operations arp_seq_ops = { .start = arp_seq_start, .next = neigh_seq_next, .stop = neigh_seq_stop, .show = arp_seq_show, }; - -/* ------------------------------------------------------------------------ */ +#endif /* CONFIG_PROC_FS */ static int __net_init arp_net_init(struct net *net) { diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c index de610cb83694..e01492234b0b 100644 --- a/net/ipv4/bpf_tcp_ca.c +++ b/net/ipv4/bpf_tcp_ca.c @@ -1,6 +1,7 @@ // SPDX-License-Identifier: GPL-2.0 /* Copyright (c) 2019 Facebook */ +#include <linux/init.h> #include <linux/types.h> #include <linux/bpf_verifier.h> #include <linux/bpf.h> @@ -11,26 +12,11 @@ #include <net/bpf_sk_storage.h> /* "extern" is to avoid sparse warning. It is only used in bpf_struct_ops.c. */ -extern struct bpf_struct_ops bpf_tcp_congestion_ops; - -static u32 optional_ops[] = { - offsetof(struct tcp_congestion_ops, init), - offsetof(struct tcp_congestion_ops, release), - offsetof(struct tcp_congestion_ops, set_state), - offsetof(struct tcp_congestion_ops, cwnd_event), - offsetof(struct tcp_congestion_ops, in_ack_event), - offsetof(struct tcp_congestion_ops, pkts_acked), - offsetof(struct tcp_congestion_ops, min_tso_segs), - offsetof(struct tcp_congestion_ops, sndbuf_expand), - offsetof(struct tcp_congestion_ops, cong_control), -}; - -static u32 unsupported_ops[] = { - offsetof(struct tcp_congestion_ops, get_info), -}; +static struct bpf_struct_ops bpf_tcp_congestion_ops; static const struct btf_type *tcp_sock_type; static u32 tcp_sock_id, sock_id; +static const struct btf_type *tcp_congestion_ops_type; static int bpf_tcp_ca_init(struct btf *btf) { @@ -47,35 +33,14 @@ static int bpf_tcp_ca_init(struct btf *btf) tcp_sock_id = type_id; tcp_sock_type = btf_type_by_id(btf, tcp_sock_id); - return 0; -} - -static bool is_optional(u32 member_offset) -{ - unsigned int i; - - for (i = 0; i < ARRAY_SIZE(optional_ops); i++) { - if (member_offset == optional_ops[i]) - return true; - } - - return false; -} - -static bool is_unsupported(u32 member_offset) -{ - unsigned int i; - - for (i = 0; i < ARRAY_SIZE(unsupported_ops); i++) { - if (member_offset == unsupported_ops[i]) - return true; - } + type_id = btf_find_by_name_kind(btf, "tcp_congestion_ops", BTF_KIND_STRUCT); + if (type_id < 0) + return -EINVAL; + tcp_congestion_ops_type = btf_type_by_id(btf, type_id); - return false; + return 0; } -extern struct btf *btf_vmlinux; - static bool bpf_tcp_ca_is_valid_access(int off, int size, enum bpf_access_type type, const struct bpf_prog *prog, @@ -84,7 +49,9 @@ static bool bpf_tcp_ca_is_valid_access(int off, int size, if (!bpf_tracing_btf_ctx_access(off, size, type, prog, info)) return false; - if (info->reg_type == PTR_TO_BTF_ID && info->btf_id == sock_id) + if (base_type(info->reg_type) == PTR_TO_BTF_ID && + !bpf_type_has_unsafe_modifiers(info->reg_type) && + info->btf_id == sock_id) /* promote it to tcp_sock */ info->btf_id = tcp_sock_id; @@ -92,22 +59,25 @@ static bool bpf_tcp_ca_is_valid_access(int off, int size, } static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log, - const struct btf *btf, - const struct btf_type *t, int off, - int size, enum bpf_access_type atype, - u32 *next_btf_id) + const struct bpf_reg_state *reg, + int off, int size) { + const struct btf_type *t; size_t end; - if (atype == BPF_READ) - return btf_struct_access(log, btf, t, off, size, atype, next_btf_id); - + t = btf_type_by_id(reg->btf, reg->btf_id); if (t != tcp_sock_type) { bpf_log(log, "only read is supported\n"); return -EACCES; } switch (off) { + case offsetof(struct sock, sk_pacing_rate): + end = offsetofend(struct sock, sk_pacing_rate); + break; + case offsetof(struct sock, sk_pacing_status): + end = offsetofend(struct sock, sk_pacing_status); + break; case bpf_ctx_range(struct inet_connection_sock, icsk_ca_priv): end = offsetofend(struct inet_connection_sock, icsk_ca_priv); break; @@ -121,12 +91,18 @@ static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log, case offsetof(struct tcp_sock, snd_cwnd_cnt): end = offsetofend(struct tcp_sock, snd_cwnd_cnt); break; + case offsetof(struct tcp_sock, snd_cwnd_stamp): + end = offsetofend(struct tcp_sock, snd_cwnd_stamp); + break; case offsetof(struct tcp_sock, snd_ssthresh): end = offsetofend(struct tcp_sock, snd_ssthresh); break; case offsetof(struct tcp_sock, ecn_flags): end = offsetofend(struct tcp_sock, ecn_flags); break; + case offsetof(struct tcp_sock, app_limited): + end = offsetofend(struct tcp_sock, app_limited); + break; default: bpf_log(log, "no write support to tcp_sock at off %d\n", off); return -EACCES; @@ -139,13 +115,13 @@ static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log, return -EACCES; } - return NOT_INIT; + return 0; } BPF_CALL_2(bpf_tcp_send_ack, struct tcp_sock *, tp, u32, rcv_nxt) { /* bpf_tcp_ca prog cannot have NULL tp */ - __tcp_send_ack((struct sock *)tp, rcv_nxt); + __tcp_send_ack((struct sock *)tp, rcv_nxt, 0); return 0; } @@ -166,7 +142,7 @@ static u32 prog_ops_moff(const struct bpf_prog *prog) u32 midx; midx = prog->expected_attach_type; - t = bpf_tcp_congestion_ops.type; + t = tcp_congestion_ops_type; m = &btf_type_member(t)[midx]; return __btf_member_bit_offset(t, m) / 8; @@ -208,30 +184,27 @@ bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id, case BPF_FUNC_ktime_get_coarse_ns: return &bpf_ktime_get_coarse_ns_proto; default: - return bpf_base_func_proto(func_id); + return bpf_base_func_proto(func_id, prog); } } -BTF_SET_START(bpf_tcp_ca_kfunc_ids) -BTF_ID(func, tcp_reno_ssthresh) -BTF_ID(func, tcp_reno_cong_avoid) -BTF_ID(func, tcp_reno_undo_cwnd) -BTF_ID(func, tcp_slow_start) -BTF_ID(func, tcp_cong_avoid_ai) -BTF_SET_END(bpf_tcp_ca_kfunc_ids) - -static bool bpf_tcp_ca_check_kfunc_call(u32 kfunc_btf_id, struct module *owner) -{ - if (btf_id_set_contains(&bpf_tcp_ca_kfunc_ids, kfunc_btf_id)) - return true; - return bpf_check_mod_kfunc_call(&bpf_tcp_ca_kfunc_list, kfunc_btf_id, owner); -} +BTF_KFUNCS_START(bpf_tcp_ca_check_kfunc_ids) +BTF_ID_FLAGS(func, tcp_reno_ssthresh) +BTF_ID_FLAGS(func, tcp_reno_cong_avoid) +BTF_ID_FLAGS(func, tcp_reno_undo_cwnd) +BTF_ID_FLAGS(func, tcp_slow_start) +BTF_ID_FLAGS(func, tcp_cong_avoid_ai) +BTF_KFUNCS_END(bpf_tcp_ca_check_kfunc_ids) + +static const struct btf_kfunc_id_set bpf_tcp_ca_kfunc_set = { + .owner = THIS_MODULE, + .set = &bpf_tcp_ca_check_kfunc_ids, +}; static const struct bpf_verifier_ops bpf_tcp_ca_verifier_ops = { .get_func_proto = bpf_tcp_ca_get_func_proto, .is_valid_access = bpf_tcp_ca_is_valid_access, .btf_struct_access = bpf_tcp_ca_btf_struct_access, - .check_kfunc_call = bpf_tcp_ca_check_kfunc_call, }; static int bpf_tcp_ca_init_member(const struct btf_type *t, @@ -240,7 +213,6 @@ static int bpf_tcp_ca_init_member(const struct btf_type *t, { const struct tcp_congestion_ops *utcp_ca; struct tcp_congestion_ops *tcp_ca; - int prog_fd; u32 moff; utcp_ca = (const struct tcp_congestion_ops *)udata; @@ -257,46 +229,121 @@ static int bpf_tcp_ca_init_member(const struct btf_type *t, if (bpf_obj_name_cpy(tcp_ca->name, utcp_ca->name, sizeof(tcp_ca->name)) <= 0) return -EINVAL; - if (tcp_ca_find(utcp_ca->name)) - return -EEXIST; return 1; } - if (!btf_type_resolve_func_ptr(btf_vmlinux, member->type, NULL)) - return 0; + return 0; +} + +static int bpf_tcp_ca_reg(void *kdata, struct bpf_link *link) +{ + return tcp_register_congestion_control(kdata); +} + +static void bpf_tcp_ca_unreg(void *kdata, struct bpf_link *link) +{ + tcp_unregister_congestion_control(kdata); +} - /* Ensure bpf_prog is provided for compulsory func ptr */ - prog_fd = (int)(*(unsigned long *)(udata + moff)); - if (!prog_fd && !is_optional(moff) && !is_unsupported(moff)) - return -EINVAL; +static int bpf_tcp_ca_update(void *kdata, void *old_kdata, struct bpf_link *link) +{ + return tcp_update_congestion_control(kdata, old_kdata); +} + +static int bpf_tcp_ca_validate(void *kdata) +{ + return tcp_validate_congestion_control(kdata); +} +static u32 bpf_tcp_ca_ssthresh(struct sock *sk) +{ return 0; } -static int bpf_tcp_ca_check_member(const struct btf_type *t, - const struct btf_member *member) +static void bpf_tcp_ca_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ +} + +static void bpf_tcp_ca_set_state(struct sock *sk, u8 new_state) +{ +} + +static void bpf_tcp_ca_cwnd_event(struct sock *sk, enum tcp_ca_event ev) +{ +} + +static void bpf_tcp_ca_in_ack_event(struct sock *sk, u32 flags) +{ +} + +static void bpf_tcp_ca_pkts_acked(struct sock *sk, const struct ack_sample *sample) +{ +} + +static u32 bpf_tcp_ca_min_tso_segs(struct sock *sk) { - if (is_unsupported(__btf_member_bit_offset(t, member) / 8)) - return -ENOTSUPP; return 0; } -static int bpf_tcp_ca_reg(void *kdata) +static void bpf_tcp_ca_cong_control(struct sock *sk, u32 ack, int flag, + const struct rate_sample *rs) { - return tcp_register_congestion_control(kdata); } -static void bpf_tcp_ca_unreg(void *kdata) +static u32 bpf_tcp_ca_undo_cwnd(struct sock *sk) +{ + return 0; +} + +static u32 bpf_tcp_ca_sndbuf_expand(struct sock *sk) +{ + return 0; +} + +static void __bpf_tcp_ca_init(struct sock *sk) { - tcp_unregister_congestion_control(kdata); } -struct bpf_struct_ops bpf_tcp_congestion_ops = { +static void __bpf_tcp_ca_release(struct sock *sk) +{ +} + +static struct tcp_congestion_ops __bpf_ops_tcp_congestion_ops = { + .ssthresh = bpf_tcp_ca_ssthresh, + .cong_avoid = bpf_tcp_ca_cong_avoid, + .set_state = bpf_tcp_ca_set_state, + .cwnd_event = bpf_tcp_ca_cwnd_event, + .in_ack_event = bpf_tcp_ca_in_ack_event, + .pkts_acked = bpf_tcp_ca_pkts_acked, + .min_tso_segs = bpf_tcp_ca_min_tso_segs, + .cong_control = bpf_tcp_ca_cong_control, + .undo_cwnd = bpf_tcp_ca_undo_cwnd, + .sndbuf_expand = bpf_tcp_ca_sndbuf_expand, + + .init = __bpf_tcp_ca_init, + .release = __bpf_tcp_ca_release, +}; + +static struct bpf_struct_ops bpf_tcp_congestion_ops = { .verifier_ops = &bpf_tcp_ca_verifier_ops, .reg = bpf_tcp_ca_reg, .unreg = bpf_tcp_ca_unreg, - .check_member = bpf_tcp_ca_check_member, + .update = bpf_tcp_ca_update, .init_member = bpf_tcp_ca_init_member, .init = bpf_tcp_ca_init, + .validate = bpf_tcp_ca_validate, .name = "tcp_congestion_ops", + .cfi_stubs = &__bpf_ops_tcp_congestion_ops, + .owner = THIS_MODULE, }; + +static int __init bpf_tcp_ca_kfunc_init(void) +{ + int ret; + + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, &bpf_tcp_ca_kfunc_set); + ret = ret ?: register_bpf_struct_ops(&bpf_tcp_congestion_ops, tcp_congestion_ops); + + return ret; +} +late_initcall(bpf_tcp_ca_kfunc_init); diff --git a/net/ipv4/bpfilter/Makefile b/net/ipv4/bpfilter/Makefile deleted file mode 100644 index 00af5305e05a..000000000000 --- a/net/ipv4/bpfilter/Makefile +++ /dev/null @@ -1,2 +0,0 @@ -# SPDX-License-Identifier: GPL-2.0-only -obj-$(CONFIG_BPFILTER) += sockopt.o diff --git a/net/ipv4/bpfilter/sockopt.c b/net/ipv4/bpfilter/sockopt.c deleted file mode 100644 index 1b34cb9a7708..000000000000 --- a/net/ipv4/bpfilter/sockopt.c +++ /dev/null @@ -1,80 +0,0 @@ -// SPDX-License-Identifier: GPL-2.0 -#include <linux/init.h> -#include <linux/module.h> -#include <linux/uaccess.h> -#include <linux/bpfilter.h> -#include <uapi/linux/bpf.h> -#include <linux/wait.h> -#include <linux/kmod.h> -#include <linux/fs.h> -#include <linux/file.h> - -struct bpfilter_umh_ops bpfilter_ops; -EXPORT_SYMBOL_GPL(bpfilter_ops); - -void bpfilter_umh_cleanup(struct umd_info *info) -{ - fput(info->pipe_to_umh); - fput(info->pipe_from_umh); - put_pid(info->tgid); - info->tgid = NULL; -} -EXPORT_SYMBOL_GPL(bpfilter_umh_cleanup); - -static int bpfilter_mbox_request(struct sock *sk, int optname, sockptr_t optval, - unsigned int optlen, bool is_set) -{ - int err; - mutex_lock(&bpfilter_ops.lock); - if (!bpfilter_ops.sockopt) { - mutex_unlock(&bpfilter_ops.lock); - request_module("bpfilter"); - mutex_lock(&bpfilter_ops.lock); - - if (!bpfilter_ops.sockopt) { - err = -ENOPROTOOPT; - goto out; - } - } - if (bpfilter_ops.info.tgid && - thread_group_exited(bpfilter_ops.info.tgid)) - bpfilter_umh_cleanup(&bpfilter_ops.info); - - if (!bpfilter_ops.info.tgid) { - err = bpfilter_ops.start(); - if (err) - goto out; - } - err = bpfilter_ops.sockopt(sk, optname, optval, optlen, is_set); -out: - mutex_unlock(&bpfilter_ops.lock); - return err; -} - -int bpfilter_ip_set_sockopt(struct sock *sk, int optname, sockptr_t optval, - unsigned int optlen) -{ - return bpfilter_mbox_request(sk, optname, optval, optlen, true); -} - -int bpfilter_ip_get_sockopt(struct sock *sk, int optname, char __user *optval, - int __user *optlen) -{ - int len; - - if (get_user(len, optlen)) - return -EFAULT; - - return bpfilter_mbox_request(sk, optname, USER_SOCKPTR(optval), len, - false); -} - -static int __init bpfilter_sockopt_init(void) -{ - mutex_init(&bpfilter_ops.lock); - bpfilter_ops.info.tgid = NULL; - bpfilter_ops.info.driver_name = "bpfilter_umh"; - - return 0; -} -device_initcall(bpfilter_sockopt_init); diff --git a/net/ipv4/cipso_ipv4.c b/net/ipv4/cipso_ipv4.c index 62d5f99760aa..709021197e1c 100644 --- a/net/ipv4/cipso_ipv4.c +++ b/net/ipv4/cipso_ipv4.c @@ -37,7 +37,7 @@ #include <net/cipso_ipv4.h> #include <linux/atomic.h> #include <linux/bug.h> -#include <asm/unaligned.h> +#include <linux/unaligned.h> /* List of available DOI definitions */ /* XXX - This currently assumes a minimal number of different DOIs in use, @@ -239,7 +239,7 @@ static int cipso_v4_cache_check(const unsigned char *key, struct cipso_v4_map_cache_entry *prev_entry = NULL; u32 hash; - if (!cipso_v4_cache_enabled) + if (!READ_ONCE(cipso_v4_cache_enabled)) return -ENOENT; hash = cipso_v4_map_cache_hash(key, key_len); @@ -296,13 +296,14 @@ static int cipso_v4_cache_check(const unsigned char *key, int cipso_v4_cache_add(const unsigned char *cipso_ptr, const struct netlbl_lsm_secattr *secattr) { + int bkt_size = READ_ONCE(cipso_v4_cache_bucketsize); int ret_val = -EPERM; u32 bkt; struct cipso_v4_map_cache_entry *entry = NULL; struct cipso_v4_map_cache_entry *old_entry = NULL; u32 cipso_ptr_len; - if (!cipso_v4_cache_enabled || cipso_v4_cache_bucketsize <= 0) + if (!READ_ONCE(cipso_v4_cache_enabled) || bkt_size <= 0) return 0; cipso_ptr_len = cipso_ptr[1]; @@ -322,7 +323,7 @@ int cipso_v4_cache_add(const unsigned char *cipso_ptr, bkt = entry->hash & (CIPSO_V4_CACHE_BUCKETS - 1); spin_lock_bh(&cipso_v4_cache[bkt].lock); - if (cipso_v4_cache[bkt].size < cipso_v4_cache_bucketsize) { + if (cipso_v4_cache[bkt].size < bkt_size) { list_add(&entry->list, &cipso_v4_cache[bkt].list); cipso_v4_cache[bkt].size += 1; } else { @@ -863,11 +864,8 @@ static int cipso_v4_map_cat_rbm_ntoh(const struct cipso_v4_doi *doi_def, net_clen_bits, net_spot + 1, 1); - if (net_spot < 0) { - if (net_spot == -2) - return -EFAULT; + if (net_spot < 0) return 0; - } switch (doi_def->type) { case CIPSO_V4_MAP_PASS: @@ -1199,7 +1197,8 @@ static int cipso_v4_gentag_rbm(const struct cipso_v4_doi *doi_def, /* This will send packets using the "optimized" format when * possible as specified in section 3.4.2.6 of the * CIPSO draft. */ - if (cipso_v4_rbm_optfmt && ret_val > 0 && ret_val <= 10) + if (READ_ONCE(cipso_v4_rbm_optfmt) && ret_val > 0 && + ret_val <= 10) tag_len = 14; else tag_len = 4 + ret_val; @@ -1603,7 +1602,7 @@ int cipso_v4_validate(const struct sk_buff *skb, unsigned char **option) * all the CIPSO validations here but it doesn't * really specify _exactly_ what we need to validate * ... so, just make it a sysctl tunable. */ - if (cipso_v4_rbm_strictvalid) { + if (READ_ONCE(cipso_v4_rbm_strictvalid)) { if (cipso_v4_map_lvl_valid(doi_def, tag[3]) < 0) { err_offset = opt_iter + 3; @@ -1716,8 +1715,7 @@ validate_return: */ void cipso_v4_error(struct sk_buff *skb, int error, u32 gateway) { - unsigned char optbuf[sizeof(struct ip_options) + 40]; - struct ip_options *opt = (struct ip_options *)optbuf; + struct inet_skb_parm parm; int res; if (ip_hdr(skb)->protocol == IPPROTO_ICMP || error != -EACCES) @@ -1728,19 +1726,19 @@ void cipso_v4_error(struct sk_buff *skb, int error, u32 gateway) * so we can not use icmp_send and IPCB here. */ - memset(opt, 0, sizeof(struct ip_options)); - opt->optlen = ip_hdr(skb)->ihl*4 - sizeof(struct iphdr); + memset(&parm, 0, sizeof(parm)); + parm.opt.optlen = ip_hdr(skb)->ihl * 4 - sizeof(struct iphdr); rcu_read_lock(); - res = __ip_options_compile(dev_net(skb->dev), opt, skb, NULL); + res = __ip_options_compile(dev_net(skb->dev), &parm.opt, skb, NULL); rcu_read_unlock(); if (res) return; if (gateway) - __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_NET_ANO, 0, opt); + __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_NET_ANO, 0, &parm); else - __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_ANO, 0, opt); + __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_ANO, 0, &parm); } /** @@ -1811,11 +1809,35 @@ static int cipso_v4_genopt(unsigned char *buf, u32 buf_len, return CIPSO_V4_HDR_LEN + ret_val; } +static int cipso_v4_get_actual_opt_len(const unsigned char *data, int len) +{ + int iter = 0, optlen = 0; + + /* determining the new total option length is tricky because of + * the padding necessary, the only thing i can think to do at + * this point is walk the options one-by-one, skipping the + * padding at the end to determine the actual option size and + * from there we can determine the new total option length + */ + while (iter < len) { + if (data[iter] == IPOPT_END) { + break; + } else if (data[iter] == IPOPT_NOP) { + iter++; + } else { + iter += data[iter + 1]; + optlen = iter; + } + } + return optlen; +} + /** * cipso_v4_sock_setattr - Add a CIPSO option to a socket * @sk: the socket * @doi_def: the CIPSO DOI to use * @secattr: the specific security attributes of the socket + * @sk_locked: true if caller holds the socket lock * * Description: * Set the CIPSO option on the given socket using the DOI definition and @@ -1827,7 +1849,8 @@ static int cipso_v4_genopt(unsigned char *buf, u32 buf_len, */ int cipso_v4_sock_setattr(struct sock *sk, const struct cipso_v4_doi *doi_def, - const struct netlbl_lsm_secattr *secattr) + const struct netlbl_lsm_secattr *secattr, + bool sk_locked) { int ret_val = -EPERM; unsigned char *buf = NULL; @@ -1877,9 +1900,8 @@ int cipso_v4_sock_setattr(struct sock *sk, sk_inet = inet_sk(sk); - old = rcu_dereference_protected(sk_inet->inet_opt, - lockdep_sock_is_held(sk)); - if (sk_inet->is_icsk) { + old = rcu_dereference_protected(sk_inet->inet_opt, sk_locked); + if (inet_test_bit(IS_ICSK, sk)) { sk_conn = inet_csk(sk); if (old) sk_conn->icsk_ext_hdr_len -= old->opt.optlen; @@ -1953,7 +1975,7 @@ int cipso_v4_req_setattr(struct request_sock *req, buf = NULL; req_inet = inet_rsk(req); - opt = xchg((__force struct ip_options_rcu **)&req_inet->ireq_opt, opt); + opt = unrcu_pointer(xchg(&req_inet->ireq_opt, RCU_INITIALIZER(opt))); if (opt) kfree_rcu(opt, rcu); @@ -1986,7 +2008,6 @@ static int cipso_v4_delopt(struct ip_options_rcu __rcu **opt_ptr) u8 cipso_len; u8 cipso_off; unsigned char *cipso_ptr; - int iter; int optlen_new; cipso_off = opt->opt.cipso - sizeof(struct iphdr); @@ -2006,19 +2027,8 @@ static int cipso_v4_delopt(struct ip_options_rcu __rcu **opt_ptr) memmove(cipso_ptr, cipso_ptr + cipso_len, opt->opt.optlen - cipso_off - cipso_len); - /* determining the new total option length is tricky because of - * the padding necessary, the only thing i can think to do at - * this point is walk the options one-by-one, skipping the - * padding at the end to determine the actual option size and - * from there we can determine the new total option length */ - iter = 0; - optlen_new = 0; - while (iter < opt->opt.optlen) - if (opt->opt.__data[iter] != IPOPT_NOP) { - iter += opt->opt.__data[iter + 1]; - optlen_new = iter; - } else - iter++; + optlen_new = cipso_v4_get_actual_opt_len(opt->opt.__data, + opt->opt.optlen); hdr_delta = opt->opt.optlen; opt->opt.optlen = (optlen_new + 3) & ~3; hdr_delta -= opt->opt.optlen; @@ -2049,7 +2059,7 @@ void cipso_v4_sock_delattr(struct sock *sk) sk_inet = inet_sk(sk); hdr_delta = cipso_v4_delopt(&sk_inet->inet_opt); - if (sk_inet->is_icsk && hdr_delta > 0) { + if (inet_test_bit(IS_ICSK, sk) && hdr_delta > 0) { struct inet_connection_sock *sk_conn = inet_csk(sk); sk_conn->icsk_ext_hdr_len -= hdr_delta; sk_conn->icsk_sync_mss(sk, sk_conn->icsk_pmtu_cookie); @@ -2220,7 +2230,7 @@ int cipso_v4_skbuff_setattr(struct sk_buff *skb, memset((char *)(iph + 1) + buf_len, 0, opt_len - buf_len); if (len_delta != 0) { iph->ihl = 5 + (opt_len >> 2); - iph->tot_len = htons(skb->len); + iph_set_totlen(iph, skb->len); } ip_send_check(iph); @@ -2238,7 +2248,8 @@ int cipso_v4_skbuff_setattr(struct sk_buff *skb, */ int cipso_v4_skbuff_delattr(struct sk_buff *skb) { - int ret_val; + int ret_val, cipso_len, hdr_len_actual, new_hdr_len_actual, new_hdr_len, + hdr_len_delta; struct iphdr *iph; struct ip_options *opt = &IPCB(skb)->opt; unsigned char *cipso_ptr; @@ -2251,16 +2262,37 @@ int cipso_v4_skbuff_delattr(struct sk_buff *skb) if (ret_val < 0) return ret_val; - /* the easiest thing to do is just replace the cipso option with noop - * options since we don't change the size of the packet, although we - * still need to recalculate the checksum */ - iph = ip_hdr(skb); cipso_ptr = (unsigned char *)iph + opt->cipso; - memset(cipso_ptr, IPOPT_NOOP, cipso_ptr[1]); + cipso_len = cipso_ptr[1]; + + hdr_len_actual = sizeof(struct iphdr) + + cipso_v4_get_actual_opt_len((unsigned char *)(iph + 1), + opt->optlen); + new_hdr_len_actual = hdr_len_actual - cipso_len; + new_hdr_len = (new_hdr_len_actual + 3) & ~3; + hdr_len_delta = (iph->ihl << 2) - new_hdr_len; + + /* 1. shift any options after CIPSO to the left */ + memmove(cipso_ptr, cipso_ptr + cipso_len, + new_hdr_len_actual - opt->cipso); + /* 2. move the whole IP header to its new place */ + memmove((unsigned char *)iph + hdr_len_delta, iph, new_hdr_len_actual); + /* 3. adjust the skb layout */ + skb_pull(skb, hdr_len_delta); + skb_reset_network_header(skb); + iph = ip_hdr(skb); + /* 4. re-fill new padding with IPOPT_END (may now be longer) */ + memset((unsigned char *)iph + new_hdr_len_actual, IPOPT_END, + new_hdr_len - new_hdr_len_actual); + + opt->optlen -= hdr_len_delta; opt->cipso = 0; opt->is_changed = 1; - + if (hdr_len_delta != 0) { + iph->ihl = new_hdr_len >> 2; + iph_set_totlen(iph, skb->len); + } ip_send_check(iph); return 0; diff --git a/net/ipv4/datagram.c b/net/ipv4/datagram.c index 48f337ccf949..1614593b6d72 100644 --- a/net/ipv4/datagram.c +++ b/net/ipv4/datagram.c @@ -16,7 +16,7 @@ #include <net/tcp_states.h> #include <net/sock_reuseport.h> -int __ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +int __ip4_datagram_connect(struct sock *sk, struct sockaddr_unsized *uaddr, int addr_len) { struct inet_sock *inet = inet_sk(sk); struct sockaddr_in *usin = (struct sockaddr_in *) uaddr; @@ -39,15 +39,16 @@ int __ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len saddr = inet->inet_saddr; if (ipv4_is_multicast(usin->sin_addr.s_addr)) { if (!oif || netif_index_is_l3_master(sock_net(sk), oif)) - oif = inet->mc_index; + oif = READ_ONCE(inet->mc_index); if (!saddr) - saddr = inet->mc_addr; + saddr = READ_ONCE(inet->mc_addr); + } else if (!oif) { + oif = READ_ONCE(inet->uc_index); } fl4 = &inet->cork.fl.u.ip4; - rt = ip_route_connect(fl4, usin->sin_addr.s_addr, saddr, - RT_CONN_FLAGS(sk), oif, - sk->sk_protocol, - inet->inet_sport, usin->sin_port, sk); + rt = ip_route_connect(fl4, usin->sin_addr.s_addr, saddr, oif, + sk->sk_protocol, inet->inet_sport, + usin->sin_port, sk); if (IS_ERR(rt)) { err = PTR_ERR(rt); if (err == -ENETUNREACH) @@ -60,19 +61,21 @@ int __ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len err = -EACCES; goto out; } + + /* Update addresses before rehashing */ + inet->inet_daddr = fl4->daddr; + inet->inet_dport = usin->sin_port; if (!inet->inet_saddr) - inet->inet_saddr = fl4->saddr; /* Update source address */ + inet->inet_saddr = fl4->saddr; if (!inet->inet_rcv_saddr) { inet->inet_rcv_saddr = fl4->saddr; if (sk->sk_prot->rehash) sk->sk_prot->rehash(sk); } - inet->inet_daddr = fl4->daddr; - inet->inet_dport = usin->sin_port; - reuseport_has_conns(sk, true); + reuseport_has_conns_set(sk); sk->sk_state = TCP_ESTABLISHED; sk_set_txhash(sk); - inet->inet_id = prandom_u32(); + atomic_set(&inet->inet_id, get_random_u16()); sk_dst_set(sk, &rt->dst); err = 0; @@ -81,7 +84,7 @@ out: } EXPORT_SYMBOL(__ip4_datagram_connect); -int ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +int ip4_datagram_connect(struct sock *sk, struct sockaddr_unsized *uaddr, int addr_len) { int res; @@ -99,8 +102,6 @@ EXPORT_SYMBOL(ip4_datagram_connect); void ip4_datagram_release_cb(struct sock *sk) { const struct inet_sock *inet = inet_sk(sk); - const struct ip_options_rcu *inet_opt; - __be32 daddr = inet->inet_daddr; struct dst_entry *dst; struct flowi4 fl4; struct rtable *rt; @@ -108,18 +109,13 @@ void ip4_datagram_release_cb(struct sock *sk) rcu_read_lock(); dst = __sk_dst_get(sk); - if (!dst || !dst->obsolete || dst->ops->check(dst, 0)) { + if (!dst || !READ_ONCE(dst->obsolete) || dst->ops->check(dst, 0)) { rcu_read_unlock(); return; } - inet_opt = rcu_dereference(inet->inet_opt); - if (inet_opt && inet_opt->opt.srr) - daddr = inet_opt->opt.faddr; - rt = ip_route_output_ports(sock_net(sk), &fl4, sk, daddr, - inet->inet_saddr, inet->inet_dport, - inet->inet_sport, sk->sk_protocol, - RT_CONN_FLAGS(sk), sk->sk_bound_dev_if); + inet_sk_init_flowi4(inet, &fl4); + rt = ip_route_output_flow(sock_net(sk), &fl4, sk); dst = !IS_ERR(rt) ? &rt->dst : NULL; sk_dst_set(sk, dst); diff --git a/net/ipv4/devinet.c b/net/ipv4/devinet.c index fba2bffd65f7..942a887bf089 100644 --- a/net/ipv4/devinet.c +++ b/net/ipv4/devinet.c @@ -46,6 +46,7 @@ #include <linux/notifier.h> #include <linux/inetdevice.h> #include <linux/igmp.h> +#include "igmp_internal.h" #include <linux/slab.h> #include <linux/hash.h> #ifdef CONFIG_SYSCTL @@ -104,25 +105,15 @@ static const struct nla_policy ifa_ipv4_policy[IFA_MAX+1] = { [IFA_FLAGS] = { .type = NLA_U32 }, [IFA_RT_PRIORITY] = { .type = NLA_U32 }, [IFA_TARGET_NETNSID] = { .type = NLA_S32 }, -}; - -struct inet_fill_args { - u32 portid; - u32 seq; - int event; - unsigned int flags; - int netnsid; - int ifindex; + [IFA_PROTO] = { .type = NLA_U8 }, }; #define IN4_ADDR_HSIZE_SHIFT 8 #define IN4_ADDR_HSIZE (1U << IN4_ADDR_HSIZE_SHIFT) -static struct hlist_head inet_addr_lst[IN4_ADDR_HSIZE]; - static u32 inet_addr_hash(const struct net *net, __be32 addr) { - u32 val = (__force u32) addr ^ net_hash_mix(net); + u32 val = __ipv4_addr_hash(addr, net_hash_mix(net)); return hash_32(val, IN4_ADDR_HSIZE_SHIFT); } @@ -132,13 +123,13 @@ static void inet_hash_insert(struct net *net, struct in_ifaddr *ifa) u32 hash = inet_addr_hash(net, ifa->ifa_local); ASSERT_RTNL(); - hlist_add_head_rcu(&ifa->hash, &inet_addr_lst[hash]); + hlist_add_head_rcu(&ifa->addr_lst, &net->ipv4.inet_addr_lst[hash]); } static void inet_hash_remove(struct in_ifaddr *ifa) { ASSERT_RTNL(); - hlist_del_init_rcu(&ifa->hash); + hlist_del_init_rcu(&ifa->addr_lst); } /** @@ -185,9 +176,8 @@ struct in_ifaddr *inet_lookup_ifaddr_rcu(struct net *net, __be32 addr) u32 hash = inet_addr_hash(net, addr); struct in_ifaddr *ifa; - hlist_for_each_entry_rcu(ifa, &inet_addr_lst[hash], hash) - if (ifa->ifa_local == addr && - net_eq(dev_net(ifa->ifa_dev->dev), net)) + hlist_for_each_entry_rcu(ifa, &net->ipv4.inet_addr_lst[hash], addr_lst) + if (ifa->ifa_local == addr) return ifa; return NULL; @@ -215,22 +205,45 @@ static void devinet_sysctl_unregister(struct in_device *idev) /* Locks all the inet devices. */ -static struct in_ifaddr *inet_alloc_ifa(void) +static struct in_ifaddr *inet_alloc_ifa(struct in_device *in_dev) { - return kzalloc(sizeof(struct in_ifaddr), GFP_KERNEL_ACCOUNT); + struct in_ifaddr *ifa; + + ifa = kzalloc(sizeof(*ifa), GFP_KERNEL_ACCOUNT); + if (!ifa) + return NULL; + + in_dev_hold(in_dev); + ifa->ifa_dev = in_dev; + + INIT_HLIST_NODE(&ifa->addr_lst); + + return ifa; } static void inet_rcu_free_ifa(struct rcu_head *head) { struct in_ifaddr *ifa = container_of(head, struct in_ifaddr, rcu_head); - if (ifa->ifa_dev) - in_dev_put(ifa->ifa_dev); + + in_dev_put(ifa->ifa_dev); kfree(ifa); } static void inet_free_ifa(struct in_ifaddr *ifa) { - call_rcu(&ifa->rcu_head, inet_rcu_free_ifa); + /* Our reference to ifa->ifa_dev must be freed ASAP + * to release the reference to the netdev the same way. + * in_dev_put() -> in_dev_finish_destroy() -> netdev_put() + */ + call_rcu_hurry(&ifa->rcu_head, inet_rcu_free_ifa); +} + +static void in_dev_free_rcu(struct rcu_head *head) +{ + struct in_device *idev = container_of(head, struct in_device, rcu_head); + + kfree(rcu_dereference_protected(idev->mc_hash, 1)); + kfree(idev); } void in_dev_finish_destroy(struct in_device *idev) @@ -239,15 +252,14 @@ void in_dev_finish_destroy(struct in_device *idev) WARN_ON(idev->ifa_list); WARN_ON(idev->mc_list); - kfree(rcu_dereference_protected(idev->mc_hash, 1)); #ifdef NET_REFCNT_DEBUG pr_debug("%s: %p=%s\n", __func__, idev, dev ? dev->name : "NIL"); #endif - dev_put_track(dev, &idev->dev_tracker); + netdev_put(dev, &idev->dev_tracker); if (!idev->dead) pr_err("Freeing alive in_device %p\n", idev); else - kfree(idev); + call_rcu(&idev->rcu_head, in_dev_free_rcu); } EXPORT_SYMBOL(in_dev_finish_destroy); @@ -269,23 +281,25 @@ static struct in_device *inetdev_init(struct net_device *dev) if (!in_dev->arp_parms) goto out_kfree; if (IPV4_DEVCONF(in_dev->cnf, FORWARDING)) - dev_disable_lro(dev); + netif_disable_lro(dev); /* Reference in_dev->dev */ - dev_hold_track(dev, &in_dev->dev_tracker, GFP_KERNEL); + netdev_hold(dev, &in_dev->dev_tracker, GFP_KERNEL); /* Account for reference dev->ip_ptr (below) */ refcount_set(&in_dev->refcnt, 1); - err = devinet_sysctl_register(in_dev); - if (err) { - in_dev->dead = 1; - neigh_parms_release(&arp_tbl, in_dev->arp_parms); - in_dev_put(in_dev); - in_dev = NULL; - goto out; + if (dev != blackhole_netdev) { + err = devinet_sysctl_register(in_dev); + if (err) { + in_dev->dead = 1; + neigh_parms_release(&arp_tbl, in_dev->arp_parms); + in_dev_put(in_dev); + in_dev = NULL; + goto out; + } + ip_mc_init_dev(in_dev); + if (dev->flags & IFF_UP) + ip_mc_up(in_dev); } - ip_mc_init_dev(in_dev); - if (dev->flags & IFF_UP) - ip_mc_up(in_dev); /* we can receive as soon as ip_ptr is set -- do this last */ rcu_assign_pointer(dev->ip_ptr, in_dev); @@ -297,12 +311,6 @@ out_kfree: goto out; } -static void in_dev_rcu_put(struct rcu_head *head) -{ - struct in_device *idev = container_of(head, struct in_device, rcu_head); - in_dev_put(idev); -} - static void inetdev_destroy(struct in_device *in_dev) { struct net_device *dev; @@ -327,9 +335,21 @@ static void inetdev_destroy(struct in_device *in_dev) neigh_parms_release(&arp_tbl, in_dev->arp_parms); arp_ifdown(dev); - call_rcu(&in_dev->rcu_head, in_dev_rcu_put); + in_dev_put(in_dev); } +static int __init inet_blackhole_dev_init(void) +{ + struct in_device *in_dev; + + rtnl_lock(); + in_dev = inetdev_init(blackhole_netdev); + rtnl_unlock(); + + return PTR_ERR_OR_ZERO(in_dev); +} +late_initcall(inet_blackhole_dev_init); + int inet_addr_onlink(struct in_device *in_dev, __be32 a, __be32 b) { const struct in_ifaddr *ifa; @@ -353,14 +373,14 @@ static void __inet_del_ifa(struct in_device *in_dev, { struct in_ifaddr *promote = NULL; struct in_ifaddr *ifa, *ifa1; - struct in_ifaddr *last_prim; + struct in_ifaddr __rcu **last_prim; struct in_ifaddr *prev_prom = NULL; int do_promote = IN_DEV_PROMOTE_SECONDARIES(in_dev); ASSERT_RTNL(); ifa1 = rtnl_dereference(*ifap); - last_prim = rtnl_dereference(in_dev->ifa_list); + last_prim = ifap; if (in_dev->dead) goto no_promotions; @@ -374,7 +394,7 @@ static void __inet_del_ifa(struct in_device *in_dev, while ((ifa = rtnl_dereference(*ifap1)) != NULL) { if (!(ifa->ifa_flags & IFA_F_SECONDARY) && ifa1->ifa_scope <= ifa->ifa_scope) - last_prim = ifa; + last_prim = &ifa->ifa_next; if (!(ifa->ifa_flags & IFA_F_SECONDARY) || ifa1->ifa_mask != ifa->ifa_mask || @@ -438,9 +458,9 @@ no_promotions: rcu_assign_pointer(prev_prom->ifa_next, next_sec); - last_sec = rtnl_dereference(last_prim->ifa_next); + last_sec = rtnl_dereference(*last_prim); rcu_assign_pointer(promote->ifa_next, last_sec); - rcu_assign_pointer(last_prim->ifa_next, promote); + rcu_assign_pointer(*last_prim, promote); } promote->ifa_flags &= ~IFA_F_SECONDARY; @@ -467,26 +487,18 @@ static void inet_del_ifa(struct in_device *in_dev, __inet_del_ifa(in_dev, ifap, destroy, NULL, 0); } -static void check_lifetime(struct work_struct *work); - -static DECLARE_DELAYED_WORK(check_lifetime_work, check_lifetime); - static int __inet_insert_ifa(struct in_ifaddr *ifa, struct nlmsghdr *nlh, u32 portid, struct netlink_ext_ack *extack) { struct in_ifaddr __rcu **last_primary, **ifap; struct in_device *in_dev = ifa->ifa_dev; + struct net *net = dev_net(in_dev->dev); struct in_validator_info ivi; struct in_ifaddr *ifa1; int ret; ASSERT_RTNL(); - if (!ifa->ifa_local) { - inet_free_ifa(ifa); - return 0; - } - ifa->ifa_flags &= ~IFA_F_SECONDARY; last_primary = &in_dev->ifa_list; @@ -507,6 +519,7 @@ static int __inet_insert_ifa(struct in_ifaddr *ifa, struct nlmsghdr *nlh, return -EEXIST; } if (ifa1->ifa_scope != ifa->ifa_scope) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid scope value"); inet_free_ifa(ifa); return -EINVAL; } @@ -535,18 +548,16 @@ static int __inet_insert_ifa(struct in_ifaddr *ifa, struct nlmsghdr *nlh, return ret; } - if (!(ifa->ifa_flags & IFA_F_SECONDARY)) { - prandom_seed((__force u32) ifa->ifa_local); + if (!(ifa->ifa_flags & IFA_F_SECONDARY)) ifap = last_primary; - } rcu_assign_pointer(ifa->ifa_next, *ifap); rcu_assign_pointer(*ifap, ifa); inet_hash_insert(dev_net(in_dev->dev), ifa); - cancel_delayed_work(&check_lifetime_work); - queue_delayed_work(system_power_efficient_wq, &check_lifetime_work, 0); + cancel_delayed_work(&net->ipv4.addr_chk_work); + queue_delayed_work(system_power_efficient_wq, &net->ipv4.addr_chk_work, 0); /* Send message first, then call notifier. Notifier will trigger FIB update, so that @@ -559,26 +570,21 @@ static int __inet_insert_ifa(struct in_ifaddr *ifa, struct nlmsghdr *nlh, static int inet_insert_ifa(struct in_ifaddr *ifa) { + if (!ifa->ifa_local) { + inet_free_ifa(ifa); + return 0; + } + return __inet_insert_ifa(ifa, NULL, 0, NULL); } static int inet_set_ifa(struct net_device *dev, struct in_ifaddr *ifa) { - struct in_device *in_dev = __in_dev_get_rtnl(dev); - - ASSERT_RTNL(); + struct in_device *in_dev = __in_dev_get_rtnl_net(dev); - if (!in_dev) { - inet_free_ifa(ifa); - return -ENOBUFS; - } ipv4_devconf_setall(in_dev); neigh_parms_data_state_setall(in_dev->arp_parms); - if (ifa->ifa_dev != in_dev) { - WARN_ON(ifa->ifa_dev); - in_dev_hold(in_dev); - ifa->ifa_dev = in_dev; - } + if (ipv4_is_loopback(ifa->ifa_local)) ifa->ifa_scope = RT_SCOPE_HOST; return inet_insert_ifa(ifa); @@ -628,7 +634,7 @@ static int ip_mc_autojoin_config(struct net *net, bool join, struct sock *sk = net->ipv4.mc_autojoin_sk; int ret; - ASSERT_RTNL(); + ASSERT_RTNL_NET(net); lock_sock(sk); if (join) @@ -654,21 +660,24 @@ static int inet_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh, struct in_ifaddr *ifa; int err; - ASSERT_RTNL(); - err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_ipv4_policy, extack); if (err < 0) - goto errout; + goto out; ifm = nlmsg_data(nlh); + + rtnl_net_lock(net); + in_dev = inetdev_by_index(net, ifm->ifa_index); if (!in_dev) { + NL_SET_ERR_MSG(extack, "ipv4: Device not found"); err = -ENODEV; - goto errout; + goto unlock; } - for (ifap = &in_dev->ifa_list; (ifa = rtnl_dereference(*ifap)) != NULL; + for (ifap = &in_dev->ifa_list; + (ifa = rtnl_net_dereference(net, *ifap)) != NULL; ifap = &ifa->ifa_next) { if (tb[IFA_LOCAL] && ifa->ifa_local != nla_get_in_addr(tb[IFA_LOCAL])) @@ -684,67 +693,76 @@ static int inet_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh, if (ipv4_is_multicast(ifa->ifa_address)) ip_mc_autojoin_config(net, false, ifa); + __inet_del_ifa(in_dev, ifap, 1, nlh, NETLINK_CB(skb).portid); - return 0; + goto unlock; } + NL_SET_ERR_MSG(extack, "ipv4: Address not found"); err = -EADDRNOTAVAIL; -errout: +unlock: + rtnl_net_unlock(net); +out: return err; } -#define INFINITY_LIFE_TIME 0xFFFFFFFF - static void check_lifetime(struct work_struct *work) { unsigned long now, next, next_sec, next_sched; struct in_ifaddr *ifa; struct hlist_node *n; + struct net *net; int i; + net = container_of(to_delayed_work(work), struct net, ipv4.addr_chk_work); now = jiffies; next = round_jiffies_up(now + ADDR_CHECK_FREQUENCY); for (i = 0; i < IN4_ADDR_HSIZE; i++) { + struct hlist_head *head = &net->ipv4.inet_addr_lst[i]; bool change_needed = false; rcu_read_lock(); - hlist_for_each_entry_rcu(ifa, &inet_addr_lst[i], hash) { - unsigned long age; - - if (ifa->ifa_flags & IFA_F_PERMANENT) + hlist_for_each_entry_rcu(ifa, head, addr_lst) { + unsigned long age, tstamp; + u32 preferred_lft; + u32 valid_lft; + u32 flags; + + flags = READ_ONCE(ifa->ifa_flags); + if (flags & IFA_F_PERMANENT) continue; + preferred_lft = READ_ONCE(ifa->ifa_preferred_lft); + valid_lft = READ_ONCE(ifa->ifa_valid_lft); + tstamp = READ_ONCE(ifa->ifa_tstamp); /* We try to batch several events at once. */ - age = (now - ifa->ifa_tstamp + + age = (now - tstamp + ADDRCONF_TIMER_FUZZ_MINUS) / HZ; - if (ifa->ifa_valid_lft != INFINITY_LIFE_TIME && - age >= ifa->ifa_valid_lft) { + if (valid_lft != INFINITY_LIFE_TIME && + age >= valid_lft) { change_needed = true; - } else if (ifa->ifa_preferred_lft == + } else if (preferred_lft == INFINITY_LIFE_TIME) { continue; - } else if (age >= ifa->ifa_preferred_lft) { - if (time_before(ifa->ifa_tstamp + - ifa->ifa_valid_lft * HZ, next)) - next = ifa->ifa_tstamp + - ifa->ifa_valid_lft * HZ; + } else if (age >= preferred_lft) { + if (time_before(tstamp + valid_lft * HZ, next)) + next = tstamp + valid_lft * HZ; - if (!(ifa->ifa_flags & IFA_F_DEPRECATED)) + if (!(flags & IFA_F_DEPRECATED)) change_needed = true; - } else if (time_before(ifa->ifa_tstamp + - ifa->ifa_preferred_lft * HZ, + } else if (time_before(tstamp + preferred_lft * HZ, next)) { - next = ifa->ifa_tstamp + - ifa->ifa_preferred_lft * HZ; + next = tstamp + preferred_lft * HZ; } } rcu_read_unlock(); if (!change_needed) continue; - rtnl_lock(); - hlist_for_each_entry_safe(ifa, n, &inet_addr_lst[i], hash) { + + rtnl_net_lock(net); + hlist_for_each_entry_safe(ifa, n, head, addr_lst) { unsigned long age; if (ifa->ifa_flags & IFA_F_PERMANENT) @@ -760,7 +778,7 @@ static void check_lifetime(struct work_struct *work) struct in_ifaddr *tmp; ifap = &ifa->ifa_dev->ifa_list; - tmp = rtnl_dereference(*ifap); + tmp = rtnl_net_dereference(net, *ifap); while (tmp) { if (tmp == ifa) { inet_del_ifa(ifa->ifa_dev, @@ -768,7 +786,7 @@ static void check_lifetime(struct work_struct *work) break; } ifap = &tmp->ifa_next; - tmp = rtnl_dereference(*ifap); + tmp = rtnl_net_dereference(net, *ifap); } } else if (ifa->ifa_preferred_lft != INFINITY_LIFE_TIME && @@ -778,7 +796,7 @@ static void check_lifetime(struct work_struct *work) rtmsg_ifa(RTM_NEWADDR, ifa, NULL, 0); } } - rtnl_unlock(); + rtnl_net_unlock(net); } next_sec = round_jiffies_up(next); @@ -793,66 +811,97 @@ static void check_lifetime(struct work_struct *work) if (time_before(next_sched, now + ADDRCONF_TIMER_FUZZ_MAX)) next_sched = now + ADDRCONF_TIMER_FUZZ_MAX; - queue_delayed_work(system_power_efficient_wq, &check_lifetime_work, - next_sched - now); + queue_delayed_work(system_power_efficient_wq, &net->ipv4.addr_chk_work, + next_sched - now); } static void set_ifa_lifetime(struct in_ifaddr *ifa, __u32 valid_lft, __u32 prefered_lft) { unsigned long timeout; + u32 flags; - ifa->ifa_flags &= ~(IFA_F_PERMANENT | IFA_F_DEPRECATED); + flags = ifa->ifa_flags & ~(IFA_F_PERMANENT | IFA_F_DEPRECATED); timeout = addrconf_timeout_fixup(valid_lft, HZ); if (addrconf_finite_timeout(timeout)) - ifa->ifa_valid_lft = timeout; + WRITE_ONCE(ifa->ifa_valid_lft, timeout); else - ifa->ifa_flags |= IFA_F_PERMANENT; + flags |= IFA_F_PERMANENT; timeout = addrconf_timeout_fixup(prefered_lft, HZ); if (addrconf_finite_timeout(timeout)) { if (timeout == 0) - ifa->ifa_flags |= IFA_F_DEPRECATED; - ifa->ifa_preferred_lft = timeout; + flags |= IFA_F_DEPRECATED; + WRITE_ONCE(ifa->ifa_preferred_lft, timeout); } - ifa->ifa_tstamp = jiffies; + WRITE_ONCE(ifa->ifa_flags, flags); + WRITE_ONCE(ifa->ifa_tstamp, jiffies); if (!ifa->ifa_cstamp) - ifa->ifa_cstamp = ifa->ifa_tstamp; + WRITE_ONCE(ifa->ifa_cstamp, ifa->ifa_tstamp); } -static struct in_ifaddr *rtm_to_ifaddr(struct net *net, struct nlmsghdr *nlh, - __u32 *pvalid_lft, __u32 *pprefered_lft, - struct netlink_ext_ack *extack) +static int inet_validate_rtm(struct nlmsghdr *nlh, struct nlattr **tb, + struct netlink_ext_ack *extack, + __u32 *valid_lft, __u32 *prefered_lft) { - struct nlattr *tb[IFA_MAX+1]; - struct in_ifaddr *ifa; - struct ifaddrmsg *ifm; - struct net_device *dev; - struct in_device *in_dev; + struct ifaddrmsg *ifm = nlmsg_data(nlh); int err; err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_ipv4_policy, extack); if (err < 0) - goto errout; + return err; - ifm = nlmsg_data(nlh); - err = -EINVAL; - if (ifm->ifa_prefixlen > 32 || !tb[IFA_LOCAL]) - goto errout; + if (ifm->ifa_prefixlen > 32) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid prefix length"); + return -EINVAL; + } + + if (!tb[IFA_LOCAL]) { + NL_SET_ERR_MSG(extack, "ipv4: Local address is not supplied"); + return -EINVAL; + } + + if (tb[IFA_CACHEINFO]) { + struct ifa_cacheinfo *ci; + + ci = nla_data(tb[IFA_CACHEINFO]); + if (!ci->ifa_valid || ci->ifa_prefered > ci->ifa_valid) { + NL_SET_ERR_MSG(extack, "ipv4: address lifetime invalid"); + return -EINVAL; + } + + *valid_lft = ci->ifa_valid; + *prefered_lft = ci->ifa_prefered; + } + + return 0; +} + +static struct in_ifaddr *inet_rtm_to_ifa(struct net *net, struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct ifaddrmsg *ifm = nlmsg_data(nlh); + struct in_device *in_dev; + struct net_device *dev; + struct in_ifaddr *ifa; + int err; dev = __dev_get_by_index(net, ifm->ifa_index); err = -ENODEV; - if (!dev) + if (!dev) { + NL_SET_ERR_MSG(extack, "ipv4: Device not found"); goto errout; + } - in_dev = __in_dev_get_rtnl(dev); + in_dev = __in_dev_get_rtnl_net(dev); err = -ENOBUFS; if (!in_dev) goto errout; - ifa = inet_alloc_ifa(); + ifa = inet_alloc_ifa(in_dev); if (!ifa) /* * A potential indev allocation can be left alive, it stays @@ -862,19 +911,14 @@ static struct in_ifaddr *rtm_to_ifaddr(struct net *net, struct nlmsghdr *nlh, ipv4_devconf_setall(in_dev); neigh_parms_data_state_setall(in_dev->arp_parms); - in_dev_hold(in_dev); if (!tb[IFA_ADDRESS]) tb[IFA_ADDRESS] = tb[IFA_LOCAL]; - INIT_HLIST_NODE(&ifa->hash); ifa->ifa_prefixlen = ifm->ifa_prefixlen; ifa->ifa_mask = inet_make_mask(ifm->ifa_prefixlen); - ifa->ifa_flags = tb[IFA_FLAGS] ? nla_get_u32(tb[IFA_FLAGS]) : - ifm->ifa_flags; + ifa->ifa_flags = nla_get_u32_default(tb[IFA_FLAGS], ifm->ifa_flags); ifa->ifa_scope = ifm->ifa_scope; - ifa->ifa_dev = in_dev; - ifa->ifa_local = nla_get_in_addr(tb[IFA_LOCAL]); ifa->ifa_address = nla_get_in_addr(tb[IFA_ADDRESS]); @@ -889,82 +933,84 @@ static struct in_ifaddr *rtm_to_ifaddr(struct net *net, struct nlmsghdr *nlh, if (tb[IFA_RT_PRIORITY]) ifa->ifa_rt_priority = nla_get_u32(tb[IFA_RT_PRIORITY]); - if (tb[IFA_CACHEINFO]) { - struct ifa_cacheinfo *ci; - - ci = nla_data(tb[IFA_CACHEINFO]); - if (!ci->ifa_valid || ci->ifa_prefered > ci->ifa_valid) { - err = -EINVAL; - goto errout_free; - } - *pvalid_lft = ci->ifa_valid; - *pprefered_lft = ci->ifa_prefered; - } + if (tb[IFA_PROTO]) + ifa->ifa_proto = nla_get_u8(tb[IFA_PROTO]); return ifa; -errout_free: - inet_free_ifa(ifa); errout: return ERR_PTR(err); } -static struct in_ifaddr *find_matching_ifa(struct in_ifaddr *ifa) +static struct in_ifaddr *find_matching_ifa(struct net *net, struct in_ifaddr *ifa) { struct in_device *in_dev = ifa->ifa_dev; struct in_ifaddr *ifa1; - if (!ifa->ifa_local) - return NULL; - - in_dev_for_each_ifa_rtnl(ifa1, in_dev) { + in_dev_for_each_ifa_rtnl_net(net, ifa1, in_dev) { if (ifa1->ifa_mask == ifa->ifa_mask && inet_ifa_match(ifa1->ifa_address, ifa) && ifa1->ifa_local == ifa->ifa_local) return ifa1; } + return NULL; } static int inet_rtm_newaddr(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { + __u32 prefered_lft = INFINITY_LIFE_TIME; + __u32 valid_lft = INFINITY_LIFE_TIME; struct net *net = sock_net(skb->sk); - struct in_ifaddr *ifa; struct in_ifaddr *ifa_existing; - __u32 valid_lft = INFINITY_LIFE_TIME; - __u32 prefered_lft = INFINITY_LIFE_TIME; + struct nlattr *tb[IFA_MAX + 1]; + struct in_ifaddr *ifa; + int ret; - ASSERT_RTNL(); + ret = inet_validate_rtm(nlh, tb, extack, &valid_lft, &prefered_lft); + if (ret < 0) + return ret; + + if (!nla_get_in_addr(tb[IFA_LOCAL])) + return 0; - ifa = rtm_to_ifaddr(net, nlh, &valid_lft, &prefered_lft, extack); - if (IS_ERR(ifa)) - return PTR_ERR(ifa); + rtnl_net_lock(net); - ifa_existing = find_matching_ifa(ifa); + ifa = inet_rtm_to_ifa(net, nlh, tb, extack); + if (IS_ERR(ifa)) { + ret = PTR_ERR(ifa); + goto unlock; + } + + ifa_existing = find_matching_ifa(net, ifa); if (!ifa_existing) { /* It would be best to check for !NLM_F_CREATE here but * userspace already relies on not having to provide this. */ set_ifa_lifetime(ifa, valid_lft, prefered_lft); if (ifa->ifa_flags & IFA_F_MCAUTOJOIN) { - int ret = ip_mc_autojoin_config(net, true, ifa); - + ret = ip_mc_autojoin_config(net, true, ifa); if (ret < 0) { + NL_SET_ERR_MSG(extack, "ipv4: Multicast auto join failed"); inet_free_ifa(ifa); - return ret; + goto unlock; } } - return __inet_insert_ifa(ifa, nlh, NETLINK_CB(skb).portid, - extack); + + ret = __inet_insert_ifa(ifa, nlh, NETLINK_CB(skb).portid, extack); } else { u32 new_metric = ifa->ifa_rt_priority; + u8 new_proto = ifa->ifa_proto; inet_free_ifa(ifa); if (nlh->nlmsg_flags & NLM_F_EXCL || - !(nlh->nlmsg_flags & NLM_F_REPLACE)) - return -EEXIST; + !(nlh->nlmsg_flags & NLM_F_REPLACE)) { + NL_SET_ERR_MSG(extack, "ipv4: Address already assigned"); + ret = -EEXIST; + goto unlock; + } ifa = ifa_existing; if (ifa->ifa_rt_priority != new_metric) { @@ -972,13 +1018,19 @@ static int inet_rtm_newaddr(struct sk_buff *skb, struct nlmsghdr *nlh, ifa->ifa_rt_priority = new_metric; } + ifa->ifa_proto = new_proto; + set_ifa_lifetime(ifa, valid_lft, prefered_lft); - cancel_delayed_work(&check_lifetime_work); + cancel_delayed_work(&net->ipv4.addr_chk_work); queue_delayed_work(system_power_efficient_wq, - &check_lifetime_work, 0); + &net->ipv4.addr_chk_work, 0); rtmsg_ifa(RTM_NEWADDR, ifa, nlh, NETLINK_CB(skb).portid); } - return 0; + +unlock: + rtnl_net_unlock(net); + + return ret; } /* @@ -1065,7 +1117,7 @@ int devinet_ioctl(struct net *net, unsigned int cmd, struct ifreq *ifr) goto out; } - rtnl_lock(); + rtnl_net_lock(net); ret = -ENODEV; dev = __dev_get_by_name(net, ifr->ifr_name); @@ -1075,7 +1127,7 @@ int devinet_ioctl(struct net *net, unsigned int cmd, struct ifreq *ifr) if (colon) *colon = ':'; - in_dev = __in_dev_get_rtnl(dev); + in_dev = __in_dev_get_rtnl_net(dev); if (in_dev) { if (tryaddrmatch) { /* Matthias Andree */ @@ -1085,7 +1137,7 @@ int devinet_ioctl(struct net *net, unsigned int cmd, struct ifreq *ifr) This is checked above. */ for (ifap = &in_dev->ifa_list; - (ifa = rtnl_dereference(*ifap)) != NULL; + (ifa = rtnl_net_dereference(net, *ifap)) != NULL; ifap = &ifa->ifa_next) { if (!strcmp(ifr->ifr_name, ifa->ifa_label) && sin_orig.sin_addr.s_addr == @@ -1099,7 +1151,7 @@ int devinet_ioctl(struct net *net, unsigned int cmd, struct ifreq *ifr) comparing just the label */ if (!ifa) { for (ifap = &in_dev->ifa_list; - (ifa = rtnl_dereference(*ifap)) != NULL; + (ifa = rtnl_net_dereference(net, *ifap)) != NULL; ifap = &ifa->ifa_next) if (!strcmp(ifr->ifr_name, ifa->ifa_label)) break; @@ -1141,6 +1193,9 @@ int devinet_ioctl(struct net *net, unsigned int cmd, struct ifreq *ifr) inet_del_ifa(in_dev, ifap, 1); break; } + + /* NETDEV_UP/DOWN/CHANGE could touch a peer dev */ + ASSERT_RTNL(); ret = dev_change_flags(dev, ifr->ifr_flags, NULL); break; @@ -1151,10 +1206,12 @@ int devinet_ioctl(struct net *net, unsigned int cmd, struct ifreq *ifr) if (!ifa) { ret = -ENOBUFS; - ifa = inet_alloc_ifa(); + if (!in_dev) + break; + ifa = inet_alloc_ifa(in_dev); if (!ifa) break; - INIT_HLIST_NODE(&ifa->hash); + if (colon) memcpy(ifa->ifa_label, ifr->ifr_name, IFNAMSIZ); else @@ -1240,14 +1297,14 @@ int devinet_ioctl(struct net *net, unsigned int cmd, struct ifreq *ifr) break; } done: - rtnl_unlock(); + rtnl_net_unlock(net); out: return ret; } int inet_gifconf(struct net_device *dev, char __user *buf, int len, int size) { - struct in_device *in_dev = __in_dev_get_rtnl(dev); + struct in_device *in_dev = __in_dev_get_rtnl_net(dev); const struct in_ifaddr *ifa; struct ifreq ifr; int done = 0; @@ -1258,7 +1315,7 @@ int inet_gifconf(struct net_device *dev, char __user *buf, int len, int size) if (!in_dev) goto out; - in_dev_for_each_ifa_rtnl(ifa, in_dev) { + in_dev_for_each_ifa_rtnl_net(dev_net(dev), ifa, in_dev) { if (!buf) { done += size; continue; @@ -1289,7 +1346,7 @@ static __be32 in_dev_select_addr(const struct in_device *in_dev, const struct in_ifaddr *ifa; in_dev_for_each_ifa_rcu(ifa, in_dev) { - if (ifa->ifa_flags & IFA_F_SECONDARY) + if (READ_ONCE(ifa->ifa_flags) & IFA_F_SECONDARY) continue; if (ifa->ifa_scope != RT_SCOPE_LINK && ifa->ifa_scope <= scope) @@ -1305,10 +1362,11 @@ __be32 inet_select_addr(const struct net_device *dev, __be32 dst, int scope) __be32 addr = 0; unsigned char localnet_scope = RT_SCOPE_HOST; struct in_device *in_dev; - struct net *net = dev_net(dev); + struct net *net; int master_idx; rcu_read_lock(); + net = dev_net_rcu(dev); in_dev = __in_dev_get_rcu(dev); if (!in_dev) goto no_in_dev; @@ -1317,7 +1375,7 @@ __be32 inet_select_addr(const struct net_device *dev, __be32 dst, int scope) localnet_scope = RT_SCOPE_LINK; in_dev_for_each_ifa_rcu(ifa, in_dev) { - if (ifa->ifa_flags & IFA_F_SECONDARY) + if (READ_ONCE(ifa->ifa_flags) & IFA_F_SECONDARY) continue; if (min(ifa->ifa_scope, localnet_scope) > scope) continue; @@ -1553,16 +1611,13 @@ static int inetdev_event(struct notifier_block *this, unsigned long event, if (!inetdev_valid_mtu(dev->mtu)) break; if (dev->flags & IFF_LOOPBACK) { - struct in_ifaddr *ifa = inet_alloc_ifa(); + struct in_ifaddr *ifa = inet_alloc_ifa(in_dev); if (ifa) { - INIT_HLIST_NODE(&ifa->hash); ifa->ifa_local = ifa->ifa_address = htonl(INADDR_LOOPBACK); ifa->ifa_prefixlen = 8; ifa->ifa_mask = inet_make_mask(8); - in_dev_hold(in_dev); - ifa->ifa_dev = in_dev; ifa->ifa_scope = RT_SCOPE_HOST; memcpy(ifa->ifa_label, dev->name, IFNAMSIZ); set_ifa_lifetime(ifa, INFINITY_LIFE_TIME, @@ -1625,6 +1680,7 @@ static size_t inet_nlmsg_size(void) + nla_total_size(4) /* IFA_BROADCAST */ + nla_total_size(IFNAMSIZ) /* IFA_LABEL */ + nla_total_size(4) /* IFA_FLAGS */ + + nla_total_size(1) /* IFA_PROTO */ + nla_total_size(4) /* IFA_RT_PRIORITY */ + nla_total_size(sizeof(struct ifa_cacheinfo)); /* IFA_CACHEINFO */ } @@ -1647,12 +1703,14 @@ static int put_cacheinfo(struct sk_buff *skb, unsigned long cstamp, return nla_put(skb, IFA_CACHEINFO, sizeof(ci), &ci); } -static int inet_fill_ifaddr(struct sk_buff *skb, struct in_ifaddr *ifa, +static int inet_fill_ifaddr(struct sk_buff *skb, const struct in_ifaddr *ifa, struct inet_fill_args *args) { struct ifaddrmsg *ifm; struct nlmsghdr *nlh; + unsigned long tstamp; u32 preferred, valid; + u32 flags; nlh = nlmsg_put(skb, args->portid, args->seq, args->event, sizeof(*ifm), args->flags); @@ -1662,7 +1720,13 @@ static int inet_fill_ifaddr(struct sk_buff *skb, struct in_ifaddr *ifa, ifm = nlmsg_data(nlh); ifm->ifa_family = AF_INET; ifm->ifa_prefixlen = ifa->ifa_prefixlen; - ifm->ifa_flags = ifa->ifa_flags; + + flags = READ_ONCE(ifa->ifa_flags); + /* Warning : ifm->ifa_flags is an __u8, it holds only 8 bits. + * The 32bit value is given in IFA_FLAGS attribute. + */ + ifm->ifa_flags = (__u8)flags; + ifm->ifa_scope = ifa->ifa_scope; ifm->ifa_index = ifa->ifa_dev->dev->ifindex; @@ -1670,11 +1734,12 @@ static int inet_fill_ifaddr(struct sk_buff *skb, struct in_ifaddr *ifa, nla_put_s32(skb, IFA_TARGET_NETNSID, args->netnsid)) goto nla_put_failure; - if (!(ifm->ifa_flags & IFA_F_PERMANENT)) { - preferred = ifa->ifa_preferred_lft; - valid = ifa->ifa_valid_lft; + tstamp = READ_ONCE(ifa->ifa_tstamp); + if (!(flags & IFA_F_PERMANENT)) { + preferred = READ_ONCE(ifa->ifa_preferred_lft); + valid = READ_ONCE(ifa->ifa_valid_lft); if (preferred != INFINITY_LIFE_TIME) { - long tval = (jiffies - ifa->ifa_tstamp) / HZ; + long tval = (jiffies - tstamp) / HZ; if (preferred > tval) preferred -= tval; @@ -1699,10 +1764,12 @@ static int inet_fill_ifaddr(struct sk_buff *skb, struct in_ifaddr *ifa, nla_put_in_addr(skb, IFA_BROADCAST, ifa->ifa_broadcast)) || (ifa->ifa_label[0] && nla_put_string(skb, IFA_LABEL, ifa->ifa_label)) || - nla_put_u32(skb, IFA_FLAGS, ifa->ifa_flags) || + (ifa->ifa_proto && + nla_put_u8(skb, IFA_PROTO, ifa->ifa_proto)) || + nla_put_u32(skb, IFA_FLAGS, flags) || (ifa->ifa_rt_priority && nla_put_u32(skb, IFA_RT_PRIORITY, ifa->ifa_rt_priority)) || - put_cacheinfo(skb, ifa->ifa_cstamp, ifa->ifa_tstamp, + put_cacheinfo(skb, READ_ONCE(ifa->ifa_cstamp), tstamp, preferred, valid)) goto nla_put_failure; @@ -1724,12 +1791,12 @@ static int inet_valid_dump_ifaddr_req(const struct nlmsghdr *nlh, struct ifaddrmsg *ifm; int err, i; - if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + ifm = nlmsg_payload(nlh, sizeof(*ifm)); + if (!ifm) { NL_SET_ERR_MSG(extack, "ipv4: Invalid header for address dump request"); return -EINVAL; } - ifm = nlmsg_data(nlh); if (ifm->ifa_prefixlen || ifm->ifa_flags || ifm->ifa_scope) { NL_SET_ERR_MSG(extack, "ipv4: Invalid values in header for address dump request"); return -EINVAL; @@ -1771,16 +1838,45 @@ static int inet_valid_dump_ifaddr_req(const struct nlmsghdr *nlh, return 0; } -static int in_dev_dump_addr(struct in_device *in_dev, struct sk_buff *skb, - struct netlink_callback *cb, int s_ip_idx, - struct inet_fill_args *fillargs) +static int in_dev_dump_ifmcaddr(struct in_device *in_dev, struct sk_buff *skb, + struct netlink_callback *cb, int *s_ip_idx, + struct inet_fill_args *fillargs) +{ + struct ip_mc_list *im; + int ip_idx = 0; + int err; + + for (im = rcu_dereference(in_dev->mc_list); + im; + im = rcu_dereference(im->next_rcu)) { + if (ip_idx < *s_ip_idx) { + ip_idx++; + continue; + } + err = inet_fill_ifmcaddr(skb, in_dev->dev, im, fillargs); + if (err < 0) + goto done; + + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); + ip_idx++; + } + err = 0; + ip_idx = 0; +done: + *s_ip_idx = ip_idx; + return err; +} + +static int in_dev_dump_ifaddr(struct in_device *in_dev, struct sk_buff *skb, + struct netlink_callback *cb, int *s_ip_idx, + struct inet_fill_args *fillargs) { struct in_ifaddr *ifa; int ip_idx = 0; int err; - in_dev_for_each_ifa_rtnl(ifa, in_dev) { - if (ip_idx < s_ip_idx) { + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (ip_idx < *s_ip_idx) { ip_idx++; continue; } @@ -1792,95 +1888,112 @@ static int in_dev_dump_addr(struct in_device *in_dev, struct sk_buff *skb, ip_idx++; } err = 0; - + ip_idx = 0; done: - cb->args[2] = ip_idx; + *s_ip_idx = ip_idx; return err; } -static int inet_dump_ifaddr(struct sk_buff *skb, struct netlink_callback *cb) +static int in_dev_dump_addr(struct in_device *in_dev, struct sk_buff *skb, + struct netlink_callback *cb, int *s_ip_idx, + struct inet_fill_args *fillargs) +{ + switch (fillargs->event) { + case RTM_NEWADDR: + return in_dev_dump_ifaddr(in_dev, skb, cb, s_ip_idx, fillargs); + case RTM_GETMULTICAST: + return in_dev_dump_ifmcaddr(in_dev, skb, cb, s_ip_idx, + fillargs); + default: + return -EINVAL; + } +} + +/* Combine dev_addr_genid and dev_base_seq to detect changes. + */ +static u32 inet_base_seq(const struct net *net) +{ + u32 res = atomic_read(&net->ipv4.dev_addr_genid) + + READ_ONCE(net->dev_base_seq); + + /* Must not return 0 (see nl_dump_check_consistent()). + * Chose a value far away from 0. + */ + if (!res) + res = 0x80000000; + return res; +} + +static int inet_dump_addr(struct sk_buff *skb, struct netlink_callback *cb, + int event) { const struct nlmsghdr *nlh = cb->nlh; struct inet_fill_args fillargs = { .portid = NETLINK_CB(cb->skb).portid, .seq = nlh->nlmsg_seq, - .event = RTM_NEWADDR, + .event = event, .flags = NLM_F_MULTI, .netnsid = -1, }; struct net *net = sock_net(skb->sk); struct net *tgt_net = net; - int h, s_h; - int idx, s_idx; - int s_ip_idx; - struct net_device *dev; + struct { + unsigned long ifindex; + int ip_idx; + } *ctx = (void *)cb->ctx; struct in_device *in_dev; - struct hlist_head *head; + struct net_device *dev; int err = 0; - s_h = cb->args[0]; - s_idx = idx = cb->args[1]; - s_ip_idx = cb->args[2]; - + rcu_read_lock(); if (cb->strict_check) { err = inet_valid_dump_ifaddr_req(nlh, &fillargs, &tgt_net, skb->sk, cb); if (err < 0) - goto put_tgt_net; + goto done; - err = 0; if (fillargs.ifindex) { - dev = __dev_get_by_index(tgt_net, fillargs.ifindex); + dev = dev_get_by_index_rcu(tgt_net, fillargs.ifindex); if (!dev) { err = -ENODEV; - goto put_tgt_net; - } - - in_dev = __in_dev_get_rtnl(dev); - if (in_dev) { - err = in_dev_dump_addr(in_dev, skb, cb, s_ip_idx, - &fillargs); + goto done; } - goto put_tgt_net; - } - } - - for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { - idx = 0; - head = &tgt_net->dev_index_head[h]; - rcu_read_lock(); - cb->seq = atomic_read(&tgt_net->ipv4.dev_addr_genid) ^ - tgt_net->dev_base_seq; - hlist_for_each_entry_rcu(dev, head, index_hlist) { - if (idx < s_idx) - goto cont; - if (h > s_h || idx > s_idx) - s_ip_idx = 0; in_dev = __in_dev_get_rcu(dev); if (!in_dev) - goto cont; - - err = in_dev_dump_addr(in_dev, skb, cb, s_ip_idx, - &fillargs); - if (err < 0) { - rcu_read_unlock(); goto done; - } -cont: - idx++; + err = in_dev_dump_addr(in_dev, skb, cb, &ctx->ip_idx, + &fillargs); + goto done; } - rcu_read_unlock(); } + cb->seq = inet_base_seq(tgt_net); + + for_each_netdev_dump(tgt_net, dev, ctx->ifindex) { + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) + continue; + err = in_dev_dump_addr(in_dev, skb, cb, &ctx->ip_idx, + &fillargs); + if (err < 0) + goto done; + } done: - cb->args[0] = h; - cb->args[1] = idx; -put_tgt_net: if (fillargs.netnsid >= 0) put_net(tgt_net); + rcu_read_unlock(); + return err; +} - return skb->len ? : err; +static int inet_dump_ifaddr(struct sk_buff *skb, struct netlink_callback *cb) +{ + return inet_dump_addr(skb, cb, RTM_NEWADDR); +} + +static int inet_dump_ifmcaddr(struct sk_buff *skb, struct netlink_callback *cb) +{ + return inet_dump_addr(skb, cb, RTM_GETMULTICAST); } static void rtmsg_ifa(int event, struct in_ifaddr *ifa, struct nlmsghdr *nlh, @@ -1912,8 +2025,7 @@ static void rtmsg_ifa(int event, struct in_ifaddr *ifa, struct nlmsghdr *nlh, rtnl_notify(skb, net, portid, RTNLGRP_IPV4_IFADDR, nlh, GFP_KERNEL); return; errout: - if (err < 0) - rtnl_set_sk_err(net, RTNLGRP_IPV4_IFADDR, err); + rtnl_set_sk_err(net, RTNLGRP_IPV4_IFADDR, err); } static size_t inet_get_link_af_size(const struct net_device *dev, @@ -1942,7 +2054,7 @@ static int inet_fill_link_af(struct sk_buff *skb, const struct net_device *dev, return -EMSGSIZE; for (i = 0; i < IPV4_DEVCONF_MAX; i++) - ((u32 *) nla_data(nla))[i] = in_dev->cnf.data[i]; + ((u32 *) nla_data(nla))[i] = READ_ONCE(in_dev->cnf.data[i]); return 0; } @@ -2028,9 +2140,9 @@ static int inet_netconf_msgsize_devconf(int type) } static int inet_netconf_fill_devconf(struct sk_buff *skb, int ifindex, - struct ipv4_devconf *devconf, u32 portid, - u32 seq, int event, unsigned int flags, - int type) + const struct ipv4_devconf *devconf, + u32 portid, u32 seq, int event, + unsigned int flags, int type) { struct nlmsghdr *nlh; struct netconfmsg *ncm; @@ -2055,27 +2167,28 @@ static int inet_netconf_fill_devconf(struct sk_buff *skb, int ifindex, if ((all || type == NETCONFA_FORWARDING) && nla_put_s32(skb, NETCONFA_FORWARDING, - IPV4_DEVCONF(*devconf, FORWARDING)) < 0) + IPV4_DEVCONF_RO(*devconf, FORWARDING)) < 0) goto nla_put_failure; if ((all || type == NETCONFA_RP_FILTER) && nla_put_s32(skb, NETCONFA_RP_FILTER, - IPV4_DEVCONF(*devconf, RP_FILTER)) < 0) + IPV4_DEVCONF_RO(*devconf, RP_FILTER)) < 0) goto nla_put_failure; if ((all || type == NETCONFA_MC_FORWARDING) && nla_put_s32(skb, NETCONFA_MC_FORWARDING, - IPV4_DEVCONF(*devconf, MC_FORWARDING)) < 0) + IPV4_DEVCONF_RO(*devconf, MC_FORWARDING)) < 0) goto nla_put_failure; if ((all || type == NETCONFA_BC_FORWARDING) && nla_put_s32(skb, NETCONFA_BC_FORWARDING, - IPV4_DEVCONF(*devconf, BC_FORWARDING)) < 0) + IPV4_DEVCONF_RO(*devconf, BC_FORWARDING)) < 0) goto nla_put_failure; if ((all || type == NETCONFA_PROXY_NEIGH) && nla_put_s32(skb, NETCONFA_PROXY_NEIGH, - IPV4_DEVCONF(*devconf, PROXY_ARP)) < 0) + IPV4_DEVCONF_RO(*devconf, PROXY_ARP)) < 0) goto nla_put_failure; if ((all || type == NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN) && nla_put_s32(skb, NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN, - IPV4_DEVCONF(*devconf, IGNORE_ROUTES_WITH_LINKDOWN)) < 0) + IPV4_DEVCONF_RO(*devconf, + IGNORE_ROUTES_WITH_LINKDOWN)) < 0) goto nla_put_failure; out: @@ -2108,8 +2221,7 @@ void inet_netconf_notify_devconf(struct net *net, int event, int type, rtnl_notify(skb, net, 0, RTNLGRP_IPV4_NETCONF, NULL, GFP_KERNEL); return; errout: - if (err < 0) - rtnl_set_sk_err(net, RTNLGRP_IPV4_NETCONF, err); + rtnl_set_sk_err(net, RTNLGRP_IPV4_NETCONF, err); } static const struct nla_policy devconf_ipv4_policy[NETCONFA_MAX+1] = { @@ -2164,21 +2276,20 @@ static int inet_netconf_get_devconf(struct sk_buff *in_skb, struct netlink_ext_ack *extack) { struct net *net = sock_net(in_skb->sk); - struct nlattr *tb[NETCONFA_MAX+1]; + struct nlattr *tb[NETCONFA_MAX + 1]; + const struct ipv4_devconf *devconf; + struct in_device *in_dev = NULL; + struct net_device *dev = NULL; struct sk_buff *skb; - struct ipv4_devconf *devconf; - struct in_device *in_dev; - struct net_device *dev; int ifindex; int err; err = inet_netconf_valid_get_req(in_skb, nlh, tb, extack); if (err) - goto errout; + return err; - err = -EINVAL; if (!tb[NETCONFA_IFINDEX]) - goto errout; + return -EINVAL; ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]); switch (ifindex) { @@ -2189,10 +2300,10 @@ static int inet_netconf_get_devconf(struct sk_buff *in_skb, devconf = net->ipv4.devconf_dflt; break; default: - dev = __dev_get_by_index(net, ifindex); - if (!dev) - goto errout; - in_dev = __in_dev_get_rtnl(dev); + err = -ENODEV; + dev = dev_get_by_index(net, ifindex); + if (dev) + in_dev = in_dev_get(dev); if (!in_dev) goto errout; devconf = &in_dev->cnf; @@ -2216,6 +2327,9 @@ static int inet_netconf_get_devconf(struct sk_buff *in_skb, } err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); errout: + if (in_dev) + in_dev_put(in_dev); + dev_put(dev); return err; } @@ -2224,11 +2338,13 @@ static int inet_netconf_dump_devconf(struct sk_buff *skb, { const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); - int h, s_h; - int idx, s_idx; + struct { + unsigned long ifindex; + unsigned int all_default; + } *ctx = (void *)cb->ctx; + const struct in_device *in_dev; struct net_device *dev; - struct in_device *in_dev; - struct hlist_head *head; + int err = 0; if (cb->strict_check) { struct netlink_ext_ack *extack = cb->extack; @@ -2245,65 +2361,45 @@ static int inet_netconf_dump_devconf(struct sk_buff *skb, } } - s_h = cb->args[0]; - s_idx = idx = cb->args[1]; - - for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { - idx = 0; - head = &net->dev_index_head[h]; - rcu_read_lock(); - cb->seq = atomic_read(&net->ipv4.dev_addr_genid) ^ - net->dev_base_seq; - hlist_for_each_entry_rcu(dev, head, index_hlist) { - if (idx < s_idx) - goto cont; - in_dev = __in_dev_get_rcu(dev); - if (!in_dev) - goto cont; - - if (inet_netconf_fill_devconf(skb, dev->ifindex, - &in_dev->cnf, - NETLINK_CB(cb->skb).portid, - nlh->nlmsg_seq, - RTM_NEWNETCONF, - NLM_F_MULTI, - NETCONFA_ALL) < 0) { - rcu_read_unlock(); - goto done; - } - nl_dump_check_consistent(cb, nlmsg_hdr(skb)); -cont: - idx++; - } - rcu_read_unlock(); + rcu_read_lock(); + for_each_netdev_dump(net, dev, ctx->ifindex) { + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) + continue; + err = inet_netconf_fill_devconf(skb, dev->ifindex, + &in_dev->cnf, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, NLM_F_MULTI, + NETCONFA_ALL); + if (err < 0) + goto done; } - if (h == NETDEV_HASHENTRIES) { - if (inet_netconf_fill_devconf(skb, NETCONFA_IFINDEX_ALL, - net->ipv4.devconf_all, - NETLINK_CB(cb->skb).portid, - nlh->nlmsg_seq, - RTM_NEWNETCONF, NLM_F_MULTI, - NETCONFA_ALL) < 0) + if (ctx->all_default == 0) { + err = inet_netconf_fill_devconf(skb, NETCONFA_IFINDEX_ALL, + net->ipv4.devconf_all, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, NLM_F_MULTI, + NETCONFA_ALL); + if (err < 0) goto done; - else - h++; - } - if (h == NETDEV_HASHENTRIES + 1) { - if (inet_netconf_fill_devconf(skb, NETCONFA_IFINDEX_DEFAULT, - net->ipv4.devconf_dflt, - NETLINK_CB(cb->skb).portid, - nlh->nlmsg_seq, - RTM_NEWNETCONF, NLM_F_MULTI, - NETCONFA_ALL) < 0) + ctx->all_default++; + } + if (ctx->all_default == 1) { + err = inet_netconf_fill_devconf(skb, NETCONFA_IFINDEX_DEFAULT, + net->ipv4.devconf_dflt, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, NLM_F_MULTI, + NETCONFA_ALL); + if (err < 0) goto done; - else - h++; + ctx->all_default++; } done: - cb->args[0] = h; - cb->args[1] = idx; - - return skb->len; + rcu_read_unlock(); + return err; } #ifdef CONFIG_SYSCTL @@ -2346,7 +2442,7 @@ static void inet_forward_change(struct net *net) if (on) dev_disable_lro(dev); - in_dev = __in_dev_get_rtnl(dev); + in_dev = __in_dev_get_rtnl_net(dev); if (in_dev) { IN_DEV_CONF_SET(in_dev, FORWARDING, on); inet_netconf_notify_devconf(net, RTM_NEWNETCONF, @@ -2369,7 +2465,7 @@ static int devinet_conf_ifindex(struct net *net, struct ipv4_devconf *cnf) } } -static int devinet_conf_proc(struct ctl_table *ctl, int write, +static int devinet_conf_proc(const struct ctl_table *ctl, int write, void *buffer, size_t *lenp, loff_t *ppos) { int old_value = *(int *)ctl->data; @@ -2421,7 +2517,7 @@ static int devinet_conf_proc(struct ctl_table *ctl, int write, return ret; } -static int devinet_sysctl_forward(struct ctl_table *ctl, int write, +static int devinet_sysctl_forward(const struct ctl_table *ctl, int write, void *buffer, size_t *lenp, loff_t *ppos) { int *valp = ctl->data; @@ -2437,7 +2533,7 @@ static int devinet_sysctl_forward(struct ctl_table *ctl, int write, if (write && *valp != val) { if (valp != &IPV4_DEVCONF_DFLT(net, FORWARDING)) { - if (!rtnl_trylock()) { + if (!rtnl_net_trylock(net)) { /* Restore the original values before restarting */ *valp = val; *ppos = pos; @@ -2456,7 +2552,7 @@ static int devinet_sysctl_forward(struct ctl_table *ctl, int write, idev->dev->ifindex, cnf); } - rtnl_unlock(); + rtnl_net_unlock(net); rt_cache_flush(net); } else inet_netconf_notify_devconf(net, RTM_NEWNETCONF, @@ -2468,7 +2564,7 @@ static int devinet_sysctl_forward(struct ctl_table *ctl, int write, return ret; } -static int ipv4_doint_and_flush(struct ctl_table *ctl, int write, +static int ipv4_doint_and_flush(const struct ctl_table *ctl, int write, void *buffer, size_t *lenp, loff_t *ppos) { int *valp = ctl->data; @@ -2507,7 +2603,7 @@ static int ipv4_doint_and_flush(struct ctl_table *ctl, int write, static struct devinet_sysctl_table { struct ctl_table_header *sysctl_header; - struct ctl_table devinet_vars[__IPV4_DEVCONF_MAX]; + struct ctl_table devinet_vars[IPV4_DEVCONF_MAX]; } devinet_sysctl = { .devinet_vars = { DEVINET_SYSCTL_COMPLEX_ENTRY(FORWARDING, "forwarding", @@ -2566,11 +2662,11 @@ static int __devinet_sysctl_register(struct net *net, char *dev_name, struct devinet_sysctl_table *t; char path[sizeof("net/ipv4/conf/") + IFNAMSIZ]; - t = kmemdup(&devinet_sysctl, sizeof(*t), GFP_KERNEL); + t = kmemdup(&devinet_sysctl, sizeof(*t), GFP_KERNEL_ACCOUNT); if (!t) goto out; - for (i = 0; i < ARRAY_SIZE(t->devinet_vars) - 1; i++) { + for (i = 0; i < ARRAY_SIZE(t->devinet_vars); i++) { t->devinet_vars[i].data += (char *)p - (char *)&ipv4_devconf; t->devinet_vars[i].extra1 = p; t->devinet_vars[i].extra2 = net; @@ -2644,20 +2740,26 @@ static struct ctl_table ctl_forward_entry[] = { .extra1 = &ipv4_devconf, .extra2 = &init_net, }, - { }, }; #endif static __net_init int devinet_init_net(struct net *net) { - int err; - struct ipv4_devconf *all, *dflt; #ifdef CONFIG_SYSCTL - struct ctl_table *tbl; struct ctl_table_header *forw_hdr; + struct ctl_table *tbl; #endif + struct ipv4_devconf *all, *dflt; + int err; + int i; err = -ENOMEM; + net->ipv4.inet_addr_lst = kmalloc_array(IN4_ADDR_HSIZE, + sizeof(struct hlist_head), + GFP_KERNEL); + if (!net->ipv4.inet_addr_lst) + goto err_alloc_hash; + all = kmemdup(&ipv4_devconf, sizeof(ipv4_devconf), GFP_KERNEL); if (!all) goto err_alloc_all; @@ -2677,23 +2779,27 @@ static __net_init int devinet_init_net(struct net *net) #endif if (!net_eq(net, &init_net)) { - if (IS_ENABLED(CONFIG_SYSCTL) && - sysctl_devconf_inherit_init_net == 3) { + switch (net_inherit_devconf()) { + case 3: /* copy from the current netns */ memcpy(all, current->nsproxy->net_ns->ipv4.devconf_all, sizeof(ipv4_devconf)); memcpy(dflt, current->nsproxy->net_ns->ipv4.devconf_dflt, sizeof(ipv4_devconf_dflt)); - } else if (!IS_ENABLED(CONFIG_SYSCTL) || - sysctl_devconf_inherit_init_net != 2) { - /* inherit == 0 or 1: copy from init_net */ + break; + case 0: + case 1: + /* copy from init_net */ memcpy(all, init_net.ipv4.devconf_all, sizeof(ipv4_devconf)); memcpy(dflt, init_net.ipv4.devconf_dflt, sizeof(ipv4_devconf_dflt)); + break; + case 2: + /* use compiled values */ + break; } - /* else inherit == 2: use compiled values */ } #ifdef CONFIG_SYSCTL @@ -2707,12 +2813,18 @@ static __net_init int devinet_init_net(struct net *net) goto err_reg_dflt; err = -ENOMEM; - forw_hdr = register_net_sysctl(net, "net/ipv4", tbl); + forw_hdr = register_net_sysctl_sz(net, "net/ipv4", tbl, + ARRAY_SIZE(ctl_forward_entry)); if (!forw_hdr) goto err_reg_ctl; net->ipv4.forw_hdr = forw_hdr; #endif + for (i = 0; i < IN4_ADDR_HSIZE; i++) + INIT_HLIST_HEAD(&net->ipv4.inet_addr_lst[i]); + + INIT_DEFERRABLE_WORK(&net->ipv4.addr_chk_work, check_lifetime); + net->ipv4.devconf_all = all; net->ipv4.devconf_dflt = dflt; return 0; @@ -2730,14 +2842,20 @@ err_alloc_ctl: err_alloc_dflt: kfree(all); err_alloc_all: + kfree(net->ipv4.inet_addr_lst); +err_alloc_hash: return err; } static __net_exit void devinet_exit_net(struct net *net) { #ifdef CONFIG_SYSCTL - struct ctl_table *tbl; + const struct ctl_table *tbl; +#endif + cancel_delayed_work_sync(&net->ipv4.addr_chk_work); + +#ifdef CONFIG_SYSCTL tbl = net->ipv4.forw_hdr->ctl_table_arg; unregister_net_sysctl_table(net->ipv4.forw_hdr); __devinet_sysctl_unregister(net, net->ipv4.devconf_dflt, @@ -2748,6 +2866,7 @@ static __net_exit void devinet_exit_net(struct net *net) #endif kfree(net->ipv4.devconf_dflt); kfree(net->ipv4.devconf_all); + kfree(net->ipv4.inet_addr_lst); } static __net_initdata struct pernet_operations devinet_ops = { @@ -2763,23 +2882,27 @@ static struct rtnl_af_ops inet_af_ops __read_mostly = { .set_link_af = inet_set_link_af, }; +static const struct rtnl_msg_handler devinet_rtnl_msg_handlers[] __initconst = { + {.protocol = PF_INET, .msgtype = RTM_NEWADDR, .doit = inet_rtm_newaddr, + .flags = RTNL_FLAG_DOIT_PERNET}, + {.protocol = PF_INET, .msgtype = RTM_DELADDR, .doit = inet_rtm_deladdr, + .flags = RTNL_FLAG_DOIT_PERNET}, + {.protocol = PF_INET, .msgtype = RTM_GETADDR, .dumpit = inet_dump_ifaddr, + .flags = RTNL_FLAG_DUMP_UNLOCKED | RTNL_FLAG_DUMP_SPLIT_NLM_DONE}, + {.protocol = PF_INET, .msgtype = RTM_GETNETCONF, + .doit = inet_netconf_get_devconf, .dumpit = inet_netconf_dump_devconf, + .flags = RTNL_FLAG_DOIT_UNLOCKED | RTNL_FLAG_DUMP_UNLOCKED}, + {.owner = THIS_MODULE, .protocol = PF_INET, .msgtype = RTM_GETMULTICAST, + .dumpit = inet_dump_ifmcaddr, .flags = RTNL_FLAG_DUMP_UNLOCKED}, +}; + void __init devinet_init(void) { - int i; - - for (i = 0; i < IN4_ADDR_HSIZE; i++) - INIT_HLIST_HEAD(&inet_addr_lst[i]); - register_pernet_subsys(&devinet_ops); register_netdevice_notifier(&ip_netdev_notifier); - queue_delayed_work(system_power_efficient_wq, &check_lifetime_work, 0); - - rtnl_af_register(&inet_af_ops); + if (rtnl_af_register(&inet_af_ops)) + panic("Unable to register inet_af_ops\n"); - rtnl_register(PF_INET, RTM_NEWADDR, inet_rtm_newaddr, NULL, 0); - rtnl_register(PF_INET, RTM_DELADDR, inet_rtm_deladdr, NULL, 0); - rtnl_register(PF_INET, RTM_GETADDR, NULL, inet_dump_ifaddr, 0); - rtnl_register(PF_INET, RTM_GETNETCONF, inet_netconf_get_devconf, - inet_netconf_dump_devconf, 0); + rtnl_register_many(devinet_rtnl_msg_handlers); } diff --git a/net/ipv4/esp4.c b/net/ipv4/esp4.c index 851f542928a3..2c922afadb8f 100644 --- a/net/ipv4/esp4.c +++ b/net/ipv4/esp4.c @@ -20,6 +20,7 @@ #include <net/udp.h> #include <net/tcp.h> #include <net/espintcp.h> +#include <linux/skbuff_ref.h> #include <linux/highmem.h> @@ -95,7 +96,7 @@ static inline struct scatterlist *esp_req_sg(struct crypto_aead *aead, __alignof__(struct scatterlist)); } -static void esp_ssg_unref(struct xfrm_state *x, void *tmp) +static void esp_ssg_unref(struct xfrm_state *x, void *tmp, struct sk_buff *skb) { struct crypto_aead *aead = x->data; int extralen = 0; @@ -114,54 +115,25 @@ static void esp_ssg_unref(struct xfrm_state *x, void *tmp) */ if (req->src != req->dst) for (sg = sg_next(req->src); sg; sg = sg_next(sg)) - put_page(sg_page(sg)); + skb_page_unref(page_to_netmem(sg_page(sg)), + skb->pp_recycle); } #ifdef CONFIG_INET_ESPINTCP -struct esp_tcp_sk { - struct sock *sk; - struct rcu_head rcu; -}; - -static void esp_free_tcp_sk(struct rcu_head *head) -{ - struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu); - - sock_put(esk->sk); - kfree(esk); -} - static struct sock *esp_find_tcp_sk(struct xfrm_state *x) { struct xfrm_encap_tmpl *encap = x->encap; - struct esp_tcp_sk *esk; + struct net *net = xs_net(x); __be16 sport, dport; - struct sock *nsk; struct sock *sk; - sk = rcu_dereference(x->encap_sk); - if (sk && sk->sk_state == TCP_ESTABLISHED) - return sk; - spin_lock_bh(&x->lock); sport = encap->encap_sport; dport = encap->encap_dport; - nsk = rcu_dereference_protected(x->encap_sk, - lockdep_is_held(&x->lock)); - if (sk && sk == nsk) { - esk = kmalloc(sizeof(*esk), GFP_ATOMIC); - if (!esk) { - spin_unlock_bh(&x->lock); - return ERR_PTR(-ENOMEM); - } - RCU_INIT_POINTER(x->encap_sk, NULL); - esk->sk = sk; - call_rcu(&esk->rcu, esp_free_tcp_sk); - } spin_unlock_bh(&x->lock); - sk = inet_lookup_established(xs_net(x), &tcp_hashinfo, x->id.daddr.a4, - dport, x->props.saddr.a4, sport, 0); + sk = inet_lookup_established(net, x->id.daddr.a4, dport, + x->props.saddr.a4, sport, 0); if (!sk) return ERR_PTR(-ENOENT); @@ -170,20 +142,6 @@ static struct sock *esp_find_tcp_sk(struct xfrm_state *x) return ERR_PTR(-EINVAL); } - spin_lock_bh(&x->lock); - nsk = rcu_dereference_protected(x->encap_sk, - lockdep_is_held(&x->lock)); - if (encap->encap_sport != sport || - encap->encap_dport != dport) { - sock_put(sk); - sk = nsk ?: ERR_PTR(-EREMCHG); - } else if (sk == nsk) { - sock_put(sk); - } else { - rcu_assign_pointer(x->encap_sk, sk); - } - spin_unlock_bh(&x->lock); - return sk; } @@ -196,8 +154,10 @@ static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb) sk = esp_find_tcp_sk(x); err = PTR_ERR_OR_ZERO(sk); - if (err) + if (err) { + kfree_skb(skb); goto out; + } bh_lock_sock(sk); if (sock_owned_by_user(sk)) @@ -206,6 +166,8 @@ static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb) err = espintcp_push_skb(sk, skb); bh_unlock_sock(sk); + sock_put(sk); + out: rcu_read_unlock(); return err; @@ -237,15 +199,14 @@ static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb) #else static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb) { - kfree_skb(skb); - + WARN_ON(1); return -EOPNOTSUPP; } #endif -static void esp_output_done(struct crypto_async_request *base, int err) +static void esp_output_done(void *data, int err) { - struct sk_buff *skb = base->data; + struct sk_buff *skb = data; struct xfrm_offload *xo = xfrm_offload(skb); void *tmp; struct xfrm_state *x; @@ -259,7 +220,7 @@ static void esp_output_done(struct crypto_async_request *base, int err) } tmp = ESP_SKB_CB(skb)->tmp; - esp_ssg_unref(x, tmp); + esp_ssg_unref(x, tmp, skb); kfree(tmp); if (xo && (xo->flags & XFRM_DEV_RESUME)) { @@ -277,7 +238,7 @@ static void esp_output_done(struct crypto_async_request *base, int err) x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP) esp_output_tail_tcp(x, skb); else - xfrm_output_resume(skb->sk, skb, err); + xfrm_output_resume(skb_to_full_sk(skb), skb, err); } } @@ -331,12 +292,12 @@ static struct ip_esp_hdr *esp_output_set_extra(struct sk_buff *skb, return esph; } -static void esp_output_done_esn(struct crypto_async_request *base, int err) +static void esp_output_done_esn(void *data, int err) { - struct sk_buff *skb = base->data; + struct sk_buff *skb = data; esp_output_restore_header(skb); - esp_output_done(base, err); + esp_output_done(data, err); } static struct ip_esp_hdr *esp_output_udp_encap(struct sk_buff *skb, @@ -346,8 +307,8 @@ static struct ip_esp_hdr *esp_output_udp_encap(struct sk_buff *skb, __be16 dport) { struct udphdr *uh; - __be32 *udpdata32; unsigned int len; + struct xfrm_offload *xo = xfrm_offload(skb); len = skb->len + esp->tailen - skb_transport_offset(skb); if (len + sizeof(struct iphdr) > IP_MAX_MTU) @@ -359,13 +320,12 @@ static struct ip_esp_hdr *esp_output_udp_encap(struct sk_buff *skb, uh->len = htons(len); uh->check = 0; - *skb_mac_header(skb) = IPPROTO_UDP; - - if (encap_type == UDP_ENCAP_ESPINUDP_NON_IKE) { - udpdata32 = (__be32 *)(uh + 1); - udpdata32[0] = udpdata32[1] = 0; - return (struct ip_esp_hdr *)(udpdata32 + 2); - } + /* For IPv4 ESP with UDP encapsulation, if xo is not null, the skb is in the crypto offload + * data path, which means that esp_output_udp_encap is called outside of the XFRM stack. + * In this case, the mac header doesn't point to the IPv4 protocol field, so don't set it. + */ + if (!xo || encap_type != UDP_ENCAP_ESPINUDP) + *skb_mac_header(skb) = IPPROTO_UDP; return (struct ip_esp_hdr *)(uh + 1); } @@ -391,6 +351,8 @@ static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x, if (IS_ERR(sk)) return ERR_CAST(sk); + sock_put(sk); + *lenp = htons(len); esph = (struct ip_esp_hdr *)(lenp + 1); @@ -422,7 +384,6 @@ static int esp_output_encap(struct xfrm_state *x, struct sk_buff *skb, switch (encap_type) { default: case UDP_ENCAP_ESPINUDP: - case UDP_ENCAP_ESPINUDP_NON_IKE: esph = esp_output_udp_encap(skb, encap_type, esp, sport, dport); break; case TCP_ENCAP_ESPINTCP: @@ -455,6 +416,10 @@ int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info * return err; } + if (ALIGN(tailen, L1_CACHE_BYTES) > PAGE_SIZE || + ALIGN(skb->data_len, L1_CACHE_BYTES) > PAGE_SIZE) + goto cow; + if (!skb_cloned(skb)) { if (tailen <= skb_tailroom(skb)) { nfrags = 1; @@ -498,9 +463,7 @@ int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info * nfrags++; - skb->len += tailen; - skb->data_len += tailen; - skb->truesize += tailen; + skb_len_add(skb, tailen); if (sk && sk_fullsock(sk)) refcount_add(tailen, &sk->sk_wmem_alloc); @@ -636,7 +599,7 @@ int esp_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info * } if (sg != dsg) - esp_ssg_unref(x, tmp); + esp_ssg_unref(x, tmp, skb); if (!err && x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP) err = esp_output_tail_tcp(x, skb); @@ -671,7 +634,7 @@ static int esp_output(struct xfrm_state *x, struct sk_buff *skb) struct xfrm_dst *dst = (struct xfrm_dst *)skb_dst(skb); u32 padto; - padto = min(x->tfcpad, __xfrm_state_mtu(x, dst->child_mtu_cached)); + padto = min(x->tfcpad, xfrm_state_mtu(x, dst->child_mtu_cached)); if (skb->len < padto) esp.tfclen = padto - skb->len; } @@ -701,7 +664,6 @@ static int esp_output(struct xfrm_state *x, struct sk_buff *skb) static inline int esp_remove_trailer(struct sk_buff *skb) { struct xfrm_state *x = xfrm_input_state(skb); - struct xfrm_offload *xo = xfrm_offload(skb); struct crypto_aead *aead = x->data; int alen, hlen, elen; int padlen, trimlen; @@ -713,11 +675,6 @@ static inline int esp_remove_trailer(struct sk_buff *skb) hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead); elen = skb->len - hlen; - if (xo && (xo->flags & XFRM_ESP_NO_TRAILER)) { - ret = xo->proto; - goto out; - } - if (skb_copy_bits(skb, skb->len - alen - 2, nexthdr, 2)) BUG(); @@ -735,7 +692,9 @@ static inline int esp_remove_trailer(struct sk_buff *skb) skb->csum = csum_block_sub(skb->csum, csumdiff, skb->len - trimlen); } - pskb_trim(skb, skb->len - trimlen); + ret = pskb_trim(skb, skb->len - trimlen); + if (unlikely(ret)) + return ret; ret = nexthdr[1]; @@ -776,7 +735,6 @@ int esp_input_done2(struct sk_buff *skb, int err) source = th->source; break; case UDP_ENCAP_ESPINUDP: - case UDP_ENCAP_ESPINUDP_NON_IKE: source = uh->source; break; default: @@ -787,7 +745,7 @@ int esp_input_done2(struct sk_buff *skb, int err) /* * 1) if the NAT-T peer's IP or port changed then - * advertize the change to the keying daemon. + * advertise the change to the keying daemon. * This is an inbound SA, so just compare * SRC ports. */ @@ -819,7 +777,8 @@ int esp_input_done2(struct sk_buff *skb, int err) } skb_pull_rcsum(skb, hlen); - if (x->props.mode == XFRM_MODE_TUNNEL) + if (x->props.mode == XFRM_MODE_TUNNEL || + x->props.mode == XFRM_MODE_IPTFS) skb_reset_transport_header(skb); else skb_set_transport_header(skb, -ihl); @@ -833,9 +792,9 @@ out: } EXPORT_SYMBOL_GPL(esp_input_done2); -static void esp_input_done(struct crypto_async_request *base, int err) +static void esp_input_done(void *data, int err) { - struct sk_buff *skb = base->data; + struct sk_buff *skb = data; xfrm_input_resume(skb, esp_input_done2(skb, err)); } @@ -863,12 +822,12 @@ static void esp_input_set_header(struct sk_buff *skb, __be32 *seqhi) } } -static void esp_input_done_esn(struct crypto_async_request *base, int err) +static void esp_input_done_esn(void *data, int err) { - struct sk_buff *skb = base->data; + struct sk_buff *skb = data; esp_input_restore_header(skb); - esp_input_done(base, err); + esp_input_done(data, err); } /* @@ -1011,16 +970,17 @@ static void esp_destroy(struct xfrm_state *x) crypto_free_aead(aead); } -static int esp_init_aead(struct xfrm_state *x) +static int esp_init_aead(struct xfrm_state *x, struct netlink_ext_ack *extack) { char aead_name[CRYPTO_MAX_ALG_NAME]; struct crypto_aead *aead; int err; - err = -ENAMETOOLONG; if (snprintf(aead_name, CRYPTO_MAX_ALG_NAME, "%s(%s)", - x->geniv, x->aead->alg_name) >= CRYPTO_MAX_ALG_NAME) - goto error; + x->geniv, x->aead->alg_name) >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); + return -ENAMETOOLONG; + } aead = crypto_alloc_aead(aead_name, 0, 0); err = PTR_ERR(aead); @@ -1038,11 +998,15 @@ static int esp_init_aead(struct xfrm_state *x) if (err) goto error; + return 0; + error: + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); return err; } -static int esp_init_authenc(struct xfrm_state *x) +static int esp_init_authenc(struct xfrm_state *x, + struct netlink_ext_ack *extack) { struct crypto_aead *aead; struct crypto_authenc_key_param *param; @@ -1053,10 +1017,6 @@ static int esp_init_authenc(struct xfrm_state *x) unsigned int keylen; int err; - err = -EINVAL; - if (!x->ealg) - goto error; - err = -ENAMETOOLONG; if ((x->props.flags & XFRM_STATE_ESN)) { @@ -1065,22 +1025,28 @@ static int esp_init_authenc(struct xfrm_state *x) x->geniv ?: "", x->geniv ? "(" : "", x->aalg ? x->aalg->alg_name : "digest_null", x->ealg->alg_name, - x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) + x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); goto error; + } } else { if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME, "%s%sauthenc(%s,%s)%s", x->geniv ?: "", x->geniv ? "(" : "", x->aalg ? x->aalg->alg_name : "digest_null", x->ealg->alg_name, - x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) + x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); goto error; + } } aead = crypto_alloc_aead(authenc_name, 0, 0); err = PTR_ERR(aead); - if (IS_ERR(aead)) + if (IS_ERR(aead)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto error; + } x->data = aead; @@ -1110,17 +1076,16 @@ static int esp_init_authenc(struct xfrm_state *x) err = -EINVAL; if (aalg_desc->uinfo.auth.icv_fullbits / 8 != crypto_aead_authsize(aead)) { - pr_info("ESP: %s digestsize %u != %hu\n", - x->aalg->alg_name, - crypto_aead_authsize(aead), - aalg_desc->uinfo.auth.icv_fullbits / 8); + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto free_key; } err = crypto_aead_setauthsize( aead, x->aalg->alg_trunc_len / 8); - if (err) + if (err) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); goto free_key; + } } param->enckeylen = cpu_to_be32((x->ealg->alg_key_len + 7) / 8); @@ -1129,13 +1094,13 @@ static int esp_init_authenc(struct xfrm_state *x) err = crypto_aead_setkey(aead, key, keylen); free_key: - kfree(key); + kfree_sensitive(key); error: return err; } -static int esp_init_state(struct xfrm_state *x) +static int esp_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) { struct crypto_aead *aead; u32 align; @@ -1143,10 +1108,14 @@ static int esp_init_state(struct xfrm_state *x) x->data = NULL; - if (x->aead) - err = esp_init_aead(x); - else - err = esp_init_authenc(x); + if (x->aead) { + err = esp_init_aead(x, extack); + } else if (x->ealg) { + err = esp_init_authenc(x, extack); + } else { + NL_SET_ERR_MSG(extack, "ESP: AEAD or CRYPT must be provided"); + err = -EINVAL; + } if (err) goto error; @@ -1164,14 +1133,12 @@ static int esp_init_state(struct xfrm_state *x) switch (encap->encap_type) { default: + NL_SET_ERR_MSG(extack, "Unsupported encapsulation type for ESP"); err = -EINVAL; goto error; case UDP_ENCAP_ESPINUDP: x->props.header_len += sizeof(struct udphdr); break; - case UDP_ENCAP_ESPINUDP_NON_IKE: - x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32); - break; #ifdef CONFIG_INET_ESPINTCP case TCP_ENCAP_ESPINTCP: /* only the length field, TCP encap is done by @@ -1237,5 +1204,6 @@ static void __exit esp4_fini(void) module_init(esp4_init); module_exit(esp4_fini); +MODULE_DESCRIPTION("IPv4 ESP transformation library"); MODULE_LICENSE("GPL"); MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_ESP); diff --git a/net/ipv4/esp4_offload.c b/net/ipv4/esp4_offload.c index d87f02a6e934..05828d4cb6cd 100644 --- a/net/ipv4/esp4_offload.c +++ b/net/ipv4/esp4_offload.c @@ -17,6 +17,7 @@ #include <linux/err.h> #include <linux/module.h> #include <net/gro.h> +#include <net/gso.h> #include <net/ip.h> #include <net/xfrm.h> #include <net/esp.h> @@ -32,6 +33,7 @@ static struct sk_buff *esp4_gro_receive(struct list_head *head, int offset = skb_gro_offset(skb); struct xfrm_offload *xo; struct xfrm_state *x; + int encap_type = 0; __be32 seq; __be32 spi; @@ -51,9 +53,16 @@ static struct sk_buff *esp4_gro_receive(struct list_head *head, if (sp->len == XFRM_MAX_DEPTH) goto out_reset; - x = xfrm_state_lookup(dev_net(skb->dev), skb->mark, - (xfrm_address_t *)&ip_hdr(skb)->daddr, - spi, IPPROTO_ESP, AF_INET); + x = xfrm_input_state_lookup(dev_net(skb->dev), skb->mark, + (xfrm_address_t *)&ip_hdr(skb)->daddr, + spi, IPPROTO_ESP, AF_INET); + + if (unlikely(x && x->dir && x->dir != XFRM_SA_DIR_IN)) { + /* non-offload path will record the error and audit log */ + xfrm_state_put(x); + x = NULL; + } + if (!x) goto out_reset; @@ -69,6 +78,9 @@ static struct sk_buff *esp4_gro_receive(struct list_head *head, xo->flags |= XFRM_GRO; + if (NAPI_GRO_CB(skb)->proto == IPPROTO_UDP) + encap_type = UDP_ENCAP_ESPINUDP; + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = NULL; XFRM_SPI_SKB_CB(skb)->family = AF_INET; XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct iphdr, daddr); @@ -76,7 +88,7 @@ static struct sk_buff *esp4_gro_receive(struct list_head *head, /* We don't need to handle errors from xfrm_input, it does all * the error handling and frees the resources on error. */ - xfrm_input(skb, IPPROTO_ESP, spi, -2); + xfrm_input(skb, IPPROTO_ESP, spi, encap_type); return ERR_PTR(-EINPROGRESS); out_reset: @@ -110,8 +122,12 @@ static struct sk_buff *xfrm4_tunnel_gso_segment(struct xfrm_state *x, struct sk_buff *skb, netdev_features_t features) { - __skb_push(skb, skb->mac_len); - return skb_mac_gso_segment(skb, features); + const struct xfrm_mode *inner_mode = xfrm_ip2inner_mode(x, + XFRM_MODE_SKB_CB(skb)->protocol); + __be16 type = inner_mode->family == AF_INET6 ? htons(ETH_P_IPV6) + : htons(ETH_P_IP); + + return skb_eth_gso_segment(skb, features, type); } static struct sk_buff *xfrm4_transport_gso_segment(struct xfrm_state *x, @@ -160,6 +176,9 @@ static struct sk_buff *xfrm4_beet_gso_segment(struct xfrm_state *x, skb_shinfo(skb)->gso_type |= SKB_GSO_TCPV4; } + if (proto == IPPROTO_IPV6) + skb_shinfo(skb)->gso_type |= SKB_GSO_IPXIP4; + __skb_pull(skb, skb_transport_offset(skb)); ops = rcu_dereference(inet_offloads[proto]); if (likely(ops && ops->callbacks.gso_segment)) @@ -254,6 +273,7 @@ static int esp_xmit(struct xfrm_state *x, struct sk_buff *skb, netdev_features_ struct esp_info esp; bool hw_offload = true; __u32 seq; + int encap_type = 0; esp.inplace = true; @@ -286,8 +306,10 @@ static int esp_xmit(struct xfrm_state *x, struct sk_buff *skb, netdev_features_ esp.esph = ip_esp_hdr(skb); + if (x->encap) + encap_type = x->encap->encap_type; - if (!hw_offload || !skb_is_gso(skb)) { + if (!hw_offload || !skb_is_gso(skb) || (hw_offload && encap_type == UDP_ENCAP_ESPINUDP)) { esp.nfrags = esp_output_head(x, skb, &esp); if (esp.nfrags < 0) return esp.nfrags; @@ -309,8 +331,23 @@ static int esp_xmit(struct xfrm_state *x, struct sk_buff *skb, netdev_features_ xo->seq.low += skb_shinfo(skb)->gso_segs; } + if (xo->seq.low < seq) + xo->seq.hi++; + esp.seqno = cpu_to_be64(seq + ((u64)xo->seq.hi << 32)); + if (hw_offload && encap_type == UDP_ENCAP_ESPINUDP) { + /* In the XFRM stack, the encapsulation protocol is set to iphdr->protocol by + * setting *skb_mac_header(skb) (see esp_output_udp_encap()) where skb->mac_header + * points to iphdr->protocol (see xfrm4_tunnel_encap_add()). + * However, in esp_xmit(), skb->mac_header doesn't point to iphdr->protocol. + * Therefore, the protocol field needs to be corrected. + */ + ip_hdr(skb)->protocol = IPPROTO_UDP; + + esph->seq_no = htonl(seq); + } + ip_hdr(skb)->tot_len = htons(skb->len); ip_send_check(ip_hdr(skb)); @@ -332,6 +369,9 @@ static int esp_xmit(struct xfrm_state *x, struct sk_buff *skb, netdev_features_ secpath_reset(skb); + if (skb_needs_linearize(skb, skb->dev->features) && + __skb_linearize(skb)) + return -ENOMEM; return 0; } diff --git a/net/ipv4/fib_frontend.c b/net/ipv4/fib_frontend.c index 4d61ddd8a0ec..1dab44e13d3b 100644 --- a/net/ipv4/fib_frontend.c +++ b/net/ipv4/fib_frontend.c @@ -32,6 +32,8 @@ #include <linux/list.h> #include <linux/slab.h> +#include <net/flow.h> +#include <net/inet_dscp.h> #include <net/ip.h> #include <net/protocol.h> #include <net/route.h> @@ -290,9 +292,9 @@ __be32 fib_compute_spec_dst(struct sk_buff *skb) bool vmark = in_dev && IN_DEV_SRC_VMARK(in_dev); struct flowi4 fl4 = { .flowi4_iif = LOOPBACK_IFINDEX, - .flowi4_oif = l3mdev_master_ifindex_rcu(dev), + .flowi4_l3mdev = l3mdev_master_ifindex_rcu(dev), .daddr = ip_hdr(skb)->saddr, - .flowi4_tos = ip_hdr(skb)->tos & IPTOS_RT_MASK, + .flowi4_dscp = ip4h_dscp(ip_hdr(skb)), .flowi4_scope = scope, .flowi4_mark = vmark ? skb->mark : 0, }; @@ -341,10 +343,11 @@ EXPORT_SYMBOL_GPL(fib_info_nh_uses_dev); * called with rcu_read_lock() */ static int __fib_validate_source(struct sk_buff *skb, __be32 src, __be32 dst, - u8 tos, int oif, struct net_device *dev, + dscp_t dscp, int oif, struct net_device *dev, int rpf, struct in_device *idev, u32 *itag) { struct net *net = dev_net(dev); + enum skb_drop_reason reason; struct flow_keys flkeys; int ret, no_addr; struct fib_result res; @@ -352,12 +355,11 @@ static int __fib_validate_source(struct sk_buff *skb, __be32 src, __be32 dst, bool dev_match; fl4.flowi4_oif = 0; - fl4.flowi4_iif = l3mdev_master_ifindex_rcu(dev); - if (!fl4.flowi4_iif) - fl4.flowi4_iif = oif ? : LOOPBACK_IFINDEX; + fl4.flowi4_l3mdev = l3mdev_master_ifindex_rcu(dev); + fl4.flowi4_iif = oif ? : LOOPBACK_IFINDEX; fl4.daddr = src; fl4.saddr = dst; - fl4.flowi4_tos = tos; + fl4.flowi4_dscp = dscp; fl4.flowi4_scope = RT_SCOPE_UNIVERSE; fl4.flowi4_tun_key.tun_id = 0; fl4.flowi4_flags = 0; @@ -377,9 +379,15 @@ static int __fib_validate_source(struct sk_buff *skb, __be32 src, __be32 dst, if (fib_lookup(net, &fl4, &res, 0)) goto last_resort; - if (res.type != RTN_UNICAST && - (res.type != RTN_LOCAL || !IN_DEV_ACCEPT_LOCAL(idev))) - goto e_inval; + if (res.type != RTN_UNICAST) { + if (res.type != RTN_LOCAL) { + reason = SKB_DROP_REASON_IP_INVALID_SOURCE; + goto e_inval; + } else if (!IN_DEV_ACCEPT_LOCAL(idev)) { + reason = SKB_DROP_REASON_IP_LOCAL_SOURCE; + goto e_inval; + } + } fib_combine_itag(itag, &res); dev_match = fib_info_nh_uses_dev(res.fi, dev); @@ -412,14 +420,14 @@ last_resort: return 0; e_inval: - return -EINVAL; + return -reason; e_rpf: - return -EXDEV; + return -SKB_DROP_REASON_IP_RPFILTER; } /* Ignore rp_filter for packets protected by IPsec. */ int fib_validate_source(struct sk_buff *skb, __be32 src, __be32 dst, - u8 tos, int oif, struct net_device *dev, + dscp_t dscp, int oif, struct net_device *dev, struct in_device *idev, u32 *itag) { int r = secpath_exists(skb) ? 0 : IN_DEV_RPFILTER(idev); @@ -436,8 +444,11 @@ int fib_validate_source(struct sk_buff *skb, __be32 src, __be32 dst, if (net->ipv4.fib_has_custom_local_routes || fib4_has_custom_rules(net)) goto full_check; + /* Within the same container, it is regarded as a martian source, + * and the same host but different containers are not. + */ if (inet_lookup_ifaddr_rcu(net, src)) - return -EINVAL; + return -SKB_DROP_REASON_IP_LOCAL_SOURCE; ok: *itag = 0; @@ -445,7 +456,8 @@ ok: } full_check: - return __fib_validate_source(skb, src, dst, tos, oif, dev, r, idev, itag); + return __fib_validate_source(skb, src, dst, dscp, oif, dev, r, idev, + itag); } static inline __be32 sk_extract_addr(struct sockaddr *addr) @@ -542,18 +554,16 @@ static int rtentry_to_fib_config(struct net *net, int cmd, struct rtentry *rt, const struct in_ifaddr *ifa; struct in_device *in_dev; - in_dev = __in_dev_get_rtnl(dev); + in_dev = __in_dev_get_rtnl_net(dev); if (!in_dev) return -ENODEV; *colon = ':'; - rcu_read_lock(); - in_dev_for_each_ifa_rcu(ifa, in_dev) { + in_dev_for_each_ifa_rtnl_net(net, ifa, in_dev) { if (strcmp(ifa->ifa_label, devname) == 0) break; } - rcu_read_unlock(); if (!ifa) return -ENODEV; @@ -573,6 +583,9 @@ static int rtentry_to_fib_config(struct net *net, int cmd, struct rtentry *rt, cfg->fc_scope = RT_SCOPE_UNIVERSE; } + if (!cfg->fc_table) + cfg->fc_table = RT_TABLE_MAIN; + if (cmd == SIOCDELRT) return 0; @@ -621,7 +634,7 @@ int ip_rt_ioctl(struct net *net, unsigned int cmd, struct rtentry *rt) if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) return -EPERM; - rtnl_lock(); + rtnl_net_lock(net); err = rtentry_to_fib_config(net, cmd, rt, &cfg); if (err == 0) { struct fib_table *tb; @@ -645,7 +658,7 @@ int ip_rt_ioctl(struct net *net, unsigned int cmd, struct rtentry *rt) /* allocated by rtentry_to_fib_config() */ kfree(cfg.fc_mx); } - rtnl_unlock(); + rtnl_net_unlock(net); return err; } return -EINVAL; @@ -735,8 +748,16 @@ static int rtm_to_fib_config(struct net *net, struct sk_buff *skb, memset(cfg, 0, sizeof(*cfg)); rtm = nlmsg_data(nlh); + + if (!inet_validate_dscp(rtm->rtm_tos)) { + NL_SET_ERR_MSG(extack, + "Invalid dsfield (tos): ECN bits must be 0"); + err = -EINVAL; + goto errout; + } + cfg->fc_dscp = inet_dsfield_to_dscp(rtm->rtm_tos); + cfg->fc_dst_len = rtm->rtm_dst_len; - cfg->fc_tos = rtm->rtm_tos; cfg->fc_table = rtm->rtm_table; cfg->fc_protocol = rtm->rtm_protocol; cfg->fc_scope = rtm->rtm_scope; @@ -815,21 +836,38 @@ static int rtm_to_fib_config(struct net *net, struct sk_buff *skb, } } + if (cfg->fc_dst_len > 32) { + NL_SET_ERR_MSG(extack, "Invalid prefix length"); + err = -EINVAL; + goto errout; + } + + if (cfg->fc_dst_len < 32 && (ntohl(cfg->fc_dst) << cfg->fc_dst_len)) { + NL_SET_ERR_MSG(extack, "Invalid prefix for given prefix length"); + err = -EINVAL; + goto errout; + } + if (cfg->fc_nh_id) { if (cfg->fc_oif || cfg->fc_gw_family || cfg->fc_encap || cfg->fc_mp) { NL_SET_ERR_MSG(extack, "Nexthop specification and nexthop id are mutually exclusive"); - return -EINVAL; + err = -EINVAL; + goto errout; } } if (has_gw && has_via) { NL_SET_ERR_MSG(extack, "Nexthop configuration can not contain both GATEWAY and VIA"); - return -EINVAL; + err = -EINVAL; + goto errout; } + if (!cfg->fc_table) + cfg->fc_table = RT_TABLE_MAIN; + return 0; errout: return err; @@ -847,20 +885,24 @@ static int inet_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, if (err < 0) goto errout; + rtnl_net_lock(net); + if (cfg.fc_nh_id && !nexthop_find_by_id(net, cfg.fc_nh_id)) { NL_SET_ERR_MSG(extack, "Nexthop id does not exist"); err = -EINVAL; - goto errout; + goto unlock; } tb = fib_get_table(net, cfg.fc_table); if (!tb) { NL_SET_ERR_MSG(extack, "FIB table does not exist"); err = -ESRCH; - goto errout; + goto unlock; } err = fib_table_delete(net, tb, &cfg, extack); +unlock: + rtnl_net_unlock(net); errout: return err; } @@ -877,15 +919,20 @@ static int inet_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, if (err < 0) goto errout; + rtnl_net_lock(net); + tb = fib_new_table(net, cfg.fc_table); if (!tb) { err = -ENOBUFS; - goto errout; + goto unlock; } err = fib_table_insert(net, tb, &cfg, extack); if (!err && cfg.fc_type == RTN_LOCAL) net->ipv4.fib_has_custom_local_routes = true; + +unlock: + rtnl_net_unlock(net); errout: return err; } @@ -899,14 +946,15 @@ int ip_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh, struct rtmsg *rtm; int err, i; - ASSERT_RTNL(); + if (filter->rtnl_held) + ASSERT_RTNL(); - if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + rtm = nlmsg_payload(nlh, sizeof(*rtm)); + if (!rtm) { NL_SET_ERR_MSG(extack, "Invalid header for FIB dump request"); return -EINVAL; } - rtm = nlmsg_data(nlh); if (rtm->rtm_dst_len || rtm->rtm_src_len || rtm->rtm_tos || rtm->rtm_scope) { NL_SET_ERR_MSG(extack, "Invalid values in header for FIB dump request"); @@ -944,7 +992,10 @@ int ip_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh, break; case RTA_OIF: ifindex = nla_get_u32(tb[i]); - filter->dev = __dev_get_by_index(net, ifindex); + if (filter->rtnl_held) + filter->dev = __dev_get_by_index(net, ifindex); + else + filter->dev = dev_get_by_index_rcu(net, ifindex); if (!filter->dev) return -ENODEV; break; @@ -966,20 +1017,24 @@ EXPORT_SYMBOL_GPL(ip_valid_fib_dump_req); static int inet_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) { - struct fib_dump_filter filter = { .dump_routes = true, - .dump_exceptions = true }; + struct fib_dump_filter filter = { + .dump_routes = true, + .dump_exceptions = true, + .rtnl_held = false, + }; const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); unsigned int h, s_h; unsigned int e = 0, s_e; struct fib_table *tb; struct hlist_head *head; - int dumped = 0, err; + int dumped = 0, err = 0; + rcu_read_lock(); if (cb->strict_check) { err = ip_valid_fib_dump_req(net, nlh, &filter, cb); if (err < 0) - return err; + goto unlock; } else if (nlmsg_len(nlh) >= sizeof(struct rtmsg)) { struct rtmsg *rtm = nlmsg_data(nlh); @@ -988,29 +1043,26 @@ static int inet_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) /* ipv4 does not use prefix flag */ if (filter.flags & RTM_F_PREFIX) - return skb->len; + goto unlock; if (filter.table_id) { tb = fib_get_table(net, filter.table_id); if (!tb) { if (rtnl_msg_family(cb->nlh) != PF_INET) - return skb->len; + goto unlock; NL_SET_ERR_MSG(cb->extack, "ipv4: FIB table does not exist"); - return -ENOENT; + err = -ENOENT; + goto unlock; } - - rcu_read_lock(); err = fib_table_dump(tb, skb, cb, &filter); - rcu_read_unlock(); - return skb->len ? : err; + goto unlock; } s_h = cb->args[0]; s_e = cb->args[1]; - rcu_read_lock(); - + err = 0; for (h = s_h; h < FIB_TABLE_HASHSZ; h++, s_e = 0) { e = 0; head = &net->ipv4.fib_table_hash[h]; @@ -1021,25 +1073,20 @@ static int inet_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) memset(&cb->args[2], 0, sizeof(cb->args) - 2 * sizeof(cb->args[0])); err = fib_table_dump(tb, skb, cb, &filter); - if (err < 0) { - if (likely(skb->len)) - goto out; - - goto out_err; - } + if (err < 0) + goto out; dumped = 1; next: e++; } } out: - err = skb->len; -out_err: - rcu_read_unlock(); cb->args[1] = e; cb->args[0] = h; +unlock: + rcu_read_unlock(); return err; } @@ -1112,9 +1159,11 @@ void fib_add_ifaddr(struct in_ifaddr *ifa) return; /* Add broadcast address, if it is explicitly assigned. */ - if (ifa->ifa_broadcast && ifa->ifa_broadcast != htonl(0xFFFFFFFF)) + if (ifa->ifa_broadcast && ifa->ifa_broadcast != htonl(0xFFFFFFFF)) { fib_magic(RTM_NEWROUTE, RTN_BROADCAST, ifa->ifa_broadcast, 32, prim, 0); + arp_invalidate(dev, ifa->ifa_broadcast, false); + } if (!ipv4_is_zeronet(prefix) && !(ifa->ifa_flags & IFA_F_SECONDARY) && (prefix != addr || ifa->ifa_prefixlen < 32)) { @@ -1128,6 +1177,7 @@ void fib_add_ifaddr(struct in_ifaddr *ifa) if (ifa->ifa_prefixlen < 31) { fib_magic(RTM_NEWROUTE, RTN_BROADCAST, prefix | ~mask, 32, prim, 0); + arp_invalidate(dev, prefix | ~mask, false); } } } @@ -1323,7 +1373,7 @@ static void nl_fib_lookup(struct net *net, struct fib_result_nl *frn) struct flowi4 fl4 = { .flowi4_mark = frn->fl_mark, .daddr = frn->fl_addr, - .flowi4_tos = frn->fl_tos, + .flowi4_dscp = inet_dsfield_to_dscp(frn->fl_tos), .flowi4_scope = frn->fl_scope, }; struct fib_table *tb; @@ -1370,7 +1420,7 @@ static void nl_fib_input(struct sk_buff *skb) return; nlh = nlmsg_hdr(skb); - frn = (struct fib_result_nl *) nlmsg_data(nlh); + frn = nlmsg_data(nlh); nl_fib_lookup(net, frn); portid = NETLINK_CB(skb).portid; /* netlink portid */ @@ -1411,7 +1461,7 @@ static void fib_disable_ip(struct net_device *dev, unsigned long event, static int fib_inetaddr_event(struct notifier_block *this, unsigned long event, void *ptr) { - struct in_ifaddr *ifa = (struct in_ifaddr *)ptr; + struct in_ifaddr *ifa = ptr; struct net_device *dev = ifa->ifa_dev->dev; struct net *net = dev_net(dev); @@ -1422,7 +1472,7 @@ static int fib_inetaddr_event(struct notifier_block *this, unsigned long event, fib_sync_up(dev, RTNH_F_DEAD); #endif atomic_inc(&net->ipv4.dev_addr_genid); - rt_cache_flush(dev_net(dev)); + rt_cache_flush(net); break; case NETDEV_DOWN: fib_del_ifaddr(ifa, NULL); @@ -1433,7 +1483,7 @@ static int fib_inetaddr_event(struct notifier_block *this, unsigned long event, */ fib_disable_ip(dev, event, true); } else { - rt_cache_flush(dev_net(dev)); + rt_cache_flush(net); } break; } @@ -1475,7 +1525,7 @@ static int fib_netdev_event(struct notifier_block *this, unsigned long event, vo fib_disable_ip(dev, event, false); break; case NETDEV_CHANGE: - flags = dev_get_flags(dev); + flags = netif_get_flags(dev); if (flags & (IFF_RUNNING | IFF_LOWER_UP)) fib_sync_up(dev, RTNH_F_LINKDOWN); else @@ -1547,7 +1597,7 @@ static void ip_fib_net_exit(struct net *net) { int i; - rtnl_lock(); + ASSERT_RTNL_NET(net); #ifdef CONFIG_IP_MULTIPLE_TABLES RCU_INIT_POINTER(net->ipv4.fib_main, NULL); RCU_INIT_POINTER(net->ipv4.fib_default, NULL); @@ -1572,7 +1622,7 @@ static void ip_fib_net_exit(struct net *net) #ifdef CONFIG_IP_MULTIPLE_TABLES fib4_rules_exit(net); #endif - rtnl_unlock(); + kfree(net->ipv4.fib_table_hash); fib4_notifier_exit(net); } @@ -1587,9 +1637,15 @@ static int __net_init fib_net_init(struct net *net) error = ip_fib_net_init(net); if (error < 0) goto out; + + error = fib4_semantics_init(net); + if (error) + goto out_semantics; + error = nl_fib_lookup_init(net); if (error < 0) goto out_nlfl; + error = fib_proc_init(net); if (error < 0) goto out_proc; @@ -1599,7 +1655,11 @@ out: out_proc: nl_fib_lookup_exit(net); out_nlfl: + fib4_semantics_exit(net); +out_semantics: + rtnl_net_lock(net); ip_fib_net_exit(net); + rtnl_net_unlock(net); goto out; } @@ -1607,12 +1667,37 @@ static void __net_exit fib_net_exit(struct net *net) { fib_proc_exit(net); nl_fib_lookup_exit(net); - ip_fib_net_exit(net); +} + +static void __net_exit fib_net_exit_batch(struct list_head *net_list) +{ + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) { + __rtnl_net_lock(net); + ip_fib_net_exit(net); + __rtnl_net_unlock(net); + } + rtnl_unlock(); + + list_for_each_entry(net, net_list, exit_list) + fib4_semantics_exit(net); } static struct pernet_operations fib_net_ops = { .init = fib_net_init, .exit = fib_net_exit, + .exit_batch = fib_net_exit_batch, +}; + +static const struct rtnl_msg_handler fib_rtnl_msg_handlers[] __initconst = { + {.protocol = PF_INET, .msgtype = RTM_NEWROUTE, + .doit = inet_rtm_newroute, .flags = RTNL_FLAG_DOIT_PERNET}, + {.protocol = PF_INET, .msgtype = RTM_DELROUTE, + .doit = inet_rtm_delroute, .flags = RTNL_FLAG_DOIT_PERNET}, + {.protocol = PF_INET, .msgtype = RTM_GETROUTE, .dumpit = inet_dump_fib, + .flags = RTNL_FLAG_DUMP_UNLOCKED | RTNL_FLAG_DUMP_SPLIT_NLM_DONE}, }; void __init ip_fib_init(void) @@ -1624,7 +1709,5 @@ void __init ip_fib_init(void) register_netdevice_notifier(&fib_netdev_notifier); register_inetaddr_notifier(&fib_inetaddr_notifier); - rtnl_register(PF_INET, RTM_NEWROUTE, inet_rtm_newroute, NULL, 0); - rtnl_register(PF_INET, RTM_DELROUTE, inet_rtm_delroute, NULL, 0); - rtnl_register(PF_INET, RTM_GETROUTE, NULL, inet_dump_fib, 0); + rtnl_register_many(fib_rtnl_msg_handlers); } diff --git a/net/ipv4/fib_lookup.h b/net/ipv4/fib_lookup.h index e184bcb19943..f9b9e26c32c1 100644 --- a/net/ipv4/fib_lookup.h +++ b/net/ipv4/fib_lookup.h @@ -4,22 +4,22 @@ #include <linux/types.h> #include <linux/list.h> +#include <net/inet_dscp.h> #include <net/ip_fib.h> #include <net/nexthop.h> struct fib_alias { struct hlist_node fa_list; struct fib_info *fa_info; - u8 fa_tos; + dscp_t fa_dscp; u8 fa_type; u8 fa_state; u8 fa_slen; u32 tb_id; s16 fa_default; - u8 offload:1, - trap:1, - offload_failed:1, - unused:5; + u8 offload; + u8 trap; + u8 offload_failed; struct rcu_head rcu; }; diff --git a/net/ipv4/fib_notifier.c b/net/ipv4/fib_notifier.c index 0e23ade74493..b1551c26554b 100644 --- a/net/ipv4/fib_notifier.c +++ b/net/ipv4/fib_notifier.c @@ -22,15 +22,15 @@ int call_fib4_notifiers(struct net *net, enum fib_event_type event_type, ASSERT_RTNL(); info->family = AF_INET; - net->ipv4.fib_seq++; + /* Paired with READ_ONCE() in fib4_seq_read() */ + WRITE_ONCE(net->ipv4.fib_seq, net->ipv4.fib_seq + 1); return call_fib_notifiers(net, event_type, info); } -static unsigned int fib4_seq_read(struct net *net) +static unsigned int fib4_seq_read(const struct net *net) { - ASSERT_RTNL(); - - return net->ipv4.fib_seq + fib4_rules_seq_read(net); + /* Paired with WRITE_ONCE() in call_fib4_notifiers() */ + return READ_ONCE(net->ipv4.fib_seq) + fib4_rules_seq_read(net); } static int fib4_dump(struct net *net, struct notifier_block *nb, diff --git a/net/ipv4/fib_rules.c b/net/ipv4/fib_rules.c index e0b6c8b6de57..51f0193092f0 100644 --- a/net/ipv4/fib_rules.c +++ b/net/ipv4/fib_rules.c @@ -23,6 +23,8 @@ #include <linux/list.h> #include <linux/rcupdate.h> #include <linux/export.h> +#include <net/flow.h> +#include <net/inet_dscp.h> #include <net/ip.h> #include <net/route.h> #include <net/tcp.h> @@ -35,7 +37,9 @@ struct fib4_rule { struct fib_rule common; u8 dst_len; u8 src_len; - u8 tos; + dscp_t dscp; + dscp_t dscp_mask; + u8 dscp_full:1; /* DSCP or TOS selector */ __be32 src; __be32 srcmask; __be32 dst; @@ -49,7 +53,7 @@ static bool fib4_rule_matchall(const struct fib_rule *rule) { struct fib4_rule *r = container_of(rule, struct fib4_rule, common); - if (r->dst_len || r->src_len || r->tos) + if (r->dst_len || r->src_len || r->dscp) return false; return fib_rule_matchall(rule); } @@ -72,7 +76,7 @@ int fib4_rules_dump(struct net *net, struct notifier_block *nb, return fib_rules_dump(net, nb, AF_INET, extack); } -unsigned int fib4_rules_seq_read(struct net *net) +unsigned int fib4_rules_seq_read(const struct net *net) { return fib_rules_seq_read(net, AF_INET); } @@ -144,7 +148,7 @@ INDIRECT_CALLABLE_SCOPE bool fib4_rule_suppress(struct fib_rule *rule, int flags, struct fib_lookup_arg *arg) { - struct fib_result *result = (struct fib_result *) arg->result; + struct fib_result *result = arg->result; struct net_device *dev = NULL; if (result->fi) { @@ -185,18 +189,26 @@ INDIRECT_CALLABLE_SCOPE int fib4_rule_match(struct fib_rule *rule, ((daddr ^ r->dst) & r->dstmask)) return 0; - if (r->tos && (r->tos != fl4->flowi4_tos)) + /* When DSCP selector is used we need to match on the entire DSCP field + * in the flow information structure. When TOS selector is used we need + * to mask the upper three DSCP bits prior to matching to maintain + * legacy behavior. + */ + if (r->dscp_full && (r->dscp ^ fl4->flowi4_dscp) & r->dscp_mask) + return 0; + else if (!r->dscp_full && r->dscp && + !fib_dscp_masked_match(r->dscp, fl4)) return 0; if (rule->ip_proto && (rule->ip_proto != fl4->flowi4_proto)) return 0; - if (fib_rule_port_range_set(&rule->sport_range) && - !fib_rule_port_inrange(&rule->sport_range, fl4->fl4_sport)) + if (!fib_rule_port_match(&rule->sport_range, rule->sport_mask, + fl4->fl4_sport)) return 0; - if (fib_rule_port_range_set(&rule->dport_range) && - !fib_rule_port_inrange(&rule->dport_range, fl4->fl4_dport)) + if (!fib_rule_port_match(&rule->dport_range, rule->dport_mask, + fl4->fl4_dport)) return 0; return 1; @@ -216,19 +228,78 @@ static struct fib_table *fib_empty_table(struct net *net) return NULL; } +static int fib4_nl2rule_dscp(const struct nlattr *nla, struct fib4_rule *rule4, + struct netlink_ext_ack *extack) +{ + if (rule4->dscp) { + NL_SET_ERR_MSG(extack, "Cannot specify both TOS and DSCP"); + return -EINVAL; + } + + rule4->dscp = inet_dsfield_to_dscp(nla_get_u8(nla) << 2); + rule4->dscp_mask = inet_dsfield_to_dscp(INET_DSCP_MASK); + rule4->dscp_full = true; + + return 0; +} + +static int fib4_nl2rule_dscp_mask(const struct nlattr *nla, + struct fib4_rule *rule4, + struct netlink_ext_ack *extack) +{ + dscp_t dscp_mask; + + if (!rule4->dscp_full) { + NL_SET_ERR_MSG_ATTR(extack, nla, + "Cannot specify DSCP mask without DSCP value"); + return -EINVAL; + } + + dscp_mask = inet_dsfield_to_dscp(nla_get_u8(nla) << 2); + if (rule4->dscp & ~dscp_mask) { + NL_SET_ERR_MSG_ATTR(extack, nla, "Invalid DSCP mask"); + return -EINVAL; + } + + rule4->dscp_mask = dscp_mask; + + return 0; +} + static int fib4_rule_configure(struct fib_rule *rule, struct sk_buff *skb, struct fib_rule_hdr *frh, struct nlattr **tb, struct netlink_ext_ack *extack) { - struct net *net = sock_net(skb->sk); + struct fib4_rule *rule4 = (struct fib4_rule *)rule; + struct net *net = rule->fr_net; int err = -EINVAL; - struct fib4_rule *rule4 = (struct fib4_rule *) rule; + if (tb[FRA_FLOWLABEL] || tb[FRA_FLOWLABEL_MASK]) { + NL_SET_ERR_MSG(extack, + "Flow label cannot be specified for IPv4 FIB rules"); + goto errout; + } + + if (!inet_validate_dscp(frh->tos)) { + NL_SET_ERR_MSG(extack, + "Invalid dsfield (tos): ECN bits must be 0"); + goto errout; + } + /* IPv4 currently doesn't handle high order DSCP bits correctly */ if (frh->tos & ~IPTOS_TOS_MASK) { NL_SET_ERR_MSG(extack, "Invalid tos"); goto errout; } + rule4->dscp = inet_dsfield_to_dscp(frh->tos); + + if (tb[FRA_DSCP] && + fib4_nl2rule_dscp(tb[FRA_DSCP], rule4, extack) < 0) + goto errout; + + if (tb[FRA_DSCP_MASK] && + fib4_nl2rule_dscp_mask(tb[FRA_DSCP_MASK], rule4, extack) < 0) + goto errout; /* split local/main if they are not already split */ err = fib_unmerge(net); @@ -270,7 +341,6 @@ static int fib4_rule_configure(struct fib_rule *rule, struct sk_buff *skb, rule4->srcmask = inet_make_mask(rule4->src_len); rule4->dst_len = frh->dst_len; rule4->dstmask = inet_make_mask(rule4->dst_len); - rule4->tos = frh->tos; net->ipv4.fib_has_custom_rules = true; @@ -313,9 +383,27 @@ static int fib4_rule_compare(struct fib_rule *rule, struct fib_rule_hdr *frh, if (frh->dst_len && (rule4->dst_len != frh->dst_len)) return 0; - if (frh->tos && (rule4->tos != frh->tos)) + if (frh->tos && + (rule4->dscp_full || + inet_dscp_to_dsfield(rule4->dscp) != frh->tos)) return 0; + if (tb[FRA_DSCP]) { + dscp_t dscp; + + dscp = inet_dsfield_to_dscp(nla_get_u8(tb[FRA_DSCP]) << 2); + if (!rule4->dscp_full || rule4->dscp != dscp) + return 0; + } + + if (tb[FRA_DSCP_MASK]) { + dscp_t dscp_mask; + + dscp_mask = inet_dsfield_to_dscp(nla_get_u8(tb[FRA_DSCP_MASK]) << 2); + if (!rule4->dscp_full || rule4->dscp_mask != dscp_mask) + return 0; + } + #ifdef CONFIG_IP_ROUTE_CLASSID if (tb[FRA_FLOW] && (rule4->tclassid != nla_get_u32(tb[FRA_FLOW]))) return 0; @@ -337,7 +425,17 @@ static int fib4_rule_fill(struct fib_rule *rule, struct sk_buff *skb, frh->dst_len = rule4->dst_len; frh->src_len = rule4->src_len; - frh->tos = rule4->tos; + + if (rule4->dscp_full) { + frh->tos = 0; + if (nla_put_u8(skb, FRA_DSCP, + inet_dscp_to_dsfield(rule4->dscp) >> 2) || + nla_put_u8(skb, FRA_DSCP_MASK, + inet_dscp_to_dsfield(rule4->dscp_mask) >> 2)) + goto nla_put_failure; + } else { + frh->tos = inet_dscp_to_dsfield(rule4->dscp); + } if ((rule4->dst_len && nla_put_in_addr(skb, FRA_DST, rule4->dst)) || @@ -359,7 +457,9 @@ static size_t fib4_rule_nlmsg_payload(struct fib_rule *rule) { return nla_total_size(4) /* dst */ + nla_total_size(4) /* src */ - + nla_total_size(4); /* flow */ + + nla_total_size(4) /* flow */ + + nla_total_size(1) /* dscp */ + + nla_total_size(1); /* dscp mask */ } static void fib4_rule_flush_cache(struct fib_rules_ops *ops) @@ -388,13 +488,13 @@ static int fib_default_rules_init(struct fib_rules_ops *ops) { int err; - err = fib_default_rule_add(ops, 0, RT_TABLE_LOCAL, 0); + err = fib_default_rule_add(ops, 0, RT_TABLE_LOCAL); if (err < 0) return err; - err = fib_default_rule_add(ops, 0x7FFE, RT_TABLE_MAIN, 0); + err = fib_default_rule_add(ops, 0x7FFE, RT_TABLE_MAIN); if (err < 0) return err; - err = fib_default_rule_add(ops, 0x7FFF, RT_TABLE_DEFAULT, 0); + err = fib_default_rule_add(ops, 0x7FFF, RT_TABLE_DEFAULT); if (err < 0) return err; return 0; diff --git a/net/ipv4/fib_semantics.c b/net/ipv4/fib_semantics.c index b4589861b84c..a5f3c8459758 100644 --- a/net/ipv4/fib_semantics.c +++ b/net/ipv4/fib_semantics.c @@ -30,8 +30,10 @@ #include <linux/slab.h> #include <linux/netlink.h> #include <linux/hash.h> +#include <linux/nospec.h> #include <net/arp.h> +#include <net/inet_dscp.h> #include <net/ip.h> #include <net/protocol.h> #include <net/route.h> @@ -48,17 +50,6 @@ #include "fib_lookup.h" -static DEFINE_SPINLOCK(fib_info_lock); -static struct hlist_head *fib_info_hash; -static struct hlist_head *fib_info_laddrhash; -static unsigned int fib_info_hash_size; -static unsigned int fib_info_hash_bits; -static unsigned int fib_info_cnt; - -#define DEVINDEX_HASHBITS 8 -#define DEVINDEX_HASHSIZE (1U << DEVINDEX_HASHBITS) -static struct hlist_head fib_info_devhash[DEVINDEX_HASHSIZE]; - /* for_nexthops and change_nexthops only used when nexthop object * is not set in a fib_info. The logic within can reference fib_nh. */ @@ -210,7 +201,7 @@ static void rt_fibinfo_free_cpus(struct rtable __rcu * __percpu *rtp) void fib_nh_common_release(struct fib_nh_common *nhc) { - dev_put_track(nhc->nhc_dev, &nhc->nhc_dev_tracker); + netdev_put(nhc->nhc_dev, &nhc->nhc_dev_tracker); lwtstate_put(nhc->nhc_lwtstate); rt_fibinfo_free_cpus(nhc->nhc_pcpu_rth_output); rt_fibinfo_free(&nhc->nhc_rth_input); @@ -252,18 +243,16 @@ void free_fib_info(struct fib_info *fi) return; } - call_rcu(&fi->rcu, free_fib_info_rcu); + call_rcu_hurry(&fi->rcu, free_fib_info_rcu); } EXPORT_SYMBOL_GPL(free_fib_info); void fib_release_info(struct fib_info *fi) { - spin_lock_bh(&fib_info_lock); + ASSERT_RTNL(); if (fi && refcount_dec_and_test(&fi->fib_treeref)) { hlist_del(&fi->fib_hash); - - /* Paired with READ_ONCE() in fib_create_info(). */ - WRITE_ONCE(fib_info_cnt, fib_info_cnt - 1); + fi->fib_net->ipv4.fib_info_cnt--; if (fi->fib_prefsrc) hlist_del(&fi->fib_lhash); @@ -273,13 +262,13 @@ void fib_release_info(struct fib_info *fi) change_nexthops(fi) { if (!nexthop_nh->fib_nh_dev) continue; - hlist_del(&nexthop_nh->nh_hash); + hlist_del_rcu(&nexthop_nh->nh_hash); } endfor_nexthops(fi) } - fi->fib_dead = 1; + /* Paired with READ_ONCE() from fib_table_lookup() */ + WRITE_ONCE(fi->fib_dead, 1); fib_info_put(fi); } - spin_unlock_bh(&fib_info_lock); } static inline int nh_comp(struct fib_info *fi, struct fib_info *ofi) @@ -319,17 +308,9 @@ static inline int nh_comp(struct fib_info *fi, struct fib_info *ofi) return 0; } -static inline unsigned int fib_devindex_hashfn(unsigned int val) +static struct hlist_head *fib_nh_head(struct net_device *dev) { - return hash_32(val, DEVINDEX_HASHBITS); -} - -static struct hlist_head * -fib_info_devhash_bucket(const struct net_device *dev) -{ - u32 val = net_hash_mix(dev_net(dev)) ^ dev->ifindex; - - return &fib_info_devhash[fib_devindex_hashfn(val)]; + return &dev->fib_nh_head; } static unsigned int fib_info_hashfn_1(int init_val, u8 protocol, u8 scope, @@ -344,15 +325,15 @@ static unsigned int fib_info_hashfn_1(int init_val, u8 protocol, u8 scope, return val; } -static unsigned int fib_info_hashfn_result(unsigned int val) +static unsigned int fib_info_hashfn_result(const struct net *net, + unsigned int val) { - unsigned int mask = (fib_info_hash_size - 1); - - return (val ^ (val >> 7) ^ (val >> 12)) & mask; + return hash_32(val ^ net_hash_mix(net), net->ipv4.fib_info_hash_bits); } -static inline unsigned int fib_info_hashfn(struct fib_info *fi) +static struct hlist_head *fib_info_hash_bucket(struct fib_info *fi) { + struct net *net = fi->fib_net; unsigned int val; val = fib_info_hashfn_1(fi->fib_nhs, fi->fib_protocol, @@ -360,14 +341,77 @@ static inline unsigned int fib_info_hashfn(struct fib_info *fi) fi->fib_priority); if (fi->nh) { - val ^= fib_devindex_hashfn(fi->nh->id); + val ^= fi->nh->id; } else { for_nexthops(fi) { - val ^= fib_devindex_hashfn(nh->fib_nh_oif); + val ^= nh->fib_nh_oif; } endfor_nexthops(fi) } - return fib_info_hashfn_result(val); + return &net->ipv4.fib_info_hash[fib_info_hashfn_result(net, val)]; +} + +static struct hlist_head *fib_info_laddrhash_bucket(const struct net *net, + __be32 val) +{ + unsigned int hash_bits = net->ipv4.fib_info_hash_bits; + u32 slot; + + slot = hash_32(net_hash_mix(net) ^ (__force u32)val, hash_bits); + + return &net->ipv4.fib_info_hash[(1 << hash_bits) + slot]; +} + +static struct hlist_head *fib_info_hash_alloc(unsigned int hash_bits) +{ + /* The second half is used for prefsrc */ + return kvcalloc((1 << hash_bits) * 2, sizeof(struct hlist_head), + GFP_KERNEL); +} + +static void fib_info_hash_free(struct hlist_head *head) +{ + kvfree(head); +} + +static void fib_info_hash_grow(struct net *net) +{ + unsigned int old_size = 1 << net->ipv4.fib_info_hash_bits; + struct hlist_head *new_info_hash, *old_info_hash; + unsigned int i; + + if (net->ipv4.fib_info_cnt < old_size) + return; + + new_info_hash = fib_info_hash_alloc(net->ipv4.fib_info_hash_bits + 1); + if (!new_info_hash) + return; + + old_info_hash = net->ipv4.fib_info_hash; + net->ipv4.fib_info_hash = new_info_hash; + net->ipv4.fib_info_hash_bits += 1; + + for (i = 0; i < old_size; i++) { + struct hlist_head *head = &old_info_hash[i]; + struct hlist_node *n; + struct fib_info *fi; + + hlist_for_each_entry_safe(fi, n, head, fib_hash) + hlist_add_head(&fi->fib_hash, fib_info_hash_bucket(fi)); + } + + for (i = 0; i < old_size; i++) { + struct hlist_head *lhead = &old_info_hash[old_size + i]; + struct hlist_node *n; + struct fib_info *fi; + + hlist_for_each_entry_safe(fi, n, lhead, fib_lhash) + hlist_add_head(&fi->fib_lhash, + fib_info_laddrhash_bucket(fi->fib_net, + fi->fib_prefsrc)); + } + + fib_info_hash_free(old_info_hash); } /* no metrics, only nexthop id */ @@ -378,18 +422,17 @@ static struct fib_info *fib_find_info_nh(struct net *net, struct fib_info *fi; unsigned int hash; - hash = fib_info_hashfn_1(fib_devindex_hashfn(cfg->fc_nh_id), + hash = fib_info_hashfn_1(cfg->fc_nh_id, cfg->fc_protocol, cfg->fc_scope, (__force u32)cfg->fc_prefsrc, cfg->fc_priority); - hash = fib_info_hashfn_result(hash); - head = &fib_info_hash[hash]; + hash = fib_info_hashfn_result(net, hash); + head = &net->ipv4.fib_info_hash[hash]; hlist_for_each_entry(fi, head, fib_hash) { - if (!net_eq(fi->fib_net, net)) - continue; if (!fi->nh || fi->nh->id != cfg->fc_nh_id) continue; + if (cfg->fc_protocol == fi->fib_protocol && cfg->fc_scope == fi->fib_scope && cfg->fc_prefsrc == fi->fib_prefsrc && @@ -405,23 +448,19 @@ static struct fib_info *fib_find_info_nh(struct net *net, static struct fib_info *fib_find_info(struct fib_info *nfi) { - struct hlist_head *head; + struct hlist_head *head = fib_info_hash_bucket(nfi); struct fib_info *fi; - unsigned int hash; - - hash = fib_info_hashfn(nfi); - head = &fib_info_hash[hash]; hlist_for_each_entry(fi, head, fib_hash) { - if (!net_eq(fi->fib_net, nfi->fib_net)) - continue; if (fi->fib_nhs != nfi->fib_nhs) continue; + if (nfi->fib_protocol == fi->fib_protocol && nfi->fib_scope == fi->fib_scope && nfi->fib_prefsrc == fi->fib_prefsrc && nfi->fib_priority == fi->fib_priority && nfi->fib_type == fi->fib_type && + nfi->fib_tb_id == fi->fib_tb_id && memcmp(nfi->fib_metrics, fi->fib_metrics, sizeof(u32) * RTAX_MAX) == 0 && !((nfi->fib_flags ^ fi->fib_flags) & ~RTNH_COMPARE_MASK) && @@ -433,28 +472,23 @@ static struct fib_info *fib_find_info(struct fib_info *nfi) } /* Check, that the gateway is already configured. - * Used only by redirect accept routine. + * Used only by redirect accept routine, under rcu_read_lock(); */ int ip_fib_check_default(__be32 gw, struct net_device *dev) { struct hlist_head *head; struct fib_nh *nh; - spin_lock(&fib_info_lock); + head = fib_nh_head(dev); - head = fib_info_devhash_bucket(dev); - - hlist_for_each_entry(nh, head, nh_hash) { - if (nh->fib_nh_dev == dev && - nh->fib_nh_gw4 == gw && + hlist_for_each_entry_rcu(nh, head, nh_hash) { + DEBUG_NET_WARN_ON_ONCE(nh->fib_nh_dev != dev); + if (nh->fib_nh_gw4 == gw && !(nh->fib_nh_flags & RTNH_F_DEAD)) { - spin_unlock(&fib_info_lock); return 0; } } - spin_unlock(&fib_info_lock); - return -1; } @@ -523,11 +557,11 @@ void rtmsg_fib(int event, __be32 key, struct fib_alias *fa, fri.tb_id = tb_id; fri.dst = key; fri.dst_len = dst_len; - fri.tos = fa->fa_tos; + fri.dscp = fa->fa_dscp; fri.type = fa->fa_type; - fri.offload = fa->offload; - fri.trap = fa->trap; - fri.offload_failed = fa->offload_failed; + fri.offload = READ_ONCE(fa->offload); + fri.trap = READ_ONCE(fa->trap); + fri.offload_failed = READ_ONCE(fa->offload_failed); err = fib_dump_info(skb, info->portid, seq, event, &fri, nlm_flags); if (err < 0) { /* -EMSGSIZE implies BUG in fib_nlmsg_size() */ @@ -539,8 +573,7 @@ void rtmsg_fib(int event, __be32 key, struct fib_alias *fa, info->nlh, GFP_KERNEL); return; errout: - if (err < 0) - rtnl_set_sk_err(info->nl_net, RTNLGRP_IPV4_ROUTE, err); + rtnl_set_sk_err(info->nl_net, RTNLGRP_IPV4_ROUTE, err); } static int fib_detect_death(struct fib_info *fi, int order, @@ -560,7 +593,7 @@ static int fib_detect_death(struct fib_info *fi, int order, n = NULL; if (n) { - state = n->nud_state; + state = READ_ONCE(n->nud_state); neigh_release(n); } else { return 0; @@ -592,11 +625,6 @@ int fib_nh_common_init(struct net *net, struct fib_nh_common *nhc, if (encap) { struct lwtunnel_state *lwtstate; - if (encap_type == LWTUNNEL_ENCAP_NONE) { - NL_SET_ERR_MSG(extack, "LWT encap type not specified"); - err = -EINVAL; - goto lwt_failure; - } err = lwtunnel_build_state(net, encap_type, encap, nhc->nhc_family, cfg, &lwtstate, extack); @@ -887,9 +915,16 @@ int fib_nh_match(struct net *net, struct fib_config *cfg, struct fib_info *fi, return 1; } + if (fi->nh) { + if (cfg->fc_oif || cfg->fc_gw_family || cfg->fc_mp) + return 1; + return 0; + } + if (cfg->fc_oif || cfg->fc_gw_family) { - struct fib_nh *nh = fib_info_nh(fi, 0); + struct fib_nh *nh; + nh = fib_info_nh(fi, 0); if (cfg->fc_encap) { if (fib_encap_match(net, cfg->fc_encap_type, cfg->fc_encap, nh, cfg, extack)) @@ -1013,12 +1048,13 @@ bool fib_metrics_match(struct fib_config *cfg, struct fib_info *fi) if (type > RTAX_MAX) return false; + type = array_index_nospec(type, RTAX_MAX + 1); if (type == RTAX_CC_ALGO) { char tmp[TCP_CA_NAME_MAX]; bool ecn_ca = false; nla_strscpy(tmp, nla, sizeof(tmp)); - val = tcp_ca_get_key_by_name(fi->fib_net, tmp, &ecn_ca); + val = tcp_ca_get_key_by_name(tmp, &ecn_ca); } else { if (nla_len(nla) != sizeof(u32)) return false; @@ -1051,7 +1087,8 @@ static int fib_check_nh_v6_gw(struct net *net, struct fib_nh *nh, err = ipv6_stub->fib6_nh_init(net, &fib6_nh, &cfg, GFP_KERNEL, extack); if (!err) { nh->fib_nh_dev = fib6_nh.fib_nh_dev; - dev_hold_track(nh->fib_nh_dev, &nh->fib_nh_dev_tracker, GFP_KERNEL); + netdev_hold(nh->fib_nh_dev, &nh->fib_nh_dev_tracker, + GFP_KERNEL); nh->fib_nh_oif = nh->fib_nh_dev->ifindex; nh->fib_nh_scope = RT_SCOPE_LINK; @@ -1135,7 +1172,7 @@ static int fib_check_nh_v4_gw(struct net *net, struct fib_nh *nh, u32 table, if (!netif_carrier_ok(dev)) nh->fib_nh_flags |= RTNH_F_LINKDOWN; nh->fib_nh_dev = dev; - dev_hold_track(dev, &nh->fib_nh_dev_tracker, GFP_ATOMIC); + netdev_hold(dev, &nh->fib_nh_dev_tracker, GFP_ATOMIC); nh->fib_nh_scope = RT_SCOPE_LINK; return 0; } @@ -1189,7 +1226,7 @@ static int fib_check_nh_v4_gw(struct net *net, struct fib_nh *nh, u32 table, "No egress device for nexthop gateway"); goto out; } - dev_hold_track(dev, &nh->fib_nh_dev_tracker, GFP_ATOMIC); + netdev_hold(dev, &nh->fib_nh_dev_tracker, GFP_ATOMIC); if (!netif_carrier_ok(dev)) nh->fib_nh_flags |= RTNH_F_LINKDOWN; err = (dev->flags & IFF_UP) ? 0 : -ENETDOWN; @@ -1223,7 +1260,7 @@ static int fib_check_nh_nongw(struct net *net, struct fib_nh *nh, } nh->fib_nh_dev = in_dev->dev; - dev_hold_track(nh->fib_nh_dev, &nh->fib_nh_dev_tracker, GFP_ATOMIC); + netdev_hold(nh->fib_nh_dev, &nh->fib_nh_dev_tracker, GFP_ATOMIC); nh->fib_nh_scope = RT_SCOPE_HOST; if (!netif_carrier_ok(nh->fib_nh_dev)) nh->fib_nh_flags |= RTNH_F_LINKDOWN; @@ -1248,101 +1285,22 @@ int fib_check_nh(struct net *net, struct fib_nh *nh, u32 table, u8 scope, return err; } -static struct hlist_head * -fib_info_laddrhash_bucket(const struct net *net, __be32 val) -{ - u32 slot = hash_32(net_hash_mix(net) ^ (__force u32)val, - fib_info_hash_bits); - - return &fib_info_laddrhash[slot]; -} - -static struct hlist_head *fib_info_hash_alloc(int bytes) -{ - if (bytes <= PAGE_SIZE) - return kzalloc(bytes, GFP_KERNEL); - else - return (struct hlist_head *) - __get_free_pages(GFP_KERNEL | __GFP_ZERO, - get_order(bytes)); -} - -static void fib_info_hash_free(struct hlist_head *hash, int bytes) -{ - if (!hash) - return; - - if (bytes <= PAGE_SIZE) - kfree(hash); - else - free_pages((unsigned long) hash, get_order(bytes)); -} - -static void fib_info_hash_move(struct hlist_head *new_info_hash, - struct hlist_head *new_laddrhash, - unsigned int new_size) -{ - struct hlist_head *old_info_hash, *old_laddrhash; - unsigned int old_size = fib_info_hash_size; - unsigned int i, bytes; - - spin_lock_bh(&fib_info_lock); - old_info_hash = fib_info_hash; - old_laddrhash = fib_info_laddrhash; - fib_info_hash_size = new_size; - fib_info_hash_bits = ilog2(new_size); - - for (i = 0; i < old_size; i++) { - struct hlist_head *head = &fib_info_hash[i]; - struct hlist_node *n; - struct fib_info *fi; - - hlist_for_each_entry_safe(fi, n, head, fib_hash) { - struct hlist_head *dest; - unsigned int new_hash; - - new_hash = fib_info_hashfn(fi); - dest = &new_info_hash[new_hash]; - hlist_add_head(&fi->fib_hash, dest); - } - } - fib_info_hash = new_info_hash; - - fib_info_laddrhash = new_laddrhash; - for (i = 0; i < old_size; i++) { - struct hlist_head *lhead = &old_laddrhash[i]; - struct hlist_node *n; - struct fib_info *fi; - - hlist_for_each_entry_safe(fi, n, lhead, fib_lhash) { - struct hlist_head *ldest; - - ldest = fib_info_laddrhash_bucket(fi->fib_net, - fi->fib_prefsrc); - hlist_add_head(&fi->fib_lhash, ldest); - } - } - - spin_unlock_bh(&fib_info_lock); - - bytes = old_size * sizeof(struct hlist_head *); - fib_info_hash_free(old_info_hash, bytes); - fib_info_hash_free(old_laddrhash, bytes); -} - __be32 fib_info_update_nhc_saddr(struct net *net, struct fib_nh_common *nhc, unsigned char scope) { struct fib_nh *nh; + __be32 saddr; if (nhc->nhc_family != AF_INET) return inet_select_addr(nhc->nhc_dev, 0, scope); nh = container_of(nhc, struct fib_nh, nh_common); - nh->nh_saddr = inet_select_addr(nh->fib_nh_dev, nh->fib_nh_gw4, scope); - nh->nh_saddr_genid = atomic_read(&net->ipv4.dev_addr_genid); + saddr = inet_select_addr(nh->fib_nh_dev, nh->fib_nh_gw4, scope); - return nh->nh_saddr; + WRITE_ONCE(nh->nh_saddr, saddr); + WRITE_ONCE(nh->nh_saddr_genid, atomic_read(&net->ipv4.dev_addr_genid)); + + return saddr; } __be32 fib_result_prefsrc(struct net *net, struct fib_result *res) @@ -1356,8 +1314,9 @@ __be32 fib_result_prefsrc(struct net *net, struct fib_result *res) struct fib_nh *nh; nh = container_of(nhc, struct fib_nh, nh_common); - if (nh->nh_saddr_genid == atomic_read(&net->ipv4.dev_addr_genid)) - return nh->nh_saddr; + if (READ_ONCE(nh->nh_saddr_genid) == + atomic_read(&net->ipv4.dev_addr_genid)) + return READ_ONCE(nh->nh_saddr); } return fib_info_update_nhc_saddr(net, nhc, res->fi->fib_scope); @@ -1397,6 +1356,7 @@ struct fib_info *fib_create_info(struct fib_config *cfg, int nhs = 1; struct net *net = cfg->fc_nlinfo.nl_net; + ASSERT_RTNL(); if (cfg->fc_type > RTN_MAX) goto err_inval; @@ -1437,35 +1397,15 @@ struct fib_info *fib_create_info(struct fib_config *cfg, } #endif - err = -ENOBUFS; - - /* Paired with WRITE_ONCE() in fib_release_info() */ - if (READ_ONCE(fib_info_cnt) >= fib_info_hash_size) { - unsigned int new_size = fib_info_hash_size << 1; - struct hlist_head *new_info_hash; - struct hlist_head *new_laddrhash; - unsigned int bytes; - - if (!new_size) - new_size = 16; - bytes = new_size * sizeof(struct hlist_head *); - new_info_hash = fib_info_hash_alloc(bytes); - new_laddrhash = fib_info_hash_alloc(bytes); - if (!new_info_hash || !new_laddrhash) { - fib_info_hash_free(new_info_hash, bytes); - fib_info_hash_free(new_laddrhash, bytes); - } else - fib_info_hash_move(new_info_hash, new_laddrhash, new_size); - - if (!fib_info_hash_size) - goto failure; - } + fib_info_hash_grow(net); fi = kzalloc(struct_size(fi, fib_nh, nhs), GFP_KERNEL); - if (!fi) + if (!fi) { + err = -ENOBUFS; goto failure; - fi->fib_metrics = ip_fib_metrics_init(fi->fib_net, cfg->fc_mx, - cfg->fc_mx_len, extack); + } + + fi->fib_metrics = ip_fib_metrics_init(cfg->fc_mx, cfg->fc_mx_len, extack); if (IS_ERR(fi->fib_metrics)) { err = PTR_ERR(fi->fib_metrics); kfree(fi); @@ -1591,6 +1531,7 @@ struct fib_info *fib_create_info(struct fib_config *cfg, link_it: ofi = fib_find_info(fi); if (ofi) { + /* fib_table_lookup() should not see @fi yet. */ fi->fib_dead = 1; free_fib_info(fi); refcount_inc(&ofi->fib_treeref); @@ -1599,10 +1540,10 @@ link_it: refcount_set(&fi->fib_treeref, 1); refcount_set(&fi->fib_clntref, 1); - spin_lock_bh(&fib_info_lock); - fib_info_cnt++; - hlist_add_head(&fi->fib_hash, - &fib_info_hash[fib_info_hashfn(fi)]); + + net->ipv4.fib_info_cnt++; + hlist_add_head(&fi->fib_hash, fib_info_hash_bucket(fi)); + if (fi->fib_prefsrc) { struct hlist_head *head; @@ -1617,11 +1558,10 @@ link_it: if (!nexthop_nh->fib_nh_dev) continue; - head = fib_info_devhash_bucket(nexthop_nh->fib_nh_dev); - hlist_add_head(&nexthop_nh->nh_hash, head); + head = fib_nh_head(nexthop_nh->fib_nh_dev); + hlist_add_head_rcu(&nexthop_nh->nh_hash, head); } endfor_nexthops(fi) } - spin_unlock_bh(&fib_info_lock); return fi; err_inval: @@ -1629,6 +1569,7 @@ err_inval: failure: if (fi) { + /* fib_table_lookup() should not see @fi yet. */ fi->fib_dead = 1; free_fib_info(fi); } @@ -1694,8 +1635,7 @@ int fib_nexthop_info(struct sk_buff *skb, const struct fib_nh_common *nhc, nla_put_u32(skb, RTA_OIF, nhc->nhc_dev->ifindex)) goto nla_put_failure; - if (nhc->nhc_lwtstate && - lwtunnel_fill_encap(skb, nhc->nhc_lwtstate, + if (lwtunnel_fill_encap(skb, nhc->nhc_lwtstate, RTA_ENCAP, RTA_ENCAP_TYPE) < 0) goto nla_put_failure; @@ -1797,7 +1737,7 @@ int fib_dump_info(struct sk_buff *skb, u32 portid, u32 seq, int event, rtm->rtm_family = AF_INET; rtm->rtm_dst_len = fri->dst_len; rtm->rtm_src_len = 0; - rtm->rtm_tos = fri->tos; + rtm->rtm_tos = inet_dscp_to_dsfield(fri->dscp); if (tb_id < 256) rtm->rtm_table = tb_id; else @@ -1827,7 +1767,7 @@ int fib_dump_info(struct sk_buff *skb, u32 portid, u32 seq, int event, goto nla_put_failure; if (nexthop_is_blackhole(fi->nh)) rtm->rtm_type = RTN_BLACKHOLE; - if (!fi->fib_net->ipv4.sysctl_nexthop_compat_mode) + if (!READ_ONCE(fi->fib_net->ipv4.sysctl_nexthop_compat_mode)) goto offload; } @@ -1884,7 +1824,7 @@ int fib_sync_down_addr(struct net_device *dev, __be32 local) struct fib_info *fi; int ret = 0; - if (!fib_info_laddrhash || local == 0) + if (!local) return 0; head = fib_info_laddrhash_bucket(net, local); @@ -1894,6 +1834,7 @@ int fib_sync_down_addr(struct net_device *dev, __be32 local) continue; if (fi->fib_prefsrc == local) { fi->fib_flags |= RTNH_F_DEAD; + fi->pfsrc_removed = true; ret++; } } @@ -1969,12 +1910,12 @@ void fib_nhc_update_mtu(struct fib_nh_common *nhc, u32 new, u32 orig) void fib_sync_mtu(struct net_device *dev, u32 orig_mtu) { - struct hlist_head *head = fib_info_devhash_bucket(dev); + struct hlist_head *head = fib_nh_head(dev); struct fib_nh *nh; hlist_for_each_entry(nh, head, nh_hash) { - if (nh->fib_nh_dev == dev) - fib_nhc_update_mtu(&nh->nh_common, dev->mtu, orig_mtu); + DEBUG_NET_WARN_ON_ONCE(nh->fib_nh_dev != dev); + fib_nhc_update_mtu(&nh->nh_common, dev->mtu, orig_mtu); } } @@ -1988,7 +1929,7 @@ void fib_sync_mtu(struct net_device *dev, u32 orig_mtu) */ int fib_sync_down_dev(struct net_device *dev, unsigned long event, bool force) { - struct hlist_head *head = fib_info_devhash_bucket(dev); + struct hlist_head *head = fib_nh_head(dev); struct fib_info *prev_fi = NULL; int scope = RT_SCOPE_NOWHERE; struct fib_nh *nh; @@ -2002,7 +1943,8 @@ int fib_sync_down_dev(struct net_device *dev, unsigned long event, bool force) int dead; BUG_ON(!fi->fib_nhs); - if (nh->fib_nh_dev != dev || fi == prev_fi) + DEBUG_NET_WARN_ON_ONCE(nh->fib_nh_dev != dev); + if (fi == prev_fi) continue; prev_fi = fi; dead = 0; @@ -2061,7 +2003,7 @@ static void fib_select_default(const struct flowi4 *flp, struct fib_result *res) int order = -1, last_idx = -1; struct fib_alias *fa, *fa1 = NULL; u32 last_prio = res->fi->fib_priority; - u8 last_tos = 0; + dscp_t last_dscp = 0; hlist_for_each_entry_rcu(fa, fa_head, fa_list) { struct fib_info *next_fi = fa->fa_info; @@ -2069,19 +2011,19 @@ static void fib_select_default(const struct flowi4 *flp, struct fib_result *res) if (fa->fa_slen != slen) continue; - if (fa->fa_tos && fa->fa_tos != flp->flowi4_tos) + if (fa->fa_dscp && !fib_dscp_masked_match(fa->fa_dscp, flp)) continue; if (fa->tb_id != tb->tb_id) continue; if (next_fi->fib_priority > last_prio && - fa->fa_tos == last_tos) { - if (last_tos) + fa->fa_dscp == last_dscp) { + if (last_dscp) continue; break; } if (next_fi->fib_flags & RTNH_F_DEAD) continue; - last_tos = fa->fa_tos; + last_dscp = fa->fa_dscp; last_prio = next_fi->fib_priority; if (next_fi->fib_scope != res->scope || @@ -2145,14 +2087,14 @@ int fib_sync_up(struct net_device *dev, unsigned char nh_flags) return 0; if (nh_flags & RTNH_F_DEAD) { - unsigned int flags = dev_get_flags(dev); + unsigned int flags = netif_get_flags(dev); if (flags & (IFF_RUNNING | IFF_LOWER_UP)) nh_flags |= RTNH_F_LINKDOWN; } prev_fi = NULL; - head = fib_info_devhash_bucket(dev); + head = fib_nh_head(dev); ret = 0; hlist_for_each_entry(nh, head, nh_hash) { @@ -2160,7 +2102,8 @@ int fib_sync_up(struct net_device *dev, unsigned char nh_flags) int alive; BUG_ON(!fi->fib_nhs); - if (nh->fib_nh_dev != dev || fi == prev_fi) + DEBUG_NET_WARN_ON_ONCE(nh->fib_nh_dev != dev); + if (fi == prev_fi) continue; prev_fi = fi; @@ -2200,7 +2143,7 @@ static bool fib_good_nh(const struct fib_nh *nh) if (nh->fib_nh_scope == RT_SCOPE_LINK) { struct neighbour *n; - rcu_read_lock_bh(); + rcu_read_lock(); if (likely(nh->fib_nh_gw_family == AF_INET)) n = __ipv4_neigh_lookup_noref(nh->fib_nh_dev, @@ -2211,42 +2154,60 @@ static bool fib_good_nh(const struct fib_nh *nh) else n = NULL; if (n) - state = n->nud_state; + state = READ_ONCE(n->nud_state); - rcu_read_unlock_bh(); + rcu_read_unlock(); } return !!(state & NUD_VALID); } -void fib_select_multipath(struct fib_result *res, int hash) +void fib_select_multipath(struct fib_result *res, int hash, + const struct flowi4 *fl4) { struct fib_info *fi = res->fi; struct net *net = fi->fib_net; - bool first = false; + bool found = false; + bool use_neigh; + __be32 saddr; if (unlikely(res->fi->nh)) { nexthop_path_fib_result(res, hash); return; } + use_neigh = READ_ONCE(net->ipv4.sysctl_fib_multipath_use_neigh); + saddr = fl4 ? fl4->saddr : 0; + change_nexthops(fi) { - if (net->ipv4.sysctl_fib_multipath_use_neigh) { - if (!fib_good_nh(nexthop_nh)) - continue; - if (!first) { - res->nh_sel = nhsel; - res->nhc = &nexthop_nh->nh_common; - first = true; - } + int nh_upper_bound; + + /* Nexthops without a carrier are assigned an upper bound of + * minus one when "ignore_routes_with_linkdown" is set. + */ + nh_upper_bound = atomic_read(&nexthop_nh->fib_nh_upper_bound); + if (nh_upper_bound == -1 || + (use_neigh && !fib_good_nh(nexthop_nh))) + continue; + + if (!found) { + res->nh_sel = nhsel; + res->nhc = &nexthop_nh->nh_common; + found = !saddr || nexthop_nh->nh_saddr == saddr; } - if (hash > atomic_read(&nexthop_nh->fib_nh_upper_bound)) + if (hash > nh_upper_bound) continue; - res->nh_sel = nhsel; - res->nhc = &nexthop_nh->nh_common; - return; + if (!saddr || nexthop_nh->nh_saddr == saddr) { + res->nh_sel = nhsel; + res->nhc = &nexthop_nh->nh_common; + return; + } + + if (found) + return; + } endfor_nexthops(fi); } #endif @@ -2254,14 +2215,14 @@ void fib_select_multipath(struct fib_result *res, int hash) void fib_select_path(struct net *net, struct fib_result *res, struct flowi4 *fl4, const struct sk_buff *skb) { - if (fl4->flowi4_oif && !(fl4->flowi4_flags & FLOWI_FLAG_SKIP_NH_OIF)) + if (fl4->flowi4_oif) goto check_saddr; #ifdef CONFIG_IP_ROUTE_MULTIPATH if (fib_info_num_path(res->fi) > 1) { int h = fib_multipath_hash(net, fl4, skb, NULL); - fib_select_multipath(res, h); + fib_select_multipath(res, h, fl4); } else #endif @@ -2271,6 +2232,34 @@ void fib_select_path(struct net *net, struct fib_result *res, fib_select_default(fl4, res); check_saddr: - if (!fl4->saddr) - fl4->saddr = fib_result_prefsrc(net, res); + if (!fl4->saddr) { + struct net_device *l3mdev; + + l3mdev = dev_get_by_index_rcu(net, fl4->flowi4_l3mdev); + + if (!l3mdev || + l3mdev_master_dev_rcu(FIB_RES_DEV(*res)) == l3mdev) + fl4->saddr = fib_result_prefsrc(net, res); + else + fl4->saddr = inet_select_addr(l3mdev, 0, RT_SCOPE_LINK); + } +} + +int __net_init fib4_semantics_init(struct net *net) +{ + unsigned int hash_bits = 4; + + net->ipv4.fib_info_hash = fib_info_hash_alloc(hash_bits); + if (!net->ipv4.fib_info_hash) + return -ENOMEM; + + net->ipv4.fib_info_hash_bits = hash_bits; + net->ipv4.fib_info_cnt = 0; + + return 0; +} + +void __net_exit fib4_semantics_exit(struct net *net) +{ + fib_info_hash_free(net->ipv4.fib_info_hash); } diff --git a/net/ipv4/fib_trie.c b/net/ipv4/fib_trie.c index 8060524f4256..59a6f0a9638f 100644 --- a/net/ipv4/fib_trie.c +++ b/net/ipv4/fib_trie.c @@ -52,6 +52,7 @@ #include <linux/if_arp.h> #include <linux/proc_fs.h> #include <linux/rcupdate.h> +#include <linux/rcupdate_wait.h> #include <linux/skbuff.h> #include <linux/netlink.h> #include <linux/init.h> @@ -61,6 +62,7 @@ #include <linux/vmalloc.h> #include <linux/notifier.h> #include <net/net_namespace.h> +#include <net/inet_dscp.h> #include <net/ip.h> #include <net/protocol.h> #include <net/route.h> @@ -81,7 +83,7 @@ static int call_fib_entry_notifier(struct notifier_block *nb, .dst = dst, .dst_len = dst_len, .fi = fa->fa_info, - .tos = fa->fa_tos, + .dscp = fa->fa_dscp, .type = fa->fa_type, .tb_id = fa->tb_id, }; @@ -98,7 +100,7 @@ static int call_fib_entry_notifiers(struct net *net, .dst = dst, .dst_len = dst_len, .fi = fa->fa_info, - .tos = fa->fa_tos, + .dscp = fa->fa_dscp, .type = fa->fa_type, .tb_id = fa->tb_id, }; @@ -125,7 +127,7 @@ struct key_vector { /* This list pointer if valid if (pos | bits) == 0 (LEAF) */ struct hlist_head leaf; /* This array is valid if (pos | bits) > 0 (TNODE) */ - struct key_vector __rcu *tnode[0]; + DECLARE_FLEX_ARRAY(struct key_vector __rcu *, tnode); }; }; @@ -290,15 +292,9 @@ static const int inflate_threshold = 50; static const int halve_threshold_root = 15; static const int inflate_threshold_root = 30; -static void __alias_free_mem(struct rcu_head *head) -{ - struct fib_alias *fa = container_of(head, struct fib_alias, rcu); - kmem_cache_free(fn_alias_kmem, fa); -} - static inline void alias_free_mem_rcu(struct fib_alias *fa) { - call_rcu(&fa->rcu, __alias_free_mem); + kfree_rcu(fa, rcu); } #define TNODE_VMALLOC_MAX \ @@ -497,9 +493,9 @@ static void tnode_free(struct key_vector *tn) tn = container_of(head, struct tnode, rcu)->kv; } - if (tnode_free_size >= sysctl_fib_sync_mem) { + if (tnode_free_size >= READ_ONCE(sysctl_fib_sync_mem)) { tnode_free_size = 0; - synchronize_rcu(); + synchronize_net(); } } @@ -973,13 +969,13 @@ static struct key_vector *fib_find_node(struct trie *t, return n; } -/* Return the first fib alias matching TOS with +/* Return the first fib alias matching DSCP with * priority less than or equal to PRIO. * If 'find_first' is set, return the first matching - * fib alias, regardless of TOS and priority. + * fib alias, regardless of DSCP and priority. */ static struct fib_alias *fib_find_alias(struct hlist_head *fah, u8 slen, - u8 tos, u32 prio, u32 tb_id, + dscp_t dscp, u32 prio, u32 tb_id, bool find_first) { struct fib_alias *fa; @@ -988,6 +984,10 @@ static struct fib_alias *fib_find_alias(struct hlist_head *fah, u8 slen, return NULL; hlist_for_each_entry(fa, fah, fa_list) { + /* Avoid Sparse warning when using dscp_t in inequalities */ + u8 __fa_dscp = inet_dscp_to_dsfield(fa->fa_dscp); + u8 __dscp = inet_dscp_to_dsfield(dscp); + if (fa->fa_slen < slen) continue; if (fa->fa_slen != slen) @@ -998,9 +998,9 @@ static struct fib_alias *fib_find_alias(struct hlist_head *fah, u8 slen, break; if (find_first) return fa; - if (fa->fa_tos > tos) + if (__fa_dscp > __dscp) continue; - if (fa->fa_info->fib_priority >= prio || fa->fa_tos < tos) + if (fa->fa_info->fib_priority >= prio || __fa_dscp < __dscp) return fa; } @@ -1027,7 +1027,7 @@ fib_find_matching_alias(struct net *net, const struct fib_rt_info *fri) hlist_for_each_entry_rcu(fa, &l->leaf, fa_list) { if (fa->fa_slen == slen && fa->tb_id == fri->tb_id && - fa->fa_tos == fri->tos && fa->fa_info == fri->fi && + fa->fa_dscp == fri->dscp && fa->fa_info == fri->fi && fa->fa_type == fri->type) return fa; } @@ -1037,6 +1037,7 @@ fib_find_matching_alias(struct net *net, const struct fib_rt_info *fri) void fib_alias_hw_flags_set(struct net *net, const struct fib_rt_info *fri) { + u8 fib_notify_on_flag_change; struct fib_alias *fa_match; struct sk_buff *skb; int err; @@ -1047,21 +1048,27 @@ void fib_alias_hw_flags_set(struct net *net, const struct fib_rt_info *fri) if (!fa_match) goto out; - if (fa_match->offload == fri->offload && fa_match->trap == fri->trap && - fa_match->offload_failed == fri->offload_failed) + /* These are paired with the WRITE_ONCE() happening in this function. + * The reason is that we are only protected by RCU at this point. + */ + if (READ_ONCE(fa_match->offload) == fri->offload && + READ_ONCE(fa_match->trap) == fri->trap && + READ_ONCE(fa_match->offload_failed) == fri->offload_failed) goto out; - fa_match->offload = fri->offload; - fa_match->trap = fri->trap; + WRITE_ONCE(fa_match->offload, fri->offload); + WRITE_ONCE(fa_match->trap, fri->trap); + + fib_notify_on_flag_change = READ_ONCE(net->ipv4.sysctl_fib_notify_on_flag_change); /* 2 means send notifications only if offload_failed was changed. */ - if (net->ipv4.sysctl_fib_notify_on_flag_change == 2 && - fa_match->offload_failed == fri->offload_failed) + if (fib_notify_on_flag_change == 2 && + READ_ONCE(fa_match->offload_failed) == fri->offload_failed) goto out; - fa_match->offload_failed = fri->offload_failed; + WRITE_ONCE(fa_match->offload_failed, fri->offload_failed); - if (!net->ipv4.sysctl_fib_notify_on_flag_change) + if (!fib_notify_on_flag_change) goto out; skb = nlmsg_new(fib_nlmsg_size(fa_match->fa_info), GFP_ATOMIC); @@ -1180,22 +1187,6 @@ static int fib_insert_alias(struct trie *t, struct key_vector *tp, return 0; } -static bool fib_valid_key_len(u32 key, u8 plen, struct netlink_ext_ack *extack) -{ - if (plen > KEYLENGTH) { - NL_SET_ERR_MSG(extack, "Invalid prefix length"); - return false; - } - - if ((plen < KEYLENGTH) && (key << plen)) { - NL_SET_ERR_MSG(extack, - "Invalid prefix for given prefix length"); - return false; - } - - return true; -} - static void fib_remove_alias(struct trie *t, struct key_vector *tp, struct key_vector *l, struct fib_alias *old); @@ -1210,15 +1201,12 @@ int fib_table_insert(struct net *net, struct fib_table *tb, struct fib_info *fi; u8 plen = cfg->fc_dst_len; u8 slen = KEYLENGTH - plen; - u8 tos = cfg->fc_tos; + dscp_t dscp; u32 key; int err; key = ntohl(cfg->fc_dst); - if (!fib_valid_key_len(key, plen, extack)) - return -EINVAL; - pr_debug("Insert table=%u %08x/%d\n", tb->tb_id, key, plen); fi = fib_create_info(cfg, extack); @@ -1227,12 +1215,13 @@ int fib_table_insert(struct net *net, struct fib_table *tb, goto err; } + dscp = cfg->fc_dscp; l = fib_find_node(t, &tp, key); - fa = l ? fib_find_alias(&l->leaf, slen, tos, fi->fib_priority, + fa = l ? fib_find_alias(&l->leaf, slen, dscp, fi->fib_priority, tb->tb_id, false) : NULL; /* Now fa, if non-NULL, points to the first fib alias - * with the same keys [prefix,tos,priority], if such key already + * with the same keys [prefix,dscp,priority], if such key already * exists or to the node before which we will insert new one. * * If fa is NULL, we will need to allocate a new one and @@ -1240,7 +1229,7 @@ int fib_table_insert(struct net *net, struct fib_table *tb, * of the new alias. */ - if (fa && fa->fa_tos == tos && + if (fa && fa->fa_dscp == dscp && fa->fa_info->fib_priority == fi->fib_priority) { struct fib_alias *fa_first, *fa_match; @@ -1260,7 +1249,7 @@ int fib_table_insert(struct net *net, struct fib_table *tb, hlist_for_each_entry_from(fa, fa_list) { if ((fa->fa_slen != slen) || (fa->tb_id != tb->tb_id) || - (fa->fa_tos != tos)) + (fa->fa_dscp != dscp)) break; if (fa->fa_info->fib_priority != fi->fib_priority) break; @@ -1288,7 +1277,7 @@ int fib_table_insert(struct net *net, struct fib_table *tb, goto out; fi_drop = fa->fa_info; - new_fa->fa_tos = fa->fa_tos; + new_fa->fa_dscp = fa->fa_dscp; new_fa->fa_info = fi; new_fa->fa_type = cfg->fc_type; state = fa->fa_state; @@ -1351,7 +1340,7 @@ int fib_table_insert(struct net *net, struct fib_table *tb, goto out; new_fa->fa_info = fi; - new_fa->fa_tos = tos; + new_fa->fa_dscp = dscp; new_fa->fa_type = cfg->fc_type; new_fa->fa_state = 0; new_fa->fa_slen = slen; @@ -1368,8 +1357,10 @@ int fib_table_insert(struct net *net, struct fib_table *tb, /* The alias was already inserted, so the node must exist. */ l = l ? l : fib_find_node(t, &tp, key); - if (WARN_ON_ONCE(!l)) + if (WARN_ON_ONCE(!l)) { + err = -ENOENT; goto out_free_new_fa; + } if (fib_find_alias(&l->leaf, new_fa->fa_slen, 0, 0, tb->tb_id, true) == new_fa) { @@ -1419,11 +1410,8 @@ bool fib_lookup_good_nhc(const struct fib_nh_common *nhc, int fib_flags, !(fib_flags & FIB_LOOKUP_IGNORE_LINKSTATE)) return false; - if (!(flp->flowi4_flags & FLOWI_FLAG_SKIP_NH_OIF)) { - if (flp->flowi4_oif && - flp->flowi4_oif != nhc->nhc_oif) - return false; - } + if (flp->flowi4_oif && flp->flowi4_oif != nhc->nhc_oif) + return false; return true; } @@ -1567,9 +1555,10 @@ found: if (index >= (1ul << fa->fa_slen)) continue; } - if (fa->fa_tos && fa->fa_tos != flp->flowi4_tos) + if (fa->fa_dscp && !fib_dscp_masked_match(fa->fa_dscp, flp)) continue; - if (fi->fib_dead) + /* Paired with WRITE_ONCE() in fib_release_info() */ + if (READ_ONCE(fi->fib_dead)) continue; if (fa->fa_info->fib_scope < flp->flowi4_scope) continue; @@ -1614,6 +1603,7 @@ set_result: res->nhc = nhc; res->type = fa->fa_type; res->scope = fi->fib_scope; + res->dscp = fa->fa_dscp; res->fi = fi; res->table = tb; res->fa_head = &n->leaf; @@ -1703,23 +1693,22 @@ int fib_table_delete(struct net *net, struct fib_table *tb, struct key_vector *l, *tp; u8 plen = cfg->fc_dst_len; u8 slen = KEYLENGTH - plen; - u8 tos = cfg->fc_tos; + dscp_t dscp; u32 key; key = ntohl(cfg->fc_dst); - if (!fib_valid_key_len(key, plen, extack)) - return -EINVAL; - l = fib_find_node(t, &tp, key); if (!l) return -ESRCH; - fa = fib_find_alias(&l->leaf, slen, tos, 0, tb->tb_id, false); + dscp = cfg->fc_dscp; + fa = fib_find_alias(&l->leaf, slen, dscp, 0, tb->tb_id, false); if (!fa) return -ESRCH; - pr_debug("Deleting %08x/%d tos=%d t=%p\n", key, plen, tos, t); + pr_debug("Deleting %08x/%d dsfield=0x%02x t=%p\n", key, plen, + inet_dscp_to_dsfield(dscp), t); fa_to_delete = NULL; hlist_for_each_entry_from(fa, fa_list) { @@ -1727,7 +1716,7 @@ int fib_table_delete(struct net *net, struct fib_table *tb, if ((fa->fa_slen != slen) || (fa->tb_id != tb->tb_id) || - (fa->fa_tos != tos)) + (fa->fa_dscp != dscp)) break; if ((!cfg->fc_type || fa->fa_type == cfg->fc_type) && @@ -2011,6 +2000,7 @@ void fib_table_flush_external(struct fib_table *tb) int fib_table_flush(struct net *net, struct fib_table *tb, bool flush_all) { struct trie *t = (struct trie *)tb->tb_data; + struct nl_info info = { .nl_net = net }; struct key_vector *pn = t->kv; unsigned long cindex = 1; struct hlist_node *tmp; @@ -2073,6 +2063,9 @@ int fib_table_flush(struct net *net, struct fib_table *tb, bool flush_all) fib_notify_alias_delete(net, n->key, &n->leaf, fa, NULL); + if (fi->pfsrc_removed) + rtmsg_fib(RTM_DELROUTE, htonl(n->key), fa, + KEYLENGTH - fa->fa_slen, tb->tb_id, &info, 0); hlist_del_rcu(&fa->fa_list); fib_release_info(fa->fa_info); alias_free_mem_rcu(fa); @@ -2295,11 +2288,11 @@ static int fn_trie_dump_leaf(struct key_vector *l, struct fib_table *tb, fri.tb_id = tb->tb_id; fri.dst = xkey; fri.dst_len = KEYLENGTH - fa->fa_slen; - fri.tos = fa->fa_tos; + fri.dscp = fa->fa_dscp; fri.type = fa->fa_type; - fri.offload = fa->offload; - fri.trap = fa->trap; - fri.offload_failed = fa->offload_failed; + fri.offload = READ_ONCE(fa->offload); + fri.trap = READ_ONCE(fa->trap); + fri.offload_failed = READ_ONCE(fa->offload_failed); err = fib_dump_info(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, @@ -2347,7 +2340,7 @@ int fib_table_dump(struct fib_table *tb, struct sk_buff *skb, * and key == 0 means the dump has wrapped around and we are done. */ if (count && !key) - return skb->len; + return 0; while ((l = leaf_walk_rcu(&tp, key)) != NULL) { int err; @@ -2373,7 +2366,7 @@ int fib_table_dump(struct fib_table *tb, struct sk_buff *skb, cb->args[3] = key; cb->args[2] = count; - return skb->len; + return 0; } void __init fib_trie_init(void) @@ -2615,7 +2608,7 @@ static void fib_table_print(struct seq_file *seq, struct fib_table *tb) static int fib_triestat_seq_show(struct seq_file *seq, void *v) { - struct net *net = (struct net *)seq->private; + struct net *net = seq->private; unsigned int h; seq_printf(seq, @@ -2807,8 +2800,9 @@ static int fib_trie_seq_show(struct seq_file *seq, void *v) fa->fa_info->fib_scope), rtn_type(buf2, sizeof(buf2), fa->fa_type)); - if (fa->fa_tos) - seq_printf(seq, " tos=%d", fa->fa_tos); + if (fa->fa_dscp) + seq_printf(seq, " tos=%d", + inet_dscp_to_dsfield(fa->fa_dscp)); seq_putc(seq, '\n'); } } @@ -2983,7 +2977,7 @@ static int fib_route_seq_show(struct seq_file *seq, void *v) seq_printf(seq, "%s\t%08X\t%08X\t%04X\t%d\t%u\t" - "%d\t%08X\t%d\t%u\t%u", + "%u\t%08X\t%d\t%u\t%u", nhc->nhc_dev ? nhc->nhc_dev->name : "*", prefix, gw, flags, 0, 0, fi->fib_priority, @@ -2995,7 +2989,7 @@ static int fib_route_seq_show(struct seq_file *seq, void *v) } else { seq_printf(seq, "*\t%08X\t%08X\t%04X\t%d\t%u\t" - "%d\t%08X\t%d\t%u\t%u", + "%u\t%08X\t%d\t%u\t%u", prefix, 0, flags, 0, 0, 0, mask, 0, 0, 0); } diff --git a/net/ipv4/fou_bpf.c b/net/ipv4/fou_bpf.c new file mode 100644 index 000000000000..54984f3170a8 --- /dev/null +++ b/net/ipv4/fou_bpf.c @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Unstable Fou Helpers for TC-BPF hook + * + * These are called from SCHED_CLS BPF programs. Note that it is + * allowed to break compatibility for these functions since the interface they + * are exposed through to BPF programs is explicitly unstable. + */ + +#include <linux/bpf.h> +#include <linux/btf_ids.h> + +#include <net/dst_metadata.h> +#include <net/fou.h> + +struct bpf_fou_encap { + __be16 sport; + __be16 dport; +}; + +enum bpf_fou_encap_type { + FOU_BPF_ENCAP_FOU, + FOU_BPF_ENCAP_GUE, +}; + +__bpf_kfunc_start_defs(); + +/* bpf_skb_set_fou_encap - Set FOU encap parameters + * + * This function allows for using GUE or FOU encapsulation together with an + * ipip device in collect-metadata mode. + * + * It is meant to be used in BPF tc-hooks and after a call to the + * bpf_skb_set_tunnel_key helper, responsible for setting IP addresses. + * + * Parameters: + * @skb_ctx Pointer to ctx (__sk_buff) in TC program. Cannot be NULL + * @encap Pointer to a `struct bpf_fou_encap` storing UDP src and + * dst ports. If sport is set to 0 the kernel will auto-assign a + * port. This is similar to using `encap-sport auto`. + * Cannot be NULL + * @type Encapsulation type for the packet. Their definitions are + * specified in `enum bpf_fou_encap_type` + */ +__bpf_kfunc int bpf_skb_set_fou_encap(struct __sk_buff *skb_ctx, + struct bpf_fou_encap *encap, int type) +{ + struct sk_buff *skb = (struct sk_buff *)skb_ctx; + struct ip_tunnel_info *info = skb_tunnel_info(skb); + + if (unlikely(!encap)) + return -EINVAL; + + if (unlikely(!info || !(info->mode & IP_TUNNEL_INFO_TX))) + return -EINVAL; + + switch (type) { + case FOU_BPF_ENCAP_FOU: + info->encap.type = TUNNEL_ENCAP_FOU; + break; + case FOU_BPF_ENCAP_GUE: + info->encap.type = TUNNEL_ENCAP_GUE; + break; + default: + info->encap.type = TUNNEL_ENCAP_NONE; + } + + if (test_bit(IP_TUNNEL_CSUM_BIT, info->key.tun_flags)) + info->encap.flags |= TUNNEL_ENCAP_FLAG_CSUM; + + info->encap.sport = encap->sport; + info->encap.dport = encap->dport; + + return 0; +} + +/* bpf_skb_get_fou_encap - Get FOU encap parameters + * + * This function allows for reading encap metadata from a packet received + * on an ipip device in collect-metadata mode. + * + * Parameters: + * @skb_ctx Pointer to ctx (__sk_buff) in TC program. Cannot be NULL + * @encap Pointer to a struct bpf_fou_encap storing UDP source and + * destination port. Cannot be NULL + */ +__bpf_kfunc int bpf_skb_get_fou_encap(struct __sk_buff *skb_ctx, + struct bpf_fou_encap *encap) +{ + struct sk_buff *skb = (struct sk_buff *)skb_ctx; + struct ip_tunnel_info *info = skb_tunnel_info(skb); + + if (unlikely(!info)) + return -EINVAL; + + encap->sport = info->encap.sport; + encap->dport = info->encap.dport; + + return 0; +} + +__bpf_kfunc_end_defs(); + +BTF_KFUNCS_START(fou_kfunc_set) +BTF_ID_FLAGS(func, bpf_skb_set_fou_encap) +BTF_ID_FLAGS(func, bpf_skb_get_fou_encap) +BTF_KFUNCS_END(fou_kfunc_set) + +static const struct btf_kfunc_id_set fou_bpf_kfunc_set = { + .owner = THIS_MODULE, + .set = &fou_kfunc_set, +}; + +int register_fou_bpf(void) +{ + return register_btf_kfunc_id_set(BPF_PROG_TYPE_SCHED_CLS, + &fou_bpf_kfunc_set); +} diff --git a/net/ipv4/fou.c b/net/ipv4/fou_core.c index 0d085cc8d96c..3970b6b7ace5 100644 --- a/net/ipv4/fou.c +++ b/net/ipv4/fou_core.c @@ -16,10 +16,11 @@ #include <net/protocol.h> #include <net/udp.h> #include <net/udp_tunnel.h> -#include <net/xfrm.h> #include <uapi/linux/fou.h> #include <uapi/linux/genetlink.h> +#include "fou_nl.h" + struct fou { struct socket *sock; u8 protocol; @@ -49,7 +50,7 @@ struct fou_net { static inline struct fou *fou_from_sock(struct sock *sk) { - return sk->sk_user_data; + return rcu_dereference_sk_user_data(sk); } static int fou_recv_pull(struct sk_buff *skb, struct fou *fou, size_t len) @@ -227,15 +228,27 @@ drop: return 0; } +static const struct net_offload *fou_gro_ops(const struct sock *sk, + int proto) +{ + const struct net_offload __rcu **offloads; + + /* FOU doesn't allow IPv4 on IPv6 sockets. */ + offloads = sk->sk_family == AF_INET6 ? inet6_offloads : inet_offloads; + return rcu_dereference(offloads[proto]); +} + static struct sk_buff *fou_gro_receive(struct sock *sk, struct list_head *head, struct sk_buff *skb) { - const struct net_offload __rcu **offloads; - u8 proto = fou_from_sock(sk)->protocol; + struct fou *fou = fou_from_sock(sk); const struct net_offload *ops; struct sk_buff *pp = NULL; + if (!fou) + goto out; + /* We can clear the encap_mark for FOU as we are essentially doing * one of two possible things. We are either adding an L4 tunnel * header to the outer L3 tunnel header, or we are simply @@ -247,8 +260,7 @@ static struct sk_buff *fou_gro_receive(struct sock *sk, /* Flag this frame as already having an outer encap header */ NAPI_GRO_CB(skb)->is_fou = 1; - offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads; - ops = rcu_dereference(offloads[proto]); + ops = fou_gro_ops(sk, fou->protocol); if (!ops || !ops->callbacks.gro_receive) goto out; @@ -261,15 +273,20 @@ out: static int fou_gro_complete(struct sock *sk, struct sk_buff *skb, int nhoff) { - const struct net_offload __rcu **offloads; - u8 proto = fou_from_sock(sk)->protocol; + struct fou *fou = fou_from_sock(sk); const struct net_offload *ops; - int err = -ENOSYS; + int err; - offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads; - ops = rcu_dereference(offloads[proto]); - if (WARN_ON(!ops || !ops->callbacks.gro_complete)) + if (!fou) { + err = -ENOENT; goto out; + } + + ops = fou_gro_ops(sk, fou->protocol); + if (WARN_ON(!ops || !ops->callbacks.gro_complete)) { + err = -ENOSYS; + goto out; + } err = ops->callbacks.gro_complete(skb, nhoff); @@ -306,7 +323,6 @@ static struct sk_buff *gue_gro_receive(struct sock *sk, struct list_head *head, struct sk_buff *skb) { - const struct net_offload __rcu **offloads; const struct net_offload *ops; struct sk_buff *pp = NULL; struct sk_buff *p; @@ -321,15 +337,15 @@ static struct sk_buff *gue_gro_receive(struct sock *sk, skb_gro_remcsum_init(&grc); + if (!fou) + goto out; + off = skb_gro_offset(skb); len = off + sizeof(*guehdr); - guehdr = skb_gro_header_fast(skb, off); - if (skb_gro_header_hard(skb, len)) { - guehdr = skb_gro_header_slow(skb, len, off); - if (unlikely(!guehdr)) - goto out; - } + guehdr = skb_gro_header(skb, len, off); + if (unlikely(!guehdr)) + goto out; switch (guehdr->version) { case 0: @@ -353,7 +369,7 @@ static struct sk_buff *gue_gro_receive(struct sock *sk, optlen = guehdr->hlen << 2; len += optlen; - if (skb_gro_header_hard(skb, len)) { + if (!skb_gro_may_pull(skb, len)) { guehdr = skb_gro_header_slow(skb, len, off); if (unlikely(!guehdr)) goto out; @@ -433,9 +449,8 @@ next_proto: /* Flag this frame as already having an outer encap header */ NAPI_GRO_CB(skb)->is_fou = 1; - offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads; - ops = rcu_dereference(offloads[proto]); - if (WARN_ON_ONCE(!ops || !ops->callbacks.gro_receive)) + ops = fou_gro_ops(sk, proto); + if (!ops || !ops->callbacks.gro_receive) goto out; pp = call_gro_receive(ops->callbacks.gro_receive, head, skb); @@ -450,7 +465,6 @@ out: static int gue_gro_complete(struct sock *sk, struct sk_buff *skb, int nhoff) { struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff); - const struct net_offload __rcu **offloads; const struct net_offload *ops; unsigned int guehlen = 0; u8 proto; @@ -477,8 +491,7 @@ static int gue_gro_complete(struct sock *sk, struct sk_buff *skb, int nhoff) return err; } - offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads; - ops = rcu_dereference(offloads[proto]); + ops = fou_gro_ops(sk, proto); if (WARN_ON(!ops || !ops->callbacks.gro_complete)) goto out; @@ -644,20 +657,6 @@ static int fou_destroy(struct net *net, struct fou_cfg *cfg) static struct genl_family fou_nl_family; -static const struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = { - [FOU_ATTR_PORT] = { .type = NLA_U16, }, - [FOU_ATTR_AF] = { .type = NLA_U8, }, - [FOU_ATTR_IPPROTO] = { .type = NLA_U8, }, - [FOU_ATTR_TYPE] = { .type = NLA_U8, }, - [FOU_ATTR_REMCSUM_NOPARTIAL] = { .type = NLA_FLAG, }, - [FOU_ATTR_LOCAL_V4] = { .type = NLA_U32, }, - [FOU_ATTR_PEER_V4] = { .type = NLA_U32, }, - [FOU_ATTR_LOCAL_V6] = { .len = sizeof(struct in6_addr), }, - [FOU_ATTR_PEER_V6] = { .len = sizeof(struct in6_addr), }, - [FOU_ATTR_PEER_PORT] = { .type = NLA_U16, }, - [FOU_ATTR_IFINDEX] = { .type = NLA_S32, }, -}; - static int parse_nl_config(struct genl_info *info, struct fou_cfg *cfg) { @@ -749,7 +748,7 @@ static int parse_nl_config(struct genl_info *info, return 0; } -static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info) +int fou_nl_add_doit(struct sk_buff *skb, struct genl_info *info) { struct net *net = genl_info_net(info); struct fou_cfg cfg; @@ -762,7 +761,7 @@ static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info) return fou_create(net, &cfg, NULL); } -static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info) +int fou_nl_del_doit(struct sk_buff *skb, struct genl_info *info) { struct net *net = genl_info_net(info); struct fou_cfg cfg; @@ -831,7 +830,7 @@ nla_put_failure: return -EMSGSIZE; } -static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info) +int fou_nl_get_doit(struct sk_buff *skb, struct genl_info *info) { struct net *net = genl_info_net(info); struct fou_net *fn = net_generic(net, fou_net_id); @@ -878,7 +877,7 @@ out_free: return ret; } -static int fou_nl_dump(struct sk_buff *skb, struct netlink_callback *cb) +int fou_nl_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb) { struct net *net = sock_net(skb->sk); struct fou_net *fn = net_generic(net, fou_net_id); @@ -901,37 +900,17 @@ static int fou_nl_dump(struct sk_buff *skb, struct netlink_callback *cb) return skb->len; } -static const struct genl_small_ops fou_nl_ops[] = { - { - .cmd = FOU_CMD_ADD, - .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, - .doit = fou_nl_cmd_add_port, - .flags = GENL_ADMIN_PERM, - }, - { - .cmd = FOU_CMD_DEL, - .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, - .doit = fou_nl_cmd_rm_port, - .flags = GENL_ADMIN_PERM, - }, - { - .cmd = FOU_CMD_GET, - .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, - .doit = fou_nl_cmd_get_port, - .dumpit = fou_nl_dump, - }, -}; - static struct genl_family fou_nl_family __ro_after_init = { .hdrsize = 0, .name = FOU_GENL_NAME, .version = FOU_GENL_VERSION, .maxattr = FOU_ATTR_MAX, - .policy = fou_nl_policy, + .policy = fou_nl_policy, .netnsok = true, .module = THIS_MODULE, .small_ops = fou_nl_ops, .n_small_ops = ARRAY_SIZE(fou_nl_ops), + .resv_start_op = FOU_CMD_GET + 1, }; size_t fou_encap_hlen(struct ip_tunnel_encap *e) @@ -1272,10 +1251,15 @@ static int __init fou_init(void) if (ret < 0) goto unregister; + ret = register_fou_bpf(); + if (ret < 0) + goto kfunc_failed; + ret = ip_tunnel_encap_add_fou_ops(); if (ret == 0) return 0; +kfunc_failed: genl_unregister_family(&fou_nl_family); unregister: unregister_pernet_device(&fou_net_ops); diff --git a/net/ipv4/fou_nl.c b/net/ipv4/fou_nl.c new file mode 100644 index 000000000000..7a99639204b1 --- /dev/null +++ b/net/ipv4/fou_nl.c @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) +/* Do not edit directly, auto-generated from: */ +/* Documentation/netlink/specs/fou.yaml */ +/* YNL-GEN kernel source */ +/* To regenerate run: tools/net/ynl/ynl-regen.sh */ + +#include <net/netlink.h> +#include <net/genetlink.h> + +#include "fou_nl.h" + +#include <uapi/linux/fou.h> + +/* Global operation policy for fou */ +const struct nla_policy fou_nl_policy[FOU_ATTR_IFINDEX + 1] = { + [FOU_ATTR_PORT] = { .type = NLA_BE16, }, + [FOU_ATTR_AF] = { .type = NLA_U8, }, + [FOU_ATTR_IPPROTO] = { .type = NLA_U8, }, + [FOU_ATTR_TYPE] = { .type = NLA_U8, }, + [FOU_ATTR_REMCSUM_NOPARTIAL] = { .type = NLA_FLAG, }, + [FOU_ATTR_LOCAL_V4] = { .type = NLA_U32, }, + [FOU_ATTR_LOCAL_V6] = NLA_POLICY_EXACT_LEN(16), + [FOU_ATTR_PEER_V4] = { .type = NLA_U32, }, + [FOU_ATTR_PEER_V6] = NLA_POLICY_EXACT_LEN(16), + [FOU_ATTR_PEER_PORT] = { .type = NLA_BE16, }, + [FOU_ATTR_IFINDEX] = { .type = NLA_S32, }, +}; + +/* Ops table for fou */ +const struct genl_small_ops fou_nl_ops[3] = { + { + .cmd = FOU_CMD_ADD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = fou_nl_add_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = FOU_CMD_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = fou_nl_del_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = FOU_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = fou_nl_get_doit, + .dumpit = fou_nl_get_dumpit, + }, +}; diff --git a/net/ipv4/fou_nl.h b/net/ipv4/fou_nl.h new file mode 100644 index 000000000000..438342dc8507 --- /dev/null +++ b/net/ipv4/fou_nl.h @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */ +/* Do not edit directly, auto-generated from: */ +/* Documentation/netlink/specs/fou.yaml */ +/* YNL-GEN kernel header */ +/* To regenerate run: tools/net/ynl/ynl-regen.sh */ + +#ifndef _LINUX_FOU_GEN_H +#define _LINUX_FOU_GEN_H + +#include <net/netlink.h> +#include <net/genetlink.h> + +#include <uapi/linux/fou.h> + +/* Global operation policy for fou */ +extern const struct nla_policy fou_nl_policy[FOU_ATTR_IFINDEX + 1]; + +/* Ops table for fou */ +extern const struct genl_small_ops fou_nl_ops[3]; + +int fou_nl_add_doit(struct sk_buff *skb, struct genl_info *info); +int fou_nl_del_doit(struct sk_buff *skb, struct genl_info *info); +int fou_nl_get_doit(struct sk_buff *skb, struct genl_info *info); +int fou_nl_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb); + +#endif /* _LINUX_FOU_GEN_H */ diff --git a/net/ipv4/gre_demux.c b/net/ipv4/gre_demux.c index cbb2b4bb0dfa..dafd68f3436a 100644 --- a/net/ipv4/gre_demux.c +++ b/net/ipv4/gre_demux.c @@ -73,7 +73,7 @@ int gre_parse_header(struct sk_buff *skb, struct tnl_ptk_info *tpi, if (unlikely(greh->flags & (GRE_VERSION | GRE_ROUTING))) return -EINVAL; - tpi->flags = gre_flags_to_tnl_flags(greh->flags); + gre_flags_to_tnl_flags(tpi->flags, greh->flags); hdr_len = gre_calc_hlen(tpi->flags); if (!pskb_may_pull(skb, nhs + hdr_len)) @@ -199,7 +199,7 @@ static const struct net_protocol net_gre_protocol = { static int __init gre_init(void) { - pr_info("GRE over IPv4 demultiplexor driver\n"); + pr_info("GRE over IPv4 demultiplexer driver\n"); if (inet_add_protocol(&net_gre_protocol, IPPROTO_GRE) < 0) { pr_err("can't add protocol\n"); @@ -217,5 +217,5 @@ module_init(gre_init); module_exit(gre_exit); MODULE_DESCRIPTION("GRE over IPv4 demultiplexer driver"); -MODULE_AUTHOR("D. Kozlov (xeb@mail.ru)"); +MODULE_AUTHOR("D. Kozlov <xeb@mail.ru>"); MODULE_LICENSE("GPL"); diff --git a/net/ipv4/gre_offload.c b/net/ipv4/gre_offload.c index 07073fa35205..5028c72d494a 100644 --- a/net/ipv4/gre_offload.c +++ b/net/ipv4/gre_offload.c @@ -11,6 +11,7 @@ #include <net/protocol.h> #include <net/gre.h> #include <net/gro.h> +#include <net/gso.h> static struct sk_buff *gre_gso_segment(struct sk_buff *skb, netdev_features_t features) @@ -137,12 +138,9 @@ static struct sk_buff *gre_gro_receive(struct list_head *head, off = skb_gro_offset(skb); hlen = off + sizeof(*greh); - greh = skb_gro_header_fast(skb, off); - if (skb_gro_header_hard(skb, hlen)) { - greh = skb_gro_header_slow(skb, hlen, off); - if (unlikely(!greh)) - goto out; - } + greh = skb_gro_header(skb, hlen, off); + if (unlikely(!greh)) + goto out; /* Only support version 0 and K (key), C (csum) flags. Note that * although the support for the S (seq#) flag can be added easily @@ -176,7 +174,7 @@ static struct sk_buff *gre_gro_receive(struct list_head *head, grehlen += GRE_HEADER_SECTION; hlen = off + grehlen; - if (skb_gro_header_hard(skb, hlen)) { + if (!skb_gro_may_pull(skb, hlen)) { greh = skb_gro_header_slow(skb, hlen, off); if (unlikely(!greh)) goto out; diff --git a/net/ipv4/icmp.c b/net/ipv4/icmp.c index b7e277d8a84d..4abbec2f47ef 100644 --- a/net/ipv4/icmp.c +++ b/net/ipv4/icmp.c @@ -72,6 +72,7 @@ #include <linux/string.h> #include <linux/netfilter_ipv4.h> #include <linux/slab.h> +#include <net/flow.h> #include <net/snmp.h> #include <net/ip.h> #include <net/route.h> @@ -92,6 +93,10 @@ #include <net/inet_common.h> #include <net/ip_fib.h> #include <net/l3mdev.h> +#include <net/addrconf.h> +#include <net/inet_dscp.h> +#define CREATE_TRACE_POINTS +#include <trace/events/icmp.h> /* * Build xmit assembly blocks @@ -186,30 +191,20 @@ EXPORT_SYMBOL(icmp_err_convert); */ struct icmp_control { - bool (*handler)(struct sk_buff *skb); + enum skb_drop_reason (*handler)(struct sk_buff *skb); short error; /* This ICMP is classed as an error message */ }; static const struct icmp_control icmp_pointers[NR_ICMP_TYPES+1]; -/* - * The ICMP socket(s). This is the most convenient way to flow control - * our ICMP output as well as maintain a clean interface throughout - * all layers. All Socketless IP sends will soon be gone. - * - * On SMP we have one ICMP socket per-cpu. - */ -static struct sock *icmp_sk(struct net *net) -{ - return this_cpu_read(*net->ipv4.icmp_sk); -} +static DEFINE_PER_CPU(struct sock *, ipv4_icmp_sk); /* Called with BH disabled */ static inline struct sock *icmp_xmit_lock(struct net *net) { struct sock *sk; - sk = icmp_sk(net); + sk = this_cpu_read(ipv4_icmp_sk); if (unlikely(!spin_trylock(&sk->sk_lock.slock))) { /* This can happen if the output path signals a @@ -217,68 +212,66 @@ static inline struct sock *icmp_xmit_lock(struct net *net) */ return NULL; } + sock_net_set(sk, net); return sk; } static inline void icmp_xmit_unlock(struct sock *sk) { + sock_net_set(sk, &init_net); spin_unlock(&sk->sk_lock.slock); } -int sysctl_icmp_msgs_per_sec __read_mostly = 1000; -int sysctl_icmp_msgs_burst __read_mostly = 50; - -static struct { - spinlock_t lock; - u32 credit; - u32 stamp; -} icmp_global = { - .lock = __SPIN_LOCK_UNLOCKED(icmp_global.lock), -}; - /** * icmp_global_allow - Are we allowed to send one more ICMP message ? + * @net: network namespace * * Uses a token bucket to limit our ICMP messages to ~sysctl_icmp_msgs_per_sec. * Returns false if we reached the limit and can not send another packet. - * Note: called with BH disabled + * Works in tandem with icmp_global_consume(). */ -bool icmp_global_allow(void) +bool icmp_global_allow(struct net *net) { - u32 credit, delta, incr = 0, now = (u32)jiffies; - bool rc = false; + u32 delta, now, oldstamp; + int incr, new, old; - /* Check if token bucket is empty and cannot be refilled - * without taking the spinlock. The READ_ONCE() are paired - * with the following WRITE_ONCE() in this same function. + /* Note: many cpus could find this condition true. + * Then later icmp_global_consume() could consume more credits, + * this is an acceptable race. */ - if (!READ_ONCE(icmp_global.credit)) { - delta = min_t(u32, now - READ_ONCE(icmp_global.stamp), HZ); - if (delta < HZ / 50) - return false; - } + if (atomic_read(&net->ipv4.icmp_global_credit) > 0) + return true; - spin_lock(&icmp_global.lock); - delta = min_t(u32, now - icmp_global.stamp, HZ); - if (delta >= HZ / 50) { - incr = sysctl_icmp_msgs_per_sec * delta / HZ ; - if (incr) - WRITE_ONCE(icmp_global.stamp, now); - } - credit = min_t(u32, icmp_global.credit + incr, sysctl_icmp_msgs_burst); - if (credit) { - /* We want to use a credit of one in average, but need to randomize - * it for security reasons. - */ - credit = max_t(int, credit - prandom_u32_max(3), 0); - rc = true; + now = jiffies; + oldstamp = READ_ONCE(net->ipv4.icmp_global_stamp); + delta = min_t(u32, now - oldstamp, HZ); + if (delta < HZ / 50) + return false; + + incr = READ_ONCE(net->ipv4.sysctl_icmp_msgs_per_sec) * delta / HZ; + if (!incr) + return false; + + if (cmpxchg(&net->ipv4.icmp_global_stamp, oldstamp, now) == oldstamp) { + old = atomic_read(&net->ipv4.icmp_global_credit); + do { + new = min(old + incr, READ_ONCE(net->ipv4.sysctl_icmp_msgs_burst)); + } while (!atomic_try_cmpxchg(&net->ipv4.icmp_global_credit, &old, new)); } - WRITE_ONCE(icmp_global.credit, credit); - spin_unlock(&icmp_global.lock); - return rc; + return true; } EXPORT_SYMBOL(icmp_global_allow); +void icmp_global_consume(struct net *net) +{ + int credits = get_random_u32_below(3); + + /* Note: this might make icmp_global.credit negative. */ + if (credits) + atomic_sub(credits, &net->ipv4.icmp_global_credit); +} +EXPORT_SYMBOL(icmp_global_consume); + static bool icmpv4_mask_allow(struct net *net, int type, int code) { if (type > NR_ICMP_TYPES) @@ -289,20 +282,23 @@ static bool icmpv4_mask_allow(struct net *net, int type, int code) return true; /* Limit if icmp type is enabled in ratemask. */ - if (!((1 << type) & net->ipv4.sysctl_icmp_ratemask)) + if (!((1 << type) & READ_ONCE(net->ipv4.sysctl_icmp_ratemask))) return true; return false; } -static bool icmpv4_global_allow(struct net *net, int type, int code) +static bool icmpv4_global_allow(struct net *net, int type, int code, + bool *apply_ratelimit) { if (icmpv4_mask_allow(net, type, code)) return true; - if (icmp_global_allow()) + if (icmp_global_allow(net)) { + *apply_ratelimit = true; return true; - + } + __ICMP_INC_STATS(net, ICMP_MIB_RATELIMITGLOBAL); return false; } @@ -311,26 +307,33 @@ static bool icmpv4_global_allow(struct net *net, int type, int code) */ static bool icmpv4_xrlim_allow(struct net *net, struct rtable *rt, - struct flowi4 *fl4, int type, int code) + struct flowi4 *fl4, int type, int code, + bool apply_ratelimit) { struct dst_entry *dst = &rt->dst; struct inet_peer *peer; + struct net_device *dev; bool rc = true; - int vif; - if (icmpv4_mask_allow(net, type, code)) - goto out; + if (!apply_ratelimit) + return true; /* No rate limit on loopback */ - if (dst->dev && (dst->dev->flags&IFF_LOOPBACK)) + rcu_read_lock(); + dev = dst_dev_rcu(dst); + if (dev && (dev->flags & IFF_LOOPBACK)) goto out; - vif = l3mdev_master_ifindex(dst->dev); - peer = inet_getpeer_v4(net->ipv4.peers, fl4->daddr, vif, 1); - rc = inet_peer_xrlim_allow(peer, net->ipv4.sysctl_icmp_ratelimit); - if (peer) - inet_putpeer(peer); + peer = inet_getpeer_v4(net->ipv4.peers, fl4->daddr, + l3mdev_master_ifindex_rcu(dev)); + rc = inet_peer_xrlim_allow(peer, + READ_ONCE(net->ipv4.sysctl_icmp_ratelimit)); out: + rcu_read_unlock(); + if (!rc) + __ICMP_INC_STATS(net, ICMP_MIB_RATELIMITHOST); + else + icmp_global_consume(net); return rc; } @@ -350,7 +353,7 @@ void icmp_out_count(struct net *net, unsigned char type) static int icmp_glue_bits(void *from, char *to, int offset, int len, int odd, struct sk_buff *skb) { - struct icmp_bxm *icmp_param = (struct icmp_bxm *)from; + struct icmp_bxm *icmp_param = from; __wsum csum; csum = skb_copy_and_csum_bits(icmp_param->skb, @@ -363,14 +366,13 @@ static int icmp_glue_bits(void *from, char *to, int offset, int len, int odd, return 0; } -static void icmp_push_reply(struct icmp_bxm *icmp_param, +static void icmp_push_reply(struct sock *sk, + struct icmp_bxm *icmp_param, struct flowi4 *fl4, struct ipcm_cookie *ipc, struct rtable **rt) { - struct sock *sk; struct sk_buff *skb; - sk = icmp_sk(dev_net((*rt)->dst.dev)); if (ip_append_data(sk, fl4, icmp_glue_bits, icmp_param, icmp_param->data_len+icmp_param->head_len, icmp_param->head_len, @@ -400,12 +402,12 @@ static void icmp_push_reply(struct icmp_bxm *icmp_param, static void icmp_reply(struct icmp_bxm *icmp_param, struct sk_buff *skb) { - struct ipcm_cookie ipc; struct rtable *rt = skb_rtable(skb); - struct net *net = dev_net(rt->dst.dev); + struct net *net = dev_net_rcu(rt->dst.dev); + bool apply_ratelimit = false; + struct ipcm_cookie ipc; struct flowi4 fl4; struct sock *sk; - struct inet_sock *inet; __be32 daddr, saddr; u32 mark = IP4_REPLY_MARK(net, skb->mark); int type = icmp_param->data.icmph.type; @@ -414,22 +416,21 @@ static void icmp_reply(struct icmp_bxm *icmp_param, struct sk_buff *skb) if (ip_options_echo(net, &icmp_param->replyopts.opt.opt, skb)) return; - /* Needed by both icmp_global_allow and icmp_xmit_lock */ + /* Needed by both icmpv4_global_allow and icmp_xmit_lock */ local_bh_disable(); - /* global icmp_msgs_per_sec */ - if (!icmpv4_global_allow(net, type, code)) + /* is global icmp_msgs_per_sec exhausted ? */ + if (!icmpv4_global_allow(net, type, code, &apply_ratelimit)) goto out_bh_enable; sk = icmp_xmit_lock(net); if (!sk) goto out_bh_enable; - inet = inet_sk(sk); icmp_param->data.icmph.checksum = 0; ipcm_init(&ipc); - inet->tos = ip_hdr(skb)->tos; + ipc.tos = ip_hdr(skb)->tos; ipc.sockc.mark = mark; daddr = ipc.addr = ip_hdr(skb)->saddr; saddr = fib_compute_spec_dst(skb); @@ -444,15 +445,15 @@ static void icmp_reply(struct icmp_bxm *icmp_param, struct sk_buff *skb) fl4.saddr = saddr; fl4.flowi4_mark = mark; fl4.flowi4_uid = sock_net_uid(net, NULL); - fl4.flowi4_tos = RT_TOS(ip_hdr(skb)->tos); + fl4.flowi4_dscp = ip4h_dscp(ip_hdr(skb)); fl4.flowi4_proto = IPPROTO_ICMP; fl4.flowi4_oif = l3mdev_master_ifindex(skb->dev); security_skb_classify_flow(skb, flowi4_to_flowi_common(&fl4)); rt = ip_route_output_key(net, &fl4); if (IS_ERR(rt)) goto out_unlock; - if (icmpv4_xrlim_allow(net, rt, &fl4, type, code)) - icmp_push_reply(icmp_param, &fl4, &ipc, &rt); + if (icmpv4_xrlim_allow(net, rt, &fl4, type, code, apply_ratelimit)) + icmp_push_reply(sk, icmp_param, &fl4, &ipc, &rt); ip_rt_put(rt); out_unlock: icmp_xmit_unlock(sk); @@ -468,24 +469,23 @@ out_bh_enable: */ static struct net_device *icmp_get_route_lookup_dev(struct sk_buff *skb) { - struct net_device *route_lookup_dev = NULL; + struct net_device *dev = skb->dev; + const struct dst_entry *dst; - if (skb->dev) - route_lookup_dev = skb->dev; - else if (skb_dst(skb)) - route_lookup_dev = skb_dst(skb)->dev; - return route_lookup_dev; + if (dev) + return dev; + dst = skb_dst(skb); + return dst ? dst_dev(dst) : NULL; } -static struct rtable *icmp_route_lookup(struct net *net, - struct flowi4 *fl4, +static struct rtable *icmp_route_lookup(struct net *net, struct flowi4 *fl4, struct sk_buff *skb_in, - const struct iphdr *iph, - __be32 saddr, u8 tos, u32 mark, - int type, int code, - struct icmp_bxm *param) + const struct iphdr *iph, __be32 saddr, + dscp_t dscp, u32 mark, int type, + int code, struct icmp_bxm *param) { struct net_device *route_lookup_dev; + struct dst_entry *dst, *dst2; struct rtable *rt, *rt2; struct flowi4 fl4_dec; int err; @@ -496,7 +496,7 @@ static struct rtable *icmp_route_lookup(struct net *net, fl4->saddr = saddr; fl4->flowi4_mark = mark; fl4->flowi4_uid = sock_net_uid(net, NULL); - fl4->flowi4_tos = RT_TOS(tos); + fl4->flowi4_dscp = dscp; fl4->flowi4_proto = IPPROTO_ICMP; fl4->fl4_icmp_type = type; fl4->fl4_icmp_code = code; @@ -511,17 +511,21 @@ static struct rtable *icmp_route_lookup(struct net *net, /* No need to clone since we're just using its address. */ rt2 = rt; - rt = (struct rtable *) xfrm_lookup(net, &rt->dst, - flowi4_to_flowi(fl4), NULL, 0); - if (!IS_ERR(rt)) { + dst = xfrm_lookup(net, &rt->dst, + flowi4_to_flowi(fl4), NULL, 0); + rt = dst_rtable(dst); + if (!IS_ERR(dst)) { if (rt != rt2) return rt; - } else if (PTR_ERR(rt) == -EPERM) { + if (inet_addr_type_dev_table(net, route_lookup_dev, + fl4->daddr) == RTN_LOCAL) + return rt; + } else if (PTR_ERR(dst) == -EPERM) { rt = NULL; - } else + } else { return rt; - - err = xfrm_decode_session_reverse(skb_in, flowi4_to_flowi(&fl4_dec), AF_INET); + } + err = xfrm_decode_session_reverse(net, skb_in, flowi4_to_flowi(&fl4_dec), AF_INET); if (err) goto relookup_failed; @@ -541,32 +545,33 @@ static struct rtable *icmp_route_lookup(struct net *net, goto relookup_failed; } /* Ugh! */ - orefdst = skb_in->_skb_refdst; /* save old refdst */ - skb_dst_set(skb_in, NULL); + orefdst = skb_dstref_steal(skb_in); err = ip_route_input(skb_in, fl4_dec.daddr, fl4_dec.saddr, - RT_TOS(tos), rt2->dst.dev); + dscp, rt2->dst.dev) ? -EINVAL : 0; dst_release(&rt2->dst); rt2 = skb_rtable(skb_in); - skb_in->_skb_refdst = orefdst; /* restore old refdst */ + /* steal dst entry from skb_in, don't drop refcnt */ + skb_dstref_steal(skb_in); + skb_dstref_restore(skb_in, orefdst); } if (err) goto relookup_failed; - rt2 = (struct rtable *) xfrm_lookup(net, &rt2->dst, - flowi4_to_flowi(&fl4_dec), NULL, - XFRM_LOOKUP_ICMP); - if (!IS_ERR(rt2)) { + dst2 = xfrm_lookup(net, &rt2->dst, flowi4_to_flowi(&fl4_dec), NULL, + XFRM_LOOKUP_ICMP); + rt2 = dst_rtable(dst2); + if (!IS_ERR(dst2)) { dst_release(&rt->dst); memcpy(fl4, &fl4_dec, sizeof(*fl4)); rt = rt2; - } else if (PTR_ERR(rt2) == -EPERM) { + } else if (PTR_ERR(dst2) == -EPERM) { if (rt) dst_release(&rt->dst); return rt2; } else { - err = PTR_ERR(rt2); + err = PTR_ERR(dst2); goto relookup_failed; } return rt; @@ -577,6 +582,185 @@ relookup_failed: return ERR_PTR(err); } +struct icmp_ext_iio_addr4_subobj { + __be16 afi; + __be16 reserved; + __be32 addr4; +}; + +static unsigned int icmp_ext_iio_len(void) +{ + return sizeof(struct icmp_extobj_hdr) + + /* ifIndex */ + sizeof(__be32) + + /* Interface Address Sub-Object */ + sizeof(struct icmp_ext_iio_addr4_subobj) + + /* Interface Name Sub-Object. Length must be a multiple of 4 + * bytes. + */ + ALIGN(sizeof(struct icmp_ext_iio_name_subobj), 4) + + /* MTU */ + sizeof(__be32); +} + +static unsigned int icmp_ext_max_len(u8 ext_objs) +{ + unsigned int ext_max_len; + + ext_max_len = sizeof(struct icmp_ext_hdr); + + if (ext_objs & BIT(ICMP_ERR_EXT_IIO_IIF)) + ext_max_len += icmp_ext_iio_len(); + + return ext_max_len; +} + +static __be32 icmp_ext_iio_addr4_find(const struct net_device *dev) +{ + struct in_device *in_dev; + struct in_ifaddr *ifa; + + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) + return 0; + + /* It is unclear from RFC 5837 which IP address should be chosen, but + * it makes sense to choose a global unicast address. + */ + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (READ_ONCE(ifa->ifa_flags) & IFA_F_SECONDARY) + continue; + if (ifa->ifa_scope != RT_SCOPE_UNIVERSE || + ipv4_is_multicast(ifa->ifa_address)) + continue; + return ifa->ifa_address; + } + + return 0; +} + +static void icmp_ext_iio_iif_append(struct net *net, struct sk_buff *skb, + int iif) +{ + struct icmp_ext_iio_name_subobj *name_subobj; + struct icmp_extobj_hdr *objh; + struct net_device *dev; + __be32 data; + + if (!iif) + return; + + /* Add the fields in the order specified by RFC 5837. */ + objh = skb_put(skb, sizeof(*objh)); + objh->class_num = ICMP_EXT_OBJ_CLASS_IIO; + objh->class_type = ICMP_EXT_CTYPE_IIO_ROLE(ICMP_EXT_CTYPE_IIO_ROLE_IIF); + + data = htonl(iif); + skb_put_data(skb, &data, sizeof(__be32)); + objh->class_type |= ICMP_EXT_CTYPE_IIO_IFINDEX; + + rcu_read_lock(); + + dev = dev_get_by_index_rcu(net, iif); + if (!dev) + goto out; + + data = icmp_ext_iio_addr4_find(dev); + if (data) { + struct icmp_ext_iio_addr4_subobj *addr4_subobj; + + addr4_subobj = skb_put_zero(skb, sizeof(*addr4_subobj)); + addr4_subobj->afi = htons(ICMP_AFI_IP); + addr4_subobj->addr4 = data; + objh->class_type |= ICMP_EXT_CTYPE_IIO_IPADDR; + } + + name_subobj = skb_put_zero(skb, ALIGN(sizeof(*name_subobj), 4)); + name_subobj->len = ALIGN(sizeof(*name_subobj), 4); + netdev_copy_name(dev, name_subobj->name); + objh->class_type |= ICMP_EXT_CTYPE_IIO_NAME; + + data = htonl(READ_ONCE(dev->mtu)); + skb_put_data(skb, &data, sizeof(__be32)); + objh->class_type |= ICMP_EXT_CTYPE_IIO_MTU; + +out: + rcu_read_unlock(); + objh->length = htons(skb_tail_pointer(skb) - (unsigned char *)objh); +} + +static void icmp_ext_objs_append(struct net *net, struct sk_buff *skb, + u8 ext_objs, int iif) +{ + if (ext_objs & BIT(ICMP_ERR_EXT_IIO_IIF)) + icmp_ext_iio_iif_append(net, skb, iif); +} + +static struct sk_buff * +icmp_ext_append(struct net *net, struct sk_buff *skb_in, struct icmphdr *icmph, + unsigned int room, int iif) +{ + unsigned int payload_len, ext_max_len, ext_len; + struct icmp_ext_hdr *ext_hdr; + struct sk_buff *skb; + u8 ext_objs; + int nhoff; + + switch (icmph->type) { + case ICMP_DEST_UNREACH: + case ICMP_TIME_EXCEEDED: + case ICMP_PARAMETERPROB: + break; + default: + return NULL; + } + + ext_objs = READ_ONCE(net->ipv4.sysctl_icmp_errors_extension_mask); + if (!ext_objs) + return NULL; + + ext_max_len = icmp_ext_max_len(ext_objs); + if (ICMP_EXT_ORIG_DGRAM_MIN_LEN + ext_max_len > room) + return NULL; + + skb = skb_clone(skb_in, GFP_ATOMIC); + if (!skb) + return NULL; + + nhoff = skb_network_offset(skb); + payload_len = min(skb->len - nhoff, ICMP_EXT_ORIG_DGRAM_MIN_LEN); + + if (!pskb_network_may_pull(skb, payload_len)) + goto free_skb; + + if (pskb_trim(skb, nhoff + ICMP_EXT_ORIG_DGRAM_MIN_LEN) || + __skb_put_padto(skb, nhoff + ICMP_EXT_ORIG_DGRAM_MIN_LEN, false)) + goto free_skb; + + if (pskb_expand_head(skb, 0, ext_max_len, GFP_ATOMIC)) + goto free_skb; + + ext_hdr = skb_put_zero(skb, sizeof(*ext_hdr)); + ext_hdr->version = ICMP_EXT_VERSION_2; + + icmp_ext_objs_append(net, skb, ext_objs, iif); + + /* Do not send an empty extension structure. */ + ext_len = skb_tail_pointer(skb) - (unsigned char *)ext_hdr; + if (ext_len == sizeof(*ext_hdr)) + goto free_skb; + + ext_hdr->checksum = ip_compute_csum(ext_hdr, ext_len); + /* The length of the original datagram in 32-bit words (RFC 4884). */ + icmph->un.reserved[1] = ICMP_EXT_ORIG_DGRAM_MIN_LEN / sizeof(u32); + + return skb; + +free_skb: + consume_skb(skb); + return NULL; +} + /* * Send an ICMP message in response to a situation * @@ -589,12 +773,14 @@ relookup_failed: */ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, - const struct ip_options *opt) + const struct inet_skb_parm *parm) { struct iphdr *iph; int room; struct icmp_bxm icmp_param; struct rtable *rt = skb_rtable(skb_in); + bool apply_ratelimit = false; + struct sk_buff *ext_skb; struct ipcm_cookie ipc; struct flowi4 fl4; __be32 saddr; @@ -604,12 +790,14 @@ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, struct sock *sk; if (!rt) - goto out; + return; + + rcu_read_lock(); if (rt->dst.dev) - net = dev_net(rt->dst.dev); + net = dev_net_rcu(rt->dst.dev); else if (skb_in->dev) - net = dev_net(skb_in->dev); + net = dev_net_rcu(skb_in->dev); else goto out; @@ -676,7 +864,7 @@ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, } } - /* Needed by both icmp_global_allow and icmp_xmit_lock */ + /* Needed by both icmpv4_global_allow and icmp_xmit_lock */ local_bh_disable(); /* Check global sysctl_icmp_msgs_per_sec ratelimit, unless @@ -684,7 +872,7 @@ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, * loopback, then peer ratelimit still work (in icmpv4_xrlim_allow) */ if (!(skb_in->dev && (skb_in->dev->flags&IFF_LOOPBACK)) && - !icmpv4_global_allow(net, type, code)) + !icmpv4_global_allow(net, type, code, &apply_ratelimit)) goto out_bh_enable; sk = icmp_xmit_lock(net); @@ -701,8 +889,9 @@ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, rcu_read_lock(); if (rt_is_input_route(rt) && - net->ipv4.sysctl_icmp_errors_use_inbound_ifaddr) - dev = dev_get_by_index_rcu(net, inet_iif(skb_in)); + READ_ONCE(net->ipv4.sysctl_icmp_errors_use_inbound_ifaddr)) + dev = dev_get_by_index_rcu(net, parm->iif ? parm->iif : + inet_iif(skb_in)); if (dev) saddr = inet_select_addr(dev, iph->saddr, @@ -717,7 +906,8 @@ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, iph->tos; mark = IP4_REPLY_MARK(net, skb_in->mark); - if (__ip_options_echo(net, &icmp_param.replyopts.opt.opt, skb_in, opt)) + if (__ip_options_echo(net, &icmp_param.replyopts.opt.opt, skb_in, + &parm->opt)) goto out_unlock; @@ -731,19 +921,20 @@ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, icmp_param.data.icmph.checksum = 0; icmp_param.skb = skb_in; icmp_param.offset = skb_network_offset(skb_in); - inet_sk(sk)->tos = tos; ipcm_init(&ipc); + ipc.tos = tos; ipc.addr = iph->saddr; ipc.opt = &icmp_param.replyopts.opt; ipc.sockc.mark = mark; - rt = icmp_route_lookup(net, &fl4, skb_in, iph, saddr, tos, mark, - type, code, &icmp_param); + rt = icmp_route_lookup(net, &fl4, skb_in, iph, saddr, + inet_dsfield_to_dscp(tos), mark, type, code, + &icmp_param); if (IS_ERR(rt)) goto out_unlock; /* peer icmp_ratelimit */ - if (!icmpv4_xrlim_allow(net, rt, &fl4, type, code)) + if (!icmpv4_xrlim_allow(net, rt, &fl4, type, code, apply_ratelimit)) goto ende; /* RFC says return as much as we can without exceeding 576 bytes. */ @@ -753,8 +944,18 @@ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, room = 576; room -= sizeof(struct iphdr) + icmp_param.replyopts.opt.opt.optlen; room -= sizeof(struct icmphdr); + /* Guard against tiny mtu. We need to include at least one + * IP network header for this message to make any sense. + */ + if (room <= (int)sizeof(struct iphdr)) + goto ende; + + ext_skb = icmp_ext_append(net, skb_in, &icmp_param.data.icmph, room, + parm->iif); + if (ext_skb) + icmp_param.skb = ext_skb; - icmp_param.data_len = skb_in->len - icmp_param.offset; + icmp_param.data_len = icmp_param.skb->len - icmp_param.offset; if (icmp_param.data_len > room) icmp_param.data_len = room; icmp_param.head_len = sizeof(struct icmphdr); @@ -766,14 +967,20 @@ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, if (!fl4.saddr) fl4.saddr = htonl(INADDR_DUMMY); - icmp_push_reply(&icmp_param, &fl4, &ipc, &rt); + trace_icmp_send(skb_in, type, code); + + icmp_push_reply(sk, &icmp_param, &fl4, &ipc, &rt); + + if (ext_skb) + consume_skb(ext_skb); ende: ip_rt_put(rt); out_unlock: icmp_xmit_unlock(sk); out_bh_enable: local_bh_enable(); -out:; +out: + rcu_read_unlock(); } EXPORT_SYMBOL(__icmp_send); @@ -782,14 +989,16 @@ EXPORT_SYMBOL(__icmp_send); void icmp_ndo_send(struct sk_buff *skb_in, int type, int code, __be32 info) { struct sk_buff *cloned_skb = NULL; - struct ip_options opts = { 0 }; enum ip_conntrack_info ctinfo; + enum ip_conntrack_dir dir; + struct inet_skb_parm parm; struct nf_conn *ct; __be32 orig_ip; + memset(&parm, 0, sizeof(parm)); ct = nf_ct_get(skb_in, &ctinfo); - if (!ct || !(ct->status & IPS_SRC_NAT)) { - __icmp_send(skb_in, type, code, info, &opts); + if (!ct || !(READ_ONCE(ct->status) & IPS_NAT_MASK)) { + __icmp_send(skb_in, type, code, info, &parm); return; } @@ -803,8 +1012,9 @@ void icmp_ndo_send(struct sk_buff *skb_in, int type, int code, __be32 info) goto out; orig_ip = ip_hdr(skb_in)->saddr; - ip_hdr(skb_in)->saddr = ct->tuplehash[0].tuple.src.u3.ip; - __icmp_send(skb_in, type, code, info, &opts); + dir = CTINFO2DIR(ctinfo); + ip_hdr(skb_in)->saddr = ct->tuplehash[dir].tuple.src.u3.ip; + __icmp_send(skb_in, type, code, info, &parm); ip_hdr(skb_in)->saddr = orig_ip; out: consume_skb(cloned_skb); @@ -822,7 +1032,7 @@ static void icmp_socket_deliver(struct sk_buff *skb, u32 info) * avoid additional coding at protocol handlers. */ if (!pskb_may_pull(skb, iph->ihl * 4 + 8)) { - __ICMP_INC_STATS(dev_net(skb->dev), ICMP_MIB_INERRORS); + __ICMP_INC_STATS(dev_net_rcu(skb->dev), ICMP_MIB_INERRORS); return; } @@ -848,14 +1058,15 @@ static bool icmp_tag_validation(int proto) * ICMP_PARAMETERPROB. */ -static bool icmp_unreach(struct sk_buff *skb) +static enum skb_drop_reason icmp_unreach(struct sk_buff *skb) { + enum skb_drop_reason reason = SKB_NOT_DROPPED_YET; const struct iphdr *iph; struct icmphdr *icmph; struct net *net; u32 info = 0; - net = dev_net(skb_dst(skb)->dev); + net = skb_dst_dev_net_rcu(skb); /* * Incomplete header ? @@ -869,8 +1080,10 @@ static bool icmp_unreach(struct sk_buff *skb) icmph = icmp_hdr(skb); iph = (const struct iphdr *)skb->data; - if (iph->ihl < 5) /* Mangled header, drop. */ + if (iph->ihl < 5) { /* Mangled header, drop. */ + reason = SKB_DROP_REASON_IP_INHDR; goto out_err; + } switch (icmph->type) { case ICMP_DEST_UNREACH: @@ -885,7 +1098,7 @@ static bool icmp_unreach(struct sk_buff *skb) * values please see * Documentation/networking/ip-sysctl.rst */ - switch (net->ipv4.sysctl_ip_no_pmtu_disc) { + switch (READ_ONCE(net->ipv4.sysctl_ip_no_pmtu_disc)) { default: net_dbg_ratelimited("%pI4: fragmentation needed and DF set\n", &iph->daddr); @@ -938,7 +1151,7 @@ static bool icmp_unreach(struct sk_buff *skb) * get the other vendor to fix their kit. */ - if (!net->ipv4.sysctl_icmp_ignore_bogus_error_responses && + if (!READ_ONCE(net->ipv4.sysctl_icmp_ignore_bogus_error_responses) && inet_addr_type_dev_table(net, skb->dev, iph->daddr) == RTN_BROADCAST) { net_warn_ratelimited("%pI4 sent an invalid ICMP type %u, code %u error to a broadcast: %pI4 on %s\n", &ip_hdr(skb)->saddr, @@ -950,10 +1163,10 @@ static bool icmp_unreach(struct sk_buff *skb) icmp_socket_deliver(skb, info); out: - return true; + return reason; out_err: __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); - return false; + return reason ?: SKB_DROP_REASON_NOT_SPECIFIED; } @@ -961,20 +1174,20 @@ out_err: * Handle ICMP_REDIRECT. */ -static bool icmp_redirect(struct sk_buff *skb) +static enum skb_drop_reason icmp_redirect(struct sk_buff *skb) { if (skb->len < sizeof(struct iphdr)) { - __ICMP_INC_STATS(dev_net(skb->dev), ICMP_MIB_INERRORS); - return false; + __ICMP_INC_STATS(dev_net_rcu(skb->dev), ICMP_MIB_INERRORS); + return SKB_DROP_REASON_PKT_TOO_SMALL; } if (!pskb_may_pull(skb, sizeof(struct iphdr))) { /* there aught to be a stat */ - return false; + return SKB_DROP_REASON_NOMEM; } icmp_socket_deliver(skb, ntohl(icmp_hdr(skb)->un.gateway)); - return true; + return SKB_NOT_DROPPED_YET; } /* @@ -991,15 +1204,15 @@ static bool icmp_redirect(struct sk_buff *skb) * See also WRT handling of options once they are done and working. */ -static bool icmp_echo(struct sk_buff *skb) +static enum skb_drop_reason icmp_echo(struct sk_buff *skb) { struct icmp_bxm icmp_param; struct net *net; - net = dev_net(skb_dst(skb)->dev); + net = skb_dst_dev_net_rcu(skb); /* should there be an ICMP stat for ignored echos? */ - if (net->ipv4.sysctl_icmp_echo_ignore_all) - return true; + if (READ_ONCE(net->ipv4.sysctl_icmp_echo_ignore_all)) + return SKB_NOT_DROPPED_YET; icmp_param.data.icmph = *icmp_hdr(skb); icmp_param.skb = skb; @@ -1010,10 +1223,10 @@ static bool icmp_echo(struct sk_buff *skb) if (icmp_param.data.icmph.type == ICMP_ECHO) icmp_param.data.icmph.type = ICMP_ECHOREPLY; else if (!icmp_build_probe(skb, &icmp_param.data.icmph)) - return true; + return SKB_NOT_DROPPED_YET; icmp_reply(&icmp_param, skb); - return true; + return SKB_NOT_DROPPED_YET; } /* Helper for icmp_echo and icmpv6_echo_reply. @@ -1025,15 +1238,17 @@ static bool icmp_echo(struct sk_buff *skb) bool icmp_build_probe(struct sk_buff *skb, struct icmphdr *icmphdr) { + struct net *net = dev_net_rcu(skb->dev); struct icmp_ext_hdr *ext_hdr, _ext_hdr; struct icmp_ext_echo_iio *iio, _iio; - struct net *net = dev_net(skb->dev); + struct inet6_dev *in6_dev; + struct in_device *in_dev; struct net_device *dev; char buff[IFNAMSIZ]; u16 ident_len; u8 status; - if (!net->ipv4.sysctl_icmp_echo_enable_probe) + if (!READ_ONCE(net->ipv4.sysctl_icmp_echo_enable_probe)) return false; /* We currently only support probing interfaces on the proxy node @@ -1111,10 +1326,15 @@ bool icmp_build_probe(struct sk_buff *skb, struct icmphdr *icmphdr) /* Fill bits in reply message */ if (dev->flags & IFF_UP) status |= ICMP_EXT_ECHOREPLY_ACTIVE; - if (__in_dev_get_rcu(dev) && __in_dev_get_rcu(dev)->ifa_list) + + in_dev = __in_dev_get_rcu(dev); + if (in_dev && rcu_access_pointer(in_dev->ifa_list)) status |= ICMP_EXT_ECHOREPLY_IPV4; - if (!list_empty(&rcu_dereference(dev->ip6_ptr)->addr_list)) + + in6_dev = __in6_dev_get(dev); + if (in6_dev && !list_empty(&in6_dev->addr_list)) status |= ICMP_EXT_ECHOREPLY_IPV6; + dev_put(dev); icmphdr->un.echo.sequence |= htons(status); return true; @@ -1131,7 +1351,7 @@ EXPORT_SYMBOL_GPL(icmp_build_probe); * MUST be accurate to a few minutes. * MUST be updated at least at 15Hz. */ -static bool icmp_timestamp(struct sk_buff *skb) +static enum skb_drop_reason icmp_timestamp(struct sk_buff *skb) { struct icmp_bxm icmp_param; /* @@ -1156,17 +1376,17 @@ static bool icmp_timestamp(struct sk_buff *skb) icmp_param.data_len = 0; icmp_param.head_len = sizeof(struct icmphdr) + 12; icmp_reply(&icmp_param, skb); - return true; + return SKB_NOT_DROPPED_YET; out_err: - __ICMP_INC_STATS(dev_net(skb_dst(skb)->dev), ICMP_MIB_INERRORS); - return false; + __ICMP_INC_STATS(skb_dst_dev_net_rcu(skb), ICMP_MIB_INERRORS); + return SKB_DROP_REASON_PKT_TOO_SMALL; } -static bool icmp_discard(struct sk_buff *skb) +static enum skb_drop_reason icmp_discard(struct sk_buff *skb) { /* pretend it was a success */ - return true; + return SKB_NOT_DROPPED_YET; } /* @@ -1174,18 +1394,20 @@ static bool icmp_discard(struct sk_buff *skb) */ int icmp_rcv(struct sk_buff *skb) { - struct icmphdr *icmph; + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; struct rtable *rt = skb_rtable(skb); - struct net *net = dev_net(rt->dst.dev); - bool success; + struct net *net = dev_net_rcu(rt->dst.dev); + struct icmphdr *icmph; if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) { struct sec_path *sp = skb_sec_path(skb); int nh; if (!(sp && sp->xvec[sp->len - 1]->props.flags & - XFRM_STATE_ICMP)) + XFRM_STATE_ICMP)) { + reason = SKB_DROP_REASON_XFRM_POLICY; goto drop; + } if (!pskb_may_pull(skb, sizeof(*icmph) + sizeof(struct iphdr))) goto drop; @@ -1193,8 +1415,11 @@ int icmp_rcv(struct sk_buff *skb) nh = skb_network_offset(skb); skb_set_network_header(skb, sizeof(*icmph)); - if (!xfrm4_policy_check_reverse(NULL, XFRM_POLICY_IN, skb)) + if (!xfrm4_policy_check_reverse(NULL, XFRM_POLICY_IN, + skb)) { + reason = SKB_DROP_REASON_XFRM_POLICY; goto drop; + } skb_set_network_header(skb, nh); } @@ -1216,25 +1441,11 @@ int icmp_rcv(struct sk_buff *skb) /* We can't use icmp_pointers[].handler() because it is an array of * size NR_ICMP_TYPES + 1 (19 elements) and PROBE has code 42. */ - success = icmp_echo(skb); - goto success_check; - } - - if (icmph->type == ICMP_EXT_ECHOREPLY) { - success = ping_rcv(skb); - goto success_check; + reason = icmp_echo(skb); + goto reason_check; } /* - * 18 is the highest 'known' ICMP type. Anything else is a mystery - * - * RFC 1122: 3.2.2 Unknown ICMP messages types MUST be silently - * discarded. - */ - if (icmph->type > NR_ICMP_TYPES) - goto error; - - /* * Parse the ICMP message */ @@ -1247,28 +1458,48 @@ int icmp_rcv(struct sk_buff *skb) */ if ((icmph->type == ICMP_ECHO || icmph->type == ICMP_TIMESTAMP) && - net->ipv4.sysctl_icmp_echo_ignore_broadcasts) { + READ_ONCE(net->ipv4.sysctl_icmp_echo_ignore_broadcasts)) { + reason = SKB_DROP_REASON_INVALID_PROTO; goto error; } if (icmph->type != ICMP_ECHO && icmph->type != ICMP_TIMESTAMP && icmph->type != ICMP_ADDRESS && icmph->type != ICMP_ADDRESSREPLY) { + reason = SKB_DROP_REASON_INVALID_PROTO; goto error; } } - success = icmp_pointers[icmph->type].handler(skb); -success_check: - if (success) { + if (icmph->type == ICMP_EXT_ECHOREPLY || + icmph->type == ICMP_ECHOREPLY) { + reason = ping_rcv(skb); + return reason ? NET_RX_DROP : NET_RX_SUCCESS; + } + + /* + * 18 is the highest 'known' ICMP type. Anything else is a mystery + * + * RFC 1122: 3.2.2 Unknown ICMP messages types MUST be silently + * discarded. + */ + if (icmph->type > NR_ICMP_TYPES) { + reason = SKB_DROP_REASON_UNHANDLED_PROTO; + goto error; + } + + reason = icmp_pointers[icmph->type].handler(skb); +reason_check: + if (!reason) { consume_skb(skb); return NET_RX_SUCCESS; } drop: - kfree_skb(skb); + kfree_skb_reason(skb, reason); return NET_RX_DROP; csum_error: + reason = SKB_DROP_REASON_ICMP_CSUM; __ICMP_INC_STATS(net, ICMP_MIB_CSUMERRORS); error: __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); @@ -1339,9 +1570,9 @@ int icmp_err(struct sk_buff *skb, u32 info) struct iphdr *iph = (struct iphdr *)skb->data; int offset = iph->ihl<<2; struct icmphdr *icmph = (struct icmphdr *)(skb->data + offset); + struct net *net = dev_net_rcu(skb->dev); int type = icmp_hdr(skb)->type; int code = icmp_hdr(skb)->code; - struct net *net = dev_net(skb->dev); /* * Use ping_err to handle all icmp errors except those @@ -1434,46 +1665,8 @@ static const struct icmp_control icmp_pointers[NR_ICMP_TYPES + 1] = { }, }; -static void __net_exit icmp_sk_exit(struct net *net) -{ - int i; - - for_each_possible_cpu(i) - inet_ctl_sock_destroy(*per_cpu_ptr(net->ipv4.icmp_sk, i)); - free_percpu(net->ipv4.icmp_sk); - net->ipv4.icmp_sk = NULL; -} - static int __net_init icmp_sk_init(struct net *net) { - int i, err; - - net->ipv4.icmp_sk = alloc_percpu(struct sock *); - if (!net->ipv4.icmp_sk) - return -ENOMEM; - - for_each_possible_cpu(i) { - struct sock *sk; - - err = inet_ctl_sock_create(&sk, PF_INET, - SOCK_RAW, IPPROTO_ICMP, net); - if (err < 0) - goto fail; - - *per_cpu_ptr(net->ipv4.icmp_sk, i) = sk; - - /* Enough space for 2 64K ICMP packets, including - * sk_buff/skb_shared_info struct overhead. - */ - sk->sk_sndbuf = 2 * SKB_TRUESIZE(64 * 1024); - - /* - * Speedup sock_wfree() - */ - sock_set_flag(sk, SOCK_USE_WRITE_QUEUE); - inet_sk(sk)->pmtudisc = IP_PMTUDISC_DONT; - } - /* Control parameters for ECHO replies. */ net->ipv4.sysctl_icmp_echo_ignore_all = 0; net->ipv4.sysctl_icmp_echo_enable_probe = 0; @@ -1497,20 +1690,41 @@ static int __net_init icmp_sk_init(struct net *net) net->ipv4.sysctl_icmp_ratelimit = 1 * HZ; net->ipv4.sysctl_icmp_ratemask = 0x1818; net->ipv4.sysctl_icmp_errors_use_inbound_ifaddr = 0; + net->ipv4.sysctl_icmp_errors_extension_mask = 0; + net->ipv4.sysctl_icmp_msgs_per_sec = 1000; + net->ipv4.sysctl_icmp_msgs_burst = 50; return 0; - -fail: - icmp_sk_exit(net); - return err; } static struct pernet_operations __net_initdata icmp_sk_ops = { .init = icmp_sk_init, - .exit = icmp_sk_exit, }; int __init icmp_init(void) { + int err, i; + + for_each_possible_cpu(i) { + struct sock *sk; + + err = inet_ctl_sock_create(&sk, PF_INET, + SOCK_RAW, IPPROTO_ICMP, &init_net); + if (err < 0) + return err; + + per_cpu(ipv4_icmp_sk, i) = sk; + + /* Enough space for 2 64K ICMP packets, including + * sk_buff/skb_shared_info struct overhead. + */ + sk->sk_sndbuf = 2 * SKB_TRUESIZE(64 * 1024); + + /* + * Speedup sock_wfree() + */ + sock_set_flag(sk, SOCK_USE_WRITE_QUEUE); + inet_sk(sk)->pmtudisc = IP_PMTUDISC_DONT; + } return register_pernet_subsys(&icmp_sk_ops); } diff --git a/net/ipv4/igmp.c b/net/ipv4/igmp.c index 2ad3c7b42d6d..7182f1419c2a 100644 --- a/net/ipv4/igmp.c +++ b/net/ipv4/igmp.c @@ -81,6 +81,7 @@ #include <linux/skbuff.h> #include <linux/inetdevice.h> #include <linux/igmp.h> +#include "igmp_internal.h" #include <linux/if_arp.h> #include <linux/rtnetlink.h> #include <linux/times.h> @@ -88,6 +89,8 @@ #include <linux/byteorder/generic.h> #include <net/net_namespace.h> +#include <net/netlink.h> +#include <net/addrconf.h> #include <net/arp.h> #include <net/ip.h> #include <net/protocol.h> @@ -120,12 +123,12 @@ */ #define IGMP_V1_SEEN(in_dev) \ - (IPV4_DEVCONF_ALL(dev_net(in_dev->dev), FORCE_IGMP_VERSION) == 1 || \ + (IPV4_DEVCONF_ALL_RO(dev_net(in_dev->dev), FORCE_IGMP_VERSION) == 1 || \ IN_DEV_CONF_GET((in_dev), FORCE_IGMP_VERSION) == 1 || \ ((in_dev)->mr_v1_seen && \ time_before(jiffies, (in_dev)->mr_v1_seen))) #define IGMP_V2_SEEN(in_dev) \ - (IPV4_DEVCONF_ALL(dev_net(in_dev->dev), FORCE_IGMP_VERSION) == 2 || \ + (IPV4_DEVCONF_ALL_RO(dev_net(in_dev->dev), FORCE_IGMP_VERSION) == 2 || \ IN_DEV_CONF_GET((in_dev), FORCE_IGMP_VERSION) == 2 || \ ((in_dev)->mr_v2_seen && \ time_before(jiffies, (in_dev)->mr_v2_seen))) @@ -202,7 +205,7 @@ static void ip_sf_list_clear_all(struct ip_sf_list *psf) static void igmp_stop_timer(struct ip_mc_list *im) { spin_lock_bh(&im->lock); - if (del_timer(&im->timer)) + if (timer_delete(&im->timer)) refcount_dec(&im->refcnt); im->tm_running = 0; im->reporter = 0; @@ -213,16 +216,18 @@ static void igmp_stop_timer(struct ip_mc_list *im) /* It must be called with locked im->lock */ static void igmp_start_timer(struct ip_mc_list *im, int max_delay) { - int tv = prandom_u32() % max_delay; + int tv = get_random_u32_below(max_delay); im->tm_running = 1; - if (!mod_timer(&im->timer, jiffies+tv+2)) - refcount_inc(&im->refcnt); + if (refcount_inc_not_zero(&im->refcnt)) { + if (mod_timer(&im->timer, jiffies + tv + 2)) + ip_ma_put(im); + } } static void igmp_gq_start_timer(struct in_device *in_dev) { - int tv = prandom_u32() % in_dev->mr_maxdelay; + int tv = get_random_u32_below(in_dev->mr_maxdelay); unsigned long exp = jiffies + tv + 2; if (in_dev->mr_gq_running && @@ -236,7 +241,7 @@ static void igmp_gq_start_timer(struct in_device *in_dev) static void igmp_ifc_start_timer(struct in_device *in_dev, int delay) { - int tv = prandom_u32() % delay; + int tv = get_random_u32_below(delay); if (!mod_timer(&in_dev->mr_ifc_timer, jiffies+tv+2)) in_dev_hold(in_dev); @@ -246,7 +251,7 @@ static void igmp_mod_timer(struct ip_mc_list *im, int max_delay) { spin_lock_bh(&im->lock); im->unsolicit_count = 0; - if (del_timer(&im->timer)) { + if (timer_delete(&im->timer)) { if ((long)(im->timer.expires-jiffies) < max_delay) { add_timer(&im->timer); im->tm_running = 1; @@ -353,8 +358,9 @@ static struct sk_buff *igmpv3_newpack(struct net_device *dev, unsigned int mtu) struct flowi4 fl4; int hlen = LL_RESERVED_SPACE(dev); int tlen = dev->needed_tailroom; - unsigned int size = mtu; + unsigned int size; + size = min(mtu, IP_MAX_MTU); while (1) { skb = alloc_skb(size + hlen + tlen, GFP_ATOMIC | __GFP_NOWARN); @@ -421,7 +427,7 @@ static int igmpv3_sendpack(struct sk_buff *skb) pig->csum = ip_compute_csum(igmp_hdr(skb), igmplen); - return ip_local_out(dev_net(skb_dst(skb)->dev), skb->sk, skb); + return ip_local_out(skb_dst_dev_net(skb), skb->sk, skb); } static int grec_size(struct ip_mc_list *pmc, int type, int gdel, int sdel) @@ -467,7 +473,8 @@ static struct sk_buff *add_grec(struct sk_buff *skb, struct ip_mc_list *pmc, if (pmc->multiaddr == IGMP_ALL_HOSTS) return skb; - if (ipv4_is_local_multicast(pmc->multiaddr) && !net->ipv4.sysctl_igmp_llm_reports) + if (ipv4_is_local_multicast(pmc->multiaddr) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) return skb; mtu = READ_ONCE(dev->mtu); @@ -593,7 +600,7 @@ static int igmpv3_send_report(struct in_device *in_dev, struct ip_mc_list *pmc) if (pmc->multiaddr == IGMP_ALL_HOSTS) continue; if (ipv4_is_local_multicast(pmc->multiaddr) && - !net->ipv4.sysctl_igmp_llm_reports) + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) continue; spin_lock_bh(&pmc->lock); if (pmc->sfcount[MCAST_EXCLUDE]) @@ -736,7 +743,8 @@ static int igmp_send_report(struct in_device *in_dev, struct ip_mc_list *pmc, if (type == IGMPV3_HOST_MEMBERSHIP_REPORT) return igmpv3_send_report(in_dev, pmc); - if (ipv4_is_local_multicast(group) && !net->ipv4.sysctl_igmp_llm_reports) + if (ipv4_is_local_multicast(group) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) return 0; if (type == IGMP_HOST_LEAVE_MESSAGE) @@ -793,7 +801,7 @@ static int igmp_send_report(struct in_device *in_dev, struct ip_mc_list *pmc, static void igmp_gq_timer_expire(struct timer_list *t) { - struct in_device *in_dev = from_timer(in_dev, t, mr_gq_timer); + struct in_device *in_dev = timer_container_of(in_dev, t, mr_gq_timer); in_dev->mr_gq_running = 0; igmpv3_send_report(in_dev, NULL); @@ -802,7 +810,7 @@ static void igmp_gq_timer_expire(struct timer_list *t) static void igmp_ifc_timer_expire(struct timer_list *t) { - struct in_device *in_dev = from_timer(in_dev, t, mr_ifc_timer); + struct in_device *in_dev = timer_container_of(in_dev, t, mr_ifc_timer); u32 mr_ifc_count; igmpv3_send_cr(in_dev); @@ -825,14 +833,14 @@ static void igmp_ifc_event(struct in_device *in_dev) struct net *net = dev_net(in_dev->dev); if (IGMP_V1_SEEN(in_dev) || IGMP_V2_SEEN(in_dev)) return; - WRITE_ONCE(in_dev->mr_ifc_count, in_dev->mr_qrv ?: net->ipv4.sysctl_igmp_qrv); + WRITE_ONCE(in_dev->mr_ifc_count, in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv)); igmp_ifc_start_timer(in_dev, 1); } static void igmp_timer_expire(struct timer_list *t) { - struct ip_mc_list *im = from_timer(im, t, timer); + struct ip_mc_list *im = timer_container_of(im, t, timer); struct in_device *in_dev = im->interface; spin_lock(&im->lock); @@ -920,7 +928,8 @@ static bool igmp_heard_report(struct in_device *in_dev, __be32 group) if (group == IGMP_ALL_HOSTS) return false; - if (ipv4_is_local_multicast(group) && !net->ipv4.sysctl_igmp_llm_reports) + if (ipv4_is_local_multicast(group) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) return false; rcu_read_lock(); @@ -965,7 +974,7 @@ static bool igmp_heard_query(struct in_device *in_dev, struct sk_buff *skb, } /* cancel the interface change timer */ WRITE_ONCE(in_dev->mr_ifc_count, 0); - if (del_timer(&in_dev->mr_ifc_timer)) + if (timer_delete(&in_dev->mr_ifc_timer)) __in_dev_put(in_dev); /* clear deleted report items */ igmpv3_clear_delrec(in_dev); @@ -1006,7 +1015,7 @@ static bool igmp_heard_query(struct in_device *in_dev, struct sk_buff *skb, * received value was zero, use the default or statically * configured value. */ - in_dev->mr_qrv = ih3->qrv ?: net->ipv4.sysctl_igmp_qrv; + in_dev->mr_qrv = ih3->qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); in_dev->mr_qi = IGMPV3_QQIC(ih3->qqic)*HZ ?: IGMP_QUERY_INTERVAL; /* RFC3376, 8.3. Query Response Interval: @@ -1045,7 +1054,7 @@ static bool igmp_heard_query(struct in_device *in_dev, struct sk_buff *skb, if (im->multiaddr == IGMP_ALL_HOSTS) continue; if (ipv4_is_local_multicast(im->multiaddr) && - !net->ipv4.sysctl_igmp_llm_reports) + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) continue; spin_lock_bh(&im->lock); if (im->tm_running) @@ -1186,7 +1195,7 @@ static void igmpv3_add_delrec(struct in_device *in_dev, struct ip_mc_list *im, pmc->interface = im->interface; in_dev_hold(in_dev); pmc->multiaddr = im->multiaddr; - pmc->crcount = in_dev->mr_qrv ?: net->ipv4.sysctl_igmp_qrv; + pmc->crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); pmc->sfmode = im->sfmode; if (pmc->sfmode == MCAST_INCLUDE) { struct ip_sf_list *psf; @@ -1237,9 +1246,11 @@ static void igmpv3_del_delrec(struct in_device *in_dev, struct ip_mc_list *im) swap(im->tomb, pmc->tomb); swap(im->sources, pmc->sources); for (psf = im->sources; psf; psf = psf->sf_next) - psf->sf_crcount = in_dev->mr_qrv ?: net->ipv4.sysctl_igmp_qrv; + psf->sf_crcount = in_dev->mr_qrv ?: + READ_ONCE(net->ipv4.sysctl_igmp_qrv); } else { - im->crcount = in_dev->mr_qrv ?: net->ipv4.sysctl_igmp_qrv; + im->crcount = in_dev->mr_qrv ?: + READ_ONCE(net->ipv4.sysctl_igmp_qrv); } in_dev_put(pmc->interface); kfree_pmc(pmc); @@ -1296,7 +1307,8 @@ static void __igmp_group_dropped(struct ip_mc_list *im, gfp_t gfp) #ifdef CONFIG_IP_MULTICAST if (im->multiaddr == IGMP_ALL_HOSTS) return; - if (ipv4_is_local_multicast(im->multiaddr) && !net->ipv4.sysctl_igmp_llm_reports) + if (ipv4_is_local_multicast(im->multiaddr) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) return; reporter = im->reporter; @@ -1338,13 +1350,14 @@ static void igmp_group_added(struct ip_mc_list *im) #ifdef CONFIG_IP_MULTICAST if (im->multiaddr == IGMP_ALL_HOSTS) return; - if (ipv4_is_local_multicast(im->multiaddr) && !net->ipv4.sysctl_igmp_llm_reports) + if (ipv4_is_local_multicast(im->multiaddr) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) return; if (in_dev->dead) return; - im->unsolicit_count = net->ipv4.sysctl_igmp_qrv; + im->unsolicit_count = READ_ONCE(net->ipv4.sysctl_igmp_qrv); if (IGMP_V1_SEEN(in_dev) || IGMP_V2_SEEN(in_dev)) { spin_lock_bh(&im->lock); igmp_start_timer(im, IGMP_INITIAL_REPORT_DELAY); @@ -1358,7 +1371,7 @@ static void igmp_group_added(struct ip_mc_list *im) * IN() to IN(A). */ if (im->sfmode == MCAST_EXCLUDE) - im->crcount = in_dev->mr_qrv ?: net->ipv4.sysctl_igmp_qrv; + im->crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); igmp_ifc_event(in_dev); #endif @@ -1420,6 +1433,70 @@ static void ip_mc_hash_remove(struct in_device *in_dev, *mc_hash = im->next_hash; } +int inet_fill_ifmcaddr(struct sk_buff *skb, struct net_device *dev, + const struct ip_mc_list *im, + struct inet_fill_args *args) +{ + struct ifa_cacheinfo ci; + struct ifaddrmsg *ifm; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, args->portid, args->seq, args->event, + sizeof(struct ifaddrmsg), args->flags); + if (!nlh) + return -EMSGSIZE; + + ifm = nlmsg_data(nlh); + ifm->ifa_family = AF_INET; + ifm->ifa_prefixlen = 32; + ifm->ifa_flags = IFA_F_PERMANENT; + ifm->ifa_scope = RT_SCOPE_UNIVERSE; + ifm->ifa_index = dev->ifindex; + + ci.cstamp = (READ_ONCE(im->mca_cstamp) - INITIAL_JIFFIES) * 100UL / HZ; + ci.tstamp = ci.cstamp; + ci.ifa_prefered = INFINITY_LIFE_TIME; + ci.ifa_valid = INFINITY_LIFE_TIME; + + if (nla_put_in_addr(skb, IFA_MULTICAST, im->multiaddr) < 0 || + nla_put(skb, IFA_CACHEINFO, sizeof(ci), &ci) < 0) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static void inet_ifmcaddr_notify(struct net_device *dev, + const struct ip_mc_list *im, int event) +{ + struct inet_fill_args fillargs = { + .event = event, + }; + struct net *net = dev_net(dev); + struct sk_buff *skb; + int err = -ENOMEM; + + skb = nlmsg_new(NLMSG_ALIGN(sizeof(struct ifaddrmsg)) + + nla_total_size(sizeof(__be32)) + + nla_total_size(sizeof(struct ifa_cacheinfo)), + GFP_KERNEL); + if (!skb) + goto error; + + err = inet_fill_ifmcaddr(skb, dev, im, &fillargs); + if (err < 0) { + WARN_ON_ONCE(err == -EMSGSIZE); + nlmsg_free(skb); + goto error; + } + + rtnl_notify(skb, net, 0, RTNLGRP_IPV4_MCADDR, NULL, GFP_KERNEL); + return; +error: + rtnl_set_sk_err(net, RTNLGRP_IPV4_MCADDR, err); +} /* * A socket has joined a multicast group on device dev. @@ -1427,16 +1504,32 @@ static void ip_mc_hash_remove(struct in_device *in_dev, static void ____ip_mc_inc_group(struct in_device *in_dev, __be32 addr, unsigned int mode, gfp_t gfp) { + struct ip_mc_list __rcu **mc_hash; struct ip_mc_list *im; ASSERT_RTNL(); - for_each_pmc_rtnl(in_dev, im) { - if (im->multiaddr == addr) { - im->users++; - ip_mc_add_src(in_dev, &addr, mode, 0, NULL, 0); - goto out; + mc_hash = rtnl_dereference(in_dev->mc_hash); + if (mc_hash) { + u32 hash = hash_32((__force u32)addr, MC_HASH_SZ_LOG); + + for (im = rtnl_dereference(mc_hash[hash]); + im; + im = rtnl_dereference(im->next_hash)) { + if (im->multiaddr == addr) + break; } + } else { + for_each_pmc_rtnl(in_dev, im) { + if (im->multiaddr == addr) + break; + } + } + + if (im) { + im->users++; + ip_mc_add_src(in_dev, &addr, mode, 0, NULL, 0); + goto out; } im = kzalloc(sizeof(*im), gfp); @@ -1447,6 +1540,8 @@ static void ____ip_mc_inc_group(struct in_device *in_dev, __be32 addr, im->interface = in_dev; in_dev_hold(in_dev); im->multiaddr = addr; + im->mca_cstamp = jiffies; + im->mca_tstamp = im->mca_cstamp; /* initial mode is (EX, empty) */ im->sfmode = mode; im->sfcount[mode] = 1; @@ -1466,6 +1561,7 @@ static void ____ip_mc_inc_group(struct in_device *in_dev, __be32 addr, igmpv3_del_delrec(in_dev, im); #endif igmp_group_added(im); + inet_ifmcaddr_notify(in_dev->dev, im, RTM_NEWMULTICAST); if (!in_dev->dead) ip_rt_multicast_event(in_dev); out: @@ -1642,7 +1738,7 @@ static void ip_mc_rejoin_groups(struct in_device *in_dev) if (im->multiaddr == IGMP_ALL_HOSTS) continue; if (ipv4_is_local_multicast(im->multiaddr) && - !net->ipv4.sysctl_igmp_llm_reports) + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) continue; /* a failover is happening and switches @@ -1679,6 +1775,8 @@ void __ip_mc_dec_group(struct in_device *in_dev, __be32 addr, gfp_t gfp) *ip = i->next_rcu; in_dev->mc_count--; __igmp_group_dropped(i, gfp); + inet_ifmcaddr_notify(in_dev->dev, i, + RTM_DELMULTICAST); ip_mc_clear_src(i); if (!in_dev->dead) @@ -1732,10 +1830,10 @@ void ip_mc_down(struct in_device *in_dev) #ifdef CONFIG_IP_MULTICAST WRITE_ONCE(in_dev->mr_ifc_count, 0); - if (del_timer(&in_dev->mr_ifc_timer)) + if (timer_delete(&in_dev->mr_ifc_timer)) __in_dev_put(in_dev); in_dev->mr_gq_running = 0; - if (del_timer(&in_dev->mr_gq_timer)) + if (timer_delete(&in_dev->mr_gq_timer)) __in_dev_put(in_dev); #endif @@ -1749,7 +1847,7 @@ static void ip_mc_reset(struct in_device *in_dev) in_dev->mr_qi = IGMP_QUERY_INTERVAL; in_dev->mr_qri = IGMP_QUERY_RESPONSE_INTERVAL; - in_dev->mr_qrv = net->ipv4.sysctl_igmp_qrv; + in_dev->mr_qrv = READ_ONCE(net->ipv4.sysctl_igmp_qrv); } #else static void ip_mc_reset(struct in_device *in_dev) @@ -1832,7 +1930,8 @@ static struct in_device *ip_mc_find_dev(struct net *net, struct ip_mreqn *imr) if (!dev) { struct rtable *rt = ip_route_output(net, imr->imr_multiaddr.s_addr, - 0, 0, 0); + 0, 0, 0, + RT_SCOPE_UNIVERSE); if (!IS_ERR(rt)) { dev = rt->dst.dev; ip_rt_put(rt); @@ -1883,7 +1982,7 @@ static int ip_mc_del1_src(struct ip_mc_list *pmc, int sfmode, #ifdef CONFIG_IP_MULTICAST if (psf->sf_oldin && !IGMP_V1_SEEN(in_dev) && !IGMP_V2_SEEN(in_dev)) { - psf->sf_crcount = in_dev->mr_qrv ?: net->ipv4.sysctl_igmp_qrv; + psf->sf_crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); psf->sf_next = pmc->tomb; pmc->tomb = psf; rv = 1; @@ -1947,7 +2046,7 @@ static int ip_mc_del_src(struct in_device *in_dev, __be32 *pmca, int sfmode, /* filter mode change */ pmc->sfmode = MCAST_INCLUDE; #ifdef CONFIG_IP_MULTICAST - pmc->crcount = in_dev->mr_qrv ?: net->ipv4.sysctl_igmp_qrv; + pmc->crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); WRITE_ONCE(in_dev->mr_ifc_count, pmc->crcount); for (psf = pmc->sources; psf; psf = psf->sf_next) psf->sf_crcount = 0; @@ -2126,7 +2225,7 @@ static int ip_mc_add_src(struct in_device *in_dev, __be32 *pmca, int sfmode, #ifdef CONFIG_IP_MULTICAST /* else no filters; keep old mode for reports */ - pmc->crcount = in_dev->mr_qrv ?: net->ipv4.sysctl_igmp_qrv; + pmc->crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); WRITE_ONCE(in_dev->mr_ifc_count, pmc->crcount); for (psf = pmc->sources; psf; psf = psf->sf_next) psf->sf_crcount = 0; @@ -2192,7 +2291,7 @@ static int __ip_mc_join_group(struct sock *sk, struct ip_mreqn *imr, count++; } err = -ENOBUFS; - if (count >= net->ipv4.sysctl_igmp_max_memberships) + if (count >= READ_ONCE(net->ipv4.sysctl_igmp_max_memberships)) goto done; iml = sock_kmalloc(sk, sizeof(*iml), GFP_KERNEL); if (!iml) @@ -2379,7 +2478,7 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct } /* else, add a new source to the filter */ - if (psl && psl->sl_count >= net->ipv4.sysctl_igmp_max_msf) { + if (psl && psl->sl_count >= READ_ONCE(net->ipv4.sysctl_igmp_max_msf)) { err = -ENOBUFS; goto done; } @@ -2403,9 +2502,10 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct /* decrease mem now to avoid the memleak warning */ atomic_sub(struct_size(psl, sl_addr, psl->sl_max), &sk->sk_omem_alloc); - kfree_rcu(psl, rcu); } rcu_assign_pointer(pmc->sflist, newpsl); + if (psl) + kfree_rcu(psl, rcu); psl = newpsl; } rv = 1; /* > 0 for insert logic below if sl_count is 0 */ @@ -2507,11 +2607,13 @@ int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex) /* decrease mem now to avoid the memleak warning */ atomic_sub(struct_size(psl, sl_addr, psl->sl_max), &sk->sk_omem_alloc); - kfree_rcu(psl, rcu); - } else + } else { (void) ip_mc_del_src(in_dev, &msf->imsf_multiaddr, pmc->sfmode, 0, NULL, 0); + } rcu_assign_pointer(pmc->sflist, newpsl); + if (psl) + kfree_rcu(psl, rcu); pmc->sfmode = msf->imsf_fmode; err = 0; done: @@ -2519,11 +2621,10 @@ done: err = ip_mc_leave_group(sk, &imr); return err; } - int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf, - struct ip_msfilter __user *optval, int __user *optlen) + sockptr_t optval, sockptr_t optlen) { - int err, len, count, copycount; + int err, len, count, copycount, msf_size; struct ip_mreqn imr; __be32 addr = msf->imsf_multiaddr; struct ip_mc_socklist *pmc; @@ -2565,12 +2666,15 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf, copycount = count < msf->imsf_numsrc ? count : msf->imsf_numsrc; len = flex_array_size(psl, sl_addr, copycount); msf->imsf_numsrc = count; - if (put_user(IP_MSFILTER_SIZE(copycount), optlen) || - copy_to_user(optval, msf, IP_MSFILTER_SIZE(0))) { + msf_size = IP_MSFILTER_SIZE(copycount); + if (copy_to_sockptr(optlen, &msf_size, sizeof(int)) || + copy_to_sockptr(optval, msf, IP_MSFILTER_SIZE(0))) { return -EFAULT; } if (len && - copy_to_user(&optval->imsf_slist_flex[0], psl->sl_addr, len)) + copy_to_sockptr_offset(optval, + offsetof(struct ip_msfilter, imsf_slist_flex), + psl->sl_addr, len)) return -EFAULT; return 0; done: @@ -2578,7 +2682,7 @@ done: } int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, - struct sockaddr_storage __user *p) + sockptr_t optval, size_t ss_offset) { int i, count, copycount; struct sockaddr_in *psin; @@ -2608,15 +2712,17 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, count = psl ? psl->sl_count : 0; copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc; gsf->gf_numsrc = count; - for (i = 0; i < copycount; i++, p++) { + for (i = 0; i < copycount; i++) { struct sockaddr_storage ss; psin = (struct sockaddr_in *)&ss; memset(&ss, 0, sizeof(ss)); psin->sin_family = AF_INET; psin->sin_addr.s_addr = psl->sl_addr[i]; - if (copy_to_user(p, &ss, sizeof(ss))) + if (copy_to_sockptr_offset(optval, ss_offset, + &ss, sizeof(ss))) return -EFAULT; + ss_offset += sizeof(ss); } return 0; } @@ -2624,10 +2730,10 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, /* * check if a multicast source filter allows delivery for a given <src,dst,intf> */ -int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr, +int ip_mc_sf_allow(const struct sock *sk, __be32 loc_addr, __be32 rmt_addr, int dif, int sdif) { - struct inet_sock *inet = inet_sk(sk); + const struct inet_sock *inet = inet_sk(sk); struct ip_mc_socklist *pmc; struct ip_sf_socklist *psl; int i; @@ -2644,7 +2750,7 @@ int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr, (sdif && pmc->multi.imr_ifindex == sdif))) break; } - ret = inet->mc_all; + ret = inet_test_bit(MC_ALL, sk); if (!pmc) goto unlock; psl = rcu_dereference(pmc->sflist); @@ -2836,7 +2942,7 @@ static int igmp_mc_seq_show(struct seq_file *seq, void *v) seq_puts(seq, "Idx\tDevice : Count Querier\tGroup Users Timer\tReporter\n"); else { - struct ip_mc_list *im = (struct ip_mc_list *)v; + struct ip_mc_list *im = v; struct igmp_mc_iter_state *state = igmp_mc_seq_private(seq); char *querier; long delta; @@ -2929,8 +3035,6 @@ static struct ip_sf_list *igmp_mcf_get_next(struct seq_file *seq, struct ip_sf_l continue; state->im = rcu_dereference(state->idev->mc_list); } - if (!state->im) - break; spin_lock_bh(&state->im->lock); psf = state->im->sources; } @@ -2980,7 +3084,7 @@ static void igmp_mcf_seq_stop(struct seq_file *seq, void *v) static int igmp_mcf_seq_show(struct seq_file *seq, void *v) { - struct ip_sf_list *psf = (struct ip_sf_list *)v; + struct ip_sf_list *psf = v; struct igmp_mcf_iter_state *state = igmp_mcf_seq_private(seq); if (v == SEQ_START_TOKEN) { diff --git a/net/ipv4/igmp_internal.h b/net/ipv4/igmp_internal.h new file mode 100644 index 000000000000..0a1bcc8ec8e1 --- /dev/null +++ b/net/ipv4/igmp_internal.h @@ -0,0 +1,17 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +#ifndef _LINUX_IGMP_INTERNAL_H +#define _LINUX_IGMP_INTERNAL_H + +struct inet_fill_args { + u32 portid; + u32 seq; + int event; + unsigned int flags; + int netnsid; + int ifindex; +}; + +int inet_fill_ifmcaddr(struct sk_buff *skb, struct net_device *dev, + const struct ip_mc_list *im, + struct inet_fill_args *args); +#endif diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index fc2a985f6064..97d57c52b9ad 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -117,29 +117,135 @@ bool inet_rcv_saddr_any(const struct sock *sk) return !sk->sk_rcv_saddr; } -void inet_get_local_port_range(struct net *net, int *low, int *high) +/** + * inet_sk_get_local_port_range - fetch ephemeral ports range + * @sk: socket + * @low: pointer to low port + * @high: pointer to high port + * + * Fetch netns port range (/proc/sys/net/ipv4/ip_local_port_range) + * Range can be overridden if socket got IP_LOCAL_PORT_RANGE option. + * Returns true if IP_LOCAL_PORT_RANGE was set on this socket. + */ +bool inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high) +{ + int lo, hi, sk_lo, sk_hi; + bool local_range = false; + u32 sk_range; + + inet_get_local_port_range(sock_net(sk), &lo, &hi); + + sk_range = READ_ONCE(inet_sk(sk)->local_port_range); + if (unlikely(sk_range)) { + sk_lo = sk_range & 0xffff; + sk_hi = sk_range >> 16; + + if (lo <= sk_lo && sk_lo <= hi) + lo = sk_lo; + if (lo <= sk_hi && sk_hi <= hi) + hi = sk_hi; + local_range = true; + } + + *low = lo; + *high = hi; + return local_range; +} +EXPORT_SYMBOL(inet_sk_get_local_port_range); + +static bool inet_use_bhash2_on_bind(const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) { + if (ipv6_addr_any(&sk->sk_v6_rcv_saddr)) + return false; + + if (!ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr)) + return true; + } +#endif + return sk->sk_rcv_saddr != htonl(INADDR_ANY); +} + +static bool inet_bind_conflict(const struct sock *sk, struct sock *sk2, + kuid_t uid, bool relax, + bool reuseport_cb_ok, bool reuseport_ok) { - unsigned int seq; + int bound_dev_if2; + + if (sk == sk2) + return false; - do { - seq = read_seqbegin(&net->ipv4.ip_local_ports.lock); + bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if); - *low = net->ipv4.ip_local_ports.range[0]; - *high = net->ipv4.ip_local_ports.range[1]; - } while (read_seqretry(&net->ipv4.ip_local_ports.lock, seq)); + if (!sk->sk_bound_dev_if || !bound_dev_if2 || + sk->sk_bound_dev_if == bound_dev_if2) { + if (sk->sk_reuse && sk2->sk_reuse && + sk2->sk_state != TCP_LISTEN) { + if (!relax || (!reuseport_ok && sk->sk_reuseport && + sk2->sk_reuseport && reuseport_cb_ok && + (sk2->sk_state == TCP_TIME_WAIT || + uid_eq(uid, sk_uid(sk2))))) + return true; + } else if (!reuseport_ok || !sk->sk_reuseport || + !sk2->sk_reuseport || !reuseport_cb_ok || + (sk2->sk_state != TCP_TIME_WAIT && + !uid_eq(uid, sk_uid(sk2)))) { + return true; + } + } + return false; } -EXPORT_SYMBOL(inet_get_local_port_range); +static bool __inet_bhash2_conflict(const struct sock *sk, struct sock *sk2, + kuid_t uid, bool relax, + bool reuseport_cb_ok, bool reuseport_ok) +{ + if (ipv6_only_sock(sk2)) { + if (sk->sk_family == AF_INET) + return false; + +#if IS_ENABLED(CONFIG_IPV6) + if (ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr)) + return false; +#endif + } + + return inet_bind_conflict(sk, sk2, uid, relax, + reuseport_cb_ok, reuseport_ok); +} + +static bool inet_bhash2_conflict(const struct sock *sk, + const struct inet_bind2_bucket *tb2, + kuid_t uid, + bool relax, bool reuseport_cb_ok, + bool reuseport_ok) +{ + struct sock *sk2; + + sk_for_each_bound(sk2, &tb2->owners) { + if (__inet_bhash2_conflict(sk, sk2, uid, relax, + reuseport_cb_ok, reuseport_ok)) + return true; + } + + return false; +} + +#define sk_for_each_bound_bhash(__sk, __tb2, __tb) \ + hlist_for_each_entry(__tb2, &(__tb)->bhash2, bhash_node) \ + sk_for_each_bound((__sk), &(__tb2)->owners) + +/* This should be called only when the tb and tb2 hashbuckets' locks are held */ static int inet_csk_bind_conflict(const struct sock *sk, const struct inet_bind_bucket *tb, + const struct inet_bind2_bucket *tb2, /* may be null */ bool relax, bool reuseport_ok) { - struct sock *sk2; - bool reuseport_cb_ok; - bool reuse = sk->sk_reuse; - bool reuseport = !!sk->sk_reuseport; struct sock_reuseport *reuseport_cb; - kuid_t uid = sock_i_uid((struct sock *)sk); + kuid_t uid = sk_uid(sk); + bool reuseport_cb_ok; + struct sock *sk2; rcu_read_lock(); reuseport_cb = rcu_dereference(sk->sk_reuseport_cb); @@ -147,63 +253,97 @@ static int inet_csk_bind_conflict(const struct sock *sk, reuseport_cb_ok = !reuseport_cb || READ_ONCE(reuseport_cb->num_closed_socks); rcu_read_unlock(); - /* - * Unlike other sk lookup places we do not check + /* Conflicts with an existing IPV6_ADDR_ANY (if ipv6) or INADDR_ANY (if + * ipv4) should have been checked already. We need to do these two + * checks separately because their spinlocks have to be acquired/released + * independently of each other, to prevent possible deadlocks + */ + if (inet_use_bhash2_on_bind(sk)) + return tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, + reuseport_cb_ok, reuseport_ok); + + /* Unlike other sk lookup places we do not check * for sk_net here, since _all_ the socks listed - * in tb->owners list belong to the same net - the - * one this bucket belongs to. + * in tb->owners and tb2->owners list belong + * to the same net - the one this bucket belongs to. */ + sk_for_each_bound_bhash(sk2, tb2, tb) { + if (!inet_bind_conflict(sk, sk2, uid, relax, reuseport_cb_ok, reuseport_ok)) + continue; - sk_for_each_bound(sk2, &tb->owners) { - if (sk != sk2 && - (!sk->sk_bound_dev_if || - !sk2->sk_bound_dev_if || - sk->sk_bound_dev_if == sk2->sk_bound_dev_if)) { - if (reuse && sk2->sk_reuse && - sk2->sk_state != TCP_LISTEN) { - if ((!relax || - (!reuseport_ok && - reuseport && sk2->sk_reuseport && - reuseport_cb_ok && - (sk2->sk_state == TCP_TIME_WAIT || - uid_eq(uid, sock_i_uid(sk2))))) && - inet_rcv_saddr_equal(sk, sk2, true)) - break; - } else if (!reuseport_ok || - !reuseport || !sk2->sk_reuseport || - !reuseport_cb_ok || - (sk2->sk_state != TCP_TIME_WAIT && - !uid_eq(uid, sock_i_uid(sk2)))) { - if (inet_rcv_saddr_equal(sk, sk2, true)) - break; - } - } + if (inet_rcv_saddr_equal(sk, sk2, true)) + return true; } - return sk2 != NULL; + + return false; +} + +/* Determine if there is a bind conflict with an existing IPV6_ADDR_ANY (if ipv6) or + * INADDR_ANY (if ipv4) socket. + * + * Caller must hold bhash hashbucket lock with local bh disabled, to protect + * against concurrent binds on the port for addr any + */ +static bool inet_bhash2_addr_any_conflict(const struct sock *sk, int port, int l3mdev, + bool relax, bool reuseport_ok) +{ + const struct net *net = sock_net(sk); + struct sock_reuseport *reuseport_cb; + struct inet_bind_hashbucket *head2; + struct inet_bind2_bucket *tb2; + kuid_t uid = sk_uid(sk); + bool conflict = false; + bool reuseport_cb_ok; + + rcu_read_lock(); + reuseport_cb = rcu_dereference(sk->sk_reuseport_cb); + /* paired with WRITE_ONCE() in __reuseport_(add|detach)_closed_sock */ + reuseport_cb_ok = !reuseport_cb || READ_ONCE(reuseport_cb->num_closed_socks); + rcu_read_unlock(); + + head2 = inet_bhash2_addr_any_hashbucket(sk, net, port); + + spin_lock(&head2->lock); + + inet_bind_bucket_for_each(tb2, &head2->chain) { + if (!inet_bind2_bucket_match_addr_any(tb2, net, port, l3mdev, sk)) + continue; + + if (!inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok, reuseport_ok)) + continue; + + conflict = true; + break; + } + + spin_unlock(&head2->lock); + + return conflict; } /* * Find an open port number for the socket. Returns with the - * inet_bind_hashbucket lock held. + * inet_bind_hashbucket locks held if successful. */ static struct inet_bind_hashbucket * -inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret, int *port_ret) +inet_csk_find_open_port(const struct sock *sk, struct inet_bind_bucket **tb_ret, + struct inet_bind2_bucket **tb2_ret, + struct inet_bind_hashbucket **head2_ret, int *port_ret) { - struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo; - int port = 0; - struct inet_bind_hashbucket *head; + struct inet_hashinfo *hinfo = tcp_get_hashinfo(sk); + int i, low, high, attempt_half, port, l3mdev; + struct inet_bind_hashbucket *head, *head2; struct net *net = sock_net(sk); - bool relax = false; - int i, low, high, attempt_half; + struct inet_bind2_bucket *tb2; struct inet_bind_bucket *tb; u32 remaining, offset; - int l3mdev; + bool relax = false; l3mdev = inet_sk_bound_l3mdev(sk); ports_exhausted: attempt_half = (sk->sk_reuse == SK_CAN_REUSE) ? 1 : 0; other_half_scan: - inet_get_local_port_range(net, &low, &high); + inet_sk_get_local_port_range(sk, &low, &high); high++; /* [32768, 60999] -> [32768, 61000[ */ if (high - low < 4) attempt_half = 0; @@ -219,7 +359,7 @@ other_half_scan: if (likely(remaining > 1)) remaining &= ~1U; - offset = prandom_u32() % remaining; + offset = get_random_u32_below(remaining); /* __inet_hash_connect() favors ports having @low parity * We do the opposite to not pollute connect() users. */ @@ -235,11 +375,20 @@ other_parity_scan: head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)]; spin_lock_bh(&head->lock); + if (inet_use_bhash2_on_bind(sk)) { + if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, relax, false)) + goto next_port; + } + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + spin_lock(&head2->lock); + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); inet_bind_bucket_for_each(tb, &head->chain) - if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev && - tb->port == port) { - if (!inet_csk_bind_conflict(sk, tb, relax, false)) + if (inet_bind_bucket_match(tb, net, port, l3mdev)) { + if (!inet_csk_bind_conflict(sk, tb, tb2, + relax, false)) goto success; + spin_unlock(&head2->lock); goto next_port; } tb = NULL; @@ -259,7 +408,7 @@ next_port: goto other_half_scan; } - if (net->ipv4.sysctl_ip_autobind_reuse && !relax) { + if (READ_ONCE(net->ipv4.sysctl_ip_autobind_reuse) && !relax) { /* We still have a chance to connect to different destinations */ relax = true; goto ports_exhausted; @@ -268,21 +417,21 @@ next_port: success: *port_ret = port; *tb_ret = tb; + *tb2_ret = tb2; + *head2_ret = head2; return head; } static inline int sk_reuseport_match(struct inet_bind_bucket *tb, - struct sock *sk) + const struct sock *sk) { - kuid_t uid = sock_i_uid(sk); - if (tb->fastreuseport <= 0) return 0; if (!sk->sk_reuseport) return 0; if (rcu_access_pointer(sk->sk_reuseport_cb)) return 0; - if (!uid_eq(tb->fastuid, uid)) + if (!uid_eq(tb->fastuid, sk_uid(sk))) return 0; /* We only need to check the rcv_saddr if this tb was once marked * without fastreuseport and then was reset, as we can only know that @@ -304,17 +453,17 @@ static inline int sk_reuseport_match(struct inet_bind_bucket *tb, ipv6_only_sock(sk), true, false); } -void inet_csk_update_fastreuse(struct inet_bind_bucket *tb, - struct sock *sk) +void inet_csk_update_fastreuse(const struct sock *sk, + struct inet_bind_bucket *tb, + struct inet_bind2_bucket *tb2) { - kuid_t uid = sock_i_uid(sk); bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN; - if (hlist_empty(&tb->owners)) { + if (hlist_empty(&tb->bhash2)) { tb->fastreuse = reuse; if (sk->sk_reuseport) { tb->fastreuseport = FASTREUSEPORT_ANY; - tb->fastuid = uid; + tb->fastuid = sk_uid(sk); tb->fast_rcv_saddr = sk->sk_rcv_saddr; tb->fast_ipv6_only = ipv6_only_sock(sk); tb->fast_sk_family = sk->sk_family; @@ -341,7 +490,7 @@ void inet_csk_update_fastreuse(struct inet_bind_bucket *tb, */ if (!sk_reuseport_match(tb, sk)) { tb->fastreuseport = FASTREUSEPORT_STRICT; - tb->fastuid = uid; + tb->fastuid = sk_uid(sk); tb->fast_rcv_saddr = sk->sk_rcv_saddr; tb->fast_ipv6_only = ipv6_only_sock(sk); tb->fast_sk_family = sk->sk_family; @@ -353,6 +502,9 @@ void inet_csk_update_fastreuse(struct inet_bind_bucket *tb, tb->fastreuseport = 0; } } + + tb2->fastreuse = tb->fastreuse; + tb2->fastreuseport = tb->fastreuseport; } /* Obtain a reference to a local port for the given sock, @@ -362,55 +514,95 @@ void inet_csk_update_fastreuse(struct inet_bind_bucket *tb, int inet_csk_get_port(struct sock *sk, unsigned short snum) { bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN; - struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo; - int ret = 1, port = snum; - struct inet_bind_hashbucket *head; - struct net *net = sock_net(sk); + bool found_port = false, check_bind_conflict = true; + bool bhash_created = false, bhash2_created = false; + struct inet_hashinfo *hinfo = tcp_get_hashinfo(sk); + int ret = -EADDRINUSE, port = snum, l3mdev; + struct inet_bind_hashbucket *head, *head2; + struct inet_bind2_bucket *tb2 = NULL; struct inet_bind_bucket *tb = NULL; - int l3mdev; + bool head2_lock_acquired = false; + struct net *net = sock_net(sk); l3mdev = inet_sk_bound_l3mdev(sk); if (!port) { - head = inet_csk_find_open_port(sk, &tb, &port); + head = inet_csk_find_open_port(sk, &tb, &tb2, &head2, &port); if (!head) return ret; + + head2_lock_acquired = true; + + if (tb && tb2) + goto success; + found_port = true; + } else { + head = &hinfo->bhash[inet_bhashfn(net, port, + hinfo->bhash_size)]; + spin_lock_bh(&head->lock); + inet_bind_bucket_for_each(tb, &head->chain) + if (inet_bind_bucket_match(tb, net, port, l3mdev)) + break; + } + + if (!tb) { + tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, net, + head, port, l3mdev); if (!tb) - goto tb_not_found; - goto success; + goto fail_unlock; + bhash_created = true; } - head = &hinfo->bhash[inet_bhashfn(net, port, - hinfo->bhash_size)]; - spin_lock_bh(&head->lock); - inet_bind_bucket_for_each(tb, &head->chain) - if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev && - tb->port == port) - goto tb_found; -tb_not_found: - tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, - net, head, port, l3mdev); - if (!tb) - goto fail_unlock; -tb_found: - if (!hlist_empty(&tb->owners)) { - if (sk->sk_reuse == SK_FORCE_REUSE) - goto success; - if ((tb->fastreuse > 0 && reuse) || - sk_reuseport_match(tb, sk)) - goto success; - if (inet_csk_bind_conflict(sk, tb, true, true)) + if (!found_port) { + if (!hlist_empty(&tb->bhash2)) { + if (sk->sk_reuse == SK_FORCE_REUSE || + (tb->fastreuse > 0 && reuse) || + sk_reuseport_match(tb, sk)) + check_bind_conflict = false; + } + + if (check_bind_conflict && inet_use_bhash2_on_bind(sk)) { + if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, true, true)) + goto fail_unlock; + } + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + spin_lock(&head2->lock); + head2_lock_acquired = true; + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + } + + if (!tb2) { + tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, + net, head2, tb, sk); + if (!tb2) + goto fail_unlock; + bhash2_created = true; + } + + if (!found_port && check_bind_conflict) { + if (inet_csk_bind_conflict(sk, tb, tb2, true, true)) goto fail_unlock; } + success: - inet_csk_update_fastreuse(tb, sk); + inet_csk_update_fastreuse(sk, tb, tb2); if (!inet_csk(sk)->icsk_bind_hash) - inet_bind_hash(sk, tb, port); + inet_bind_hash(sk, tb, tb2, port); WARN_ON(inet_csk(sk)->icsk_bind_hash != tb); + WARN_ON(inet_csk(sk)->icsk_bind2_hash != tb2); ret = 0; fail_unlock: + if (ret) { + if (bhash2_created) + inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, tb2); + if (bhash_created) + inet_bind_bucket_destroy(tb); + } + if (head2_lock_acquired) + spin_unlock(&head2->lock); spin_unlock_bh(&head->lock); return ret; } @@ -468,7 +660,7 @@ static int inet_csk_wait_for_connect(struct sock *sk, long timeo) /* * This will accept the next outstanding connection. */ -struct sock *inet_csk_accept(struct sock *sk, int flags, int *err, bool kern) +struct sock *inet_csk_accept(struct sock *sk, struct proto_accept_arg *arg) { struct inet_connection_sock *icsk = inet_csk(sk); struct request_sock_queue *queue = &icsk->icsk_accept_queue; @@ -487,7 +679,7 @@ struct sock *inet_csk_accept(struct sock *sk, int flags, int *err, bool kern) /* Find already established connection */ if (reqsk_queue_empty(queue)) { - long timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + long timeo = sock_rcvtimeo(sk, arg->flags & O_NONBLOCK); /* If this is a non blocking socket don't sleep */ error = -EAGAIN; @@ -499,6 +691,7 @@ struct sock *inet_csk_accept(struct sock *sk, int flags, int *err, bool kern) goto out_err; } req = reqsk_queue_remove(queue, sk); + arg->is_empty = reqsk_queue_empty(queue); newsk = req->sk; if (sk->sk_protocol == IPPROTO_TCP && @@ -517,36 +710,18 @@ struct sock *inet_csk_accept(struct sock *sk, int flags, int *err, bool kern) spin_unlock_bh(&queue->fastopenq.lock); } -out: release_sock(sk); - if (newsk && mem_cgroup_sockets_enabled) { - int amt; - - /* atomically get the memory usage, set and charge the - * newsk->sk_memcg. - */ - lock_sock(newsk); - /* The socket has not been accepted yet, no need to look at - * newsk->sk_wmem_queued. - */ - amt = sk_mem_pages(newsk->sk_forward_alloc + - atomic_read(&newsk->sk_rmem_alloc)); - mem_cgroup_sk_alloc(newsk); - if (newsk->sk_memcg && amt) - mem_cgroup_charge_skmem(newsk->sk_memcg, amt, - GFP_KERNEL | __GFP_NOFAIL); - - release_sock(newsk); - } if (req) reqsk_put(req); + + inet_init_csk_locks(newsk); return newsk; + out_err: - newsk = NULL; - req = NULL; - *err = error; - goto out; + release_sock(sk); + arg->err = error; + return NULL; } EXPORT_SYMBOL(inet_csk_accept); @@ -562,36 +737,38 @@ void inet_csk_init_xmit_timers(struct sock *sk, { struct inet_connection_sock *icsk = inet_csk(sk); - timer_setup(&icsk->icsk_retransmit_timer, retransmit_handler, 0); + timer_setup(&sk->tcp_retransmit_timer, retransmit_handler, 0); timer_setup(&icsk->icsk_delack_timer, delack_handler, 0); - timer_setup(&sk->sk_timer, keepalive_handler, 0); + timer_setup(&icsk->icsk_keepalive_timer, keepalive_handler, 0); icsk->icsk_pending = icsk->icsk_ack.pending = 0; } -EXPORT_SYMBOL(inet_csk_init_xmit_timers); void inet_csk_clear_xmit_timers(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); - icsk->icsk_pending = icsk->icsk_ack.pending = 0; + smp_store_release(&icsk->icsk_pending, 0); + smp_store_release(&icsk->icsk_ack.pending, 0); - sk_stop_timer(sk, &icsk->icsk_retransmit_timer); + sk_stop_timer(sk, &sk->tcp_retransmit_timer); sk_stop_timer(sk, &icsk->icsk_delack_timer); - sk_stop_timer(sk, &sk->sk_timer); + sk_stop_timer(sk, &icsk->icsk_keepalive_timer); } -EXPORT_SYMBOL(inet_csk_clear_xmit_timers); -void inet_csk_delete_keepalive_timer(struct sock *sk) +void inet_csk_clear_xmit_timers_sync(struct sock *sk) { - sk_stop_timer(sk, &sk->sk_timer); -} -EXPORT_SYMBOL(inet_csk_delete_keepalive_timer); + struct inet_connection_sock *icsk = inet_csk(sk); -void inet_csk_reset_keepalive_timer(struct sock *sk, unsigned long len) -{ - sk_reset_timer(sk, &sk->sk_timer, jiffies + len); + /* ongoing timer handlers need to acquire socket lock. */ + sock_not_owned_by_me(sk); + + smp_store_release(&icsk->icsk_pending, 0); + smp_store_release(&icsk->icsk_ack.pending, 0); + + sk_stop_timer_sync(sk, &sk->tcp_retransmit_timer); + sk_stop_timer_sync(sk, &icsk->icsk_delack_timer); + sk_stop_timer_sync(sk, &icsk->icsk_keepalive_timer); } -EXPORT_SYMBOL(inet_csk_reset_keepalive_timer); struct dst_entry *inet_csk_route_req(const struct sock *sk, struct flowi4 *fl4, @@ -606,11 +783,11 @@ struct dst_entry *inet_csk_route_req(const struct sock *sk, opt = rcu_dereference(ireq->ireq_opt); flowi4_init_output(fl4, ireq->ir_iif, ireq->ir_mark, - RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, + ip_sock_rt_tos(sk), ip_sock_rt_scope(sk), sk->sk_protocol, inet_sk_flowi_flags(sk), (opt && opt->opt.srr) ? opt->opt.faddr : ireq->ir_rmt_addr, ireq->ir_loc_addr, ireq->ir_rmt_port, - htons(ireq->ir_num), sk->sk_uid); + htons(ireq->ir_num), sk_uid(sk)); security_req_classify_flow(req, flowi4_to_flowi_common(fl4)); rt = ip_route_output_flow(net, fl4, sk); if (IS_ERR(rt)) @@ -627,7 +804,6 @@ no_route: __IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); return NULL; } -EXPORT_SYMBOL_GPL(inet_csk_route_req); struct dst_entry *inet_csk_route_child_sock(const struct sock *sk, struct sock *newsk, @@ -644,11 +820,11 @@ struct dst_entry *inet_csk_route_child_sock(const struct sock *sk, fl4 = &newinet->cork.fl.u.ip4; flowi4_init_output(fl4, ireq->ir_iif, ireq->ir_mark, - RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, + ip_sock_rt_tos(sk), ip_sock_rt_scope(sk), sk->sk_protocol, inet_sk_flowi_flags(sk), (opt && opt->opt.srr) ? opt->opt.faddr : ireq->ir_rmt_addr, ireq->ir_loc_addr, ireq->ir_rmt_port, - htons(ireq->ir_num), sk->sk_uid); + htons(ireq->ir_num), sk_uid(sk)); security_req_classify_flow(req, flowi4_to_flowi_common(fl4)); rt = ip_route_output_flow(net, fl4, sk); if (IS_ERR(rt)) @@ -686,15 +862,61 @@ static void syn_ack_recalc(struct request_sock *req, req->num_timeout >= rskq_defer_accept - 1; } -int inet_rtx_syn_ack(const struct sock *parent, struct request_sock *req) +static struct request_sock * +reqsk_alloc_noprof(const struct request_sock_ops *ops, struct sock *sk_listener, + bool attach_listener) { - int err = req->rsk_ops->rtx_syn_ack(parent, req); + struct request_sock *req; - if (!err) - req->num_retrans++; - return err; + req = kmem_cache_alloc_noprof(ops->slab, GFP_ATOMIC | __GFP_NOWARN); + if (!req) + return NULL; + req->rsk_listener = NULL; + if (attach_listener) { + if (unlikely(!refcount_inc_not_zero(&sk_listener->sk_refcnt))) { + kmem_cache_free(ops->slab, req); + return NULL; + } + req->rsk_listener = sk_listener; + } + req->rsk_ops = ops; + req_to_sk(req)->sk_prot = sk_listener->sk_prot; + sk_node_init(&req_to_sk(req)->sk_node); + sk_tx_queue_clear(req_to_sk(req)); + req->saved_syn = NULL; + req->syncookie = 0; + req->num_timeout = 0; + req->num_retrans = 0; + req->sk = NULL; + refcount_set(&req->rsk_refcnt, 0); + + return req; } -EXPORT_SYMBOL(inet_rtx_syn_ack); +#define reqsk_alloc(...) alloc_hooks(reqsk_alloc_noprof(__VA_ARGS__)) + +struct request_sock *inet_reqsk_alloc(const struct request_sock_ops *ops, + struct sock *sk_listener, + bool attach_listener) +{ + struct request_sock *req = reqsk_alloc(ops, sk_listener, + attach_listener); + + if (req) { + struct inet_request_sock *ireq = inet_rsk(req); + + ireq->ireq_opt = NULL; +#if IS_ENABLED(CONFIG_IPV6) + ireq->pktopts = NULL; +#endif + atomic64_set(&ireq->ir_cookie, 0); + ireq->ireq_state = TCP_NEW_SYN_RECV; + write_pnet(&ireq->ireq_net, sock_net(sk_listener)); + ireq->ireq_family = sk_listener->sk_family; + } + + return req; +} +EXPORT_SYMBOL(inet_reqsk_alloc); static struct request_sock *inet_reqsk_clone(struct request_sock *req, struct sock *sk) @@ -716,8 +938,9 @@ static struct request_sock *inet_reqsk_clone(struct request_sock *req, memcpy(nreq_sk, req_sk, offsetof(struct sock, sk_dontcopy_begin)); - memcpy(&nreq_sk->sk_dontcopy_end, &req_sk->sk_dontcopy_end, - req->rsk_ops->obj_size - offsetof(struct sock, sk_dontcopy_end)); + unsafe_memcpy(&nreq_sk->sk_dontcopy_end, &req_sk->sk_dontcopy_end, + req->rsk_ops->obj_size - offsetof(struct sock, sk_dontcopy_end), + /* alloc is larger than struct, see above */); sk_node_init(&nreq_sk->sk_node); nreq_sk->sk_tx_queue_mapping = req_sk->sk_tx_queue_mapping; @@ -759,43 +982,54 @@ static void reqsk_migrate_reset(struct request_sock *req) /* return true if req was found in the ehash table */ static bool reqsk_queue_unlink(struct request_sock *req) { - struct inet_hashinfo *hashinfo = req_to_sk(req)->sk_prot->h.hashinfo; + struct sock *sk = req_to_sk(req); bool found = false; - if (sk_hashed(req_to_sk(req))) { - spinlock_t *lock = inet_ehash_lockp(hashinfo, req->rsk_hash); + if (sk_hashed(sk)) { + struct inet_hashinfo *hashinfo = tcp_get_hashinfo(sk); + spinlock_t *lock; + lock = inet_ehash_lockp(hashinfo, req->rsk_hash); spin_lock(lock); - found = __sk_nulls_del_node_init_rcu(req_to_sk(req)); + found = __sk_nulls_del_node_init_rcu(sk); spin_unlock(lock); } - if (timer_pending(&req->rsk_timer) && del_timer_sync(&req->rsk_timer)) - reqsk_put(req); + return found; } -bool inet_csk_reqsk_queue_drop(struct sock *sk, struct request_sock *req) +static bool __inet_csk_reqsk_queue_drop(struct sock *sk, + struct request_sock *req, + bool from_timer) { bool unlinked = reqsk_queue_unlink(req); + if (!from_timer && timer_delete_sync(&req->rsk_timer)) + reqsk_put(req); + if (unlinked) { reqsk_queue_removed(&inet_csk(sk)->icsk_accept_queue, req); reqsk_put(req); } + return unlinked; } -EXPORT_SYMBOL(inet_csk_reqsk_queue_drop); + +bool inet_csk_reqsk_queue_drop(struct sock *sk, struct request_sock *req) +{ + return __inet_csk_reqsk_queue_drop(sk, req, false); +} void inet_csk_reqsk_queue_drop_and_put(struct sock *sk, struct request_sock *req) { inet_csk_reqsk_queue_drop(sk, req); reqsk_put(req); } -EXPORT_SYMBOL(inet_csk_reqsk_queue_drop_and_put); +EXPORT_IPV6_MOD(inet_csk_reqsk_queue_drop_and_put); static void reqsk_timer_handler(struct timer_list *t) { - struct request_sock *req = from_timer(req, t, rsk_timer); + struct request_sock *req = timer_container_of(req, t, rsk_timer); struct request_sock *nreq = NULL, *oreq = req; struct sock *sk_listener = req->rsk_listener; struct inet_connection_sock *icsk; @@ -829,7 +1063,8 @@ static void reqsk_timer_handler(struct timer_list *t) icsk = inet_csk(sk_listener); net = sock_net(sk_listener); - max_syn_ack_retries = icsk->icsk_syn_retries ? : net->ipv4.sysctl_tcp_synack_retries; + max_syn_ack_retries = READ_ONCE(icsk->icsk_syn_retries) ? : + READ_ONCE(net->ipv4.sysctl_tcp_synack_retries); /* Normally all the openreqs are young and become mature * (i.e. converted to established socket) for first timeout. * If synack was not acknowledged for 1 second, it means @@ -859,26 +1094,25 @@ static void reqsk_timer_handler(struct timer_list *t) young <<= 1; } } + syn_ack_recalc(req, max_syn_ack_retries, READ_ONCE(queue->rskq_defer_accept), &expire, &resend); - req->rsk_ops->syn_ack_timeout(req); + tcp_syn_ack_timeout(req); + if (!expire && (!resend || - !inet_rtx_syn_ack(sk_listener, req) || + !tcp_rtx_synack(sk_listener, req) || inet_rsk(req)->acked)) { - unsigned long timeo; - if (req->num_timeout++ == 0) atomic_dec(&queue->young); - timeo = min(TCP_TIMEOUT_INIT << req->num_timeout, TCP_RTO_MAX); - mod_timer(&req->rsk_timer, jiffies + timeo); + mod_timer(&req->rsk_timer, jiffies + tcp_reqsk_timeout(req)); if (!nreq) return; if (!inet_ehash_insert(req_to_sk(nreq), req_to_sk(oreq), NULL)) { /* delete timer */ - inet_csk_reqsk_queue_drop(sk_listener, nreq); + __inet_csk_reqsk_queue_drop(sk_listener, nreq, true); goto no_ownership; } @@ -904,30 +1138,38 @@ no_ownership: } drop: - inet_csk_reqsk_queue_drop_and_put(oreq->rsk_listener, oreq); + __inet_csk_reqsk_queue_drop(sk_listener, oreq, true); + reqsk_put(oreq); } -static void reqsk_queue_hash_req(struct request_sock *req, - unsigned long timeout) +static bool reqsk_queue_hash_req(struct request_sock *req) { + bool found_dup_sk = false; + + if (!inet_ehash_insert(req_to_sk(req), NULL, &found_dup_sk)) + return false; + + /* The timer needs to be setup after a successful insertion. */ + req->timeout = tcp_timeout_init((struct sock *)req); timer_setup(&req->rsk_timer, reqsk_timer_handler, TIMER_PINNED); - mod_timer(&req->rsk_timer, jiffies + timeout); + mod_timer(&req->rsk_timer, jiffies + req->timeout); - inet_ehash_insert(req_to_sk(req), NULL, NULL); /* before letting lookups find us, make sure all req fields * are committed to memory and refcnt initialized. */ smp_wmb(); refcount_set(&req->rsk_refcnt, 2 + 1); + return true; } -void inet_csk_reqsk_queue_hash_add(struct sock *sk, struct request_sock *req, - unsigned long timeout) +bool inet_csk_reqsk_queue_hash_add(struct sock *sk, struct request_sock *req) { - reqsk_queue_hash_req(req, timeout); + if (!reqsk_queue_hash_req(req)) + return false; + inet_csk_reqsk_queue_added(sk); + return true; } -EXPORT_SYMBOL_GPL(inet_csk_reqsk_queue_hash_add); static void inet_clone_ulp(const struct request_sock *req, struct sock *newsk, const gfp_t priority) @@ -937,8 +1179,7 @@ static void inet_clone_ulp(const struct request_sock *req, struct sock *newsk, if (!icsk->icsk_ulp_ops) return; - if (icsk->icsk_ulp_ops->clone) - icsk->icsk_ulp_ops->clone(req, newsk, priority); + icsk->icsk_ulp_ops->clone(req, newsk, priority); } /** @@ -954,41 +1195,61 @@ struct sock *inet_csk_clone_lock(const struct sock *sk, const gfp_t priority) { struct sock *newsk = sk_clone_lock(sk, priority); + struct inet_connection_sock *newicsk; + struct inet_request_sock *ireq; + struct inet_sock *newinet; + + if (!newsk) + return NULL; + + newicsk = inet_csk(newsk); + newinet = inet_sk(newsk); + ireq = inet_rsk(req); - if (newsk) { - struct inet_connection_sock *newicsk = inet_csk(newsk); + newicsk->icsk_bind_hash = NULL; + newicsk->icsk_bind2_hash = NULL; - inet_sk_set_state(newsk, TCP_SYN_RECV); - newicsk->icsk_bind_hash = NULL; + newinet->inet_dport = ireq->ir_rmt_port; + newinet->inet_num = ireq->ir_num; + newinet->inet_sport = htons(ireq->ir_num); - inet_sk(newsk)->inet_dport = inet_rsk(req)->ir_rmt_port; - inet_sk(newsk)->inet_num = inet_rsk(req)->ir_num; - inet_sk(newsk)->inet_sport = htons(inet_rsk(req)->ir_num); + newsk->sk_bound_dev_if = ireq->ir_iif; - /* listeners have SOCK_RCU_FREE, not the children */ - sock_reset_flag(newsk, SOCK_RCU_FREE); + newsk->sk_daddr = ireq->ir_rmt_addr; + newsk->sk_rcv_saddr = ireq->ir_loc_addr; + newinet->inet_saddr = ireq->ir_loc_addr; - inet_sk(newsk)->mc_list = NULL; +#if IS_ENABLED(CONFIG_IPV6) + newsk->sk_v6_daddr = ireq->ir_v6_rmt_addr; + newsk->sk_v6_rcv_saddr = ireq->ir_v6_loc_addr; +#endif - newsk->sk_mark = inet_rsk(req)->ir_mark; - atomic64_set(&newsk->sk_cookie, - atomic64_read(&inet_rsk(req)->ir_cookie)); + /* listeners have SOCK_RCU_FREE, not the children */ + sock_reset_flag(newsk, SOCK_RCU_FREE); - newicsk->icsk_retransmits = 0; - newicsk->icsk_backoff = 0; - newicsk->icsk_probes_out = 0; - newicsk->icsk_probes_tstamp = 0; + inet_sk(newsk)->mc_list = NULL; - /* Deinitialize accept_queue to trap illegal accesses. */ - memset(&newicsk->icsk_accept_queue, 0, sizeof(newicsk->icsk_accept_queue)); + newsk->sk_mark = inet_rsk(req)->ir_mark; + atomic64_set(&newsk->sk_cookie, + atomic64_read(&inet_rsk(req)->ir_cookie)); - inet_clone_ulp(req, newsk, priority); + newicsk->icsk_retransmits = 0; + newicsk->icsk_backoff = 0; + newicsk->icsk_probes_out = 0; + newicsk->icsk_probes_tstamp = 0; + + /* Deinitialize accept_queue to trap illegal accesses. */ + memset(&newicsk->icsk_accept_queue, 0, + sizeof(newicsk->icsk_accept_queue)); + + inet_sk_set_state(newsk, TCP_SYN_RECV); + + inet_clone_ulp(req, newsk, priority); + + security_inet_csk_clone(newsk, req); - security_inet_csk_clone(newsk, req); - } return newsk; } -EXPORT_SYMBOL_GPL(inet_csk_clone_lock); /* * At this point, there should be no process reference to this @@ -1013,16 +1274,21 @@ void inet_csk_destroy_sock(struct sock *sk) xfrm_sk_free_policy(sk); - sk_refcnt_debug_release(sk); - - this_cpu_dec(*sk->sk_prot->orphan_count); + tcp_orphan_count_dec(); sock_put(sk); } EXPORT_SYMBOL(inet_csk_destroy_sock); +void inet_csk_prepare_for_destroy_sock(struct sock *sk) +{ + /* The below has to be done to allow calling inet_csk_destroy_sock */ + sock_set_flag(sk, SOCK_DEAD); + tcp_orphan_count_inc(); +} + /* This function allows to force a closure of a socket after the call to - * tcp/dccp_create_openreq_child(). + * tcp_create_openreq_child(). */ void inet_csk_prepare_forced_close(struct sock *sk) __releases(&sk->sk_lock.slock) @@ -1035,11 +1301,25 @@ void inet_csk_prepare_forced_close(struct sock *sk) } EXPORT_SYMBOL(inet_csk_prepare_forced_close); +static int inet_ulp_can_listen(const struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + if (icsk->icsk_ulp_ops && !icsk->icsk_ulp_ops->clone) + return -EINVAL; + + return 0; +} + int inet_csk_listen_start(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); struct inet_sock *inet = inet_sk(sk); - int err = -EADDRINUSE; + int err; + + err = inet_ulp_can_listen(sk); + if (unlikely(err)) + return err; reqsk_queue_alloc(&icsk->icsk_accept_queue); @@ -1052,7 +1332,8 @@ int inet_csk_listen_start(struct sock *sk) * after validation is complete. */ inet_sk_state_store(sk, TCP_LISTEN); - if (!sk->sk_prot->get_port(sk, inet->inet_num)) { + err = sk->sk_prot->get_port(sk, inet->inet_num); + if (!err) { inet->inet_sport = htons(inet->inet_num); sk_dst_reset(sk); @@ -1065,7 +1346,6 @@ int inet_csk_listen_start(struct sock *sk) inet_sk_set_state(sk, TCP_CLOSE); return err; } -EXPORT_SYMBOL_GPL(inet_csk_listen_start); static void inet_child_forget(struct sock *sk, struct request_sock *req, struct sock *child) @@ -1074,7 +1354,7 @@ static void inet_child_forget(struct sock *sk, struct request_sock *req, sock_orphan(child); - this_cpu_inc(*sk->sk_prot->orphan_count); + tcp_orphan_count_inc(); if (sk->sk_protocol == IPPROTO_TCP && tcp_rsk(req)->tfo_listener) { BUG_ON(rcu_access_pointer(tcp_sk(child)->fastopen_rsk) != req); @@ -1160,7 +1440,6 @@ child_put: sock_put(child); return NULL; } -EXPORT_SYMBOL(inet_csk_complete_hashdance); /* * This routine closes sockets which have been at least partially @@ -1238,34 +1517,16 @@ skip_child_forget: } EXPORT_SYMBOL_GPL(inet_csk_listen_stop); -void inet_csk_addr2sockaddr(struct sock *sk, struct sockaddr *uaddr) -{ - struct sockaddr_in *sin = (struct sockaddr_in *)uaddr; - const struct inet_sock *inet = inet_sk(sk); - - sin->sin_family = AF_INET; - sin->sin_addr.s_addr = inet->inet_daddr; - sin->sin_port = inet->inet_dport; -} -EXPORT_SYMBOL_GPL(inet_csk_addr2sockaddr); - static struct dst_entry *inet_csk_rebuild_route(struct sock *sk, struct flowi *fl) { const struct inet_sock *inet = inet_sk(sk); - const struct ip_options_rcu *inet_opt; - __be32 daddr = inet->inet_daddr; struct flowi4 *fl4; struct rtable *rt; rcu_read_lock(); - inet_opt = rcu_dereference(inet->inet_opt); - if (inet_opt && inet_opt->opt.srr) - daddr = inet_opt->opt.faddr; fl4 = &fl->u.ip4; - rt = ip_route_output_ports(sock_net(sk), fl4, sk, daddr, - inet->inet_saddr, inet->inet_dport, - inet->inet_sport, sk->sk_protocol, - RT_CONN_FLAGS(sk), sk->sk_bound_dev_if); + inet_sk_init_flowi4(inet, fl4); + rt = ip_route_output_flow(sock_net(sk), fl4, sk); if (IS_ERR(rt)) rt = NULL; if (rt) @@ -1293,4 +1554,3 @@ struct dst_entry *inet_csk_update_pmtu(struct sock *sk, u32 mtu) out: return dst; } -EXPORT_SYMBOL_GPL(inet_csk_update_pmtu); diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c index 581b5b2d72a5..3f5b1418a610 100644 --- a/net/ipv4/inet_diag.c +++ b/net/ipv4/inet_diag.c @@ -20,9 +20,6 @@ #include <net/ipv6.h> #include <net/inet_common.h> #include <net/inet_connection_sock.h> -#include <net/inet_hashtables.h> -#include <net/inet_timewait_sock.h> -#include <net/inet6_hashtables.h> #include <net/bpf_sk_storage.h> #include <net/netlink.h> @@ -32,7 +29,7 @@ #include <linux/inet_diag.h> #include <linux/sock_diag.h> -static const struct inet_diag_handler **inet_diag_table; +static const struct inet_diag_handler __rcu **inet_diag_table; struct inet_diag_entry { const __be32 *saddr; @@ -48,77 +45,55 @@ struct inet_diag_entry { #endif }; -static DEFINE_MUTEX(inet_diag_table_mutex); - static const struct inet_diag_handler *inet_diag_lock_handler(int proto) { - if (proto < 0 || proto >= IPPROTO_MAX) { - mutex_lock(&inet_diag_table_mutex); - return ERR_PTR(-ENOENT); - } + const struct inet_diag_handler *handler; - if (!inet_diag_table[proto]) + if (proto < 0 || proto >= IPPROTO_MAX) + return NULL; + + if (!READ_ONCE(inet_diag_table[proto])) sock_load_diag_module(AF_INET, proto); - mutex_lock(&inet_diag_table_mutex); - if (!inet_diag_table[proto]) - return ERR_PTR(-ENOENT); + rcu_read_lock(); + handler = rcu_dereference(inet_diag_table[proto]); + if (handler && !try_module_get(handler->owner)) + handler = NULL; + rcu_read_unlock(); - return inet_diag_table[proto]; + return handler; } static void inet_diag_unlock_handler(const struct inet_diag_handler *handler) { - mutex_unlock(&inet_diag_table_mutex); + module_put(handler->owner); } void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk) { - r->idiag_family = sk->sk_family; + r->idiag_family = READ_ONCE(sk->sk_family); - r->id.idiag_sport = htons(sk->sk_num); - r->id.idiag_dport = sk->sk_dport; - r->id.idiag_if = sk->sk_bound_dev_if; + r->id.idiag_sport = htons(READ_ONCE(sk->sk_num)); + r->id.idiag_dport = READ_ONCE(sk->sk_dport); + r->id.idiag_if = READ_ONCE(sk->sk_bound_dev_if); sock_diag_save_cookie(sk, r->id.idiag_cookie); #if IS_ENABLED(CONFIG_IPV6) - if (sk->sk_family == AF_INET6) { - *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr; - *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr; + if (r->idiag_family == AF_INET6) { + data_race(*(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr); + data_race(*(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr); } else #endif { memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src)); memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst)); - r->id.idiag_src[0] = sk->sk_rcv_saddr; - r->id.idiag_dst[0] = sk->sk_daddr; + r->id.idiag_src[0] = READ_ONCE(sk->sk_rcv_saddr); + r->id.idiag_dst[0] = READ_ONCE(sk->sk_daddr); } } EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill); -static size_t inet_sk_attr_size(struct sock *sk, - const struct inet_diag_req_v2 *req, - bool net_admin) -{ - const struct inet_diag_handler *handler; - size_t aux = 0; - - handler = inet_diag_table[req->sdiag_protocol]; - if (handler && handler->idiag_get_aux_size) - aux = handler->idiag_get_aux_size(sk, net_admin); - - return nla_total_size(sizeof(struct tcp_info)) - + nla_total_size(sizeof(struct inet_diag_msg)) - + inet_diag_msg_attrs_size() - + nla_total_size(sizeof(struct inet_diag_meminfo)) - + nla_total_size(SK_MEMINFO_VARS * sizeof(u32)) - + nla_total_size(TCP_CA_NAME_MAX) - + nla_total_size(sizeof(struct tcpvegas_info)) - + aux - + 64; -} - int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, struct inet_diag_msg *r, int ext, struct user_namespace *user_ns, @@ -134,7 +109,7 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, * hence this needs to be included regardless of socket family. */ if (ext & (1 << (INET_DIAG_TOS - 1))) - if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0) + if (nla_put_u8(skb, INET_DIAG_TOS, READ_ONCE(inet->tos)) < 0) goto errout; #if IS_ENABLED(CONFIG_IPV6) @@ -150,14 +125,14 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, } #endif - if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark)) + if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, READ_ONCE(sk->sk_mark))) goto errout; if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) || ext & (1 << (INET_DIAG_TCLASS - 1))) { u32 classid = 0; -#ifdef CONFIG_SOCK_CGROUP_DATA +#ifdef CONFIG_CGROUP_NET_CLASSID classid = sock_cgroup_classid(&sk->sk_cgrp_data); #endif /* Fallback to socket priority if class id isn't set. @@ -165,7 +140,7 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, * For cgroup2 classid is always zero. */ if (!classid) - classid = sk->sk_priority; + classid = READ_ONCE(sk->sk_priority); if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid)) goto errout; @@ -178,21 +153,21 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, goto errout; #endif - r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk)); + r->idiag_uid = from_kuid_munged(user_ns, sk_uid(sk)); r->idiag_inode = sock_i_ino(sk); memset(&inet_sockopt, 0, sizeof(inet_sockopt)); - inet_sockopt.recverr = inet->recverr; - inet_sockopt.is_icsk = inet->is_icsk; - inet_sockopt.freebind = inet->freebind; - inet_sockopt.hdrincl = inet->hdrincl; - inet_sockopt.mc_loop = inet->mc_loop; - inet_sockopt.transparent = inet->transparent; - inet_sockopt.mc_all = inet->mc_all; - inet_sockopt.nodefrag = inet->nodefrag; - inet_sockopt.bind_address_no_port = inet->bind_address_no_port; - inet_sockopt.recverr_rfc4884 = inet->recverr_rfc4884; - inet_sockopt.defer_connect = inet->defer_connect; + inet_sockopt.recverr = inet_test_bit(RECVERR, sk); + inet_sockopt.is_icsk = inet_test_bit(IS_ICSK, sk); + inet_sockopt.freebind = inet_test_bit(FREEBIND, sk); + inet_sockopt.hdrincl = inet_test_bit(HDRINCL, sk); + inet_sockopt.mc_loop = inet_test_bit(MC_LOOP, sk); + inet_sockopt.transparent = inet_test_bit(TRANSPARENT, sk); + inet_sockopt.mc_all = inet_test_bit(MC_ALL, sk); + inet_sockopt.nodefrag = inet_test_bit(NODEFRAG, sk); + inet_sockopt.bind_address_no_port = inet_test_bit(BIND_ADDRESS_NO_PORT, sk); + inet_sockopt.recverr_rfc4884 = inet_test_bit(RECVERR_RFC4884, sk); + inet_sockopt.defer_connect = inet_test_bit(DEFER_CONNECT, sk); if (nla_put(skb, INET_DIAG_SOCKOPT, sizeof(inet_sockopt), &inet_sockopt)) goto errout; @@ -244,10 +219,17 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, struct nlmsghdr *nlh; struct nlattr *attr; void *info = NULL; + u8 icsk_pending; + int protocol; cb_data = cb->data; - handler = inet_diag_table[inet_diag_get_protocol(req, cb_data)]; - BUG_ON(!handler); + protocol = inet_diag_get_protocol(req, cb_data); + + /* inet_diag_lock_handler() made sure inet_diag_table[] is stable. */ + handler = rcu_dereference_protected(inet_diag_table[protocol], 1); + DEBUG_NET_WARN_ON_ONCE(!handler); + if (!handler) + return -ENXIO; nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags); @@ -272,7 +254,7 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, struct inet_diag_meminfo minfo = { .idiag_rmem = sk_rmem_alloc_get(sk), .idiag_wmem = READ_ONCE(sk->sk_wmem_queued), - .idiag_fmem = sk_forward_alloc_get(sk), + .idiag_fmem = READ_ONCE(sk->sk_forward_alloc), .idiag_tmem = sk_wmem_alloc_get(sk), }; @@ -298,23 +280,24 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, goto out; } - if (icsk->icsk_pending == ICSK_TIME_RETRANS || - icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT || - icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) { + icsk_pending = smp_load_acquire(&icsk->icsk_pending); + if (icsk_pending == ICSK_TIME_RETRANS || + icsk_pending == ICSK_TIME_REO_TIMEOUT || + icsk_pending == ICSK_TIME_LOSS_PROBE) { r->idiag_timer = 1; - r->idiag_retrans = icsk->icsk_retransmits; + r->idiag_retrans = READ_ONCE(icsk->icsk_retransmits); r->idiag_expires = - jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies); - } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) { + jiffies_delta_to_msecs(tcp_timeout_expires(sk) - jiffies); + } else if (icsk_pending == ICSK_TIME_PROBE0) { r->idiag_timer = 4; - r->idiag_retrans = icsk->icsk_probes_out; + r->idiag_retrans = READ_ONCE(icsk->icsk_probes_out); r->idiag_expires = - jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies); - } else if (timer_pending(&sk->sk_timer)) { + jiffies_delta_to_msecs(tcp_timeout_expires(sk) - jiffies); + } else if (timer_pending(&icsk->icsk_keepalive_timer)) { r->idiag_timer = 2; - r->idiag_retrans = icsk->icsk_probes_out; + r->idiag_retrans = READ_ONCE(icsk->icsk_probes_out); r->idiag_expires = - jiffies_delta_to_msecs(sk->sk_timer.expires - jiffies); + jiffies_delta_to_msecs(icsk->icsk_keepalive_timer.expires - jiffies); } if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) { @@ -411,183 +394,6 @@ errout: } EXPORT_SYMBOL_GPL(inet_sk_diag_fill); -static int inet_twsk_diag_fill(struct sock *sk, - struct sk_buff *skb, - struct netlink_callback *cb, - u16 nlmsg_flags, bool net_admin) -{ - struct inet_timewait_sock *tw = inet_twsk(sk); - struct inet_diag_msg *r; - struct nlmsghdr *nlh; - long tmo; - - nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type, - sizeof(*r), nlmsg_flags); - if (!nlh) - return -EMSGSIZE; - - r = nlmsg_data(nlh); - BUG_ON(tw->tw_state != TCP_TIME_WAIT); - - inet_diag_msg_common_fill(r, sk); - r->idiag_retrans = 0; - - r->idiag_state = tw->tw_substate; - r->idiag_timer = 3; - tmo = tw->tw_timer.expires - jiffies; - r->idiag_expires = jiffies_delta_to_msecs(tmo); - r->idiag_rqueue = 0; - r->idiag_wqueue = 0; - r->idiag_uid = 0; - r->idiag_inode = 0; - - if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, - tw->tw_mark)) { - nlmsg_cancel(skb, nlh); - return -EMSGSIZE; - } - - nlmsg_end(skb, nlh); - return 0; -} - -static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb, - struct netlink_callback *cb, - u16 nlmsg_flags, bool net_admin) -{ - struct request_sock *reqsk = inet_reqsk(sk); - struct inet_diag_msg *r; - struct nlmsghdr *nlh; - long tmo; - - nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, - cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags); - if (!nlh) - return -EMSGSIZE; - - r = nlmsg_data(nlh); - inet_diag_msg_common_fill(r, sk); - r->idiag_state = TCP_SYN_RECV; - r->idiag_timer = 1; - r->idiag_retrans = reqsk->num_retrans; - - BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) != - offsetof(struct sock, sk_cookie)); - - tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies; - r->idiag_expires = jiffies_delta_to_msecs(tmo); - r->idiag_rqueue = 0; - r->idiag_wqueue = 0; - r->idiag_uid = 0; - r->idiag_inode = 0; - - if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, - inet_rsk(reqsk)->ir_mark)) { - nlmsg_cancel(skb, nlh); - return -EMSGSIZE; - } - - nlmsg_end(skb, nlh); - return 0; -} - -static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, - struct netlink_callback *cb, - const struct inet_diag_req_v2 *r, - u16 nlmsg_flags, bool net_admin) -{ - if (sk->sk_state == TCP_TIME_WAIT) - return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags, net_admin); - - if (sk->sk_state == TCP_NEW_SYN_RECV) - return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin); - - return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags, - net_admin); -} - -struct sock *inet_diag_find_one_icsk(struct net *net, - struct inet_hashinfo *hashinfo, - const struct inet_diag_req_v2 *req) -{ - struct sock *sk; - - rcu_read_lock(); - if (req->sdiag_family == AF_INET) - sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0], - req->id.idiag_dport, req->id.idiag_src[0], - req->id.idiag_sport, req->id.idiag_if); -#if IS_ENABLED(CONFIG_IPV6) - else if (req->sdiag_family == AF_INET6) { - if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) && - ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src)) - sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3], - req->id.idiag_dport, req->id.idiag_src[3], - req->id.idiag_sport, req->id.idiag_if); - else - sk = inet6_lookup(net, hashinfo, NULL, 0, - (struct in6_addr *)req->id.idiag_dst, - req->id.idiag_dport, - (struct in6_addr *)req->id.idiag_src, - req->id.idiag_sport, - req->id.idiag_if); - } -#endif - else { - rcu_read_unlock(); - return ERR_PTR(-EINVAL); - } - rcu_read_unlock(); - if (!sk) - return ERR_PTR(-ENOENT); - - if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) { - sock_gen_put(sk); - return ERR_PTR(-ENOENT); - } - - return sk; -} -EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk); - -int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, - struct netlink_callback *cb, - const struct inet_diag_req_v2 *req) -{ - struct sk_buff *in_skb = cb->skb; - bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN); - struct net *net = sock_net(in_skb->sk); - struct sk_buff *rep; - struct sock *sk; - int err; - - sk = inet_diag_find_one_icsk(net, hashinfo, req); - if (IS_ERR(sk)) - return PTR_ERR(sk); - - rep = nlmsg_new(inet_sk_attr_size(sk, req, net_admin), GFP_KERNEL); - if (!rep) { - err = -ENOMEM; - goto out; - } - - err = sk_diag_fill(sk, rep, cb, req, 0, net_admin); - if (err < 0) { - WARN_ON(err == -EMSGSIZE); - nlmsg_free(rep); - goto out; - } - err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid); - -out: - if (sk) - sock_gen_put(sk); - - return err; -} -EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk); - static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb, const struct nlmsghdr *nlh, int hdrlen, @@ -605,9 +411,10 @@ static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb, protocol = inet_diag_get_protocol(req, &dump_data); handler = inet_diag_lock_handler(protocol); - if (IS_ERR(handler)) { - err = PTR_ERR(handler); - } else if (cmd == SOCK_DIAG_BY_FAMILY) { + if (!handler) + return -ENOENT; + + if (cmd == SOCK_DIAG_BY_FAMILY) { struct netlink_callback cb = { .nlh = nlh, .skb = in_skb, @@ -773,7 +580,7 @@ static void entry_fill_addrs(struct inet_diag_entry *entry, const struct sock *sk) { #if IS_ENABLED(CONFIG_IPV6) - if (sk->sk_family == AF_INET6) { + if (entry->family == AF_INET6) { entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32; entry->daddr = sk->sk_v6_daddr.s6_addr32; } else @@ -784,31 +591,36 @@ static void entry_fill_addrs(struct inet_diag_entry *entry, } } -int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk) +int inet_diag_bc_sk(const struct inet_diag_dump_data *cb_data, struct sock *sk) { - struct inet_sock *inet = inet_sk(sk); + const struct nlattr *bc = cb_data->inet_diag_nla_bc; + const struct inet_sock *inet = inet_sk(sk); struct inet_diag_entry entry; if (!bc) return 1; - entry.family = sk->sk_family; + entry.family = READ_ONCE(sk->sk_family); entry_fill_addrs(&entry, sk); - entry.sport = inet->inet_num; - entry.dport = ntohs(inet->inet_dport); - entry.ifindex = sk->sk_bound_dev_if; - entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0; - if (sk_fullsock(sk)) - entry.mark = sk->sk_mark; - else if (sk->sk_state == TCP_NEW_SYN_RECV) - entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark; - else if (sk->sk_state == TCP_TIME_WAIT) - entry.mark = inet_twsk(sk)->tw_mark; - else - entry.mark = 0; + entry.sport = READ_ONCE(inet->inet_num); + entry.dport = ntohs(READ_ONCE(inet->inet_dport)); + entry.ifindex = READ_ONCE(sk->sk_bound_dev_if); + if (cb_data->userlocks_needed) + entry.userlocks = sk_fullsock(sk) ? READ_ONCE(sk->sk_userlocks) : 0; + if (cb_data->mark_needed) { + if (sk_fullsock(sk)) + entry.mark = READ_ONCE(sk->sk_mark); + else if (sk->sk_state == TCP_NEW_SYN_RECV) + entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark; + else if (sk->sk_state == TCP_TIME_WAIT) + entry.mark = inet_twsk(sk)->tw_mark; + else + entry.mark = 0; + } #ifdef CONFIG_SOCK_CGROUP_DATA - entry.cgroup_id = sk_fullsock(sk) ? - cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)) : 0; + if (cb_data->cgroup_needed) + entry.cgroup_id = sk_fullsock(sk) ? + cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)) : 0; #endif return inet_diag_bc_run(bc, &entry); @@ -908,16 +720,21 @@ static bool valid_cgroupcond(const struct inet_diag_bc_op *op, int len, } #endif -static int inet_diag_bc_audit(const struct nlattr *attr, +static int inet_diag_bc_audit(struct inet_diag_dump_data *cb_data, const struct sk_buff *skb) { - bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN); + const struct nlattr *attr = cb_data->inet_diag_nla_bc; const void *bytecode, *bc; int bytecode_len, len; + bool net_admin; - if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op)) + if (!attr) + return 0; + + if (nla_len(attr) < sizeof(struct inet_diag_bc_op)) return -EINVAL; + net_admin = netlink_net_capable(skb, CAP_NET_ADMIN); bytecode = bc = nla_data(attr); len = bytecode_len = nla_len(attr); @@ -949,14 +766,18 @@ static int inet_diag_bc_audit(const struct nlattr *attr, return -EPERM; if (!valid_markcond(bc, len, &min_len)) return -EINVAL; + cb_data->mark_needed = true; break; #ifdef CONFIG_SOCK_CGROUP_DATA case INET_DIAG_BC_CGROUP_COND: if (!valid_cgroupcond(bc, len, &min_len)) return -EINVAL; + cb_data->cgroup_needed = true; break; #endif case INET_DIAG_BC_AUTO: + cb_data->userlocks_needed = true; + fallthrough; case INET_DIAG_BC_JMP: case INET_DIAG_BC_NOP: break; @@ -980,187 +801,6 @@ static int inet_diag_bc_audit(const struct nlattr *attr, return len == 0 ? 0 : -EINVAL; } -static void twsk_build_assert(void) -{ - BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) != - offsetof(struct sock, sk_family)); - - BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) != - offsetof(struct inet_sock, inet_num)); - - BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) != - offsetof(struct inet_sock, inet_dport)); - - BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) != - offsetof(struct inet_sock, inet_rcv_saddr)); - - BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) != - offsetof(struct inet_sock, inet_daddr)); - -#if IS_ENABLED(CONFIG_IPV6) - BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) != - offsetof(struct sock, sk_v6_rcv_saddr)); - - BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) != - offsetof(struct sock, sk_v6_daddr)); -#endif -} - -void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, - struct netlink_callback *cb, - const struct inet_diag_req_v2 *r) -{ - bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); - struct inet_diag_dump_data *cb_data = cb->data; - struct net *net = sock_net(skb->sk); - u32 idiag_states = r->idiag_states; - int i, num, s_i, s_num; - struct nlattr *bc; - struct sock *sk; - - bc = cb_data->inet_diag_nla_bc; - if (idiag_states & TCPF_SYN_RECV) - idiag_states |= TCPF_NEW_SYN_RECV; - s_i = cb->args[1]; - s_num = num = cb->args[2]; - - if (cb->args[0] == 0) { - if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport) - goto skip_listen_ht; - - for (i = s_i; i < INET_LHTABLE_SIZE; i++) { - struct inet_listen_hashbucket *ilb; - struct hlist_nulls_node *node; - - num = 0; - ilb = &hashinfo->listening_hash[i]; - spin_lock(&ilb->lock); - sk_nulls_for_each(sk, node, &ilb->nulls_head) { - struct inet_sock *inet = inet_sk(sk); - - if (!net_eq(sock_net(sk), net)) - continue; - - if (num < s_num) { - num++; - continue; - } - - if (r->sdiag_family != AF_UNSPEC && - sk->sk_family != r->sdiag_family) - goto next_listen; - - if (r->id.idiag_sport != inet->inet_sport && - r->id.idiag_sport) - goto next_listen; - - if (!inet_diag_bc_sk(bc, sk)) - goto next_listen; - - if (inet_sk_diag_fill(sk, inet_csk(sk), skb, - cb, r, NLM_F_MULTI, - net_admin) < 0) { - spin_unlock(&ilb->lock); - goto done; - } - -next_listen: - ++num; - } - spin_unlock(&ilb->lock); - - s_num = 0; - } -skip_listen_ht: - cb->args[0] = 1; - s_i = num = s_num = 0; - } - - if (!(idiag_states & ~TCPF_LISTEN)) - goto out; - -#define SKARR_SZ 16 - for (i = s_i; i <= hashinfo->ehash_mask; i++) { - struct inet_ehash_bucket *head = &hashinfo->ehash[i]; - spinlock_t *lock = inet_ehash_lockp(hashinfo, i); - struct hlist_nulls_node *node; - struct sock *sk_arr[SKARR_SZ]; - int num_arr[SKARR_SZ]; - int idx, accum, res; - - if (hlist_nulls_empty(&head->chain)) - continue; - - if (i > s_i) - s_num = 0; - -next_chunk: - num = 0; - accum = 0; - spin_lock_bh(lock); - sk_nulls_for_each(sk, node, &head->chain) { - int state; - - if (!net_eq(sock_net(sk), net)) - continue; - if (num < s_num) - goto next_normal; - state = (sk->sk_state == TCP_TIME_WAIT) ? - inet_twsk(sk)->tw_substate : sk->sk_state; - if (!(idiag_states & (1 << state))) - goto next_normal; - if (r->sdiag_family != AF_UNSPEC && - sk->sk_family != r->sdiag_family) - goto next_normal; - if (r->id.idiag_sport != htons(sk->sk_num) && - r->id.idiag_sport) - goto next_normal; - if (r->id.idiag_dport != sk->sk_dport && - r->id.idiag_dport) - goto next_normal; - twsk_build_assert(); - - if (!inet_diag_bc_sk(bc, sk)) - goto next_normal; - - if (!refcount_inc_not_zero(&sk->sk_refcnt)) - goto next_normal; - - num_arr[accum] = num; - sk_arr[accum] = sk; - if (++accum == SKARR_SZ) - break; -next_normal: - ++num; - } - spin_unlock_bh(lock); - res = 0; - for (idx = 0; idx < accum; idx++) { - if (res >= 0) { - res = sk_diag_fill(sk_arr[idx], skb, cb, r, - NLM_F_MULTI, net_admin); - if (res < 0) - num = num_arr[idx]; - } - sock_gen_put(sk_arr[idx]); - } - if (res < 0) - break; - cond_resched(); - if (accum == SKARR_SZ) { - s_num = num + 1; - goto next_chunk; - } - } - -done: - cb->args[1] = i; - cb->args[2] = num; -out: - ; -} -EXPORT_SYMBOL_GPL(inet_diag_dump_icsk); - static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *r) { @@ -1174,12 +814,12 @@ static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, again: prev_min_dump_alloc = cb->min_dump_alloc; handler = inet_diag_lock_handler(protocol); - if (!IS_ERR(handler)) + if (handler) { handler->dump(skb, cb, r); - else - err = PTR_ERR(handler); - inet_diag_unlock_handler(handler); - + inet_diag_unlock_handler(handler); + } else { + err = -ENOENT; + } /* The skb is not large enough to fit one sk info and * inet_sk_diag_fill() has requested for a larger skb. */ @@ -1214,13 +854,10 @@ static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen) kfree(cb_data); return err; } - nla = cb_data->inet_diag_nla_bc; - if (nla) { - err = inet_diag_bc_audit(nla, skb); - if (err) { - kfree(cb_data); - return err; - } + err = inet_diag_bc_audit(cb_data, skb); + if (err) { + kfree(cb_data); + return err; } nla = cb_data->inet_diag_nla_bpf_stgs; @@ -1264,8 +901,6 @@ static int inet_diag_type2proto(int type) switch (type) { case TCPDIAG_GETSOCK: return IPPROTO_TCP; - case DCCPDIAG_GETSOCK: - return IPPROTO_DCCP; default: return 0; } @@ -1280,6 +915,7 @@ static int inet_diag_dump_compat(struct sk_buff *skb, req.sdiag_family = AF_UNSPEC; /* compatibility */ req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type); req.idiag_ext = rc->idiag_ext; + req.pad = 0; req.idiag_states = rc->idiag_states; req.id = rc->id; @@ -1295,6 +931,7 @@ static int inet_diag_get_exact_compat(struct sk_buff *in_skb, req.sdiag_family = rc->idiag_family; req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type); req.idiag_ext = rc->idiag_ext; + req.pad = 0; req.idiag_states = rc->idiag_states; req.id = rc->id; @@ -1372,10 +1009,9 @@ int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk) } handler = inet_diag_lock_handler(sk->sk_protocol); - if (IS_ERR(handler)) { - inet_diag_unlock_handler(handler); + if (!handler) { nlmsg_cancel(skb, nlh); - return PTR_ERR(handler); + return -ENOENT; } attr = handler->idiag_info_size @@ -1394,6 +1030,7 @@ int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk) } static const struct sock_diag_handler inet_diag_handler = { + .owner = THIS_MODULE, .family = AF_INET, .dump = inet_diag_handler_cmd, .get_info = inet_diag_handler_get_info, @@ -1401,6 +1038,7 @@ static const struct sock_diag_handler inet_diag_handler = { }; static const struct sock_diag_handler inet6_diag_handler = { + .owner = THIS_MODULE, .family = AF_INET6, .dump = inet_diag_handler_cmd, .get_info = inet_diag_handler_get_info, @@ -1410,20 +1048,12 @@ static const struct sock_diag_handler inet6_diag_handler = { int inet_diag_register(const struct inet_diag_handler *h) { const __u16 type = h->idiag_type; - int err = -EINVAL; if (type >= IPPROTO_MAX) - goto out; + return -EINVAL; - mutex_lock(&inet_diag_table_mutex); - err = -EEXIST; - if (!inet_diag_table[type]) { - inet_diag_table[type] = h; - err = 0; - } - mutex_unlock(&inet_diag_table_mutex); -out: - return err; + return !cmpxchg((const struct inet_diag_handler **)&inet_diag_table[type], + NULL, h) ? 0 : -EEXIST; } EXPORT_SYMBOL_GPL(inet_diag_register); @@ -1434,12 +1064,16 @@ void inet_diag_unregister(const struct inet_diag_handler *h) if (type >= IPPROTO_MAX) return; - mutex_lock(&inet_diag_table_mutex); - inet_diag_table[type] = NULL; - mutex_unlock(&inet_diag_table_mutex); + xchg((const struct inet_diag_handler **)&inet_diag_table[type], + NULL); } EXPORT_SYMBOL_GPL(inet_diag_unregister); +static const struct sock_diag_inet_compat inet_diag_compat = { + .owner = THIS_MODULE, + .fn = inet_diag_rcv_msg_compat, +}; + static int __init inet_diag_init(void) { const int inet_diag_table_size = (IPPROTO_MAX * @@ -1458,7 +1092,7 @@ static int __init inet_diag_init(void) if (err) goto out_free_inet; - sock_diag_register_inet_compat(inet_diag_rcv_msg_compat); + sock_diag_register_inet_compat(&inet_diag_compat); out: return err; @@ -1473,12 +1107,13 @@ static void __exit inet_diag_exit(void) { sock_diag_unregister(&inet6_diag_handler); sock_diag_unregister(&inet_diag_handler); - sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat); + sock_diag_unregister_inet_compat(&inet_diag_compat); kfree(inet_diag_table); } module_init(inet_diag_init); module_exit(inet_diag_exit); MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("INET/INET6: socket monitoring via SOCK_DIAG"); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */); diff --git a/net/ipv4/inet_fragment.c b/net/ipv4/inet_fragment.c index 341096807100..025895eb6ec5 100644 --- a/net/ipv4/inet_fragment.c +++ b/net/ipv4/inet_fragment.c @@ -24,6 +24,8 @@ #include <net/ip.h> #include <net/ipv6.h> +#include "../core/sock_destructor.h" + /* Use skb->cb to track consecutive/adjacent fragments coming at * the end of the queue. Nodes in the rb-tree queue will * contain "runs" of one or more adjacent fragments. @@ -39,6 +41,7 @@ struct ipfrag_skb_cb { }; struct sk_buff *next_frag; int frag_run_len; + int ip_defrag_offset; }; #define FRAG_CB(skb) ((struct ipfrag_skb_cb *)((skb)->cb)) @@ -130,9 +133,10 @@ static void inet_frags_free_cb(void *ptr, void *arg) struct inet_frag_queue *fq = ptr; int count; - count = del_timer_sync(&fq->timer) ? 1 : 0; + count = timer_delete_sync(&fq->timer) ? 1 : 0; spin_lock_bh(&fq->lock); + fq->flags |= INET_FRAG_DROP; if (!(fq->flags & INET_FRAG_COMPLETE)) { fq->flags |= INET_FRAG_COMPLETE; count++; @@ -141,8 +145,7 @@ static void inet_frags_free_cb(void *ptr, void *arg) } spin_unlock_bh(&fq->lock); - if (refcount_sub_and_test(count, &fq->refcnt)) - inet_frag_destroy(fq); + inet_frag_putn(fq, count); } static LLIST_HEAD(fqdir_free_list); @@ -171,7 +174,7 @@ static void fqdir_free_fn(struct work_struct *work) } } -static DECLARE_WORK(fqdir_free_work, fqdir_free_fn); +static DECLARE_DELAYED_WORK(fqdir_free_work, fqdir_free_fn); static void fqdir_work_fn(struct work_struct *work) { @@ -180,7 +183,7 @@ static void fqdir_work_fn(struct work_struct *work) rhashtable_free_and_destroy(&fqdir->rhashtable, inet_frags_free_cb, NULL); if (llist_add(&fqdir->free_list, &fqdir_free_list)) - queue_work(system_wq, &fqdir_free_work); + queue_delayed_work(system_percpu_wq, &fqdir_free_work, HZ); } int fqdir_init(struct fqdir **fqdirp, struct inet_frags *f, struct net *net) @@ -222,10 +225,10 @@ void fqdir_exit(struct fqdir *fqdir) } EXPORT_SYMBOL(fqdir_exit); -void inet_frag_kill(struct inet_frag_queue *fq) +void inet_frag_kill(struct inet_frag_queue *fq, int *refs) { - if (del_timer(&fq->timer)) - refcount_dec(&fq->refcnt); + if (timer_delete(&fq->timer)) + (*refs)++; if (!(fq->flags & INET_FRAG_COMPLETE)) { struct fqdir *fqdir = fq->fqdir; @@ -240,7 +243,7 @@ void inet_frag_kill(struct inet_frag_queue *fq) if (!READ_ONCE(fqdir->dead)) { rhashtable_remove_fast(&fqdir->rhashtable, &fq->node, fqdir->f->rhash_params); - refcount_dec(&fq->refcnt); + (*refs)++; } else { fq->flags |= INET_FRAG_HASH_DEAD; } @@ -260,7 +263,8 @@ static void inet_frag_destroy_rcu(struct rcu_head *head) kmem_cache_free(f->frags_cachep, q); } -unsigned int inet_frag_rbtree_purge(struct rb_root *root) +unsigned int inet_frag_rbtree_purge(struct rb_root *root, + enum skb_drop_reason reason) { struct rb_node *p = rb_first(root); unsigned int sum = 0; @@ -274,7 +278,7 @@ unsigned int inet_frag_rbtree_purge(struct rb_root *root) struct sk_buff *next = FRAG_CB(skb)->next_frag; sum += skb->truesize; - kfree_skb(skb); + kfree_skb_reason(skb, reason); skb = next; } } @@ -284,17 +288,21 @@ EXPORT_SYMBOL(inet_frag_rbtree_purge); void inet_frag_destroy(struct inet_frag_queue *q) { - struct fqdir *fqdir; unsigned int sum, sum_truesize = 0; + enum skb_drop_reason reason; struct inet_frags *f; + struct fqdir *fqdir; WARN_ON(!(q->flags & INET_FRAG_COMPLETE)); - WARN_ON(del_timer(&q->timer) != 0); + reason = (q->flags & INET_FRAG_DROP) ? + SKB_DROP_REASON_FRAG_REASM_TIMEOUT : + SKB_CONSUMED; + WARN_ON(timer_delete(&q->timer) != 0); /* Release all fragment data. */ fqdir = q->fqdir; f = fqdir->f; - sum_truesize = inet_frag_rbtree_purge(&q->rb_fragments); + sum_truesize = inet_frag_rbtree_purge(&q->rb_fragments, reason); sum = sum_truesize + f->qsize; call_rcu(&q->rcu, inet_frag_destroy_rcu); @@ -319,7 +327,8 @@ static struct inet_frag_queue *inet_frag_alloc(struct fqdir *fqdir, timer_setup(&q->timer, f->frag_expire, 0); spin_lock_init(&q->lock); - refcount_set(&q->refcnt, 3); + /* One reference for the timer, one for the hash table. */ + refcount_set(&q->refcnt, 2); return q; } @@ -341,15 +350,20 @@ static struct inet_frag_queue *inet_frag_create(struct fqdir *fqdir, *prev = rhashtable_lookup_get_insert_key(&fqdir->rhashtable, &q->key, &q->node, f->rhash_params); if (*prev) { + /* We could not insert in the hash table, + * we need to cancel what inet_frag_alloc() + * anticipated. + */ + int refs = 1; + q->flags |= INET_FRAG_COMPLETE; - inet_frag_kill(q); - inet_frag_destroy(q); + inet_frag_kill(q, &refs); + inet_frag_putn(q, refs); return NULL; } return q; } -/* TODO : call from rcu_read_lock() and no longer use refcount_inc_not_zero() */ struct inet_frag_queue *inet_frag_find(struct fqdir *fqdir, void *key) { /* This pairs with WRITE_ONCE() in fqdir_pre_exit(). */ @@ -359,17 +373,11 @@ struct inet_frag_queue *inet_frag_find(struct fqdir *fqdir, void *key) if (!high_thresh || frag_mem_limit(fqdir) > high_thresh) return NULL; - rcu_read_lock(); - prev = rhashtable_lookup(&fqdir->rhashtable, key, fqdir->f->rhash_params); if (!prev) fq = inet_frag_create(fqdir, key, &prev); - if (!IS_ERR_OR_NULL(prev)) { + if (!IS_ERR_OR_NULL(prev)) fq = prev; - if (!refcount_inc_not_zero(&fq->refcnt)) - fq = NULL; - } - rcu_read_unlock(); return fq; } EXPORT_SYMBOL(inet_frag_find); @@ -390,12 +398,12 @@ int inet_frag_queue_insert(struct inet_frag_queue *q, struct sk_buff *skb, */ if (!last) fragrun_create(q, skb); /* First fragment. */ - else if (last->ip_defrag_offset + last->len < end) { + else if (FRAG_CB(last)->ip_defrag_offset + last->len < end) { /* This is the common case: skb goes to the end. */ /* Detect and discard overlaps. */ - if (offset < last->ip_defrag_offset + last->len) + if (offset < FRAG_CB(last)->ip_defrag_offset + last->len) return IPFRAG_OVERLAP; - if (offset == last->ip_defrag_offset + last->len) + if (offset == FRAG_CB(last)->ip_defrag_offset + last->len) fragrun_append_to_last(q, skb); else fragrun_create(q, skb); @@ -412,13 +420,13 @@ int inet_frag_queue_insert(struct inet_frag_queue *q, struct sk_buff *skb, parent = *rbn; curr = rb_to_skb(parent); - curr_run_end = curr->ip_defrag_offset + + curr_run_end = FRAG_CB(curr)->ip_defrag_offset + FRAG_CB(curr)->frag_run_len; - if (end <= curr->ip_defrag_offset) + if (end <= FRAG_CB(curr)->ip_defrag_offset) rbn = &parent->rb_left; else if (offset >= curr_run_end) rbn = &parent->rb_right; - else if (offset >= curr->ip_defrag_offset && + else if (offset >= FRAG_CB(curr)->ip_defrag_offset && end <= curr_run_end) return IPFRAG_DUP; else @@ -432,7 +440,7 @@ int inet_frag_queue_insert(struct inet_frag_queue *q, struct sk_buff *skb, rb_insert_color(&skb->rbnode, &q->rb_fragments); } - skb->ip_defrag_offset = offset; + FRAG_CB(skb)->ip_defrag_offset = offset; return IPFRAG_OK; } @@ -442,13 +450,28 @@ void *inet_frag_reasm_prepare(struct inet_frag_queue *q, struct sk_buff *skb, struct sk_buff *parent) { struct sk_buff *fp, *head = skb_rb_first(&q->rb_fragments); - struct sk_buff **nextp; + void (*destructor)(struct sk_buff *); + unsigned int orig_truesize = 0; + struct sk_buff **nextp = NULL; + struct sock *sk = skb->sk; int delta; + if (sk && is_skb_wmem(skb)) { + /* TX: skb->sk might have been passed as argument to + * dst->output and must remain valid until tx completes. + * + * Move sk to reassembled skb and fix up wmem accounting. + */ + orig_truesize = skb->truesize; + destructor = skb->destructor; + } + if (head != skb) { fp = skb_clone(skb, GFP_ATOMIC); - if (!fp) - return NULL; + if (!fp) { + head = skb; + goto out_restore_sk; + } FRAG_CB(fp)->next_frag = FRAG_CB(skb)->next_frag; if (RB_EMPTY_NODE(&skb->rbnode)) FRAG_CB(parent)->next_frag = fp; @@ -457,6 +480,12 @@ void *inet_frag_reasm_prepare(struct inet_frag_queue *q, struct sk_buff *skb, &q->rb_fragments); if (q->fragments_tail == skb) q->fragments_tail = fp; + + if (orig_truesize) { + /* prevent skb_morph from releasing sk */ + skb->sk = NULL; + skb->destructor = NULL; + } skb_morph(skb, head); FRAG_CB(skb)->next_frag = FRAG_CB(head)->next_frag; rb_replace_node(&head->rbnode, &skb->rbnode, @@ -464,13 +493,13 @@ void *inet_frag_reasm_prepare(struct inet_frag_queue *q, struct sk_buff *skb, consume_skb(head); head = skb; } - WARN_ON(head->ip_defrag_offset != 0); + WARN_ON(FRAG_CB(head)->ip_defrag_offset != 0); delta = -head->truesize; /* Head of list must not be cloned. */ if (skb_unclone(head, GFP_ATOMIC)) - return NULL; + goto out_restore_sk; delta += head->truesize; if (delta) @@ -486,7 +515,7 @@ void *inet_frag_reasm_prepare(struct inet_frag_queue *q, struct sk_buff *skb, clone = alloc_skb(0, GFP_ATOMIC); if (!clone) - return NULL; + goto out_restore_sk; skb_shinfo(clone)->frag_list = skb_shinfo(head)->frag_list; skb_frag_list_init(head); for (i = 0; i < skb_shinfo(head)->nr_frags; i++) @@ -503,6 +532,21 @@ void *inet_frag_reasm_prepare(struct inet_frag_queue *q, struct sk_buff *skb, nextp = &skb_shinfo(head)->frag_list; } +out_restore_sk: + if (orig_truesize) { + int ts_delta = head->truesize - orig_truesize; + + /* if this reassembled skb is fragmented later, + * fraglist skbs will get skb->sk assigned from head->sk, + * and each frag skb will be released via sock_wfree. + * + * Update sk_wmem_alloc. + */ + head->sk = sk; + head->destructor = destructor; + refcount_add(ts_delta, &sk->sk_wmem_alloc); + } + return nextp; } EXPORT_SYMBOL(inet_frag_reasm_prepare); @@ -510,7 +554,9 @@ EXPORT_SYMBOL(inet_frag_reasm_prepare); void inet_frag_reasm_finish(struct inet_frag_queue *q, struct sk_buff *head, void *reasm_data, bool try_coalesce) { - struct sk_buff **nextp = (struct sk_buff **)reasm_data; + struct sock *sk = is_skb_wmem(head) ? head->sk : NULL; + const unsigned int head_truesize = head->truesize; + struct sk_buff **nextp = reasm_data; struct rb_node *rbn; struct sk_buff *fp; int sum_truesize; @@ -572,6 +618,10 @@ void inet_frag_reasm_finish(struct inet_frag_queue *q, struct sk_buff *head, skb_mark_not_on_list(head); head->prev = NULL; head->tstamp = q->stamp; + head->tstamp_type = q->tstamp_type; + + if (sk) + refcount_add(sum_truesize - head_truesize, &sk->sk_wmem_alloc); } EXPORT_SYMBOL(inet_frag_reasm_finish); diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index 30ab717ff1b8..f5826ec4bcaa 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -23,22 +23,23 @@ #if IS_ENABLED(CONFIG_IPV6) #include <net/inet6_hashtables.h> #endif -#include <net/secure_seq.h> +#include <net/hotdata.h> #include <net/ip.h> -#include <net/tcp.h> +#include <net/rps.h> +#include <net/secure_seq.h> #include <net/sock_reuseport.h> +#include <net/tcp.h> -static u32 inet_ehashfn(const struct net *net, const __be32 laddr, - const __u16 lport, const __be32 faddr, - const __be16 fport) +u32 inet_ehashfn(const struct net *net, const __be32 laddr, + const __u16 lport, const __be32 faddr, + const __be16 fport) { - static u32 inet_ehash_secret __read_mostly; - net_get_random_once(&inet_ehash_secret, sizeof(inet_ehash_secret)); - return __inet_ehashfn(laddr, lport, faddr, fport, - inet_ehash_secret + net_hash_mix(net)); + return lport + __inet_ehashfn(laddr, 0, faddr, fport, + inet_ehash_secret + net_hash_mix(net)); } +EXPORT_SYMBOL_GPL(inet_ehashfn); /* This function handles inet_sock, but also timewait and request sockets * for IPv4/IPv6. @@ -57,6 +58,14 @@ static u32 sk_ehashfn(const struct sock *sk) sk->sk_daddr, sk->sk_dport); } +static bool sk_is_connect_bind(const struct sock *sk) +{ + if (sk->sk_state == TCP_TIME_WAIT) + return inet_twsk(sk)->tw_connect_bind; + else + return sk->sk_userlocks & SOCK_CONNECT_BIND; +} + /* * Allocate and initialize a new local port bind bucket. * The bindhash mutex for snum's hash chain must be held here. @@ -75,8 +84,8 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, tb->port = snum; tb->fastreuse = 0; tb->fastreuseport = 0; - INIT_HLIST_HEAD(&tb->owners); - hlist_add_head(&tb->node, &head->chain); + INIT_HLIST_HEAD(&tb->bhash2); + hlist_add_head_rcu(&tb->node, &head->chain); } return tb; } @@ -84,20 +93,117 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, /* * Caller must hold hashbucket lock for this tb with local BH disabled */ -void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb) +void inet_bind_bucket_destroy(struct inet_bind_bucket *tb) { + const struct inet_bind2_bucket *tb2; + + if (hlist_empty(&tb->bhash2)) { + hlist_del_rcu(&tb->node); + kfree_rcu(tb, rcu); + return; + } + + if (tb->fastreuse == -1 && tb->fastreuseport == -1) + return; + hlist_for_each_entry(tb2, &tb->bhash2, bhash_node) { + if (tb2->fastreuse != -1 || tb2->fastreuseport != -1) + return; + } + tb->fastreuse = -1; + tb->fastreuseport = -1; +} + +bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net, + unsigned short port, int l3mdev) +{ + return net_eq(ib_net(tb), net) && tb->port == port && + tb->l3mdev == l3mdev; +} + +static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb2, + struct net *net, + struct inet_bind_hashbucket *head, + struct inet_bind_bucket *tb, + const struct sock *sk) +{ + write_pnet(&tb2->ib_net, net); + tb2->l3mdev = tb->l3mdev; + tb2->port = tb->port; +#if IS_ENABLED(CONFIG_IPV6) + BUILD_BUG_ON(USHRT_MAX < (IPV6_ADDR_ANY | IPV6_ADDR_MAPPED)); + if (sk->sk_family == AF_INET6) { + tb2->addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr); + tb2->v6_rcv_saddr = sk->sk_v6_rcv_saddr; + } else { + tb2->addr_type = IPV6_ADDR_MAPPED; + ipv6_addr_set_v4mapped(sk->sk_rcv_saddr, &tb2->v6_rcv_saddr); + } +#else + tb2->rcv_saddr = sk->sk_rcv_saddr; +#endif + tb2->fastreuse = 0; + tb2->fastreuseport = 0; + INIT_HLIST_HEAD(&tb2->owners); + hlist_add_head(&tb2->node, &head->chain); + hlist_add_head(&tb2->bhash_node, &tb->bhash2); +} + +struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep, + struct net *net, + struct inet_bind_hashbucket *head, + struct inet_bind_bucket *tb, + const struct sock *sk) +{ + struct inet_bind2_bucket *tb2 = kmem_cache_alloc(cachep, GFP_ATOMIC); + + if (tb2) + inet_bind2_bucket_init(tb2, net, head, tb, sk); + + return tb2; +} + +/* Caller must hold hashbucket lock for this tb with local BH disabled */ +void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb) +{ + const struct sock *sk; + if (hlist_empty(&tb->owners)) { __hlist_del(&tb->node); + __hlist_del(&tb->bhash_node); kmem_cache_free(cachep, tb); + return; } + + if (tb->fastreuse == -1 && tb->fastreuseport == -1) + return; + sk_for_each_bound(sk, &tb->owners) { + if (!sk_is_connect_bind(sk)) + return; + } + tb->fastreuse = -1; + tb->fastreuseport = -1; +} + +static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2, + const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + return ipv6_addr_equal(&tb2->v6_rcv_saddr, &sk->sk_v6_rcv_saddr); + + if (tb2->addr_type != IPV6_ADDR_MAPPED) + return false; +#endif + return tb2->rcv_saddr == sk->sk_rcv_saddr; } void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, - const unsigned short snum) + struct inet_bind2_bucket *tb2, unsigned short port) { - inet_sk(sk)->inet_num = snum; - sk_add_bind_node(sk, &tb->owners); + inet_sk(sk)->inet_num = port; inet_csk(sk)->icsk_bind_hash = tb; + inet_csk(sk)->icsk_bind2_hash = tb2; + sk_add_bind_node(sk, &tb2->owners); } /* @@ -105,18 +211,33 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, */ static void __inet_put_port(struct sock *sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; - const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num, - hashinfo->bhash_size); - struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash]; + struct inet_hashinfo *hashinfo = tcp_get_hashinfo(sk); + struct inet_bind_hashbucket *head, *head2; + struct net *net = sock_net(sk); struct inet_bind_bucket *tb; + int bhash; + + bhash = inet_bhashfn(net, inet_sk(sk)->inet_num, hashinfo->bhash_size); + head = &hashinfo->bhash[bhash]; + head2 = inet_bhashfn_portaddr(hashinfo, sk, net, inet_sk(sk)->inet_num); spin_lock(&head->lock); tb = inet_csk(sk)->icsk_bind_hash; - __sk_del_bind_node(sk); inet_csk(sk)->icsk_bind_hash = NULL; inet_sk(sk)->inet_num = 0; - inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); + sk->sk_userlocks &= ~SOCK_CONNECT_BIND; + + spin_lock(&head2->lock); + if (inet_csk(sk)->icsk_bind2_hash) { + struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash; + + __sk_del_bind_node(sk); + inet_csk(sk)->icsk_bind2_hash = NULL; + inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2); + } + spin_unlock(&head2->lock); + + inet_bind_bucket_destroy(tb); spin_unlock(&head->lock); } @@ -130,17 +251,26 @@ EXPORT_SYMBOL(inet_put_port); int __inet_inherit_port(const struct sock *sk, struct sock *child) { - struct inet_hashinfo *table = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *table = tcp_get_hashinfo(sk); unsigned short port = inet_sk(child)->inet_num; - const int bhash = inet_bhashfn(sock_net(sk), port, - table->bhash_size); - struct inet_bind_hashbucket *head = &table->bhash[bhash]; + struct inet_bind_hashbucket *head, *head2; + bool created_inet_bind_bucket = false; + struct net *net = sock_net(sk); + bool update_fastreuse = false; + struct inet_bind2_bucket *tb2; struct inet_bind_bucket *tb; - int l3mdev; + int bhash, l3mdev; + + bhash = inet_bhashfn(net, port, table->bhash_size); + head = &table->bhash[bhash]; + head2 = inet_bhashfn_portaddr(table, child, net, port); spin_lock(&head->lock); + spin_lock(&head2->lock); tb = inet_csk(sk)->icsk_bind_hash; - if (unlikely(!tb)) { + tb2 = inet_csk(sk)->icsk_bind2_hash; + if (unlikely(!tb || !tb2)) { + spin_unlock(&head2->lock); spin_unlock(&head->lock); return -ENOENT; } @@ -153,25 +283,48 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child) * as that of the child socket. We have to look up or * create a new bind bucket for the child here. */ inet_bind_bucket_for_each(tb, &head->chain) { - if (net_eq(ib_net(tb), sock_net(sk)) && - tb->l3mdev == l3mdev && tb->port == port) + if (inet_bind_bucket_match(tb, net, port, l3mdev)) break; } if (!tb) { tb = inet_bind_bucket_create(table->bind_bucket_cachep, - sock_net(sk), head, port, - l3mdev); + net, head, port, l3mdev); if (!tb) { + spin_unlock(&head2->lock); spin_unlock(&head->lock); return -ENOMEM; } + created_inet_bind_bucket = true; + } + update_fastreuse = true; + + goto bhash2_find; + } else if (!inet_bind2_bucket_addr_match(tb2, child)) { + l3mdev = inet_sk_bound_l3mdev(sk); + +bhash2_find: + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child); + if (!tb2) { + tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep, + net, head2, tb, child); + if (!tb2) + goto error; } - inet_csk_update_fastreuse(tb, child); } - inet_bind_hash(child, tb, port); + if (update_fastreuse) + inet_csk_update_fastreuse(child, tb, tb2); + inet_bind_hash(child, tb, tb2, port); + spin_unlock(&head2->lock); spin_unlock(&head->lock); return 0; + +error: + if (created_inet_bind_bucket) + inet_bind_bucket_destroy(tb); + spin_unlock(&head2->lock); + spin_unlock(&head->lock); + return -ENOMEM; } EXPORT_SYMBOL_GPL(__inet_inherit_port); @@ -193,43 +346,7 @@ inet_lhash2_bucket_sk(struct inet_hashinfo *h, struct sock *sk) return inet_lhash2_bucket(h, hash); } -static void inet_hash2(struct inet_hashinfo *h, struct sock *sk) -{ - struct inet_listen_hashbucket *ilb2; - - if (!h->lhash2) - return; - - ilb2 = inet_lhash2_bucket_sk(h, sk); - - spin_lock(&ilb2->lock); - if (sk->sk_reuseport && sk->sk_family == AF_INET6) - hlist_add_tail_rcu(&inet_csk(sk)->icsk_listen_portaddr_node, - &ilb2->head); - else - hlist_add_head_rcu(&inet_csk(sk)->icsk_listen_portaddr_node, - &ilb2->head); - ilb2->count++; - spin_unlock(&ilb2->lock); -} - -static void inet_unhash2(struct inet_hashinfo *h, struct sock *sk) -{ - struct inet_listen_hashbucket *ilb2; - - if (!h->lhash2 || - WARN_ON_ONCE(hlist_unhashed(&inet_csk(sk)->icsk_listen_portaddr_node))) - return; - - ilb2 = inet_lhash2_bucket_sk(h, sk); - - spin_lock(&ilb2->lock); - hlist_del_init_rcu(&inet_csk(sk)->icsk_listen_portaddr_node); - ilb2->count--; - spin_unlock(&ilb2->lock); -} - -static inline int compute_score(struct sock *sk, struct net *net, +static inline int compute_score(struct sock *sk, const struct net *net, const unsigned short hnum, const __be32 daddr, const int dif, const int sdif) { @@ -252,20 +369,38 @@ static inline int compute_score(struct sock *sk, struct net *net, return score; } -static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk, - struct sk_buff *skb, int doff, - __be32 saddr, __be16 sport, - __be32 daddr, unsigned short hnum) +/** + * inet_lookup_reuseport() - execute reuseport logic on AF_INET socket if necessary. + * @net: network namespace. + * @sk: AF_INET socket, must be in TCP_LISTEN state for TCP or TCP_CLOSE for UDP. + * @skb: context for a potential SK_REUSEPORT program. + * @doff: header offset. + * @saddr: source address. + * @sport: source port. + * @daddr: destination address. + * @hnum: destination port in host byte order. + * @ehashfn: hash function used to generate the fallback hash. + * + * Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to + * the selected sock or an error. + */ +struct sock *inet_lookup_reuseport(const struct net *net, struct sock *sk, + struct sk_buff *skb, int doff, + __be32 saddr, __be16 sport, + __be32 daddr, unsigned short hnum, + inet_ehashfn_t *ehashfn) { struct sock *reuse_sk = NULL; u32 phash; if (sk->sk_reuseport) { - phash = inet_ehashfn(net, daddr, hnum, saddr, sport); + phash = INDIRECT_CALL_2(ehashfn, udp_ehashfn, inet_ehashfn, + net, daddr, hnum, saddr, sport); reuse_sk = reuseport_select_sock(sk, phash, skb, doff); } return reuse_sk; } +EXPORT_SYMBOL_GPL(inet_lookup_reuseport); /* * Here are some nice properties to exploit here. The BSD API @@ -275,23 +410,22 @@ static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk, */ /* called with rcu_read_lock() : No refcount taken on the socket */ -static struct sock *inet_lhash2_lookup(struct net *net, +static struct sock *inet_lhash2_lookup(const struct net *net, struct inet_listen_hashbucket *ilb2, struct sk_buff *skb, int doff, const __be32 saddr, __be16 sport, const __be32 daddr, const unsigned short hnum, const int dif, const int sdif) { - struct inet_connection_sock *icsk; struct sock *sk, *result = NULL; + struct hlist_nulls_node *node; int score, hiscore = 0; - inet_lhash2_for_each_icsk_rcu(icsk, &ilb2->head) { - sk = (struct sock *)icsk; + sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) { score = compute_score(sk, net, hnum, daddr, dif, sdif); if (score > hiscore) { - result = lookup_reuseport(net, sk, skb, doff, - saddr, sport, daddr, hnum); + result = inet_lookup_reuseport(net, sk, skb, doff, + saddr, sport, daddr, hnum, inet_ehashfn); if (result) return result; @@ -303,48 +437,49 @@ static struct sock *inet_lhash2_lookup(struct net *net, return result; } -static inline struct sock *inet_lookup_run_bpf(struct net *net, - struct inet_hashinfo *hashinfo, - struct sk_buff *skb, int doff, - __be32 saddr, __be16 sport, - __be32 daddr, u16 hnum, const int dif) +struct sock *inet_lookup_run_sk_lookup(const struct net *net, + int protocol, + struct sk_buff *skb, int doff, + __be32 saddr, __be16 sport, + __be32 daddr, u16 hnum, const int dif, + inet_ehashfn_t *ehashfn) { struct sock *sk, *reuse_sk; bool no_reuseport; - if (hashinfo != &tcp_hashinfo) - return NULL; /* only TCP is supported */ - - no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP, saddr, sport, + no_reuseport = bpf_sk_lookup_run_v4(net, protocol, saddr, sport, daddr, hnum, dif, &sk); if (no_reuseport || IS_ERR_OR_NULL(sk)) return sk; - reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum); + reuse_sk = inet_lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum, + ehashfn); if (reuse_sk) sk = reuse_sk; return sk; } -struct sock *__inet_lookup_listener(struct net *net, - struct inet_hashinfo *hashinfo, +struct sock *__inet_lookup_listener(const struct net *net, struct sk_buff *skb, int doff, const __be32 saddr, __be16 sport, const __be32 daddr, const unsigned short hnum, const int dif, const int sdif) { struct inet_listen_hashbucket *ilb2; + struct inet_hashinfo *hashinfo; struct sock *result = NULL; unsigned int hash2; /* Lookup redirect from BPF */ if (static_branch_unlikely(&bpf_sk_lookup_enabled)) { - result = inet_lookup_run_bpf(net, hashinfo, skb, doff, - saddr, sport, daddr, hnum, dif); + result = inet_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff, + saddr, sport, daddr, hnum, dif, + inet_ehashfn); if (result) goto done; } + hashinfo = net->ipv4.tcp_death_row.hashinfo; hash2 = ipv4_portaddr_hash(net, daddr, hnum); ilb2 = inet_lhash2_bucket(hashinfo, hash2); @@ -389,34 +524,33 @@ void sock_edemux(struct sk_buff *skb) } EXPORT_SYMBOL(sock_edemux); -struct sock *__inet_lookup_established(struct net *net, - struct inet_hashinfo *hashinfo, - const __be32 saddr, const __be16 sport, - const __be32 daddr, const u16 hnum, - const int dif, const int sdif) +struct sock *__inet_lookup_established(const struct net *net, + const __be32 saddr, const __be16 sport, + const __be32 daddr, const u16 hnum, + const int dif, const int sdif) { - INET_ADDR_COOKIE(acookie, saddr, daddr); const __portpair ports = INET_COMBINED_PORTS(sport, hnum); - struct sock *sk; + INET_ADDR_COOKIE(acookie, saddr, daddr); const struct hlist_nulls_node *node; - /* Optimize here for direct hit, only listening connections can - * have wildcards anyways. - */ - unsigned int hash = inet_ehashfn(net, daddr, hnum, saddr, sport); - unsigned int slot = hash & hashinfo->ehash_mask; - struct inet_ehash_bucket *head = &hashinfo->ehash[slot]; + struct inet_ehash_bucket *head; + struct inet_hashinfo *hashinfo; + unsigned int hash, slot; + struct sock *sk; + + hashinfo = net->ipv4.tcp_death_row.hashinfo; + hash = inet_ehashfn(net, daddr, hnum, saddr, sport); + slot = hash & hashinfo->ehash_mask; + head = &hashinfo->ehash[slot]; begin: sk_nulls_for_each_rcu(sk, node, &head->chain) { if (sk->sk_hash != hash) continue; - if (likely(INET_MATCH(sk, net, acookie, - saddr, daddr, ports, dif, sdif))) { + if (likely(inet_match(net, sk, acookie, ports, dif, sdif))) { if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) goto out; - if (unlikely(!INET_MATCH(sk, net, acookie, - saddr, daddr, ports, - dif, sdif))) { + if (unlikely(!inet_match(net, sk, acookie, + ports, dif, sdif))) { sock_gen_put(sk); goto begin; } @@ -440,7 +574,9 @@ EXPORT_SYMBOL_GPL(__inet_lookup_established); /* called with local bh disabled */ static int __inet_check_established(struct inet_timewait_death_row *death_row, struct sock *sk, __u16 lport, - struct inet_timewait_sock **twp) + struct inet_timewait_sock **twp, + bool rcu_lookup, + u32 hash) { struct inet_hashinfo *hinfo = death_row->hashinfo; struct inet_sock *inet = inet_sk(sk); @@ -451,25 +587,35 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row, int sdif = l3mdev_master_ifindex_by_index(net, dif); INET_ADDR_COOKIE(acookie, saddr, daddr); const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport); - unsigned int hash = inet_ehashfn(net, daddr, lport, - saddr, inet->inet_dport); struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); - spinlock_t *lock = inet_ehash_lockp(hinfo, hash); - struct sock *sk2; - const struct hlist_nulls_node *node; struct inet_timewait_sock *tw = NULL; + const struct hlist_nulls_node *node; + struct sock *sk2; + spinlock_t *lock; + + if (rcu_lookup) { + sk_nulls_for_each(sk2, node, &head->chain) { + if (sk2->sk_hash != hash || + !inet_match(net, sk2, acookie, ports, dif, sdif)) + continue; + if (sk2->sk_state == TCP_TIME_WAIT) + break; + return -EADDRNOTAVAIL; + } + return 0; + } + lock = inet_ehash_lockp(hinfo, hash); spin_lock(lock); sk_nulls_for_each(sk2, node, &head->chain) { if (sk2->sk_hash != hash) continue; - if (likely(INET_MATCH(sk2, net, acookie, - saddr, daddr, ports, dif, sdif))) { + if (likely(inet_match(net, sk2, acookie, ports, dif, sdif))) { if (sk2->sk_state == TCP_TIME_WAIT) { tw = inet_twsk(sk2); - if (twsk_unique(sk, sk2, twp)) + if (tcp_twsk_unique(sk, sk2, twp)) break; } goto not_unique; @@ -504,7 +650,7 @@ not_unique: return -EADDRNOTAVAIL; } -static u32 inet_sk_port_offset(const struct sock *sk) +static u64 inet_sk_port_offset(const struct sock *sk) { const struct inet_sock *inet = inet_sk(sk); @@ -532,16 +678,14 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk, if (esk->sk_hash != sk->sk_hash) continue; if (sk->sk_family == AF_INET) { - if (unlikely(INET_MATCH(esk, net, acookie, - sk->sk_daddr, - sk->sk_rcv_saddr, + if (unlikely(inet_match(net, esk, acookie, ports, dif, sdif))) { return true; } } #if IS_ENABLED(CONFIG_IPV6) else if (sk->sk_family == AF_INET6) { - if (unlikely(INET6_MATCH(esk, net, + if (unlikely(inet6_match(net, esk, &sk->sk_v6_daddr, &sk->sk_v6_rcv_saddr, ports, dif, sdif))) { @@ -560,9 +704,9 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk, */ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; - struct hlist_nulls_head *list; + struct inet_hashinfo *hashinfo = tcp_get_hashinfo(sk); struct inet_ehash_bucket *head; + struct hlist_nulls_head *list; spinlock_t *lock; bool ret = true; @@ -576,8 +720,11 @@ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) spin_lock(lock); if (osk) { WARN_ON_ONCE(sk->sk_hash != osk->sk_hash); - ret = sk_nulls_del_node_init_rcu(osk); - } else if (found_dup_sk) { + ret = sk_nulls_replace_node_init_rcu(osk, sk); + goto unlock; + } + + if (found_dup_sk) { *found_dup_sk = inet_ehash_lookup_by_sk(sk, list); if (*found_dup_sk) ret = false; @@ -586,6 +733,7 @@ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) if (ret) __sk_nulls_add_node_rcu(sk, list); +unlock: spin_unlock(lock); return ret; @@ -598,22 +746,22 @@ bool inet_ehash_nolisten(struct sock *sk, struct sock *osk, bool *found_dup_sk) if (ok) { sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); } else { - this_cpu_inc(*sk->sk_prot->orphan_count); + tcp_orphan_count_inc(); inet_sk_set_state(sk, TCP_CLOSE); sock_set_flag(sk, SOCK_DEAD); inet_csk_destroy_sock(sk); } return ok; } -EXPORT_SYMBOL_GPL(inet_ehash_nolisten); +EXPORT_IPV6_MOD(inet_ehash_nolisten); static int inet_reuseport_add_sock(struct sock *sk, struct inet_listen_hashbucket *ilb) { struct inet_bind_bucket *tb = inet_csk(sk)->icsk_bind_hash; const struct hlist_nulls_node *node; + kuid_t uid = sk_uid(sk); struct sock *sk2; - kuid_t uid = sock_i_uid(sk); sk_nulls_for_each_rcu(sk2, node, &ilb->nulls_head) { if (sk2 != sk && @@ -621,7 +769,7 @@ static int inet_reuseport_add_sock(struct sock *sk, ipv6_only_sock(sk2) == ipv6_only_sock(sk) && sk2->sk_bound_dev_if == sk->sk_bound_dev_if && inet_csk(sk2)->icsk_bind_hash == tb && - sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && + sk2->sk_reuseport && uid_eq(uid, sk_uid(sk2)) && inet_rcv_saddr_equal(sk, sk2, false)) return reuseport_add_sock(sk, sk2, inet_rcv_saddr_any(sk)); @@ -630,172 +778,352 @@ static int inet_reuseport_add_sock(struct sock *sk, return reuseport_alloc(sk, inet_rcv_saddr_any(sk)); } -int __inet_hash(struct sock *sk, struct sock *osk) +int inet_hash(struct sock *sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; - struct inet_listen_hashbucket *ilb; + struct inet_hashinfo *hashinfo = tcp_get_hashinfo(sk); + struct inet_listen_hashbucket *ilb2; int err = 0; + if (sk->sk_state == TCP_CLOSE) + return 0; + if (sk->sk_state != TCP_LISTEN) { - inet_ehash_nolisten(sk, osk, NULL); + local_bh_disable(); + inet_ehash_nolisten(sk, NULL, NULL); + local_bh_enable(); return 0; } WARN_ON(!sk_unhashed(sk)); - ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)]; + ilb2 = inet_lhash2_bucket_sk(hashinfo, sk); - spin_lock(&ilb->lock); + spin_lock(&ilb2->lock); if (sk->sk_reuseport) { - err = inet_reuseport_add_sock(sk, ilb); + err = inet_reuseport_add_sock(sk, ilb2); if (err) goto unlock; } + sock_set_flag(sk, SOCK_RCU_FREE); if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && sk->sk_family == AF_INET6) - __sk_nulls_add_node_tail_rcu(sk, &ilb->nulls_head); + __sk_nulls_add_node_tail_rcu(sk, &ilb2->nulls_head); else - __sk_nulls_add_node_rcu(sk, &ilb->nulls_head); - inet_hash2(hashinfo, sk); - ilb->count++; - sock_set_flag(sk, SOCK_RCU_FREE); + __sk_nulls_add_node_rcu(sk, &ilb2->nulls_head); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); unlock: - spin_unlock(&ilb->lock); + spin_unlock(&ilb2->lock); return err; } -EXPORT_SYMBOL(__inet_hash); +EXPORT_IPV6_MOD(inet_hash); -int inet_hash(struct sock *sk) +void inet_unhash(struct sock *sk) { - int err = 0; + struct inet_hashinfo *hashinfo = tcp_get_hashinfo(sk); - if (sk->sk_state != TCP_CLOSE) { - local_bh_disable(); - err = __inet_hash(sk, NULL); - local_bh_enable(); + if (sk_unhashed(sk)) + return; + + sock_rps_delete_flow(sk); + if (sk->sk_state == TCP_LISTEN) { + struct inet_listen_hashbucket *ilb2; + + ilb2 = inet_lhash2_bucket_sk(hashinfo, sk); + /* Don't disable bottom halves while acquiring the lock to + * avoid circular locking dependency on PREEMPT_RT. + */ + spin_lock(&ilb2->lock); + if (rcu_access_pointer(sk->sk_reuseport_cb)) + reuseport_stop_listen_sock(sk); + + __sk_nulls_del_node_init_rcu(sk); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + spin_unlock(&ilb2->lock); + } else { + spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash); + + spin_lock_bh(lock); + __sk_nulls_del_node_init_rcu(sk); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + spin_unlock_bh(lock); } +} +EXPORT_IPV6_MOD(inet_unhash); - return err; +static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb, + const struct net *net, unsigned short port, + int l3mdev, const struct sock *sk) +{ + if (!net_eq(ib2_net(tb), net) || tb->port != port || + tb->l3mdev != l3mdev) + return false; + + return inet_bind2_bucket_addr_match(tb, sk); } -EXPORT_SYMBOL_GPL(inet_hash); -void inet_unhash(struct sock *sk) +bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net, + unsigned short port, int l3mdev, const struct sock *sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; - struct inet_listen_hashbucket *ilb = NULL; - spinlock_t *lock; + if (!net_eq(ib2_net(tb), net) || tb->port != port || + tb->l3mdev != l3mdev) + return false; - if (sk_unhashed(sk)) - return; +#if IS_ENABLED(CONFIG_IPV6) + if (tb->addr_type == IPV6_ADDR_ANY) + return true; - if (sk->sk_state == TCP_LISTEN) { - ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)]; - lock = &ilb->lock; - } else { - lock = inet_ehash_lockp(hashinfo, sk->sk_hash); + if (tb->addr_type != IPV6_ADDR_MAPPED) + return false; + + if (sk->sk_family == AF_INET6 && + !ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr)) + return false; +#endif + return tb->rcv_saddr == 0; +} + +/* The socket's bhash2 hashbucket spinlock must be held when this is called */ +struct inet_bind2_bucket * +inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net *net, + unsigned short port, int l3mdev, const struct sock *sk) +{ + struct inet_bind2_bucket *bhash2 = NULL; + + inet_bind_bucket_for_each(bhash2, &head->chain) + if (inet_bind2_bucket_match(bhash2, net, port, l3mdev, sk)) + break; + + return bhash2; +} + +struct inet_bind_hashbucket * +inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port) +{ + struct inet_hashinfo *hinfo = tcp_get_hashinfo(sk); + u32 hash; + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + hash = ipv6_portaddr_hash(net, &in6addr_any, port); + else +#endif + hash = ipv4_portaddr_hash(net, 0, port); + + return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)]; +} + +static void inet_update_saddr(struct sock *sk, void *saddr, int family) +{ + if (family == AF_INET) { + inet_sk(sk)->inet_saddr = *(__be32 *)saddr; + sk_rcv_saddr_set(sk, inet_sk(sk)->inet_saddr); } - spin_lock_bh(lock); - if (sk_unhashed(sk)) - goto unlock; +#if IS_ENABLED(CONFIG_IPV6) + else { + sk->sk_v6_rcv_saddr = *(struct in6_addr *)saddr; + } +#endif +} + +static int __inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family, bool reset) +{ + struct inet_hashinfo *hinfo = tcp_get_hashinfo(sk); + struct inet_bind_hashbucket *head, *head2; + struct inet_bind2_bucket *tb2, *new_tb2; + int l3mdev = inet_sk_bound_l3mdev(sk); + int port = inet_sk(sk)->inet_num; + struct net *net = sock_net(sk); + int bhash; + + if (!inet_csk(sk)->icsk_bind2_hash) { + /* Not bind()ed before. */ + if (reset) + inet_reset_saddr(sk); + else + inet_update_saddr(sk, saddr, family); - if (rcu_access_pointer(sk->sk_reuseport_cb)) - reuseport_stop_listen_sock(sk); - if (ilb) { - inet_unhash2(hashinfo, sk); - ilb->count--; + return 0; } - __sk_nulls_del_node_init_rcu(sk); - sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); -unlock: - spin_unlock_bh(lock); + + /* Allocate a bind2 bucket ahead of time to avoid permanently putting + * the bhash2 table in an inconsistent state if a new tb2 bucket + * allocation fails. + */ + new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC); + if (!new_tb2) { + if (reset) { + /* The (INADDR_ANY, port) bucket might have already + * been freed, then we cannot fixup icsk_bind2_hash, + * so we give up and unlink sk from bhash/bhash2 not + * to leave inconsistency in bhash2. + */ + inet_put_port(sk); + inet_reset_saddr(sk); + } + + return -ENOMEM; + } + + bhash = inet_bhashfn(net, port, hinfo->bhash_size); + head = &hinfo->bhash[bhash]; + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + + /* If we change saddr locklessly, another thread + * iterating over bhash might see corrupted address. + */ + spin_lock_bh(&head->lock); + + spin_lock(&head2->lock); + __sk_del_bind_node(sk); + inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, inet_csk(sk)->icsk_bind2_hash); + spin_unlock(&head2->lock); + + if (reset) + inet_reset_saddr(sk); + else + inet_update_saddr(sk, saddr, family); + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + + spin_lock(&head2->lock); + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + if (!tb2) { + tb2 = new_tb2; + inet_bind2_bucket_init(tb2, net, head2, inet_csk(sk)->icsk_bind_hash, sk); + if (sk_is_connect_bind(sk)) { + tb2->fastreuse = -1; + tb2->fastreuseport = -1; + } + } + inet_csk(sk)->icsk_bind2_hash = tb2; + sk_add_bind_node(sk, &tb2->owners); + spin_unlock(&head2->lock); + + spin_unlock_bh(&head->lock); + + if (tb2 != new_tb2) + kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2); + + return 0; +} + +int inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family) +{ + return __inet_bhash2_update_saddr(sk, saddr, family, false); } -EXPORT_SYMBOL_GPL(inet_unhash); +EXPORT_IPV6_MOD(inet_bhash2_update_saddr); + +void inet_bhash2_reset_saddr(struct sock *sk) +{ + if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) + __inet_bhash2_update_saddr(sk, NULL, 0, true); +} +EXPORT_IPV6_MOD(inet_bhash2_reset_saddr); /* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm * Note that we use 32bit integers (vs RFC 'short integers') * because 2^16 is not a multiple of num_ephemeral and this * property might be used by clever attacker. - * RFC claims using TABLE_LENGTH=10 buckets gives an improvement, - * we use 256 instead to really give more isolation and - * privacy, this only consumes 1 KB of kernel memory. + * + * RFC claims using TABLE_LENGTH=10 buckets gives an improvement, though + * attacks were since demonstrated, thus we use 65536 by default instead + * to really give more isolation and privacy, at the expense of 256kB + * of kernel memory. */ -#define INET_TABLE_PERTURB_SHIFT 8 -static u32 table_perturb[1 << INET_TABLE_PERTURB_SHIFT]; +#define INET_TABLE_PERTURB_SIZE (1 << CONFIG_INET_TABLE_PERTURB_ORDER) +static u32 *table_perturb; int __inet_hash_connect(struct inet_timewait_death_row *death_row, - struct sock *sk, u32 port_offset, + struct sock *sk, u64 port_offset, + u32 hash_port0, int (*check_established)(struct inet_timewait_death_row *, - struct sock *, __u16, struct inet_timewait_sock **)) + struct sock *, __u16, struct inet_timewait_sock **, + bool rcu_lookup, u32 hash)) { struct inet_hashinfo *hinfo = death_row->hashinfo; + struct inet_bind_hashbucket *head, *head2; struct inet_timewait_sock *tw = NULL; - struct inet_bind_hashbucket *head; int port = inet_sk(sk)->inet_num; struct net *net = sock_net(sk); + struct inet_bind2_bucket *tb2; struct inet_bind_bucket *tb; + bool tb_created = false; u32 remaining, offset; int ret, i, low, high; - int l3mdev; + bool local_ports; + int step, l3mdev; u32 index; if (port) { - head = &hinfo->bhash[inet_bhashfn(net, port, - hinfo->bhash_size)]; - tb = inet_csk(sk)->icsk_bind_hash; - spin_lock_bh(&head->lock); - if (sk_head(&tb->owners) == sk && !sk->sk_bind_node.next) { - inet_ehash_nolisten(sk, NULL, NULL); - spin_unlock_bh(&head->lock); - return 0; - } - spin_unlock(&head->lock); - /* No definite answer... Walk to established hash table */ - ret = check_established(death_row, sk, port, NULL); + local_bh_disable(); + ret = check_established(death_row, sk, port, NULL, false, + hash_port0 + port); local_bh_enable(); return ret; } l3mdev = inet_sk_bound_l3mdev(sk); - inet_get_local_port_range(net, &low, &high); + local_ports = inet_sk_get_local_port_range(sk, &low, &high); + step = local_ports ? 1 : 2; + high++; /* [32768, 60999] -> [32768, 61000[ */ remaining = high - low; - if (likely(remaining > 1)) + if (!local_ports && remaining > 1) remaining &= ~1U; - net_get_random_once(table_perturb, sizeof(table_perturb)); - index = hash_32(port_offset, INET_TABLE_PERTURB_SHIFT); + get_random_sleepable_once(table_perturb, + INET_TABLE_PERTURB_SIZE * sizeof(*table_perturb)); + index = port_offset & (INET_TABLE_PERTURB_SIZE - 1); + + offset = READ_ONCE(table_perturb[index]) + (port_offset >> 32); + offset %= remaining; - offset = (READ_ONCE(table_perturb[index]) + port_offset) % remaining; /* In first pass we try ports of @low parity. * inet_csk_get_port() does the opposite choice. */ - offset &= ~1U; + if (!local_ports) + offset &= ~1U; other_parity_scan: port = low + offset; - for (i = 0; i < remaining; i += 2, port += 2) { + for (i = 0; i < remaining; i += step, port += step) { if (unlikely(port >= high)) port -= remaining; if (inet_is_local_reserved_port(net, port)) continue; head = &hinfo->bhash[inet_bhashfn(net, port, hinfo->bhash_size)]; + rcu_read_lock(); + hlist_for_each_entry_rcu(tb, &head->chain, node) { + if (!inet_bind_bucket_match(tb, net, port, l3mdev)) + continue; + if (tb->fastreuse >= 0 || tb->fastreuseport >= 0) { + rcu_read_unlock(); + goto next_port; + } + if (!check_established(death_row, sk, port, &tw, true, + hash_port0 + port)) + break; + rcu_read_unlock(); + goto next_port; + } + rcu_read_unlock(); + spin_lock_bh(&head->lock); /* Does not bother with rcv_saddr checks, because * the established check is already unique enough. */ inet_bind_bucket_for_each(tb, &head->chain) { - if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev && - tb->port == port) { + if (inet_bind_bucket_match(tb, net, port, l3mdev)) { if (tb->fastreuse >= 0 || tb->fastreuseport >= 0) - goto next_port; - WARN_ON(hlist_empty(&tb->owners)); + goto next_port_unlock; + WARN_ON(hlist_empty(&tb->bhash2)); if (!check_established(death_row, sk, - port, &tw)) + port, &tw, false, + hash_port0 + port)) goto ok; - goto next_port; + goto next_port_unlock; } } @@ -805,41 +1133,96 @@ other_parity_scan: spin_unlock_bh(&head->lock); return -ENOMEM; } + tb_created = true; tb->fastreuse = -1; tb->fastreuseport = -1; goto ok; -next_port: +next_port_unlock: spin_unlock_bh(&head->lock); +next_port: cond_resched(); } - offset++; - if ((offset & 1) && remaining > 1) - goto other_parity_scan; - + if (!local_ports) { + offset++; + if ((offset & 1) && remaining > 1) + goto other_parity_scan; + } return -EADDRNOTAVAIL; ok: - /* If our first attempt found a candidate, skip next candidate - * in 1/16 of cases to add some noise. + /* Find the corresponding tb2 bucket since we need to + * add the socket to the bhash2 table as well + */ + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + spin_lock(&head2->lock); + + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + if (!tb2) { + tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net, + head2, tb, sk); + if (!tb2) + goto error; + tb2->fastreuse = -1; + tb2->fastreuseport = -1; + } + + /* Here we want to add a little bit of randomness to the next source + * port that will be chosen. We use a max() with a random here so that + * on low contention the randomness is maximal and on high contention + * it may be inexistent. */ - if (!i && !(prandom_u32() % 16)) - i = 2; - WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2); + i = max_t(int, i, get_random_u32_below(8) * step); + WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + step); /* Head lock still held and bh's disabled */ - inet_bind_hash(sk, tb, port); + inet_bind_hash(sk, tb, tb2, port); + sk->sk_userlocks |= SOCK_CONNECT_BIND; + if (sk_unhashed(sk)) { inet_sk(sk)->inet_sport = htons(port); inet_ehash_nolisten(sk, (struct sock *)tw, NULL); } if (tw) inet_twsk_bind_unhash(tw, hinfo); + + spin_unlock(&head2->lock); spin_unlock(&head->lock); + if (tw) inet_twsk_deschedule_put(tw); local_bh_enable(); return 0; + +error: + if (sk_hashed(sk)) { + spinlock_t *lock = inet_ehash_lockp(hinfo, sk->sk_hash); + + sock_prot_inuse_add(net, sk->sk_prot, -1); + + spin_lock(lock); + __sk_nulls_del_node_init_rcu(sk); + spin_unlock(lock); + + sk->sk_hash = 0; + inet_sk(sk)->inet_sport = 0; + inet_sk(sk)->inet_num = 0; + + if (tw) + inet_twsk_bind_unhash(tw, hinfo); + } + + spin_unlock(&head2->lock); + if (tb_created) + inet_bind_bucket_destroy(tb); + spin_unlock(&head->lock); + + if (tw) + inet_twsk_deschedule_put(tw); + + local_bh_enable(); + + return -ENOMEM; } /* @@ -848,29 +1231,20 @@ ok: int inet_hash_connect(struct inet_timewait_death_row *death_row, struct sock *sk) { - u32 port_offset = 0; + const struct inet_sock *inet = inet_sk(sk); + const struct net *net = sock_net(sk); + u64 port_offset = 0; + u32 hash_port0; if (!inet_sk(sk)->inet_num) port_offset = inet_sk_port_offset(sk); - return __inet_hash_connect(death_row, sk, port_offset, - __inet_check_established); -} -EXPORT_SYMBOL_GPL(inet_hash_connect); -void inet_hashinfo_init(struct inet_hashinfo *h) -{ - int i; + hash_port0 = inet_ehashfn(net, inet->inet_rcv_saddr, 0, + inet->inet_daddr, inet->inet_dport); - for (i = 0; i < INET_LHTABLE_SIZE; i++) { - spin_lock_init(&h->listening_hash[i].lock); - INIT_HLIST_NULLS_HEAD(&h->listening_hash[i].nulls_head, - i + LISTENING_NULLS_BASE); - h->listening_hash[i].count = 0; - } - - h->lhash2 = NULL; + return __inet_hash_connect(death_row, sk, port_offset, hash_port0, + __inet_check_established); } -EXPORT_SYMBOL_GPL(inet_hashinfo_init); static void init_hashinfo_lhash2(struct inet_hashinfo *h) { @@ -878,8 +1252,8 @@ static void init_hashinfo_lhash2(struct inet_hashinfo *h) for (i = 0; i <= h->lhash2_mask; i++) { spin_lock_init(&h->lhash2[i].lock); - INIT_HLIST_HEAD(&h->lhash2[i].head); - h->lhash2[i].count = 0; + INIT_HLIST_NULLS_HEAD(&h->lhash2[i].nulls_head, + i + LISTENING_NULLS_BASE); } } @@ -898,6 +1272,14 @@ void __init inet_hashinfo2_init(struct inet_hashinfo *h, const char *name, low_limit, high_limit); init_hashinfo_lhash2(h); + + /* this one is used for source ports of outgoing connections */ + table_perturb = alloc_large_system_hash("Table-perturb", + sizeof(*table_perturb), + INET_TABLE_PERTURB_SIZE, + 0, 0, NULL, NULL, + INET_TABLE_PERTURB_SIZE, + INET_TABLE_PERTURB_SIZE); } int inet_hashinfo2_init_mod(struct inet_hashinfo *h) @@ -913,29 +1295,87 @@ int inet_hashinfo2_init_mod(struct inet_hashinfo *h) init_hashinfo_lhash2(h); return 0; } -EXPORT_SYMBOL_GPL(inet_hashinfo2_init_mod); int inet_ehash_locks_alloc(struct inet_hashinfo *hashinfo) { unsigned int locksz = sizeof(spinlock_t); unsigned int i, nblocks = 1; + spinlock_t *ptr = NULL; - if (locksz != 0) { - /* allocate 2 cache lines or at least one spinlock per cpu */ - nblocks = max(2U * L1_CACHE_BYTES / locksz, 1U); - nblocks = roundup_pow_of_two(nblocks * num_possible_cpus()); + if (locksz == 0) + goto set_mask; - /* no more locks than number of hash buckets */ - nblocks = min(nblocks, hashinfo->ehash_mask + 1); + /* Allocate 2 cache lines or at least one spinlock per cpu. */ + nblocks = max(2U * L1_CACHE_BYTES / locksz, 1U) * num_possible_cpus(); - hashinfo->ehash_locks = kvmalloc_array(nblocks, locksz, GFP_KERNEL); - if (!hashinfo->ehash_locks) - return -ENOMEM; + /* At least one page per NUMA node. */ + nblocks = max(nblocks, num_online_nodes() * PAGE_SIZE / locksz); + + nblocks = roundup_pow_of_two(nblocks); - for (i = 0; i < nblocks; i++) - spin_lock_init(&hashinfo->ehash_locks[i]); + /* No more locks than number of hash buckets. */ + nblocks = min(nblocks, hashinfo->ehash_mask + 1); + + if (num_online_nodes() > 1) { + /* Use vmalloc() to allow NUMA policy to spread pages + * on all available nodes if desired. + */ + ptr = vmalloc_array(nblocks, locksz); + } + if (!ptr) { + ptr = kvmalloc_array(nblocks, locksz, GFP_KERNEL); + if (!ptr) + return -ENOMEM; } + for (i = 0; i < nblocks; i++) + spin_lock_init(&ptr[i]); + hashinfo->ehash_locks = ptr; +set_mask: hashinfo->ehash_locks_mask = nblocks - 1; return 0; } -EXPORT_SYMBOL_GPL(inet_ehash_locks_alloc); + +struct inet_hashinfo *inet_pernet_hashinfo_alloc(struct inet_hashinfo *hashinfo, + unsigned int ehash_entries) +{ + struct inet_hashinfo *new_hashinfo; + int i; + + new_hashinfo = kmemdup(hashinfo, sizeof(*hashinfo), GFP_KERNEL); + if (!new_hashinfo) + goto err; + + new_hashinfo->ehash = vmalloc_huge(ehash_entries * sizeof(struct inet_ehash_bucket), + GFP_KERNEL_ACCOUNT); + if (!new_hashinfo->ehash) + goto free_hashinfo; + + new_hashinfo->ehash_mask = ehash_entries - 1; + + if (inet_ehash_locks_alloc(new_hashinfo)) + goto free_ehash; + + for (i = 0; i < ehash_entries; i++) + INIT_HLIST_NULLS_HEAD(&new_hashinfo->ehash[i].chain, i); + + new_hashinfo->pernet = true; + + return new_hashinfo; + +free_ehash: + vfree(new_hashinfo->ehash); +free_hashinfo: + kfree(new_hashinfo); +err: + return NULL; +} + +void inet_pernet_hashinfo_free(struct inet_hashinfo *hashinfo) +{ + if (!hashinfo->pernet) + return; + + inet_ehash_locks_free(hashinfo); + vfree(hashinfo->ehash); + kfree(hashinfo); +} diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c index 437afe392e66..d4c781a0667f 100644 --- a/net/ipv4/inet_timewait_sock.c +++ b/net/ipv4/inet_timewait_sock.c @@ -15,7 +15,8 @@ #include <net/inet_hashtables.h> #include <net/inet_timewait_sock.h> #include <net/ip.h> - +#include <net/tcp.h> +#include <net/psp.h> /** * inet_twsk_bind_unhash - unhash a timewait socket from bind hash @@ -29,14 +30,18 @@ void inet_twsk_bind_unhash(struct inet_timewait_sock *tw, struct inet_hashinfo *hashinfo) { + struct inet_bind2_bucket *tb2 = tw->tw_tb2; struct inet_bind_bucket *tb = tw->tw_tb; if (!tb) return; - __hlist_del(&tw->tw_bind_node); + __sk_del_bind_node((struct sock *)tw); tw->tw_tb = NULL; - inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); + tw->tw_tb2 = NULL; + inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2); + inet_bind_bucket_destroy(tb); + __sock_put((struct sock *)tw); } @@ -45,7 +50,7 @@ static void inet_twsk_kill(struct inet_timewait_sock *tw) { struct inet_hashinfo *hashinfo = tw->tw_dr->hashinfo; spinlock_t *lock = inet_ehash_lockp(hashinfo, tw->tw_hash); - struct inet_bind_hashbucket *bhead; + struct inet_bind_hashbucket *bhead, *bhead2; spin_lock(lock); sk_nulls_del_node_init_rcu((struct sock *)tw); @@ -54,22 +59,24 @@ static void inet_twsk_kill(struct inet_timewait_sock *tw) /* Disassociate with bind bucket. */ bhead = &hashinfo->bhash[inet_bhashfn(twsk_net(tw), tw->tw_num, hashinfo->bhash_size)]; + bhead2 = inet_bhashfn_portaddr(hashinfo, (struct sock *)tw, + twsk_net(tw), tw->tw_num); spin_lock(&bhead->lock); + spin_lock(&bhead2->lock); inet_twsk_bind_unhash(tw, hashinfo); + spin_unlock(&bhead2->lock); spin_unlock(&bhead->lock); - atomic_dec(&tw->tw_dr->tw_count); + refcount_dec(&tw->tw_dr->tw_refcount); inet_twsk_put(tw); } void inet_twsk_free(struct inet_timewait_sock *tw) { struct module *owner = tw->tw_prot->owner; - twsk_destructor((struct sock *)tw); -#ifdef SOCK_REFCNT_DEBUG - pr_debug("%s timewait_sock %p released\n", tw->tw_prot->name, tw); -#endif + + tcp_twsk_destructor((struct sock *)tw); kmem_cache_free(tw->tw_prot->twsk_prot->twsk_slab, tw); module_put(owner); } @@ -81,74 +88,80 @@ void inet_twsk_put(struct inet_timewait_sock *tw) } EXPORT_SYMBOL_GPL(inet_twsk_put); -static void inet_twsk_add_node_rcu(struct inet_timewait_sock *tw, - struct hlist_nulls_head *list) +static void inet_twsk_schedule(struct inet_timewait_sock *tw, int timeo) { - hlist_nulls_add_head_rcu(&tw->tw_node, list); -} - -static void inet_twsk_add_bind_node(struct inet_timewait_sock *tw, - struct hlist_head *list) -{ - hlist_add_head(&tw->tw_bind_node, list); + __inet_twsk_schedule(tw, timeo, false); } /* - * Enter the time wait state. This is called with locally disabled BH. + * Enter the time wait state. * Essentially we whip up a timewait bucket, copy the relevant info into it * from the SK, and mess with hash chains and list linkage. + * + * The caller must not access @tw anymore after this function returns. */ -void inet_twsk_hashdance(struct inet_timewait_sock *tw, struct sock *sk, - struct inet_hashinfo *hashinfo) +void inet_twsk_hashdance_schedule(struct inet_timewait_sock *tw, + struct sock *sk, + struct inet_hashinfo *hashinfo, + int timeo) { const struct inet_sock *inet = inet_sk(sk); const struct inet_connection_sock *icsk = inet_csk(sk); - struct inet_ehash_bucket *ehead = inet_ehash_bucket(hashinfo, sk->sk_hash); spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash); - struct inet_bind_hashbucket *bhead; - /* Step 1: Put TW into bind hash. Original socket stays there too. - Note, that any socket with inet->num != 0 MUST be bound in - binding cache, even if it is closed. + struct inet_bind_hashbucket *bhead, *bhead2; + + /* Put TW into bind hash. Original socket stays there too. + * Note, that any socket with inet->num != 0 MUST be bound in + * binding cache, even if it is closed. */ bhead = &hashinfo->bhash[inet_bhashfn(twsk_net(tw), inet->inet_num, hashinfo->bhash_size)]; + bhead2 = inet_bhashfn_portaddr(hashinfo, sk, twsk_net(tw), inet->inet_num); + + local_bh_disable(); spin_lock(&bhead->lock); + spin_lock(&bhead2->lock); + tw->tw_tb = icsk->icsk_bind_hash; WARN_ON(!icsk->icsk_bind_hash); - inet_twsk_add_bind_node(tw, &tw->tw_tb->owners); - spin_unlock(&bhead->lock); - - spin_lock(lock); - inet_twsk_add_node_rcu(tw, &ehead->chain); + tw->tw_tb2 = icsk->icsk_bind2_hash; + WARN_ON(!icsk->icsk_bind2_hash); + sk_add_bind_node((struct sock *)tw, &tw->tw_tb2->owners); - /* Step 3: Remove SK from hash chain */ - if (__sk_nulls_del_node_init_rcu(sk)) - sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + spin_unlock(&bhead2->lock); + spin_unlock(&bhead->lock); - spin_unlock(lock); + spin_lock(lock); /* tw_refcnt is set to 3 because we have : * - one reference for bhash chain. * - one reference for ehash chain. * - one reference for timer. - * We can use atomic_set() because prior spin_lock()/spin_unlock() - * committed into memory all tw fields. * Also note that after this point, we lost our implicit reference * so we are not allowed to use tw anymore. */ refcount_set(&tw->tw_refcnt, 3); + + /* Ensure tw_refcnt has been set before tw is published. + * smp_wmb() provides the necessary memory barrier to enforce this + * ordering. + */ + smp_wmb(); + + hlist_nulls_replace_init_rcu(&sk->sk_nulls_node, &tw->tw_node); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + + inet_twsk_schedule(tw, timeo); + + spin_unlock(lock); + local_bh_enable(); } -EXPORT_SYMBOL_GPL(inet_twsk_hashdance); static void tw_timer_handler(struct timer_list *t) { - struct inet_timewait_sock *tw = from_timer(tw, t, tw_timer); + struct inet_timewait_sock *tw = timer_container_of(tw, t, tw_timer); - if (tw->tw_kill) - __NET_INC_STATS(twsk_net(tw), LINUX_MIB_TIMEWAITKILLED); - else - __NET_INC_STATS(twsk_net(tw), LINUX_MIB_TIMEWAITED); inet_twsk_kill(tw); } @@ -158,7 +171,8 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk, { struct inet_timewait_sock *tw; - if (atomic_read(&dr->tw_count) >= dr->sysctl_max_tw_buckets) + if (refcount_read(&dr->tw_refcount) - 1 >= + READ_ONCE(dr->sysctl_max_tw_buckets)) return NULL; tw = kmem_cache_alloc(sk->sk_prot_creator->twsk_prot->twsk_slab, @@ -182,11 +196,15 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk, tw->tw_reuseport = sk->sk_reuseport; tw->tw_hash = sk->sk_hash; tw->tw_ipv6only = 0; - tw->tw_transparent = inet->transparent; + tw->tw_transparent = inet_test_bit(TRANSPARENT, sk); + tw->tw_connect_bind = !!(sk->sk_userlocks & SOCK_CONNECT_BIND); tw->tw_prot = sk->sk_prot_creator; atomic64_set(&tw->tw_cookie, atomic64_read(&sk->sk_cookie)); twsk_net_set(tw, sock_net(sk)); - timer_setup(&tw->tw_timer, tw_timer_handler, TIMER_PINNED); + timer_setup(&tw->tw_timer, tw_timer_handler, 0); +#ifdef CONFIG_SOCK_VALIDATE_XMIT + tw->tw_validate_xmit_skb = NULL; +#endif /* * Because we use RCU lookups, we should not set tw_refcnt * to a non null value before everything is setup for this @@ -195,11 +213,11 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk, refcount_set(&tw->tw_refcnt, 0); __module_get(tw->tw_prot->owner); + psp_twsk_init(tw, sk); } return tw; } -EXPORT_SYMBOL_GPL(inet_twsk_alloc); /* These are always called from BH context. See callers in * tcp_input.c to verify this. @@ -211,7 +229,34 @@ EXPORT_SYMBOL_GPL(inet_twsk_alloc); */ void inet_twsk_deschedule_put(struct inet_timewait_sock *tw) { - if (del_timer_sync(&tw->tw_timer)) + struct inet_hashinfo *hashinfo = tw->tw_dr->hashinfo; + spinlock_t *lock = inet_ehash_lockp(hashinfo, tw->tw_hash); + + /* inet_twsk_purge() walks over all sockets, including tw ones, + * and removes them via inet_twsk_deschedule_put() after a + * refcount_inc_not_zero(). + * + * inet_twsk_hashdance_schedule() must (re)init the refcount before + * arming the timer, i.e. inet_twsk_purge can obtain a reference to + * a twsk that did not yet schedule the timer. + * + * The ehash lock synchronizes these two: + * After acquiring the lock, the timer is always scheduled (else + * timer_shutdown returns false), because hashdance_schedule releases + * the ehash lock only after completing the timer initialization. + * + * Without grabbing the ehash lock, we get: + * 1) cpu x sets twsk refcount to 3 + * 2) cpu y bumps refcount to 4 + * 3) cpu y calls inet_twsk_deschedule_put() and shuts timer down + * 4) cpu x tries to start timer, but mod_timer is a noop post-shutdown + * -> timer refcount is never decremented. + */ + spin_lock(lock); + /* Makes sure hashdance_schedule() has completed */ + spin_unlock(lock); + + if (timer_shutdown_sync(&tw->tw_timer)) inet_twsk_kill(tw); inet_twsk_put(tw); } @@ -244,49 +289,63 @@ void __inet_twsk_schedule(struct inet_timewait_sock *tw, int timeo, bool rearm) * of PAWS. */ - tw->tw_kill = timeo <= 4*HZ; if (!rearm) { + bool kill = timeo <= 4*HZ; + + __NET_INC_STATS(twsk_net(tw), kill ? LINUX_MIB_TIMEWAITKILLED : + LINUX_MIB_TIMEWAITED); BUG_ON(mod_timer(&tw->tw_timer, jiffies + timeo)); - atomic_inc(&tw->tw_dr->tw_count); + refcount_inc(&tw->tw_dr->tw_refcount); } else { mod_timer_pending(&tw->tw_timer, jiffies + timeo); } } -EXPORT_SYMBOL_GPL(__inet_twsk_schedule); -void inet_twsk_purge(struct inet_hashinfo *hashinfo, int family) +/* Remove all non full sockets (TIME_WAIT and NEW_SYN_RECV) for dead netns */ +void inet_twsk_purge(struct inet_hashinfo *hashinfo) { - struct inet_timewait_sock *tw; - struct sock *sk; + struct inet_ehash_bucket *head = &hashinfo->ehash[0]; + unsigned int ehash_mask = hashinfo->ehash_mask; struct hlist_nulls_node *node; unsigned int slot; + struct sock *sk; + + for (slot = 0; slot <= ehash_mask; slot++, head++) { + if (hlist_nulls_empty(&head->chain)) + continue; - for (slot = 0; slot <= hashinfo->ehash_mask; slot++) { - struct inet_ehash_bucket *head = &hashinfo->ehash[slot]; restart_rcu: cond_resched(); rcu_read_lock(); restart: sk_nulls_for_each_rcu(sk, node, &head->chain) { - if (sk->sk_state != TCP_TIME_WAIT) + int state = inet_sk_state_load(sk); + + if ((1 << state) & ~(TCPF_TIME_WAIT | + TCPF_NEW_SYN_RECV)) continue; - tw = inet_twsk(sk); - if ((tw->tw_family != family) || - refcount_read(&twsk_net(tw)->ns.count)) + + if (check_net(sock_net(sk))) continue; - if (unlikely(!refcount_inc_not_zero(&tw->tw_refcnt))) + if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) continue; - if (unlikely((tw->tw_family != family) || - refcount_read(&twsk_net(tw)->ns.count))) { - inet_twsk_put(tw); + if (check_net(sock_net(sk))) { + sock_gen_put(sk); goto restart; } rcu_read_unlock(); local_bh_disable(); - inet_twsk_deschedule_put(tw); + if (state == TCP_TIME_WAIT) { + inet_twsk_deschedule_put(inet_twsk(sk)); + } else { + struct request_sock *req = inet_reqsk(sk); + + inet_csk_reqsk_queue_drop_and_put(req->rsk_listener, + req); + } local_bh_enable(); goto restart_rcu; } @@ -299,4 +358,3 @@ restart: rcu_read_unlock(); } } -EXPORT_SYMBOL_GPL(inet_twsk_purge); diff --git a/net/ipv4/inetpeer.c b/net/ipv4/inetpeer.c index da21dfce24d7..7b1e0a2d6906 100644 --- a/net/ipv4/inetpeer.c +++ b/net/ipv4/inetpeer.c @@ -60,7 +60,7 @@ void inet_peer_base_init(struct inet_peer_base *bp) seqlock_init(&bp->lock); bp->total = 0; } -EXPORT_SYMBOL_GPL(inet_peer_base_init); +EXPORT_IPV6_MOD_GPL(inet_peer_base_init); #define PEER_MAX_GC 32 @@ -81,10 +81,7 @@ void __init inet_initpeers(void) inet_peer_threshold = clamp_val(nr_entries, 4096, 65536 + 128); - peer_cachep = kmem_cache_create("inet_peer_cache", - sizeof(struct inet_peer), - 0, SLAB_HWCACHE_ALIGN | SLAB_PANIC, - NULL); + peer_cachep = KMEM_CACHE(inet_peer, SLAB_HWCACHE_ALIGN | SLAB_PANIC); } /* Called with rcu_read_lock() or base->lock held */ @@ -98,6 +95,7 @@ static struct inet_peer *lookup(const struct inetpeer_addr *daddr, { struct rb_node **pp, *parent, *next; struct inet_peer *p; + u32 now; pp = &base->rb_root.rb_node; parent = NULL; @@ -111,8 +109,9 @@ static struct inet_peer *lookup(const struct inetpeer_addr *daddr, p = rb_entry(parent, struct inet_peer, rb_node); cmp = inetpeer_addr_cmp(daddr, &p->daddr); if (cmp == 0) { - if (!refcount_inc_not_zero(&p->refcnt)) - break; + now = jiffies; + if (READ_ONCE(p->dtime) != now) + WRITE_ONCE(p->dtime, now); return p; } if (gc_stack) { @@ -131,32 +130,28 @@ static struct inet_peer *lookup(const struct inetpeer_addr *daddr, return NULL; } -static void inetpeer_free_rcu(struct rcu_head *head) -{ - kmem_cache_free(peer_cachep, container_of(head, struct inet_peer, rcu)); -} - /* perform garbage collect on all items stacked during a lookup */ static void inet_peer_gc(struct inet_peer_base *base, struct inet_peer *gc_stack[], unsigned int gc_cnt) { + int peer_threshold, peer_maxttl, peer_minttl; struct inet_peer *p; __u32 delta, ttl; int i; - if (base->total >= inet_peer_threshold) + peer_threshold = READ_ONCE(inet_peer_threshold); + peer_maxttl = READ_ONCE(inet_peer_maxttl); + peer_minttl = READ_ONCE(inet_peer_minttl); + + if (base->total >= peer_threshold) ttl = 0; /* be aggressive */ else - ttl = inet_peer_maxttl - - (inet_peer_maxttl - inet_peer_minttl) / HZ * - base->total / inet_peer_threshold * HZ; + ttl = peer_maxttl - (peer_maxttl - peer_minttl) / HZ * + base->total / peer_threshold * HZ; for (i = 0; i < gc_cnt; i++) { p = gc_stack[i]; - /* The READ_ONCE() pairs with the WRITE_ONCE() - * in inet_putpeer() - */ delta = (__u32)jiffies - READ_ONCE(p->dtime); if (delta < ttl || !refcount_dec_if_one(&p->refcnt)) @@ -167,36 +162,28 @@ static void inet_peer_gc(struct inet_peer_base *base, if (p) { rb_erase(&p->rb_node, &base->rb_root); base->total--; - call_rcu(&p->rcu, inetpeer_free_rcu); + kfree_rcu(p, rcu); } } } +/* Must be called under RCU : No refcount change is done here. */ struct inet_peer *inet_getpeer(struct inet_peer_base *base, - const struct inetpeer_addr *daddr, - int create) + const struct inetpeer_addr *daddr) { struct inet_peer *p, *gc_stack[PEER_MAX_GC]; struct rb_node **pp, *parent; unsigned int gc_cnt, seq; - int invalidated; /* Attempt a lockless lookup first. * Because of a concurrent writer, we might not find an existing entry. */ - rcu_read_lock(); seq = read_seqbegin(&base->lock); p = lookup(daddr, base, seq, NULL, &gc_cnt, &parent, &pp); - invalidated = read_seqretry(&base->lock, seq); - rcu_read_unlock(); if (p) return p; - /* If no writer did a change during our lookup, we can return early. */ - if (!create && !invalidated) - return NULL; - /* retry an exact lookup, taking the lock before. * At least, nodes should be hot in our cache. */ @@ -205,12 +192,12 @@ struct inet_peer *inet_getpeer(struct inet_peer_base *base, gc_cnt = 0; p = lookup(daddr, base, seq, gc_stack, &gc_cnt, &parent, &pp); - if (!p && create) { + if (!p) { p = kmem_cache_alloc(peer_cachep, GFP_ATOMIC); if (p) { p->daddr = *daddr; p->dtime = (__u32)jiffies; - refcount_set(&p->refcnt, 2); + refcount_set(&p->refcnt, 1); atomic_set(&p->rid, 0); p->metrics[RTAX_LOCK-1] = INETPEER_METRICS_NEW; p->rate_tokens = 0; @@ -231,19 +218,13 @@ struct inet_peer *inet_getpeer(struct inet_peer_base *base, return p; } -EXPORT_SYMBOL_GPL(inet_getpeer); +EXPORT_IPV6_MOD_GPL(inet_getpeer); void inet_putpeer(struct inet_peer *p) { - /* The WRITE_ONCE() pairs with itself (we run lockless) - * and the READ_ONCE() in inet_peer_gc() - */ - WRITE_ONCE(p->dtime, (__u32)jiffies); - if (refcount_dec_and_test(&p->refcnt)) - call_rcu(&p->rcu, inetpeer_free_rcu); + kfree_rcu(p, rcu); } -EXPORT_SYMBOL_GPL(inet_putpeer); /* * Check transmit rate limitation for given message. @@ -265,26 +246,30 @@ EXPORT_SYMBOL_GPL(inet_putpeer); #define XRLIM_BURST_FACTOR 6 bool inet_peer_xrlim_allow(struct inet_peer *peer, int timeout) { - unsigned long now, token; + unsigned long now, token, otoken, delta; bool rc = false; if (!peer) return true; - token = peer->rate_tokens; + token = otoken = READ_ONCE(peer->rate_tokens); now = jiffies; - token += now - peer->rate_last; - peer->rate_last = now; - if (token > XRLIM_BURST_FACTOR * timeout) - token = XRLIM_BURST_FACTOR * timeout; + delta = now - READ_ONCE(peer->rate_last); + if (delta) { + WRITE_ONCE(peer->rate_last, now); + token += delta; + if (token > XRLIM_BURST_FACTOR * timeout) + token = XRLIM_BURST_FACTOR * timeout; + } if (token >= timeout) { token -= timeout; rc = true; } - peer->rate_tokens = token; + if (token != otoken) + WRITE_ONCE(peer->rate_tokens, token); return rc; } -EXPORT_SYMBOL(inet_peer_xrlim_allow); +EXPORT_IPV6_MOD(inet_peer_xrlim_allow); void inetpeer_invalidate_tree(struct inet_peer_base *base) { @@ -301,4 +286,4 @@ void inetpeer_invalidate_tree(struct inet_peer_base *base) base->total = 0; } -EXPORT_SYMBOL(inetpeer_invalidate_tree); +EXPORT_IPV6_MOD(inetpeer_invalidate_tree); diff --git a/net/ipv4/ip_forward.c b/net/ipv4/ip_forward.c index 00ec819f949b..8b65f12583eb 100644 --- a/net/ipv4/ip_forward.c +++ b/net/ipv4/ip_forward.c @@ -66,9 +66,6 @@ static int ip_forward_finish(struct net *net, struct sock *sk, struct sk_buff *s { struct ip_options *opt = &(IPCB(skb)->opt); - __IP_INC_STATS(net, IPSTATS_MIB_OUTFORWDATAGRAMS); - __IP_ADD_STATS(net, IPSTATS_MIB_OUTOCTETS, skb->len); - #ifdef CONFIG_NET_SWITCHDEV if (skb->offload_l3_fwd_mark) { consume_skb(skb); @@ -79,7 +76,7 @@ static int ip_forward_finish(struct net *net, struct sock *sk, struct sk_buff *s if (unlikely(opt->optlen)) ip_forward_options(skb); - skb->tstamp = 0; + skb_clear_tstamp(skb); return dst_output(net, sk, skb); } @@ -90,6 +87,7 @@ int ip_forward(struct sk_buff *skb) struct rtable *rt; /* Route we use */ struct ip_options *opt = &(IPCB(skb)->opt); struct net *net; + SKB_DR(reason); /* that should never happen */ if (skb->pkt_type != PACKET_HOST) @@ -101,8 +99,10 @@ int ip_forward(struct sk_buff *skb) if (skb_warn_if_lro(skb)) goto drop; - if (!xfrm4_policy_check(NULL, XFRM_POLICY_FWD, skb)) + if (!xfrm4_policy_check(NULL, XFRM_POLICY_FWD, skb)) { + SKB_DR_SET(reason, XFRM_POLICY); goto drop; + } if (IPCB(skb)->opt.router_alert && ip_call_ra_chain(skb)) return NET_RX_SUCCESS; @@ -118,20 +118,25 @@ int ip_forward(struct sk_buff *skb) if (ip_hdr(skb)->ttl <= 1) goto too_many_hops; - if (!xfrm4_route_forward(skb)) + if (!xfrm4_route_forward(skb)) { + SKB_DR_SET(reason, XFRM_POLICY); goto drop; + } rt = skb_rtable(skb); if (opt->is_strictroute && rt->rt_uses_gateway) goto sr_failed; + __IP_INC_STATS(net, IPSTATS_MIB_OUTFORWDATAGRAMS); + IPCB(skb)->flags |= IPSKB_FORWARDED; mtu = ip_dst_mtu_maybe_forward(&rt->dst, true); if (ip_exceeds_mtu(skb, mtu)) { IP_INC_STATS(net, IPSTATS_MIB_FRAGFAILS); icmp_send(skb, ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, htonl(mtu)); + SKB_DR_SET(reason, PKT_TOO_BIG); goto drop; } @@ -151,7 +156,7 @@ int ip_forward(struct sk_buff *skb) !skb_sec_path(skb)) ip_rt_send_redirect(skb); - if (net->ipv4.sysctl_ip_fwd_update_priority) + if (READ_ONCE(net->ipv4.sysctl_ip_fwd_update_priority)) skb->priority = rt_tos2priority(iph->tos); return NF_HOOK(NFPROTO_IPV4, NF_INET_FORWARD, @@ -169,7 +174,8 @@ too_many_hops: /* Tell the sender its packet died... */ __IP_INC_STATS(net, IPSTATS_MIB_INHDRERRORS); icmp_send(skb, ICMP_TIME_EXCEEDED, ICMP_EXC_TTL, 0); + SKB_DR_SET(reason, IP_INHDR); drop: - kfree_skb(skb); + kfree_skb_reason(skb, reason); return NET_RX_DROP; } diff --git a/net/ipv4/ip_fragment.c b/net/ipv4/ip_fragment.c index fad803d2d711..f7012479713b 100644 --- a/net/ipv4/ip_fragment.c +++ b/net/ipv4/ip_fragment.c @@ -76,21 +76,27 @@ static u8 ip4_frag_ecn(u8 tos) static struct inet_frags ip4_frags; static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb, - struct sk_buff *prev_tail, struct net_device *dev); + struct sk_buff *prev_tail, struct net_device *dev, + int *refs); static void ip4_frag_init(struct inet_frag_queue *q, const void *a) { struct ipq *qp = container_of(q, struct ipq, q); - struct net *net = q->fqdir->net; - const struct frag_v4_compare_key *key = a; + struct net *net = q->fqdir->net; + struct inet_peer *p = NULL; q->key.v4 = *key; qp->ecn = 0; - qp->peer = q->fqdir->max_dist ? - inet_getpeer_v4(net->ipv4.peers, key->saddr, key->vif, 1) : - NULL; + if (q->fqdir->max_dist) { + rcu_read_lock(); + p = inet_getpeer_v4(net->ipv4.peers, key->saddr, key->vif); + if (p && !refcount_inc_not_zero(&p->refcnt)) + p = NULL; + rcu_read_unlock(); + } + qp->peer = p; } static void ip4_frag_free(struct inet_frag_queue *q) @@ -102,22 +108,6 @@ static void ip4_frag_free(struct inet_frag_queue *q) inet_putpeer(qp->peer); } - -/* Destruction primitives. */ - -static void ipq_put(struct ipq *ipq) -{ - inet_frag_put(&ipq->q); -} - -/* Kill ipq entry. It is not destroyed immediately, - * because caller (and someone more) holds reference count. - */ -static void ipq_kill(struct ipq *ipq) -{ - inet_frag_kill(&ipq->q); -} - static bool frag_expire_skip_icmp(u32 user) { return user == IP_DEFRAG_AF_PACKET || @@ -132,12 +122,13 @@ static bool frag_expire_skip_icmp(u32 user) */ static void ip_expire(struct timer_list *t) { - struct inet_frag_queue *frag = from_timer(frag, t, timer); + enum skb_drop_reason reason = SKB_DROP_REASON_FRAG_REASM_TIMEOUT; + struct inet_frag_queue *frag = timer_container_of(frag, t, timer); const struct iphdr *iph; struct sk_buff *head = NULL; struct net *net; struct ipq *qp; - int err; + int refs = 1; qp = container_of(frag, struct ipq, q); net = qp->q.fqdir->net; @@ -153,7 +144,8 @@ static void ip_expire(struct timer_list *t) if (qp->q.flags & INET_FRAG_COMPLETE) goto out; - ipq_kill(qp); + qp->q.flags |= INET_FRAG_DROP; + inet_frag_kill(&qp->q, &refs); __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); __IP_INC_STATS(net, IPSTATS_MIB_REASMTIMEOUT); @@ -174,14 +166,15 @@ static void ip_expire(struct timer_list *t) /* skb has no dst, perform route lookup again */ iph = ip_hdr(head); - err = ip_route_input_noref(head, iph->daddr, iph->saddr, - iph->tos, head->dev); - if (err) + reason = ip_route_input_noref(head, iph->daddr, iph->saddr, + ip4h_dscp(iph), head->dev); + if (reason) goto out; /* Only an end host needs to send an ICMP * "Fragment Reassembly Timeout" message, per RFC792. */ + reason = SKB_DROP_REASON_FRAG_REASM_TIMEOUT; if (frag_expire_skip_icmp(qp->q.key.v4.user) && (skb_rtable(head)->rt_type != RTN_LOCAL)) goto out; @@ -194,8 +187,8 @@ out: spin_unlock(&qp->q.lock); out_rcu_unlock: rcu_read_unlock(); - kfree_skb(head); - ipq_put(qp); + kfree_skb_reason(head, reason); + inet_frag_putn(&qp->q, refs); } /* Find the correct entry in the "incomplete datagrams" queue for @@ -254,7 +247,8 @@ static int ip_frag_reinit(struct ipq *qp) return -ETIMEDOUT; } - sum_truesize = inet_frag_rbtree_purge(&qp->q.rb_fragments); + sum_truesize = inet_frag_rbtree_purge(&qp->q.rb_fragments, + SKB_DROP_REASON_FRAG_TOO_FAR); sub_frag_mem_limit(qp->q.fqdir, sum_truesize); qp->q.flags = 0; @@ -270,7 +264,7 @@ static int ip_frag_reinit(struct ipq *qp) } /* Add new segment to existing queue. */ -static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) +static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb, int *refs) { struct net *net = qp->q.fqdir->net; int ihl, end, flags, offset; @@ -278,15 +272,19 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) struct net_device *dev; unsigned int fragsize; int err = -ENOENT; + SKB_DR(reason); u8 ecn; - if (qp->q.flags & INET_FRAG_COMPLETE) + /* If reassembly is already done, @skb must be a duplicate frag. */ + if (qp->q.flags & INET_FRAG_COMPLETE) { + SKB_DR_SET(reason, DUP_FRAG); goto err; + } if (!(IPCB(skb)->flags & IPSKB_FRAG_COMPLETE) && unlikely(ip_frag_too_far(qp)) && unlikely(err = ip_frag_reinit(qp))) { - ipq_kill(qp); + inet_frag_kill(&qp->q, refs); goto err; } @@ -349,6 +347,7 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) qp->iif = dev->ifindex; qp->q.stamp = skb->tstamp; + qp->q.tstamp_type = skb->tstamp_type; qp->q.meat += skb->len; qp->ecn |= ecn; add_frag_mem_limit(qp->q.fqdir, skb->truesize); @@ -369,28 +368,30 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) unsigned long orefdst = skb->_skb_refdst; skb->_skb_refdst = 0UL; - err = ip_frag_reasm(qp, skb, prev_tail, dev); + err = ip_frag_reasm(qp, skb, prev_tail, dev, refs); skb->_skb_refdst = orefdst; if (err) - inet_frag_kill(&qp->q); + inet_frag_kill(&qp->q, refs); return err; } skb_dst_drop(skb); + skb_orphan(skb); return -EINPROGRESS; insert_error: if (err == IPFRAG_DUP) { - kfree_skb(skb); - return -EINVAL; + SKB_DR_SET(reason, DUP_FRAG); + err = -EINVAL; + goto err; } err = -EINVAL; __IP_INC_STATS(net, IPSTATS_MIB_REASM_OVERLAPS); discard_qp: - inet_frag_kill(&qp->q); + inet_frag_kill(&qp->q, refs); __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); err: - kfree_skb(skb); + kfree_skb_reason(skb, reason); return err; } @@ -401,7 +402,8 @@ static bool ip_frag_coalesce_ok(const struct ipq *qp) /* Build a new IP datagram from all its fragments. */ static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb, - struct sk_buff *prev_tail, struct net_device *dev) + struct sk_buff *prev_tail, struct net_device *dev, + int *refs) { struct net *net = qp->q.fqdir->net; struct iphdr *iph; @@ -409,7 +411,7 @@ static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb, int len, err; u8 ecn; - ipq_kill(qp); + inet_frag_kill(&qp->q, refs); ecn = ip_frag_ecn_table[qp->ecn]; if (unlikely(ecn == 0xff)) { @@ -474,26 +476,30 @@ out_fail: /* Process an incoming IP datagram fragment. */ int ip_defrag(struct net *net, struct sk_buff *skb, u32 user) { - struct net_device *dev = skb->dev ? : skb_dst(skb)->dev; - int vif = l3mdev_master_ifindex_rcu(dev); + struct net_device *dev; struct ipq *qp; + int vif; __IP_INC_STATS(net, IPSTATS_MIB_REASMREQDS); - skb_orphan(skb); /* Lookup (or create) queue header */ + rcu_read_lock(); + dev = skb->dev ? : skb_dst_dev_rcu(skb); + vif = l3mdev_master_ifindex_rcu(dev); qp = ip_find(net, ip_hdr(skb), user, vif); if (qp) { - int ret; + int ret, refs = 0; spin_lock(&qp->q.lock); - ret = ip_frag_queue(qp, skb); + ret = ip_frag_queue(qp, skb, &refs); spin_unlock(&qp->q.lock); - ipq_put(qp); + rcu_read_unlock(); + inet_frag_putn(&qp->q, refs); return ret; } + rcu_read_unlock(); __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); kfree_skb(skb); @@ -572,7 +578,6 @@ static struct ctl_table ip4_frags_ns_ctl_table[] = { .proc_handler = proc_dointvec_minmax, .extra1 = &dist_min, }, - { } }; /* secret interval has been deprecated */ @@ -585,7 +590,6 @@ static struct ctl_table ip4_frags_ctl_table[] = { .mode = 0644, .proc_handler = proc_dointvec_jiffies, }, - { } }; static int __net_init ip4_frags_ns_ctl_register(struct net *net) @@ -607,7 +611,8 @@ static int __net_init ip4_frags_ns_ctl_register(struct net *net) table[2].data = &net->ipv4.fqdir->timeout; table[3].data = &net->ipv4.fqdir->max_dist; - hdr = register_net_sysctl(net, "net/ipv4", table); + hdr = register_net_sysctl_sz(net, "net/ipv4", table, + ARRAY_SIZE(ip4_frags_ns_ctl_table)); if (!hdr) goto err_reg; @@ -623,7 +628,7 @@ err_alloc: static void __net_exit ip4_frags_ns_ctl_unregister(struct net *net) { - struct ctl_table *table; + const struct ctl_table *table; table = net->ipv4.frags_hdr->ctl_table_arg; unregister_net_sysctl_table(net->ipv4.frags_hdr); diff --git a/net/ipv4/ip_gre.c b/net/ipv4/ip_gre.c index 99db2e41ed10..761a53c6a89a 100644 --- a/net/ipv4/ip_gre.c +++ b/net/ipv4/ip_gre.c @@ -28,6 +28,7 @@ #include <linux/etherdevice.h> #include <linux/if_ether.h> +#include <net/flow.h> #include <net/sock.h> #include <net/ip.h> #include <net/icmp.h> @@ -140,7 +141,6 @@ static int ipgre_err(struct sk_buff *skb, u32 info, const struct iphdr *iph; const int type = icmp_hdr(skb)->type; const int code = icmp_hdr(skb)->code; - unsigned int data_len = 0; struct ip_tunnel *t; if (tpi->proto == htons(ETH_P_TEB)) @@ -181,7 +181,6 @@ static int ipgre_err(struct sk_buff *skb, u32 info, case ICMP_TIME_EXCEEDED: if (code != ICMP_EXC_TTL) return 0; - data_len = icmp_hdr(skb)->un.reserved[1] * 4; /* RFC 4884 4.1 */ break; case ICMP_REDIRECT: @@ -189,10 +188,16 @@ static int ipgre_err(struct sk_buff *skb, u32 info, } #if IS_ENABLED(CONFIG_IPV6) - if (tpi->proto == htons(ETH_P_IPV6) && - !ip6_err_gen_icmpv6_unreach(skb, iph->ihl * 4 + tpi->hdr_len, - type, data_len)) - return 0; + if (tpi->proto == htons(ETH_P_IPV6)) { + unsigned int data_len = 0; + + if (type == ICMP_TIME_EXCEEDED) + data_len = icmp_hdr(skb)->un.reserved[1] * 4; /* RFC 4884 4.1 */ + + if (!ip6_err_gen_icmpv6_unreach(skb, iph->ihl * 4 + tpi->hdr_len, + type, data_len)) + return 0; + } #endif if (t->parms.iph.daddr == 0 || @@ -265,6 +270,7 @@ static int erspan_rcv(struct sk_buff *skb, struct tnl_ptk_info *tpi, struct net *net = dev_net(skb->dev); struct metadata_dst *tun_dst = NULL; struct erspan_base_hdr *ershdr; + IP_TUNNEL_DECLARE_FLAGS(flags); struct ip_tunnel_net *itn; struct ip_tunnel *tunnel; const struct iphdr *iph; @@ -272,18 +278,25 @@ static int erspan_rcv(struct sk_buff *skb, struct tnl_ptk_info *tpi, int ver; int len; + ip_tunnel_flags_copy(flags, tpi->flags); + itn = net_generic(net, erspan_net_id); iph = ip_hdr(skb); if (is_erspan_type1(gre_hdr_len)) { ver = 0; - tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, - tpi->flags | TUNNEL_NO_KEY, + __set_bit(IP_TUNNEL_NO_KEY_BIT, flags); + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, flags, iph->saddr, iph->daddr, 0); } else { + if (unlikely(!pskb_may_pull(skb, + gre_hdr_len + sizeof(*ershdr)))) + return PACKET_REJECT; + ershdr = (struct erspan_base_hdr *)(skb->data + gre_hdr_len); ver = ershdr->ver; - tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, - tpi->flags | TUNNEL_KEY, + iph = ip_hdr(skb); + __set_bit(IP_TUNNEL_KEY_BIT, flags); + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, flags, iph->saddr, iph->daddr, tpi->key); } @@ -307,10 +320,9 @@ static int erspan_rcv(struct sk_buff *skb, struct tnl_ptk_info *tpi, struct ip_tunnel_info *info; unsigned char *gh; __be64 tun_id; - __be16 flags; - tpi->flags |= TUNNEL_KEY; - flags = tpi->flags; + __set_bit(IP_TUNNEL_KEY_BIT, tpi->flags); + ip_tunnel_flags_copy(flags, tpi->flags); tun_id = key32_to_tunnel_id(tpi->key); tun_dst = ip_tun_rx_dst(skb, flags, @@ -333,7 +345,8 @@ static int erspan_rcv(struct sk_buff *skb, struct tnl_ptk_info *tpi, ERSPAN_V2_MDSIZE); info = &tun_dst->u.tun_info; - info->key.tun_flags |= TUNNEL_ERSPAN_OPT; + __set_bit(IP_TUNNEL_ERSPAN_OPT_BIT, + info->key.tun_flags); info->options_len = sizeof(*md); } @@ -376,10 +389,13 @@ static int __ipgre_rcv(struct sk_buff *skb, const struct tnl_ptk_info *tpi, tnl_params = &tunnel->parms.iph; if (tunnel->collect_md || tnl_params->daddr == 0) { - __be16 flags; + IP_TUNNEL_DECLARE_FLAGS(flags) = { }; __be64 tun_id; - flags = tpi->flags & (TUNNEL_CSUM | TUNNEL_KEY); + __set_bit(IP_TUNNEL_CSUM_BIT, flags); + __set_bit(IP_TUNNEL_KEY_BIT, flags); + ip_tunnel_flags_and(flags, tpi->flags, flags); + tun_id = key32_to_tunnel_id(tpi->key); tun_dst = ip_tun_rx_dst(skb, flags, tun_id, 0); if (!tun_dst) @@ -459,14 +475,15 @@ static void __gre_xmit(struct sk_buff *skb, struct net_device *dev, __be16 proto) { struct ip_tunnel *tunnel = netdev_priv(dev); + IP_TUNNEL_DECLARE_FLAGS(flags); - if (tunnel->parms.o_flags & TUNNEL_SEQ) - tunnel->o_seqno++; + ip_tunnel_flags_copy(flags, tunnel->parms.o_flags); /* Push GRE header. */ gre_build_header(skb, tunnel->tun_hlen, - tunnel->parms.o_flags, proto, tunnel->parms.o_key, - htonl(tunnel->o_seqno)); + flags, proto, tunnel->parms.o_key, + test_bit(IP_TUNNEL_SEQ_BIT, flags) ? + htonl(atomic_fetch_inc(&tunnel->o_seqno)) : 0); ip_tunnel_xmit(skb, dev, tnl_params, tnl_params->protocol); } @@ -480,10 +497,10 @@ static void gre_fb_xmit(struct sk_buff *skb, struct net_device *dev, __be16 proto) { struct ip_tunnel *tunnel = netdev_priv(dev); + IP_TUNNEL_DECLARE_FLAGS(flags) = { }; struct ip_tunnel_info *tun_info; const struct ip_tunnel_key *key; int tunnel_hlen; - __be16 flags; tun_info = skb_tunnel_info(skb); if (unlikely(!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX) || @@ -497,14 +514,19 @@ static void gre_fb_xmit(struct sk_buff *skb, struct net_device *dev, goto err_free_skb; /* Push Tunnel header. */ - if (gre_handle_offloads(skb, !!(tun_info->key.tun_flags & TUNNEL_CSUM))) + if (gre_handle_offloads(skb, test_bit(IP_TUNNEL_CSUM_BIT, + tunnel->parms.o_flags))) goto err_free_skb; - flags = tun_info->key.tun_flags & - (TUNNEL_CSUM | TUNNEL_KEY | TUNNEL_SEQ); + __set_bit(IP_TUNNEL_CSUM_BIT, flags); + __set_bit(IP_TUNNEL_KEY_BIT, flags); + __set_bit(IP_TUNNEL_SEQ_BIT, flags); + ip_tunnel_flags_and(flags, tun_info->key.tun_flags, flags); + gre_build_header(skb, tunnel_hlen, flags, proto, tunnel_id_to_key32(tun_info->key.tun_id), - (flags & TUNNEL_SEQ) ? htonl(tunnel->o_seqno++) : 0); + test_bit(IP_TUNNEL_SEQ_BIT, flags) ? + htonl(atomic_fetch_inc(&tunnel->o_seqno)) : 0); ip_md_tunnel_xmit(skb, dev, IPPROTO_GRE, tunnel_hlen); @@ -512,12 +534,13 @@ static void gre_fb_xmit(struct sk_buff *skb, struct net_device *dev, err_free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); } static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) { struct ip_tunnel *tunnel = netdev_priv(dev); + IP_TUNNEL_DECLARE_FLAGS(flags) = { }; struct ip_tunnel_info *tun_info; const struct ip_tunnel_key *key; struct erspan_metadata *md; @@ -526,7 +549,6 @@ static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) int tunnel_hlen; int version; int nhoff; - int thoff; tun_info = skb_tunnel_info(skb); if (unlikely(!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX) || @@ -534,7 +556,7 @@ static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) goto err_free_skb; key = &tun_info->key; - if (!(tun_info->key.tun_flags & TUNNEL_ERSPAN_OPT)) + if (!test_bit(IP_TUNNEL_ERSPAN_OPT_BIT, tun_info->key.tun_flags)) goto err_free_skb; if (tun_info->options_len < sizeof(*md)) goto err_free_skb; @@ -551,19 +573,26 @@ static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) goto err_free_skb; if (skb->len > dev->mtu + dev->hard_header_len) { - pskb_trim(skb, dev->mtu + dev->hard_header_len); + if (pskb_trim(skb, dev->mtu + dev->hard_header_len)) + goto err_free_skb; truncate = true; } - nhoff = skb_network_header(skb) - skb_mac_header(skb); + nhoff = skb_network_offset(skb); if (skb->protocol == htons(ETH_P_IP) && (ntohs(ip_hdr(skb)->tot_len) > skb->len - nhoff)) truncate = true; - thoff = skb_transport_header(skb) - skb_mac_header(skb); - if (skb->protocol == htons(ETH_P_IPV6) && - (ntohs(ipv6_hdr(skb)->payload_len) > skb->len - thoff)) - truncate = true; + if (skb->protocol == htons(ETH_P_IPV6)) { + int thoff; + + if (skb_transport_header_was_set(skb)) + thoff = skb_transport_offset(skb); + else + thoff = nhoff + sizeof(struct ipv6hdr); + if (ntohs(ipv6_hdr(skb)->payload_len) > skb->len - thoff) + truncate = true; + } if (version == 1) { erspan_build_header(skb, ntohl(tunnel_id_to_key32(key->tun_id)), @@ -580,8 +609,9 @@ static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) goto err_free_skb; } - gre_build_header(skb, 8, TUNNEL_SEQ, - proto, 0, htonl(tunnel->o_seqno++)); + __set_bit(IP_TUNNEL_SEQ_BIT, flags); + gre_build_header(skb, 8, flags, proto, 0, + htonl(atomic_fetch_inc(&tunnel->o_seqno))); ip_md_tunnel_xmit(skb, dev, IPPROTO_GRE, tunnel_hlen); @@ -589,7 +619,7 @@ static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) err_free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); } static int gre_fill_metadata_dst(struct net_device *dev, struct sk_buff *skb) @@ -605,8 +635,8 @@ static int gre_fill_metadata_dst(struct net_device *dev, struct sk_buff *skb) key = &info->key; ip_tunnel_init_flow(&fl4, IPPROTO_GRE, key->u.ipv4.dst, key->u.ipv4.src, tunnel_id_to_key32(key->tun_id), - key->tos & ~INET_ECN_MASK, 0, skb->mark, - skb_get_hash(skb)); + key->tos & ~INET_ECN_MASK, dev_net(dev), 0, + skb->mark, skb_get_hash(skb), key->flow_flags); rt = ip_route_output_key(dev_net(dev), &fl4); if (IS_ERR(rt)) return PTR_ERR(rt); @@ -631,21 +661,23 @@ static netdev_tx_t ipgre_xmit(struct sk_buff *skb, } if (dev->header_ops) { - const int pull_len = tunnel->hlen + sizeof(struct iphdr); + int pull_len = tunnel->hlen + sizeof(struct iphdr); if (skb_cow_head(skb, 0)) goto free_skb; - tnl_params = (const struct iphdr *)skb->data; - - if (pull_len > skb_transport_offset(skb)) + if (!pskb_may_pull(skb, pull_len)) goto free_skb; - /* Pull skb since ip_tunnel_xmit() needs skb->data pointing - * to gre header. - */ + tnl_params = (const struct iphdr *)skb->data; + + /* ip_tunnel_xmit() needs skb->data pointing to gre header. */ skb_pull(skb, pull_len); skb_reset_mac_header(skb); + + if (skb->ip_summed == CHECKSUM_PARTIAL && + skb_checksum_start(skb) < skb->data) + goto free_skb; } else { if (skb_cow_head(skb, dev->needed_headroom)) goto free_skb; @@ -653,7 +685,8 @@ static netdev_tx_t ipgre_xmit(struct sk_buff *skb, tnl_params = &tunnel->parms.iph; } - if (gre_handle_offloads(skb, !!(tunnel->parms.o_flags & TUNNEL_CSUM))) + if (gre_handle_offloads(skb, test_bit(IP_TUNNEL_CSUM_BIT, + tunnel->parms.o_flags))) goto free_skb; __gre_xmit(skb, dev, tnl_params, skb->protocol); @@ -661,7 +694,7 @@ static netdev_tx_t ipgre_xmit(struct sk_buff *skb, free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); return NETDEV_TX_OK; } @@ -687,14 +720,15 @@ static netdev_tx_t erspan_xmit(struct sk_buff *skb, goto free_skb; if (skb->len > dev->mtu + dev->hard_header_len) { - pskb_trim(skb, dev->mtu + dev->hard_header_len); + if (pskb_trim(skb, dev->mtu + dev->hard_header_len)) + goto free_skb; truncate = true; } /* Push ERSPAN header */ if (tunnel->erspan_ver == 0) { proto = htons(ETH_P_ERSPAN); - tunnel->parms.o_flags &= ~TUNNEL_SEQ; + __clear_bit(IP_TUNNEL_SEQ_BIT, tunnel->parms.o_flags); } else if (tunnel->erspan_ver == 1) { erspan_build_header(skb, ntohl(tunnel->parms.o_key), tunnel->index, @@ -709,13 +743,13 @@ static netdev_tx_t erspan_xmit(struct sk_buff *skb, goto free_skb; } - tunnel->parms.o_flags &= ~TUNNEL_KEY; + __clear_bit(IP_TUNNEL_KEY_BIT, tunnel->parms.o_flags); __gre_xmit(skb, dev, &tunnel->parms.iph, proto); return NETDEV_TX_OK; free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); return NETDEV_TX_OK; } @@ -732,7 +766,8 @@ static netdev_tx_t gre_tap_xmit(struct sk_buff *skb, return NETDEV_TX_OK; } - if (gre_handle_offloads(skb, !!(tunnel->parms.o_flags & TUNNEL_CSUM))) + if (gre_handle_offloads(skb, test_bit(IP_TUNNEL_CSUM_BIT, + tunnel->parms.o_flags))) goto free_skb; if (skb_cow_head(skb, dev->needed_headroom)) @@ -743,7 +778,7 @@ static netdev_tx_t gre_tap_xmit(struct sk_buff *skb, free_skb: kfree_skb(skb); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); return NETDEV_TX_OK; } @@ -763,38 +798,42 @@ static void ipgre_link_update(struct net_device *dev, bool set_mtu) dev->needed_headroom += len; if (set_mtu) - dev->mtu = max_t(int, dev->mtu - len, 68); - - if (!(tunnel->parms.o_flags & TUNNEL_SEQ)) { - if (!(tunnel->parms.o_flags & TUNNEL_CSUM) || - tunnel->encap.type == TUNNEL_ENCAP_NONE) { - dev->features |= NETIF_F_GSO_SOFTWARE; - dev->hw_features |= NETIF_F_GSO_SOFTWARE; - } else { - dev->features &= ~NETIF_F_GSO_SOFTWARE; - dev->hw_features &= ~NETIF_F_GSO_SOFTWARE; - } - dev->features |= NETIF_F_LLTX; - } else { + WRITE_ONCE(dev->mtu, max_t(int, dev->mtu - len, 68)); + + if (test_bit(IP_TUNNEL_SEQ_BIT, tunnel->parms.o_flags) || + (test_bit(IP_TUNNEL_CSUM_BIT, tunnel->parms.o_flags) && + tunnel->encap.type != TUNNEL_ENCAP_NONE)) { + dev->features &= ~NETIF_F_GSO_SOFTWARE; dev->hw_features &= ~NETIF_F_GSO_SOFTWARE; - dev->features &= ~(NETIF_F_LLTX | NETIF_F_GSO_SOFTWARE); + } else { + dev->features |= NETIF_F_GSO_SOFTWARE; + dev->hw_features |= NETIF_F_GSO_SOFTWARE; } } -static int ipgre_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, +static int ipgre_tunnel_ctl(struct net_device *dev, + struct ip_tunnel_parm_kern *p, int cmd) { + __be16 i_flags, o_flags; int err; + if (!ip_tunnel_flags_is_be16_compat(p->i_flags) || + !ip_tunnel_flags_is_be16_compat(p->o_flags)) + return -EOVERFLOW; + + i_flags = ip_tunnel_flags_to_be16(p->i_flags); + o_flags = ip_tunnel_flags_to_be16(p->o_flags); + if (cmd == SIOCADDTUNNEL || cmd == SIOCCHGTUNNEL) { if (p->iph.version != 4 || p->iph.protocol != IPPROTO_GRE || p->iph.ihl != 5 || (p->iph.frag_off & htons(~IP_DF)) || - ((p->i_flags | p->o_flags) & (GRE_VERSION | GRE_ROUTING))) + ((i_flags | o_flags) & (GRE_VERSION | GRE_ROUTING))) return -EINVAL; } - p->i_flags = gre_flags_to_tnl_flags(p->i_flags); - p->o_flags = gre_flags_to_tnl_flags(p->o_flags); + gre_flags_to_tnl_flags(p->i_flags, i_flags); + gre_flags_to_tnl_flags(p->o_flags, o_flags); err = ip_tunnel_ctl(dev, p, cmd); if (err) @@ -803,15 +842,18 @@ static int ipgre_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, if (cmd == SIOCCHGTUNNEL) { struct ip_tunnel *t = netdev_priv(dev); - t->parms.i_flags = p->i_flags; - t->parms.o_flags = p->o_flags; + ip_tunnel_flags_copy(t->parms.i_flags, p->i_flags); + ip_tunnel_flags_copy(t->parms.o_flags, p->o_flags); if (strcmp(dev->rtnl_link_ops->kind, "erspan")) ipgre_link_update(dev, true); } - p->i_flags = gre_tnl_flags_to_gre_flags(p->i_flags); - p->o_flags = gre_tnl_flags_to_gre_flags(p->o_flags); + i_flags = gre_tnl_flags_to_gre_flags(p->i_flags); + ip_tunnel_flags_from_be16(p->i_flags, i_flags); + o_flags = gre_tnl_flags_to_gre_flags(p->o_flags); + ip_tunnel_flags_from_be16(p->o_flags, o_flags); + return 0; } @@ -886,15 +928,18 @@ static int ipgre_open(struct net_device *dev) struct ip_tunnel *t = netdev_priv(dev); if (ipv4_is_multicast(t->parms.iph.daddr)) { - struct flowi4 fl4; + struct flowi4 fl4 = { + .flowi4_oif = t->parms.link, + .flowi4_dscp = ip4h_dscp(&t->parms.iph), + .flowi4_scope = RT_SCOPE_UNIVERSE, + .flowi4_proto = IPPROTO_GRE, + .saddr = t->parms.iph.saddr, + .daddr = t->parms.iph.daddr, + .fl4_gre_key = t->parms.o_key, + }; struct rtable *rt; - rt = ip_route_output_gre(t->net, &fl4, - t->parms.iph.daddr, - t->parms.iph.saddr, - t->parms.o_key, - RT_TOS(t->parms.iph.tos), - t->parms.link); + rt = ip_route_output_key(t->net, &fl4); if (IS_ERR(rt)) return -EADDRNOTAVAIL; dev = rt->dst.dev; @@ -962,22 +1007,19 @@ static void __gre_tunnel_init(struct net_device *dev) dev->features |= GRE_FEATURES; dev->hw_features |= GRE_FEATURES; - if (!(tunnel->parms.o_flags & TUNNEL_SEQ)) { - /* TCP offload with GRE SEQ is not supported, nor - * can we support 2 levels of outer headers requiring - * an update. - */ - if (!(tunnel->parms.o_flags & TUNNEL_CSUM) || - (tunnel->encap.type == TUNNEL_ENCAP_NONE)) { - dev->features |= NETIF_F_GSO_SOFTWARE; - dev->hw_features |= NETIF_F_GSO_SOFTWARE; - } + /* TCP offload with GRE SEQ is not supported, nor can we support 2 + * levels of outer headers requiring an update. + */ + if (test_bit(IP_TUNNEL_SEQ_BIT, tunnel->parms.o_flags)) + return; + if (test_bit(IP_TUNNEL_CSUM_BIT, tunnel->parms.o_flags) && + tunnel->encap.type != TUNNEL_ENCAP_NONE) + return; - /* Can use a lockless transmit, unless we generate - * output sequences - */ - dev->features |= NETIF_F_LLTX; - } + dev->features |= NETIF_F_GSO_SOFTWARE; + dev->hw_features |= NETIF_F_GSO_SOFTWARE; + + dev->lltx = true; } static int ipgre_tunnel_init(struct net_device *dev) @@ -1024,14 +1066,15 @@ static int __net_init ipgre_init_net(struct net *net) return ip_tunnel_init_net(net, ipgre_net_id, &ipgre_link_ops, NULL); } -static void __net_exit ipgre_exit_batch_net(struct list_head *list_net) +static void __net_exit ipgre_exit_rtnl(struct net *net, + struct list_head *dev_to_kill) { - ip_tunnel_delete_nets(list_net, ipgre_net_id, &ipgre_link_ops); + ip_tunnel_delete_net(net, ipgre_net_id, &ipgre_link_ops, dev_to_kill); } static struct pernet_operations ipgre_net_ops = { .init = ipgre_init_net, - .exit_batch = ipgre_exit_batch_net, + .exit_rtnl = ipgre_exit_rtnl, .id = &ipgre_net_id, .size = sizeof(struct ip_tunnel_net), }; @@ -1128,7 +1171,7 @@ static int erspan_validate(struct nlattr *tb[], struct nlattr *data[], static int ipgre_netlink_parms(struct net_device *dev, struct nlattr *data[], struct nlattr *tb[], - struct ip_tunnel_parm *parms, + struct ip_tunnel_parm_kern *parms, __u32 *fwmark) { struct ip_tunnel *t = netdev_priv(dev); @@ -1144,10 +1187,12 @@ static int ipgre_netlink_parms(struct net_device *dev, parms->link = nla_get_u32(data[IFLA_GRE_LINK]); if (data[IFLA_GRE_IFLAGS]) - parms->i_flags = gre_flags_to_tnl_flags(nla_get_be16(data[IFLA_GRE_IFLAGS])); + gre_flags_to_tnl_flags(parms->i_flags, + nla_get_be16(data[IFLA_GRE_IFLAGS])); if (data[IFLA_GRE_OFLAGS]) - parms->o_flags = gre_flags_to_tnl_flags(nla_get_be16(data[IFLA_GRE_OFLAGS])); + gre_flags_to_tnl_flags(parms->o_flags, + nla_get_be16(data[IFLA_GRE_OFLAGS])); if (data[IFLA_GRE_IKEY]) parms->i_key = nla_get_be32(data[IFLA_GRE_IKEY]); @@ -1195,7 +1240,7 @@ static int ipgre_netlink_parms(struct net_device *dev, static int erspan_netlink_parms(struct net_device *dev, struct nlattr *data[], struct nlattr *tb[], - struct ip_tunnel_parm *parms, + struct ip_tunnel_parm_kern *parms, __u32 *fwmark) { struct ip_tunnel *t = netdev_priv(dev); @@ -1350,11 +1395,13 @@ ipgre_newlink_encap_setup(struct net_device *dev, struct nlattr *data[]) return 0; } -static int ipgre_newlink(struct net *src_net, struct net_device *dev, - struct nlattr *tb[], struct nlattr *data[], +static int ipgre_newlink(struct net_device *dev, + struct rtnl_newlink_params *params, struct netlink_ext_ack *extack) { - struct ip_tunnel_parm p; + struct nlattr **data = params->data; + struct nlattr **tb = params->tb; + struct ip_tunnel_parm_kern p; __u32 fwmark = 0; int err; @@ -1365,14 +1412,17 @@ static int ipgre_newlink(struct net *src_net, struct net_device *dev, err = ipgre_netlink_parms(dev, data, tb, &p, &fwmark); if (err < 0) return err; - return ip_tunnel_newlink(dev, tb, &p, fwmark); + return ip_tunnel_newlink(params->link_net ? : dev_net(dev), dev, tb, &p, + fwmark); } -static int erspan_newlink(struct net *src_net, struct net_device *dev, - struct nlattr *tb[], struct nlattr *data[], +static int erspan_newlink(struct net_device *dev, + struct rtnl_newlink_params *params, struct netlink_ext_ack *extack) { - struct ip_tunnel_parm p; + struct nlattr **data = params->data; + struct nlattr **tb = params->tb; + struct ip_tunnel_parm_kern p; __u32 fwmark = 0; int err; @@ -1383,7 +1433,8 @@ static int erspan_newlink(struct net *src_net, struct net_device *dev, err = erspan_netlink_parms(dev, data, tb, &p, &fwmark); if (err) return err; - return ip_tunnel_newlink(dev, tb, &p, fwmark); + return ip_tunnel_newlink(params->link_net ? : dev_net(dev), dev, tb, &p, + fwmark); } static int ipgre_changelink(struct net_device *dev, struct nlattr *tb[], @@ -1391,8 +1442,8 @@ static int ipgre_changelink(struct net_device *dev, struct nlattr *tb[], struct netlink_ext_ack *extack) { struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_parm_kern p; __u32 fwmark = t->fwmark; - struct ip_tunnel_parm p; int err; err = ipgre_newlink_encap_setup(dev, data); @@ -1407,8 +1458,8 @@ static int ipgre_changelink(struct net_device *dev, struct nlattr *tb[], if (err < 0) return err; - t->parms.i_flags = p.i_flags; - t->parms.o_flags = p.o_flags; + ip_tunnel_flags_copy(t->parms.i_flags, p.i_flags); + ip_tunnel_flags_copy(t->parms.o_flags, p.o_flags); ipgre_link_update(dev, !tb[IFLA_MTU]); @@ -1420,8 +1471,8 @@ static int erspan_changelink(struct net_device *dev, struct nlattr *tb[], struct netlink_ext_ack *extack) { struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_parm_kern p; __u32 fwmark = t->fwmark; - struct ip_tunnel_parm p; int err; err = ipgre_newlink_encap_setup(dev, data); @@ -1436,8 +1487,8 @@ static int erspan_changelink(struct net_device *dev, struct nlattr *tb[], if (err < 0) return err; - t->parms.i_flags = p.i_flags; - t->parms.o_flags = p.o_flags; + ip_tunnel_flags_copy(t->parms.i_flags, p.i_flags); + ip_tunnel_flags_copy(t->parms.o_flags, p.o_flags); return 0; } @@ -1493,26 +1544,10 @@ static size_t ipgre_get_size(const struct net_device *dev) static int ipgre_fill_info(struct sk_buff *skb, const struct net_device *dev) { struct ip_tunnel *t = netdev_priv(dev); - struct ip_tunnel_parm *p = &t->parms; - __be16 o_flags = p->o_flags; + struct ip_tunnel_parm_kern *p = &t->parms; + IP_TUNNEL_DECLARE_FLAGS(o_flags); - if (t->erspan_ver <= 2) { - if (t->erspan_ver != 0 && !t->collect_md) - o_flags |= TUNNEL_KEY; - - if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, t->erspan_ver)) - goto nla_put_failure; - - if (t->erspan_ver == 1) { - if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, t->index)) - goto nla_put_failure; - } else if (t->erspan_ver == 2) { - if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, t->dir)) - goto nla_put_failure; - if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, t->hwid)) - goto nla_put_failure; - } - } + ip_tunnel_flags_copy(o_flags, p->o_flags); if (nla_put_u32(skb, IFLA_GRE_LINK, p->link) || nla_put_be16(skb, IFLA_GRE_IFLAGS, @@ -1554,6 +1589,34 @@ nla_put_failure: return -EMSGSIZE; } +static int erspan_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + + if (t->erspan_ver <= 2) { + if (t->erspan_ver != 0 && !t->collect_md) + __set_bit(IP_TUNNEL_KEY_BIT, t->parms.o_flags); + + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, t->erspan_ver)) + goto nla_put_failure; + + if (t->erspan_ver == 1) { + if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, t->index)) + goto nla_put_failure; + } else if (t->erspan_ver == 2) { + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, t->dir)) + goto nla_put_failure; + if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, t->hwid)) + goto nla_put_failure; + } + } + + return ipgre_fill_info(skb, dev); + +nla_put_failure: + return -EMSGSIZE; +} + static void erspan_setup(struct net_device *dev) { struct ip_tunnel *t = netdev_priv(dev); @@ -1632,13 +1695,14 @@ static struct rtnl_link_ops erspan_link_ops __read_mostly = { .changelink = erspan_changelink, .dellink = ip_tunnel_dellink, .get_size = ipgre_get_size, - .fill_info = ipgre_fill_info, + .fill_info = erspan_fill_info, .get_link_net = ip_tunnel_get_link_net, }; struct net_device *gretap_fb_dev_create(struct net *net, const char *name, u8 name_assign_type) { + struct rtnl_newlink_params params = { .src_net = net }; struct nlattr *tb[IFLA_MAX + 1]; struct net_device *dev; LIST_HEAD(list_kill); @@ -1646,6 +1710,7 @@ struct net_device *gretap_fb_dev_create(struct net *net, const char *name, int err; memset(&tb, 0, sizeof(tb)); + params.tb = tb; dev = rtnl_create_link(net, name, name_assign_type, &ipgre_tap_ops, tb, NULL); @@ -1656,7 +1721,7 @@ struct net_device *gretap_fb_dev_create(struct net *net, const char *name, t = netdev_priv(dev); t->collect_md = true; - err = ipgre_newlink(net, dev, tb, NULL, NULL); + err = ipgre_newlink(dev, ¶ms, NULL); if (err < 0) { free_netdev(dev); return ERR_PTR(err); @@ -1669,7 +1734,7 @@ struct net_device *gretap_fb_dev_create(struct net *net, const char *name, if (err) goto out; - err = rtnl_configure_link(dev, NULL); + err = rtnl_configure_link(dev, NULL, 0, NULL); if (err < 0) goto out; @@ -1686,14 +1751,15 @@ static int __net_init ipgre_tap_init_net(struct net *net) return ip_tunnel_init_net(net, gre_tap_net_id, &ipgre_tap_ops, "gretap0"); } -static void __net_exit ipgre_tap_exit_batch_net(struct list_head *list_net) +static void __net_exit ipgre_tap_exit_rtnl(struct net *net, + struct list_head *dev_to_kill) { - ip_tunnel_delete_nets(list_net, gre_tap_net_id, &ipgre_tap_ops); + ip_tunnel_delete_net(net, gre_tap_net_id, &ipgre_tap_ops, dev_to_kill); } static struct pernet_operations ipgre_tap_net_ops = { .init = ipgre_tap_init_net, - .exit_batch = ipgre_tap_exit_batch_net, + .exit_rtnl = ipgre_tap_exit_rtnl, .id = &gre_tap_net_id, .size = sizeof(struct ip_tunnel_net), }; @@ -1704,14 +1770,15 @@ static int __net_init erspan_init_net(struct net *net) &erspan_link_ops, "erspan0"); } -static void __net_exit erspan_exit_batch_net(struct list_head *net_list) +static void __net_exit erspan_exit_rtnl(struct net *net, + struct list_head *dev_to_kill) { - ip_tunnel_delete_nets(net_list, erspan_net_id, &erspan_link_ops); + ip_tunnel_delete_net(net, erspan_net_id, &erspan_link_ops, dev_to_kill); } static struct pernet_operations erspan_net_ops = { .init = erspan_init_net, - .exit_batch = erspan_exit_batch_net, + .exit_rtnl = erspan_exit_rtnl, .id = &erspan_net_id, .size = sizeof(struct ip_tunnel_net), }; @@ -1782,6 +1849,7 @@ static void __exit ipgre_fini(void) module_init(ipgre_init); module_exit(ipgre_fini); +MODULE_DESCRIPTION("IPv4 GRE tunnels over IP library"); MODULE_LICENSE("GPL"); MODULE_ALIAS_RTNL_LINK("gre"); MODULE_ALIAS_RTNL_LINK("gretap"); diff --git a/net/ipv4/ip_input.c b/net/ipv4/ip_input.c index 3a025c011971..19d3141dad1f 100644 --- a/net/ipv4/ip_input.c +++ b/net/ipv4/ip_input.c @@ -141,6 +141,8 @@ #include <linux/mroute.h> #include <linux/netlink.h> #include <net/dst_metadata.h> +#include <net/udp.h> +#include <net/tcp.h> /* * Process Router Attention IP option (RFC 2113) @@ -196,7 +198,8 @@ resubmit: if (ipprot) { if (!ipprot->no_policy) { if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) { - kfree_skb(skb); + kfree_skb_reason(skb, + SKB_DROP_REASON_XFRM_POLICY); return; } nf_reset_ct(skb); @@ -215,7 +218,7 @@ resubmit: icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PROT_UNREACH, 0); } - kfree_skb(skb); + kfree_skb_reason(skb, SKB_DROP_REASON_IP_NOPROTO); } else { __IP_INC_STATS(net, IPSTATS_MIB_INDELIVERS); consume_skb(skb); @@ -225,6 +228,13 @@ resubmit: static int ip_local_deliver_finish(struct net *net, struct sock *sk, struct sk_buff *skb) { + if (unlikely(skb_orphan_frags_rx(skb, GFP_ATOMIC))) { + __IP_INC_STATS(net, IPSTATS_MIB_INDISCARDS); + kfree_skb_reason(skb, SKB_DROP_REASON_NOMEM); + return 0; + } + + skb_clear_delivery_time(skb); __skb_pull(skb, skb_network_header_len(skb)); rcu_read_lock(); @@ -255,10 +265,11 @@ int ip_local_deliver(struct sk_buff *skb) } EXPORT_SYMBOL(ip_local_deliver); -static inline bool ip_rcv_options(struct sk_buff *skb, struct net_device *dev) +static inline enum skb_drop_reason +ip_rcv_options(struct sk_buff *skb, struct net_device *dev) { - struct ip_options *opt; const struct iphdr *iph; + struct ip_options *opt; /* It looks as overkill, because not all IP options require packet mangling. @@ -269,7 +280,7 @@ static inline bool ip_rcv_options(struct sk_buff *skb, struct net_device *dev) */ if (skb_cow(skb, skb_headroom(skb))) { __IP_INC_STATS(dev_net(dev), IPSTATS_MIB_INDISCARDS); - goto drop; + return SKB_DROP_REASON_NOMEM; } iph = ip_hdr(skb); @@ -278,7 +289,7 @@ static inline bool ip_rcv_options(struct sk_buff *skb, struct net_device *dev) if (ip_options_compile(dev_net(dev), opt, skb)) { __IP_INC_STATS(dev_net(dev), IPSTATS_MIB_INHDRERRORS); - goto drop; + return SKB_DROP_REASON_IP_INHDR; } if (unlikely(opt->srr)) { @@ -290,17 +301,15 @@ static inline bool ip_rcv_options(struct sk_buff *skb, struct net_device *dev) net_info_ratelimited("source route option %pI4 -> %pI4\n", &iph->saddr, &iph->daddr); - goto drop; + return SKB_DROP_REASON_NOT_SPECIFIED; } } if (ip_options_rcv_srr(skb, dev)) - goto drop; + return SKB_DROP_REASON_NOT_SPECIFIED; } - return false; -drop: - return true; + return SKB_NOT_DROPPED_YET; } static bool ip_can_use_hint(const struct sk_buff *skb, const struct iphdr *iph, @@ -310,39 +319,44 @@ static bool ip_can_use_hint(const struct sk_buff *skb, const struct iphdr *iph, ip_hdr(hint)->tos == iph->tos; } -INDIRECT_CALLABLE_DECLARE(int udp_v4_early_demux(struct sk_buff *)); -INDIRECT_CALLABLE_DECLARE(int tcp_v4_early_demux(struct sk_buff *)); -static int ip_rcv_finish_core(struct net *net, struct sock *sk, +static int ip_rcv_finish_core(struct net *net, struct sk_buff *skb, struct net_device *dev, const struct sk_buff *hint) { const struct iphdr *iph = ip_hdr(skb); - int (*edemux)(struct sk_buff *skb); struct rtable *rt; - int err; + int drop_reason; if (ip_can_use_hint(skb, iph, hint)) { - err = ip_route_use_hint(skb, iph->daddr, iph->saddr, iph->tos, - dev, hint); - if (unlikely(err)) + drop_reason = ip_route_use_hint(skb, iph->daddr, iph->saddr, + ip4h_dscp(iph), dev, hint); + if (unlikely(drop_reason)) goto drop_error; } - if (net->ipv4.sysctl_ip_early_demux && + if (READ_ONCE(net->ipv4.sysctl_ip_early_demux) && !skb_dst(skb) && !skb->sk && !ip_is_fragment(iph)) { - const struct net_protocol *ipprot; - int protocol = iph->protocol; - - ipprot = rcu_dereference(inet_protos[protocol]); - if (ipprot && (edemux = READ_ONCE(ipprot->early_demux))) { - err = INDIRECT_CALL_2(edemux, tcp_v4_early_demux, - udp_v4_early_demux, skb); - if (unlikely(err)) - goto drop_error; - /* must reload iph, skb->head might have changed */ - iph = ip_hdr(skb); + switch (iph->protocol) { + case IPPROTO_TCP: + if (READ_ONCE(net->ipv4.sysctl_tcp_early_demux)) { + tcp_v4_early_demux(skb); + + /* must reload iph, skb->head might have changed */ + iph = ip_hdr(skb); + } + break; + case IPPROTO_UDP: + if (READ_ONCE(net->ipv4.sysctl_udp_early_demux)) { + drop_reason = udp_v4_early_demux(skb); + if (unlikely(drop_reason)) + goto drop_error; + + /* must reload iph, skb->head might have changed */ + iph = ip_hdr(skb); + } + break; } } @@ -351,10 +365,15 @@ static int ip_rcv_finish_core(struct net *net, struct sock *sk, * how the packet travels inside Linux networking. */ if (!skb_valid_dst(skb)) { - err = ip_route_input_noref(skb, iph->daddr, iph->saddr, - iph->tos, dev); - if (unlikely(err)) + drop_reason = ip_route_input_noref(skb, iph->daddr, iph->saddr, + ip4h_dscp(iph), dev); + if (unlikely(drop_reason)) goto drop_error; + } else { + struct in_device *in_dev = __in_dev_get_rcu(dev); + + if (in_dev && IN_DEV_ORCONF(in_dev, NOPOLICY)) + IPCB(skb)->flags |= IPSKB_NOPOLICY; } #ifdef CONFIG_IP_ROUTE_CLASSID @@ -368,8 +387,11 @@ static int ip_rcv_finish_core(struct net *net, struct sock *sk, } #endif - if (iph->ihl > 5 && ip_rcv_options(skb, dev)) - goto drop; + if (iph->ihl > 5) { + drop_reason = ip_rcv_options(skb, dev); + if (drop_reason) + goto drop; + } rt = skb_rtable(skb); if (rt->rt_type == RTN_MULTICAST) { @@ -396,18 +418,20 @@ static int ip_rcv_finish_core(struct net *net, struct sock *sk, * so-called "hole-196" attack) so do it for both. */ if (in_dev && - IN_DEV_ORCONF(in_dev, DROP_UNICAST_IN_L2_MULTICAST)) + IN_DEV_ORCONF(in_dev, DROP_UNICAST_IN_L2_MULTICAST)) { + drop_reason = SKB_DROP_REASON_UNICAST_IN_L2_MULTICAST; goto drop; + } } return NET_RX_SUCCESS; drop: - kfree_skb(skb); + kfree_skb_reason(skb, drop_reason); return NET_RX_DROP; drop_error: - if (err == -EXDEV) + if (drop_reason == SKB_DROP_REASON_IP_RPFILTER) __NET_INC_STATS(net, LINUX_MIB_IPRPFILTER); goto drop; } @@ -424,7 +448,7 @@ static int ip_rcv_finish(struct net *net, struct sock *sk, struct sk_buff *skb) if (!skb) return NET_RX_SUCCESS; - ret = ip_rcv_finish_core(net, sk, skb, dev, NULL); + ret = ip_rcv_finish_core(net, skb, dev, NULL); if (ret != NET_RX_DROP) ret = dst_input(skb); return ret; @@ -436,13 +460,17 @@ static int ip_rcv_finish(struct net *net, struct sock *sk, struct sk_buff *skb) static struct sk_buff *ip_rcv_core(struct sk_buff *skb, struct net *net) { const struct iphdr *iph; + int drop_reason; u32 len; /* When the interface is in promisc. mode, drop all the crap * that it receives, do not try to analyse it. */ - if (skb->pkt_type == PACKET_OTHERHOST) + if (skb->pkt_type == PACKET_OTHERHOST) { + dev_core_stats_rx_otherhost_dropped_inc(skb->dev); + drop_reason = SKB_DROP_REASON_OTHERHOST; goto drop; + } __IP_UPD_PO_STATS(net, IPSTATS_MIB_IN, skb->len); @@ -452,6 +480,7 @@ static struct sk_buff *ip_rcv_core(struct sk_buff *skb, struct net *net) goto out; } + drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; if (!pskb_may_pull(skb, sizeof(struct iphdr))) goto inhdr_error; @@ -486,8 +515,9 @@ static struct sk_buff *ip_rcv_core(struct sk_buff *skb, struct net *net) if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl))) goto csum_error; - len = ntohs(iph->tot_len); + len = iph_totlen(skb, iph); if (skb->len < len) { + drop_reason = SKB_DROP_REASON_PKT_TOO_SMALL; __IP_INC_STATS(net, IPSTATS_MIB_INTRUNCATEDPKTS); goto drop; } else if (len < (iph->ihl*4)) @@ -516,11 +546,14 @@ static struct sk_buff *ip_rcv_core(struct sk_buff *skb, struct net *net) return skb; csum_error: + drop_reason = SKB_DROP_REASON_IP_CSUM; __IP_INC_STATS(net, IPSTATS_MIB_CSUMERRORS); inhdr_error: + if (drop_reason == SKB_DROP_REASON_NOT_SPECIFIED) + drop_reason = SKB_DROP_REASON_IP_INHDR; __IP_INC_STATS(net, IPSTATS_MIB_INHDRERRORS); drop: - kfree_skb(skb); + kfree_skb_reason(skb, drop_reason); out: return NULL; } @@ -553,22 +586,25 @@ static void ip_sublist_rcv_finish(struct list_head *head) } static struct sk_buff *ip_extract_route_hint(const struct net *net, - struct sk_buff *skb, int rt_type) + struct sk_buff *skb) { - if (fib4_has_custom_rules(net) || rt_type == RTN_BROADCAST) + const struct iphdr *iph = ip_hdr(skb); + + if (fib4_has_custom_rules(net) || + ipv4_is_lbcast(iph->daddr) || + ipv4_is_zeronet(iph->daddr) || + IPCB(skb)->flags & IPSKB_MULTIPATH) return NULL; return skb; } -static void ip_list_rcv_finish(struct net *net, struct sock *sk, - struct list_head *head) +static void ip_list_rcv_finish(struct net *net, struct list_head *head) { struct sk_buff *skb, *next, *hint = NULL; struct dst_entry *curr_dst = NULL; - struct list_head sublist; + LIST_HEAD(sublist); - INIT_LIST_HEAD(&sublist); list_for_each_entry_safe(skb, next, head, list) { struct net_device *dev = skb->dev; struct dst_entry *dst; @@ -580,13 +616,12 @@ static void ip_list_rcv_finish(struct net *net, struct sock *sk, skb = l3mdev_ip_rcv(skb); if (!skb) continue; - if (ip_rcv_finish_core(net, sk, skb, dev, hint) == NET_RX_DROP) + if (ip_rcv_finish_core(net, skb, dev, hint) == NET_RX_DROP) continue; dst = skb_dst(skb); if (curr_dst != dst) { - hint = ip_extract_route_hint(net, skb, - ((struct rtable *)dst)->rt_type); + hint = ip_extract_route_hint(net, skb); /* dispatch old sublist */ if (!list_empty(&sublist)) @@ -606,7 +641,7 @@ static void ip_sublist_rcv(struct list_head *head, struct net_device *dev, { NF_HOOK_LIST(NFPROTO_IPV4, NF_INET_PRE_ROUTING, net, NULL, head, dev, NULL, ip_rcv_finish); - ip_list_rcv_finish(net, NULL, head); + ip_list_rcv_finish(net, head); } /* Receive a list of IP packets */ @@ -616,9 +651,8 @@ void ip_list_rcv(struct list_head *head, struct packet_type *pt, struct net_device *curr_dev = NULL; struct net *curr_net = NULL; struct sk_buff *skb, *next; - struct list_head sublist; + LIST_HEAD(sublist); - INIT_LIST_HEAD(&sublist); list_for_each_entry_safe(skb, next, head, list) { struct net_device *dev = skb->dev; struct net *net = dev_net(dev); diff --git a/net/ipv4/ip_options.c b/net/ipv4/ip_options.c index da1b5038bdfd..be8815ce3ac2 100644 --- a/net/ipv4/ip_options.c +++ b/net/ipv4/ip_options.c @@ -17,7 +17,7 @@ #include <linux/slab.h> #include <linux/types.h> #include <linux/uaccess.h> -#include <asm/unaligned.h> +#include <linux/unaligned.h> #include <linux/skbuff.h> #include <linux/ip.h> #include <linux/icmp.h> @@ -42,7 +42,7 @@ */ void ip_options_build(struct sk_buff *skb, struct ip_options *opt, - __be32 daddr, struct rtable *rt, int is_frag) + __be32 daddr, struct rtable *rt) { unsigned char *iph = skb_network_header(skb); @@ -53,28 +53,15 @@ void ip_options_build(struct sk_buff *skb, struct ip_options *opt, if (opt->srr) memcpy(iph + opt->srr + iph[opt->srr + 1] - 4, &daddr, 4); - if (!is_frag) { - if (opt->rr_needaddr) - ip_rt_get_source(iph + opt->rr + iph[opt->rr + 2] - 5, skb, rt); - if (opt->ts_needaddr) - ip_rt_get_source(iph + opt->ts + iph[opt->ts + 2] - 9, skb, rt); - if (opt->ts_needtime) { - __be32 midtime; + if (opt->rr_needaddr) + ip_rt_get_source(iph + opt->rr + iph[opt->rr + 2] - 5, skb, rt); + if (opt->ts_needaddr) + ip_rt_get_source(iph + opt->ts + iph[opt->ts + 2] - 9, skb, rt); + if (opt->ts_needtime) { + __be32 midtime; - midtime = inet_current_timestamp(); - memcpy(iph + opt->ts + iph[opt->ts + 2] - 5, &midtime, 4); - } - return; - } - if (opt->rr) { - memset(iph + opt->rr, IPOPT_NOP, iph[opt->rr + 1]); - opt->rr = 0; - opt->rr_needaddr = 0; - } - if (opt->ts) { - memset(iph + opt->ts, IPOPT_NOP, iph[opt->ts + 1]); - opt->ts = 0; - opt->ts_needaddr = opt->ts_needtime = 0; + midtime = inet_current_timestamp(); + memcpy(iph + opt->ts + iph[opt->ts + 2] - 5, &midtime, 4); } } @@ -628,13 +615,13 @@ int ip_options_rcv_srr(struct sk_buff *skb, struct net_device *dev) } memcpy(&nexthop, &optptr[srrptr-1], 4); - orefdst = skb->_skb_refdst; - skb_dst_set(skb, NULL); - err = ip_route_input(skb, nexthop, iph->saddr, iph->tos, dev); + orefdst = skb_dstref_steal(skb); + err = ip_route_input(skb, nexthop, iph->saddr, ip4h_dscp(iph), + dev) ? -EINVAL : 0; rt2 = skb_rtable(skb); if (err || (rt2->rt_type != RTN_UNICAST && rt2->rt_type != RTN_LOCAL)) { skb_dst_drop(skb); - skb->_skb_refdst = orefdst; + skb_dstref_restore(skb, orefdst); return -EINVAL; } refdst_drop(orefdst); diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c index 57c1d8431386..ff11d3a85a36 100644 --- a/net/ipv4/ip_output.c +++ b/net/ipv4/ip_output.c @@ -63,6 +63,7 @@ #include <linux/stat.h> #include <linux/init.h> +#include <net/flow.h> #include <net/snmp.h> #include <net/ip.h> #include <net/protocol.h> @@ -73,15 +74,17 @@ #include <net/arp.h> #include <net/icmp.h> #include <net/checksum.h> +#include <net/gso.h> #include <net/inetpeer.h> -#include <net/inet_ecn.h> #include <net/lwtunnel.h> +#include <net/inet_dscp.h> #include <linux/bpf-cgroup.h> #include <linux/igmp.h> #include <linux/netfilter_ipv4.h> #include <linux/netfilter_bridge.h> #include <linux/netlink.h> #include <linux/tcp.h> +#include <net/psp.h> static int ip_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, @@ -100,7 +103,9 @@ int __ip_local_out(struct net *net, struct sock *sk, struct sk_buff *skb) { struct iphdr *iph = ip_hdr(skb); - iph->tot_len = htons(skb->len); + IP_INC_STATS(net, IPSTATS_MIB_OUTREQUESTS); + + iph_set_totlen(iph, skb->len); ip_send_check(iph); /* if egress device is enslaved to an L3 master device pass the @@ -113,7 +118,7 @@ int __ip_local_out(struct net *net, struct sock *sk, struct sk_buff *skb) skb->protocol = htons(ETH_P_IP); return nf_hook(NFPROTO_IPV4, NF_INET_LOCAL_OUT, - net, sk, skb, NULL, skb_dst(skb)->dev, + net, sk, skb, NULL, skb_dst_dev(skb), dst_output); } @@ -129,9 +134,10 @@ int ip_local_out(struct net *net, struct sock *sk, struct sk_buff *skb) } EXPORT_SYMBOL_GPL(ip_local_out); -static inline int ip_select_ttl(struct inet_sock *inet, struct dst_entry *dst) +static inline int ip_select_ttl(const struct inet_sock *inet, + const struct dst_entry *dst) { - int ttl = inet->uc_ttl; + int ttl = READ_ONCE(inet->uc_ttl); if (ttl < 0) ttl = ip4_dst_hoplimit(dst); @@ -146,7 +152,7 @@ int ip_build_and_send_pkt(struct sk_buff *skb, const struct sock *sk, __be32 saddr, __be32 daddr, struct ip_options_rcu *opt, u8 tos) { - struct inet_sock *inet = inet_sk(sk); + const struct inet_sock *inet = inet_sk(sk); struct rtable *rt = skb_rtable(skb); struct net *net = sock_net(sk); struct iphdr *iph; @@ -162,22 +168,29 @@ int ip_build_and_send_pkt(struct sk_buff *skb, const struct sock *sk, iph->daddr = (opt && opt->opt.srr ? opt->opt.faddr : daddr); iph->saddr = saddr; iph->protocol = sk->sk_protocol; - if (ip_dont_fragment(sk, &rt->dst)) { + /* Do not bother generating IPID for small packets (eg SYNACK) */ + if (skb->len <= IPV4_MIN_MTU || ip_dont_fragment(sk, &rt->dst)) { iph->frag_off = htons(IP_DF); iph->id = 0; } else { iph->frag_off = 0; - __ip_select_ident(net, iph, 1); + /* TCP packets here are SYNACK with fat IPv4/TCP options. + * Avoid using the hashed IP ident generator. + */ + if (sk->sk_protocol == IPPROTO_TCP) + iph->id = (__force __be16)get_random_u16(); + else + __ip_select_ident(net, iph, 1); } if (opt && opt->opt.optlen) { iph->ihl += opt->opt.optlen>>2; - ip_options_build(skb, &opt->opt, daddr, rt, 0); + ip_options_build(skb, &opt->opt, daddr, rt); } - skb->priority = sk->sk_priority; + skb->priority = READ_ONCE(sk->sk_priority); if (!skb->mark) - skb->mark = sk->sk_mark; + skb->mark = READ_ONCE(sk->sk_mark); /* Send it out. */ return ip_local_out(net, skb->sk, skb); @@ -187,8 +200,8 @@ EXPORT_SYMBOL_GPL(ip_build_and_send_pkt); static int ip_finish_output2(struct net *net, struct sock *sk, struct sk_buff *skb) { struct dst_entry *dst = skb_dst(skb); - struct rtable *rt = (struct rtable *)dst; - struct net_device *dev = dst->dev; + struct rtable *rt = dst_rtable(dst); + struct net_device *dev = dst_dev(dst); unsigned int hh_len = LL_RESERVED_SPACE(dev); struct neighbour *neigh; bool is_v6gw = false; @@ -198,6 +211,9 @@ static int ip_finish_output2(struct net *net, struct sock *sk, struct sk_buff *s } else if (rt->rt_type == RTN_BROADCAST) IP_UPD_PO_STATS(net, IPSTATS_MIB_OUTBCAST, skb->len); + /* OUTOCTETS should be counted after fragment */ + IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len); + if (unlikely(skb_headroom(skb) < hh_len && dev->header_ops)) { skb = skb_expand_head(skb, hh_len); if (!skb) @@ -207,11 +223,11 @@ static int ip_finish_output2(struct net *net, struct sock *sk, struct sk_buff *s if (lwtunnel_xmit_redirect(dst->lwtstate)) { int res = lwtunnel_xmit(skb); - if (res < 0 || res == LWTUNNEL_XMIT_DONE) + if (res != LWTUNNEL_XMIT_CONTINUE) return res; } - rcu_read_lock_bh(); + rcu_read_lock(); neigh = ip_neigh_for_gw(rt, skb, &is_v6gw); if (!IS_ERR(neigh)) { int res; @@ -219,15 +235,15 @@ static int ip_finish_output2(struct net *net, struct sock *sk, struct sk_buff *s sock_confirm_neigh(skb, neigh); /* if crossing protocols, can not use the cached header */ res = neigh_output(neigh, skb, is_v6gw); - rcu_read_unlock_bh(); + rcu_read_unlock(); return res; } - rcu_read_unlock_bh(); + rcu_read_unlock(); net_dbg_ratelimited("%s: No header cache and no neighbour!\n", __func__); - kfree_skb(skb); - return -EINVAL; + kfree_skb_reason(skb, SKB_DROP_REASON_NEIGH_CREATEFAIL); + return PTR_ERR(neigh); } static int ip_finish_output_gso(struct net *net, struct sock *sk, @@ -310,7 +326,7 @@ static int ip_finish_output(struct net *net, struct sock *sk, struct sk_buff *sk case NET_XMIT_CN: return __ip_finish_output(net, sk, skb) ? : ret; default: - kfree_skb(skb); + kfree_skb_reason(skb, SKB_DROP_REASON_BPF_CGROUP_EGRESS); return ret; } } @@ -330,7 +346,7 @@ static int ip_mc_finish_output(struct net *net, struct sock *sk, case NET_XMIT_SUCCESS: break; default: - kfree_skb(skb); + kfree_skb_reason(skb, SKB_DROP_REASON_BPF_CGROUP_EGRESS); return ret; } @@ -357,8 +373,6 @@ int ip_mc_output(struct net *net, struct sock *sk, struct sk_buff *skb) /* * If the indicated interface is up and running, send the packet. */ - IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len); - skb->dev = dev; skb->protocol = htons(ETH_P_IP); @@ -413,17 +427,20 @@ int ip_mc_output(struct net *net, struct sock *sk, struct sk_buff *skb) int ip_output(struct net *net, struct sock *sk, struct sk_buff *skb) { - struct net_device *dev = skb_dst(skb)->dev, *indev = skb->dev; - - IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len); + struct net_device *dev, *indev = skb->dev; + int ret_val; + rcu_read_lock(); + dev = skb_dst_dev_rcu(skb); skb->dev = dev; skb->protocol = htons(ETH_P_IP); - return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, - net, sk, skb, indev, dev, - ip_finish_output, - !(IPCB(skb)->flags & IPSKB_REROUTED)); + ret_val = NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, + net, sk, skb, indev, dev, + ip_finish_output, + !(IPCB(skb)->flags & IPSKB_REROUTED)); + rcu_read_unlock(); + return ret_val; } EXPORT_SYMBOL(ip_output); @@ -465,26 +482,18 @@ int __ip_queue_xmit(struct sock *sk, struct sk_buff *skb, struct flowi *fl, goto packet_routed; /* Make sure we can route this packet. */ - rt = (struct rtable *)__sk_dst_check(sk, 0); + rt = dst_rtable(__sk_dst_check(sk, 0)); if (!rt) { - __be32 daddr; + inet_sk_init_flowi4(inet, fl4); - /* Use correct destination address if we have options. */ - daddr = inet->inet_daddr; - if (inet_opt && inet_opt->opt.srr) - daddr = inet_opt->opt.faddr; + /* sctp_v4_xmit() uses its own DSCP value */ + fl4->flowi4_dscp = inet_dsfield_to_dscp(tos); /* If this fails, retransmit mechanism of transport layer will * keep trying until route appears or the connection times * itself out. */ - rt = ip_route_output_ports(net, fl4, sk, - daddr, inet->inet_saddr, - inet->inet_dport, - inet->inet_sport, - sk->sk_protocol, - RT_CONN_FLAGS_TOS(sk, tos), - sk->sk_bound_dev_if); + rt = ip_route_output_flow(net, fl4, sk); if (IS_ERR(rt)) goto no_route; sk_setup_caps(sk, &rt->dst); @@ -512,15 +521,15 @@ packet_routed: if (inet_opt && inet_opt->opt.optlen) { iph->ihl += inet_opt->opt.optlen >> 2; - ip_options_build(skb, &inet_opt->opt, inet->inet_daddr, rt, 0); + ip_options_build(skb, &inet_opt->opt, inet->inet_daddr, rt); } ip_select_ident_segs(net, skb, sk, skb_shinfo(skb)->gso_segs ?: 1); /* TODO : should we use skb->sk here instead of sk ? */ - skb->priority = sk->sk_priority; - skb->mark = sk->sk_mark; + skb->priority = READ_ONCE(sk->sk_priority); + skb->mark = READ_ONCE(sk->sk_mark); res = ip_local_out(net, sk, skb); rcu_read_unlock(); @@ -529,14 +538,14 @@ packet_routed: no_route: rcu_read_unlock(); IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); - kfree_skb(skb); + kfree_skb_reason(skb, SKB_DROP_REASON_IP_OUTNOROUTES); return -EHOSTUNREACH; } EXPORT_SYMBOL(__ip_queue_xmit); int ip_queue_xmit(struct sock *sk, struct sk_buff *skb, struct flowi *fl) { - return __ip_queue_xmit(sk, skb, fl, inet_sk(sk)->tos); + return __ip_queue_xmit(sk, skb, fl, READ_ONCE(inet_sk(sk)->tos)); } EXPORT_SYMBOL(ip_queue_xmit); @@ -754,6 +763,7 @@ int ip_do_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, { struct iphdr *iph; struct sk_buff *skb2; + u8 tstamp_type = skb->tstamp_type; struct rtable *rt = skb_rtable(skb); unsigned int mtu, hlen, ll_rs; struct ip_fraglist_iter iter; @@ -825,18 +835,27 @@ int ip_do_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, /* Everything is OK. Generate! */ ip_fraglist_init(skb, iph, hlen, &iter); - if (iter.frag) - ip_options_fragment(iter.frag); - for (;;) { /* Prepare header of the next frame, * before previous one went down. */ if (iter.frag) { + bool first_frag = (iter.offset == 0); + IPCB(iter.frag)->flags = IPCB(skb)->flags; ip_fraglist_prepare(skb, &iter); + if (first_frag && IPCB(skb)->opt.optlen) { + /* ipcb->opt is not populated for frags + * coming from __ip_make_skb(), + * ip_options_fragment() needs optlen + */ + IPCB(iter.frag)->opt.optlen = + IPCB(skb)->opt.optlen; + ip_options_fragment(iter.frag); + ip_send_check(iter.iph); + } } - skb->tstamp = tstamp; + skb_set_delivery_time(skb, tstamp, tstamp_type); err = output(net, sk, skb); if (!err) @@ -892,7 +911,7 @@ slow_path: /* * Put this fragment into the sending queue. */ - skb2->tstamp = tstamp; + skb_set_delivery_time(skb2, tstamp, tstamp_type); err = output(net, sk, skb2); if (err) goto fail; @@ -928,17 +947,6 @@ ip_generic_getfrag(void *from, char *to, int offset, int len, int odd, struct sk } EXPORT_SYMBOL(ip_generic_getfrag); -static inline __wsum -csum_page(struct page *page, int offset, int copy) -{ - char *kaddr; - __wsum csum; - kaddr = kmap(page); - csum = csum_partial(kaddr + offset, copy, 0); - kunmap(page); - return csum; -} - static int __ip_append_data(struct sock *sk, struct flowi4 *fl4, struct sk_buff_head *queue, @@ -952,7 +960,6 @@ static int __ip_append_data(struct sock *sk, struct inet_sock *inet = inet_sk(sk); struct ubuf_info *uarg = NULL; struct sk_buff *skb; - struct ip_options *opt = cork->opt; int hh_len; int exthdrlen; @@ -960,11 +967,12 @@ static int __ip_append_data(struct sock *sk, int copy; int err; int offset = 0; + bool zc = false; unsigned int maxfraglen, fragheaderlen, maxnonfragsize; int csummode = CHECKSUM_NONE; - struct rtable *rt = (struct rtable *)cork->dst; + struct rtable *rt = dst_rtable(cork->dst); + bool paged, hold_tskey = false, extra_uref = false; unsigned int wmem_alloc_delta = 0; - bool paged, extra_uref = false; u32 tskey = 0; skb = skb_peek_tail(queue); @@ -973,10 +981,6 @@ static int __ip_append_data(struct sock *sk, mtu = cork->gso_size ? IP_MAX_MTU : cork->fragsize; paged = !!cork->gso_size; - if (cork->tx_flags & SKBTX_ANY_SW_TSTAMP && - sk->sk_tsflags & SOF_TIMESTAMPING_OPT_ID) - tskey = sk->sk_tskey++; - hh_len = LL_RESERVED_SPACE(rt->dst.dev); fragheaderlen = sizeof(struct iphdr) + (opt ? opt->optlen : 0); @@ -1000,22 +1004,60 @@ static int __ip_append_data(struct sock *sk, (!exthdrlen || (rt->dst.dev->features & NETIF_F_HW_ESP_TX_CSUM))) csummode = CHECKSUM_PARTIAL; - if (flags & MSG_ZEROCOPY && length && sock_flag(sk, SOCK_ZEROCOPY)) { - uarg = msg_zerocopy_realloc(sk, length, skb_zcopy(skb)); - if (!uarg) - return -ENOBUFS; - extra_uref = !skb_zcopy(skb); /* only ref on new uarg */ + if ((flags & MSG_ZEROCOPY) && length) { + struct msghdr *msg = from; + + if (getfrag == ip_generic_getfrag && msg->msg_ubuf) { + if (skb_zcopy(skb) && msg->msg_ubuf != skb_zcopy(skb)) + return -EINVAL; + + /* Leave uarg NULL if can't zerocopy, callers should + * be able to handle it. + */ + if ((rt->dst.dev->features & NETIF_F_SG) && + csummode == CHECKSUM_PARTIAL) { + paged = true; + zc = true; + uarg = msg->msg_ubuf; + } + } else if (sock_flag(sk, SOCK_ZEROCOPY)) { + uarg = msg_zerocopy_realloc(sk, length, skb_zcopy(skb), + false); + if (!uarg) + return -ENOBUFS; + extra_uref = !skb_zcopy(skb); /* only ref on new uarg */ + if (rt->dst.dev->features & NETIF_F_SG && + csummode == CHECKSUM_PARTIAL) { + paged = true; + zc = true; + } else { + uarg_to_msgzc(uarg)->zerocopy = 0; + skb_zcopy_set(skb, uarg, &extra_uref); + } + } + } else if ((flags & MSG_SPLICE_PAGES) && length) { + if (inet_test_bit(HDRINCL, sk)) + return -EPERM; if (rt->dst.dev->features & NETIF_F_SG && - csummode == CHECKSUM_PARTIAL) { + getfrag == ip_generic_getfrag) + /* We need an empty buffer to attach stuff to */ paged = true; - } else { - uarg->zerocopy = 0; - skb_zcopy_set(skb, uarg, &extra_uref); - } + else + flags &= ~MSG_SPLICE_PAGES; } cork->length += length; + if (cork->tx_flags & SKBTX_ANY_TSTAMP && + READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID) { + if (cork->flags & IPCORK_TS_OPT_ID) { + tskey = cork->ts_opt_id; + } else { + tskey = atomic_inc_return(&sk->sk_tskey) - 1; + hold_tskey = true; + } + } + /* So, what's going on in the loop below? * * We use calculated fragment length to generate chained skb, @@ -1075,8 +1117,8 @@ alloc_new_skb: !(rt->dst.dev->features & NETIF_F_SG))) alloclen = fraglen; else { - alloclen = min_t(int, fraglen, MAX_HEADER); - pagedlen = fraglen - alloclen; + alloclen = fragheaderlen + transhdrlen; + pagedlen = datalen - transhdrlen; } alloclen += alloc_extra; @@ -1123,10 +1165,18 @@ alloc_new_skb: } copy = datalen - transhdrlen - fraggap - pagedlen; - if (copy > 0 && getfrag(from, data + transhdrlen, offset, copy, fraggap, skb) < 0) { + /* [!] NOTE: copy will be negative if pagedlen>0 + * because then the equation reduces to -fraggap. + */ + if (copy > 0 && + INDIRECT_CALL_1(getfrag, ip_generic_getfrag, + from, data + transhdrlen, offset, + copy, fraggap, skb) < 0) { err = -EFAULT; kfree_skb(skb); goto error; + } else if (flags & MSG_SPLICE_PAGES) { + copy = 0; } offset += copy; @@ -1165,19 +1215,33 @@ alloc_new_skb: unsigned int off; off = skb->len; - if (getfrag(from, skb_put(skb, copy), - offset, copy, off, skb) < 0) { + if (INDIRECT_CALL_1(getfrag, ip_generic_getfrag, + from, skb_put(skb, copy), + offset, copy, off, skb) < 0) { __skb_trim(skb, off); err = -EFAULT; goto error; } - } else if (!uarg || !uarg->zerocopy) { + } else if (flags & MSG_SPLICE_PAGES) { + struct msghdr *msg = from; + + err = -EIO; + if (WARN_ON_ONCE(copy > msg->msg_iter.count)) + goto error; + + err = skb_splice_from_iter(skb, &msg->msg_iter, copy); + if (err < 0) + goto error; + copy = err; + wmem_alloc_delta += copy; + } else if (!zc) { int i = skb_shinfo(skb)->nr_frags; err = -ENOMEM; if (!sk_page_frag_refill(sk, pfrag)) goto error; + skb_zcopy_downgrade_managed(skb); if (!skb_can_coalesce(skb, i, pfrag->page, pfrag->offset)) { err = -EMSGSIZE; @@ -1190,16 +1254,15 @@ alloc_new_skb: get_page(pfrag->page); } copy = min_t(int, copy, pfrag->size - pfrag->offset); - if (getfrag(from, + if (INDIRECT_CALL_1(getfrag, ip_generic_getfrag, + from, page_address(pfrag->page) + pfrag->offset, offset, copy, skb->len, skb) < 0) goto error_efault; pfrag->offset += copy; skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy); - skb->len += copy; - skb->data_len += copy; - skb->truesize += copy; + skb_len_add(skb, copy); wmem_alloc_delta += copy; } else { err = skb_zerocopy_iter_dgram(skb, from, copy); @@ -1221,6 +1284,8 @@ error: cork->length -= length; IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS); refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc); + if (hold_tskey) + atomic_dec(&sk->sk_tskey); return err; } @@ -1234,6 +1299,12 @@ static int ip_setup_cork(struct sock *sk, struct inet_cork *cork, if (unlikely(!rt)) return -EFAULT; + cork->fragsize = ip_sk_use_pmtu(sk) ? + dst_mtu(&rt->dst) : READ_ONCE(rt->dst.dev->mtu); + + if (!inetdev_valid_mtu(cork->fragsize)) + return -ENETUNREACH; + /* * setup for corking. */ @@ -1250,12 +1321,6 @@ static int ip_setup_cork(struct sock *sk, struct inet_cork *cork, cork->addr = ipc->addr; } - cork->fragsize = ip_sk_use_pmtu(sk) ? - dst_mtu(&rt->dst) : READ_ONCE(rt->dst.dev->mtu); - - if (!inetdev_valid_mtu(cork->fragsize)) - return -ENETUNREACH; - cork->gso_size = ipc->gso_size; cork->dst = &rt->dst; @@ -1266,19 +1331,23 @@ static int ip_setup_cork(struct sock *sk, struct inet_cork *cork, cork->ttl = ipc->ttl; cork->tos = ipc->tos; cork->mark = ipc->sockc.mark; - cork->priority = ipc->priority; + cork->priority = ipc->sockc.priority; cork->transmit_time = ipc->sockc.transmit_time; cork->tx_flags = 0; - sock_tx_timestamp(sk, ipc->sockc.tsflags, &cork->tx_flags); + sock_tx_timestamp(sk, &ipc->sockc, &cork->tx_flags); + if (ipc->sockc.tsflags & SOCKCM_FLAG_TS_OPT_ID) { + cork->flags |= IPCORK_TS_OPT_ID; + cork->ts_opt_id = ipc->sockc.ts_opt_id; + } return 0; } /* - * ip_append_data() and ip_append_page() can make one large IP datagram - * from many pieces of data. Each pieces will be holded on the socket - * until ip_push_pending_frames() is called. Each piece can be a page - * or non-page data. + * ip_append_data() can make one large IP datagram from many pieces of + * data. Each piece will be held on the socket until + * ip_push_pending_frames() is called. Each piece can be a page or + * non-page data. * * Not only UDP, other transport protocols - e.g. raw sockets - can use * this interface potentially. @@ -1311,136 +1380,6 @@ int ip_append_data(struct sock *sk, struct flowi4 *fl4, from, length, transhdrlen, flags); } -ssize_t ip_append_page(struct sock *sk, struct flowi4 *fl4, struct page *page, - int offset, size_t size, int flags) -{ - struct inet_sock *inet = inet_sk(sk); - struct sk_buff *skb; - struct rtable *rt; - struct ip_options *opt = NULL; - struct inet_cork *cork; - int hh_len; - int mtu; - int len; - int err; - unsigned int maxfraglen, fragheaderlen, fraggap, maxnonfragsize; - - if (inet->hdrincl) - return -EPERM; - - if (flags&MSG_PROBE) - return 0; - - if (skb_queue_empty(&sk->sk_write_queue)) - return -EINVAL; - - cork = &inet->cork.base; - rt = (struct rtable *)cork->dst; - if (cork->flags & IPCORK_OPT) - opt = cork->opt; - - if (!(rt->dst.dev->features & NETIF_F_SG)) - return -EOPNOTSUPP; - - hh_len = LL_RESERVED_SPACE(rt->dst.dev); - mtu = cork->gso_size ? IP_MAX_MTU : cork->fragsize; - - fragheaderlen = sizeof(struct iphdr) + (opt ? opt->optlen : 0); - maxfraglen = ((mtu - fragheaderlen) & ~7) + fragheaderlen; - maxnonfragsize = ip_sk_ignore_df(sk) ? 0xFFFF : mtu; - - if (cork->length + size > maxnonfragsize - fragheaderlen) { - ip_local_error(sk, EMSGSIZE, fl4->daddr, inet->inet_dport, - mtu - (opt ? opt->optlen : 0)); - return -EMSGSIZE; - } - - skb = skb_peek_tail(&sk->sk_write_queue); - if (!skb) - return -EINVAL; - - cork->length += size; - - while (size > 0) { - /* Check if the remaining data fits into current packet. */ - len = mtu - skb->len; - if (len < size) - len = maxfraglen - skb->len; - - if (len <= 0) { - struct sk_buff *skb_prev; - int alloclen; - - skb_prev = skb; - fraggap = skb_prev->len - maxfraglen; - - alloclen = fragheaderlen + hh_len + fraggap + 15; - skb = sock_wmalloc(sk, alloclen, 1, sk->sk_allocation); - if (unlikely(!skb)) { - err = -ENOBUFS; - goto error; - } - - /* - * Fill in the control structures - */ - skb->ip_summed = CHECKSUM_NONE; - skb->csum = 0; - skb_reserve(skb, hh_len); - - /* - * Find where to start putting bytes. - */ - skb_put(skb, fragheaderlen + fraggap); - skb_reset_network_header(skb); - skb->transport_header = (skb->network_header + - fragheaderlen); - if (fraggap) { - skb->csum = skb_copy_and_csum_bits(skb_prev, - maxfraglen, - skb_transport_header(skb), - fraggap); - skb_prev->csum = csum_sub(skb_prev->csum, - skb->csum); - pskb_trim_unique(skb_prev, maxfraglen); - } - - /* - * Put the packet on the pending queue. - */ - __skb_queue_tail(&sk->sk_write_queue, skb); - continue; - } - - if (len > size) - len = size; - - if (skb_append_pagefrags(skb, page, offset, len)) { - err = -EMSGSIZE; - goto error; - } - - if (skb->ip_summed == CHECKSUM_NONE) { - __wsum csum; - csum = csum_page(page, offset, len); - skb->csum = csum_block_add(skb->csum, csum, skb->len); - } - - skb->len += len; - skb->data_len += len; - skb->truesize += len; - refcount_add(len, &sk->sk_wmem_alloc); - offset += len; - size -= len; - } - return 0; - -error: - cork->length -= size; - IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS); - return err; -} - static void ip_cork_release(struct inet_cork *cork) { cork->flags &= ~IPCORK_OPT; @@ -1464,10 +1403,10 @@ struct sk_buff *__ip_make_skb(struct sock *sk, struct inet_sock *inet = inet_sk(sk); struct net *net = sock_net(sk); struct ip_options *opt = NULL; - struct rtable *rt = (struct rtable *)cork->dst; + struct rtable *rt = dst_rtable(cork->dst); struct iphdr *iph; + u8 pmtudisc, ttl; __be16 df = 0; - __u8 ttl; skb = __skb_dequeue(queue); if (!skb) @@ -1497,8 +1436,9 @@ struct sk_buff *__ip_make_skb(struct sock *sk, /* DF bit is set when we want to see DF on outgoing frames. * If ignore_df is set too, we still allow to fragment this frame * locally. */ - if (inet->pmtudisc == IP_PMTUDISC_DO || - inet->pmtudisc == IP_PMTUDISC_PROBE || + pmtudisc = READ_ONCE(inet->pmtudisc); + if (pmtudisc == IP_PMTUDISC_DO || + pmtudisc == IP_PMTUDISC_PROBE || (skb->len <= dst_mtu(&rt->dst) && ip_dont_fragment(sk, &rt->dst))) df = htons(IP_DF); @@ -1509,14 +1449,14 @@ struct sk_buff *__ip_make_skb(struct sock *sk, if (cork->ttl != 0) ttl = cork->ttl; else if (rt->rt_type == RTN_MULTICAST) - ttl = inet->mc_ttl; + ttl = READ_ONCE(inet->mc_ttl); else ttl = ip_select_ttl(inet, &rt->dst); iph = ip_hdr(skb); iph->version = 4; iph->ihl = 5; - iph->tos = (cork->tos != -1) ? cork->tos : inet->tos; + iph->tos = (cork->tos != -1) ? cork->tos : READ_ONCE(inet->tos); iph->frag_off = df; iph->ttl = ttl; iph->protocol = sk->sk_protocol; @@ -1525,12 +1465,15 @@ struct sk_buff *__ip_make_skb(struct sock *sk, if (opt) { iph->ihl += opt->optlen >> 2; - ip_options_build(skb, opt, cork->addr, rt, 0); + ip_options_build(skb, opt, cork->addr, rt); } - skb->priority = (cork->tos != -1) ? cork->priority: sk->sk_priority; + skb->priority = cork->priority; skb->mark = cork->mark; - skb->tstamp = cork->transmit_time; + if (sk_is_tcp(sk)) + skb_set_delivery_time(skb, cork->transmit_time, SKB_CLOCK_MONOTONIC); + else + skb_set_delivery_type_by_clockid(skb, cork->transmit_time, sk->sk_clockid); /* * Steal rt from cork.dst to avoid a pair of atomic_inc/atomic_dec * on dst refcount @@ -1538,9 +1481,20 @@ struct sk_buff *__ip_make_skb(struct sock *sk, cork->dst = NULL; skb_dst_set(skb, &rt->dst); - if (iph->protocol == IPPROTO_ICMP) - icmp_out_count(net, ((struct icmphdr *) - skb_transport_header(skb))->type); + if (iph->protocol == IPPROTO_ICMP) { + u8 icmp_type; + + /* For such sockets, transhdrlen is zero when do ip_append_data(), + * so icmphdr does not in skb linear region and can not get icmp_type + * by icmp_hdr(skb)->type. + */ + if (sk->sk_type == SOCK_RAW && + !(fl4->flowi4_flags & FLOWI_FLAG_KNOWN_NH)) + icmp_type = fl4->fl4_icmp_type; + else + icmp_type = icmp_hdr(skb)->type; + icmp_out_count(net, icmp_type); + } ip_cork_release(cork); out: @@ -1645,11 +1599,12 @@ static int ip_reply_glue_bits(void *dptr, char *to, int offset, * Generic function to send a packet as reply to another packet. * Used to send some TCP resets/acks so far. */ -void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb, +void ip_send_unicast_reply(struct sock *sk, const struct sock *orig_sk, + struct sk_buff *skb, const struct ip_options *sopt, __be32 daddr, __be32 saddr, const struct ip_reply_arg *arg, - unsigned int len, u64 transmit_time) + unsigned int len, u64 transmit_time, u32 txhash) { struct ip_options_data replyopts; struct ipcm_cookie ipc; @@ -1680,22 +1635,22 @@ void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb, flowi4_init_output(&fl4, oif, IP4_REPLY_MARK(net, skb->mark) ?: sk->sk_mark, - RT_TOS(arg->tos), + arg->tos & INET_DSCP_MASK, RT_SCOPE_UNIVERSE, ip_hdr(skb)->protocol, ip_reply_arg_flowi_flags(arg), daddr, saddr, tcp_hdr(skb)->source, tcp_hdr(skb)->dest, arg->uid); security_skb_classify_flow(skb, flowi4_to_flowi_common(&fl4)); - rt = ip_route_output_key(net, &fl4); + rt = ip_route_output_flow(net, &fl4, sk); if (IS_ERR(rt)) return; - inet_sk(sk)->tos = arg->tos & ~INET_ECN_MASK; + inet_sk(sk)->tos = arg->tos; sk->sk_protocol = ip_hdr(skb)->protocol; sk->sk_bound_dev_if = arg->bound_dev_if; - sk->sk_sndbuf = sysctl_wmem_default; + sk->sk_sndbuf = READ_ONCE(sysctl_wmem_default); ipc.sockc.mark = fl4.flowi4_mark; err = ip_append_data(sk, &fl4, ip_reply_glue_bits, arg->iov->iov_base, len, 0, &ipc, &rt, MSG_DONTWAIT); @@ -1711,6 +1666,14 @@ void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb, arg->csumoffset) = csum_fold(csum_add(nskb->csum, arg->csum)); nskb->ip_summed = CHECKSUM_NONE; + if (orig_sk) { + skb_set_owner_edemux(nskb, (struct sock *)orig_sk); + psp_reply_set_decrypted(orig_sk, nskb); + } + if (transmit_time) + nskb->tstamp_type = SKB_CLOCK_MONOTONIC; + if (txhash) + skb_set_hash(nskb, txhash, PKT_HASH_TYPE_L4); ip_push_pending_frames(sk, &fl4); } out: diff --git a/net/ipv4/ip_sockglue.c b/net/ipv4/ip_sockglue.c index 445a9ecaefa1..6d9c5c20b1c4 100644 --- a/net/ipv4/ip_sockglue.c +++ b/net/ipv4/ip_sockglue.c @@ -47,8 +47,6 @@ #include <linux/errqueue.h> #include <linux/uaccess.h> -#include <linux/bpfilter.h> - /* * SOL_IP control messages. */ @@ -130,20 +128,20 @@ static void ip_cmsg_recv_checksum(struct msghdr *msg, struct sk_buff *skb, static void ip_cmsg_recv_security(struct msghdr *msg, struct sk_buff *skb) { - char *secdata; - u32 seclen, secid; + struct lsm_context ctx; + u32 secid; int err; err = security_socket_getpeersec_dgram(NULL, skb, &secid); if (err) return; - err = security_secid_to_secctx(secid, &secdata, &seclen); - if (err) + err = security_secid_to_secctx(secid, &ctx); + if (err < 0) return; - put_cmsg(msg, SOL_IP, SCM_SECURITY, seclen, secdata); - security_release_secctx(secdata, seclen); + put_cmsg(msg, SOL_IP, SCM_SECURITY, ctx.len, ctx.context); + security_release_secctx(&ctx); } static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb) @@ -171,8 +169,10 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb) void ip_cmsg_recv_offset(struct msghdr *msg, struct sock *sk, struct sk_buff *skb, int tlen, int offset) { - struct inet_sock *inet = inet_sk(sk); - unsigned int flags = inet->cmsg_flags; + unsigned long flags = inet_cmsg_flags(inet_sk(sk)); + + if (!flags) + return; /* Ordered by supposed usage frequency */ if (flags & IP_CMSG_PKTINFO) { @@ -267,7 +267,7 @@ int ip_cmsg_send(struct sock *sk, struct msghdr *msg, struct ipcm_cookie *ipc, } #endif if (cmsg->cmsg_level == SOL_SOCKET) { - err = __sock_cmsg_send(sk, msg, cmsg, &ipc->sockc); + err = __sock_cmsg_send(sk, cmsg, &ipc->sockc); if (err) return err; continue; @@ -315,9 +315,16 @@ int ip_cmsg_send(struct sock *sk, struct msghdr *msg, struct ipcm_cookie *ipc, if (val < 0 || val > 255) return -EINVAL; ipc->tos = val; - ipc->priority = rt_tos2priority(ipc->tos); + ipc->sockc.priority = rt_tos2priority(ipc->tos); + break; + case IP_PROTOCOL: + if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) + return -EINVAL; + val = *(int *)CMSG_DATA(cmsg); + if (val < 1 || val > 255) + return -EINVAL; + ipc->protocol = val; break; - default: return -EINVAL; } @@ -424,7 +431,7 @@ void ip_icmp_error(struct sock *sk, struct sk_buff *skb, int err, serr->port = port; if (skb_pull(skb, payload - skb->data)) { - if (inet_sk(sk)->recverr_rfc4884) + if (inet_test_bit(RECVERR_RFC4884, sk)) ipv4_icmp_error_rfc4884(skb, &serr->ee.ee_rfc4884); skb_reset_transport_header(skb); @@ -433,15 +440,15 @@ void ip_icmp_error(struct sock *sk, struct sk_buff *skb, int err, } kfree_skb(skb); } +EXPORT_SYMBOL_GPL(ip_icmp_error); void ip_local_error(struct sock *sk, int err, __be32 daddr, __be16 port, u32 info) { - struct inet_sock *inet = inet_sk(sk); struct sock_exterr_skb *serr; struct iphdr *iph; struct sk_buff *skb; - if (!inet->recverr) + if (!inet_test_bit(RECVERR, sk)) return; skb = alloc_skb(sizeof(struct iphdr), GFP_ATOMIC); @@ -502,7 +509,7 @@ static bool ipv4_datagram_support_cmsg(const struct sock *sk, * or without payload (SOF_TIMESTAMPING_OPT_TSONLY). */ info = PKTINFO_SKB_CB(skb); - if (!(sk->sk_tsflags & SOF_TIMESTAMPING_OPT_CMSG) || + if (!(READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_CMSG) || !info->ipi_ifindex) return false; @@ -560,7 +567,7 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) if (ipv4_datagram_support_cmsg(sk, skb, serr->ee.ee_origin)) { sin->sin_family = AF_INET; sin->sin_addr.s_addr = ip_hdr(skb)->saddr; - if (inet_sk(sk)->cmsg_flags) + if (inet_cmsg_flags(inet_sk(sk))) ip_cmsg_recv(msg, skb); } @@ -578,38 +585,36 @@ out: void __ip_sock_set_tos(struct sock *sk, int val) { + u8 old_tos = inet_sk(sk)->tos; + if (sk->sk_type == SOCK_STREAM) { val &= ~INET_ECN_MASK; - val |= inet_sk(sk)->tos & INET_ECN_MASK; + val |= old_tos & INET_ECN_MASK; } - if (inet_sk(sk)->tos != val) { - inet_sk(sk)->tos = val; - sk->sk_priority = rt_tos2priority(val); + if (old_tos != val) { + WRITE_ONCE(inet_sk(sk)->tos, val); + WRITE_ONCE(sk->sk_priority, rt_tos2priority(val)); sk_dst_reset(sk); } } void ip_sock_set_tos(struct sock *sk, int val) { - lock_sock(sk); + sockopt_lock_sock(sk); __ip_sock_set_tos(sk, val); - release_sock(sk); + sockopt_release_sock(sk); } EXPORT_SYMBOL(ip_sock_set_tos); void ip_sock_set_freebind(struct sock *sk) { - lock_sock(sk); - inet_sk(sk)->freebind = true; - release_sock(sk); + inet_set_bit(FREEBIND, sk); } EXPORT_SYMBOL(ip_sock_set_freebind); void ip_sock_set_recverr(struct sock *sk) { - lock_sock(sk); - inet_sk(sk)->recverr = true; - release_sock(sk); + inet_set_bit(RECVERR, sk); } EXPORT_SYMBOL(ip_sock_set_recverr); @@ -617,18 +622,14 @@ int ip_sock_set_mtu_discover(struct sock *sk, int val) { if (val < IP_PMTUDISC_DONT || val > IP_PMTUDISC_OMIT) return -EINVAL; - lock_sock(sk); - inet_sk(sk)->pmtudisc = val; - release_sock(sk); + WRITE_ONCE(inet_sk(sk)->pmtudisc, val); return 0; } EXPORT_SYMBOL(ip_sock_set_mtu_discover); void ip_sock_set_pktinfo(struct sock *sk) { - lock_sock(sk); - inet_sk(sk)->cmsg_flags |= IP_CMSG_PKTINFO; - release_sock(sk); + inet_set_bit(PKTINFO, sk); } EXPORT_SYMBOL(ip_sock_set_pktinfo); @@ -772,7 +773,7 @@ static int ip_set_mcast_msfilter(struct sock *sk, sockptr_t optval, int optlen) if (optlen < GROUP_FILTER_SIZE(0)) return -EINVAL; - if (optlen > sysctl_optmem_max) + if (optlen > READ_ONCE(sock_net(sk)->core.sysctl_optmem_max)) return -ENOBUFS; gsf = memdup_sockptr(optval, optlen); @@ -782,7 +783,7 @@ static int ip_set_mcast_msfilter(struct sock *sk, sockptr_t optval, int optlen) /* numsrc >= (4G-140)/128 overflow in 32 bits */ err = -ENOBUFS; if (gsf->gf_numsrc >= 0x1ffffff || - gsf->gf_numsrc > sock_net(sk)->ipv4.sysctl_igmp_max_msf) + gsf->gf_numsrc > READ_ONCE(sock_net(sk)->ipv4.sysctl_igmp_max_msf)) goto out_free_gsf; err = -EINVAL; @@ -808,7 +809,7 @@ static int compat_ip_set_mcast_msfilter(struct sock *sk, sockptr_t optval, if (optlen < size0) return -EINVAL; - if (optlen > sysctl_optmem_max - 4) + if (optlen > READ_ONCE(sock_net(sk)->core.sysctl_optmem_max) - 4) return -ENOBUFS; p = kmalloc(optlen + 4, GFP_KERNEL); @@ -832,7 +833,7 @@ static int compat_ip_set_mcast_msfilter(struct sock *sk, sockptr_t optval, /* numsrc >= (4G-140)/128 overflow in 32 bits */ err = -ENOBUFS; - if (n > sock_net(sk)->ipv4.sysctl_igmp_max_msf) + if (n > READ_ONCE(sock_net(sk)->ipv4.sysctl_igmp_max_msf)) goto out_free_gsf; err = set_mcast_msfilter(sk, gf32->gf_interface, n, gf32->gf_fmode, &gf32->gf_group, gf32->gf_slist_flex); @@ -888,12 +889,12 @@ static int compat_ip_mcast_join_leave(struct sock *sk, int optname, DEFINE_STATIC_KEY_FALSE(ip4_min_ttl); -static int do_ip_setsockopt(struct sock *sk, int level, int optname, - sockptr_t optval, unsigned int optlen) +int do_ip_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) { struct inet_sock *inet = inet_sk(sk); struct net *net = sock_net(sk); - int val = 0, err; + int val = 0, err, retv; bool needs_rtnl = setsockopt_needs_rtnl(optname); switch (optname) { @@ -922,6 +923,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, case IP_CHECKSUM: case IP_RECVFRAGSIZE: case IP_RECVERR_RFC4884: + case IP_LOCAL_PORT_RANGE: if (optlen >= sizeof(int)) { if (copy_from_sockptr(&val, optval, sizeof(val))) return -EFAULT; @@ -936,15 +938,144 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, /* If optlen==0, it is equivalent to val == 0 */ - if (optname == IP_ROUTER_ALERT) - return ip_ra_control(sk, val ? 1 : 0, NULL); + if (optname == IP_ROUTER_ALERT) { + retv = ip_ra_control(sk, val ? 1 : 0, NULL); + if (retv == 0) + inet_assign_bit(RTALERT, sk, val); + return retv; + } if (ip_mroute_opt(optname)) return ip_mroute_setsockopt(sk, optname, optval, optlen); + /* Handle options that can be set without locking the socket. */ + switch (optname) { + case IP_PKTINFO: + inet_assign_bit(PKTINFO, sk, val); + return 0; + case IP_RECVTTL: + inet_assign_bit(TTL, sk, val); + return 0; + case IP_RECVTOS: + inet_assign_bit(TOS, sk, val); + return 0; + case IP_RECVOPTS: + inet_assign_bit(RECVOPTS, sk, val); + return 0; + case IP_RETOPTS: + inet_assign_bit(RETOPTS, sk, val); + return 0; + case IP_PASSSEC: + inet_assign_bit(PASSSEC, sk, val); + return 0; + case IP_RECVORIGDSTADDR: + inet_assign_bit(ORIGDSTADDR, sk, val); + return 0; + case IP_RECVFRAGSIZE: + if (sk->sk_type != SOCK_RAW && sk->sk_type != SOCK_DGRAM) + return -EINVAL; + inet_assign_bit(RECVFRAGSIZE, sk, val); + return 0; + case IP_RECVERR: + inet_assign_bit(RECVERR, sk, val); + if (!val) + skb_errqueue_purge(&sk->sk_error_queue); + return 0; + case IP_RECVERR_RFC4884: + if (val < 0 || val > 1) + return -EINVAL; + inet_assign_bit(RECVERR_RFC4884, sk, val); + return 0; + case IP_FREEBIND: + if (optlen < 1) + return -EINVAL; + inet_assign_bit(FREEBIND, sk, val); + return 0; + case IP_HDRINCL: + if (sk->sk_type != SOCK_RAW) + return -ENOPROTOOPT; + inet_assign_bit(HDRINCL, sk, val); + return 0; + case IP_MULTICAST_LOOP: + if (optlen < 1) + return -EINVAL; + inet_assign_bit(MC_LOOP, sk, val); + return 0; + case IP_MULTICAST_ALL: + if (optlen < 1) + return -EINVAL; + if (val != 0 && val != 1) + return -EINVAL; + inet_assign_bit(MC_ALL, sk, val); + return 0; + case IP_TRANSPARENT: + if (!!val && !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && + !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + if (optlen < 1) + return -EINVAL; + inet_assign_bit(TRANSPARENT, sk, val); + return 0; + case IP_NODEFRAG: + if (sk->sk_type != SOCK_RAW) + return -ENOPROTOOPT; + inet_assign_bit(NODEFRAG, sk, val); + return 0; + case IP_BIND_ADDRESS_NO_PORT: + inet_assign_bit(BIND_ADDRESS_NO_PORT, sk, val); + return 0; + case IP_TTL: + if (optlen < 1) + return -EINVAL; + if (val != -1 && (val < 1 || val > 255)) + return -EINVAL; + WRITE_ONCE(inet->uc_ttl, val); + return 0; + case IP_MINTTL: + if (optlen < 1) + return -EINVAL; + if (val < 0 || val > 255) + return -EINVAL; + + if (val) + static_branch_enable(&ip4_min_ttl); + + WRITE_ONCE(inet->min_ttl, val); + return 0; + case IP_MULTICAST_TTL: + if (sk->sk_type == SOCK_STREAM) + return -EINVAL; + if (optlen < 1) + return -EINVAL; + if (val == -1) + val = 1; + if (val < 0 || val > 255) + return -EINVAL; + WRITE_ONCE(inet->mc_ttl, val); + return 0; + case IP_MTU_DISCOVER: + return ip_sock_set_mtu_discover(sk, val); + case IP_TOS: /* This sets both TOS and Precedence */ + ip_sock_set_tos(sk, val); + return 0; + case IP_LOCAL_PORT_RANGE: + { + u16 lo = val; + u16 hi = val >> 16; + + if (optlen != sizeof(u32)) + return -EINVAL; + if (lo != 0 && hi != 0 && lo > hi) + return -EINVAL; + + WRITE_ONCE(inet->local_port_range, val); + return 0; + } + } + err = 0; if (needs_rtnl) rtnl_lock(); - lock_sock(sk); + sockopt_lock_sock(sk); switch (optname) { case IP_OPTIONS: @@ -958,7 +1089,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, break; old = rcu_dereference_protected(inet->inet_opt, lockdep_sock_is_held(sk)); - if (inet->is_icsk) { + if (inet_test_bit(IS_ICSK, sk)) { struct inet_connection_sock *icsk = inet_csk(sk); #if IS_ENABLED(CONFIG_IPV6) if (sk->sk_family == PF_INET || @@ -980,127 +1111,19 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, kfree_rcu(old, rcu); break; } - case IP_PKTINFO: - if (val) - inet->cmsg_flags |= IP_CMSG_PKTINFO; - else - inet->cmsg_flags &= ~IP_CMSG_PKTINFO; - break; - case IP_RECVTTL: - if (val) - inet->cmsg_flags |= IP_CMSG_TTL; - else - inet->cmsg_flags &= ~IP_CMSG_TTL; - break; - case IP_RECVTOS: - if (val) - inet->cmsg_flags |= IP_CMSG_TOS; - else - inet->cmsg_flags &= ~IP_CMSG_TOS; - break; - case IP_RECVOPTS: - if (val) - inet->cmsg_flags |= IP_CMSG_RECVOPTS; - else - inet->cmsg_flags &= ~IP_CMSG_RECVOPTS; - break; - case IP_RETOPTS: - if (val) - inet->cmsg_flags |= IP_CMSG_RETOPTS; - else - inet->cmsg_flags &= ~IP_CMSG_RETOPTS; - break; - case IP_PASSSEC: - if (val) - inet->cmsg_flags |= IP_CMSG_PASSSEC; - else - inet->cmsg_flags &= ~IP_CMSG_PASSSEC; - break; - case IP_RECVORIGDSTADDR: - if (val) - inet->cmsg_flags |= IP_CMSG_ORIGDSTADDR; - else - inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR; - break; case IP_CHECKSUM: if (val) { - if (!(inet->cmsg_flags & IP_CMSG_CHECKSUM)) { + if (!(inet_test_bit(CHECKSUM, sk))) { inet_inc_convert_csum(sk); - inet->cmsg_flags |= IP_CMSG_CHECKSUM; + inet_set_bit(CHECKSUM, sk); } } else { - if (inet->cmsg_flags & IP_CMSG_CHECKSUM) { + if (inet_test_bit(CHECKSUM, sk)) { inet_dec_convert_csum(sk); - inet->cmsg_flags &= ~IP_CMSG_CHECKSUM; + inet_clear_bit(CHECKSUM, sk); } } break; - case IP_RECVFRAGSIZE: - if (sk->sk_type != SOCK_RAW && sk->sk_type != SOCK_DGRAM) - goto e_inval; - if (val) - inet->cmsg_flags |= IP_CMSG_RECVFRAGSIZE; - else - inet->cmsg_flags &= ~IP_CMSG_RECVFRAGSIZE; - break; - case IP_TOS: /* This sets both TOS and Precedence */ - __ip_sock_set_tos(sk, val); - break; - case IP_TTL: - if (optlen < 1) - goto e_inval; - if (val != -1 && (val < 1 || val > 255)) - goto e_inval; - inet->uc_ttl = val; - break; - case IP_HDRINCL: - if (sk->sk_type != SOCK_RAW) { - err = -ENOPROTOOPT; - break; - } - inet->hdrincl = val ? 1 : 0; - break; - case IP_NODEFRAG: - if (sk->sk_type != SOCK_RAW) { - err = -ENOPROTOOPT; - break; - } - inet->nodefrag = val ? 1 : 0; - break; - case IP_BIND_ADDRESS_NO_PORT: - inet->bind_address_no_port = val ? 1 : 0; - break; - case IP_MTU_DISCOVER: - if (val < IP_PMTUDISC_DONT || val > IP_PMTUDISC_OMIT) - goto e_inval; - inet->pmtudisc = val; - break; - case IP_RECVERR: - inet->recverr = !!val; - if (!val) - skb_queue_purge(&sk->sk_error_queue); - break; - case IP_RECVERR_RFC4884: - if (val < 0 || val > 1) - goto e_inval; - inet->recverr_rfc4884 = !!val; - break; - case IP_MULTICAST_TTL: - if (sk->sk_type == SOCK_STREAM) - goto e_inval; - if (optlen < 1) - goto e_inval; - if (val == -1) - val = 1; - if (val < 0 || val > 255) - goto e_inval; - inet->mc_ttl = val; - break; - case IP_MULTICAST_LOOP: - if (optlen < 1) - goto e_inval; - inet->mc_loop = !!val; - break; case IP_UNICAST_IF: { struct net_device *dev = NULL; @@ -1112,7 +1135,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, ifindex = (__force int)ntohl((__force __be32)val); if (ifindex == 0) { - inet->uc_index = 0; + WRITE_ONCE(inet->uc_index, 0); err = 0; break; } @@ -1129,7 +1152,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, if (sk->sk_bound_dev_if && midx != sk->sk_bound_dev_if) break; - inet->uc_index = ifindex; + WRITE_ONCE(inet->uc_index, ifindex); err = 0; break; } @@ -1167,8 +1190,8 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, if (!mreq.imr_ifindex) { if (mreq.imr_address.s_addr == htonl(INADDR_ANY)) { - inet->mc_index = 0; - inet->mc_addr = 0; + WRITE_ONCE(inet->mc_index, 0); + WRITE_ONCE(inet->mc_addr, 0); err = 0; break; } @@ -1193,8 +1216,8 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, midx != sk->sk_bound_dev_if) break; - inet->mc_index = mreq.imr_ifindex; - inet->mc_addr = mreq.imr_address.s_addr; + WRITE_ONCE(inet->mc_index, mreq.imr_ifindex); + WRITE_ONCE(inet->mc_addr, mreq.imr_address.s_addr); err = 0; break; } @@ -1205,7 +1228,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, struct ip_mreqn mreq; err = -EPROTO; - if (inet_sk(sk)->is_icsk) + if (inet_test_bit(IS_ICSK, sk)) break; if (optlen < sizeof(struct ip_mreq)) @@ -1233,7 +1256,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, if (optlen < IP_MSFILTER_SIZE(0)) goto e_inval; - if (optlen > sysctl_optmem_max) { + if (optlen > READ_ONCE(net->core.sysctl_optmem_max)) { err = -ENOBUFS; break; } @@ -1244,7 +1267,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, } /* numsrc >= (1G-4) overflow in 32 bits */ if (msf->imsf_numsrc >= 0x3ffffffcU || - msf->imsf_numsrc > net->ipv4.sysctl_igmp_max_msf) { + msf->imsf_numsrc > READ_ONCE(net->ipv4.sysctl_igmp_max_msf)) { kfree(msf); err = -ENOBUFS; break; @@ -1316,65 +1339,25 @@ static int do_ip_setsockopt(struct sock *sk, int level, int optname, else err = ip_set_mcast_msfilter(sk, optval, optlen); break; - case IP_MULTICAST_ALL: - if (optlen < 1) - goto e_inval; - if (val != 0 && val != 1) - goto e_inval; - inet->mc_all = val; - break; - - case IP_FREEBIND: - if (optlen < 1) - goto e_inval; - inet->freebind = !!val; - break; - case IP_IPSEC_POLICY: case IP_XFRM_POLICY: err = -EPERM; - if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) break; err = xfrm_user_policy(sk, optname, optval, optlen); break; - case IP_TRANSPARENT: - if (!!val && !ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && - !ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) { - err = -EPERM; - break; - } - if (optlen < 1) - goto e_inval; - inet->transparent = !!val; - break; - - case IP_MINTTL: - if (optlen < 1) - goto e_inval; - if (val < 0 || val > 255) - goto e_inval; - - if (val) - static_branch_enable(&ip4_min_ttl); - - /* tcp_v4_err() and tcp_v4_rcv() might read min_ttl - * while we are changint it. - */ - WRITE_ONCE(inet->min_ttl, val); - break; - default: err = -ENOPROTOOPT; break; } - release_sock(sk); + sockopt_release_sock(sk); if (needs_rtnl) rtnl_unlock(); return err; e_inval: - release_sock(sk); + sockopt_release_sock(sk); if (needs_rtnl) rtnl_unlock(); return -EINVAL; @@ -1384,15 +1367,16 @@ e_inval: * ipv4_pktinfo_prepare - transfer some info from rtable to skb * @sk: socket * @skb: buffer + * @drop_dst: if true, drops skb dst * * To support IP_CMSG_PKTINFO option, we store rt_iif and specific * destination in skb->cb[] before dst drop. * This way, receiver doesn't make cache line misses to read rtable. */ -void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb) +void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb, bool drop_dst) { struct in_pktinfo *pktinfo = PKTINFO_SKB_CB(skb); - bool prepare = (inet_sk(sk)->cmsg_flags & IP_CMSG_PKTINFO) || + bool prepare = inet_test_bit(PKTINFO, sk) || ipv6_sk_rxinfo(sk); if (prepare && skb_rtable(skb)) { @@ -1418,7 +1402,8 @@ void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb) pktinfo->ipi_ifindex = 0; pktinfo->ipi_spec_dst.s_addr = 0; } - skb_dst_drop(skb); + if (drop_dst) + skb_dst_drop(skb); } int ip_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, @@ -1430,11 +1415,6 @@ int ip_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, return -ENOPROTOOPT; err = do_ip_setsockopt(sk, level, optname, optval, optlen); -#if IS_ENABLED(CONFIG_BPFILTER_UMH) - if (optname >= BPFILTER_IPT_SO_SET_REPLACE && - optname < BPFILTER_IPT_SET_MAX) - err = bpfilter_ip_set_sockopt(sk, optname, optval, optlen); -#endif #ifdef CONFIG_NETFILTER /* we need to exclude all possible ENOPROTOOPTs except default case */ if (err == -ENOPROTOOPT && optname != IP_HDRINCL && @@ -1462,37 +1442,37 @@ static bool getsockopt_needs_rtnl(int optname) return false; } -static int ip_get_mcast_msfilter(struct sock *sk, void __user *optval, - int __user *optlen, int len) +static int ip_get_mcast_msfilter(struct sock *sk, sockptr_t optval, + sockptr_t optlen, int len) { const int size0 = offsetof(struct group_filter, gf_slist_flex); - struct group_filter __user *p = optval; struct group_filter gsf; - int num; + int num, gsf_size; int err; if (len < size0) return -EINVAL; - if (copy_from_user(&gsf, p, size0)) + if (copy_from_sockptr(&gsf, optval, size0)) return -EFAULT; num = gsf.gf_numsrc; - err = ip_mc_gsfget(sk, &gsf, p->gf_slist_flex); + err = ip_mc_gsfget(sk, &gsf, optval, + offsetof(struct group_filter, gf_slist_flex)); if (err) return err; if (gsf.gf_numsrc < num) num = gsf.gf_numsrc; - if (put_user(GROUP_FILTER_SIZE(num), optlen) || - copy_to_user(p, &gsf, size0)) + gsf_size = GROUP_FILTER_SIZE(num); + if (copy_to_sockptr(optlen, &gsf_size, sizeof(int)) || + copy_to_sockptr(optval, &gsf, size0)) return -EFAULT; return 0; } -static int compat_ip_get_mcast_msfilter(struct sock *sk, void __user *optval, - int __user *optlen, int len) +static int compat_ip_get_mcast_msfilter(struct sock *sk, sockptr_t optval, + sockptr_t optlen, int len) { const int size0 = offsetof(struct compat_group_filter, gf_slist_flex); - struct compat_group_filter __user *p = optval; struct compat_group_filter gf32; struct group_filter gf; int num; @@ -1500,7 +1480,7 @@ static int compat_ip_get_mcast_msfilter(struct sock *sk, void __user *optval, if (len < size0) return -EINVAL; - if (copy_from_user(&gf32, p, size0)) + if (copy_from_sockptr(&gf32, optval, size0)) return -EFAULT; gf.gf_interface = gf32.gf_interface; @@ -1508,21 +1488,24 @@ static int compat_ip_get_mcast_msfilter(struct sock *sk, void __user *optval, num = gf.gf_numsrc = gf32.gf_numsrc; gf.gf_group = gf32.gf_group; - err = ip_mc_gsfget(sk, &gf, p->gf_slist_flex); + err = ip_mc_gsfget(sk, &gf, optval, + offsetof(struct compat_group_filter, gf_slist_flex)); if (err) return err; if (gf.gf_numsrc < num) num = gf.gf_numsrc; len = GROUP_FILTER_SIZE(num) - (sizeof(gf) - sizeof(gf32)); - if (put_user(len, optlen) || - put_user(gf.gf_fmode, &p->gf_fmode) || - put_user(gf.gf_numsrc, &p->gf_numsrc)) + if (copy_to_sockptr(optlen, &len, sizeof(int)) || + copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_fmode), + &gf.gf_fmode, sizeof(gf.gf_fmode)) || + copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_numsrc), + &gf.gf_numsrc, sizeof(gf.gf_numsrc))) return -EFAULT; return 0; } -static int do_ip_getsockopt(struct sock *sk, int level, int optname, - char __user *optval, int __user *optlen) +int do_ip_getsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, sockptr_t optlen) { struct inet_sock *inet = inet_sk(sk); bool needs_rtnl = getsockopt_needs_rtnl(optname); @@ -1535,93 +1518,116 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, if (ip_mroute_opt(optname)) return ip_mroute_getsockopt(sk, optname, optval, optlen); - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; if (len < 0) return -EINVAL; - if (needs_rtnl) - rtnl_lock(); - lock_sock(sk); - + /* Handle options that can be read without locking the socket. */ switch (optname) { + case IP_PKTINFO: + val = inet_test_bit(PKTINFO, sk); + goto copyval; + case IP_RECVTTL: + val = inet_test_bit(TTL, sk); + goto copyval; + case IP_RECVTOS: + val = inet_test_bit(TOS, sk); + goto copyval; + case IP_RECVOPTS: + val = inet_test_bit(RECVOPTS, sk); + goto copyval; + case IP_RETOPTS: + val = inet_test_bit(RETOPTS, sk); + goto copyval; + case IP_PASSSEC: + val = inet_test_bit(PASSSEC, sk); + goto copyval; + case IP_RECVORIGDSTADDR: + val = inet_test_bit(ORIGDSTADDR, sk); + goto copyval; + case IP_CHECKSUM: + val = inet_test_bit(CHECKSUM, sk); + goto copyval; + case IP_RECVFRAGSIZE: + val = inet_test_bit(RECVFRAGSIZE, sk); + goto copyval; + case IP_RECVERR: + val = inet_test_bit(RECVERR, sk); + goto copyval; + case IP_RECVERR_RFC4884: + val = inet_test_bit(RECVERR_RFC4884, sk); + goto copyval; + case IP_FREEBIND: + val = inet_test_bit(FREEBIND, sk); + goto copyval; + case IP_HDRINCL: + val = inet_test_bit(HDRINCL, sk); + goto copyval; + case IP_MULTICAST_LOOP: + val = inet_test_bit(MC_LOOP, sk); + goto copyval; + case IP_MULTICAST_ALL: + val = inet_test_bit(MC_ALL, sk); + goto copyval; + case IP_TRANSPARENT: + val = inet_test_bit(TRANSPARENT, sk); + goto copyval; + case IP_NODEFRAG: + val = inet_test_bit(NODEFRAG, sk); + goto copyval; + case IP_BIND_ADDRESS_NO_PORT: + val = inet_test_bit(BIND_ADDRESS_NO_PORT, sk); + goto copyval; + case IP_ROUTER_ALERT: + val = inet_test_bit(RTALERT, sk); + goto copyval; + case IP_TTL: + val = READ_ONCE(inet->uc_ttl); + if (val < 0) + val = READ_ONCE(sock_net(sk)->ipv4.sysctl_ip_default_ttl); + goto copyval; + case IP_MINTTL: + val = READ_ONCE(inet->min_ttl); + goto copyval; + case IP_MULTICAST_TTL: + val = READ_ONCE(inet->mc_ttl); + goto copyval; + case IP_MTU_DISCOVER: + val = READ_ONCE(inet->pmtudisc); + goto copyval; + case IP_TOS: + val = READ_ONCE(inet->tos); + goto copyval; case IP_OPTIONS: { unsigned char optbuf[sizeof(struct ip_options)+40]; struct ip_options *opt = (struct ip_options *)optbuf; struct ip_options_rcu *inet_opt; - inet_opt = rcu_dereference_protected(inet->inet_opt, - lockdep_sock_is_held(sk)); + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); opt->optlen = 0; if (inet_opt) memcpy(optbuf, &inet_opt->opt, sizeof(struct ip_options) + inet_opt->opt.optlen); - release_sock(sk); + rcu_read_unlock(); - if (opt->optlen == 0) - return put_user(0, optlen); + if (opt->optlen == 0) { + len = 0; + return copy_to_sockptr(optlen, &len, sizeof(int)); + } ip_options_undo(opt); len = min_t(unsigned int, len, opt->optlen); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, opt->__data, len)) + if (copy_to_sockptr(optval, opt->__data, len)) return -EFAULT; return 0; } - case IP_PKTINFO: - val = (inet->cmsg_flags & IP_CMSG_PKTINFO) != 0; - break; - case IP_RECVTTL: - val = (inet->cmsg_flags & IP_CMSG_TTL) != 0; - break; - case IP_RECVTOS: - val = (inet->cmsg_flags & IP_CMSG_TOS) != 0; - break; - case IP_RECVOPTS: - val = (inet->cmsg_flags & IP_CMSG_RECVOPTS) != 0; - break; - case IP_RETOPTS: - val = (inet->cmsg_flags & IP_CMSG_RETOPTS) != 0; - break; - case IP_PASSSEC: - val = (inet->cmsg_flags & IP_CMSG_PASSSEC) != 0; - break; - case IP_RECVORIGDSTADDR: - val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0; - break; - case IP_CHECKSUM: - val = (inet->cmsg_flags & IP_CMSG_CHECKSUM) != 0; - break; - case IP_RECVFRAGSIZE: - val = (inet->cmsg_flags & IP_CMSG_RECVFRAGSIZE) != 0; - break; - case IP_TOS: - val = inet->tos; - break; - case IP_TTL: - { - struct net *net = sock_net(sk); - val = (inet->uc_ttl == -1 ? - net->ipv4.sysctl_ip_default_ttl : - inet->uc_ttl); - break; - } - case IP_HDRINCL: - val = inet->hdrincl; - break; - case IP_NODEFRAG: - val = inet->nodefrag; - break; - case IP_BIND_ADDRESS_NO_PORT: - val = inet->bind_address_no_port; - break; - case IP_MTU_DISCOVER: - val = inet->pmtudisc; - break; case IP_MTU: { struct dst_entry *dst; @@ -1631,40 +1637,72 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, val = dst_mtu(dst); dst_release(dst); } - if (!val) { - release_sock(sk); + if (!val) return -ENOTCONN; + goto copyval; + } + case IP_PKTOPTIONS: + { + struct msghdr msg; + + if (sk->sk_type != SOCK_STREAM) + return -ENOPROTOOPT; + + if (optval.is_kernel) { + msg.msg_control_is_user = false; + msg.msg_control = optval.kernel; + } else { + msg.msg_control_is_user = true; + msg.msg_control_user = optval.user; } - break; + msg.msg_controllen = len; + msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0; + + if (inet_test_bit(PKTINFO, sk)) { + struct in_pktinfo info; + + info.ipi_addr.s_addr = READ_ONCE(inet->inet_rcv_saddr); + info.ipi_spec_dst.s_addr = READ_ONCE(inet->inet_rcv_saddr); + info.ipi_ifindex = READ_ONCE(inet->mc_index); + put_cmsg(&msg, SOL_IP, IP_PKTINFO, sizeof(info), &info); + } + if (inet_test_bit(TTL, sk)) { + int hlim = READ_ONCE(inet->mc_ttl); + + put_cmsg(&msg, SOL_IP, IP_TTL, sizeof(hlim), &hlim); + } + if (inet_test_bit(TOS, sk)) { + int tos = READ_ONCE(inet->rcv_tos); + put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos); + } + len -= msg.msg_controllen; + return copy_to_sockptr(optlen, &len, sizeof(int)); } - case IP_RECVERR: - val = inet->recverr; - break; - case IP_RECVERR_RFC4884: - val = inet->recverr_rfc4884; - break; - case IP_MULTICAST_TTL: - val = inet->mc_ttl; - break; - case IP_MULTICAST_LOOP: - val = inet->mc_loop; - break; case IP_UNICAST_IF: - val = (__force int)htonl((__u32) inet->uc_index); - break; + val = (__force int)htonl((__u32) READ_ONCE(inet->uc_index)); + goto copyval; case IP_MULTICAST_IF: { struct in_addr addr; len = min_t(unsigned int, len, sizeof(struct in_addr)); - addr.s_addr = inet->mc_addr; - release_sock(sk); + addr.s_addr = READ_ONCE(inet->mc_addr); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &addr, len)) + if (copy_to_sockptr(optval, &addr, len)) return -EFAULT; return 0; } + case IP_LOCAL_PORT_RANGE: + val = READ_ONCE(inet->local_port_range); + goto copyval; + } + + if (needs_rtnl) + rtnl_lock(); + sockopt_lock_sock(sk); + + switch (optname) { case IP_MSFILTER: { struct ip_msfilter msf; @@ -1673,12 +1711,11 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, err = -EINVAL; goto out; } - if (copy_from_user(&msf, optval, IP_MSFILTER_SIZE(0))) { + if (copy_from_sockptr(&msf, optval, IP_MSFILTER_SIZE(0))) { err = -EFAULT; goto out; } - err = ip_mc_msfget(sk, &msf, - (struct ip_msfilter __user *)optval, optlen); + err = ip_mc_msfget(sk, &msf, optval, optlen); goto out; } case MCAST_MSFILTER: @@ -1688,75 +1725,33 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, else err = ip_get_mcast_msfilter(sk, optval, optlen, len); goto out; - case IP_MULTICAST_ALL: - val = inet->mc_all; - break; - case IP_PKTOPTIONS: - { - struct msghdr msg; - - release_sock(sk); - - if (sk->sk_type != SOCK_STREAM) - return -ENOPROTOOPT; - - msg.msg_control_is_user = true; - msg.msg_control_user = optval; - msg.msg_controllen = len; - msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0; - - if (inet->cmsg_flags & IP_CMSG_PKTINFO) { - struct in_pktinfo info; - - info.ipi_addr.s_addr = inet->inet_rcv_saddr; - info.ipi_spec_dst.s_addr = inet->inet_rcv_saddr; - info.ipi_ifindex = inet->mc_index; - put_cmsg(&msg, SOL_IP, IP_PKTINFO, sizeof(info), &info); - } - if (inet->cmsg_flags & IP_CMSG_TTL) { - int hlim = inet->mc_ttl; - put_cmsg(&msg, SOL_IP, IP_TTL, sizeof(hlim), &hlim); - } - if (inet->cmsg_flags & IP_CMSG_TOS) { - int tos = inet->rcv_tos; - put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos); - } - len -= msg.msg_controllen; - return put_user(len, optlen); - } - case IP_FREEBIND: - val = inet->freebind; - break; - case IP_TRANSPARENT: - val = inet->transparent; - break; - case IP_MINTTL: - val = inet->min_ttl; + case IP_PROTOCOL: + val = inet_sk(sk)->inet_num; break; default: - release_sock(sk); + sockopt_release_sock(sk); return -ENOPROTOOPT; } - release_sock(sk); - + sockopt_release_sock(sk); +copyval: if (len < sizeof(int) && len > 0 && val >= 0 && val <= 255) { unsigned char ucval = (unsigned char)val; len = 1; - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &ucval, 1)) + if (copy_to_sockptr(optval, &ucval, 1)) return -EFAULT; } else { len = min_t(unsigned int, sizeof(int), len); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &val, len)) + if (copy_to_sockptr(optval, &val, len)) return -EFAULT; } return 0; out: - release_sock(sk); + sockopt_release_sock(sk); if (needs_rtnl) rtnl_unlock(); return err; @@ -1767,13 +1762,9 @@ int ip_getsockopt(struct sock *sk, int level, { int err; - err = do_ip_getsockopt(sk, level, optname, optval, optlen); + err = do_ip_getsockopt(sk, level, optname, + USER_SOCKPTR(optval), USER_SOCKPTR(optlen)); -#if IS_ENABLED(CONFIG_BPFILTER_UMH) - if (optname >= BPFILTER_IPT_SO_GET_INFO && - optname < BPFILTER_IPT_GET_MAX) - err = bpfilter_ip_get_sockopt(sk, optname, optval, optlen); -#endif #ifdef CONFIG_NETFILTER /* we need to exclude all possible ENOPROTOOPTs except default case */ if (err == -ENOPROTOOPT && optname != IP_PKTOPTIONS && diff --git a/net/ipv4/ip_tunnel.c b/net/ipv4/ip_tunnel.c index 5a473319d3a5..158a30ae7c5f 100644 --- a/net/ipv4/ip_tunnel.c +++ b/net/ipv4/ip_tunnel.c @@ -40,9 +40,11 @@ #include <net/xfrm.h> #include <net/net_namespace.h> #include <net/netns/generic.h> +#include <net/netdev_lock.h> #include <net/rtnetlink.h> #include <net/udp.h> #include <net/dst_metadata.h> +#include <net/inet_dscp.h> #if IS_ENABLED(CONFIG_IPV6) #include <net/ipv6.h> @@ -56,17 +58,13 @@ static unsigned int ip_tunnel_hash(__be32 key, __be32 remote) IP_TNL_HASH_BITS); } -static bool ip_tunnel_key_match(const struct ip_tunnel_parm *p, - __be16 flags, __be32 key) +static bool ip_tunnel_key_match(const struct ip_tunnel_parm_kern *p, + const unsigned long *flags, __be32 key) { - if (p->i_flags & TUNNEL_KEY) { - if (flags & TUNNEL_KEY) - return key == p->i_key; - else - /* key expected, none present */ - return false; - } else - return !(flags & TUNNEL_KEY); + if (!test_bit(IP_TUNNEL_KEY_BIT, flags)) + return !test_bit(IP_TUNNEL_KEY_BIT, p->i_flags); + + return test_bit(IP_TUNNEL_KEY_BIT, p->i_flags) && p->i_key == key; } /* Fallback tunnel: no source, no destination, no key, no options @@ -81,7 +79,7 @@ static bool ip_tunnel_key_match(const struct ip_tunnel_parm *p, Given src, dst and key, find appropriate for input tunnel. */ struct ip_tunnel *ip_tunnel_lookup(struct ip_tunnel_net *itn, - int link, __be16 flags, + int link, const unsigned long *flags, __be32 remote, __be32 local, __be32 key) { @@ -102,10 +100,9 @@ struct ip_tunnel *ip_tunnel_lookup(struct ip_tunnel_net *itn, if (!ip_tunnel_key_match(&t->parms, flags, key)) continue; - if (t->parms.link == link) + if (READ_ONCE(t->parms.link) == link) return t; - else - cand = t; + cand = t; } hlist_for_each_entry_rcu(t, head, hash_node) { @@ -117,9 +114,9 @@ struct ip_tunnel *ip_tunnel_lookup(struct ip_tunnel_net *itn, if (!ip_tunnel_key_match(&t->parms, flags, key)) continue; - if (t->parms.link == link) + if (READ_ONCE(t->parms.link) == link) return t; - else if (!cand) + if (!cand) cand = t; } @@ -137,22 +134,23 @@ struct ip_tunnel *ip_tunnel_lookup(struct ip_tunnel_net *itn, if (!ip_tunnel_key_match(&t->parms, flags, key)) continue; - if (t->parms.link == link) + if (READ_ONCE(t->parms.link) == link) return t; - else if (!cand) + if (!cand) cand = t; } hlist_for_each_entry_rcu(t, head, hash_node) { - if ((!(flags & TUNNEL_NO_KEY) && t->parms.i_key != key) || + if ((!test_bit(IP_TUNNEL_NO_KEY_BIT, flags) && + t->parms.i_key != key) || t->parms.iph.saddr != 0 || t->parms.iph.daddr != 0 || !(t->dev->flags & IFF_UP)) continue; - if (t->parms.link == link) + if (READ_ONCE(t->parms.link) == link) return t; - else if (!cand) + if (!cand) cand = t; } @@ -172,7 +170,7 @@ struct ip_tunnel *ip_tunnel_lookup(struct ip_tunnel_net *itn, EXPORT_SYMBOL_GPL(ip_tunnel_lookup); static struct hlist_head *ip_bucket(struct ip_tunnel_net *itn, - struct ip_tunnel_parm *parms) + struct ip_tunnel_parm_kern *parms) { unsigned int h; __be32 remote; @@ -183,7 +181,8 @@ static struct hlist_head *ip_bucket(struct ip_tunnel_net *itn, else remote = 0; - if (!(parms->i_flags & TUNNEL_KEY) && (parms->i_flags & VTI_ISVTI)) + if (!test_bit(IP_TUNNEL_KEY_BIT, parms->i_flags) && + test_bit(IP_TUNNEL_VTI_BIT, parms->i_flags)) i_key = 0; h = ip_tunnel_hash(i_key, remote); @@ -207,21 +206,23 @@ static void ip_tunnel_del(struct ip_tunnel_net *itn, struct ip_tunnel *t) } static struct ip_tunnel *ip_tunnel_find(struct ip_tunnel_net *itn, - struct ip_tunnel_parm *parms, + struct ip_tunnel_parm_kern *parms, int type) { __be32 remote = parms->iph.daddr; __be32 local = parms->iph.saddr; + IP_TUNNEL_DECLARE_FLAGS(flags); __be32 key = parms->i_key; - __be16 flags = parms->i_flags; int link = parms->link; struct ip_tunnel *t = NULL; struct hlist_head *head = ip_bucket(itn, parms); - hlist_for_each_entry_rcu(t, head, hash_node) { + ip_tunnel_flags_copy(flags, parms->i_flags); + + hlist_for_each_entry_rcu(t, head, hash_node, lockdep_rtnl_is_held()) { if (local == t->parms.iph.saddr && remote == t->parms.iph.daddr && - link == t->parms.link && + link == READ_ONCE(t->parms.link) && type == t->dev->type && ip_tunnel_key_match(&t->parms, flags, key)) break; @@ -231,7 +232,7 @@ static struct ip_tunnel *ip_tunnel_find(struct ip_tunnel_net *itn, static struct net_device *__ip_tunnel_create(struct net *net, const struct rtnl_link_ops *ops, - struct ip_tunnel_parm *parms) + struct ip_tunnel_parm_kern *parms) { int err; struct ip_tunnel *tunnel; @@ -242,11 +243,11 @@ static struct net_device *__ip_tunnel_create(struct net *net, if (parms->name[0]) { if (!dev_valid_name(parms->name)) goto failed; - strlcpy(name, parms->name, IFNAMSIZ); + strscpy(name, parms->name); } else { if (strlen(ops->kind) > (IFNAMSIZ - 3)) goto failed; - strcpy(name, ops->kind); + strscpy(name, ops->kind); strcat(name, "%d"); } @@ -294,8 +295,8 @@ static int ip_tunnel_bind_dev(struct net_device *dev) ip_tunnel_init_flow(&fl4, iph->protocol, iph->daddr, iph->saddr, tunnel->parms.o_key, - RT_TOS(iph->tos), tunnel->parms.link, - tunnel->fwmark, 0); + iph->tos & INET_DSCP_MASK, tunnel->net, + tunnel->parms.link, tunnel->fwmark, 0, 0); rt = ip_route_output_key(tunnel->net, &fl4); if (!IS_ERR(rt)) { @@ -327,7 +328,7 @@ static int ip_tunnel_bind_dev(struct net_device *dev) static struct ip_tunnel *ip_tunnel_create(struct net *net, struct ip_tunnel_net *itn, - struct ip_tunnel_parm *parms) + struct ip_tunnel_parm_kern *parms) { struct ip_tunnel *nt; struct net_device *dev; @@ -359,47 +360,74 @@ err_dev_set_mtu: return ERR_PTR(err); } +void ip_tunnel_md_udp_encap(struct sk_buff *skb, struct ip_tunnel_info *info) +{ + const struct iphdr *iph = ip_hdr(skb); + const struct udphdr *udph; + + if (iph->protocol != IPPROTO_UDP) + return; + + udph = (struct udphdr *)((__u8 *)iph + (iph->ihl << 2)); + info->encap.sport = udph->source; + info->encap.dport = udph->dest; +} +EXPORT_SYMBOL(ip_tunnel_md_udp_encap); + int ip_tunnel_rcv(struct ip_tunnel *tunnel, struct sk_buff *skb, const struct tnl_ptk_info *tpi, struct metadata_dst *tun_dst, bool log_ecn_error) { const struct iphdr *iph = ip_hdr(skb); - int err; + int nh, err; #ifdef CONFIG_NET_IPGRE_BROADCAST if (ipv4_is_multicast(iph->daddr)) { - tunnel->dev->stats.multicast++; + DEV_STATS_INC(tunnel->dev, multicast); skb->pkt_type = PACKET_BROADCAST; } #endif - if ((!(tpi->flags&TUNNEL_CSUM) && (tunnel->parms.i_flags&TUNNEL_CSUM)) || - ((tpi->flags&TUNNEL_CSUM) && !(tunnel->parms.i_flags&TUNNEL_CSUM))) { - tunnel->dev->stats.rx_crc_errors++; - tunnel->dev->stats.rx_errors++; + if (test_bit(IP_TUNNEL_CSUM_BIT, tunnel->parms.i_flags) != + test_bit(IP_TUNNEL_CSUM_BIT, tpi->flags)) { + DEV_STATS_INC(tunnel->dev, rx_crc_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } - if (tunnel->parms.i_flags&TUNNEL_SEQ) { - if (!(tpi->flags&TUNNEL_SEQ) || + if (test_bit(IP_TUNNEL_SEQ_BIT, tunnel->parms.i_flags)) { + if (!test_bit(IP_TUNNEL_SEQ_BIT, tpi->flags) || (tunnel->i_seqno && (s32)(ntohl(tpi->seq) - tunnel->i_seqno) < 0)) { - tunnel->dev->stats.rx_fifo_errors++; - tunnel->dev->stats.rx_errors++; + DEV_STATS_INC(tunnel->dev, rx_fifo_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } tunnel->i_seqno = ntohl(tpi->seq) + 1; } + /* Save offset of outer header relative to skb->head, + * because we are going to reset the network header to the inner header + * and might change skb->head. + */ + nh = skb_network_header(skb) - skb->head; + skb_set_network_header(skb, (tunnel->dev->type == ARPHRD_ETHER) ? ETH_HLEN : 0); + if (!pskb_inet_may_pull(skb)) { + DEV_STATS_INC(tunnel->dev, rx_length_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); + goto drop; + } + iph = (struct iphdr *)(skb->head + nh); + err = IP_ECN_decapsulate(iph, skb); if (unlikely(err)) { if (log_ecn_error) net_info_ratelimited("non-ECT from %pI4 with TOS=%#x\n", &iph->saddr, iph->tos); if (err > 1) { - ++tunnel->dev->stats.rx_frame_errors; - ++tunnel->dev->stats.rx_errors; + DEV_STATS_INC(tunnel->dev, rx_frame_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); goto drop; } } @@ -517,7 +545,7 @@ static int tnl_update_pmtu(struct net_device *dev, struct sk_buff *skb, struct rt6_info *rt6; __be32 daddr; - rt6 = skb_valid_dst(skb) ? (struct rt6_info *)skb_dst(skb) : + rt6 = skb_valid_dst(skb) ? dst_rt6_info(skb_dst(skb)) : NULL; daddr = md ? dst : tunnel->parms.iph.daddr; @@ -569,9 +597,14 @@ void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, tos = ipv6_get_dsfield((const struct ipv6hdr *)inner_iph); } ip_tunnel_init_flow(&fl4, proto, key->u.ipv4.dst, key->u.ipv4.src, - tunnel_id_to_key32(key->tun_id), RT_TOS(tos), - 0, skb->mark, skb_get_hash(skb)); - if (tunnel->encap.type != TUNNEL_ENCAP_NONE) + tunnel_id_to_key32(key->tun_id), + tos & INET_DSCP_MASK, tunnel->net, 0, skb->mark, + skb_get_hash(skb), key->flow_flags); + + if (!tunnel_hlen) + tunnel_hlen = ip_encap_hlen(&tun_info->encap); + + if (ip_tunnel_encap(skb, &tun_info->encap, &proto, &fl4) < 0) goto tx_error; use_cache = ip_tunnel_dst_cache_usable(skb, tun_info); @@ -580,7 +613,7 @@ void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, if (!rt) { rt = ip_route_output_key(tunnel->net, &fl4); if (IS_ERR(rt)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error; } if (use_cache) @@ -589,11 +622,11 @@ void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, } if (rt->dst.dev == dev) { ip_rt_put(rt); - dev->stats.collisions++; + DEV_STATS_INC(dev, collisions); goto tx_error; } - if (key->tun_flags & TUNNEL_DONT_FRAGMENT) + if (test_bit(IP_TUNNEL_DONT_FRAGMENT_BIT, key->tun_flags)) df = htons(IP_DF); if (tnl_update_pmtu(dev, skb, rt, df, inner_iph, tunnel_hlen, key->u.ipv4.dst, true)) { @@ -613,21 +646,21 @@ void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, } headroom += LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len; - if (headroom > dev->needed_headroom) - dev->needed_headroom = headroom; - - if (skb_cow_head(skb, dev->needed_headroom)) { + if (skb_cow_head(skb, headroom)) { ip_rt_put(rt); goto tx_dropped; } + + ip_tunnel_adj_headroom(dev, headroom); + iptunnel_xmit(NULL, rt, skb, fl4.saddr, fl4.daddr, proto, tos, ttl, - df, !net_eq(tunnel->net, dev_net(dev))); + df, !net_eq(tunnel->net, dev_net(dev)), 0); return; tx_error: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); goto kfree; tx_dropped: - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); kfree: kfree_skb(skb); } @@ -641,6 +674,7 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, const struct iphdr *inner_iph; unsigned int max_headroom; /* The extra header space needed */ struct rtable *rt = NULL; /* Route to the other host */ + __be16 payload_protocol; bool use_cache = false; struct flowi4 fl4; bool md = false; @@ -651,6 +685,7 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, inner_iph = (const struct iphdr *)skb_inner_network_header(skb); connected = (tunnel->parms.iph.daddr != 0); + payload_protocol = skb_protocol(skb, true); memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); @@ -659,7 +694,7 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, /* NBMA tunnel */ if (!skb_dst(skb)) { - dev->stats.tx_fifo_errors++; + DEV_STATS_INC(dev, tx_fifo_errors); goto tx_error; } @@ -670,13 +705,12 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, dst = tun_info->key.u.ipv4.dst; md = true; connected = true; - } - else if (skb->protocol == htons(ETH_P_IP)) { + } else if (payload_protocol == htons(ETH_P_IP)) { rt = skb_rtable(skb); dst = rt_nexthop(rt, inner_iph->daddr); } #if IS_ENABLED(CONFIG_IPV6) - else if (skb->protocol == htons(ETH_P_IPV6)) { + else if (payload_protocol == htons(ETH_P_IPV6)) { const struct in6_addr *addr6; struct neighbour *neigh; bool do_tx_error_icmp; @@ -716,20 +750,21 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, tos = tnl_params->tos; if (tos & 0x1) { tos &= ~0x1; - if (skb->protocol == htons(ETH_P_IP)) { + if (payload_protocol == htons(ETH_P_IP)) { tos = inner_iph->tos; connected = false; - } else if (skb->protocol == htons(ETH_P_IPV6)) { + } else if (payload_protocol == htons(ETH_P_IPV6)) { tos = ipv6_get_dsfield((const struct ipv6hdr *)inner_iph); connected = false; } } ip_tunnel_init_flow(&fl4, protocol, dst, tnl_params->saddr, - tunnel->parms.o_key, RT_TOS(tos), tunnel->parms.link, - tunnel->fwmark, skb_get_hash(skb)); + tunnel->parms.o_key, tos & INET_DSCP_MASK, + tunnel->net, READ_ONCE(tunnel->parms.link), + tunnel->fwmark, skb_get_hash(skb), 0); - if (ip_tunnel_encap(skb, tunnel, &protocol, &fl4) < 0) + if (ip_tunnel_encap(skb, &tunnel->encap, &protocol, &fl4) < 0) goto tx_error; if (connected && md) { @@ -746,7 +781,7 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, rt = ip_route_output_key(tunnel->net, &fl4); if (IS_ERR(rt)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error; } if (use_cache) @@ -759,12 +794,12 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, if (rt->dst.dev == dev) { ip_rt_put(rt); - dev->stats.collisions++; + DEV_STATS_INC(dev, collisions); goto tx_error; } df = tnl_params->frag_off; - if (skb->protocol == htons(ETH_P_IP) && !tunnel->ignore_df) + if (payload_protocol == htons(ETH_P_IP) && !tunnel->ignore_df) df |= (inner_iph->frag_off & htons(IP_DF)); if (tnl_update_pmtu(dev, skb, rt, df, inner_iph, 0, 0, false)) { @@ -785,10 +820,10 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, tos = ip_tunnel_ecn_encap(tos, inner_iph, skb); ttl = tnl_params->ttl; if (ttl == 0) { - if (skb->protocol == htons(ETH_P_IP)) + if (payload_protocol == htons(ETH_P_IP)) ttl = inner_iph->ttl; #if IS_ENABLED(CONFIG_IPV6) - else if (skb->protocol == htons(ETH_P_IPV6)) + else if (payload_protocol == htons(ETH_P_IPV6)) ttl = ((const struct ipv6hdr *)inner_iph)->hop_limit; #endif else @@ -797,18 +832,18 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, max_headroom = LL_RESERVED_SPACE(rt->dst.dev) + sizeof(struct iphdr) + rt->dst.header_len + ip_encap_hlen(&tunnel->encap); - if (max_headroom > dev->needed_headroom) - dev->needed_headroom = max_headroom; - if (skb_cow_head(skb, dev->needed_headroom)) { + if (skb_cow_head(skb, max_headroom)) { ip_rt_put(rt); - dev->stats.tx_dropped++; + DEV_STATS_INC(dev, tx_dropped); kfree_skb(skb); return; } + ip_tunnel_adj_headroom(dev, max_headroom); + iptunnel_xmit(NULL, rt, skb, fl4.saddr, fl4.daddr, protocol, tos, ttl, - df, !net_eq(tunnel->net, dev_net(dev))); + df, !net_eq(tunnel->net, dev_net(dev)), 0); return; #if IS_ENABLED(CONFIG_IPV6) @@ -816,7 +851,7 @@ tx_error_icmp: dst_link_failure(skb); #endif tx_error: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); } EXPORT_SYMBOL_GPL(ip_tunnel_xmit); @@ -824,7 +859,7 @@ EXPORT_SYMBOL_GPL(ip_tunnel_xmit); static void ip_tunnel_update(struct ip_tunnel_net *itn, struct ip_tunnel *t, struct net_device *dev, - struct ip_tunnel_parm *p, + struct ip_tunnel_parm_kern *p, bool set_mtu, __u32 fwmark) { @@ -846,17 +881,18 @@ static void ip_tunnel_update(struct ip_tunnel_net *itn, if (t->parms.link != p->link || t->fwmark != fwmark) { int mtu; - t->parms.link = p->link; + WRITE_ONCE(t->parms.link, p->link); t->fwmark = fwmark; mtu = ip_tunnel_bind_dev(dev); if (set_mtu) - dev->mtu = mtu; + WRITE_ONCE(dev->mtu, mtu); } dst_cache_reset(&t->dst_cache); netdev_state_change(dev); } -int ip_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) +int ip_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm_kern *p, + int cmd) { int err = 0; struct ip_tunnel *t = netdev_priv(dev); @@ -880,10 +916,10 @@ int ip_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) goto done; if (p->iph.ttl) p->iph.frag_off |= htons(IP_DF); - if (!(p->i_flags & VTI_ISVTI)) { - if (!(p->i_flags & TUNNEL_KEY)) + if (!test_bit(IP_TUNNEL_VTI_BIT, p->i_flags)) { + if (!test_bit(IP_TUNNEL_KEY_BIT, p->i_flags)) p->i_key = 0; - if (!(p->o_flags & TUNNEL_KEY)) + if (!test_bit(IP_TUNNEL_KEY_BIT, p->o_flags)) p->o_key = 0; } @@ -958,16 +994,58 @@ done: } EXPORT_SYMBOL_GPL(ip_tunnel_ctl); +bool ip_tunnel_parm_from_user(struct ip_tunnel_parm_kern *kp, + const void __user *data) +{ + struct ip_tunnel_parm p; + + if (copy_from_user(&p, data, sizeof(p))) + return false; + + strscpy(kp->name, p.name); + kp->link = p.link; + ip_tunnel_flags_from_be16(kp->i_flags, p.i_flags); + ip_tunnel_flags_from_be16(kp->o_flags, p.o_flags); + kp->i_key = p.i_key; + kp->o_key = p.o_key; + memcpy(&kp->iph, &p.iph, min(sizeof(kp->iph), sizeof(p.iph))); + + return true; +} +EXPORT_SYMBOL_GPL(ip_tunnel_parm_from_user); + +bool ip_tunnel_parm_to_user(void __user *data, struct ip_tunnel_parm_kern *kp) +{ + struct ip_tunnel_parm p; + + if (!ip_tunnel_flags_is_be16_compat(kp->i_flags) || + !ip_tunnel_flags_is_be16_compat(kp->o_flags)) + return false; + + memset(&p, 0, sizeof(p)); + + strscpy(p.name, kp->name); + p.link = kp->link; + p.i_flags = ip_tunnel_flags_to_be16(kp->i_flags); + p.o_flags = ip_tunnel_flags_to_be16(kp->o_flags); + p.i_key = kp->i_key; + p.o_key = kp->o_key; + memcpy(&p.iph, &kp->iph, min(sizeof(p.iph), sizeof(kp->iph))); + + return !copy_to_user(data, &p, sizeof(p)); +} +EXPORT_SYMBOL_GPL(ip_tunnel_parm_to_user); + int ip_tunnel_siocdevprivate(struct net_device *dev, struct ifreq *ifr, void __user *data, int cmd) { - struct ip_tunnel_parm p; + struct ip_tunnel_parm_kern p; int err; - if (copy_from_user(&p, data, sizeof(p))) + if (!ip_tunnel_parm_from_user(&p, data)) return -EFAULT; err = dev->netdev_ops->ndo_tunnel_ctl(dev, &p, cmd); - if (!err && copy_to_user(data, &p, sizeof(p))) + if (!err && !ip_tunnel_parm_to_user(data, &p)) return -EFAULT; return err; } @@ -992,7 +1070,7 @@ int __ip_tunnel_change_mtu(struct net_device *dev, int new_mtu, bool strict) new_mtu = max_mtu; } - dev->mtu = new_mtu; + WRITE_ONCE(dev->mtu, new_mtu); return 0; } EXPORT_SYMBOL_GPL(__ip_tunnel_change_mtu); @@ -1009,7 +1087,6 @@ static void ip_tunnel_dev_free(struct net_device *dev) gro_cells_destroy(&tunnel->gro_cells); dst_cache_destroy(&tunnel->dst_cache); - free_percpu(dev->tstats); } void ip_tunnel_dellink(struct net_device *dev, struct list_head *head) @@ -1030,15 +1107,15 @@ struct net *ip_tunnel_get_link_net(const struct net_device *dev) { struct ip_tunnel *tunnel = netdev_priv(dev); - return tunnel->net; + return READ_ONCE(tunnel->net); } EXPORT_SYMBOL(ip_tunnel_get_link_net); int ip_tunnel_get_iflink(const struct net_device *dev) { - struct ip_tunnel *tunnel = netdev_priv(dev); + const struct ip_tunnel *tunnel = netdev_priv(dev); - return tunnel->parms.link; + return READ_ONCE(tunnel->parms.link); } EXPORT_SYMBOL(ip_tunnel_get_iflink); @@ -1046,7 +1123,7 @@ int ip_tunnel_init_net(struct net *net, unsigned int ip_tnl_net_id, struct rtnl_link_ops *ops, char *devname) { struct ip_tunnel_net *itn = net_generic(net, ip_tnl_net_id); - struct ip_tunnel_parm parms; + struct ip_tunnel_parm_kern parms; unsigned int i; itn->rtnl_link_ops = ops; @@ -1064,7 +1141,7 @@ int ip_tunnel_init_net(struct net *net, unsigned int ip_tnl_net_id, memset(&parms, 0, sizeof(parms)); if (devname) - strlcpy(parms.name, devname, IFNAMSIZ); + strscpy(parms.name, devname, IFNAMSIZ); rtnl_lock(); itn->fb_tunnel_dev = __ip_tunnel_create(net, ops, &parms); @@ -1072,7 +1149,7 @@ int ip_tunnel_init_net(struct net *net, unsigned int ip_tnl_net_id, * Allowing to move it to another netns is clearly unsafe. */ if (!IS_ERR(itn->fb_tunnel_dev)) { - itn->fb_tunnel_dev->features |= NETIF_F_NETNS_LOCAL; + itn->fb_tunnel_dev->netns_immutable = true; itn->fb_tunnel_dev->mtu = ip_tunnel_bind_dev(itn->fb_tunnel_dev); ip_tunnel_add(itn, netdev_priv(itn->fb_tunnel_dev)); itn->type = itn->fb_tunnel_dev->type; @@ -1083,13 +1160,16 @@ int ip_tunnel_init_net(struct net *net, unsigned int ip_tnl_net_id, } EXPORT_SYMBOL_GPL(ip_tunnel_init_net); -static void ip_tunnel_destroy(struct net *net, struct ip_tunnel_net *itn, - struct list_head *head, - struct rtnl_link_ops *ops) +void ip_tunnel_delete_net(struct net *net, unsigned int id, + struct rtnl_link_ops *ops, + struct list_head *head) { + struct ip_tunnel_net *itn = net_generic(net, id); struct net_device *dev, *aux; int h; + ASSERT_RTNL_NET(net); + for_each_netdev_safe(net, dev, aux) if (dev->rtnl_link_ops == ops) unregister_netdevice_queue(dev, head); @@ -1107,29 +1187,13 @@ static void ip_tunnel_destroy(struct net *net, struct ip_tunnel_net *itn, unregister_netdevice_queue(t->dev, head); } } +EXPORT_SYMBOL_GPL(ip_tunnel_delete_net); -void ip_tunnel_delete_nets(struct list_head *net_list, unsigned int id, - struct rtnl_link_ops *ops) -{ - struct ip_tunnel_net *itn; - struct net *net; - LIST_HEAD(list); - - rtnl_lock(); - list_for_each_entry(net, net_list, exit_list) { - itn = net_generic(net, id); - ip_tunnel_destroy(net, itn, &list, ops); - } - unregister_netdevice_many(&list); - rtnl_unlock(); -} -EXPORT_SYMBOL_GPL(ip_tunnel_delete_nets); - -int ip_tunnel_newlink(struct net_device *dev, struct nlattr *tb[], - struct ip_tunnel_parm *p, __u32 fwmark) +int ip_tunnel_newlink(struct net *net, struct net_device *dev, + struct nlattr *tb[], struct ip_tunnel_parm_kern *p, + __u32 fwmark) { struct ip_tunnel *nt; - struct net *net = dev_net(dev); struct ip_tunnel_net *itn; int mtu; int err; @@ -1180,7 +1244,7 @@ err_register_netdevice: EXPORT_SYMBOL_GPL(ip_tunnel_newlink); int ip_tunnel_changelink(struct net_device *dev, struct nlattr *tb[], - struct ip_tunnel_parm *p, __u32 fwmark) + struct ip_tunnel_parm_kern *p, __u32 fwmark) { struct ip_tunnel *t; struct ip_tunnel *tunnel = netdev_priv(dev); @@ -1225,31 +1289,26 @@ int ip_tunnel_init(struct net_device *dev) dev->needs_free_netdev = true; dev->priv_destructor = ip_tunnel_dev_free; - dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); - if (!dev->tstats) - return -ENOMEM; + dev->pcpu_stat_type = NETDEV_PCPU_STAT_TSTATS; err = dst_cache_init(&tunnel->dst_cache, GFP_KERNEL); - if (err) { - free_percpu(dev->tstats); + if (err) return err; - } err = gro_cells_init(&tunnel->gro_cells, dev); if (err) { dst_cache_destroy(&tunnel->dst_cache); - free_percpu(dev->tstats); return err; } tunnel->dev = dev; - tunnel->net = dev_net(dev); - strcpy(tunnel->parms.name, dev->name); + strscpy(tunnel->parms.name, dev->name); iph->version = 4; iph->ihl = 5; if (tunnel->collect_md) netif_keep_dst(dev); + netdev_lockdep_set_classes(dev); return 0; } EXPORT_SYMBOL_GPL(ip_tunnel_init); @@ -1277,4 +1336,5 @@ void ip_tunnel_setup(struct net_device *dev, unsigned int net_id) } EXPORT_SYMBOL_GPL(ip_tunnel_setup); +MODULE_DESCRIPTION("IPv4 tunnel implementation library"); MODULE_LICENSE("GPL"); diff --git a/net/ipv4/ip_tunnel_core.c b/net/ipv4/ip_tunnel_core.c index 6b2dc7b2b612..2e61ac137128 100644 --- a/net/ipv4/ip_tunnel_core.c +++ b/net/ipv4/ip_tunnel_core.c @@ -49,7 +49,8 @@ EXPORT_SYMBOL(ip6tun_encaps); void iptunnel_xmit(struct sock *sk, struct rtable *rt, struct sk_buff *skb, __be32 src, __be32 dst, __u8 proto, - __u8 tos, __u8 ttl, __be16 df, bool xnet) + __u8 tos, __u8 ttl, __be16 df, bool xnet, + u16 ipcb_flags) { int pkt_len = skb->len - skb_inner_network_offset(skb); struct net *net = dev_net(rt->dst.dev); @@ -62,6 +63,7 @@ void iptunnel_xmit(struct sock *sk, struct rtable *rt, struct sk_buff *skb, skb_clear_hash_if_not_l4(skb); skb_dst_set(skb, &rt->dst); memset(IPCB(skb), 0, sizeof(*IPCB(skb))); + IPCB(skb)->flags = ipcb_flags; /* Push down and install the IP header. */ skb_push(skb, sizeof(struct iphdr)); @@ -125,6 +127,7 @@ EXPORT_SYMBOL_GPL(__iptunnel_pull_header); struct metadata_dst *iptunnel_metadata_reply(struct metadata_dst *md, gfp_t flags) { + IP_TUNNEL_DECLARE_FLAGS(tun_flags) = { }; struct metadata_dst *res; struct ip_tunnel_info *dst, *src; @@ -144,10 +147,10 @@ struct metadata_dst *iptunnel_metadata_reply(struct metadata_dst *md, sizeof(struct in6_addr)); else dst->key.u.ipv4.dst = src->key.u.ipv4.src; - dst->key.tun_flags = src->key.tun_flags; + ip_tunnel_flags_copy(dst->key.tun_flags, src->key.tun_flags); dst->mode = src->mode | IP_TUNNEL_INFO_TX; ip_tunnel_info_opts_set(dst, ip_tunnel_info_opts(src), - src->options_len, 0); + src->options_len, tun_flags); return res; } @@ -203,6 +206,9 @@ static int iptunnel_pmtud_build_icmp(struct sk_buff *skb, int mtu) if (!pskb_may_pull(skb, ETH_HLEN + sizeof(struct iphdr))) return -EINVAL; + if (skb_is_gso(skb)) + skb_gso_reset(skb); + skb_copy_bits(skb, skb_mac_offset(skb), &eh, ETH_HLEN); pskb_pull(skb, ETH_HLEN); skb_reset_network_header(skb); @@ -224,7 +230,7 @@ static int iptunnel_pmtud_build_icmp(struct sk_buff *skb, int mtu) .un.frag.__unused = 0, .un.frag.mtu = htons(mtu), }; - icmph->checksum = ip_compute_csum(icmph, len); + icmph->checksum = csum_fold(skb_checksum(skb, 0, len, 0)); skb_reset_transport_header(skb); niph = skb_push(skb, sizeof(*niph)); @@ -297,6 +303,9 @@ static int iptunnel_pmtud_build_icmpv6(struct sk_buff *skb, int mtu) if (!pskb_may_pull(skb, ETH_HLEN + sizeof(struct ipv6hdr))) return -EINVAL; + if (skb_is_gso(skb)) + skb_gso_reset(skb); + skb_copy_bits(skb, skb_mac_offset(skb), &eh, ETH_HLEN); pskb_pull(skb, ETH_HLEN); skb_reset_network_header(skb); @@ -332,7 +341,7 @@ static int iptunnel_pmtud_build_icmpv6(struct sk_buff *skb, int mtu) }; skb_reset_network_header(skb); - csum = csum_partial(icmp6h, len, 0); + csum = skb_checksum(skb, skb_transport_offset(skb), len, 0); icmp6h->icmp6_cksum = csum_ipv6_magic(&nip6h->saddr, &nip6h->daddr, len, IPPROTO_ICMPV6, csum); @@ -410,12 +419,12 @@ int skb_tunnel_check_pmtu(struct sk_buff *skb, struct dst_entry *encap_dst, u32 mtu = dst_mtu(encap_dst) - headroom; if ((skb_is_gso(skb) && skb_gso_validate_network_len(skb, mtu)) || - (!skb_is_gso(skb) && (skb->len - skb_mac_header_len(skb)) <= mtu)) + (!skb_is_gso(skb) && (skb->len - skb_network_offset(skb)) <= mtu)) return 0; skb_dst_update_pmtu_no_confirm(skb, mtu); - if (!reply || skb->pkt_type == PACKET_HOST) + if (!reply) return 0; if (skb->protocol == htons(ETH_P_IP)) @@ -450,7 +459,7 @@ static const struct nla_policy geneve_opt_policy[LWTUNNEL_IP_OPT_GENEVE_MAX + 1] = { [LWTUNNEL_IP_OPT_GENEVE_CLASS] = { .type = NLA_U16 }, [LWTUNNEL_IP_OPT_GENEVE_TYPE] = { .type = NLA_U8 }, - [LWTUNNEL_IP_OPT_GENEVE_DATA] = { .type = NLA_BINARY, .len = 128 }, + [LWTUNNEL_IP_OPT_GENEVE_DATA] = { .type = NLA_BINARY, .len = 127 }, }; static const struct nla_policy @@ -497,7 +506,7 @@ static int ip_tun_parse_opts_geneve(struct nlattr *attr, opt->opt_class = nla_get_be16(attr); attr = tb[LWTUNNEL_IP_OPT_GENEVE_TYPE]; opt->type = nla_get_u8(attr); - info->key.tun_flags |= TUNNEL_GENEVE_OPT; + __set_bit(IP_TUNNEL_GENEVE_OPT_BIT, info->key.tun_flags); } return sizeof(struct geneve_opt) + data_len; @@ -525,7 +534,7 @@ static int ip_tun_parse_opts_vxlan(struct nlattr *attr, attr = tb[LWTUNNEL_IP_OPT_VXLAN_GBP]; md->gbp = nla_get_u32(attr); md->gbp &= VXLAN_GBP_MASK; - info->key.tun_flags |= TUNNEL_VXLAN_OPT; + __set_bit(IP_TUNNEL_VXLAN_OPT_BIT, info->key.tun_flags); } return sizeof(struct vxlan_metadata); @@ -574,7 +583,7 @@ static int ip_tun_parse_opts_erspan(struct nlattr *attr, set_hwid(&md->u.md2, nla_get_u8(attr)); } - info->key.tun_flags |= TUNNEL_ERSPAN_OPT; + __set_bit(IP_TUNNEL_ERSPAN_OPT_BIT, info->key.tun_flags); } return sizeof(struct erspan_metadata); @@ -585,7 +594,7 @@ static int ip_tun_parse_opts(struct nlattr *attr, struct ip_tunnel_info *info, { int err, rem, opt_len, opts_len = 0; struct nlattr *nla; - __be16 type = 0; + u32 type = 0; if (!attr) return 0; @@ -598,7 +607,7 @@ static int ip_tun_parse_opts(struct nlattr *attr, struct ip_tunnel_info *info, nla_for_each_attr(nla, nla_data(attr), nla_len(attr), rem) { switch (nla_type(nla)) { case LWTUNNEL_IP_OPTS_GENEVE: - if (type && type != TUNNEL_GENEVE_OPT) + if (type && type != IP_TUNNEL_GENEVE_OPT_BIT) return -EINVAL; opt_len = ip_tun_parse_opts_geneve(nla, info, opts_len, extack); @@ -607,7 +616,7 @@ static int ip_tun_parse_opts(struct nlattr *attr, struct ip_tunnel_info *info, opts_len += opt_len; if (opts_len > IP_TUNNEL_OPTS_MAX) return -EINVAL; - type = TUNNEL_GENEVE_OPT; + type = IP_TUNNEL_GENEVE_OPT_BIT; break; case LWTUNNEL_IP_OPTS_VXLAN: if (type) @@ -617,7 +626,7 @@ static int ip_tun_parse_opts(struct nlattr *attr, struct ip_tunnel_info *info, if (opt_len < 0) return opt_len; opts_len += opt_len; - type = TUNNEL_VXLAN_OPT; + type = IP_TUNNEL_VXLAN_OPT_BIT; break; case LWTUNNEL_IP_OPTS_ERSPAN: if (type) @@ -627,7 +636,7 @@ static int ip_tun_parse_opts(struct nlattr *attr, struct ip_tunnel_info *info, if (opt_len < 0) return opt_len; opts_len += opt_len; - type = TUNNEL_ERSPAN_OPT; + type = IP_TUNNEL_ERSPAN_OPT_BIT; break; default: return -EINVAL; @@ -705,10 +714,16 @@ static int ip_tun_build_state(struct net *net, struct nlattr *attr, if (tb[LWTUNNEL_IP_TOS]) tun_info->key.tos = nla_get_u8(tb[LWTUNNEL_IP_TOS]); - if (tb[LWTUNNEL_IP_FLAGS]) - tun_info->key.tun_flags |= - (nla_get_be16(tb[LWTUNNEL_IP_FLAGS]) & - ~TUNNEL_OPTIONS_PRESENT); + if (tb[LWTUNNEL_IP_FLAGS]) { + IP_TUNNEL_DECLARE_FLAGS(flags); + + ip_tunnel_flags_from_be16(flags, + nla_get_be16(tb[LWTUNNEL_IP_FLAGS])); + ip_tunnel_clear_options_present(flags); + + ip_tunnel_flags_or(tun_info->key.tun_flags, + tun_info->key.tun_flags, flags); + } tun_info->mode = IP_TUNNEL_INFO_TX; tun_info->options_len = opt_len; @@ -812,18 +827,18 @@ static int ip_tun_fill_encap_opts(struct sk_buff *skb, int type, struct nlattr *nest; int err = 0; - if (!(tun_info->key.tun_flags & TUNNEL_OPTIONS_PRESENT)) + if (!ip_tunnel_is_options_present(tun_info->key.tun_flags)) return 0; nest = nla_nest_start_noflag(skb, type); if (!nest) return -ENOMEM; - if (tun_info->key.tun_flags & TUNNEL_GENEVE_OPT) + if (test_bit(IP_TUNNEL_GENEVE_OPT_BIT, tun_info->key.tun_flags)) err = ip_tun_fill_encap_opts_geneve(skb, tun_info); - else if (tun_info->key.tun_flags & TUNNEL_VXLAN_OPT) + else if (test_bit(IP_TUNNEL_VXLAN_OPT_BIT, tun_info->key.tun_flags)) err = ip_tun_fill_encap_opts_vxlan(skb, tun_info); - else if (tun_info->key.tun_flags & TUNNEL_ERSPAN_OPT) + else if (test_bit(IP_TUNNEL_ERSPAN_OPT_BIT, tun_info->key.tun_flags)) err = ip_tun_fill_encap_opts_erspan(skb, tun_info); if (err) { @@ -846,7 +861,8 @@ static int ip_tun_fill_encap_info(struct sk_buff *skb, nla_put_in_addr(skb, LWTUNNEL_IP_SRC, tun_info->key.u.ipv4.src) || nla_put_u8(skb, LWTUNNEL_IP_TOS, tun_info->key.tos) || nla_put_u8(skb, LWTUNNEL_IP_TTL, tun_info->key.ttl) || - nla_put_be16(skb, LWTUNNEL_IP_FLAGS, tun_info->key.tun_flags) || + nla_put_be16(skb, LWTUNNEL_IP_FLAGS, + ip_tunnel_flags_to_be16(tun_info->key.tun_flags)) || ip_tun_fill_encap_opts(skb, LWTUNNEL_IP_OPTS, tun_info)) return -ENOMEM; @@ -857,11 +873,11 @@ static int ip_tun_opts_nlsize(struct ip_tunnel_info *info) { int opt_len; - if (!(info->key.tun_flags & TUNNEL_OPTIONS_PRESENT)) + if (!ip_tunnel_is_options_present(info->key.tun_flags)) return 0; opt_len = nla_total_size(0); /* LWTUNNEL_IP_OPTS */ - if (info->key.tun_flags & TUNNEL_GENEVE_OPT) { + if (test_bit(IP_TUNNEL_GENEVE_OPT_BIT, info->key.tun_flags)) { struct geneve_opt *opt; int offset = 0; @@ -874,10 +890,10 @@ static int ip_tun_opts_nlsize(struct ip_tunnel_info *info) /* OPT_GENEVE_DATA */ offset += sizeof(*opt) + opt->length * 4; } - } else if (info->key.tun_flags & TUNNEL_VXLAN_OPT) { + } else if (test_bit(IP_TUNNEL_VXLAN_OPT_BIT, info->key.tun_flags)) { opt_len += nla_total_size(0) /* LWTUNNEL_IP_OPTS_VXLAN */ + nla_total_size(4); /* OPT_VXLAN_GBP */ - } else if (info->key.tun_flags & TUNNEL_ERSPAN_OPT) { + } else if (test_bit(IP_TUNNEL_ERSPAN_OPT_BIT, info->key.tun_flags)) { struct erspan_metadata *md = ip_tunnel_info_opts(info); opt_len += nla_total_size(0) /* LWTUNNEL_IP_OPTS_ERSPAN */ @@ -984,10 +1000,17 @@ static int ip6_tun_build_state(struct net *net, struct nlattr *attr, if (tb[LWTUNNEL_IP6_TC]) tun_info->key.tos = nla_get_u8(tb[LWTUNNEL_IP6_TC]); - if (tb[LWTUNNEL_IP6_FLAGS]) - tun_info->key.tun_flags |= - (nla_get_be16(tb[LWTUNNEL_IP6_FLAGS]) & - ~TUNNEL_OPTIONS_PRESENT); + if (tb[LWTUNNEL_IP6_FLAGS]) { + IP_TUNNEL_DECLARE_FLAGS(flags); + __be16 data; + + data = nla_get_be16(tb[LWTUNNEL_IP6_FLAGS]); + ip_tunnel_flags_from_be16(flags, data); + ip_tunnel_clear_options_present(flags); + + ip_tunnel_flags_or(tun_info->key.tun_flags, + tun_info->key.tun_flags, flags); + } tun_info->mode = IP_TUNNEL_INFO_TX | IP_TUNNEL_INFO_IPV6; tun_info->options_len = opt_len; @@ -1008,7 +1031,8 @@ static int ip6_tun_fill_encap_info(struct sk_buff *skb, nla_put_in6_addr(skb, LWTUNNEL_IP6_SRC, &tun_info->key.u.ipv6.src) || nla_put_u8(skb, LWTUNNEL_IP6_TC, tun_info->key.tos) || nla_put_u8(skb, LWTUNNEL_IP6_HOPLIMIT, tun_info->key.ttl) || - nla_put_be16(skb, LWTUNNEL_IP6_FLAGS, tun_info->key.tun_flags) || + nla_put_be16(skb, LWTUNNEL_IP6_FLAGS, + ip_tunnel_flags_to_be16(tun_info->key.tun_flags)) || ip_tun_fill_encap_opts(skb, LWTUNNEL_IP6_OPTS, tun_info)) return -ENOMEM; @@ -1079,3 +1103,74 @@ EXPORT_SYMBOL(ip_tunnel_parse_protocol); const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tunnel_parse_protocol }; EXPORT_SYMBOL(ip_tunnel_header_ops); + +/* This function returns true when ENCAP attributes are present in the nl msg */ +bool ip_tunnel_netlink_encap_parms(struct nlattr *data[], + struct ip_tunnel_encap *encap) +{ + bool ret = false; + + memset(encap, 0, sizeof(*encap)); + + if (!data) + return ret; + + if (data[IFLA_IPTUN_ENCAP_TYPE]) { + ret = true; + encap->type = nla_get_u16(data[IFLA_IPTUN_ENCAP_TYPE]); + } + + if (data[IFLA_IPTUN_ENCAP_FLAGS]) { + ret = true; + encap->flags = nla_get_u16(data[IFLA_IPTUN_ENCAP_FLAGS]); + } + + if (data[IFLA_IPTUN_ENCAP_SPORT]) { + ret = true; + encap->sport = nla_get_be16(data[IFLA_IPTUN_ENCAP_SPORT]); + } + + if (data[IFLA_IPTUN_ENCAP_DPORT]) { + ret = true; + encap->dport = nla_get_be16(data[IFLA_IPTUN_ENCAP_DPORT]); + } + + return ret; +} +EXPORT_SYMBOL_GPL(ip_tunnel_netlink_encap_parms); + +void ip_tunnel_netlink_parms(struct nlattr *data[], + struct ip_tunnel_parm_kern *parms) +{ + if (data[IFLA_IPTUN_LINK]) + parms->link = nla_get_u32(data[IFLA_IPTUN_LINK]); + + if (data[IFLA_IPTUN_LOCAL]) + parms->iph.saddr = nla_get_be32(data[IFLA_IPTUN_LOCAL]); + + if (data[IFLA_IPTUN_REMOTE]) + parms->iph.daddr = nla_get_be32(data[IFLA_IPTUN_REMOTE]); + + if (data[IFLA_IPTUN_TTL]) { + parms->iph.ttl = nla_get_u8(data[IFLA_IPTUN_TTL]); + if (parms->iph.ttl) + parms->iph.frag_off = htons(IP_DF); + } + + if (data[IFLA_IPTUN_TOS]) + parms->iph.tos = nla_get_u8(data[IFLA_IPTUN_TOS]); + + if (!data[IFLA_IPTUN_PMTUDISC] || nla_get_u8(data[IFLA_IPTUN_PMTUDISC])) + parms->iph.frag_off = htons(IP_DF); + + if (data[IFLA_IPTUN_FLAGS]) { + __be16 flags; + + flags = nla_get_be16(data[IFLA_IPTUN_FLAGS]); + ip_tunnel_flags_from_be16(parms->i_flags, flags); + } + + if (data[IFLA_IPTUN_PROTO]) + parms->iph.protocol = nla_get_u8(data[IFLA_IPTUN_PROTO]); +} +EXPORT_SYMBOL_GPL(ip_tunnel_netlink_parms); diff --git a/net/ipv4/ip_vti.c b/net/ipv4/ip_vti.c index 8c2bd1d9ddce..95b6bb78fcd2 100644 --- a/net/ipv4/ip_vti.c +++ b/net/ipv4/ip_vti.c @@ -51,8 +51,11 @@ static int vti_input(struct sk_buff *skb, int nexthdr, __be32 spi, const struct iphdr *iph = ip_hdr(skb); struct net *net = dev_net(skb->dev); struct ip_tunnel_net *itn = net_generic(net, vti_net_id); + IP_TUNNEL_DECLARE_FLAGS(flags) = { }; - tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY, + __set_bit(IP_TUNNEL_NO_KEY_BIT, flags); + + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, flags, iph->saddr, iph->daddr, 0); if (tunnel) { if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) @@ -107,8 +110,8 @@ static int vti_rcv_cb(struct sk_buff *skb, int err) dev = tunnel->dev; if (err) { - dev->stats.rx_errors++; - dev->stats.rx_dropped++; + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_dropped); return 0; } @@ -167,7 +170,7 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, struct flowi *fl) { struct ip_tunnel *tunnel = netdev_priv(dev); - struct ip_tunnel_parm *parms = &tunnel->parms; + struct ip_tunnel_parm_kern *parms = &tunnel->parms; struct dst_entry *dst = skb_dst(skb); struct net_device *tdev; /* Device to other host */ int pkt_len = skb->len; @@ -183,7 +186,7 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, fl->u.ip4.flowi4_flags |= FLOWI_FLAG_ANYSRC; rt = __ip_route_output_key(dev_net(dev), &fl->u.ip4); if (IS_ERR(rt)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } dst = &rt->dst; @@ -198,14 +201,14 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, if (dst->error) { dst_release(dst); dst = NULL; - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } skb_dst_set(skb, dst); break; #endif default: - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } } @@ -213,7 +216,7 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, dst_hold(dst); dst = xfrm_lookup_route(tunnel->net, dst, fl, NULL, 0); if (IS_ERR(dst)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); goto tx_error_icmp; } @@ -221,16 +224,16 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, goto xmit; if (!vti_state_check(dst->xfrm, parms->iph.daddr, parms->iph.saddr)) { - dev->stats.tx_carrier_errors++; + DEV_STATS_INC(dev, tx_carrier_errors); dst_release(dst); goto tx_error_icmp; } - tdev = dst->dev; + tdev = dst_dev(dst); if (tdev == dev) { dst_release(dst); - dev->stats.collisions++; + DEV_STATS_INC(dev, collisions); goto tx_error; } @@ -256,7 +259,7 @@ static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, xmit: skb_scrub_packet(skb, !net_eq(tunnel->net, dev_net(dev))); skb_dst_set(skb, dst); - skb->dev = skb_dst(skb)->dev; + skb->dev = skb_dst_dev(skb); err = dst_output(tunnel->net, skb->sk, skb); if (net_xmit_eval(err) == 0) @@ -267,7 +270,7 @@ xmit: tx_error_icmp: dst_link_failure(skb); tx_error: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); return NETDEV_TX_OK; } @@ -287,12 +290,12 @@ static netdev_tx_t vti_tunnel_xmit(struct sk_buff *skb, struct net_device *dev) switch (skb->protocol) { case htons(ETH_P_IP): - xfrm_decode_session(skb, &fl, AF_INET); memset(IPCB(skb), 0, sizeof(*IPCB(skb))); + xfrm_decode_session(dev_net(dev), skb, &fl, AF_INET); break; case htons(ETH_P_IPV6): - xfrm_decode_session(skb, &fl, AF_INET6); memset(IP6CB(skb), 0, sizeof(*IP6CB(skb))); + xfrm_decode_session(dev_net(dev), skb, &fl, AF_INET6); break; default: goto tx_err; @@ -304,7 +307,7 @@ static netdev_tx_t vti_tunnel_xmit(struct sk_buff *skb, struct net_device *dev) return vti_xmit(skb, dev, &fl); tx_err: - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); return NETDEV_TX_OK; } @@ -322,8 +325,11 @@ static int vti4_err(struct sk_buff *skb, u32 info) const struct iphdr *iph = (const struct iphdr *)skb->data; int protocol = iph->protocol; struct ip_tunnel_net *itn = net_generic(net, vti_net_id); + IP_TUNNEL_DECLARE_FLAGS(flags) = { }; + + __set_bit(IP_TUNNEL_NO_KEY_BIT, flags); - tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY, + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, flags, iph->daddr, iph->saddr, 0); if (!tunnel) return -1; @@ -373,8 +379,9 @@ static int vti4_err(struct sk_buff *skb, u32 info) } static int -vti_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) +vti_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm_kern *p, int cmd) { + IP_TUNNEL_DECLARE_FLAGS(flags) = { }; int err = 0; if (cmd == SIOCADDTUNNEL || cmd == SIOCCHGTUNNEL) { @@ -383,20 +390,26 @@ vti_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) return -EINVAL; } - if (!(p->i_flags & GRE_KEY)) + if (!ip_tunnel_flags_is_be16_compat(p->i_flags) || + !ip_tunnel_flags_is_be16_compat(p->o_flags)) + return -EOVERFLOW; + + if (!(ip_tunnel_flags_to_be16(p->i_flags) & GRE_KEY)) p->i_key = 0; - if (!(p->o_flags & GRE_KEY)) + if (!(ip_tunnel_flags_to_be16(p->o_flags) & GRE_KEY)) p->o_key = 0; - p->i_flags = VTI_ISVTI; + __set_bit(IP_TUNNEL_VTI_BIT, flags); + ip_tunnel_flags_copy(p->i_flags, flags); err = ip_tunnel_ctl(dev, p, cmd); if (err) return err; if (cmd != SIOCDELTUNNEL) { - p->i_flags |= GRE_KEY; - p->o_flags |= GRE_KEY; + ip_tunnel_flags_from_be16(flags, GRE_KEY); + ip_tunnel_flags_or(p->i_flags, p->i_flags, flags); + ip_tunnel_flags_or(p->o_flags, p->o_flags, flags); } return 0; } @@ -430,7 +443,7 @@ static int vti_tunnel_init(struct net_device *dev) dev->flags = IFF_NOARP; dev->addr_len = 4; - dev->features |= NETIF_F_LLTX; + dev->lltx = true; netif_keep_dst(dev); return ip_tunnel_init(dev); @@ -510,14 +523,15 @@ static int __net_init vti_init_net(struct net *net) return 0; } -static void __net_exit vti_exit_batch_net(struct list_head *list_net) +static void __net_exit vti_exit_rtnl(struct net *net, + struct list_head *dev_to_kill) { - ip_tunnel_delete_nets(list_net, vti_net_id, &vti_link_ops); + ip_tunnel_delete_net(net, vti_net_id, &vti_link_ops, dev_to_kill); } static struct pernet_operations vti_net_ops = { .init = vti_init_net, - .exit_batch = vti_exit_batch_net, + .exit_rtnl = vti_exit_rtnl, .id = &vti_net_id, .size = sizeof(struct ip_tunnel_net), }; @@ -529,7 +543,7 @@ static int vti_tunnel_validate(struct nlattr *tb[], struct nlattr *data[], } static void vti_netlink_parms(struct nlattr *data[], - struct ip_tunnel_parm *parms, + struct ip_tunnel_parm_kern *parms, __u32 *fwmark) { memset(parms, 0, sizeof(*parms)); @@ -539,7 +553,7 @@ static void vti_netlink_parms(struct nlattr *data[], if (!data) return; - parms->i_flags = VTI_ISVTI; + __set_bit(IP_TUNNEL_VTI_BIT, parms->i_flags); if (data[IFLA_VTI_LINK]) parms->link = nla_get_u32(data[IFLA_VTI_LINK]); @@ -560,15 +574,18 @@ static void vti_netlink_parms(struct nlattr *data[], *fwmark = nla_get_u32(data[IFLA_VTI_FWMARK]); } -static int vti_newlink(struct net *src_net, struct net_device *dev, - struct nlattr *tb[], struct nlattr *data[], +static int vti_newlink(struct net_device *dev, + struct rtnl_newlink_params *params, struct netlink_ext_ack *extack) { - struct ip_tunnel_parm parms; + struct nlattr **data = params->data; + struct ip_tunnel_parm_kern parms; + struct nlattr **tb = params->tb; __u32 fwmark = 0; vti_netlink_parms(data, &parms, &fwmark); - return ip_tunnel_newlink(dev, tb, &parms, fwmark); + return ip_tunnel_newlink(params->link_net ? : dev_net(dev), dev, tb, + &parms, fwmark); } static int vti_changelink(struct net_device *dev, struct nlattr *tb[], @@ -576,8 +593,8 @@ static int vti_changelink(struct net_device *dev, struct nlattr *tb[], struct netlink_ext_ack *extack) { struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_parm_kern p; __u32 fwmark = t->fwmark; - struct ip_tunnel_parm p; vti_netlink_parms(data, &p, &fwmark); return ip_tunnel_changelink(dev, tb, &p, fwmark); @@ -604,7 +621,7 @@ static size_t vti_get_size(const struct net_device *dev) static int vti_fill_info(struct sk_buff *skb, const struct net_device *dev) { struct ip_tunnel *t = netdev_priv(dev); - struct ip_tunnel_parm *p = &t->parms; + struct ip_tunnel_parm_kern *p = &t->parms; if (nla_put_u32(skb, IFLA_VTI_LINK, p->link) || nla_put_be32(skb, IFLA_VTI_IKEY, p->i_key) || @@ -721,6 +738,7 @@ static void __exit vti_fini(void) module_init(vti_init); module_exit(vti_fini); +MODULE_DESCRIPTION("Virtual (secure) IP tunneling library"); MODULE_LICENSE("GPL"); MODULE_ALIAS_RTNL_LINK("vti"); MODULE_ALIAS_NETDEV("ip_vti0"); diff --git a/net/ipv4/ipcomp.c b/net/ipv4/ipcomp.c index 366094c1ce6c..9a45aed508d1 100644 --- a/net/ipv4/ipcomp.c +++ b/net/ipv4/ipcomp.c @@ -54,6 +54,7 @@ static int ipcomp4_err(struct sk_buff *skb, u32 info) } /* We always hold one tunnel user reference to indicate a tunnel */ +static struct lock_class_key xfrm_state_lock_key; static struct xfrm_state *ipcomp_tunnel_create(struct xfrm_state *x) { struct net *net = xs_net(x); @@ -62,6 +63,7 @@ static struct xfrm_state *ipcomp_tunnel_create(struct xfrm_state *x) t = xfrm_state_alloc(net); if (!t) goto out; + lockdep_set_class(&t->lock, &xfrm_state_lock_key); t->id.proto = IPPROTO_IPIP; t->id.spi = x->props.saddr.a4; @@ -117,7 +119,8 @@ out: return err; } -static int ipcomp4_init_state(struct xfrm_state *x) +static int ipcomp4_init_state(struct xfrm_state *x, + struct netlink_ext_ack *extack) { int err = -EINVAL; @@ -129,17 +132,20 @@ static int ipcomp4_init_state(struct xfrm_state *x) x->props.header_len += sizeof(struct iphdr); break; default: + NL_SET_ERR_MSG(extack, "Unsupported XFRM mode for IPcomp"); goto out; } - err = ipcomp_init_state(x); + err = ipcomp_init_state(x, extack); if (err) goto out; if (x->props.mode == XFRM_MODE_TUNNEL) { err = ipcomp_tunnel_attach(x); - if (err) + if (err) { + NL_SET_ERR_MSG(extack, "Kernel error: failed to initialize the associated state"); goto out; + } } err = 0; diff --git a/net/ipv4/ipconfig.c b/net/ipv4/ipconfig.c index 9d41d5d5cd1e..019408d3ca2c 100644 --- a/net/ipv4/ipconfig.c +++ b/net/ipv4/ipconfig.c @@ -274,9 +274,9 @@ static int __init ic_open_devs(void) /* wait for a carrier on at least one device */ start = jiffies; - next_msg = start + msecs_to_jiffies(20000); + next_msg = start + secs_to_jiffies(20); while (time_before(jiffies, start + - msecs_to_jiffies(carrier_timeout * 1000))) { + secs_to_jiffies(carrier_timeout))) { int wait, elapsed; rtnl_lock(); @@ -295,7 +295,7 @@ static int __init ic_open_devs(void) elapsed = jiffies_to_msecs(jiffies - start); wait = (carrier_timeout * 1000 - elapsed + 500) / 1000; pr_info("Waiting up to %d more seconds for network.\n", wait); - next_msg = jiffies + msecs_to_jiffies(20000); + next_msg = jiffies + secs_to_jiffies(20); } have_carrier: @@ -665,6 +665,9 @@ static struct packet_type bootp_packet_type __initdata = { .func = ic_bootp_recv, }; +/* DHCPACK can overwrite DNS if fallback was set upon first BOOTP reply */ +static int ic_nameservers_fallback __initdata; + /* * Initialize DHCP/BOOTP extension fields in the request. */ @@ -938,7 +941,8 @@ static void __init ic_do_bootp_ext(u8 *ext) if (servers > CONF_NAMESERVERS_MAX) servers = CONF_NAMESERVERS_MAX; for (i = 0; i < servers; i++) { - if (ic_nameservers[i] == NONE) + if (ic_nameservers[i] == NONE || + ic_nameservers_fallback) memcpy(&ic_nameservers[i], ext+1+4*i, 4); } break; @@ -1158,8 +1162,10 @@ static int __init ic_bootp_recv(struct sk_buff *skb, struct net_device *dev, str ic_addrservaddr = b->iph.saddr; if (ic_gateway == NONE && b->relay_ip) ic_gateway = b->relay_ip; - if (ic_nameservers[0] == NONE) + if (ic_nameservers[0] == NONE) { ic_nameservers[0] = ic_servaddr; + ic_nameservers_fallback = 1; + } ic_got_reply = IC_BOOTP; drop_unlock: @@ -1434,6 +1440,7 @@ __be32 __init root_nfs_parse_addr(char *name) static int __init wait_for_devices(void) { int i; + bool try_init_devs = true; for (i = 0; i < DEVICE_WAIT_MAX; i++) { struct net_device *dev; @@ -1452,6 +1459,11 @@ static int __init wait_for_devices(void) rtnl_unlock(); if (found) return 0; + if (try_init_devs && + (ROOT_DEV == Root_NFS || ROOT_DEV == Root_CIFS)) { + try_init_devs = false; + wait_for_init_devices_probe(); + } ssleep(1); } return -ENODEV; @@ -1678,7 +1690,8 @@ static int __init ic_proto_name(char *name) *v = 0; if (kstrtou8(client_id, 0, dhcp_client_identifier)) pr_debug("DHCP: Invalid client identifier type\n"); - strncpy(dhcp_client_identifier + 1, v + 1, 251); + strscpy(dhcp_client_identifier + 1, v + 1, + sizeof(dhcp_client_identifier) - 1); *v = ','; } return 1; @@ -1759,15 +1772,15 @@ static int __init ip_auto_config_setup(char *addrs) case 4: if ((dp = strchr(ip, '.'))) { *dp++ = '\0'; - strlcpy(utsname()->domainname, dp, + strscpy(utsname()->domainname, dp, sizeof(utsname()->domainname)); } - strlcpy(utsname()->nodename, ip, + strscpy(utsname()->nodename, ip, sizeof(utsname()->nodename)); ic_host_name_set = 1; break; case 5: - strlcpy(user_dev_name, ip, sizeof(user_dev_name)); + strscpy(user_dev_name, ip, sizeof(user_dev_name)); break; case 6: if (ic_proto_name(ip) == 0 && @@ -1814,7 +1827,7 @@ __setup("nfsaddrs=", nfsaddrs_config_setup); static int __init vendor_class_identifier_setup(char *addrs) { - if (strlcpy(vendor_class_identifier, addrs, + if (strscpy(vendor_class_identifier, addrs, sizeof(vendor_class_identifier)) >= sizeof(vendor_class_identifier)) pr_warn("DHCP: vendorclass too long, truncated to \"%s\"\n", diff --git a/net/ipv4/ipip.c b/net/ipv4/ipip.c index 123ea63a04cb..ff95b1b9908e 100644 --- a/net/ipv4/ipip.c +++ b/net/ipv4/ipip.c @@ -130,13 +130,16 @@ static int ipip_err(struct sk_buff *skb, u32 info) struct net *net = dev_net(skb->dev); struct ip_tunnel_net *itn = net_generic(net, ipip_net_id); const struct iphdr *iph = (const struct iphdr *)skb->data; + IP_TUNNEL_DECLARE_FLAGS(flags) = { }; const int type = icmp_hdr(skb)->type; const int code = icmp_hdr(skb)->code; struct ip_tunnel *t; int err = 0; - t = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY, - iph->daddr, iph->saddr, 0); + __set_bit(IP_TUNNEL_NO_KEY_BIT, flags); + + t = ip_tunnel_lookup(itn, skb->dev->ifindex, flags, iph->daddr, + iph->saddr, 0); if (!t) { err = -ENOENT; goto out; @@ -213,13 +216,16 @@ static int ipip_tunnel_rcv(struct sk_buff *skb, u8 ipproto) { struct net *net = dev_net(skb->dev); struct ip_tunnel_net *itn = net_generic(net, ipip_net_id); + IP_TUNNEL_DECLARE_FLAGS(flags) = { }; struct metadata_dst *tun_dst = NULL; struct ip_tunnel *tunnel; const struct iphdr *iph; + __set_bit(IP_TUNNEL_NO_KEY_BIT, flags); + iph = ip_hdr(skb); - tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY, - iph->saddr, iph->daddr, 0); + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, flags, iph->saddr, + iph->daddr, 0); if (tunnel) { const struct tnl_ptk_info *tpi; @@ -238,9 +244,12 @@ static int ipip_tunnel_rcv(struct sk_buff *skb, u8 ipproto) if (iptunnel_pull_header(skb, 0, tpi->proto, false)) goto drop; if (tunnel->collect_md) { - tun_dst = ip_tun_rx_dst(skb, 0, 0, 0); + ip_tunnel_flags_zero(flags); + + tun_dst = ip_tun_rx_dst(skb, flags, 0, 0); if (!tun_dst) return 0; + ip_tunnel_md_udp_encap(skb, &tun_dst->u.tun_info); } skb_reset_mac_header(skb); @@ -310,7 +319,7 @@ static netdev_tx_t ipip_tunnel_xmit(struct sk_buff *skb, tx_error: kfree_skb(skb); - dev->stats.tx_errors++; + DEV_STATS_INC(dev, tx_errors); return NETDEV_TX_OK; } @@ -329,7 +338,7 @@ static bool ipip_tunnel_ioctl_verify_protocol(u8 ipproto) } static int -ipip_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) +ipip_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm_kern *p, int cmd) { if (cmd == SIOCADDTUNNEL || cmd == SIOCCHGTUNNEL) { if (p->iph.version != 4 || @@ -339,10 +348,35 @@ ipip_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) } p->i_key = p->o_key = 0; - p->i_flags = p->o_flags = 0; + ip_tunnel_flags_zero(p->i_flags); + ip_tunnel_flags_zero(p->o_flags); return ip_tunnel_ctl(dev, p, cmd); } +static int ipip_fill_forward_path(struct net_device_path_ctx *ctx, + struct net_device_path *path) +{ + struct ip_tunnel *tunnel = netdev_priv(ctx->dev); + const struct iphdr *tiph = &tunnel->parms.iph; + struct rtable *rt; + + rt = ip_route_output(dev_net(ctx->dev), tiph->daddr, 0, 0, 0, + RT_SCOPE_UNIVERSE); + if (IS_ERR(rt)) + return PTR_ERR(rt); + + path->type = DEV_PATH_TUN; + path->tun.src_v4.s_addr = tiph->saddr; + path->tun.dst_v4.s_addr = tiph->daddr; + path->tun.l3_proto = IPPROTO_IPIP; + path->dev = ctx->dev; + + ctx->dev = rt->dst.dev; + ip_rt_put(rt); + + return 0; +} + static const struct net_device_ops ipip_netdev_ops = { .ndo_init = ipip_tunnel_init, .ndo_uninit = ip_tunnel_uninit, @@ -352,6 +386,7 @@ static const struct net_device_ops ipip_netdev_ops = { .ndo_get_stats64 = dev_get_tstats64, .ndo_get_iflink = ip_tunnel_get_iflink, .ndo_tunnel_ctl = ipip_tunnel_ctl, + .ndo_fill_forward_path = ipip_fill_forward_path, }; #define IPIP_FEATURES (NETIF_F_SG | \ @@ -368,7 +403,7 @@ static void ipip_tunnel_setup(struct net_device *dev) dev->type = ARPHRD_TUNNEL; dev->flags = IFF_NOARP; dev->addr_len = 4; - dev->features |= NETIF_F_LLTX; + dev->lltx = true; netif_keep_dst(dev); dev->features |= IPIP_FEATURES; @@ -404,8 +439,8 @@ static int ipip_tunnel_validate(struct nlattr *tb[], struct nlattr *data[], } static void ipip_netlink_parms(struct nlattr *data[], - struct ip_tunnel_parm *parms, bool *collect_md, - __u32 *fwmark) + struct ip_tunnel_parm_kern *parms, + bool *collect_md, __u32 *fwmark) { memset(parms, 0, sizeof(*parms)); @@ -417,29 +452,7 @@ static void ipip_netlink_parms(struct nlattr *data[], if (!data) return; - if (data[IFLA_IPTUN_LINK]) - parms->link = nla_get_u32(data[IFLA_IPTUN_LINK]); - - if (data[IFLA_IPTUN_LOCAL]) - parms->iph.saddr = nla_get_in_addr(data[IFLA_IPTUN_LOCAL]); - - if (data[IFLA_IPTUN_REMOTE]) - parms->iph.daddr = nla_get_in_addr(data[IFLA_IPTUN_REMOTE]); - - if (data[IFLA_IPTUN_TTL]) { - parms->iph.ttl = nla_get_u8(data[IFLA_IPTUN_TTL]); - if (parms->iph.ttl) - parms->iph.frag_off = htons(IP_DF); - } - - if (data[IFLA_IPTUN_TOS]) - parms->iph.tos = nla_get_u8(data[IFLA_IPTUN_TOS]); - - if (data[IFLA_IPTUN_PROTO]) - parms->iph.protocol = nla_get_u8(data[IFLA_IPTUN_PROTO]); - - if (!data[IFLA_IPTUN_PMTUDISC] || nla_get_u8(data[IFLA_IPTUN_PMTUDISC])) - parms->iph.frag_off = htons(IP_DF); + ip_tunnel_netlink_parms(data, parms); if (data[IFLA_IPTUN_COLLECT_METADATA]) *collect_md = true; @@ -448,50 +461,18 @@ static void ipip_netlink_parms(struct nlattr *data[], *fwmark = nla_get_u32(data[IFLA_IPTUN_FWMARK]); } -/* This function returns true when ENCAP attributes are present in the nl msg */ -static bool ipip_netlink_encap_parms(struct nlattr *data[], - struct ip_tunnel_encap *ipencap) -{ - bool ret = false; - - memset(ipencap, 0, sizeof(*ipencap)); - - if (!data) - return ret; - - if (data[IFLA_IPTUN_ENCAP_TYPE]) { - ret = true; - ipencap->type = nla_get_u16(data[IFLA_IPTUN_ENCAP_TYPE]); - } - - if (data[IFLA_IPTUN_ENCAP_FLAGS]) { - ret = true; - ipencap->flags = nla_get_u16(data[IFLA_IPTUN_ENCAP_FLAGS]); - } - - if (data[IFLA_IPTUN_ENCAP_SPORT]) { - ret = true; - ipencap->sport = nla_get_be16(data[IFLA_IPTUN_ENCAP_SPORT]); - } - - if (data[IFLA_IPTUN_ENCAP_DPORT]) { - ret = true; - ipencap->dport = nla_get_be16(data[IFLA_IPTUN_ENCAP_DPORT]); - } - - return ret; -} - -static int ipip_newlink(struct net *src_net, struct net_device *dev, - struct nlattr *tb[], struct nlattr *data[], +static int ipip_newlink(struct net_device *dev, + struct rtnl_newlink_params *params, struct netlink_ext_ack *extack) { struct ip_tunnel *t = netdev_priv(dev); - struct ip_tunnel_parm p; + struct nlattr **data = params->data; + struct nlattr **tb = params->tb; struct ip_tunnel_encap ipencap; + struct ip_tunnel_parm_kern p; __u32 fwmark = 0; - if (ipip_netlink_encap_parms(data, &ipencap)) { + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { int err = ip_tunnel_encap_setup(t, &ipencap); if (err < 0) @@ -499,7 +480,8 @@ static int ipip_newlink(struct net *src_net, struct net_device *dev, } ipip_netlink_parms(data, &p, &t->collect_md, &fwmark); - return ip_tunnel_newlink(dev, tb, &p, fwmark); + return ip_tunnel_newlink(params->link_net ? : dev_net(dev), dev, tb, &p, + fwmark); } static int ipip_changelink(struct net_device *dev, struct nlattr *tb[], @@ -507,12 +489,12 @@ static int ipip_changelink(struct net_device *dev, struct nlattr *tb[], struct netlink_ext_ack *extack) { struct ip_tunnel *t = netdev_priv(dev); - struct ip_tunnel_parm p; struct ip_tunnel_encap ipencap; + struct ip_tunnel_parm_kern p; bool collect_md; __u32 fwmark = t->fwmark; - if (ipip_netlink_encap_parms(data, &ipencap)) { + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { int err = ip_tunnel_encap_setup(t, &ipencap); if (err < 0) @@ -565,7 +547,7 @@ static size_t ipip_get_size(const struct net_device *dev) static int ipip_fill_info(struct sk_buff *skb, const struct net_device *dev) { struct ip_tunnel *tunnel = netdev_priv(dev); - struct ip_tunnel_parm *parm = &tunnel->parms; + struct ip_tunnel_parm_kern *parm = &tunnel->parms; if (nla_put_u32(skb, IFLA_IPTUN_LINK, parm->link) || nla_put_in_addr(skb, IFLA_IPTUN_LOCAL, parm->iph.saddr) || @@ -647,14 +629,15 @@ static int __net_init ipip_init_net(struct net *net) return ip_tunnel_init_net(net, ipip_net_id, &ipip_link_ops, "tunl0"); } -static void __net_exit ipip_exit_batch_net(struct list_head *list_net) +static void __net_exit ipip_exit_rtnl(struct net *net, + struct list_head *dev_to_kill) { - ip_tunnel_delete_nets(list_net, ipip_net_id, &ipip_link_ops); + ip_tunnel_delete_net(net, ipip_net_id, &ipip_link_ops, dev_to_kill); } static struct pernet_operations ipip_net_ops = { .init = ipip_init_net, - .exit_batch = ipip_exit_batch_net, + .exit_rtnl = ipip_exit_rtnl, .id = &ipip_net_id, .size = sizeof(struct ip_tunnel_net), }; @@ -713,6 +696,7 @@ static void __exit ipip_fini(void) module_init(ipip_init); module_exit(ipip_fini); +MODULE_DESCRIPTION("IP/IP protocol decoder library"); MODULE_LICENSE("GPL"); MODULE_ALIAS_RTNL_LINK("ipip"); MODULE_ALIAS_NETDEV("tunl0"); diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c index 07274619b9ea..ca9eaee4c2ef 100644 --- a/net/ipv4/ipmr.c +++ b/net/ipv4/ipmr.c @@ -42,6 +42,7 @@ #include <linux/init.h> #include <linux/if_ether.h> #include <linux/slab.h> +#include <net/flow.h> #include <net/net_namespace.h> #include <net/ip.h> #include <net/protocol.h> @@ -62,6 +63,7 @@ #include <net/fib_rules.h> #include <linux/netconf.h> #include <net/rtnh.h> +#include <net/inet_dscp.h> #include <linux/nospec.h> @@ -77,7 +79,12 @@ struct ipmr_result { * Note that the changes are semaphored via rtnl_lock. */ -static DEFINE_RWLOCK(mrt_lock); +static DEFINE_SPINLOCK(mrt_lock); + +static struct net_device *vif_dev_read(const struct vif_device *vif) +{ + return rcu_dereference(vif->dev); +} /* Multicast router control variables */ @@ -100,11 +107,11 @@ static void ipmr_free_table(struct mr_table *mrt); static void ip_mr_forward(struct net *net, struct mr_table *mrt, struct net_device *dev, struct sk_buff *skb, struct mfc_cache *cache, int local); -static int ipmr_cache_report(struct mr_table *mrt, +static int ipmr_cache_report(const struct mr_table *mrt, struct sk_buff *pkt, vifi_t vifi, int assert); static void mroute_netlink_event(struct mr_table *mrt, struct mfc_cache *mfc, int cmd); -static void igmpmsg_netlink_event(struct mr_table *mrt, struct sk_buff *pkt); +static void igmpmsg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt); static void mroute_clean_tables(struct mr_table *mrt, int flags); static void ipmr_expire_process(struct timer_list *t); @@ -131,7 +138,7 @@ static struct mr_table *ipmr_mr_table_iter(struct net *net, return ret; } -static struct mr_table *ipmr_get_table(struct net *net, u32 id) +static struct mr_table *__ipmr_get_table(struct net *net, u32 id) { struct mr_table *mrt; @@ -142,6 +149,16 @@ static struct mr_table *ipmr_get_table(struct net *net, u32 id) return NULL; } +static struct mr_table *ipmr_get_table(struct net *net, u32 id) +{ + struct mr_table *mrt; + + rcu_read_lock(); + mrt = __ipmr_get_table(net, id); + rcu_read_unlock(); + return mrt; +} + static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4, struct mr_table **mrt) { @@ -183,7 +200,7 @@ static int ipmr_rule_action(struct fib_rule *rule, struct flowi *flp, arg->table = fib_rule_get_table(rule, arg); - mrt = ipmr_get_table(rule->fr_net, arg->table); + mrt = __ipmr_get_table(rule->fr_net, arg->table); if (!mrt) return -EAGAIN; res->mrt = mrt; @@ -248,7 +265,7 @@ static int __net_init ipmr_rules_init(struct net *net) goto err1; } - err = fib_default_rule_add(ops, 0x7fff, RT_TABLE_DEFAULT, 0); + err = fib_default_rule_add(ops, 0x7fff, RT_TABLE_DEFAULT); if (err < 0) goto err2; @@ -256,7 +273,9 @@ static int __net_init ipmr_rules_init(struct net *net) return 0; err2: + rtnl_lock(); ipmr_free_table(mrt); + rtnl_unlock(); err1: fib_rules_unregister(ops); return err; @@ -266,13 +285,12 @@ static void __net_exit ipmr_rules_exit(struct net *net) { struct mr_table *mrt, *next; - rtnl_lock(); + ASSERT_RTNL(); list_for_each_entry_safe(mrt, next, &net->ipv4.mr_tables, list) { list_del(&mrt->list); ipmr_free_table(mrt); } fib_rules_unregister(net->ipv4.mr_rules_ops); - rtnl_unlock(); } static int ipmr_rules_dump(struct net *net, struct notifier_block *nb, @@ -281,7 +299,7 @@ static int ipmr_rules_dump(struct net *net, struct notifier_block *nb, return fib_rules_dump(net, nb, RTNL_FAMILY_IPMR, extack); } -static unsigned int ipmr_rules_seq_read(struct net *net) +static unsigned int ipmr_rules_seq_read(const struct net *net) { return fib_rules_seq_read(net, RTNL_FAMILY_IPMR); } @@ -308,6 +326,8 @@ static struct mr_table *ipmr_get_table(struct net *net, u32 id) return net->ipv4.mrt; } +#define __ipmr_get_table ipmr_get_table + static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4, struct mr_table **mrt) { @@ -328,10 +348,9 @@ static int __net_init ipmr_rules_init(struct net *net) static void __net_exit ipmr_rules_exit(struct net *net) { - rtnl_lock(); + ASSERT_RTNL(); ipmr_free_table(net->ipv4.mrt); net->ipv4.mrt = NULL; - rtnl_unlock(); } static int ipmr_rules_dump(struct net *net, struct notifier_block *nb, @@ -340,7 +359,7 @@ static int ipmr_rules_dump(struct net *net, struct notifier_block *nb, return 0; } -static unsigned int ipmr_rules_seq_read(struct net *net) +static unsigned int ipmr_rules_seq_read(const struct net *net) { return 0; } @@ -356,7 +375,7 @@ static inline int ipmr_hash_cmp(struct rhashtable_compare_arg *arg, const void *ptr) { const struct mfc_cache_cmp_arg *cmparg = arg->key; - struct mfc_cache *c = (struct mfc_cache *)ptr; + const struct mfc_cache *c = ptr; return cmparg->mfc_mcastgrp != c->mfc_mcastgrp || cmparg->mfc_origin != c->mfc_origin; @@ -397,7 +416,7 @@ static struct mr_table *ipmr_new_table(struct net *net, u32 id) if (id != RT_TABLE_DEFAULT && id >= 1000000000) return ERR_PTR(-EINVAL); - mrt = ipmr_get_table(net, id); + mrt = __ipmr_get_table(net, id); if (mrt) return mrt; @@ -407,7 +426,11 @@ static struct mr_table *ipmr_new_table(struct net *net, u32 id) static void ipmr_free_table(struct mr_table *mrt) { - del_timer_sync(&mrt->ipmr_expire_timer); + struct net *net = read_pnet(&mrt->net); + + WARN_ON_ONCE(!mr_can_free_table(net)); + + timer_shutdown_sync(&mrt->ipmr_expire_timer); mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC | MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC); rhltable_destroy(&mrt->mfc_hash); @@ -436,7 +459,7 @@ static bool ipmr_init_vif_indev(const struct net_device *dev) static struct net_device *ipmr_new_tunnel(struct net *net, struct vifctl *v) { struct net_device *tunnel_dev, *new_dev; - struct ip_tunnel_parm p = { }; + struct ip_tunnel_parm_kern p = { }; int err; tunnel_dev = __dev_get_by_name(net, "tunl0"); @@ -501,11 +524,15 @@ static netdev_tx_t reg_vif_xmit(struct sk_buff *skb, struct net_device *dev) return err; } - read_lock(&mrt_lock); - dev->stats.tx_bytes += skb->len; - dev->stats.tx_packets++; - ipmr_cache_report(mrt, skb, mrt->mroute_reg_vif_num, IGMPMSG_WHOLEPKT); - read_unlock(&mrt_lock); + DEV_STATS_ADD(dev, tx_bytes, skb->len); + DEV_STATS_INC(dev, tx_packets); + rcu_read_lock(); + + /* Pairs with WRITE_ONCE() in vif_add() and vif_delete() */ + ipmr_cache_report(mrt, skb, READ_ONCE(mrt->mroute_reg_vif_num), + IGMPMSG_WHOLEPKT); + + rcu_read_unlock(); kfree_skb(skb); return NETDEV_TX_OK; } @@ -527,7 +554,7 @@ static void reg_vif_setup(struct net_device *dev) dev->flags = IFF_NOARP; dev->netdev_ops = ®_vif_netdev_ops; dev->needs_free_netdev = true; - dev->features |= NETIF_F_NETNS_LOCAL; + dev->netns_immutable = true; } static struct net_device *ipmr_reg_vif(struct net *net, struct mr_table *mrt) @@ -572,6 +599,7 @@ static int __pim_rcv(struct mr_table *mrt, struct sk_buff *skb, { struct net_device *reg_dev = NULL; struct iphdr *encap; + int vif_num; encap = (struct iphdr *)(skb_transport_header(skb) + pimlen); /* Check that: @@ -584,11 +612,10 @@ static int __pim_rcv(struct mr_table *mrt, struct sk_buff *skb, ntohs(encap->tot_len) + pimlen > skb->len) return 1; - read_lock(&mrt_lock); - if (mrt->mroute_reg_vif_num >= 0) - reg_dev = mrt->vif_table[mrt->mroute_reg_vif_num].dev; - read_unlock(&mrt_lock); - + /* Pairs with WRITE_ONCE() in vif_add()/vid_delete() */ + vif_num = READ_ONCE(mrt->mroute_reg_vif_num); + if (vif_num >= 0) + reg_dev = vif_dev_read(&mrt->vif_table[vif_num]); if (!reg_dev) return 1; @@ -614,10 +641,11 @@ static struct net_device *ipmr_reg_vif(struct net *net, struct mr_table *mrt) static int call_ipmr_vif_entry_notifiers(struct net *net, enum fib_event_type event_type, struct vif_device *vif, + struct net_device *vif_dev, vifi_t vif_index, u32 tb_id) { return mr_call_vif_notifiers(net, RTNL_FAMILY_IPMR, event_type, - vif, vif_index, tb_id, + vif, vif_dev, vif_index, tb_id, &net->ipv4.ipmr_seq); } @@ -649,22 +677,19 @@ static int vif_delete(struct mr_table *mrt, int vifi, int notify, v = &mrt->vif_table[vifi]; - if (VIF_EXISTS(mrt, vifi)) - call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_DEL, v, vifi, - mrt->id); - - write_lock_bh(&mrt_lock); - dev = v->dev; - v->dev = NULL; - - if (!dev) { - write_unlock_bh(&mrt_lock); + dev = rtnl_dereference(v->dev); + if (!dev) return -EADDRNOTAVAIL; - } - if (vifi == mrt->mroute_reg_vif_num) - mrt->mroute_reg_vif_num = -1; + spin_lock(&mrt_lock); + call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_DEL, v, dev, + vifi, mrt->id); + RCU_INIT_POINTER(v->dev, NULL); + if (vifi == mrt->mroute_reg_vif_num) { + /* Pairs with READ_ONCE() in ipmr_cache_report() and reg_vif_xmit() */ + WRITE_ONCE(mrt->mroute_reg_vif_num, -1); + } if (vifi + 1 == mrt->maxvif) { int tmp; @@ -672,10 +697,10 @@ static int vif_delete(struct mr_table *mrt, int vifi, int notify, if (VIF_EXISTS(mrt, tmp)) break; } - mrt->maxvif = tmp+1; + WRITE_ONCE(mrt->maxvif, tmp + 1); } - write_unlock_bh(&mrt_lock); + spin_unlock(&mrt_lock); dev_set_allmulti(dev, -1); @@ -691,7 +716,7 @@ static int vif_delete(struct mr_table *mrt, int vifi, int notify, if (v->flags & (VIFF_TUNNEL | VIFF_REGISTER) && !notify) unregister_netdevice_queue(dev, head); - dev_put_track(dev, &v->dev_tracker); + netdev_put(dev, &v->dev_tracker); return 0; } @@ -741,7 +766,7 @@ static void ipmr_destroy_unres(struct mr_table *mrt, struct mfc_cache *c) /* Timer process for the unresolved queue. */ static void ipmr_expire_process(struct timer_list *t) { - struct mr_table *mrt = from_timer(mrt, t, ipmr_expire_timer); + struct mr_table *mrt = timer_container_of(mrt, t, ipmr_expire_timer); struct mr_mfc *c, *next; unsigned long expires; unsigned long now; @@ -777,7 +802,7 @@ out: spin_unlock(&mfc_unres_lock); } -/* Fill oifs list. It is called under write locked mrt_lock. */ +/* Fill oifs list. It is called under locked mrt_lock. */ static void ipmr_update_thresholds(struct mr_table *mrt, struct mr_mfc *cache, unsigned char *ttls) { @@ -797,7 +822,7 @@ static void ipmr_update_thresholds(struct mr_table *mrt, struct mr_mfc *cache, cache->mfc_un.res.maxvif = vifi + 1; } } - cache->mfc_un.res.lastuse = jiffies; + WRITE_ONCE(cache->mfc_un.res.lastuse, jiffies); } static int vif_add(struct net *net, struct mr_table *mrt, @@ -877,7 +902,7 @@ static int vif_add(struct net *net, struct mr_table *mrt, vifc->vifc_flags | (!mrtsock ? VIFF_STATIC : 0), (VIFF_TUNNEL | VIFF_REGISTER)); - err = dev_get_port_parent_id(dev, &ppid, true); + err = netif_get_port_parent_id(dev, &ppid, true); if (err == 0) { memcpy(v->dev_parent_id.id, ppid.id, ppid.id_len); v->dev_parent_id.id_len = ppid.id_len; @@ -889,15 +914,18 @@ static int vif_add(struct net *net, struct mr_table *mrt, v->remote = vifc->vifc_rmt_addr.s_addr; /* And finish update writing critical data */ - write_lock_bh(&mrt_lock); - v->dev = dev; + spin_lock(&mrt_lock); + rcu_assign_pointer(v->dev, dev); netdev_tracker_alloc(dev, &v->dev_tracker, GFP_ATOMIC); - if (v->flags & VIFF_REGISTER) - mrt->mroute_reg_vif_num = vifi; + if (v->flags & VIFF_REGISTER) { + /* Pairs with READ_ONCE() in ipmr_cache_report() and reg_vif_xmit() */ + WRITE_ONCE(mrt->mroute_reg_vif_num, vifi); + } if (vifi+1 > mrt->maxvif) - mrt->maxvif = vifi+1; - write_unlock_bh(&mrt_lock); - call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_ADD, v, vifi, mrt->id); + WRITE_ONCE(mrt->maxvif, vifi + 1); + spin_unlock(&mrt_lock); + call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_ADD, v, dev, + vifi, mrt->id); return 0; } @@ -994,16 +1022,18 @@ static void ipmr_cache_resolve(struct net *net, struct mr_table *mrt, rtnl_unicast(skb, net, NETLINK_CB(skb).portid); } else { + rcu_read_lock(); ip_mr_forward(net, mrt, skb->dev, skb, c, 0); + rcu_read_unlock(); } } } /* Bounce a cache query up to mrouted and netlink. * - * Called under mrt_lock. + * Called under rcu_read_lock(). */ -static int ipmr_cache_report(struct mr_table *mrt, +static int ipmr_cache_report(const struct mr_table *mrt, struct sk_buff *pkt, vifi_t vifi, int assert) { const int ihl = ip_hdrlen(pkt); @@ -1013,6 +1043,10 @@ static int ipmr_cache_report(struct mr_table *mrt, struct sk_buff *skb; int ret; + mroute_sk = rcu_dereference(mrt->mroute_sk); + if (!mroute_sk) + return -EINVAL; + if (assert == IGMPMSG_WHOLEPKT || assert == IGMPMSG_WRVIFWHOLE) skb = skb_realloc_headroom(pkt, sizeof(struct iphdr)); else @@ -1038,8 +1072,11 @@ static int ipmr_cache_report(struct mr_table *mrt, msg->im_vif = vifi; msg->im_vif_hi = vifi >> 8; } else { - msg->im_vif = mrt->mroute_reg_vif_num; - msg->im_vif_hi = mrt->mroute_reg_vif_num >> 8; + /* Pairs with WRITE_ONCE() in vif_add() and vif_delete() */ + int vif_num = READ_ONCE(mrt->mroute_reg_vif_num); + + msg->im_vif = vif_num; + msg->im_vif_hi = vif_num >> 8; } ip_hdr(skb)->ihl = sizeof(struct iphdr) >> 2; ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(pkt)->tot_len) + @@ -1054,7 +1091,8 @@ static int ipmr_cache_report(struct mr_table *mrt, msg = (struct igmpmsg *)skb_network_header(skb); msg->im_vif = vifi; msg->im_vif_hi = vifi >> 8; - skb_dst_set(skb, dst_clone(skb_dst(pkt))); + ipv4_pktinfo_prepare(mroute_sk, pkt, false); + memcpy(skb->cb, pkt->cb, sizeof(skb->cb)); /* Add our header */ igmp = skb_put(skb, sizeof(struct igmphdr)); igmp->type = assert; @@ -1064,19 +1102,11 @@ static int ipmr_cache_report(struct mr_table *mrt, skb->transport_header = skb->network_header; } - rcu_read_lock(); - mroute_sk = rcu_dereference(mrt->mroute_sk); - if (!mroute_sk) { - rcu_read_unlock(); - kfree_skb(skb); - return -EINVAL; - } - igmpmsg_netlink_event(mrt, skb); /* Deliver to mrouted */ ret = sock_queue_rcv_skb(mroute_sk, skb); - rcu_read_unlock(); + if (ret < 0) { net_warn_ratelimited("mroute: pending queue full, dropping entries\n"); kfree_skb(skb); @@ -1086,6 +1116,7 @@ static int ipmr_cache_report(struct mr_table *mrt, } /* Queue a packet for resolution. It gets locked cache entry! */ +/* Called under rcu_read_lock() */ static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi, struct sk_buff *skb, struct net_device *dev) { @@ -1198,12 +1229,12 @@ static int ipmr_mfc_add(struct net *net, struct mr_table *mrt, mfc->mfcc_mcastgrp.s_addr, parent); rcu_read_unlock(); if (c) { - write_lock_bh(&mrt_lock); + spin_lock(&mrt_lock); c->_c.mfc_parent = mfc->mfcc_parent; ipmr_update_thresholds(mrt, &c->_c, mfc->mfcc_ttls); if (!mrtsock) c->_c.mfc_flags |= MFC_STATIC; - write_unlock_bh(&mrt_lock); + spin_unlock(&mrt_lock); call_ipmr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_REPLACE, c, mrt->id); mroute_netlink_event(mrt, c, RTM_NEWROUTE); @@ -1249,7 +1280,7 @@ static int ipmr_mfc_add(struct net *net, struct mr_table *mrt, } } if (list_empty(&mrt->mfc_unres_queue)) - del_timer(&mrt->ipmr_expire_timer); + timer_delete(&mrt->ipmr_expire_timer); spin_unlock_bh(&mfc_unres_lock); if (found) { @@ -1360,7 +1391,7 @@ int ip_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval, goto out_unlock; } - mrt = ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT); + mrt = __ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT); if (!mrt) { ret = -ENOENT; goto out_unlock; @@ -1533,8 +1564,31 @@ out: return ret; } +/* Execute if this ioctl is a special mroute ioctl */ +int ipmr_sk_ioctl(struct sock *sk, unsigned int cmd, void __user *arg) +{ + switch (cmd) { + /* These userspace buffers will be consumed by ipmr_ioctl() */ + case SIOCGETVIFCNT: { + struct sioc_vif_req buffer; + + return sock_ioctl_inout(sk, cmd, arg, &buffer, + sizeof(buffer)); + } + case SIOCGETSGCNT: { + struct sioc_sg_req buffer; + + return sock_ioctl_inout(sk, cmd, arg, &buffer, + sizeof(buffer)); + } + } + /* return code > 0 means that the ioctl was not executed */ + return 1; +} + /* Getsock opt support for the multicast routing system. */ -int ip_mroute_getsockopt(struct sock *sk, int optname, char __user *optval, int __user *optlen) +int ip_mroute_getsockopt(struct sock *sk, int optname, sockptr_t optval, + sockptr_t optlen) { int olr; int val; @@ -1565,26 +1619,28 @@ int ip_mroute_getsockopt(struct sock *sk, int optname, char __user *optval, int return -ENOPROTOOPT; } - if (get_user(olr, optlen)) + if (copy_from_sockptr(&olr, optlen, sizeof(int))) return -EFAULT; - olr = min_t(unsigned int, olr, sizeof(int)); if (olr < 0) return -EINVAL; - if (put_user(olr, optlen)) + + olr = min_t(unsigned int, olr, sizeof(int)); + + if (copy_to_sockptr(optlen, &olr, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &val, olr)) + if (copy_to_sockptr(optval, &val, olr)) return -EFAULT; return 0; } /* The IP multicast ioctl support routines. */ -int ipmr_ioctl(struct sock *sk, int cmd, void __user *arg) +int ipmr_ioctl(struct sock *sk, int cmd, void *arg) { - struct sioc_sg_req sr; - struct sioc_vif_req vr; struct vif_device *vif; struct mfc_cache *c; struct net *net = sock_net(sk); + struct sioc_vif_req *vr; + struct sioc_sg_req *sr; struct mr_table *mrt; mrt = ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT); @@ -1593,40 +1649,33 @@ int ipmr_ioctl(struct sock *sk, int cmd, void __user *arg) switch (cmd) { case SIOCGETVIFCNT: - if (copy_from_user(&vr, arg, sizeof(vr))) - return -EFAULT; - if (vr.vifi >= mrt->maxvif) + vr = (struct sioc_vif_req *)arg; + if (vr->vifi >= mrt->maxvif) return -EINVAL; - vr.vifi = array_index_nospec(vr.vifi, mrt->maxvif); - read_lock(&mrt_lock); - vif = &mrt->vif_table[vr.vifi]; - if (VIF_EXISTS(mrt, vr.vifi)) { - vr.icount = vif->pkt_in; - vr.ocount = vif->pkt_out; - vr.ibytes = vif->bytes_in; - vr.obytes = vif->bytes_out; - read_unlock(&mrt_lock); + vr->vifi = array_index_nospec(vr->vifi, mrt->maxvif); + rcu_read_lock(); + vif = &mrt->vif_table[vr->vifi]; + if (VIF_EXISTS(mrt, vr->vifi)) { + vr->icount = READ_ONCE(vif->pkt_in); + vr->ocount = READ_ONCE(vif->pkt_out); + vr->ibytes = READ_ONCE(vif->bytes_in); + vr->obytes = READ_ONCE(vif->bytes_out); + rcu_read_unlock(); - if (copy_to_user(arg, &vr, sizeof(vr))) - return -EFAULT; return 0; } - read_unlock(&mrt_lock); + rcu_read_unlock(); return -EADDRNOTAVAIL; case SIOCGETSGCNT: - if (copy_from_user(&sr, arg, sizeof(sr))) - return -EFAULT; + sr = (struct sioc_sg_req *)arg; rcu_read_lock(); - c = ipmr_cache_find(mrt, sr.src.s_addr, sr.grp.s_addr); + c = ipmr_cache_find(mrt, sr->src.s_addr, sr->grp.s_addr); if (c) { - sr.pktcnt = c->_c.mfc_un.res.pkt; - sr.bytecnt = c->_c.mfc_un.res.bytes; - sr.wrong_if = c->_c.mfc_un.res.wrong_if; + sr->pktcnt = atomic_long_read(&c->_c.mfc_un.res.pkt); + sr->bytecnt = atomic_long_read(&c->_c.mfc_un.res.bytes); + sr->wrong_if = atomic_long_read(&c->_c.mfc_un.res.wrong_if); rcu_read_unlock(); - - if (copy_to_user(arg, &sr, sizeof(sr))) - return -EFAULT; return 0; } rcu_read_unlock(); @@ -1673,20 +1722,20 @@ int ipmr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg) if (vr.vifi >= mrt->maxvif) return -EINVAL; vr.vifi = array_index_nospec(vr.vifi, mrt->maxvif); - read_lock(&mrt_lock); + rcu_read_lock(); vif = &mrt->vif_table[vr.vifi]; if (VIF_EXISTS(mrt, vr.vifi)) { - vr.icount = vif->pkt_in; - vr.ocount = vif->pkt_out; - vr.ibytes = vif->bytes_in; - vr.obytes = vif->bytes_out; - read_unlock(&mrt_lock); + vr.icount = READ_ONCE(vif->pkt_in); + vr.ocount = READ_ONCE(vif->pkt_out); + vr.ibytes = READ_ONCE(vif->bytes_in); + vr.obytes = READ_ONCE(vif->bytes_out); + rcu_read_unlock(); if (copy_to_user(arg, &vr, sizeof(vr))) return -EFAULT; return 0; } - read_unlock(&mrt_lock); + rcu_read_unlock(); return -EADDRNOTAVAIL; case SIOCGETSGCNT: if (copy_from_user(&sr, arg, sizeof(sr))) @@ -1695,9 +1744,9 @@ int ipmr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg) rcu_read_lock(); c = ipmr_cache_find(mrt, sr.src.s_addr, sr.grp.s_addr); if (c) { - sr.pktcnt = c->_c.mfc_un.res.pkt; - sr.bytecnt = c->_c.mfc_un.res.bytes; - sr.wrong_if = c->_c.mfc_un.res.wrong_if; + sr.pktcnt = atomic_long_read(&c->_c.mfc_un.res.pkt); + sr.bytecnt = atomic_long_read(&c->_c.mfc_un.res.bytes); + sr.wrong_if = atomic_long_read(&c->_c.mfc_un.res.wrong_if); rcu_read_unlock(); if (copy_to_user(arg, &sr, sizeof(sr))) @@ -1726,7 +1775,7 @@ static int ipmr_device_event(struct notifier_block *this, unsigned long event, v ipmr_for_each_table(mrt, net) { v = &mrt->vif_table[0]; for (ct = 0; ct < mrt->maxvif; ct++, v++) { - if (v->dev == dev) + if (rcu_access_pointer(v->dev) == dev) vif_delete(mrt, ct, 1, NULL); } } @@ -1774,7 +1823,6 @@ static inline int ipmr_forward_finish(struct net *net, struct sock *sk, struct ip_options *opt = &(IPCB(skb)->opt); IP_INC_STATS(net, IPSTATS_MIB_OUTFORWDATAGRAMS); - IP_ADD_STATS(net, IPSTATS_MIB_OUTOCTETS, skb->len); if (unlikely(opt->optlen)) ip_forward_options(skb); @@ -1804,53 +1852,49 @@ static bool ipmr_forward_offloaded(struct sk_buff *skb, struct mr_table *mrt, } #endif -/* Processing handlers for ipmr_forward */ +/* Processing handlers for ipmr_forward, under rcu_read_lock() */ -static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt, - int in_vifi, struct sk_buff *skb, int vifi) +static int ipmr_prepare_xmit(struct net *net, struct mr_table *mrt, + struct sk_buff *skb, int vifi) { const struct iphdr *iph = ip_hdr(skb); struct vif_device *vif = &mrt->vif_table[vifi]; - struct net_device *dev; + struct net_device *vif_dev; struct rtable *rt; struct flowi4 fl4; int encap = 0; - if (!vif->dev) - goto out_free; + vif_dev = vif_dev_read(vif); + if (!vif_dev) + return -1; if (vif->flags & VIFF_REGISTER) { - vif->pkt_out++; - vif->bytes_out += skb->len; - vif->dev->stats.tx_bytes += skb->len; - vif->dev->stats.tx_packets++; + WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1); + WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len); + DEV_STATS_ADD(vif_dev, tx_bytes, skb->len); + DEV_STATS_INC(vif_dev, tx_packets); ipmr_cache_report(mrt, skb, vifi, IGMPMSG_WHOLEPKT); - goto out_free; + return -1; } - if (ipmr_forward_offloaded(skb, mrt, in_vifi, vifi)) - goto out_free; - if (vif->flags & VIFF_TUNNEL) { rt = ip_route_output_ports(net, &fl4, NULL, vif->remote, vif->local, 0, 0, IPPROTO_IPIP, - RT_TOS(iph->tos), vif->link); + iph->tos & INET_DSCP_MASK, vif->link); if (IS_ERR(rt)) - goto out_free; + return -1; encap = sizeof(struct iphdr); } else { rt = ip_route_output_ports(net, &fl4, NULL, iph->daddr, 0, 0, 0, IPPROTO_IPIP, - RT_TOS(iph->tos), vif->link); + iph->tos & INET_DSCP_MASK, vif->link); if (IS_ERR(rt)) - goto out_free; + return -1; } - dev = rt->dst.dev; - if (skb->len+encap > dst_mtu(&rt->dst) && (ntohs(iph->frag_off) & IP_DF)) { /* Do not fragment multicasts. Alas, IPv4 does not * allow to send ICMP, so that packets will disappear @@ -1858,18 +1902,18 @@ static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt, */ IP_INC_STATS(net, IPSTATS_MIB_FRAGFAILS); ip_rt_put(rt); - goto out_free; + return -1; } - encap += LL_RESERVED_SPACE(dev) + rt->dst.header_len; + encap += LL_RESERVED_SPACE(dst_dev_rcu(&rt->dst)) + rt->dst.header_len; if (skb_cow(skb, encap)) { ip_rt_put(rt); - goto out_free; + return -1; } - vif->pkt_out++; - vif->bytes_out += skb->len; + WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1); + WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len); skb_dst_drop(skb); skb_dst_set(skb, &rt->dst); @@ -1881,10 +1925,26 @@ static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt, if (vif->flags & VIFF_TUNNEL) { ip_encap(net, skb, vif->local, vif->remote); /* FIXME: extra output firewall step used to be here. --RR */ - vif->dev->stats.tx_packets++; - vif->dev->stats.tx_bytes += skb->len; + DEV_STATS_INC(vif_dev, tx_packets); + DEV_STATS_ADD(vif_dev, tx_bytes, skb->len); } + return 0; +} + +static void ipmr_queue_fwd_xmit(struct net *net, struct mr_table *mrt, + int in_vifi, struct sk_buff *skb, int vifi) +{ + struct rtable *rt; + + if (ipmr_forward_offloaded(skb, mrt, in_vifi, vifi)) + goto out_free; + + if (ipmr_prepare_xmit(net, mrt, skb, vifi)) + goto out_free; + + rt = skb_rtable(skb); + IPCB(skb)->flags |= IPSKB_FORWARDED; /* RFC1584 teaches, that DVMRP/PIM router must deliver packets locally @@ -1898,7 +1958,7 @@ static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt, * result in receiving multiple packets. */ NF_HOOK(NFPROTO_IPV4, NF_INET_FORWARD, - net, NULL, skb, skb->dev, dev, + net, NULL, skb, skb->dev, dst_dev_rcu(&rt->dst), ipmr_forward_finish); return; @@ -1906,18 +1966,33 @@ out_free: kfree_skb(skb); } -static int ipmr_find_vif(struct mr_table *mrt, struct net_device *dev) +static void ipmr_queue_output_xmit(struct net *net, struct mr_table *mrt, + struct sk_buff *skb, int vifi) { - int ct; + if (ipmr_prepare_xmit(net, mrt, skb, vifi)) + goto out_free; + + ip_mc_output(net, NULL, skb); + return; - for (ct = mrt->maxvif-1; ct >= 0; ct--) { - if (mrt->vif_table[ct].dev == dev) +out_free: + kfree_skb(skb); +} + +/* Called with mrt_lock or rcu_read_lock() */ +static int ipmr_find_vif(const struct mr_table *mrt, struct net_device *dev) +{ + int ct; + /* Pairs with WRITE_ONCE() in vif_delete()/vif_add() */ + for (ct = READ_ONCE(mrt->maxvif) - 1; ct >= 0; ct--) { + if (rcu_access_pointer(mrt->vif_table[ct].dev) == dev) break; } return ct; } /* "local" means that we should preserve one skb (for local delivery) */ +/* Called uner rcu_read_lock() */ static void ip_mr_forward(struct net *net, struct mr_table *mrt, struct net_device *dev, struct sk_buff *skb, struct mfc_cache *c, int local) @@ -1927,9 +2002,9 @@ static void ip_mr_forward(struct net *net, struct mr_table *mrt, int vif, ct; vif = c->_c.mfc_parent; - c->_c.mfc_un.res.pkt++; - c->_c.mfc_un.res.bytes += skb->len; - c->_c.mfc_un.res.lastuse = jiffies; + atomic_long_inc(&c->_c.mfc_un.res.pkt); + atomic_long_add(skb->len, &c->_c.mfc_un.res.bytes); + WRITE_ONCE(c->_c.mfc_un.res.lastuse, jiffies); if (c->mfc_origin == htonl(INADDR_ANY) && true_vifi >= 0) { struct mfc_cache *cache_proxy; @@ -1944,7 +2019,7 @@ static void ip_mr_forward(struct net *net, struct mr_table *mrt, } /* Wrong interface: drop packet and (maybe) send PIM assert. */ - if (mrt->vif_table[vif].dev != dev) { + if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev) { if (rt_is_output_route(skb_rtable(skb))) { /* It is our own packet, looped back. * Very complicated situation... @@ -1960,7 +2035,7 @@ static void ip_mr_forward(struct net *net, struct mr_table *mrt, goto dont_forward; } - c->_c.mfc_un.res.wrong_if++; + atomic_long_inc(&c->_c.mfc_un.res.wrong_if); if (true_vifi >= 0 && mrt->mroute_do_assert && /* pimsm uses asserts, when switching from RPT to SPT, @@ -1983,8 +2058,10 @@ static void ip_mr_forward(struct net *net, struct mr_table *mrt, } forward: - mrt->vif_table[vif].pkt_in++; - mrt->vif_table[vif].bytes_in += skb->len; + WRITE_ONCE(mrt->vif_table[vif].pkt_in, + mrt->vif_table[vif].pkt_in + 1); + WRITE_ONCE(mrt->vif_table[vif].bytes_in, + mrt->vif_table[vif].bytes_in + skb->len); /* Forward the frame */ if (c->mfc_origin == htonl(INADDR_ANY) && @@ -2012,8 +2089,8 @@ forward: struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); if (skb2) - ipmr_queue_xmit(net, mrt, true_vifi, - skb2, psend); + ipmr_queue_fwd_xmit(net, mrt, true_vifi, + skb2, psend); } psend = ct; } @@ -2024,10 +2101,10 @@ last_forward: struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); if (skb2) - ipmr_queue_xmit(net, mrt, true_vifi, skb2, - psend); + ipmr_queue_fwd_xmit(net, mrt, true_vifi, skb2, + psend); } else { - ipmr_queue_xmit(net, mrt, true_vifi, skb, psend); + ipmr_queue_fwd_xmit(net, mrt, true_vifi, skb, psend); return; } } @@ -2044,7 +2121,7 @@ static struct mr_table *ipmr_rt_fib_lookup(struct net *net, struct sk_buff *skb) struct flowi4 fl4 = { .daddr = iph->daddr, .saddr = iph->saddr, - .flowi4_tos = RT_TOS(iph->tos), + .flowi4_dscp = ip4h_dscp(iph), .flowi4_oif = (rt_is_output_route(rt) ? skb->dev->ifindex : 0), .flowi4_iif = (rt_is_output_route(rt) ? @@ -2140,22 +2217,14 @@ int ip_mr_input(struct sk_buff *skb) skb = skb2; } - read_lock(&mrt_lock); vif = ipmr_find_vif(mrt, dev); - if (vif >= 0) { - int err2 = ipmr_cache_unresolved(mrt, vif, skb, dev); - read_unlock(&mrt_lock); - - return err2; - } - read_unlock(&mrt_lock); + if (vif >= 0) + return ipmr_cache_unresolved(mrt, vif, skb, dev); kfree_skb(skb); return -ENODEV; } - read_lock(&mrt_lock); ip_mr_forward(net, mrt, dev, skb, cache, local); - read_unlock(&mrt_lock); if (local) return ip_local_deliver(skb); @@ -2169,6 +2238,110 @@ dont_forward: return 0; } +static void ip_mr_output_finish(struct net *net, struct mr_table *mrt, + struct net_device *dev, struct sk_buff *skb, + struct mfc_cache *c) +{ + int psend = -1; + int ct; + + atomic_long_inc(&c->_c.mfc_un.res.pkt); + atomic_long_add(skb->len, &c->_c.mfc_un.res.bytes); + WRITE_ONCE(c->_c.mfc_un.res.lastuse, jiffies); + + /* Forward the frame */ + if (c->mfc_origin == htonl(INADDR_ANY) && + c->mfc_mcastgrp == htonl(INADDR_ANY)) { + if (ip_hdr(skb)->ttl > + c->_c.mfc_un.res.ttls[c->_c.mfc_parent]) { + /* It's an (*,*) entry and the packet is not coming from + * the upstream: forward the packet to the upstream + * only. + */ + psend = c->_c.mfc_parent; + goto last_xmit; + } + goto dont_xmit; + } + + for (ct = c->_c.mfc_un.res.maxvif - 1; + ct >= c->_c.mfc_un.res.minvif; ct--) { + if (ip_hdr(skb)->ttl > c->_c.mfc_un.res.ttls[ct]) { + if (psend != -1) { + struct sk_buff *skb2; + + skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2) + ipmr_queue_output_xmit(net, mrt, + skb2, psend); + } + psend = ct; + } + } + +last_xmit: + if (psend != -1) { + ipmr_queue_output_xmit(net, mrt, skb, psend); + return; + } + +dont_xmit: + kfree_skb(skb); +} + +/* Multicast packets for forwarding arrive here + * Called with rcu_read_lock(); + */ +int ip_mr_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct rtable *rt = skb_rtable(skb); + struct mfc_cache *cache; + struct net_device *dev; + struct mr_table *mrt; + int vif; + + guard(rcu)(); + + dev = dst_dev_rcu(&rt->dst); + + if (IPCB(skb)->flags & IPSKB_FORWARDED) + goto mc_output; + if (!(IPCB(skb)->flags & IPSKB_MCROUTE)) + goto mc_output; + + skb->dev = dev; + + mrt = ipmr_rt_fib_lookup(net, skb); + if (IS_ERR(mrt)) + goto mc_output; + + cache = ipmr_cache_find(mrt, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr); + if (!cache) { + vif = ipmr_find_vif(mrt, dev); + if (vif >= 0) + cache = ipmr_cache_find_any(mrt, ip_hdr(skb)->daddr, + vif); + } + + /* No usable cache entry */ + if (!cache) { + vif = ipmr_find_vif(mrt, dev); + if (vif >= 0) + return ipmr_cache_unresolved(mrt, vif, skb, dev); + goto mc_output; + } + + vif = cache->_c.mfc_parent; + if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev) + goto mc_output; + + ip_mr_output_finish(net, mrt, dev, skb, cache); + return 0; + +mc_output: + return ip_mc_output(net, sk, skb); +} + #ifdef CONFIG_IP_PIMSM_V1 /* Handle IGMP messages of PIMv1 */ int pim_rcv_v1(struct sk_buff *skb) @@ -2233,11 +2406,13 @@ int ipmr_get_route(struct net *net, struct sk_buff *skb, struct mr_table *mrt; int err; - mrt = ipmr_get_table(net, RT_TABLE_DEFAULT); - if (!mrt) + rcu_read_lock(); + mrt = __ipmr_get_table(net, RT_TABLE_DEFAULT); + if (!mrt) { + rcu_read_unlock(); return -ENOENT; + } - rcu_read_lock(); cache = ipmr_cache_find(mrt, saddr, daddr); if (!cache && skb->dev) { int vif = ipmr_find_vif(mrt, skb->dev); @@ -2252,18 +2427,15 @@ int ipmr_get_route(struct net *net, struct sk_buff *skb, int vif = -1; dev = skb->dev; - read_lock(&mrt_lock); if (dev) vif = ipmr_find_vif(mrt, dev); if (vif < 0) { - read_unlock(&mrt_lock); rcu_read_unlock(); return -ENODEV; } skb2 = skb_realloc_headroom(skb, sizeof(struct iphdr)); if (!skb2) { - read_unlock(&mrt_lock); rcu_read_unlock(); return -ENOMEM; } @@ -2277,14 +2449,11 @@ int ipmr_get_route(struct net *net, struct sk_buff *skb, iph->daddr = daddr; iph->version = 0; err = ipmr_cache_unresolved(mrt, vif, skb2, dev); - read_unlock(&mrt_lock); rcu_read_unlock(); return err; } - read_lock(&mrt_lock); err = mr_fill_mroute(mrt, skb, &cache->_c, rtm); - read_unlock(&mrt_lock); rcu_read_unlock(); return err; } @@ -2384,8 +2553,7 @@ static void mroute_netlink_event(struct mr_table *mrt, struct mfc_cache *mfc, errout: kfree_skb(skb); - if (err < 0) - rtnl_set_sk_err(net, RTNLGRP_IPV4_MROUTE, err); + rtnl_set_sk_err(net, RTNLGRP_IPV4_MROUTE, err); } static size_t igmpmsg_netlink_msgsize(size_t payloadlen) @@ -2404,7 +2572,7 @@ static size_t igmpmsg_netlink_msgsize(size_t payloadlen) return len; } -static void igmpmsg_netlink_event(struct mr_table *mrt, struct sk_buff *pkt) +static void igmpmsg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt) { struct net *net = read_pnet(&mrt->net); struct nlmsghdr *nlh; @@ -2461,7 +2629,8 @@ static int ipmr_rtm_valid_getroute_req(struct sk_buff *skb, struct rtmsg *rtm; int i, err; - if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + rtm = nlmsg_payload(nlh, sizeof(*rtm)); + if (!rtm) { NL_SET_ERR_MSG(extack, "ipv4: Invalid header for multicast route get request"); return -EINVAL; } @@ -2470,7 +2639,6 @@ static int ipmr_rtm_valid_getroute_req(struct sk_buff *skb, return nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_ipv4_policy, extack); - rtm = nlmsg_data(nlh); if ((rtm->rtm_src_len && rtm->rtm_src_len != 32) || (rtm->rtm_dst_len && rtm->rtm_dst_len != 32) || rtm->rtm_tos || rtm->rtm_table || rtm->rtm_protocol || @@ -2524,11 +2692,11 @@ static int ipmr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, if (err < 0) goto errout; - src = tb[RTA_SRC] ? nla_get_in_addr(tb[RTA_SRC]) : 0; - grp = tb[RTA_DST] ? nla_get_in_addr(tb[RTA_DST]) : 0; - tableid = tb[RTA_TABLE] ? nla_get_u32(tb[RTA_TABLE]) : 0; + src = nla_get_in_addr_default(tb[RTA_SRC], 0); + grp = nla_get_in_addr_default(tb[RTA_DST], 0); + tableid = nla_get_u32_default(tb[RTA_TABLE], 0); - mrt = ipmr_get_table(net, tableid ? tableid : RT_TABLE_DEFAULT); + mrt = __ipmr_get_table(net, tableid ? tableid : RT_TABLE_DEFAULT); if (!mrt) { err = -ENOENT; goto errout_free; @@ -2567,7 +2735,9 @@ errout_free: static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) { - struct fib_dump_filter filter = {}; + struct fib_dump_filter filter = { + .rtnl_held = true, + }; int err; if (cb->strict_check) { @@ -2580,7 +2750,7 @@ static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) if (filter.table_id) { struct mr_table *mrt; - mrt = ipmr_get_table(sock_net(skb->sk), filter.table_id); + mrt = __ipmr_get_table(sock_net(skb->sk), filter.table_id); if (!mrt) { if (rtnl_msg_family(cb->nlh) != RTNL_FAMILY_IPMR) return skb->len; @@ -2688,7 +2858,7 @@ static int rtm_to_ipmr_mfcc(struct net *net, struct nlmsghdr *nlh, break; } } - mrt = ipmr_get_table(net, tblid); + mrt = __ipmr_get_table(net, tblid); if (!mrt) { ret = -ENOENT; goto out; @@ -2744,18 +2914,21 @@ static bool ipmr_fill_table(struct mr_table *mrt, struct sk_buff *skb) static bool ipmr_fill_vif(struct mr_table *mrt, u32 vifid, struct sk_buff *skb) { + struct net_device *vif_dev; struct nlattr *vif_nest; struct vif_device *vif; + vif = &mrt->vif_table[vifid]; + vif_dev = rtnl_dereference(vif->dev); /* if the VIF doesn't exist just continue */ - if (!VIF_EXISTS(mrt, vifid)) + if (!vif_dev) return true; - vif = &mrt->vif_table[vifid]; vif_nest = nla_nest_start_noflag(skb, IPMRA_VIF); if (!vif_nest) return false; - if (nla_put_u32(skb, IPMRA_VIFA_IFINDEX, vif->dev->ifindex) || + + if (nla_put_u32(skb, IPMRA_VIFA_IFINDEX, vif_dev->ifindex) || nla_put_u32(skb, IPMRA_VIFA_VIF_ID, vifid) || nla_put_u16(skb, IPMRA_VIFA_FLAGS, vif->flags) || nla_put_u64_64bit(skb, IPMRA_VIFA_BYTES_IN, vif->bytes_in, @@ -2781,7 +2954,8 @@ static int ipmr_valid_dumplink(const struct nlmsghdr *nlh, { struct ifinfomsg *ifm; - if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + ifm = nlmsg_payload(nlh, sizeof(*ifm)); + if (!ifm) { NL_SET_ERR_MSG(extack, "ipv4: Invalid header for ipmr link dump"); return -EINVAL; } @@ -2791,7 +2965,6 @@ static int ipmr_valid_dumplink(const struct nlmsghdr *nlh, return -EINVAL; } - ifm = nlmsg_data(nlh); if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || ifm->ifi_change || ifm->ifi_index) { NL_SET_ERR_MSG(extack, "Invalid values in header for ipmr link dump request"); @@ -2887,26 +3060,28 @@ out: */ static void *ipmr_vif_seq_start(struct seq_file *seq, loff_t *pos) - __acquires(mrt_lock) + __acquires(RCU) { struct mr_vif_iter *iter = seq->private; struct net *net = seq_file_net(seq); struct mr_table *mrt; - mrt = ipmr_get_table(net, RT_TABLE_DEFAULT); - if (!mrt) + rcu_read_lock(); + mrt = __ipmr_get_table(net, RT_TABLE_DEFAULT); + if (!mrt) { + rcu_read_unlock(); return ERR_PTR(-ENOENT); + } iter->mrt = mrt; - read_lock(&mrt_lock); return mr_vif_seq_start(seq, pos); } static void ipmr_vif_seq_stop(struct seq_file *seq, void *v) - __releases(mrt_lock) + __releases(RCU) { - read_unlock(&mrt_lock); + rcu_read_unlock(); } static int ipmr_vif_seq_show(struct seq_file *seq, void *v) @@ -2919,9 +3094,11 @@ static int ipmr_vif_seq_show(struct seq_file *seq, void *v) "Interface BytesIn PktsIn BytesOut PktsOut Flags Local Remote\n"); } else { const struct vif_device *vif = v; - const char *name = vif->dev ? - vif->dev->name : "none"; + const struct net_device *vif_dev; + const char *name; + vif_dev = vif_dev_read(vif); + name = vif_dev ? vif_dev->name : "none"; seq_printf(seq, "%2td %-10s %8ld %7ld %8ld %7ld %05X %08X %08X\n", vif - mrt->vif_table, @@ -2970,9 +3147,9 @@ static int ipmr_mfc_seq_show(struct seq_file *seq, void *v) if (it->cache != &mrt->mfc_unres_queue) { seq_printf(seq, " %8lu %8lu %8lu", - mfc->_c.mfc_un.res.pkt, - mfc->_c.mfc_un.res.bytes, - mfc->_c.mfc_un.res.wrong_if); + atomic_long_read(&mfc->_c.mfc_un.res.pkt), + atomic_long_read(&mfc->_c.mfc_un.res.bytes), + atomic_long_read(&mfc->_c.mfc_un.res.wrong_if)); for (n = mfc->_c.mfc_un.res.minvif; n < mfc->_c.mfc_un.res.maxvif; n++) { if (VIF_EXISTS(mrt, n) && @@ -3006,18 +3183,16 @@ static const struct net_protocol pim_protocol = { }; #endif -static unsigned int ipmr_seq_read(struct net *net) +static unsigned int ipmr_seq_read(const struct net *net) { - ASSERT_RTNL(); - - return net->ipv4.ipmr_seq + ipmr_rules_seq_read(net); + return READ_ONCE(net->ipv4.ipmr_seq) + ipmr_rules_seq_read(net); } static int ipmr_dump(struct net *net, struct notifier_block *nb, struct netlink_ext_ack *extack) { return mr_dump(net, nb, RTNL_FAMILY_IPMR, ipmr_rules_dump, - ipmr_mr_table_iter, &mrt_lock, extack); + ipmr_mr_table_iter, extack); } static const struct fib_notifier_ops ipmr_notifier_ops_template = { @@ -3075,7 +3250,9 @@ static int __net_init ipmr_net_init(struct net *net) proc_cache_fail: remove_proc_entry("ip_mr_vif", net->proc_net); proc_vif_fail: + rtnl_lock(); ipmr_rules_exit(net); + rtnl_unlock(); #endif ipmr_rules_fail: ipmr_notifier_exit(net); @@ -3090,22 +3267,40 @@ static void __net_exit ipmr_net_exit(struct net *net) remove_proc_entry("ip_mr_vif", net->proc_net); #endif ipmr_notifier_exit(net); - ipmr_rules_exit(net); +} + +static void __net_exit ipmr_net_exit_batch(struct list_head *net_list) +{ + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) + ipmr_rules_exit(net); + rtnl_unlock(); } static struct pernet_operations ipmr_net_ops = { .init = ipmr_net_init, .exit = ipmr_net_exit, + .exit_batch = ipmr_net_exit_batch, +}; + +static const struct rtnl_msg_handler ipmr_rtnl_msg_handlers[] __initconst = { + {.protocol = RTNL_FAMILY_IPMR, .msgtype = RTM_GETLINK, + .dumpit = ipmr_rtm_dumplink}, + {.protocol = RTNL_FAMILY_IPMR, .msgtype = RTM_NEWROUTE, + .doit = ipmr_rtm_route}, + {.protocol = RTNL_FAMILY_IPMR, .msgtype = RTM_DELROUTE, + .doit = ipmr_rtm_route}, + {.protocol = RTNL_FAMILY_IPMR, .msgtype = RTM_GETROUTE, + .doit = ipmr_rtm_getroute, .dumpit = ipmr_rtm_dumproute}, }; int __init ip_mr_init(void) { int err; - mrt_cachep = kmem_cache_create("ip_mrt_cache", - sizeof(struct mfc_cache), - 0, SLAB_HWCACHE_ALIGN | SLAB_PANIC, - NULL); + mrt_cachep = KMEM_CACHE(mfc_cache, SLAB_HWCACHE_ALIGN | SLAB_PANIC); err = register_pernet_subsys(&ipmr_net_ops); if (err) @@ -3121,15 +3316,8 @@ int __init ip_mr_init(void) goto add_proto_fail; } #endif - rtnl_register(RTNL_FAMILY_IPMR, RTM_GETROUTE, - ipmr_rtm_getroute, ipmr_rtm_dumproute, 0); - rtnl_register(RTNL_FAMILY_IPMR, RTM_NEWROUTE, - ipmr_rtm_route, NULL, 0); - rtnl_register(RTNL_FAMILY_IPMR, RTM_DELROUTE, - ipmr_rtm_route, NULL, 0); - - rtnl_register(RTNL_FAMILY_IPMR, RTM_GETLINK, - NULL, ipmr_rtm_dumplink, 0); + rtnl_register_many(ipmr_rtnl_msg_handlers); + return 0; #ifdef CONFIG_IP_PIMSM_V2 diff --git a/net/ipv4/ipmr_base.c b/net/ipv4/ipmr_base.c index aa8738a91210..28d77d454d44 100644 --- a/net/ipv4/ipmr_base.c +++ b/net/ipv4/ipmr_base.c @@ -13,7 +13,7 @@ void vif_device_init(struct vif_device *v, unsigned short flags, unsigned short get_iflink_mask) { - v->dev = NULL; + RCU_INIT_POINTER(v->dev, NULL); v->bytes_in = 0; v->bytes_out = 0; v->pkt_in = 0; @@ -208,6 +208,7 @@ EXPORT_SYMBOL(mr_mfc_seq_next); int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb, struct mr_mfc *c, struct rtmsg *rtm) { + struct net_device *vif_dev; struct rta_mfc_stats mfcs; struct nlattr *mp_attr; struct rtnexthop *nhp; @@ -220,10 +221,13 @@ int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb, return -ENOENT; } - if (VIF_EXISTS(mrt, c->mfc_parent) && - nla_put_u32(skb, RTA_IIF, - mrt->vif_table[c->mfc_parent].dev->ifindex) < 0) + rcu_read_lock(); + vif_dev = rcu_dereference(mrt->vif_table[c->mfc_parent].dev); + if (vif_dev && nla_put_u32(skb, RTA_IIF, vif_dev->ifindex) < 0) { + rcu_read_unlock(); return -EMSGSIZE; + } + rcu_read_unlock(); if (c->mfc_flags & MFC_OFFLOAD) rtm->rtm_flags |= RTNH_F_OFFLOAD; @@ -232,32 +236,36 @@ int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb, if (!mp_attr) return -EMSGSIZE; + rcu_read_lock(); for (ct = c->mfc_un.res.minvif; ct < c->mfc_un.res.maxvif; ct++) { - if (VIF_EXISTS(mrt, ct) && c->mfc_un.res.ttls[ct] < 255) { - struct vif_device *vif; + struct vif_device *vif = &mrt->vif_table[ct]; + + vif_dev = rcu_dereference(vif->dev); + if (vif_dev && c->mfc_un.res.ttls[ct] < 255) { nhp = nla_reserve_nohdr(skb, sizeof(*nhp)); if (!nhp) { + rcu_read_unlock(); nla_nest_cancel(skb, mp_attr); return -EMSGSIZE; } nhp->rtnh_flags = 0; nhp->rtnh_hops = c->mfc_un.res.ttls[ct]; - vif = &mrt->vif_table[ct]; - nhp->rtnh_ifindex = vif->dev->ifindex; + nhp->rtnh_ifindex = vif_dev->ifindex; nhp->rtnh_len = sizeof(*nhp); } } + rcu_read_unlock(); nla_nest_end(skb, mp_attr); lastuse = READ_ONCE(c->mfc_un.res.lastuse); lastuse = time_after_eq(jiffies, lastuse) ? jiffies - lastuse : 0; - mfcs.mfcs_packets = c->mfc_un.res.pkt; - mfcs.mfcs_bytes = c->mfc_un.res.bytes; - mfcs.mfcs_wrong_if = c->mfc_un.res.wrong_if; + mfcs.mfcs_packets = atomic_long_read(&c->mfc_un.res.pkt); + mfcs.mfcs_bytes = atomic_long_read(&c->mfc_un.res.bytes); + mfcs.mfcs_wrong_if = atomic_long_read(&c->mfc_un.res.wrong_if); if (nla_put_64bit(skb, RTA_MFC_STATS, sizeof(mfcs), &mfcs, RTA_PAD) || nla_put_u64_64bit(skb, RTA_EXPIRES, jiffies_to_clock_t(lastuse), RTA_PAD)) @@ -275,13 +283,14 @@ static bool mr_mfc_uses_dev(const struct mr_table *mrt, int ct; for (ct = c->mfc_un.res.minvif; ct < c->mfc_un.res.maxvif; ct++) { - if (VIF_EXISTS(mrt, ct) && c->mfc_un.res.ttls[ct] < 255) { - const struct vif_device *vif; - - vif = &mrt->vif_table[ct]; - if (vif->dev == dev) - return true; - } + const struct net_device *vif_dev; + const struct vif_device *vif; + + vif = &mrt->vif_table[ct]; + vif_dev = rcu_access_pointer(vif->dev); + if (vif_dev && c->mfc_un.res.ttls[ct] < 255 && + vif_dev == dev) + return true; } return false; } @@ -301,7 +310,8 @@ int mr_table_dump(struct mr_table *mrt, struct sk_buff *skb, if (filter->filter_set) flags |= NLM_F_DUMP_FILTERED; - list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list) { + list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list, + lockdep_rtnl_is_held()) { if (e < s_e) goto next_entry; if (filter->dev && @@ -320,9 +330,6 @@ next_entry: list_for_each_entry(mfc, &mrt->mfc_unres_queue, list) { if (e < s_e) goto next_entry2; - if (filter->dev && - !mr_mfc_uses_dev(mrt, mfc, filter->dev)) - goto next_entry2; err = fill(mrt, skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, mfc, RTM_NEWROUTE, flags); @@ -390,7 +397,6 @@ int mr_dump(struct net *net, struct notifier_block *nb, unsigned short family, struct netlink_ext_ack *extack), struct mr_table *(*mr_iter)(struct net *net, struct mr_table *mrt), - rwlock_t *mrt_lock, struct netlink_ext_ack *extack) { struct mr_table *mrt; @@ -402,22 +408,25 @@ int mr_dump(struct net *net, struct notifier_block *nb, unsigned short family, for (mrt = mr_iter(net, NULL); mrt; mrt = mr_iter(net, mrt)) { struct vif_device *v = &mrt->vif_table[0]; + struct net_device *vif_dev; struct mr_mfc *mfc; int vifi; /* Notifiy on table VIF entries */ - read_lock(mrt_lock); + rcu_read_lock(); for (vifi = 0; vifi < mrt->maxvif; vifi++, v++) { - if (!v->dev) + vif_dev = rcu_dereference(v->dev); + if (!vif_dev) continue; err = mr_call_vif_notifier(nb, family, - FIB_EVENT_VIF_ADD, - v, vifi, mrt->id, extack); + FIB_EVENT_VIF_ADD, v, + vif_dev, vifi, + mrt->id, extack); if (err) break; } - read_unlock(mrt_lock); + rcu_read_unlock(); if (err) return err; diff --git a/net/ipv4/metrics.c b/net/ipv4/metrics.c index 25ea6ac44db9..8ddac1f595ed 100644 --- a/net/ipv4/metrics.c +++ b/net/ipv4/metrics.c @@ -1,12 +1,13 @@ // SPDX-License-Identifier: GPL-2.0-only #include <linux/netlink.h> +#include <linux/nospec.h> #include <linux/rtnetlink.h> #include <linux/types.h> #include <net/ip.h> #include <net/net_namespace.h> #include <net/tcp.h> -static int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, +static int ip_metrics_convert(struct nlattr *fc_mx, int fc_mx_len, u32 *metrics, struct netlink_ext_ack *extack) { @@ -14,9 +15,6 @@ static int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, struct nlattr *nla; int remaining; - if (!fc_mx) - return 0; - nla_for_each_attr(nla, fc_mx, fc_mx_len, remaining) { int type = nla_type(nla); u32 val; @@ -28,11 +26,12 @@ static int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, return -EINVAL; } + type = array_index_nospec(type, RTAX_MAX + 1); if (type == RTAX_CC_ALGO) { char tmp[TCP_CA_NAME_MAX]; nla_strscpy(tmp, nla, sizeof(tmp)); - val = tcp_ca_get_key_by_name(net, tmp, &ecn_ca); + val = tcp_ca_get_key_by_name(tmp, &ecn_ca); if (val == TCP_CA_UNSPEC) { NL_SET_ERR_MSG(extack, "Unknown tcp congestion algorithm"); return -EINVAL; @@ -64,7 +63,7 @@ static int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, return 0; } -struct dst_metrics *ip_fib_metrics_init(struct net *net, struct nlattr *fc_mx, +struct dst_metrics *ip_fib_metrics_init(struct nlattr *fc_mx, int fc_mx_len, struct netlink_ext_ack *extack) { @@ -78,7 +77,7 @@ struct dst_metrics *ip_fib_metrics_init(struct net *net, struct nlattr *fc_mx, if (unlikely(!fib_metrics)) return ERR_PTR(-ENOMEM); - err = ip_metrics_convert(net, fc_mx, fc_mx_len, fib_metrics->metrics, + err = ip_metrics_convert(fc_mx, fc_mx_len, fib_metrics->metrics, extack); if (!err) { refcount_set(&fib_metrics->refcnt, 1); diff --git a/net/ipv4/netfilter.c b/net/ipv4/netfilter.c index aff707988e23..ce310eb779e0 100644 --- a/net/ipv4/netfilter.c +++ b/net/ipv4/netfilter.c @@ -11,6 +11,7 @@ #include <linux/skbuff.h> #include <linux/gfp.h> #include <linux/export.h> +#include <net/flow.h> #include <net/route.h> #include <net/xfrm.h> #include <net/ip.h> @@ -19,12 +20,12 @@ /* route_me_harder function, used by iptable_nat, iptable_mangle + ip_queue */ int ip_route_me_harder(struct net *net, struct sock *sk, struct sk_buff *skb, unsigned int addr_type) { + struct net_device *dev = skb_dst_dev(skb); const struct iphdr *iph = ip_hdr(skb); struct rtable *rt; struct flowi4 fl4 = {}; __be32 saddr = iph->saddr; __u8 flags; - struct net_device *dev = skb_dst(skb)->dev; struct flow_keys flkeys; unsigned int hh_len; @@ -43,10 +44,9 @@ int ip_route_me_harder(struct net *net, struct sock *sk, struct sk_buff *skb, un */ fl4.daddr = iph->daddr; fl4.saddr = saddr; - fl4.flowi4_tos = RT_TOS(iph->tos); + fl4.flowi4_dscp = ip4h_dscp(iph); fl4.flowi4_oif = sk ? sk->sk_bound_dev_if : 0; - if (!fl4.flowi4_oif) - fl4.flowi4_oif = l3mdev_master_ifindex(dev); + fl4.flowi4_l3mdev = l3mdev_master_ifindex(dev); fl4.flowi4_mark = skb->mark; fl4.flowi4_flags = flags; fib4_rules_early_flow_dissect(net, skb, &fl4, &flkeys); @@ -63,9 +63,12 @@ int ip_route_me_harder(struct net *net, struct sock *sk, struct sk_buff *skb, un #ifdef CONFIG_XFRM if (!(IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) && - xfrm_decode_session(skb, flowi4_to_flowi(&fl4), AF_INET) == 0) { + xfrm_decode_session(net, skb, flowi4_to_flowi(&fl4), AF_INET) == 0) { struct dst_entry *dst = skb_dst(skb); - skb_dst_set(skb, NULL); + /* ignore return value from skb_dstref_steal, xfrm_lookup takes + * care of dropping the refcnt if needed. + */ + skb_dstref_steal(skb); dst = xfrm_lookup(net, dst, flowi4_to_flowi(&fl4), sk, 0); if (IS_ERR(dst)) return PTR_ERR(dst); @@ -74,7 +77,7 @@ int ip_route_me_harder(struct net *net, struct sock *sk, struct sk_buff *skb, un #endif /* Change in oif may mean change in hh_len. */ - hh_len = skb_dst(skb)->dev->hard_header_len; + hh_len = skb_dst_dev(skb)->hard_header_len; if (skb_headroom(skb) < hh_len && pskb_expand_head(skb, HH_DATA_ALIGN(hh_len - skb_headroom(skb)), 0, GFP_ATOMIC)) diff --git a/net/ipv4/netfilter/Kconfig b/net/ipv4/netfilter/Kconfig index 67087f95579f..7dc9772fe2d8 100644 --- a/net/ipv4/netfilter/Kconfig +++ b/net/ipv4/netfilter/Kconfig @@ -10,6 +10,17 @@ config NF_DEFRAG_IPV4 tristate default n +# old sockopt interface and eval loop +config IP_NF_IPTABLES_LEGACY + tristate "Legacy IP tables support" + depends on NETFILTER_XTABLES_LEGACY + depends on NETFILTER_XTABLES + default m if NETFILTER_XTABLES_LEGACY + help + iptables is a legacy packet classifier. + This is not needed if you are using iptables over nftables + (iptables-nft). + config NF_SOCKET_IPV4 tristate "IPv4 socket lookup support" help @@ -58,10 +69,6 @@ config NF_TABLES_ARP endif # NF_TABLES -config NF_FLOW_TABLE_IPV4 - tristate - select NF_FLOW_TABLE_INET - config NF_DUP_IPV4 tristate "Netfilter IPv4 packet duplication to alternate destination" depends on !NF_CONNTRACK || NF_CONNTRACK @@ -156,7 +163,7 @@ config IP_NF_MATCH_ECN config IP_NF_MATCH_RPFILTER tristate '"rpfilter" reverse path filter match support' depends on NETFILTER_ADVANCED - depends on IP_NF_MANGLE || IP_NF_RAW + depends on IP_NF_MANGLE || IP_NF_RAW || NFT_COMPAT help This option allows you to match packets whose replies would go out via the interface the packet came in. @@ -176,7 +183,8 @@ config IP_NF_MATCH_TTL # `filter', generic and specific targets config IP_NF_FILTER tristate "Packet filtering" - default m if NETFILTER_ADVANCED=n + default m if NETFILTER_ADVANCED=n || IP_NF_IPTABLES_LEGACY + depends on IP_NF_IPTABLES_LEGACY help Packet filtering defines a table `filter', which has a series of rules for simple packet filtering at local input, forwarding and @@ -186,7 +194,7 @@ config IP_NF_FILTER config IP_NF_TARGET_REJECT tristate "REJECT target support" - depends on IP_NF_FILTER + depends on IP_NF_FILTER || NFT_COMPAT select NF_REJECT_IPV4 default m if NETFILTER_ADVANCED=n help @@ -213,6 +221,7 @@ config IP_NF_TARGET_SYNPROXY config IP_NF_NAT tristate "iptables NAT support" depends on NF_CONNTRACK + depends on IP_NF_IPTABLES_LEGACY default m if NETFILTER_ADVANCED=n select NF_NAT select NETFILTER_XT_NAT @@ -255,7 +264,8 @@ endif # IP_NF_NAT # mangle + specific targets config IP_NF_MANGLE tristate "Packet mangling" - default m if NETFILTER_ADVANCED=n + default m if NETFILTER_ADVANCED=n || IP_NF_IPTABLES_LEGACY + depends on IP_NF_IPTABLES_LEGACY help This option adds a `mangle' table to iptables: see the man page for iptables(8). This table is used for various packet alterations @@ -263,23 +273,9 @@ config IP_NF_MANGLE To compile it as a module, choose M here. If unsure, say N. -config IP_NF_TARGET_CLUSTERIP - tristate "CLUSTERIP target support" - depends on IP_NF_MANGLE - depends on NF_CONNTRACK - depends on NETFILTER_ADVANCED - select NF_CONNTRACK_MARK - select NETFILTER_FAMILY_ARP - help - The CLUSTERIP target allows you to build load-balancing clusters of - network servers without having a dedicated load-balancing - router/server/switch. - - To compile it as a module, choose M here. If unsure, say N. - config IP_NF_TARGET_ECN tristate "ECN target support" - depends on IP_NF_MANGLE + depends on IP_NF_MANGLE || NFT_COMPAT depends on NETFILTER_ADVANCED help This option adds a `ECN' target, which can be used in the iptables mangle @@ -304,6 +300,7 @@ config IP_NF_TARGET_TTL # raw + specific targets config IP_NF_RAW tristate 'raw table support (required for NOTRACK/TRACE)' + depends on IP_NF_IPTABLES_LEGACY help This option adds a `raw' table to iptables. This table is the very first in the netfilter framework and hooks in at the PREROUTING @@ -317,6 +314,7 @@ config IP_NF_SECURITY tristate "Security table" depends on SECURITY depends on NETFILTER_ADVANCED + depends on IP_NF_IPTABLES_LEGACY help This option adds a `security' table to iptables, for use with Mandatory Access Control (MAC) policy. @@ -327,36 +325,44 @@ endif # IP_NF_IPTABLES # ARP tables config IP_NF_ARPTABLES - tristate "ARP tables support" - select NETFILTER_XTABLES - select NETFILTER_FAMILY_ARP - depends on NETFILTER_ADVANCED + tristate "Legacy ARPTABLES support" + depends on NETFILTER_XTABLES_LEGACY + depends on NETFILTER_XTABLES + default n help - arptables is a general, extensible packet identification framework. - The ARP packet filtering and mangling (manipulation)subsystems - use this: say Y or M here if you want to use either of those. - - To compile it as a module, choose M here. If unsure, say N. + arptables is a legacy packet classifier. + This is not needed if you are using arptables over nftables + (iptables-nft). -if IP_NF_ARPTABLES +config NFT_COMPAT_ARP + tristate + depends on NF_TABLES_ARP && NFT_COMPAT + default m if NFT_COMPAT=m + default y if NFT_COMPAT=y config IP_NF_ARPFILTER - tristate "ARP packet filtering" + tristate "arptables-legacy packet filtering support" + select IP_NF_ARPTABLES + select NETFILTER_FAMILY_ARP + depends on NETFILTER_XTABLES_LEGACY + depends on NETFILTER_XTABLES help ARP packet filtering defines a table `filter', which has a series of rules for simple ARP packet filtering at local input and - local output. On a bridge, you can also specify filtering rules - for forwarded ARP packets. See the man page for arptables(8). + local output. This is only needed for arptables-legacy(8). + Neither arptables-nft nor nftables need this to work. To compile it as a module, choose M here. If unsure, say N. config IP_NF_ARP_MANGLE tristate "ARP payload mangling" + depends on IP_NF_ARPTABLES || NFT_COMPAT_ARP help Allows altering the ARP packet payload: source and destination hardware and network addresses. -endif # IP_NF_ARPTABLES + This option is needed by both arptables-legacy and arptables-nft. + It is not used by nftables. endmenu diff --git a/net/ipv4/netfilter/Makefile b/net/ipv4/netfilter/Makefile index 93bad1184251..85502d4dfbb4 100644 --- a/net/ipv4/netfilter/Makefile +++ b/net/ipv4/netfilter/Makefile @@ -25,7 +25,7 @@ obj-$(CONFIG_NFT_FIB_IPV4) += nft_fib_ipv4.o obj-$(CONFIG_NFT_DUP_IPV4) += nft_dup_ipv4.o # generic IP tables -obj-$(CONFIG_IP_NF_IPTABLES) += ip_tables.o +obj-$(CONFIG_IP_NF_IPTABLES_LEGACY) += ip_tables.o # the three instances of ip_tables obj-$(CONFIG_IP_NF_FILTER) += iptable_filter.o @@ -39,7 +39,6 @@ obj-$(CONFIG_IP_NF_MATCH_AH) += ipt_ah.o obj-$(CONFIG_IP_NF_MATCH_RPFILTER) += ipt_rpfilter.o # targets -obj-$(CONFIG_IP_NF_TARGET_CLUSTERIP) += ipt_CLUSTERIP.o obj-$(CONFIG_IP_NF_TARGET_ECN) += ipt_ECN.o obj-$(CONFIG_IP_NF_TARGET_REJECT) += ipt_REJECT.o obj-$(CONFIG_IP_NF_TARGET_SYNPROXY) += ipt_SYNPROXY.o diff --git a/net/ipv4/netfilter/arp_tables.c b/net/ipv4/netfilter/arp_tables.c index ffc0cab7cf18..1cdd9c28ab2d 100644 --- a/net/ipv4/netfilter/arp_tables.c +++ b/net/ipv4/netfilter/arp_tables.c @@ -826,7 +826,7 @@ static int get_info(struct net *net, void __user *user, const int *len) sizeof(info.underflow)); info.num_entries = private->number; info.size = private->size; - strcpy(info.name, name); + strscpy(info.name, name); if (copy_to_user(user, &info, *len) != 0) ret = -EFAULT; @@ -956,6 +956,8 @@ static int do_replace(struct net *net, sockptr_t arg, unsigned int len) void *loc_cpu_entry; struct arpt_entry *iter; + if (len < sizeof(tmp)) + return -EINVAL; if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) return -EFAULT; @@ -964,6 +966,8 @@ static int do_replace(struct net *net, sockptr_t arg, unsigned int len) return -ENOMEM; if (tmp.num_counters == 0) return -EINVAL; + if ((u64)len < (u64)tmp.size + sizeof(tmp)) + return -EINVAL; tmp.name[sizeof(tmp.name)-1] = 0; @@ -1254,6 +1258,8 @@ static int compat_do_replace(struct net *net, sockptr_t arg, unsigned int len) void *loc_cpu_entry; struct arpt_entry *iter; + if (len < sizeof(tmp)) + return -EINVAL; if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) return -EFAULT; @@ -1262,6 +1268,8 @@ static int compat_do_replace(struct net *net, sockptr_t arg, unsigned int len) return -ENOMEM; if (tmp.num_counters == 0) return -EINVAL; + if ((u64)len < (u64)tmp.size + sizeof(tmp)) + return -EINVAL; tmp.name[sizeof(tmp.name)-1] = 0; @@ -1525,6 +1533,10 @@ int arpt_register_table(struct net *net, new_table = xt_register_table(net, table, &bootstrap, newinfo); if (IS_ERR(new_table)) { + struct arpt_entry *iter; + + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); xt_free_table_info(newinfo); return PTR_ERR(new_table); } @@ -1535,7 +1547,7 @@ int arpt_register_table(struct net *net, goto out_free; } - ops = kmemdup(template_ops, sizeof(*ops) * num_ops, GFP_KERNEL); + ops = kmemdup_array(template_ops, num_ops, sizeof(*ops), GFP_KERNEL); if (!ops) { ret = -ENOMEM; goto out_free; diff --git a/net/ipv4/netfilter/ip_tables.c b/net/ipv4/netfilter/ip_tables.c index 2ed7c58b471a..23c8deff8095 100644 --- a/net/ipv4/netfilter/ip_tables.c +++ b/net/ipv4/netfilter/ip_tables.c @@ -14,7 +14,6 @@ #include <linux/vmalloc.h> #include <linux/netdevice.h> #include <linux/module.h> -#include <linux/icmp.h> #include <net/ip.h> #include <net/compat.h> #include <linux/uaccess.h> @@ -31,7 +30,6 @@ MODULE_LICENSE("GPL"); MODULE_AUTHOR("Netfilter Core Team <coreteam@netfilter.org>"); MODULE_DESCRIPTION("IPv4 packet filter"); -MODULE_ALIAS("ipt_icmp"); void *ipt_alloc_initial_table(const struct xt_table *info) { @@ -272,7 +270,7 @@ ipt_do_table(void *priv, * but it is no problem since absolute verdict is issued by these. */ if (static_key_false(&xt_tee_enabled)) - jumpstack += private->stacksize * __this_cpu_read(nf_skb_duplicated); + jumpstack += private->stacksize * current->in_nf_duplicate; e = get_entry(table_base, private->hook_entry[hook]); @@ -983,7 +981,7 @@ static int get_info(struct net *net, void __user *user, const int *len) sizeof(info.underflow)); info.num_entries = private->number; info.size = private->size; - strcpy(info.name, name); + strscpy(info.name, name); if (copy_to_user(user, &info, *len) != 0) ret = -EFAULT; @@ -1045,7 +1043,6 @@ __do_replace(struct net *net, const char *name, unsigned int valid_hooks, struct xt_counters *counters; struct ipt_entry *iter; - ret = 0; counters = xt_counters_alloc(num_counters); if (!counters) { ret = -ENOMEM; @@ -1091,7 +1088,7 @@ __do_replace(struct net *net, const char *name, unsigned int valid_hooks, net_warn_ratelimited("iptables: counters copy to user failed while replacing table\n"); } vfree(counters); - return ret; + return 0; put_module: module_put(t->me); @@ -1111,6 +1108,8 @@ do_replace(struct net *net, sockptr_t arg, unsigned int len) void *loc_cpu_entry; struct ipt_entry *iter; + if (len < sizeof(tmp)) + return -EINVAL; if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) return -EFAULT; @@ -1119,6 +1118,8 @@ do_replace(struct net *net, sockptr_t arg, unsigned int len) return -ENOMEM; if (tmp.num_counters == 0) return -EINVAL; + if ((u64)len < (u64)tmp.size + sizeof(tmp)) + return -EINVAL; tmp.name[sizeof(tmp.name)-1] = 0; @@ -1495,6 +1496,8 @@ compat_do_replace(struct net *net, sockptr_t arg, unsigned int len) void *loc_cpu_entry; struct ipt_entry *iter; + if (len < sizeof(tmp)) + return -EINVAL; if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) return -EFAULT; @@ -1503,6 +1506,8 @@ compat_do_replace(struct net *net, sockptr_t arg, unsigned int len) return -ENOMEM; if (tmp.num_counters == 0) return -EINVAL; + if ((u64)len < (u64)tmp.size + sizeof(tmp)) + return -EINVAL; tmp.name[sizeof(tmp.name)-1] = 0; @@ -1742,6 +1747,10 @@ int ipt_register_table(struct net *net, const struct xt_table *table, new_table = xt_register_table(net, table, &bootstrap, newinfo); if (IS_ERR(new_table)) { + struct ipt_entry *iter; + + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); xt_free_table_info(newinfo); return PTR_ERR(new_table); } @@ -1758,7 +1767,7 @@ int ipt_register_table(struct net *net, const struct xt_table *table, goto out_free; } - ops = kmemdup(template_ops, sizeof(*ops) * num_ops, GFP_KERNEL); + ops = kmemdup_array(template_ops, num_ops, sizeof(*ops), GFP_KERNEL); if (!ops) { ret = -ENOMEM; goto out_free; @@ -1796,52 +1805,6 @@ void ipt_unregister_table_exit(struct net *net, const char *name) __ipt_unregister_table(net, table); } -/* Returns 1 if the type and code is matched by the range, 0 otherwise */ -static inline bool -icmp_type_code_match(u_int8_t test_type, u_int8_t min_code, u_int8_t max_code, - u_int8_t type, u_int8_t code, - bool invert) -{ - return ((test_type == 0xFF) || - (type == test_type && code >= min_code && code <= max_code)) - ^ invert; -} - -static bool -icmp_match(const struct sk_buff *skb, struct xt_action_param *par) -{ - const struct icmphdr *ic; - struct icmphdr _icmph; - const struct ipt_icmp *icmpinfo = par->matchinfo; - - /* Must not be a fragment. */ - if (par->fragoff != 0) - return false; - - ic = skb_header_pointer(skb, par->thoff, sizeof(_icmph), &_icmph); - if (ic == NULL) { - /* We've been asked to examine this packet, and we - * can't. Hence, no choice but to drop. - */ - par->hotdrop = true; - return false; - } - - return icmp_type_code_match(icmpinfo->type, - icmpinfo->code[0], - icmpinfo->code[1], - ic->type, ic->code, - !!(icmpinfo->invflags&IPT_ICMP_INV)); -} - -static int icmp_checkentry(const struct xt_mtchk_param *par) -{ - const struct ipt_icmp *icmpinfo = par->matchinfo; - - /* Must specify no unknown invflags */ - return (icmpinfo->invflags & ~IPT_ICMP_INV) ? -EINVAL : 0; -} - static struct xt_target ipt_builtin_tg[] __read_mostly = { { .name = XT_STANDARD_TARGET, @@ -1872,18 +1835,6 @@ static struct nf_sockopt_ops ipt_sockopts = { .owner = THIS_MODULE, }; -static struct xt_match ipt_builtin_mt[] __read_mostly = { - { - .name = "icmp", - .match = icmp_match, - .matchsize = sizeof(struct ipt_icmp), - .checkentry = icmp_checkentry, - .proto = IPPROTO_ICMP, - .family = NFPROTO_IPV4, - .me = THIS_MODULE, - }, -}; - static int __net_init ip_tables_net_init(struct net *net) { return xt_proto_init(net, NFPROTO_IPV4); @@ -1911,19 +1862,14 @@ static int __init ip_tables_init(void) ret = xt_register_targets(ipt_builtin_tg, ARRAY_SIZE(ipt_builtin_tg)); if (ret < 0) goto err2; - ret = xt_register_matches(ipt_builtin_mt, ARRAY_SIZE(ipt_builtin_mt)); - if (ret < 0) - goto err4; /* Register setsockopt */ ret = nf_register_sockopt(&ipt_sockopts); if (ret < 0) - goto err5; + goto err4; return 0; -err5: - xt_unregister_matches(ipt_builtin_mt, ARRAY_SIZE(ipt_builtin_mt)); err4: xt_unregister_targets(ipt_builtin_tg, ARRAY_SIZE(ipt_builtin_tg)); err2: @@ -1936,7 +1882,6 @@ static void __exit ip_tables_fini(void) { nf_unregister_sockopt(&ipt_sockopts); - xt_unregister_matches(ipt_builtin_mt, ARRAY_SIZE(ipt_builtin_mt)); xt_unregister_targets(ipt_builtin_tg, ARRAY_SIZE(ipt_builtin_tg)); unregister_pernet_subsys(&ip_tables_net_ops); } diff --git a/net/ipv4/netfilter/ipt_CLUSTERIP.c b/net/ipv4/netfilter/ipt_CLUSTERIP.c deleted file mode 100644 index f8e176c77d1c..000000000000 --- a/net/ipv4/netfilter/ipt_CLUSTERIP.c +++ /dev/null @@ -1,929 +0,0 @@ -// SPDX-License-Identifier: GPL-2.0-only -/* Cluster IP hashmark target - * (C) 2003-2004 by Harald Welte <laforge@netfilter.org> - * based on ideas of Fabio Olive Leite <olive@unixforge.org> - * - * Development of this code funded by SuSE Linux AG, https://www.suse.com/ - */ -#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt -#include <linux/module.h> -#include <linux/proc_fs.h> -#include <linux/jhash.h> -#include <linux/bitops.h> -#include <linux/skbuff.h> -#include <linux/slab.h> -#include <linux/ip.h> -#include <linux/tcp.h> -#include <linux/udp.h> -#include <linux/icmp.h> -#include <linux/if_arp.h> -#include <linux/seq_file.h> -#include <linux/refcount.h> -#include <linux/netfilter_arp.h> -#include <linux/netfilter/x_tables.h> -#include <linux/netfilter_ipv4/ip_tables.h> -#include <linux/netfilter_ipv4/ipt_CLUSTERIP.h> -#include <net/netfilter/nf_conntrack.h> -#include <net/net_namespace.h> -#include <net/netns/generic.h> -#include <net/checksum.h> -#include <net/ip.h> - -#define CLUSTERIP_VERSION "0.8" - -MODULE_LICENSE("GPL"); -MODULE_AUTHOR("Harald Welte <laforge@netfilter.org>"); -MODULE_DESCRIPTION("Xtables: CLUSTERIP target"); - -struct clusterip_config { - struct list_head list; /* list of all configs */ - refcount_t refcount; /* reference count */ - refcount_t entries; /* number of entries/rules - * referencing us */ - - __be32 clusterip; /* the IP address */ - u_int8_t clustermac[ETH_ALEN]; /* the MAC address */ - int ifindex; /* device ifindex */ - u_int16_t num_total_nodes; /* total number of nodes */ - unsigned long local_nodes; /* node number array */ - -#ifdef CONFIG_PROC_FS - struct proc_dir_entry *pde; /* proc dir entry */ -#endif - enum clusterip_hashmode hash_mode; /* which hashing mode */ - u_int32_t hash_initval; /* hash initialization */ - struct rcu_head rcu; /* for call_rcu */ - struct net *net; /* netns for pernet list */ - char ifname[IFNAMSIZ]; /* device ifname */ -}; - -#ifdef CONFIG_PROC_FS -static const struct proc_ops clusterip_proc_ops; -#endif - -struct clusterip_net { - struct list_head configs; - /* lock protects the configs list */ - spinlock_t lock; - - bool clusterip_deprecated_warning; -#ifdef CONFIG_PROC_FS - struct proc_dir_entry *procdir; - /* mutex protects the config->pde*/ - struct mutex mutex; -#endif - unsigned int hook_users; -}; - -static unsigned int clusterip_arp_mangle(void *priv, struct sk_buff *skb, const struct nf_hook_state *state); - -static const struct nf_hook_ops cip_arp_ops = { - .hook = clusterip_arp_mangle, - .pf = NFPROTO_ARP, - .hooknum = NF_ARP_OUT, - .priority = -1 -}; - -static unsigned int clusterip_net_id __read_mostly; -static inline struct clusterip_net *clusterip_pernet(struct net *net) -{ - return net_generic(net, clusterip_net_id); -} - -static inline void -clusterip_config_get(struct clusterip_config *c) -{ - refcount_inc(&c->refcount); -} - -static void clusterip_config_rcu_free(struct rcu_head *head) -{ - struct clusterip_config *config; - struct net_device *dev; - - config = container_of(head, struct clusterip_config, rcu); - dev = dev_get_by_name(config->net, config->ifname); - if (dev) { - dev_mc_del(dev, config->clustermac); - dev_put(dev); - } - kfree(config); -} - -static inline void -clusterip_config_put(struct clusterip_config *c) -{ - if (refcount_dec_and_test(&c->refcount)) - call_rcu(&c->rcu, clusterip_config_rcu_free); -} - -/* decrease the count of entries using/referencing this config. If last - * entry(rule) is removed, remove the config from lists, but don't free it - * yet, since proc-files could still be holding references */ -static inline void -clusterip_config_entry_put(struct clusterip_config *c) -{ - struct clusterip_net *cn = clusterip_pernet(c->net); - - local_bh_disable(); - if (refcount_dec_and_lock(&c->entries, &cn->lock)) { - list_del_rcu(&c->list); - spin_unlock(&cn->lock); - local_bh_enable(); - /* In case anyone still accesses the file, the open/close - * functions are also incrementing the refcount on their own, - * so it's safe to remove the entry even if it's in use. */ -#ifdef CONFIG_PROC_FS - mutex_lock(&cn->mutex); - if (cn->procdir) - proc_remove(c->pde); - mutex_unlock(&cn->mutex); -#endif - return; - } - local_bh_enable(); -} - -static struct clusterip_config * -__clusterip_config_find(struct net *net, __be32 clusterip) -{ - struct clusterip_config *c; - struct clusterip_net *cn = clusterip_pernet(net); - - list_for_each_entry_rcu(c, &cn->configs, list) { - if (c->clusterip == clusterip) - return c; - } - - return NULL; -} - -static inline struct clusterip_config * -clusterip_config_find_get(struct net *net, __be32 clusterip, int entry) -{ - struct clusterip_config *c; - - rcu_read_lock_bh(); - c = __clusterip_config_find(net, clusterip); - if (c) { -#ifdef CONFIG_PROC_FS - if (!c->pde) - c = NULL; - else -#endif - if (unlikely(!refcount_inc_not_zero(&c->refcount))) - c = NULL; - else if (entry) { - if (unlikely(!refcount_inc_not_zero(&c->entries))) { - clusterip_config_put(c); - c = NULL; - } - } - } - rcu_read_unlock_bh(); - - return c; -} - -static void -clusterip_config_init_nodelist(struct clusterip_config *c, - const struct ipt_clusterip_tgt_info *i) -{ - int n; - - for (n = 0; n < i->num_local_nodes; n++) - set_bit(i->local_nodes[n] - 1, &c->local_nodes); -} - -static int -clusterip_netdev_event(struct notifier_block *this, unsigned long event, - void *ptr) -{ - struct net_device *dev = netdev_notifier_info_to_dev(ptr); - struct net *net = dev_net(dev); - struct clusterip_net *cn = clusterip_pernet(net); - struct clusterip_config *c; - - spin_lock_bh(&cn->lock); - list_for_each_entry_rcu(c, &cn->configs, list) { - switch (event) { - case NETDEV_REGISTER: - if (!strcmp(dev->name, c->ifname)) { - c->ifindex = dev->ifindex; - dev_mc_add(dev, c->clustermac); - } - break; - case NETDEV_UNREGISTER: - if (dev->ifindex == c->ifindex) { - dev_mc_del(dev, c->clustermac); - c->ifindex = -1; - } - break; - case NETDEV_CHANGENAME: - if (!strcmp(dev->name, c->ifname)) { - c->ifindex = dev->ifindex; - dev_mc_add(dev, c->clustermac); - } else if (dev->ifindex == c->ifindex) { - dev_mc_del(dev, c->clustermac); - c->ifindex = -1; - } - break; - } - } - spin_unlock_bh(&cn->lock); - - return NOTIFY_DONE; -} - -static struct clusterip_config * -clusterip_config_init(struct net *net, const struct ipt_clusterip_tgt_info *i, - __be32 ip, const char *iniface) -{ - struct clusterip_net *cn = clusterip_pernet(net); - struct clusterip_config *c; - struct net_device *dev; - int err; - - if (iniface[0] == '\0') { - pr_info("Please specify an interface name\n"); - return ERR_PTR(-EINVAL); - } - - c = kzalloc(sizeof(*c), GFP_ATOMIC); - if (!c) - return ERR_PTR(-ENOMEM); - - dev = dev_get_by_name(net, iniface); - if (!dev) { - pr_info("no such interface %s\n", iniface); - kfree(c); - return ERR_PTR(-ENOENT); - } - c->ifindex = dev->ifindex; - strcpy(c->ifname, dev->name); - memcpy(&c->clustermac, &i->clustermac, ETH_ALEN); - dev_mc_add(dev, c->clustermac); - dev_put(dev); - - c->clusterip = ip; - c->num_total_nodes = i->num_total_nodes; - clusterip_config_init_nodelist(c, i); - c->hash_mode = i->hash_mode; - c->hash_initval = i->hash_initval; - c->net = net; - refcount_set(&c->refcount, 1); - - spin_lock_bh(&cn->lock); - if (__clusterip_config_find(net, ip)) { - err = -EBUSY; - goto out_config_put; - } - - list_add_rcu(&c->list, &cn->configs); - spin_unlock_bh(&cn->lock); - -#ifdef CONFIG_PROC_FS - { - char buffer[16]; - - /* create proc dir entry */ - sprintf(buffer, "%pI4", &ip); - mutex_lock(&cn->mutex); - c->pde = proc_create_data(buffer, 0600, - cn->procdir, - &clusterip_proc_ops, c); - mutex_unlock(&cn->mutex); - if (!c->pde) { - err = -ENOMEM; - goto err; - } - } -#endif - - refcount_set(&c->entries, 1); - return c; - -#ifdef CONFIG_PROC_FS -err: -#endif - spin_lock_bh(&cn->lock); - list_del_rcu(&c->list); -out_config_put: - spin_unlock_bh(&cn->lock); - clusterip_config_put(c); - return ERR_PTR(err); -} - -#ifdef CONFIG_PROC_FS -static int -clusterip_add_node(struct clusterip_config *c, u_int16_t nodenum) -{ - - if (nodenum == 0 || - nodenum > c->num_total_nodes) - return 1; - - /* check if we already have this number in our bitfield */ - if (test_and_set_bit(nodenum - 1, &c->local_nodes)) - return 1; - - return 0; -} - -static bool -clusterip_del_node(struct clusterip_config *c, u_int16_t nodenum) -{ - if (nodenum == 0 || - nodenum > c->num_total_nodes) - return true; - - if (test_and_clear_bit(nodenum - 1, &c->local_nodes)) - return false; - - return true; -} -#endif - -static inline u_int32_t -clusterip_hashfn(const struct sk_buff *skb, - const struct clusterip_config *config) -{ - const struct iphdr *iph = ip_hdr(skb); - unsigned long hashval; - u_int16_t sport = 0, dport = 0; - int poff; - - poff = proto_ports_offset(iph->protocol); - if (poff >= 0) { - const u_int16_t *ports; - u16 _ports[2]; - - ports = skb_header_pointer(skb, iph->ihl * 4 + poff, 4, _ports); - if (ports) { - sport = ports[0]; - dport = ports[1]; - } - } else { - net_info_ratelimited("unknown protocol %u\n", iph->protocol); - } - - switch (config->hash_mode) { - case CLUSTERIP_HASHMODE_SIP: - hashval = jhash_1word(ntohl(iph->saddr), - config->hash_initval); - break; - case CLUSTERIP_HASHMODE_SIP_SPT: - hashval = jhash_2words(ntohl(iph->saddr), sport, - config->hash_initval); - break; - case CLUSTERIP_HASHMODE_SIP_SPT_DPT: - hashval = jhash_3words(ntohl(iph->saddr), sport, dport, - config->hash_initval); - break; - default: - /* to make gcc happy */ - hashval = 0; - /* This cannot happen, unless the check function wasn't called - * at rule load time */ - pr_info("unknown mode %u\n", config->hash_mode); - BUG(); - break; - } - - /* node numbers are 1..n, not 0..n */ - return reciprocal_scale(hashval, config->num_total_nodes) + 1; -} - -static inline int -clusterip_responsible(const struct clusterip_config *config, u_int32_t hash) -{ - return test_bit(hash - 1, &config->local_nodes); -} - -/*********************************************************************** - * IPTABLES TARGET - ***********************************************************************/ - -static unsigned int -clusterip_tg(struct sk_buff *skb, const struct xt_action_param *par) -{ - const struct ipt_clusterip_tgt_info *cipinfo = par->targinfo; - struct nf_conn *ct; - enum ip_conntrack_info ctinfo; - u_int32_t hash; - - /* don't need to clusterip_config_get() here, since refcount - * is only decremented by destroy() - and ip_tables guarantees - * that the ->target() function isn't called after ->destroy() */ - - ct = nf_ct_get(skb, &ctinfo); - if (ct == NULL) - return NF_DROP; - - /* special case: ICMP error handling. conntrack distinguishes between - * error messages (RELATED) and information requests (see below) */ - if (ip_hdr(skb)->protocol == IPPROTO_ICMP && - (ctinfo == IP_CT_RELATED || - ctinfo == IP_CT_RELATED_REPLY)) - return XT_CONTINUE; - - /* nf_conntrack_proto_icmp guarantees us that we only have ICMP_ECHO, - * TIMESTAMP, INFO_REQUEST or ICMP_ADDRESS type icmp packets from here - * on, which all have an ID field [relevant for hashing]. */ - - hash = clusterip_hashfn(skb, cipinfo->config); - - switch (ctinfo) { - case IP_CT_NEW: - ct->mark = hash; - break; - case IP_CT_RELATED: - case IP_CT_RELATED_REPLY: - /* FIXME: we don't handle expectations at the moment. - * They can arrive on a different node than - * the master connection (e.g. FTP passive mode) */ - case IP_CT_ESTABLISHED: - case IP_CT_ESTABLISHED_REPLY: - break; - default: /* Prevent gcc warnings */ - break; - } - -#ifdef DEBUG - nf_ct_dump_tuple_ip(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple); -#endif - pr_debug("hash=%u ct_hash=%u ", hash, ct->mark); - if (!clusterip_responsible(cipinfo->config, hash)) { - pr_debug("not responsible\n"); - return NF_DROP; - } - pr_debug("responsible\n"); - - /* despite being received via linklayer multicast, this is - * actually a unicast IP packet. TCP doesn't like PACKET_MULTICAST */ - skb->pkt_type = PACKET_HOST; - - return XT_CONTINUE; -} - -static int clusterip_tg_check(const struct xt_tgchk_param *par) -{ - struct ipt_clusterip_tgt_info *cipinfo = par->targinfo; - struct clusterip_net *cn = clusterip_pernet(par->net); - const struct ipt_entry *e = par->entryinfo; - struct clusterip_config *config; - int ret, i; - - if (par->nft_compat) { - pr_err("cannot use CLUSTERIP target from nftables compat\n"); - return -EOPNOTSUPP; - } - - if (cn->hook_users == UINT_MAX) - return -EOVERFLOW; - - if (cipinfo->hash_mode != CLUSTERIP_HASHMODE_SIP && - cipinfo->hash_mode != CLUSTERIP_HASHMODE_SIP_SPT && - cipinfo->hash_mode != CLUSTERIP_HASHMODE_SIP_SPT_DPT) { - pr_info("unknown mode %u\n", cipinfo->hash_mode); - return -EINVAL; - - } - if (e->ip.dmsk.s_addr != htonl(0xffffffff) || - e->ip.dst.s_addr == 0) { - pr_info("Please specify destination IP\n"); - return -EINVAL; - } - if (cipinfo->num_local_nodes > ARRAY_SIZE(cipinfo->local_nodes)) { - pr_info("bad num_local_nodes %u\n", cipinfo->num_local_nodes); - return -EINVAL; - } - for (i = 0; i < cipinfo->num_local_nodes; i++) { - if (cipinfo->local_nodes[i] - 1 >= - sizeof(config->local_nodes) * 8) { - pr_info("bad local_nodes[%d] %u\n", - i, cipinfo->local_nodes[i]); - return -EINVAL; - } - } - - config = clusterip_config_find_get(par->net, e->ip.dst.s_addr, 1); - if (!config) { - if (!(cipinfo->flags & CLUSTERIP_FLAG_NEW)) { - pr_info("no config found for %pI4, need 'new'\n", - &e->ip.dst.s_addr); - return -EINVAL; - } else { - config = clusterip_config_init(par->net, cipinfo, - e->ip.dst.s_addr, - e->ip.iniface); - if (IS_ERR(config)) - return PTR_ERR(config); - } - } else if (memcmp(&config->clustermac, &cipinfo->clustermac, ETH_ALEN)) { - clusterip_config_entry_put(config); - clusterip_config_put(config); - return -EINVAL; - } - - ret = nf_ct_netns_get(par->net, par->family); - if (ret < 0) { - pr_info("cannot load conntrack support for proto=%u\n", - par->family); - clusterip_config_entry_put(config); - clusterip_config_put(config); - return ret; - } - - if (cn->hook_users == 0) { - ret = nf_register_net_hook(par->net, &cip_arp_ops); - - if (ret < 0) { - clusterip_config_entry_put(config); - clusterip_config_put(config); - nf_ct_netns_put(par->net, par->family); - return ret; - } - } - - cn->hook_users++; - - if (!cn->clusterip_deprecated_warning) { - pr_info("ipt_CLUSTERIP is deprecated and it will removed soon, " - "use xt_cluster instead\n"); - cn->clusterip_deprecated_warning = true; - } - - cipinfo->config = config; - return ret; -} - -/* drop reference count of cluster config when rule is deleted */ -static void clusterip_tg_destroy(const struct xt_tgdtor_param *par) -{ - const struct ipt_clusterip_tgt_info *cipinfo = par->targinfo; - struct clusterip_net *cn = clusterip_pernet(par->net); - - /* if no more entries are referencing the config, remove it - * from the list and destroy the proc entry */ - clusterip_config_entry_put(cipinfo->config); - - clusterip_config_put(cipinfo->config); - - nf_ct_netns_put(par->net, par->family); - cn->hook_users--; - - if (cn->hook_users == 0) - nf_unregister_net_hook(par->net, &cip_arp_ops); -} - -#ifdef CONFIG_NETFILTER_XTABLES_COMPAT -struct compat_ipt_clusterip_tgt_info -{ - u_int32_t flags; - u_int8_t clustermac[6]; - u_int16_t num_total_nodes; - u_int16_t num_local_nodes; - u_int16_t local_nodes[CLUSTERIP_MAX_NODES]; - u_int32_t hash_mode; - u_int32_t hash_initval; - compat_uptr_t config; -}; -#endif /* CONFIG_NETFILTER_XTABLES_COMPAT */ - -static struct xt_target clusterip_tg_reg __read_mostly = { - .name = "CLUSTERIP", - .family = NFPROTO_IPV4, - .target = clusterip_tg, - .checkentry = clusterip_tg_check, - .destroy = clusterip_tg_destroy, - .targetsize = sizeof(struct ipt_clusterip_tgt_info), - .usersize = offsetof(struct ipt_clusterip_tgt_info, config), -#ifdef CONFIG_NETFILTER_XTABLES_COMPAT - .compatsize = sizeof(struct compat_ipt_clusterip_tgt_info), -#endif /* CONFIG_NETFILTER_XTABLES_COMPAT */ - .me = THIS_MODULE -}; - - -/*********************************************************************** - * ARP MANGLING CODE - ***********************************************************************/ - -/* hardcoded for 48bit ethernet and 32bit ipv4 addresses */ -struct arp_payload { - u_int8_t src_hw[ETH_ALEN]; - __be32 src_ip; - u_int8_t dst_hw[ETH_ALEN]; - __be32 dst_ip; -} __packed; - -#ifdef DEBUG -static void arp_print(struct arp_payload *payload) -{ -#define HBUFFERLEN 30 - char hbuffer[HBUFFERLEN]; - int j, k; - - for (k = 0, j = 0; k < HBUFFERLEN - 3 && j < ETH_ALEN; j++) { - hbuffer[k++] = hex_asc_hi(payload->src_hw[j]); - hbuffer[k++] = hex_asc_lo(payload->src_hw[j]); - hbuffer[k++] = ':'; - } - hbuffer[--k] = '\0'; - - pr_debug("src %pI4@%s, dst %pI4\n", - &payload->src_ip, hbuffer, &payload->dst_ip); -} -#endif - -static unsigned int -clusterip_arp_mangle(void *priv, struct sk_buff *skb, - const struct nf_hook_state *state) -{ - struct arphdr *arp = arp_hdr(skb); - struct arp_payload *payload; - struct clusterip_config *c; - struct net *net = state->net; - - /* we don't care about non-ethernet and non-ipv4 ARP */ - if (arp->ar_hrd != htons(ARPHRD_ETHER) || - arp->ar_pro != htons(ETH_P_IP) || - arp->ar_pln != 4 || arp->ar_hln != ETH_ALEN) - return NF_ACCEPT; - - /* we only want to mangle arp requests and replies */ - if (arp->ar_op != htons(ARPOP_REPLY) && - arp->ar_op != htons(ARPOP_REQUEST)) - return NF_ACCEPT; - - payload = (void *)(arp+1); - - /* if there is no clusterip configuration for the arp reply's - * source ip, we don't want to mangle it */ - c = clusterip_config_find_get(net, payload->src_ip, 0); - if (!c) - return NF_ACCEPT; - - /* normally the linux kernel always replies to arp queries of - * addresses on different interfacs. However, in the CLUSTERIP case - * this wouldn't work, since we didn't subscribe the mcast group on - * other interfaces */ - if (c->ifindex != state->out->ifindex) { - pr_debug("not mangling arp reply on different interface: cip'%d'-skb'%d'\n", - c->ifindex, state->out->ifindex); - clusterip_config_put(c); - return NF_ACCEPT; - } - - /* mangle reply hardware address */ - memcpy(payload->src_hw, c->clustermac, arp->ar_hln); - -#ifdef DEBUG - pr_debug("mangled arp reply: "); - arp_print(payload); -#endif - - clusterip_config_put(c); - - return NF_ACCEPT; -} - -/*********************************************************************** - * PROC DIR HANDLING - ***********************************************************************/ - -#ifdef CONFIG_PROC_FS - -struct clusterip_seq_position { - unsigned int pos; /* position */ - unsigned int weight; /* number of bits set == size */ - unsigned int bit; /* current bit */ - unsigned long val; /* current value */ -}; - -static void *clusterip_seq_start(struct seq_file *s, loff_t *pos) -{ - struct clusterip_config *c = s->private; - unsigned int weight; - u_int32_t local_nodes; - struct clusterip_seq_position *idx; - - /* FIXME: possible race */ - local_nodes = c->local_nodes; - weight = hweight32(local_nodes); - if (*pos >= weight) - return NULL; - - idx = kmalloc(sizeof(struct clusterip_seq_position), GFP_KERNEL); - if (!idx) - return ERR_PTR(-ENOMEM); - - idx->pos = *pos; - idx->weight = weight; - idx->bit = ffs(local_nodes); - idx->val = local_nodes; - clear_bit(idx->bit - 1, &idx->val); - - return idx; -} - -static void *clusterip_seq_next(struct seq_file *s, void *v, loff_t *pos) -{ - struct clusterip_seq_position *idx = v; - - *pos = ++idx->pos; - if (*pos >= idx->weight) { - kfree(v); - return NULL; - } - idx->bit = ffs(idx->val); - clear_bit(idx->bit - 1, &idx->val); - return idx; -} - -static void clusterip_seq_stop(struct seq_file *s, void *v) -{ - if (!IS_ERR(v)) - kfree(v); -} - -static int clusterip_seq_show(struct seq_file *s, void *v) -{ - struct clusterip_seq_position *idx = v; - - if (idx->pos != 0) - seq_putc(s, ','); - - seq_printf(s, "%u", idx->bit); - - if (idx->pos == idx->weight - 1) - seq_putc(s, '\n'); - - return 0; -} - -static const struct seq_operations clusterip_seq_ops = { - .start = clusterip_seq_start, - .next = clusterip_seq_next, - .stop = clusterip_seq_stop, - .show = clusterip_seq_show, -}; - -static int clusterip_proc_open(struct inode *inode, struct file *file) -{ - int ret = seq_open(file, &clusterip_seq_ops); - - if (!ret) { - struct seq_file *sf = file->private_data; - struct clusterip_config *c = pde_data(inode); - - sf->private = c; - - clusterip_config_get(c); - } - - return ret; -} - -static int clusterip_proc_release(struct inode *inode, struct file *file) -{ - struct clusterip_config *c = pde_data(inode); - int ret; - - ret = seq_release(inode, file); - - if (!ret) - clusterip_config_put(c); - - return ret; -} - -static ssize_t clusterip_proc_write(struct file *file, const char __user *input, - size_t size, loff_t *ofs) -{ - struct clusterip_config *c = pde_data(file_inode(file)); -#define PROC_WRITELEN 10 - char buffer[PROC_WRITELEN+1]; - unsigned long nodenum; - int rc; - - if (size > PROC_WRITELEN) - return -EIO; - if (copy_from_user(buffer, input, size)) - return -EFAULT; - buffer[size] = 0; - - if (*buffer == '+') { - rc = kstrtoul(buffer+1, 10, &nodenum); - if (rc) - return rc; - if (clusterip_add_node(c, nodenum)) - return -ENOMEM; - } else if (*buffer == '-') { - rc = kstrtoul(buffer+1, 10, &nodenum); - if (rc) - return rc; - if (clusterip_del_node(c, nodenum)) - return -ENOENT; - } else - return -EIO; - - return size; -} - -static const struct proc_ops clusterip_proc_ops = { - .proc_open = clusterip_proc_open, - .proc_read = seq_read, - .proc_write = clusterip_proc_write, - .proc_lseek = seq_lseek, - .proc_release = clusterip_proc_release, -}; - -#endif /* CONFIG_PROC_FS */ - -static int clusterip_net_init(struct net *net) -{ - struct clusterip_net *cn = clusterip_pernet(net); - - INIT_LIST_HEAD(&cn->configs); - - spin_lock_init(&cn->lock); - -#ifdef CONFIG_PROC_FS - cn->procdir = proc_mkdir("ipt_CLUSTERIP", net->proc_net); - if (!cn->procdir) { - pr_err("Unable to proc dir entry\n"); - return -ENOMEM; - } - mutex_init(&cn->mutex); -#endif /* CONFIG_PROC_FS */ - - return 0; -} - -static void clusterip_net_exit(struct net *net) -{ -#ifdef CONFIG_PROC_FS - struct clusterip_net *cn = clusterip_pernet(net); - - mutex_lock(&cn->mutex); - proc_remove(cn->procdir); - cn->procdir = NULL; - mutex_unlock(&cn->mutex); -#endif -} - -static struct pernet_operations clusterip_net_ops = { - .init = clusterip_net_init, - .exit = clusterip_net_exit, - .id = &clusterip_net_id, - .size = sizeof(struct clusterip_net), -}; - -static struct notifier_block cip_netdev_notifier = { - .notifier_call = clusterip_netdev_event -}; - -static int __init clusterip_tg_init(void) -{ - int ret; - - ret = register_pernet_subsys(&clusterip_net_ops); - if (ret < 0) - return ret; - - ret = xt_register_target(&clusterip_tg_reg); - if (ret < 0) - goto cleanup_subsys; - - ret = register_netdevice_notifier(&cip_netdev_notifier); - if (ret < 0) - goto unregister_target; - - pr_info("ClusterIP Version %s loaded successfully\n", - CLUSTERIP_VERSION); - - return 0; - -unregister_target: - xt_unregister_target(&clusterip_tg_reg); -cleanup_subsys: - unregister_pernet_subsys(&clusterip_net_ops); - return ret; -} - -static void __exit clusterip_tg_exit(void) -{ - pr_info("ClusterIP Version %s unloading\n", CLUSTERIP_VERSION); - - unregister_netdevice_notifier(&cip_netdev_notifier); - xt_unregister_target(&clusterip_tg_reg); - unregister_pernet_subsys(&clusterip_net_ops); - - /* Wait for completion of call_rcu()'s (clusterip_config_rcu_free) */ - rcu_barrier(); -} - -module_init(clusterip_tg_init); -module_exit(clusterip_tg_exit); diff --git a/net/ipv4/netfilter/ipt_rpfilter.c b/net/ipv4/netfilter/ipt_rpfilter.c index 8cd3224d913e..6d9bf5106868 100644 --- a/net/ipv4/netfilter/ipt_rpfilter.c +++ b/net/ipv4/netfilter/ipt_rpfilter.c @@ -9,6 +9,7 @@ #include <linux/skbuff.h> #include <linux/netdevice.h> #include <linux/ip.h> +#include <net/flow.h> #include <net/ip.h> #include <net/ip_fib.h> #include <net/route.h> @@ -33,7 +34,6 @@ static bool rpfilter_lookup_reverse(struct net *net, struct flowi4 *fl4, const struct net_device *dev, u8 flags) { struct fib_result res; - int ret __maybe_unused; if (fib_lookup(net, fl4, &res, FIB_LOOKUP_IGNORE_LINKSTATE)) return false; @@ -76,9 +76,10 @@ static bool rpfilter_mt(const struct sk_buff *skb, struct xt_action_param *par) flow.daddr = iph->saddr; flow.saddr = rpfilter_get_saddr(iph->daddr); flow.flowi4_mark = info->flags & XT_RPFILTER_VALID_MARK ? skb->mark : 0; - flow.flowi4_tos = iph->tos & IPTOS_RT_MASK; + flow.flowi4_dscp = ip4h_dscp(iph); flow.flowi4_scope = RT_SCOPE_UNIVERSE; - flow.flowi4_oif = l3mdev_master_ifindex_rcu(xt_in(par)); + flow.flowi4_l3mdev = l3mdev_master_ifindex_rcu(xt_in(par)); + flow.flowi4_uid = sock_net_uid(xt_net(par), NULL); return rpfilter_lookup_reverse(xt_net(par), &flow, xt_in(par), info->flags) ^ invert; } diff --git a/net/ipv4/netfilter/iptable_filter.c b/net/ipv4/netfilter/iptable_filter.c index b9062f4552ac..3ab908b74795 100644 --- a/net/ipv4/netfilter/iptable_filter.c +++ b/net/ipv4/netfilter/iptable_filter.c @@ -44,7 +44,7 @@ static int iptable_filter_table_init(struct net *net) return -ENOMEM; /* Entry 1 is the FORWARD hook */ ((struct ipt_standard *)repl->entries)[1].target.verdict = - forward ? -NF_ACCEPT - 1 : -NF_DROP - 1; + forward ? -NF_ACCEPT - 1 : NF_DROP - 1; err = ipt_register_table(net, &packet_filter, repl, filter_ops); kfree(repl); diff --git a/net/ipv4/netfilter/iptable_mangle.c b/net/ipv4/netfilter/iptable_mangle.c index 3abb430af9e6..385d945d8ebe 100644 --- a/net/ipv4/netfilter/iptable_mangle.c +++ b/net/ipv4/netfilter/iptable_mangle.c @@ -36,12 +36,12 @@ static const struct xt_table packet_mangler = { static unsigned int ipt_mangle_out(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) { - unsigned int ret; + unsigned int ret, verdict; const struct iphdr *iph; - u_int8_t tos; __be32 saddr, daddr; - u_int32_t mark; + u32 mark; int err; + u8 tos; /* Save things which could affect route */ mark = skb->mark; @@ -51,8 +51,9 @@ ipt_mangle_out(void *priv, struct sk_buff *skb, const struct nf_hook_state *stat tos = iph->tos; ret = ipt_do_table(priv, skb, state); + verdict = ret & NF_VERDICT_MASK; /* Reroute for ANY change. */ - if (ret != NF_DROP && ret != NF_STOLEN) { + if (verdict != NF_DROP && verdict != NF_STOLEN) { iph = ip_hdr(skb); if (iph->saddr != saddr || diff --git a/net/ipv4/netfilter/iptable_nat.c b/net/ipv4/netfilter/iptable_nat.c index 56f6ecc43451..a5db7c67d61b 100644 --- a/net/ipv4/netfilter/iptable_nat.c +++ b/net/ipv4/netfilter/iptable_nat.c @@ -145,28 +145,31 @@ static struct pernet_operations iptable_nat_net_ops = { static int __init iptable_nat_init(void) { - int ret = xt_register_template(&nf_nat_ipv4_table, - iptable_nat_table_init); + int ret; + /* net->gen->ptr[iptable_nat_net_id] must be allocated + * before calling iptable_nat_table_init(). + */ + ret = register_pernet_subsys(&iptable_nat_net_ops); if (ret < 0) return ret; - ret = register_pernet_subsys(&iptable_nat_net_ops); - if (ret < 0) { - xt_unregister_template(&nf_nat_ipv4_table); - return ret; - } + ret = xt_register_template(&nf_nat_ipv4_table, + iptable_nat_table_init); + if (ret < 0) + unregister_pernet_subsys(&iptable_nat_net_ops); return ret; } static void __exit iptable_nat_exit(void) { - unregister_pernet_subsys(&iptable_nat_net_ops); xt_unregister_template(&nf_nat_ipv4_table); + unregister_pernet_subsys(&iptable_nat_net_ops); } module_init(iptable_nat_init); module_exit(iptable_nat_exit); MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("iptables legacy nat table"); diff --git a/net/ipv4/netfilter/iptable_raw.c b/net/ipv4/netfilter/iptable_raw.c index ca5e5b21587c..0e7f53964d0a 100644 --- a/net/ipv4/netfilter/iptable_raw.c +++ b/net/ipv4/netfilter/iptable_raw.c @@ -108,3 +108,4 @@ static void __exit iptable_raw_fini(void) module_init(iptable_raw_init); module_exit(iptable_raw_fini); MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("iptables legacy raw table"); diff --git a/net/ipv4/netfilter/nf_defrag_ipv4.c b/net/ipv4/netfilter/nf_defrag_ipv4.c index e61ea428ea18..482e733c3375 100644 --- a/net/ipv4/netfilter/nf_defrag_ipv4.c +++ b/net/ipv4/netfilter/nf_defrag_ipv4.c @@ -7,6 +7,7 @@ #include <linux/ip.h> #include <linux/netfilter.h> #include <linux/module.h> +#include <linux/rcupdate.h> #include <linux/skbuff.h> #include <net/netns/generic.h> #include <net/route.h> @@ -65,7 +66,7 @@ static unsigned int ipv4_conntrack_defrag(void *priv, struct sock *sk = skb->sk; if (sk && sk_fullsock(sk) && (sk->sk_family == PF_INET) && - inet_sk(sk)->nodefrag) + inet_test_bit(NODEFRAG, sk)) return NF_ACCEPT; #if IS_ENABLED(CONFIG_NF_CONNTRACK) @@ -113,17 +114,31 @@ static void __net_exit defrag4_net_exit(struct net *net) } } +static const struct nf_defrag_hook defrag_hook = { + .owner = THIS_MODULE, + .enable = nf_defrag_ipv4_enable, + .disable = nf_defrag_ipv4_disable, +}; + static struct pernet_operations defrag4_net_ops = { .exit = defrag4_net_exit, }; static int __init nf_defrag_init(void) { - return register_pernet_subsys(&defrag4_net_ops); + int err; + + err = register_pernet_subsys(&defrag4_net_ops); + if (err) + return err; + + rcu_assign_pointer(nf_defrag_v4_hook, &defrag_hook); + return err; } static void __exit nf_defrag_fini(void) { + rcu_assign_pointer(nf_defrag_v4_hook, NULL); unregister_pernet_subsys(&defrag4_net_ops); } @@ -171,3 +186,4 @@ module_init(nf_defrag_init); module_exit(nf_defrag_fini); MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("IPv4 defragmentation support"); diff --git a/net/ipv4/netfilter/nf_dup_ipv4.c b/net/ipv4/netfilter/nf_dup_ipv4.c index 6cc5743c553a..9a773502f10a 100644 --- a/net/ipv4/netfilter/nf_dup_ipv4.c +++ b/net/ipv4/netfilter/nf_dup_ipv4.c @@ -12,6 +12,7 @@ #include <linux/skbuff.h> #include <linux/netfilter.h> #include <net/checksum.h> +#include <net/flow.h> #include <net/icmp.h> #include <net/ip.h> #include <net/route.h> @@ -32,7 +33,7 @@ static bool nf_dup_ipv4_route(struct net *net, struct sk_buff *skb, fl4.flowi4_oif = oif; fl4.daddr = gw->s_addr; - fl4.flowi4_tos = RT_TOS(iph->tos); + fl4.flowi4_dscp = ip4h_dscp(iph); fl4.flowi4_scope = RT_SCOPE_UNIVERSE; fl4.flowi4_flags = FLOWI_FLAG_KNOWN_NH; rt = ip_route_output_key(net, &fl4); @@ -52,8 +53,9 @@ void nf_dup_ipv4(struct net *net, struct sk_buff *skb, unsigned int hooknum, { struct iphdr *iph; - if (this_cpu_read(nf_skb_duplicated)) - return; + local_bh_disable(); + if (current->in_nf_duplicate) + goto out; /* * Copy the skb, and route the copy. Will later return %XT_CONTINUE for * the original skb, which should continue on its way as if nothing has @@ -61,7 +63,7 @@ void nf_dup_ipv4(struct net *net, struct sk_buff *skb, unsigned int hooknum, */ skb = pskb_copy(skb, GFP_ATOMIC); if (skb == NULL) - return; + goto out; #if IS_ENABLED(CONFIG_NF_CONNTRACK) /* Avoid counting cloned packets towards the original connection. */ @@ -84,12 +86,14 @@ void nf_dup_ipv4(struct net *net, struct sk_buff *skb, unsigned int hooknum, --iph->ttl; if (nf_dup_ipv4_route(net, skb, gw, oif)) { - __this_cpu_write(nf_skb_duplicated, true); + current->in_nf_duplicate = true; ip_local_out(net, skb->sk, skb); - __this_cpu_write(nf_skb_duplicated, false); + current->in_nf_duplicate = false; } else { kfree_skb(skb); } +out: + local_bh_enable(); } EXPORT_SYMBOL_GPL(nf_dup_ipv4); diff --git a/net/ipv4/netfilter/nf_flow_table_ipv4.c b/net/ipv4/netfilter/nf_flow_table_ipv4.c deleted file mode 100644 index e69de29bb2d1..000000000000 --- a/net/ipv4/netfilter/nf_flow_table_ipv4.c +++ /dev/null diff --git a/net/ipv4/netfilter/nf_nat_h323.c b/net/ipv4/netfilter/nf_nat_h323.c index 3e2685c120c7..faee20af4856 100644 --- a/net/ipv4/netfilter/nf_nat_h323.c +++ b/net/ipv4/netfilter/nf_nat_h323.c @@ -291,20 +291,7 @@ static int nat_t120(struct sk_buff *skb, struct nf_conn *ct, exp->expectfn = nf_nat_follow_master; exp->dir = !dir; - /* Try to get same port: if not, try to change it. */ - for (; nated_port != 0; nated_port++) { - int ret; - - exp->tuple.dst.u.tcp.port = htons(nated_port); - ret = nf_ct_expect_related(exp, 0); - if (ret == 0) - break; - else if (ret != -EBUSY) { - nated_port = 0; - break; - } - } - + nated_port = nf_nat_exp_find_port(exp, nated_port); if (nated_port == 0) { /* No port available */ net_notice_ratelimited("nf_nat_h323: out of TCP ports\n"); return 0; @@ -347,20 +334,7 @@ static int nat_h245(struct sk_buff *skb, struct nf_conn *ct, if (info->sig_port[dir] == port) nated_port = ntohs(info->sig_port[!dir]); - /* Try to get same port: if not, try to change it. */ - for (; nated_port != 0; nated_port++) { - int ret; - - exp->tuple.dst.u.tcp.port = htons(nated_port); - ret = nf_ct_expect_related(exp, 0); - if (ret == 0) - break; - else if (ret != -EBUSY) { - nated_port = 0; - break; - } - } - + nated_port = nf_nat_exp_find_port(exp, nated_port); if (nated_port == 0) { /* No port available */ net_notice_ratelimited("nf_nat_q931: out of TCP ports\n"); return 0; @@ -439,20 +413,7 @@ static int nat_q931(struct sk_buff *skb, struct nf_conn *ct, if (info->sig_port[dir] == port) nated_port = ntohs(info->sig_port[!dir]); - /* Try to get same port: if not, try to change it. */ - for (; nated_port != 0; nated_port++) { - int ret; - - exp->tuple.dst.u.tcp.port = htons(nated_port); - ret = nf_ct_expect_related(exp, 0); - if (ret == 0) - break; - else if (ret != -EBUSY) { - nated_port = 0; - break; - } - } - + nated_port = nf_nat_exp_find_port(exp, nated_port); if (nated_port == 0) { /* No port available */ net_notice_ratelimited("nf_nat_ras: out of TCP ports\n"); return 0; @@ -532,20 +493,7 @@ static int nat_callforwarding(struct sk_buff *skb, struct nf_conn *ct, exp->expectfn = ip_nat_callforwarding_expect; exp->dir = !dir; - /* Try to get same port: if not, try to change it. */ - for (nated_port = ntohs(port); nated_port != 0; nated_port++) { - int ret; - - exp->tuple.dst.u.tcp.port = htons(nated_port); - ret = nf_ct_expect_related(exp, 0); - if (ret == 0) - break; - else if (ret != -EBUSY) { - nated_port = 0; - break; - } - } - + nated_port = nf_nat_exp_find_port(exp, ntohs(port)); if (nated_port == 0) { /* No port available */ net_notice_ratelimited("nf_nat_q931: out of TCP ports\n"); return 0; @@ -579,53 +527,39 @@ static struct nf_ct_helper_expectfn callforwarding_nat = { .expectfn = ip_nat_callforwarding_expect, }; +static const struct nfct_h323_nat_hooks nathooks = { + .set_h245_addr = set_h245_addr, + .set_h225_addr = set_h225_addr, + .set_sig_addr = set_sig_addr, + .set_ras_addr = set_ras_addr, + .nat_rtp_rtcp = nat_rtp_rtcp, + .nat_t120 = nat_t120, + .nat_h245 = nat_h245, + .nat_callforwarding = nat_callforwarding, + .nat_q931 = nat_q931, +}; + /****************************************************************************/ -static int __init init(void) +static int __init nf_nat_h323_init(void) { - BUG_ON(set_h245_addr_hook != NULL); - BUG_ON(set_h225_addr_hook != NULL); - BUG_ON(set_sig_addr_hook != NULL); - BUG_ON(set_ras_addr_hook != NULL); - BUG_ON(nat_rtp_rtcp_hook != NULL); - BUG_ON(nat_t120_hook != NULL); - BUG_ON(nat_h245_hook != NULL); - BUG_ON(nat_callforwarding_hook != NULL); - BUG_ON(nat_q931_hook != NULL); - - RCU_INIT_POINTER(set_h245_addr_hook, set_h245_addr); - RCU_INIT_POINTER(set_h225_addr_hook, set_h225_addr); - RCU_INIT_POINTER(set_sig_addr_hook, set_sig_addr); - RCU_INIT_POINTER(set_ras_addr_hook, set_ras_addr); - RCU_INIT_POINTER(nat_rtp_rtcp_hook, nat_rtp_rtcp); - RCU_INIT_POINTER(nat_t120_hook, nat_t120); - RCU_INIT_POINTER(nat_h245_hook, nat_h245); - RCU_INIT_POINTER(nat_callforwarding_hook, nat_callforwarding); - RCU_INIT_POINTER(nat_q931_hook, nat_q931); + RCU_INIT_POINTER(nfct_h323_nat_hook, &nathooks); nf_ct_helper_expectfn_register(&q931_nat); nf_ct_helper_expectfn_register(&callforwarding_nat); return 0; } /****************************************************************************/ -static void __exit fini(void) +static void __exit nf_nat_h323_fini(void) { - RCU_INIT_POINTER(set_h245_addr_hook, NULL); - RCU_INIT_POINTER(set_h225_addr_hook, NULL); - RCU_INIT_POINTER(set_sig_addr_hook, NULL); - RCU_INIT_POINTER(set_ras_addr_hook, NULL); - RCU_INIT_POINTER(nat_rtp_rtcp_hook, NULL); - RCU_INIT_POINTER(nat_t120_hook, NULL); - RCU_INIT_POINTER(nat_h245_hook, NULL); - RCU_INIT_POINTER(nat_callforwarding_hook, NULL); - RCU_INIT_POINTER(nat_q931_hook, NULL); + RCU_INIT_POINTER(nfct_h323_nat_hook, NULL); nf_ct_helper_expectfn_unregister(&q931_nat); nf_ct_helper_expectfn_unregister(&callforwarding_nat); synchronize_rcu(); } /****************************************************************************/ -module_init(init); -module_exit(fini); +module_init(nf_nat_h323_init); +module_exit(nf_nat_h323_fini); MODULE_AUTHOR("Jing Min Zhao <zhaojingmin@users.sourceforge.net>"); MODULE_DESCRIPTION("H.323 NAT helper"); diff --git a/net/ipv4/netfilter/nf_nat_pptp.c b/net/ipv4/netfilter/nf_nat_pptp.c index 3f248a19faa3..fab357cc8559 100644 --- a/net/ipv4/netfilter/nf_nat_pptp.c +++ b/net/ipv4/netfilter/nf_nat_pptp.c @@ -295,28 +295,24 @@ pptp_inbound_pkt(struct sk_buff *skb, return NF_ACCEPT; } +static const struct nf_nat_pptp_hook pptp_hooks = { + .outbound = pptp_outbound_pkt, + .inbound = pptp_inbound_pkt, + .exp_gre = pptp_exp_gre, + .expectfn = pptp_nat_expected, +}; + static int __init nf_nat_helper_pptp_init(void) { - BUG_ON(nf_nat_pptp_hook_outbound != NULL); - RCU_INIT_POINTER(nf_nat_pptp_hook_outbound, pptp_outbound_pkt); - - BUG_ON(nf_nat_pptp_hook_inbound != NULL); - RCU_INIT_POINTER(nf_nat_pptp_hook_inbound, pptp_inbound_pkt); - - BUG_ON(nf_nat_pptp_hook_exp_gre != NULL); - RCU_INIT_POINTER(nf_nat_pptp_hook_exp_gre, pptp_exp_gre); + WARN_ON(nf_nat_pptp_hook != NULL); + RCU_INIT_POINTER(nf_nat_pptp_hook, &pptp_hooks); - BUG_ON(nf_nat_pptp_hook_expectfn != NULL); - RCU_INIT_POINTER(nf_nat_pptp_hook_expectfn, pptp_nat_expected); return 0; } static void __exit nf_nat_helper_pptp_fini(void) { - RCU_INIT_POINTER(nf_nat_pptp_hook_expectfn, NULL); - RCU_INIT_POINTER(nf_nat_pptp_hook_exp_gre, NULL); - RCU_INIT_POINTER(nf_nat_pptp_hook_inbound, NULL); - RCU_INIT_POINTER(nf_nat_pptp_hook_outbound, NULL); + RCU_INIT_POINTER(nf_nat_pptp_hook, NULL); synchronize_rcu(); } diff --git a/net/ipv4/netfilter/nf_nat_snmp_basic.asn1 b/net/ipv4/netfilter/nf_nat_snmp_basic.asn1 index 24b73268f362..dc2cc5794160 100644 --- a/net/ipv4/netfilter/nf_nat_snmp_basic.asn1 +++ b/net/ipv4/netfilter/nf_nat_snmp_basic.asn1 @@ -1,3 +1,11 @@ +-- SPDX-License-Identifier: BSD-3-Clause +-- +-- Copyright (C) 1990, 2002 IETF Trust and the persons identified as authors +-- of the code +-- +-- https://www.rfc-editor.org/rfc/rfc1157#section-4 +-- https://www.rfc-editor.org/rfc/rfc3416#section-3 + Message ::= SEQUENCE { version diff --git a/net/ipv4/netfilter/nf_reject_ipv4.c b/net/ipv4/netfilter/nf_reject_ipv4.c index 4eed5afca392..fae4aa4a5f09 100644 --- a/net/ipv4/netfilter/nf_reject_ipv4.c +++ b/net/ipv4/netfilter/nf_reject_ipv4.c @@ -12,6 +12,15 @@ #include <linux/netfilter_ipv4.h> #include <linux/netfilter_bridge.h> +static struct iphdr *nf_reject_iphdr_put(struct sk_buff *nskb, + const struct sk_buff *oldskb, + __u8 protocol, int ttl); +static void nf_reject_ip_tcphdr_put(struct sk_buff *nskb, const struct sk_buff *oldskb, + const struct tcphdr *oth); +static const struct tcphdr * +nf_reject_ip_tcphdr_get(struct sk_buff *oldskb, + struct tcphdr *_oth, int hook); + static int nf_reject_iphdr_validate(struct sk_buff *skb) { struct iphdr *iph; @@ -62,7 +71,7 @@ struct sk_buff *nf_reject_skb_v4_tcp_reset(struct net *net, skb_reserve(nskb, LL_MAX_HEADER); niph = nf_reject_iphdr_put(nskb, oldskb, IPPROTO_TCP, - net->ipv4.sysctl_ip_default_ttl); + READ_ONCE(net->ipv4.sysctl_ip_default_ttl)); nf_reject_ip_tcphdr_put(nskb, oldskb, oth); niph->tot_len = htons(nskb->len); ip_send_check(niph); @@ -71,6 +80,27 @@ struct sk_buff *nf_reject_skb_v4_tcp_reset(struct net *net, } EXPORT_SYMBOL_GPL(nf_reject_skb_v4_tcp_reset); +static bool nf_skb_is_icmp_unreach(const struct sk_buff *skb) +{ + const struct iphdr *iph = ip_hdr(skb); + u8 *tp, _type; + int thoff; + + if (iph->protocol != IPPROTO_ICMP) + return false; + + thoff = skb_network_offset(skb) + sizeof(*iph); + + tp = skb_header_pointer(skb, + thoff + offsetof(struct icmphdr, type), + sizeof(_type), &_type); + + if (!tp) + return false; + + return *tp == ICMP_DEST_UNREACH; +} + struct sk_buff *nf_reject_skb_v4_unreach(struct net *net, struct sk_buff *oldskb, const struct net_device *dev, @@ -80,6 +110,7 @@ struct sk_buff *nf_reject_skb_v4_unreach(struct net *net, struct iphdr *niph; struct icmphdr *icmph; unsigned int len; + int dataoff; __wsum csum; u8 proto; @@ -90,6 +121,10 @@ struct sk_buff *nf_reject_skb_v4_unreach(struct net *net, if (ip_hdr(oldskb)->frag_off & htons(IP_OFFSET)) return NULL; + /* don't reply to ICMP_DEST_UNREACH with ICMP_DEST_UNREACH. */ + if (nf_skb_is_icmp_unreach(oldskb)) + return NULL; + /* RFC says return as much as we can without exceeding 576 bytes. */ len = min_t(unsigned int, 536, oldskb->len); @@ -99,10 +134,11 @@ struct sk_buff *nf_reject_skb_v4_unreach(struct net *net, if (pskb_trim_rcsum(oldskb, ntohs(ip_hdr(oldskb)->tot_len))) return NULL; + dataoff = ip_hdrlen(oldskb); proto = ip_hdr(oldskb)->protocol; if (!skb_csum_unnecessary(oldskb) && - nf_reject_verify_csum(proto) && + nf_reject_verify_csum(oldskb, dataoff, proto) && nf_ip_checksum(oldskb, hook, ip_hdrlen(oldskb), proto)) return NULL; @@ -115,7 +151,7 @@ struct sk_buff *nf_reject_skb_v4_unreach(struct net *net, skb_reserve(nskb, LL_MAX_HEADER); niph = nf_reject_iphdr_put(nskb, oldskb, IPPROTO_ICMP, - net->ipv4.sysctl_ip_default_ttl); + READ_ONCE(net->ipv4.sysctl_ip_default_ttl)); skb_reset_transport_header(nskb); icmph = skb_put_zero(nskb, sizeof(struct icmphdr)); @@ -134,8 +170,9 @@ struct sk_buff *nf_reject_skb_v4_unreach(struct net *net, } EXPORT_SYMBOL_GPL(nf_reject_skb_v4_unreach); -const struct tcphdr *nf_reject_ip_tcphdr_get(struct sk_buff *oldskb, - struct tcphdr *_oth, int hook) +static const struct tcphdr * +nf_reject_ip_tcphdr_get(struct sk_buff *oldskb, + struct tcphdr *_oth, int hook) { const struct tcphdr *oth; @@ -161,11 +198,10 @@ const struct tcphdr *nf_reject_ip_tcphdr_get(struct sk_buff *oldskb, return oth; } -EXPORT_SYMBOL_GPL(nf_reject_ip_tcphdr_get); -struct iphdr *nf_reject_iphdr_put(struct sk_buff *nskb, - const struct sk_buff *oldskb, - __u8 protocol, int ttl) +static struct iphdr *nf_reject_iphdr_put(struct sk_buff *nskb, + const struct sk_buff *oldskb, + __u8 protocol, int ttl) { struct iphdr *niph, *oiph = ip_hdr(oldskb); @@ -186,10 +222,9 @@ struct iphdr *nf_reject_iphdr_put(struct sk_buff *nskb, return niph; } -EXPORT_SYMBOL_GPL(nf_reject_iphdr_put); -void nf_reject_ip_tcphdr_put(struct sk_buff *nskb, const struct sk_buff *oldskb, - const struct tcphdr *oth) +static void nf_reject_ip_tcphdr_put(struct sk_buff *nskb, const struct sk_buff *oldskb, + const struct tcphdr *oth) { struct iphdr *niph = ip_hdr(nskb); struct tcphdr *tcph; @@ -216,7 +251,6 @@ void nf_reject_ip_tcphdr_put(struct sk_buff *nskb, const struct sk_buff *oldskb, nskb->csum_start = (unsigned char *)tcph - nskb->head; nskb->csum_offset = offsetof(struct tcphdr, check); } -EXPORT_SYMBOL_GPL(nf_reject_ip_tcphdr_put); static int nf_reject_fill_skb_dst(struct sk_buff *skb_in) { @@ -237,18 +271,15 @@ static int nf_reject_fill_skb_dst(struct sk_buff *skb_in) void nf_send_reset(struct net *net, struct sock *sk, struct sk_buff *oldskb, int hook) { - struct net_device *br_indev __maybe_unused; - struct sk_buff *nskb; - struct iphdr *niph; const struct tcphdr *oth; + struct sk_buff *nskb; struct tcphdr _oth; oth = nf_reject_ip_tcphdr_get(oldskb, &_oth, hook); if (!oth) return; - if ((hook == NF_INET_PRE_ROUTING || hook == NF_INET_INGRESS) && - nf_reject_fill_skb_dst(oldskb) < 0) + if (!skb_dst(oldskb) && nf_reject_fill_skb_dst(oldskb) < 0) return; if (skb_rtable(oldskb)->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST)) @@ -265,19 +296,18 @@ void nf_send_reset(struct net *net, struct sock *sk, struct sk_buff *oldskb, nskb->mark = IP4_REPLY_MARK(net, oldskb->mark); skb_reserve(nskb, LL_MAX_HEADER); - niph = nf_reject_iphdr_put(nskb, oldskb, IPPROTO_TCP, - ip4_dst_hoplimit(skb_dst(nskb))); + nf_reject_iphdr_put(nskb, oldskb, IPPROTO_TCP, + ip4_dst_hoplimit(skb_dst(nskb))); nf_reject_ip_tcphdr_put(nskb, oldskb, oth); if (ip_route_me_harder(net, sk, nskb, RTN_UNSPEC)) goto free_nskb; - niph = ip_hdr(nskb); - /* "Never happens" */ if (nskb->len > dst_mtu(skb_dst(nskb))) goto free_nskb; nf_ct_attach(nskb, oldskb); + nf_ct_set_closing(skb_nfct(oldskb)); #if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) /* If we use ip_local_out for bridged traffic, the MAC source on @@ -286,9 +316,14 @@ void nf_send_reset(struct net *net, struct sock *sk, struct sk_buff *oldskb, * build the eth header using the original destination's MAC as the * source, and send the RST packet directly. */ - br_indev = nf_bridge_get_physindev(oldskb); - if (br_indev) { + if (nf_bridge_info_exists(oldskb)) { struct ethhdr *oeth = eth_hdr(oldskb); + struct iphdr *niph = ip_hdr(nskb); + struct net_device *br_indev; + + br_indev = nf_bridge_get_physindev(oldskb, net); + if (!br_indev) + goto free_nskb; nskb->dev = br_indev; niph->tot_len = htons(nskb->len); @@ -311,23 +346,25 @@ EXPORT_SYMBOL_GPL(nf_send_reset); void nf_send_unreach(struct sk_buff *skb_in, int code, int hook) { struct iphdr *iph = ip_hdr(skb_in); + int dataoff = ip_hdrlen(skb_in); u8 proto = iph->protocol; if (iph->frag_off & htons(IP_OFFSET)) return; - if ((hook == NF_INET_PRE_ROUTING || hook == NF_INET_INGRESS) && - nf_reject_fill_skb_dst(skb_in) < 0) + if (!skb_dst(skb_in) && nf_reject_fill_skb_dst(skb_in) < 0) return; - if (skb_csum_unnecessary(skb_in) || !nf_reject_verify_csum(proto)) { + if (skb_csum_unnecessary(skb_in) || + !nf_reject_verify_csum(skb_in, dataoff, proto)) { icmp_send(skb_in, ICMP_DEST_UNREACH, code, 0); return; } - if (nf_ip_checksum(skb_in, hook, ip_hdrlen(skb_in), proto) == 0) + if (nf_ip_checksum(skb_in, hook, dataoff, proto) == 0) icmp_send(skb_in, ICMP_DEST_UNREACH, code, 0); } EXPORT_SYMBOL_GPL(nf_send_unreach); MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("IPv4 packet rejection core"); diff --git a/net/ipv4/netfilter/nf_socket_ipv4.c b/net/ipv4/netfilter/nf_socket_ipv4.c index 2d42e4c35a20..5080fa5fbf6a 100644 --- a/net/ipv4/netfilter/nf_socket_ipv4.c +++ b/net/ipv4/netfilter/nf_socket_ipv4.c @@ -71,8 +71,7 @@ nf_socket_get_sock_v4(struct net *net, struct sk_buff *skb, const int doff, { switch (protocol) { case IPPROTO_TCP: - return inet_lookup(net, &tcp_hashinfo, skb, doff, - saddr, sport, daddr, dport, + return inet_lookup(net, skb, doff, saddr, sport, daddr, dport, in->ifindex); case IPPROTO_UDP: return udp4_lib_lookup(net, saddr, sport, daddr, dport, diff --git a/net/ipv4/netfilter/nf_tproxy_ipv4.c b/net/ipv4/netfilter/nf_tproxy_ipv4.c index b2bae0b0e42a..041c3f37f237 100644 --- a/net/ipv4/netfilter/nf_tproxy_ipv4.c +++ b/net/ipv4/netfilter/nf_tproxy_ipv4.c @@ -38,7 +38,7 @@ nf_tproxy_handle_time_wait4(struct net *net, struct sk_buff *skb, hp->source, lport ? lport : hp->dest, skb->dev, NF_TPROXY_LOOKUP_LISTENER); if (sk2) { - inet_twsk_deschedule_put(inet_twsk(sk)); + nf_tproxy_twsk_deschedule_put(inet_twsk(sk)); sk = sk2; } } @@ -58,6 +58,8 @@ __be32 nf_tproxy_laddr4(struct sk_buff *skb, __be32 user_laddr, __be32 daddr) laddr = 0; indev = __in_dev_get_rcu(skb->dev); + if (!indev) + return daddr; in_dev_for_each_ifa_rcu(ifa, indev) { if (ifa->ifa_flags & IFA_F_SECONDARY) @@ -92,12 +94,10 @@ nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, switch (lookup_type) { case NF_TPROXY_LOOKUP_LISTENER: - sk = inet_lookup_listener(net, &tcp_hashinfo, skb, - ip_hdrlen(skb) + - __tcp_hdrlen(hp), - saddr, sport, - daddr, dport, - in->ifindex, 0); + sk = inet_lookup_listener(net, skb, + ip_hdrlen(skb) + __tcp_hdrlen(hp), + saddr, sport, daddr, dport, + in->ifindex, 0); if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) sk = NULL; @@ -108,9 +108,8 @@ nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, */ break; case NF_TPROXY_LOOKUP_ESTABLISHED: - sk = inet_lookup_established(net, &tcp_hashinfo, - saddr, sport, daddr, dport, - in->ifindex); + sk = inet_lookup_established(net, saddr, sport, + daddr, dport, in->ifindex); break; default: BUG(); diff --git a/net/ipv4/netfilter/nft_dup_ipv4.c b/net/ipv4/netfilter/nft_dup_ipv4.c index aeb631760eb9..ef5dd88107dd 100644 --- a/net/ipv4/netfilter/nft_dup_ipv4.c +++ b/net/ipv4/netfilter/nft_dup_ipv4.c @@ -40,19 +40,20 @@ static int nft_dup_ipv4_init(const struct nft_ctx *ctx, if (tb[NFTA_DUP_SREG_ADDR] == NULL) return -EINVAL; - err = nft_parse_register_load(tb[NFTA_DUP_SREG_ADDR], &priv->sreg_addr, + err = nft_parse_register_load(ctx, tb[NFTA_DUP_SREG_ADDR], &priv->sreg_addr, sizeof(struct in_addr)); if (err < 0) return err; if (tb[NFTA_DUP_SREG_DEV]) - err = nft_parse_register_load(tb[NFTA_DUP_SREG_DEV], + err = nft_parse_register_load(ctx, tb[NFTA_DUP_SREG_DEV], &priv->sreg_dev, sizeof(int)); return err; } -static int nft_dup_ipv4_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_dup_ipv4_dump(struct sk_buff *skb, + const struct nft_expr *expr, bool reset) { struct nft_dup_ipv4 *priv = nft_expr_priv(expr); @@ -75,6 +76,7 @@ static const struct nft_expr_ops nft_dup_ipv4_ops = { .eval = nft_dup_ipv4_eval, .init = nft_dup_ipv4_init, .dump = nft_dup_ipv4_dump, + .reduce = NFT_REDUCE_READONLY, }; static const struct nla_policy nft_dup_ipv4_policy[NFTA_DUP_MAX + 1] = { diff --git a/net/ipv4/netfilter/nft_fib_ipv4.c b/net/ipv4/netfilter/nft_fib_ipv4.c index 03df986217b7..82af6cd76d13 100644 --- a/net/ipv4/netfilter/nft_fib_ipv4.c +++ b/net/ipv4/netfilter/nft_fib_ipv4.c @@ -10,6 +10,8 @@ #include <net/netfilter/nf_tables.h> #include <net/netfilter/nft_fib.h> +#include <net/flow.h> +#include <net/ip.h> #include <net/ip_fib.h> #include <net/route.h> @@ -22,8 +24,6 @@ static __be32 get_saddr(__be32 addr) return addr; } -#define DSCP_BITS 0xfc - void nft_fib4_eval_type(const struct nft_expr *expr, struct nft_regs *regs, const struct nft_pktinfo *pkt) { @@ -50,7 +50,12 @@ void nft_fib4_eval_type(const struct nft_expr *expr, struct nft_regs *regs, else addr = iph->saddr; - *dst = inet_dev_addr_type(nft_net(pkt), dev, addr); + if (priv->flags & (NFTA_FIB_F_IIF | NFTA_FIB_F_OIF)) { + *dst = inet_dev_addr_type(nft_net(pkt), dev, addr); + return; + } + + *dst = inet_addr_type_dev_table(nft_net(pkt), pkt->skb->dev, addr); } EXPORT_SYMBOL_GPL(nft_fib4_eval_type); @@ -65,10 +70,17 @@ void nft_fib4_eval(const struct nft_expr *expr, struct nft_regs *regs, struct flowi4 fl4 = { .flowi4_scope = RT_SCOPE_UNIVERSE, .flowi4_iif = LOOPBACK_IFINDEX, + .flowi4_proto = pkt->tprot, + .flowi4_uid = sock_net_uid(nft_net(pkt), NULL), }; const struct net_device *oif; const struct net_device *found; + if (nft_fib_can_skip(pkt)) { + nft_fib_store_result(dest, priv, nft_in(pkt)); + return; + } + /* * Do not set flowi4_oif, it restricts results (for example, asking * for oif 3 will get RTN_UNICAST result even if the daddr exits @@ -83,11 +95,7 @@ void nft_fib4_eval(const struct nft_expr *expr, struct nft_regs *regs, else oif = NULL; - if (nft_hook(pkt) == NF_INET_PRE_ROUTING && - nft_fib_is_loopback(pkt->skb, nft_in(pkt))) { - nft_fib_store_result(dest, priv, nft_in(pkt)); - return; - } + fl4.flowi4_l3mdev = nft_fib_l3mdev_master_ifindex_rcu(pkt, oif); iph = skb_header_pointer(pkt->skb, noff, sizeof(_iph), &_iph); if (!iph) { @@ -106,12 +114,16 @@ void nft_fib4_eval(const struct nft_expr *expr, struct nft_regs *regs, if (priv->flags & NFTA_FIB_F_MARK) fl4.flowi4_mark = pkt->skb->mark; - fl4.flowi4_tos = iph->tos & DSCP_BITS; + fl4.flowi4_dscp = ip4h_dscp(iph); if (priv->flags & NFTA_FIB_F_DADDR) { fl4.daddr = iph->daddr; fl4.saddr = get_saddr(iph->saddr); } else { + if (nft_hook(pkt) == NF_INET_FORWARD && + priv->flags & NFTA_FIB_F_IIF) + fl4.flowi4_iif = nft_out(pkt)->ifindex; + fl4.daddr = iph->saddr; fl4.saddr = get_saddr(iph->daddr); } @@ -130,12 +142,11 @@ void nft_fib4_eval(const struct nft_expr *expr, struct nft_regs *regs, break; } - if (!oif) { - found = FIB_RES_DEV(res); + if (!oif) { + found = FIB_RES_DEV(res); } else { if (!fib_info_nh_uses_dev(res.fi, oif)) return; - found = oif; } @@ -152,6 +163,7 @@ static const struct nft_expr_ops nft_fib4_type_ops = { .init = nft_fib_init, .dump = nft_fib_dump, .validate = nft_fib_validate, + .reduce = nft_fib_reduce, }; static const struct nft_expr_ops nft_fib4_ops = { @@ -161,6 +173,7 @@ static const struct nft_expr_ops nft_fib4_ops = { .init = nft_fib_init, .dump = nft_fib_dump, .validate = nft_fib_validate, + .reduce = nft_fib_reduce, }; static const struct nft_expr_ops * diff --git a/net/ipv4/netfilter/nft_reject_ipv4.c b/net/ipv4/netfilter/nft_reject_ipv4.c index 55fc23a8f7a7..6cb213bb7256 100644 --- a/net/ipv4/netfilter/nft_reject_ipv4.c +++ b/net/ipv4/netfilter/nft_reject_ipv4.c @@ -45,6 +45,7 @@ static const struct nft_expr_ops nft_reject_ipv4_ops = { .init = nft_reject_init, .dump = nft_reject_dump, .validate = nft_reject_validate, + .reduce = NFT_REDUCE_READONLY, }; static struct nft_expr_type nft_reject_ipv4_type __read_mostly = { diff --git a/net/ipv4/nexthop.c b/net/ipv4/nexthop.c index eeafeccebb8d..7b9d70f9b31c 100644 --- a/net/ipv4/nexthop.c +++ b/net/ipv4/nexthop.c @@ -26,6 +26,9 @@ static void remove_nexthop(struct net *net, struct nexthop *nh, #define NH_DEV_HASHBITS 8 #define NH_DEV_HASHSIZE (1U << NH_DEV_HASHBITS) +#define NHA_OP_FLAGS_DUMP_ALL (NHA_OP_FLAG_DUMP_STATS | \ + NHA_OP_FLAG_DUMP_HW_STATS) + static const struct nla_policy rtm_nh_policy_new[] = { [NHA_ID] = { .type = NLA_U32 }, [NHA_GROUP] = { .type = NLA_BINARY }, @@ -37,10 +40,17 @@ static const struct nla_policy rtm_nh_policy_new[] = { [NHA_ENCAP] = { .type = NLA_NESTED }, [NHA_FDB] = { .type = NLA_FLAG }, [NHA_RES_GROUP] = { .type = NLA_NESTED }, + [NHA_HW_STATS_ENABLE] = NLA_POLICY_MAX(NLA_U32, true), }; static const struct nla_policy rtm_nh_policy_get[] = { [NHA_ID] = { .type = NLA_U32 }, + [NHA_OP_FLAGS] = NLA_POLICY_MASK(NLA_U32, + NHA_OP_FLAGS_DUMP_ALL), +}; + +static const struct nla_policy rtm_nh_policy_del[] = { + [NHA_ID] = { .type = NLA_U32 }, }; static const struct nla_policy rtm_nh_policy_dump[] = { @@ -48,6 +58,8 @@ static const struct nla_policy rtm_nh_policy_dump[] = { [NHA_GROUPS] = { .type = NLA_FLAG }, [NHA_MASTER] = { .type = NLA_U32 }, [NHA_FDB] = { .type = NLA_FLAG }, + [NHA_OP_FLAGS] = NLA_POLICY_MASK(NLA_U32, + NHA_OP_FLAGS_DUMP_ALL), }; static const struct nla_policy rtm_nh_res_policy_new[] = { @@ -92,6 +104,7 @@ __nh_notifier_single_info_init(struct nh_notifier_single_info *nh_info, else if (nh_info->gw_family == AF_INET6) nh_info->ipv6 = nhi->fib_nhc.nhc_gw.ipv6; + nh_info->id = nhi->nh_parent->id; nh_info->is_reject = nhi->reject_nh; nh_info->is_fdb = nhi->fdb_nh; nh_info->has_encap = !!nhi->fib_nhc.nhc_lwtstate; @@ -131,13 +144,13 @@ static int nh_notifier_mpath_info_init(struct nh_notifier_info *info, info->nh_grp->num_nh = num_nh; info->nh_grp->is_fdb = nhg->fdb_nh; + info->nh_grp->hw_stats = nhg->hw_stats; for (i = 0; i < num_nh; i++) { struct nh_grp_entry *nhge = &nhg->nh_entries[i]; struct nh_info *nhi; nhi = rtnl_dereference(nhge->nh->nh_info); - info->nh_grp->nh_entries[i].id = nhge->nh->id; info->nh_grp->nh_entries[i].weight = nhge->weight; __nh_notifier_single_info_init(&info->nh_grp->nh_entries[i].nh, nhi); @@ -162,6 +175,7 @@ static int nh_notifier_res_table_info_init(struct nh_notifier_info *info, return -ENOMEM; info->nh_res_table->num_nh_buckets = num_nh_buckets; + info->nh_res_table->hw_stats = nhg->hw_stats; for (i = 0; i < num_nh_buckets; i++) { struct nh_res_bucket *bucket = &res_table->nh_buckets[i]; @@ -393,6 +407,7 @@ static int call_nexthop_res_table_notifiers(struct net *net, struct nexthop *nh, struct nh_notifier_info info = { .net = net, .extack = extack, + .id = nh->id, }; struct nh_group *nhg; int err; @@ -474,6 +489,7 @@ static void nexthop_free_group(struct nexthop *nh) struct nh_grp_entry *nhge = &nhg->nh_entries[i]; WARN_ON(!list_empty(&nhge->nh_list)); + free_percpu(nhge->stats); nexthop_put(nhge->nh); } @@ -525,6 +541,7 @@ static struct nexthop *nexthop_alloc(void) INIT_LIST_HEAD(&nh->f6i_list); INIT_LIST_HEAD(&nh->grp_list); INIT_LIST_HEAD(&nh->fdb_list); + spin_lock_init(&nh->lock); } return nh; } @@ -654,14 +671,213 @@ nla_put_failure: return -EMSGSIZE; } -static int nla_put_nh_group(struct sk_buff *skb, struct nh_group *nhg) +static void nh_grp_entry_stats_inc(struct nh_grp_entry *nhge) +{ + struct nh_grp_entry_stats *cpu_stats; + + cpu_stats = get_cpu_ptr(nhge->stats); + u64_stats_update_begin(&cpu_stats->syncp); + u64_stats_inc(&cpu_stats->packets); + u64_stats_update_end(&cpu_stats->syncp); + put_cpu_ptr(cpu_stats); +} + +static void nh_grp_entry_stats_read(struct nh_grp_entry *nhge, + u64 *ret_packets) +{ + int i; + + *ret_packets = 0; + + for_each_possible_cpu(i) { + struct nh_grp_entry_stats *cpu_stats; + unsigned int start; + u64 packets; + + cpu_stats = per_cpu_ptr(nhge->stats, i); + do { + start = u64_stats_fetch_begin(&cpu_stats->syncp); + packets = u64_stats_read(&cpu_stats->packets); + } while (u64_stats_fetch_retry(&cpu_stats->syncp, start)); + + *ret_packets += packets; + } +} + +static int nh_notifier_grp_hw_stats_init(struct nh_notifier_info *info, + const struct nexthop *nh) +{ + struct nh_group *nhg; + int i; + + ASSERT_RTNL(); + nhg = rtnl_dereference(nh->nh_grp); + + info->id = nh->id; + info->type = NH_NOTIFIER_INFO_TYPE_GRP_HW_STATS; + info->nh_grp_hw_stats = kzalloc(struct_size(info->nh_grp_hw_stats, + stats, nhg->num_nh), + GFP_KERNEL); + if (!info->nh_grp_hw_stats) + return -ENOMEM; + + info->nh_grp_hw_stats->num_nh = nhg->num_nh; + for (i = 0; i < nhg->num_nh; i++) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + + info->nh_grp_hw_stats->stats[i].id = nhge->nh->id; + } + + return 0; +} + +static void nh_notifier_grp_hw_stats_fini(struct nh_notifier_info *info) +{ + kfree(info->nh_grp_hw_stats); +} + +void nh_grp_hw_stats_report_delta(struct nh_notifier_grp_hw_stats_info *info, + unsigned int nh_idx, + u64 delta_packets) +{ + info->hw_stats_used = true; + info->stats[nh_idx].packets += delta_packets; +} +EXPORT_SYMBOL(nh_grp_hw_stats_report_delta); + +static void nh_grp_hw_stats_apply_update(struct nexthop *nh, + struct nh_notifier_info *info) +{ + struct nh_group *nhg; + int i; + + ASSERT_RTNL(); + nhg = rtnl_dereference(nh->nh_grp); + + for (i = 0; i < nhg->num_nh; i++) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + + nhge->packets_hw += info->nh_grp_hw_stats->stats[i].packets; + } +} + +static int nh_grp_hw_stats_update(struct nexthop *nh, bool *hw_stats_used) +{ + struct nh_notifier_info info = { + .net = nh->net, + }; + struct net *net = nh->net; + int err; + + if (nexthop_notifiers_is_empty(net)) { + *hw_stats_used = false; + return 0; + } + + err = nh_notifier_grp_hw_stats_init(&info, nh); + if (err) + return err; + + err = blocking_notifier_call_chain(&net->nexthop.notifier_chain, + NEXTHOP_EVENT_HW_STATS_REPORT_DELTA, + &info); + + /* Cache whatever we got, even if there was an error, otherwise the + * successful stats retrievals would get lost. + */ + nh_grp_hw_stats_apply_update(nh, &info); + *hw_stats_used = info.nh_grp_hw_stats->hw_stats_used; + + nh_notifier_grp_hw_stats_fini(&info); + return notifier_to_errno(err); +} + +static int nla_put_nh_group_stats_entry(struct sk_buff *skb, + struct nh_grp_entry *nhge, + u32 op_flags) +{ + struct nlattr *nest; + u64 packets; + + nh_grp_entry_stats_read(nhge, &packets); + + nest = nla_nest_start(skb, NHA_GROUP_STATS_ENTRY); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u32(skb, NHA_GROUP_STATS_ENTRY_ID, nhge->nh->id) || + nla_put_uint(skb, NHA_GROUP_STATS_ENTRY_PACKETS, + packets + nhge->packets_hw)) + goto nla_put_failure; + + if (op_flags & NHA_OP_FLAG_DUMP_HW_STATS && + nla_put_uint(skb, NHA_GROUP_STATS_ENTRY_PACKETS_HW, + nhge->packets_hw)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int nla_put_nh_group_stats(struct sk_buff *skb, struct nexthop *nh, + u32 op_flags) +{ + struct nh_group *nhg = rtnl_dereference(nh->nh_grp); + struct nlattr *nest; + bool hw_stats_used; + int err; + int i; + + if (nla_put_u32(skb, NHA_HW_STATS_ENABLE, nhg->hw_stats)) + goto err_out; + + if (op_flags & NHA_OP_FLAG_DUMP_HW_STATS && + nhg->hw_stats) { + err = nh_grp_hw_stats_update(nh, &hw_stats_used); + if (err) + goto out; + + if (nla_put_u32(skb, NHA_HW_STATS_USED, hw_stats_used)) + goto err_out; + } + + nest = nla_nest_start(skb, NHA_GROUP_STATS); + if (!nest) + goto err_out; + + for (i = 0; i < nhg->num_nh; i++) + if (nla_put_nh_group_stats_entry(skb, &nhg->nh_entries[i], + op_flags)) + goto cancel_out; + + nla_nest_end(skb, nest); + return 0; + +cancel_out: + nla_nest_cancel(skb, nest); +err_out: + err = -EMSGSIZE; +out: + return err; +} + +static int nla_put_nh_group(struct sk_buff *skb, struct nexthop *nh, + u32 op_flags, u32 *resp_op_flags) { + struct nh_group *nhg = rtnl_dereference(nh->nh_grp); struct nexthop_grp *p; size_t len = nhg->num_nh * sizeof(*p); struct nlattr *nla; u16 group_type = 0; + u16 weight; int i; + *resp_op_flags |= NHA_OP_FLAG_RESP_GRP_RESVD_0; + if (nhg->hash_threshold) group_type = NEXTHOP_GRP_TYPE_MPATH; else if (nhg->resilient) @@ -676,14 +892,23 @@ static int nla_put_nh_group(struct sk_buff *skb, struct nh_group *nhg) p = nla_data(nla); for (i = 0; i < nhg->num_nh; ++i) { - p->id = nhg->nh_entries[i].nh->id; - p->weight = nhg->nh_entries[i].weight - 1; - p += 1; + weight = nhg->nh_entries[i].weight - 1; + + *p++ = (struct nexthop_grp) { + .id = nhg->nh_entries[i].nh->id, + .weight = weight, + .weight_high = weight >> 8, + }; } if (nhg->resilient && nla_put_nh_group_res(skb, nhg)) goto nla_put_failure; + if (op_flags & NHA_OP_FLAG_DUMP_STATS && + (nla_put_u32(skb, NHA_HW_STATS_ENABLE, nhg->hw_stats) || + nla_put_nh_group_stats(skb, nh, op_flags))) + goto nla_put_failure; + return 0; nla_put_failure: @@ -691,7 +916,8 @@ nla_put_failure: } static int nh_fill_node(struct sk_buff *skb, struct nexthop *nh, - int event, u32 portid, u32 seq, unsigned int nlflags) + int event, u32 portid, u32 seq, unsigned int nlflags, + u32 op_flags) { struct fib6_nh *fib6_nh; struct fib_nh *fib_nh; @@ -715,10 +941,12 @@ static int nh_fill_node(struct sk_buff *skb, struct nexthop *nh, if (nh->is_group) { struct nh_group *nhg = rtnl_dereference(nh->nh_grp); + u32 resp_op_flags = 0; if (nhg->fdb_nh && nla_put_flag(skb, NHA_FDB)) goto nla_put_failure; - if (nla_put_nh_group(skb, nhg)) + if (nla_put_nh_group(skb, nh, op_flags, &resp_op_flags) || + nla_put_u32(skb, NHA_OP_FLAGS, resp_op_flags)) goto nla_put_failure; goto out; } @@ -757,8 +985,7 @@ static int nh_fill_node(struct sk_buff *skb, struct nexthop *nh, break; } - if (nhi->fib_nhc.nhc_lwtstate && - lwtunnel_fill_encap(skb, nhi->fib_nhc.nhc_lwtstate, + if (lwtunnel_fill_encap(skb, nhi->fib_nhc.nhc_lwtstate, NHA_ENCAP, NHA_ENCAP_TYPE) < 0) goto nla_put_failure; @@ -831,7 +1058,9 @@ static size_t nh_nlmsg_size(struct nexthop *nh) sz += nla_total_size(4); /* NHA_ID */ if (nh->is_group) - sz += nh_nlmsg_size_grp(nh); + sz += nh_nlmsg_size_grp(nh) + + nla_total_size(4) + /* NHA_OP_FLAGS */ + 0; else sz += nh_nlmsg_size_single(nh); @@ -849,7 +1078,7 @@ static void nexthop_notify(int event, struct nexthop *nh, struct nl_info *info) if (!skb) goto errout; - err = nh_fill_node(skb, nh, event, info->portid, seq, nlflags); + err = nh_fill_node(skb, nh, event, info->portid, seq, nlflags, 0); if (err < 0) { /* -EMSGSIZE implies BUG in nh_nlmsg_size() */ WARN_ON(err == -EMSGSIZE); @@ -861,8 +1090,7 @@ static void nexthop_notify(int event, struct nexthop *nh, struct nl_info *info) info->nlh, gfp_any()); return; errout: - if (err < 0) - rtnl_set_sk_err(info->nl_net, RTNLGRP_NEXTHOP, err); + rtnl_set_sk_err(info->nl_net, RTNLGRP_NEXTHOP, err); } static unsigned long nh_res_bucket_used_time(const struct nh_res_bucket *bucket) @@ -982,8 +1210,7 @@ static void nexthop_bucket_notify(struct nh_res_table *res_table, rtnl_notify(skb, nh->net, 0, RTNLGRP_NEXTHOP, NULL, GFP_KERNEL); return; errout: - if (err < 0) - rtnl_set_sk_err(nh->net, RTNLGRP_NEXTHOP, err); + rtnl_set_sk_err(nh->net, RTNLGRP_NEXTHOP, err); } static bool valid_group_nh(struct nexthop *nh, unsigned int npaths, @@ -1045,10 +1272,8 @@ static int nh_check_attr_group(struct net *net, u16 nh_grp_type, struct netlink_ext_ack *extack) { unsigned int len = nla_len(tb[NHA_GROUP]); - u8 nh_family = AF_UNSPEC; struct nexthop_grp *nhg; unsigned int i, j; - u8 nhg_fdb = 0; if (!len || len & (sizeof(struct nexthop_grp) - 1)) { NL_SET_ERR_MSG(extack, @@ -1061,11 +1286,14 @@ static int nh_check_attr_group(struct net *net, nhg = nla_data(tb[NHA_GROUP]); for (i = 0; i < len; ++i) { - if (nhg[i].resvd1 || nhg[i].resvd2) { - NL_SET_ERR_MSG(extack, "Reserved fields in nexthop_grp must be 0"); + if (nhg[i].resvd2) { + NL_SET_ERR_MSG(extack, "Reserved field in nexthop_grp must be 0"); return -EINVAL; } - if (nhg[i].weight > 254) { + if (nexthop_grp_weight(&nhg[i]) == 0) { + /* 0xffff got passed in, representing weight of 0x10000, + * which is too heavy. + */ NL_SET_ERR_MSG(extack, "Invalid value for weight"); return -EINVAL; } @@ -1077,10 +1305,41 @@ static int nh_check_attr_group(struct net *net, } } - if (tb[NHA_FDB]) - nhg_fdb = 1; nhg = nla_data(tb[NHA_GROUP]); - for (i = 0; i < len; ++i) { + for (i = NHA_GROUP_TYPE + 1; i < tb_size; ++i) { + if (!tb[i]) + continue; + switch (i) { + case NHA_HW_STATS_ENABLE: + case NHA_FDB: + continue; + case NHA_RES_GROUP: + if (nh_grp_type == NEXTHOP_GRP_TYPE_RES) + continue; + break; + } + NL_SET_ERR_MSG(extack, + "No other attributes can be set in nexthop groups"); + return -EINVAL; + } + + return 0; +} + +static int nh_check_attr_group_rtnl(struct net *net, struct nlattr *tb[], + struct netlink_ext_ack *extack) +{ + u8 nh_family = AF_UNSPEC; + struct nexthop_grp *nhg; + unsigned int len; + unsigned int i; + u8 nhg_fdb; + + len = nla_len(tb[NHA_GROUP]) / sizeof(*nhg); + nhg = nla_data(tb[NHA_GROUP]); + nhg_fdb = !!tb[NHA_FDB]; + + for (i = 0; i < len; i++) { struct nexthop *nh; bool is_fdb_nh; @@ -1100,21 +1359,6 @@ static int nh_check_attr_group(struct net *net, return -EINVAL; } } - for (i = NHA_GROUP_TYPE + 1; i < tb_size; ++i) { - if (!tb[i]) - continue; - switch (i) { - case NHA_FDB: - continue; - case NHA_RES_GROUP: - if (nh_grp_type == NEXTHOP_GRP_TYPE_RES) - continue; - break; - } - NL_SET_ERR_MSG(extack, - "No other attributes can be set in nexthop groups"); - return -EINVAL; - } return 0; } @@ -1124,13 +1368,13 @@ static bool ipv6_good_nh(const struct fib6_nh *nh) int state = NUD_REACHABLE; struct neighbour *n; - rcu_read_lock_bh(); + rcu_read_lock(); n = __ipv6_neigh_lookup_noref_stub(nh->fib_nh_dev, &nh->fib_nh_gw6); if (n) - state = n->nud_state; + state = READ_ONCE(n->nud_state); - rcu_read_unlock_bh(); + rcu_read_unlock(); return !!(state & NUD_VALID); } @@ -1140,53 +1384,81 @@ static bool ipv4_good_nh(const struct fib_nh *nh) int state = NUD_REACHABLE; struct neighbour *n; - rcu_read_lock_bh(); + rcu_read_lock(); n = __ipv4_neigh_lookup_noref(nh->fib_nh_dev, (__force u32)nh->fib_nh_gw4); if (n) - state = n->nud_state; + state = READ_ONCE(n->nud_state); - rcu_read_unlock_bh(); + rcu_read_unlock(); return !!(state & NUD_VALID); } -static struct nexthop *nexthop_select_path_hthr(struct nh_group *nhg, int hash) +static bool nexthop_is_good_nh(const struct nexthop *nh) +{ + struct nh_info *nhi = rcu_dereference(nh->nh_info); + + switch (nhi->family) { + case AF_INET: + return ipv4_good_nh(&nhi->fib_nh); + case AF_INET6: + return ipv6_good_nh(&nhi->fib6_nh); + } + + return false; +} + +static struct nexthop *nexthop_select_path_fdb(struct nh_group *nhg, int hash) { - struct nexthop *rc = NULL; int i; - for (i = 0; i < nhg->num_nh; ++i) { + for (i = 0; i < nhg->num_nh; i++) { struct nh_grp_entry *nhge = &nhg->nh_entries[i]; - struct nh_info *nhi; if (hash > atomic_read(&nhge->hthr.upper_bound)) continue; - nhi = rcu_dereference(nhge->nh->nh_info); - if (nhi->fdb_nh) - return nhge->nh; + nh_grp_entry_stats_inc(nhge); + return nhge->nh; + } + + WARN_ON_ONCE(1); + return NULL; +} + +static struct nexthop *nexthop_select_path_hthr(struct nh_group *nhg, int hash) +{ + struct nh_grp_entry *nhge0 = NULL; + int i; + + if (nhg->fdb_nh) + return nexthop_select_path_fdb(nhg, hash); + + for (i = 0; i < nhg->num_nh; ++i) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; /* nexthops always check if it is good and does * not rely on a sysctl for this behavior */ - switch (nhi->family) { - case AF_INET: - if (ipv4_good_nh(&nhi->fib_nh)) - return nhge->nh; - break; - case AF_INET6: - if (ipv6_good_nh(&nhi->fib6_nh)) - return nhge->nh; - break; - } + if (!nexthop_is_good_nh(nhge->nh)) + continue; + + if (!nhge0) + nhge0 = nhge; + + if (hash > atomic_read(&nhge->hthr.upper_bound)) + continue; - if (!rc) - rc = nhge->nh; + nh_grp_entry_stats_inc(nhge); + return nhge->nh; } - return rc; + if (!nhge0) + nhge0 = &nhg->nh_entries[0]; + nh_grp_entry_stats_inc(nhge0); + return nhge0->nh; } static struct nexthop *nexthop_select_path_res(struct nh_group *nhg, int hash) @@ -1202,6 +1474,7 @@ static struct nexthop *nexthop_select_path_res(struct nh_group *nhg, int hash) bucket = &res_table->nh_buckets[bucket_index]; nh_res_bucket_set_busy(bucket); nhge = rcu_dereference(bucket->nh_entry); + nh_grp_entry_stats_inc(nhge); return nhge->nh; } @@ -1282,12 +1555,12 @@ int fib6_check_nexthop(struct nexthop *nh, struct fib6_config *cfg, if (nh->is_group) { struct nh_group *nhg; - nhg = rtnl_dereference(nh->nh_grp); + nhg = rcu_dereference_rtnl(nh->nh_grp); if (nhg->has_v4) goto no_v4_nh; is_fdb_nh = nhg->fdb_nh; } else { - nhi = rtnl_dereference(nh->nh_info); + nhi = rcu_dereference_rtnl(nh->nh_info); if (nhi->family == AF_INET) goto no_v4_nh; is_fdb_nh = nhi->fdb_nh; @@ -1631,9 +1904,9 @@ static void nh_res_table_cancel_upkeep(struct nh_res_table *res_table) static void nh_res_group_rebalance(struct nh_group *nhg, struct nh_res_table *res_table) { - int prev_upper_bound = 0; - int total = 0; - int w = 0; + u16 prev_upper_bound = 0; + u32 total = 0; + u32 w = 0; int i; INIT_LIST_HEAD(&res_table->uw_nh_entries); @@ -1643,11 +1916,12 @@ static void nh_res_group_rebalance(struct nh_group *nhg, for (i = 0; i < nhg->num_nh; ++i) { struct nh_grp_entry *nhge = &nhg->nh_entries[i]; - int upper_bound; + u16 upper_bound; + u64 btw; w += nhge->weight; - upper_bound = DIV_ROUND_CLOSEST(res_table->num_nh_buckets * w, - total); + btw = ((u64)res_table->num_nh_buckets) * w; + upper_bound = DIV_ROUND_CLOSEST_ULL(btw, total); nhge->res.wants_buckets = upper_bound - prev_upper_bound; prev_upper_bound = upper_bound; @@ -1713,8 +1987,8 @@ static void replace_nexthop_grp_res(struct nh_group *oldg, static void nh_hthr_group_rebalance(struct nh_group *nhg) { - int total = 0; - int w = 0; + u32 total = 0; + u32 w = 0; int i; for (i = 0; i < nhg->num_nh; ++i) @@ -1722,7 +1996,7 @@ static void nh_hthr_group_rebalance(struct nh_group *nhg) for (i = 0; i < nhg->num_nh; ++i) { struct nh_grp_entry *nhge = &nhg->nh_entries[i]; - int upper_bound; + u32 upper_bound; w += nhge->weight; upper_bound = DIV_ROUND_CLOSEST_ULL((u64)w << 31, total) - 1; @@ -1775,6 +2049,7 @@ static void remove_nh_grp_entry(struct net *net, struct nh_grp_entry *nhge, newg->has_v4 = true; list_del(&nhges[i].nh_list); + new_nhges[j].stats = nhges[i].stats; new_nhges[j].nh_parent = nhges[i].nh_parent; new_nhges[j].nh = nhges[i].nh; new_nhges[j].weight = nhges[i].weight; @@ -1790,6 +2065,7 @@ static void remove_nh_grp_entry(struct net *net, struct nh_grp_entry *nhge, rcu_assign_pointer(nhp->nh_grp, newg); list_del(&nhge->nh_list); + free_percpu(nhge->stats); nexthop_put(nhge->nh); /* Removal of a NH from a resilient group is notified through @@ -1811,6 +2087,12 @@ static void remove_nexthop_from_groups(struct net *net, struct nexthop *nh, { struct nh_grp_entry *nhge, *tmp; + /* If there is nothing to do, let's avoid the costly call to + * synchronize_net() + */ + if (list_empty(&nh->grp_list)) + return; + list_for_each_entry_safe(nhge, tmp, &nh->grp_list, nh_list) remove_nh_grp_entry(net, nhge, nlinfo); @@ -1842,7 +2124,7 @@ static void remove_nexthop_group(struct nexthop *nh, struct nl_info *nlinfo) /* not called for nexthop replace */ static void __remove_nexthop_fib(struct net *net, struct nexthop *nh) { - struct fib6_info *f6i, *tmp; + struct fib6_info *f6i; bool do_flush = false; struct fib_info *fi; @@ -1853,13 +2135,24 @@ static void __remove_nexthop_fib(struct net *net, struct nexthop *nh) if (do_flush) fib_flush(net); - /* ip6_del_rt removes the entry from this list hence the _safe */ - list_for_each_entry_safe(f6i, tmp, &nh->f6i_list, nh_list) { + spin_lock_bh(&nh->lock); + + nh->dead = true; + + while (!list_empty(&nh->f6i_list)) { + f6i = list_first_entry(&nh->f6i_list, typeof(*f6i), nh_list); + /* __ip6_del_rt does a release, so do a hold here */ fib6_info_hold(f6i); + + spin_unlock_bh(&nh->lock); ipv6_stub->ip6_del_rt(net, f6i, - !net->ipv4.sysctl_nexthop_compat_mode); + !READ_ONCE(net->ipv4.sysctl_nexthop_compat_mode)); + + spin_lock_bh(&nh->lock); } + + spin_unlock_bh(&nh->lock); } static void __remove_nexthop(struct net *net, struct nexthop *nh, @@ -2112,6 +2405,13 @@ static int replace_nexthop_single(struct net *net, struct nexthop *old, return -EINVAL; } + if (!list_empty(&old->grp_list) && + rtnl_dereference(new->nh_info)->fdb_nh != + rtnl_dereference(old->nh_info)->fdb_nh) { + NL_SET_ERR_MSG(extack, "Cannot change nexthop FDB status while in a group"); + return -EINVAL; + } + err = call_nexthop_notifiers(net, NEXTHOP_EVENT_REPLACE, new, extack); if (err) return err; @@ -2361,7 +2661,8 @@ out: if (!rc) { nh_base_seq_inc(net); nexthop_notify(RTM_NEWNEXTHOP, new_nh, &cfg->nlinfo); - if (replace_notify && net->ipv4.sysctl_nexthop_compat_mode) + if (replace_notify && + READ_ONCE(net->ipv4.sysctl_nexthop_compat_mode)) nexthop_replace_notify(net, new_nh, &cfg->nlinfo); } @@ -2415,9 +2716,6 @@ static struct nexthop *nexthop_create_group(struct net *net, int err; int i; - if (WARN_ON(!num_nh)) - return ERR_PTR(-EINVAL); - nh = nexthop_alloc(); if (!nh) return ERR_PTR(-ENOMEM); @@ -2453,8 +2751,16 @@ static struct nexthop *nexthop_create_group(struct net *net, if (nhi->family == AF_INET) nhg->has_v4 = true; + nhg->nh_entries[i].stats = + netdev_alloc_pcpu_stats(struct nh_grp_entry_stats); + if (!nhg->nh_entries[i].stats) { + err = -ENOMEM; + nexthop_put(nhe); + goto out_no_nh; + } nhg->nh_entries[i].nh = nhe; - nhg->nh_entries[i].weight = entry[i].weight + 1; + nhg->nh_entries[i].weight = nexthop_grp_weight(&entry[i]); + list_add(&nhg->nh_entries[i].nh_list, &nhe->grp_list); nhg->nh_entries[i].nh_parent = nh; } @@ -2485,6 +2791,9 @@ static struct nexthop *nexthop_create_group(struct net *net, if (cfg->nh_fdb) nhg->fdb_nh = 1; + if (cfg->nh_hw_stats) + nhg->hw_stats = true; + rcu_assign_pointer(nh->nh_grp, nhg); return nh; @@ -2492,6 +2801,7 @@ static struct nexthop *nexthop_create_group(struct net *net, out_no_nh: for (i--; i >= 0; --i) { list_del(&nhg->nh_entries[i].nh_list); + free_percpu(nhg->nh_entries[i].stats); nexthop_put(nhg->nh_entries[i].nh); } @@ -2533,7 +2843,7 @@ static int nh_create_ipv4(struct net *net, struct nexthop *nh, if (!err) { nh->nh_flags = fib_nh->fib_nh_flags; fib_info_update_nhc_saddr(net, &fib_nh->nh_common, - fib_nh->fib_nh_scope); + !fib_nh->fib_nh_scope ? 0 : fib_nh->fib_nh_scope - 1); } else { fib_nh_release(net, fib_nh); } @@ -2639,11 +2949,6 @@ static struct nexthop *nexthop_add(struct net *net, struct nh_config *cfg, struct nexthop *nh; int err; - if (cfg->nlflags & NLM_F_REPLACE && !cfg->nh_id) { - NL_SET_ERR_MSG(extack, "Replace requires nexthop id"); - return ERR_PTR(-EINVAL); - } - if (!cfg->nh_id) { cfg->nh_id = nh_find_unused_id(net); if (!cfg->nh_id) { @@ -2740,19 +3045,13 @@ static int rtm_to_nh_config_grp_res(struct nlattr *res, struct nh_config *cfg, } static int rtm_to_nh_config(struct net *net, struct sk_buff *skb, - struct nlmsghdr *nlh, struct nh_config *cfg, + struct nlmsghdr *nlh, struct nlattr **tb, + struct nh_config *cfg, struct netlink_ext_ack *extack) { struct nhmsg *nhm = nlmsg_data(nlh); - struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_new)]; int err; - err = nlmsg_parse(nlh, sizeof(*nhm), tb, - ARRAY_SIZE(rtm_nh_policy_new) - 1, - rtm_nh_policy_new, extack); - if (err < 0) - return err; - err = -EINVAL; if (nhm->resvd || nhm->nh_scope) { NL_SET_ERR_MSG(extack, "Invalid values in ancillary header"); @@ -2817,7 +3116,8 @@ static int rtm_to_nh_config(struct net *net, struct sk_buff *skb, NL_SET_ERR_MSG(extack, "Invalid group type"); goto out; } - err = nh_check_attr_group(net, tb, ARRAY_SIZE(tb), + + err = nh_check_attr_group(net, tb, ARRAY_SIZE(rtm_nh_policy_new), cfg->nh_grp_type, extack); if (err) goto out; @@ -2826,6 +3126,9 @@ static int rtm_to_nh_config(struct net *net, struct sk_buff *skb, err = rtm_to_nh_config_grp_res(tb[NHA_RES_GROUP], cfg, extack); + if (tb[NHA_HW_STATS_ENABLE]) + cfg->nh_hw_stats = nla_get_u32(tb[NHA_HW_STATS_ENABLE]); + /* no other attributes should be set */ goto out; } @@ -2847,25 +3150,6 @@ static int rtm_to_nh_config(struct net *net, struct sk_buff *skb, goto out; } - if (!cfg->nh_fdb && tb[NHA_OIF]) { - cfg->nh_ifindex = nla_get_u32(tb[NHA_OIF]); - if (cfg->nh_ifindex) - cfg->dev = __dev_get_by_index(net, cfg->nh_ifindex); - - if (!cfg->dev) { - NL_SET_ERR_MSG(extack, "Invalid device index"); - goto out; - } else if (!(cfg->dev->flags & IFF_UP)) { - NL_SET_ERR_MSG(extack, "Nexthop device is not up"); - err = -ENETDOWN; - goto out; - } else if (!netif_carrier_ok(cfg->dev)) { - NL_SET_ERR_MSG(extack, "Carrier for nexthop device is down"); - err = -ENETDOWN; - goto out; - } - } - err = -EINVAL; if (tb[NHA_GATEWAY]) { struct nlattr *gwa = tb[NHA_GATEWAY]; @@ -2917,34 +3201,92 @@ static int rtm_to_nh_config(struct net *net, struct sk_buff *skb, goto out; } + if (tb[NHA_HW_STATS_ENABLE]) { + NL_SET_ERR_MSG(extack, "Cannot enable nexthop hardware statistics for non-group nexthops"); + goto out; + } err = 0; out: return err; } +static int rtm_to_nh_config_rtnl(struct net *net, struct nlattr **tb, + struct nh_config *cfg, + struct netlink_ext_ack *extack) +{ + if (tb[NHA_GROUP]) + return nh_check_attr_group_rtnl(net, tb, extack); + + if (tb[NHA_OIF]) { + cfg->nh_ifindex = nla_get_u32(tb[NHA_OIF]); + if (cfg->nh_ifindex) + cfg->dev = __dev_get_by_index(net, cfg->nh_ifindex); + + if (!cfg->dev) { + NL_SET_ERR_MSG(extack, "Invalid device index"); + return -EINVAL; + } + + if (!(cfg->dev->flags & IFF_UP)) { + NL_SET_ERR_MSG(extack, "Nexthop device is not up"); + return -ENETDOWN; + } + + if (!netif_carrier_ok(cfg->dev)) { + NL_SET_ERR_MSG(extack, "Carrier for nexthop device is down"); + return -ENETDOWN; + } + } + + return 0; +} + /* rtnl */ static int rtm_new_nexthop(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { + struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_new)]; struct net *net = sock_net(skb->sk); struct nh_config cfg; struct nexthop *nh; int err; - err = rtm_to_nh_config(net, skb, nlh, &cfg, extack); - if (!err) { - nh = nexthop_add(net, &cfg, extack); - if (IS_ERR(nh)) - err = PTR_ERR(nh); + err = nlmsg_parse(nlh, sizeof(struct nhmsg), tb, + ARRAY_SIZE(rtm_nh_policy_new) - 1, + rtm_nh_policy_new, extack); + if (err < 0) + goto out; + + err = rtm_to_nh_config(net, skb, nlh, tb, &cfg, extack); + if (err) + goto out; + + if (cfg.nlflags & NLM_F_REPLACE && !cfg.nh_id) { + NL_SET_ERR_MSG(extack, "Replace requires nexthop id"); + err = -EINVAL; + goto out; } + rtnl_net_lock(net); + + err = rtm_to_nh_config_rtnl(net, tb, &cfg, extack); + if (err) + goto unlock; + + nh = nexthop_add(net, &cfg, extack); + if (IS_ERR(nh)) + err = PTR_ERR(nh); + +unlock: + rtnl_net_unlock(net); +out: return err; } -static int __nh_valid_get_del_req(const struct nlmsghdr *nlh, - struct nlattr **tb, u32 *id, - struct netlink_ext_ack *extack) +static int nh_valid_get_del_req(const struct nlmsghdr *nlh, + struct nlattr **tb, u32 *id, u32 *op_flags, + struct netlink_ext_ack *extack) { struct nhmsg *nhm = nlmsg_data(nlh); @@ -2964,28 +3306,17 @@ static int __nh_valid_get_del_req(const struct nlmsghdr *nlh, return -EINVAL; } - return 0; -} - -static int nh_valid_get_del_req(const struct nlmsghdr *nlh, u32 *id, - struct netlink_ext_ack *extack) -{ - struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_get)]; - int err; + if (op_flags) + *op_flags = nla_get_u32_default(tb[NHA_OP_FLAGS], 0); - err = nlmsg_parse(nlh, sizeof(struct nhmsg), tb, - ARRAY_SIZE(rtm_nh_policy_get) - 1, - rtm_nh_policy_get, extack); - if (err < 0) - return err; - - return __nh_valid_get_del_req(nlh, tb, id, extack); + return 0; } /* rtnl */ static int rtm_del_nexthop(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { + struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_del)]; struct net *net = sock_net(skb->sk); struct nl_info nlinfo = { .nlh = nlh, @@ -2996,30 +3327,48 @@ static int rtm_del_nexthop(struct sk_buff *skb, struct nlmsghdr *nlh, int err; u32 id; - err = nh_valid_get_del_req(nlh, &id, extack); + err = nlmsg_parse(nlh, sizeof(struct nhmsg), tb, + ARRAY_SIZE(rtm_nh_policy_del) - 1, rtm_nh_policy_del, + extack); + if (err < 0) + return err; + + err = nh_valid_get_del_req(nlh, tb, &id, NULL, extack); if (err) return err; + rtnl_net_lock(net); + nh = nexthop_find_by_id(net, id); - if (!nh) - return -ENOENT; + if (nh) + remove_nexthop(net, nh, &nlinfo); + else + err = -ENOENT; - remove_nexthop(net, nh, &nlinfo); + rtnl_net_unlock(net); - return 0; + return err; } /* rtnl */ static int rtm_get_nexthop(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { + struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_get)]; struct net *net = sock_net(in_skb->sk); struct sk_buff *skb = NULL; struct nexthop *nh; + u32 op_flags; int err; u32 id; - err = nh_valid_get_del_req(nlh, &id, extack); + err = nlmsg_parse(nlh, sizeof(struct nhmsg), tb, + ARRAY_SIZE(rtm_nh_policy_get) - 1, rtm_nh_policy_get, + extack); + if (err < 0) + return err; + + err = nh_valid_get_del_req(nlh, tb, &id, &op_flags, extack); if (err) return err; @@ -3034,7 +3383,7 @@ static int rtm_get_nexthop(struct sk_buff *in_skb, struct nlmsghdr *nlh, goto errout_free; err = nh_fill_node(skb, nh, RTM_NEWNEXTHOP, NETLINK_CB(in_skb).portid, - nlh->nlmsg_seq, 0); + nlh->nlmsg_seq, 0, op_flags); if (err < 0) { WARN_ON(err == -EMSGSIZE); goto errout_free; @@ -3055,6 +3404,7 @@ struct nh_dump_filter { bool group_filter; bool fdb_filter; u32 res_bucket_nh_id; + u32 op_flags; }; static bool nh_dump_filtered(struct nexthop *nh, @@ -3142,6 +3492,8 @@ static int nh_valid_dump_req(const struct nlmsghdr *nlh, if (err < 0) return err; + filter->op_flags = nla_get_u32_default(tb[NHA_OP_FLAGS], 0); + return __nh_valid_dump_req(nlh, tb, filter, cb->extack); } @@ -3172,12 +3524,42 @@ static int rtm_dump_walk_nexthops(struct sk_buff *skb, int err; s_idx = ctx->idx; - for (node = rb_first(root); node; node = rb_next(node)) { + + /* If this is not the first invocation, ctx->idx will contain the id of + * the last nexthop we processed. Instead of starting from the very + * first element of the red/black tree again and linearly skipping the + * (potentially large) set of nodes with an id smaller than s_idx, walk + * the tree and find the left-most node whose id is >= s_idx. This + * provides an efficient O(log n) starting point for the dump + * continuation. + */ + if (s_idx != 0) { + struct rb_node *tmp = root->rb_node; + + node = NULL; + while (tmp) { + struct nexthop *nh; + + nh = rb_entry(tmp, struct nexthop, rb_node); + if (nh->id < s_idx) { + tmp = tmp->rb_right; + } else { + /* Track current candidate and keep looking on + * the left side to find the left-most + * (smallest id) that is still >= s_idx. + */ + node = tmp; + tmp = tmp->rb_left; + } + } + } else { + node = rb_first(root); + } + + for (; node; node = rb_next(node)) { struct nexthop *nh; nh = rb_entry(node, struct nexthop, rb_node); - if (nh->id < s_idx) - continue; ctx->idx = nh->id; err = nh_cb(skb, cb, nh, data); @@ -3185,7 +3567,6 @@ static int rtm_dump_walk_nexthops(struct sk_buff *skb, return err; } - ctx->idx++; return 0; } @@ -3200,7 +3581,7 @@ static int rtm_dump_nexthop_cb(struct sk_buff *skb, struct netlink_callback *cb, return nh_fill_node(skb, nh, RTM_NEWNEXTHOP, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, NLM_F_MULTI); + cb->nlh->nlmsg_seq, NLM_F_MULTI, filter->op_flags); } /* rtnl */ @@ -3218,15 +3599,7 @@ static int rtm_dump_nexthop(struct sk_buff *skb, struct netlink_callback *cb) err = rtm_dump_walk_nexthops(skb, cb, root, ctx, &rtm_dump_nexthop_cb, &filter); - if (err < 0) { - if (likely(skb->len)) - goto out; - goto out_err; - } -out: - err = skb->len; -out_err: cb->seq = net->nexthop.seq; nl_dump_check_consistent(cb, nlmsg_hdr(skb)); return err; @@ -3317,7 +3690,6 @@ static int nh_valid_dump_bucket_req(const struct nlmsghdr *nlh, struct rtm_dump_res_bucket_ctx { struct rtm_dump_nh_ctx nh; u16 bucket_index; - u32 done_nh_idx; /* 1 + the index of the last fully processed NH. */ }; static struct rtm_dump_res_bucket_ctx * @@ -3346,9 +3718,6 @@ static int rtm_dump_nexthop_bucket_nh(struct sk_buff *skb, u16 bucket_index; int err; - if (dd->ctx->nh.idx < dd->ctx->done_nh_idx) - return 0; - nhg = rtnl_dereference(nh->nh_grp); res_table = rtnl_dereference(nhg->res_table); for (bucket_index = dd->ctx->bucket_index; @@ -3366,25 +3735,18 @@ static int rtm_dump_nexthop_bucket_nh(struct sk_buff *skb, dd->filter.res_bucket_nh_id != nhge->nh->id) continue; + dd->ctx->bucket_index = bucket_index; err = nh_fill_res_bucket(skb, nh, bucket, bucket_index, RTM_NEWNEXTHOPBUCKET, portid, cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->extack); - if (err < 0) { - if (likely(skb->len)) - goto out; - goto out_err; - } + if (err) + return err; } - dd->ctx->done_nh_idx = dd->ctx->nh.idx + 1; - bucket_index = 0; + dd->ctx->bucket_index = 0; -out: - err = skb->len; -out_err: - dd->ctx->bucket_index = bucket_index; - return err; + return 0; } static int rtm_dump_nexthop_bucket_cb(struct sk_buff *skb, @@ -3431,15 +3793,6 @@ static int rtm_dump_nexthop_bucket(struct sk_buff *skb, &rtm_dump_nexthop_bucket_cb, &dd); } - if (err < 0) { - if (likely(skb->len)) - goto out; - goto out_err; - } - -out: - err = skb->len; -out_err: cb->seq = net->nexthop.seq; nl_dump_check_consistent(cb, nlmsg_hdr(skb)); return err; @@ -3479,7 +3832,7 @@ static int nh_valid_get_bucket_req(const struct nlmsghdr *nlh, if (err < 0) return err; - err = __nh_valid_get_del_req(nlh, tb, id, extack); + err = nh_valid_get_del_req(nlh, tb, id, NULL, extack); if (err) return err; @@ -3574,7 +3927,7 @@ static int nh_netdev_event(struct notifier_block *this, nexthop_flush_dev(dev, event); break; case NETDEV_CHANGE: - if (!(dev_get_flags(dev) & (IFF_RUNNING | IFF_LOWER_UP))) + if (!(netif_get_flags(dev) & (IFF_RUNNING | IFF_LOWER_UP))) nexthop_flush_dev(dev, event); break; case NETDEV_CHANGEMTU: @@ -3627,17 +3980,24 @@ unlock: } EXPORT_SYMBOL(register_nexthop_notifier); -int unregister_nexthop_notifier(struct net *net, struct notifier_block *nb) +int __unregister_nexthop_notifier(struct net *net, struct notifier_block *nb) { int err; - rtnl_lock(); err = blocking_notifier_chain_unregister(&net->nexthop.notifier_chain, nb); - if (err) - goto unlock; - nexthops_dump(net, nb, NEXTHOP_EVENT_DEL, NULL); -unlock: + if (!err) + nexthops_dump(net, nb, NEXTHOP_EVENT_DEL, NULL); + return err; +} +EXPORT_SYMBOL(__unregister_nexthop_notifier); + +int unregister_nexthop_notifier(struct net *net, struct notifier_block *nb) +{ + int err; + + rtnl_lock(); + err = __unregister_nexthop_notifier(net, nb); rtnl_unlock(); return err; } @@ -3733,12 +4093,17 @@ out: } EXPORT_SYMBOL(nexthop_res_grp_activity_update); -static void __net_exit nexthop_net_exit(struct net *net) +static void __net_exit nexthop_net_exit_rtnl(struct net *net, + struct list_head *dev_to_kill) { - rtnl_lock(); + ASSERT_RTNL_NET(net); flush_all_nexthops(net); - rtnl_unlock(); +} + +static void __net_exit nexthop_net_exit(struct net *net) +{ kfree(net->nexthop.devhash); + net->nexthop.devhash = NULL; } static int __net_init nexthop_net_init(struct net *net) @@ -3757,6 +4122,26 @@ static int __net_init nexthop_net_init(struct net *net) static struct pernet_operations nexthop_net_ops = { .init = nexthop_net_init, .exit = nexthop_net_exit, + .exit_rtnl = nexthop_net_exit_rtnl, +}; + +static const struct rtnl_msg_handler nexthop_rtnl_msg_handlers[] __initconst = { + {.msgtype = RTM_NEWNEXTHOP, .doit = rtm_new_nexthop, + .flags = RTNL_FLAG_DOIT_PERNET}, + {.msgtype = RTM_DELNEXTHOP, .doit = rtm_del_nexthop, + .flags = RTNL_FLAG_DOIT_PERNET}, + {.msgtype = RTM_GETNEXTHOP, .doit = rtm_get_nexthop, + .dumpit = rtm_dump_nexthop}, + {.msgtype = RTM_GETNEXTHOPBUCKET, .doit = rtm_get_nexthop_bucket, + .dumpit = rtm_dump_nexthop_bucket}, + {.protocol = PF_INET, .msgtype = RTM_NEWNEXTHOP, + .doit = rtm_new_nexthop, .flags = RTNL_FLAG_DOIT_PERNET}, + {.protocol = PF_INET, .msgtype = RTM_GETNEXTHOP, + .dumpit = rtm_dump_nexthop}, + {.protocol = PF_INET6, .msgtype = RTM_NEWNEXTHOP, + .doit = rtm_new_nexthop, .flags = RTNL_FLAG_DOIT_PERNET}, + {.protocol = PF_INET6, .msgtype = RTM_GETNEXTHOP, + .dumpit = rtm_dump_nexthop}, }; static int __init nexthop_init(void) @@ -3765,19 +4150,7 @@ static int __init nexthop_init(void) register_netdevice_notifier(&nh_netdev_notifier); - rtnl_register(PF_UNSPEC, RTM_NEWNEXTHOP, rtm_new_nexthop, NULL, 0); - rtnl_register(PF_UNSPEC, RTM_DELNEXTHOP, rtm_del_nexthop, NULL, 0); - rtnl_register(PF_UNSPEC, RTM_GETNEXTHOP, rtm_get_nexthop, - rtm_dump_nexthop, 0); - - rtnl_register(PF_INET, RTM_NEWNEXTHOP, rtm_new_nexthop, NULL, 0); - rtnl_register(PF_INET, RTM_GETNEXTHOP, NULL, rtm_dump_nexthop, 0); - - rtnl_register(PF_INET6, RTM_NEWNEXTHOP, rtm_new_nexthop, NULL, 0); - rtnl_register(PF_INET6, RTM_GETNEXTHOP, NULL, rtm_dump_nexthop, 0); - - rtnl_register(PF_UNSPEC, RTM_GETNEXTHOPBUCKET, rtm_get_nexthop_bucket, - rtm_dump_nexthop_bucket, 0); + rtnl_register_many(nexthop_rtnl_msg_handlers); return 0; } diff --git a/net/ipv4/ping.c b/net/ipv4/ping.c index 0e56df3a45e2..ad56588107cc 100644 --- a/net/ipv4/ping.c +++ b/net/ipv4/ping.c @@ -33,6 +33,7 @@ #include <linux/skbuff.h> #include <linux/proc_fs.h> #include <linux/export.h> +#include <linux/bpf-cgroup.h> #include <net/sock.h> #include <net/ping.h> #include <net/udp.h> @@ -49,15 +50,13 @@ #endif struct ping_table { - struct hlist_nulls_head hash[PING_HTABLE_SIZE]; - rwlock_t lock; + struct hlist_head hash[PING_HTABLE_SIZE]; + spinlock_t lock; }; static struct ping_table ping_table; struct pingv6_ops pingv6_ops; -EXPORT_SYMBOL_GPL(pingv6_ops); - -static u16 ping_port_rover; +EXPORT_IPV6_MOD_GPL(pingv6_ops); static inline u32 ping_hashfn(const struct net *net, u32 num, u32 mask) { @@ -66,33 +65,33 @@ static inline u32 ping_hashfn(const struct net *net, u32 num, u32 mask) pr_debug("hash(%u) = %u\n", num, res); return res; } -EXPORT_SYMBOL_GPL(ping_hash); -static inline struct hlist_nulls_head *ping_hashslot(struct ping_table *table, - struct net *net, unsigned int num) +static inline struct hlist_head *ping_hashslot(struct ping_table *table, + struct net *net, unsigned int num) { return &table->hash[ping_hashfn(net, num, PING_HTABLE_MASK)]; } int ping_get_port(struct sock *sk, unsigned short ident) { - struct hlist_nulls_node *node; - struct hlist_nulls_head *hlist; + struct net *net = sock_net(sk); struct inet_sock *isk, *isk2; + struct hlist_head *hlist; struct sock *sk2 = NULL; isk = inet_sk(sk); - write_lock_bh(&ping_table.lock); + spin_lock(&ping_table.lock); if (ident == 0) { + u16 result = net->ipv4.ping_port_rover + 1; u32 i; - u16 result = ping_port_rover + 1; for (i = 0; i < (1L << 16); i++, result++) { if (!result) - result++; /* avoid zero */ - hlist = ping_hashslot(&ping_table, sock_net(sk), - result); - ping_portaddr_for_each_entry(sk2, node, hlist) { + continue; /* avoid zero */ + hlist = ping_hashslot(&ping_table, net, result); + sk_for_each(sk2, hlist) { + if (!net_eq(sock_net(sk2), net)) + continue; isk2 = inet_sk(sk2); if (isk2->inet_num == result) @@ -100,7 +99,7 @@ int ping_get_port(struct sock *sk, unsigned short ident) } /* found */ - ping_port_rover = ident = result; + net->ipv4.ping_port_rover = ident = result; break; next_port: ; @@ -108,8 +107,10 @@ next_port: if (i >= (1L << 16)) goto fail; } else { - hlist = ping_hashslot(&ping_table, sock_net(sk), ident); - ping_portaddr_for_each_entry(sk2, node, hlist) { + hlist = ping_hashslot(&ping_table, net, ident); + sk_for_each(sk2, hlist) { + if (!net_eq(sock_net(sk2), net)) + continue; isk2 = inet_sk(sk2); /* BUG? Why is this reuse and not reuseaddr? ping.c @@ -127,66 +128,61 @@ next_port: isk->inet_num = ident; if (sk_unhashed(sk)) { pr_debug("was not hashed\n"); - sock_hold(sk); - hlist_nulls_add_head(&sk->sk_nulls_node, hlist); - sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + sk_add_node_rcu(sk, hlist); + sock_set_flag(sk, SOCK_RCU_FREE); + sock_prot_inuse_add(net, sk->sk_prot, 1); } - write_unlock_bh(&ping_table.lock); + spin_unlock(&ping_table.lock); return 0; fail: - write_unlock_bh(&ping_table.lock); - return 1; -} -EXPORT_SYMBOL_GPL(ping_get_port); - -int ping_hash(struct sock *sk) -{ - pr_debug("ping_hash(sk->port=%u)\n", inet_sk(sk)->inet_num); - BUG(); /* "Please do not press this button again." */ - - return 0; + spin_unlock(&ping_table.lock); + return -EADDRINUSE; } +EXPORT_IPV6_MOD_GPL(ping_get_port); void ping_unhash(struct sock *sk) { struct inet_sock *isk = inet_sk(sk); pr_debug("ping_unhash(isk=%p,isk->num=%u)\n", isk, isk->inet_num); - write_lock_bh(&ping_table.lock); - if (sk_hashed(sk)) { - hlist_nulls_del(&sk->sk_nulls_node); - sk_nulls_node_init(&sk->sk_nulls_node); - sock_put(sk); + spin_lock(&ping_table.lock); + if (sk_del_node_init_rcu(sk)) { isk->inet_num = 0; isk->inet_sport = 0; sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); } - write_unlock_bh(&ping_table.lock); + spin_unlock(&ping_table.lock); } -EXPORT_SYMBOL_GPL(ping_unhash); +EXPORT_IPV6_MOD_GPL(ping_unhash); +/* Called under rcu_read_lock() */ static struct sock *ping_lookup(struct net *net, struct sk_buff *skb, u16 ident) { - struct hlist_nulls_head *hslot = ping_hashslot(&ping_table, net, ident); + struct hlist_head *hslot = ping_hashslot(&ping_table, net, ident); struct sock *sk = NULL; struct inet_sock *isk; - struct hlist_nulls_node *hnode; - int dif = skb->dev->ifindex; + int dif, sdif; if (skb->protocol == htons(ETH_P_IP)) { + dif = inet_iif(skb); + sdif = inet_sdif(skb); pr_debug("try to find: num = %d, daddr = %pI4, dif = %d\n", (int)ident, &ip_hdr(skb)->daddr, dif); #if IS_ENABLED(CONFIG_IPV6) } else if (skb->protocol == htons(ETH_P_IPV6)) { + dif = inet6_iif(skb); + sdif = inet6_sdif(skb); pr_debug("try to find: num = %d, daddr = %pI6c, dif = %d\n", (int)ident, &ipv6_hdr(skb)->daddr, dif); #endif + } else { + return NULL; } - read_lock_bh(&ping_table.lock); - - ping_portaddr_for_each_entry(sk, hnode, hslot) { + sk_for_each_rcu(sk, hslot) { + if (!net_eq(sock_net(sk), net)) + continue; isk = inet_sk(sk); pr_debug("iterate\n"); @@ -220,16 +216,15 @@ static struct sock *ping_lookup(struct net *net, struct sk_buff *skb, u16 ident) continue; } - if (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif) + if (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif && + sk->sk_bound_dev_if != sdif) continue; - sock_hold(sk); goto exit; } sk = NULL; exit: - read_unlock_bh(&ping_table.lock); return sk; } @@ -279,7 +274,7 @@ out_release_group: put_group_info(group_info); return ret; } -EXPORT_SYMBOL_GPL(ping_init_sock); +EXPORT_IPV6_MOD_GPL(ping_init_sock); void ping_close(struct sock *sk, long timeout) { @@ -289,15 +284,29 @@ void ping_close(struct sock *sk, long timeout) sk_common_release(sk); } -EXPORT_SYMBOL_GPL(ping_close); +EXPORT_IPV6_MOD_GPL(ping_close); + +static int ping_pre_connect(struct sock *sk, struct sockaddr_unsized *uaddr, + int addr_len) +{ + /* This check is replicated from __ip4_datagram_connect() and + * intended to prevent BPF program called below from accessing bytes + * that are out of the bound specified by user in addr_len. + */ + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + return BPF_CGROUP_RUN_PROG_INET4_CONNECT_LOCK(sk, uaddr, &addr_len); +} /* Checks the bind address and possibly modifies sk->sk_bound_dev_if. */ static int ping_check_bind_addr(struct sock *sk, struct inet_sock *isk, - struct sockaddr *uaddr, int addr_len) + struct sockaddr_unsized *uaddr, int addr_len) { struct net *net = sock_net(sk); if (sk->sk_family == AF_INET) { struct sockaddr_in *addr = (struct sockaddr_in *) uaddr; + u32 tb_id = RT_TABLE_LOCAL; int chk_addr_ret; if (addr_len < sizeof(*addr)) @@ -311,11 +320,16 @@ static int ping_check_bind_addr(struct sock *sk, struct inet_sock *isk, pr_debug("ping_check_bind_addr(sk=%p,addr=%pI4,port=%d)\n", sk, &addr->sin_addr.s_addr, ntohs(addr->sin_port)); - chk_addr_ret = inet_addr_type(net, addr->sin_addr.s_addr); + if (addr->sin_addr.s_addr == htonl(INADDR_ANY)) + return 0; + + tb_id = l3mdev_fib_table_by_index(net, sk->sk_bound_dev_if) ? : tb_id; + chk_addr_ret = inet_addr_type_table(net, addr->sin_addr.s_addr, tb_id); - if (!inet_addr_valid_or_nonlocal(net, inet_sk(sk), - addr->sin_addr.s_addr, - chk_addr_ret)) + if (chk_addr_ret == RTN_MULTICAST || + chk_addr_ret == RTN_BROADCAST || + (chk_addr_ret != RTN_LOCAL && + !inet_can_nonlocal_bind(net, isk))) return -EADDRNOTAVAIL; #if IS_ENABLED(CONFIG_IPV6) @@ -348,6 +362,14 @@ static int ping_check_bind_addr(struct sock *sk, struct inet_sock *isk, return -ENODEV; } } + + if (!dev && sk->sk_bound_dev_if) { + dev = dev_get_by_index_rcu(net, sk->sk_bound_dev_if); + if (!dev) { + rcu_read_unlock(); + return -ENODEV; + } + } has_addr = pingv6_ops.ipv6_chk_addr(net, &addr->sin6_addr, dev, scoped); rcu_read_unlock(); @@ -365,7 +387,7 @@ static int ping_check_bind_addr(struct sock *sk, struct inet_sock *isk, return 0; } -static void ping_set_saddr(struct sock *sk, struct sockaddr *saddr) +static void ping_set_saddr(struct sock *sk, struct sockaddr_unsized *saddr) { if (saddr->sa_family == AF_INET) { struct inet_sock *isk = inet_sk(sk); @@ -385,7 +407,7 @@ static void ping_set_saddr(struct sock *sk, struct sockaddr *saddr) * Moreover, we don't allow binding to multi- and broadcast addresses. */ -int ping_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) +int ping_bind(struct sock *sk, struct sockaddr_unsized *uaddr, int addr_len) { struct inet_sock *isk = inet_sk(sk); unsigned short snum; @@ -440,7 +462,7 @@ out: pr_debug("ping_v4_bind -> %d\n", err); return err; } -EXPORT_SYMBOL_GPL(ping_bind); +EXPORT_IPV6_MOD_GPL(ping_bind); /* * Is this a supported type of ICMP message? @@ -524,7 +546,7 @@ void ping_err(struct sk_buff *skb, int offset, u32 info) case ICMP_DEST_UNREACH: if (code == ICMP_FRAG_NEEDED) { /* Path MTU discovery */ ipv4_sk_update_pmtu(skb, sk, info); - if (inet_sock->pmtudisc != IP_PMTUDISC_DONT) { + if (READ_ONCE(inet_sock->pmtudisc) != IP_PMTUDISC_DONT) { err = EMSGSIZE; harderr = 1; break; @@ -553,8 +575,8 @@ void ping_err(struct sk_buff *skb, int offset, u32 info) * RFC1122: OK. Passes ICMP errors back to application, as per * 4.1.3.3. */ - if ((family == AF_INET && !inet_sock->recverr) || - (family == AF_INET6 && !inet6_sk(sk)->recverr)) { + if ((family == AF_INET && !inet_test_bit(RECVERR, sk)) || + (family == AF_INET6 && !inet6_test_bit(RECVERR6, sk))) { if (!harderr || sk->sk_state != TCP_ESTABLISHED) goto out; } else { @@ -571,9 +593,9 @@ void ping_err(struct sk_buff *skb, int offset, u32 info) sk->sk_err = err; sk_error_report(sk); out: - sock_put(sk); + return; } -EXPORT_SYMBOL_GPL(ping_err); +EXPORT_IPV6_MOD_GPL(ping_err); /* * Copy and checksum an ICMP Echo packet from user space into a buffer @@ -583,23 +605,11 @@ EXPORT_SYMBOL_GPL(ping_err); int ping_getfrag(void *from, char *to, int offset, int fraglen, int odd, struct sk_buff *skb) { - struct pingfakehdr *pfh = (struct pingfakehdr *)from; - - if (offset == 0) { - fraglen -= sizeof(struct icmphdr); - if (fraglen < 0) - BUG(); - if (!csum_and_copy_from_iter_full(to + sizeof(struct icmphdr), - fraglen, &pfh->wcheck, - &pfh->msg->msg_iter)) - return -EFAULT; - } else if (offset < sizeof(struct icmphdr)) { - BUG(); - } else { - if (!csum_and_copy_from_iter_full(to, fraglen, &pfh->wcheck, - &pfh->msg->msg_iter)) - return -EFAULT; - } + struct pingfakehdr *pfh = from; + + if (!csum_and_copy_from_iter_full(to, fraglen, &pfh->wcheck, + &pfh->msg->msg_iter)) + return -EFAULT; #if IS_ENABLED(CONFIG_IPV6) /* For IPv6, checksum each skb as we go along, as expected by @@ -607,7 +617,7 @@ int ping_getfrag(void *from, char *to, * wcheck, it will be finalized in ping_v4_push_pending_frames. */ if (pfh->family == AF_INET6) { - skb->csum = pfh->wcheck; + skb->csum = csum_block_add(skb->csum, pfh->wcheck, odd); skb->ip_summed = CHECKSUM_NONE; pfh->wcheck = 0; } @@ -615,7 +625,7 @@ int ping_getfrag(void *from, char *to, return 0; } -EXPORT_SYMBOL_GPL(ping_getfrag); +EXPORT_IPV6_MOD_GPL(ping_getfrag); static int ping_v4_push_pending_frames(struct sock *sk, struct pingfakehdr *pfh, struct flowi4 *fl4) @@ -676,7 +686,7 @@ int ping_common_sendmsg(int family, struct msghdr *msg, size_t len, return 0; } -EXPORT_SYMBOL_GPL(ping_common_sendmsg); +EXPORT_IPV6_MOD_GPL(ping_common_sendmsg); static int ping_v4_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) { @@ -690,7 +700,7 @@ static int ping_v4_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) struct ip_options_data opt_copy; int free = 0; __be32 saddr, daddr, faddr; - u8 tos; + u8 scope; int err; pr_debug("ping_v4_sendmsg(sk=%p,sk->num=%u)\n", inet, inet->inet_num); @@ -753,25 +763,20 @@ static int ping_v4_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) } faddr = ipc.opt->opt.faddr; } - tos = get_rttos(&ipc, inet); - if (sock_flag(sk, SOCK_LOCALROUTE) || - (msg->msg_flags & MSG_DONTROUTE) || - (ipc.opt && ipc.opt->opt.is_strictroute)) { - tos |= RTO_ONLINK; - } + scope = ip_sendmsg_scope(inet, &ipc, msg); if (ipv4_is_multicast(daddr)) { if (!ipc.oif || netif_index_is_l3_master(sock_net(sk), ipc.oif)) - ipc.oif = inet->mc_index; + ipc.oif = READ_ONCE(inet->mc_index); if (!saddr) - saddr = inet->mc_addr; + saddr = READ_ONCE(inet->mc_addr); } else if (!ipc.oif) - ipc.oif = inet->uc_index; + ipc.oif = READ_ONCE(inet->uc_index); - flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark, tos, - RT_SCOPE_UNIVERSE, sk->sk_protocol, - inet_sk_flowi_flags(sk), faddr, saddr, 0, 0, - sk->sk_uid); + flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark, + ipc.tos & INET_DSCP_MASK, scope, + sk->sk_protocol, inet_sk_flowi_flags(sk), faddr, + saddr, 0, 0, sk_uid(sk)); fl4.fl4_icmp_type = user_icmph.type; fl4.fl4_icmp_code = user_icmph.code; @@ -810,7 +815,8 @@ back_from_confirm: pfh.family = AF_INET; err = ip_append_data(sk, &fl4, ping_getfrag, &pfh, len, - 0, &ipc, &rt, msg->msg_flags); + sizeof(struct icmphdr), &ipc, &rt, + msg->msg_flags); if (err) ip_flush_pending_frames(sk); else @@ -837,8 +843,8 @@ do_confirm: goto out; } -int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, - int flags, int *addr_len) +int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len) { struct inet_sock *isk = inet_sk(sk); int family = sk->sk_family; @@ -854,7 +860,7 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, if (flags & MSG_ERRQUEUE) return inet_recv_error(sk, msg, len, addr_len); - skb = skb_recv_datagram(sk, flags, noblock, &err); + skb = skb_recv_datagram(sk, flags, &err); if (!skb) goto out; @@ -883,12 +889,11 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, *addr_len = sizeof(*sin); } - if (isk->cmsg_flags) + if (inet_cmsg_flags(isk)) ip_cmsg_recv(msg, skb); #if IS_ENABLED(CONFIG_IPV6) } else if (family == AF_INET6) { - struct ipv6_pinfo *np = inet6_sk(sk); struct ipv6hdr *ip6 = ipv6_hdr(skb); DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name); @@ -897,7 +902,7 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, sin6->sin6_port = 0; sin6->sin6_addr = ip6->saddr; sin6->sin6_flowinfo = 0; - if (np->sndflow) + if (inet6_test_bit(SNDFLOW, sk)) sin6->sin6_flowinfo = ip6_flowinfo(ip6); sin6->sin6_scope_id = ipv6_iface_scope_id(&sin6->sin6_addr, @@ -910,7 +915,8 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, if (skb->protocol == htons(ETH_P_IPV6) && inet6_sk(sk)->rxopt.all) pingv6_ops.ip6_datagram_recv_specific_ctl(sk, msg, skb); - else if (skb->protocol == htons(ETH_P_IP) && isk->cmsg_flags) + else if (skb->protocol == htons(ETH_P_IP) && + inet_cmsg_flags(isk)) ip_cmsg_recv(msg, skb); #endif } else { @@ -925,32 +931,39 @@ out: pr_debug("ping_recvmsg -> %d\n", err); return err; } -EXPORT_SYMBOL_GPL(ping_recvmsg); +EXPORT_IPV6_MOD_GPL(ping_recvmsg); -int ping_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +static enum skb_drop_reason __ping_queue_rcv_skb(struct sock *sk, + struct sk_buff *skb) { + enum skb_drop_reason reason; + pr_debug("ping_queue_rcv_skb(sk=%p,sk->num=%d,skb=%p)\n", inet_sk(sk), inet_sk(sk)->inet_num, skb); - if (sock_queue_rcv_skb(sk, skb) < 0) { - kfree_skb(skb); + if (sock_queue_rcv_skb_reason(sk, skb, &reason) < 0) { + sk_skb_reason_drop(sk, skb, reason); pr_debug("ping_queue_rcv_skb -> failed\n"); - return -1; + return reason; } - return 0; + return SKB_NOT_DROPPED_YET; +} + +int ping_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + return __ping_queue_rcv_skb(sk, skb) ? -1 : 0; } -EXPORT_SYMBOL_GPL(ping_queue_rcv_skb); +EXPORT_IPV6_MOD_GPL(ping_queue_rcv_skb); /* * All we need to do is get the socket. */ -bool ping_rcv(struct sk_buff *skb) +enum skb_drop_reason ping_rcv(struct sk_buff *skb) { - struct sock *sk; struct net *net = dev_net(skb->dev); struct icmphdr *icmph = icmp_hdr(skb); - bool rc = false; + struct sock *sk; /* We assume the packet has already been checked by icmp_rcv */ @@ -961,27 +974,20 @@ bool ping_rcv(struct sk_buff *skb) skb_push(skb, skb->data - (u8 *)icmph); sk = ping_lookup(net, skb, ntohs(icmph->un.echo.id)); - if (sk) { - struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); - - pr_debug("rcv on socket %p\n", sk); - if (skb2 && !ping_queue_rcv_skb(sk, skb2)) - rc = true; - sock_put(sk); - } - - if (!rc) - pr_debug("no socket, dropping\n"); + if (sk) + return __ping_queue_rcv_skb(sk, skb); - return rc; + kfree_skb_reason(skb, SKB_DROP_REASON_NO_SOCKET); + return SKB_DROP_REASON_NO_SOCKET; } -EXPORT_SYMBOL_GPL(ping_rcv); +EXPORT_IPV6_MOD_GPL(ping_rcv); struct proto ping_prot = { .name = "PING", .owner = THIS_MODULE, .init = ping_init_sock, .close = ping_close, + .pre_connect = ping_pre_connect, .connect = ip4_datagram_connect, .disconnect = __udp_disconnect, .setsockopt = ip_setsockopt, @@ -991,13 +997,12 @@ struct proto ping_prot = { .bind = ping_bind, .backlog_rcv = ping_queue_rcv_skb, .release_cb = ip4_datagram_release_cb, - .hash = ping_hash, .unhash = ping_unhash, .get_port = ping_get_port, .put_port = ping_unhash, .obj_size = sizeof(struct inet_sock), }; -EXPORT_SYMBOL(ping_prot); +EXPORT_IPV6_MOD(ping_prot); #ifdef CONFIG_PROC_FS @@ -1009,15 +1014,14 @@ static struct sock *ping_get_first(struct seq_file *seq, int start) for (state->bucket = start; state->bucket < PING_HTABLE_SIZE; ++state->bucket) { - struct hlist_nulls_node *node; - struct hlist_nulls_head *hslot; + struct hlist_head *hslot; hslot = &ping_table.hash[state->bucket]; - if (hlist_nulls_empty(hslot)) + if (hlist_empty(hslot)) continue; - sk_nulls_for_each(sk, node, hslot) { + sk_for_each(sk, hslot) { if (net_eq(sock_net(sk), net) && sk->sk_family == state->family) goto found; @@ -1034,7 +1038,7 @@ static struct sock *ping_get_next(struct seq_file *seq, struct sock *sk) struct net *net = seq_file_net(seq); do { - sk = sk_nulls_next(sk); + sk = sk_next(sk); } while (sk && (!net_eq(sock_net(sk), net))); if (!sk) @@ -1059,11 +1063,11 @@ void *ping_seq_start(struct seq_file *seq, loff_t *pos, sa_family_t family) state->bucket = 0; state->family = family; - read_lock_bh(&ping_table.lock); + spin_lock(&ping_table.lock); return *pos ? ping_get_idx(seq, *pos-1) : SEQ_START_TOKEN; } -EXPORT_SYMBOL_GPL(ping_seq_start); +EXPORT_IPV6_MOD_GPL(ping_seq_start); static void *ping_v4_seq_start(struct seq_file *seq, loff_t *pos) { @@ -1082,14 +1086,14 @@ void *ping_seq_next(struct seq_file *seq, void *v, loff_t *pos) ++*pos; return sk; } -EXPORT_SYMBOL_GPL(ping_seq_next); +EXPORT_IPV6_MOD_GPL(ping_seq_next); void ping_seq_stop(struct seq_file *seq, void *v) __releases(ping_table.lock) { - read_unlock_bh(&ping_table.lock); + spin_unlock(&ping_table.lock); } -EXPORT_SYMBOL_GPL(ping_seq_stop); +EXPORT_IPV6_MOD_GPL(ping_seq_stop); static void ping_v4_format_sock(struct sock *sp, struct seq_file *f, int bucket) @@ -1106,10 +1110,10 @@ static void ping_v4_format_sock(struct sock *sp, struct seq_file *f, sk_wmem_alloc_get(sp), sk_rmem_alloc_get(sp), 0, 0L, 0, - from_kuid_munged(seq_user_ns(f), sock_i_uid(sp)), + from_kuid_munged(seq_user_ns(f), sk_uid(sp)), 0, sock_i_ino(sp), refcount_read(&sp->sk_refcnt), sp, - atomic_read(&sp->sk_drops)); + sk_drops_read(sp)); } static int ping_v4_seq_show(struct seq_file *seq, void *v) @@ -1140,6 +1144,8 @@ static int __net_init ping_v4_proc_init_net(struct net *net) if (!proc_create_net("icmp", 0444, net->proc_net, &ping_v4_seq_ops, sizeof(struct ping_iter_state))) return -ENOMEM; + + net->ipv4.ping_port_rover = get_random_u16(); return 0; } @@ -1170,6 +1176,6 @@ void __init ping_init(void) int i; for (i = 0; i < PING_HTABLE_SIZE; i++) - INIT_HLIST_NULLS_HEAD(&ping_table.hash[i], i); - rwlock_init(&ping_table.lock); + INIT_HLIST_HEAD(&ping_table.hash[i]); + spin_lock_init(&ping_table.lock); } diff --git a/net/ipv4/proc.c b/net/ipv4/proc.c index f30273afb539..974afc4ecbe2 100644 --- a/net/ipv4/proc.c +++ b/net/ipv4/proc.c @@ -33,6 +33,7 @@ #include <net/protocol.h> #include <net/tcp.h> #include <net/mptcp.h> +#include <net/proto_memory.h> #include <net/udp.h> #include <net/udplite.h> #include <linux/bottom_half.h> @@ -43,7 +44,7 @@ #include <net/sock.h> #include <net/raw.h> -#define TCPUDP_MIB_MAX max_t(u32, UDP_MIB_MAX, TCP_MIB_MAX) +#define TCPUDP_MIB_MAX MAX_T(u32, UDP_MIB_MAX, TCP_MIB_MAX) /* * Report socket allocation statistics [mea@utu.fi] @@ -59,8 +60,8 @@ static int sockstat_seq_show(struct seq_file *seq, void *v) socket_seq_show(seq); seq_printf(seq, "TCP: inuse %d orphan %d tw %d alloc %d mem %ld\n", sock_prot_inuse_get(net, &tcp_prot), orphans, - atomic_read(&net->ipv4.tcp_death_row.tw_count), sockets, - proto_memory_allocated(&tcp_prot)); + refcount_read(&net->ipv4.tcp_death_row.tw_refcount) - 1, + sockets, proto_memory_allocated(&tcp_prot)); seq_printf(seq, "UDP: inuse %d mem %ld\n", sock_prot_inuse_get(net, &udp_prot), proto_memory_allocated(&udp_prot)); @@ -83,7 +84,7 @@ static const struct snmp_mib snmp4_ipstats_list[] = { SNMP_MIB_ITEM("InUnknownProtos", IPSTATS_MIB_INUNKNOWNPROTOS), SNMP_MIB_ITEM("InDiscards", IPSTATS_MIB_INDISCARDS), SNMP_MIB_ITEM("InDelivers", IPSTATS_MIB_INDELIVERS), - SNMP_MIB_ITEM("OutRequests", IPSTATS_MIB_OUTPKTS), + SNMP_MIB_ITEM("OutRequests", IPSTATS_MIB_OUTREQUESTS), SNMP_MIB_ITEM("OutDiscards", IPSTATS_MIB_OUTDISCARDS), SNMP_MIB_ITEM("OutNoRoutes", IPSTATS_MIB_OUTNOROUTES), SNMP_MIB_ITEM("ReasmTimeout", IPSTATS_MIB_REASMTIMEOUT), @@ -93,7 +94,7 @@ static const struct snmp_mib snmp4_ipstats_list[] = { SNMP_MIB_ITEM("FragOKs", IPSTATS_MIB_FRAGOKS), SNMP_MIB_ITEM("FragFails", IPSTATS_MIB_FRAGFAILS), SNMP_MIB_ITEM("FragCreates", IPSTATS_MIB_FRAGCREATES), - SNMP_MIB_SENTINEL + SNMP_MIB_ITEM("OutTransmits", IPSTATS_MIB_OUTPKTS), }; /* Following items are displayed in /proc/net/netstat */ @@ -117,7 +118,6 @@ static const struct snmp_mib snmp4_ipextstats_list[] = { SNMP_MIB_ITEM("InECT0Pkts", IPSTATS_MIB_ECT0PKTS), SNMP_MIB_ITEM("InCEPkts", IPSTATS_MIB_CEPKTS), SNMP_MIB_ITEM("ReasmOverlaps", IPSTATS_MIB_REASM_OVERLAPS), - SNMP_MIB_SENTINEL }; static const struct { @@ -155,7 +155,6 @@ static const struct snmp_mib snmp4_tcp_list[] = { SNMP_MIB_ITEM("InErrs", TCP_MIB_INERRS), SNMP_MIB_ITEM("OutRsts", TCP_MIB_OUTRSTS), SNMP_MIB_ITEM("InCsumErrors", TCP_MIB_CSUMERRORS), - SNMP_MIB_SENTINEL }; static const struct snmp_mib snmp4_udp_list[] = { @@ -168,7 +167,6 @@ static const struct snmp_mib snmp4_udp_list[] = { SNMP_MIB_ITEM("InCsumErrors", UDP_MIB_CSUMERRORS), SNMP_MIB_ITEM("IgnoredMulti", UDP_MIB_IGNOREDMULTI), SNMP_MIB_ITEM("MemErrors", UDP_MIB_MEMERRORS), - SNMP_MIB_SENTINEL }; static const struct snmp_mib snmp4_net_list[] = { @@ -187,6 +185,10 @@ static const struct snmp_mib snmp4_net_list[] = { SNMP_MIB_ITEM("TWKilled", LINUX_MIB_TIMEWAITKILLED), SNMP_MIB_ITEM("PAWSActive", LINUX_MIB_PAWSACTIVEREJECTED), SNMP_MIB_ITEM("PAWSEstab", LINUX_MIB_PAWSESTABREJECTED), + SNMP_MIB_ITEM("BeyondWindow", LINUX_MIB_BEYOND_WINDOW), + SNMP_MIB_ITEM("TSEcrRejected", LINUX_MIB_TSECRREJECTED), + SNMP_MIB_ITEM("PAWSOldAck", LINUX_MIB_PAWS_OLD_ACK), + SNMP_MIB_ITEM("PAWSTimewait", LINUX_MIB_PAWS_TW_REJECTED), SNMP_MIB_ITEM("DelayedACKs", LINUX_MIB_DELAYEDACKS), SNMP_MIB_ITEM("DelayedACKLocked", LINUX_MIB_DELAYEDACKLOCKED), SNMP_MIB_ITEM("DelayedACKLost", LINUX_MIB_DELAYEDACKLOST), @@ -297,7 +299,12 @@ static const struct snmp_mib snmp4_net_list[] = { SNMP_MIB_ITEM("TCPDSACKIgnoredDubious", LINUX_MIB_TCPDSACKIGNOREDDUBIOUS), SNMP_MIB_ITEM("TCPMigrateReqSuccess", LINUX_MIB_TCPMIGRATEREQSUCCESS), SNMP_MIB_ITEM("TCPMigrateReqFailure", LINUX_MIB_TCPMIGRATEREQFAILURE), - SNMP_MIB_SENTINEL + SNMP_MIB_ITEM("TCPPLBRehash", LINUX_MIB_TCPPLBREHASH), + SNMP_MIB_ITEM("TCPAORequired", LINUX_MIB_TCPAOREQUIRED), + SNMP_MIB_ITEM("TCPAOBad", LINUX_MIB_TCPAOBAD), + SNMP_MIB_ITEM("TCPAOKeyNotFound", LINUX_MIB_TCPAOKEYNOTFOUND), + SNMP_MIB_ITEM("TCPAOGood", LINUX_MIB_TCPAOGOOD), + SNMP_MIB_ITEM("TCPAODroppedIcmps", LINUX_MIB_TCPAODROPPEDICMPS), }; static void icmpmsg_put_line(struct seq_file *seq, unsigned long *vals, @@ -352,7 +359,7 @@ static void icmp_put(struct seq_file *seq) seq_puts(seq, "\nIcmp: InMsgs InErrors InCsumErrors"); for (i = 0; icmpmibmap[i].name; i++) seq_printf(seq, " In%s", icmpmibmap[i].name); - seq_puts(seq, " OutMsgs OutErrors"); + seq_puts(seq, " OutMsgs OutErrors OutRateLimitGlobal OutRateLimitHost"); for (i = 0; icmpmibmap[i].name; i++) seq_printf(seq, " Out%s", icmpmibmap[i].name); seq_printf(seq, "\nIcmp: %lu %lu %lu", @@ -362,9 +369,11 @@ static void icmp_put(struct seq_file *seq) for (i = 0; icmpmibmap[i].name; i++) seq_printf(seq, " %lu", atomic_long_read(ptr + icmpmibmap[i].index)); - seq_printf(seq, " %lu %lu", + seq_printf(seq, " %lu %lu %lu %lu", snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_OUTMSGS), - snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_OUTERRORS)); + snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_OUTERRORS), + snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_RATELIMITGLOBAL), + snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_RATELIMITHOST)); for (i = 0; icmpmibmap[i].name; i++) seq_printf(seq, " %lu", atomic_long_read(ptr + (icmpmibmap[i].index | 0x100))); @@ -375,25 +384,26 @@ static void icmp_put(struct seq_file *seq) */ static int snmp_seq_show_ipstats(struct seq_file *seq, void *v) { + const int cnt = ARRAY_SIZE(snmp4_ipstats_list); + u64 buff64[ARRAY_SIZE(snmp4_ipstats_list)]; struct net *net = seq->private; - u64 buff64[IPSTATS_MIB_MAX]; int i; - memset(buff64, 0, IPSTATS_MIB_MAX * sizeof(u64)); + memset(buff64, 0, sizeof(buff64)); seq_puts(seq, "Ip: Forwarding DefaultTTL"); - for (i = 0; snmp4_ipstats_list[i].name; i++) + for (i = 0; i < cnt; i++) seq_printf(seq, " %s", snmp4_ipstats_list[i].name); seq_printf(seq, "\nIp: %d %d", - IPV4_DEVCONF_ALL(net, FORWARDING) ? 1 : 2, - net->ipv4.sysctl_ip_default_ttl); + IPV4_DEVCONF_ALL_RO(net, FORWARDING) ? 1 : 2, + READ_ONCE(net->ipv4.sysctl_ip_default_ttl)); BUILD_BUG_ON(offsetof(struct ipstats_mib, mibs) != 0); - snmp_get_cpu_field64_batch(buff64, snmp4_ipstats_list, - net->mib.ip_statistics, - offsetof(struct ipstats_mib, syncp)); - for (i = 0; snmp4_ipstats_list[i].name; i++) + snmp_get_cpu_field64_batch_cnt(buff64, snmp4_ipstats_list, cnt, + net->mib.ip_statistics, + offsetof(struct ipstats_mib, syncp)); + for (i = 0; i < cnt; i++) seq_printf(seq, " %llu", buff64[i]); return 0; @@ -401,20 +411,23 @@ static int snmp_seq_show_ipstats(struct seq_file *seq, void *v) static int snmp_seq_show_tcp_udp(struct seq_file *seq, void *v) { + const int udp_cnt = ARRAY_SIZE(snmp4_udp_list); + const int tcp_cnt = ARRAY_SIZE(snmp4_tcp_list); unsigned long buff[TCPUDP_MIB_MAX]; struct net *net = seq->private; int i; - memset(buff, 0, TCPUDP_MIB_MAX * sizeof(unsigned long)); + memset(buff, 0, tcp_cnt * sizeof(unsigned long)); seq_puts(seq, "\nTcp:"); - for (i = 0; snmp4_tcp_list[i].name; i++) + for (i = 0; i < tcp_cnt; i++) seq_printf(seq, " %s", snmp4_tcp_list[i].name); seq_puts(seq, "\nTcp:"); - snmp_get_cpu_field_batch(buff, snmp4_tcp_list, - net->mib.tcp_statistics); - for (i = 0; snmp4_tcp_list[i].name; i++) { + snmp_get_cpu_field_batch_cnt(buff, snmp4_tcp_list, + tcp_cnt, + net->mib.tcp_statistics); + for (i = 0; i < tcp_cnt; i++) { /* MaxConn field is signed, RFC 2012 */ if (snmp4_tcp_list[i].entry == TCP_MIB_MAXCONN) seq_printf(seq, " %ld", buff[i]); @@ -422,27 +435,29 @@ static int snmp_seq_show_tcp_udp(struct seq_file *seq, void *v) seq_printf(seq, " %lu", buff[i]); } - memset(buff, 0, TCPUDP_MIB_MAX * sizeof(unsigned long)); + memset(buff, 0, udp_cnt * sizeof(unsigned long)); - snmp_get_cpu_field_batch(buff, snmp4_udp_list, - net->mib.udp_statistics); + snmp_get_cpu_field_batch_cnt(buff, snmp4_udp_list, + udp_cnt, + net->mib.udp_statistics); seq_puts(seq, "\nUdp:"); - for (i = 0; snmp4_udp_list[i].name; i++) + for (i = 0; i < udp_cnt; i++) seq_printf(seq, " %s", snmp4_udp_list[i].name); seq_puts(seq, "\nUdp:"); - for (i = 0; snmp4_udp_list[i].name; i++) + for (i = 0; i < udp_cnt; i++) seq_printf(seq, " %lu", buff[i]); - memset(buff, 0, TCPUDP_MIB_MAX * sizeof(unsigned long)); + memset(buff, 0, udp_cnt * sizeof(unsigned long)); /* the UDP and UDP-Lite MIBs are the same */ seq_puts(seq, "\nUdpLite:"); - snmp_get_cpu_field_batch(buff, snmp4_udp_list, - net->mib.udplite_statistics); - for (i = 0; snmp4_udp_list[i].name; i++) + snmp_get_cpu_field_batch_cnt(buff, snmp4_udp_list, + udp_cnt, + net->mib.udplite_statistics); + for (i = 0; i < udp_cnt; i++) seq_printf(seq, " %s", snmp4_udp_list[i].name); seq_puts(seq, "\nUdpLite:"); - for (i = 0; snmp4_udp_list[i].name; i++) + for (i = 0; i < udp_cnt; i++) seq_printf(seq, " %lu", buff[i]); seq_putc(seq, '\n'); @@ -466,8 +481,8 @@ static int snmp_seq_show(struct seq_file *seq, void *v) */ static int netstat_seq_show(struct seq_file *seq, void *v) { - const int ip_cnt = ARRAY_SIZE(snmp4_ipextstats_list) - 1; - const int tcp_cnt = ARRAY_SIZE(snmp4_net_list) - 1; + const int ip_cnt = ARRAY_SIZE(snmp4_ipextstats_list); + const int tcp_cnt = ARRAY_SIZE(snmp4_net_list); struct net *net = seq->private; unsigned long *buff; int i; @@ -480,8 +495,8 @@ static int netstat_seq_show(struct seq_file *seq, void *v) buff = kzalloc(max(tcp_cnt * sizeof(long), ip_cnt * sizeof(u64)), GFP_KERNEL); if (buff) { - snmp_get_cpu_field_batch(buff, snmp4_net_list, - net->mib.net_statistics); + snmp_get_cpu_field_batch_cnt(buff, snmp4_net_list, tcp_cnt, + net->mib.net_statistics); for (i = 0; i < tcp_cnt; i++) seq_printf(seq, " %lu", buff[i]); } else { @@ -499,7 +514,7 @@ static int netstat_seq_show(struct seq_file *seq, void *v) u64 *buff64 = (u64 *)buff; memset(buff64, 0, ip_cnt * sizeof(u64)); - snmp_get_cpu_field64_batch(buff64, snmp4_ipextstats_list, + snmp_get_cpu_field64_batch_cnt(buff64, snmp4_ipextstats_list, ip_cnt, net->mib.ip_statistics, offsetof(struct ipstats_mib, syncp)); for (i = 0; i < ip_cnt; i++) diff --git a/net/ipv4/raw.c b/net/ipv4/raw.c index 9eb5fc247868..5998c4cc6f47 100644 --- a/net/ipv4/raw.c +++ b/net/ipv4/raw.c @@ -85,21 +85,20 @@ struct raw_frag_vec { int hlen; }; -struct raw_hashinfo raw_v4_hashinfo = { - .lock = __RW_LOCK_UNLOCKED(raw_v4_hashinfo.lock), -}; +struct raw_hashinfo raw_v4_hashinfo; EXPORT_SYMBOL_GPL(raw_v4_hashinfo); int raw_hash_sk(struct sock *sk) { struct raw_hashinfo *h = sk->sk_prot->h.raw_hash; - struct hlist_head *head; + struct hlist_head *hlist; - head = &h->ht[inet_sk(sk)->inet_num & (RAW_HTABLE_SIZE - 1)]; + hlist = &h->ht[raw_hashfunc(sock_net(sk), inet_sk(sk)->inet_num)]; - write_lock_bh(&h->lock); - sk_add_node(sk, head); - write_unlock_bh(&h->lock); + spin_lock(&h->lock); + sk_add_node_rcu(sk, hlist); + sock_set_flag(sk, SOCK_RCU_FREE); + spin_unlock(&h->lock); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); return 0; @@ -110,31 +109,26 @@ void raw_unhash_sk(struct sock *sk) { struct raw_hashinfo *h = sk->sk_prot->h.raw_hash; - write_lock_bh(&h->lock); - if (sk_del_node_init(sk)) + spin_lock(&h->lock); + if (sk_del_node_init_rcu(sk)) sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); - write_unlock_bh(&h->lock); + spin_unlock(&h->lock); } EXPORT_SYMBOL_GPL(raw_unhash_sk); -struct sock *__raw_v4_lookup(struct net *net, struct sock *sk, - unsigned short num, __be32 raddr, __be32 laddr, - int dif, int sdif) +bool raw_v4_match(struct net *net, const struct sock *sk, unsigned short num, + __be32 raddr, __be32 laddr, int dif, int sdif) { - sk_for_each_from(sk) { - struct inet_sock *inet = inet_sk(sk); - - if (net_eq(sock_net(sk), net) && inet->inet_num == num && - !(inet->inet_daddr && inet->inet_daddr != raddr) && - !(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) && - raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) - goto found; /* gotcha */ - } - sk = NULL; -found: - return sk; + const struct inet_sock *inet = inet_sk(sk); + + if (net_eq(sock_net(sk), net) && inet->inet_num == num && + !(inet->inet_daddr && inet->inet_daddr != raddr) && + !(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) && + raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) + return true; + return false; } -EXPORT_SYMBOL_GPL(__raw_v4_lookup); +EXPORT_SYMBOL_GPL(raw_v4_match); /* * 0 - deliver @@ -166,25 +160,28 @@ static int icmp_filter(const struct sock *sk, const struct sk_buff *skb) * RFC 1122: SHOULD pass TOS value up to the transport layer. * -> It does. And not only TOS, but all IP header. */ -static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash) +static int raw_v4_input(struct net *net, struct sk_buff *skb, + const struct iphdr *iph, int hash) { int sdif = inet_sdif(skb); + struct hlist_head *hlist; int dif = inet_iif(skb); - struct sock *sk; - struct hlist_head *head; int delivered = 0; - struct net *net; - - read_lock(&raw_v4_hashinfo.lock); - head = &raw_v4_hashinfo.ht[hash]; - if (hlist_empty(head)) - goto out; + struct sock *sk; - net = dev_net(skb->dev); - sk = __raw_v4_lookup(net, __sk_head(head), iph->protocol, - iph->saddr, iph->daddr, dif, sdif); + hlist = &raw_v4_hashinfo.ht[hash]; + rcu_read_lock(); + sk_for_each_rcu(sk, hlist) { + if (!raw_v4_match(net, sk, iph->protocol, + iph->saddr, iph->daddr, dif, sdif)) + continue; + + if (atomic_read(&sk->sk_rmem_alloc) >= + READ_ONCE(sk->sk_rcvbuf)) { + sk_drops_inc(sk); + continue; + } - while (sk) { delivered = 1; if ((iph->protocol != IPPROTO_ICMP || !icmp_filter(sk, skb)) && ip_mc_sf_allow(sk, iph->daddr, iph->saddr, @@ -195,31 +192,17 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash) if (clone) raw_rcv(sk, clone); } - sk = __raw_v4_lookup(net, sk_next(sk), iph->protocol, - iph->saddr, iph->daddr, - dif, sdif); } -out: - read_unlock(&raw_v4_hashinfo.lock); + rcu_read_unlock(); return delivered; } int raw_local_deliver(struct sk_buff *skb, int protocol) { - int hash; - struct sock *raw_sk; - - hash = protocol & (RAW_HTABLE_SIZE - 1); - raw_sk = sk_head(&raw_v4_hashinfo.ht[hash]); - - /* If there maybe a raw socket we must check - if not we - * don't care less - */ - if (raw_sk && !raw_v4_input(skb, ip_hdr(skb), hash)) - raw_sk = NULL; - - return raw_sk != NULL; + struct net *net = dev_net(skb->dev); + return raw_v4_input(net, skb, ip_hdr(skb), + raw_hashfunc(net, protocol)); } static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) @@ -227,8 +210,9 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) struct inet_sock *inet = inet_sk(sk); const int type = icmp_hdr(skb)->type; const int code = icmp_hdr(skb)->code; - int err = 0; int harderr = 0; + bool recverr; + int err = 0; if (type == ICMP_DEST_UNREACH && code == ICMP_FRAG_NEEDED) ipv4_sk_update_pmtu(skb, sk, info); @@ -242,7 +226,8 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) 2. Socket is connected (otherwise the error indication is useless without ip_recverr and error is hard. */ - if (!inet->recverr && sk->sk_state != TCP_ESTABLISHED) + recverr = inet_test_bit(RECVERR, sk); + if (!recverr && sk->sk_state != TCP_ESTABLISHED) return; switch (type) { @@ -261,7 +246,7 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) if (code > NR_ICMP_UNREACH) break; if (code == ICMP_FRAG_NEEDED) { - harderr = inet->pmtudisc != IP_PMTUDISC_DONT; + harderr = READ_ONCE(inet->pmtudisc) != IP_PMTUDISC_DONT; err = EMSGSIZE; } else { err = icmp_err_convert[code].errno; @@ -269,16 +254,16 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) } } - if (inet->recverr) { + if (recverr) { const struct iphdr *iph = (const struct iphdr *)skb->data; u8 *payload = skb->data + (iph->ihl << 2); - if (inet->hdrincl) + if (inet_test_bit(HDRINCL, sk)) payload = skb->data; ip_icmp_error(sk, skb, err, 0, info, payload); } - if (inet->recverr || harderr) { + if (recverr || harderr) { sk->sk_err = err; sk_error_report(sk); } @@ -286,40 +271,37 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info) { - int hash; - struct sock *raw_sk; + struct net *net = dev_net(skb->dev); + int dif = skb->dev->ifindex; + int sdif = inet_sdif(skb); + struct hlist_head *hlist; const struct iphdr *iph; - struct net *net; - - hash = protocol & (RAW_HTABLE_SIZE - 1); + struct sock *sk; + int hash; - read_lock(&raw_v4_hashinfo.lock); - raw_sk = sk_head(&raw_v4_hashinfo.ht[hash]); - if (raw_sk) { - int dif = skb->dev->ifindex; - int sdif = inet_sdif(skb); + hash = raw_hashfunc(net, protocol); + hlist = &raw_v4_hashinfo.ht[hash]; + rcu_read_lock(); + sk_for_each_rcu(sk, hlist) { iph = (const struct iphdr *)skb->data; - net = dev_net(skb->dev); - - while ((raw_sk = __raw_v4_lookup(net, raw_sk, protocol, - iph->daddr, iph->saddr, - dif, sdif)) != NULL) { - raw_err(raw_sk, skb, info); - raw_sk = sk_next(raw_sk); - iph = (const struct iphdr *)skb->data; - } + if (!raw_v4_match(net, sk, iph->protocol, + iph->daddr, iph->saddr, dif, sdif)) + continue; + raw_err(sk, skb, info); } - read_unlock(&raw_v4_hashinfo.lock); + rcu_read_unlock(); } static int raw_rcv_skb(struct sock *sk, struct sk_buff *skb) { + enum skb_drop_reason reason; + /* Charge it to the socket. */ - ipv4_pktinfo_prepare(sk, skb); - if (sock_queue_rcv_skb(sk, skb) < 0) { - kfree_skb(skb); + ipv4_pktinfo_prepare(sk, skb, true); + if (sock_queue_rcv_skb_reason(sk, skb, &reason) < 0) { + sk_skb_reason_drop(sk, skb, reason); return NET_RX_DROP; } @@ -329,13 +311,13 @@ static int raw_rcv_skb(struct sock *sk, struct sk_buff *skb) int raw_rcv(struct sock *sk, struct sk_buff *skb) { if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) { - atomic_inc(&sk->sk_drops); - kfree_skb(skb); + sk_drops_inc(sk); + sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_XFRM_POLICY); return NET_RX_DROP; } nf_reset_ct(skb); - skb_push(skb, skb->data - skb_network_header(skb)); + skb_push(skb, -skb_network_offset(skb)); raw_rcv_skb(sk, skb); return 0; @@ -375,9 +357,10 @@ static int raw_send_hdrinc(struct sock *sk, struct flowi4 *fl4, goto error; skb_reserve(skb, hlen); - skb->priority = sk->sk_priority; + skb->protocol = htons(ETH_P_IP); + skb->priority = sockc->priority; skb->mark = sockc->mark; - skb->tstamp = sockc->transmit_time; + skb_set_delivery_type_by_clockid(skb, sockc->transmit_time, sk->sk_clockid); skb_dst_set(skb, &rt->dst); *rtp = NULL; @@ -387,7 +370,7 @@ static int raw_send_hdrinc(struct sock *sk, struct flowi4 *fl4, skb->ip_summed = CHECKSUM_NONE; - skb_setup_tx_timestamp(skb, sockc->tsflags); + skb_setup_tx_timestamp(skb, sockc); if (flags & MSG_CONFIRM) skb_set_dst_pending_confirm(skb, 1); @@ -440,7 +423,7 @@ error_free: kfree_skb(skb); error: IP_INC_STATS(net, IPSTATS_MIB_OUTDISCARDS); - if (err == -ENOBUFS && !inet->recverr) + if (err == -ENOBUFS && !inet_test_bit(RECVERR, sk)) err = 0; return err; } @@ -503,11 +486,11 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) struct ipcm_cookie ipc; struct rtable *rt = NULL; struct flowi4 fl4; + u8 scope; int free = 0; __be32 daddr; __be32 saddr; - u8 tos; - int err; + int uc_index, err; struct ip_options_data opt_copy; struct raw_frag_vec rfv; int hdrincl; @@ -516,12 +499,8 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) if (len > 0xFFFF) goto out; - /* hdrincl should be READ_ONCE(inet->hdrincl) - * but READ_ONCE() doesn't work with bit fields. - * Doing this indirectly yields the same result. - */ - hdrincl = inet->hdrincl; - hdrincl = READ_ONCE(hdrincl); + hdrincl = inet_test_bit(HDRINCL, sk); + /* * Check the flags. */ @@ -559,6 +538,9 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) } ipcm_init_sk(&ipc, inet); + /* Keep backward compat */ + if (hdrincl) + ipc.protocol = IPPROTO_RAW; if (msg->msg_controllen) { err = ip_cmsg_send(sk, msg, &ipc, false); @@ -599,37 +581,39 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) daddr = ipc.opt->opt.faddr; } } - tos = get_rtconn_flags(&ipc, sk); - if (msg->msg_flags & MSG_DONTROUTE) - tos |= RTO_ONLINK; + scope = ip_sendmsg_scope(inet, &ipc, msg); + uc_index = READ_ONCE(inet->uc_index); if (ipv4_is_multicast(daddr)) { if (!ipc.oif || netif_index_is_l3_master(sock_net(sk), ipc.oif)) - ipc.oif = inet->mc_index; + ipc.oif = READ_ONCE(inet->mc_index); if (!saddr) - saddr = inet->mc_addr; + saddr = READ_ONCE(inet->mc_addr); } else if (!ipc.oif) { - ipc.oif = inet->uc_index; - } else if (ipv4_is_lbcast(daddr) && inet->uc_index) { + ipc.oif = uc_index; + } else if (ipv4_is_lbcast(daddr) && uc_index) { /* oif is set, packet is to local broadcast * and uc_index is set. oif is most likely set * by sk_bound_dev_if. If uc_index != oif check if the * oif is an L3 master and uc_index is an L3 slave. * If so, we want to allow the send using the uc_index. */ - if (ipc.oif != inet->uc_index && + if (ipc.oif != uc_index && ipc.oif == l3mdev_master_ifindex_by_index(sock_net(sk), - inet->uc_index)) { - ipc.oif = inet->uc_index; + uc_index)) { + ipc.oif = uc_index; } } - flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark, tos, - RT_SCOPE_UNIVERSE, - hdrincl ? IPPROTO_RAW : sk->sk_protocol, + flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark, + ipc.tos & INET_DSCP_MASK, scope, + hdrincl ? ipc.protocol : sk->sk_protocol, inet_sk_flowi_flags(sk) | (hdrincl ? FLOWI_FLAG_KNOWN_NH : 0), - daddr, saddr, 0, 0, sk->sk_uid); + daddr, saddr, 0, 0, sk_uid(sk)); + + fl4.fl4_icmp_type = 0; + fl4.fl4_icmp_code = 0; if (!hdrincl) { rfv.msg = msg; @@ -671,7 +655,7 @@ back_from_confirm: ip_flush_pending_frames(sk); else if (!(msg->msg_flags & MSG_MORE)) { err = ip_push_pending_frames(sk, &fl4); - if (err == -ENOBUFS && !inet->recverr) + if (err == -ENOBUFS && !inet_test_bit(RECVERR, sk)) err = 0; } release_sock(sk); @@ -713,7 +697,8 @@ static void raw_destroy(struct sock *sk) } /* This gets rid of all the nasties in af_inet. -DaveM */ -static int raw_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) +static int raw_bind(struct sock *sk, struct sockaddr_unsized *uaddr, + int addr_len) { struct inet_sock *inet = inet_sk(sk); struct sockaddr_in *addr = (struct sockaddr_in *) uaddr; @@ -722,6 +707,7 @@ static int raw_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) int ret = -EINVAL; int chk_addr_ret; + lock_sock(sk); if (sk->sk_state != TCP_CLOSE || addr_len < sizeof(struct sockaddr_in)) goto out; @@ -741,7 +727,9 @@ static int raw_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) inet->inet_saddr = 0; /* Use device */ sk_dst_reset(sk); ret = 0; -out: return ret; +out: + release_sock(sk); + return ret; } /* @@ -750,7 +738,7 @@ out: return ret; */ static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, - int noblock, int flags, int *addr_len) + int flags, int *addr_len) { struct inet_sock *inet = inet_sk(sk); size_t copied = 0; @@ -766,7 +754,7 @@ static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, goto out; } - skb = skb_recv_datagram(sk, flags, noblock, &err); + skb = skb_recv_datagram(sk, flags, &err); if (!skb) goto out; @@ -780,7 +768,7 @@ static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, if (err) goto done; - sock_recv_ts_and_drops(msg, sk, skb); + sock_recv_cmsgs(msg, sk, skb); /* Copy the address. */ if (sin) { @@ -790,7 +778,7 @@ static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, memset(&sin->sin_zero, 0, sizeof(sin->sin_zero)); *addr_len = sizeof(*sin); } - if (inet->cmsg_flags) + if (inet_cmsg_flags(inet)) ip_cmsg_recv(msg, skb); if (flags & MSG_TRUNC) copied = skb->len; @@ -806,6 +794,7 @@ static int raw_sk_init(struct sock *sk) { struct raw_sock *rp = raw_sk(sk); + sk->sk_drop_counters = &rp->drop_counters; if (inet_sk(sk)->inet_num == IPPROTO_ICMP) memset(&rp->filter, 0, sizeof(rp->filter)); return 0; @@ -839,7 +828,7 @@ static int raw_geticmpfilter(struct sock *sk, char __user *optval, int __user *o out: return ret; } -static int do_raw_setsockopt(struct sock *sk, int level, int optname, +static int do_raw_setsockopt(struct sock *sk, int optname, sockptr_t optval, unsigned int optlen) { if (optname == ICMP_FILTER) { @@ -856,11 +845,11 @@ static int raw_setsockopt(struct sock *sk, int level, int optname, { if (level != SOL_RAW) return ip_setsockopt(sk, level, optname, optval, optlen); - return do_raw_setsockopt(sk, level, optname, optval, optlen); + return do_raw_setsockopt(sk, optname, optval, optlen); } -static int do_raw_getsockopt(struct sock *sk, int level, int optname, - char __user *optval, int __user *optlen) +static int do_raw_getsockopt(struct sock *sk, int optname, + char __user *optval, int __user *optlen) { if (optname == ICMP_FILTER) { if (inet_sk(sk)->inet_num != IPPROTO_ICMP) @@ -876,32 +865,32 @@ static int raw_getsockopt(struct sock *sk, int level, int optname, { if (level != SOL_RAW) return ip_getsockopt(sk, level, optname, optval, optlen); - return do_raw_getsockopt(sk, level, optname, optval, optlen); + return do_raw_getsockopt(sk, optname, optval, optlen); } -static int raw_ioctl(struct sock *sk, int cmd, unsigned long arg) +static int raw_ioctl(struct sock *sk, int cmd, int *karg) { switch (cmd) { case SIOCOUTQ: { - int amount = sk_wmem_alloc_get(sk); - - return put_user(amount, (int __user *)arg); + *karg = sk_wmem_alloc_get(sk); + return 0; } case SIOCINQ: { struct sk_buff *skb; - int amount = 0; spin_lock_bh(&sk->sk_receive_queue.lock); skb = skb_peek(&sk->sk_receive_queue); if (skb) - amount = skb->len; + *karg = skb->len; + else + *karg = 0; spin_unlock_bh(&sk->sk_receive_queue.lock); - return put_user(amount, (int __user *)arg); + return 0; } default: #ifdef CONFIG_IP_MROUTE - return ipmr_ioctl(sk, cmd, (void __user *)arg); + return ipmr_ioctl(sk, cmd, karg); #else return -ENOIOCTLCMD; #endif @@ -968,44 +957,40 @@ struct proto raw_prot = { }; #ifdef CONFIG_PROC_FS -static struct sock *raw_get_first(struct seq_file *seq) +static struct sock *raw_get_first(struct seq_file *seq, int bucket) { - struct sock *sk; struct raw_hashinfo *h = pde_data(file_inode(seq->file)); struct raw_iter_state *state = raw_seq_private(seq); + struct hlist_head *hlist; + struct sock *sk; - for (state->bucket = 0; state->bucket < RAW_HTABLE_SIZE; + for (state->bucket = bucket; state->bucket < RAW_HTABLE_SIZE; ++state->bucket) { - sk_for_each(sk, &h->ht[state->bucket]) + hlist = &h->ht[state->bucket]; + sk_for_each(sk, hlist) { if (sock_net(sk) == seq_file_net(seq)) - goto found; + return sk; + } } - sk = NULL; -found: - return sk; + return NULL; } static struct sock *raw_get_next(struct seq_file *seq, struct sock *sk) { - struct raw_hashinfo *h = pde_data(file_inode(seq->file)); struct raw_iter_state *state = raw_seq_private(seq); do { sk = sk_next(sk); -try_again: - ; } while (sk && sock_net(sk) != seq_file_net(seq)); - if (!sk && ++state->bucket < RAW_HTABLE_SIZE) { - sk = sk_head(&h->ht[state->bucket]); - goto try_again; - } + if (!sk) + return raw_get_first(seq, state->bucket + 1); return sk; } static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos) { - struct sock *sk = raw_get_first(seq); + struct sock *sk = raw_get_first(seq, 0); if (sk) while (pos && (sk = raw_get_next(seq, sk)) != NULL) @@ -1018,7 +1003,8 @@ void *raw_seq_start(struct seq_file *seq, loff_t *pos) { struct raw_hashinfo *h = pde_data(file_inode(seq->file)); - read_lock(&h->lock); + spin_lock(&h->lock); + return *pos ? raw_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; } EXPORT_SYMBOL_GPL(raw_seq_start); @@ -1028,7 +1014,7 @@ void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos) struct sock *sk; if (v == SEQ_START_TOKEN) - sk = raw_get_first(seq); + sk = raw_get_first(seq, 0); else sk = raw_get_next(seq, v); ++*pos; @@ -1041,7 +1027,7 @@ void raw_seq_stop(struct seq_file *seq, void *v) { struct raw_hashinfo *h = pde_data(file_inode(seq->file)); - read_unlock(&h->lock); + spin_unlock(&h->lock); } EXPORT_SYMBOL_GPL(raw_seq_stop); @@ -1059,9 +1045,9 @@ static void raw_sock_seq_show(struct seq_file *seq, struct sock *sp, int i) sk_wmem_alloc_get(sp), sk_rmem_alloc_get(sp), 0, 0L, 0, - from_kuid_munged(seq_user_ns(seq), sock_i_uid(sp)), + from_kuid_munged(seq_user_ns(seq), sk_uid(sp)), 0, sock_i_ino(sp), - refcount_read(&sp->sk_refcnt), sp, atomic_read(&sp->sk_drops)); + refcount_read(&sp->sk_refcnt), sp, sk_drops_read(sp)); } static int raw_seq_show(struct seq_file *seq, void *v) @@ -1103,6 +1089,7 @@ static __net_initdata struct pernet_operations raw_net_ops = { int __init raw_proc_init(void) { + return register_pernet_subsys(&raw_net_ops); } diff --git a/net/ipv4/raw_diag.c b/net/ipv4/raw_diag.c index ccacbde30a2c..943e5998e0ad 100644 --- a/net/ipv4/raw_diag.c +++ b/net/ipv4/raw_diag.c @@ -34,57 +34,56 @@ raw_get_hashinfo(const struct inet_diag_req_v2 *r) * use helper to figure it out. */ -static struct sock *raw_lookup(struct net *net, struct sock *from, - const struct inet_diag_req_v2 *req) +static bool raw_lookup(struct net *net, const struct sock *sk, + const struct inet_diag_req_v2 *req) { struct inet_diag_req_raw *r = (void *)req; - struct sock *sk = NULL; if (r->sdiag_family == AF_INET) - sk = __raw_v4_lookup(net, from, r->sdiag_raw_protocol, - r->id.idiag_dst[0], - r->id.idiag_src[0], - r->id.idiag_if, 0); + return raw_v4_match(net, sk, r->sdiag_raw_protocol, + r->id.idiag_dst[0], + r->id.idiag_src[0], + r->id.idiag_if, 0); #if IS_ENABLED(CONFIG_IPV6) else - sk = __raw_v6_lookup(net, from, r->sdiag_raw_protocol, - (const struct in6_addr *)r->id.idiag_src, - (const struct in6_addr *)r->id.idiag_dst, - r->id.idiag_if, 0); + return raw_v6_match(net, sk, r->sdiag_raw_protocol, + (const struct in6_addr *)r->id.idiag_src, + (const struct in6_addr *)r->id.idiag_dst, + r->id.idiag_if, 0); #endif - return sk; + return false; } static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r) { struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); - struct sock *sk = NULL, *s; + struct hlist_head *hlist; + struct sock *sk; int slot; if (IS_ERR(hashinfo)) return ERR_CAST(hashinfo); - read_lock(&hashinfo->lock); + rcu_read_lock(); for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) { - sk_for_each(s, &hashinfo->ht[slot]) { - sk = raw_lookup(net, s, r); - if (sk) { + hlist = &hashinfo->ht[slot]; + sk_for_each_rcu(sk, hlist) { + if (raw_lookup(net, sk, r)) { /* * Grab it and keep until we fill - * diag meaage to be reported, so + * diag message to be reported, so * caller should call sock_put then. - * We can do that because we're keeping - * hashinfo->lock here. */ - sock_hold(sk); - goto out_unlock; + if (refcount_inc_not_zero(&sk->sk_refcnt)) + goto out_unlock; } } } + sk = ERR_PTR(-ENOENT); out_unlock: - read_unlock(&hashinfo->lock); + rcu_read_unlock(); - return sk ? sk : ERR_PTR(-ENOENT); + return sk; } static int raw_diag_dump_one(struct netlink_callback *cb, @@ -127,9 +126,9 @@ static int raw_diag_dump_one(struct netlink_callback *cb, static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *r, - struct nlattr *bc, bool net_admin) + bool net_admin) { - if (!inet_diag_bc_sk(bc, sk)) + if (!inet_diag_bc_sk(cb->data, sk)) return 0; return inet_sk_diag_fill(sk, NULL, skb, cb, r, NLM_F_MULTI, net_admin); @@ -141,24 +140,22 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); struct net *net = sock_net(skb->sk); - struct inet_diag_dump_data *cb_data; int num, s_num, slot, s_slot; + struct hlist_head *hlist; struct sock *sk = NULL; - struct nlattr *bc; if (IS_ERR(hashinfo)) return; - cb_data = cb->data; - bc = cb_data->inet_diag_nla_bc; s_slot = cb->args[0]; num = s_num = cb->args[1]; - read_lock(&hashinfo->lock); + rcu_read_lock(); for (slot = s_slot; slot < RAW_HTABLE_SIZE; s_num = 0, slot++) { num = 0; - sk_for_each(sk, &hashinfo->ht[slot]) { + hlist = &hashinfo->ht[slot]; + sk_for_each_rcu(sk, hlist) { struct inet_sock *inet = inet_sk(sk); if (!net_eq(sock_net(sk), net)) @@ -173,7 +170,7 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, if (r->id.idiag_dport != inet->inet_dport && r->id.idiag_dport) goto next; - if (sk_diag_dump(sk, skb, cb, r, bc, net_admin) < 0) + if (sk_diag_dump(sk, skb, cb, r, net_admin) < 0) goto out_unlock; next: num++; @@ -181,7 +178,7 @@ next: } out_unlock: - read_unlock(&hashinfo->lock); + rcu_read_unlock(); cb->args[0] = slot; cb->args[1] = num; @@ -212,6 +209,7 @@ static int raw_diag_destroy(struct sk_buff *in_skb, #endif static const struct inet_diag_handler raw_diag_handler = { + .owner = THIS_MODULE, .dump = raw_diag_dump, .dump_one = raw_diag_dump_one, .idiag_get_info = raw_diag_get_info, @@ -256,5 +254,6 @@ static void __exit raw_diag_exit(void) module_init(raw_diag_init); module_exit(raw_diag_exit); MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("RAW socket monitoring via SOCK_DIAG"); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-255 /* AF_INET - IPPROTO_RAW */); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10-255 /* AF_INET6 - IPPROTO_RAW */); diff --git a/net/ipv4/route.c b/net/ipv4/route.c index ff6f91cdb6c4..b549d6a57307 100644 --- a/net/ipv4/route.c +++ b/net/ipv4/route.c @@ -84,6 +84,8 @@ #include <linux/jhash.h> #include <net/dst.h> #include <net/dst_metadata.h> +#include <net/flow.h> +#include <net/inet_dscp.h> #include <net/net_namespace.h> #include <net/ip.h> #include <net/route.h> @@ -105,21 +107,17 @@ #include "fib_lookup.h" -#define RT_FL_TOS(oldflp4) \ - ((oldflp4)->flowi4_tos & (IPTOS_RT_MASK | RTO_ONLINK)) - #define RT_GC_TIMEOUT (300*HZ) #define DEFAULT_MIN_PMTU (512 + 20 + 20) #define DEFAULT_MTU_EXPIRES (10 * 60 * HZ) - +#define DEFAULT_MIN_ADVMSS 256 static int ip_rt_max_size; static int ip_rt_redirect_number __read_mostly = 9; static int ip_rt_redirect_load __read_mostly = HZ / 50; static int ip_rt_redirect_silence __read_mostly = ((HZ / 50) << (9 + 1)); static int ip_rt_error_cost __read_mostly = HZ; static int ip_rt_error_burst __read_mostly = 5 * HZ; -static int ip_rt_min_advmss __read_mostly = 256; static int ip_rt_gc_timeout __read_mostly = RT_GC_TIMEOUT; @@ -132,7 +130,8 @@ struct dst_entry *ipv4_dst_check(struct dst_entry *dst, u32 cookie); static unsigned int ipv4_default_advmss(const struct dst_entry *dst); INDIRECT_CALLABLE_SCOPE unsigned int ipv4_mtu(const struct dst_entry *dst); -static struct dst_entry *ipv4_negative_advice(struct dst_entry *dst); +static void ipv4_negative_advice(struct sock *sk, + struct dst_entry *dst); static void ipv4_link_failure(struct sk_buff *skb); static void ip_rt_update_pmtu(struct dst_entry *dst, struct sock *sk, struct sk_buff *skb, u32 mtu, @@ -191,7 +190,11 @@ const __u8 ip_tos2prio[16] = { EXPORT_SYMBOL(ip_tos2prio); static DEFINE_PER_CPU(struct rt_cache_stat, rt_cache_stat); +#ifndef CONFIG_PREEMPT_RT #define RT_CACHE_STAT_INC(field) raw_cpu_inc(rt_cache_stat.field) +#else +#define RT_CACHE_STAT_INC(field) this_cpu_inc(rt_cache_stat.field) +#endif #ifdef CONFIG_PROC_FS static void *rt_cache_seq_start(struct seq_file *seq, loff_t *pos) @@ -392,7 +395,13 @@ static inline int ip_rt_proc_init(void) static inline bool rt_is_expired(const struct rtable *rth) { - return rth->rt_genid != rt_genid_ipv4(dev_net(rth->dst.dev)); + bool res; + + rcu_read_lock(); + res = rth->rt_genid != rt_genid_ipv4(dev_net_rcu(rth->dst.dev)); + rcu_read_unlock(); + + return res; } void rt_cache_flush(struct net *net) @@ -405,11 +414,11 @@ static struct neighbour *ipv4_neigh_lookup(const struct dst_entry *dst, const void *daddr) { const struct rtable *rt = container_of(dst, struct rtable, dst); - struct net_device *dev = dst->dev; + struct net_device *dev; struct neighbour *n; - rcu_read_lock_bh(); - + rcu_read_lock(); + dev = dst_dev_rcu(dst); if (likely(rt->rt_gw_family == AF_INET)) { n = ip_neigh_gw4(dev, rt->rt_gw4); } else if (rt->rt_gw_family == AF_INET6) { @@ -424,7 +433,7 @@ static struct neighbour *ipv4_neigh_lookup(const struct dst_entry *dst, if (!IS_ERR(n) && !refcount_inc_not_zero(&n->refcnt)) n = NULL; - rcu_read_unlock_bh(); + rcu_read_unlock(); return n; } @@ -432,7 +441,7 @@ static struct neighbour *ipv4_neigh_lookup(const struct dst_entry *dst, static void ipv4_confirm_neigh(const struct dst_entry *dst, const void *daddr) { const struct rtable *rt = container_of(dst, struct rtable, dst); - struct net_device *dev = dst->dev; + struct net_device *dev = dst_dev(dst); const __be32 *pkey = daddr; if (rt->rt_gw_family == AF_INET) { @@ -458,7 +467,7 @@ static u32 *ip_tstamps __read_mostly; * if one generator is seldom used. This makes hard for an attacker * to infer how many packets were sent between two points in time. */ -u32 ip_idents_reserve(u32 hash, int segs) +static u32 ip_idents_reserve(u32 hash, int segs) { u32 bucket, old, now = (u32)jiffies; atomic_t *p_id; @@ -471,7 +480,7 @@ u32 ip_idents_reserve(u32 hash, int segs) old = READ_ONCE(*p_tstamp); if (old != now && cmpxchg(p_tstamp, old, now) == old) - delta = prandom_u32_max(now - old); + delta = get_random_u32_below(now - old); /* If UBSAN reports an error there, please make sure your compiler * supports -fno-strict-overflow before reporting it that was a bug @@ -479,7 +488,6 @@ u32 ip_idents_reserve(u32 hash, int segs) */ return atomic_add_return(segs + delta, p_id) - segs; } -EXPORT_SYMBOL(ip_idents_reserve); void __ip_select_ident(struct net *net, struct iphdr *iph, int segs) { @@ -500,23 +508,23 @@ void __ip_select_ident(struct net *net, struct iphdr *iph, int segs) EXPORT_SYMBOL(__ip_select_ident); static void __build_flow_key(const struct net *net, struct flowi4 *fl4, - const struct sock *sk, - const struct iphdr *iph, - int oif, u8 tos, - u8 prot, u32 mark, int flow_flags) + const struct sock *sk, const struct iphdr *iph, + int oif, __u8 tos, u8 prot, u32 mark, + int flow_flags) { - if (sk) { - const struct inet_sock *inet = inet_sk(sk); + __u8 scope = RT_SCOPE_UNIVERSE; + if (sk) { oif = sk->sk_bound_dev_if; - mark = sk->sk_mark; - tos = RT_CONN_FLAGS(sk); - prot = inet->hdrincl ? IPPROTO_RAW : sk->sk_protocol; - } - flowi4_init_output(fl4, oif, mark, tos, - RT_SCOPE_UNIVERSE, prot, - flow_flags, - iph->daddr, iph->saddr, 0, 0, + mark = READ_ONCE(sk->sk_mark); + tos = ip_sock_rt_tos(sk); + scope = ip_sock_rt_scope(sk); + prot = inet_test_bit(HDRINCL, sk) ? IPPROTO_RAW : + sk->sk_protocol; + } + + flowi4_init_output(fl4, oif, mark, tos & INET_DSCP_MASK, scope, + prot, flow_flags, iph->daddr, iph->saddr, 0, 0, sock_net_uid(net, sk)); } @@ -526,9 +534,9 @@ static void build_skb_flow_key(struct flowi4 *fl4, const struct sk_buff *skb, const struct net *net = dev_net(skb->dev); const struct iphdr *iph = ip_hdr(skb); int oif = skb->dev->ifindex; - u8 tos = RT_TOS(iph->tos); u8 prot = iph->protocol; u32 mark = skb->mark; + __u8 tos = iph->tos; __build_flow_key(net, fl4, sk, iph, oif, tos, prot, mark, 0); } @@ -543,11 +551,14 @@ static void build_sk_flow_key(struct flowi4 *fl4, const struct sock *sk) inet_opt = rcu_dereference(inet->inet_opt); if (inet_opt && inet_opt->opt.srr) daddr = inet_opt->opt.faddr; - flowi4_init_output(fl4, sk->sk_bound_dev_if, sk->sk_mark, - RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, - inet->hdrincl ? IPPROTO_RAW : sk->sk_protocol, + flowi4_init_output(fl4, sk->sk_bound_dev_if, READ_ONCE(sk->sk_mark), + ip_sock_rt_tos(sk), + ip_sock_rt_scope(sk), + inet_test_bit(HDRINCL, sk) ? + IPPROTO_RAW : sk->sk_protocol, inet_sk_flowi_flags(sk), - daddr, inet->inet_saddr, 0, 0, sk->sk_uid); + daddr, inet->inet_saddr, 0, 0, + sk_uid(sk)); rcu_read_unlock(); } @@ -596,6 +607,11 @@ static void fnhe_remove_oldest(struct fnhe_hash_bucket *hash) oldest_p = fnhe_p; } } + + /* Clear oldest->fnhe_daddr to prevent this fnhe from being + * rebound with new dsts in rt_bind_exception(). + */ + oldest->fnhe_daddr = 0; fnhe_flush_routes(oldest); *oldest_p = oldest->fnhe_next; kfree_rcu(oldest, rcu); @@ -679,7 +695,7 @@ static void update_or_create_fnhe(struct fib_nh_common *nhc, __be32 daddr, } else { /* Randomize max depth to avoid some side channels attacks. */ int max_depth = FNHE_RECLAIM_DEPTH + - prandom_u32_max(FNHE_RECLAIM_DEPTH); + get_random_u32_below(FNHE_RECLAIM_DEPTH); while (depth > max_depth) { fnhe_remove_oldest(hash); @@ -707,7 +723,7 @@ static void update_or_create_fnhe(struct fib_nh_common *nhc, __be32 daddr, */ rt = rcu_dereference(nhc->nhc_rth_input); if (rt) - rt->dst.obsolete = DST_OBSOLETE_KILL; + WRITE_ONCE(rt->dst.obsolete, DST_OBSOLETE_KILL); for_each_possible_cpu(i) { struct rtable __rcu **prt; @@ -715,7 +731,7 @@ static void update_or_create_fnhe(struct fib_nh_common *nhc, __be32 daddr, prt = per_cpu_ptr(nhc->nhc_pcpu_rth_output, i); rt = rcu_dereference(*prt); if (rt) - rt->dst.obsolete = DST_OBSOLETE_KILL; + WRITE_ONCE(rt->dst.obsolete, DST_OBSOLETE_KILL); } } @@ -770,11 +786,11 @@ static void __ip_do_redirect(struct rtable *rt, struct sk_buff *skb, struct flow goto reject_redirect; } - n = __ipv4_neigh_lookup(rt->dst.dev, new_gw); + n = __ipv4_neigh_lookup(rt->dst.dev, (__force u32)new_gw); if (!n) n = neigh_create(&arp_tbl, &new_gw, rt->dst.dev); if (!IS_ERR(n)) { - if (!(n->nud_state & NUD_VALID)) { + if (!(READ_ONCE(n->nud_state) & NUD_VALID)) { neigh_event_send(n, NULL); } else { if (fib_lookup(net, fl4, &res, 0) == 0) { @@ -787,7 +803,7 @@ static void __ip_do_redirect(struct rtable *rt, struct sk_buff *skb, struct flow jiffies + ip_rt_gc_timeout); } if (kill_route) - rt->dst.obsolete = DST_OBSOLETE_KILL; + WRITE_ONCE(rt->dst.obsolete, DST_OBSOLETE_KILL); call_netevent_notifiers(NETEVENT_NEIGH_UPDATE, n); } neigh_release(n); @@ -817,32 +833,25 @@ static void ip_do_redirect(struct dst_entry *dst, struct sock *sk, struct sk_buf const struct iphdr *iph = (const struct iphdr *) skb->data; struct net *net = dev_net(skb->dev); int oif = skb->dev->ifindex; - u8 tos = RT_TOS(iph->tos); u8 prot = iph->protocol; u32 mark = skb->mark; + __u8 tos = iph->tos; - rt = (struct rtable *) dst; + rt = dst_rtable(dst); __build_flow_key(net, &fl4, sk, iph, oif, tos, prot, mark, 0); __ip_do_redirect(rt, skb, &fl4, true); } -static struct dst_entry *ipv4_negative_advice(struct dst_entry *dst) +static void ipv4_negative_advice(struct sock *sk, + struct dst_entry *dst) { - struct rtable *rt = (struct rtable *)dst; - struct dst_entry *ret = dst; + struct rtable *rt = dst_rtable(dst); - if (rt) { - if (dst->obsolete > 0) { - ip_rt_put(rt); - ret = NULL; - } else if ((rt->rt_flags & RTCF_REDIRECTED) || - rt->dst.expires) { - ip_rt_put(rt); - ret = NULL; - } - } - return ret; + if ((READ_ONCE(dst->obsolete) > 0) || + (rt->rt_flags & RTCF_REDIRECTED) || + READ_ONCE(rt->dst.expires)) + sk_dst_reset(sk); } /* @@ -878,11 +887,11 @@ void ip_rt_send_redirect(struct sk_buff *skb) } log_martians = IN_DEV_LOG_MARTIANS(in_dev); vif = l3mdev_master_ifindex_rcu(rt->dst.dev); - rcu_read_unlock(); net = dev_net(rt->dst.dev); - peer = inet_getpeer_v4(net->ipv4.peers, ip_hdr(skb)->saddr, vif, 1); + peer = inet_getpeer_v4(net->ipv4.peers, ip_hdr(skb)->saddr, vif); if (!peer) { + rcu_read_unlock(); icmp_send(skb, ICMP_REDIRECT, ICMP_REDIR_HOST, rt_nexthop(rt, ip_hdr(skb)->daddr)); return; @@ -901,7 +910,7 @@ void ip_rt_send_redirect(struct sk_buff *skb) */ if (peer->n_redirects >= ip_rt_redirect_number) { peer->rate_last = jiffies; - goto out_put_peer; + goto out_unlock; } /* Check for load limit; set rate_last to the latest sent @@ -916,16 +925,14 @@ void ip_rt_send_redirect(struct sk_buff *skb) icmp_send(skb, ICMP_REDIRECT, ICMP_REDIR_HOST, gw); peer->rate_last = jiffies; ++peer->n_redirects; -#ifdef CONFIG_IP_ROUTE_VERBOSE - if (log_martians && + if (IS_ENABLED(CONFIG_IP_ROUTE_VERBOSE) && log_martians && peer->n_redirects == ip_rt_redirect_number) net_warn_ratelimited("host %pI4/if%d ignores redirects for %pI4 to %pI4\n", &ip_hdr(skb)->saddr, inet_iif(skb), &ip_hdr(skb)->daddr, &gw); -#endif } -out_put_peer: - inet_putpeer(peer); +out_unlock: + rcu_read_unlock(); } static int ip_error(struct sk_buff *skb) @@ -936,6 +943,7 @@ static int ip_error(struct sk_buff *skb) struct inet_peer *peer; unsigned long now; struct net *net; + SKB_DR(reason); bool send; int code; @@ -955,10 +963,12 @@ static int ip_error(struct sk_buff *skb) if (!IN_DEV_FORWARD(in_dev)) { switch (rt->dst.error) { case EHOSTUNREACH: + SKB_DR_SET(reason, IP_INADDRERRORS); __IP_INC_STATS(net, IPSTATS_MIB_INADDRERRORS); break; case ENETUNREACH: + SKB_DR_SET(reason, IP_INNOROUTES); __IP_INC_STATS(net, IPSTATS_MIB_INNOROUTES); break; } @@ -974,6 +984,7 @@ static int ip_error(struct sk_buff *skb) break; case ENETUNREACH: code = ICMP_NET_UNREACH; + SKB_DR_SET(reason, IP_INNOROUTES); __IP_INC_STATS(net, IPSTATS_MIB_INNOROUTES); break; case EACCES: @@ -981,9 +992,9 @@ static int ip_error(struct sk_buff *skb) break; } + rcu_read_lock(); peer = inet_getpeer_v4(net->ipv4.peers, ip_hdr(skb)->saddr, - l3mdev_master_ifindex(skb->dev), 1); - + l3mdev_master_ifindex_rcu(skb->dev)); send = true; if (peer) { now = jiffies; @@ -995,21 +1006,22 @@ static int ip_error(struct sk_buff *skb) peer->rate_tokens -= ip_rt_error_cost; else send = false; - inet_putpeer(peer); } + rcu_read_unlock(); + if (send) icmp_send(skb, ICMP_DEST_UNREACH, code, 0); -out: kfree_skb(skb); +out: kfree_skb_reason(skb, reason); return 0; } static void __ip_rt_update_pmtu(struct rtable *rt, struct flowi4 *fl4, u32 mtu) { struct dst_entry *dst = &rt->dst; - struct net *net = dev_net(dst->dev); struct fib_result res; bool lock = false; + struct net *net; u32 old_mtu; if (ip_mtu_locked(dst)) @@ -1019,24 +1031,39 @@ static void __ip_rt_update_pmtu(struct rtable *rt, struct flowi4 *fl4, u32 mtu) if (old_mtu < mtu) return; + rcu_read_lock(); + net = dst_dev_net_rcu(dst); if (mtu < net->ipv4.ip_rt_min_pmtu) { lock = true; mtu = min(old_mtu, net->ipv4.ip_rt_min_pmtu); } if (rt->rt_pmtu == mtu && !lock && - time_before(jiffies, dst->expires - net->ipv4.ip_rt_mtu_expires / 2)) - return; + time_before(jiffies, READ_ONCE(dst->expires) - + net->ipv4.ip_rt_mtu_expires / 2)) + goto out; - rcu_read_lock(); if (fib_lookup(net, fl4, &res, 0) == 0) { struct fib_nh_common *nhc; fib_select_path(net, &res, fl4, NULL); +#ifdef CONFIG_IP_ROUTE_MULTIPATH + if (fib_info_num_path(res.fi) > 1) { + int nhsel; + + for (nhsel = 0; nhsel < fib_info_num_path(res.fi); nhsel++) { + nhc = fib_info_nhc(res.fi, nhsel); + update_or_create_fnhe(nhc, fl4->daddr, 0, mtu, lock, + jiffies + net->ipv4.ip_rt_mtu_expires); + } + goto out; + } +#endif /* CONFIG_IP_ROUTE_MULTIPATH */ nhc = FIB_RES_NHC(res); update_or_create_fnhe(nhc, fl4->daddr, 0, mtu, lock, jiffies + net->ipv4.ip_rt_mtu_expires); } +out: rcu_read_unlock(); } @@ -1044,7 +1071,7 @@ static void ip_rt_update_pmtu(struct dst_entry *dst, struct sock *sk, struct sk_buff *skb, u32 mtu, bool confirm_neigh) { - struct rtable *rt = (struct rtable *) dst; + struct rtable *rt = dst_rtable(dst); struct flowi4 fl4; ip_rt_build_flow_key(&fl4, sk, skb); @@ -1064,8 +1091,8 @@ void ipv4_update_pmtu(struct sk_buff *skb, struct net *net, u32 mtu, struct rtable *rt; u32 mark = IP4_REPLY_MARK(net, skb->mark); - __build_flow_key(net, &fl4, NULL, iph, oif, - RT_TOS(iph->tos), protocol, mark, 0); + __build_flow_key(net, &fl4, NULL, iph, oif, iph->tos, protocol, mark, + 0); rt = __ip_route_output_key(net, &fl4); if (!IS_ERR(rt)) { __ip_rt_update_pmtu(rt, &fl4, mtu); @@ -1115,8 +1142,8 @@ void ipv4_sk_update_pmtu(struct sk_buff *skb, struct sock *sk, u32 mtu) __build_flow_key(net, &fl4, sk, iph, 0, 0, 0, 0, 0); - rt = (struct rtable *)odst; - if (odst->obsolete && !odst->ops->check(odst, 0)) { + rt = dst_rtable(odst); + if (READ_ONCE(odst->obsolete) && !odst->ops->check(odst, 0)) { rt = ip_route_output_flow(sock_net(sk), &fl4, sk); if (IS_ERR(rt)) goto out; @@ -1124,7 +1151,7 @@ void ipv4_sk_update_pmtu(struct sk_buff *skb, struct sock *sk, u32 mtu) new = true; } - __ip_rt_update_pmtu((struct rtable *)xfrm_dst_path(&rt->dst), &fl4, mtu); + __ip_rt_update_pmtu(dst_rtable(xfrm_dst_path(&rt->dst)), &fl4, mtu); if (!dst_check(&rt->dst, 0)) { if (new) @@ -1153,8 +1180,7 @@ void ipv4_redirect(struct sk_buff *skb, struct net *net, struct flowi4 fl4; struct rtable *rt; - __build_flow_key(net, &fl4, NULL, iph, oif, - RT_TOS(iph->tos), protocol, 0, 0); + __build_flow_key(net, &fl4, NULL, iph, oif, iph->tos, protocol, 0, 0); rt = __ip_route_output_key(net, &fl4); if (!IS_ERR(rt)) { __ip_do_redirect(rt, skb, &fl4, false); @@ -1182,7 +1208,7 @@ EXPORT_SYMBOL_GPL(ipv4_sk_redirect); INDIRECT_CALLABLE_SCOPE struct dst_entry *ipv4_dst_check(struct dst_entry *dst, u32 cookie) { - struct rtable *rt = (struct rtable *) dst; + struct rtable *rt = dst_rtable(dst); /* All IPV4 dsts are created with ->obsolete set to the value * DST_OBSOLETE_FORCE_CHK which forces validation calls down @@ -1192,7 +1218,8 @@ INDIRECT_CALLABLE_SCOPE struct dst_entry *ipv4_dst_check(struct dst_entry *dst, * this is indicated by setting obsolete to DST_OBSOLETE_KILL or * DST_OBSOLETE_DEAD. */ - if (dst->obsolete != DST_OBSOLETE_FORCE_CHK || rt_is_expired(rt)) + if (READ_ONCE(dst->obsolete) != DST_OBSOLETE_FORCE_CHK || + rt_is_expired(rt)) return NULL; return dst; } @@ -1200,7 +1227,8 @@ EXPORT_INDIRECT_CALLABLE(ipv4_dst_check); static void ipv4_send_dest_unreach(struct sk_buff *skb) { - struct ip_options opt; + struct inet_skb_parm parm; + struct net_device *dev; int res; /* Recompile ip options since IPCB may not be valid anymore. @@ -1210,20 +1238,21 @@ static void ipv4_send_dest_unreach(struct sk_buff *skb) ip_hdr(skb)->version != 4 || ip_hdr(skb)->ihl < 5) return; - memset(&opt, 0, sizeof(opt)); + memset(&parm, 0, sizeof(parm)); if (ip_hdr(skb)->ihl > 5) { if (!pskb_network_may_pull(skb, ip_hdr(skb)->ihl * 4)) return; - opt.optlen = ip_hdr(skb)->ihl * 4 - sizeof(struct iphdr); + parm.opt.optlen = ip_hdr(skb)->ihl * 4 - sizeof(struct iphdr); rcu_read_lock(); - res = __ip_options_compile(dev_net(skb->dev), &opt, skb, NULL); + dev = skb->dev ? skb->dev : skb_rtable(skb)->dst.dev; + res = __ip_options_compile(dev_net(dev), &parm.opt, skb, NULL); rcu_read_unlock(); if (res) return; } - __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0, &opt); + __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0, &parm); } static void ipv4_link_failure(struct sk_buff *skb) @@ -1268,7 +1297,7 @@ void ip_rt_get_source(u8 *addr, struct sk_buff *skb, struct rtable *rt) struct flowi4 fl4 = { .daddr = iph->daddr, .saddr = iph->saddr, - .flowi4_tos = RT_TOS(iph->tos), + .flowi4_dscp = ip4h_dscp(iph), .flowi4_oif = rt->dst.dev->ifindex, .flowi4_iif = skb->dev->ifindex, .flowi4_mark = skb->mark, @@ -1299,8 +1328,14 @@ static void set_class_tag(struct rtable *rt, u32 tag) static unsigned int ipv4_default_advmss(const struct dst_entry *dst) { unsigned int header_size = sizeof(struct tcphdr) + sizeof(struct iphdr); - unsigned int advmss = max_t(unsigned int, ipv4_mtu(dst) - header_size, - ip_rt_min_advmss); + unsigned int advmss; + struct net *net; + + rcu_read_lock(); + net = dst_dev_net_rcu(dst); + advmss = max_t(unsigned int, ipv4_mtu(dst) - header_size, + net->ipv4.ip_rt_min_advmss); + rcu_read_unlock(); return min(advmss, IPV4_MAX_PMTU - header_size); } @@ -1384,7 +1419,7 @@ u32 ip_mtu_from_fib_result(struct fib_result *res, __be32 daddr) struct fib_info *fi = res->fi; u32 mtu = 0; - if (dev_net(dev)->ipv4.sysctl_ip_fwd_use_pmtu || + if (READ_ONCE(dev_net(dev)->ipv4.sysctl_ip_fwd_use_pmtu) || fi->fib_metrics->metrics[RTAX_LOCK - 1] & (1 << RTAX_MTU)) mtu = fi->fib_mtu; @@ -1493,48 +1528,49 @@ void rt_add_uncached_list(struct rtable *rt) { struct uncached_list *ul = raw_cpu_ptr(&rt_uncached_list); - rt->rt_uncached_list = ul; + rt->dst.rt_uncached_list = ul; spin_lock_bh(&ul->lock); - list_add_tail(&rt->rt_uncached, &ul->head); + list_add_tail(&rt->dst.rt_uncached, &ul->head); spin_unlock_bh(&ul->lock); } void rt_del_uncached_list(struct rtable *rt) { - if (!list_empty(&rt->rt_uncached)) { - struct uncached_list *ul = rt->rt_uncached_list; + if (!list_empty(&rt->dst.rt_uncached)) { + struct uncached_list *ul = rt->dst.rt_uncached_list; spin_lock_bh(&ul->lock); - list_del(&rt->rt_uncached); + list_del_init(&rt->dst.rt_uncached); spin_unlock_bh(&ul->lock); } } static void ipv4_dst_destroy(struct dst_entry *dst) { - struct rtable *rt = (struct rtable *)dst; - ip_dst_metrics_put(dst); - rt_del_uncached_list(rt); + rt_del_uncached_list(dst_rtable(dst)); } void rt_flush_dev(struct net_device *dev) { - struct rtable *rt; + struct rtable *rt, *safe; int cpu; for_each_possible_cpu(cpu) { struct uncached_list *ul = &per_cpu(rt_uncached_list, cpu); + if (list_empty(&ul->head)) + continue; + spin_lock_bh(&ul->lock); - list_for_each_entry(rt, &ul->head, rt_uncached) { + list_for_each_entry_safe(rt, safe, &ul->head, dst.rt_uncached) { if (rt->dst.dev != dev) continue; rt->dst.dev = blackhole_netdev; - dev_replace_track(dev, blackhole_netdev, - &rt->dst.dev_tracker, - GFP_ATOMIC); + netdev_ref_replace(dev, blackhole_netdev, + &rt->dst.dev_tracker, GFP_ATOMIC); + list_del_init(&rt->dst.rt_uncached); } spin_unlock_bh(&ul->lock); } @@ -1543,7 +1579,7 @@ void rt_flush_dev(struct net_device *dev) static bool rt_cache_valid(const struct rtable *rt) { return rt && - rt->dst.obsolete == DST_OBSOLETE_FORCE_CHK && + READ_ONCE(rt->dst.obsolete) == DST_OBSOLETE_FORCE_CHK && !rt_is_expired(rt); } @@ -1608,12 +1644,11 @@ static void rt_set_nexthop(struct rtable *rt, __be32 daddr, struct rtable *rt_dst_alloc(struct net_device *dev, unsigned int flags, u16 type, - bool nopolicy, bool noxfrm) + bool noxfrm) { struct rtable *rt; - rt = dst_alloc(&ipv4_dst_ops, dev, 1, DST_OBSOLETE_FORCE_CHK, - (nopolicy ? DST_NOPOLICY : 0) | + rt = dst_alloc(&ipv4_dst_ops, dev, DST_OBSOLETE_FORCE_CHK, (noxfrm ? DST_NOXFRM : 0)); if (rt) { @@ -1627,7 +1662,6 @@ struct rtable *rt_dst_alloc(struct net_device *dev, rt->rt_uses_gateway = 0; rt->rt_gw_family = 0; rt->rt_gw4 = 0; - INIT_LIST_HEAD(&rt->rt_uncached); rt->dst.output = ip_output; if (flags & RTCF_LOCAL) @@ -1642,7 +1676,7 @@ struct rtable *rt_dst_clone(struct net_device *dev, struct rtable *rt) { struct rtable *new_rt; - new_rt = dst_alloc(&ipv4_dst_ops, dev, 1, DST_OBSOLETE_FORCE_CHK, + new_rt = dst_alloc(&ipv4_dst_ops, dev, DST_OBSOLETE_FORCE_CHK, rt->dst.flags); if (new_rt) { @@ -1658,10 +1692,9 @@ struct rtable *rt_dst_clone(struct net_device *dev, struct rtable *rt) new_rt->rt_gw4 = rt->rt_gw4; else if (rt->rt_gw_family == AF_INET6) new_rt->rt_gw6 = rt->rt_gw6; - INIT_LIST_HEAD(&new_rt->rt_uncached); - new_rt->dst.input = rt->dst.input; - new_rt->dst.output = rt->dst.output; + new_rt->dst.input = READ_ONCE(rt->dst.input); + new_rt->dst.output = READ_ONCE(rt->dst.output); new_rt->dst.error = rt->dst.error; new_rt->dst.lastuse = jiffies; new_rt->dst.lwtstate = lwtstate_get(rt->dst.lwtstate); @@ -1671,57 +1704,65 @@ struct rtable *rt_dst_clone(struct net_device *dev, struct rtable *rt) EXPORT_SYMBOL(rt_dst_clone); /* called in rcu_read_lock() section */ -int ip_mc_validate_source(struct sk_buff *skb, __be32 daddr, __be32 saddr, - u8 tos, struct net_device *dev, - struct in_device *in_dev, u32 *itag) +enum skb_drop_reason +ip_mc_validate_source(struct sk_buff *skb, __be32 daddr, __be32 saddr, + dscp_t dscp, struct net_device *dev, + struct in_device *in_dev, u32 *itag) { - int err; + enum skb_drop_reason reason; /* Primary sanity checks. */ if (!in_dev) - return -EINVAL; + return SKB_DROP_REASON_NOT_SPECIFIED; - if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr) || - skb->protocol != htons(ETH_P_IP)) - return -EINVAL; + if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr)) + return SKB_DROP_REASON_IP_INVALID_SOURCE; + + if (skb->protocol != htons(ETH_P_IP)) + return SKB_DROP_REASON_INVALID_PROTO; if (ipv4_is_loopback(saddr) && !IN_DEV_ROUTE_LOCALNET(in_dev)) - return -EINVAL; + return SKB_DROP_REASON_IP_LOCALNET; if (ipv4_is_zeronet(saddr)) { if (!ipv4_is_local_multicast(daddr) && ip_hdr(skb)->protocol != IPPROTO_IGMP) - return -EINVAL; + return SKB_DROP_REASON_IP_INVALID_SOURCE; } else { - err = fib_validate_source(skb, saddr, 0, tos, 0, dev, - in_dev, itag); - if (err < 0) - return err; + reason = fib_validate_source_reason(skb, saddr, 0, dscp, 0, + dev, in_dev, itag); + if (reason) + return reason; } - return 0; + return SKB_NOT_DROPPED_YET; } /* called in rcu_read_lock() section */ -static int ip_route_input_mc(struct sk_buff *skb, __be32 daddr, __be32 saddr, - u8 tos, struct net_device *dev, int our) +static enum skb_drop_reason +ip_route_input_mc(struct sk_buff *skb, __be32 daddr, __be32 saddr, + dscp_t dscp, struct net_device *dev, int our) { struct in_device *in_dev = __in_dev_get_rcu(dev); unsigned int flags = RTCF_MULTICAST; + enum skb_drop_reason reason; struct rtable *rth; u32 itag = 0; - int err; - err = ip_mc_validate_source(skb, daddr, saddr, tos, dev, in_dev, &itag); - if (err) - return err; + reason = ip_mc_validate_source(skb, daddr, saddr, dscp, dev, in_dev, + &itag); + if (reason) + return reason; if (our) flags |= RTCF_LOCAL; + if (IN_DEV_ORCONF(in_dev, NOPOLICY)) + IPCB(skb)->flags |= IPSKB_NOPOLICY; + rth = rt_dst_alloc(dev_net(dev)->loopback_dev, flags, RTN_MULTICAST, - IN_DEV_ORCONF(in_dev, NOPOLICY), false); + false); if (!rth) - return -ENOBUFS; + return SKB_DROP_REASON_NOMEM; #ifdef CONFIG_IP_ROUTE_CLASSID rth->dst.tclassid = itag; @@ -1735,8 +1776,9 @@ static int ip_route_input_mc(struct sk_buff *skb, __be32 daddr, __be32 saddr, #endif RT_CACHE_STAT_INC(in_slow_mc); + skb_dst_drop(skb); skb_dst_set(skb, &rth->dst); - return 0; + return SKB_NOT_DROPPED_YET; } @@ -1766,11 +1808,12 @@ static void ip_handle_martian_source(struct net_device *dev, } /* called in rcu_read_lock() section */ -static int __mkroute_input(struct sk_buff *skb, - const struct fib_result *res, - struct in_device *in_dev, - __be32 daddr, __be32 saddr, u32 tos) +static enum skb_drop_reason +__mkroute_input(struct sk_buff *skb, const struct fib_result *res, + struct in_device *in_dev, __be32 daddr, + __be32 saddr, dscp_t dscp) { + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; struct fib_nh_common *nhc = FIB_RES_NHC(*res); struct net_device *dev = nhc->nhc_dev; struct fib_nh_exception *fnhe; @@ -1784,12 +1827,13 @@ static int __mkroute_input(struct sk_buff *skb, out_dev = __in_dev_get_rcu(dev); if (!out_dev) { net_crit_ratelimited("Bug in ip_route_input_slow(). Please report.\n"); - return -EINVAL; + return reason; } - err = fib_validate_source(skb, saddr, daddr, tos, FIB_RES_OIF(*res), + err = fib_validate_source(skb, saddr, daddr, dscp, FIB_RES_OIF(*res), in_dev->dev, in_dev, &itag); if (err < 0) { + reason = -err; ip_handle_martian_source(in_dev->dev, in_dev, skb, daddr, saddr); @@ -1817,11 +1861,14 @@ static int __mkroute_input(struct sk_buff *skb, */ if (out_dev == in_dev && IN_DEV_PROXY_ARP_PVLAN(in_dev) == 0) { - err = -EINVAL; + reason = SKB_DROP_REASON_ARP_PVLAN_DISABLE; goto cleanup; } } + if (IN_DEV_ORCONF(in_dev, NOPOLICY)) + IPCB(skb)->flags |= IPSKB_NOPOLICY; + fnhe = find_exception(nhc, daddr); if (do_cache) { if (fnhe) @@ -1835,10 +1882,9 @@ static int __mkroute_input(struct sk_buff *skb, } rth = rt_dst_alloc(out_dev->dev, 0, res->type, - IN_DEV_ORCONF(in_dev, NOPOLICY), IN_DEV_ORCONF(out_dev, NOXFRM)); if (!rth) { - err = -ENOBUFS; + reason = SKB_DROP_REASON_NOMEM; goto cleanup; } @@ -1852,9 +1898,9 @@ static int __mkroute_input(struct sk_buff *skb, lwtunnel_set_redirect(&rth->dst); skb_dst_set(skb, &rth->dst); out: - err = 0; - cleanup: - return err; + reason = SKB_NOT_DROPPED_YET; +cleanup: + return reason; } #ifdef CONFIG_IP_ROUTE_MULTIPATH @@ -1901,7 +1947,7 @@ static u32 fib_multipath_custom_hash_outer(const struct net *net, const struct sk_buff *skb, bool *p_has_inner) { - u32 hash_fields = net->ipv4.sysctl_fib_multipath_hash_fields; + u32 hash_fields = READ_ONCE(net->ipv4.sysctl_fib_multipath_hash_fields); struct flow_keys keys, hash_keys; if (!(hash_fields & FIB_MULTIPATH_HASH_FIELD_OUTER_MASK)) @@ -1923,14 +1969,14 @@ static u32 fib_multipath_custom_hash_outer(const struct net *net, hash_keys.ports.dst = keys.ports.dst; *p_has_inner = !!(keys.control.flags & FLOW_DIS_ENCAPSULATION); - return flow_hash_from_keys(&hash_keys); + return fib_multipath_hash_from_keys(net, &hash_keys); } static u32 fib_multipath_custom_hash_inner(const struct net *net, const struct sk_buff *skb, bool has_inner) { - u32 hash_fields = net->ipv4.sysctl_fib_multipath_hash_fields; + u32 hash_fields = READ_ONCE(net->ipv4.sysctl_fib_multipath_hash_fields); struct flow_keys keys, hash_keys; /* We assume the packet carries an encapsulation, but if none was @@ -1972,7 +2018,7 @@ static u32 fib_multipath_custom_hash_inner(const struct net *net, if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_DST_PORT) hash_keys.ports.dst = keys.ports.dst; - return flow_hash_from_keys(&hash_keys); + return fib_multipath_hash_from_keys(net, &hash_keys); } static u32 fib_multipath_custom_hash_skb(const struct net *net, @@ -1990,7 +2036,7 @@ static u32 fib_multipath_custom_hash_skb(const struct net *net, static u32 fib_multipath_custom_hash_fl4(const struct net *net, const struct flowi4 *fl4) { - u32 hash_fields = net->ipv4.sysctl_fib_multipath_hash_fields; + u32 hash_fields = READ_ONCE(net->ipv4.sysctl_fib_multipath_hash_fields); struct flow_keys hash_keys; if (!(hash_fields & FIB_MULTIPATH_HASH_FIELD_OUTER_MASK)) @@ -2004,12 +2050,16 @@ static u32 fib_multipath_custom_hash_fl4(const struct net *net, hash_keys.addrs.v4addrs.dst = fl4->daddr; if (hash_fields & FIB_MULTIPATH_HASH_FIELD_IP_PROTO) hash_keys.basic.ip_proto = fl4->flowi4_proto; - if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_PORT) - hash_keys.ports.src = fl4->fl4_sport; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_PORT) { + if (fl4->flowi4_flags & FLOWI_FLAG_ANY_SPORT) + hash_keys.ports.src = (__force __be16)get_random_u16(); + else + hash_keys.ports.src = fl4->fl4_sport; + } if (hash_fields & FIB_MULTIPATH_HASH_FIELD_DST_PORT) hash_keys.ports.dst = fl4->fl4_dport; - return flow_hash_from_keys(&hash_keys); + return fib_multipath_hash_from_keys(net, &hash_keys); } /* if skb is set it will be used and fl4 can be NULL */ @@ -2020,7 +2070,7 @@ int fib_multipath_hash(const struct net *net, const struct flowi4 *fl4, struct flow_keys hash_keys; u32 mhash = 0; - switch (net->ipv4.sysctl_fib_multipath_hash_policy) { + switch (READ_ONCE(net->ipv4.sysctl_fib_multipath_hash_policy)) { case 0: memset(&hash_keys, 0, sizeof(hash_keys)); hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; @@ -2030,7 +2080,7 @@ int fib_multipath_hash(const struct net *net, const struct flowi4 *fl4, hash_keys.addrs.v4addrs.src = fl4->saddr; hash_keys.addrs.v4addrs.dst = fl4->daddr; } - mhash = flow_hash_from_keys(&hash_keys); + mhash = fib_multipath_hash_from_keys(net, &hash_keys); break; case 1: /* skb is currently provided only when forwarding */ @@ -2060,11 +2110,14 @@ int fib_multipath_hash(const struct net *net, const struct flowi4 *fl4, hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; hash_keys.addrs.v4addrs.src = fl4->saddr; hash_keys.addrs.v4addrs.dst = fl4->daddr; - hash_keys.ports.src = fl4->fl4_sport; + if (fl4->flowi4_flags & FLOWI_FLAG_ANY_SPORT) + hash_keys.ports.src = (__force __be16)get_random_u16(); + else + hash_keys.ports.src = fl4->fl4_sport; hash_keys.ports.dst = fl4->fl4_dport; hash_keys.basic.ip_proto = fl4->flowi4_proto; } - mhash = flow_hash_from_keys(&hash_keys); + mhash = fib_multipath_hash_from_keys(net, &hash_keys); break; case 2: memset(&hash_keys, 0, sizeof(hash_keys)); @@ -2095,7 +2148,7 @@ int fib_multipath_hash(const struct net *net, const struct flowi4 *fl4, hash_keys.addrs.v4addrs.src = fl4->saddr; hash_keys.addrs.v4addrs.dst = fl4->daddr; } - mhash = flow_hash_from_keys(&hash_keys); + mhash = fib_multipath_hash_from_keys(net, &hash_keys); break; case 3: if (skb) @@ -2112,62 +2165,72 @@ int fib_multipath_hash(const struct net *net, const struct flowi4 *fl4, } #endif /* CONFIG_IP_ROUTE_MULTIPATH */ -static int ip_mkroute_input(struct sk_buff *skb, - struct fib_result *res, - struct in_device *in_dev, - __be32 daddr, __be32 saddr, u32 tos, - struct flow_keys *hkeys) +static enum skb_drop_reason +ip_mkroute_input(struct sk_buff *skb, struct fib_result *res, + struct in_device *in_dev, __be32 daddr, + __be32 saddr, dscp_t dscp, struct flow_keys *hkeys) { #ifdef CONFIG_IP_ROUTE_MULTIPATH if (res->fi && fib_info_num_path(res->fi) > 1) { int h = fib_multipath_hash(res->fi->fib_net, NULL, skb, hkeys); - fib_select_multipath(res, h); + fib_select_multipath(res, h, NULL); + IPCB(skb)->flags |= IPSKB_MULTIPATH; } #endif /* create a routing cache entry */ - return __mkroute_input(skb, res, in_dev, daddr, saddr, tos); + return __mkroute_input(skb, res, in_dev, daddr, saddr, dscp); } /* Implements all the saddr-related checks as ip_route_input_slow(), * assuming daddr is valid and the destination is not a local broadcast one. * Uses the provided hint instead of performing a route lookup. */ -int ip_route_use_hint(struct sk_buff *skb, __be32 daddr, __be32 saddr, - u8 tos, struct net_device *dev, - const struct sk_buff *hint) +enum skb_drop_reason +ip_route_use_hint(struct sk_buff *skb, __be32 daddr, __be32 saddr, + dscp_t dscp, struct net_device *dev, + const struct sk_buff *hint) { + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; struct in_device *in_dev = __in_dev_get_rcu(dev); struct rtable *rt = skb_rtable(hint); struct net *net = dev_net(dev); - int err = -EINVAL; u32 tag = 0; - if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr)) + if (!in_dev) + return reason; + + if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr)) { + reason = SKB_DROP_REASON_IP_INVALID_SOURCE; goto martian_source; + } - if (ipv4_is_zeronet(saddr)) + if (ipv4_is_zeronet(saddr)) { + reason = SKB_DROP_REASON_IP_INVALID_SOURCE; goto martian_source; + } - if (ipv4_is_loopback(saddr) && !IN_DEV_NET_ROUTE_LOCALNET(in_dev, net)) + if (ipv4_is_loopback(saddr) && !IN_DEV_NET_ROUTE_LOCALNET(in_dev, net)) { + reason = SKB_DROP_REASON_IP_LOCALNET; goto martian_source; + } - if (rt->rt_type != RTN_LOCAL) + if (!(rt->rt_flags & RTCF_LOCAL)) goto skip_validate_source; - tos &= IPTOS_RT_MASK; - err = fib_validate_source(skb, saddr, daddr, tos, 0, dev, in_dev, &tag); - if (err < 0) + reason = fib_validate_source_reason(skb, saddr, daddr, dscp, 0, dev, + in_dev, &tag); + if (reason) goto martian_source; skip_validate_source: skb_dst_copy(skb, hint); - return 0; + return SKB_NOT_DROPPED_YET; martian_source: ip_handle_martian_source(dev, in_dev, skb, daddr, saddr); - return err; + return reason; } /* get device for dst_alloc with local routes */ @@ -2196,10 +2259,12 @@ static struct net_device *ip_rt_get_dev(struct net *net, * called with rcu_read_lock() */ -static int ip_route_input_slow(struct sk_buff *skb, __be32 daddr, __be32 saddr, - u8 tos, struct net_device *dev, - struct fib_result *res) +static enum skb_drop_reason +ip_route_input_slow(struct sk_buff *skb, __be32 daddr, __be32 saddr, + dscp_t dscp, struct net_device *dev, + struct fib_result *res) { + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; struct in_device *in_dev = __in_dev_get_rcu(dev); struct flow_keys *flkeys = NULL, _flkeys; struct net *net = dev_net(dev); @@ -2227,8 +2292,10 @@ static int ip_route_input_slow(struct sk_buff *skb, __be32 daddr, __be32 saddr, fl4.flowi4_tun_key.tun_id = 0; skb_dst_drop(skb); - if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr)) + if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr)) { + reason = SKB_DROP_REASON_IP_INVALID_SOURCE; goto martian_source; + } res->fi = NULL; res->table = NULL; @@ -2238,30 +2305,39 @@ static int ip_route_input_slow(struct sk_buff *skb, __be32 daddr, __be32 saddr, /* Accept zero addresses only to limited broadcast; * I even do not know to fix it or not. Waiting for complains :-) */ - if (ipv4_is_zeronet(saddr)) + if (ipv4_is_zeronet(saddr)) { + reason = SKB_DROP_REASON_IP_INVALID_SOURCE; goto martian_source; + } - if (ipv4_is_zeronet(daddr)) + if (ipv4_is_zeronet(daddr)) { + reason = SKB_DROP_REASON_IP_INVALID_DEST; goto martian_destination; + } /* Following code try to avoid calling IN_DEV_NET_ROUTE_LOCALNET(), * and call it once if daddr or/and saddr are loopback addresses */ if (ipv4_is_loopback(daddr)) { - if (!IN_DEV_NET_ROUTE_LOCALNET(in_dev, net)) + if (!IN_DEV_NET_ROUTE_LOCALNET(in_dev, net)) { + reason = SKB_DROP_REASON_IP_LOCALNET; goto martian_destination; + } } else if (ipv4_is_loopback(saddr)) { - if (!IN_DEV_NET_ROUTE_LOCALNET(in_dev, net)) + if (!IN_DEV_NET_ROUTE_LOCALNET(in_dev, net)) { + reason = SKB_DROP_REASON_IP_LOCALNET; goto martian_source; + } } /* * Now we are ready to route packet. */ + fl4.flowi4_l3mdev = 0; fl4.flowi4_oif = 0; fl4.flowi4_iif = dev->ifindex; fl4.flowi4_mark = skb->mark; - fl4.flowi4_tos = tos; + fl4.flowi4_dscp = dscp; fl4.flowi4_scope = RT_SCOPE_UNIVERSE; fl4.flowi4_flags = 0; fl4.daddr = daddr; @@ -2288,15 +2364,16 @@ static int ip_route_input_slow(struct sk_buff *skb, __be32 daddr, __be32 saddr, if (IN_DEV_BFORWARD(in_dev)) goto make_route; /* not do cache if bc_forwarding is enabled */ - if (IPV4_DEVCONF_ALL(net, BC_FORWARDING)) + if (IPV4_DEVCONF_ALL_RO(net, BC_FORWARDING)) do_cache = false; goto brd_input; } + err = -EINVAL; if (res->type == RTN_LOCAL) { - err = fib_validate_source(skb, saddr, daddr, tos, - 0, dev, in_dev, &itag); - if (err < 0) + reason = fib_validate_source_reason(skb, saddr, daddr, dscp, + 0, dev, in_dev, &itag); + if (reason) goto martian_source; goto local_input; } @@ -2305,21 +2382,28 @@ static int ip_route_input_slow(struct sk_buff *skb, __be32 daddr, __be32 saddr, err = -EHOSTUNREACH; goto no_route; } - if (res->type != RTN_UNICAST) + if (res->type != RTN_UNICAST) { + reason = SKB_DROP_REASON_IP_INVALID_DEST; goto martian_destination; + } make_route: - err = ip_mkroute_input(skb, res, in_dev, daddr, saddr, tos, flkeys); -out: return err; + reason = ip_mkroute_input(skb, res, in_dev, daddr, saddr, dscp, + flkeys); + +out: + return reason; brd_input: - if (skb->protocol != htons(ETH_P_IP)) - goto e_inval; + if (skb->protocol != htons(ETH_P_IP)) { + reason = SKB_DROP_REASON_INVALID_PROTO; + goto out; + } if (!ipv4_is_zeronet(saddr)) { - err = fib_validate_source(skb, saddr, 0, tos, 0, dev, - in_dev, &itag); - if (err < 0) + reason = fib_validate_source_reason(skb, saddr, 0, dscp, 0, + dev, in_dev, &itag); + if (reason) goto martian_source; } flags |= RTCF_BROADCAST; @@ -2327,6 +2411,9 @@ brd_input: RT_CACHE_STAT_INC(in_brd); local_input: + if (IN_DEV_ORCONF(in_dev, NOPOLICY)) + IPCB(skb)->flags |= IPSKB_NOPOLICY; + do_cache &= res->fi && !itag; if (do_cache) { struct fib_nh_common *nhc = FIB_RES_NHC(*res); @@ -2334,14 +2421,13 @@ local_input: rth = rcu_dereference(nhc->nhc_rth_input); if (rt_cache_valid(rth)) { skb_dst_set_noref(skb, &rth->dst); - err = 0; + reason = SKB_NOT_DROPPED_YET; goto out; } } rth = rt_dst_alloc(ip_rt_get_dev(net, res), - flags | RTCF_LOCAL, res->type, - IN_DEV_ORCONF(in_dev, NOPOLICY), false); + flags | RTCF_LOCAL, res->type, false); if (!rth) goto e_nobufs; @@ -2372,7 +2458,7 @@ local_input: rt_add_uncached_list(rth); } skb_dst_set(skb, &rth->dst); - err = 0; + reason = SKB_NOT_DROPPED_YET; goto out; no_route: @@ -2392,13 +2478,10 @@ martian_destination: net_warn_ratelimited("martian destination %pI4 from %pI4, dev %s\n", &daddr, &saddr, dev->name); #endif - -e_inval: - err = -EINVAL; goto out; e_nobufs: - err = -ENOBUFS; + reason = SKB_DROP_REASON_NOMEM; goto out; martian_source: @@ -2406,24 +2489,11 @@ martian_source: goto out; } -int ip_route_input_noref(struct sk_buff *skb, __be32 daddr, __be32 saddr, - u8 tos, struct net_device *dev) -{ - struct fib_result res; - int err; - - tos &= IPTOS_RT_MASK; - rcu_read_lock(); - err = ip_route_input_rcu(skb, daddr, saddr, tos, dev, &res); - rcu_read_unlock(); - - return err; -} -EXPORT_SYMBOL(ip_route_input_noref); - /* called with rcu_read_lock held */ -int ip_route_input_rcu(struct sk_buff *skb, __be32 daddr, __be32 saddr, - u8 tos, struct net_device *dev, struct fib_result *res) +static enum skb_drop_reason +ip_route_input_rcu(struct sk_buff *skb, __be32 daddr, __be32 saddr, + dscp_t dscp, struct net_device *dev, + struct fib_result *res) { /* Multicast recognition logic is moved from route cache to here. * The problem was that too many Ethernet cards have broken/missing @@ -2437,12 +2507,13 @@ int ip_route_input_rcu(struct sk_buff *skb, __be32 daddr, __be32 saddr, * route cache entry is created eventually. */ if (ipv4_is_multicast(daddr)) { + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; struct in_device *in_dev = __in_dev_get_rcu(dev); int our = 0; - int err = -EINVAL; if (!in_dev) - return err; + return reason; + our = ip_check_mc_rcu(in_dev, daddr, saddr, ip_hdr(skb)->protocol); @@ -2463,14 +2534,29 @@ int ip_route_input_rcu(struct sk_buff *skb, __be32 daddr, __be32 saddr, IN_DEV_MFORWARD(in_dev)) #endif ) { - err = ip_route_input_mc(skb, daddr, saddr, - tos, dev, our); + reason = ip_route_input_mc(skb, daddr, saddr, dscp, + dev, our); } - return err; + return reason; } - return ip_route_input_slow(skb, daddr, saddr, tos, dev, res); + return ip_route_input_slow(skb, daddr, saddr, dscp, dev, res); +} + +enum skb_drop_reason ip_route_input_noref(struct sk_buff *skb, __be32 daddr, + __be32 saddr, dscp_t dscp, + struct net_device *dev) +{ + enum skb_drop_reason reason; + struct fib_result res; + + rcu_read_lock(); + reason = ip_route_input_rcu(skb, daddr, saddr, dscp, dev, &res); + rcu_read_unlock(); + + return reason; } +EXPORT_SYMBOL(ip_route_input_noref); /* called with rcu_read_lock() */ static struct rtable *__mkroute_output(const struct fib_result *res, @@ -2495,12 +2581,16 @@ static struct rtable *__mkroute_output(const struct fib_result *res, !netif_is_l3_master(dev_out)) return ERR_PTR(-EINVAL); - if (ipv4_is_lbcast(fl4->daddr)) + if (ipv4_is_lbcast(fl4->daddr)) { type = RTN_BROADCAST; - else if (ipv4_is_multicast(fl4->daddr)) + + /* reset fi to prevent gateway resolution */ + fi = NULL; + } else if (ipv4_is_multicast(fl4->daddr)) { type = RTN_MULTICAST; - else if (ipv4_is_zeronet(fl4->daddr)) + } else if (ipv4_is_zeronet(fl4->daddr)) { return ERR_PTR(-EINVAL); + } if (dev_out->flags & IFF_LOOPBACK) flags |= RTCF_LOCAL; @@ -2508,7 +2598,6 @@ static struct rtable *__mkroute_output(const struct fib_result *res, do_cache = true; if (type == RTN_BROADCAST) { flags |= RTCF_BROADCAST | RTCF_LOCAL; - fi = NULL; } else if (type == RTN_MULTICAST) { flags |= RTCF_MULTICAST | RTCF_LOCAL; if (!ip_check_mc_rcu(in_dev, fl4->daddr, fl4->saddr, @@ -2564,7 +2653,6 @@ static struct rtable *__mkroute_output(const struct fib_result *res, add: rth = rt_dst_alloc(dev_out, flags, type, - IN_DEV_ORCONF(in_dev, NOPOLICY), IN_DEV_ORCONF(in_dev, NOXFRM)); if (!rth) return ERR_PTR(-ENOBUFS); @@ -2584,7 +2672,7 @@ add: if (IN_DEV_MFORWARD(in_dev) && !ipv4_is_local_multicast(fl4->daddr)) { rth->dst.input = ip_mr_input; - rth->dst.output = ip_mc_output; + rth->dst.output = ip_mr_output; } } #endif @@ -2603,7 +2691,6 @@ add: struct rtable *ip_route_output_key_hash(struct net *net, struct flowi4 *fl4, const struct sk_buff *skb) { - __u8 tos = RT_FL_TOS(fl4); struct fib_result res = { .type = RTN_UNSPEC, .fi = NULL, @@ -2613,9 +2700,6 @@ struct rtable *ip_route_output_key_hash(struct net *net, struct flowi4 *fl4, struct rtable *rth; fl4->flowi4_iif = LOOPBACK_IFINDEX; - fl4->flowi4_tos = tos & IPTOS_RT_MASK; - fl4->flowi4_scope = ((tos & RTO_ONLINK) ? - RT_SCOPE_LINK : RT_SCOPE_UNIVERSE); rcu_read_lock(); rth = ip_route_output_key_hash_rcu(net, fl4, &res, skb); @@ -2637,8 +2721,7 @@ struct rtable *ip_route_output_key_hash_rcu(struct net *net, struct flowi4 *fl4, if (fl4->saddr) { if (ipv4_is_multicast(fl4->saddr) || - ipv4_is_lbcast(fl4->saddr) || - ipv4_is_zeronet(fl4->saddr)) { + ipv4_is_lbcast(fl4->saddr)) { rth = ERR_PTR(-EINVAL); goto out; } @@ -2733,8 +2816,7 @@ struct rtable *ip_route_output_key_hash_rcu(struct net *net, struct flowi4 *fl4, res->fi = NULL; res->table = NULL; if (fl4->flowi4_oif && - (ipv4_is_multicast(fl4->daddr) || - !netif_index_is_l3_master(net, fl4->flowi4_oif))) { + (ipv4_is_multicast(fl4->daddr) || !fl4->flowi4_l3mdev)) { /* Apparently, routing tables are wrong. Assume, * that the destination is on link. * @@ -2809,10 +2891,10 @@ static struct dst_ops ipv4_dst_blackhole_ops = { struct dst_entry *ipv4_blackhole_route(struct net *net, struct dst_entry *dst_orig) { - struct rtable *ort = (struct rtable *) dst_orig; + struct rtable *ort = dst_rtable(dst_orig); struct rtable *rt; - rt = dst_alloc(&ipv4_dst_blackhole_ops, NULL, 1, DST_OBSOLETE_DEAD, 0); + rt = dst_alloc(&ipv4_dst_blackhole_ops, NULL, DST_OBSOLETE_DEAD, 0); if (rt) { struct dst_entry *new = &rt->dst; @@ -2821,7 +2903,7 @@ struct dst_entry *ipv4_blackhole_route(struct net *net, struct dst_entry *dst_or new->output = dst_discard_out; new->dev = net->loopback_dev; - dev_hold_track(new->dev, &new->dev_tracker, GFP_ATOMIC); + netdev_hold(new->dev, &new->dev_tracker, GFP_ATOMIC); rt->rt_is_input = ort->rt_is_input; rt->rt_iif = ort->rt_iif; @@ -2837,8 +2919,6 @@ struct dst_entry *ipv4_blackhole_route(struct net *net, struct dst_entry *dst_or rt->rt_gw4 = ort->rt_gw4; else if (rt->rt_gw_family == AF_INET6) rt->rt_gw6 = ort->rt_gw6; - - INIT_LIST_HEAD(&rt->rt_uncached); } dst_release(dst_orig); @@ -2856,68 +2936,20 @@ struct rtable *ip_route_output_flow(struct net *net, struct flowi4 *flp4, if (flp4->flowi4_proto) { flp4->flowi4_oif = rt->dst.dev->ifindex; - rt = (struct rtable *)xfrm_lookup_route(net, &rt->dst, - flowi4_to_flowi(flp4), - sk, 0); + rt = dst_rtable(xfrm_lookup_route(net, &rt->dst, + flowi4_to_flowi(flp4), + sk, 0)); } return rt; } EXPORT_SYMBOL_GPL(ip_route_output_flow); -struct rtable *ip_route_output_tunnel(struct sk_buff *skb, - struct net_device *dev, - struct net *net, __be32 *saddr, - const struct ip_tunnel_info *info, - u8 protocol, bool use_cache) -{ -#ifdef CONFIG_DST_CACHE - struct dst_cache *dst_cache; -#endif - struct rtable *rt = NULL; - struct flowi4 fl4; - __u8 tos; - -#ifdef CONFIG_DST_CACHE - dst_cache = (struct dst_cache *)&info->dst_cache; - if (use_cache) { - rt = dst_cache_get_ip4(dst_cache, saddr); - if (rt) - return rt; - } -#endif - memset(&fl4, 0, sizeof(fl4)); - fl4.flowi4_mark = skb->mark; - fl4.flowi4_proto = protocol; - fl4.daddr = info->key.u.ipv4.dst; - fl4.saddr = info->key.u.ipv4.src; - tos = info->key.tos; - fl4.flowi4_tos = RT_TOS(tos); - - rt = ip_route_output_key(net, &fl4); - if (IS_ERR(rt)) { - netdev_dbg(dev, "no route to %pI4\n", &fl4.daddr); - return ERR_PTR(-ENETUNREACH); - } - if (rt->dst.dev == dev) { /* is this necessary? */ - netdev_dbg(dev, "circular route to %pI4\n", &fl4.daddr); - ip_rt_put(rt); - return ERR_PTR(-ELOOP); - } -#ifdef CONFIG_DST_CACHE - if (use_cache) - dst_cache_set_ip4(dst_cache, &rt->dst, fl4.saddr); -#endif - *saddr = fl4.saddr; - return rt; -} -EXPORT_SYMBOL_GPL(ip_route_output_tunnel); - /* called with rcu_read_lock held */ static int rt_fill_info(struct net *net, __be32 dst, __be32 src, - struct rtable *rt, u32 table_id, struct flowi4 *fl4, - struct sk_buff *skb, u32 portid, u32 seq, - unsigned int flags) + struct rtable *rt, u32 table_id, dscp_t dscp, + struct flowi4 *fl4, struct sk_buff *skb, u32 portid, + u32 seq, unsigned int flags) { struct rtmsg *r; struct nlmsghdr *nlh; @@ -2933,7 +2965,7 @@ static int rt_fill_info(struct net *net, __be32 dst, __be32 src, r->rtm_family = AF_INET; r->rtm_dst_len = 32; r->rtm_src_len = 0; - r->rtm_tos = fl4 ? fl4->flowi4_tos : 0; + r->rtm_tos = inet_dscp_to_dsfield(dscp); r->rtm_table = table_id < 256 ? table_id : RT_TABLE_COMPAT; if (nla_put_u32(skb, RTA_TABLE, table_id)) goto nla_put_failure; @@ -2956,8 +2988,7 @@ static int rt_fill_info(struct net *net, __be32 dst, __be32 src, if (rt->dst.dev && nla_put_u32(skb, RTA_OIF, rt->dst.dev->ifindex)) goto nla_put_failure; - if (rt->dst.lwtstate && - lwtunnel_fill_encap(skb, rt->dst.lwtstate, RTA_ENCAP, RTA_ENCAP_TYPE) < 0) + if (lwtunnel_fill_encap(skb, rt->dst.lwtstate, RTA_ENCAP, RTA_ENCAP_TYPE) < 0) goto nla_put_failure; #ifdef CONFIG_IP_ROUTE_CLASSID if (rt->dst.tclassid && @@ -2988,7 +3019,7 @@ static int rt_fill_info(struct net *net, __be32 dst, __be32 src, } } - expires = rt->dst.expires; + expires = READ_ONCE(rt->dst.expires); if (expires) { unsigned long now = jiffies; @@ -3021,7 +3052,7 @@ static int rt_fill_info(struct net *net, __be32 dst, __be32 src, #ifdef CONFIG_IP_MROUTE if (ipv4_is_multicast(dst) && !ipv4_is_local_multicast(dst) && - IPV4_DEVCONF_ALL(net, MC_FORWARDING)) { + IPV4_DEVCONF_ALL_RO(net, MC_FORWARDING)) { int err = ipmr_get_route(net, skb, fl4->saddr, fl4->daddr, r, portid); @@ -3083,7 +3114,7 @@ static int fnhe_dump_bucket(struct net *net, struct sk_buff *skb, goto next; err = rt_fill_info(net, fnhe->fnhe_daddr, 0, rt, - table_id, NULL, skb, + table_id, 0, NULL, skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, flags); if (err) @@ -3195,7 +3226,8 @@ static int inet_rtm_valid_getroute_req(struct sk_buff *skb, struct rtmsg *rtm; int i, err; - if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + rtm = nlmsg_payload(nlh, sizeof(*rtm)); + if (!rtm) { NL_SET_ERR_MSG(extack, "ipv4: Invalid header for route get request"); return -EINVAL; @@ -3205,7 +3237,6 @@ static int inet_rtm_valid_getroute_req(struct sk_buff *skb, return nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_ipv4_policy, extack); - rtm = nlmsg_data(nlh); if ((rtm->rtm_src_len && rtm->rtm_src_len != 32) || (rtm->rtm_dst_len && rtm->rtm_dst_len != 32) || rtm->rtm_table || rtm->rtm_protocol || @@ -3271,6 +3302,7 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct flowi4 fl4 = {}; __be32 dst = 0; __be32 src = 0; + dscp_t dscp; kuid_t uid; u32 iif; int err; @@ -3281,10 +3313,11 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, return err; rtm = nlmsg_data(nlh); - src = tb[RTA_SRC] ? nla_get_in_addr(tb[RTA_SRC]) : 0; - dst = tb[RTA_DST] ? nla_get_in_addr(tb[RTA_DST]) : 0; - iif = tb[RTA_IIF] ? nla_get_u32(tb[RTA_IIF]) : 0; - mark = tb[RTA_MARK] ? nla_get_u32(tb[RTA_MARK]) : 0; + src = nla_get_in_addr_default(tb[RTA_SRC], 0); + dst = nla_get_in_addr_default(tb[RTA_DST], 0); + iif = nla_get_u32_default(tb[RTA_IIF], 0); + mark = nla_get_u32_default(tb[RTA_MARK], 0); + dscp = inet_dsfield_to_dscp(rtm->rtm_tos); if (tb[RTA_UID]) uid = make_kuid(current_user_ns(), nla_get_u32(tb[RTA_UID])); else @@ -3309,8 +3342,8 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, fl4.daddr = dst; fl4.saddr = src; - fl4.flowi4_tos = rtm->rtm_tos & IPTOS_RT_MASK; - fl4.flowi4_oif = tb[RTA_OIF] ? nla_get_u32(tb[RTA_OIF]) : 0; + fl4.flowi4_dscp = dscp; + fl4.flowi4_oif = nla_get_u32_default(tb[RTA_OIF], 0); fl4.flowi4_mark = mark; fl4.flowi4_uid = uid; if (sport) @@ -3333,9 +3366,8 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, fl4.flowi4_iif = iif; /* for rt_fill_info */ skb->dev = dev; skb->mark = mark; - err = ip_route_input_rcu(skb, dst, src, - rtm->rtm_tos & IPTOS_RT_MASK, dev, - &res); + err = ip_route_input_rcu(skb, dst, src, dscp, dev, + &res) ? -EINVAL : 0; rt = skb_rtable(skb); if (err == 0 && rt->dst.error) @@ -3379,7 +3411,7 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, fri.tb_id = table_id; fri.dst = res.prefix; fri.dst_len = res.prefixlen; - fri.tos = fl4.flowi4_tos; + fri.dscp = res.dscp; fri.type = rt->rt_type; fri.offload = 0; fri.trap = 0; @@ -3392,11 +3424,13 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, if (fa->fa_slen == slen && fa->tb_id == fri.tb_id && - fa->fa_tos == fri.tos && + fa->fa_dscp == fri.dscp && fa->fa_info == res.fi && fa->fa_type == fri.type) { - fri.offload = fa->offload; - fri.trap = fa->trap; + fri.offload = READ_ONCE(fa->offload); + fri.trap = READ_ONCE(fa->trap); + fri.offload_failed = + READ_ONCE(fa->offload_failed); break; } } @@ -3404,8 +3438,8 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, err = fib_dump_info(skb, NETLINK_CB(in_skb).portid, nlh->nlmsg_seq, RTM_NEWROUTE, &fri, 0); } else { - err = rt_fill_info(net, dst, src, rt, table_id, &fl4, skb, - NETLINK_CB(in_skb).portid, + err = rt_fill_info(net, dst, src, rt, table_id, res.dscp, &fl4, + skb, NETLINK_CB(in_skb).portid, nlh->nlmsg_seq, 0); } if (err < 0) @@ -3434,7 +3468,7 @@ static int ip_rt_gc_min_interval __read_mostly = HZ / 2; static int ip_rt_gc_elasticity __read_mostly = 8; static int ip_min_valid_pmtu __read_mostly = IPV4_MIN_MTU; -static int ipv4_sysctl_rtcache_flush(struct ctl_table *__ctl, int write, +static int ipv4_sysctl_rtcache_flush(const struct ctl_table *__ctl, int write, void *buffer, size_t *lenp, loff_t *ppos) { struct net *net = (struct net *)__ctl->extra1; @@ -3535,14 +3569,6 @@ static struct ctl_table ipv4_route_table[] = { .mode = 0644, .proc_handler = proc_dointvec, }, - { - .procname = "min_adv_mss", - .data = &ip_rt_min_advmss, - .maxlen = sizeof(int), - .mode = 0644, - .proc_handler = proc_dointvec, - }, - { } }; static const char ipv4_route_flush_procname[] = "flush"; @@ -3569,12 +3595,19 @@ static struct ctl_table ipv4_route_netns_table[] = { .mode = 0644, .proc_handler = proc_dointvec_jiffies, }, - { }, + { + .procname = "min_adv_mss", + .data = &init_net.ipv4.ip_rt_min_advmss, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, }; static __net_init int sysctl_route_net_init(struct net *net) { struct ctl_table *tbl; + size_t table_size = ARRAY_SIZE(ipv4_route_netns_table); tbl = ipv4_route_netns_table; if (!net_eq(net, &init_net)) { @@ -3587,18 +3620,19 @@ static __net_init int sysctl_route_net_init(struct net *net) /* Don't export non-whitelisted sysctls to unprivileged users */ if (net->user_ns != &init_user_ns) { if (tbl[0].procname != ipv4_route_flush_procname) - tbl[0].procname = NULL; + table_size = 0; } /* Update the variables to point into the current struct net * except for the first element flush */ - for (i = 1; i < ARRAY_SIZE(ipv4_route_netns_table) - 1; i++) + for (i = 1; i < table_size; i++) tbl[i].data += (void *)net - (void *)&init_net; } tbl[0].extra1 = net; - net->ipv4.route_hdr = register_net_sysctl(net, "net/ipv4/route", tbl); + net->ipv4.route_hdr = register_net_sysctl_sz(net, "net/ipv4/route", + tbl, table_size); if (!net->ipv4.route_hdr) goto err_reg; return 0; @@ -3612,7 +3646,7 @@ err_dup: static __net_exit void sysctl_route_net_exit(struct net *net) { - struct ctl_table *tbl; + const struct ctl_table *tbl; tbl = net->ipv4.route_hdr->ctl_table_arg; unregister_net_sysctl_table(net->ipv4.route_hdr); @@ -3631,6 +3665,7 @@ static __net_init int netns_ip_rt_init(struct net *net) /* Set default value for namespaceified sysctls */ net->ipv4.ip_rt_min_pmtu = DEFAULT_MIN_PMTU; net->ipv4.ip_rt_mtu_expires = DEFAULT_MTU_EXPIRES; + net->ipv4.ip_rt_min_advmss = DEFAULT_MIN_ADVMSS; return 0; } @@ -3642,7 +3677,7 @@ static __net_init int rt_genid_init(struct net *net) { atomic_set(&net->ipv4.rt_genid, 0); atomic_set(&net->fnhe_genid, 0); - atomic_set(&net->ipv4.dev_addr_genid, get_random_int()); + atomic_set(&net->ipv4.dev_addr_genid, get_random_u32()); return 0; } @@ -3679,6 +3714,11 @@ static __net_initdata struct pernet_operations ipv4_inetpeer_ops = { struct ip_rt_acct __percpu *ip_rt_acct __read_mostly; #endif /* CONFIG_IP_ROUTE_CLASSID */ +static const struct rtnl_msg_handler ip_rt_rtnl_msg_handlers[] __initconst = { + {.protocol = PF_INET, .msgtype = RTM_GETROUTE, + .doit = inet_rtm_getroute, .flags = RTNL_FLAG_DOIT_UNLOCKED}, +}; + int __init ip_rt_init(void) { void *idents_hash; @@ -3697,7 +3737,7 @@ int __init ip_rt_init(void) ip_idents = idents_hash; - prandom_bytes(ip_idents, (ip_idents_mask + 1) * sizeof(*ip_idents)); + get_random_bytes(ip_idents, (ip_idents_mask + 1) * sizeof(*ip_idents)); ip_tstamps = idents_hash + (ip_idents_mask + 1) * sizeof(*ip_idents); @@ -3713,9 +3753,8 @@ int __init ip_rt_init(void) panic("IP: failed to allocate ip_rt_acct\n"); #endif - ipv4_dst_ops.kmem_cachep = - kmem_cache_create("ip_dst_cache", sizeof(struct rtable), 0, - SLAB_HWCACHE_ALIGN|SLAB_PANIC, NULL); + ipv4_dst_ops.kmem_cachep = KMEM_CACHE(rtable, + SLAB_HWCACHE_ALIGN | SLAB_PANIC); ipv4_dst_blackhole_ops.kmem_cachep = ipv4_dst_ops.kmem_cachep; @@ -3737,8 +3776,7 @@ int __init ip_rt_init(void) xfrm_init(); xfrm4_init(); #endif - rtnl_register(PF_INET, RTM_GETROUTE, inet_rtm_getroute, NULL, - RTNL_FLAG_DOIT_UNLOCKED); + rtnl_register_many(ip_rt_rtnl_msg_handlers); #ifdef CONFIG_SYSCTL register_pernet_subsys(&sysctl_route_ops); diff --git a/net/ipv4/syncookies.c b/net/ipv4/syncookies.c index 2cb3b852d148..569befcf021b 100644 --- a/net/ipv4/syncookies.c +++ b/net/ipv4/syncookies.c @@ -12,6 +12,7 @@ #include <linux/export.h> #include <net/secure_seq.h> #include <net/tcp.h> +#include <net/tcp_ecn.h> #include <net/route.h> static siphash_aligned_key_t syncookie_secret[2]; @@ -41,7 +42,6 @@ static siphash_aligned_key_t syncookie_secret[2]; * requested/supported by the syn/synack exchange. */ #define TSBITS 6 -#define TSMASK (((__u32)1 << TSBITS) - 1) static u32 cookie_hash(__be32 saddr, __be32 daddr, __be16 sport, __be16 dport, u32 count, int c) @@ -52,7 +52,6 @@ static u32 cookie_hash(__be32 saddr, __be32 daddr, __be16 sport, __be16 dport, count, &syncookie_secret[c]); } - /* * when syncookies are in effect and tcp timestamps are enabled we encode * tcp options in the lower bits of the timestamp value that will be @@ -62,27 +61,24 @@ static u32 cookie_hash(__be32 saddr, __be32 daddr, __be16 sport, __be16 dport, */ u64 cookie_init_timestamp(struct request_sock *req, u64 now) { - struct inet_request_sock *ireq; - u32 ts, ts_now = tcp_ns_to_ts(now); + const struct inet_request_sock *ireq = inet_rsk(req); + u64 ts, ts_now = tcp_ns_to_ts(false, now); u32 options = 0; - ireq = inet_rsk(req); - options = ireq->wscale_ok ? ireq->snd_wscale : TS_OPT_WSCALE_MASK; if (ireq->sack_ok) options |= TS_OPT_SACK; if (ireq->ecn_ok) options |= TS_OPT_ECN; - ts = ts_now & ~TSMASK; + ts = (ts_now >> TSBITS) << TSBITS; ts |= options; - if (ts > ts_now) { - ts >>= TSBITS; - ts--; - ts <<= TSBITS; - ts |= options; - } - return (u64)ts * (NSEC_PER_SEC / TCP_TS_HZ); + if (ts > ts_now) + ts -= (1UL << TSBITS); + + if (tcp_rsk(req)->req_usec_ts) + return ts * NSEC_PER_USEC; + return ts * NSEC_PER_MSEC; } @@ -185,12 +181,14 @@ __u32 cookie_v4_init_sequence(const struct sk_buff *skb, __u16 *mssp) * Check if a ack sequence number is a valid syncookie. * Return the decoded mss if it is, or 0 if not. */ -int __cookie_v4_check(const struct iphdr *iph, const struct tcphdr *th, - u32 cookie) +int __cookie_v4_check(const struct iphdr *iph, const struct tcphdr *th) { + __u32 cookie = ntohl(th->ack_seq) - 1; __u32 seq = ntohl(th->seq) - 1; - __u32 mssind = check_tcp_syn_cookie(cookie, iph->saddr, iph->daddr, - th->source, th->dest, seq); + __u32 mssind; + + mssind = check_tcp_syn_cookie(cookie, iph->saddr, iph->daddr, + th->source, th->dest, seq); return mssind < ARRAY_SIZE(msstab) ? msstab[mssind] : 0; } @@ -198,7 +196,7 @@ EXPORT_SYMBOL_GPL(__cookie_v4_check); struct sock *tcp_get_cookie_sock(struct sock *sk, struct sk_buff *skb, struct request_sock *req, - struct dst_entry *dst, u32 tsoff) + struct dst_entry *dst) { struct inet_connection_sock *icsk = inet_csk(sk); struct sock *child; @@ -208,7 +206,6 @@ struct sock *tcp_get_cookie_sock(struct sock *sk, struct sk_buff *skb, NULL, &own_req); if (child) { refcount_set(&req->rsk_refcnt, 1); - tcp_sk(child)->tsoffset = tsoff; sock_rps_save_rxhash(child, skb); if (rsk_drop_req(req)) { @@ -226,7 +223,7 @@ struct sock *tcp_get_cookie_sock(struct sock *sk, struct sk_buff *skb, return NULL; } -EXPORT_SYMBOL(tcp_get_cookie_sock); +EXPORT_IPV6_MOD(tcp_get_cookie_sock); /* * when syncookies are in effect and tcp timestamps are enabled we stored @@ -247,12 +244,12 @@ bool cookie_timestamp_decode(const struct net *net, return true; } - if (!net->ipv4.sysctl_tcp_timestamps) + if (!READ_ONCE(net->ipv4.sysctl_tcp_timestamps)) return false; tcp_opt->sack_ok = (options & TS_OPT_SACK) ? TCP_SACK_SEEN : 0; - if (tcp_opt->sack_ok && !net->ipv4.sysctl_tcp_sack) + if (tcp_opt->sack_ok && !READ_ONCE(net->ipv4.sysctl_tcp_sack)) return false; if ((options & TS_OPT_WSCALE_MASK) == TS_OPT_WSCALE_MASK) @@ -261,149 +258,194 @@ bool cookie_timestamp_decode(const struct net *net, tcp_opt->wscale_ok = 1; tcp_opt->snd_wscale = options & TS_OPT_WSCALE_MASK; - return net->ipv4.sysctl_tcp_window_scaling != 0; + return READ_ONCE(net->ipv4.sysctl_tcp_window_scaling) != 0; } -EXPORT_SYMBOL(cookie_timestamp_decode); +EXPORT_IPV6_MOD(cookie_timestamp_decode); -bool cookie_ecn_ok(const struct tcp_options_received *tcp_opt, - const struct net *net, const struct dst_entry *dst) +static int cookie_tcp_reqsk_init(struct sock *sk, struct sk_buff *skb, + struct request_sock *req) { - bool ecn_ok = tcp_opt->rcv_tsecr & TS_OPT_ECN; + struct inet_request_sock *ireq = inet_rsk(req); + struct tcp_request_sock *treq = tcp_rsk(req); + const struct tcphdr *th = tcp_hdr(skb); - if (!ecn_ok) - return false; + req->num_retrans = 0; - if (net->ipv4.sysctl_tcp_ecn) - return true; + ireq->ir_num = ntohs(th->dest); + ireq->ir_rmt_port = th->source; + ireq->ir_iif = inet_request_bound_dev_if(sk, skb); + ireq->ir_mark = inet_request_mark(sk, skb); - return dst_feature(dst, RTAX_FEATURE_ECN); + if (IS_ENABLED(CONFIG_SMC)) + ireq->smc_ok = 0; + + treq->snt_synack = 0; + treq->snt_tsval_first = 0; + treq->tfo_listener = false; + treq->txhash = net_tx_rndhash(); + treq->rcv_isn = ntohl(th->seq) - 1; + treq->snt_isn = ntohl(th->ack_seq) - 1; + treq->syn_tos = TCP_SKB_CB(skb)->ip_dsfield; + treq->req_usec_ts = false; + +#if IS_ENABLED(CONFIG_MPTCP) + treq->is_mptcp = sk_is_mptcp(sk); + if (treq->is_mptcp) + return mptcp_subflow_init_cookie_req(req, sk, skb); +#endif + + return 0; } -EXPORT_SYMBOL(cookie_ecn_ok); + +#if IS_ENABLED(CONFIG_BPF) +struct request_sock *cookie_bpf_check(struct sock *sk, struct sk_buff *skb) +{ + struct request_sock *req = inet_reqsk(skb->sk); + + skb->sk = NULL; + skb->destructor = NULL; + + if (cookie_tcp_reqsk_init(sk, skb, req)) { + reqsk_free(req); + req = NULL; + } + + return req; +} +EXPORT_IPV6_MOD_GPL(cookie_bpf_check); +#endif struct request_sock *cookie_tcp_reqsk_alloc(const struct request_sock_ops *ops, - struct sock *sk, - struct sk_buff *skb) + struct sock *sk, struct sk_buff *skb, + struct tcp_options_received *tcp_opt, + int mss, u32 tsoff) { + struct inet_request_sock *ireq; struct tcp_request_sock *treq; struct request_sock *req; -#ifdef CONFIG_MPTCP if (sk_is_mptcp(sk)) - ops = &mptcp_subflow_request_sock_ops; -#endif + req = mptcp_subflow_reqsk_alloc(ops, sk, false); + else + req = inet_reqsk_alloc(ops, sk, false); - req = inet_reqsk_alloc(ops, sk, false); if (!req) return NULL; + if (cookie_tcp_reqsk_init(sk, skb, req)) { + reqsk_free(req); + return NULL; + } + + ireq = inet_rsk(req); treq = tcp_rsk(req); - treq->syn_tos = TCP_SKB_CB(skb)->ip_dsfield; -#if IS_ENABLED(CONFIG_MPTCP) - treq->is_mptcp = sk_is_mptcp(sk); - if (treq->is_mptcp) { - int err = mptcp_subflow_init_cookie_req(req, sk, skb); - if (err) { - reqsk_free(req); - return NULL; - } - } -#endif + req->mss = mss; + req->ts_recent = tcp_opt->saw_tstamp ? tcp_opt->rcv_tsval : 0; + + ireq->snd_wscale = tcp_opt->snd_wscale; + ireq->tstamp_ok = tcp_opt->saw_tstamp; + ireq->sack_ok = tcp_opt->sack_ok; + ireq->wscale_ok = tcp_opt->wscale_ok; + ireq->ecn_ok = !!(tcp_opt->rcv_tsecr & TS_OPT_ECN); + + treq->ts_off = tsoff; return req; } -EXPORT_SYMBOL_GPL(cookie_tcp_reqsk_alloc); +EXPORT_IPV6_MOD_GPL(cookie_tcp_reqsk_alloc); -/* On input, sk is a listener. - * Output is listener if incoming packet would not create a child - * NULL if memory could not be allocated. - */ -struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb) +static struct request_sock *cookie_tcp_check(struct net *net, struct sock *sk, + struct sk_buff *skb) { - struct ip_options *opt = &TCP_SKB_CB(skb)->header.h4.opt; struct tcp_options_received tcp_opt; - struct inet_request_sock *ireq; - struct tcp_request_sock *treq; - struct tcp_sock *tp = tcp_sk(sk); - const struct tcphdr *th = tcp_hdr(skb); - __u32 cookie = ntohl(th->ack_seq) - 1; - struct sock *ret = sk; - struct request_sock *req; - int full_space, mss; - struct rtable *rt; - __u8 rcv_wscale; - struct flowi4 fl4; u32 tsoff = 0; - - if (!sock_net(sk)->ipv4.sysctl_tcp_syncookies || !th->ack || th->rst) - goto out; + int mss; if (tcp_synq_no_recent_overflow(sk)) goto out; - mss = __cookie_v4_check(ip_hdr(skb), th, cookie); - if (mss == 0) { - __NET_INC_STATS(sock_net(sk), LINUX_MIB_SYNCOOKIESFAILED); + mss = __cookie_v4_check(ip_hdr(skb), tcp_hdr(skb)); + if (!mss) { + __NET_INC_STATS(net, LINUX_MIB_SYNCOOKIESFAILED); goto out; } - __NET_INC_STATS(sock_net(sk), LINUX_MIB_SYNCOOKIESRECV); + __NET_INC_STATS(net, LINUX_MIB_SYNCOOKIESRECV); /* check for timestamp cookie support */ memset(&tcp_opt, 0, sizeof(tcp_opt)); - tcp_parse_options(sock_net(sk), skb, &tcp_opt, 0, NULL); + tcp_parse_options(net, skb, &tcp_opt, 0, NULL); if (tcp_opt.saw_tstamp && tcp_opt.rcv_tsecr) { - tsoff = secure_tcp_ts_off(sock_net(sk), + tsoff = secure_tcp_ts_off(net, ip_hdr(skb)->daddr, ip_hdr(skb)->saddr); tcp_opt.rcv_tsecr -= tsoff; } - if (!cookie_timestamp_decode(sock_net(sk), &tcp_opt)) + if (!cookie_timestamp_decode(net, &tcp_opt)) goto out; - ret = NULL; - req = cookie_tcp_reqsk_alloc(&tcp_request_sock_ops, sk, skb); - if (!req) + return cookie_tcp_reqsk_alloc(&tcp_request_sock_ops, sk, skb, + &tcp_opt, mss, tsoff); +out: + return ERR_PTR(-EINVAL); +} + +/* On input, sk is a listener. + * Output is listener if incoming packet would not create a child + * NULL if memory could not be allocated. + */ +struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb) +{ + struct ip_options *opt = &TCP_SKB_CB(skb)->header.h4.opt; + const struct tcphdr *th = tcp_hdr(skb); + struct tcp_sock *tp = tcp_sk(sk); + struct inet_request_sock *ireq; + struct net *net = sock_net(sk); + struct tcp_request_sock *treq; + struct request_sock *req; + struct sock *ret = sk; + struct flowi4 fl4; + struct rtable *rt; + __u8 rcv_wscale; + int full_space; + SKB_DR(reason); + + if (!READ_ONCE(net->ipv4.sysctl_tcp_syncookies) || + !th->ack || th->rst) goto out; + if (cookie_bpf_ok(skb)) { + req = cookie_bpf_check(sk, skb); + } else { + req = cookie_tcp_check(net, sk, skb); + if (IS_ERR(req)) + goto out; + } + if (!req) { + SKB_DR_SET(reason, NO_SOCKET); + goto out_drop; + } + ireq = inet_rsk(req); treq = tcp_rsk(req); - treq->rcv_isn = ntohl(th->seq) - 1; - treq->snt_isn = cookie; - treq->ts_off = 0; - treq->txhash = net_tx_rndhash(); - req->mss = mss; - ireq->ir_num = ntohs(th->dest); - ireq->ir_rmt_port = th->source; + sk_rcv_saddr_set(req_to_sk(req), ip_hdr(skb)->daddr); sk_daddr_set(req_to_sk(req), ip_hdr(skb)->saddr); - ireq->ir_mark = inet_request_mark(sk, skb); - ireq->snd_wscale = tcp_opt.snd_wscale; - ireq->sack_ok = tcp_opt.sack_ok; - ireq->wscale_ok = tcp_opt.wscale_ok; - ireq->tstamp_ok = tcp_opt.saw_tstamp; - req->ts_recent = tcp_opt.saw_tstamp ? tcp_opt.rcv_tsval : 0; - treq->snt_synack = 0; - treq->tfo_listener = false; - - if (IS_ENABLED(CONFIG_SMC)) - ireq->smc_ok = 0; - - ireq->ir_iif = inet_request_bound_dev_if(sk, skb); /* We throwed the options of the initial SYN away, so we hope * the ACK carries the same options again (see RFC1122 4.2.3.8) */ - RCU_INIT_POINTER(ireq->ireq_opt, tcp_v4_save_options(sock_net(sk), skb)); + RCU_INIT_POINTER(ireq->ireq_opt, tcp_v4_save_options(net, skb)); if (security_inet_conn_request(sk, skb, req)) { - reqsk_free(req); - goto out; + SKB_DR_SET(reason, SECURITY_HOOK); + goto out_free; } - req->num_retrans = 0; + tcp_ao_syncookie(sk, skb, req, AF_INET); /* * We need to lookup the route here to get at the correct @@ -412,19 +454,21 @@ struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb) * no easy way to do this. */ flowi4_init_output(&fl4, ireq->ir_iif, ireq->ir_mark, - RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, IPPROTO_TCP, - inet_sk_flowi_flags(sk), + ip_sock_rt_tos(sk), ip_sock_rt_scope(sk), + IPPROTO_TCP, inet_sk_flowi_flags(sk), opt->srr ? opt->faddr : ireq->ir_rmt_addr, - ireq->ir_loc_addr, th->source, th->dest, sk->sk_uid); + ireq->ir_loc_addr, th->source, th->dest, + sk_uid(sk)); security_req_classify_flow(req, flowi4_to_flowi_common(&fl4)); - rt = ip_route_output_key(sock_net(sk), &fl4); + rt = ip_route_output_key(net, &fl4); if (IS_ERR(rt)) { - reqsk_free(req); - goto out; + SKB_DR_SET(reason, IP_OUTNOROUTES); + goto out_free; } /* Try to redo what tcp_v4_send_synack did. */ - req->rsk_window_clamp = tp->window_clamp ? :dst_metric(&rt->dst, RTAX_WINDOW); + req->rsk_window_clamp = READ_ONCE(tp->window_clamp) ? : + dst_metric(&rt->dst, RTAX_WINDOW); /* limit the window selection if the user enforce a smaller rx buffer */ full_space = tcp_full_space(sk); if (sk->sk_userlocks & SOCK_RCVBUF_LOCK && @@ -436,14 +480,28 @@ struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb) ireq->wscale_ok, &rcv_wscale, dst_metric(&rt->dst, RTAX_INITRWND)); - ireq->rcv_wscale = rcv_wscale; - ireq->ecn_ok = cookie_ecn_ok(&tcp_opt, sock_net(sk), &rt->dst); + /* req->syncookie is set true only if ACK is validated + * by BPF kfunc, then, rcv_wscale is already configured. + */ + if (!req->syncookie) + ireq->rcv_wscale = rcv_wscale; + ireq->ecn_ok &= cookie_ecn_ok(net, &rt->dst); + treq->accecn_ok = ireq->ecn_ok && cookie_accecn_ok(th); - ret = tcp_get_cookie_sock(sk, skb, req, &rt->dst, tsoff); + ret = tcp_get_cookie_sock(sk, skb, req, &rt->dst); /* ip_queue_xmit() depends on our flow being setup * Normal sockets get it right from inet_csk_route_child_sock() */ - if (ret) - inet_sk(ret)->cork.fl.u.ip4 = fl4; -out: return ret; + if (!ret) { + SKB_DR_SET(reason, NO_SOCKET); + goto out_drop; + } + inet_sk(ret)->cork.fl.u.ip4 = fl4; +out: + return ret; +out_free: + reqsk_free(req); +out_drop: + sk_skb_reason_drop(sk, skb, reason); + return NULL; } diff --git a/net/ipv4/sysctl_net_ipv4.c b/net/ipv4/sysctl_net_ipv4.c index 97eb54774924..a1a50a5c80dc 100644 --- a/net/ipv4/sysctl_net_ipv4.c +++ b/net/ipv4/sysctl_net_ipv4.c @@ -20,54 +20,57 @@ #include <net/protocol.h> #include <net/netevent.h> -static int two = 2; -static int three __maybe_unused = 3; -static int four = 4; -static int thousand = 1000; static int tcp_retr1_max = 255; static int ip_local_port_range_min[] = { 1, 1 }; static int ip_local_port_range_max[] = { 65535, 65535 }; static int tcp_adv_win_scale_min = -31; static int tcp_adv_win_scale_max = 31; +static int tcp_app_win_max = 31; static int tcp_min_snd_mss_min = TCP_MIN_SND_MSS; static int tcp_min_snd_mss_max = 65535; +static int tcp_rto_max_max = TCP_RTO_MAX_SEC * MSEC_PER_SEC; static int ip_privileged_port_min; static int ip_privileged_port_max = 65535; static int ip_ttl_min = 1; static int ip_ttl_max = 255; static int tcp_syn_retries_min = 1; static int tcp_syn_retries_max = MAX_TCP_SYNCNT; -static int ip_ping_group_range_min[] = { 0, 0 }; -static int ip_ping_group_range_max[] = { GID_T_MAX, GID_T_MAX }; +static int tcp_syn_linear_timeouts_max = MAX_TCP_SYNCNT; +static unsigned long ip_ping_group_range_min[] = { 0, 0 }; +static unsigned long ip_ping_group_range_max[] = { GID_T_MAX, GID_T_MAX }; static u32 u32_max_div_HZ = UINT_MAX / HZ; static int one_day_secs = 24 * 3600; static u32 fib_multipath_hash_fields_all_mask __maybe_unused = FIB_MULTIPATH_HASH_FIELD_ALL_MASK; +static unsigned int tcp_child_ehash_entries_max = 16 * 1024 * 1024; +static unsigned int udp_child_hash_entries_max = UDP_HTABLE_SIZE_MAX; +static int tcp_plb_max_rounds = 31; +static int tcp_plb_max_cong_thresh = 256; +static unsigned int tcp_tw_reuse_delay_max = TCP_PAWS_MSL * MSEC_PER_SEC; +static int tcp_ecn_mode_max = 2; +static u32 icmp_errors_extension_mask_all = + GENMASK_U8(ICMP_ERR_EXT_COUNT - 1, 0); /* obsolete */ static int sysctl_tcp_low_latency __read_mostly; /* Update system visible IP port range */ -static void set_local_port_range(struct net *net, int range[2]) +static void set_local_port_range(struct net *net, unsigned int low, unsigned int high) { - bool same_parity = !((range[0] ^ range[1]) & 1); + bool same_parity = !((low ^ high) & 1); - write_seqlock_bh(&net->ipv4.ip_local_ports.lock); if (same_parity && !net->ipv4.ip_local_ports.warned) { net->ipv4.ip_local_ports.warned = true; pr_err_ratelimited("ip_local_port_range: prefer different parity for start/end values.\n"); } - net->ipv4.ip_local_ports.range[0] = range[0]; - net->ipv4.ip_local_ports.range[1] = range[1]; - write_sequnlock_bh(&net->ipv4.ip_local_ports.lock); + WRITE_ONCE(net->ipv4.ip_local_ports.range, high << 16 | low); } /* Validate changes from /proc interface. */ -static int ipv4_local_port_range(struct ctl_table *table, int write, +static int ipv4_local_port_range(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { - struct net *net = - container_of(table->data, struct net, ipv4.ip_local_ports.range); + struct net *net = table->data; int ret; int range[2]; struct ctl_table tmp = { @@ -88,17 +91,17 @@ static int ipv4_local_port_range(struct ctl_table *table, int write, * port limit. */ if ((range[1] < range[0]) || - (range[0] < net->ipv4.sysctl_ip_prot_sock)) + (range[0] < READ_ONCE(net->ipv4.sysctl_ip_prot_sock))) ret = -EINVAL; else - set_local_port_range(net, range); + set_local_port_range(net, range[0], range[1]); } return ret; } /* Validate changes from /proc interface. */ -static int ipv4_privileged_ports(struct ctl_table *table, int write, +static int ipv4_privileged_ports(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { struct net *net = container_of(table->data, struct net, @@ -114,7 +117,7 @@ static int ipv4_privileged_ports(struct ctl_table *table, int write, .extra2 = &ip_privileged_port_max, }; - pports = net->ipv4.sysctl_ip_prot_sock; + pports = READ_ONCE(net->ipv4.sysctl_ip_prot_sock); ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); @@ -126,13 +129,14 @@ static int ipv4_privileged_ports(struct ctl_table *table, int write, if (range[0] < pports) ret = -EINVAL; else - net->ipv4.sysctl_ip_prot_sock = pports; + WRITE_ONCE(net->ipv4.sysctl_ip_prot_sock, pports); } return ret; } -static void inet_get_ping_group_range_table(struct ctl_table *table, kgid_t *low, kgid_t *high) +static void inet_get_ping_group_range_table(const struct ctl_table *table, + kgid_t *low, kgid_t *high) { kgid_t *data = table->data; struct net *net = @@ -147,7 +151,8 @@ static void inet_get_ping_group_range_table(struct ctl_table *table, kgid_t *low } /* Update system visible IP port range */ -static void set_ping_group_range(struct ctl_table *table, kgid_t low, kgid_t high) +static void set_ping_group_range(const struct ctl_table *table, + kgid_t low, kgid_t high) { kgid_t *data = table->data; struct net *net = @@ -159,12 +164,12 @@ static void set_ping_group_range(struct ctl_table *table, kgid_t low, kgid_t hig } /* Validate changes from /proc interface. */ -static int ipv4_ping_group_range(struct ctl_table *table, int write, +static int ipv4_ping_group_range(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { struct user_namespace *user_ns = current_user_ns(); int ret; - gid_t urange[2]; + unsigned long urange[2]; kgid_t low, high; struct ctl_table tmp = { .data = &urange, @@ -177,7 +182,7 @@ static int ipv4_ping_group_range(struct ctl_table *table, int write, inet_get_ping_group_range_table(table, &low, &high); urange[0] = from_kgid_munged(user_ns, low); urange[1] = from_kgid_munged(user_ns, high); - ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); + ret = proc_doulongvec_minmax(&tmp, write, buffer, lenp, ppos); if (write && ret == 0) { low = make_kgid(user_ns, urange[0]); @@ -194,7 +199,7 @@ static int ipv4_ping_group_range(struct ctl_table *table, int write, return ret; } -static int ipv4_fwd_update_priority(struct ctl_table *table, int write, +static int ipv4_fwd_update_priority(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { struct net *net; @@ -210,7 +215,7 @@ static int ipv4_fwd_update_priority(struct ctl_table *table, int write, return ret; } -static int proc_tcp_congestion_control(struct ctl_table *ctl, int write, +static int proc_tcp_congestion_control(const struct ctl_table *ctl, int write, void *buffer, size_t *lenp, loff_t *ppos) { struct net *net = container_of(ctl->data, struct net, @@ -230,7 +235,7 @@ static int proc_tcp_congestion_control(struct ctl_table *ctl, int write, return ret; } -static int proc_tcp_available_congestion_control(struct ctl_table *ctl, +static int proc_tcp_available_congestion_control(const struct ctl_table *ctl, int write, void *buffer, size_t *lenp, loff_t *ppos) { @@ -246,7 +251,7 @@ static int proc_tcp_available_congestion_control(struct ctl_table *ctl, return ret; } -static int proc_allowed_congestion_control(struct ctl_table *ctl, +static int proc_allowed_congestion_control(const struct ctl_table *ctl, int write, void *buffer, size_t *lenp, loff_t *ppos) { @@ -283,7 +288,7 @@ static int sscanf_key(char *buf, __le32 *key) return ret; } -static int proc_tcp_fastopen_key(struct ctl_table *table, int write, +static int proc_tcp_fastopen_key(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { struct net *net = container_of(table->data, struct net, @@ -354,62 +359,7 @@ bad_key: return ret; } -static void proc_configure_early_demux(int enabled, int protocol) -{ - struct net_protocol *ipprot; -#if IS_ENABLED(CONFIG_IPV6) - struct inet6_protocol *ip6prot; -#endif - - rcu_read_lock(); - - ipprot = rcu_dereference(inet_protos[protocol]); - if (ipprot) - ipprot->early_demux = enabled ? ipprot->early_demux_handler : - NULL; - -#if IS_ENABLED(CONFIG_IPV6) - ip6prot = rcu_dereference(inet6_protos[protocol]); - if (ip6prot) - ip6prot->early_demux = enabled ? ip6prot->early_demux_handler : - NULL; -#endif - rcu_read_unlock(); -} - -static int proc_tcp_early_demux(struct ctl_table *table, int write, - void *buffer, size_t *lenp, loff_t *ppos) -{ - int ret = 0; - - ret = proc_dou8vec_minmax(table, write, buffer, lenp, ppos); - - if (write && !ret) { - int enabled = init_net.ipv4.sysctl_tcp_early_demux; - - proc_configure_early_demux(enabled, IPPROTO_TCP); - } - - return ret; -} - -static int proc_udp_early_demux(struct ctl_table *table, int write, - void *buffer, size_t *lenp, loff_t *ppos) -{ - int ret = 0; - - ret = proc_dou8vec_minmax(table, write, buffer, lenp, ppos); - - if (write && !ret) { - int enabled = init_net.ipv4.sysctl_udp_early_demux; - - proc_configure_early_demux(enabled, IPPROTO_UDP); - } - - return ret; -} - -static int proc_tfo_blackhole_detect_timeout(struct ctl_table *table, +static int proc_tfo_blackhole_detect_timeout(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { @@ -424,7 +374,7 @@ static int proc_tfo_blackhole_detect_timeout(struct ctl_table *table, return ret; } -static int proc_tcp_available_ulp(struct ctl_table *ctl, +static int proc_tcp_available_ulp(const struct ctl_table *ctl, int write, void *buffer, size_t *lenp, loff_t *ppos) { @@ -441,8 +391,55 @@ static int proc_tcp_available_ulp(struct ctl_table *ctl, return ret; } +static int proc_tcp_ehash_entries(const struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(table->data, struct net, + ipv4.sysctl_tcp_child_ehash_entries); + struct inet_hashinfo *hinfo = net->ipv4.tcp_death_row.hashinfo; + int tcp_ehash_entries; + struct ctl_table tbl; + + tcp_ehash_entries = hinfo->ehash_mask + 1; + + /* A negative number indicates that the child netns + * shares the global ehash. + */ + if (!net_eq(net, &init_net) && !hinfo->pernet) + tcp_ehash_entries *= -1; + + memset(&tbl, 0, sizeof(tbl)); + tbl.data = &tcp_ehash_entries; + tbl.maxlen = sizeof(int); + + return proc_dointvec(&tbl, write, buffer, lenp, ppos); +} + +static int proc_udp_hash_entries(const struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(table->data, struct net, + ipv4.sysctl_udp_child_hash_entries); + int udp_hash_entries; + struct ctl_table tbl; + + udp_hash_entries = net->ipv4.udp_table->mask + 1; + + /* A negative number indicates that the child netns + * shares the global udp_table. + */ + if (!net_eq(net, &init_net) && net->ipv4.udp_table == &udp_table) + udp_hash_entries *= -1; + + memset(&tbl, 0, sizeof(tbl)); + tbl.data = &udp_hash_entries; + tbl.maxlen = sizeof(int); + + return proc_dointvec(&tbl, write, buffer, lenp, ppos); +} + #ifdef CONFIG_IP_ROUTE_MULTIPATH -static int proc_fib_multipath_hash_policy(struct ctl_table *table, int write, +static int proc_fib_multipath_hash_policy(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { @@ -457,7 +454,7 @@ static int proc_fib_multipath_hash_policy(struct ctl_table *table, int write, return ret; } -static int proc_fib_multipath_hash_fields(struct ctl_table *table, int write, +static int proc_fib_multipath_hash_fields(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { @@ -472,6 +469,61 @@ static int proc_fib_multipath_hash_fields(struct ctl_table *table, int write, return ret; } + +static u32 proc_fib_multipath_hash_rand_seed __ro_after_init; + +static void proc_fib_multipath_hash_init_rand_seed(void) +{ + get_random_bytes(&proc_fib_multipath_hash_rand_seed, + sizeof(proc_fib_multipath_hash_rand_seed)); +} + +static void proc_fib_multipath_hash_set_seed(struct net *net, u32 user_seed) +{ + struct sysctl_fib_multipath_hash_seed new = { + .user_seed = user_seed, + .mp_seed = (user_seed ? user_seed : + proc_fib_multipath_hash_rand_seed), + }; + + WRITE_ONCE(net->ipv4.sysctl_fib_multipath_hash_seed, new); +} + +static int proc_fib_multipath_hash_seed(const struct ctl_table *table, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + struct sysctl_fib_multipath_hash_seed *mphs; + struct net *net = table->data; + struct ctl_table tmp; + u32 user_seed; + int ret; + + mphs = &net->ipv4.sysctl_fib_multipath_hash_seed; + user_seed = mphs->user_seed; + + tmp = *table; + tmp.data = &user_seed; + + ret = proc_douintvec_minmax(&tmp, write, buffer, lenp, ppos); + + if (write && ret == 0) { + proc_fib_multipath_hash_set_seed(net, user_seed); + call_netevent_notifiers(NETEVENT_IPV4_MPATH_HASH_UPDATE, net); + } + + return ret; +} +#else + +static void proc_fib_multipath_hash_init_rand_seed(void) +{ +} + +static void proc_fib_multipath_hash_set_seed(struct net *net, u32 user_seed) +{ +} + #endif static struct ctl_table ipv4_table[] = { @@ -554,22 +606,6 @@ static struct ctl_table ipv4_table[] = { .proc_handler = proc_tcp_available_ulp, }, { - .procname = "icmp_msgs_per_sec", - .data = &sysctl_icmp_msgs_per_sec, - .maxlen = sizeof(int), - .mode = 0644, - .proc_handler = proc_dointvec_minmax, - .extra1 = SYSCTL_ZERO, - }, - { - .procname = "icmp_msgs_burst", - .data = &sysctl_icmp_msgs_burst, - .maxlen = sizeof(int), - .mode = 0644, - .proc_handler = proc_dointvec_minmax, - .extra1 = SYSCTL_ZERO, - }, - { .procname = "udp_mem", .data = &sysctl_udp_mem, .maxlen = sizeof(sysctl_udp_mem), @@ -585,16 +621,24 @@ static struct ctl_table ipv4_table[] = { .extra1 = &sysctl_fib_sync_mem_min, .extra2 = &sysctl_fib_sync_mem_max, }, - { } }; static struct ctl_table ipv4_net_table[] = { { + .procname = "tcp_max_tw_buckets", + .data = &init_net.ipv4.tcp_death_row.sysctl_max_tw_buckets, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { .procname = "icmp_echo_ignore_all", .data = &init_net.ipv4.sysctl_icmp_echo_ignore_all, .maxlen = sizeof(u8), .mode = 0644, .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE }, { .procname = "icmp_echo_enable_probe", @@ -611,6 +655,8 @@ static struct ctl_table ipv4_net_table[] = { .maxlen = sizeof(u8), .mode = 0644, .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE }, { .procname = "icmp_ignore_bogus_error_responses", @@ -618,6 +664,8 @@ static struct ctl_table ipv4_net_table[] = { .maxlen = sizeof(u8), .mode = 0644, .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE }, { .procname = "icmp_errors_use_inbound_ifaddr", @@ -625,6 +673,17 @@ static struct ctl_table ipv4_net_table[] = { .maxlen = sizeof(u8), .mode = 0644, .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "icmp_errors_extension_mask", + .data = &init_net.ipv4.sysctl_icmp_errors_extension_mask, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &icmp_errors_extension_mask_all, }, { .procname = "icmp_ratelimit", @@ -641,6 +700,22 @@ static struct ctl_table ipv4_net_table[] = { .proc_handler = proc_dointvec }, { + .procname = "icmp_msgs_per_sec", + .data = &init_net.ipv4.sysctl_icmp_msgs_per_sec, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { + .procname = "icmp_msgs_burst", + .data = &init_net.ipv4.sysctl_icmp_msgs_burst, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { .procname = "ping_group_range", .data = &init_net.ipv4.ping_group_range.range, .maxlen = sizeof(gid_t)*2, @@ -664,6 +739,26 @@ static struct ctl_table ipv4_net_table[] = { .maxlen = sizeof(u8), .mode = 0644, .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_ecn_mode_max, + }, + { + .procname = "tcp_ecn_option", + .data = &init_net.ipv4.sysctl_tcp_ecn_option, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, + { + .procname = "tcp_ecn_option_beacon", + .data = &init_net.ipv4.sysctl_tcp_ecn_option_beacon, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_THREE, }, { .procname = "tcp_ecn_fallback", @@ -671,6 +766,8 @@ static struct ctl_table ipv4_net_table[] = { .maxlen = sizeof(u8), .mode = 0644, .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, }, { .procname = "ip_dynaddr", @@ -691,14 +788,14 @@ static struct ctl_table ipv4_net_table[] = { .data = &init_net.ipv4.sysctl_udp_early_demux, .maxlen = sizeof(u8), .mode = 0644, - .proc_handler = proc_udp_early_demux + .proc_handler = proc_dou8vec_minmax, }, { .procname = "tcp_early_demux", .data = &init_net.ipv4.sysctl_tcp_early_demux, .maxlen = sizeof(u8), .mode = 0644, - .proc_handler = proc_tcp_early_demux + .proc_handler = proc_dou8vec_minmax, }, { .procname = "nexthop_compat_mode", @@ -720,8 +817,8 @@ static struct ctl_table ipv4_net_table[] = { }, { .procname = "ip_local_port_range", - .maxlen = sizeof(init_net.ipv4.ip_local_ports.range), - .data = &init_net.ipv4.ip_local_ports.range, + .maxlen = 0, + .data = &init_net, .mode = 0644, .proc_handler = ipv4_local_port_range, }, @@ -998,14 +1095,16 @@ static struct ctl_table ipv4_net_table[] = { .mode = 0644, .proc_handler = proc_dou8vec_minmax, .extra1 = SYSCTL_ZERO, - .extra2 = &two, + .extra2 = SYSCTL_TWO, }, { - .procname = "tcp_max_tw_buckets", - .data = &init_net.ipv4.tcp_death_row.sysctl_max_tw_buckets, - .maxlen = sizeof(int), + .procname = "tcp_tw_reuse_delay", + .data = &init_net.ipv4.sysctl_tcp_tw_reuse_delay, + .maxlen = sizeof(unsigned int), .mode = 0644, - .proc_handler = proc_dointvec + .proc_handler = proc_douintvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &tcp_tw_reuse_delay_max, }, { .procname = "tcp_max_syn_backlog", @@ -1058,7 +1157,7 @@ static struct ctl_table ipv4_net_table[] = { .mode = 0644, .proc_handler = proc_fib_multipath_hash_policy, .extra1 = SYSCTL_ZERO, - .extra2 = &three, + .extra2 = SYSCTL_THREE, }, { .procname = "fib_multipath_hash_fields", @@ -1069,6 +1168,13 @@ static struct ctl_table ipv4_net_table[] = { .extra1 = SYSCTL_ONE, .extra2 = &fib_multipath_hash_fields_all_mask, }, + { + .procname = "fib_multipath_hash_seed", + .data = &init_net, + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_fib_multipath_hash_seed, + }, #endif { .procname = "ip_unprivileged_port_start", @@ -1116,7 +1222,7 @@ static struct ctl_table ipv4_net_table[] = { .mode = 0644, .proc_handler = proc_dou8vec_minmax, .extra1 = SYSCTL_ZERO, - .extra2 = &four, + .extra2 = SYSCTL_FOUR, }, { .procname = "tcp_recovery", @@ -1194,6 +1300,8 @@ static struct ctl_table ipv4_net_table[] = { .maxlen = sizeof(u8), .mode = 0644, .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_app_win_max, }, { .procname = "tcp_adv_win_scale", @@ -1235,6 +1343,15 @@ static struct ctl_table ipv4_net_table[] = { .proc_handler = proc_dou8vec_minmax, }, { + .procname = "tcp_rcvbuf_low_rtt", + .data = &init_net.ipv4.sysctl_tcp_rcvbuf_low_rtt, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_INT_MAX, + }, + { .procname = "tcp_tso_win_divisor", .data = &init_net.ipv4.sysctl_tcp_tso_win_divisor, .maxlen = sizeof(u8), @@ -1271,6 +1388,13 @@ static struct ctl_table ipv4_net_table[] = { .extra1 = SYSCTL_ONE, }, { + .procname = "tcp_tso_rtt_log", + .data = &init_net.ipv4.sysctl_tcp_tso_rtt_log, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { .procname = "tcp_min_rtt_wlen", .data = &init_net.ipv4.sysctl_tcp_min_rtt_wlen, .maxlen = sizeof(int), @@ -1302,7 +1426,7 @@ static struct ctl_table ipv4_net_table[] = { .mode = 0644, .proc_handler = proc_dointvec_minmax, .extra1 = SYSCTL_ZERO, - .extra2 = &thousand, + .extra2 = SYSCTL_ONE_THOUSAND, }, { .procname = "tcp_pacing_ca_ratio", @@ -1311,7 +1435,7 @@ static struct ctl_table ipv4_net_table[] = { .mode = 0644, .proc_handler = proc_dointvec_minmax, .extra1 = SYSCTL_ZERO, - .extra2 = &thousand, + .extra2 = SYSCTL_ONE_THOUSAND, }, { .procname = "tcp_wmem", @@ -1337,6 +1461,15 @@ static struct ctl_table ipv4_net_table[] = { .proc_handler = proc_doulongvec_minmax, }, { + .procname = "tcp_comp_sack_rtt_percent", + .data = &init_net.ipv4.sysctl_tcp_comp_sack_rtt_percent, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = SYSCTL_ONE_THOUSAND, + }, + { .procname = "tcp_comp_sack_slack_ns", .data = &init_net.ipv4.sysctl_tcp_comp_sack_slack_ns, .maxlen = sizeof(unsigned long), @@ -1352,6 +1485,15 @@ static struct ctl_table ipv4_net_table[] = { .extra1 = SYSCTL_ZERO, }, { + .procname = "tcp_backlog_ack_defer", + .data = &init_net.ipv4.sysctl_tcp_backlog_ack_defer, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { .procname = "tcp_reflect_tos", .data = &init_net.ipv4.sysctl_tcp_reflect_tos, .maxlen = sizeof(u8), @@ -1361,6 +1503,36 @@ static struct ctl_table ipv4_net_table[] = { .extra2 = SYSCTL_ONE, }, { + .procname = "tcp_ehash_entries", + .data = &init_net.ipv4.sysctl_tcp_child_ehash_entries, + .mode = 0444, + .proc_handler = proc_tcp_ehash_entries, + }, + { + .procname = "tcp_child_ehash_entries", + .data = &init_net.ipv4.sysctl_tcp_child_ehash_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_child_ehash_entries_max, + }, + { + .procname = "udp_hash_entries", + .data = &init_net.ipv4.sysctl_udp_child_hash_entries, + .mode = 0444, + .proc_handler = proc_udp_hash_entries, + }, + { + .procname = "udp_child_hash_entries", + .data = &init_net.ipv4.sysctl_udp_child_hash_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &udp_child_hash_entries_max, + }, + { .procname = "udp_rmem_min", .data = &init_net.ipv4.sysctl_udp_rmem_min, .maxlen = sizeof(init_net.ipv4.sysctl_udp_rmem_min), @@ -1383,13 +1555,97 @@ static struct ctl_table ipv4_net_table[] = { .mode = 0644, .proc_handler = proc_dou8vec_minmax, .extra1 = SYSCTL_ZERO, - .extra2 = &two, + .extra2 = SYSCTL_TWO, + }, + { + .procname = "tcp_plb_enabled", + .data = &init_net.ipv4.sysctl_tcp_plb_enabled, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "tcp_plb_idle_rehash_rounds", + .data = &init_net.ipv4.sysctl_tcp_plb_idle_rehash_rounds, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra2 = &tcp_plb_max_rounds, + }, + { + .procname = "tcp_plb_rehash_rounds", + .data = &init_net.ipv4.sysctl_tcp_plb_rehash_rounds, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra2 = &tcp_plb_max_rounds, + }, + { + .procname = "tcp_plb_suspend_rto_sec", + .data = &init_net.ipv4.sysctl_tcp_plb_suspend_rto_sec, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_plb_cong_thresh", + .data = &init_net.ipv4.sysctl_tcp_plb_cong_thresh, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_plb_max_cong_thresh, + }, + { + .procname = "tcp_syn_linear_timeouts", + .data = &init_net.ipv4.sysctl_tcp_syn_linear_timeouts, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_syn_linear_timeouts_max, + }, + { + .procname = "tcp_shrink_window", + .data = &init_net.ipv4.sysctl_tcp_shrink_window, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "tcp_pingpong_thresh", + .data = &init_net.ipv4.sysctl_tcp_pingpong_thresh, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ONE, + }, + { + .procname = "tcp_rto_min_us", + .data = &init_net.ipv4.sysctl_tcp_rto_min_us, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + }, + { + .procname = "tcp_rto_max_ms", + .data = &init_net.ipv4.sysctl_tcp_rto_max_ms, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE_THOUSAND, + .extra2 = &tcp_rto_max_max, }, - { } }; static __net_init int ipv4_sysctl_init_net(struct net *net) { + size_t table_size = ARRAY_SIZE(ipv4_net_table); struct ctl_table *table; table = ipv4_net_table; @@ -1400,7 +1656,7 @@ static __net_init int ipv4_sysctl_init_net(struct net *net) if (!table) goto err_alloc; - for (i = 0; i < ARRAY_SIZE(ipv4_net_table) - 1; i++) { + for (i = 0; i < table_size; i++) { if (table[i].data) { /* Update the variables to point into * the current struct net @@ -1415,7 +1671,8 @@ static __net_init int ipv4_sysctl_init_net(struct net *net) } } - net->ipv4.ipv4_hdr = register_net_sysctl(net, "net/ipv4", table); + net->ipv4.ipv4_hdr = register_net_sysctl_sz(net, "net/ipv4", table, + table_size); if (!net->ipv4.ipv4_hdr) goto err_reg; @@ -1423,6 +1680,8 @@ static __net_init int ipv4_sysctl_init_net(struct net *net) if (!net->ipv4.sysctl_local_reserved_ports) goto err_ports; + proc_fib_multipath_hash_set_seed(net, 0); + return 0; err_ports: @@ -1436,7 +1695,7 @@ err_alloc: static __net_exit void ipv4_sysctl_exit_net(struct net *net) { - struct ctl_table *table; + const struct ctl_table *table; kfree(net->ipv4.sysctl_local_reserved_ports); table = net->ipv4.ipv4_hdr->ctl_table_arg; @@ -1457,6 +1716,8 @@ static __init int sysctl_ipv4_init(void) if (!hdr) return -ENOMEM; + proc_fib_multipath_hash_init_rand_seed(); + if (register_pernet_subsys(&ipv4_sysctl_ops)) { unregister_net_sysctl_table(hdr); return -ENOMEM; diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 3b75836db19b..f035440c475a 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -243,7 +243,7 @@ #define pr_fmt(fmt) "TCP: " fmt -#include <crypto/hash.h> +#include <crypto/md5.h> #include <linux/kernel.h> #include <linux/module.h> #include <linux/types.h> @@ -253,7 +253,6 @@ #include <linux/init.h> #include <linux/fs.h> #include <linux/skbuff.h> -#include <linux/scatterlist.h> #include <linux/splice.h> #include <linux/net.h> #include <linux/socket.h> @@ -270,15 +269,25 @@ #include <net/icmp.h> #include <net/inet_common.h> +#include <net/inet_ecn.h> #include <net/tcp.h> +#include <net/tcp_ecn.h> #include <net/mptcp.h> +#include <net/proto_memory.h> #include <net/xfrm.h> #include <net/ip.h> +#include <net/psp.h> #include <net/sock.h> +#include <net/rstreason.h> #include <linux/uaccess.h> #include <asm/ioctls.h> #include <net/busy_poll.h> +#include <net/hotdata.h> +#include <trace/events/tcp.h> +#include <net/rps.h> + +#include "../core/devmem.h" /* Track pending CMSGs. */ enum { @@ -289,11 +298,14 @@ enum { DEFINE_PER_CPU(unsigned int, tcp_orphan_count); EXPORT_PER_CPU_SYMBOL_GPL(tcp_orphan_count); +DEFINE_PER_CPU(u32, tcp_tw_isn); +EXPORT_PER_CPU_SYMBOL_GPL(tcp_tw_isn); + long sysctl_tcp_mem[3] __read_mostly; -EXPORT_SYMBOL(sysctl_tcp_mem); +EXPORT_IPV6_MOD(sysctl_tcp_mem); -atomic_long_t tcp_memory_allocated ____cacheline_aligned_in_smp; /* Current allocated memory. */ -EXPORT_SYMBOL(tcp_memory_allocated); +DEFINE_PER_CPU(int, tcp_memory_per_cpu_fw_alloc); +EXPORT_PER_CPU_SYMBOL_GPL(tcp_memory_per_cpu_fw_alloc); #if IS_ENABLED(CONFIG_SMC) DEFINE_STATIC_KEY_FALSE(tcp_have_smc); @@ -304,7 +316,7 @@ EXPORT_SYMBOL(tcp_have_smc); * Current number of TCP sockets. */ struct percpu_counter tcp_sockets_allocated ____cacheline_aligned_in_smp; -EXPORT_SYMBOL(tcp_sockets_allocated); +EXPORT_IPV6_MOD(tcp_sockets_allocated); /* * TCP splice context @@ -337,7 +349,7 @@ void tcp_enter_memory_pressure(struct sock *sk) if (!cmpxchg(&tcp_memory_pressure, 0, val)) NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMEMORYPRESSURES); } -EXPORT_SYMBOL_GPL(tcp_enter_memory_pressure); +EXPORT_IPV6_MOD_GPL(tcp_enter_memory_pressure); void tcp_leave_memory_pressure(struct sock *sk) { @@ -350,7 +362,7 @@ void tcp_leave_memory_pressure(struct sock *sk) NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPMEMORYPRESSURESCHRONO, jiffies_to_msecs(jiffies - val)); } -EXPORT_SYMBOL_GPL(tcp_leave_memory_pressure); +EXPORT_IPV6_MOD_GPL(tcp_leave_memory_pressure); /* Convert seconds to retransmits based on initial and max timeout */ static u8 secs_to_retrans(int seconds, int timeout, int rto_max) @@ -402,6 +414,21 @@ static u64 tcp_compute_delivery_rate(const struct tcp_sock *tp) return rate64; } +#ifdef CONFIG_TCP_MD5SIG +void tcp_md5_destruct_sock(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tp->md5sig_info) { + + tcp_clear_md5_list(sk); + kfree(rcu_replace_pointer(tp->md5sig_info, NULL, 1)); + static_branch_slow_dec_deferred(&tcp_md5_needed); + } +} +EXPORT_IPV6_MOD_GPL(tcp_md5_destruct_sock); +#endif + /* Address-family independent initialization for a tcp_sock. * * NOTE: A lot of things set to zero explicitly by call to @@ -411,6 +438,7 @@ void tcp_init_sock(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); + int rto_min_us, rto_max_ms; tp->out_of_order_queue = RB_ROOT; sk->tcp_rtx_queue = RB_ROOT; @@ -419,7 +447,12 @@ void tcp_init_sock(struct sock *sk) INIT_LIST_HEAD(&tp->tsorted_sent_queue); icsk->icsk_rto = TCP_TIMEOUT_INIT; - icsk->icsk_rto_min = TCP_RTO_MIN; + + rto_max_ms = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rto_max_ms); + icsk->icsk_rto_max = msecs_to_jiffies(rto_max_ms); + + rto_min_us = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rto_min_us); + icsk->icsk_rto_min = usecs_to_jiffies(rto_min_us); icsk->icsk_delack_max = TCP_DELACK_MAX; tp->mdev_us = jiffies_to_usecs(TCP_TIMEOUT_INIT); minmax_reset(&tp->rtt_min, tcp_jiffies32, ~0U); @@ -429,10 +462,11 @@ void tcp_init_sock(struct sock *sk) * algorithms that we must have the following bandaid to talk * efficiently to them. -DaveM */ - tp->snd_cwnd = TCP_INIT_CWND; + tcp_snd_cwnd_set(tp, TCP_INIT_CWND); /* There's a bubble in the pipe until at least the first ACK. */ tp->app_limited = ~0U; + tp->rate_app_limited = 1; /* See draft-stevens-tcpca-spec-01 for discussion of the * initialization of these values. @@ -441,7 +475,7 @@ void tcp_init_sock(struct sock *sk) tp->snd_cwnd_clamp = ~0; tp->mss_cache = TCP_MSS_DEFAULT; - tp->reordering = sock_net(sk)->ipv4.sysctl_tcp_reordering; + tp->reordering = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reordering); tcp_assign_congestion_control(sk); tp->tsoffset = 0; @@ -452,27 +486,35 @@ void tcp_init_sock(struct sock *sk) icsk->icsk_sync_mss = tcp_sync_mss; - WRITE_ONCE(sk->sk_sndbuf, sock_net(sk)->ipv4.sysctl_tcp_wmem[1]); - WRITE_ONCE(sk->sk_rcvbuf, sock_net(sk)->ipv4.sysctl_tcp_rmem[1]); + WRITE_ONCE(sk->sk_sndbuf, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_wmem[1])); + WRITE_ONCE(sk->sk_rcvbuf, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[1])); + tcp_scaling_ratio_init(sk); + set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); sk_sockets_allocated_inc(sk); + xa_init_flags(&sk->sk_user_frags, XA_FLAGS_ALLOC1); } -EXPORT_SYMBOL(tcp_init_sock); +EXPORT_IPV6_MOD(tcp_init_sock); -static void tcp_tx_timestamp(struct sock *sk, u16 tsflags) +static void tcp_tx_timestamp(struct sock *sk, struct sockcm_cookie *sockc) { struct sk_buff *skb = tcp_write_queue_tail(sk); + u32 tsflags = sockc->tsflags; if (tsflags && skb) { struct skb_shared_info *shinfo = skb_shinfo(skb); struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); - sock_tx_timestamp(sk, tsflags, &shinfo->tx_flags); + sock_tx_timestamp(sk, sockc, &shinfo->tx_flags); if (tsflags & SOF_TIMESTAMPING_TX_ACK) - tcb->txstamp_ack = 1; + tcb->txstamp_ack |= TSTAMP_ACK_SK; if (tsflags & SOF_TIMESTAMPING_TX_RECORD_MASK) shinfo->tskey = TCP_SKB_CB(skb)->seq + skb->len - 1; } + + if (cgroup_bpf_enabled(CGROUP_SOCK_OPS) && + SK_BPF_CB_FLAG_TEST(sk, SK_BPF_CB_TX_TIMESTAMPING) && skb) + bpf_skops_tx_timestamping(sk, skb, BPF_SOCK_OPS_TSTAMP_SENDMSG_CB); } static bool tcp_stream_is_readable(struct sock *sk, int target) @@ -494,6 +536,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) __poll_t mask; struct sock *sk = sock->sk; const struct tcp_sock *tp = tcp_sk(sk); + u8 shutdown; int state; sock_poll_wait(file, sock, wait); @@ -536,9 +579,10 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) * NOTE. Check for TCP_CLOSE is added. The goal is to prevent * blocking on fresh not-connected or disconnected socket. --ANK */ - if (sk->sk_shutdown == SHUTDOWN_MASK || state == TCP_CLOSE) + shutdown = READ_ONCE(sk->sk_shutdown); + if (shutdown == SHUTDOWN_MASK || state == TCP_CLOSE) mask |= EPOLLHUP; - if (sk->sk_shutdown & RCV_SHUTDOWN) + if (shutdown & RCV_SHUTDOWN) mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; /* Connected or passive Fast Open socket? */ @@ -555,7 +599,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) if (tcp_stream_is_readable(sk, target)) mask |= EPOLLIN | EPOLLRDNORM; - if (!(sk->sk_shutdown & SEND_SHUTDOWN)) { + if (!(shutdown & SEND_SHUTDOWN)) { if (__sk_stream_is_writeable(sk, 1)) { mask |= EPOLLOUT | EPOLLWRNORM; } else { /* send SIGIO later */ @@ -576,23 +620,25 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) if (urg_data & TCP_URG_VALID) mask |= EPOLLPRI; - } else if (state == TCP_SYN_SENT && inet_sk(sk)->defer_connect) { + } else if (state == TCP_SYN_SENT && + inet_test_bit(DEFER_CONNECT, sk)) { /* Active TCP fastopen socket with defer_connect * Return EPOLLOUT so application can call write() * in order for kernel to generate SYN+data */ mask |= EPOLLOUT | EPOLLWRNORM; } - /* This barrier is coupled with smp_wmb() in tcp_reset() */ + /* This barrier is coupled with smp_wmb() in tcp_done_with_error() */ smp_rmb(); - if (sk->sk_err || !skb_queue_empty_lockless(&sk->sk_error_queue)) + if (READ_ONCE(sk->sk_err) || + !skb_queue_empty_lockless(&sk->sk_error_queue)) mask |= EPOLLERR; return mask; } EXPORT_SYMBOL(tcp_poll); -int tcp_ioctl(struct sock *sk, int cmd, unsigned long arg) +int tcp_ioctl(struct sock *sk, int cmd, int *karg) { struct tcp_sock *tp = tcp_sk(sk); int answ; @@ -634,9 +680,10 @@ int tcp_ioctl(struct sock *sk, int cmd, unsigned long arg) return -ENOIOCTLCMD; } - return put_user(answ, (int __user *)arg); + *karg = answ; + return 0; } -EXPORT_SYMBOL(tcp_ioctl); +EXPORT_IPV6_MOD(tcp_ioctl); void tcp_mark_push(struct tcp_sock *tp, struct sk_buff *skb) { @@ -657,6 +704,7 @@ void tcp_skb_entail(struct sock *sk, struct sk_buff *skb) tcb->seq = tcb->end_seq = tp->write_seq; tcb->tcp_flags = TCPHDR_ACK; __skb_header_release(skb); + psp_enqueue_set_decrypted(sk, skb); tcp_add_write_queue_tail(sk, skb); sk_wmem_queued_add(sk, skb->truesize); sk_mem_charge(sk, skb->truesize); @@ -686,9 +734,10 @@ static bool tcp_should_autocork(struct sock *sk, struct sk_buff *skb, int size_goal) { return skb->len < size_goal && - sock_net(sk)->ipv4.sysctl_tcp_autocorking && + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_autocorking) && !tcp_rtx_queue_empty(sk) && - refcount_read(&sk->sk_wmem_alloc) > skb->truesize; + refcount_read(&sk->sk_wmem_alloc) > skb->truesize && + tcp_skb_can_collapse_to(skb); } void tcp_push(struct sock *sk, int flags, int mss_now, @@ -711,6 +760,7 @@ void tcp_push(struct sock *sk, int flags, int mss_now, if (!test_bit(TSQ_THROTTLED, &sk->sk_tsq_flags)) { NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAUTOCORKING); set_bit(TSQ_THROTTLED, &sk->sk_tsq_flags); + smp_mb__after_atomic(); } /* It is possible TX completion already happened * before we set TSQ_THROTTLED. @@ -820,7 +870,9 @@ ssize_t tcp_splice_read(struct socket *sock, loff_t *ppos, */ if (!skb_queue_empty(&sk->sk_receive_queue)) break; - sk_wait_data(sk, &timeo, NULL); + ret = sk_wait_data(sk, &timeo, NULL); + if (ret < 0) + break; if (signal_pending(current)) { ret = sock_intr_errno(timeo); break; @@ -830,7 +882,7 @@ ssize_t tcp_splice_read(struct socket *sock, loff_t *ppos, tss.len -= ret; spliced += ret; - if (!timeo) + if (!tss.len || !timeo) break; release_sock(sk); lock_sock(sk); @@ -848,17 +900,14 @@ ssize_t tcp_splice_read(struct socket *sock, loff_t *ppos, return ret; } -EXPORT_SYMBOL(tcp_splice_read); +EXPORT_IPV6_MOD(tcp_splice_read); -struct sk_buff *tcp_stream_alloc_skb(struct sock *sk, int size, gfp_t gfp, +struct sk_buff *tcp_stream_alloc_skb(struct sock *sk, gfp_t gfp, bool force_schedule) { struct sk_buff *skb; - if (unlikely(tcp_under_memory_pressure(sk))) - sk_mem_reclaim_partial(sk); - - skb = alloc_skb_fclone(size + MAX_TCP_HEADER, gfp); + skb = alloc_skb_fclone(MAX_TCP_HEADER, gfp); if (likely(skb)) { bool mem_scheduled; @@ -877,7 +926,8 @@ struct sk_buff *tcp_stream_alloc_skb(struct sock *sk, int size, gfp_t gfp, } __kfree_skb(skb); } else { - sk->sk_prot->enter_memory_pressure(sk); + if (!sk->sk_bypass_prot_mem) + tcp_enter_memory_pressure(sk); sk_stream_moderate_sndbuf(sk); } return NULL; @@ -893,8 +943,7 @@ static unsigned int tcp_xmit_size_goal(struct sock *sk, u32 mss_now, return mss_now; /* Note : tcp_tso_autosize() will eventually split this later */ - new_size_goal = sk->sk_gso_max_size - 1 - MAX_TCP_HEADER; - new_size_goal = tcp_bound_to_half_wnd(tp, new_size_goal); + new_size_goal = tcp_bound_to_half_wnd(tp, sk->sk_gso_max_size); /* We try hard to avoid divides here */ size_goal = tp->gso_segs * mss_now; @@ -918,11 +967,11 @@ int tcp_send_mss(struct sock *sk, int *size_goal, int flags) return mss_now; } -/* In some cases, both sendpage() and sendmsg() could have added - * an skb to the write queue, but failed adding payload on it. - * We need to remove it to consume less memory, but more - * importantly be able to generate EPOLLOUT for Edge Trigger epoll() - * users. +/* In some cases, sendmsg() could have added an skb to the write queue, + * but failed adding payload on it. We need to remove it to consume less + * memory, but more importantly be able to generate EPOLLOUT for Edge Trigger + * epoll() users. Another reason is that tcp_write_xmit() does not like + * finding an empty skb in the write queue. */ void tcp_remove_empty_skb(struct sock *sk) { @@ -936,186 +985,39 @@ void tcp_remove_empty_skb(struct sock *sk) } } -static struct sk_buff *tcp_build_frag(struct sock *sk, int size_goal, int flags, - struct page *page, int offset, size_t *size) +/* skb changing from pure zc to mixed, must charge zc */ +static int tcp_downgrade_zcopy_pure(struct sock *sk, struct sk_buff *skb) { - struct sk_buff *skb = tcp_write_queue_tail(sk); - struct tcp_sock *tp = tcp_sk(sk); - bool can_coalesce; - int copy, i; - - if (!skb || (copy = size_goal - skb->len) <= 0 || - !tcp_skb_can_collapse_to(skb)) { -new_segment: - if (!sk_stream_memory_free(sk)) - return NULL; - - skb = tcp_stream_alloc_skb(sk, 0, sk->sk_allocation, - tcp_rtx_and_write_queues_empty(sk)); - if (!skb) - return NULL; - -#ifdef CONFIG_TLS_DEVICE - skb->decrypted = !!(flags & MSG_SENDPAGE_DECRYPTED); -#endif - tcp_skb_entail(sk, skb); - copy = size_goal; - } - - if (copy > *size) - copy = *size; - - i = skb_shinfo(skb)->nr_frags; - can_coalesce = skb_can_coalesce(skb, i, page, offset); - if (!can_coalesce && i >= sysctl_max_skb_frags) { - tcp_mark_push(tp, skb); - goto new_segment; - } - if (!sk_wmem_schedule(sk, copy)) - return NULL; - - if (can_coalesce) { - skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy); - } else { - get_page(page); - skb_fill_page_desc(skb, i, page, offset, copy); - } - - if (!(flags & MSG_NO_SHARED_FRAGS)) - skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG; - - skb->len += copy; - skb->data_len += copy; - skb->truesize += copy; - sk_wmem_queued_add(sk, copy); - sk_mem_charge(sk, copy); - WRITE_ONCE(tp->write_seq, tp->write_seq + copy); - TCP_SKB_CB(skb)->end_seq += copy; - tcp_skb_pcount_set(skb, 0); - - *size = copy; - return skb; -} - -ssize_t do_tcp_sendpages(struct sock *sk, struct page *page, int offset, - size_t size, int flags) -{ - struct tcp_sock *tp = tcp_sk(sk); - int mss_now, size_goal; - int err; - ssize_t copied; - long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); - - if (IS_ENABLED(CONFIG_DEBUG_VM) && - WARN_ONCE(!sendpage_ok(page), - "page must not be a Slab one and have page_count > 0")) - return -EINVAL; - - /* Wait for a connection to finish. One exception is TCP Fast Open - * (passive side) where data is allowed to be sent before a connection - * is fully established. - */ - if (((1 << sk->sk_state) & ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)) && - !tcp_passive_fastopen(sk)) { - err = sk_stream_wait_connect(sk, &timeo); - if (err != 0) - goto out_err; - } - - sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); - - mss_now = tcp_send_mss(sk, &size_goal, flags); - copied = 0; - - err = -EPIPE; - if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN)) - goto out_err; - - while (size > 0) { - struct sk_buff *skb; - size_t copy = size; - - skb = tcp_build_frag(sk, size_goal, flags, page, offset, ©); - if (!skb) - goto wait_for_space; - - if (!copied) - TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_PSH; - - copied += copy; - offset += copy; - size -= copy; - if (!size) - goto out; - - if (skb->len < size_goal || (flags & MSG_OOB)) - continue; + if (unlikely(skb_zcopy_pure(skb))) { + u32 extra = skb->truesize - + SKB_TRUESIZE(skb_end_offset(skb)); - if (forced_push(tp)) { - tcp_mark_push(tp, skb); - __tcp_push_pending_frames(sk, mss_now, TCP_NAGLE_PUSH); - } else if (skb == tcp_send_head(sk)) - tcp_push_one(sk, mss_now); - continue; - -wait_for_space: - set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); - tcp_push(sk, flags & ~MSG_MORE, mss_now, - TCP_NAGLE_PUSH, size_goal); - - err = sk_stream_wait_memory(sk, &timeo); - if (err != 0) - goto do_error; - - mss_now = tcp_send_mss(sk, &size_goal, flags); - } + if (!sk_wmem_schedule(sk, extra)) + return -ENOMEM; -out: - if (copied) { - tcp_tx_timestamp(sk, sk->sk_tsflags); - if (!(flags & MSG_SENDPAGE_NOTLAST)) - tcp_push(sk, flags, mss_now, tp->nonagle, size_goal); + sk_mem_charge(sk, extra); + skb_shinfo(skb)->flags &= ~SKBFL_PURE_ZEROCOPY; } - return copied; - -do_error: - tcp_remove_empty_skb(sk); - if (copied) - goto out; -out_err: - /* make sure we wake any epoll edge trigger waiter */ - if (unlikely(tcp_rtx_and_write_queues_empty(sk) && err == -EAGAIN)) { - sk->sk_write_space(sk); - tcp_chrono_stop(sk, TCP_CHRONO_SNDBUF_LIMITED); - } - return sk_stream_error(sk, flags, err); + return 0; } -EXPORT_SYMBOL_GPL(do_tcp_sendpages); -int tcp_sendpage_locked(struct sock *sk, struct page *page, int offset, - size_t size, int flags) -{ - if (!(sk->sk_route_caps & NETIF_F_SG)) - return sock_no_sendpage_locked(sk, page, offset, size, flags); - - tcp_rate_check_app_limited(sk); /* is sending application-limited? */ - - return do_tcp_sendpages(sk, page, offset, size, flags); -} -EXPORT_SYMBOL_GPL(tcp_sendpage_locked); -int tcp_sendpage(struct sock *sk, struct page *page, int offset, - size_t size, int flags) +int tcp_wmem_schedule(struct sock *sk, int copy) { - int ret; + int left; - lock_sock(sk); - ret = tcp_sendpage_locked(sk, page, offset, size, flags); - release_sock(sk); + if (likely(sk_wmem_schedule(sk, copy))) + return copy; - return ret; + /* We could be in trouble if we have nothing queued. + * Use whatever is left in sk->sk_forward_alloc and tcp_wmem[0] + * to guarantee some progress. + */ + left = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_wmem[0]) - sk->sk_wmem_queued; + if (left > 0) + sk_forced_mem_schedule(sk, min(left, copy)); + return min(copy, sk->sk_forward_alloc); } -EXPORT_SYMBOL(tcp_sendpage); void tcp_free_fastopen_req(struct tcp_sock *tp) { @@ -1125,16 +1027,16 @@ void tcp_free_fastopen_req(struct tcp_sock *tp) } } -static int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, - int *copied, size_t size, - struct ubuf_info *uarg) +int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, int *copied, + size_t size, struct ubuf_info *uarg) { struct tcp_sock *tp = tcp_sk(sk); struct inet_sock *inet = inet_sk(sk); struct sockaddr *uaddr = msg->msg_name; int err, flags; - if (!(sock_net(sk)->ipv4.sysctl_tcp_fastopen & TFO_CLIENT_ENABLE) || + if (!(READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen) & + TFO_CLIENT_ENABLE) || (uaddr && msg->msg_namelen >= sizeof(uaddr->sa_family) && uaddr->sa_family == AF_UNSPEC)) return -EOPNOTSUPP; @@ -1149,7 +1051,7 @@ static int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, tp->fastopen_req->size = size; tp->fastopen_req->uarg = uarg; - if (inet->defer_connect) { + if (inet_test_bit(DEFER_CONNECT, sk)) { err = tcp_connect(sk); /* Same failure procedure as in tcp_v4/6_connect */ if (err) { @@ -1159,7 +1061,7 @@ static int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, } } flags = (msg->msg_flags & MSG_DONTWAIT) ? O_NONBLOCK : 0; - err = __inet_stream_connect(sk->sk_socket, uaddr, + err = __inet_stream_connect(sk->sk_socket, (struct sockaddr_unsized *)uaddr, msg->msg_namelen, flags, 1); /* fastopen_req could already be freed in __inet_stream_connect * if the connection times out or gets rst @@ -1167,13 +1069,14 @@ static int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, if (tp->fastopen_req) { *copied = tp->fastopen_req->copied; tcp_free_fastopen_req(tp); - inet->defer_connect = 0; + inet_clear_bit(DEFER_CONNECT, sk); } return err; } int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) { + struct net_devmem_dmabuf_binding *binding = NULL; struct tcp_sock *tp = tcp_sk(sk); struct ubuf_info *uarg = NULL; struct sk_buff *skb; @@ -1181,25 +1084,60 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) int flags, err, copied = 0; int mss_now = 0, size_goal, copied_syn = 0; int process_backlog = 0; - bool zc = false; + int sockc_err = 0; + int zc = 0; long timeo; flags = msg->msg_flags; - if (flags & MSG_ZEROCOPY && size && sock_flag(sk, SOCK_ZEROCOPY)) { - skb = tcp_write_queue_tail(sk); - uarg = msg_zerocopy_realloc(sk, size, skb_zcopy(skb)); - if (!uarg) { - err = -ENOBUFS; - goto out_err; + sockc = (struct sockcm_cookie){ .tsflags = READ_ONCE(sk->sk_tsflags) }; + if (msg->msg_controllen) { + sockc_err = sock_cmsg_send(sk, msg, &sockc); + /* Don't return error until MSG_FASTOPEN has been processed; + * that may succeed even if the cmsg is invalid. + */ + } + + if ((flags & MSG_ZEROCOPY) && size) { + if (msg->msg_ubuf) { + uarg = msg->msg_ubuf; + if (sk->sk_route_caps & NETIF_F_SG) + zc = MSG_ZEROCOPY; + } else if (sock_flag(sk, SOCK_ZEROCOPY)) { + skb = tcp_write_queue_tail(sk); + uarg = msg_zerocopy_realloc(sk, size, skb_zcopy(skb), + !sockc_err && sockc.dmabuf_id); + if (!uarg) { + err = -ENOBUFS; + goto out_err; + } + if (sk->sk_route_caps & NETIF_F_SG) + zc = MSG_ZEROCOPY; + else + uarg_to_msgzc(uarg)->zerocopy = 0; + + if (!sockc_err && sockc.dmabuf_id) { + binding = net_devmem_get_binding(sk, sockc.dmabuf_id); + if (IS_ERR(binding)) { + err = PTR_ERR(binding); + binding = NULL; + goto out_err; + } + } } + } else if (unlikely(msg->msg_flags & MSG_SPLICE_PAGES) && size) { + if (sk->sk_route_caps & NETIF_F_SG) + zc = MSG_SPLICE_PAGES; + } - zc = sk->sk_route_caps & NETIF_F_SG; - if (!zc) - uarg->zerocopy = 0; + if (!sockc_err && sockc.dmabuf_id && + (!(flags & MSG_ZEROCOPY) || !sock_flag(sk, SOCK_ZEROCOPY))) { + err = -EINVAL; + goto out_err; } - if (unlikely(flags & MSG_FASTOPEN || inet_sk(sk)->defer_connect) && + if (unlikely(flags & MSG_FASTOPEN || + inet_test_bit(DEFER_CONNECT, sk)) && !tp->repair) { err = tcp_sendmsg_fastopen(sk, msg, &copied_syn, size, uarg); if (err == -EINPROGRESS && copied_syn > 0) @@ -1236,13 +1174,9 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) /* 'common' sending to sendq */ } - sockcm_init(&sockc, sk); - if (msg->msg_controllen) { - err = sock_cmsg_send(sk, msg, &sockc); - if (unlikely(err)) { - err = -EINVAL; - goto out_err; - } + if (sockc_err) { + err = sockc_err; + goto out_err; } /* This should be in poll */ @@ -1265,6 +1199,8 @@ restart: if (skb) copy = size_goal - skb->len; + trace_tcp_sendmsg_locked(sk, msg, skb, size_goal); + if (copy <= 0 || !tcp_skb_can_collapse_to(skb)) { bool first_skb; @@ -1278,13 +1214,16 @@ new_segment: goto restart; } first_skb = tcp_rtx_and_write_queues_empty(sk); - skb = tcp_stream_alloc_skb(sk, 0, sk->sk_allocation, + skb = tcp_stream_alloc_skb(sk, sk->sk_allocation, first_skb); if (!skb) goto wait_for_space; process_backlog++; +#ifdef CONFIG_SKB_DECRYPTED + skb->decrypted = !!(flags & MSG_SENDPAGE_DECRYPTED); +#endif tcp_skb_entail(sk, skb); copy = size_goal; @@ -1300,7 +1239,7 @@ new_segment: if (copy > msg_data_left(msg)) copy = msg_data_left(msg); - if (!zc) { + if (zc == 0) { bool merge = true; int i = skb_shinfo(skb)->nr_frags; struct page_frag *pfrag = sk_page_frag(sk); @@ -1310,7 +1249,7 @@ new_segment: if (!skb_can_coalesce(skb, i, pfrag->page, pfrag->offset)) { - if (i >= sysctl_max_skb_frags) { + if (i >= READ_ONCE(net_hotdata.sysctl_max_skb_frags)) { tcp_mark_push(tp, skb); goto new_segment; } @@ -1319,16 +1258,14 @@ new_segment: copy = min_t(int, copy, pfrag->size - pfrag->offset); - /* skb changing from pure zc to mixed, must charge zc */ - if (unlikely(skb_zcopy_pure(skb))) { - if (!sk_wmem_schedule(sk, skb->data_len)) + if (unlikely(skb_zcopy_pure(skb) || skb_zcopy_managed(skb))) { + if (tcp_downgrade_zcopy_pure(sk, skb)) goto wait_for_space; - - sk_mem_charge(sk, skb->data_len); - skb_shinfo(skb)->flags &= ~SKBFL_PURE_ZEROCOPY; + skb_zcopy_downgrade_managed(skb); } - if (!sk_wmem_schedule(sk, copy)) + copy = tcp_wmem_schedule(sk, copy); + if (!copy) goto wait_for_space; err = skb_copy_to_page_nocache(sk, &msg->msg_iter, skb, @@ -1347,7 +1284,7 @@ new_segment: page_ref_inc(pfrag->page); } pfrag->offset += copy; - } else { + } else if (zc == MSG_ZEROCOPY) { /* First append to a fragless skb builds initial * pure zerocopy skb */ @@ -1355,11 +1292,13 @@ new_segment: skb_shinfo(skb)->flags |= SKBFL_PURE_ZEROCOPY; if (!skb_zcopy_pure(skb)) { - if (!sk_wmem_schedule(sk, copy)) + copy = tcp_wmem_schedule(sk, copy); + if (!copy) goto wait_for_space; } - err = skb_zerocopy_iter_stream(sk, skb, msg, copy, uarg); + err = skb_zerocopy_iter_stream(sk, skb, msg, copy, uarg, + binding); if (err == -EMSGSIZE || err == -EEXIST) { tcp_mark_push(tp, skb); goto new_segment; @@ -1367,6 +1306,29 @@ new_segment: if (err < 0) goto do_error; copy = err; + } else if (zc == MSG_SPLICE_PAGES) { + /* Splice in data if we can; copy if we can't. */ + if (tcp_downgrade_zcopy_pure(sk, skb)) + goto wait_for_space; + copy = tcp_wmem_schedule(sk, copy); + if (!copy) + goto wait_for_space; + + err = skb_splice_from_iter(skb, &msg->msg_iter, copy); + if (err < 0) { + if (err == -EMSGSIZE) { + tcp_mark_push(tp, skb); + goto new_segment; + } + goto do_error; + } + copy = err; + + if (!(flags & MSG_NO_SHARED_FRAGS)) + skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG; + + sk_wmem_queued_add(sk, copy); + sk_mem_charge(sk, copy); } if (!copied) @@ -1395,6 +1357,7 @@ new_segment: wait_for_space: set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + tcp_remove_empty_skb(sk); if (copied) tcp_push(sk, flags & ~MSG_MORE, mss_now, TCP_NAGLE_PUSH, size_goal); @@ -1408,11 +1371,15 @@ wait_for_space: out: if (copied) { - tcp_tx_timestamp(sk, sockc.tsflags); + tcp_tx_timestamp(sk, &sockc); tcp_push(sk, flags, mss_now, tp->nonagle, size_goal); } out_nopush: - net_zcopy_put(uarg); + /* msg->msg_ubuf is pinned by the caller so we don't take extra refs */ + if (uarg && !msg->msg_ubuf) + net_zcopy_put(uarg); + if (binding) + net_devmem_dmabuf_binding_put(binding); return copied + copied_syn; do_error: @@ -1421,13 +1388,18 @@ do_error: if (copied + copied_syn) goto out; out_err: - net_zcopy_put_abort(uarg, true); + /* msg->msg_ubuf is pinned by the caller so we don't take extra refs */ + if (uarg && !msg->msg_ubuf) + net_zcopy_put_abort(uarg, true); err = sk_stream_error(sk, flags, err); /* make sure we wake any epoll edge trigger waiter */ if (unlikely(tcp_rtx_and_write_queues_empty(sk) && err == -EAGAIN)) { sk->sk_write_space(sk); tcp_chrono_stop(sk, TCP_CHRONO_SNDBUF_LIMITED); } + if (binding) + net_devmem_dmabuf_binding_put(binding); + return err; } EXPORT_SYMBOL_GPL(tcp_sendmsg_locked); @@ -1444,6 +1416,22 @@ int tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) } EXPORT_SYMBOL(tcp_sendmsg); +void tcp_splice_eof(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct tcp_sock *tp = tcp_sk(sk); + int mss_now, size_goal; + + if (!tcp_write_queue_tail(sk)) + return; + + lock_sock(sk); + mss_now = tcp_send_mss(sk, &size_goal, 0); + tcp_push(sk, 0, mss_now, tp->nonagle, size_goal); + release_sock(sk); +} +EXPORT_IPV6_MOD_GPL(tcp_splice_eof); + /* * Handle reading urgent data. BSD has very simple semantics for * this, no blocking and very strange errors 8) @@ -1498,8 +1486,6 @@ static int tcp_peek_sndq(struct sock *sk, struct msghdr *msg, int len) struct sk_buff *skb; int copied = 0, err = 0; - /* XXX -- need to support SO_PEEK_OFF */ - skb_rbtree_walk(skb, &sk->tcp_rtx_queue) { err = skb_copy_datagram_msg(skb, 0, msg, skb->len); if (err) @@ -1524,17 +1510,11 @@ static int tcp_peek_sndq(struct sock *sk, struct msghdr *msg, int len) * calculation of whether or not we must ACK for the sake of * a window update. */ -void tcp_cleanup_rbuf(struct sock *sk, int copied) +void __tcp_cleanup_rbuf(struct sock *sk, int copied) { struct tcp_sock *tp = tcp_sk(sk); bool time_to_ack = false; - struct sk_buff *skb = skb_peek(&sk->sk_receive_queue); - - WARN(skb && !before(tp->copied_seq, TCP_SKB_CB(skb)->end_seq), - "cleanup rbuf bug: copied %X seq %X rcvnxt %X\n", - tp->copied_seq, TCP_SKB_CB(skb)->end_seq, tp->rcv_nxt); - if (inet_csk_ack_scheduled(sk)) { const struct inet_connection_sock *icsk = inet_csk(sk); @@ -1576,23 +1556,22 @@ void tcp_cleanup_rbuf(struct sock *sk, int copied) time_to_ack = true; } } - if (time_to_ack) + if (time_to_ack) { + tcp_mstamp_refresh(tp); tcp_send_ack(sk); + } } -void __sk_defer_free_flush(struct sock *sk) +void tcp_cleanup_rbuf(struct sock *sk, int copied) { - struct llist_node *head; - struct sk_buff *skb, *n; + struct sk_buff *skb = skb_peek(&sk->sk_receive_queue); + struct tcp_sock *tp = tcp_sk(sk); - head = llist_del_all(&sk->defer_list); - llist_for_each_entry_safe(skb, n, head, ll_node) { - prefetch(n); - skb_mark_not_on_list(skb); - __kfree_skb(skb); - } + WARN(skb && !before(tp->copied_seq, TCP_SKB_CB(skb)->end_seq), + "cleanup rbuf bug: copied %X seq %X rcvnxt %X\n", + tp->copied_seq, TCP_SKB_CB(skb)->end_seq, tp->rcv_nxt); + __tcp_cleanup_rbuf(sk, copied); } -EXPORT_SYMBOL(__sk_defer_free_flush); static void tcp_eat_recv_skb(struct sock *sk, struct sk_buff *skb) { @@ -1601,16 +1580,12 @@ static void tcp_eat_recv_skb(struct sock *sk, struct sk_buff *skb) sock_rfree(skb); skb->destructor = NULL; skb->sk = NULL; - if (!skb_queue_empty(&sk->sk_receive_queue) || - !llist_empty(&sk->defer_list)) { - llist_add(&skb->ll_node, &sk->defer_list); - return; - } + return skb_attempt_defer_free(skb); } __kfree_skb(skb); } -static struct sk_buff *tcp_recv_skb(struct sock *sk, u32 seq, u32 *off) +struct sk_buff *tcp_recv_skb(struct sock *sk, u32 seq, u32 *off) { struct sk_buff *skb; u32 offset; @@ -1633,6 +1608,7 @@ static struct sk_buff *tcp_recv_skb(struct sock *sk, u32 seq, u32 *off) } return NULL; } +EXPORT_SYMBOL(tcp_recv_skb); /* * This routine provides an alternative to tcp_recvmsg() for routines @@ -1645,12 +1621,13 @@ static struct sk_buff *tcp_recv_skb(struct sock *sk, u32 seq, u32 *off) * or for 'peeking' the socket using this routine * (although both would be easy to implement). */ -int tcp_read_sock(struct sock *sk, read_descriptor_t *desc, - sk_read_actor_t recv_actor) +static int __tcp_read_sock(struct sock *sk, read_descriptor_t *desc, + sk_read_actor_t recv_actor, bool noack, + u32 *copied_seq) { struct sk_buff *skb; struct tcp_sock *tp = tcp_sk(sk); - u32 seq = tp->copied_seq; + u32 seq = *copied_seq; u32 offset; int copied = 0; @@ -1675,11 +1652,13 @@ int tcp_read_sock(struct sock *sk, read_descriptor_t *desc, if (!copied) copied = used; break; - } else if (used <= len) { - seq += used; - copied += used; - offset += used; } + if (WARN_ON_ONCE(used > len)) + used = len; + seq += used; + copied += used; + offset += used; + /* If recv_actor drops the lock (e.g. TCP splice * receive) the skb pointer might be invalid when * getting here: tcp_collapse might have deleted it @@ -1702,9 +1681,12 @@ int tcp_read_sock(struct sock *sk, read_descriptor_t *desc, tcp_eat_recv_skb(sk, skb); if (!desc->count) break; - WRITE_ONCE(tp->copied_seq, seq); + WRITE_ONCE(*copied_seq, seq); } - WRITE_ONCE(tp->copied_seq, seq); + WRITE_ONCE(*copied_seq, seq); + + if (noack) + goto out; tcp_rcv_space_adjust(sk); @@ -1713,25 +1695,110 @@ int tcp_read_sock(struct sock *sk, read_descriptor_t *desc, tcp_recv_skb(sk, seq, &offset); tcp_cleanup_rbuf(sk, copied); } +out: return copied; } + +int tcp_read_sock(struct sock *sk, read_descriptor_t *desc, + sk_read_actor_t recv_actor) +{ + return __tcp_read_sock(sk, desc, recv_actor, false, + &tcp_sk(sk)->copied_seq); +} EXPORT_SYMBOL(tcp_read_sock); +int tcp_read_sock_noack(struct sock *sk, read_descriptor_t *desc, + sk_read_actor_t recv_actor, bool noack, + u32 *copied_seq) +{ + return __tcp_read_sock(sk, desc, recv_actor, noack, copied_seq); +} + +int tcp_read_skb(struct sock *sk, skb_read_actor_t recv_actor) +{ + struct sk_buff *skb; + int copied = 0; + + if (sk->sk_state == TCP_LISTEN) + return -ENOTCONN; + + while ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) { + u8 tcp_flags; + int used; + + __skb_unlink(skb, &sk->sk_receive_queue); + WARN_ON_ONCE(!skb_set_owner_sk_safe(skb, sk)); + tcp_flags = TCP_SKB_CB(skb)->tcp_flags; + used = recv_actor(sk, skb); + if (used < 0) { + if (!copied) + copied = used; + break; + } + copied += used; + + if (tcp_flags & TCPHDR_FIN) + break; + } + return copied; +} +EXPORT_IPV6_MOD(tcp_read_skb); + +void tcp_read_done(struct sock *sk, size_t len) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 seq = tp->copied_seq; + struct sk_buff *skb; + size_t left; + u32 offset; + + if (sk->sk_state == TCP_LISTEN) + return; + + left = len; + while (left && (skb = tcp_recv_skb(sk, seq, &offset)) != NULL) { + int used; + + used = min_t(size_t, skb->len - offset, left); + seq += used; + left -= used; + + if (skb->len > offset + used) + break; + + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) { + tcp_eat_recv_skb(sk, skb); + ++seq; + break; + } + tcp_eat_recv_skb(sk, skb); + } + WRITE_ONCE(tp->copied_seq, seq); + + tcp_rcv_space_adjust(sk); + + /* Clean up data we have read: This will do ACK frames. */ + if (left != len) + tcp_cleanup_rbuf(sk, len - left); +} +EXPORT_SYMBOL(tcp_read_done); + int tcp_peek_len(struct socket *sock) { return tcp_inq(sock->sk); } -EXPORT_SYMBOL(tcp_peek_len); +EXPORT_IPV6_MOD(tcp_peek_len); /* Make sure sk_rcvbuf is big enough to satisfy SO_RCVLOWAT hint */ int tcp_set_rcvlowat(struct sock *sk, int val) { - int cap; + struct tcp_sock *tp = tcp_sk(sk); + int space, cap; if (sk->sk_userlocks & SOCK_RCVBUF_LOCK) cap = sk->sk_rcvbuf >> 1; else - cap = sock_net(sk)->ipv4.sysctl_tcp_rmem[2] >> 1; + cap = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[2]) >> 1; val = min(val, cap); WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); @@ -1741,14 +1808,16 @@ int tcp_set_rcvlowat(struct sock *sk, int val) if (sk->sk_userlocks & SOCK_RCVBUF_LOCK) return 0; - val <<= 1; - if (val > sk->sk_rcvbuf) { - WRITE_ONCE(sk->sk_rcvbuf, val); - tcp_sk(sk)->window_clamp = tcp_win_from_space(sk, val); + space = tcp_space_from_win(sk, val); + if (space > sk->sk_rcvbuf) { + WRITE_ONCE(sk->sk_rcvbuf, space); + + if (tp->window_clamp && tp->window_clamp < val) + WRITE_ONCE(tp->window_clamp, val); } return 0; } -EXPORT_SYMBOL(tcp_set_rcvlowat); +EXPORT_IPV6_MOD(tcp_set_rcvlowat); void tcp_update_recv_tstamps(struct sk_buff *skb, struct scm_timestamping_internal *tss) @@ -1773,15 +1842,15 @@ int tcp_mmap(struct file *file, struct socket *sock, { if (vma->vm_flags & (VM_WRITE | VM_EXEC)) return -EPERM; - vma->vm_flags &= ~(VM_MAYWRITE | VM_MAYEXEC); + vm_flags_clear(vma, VM_MAYWRITE | VM_MAYEXEC); /* Instruct vm_insert_page() to not mmap_read_lock(mm) */ - vma->vm_flags |= VM_MIXEDMAP; + vm_flags_set(vma, VM_MIXEDMAP); vma->vm_ops = &tcp_vm_ops; return 0; } -EXPORT_SYMBOL(tcp_mmap); +EXPORT_IPV6_MOD(tcp_mmap); static skb_frag_t *skb_advance_to_frag(struct sk_buff *skb, u32 offset_skb, u32 *offset_frag) @@ -1810,7 +1879,17 @@ static skb_frag_t *skb_advance_to_frag(struct sk_buff *skb, u32 offset_skb, static bool can_map_frag(const skb_frag_t *frag) { - return skb_frag_size(frag) == PAGE_SIZE && !skb_frag_off(frag); + struct page *page; + + if (skb_frag_size(frag) != PAGE_SIZE || skb_frag_off(frag)) + return false; + + page = skb_frag_page(frag); + + if (PageCompound(page) || page->mapping) + return false; + + return true; } static int find_next_mappable_frag(const skb_frag_t *frag, @@ -1866,8 +1945,7 @@ static void tcp_zerocopy_set_hint_for_skb(struct sock *sk, } static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, - int nonblock, int flags, - struct scm_timestamping_internal *tss, + int flags, struct scm_timestamping_internal *tss, int *cmsg_flags); static int receive_fallback_to_copy(struct sock *sk, struct tcp_zerocopy_receive *zc, int inq, @@ -1875,7 +1953,6 @@ static int receive_fallback_to_copy(struct sock *sk, { unsigned long copy_address = (unsigned long)zc->copybuf_address; struct msghdr msg = {}; - struct iovec iov; int err; zc->length = 0; @@ -1884,12 +1961,12 @@ static int receive_fallback_to_copy(struct sock *sk, if (copy_address != zc->copybuf_address) return -EINVAL; - err = import_single_range(READ, (void __user *)copy_address, - inq, &iov, &msg.msg_iter); + err = import_ubuf(ITER_DEST, (void __user *)copy_address, inq, + &msg.msg_iter); if (err) return err; - err = tcp_recvmsg_locked(sk, &msg, inq, /*nonblock=*/1, /*flags=*/0, + err = tcp_recvmsg_locked(sk, &msg, inq, MSG_DONTWAIT, tss, &zc->msg_flags); if (err < 0) return err; @@ -1912,14 +1989,13 @@ static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc, { unsigned long copy_address = (unsigned long)zc->copybuf_address; struct msghdr msg = {}; - struct iovec iov; int err; if (copy_address != zc->copybuf_address) return -EINVAL; - err = import_single_range(READ, (void __user *)copy_address, - copylen, &iov, &msg.msg_iter); + err = import_ubuf(ITER_DEST, (void __user *)copy_address, copylen, + &msg.msg_iter); if (err) return err; err = skb_copy_datagram_msg(skb, *offset, &msg, copylen); @@ -1976,7 +2052,7 @@ static int tcp_zerocopy_vm_insert_batch_error(struct vm_area_struct *vma, maybe_zap_len = total_bytes_to_map - /* All bytes to map */ *length + /* Mapped or pending */ (pages_remaining * PAGE_SIZE); /* Failed map. */ - zap_page_range(vma, *address, maybe_zap_len); + zap_page_range_single(vma, *address, maybe_zap_len, NULL); err = 0; } @@ -1984,7 +2060,7 @@ static int tcp_zerocopy_vm_insert_batch_error(struct vm_area_struct *vma, unsigned long leftover_pages = pages_remaining; int bytes_mapped; - /* We called zap_page_range, try to reinsert. */ + /* We called zap_page_range_single, try to reinsert. */ err = vm_insert_pages(vma, *address, pending_pages, &pages_remaining); @@ -2047,7 +2123,7 @@ static void tcp_zc_finalize_rx_tstamp(struct sock *sk, struct msghdr cmsg_dummy; msg_control_addr = (unsigned long)zc->msg_control; - cmsg_dummy.msg_control = (void *)msg_control_addr; + cmsg_dummy.msg_control_user = (void __user *)msg_control_addr; cmsg_dummy.msg_controllen = (__kernel_size_t)zc->msg_controllen; cmsg_dummy.msg_flags = in_compat_syscall() @@ -2058,13 +2134,38 @@ static void tcp_zc_finalize_rx_tstamp(struct sock *sk, zc->msg_controllen == cmsg_dummy.msg_controllen) { tcp_recv_timestamp(&cmsg_dummy, sk, tss); zc->msg_control = (__u64) - ((uintptr_t)cmsg_dummy.msg_control); + ((uintptr_t)cmsg_dummy.msg_control_user); zc->msg_controllen = (__u64)cmsg_dummy.msg_controllen; zc->msg_flags = (__u32)cmsg_dummy.msg_flags; } } +static struct vm_area_struct *find_tcp_vma(struct mm_struct *mm, + unsigned long address, + bool *mmap_locked) +{ + struct vm_area_struct *vma = lock_vma_under_rcu(mm, address); + + if (vma) { + if (vma->vm_ops != &tcp_vm_ops) { + vma_end_read(vma); + return NULL; + } + *mmap_locked = false; + return vma; + } + + mmap_read_lock(mm); + vma = vma_lookup(mm, address); + if (!vma || vma->vm_ops != &tcp_vm_ops) { + mmap_read_unlock(mm); + return NULL; + } + *mmap_locked = true; + return vma; +} + #define TCP_ZEROCOPY_PAGE_BATCH_SIZE 32 static int tcp_zerocopy_receive(struct sock *sk, struct tcp_zerocopy_receive *zc, @@ -2082,6 +2183,7 @@ static int tcp_zerocopy_receive(struct sock *sk, u32 seq = tp->copied_seq; u32 total_bytes_to_map; int inq = tcp_inq(sk); + bool mmap_locked; int ret; zc->copybuf_len = 0; @@ -2106,19 +2208,17 @@ static int tcp_zerocopy_receive(struct sock *sk, return 0; } - mmap_read_lock(current->mm); - - vma = vma_lookup(current->mm, address); - if (!vma || vma->vm_ops != &tcp_vm_ops) { - mmap_read_unlock(current->mm); + vma = find_tcp_vma(current->mm, address, &mmap_locked); + if (!vma) return -EINVAL; - } + vma_len = min_t(unsigned long, zc->length, vma->vm_end - address); avail_len = min_t(u32, vma_len, inq); total_bytes_to_map = avail_len & ~(PAGE_SIZE - 1); if (total_bytes_to_map) { if (!(zc->flags & TCP_RECEIVE_ZEROCOPY_FLAG_TLB_CLEAN_HINT)) - zap_page_range(vma, address, total_bytes_to_map); + zap_page_range_single(vma, address, total_bytes_to_map, + NULL); zc->length = total_bytes_to_map; zc->recv_skip_hint = 0; } else { @@ -2142,6 +2242,9 @@ static int tcp_zerocopy_receive(struct sock *sk, skb = tcp_recv_skb(sk, seq, &offset); } + if (!skb_frags_readable(skb)) + break; + if (TCP_SKB_CB(skb)->has_rxtstamp) { tcp_update_recv_tstamps(skb, tss); zc->msg_flags |= TCP_CMSG_TS; @@ -2159,6 +2262,9 @@ static int tcp_zerocopy_receive(struct sock *sk, break; } page = skb_frag_page(frags); + if (WARN_ON_ONCE(!page)) + break; + prefetchw(page); pages[pages_to_map++] = page; length += PAGE_SIZE; @@ -2185,7 +2291,10 @@ static int tcp_zerocopy_receive(struct sock *sk, zc, total_bytes_to_map); } out: - mmap_read_unlock(current->mm); + if (mmap_locked) + mmap_read_unlock(current->mm); + else + vma_end_read(vma); /* Try to copy straggler data. */ if (!ret) copylen = tcp_zc_handle_leftover(zc, sk, skb, &seq, copybuf_len, tss); @@ -2214,6 +2323,7 @@ void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk, struct scm_timestamping_internal *tss) { int new_tstamp = sock_flag(sk, SOCK_TSTAMP_NEW); + u32 tsflags = READ_ONCE(sk->sk_tsflags); bool has_timestamping = false; if (tss->ts[0].tv_sec || tss->ts[0].tv_nsec) { @@ -2253,14 +2363,18 @@ void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk, } } - if (sk->sk_tsflags & SOF_TIMESTAMPING_SOFTWARE) + if (tsflags & SOF_TIMESTAMPING_SOFTWARE && + (tsflags & SOF_TIMESTAMPING_RX_SOFTWARE || + !(tsflags & SOF_TIMESTAMPING_OPT_RX_FILTER))) has_timestamping = true; else tss->ts[0] = (struct timespec64) {0}; } if (tss->ts[2].tv_sec || tss->ts[2].tv_nsec) { - if (sk->sk_tsflags & SOF_TIMESTAMPING_RAW_HARDWARE) + if (tsflags & SOF_TIMESTAMPING_RAW_HARDWARE && + (tsflags & SOF_TIMESTAMPING_RX_HARDWARE || + !(tsflags & SOF_TIMESTAMPING_OPT_RX_FILTER))) has_timestamping = true; else tss->ts[2] = (struct timespec64) {0}; @@ -2296,6 +2410,219 @@ static int tcp_inq_hint(struct sock *sk) return inq; } +/* batch __xa_alloc() calls and reduce xa_lock()/xa_unlock() overhead. */ +struct tcp_xa_pool { + u8 max; /* max <= MAX_SKB_FRAGS */ + u8 idx; /* idx <= max */ + __u32 tokens[MAX_SKB_FRAGS]; + netmem_ref netmems[MAX_SKB_FRAGS]; +}; + +static void tcp_xa_pool_commit_locked(struct sock *sk, struct tcp_xa_pool *p) +{ + int i; + + /* Commit part that has been copied to user space. */ + for (i = 0; i < p->idx; i++) + __xa_cmpxchg(&sk->sk_user_frags, p->tokens[i], XA_ZERO_ENTRY, + (__force void *)p->netmems[i], GFP_KERNEL); + /* Rollback what has been pre-allocated and is no longer needed. */ + for (; i < p->max; i++) + __xa_erase(&sk->sk_user_frags, p->tokens[i]); + + p->max = 0; + p->idx = 0; +} + +static void tcp_xa_pool_commit(struct sock *sk, struct tcp_xa_pool *p) +{ + if (!p->max) + return; + + xa_lock_bh(&sk->sk_user_frags); + + tcp_xa_pool_commit_locked(sk, p); + + xa_unlock_bh(&sk->sk_user_frags); +} + +static int tcp_xa_pool_refill(struct sock *sk, struct tcp_xa_pool *p, + unsigned int max_frags) +{ + int err, k; + + if (p->idx < p->max) + return 0; + + xa_lock_bh(&sk->sk_user_frags); + + tcp_xa_pool_commit_locked(sk, p); + + for (k = 0; k < max_frags; k++) { + err = __xa_alloc(&sk->sk_user_frags, &p->tokens[k], + XA_ZERO_ENTRY, xa_limit_31b, GFP_KERNEL); + if (err) + break; + } + + xa_unlock_bh(&sk->sk_user_frags); + + p->max = k; + p->idx = 0; + return k ? 0 : err; +} + +/* On error, returns the -errno. On success, returns number of bytes sent to the + * user. May not consume all of @remaining_len. + */ +static int tcp_recvmsg_dmabuf(struct sock *sk, const struct sk_buff *skb, + unsigned int offset, struct msghdr *msg, + int remaining_len) +{ + struct dmabuf_cmsg dmabuf_cmsg = { 0 }; + struct tcp_xa_pool tcp_xa_pool; + unsigned int start; + int i, copy, n; + int sent = 0; + int err = 0; + + tcp_xa_pool.max = 0; + tcp_xa_pool.idx = 0; + do { + start = skb_headlen(skb); + + if (skb_frags_readable(skb)) { + err = -ENODEV; + goto out; + } + + /* Copy header. */ + copy = start - offset; + if (copy > 0) { + copy = min(copy, remaining_len); + + n = copy_to_iter(skb->data + offset, copy, + &msg->msg_iter); + if (n != copy) { + err = -EFAULT; + goto out; + } + + offset += copy; + remaining_len -= copy; + + /* First a dmabuf_cmsg for # bytes copied to user + * buffer. + */ + memset(&dmabuf_cmsg, 0, sizeof(dmabuf_cmsg)); + dmabuf_cmsg.frag_size = copy; + err = put_cmsg_notrunc(msg, SOL_SOCKET, + SO_DEVMEM_LINEAR, + sizeof(dmabuf_cmsg), + &dmabuf_cmsg); + if (err) + goto out; + + sent += copy; + + if (remaining_len == 0) + goto out; + } + + /* after that, send information of dmabuf pages through a + * sequence of cmsg + */ + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + struct net_iov *niov; + u64 frag_offset; + int end; + + /* !skb_frags_readable() should indicate that ALL the + * frags in this skb are dmabuf net_iovs. We're checking + * for that flag above, but also check individual frags + * here. If the tcp stack is not setting + * skb_frags_readable() correctly, we still don't want + * to crash here. + */ + if (!skb_frag_net_iov(frag)) { + net_err_ratelimited("Found non-dmabuf skb with net_iov"); + err = -ENODEV; + goto out; + } + + niov = skb_frag_net_iov(frag); + if (!net_is_devmem_iov(niov)) { + err = -ENODEV; + goto out; + } + + end = start + skb_frag_size(frag); + copy = end - offset; + + if (copy > 0) { + copy = min(copy, remaining_len); + + frag_offset = net_iov_virtual_addr(niov) + + skb_frag_off(frag) + offset - + start; + dmabuf_cmsg.frag_offset = frag_offset; + dmabuf_cmsg.frag_size = copy; + err = tcp_xa_pool_refill(sk, &tcp_xa_pool, + skb_shinfo(skb)->nr_frags - i); + if (err) + goto out; + + /* Will perform the exchange later */ + dmabuf_cmsg.frag_token = tcp_xa_pool.tokens[tcp_xa_pool.idx]; + dmabuf_cmsg.dmabuf_id = net_devmem_iov_binding_id(niov); + + offset += copy; + remaining_len -= copy; + + err = put_cmsg_notrunc(msg, SOL_SOCKET, + SO_DEVMEM_DMABUF, + sizeof(dmabuf_cmsg), + &dmabuf_cmsg); + if (err) + goto out; + + atomic_long_inc(&niov->desc.pp_ref_count); + tcp_xa_pool.netmems[tcp_xa_pool.idx++] = skb_frag_netmem(frag); + + sent += copy; + + if (remaining_len == 0) + goto out; + } + start = end; + } + + tcp_xa_pool_commit(sk, &tcp_xa_pool); + if (!remaining_len) + goto out; + + /* if remaining_len is not satisfied yet, we need to go to the + * next frag in the frag_list to satisfy remaining_len. + */ + skb = skb_shinfo(skb)->frag_list ?: skb->next; + + offset = offset - start; + } while (skb); + + if (remaining_len) { + err = -EFAULT; + goto out; + } + +out: + tcp_xa_pool_commit(sk, &tcp_xa_pool); + if (!sent) + sent = err; + + return sent; +} + /* * This routine copies from a sock struct into the user buffer. * @@ -2305,11 +2632,11 @@ static int tcp_inq_hint(struct sock *sk) */ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, - int nonblock, int flags, - struct scm_timestamping_internal *tss, + int flags, struct scm_timestamping_internal *tss, int *cmsg_flags) { struct tcp_sock *tp = tcp_sk(sk); + int last_copied_dmabuf = -1; /* uninitialized */ int copied = 0; u32 peek_seq; u32 *seq; @@ -2318,15 +2645,18 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, int target; /* Read at least this many bytes */ long timeo; struct sk_buff *skb, *last; + u32 peek_offset = 0; u32 urg_hole = 0; err = -ENOTCONN; if (sk->sk_state == TCP_LISTEN) goto out; - if (tp->recvmsg_inq) + if (tp->recvmsg_inq) { *cmsg_flags = TCP_CMSG_INQ; - timeo = sock_rcvtimeo(sk, nonblock); + msg->msg_get_inq = 1; + } + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); /* Urgent data needs to be handled specially. */ if (flags & MSG_OOB) @@ -2349,7 +2679,8 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, seq = &tp->copied_seq; if (flags & MSG_PEEK) { - peek_seq = tp->copied_seq; + peek_offset = max(sk_peek_offset(sk, flags), 0); + peek_seq = tp->copied_seq + peek_offset; seq = &peek_seq; } @@ -2444,16 +2775,19 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, __sk_flush_backlog(sk); } else { tcp_cleanup_rbuf(sk, copied); - sk_defer_free_flush(sk); - sk_wait_data(sk, &timeo, last); + err = sk_wait_data(sk, &timeo, last); + if (err < 0) { + err = copied ? : err; + goto out; + } } if ((flags & MSG_PEEK) && - (peek_seq - copied - urg_hole != tp->copied_seq)) { + (peek_seq - peek_offset - copied - urg_hole != tp->copied_seq)) { net_dbg_ratelimited("TCP(%s:%d): Application bug, race in MSG_PEEK\n", current->comm, task_pid_nr(current)); - peek_seq = tp->copied_seq; + peek_seq = tp->copied_seq + peek_offset; } continue; @@ -2482,19 +2816,51 @@ found_ok_skb: } if (!(flags & MSG_TRUNC)) { - err = skb_copy_datagram_msg(skb, offset, msg, used); - if (err) { - /* Exception. Bailout! */ - if (!copied) - copied = -EFAULT; + if (last_copied_dmabuf != -1 && + last_copied_dmabuf != !skb_frags_readable(skb)) break; + + if (skb_frags_readable(skb)) { + err = skb_copy_datagram_msg(skb, offset, msg, + used); + if (err) { + /* Exception. Bailout! */ + if (!copied) + copied = -EFAULT; + break; + } + } else { + if (!(flags & MSG_SOCK_DEVMEM)) { + /* dmabuf skbs can only be received + * with the MSG_SOCK_DEVMEM flag. + */ + if (!copied) + copied = -EFAULT; + + break; + } + + err = tcp_recvmsg_dmabuf(sk, skb, offset, msg, + used); + if (err < 0) { + if (!copied) + copied = err; + + break; + } + used = err; } } + last_copied_dmabuf = !skb_frags_readable(skb); + WRITE_ONCE(*seq, *seq + used); copied += used; len -= used; - + if (flags & MSG_PEEK) + sk_peek_offset_fwd(sk, used); + else + sk_peek_offset_bwd(sk, used); tcp_rcv_space_adjust(sk); skip_copy: @@ -2545,10 +2911,10 @@ recv_sndq: goto out; } -int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock, - int flags, int *addr_len) +int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len) { - int cmsg_flags = 0, ret, inq; + int cmsg_flags = 0, ret; struct scm_timestamping_internal tss; if (unlikely(flags & MSG_ERRQUEUE)) @@ -2557,25 +2923,25 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock, if (sk_can_busy_loop(sk) && skb_queue_empty_lockless(&sk->sk_receive_queue) && sk->sk_state == TCP_ESTABLISHED) - sk_busy_loop(sk, nonblock); + sk_busy_loop(sk, flags & MSG_DONTWAIT); lock_sock(sk); - ret = tcp_recvmsg_locked(sk, msg, len, nonblock, flags, &tss, - &cmsg_flags); + ret = tcp_recvmsg_locked(sk, msg, len, flags, &tss, &cmsg_flags); release_sock(sk); - sk_defer_free_flush(sk); - if (cmsg_flags && ret >= 0) { + if ((cmsg_flags || msg->msg_get_inq) && ret >= 0) { if (cmsg_flags & TCP_CMSG_TS) tcp_recv_timestamp(msg, sk, &tss); - if (cmsg_flags & TCP_CMSG_INQ) { - inq = tcp_inq_hint(sk); - put_cmsg(msg, SOL_TCP, TCP_CM_INQ, sizeof(inq), &inq); + if (msg->msg_get_inq) { + msg->msg_inq = tcp_inq_hint(sk); + if (cmsg_flags & TCP_CMSG_INQ) + put_cmsg(msg, SOL_TCP, TCP_CM_INQ, + sizeof(msg->msg_inq), &msg->msg_inq); } } return ret; } -EXPORT_SYMBOL(tcp_recvmsg); +EXPORT_IPV6_MOD(tcp_recvmsg); void tcp_set_state(struct sock *sk, int state) { @@ -2600,6 +2966,7 @@ void tcp_set_state(struct sock *sk, int state) BUILD_BUG_ON((int)BPF_TCP_LISTEN != (int)TCP_LISTEN); BUILD_BUG_ON((int)BPF_TCP_CLOSING != (int)TCP_CLOSING); BUILD_BUG_ON((int)BPF_TCP_NEW_SYN_RECV != (int)TCP_NEW_SYN_RECV); + BUILD_BUG_ON((int)BPF_TCP_BOUND_INACTIVE != (int)TCP_BOUND_INACTIVE); BUILD_BUG_ON((int)BPF_TCP_MAX_STATES != (int)TCP_MAX_STATES); /* bpf uapi header bpf.h defines an anonymous enum with values @@ -2621,6 +2988,10 @@ void tcp_set_state(struct sock *sk, int state) if (oldstate != TCP_ESTABLISHED) TCP_INC_STATS(sock_net(sk), TCP_MIB_CURRESTAB); break; + case TCP_CLOSE_WAIT: + if (oldstate == TCP_SYN_RECV) + TCP_INC_STATS(sock_net(sk), TCP_MIB_CURRESTAB); + break; case TCP_CLOSE: if (oldstate == TCP_CLOSE_WAIT || oldstate == TCP_ESTABLISHED) @@ -2632,7 +3003,7 @@ void tcp_set_state(struct sock *sk, int state) inet_put_port(sk); fallthrough; default: - if (oldstate == TCP_ESTABLISHED) + if (oldstate == TCP_ESTABLISHED || oldstate == TCP_CLOSE_WAIT) TCP_DEC_STATS(sock_net(sk), TCP_MIB_CURRESTAB); } @@ -2694,13 +3065,13 @@ void tcp_shutdown(struct sock *sk, int how) /* If we've already sent a FIN, or it's a closed state, skip this. */ if ((1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_SYN_SENT | - TCPF_SYN_RECV | TCPF_CLOSE_WAIT)) { + TCPF_CLOSE_WAIT)) { /* Clear out any half completed packets. FIN if needed. */ if (tcp_close_state(sk)) tcp_send_fin(sk); } } -EXPORT_SYMBOL(tcp_shutdown); +EXPORT_IPV6_MOD(tcp_shutdown); int tcp_orphan_count_sum(void) { @@ -2724,10 +3095,19 @@ static void tcp_orphan_update(struct timer_list *unused) static bool tcp_too_many_orphans(int shift) { - return READ_ONCE(tcp_orphan_cache) << shift > sysctl_tcp_max_orphans; + return READ_ONCE(tcp_orphan_cache) << shift > + READ_ONCE(sysctl_tcp_max_orphans); +} + +static bool tcp_out_of_memory(const struct sock *sk) +{ + if (sk->sk_wmem_queued > SOCK_MIN_SNDBUF && + sk_memory_allocated(sk) > sk_prot_mem_limits(sk, 2)) + return true; + return false; } -bool tcp_check_oom(struct sock *sk, int shift) +bool tcp_check_oom(const struct sock *sk, int shift) { bool too_many_orphans, out_of_socket_memory; @@ -2743,11 +3123,11 @@ bool tcp_check_oom(struct sock *sk, int shift) void __tcp_close(struct sock *sk, long timeout) { + bool data_was_unread = false; struct sk_buff *skb; - int data_was_unread = 0; int state; - sk->sk_shutdown = SHUTDOWN_MASK; + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); if (sk->sk_state == TCP_LISTEN) { tcp_set_state(sk, TCP_CLOSE); @@ -2762,17 +3142,16 @@ void __tcp_close(struct sock *sk, long timeout) * descriptor close, not protocol-sourced closes, because the * reader process may not have drained the data yet! */ - while ((skb = __skb_dequeue(&sk->sk_receive_queue)) != NULL) { - u32 len = TCP_SKB_CB(skb)->end_seq - TCP_SKB_CB(skb)->seq; + while ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) { + u32 end_seq = TCP_SKB_CB(skb)->end_seq; if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) - len--; - data_was_unread += len; - __kfree_skb(skb); + end_seq--; + if (after(end_seq, tcp_sk(sk)->copied_seq)) + data_was_unread = true; + tcp_eat_recv_skb(sk, skb); } - sk_mem_reclaim(sk); - /* If socket has been already reset (e.g. in tcp_reset()) - kill it. */ if (sk->sk_state == TCP_CLOSE) goto adjudge_to_death; @@ -2790,7 +3169,8 @@ void __tcp_close(struct sock *sk, long timeout) /* Unread data was tossed, zap the connection. */ NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONCLOSE); tcp_set_state(sk, TCP_CLOSE); - tcp_send_active_reset(sk, sk->sk_allocation); + tcp_send_active_reset(sk, sk->sk_allocation, + SK_RST_REASON_TCP_ABORT_ON_CLOSE); } else if (sock_flag(sk, SOCK_LINGER) && !sk->sk_lingertime) { /* Check zero linger _after_ checking for unread data. */ sk->sk_prot->disconnect(sk, 0); @@ -2804,7 +3184,7 @@ void __tcp_close(struct sock *sk, long timeout) * machine. State transitions: * * TCP_ESTABLISHED -> TCP_FIN_WAIT1 - * TCP_SYN_RECV -> TCP_FIN_WAIT1 (forget it, it's impossible) + * TCP_SYN_RECV -> TCP_FIN_WAIT1 (it is difficult) * TCP_CLOSE_WAIT -> TCP_LAST_ACK * * are legal only when FIN has been sent (i.e. in window), @@ -2840,7 +3220,7 @@ adjudge_to_death: /* remove backlog if any, without releasing ownership. */ __release_sock(sk); - this_cpu_inc(tcp_orphan_count); + tcp_orphan_count_inc(); /* Have we already been destroyed by a softirq or backlog? */ if (state != TCP_CLOSE && sk->sk_state == TCP_CLOSE) @@ -2862,16 +3242,17 @@ adjudge_to_death: if (sk->sk_state == TCP_FIN_WAIT2) { struct tcp_sock *tp = tcp_sk(sk); - if (tp->linger2 < 0) { + if (READ_ONCE(tp->linger2) < 0) { tcp_set_state(sk, TCP_CLOSE); - tcp_send_active_reset(sk, GFP_ATOMIC); + tcp_send_active_reset(sk, GFP_ATOMIC, + SK_RST_REASON_TCP_ABORT_ON_LINGER); __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONLINGER); } else { const int tmo = tcp_fin_time(sk); if (tmo > TCP_TIMEWAIT_LEN) { - inet_csk_reset_keepalive_timer(sk, + tcp_reset_keepalive_timer(sk, tmo - TCP_TIMEWAIT_LEN); } else { tcp_time_wait(sk, TCP_FIN_WAIT2, tmo); @@ -2880,10 +3261,10 @@ adjudge_to_death: } } if (sk->sk_state != TCP_CLOSE) { - sk_mem_reclaim(sk); if (tcp_check_oom(sk, 0)) { tcp_set_state(sk, TCP_CLOSE); - tcp_send_active_reset(sk, GFP_ATOMIC); + tcp_send_active_reset(sk, GFP_ATOMIC, + SK_RST_REASON_TCP_ABORT_ON_MEMORY); __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONMEMORY); } else if (!check_net(sock_net(sk))) { @@ -2917,6 +3298,8 @@ void tcp_close(struct sock *sk, long timeout) lock_sock(sk); __tcp_close(sk, timeout); release_sock(sk); + if (!sk->sk_net_refcnt) + inet_csk_clear_xmit_timers_sync(sk); sock_put(sk); } EXPORT_SYMBOL(tcp_close); @@ -2958,7 +3341,6 @@ void tcp_write_queue_purge(struct sock *sk) } tcp_rtx_queue_purge(sk); INIT_LIST_HEAD(&tcp_sk(sk)->tsorted_sent_queue); - sk_mem_reclaim(sk); tcp_clear_all_retrans_hints(tcp_sk(sk)); tcp_sk(sk)->packets_out = 0; inet_csk(sk)->icsk_backoff = 0; @@ -2970,6 +3352,7 @@ int tcp_disconnect(struct sock *sk, int flags) struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); int old_state = sk->sk_state; + struct request_sock *req; u32 seq; if (old_state != TCP_CLOSE) @@ -2979,32 +3362,35 @@ int tcp_disconnect(struct sock *sk, int flags) if (old_state == TCP_LISTEN) { inet_csk_listen_stop(sk); } else if (unlikely(tp->repair)) { - sk->sk_err = ECONNABORTED; - } else if (tcp_need_reset(old_state) || - (tp->snd_nxt != tp->write_seq && - (1 << old_state) & (TCPF_CLOSING | TCPF_LAST_ACK))) { + WRITE_ONCE(sk->sk_err, ECONNABORTED); + } else if (tcp_need_reset(old_state)) { + tcp_send_active_reset(sk, gfp_any(), SK_RST_REASON_TCP_STATE); + WRITE_ONCE(sk->sk_err, ECONNRESET); + } else if (tp->snd_nxt != tp->write_seq && + (1 << old_state) & (TCPF_CLOSING | TCPF_LAST_ACK)) { /* The last check adjusts for discrepancy of Linux wrt. RFC * states */ - tcp_send_active_reset(sk, gfp_any()); - sk->sk_err = ECONNRESET; + tcp_send_active_reset(sk, gfp_any(), + SK_RST_REASON_TCP_DISCONNECT_WITH_DATA); + WRITE_ONCE(sk->sk_err, ECONNRESET); } else if (old_state == TCP_SYN_SENT) - sk->sk_err = ECONNRESET; + WRITE_ONCE(sk->sk_err, ECONNRESET); tcp_clear_xmit_timers(sk); __skb_queue_purge(&sk->sk_receive_queue); WRITE_ONCE(tp->copied_seq, tp->rcv_nxt); WRITE_ONCE(tp->urg_data, 0); + sk_set_peek_off(sk, -1); tcp_write_queue_purge(sk); tcp_fastopen_active_disable_ofo_check(sk); skb_rbtree_purge(&tp->out_of_order_queue); inet->inet_dport = 0; - if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) - inet_reset_saddr(sk); + inet_bhash2_reset_saddr(sk); - sk->sk_shutdown = 0; + WRITE_ONCE(sk->sk_shutdown, 0); sock_reset_flag(sk, SOCK_DONE); tp->srtt_us = 0; tp->mdev_us = jiffies_to_usecs(TCP_TIMEOUT_INIT); @@ -3016,18 +3402,25 @@ int tcp_disconnect(struct sock *sk, int flags) WRITE_ONCE(tp->write_seq, seq); icsk->icsk_backoff = 0; - icsk->icsk_probes_out = 0; + WRITE_ONCE(icsk->icsk_probes_out, 0); icsk->icsk_probes_tstamp = 0; icsk->icsk_rto = TCP_TIMEOUT_INIT; - icsk->icsk_rto_min = TCP_RTO_MIN; - icsk->icsk_delack_max = TCP_DELACK_MAX; + WRITE_ONCE(icsk->icsk_rto_min, TCP_RTO_MIN); + WRITE_ONCE(icsk->icsk_delack_max, TCP_DELACK_MAX); tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; - tp->snd_cwnd = TCP_INIT_CWND; + tcp_snd_cwnd_set(tp, TCP_INIT_CWND); tp->snd_cwnd_cnt = 0; + tp->is_cwnd_limited = 0; + tp->max_packets_out = 0; tp->window_clamp = 0; tp->delivered = 0; tp->delivered_ce = 0; - if (icsk->icsk_ca_ops->release) + tp->accecn_fail_mode = 0; + tp->saw_accecn_opt = TCP_ACCECN_OPT_NOT_SEEN; + tcp_accecn_init_counters(tp); + tp->prev_ecnfield = 0; + tp->accecn_opt_tstamp = 0; + if (icsk->icsk_ca_initialized && icsk->icsk_ca_ops->release) icsk->icsk_ca_ops->release(sk); memset(icsk->icsk_ca_priv, 0, sizeof(icsk->icsk_ca_priv)); icsk->icsk_ca_initialized = 0; @@ -3042,7 +3435,7 @@ int tcp_disconnect(struct sock *sk, int flags) icsk->icsk_ack.rcv_mss = TCP_MIN_MSS; memset(&tp->rx_opt, 0, sizeof(tp->rx_opt)); __sk_dst_reset(sk); - dst_release(xchg((__force struct dst_entry **)&sk->sk_rx_dst, NULL)); + dst_release(unrcu_pointer(xchg(&sk->sk_rx_dst, NULL))); tcp_saved_syn_free(tp); tp->compressed_ack = 0; tp->segs_in = 0; @@ -3061,8 +3454,10 @@ int tcp_disconnect(struct sock *sk, int flags) tp->sacked_out = 0; tp->tlp_high_seq = 0; tp->last_oow_ack_time = 0; + tp->plb_rehash = 0; /* There's a bubble in the pipe until at least the first ACK. */ tp->app_limited = ~0U; + tp->rate_app_limited = 1; tp->rack.mstamp = 0; tp->rack.advanced = 0; tp->rack.reo_wnd_steps = 1; @@ -3070,6 +3465,7 @@ int tcp_disconnect(struct sock *sk, int flags) tp->rack.reo_wnd_persist = 0; tp->rack.dsack_seen = 0; tp->syn_data_acked = 0; + tp->syn_fastopen_child = 0; tp->rx_opt.saw_tstamp = 0; tp->rx_opt.dsack = 0; tp->rx_opt.num_sacks = 0; @@ -3077,8 +3473,12 @@ int tcp_disconnect(struct sock *sk, int flags) /* Clean up fastopen related fields */ + req = rcu_dereference_protected(tp->fastopen_rsk, + lockdep_sock_is_held(sk)); + if (req) + reqsk_fastopen_remove(sk, req, false); tcp_free_fastopen_req(tp); - inet->defer_connect = 0; + inet_clear_bit(DEFER_CONNECT, sk); tp->fastopen_client_fail = 0; WARN_ON(inet->inet_num && !icsk->icsk_bind_hash); @@ -3088,7 +3488,6 @@ int tcp_disconnect(struct sock *sk, int flags) sk->sk_frag.page = NULL; sk->sk_frag.offset = 0; } - sk_defer_free_flush(sk); sk_error_report(sk); return 0; } @@ -3096,7 +3495,7 @@ EXPORT_SYMBOL(tcp_disconnect); static inline bool tcp_can_repair_sock(const struct sock *sk) { - return ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN) && + return sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN) && (sk->sk_state != TCP_LISTEN); } @@ -3183,11 +3582,14 @@ static int tcp_repair_options_est(struct sock *sk, sockptr_t optbuf, } DEFINE_STATIC_KEY_FALSE(tcp_tx_delay_enabled); -EXPORT_SYMBOL(tcp_tx_delay_enabled); +EXPORT_IPV6_MOD(tcp_tx_delay_enabled); -static void tcp_enable_tx_delay(void) +static void tcp_enable_tx_delay(struct sock *sk, int val) { - if (!static_branch_unlikely(&tcp_tx_delay_enabled)) { + struct tcp_sock *tp = tcp_sk(sk); + s32 delta = (val - tp->tcp_tx_delay) << 3; + + if (val && !static_branch_unlikely(&tcp_tx_delay_enabled)) { static int __tcp_tx_delay_enabled = 0; if (cmpxchg(&__tcp_tx_delay_enabled, 0, 1) == 0) { @@ -3195,6 +3597,22 @@ static void tcp_enable_tx_delay(void) pr_info("TCP_TX_DELAY enabled\n"); } } + /* If we change tcp_tx_delay on a live flow, adjust tp->srtt_us, + * tp->rtt_min, icsk_rto and sk->sk_pacing_rate. + * This is best effort. + */ + if (delta && sk->sk_state == TCP_ESTABLISHED) { + s64 srtt = (s64)tp->srtt_us + delta; + + tp->srtt_us = clamp_t(s64, srtt, 1, ~0U); + + /* Note: does not deal with non zero icsk_backoff */ + tcp_set_rto(sk); + + minmax_reset(&tp->rtt_min, tcp_jiffies32, ~0U); + + tcp_update_pacing_rate(sk); + } } /* When set indicates to always queue non-full frames. Later the user clears @@ -3282,18 +3700,21 @@ int tcp_sock_set_syncnt(struct sock *sk, int val) if (val < 1 || val > MAX_TCP_SYNCNT) return -EINVAL; - lock_sock(sk); - inet_csk(sk)->icsk_syn_retries = val; - release_sock(sk); + WRITE_ONCE(inet_csk(sk)->icsk_syn_retries, val); return 0; } EXPORT_SYMBOL(tcp_sock_set_syncnt); -void tcp_sock_set_user_timeout(struct sock *sk, u32 val) +int tcp_sock_set_user_timeout(struct sock *sk, int val) { - lock_sock(sk); - inet_csk(sk)->icsk_user_timeout = val; - release_sock(sk); + /* Cap the max time in ms TCP will retry or probe the window + * before giving up and aborting (ETIMEDOUT) a connection. + */ + if (val < 0) + return -EINVAL; + + WRITE_ONCE(inet_csk(sk)->icsk_user_timeout, val); + return 0; } EXPORT_SYMBOL(tcp_sock_set_user_timeout); @@ -3304,7 +3725,8 @@ int tcp_sock_set_keepidle_locked(struct sock *sk, int val) if (val < 1 || val > MAX_TCP_KEEPIDLE) return -EINVAL; - tp->keepalive_time = val * HZ; + /* Paired with WRITE_ONCE() in keepalive_time_when() */ + WRITE_ONCE(tp->keepalive_time, val * HZ); if (sock_flag(sk, SOCK_KEEPOPEN) && !((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) { u32 elapsed = keepalive_time_elapsed(tp); @@ -3313,7 +3735,7 @@ int tcp_sock_set_keepidle_locked(struct sock *sk, int val) elapsed = tp->keepalive_time - elapsed; else elapsed = 0; - inet_csk_reset_keepalive_timer(sk, elapsed); + tcp_reset_keepalive_timer(sk, elapsed); } return 0; @@ -3335,9 +3757,7 @@ int tcp_sock_set_keepintvl(struct sock *sk, int val) if (val < 1 || val > MAX_TCP_KEEPINTVL) return -EINVAL; - lock_sock(sk); - tcp_sk(sk)->keepalive_intvl = val * HZ; - release_sock(sk); + WRITE_ONCE(tcp_sk(sk)->keepalive_intvl, val * HZ); return 0; } EXPORT_SYMBOL(tcp_sock_set_keepintvl); @@ -3347,34 +3767,62 @@ int tcp_sock_set_keepcnt(struct sock *sk, int val) if (val < 1 || val > MAX_TCP_KEEPCNT) return -EINVAL; - lock_sock(sk); - tcp_sk(sk)->keepalive_probes = val; - release_sock(sk); + /* Paired with READ_ONCE() in keepalive_probes() */ + WRITE_ONCE(tcp_sk(sk)->keepalive_probes, val); return 0; } EXPORT_SYMBOL(tcp_sock_set_keepcnt); int tcp_set_window_clamp(struct sock *sk, int val) { + u32 old_window_clamp, new_window_clamp, new_rcv_ssthresh; struct tcp_sock *tp = tcp_sk(sk); if (!val) { if (sk->sk_state != TCP_CLOSE) return -EINVAL; - tp->window_clamp = 0; + WRITE_ONCE(tp->window_clamp, 0); + return 0; + } + + old_window_clamp = tp->window_clamp; + new_window_clamp = max_t(int, SOCK_MIN_RCVBUF / 2, val); + + if (new_window_clamp == old_window_clamp) + return 0; + + WRITE_ONCE(tp->window_clamp, new_window_clamp); + + /* Need to apply the reserved mem provisioning only + * when shrinking the window clamp. + */ + if (new_window_clamp < old_window_clamp) { + __tcp_adjust_rcv_ssthresh(sk, new_window_clamp); } else { - tp->window_clamp = val < SOCK_MIN_RCVBUF / 2 ? - SOCK_MIN_RCVBUF / 2 : val; - tp->rcv_ssthresh = min(tp->rcv_wnd, tp->window_clamp); + new_rcv_ssthresh = min(tp->rcv_wnd, new_window_clamp); + tp->rcv_ssthresh = max(new_rcv_ssthresh, tp->rcv_ssthresh); } return 0; } +int tcp_sock_set_maxseg(struct sock *sk, int val) +{ + /* Values greater than interface MTU won't take effect. However + * at the point when this call is done we typically don't yet + * know which interface is going to be used + */ + if (val && (val < TCP_MIN_MSS || val > MAX_TCP_WINDOW)) + return -EINVAL; + + WRITE_ONCE(tcp_sk(sk)->rx_opt.user_mss, val); + return 0; +} + /* * Socket option code for TCP. */ -static int do_tcp_setsockopt(struct sock *sk, int level, int optname, - sockptr_t optval, unsigned int optlen) +int do_tcp_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) { struct tcp_sock *tp = tcp_sk(sk); struct inet_connection_sock *icsk = inet_csk(sk); @@ -3396,11 +3844,11 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, return -EFAULT; name[val] = 0; - lock_sock(sk); - err = tcp_set_congestion_control(sk, name, true, - ns_capable(sock_net(sk)->user_ns, - CAP_NET_ADMIN)); - release_sock(sk); + sockopt_lock_sock(sk); + err = tcp_set_congestion_control(sk, name, !has_current_bpf_ctx(), + sockopt_ns_capable(sock_net(sk)->user_ns, + CAP_NET_ADMIN)); + sockopt_release_sock(sk); return err; } case TCP_ULP: { @@ -3416,9 +3864,9 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, return -EFAULT; name[val] = 0; - lock_sock(sk); + sockopt_lock_sock(sk); err = tcp_set_ulp(sk, name); - release_sock(sk); + sockopt_release_sock(sk); return err; } case TCP_FASTOPEN_KEY: { @@ -3451,21 +3899,58 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, if (copy_from_sockptr(&val, optval, sizeof(val))) return -EFAULT; - lock_sock(sk); - + /* Handle options that can be set without locking the socket. */ switch (optname) { + case TCP_SYNCNT: + return tcp_sock_set_syncnt(sk, val); + case TCP_USER_TIMEOUT: + return tcp_sock_set_user_timeout(sk, val); + case TCP_KEEPINTVL: + return tcp_sock_set_keepintvl(sk, val); + case TCP_KEEPCNT: + return tcp_sock_set_keepcnt(sk, val); + case TCP_LINGER2: + if (val < 0) + WRITE_ONCE(tp->linger2, -1); + else if (val > TCP_FIN_TIMEOUT_MAX / HZ) + WRITE_ONCE(tp->linger2, TCP_FIN_TIMEOUT_MAX); + else + WRITE_ONCE(tp->linger2, val * HZ); + return 0; + case TCP_DEFER_ACCEPT: + /* Translate value in seconds to number of retransmits */ + WRITE_ONCE(icsk->icsk_accept_queue.rskq_defer_accept, + secs_to_retrans(val, TCP_TIMEOUT_INIT / HZ, + TCP_RTO_MAX / HZ)); + return 0; + case TCP_RTO_MAX_MS: + if (val < MSEC_PER_SEC || val > TCP_RTO_MAX_SEC * MSEC_PER_SEC) + return -EINVAL; + WRITE_ONCE(inet_csk(sk)->icsk_rto_max, msecs_to_jiffies(val)); + return 0; + case TCP_RTO_MIN_US: { + int rto_min = usecs_to_jiffies(val); + + if (rto_min > TCP_RTO_MIN || rto_min < TCP_TIMEOUT_MIN) + return -EINVAL; + WRITE_ONCE(inet_csk(sk)->icsk_rto_min, rto_min); + return 0; + } + case TCP_DELACK_MAX_US: { + int delack_max = usecs_to_jiffies(val); + + if (delack_max > TCP_DELACK_MAX || delack_max < TCP_TIMEOUT_MIN) + return -EINVAL; + WRITE_ONCE(inet_csk(sk)->icsk_delack_max, delack_max); + return 0; + } case TCP_MAXSEG: - /* Values greater than interface MTU won't take effect. However - * at the point when this call is done we typically don't yet - * know which interface is going to be used - */ - if (val && (val < TCP_MIN_MSS || val > MAX_TCP_WINDOW)) { - err = -EINVAL; - break; - } - tp->rx_opt.user_mss = val; - break; + return tcp_sock_set_maxseg(sk, val); + } + + sockopt_lock_sock(sk); + switch (optname) { case TCP_NODELAY: __tcp_sock_set_nodelay(sk, val); break; @@ -3533,7 +4018,7 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, case TCP_REPAIR_OPTIONS: if (!tp->repair) err = -EINVAL; - else if (sk->sk_state == TCP_ESTABLISHED) + else if (sk->sk_state == TCP_ESTABLISHED && !tp->bytes_sent) err = tcp_repair_options_est(sk, optval, optlen); else err = -EPERM; @@ -3546,25 +4031,6 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, case TCP_KEEPIDLE: err = tcp_sock_set_keepidle_locked(sk, val); break; - case TCP_KEEPINTVL: - if (val < 1 || val > MAX_TCP_KEEPINTVL) - err = -EINVAL; - else - tp->keepalive_intvl = val * HZ; - break; - case TCP_KEEPCNT: - if (val < 1 || val > MAX_TCP_KEEPCNT) - err = -EINVAL; - else - tp->keepalive_probes = val; - break; - case TCP_SYNCNT: - if (val < 1 || val > MAX_TCP_SYNCNT) - err = -EINVAL; - else - icsk->icsk_syn_retries = val; - break; - case TCP_SAVE_SYN: /* 0: disable, 1: enable, 2: start from ether_header */ if (val < 0 || val > 2) @@ -3573,22 +4039,6 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, tp->save_syn = val; break; - case TCP_LINGER2: - if (val < 0) - tp->linger2 = -1; - else if (val > TCP_FIN_TIMEOUT_MAX / HZ) - tp->linger2 = TCP_FIN_TIMEOUT_MAX; - else - tp->linger2 = val * HZ; - break; - - case TCP_DEFER_ACCEPT: - /* Translate value in seconds to number of retransmits */ - icsk->icsk_accept_queue.rskq_defer_accept = - secs_to_retrans(val, TCP_TIMEOUT_INIT / HZ, - TCP_RTO_MAX / HZ); - break; - case TCP_WINDOW_CLAMP: err = tcp_set_window_clamp(sk, val); break; @@ -3597,22 +4047,41 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, __tcp_sock_set_quickack(sk, val); break; + case TCP_AO_REPAIR: + if (!tcp_can_repair_sock(sk)) { + err = -EPERM; + break; + } + err = tcp_ao_set_repair(sk, optval, optlen); + break; +#ifdef CONFIG_TCP_AO + case TCP_AO_ADD_KEY: + case TCP_AO_DEL_KEY: + case TCP_AO_INFO: { + /* If this is the first TCP-AO setsockopt() on the socket, + * sk_state has to be LISTEN or CLOSE. Allow TCP_REPAIR + * in any state. + */ + if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) + goto ao_parse; + if (rcu_dereference_protected(tcp_sk(sk)->ao_info, + lockdep_sock_is_held(sk))) + goto ao_parse; + if (tp->repair) + goto ao_parse; + err = -EISCONN; + break; +ao_parse: + err = tp->af_specific->ao_parse(sk, optname, optval, optlen); + break; + } +#endif #ifdef CONFIG_TCP_MD5SIG case TCP_MD5SIG: case TCP_MD5SIG_EXT: err = tp->af_specific->md5_parse(sk, optname, optval, optlen); break; #endif - case TCP_USER_TIMEOUT: - /* Cap the max time in ms TCP will retry or probe the window - * before giving up and aborting (ETIMEDOUT) a connection. - */ - if (val < 0) - err = -EINVAL; - else - icsk->icsk_user_timeout = val; - break; - case TCP_FASTOPEN: if (val >= 0 && ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) { @@ -3626,7 +4095,8 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, case TCP_FASTOPEN_CONNECT: if (val > 1 || val < 0) { err = -EINVAL; - } else if (net->ipv4.sysctl_tcp_fastopen & TFO_CLIENT_ENABLE) { + } else if (READ_ONCE(net->ipv4.sysctl_tcp_fastopen) & + TFO_CLIENT_ENABLE) { if (sk->sk_state == TCP_CLOSE) tp->fastopen_connect = val; else @@ -3644,16 +4114,22 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, tp->fastopen_no_cookie = val; break; case TCP_TIMESTAMP: - if (!tp->repair) + if (!tp->repair) { err = -EPERM; - else - tp->tsoffset = val - tcp_time_stamp_raw(); + break; + } + /* val is an opaque field, + * and low order bit contains usec_ts enable bit. + * Its a best effort, and we do not care if user makes an error. + */ + tp->tcp_usec_ts = val & 1; + WRITE_ONCE(tp->tsoffset, val - tcp_clock_ts(tp->tcp_usec_ts)); break; case TCP_REPAIR_WINDOW: err = tcp_repair_set_window(tp, optval, optlen); break; case TCP_NOTSENT_LOWAT: - tp->notsent_lowat = val; + WRITE_ONCE(tp->notsent_lowat, val); sk->sk_write_space(sk); break; case TCP_INQ: @@ -3663,16 +4139,20 @@ static int do_tcp_setsockopt(struct sock *sk, int level, int optname, tp->recvmsg_inq = val; break; case TCP_TX_DELAY: - if (val) - tcp_enable_tx_delay(); - tp->tcp_tx_delay = val; + /* tp->srtt_us is u32, and is shifted by 3 */ + if (val < 0 || val >= (1U << (31 - 3))) { + err = -EINVAL; + break; + } + tcp_enable_tx_delay(sk, val); + WRITE_ONCE(tp->tcp_tx_delay, val); break; default: err = -ENOPROTOOPT; break; } - release_sock(sk); + sockopt_release_sock(sk); return err; } @@ -3682,11 +4162,12 @@ int tcp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, const struct inet_connection_sock *icsk = inet_csk(sk); if (level != SOL_TCP) - return icsk->icsk_af_ops->setsockopt(sk, level, optname, - optval, optlen); + /* Paired with WRITE_ONCE() in do_ipv6_setsockopt() and tcp_v6_connect() */ + return READ_ONCE(icsk->icsk_af_ops)->setsockopt(sk, level, optname, + optval, optlen); return do_tcp_setsockopt(sk, level, optname, optval, optlen); } -EXPORT_SYMBOL(tcp_setsockopt); +EXPORT_IPV6_MOD(tcp_setsockopt); static void tcp_get_info_chrono_stats(const struct tcp_sock *tp, struct tcp_info *info) @@ -3712,6 +4193,9 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) { const struct tcp_sock *tp = tcp_sk(sk); /* iff sk_type == SOCK_STREAM */ const struct inet_connection_sock *icsk = inet_csk(sk); + const u8 ect1_idx = INET_ECN_ECT_1 - 1; + const u8 ect0_idx = INET_ECN_ECT_0 - 1; + const u8 ce_idx = INET_ECN_CE - 1; unsigned long rate; u32 now; u64 rate64; @@ -3733,7 +4217,7 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) info->tcpi_max_pacing_rate = rate64; info->tcpi_reordering = tp->reordering; - info->tcpi_snd_cwnd = tp->snd_cwnd; + info->tcpi_snd_cwnd = tcp_snd_cwnd(tp); if (info->tcpi_state == TCP_LISTEN) { /* listeners aliased fields : @@ -3762,15 +4246,20 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) info->tcpi_rcv_wscale = tp->rx_opt.rcv_wscale; } - if (tp->ecn_flags & TCP_ECN_OK) + if (tcp_ecn_mode_any(tp)) info->tcpi_options |= TCPI_OPT_ECN; if (tp->ecn_flags & TCP_ECN_SEEN) info->tcpi_options |= TCPI_OPT_ECN_SEEN; if (tp->syn_data_acked) info->tcpi_options |= TCPI_OPT_SYN_DATA; + if (tp->tcp_usec_ts) + info->tcpi_options |= TCPI_OPT_USEC_TS; + if (tp->syn_fastopen_child) + info->tcpi_options |= TCPI_OPT_TFO_CHILD; info->tcpi_rto = jiffies_to_usecs(icsk->icsk_rto); - info->tcpi_ato = jiffies_to_usecs(icsk->icsk_ack.ato); + info->tcpi_ato = jiffies_to_usecs(min_t(u32, icsk->icsk_ack.ato, + tcp_delack_max(sk))); info->tcpi_snd_mss = tp->mss_cache; info->tcpi_rcv_mss = icsk->icsk_ack.rcv_mss; @@ -3823,7 +4312,26 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) info->tcpi_reord_seen = tp->reord_seen; info->tcpi_rcv_ooopack = tp->rcv_ooopack; info->tcpi_snd_wnd = tp->snd_wnd; + info->tcpi_rcv_wnd = tp->rcv_wnd; + info->tcpi_rehash = tp->plb_rehash + tp->timeout_rehash; info->tcpi_fastopen_client_fail = tp->fastopen_client_fail; + + info->tcpi_total_rto = tp->total_rto; + info->tcpi_total_rto_recoveries = tp->total_rto_recoveries; + info->tcpi_total_rto_time = tp->total_rto_time; + if (tp->rto_stamp) + info->tcpi_total_rto_time += tcp_clock_ms() - tp->rto_stamp; + + info->tcpi_accecn_fail_mode = tp->accecn_fail_mode; + info->tcpi_accecn_opt_seen = tp->saw_accecn_opt; + info->tcpi_received_ce = tp->received_ce; + info->tcpi_delivered_e1_bytes = tp->delivered_ecn_bytes[ect1_idx]; + info->tcpi_delivered_e0_bytes = tp->delivered_ecn_bytes[ect0_idx]; + info->tcpi_delivered_ce_bytes = tp->delivered_ecn_bytes[ce_idx]; + info->tcpi_received_e1_bytes = tp->received_ecn_bytes[ect1_idx]; + info->tcpi_received_e0_bytes = tp->received_ecn_bytes[ect0_idx]; + info->tcpi_received_ce_bytes = tp->received_ecn_bytes[ce_idx]; + unlock_sock_fast(sk, slow); } EXPORT_SYMBOL_GPL(tcp_get_info); @@ -3857,6 +4365,7 @@ static size_t tcp_opt_stats_get_size(void) nla_total_size(sizeof(u32)) + /* TCP_NLA_BYTES_NOTSENT */ nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_EDT */ nla_total_size(sizeof(u8)) + /* TCP_NLA_TTL */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_REHASH */ 0; } @@ -3904,11 +4413,12 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk, rate64 = tcp_compute_delivery_rate(tp); nla_put_u64_64bit(stats, TCP_NLA_DELIVERY_RATE, rate64, TCP_NLA_PAD); - nla_put_u32(stats, TCP_NLA_SND_CWND, tp->snd_cwnd); + nla_put_u32(stats, TCP_NLA_SND_CWND, tcp_snd_cwnd(tp)); nla_put_u32(stats, TCP_NLA_REORDERING, tp->reordering); nla_put_u32(stats, TCP_NLA_MIN_RTT, tcp_min_rtt(tp)); - nla_put_u8(stats, TCP_NLA_RECUR_RETRANS, inet_csk(sk)->icsk_retransmits); + nla_put_u8(stats, TCP_NLA_RECUR_RETRANS, + READ_ONCE(inet_csk(sk)->icsk_retransmits)); nla_put_u8(stats, TCP_NLA_DELIVERY_RATE_APP_LMT, !!tp->rate_app_limited); nla_put_u32(stats, TCP_NLA_SND_SSTHRESH, tp->snd_ssthresh); nla_put_u32(stats, TCP_NLA_DELIVERED, tp->delivered); @@ -3933,30 +4443,34 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk, nla_put_u8(stats, TCP_NLA_TTL, tcp_skb_ttl_or_hop_limit(ack_skb)); + nla_put_u32(stats, TCP_NLA_REHASH, tp->plb_rehash + tp->timeout_rehash); return stats; } -static int do_tcp_getsockopt(struct sock *sk, int level, - int optname, char __user *optval, int __user *optlen) +int do_tcp_getsockopt(struct sock *sk, int level, + int optname, sockptr_t optval, sockptr_t optlen) { struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); struct net *net = sock_net(sk); + int user_mss; int val, len; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; - len = min_t(unsigned int, len, sizeof(int)); - if (len < 0) return -EINVAL; + len = min_t(unsigned int, len, sizeof(int)); + switch (optname) { case TCP_MAXSEG: val = tp->mss_cache; - if (!val && ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) - val = tp->rx_opt.user_mss; + user_mss = READ_ONCE(tp->rx_opt.user_mss); + if (user_mss && + ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) + val = user_mss; if (tp->repair) val = tp->rx_opt.mss_clamp; break; @@ -3976,32 +4490,34 @@ static int do_tcp_getsockopt(struct sock *sk, int level, val = keepalive_probes(tp); break; case TCP_SYNCNT: - val = icsk->icsk_syn_retries ? : net->ipv4.sysctl_tcp_syn_retries; + val = READ_ONCE(icsk->icsk_syn_retries) ? : + READ_ONCE(net->ipv4.sysctl_tcp_syn_retries); break; case TCP_LINGER2: - val = tp->linger2; + val = READ_ONCE(tp->linger2); if (val >= 0) - val = (val ? : net->ipv4.sysctl_tcp_fin_timeout) / HZ; + val = (val ? : READ_ONCE(net->ipv4.sysctl_tcp_fin_timeout)) / HZ; break; case TCP_DEFER_ACCEPT: - val = retrans_to_secs(icsk->icsk_accept_queue.rskq_defer_accept, - TCP_TIMEOUT_INIT / HZ, TCP_RTO_MAX / HZ); + val = READ_ONCE(icsk->icsk_accept_queue.rskq_defer_accept); + val = retrans_to_secs(val, TCP_TIMEOUT_INIT / HZ, + TCP_RTO_MAX / HZ); break; case TCP_WINDOW_CLAMP: - val = tp->window_clamp; + val = READ_ONCE(tp->window_clamp); break; case TCP_INFO: { struct tcp_info info; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; tcp_get_info(sk, &info); len = min_t(unsigned int, len, sizeof(info)); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &info, len)) + if (copy_to_sockptr(optval, &info, len)) return -EFAULT; return 0; } @@ -4011,7 +4527,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level, size_t sz = 0; int attr; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; ca_ops = icsk->icsk_ca_ops; @@ -4019,9 +4535,9 @@ static int do_tcp_getsockopt(struct sock *sk, int level, sz = ca_ops->get_info(sk, ~0U, &attr, &info); len = min_t(unsigned int, len, sz); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &info, len)) + if (copy_to_sockptr(optval, &info, len)) return -EFAULT; return 0; } @@ -4030,27 +4546,28 @@ static int do_tcp_getsockopt(struct sock *sk, int level, break; case TCP_CONGESTION: - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; len = min_t(unsigned int, len, TCP_CA_NAME_MAX); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, icsk->icsk_ca_ops->name, len)) + if (copy_to_sockptr(optval, icsk->icsk_ca_ops->name, len)) return -EFAULT; return 0; case TCP_ULP: - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; len = min_t(unsigned int, len, TCP_ULP_NAME_MAX); if (!icsk->icsk_ulp_ops) { - if (put_user(0, optlen)) + len = 0; + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; return 0; } - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, icsk->icsk_ulp_ops->name, len)) + if (copy_to_sockptr(optval, icsk->icsk_ulp_ops->name, len)) return -EFAULT; return 0; @@ -4058,15 +4575,15 @@ static int do_tcp_getsockopt(struct sock *sk, int level, u64 key[TCP_FASTOPEN_KEY_BUF_LENGTH / sizeof(u64)]; unsigned int key_len; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; key_len = tcp_fastopen_get_cipher(net, icsk, key) * TCP_FASTOPEN_KEY_LENGTH; len = min_t(unsigned int, len, key_len); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, key, len)) + if (copy_to_sockptr(optval, key, len)) return -EFAULT; return 0; } @@ -4092,7 +4609,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level, case TCP_REPAIR_WINDOW: { struct tcp_repair_window opt; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; if (len != sizeof(opt)) @@ -4107,7 +4624,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level, opt.rcv_wnd = tp->rcv_wnd; opt.rcv_wup = tp->rcv_wup; - if (copy_to_user(optval, &opt, len)) + if (copy_to_sockptr(optval, &opt, len)) return -EFAULT; return 0; } @@ -4121,11 +4638,11 @@ static int do_tcp_getsockopt(struct sock *sk, int level, break; case TCP_USER_TIMEOUT: - val = icsk->icsk_user_timeout; + val = READ_ONCE(icsk->icsk_user_timeout); break; case TCP_FASTOPEN: - val = icsk->icsk_accept_queue.fastopenq.max_qlen; + val = READ_ONCE(icsk->icsk_accept_queue.fastopenq.max_qlen); break; case TCP_FASTOPEN_CONNECT: @@ -4137,14 +4654,18 @@ static int do_tcp_getsockopt(struct sock *sk, int level, break; case TCP_TX_DELAY: - val = tp->tcp_tx_delay; + val = READ_ONCE(tp->tcp_tx_delay); break; case TCP_TIMESTAMP: - val = tcp_time_stamp_raw() + tp->tsoffset; + val = tcp_clock_ts(tp->tcp_usec_ts) + READ_ONCE(tp->tsoffset); + if (tp->tcp_usec_ts) + val |= 1; + else + val &= ~1; break; case TCP_NOTSENT_LOWAT: - val = tp->notsent_lowat; + val = READ_ONCE(tp->notsent_lowat); break; case TCP_INQ: val = tp->recvmsg_inq; @@ -4153,35 +4674,35 @@ static int do_tcp_getsockopt(struct sock *sk, int level, val = tp->save_syn; break; case TCP_SAVED_SYN: { - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; - lock_sock(sk); + sockopt_lock_sock(sk); if (tp->saved_syn) { if (len < tcp_saved_syn_len(tp->saved_syn)) { - if (put_user(tcp_saved_syn_len(tp->saved_syn), - optlen)) { - release_sock(sk); + len = tcp_saved_syn_len(tp->saved_syn); + if (copy_to_sockptr(optlen, &len, sizeof(int))) { + sockopt_release_sock(sk); return -EFAULT; } - release_sock(sk); + sockopt_release_sock(sk); return -EINVAL; } len = tcp_saved_syn_len(tp->saved_syn); - if (put_user(len, optlen)) { - release_sock(sk); + if (copy_to_sockptr(optlen, &len, sizeof(int))) { + sockopt_release_sock(sk); return -EFAULT; } - if (copy_to_user(optval, tp->saved_syn->data, len)) { - release_sock(sk); + if (copy_to_sockptr(optval, tp->saved_syn->data, len)) { + sockopt_release_sock(sk); return -EFAULT; } tcp_saved_syn_free(tp); - release_sock(sk); + sockopt_release_sock(sk); } else { - release_sock(sk); + sockopt_release_sock(sk); len = 0; - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; } return 0; @@ -4192,32 +4713,31 @@ static int do_tcp_getsockopt(struct sock *sk, int level, struct tcp_zerocopy_receive zc = {}; int err; - if (get_user(len, optlen)) + if (copy_from_sockptr(&len, optlen, sizeof(int))) return -EFAULT; if (len < 0 || len < offsetofend(struct tcp_zerocopy_receive, length)) return -EINVAL; if (unlikely(len > sizeof(zc))) { - err = check_zeroed_user(optval + sizeof(zc), - len - sizeof(zc)); + err = check_zeroed_sockptr(optval, sizeof(zc), + len - sizeof(zc)); if (err < 1) return err == 0 ? -EINVAL : err; len = sizeof(zc); - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; } - if (copy_from_user(&zc, optval, len)) + if (copy_from_sockptr(&zc, optval, len)) return -EFAULT; if (zc.reserved) return -EINVAL; if (zc.msg_flags & ~(TCP_VALID_ZC_MSG_FLAGS)) return -EINVAL; - lock_sock(sk); + sockopt_lock_sock(sk); err = tcp_zerocopy_receive(sk, &zc, &tss); err = BPF_CGROUP_RUN_PROG_GETSOCKOPT_KERN(sk, level, optname, &zc, &len, err); - release_sock(sk); - sk_defer_free_flush(sk); + sockopt_release_sock(sk); if (len >= offsetofend(struct tcp_zerocopy_receive, msg_flags)) goto zerocopy_rcv_cmsg; switch (len) { @@ -4247,18 +4767,47 @@ zerocopy_rcv_sk_err: zerocopy_rcv_inq: zc.inq = tcp_inq_hint(sk); zerocopy_rcv_out: - if (!err && copy_to_user(optval, &zc, len)) + if (!err && copy_to_sockptr(optval, &zc, len)) err = -EFAULT; return err; } #endif + case TCP_AO_REPAIR: + if (!tcp_can_repair_sock(sk)) + return -EPERM; + return tcp_ao_get_repair(sk, optval, optlen); + case TCP_AO_GET_KEYS: + case TCP_AO_INFO: { + int err; + + sockopt_lock_sock(sk); + if (optname == TCP_AO_GET_KEYS) + err = tcp_ao_get_mkts(sk, optval, optlen); + else + err = tcp_ao_get_sock_info(sk, optval, optlen); + sockopt_release_sock(sk); + + return err; + } + case TCP_IS_MPTCP: + val = 0; + break; + case TCP_RTO_MAX_MS: + val = jiffies_to_msecs(tcp_rto_max(sk)); + break; + case TCP_RTO_MIN_US: + val = jiffies_to_usecs(READ_ONCE(inet_csk(sk)->icsk_rto_min)); + break; + case TCP_DELACK_MAX_US: + val = jiffies_to_usecs(READ_ONCE(inet_csk(sk)->icsk_delack_max)); + break; default: return -ENOPROTOOPT; } - if (put_user(len, optlen)) + if (copy_to_sockptr(optlen, &len, sizeof(int))) return -EFAULT; - if (copy_to_user(optval, &val, len)) + if (copy_to_sockptr(optval, &val, len)) return -EFAULT; return 0; } @@ -4273,7 +4822,7 @@ bool tcp_bpf_bypass_getsockopt(int level, int optname) return false; } -EXPORT_SYMBOL(tcp_bpf_bypass_getsockopt); +EXPORT_IPV6_MOD(tcp_bpf_bypass_getsockopt); int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval, int __user *optlen) @@ -4281,149 +4830,172 @@ int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval, struct inet_connection_sock *icsk = inet_csk(sk); if (level != SOL_TCP) - return icsk->icsk_af_ops->getsockopt(sk, level, optname, - optval, optlen); - return do_tcp_getsockopt(sk, level, optname, optval, optlen); + /* Paired with WRITE_ONCE() in do_ipv6_setsockopt() and tcp_v6_connect() */ + return READ_ONCE(icsk->icsk_af_ops)->getsockopt(sk, level, optname, + optval, optlen); + return do_tcp_getsockopt(sk, level, optname, USER_SOCKPTR(optval), + USER_SOCKPTR(optlen)); } -EXPORT_SYMBOL(tcp_getsockopt); +EXPORT_IPV6_MOD(tcp_getsockopt); #ifdef CONFIG_TCP_MD5SIG -static DEFINE_PER_CPU(struct tcp_md5sig_pool, tcp_md5sig_pool); -static DEFINE_MUTEX(tcp_md5sig_mutex); -static bool tcp_md5sig_pool_populated = false; - -static void __tcp_alloc_md5sig_pool(void) +void tcp_md5_hash_skb_data(struct md5_ctx *ctx, const struct sk_buff *skb, + unsigned int header_len) { - struct crypto_ahash *hash; - int cpu; + const unsigned int head_data_len = skb_headlen(skb) > header_len ? + skb_headlen(skb) - header_len : 0; + const struct skb_shared_info *shi = skb_shinfo(skb); + struct sk_buff *frag_iter; + unsigned int i; - hash = crypto_alloc_ahash("md5", 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(hash)) - return; + md5_update(ctx, (const u8 *)tcp_hdr(skb) + header_len, head_data_len); - for_each_possible_cpu(cpu) { - void *scratch = per_cpu(tcp_md5sig_pool, cpu).scratch; - struct ahash_request *req; - - if (!scratch) { - scratch = kmalloc_node(sizeof(union tcp_md5sum_block) + - sizeof(struct tcphdr), - GFP_KERNEL, - cpu_to_node(cpu)); - if (!scratch) - return; - per_cpu(tcp_md5sig_pool, cpu).scratch = scratch; + for (i = 0; i < shi->nr_frags; ++i) { + const skb_frag_t *f = &shi->frags[i]; + u32 p_off, p_len, copied; + const void *vaddr; + struct page *p; + + skb_frag_foreach_page(f, skb_frag_off(f), skb_frag_size(f), + p, p_off, p_len, copied) { + vaddr = kmap_local_page(p); + md5_update(ctx, vaddr + p_off, p_len); + kunmap_local(vaddr); } - if (per_cpu(tcp_md5sig_pool, cpu).md5_req) - continue; + } - req = ahash_request_alloc(hash, GFP_KERNEL); - if (!req) - return; + skb_walk_frags(skb, frag_iter) + tcp_md5_hash_skb_data(ctx, frag_iter, 0); +} +EXPORT_IPV6_MOD(tcp_md5_hash_skb_data); - ahash_request_set_callback(req, 0, NULL, NULL); +void tcp_md5_hash_key(struct md5_ctx *ctx, + const struct tcp_md5sig_key *key) +{ + u8 keylen = READ_ONCE(key->keylen); /* paired with WRITE_ONCE() in tcp_md5_do_add */ - per_cpu(tcp_md5sig_pool, cpu).md5_req = req; - } - /* before setting tcp_md5sig_pool_populated, we must commit all writes - * to memory. See smp_rmb() in tcp_get_md5sig_pool() + /* We use data_race() because tcp_md5_do_add() might change + * key->key under us */ - smp_wmb(); - tcp_md5sig_pool_populated = true; + data_race(({ md5_update(ctx, key->key, keylen), 0; })); } +EXPORT_IPV6_MOD(tcp_md5_hash_key); -bool tcp_alloc_md5sig_pool(void) +/* Called with rcu_read_lock() */ +static enum skb_drop_reason +tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb, + const void *saddr, const void *daddr, + int family, int l3index, const __u8 *hash_location) { - if (unlikely(!tcp_md5sig_pool_populated)) { - mutex_lock(&tcp_md5sig_mutex); - - if (!tcp_md5sig_pool_populated) { - __tcp_alloc_md5sig_pool(); - if (tcp_md5sig_pool_populated) - static_branch_inc(&tcp_md5_needed); - } + /* This gets called for each TCP segment that has TCP-MD5 option. + * We have 2 drop cases: + * o An MD5 signature is present, but we're not expecting one. + * o The MD5 signature is wrong. + */ + const struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_key *key; + u8 newhash[16]; + + key = tcp_md5_do_lookup(sk, l3index, saddr, family); + if (!key) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED); + trace_tcp_hash_md5_unexpected(sk, skb); + return SKB_DROP_REASON_TCP_MD5UNEXPECTED; + } - mutex_unlock(&tcp_md5sig_mutex); + /* Check the signature. + * To support dual stack listeners, we need to handle + * IPv4-mapped case. + */ + if (family == AF_INET) + tcp_v4_md5_hash_skb(newhash, key, NULL, skb); + else + tp->af_specific->calc_md5_hash(newhash, key, NULL, skb); + if (memcmp(hash_location, newhash, 16) != 0) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE); + trace_tcp_hash_md5_mismatch(sk, skb); + return SKB_DROP_REASON_TCP_MD5FAILURE; } - return tcp_md5sig_pool_populated; + return SKB_NOT_DROPPED_YET; } -EXPORT_SYMBOL(tcp_alloc_md5sig_pool); - - -/** - * tcp_get_md5sig_pool - get md5sig_pool for this user - * - * We use percpu structure, so if we succeed, we exit with preemption - * and BH disabled, to make sure another thread or softirq handling - * wont try to get same context. - */ -struct tcp_md5sig_pool *tcp_get_md5sig_pool(void) +#else +static inline enum skb_drop_reason +tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb, + const void *saddr, const void *daddr, + int family, int l3index, const __u8 *hash_location) { - local_bh_disable(); - - if (tcp_md5sig_pool_populated) { - /* coupled with smp_wmb() in __tcp_alloc_md5sig_pool() */ - smp_rmb(); - return this_cpu_ptr(&tcp_md5sig_pool); - } - local_bh_enable(); - return NULL; + return SKB_NOT_DROPPED_YET; } -EXPORT_SYMBOL(tcp_get_md5sig_pool); -int tcp_md5_hash_skb_data(struct tcp_md5sig_pool *hp, - const struct sk_buff *skb, unsigned int header_len) +#endif + +/* Called with rcu_read_lock() */ +enum skb_drop_reason +tcp_inbound_hash(struct sock *sk, const struct request_sock *req, + const struct sk_buff *skb, + const void *saddr, const void *daddr, + int family, int dif, int sdif) { - struct scatterlist sg; - const struct tcphdr *tp = tcp_hdr(skb); - struct ahash_request *req = hp->md5_req; - unsigned int i; - const unsigned int head_data_len = skb_headlen(skb) > header_len ? - skb_headlen(skb) - header_len : 0; - const struct skb_shared_info *shi = skb_shinfo(skb); - struct sk_buff *frag_iter; + const struct tcphdr *th = tcp_hdr(skb); + const struct tcp_ao_hdr *aoh; + const __u8 *md5_location; + int l3index; + + /* Invalid option or two times meet any of auth options */ + if (tcp_parse_auth_options(th, &md5_location, &aoh)) { + trace_tcp_hash_bad_header(sk, skb); + return SKB_DROP_REASON_TCP_AUTH_HDR; + } - sg_init_table(&sg, 1); + if (req) { + if (tcp_rsk_used_ao(req) != !!aoh) { + u8 keyid, rnext, maclen; - sg_set_buf(&sg, ((u8 *) tp) + header_len, head_data_len); - ahash_request_set_crypt(req, &sg, NULL, head_data_len); - if (crypto_ahash_update(req)) - return 1; + if (aoh) { + keyid = aoh->keyid; + rnext = aoh->rnext_keyid; + maclen = tcp_ao_hdr_maclen(aoh); + } else { + keyid = rnext = maclen = 0; + } - for (i = 0; i < shi->nr_frags; ++i) { - const skb_frag_t *f = &shi->frags[i]; - unsigned int offset = skb_frag_off(f); - struct page *page = skb_frag_page(f) + (offset >> PAGE_SHIFT); - - sg_set_page(&sg, page, skb_frag_size(f), - offset_in_page(offset)); - ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f)); - if (crypto_ahash_update(req)) - return 1; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOBAD); + trace_tcp_ao_handshake_failure(sk, skb, keyid, rnext, maclen); + return SKB_DROP_REASON_TCP_AOFAILURE; + } } - skb_walk_frags(skb, frag_iter) - if (tcp_md5_hash_skb_data(hp, frag_iter, 0)) - return 1; - - return 0; -} -EXPORT_SYMBOL(tcp_md5_hash_skb_data); - -int tcp_md5_hash_key(struct tcp_md5sig_pool *hp, const struct tcp_md5sig_key *key) -{ - u8 keylen = READ_ONCE(key->keylen); /* paired with WRITE_ONCE() in tcp_md5_do_add */ - struct scatterlist sg; + /* sdif set, means packet ingressed via a device + * in an L3 domain and dif is set to the l3mdev + */ + l3index = sdif ? dif : 0; + + /* Fast path: unsigned segments */ + if (likely(!md5_location && !aoh)) { + /* Drop if there's TCP-MD5 or TCP-AO key with any rcvid/sndid + * for the remote peer. On TCP-AO established connection + * the last key is impossible to remove, so there's + * always at least one current_key. + */ + if (tcp_ao_required(sk, saddr, family, l3index, true)) { + trace_tcp_hash_ao_required(sk, skb); + return SKB_DROP_REASON_TCP_AONOTFOUND; + } + if (unlikely(tcp_md5_do_lookup(sk, l3index, saddr, family))) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND); + trace_tcp_hash_md5_required(sk, skb); + return SKB_DROP_REASON_TCP_MD5NOTFOUND; + } + return SKB_NOT_DROPPED_YET; + } - sg_init_one(&sg, key->key, keylen); - ahash_request_set_crypt(hp->md5_req, &sg, NULL, keylen); + if (aoh) + return tcp_inbound_ao_hash(sk, skb, family, req, l3index, aoh); - /* We use data_race() because tcp_md5_do_add() might change key->key under us */ - return data_race(crypto_ahash_update(hp->md5_req)); + return tcp_inbound_md5_hash(sk, skb, saddr, daddr, family, + l3index, md5_location); } -EXPORT_SYMBOL(tcp_md5_hash_key); - -#endif +EXPORT_IPV6_MOD_GPL(tcp_inbound_hash); void tcp_done(struct sock *sk) { @@ -4443,7 +5015,7 @@ void tcp_done(struct sock *sk) if (req) reqsk_fastopen_remove(sk, req, false); - sk->sk_shutdown = SHUTDOWN_MASK; + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); if (!sock_flag(sk, SOCK_DEAD)) sk->sk_state_change(sk); @@ -4454,20 +5026,37 @@ EXPORT_SYMBOL_GPL(tcp_done); int tcp_abort(struct sock *sk, int err) { - if (!sk_fullsock(sk)) { - if (sk->sk_state == TCP_NEW_SYN_RECV) { - struct request_sock *req = inet_reqsk(sk); + int state = inet_sk_state_load(sk); - local_bh_disable(); - inet_csk_reqsk_queue_drop(req->rsk_listener, req); - local_bh_enable(); - return 0; - } - return -EOPNOTSUPP; + if (state == TCP_NEW_SYN_RECV) { + struct request_sock *req = inet_reqsk(sk); + + local_bh_disable(); + inet_csk_reqsk_queue_drop(req->rsk_listener, req); + local_bh_enable(); + return 0; } + if (state == TCP_TIME_WAIT) { + struct inet_timewait_sock *tw = inet_twsk(sk); - /* Don't race with userspace socket closes such as tcp_close. */ - lock_sock(sk); + refcount_inc(&tw->tw_refcnt); + local_bh_disable(); + inet_twsk_deschedule_put(tw); + local_bh_enable(); + return 0; + } + + /* BPF context ensures sock locking. */ + if (!has_current_bpf_ctx()) + /* Don't race with userspace socket closes such as tcp_close. */ + lock_sock(sk); + + /* Avoid closing the same socket twice. */ + if (sk->sk_state == TCP_CLOSE) { + if (!has_current_bpf_ctx()) + release_sock(sk); + return -ENOENT; + } if (sk->sk_state == TCP_LISTEN) { tcp_set_state(sk, TCP_CLOSE); @@ -4478,20 +5067,15 @@ int tcp_abort(struct sock *sk, int err) local_bh_disable(); bh_lock_sock(sk); - if (!sock_flag(sk, SOCK_DEAD)) { - sk->sk_err = err; - /* This barrier is coupled with smp_rmb() in tcp_poll() */ - smp_wmb(); - sk_error_report(sk); - if (tcp_need_reset(sk->sk_state)) - tcp_send_active_reset(sk, GFP_ATOMIC); - tcp_done(sk); - } + if (tcp_need_reset(sk->sk_state)) + tcp_send_active_reset(sk, GFP_ATOMIC, + SK_RST_REASON_TCP_STATE); + tcp_done_with_error(sk, err); bh_unlock_sock(sk); local_bh_enable(); - tcp_write_queue_purge(sk); - release_sock(sk); + if (!has_current_bpf_ctx()) + release_sock(sk); return 0; } EXPORT_SYMBOL_GPL(tcp_abort); @@ -4524,6 +5108,98 @@ static void __init tcp_init_mem(void) sysctl_tcp_mem[2] = sysctl_tcp_mem[0] * 2; /* 9.37 % */ } +static void __init tcp_struct_check(void) +{ + /* TX read-mostly hotpath cache lines */ + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_tx, max_window); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_tx, rcv_ssthresh); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_tx, reordering); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_tx, notsent_lowat); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_tx, gso_segs); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_tx, retransmit_skb_hint); +#if IS_ENABLED(CONFIG_TLS_DEVICE) + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_tx, tcp_clean_acked); +#endif + + /* TXRX read-mostly hotpath cache lines */ + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_txrx, tsoffset); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_txrx, snd_wnd); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_txrx, mss_cache); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_txrx, snd_cwnd); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_txrx, prr_out); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_txrx, lost_out); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_txrx, sacked_out); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_txrx, scaling_ratio); + + /* RX read-mostly hotpath cache lines */ + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, copied_seq); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, snd_wl1); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, tlp_high_seq); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, rttvar_us); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, retrans_out); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, advmss); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, urg_data); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, lost); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, rtt_min); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, out_of_order_queue); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_read_rx, snd_ssthresh); + + /* TX read-write hotpath cache lines */ + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, segs_out); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, data_segs_out); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, bytes_sent); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, snd_sml); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, chrono_start); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, chrono_stat); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, write_seq); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, pushed_seq); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, lsndtime); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, mdev_us); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, tcp_wstamp_ns); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, accecn_opt_tstamp); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, rtt_seq); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, tsorted_sent_queue); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, highest_sack); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_tx, ecn_flags); + + /* TXRX read-write hotpath cache lines */ + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, pred_flags); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, tcp_clock_cache); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, tcp_mstamp); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, rcv_nxt); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, snd_nxt); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, snd_una); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, window_clamp); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, srtt_us); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, packets_out); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, snd_up); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, delivered); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, delivered_ce); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, received_ce); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, received_ecn_bytes); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, app_limited); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, rcv_wnd); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, rcv_tstamp); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_txrx, rx_opt); + + /* RX read-write hotpath cache lines */ + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, bytes_received); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, segs_in); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, data_segs_in); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, rcv_wup); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, max_packets_out); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, cwnd_usage_seq); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, rate_delivered); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, rate_interval_us); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, rcv_rtt_last_tsecr); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, delivered_ecn_bytes); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, first_tx_mstamp); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, delivered_mstamp); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, bytes_acked); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, rcv_rtt_est); + CACHELINE_ASSERT_GROUP_MEMBER(struct tcp_sock, tcp_sock_write_rx, rcvq_space); +} + void __init tcp_init(void) { int max_rshare, max_wshare, cnt; @@ -4534,12 +5210,13 @@ void __init tcp_init(void) BUILD_BUG_ON(sizeof(struct tcp_skb_cb) > sizeof_field(struct sk_buff, cb)); + tcp_struct_check(); + percpu_counter_init(&tcp_sockets_allocated, 0, GFP_KERNEL); timer_setup(&tcp_orphan_timer, tcp_orphan_update, TIMER_DEFERRABLE); mod_timer(&tcp_orphan_timer, jiffies + TCP_ORPHAN_TIMER_PERIOD); - inet_hashinfo_init(&tcp_hashinfo); inet_hashinfo2_init(&tcp_hashinfo, "tcp_listen_portaddr_hash", thash_entries, 21, /* one slot per 2 MB*/ 0, 64 * 1024); @@ -4549,6 +5226,12 @@ void __init tcp_init(void) SLAB_HWCACHE_ALIGN | SLAB_PANIC | SLAB_ACCOUNT, NULL); + tcp_hashinfo.bind2_bucket_cachep = + kmem_cache_create("tcp_bind2_bucket", + sizeof(struct inet_bind2_bucket), 0, + SLAB_HWCACHE_ALIGN | SLAB_PANIC | + SLAB_ACCOUNT, + NULL); /* Size and allocate the main established and bind bucket * hash tables. @@ -4572,7 +5255,7 @@ void __init tcp_init(void) panic("TCP: failed to alloc ehash_locks"); tcp_hashinfo.bhash = alloc_large_system_hash("TCP bind", - sizeof(struct inet_bind_hashbucket), + 2 * sizeof(struct inet_bind_hashbucket), tcp_hashinfo.ehash_mask + 1, 17, /* one slot per 128 KB of memory */ 0, @@ -4581,11 +5264,15 @@ void __init tcp_init(void) 0, 64 * 1024); tcp_hashinfo.bhash_size = 1U << tcp_hashinfo.bhash_size; + tcp_hashinfo.bhash2 = tcp_hashinfo.bhash + tcp_hashinfo.bhash_size; for (i = 0; i < tcp_hashinfo.bhash_size; i++) { spin_lock_init(&tcp_hashinfo.bhash[i].lock); INIT_HLIST_HEAD(&tcp_hashinfo.bhash[i].chain); + spin_lock_init(&tcp_hashinfo.bhash2[i].lock); + INIT_HLIST_HEAD(&tcp_hashinfo.bhash2[i].chain); } + tcp_hashinfo.pernet = false; cnt = tcp_hashinfo.ehash_mask + 1; sysctl_tcp_max_orphans = cnt / 2; @@ -4594,13 +5281,13 @@ void __init tcp_init(void) /* Set per-socket limits to no more than 1/128 the pressure threshold */ limit = nr_free_buffer_pages() << (PAGE_SHIFT - 7); max_wshare = min(4UL*1024*1024, limit); - max_rshare = min(6UL*1024*1024, limit); + max_rshare = min(32UL*1024*1024, limit); - init_net.ipv4.sysctl_tcp_wmem[0] = SK_MEM_QUANTUM; + init_net.ipv4.sysctl_tcp_wmem[0] = PAGE_SIZE; init_net.ipv4.sysctl_tcp_wmem[1] = 16*1024; init_net.ipv4.sysctl_tcp_wmem[2] = max(64*1024, max_wshare); - init_net.ipv4.sysctl_tcp_rmem[0] = SK_MEM_QUANTUM; + init_net.ipv4.sysctl_tcp_rmem[0] = PAGE_SIZE; init_net.ipv4.sysctl_tcp_rmem[1] = 131072; init_net.ipv4.sysctl_tcp_rmem[2] = max(131072, max_rshare); @@ -4610,6 +5297,6 @@ void __init tcp_init(void) tcp_v4_init(); tcp_metrics_init(); BUG_ON(tcp_register_congestion_control(&tcp_reno) != 0); - tcp_tasklet_init(); + tcp_tsq_work_init(); mptcp_init(); } diff --git a/net/ipv4/tcp_ao.c b/net/ipv4/tcp_ao.c new file mode 100644 index 000000000000..34b8450829d0 --- /dev/null +++ b/net/ipv4/tcp_ao.c @@ -0,0 +1,2442 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP Authentication Option (TCP-AO). + * See RFC5925. + * + * Authors: Dmitry Safonov <dima@arista.com> + * Francesco Ruggeri <fruggeri@arista.com> + * Salam Noureddine <noureddine@arista.com> + */ +#define pr_fmt(fmt) "TCP: " fmt + +#include <crypto/hash.h> +#include <linux/inetdevice.h> +#include <linux/tcp.h> + +#include <net/tcp.h> +#include <net/ipv6.h> +#include <net/icmp.h> +#include <trace/events/tcp.h> + +DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_ao_needed, HZ); + +int tcp_ao_calc_traffic_key(struct tcp_ao_key *mkt, u8 *key, void *ctx, + unsigned int len, struct tcp_sigpool *hp) +{ + struct scatterlist sg; + int ret; + + if (crypto_ahash_setkey(crypto_ahash_reqtfm(hp->req), + mkt->key, mkt->keylen)) + goto clear_hash; + + ret = crypto_ahash_init(hp->req); + if (ret) + goto clear_hash; + + sg_init_one(&sg, ctx, len); + ahash_request_set_crypt(hp->req, &sg, key, len); + crypto_ahash_update(hp->req); + + ret = crypto_ahash_final(hp->req); + if (ret) + goto clear_hash; + + return 0; +clear_hash: + memset(key, 0, tcp_ao_digest_size(mkt)); + return 1; +} + +bool tcp_ao_ignore_icmp(const struct sock *sk, int family, int type, int code) +{ + bool ignore_icmp = false; + struct tcp_ao_info *ao; + + if (!static_branch_unlikely(&tcp_ao_needed.key)) + return false; + + /* RFC5925, 7.8: + * >> A TCP-AO implementation MUST default to ignore incoming ICMPv4 + * messages of Type 3 (destination unreachable), Codes 2-4 (protocol + * unreachable, port unreachable, and fragmentation needed -- ’hard + * errors’), and ICMPv6 Type 1 (destination unreachable), Code 1 + * (administratively prohibited) and Code 4 (port unreachable) intended + * for connections in synchronized states (ESTABLISHED, FIN-WAIT-1, FIN- + * WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT) that match MKTs. + */ + if (family == AF_INET) { + if (type != ICMP_DEST_UNREACH) + return false; + if (code < ICMP_PROT_UNREACH || code > ICMP_FRAG_NEEDED) + return false; + } else { + if (type != ICMPV6_DEST_UNREACH) + return false; + if (code != ICMPV6_ADM_PROHIBITED && code != ICMPV6_PORT_UNREACH) + return false; + } + + rcu_read_lock(); + switch (sk->sk_state) { + case TCP_TIME_WAIT: + ao = rcu_dereference(tcp_twsk(sk)->ao_info); + break; + case TCP_SYN_SENT: + case TCP_SYN_RECV: + case TCP_LISTEN: + case TCP_NEW_SYN_RECV: + /* RFC5925 specifies to ignore ICMPs *only* on connections + * in synchronized states. + */ + rcu_read_unlock(); + return false; + default: + ao = rcu_dereference(tcp_sk(sk)->ao_info); + } + + if (ao && !ao->accept_icmps) { + ignore_icmp = true; + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAODROPPEDICMPS); + atomic64_inc(&ao->counters.dropped_icmp); + } + rcu_read_unlock(); + + return ignore_icmp; +} + +/* Optimized version of tcp_ao_do_lookup(): only for sockets for which + * it's known that the keys in ao_info are matching peer's + * family/address/VRF/etc. + */ +struct tcp_ao_key *tcp_ao_established_key(const struct sock *sk, + struct tcp_ao_info *ao, + int sndid, int rcvid) +{ + struct tcp_ao_key *key; + + hlist_for_each_entry_rcu(key, &ao->head, node, lockdep_sock_is_held(sk)) { + if ((sndid >= 0 && key->sndid != sndid) || + (rcvid >= 0 && key->rcvid != rcvid)) + continue; + return key; + } + + return NULL; +} + +static int ipv4_prefix_cmp(const struct in_addr *addr1, + const struct in_addr *addr2, + unsigned int prefixlen) +{ + __be32 mask = inet_make_mask(prefixlen); + __be32 a1 = addr1->s_addr & mask; + __be32 a2 = addr2->s_addr & mask; + + if (a1 == a2) + return 0; + return memcmp(&a1, &a2, sizeof(a1)); +} + +static int __tcp_ao_key_cmp(const struct tcp_ao_key *key, int l3index, + const union tcp_ao_addr *addr, u8 prefixlen, + int family, int sndid, int rcvid) +{ + if (sndid >= 0 && key->sndid != sndid) + return (key->sndid > sndid) ? 1 : -1; + if (rcvid >= 0 && key->rcvid != rcvid) + return (key->rcvid > rcvid) ? 1 : -1; + if (l3index >= 0 && (key->keyflags & TCP_AO_KEYF_IFINDEX)) { + if (key->l3index != l3index) + return (key->l3index > l3index) ? 1 : -1; + } + + if (family == AF_UNSPEC) + return 0; + if (key->family != family) + return (key->family > family) ? 1 : -1; + + if (family == AF_INET) { + if (ntohl(key->addr.a4.s_addr) == INADDR_ANY) + return 0; + if (ntohl(addr->a4.s_addr) == INADDR_ANY) + return 0; + return ipv4_prefix_cmp(&key->addr.a4, &addr->a4, prefixlen); +#if IS_ENABLED(CONFIG_IPV6) + } else { + if (ipv6_addr_any(&key->addr.a6) || ipv6_addr_any(&addr->a6)) + return 0; + if (ipv6_prefix_equal(&key->addr.a6, &addr->a6, prefixlen)) + return 0; + return memcmp(&key->addr.a6, &addr->a6, sizeof(addr->a6)); +#endif + } + return -1; +} + +static int tcp_ao_key_cmp(const struct tcp_ao_key *key, int l3index, + const union tcp_ao_addr *addr, u8 prefixlen, + int family, int sndid, int rcvid) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (family == AF_INET6 && ipv6_addr_v4mapped(&addr->a6)) { + __be32 addr4 = addr->a6.s6_addr32[3]; + + return __tcp_ao_key_cmp(key, l3index, + (union tcp_ao_addr *)&addr4, + prefixlen, AF_INET, sndid, rcvid); + } +#endif + return __tcp_ao_key_cmp(key, l3index, addr, + prefixlen, family, sndid, rcvid); +} + +static struct tcp_ao_key *__tcp_ao_do_lookup(const struct sock *sk, int l3index, + const union tcp_ao_addr *addr, int family, u8 prefix, + int sndid, int rcvid) +{ + struct tcp_ao_key *key; + struct tcp_ao_info *ao; + + if (!static_branch_unlikely(&tcp_ao_needed.key)) + return NULL; + + ao = rcu_dereference_check(tcp_sk(sk)->ao_info, + lockdep_sock_is_held(sk)); + if (!ao) + return NULL; + + hlist_for_each_entry_rcu(key, &ao->head, node, lockdep_sock_is_held(sk)) { + u8 prefixlen = min(prefix, key->prefixlen); + + if (!tcp_ao_key_cmp(key, l3index, addr, prefixlen, + family, sndid, rcvid)) + return key; + } + return NULL; +} + +struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk, int l3index, + const union tcp_ao_addr *addr, + int family, int sndid, int rcvid) +{ + return __tcp_ao_do_lookup(sk, l3index, addr, family, U8_MAX, sndid, rcvid); +} + +static struct tcp_ao_info *tcp_ao_alloc_info(gfp_t flags) +{ + struct tcp_ao_info *ao; + + ao = kzalloc(sizeof(*ao), flags); + if (!ao) + return NULL; + INIT_HLIST_HEAD(&ao->head); + refcount_set(&ao->refcnt, 1); + + return ao; +} + +static void tcp_ao_link_mkt(struct tcp_ao_info *ao, struct tcp_ao_key *mkt) +{ + hlist_add_head_rcu(&mkt->node, &ao->head); +} + +static struct tcp_ao_key *tcp_ao_copy_key(struct sock *sk, + struct tcp_ao_key *key) +{ + struct tcp_ao_key *new_key; + + new_key = sock_kmalloc(sk, tcp_ao_sizeof_key(key), + GFP_ATOMIC); + if (!new_key) + return NULL; + + *new_key = *key; + INIT_HLIST_NODE(&new_key->node); + tcp_sigpool_get(new_key->tcp_sigpool_id); + atomic64_set(&new_key->pkt_good, 0); + atomic64_set(&new_key->pkt_bad, 0); + + return new_key; +} + +static void tcp_ao_key_free_rcu(struct rcu_head *head) +{ + struct tcp_ao_key *key = container_of(head, struct tcp_ao_key, rcu); + + tcp_sigpool_release(key->tcp_sigpool_id); + kfree_sensitive(key); +} + +static void tcp_ao_info_free(struct tcp_ao_info *ao) +{ + struct tcp_ao_key *key; + struct hlist_node *n; + + hlist_for_each_entry_safe(key, n, &ao->head, node) { + hlist_del(&key->node); + tcp_sigpool_release(key->tcp_sigpool_id); + kfree_sensitive(key); + } + kfree(ao); + static_branch_slow_dec_deferred(&tcp_ao_needed); +} + +static void tcp_ao_sk_omem_free(struct sock *sk, struct tcp_ao_info *ao) +{ + size_t total_ao_sk_mem = 0; + struct tcp_ao_key *key; + + hlist_for_each_entry(key, &ao->head, node) + total_ao_sk_mem += tcp_ao_sizeof_key(key); + atomic_sub(total_ao_sk_mem, &sk->sk_omem_alloc); +} + +void tcp_ao_destroy_sock(struct sock *sk, bool twsk) +{ + struct tcp_ao_info *ao; + + if (twsk) { + ao = rcu_dereference_protected(tcp_twsk(sk)->ao_info, 1); + rcu_assign_pointer(tcp_twsk(sk)->ao_info, NULL); + } else { + ao = rcu_dereference_protected(tcp_sk(sk)->ao_info, 1); + rcu_assign_pointer(tcp_sk(sk)->ao_info, NULL); + } + + if (!ao || !refcount_dec_and_test(&ao->refcnt)) + return; + + if (!twsk) + tcp_ao_sk_omem_free(sk, ao); + tcp_ao_info_free(ao); +} + +void tcp_ao_time_wait(struct tcp_timewait_sock *tcptw, struct tcp_sock *tp) +{ + struct tcp_ao_info *ao_info = rcu_dereference_protected(tp->ao_info, 1); + + if (ao_info) { + struct tcp_ao_key *key; + struct hlist_node *n; + int omem = 0; + + hlist_for_each_entry_safe(key, n, &ao_info->head, node) { + omem += tcp_ao_sizeof_key(key); + } + + refcount_inc(&ao_info->refcnt); + atomic_sub(omem, &(((struct sock *)tp)->sk_omem_alloc)); + rcu_assign_pointer(tcptw->ao_info, ao_info); + } else { + tcptw->ao_info = NULL; + } +} + +/* 4 tuple and ISNs are expected in NBO */ +static int tcp_v4_ao_calc_key(struct tcp_ao_key *mkt, u8 *key, + __be32 saddr, __be32 daddr, + __be16 sport, __be16 dport, + __be32 sisn, __be32 disn) +{ + /* See RFC5926 3.1.1 */ + struct kdf_input_block { + u8 counter; + u8 label[6]; + struct tcp4_ao_context ctx; + __be16 outlen; + } __packed * tmp; + struct tcp_sigpool hp; + int err; + + err = tcp_sigpool_start(mkt->tcp_sigpool_id, &hp); + if (err) + return err; + + tmp = hp.scratch; + tmp->counter = 1; + memcpy(tmp->label, "TCP-AO", 6); + tmp->ctx.saddr = saddr; + tmp->ctx.daddr = daddr; + tmp->ctx.sport = sport; + tmp->ctx.dport = dport; + tmp->ctx.sisn = sisn; + tmp->ctx.disn = disn; + tmp->outlen = htons(tcp_ao_digest_size(mkt) * 8); /* in bits */ + + err = tcp_ao_calc_traffic_key(mkt, key, tmp, sizeof(*tmp), &hp); + tcp_sigpool_end(&hp); + + return err; +} + +int tcp_v4_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key, + const struct sock *sk, + __be32 sisn, __be32 disn, bool send) +{ + if (send) + return tcp_v4_ao_calc_key(mkt, key, sk->sk_rcv_saddr, + sk->sk_daddr, htons(sk->sk_num), + sk->sk_dport, sisn, disn); + else + return tcp_v4_ao_calc_key(mkt, key, sk->sk_daddr, + sk->sk_rcv_saddr, sk->sk_dport, + htons(sk->sk_num), disn, sisn); +} + +static int tcp_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key, + const struct sock *sk, + __be32 sisn, __be32 disn, bool send) +{ + if (mkt->family == AF_INET) + return tcp_v4_ao_calc_key_sk(mkt, key, sk, sisn, disn, send); +#if IS_ENABLED(CONFIG_IPV6) + else if (mkt->family == AF_INET6) + return tcp_v6_ao_calc_key_sk(mkt, key, sk, sisn, disn, send); +#endif + else + return -EOPNOTSUPP; +} + +int tcp_v4_ao_calc_key_rsk(struct tcp_ao_key *mkt, u8 *key, + struct request_sock *req) +{ + struct inet_request_sock *ireq = inet_rsk(req); + + return tcp_v4_ao_calc_key(mkt, key, + ireq->ir_loc_addr, ireq->ir_rmt_addr, + htons(ireq->ir_num), ireq->ir_rmt_port, + htonl(tcp_rsk(req)->snt_isn), + htonl(tcp_rsk(req)->rcv_isn)); +} + +static int tcp_v4_ao_calc_key_skb(struct tcp_ao_key *mkt, u8 *key, + const struct sk_buff *skb, + __be32 sisn, __be32 disn) +{ + const struct iphdr *iph = ip_hdr(skb); + const struct tcphdr *th = tcp_hdr(skb); + + return tcp_v4_ao_calc_key(mkt, key, iph->saddr, iph->daddr, + th->source, th->dest, sisn, disn); +} + +static int tcp_ao_calc_key_skb(struct tcp_ao_key *mkt, u8 *key, + const struct sk_buff *skb, + __be32 sisn, __be32 disn, int family) +{ + if (family == AF_INET) + return tcp_v4_ao_calc_key_skb(mkt, key, skb, sisn, disn); +#if IS_ENABLED(CONFIG_IPV6) + else if (family == AF_INET6) + return tcp_v6_ao_calc_key_skb(mkt, key, skb, sisn, disn); +#endif + return -EAFNOSUPPORT; +} + +static int tcp_v4_ao_hash_pseudoheader(struct tcp_sigpool *hp, + __be32 daddr, __be32 saddr, + int nbytes) +{ + struct tcp4_pseudohdr *bp; + struct scatterlist sg; + + bp = hp->scratch; + bp->saddr = saddr; + bp->daddr = daddr; + bp->pad = 0; + bp->protocol = IPPROTO_TCP; + bp->len = cpu_to_be16(nbytes); + + sg_init_one(&sg, bp, sizeof(*bp)); + ahash_request_set_crypt(hp->req, &sg, NULL, sizeof(*bp)); + return crypto_ahash_update(hp->req); +} + +static int tcp_ao_hash_pseudoheader(unsigned short int family, + const struct sock *sk, + const struct sk_buff *skb, + struct tcp_sigpool *hp, int nbytes) +{ + const struct tcphdr *th = tcp_hdr(skb); + + /* TODO: Can we rely on checksum being zero to mean outbound pkt? */ + if (!th->check) { + if (family == AF_INET) + return tcp_v4_ao_hash_pseudoheader(hp, sk->sk_daddr, + sk->sk_rcv_saddr, skb->len); +#if IS_ENABLED(CONFIG_IPV6) + else if (family == AF_INET6) + return tcp_v6_ao_hash_pseudoheader(hp, &sk->sk_v6_daddr, + &sk->sk_v6_rcv_saddr, skb->len); +#endif + else + return -EAFNOSUPPORT; + } + + if (family == AF_INET) { + const struct iphdr *iph = ip_hdr(skb); + + return tcp_v4_ao_hash_pseudoheader(hp, iph->daddr, + iph->saddr, skb->len); +#if IS_ENABLED(CONFIG_IPV6) + } else if (family == AF_INET6) { + const struct ipv6hdr *iph = ipv6_hdr(skb); + + return tcp_v6_ao_hash_pseudoheader(hp, &iph->daddr, + &iph->saddr, skb->len); +#endif + } + return -EAFNOSUPPORT; +} + +u32 tcp_ao_compute_sne(u32 next_sne, u32 next_seq, u32 seq) +{ + u32 sne = next_sne; + + if (before(seq, next_seq)) { + if (seq > next_seq) + sne--; + } else { + if (seq < next_seq) + sne++; + } + + return sne; +} + +/* tcp_ao_hash_sne(struct tcp_sigpool *hp) + * @hp - used for hashing + * @sne - sne value + */ +static int tcp_ao_hash_sne(struct tcp_sigpool *hp, u32 sne) +{ + struct scatterlist sg; + __be32 *bp; + + bp = (__be32 *)hp->scratch; + *bp = htonl(sne); + + sg_init_one(&sg, bp, sizeof(*bp)); + ahash_request_set_crypt(hp->req, &sg, NULL, sizeof(*bp)); + return crypto_ahash_update(hp->req); +} + +static int tcp_ao_hash_header(struct tcp_sigpool *hp, + const struct tcphdr *th, + bool exclude_options, u8 *hash, + int hash_offset, int hash_len) +{ + struct scatterlist sg; + u8 *hdr = hp->scratch; + int err, len; + + /* We are not allowed to change tcphdr, make a local copy */ + if (exclude_options) { + len = sizeof(*th) + sizeof(struct tcp_ao_hdr) + hash_len; + memcpy(hdr, th, sizeof(*th)); + memcpy(hdr + sizeof(*th), + (u8 *)th + hash_offset - sizeof(struct tcp_ao_hdr), + sizeof(struct tcp_ao_hdr)); + memset(hdr + sizeof(*th) + sizeof(struct tcp_ao_hdr), + 0, hash_len); + ((struct tcphdr *)hdr)->check = 0; + } else { + len = th->doff << 2; + memcpy(hdr, th, len); + /* zero out tcp-ao hash */ + ((struct tcphdr *)hdr)->check = 0; + memset(hdr + hash_offset, 0, hash_len); + } + + sg_init_one(&sg, hdr, len); + ahash_request_set_crypt(hp->req, &sg, NULL, len); + err = crypto_ahash_update(hp->req); + WARN_ON_ONCE(err != 0); + return err; +} + +int tcp_ao_hash_hdr(unsigned short int family, char *ao_hash, + struct tcp_ao_key *key, const u8 *tkey, + const union tcp_ao_addr *daddr, + const union tcp_ao_addr *saddr, + const struct tcphdr *th, u32 sne) +{ + int tkey_len = tcp_ao_digest_size(key); + int hash_offset = ao_hash - (char *)th; + struct tcp_sigpool hp; + void *hash_buf = NULL; + + hash_buf = kmalloc(tkey_len, GFP_ATOMIC); + if (!hash_buf) + goto clear_hash_noput; + + if (tcp_sigpool_start(key->tcp_sigpool_id, &hp)) + goto clear_hash_noput; + + if (crypto_ahash_setkey(crypto_ahash_reqtfm(hp.req), tkey, tkey_len)) + goto clear_hash; + + if (crypto_ahash_init(hp.req)) + goto clear_hash; + + if (tcp_ao_hash_sne(&hp, sne)) + goto clear_hash; + if (family == AF_INET) { + if (tcp_v4_ao_hash_pseudoheader(&hp, daddr->a4.s_addr, + saddr->a4.s_addr, th->doff * 4)) + goto clear_hash; +#if IS_ENABLED(CONFIG_IPV6) + } else if (family == AF_INET6) { + if (tcp_v6_ao_hash_pseudoheader(&hp, &daddr->a6, + &saddr->a6, th->doff * 4)) + goto clear_hash; +#endif + } else { + WARN_ON_ONCE(1); + goto clear_hash; + } + if (tcp_ao_hash_header(&hp, th, + !!(key->keyflags & TCP_AO_KEYF_EXCLUDE_OPT), + ao_hash, hash_offset, tcp_ao_maclen(key))) + goto clear_hash; + ahash_request_set_crypt(hp.req, NULL, hash_buf, 0); + if (crypto_ahash_final(hp.req)) + goto clear_hash; + + memcpy(ao_hash, hash_buf, tcp_ao_maclen(key)); + tcp_sigpool_end(&hp); + kfree(hash_buf); + return 0; + +clear_hash: + tcp_sigpool_end(&hp); +clear_hash_noput: + memset(ao_hash, 0, tcp_ao_maclen(key)); + kfree(hash_buf); + return 1; +} + +int tcp_ao_hash_skb(unsigned short int family, + char *ao_hash, struct tcp_ao_key *key, + const struct sock *sk, const struct sk_buff *skb, + const u8 *tkey, int hash_offset, u32 sne) +{ + const struct tcphdr *th = tcp_hdr(skb); + int tkey_len = tcp_ao_digest_size(key); + struct tcp_sigpool hp; + void *hash_buf = NULL; + + hash_buf = kmalloc(tkey_len, GFP_ATOMIC); + if (!hash_buf) + goto clear_hash_noput; + + if (tcp_sigpool_start(key->tcp_sigpool_id, &hp)) + goto clear_hash_noput; + + if (crypto_ahash_setkey(crypto_ahash_reqtfm(hp.req), tkey, tkey_len)) + goto clear_hash; + + /* For now use sha1 by default. Depends on alg in tcp_ao_key */ + if (crypto_ahash_init(hp.req)) + goto clear_hash; + + if (tcp_ao_hash_sne(&hp, sne)) + goto clear_hash; + if (tcp_ao_hash_pseudoheader(family, sk, skb, &hp, skb->len)) + goto clear_hash; + if (tcp_ao_hash_header(&hp, th, + !!(key->keyflags & TCP_AO_KEYF_EXCLUDE_OPT), + ao_hash, hash_offset, tcp_ao_maclen(key))) + goto clear_hash; + if (tcp_sigpool_hash_skb_data(&hp, skb, th->doff << 2)) + goto clear_hash; + ahash_request_set_crypt(hp.req, NULL, hash_buf, 0); + if (crypto_ahash_final(hp.req)) + goto clear_hash; + + memcpy(ao_hash, hash_buf, tcp_ao_maclen(key)); + tcp_sigpool_end(&hp); + kfree(hash_buf); + return 0; + +clear_hash: + tcp_sigpool_end(&hp); +clear_hash_noput: + memset(ao_hash, 0, tcp_ao_maclen(key)); + kfree(hash_buf); + return 1; +} + +int tcp_v4_ao_hash_skb(char *ao_hash, struct tcp_ao_key *key, + const struct sock *sk, const struct sk_buff *skb, + const u8 *tkey, int hash_offset, u32 sne) +{ + return tcp_ao_hash_skb(AF_INET, ao_hash, key, sk, skb, + tkey, hash_offset, sne); +} + +int tcp_v4_ao_synack_hash(char *ao_hash, struct tcp_ao_key *ao_key, + struct request_sock *req, const struct sk_buff *skb, + int hash_offset, u32 sne) +{ + void *hash_buf = NULL; + int err; + + hash_buf = kmalloc(tcp_ao_digest_size(ao_key), GFP_ATOMIC); + if (!hash_buf) + return -ENOMEM; + + err = tcp_v4_ao_calc_key_rsk(ao_key, hash_buf, req); + if (err) + goto out; + + err = tcp_ao_hash_skb(AF_INET, ao_hash, ao_key, req_to_sk(req), skb, + hash_buf, hash_offset, sne); +out: + kfree(hash_buf); + return err; +} + +struct tcp_ao_key *tcp_v4_ao_lookup_rsk(const struct sock *sk, + struct request_sock *req, + int sndid, int rcvid) +{ + struct inet_request_sock *ireq = inet_rsk(req); + union tcp_ao_addr *addr = (union tcp_ao_addr *)&ireq->ir_rmt_addr; + int l3index; + + l3index = l3mdev_master_ifindex_by_index(sock_net(sk), ireq->ir_iif); + return tcp_ao_do_lookup(sk, l3index, addr, AF_INET, sndid, rcvid); +} + +struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk, + int sndid, int rcvid) +{ + int l3index = l3mdev_master_ifindex_by_index(sock_net(sk), + addr_sk->sk_bound_dev_if); + union tcp_ao_addr *addr = (union tcp_ao_addr *)&addr_sk->sk_daddr; + + return tcp_ao_do_lookup(sk, l3index, addr, AF_INET, sndid, rcvid); +} + +int tcp_ao_prepare_reset(const struct sock *sk, struct sk_buff *skb, + const struct tcp_ao_hdr *aoh, int l3index, u32 seq, + struct tcp_ao_key **key, char **traffic_key, + bool *allocated_traffic_key, u8 *keyid, u32 *sne) +{ + const struct tcphdr *th = tcp_hdr(skb); + struct tcp_ao_info *ao_info; + + *allocated_traffic_key = false; + /* If there's no socket - than initial sisn/disn are unknown. + * Drop the segment. RFC5925 (7.7) advises to require graceful + * restart [RFC4724]. Alternatively, the RFC5925 advises to + * save/restore traffic keys before/after reboot. + * Linux TCP-AO support provides TCP_AO_ADD_KEY and TCP_AO_REPAIR + * options to restore a socket post-reboot. + */ + if (!sk) + return -ENOTCONN; + + if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_NEW_SYN_RECV)) { + unsigned int family = READ_ONCE(sk->sk_family); + union tcp_ao_addr *addr; + __be32 disn, sisn; + + if (sk->sk_state == TCP_NEW_SYN_RECV) { + struct request_sock *req = inet_reqsk(sk); + + sisn = htonl(tcp_rsk(req)->rcv_isn); + disn = htonl(tcp_rsk(req)->snt_isn); + *sne = tcp_ao_compute_sne(0, tcp_rsk(req)->snt_isn, seq); + } else { + sisn = th->seq; + disn = 0; + } + if (IS_ENABLED(CONFIG_IPV6) && family == AF_INET6) + addr = (union tcp_md5_addr *)&ipv6_hdr(skb)->saddr; + else + addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr; +#if IS_ENABLED(CONFIG_IPV6) + if (family == AF_INET6 && ipv6_addr_v4mapped(&sk->sk_v6_daddr)) + family = AF_INET; +#endif + + sk = sk_const_to_full_sk(sk); + ao_info = rcu_dereference(tcp_sk(sk)->ao_info); + if (!ao_info) + return -ENOENT; + *key = tcp_ao_do_lookup(sk, l3index, addr, family, + -1, aoh->rnext_keyid); + if (!*key) + return -ENOENT; + *traffic_key = kmalloc(tcp_ao_digest_size(*key), GFP_ATOMIC); + if (!*traffic_key) + return -ENOMEM; + *allocated_traffic_key = true; + if (tcp_ao_calc_key_skb(*key, *traffic_key, skb, + sisn, disn, family)) + return -1; + *keyid = (*key)->rcvid; + } else { + struct tcp_ao_key *rnext_key; + u32 snd_basis; + + if (sk->sk_state == TCP_TIME_WAIT) { + ao_info = rcu_dereference(tcp_twsk(sk)->ao_info); + snd_basis = tcp_twsk(sk)->tw_snd_nxt; + } else { + ao_info = rcu_dereference(tcp_sk(sk)->ao_info); + snd_basis = tcp_sk(sk)->snd_una; + } + if (!ao_info) + return -ENOENT; + + *key = tcp_ao_established_key(sk, ao_info, aoh->rnext_keyid, -1); + if (!*key) + return -ENOENT; + *traffic_key = snd_other_key(*key); + rnext_key = READ_ONCE(ao_info->rnext_key); + *keyid = rnext_key->rcvid; + *sne = tcp_ao_compute_sne(READ_ONCE(ao_info->snd_sne), + snd_basis, seq); + } + return 0; +} + +int tcp_ao_transmit_skb(struct sock *sk, struct sk_buff *skb, + struct tcp_ao_key *key, struct tcphdr *th, + __u8 *hash_location) +{ + struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_ao_info *ao; + void *tkey_buf = NULL; + u8 *traffic_key; + u32 sne; + + ao = rcu_dereference_protected(tcp_sk(sk)->ao_info, + lockdep_sock_is_held(sk)); + traffic_key = snd_other_key(key); + if (unlikely(tcb->tcp_flags & TCPHDR_SYN)) { + __be32 disn; + + if (!(tcb->tcp_flags & TCPHDR_ACK)) { + disn = 0; + tkey_buf = kmalloc(tcp_ao_digest_size(key), GFP_ATOMIC); + if (!tkey_buf) + return -ENOMEM; + traffic_key = tkey_buf; + } else { + disn = ao->risn; + } + tp->af_specific->ao_calc_key_sk(key, traffic_key, + sk, ao->lisn, disn, true); + } + sne = tcp_ao_compute_sne(READ_ONCE(ao->snd_sne), READ_ONCE(tp->snd_una), + ntohl(th->seq)); + tp->af_specific->calc_ao_hash(hash_location, key, sk, skb, traffic_key, + hash_location - (u8 *)th, sne); + kfree(tkey_buf); + return 0; +} + +static struct tcp_ao_key *tcp_ao_inbound_lookup(unsigned short int family, + const struct sock *sk, const struct sk_buff *skb, + int sndid, int rcvid, int l3index) +{ + if (family == AF_INET) { + const struct iphdr *iph = ip_hdr(skb); + + return tcp_ao_do_lookup(sk, l3index, + (union tcp_ao_addr *)&iph->saddr, + AF_INET, sndid, rcvid); + } else { + const struct ipv6hdr *iph = ipv6_hdr(skb); + + return tcp_ao_do_lookup(sk, l3index, + (union tcp_ao_addr *)&iph->saddr, + AF_INET6, sndid, rcvid); + } +} + +void tcp_ao_syncookie(struct sock *sk, const struct sk_buff *skb, + struct request_sock *req, unsigned short int family) +{ + struct tcp_request_sock *treq = tcp_rsk(req); + const struct tcphdr *th = tcp_hdr(skb); + const struct tcp_ao_hdr *aoh; + struct tcp_ao_key *key; + int l3index; + + /* treq->af_specific is used to perform TCP_AO lookup + * in tcp_create_openreq_child(). + */ +#if IS_ENABLED(CONFIG_IPV6) + if (family == AF_INET6) + treq->af_specific = &tcp_request_sock_ipv6_ops; + else +#endif + treq->af_specific = &tcp_request_sock_ipv4_ops; + + treq->used_tcp_ao = false; + + if (tcp_parse_auth_options(th, NULL, &aoh) || !aoh) + return; + + l3index = l3mdev_master_ifindex_by_index(sock_net(sk), inet_rsk(req)->ir_iif); + key = tcp_ao_inbound_lookup(family, sk, skb, -1, aoh->keyid, l3index); + if (!key) + /* Key not found, continue without TCP-AO */ + return; + + treq->ao_rcv_next = aoh->keyid; + treq->ao_keyid = aoh->rnext_keyid; + treq->used_tcp_ao = true; +} + +static enum skb_drop_reason +tcp_ao_verify_hash(const struct sock *sk, const struct sk_buff *skb, + unsigned short int family, struct tcp_ao_info *info, + const struct tcp_ao_hdr *aoh, struct tcp_ao_key *key, + u8 *traffic_key, u8 *phash, u32 sne, int l3index) +{ + const struct tcphdr *th = tcp_hdr(skb); + u8 maclen = tcp_ao_hdr_maclen(aoh); + void *hash_buf = NULL; + + if (maclen != tcp_ao_maclen(key)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOBAD); + atomic64_inc(&info->counters.pkt_bad); + atomic64_inc(&key->pkt_bad); + trace_tcp_ao_wrong_maclen(sk, skb, aoh->keyid, + aoh->rnext_keyid, maclen); + return SKB_DROP_REASON_TCP_AOFAILURE; + } + + hash_buf = kmalloc(tcp_ao_digest_size(key), GFP_ATOMIC); + if (!hash_buf) + return SKB_DROP_REASON_NOT_SPECIFIED; + + /* XXX: make it per-AF callback? */ + tcp_ao_hash_skb(family, hash_buf, key, sk, skb, traffic_key, + (phash - (u8 *)th), sne); + if (memcmp(phash, hash_buf, maclen)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOBAD); + atomic64_inc(&info->counters.pkt_bad); + atomic64_inc(&key->pkt_bad); + trace_tcp_ao_mismatch(sk, skb, aoh->keyid, + aoh->rnext_keyid, maclen); + kfree(hash_buf); + return SKB_DROP_REASON_TCP_AOFAILURE; + } + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOGOOD); + atomic64_inc(&info->counters.pkt_good); + atomic64_inc(&key->pkt_good); + kfree(hash_buf); + return SKB_NOT_DROPPED_YET; +} + +enum skb_drop_reason +tcp_inbound_ao_hash(struct sock *sk, const struct sk_buff *skb, + unsigned short int family, const struct request_sock *req, + int l3index, const struct tcp_ao_hdr *aoh) +{ + const struct tcphdr *th = tcp_hdr(skb); + u8 maclen = tcp_ao_hdr_maclen(aoh); + u8 *phash = (u8 *)(aoh + 1); /* hash goes just after the header */ + struct tcp_ao_info *info; + enum skb_drop_reason ret; + struct tcp_ao_key *key; + __be32 sisn, disn; + u8 *traffic_key; + int state; + u32 sne = 0; + + info = rcu_dereference(tcp_sk(sk)->ao_info); + if (!info) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOKEYNOTFOUND); + trace_tcp_ao_key_not_found(sk, skb, aoh->keyid, + aoh->rnext_keyid, maclen); + return SKB_DROP_REASON_TCP_AOUNEXPECTED; + } + + if (unlikely(th->syn)) { + sisn = th->seq; + disn = 0; + } + + state = READ_ONCE(sk->sk_state); + /* Fast-path */ + if (likely((1 << state) & TCP_AO_ESTABLISHED)) { + enum skb_drop_reason err; + struct tcp_ao_key *current_key; + + /* Check if this socket's rnext_key matches the keyid in the + * packet. If not we lookup the key based on the keyid + * matching the rcvid in the mkt. + */ + key = READ_ONCE(info->rnext_key); + if (key->rcvid != aoh->keyid) { + key = tcp_ao_established_key(sk, info, -1, aoh->keyid); + if (!key) + goto key_not_found; + } + + /* Delayed retransmitted SYN */ + if (unlikely(th->syn && !th->ack)) + goto verify_hash; + + sne = tcp_ao_compute_sne(info->rcv_sne, tcp_sk(sk)->rcv_nxt, + ntohl(th->seq)); + /* Established socket, traffic key are cached */ + traffic_key = rcv_other_key(key); + err = tcp_ao_verify_hash(sk, skb, family, info, aoh, key, + traffic_key, phash, sne, l3index); + if (err) + return err; + current_key = READ_ONCE(info->current_key); + /* Key rotation: the peer asks us to use new key (RNext) */ + if (unlikely(aoh->rnext_keyid != current_key->sndid)) { + trace_tcp_ao_rnext_request(sk, skb, current_key->sndid, + aoh->rnext_keyid, + tcp_ao_hdr_maclen(aoh)); + /* If the key is not found we do nothing. */ + key = tcp_ao_established_key(sk, info, aoh->rnext_keyid, -1); + if (key) + /* pairs with tcp_ao_del_cmd */ + WRITE_ONCE(info->current_key, key); + } + return SKB_NOT_DROPPED_YET; + } + + if (unlikely(state == TCP_CLOSE)) + return SKB_DROP_REASON_TCP_CLOSE; + + /* Lookup key based on peer address and keyid. + * current_key and rnext_key must not be used on tcp listen + * sockets as otherwise: + * - request sockets would race on those key pointers + * - tcp_ao_del_cmd() allows async key removal + */ + key = tcp_ao_inbound_lookup(family, sk, skb, -1, aoh->keyid, l3index); + if (!key) + goto key_not_found; + + if (th->syn && !th->ack) + goto verify_hash; + + if ((1 << state) & (TCPF_LISTEN | TCPF_NEW_SYN_RECV)) { + /* Make the initial syn the likely case here */ + if (unlikely(req)) { + sne = tcp_ao_compute_sne(0, tcp_rsk(req)->rcv_isn, + ntohl(th->seq)); + sisn = htonl(tcp_rsk(req)->rcv_isn); + disn = htonl(tcp_rsk(req)->snt_isn); + } else if (unlikely(th->ack && !th->syn)) { + /* Possible syncookie packet */ + sisn = htonl(ntohl(th->seq) - 1); + disn = htonl(ntohl(th->ack_seq) - 1); + sne = tcp_ao_compute_sne(0, ntohl(sisn), + ntohl(th->seq)); + } else if (unlikely(!th->syn)) { + /* no way to figure out initial sisn/disn - drop */ + return SKB_DROP_REASON_TCP_FLAGS; + } + } else if ((1 << state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) { + disn = info->lisn; + if (th->syn || th->rst) + sisn = th->seq; + else + sisn = info->risn; + } else { + WARN_ONCE(1, "TCP-AO: Unexpected sk_state %d", state); + return SKB_DROP_REASON_TCP_AOFAILURE; + } +verify_hash: + traffic_key = kmalloc(tcp_ao_digest_size(key), GFP_ATOMIC); + if (!traffic_key) + return SKB_DROP_REASON_NOT_SPECIFIED; + tcp_ao_calc_key_skb(key, traffic_key, skb, sisn, disn, family); + ret = tcp_ao_verify_hash(sk, skb, family, info, aoh, key, + traffic_key, phash, sne, l3index); + kfree(traffic_key); + return ret; + +key_not_found: + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOKEYNOTFOUND); + atomic64_inc(&info->counters.key_not_found); + trace_tcp_ao_key_not_found(sk, skb, aoh->keyid, + aoh->rnext_keyid, maclen); + return SKB_DROP_REASON_TCP_AOKEYNOTFOUND; +} + +static int tcp_ao_cache_traffic_keys(const struct sock *sk, + struct tcp_ao_info *ao, + struct tcp_ao_key *ao_key) +{ + u8 *traffic_key = snd_other_key(ao_key); + int ret; + + ret = tcp_ao_calc_key_sk(ao_key, traffic_key, sk, + ao->lisn, ao->risn, true); + if (ret) + return ret; + + traffic_key = rcv_other_key(ao_key); + ret = tcp_ao_calc_key_sk(ao_key, traffic_key, sk, + ao->lisn, ao->risn, false); + return ret; +} + +void tcp_ao_connect_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_ao_info *ao_info; + struct hlist_node *next; + union tcp_ao_addr *addr; + struct tcp_ao_key *key; + int family, l3index; + + ao_info = rcu_dereference_protected(tp->ao_info, + lockdep_sock_is_held(sk)); + if (!ao_info) + return; + + /* Remove all keys that don't match the peer */ + family = sk->sk_family; + if (family == AF_INET) + addr = (union tcp_ao_addr *)&sk->sk_daddr; +#if IS_ENABLED(CONFIG_IPV6) + else if (family == AF_INET6) + addr = (union tcp_ao_addr *)&sk->sk_v6_daddr; +#endif + else + return; + l3index = l3mdev_master_ifindex_by_index(sock_net(sk), + sk->sk_bound_dev_if); + + hlist_for_each_entry_safe(key, next, &ao_info->head, node) { + if (!tcp_ao_key_cmp(key, l3index, addr, key->prefixlen, family, -1, -1)) + continue; + + if (key == ao_info->current_key) + ao_info->current_key = NULL; + if (key == ao_info->rnext_key) + ao_info->rnext_key = NULL; + hlist_del_rcu(&key->node); + atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc); + call_rcu(&key->rcu, tcp_ao_key_free_rcu); + } + + key = tp->af_specific->ao_lookup(sk, sk, -1, -1); + if (key) { + /* if current_key or rnext_key were not provided, + * use the first key matching the peer + */ + if (!ao_info->current_key) + ao_info->current_key = key; + if (!ao_info->rnext_key) + ao_info->rnext_key = key; + tp->tcp_header_len += tcp_ao_len_aligned(key); + + ao_info->lisn = htonl(tp->write_seq); + ao_info->snd_sne = 0; + } else { + /* Can't happen: tcp_connect() verifies that there's + * at least one tcp-ao key that matches the remote peer. + */ + WARN_ON_ONCE(1); + rcu_assign_pointer(tp->ao_info, NULL); + kfree(ao_info); + } +} + +void tcp_ao_established(struct sock *sk) +{ + struct tcp_ao_info *ao; + struct tcp_ao_key *key; + + ao = rcu_dereference_protected(tcp_sk(sk)->ao_info, + lockdep_sock_is_held(sk)); + if (!ao) + return; + + hlist_for_each_entry_rcu(key, &ao->head, node, lockdep_sock_is_held(sk)) + tcp_ao_cache_traffic_keys(sk, ao, key); +} + +void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_ao_info *ao; + struct tcp_ao_key *key; + + ao = rcu_dereference_protected(tcp_sk(sk)->ao_info, + lockdep_sock_is_held(sk)); + if (!ao) + return; + + /* sk with TCP_REPAIR_ON does not have skb in tcp_finish_connect */ + if (skb) + WRITE_ONCE(ao->risn, tcp_hdr(skb)->seq); + ao->rcv_sne = 0; + + hlist_for_each_entry_rcu(key, &ao->head, node, lockdep_sock_is_held(sk)) + tcp_ao_cache_traffic_keys(sk, ao, key); +} + +int tcp_ao_copy_all_matching(const struct sock *sk, struct sock *newsk, + struct request_sock *req, struct sk_buff *skb, + int family) +{ + struct tcp_ao_key *key, *new_key, *first_key; + struct tcp_ao_info *new_ao, *ao; + struct hlist_node *key_head; + int l3index, ret = -ENOMEM; + union tcp_ao_addr *addr; + bool match = false; + + ao = rcu_dereference(tcp_sk(sk)->ao_info); + if (!ao) + return 0; + + /* New socket without TCP-AO on it */ + if (!tcp_rsk_used_ao(req)) + return 0; + + new_ao = tcp_ao_alloc_info(GFP_ATOMIC); + if (!new_ao) + return -ENOMEM; + new_ao->lisn = htonl(tcp_rsk(req)->snt_isn); + new_ao->risn = htonl(tcp_rsk(req)->rcv_isn); + new_ao->ao_required = ao->ao_required; + new_ao->accept_icmps = ao->accept_icmps; + + if (family == AF_INET) { + addr = (union tcp_ao_addr *)&newsk->sk_daddr; +#if IS_ENABLED(CONFIG_IPV6) + } else if (family == AF_INET6) { + addr = (union tcp_ao_addr *)&newsk->sk_v6_daddr; +#endif + } else { + ret = -EAFNOSUPPORT; + goto free_ao; + } + l3index = l3mdev_master_ifindex_by_index(sock_net(newsk), + newsk->sk_bound_dev_if); + + hlist_for_each_entry_rcu(key, &ao->head, node) { + if (tcp_ao_key_cmp(key, l3index, addr, key->prefixlen, family, -1, -1)) + continue; + + new_key = tcp_ao_copy_key(newsk, key); + if (!new_key) + goto free_and_exit; + + tcp_ao_cache_traffic_keys(newsk, new_ao, new_key); + tcp_ao_link_mkt(new_ao, new_key); + match = true; + } + + if (!match) { + /* RFC5925 (7.4.1) specifies that the TCP-AO status + * of a connection is determined on the initial SYN. + * At this point the connection was TCP-AO enabled, so + * it can't switch to being unsigned if peer's key + * disappears on the listening socket. + */ + ret = -EKEYREJECTED; + goto free_and_exit; + } + + if (!static_key_fast_inc_not_disabled(&tcp_ao_needed.key.key)) { + ret = -EUSERS; + goto free_and_exit; + } + + key_head = rcu_dereference(hlist_first_rcu(&new_ao->head)); + first_key = hlist_entry_safe(key_head, struct tcp_ao_key, node); + + key = tcp_ao_established_key(req_to_sk(req), new_ao, tcp_rsk(req)->ao_keyid, -1); + if (key) + new_ao->current_key = key; + else + new_ao->current_key = first_key; + + /* set rnext_key */ + key = tcp_ao_established_key(req_to_sk(req), new_ao, -1, tcp_rsk(req)->ao_rcv_next); + if (key) + new_ao->rnext_key = key; + else + new_ao->rnext_key = first_key; + + sk_gso_disable(newsk); + rcu_assign_pointer(tcp_sk(newsk)->ao_info, new_ao); + + return 0; + +free_and_exit: + hlist_for_each_entry_safe(key, key_head, &new_ao->head, node) { + hlist_del(&key->node); + tcp_sigpool_release(key->tcp_sigpool_id); + atomic_sub(tcp_ao_sizeof_key(key), &newsk->sk_omem_alloc); + kfree_sensitive(key); + } +free_ao: + kfree(new_ao); + return ret; +} + +static bool tcp_ao_can_set_current_rnext(struct sock *sk) +{ + /* There aren't current/rnext keys on TCP_LISTEN sockets */ + if (sk->sk_state == TCP_LISTEN) + return false; + return true; +} + +static int tcp_ao_verify_ipv4(struct sock *sk, struct tcp_ao_add *cmd, + union tcp_ao_addr **addr) +{ + struct sockaddr_in *sin = (struct sockaddr_in *)&cmd->addr; + struct inet_sock *inet = inet_sk(sk); + + if (sin->sin_family != AF_INET) + return -EINVAL; + + /* Currently matching is not performed on port (or port ranges) */ + if (sin->sin_port != 0) + return -EINVAL; + + /* Check prefix and trailing 0's in addr */ + if (cmd->prefix != 0) { + __be32 mask; + + if (ntohl(sin->sin_addr.s_addr) == INADDR_ANY) + return -EINVAL; + if (cmd->prefix > 32) + return -EINVAL; + + mask = inet_make_mask(cmd->prefix); + if (sin->sin_addr.s_addr & ~mask) + return -EINVAL; + + /* Check that MKT address is consistent with socket */ + if (ntohl(inet->inet_daddr) != INADDR_ANY && + (inet->inet_daddr & mask) != sin->sin_addr.s_addr) + return -EINVAL; + } else { + if (ntohl(sin->sin_addr.s_addr) != INADDR_ANY) + return -EINVAL; + } + + *addr = (union tcp_ao_addr *)&sin->sin_addr; + return 0; +} + +static int tcp_ao_parse_crypto(struct tcp_ao_add *cmd, struct tcp_ao_key *key) +{ + unsigned int syn_tcp_option_space; + bool is_kdf_aes_128_cmac = false; + struct crypto_ahash *tfm; + struct tcp_sigpool hp; + void *tmp_key = NULL; + int err; + + /* RFC5926, 3.1.1.2. KDF_AES_128_CMAC */ + if (!strcmp("cmac(aes128)", cmd->alg_name)) { + strscpy(cmd->alg_name, "cmac(aes)", sizeof(cmd->alg_name)); + is_kdf_aes_128_cmac = (cmd->keylen != 16); + tmp_key = kmalloc(cmd->keylen, GFP_KERNEL); + if (!tmp_key) + return -ENOMEM; + } + + key->maclen = cmd->maclen ?: 12; /* 12 is the default in RFC5925 */ + + /* Check: maclen + tcp-ao header <= (MAX_TCP_OPTION_SPACE - mss + * - tstamp (including sackperm) + * - wscale), + * see tcp_syn_options(), tcp_synack_options(), commit 33ad798c924b. + * + * In order to allow D-SACK with TCP-AO, the header size should be: + * (MAX_TCP_OPTION_SPACE - TCPOLEN_TSTAMP_ALIGNED + * - TCPOLEN_SACK_BASE_ALIGNED + * - 2 * TCPOLEN_SACK_PERBLOCK) = 8 (maclen = 4), + * see tcp_established_options(). + * + * RFC5925, 2.2: + * Typical MACs are 96-128 bits (12-16 bytes), but any length + * that fits in the header of the segment being authenticated + * is allowed. + * + * RFC5925, 7.6: + * TCP-AO continues to consume 16 bytes in non-SYN segments, + * leaving a total of 24 bytes for other options, of which + * the timestamp consumes 10. This leaves 14 bytes, of which 10 + * are used for a single SACK block. When two SACK blocks are used, + * such as to handle D-SACK, a smaller TCP-AO MAC would be required + * to make room for the additional SACK block (i.e., to leave 18 + * bytes for the D-SACK variant of the SACK option) [RFC2883]. + * Note that D-SACK is not supportable in TCP MD5 in the presence + * of timestamps, because TCP MD5’s MAC length is fixed and too + * large to leave sufficient option space. + */ + syn_tcp_option_space = MAX_TCP_OPTION_SPACE; + syn_tcp_option_space -= TCPOLEN_MSS_ALIGNED; + syn_tcp_option_space -= TCPOLEN_TSTAMP_ALIGNED; + syn_tcp_option_space -= TCPOLEN_WSCALE_ALIGNED; + if (tcp_ao_len_aligned(key) > syn_tcp_option_space) { + err = -EMSGSIZE; + goto err_kfree; + } + + key->keylen = cmd->keylen; + memcpy(key->key, cmd->key, cmd->keylen); + + err = tcp_sigpool_start(key->tcp_sigpool_id, &hp); + if (err) + goto err_kfree; + + tfm = crypto_ahash_reqtfm(hp.req); + if (is_kdf_aes_128_cmac) { + void *scratch = hp.scratch; + struct scatterlist sg; + + memcpy(tmp_key, cmd->key, cmd->keylen); + sg_init_one(&sg, tmp_key, cmd->keylen); + + /* Using zero-key of 16 bytes as described in RFC5926 */ + memset(scratch, 0, 16); + err = crypto_ahash_setkey(tfm, scratch, 16); + if (err) + goto err_pool_end; + + err = crypto_ahash_init(hp.req); + if (err) + goto err_pool_end; + + ahash_request_set_crypt(hp.req, &sg, key->key, cmd->keylen); + err = crypto_ahash_update(hp.req); + if (err) + goto err_pool_end; + + err |= crypto_ahash_final(hp.req); + if (err) + goto err_pool_end; + key->keylen = 16; + } + + err = crypto_ahash_setkey(tfm, key->key, key->keylen); + if (err) + goto err_pool_end; + + tcp_sigpool_end(&hp); + kfree_sensitive(tmp_key); + + if (tcp_ao_maclen(key) > key->digest_size) + return -EINVAL; + + return 0; + +err_pool_end: + tcp_sigpool_end(&hp); +err_kfree: + kfree_sensitive(tmp_key); + return err; +} + +#if IS_ENABLED(CONFIG_IPV6) +static int tcp_ao_verify_ipv6(struct sock *sk, struct tcp_ao_add *cmd, + union tcp_ao_addr **paddr, + unsigned short int *family) +{ + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&cmd->addr; + struct in6_addr *addr = &sin6->sin6_addr; + u8 prefix = cmd->prefix; + + if (sin6->sin6_family != AF_INET6) + return -EINVAL; + + /* Currently matching is not performed on port (or port ranges) */ + if (sin6->sin6_port != 0) + return -EINVAL; + + /* Check prefix and trailing 0's in addr */ + if (cmd->prefix != 0 && ipv6_addr_v4mapped(addr)) { + __be32 addr4 = addr->s6_addr32[3]; + __be32 mask; + + if (prefix > 32 || ntohl(addr4) == INADDR_ANY) + return -EINVAL; + + mask = inet_make_mask(prefix); + if (addr4 & ~mask) + return -EINVAL; + + /* Check that MKT address is consistent with socket */ + if (!ipv6_addr_any(&sk->sk_v6_daddr)) { + __be32 daddr4 = sk->sk_v6_daddr.s6_addr32[3]; + + if (!ipv6_addr_v4mapped(&sk->sk_v6_daddr)) + return -EINVAL; + if ((daddr4 & mask) != addr4) + return -EINVAL; + } + + *paddr = (union tcp_ao_addr *)&addr->s6_addr32[3]; + *family = AF_INET; + return 0; + } else if (cmd->prefix != 0) { + struct in6_addr pfx; + + if (ipv6_addr_any(addr) || prefix > 128) + return -EINVAL; + + ipv6_addr_prefix(&pfx, addr, prefix); + if (ipv6_addr_cmp(&pfx, addr)) + return -EINVAL; + + /* Check that MKT address is consistent with socket */ + if (!ipv6_addr_any(&sk->sk_v6_daddr) && + !ipv6_prefix_equal(&sk->sk_v6_daddr, addr, prefix)) + + return -EINVAL; + } else { + if (!ipv6_addr_any(addr)) + return -EINVAL; + } + + *paddr = (union tcp_ao_addr *)addr; + return 0; +} +#else +static int tcp_ao_verify_ipv6(struct sock *sk, struct tcp_ao_add *cmd, + union tcp_ao_addr **paddr, + unsigned short int *family) +{ + return -EOPNOTSUPP; +} +#endif + +static struct tcp_ao_info *setsockopt_ao_info(struct sock *sk) +{ + if (sk_fullsock(sk)) { + return rcu_dereference_protected(tcp_sk(sk)->ao_info, + lockdep_sock_is_held(sk)); + } else if (sk->sk_state == TCP_TIME_WAIT) { + return rcu_dereference_protected(tcp_twsk(sk)->ao_info, + lockdep_sock_is_held(sk)); + } + return ERR_PTR(-ESOCKTNOSUPPORT); +} + +static struct tcp_ao_info *getsockopt_ao_info(struct sock *sk) +{ + if (sk_fullsock(sk)) + return rcu_dereference(tcp_sk(sk)->ao_info); + else if (sk->sk_state == TCP_TIME_WAIT) + return rcu_dereference(tcp_twsk(sk)->ao_info); + + return ERR_PTR(-ESOCKTNOSUPPORT); +} + +#define TCP_AO_KEYF_ALL (TCP_AO_KEYF_IFINDEX | TCP_AO_KEYF_EXCLUDE_OPT) +#define TCP_AO_GET_KEYF_VALID (TCP_AO_KEYF_IFINDEX) + +static struct tcp_ao_key *tcp_ao_key_alloc(struct sock *sk, + struct tcp_ao_add *cmd) +{ + const char *algo = cmd->alg_name; + unsigned int digest_size; + struct crypto_ahash *tfm; + struct tcp_ao_key *key; + struct tcp_sigpool hp; + int err, pool_id; + size_t size; + + /* Force null-termination of alg_name */ + cmd->alg_name[ARRAY_SIZE(cmd->alg_name) - 1] = '\0'; + + /* RFC5926, 3.1.1.2. KDF_AES_128_CMAC */ + if (!strcmp("cmac(aes128)", algo)) + algo = "cmac(aes)"; + + /* Full TCP header (th->doff << 2) should fit into scratch area, + * see tcp_ao_hash_header(). + */ + pool_id = tcp_sigpool_alloc_ahash(algo, 60); + if (pool_id < 0) + return ERR_PTR(pool_id); + + err = tcp_sigpool_start(pool_id, &hp); + if (err) + goto err_free_pool; + + tfm = crypto_ahash_reqtfm(hp.req); + digest_size = crypto_ahash_digestsize(tfm); + tcp_sigpool_end(&hp); + + size = sizeof(struct tcp_ao_key) + (digest_size << 1); + key = sock_kmalloc(sk, size, GFP_KERNEL); + if (!key) { + err = -ENOMEM; + goto err_free_pool; + } + + key->tcp_sigpool_id = pool_id; + key->digest_size = digest_size; + return key; + +err_free_pool: + tcp_sigpool_release(pool_id); + return ERR_PTR(err); +} + +static int tcp_ao_add_cmd(struct sock *sk, unsigned short int family, + sockptr_t optval, int optlen) +{ + struct tcp_ao_info *ao_info; + union tcp_ao_addr *addr; + struct tcp_ao_key *key; + struct tcp_ao_add cmd; + int ret, l3index = 0; + bool first = false; + + if (optlen < sizeof(cmd)) + return -EINVAL; + + ret = copy_struct_from_sockptr(&cmd, sizeof(cmd), optval, optlen); + if (ret) + return ret; + + if (cmd.keylen > TCP_AO_MAXKEYLEN) + return -EINVAL; + + if (cmd.reserved != 0 || cmd.reserved2 != 0) + return -EINVAL; + + if (family == AF_INET) + ret = tcp_ao_verify_ipv4(sk, &cmd, &addr); + else + ret = tcp_ao_verify_ipv6(sk, &cmd, &addr, &family); + if (ret) + return ret; + + if (cmd.keyflags & ~TCP_AO_KEYF_ALL) + return -EINVAL; + + if (cmd.set_current || cmd.set_rnext) { + if (!tcp_ao_can_set_current_rnext(sk)) + return -EINVAL; + } + + if (cmd.ifindex && !(cmd.keyflags & TCP_AO_KEYF_IFINDEX)) + return -EINVAL; + + /* For cmd.tcp_ifindex = 0 the key will apply to the default VRF */ + if (cmd.keyflags & TCP_AO_KEYF_IFINDEX && cmd.ifindex) { + int bound_dev_if = READ_ONCE(sk->sk_bound_dev_if); + struct net_device *dev; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), cmd.ifindex); + if (dev && netif_is_l3_master(dev)) + l3index = dev->ifindex; + rcu_read_unlock(); + + if (!dev || !l3index) + return -EINVAL; + + if (!bound_dev_if || bound_dev_if != cmd.ifindex) { + /* tcp_ao_established_key() doesn't expect having + * non peer-matching key on an established TCP-AO + * connection. + */ + if (!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE))) + return -EINVAL; + } + + /* It's still possible to bind after adding keys or even + * re-bind to a different dev (with CAP_NET_RAW). + * So, no reason to return error here, rather try to be + * nice and warn the user. + */ + if (bound_dev_if && bound_dev_if != cmd.ifindex) + net_warn_ratelimited("AO key ifindex %d != sk bound ifindex %d\n", + cmd.ifindex, bound_dev_if); + } + + /* Don't allow keys for peers that have a matching TCP-MD5 key */ + if (cmd.keyflags & TCP_AO_KEYF_IFINDEX) { + /* Non-_exact version of tcp_md5_do_lookup() will + * as well match keys that aren't bound to a specific VRF + * (that will make them match AO key with + * sysctl_tcp_l3dev_accept = 1 + */ + if (tcp_md5_do_lookup(sk, l3index, addr, family)) + return -EKEYREJECTED; + } else { + if (tcp_md5_do_lookup_any_l3index(sk, addr, family)) + return -EKEYREJECTED; + } + + ao_info = setsockopt_ao_info(sk); + if (IS_ERR(ao_info)) + return PTR_ERR(ao_info); + + if (!ao_info) { + ao_info = tcp_ao_alloc_info(GFP_KERNEL); + if (!ao_info) + return -ENOMEM; + first = true; + } else { + /* Check that neither RecvID nor SendID match any + * existing key for the peer, RFC5925 3.1: + * > The IDs of MKTs MUST NOT overlap where their + * > TCP connection identifiers overlap. + */ + if (__tcp_ao_do_lookup(sk, l3index, addr, family, cmd.prefix, -1, cmd.rcvid)) + return -EEXIST; + if (__tcp_ao_do_lookup(sk, l3index, addr, family, + cmd.prefix, cmd.sndid, -1)) + return -EEXIST; + } + + key = tcp_ao_key_alloc(sk, &cmd); + if (IS_ERR(key)) { + ret = PTR_ERR(key); + goto err_free_ao; + } + + INIT_HLIST_NODE(&key->node); + memcpy(&key->addr, addr, (family == AF_INET) ? sizeof(struct in_addr) : + sizeof(struct in6_addr)); + key->prefixlen = cmd.prefix; + key->family = family; + key->keyflags = cmd.keyflags; + key->sndid = cmd.sndid; + key->rcvid = cmd.rcvid; + key->l3index = l3index; + atomic64_set(&key->pkt_good, 0); + atomic64_set(&key->pkt_bad, 0); + + ret = tcp_ao_parse_crypto(&cmd, key); + if (ret < 0) + goto err_free_sock; + + if (!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE))) { + tcp_ao_cache_traffic_keys(sk, ao_info, key); + if (first) { + ao_info->current_key = key; + ao_info->rnext_key = key; + } + } + + tcp_ao_link_mkt(ao_info, key); + if (first) { + if (!static_branch_inc(&tcp_ao_needed.key)) { + ret = -EUSERS; + goto err_free_sock; + } + sk_gso_disable(sk); + rcu_assign_pointer(tcp_sk(sk)->ao_info, ao_info); + } + + if (cmd.set_current) + WRITE_ONCE(ao_info->current_key, key); + if (cmd.set_rnext) + WRITE_ONCE(ao_info->rnext_key, key); + return 0; + +err_free_sock: + atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc); + tcp_sigpool_release(key->tcp_sigpool_id); + kfree_sensitive(key); +err_free_ao: + if (first) + kfree(ao_info); + return ret; +} + +static int tcp_ao_delete_key(struct sock *sk, struct tcp_ao_info *ao_info, + bool del_async, struct tcp_ao_key *key, + struct tcp_ao_key *new_current, + struct tcp_ao_key *new_rnext) +{ + int err; + + hlist_del_rcu(&key->node); + + /* Support for async delete on listening sockets: as they don't + * need current_key/rnext_key maintaining, we don't need to check + * them and we can just free all resources in RCU fashion. + */ + if (del_async) { + atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc); + call_rcu(&key->rcu, tcp_ao_key_free_rcu); + return 0; + } + + /* At this moment another CPU could have looked this key up + * while it was unlinked from the list. Wait for RCU grace period, + * after which the key is off-list and can't be looked up again; + * the rx path [just before RCU came] might have used it and set it + * as current_key (very unlikely). + * Free the key with next RCU grace period (in case it was + * current_key before tcp_ao_current_rnext() might have + * changed it in forced-delete). + */ + synchronize_rcu(); + if (new_current) + WRITE_ONCE(ao_info->current_key, new_current); + if (new_rnext) + WRITE_ONCE(ao_info->rnext_key, new_rnext); + + if (unlikely(READ_ONCE(ao_info->current_key) == key || + READ_ONCE(ao_info->rnext_key) == key)) { + err = -EBUSY; + goto add_key; + } + + atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc); + call_rcu(&key->rcu, tcp_ao_key_free_rcu); + + return 0; +add_key: + hlist_add_head_rcu(&key->node, &ao_info->head); + return err; +} + +#define TCP_AO_DEL_KEYF_ALL (TCP_AO_KEYF_IFINDEX) +static int tcp_ao_del_cmd(struct sock *sk, unsigned short int family, + sockptr_t optval, int optlen) +{ + struct tcp_ao_key *key, *new_current = NULL, *new_rnext = NULL; + int err, addr_len, l3index = 0; + struct tcp_ao_info *ao_info; + union tcp_ao_addr *addr; + struct tcp_ao_del cmd; + __u8 prefix; + u16 port; + + if (optlen < sizeof(cmd)) + return -EINVAL; + + err = copy_struct_from_sockptr(&cmd, sizeof(cmd), optval, optlen); + if (err) + return err; + + if (cmd.reserved != 0 || cmd.reserved2 != 0) + return -EINVAL; + + if (cmd.set_current || cmd.set_rnext) { + if (!tcp_ao_can_set_current_rnext(sk)) + return -EINVAL; + } + + if (cmd.keyflags & ~TCP_AO_DEL_KEYF_ALL) + return -EINVAL; + + /* No sanity check for TCP_AO_KEYF_IFINDEX as if a VRF + * was destroyed, there still should be a way to delete keys, + * that were bound to that l3intf. So, fail late at lookup stage + * if there is no key for that ifindex. + */ + if (cmd.ifindex && !(cmd.keyflags & TCP_AO_KEYF_IFINDEX)) + return -EINVAL; + + ao_info = setsockopt_ao_info(sk); + if (IS_ERR(ao_info)) + return PTR_ERR(ao_info); + if (!ao_info) + return -ENOENT; + + /* For sockets in TCP_CLOSED it's possible set keys that aren't + * matching the future peer (address/VRF/etc), + * tcp_ao_connect_init() will choose a correct matching MKT + * if there's any. + */ + if (cmd.set_current) { + new_current = tcp_ao_established_key(sk, ao_info, cmd.current_key, -1); + if (!new_current) + return -ENOENT; + } + if (cmd.set_rnext) { + new_rnext = tcp_ao_established_key(sk, ao_info, -1, cmd.rnext); + if (!new_rnext) + return -ENOENT; + } + if (cmd.del_async && sk->sk_state != TCP_LISTEN) + return -EINVAL; + + if (family == AF_INET) { + struct sockaddr_in *sin = (struct sockaddr_in *)&cmd.addr; + + addr = (union tcp_ao_addr *)&sin->sin_addr; + addr_len = sizeof(struct in_addr); + port = ntohs(sin->sin_port); + } else { + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&cmd.addr; + struct in6_addr *addr6 = &sin6->sin6_addr; + + if (ipv6_addr_v4mapped(addr6)) { + addr = (union tcp_ao_addr *)&addr6->s6_addr32[3]; + addr_len = sizeof(struct in_addr); + family = AF_INET; + } else { + addr = (union tcp_ao_addr *)addr6; + addr_len = sizeof(struct in6_addr); + } + port = ntohs(sin6->sin6_port); + } + prefix = cmd.prefix; + + /* Currently matching is not performed on port (or port ranges) */ + if (port != 0) + return -EINVAL; + + /* We could choose random present key here for current/rnext + * but that's less predictable. Let's be strict and don't + * allow removing a key that's in use. RFC5925 doesn't + * specify how-to coordinate key removal, but says: + * "It is presumed that an MKT affecting a particular + * connection cannot be destroyed during an active connection" + */ + hlist_for_each_entry_rcu(key, &ao_info->head, node, + lockdep_sock_is_held(sk)) { + if (cmd.sndid != key->sndid || + cmd.rcvid != key->rcvid) + continue; + + if (family != key->family || + prefix != key->prefixlen || + memcmp(addr, &key->addr, addr_len)) + continue; + + if ((cmd.keyflags & TCP_AO_KEYF_IFINDEX) != + (key->keyflags & TCP_AO_KEYF_IFINDEX)) + continue; + + if (key->l3index != l3index) + continue; + + if (key == new_current || key == new_rnext) + continue; + + return tcp_ao_delete_key(sk, ao_info, cmd.del_async, key, + new_current, new_rnext); + } + return -ENOENT; +} + +/* cmd.ao_required makes a socket TCP-AO only. + * Don't allow any md5 keys for any l3intf on the socket together with it. + * Restricting it early in setsockopt() removes a check for + * ao_info->ao_required on inbound tcp segment fast-path. + */ +static int tcp_ao_required_verify(struct sock *sk) +{ +#ifdef CONFIG_TCP_MD5SIG + const struct tcp_md5sig_info *md5sig; + + if (!static_branch_unlikely(&tcp_md5_needed.key)) + return 0; + + md5sig = rcu_dereference_check(tcp_sk(sk)->md5sig_info, + lockdep_sock_is_held(sk)); + if (!md5sig) + return 0; + + if (rcu_dereference_check(hlist_first_rcu(&md5sig->head), + lockdep_sock_is_held(sk))) + return 1; +#endif + return 0; +} + +static int tcp_ao_info_cmd(struct sock *sk, unsigned short int family, + sockptr_t optval, int optlen) +{ + struct tcp_ao_key *new_current = NULL, *new_rnext = NULL; + struct tcp_ao_info *ao_info; + struct tcp_ao_info_opt cmd; + bool first = false; + int err; + + if (optlen < sizeof(cmd)) + return -EINVAL; + + err = copy_struct_from_sockptr(&cmd, sizeof(cmd), optval, optlen); + if (err) + return err; + + if (cmd.set_current || cmd.set_rnext) { + if (!tcp_ao_can_set_current_rnext(sk)) + return -EINVAL; + } + + if (cmd.reserved != 0 || cmd.reserved2 != 0) + return -EINVAL; + + ao_info = setsockopt_ao_info(sk); + if (IS_ERR(ao_info)) + return PTR_ERR(ao_info); + if (!ao_info) { + if (!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE))) + return -EINVAL; + ao_info = tcp_ao_alloc_info(GFP_KERNEL); + if (!ao_info) + return -ENOMEM; + first = true; + } + + if (cmd.ao_required && tcp_ao_required_verify(sk)) { + err = -EKEYREJECTED; + goto out; + } + + /* For sockets in TCP_CLOSED it's possible set keys that aren't + * matching the future peer (address/port/VRF/etc), + * tcp_ao_connect_init() will choose a correct matching MKT + * if there's any. + */ + if (cmd.set_current) { + new_current = tcp_ao_established_key(sk, ao_info, cmd.current_key, -1); + if (!new_current) { + err = -ENOENT; + goto out; + } + } + if (cmd.set_rnext) { + new_rnext = tcp_ao_established_key(sk, ao_info, -1, cmd.rnext); + if (!new_rnext) { + err = -ENOENT; + goto out; + } + } + if (cmd.set_counters) { + atomic64_set(&ao_info->counters.pkt_good, cmd.pkt_good); + atomic64_set(&ao_info->counters.pkt_bad, cmd.pkt_bad); + atomic64_set(&ao_info->counters.key_not_found, cmd.pkt_key_not_found); + atomic64_set(&ao_info->counters.ao_required, cmd.pkt_ao_required); + atomic64_set(&ao_info->counters.dropped_icmp, cmd.pkt_dropped_icmp); + } + + ao_info->ao_required = cmd.ao_required; + ao_info->accept_icmps = cmd.accept_icmps; + if (new_current) + WRITE_ONCE(ao_info->current_key, new_current); + if (new_rnext) + WRITE_ONCE(ao_info->rnext_key, new_rnext); + if (first) { + if (!static_branch_inc(&tcp_ao_needed.key)) { + err = -EUSERS; + goto out; + } + sk_gso_disable(sk); + rcu_assign_pointer(tcp_sk(sk)->ao_info, ao_info); + } + return 0; +out: + if (first) + kfree(ao_info); + return err; +} + +int tcp_parse_ao(struct sock *sk, int cmd, unsigned short int family, + sockptr_t optval, int optlen) +{ + if (WARN_ON_ONCE(family != AF_INET && family != AF_INET6)) + return -EAFNOSUPPORT; + + switch (cmd) { + case TCP_AO_ADD_KEY: + return tcp_ao_add_cmd(sk, family, optval, optlen); + case TCP_AO_DEL_KEY: + return tcp_ao_del_cmd(sk, family, optval, optlen); + case TCP_AO_INFO: + return tcp_ao_info_cmd(sk, family, optval, optlen); + default: + WARN_ON_ONCE(1); + return -EINVAL; + } +} + +int tcp_v4_parse_ao(struct sock *sk, int cmd, sockptr_t optval, int optlen) +{ + return tcp_parse_ao(sk, cmd, AF_INET, optval, optlen); +} + +/* tcp_ao_copy_mkts_to_user(ao_info, optval, optlen) + * + * @ao_info: struct tcp_ao_info on the socket that + * socket getsockopt(TCP_AO_GET_KEYS) is executed on + * @optval: pointer to array of tcp_ao_getsockopt structures in user space. + * Must be != NULL. + * @optlen: pointer to size of tcp_ao_getsockopt structure. + * Must be != NULL. + * + * Return value: 0 on success, a negative error number otherwise. + * + * optval points to an array of tcp_ao_getsockopt structures in user space. + * optval[0] is used as both input and output to getsockopt. It determines + * which keys are returned by the kernel. + * optval[0].nkeys is the size of the array in user space. On return it contains + * the number of keys matching the search criteria. + * If tcp_ao_getsockopt::get_all is set, then all keys in the socket are + * returned, otherwise only keys matching <addr, prefix, sndid, rcvid> + * in optval[0] are returned. + * optlen is also used as both input and output. The user provides the size + * of struct tcp_ao_getsockopt in user space, and the kernel returns the size + * of the structure in kernel space. + * The size of struct tcp_ao_getsockopt may differ between user and kernel. + * There are three cases to consider: + * * If usize == ksize, then keys are copied verbatim. + * * If usize < ksize, then the userspace has passed an old struct to a + * newer kernel. The rest of the trailing bytes in optval[0] + * (ksize - usize) are interpreted as 0 by the kernel. + * * If usize > ksize, then the userspace has passed a new struct to an + * older kernel. The trailing bytes unknown to the kernel (usize - ksize) + * are checked to ensure they are zeroed, otherwise -E2BIG is returned. + * On return the kernel fills in min(usize, ksize) in each entry of the array. + * The layout of the fields in the user and kernel structures is expected to + * be the same (including in the 32bit vs 64bit case). + */ +static int tcp_ao_copy_mkts_to_user(const struct sock *sk, + struct tcp_ao_info *ao_info, + sockptr_t optval, sockptr_t optlen) +{ + struct tcp_ao_getsockopt opt_in, opt_out; + struct tcp_ao_key *key, *current_key; + bool do_address_matching = true; + union tcp_ao_addr *addr = NULL; + int err, l3index, user_len; + unsigned int max_keys; /* maximum number of keys to copy to user */ + size_t out_offset = 0; + size_t bytes_to_write; /* number of bytes to write to user level */ + u32 matched_keys; /* keys from ao_info matched so far */ + int optlen_out; + __be16 port = 0; + + if (copy_from_sockptr(&user_len, optlen, sizeof(int))) + return -EFAULT; + + if (user_len <= 0) + return -EINVAL; + + memset(&opt_in, 0, sizeof(struct tcp_ao_getsockopt)); + err = copy_struct_from_sockptr(&opt_in, sizeof(opt_in), + optval, user_len); + if (err < 0) + return err; + + if (opt_in.pkt_good || opt_in.pkt_bad) + return -EINVAL; + if (opt_in.keyflags & ~TCP_AO_GET_KEYF_VALID) + return -EINVAL; + if (opt_in.ifindex && !(opt_in.keyflags & TCP_AO_KEYF_IFINDEX)) + return -EINVAL; + + if (opt_in.reserved != 0) + return -EINVAL; + + max_keys = opt_in.nkeys; + l3index = (opt_in.keyflags & TCP_AO_KEYF_IFINDEX) ? opt_in.ifindex : -1; + + if (opt_in.get_all || opt_in.is_current || opt_in.is_rnext) { + if (opt_in.get_all && (opt_in.is_current || opt_in.is_rnext)) + return -EINVAL; + do_address_matching = false; + } + + switch (opt_in.addr.ss_family) { + case AF_INET: { + struct sockaddr_in *sin; + __be32 mask; + + sin = (struct sockaddr_in *)&opt_in.addr; + port = sin->sin_port; + addr = (union tcp_ao_addr *)&sin->sin_addr; + + if (opt_in.prefix > 32) + return -EINVAL; + + if (ntohl(sin->sin_addr.s_addr) == INADDR_ANY && + opt_in.prefix != 0) + return -EINVAL; + + mask = inet_make_mask(opt_in.prefix); + if (sin->sin_addr.s_addr & ~mask) + return -EINVAL; + + break; + } + case AF_INET6: { + struct sockaddr_in6 *sin6; + struct in6_addr *addr6; + + sin6 = (struct sockaddr_in6 *)&opt_in.addr; + addr = (union tcp_ao_addr *)&sin6->sin6_addr; + addr6 = &sin6->sin6_addr; + port = sin6->sin6_port; + + /* We don't have to change family and @addr here if + * ipv6_addr_v4mapped() like in key adding: + * tcp_ao_key_cmp() does it. Do the sanity checks though. + */ + if (opt_in.prefix != 0) { + if (ipv6_addr_v4mapped(addr6)) { + __be32 mask, addr4 = addr6->s6_addr32[3]; + + if (opt_in.prefix > 32 || + ntohl(addr4) == INADDR_ANY) + return -EINVAL; + mask = inet_make_mask(opt_in.prefix); + if (addr4 & ~mask) + return -EINVAL; + } else { + struct in6_addr pfx; + + if (ipv6_addr_any(addr6) || + opt_in.prefix > 128) + return -EINVAL; + + ipv6_addr_prefix(&pfx, addr6, opt_in.prefix); + if (ipv6_addr_cmp(&pfx, addr6)) + return -EINVAL; + } + } else if (!ipv6_addr_any(addr6)) { + return -EINVAL; + } + break; + } + case 0: + if (!do_address_matching) + break; + fallthrough; + default: + return -EAFNOSUPPORT; + } + + if (!do_address_matching) { + /* We could just ignore those, but let's do stricter checks */ + if (addr || port) + return -EINVAL; + if (opt_in.prefix || opt_in.sndid || opt_in.rcvid) + return -EINVAL; + } + + bytes_to_write = min_t(int, user_len, sizeof(struct tcp_ao_getsockopt)); + matched_keys = 0; + /* May change in RX, while we're dumping, pre-fetch it */ + current_key = READ_ONCE(ao_info->current_key); + + hlist_for_each_entry_rcu(key, &ao_info->head, node, + lockdep_sock_is_held(sk)) { + if (opt_in.get_all) + goto match; + + if (opt_in.is_current || opt_in.is_rnext) { + if (opt_in.is_current && key == current_key) + goto match; + if (opt_in.is_rnext && key == ao_info->rnext_key) + goto match; + continue; + } + + if (tcp_ao_key_cmp(key, l3index, addr, opt_in.prefix, + opt_in.addr.ss_family, + opt_in.sndid, opt_in.rcvid) != 0) + continue; +match: + matched_keys++; + if (matched_keys > max_keys) + continue; + + memset(&opt_out, 0, sizeof(struct tcp_ao_getsockopt)); + + if (key->family == AF_INET) { + struct sockaddr_in *sin_out = (struct sockaddr_in *)&opt_out.addr; + + sin_out->sin_family = key->family; + sin_out->sin_port = 0; + memcpy(&sin_out->sin_addr, &key->addr, sizeof(struct in_addr)); + } else { + struct sockaddr_in6 *sin6_out = (struct sockaddr_in6 *)&opt_out.addr; + + sin6_out->sin6_family = key->family; + sin6_out->sin6_port = 0; + memcpy(&sin6_out->sin6_addr, &key->addr, sizeof(struct in6_addr)); + } + opt_out.sndid = key->sndid; + opt_out.rcvid = key->rcvid; + opt_out.prefix = key->prefixlen; + opt_out.keyflags = key->keyflags; + opt_out.is_current = (key == current_key); + opt_out.is_rnext = (key == ao_info->rnext_key); + opt_out.nkeys = 0; + opt_out.maclen = key->maclen; + opt_out.keylen = key->keylen; + opt_out.ifindex = key->l3index; + opt_out.pkt_good = atomic64_read(&key->pkt_good); + opt_out.pkt_bad = atomic64_read(&key->pkt_bad); + memcpy(&opt_out.key, key->key, key->keylen); + tcp_sigpool_algo(key->tcp_sigpool_id, opt_out.alg_name, 64); + + /* Copy key to user */ + if (copy_to_sockptr_offset(optval, out_offset, + &opt_out, bytes_to_write)) + return -EFAULT; + out_offset += user_len; + } + + optlen_out = (int)sizeof(struct tcp_ao_getsockopt); + if (copy_to_sockptr(optlen, &optlen_out, sizeof(int))) + return -EFAULT; + + out_offset = offsetof(struct tcp_ao_getsockopt, nkeys); + if (copy_to_sockptr_offset(optval, out_offset, + &matched_keys, sizeof(u32))) + return -EFAULT; + + return 0; +} + +int tcp_ao_get_mkts(struct sock *sk, sockptr_t optval, sockptr_t optlen) +{ + struct tcp_ao_info *ao_info; + + ao_info = setsockopt_ao_info(sk); + if (IS_ERR(ao_info)) + return PTR_ERR(ao_info); + if (!ao_info) + return -ENOENT; + + return tcp_ao_copy_mkts_to_user(sk, ao_info, optval, optlen); +} + +int tcp_ao_get_sock_info(struct sock *sk, sockptr_t optval, sockptr_t optlen) +{ + struct tcp_ao_info_opt out, in = {}; + struct tcp_ao_key *current_key; + struct tcp_ao_info *ao; + int err, len; + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + + if (len <= 0) + return -EINVAL; + + /* Copying this "in" only to check ::reserved, ::reserved2, + * that may be needed to extend (struct tcp_ao_info_opt) and + * what getsockopt() provides in future. + */ + err = copy_struct_from_sockptr(&in, sizeof(in), optval, len); + if (err) + return err; + + if (in.reserved != 0 || in.reserved2 != 0) + return -EINVAL; + + ao = setsockopt_ao_info(sk); + if (IS_ERR(ao)) + return PTR_ERR(ao); + if (!ao) + return -ENOENT; + + memset(&out, 0, sizeof(out)); + out.ao_required = ao->ao_required; + out.accept_icmps = ao->accept_icmps; + out.pkt_good = atomic64_read(&ao->counters.pkt_good); + out.pkt_bad = atomic64_read(&ao->counters.pkt_bad); + out.pkt_key_not_found = atomic64_read(&ao->counters.key_not_found); + out.pkt_ao_required = atomic64_read(&ao->counters.ao_required); + out.pkt_dropped_icmp = atomic64_read(&ao->counters.dropped_icmp); + + current_key = READ_ONCE(ao->current_key); + if (current_key) { + out.set_current = 1; + out.current_key = current_key->sndid; + } + if (ao->rnext_key) { + out.set_rnext = 1; + out.rnext = ao->rnext_key->rcvid; + } + + if (copy_to_sockptr(optval, &out, min_t(int, len, sizeof(out)))) + return -EFAULT; + + return 0; +} + +int tcp_ao_set_repair(struct sock *sk, sockptr_t optval, unsigned int optlen) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_ao_repair cmd; + struct tcp_ao_key *key; + struct tcp_ao_info *ao; + int err; + + if (optlen < sizeof(cmd)) + return -EINVAL; + + err = copy_struct_from_sockptr(&cmd, sizeof(cmd), optval, optlen); + if (err) + return err; + + if (!tp->repair) + return -EPERM; + + ao = setsockopt_ao_info(sk); + if (IS_ERR(ao)) + return PTR_ERR(ao); + if (!ao) + return -ENOENT; + + WRITE_ONCE(ao->lisn, cmd.snt_isn); + WRITE_ONCE(ao->risn, cmd.rcv_isn); + WRITE_ONCE(ao->snd_sne, cmd.snd_sne); + WRITE_ONCE(ao->rcv_sne, cmd.rcv_sne); + + hlist_for_each_entry_rcu(key, &ao->head, node, lockdep_sock_is_held(sk)) + tcp_ao_cache_traffic_keys(sk, ao, key); + + return 0; +} + +int tcp_ao_get_repair(struct sock *sk, sockptr_t optval, sockptr_t optlen) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_ao_repair opt; + struct tcp_ao_info *ao; + int len; + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + + if (len <= 0) + return -EINVAL; + + if (!tp->repair) + return -EPERM; + + rcu_read_lock(); + ao = getsockopt_ao_info(sk); + if (IS_ERR_OR_NULL(ao)) { + rcu_read_unlock(); + return ao ? PTR_ERR(ao) : -ENOENT; + } + + opt.snt_isn = ao->lisn; + opt.rcv_isn = ao->risn; + opt.snd_sne = READ_ONCE(ao->snd_sne); + opt.rcv_sne = READ_ONCE(ao->rcv_sne); + rcu_read_unlock(); + + if (copy_to_sockptr(optval, &opt, min_t(int, len, sizeof(opt)))) + return -EFAULT; + return 0; +} diff --git a/net/ipv4/tcp_bbr.c b/net/ipv4/tcp_bbr.c index ec5550089b4d..760941e55153 100644 --- a/net/ipv4/tcp_bbr.c +++ b/net/ipv4/tcp_bbr.c @@ -258,7 +258,7 @@ static unsigned long bbr_bw_to_pacing_rate(struct sock *sk, u32 bw, int gain) u64 rate = bw; rate = bbr_rate_bytes_per_sec(sk, rate, gain); - rate = min_t(u64, rate, sk->sk_max_pacing_rate); + rate = min_t(u64, rate, READ_ONCE(sk->sk_max_pacing_rate)); return rate; } @@ -276,9 +276,10 @@ static void bbr_init_pacing_rate_from_rtt(struct sock *sk) } else { /* no RTT sample yet */ rtt_us = USEC_PER_MSEC; /* use nominal default RTT */ } - bw = (u64)tp->snd_cwnd * BW_UNIT; + bw = (u64)tcp_snd_cwnd(tp) * BW_UNIT; do_div(bw, rtt_us); - sk->sk_pacing_rate = bbr_bw_to_pacing_rate(sk, bw, bbr_high_gain); + WRITE_ONCE(sk->sk_pacing_rate, + bbr_bw_to_pacing_rate(sk, bw, bbr_high_gain)); } /* Pace using current bw estimate and a gain factor. */ @@ -290,14 +291,14 @@ static void bbr_set_pacing_rate(struct sock *sk, u32 bw, int gain) if (unlikely(!bbr->has_seen_rtt && tp->srtt_us)) bbr_init_pacing_rate_from_rtt(sk); - if (bbr_full_bw_reached(sk) || rate > sk->sk_pacing_rate) - sk->sk_pacing_rate = rate; + if (bbr_full_bw_reached(sk) || rate > READ_ONCE(sk->sk_pacing_rate)) + WRITE_ONCE(sk->sk_pacing_rate, rate); } /* override sysctl_tcp_min_tso_segs */ -static u32 bbr_min_tso_segs(struct sock *sk) +__bpf_kfunc static u32 bbr_min_tso_segs(struct sock *sk) { - return sk->sk_pacing_rate < (bbr_min_tso_rate >> 3) ? 1 : 2; + return READ_ONCE(sk->sk_pacing_rate) < (bbr_min_tso_rate >> 3) ? 1 : 2; } static u32 bbr_tso_segs_goal(struct sock *sk) @@ -309,8 +310,8 @@ static u32 bbr_tso_segs_goal(struct sock *sk) * driver provided sk_gso_max_size. */ bytes = min_t(unsigned long, - sk->sk_pacing_rate >> READ_ONCE(sk->sk_pacing_shift), - GSO_MAX_SIZE - 1 - MAX_TCP_HEADER); + READ_ONCE(sk->sk_pacing_rate) >> READ_ONCE(sk->sk_pacing_shift), + GSO_LEGACY_MAX_SIZE - 1 - MAX_TCP_HEADER); segs = max_t(u32, bytes / tp->mss_cache, bbr_min_tso_segs(sk)); return min(segs, 0x7FU); @@ -323,12 +324,12 @@ static void bbr_save_cwnd(struct sock *sk) struct bbr *bbr = inet_csk_ca(sk); if (bbr->prev_ca_state < TCP_CA_Recovery && bbr->mode != BBR_PROBE_RTT) - bbr->prior_cwnd = tp->snd_cwnd; /* this cwnd is good enough */ + bbr->prior_cwnd = tcp_snd_cwnd(tp); /* this cwnd is good enough */ else /* loss recovery or BBR_PROBE_RTT have temporarily cut cwnd */ - bbr->prior_cwnd = max(bbr->prior_cwnd, tp->snd_cwnd); + bbr->prior_cwnd = max(bbr->prior_cwnd, tcp_snd_cwnd(tp)); } -static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) +__bpf_kfunc static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); @@ -482,7 +483,7 @@ static bool bbr_set_cwnd_to_recover_or_restore( struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); u8 prev_state = bbr->prev_ca_state, state = inet_csk(sk)->icsk_ca_state; - u32 cwnd = tp->snd_cwnd; + u32 cwnd = tcp_snd_cwnd(tp); /* An ACK for P pkts should release at most 2*P packets. We do this * in two steps. First, here we deduct the number of lost packets. @@ -520,7 +521,7 @@ static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - u32 cwnd = tp->snd_cwnd, target_cwnd = 0; + u32 cwnd = tcp_snd_cwnd(tp), target_cwnd = 0; if (!acked) goto done; /* no packet fully ACKed; just apply caps */ @@ -544,9 +545,9 @@ static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, cwnd = max(cwnd, bbr_cwnd_min_target); done: - tp->snd_cwnd = min(cwnd, tp->snd_cwnd_clamp); /* apply global cap */ + tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); /* apply global cap */ if (bbr->mode == BBR_PROBE_RTT) /* drain queue, refresh min_rtt */ - tp->snd_cwnd = min(tp->snd_cwnd, bbr_cwnd_min_target); + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), bbr_cwnd_min_target)); } /* End cycle phase if it's time and/or we hit the phase's in-flight target. */ @@ -618,7 +619,7 @@ static void bbr_reset_probe_bw_mode(struct sock *sk) struct bbr *bbr = inet_csk_ca(sk); bbr->mode = BBR_PROBE_BW; - bbr->cycle_idx = CYCLE_LEN - 1 - prandom_u32_max(bbr_cycle_rand); + bbr->cycle_idx = CYCLE_LEN - 1 - get_random_u32_below(bbr_cycle_rand); bbr_advance_cycle_phase(sk); /* flip to next phase of gain cycle */ } @@ -856,7 +857,7 @@ static void bbr_update_ack_aggregation(struct sock *sk, bbr->ack_epoch_acked = min_t(u32, 0xFFFFF, bbr->ack_epoch_acked + rs->acked_sacked); extra_acked = bbr->ack_epoch_acked - expected_acked; - extra_acked = min(extra_acked, tp->snd_cwnd); + extra_acked = min(extra_acked, tcp_snd_cwnd(tp)); if (extra_acked > bbr->extra_acked[bbr->extra_acked_win_idx]) bbr->extra_acked[bbr->extra_acked_win_idx] = extra_acked; } @@ -914,7 +915,7 @@ static void bbr_check_probe_rtt_done(struct sock *sk) return; bbr->min_rtt_stamp = tcp_jiffies32; /* wait a while until PROBE_RTT */ - tp->snd_cwnd = max(tp->snd_cwnd, bbr->prior_cwnd); + tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), bbr->prior_cwnd)); bbr_reset_mode(sk); } @@ -1023,7 +1024,7 @@ static void bbr_update_model(struct sock *sk, const struct rate_sample *rs) bbr_update_gains(sk); } -static void bbr_main(struct sock *sk, const struct rate_sample *rs) +__bpf_kfunc static void bbr_main(struct sock *sk, u32 ack, int flag, const struct rate_sample *rs) { struct bbr *bbr = inet_csk_ca(sk); u32 bw; @@ -1035,7 +1036,7 @@ static void bbr_main(struct sock *sk, const struct rate_sample *rs) bbr_set_cwnd(sk, rs, rs->acked_sacked, bw, bbr->cwnd_gain); } -static void bbr_init(struct sock *sk) +__bpf_kfunc static void bbr_init(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); @@ -1077,7 +1078,7 @@ static void bbr_init(struct sock *sk) cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); } -static u32 bbr_sndbuf_expand(struct sock *sk) +__bpf_kfunc static u32 bbr_sndbuf_expand(struct sock *sk) { /* Provision 3 * cwnd since BBR may slow-start even during recovery. */ return 3; @@ -1086,18 +1087,18 @@ static u32 bbr_sndbuf_expand(struct sock *sk) /* In theory BBR does not need to undo the cwnd since it does not * always reduce cwnd on losses (see bbr_main()). Keep it for now. */ -static u32 bbr_undo_cwnd(struct sock *sk) +__bpf_kfunc static u32 bbr_undo_cwnd(struct sock *sk) { struct bbr *bbr = inet_csk_ca(sk); bbr->full_bw = 0; /* spurious slow-down; reset full pipe detection */ bbr->full_bw_cnt = 0; bbr_reset_lt_bw_sampling(sk); - return tcp_sk(sk)->snd_cwnd; + return tcp_snd_cwnd(tcp_sk(sk)); } /* Entering loss recovery, so save cwnd for when we exit or undo recovery. */ -static u32 bbr_ssthresh(struct sock *sk) +__bpf_kfunc static u32 bbr_ssthresh(struct sock *sk) { bbr_save_cwnd(sk); return tcp_sk(sk)->snd_ssthresh; @@ -1125,7 +1126,7 @@ static size_t bbr_get_info(struct sock *sk, u32 ext, int *attr, return 0; } -static void bbr_set_state(struct sock *sk, u8 new_state) +__bpf_kfunc static void bbr_set_state(struct sock *sk, u8 new_state) { struct bbr *bbr = inet_csk_ca(sk); @@ -1154,38 +1155,36 @@ static struct tcp_congestion_ops tcp_bbr_cong_ops __read_mostly = { .set_state = bbr_set_state, }; -BTF_SET_START(tcp_bbr_kfunc_ids) -#ifdef CONFIG_X86 -#ifdef CONFIG_DYNAMIC_FTRACE -BTF_ID(func, bbr_init) -BTF_ID(func, bbr_main) -BTF_ID(func, bbr_sndbuf_expand) -BTF_ID(func, bbr_undo_cwnd) -BTF_ID(func, bbr_cwnd_event) -BTF_ID(func, bbr_ssthresh) -BTF_ID(func, bbr_min_tso_segs) -BTF_ID(func, bbr_set_state) -#endif -#endif -BTF_SET_END(tcp_bbr_kfunc_ids) - -static DEFINE_KFUNC_BTF_ID_SET(&tcp_bbr_kfunc_ids, tcp_bbr_kfunc_btf_set); +BTF_KFUNCS_START(tcp_bbr_check_kfunc_ids) +BTF_ID_FLAGS(func, bbr_init) +BTF_ID_FLAGS(func, bbr_main) +BTF_ID_FLAGS(func, bbr_sndbuf_expand) +BTF_ID_FLAGS(func, bbr_undo_cwnd) +BTF_ID_FLAGS(func, bbr_cwnd_event) +BTF_ID_FLAGS(func, bbr_ssthresh) +BTF_ID_FLAGS(func, bbr_min_tso_segs) +BTF_ID_FLAGS(func, bbr_set_state) +BTF_KFUNCS_END(tcp_bbr_check_kfunc_ids) + +static const struct btf_kfunc_id_set tcp_bbr_kfunc_set = { + .owner = THIS_MODULE, + .set = &tcp_bbr_check_kfunc_ids, +}; static int __init bbr_register(void) { int ret; BUILD_BUG_ON(sizeof(struct bbr) > ICSK_CA_PRIV_SIZE); - ret = tcp_register_congestion_control(&tcp_bbr_cong_ops); - if (ret) + + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, &tcp_bbr_kfunc_set); + if (ret < 0) return ret; - register_kfunc_btf_id_set(&bpf_tcp_ca_kfunc_list, &tcp_bbr_kfunc_btf_set); - return 0; + return tcp_register_congestion_control(&tcp_bbr_cong_ops); } static void __exit bbr_unregister(void) { - unregister_kfunc_btf_id_set(&bpf_tcp_ca_kfunc_list, &tcp_bbr_kfunc_btf_set); tcp_unregister_congestion_control(&tcp_bbr_cong_ops); } diff --git a/net/ipv4/tcp_bic.c b/net/ipv4/tcp_bic.c index f5f588b1f6e9..58358bf92e1b 100644 --- a/net/ipv4/tcp_bic.c +++ b/net/ipv4/tcp_bic.c @@ -150,7 +150,7 @@ static void bictcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) if (!acked) return; } - bictcp_update(ca, tp->snd_cwnd); + bictcp_update(ca, tcp_snd_cwnd(tp)); tcp_cong_avoid_ai(tp, ca->cnt, acked); } @@ -166,16 +166,16 @@ static u32 bictcp_recalc_ssthresh(struct sock *sk) ca->epoch_start = 0; /* end of epoch */ /* Wmax and fast convergence */ - if (tp->snd_cwnd < ca->last_max_cwnd && fast_convergence) - ca->last_max_cwnd = (tp->snd_cwnd * (BICTCP_BETA_SCALE + beta)) + if (tcp_snd_cwnd(tp) < ca->last_max_cwnd && fast_convergence) + ca->last_max_cwnd = (tcp_snd_cwnd(tp) * (BICTCP_BETA_SCALE + beta)) / (2 * BICTCP_BETA_SCALE); else - ca->last_max_cwnd = tp->snd_cwnd; + ca->last_max_cwnd = tcp_snd_cwnd(tp); - if (tp->snd_cwnd <= low_window) - return max(tp->snd_cwnd >> 1U, 2U); + if (tcp_snd_cwnd(tp) <= low_window) + return max(tcp_snd_cwnd(tp) >> 1U, 2U); else - return max((tp->snd_cwnd * beta) / BICTCP_BETA_SCALE, 2U); + return max((tcp_snd_cwnd(tp) * beta) / BICTCP_BETA_SCALE, 2U); } static void bictcp_state(struct sock *sk, u8 new_state) diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index 9b9b02052fd3..a268e1595b22 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -6,12 +6,31 @@ #include <linux/bpf.h> #include <linux/init.h> #include <linux/wait.h> +#include <linux/util_macros.h> #include <net/inet_common.h> #include <net/tls.h> +void tcp_eat_skb(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tcp; + int copied; + + if (!skb || !skb->len || !sk_is_tcp(sk)) + return; + + if (skb_bpf_strparser(skb)) + return; + + tcp = tcp_sk(sk); + copied = tcp->copied_seq + skb->len; + WRITE_ONCE(tcp->copied_seq, copied); + tcp_rcv_space_adjust(sk); + __tcp_cleanup_rbuf(sk, skb->len); +} + static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, - struct sk_msg *msg, u32 apply_bytes, int flags) + struct sk_msg *msg, u32 apply_bytes) { bool apply = apply_bytes; struct scatterlist *sge; @@ -30,13 +49,14 @@ static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, sge = sk_msg_elem(msg, i); size = (apply && apply_bytes < sge->length) ? apply_bytes : sge->length; - if (!sk_wmem_schedule(sk, size)) { + if (!__sk_rmem_schedule(sk, size, false)) { if (!copied) ret = -ENOMEM; break; } sk_mem_charge(sk, size); + atomic_add(size, &sk->sk_rmem_alloc); sk_msg_xfer(tmp, msg, i, size); copied += size; if (sge->length) @@ -45,14 +65,18 @@ static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, tmp->sg.end = i; if (apply) { apply_bytes -= size; - if (!apply_bytes) + if (!apply_bytes) { + if (sge->length) + sk_msg_iter_var_prev(i); break; + } } } while (i != msg->sg.end); if (!ret) { msg->sg.start = i; - sk_psock_queue_msg(psock, tmp); + if (!sk_psock_queue_msg(psock, tmp)) + atomic_sub(copied, &sk->sk_rmem_alloc); sk_psock_data_ready(sk, psock); } else { sk_msg_free(sk, tmp); @@ -66,6 +90,7 @@ static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, int flags, bool uncharge) { + struct msghdr msghdr = {}; bool apply = apply_bytes; struct scatterlist *sge; struct page *page; @@ -73,6 +98,7 @@ static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, u32 off; while (1) { + struct bio_vec bvec; bool has_tx_ulp; sge = sk_msg_elem(msg, msg->sg.start); @@ -83,17 +109,20 @@ static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, tcp_rate_check_app_limited(sk); retry: + msghdr.msg_flags = flags | MSG_SPLICE_PAGES; has_tx_ulp = tls_sw_has_ctx_tx(sk); - if (has_tx_ulp) { - flags |= MSG_SENDPAGE_NOPOLICY; - ret = kernel_sendpage_locked(sk, - page, off, size, flags); - } else { - ret = do_tcp_sendpages(sk, page, off, size, flags); - } + if (has_tx_ulp) + msghdr.msg_flags |= MSG_SENDPAGE_NOPOLICY; + if (size < sge->length && msg->sg.start != msg->sg.end) + msghdr.msg_flags |= MSG_MORE; + + bvec_set_page(&bvec, page, size, off); + iov_iter_bvec(&msghdr.msg_iter, ITER_SOURCE, &bvec, 1, size); + ret = tcp_sendmsg_locked(sk, &msghdr, size); if (ret <= 0) return ret; + if (apply) apply_bytes -= ret; msg->sg.size -= ret; @@ -131,18 +160,16 @@ static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, return ret; } -int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, - u32 bytes, int flags) +int tcp_bpf_sendmsg_redir(struct sock *sk, bool ingress, + struct sk_msg *msg, u32 bytes, int flags) { - bool ingress = sk_msg_to_ingress(msg); struct sk_psock *psock = sk_psock_get(sk); int ret; - if (unlikely(!psock)) { - sk_msg_free(sk, msg); - return 0; - } - ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : + if (unlikely(!psock)) + return -EPIPE; + + ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes) : tcp_bpf_push_locked(sk, msg, bytes, flags, false); sk_psock_put(sk, psock); return ret; @@ -166,32 +193,91 @@ static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); ret = sk_wait_event(sk, &timeo, !list_empty(&psock->ingress_msg) || - !skb_queue_empty(&sk->sk_receive_queue), &wait); + !skb_queue_empty_lockless(&sk->sk_receive_queue), &wait); sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); remove_wait_queue(sk_sleep(sk), &wait); return ret; } +static bool is_next_msg_fin(struct sk_psock *psock) +{ + struct scatterlist *sge; + struct sk_msg *msg_rx; + int i; + + msg_rx = sk_psock_peek_msg(psock); + i = msg_rx->sg.start; + sge = sk_msg_elem(msg_rx, i); + if (!sge->length) { + struct sk_buff *skb = msg_rx->skb; + + if (skb && TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) + return true; + } + return false; +} + static int tcp_bpf_recvmsg_parser(struct sock *sk, struct msghdr *msg, size_t len, - int nonblock, int flags, int *addr_len) { + int peek = flags & MSG_PEEK; struct sk_psock *psock; - int copied; + struct tcp_sock *tcp; + int copied = 0; + u32 seq; if (unlikely(flags & MSG_ERRQUEUE)) return inet_recv_error(sk, msg, len, addr_len); + if (!len) + return 0; + psock = sk_psock_get(sk); if (unlikely(!psock)) - return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + return tcp_recvmsg(sk, msg, len, flags, addr_len); lock_sock(sk); + tcp = tcp_sk(sk); + seq = tcp->copied_seq; + /* We may have received data on the sk_receive_queue pre-accept and + * then we can not use read_skb in this context because we haven't + * assigned a sk_socket yet so have no link to the ops. The work-around + * is to check the sk_receive_queue and in these cases read skbs off + * queue again. The read_skb hook is not running at this point because + * of lock_sock so we avoid having multiple runners in read_skb. + */ + if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) { + tcp_data_ready(sk); + /* This handles the ENOMEM errors if we both receive data + * pre accept and are already under memory pressure. At least + * let user know to retry. + */ + if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) { + copied = -EAGAIN; + goto out; + } + } + msg_bytes_ready: copied = sk_msg_recvmsg(sk, psock, msg, len, flags); + /* The typical case for EFAULT is the socket was gracefully + * shutdown with a FIN pkt. So check here the other case is + * some error on copy_page_to_iter which would be unexpected. + * On fin return correct return code to zero. + */ + if (copied == -EFAULT) { + bool is_fin = is_next_msg_fin(psock); + + if (is_fin) { + copied = 0; + seq++; + goto out; + } + } + seq += copied; if (!copied) { long timeo; int data; @@ -212,7 +298,7 @@ msg_bytes_ready: goto out; } - timeo = sock_rcvtimeo(sk, nonblock); + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); if (!timeo) { copied = -EAGAIN; goto out; @@ -224,18 +310,29 @@ msg_bytes_ready: } data = tcp_msg_wait_data(sk, psock, timeo); + if (data < 0) { + copied = data; + goto unlock; + } if (data && !sk_psock_queue_empty(psock)) goto msg_bytes_ready; copied = -EAGAIN; } out: + if (!peek) + WRITE_ONCE(tcp->copied_seq, seq); + tcp_rcv_space_adjust(sk); + if (copied > 0) + __tcp_cleanup_rbuf(sk, copied); + +unlock: release_sock(sk); sk_psock_put(sk, psock); return copied; } static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, - int nonblock, int flags, int *addr_len) + int flags, int *addr_len) { struct sk_psock *psock; int copied, ret; @@ -243,13 +340,16 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, if (unlikely(flags & MSG_ERRQUEUE)) return inet_recv_error(sk, msg, len, addr_len); + if (!len) + return 0; + psock = sk_psock_get(sk); if (unlikely(!psock)) - return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + return tcp_recvmsg(sk, msg, len, flags, addr_len); if (!skb_queue_empty(&sk->sk_receive_queue) && sk_psock_queue_empty(psock)) { sk_psock_put(sk, psock); - return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + return tcp_recvmsg(sk, msg, len, flags, addr_len); } lock_sock(sk); msg_bytes_ready: @@ -258,18 +358,24 @@ msg_bytes_ready: long timeo; int data; - timeo = sock_rcvtimeo(sk, nonblock); + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); data = tcp_msg_wait_data(sk, psock, timeo); + if (data < 0) { + ret = data; + goto unlock; + } if (data) { if (!sk_psock_queue_empty(psock)) goto msg_bytes_ready; release_sock(sk); sk_psock_put(sk, psock); - return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + return tcp_recvmsg(sk, msg, len, flags, addr_len); } copied = -EAGAIN; } ret = copied; + +unlock: release_sock(sk); sk_psock_put(sk, psock); return ret; @@ -278,10 +384,10 @@ msg_bytes_ready: static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, struct sk_msg *msg, int *copied, int flags) { - bool cork = false, enospc = sk_msg_full(msg); + bool cork = false, enospc = sk_msg_full(msg), redir_ingress; struct sock *sk_redir; - u32 tosend, delta = 0; - u32 eval = __SK_NONE; + u32 tosend, origsize, sent, delta = 0; + u32 eval; int ret; more_data: @@ -302,8 +408,11 @@ more_data: if (!psock->cork) { psock->cork = kzalloc(sizeof(*psock->cork), GFP_ATOMIC | __GFP_NOWARN); - if (!psock->cork) + if (!psock->cork) { + sk_msg_free(sk, msg); + *copied = 0; return -ENOMEM; + } } memcpy(psock->cork, msg, sizeof(*msg)); return 0; @@ -312,6 +421,7 @@ more_data: tosend = msg->sg.size; if (psock->apply_bytes && psock->apply_bytes < tosend) tosend = psock->apply_bytes; + eval = __SK_NONE; switch (psock->eval) { case __SK_PASS: @@ -323,6 +433,7 @@ more_data: sk_msg_apply_bytes(psock, tosend); break; case __SK_REDIRECT: + redir_ingress = psock->redir_ingress; sk_redir = psock->sk_redir; sk_msg_apply_bytes(psock, tosend); if (!psock->apply_bytes) { @@ -335,17 +446,20 @@ more_data: cork = true; psock->cork = NULL; } - sk_msg_return(sk, msg, tosend); release_sock(sk); - ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); + origsize = msg->sg.size; + ret = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, + msg, tosend, flags); + sent = origsize - msg->sg.size; if (eval == __SK_REDIRECT) sock_put(sk_redir); lock_sock(sk); + sk_mem_uncharge(sk, sent); if (unlikely(ret < 0)) { - int free = sk_msg_free_nocharge(sk, msg); + int free = sk_msg_free(sk, msg); if (!cork) *copied -= free; @@ -359,7 +473,7 @@ more_data: break; case __SK_DROP: default: - sk_msg_free_partial(sk, msg, tosend); + sk_msg_free(sk, msg); sk_msg_apply_bytes(psock, tosend); *copied -= (tosend + delta); return -EACCES; @@ -384,12 +498,12 @@ more_data: static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) { struct sk_msg tmp, *msg_tx = NULL; - int copied = 0, err = 0; + int copied = 0, err = 0, ret = 0; struct sk_psock *psock; long timeo; int flags; - /* Don't let internal do_tcp_sendpages() flags through */ + /* Don't let internal flags through */ flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED); flags |= MSG_NO_SHARED_FRAGS; @@ -427,14 +541,14 @@ static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) copy = msg_tx->sg.size - osize; } - err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, + ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, copy); - if (err < 0) { + if (ret < 0) { sk_msg_trim(sk, msg_tx, osize); goto out_err; } - copied += copy; + copied += ret; if (psock->cork_bytes) { if (size > psock->cork_bytes) psock->cork_bytes = 0; @@ -466,55 +580,7 @@ out_err: err = sk_stream_error(sk, msg->msg_flags, err); release_sock(sk); sk_psock_put(sk, psock); - return copied ? copied : err; -} - -static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, - size_t size, int flags) -{ - struct sk_msg tmp, *msg = NULL; - int err = 0, copied = 0; - struct sk_psock *psock; - bool enospc = false; - - psock = sk_psock_get(sk); - if (unlikely(!psock)) - return tcp_sendpage(sk, page, offset, size, flags); - - lock_sock(sk); - if (psock->cork) { - msg = psock->cork; - } else { - msg = &tmp; - sk_msg_init(msg); - } - - /* Catch case where ring is full and sendpage is stalled. */ - if (unlikely(sk_msg_full(msg))) - goto out_err; - - sk_msg_page_add(msg, page, size, offset); - sk_mem_charge(sk, size); - copied = size; - if (sk_msg_full(msg)) - enospc = true; - if (psock->cork_bytes) { - if (size > psock->cork_bytes) - psock->cork_bytes = 0; - else - psock->cork_bytes -= size; - if (psock->cork_bytes && !enospc) - goto out_err; - /* All cork bytes are accounted, rerun the prog. */ - psock->eval = __SK_NONE; - psock->cork_bytes = 0; - } - - err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); -out_err: - release_sock(sk); - sk_psock_put(sk, psock); - return copied ? copied : err; + return copied > 0 ? copied : err; } enum { @@ -539,13 +605,13 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], struct proto *base) { prot[TCP_BPF_BASE] = *base; + prot[TCP_BPF_BASE].destroy = sock_map_destroy; prot[TCP_BPF_BASE].close = sock_map_close; prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable; prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; - prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; prot[TCP_BPF_RX] = prot[TCP_BPF_BASE]; prot[TCP_BPF_RX].recvmsg = tcp_bpf_recvmsg_parser; @@ -580,10 +646,45 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops) * indeed valid assumptions. */ return ops->recvmsg == tcp_recvmsg && - ops->sendmsg == tcp_sendmsg && - ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; + ops->sendmsg == tcp_sendmsg ? 0 : -ENOTSUPP; } +#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER) +int tcp_bpf_strp_read_sock(struct strparser *strp, read_descriptor_t *desc, + sk_read_actor_t recv_actor) +{ + struct sock *sk = strp->sk; + struct sk_psock *psock; + struct tcp_sock *tp; + int copied = 0; + + tp = tcp_sk(sk); + rcu_read_lock(); + psock = sk_psock(sk); + if (WARN_ON_ONCE(!psock)) { + desc->error = -EINVAL; + goto out; + } + + psock->ingress_bytes = 0; + copied = tcp_read_sock_noack(sk, desc, recv_actor, true, + &psock->copied_seq); + if (copied < 0) + goto out; + /* recv_actor may redirect skb to another socket (SK_REDIRECT) or + * just put skb into ingress queue of current socket (SK_PASS). + * For SK_REDIRECT, we need to ack the frame immediately but for + * SK_PASS, we want to delay the ack until tcp_bpf_recvmsg_parser(). + */ + tp->copied_seq = psock->copied_seq - psock->ingress_bytes; + tcp_rcv_space_adjust(sk); + __tcp_cleanup_rbuf(sk, copied - psock->ingress_bytes); +out: + rcu_read_unlock(); + return copied; +} +#endif /* CONFIG_BPF_STREAM_PARSER */ + int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) { int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; @@ -605,14 +706,11 @@ int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) } else { sk->sk_write_space = psock->saved_write_space; /* Pairs with lockless read in sk_clone_lock() */ - WRITE_ONCE(sk->sk_prot, psock->sk_proto); + sock_replace_proto(sk, psock->sk_proto); } return 0; } - if (inet_csk_has_ulp(sk)) - return -EINVAL; - if (sk->sk_family == AF_INET6) { if (tcp_bpf_assert_proto_ops(psock->sk_proto)) return -EINVAL; @@ -621,7 +719,7 @@ int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) } /* Pairs with lockless read in sk_clone_lock() */ - WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]); + sock_replace_proto(sk, &tcp_bpf_prots[family][config]); return 0; } EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); @@ -633,10 +731,9 @@ EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); */ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) { - int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; struct proto *prot = newsk->sk_prot; - if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE]) + if (is_insidevar(prot, tcp_bpf_prots)) newsk->sk_prot = sk->sk_prot_creator; } #endif /* CONFIG_BPF_SYSCALL */ diff --git a/net/ipv4/tcp_cdg.c b/net/ipv4/tcp_cdg.c index 709d23801823..fbad6c35dee9 100644 --- a/net/ipv4/tcp_cdg.c +++ b/net/ipv4/tcp_cdg.c @@ -161,8 +161,8 @@ static void tcp_cdg_hystart_update(struct sock *sk) LINUX_MIB_TCPHYSTARTTRAINDETECT); NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPHYSTARTTRAINCWND, - tp->snd_cwnd); - tp->snd_ssthresh = tp->snd_cwnd; + tcp_snd_cwnd(tp)); + tp->snd_ssthresh = tcp_snd_cwnd(tp); return; } } @@ -180,8 +180,8 @@ static void tcp_cdg_hystart_update(struct sock *sk) LINUX_MIB_TCPHYSTARTDELAYDETECT); NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPHYSTARTDELAYCWND, - tp->snd_cwnd); - tp->snd_ssthresh = tp->snd_cwnd; + tcp_snd_cwnd(tp)); + tp->snd_ssthresh = tcp_snd_cwnd(tp); } } } @@ -243,7 +243,7 @@ static bool tcp_cdg_backoff(struct sock *sk, u32 grad) struct cdg *ca = inet_csk_ca(sk); struct tcp_sock *tp = tcp_sk(sk); - if (prandom_u32() <= nexp_u32(grad * backoff_factor)) + if (get_random_u32() <= nexp_u32(grad * backoff_factor)) return false; if (use_ineff) { @@ -252,7 +252,7 @@ static bool tcp_cdg_backoff(struct sock *sk, u32 grad) return false; } - ca->shadow_wnd = max(ca->shadow_wnd, tp->snd_cwnd); + ca->shadow_wnd = max(ca->shadow_wnd, tcp_snd_cwnd(tp)); ca->state = CDG_BACKOFF; tcp_enter_cwr(sk); return true; @@ -285,14 +285,14 @@ static void tcp_cdg_cong_avoid(struct sock *sk, u32 ack, u32 acked) } if (!tcp_is_cwnd_limited(sk)) { - ca->shadow_wnd = min(ca->shadow_wnd, tp->snd_cwnd); + ca->shadow_wnd = min(ca->shadow_wnd, tcp_snd_cwnd(tp)); return; } - prior_snd_cwnd = tp->snd_cwnd; + prior_snd_cwnd = tcp_snd_cwnd(tp); tcp_reno_cong_avoid(sk, ack, acked); - incr = tp->snd_cwnd - prior_snd_cwnd; + incr = tcp_snd_cwnd(tp) - prior_snd_cwnd; ca->shadow_wnd = max(ca->shadow_wnd, ca->shadow_wnd + incr); } @@ -331,15 +331,15 @@ static u32 tcp_cdg_ssthresh(struct sock *sk) struct tcp_sock *tp = tcp_sk(sk); if (ca->state == CDG_BACKOFF) - return max(2U, (tp->snd_cwnd * min(1024U, backoff_beta)) >> 10); + return max(2U, (tcp_snd_cwnd(tp) * min(1024U, backoff_beta)) >> 10); if (ca->state == CDG_NONFULL && use_tolerance) - return tp->snd_cwnd; + return tcp_snd_cwnd(tp); - ca->shadow_wnd = min(ca->shadow_wnd >> 1, tp->snd_cwnd); + ca->shadow_wnd = min(ca->shadow_wnd >> 1, tcp_snd_cwnd(tp)); if (use_shadow) - return max3(2U, ca->shadow_wnd, tp->snd_cwnd >> 1); - return max(2U, tp->snd_cwnd >> 1); + return max3(2U, ca->shadow_wnd, tcp_snd_cwnd(tp) >> 1); + return max(2U, tcp_snd_cwnd(tp) >> 1); } static void tcp_cdg_cwnd_event(struct sock *sk, const enum tcp_ca_event ev) @@ -357,7 +357,7 @@ static void tcp_cdg_cwnd_event(struct sock *sk, const enum tcp_ca_event ev) ca->gradients = gradients; ca->rtt_seq = tp->snd_nxt; - ca->shadow_wnd = tp->snd_cwnd; + ca->shadow_wnd = tcp_snd_cwnd(tp); break; case CA_EVENT_COMPLETE_CWR: ca->state = CDG_UNKNOWN; @@ -375,12 +375,13 @@ static void tcp_cdg_init(struct sock *sk) struct cdg *ca = inet_csk_ca(sk); struct tcp_sock *tp = tcp_sk(sk); + ca->gradients = NULL; /* We silently fall back to window = 1 if allocation fails. */ if (window > 1) ca->gradients = kcalloc(window, sizeof(ca->gradients[0]), - GFP_NOWAIT | __GFP_NOWARN); + GFP_NOWAIT); ca->rtt_seq = tp->snd_nxt; - ca->shadow_wnd = tp->snd_cwnd; + ca->shadow_wnd = tcp_snd_cwnd(tp); } static void tcp_cdg_release(struct sock *sk) @@ -388,6 +389,7 @@ static void tcp_cdg_release(struct sock *sk) struct cdg *ca = inet_csk_ca(sk); kfree(ca->gradients); + ca->gradients = NULL; } static struct tcp_congestion_ops tcp_cdg __read_mostly = { diff --git a/net/ipv4/tcp_cong.c b/net/ipv4/tcp_cong.c index db5831e6c136..df758adbb445 100644 --- a/net/ipv4/tcp_cong.c +++ b/net/ipv4/tcp_cong.c @@ -16,6 +16,7 @@ #include <linux/gfp.h> #include <linux/jhash.h> #include <net/tcp.h> +#include <trace/events/tcp.h> static DEFINE_SPINLOCK(tcp_cong_list_lock); static LIST_HEAD(tcp_cong_list); @@ -33,9 +34,19 @@ struct tcp_congestion_ops *tcp_ca_find(const char *name) return NULL; } +void tcp_set_ca_state(struct sock *sk, const u8 ca_state) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + trace_tcp_cong_state_set(sk, ca_state); + + if (icsk->icsk_ca_ops->set_state) + icsk->icsk_ca_ops->set_state(sk, ca_state); + icsk->icsk_ca_state = ca_state; +} + /* Must be called with rcu lock held */ -static struct tcp_congestion_ops *tcp_ca_find_autoload(struct net *net, - const char *name) +static struct tcp_congestion_ops *tcp_ca_find_autoload(const char *name) { struct tcp_congestion_ops *ca = tcp_ca_find(name); @@ -63,14 +74,8 @@ struct tcp_congestion_ops *tcp_ca_find_key(u32 key) return NULL; } -/* - * Attach new congestion control algorithm to the list - * of available options. - */ -int tcp_register_congestion_control(struct tcp_congestion_ops *ca) +int tcp_validate_congestion_control(struct tcp_congestion_ops *ca) { - int ret = 0; - /* all algorithms must implement these */ if (!ca->ssthresh || !ca->undo_cwnd || !(ca->cong_avoid || ca->cong_control)) { @@ -78,6 +83,20 @@ int tcp_register_congestion_control(struct tcp_congestion_ops *ca) return -EINVAL; } + return 0; +} + +/* Attach new congestion control algorithm to the list + * of available options. + */ +int tcp_register_congestion_control(struct tcp_congestion_ops *ca) +{ + int ret; + + ret = tcp_validate_congestion_control(ca); + if (ret) + return ret; + ca->key = jhash(ca->name, sizeof(ca->name), strlen(ca->name)); spin_lock(&tcp_cong_list_lock); @@ -118,7 +137,47 @@ void tcp_unregister_congestion_control(struct tcp_congestion_ops *ca) } EXPORT_SYMBOL_GPL(tcp_unregister_congestion_control); -u32 tcp_ca_get_key_by_name(struct net *net, const char *name, bool *ecn_ca) +/* Replace a registered old ca with a new one. + * + * The new ca must have the same name as the old one, that has been + * registered. + */ +int tcp_update_congestion_control(struct tcp_congestion_ops *ca, struct tcp_congestion_ops *old_ca) +{ + struct tcp_congestion_ops *existing; + int ret = 0; + + ca->key = jhash(ca->name, sizeof(ca->name), strlen(ca->name)); + + spin_lock(&tcp_cong_list_lock); + existing = tcp_ca_find_key(old_ca->key); + if (ca->key == TCP_CA_UNSPEC || !existing || strcmp(existing->name, ca->name)) { + pr_notice("%s not registered or non-unique key\n", + ca->name); + ret = -EINVAL; + } else if (existing != old_ca) { + pr_notice("invalid old congestion control algorithm to replace\n"); + ret = -EINVAL; + } else { + /* Add the new one before removing the old one to keep + * one implementation available all the time. + */ + list_add_tail_rcu(&ca->list, &tcp_cong_list); + list_del_rcu(&existing->list); + pr_debug("%s updated\n", ca->name); + } + spin_unlock(&tcp_cong_list_lock); + + /* Wait for outstanding readers to complete before the + * module or struct_ops gets removed entirely. + */ + if (!ret) + synchronize_rcu(); + + return ret; +} + +u32 tcp_ca_get_key_by_name(const char *name, bool *ecn_ca) { const struct tcp_congestion_ops *ca; u32 key = TCP_CA_UNSPEC; @@ -126,7 +185,7 @@ u32 tcp_ca_get_key_by_name(struct net *net, const char *name, bool *ecn_ca) might_sleep(); rcu_read_lock(); - ca = tcp_ca_find_autoload(net, name); + ca = tcp_ca_find_autoload(name); if (ca) { key = ca->key; *ecn_ca = ca->flags & TCP_CONG_NEEDS_ECN; @@ -135,7 +194,6 @@ u32 tcp_ca_get_key_by_name(struct net *net, const char *name, bool *ecn_ca) return key; } -EXPORT_SYMBOL_GPL(tcp_ca_get_key_by_name); char *tcp_ca_get_name_by_key(u32 key, char *buffer) { @@ -144,14 +202,14 @@ char *tcp_ca_get_name_by_key(u32 key, char *buffer) rcu_read_lock(); ca = tcp_ca_find_key(key); - if (ca) - ret = strncpy(buffer, ca->name, - TCP_CA_NAME_MAX); + if (ca) { + strscpy(buffer, ca->name, TCP_CA_NAME_MAX); + ret = buffer; + } rcu_read_unlock(); return ret; } -EXPORT_SYMBOL_GPL(tcp_ca_get_name_by_key); /* Assign choice of congestion control. */ void tcp_assign_congestion_control(struct sock *sk) @@ -212,8 +270,9 @@ void tcp_cleanup_congestion_control(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); - if (icsk->icsk_ca_ops->release) + if (icsk->icsk_ca_initialized && icsk->icsk_ca_ops->release) icsk->icsk_ca_ops->release(sk); + icsk->icsk_ca_initialized = 0; bpf_module_put(icsk->icsk_ca_ops, icsk->icsk_ca_ops->owner); } @@ -225,7 +284,7 @@ int tcp_set_default_congestion_control(struct net *net, const char *name) int ret; rcu_read_lock(); - ca = tcp_ca_find_autoload(net, name); + ca = tcp_ca_find_autoload(name); if (!ca) { ret = -ENOENT; } else if (!bpf_try_module_get(ca, ca->owner)) { @@ -280,7 +339,7 @@ void tcp_get_default_congestion_control(struct net *net, char *name) rcu_read_lock(); ca = rcu_dereference(net->ipv4.tcp_congestion_control); - strncpy(name, ca->name, TCP_CA_NAME_MAX); + strscpy(name, ca->name, TCP_CA_NAME_MAX); rcu_read_unlock(); } @@ -363,7 +422,7 @@ int tcp_set_congestion_control(struct sock *sk, const char *name, bool load, if (!load) ca = tcp_ca_find(name); else - ca = tcp_ca_find_autoload(sock_net(sk), name); + ca = tcp_ca_find_autoload(name); /* No change asking for existing value */ if (ca == icsk->icsk_ca_ops) { @@ -393,12 +452,12 @@ int tcp_set_congestion_control(struct sock *sk, const char *name, bool load, * ABC caps N to 2. Slow start exits when cwnd grows over ssthresh and * returns the leftover acks to adjust cwnd in congestion avoidance mode. */ -u32 tcp_slow_start(struct tcp_sock *tp, u32 acked) +__bpf_kfunc u32 tcp_slow_start(struct tcp_sock *tp, u32 acked) { - u32 cwnd = min(tp->snd_cwnd + acked, tp->snd_ssthresh); + u32 cwnd = min(tcp_snd_cwnd(tp) + acked, tp->snd_ssthresh); - acked -= cwnd - tp->snd_cwnd; - tp->snd_cwnd = min(cwnd, tp->snd_cwnd_clamp); + acked -= cwnd - tcp_snd_cwnd(tp); + tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); return acked; } @@ -407,12 +466,12 @@ EXPORT_SYMBOL_GPL(tcp_slow_start); /* In theory this is tp->snd_cwnd += 1 / tp->snd_cwnd (or alternative w), * for every packet that was ACKed. */ -void tcp_cong_avoid_ai(struct tcp_sock *tp, u32 w, u32 acked) +__bpf_kfunc void tcp_cong_avoid_ai(struct tcp_sock *tp, u32 w, u32 acked) { /* If credits accumulated at a higher w, apply them gently now. */ if (tp->snd_cwnd_cnt >= w) { tp->snd_cwnd_cnt = 0; - tp->snd_cwnd++; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); } tp->snd_cwnd_cnt += acked; @@ -420,9 +479,9 @@ void tcp_cong_avoid_ai(struct tcp_sock *tp, u32 w, u32 acked) u32 delta = tp->snd_cwnd_cnt / w; tp->snd_cwnd_cnt -= delta * w; - tp->snd_cwnd += delta; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + delta); } - tp->snd_cwnd = min(tp->snd_cwnd, tp->snd_cwnd_clamp); + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), tp->snd_cwnd_clamp)); } EXPORT_SYMBOL_GPL(tcp_cong_avoid_ai); @@ -433,7 +492,7 @@ EXPORT_SYMBOL_GPL(tcp_cong_avoid_ai); /* This is Jacobson's slow start and congestion avoidance. * SIGCOMM '88, p. 328. */ -void tcp_reno_cong_avoid(struct sock *sk, u32 ack, u32 acked) +__bpf_kfunc void tcp_reno_cong_avoid(struct sock *sk, u32 ack, u32 acked) { struct tcp_sock *tp = tcp_sk(sk); @@ -447,24 +506,24 @@ void tcp_reno_cong_avoid(struct sock *sk, u32 ack, u32 acked) return; } /* In dangerous area, increase slowly. */ - tcp_cong_avoid_ai(tp, tp->snd_cwnd, acked); + tcp_cong_avoid_ai(tp, tcp_snd_cwnd(tp), acked); } EXPORT_SYMBOL_GPL(tcp_reno_cong_avoid); /* Slow start threshold is half the congestion window (min 2) */ -u32 tcp_reno_ssthresh(struct sock *sk) +__bpf_kfunc u32 tcp_reno_ssthresh(struct sock *sk) { const struct tcp_sock *tp = tcp_sk(sk); - return max(tp->snd_cwnd >> 1U, 2U); + return max(tcp_snd_cwnd(tp) >> 1U, 2U); } EXPORT_SYMBOL_GPL(tcp_reno_ssthresh); -u32 tcp_reno_undo_cwnd(struct sock *sk) +__bpf_kfunc u32 tcp_reno_undo_cwnd(struct sock *sk) { const struct tcp_sock *tp = tcp_sk(sk); - return max(tp->snd_cwnd, tp->prior_cwnd); + return max(tcp_snd_cwnd(tp), tp->prior_cwnd); } EXPORT_SYMBOL_GPL(tcp_reno_undo_cwnd); diff --git a/net/ipv4/tcp_cubic.c b/net/ipv4/tcp_cubic.c index e07837e23b3f..76c23675ae50 100644 --- a/net/ipv4/tcp_cubic.c +++ b/net/ipv4/tcp_cubic.c @@ -126,7 +126,7 @@ static inline void bictcp_hystart_reset(struct sock *sk) ca->sample_cnt = 0; } -static void cubictcp_init(struct sock *sk) +__bpf_kfunc static void cubictcp_init(struct sock *sk) { struct bictcp *ca = inet_csk_ca(sk); @@ -139,7 +139,7 @@ static void cubictcp_init(struct sock *sk) tcp_sk(sk)->snd_ssthresh = initial_ssthresh; } -static void cubictcp_cwnd_event(struct sock *sk, enum tcp_ca_event event) +__bpf_kfunc static void cubictcp_cwnd_event(struct sock *sk, enum tcp_ca_event event) { if (event == CA_EVENT_TX_START) { struct bictcp *ca = inet_csk_ca(sk); @@ -321,7 +321,7 @@ tcp_friendliness: ca->cnt = max(ca->cnt, 2U); } -static void cubictcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) +__bpf_kfunc static void cubictcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) { struct tcp_sock *tp = tcp_sk(sk); struct bictcp *ca = inet_csk_ca(sk); @@ -334,11 +334,11 @@ static void cubictcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) if (!acked) return; } - bictcp_update(ca, tp->snd_cwnd, acked); + bictcp_update(ca, tcp_snd_cwnd(tp), acked); tcp_cong_avoid_ai(tp, ca->cnt, acked); } -static u32 cubictcp_recalc_ssthresh(struct sock *sk) +__bpf_kfunc static u32 cubictcp_recalc_ssthresh(struct sock *sk) { const struct tcp_sock *tp = tcp_sk(sk); struct bictcp *ca = inet_csk_ca(sk); @@ -346,16 +346,16 @@ static u32 cubictcp_recalc_ssthresh(struct sock *sk) ca->epoch_start = 0; /* end of epoch */ /* Wmax and fast convergence */ - if (tp->snd_cwnd < ca->last_max_cwnd && fast_convergence) - ca->last_max_cwnd = (tp->snd_cwnd * (BICTCP_BETA_SCALE + beta)) + if (tcp_snd_cwnd(tp) < ca->last_max_cwnd && fast_convergence) + ca->last_max_cwnd = (tcp_snd_cwnd(tp) * (BICTCP_BETA_SCALE + beta)) / (2 * BICTCP_BETA_SCALE); else - ca->last_max_cwnd = tp->snd_cwnd; + ca->last_max_cwnd = tcp_snd_cwnd(tp); - return max((tp->snd_cwnd * beta) / BICTCP_BETA_SCALE, 2U); + return max((tcp_snd_cwnd(tp) * beta) / BICTCP_BETA_SCALE, 2U); } -static void cubictcp_state(struct sock *sk, u8 new_state) +__bpf_kfunc static void cubictcp_state(struct sock *sk, u8 new_state) { if (new_state == TCP_CA_Loss) { bictcp_reset(inet_csk_ca(sk)); @@ -372,7 +372,7 @@ static void cubictcp_state(struct sock *sk, u8 new_state) * We apply another 100% factor because @rate is doubled at this point. * We cap the cushion to 1ms. */ -static u32 hystart_ack_delay(struct sock *sk) +static u32 hystart_ack_delay(const struct sock *sk) { unsigned long rate; @@ -380,7 +380,7 @@ static u32 hystart_ack_delay(struct sock *sk) if (!rate) return 0; return min_t(u64, USEC_PER_MSEC, - div64_ul((u64)GSO_MAX_SIZE * 4 * USEC_PER_SEC, rate)); + div64_ul((u64)sk->sk_gso_max_size * 4 * USEC_PER_SEC, rate)); } static void hystart_update(struct sock *sk, u32 delay) @@ -392,6 +392,10 @@ static void hystart_update(struct sock *sk, u32 delay) if (after(tp->snd_una, ca->end_seq)) bictcp_hystart_reset(sk); + /* hystart triggers when cwnd is larger than some threshold */ + if (tcp_snd_cwnd(tp) < hystart_low_window) + return; + if (hystart_detect & HYSTART_ACK_TRAIN) { u32 now = bictcp_clock_us(sk); @@ -413,13 +417,13 @@ static void hystart_update(struct sock *sk, u32 delay) ca->found = 1; pr_debug("hystart_ack_train (%u > %u) delay_min %u (+ ack_delay %u) cwnd %u\n", now - ca->round_start, threshold, - ca->delay_min, hystart_ack_delay(sk), tp->snd_cwnd); + ca->delay_min, hystart_ack_delay(sk), tcp_snd_cwnd(tp)); NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPHYSTARTTRAINDETECT); NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPHYSTARTTRAINCWND, - tp->snd_cwnd); - tp->snd_ssthresh = tp->snd_cwnd; + tcp_snd_cwnd(tp)); + tp->snd_ssthresh = tcp_snd_cwnd(tp); } } } @@ -438,14 +442,14 @@ static void hystart_update(struct sock *sk, u32 delay) LINUX_MIB_TCPHYSTARTDELAYDETECT); NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPHYSTARTDELAYCWND, - tp->snd_cwnd); - tp->snd_ssthresh = tp->snd_cwnd; + tcp_snd_cwnd(tp)); + tp->snd_ssthresh = tcp_snd_cwnd(tp); } } } } -static void cubictcp_acked(struct sock *sk, const struct ack_sample *sample) +__bpf_kfunc static void cubictcp_acked(struct sock *sk, const struct ack_sample *sample) { const struct tcp_sock *tp = tcp_sk(sk); struct bictcp *ca = inet_csk_ca(sk); @@ -467,9 +471,7 @@ static void cubictcp_acked(struct sock *sk, const struct ack_sample *sample) if (ca->delay_min == 0 || ca->delay_min > delay) ca->delay_min = delay; - /* hystart triggers when cwnd is larger than some threshold */ - if (!ca->found && tcp_in_slow_start(tp) && hystart && - tp->snd_cwnd >= hystart_low_window) + if (!ca->found && tcp_in_slow_start(tp) && hystart) hystart_update(sk, delay); } @@ -485,20 +487,19 @@ static struct tcp_congestion_ops cubictcp __read_mostly = { .name = "cubic", }; -BTF_SET_START(tcp_cubic_kfunc_ids) -#ifdef CONFIG_X86 -#ifdef CONFIG_DYNAMIC_FTRACE -BTF_ID(func, cubictcp_init) -BTF_ID(func, cubictcp_recalc_ssthresh) -BTF_ID(func, cubictcp_cong_avoid) -BTF_ID(func, cubictcp_state) -BTF_ID(func, cubictcp_cwnd_event) -BTF_ID(func, cubictcp_acked) -#endif -#endif -BTF_SET_END(tcp_cubic_kfunc_ids) - -static DEFINE_KFUNC_BTF_ID_SET(&tcp_cubic_kfunc_ids, tcp_cubic_kfunc_btf_set); +BTF_KFUNCS_START(tcp_cubic_check_kfunc_ids) +BTF_ID_FLAGS(func, cubictcp_init) +BTF_ID_FLAGS(func, cubictcp_recalc_ssthresh) +BTF_ID_FLAGS(func, cubictcp_cong_avoid) +BTF_ID_FLAGS(func, cubictcp_state) +BTF_ID_FLAGS(func, cubictcp_cwnd_event) +BTF_ID_FLAGS(func, cubictcp_acked) +BTF_KFUNCS_END(tcp_cubic_check_kfunc_ids) + +static const struct btf_kfunc_id_set tcp_cubic_kfunc_set = { + .owner = THIS_MODULE, + .set = &tcp_cubic_check_kfunc_ids, +}; static int __init cubictcp_register(void) { @@ -534,16 +535,14 @@ static int __init cubictcp_register(void) /* divide by bic_scale and by constant Srtt (100ms) */ do_div(cube_factor, bic_scale * 10); - ret = tcp_register_congestion_control(&cubictcp); - if (ret) + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, &tcp_cubic_kfunc_set); + if (ret < 0) return ret; - register_kfunc_btf_id_set(&bpf_tcp_ca_kfunc_list, &tcp_cubic_kfunc_btf_set); - return 0; + return tcp_register_congestion_control(&cubictcp); } static void __exit cubictcp_unregister(void) { - unregister_kfunc_btf_id_set(&bpf_tcp_ca_kfunc_list, &tcp_cubic_kfunc_btf_set); tcp_unregister_congestion_control(&cubictcp); } diff --git a/net/ipv4/tcp_dctcp.c b/net/ipv4/tcp_dctcp.c index 0d7ab3cc7b61..03abe0848420 100644 --- a/net/ipv4/tcp_dctcp.c +++ b/net/ipv4/tcp_dctcp.c @@ -54,10 +54,22 @@ struct dctcp { u32 next_seq; u32 ce_state; u32 loss_cwnd; + struct tcp_plb_state plb; }; static unsigned int dctcp_shift_g __read_mostly = 4; /* g = 1/2^4 */ -module_param(dctcp_shift_g, uint, 0644); + +static int dctcp_shift_g_set(const char *val, const struct kernel_param *kp) +{ + return param_set_uint_minmax(val, kp, 0, 10); +} + +static const struct kernel_param_ops dctcp_shift_g_ops = { + .set = dctcp_shift_g_set, + .get = param_get_uint, +}; + +module_param_cb(dctcp_shift_g, &dctcp_shift_g_ops, &dctcp_shift_g, 0644); MODULE_PARM_DESC(dctcp_shift_g, "parameter g for updating dctcp_alpha"); static unsigned int dctcp_alpha_on_init __read_mostly = DCTCP_MAX_ALPHA; @@ -74,11 +86,11 @@ static void dctcp_reset(const struct tcp_sock *tp, struct dctcp *ca) ca->old_delivered_ce = tp->delivered_ce; } -static void dctcp_init(struct sock *sk) +__bpf_kfunc static void dctcp_init(struct sock *sk) { const struct tcp_sock *tp = tcp_sk(sk); - if ((tp->ecn_flags & TCP_ECN_OK) || + if (tcp_ecn_mode_any(tp) || (sk->sk_state == TCP_LISTEN || sk->sk_state == TCP_CLOSE)) { struct dctcp *ca = inet_csk_ca(sk); @@ -91,6 +103,8 @@ static void dctcp_init(struct sock *sk) ca->ce_state = 0; dctcp_reset(tp, ca); + tcp_plb_init(sk, &ca->plb); + return; } @@ -101,30 +115,44 @@ static void dctcp_init(struct sock *sk) INET_ECN_dontxmit(sk); } -static u32 dctcp_ssthresh(struct sock *sk) +__bpf_kfunc static u32 dctcp_ssthresh(struct sock *sk) { struct dctcp *ca = inet_csk_ca(sk); struct tcp_sock *tp = tcp_sk(sk); - ca->loss_cwnd = tp->snd_cwnd; - return max(tp->snd_cwnd - ((tp->snd_cwnd * ca->dctcp_alpha) >> 11U), 2U); + ca->loss_cwnd = tcp_snd_cwnd(tp); + return max(tcp_snd_cwnd(tp) - ((tcp_snd_cwnd(tp) * ca->dctcp_alpha) >> 11U), 2U); } -static void dctcp_update_alpha(struct sock *sk, u32 flags) +__bpf_kfunc static void dctcp_update_alpha(struct sock *sk, u32 flags) { const struct tcp_sock *tp = tcp_sk(sk); struct dctcp *ca = inet_csk_ca(sk); /* Expired RTT */ if (!before(tp->snd_una, ca->next_seq)) { + u32 delivered = tp->delivered - ca->old_delivered; u32 delivered_ce = tp->delivered_ce - ca->old_delivered_ce; u32 alpha = ca->dctcp_alpha; + u32 ce_ratio = 0; + + if (delivered > 0) { + /* dctcp_alpha keeps EWMA of fraction of ECN marked + * packets. Because of EWMA smoothing, PLB reaction can + * be slow so we use ce_ratio which is an instantaneous + * measure of congestion. ce_ratio is the fraction of + * ECN marked packets in the previous RTT. + */ + if (delivered_ce > 0) + ce_ratio = (delivered_ce << TCP_PLB_SCALE) / delivered; + tcp_plb_update_state(sk, &ca->plb, (int)ce_ratio); + tcp_plb_check_rehash(sk, &ca->plb); + } /* alpha = (1 - g) * alpha + g * F */ alpha -= min_not_zero(alpha, alpha >> dctcp_shift_g); if (delivered_ce) { - u32 delivered = tp->delivered - ca->old_delivered; /* If dctcp_shift_g == 1, a 32bit value would overflow * after 8 M packets. @@ -148,11 +176,11 @@ static void dctcp_react_to_loss(struct sock *sk) struct dctcp *ca = inet_csk_ca(sk); struct tcp_sock *tp = tcp_sk(sk); - ca->loss_cwnd = tp->snd_cwnd; - tp->snd_ssthresh = max(tp->snd_cwnd >> 1U, 2U); + ca->loss_cwnd = tcp_snd_cwnd(tp); + tp->snd_ssthresh = max(tcp_snd_cwnd(tp) >> 1U, 2U); } -static void dctcp_state(struct sock *sk, u8 new_state) +__bpf_kfunc static void dctcp_state(struct sock *sk, u8 new_state) { if (new_state == TCP_CA_Recovery && new_state != inet_csk(sk)->icsk_ca_state) @@ -162,7 +190,7 @@ static void dctcp_state(struct sock *sk, u8 new_state) */ } -static void dctcp_cwnd_event(struct sock *sk, enum tcp_ca_event ev) +__bpf_kfunc static void dctcp_cwnd_event(struct sock *sk, enum tcp_ca_event ev) { struct dctcp *ca = inet_csk_ca(sk); @@ -172,8 +200,12 @@ static void dctcp_cwnd_event(struct sock *sk, enum tcp_ca_event ev) dctcp_ece_ack_update(sk, ev, &ca->prior_rcv_nxt, &ca->ce_state); break; case CA_EVENT_LOSS: + tcp_plb_update_state_upon_rto(sk, &ca->plb); dctcp_react_to_loss(sk); break; + case CA_EVENT_TX_START: + tcp_plb_check_rehash(sk, &ca->plb); /* Maybe rehash when inflight is 0 */ + break; default: /* Don't care for the rest. */ break; @@ -208,11 +240,12 @@ static size_t dctcp_get_info(struct sock *sk, u32 ext, int *attr, return 0; } -static u32 dctcp_cwnd_undo(struct sock *sk) +__bpf_kfunc static u32 dctcp_cwnd_undo(struct sock *sk) { const struct dctcp *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); - return max(tcp_sk(sk)->snd_cwnd, ca->loss_cwnd); + return max(tcp_snd_cwnd(tp), ca->loss_cwnd); } static struct tcp_congestion_ops dctcp __read_mostly = { @@ -238,36 +271,34 @@ static struct tcp_congestion_ops dctcp_reno __read_mostly = { .name = "dctcp-reno", }; -BTF_SET_START(tcp_dctcp_kfunc_ids) -#ifdef CONFIG_X86 -#ifdef CONFIG_DYNAMIC_FTRACE -BTF_ID(func, dctcp_init) -BTF_ID(func, dctcp_update_alpha) -BTF_ID(func, dctcp_cwnd_event) -BTF_ID(func, dctcp_ssthresh) -BTF_ID(func, dctcp_cwnd_undo) -BTF_ID(func, dctcp_state) -#endif -#endif -BTF_SET_END(tcp_dctcp_kfunc_ids) - -static DEFINE_KFUNC_BTF_ID_SET(&tcp_dctcp_kfunc_ids, tcp_dctcp_kfunc_btf_set); +BTF_KFUNCS_START(tcp_dctcp_check_kfunc_ids) +BTF_ID_FLAGS(func, dctcp_init) +BTF_ID_FLAGS(func, dctcp_update_alpha) +BTF_ID_FLAGS(func, dctcp_cwnd_event) +BTF_ID_FLAGS(func, dctcp_ssthresh) +BTF_ID_FLAGS(func, dctcp_cwnd_undo) +BTF_ID_FLAGS(func, dctcp_state) +BTF_KFUNCS_END(tcp_dctcp_check_kfunc_ids) + +static const struct btf_kfunc_id_set tcp_dctcp_kfunc_set = { + .owner = THIS_MODULE, + .set = &tcp_dctcp_check_kfunc_ids, +}; static int __init dctcp_register(void) { int ret; BUILD_BUG_ON(sizeof(struct dctcp) > ICSK_CA_PRIV_SIZE); - ret = tcp_register_congestion_control(&dctcp); - if (ret) + + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, &tcp_dctcp_kfunc_set); + if (ret < 0) return ret; - register_kfunc_btf_id_set(&bpf_tcp_ca_kfunc_list, &tcp_dctcp_kfunc_btf_set); - return 0; + return tcp_register_congestion_control(&dctcp); } static void __exit dctcp_unregister(void) { - unregister_kfunc_btf_id_set(&bpf_tcp_ca_kfunc_list, &tcp_dctcp_kfunc_btf_set); tcp_unregister_congestion_control(&dctcp); } diff --git a/net/ipv4/tcp_dctcp.h b/net/ipv4/tcp_dctcp.h index d69a77cbd0c7..4b0259111d81 100644 --- a/net/ipv4/tcp_dctcp.h +++ b/net/ipv4/tcp_dctcp.h @@ -28,7 +28,7 @@ static inline void dctcp_ece_ack_update(struct sock *sk, enum tcp_ca_event evt, */ if (inet_csk(sk)->icsk_ack.pending & ICSK_ACK_TIMER) { dctcp_ece_ack_cwr(sk, *ce_state); - __tcp_send_ack(sk, *prior_rcv_nxt); + __tcp_send_ack(sk, *prior_rcv_nxt, 0); } inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW; } diff --git a/net/ipv4/tcp_diag.c b/net/ipv4/tcp_diag.c index 75a1c985f49a..d83efd91f461 100644 --- a/net/ipv4/tcp_diag.c +++ b/net/ipv4/tcp_diag.c @@ -12,6 +12,9 @@ #include <linux/tcp.h> +#include <net/inet_hashtables.h> +#include <net/inet6_hashtables.h> +#include <net/inet_timewait_sock.h> #include <net/netlink.h> #include <net/tcp.h> @@ -83,7 +86,7 @@ static int tcp_diag_put_md5sig(struct sk_buff *skb, #endif static int tcp_diag_put_ulp(struct sk_buff *skb, struct sock *sk, - const struct tcp_ulp_ops *ulp_ops) + const struct tcp_ulp_ops *ulp_ops, bool net_admin) { struct nlattr *nest; int err; @@ -97,7 +100,7 @@ static int tcp_diag_put_ulp(struct sk_buff *skb, struct sock *sk, goto nla_failure; if (ulp_ops->get_info) - err = ulp_ops->get_info(sk, skb); + err = ulp_ops->get_info(sk, skb, net_admin); if (err) goto nla_failure; @@ -113,6 +116,7 @@ static int tcp_diag_get_aux(struct sock *sk, bool net_admin, struct sk_buff *skb) { struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_ulp_ops *ulp_ops; int err = 0; #ifdef CONFIG_TCP_MD5SIG @@ -129,15 +133,13 @@ static int tcp_diag_get_aux(struct sock *sk, bool net_admin, } #endif - if (net_admin) { - const struct tcp_ulp_ops *ulp_ops; - - ulp_ops = icsk->icsk_ulp_ops; - if (ulp_ops) - err = tcp_diag_put_ulp(skb, sk, ulp_ops); - if (err) + ulp_ops = icsk->icsk_ulp_ops; + if (ulp_ops) { + err = tcp_diag_put_ulp(skb, sk, ulp_ops, net_admin); + if (err < 0) return err; } + return 0; } @@ -164,7 +166,7 @@ static size_t tcp_diag_get_aux_size(struct sock *sk, bool net_admin) } #endif - if (net_admin && sk_fullsock(sk)) { + if (sk_fullsock(sk)) { const struct tcp_ulp_ops *ulp_ops; ulp_ops = icsk->icsk_ulp_ops; @@ -172,22 +174,468 @@ static size_t tcp_diag_get_aux_size(struct sock *sk, bool net_admin) size += nla_total_size(0) + nla_total_size(TCP_ULP_NAME_MAX); if (ulp_ops->get_info_size) - size += ulp_ops->get_info_size(sk); + size += ulp_ops->get_info_size(sk, net_admin); } } - return size; + + return size + + nla_total_size(sizeof(struct tcp_info)) + + nla_total_size(sizeof(struct inet_diag_msg)) + + inet_diag_msg_attrs_size() + + nla_total_size(sizeof(struct inet_diag_meminfo)) + + nla_total_size(SK_MEMINFO_VARS * sizeof(u32)) + + nla_total_size(TCP_CA_NAME_MAX) + + nla_total_size(sizeof(struct tcpvegas_info)) + + 64; +} + +static int tcp_twsk_diag_fill(struct sock *sk, + struct sk_buff *skb, + struct netlink_callback *cb, + u16 nlmsg_flags, bool net_admin) +{ + struct inet_timewait_sock *tw = inet_twsk(sk); + struct inet_diag_msg *r; + struct nlmsghdr *nlh; + long tmo; + + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type, + sizeof(*r), nlmsg_flags); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + DEBUG_NET_WARN_ON_ONCE(tw->tw_state != TCP_TIME_WAIT); + + inet_diag_msg_common_fill(r, sk); + r->idiag_retrans = 0; + + r->idiag_state = READ_ONCE(tw->tw_substate); + r->idiag_timer = 3; + tmo = tw->tw_timer.expires - jiffies; + r->idiag_expires = jiffies_delta_to_msecs(tmo); + r->idiag_rqueue = 0; + r->idiag_wqueue = 0; + r->idiag_uid = 0; + r->idiag_inode = 0; + + if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, + tw->tw_mark)) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static int tcp_req_diag_fill(struct sock *sk, struct sk_buff *skb, + struct netlink_callback *cb, + u16 nlmsg_flags, bool net_admin) +{ + struct request_sock *reqsk = inet_reqsk(sk); + struct inet_diag_msg *r; + struct nlmsghdr *nlh; + long tmo; + + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + inet_diag_msg_common_fill(r, sk); + r->idiag_state = TCP_SYN_RECV; + r->idiag_timer = 1; + r->idiag_retrans = READ_ONCE(reqsk->num_retrans); + + BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) != + offsetof(struct sock, sk_cookie)); + + tmo = READ_ONCE(inet_reqsk(sk)->rsk_timer.expires) - jiffies; + r->idiag_expires = jiffies_delta_to_msecs(tmo); + r->idiag_rqueue = 0; + r->idiag_wqueue = 0; + r->idiag_uid = 0; + r->idiag_inode = 0; + + if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, + inet_rsk(reqsk)->ir_mark)) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, + struct netlink_callback *cb, + const struct inet_diag_req_v2 *r, + u16 nlmsg_flags, bool net_admin) +{ + if (sk->sk_state == TCP_TIME_WAIT) + return tcp_twsk_diag_fill(sk, skb, cb, nlmsg_flags, net_admin); + + if (sk->sk_state == TCP_NEW_SYN_RECV) + return tcp_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin); + + return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags, + net_admin); +} + +static void twsk_build_assert(void) +{ + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) != + offsetof(struct sock, sk_family)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) != + offsetof(struct inet_sock, inet_num)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) != + offsetof(struct inet_sock, inet_dport)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) != + offsetof(struct inet_sock, inet_rcv_saddr)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) != + offsetof(struct inet_sock, inet_daddr)); + +#if IS_ENABLED(CONFIG_IPV6) + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) != + offsetof(struct sock, sk_v6_rcv_saddr)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) != + offsetof(struct sock, sk_v6_daddr)); +#endif } static void tcp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *r) { - inet_diag_dump_icsk(&tcp_hashinfo, skb, cb, r); + bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); + struct inet_diag_dump_data *cb_data = cb->data; + struct net *net = sock_net(skb->sk); + u32 idiag_states = r->idiag_states; + struct inet_hashinfo *hashinfo; + int i, num, s_i, s_num; + struct sock *sk; + + hashinfo = net->ipv4.tcp_death_row.hashinfo; + if (idiag_states & TCPF_SYN_RECV) + idiag_states |= TCPF_NEW_SYN_RECV; + s_i = cb->args[1]; + s_num = num = cb->args[2]; + + if (cb->args[0] == 0) { + if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport) + goto skip_listen_ht; + + for (i = s_i; i <= hashinfo->lhash2_mask; i++) { + struct inet_listen_hashbucket *ilb; + struct hlist_nulls_node *node; + + num = 0; + ilb = &hashinfo->lhash2[i]; + + if (hlist_nulls_empty(&ilb->nulls_head)) { + s_num = 0; + continue; + } + spin_lock(&ilb->lock); + sk_nulls_for_each(sk, node, &ilb->nulls_head) { + struct inet_sock *inet = inet_sk(sk); + + if (!net_eq(sock_net(sk), net)) + continue; + + if (num < s_num) { + num++; + continue; + } + + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) + goto next_listen; + + if (r->id.idiag_sport != inet->inet_sport && + r->id.idiag_sport) + goto next_listen; + + if (!inet_diag_bc_sk(cb_data, sk)) + goto next_listen; + + if (inet_sk_diag_fill(sk, inet_csk(sk), skb, + cb, r, NLM_F_MULTI, + net_admin) < 0) { + spin_unlock(&ilb->lock); + goto done; + } + +next_listen: + ++num; + } + spin_unlock(&ilb->lock); + + s_num = 0; + } +skip_listen_ht: + cb->args[0] = 1; + s_i = num = s_num = 0; + } + +/* Process a maximum of SKARR_SZ sockets at a time when walking hash buckets + * with bh disabled. + */ +#define SKARR_SZ 16 + + /* Dump bound but inactive (not listening, connecting, etc.) sockets */ + if (cb->args[0] == 1) { + if (!(idiag_states & TCPF_BOUND_INACTIVE)) + goto skip_bind_ht; + + for (i = s_i; i < hashinfo->bhash_size; i++) { + struct inet_bind_hashbucket *ibb; + struct inet_bind2_bucket *tb2; + struct sock *sk_arr[SKARR_SZ]; + int num_arr[SKARR_SZ]; + int idx, accum, res; + +resume_bind_walk: + num = 0; + accum = 0; + ibb = &hashinfo->bhash2[i]; + + if (hlist_empty(&ibb->chain)) { + s_num = 0; + continue; + } + spin_lock_bh(&ibb->lock); + inet_bind_bucket_for_each(tb2, &ibb->chain) { + if (!net_eq(ib2_net(tb2), net)) + continue; + + sk_for_each_bound(sk, &tb2->owners) { + struct inet_sock *inet = inet_sk(sk); + + if (num < s_num) + goto next_bind; + + if (sk->sk_state != TCP_CLOSE || + !inet->inet_num) + goto next_bind; + + if (r->sdiag_family != AF_UNSPEC && + r->sdiag_family != sk->sk_family) + goto next_bind; + + if (!inet_diag_bc_sk(cb_data, sk)) + goto next_bind; + + sock_hold(sk); + num_arr[accum] = num; + sk_arr[accum] = sk; + if (++accum == SKARR_SZ) + goto pause_bind_walk; +next_bind: + num++; + } + } +pause_bind_walk: + spin_unlock_bh(&ibb->lock); + + res = 0; + for (idx = 0; idx < accum; idx++) { + if (res >= 0) { + res = inet_sk_diag_fill(sk_arr[idx], + NULL, skb, cb, + r, NLM_F_MULTI, + net_admin); + if (res < 0) + num = num_arr[idx]; + } + sock_put(sk_arr[idx]); + } + if (res < 0) + goto done; + + cond_resched(); + + if (accum == SKARR_SZ) { + s_num = num + 1; + goto resume_bind_walk; + } + + s_num = 0; + } +skip_bind_ht: + cb->args[0] = 2; + s_i = num = s_num = 0; + } + + if (!(idiag_states & ~TCPF_LISTEN)) + goto out; + + for (i = s_i; i <= hashinfo->ehash_mask; i++) { + struct inet_ehash_bucket *head = &hashinfo->ehash[i]; + spinlock_t *lock = inet_ehash_lockp(hashinfo, i); + struct hlist_nulls_node *node; + struct sock *sk_arr[SKARR_SZ]; + int num_arr[SKARR_SZ]; + int idx, accum, res; + + if (hlist_nulls_empty(&head->chain)) + continue; + + if (i > s_i) + s_num = 0; + +next_chunk: + num = 0; + accum = 0; + spin_lock_bh(lock); + sk_nulls_for_each(sk, node, &head->chain) { + int state; + + if (!net_eq(sock_net(sk), net)) + continue; + if (num < s_num) + goto next_normal; + state = (sk->sk_state == TCP_TIME_WAIT) ? + READ_ONCE(inet_twsk(sk)->tw_substate) : sk->sk_state; + if (!(idiag_states & (1 << state))) + goto next_normal; + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) + goto next_normal; + if (r->id.idiag_sport != htons(sk->sk_num) && + r->id.idiag_sport) + goto next_normal; + if (r->id.idiag_dport != sk->sk_dport && + r->id.idiag_dport) + goto next_normal; + twsk_build_assert(); + + if (!inet_diag_bc_sk(cb_data, sk)) + goto next_normal; + + if (!refcount_inc_not_zero(&sk->sk_refcnt)) + goto next_normal; + + num_arr[accum] = num; + sk_arr[accum] = sk; + if (++accum == SKARR_SZ) + break; +next_normal: + ++num; + } + spin_unlock_bh(lock); + + res = 0; + for (idx = 0; idx < accum; idx++) { + if (res >= 0) { + res = sk_diag_fill(sk_arr[idx], skb, cb, r, + NLM_F_MULTI, net_admin); + if (res < 0) + num = num_arr[idx]; + } + sock_gen_put(sk_arr[idx]); + } + if (res < 0) + break; + + cond_resched(); + + if (accum == SKARR_SZ) { + s_num = num + 1; + goto next_chunk; + } + } + +done: + cb->args[1] = i; + cb->args[2] = num; +out: + ; +} + +static struct sock *tcp_diag_find_one_icsk(struct net *net, + const struct inet_diag_req_v2 *req) +{ + struct sock *sk; + + rcu_read_lock(); + if (req->sdiag_family == AF_INET) { + sk = inet_lookup(net, NULL, 0, req->id.idiag_dst[0], + req->id.idiag_dport, req->id.idiag_src[0], + req->id.idiag_sport, req->id.idiag_if); +#if IS_ENABLED(CONFIG_IPV6) + } else if (req->sdiag_family == AF_INET6) { + if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) && + ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src)) + sk = inet_lookup(net, NULL, 0, req->id.idiag_dst[3], + req->id.idiag_dport, req->id.idiag_src[3], + req->id.idiag_sport, req->id.idiag_if); + else + sk = inet6_lookup(net, NULL, 0, + (struct in6_addr *)req->id.idiag_dst, + req->id.idiag_dport, + (struct in6_addr *)req->id.idiag_src, + req->id.idiag_sport, + req->id.idiag_if); +#endif + } else { + rcu_read_unlock(); + return ERR_PTR(-EINVAL); + } + rcu_read_unlock(); + if (!sk) + return ERR_PTR(-ENOENT); + + if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) { + sock_gen_put(sk); + return ERR_PTR(-ENOENT); + } + + return sk; } static int tcp_diag_dump_one(struct netlink_callback *cb, const struct inet_diag_req_v2 *req) { - return inet_diag_dump_one_icsk(&tcp_hashinfo, cb, req); + struct sk_buff *in_skb = cb->skb; + struct sk_buff *rep; + struct sock *sk; + struct net *net; + bool net_admin; + int err; + + net = sock_net(in_skb->sk); + sk = tcp_diag_find_one_icsk(net, req); + if (IS_ERR(sk)) + return PTR_ERR(sk); + + net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN); + rep = nlmsg_new(tcp_diag_get_aux_size(sk, net_admin), GFP_KERNEL); + if (!rep) { + err = -ENOMEM; + goto out; + } + + err = sk_diag_fill(sk, rep, cb, req, 0, net_admin); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + nlmsg_free(rep); + goto out; + } + err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid); + +out: + if (sk) + sock_gen_put(sk); + + return err; } #ifdef CONFIG_INET_DIAG_DESTROY @@ -195,9 +643,10 @@ static int tcp_diag_destroy(struct sk_buff *in_skb, const struct inet_diag_req_v2 *req) { struct net *net = sock_net(in_skb->sk); - struct sock *sk = inet_diag_find_one_icsk(net, &tcp_hashinfo, req); + struct sock *sk; int err; + sk = tcp_diag_find_one_icsk(net, req); if (IS_ERR(sk)) return PTR_ERR(sk); @@ -210,11 +659,11 @@ static int tcp_diag_destroy(struct sk_buff *in_skb, #endif static const struct inet_diag_handler tcp_diag_handler = { + .owner = THIS_MODULE, .dump = tcp_diag_dump, .dump_one = tcp_diag_dump_one, .idiag_get_info = tcp_diag_get_info, .idiag_get_aux = tcp_diag_get_aux, - .idiag_get_aux_size = tcp_diag_get_aux_size, .idiag_type = IPPROTO_TCP, .idiag_info_size = sizeof(struct tcp_info), #ifdef CONFIG_INET_DIAG_DESTROY @@ -235,4 +684,5 @@ static void __exit tcp_diag_exit(void) module_init(tcp_diag_init); module_exit(tcp_diag_exit); MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("TCP socket monitoring via SOCK_DIAG"); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-6 /* AF_INET - IPPROTO_TCP */); diff --git a/net/ipv4/tcp_fastopen.c b/net/ipv4/tcp_fastopen.c index fdbcf2a6d08e..7d945a527daf 100644 --- a/net/ipv4/tcp_fastopen.c +++ b/net/ipv4/tcp_fastopen.c @@ -3,6 +3,7 @@ #include <linux/tcp.h> #include <linux/rcupdate.h> #include <net/tcp.h> +#include <net/busy_poll.h> void tcp_fastopen_init_key_once(struct net *net) { @@ -49,7 +50,7 @@ void tcp_fastopen_ctx_destroy(struct net *net) { struct tcp_fastopen_context *ctxt; - ctxt = xchg((__force struct tcp_fastopen_context **)&net->ipv4.tcp_fastopen_ctx, NULL); + ctxt = unrcu_pointer(xchg(&net->ipv4.tcp_fastopen_ctx, NULL)); if (ctxt) call_rcu(&ctxt->rcu, tcp_fastopen_ctx_free); @@ -80,9 +81,10 @@ int tcp_fastopen_reset_cipher(struct net *net, struct sock *sk, if (sk) { q = &inet_csk(sk)->icsk_accept_queue.fastopenq; - octx = xchg((__force struct tcp_fastopen_context **)&q->ctx, ctx); + octx = unrcu_pointer(xchg(&q->ctx, RCU_INITIALIZER(ctx))); } else { - octx = xchg((__force struct tcp_fastopen_context **)&net->ipv4.tcp_fastopen_ctx, ctx); + octx = unrcu_pointer(xchg(&net->ipv4.tcp_fastopen_ctx, + RCU_INITIALIZER(ctx))); } if (octx) @@ -177,7 +179,7 @@ void tcp_fastopen_add_skb(struct sock *sk, struct sk_buff *skb) if (!skb) return; - skb_dst_drop(skb); + tcp_cleanup_skb(skb); /* segs_in has been initialized to 1 in tcp_create_openreq_child(). * Hence, reset segs_in to 0 before calling tcp_segs_in() * to avoid double counting. Also, tcp_segs_in() expects @@ -194,7 +196,7 @@ void tcp_fastopen_add_skb(struct sock *sk, struct sk_buff *skb) TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_SYN; tp->rcv_nxt = TCP_SKB_CB(skb)->end_seq; - __skb_queue_tail(&sk->sk_receive_queue, skb); + tcp_add_receive_queue(sk, skb); tp->syn_data_acked = 1; /* u64_stats_update_begin(&tp->syncp) not needed here, @@ -272,11 +274,14 @@ static struct sock *tcp_fastopen_create_child(struct sock *sk, * The request socket is not added to the ehash * because it's been added to the accept queue directly. */ - inet_csk_reset_xmit_timer(child, ICSK_TIME_RETRANS, - TCP_TIMEOUT_INIT, TCP_RTO_MAX); + req->timeout = tcp_timeout_init(child); + tcp_reset_xmit_timer(child, ICSK_TIME_RETRANS, + req->timeout, false); refcount_set(&req->rsk_refcnt, 2); + sk_mark_napi_id_set(child, skb); + /* Now finish processing the fastopen child socket. */ tcp_init_transfer(child, BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB, skb); @@ -295,6 +300,7 @@ static struct sock *tcp_fastopen_create_child(struct sock *sk, static bool tcp_fastopen_queue_check(struct sock *sk) { struct fastopen_queue *fastopenq; + int max_qlen; /* Make sure the listener has enabled fastopen, and we don't * exceed the max # of pending TFO requests allowed before trying @@ -307,10 +313,11 @@ static bool tcp_fastopen_queue_check(struct sock *sk) * temporarily vs a server not supporting Fast Open at all. */ fastopenq = &inet_csk(sk)->icsk_accept_queue.fastopenq; - if (fastopenq->max_qlen == 0) + max_qlen = READ_ONCE(fastopenq->max_qlen); + if (max_qlen == 0) return false; - if (fastopenq->qlen >= fastopenq->max_qlen) { + if (fastopenq->qlen >= max_qlen) { struct request_sock *req1; spin_lock(&fastopenq->lock); req1 = fastopenq->rskq_rst_head; @@ -332,7 +339,7 @@ static bool tcp_fastopen_no_cookie(const struct sock *sk, const struct dst_entry *dst, int flag) { - return (sock_net(sk)->ipv4.sysctl_tcp_fastopen & flag) || + return (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen) & flag) || tcp_sk(sk)->fastopen_no_cookie || (dst && dst_metric(dst, RTAX_FASTOPEN_NO_COOKIE)); } @@ -347,7 +354,7 @@ struct sock *tcp_try_fastopen(struct sock *sk, struct sk_buff *skb, const struct dst_entry *dst) { bool syn_data = TCP_SKB_CB(skb)->end_seq != TCP_SKB_CB(skb)->seq + 1; - int tcp_fastopen = sock_net(sk)->ipv4.sysctl_tcp_fastopen; + int tcp_fastopen = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen); struct tcp_fastopen_cookie valid_foc = { .len = -1 }; struct sock *child; int ret = 0; @@ -397,6 +404,7 @@ fastopen: } NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPFASTOPENPASSIVE); + tcp_sk(child)->syn_fastopen_child = 1; return child; } NET_INC_STATS(sock_net(sk), @@ -448,7 +456,7 @@ bool tcp_fastopen_defer_connect(struct sock *sk, int *err) if (tp->fastopen_connect && !tp->fastopen_req) { if (tcp_fastopen_cookie_check(sk, &mss, &cookie)) { - inet_sk(sk)->defer_connect = 1; + inet_set_bit(DEFER_CONNECT, sk); return true; } @@ -464,7 +472,7 @@ bool tcp_fastopen_defer_connect(struct sock *sk, int *err) } return false; } -EXPORT_SYMBOL(tcp_fastopen_defer_connect); +EXPORT_IPV6_MOD(tcp_fastopen_defer_connect); /* * The following code block is to deal with middle box issues with TFO: @@ -489,7 +497,7 @@ void tcp_fastopen_active_disable(struct sock *sk) { struct net *net = sock_net(sk); - if (!sock_net(sk)->ipv4.sysctl_tcp_fastopen_blackhole_timeout) + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen_blackhole_timeout)) return; /* Paired with READ_ONCE() in tcp_fastopen_active_should_disable() */ @@ -510,7 +518,8 @@ void tcp_fastopen_active_disable(struct sock *sk) */ bool tcp_fastopen_active_should_disable(struct sock *sk) { - unsigned int tfo_bh_timeout = sock_net(sk)->ipv4.sysctl_tcp_fastopen_blackhole_timeout; + unsigned int tfo_bh_timeout = + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen_blackhole_timeout); unsigned long timeout; int tfo_da_times; int multiplier; @@ -550,6 +559,7 @@ bool tcp_fastopen_active_should_disable(struct sock *sk) void tcp_fastopen_active_disable_ofo_check(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); + struct net_device *dev; struct dst_entry *dst; struct sk_buff *skb; @@ -566,10 +576,12 @@ void tcp_fastopen_active_disable_ofo_check(struct sock *sk) } } else if (tp->syn_fastopen_ch && atomic_read(&sock_net(sk)->ipv4.tfo_active_disable_times)) { - dst = sk_dst_get(sk); - if (!(dst && dst->dev && (dst->dev->flags & IFF_LOOPBACK))) + rcu_read_lock(); + dst = __sk_dst_get(sk); + dev = dst ? dst_dev_rcu(dst) : NULL; + if (!(dev && (dev->flags & IFF_LOOPBACK))) atomic_set(&sock_net(sk)->ipv4.tfo_active_disable_times, 0); - dst_release(dst); + rcu_read_unlock(); } } diff --git a/net/ipv4/tcp_highspeed.c b/net/ipv4/tcp_highspeed.c index 349069d6cd0a..c6de5ce79ad3 100644 --- a/net/ipv4/tcp_highspeed.c +++ b/net/ipv4/tcp_highspeed.c @@ -127,22 +127,22 @@ static void hstcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) * snd_cwnd <= * hstcp_aimd_vals[ca->ai].cwnd */ - if (tp->snd_cwnd > hstcp_aimd_vals[ca->ai].cwnd) { - while (tp->snd_cwnd > hstcp_aimd_vals[ca->ai].cwnd && + if (tcp_snd_cwnd(tp) > hstcp_aimd_vals[ca->ai].cwnd) { + while (tcp_snd_cwnd(tp) > hstcp_aimd_vals[ca->ai].cwnd && ca->ai < HSTCP_AIMD_MAX - 1) ca->ai++; - } else if (ca->ai && tp->snd_cwnd <= hstcp_aimd_vals[ca->ai-1].cwnd) { - while (ca->ai && tp->snd_cwnd <= hstcp_aimd_vals[ca->ai-1].cwnd) + } else if (ca->ai && tcp_snd_cwnd(tp) <= hstcp_aimd_vals[ca->ai-1].cwnd) { + while (ca->ai && tcp_snd_cwnd(tp) <= hstcp_aimd_vals[ca->ai-1].cwnd) ca->ai--; } /* Do additive increase */ - if (tp->snd_cwnd < tp->snd_cwnd_clamp) { + if (tcp_snd_cwnd(tp) < tp->snd_cwnd_clamp) { /* cwnd = cwnd + a(w) / cwnd */ tp->snd_cwnd_cnt += ca->ai + 1; - if (tp->snd_cwnd_cnt >= tp->snd_cwnd) { - tp->snd_cwnd_cnt -= tp->snd_cwnd; - tp->snd_cwnd++; + if (tp->snd_cwnd_cnt >= tcp_snd_cwnd(tp)) { + tp->snd_cwnd_cnt -= tcp_snd_cwnd(tp); + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); } } } @@ -154,7 +154,7 @@ static u32 hstcp_ssthresh(struct sock *sk) struct hstcp *ca = inet_csk_ca(sk); /* Do multiplicative decrease */ - return max(tp->snd_cwnd - ((tp->snd_cwnd * hstcp_aimd_vals[ca->ai].md) >> 8), 2U); + return max(tcp_snd_cwnd(tp) - ((tcp_snd_cwnd(tp) * hstcp_aimd_vals[ca->ai].md) >> 8), 2U); } static struct tcp_congestion_ops tcp_highspeed __read_mostly = { diff --git a/net/ipv4/tcp_htcp.c b/net/ipv4/tcp_htcp.c index 55adcfcf96fe..81b96331b2bb 100644 --- a/net/ipv4/tcp_htcp.c +++ b/net/ipv4/tcp_htcp.c @@ -124,7 +124,7 @@ static void measure_achieved_throughput(struct sock *sk, ca->packetcount += sample->pkts_acked; - if (ca->packetcount >= tp->snd_cwnd - (ca->alpha >> 7 ? : 1) && + if (ca->packetcount >= tcp_snd_cwnd(tp) - (ca->alpha >> 7 ? : 1) && now - ca->lasttime >= ca->minRTT && ca->minRTT > 0) { __u32 cur_Bi = ca->packetcount * HZ / (now - ca->lasttime); @@ -185,7 +185,7 @@ static inline void htcp_alpha_update(struct htcp *ca) u32 scale = (HZ << 3) / (10 * minRTT); /* clamping ratio to interval [0.5,10]<<3 */ - scale = min(max(scale, 1U << 2), 10U << 3); + scale = clamp(scale, 1U << 2, 10U << 3); factor = (factor << 3) / scale; if (!factor) factor = 1; @@ -225,7 +225,7 @@ static u32 htcp_recalc_ssthresh(struct sock *sk) const struct htcp *ca = inet_csk_ca(sk); htcp_param_update(sk); - return max((tp->snd_cwnd * ca->beta) >> 7, 2U); + return max((tcp_snd_cwnd(tp) * ca->beta) >> 7, 2U); } static void htcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) @@ -242,9 +242,9 @@ static void htcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) /* In dangerous area, increase slowly. * In theory this is tp->snd_cwnd += alpha / tp->snd_cwnd */ - if ((tp->snd_cwnd_cnt * ca->alpha)>>7 >= tp->snd_cwnd) { - if (tp->snd_cwnd < tp->snd_cwnd_clamp) - tp->snd_cwnd++; + if ((tp->snd_cwnd_cnt * ca->alpha)>>7 >= tcp_snd_cwnd(tp)) { + if (tcp_snd_cwnd(tp) < tp->snd_cwnd_clamp) + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); tp->snd_cwnd_cnt = 0; htcp_alpha_update(ca); } else diff --git a/net/ipv4/tcp_hybla.c b/net/ipv4/tcp_hybla.c index be39327e04e6..abd7d91807e5 100644 --- a/net/ipv4/tcp_hybla.c +++ b/net/ipv4/tcp_hybla.c @@ -54,7 +54,7 @@ static void hybla_init(struct sock *sk) ca->rho2_7ls = 0; ca->snd_cwnd_cents = 0; ca->hybla_en = true; - tp->snd_cwnd = 2; + tcp_snd_cwnd_set(tp, 2); tp->snd_cwnd_clamp = 65535; /* 1st Rho measurement based on initial srtt */ @@ -62,7 +62,7 @@ static void hybla_init(struct sock *sk) /* set minimum rtt as this is the 1st ever seen */ ca->minrtt_us = tp->srtt_us; - tp->snd_cwnd = ca->rho; + tcp_snd_cwnd_set(tp, ca->rho); } static void hybla_state(struct sock *sk, u8 ca_state) @@ -137,31 +137,31 @@ static void hybla_cong_avoid(struct sock *sk, u32 ack, u32 acked) * as long as increment is estimated as (rho<<7)/window * it already is <<7 and we can easily count its fractions. */ - increment = ca->rho2_7ls / tp->snd_cwnd; + increment = ca->rho2_7ls / tcp_snd_cwnd(tp); if (increment < 128) tp->snd_cwnd_cnt++; } odd = increment % 128; - tp->snd_cwnd += increment >> 7; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + (increment >> 7)); ca->snd_cwnd_cents += odd; /* check when fractions goes >=128 and increase cwnd by 1. */ while (ca->snd_cwnd_cents >= 128) { - tp->snd_cwnd++; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); ca->snd_cwnd_cents -= 128; tp->snd_cwnd_cnt = 0; } /* check when cwnd has not been incremented for a while */ - if (increment == 0 && odd == 0 && tp->snd_cwnd_cnt >= tp->snd_cwnd) { - tp->snd_cwnd++; + if (increment == 0 && odd == 0 && tp->snd_cwnd_cnt >= tcp_snd_cwnd(tp)) { + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); tp->snd_cwnd_cnt = 0; } /* clamp down slowstart cwnd to ssthresh value. */ if (is_slowstart) - tp->snd_cwnd = min(tp->snd_cwnd, tp->snd_ssthresh); + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), tp->snd_ssthresh)); - tp->snd_cwnd = min_t(u32, tp->snd_cwnd, tp->snd_cwnd_clamp); + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), tp->snd_cwnd_clamp)); } static struct tcp_congestion_ops tcp_hybla __read_mostly = { diff --git a/net/ipv4/tcp_illinois.c b/net/ipv4/tcp_illinois.c index 00e54873213e..c0c81a2c77fa 100644 --- a/net/ipv4/tcp_illinois.c +++ b/net/ipv4/tcp_illinois.c @@ -224,7 +224,7 @@ static void update_params(struct sock *sk) struct tcp_sock *tp = tcp_sk(sk); struct illinois *ca = inet_csk_ca(sk); - if (tp->snd_cwnd < win_thresh) { + if (tcp_snd_cwnd(tp) < win_thresh) { ca->alpha = ALPHA_BASE; ca->beta = BETA_BASE; } else if (ca->cnt_rtt > 0) { @@ -284,9 +284,9 @@ static void tcp_illinois_cong_avoid(struct sock *sk, u32 ack, u32 acked) * tp->snd_cwnd += alpha/tp->snd_cwnd */ delta = (tp->snd_cwnd_cnt * ca->alpha) >> ALPHA_SHIFT; - if (delta >= tp->snd_cwnd) { - tp->snd_cwnd = min(tp->snd_cwnd + delta / tp->snd_cwnd, - (u32)tp->snd_cwnd_clamp); + if (delta >= tcp_snd_cwnd(tp)) { + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp) + delta / tcp_snd_cwnd(tp), + (u32)tp->snd_cwnd_clamp)); tp->snd_cwnd_cnt = 0; } } @@ -296,9 +296,11 @@ static u32 tcp_illinois_ssthresh(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); struct illinois *ca = inet_csk_ca(sk); + u32 decr; /* Multiplicative decrease */ - return max(tp->snd_cwnd - ((tp->snd_cwnd * ca->beta) >> BETA_SHIFT), 2U); + decr = (tcp_snd_cwnd(tp) * ca->beta) >> BETA_SHIFT; + return max(tcp_snd_cwnd(tp) - decr, 2U); } /* Extract info for Tcp socket info provided via netlink. */ diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index dc49a3d551eb..198f8a0d37be 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -70,11 +70,14 @@ #include <linux/sysctl.h> #include <linux/kernel.h> #include <linux/prefetch.h> +#include <linux/bitops.h> #include <net/dst.h> #include <net/tcp.h> +#include <net/tcp_ecn.h> +#include <net/proto_memory.h> #include <net/inet_common.h> #include <linux/ipsec.h> -#include <asm/unaligned.h> +#include <linux/unaligned.h> #include <linux/errqueue.h> #include <trace/events/tcp.h> #include <linux/jump_label_ratelimit.h> @@ -101,6 +104,7 @@ int sysctl_tcp_max_orphans __read_mostly = NR_FILE; #define FLAG_NO_CHALLENGE_ACK 0x8000 /* do not call tcp_send_challenge_ack() */ #define FLAG_ACK_MAYBE_DELAYED 0x10000 /* Likely a delayed ACK */ #define FLAG_DSACK_TLP 0x20000 /* DSACK for tail loss probe */ +#define FLAG_TS_PROGRESS 0x40000 /* Positive timestamp delta */ #define FLAG_ACKED (FLAG_DATA_ACKED|FLAG_SYN_ACKED) #define FLAG_NOT_DUP (FLAG_DATA|FLAG_WIN_UPDATE|FLAG_ACKED) @@ -117,18 +121,18 @@ int sysctl_tcp_max_orphans __read_mostly = NR_FILE; #if IS_ENABLED(CONFIG_TLS_DEVICE) static DEFINE_STATIC_KEY_DEFERRED_FALSE(clean_acked_data_enabled, HZ); -void clean_acked_data_enable(struct inet_connection_sock *icsk, +void clean_acked_data_enable(struct tcp_sock *tp, void (*cad)(struct sock *sk, u32 ack_seq)) { - icsk->icsk_clean_acked = cad; + tp->tcp_clean_acked = cad; static_branch_deferred_inc(&clean_acked_data_enabled); } EXPORT_SYMBOL_GPL(clean_acked_data_enable); -void clean_acked_data_disable(struct inet_connection_sock *icsk) +void clean_acked_data_disable(struct tcp_sock *tp) { static_branch_slow_dec_deferred(&clean_acked_data_enabled); - icsk->icsk_clean_acked = NULL; + tp->tcp_clean_acked = NULL; } EXPORT_SYMBOL_GPL(clean_acked_data_disable); @@ -168,6 +172,7 @@ static void bpf_skops_parse_hdr(struct sock *sk, struct sk_buff *skb) memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp)); sock_ops.op = BPF_SOCK_OPS_PARSE_HDR_OPT_CB; sock_ops.is_fullsock = 1; + sock_ops.is_locked_tcp_sock = 1; sock_ops.sk = sk; bpf_skops_init_skb(&sock_ops, skb, tcp_hdrlen(skb)); @@ -184,6 +189,7 @@ static void bpf_skops_established(struct sock *sk, int bpf_op, memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp)); sock_ops.op = bpf_op; sock_ops.is_fullsock = 1; + sock_ops.is_locked_tcp_sock = 1; sock_ops.sk = sk; /* sk with TCP_REPAIR_ON does not have skb in tcp_finish_connect */ if (skb) @@ -202,23 +208,17 @@ static void bpf_skops_established(struct sock *sk, int bpf_op, } #endif -static void tcp_gro_dev_warn(struct sock *sk, const struct sk_buff *skb, - unsigned int len) +static __cold void tcp_gro_dev_warn(const struct sock *sk, const struct sk_buff *skb, + unsigned int len) { - static bool __once __read_mostly; + struct net_device *dev; - if (!__once) { - struct net_device *dev; - - __once = true; - - rcu_read_lock(); - dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif); - if (!dev || len >= dev->mtu) - pr_warn("%s: Driver has suspect GRO implementation, TCP performance may be compromised.\n", - dev ? dev->name : "Unknown driver"); - rcu_read_unlock(); - } + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif); + if (!dev || len >= READ_ONCE(dev->mtu)) + pr_warn("%s: Driver has suspect GRO implementation, TCP performance may be compromised.\n", + dev ? dev->name : "Unknown driver"); + rcu_read_unlock(); } /* Adapt the MSS value used to make delayed ack decision to the @@ -237,12 +237,45 @@ static void tcp_measure_rcv_mss(struct sock *sk, const struct sk_buff *skb) */ len = skb_shinfo(skb)->gso_size ? : skb->len; if (len >= icsk->icsk_ack.rcv_mss) { + /* Note: divides are still a bit expensive. + * For the moment, only adjust scaling_ratio + * when we update icsk_ack.rcv_mss. + */ + if (unlikely(len != icsk->icsk_ack.rcv_mss)) { + u64 val = (u64)skb->len << TCP_RMEM_TO_WIN_SCALE; + u8 old_ratio = tcp_sk(sk)->scaling_ratio; + + do_div(val, skb->truesize); + tcp_sk(sk)->scaling_ratio = val ? val : 1; + + if (old_ratio != tcp_sk(sk)->scaling_ratio) { + struct tcp_sock *tp = tcp_sk(sk); + + val = tcp_win_from_space(sk, sk->sk_rcvbuf); + tcp_set_window_clamp(sk, val); + + if (tp->window_clamp < tp->rcvq_space.space) + tp->rcvq_space.space = tp->window_clamp; + } + } icsk->icsk_ack.rcv_mss = min_t(unsigned int, len, tcp_sk(sk)->advmss); /* Account for possibly-removed options */ - if (unlikely(len > icsk->icsk_ack.rcv_mss + - MAX_TCP_OPTION_SPACE)) - tcp_gro_dev_warn(sk, skb, len); + DO_ONCE_LITE_IF(len > icsk->icsk_ack.rcv_mss + MAX_TCP_OPTION_SPACE, + tcp_gro_dev_warn, sk, skb, len); + /* If the skb has a len of exactly 1*MSS and has the PSH bit + * set then it is likely the end of an application write. So + * more data may not be arriving soon, and yet the data sender + * may be waiting for an ACK if cwnd-bound or using TX zero + * copy. So we set ICSK_ACK_PUSHED here so that + * tcp_cleanup_rbuf() will send an ACK immediately if the app + * reads all of the data and is not ping-pong. If len > MSS + * then this logic does not matter (and does not hurt) because + * tcp_cleanup_rbuf() will always ACK immediately if the app + * reads data and there is more than an MSS of unACKed data. + */ + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_PSH) + icsk->icsk_ack.pending |= ICSK_ACK_PUSHED; } else { /* Otherwise, we make more careful check taking into account, * that SACKs block is variable. @@ -287,7 +320,7 @@ static void tcp_incr_quickack(struct sock *sk, unsigned int max_quickacks) icsk->icsk_ack.quick = quickacks; } -void tcp_enter_quickack_mode(struct sock *sk, unsigned int max_quickacks) +static void tcp_enter_quickack_mode(struct sock *sk, unsigned int max_quickacks) { struct inet_connection_sock *icsk = inet_csk(sk); @@ -295,7 +328,6 @@ void tcp_enter_quickack_mode(struct sock *sk, unsigned int max_quickacks) inet_csk_exit_pingpong_mode(sk); icsk->icsk_ack.ato = TCP_ATO_MIN; } -EXPORT_SYMBOL(tcp_enter_quickack_mode); /* Send ACKs quickly, if "quick" count is not exhausted * and the session is not interactive. @@ -304,41 +336,18 @@ EXPORT_SYMBOL(tcp_enter_quickack_mode); static bool tcp_in_quickack_mode(struct sock *sk) { const struct inet_connection_sock *icsk = inet_csk(sk); - const struct dst_entry *dst = __sk_dst_get(sk); - return (dst && dst_metric(dst, RTAX_QUICKACK)) || + return icsk->icsk_ack.dst_quick_ack || (icsk->icsk_ack.quick && !inet_csk_in_pingpong_mode(sk)); } -static void tcp_ecn_queue_cwr(struct tcp_sock *tp) -{ - if (tp->ecn_flags & TCP_ECN_OK) - tp->ecn_flags |= TCP_ECN_QUEUE_CWR; -} - -static void tcp_ecn_accept_cwr(struct sock *sk, const struct sk_buff *skb) -{ - if (tcp_hdr(skb)->cwr) { - tcp_sk(sk)->ecn_flags &= ~TCP_ECN_DEMAND_CWR; - - /* If the sender is telling us it has entered CWR, then its - * cwnd may be very low (even just 1 packet), so we should ACK - * immediately. - */ - if (TCP_SKB_CB(skb)->seq != TCP_SKB_CB(skb)->end_seq) - inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW; - } -} - -static void tcp_ecn_withdraw_cwr(struct tcp_sock *tp) -{ - tp->ecn_flags &= ~TCP_ECN_QUEUE_CWR; -} - -static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) +static void tcp_data_ecn_check(struct sock *sk, const struct sk_buff *skb) { struct tcp_sock *tp = tcp_sk(sk); + if (tcp_ecn_disabled(tp)) + return; + switch (TCP_SKB_CB(skb)->ip_dsfield & INET_ECN_MASK) { case INET_ECN_NOT_ECT: /* Funny extension: if ECT is not set on a segment, @@ -352,44 +361,222 @@ static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) if (tcp_ca_needs_ecn(sk)) tcp_ca_event(sk, CA_EVENT_ECN_IS_CE); - if (!(tp->ecn_flags & TCP_ECN_DEMAND_CWR)) { + if (!(tp->ecn_flags & TCP_ECN_DEMAND_CWR) && + tcp_ecn_mode_rfc3168(tp)) { /* Better not delay acks, sender can have a very low cwnd */ tcp_enter_quickack_mode(sk, 2); tp->ecn_flags |= TCP_ECN_DEMAND_CWR; } + /* As for RFC3168 ECN, the TCP_ECN_SEEN flag is set by + * tcp_data_ecn_check() when the ECN codepoint of + * received TCP data contains ECT(0), ECT(1), or CE. + */ + if (!tcp_ecn_mode_rfc3168(tp)) + break; tp->ecn_flags |= TCP_ECN_SEEN; break; default: if (tcp_ca_needs_ecn(sk)) tcp_ca_event(sk, CA_EVENT_ECN_NO_CE); + if (!tcp_ecn_mode_rfc3168(tp)) + break; tp->ecn_flags |= TCP_ECN_SEEN; break; } } -static void tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) +/* Returns true if the byte counters can be used */ +static bool tcp_accecn_process_option(struct tcp_sock *tp, + const struct sk_buff *skb, + u32 delivered_bytes, int flag) { - if (tcp_sk(sk)->ecn_flags & TCP_ECN_OK) - __tcp_ecn_check_ce(sk, skb); + u8 estimate_ecnfield = tp->est_ecnfield; + bool ambiguous_ecn_bytes_incr = false; + bool first_changed = false; + unsigned int optlen; + bool order1, res; + unsigned int i; + u8 *ptr; + + if (tcp_accecn_opt_fail_recv(tp)) + return false; + + if (!(flag & FLAG_SLOWPATH) || !tp->rx_opt.accecn) { + if (!tp->saw_accecn_opt) { + /* Too late to enable after this point due to + * potential counter wraps + */ + if (tp->bytes_sent >= (1 << 23) - 1) { + u8 saw_opt = TCP_ACCECN_OPT_FAIL_SEEN; + + tcp_accecn_saw_opt_fail_recv(tp, saw_opt); + } + return false; + } + + if (estimate_ecnfield) { + u8 ecnfield = estimate_ecnfield - 1; + + tp->delivered_ecn_bytes[ecnfield] += delivered_bytes; + return true; + } + return false; + } + + ptr = skb_transport_header(skb) + tp->rx_opt.accecn; + optlen = ptr[1] - 2; + if (WARN_ON_ONCE(ptr[0] != TCPOPT_ACCECN0 && ptr[0] != TCPOPT_ACCECN1)) + return false; + order1 = (ptr[0] == TCPOPT_ACCECN1); + ptr += 2; + + if (tp->saw_accecn_opt < TCP_ACCECN_OPT_COUNTER_SEEN) { + tp->saw_accecn_opt = tcp_accecn_option_init(skb, + tp->rx_opt.accecn); + if (tp->saw_accecn_opt == TCP_ACCECN_OPT_FAIL_SEEN) + tcp_accecn_fail_mode_set(tp, TCP_ACCECN_OPT_FAIL_RECV); + } + + res = !!estimate_ecnfield; + for (i = 0; i < 3; i++) { + u32 init_offset; + u8 ecnfield; + s32 delta; + u32 *cnt; + + if (optlen < TCPOLEN_ACCECN_PERFIELD) + break; + + ecnfield = tcp_accecn_optfield_to_ecnfield(i, order1); + init_offset = tcp_accecn_field_init_offset(ecnfield); + cnt = &tp->delivered_ecn_bytes[ecnfield - 1]; + delta = tcp_update_ecn_bytes(cnt, ptr, init_offset); + if (delta && delta < 0) { + res = false; + ambiguous_ecn_bytes_incr = true; + } + if (delta && ecnfield != estimate_ecnfield) { + if (!first_changed) { + tp->est_ecnfield = ecnfield; + first_changed = true; + } else { + res = false; + ambiguous_ecn_bytes_incr = true; + } + } + + optlen -= TCPOLEN_ACCECN_PERFIELD; + ptr += TCPOLEN_ACCECN_PERFIELD; + } + if (ambiguous_ecn_bytes_incr) + tp->est_ecnfield = 0; + + return res; } -static void tcp_ecn_rcv_synack(struct tcp_sock *tp, const struct tcphdr *th) +static void tcp_count_delivered_ce(struct tcp_sock *tp, u32 ecn_count) { - if ((tp->ecn_flags & TCP_ECN_OK) && (!th->ece || th->cwr)) - tp->ecn_flags &= ~TCP_ECN_OK; + tp->delivered_ce += ecn_count; } -static void tcp_ecn_rcv_syn(struct tcp_sock *tp, const struct tcphdr *th) +/* Updates the delivered and delivered_ce counts */ +static void tcp_count_delivered(struct tcp_sock *tp, u32 delivered, + bool ece_ack) { - if ((tp->ecn_flags & TCP_ECN_OK) && (!th->ece || !th->cwr)) - tp->ecn_flags &= ~TCP_ECN_OK; + tp->delivered += delivered; + if (tcp_ecn_mode_rfc3168(tp) && ece_ack) + tcp_count_delivered_ce(tp, delivered); } -static bool tcp_ecn_rcv_ecn_echo(const struct tcp_sock *tp, const struct tcphdr *th) +/* Returns the ECN CE delta */ +static u32 __tcp_accecn_process(struct sock *sk, const struct sk_buff *skb, + u32 delivered_pkts, u32 delivered_bytes, + int flag) { - if (th->ece && !th->syn && (tp->ecn_flags & TCP_ECN_OK)) - return true; - return false; + u32 old_ceb = tcp_sk(sk)->delivered_ecn_bytes[INET_ECN_CE - 1]; + const struct tcphdr *th = tcp_hdr(skb); + struct tcp_sock *tp = tcp_sk(sk); + u32 delta, safe_delta, d_ceb; + bool opt_deltas_valid; + u32 corrected_ace; + + /* Reordered ACK or uncertain due to lack of data to send and ts */ + if (!(flag & (FLAG_FORWARD_PROGRESS | FLAG_TS_PROGRESS))) + return 0; + + opt_deltas_valid = tcp_accecn_process_option(tp, skb, + delivered_bytes, flag); + + if (!(flag & FLAG_SLOWPATH)) { + /* AccECN counter might overflow on large ACKs */ + if (delivered_pkts <= TCP_ACCECN_CEP_ACE_MASK) + return 0; + } + + /* ACE field is not available during handshake */ + if (flag & FLAG_SYN_ACKED) + return 0; + + if (tp->received_ce_pending >= TCP_ACCECN_ACE_MAX_DELTA) + inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW; + + corrected_ace = tcp_accecn_ace(th) - TCP_ACCECN_CEP_INIT_OFFSET; + delta = (corrected_ace - tp->delivered_ce) & TCP_ACCECN_CEP_ACE_MASK; + if (delivered_pkts <= TCP_ACCECN_CEP_ACE_MASK) + return delta; + + safe_delta = delivered_pkts - + ((delivered_pkts - delta) & TCP_ACCECN_CEP_ACE_MASK); + + if (opt_deltas_valid) { + d_ceb = tp->delivered_ecn_bytes[INET_ECN_CE - 1] - old_ceb; + if (!d_ceb) + return delta; + + if ((delivered_pkts >= (TCP_ACCECN_CEP_ACE_MASK + 1) * 2) && + (tcp_is_sack(tp) || + ((1 << inet_csk(sk)->icsk_ca_state) & + (TCPF_CA_Open | TCPF_CA_CWR)))) { + u32 est_d_cep; + + if (delivered_bytes <= d_ceb) + return safe_delta; + + est_d_cep = DIV_ROUND_UP_ULL((u64)d_ceb * + delivered_pkts, + delivered_bytes); + return min(safe_delta, + delta + + (est_d_cep & ~TCP_ACCECN_CEP_ACE_MASK)); + } + + if (d_ceb > delta * tp->mss_cache) + return safe_delta; + if (d_ceb < + safe_delta * tp->mss_cache >> TCP_ACCECN_SAFETY_SHIFT) + return delta; + } + + return safe_delta; +} + +static u32 tcp_accecn_process(struct sock *sk, const struct sk_buff *skb, + u32 delivered_pkts, u32 delivered_bytes, + int *flag) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 delta; + + delta = __tcp_accecn_process(sk, skb, delivered_pkts, + delivered_bytes, *flag); + if (delta > 0) { + tcp_count_delivered_ce(tp, delta); + *flag |= FLAG_ECE; + /* Recalculate header predictor */ + if (tp->pred_flags) + tcp_fast_path_on(tp); + } + return delta; } /* Buffer size and advertised window tuning. @@ -414,7 +601,7 @@ static void tcp_sndbuf_expand(struct sock *sk) per_mss = roundup_pow_of_two(per_mss) + SKB_DATA_ALIGN(sizeof(struct sk_buff)); - nr_segs = max_t(u32, TCP_INIT_CWND, tp->snd_cwnd); + nr_segs = max_t(u32, TCP_INIT_CWND, tcp_snd_cwnd(tp)); nr_segs = max_t(u32, nr_segs, tp->reordering + 1); /* Fast Recovery (RFC 5681 3.2) : @@ -426,7 +613,7 @@ static void tcp_sndbuf_expand(struct sock *sk) if (sk->sk_sndbuf < sndmem) WRITE_ONCE(sk->sk_sndbuf, - min(sndmem, sock_net(sk)->ipv4.sysctl_tcp_wmem[2])); + min(sndmem, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_wmem[2]))); } /* 2. Tuning advertised window (window_clamp, rcv_ssthresh) @@ -458,10 +645,10 @@ static void tcp_sndbuf_expand(struct sock *sk) static int __tcp_grow_window(const struct sock *sk, const struct sk_buff *skb, unsigned int skbtruesize) { - struct tcp_sock *tp = tcp_sk(sk); + const struct tcp_sock *tp = tcp_sk(sk); /* Optimize this! */ int truesize = tcp_win_from_space(sk, skbtruesize) >> 1; - int window = tcp_win_from_space(sk, sock_net(sk)->ipv4.sysctl_tcp_rmem[2]) >> 1; + int window = tcp_win_from_space(sk, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[2])) >> 1; while (tp->rcv_ssthresh <= window) { if (truesize <= skb->len) @@ -534,7 +721,7 @@ static void tcp_grow_window(struct sock *sk, const struct sk_buff *skb, */ static void tcp_init_buffer_space(struct sock *sk) { - int tcp_app_win = sock_net(sk)->ipv4.sysctl_tcp_app_win; + int tcp_app_win = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_app_win); struct tcp_sock *tp = tcp_sk(sk); int maxwin; @@ -548,19 +735,20 @@ static void tcp_init_buffer_space(struct sock *sk) maxwin = tcp_full_space(sk); if (tp->window_clamp >= maxwin) { - tp->window_clamp = maxwin; + WRITE_ONCE(tp->window_clamp, maxwin); if (tcp_app_win && maxwin > 4 * tp->advmss) - tp->window_clamp = max(maxwin - - (maxwin >> tcp_app_win), - 4 * tp->advmss); + WRITE_ONCE(tp->window_clamp, + max(maxwin - (maxwin >> tcp_app_win), + 4 * tp->advmss)); } /* Force reservation of one segment. */ if (tcp_app_win && tp->window_clamp > 2 * tp->advmss && tp->window_clamp + tp->advmss > maxwin) - tp->window_clamp = max(2 * tp->advmss, maxwin - tp->advmss); + WRITE_ONCE(tp->window_clamp, + max(2 * tp->advmss, maxwin - tp->advmss)); tp->rcv_ssthresh = min(tp->rcv_ssthresh, tp->window_clamp); tp->snd_cwnd_stamp = tcp_jiffies32; @@ -574,16 +762,17 @@ static void tcp_clamp_window(struct sock *sk) struct tcp_sock *tp = tcp_sk(sk); struct inet_connection_sock *icsk = inet_csk(sk); struct net *net = sock_net(sk); + int rmem2; icsk->icsk_ack.quick = 0; + rmem2 = READ_ONCE(net->ipv4.sysctl_tcp_rmem[2]); - if (sk->sk_rcvbuf < net->ipv4.sysctl_tcp_rmem[2] && + if (sk->sk_rcvbuf < rmem2 && !(sk->sk_userlocks & SOCK_RCVBUF_LOCK) && !tcp_under_memory_pressure(sk) && sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)) { WRITE_ONCE(sk->sk_rcvbuf, - min(atomic_read(&sk->sk_rmem_alloc), - net->ipv4.sysctl_tcp_rmem[2])); + min(atomic_read(&sk->sk_rmem_alloc), rmem2)); } if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf) tp->rcv_ssthresh = min(tp->window_clamp, 2U * tp->advmss); @@ -607,7 +796,7 @@ void tcp_initialize_rcv_mss(struct sock *sk) inet_csk(sk)->icsk_ack.rcv_mss = hint; } -EXPORT_SYMBOL(tcp_initialize_rcv_mss); +EXPORT_IPV6_MOD(tcp_initialize_rcv_mss); /* Receiver "autotuning" code. * @@ -622,10 +811,12 @@ EXPORT_SYMBOL(tcp_initialize_rcv_mss); */ static void tcp_rcv_rtt_update(struct tcp_sock *tp, u32 sample, int win_dep) { - u32 new_sample = tp->rcv_rtt_est.rtt_us; - long m = sample; + u32 new_sample, old_sample = tp->rcv_rtt_est.rtt_us; + long m = sample << 3; - if (new_sample != 0) { + if (old_sample == 0 || m < old_sample) { + new_sample = m; + } else { /* If we sample in larger samples in the non-timestamp * case, we could grossly overestimate the RTT especially * with chatty applications or bulk transfer apps which @@ -636,17 +827,12 @@ static void tcp_rcv_rtt_update(struct tcp_sock *tp, u32 sample, int win_dep) * else with timestamps disabled convergence takes too * long. */ - if (!win_dep) { - m -= (new_sample >> 3); - new_sample += m; - } else { - m <<= 3; - if (m < new_sample) - new_sample = m; - } - } else { - /* No previous measure. */ - new_sample = m << 3; + if (win_dep) + return; + /* Do not use this sample if receive queue is not empty. */ + if (tp->rcv_nxt != tp->copied_seq) + return; + new_sample = old_sample - (old_sample >> 3) + sample; } tp->rcv_rtt_est.rtt_us = new_sample; @@ -670,6 +856,23 @@ new_measure: tp->rcv_rtt_est.time = tp->tcp_mstamp; } +static s32 tcp_rtt_tsopt_us(const struct tcp_sock *tp, u32 min_delta) +{ + u32 delta, delta_us; + + delta = tcp_time_stamp_ts(tp) - tp->rx_opt.rcv_tsecr; + if (tp->tcp_usec_ts) + return delta; + + if (likely(delta < INT_MAX / (USEC_PER_SEC / TCP_TS_HZ))) { + if (!delta) + delta = min_delta; + delta_us = delta * (USEC_PER_SEC / TCP_TS_HZ); + return delta_us; + } + return -1; +} + static inline void tcp_rcv_rtt_measure_ts(struct sock *sk, const struct sk_buff *skb) { @@ -681,18 +884,58 @@ static inline void tcp_rcv_rtt_measure_ts(struct sock *sk, if (TCP_SKB_CB(skb)->end_seq - TCP_SKB_CB(skb)->seq >= inet_csk(sk)->icsk_ack.rcv_mss) { - u32 delta = tcp_time_stamp(tp) - tp->rx_opt.rcv_tsecr; - u32 delta_us; - - if (likely(delta < INT_MAX / (USEC_PER_SEC / TCP_TS_HZ))) { - if (!delta) - delta = 1; - delta_us = delta * (USEC_PER_SEC / TCP_TS_HZ); - tcp_rcv_rtt_update(tp, delta_us, 0); - } + s32 delta = tcp_rtt_tsopt_us(tp, 0); + + if (delta > 0) + tcp_rcv_rtt_update(tp, delta, 0); } } +void tcp_rcvbuf_grow(struct sock *sk, u32 newval) +{ + const struct net *net = sock_net(sk); + struct tcp_sock *tp = tcp_sk(sk); + u32 rcvwin, rcvbuf, cap, oldval; + u32 rtt_threshold, rtt_us; + u64 grow; + + oldval = tp->rcvq_space.space; + tp->rcvq_space.space = newval; + + if (!READ_ONCE(net->ipv4.sysctl_tcp_moderate_rcvbuf) || + (sk->sk_userlocks & SOCK_RCVBUF_LOCK)) + return; + + /* DRS is always one RTT late. */ + rcvwin = newval << 1; + + rtt_us = tp->rcv_rtt_est.rtt_us >> 3; + rtt_threshold = READ_ONCE(net->ipv4.sysctl_tcp_rcvbuf_low_rtt); + if (rtt_us < rtt_threshold) { + /* For small RTT, we set @grow to rcvwin * rtt_us/rtt_threshold. + * It might take few additional ms to reach 'line rate', + * but will avoid sk_rcvbuf inflation and poor cache use. + */ + grow = div_u64((u64)rcvwin * rtt_us, rtt_threshold); + } else { + /* slow start: allow the sender to double its rate. */ + grow = div_u64(((u64)rcvwin << 1) * (newval - oldval), oldval); + } + rcvwin += grow; + + if (!RB_EMPTY_ROOT(&tp->out_of_order_queue)) + rcvwin += TCP_SKB_CB(tp->ooo_last_skb)->end_seq - tp->rcv_nxt; + + cap = READ_ONCE(net->ipv4.sysctl_tcp_rmem[2]); + + rcvbuf = min_t(u32, tcp_space_from_win(sk, rcvwin), cap); + if (rcvbuf > sk->sk_rcvbuf) { + WRITE_ONCE(sk->sk_rcvbuf, rcvbuf); + /* Make the window clamp follow along. */ + WRITE_ONCE(tp->window_clamp, + tcp_win_from_space(sk, rcvbuf)); + } +} /* * This function should be called every time data is copied to user space. * It calculates the appropriate TCP receive buffer space. @@ -700,66 +943,48 @@ static inline void tcp_rcv_rtt_measure_ts(struct sock *sk, void tcp_rcv_space_adjust(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); - u32 copied; - int time; + int time, inq, copied; trace_tcp_rcv_space_adjust(sk); - tcp_mstamp_refresh(tp); + if (unlikely(!tp->rcv_rtt_est.rtt_us)) + return; + + /* We do not refresh tp->tcp_mstamp here. + * Some platforms have expensive ktime_get() implementations. + * Using the last cached value is enough for DRS. + */ time = tcp_stamp_us_delta(tp->tcp_mstamp, tp->rcvq_space.time); - if (time < (tp->rcv_rtt_est.rtt_us >> 3) || tp->rcv_rtt_est.rtt_us == 0) + if (time < (tp->rcv_rtt_est.rtt_us >> 3)) return; /* Number of bytes copied to user in last RTT */ copied = tp->copied_seq - tp->rcvq_space.seq; + /* Number of bytes in receive queue. */ + inq = tp->rcv_nxt - tp->copied_seq; + copied -= inq; if (copied <= tp->rcvq_space.space) goto new_measure; - /* A bit of theory : - * copied = bytes received in previous RTT, our base window - * To cope with packet losses, we need a 2x factor - * To cope with slow start, and sender growing its cwin by 100 % - * every RTT, we need a 4x factor, because the ACK we are sending - * now is for the next RTT, not the current one : - * <prev RTT . ><current RTT .. ><next RTT .... > - */ + trace_tcp_rcvbuf_grow(sk, time); - if (sock_net(sk)->ipv4.sysctl_tcp_moderate_rcvbuf && - !(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) { - int rcvmem, rcvbuf; - u64 rcvwin, grow; - - /* minimal window to cope with packet losses, assuming - * steady state. Add some cushion because of small variations. - */ - rcvwin = ((u64)copied << 1) + 16 * tp->advmss; - - /* Accommodate for sender rate increase (eg. slow start) */ - grow = rcvwin * (copied - tp->rcvq_space.space); - do_div(grow, tp->rcvq_space.space); - rcvwin += (grow << 1); - - rcvmem = SKB_TRUESIZE(tp->advmss + MAX_TCP_HEADER); - while (tcp_win_from_space(sk, rcvmem) < tp->advmss) - rcvmem += 128; - - do_div(rcvwin, tp->advmss); - rcvbuf = min_t(u64, rcvwin * rcvmem, - sock_net(sk)->ipv4.sysctl_tcp_rmem[2]); - if (rcvbuf > sk->sk_rcvbuf) { - WRITE_ONCE(sk->sk_rcvbuf, rcvbuf); - - /* Make the window clamp follow along. */ - tp->window_clamp = tcp_win_from_space(sk, rcvbuf); - } - } - tp->rcvq_space.space = copied; + tcp_rcvbuf_grow(sk, copied); new_measure: tp->rcvq_space.seq = tp->copied_seq; tp->rcvq_space.time = tp->tcp_mstamp; } +static void tcp_save_lrcv_flowlabel(struct sock *sk, const struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct inet_connection_sock *icsk = inet_csk(sk); + + if (skb->protocol == htons(ETH_P_IPV6)) + icsk->icsk_ack.lrcv_flowlabel = ntohl(ip6_flowlabel(ipv6_hdr(skb))); +#endif +} + /* There is something which you must keep in mind when you analyze the * behavior of the tp->ato delayed ack timeout interval. When a * connection starts up, we want to ack as quickly as possible. The @@ -805,12 +1030,12 @@ static void tcp_event_data_recv(struct sock *sk, struct sk_buff *skb) * restart window, so that we send ACKs quickly. */ tcp_incr_quickack(sk, TCP_MAX_QUICKACKS); - sk_mem_reclaim(sk); } } icsk->icsk_ack.lrcvtime = now; + tcp_save_lrcv_flowlabel(sk, skb); - tcp_ecn_check_ce(sk, skb); + tcp_data_ecn_check(sk, skb); if (skb->len >= 128) tcp_grow_window(sk, skb, true); @@ -878,7 +1103,7 @@ static void tcp_rtt_estimator(struct sock *sk, long mrtt_us) tp->rtt_seq = tp->snd_nxt; tp->mdev_max_us = tcp_rto_min_us(sk); - tcp_bpf_rtt(sk); + tcp_bpf_rtt(sk, mrtt_us, srtt); } } else { /* no previous measure. */ @@ -888,12 +1113,12 @@ static void tcp_rtt_estimator(struct sock *sk, long mrtt_us) tp->mdev_max_us = tp->rttvar_us; tp->rtt_seq = tp->snd_nxt; - tcp_bpf_rtt(sk); + tcp_bpf_rtt(sk, mrtt_us, srtt); } tp->srtt_us = max(1U, srtt); } -static void tcp_update_pacing_rate(struct sock *sk) +void tcp_update_pacing_rate(struct sock *sk) { const struct tcp_sock *tp = tcp_sk(sk); u64 rate; @@ -909,12 +1134,12 @@ static void tcp_update_pacing_rate(struct sock *sk) * If snd_cwnd >= (tp->snd_ssthresh / 2), we are approaching * end of slow start and should slow down. */ - if (tp->snd_cwnd < tp->snd_ssthresh / 2) - rate *= sock_net(sk)->ipv4.sysctl_tcp_pacing_ss_ratio; + if (tcp_snd_cwnd(tp) < tp->snd_ssthresh / 2) + rate *= READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_pacing_ss_ratio); else - rate *= sock_net(sk)->ipv4.sysctl_tcp_pacing_ca_ratio; + rate *= READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_pacing_ca_ratio); - rate *= max(tp->snd_cwnd, tp->packets_out); + rate *= max(tcp_snd_cwnd(tp), tp->packets_out); if (likely(tp->srtt_us)) do_div(rate, tp->srtt_us); @@ -923,14 +1148,14 @@ static void tcp_update_pacing_rate(struct sock *sk) * without any lock. We want to make sure compiler wont store * intermediate values in this location. */ - WRITE_ONCE(sk->sk_pacing_rate, min_t(u64, rate, - sk->sk_max_pacing_rate)); + WRITE_ONCE(sk->sk_pacing_rate, + min_t(u64, rate, READ_ONCE(sk->sk_max_pacing_rate))); } /* Calculate rto without backoff. This is the second half of Van Jacobson's * routine referred to above. */ -static void tcp_set_rto(struct sock *sk) +void tcp_set_rto(struct sock *sk) { const struct tcp_sock *tp = tcp_sk(sk); /* Old crap is replaced with new one. 8) @@ -975,6 +1200,7 @@ struct tcp_sacktag_state { u64 last_sackt; u32 reord; u32 sack_delivered; + u32 delivered_bytes; int flag; unsigned int mss_now; struct rate_sample *rate; @@ -1051,7 +1277,7 @@ static void tcp_check_sack_reordering(struct sock *sk, const u32 low_seq, tp->undo_marker ? tp->undo_retrans : 0); #endif tp->reordering = min_t(u32, (metric + mss - 1) / mss, - sock_net(sk)->ipv4.sysctl_tcp_max_reordering); + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_max_reordering)); } /* This exciting event is worth to be remembered. 8) */ @@ -1107,15 +1333,6 @@ void tcp_mark_skb_lost(struct sock *sk, struct sk_buff *skb) } } -/* Updates the delivered and delivered_ce counts */ -static void tcp_count_delivered(struct tcp_sock *tp, u32 delivered, - bool ece_ack) -{ - tp->delivered += delivered; - if (ece_ack) - tp->delivered_ce += delivered; -} - /* This procedure tags the retransmission queue when SACKs arrive. * * We have three tag bits: SACKED(S), RETRANS(R) and LOST(L). @@ -1131,7 +1348,7 @@ static void tcp_count_delivered(struct tcp_sock *tp, u32 delivered, * L|R 1 - orig is lost, retransmit is in flight. * S|R 1 - orig reached receiver, retrans is still in flight. * (L|S|R is logically valid, it could occur when L|R is sacked, - * but it is equivalent to plain S and code short-curcuits it to S. + * but it is equivalent to plain S and code short-circuits it to S. * L|S is logically invalid, it would mean -1 packet in flight 8)) * * These 6 states form finite state machine, controlled by the following events: @@ -1345,7 +1562,7 @@ static int tcp_match_skb_to_sack(struct sock *sk, struct sk_buff *skb, static u8 tcp_sacktag_one(struct sock *sk, struct tcp_sacktag_state *state, u8 sacked, u32 start_seq, u32 end_seq, - int dup_sack, int pcount, + int dup_sack, int pcount, u32 plen, u64 xmit_time) { struct tcp_sock *tp = tcp_sk(sk); @@ -1405,11 +1622,7 @@ static u8 tcp_sacktag_one(struct sock *sk, tp->sacked_out += pcount; /* Out-of-order packets delivered */ state->sack_delivered += pcount; - - /* Lost marker hint past SACKed? Tweak RFC3517 cnt */ - if (tp->lost_skb_hint && - before(start_seq, TCP_SKB_CB(tp->lost_skb_hint)->seq)) - tp->lost_cnt_hint += pcount; + state->delivered_bytes += plen; } /* D-SACK. We can detect redundant retransmission in S|R and plain R @@ -1446,13 +1659,10 @@ static bool tcp_shifted_skb(struct sock *sk, struct sk_buff *prev, * tcp_highest_sack_seq() when skb is highest_sack. */ tcp_sacktag_one(sk, state, TCP_SKB_CB(skb)->sacked, - start_seq, end_seq, dup_sack, pcount, + start_seq, end_seq, dup_sack, pcount, skb->len, tcp_skb_timestamp_us(skb)); tcp_rate_skb_delivered(sk, skb, state->rate); - if (skb == tp->lost_skb_hint) - tp->lost_cnt_hint += pcount; - TCP_SKB_CB(prev)->end_seq += shifted; TCP_SKB_CB(skb)->seq += shifted; @@ -1485,10 +1695,6 @@ static bool tcp_shifted_skb(struct sock *sk, struct sk_buff *prev, if (skb == tp->retransmit_skb_hint) tp->retransmit_skb_hint = prev; - if (skb == tp->lost_skb_hint) { - tp->lost_skb_hint = prev; - tp->lost_cnt_hint -= tcp_skb_pcount(prev); - } TCP_SKB_CB(prev)->tcp_flags |= TCP_SKB_CB(skb)->tcp_flags; TCP_SKB_CB(prev)->eor = TCP_SKB_CB(skb)->eor; @@ -1660,6 +1866,8 @@ static struct sk_buff *tcp_shift_skb_data(struct sock *sk, struct sk_buff *skb, (mss != tcp_skb_seglen(skb))) goto out; + if (!tcp_skb_can_collapse(prev, skb)) + goto out; len = skb->len; pcount = tcp_skb_pcount(skb); if (tcp_skb_shift(prev, skb, pcount, len)) @@ -1736,6 +1944,7 @@ static struct sk_buff *tcp_sacktag_walk(struct sk_buff *skb, struct sock *sk, TCP_SKB_CB(skb)->end_seq, dup_sack, tcp_skb_pcount(skb), + skb->len, tcp_skb_timestamp_us(skb)); tcp_rate_skb_delivered(sk, skb, state->rate); if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED) @@ -2028,7 +2237,7 @@ static void tcp_check_reno_reordering(struct sock *sk, const int addend) return; tp->reordering = min_t(u32, tp->packets_out + addend, - sock_net(sk)->ipv4.sysctl_tcp_max_reordering); + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_max_reordering)); tp->reord_seen++; NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPRENOREORDER); } @@ -2082,18 +2291,25 @@ void tcp_clear_retrans(struct tcp_sock *tp) tp->undo_marker = 0; tp->undo_retrans = -1; tp->sacked_out = 0; + tp->rto_stamp = 0; + tp->total_rto = 0; + tp->total_rto_recoveries = 0; + tp->total_rto_time = 0; } static inline void tcp_init_undo(struct tcp_sock *tp) { tp->undo_marker = tp->snd_una; - /* Retransmission still in flight may cause DSACKs later. */ - tp->undo_retrans = tp->retrans_out ? : -1; -} -static bool tcp_is_rack(const struct sock *sk) -{ - return sock_net(sk)->ipv4.sysctl_tcp_recovery & TCP_RACK_LOSS_DETECTION; + /* Retransmission still in flight may cause DSACKs later. */ + /* First, account for regular retransmits in flight: */ + tp->undo_retrans = tp->retrans_out; + /* Next, account for TLP retransmits in flight: */ + if (tp->tlp_high_seq && tp->tlp_retrans) + tp->undo_retrans++; + /* Finally, avoid 0, because undo_retrans==0 means "can undo now": */ + if (!tp->undo_retrans) + tp->undo_retrans = -1; } /* If we detect SACK reneging, forget all SACK information @@ -2121,8 +2337,7 @@ static void tcp_timeout_mark_lost(struct sock *sk) skb_rbtree_walk_from(skb) { if (is_reneg) TCP_SKB_CB(skb)->sacked &= ~TCPCB_SACKED_ACKED; - else if (tcp_is_rack(sk) && skb != head && - tcp_rack_skb_timeout(tp, skb, 0) > 0) + else if (skb != head && tcp_rack_skb_timeout(tp, skb, 0) > 0) continue; /* Don't mark recently sent ones lost yet */ tcp_mark_skb_lost(sk, skb); } @@ -2137,6 +2352,7 @@ void tcp_enter_loss(struct sock *sk) struct tcp_sock *tp = tcp_sk(sk); struct net *net = sock_net(sk); bool new_recovery = icsk->icsk_ca_state < TCP_CA_Recovery; + u8 reordering; tcp_timeout_mark_lost(sk); @@ -2145,31 +2361,34 @@ void tcp_enter_loss(struct sock *sk) !after(tp->high_seq, tp->snd_una) || (icsk->icsk_ca_state == TCP_CA_Loss && !icsk->icsk_retransmits)) { tp->prior_ssthresh = tcp_current_ssthresh(sk); - tp->prior_cwnd = tp->snd_cwnd; + tp->prior_cwnd = tcp_snd_cwnd(tp); tp->snd_ssthresh = icsk->icsk_ca_ops->ssthresh(sk); tcp_ca_event(sk, CA_EVENT_LOSS); tcp_init_undo(tp); } - tp->snd_cwnd = tcp_packets_in_flight(tp) + 1; + tcp_snd_cwnd_set(tp, tcp_packets_in_flight(tp) + 1); tp->snd_cwnd_cnt = 0; tp->snd_cwnd_stamp = tcp_jiffies32; /* Timeout in disordered state after receiving substantial DUPACKs * suggests that the degree of reordering is over-estimated. */ + reordering = READ_ONCE(net->ipv4.sysctl_tcp_reordering); if (icsk->icsk_ca_state <= TCP_CA_Disorder && - tp->sacked_out >= net->ipv4.sysctl_tcp_reordering) + tp->sacked_out >= reordering) tp->reordering = min_t(unsigned int, tp->reordering, - net->ipv4.sysctl_tcp_reordering); + reordering); + tcp_set_ca_state(sk, TCP_CA_Loss); tp->high_seq = tp->snd_nxt; + tp->tlp_high_seq = 0; tcp_ecn_queue_cwr(tp); /* F-RTO RFC5682 sec 3.1 step 1: retransmit SND.UNA if no previous * loss recovery is underway except recurring timeout(s) on * the same SND.UNA (sec 3.2). Disable F-RTO on path MTU probing */ - tp->frto = net->ipv4.sysctl_tcp_frto && + tp->frto = READ_ONCE(net->ipv4.sysctl_tcp_frto) && (new_recovery || icsk->icsk_retransmits) && !inet_csk(sk)->icsk_mtup.probe_size; } @@ -2184,36 +2403,21 @@ void tcp_enter_loss(struct sock *sk) * restore sanity to the SACK scoreboard. If the apparent reneging * persists until this RTO then we'll clear the SACK scoreboard. */ -static bool tcp_check_sack_reneging(struct sock *sk, int flag) +static bool tcp_check_sack_reneging(struct sock *sk, int *ack_flag) { - if (flag & FLAG_SACK_RENEGING) { + if (*ack_flag & FLAG_SACK_RENEGING && + *ack_flag & FLAG_SND_UNA_ADVANCED) { struct tcp_sock *tp = tcp_sk(sk); unsigned long delay = max(usecs_to_jiffies(tp->srtt_us >> 4), msecs_to_jiffies(10)); - inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - delay, TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, delay, false); + *ack_flag &= ~FLAG_SET_XMIT_TIMER; return true; } return false; } -/* Heurestics to calculate number of duplicate ACKs. There's no dupACKs - * counter when SACK is enabled (without SACK, sacked_out is used for - * that purpose). - * - * With reordering, holes may still be in flight, so RFC3517 recovery - * uses pure sacked_out (total number of SACKed segments) even though - * it violates the RFC that uses duplicate ACKs, often these are equal - * but when e.g. out-of-window ACKs or packet duplication occurs, - * they differ. Since neither occurs due to loss, TCP should really - * ignore them. - */ -static inline int tcp_dupack_heuristics(const struct tcp_sock *tp) -{ - return tp->sacked_out + 1; -} - /* Linux NewReno/SACK/ECN state machine. * -------------------------------------- * @@ -2266,13 +2470,7 @@ static inline int tcp_dupack_heuristics(const struct tcp_sock *tp) * * If the receiver supports SACK: * - * RFC6675/3517: It is the conventional algorithm. A packet is - * considered lost if the number of higher sequence packets - * SACKed is greater than or equal the DUPACK thoreshold - * (reordering). This is implemented in tcp_mark_head_lost and - * tcp_update_scoreboard. - * - * RACK (draft-ietf-tcpm-rack-01): it is a newer algorithm + * RACK (RFC8985): RACK is a newer loss detection algorithm * (2017-) that checks timing instead of counting DUPACKs. * Essentially a packet is considered lost if it's not S/ACKed * after RTT + reordering_window, where both metrics are @@ -2287,8 +2485,8 @@ static inline int tcp_dupack_heuristics(const struct tcp_sock *tp) * is lost (NewReno). This heuristics are the same in NewReno * and SACK. * - * Really tricky (and requiring careful tuning) part of algorithm - * is hidden in functions tcp_time_to_recover() and tcp_xmit_retransmit_queue(). + * The really tricky (and requiring careful tuning) part of the algorithm + * is hidden in the RACK code in tcp_recovery.c and tcp_xmit_retransmit_queue(). * The first determines the moment _when_ we should reduce CWND and, * hence, slow down forward transmission. In fact, it determines the moment * when we decide that hole is caused by loss, rather than by a reorder. @@ -2311,83 +2509,10 @@ static inline int tcp_dupack_heuristics(const struct tcp_sock *tp) * Main question: may we further continue forward transmission * with the same cwnd? */ -static bool tcp_time_to_recover(struct sock *sk, int flag) -{ - struct tcp_sock *tp = tcp_sk(sk); - - /* Trick#1: The loss is proven. */ - if (tp->lost_out) - return true; - - /* Not-A-Trick#2 : Classic rule... */ - if (!tcp_is_rack(sk) && tcp_dupack_heuristics(tp) > tp->reordering) - return true; - - return false; -} - -/* Detect loss in event "A" above by marking head of queue up as lost. - * For RFC3517 SACK, a segment is considered lost if it - * has at least tp->reordering SACKed seqments above it; "packets" refers to - * the maximum SACKed segments to pass before reaching this limit. - */ -static void tcp_mark_head_lost(struct sock *sk, int packets, int mark_head) -{ - struct tcp_sock *tp = tcp_sk(sk); - struct sk_buff *skb; - int cnt; - /* Use SACK to deduce losses of new sequences sent during recovery */ - const u32 loss_high = tp->snd_nxt; - - WARN_ON(packets > tp->packets_out); - skb = tp->lost_skb_hint; - if (skb) { - /* Head already handled? */ - if (mark_head && after(TCP_SKB_CB(skb)->seq, tp->snd_una)) - return; - cnt = tp->lost_cnt_hint; - } else { - skb = tcp_rtx_queue_head(sk); - cnt = 0; - } - - skb_rbtree_walk_from(skb) { - /* TODO: do this better */ - /* this is not the most efficient way to do this... */ - tp->lost_skb_hint = skb; - tp->lost_cnt_hint = cnt; - - if (after(TCP_SKB_CB(skb)->end_seq, loss_high)) - break; - - if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED) - cnt += tcp_skb_pcount(skb); - - if (cnt > packets) - break; - - if (!(TCP_SKB_CB(skb)->sacked & TCPCB_LOST)) - tcp_mark_skb_lost(sk, skb); - - if (mark_head) - break; - } - tcp_verify_left_out(tp); -} - -/* Account newly detected lost packet(s) */ - -static void tcp_update_scoreboard(struct sock *sk, int fast_rexmit) +static bool tcp_time_to_recover(const struct tcp_sock *tp) { - struct tcp_sock *tp = tcp_sk(sk); - - if (tcp_is_sack(tp)) { - int sacked_upto = tp->sacked_out - tp->reordering; - if (sacked_upto >= 0) - tcp_mark_head_lost(sk, sacked_upto, 0); - else if (fast_rexmit) - tcp_mark_head_lost(sk, 1, 1); - } + /* Has loss detection marked at least one packet lost? */ + return tp->lost_out != 0; } static bool tcp_tsopt_ecr_before(const struct tcp_sock *tp, u32 when) @@ -2403,7 +2528,7 @@ static bool tcp_skb_spurious_retrans(const struct tcp_sock *tp, const struct sk_buff *skb) { return (TCP_SKB_CB(skb)->sacked & TCPCB_RETRANS) && - tcp_tsopt_ecr_before(tp, tcp_skb_timestamp(skb)); + tcp_tsopt_ecr_before(tp, tcp_skb_timestamp_ts(tp->tcp_usec_ts, skb)); } /* Nothing was retransmitted or returned timestamp is less @@ -2411,8 +2536,35 @@ static bool tcp_skb_spurious_retrans(const struct tcp_sock *tp, */ static inline bool tcp_packet_delayed(const struct tcp_sock *tp) { - return tp->retrans_stamp && - tcp_tsopt_ecr_before(tp, tp->retrans_stamp); + const struct sock *sk = (const struct sock *)tp; + + /* Received an echoed timestamp before the first retransmission? */ + if (tp->retrans_stamp) + return tcp_tsopt_ecr_before(tp, tp->retrans_stamp); + + /* We set tp->retrans_stamp upon the first retransmission of a loss + * recovery episode, so normally if tp->retrans_stamp is 0 then no + * retransmission has happened yet (likely due to TSQ, which can cause + * fast retransmits to be delayed). So if snd_una advanced while + * (tp->retrans_stamp is 0 then apparently a packet was merely delayed, + * not lost. But there are exceptions where we retransmit but then + * clear tp->retrans_stamp, so we check for those exceptions. + */ + + /* (1) For non-SACK connections, tcp_is_non_sack_preventing_reopen() + * clears tp->retrans_stamp when snd_una == high_seq. + */ + if (!tcp_is_sack(tp) && !before(tp->snd_una, tp->high_seq)) + return false; + + /* (2) In TCP_SYN_SENT tcp_clean_rtx_queue() clears tp->retrans_stamp + * when setting FLAG_SYN_ACKED is set, even if the SYN was + * retransmitted. + */ + if (sk->sk_state == TCP_SYN_SENT) + return false; + + return true; /* tp->retrans_stamp is zero; no retransmit yet */ } /* Undo procedures. */ @@ -2446,6 +2598,16 @@ static bool tcp_any_retrans_done(const struct sock *sk) return false; } +/* If loss recovery is finished and there are no retransmits out in the + * network, then we clear retrans_stamp so that upon the next loss recovery + * retransmits_timed_out() and timestamp-undo are using the correct value. + */ +static void tcp_retrans_stamp_cleanup(struct sock *sk) +{ + if (!tcp_any_retrans_done(sk)) + tcp_sk(sk)->retrans_stamp = 0; +} + static void DBGUNDO(struct sock *sk, const char *msg) { #if FASTRETRANS_DEBUG > 1 @@ -2456,7 +2618,7 @@ static void DBGUNDO(struct sock *sk, const char *msg) pr_debug("Undo %s %pI4/%u c%u l%u ss%u/%u p%u\n", msg, &inet->inet_daddr, ntohs(inet->inet_dport), - tp->snd_cwnd, tcp_left_out(tp), + tcp_snd_cwnd(tp), tcp_left_out(tp), tp->snd_ssthresh, tp->prior_ssthresh, tp->packets_out); } @@ -2465,7 +2627,7 @@ static void DBGUNDO(struct sock *sk, const char *msg) pr_debug("Undo %s %pI6/%u c%u l%u ss%u/%u p%u\n", msg, &sk->sk_v6_daddr, ntohs(inet->inet_dport), - tp->snd_cwnd, tcp_left_out(tp), + tcp_snd_cwnd(tp), tcp_left_out(tp), tp->snd_ssthresh, tp->prior_ssthresh, tp->packets_out); } @@ -2490,7 +2652,7 @@ static void tcp_undo_cwnd_reduction(struct sock *sk, bool unmark_loss) if (tp->prior_ssthresh) { const struct inet_connection_sock *icsk = inet_csk(sk); - tp->snd_cwnd = icsk->icsk_ca_ops->undo_cwnd(sk); + tcp_snd_cwnd_set(tp, icsk->icsk_ca_ops->undo_cwnd(sk)); if (tp->prior_ssthresh > tp->snd_ssthresh) { tp->snd_ssthresh = tp->prior_ssthresh; @@ -2507,6 +2669,21 @@ static inline bool tcp_may_undo(const struct tcp_sock *tp) return tp->undo_marker && (!tp->undo_retrans || tcp_packet_delayed(tp)); } +static bool tcp_is_non_sack_preventing_reopen(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tp->snd_una == tp->high_seq && tcp_is_reno(tp)) { + /* Hold old state until something *above* high_seq + * is ACKed. For Reno it is MUST to prevent false + * fast retransmits (RFC2582). SACK TCP is safe. */ + if (!tcp_any_retrans_done(sk)) + tp->retrans_stamp = 0; + return true; + } + return false; +} + /* People celebrate: "We love our President!" */ static bool tcp_try_undo_recovery(struct sock *sk) { @@ -2529,14 +2706,8 @@ static bool tcp_try_undo_recovery(struct sock *sk) } else if (tp->rack.reo_wnd_persist) { tp->rack.reo_wnd_persist--; } - if (tp->snd_una == tp->high_seq && tcp_is_reno(tp)) { - /* Hold old state until something *above* high_seq - * is ACKed. For Reno it is MUST to prevent false - * fast retransmits (RFC2582). SACK TCP is safe. */ - if (!tcp_any_retrans_done(sk)) - tp->retrans_stamp = 0; + if (tcp_is_non_sack_preventing_reopen(sk)) return true; - } tcp_set_ca_state(sk, TCP_CA_Open); tp->is_sack_reneg = 0; return false; @@ -2571,7 +2742,9 @@ static bool tcp_try_undo_loss(struct sock *sk, bool frto_undo) if (frto_undo) NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPSPURIOUSRTOS); - inet_csk(sk)->icsk_retransmits = 0; + WRITE_ONCE(inet_csk(sk)->icsk_retransmits, 0); + if (tcp_is_non_sack_preventing_reopen(sk)) + return true; if (frto_undo || tcp_is_sack(tp)) { tcp_set_ca_state(sk, TCP_CA_Open); tp->is_sack_reneg = 0; @@ -2597,7 +2770,7 @@ static void tcp_init_cwnd_reduction(struct sock *sk) tp->high_seq = tp->snd_nxt; tp->tlp_high_seq = 0; tp->snd_cwnd_cnt = 0; - tp->prior_cwnd = tp->snd_cwnd; + tp->prior_cwnd = tcp_snd_cwnd(tp); tp->prr_delivered = 0; tp->prr_out = 0; tp->snd_ssthresh = inet_csk(sk)->icsk_ca_ops->ssthresh(sk); @@ -2613,21 +2786,23 @@ void tcp_cwnd_reduction(struct sock *sk, int newly_acked_sacked, int newly_lost, if (newly_acked_sacked <= 0 || WARN_ON_ONCE(!tp->prior_cwnd)) return; + trace_tcp_cwnd_reduction_tp(sk, newly_acked_sacked, newly_lost, flag); + tp->prr_delivered += newly_acked_sacked; if (delta < 0) { u64 dividend = (u64)tp->snd_ssthresh * tp->prr_delivered + tp->prior_cwnd - 1; sndcnt = div_u64(dividend, tp->prior_cwnd) - tp->prr_out; - } else if (flag & FLAG_SND_UNA_ADVANCED && !newly_lost) { - sndcnt = min_t(int, delta, - max_t(int, tp->prr_delivered - tp->prr_out, - newly_acked_sacked) + 1); } else { - sndcnt = min(delta, newly_acked_sacked); + sndcnt = max_t(int, tp->prr_delivered - tp->prr_out, + newly_acked_sacked); + if (flag & FLAG_SND_UNA_ADVANCED && !newly_lost) + sndcnt++; + sndcnt = min(delta, sndcnt); } /* Force a fast retransmit upon entering fast recovery */ sndcnt = max(sndcnt, (tp->prr_out ? 0 : 1)); - tp->snd_cwnd = tcp_packets_in_flight(tp) + sndcnt; + tcp_snd_cwnd_set(tp, tcp_packets_in_flight(tp) + sndcnt); } static inline void tcp_end_cwnd_reduction(struct sock *sk) @@ -2640,7 +2815,7 @@ static inline void tcp_end_cwnd_reduction(struct sock *sk) /* Reset cwnd to ssthresh in CWR or Recovery (unless it's undone) */ if (tp->snd_ssthresh < TCP_INFINITE_SSTHRESH && (inet_csk(sk)->icsk_ca_state == TCP_CA_CWR || tp->undo_marker)) { - tp->snd_cwnd = tp->snd_ssthresh; + tcp_snd_cwnd_set(tp, tp->snd_ssthresh); tp->snd_cwnd_stamp = tcp_jiffies32; } tcp_ca_event(sk, CA_EVENT_COMPLETE_CWR); @@ -2704,12 +2879,15 @@ static void tcp_mtup_probe_success(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); struct inet_connection_sock *icsk = inet_csk(sk); + u64 val; - /* FIXME: breaks with very large cwnd */ tp->prior_ssthresh = tcp_current_ssthresh(sk); - tp->snd_cwnd = tp->snd_cwnd * - tcp_mss_to_mtu(sk, tp->mss_cache) / - icsk->icsk_mtup.probe_size; + + val = (u64)tcp_snd_cwnd(tp) * tcp_mss_to_mtu(sk, tp->mss_cache); + do_div(val, icsk->icsk_mtup.probe_size); + DEBUG_NET_WARN_ON_ONCE((u32)val != val); + tcp_snd_cwnd_set(tp, max_t(u32, 1U, val)); + tp->snd_cwnd_cnt = 0; tp->snd_cwnd_stamp = tcp_jiffies32; tp->snd_ssthresh = tcp_current_ssthresh(sk); @@ -2720,13 +2898,37 @@ static void tcp_mtup_probe_success(struct sock *sk) NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMTUPSUCCESS); } +/* Sometimes we deduce that packets have been dropped due to reasons other than + * congestion, like path MTU reductions or failed client TFO attempts. In these + * cases we call this function to retransmit as many packets as cwnd allows, + * without reducing cwnd. Given that retransmits will set retrans_stamp to a + * non-zero value (and may do so in a later calling context due to TSQ), we + * also enter CA_Loss so that we track when all retransmitted packets are ACKed + * and clear retrans_stamp when that happens (to ensure later recurring RTOs + * are using the correct retrans_stamp and don't declare ETIMEDOUT + * prematurely). + */ +static void tcp_non_congestion_loss_retransmit(struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + + if (icsk->icsk_ca_state != TCP_CA_Loss) { + tp->high_seq = tp->snd_nxt; + tp->snd_ssthresh = tcp_current_ssthresh(sk); + tp->prior_ssthresh = 0; + tp->undo_marker = 0; + tcp_set_ca_state(sk, TCP_CA_Loss); + } + tcp_xmit_retransmit_queue(sk); +} + /* Do a simple retransmit without using the backoff mechanisms in * tcp_timer. This is used for path mtu discovery. * The socket is already locked here. */ void tcp_simple_retransmit(struct sock *sk) { - const struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); struct sk_buff *skb; int mss; @@ -2751,8 +2953,6 @@ void tcp_simple_retransmit(struct sock *sk) tcp_mark_skb_lost(sk, skb); } - tcp_clear_retrans_hints_partial(tp); - if (!tp->lost_out) return; @@ -2766,22 +2966,18 @@ void tcp_simple_retransmit(struct sock *sk) * in network, but units changed and effective * cwnd/ssthresh really reduced now. */ - if (icsk->icsk_ca_state != TCP_CA_Loss) { - tp->high_seq = tp->snd_nxt; - tp->snd_ssthresh = tcp_current_ssthresh(sk); - tp->prior_ssthresh = 0; - tp->undo_marker = 0; - tcp_set_ca_state(sk, TCP_CA_Loss); - } - tcp_xmit_retransmit_queue(sk); + tcp_non_congestion_loss_retransmit(sk); } -EXPORT_SYMBOL(tcp_simple_retransmit); +EXPORT_IPV6_MOD(tcp_simple_retransmit); void tcp_enter_recovery(struct sock *sk, bool ece_ack) { struct tcp_sock *tp = tcp_sk(sk); int mib_idx; + /* Start the clock with our fast retransmit, for undo and ETIMEDOUT. */ + tcp_retrans_stamp_cleanup(sk); + if (tcp_is_reno(tp)) mib_idx = LINUX_MIB_TCPRENORECOVERY; else @@ -2800,6 +2996,14 @@ void tcp_enter_recovery(struct sock *sk, bool ece_ack) tcp_set_ca_state(sk, TCP_CA_Recovery); } +static void tcp_update_rto_time(struct tcp_sock *tp) +{ + if (tp->rto_stamp) { + tp->total_rto_time += tcp_time_stamp_ms(tp) - tp->rto_stamp; + tp->rto_stamp = 0; + } +} + /* Process an ACK in CA_Loss state. Move to CA_Open if lost data are * recovered or spurious. Otherwise retransmits more on partial ACKs. */ @@ -2846,7 +3050,7 @@ static void tcp_process_loss(struct sock *sk, int flag, int num_dupack, } if (tcp_is_reno(tp)) { /* A Reno DUPACK means new data in F-RTO step 2.b above are - * delivered. Lower inflight to clock out (re)tranmissions. + * delivered. Lower inflight to clock out (re)transmissions. */ if (after(tp->snd_nxt, tp->high_seq) && num_dupack) tcp_add_reno_sack(sk, num_dupack, flag & FLAG_ECE); @@ -2856,17 +3060,8 @@ static void tcp_process_loss(struct sock *sk, int flag, int num_dupack, *rexmit = REXMIT_LOST; } -static bool tcp_force_fast_retransmit(struct sock *sk) -{ - struct tcp_sock *tp = tcp_sk(sk); - - return after(tcp_highest_sack_seq(tp), - tp->snd_una + tp->reordering * tp->mss_cache); -} - /* Undo during fast recovery after partial ACK. */ -static bool tcp_try_undo_partial(struct sock *sk, u32 prior_snd_una, - bool *do_lost) +static bool tcp_try_undo_partial(struct sock *sk, u32 prior_snd_una) { struct tcp_sock *tp = tcp_sk(sk); @@ -2891,9 +3086,6 @@ static bool tcp_try_undo_partial(struct sock *sk, u32 prior_snd_una, tcp_undo_cwnd_reduction(sk, true); NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPPARTIALUNDO); tcp_try_keep_open(sk); - } else { - /* Partial ACK arrived. Force fast retransmit. */ - *do_lost = tcp_force_fast_retransmit(sk); } return false; } @@ -2907,7 +3099,7 @@ static void tcp_identify_packet_loss(struct sock *sk, int *ack_flag) if (unlikely(tcp_is_reno(tp))) { tcp_newreno_mark_lost(sk, *ack_flag & FLAG_SND_UNA_ADVANCED); - } else if (tcp_is_rack(sk)) { + } else { u32 prior_retrans = tp->retrans_out; if (tcp_rack_mark_lost(sk)) @@ -2934,10 +3126,8 @@ static void tcp_fastretrans_alert(struct sock *sk, const u32 prior_snd_una, { struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); - int fast_rexmit = 0, flag = *ack_flag; + int flag = *ack_flag; bool ece_ack = flag & FLAG_ECE; - bool do_lost = num_dupack || ((flag & FLAG_DATA_SACKED) && - tcp_force_fast_retransmit(sk)); if (!tp->packets_out && tp->sacked_out) tp->sacked_out = 0; @@ -2948,7 +3138,7 @@ static void tcp_fastretrans_alert(struct sock *sk, const u32 prior_snd_una, tp->prior_ssthresh = 0; /* B. In all the states check for reneging SACKs. */ - if (tcp_check_sack_reneging(sk, flag)) + if (tcp_check_sack_reneging(sk, ack_flag)) return; /* C. Check consistency of the current state. */ @@ -2986,15 +3176,15 @@ static void tcp_fastretrans_alert(struct sock *sk, const u32 prior_snd_una, if (!(flag & FLAG_SND_UNA_ADVANCED)) { if (tcp_is_reno(tp)) tcp_add_reno_sack(sk, num_dupack, ece_ack); - } else if (tcp_try_undo_partial(sk, prior_snd_una, &do_lost)) + } else if (tcp_try_undo_partial(sk, prior_snd_una)) return; if (tcp_try_undo_dsack(sk)) - tcp_try_keep_open(sk); + tcp_try_to_open(sk, flag); tcp_identify_packet_loss(sk, ack_flag); if (icsk->icsk_ca_state != TCP_CA_Recovery) { - if (!tcp_time_to_recover(sk, flag)) + if (!tcp_time_to_recover(tp)) return; /* Undo reverts the recovery state. If loss is evident, * starts a new recovery (e.g. reordering then loss); @@ -3004,6 +3194,8 @@ static void tcp_fastretrans_alert(struct sock *sk, const u32 prior_snd_una, break; case TCP_CA_Loss: tcp_process_loss(sk, flag, num_dupack, rexmit); + if (icsk->icsk_ca_state != TCP_CA_Loss) + tcp_update_rto_time(tp); tcp_identify_packet_loss(sk, ack_flag); if (!(icsk->icsk_ca_state == TCP_CA_Open || (*ack_flag & FLAG_LOST_RETRANS))) @@ -3021,7 +3213,7 @@ static void tcp_fastretrans_alert(struct sock *sk, const u32 prior_snd_una, tcp_try_undo_dsack(sk); tcp_identify_packet_loss(sk, ack_flag); - if (!tcp_time_to_recover(sk, flag)) { + if (!tcp_time_to_recover(tp)) { tcp_try_to_open(sk, flag); return; } @@ -3032,24 +3224,21 @@ static void tcp_fastretrans_alert(struct sock *sk, const u32 prior_snd_una, tp->snd_una == tp->mtu_probe.probe_seq_start) { tcp_mtup_probe_failed(sk); /* Restores the reduction we did in tcp_mtup_probe() */ - tp->snd_cwnd++; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); tcp_simple_retransmit(sk); return; } /* Otherwise enter Recovery state */ tcp_enter_recovery(sk, ece_ack); - fast_rexmit = 1; } - if (!tcp_is_rack(sk) && do_lost) - tcp_update_scoreboard(sk, fast_rexmit); *rexmit = REXMIT_LOST; } static void tcp_update_rtt_min(struct sock *sk, u32 rtt_us, const int flag) { - u32 wlen = sock_net(sk)->ipv4.sysctl_tcp_min_rtt_wlen * HZ; + u32 wlen = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_rtt_wlen) * HZ; struct tcp_sock *tp = tcp_sk(sk); if ((flag & FLAG_ACK_MAYBE_DELAYED) && rtt_us > tcp_min_rtt(tp)) { @@ -3083,17 +3272,10 @@ static bool tcp_ack_update_rtt(struct sock *sk, const int flag, * left edge of the send window. * See draft-ietf-tcplw-high-performance-00, section 3.3. */ - if (seq_rtt_us < 0 && tp->rx_opt.saw_tstamp && tp->rx_opt.rcv_tsecr && - flag & FLAG_ACKED) { - u32 delta = tcp_time_stamp(tp) - tp->rx_opt.rcv_tsecr; - - if (likely(delta < INT_MAX / (USEC_PER_SEC / TCP_TS_HZ))) { - if (!delta) - delta = 1; - seq_rtt_us = delta * (USEC_PER_SEC / TCP_TS_HZ); - ca_rtt_us = seq_rtt_us; - } - } + if (seq_rtt_us < 0 && tp->rx_opt.saw_tstamp && + tp->rx_opt.rcv_tsecr && flag & FLAG_ACKED) + seq_rtt_us = ca_rtt_us = tcp_rtt_tsopt_us(tp, 1); + rs->rtt_us = ca_rtt_us; /* RTT of last (S)ACKed packet (or -1) */ if (seq_rtt_us < 0) return false; @@ -3159,8 +3341,7 @@ void tcp_rearm_rto(struct sock *sk) */ rto = usecs_to_jiffies(max_t(int, delta_us, 1)); } - tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, rto, - TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, rto, true); } } @@ -3272,6 +3453,8 @@ static int tcp_clean_rtx_queue(struct sock *sk, const struct sk_buff *ack_skb, if (sacked & TCPCB_SACKED_ACKED) { tp->sacked_out -= acked_pcount; + /* snd_una delta covers these skbs */ + sack->delivered_bytes -= skb->len; } else if (tcp_is_sack(tp)) { tcp_count_delivered(tp, acked_pcount, ece_ack); if (!tcp_skb_spurious_retrans(tp, skb)) @@ -3307,8 +3490,6 @@ static int tcp_clean_rtx_queue(struct sock *sk, const struct sk_buff *ack_skb, next = skb_rb_next(skb); if (unlikely(skb == tp->retransmit_skb_hint)) tp->retransmit_skb_hint = NULL; - if (unlikely(skb == tp->lost_skb_hint)) - tp->lost_skb_hint = NULL; tcp_highest_sack_replace(sk, skb, next); tcp_rtx_queue_unlink_and_free(skb, sk); } @@ -3366,15 +3547,14 @@ static int tcp_clean_rtx_queue(struct sock *sk, const struct sk_buff *ack_skb, if (flag & FLAG_RETRANS_DATA_ACKED) flag &= ~FLAG_ORIG_SACK_ACKED; } else { - int delta; - /* Non-retransmitted hole got filled? That's reordering */ if (before(reord, prior_fack)) tcp_check_sack_reordering(sk, reord, 0); - - delta = prior_sacked - tp->sacked_out; - tp->lost_cnt_hint -= min(tp->lost_cnt_hint, delta); } + + sack->delivered_bytes = (skb ? + TCP_SKB_CB(skb)->seq : tp->snd_una) - + prior_snd_una; } else if (skb && rtt_update && sack_rtt_us >= 0 && sack_rtt_us > tcp_stamp_us_delta(tp->tcp_mstamp, tcp_skb_timestamp_us(skb))) { @@ -3437,10 +3617,10 @@ static void tcp_ack_probe(struct sock *sk) * This function is not for random using! */ } else { - unsigned long when = tcp_probe0_when(sk, TCP_RTO_MAX); + unsigned long when = tcp_probe0_when(sk, tcp_rto_max(sk)); when = tcp_clamp_probe0_to_user_timeout(sk, when); - tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, when, TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, when, true); } } @@ -3459,7 +3639,8 @@ static inline bool tcp_may_raise_cwnd(const struct sock *sk, const int flag) * new SACK or ECE mark may first advance cwnd here and later reduce * cwnd in tcp_fastretrans_alert() based on more states. */ - if (tcp_sk(sk)->reordering > sock_net(sk)->ipv4.sysctl_tcp_reordering) + if (tcp_sk(sk)->reordering > + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reordering)) return flag & FLAG_FORWARD_PROGRESS; return flag & FLAG_DATA_ACKED; @@ -3476,7 +3657,7 @@ static void tcp_cong_control(struct sock *sk, u32 ack, u32 acked_sacked, const struct inet_connection_sock *icsk = inet_csk(sk); if (icsk->icsk_ca_ops->cong_control) { - icsk->icsk_ca_ops->cong_control(sk, rs); + icsk->icsk_ca_ops->cong_control(sk, ack, flag, rs); return; } @@ -3499,7 +3680,24 @@ static inline bool tcp_may_update_window(const struct tcp_sock *tp, { return after(ack, tp->snd_una) || after(ack_seq, tp->snd_wl1) || - (ack_seq == tp->snd_wl1 && nwin > tp->snd_wnd); + (ack_seq == tp->snd_wl1 && (nwin > tp->snd_wnd || !nwin)); +} + +static void tcp_snd_sne_update(struct tcp_sock *tp, u32 ack) +{ +#ifdef CONFIG_TCP_AO + struct tcp_ao_info *ao; + + if (!static_branch_unlikely(&tcp_ao_needed.key)) + return; + + ao = rcu_dereference_protected(tp->ao_info, + lockdep_sock_is_held((struct sock *)tp)); + if (ao && ack < tp->snd_una) { + ao->snd_sne++; + trace_tcp_ao_snd_sne_update((struct sock *)tp, ao->snd_sne); + } +#endif } /* If we update tp->snd_una, also update tp->bytes_acked */ @@ -3509,9 +3707,27 @@ static void tcp_snd_una_update(struct tcp_sock *tp, u32 ack) sock_owned_by_me((struct sock *)tp); tp->bytes_acked += delta; + tcp_snd_sne_update(tp, ack); tp->snd_una = ack; } +static void tcp_rcv_sne_update(struct tcp_sock *tp, u32 seq) +{ +#ifdef CONFIG_TCP_AO + struct tcp_ao_info *ao; + + if (!static_branch_unlikely(&tcp_ao_needed.key)) + return; + + ao = rcu_dereference_protected(tp->ao_info, + lockdep_sock_is_held((struct sock *)tp)); + if (ao && seq < tp->rcv_nxt) { + ao->rcv_sne++; + trace_tcp_ao_rcv_sne_update((struct sock *)tp, ao->rcv_sne); + } +#endif +} + /* If we update tp->rcv_nxt, also update tp->bytes_received */ static void tcp_rcv_nxt_update(struct tcp_sock *tp, u32 seq) { @@ -3519,6 +3735,7 @@ static void tcp_rcv_nxt_update(struct tcp_sock *tp, u32 seq) sock_owned_by_me((struct sock *)tp); tp->bytes_received += delta; + tcp_rcv_sne_update(tp, seq); WRITE_ONCE(tp->rcv_nxt, seq); } @@ -3568,16 +3785,23 @@ static int tcp_ack_update_window(struct sock *sk, const struct sk_buff *skb, u32 static bool __tcp_oow_rate_limited(struct net *net, int mib_idx, u32 *last_oow_ack_time) { - if (*last_oow_ack_time) { - s32 elapsed = (s32)(tcp_jiffies32 - *last_oow_ack_time); + /* Paired with the WRITE_ONCE() in this function. */ + u32 val = READ_ONCE(*last_oow_ack_time); + + if (val) { + s32 elapsed = (s32)(tcp_jiffies32 - val); - if (0 <= elapsed && elapsed < net->ipv4.sysctl_tcp_invalid_ratelimit) { + if (0 <= elapsed && + elapsed < READ_ONCE(net->ipv4.sysctl_tcp_invalid_ratelimit)) { NET_INC_STATS(net, mib_idx); return true; /* rate-limited: don't send yet! */ } } - *last_oow_ack_time = tcp_jiffies32; + /* Paired with the prior READ_ONCE() and with itself, + * as we might be lockless. + */ + WRITE_ONCE(*last_oow_ack_time, tcp_jiffies32); return false; /* not rate-limited: go ahead, send dupack now! */ } @@ -3600,15 +3824,22 @@ bool tcp_oow_rate_limited(struct net *net, const struct sk_buff *skb, return __tcp_oow_rate_limited(net, mib_idx, last_oow_ack_time); } +static void tcp_send_ack_reflect_ect(struct sock *sk, bool accecn_reflector) +{ + struct tcp_sock *tp = tcp_sk(sk); + u16 flags = 0; + + if (accecn_reflector) + flags = tcp_accecn_reflector_flags(tp->syn_ect_rcv); + __tcp_send_ack(sk, tp->rcv_nxt, flags); +} + /* RFC 5961 7 [ACK Throttling] */ -static void tcp_send_challenge_ack(struct sock *sk) +static void tcp_send_challenge_ack(struct sock *sk, bool accecn_reflector) { - /* unprotected vars, we dont care of overwrites */ - static u32 challenge_timestamp; - static unsigned int challenge_count; struct tcp_sock *tp = tcp_sk(sk); struct net *net = sock_net(sk); - u32 count, now; + u32 count, now, ack_limit; /* First check our per-socket dupack rate limit. */ if (__tcp_oow_rate_limited(net, @@ -3616,20 +3847,25 @@ static void tcp_send_challenge_ack(struct sock *sk) &tp->last_oow_ack_time)) return; + ack_limit = READ_ONCE(net->ipv4.sysctl_tcp_challenge_ack_limit); + if (ack_limit == INT_MAX) + goto send_ack; + /* Then check host-wide RFC 5961 rate limit. */ now = jiffies / HZ; - if (now != challenge_timestamp) { - u32 ack_limit = net->ipv4.sysctl_tcp_challenge_ack_limit; + if (now != READ_ONCE(net->ipv4.tcp_challenge_timestamp)) { u32 half = (ack_limit + 1) >> 1; - challenge_timestamp = now; - WRITE_ONCE(challenge_count, half + prandom_u32_max(ack_limit)); + WRITE_ONCE(net->ipv4.tcp_challenge_timestamp, now); + WRITE_ONCE(net->ipv4.tcp_challenge_count, + get_random_u32_inclusive(half, ack_limit + half - 1)); } - count = READ_ONCE(challenge_count); + count = READ_ONCE(net->ipv4.tcp_challenge_count); if (count > 0) { - WRITE_ONCE(challenge_count, count - 1); + WRITE_ONCE(net->ipv4.tcp_challenge_count, count - 1); +send_ack: NET_INC_STATS(net, LINUX_MIB_TCPCHALLENGEACK); - tcp_send_ack(sk); + tcp_send_ack_reflect_ect(sk, accecn_reflector); } } @@ -3639,8 +3875,16 @@ static void tcp_store_ts_recent(struct tcp_sock *tp) tp->rx_opt.ts_recent_stamp = ktime_get_seconds(); } -static void tcp_replace_ts_recent(struct tcp_sock *tp, u32 seq) +static int __tcp_replace_ts_recent(struct tcp_sock *tp, s32 tstamp_delta) { + tcp_store_ts_recent(tp); + return tstamp_delta > 0 ? FLAG_TS_PROGRESS : 0; +} + +static int tcp_replace_ts_recent(struct tcp_sock *tp, u32 seq) +{ + s32 delta; + if (tp->rx_opt.saw_tstamp && !after(seq, tp->rcv_wup)) { /* PAWS bug workaround wrt. ACK frames, the PAWS discard * extra check below makes sure this can only happen @@ -3649,13 +3893,17 @@ static void tcp_replace_ts_recent(struct tcp_sock *tp, u32 seq) * Not only, also it occurs for expired timestamps. */ - if (tcp_paws_check(&tp->rx_opt, 0)) - tcp_store_ts_recent(tp); + if (tcp_paws_check(&tp->rx_opt, 0)) { + delta = tp->rx_opt.rcv_tsval - tp->rx_opt.ts_recent; + return __tcp_replace_ts_recent(tp, delta); + } } + + return 0; } /* This routine deals with acks during a TLP episode and ends an episode by - * resetting tlp_high_seq. Ref: TLP algorithm in draft-ietf-tcpm-rack + * resetting tlp_high_seq. Ref: TLP algorithm in RFC8985 */ static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) { @@ -3687,12 +3935,23 @@ static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) } } -static inline void tcp_in_ack_event(struct sock *sk, u32 flags) +static void tcp_in_ack_event(struct sock *sk, int flag) { const struct inet_connection_sock *icsk = inet_csk(sk); - if (icsk->icsk_ca_ops->in_ack_event) - icsk->icsk_ca_ops->in_ack_event(sk, flags); + if (icsk->icsk_ca_ops->in_ack_event) { + u32 ack_ev_flags = 0; + + if (flag & FLAG_WIN_UPDATE) + ack_ev_flags |= CA_ACK_WIN_UPDATE; + if (flag & FLAG_SLOWPATH) { + ack_ev_flags |= CA_ACK_SLOWPATH; + if (flag & FLAG_ECE) + ack_ev_flags |= CA_ACK_ECE; + } + + icsk->icsk_ca_ops->in_ack_event(sk, ack_ev_flags); + } } /* Congestion control has updated the cwnd already. So if we're in @@ -3717,7 +3976,8 @@ static void tcp_xmit_recovery(struct sock *sk, int rexmit) } /* Returns the number of packets newly acked or sacked by the current ACK */ -static u32 tcp_newly_delivered(struct sock *sk, u32 prior_delivered, int flag) +static u32 tcp_newly_delivered(struct sock *sk, u32 prior_delivered, + u32 ecn_count, int flag) { const struct net *net = sock_net(sk); struct tcp_sock *tp = tcp_sk(sk); @@ -3725,8 +3985,12 @@ static u32 tcp_newly_delivered(struct sock *sk, u32 prior_delivered, int flag) delivered = tp->delivered - prior_delivered; NET_ADD_STATS(net, LINUX_MIB_TCPDELIVERED, delivered); - if (flag & FLAG_ECE) - NET_ADD_STATS(net, LINUX_MIB_TCPDELIVEREDCE, delivered); + + if (flag & FLAG_ECE) { + if (tcp_ecn_mode_rfc3168(tp)) + ecn_count = delivered; + NET_ADD_STATS(net, LINUX_MIB_TCPDELIVEREDCE, ecn_count); + } return delivered; } @@ -3747,11 +4011,13 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) u32 delivered = tp->delivered; u32 lost = tp->lost; int rexmit = REXMIT_NONE; /* Flag to (re)transmit to recover losses */ + u32 ecn_count = 0; /* Did we receive ECE/an AccECN ACE update? */ u32 prior_fack; sack_state.first_sackt = 0; sack_state.rate = &rs; sack_state.sack_delivered = 0; + sack_state.delivered_bytes = 0; /* We very likely will need to access rtx queue. */ prefetch(sk->tcp_rtx_queue.rb_node); @@ -3760,11 +4026,15 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) * then we can probably ignore it. */ if (before(ack, prior_snd_una)) { + u32 max_window; + + /* do not accept ACK for bytes we never sent. */ + max_window = min_t(u64, tp->max_window, tp->bytes_acked); /* RFC 5961 5.2 [Blind Data Injection Attack].[Mitigation] */ - if (before(ack, prior_snd_una - tp->max_window)) { + if (before(ack, prior_snd_una - max_window)) { if (!(flag & FLAG_NO_CHALLENGE_ACK)) - tcp_send_challenge_ack(sk); - return -1; + tcp_send_challenge_ack(sk, false); + return -SKB_DROP_REASON_TCP_TOO_OLD_ACK; } goto old_ack; } @@ -3773,16 +4043,16 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) * this segment (RFC793 Section 3.9). */ if (after(ack, tp->snd_nxt)) - return -1; + return -SKB_DROP_REASON_TCP_ACK_UNSENT_DATA; if (after(ack, prior_snd_una)) { flag |= FLAG_SND_UNA_ADVANCED; - icsk->icsk_retransmits = 0; + WRITE_ONCE(icsk->icsk_retransmits, 0); #if IS_ENABLED(CONFIG_TLS_DEVICE) if (static_branch_unlikely(&clean_acked_data_enabled.key)) - if (icsk->icsk_clean_acked) - icsk->icsk_clean_acked(sk, ack); + if (tp->tcp_clean_acked) + tp->tcp_clean_acked(sk, ack); #endif } @@ -3793,7 +4063,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) * is in window. */ if (flag & FLAG_UPDATE_TS_RECENT) - tcp_replace_ts_recent(tp, TCP_SKB_CB(skb)->seq); + flag |= tcp_replace_ts_recent(tp, TCP_SKB_CB(skb)->seq); if ((flag & (FLAG_SLOWPATH | FLAG_SND_UNA_ADVANCED)) == FLAG_SND_UNA_ADVANCED) { @@ -3805,12 +4075,8 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) tcp_snd_una_update(tp, ack); flag |= FLAG_WIN_UPDATE; - tcp_in_ack_event(sk, CA_ACK_WIN_UPDATE); - NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPHPACKS); } else { - u32 ack_ev_flags = CA_ACK_SLOWPATH; - if (ack_seq != TCP_SKB_CB(skb)->end_seq) flag |= FLAG_DATA; else @@ -3822,19 +4088,12 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) flag |= tcp_sacktag_write_queue(sk, skb, prior_snd_una, &sack_state); - if (tcp_ecn_rcv_ecn_echo(tp, tcp_hdr(skb))) { + if (tcp_ecn_rcv_ecn_echo(tp, tcp_hdr(skb))) flag |= FLAG_ECE; - ack_ev_flags |= CA_ACK_ECE; - } if (sack_state.sack_delivered) tcp_count_delivered(tp, sack_state.sack_delivered, flag & FLAG_ECE); - - if (flag & FLAG_WIN_UPDATE) - ack_ev_flags |= CA_ACK_WIN_UPDATE; - - tcp_in_ack_event(sk, ack_ev_flags); } /* This is a deviation from RFC3168 since it states that: @@ -3849,8 +4108,9 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) /* We passed data and got it acked, remove any soft error * log. Something worked... */ - sk->sk_err_soft = 0; - icsk->icsk_probes_out = 0; + if (READ_ONCE(sk->sk_err_soft)) + WRITE_ONCE(sk->sk_err_soft, 0); + WRITE_ONCE(icsk->icsk_probes_out, 0); tp->rcv_tstamp = tcp_jiffies32; if (!prior_packets) goto no_queue; @@ -3861,11 +4121,20 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) tcp_rack_update_reo_wnd(sk, &rs); + if (tcp_ecn_mode_accecn(tp)) + ecn_count = tcp_accecn_process(sk, skb, + tp->delivered - delivered, + sack_state.delivered_bytes, + &flag); + + tcp_in_ack_event(sk, flag); + if (tp->tlp_high_seq) tcp_process_tlp_ack(sk, ack, flag); if (tcp_ack_is_dubious(sk, flag)) { - if (!(flag & (FLAG_SND_UNA_ADVANCED | FLAG_NOT_DUP))) { + if (!(flag & (FLAG_SND_UNA_ADVANCED | + FLAG_NOT_DUP | FLAG_DSACKING_ACK))) { num_dupack = 1; /* Consider if pure acks were aggregated in tcp_add_backlog() */ if (!(flag & FLAG_DATA)) @@ -3882,7 +4151,8 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) if ((flag & FLAG_FORWARD_PROGRESS) || !(flag & FLAG_NOT_DUP)) sk_dst_confirm(sk); - delivered = tcp_newly_delivered(sk, delivered, flag); + delivered = tcp_newly_delivered(sk, delivered, ecn_count, flag); + lost = tp->lost - lost; /* freshly marked lost */ rs.is_ack_delayed = !!(flag & FLAG_ACK_MAYBE_DELAYED); tcp_rate_gen(sk, delivered, lost, is_sack_reneg, sack_state.rate); @@ -3891,11 +4161,17 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) return 1; no_queue: + if (tcp_ecn_mode_accecn(tp)) + ecn_count = tcp_accecn_process(sk, skb, + tp->delivered - delivered, + sack_state.delivered_bytes, + &flag); + tcp_in_ack_event(sk, flag); /* If data was DSACKed, see if we can undo a cwnd reduction. */ if (flag & FLAG_DSACKING_ACK) { tcp_fastretrans_alert(sk, prior_snd_una, num_dupack, &flag, &rexmit); - tcp_newly_delivered(sk, delivered, flag); + tcp_newly_delivered(sk, delivered, ecn_count, flag); } /* If this ack opens up a zero window, clear backoff. It was * being used to time the probes, and is probably far higher than @@ -3916,7 +4192,7 @@ old_ack: &sack_state); tcp_fastretrans_alert(sk, prior_snd_una, num_dupack, &flag, &rexmit); - tcp_newly_delivered(sk, delivered, flag); + tcp_newly_delivered(sk, delivered, ecn_count, flag); tcp_xmit_recovery(sk, rexmit); } @@ -3961,7 +4237,7 @@ static bool smc_parse_options(const struct tcphdr *th, /* Try to parse the MSS option from the TCP header. Return 0 on failure, clamped * value on success. */ -static u16 tcp_parse_mss_option(const struct tcphdr *th, u16 user_mss) +u16 tcp_parse_mss_option(const struct tcphdr *th, u16 user_mss) { const unsigned char *ptr = (const unsigned char *)(th + 1); int length = (th->doff * 4) - sizeof(struct tcphdr); @@ -4016,6 +4292,7 @@ void tcp_parse_options(const struct net *net, ptr = (const unsigned char *)(th + 1); opt_rx->saw_tstamp = 0; + opt_rx->accecn = 0; opt_rx->saw_unknown = 0; while (length > 0) { @@ -4050,7 +4327,7 @@ void tcp_parse_options(const struct net *net, break; case TCPOPT_WINDOW: if (opsize == TCPOLEN_WINDOW && th->syn && - !estab && net->ipv4.sysctl_tcp_window_scaling) { + !estab && READ_ONCE(net->ipv4.sysctl_tcp_window_scaling)) { __u8 snd_wscale = *(__u8 *)ptr; opt_rx->wscale_ok = 1; if (snd_wscale > TCP_MAX_WSCALE) { @@ -4066,7 +4343,7 @@ void tcp_parse_options(const struct net *net, case TCPOPT_TIMESTAMP: if ((opsize == TCPOLEN_TIMESTAMP) && ((estab && opt_rx->tstamp_ok) || - (!estab && net->ipv4.sysctl_tcp_timestamps))) { + (!estab && READ_ONCE(net->ipv4.sysctl_tcp_timestamps)))) { opt_rx->saw_tstamp = 1; opt_rx->rcv_tsval = get_unaligned_be32(ptr); opt_rx->rcv_tsecr = get_unaligned_be32(ptr + 4); @@ -4074,7 +4351,7 @@ void tcp_parse_options(const struct net *net, break; case TCPOPT_SACK_PERM: if (opsize == TCPOLEN_SACK_PERM && th->syn && - !estab && net->ipv4.sysctl_tcp_sack) { + !estab && READ_ONCE(net->ipv4.sysctl_tcp_sack)) { opt_rx->sack_ok = TCP_SACK_SEEN; tcp_sack_reset(opt_rx); } @@ -4089,9 +4366,15 @@ void tcp_parse_options(const struct net *net, break; #ifdef CONFIG_TCP_MD5SIG case TCPOPT_MD5SIG: - /* - * The MD5 Hash has already been - * checked (see tcp_v{4,6}_do_rcv()). + /* The MD5 Hash has already been + * checked (see tcp_v{4,6}_rcv()). + */ + break; +#endif +#ifdef CONFIG_TCP_AO + case TCPOPT_AO: + /* TCP AO has already been checked + * (see tcp_inbound_ao_hash()). */ break; #endif @@ -4101,6 +4384,12 @@ void tcp_parse_options(const struct net *net, ptr, th->syn, foc, false); break; + case TCPOPT_ACCECN0: + case TCPOPT_ACCECN1: + /* Save offset of AccECN option in TCP header */ + opt_rx->accecn = (ptr - 2) - (__u8 *)th; + break; + case TCPOPT_EXP: /* Fast Open option shares code 254 using a * 16 bits magic number. @@ -4161,11 +4450,14 @@ static bool tcp_fast_parse_options(const struct net *net, */ if (th->doff == (sizeof(*th) / 4)) { tp->rx_opt.saw_tstamp = 0; + tp->rx_opt.accecn = 0; return false; } else if (tp->rx_opt.tstamp_ok && th->doff == ((sizeof(*th) + TCPOLEN_TSTAMP_ALIGNED) / 4)) { - if (tcp_parse_aligned_timestamp(tp, th)) + if (tcp_parse_aligned_timestamp(tp, th)) { + tp->rx_opt.accecn = 0; return true; + } } tcp_parse_options(net, skb, &tp->rx_opt, 1, NULL); @@ -4175,39 +4467,58 @@ static bool tcp_fast_parse_options(const struct net *net, return true; } -#ifdef CONFIG_TCP_MD5SIG +#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO) /* - * Parse MD5 Signature option + * Parse Signature options */ -const u8 *tcp_parse_md5sig_option(const struct tcphdr *th) +int tcp_do_parse_auth_options(const struct tcphdr *th, + const u8 **md5_hash, const u8 **ao_hash) { int length = (th->doff << 2) - sizeof(*th); const u8 *ptr = (const u8 *)(th + 1); + unsigned int minlen = TCPOLEN_MD5SIG; + + if (IS_ENABLED(CONFIG_TCP_AO)) + minlen = sizeof(struct tcp_ao_hdr) + 1; + + *md5_hash = NULL; + *ao_hash = NULL; /* If not enough data remaining, we can short cut */ - while (length >= TCPOLEN_MD5SIG) { + while (length >= minlen) { int opcode = *ptr++; int opsize; switch (opcode) { case TCPOPT_EOL: - return NULL; + return 0; case TCPOPT_NOP: length--; continue; default: opsize = *ptr++; if (opsize < 2 || opsize > length) - return NULL; - if (opcode == TCPOPT_MD5SIG) - return opsize == TCPOLEN_MD5SIG ? ptr : NULL; + return -EINVAL; + if (opcode == TCPOPT_MD5SIG) { + if (opsize != TCPOLEN_MD5SIG) + return -EINVAL; + if (unlikely(*md5_hash || *ao_hash)) + return -EEXIST; + *md5_hash = ptr; + } else if (opcode == TCPOPT_AO) { + if (opsize <= sizeof(struct tcp_ao_hdr)) + return -EINVAL; + if (unlikely(*md5_hash || *ao_hash)) + return -EEXIST; + *ao_hash = ptr; + } } ptr += opsize - 2; length -= opsize; } - return NULL; + return 0; } -EXPORT_SYMBOL(tcp_parse_md5sig_option); +EXPORT_SYMBOL(tcp_do_parse_auth_options); #endif /* Sorry, PAWS as specified is broken wrt. pure-ACKs -DaveM @@ -4233,33 +4544,57 @@ EXPORT_SYMBOL(tcp_parse_md5sig_option); * up to bandwidth of 18Gigabit/sec. 8) ] */ -static int tcp_disordered_ack(const struct sock *sk, const struct sk_buff *skb) +/* Estimates max number of increments of remote peer TSval in + * a replay window (based on our current RTO estimation). + */ +static u32 tcp_tsval_replay(const struct sock *sk) +{ + /* If we use usec TS resolution, + * then expect the remote peer to use the same resolution. + */ + if (tcp_sk(sk)->tcp_usec_ts) + return inet_csk(sk)->icsk_rto * (USEC_PER_SEC / HZ); + + /* RFC 7323 recommends a TSval clock between 1ms and 1sec. + * We know that some OS (including old linux) can use 1200 Hz. + */ + return inet_csk(sk)->icsk_rto * 1200 / HZ; +} + +static enum skb_drop_reason tcp_disordered_ack_check(const struct sock *sk, + const struct sk_buff *skb) { const struct tcp_sock *tp = tcp_sk(sk); const struct tcphdr *th = tcp_hdr(skb); - u32 seq = TCP_SKB_CB(skb)->seq; + SKB_DR_INIT(reason, TCP_RFC7323_PAWS); u32 ack = TCP_SKB_CB(skb)->ack_seq; + u32 seq = TCP_SKB_CB(skb)->seq; - return (/* 1. Pure ACK with correct sequence number. */ - (th->ack && seq == TCP_SKB_CB(skb)->end_seq && seq == tp->rcv_nxt) && + /* 1. Is this not a pure ACK ? */ + if (!th->ack || seq != TCP_SKB_CB(skb)->end_seq) + return reason; - /* 2. ... and duplicate ACK. */ - ack == tp->snd_una && + /* 2. Is its sequence not the expected one ? */ + if (seq != tp->rcv_nxt) + return before(seq, tp->rcv_nxt) ? + SKB_DROP_REASON_TCP_RFC7323_PAWS_ACK : + reason; - /* 3. ... and does not update window. */ - !tcp_may_update_window(tp, ack, seq, ntohs(th->window) << tp->rx_opt.snd_wscale) && + /* 3. Is this not a duplicate ACK ? */ + if (ack != tp->snd_una) + return reason; - /* 4. ... and sits in replay window. */ - (s32)(tp->rx_opt.ts_recent - tp->rx_opt.rcv_tsval) <= (inet_csk(sk)->icsk_rto * 1024) / HZ); -} + /* 4. Is this updating the window ? */ + if (tcp_may_update_window(tp, ack, seq, ntohs(th->window) << + tp->rx_opt.snd_wscale)) + return reason; -static inline bool tcp_paws_discard(const struct sock *sk, - const struct sk_buff *skb) -{ - const struct tcp_sock *tp = tcp_sk(sk); + /* 5. Is this not in the replay window ? */ + if ((s32)(tp->rx_opt.ts_recent - tp->rx_opt.rcv_tsval) > + tcp_tsval_replay(sk)) + return reason; - return !tcp_paws_check(&tp->rx_opt, TCP_PAWS_WINDOW) && - !tcp_disordered_ack(sk, skb); + return 0; } /* Check segment sequence number for validity. @@ -4275,15 +4610,46 @@ static inline bool tcp_paws_discard(const struct sock *sk, * (borrowed from freebsd) */ -static inline bool tcp_sequence(const struct tcp_sock *tp, u32 seq, u32 end_seq) +static enum skb_drop_reason tcp_sequence(const struct sock *sk, + u32 seq, u32 end_seq) { - return !before(end_seq, tp->rcv_wup) && - !after(seq, tp->rcv_nxt + tcp_receive_window(tp)); + const struct tcp_sock *tp = tcp_sk(sk); + + if (before(end_seq, tp->rcv_wup)) + return SKB_DROP_REASON_TCP_OLD_SEQUENCE; + + if (after(end_seq, tp->rcv_nxt + tcp_receive_window(tp))) { + if (after(seq, tp->rcv_nxt + tcp_receive_window(tp))) + return SKB_DROP_REASON_TCP_INVALID_SEQUENCE; + + /* Only accept this packet if receive queue is empty. */ + if (skb_queue_len(&sk->sk_receive_queue)) + return SKB_DROP_REASON_TCP_INVALID_END_SEQUENCE; + } + + return SKB_NOT_DROPPED_YET; +} + + +void tcp_done_with_error(struct sock *sk, int err) +{ + /* This barrier is coupled with smp_rmb() in tcp_poll() */ + WRITE_ONCE(sk->sk_err, err); + smp_wmb(); + + tcp_write_queue_purge(sk); + tcp_done(sk); + + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); } +EXPORT_IPV6_MOD(tcp_done_with_error); /* When we get a reset we do this. */ void tcp_reset(struct sock *sk, struct sk_buff *skb) { + int err; + trace_tcp_receive_reset(sk); /* mptcp can't tell us to ignore reset pkts, @@ -4295,24 +4661,17 @@ void tcp_reset(struct sock *sk, struct sk_buff *skb) /* We want the right error as BSD sees it (and indeed as we do). */ switch (sk->sk_state) { case TCP_SYN_SENT: - sk->sk_err = ECONNREFUSED; + err = ECONNREFUSED; break; case TCP_CLOSE_WAIT: - sk->sk_err = EPIPE; + err = EPIPE; break; case TCP_CLOSE: return; default: - sk->sk_err = ECONNRESET; + err = ECONNRESET; } - /* This barrier is coupled with smp_rmb() in tcp_poll() */ - smp_wmb(); - - tcp_write_queue_purge(sk); - tcp_done(sk); - - if (!sock_flag(sk, SOCK_DEAD)) - sk_error_report(sk); + tcp_done_with_error(sk, err); } /* @@ -4335,7 +4694,7 @@ void tcp_fin(struct sock *sk) inet_csk_schedule_ack(sk); - sk->sk_shutdown |= RCV_SHUTDOWN; + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN); sock_set_flag(sk, SOCK_DONE); switch (sk->sk_state) { @@ -4384,7 +4743,6 @@ void tcp_fin(struct sock *sk) skb_rbtree_purge(&tp->out_of_order_queue); if (tcp_is_sack(tp)) tcp_sack_reset(&tp->rx_opt); - sk_mem_reclaim(sk); if (!sock_flag(sk, SOCK_DEAD)) { sk->sk_state_change(sk); @@ -4415,7 +4773,7 @@ static void tcp_dsack_set(struct sock *sk, u32 seq, u32 end_seq) { struct tcp_sock *tp = tcp_sk(sk); - if (tcp_is_sack(tp) && sock_net(sk)->ipv4.sysctl_tcp_dsack) { + if (tcp_is_sack(tp) && READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_dsack)) { int mib_idx; if (before(seq, tp->rcv_nxt)) @@ -4445,12 +4803,23 @@ static void tcp_rcv_spurious_retrans(struct sock *sk, const struct sk_buff *skb) { /* When the ACK path fails or drops most ACKs, the sender would * timeout and spuriously retransmit the same segment repeatedly. - * The receiver remembers and reflects via DSACKs. Leverage the - * DSACK state and change the txhash to re-route speculatively. + * If it seems our ACKs are not reaching the other side, + * based on receiving a duplicate data segment with new flowlabel + * (suggesting the sender suffered an RTO), and we are not already + * repathing due to our own RTO, then rehash the socket to repath our + * packets. */ - if (TCP_SKB_CB(skb)->seq == tcp_sk(sk)->duplicate_sack[0].start_seq && +#if IS_ENABLED(CONFIG_IPV6) + if (inet_csk(sk)->icsk_ca_state != TCP_CA_Loss && + skb->protocol == htons(ETH_P_IPV6) && + (tcp_sk(sk)->inet_conn.icsk_ack.lrcv_flowlabel != + ntohl(ip6_flowlabel(ipv6_hdr(skb)))) && sk_rethink_txhash(sk)) NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPDUPLICATEDATAREHASH); + + /* Save last flowlabel after a spurious retrans. */ + tcp_save_lrcv_flowlabel(sk, skb); +#endif } static void tcp_send_dupack(struct sock *sk, const struct sk_buff *skb) @@ -4462,7 +4831,7 @@ static void tcp_send_dupack(struct sock *sk, const struct sk_buff *skb) NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKLOST); tcp_enter_quickack_mode(sk, TCP_MAX_QUICKACKS); - if (tcp_is_sack(tp) && sock_net(sk)->ipv4.sysctl_tcp_dsack) { + if (tcp_is_sack(tp) && READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_dsack)) { u32 end_seq = TCP_SKB_CB(skb)->end_seq; tcp_rcv_spurious_retrans(sk, skb); @@ -4504,7 +4873,7 @@ static void tcp_sack_maybe_coalesce(struct tcp_sock *tp) } } -static void tcp_sack_compress_send_ack(struct sock *sk) +void tcp_sack_compress_send_ack(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); @@ -4638,14 +5007,9 @@ static bool tcp_try_coalesce(struct sock *sk, if (TCP_SKB_CB(from)->seq != TCP_SKB_CB(to)->end_seq) return false; - if (!mptcp_skb_can_collapse(to, from)) + if (!tcp_skb_can_collapse_rx(to, from)) return false; -#ifdef CONFIG_TLS_DEVICE - if (from->decrypted != to->decrypted) - return false; -#endif - if (!skb_try_coalesce(to, from, fragstolen, &delta)) return false; @@ -4672,7 +5036,7 @@ static bool tcp_ooo_try_coalesce(struct sock *sk, { bool res = tcp_try_coalesce(sk, to, from, fragstolen); - /* In case tcp_drop() is called later, update to->gso_segs */ + /* In case tcp_drop_reason() is called later, update to->gso_segs */ if (res) { u32 gso_segs = max_t(u16, 1, skb_shinfo(to)->gso_segs) + max_t(u16, 1, skb_shinfo(from)->gso_segs); @@ -4682,10 +5046,11 @@ static bool tcp_ooo_try_coalesce(struct sock *sk, return res; } -static void tcp_drop(struct sock *sk, struct sk_buff *skb) +noinline_for_tracing static void +tcp_drop_reason(struct sock *sk, struct sk_buff *skb, enum skb_drop_reason reason) { - sk_drops_add(sk, skb); - __kfree_skb(skb); + sk_drops_skbadd(sk, skb); + sk_skb_reason_drop(sk, skb, reason); } /* This one checks to see if we can put data from the @@ -4707,15 +5072,16 @@ static void tcp_ofo_queue(struct sock *sk) if (before(TCP_SKB_CB(skb)->seq, dsack_high)) { __u32 dsack = dsack_high; + if (before(TCP_SKB_CB(skb)->end_seq, dsack_high)) - dsack_high = TCP_SKB_CB(skb)->end_seq; + dsack = TCP_SKB_CB(skb)->end_seq; tcp_dsack_extend(sk, TCP_SKB_CB(skb)->seq, dsack); } p = rb_next(p); rb_erase(&skb->rbnode, &tp->out_of_order_queue); if (unlikely(!after(TCP_SKB_CB(skb)->end_seq, tp->rcv_nxt))) { - tcp_drop(sk, skb); + tcp_drop_reason(sk, skb, SKB_DROP_REASON_TCP_OFO_DROP); continue; } @@ -4724,7 +5090,7 @@ static void tcp_ofo_queue(struct sock *sk) tcp_rcv_nxt_update(tp, TCP_SKB_CB(skb)->end_seq); fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN; if (!eaten) - __skb_queue_tail(&sk->sk_receive_queue, skb); + tcp_add_receive_queue(sk, skb); else kfree_skb_partial(skb, fragstolen); @@ -4738,20 +5104,41 @@ static void tcp_ofo_queue(struct sock *sk) } } -static bool tcp_prune_ofo_queue(struct sock *sk); -static int tcp_prune_queue(struct sock *sk); +static bool tcp_prune_ofo_queue(struct sock *sk, const struct sk_buff *in_skb); +static int tcp_prune_queue(struct sock *sk, const struct sk_buff *in_skb); -static int tcp_try_rmem_schedule(struct sock *sk, struct sk_buff *skb, +/* Check if this incoming skb can be added to socket receive queues + * while satisfying sk->sk_rcvbuf limit. + * + * In theory we should use skb->truesize, but this can cause problems + * when applications use too small SO_RCVBUF values. + * When LRO / hw gro is used, the socket might have a high tp->scaling_ratio, + * allowing RWIN to be close to available space. + * Whenever the receive queue gets full, we can receive a small packet + * filling RWIN, but with a high skb->truesize, because most NIC use 4K page + * plus sk_buff metadata even when receiving less than 1500 bytes of payload. + * + * Note that we use skb->len to decide to accept or drop this packet, + * but sk->sk_rmem_alloc is the sum of all skb->truesize. + */ +static bool tcp_can_ingest(const struct sock *sk, const struct sk_buff *skb) +{ + unsigned int rmem = atomic_read(&sk->sk_rmem_alloc); + + return rmem + skb->len <= sk->sk_rcvbuf; +} + +static int tcp_try_rmem_schedule(struct sock *sk, const struct sk_buff *skb, unsigned int size) { - if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf || + if (!tcp_can_ingest(sk, skb) || !sk_rmem_schedule(sk, skb, size)) { - if (tcp_prune_queue(sk) < 0) + if (tcp_prune_queue(sk, skb) < 0) return -1; while (!sk_rmem_schedule(sk, skb, size)) { - if (!tcp_prune_ofo_queue(sk)) + if (!tcp_prune_ofo_queue(sk, skb)) return -1; } } @@ -4766,15 +5153,17 @@ static void tcp_data_queue_ofo(struct sock *sk, struct sk_buff *skb) u32 seq, end_seq; bool fragstolen; - tcp_ecn_check_ce(sk, skb); + tcp_save_lrcv_flowlabel(sk, skb); + tcp_data_ecn_check(sk, skb); if (unlikely(tcp_try_rmem_schedule(sk, skb, skb->truesize))) { NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPOFODROP); sk->sk_data_ready(sk); - tcp_drop(sk, skb); + tcp_drop_reason(sk, skb, SKB_DROP_REASON_PROTO_MEM); return; } + tcp_measure_rcv_mss(sk, skb); /* Disable header prediction. */ tp->pred_flags = 0; inet_csk_schedule_ack(sk); @@ -4834,7 +5223,8 @@ coalesce_done: /* All the bits are present. Drop. */ NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPOFOMERGE); - tcp_drop(sk, skb); + tcp_drop_reason(sk, skb, + SKB_DROP_REASON_TCP_OFOMERGE); skb = NULL; tcp_dsack_set(sk, seq, end_seq); goto add_sack; @@ -4853,7 +5243,8 @@ coalesce_done: TCP_SKB_CB(skb1)->end_seq); NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPOFOMERGE); - tcp_drop(sk, skb1); + tcp_drop_reason(sk, skb1, + SKB_DROP_REASON_TCP_OFOMERGE); goto merge_right; } } else if (tcp_ooo_try_coalesce(sk, skb1, @@ -4881,7 +5272,7 @@ merge_right: tcp_dsack_extend(sk, TCP_SKB_CB(skb1)->seq, TCP_SKB_CB(skb1)->end_seq); NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPOFOMERGE); - tcp_drop(sk, skb1); + tcp_drop_reason(sk, skb1, SKB_DROP_REASON_TCP_OFOMERGE); } /* If there is no skb after us, we are the last_skb ! */ if (!skb1) @@ -4900,6 +5291,9 @@ end: skb_condense(skb); skb_set_owner_r(skb, sk); } + /* do not grow rcvbuf for not-yet-accepted or orphaned sockets. */ + if (sk->sk_socket) + tcp_rcvbuf_grow(sk, tp->rcvq_space.space); } static int __must_check tcp_queue_rcv(struct sock *sk, struct sk_buff *skb, @@ -4913,7 +5307,7 @@ static int __must_check tcp_queue_rcv(struct sock *sk, struct sk_buff *skb, skb, fragstolen)) ? 1 : 0; tcp_rcv_nxt_update(tcp_sk(sk), TCP_SKB_CB(skb)->end_seq); if (!eaten) { - __skb_queue_tail(&sk->sk_receive_queue, skb); + tcp_add_receive_queue(sk, skb); skb_set_owner_r(skb, sk); } return eaten; @@ -4980,6 +5374,7 @@ void tcp_data_ready(struct sock *sk) static void tcp_data_queue(struct sock *sk, struct sk_buff *skb) { struct tcp_sock *tp = tcp_sk(sk); + enum skb_drop_reason reason; bool fragstolen; int eaten; @@ -4995,9 +5390,10 @@ static void tcp_data_queue(struct sock *sk, struct sk_buff *skb) __kfree_skb(skb); return; } - skb_dst_drop(skb); + tcp_cleanup_skb(skb); __skb_pull(skb, tcp_hdr(skb)->doff * 4); + reason = SKB_DROP_REASON_NOT_SPECIFIED; tp->rx_opt.dsack = 0; /* Queue data for delivery to the user. @@ -5006,18 +5402,36 @@ static void tcp_data_queue(struct sock *sk, struct sk_buff *skb) */ if (TCP_SKB_CB(skb)->seq == tp->rcv_nxt) { if (tcp_receive_window(tp) == 0) { + /* Some stacks are known to send bare FIN packets + * in a loop even if we send RWIN 0 in our ACK. + * Accepting this FIN does not hurt memory pressure + * because the FIN flag will simply be merged to the + * receive queue tail skb in most cases. + */ + if (!skb->len && + (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN)) + goto queue_and_out; + + reason = SKB_DROP_REASON_TCP_ZEROWINDOW; NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPZEROWINDOWDROP); goto out_of_window; } /* Ok. In sequence. In window. */ queue_and_out: - if (skb_queue_len(&sk->sk_receive_queue) == 0) - sk_forced_mem_schedule(sk, skb->truesize); - else if (tcp_try_rmem_schedule(sk, skb, skb->truesize)) { - NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPRCVQDROP); + if (tcp_try_rmem_schedule(sk, skb, skb->truesize)) { + /* TODO: maybe ratelimit these WIN 0 ACK ? */ + inet_csk(sk)->icsk_ack.pending |= + (ICSK_ACK_NOMEM | ICSK_ACK_NOW); + inet_csk_schedule_ack(sk); sk->sk_data_ready(sk); - goto drop; + + if (skb_queue_len(&sk->sk_receive_queue) && skb->len) { + reason = SKB_DROP_REASON_PROTO_MEM; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPRCVQDROP); + goto drop; + } + sk_forced_mem_schedule(sk, skb->truesize); } eaten = tcp_queue_rcv(sk, skb, &fragstolen); @@ -5051,6 +5465,7 @@ queue_and_out: if (!after(TCP_SKB_CB(skb)->end_seq, tp->rcv_nxt)) { tcp_rcv_spurious_retrans(sk, skb); /* A retransmit, 2nd most common case. Force an immediate ack. */ + reason = SKB_DROP_REASON_TCP_OLD_DATA; NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKLOST); tcp_dsack_set(sk, TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq); @@ -5058,13 +5473,16 @@ out_of_window: tcp_enter_quickack_mode(sk, TCP_MAX_QUICKACKS); inet_csk_schedule_ack(sk); drop: - tcp_drop(sk, skb); + tcp_drop_reason(sk, skb, reason); return; } /* Out of window. F.e. zero window probe. */ - if (!before(TCP_SKB_CB(skb)->seq, tp->rcv_nxt + tcp_receive_window(tp))) + if (!before(TCP_SKB_CB(skb)->seq, + tp->rcv_nxt + tcp_receive_window(tp))) { + reason = SKB_DROP_REASON_TCP_OVERWINDOW; goto out_of_window; + } if (before(TCP_SKB_CB(skb)->seq, tp->rcv_nxt)) { /* Partial packet, seq < rcv_next < end_seq */ @@ -5074,6 +5492,7 @@ drop: * remembering D-SACK for its head made in previous line. */ if (!tcp_receive_window(tp)) { + reason = SKB_DROP_REASON_TCP_ZEROWINDOW; NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPZEROWINDOWDROP); goto out_of_window; } @@ -5150,6 +5569,9 @@ restart: for (end_of_skbs = true; skb != NULL && skb != tail; skb = n) { n = tcp_skb_next(skb, list); + if (!skb_frags_readable(skb)) + goto skip_this; + /* No new bits? It is possible on ofo queue. */ if (!before(start, TCP_SKB_CB(skb)->end_seq)) { skb = tcp_collapse_one(sk, skb, list, root); @@ -5170,17 +5592,20 @@ restart: break; } - if (n && n != tail && mptcp_skb_can_collapse(skb, n) && + if (n && n != tail && skb_frags_readable(n) && + tcp_skb_can_collapse_rx(skb, n) && TCP_SKB_CB(skb)->end_seq != TCP_SKB_CB(n)->seq) { end_of_skbs = false; break; } +skip_this: /* Decided to skip this, advance start seq. */ start = TCP_SKB_CB(skb)->end_seq; } if (end_of_skbs || - (TCP_SKB_CB(skb)->tcp_flags & (TCPHDR_SYN | TCPHDR_FIN))) + (TCP_SKB_CB(skb)->tcp_flags & (TCPHDR_SYN | TCPHDR_FIN)) || + !skb_frags_readable(skb)) return; __skb_queue_head_init(&tmp); @@ -5194,9 +5619,7 @@ restart: break; memcpy(nskb->cb, skb->cb, sizeof(skb->cb)); -#ifdef CONFIG_TLS_DEVICE - nskb->decrypted = skb->decrypted; -#endif + skb_copy_decrypted(nskb, skb); TCP_SKB_CB(nskb)->seq = TCP_SKB_CB(nskb)->end_seq = start; if (list) __skb_queue_before(list, skb, nskb); @@ -5223,13 +5646,10 @@ restart: skb = tcp_collapse_one(sk, skb, list, root); if (!skb || skb == tail || - !mptcp_skb_can_collapse(nskb, skb) || - (TCP_SKB_CB(skb)->tcp_flags & (TCPHDR_SYN | TCPHDR_FIN))) - goto end; -#ifdef CONFIG_TLS_DEVICE - if (skb->decrypted != nskb->decrypted) + !tcp_skb_can_collapse_rx(nskb, skb) || + (TCP_SKB_CB(skb)->tcp_flags & (TCPHDR_SYN | TCPHDR_FIN)) || + !skb_frags_readable(skb)) goto end; -#endif } } } @@ -5269,7 +5689,7 @@ new_range: before(TCP_SKB_CB(skb)->end_seq, start)) { /* Do not attempt collapsing tiny skbs */ if (range_truesize != head->truesize || - end - start >= SKB_WITH_OVERHEAD(SK_MEM_QUANTUM)) { + end - start >= SKB_WITH_OVERHEAD(PAGE_SIZE)) { tcp_collapse(sk, NULL, &tp->out_of_order_queue, head, skb, start, end); } else { @@ -5292,6 +5712,8 @@ new_range: * Clean the out-of-order queue to make room. * We drop high sequences packets to : * 1) Let a chance for holes to be filled. + * This means we do not drop packets from ooo queue if their sequence + * is before incoming packet sequence. * 2) not add too big latencies if thousands of packets sit there. * (But if application shrinks SO_RCVBUF, we could still end up * freeing whole queue here) @@ -5299,42 +5721,51 @@ new_range: * * Return true if queue has shrunk. */ -static bool tcp_prune_ofo_queue(struct sock *sk) +static bool tcp_prune_ofo_queue(struct sock *sk, const struct sk_buff *in_skb) { struct tcp_sock *tp = tcp_sk(sk); struct rb_node *node, *prev; + bool pruned = false; int goal; if (RB_EMPTY_ROOT(&tp->out_of_order_queue)) return false; - NET_INC_STATS(sock_net(sk), LINUX_MIB_OFOPRUNED); goal = sk->sk_rcvbuf >> 3; node = &tp->ooo_last_skb->rbnode; + do { + struct sk_buff *skb = rb_to_skb(node); + + /* If incoming skb would land last in ofo queue, stop pruning. */ + if (after(TCP_SKB_CB(in_skb)->seq, TCP_SKB_CB(skb)->seq)) + break; + pruned = true; prev = rb_prev(node); rb_erase(node, &tp->out_of_order_queue); - goal -= rb_to_skb(node)->truesize; - tcp_drop(sk, rb_to_skb(node)); + goal -= skb->truesize; + tcp_drop_reason(sk, skb, SKB_DROP_REASON_TCP_OFO_QUEUE_PRUNE); + tp->ooo_last_skb = rb_to_skb(prev); if (!prev || goal <= 0) { - sk_mem_reclaim(sk); - if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf && + if (tcp_can_ingest(sk, in_skb) && !tcp_under_memory_pressure(sk)) break; goal = sk->sk_rcvbuf >> 3; } node = prev; } while (node); - tp->ooo_last_skb = rb_to_skb(prev); - /* Reset SACK state. A conforming SACK implementation will - * do the same at a timeout based retransmit. When a connection - * is in a sad state like this, we care only about integrity - * of the connection not performance. - */ - if (tp->rx_opt.sack_ok) - tcp_sack_reset(&tp->rx_opt); - return true; + if (pruned) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_OFOPRUNED); + /* Reset SACK state. A conforming SACK implementation will + * do the same at a timeout based retransmit. When a connection + * is in a sad state like this, we care only about integrity + * of the connection not performance. + */ + if (tp->rx_opt.sack_ok) + tcp_sack_reset(&tp->rx_opt); + } + return pruned; } /* Reduce allocated memory if we can, trying to get @@ -5344,18 +5775,22 @@ static bool tcp_prune_ofo_queue(struct sock *sk) * until the socket owning process reads some of the data * to stabilize the situation. */ -static int tcp_prune_queue(struct sock *sk) +static int tcp_prune_queue(struct sock *sk, const struct sk_buff *in_skb) { struct tcp_sock *tp = tcp_sk(sk); + /* Do nothing if our queues are empty. */ + if (!atomic_read(&sk->sk_rmem_alloc)) + return -1; + NET_INC_STATS(sock_net(sk), LINUX_MIB_PRUNECALLED); - if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf) + if (!tcp_can_ingest(sk, in_skb)) tcp_clamp_window(sk); else if (tcp_under_memory_pressure(sk)) tcp_adjust_rcv_ssthresh(sk); - if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) + if (tcp_can_ingest(sk, in_skb)) return 0; tcp_collapse_ofo_queue(sk); @@ -5364,17 +5799,16 @@ static int tcp_prune_queue(struct sock *sk) skb_peek(&sk->sk_receive_queue), NULL, tp->copied_seq, tp->rcv_nxt); - sk_mem_reclaim(sk); - if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) + if (tcp_can_ingest(sk, in_skb)) return 0; /* Collapsing did not help, destructive actions follow. * This must not ever occur. */ - tcp_prune_ofo_queue(sk); + tcp_prune_ofo_queue(sk, in_skb); - if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) + if (tcp_can_ingest(sk, in_skb)) return 0; /* If we are really being abused, tell the caller to silently @@ -5417,7 +5851,7 @@ static bool tcp_should_expand_sndbuf(struct sock *sk) return false; /* If we filled the congestion window, do not expand. */ - if (tcp_packets_in_flight(tp) >= tp->snd_cwnd) + if (tcp_packets_in_flight(tp) >= tcp_snd_cwnd(tp)) return false; return true; @@ -5435,7 +5869,17 @@ static void tcp_new_space(struct sock *sk) INDIRECT_CALL_1(sk->sk_write_space, sk_stream_write_space, sk); } -static void tcp_check_space(struct sock *sk) +/* Caller made space either from: + * 1) Freeing skbs in rtx queues (after tp->snd_una has advanced) + * 2) Sent skbs from output queue (and thus advancing tp->snd_nxt) + * + * We might be able to generate EPOLLOUT to the application if: + * 1) Space consumed in output/rtx queues is below sk->sk_sndbuf/2 + * 2) notsent amount (tp->write_seq - tp->snd_nxt) became + * small enough that tcp_stream_memory_free() decides it + * is time to generate EPOLLOUT. + */ +void tcp_check_space(struct sock *sk) { /* pairs with tcp_poll() */ smp_mb(); @@ -5459,7 +5903,9 @@ static inline void tcp_data_snd_check(struct sock *sk) static void __tcp_ack_snd_check(struct sock *sk, int ofo_possible) { struct tcp_sock *tp = tcp_sk(sk); - unsigned long rtt, delay; + struct net *net = sock_net(sk); + unsigned long rtt; + u64 delay; /* More than one full frame received... */ if (((tp->rcv_nxt - tp->rcv_wup) > inet_csk(sk)->icsk_ack.rcv_mss && @@ -5474,6 +5920,14 @@ static void __tcp_ack_snd_check(struct sock *sk, int ofo_possible) tcp_in_quickack_mode(sk) || /* Protocol state mandates a one-time immediate ACK */ inet_csk(sk)->icsk_ack.pending & ICSK_ACK_NOW) { + /* If we are running from __release_sock() in user context, + * Defer the ack until tcp_release_cb(). + */ + if (sock_owned_by_user_nocheck(sk) && + READ_ONCE(net->ipv4.sysctl_tcp_backlog_ack_defer)) { + set_bit(TCP_ACK_DEFERRED, &sk->sk_tsq_flags); + return; + } send_now: tcp_send_ack(sk); return; @@ -5485,7 +5939,7 @@ send_now: } if (!tcp_is_sack(tp) || - tp->compressed_ack >= sock_net(sk)->ipv4.sysctl_tcp_comp_sack_nr) + tp->compressed_ack >= READ_ONCE(net->ipv4.sysctl_tcp_comp_sack_nr)) goto send_now; if (tp->compressed_ack_rcv_nxt != tp->rcv_nxt) { @@ -5500,17 +5954,26 @@ send_now: if (hrtimer_is_queued(&tp->compressed_ack_timer)) return; - /* compress ack timer : 5 % of rtt, but no more than tcp_comp_sack_delay_ns */ + /* compress ack timer : comp_sack_rtt_percent of rtt, + * but no more than tcp_comp_sack_delay_ns. + */ rtt = tp->rcv_rtt_est.rtt_us; if (tp->srtt_us && tp->srtt_us < rtt) rtt = tp->srtt_us; - delay = min_t(unsigned long, sock_net(sk)->ipv4.sysctl_tcp_comp_sack_delay_ns, - rtt * (NSEC_PER_USEC >> 3)/20); + /* delay = (rtt >> 3) * NSEC_PER_USEC * comp_sack_rtt_percent / 100 + * -> + * delay = rtt * 1.25 * comp_sack_rtt_percent + */ + delay = (u64)(rtt + (rtt >> 2)) * + READ_ONCE(net->ipv4.sysctl_tcp_comp_sack_rtt_percent); + + delay = min(delay, READ_ONCE(net->ipv4.sysctl_tcp_comp_sack_delay_ns)); + sock_hold(sk); hrtimer_start_range_ns(&tp->compressed_ack_timer, ns_to_ktime(delay), - sock_net(sk)->ipv4.sysctl_tcp_comp_sack_slack_ns, + READ_ONCE(net->ipv4.sysctl_tcp_comp_sack_slack_ns), HRTIMER_MODE_REL_PINNED_SOFT); } @@ -5538,7 +6001,7 @@ static void tcp_check_urg(struct sock *sk, const struct tcphdr *th) struct tcp_sock *tp = tcp_sk(sk); u32 ptr = ntohs(th->urg_ptr); - if (ptr && !sock_net(sk)->ipv4.sysctl_tcp_stdurg) + if (ptr && !READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_stdurg)) ptr--; ptr += ntohl(th->seq); @@ -5634,7 +6097,7 @@ static void tcp_urg(struct sock *sk, struct sk_buff *skb, const struct tcphdr *t */ static bool tcp_reset_check(const struct sock *sk, const struct sk_buff *skb) { - struct tcp_sock *tp = tcp_sk(sk); + const struct tcp_sock *tp = tcp_sk(sk); return unlikely(TCP_SKB_CB(skb)->seq == (tp->rcv_nxt - 1) && (1 << sk->sk_state) & (TCPF_CLOSE_WAIT | TCPF_LAST_ACK | @@ -5648,25 +6111,42 @@ static bool tcp_validate_incoming(struct sock *sk, struct sk_buff *skb, const struct tcphdr *th, int syn_inerr) { struct tcp_sock *tp = tcp_sk(sk); - bool rst_seq_match = false; + bool accecn_reflector = false; + SKB_DR(reason); /* RFC1323: H1. Apply PAWS check first. */ - if (tcp_fast_parse_options(sock_net(sk), skb, th, tp) && - tp->rx_opt.saw_tstamp && - tcp_paws_discard(sk, skb)) { - if (!th->rst) { - NET_INC_STATS(sock_net(sk), LINUX_MIB_PAWSESTABREJECTED); - if (!tcp_oow_rate_limited(sock_net(sk), skb, - LINUX_MIB_TCPACKSKIPPEDPAWS, - &tp->last_oow_ack_time)) - tcp_send_dupack(sk, skb); - goto discard; - } - /* Reset is accepted even if it did not pass PAWS. */ + if (!tcp_fast_parse_options(sock_net(sk), skb, th, tp) || + !tp->rx_opt.saw_tstamp || + tcp_paws_check(&tp->rx_opt, TCP_PAWS_WINDOW)) + goto step1; + + reason = tcp_disordered_ack_check(sk, skb); + if (!reason) + goto step1; + /* Reset is accepted even if it did not pass PAWS. */ + if (th->rst) + goto step1; + if (unlikely(th->syn)) + goto syn_challenge; + + /* Old ACK are common, increment PAWS_OLD_ACK + * and do not send a dupack. + */ + if (reason == SKB_DROP_REASON_TCP_RFC7323_PAWS_ACK) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_PAWS_OLD_ACK); + goto discard; } + NET_INC_STATS(sock_net(sk), LINUX_MIB_PAWSESTABREJECTED); + if (!tcp_oow_rate_limited(sock_net(sk), skb, + LINUX_MIB_TCPACKSKIPPEDPAWS, + &tp->last_oow_ack_time)) + tcp_send_dupack(sk, skb); + goto discard; +step1: /* Step 1: check sequence number */ - if (!tcp_sequence(tp, TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq)) { + reason = tcp_sequence(sk, TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq); + if (reason) { /* RFC793, page 37: "In all states except SYN-SENT, all reset * (RST) segments are validated by checking their SEQ-fields." * And page 69: "If an incoming segment is not acceptable, @@ -5676,12 +6156,17 @@ static bool tcp_validate_incoming(struct sock *sk, struct sk_buff *skb, if (!th->rst) { if (th->syn) goto syn_challenge; + + if (reason == SKB_DROP_REASON_TCP_INVALID_SEQUENCE || + reason == SKB_DROP_REASON_TCP_INVALID_END_SEQUENCE) + NET_INC_STATS(sock_net(sk), + LINUX_MIB_BEYOND_WINDOW); if (!tcp_oow_rate_limited(sock_net(sk), skb, LINUX_MIB_TCPACKSKIPPEDSEQ, &tp->last_oow_ack_time)) tcp_send_dupack(sk, skb); } else if (tcp_reset_check(sk, skb)) { - tcp_reset(sk, skb); + goto reset; } goto discard; } @@ -5698,9 +6183,10 @@ static bool tcp_validate_incoming(struct sock *sk, struct sk_buff *skb, * Send a challenge ACK */ if (TCP_SKB_CB(skb)->seq == tp->rcv_nxt || - tcp_reset_check(sk, skb)) { - rst_seq_match = true; - } else if (tcp_is_sack(tp) && tp->rx_opt.num_sacks > 0) { + tcp_reset_check(sk, skb)) + goto reset; + + if (tcp_is_sack(tp) && tp->rx_opt.num_sacks > 0) { struct tcp_sack_block *sp = &tp->selective_acks[0]; int max_sack = sp[0].end_seq; int this_sack; @@ -5713,21 +6199,18 @@ static bool tcp_validate_incoming(struct sock *sk, struct sk_buff *skb, } if (TCP_SKB_CB(skb)->seq == max_sack) - rst_seq_match = true; + goto reset; } - if (rst_seq_match) - tcp_reset(sk, skb); - else { - /* Disable TFO if RST is out-of-order - * and no data has been received - * for current active TFO socket - */ - if (tp->syn_fastopen && !tp->data_segs_in && - sk->sk_state == TCP_ESTABLISHED) - tcp_fastopen_active_disable(sk); - tcp_send_challenge_ack(sk); - } + /* Disable TFO if RST is out-of-order + * and no data has been received + * for current active TFO socket + */ + if (tp->syn_fastopen && !tp->data_segs_in && + sk->sk_state == TCP_ESTABLISHED) + tcp_fastopen_active_disable(sk); + tcp_send_challenge_ack(sk, false); + SKB_DR_SET(reason, TCP_RESET); goto discard; } @@ -5737,20 +6220,42 @@ static bool tcp_validate_incoming(struct sock *sk, struct sk_buff *skb, * RFC 5961 4.2 : Send a challenge ack */ if (th->syn) { + if (tcp_ecn_mode_accecn(tp)) { + accecn_reflector = true; + if (tp->rx_opt.accecn && + tp->saw_accecn_opt < TCP_ACCECN_OPT_COUNTER_SEEN) { + u8 saw_opt = tcp_accecn_option_init(skb, tp->rx_opt.accecn); + + tcp_accecn_saw_opt_fail_recv(tp, saw_opt); + tcp_accecn_opt_demand_min(sk, 1); + } + } + if (sk->sk_state == TCP_SYN_RECV && sk->sk_socket && th->ack && + TCP_SKB_CB(skb)->seq + 1 == TCP_SKB_CB(skb)->end_seq && + TCP_SKB_CB(skb)->seq + 1 == tp->rcv_nxt && + TCP_SKB_CB(skb)->ack_seq == tp->snd_nxt) + goto pass; syn_challenge: if (syn_inerr) TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPSYNCHALLENGE); - tcp_send_challenge_ack(sk); + tcp_send_challenge_ack(sk, accecn_reflector); + SKB_DR_SET(reason, TCP_INVALID_SYN); goto discard; } +pass: bpf_skops_parse_hdr(sk, skb); return true; discard: - tcp_drop(sk, skb); + tcp_drop_reason(sk, skb, reason); + return false; + +reset: + tcp_reset(sk, skb); + __kfree_skb(skb); return false; } @@ -5779,6 +6284,7 @@ discard: */ void tcp_rcv_established(struct sock *sk, struct sk_buff *skb) { + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; const struct tcphdr *th = (const struct tcphdr *)skb->data; struct tcp_sock *tp = tcp_sk(sk); unsigned int len = skb->len; @@ -5805,6 +6311,7 @@ void tcp_rcv_established(struct sock *sk, struct sk_buff *skb) */ tp->rx_opt.saw_tstamp = 0; + tp->rx_opt.accecn = 0; /* pred_flags is 0xS?10 << 16 + snd_wnd * if header_prediction is to be made @@ -5819,6 +6326,8 @@ void tcp_rcv_established(struct sock *sk, struct sk_buff *skb) TCP_SKB_CB(skb)->seq == tp->rcv_nxt && !after(TCP_SKB_CB(skb)->ack_seq, tp->snd_nxt)) { int tcp_header_len = tp->tcp_header_len; + s32 delta = 0; + int flag = 0; /* Timestamp header prediction: tcp_header_len * is automatically equal to th->doff*4 due to pred_flags @@ -5831,8 +6340,10 @@ void tcp_rcv_established(struct sock *sk, struct sk_buff *skb) if (!tcp_parse_aligned_timestamp(tp, th)) goto slow_path; + delta = tp->rx_opt.rcv_tsval - + tp->rx_opt.ts_recent; /* If PAWS failed, check it more carefully in slow path */ - if ((s32)(tp->rx_opt.rcv_tsval - tp->rx_opt.ts_recent) < 0) + if (delta < 0) goto slow_path; /* DO NOT update ts_recent here, if checksum fails @@ -5852,12 +6363,15 @@ void tcp_rcv_established(struct sock *sk, struct sk_buff *skb) if (tcp_header_len == (sizeof(struct tcphdr) + TCPOLEN_TSTAMP_ALIGNED) && tp->rcv_nxt == tp->rcv_wup) - tcp_store_ts_recent(tp); + flag |= __tcp_replace_ts_recent(tp, + delta); + + tcp_ecn_received_counters(sk, skb, 0); /* We know that such packets are checksummed * on entry. */ - tcp_ack(sk, skb, 0); + tcp_ack(sk, skb, flag); __kfree_skb(skb); tcp_data_snd_check(sk); /* When receiving pure ack in fast path, update @@ -5867,6 +6381,7 @@ void tcp_rcv_established(struct sock *sk, struct sk_buff *skb) tp->rcv_rtt_last_tsecr = tp->rx_opt.rcv_tsecr; return; } else { /* Header too small */ + reason = SKB_DROP_REASON_PKT_TOO_SMALL; TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); goto discard; } @@ -5877,6 +6392,10 @@ void tcp_rcv_established(struct sock *sk, struct sk_buff *skb) if (tcp_checksum_complete(skb)) goto csum_error; + if (after(TCP_SKB_CB(skb)->end_seq, + tp->rcv_nxt + tcp_receive_window(tp))) + goto validate; + if ((int)skb->truesize > sk->sk_forward_alloc) goto step5; @@ -5887,21 +6406,25 @@ void tcp_rcv_established(struct sock *sk, struct sk_buff *skb) if (tcp_header_len == (sizeof(struct tcphdr) + TCPOLEN_TSTAMP_ALIGNED) && tp->rcv_nxt == tp->rcv_wup) - tcp_store_ts_recent(tp); + flag |= __tcp_replace_ts_recent(tp, + delta); tcp_rcv_rtt_measure_ts(sk, skb); NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPHPHITS); /* Bulk data transfer: receiver */ + tcp_cleanup_skb(skb); __skb_pull(skb, tcp_header_len); + tcp_ecn_received_counters(sk, skb, + len - tcp_header_len); eaten = tcp_queue_rcv(sk, skb, &fragstolen); tcp_event_data_recv(sk, skb); if (TCP_SKB_CB(skb)->ack_seq != tp->snd_una) { /* Well, only one small jumplet in fast path... */ - tcp_ack(sk, skb, FLAG_DATA); + tcp_ack(sk, skb, flag | FLAG_DATA); tcp_data_snd_check(sk); if (!inet_csk_ack_scheduled(sk)) goto no_ack; @@ -5922,20 +6445,26 @@ slow_path: if (len < (th->doff << 2) || tcp_checksum_complete(skb)) goto csum_error; - if (!th->ack && !th->rst && !th->syn) + if (!th->ack && !th->rst && !th->syn) { + reason = SKB_DROP_REASON_TCP_FLAGS; goto discard; + } /* * Standard slow path. */ - +validate: if (!tcp_validate_incoming(sk, skb, th, 1)) return; step5: - if (tcp_ack(sk, skb, FLAG_SLOWPATH | FLAG_UPDATE_TS_RECENT) < 0) - goto discard; + tcp_ecn_received_counters_payload(sk, skb); + reason = tcp_ack(sk, skb, FLAG_SLOWPATH | FLAG_UPDATE_TS_RECENT); + if ((int)reason < 0) { + reason = -reason; + goto discard; + } tcp_rcv_rtt_measure_ts(sk, skb); /* Process urgent data. */ @@ -5949,14 +6478,15 @@ step5: return; csum_error: + reason = SKB_DROP_REASON_TCP_CSUM; trace_tcp_bad_csum(skb); TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS); TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); discard: - tcp_drop(sk, skb); + tcp_drop_reason(sk, skb, reason); } -EXPORT_SYMBOL(tcp_rcv_established); +EXPORT_IPV6_MOD(tcp_rcv_established); void tcp_init_transfer(struct sock *sk, int bpf_op, struct sk_buff *skb) { @@ -5974,9 +6504,9 @@ void tcp_init_transfer(struct sock *sk, int bpf_op, struct sk_buff *skb) * retransmission has occurred. */ if (tp->total_retrans > 1 && tp->undo_marker) - tp->snd_cwnd = 1; + tcp_snd_cwnd_set(tp, 1); else - tp->snd_cwnd = tcp_init_cwnd(tp, __sk_dst_get(sk)); + tcp_snd_cwnd_set(tp, tcp_init_cwnd(tp, __sk_dst_get(sk))); tp->snd_cwnd_stamp = tcp_jiffies32; bpf_skops_established(sk, bpf_op, skb); @@ -5991,6 +6521,7 @@ void tcp_finish_connect(struct sock *sk, struct sk_buff *skb) struct tcp_sock *tp = tcp_sk(sk); struct inet_connection_sock *icsk = inet_csk(sk); + tcp_ao_finish_connect(sk, skb); tcp_set_state(sk, TCP_ESTABLISHED); icsk->icsk_ack.lrcvtime = tcp_jiffies32; @@ -6008,7 +6539,7 @@ void tcp_finish_connect(struct sock *sk, struct sk_buff *skb) tp->lsndtime = tcp_jiffies32; if (sock_flag(sk, SOCK_KEEPOPEN)) - inet_csk_reset_keepalive_timer(sk, keepalive_time_when(tp)); + tcp_reset_keepalive_timer(sk, keepalive_time_when(tp)); if (!tp->rx_opt.snd_wscale) __tcp_fast_path_on(tp, tp->snd_wnd); @@ -6024,7 +6555,7 @@ static bool tcp_rcv_fastopen_synack(struct sock *sk, struct sk_buff *synack, u16 mss = tp->rx_opt.mss_clamp, try_exp = 0; bool syn_drop = false; - if (mss == tp->rx_opt.user_mss) { + if (mss == READ_ONCE(tp->rx_opt.user_mss)) { struct tcp_options_received opt; /* Get original SYNACK MSS value if user MSS sets mss_clamp */ @@ -6061,7 +6592,7 @@ static bool tcp_rcv_fastopen_synack(struct sock *sk, struct sk_buff *synack, tp->fastopen_client_fail = TFO_DATA_NOT_ACKED; skb_rbtree_walk_from(data) tcp_mark_skb_lost(sk, data); - tcp_xmit_retransmit_queue(sk); + tcp_non_congestion_loss_retransmit(sk); NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPFASTOPENACTIVEFAIL); return true; @@ -6112,6 +6643,7 @@ static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, struct tcp_fastopen_cookie foc = { .len = -1 }; int saved_clamp = tp->rx_opt.mss_clamp; bool fastopen_fail; + SKB_DR(reason); tcp_parse_options(sock_net(sk), skb, &tp->rx_opt, 0, &foc); if (tp->rx_opt.saw_tstamp && tp->rx_opt.rcv_tsecr) @@ -6130,17 +6662,18 @@ static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, after(TCP_SKB_CB(skb)->ack_seq, tp->snd_nxt)) { /* Previous FIN/ACK or RST/ACK might be ignored. */ if (icsk->icsk_retransmits == 0) - inet_csk_reset_xmit_timer(sk, - ICSK_TIME_RETRANS, - TCP_TIMEOUT_MIN, TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + TCP_TIMEOUT_MIN, false); + SKB_DR_SET(reason, TCP_INVALID_ACK_SEQUENCE); goto reset_and_undo; } if (tp->rx_opt.saw_tstamp && tp->rx_opt.rcv_tsecr && !between(tp->rx_opt.rcv_tsecr, tp->retrans_stamp, - tcp_time_stamp(tp))) { + tcp_time_stamp_ts(tp))) { NET_INC_STATS(sock_net(sk), LINUX_MIB_PAWSACTIVEREJECTED); + SKB_DR_SET(reason, TCP_RFC7323_PAWS); goto reset_and_undo; } @@ -6154,7 +6687,9 @@ static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, if (th->rst) { tcp_reset(sk, skb); - goto discard; +consume: + __kfree_skb(skb); + return 0; } /* rfc793: @@ -6164,9 +6699,10 @@ static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, * See note below! * --ANK(990513) */ - if (!th->syn) + if (!th->syn) { + SKB_DR_SET(reason, TCP_FLAGS); goto discard_and_undo; - + } /* rfc793: * "If the SYN bit is on ... * are acceptable then ... @@ -6174,7 +6710,9 @@ static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, * state to ESTABLISHED..." */ - tcp_ecn_rcv_synack(tp, th); + if (tcp_ecn_mode_any(tp)) + tcp_ecn_rcv_synack(sk, skb, th, + TCP_SKB_CB(skb)->ip_dsfield); tcp_init_wl(tp, TCP_SKB_CB(skb)->seq); tcp_try_undo_spurious_syn(sk); @@ -6193,7 +6731,8 @@ static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, if (!tp->rx_opt.wscale_ok) { tp->rx_opt.snd_wscale = tp->rx_opt.rcv_wscale = 0; - tp->window_clamp = min(tp->window_clamp, 65535U); + WRITE_ONCE(tp->window_clamp, + min(tp->window_clamp, 65535U)); } if (tp->rx_opt.saw_tstamp) { @@ -6230,7 +6769,7 @@ static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, if (fastopen_fail) return -1; if (sk->sk_write_pending || - icsk->icsk_accept_queue.rskq_defer_accept || + READ_ONCE(icsk->icsk_accept_queue.rskq_defer_accept) || inet_csk_in_pingpong_mode(sk)) { /* Save one ACK. Data will be ready after * several ticks, if write_pending is set. @@ -6241,15 +6780,11 @@ static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, */ inet_csk_schedule_ack(sk); tcp_enter_quickack_mode(sk, TCP_MAX_QUICKACKS); - inet_csk_reset_xmit_timer(sk, ICSK_TIME_DACK, - TCP_DELACK_MAX, TCP_RTO_MAX); - -discard: - tcp_drop(sk, skb); - return 0; - } else { - tcp_send_ack(sk); + tcp_reset_xmit_timer(sk, ICSK_TIME_DACK, + TCP_DELACK_MAX, false); + goto consume; } + tcp_send_ack_reflect_ect(sk, tcp_ecn_mode_accecn(tp)); return -1; } @@ -6261,20 +6796,31 @@ discard: * * Otherwise (no ACK) drop the segment and return." */ - + SKB_DR_SET(reason, TCP_RESET); goto discard_and_undo; } /* PAWS check. */ if (tp->rx_opt.ts_recent_stamp && tp->rx_opt.saw_tstamp && - tcp_paws_reject(&tp->rx_opt, 0)) + tcp_paws_reject(&tp->rx_opt, 0)) { + SKB_DR_SET(reason, TCP_RFC7323_PAWS); goto discard_and_undo; - + } if (th->syn) { /* We see SYN without ACK. It is attempt of * simultaneous connect with crossed SYNs. * Particularly, it can be connect to self. */ +#ifdef CONFIG_TCP_AO + struct tcp_ao_info *ao; + + ao = rcu_dereference_protected(tp->ao_info, + lockdep_sock_is_held(sk)); + if (ao) { + WRITE_ONCE(ao->risn, th->seq); + ao->rcv_sne = 0; + } +#endif tcp_set_state(sk, TCP_SYN_RECV); if (tp->rx_opt.saw_tstamp) { @@ -6297,7 +6843,7 @@ discard: tp->snd_wl1 = TCP_SKB_CB(skb)->seq; tp->max_window = tp->snd_wnd; - tcp_ecn_rcv_syn(tp, th); + tcp_ecn_rcv_syn(tp, th, skb); tcp_mtup_init(sk); tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); @@ -6318,7 +6864,7 @@ discard: */ return -1; #else - goto discard; + goto consume; #endif } /* "fifth, if neither of the SYN or RST bits is set then @@ -6328,32 +6874,43 @@ discard: discard_and_undo: tcp_clear_options(&tp->rx_opt); tp->rx_opt.mss_clamp = saved_clamp; - goto discard; + tcp_drop_reason(sk, skb, reason); + return 0; reset_and_undo: tcp_clear_options(&tp->rx_opt); tp->rx_opt.mss_clamp = saved_clamp; - return 1; + /* we can reuse/return @reason to its caller to handle the exception */ + return reason; } static void tcp_rcv_synrecv_state_fastopen(struct sock *sk) { + struct tcp_sock *tp = tcp_sk(sk); struct request_sock *req; /* If we are still handling the SYNACK RTO, see if timestamp ECR allows * undo. If peer SACKs triggered fast recovery, we can't undo here. */ - if (inet_csk(sk)->icsk_ca_state == TCP_CA_Loss) - tcp_try_undo_loss(sk, false); + if (inet_csk(sk)->icsk_ca_state == TCP_CA_Loss && !tp->packets_out) + tcp_try_undo_recovery(sk); - /* Reset rtx states to prevent spurious retransmits_timed_out() */ - tcp_sk(sk)->retrans_stamp = 0; - inet_csk(sk)->icsk_retransmits = 0; + tcp_update_rto_time(tp); + WRITE_ONCE(inet_csk(sk)->icsk_retransmits, 0); + /* In tcp_fastopen_synack_timer() on the first SYNACK RTO we set + * retrans_stamp but don't enter CA_Loss, so in case that happened we + * need to zero retrans_stamp here to prevent spurious + * retransmits_timed_out(). However, if the ACK of our SYNACK caused us + * to enter CA_Recovery then we need to leave retrans_stamp as it was + * set entering CA_Recovery, for correct retransmits_timed_out() and + * undo behavior. + */ + tcp_retrans_stamp_cleanup(sk); /* Once we leave TCP_SYN_RECV or TCP_FIN_WAIT_1, * we no longer need req so release it. */ - req = rcu_dereference_protected(tcp_sk(sk)->fastopen_rsk, + req = rcu_dereference_protected(tp->fastopen_rsk, lockdep_sock_is_held(sk)); reqsk_fastopen_remove(sk, req, false); @@ -6375,43 +6932,47 @@ static void tcp_rcv_synrecv_state_fastopen(struct sock *sk) * address independent. */ -int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) +enum skb_drop_reason +tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) { struct tcp_sock *tp = tcp_sk(sk); struct inet_connection_sock *icsk = inet_csk(sk); const struct tcphdr *th = tcp_hdr(skb); struct request_sock *req; int queued = 0; - bool acceptable; + SKB_DR(reason); switch (sk->sk_state) { case TCP_CLOSE: + SKB_DR_SET(reason, TCP_CLOSE); goto discard; case TCP_LISTEN: if (th->ack) - return 1; + return SKB_DROP_REASON_TCP_FLAGS; - if (th->rst) + if (th->rst) { + SKB_DR_SET(reason, TCP_RESET); goto discard; - + } if (th->syn) { - if (th->fin) + if (th->fin) { + SKB_DR_SET(reason, TCP_FLAGS); goto discard; + } /* It is possible that we process SYN packets from backlog, * so we need to make sure to disable BH and RCU right there. */ rcu_read_lock(); local_bh_disable(); - acceptable = icsk->icsk_af_ops->conn_request(sk, skb) >= 0; + icsk->icsk_af_ops->conn_request(sk, skb); local_bh_enable(); rcu_read_unlock(); - if (!acceptable) - return 1; consume_skb(skb); return 0; } + SKB_DR_SET(reason, TCP_FLAGS); goto discard; case TCP_SYN_SENT: @@ -6438,33 +6999,47 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) WARN_ON_ONCE(sk->sk_state != TCP_SYN_RECV && sk->sk_state != TCP_FIN_WAIT1); - if (!tcp_check_req(sk, skb, req, true, &req_stolen)) + SKB_DR_SET(reason, TCP_FASTOPEN); + if (!tcp_check_req(sk, skb, req, true, &req_stolen, &reason)) goto discard; } - if (!th->ack && !th->rst && !th->syn) + if (!th->ack && !th->rst && !th->syn) { + SKB_DR_SET(reason, TCP_FLAGS); goto discard; - + } if (!tcp_validate_incoming(sk, skb, th, 0)) return 0; /* step 5: check the ACK field */ - acceptable = tcp_ack(sk, skb, FLAG_SLOWPATH | - FLAG_UPDATE_TS_RECENT | - FLAG_NO_CHALLENGE_ACK) > 0; - - if (!acceptable) { - if (sk->sk_state == TCP_SYN_RECV) - return 1; /* send one RST */ - tcp_send_challenge_ack(sk); - goto discard; + reason = tcp_ack(sk, skb, FLAG_SLOWPATH | + FLAG_UPDATE_TS_RECENT | + FLAG_NO_CHALLENGE_ACK); + + if ((int)reason <= 0) { + if (sk->sk_state == TCP_SYN_RECV) { + /* send one RST */ + if (!reason) + return SKB_DROP_REASON_TCP_OLD_ACK; + return -reason; + } + /* accept old ack during closing */ + if ((int)reason < 0) { + tcp_send_challenge_ack(sk, false); + reason = -reason; + goto discard; + } } + SKB_DR_SET(reason, NOT_SPECIFIED); switch (sk->sk_state) { case TCP_SYN_RECV: tp->delivered++; /* SYN-ACK delivery isn't tracked in tcp_ack */ if (!tp->srtt_us) tcp_synack_rtt_meas(sk, req); + if (tp->rx_opt.tstamp_ok) + tp->advmss -= TCPOLEN_TSTAMP_ALIGNED; + if (req) { tcp_rcv_synrecv_state_fastopen(sk); } else { @@ -6474,6 +7049,7 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) skb); WRITE_ONCE(tp->copied_seq, tp->rcv_nxt); } + tcp_ao_established(sk); smp_mb(); tcp_set_state(sk, TCP_ESTABLISHED); sk->sk_state_change(sk); @@ -6489,9 +7065,6 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) tp->snd_wnd = ntohs(th->window) << tp->rx_opt.snd_wscale; tcp_init_wl(tp, TCP_SKB_CB(skb)->seq); - if (tp->rx_opt.tstamp_ok) - tp->advmss -= TCPOLEN_TSTAMP_ALIGNED; - if (!inet_csk(sk)->icsk_ca_ops->cong_control) tcp_update_pacing_rate(sk); @@ -6499,7 +7072,12 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) tp->lsndtime = tcp_jiffies32; tcp_initialize_rcv_mss(sk); + if (tcp_ecn_mode_accecn(tp)) + tcp_accecn_third_ack(sk, skb, tp->syn_ect_snt); tcp_fast_path_on(tp); + if (sk->sk_shutdown & SEND_SHUTDOWN) + tcp_shutdown(sk, SEND_SHUTDOWN); + break; case TCP_FIN_WAIT1: { @@ -6512,7 +7090,7 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) break; tcp_set_state(sk, TCP_FIN_WAIT2); - sk->sk_shutdown |= SEND_SHUTDOWN; + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | SEND_SHUTDOWN); sk_dst_confirm(sk); @@ -6522,10 +7100,10 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) break; } - if (tp->linger2 < 0) { + if (READ_ONCE(tp->linger2) < 0) { tcp_done(sk); NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONDATA); - return 1; + return SKB_DROP_REASON_TCP_ABORT_ON_DATA; } if (TCP_SKB_CB(skb)->end_seq != TCP_SKB_CB(skb)->seq && after(TCP_SKB_CB(skb)->end_seq - th->fin, tp->rcv_nxt)) { @@ -6534,12 +7112,12 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) tcp_fastopen_active_disable(sk); tcp_done(sk); NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONDATA); - return 1; + return SKB_DROP_REASON_TCP_ABORT_ON_DATA; } tmo = tcp_fin_time(sk); if (tmo > TCP_TIMEWAIT_LEN) { - inet_csk_reset_keepalive_timer(sk, tmo - TCP_TIMEWAIT_LEN); + tcp_reset_keepalive_timer(sk, tmo - TCP_TIMEWAIT_LEN); } else if (th->fin || sock_owned_by_user(sk)) { /* Bad case. We could lose such FIN otherwise. * It is not a big problem, but it looks confusing @@ -6547,10 +7125,10 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) * if it spins in bh_lock_sock(), but it is really * marginal case. */ - inet_csk_reset_keepalive_timer(sk, tmo); + tcp_reset_keepalive_timer(sk, tmo); } else { tcp_time_wait(sk, TCP_FIN_WAIT2, tmo); - goto discard; + goto consume; } break; } @@ -6558,7 +7136,7 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) case TCP_CLOSING: if (tp->snd_una == tp->write_seq) { tcp_time_wait(sk, TCP_TIME_WAIT, 0); - goto discard; + goto consume; } break; @@ -6566,7 +7144,7 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) if (tp->snd_una == tp->write_seq) { tcp_update_metrics(sk); tcp_done(sk); - goto discard; + goto consume; } break; } @@ -6599,7 +7177,7 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) after(TCP_SKB_CB(skb)->end_seq - th->fin, tp->rcv_nxt)) { NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONDATA); tcp_reset(sk, skb); - return 1; + return SKB_DROP_REASON_TCP_ABORT_ON_DATA; } } fallthrough; @@ -6617,11 +7195,15 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) if (!queued) { discard: - tcp_drop(sk, skb); + tcp_drop_reason(sk, skb, reason); } return 0; + +consume: + __kfree_skb(skb); + return 0; } -EXPORT_SYMBOL(tcp_rcv_state_process); +EXPORT_IPV6_MOD(tcp_rcv_state_process); static inline void pr_drop_req(struct request_sock *req, __u16 port, int family) { @@ -6665,14 +7247,24 @@ static void tcp_ecn_create_request(struct request_sock *req, bool ect, ecn_ok; u32 ecn_ok_dst; + if (tcp_accecn_syn_requested(th) && + READ_ONCE(net->ipv4.sysctl_tcp_ecn) >= 3) { + inet_rsk(req)->ecn_ok = 1; + tcp_rsk(req)->accecn_ok = 1; + tcp_rsk(req)->syn_ect_rcv = TCP_SKB_CB(skb)->ip_dsfield & + INET_ECN_MASK; + return; + } + if (!th_ecn) return; ect = !INET_ECN_is_not_ect(TCP_SKB_CB(skb)->ip_dsfield); ecn_ok_dst = dst_feature(dst, DST_FEATURE_ECN_MASK); - ecn_ok = net->ipv4.sysctl_tcp_ecn || ecn_ok_dst; + ecn_ok = READ_ONCE(net->ipv4.sysctl_tcp_ecn) || ecn_ok_dst; - if (((!ect || th->res1) && ecn_ok) || tcp_ca_needs_ecn(listen_sk) || + if (((!ect || th->res1 || th->ae) && ecn_ok) || + tcp_ca_needs_ecn(listen_sk) || (ecn_ok_dst & DST_FEATURE_ECN_CA) || tcp_bpf_ca_needs_ecn((struct sock *)req)) inet_rsk(req)->ecn_ok = 1; @@ -6688,7 +7280,13 @@ static void tcp_openreq_init(struct request_sock *req, tcp_rsk(req)->rcv_isn = TCP_SKB_CB(skb)->seq; tcp_rsk(req)->rcv_nxt = TCP_SKB_CB(skb)->seq + 1; tcp_rsk(req)->snt_synack = 0; + tcp_rsk(req)->snt_tsval_first = 0; tcp_rsk(req)->last_oow_ack_time = 0; + tcp_rsk(req)->accecn_ok = 0; + tcp_rsk(req)->saw_accecn_opt = TCP_ACCECN_OPT_NOT_SEEN; + tcp_rsk(req)->accecn_fail_mode = 0; + tcp_rsk(req)->syn_ect_rcv = 0; + tcp_rsk(req)->syn_ect_snt = 0; req->mss = rx_opt->mss_clamp; req->ts_recent = rx_opt->saw_tstamp ? rx_opt->rcv_tsval : 0; ireq->tstamp_ok = rx_opt->tstamp_ok; @@ -6701,46 +7299,26 @@ static void tcp_openreq_init(struct request_sock *req, ireq->ir_num = ntohs(tcp_hdr(skb)->dest); ireq->ir_mark = inet_request_mark(sk, skb); #if IS_ENABLED(CONFIG_SMC) - ireq->smc_ok = rx_opt->smc_ok; + ireq->smc_ok = rx_opt->smc_ok && !(tcp_sk(sk)->smc_hs_congested && + tcp_sk(sk)->smc_hs_congested(sk)); #endif } -struct request_sock *inet_reqsk_alloc(const struct request_sock_ops *ops, - struct sock *sk_listener, - bool attach_listener) -{ - struct request_sock *req = reqsk_alloc(ops, sk_listener, - attach_listener); - - if (req) { - struct inet_request_sock *ireq = inet_rsk(req); - - ireq->ireq_opt = NULL; -#if IS_ENABLED(CONFIG_IPV6) - ireq->pktopts = NULL; -#endif - atomic64_set(&ireq->ir_cookie, 0); - ireq->ireq_state = TCP_NEW_SYN_RECV; - write_pnet(&ireq->ireq_net, sock_net(sk_listener)); - ireq->ireq_family = sk_listener->sk_family; - } - - return req; -} -EXPORT_SYMBOL(inet_reqsk_alloc); - /* * Return true if a syncookie should be sent */ -static bool tcp_syn_flood_action(const struct sock *sk, const char *proto) +static bool tcp_syn_flood_action(struct sock *sk, const char *proto) { struct request_sock_queue *queue = &inet_csk(sk)->icsk_accept_queue; const char *msg = "Dropping request"; - bool want_cookie = false; struct net *net = sock_net(sk); + bool want_cookie = false; + u8 syncookies; + + syncookies = READ_ONCE(net->ipv4.sysctl_tcp_syncookies); #ifdef CONFIG_SYN_COOKIES - if (net->ipv4.sysctl_tcp_syncookies) { + if (syncookies) { msg = "Sending cookies"; want_cookie = true; __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPREQQFULLDOCOOKIES); @@ -6748,11 +7326,18 @@ static bool tcp_syn_flood_action(const struct sock *sk, const char *proto) #endif __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPREQQFULLDROP); - if (!queue->synflood_warned && - net->ipv4.sysctl_tcp_syncookies != 2 && - xchg(&queue->synflood_warned, 1) == 0) - net_info_ratelimited("%s: Possible SYN flooding on port %d. %s. Check SNMP counters.\n", - proto, sk->sk_num, msg); + if (syncookies != 2 && !READ_ONCE(queue->synflood_warned)) { + WRITE_ONCE(queue->synflood_warned, 1); + if (IS_ENABLED(CONFIG_IPV6) && sk->sk_family == AF_INET6) { + net_info_ratelimited("%s: Possible SYN flooding on port [%pI6c]:%u. %s.\n", + proto, inet6_rcv_saddr(sk), + sk->sk_num, msg); + } else { + net_info_ratelimited("%s: Possible SYN flooding on port %pI4:%u. %s.\n", + proto, &sk->sk_rcv_saddr, + sk->sk_num, msg); + } + } return want_cookie; } @@ -6798,7 +7383,7 @@ u16 tcp_get_syncookie_mss(struct request_sock_ops *rsk_ops, struct tcp_sock *tp = tcp_sk(sk); u16 mss; - if (sock_net(sk)->ipv4.sysctl_tcp_syncookies != 2 && + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_syncookies) != 2 && !inet_csk_reqsk_queue_is_full(sk)) return 0; @@ -6810,38 +7395,50 @@ u16 tcp_get_syncookie_mss(struct request_sock_ops *rsk_ops, return 0; } - mss = tcp_parse_mss_option(th, tp->rx_opt.user_mss); + mss = tcp_parse_mss_option(th, READ_ONCE(tp->rx_opt.user_mss)); if (!mss) mss = af_ops->mss_clamp; return mss; } -EXPORT_SYMBOL_GPL(tcp_get_syncookie_mss); +EXPORT_IPV6_MOD_GPL(tcp_get_syncookie_mss); int tcp_conn_request(struct request_sock_ops *rsk_ops, const struct tcp_request_sock_ops *af_ops, struct sock *sk, struct sk_buff *skb) { struct tcp_fastopen_cookie foc = { .len = -1 }; - __u32 isn = TCP_SKB_CB(skb)->tcp_tw_isn; struct tcp_options_received tmp_opt; - struct tcp_sock *tp = tcp_sk(sk); + const struct tcp_sock *tp = tcp_sk(sk); struct net *net = sock_net(sk); struct sock *fastopen_sk = NULL; struct request_sock *req; bool want_cookie = false; struct dst_entry *dst; struct flowi fl; + u8 syncookies; + u32 isn; - /* TW buckets are converted to open requests without - * limitations, they conserve resources and peer is - * evidently real one. - */ - if ((net->ipv4.sysctl_tcp_syncookies == 2 || - inet_csk_reqsk_queue_is_full(sk)) && !isn) { - want_cookie = tcp_syn_flood_action(sk, rsk_ops->slab_name); - if (!want_cookie) - goto drop; +#ifdef CONFIG_TCP_AO + const struct tcp_ao_hdr *aoh; +#endif + + isn = __this_cpu_read(tcp_tw_isn); + if (isn) { + /* TW buckets are converted to open requests without + * limitations, they conserve resources and peer is + * evidently real one. + */ + __this_cpu_write(tcp_tw_isn, 0); + } else { + syncookies = READ_ONCE(net->ipv4.sysctl_tcp_syncookies); + + if (syncookies == 2 || inet_csk_reqsk_queue_is_full(sk)) { + want_cookie = tcp_syn_flood_action(sk, + rsk_ops->slab_name); + if (!want_cookie) + goto drop; + } } if (sk_acceptq_is_full(sk)) { @@ -6856,13 +7453,14 @@ int tcp_conn_request(struct request_sock_ops *rsk_ops, req->syncookie = want_cookie; tcp_rsk(req)->af_specific = af_ops; tcp_rsk(req)->ts_off = 0; + tcp_rsk(req)->req_usec_ts = false; #if IS_ENABLED(CONFIG_MPTCP) tcp_rsk(req)->is_mptcp = 0; #endif tcp_clear_options(&tmp_opt); tmp_opt.mss_clamp = af_ops->mss_clamp; - tmp_opt.user_mss = tp->rx_opt.user_mss; + tmp_opt.user_mss = READ_ONCE(tp->rx_opt.user_mss); tcp_parse_options(sock_net(sk), skb, &tmp_opt, 0, want_cookie ? NULL : &foc); @@ -6874,23 +7472,26 @@ int tcp_conn_request(struct request_sock_ops *rsk_ops, tmp_opt.tstamp_ok = tmp_opt.saw_tstamp; tcp_openreq_init(req, &tmp_opt, skb, sk); - inet_rsk(req)->no_srccheck = inet_sk(sk)->transparent; + inet_rsk(req)->no_srccheck = inet_test_bit(TRANSPARENT, sk); /* Note: tcp_v6_init_req() might override ir_iif for link locals */ inet_rsk(req)->ir_iif = inet_request_bound_dev_if(sk, skb); - dst = af_ops->route_req(sk, skb, &fl, req); + dst = af_ops->route_req(sk, skb, &fl, req, isn); if (!dst) goto drop_and_free; - if (tmp_opt.tstamp_ok) + if (tmp_opt.tstamp_ok) { + tcp_rsk(req)->req_usec_ts = dst_tcp_usec_ts(dst); tcp_rsk(req)->ts_off = af_ops->init_ts_off(net, skb); - + } if (!want_cookie && !isn) { + int max_syn_backlog = READ_ONCE(net->ipv4.sysctl_max_syn_backlog); + /* Kill the following clause, if you dislike this way. */ - if (!net->ipv4.sysctl_tcp_syncookies && - (net->ipv4.sysctl_max_syn_backlog - inet_csk_reqsk_queue_len(sk) < - (net->ipv4.sysctl_max_syn_backlog >> 2)) && + if (!syncookies && + (max_syn_backlog - inet_csk_reqsk_queue_len(sk) < + (max_syn_backlog >> 2)) && !tcp_peer_is_proven(req, dst)) { /* Without syncookies last quarter of * backlog is filled with destinations, @@ -6915,6 +7516,18 @@ int tcp_conn_request(struct request_sock_ops *rsk_ops, inet_rsk(req)->ecn_ok = 0; } +#ifdef CONFIG_TCP_AO + if (tcp_parse_auth_options(tcp_hdr(skb), NULL, &aoh)) + goto drop_and_release; /* Invalid TCP options */ + if (aoh) { + tcp_rsk(req)->used_tcp_ao = true; + tcp_rsk(req)->ao_rcv_next = aoh->keyid; + tcp_rsk(req)->ao_keyid = aoh->rnext_keyid; + + } else { + tcp_rsk(req)->used_tcp_ao = false; + } +#endif tcp_rsk(req)->snt_isn = isn; tcp_rsk(req)->txhash = net_tx_rndhash(); tcp_rsk(req)->syn_tos = TCP_SKB_CB(skb)->ip_dsfield; @@ -6929,7 +7542,6 @@ int tcp_conn_request(struct request_sock_ops *rsk_ops, &foc, TCP_SYNACK_FASTOPEN, skb); /* Add the child socket directly into the accept queue */ if (!inet_csk_reqsk_queue_add(sk, req, fastopen_sk)) { - reqsk_fastopen_remove(fastopen_sk, req, false); bh_unlock_sock(fastopen_sk); sock_put(fastopen_sk); goto drop_and_free; @@ -6939,9 +7551,12 @@ int tcp_conn_request(struct request_sock_ops *rsk_ops, sock_put(fastopen_sk); } else { tcp_rsk(req)->tfo_listener = false; - if (!want_cookie) - inet_csk_reqsk_queue_hash_add(sk, req, - tcp_timeout_init((struct sock *)req)); + if (!want_cookie && + unlikely(!inet_csk_reqsk_queue_hash_add(sk, req))) { + reqsk_free(req); + dst_release(dst); + return 0; + } af_ops->send_synack(sk, dst, &fl, req, &foc, !want_cookie ? TCP_SYNACK_NORMAL : TCP_SYNACK_COOKIE, @@ -6962,4 +7577,4 @@ drop: tcp_listendrop(sk); return 0; } -EXPORT_SYMBOL(tcp_conn_request); +EXPORT_IPV6_MOD(tcp_conn_request); diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index b53476e78c84..f8a9596e8f4d 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -53,22 +53,30 @@ #include <linux/module.h> #include <linux/random.h> #include <linux/cache.h> +#include <linux/fips.h> #include <linux/jhash.h> #include <linux/init.h> #include <linux/times.h> #include <linux/slab.h> +#include <linux/sched.h> +#include <linux/sock_diag.h> +#include <net/aligned_data.h> #include <net/net_namespace.h> #include <net/icmp.h> #include <net/inet_hashtables.h> #include <net/tcp.h> +#include <net/tcp_ecn.h> #include <net/transp_v6.h> #include <net/ipv6.h> #include <net/inet_common.h> +#include <net/inet_ecn.h> #include <net/timewait_sock.h> #include <net/xfrm.h> #include <net/secure_seq.h> #include <net/busy_poll.h> +#include <net/rstreason.h> +#include <net/psp.h> #include <linux/inet.h> #include <linux/ipv6.h> @@ -77,19 +85,24 @@ #include <linux/seq_file.h> #include <linux/inetdevice.h> #include <linux/btf_ids.h> +#include <linux/skbuff_ref.h> -#include <crypto/hash.h> -#include <linux/scatterlist.h> +#include <crypto/md5.h> #include <trace/events/tcp.h> #ifdef CONFIG_TCP_MD5SIG -static int tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key, - __be32 daddr, __be32 saddr, const struct tcphdr *th); +static void tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key, + __be32 daddr, __be32 saddr, const struct tcphdr *th); #endif struct inet_hashinfo tcp_hashinfo; -EXPORT_SYMBOL(tcp_hashinfo); + +static DEFINE_PER_CPU(struct sock_bh_locked, ipv4_tcp_sk) = { + .bh_lock = INIT_LOCAL_LOCK(bh_lock), +}; + +static DEFINE_MUTEX(tcp_exit_batch_mutex); static u32 tcp_v4_init_seq(const struct sk_buff *skb) { @@ -106,10 +119,15 @@ static u32 tcp_v4_init_ts_off(const struct net *net, const struct sk_buff *skb) int tcp_twsk_unique(struct sock *sk, struct sock *sktw, void *twp) { + int reuse = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_tw_reuse); const struct inet_timewait_sock *tw = inet_twsk(sktw); const struct tcp_timewait_sock *tcptw = tcp_twsk(sktw); struct tcp_sock *tp = tcp_sk(sk); - int reuse = sock_net(sk)->ipv4.sysctl_tcp_tw_reuse; + int ts_recent_stamp; + u32 reuse_thresh; + + if (READ_ONCE(tw->tw_substate) == TCP_FIN_WAIT2) + reuse = 0; if (reuse == 2) { /* Still does not detect *everything* that goes through @@ -148,9 +166,17 @@ int tcp_twsk_unique(struct sock *sk, struct sock *sktw, void *twp) If TW bucket has been already destroyed we fall back to VJ's scheme and use initial timestamp retrieved from peer table. */ - if (tcptw->tw_ts_recent_stamp && - (!twp || (reuse && time_after32(ktime_get_seconds(), - tcptw->tw_ts_recent_stamp)))) { + ts_recent_stamp = READ_ONCE(tcptw->tw_ts_recent_stamp); + reuse_thresh = READ_ONCE(tw->tw_entry_stamp) + + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_tw_reuse_delay); + if (ts_recent_stamp && + (!twp || (reuse && time_after32(tcp_clock_ms(), reuse_thresh)))) { + /* inet_twsk_hashdance_schedule() sets sk_refcnt after putting twsk + * and releasing the bucket lock. + */ + if (unlikely(!refcount_inc_not_zero(&sktw->sk_refcnt))) + return 0; + /* In case of repair and re-using TIME-WAIT sockets we still * want to be sure that it is safe as above but honor the * sequence numbers and time stamps set as part of the repair @@ -168,18 +194,18 @@ int tcp_twsk_unique(struct sock *sk, struct sock *sktw, void *twp) if (!seq) seq = 1; WRITE_ONCE(tp->write_seq, seq); - tp->rx_opt.ts_recent = tcptw->tw_ts_recent; - tp->rx_opt.ts_recent_stamp = tcptw->tw_ts_recent_stamp; + tp->rx_opt.ts_recent = READ_ONCE(tcptw->tw_ts_recent); + tp->rx_opt.ts_recent_stamp = ts_recent_stamp; } - sock_hold(sktw); + return 1; } return 0; } -EXPORT_SYMBOL_GPL(tcp_twsk_unique); +EXPORT_IPV6_MOD_GPL(tcp_twsk_unique); -static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr *uaddr, +static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr_unsized *uaddr, int addr_len) { /* This check is replicated from tcp_v4_connect() and intended to @@ -191,22 +217,23 @@ static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr *uaddr, sock_owned_by_me(sk); - return BPF_CGROUP_RUN_PROG_INET4_CONNECT(sk, uaddr); + return BPF_CGROUP_RUN_PROG_INET4_CONNECT(sk, uaddr, &addr_len); } /* This will initiate an outgoing connection. */ -int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +int tcp_v4_connect(struct sock *sk, struct sockaddr_unsized *uaddr, int addr_len) { struct sockaddr_in *usin = (struct sockaddr_in *)uaddr; + struct inet_timewait_death_row *tcp_death_row; struct inet_sock *inet = inet_sk(sk); struct tcp_sock *tp = tcp_sk(sk); + struct ip_options_rcu *inet_opt; + struct net *net = sock_net(sk); __be16 orig_sport, orig_dport; __be32 daddr, nexthop; struct flowi4 *fl4; struct rtable *rt; int err; - struct ip_options_rcu *inet_opt; - struct inet_timewait_death_row *tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row; if (addr_len < sizeof(struct sockaddr_in)) return -EINVAL; @@ -227,13 +254,12 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) orig_dport = usin->sin_port; fl4 = &inet->cork.fl.u.ip4; rt = ip_route_connect(fl4, nexthop, inet->inet_saddr, - RT_CONN_FLAGS(sk), sk->sk_bound_dev_if, - IPPROTO_TCP, - orig_sport, orig_dport, sk); + sk->sk_bound_dev_if, IPPROTO_TCP, orig_sport, + orig_dport, sk); if (IS_ERR(rt)) { err = PTR_ERR(rt); if (err == -ENETUNREACH) - IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTNOROUTES); + IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); return err; } @@ -245,9 +271,17 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) if (!inet_opt || !inet_opt->opt.srr) daddr = fl4->daddr; - if (!inet->inet_saddr) - inet->inet_saddr = fl4->saddr; - sk_rcv_saddr_set(sk, inet->inet_saddr); + tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row; + + if (!inet->inet_saddr) { + err = inet_bhash2_update_saddr(sk, &fl4->saddr, AF_INET); + if (err) { + ip_rt_put(rt); + return err; + } + } else { + sk_rcv_saddr_set(sk, inet->inet_saddr); + } if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) { /* Reset inherited state */ @@ -260,9 +294,9 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) inet->inet_dport = usin->sin_port; sk_daddr_set(sk, daddr); - inet_csk(sk)->icsk_ext_hdr_len = 0; + inet_csk(sk)->icsk_ext_hdr_len = psp_sk_overhead(sk); if (inet_opt) - inet_csk(sk)->icsk_ext_hdr_len = inet_opt->opt.optlen; + inet_csk(sk)->icsk_ext_hdr_len += inet_opt->opt.optlen; tp->rx_opt.mss_clamp = TCP_MSS_DEFAULT; @@ -285,6 +319,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) rt = NULL; goto failure; } + tp->tcp_usec_ts = dst_tcp_usec_ts(&rt->dst); /* OK, now commit destination to socket. */ sk->sk_gso_type = SKB_GSO_TCPV4; sk_setup_caps(sk, &rt->dst); @@ -297,12 +332,12 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) inet->inet_daddr, inet->inet_sport, usin->sin_port)); - tp->tsoffset = secure_tcp_ts_off(sock_net(sk), - inet->inet_saddr, - inet->inet_daddr); + WRITE_ONCE(tp->tsoffset, + secure_tcp_ts_off(net, inet->inet_saddr, + inet->inet_daddr)); } - inet->inet_id = prandom_u32(); + atomic_set(&inet->inet_id, get_random_u16()); if (tcp_fastopen_defer_connect(sk, &err)) return err; @@ -322,12 +357,13 @@ failure: * if necessary. */ tcp_set_state(sk, TCP_CLOSE); + inet_bhash2_reset_saddr(sk); ip_rt_put(rt); sk->sk_route_caps = 0; inet->inet_dport = 0; return err; } -EXPORT_SYMBOL(tcp_v4_connect); +EXPORT_IPV6_MOD(tcp_v4_connect); /* * This routine reacts to ICMP_FRAG_NEEDED mtu indications as defined in RFC1191. @@ -351,7 +387,7 @@ void tcp_v4_mtu_reduced(struct sock *sk) * for the case, if this connection will not able to recover. */ if (mtu < dst_mtu(dst) && ip_dont_fragment(sk, dst)) - sk->sk_err_soft = EMSGSIZE; + WRITE_ONCE(sk->sk_err_soft, EMSGSIZE); mtu = dst_mtu(dst); @@ -368,7 +404,7 @@ void tcp_v4_mtu_reduced(struct sock *sk) tcp_simple_retransmit(sk); } /* else let the usual retransmit timer handle it */ } -EXPORT_SYMBOL(tcp_v4_mtu_reduced); +EXPORT_IPV6_MOD(tcp_v4_mtu_reduced); static void do_redirect(struct sk_buff *skb, struct sock *sk) { @@ -402,7 +438,7 @@ void tcp_req_err(struct sock *sk, u32 seq, bool abort) } reqsk_put(req); } -EXPORT_SYMBOL(tcp_req_err); +EXPORT_IPV6_MOD(tcp_req_err); /* TCP-LD (RFC 6069) logic */ void tcp_ld_RTO_revert(struct sock *sk, u32 seq) @@ -426,15 +462,14 @@ void tcp_ld_RTO_revert(struct sock *sk, u32 seq) icsk->icsk_backoff--; icsk->icsk_rto = tp->srtt_us ? __tcp_set_rto(tp) : TCP_TIMEOUT_INIT; - icsk->icsk_rto = inet_csk_rto_backoff(icsk, TCP_RTO_MAX); + icsk->icsk_rto = inet_csk_rto_backoff(icsk, tcp_rto_max(sk)); tcp_mstamp_refresh(tp); delta_us = (u32)(tp->tcp_mstamp - tcp_skb_timestamp_us(skb)); remaining = icsk->icsk_rto - usecs_to_jiffies(delta_us); if (remaining > 0) { - inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - remaining, TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, remaining, false); } else { /* RTO revert clocked out retransmission. * Will retransmit now. @@ -442,7 +477,7 @@ void tcp_ld_RTO_revert(struct sock *sk, u32 seq) tcp_retransmit_timer(sk); } } -EXPORT_SYMBOL(tcp_ld_RTO_revert); +EXPORT_IPV6_MOD(tcp_ld_RTO_revert); /* * This routine is called by the ICMP module when it gets some @@ -464,24 +499,24 @@ int tcp_v4_err(struct sk_buff *skb, u32 info) { const struct iphdr *iph = (const struct iphdr *)skb->data; struct tcphdr *th = (struct tcphdr *)(skb->data + (iph->ihl << 2)); - struct tcp_sock *tp; - struct inet_sock *inet; + struct net *net = dev_net_rcu(skb->dev); const int type = icmp_hdr(skb)->type; const int code = icmp_hdr(skb)->code; - struct sock *sk; struct request_sock *fastopen; + struct tcp_sock *tp; u32 seq, snd_una; + struct sock *sk; int err; - struct net *net = dev_net(skb->dev); - sk = __inet_lookup_established(net, &tcp_hashinfo, iph->daddr, - th->dest, iph->saddr, ntohs(th->source), - inet_iif(skb), 0); + sk = __inet_lookup_established(net, iph->daddr, th->dest, iph->saddr, + ntohs(th->source), inet_iif(skb), 0); if (!sk) { __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); return -ENOENT; } if (sk->sk_state == TCP_TIME_WAIT) { + /* To increase the counter of ignored icmps for TCP-AO */ + tcp_ao_ignore_icmp(sk, AF_INET, type, code); inet_twsk_put(inet_twsk(sk)); return 0; } @@ -495,6 +530,11 @@ int tcp_v4_err(struct sk_buff *skb, u32 info) return 0; } + if (tcp_ao_ignore_icmp(sk, AF_INET, type, code)) { + sock_put(sk); + return 0; + } + bh_lock_sock(sk); /* If too many ICMPs get dropped on busy * servers this needs to be solved differently. @@ -585,15 +625,10 @@ int tcp_v4_err(struct sk_buff *skb, u32 info) ip_icmp_error(sk, skb, err, th->dest, info, (u8 *)th); - if (!sock_owned_by_user(sk)) { - sk->sk_err = err; - - sk_error_report(sk); - - tcp_done(sk); - } else { - sk->sk_err_soft = err; - } + if (!sock_owned_by_user(sk)) + tcp_done_with_error(sk, err); + else + WRITE_ONCE(sk->sk_err_soft, err); goto out; } @@ -613,12 +648,12 @@ int tcp_v4_err(struct sk_buff *skb, u32 info) * --ANK (980905) */ - inet = inet_sk(sk); - if (!sock_owned_by_user(sk) && inet->recverr) { - sk->sk_err = err; + if (!sock_owned_by_user(sk) && + inet_test_bit(RECVERR, sk)) { + WRITE_ONCE(sk->sk_err, err); sk_error_report(sk); } else { /* Only an error on timeout */ - sk->sk_err_soft = err; + WRITE_ONCE(sk->sk_err_soft, err); } out: @@ -643,7 +678,53 @@ void tcp_v4_send_check(struct sock *sk, struct sk_buff *skb) __tcp_v4_send_check(skb, inet->inet_saddr, inet->inet_daddr); } -EXPORT_SYMBOL(tcp_v4_send_check); +EXPORT_IPV6_MOD(tcp_v4_send_check); + +#define REPLY_OPTIONS_LEN (MAX_TCP_OPTION_SPACE / sizeof(__be32)) + +static bool tcp_v4_ao_sign_reset(const struct sock *sk, struct sk_buff *skb, + const struct tcp_ao_hdr *aoh, + struct ip_reply_arg *arg, struct tcphdr *reply, + __be32 reply_options[REPLY_OPTIONS_LEN]) +{ +#ifdef CONFIG_TCP_AO + int sdif = tcp_v4_sdif(skb); + int dif = inet_iif(skb); + int l3index = sdif ? dif : 0; + bool allocated_traffic_key; + struct tcp_ao_key *key; + char *traffic_key; + bool drop = true; + u32 ao_sne = 0; + u8 keyid; + + rcu_read_lock(); + if (tcp_ao_prepare_reset(sk, skb, aoh, l3index, ntohl(reply->seq), + &key, &traffic_key, &allocated_traffic_key, + &keyid, &ao_sne)) + goto out; + + reply_options[0] = htonl((TCPOPT_AO << 24) | (tcp_ao_len(key) << 16) | + (aoh->rnext_keyid << 8) | keyid); + arg->iov[0].iov_len += tcp_ao_len_aligned(key); + reply->doff = arg->iov[0].iov_len / 4; + + if (tcp_ao_hash_hdr(AF_INET, (char *)&reply_options[1], + key, traffic_key, + (union tcp_ao_addr *)&ip_hdr(skb)->saddr, + (union tcp_ao_addr *)&ip_hdr(skb)->daddr, + reply, ao_sne)) + goto out; + drop = false; +out: + rcu_read_unlock(); + if (allocated_traffic_key) + kfree(traffic_key); + return drop; +#else + return true; +#endif +} /* * This routine will send an RST to the other tcp. @@ -658,30 +739,26 @@ EXPORT_SYMBOL(tcp_v4_send_check); * Exception: precedence violation. We do not implement it in any case. */ -#ifdef CONFIG_TCP_MD5SIG -#define OPTION_BYTES TCPOLEN_MD5SIG_ALIGNED -#else -#define OPTION_BYTES sizeof(__be32) -#endif - -static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) +static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb, + enum sk_rst_reason reason) { const struct tcphdr *th = tcp_hdr(skb); struct { struct tcphdr th; - __be32 opt[OPTION_BYTES / sizeof(__be32)]; + __be32 opt[REPLY_OPTIONS_LEN]; } rep; + const __u8 *md5_hash_location = NULL; + const struct tcp_ao_hdr *aoh; struct ip_reply_arg arg; #ifdef CONFIG_TCP_MD5SIG struct tcp_md5sig_key *key = NULL; - const __u8 *hash_location = NULL; unsigned char newhash[16]; - int genhash; struct sock *sk1 = NULL; #endif u64 transmit_time = 0; struct sock *ctl_sk; struct net *net; + u32 txhash = 0; /* Never send a reset in response to a reset. */ if (th->rst) @@ -712,10 +789,17 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) arg.iov[0].iov_base = (unsigned char *)&rep; arg.iov[0].iov_len = sizeof(rep.th); - net = sk ? sock_net(sk) : dev_net(skb_dst(skb)->dev); + net = sk ? sock_net(sk) : skb_dst_dev_net_rcu(skb); + + /* Invalid TCP option size or twice included auth */ + if (tcp_parse_auth_options(tcp_hdr(skb), &md5_hash_location, &aoh)) + return; + + if (aoh && tcp_v4_ao_sign_reset(sk, skb, aoh, &arg, &rep.th, rep.opt)) + return; + #ifdef CONFIG_TCP_MD5SIG rcu_read_lock(); - hash_location = tcp_parse_md5sig_option(th); if (sk && sk_fullsock(sk)) { const union tcp_md5_addr *addr; int l3index; @@ -726,7 +810,7 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0; addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr; key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET); - } else if (hash_location) { + } else if (md5_hash_location) { const union tcp_md5_addr *addr; int sdif = tcp_v4_sdif(skb); int dif = inet_iif(skb); @@ -739,8 +823,7 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) * Incoming packet is checked with md5 hash with finding key, * no RST generated if md5 hash doesn't match. */ - sk1 = __inet_lookup_listener(net, &tcp_hashinfo, NULL, 0, - ip_hdr(skb)->saddr, + sk1 = __inet_lookup_listener(net, NULL, 0, ip_hdr(skb)->saddr, th->source, ip_hdr(skb)->daddr, ntohs(th->source), dif, sdif); /* don't send rst if it can't find key */ @@ -756,11 +839,9 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) if (!key) goto out; - - genhash = tcp_v4_md5_hash_skb(newhash, key, NULL, skb); - if (genhash || memcmp(hash_location, newhash, 16) != 0) + tcp_v4_md5_hash_skb(newhash, key, NULL, skb); + if (memcmp(md5_hash_location, newhash, 16) != 0) goto out; - } if (key) { @@ -798,35 +879,46 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) * routing might fail in this case. No choice here, if we choose to force * input interface, we will misroute in case of asymmetric route. */ - if (sk) { + if (sk) arg.bound_dev_if = sk->sk_bound_dev_if; - if (sk_fullsock(sk)) - trace_tcp_send_reset(sk, skb); - } + + trace_tcp_send_reset(sk, skb, reason); BUILD_BUG_ON(offsetof(struct sock, sk_bound_dev_if) != offsetof(struct inet_timewait_sock, tw_bound_dev_if)); - arg.tos = ip_hdr(skb)->tos; + /* ECN bits of TW reset are cleared */ + arg.tos = ip_hdr(skb)->tos & ~INET_ECN_MASK; arg.uid = sock_net_uid(net, sk && sk_fullsock(sk) ? sk : NULL); local_bh_disable(); - ctl_sk = this_cpu_read(*net->ipv4.tcp_sk); + local_lock_nested_bh(&ipv4_tcp_sk.bh_lock); + ctl_sk = this_cpu_read(ipv4_tcp_sk.sock); + + sock_net_set(ctl_sk, net); if (sk) { ctl_sk->sk_mark = (sk->sk_state == TCP_TIME_WAIT) ? - inet_twsk(sk)->tw_mark : sk->sk_mark; + inet_twsk(sk)->tw_mark : READ_ONCE(sk->sk_mark); ctl_sk->sk_priority = (sk->sk_state == TCP_TIME_WAIT) ? - inet_twsk(sk)->tw_priority : sk->sk_priority; + inet_twsk(sk)->tw_priority : READ_ONCE(sk->sk_priority); transmit_time = tcp_transmit_time(sk); + xfrm_sk_clone_policy(ctl_sk, sk); + txhash = (sk->sk_state == TCP_TIME_WAIT) ? + inet_twsk(sk)->tw_txhash : sk->sk_txhash; + } else { + ctl_sk->sk_mark = 0; + ctl_sk->sk_priority = 0; } - ip_send_unicast_reply(ctl_sk, + ip_send_unicast_reply(ctl_sk, sk, skb, &TCP_SKB_CB(skb)->header.h4.opt, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr, &arg, arg.iov[0].iov_len, - transmit_time); + transmit_time, txhash); - ctl_sk->sk_mark = 0; + xfrm_sk_free_policy(ctl_sk); + sock_net_set(ctl_sk, &init_net); __TCP_INC_STATS(net, TCP_MIB_OUTSEGS); __TCP_INC_STATS(net, TCP_MIB_OUTRSTS); + local_unlock_nested_bh(&ipv4_tcp_sk.bh_lock); local_bh_enable(); #ifdef CONFIG_TCP_MD5SIG @@ -842,17 +934,13 @@ out: static void tcp_v4_send_ack(const struct sock *sk, struct sk_buff *skb, u32 seq, u32 ack, u32 win, u32 tsval, u32 tsecr, int oif, - struct tcp_md5sig_key *key, - int reply_flags, u8 tos) + struct tcp_key *key, + int reply_flags, u8 tos, u32 txhash) { const struct tcphdr *th = tcp_hdr(skb); struct { struct tcphdr th; - __be32 opt[(TCPOLEN_TSTAMP_ALIGNED >> 2) -#ifdef CONFIG_TCP_MD5SIG - + (TCPOLEN_MD5SIG_ALIGNED >> 2) -#endif - ]; + __be32 opt[(MAX_TCP_OPTION_SPACE >> 2)]; } rep; struct net *net = sock_net(sk); struct ip_reply_arg arg; @@ -883,7 +971,7 @@ static void tcp_v4_send_ack(const struct sock *sk, rep.th.window = htons(win); #ifdef CONFIG_TCP_MD5SIG - if (key) { + if (tcp_key_is_md5(key)) { int offset = (tsecr) ? 3 : 0; rep.opt[offset++] = htonl((TCPOPT_NOP << 24) | @@ -894,10 +982,28 @@ static void tcp_v4_send_ack(const struct sock *sk, rep.th.doff = arg.iov[0].iov_len/4; tcp_v4_md5_hash_hdr((__u8 *) &rep.opt[offset], - key, ip_hdr(skb)->saddr, + key->md5_key, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr, &rep.th); } #endif +#ifdef CONFIG_TCP_AO + if (tcp_key_is_ao(key)) { + int offset = (tsecr) ? 3 : 0; + + rep.opt[offset++] = htonl((TCPOPT_AO << 24) | + (tcp_ao_len(key->ao_key) << 16) | + (key->ao_key->sndid << 8) | + key->rcv_next); + arg.iov[0].iov_len += tcp_ao_len_aligned(key->ao_key); + rep.th.doff = arg.iov[0].iov_len / 4; + + tcp_ao_hash_hdr(AF_INET, (char *)&rep.opt[offset], + key->ao_key, key->traffic_key, + (union tcp_ao_addr *)&ip_hdr(skb)->saddr, + (union tcp_ao_addr *)&ip_hdr(skb)->daddr, + &rep.th, key->sne); + } +#endif arg.flags = reply_flags; arg.csum = csum_tcpudp_nofold(ip_hdr(skb)->daddr, ip_hdr(skb)->saddr, /* XXX */ @@ -908,38 +1014,86 @@ static void tcp_v4_send_ack(const struct sock *sk, arg.tos = tos; arg.uid = sock_net_uid(net, sk_fullsock(sk) ? sk : NULL); local_bh_disable(); - ctl_sk = this_cpu_read(*net->ipv4.tcp_sk); + local_lock_nested_bh(&ipv4_tcp_sk.bh_lock); + ctl_sk = this_cpu_read(ipv4_tcp_sk.sock); + sock_net_set(ctl_sk, net); ctl_sk->sk_mark = (sk->sk_state == TCP_TIME_WAIT) ? - inet_twsk(sk)->tw_mark : sk->sk_mark; + inet_twsk(sk)->tw_mark : READ_ONCE(sk->sk_mark); ctl_sk->sk_priority = (sk->sk_state == TCP_TIME_WAIT) ? - inet_twsk(sk)->tw_priority : sk->sk_priority; + inet_twsk(sk)->tw_priority : READ_ONCE(sk->sk_priority); transmit_time = tcp_transmit_time(sk); - ip_send_unicast_reply(ctl_sk, + ip_send_unicast_reply(ctl_sk, sk, skb, &TCP_SKB_CB(skb)->header.h4.opt, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr, &arg, arg.iov[0].iov_len, - transmit_time); + transmit_time, txhash); - ctl_sk->sk_mark = 0; + sock_net_set(ctl_sk, &init_net); __TCP_INC_STATS(net, TCP_MIB_OUTSEGS); + local_unlock_nested_bh(&ipv4_tcp_sk.bh_lock); local_bh_enable(); } -static void tcp_v4_timewait_ack(struct sock *sk, struct sk_buff *skb) +static void tcp_v4_timewait_ack(struct sock *sk, struct sk_buff *skb, + enum tcp_tw_status tw_status) { struct inet_timewait_sock *tw = inet_twsk(sk); struct tcp_timewait_sock *tcptw = tcp_twsk(sk); + struct tcp_key key = {}; + u8 tos = tw->tw_tos; + + /* Cleaning only ECN bits of TW ACKs of oow data or is paws_reject, + * while not cleaning ECN bits of other TW ACKs to avoid these ACKs + * being placed in a different service queues (Classic rather than L4S) + */ + if (tw_status == TCP_TW_ACK_OOW) + tos &= ~INET_ECN_MASK; + +#ifdef CONFIG_TCP_AO + struct tcp_ao_info *ao_info; + + if (static_branch_unlikely(&tcp_ao_needed.key)) { + /* FIXME: the segment to-be-acked is not verified yet */ + ao_info = rcu_dereference(tcptw->ao_info); + if (ao_info) { + const struct tcp_ao_hdr *aoh; + + if (tcp_parse_auth_options(tcp_hdr(skb), NULL, &aoh)) { + inet_twsk_put(tw); + return; + } + + if (aoh) + key.ao_key = tcp_ao_established_key(sk, ao_info, + aoh->rnext_keyid, -1); + } + } + if (key.ao_key) { + struct tcp_ao_key *rnext_key; + + key.traffic_key = snd_other_key(key.ao_key); + key.sne = READ_ONCE(ao_info->snd_sne); + rnext_key = READ_ONCE(ao_info->rnext_key); + key.rcv_next = rnext_key->rcvid; + key.type = TCP_KEY_AO; +#else + if (0) { +#endif + } else if (static_branch_tcp_md5()) { + key.md5_key = tcp_twsk_md5_key(tcptw); + if (key.md5_key) + key.type = TCP_KEY_MD5; + } tcp_v4_send_ack(sk, skb, - tcptw->tw_snd_nxt, tcptw->tw_rcv_nxt, + tcptw->tw_snd_nxt, READ_ONCE(tcptw->tw_rcv_nxt), tcptw->tw_rcv_wnd >> tw->tw_rcv_wscale, - tcp_time_stamp_raw() + tcptw->tw_ts_offset, - tcptw->tw_ts_recent, - tw->tw_bound_dev_if, - tcp_twsk_md5_key(tcptw), + tcp_tw_tsval(tcptw), + READ_ONCE(tcptw->tw_ts_recent), + tw->tw_bound_dev_if, &key, tw->tw_transparent ? IP_REPLY_ARG_NOSRCCHECK : 0, - tw->tw_tos - ); + tos, + tw->tw_txhash); inet_twsk_put(tw); } @@ -947,8 +1101,7 @@ static void tcp_v4_timewait_ack(struct sock *sk, struct sk_buff *skb) static void tcp_v4_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb, struct request_sock *req) { - const union tcp_md5_addr *addr; - int l3index; + struct tcp_key key = {}; /* sk->sk_state == TCP_LISTEN -> for regular TCP_SYN_RECV * sk->sk_state == TCP_SYN_RECV -> for Fast Open. @@ -956,22 +1109,71 @@ static void tcp_v4_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb, u32 seq = (sk->sk_state == TCP_LISTEN) ? tcp_rsk(req)->snt_isn + 1 : tcp_sk(sk)->snd_nxt; - /* RFC 7323 2.3 - * The window field (SEG.WND) of every outgoing segment, with the - * exception of <SYN> segments, MUST be right-shifted by - * Rcv.Wind.Shift bits: - */ - addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr; - l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0; +#ifdef CONFIG_TCP_AO + if (static_branch_unlikely(&tcp_ao_needed.key) && + tcp_rsk_used_ao(req)) { + const union tcp_md5_addr *addr; + const struct tcp_ao_hdr *aoh; + int l3index; + + /* Invalid TCP option size or twice included auth */ + if (tcp_parse_auth_options(tcp_hdr(skb), NULL, &aoh)) + return; + if (!aoh) + return; + + addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr; + l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0; + key.ao_key = tcp_ao_do_lookup(sk, l3index, addr, AF_INET, + aoh->rnext_keyid, -1); + if (unlikely(!key.ao_key)) { + /* Send ACK with any matching MKT for the peer */ + key.ao_key = tcp_ao_do_lookup(sk, l3index, addr, AF_INET, -1, -1); + /* Matching key disappeared (user removed the key?) + * let the handshake timeout. + */ + if (!key.ao_key) { + net_info_ratelimited("TCP-AO key for (%pI4, %d)->(%pI4, %d) suddenly disappeared, won't ACK new connection\n", + addr, + ntohs(tcp_hdr(skb)->source), + &ip_hdr(skb)->daddr, + ntohs(tcp_hdr(skb)->dest)); + return; + } + } + key.traffic_key = kmalloc(tcp_ao_digest_size(key.ao_key), GFP_ATOMIC); + if (!key.traffic_key) + return; + + key.type = TCP_KEY_AO; + key.rcv_next = aoh->keyid; + tcp_v4_ao_calc_key_rsk(key.ao_key, key.traffic_key, req); +#else + if (0) { +#endif + } else if (static_branch_tcp_md5()) { + const union tcp_md5_addr *addr; + int l3index; + + addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr; + l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0; + key.md5_key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET); + if (key.md5_key) + key.type = TCP_KEY_MD5; + } + + /* Cleaning ECN bits of TW ACKs of oow data or is paws_reject */ tcp_v4_send_ack(sk, skb, seq, tcp_rsk(req)->rcv_nxt, - req->rsk_rcv_wnd >> inet_rsk(req)->rcv_wscale, - tcp_time_stamp_raw() + tcp_rsk(req)->ts_off, + tcp_synack_window(req) >> inet_rsk(req)->rcv_wscale, + tcp_rsk_tsval(tcp_rsk(req)), req->ts_recent, - 0, - tcp_md5_do_lookup(sk, l3index, addr, AF_INET), + 0, &key, inet_rsk(req)->no_srccheck ? IP_REPLY_ARG_NOSRCCHECK : 0, - ip_hdr(skb)->tos); + ip_hdr(skb)->tos & ~INET_ECN_MASK, + READ_ONCE(tcp_rsk(req)->txhash)); + if (tcp_key_is_ao(&key)) + kfree(key.traffic_key); } /* @@ -986,7 +1188,7 @@ static int tcp_v4_send_synack(const struct sock *sk, struct dst_entry *dst, enum tcp_synack_type synack_type, struct sk_buff *syn_skb) { - const struct inet_request_sock *ireq = inet_rsk(req); + struct inet_request_sock *ireq = inet_rsk(req); struct flowi4 fl4; int err = -1; struct sk_buff *skb; @@ -999,12 +1201,14 @@ static int tcp_v4_send_synack(const struct sock *sk, struct dst_entry *dst, skb = tcp_make_synack(sk, dst, req, foc, synack_type, syn_skb); if (skb) { + tcp_rsk(req)->syn_ect_snt = inet_sk(sk)->tos & INET_ECN_MASK; __tcp_v4_send_check(skb, ireq->ir_loc_addr, ireq->ir_rmt_addr); - tos = sock_net(sk)->ipv4.sysctl_tcp_reflect_tos ? - (tcp_rsk(req)->syn_tos & ~INET_ECN_MASK) | - (inet_sk(sk)->tos & INET_ECN_MASK) : - inet_sk(sk)->tos; + tos = READ_ONCE(inet_sk(sk)->tos); + + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reflect_tos)) + tos = (tcp_rsk(req)->syn_tos & ~INET_ECN_MASK) | + (tos & INET_ECN_MASK); if (!INET_ECN_is_capable(tos) && tcp_bpf_ca_needs_ecn((struct sock *)req)) @@ -1037,8 +1241,8 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req) * We need to maintain these in the sk structure. */ -DEFINE_STATIC_KEY_FALSE(tcp_md5_needed); -EXPORT_SYMBOL(tcp_md5_needed); +DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_md5_needed, HZ); +EXPORT_IPV6_MOD(tcp_md5_needed); static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new) { @@ -1057,7 +1261,7 @@ static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key * /* Find the Key structure for an address. */ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index, const union tcp_md5_addr *addr, - int family) + int family, bool any_l3index) { const struct tcp_sock *tp = tcp_sk(sk); struct tcp_md5sig_key *key; @@ -1076,7 +1280,8 @@ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index, lockdep_sock_is_held(sk)) { if (key->family != family) continue; - if (key->flags & TCP_MD5SIG_FLAG_IFINDEX && key->l3index != l3index) + if (!any_l3index && key->flags & TCP_MD5SIG_FLAG_IFINDEX && + key->l3index != l3index) continue; if (family == AF_INET) { mask = inet_make_mask(key->prefixlen); @@ -1096,7 +1301,7 @@ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index, } return best_match; } -EXPORT_SYMBOL(__tcp_md5_do_lookup); +EXPORT_IPV6_MOD(__tcp_md5_do_lookup); static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk, const union tcp_md5_addr *addr, @@ -1143,12 +1348,27 @@ struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk, addr = (const union tcp_md5_addr *)&addr_sk->sk_daddr; return tcp_md5_do_lookup(sk, l3index, addr, AF_INET); } -EXPORT_SYMBOL(tcp_v4_md5_lookup); +EXPORT_IPV6_MOD(tcp_v4_md5_lookup); + +static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_info *md5sig; + + md5sig = kmalloc(sizeof(*md5sig), gfp); + if (!md5sig) + return -ENOMEM; + + sk_gso_disable(sk); + INIT_HLIST_HEAD(&md5sig->head); + rcu_assign_pointer(tp->md5sig_info, md5sig); + return 0; +} /* This can be called on a newly created socket, from other files */ -int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, - int family, u8 prefixlen, int l3index, u8 flags, - const u8 *newkey, u8 newkeylen, gfp_t gfp) +static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, + int family, u8 prefixlen, int l3index, u8 flags, + const u8 *newkey, u8 newkeylen, gfp_t gfp) { /* Add Key to the list */ struct tcp_md5sig_key *key; @@ -1177,23 +1397,10 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)); - if (!md5sig) { - md5sig = kmalloc(sizeof(*md5sig), gfp); - if (!md5sig) - return -ENOMEM; - - sk_gso_disable(sk); - INIT_HLIST_HEAD(&md5sig->head); - rcu_assign_pointer(tp->md5sig_info, md5sig); - } key = sock_kmalloc(sk, sizeof(*key), gfp | __GFP_ZERO); if (!key) return -ENOMEM; - if (!tcp_alloc_md5sig_pool()) { - sock_kfree_s(sk, key, sizeof(*key)); - return -ENOMEM; - } memcpy(key->key, newkey, newkeylen); key->keylen = newkeylen; @@ -1202,12 +1409,69 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, key->l3index = l3index; key->flags = flags; memcpy(&key->addr, addr, - (family == AF_INET6) ? sizeof(struct in6_addr) : - sizeof(struct in_addr)); + (IS_ENABLED(CONFIG_IPV6) && family == AF_INET6) ? sizeof(struct in6_addr) : + sizeof(struct in_addr)); hlist_add_head_rcu(&key->node, &md5sig->head); return 0; } -EXPORT_SYMBOL(tcp_md5_do_add); + +int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, + int family, u8 prefixlen, int l3index, u8 flags, + const u8 *newkey, u8 newkeylen) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) { + if (fips_enabled) { + pr_warn_once("TCP-MD5 support is disabled due to FIPS\n"); + return -EOPNOTSUPP; + } + + if (tcp_md5sig_info_add(sk, GFP_KERNEL)) + return -ENOMEM; + + if (!static_branch_inc(&tcp_md5_needed.key)) { + struct tcp_md5sig_info *md5sig; + + md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)); + rcu_assign_pointer(tp->md5sig_info, NULL); + kfree_rcu(md5sig, rcu); + return -EUSERS; + } + } + + return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags, + newkey, newkeylen, GFP_KERNEL); +} +EXPORT_IPV6_MOD(tcp_md5_do_add); + +int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr, + int family, u8 prefixlen, int l3index, + struct tcp_md5sig_key *key) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) { + + if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC))) + return -ENOMEM; + + if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) { + struct tcp_md5sig_info *md5sig; + + md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)); + net_warn_ratelimited("Too many TCP-MD5 keys in the system\n"); + rcu_assign_pointer(tp->md5sig_info, NULL); + kfree_rcu(md5sig, rcu); + return -EUSERS; + } + } + + return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, + key->flags, key->key, key->keylen, + sk_gfp_mask(sk, GFP_ATOMIC)); +} +EXPORT_IPV6_MOD(tcp_md5_key_copy); int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family, u8 prefixlen, int l3index, u8 flags) @@ -1222,9 +1486,9 @@ int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family, kfree_rcu(key, rcu); return 0; } -EXPORT_SYMBOL(tcp_md5_do_del); +EXPORT_IPV6_MOD(tcp_md5_do_del); -static void tcp_clear_md5_list(struct sock *sk) +void tcp_clear_md5_list(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); struct tcp_md5sig_key *key; @@ -1234,9 +1498,9 @@ static void tcp_clear_md5_list(struct sock *sk) md5sig = rcu_dereference_protected(tp->md5sig_info, 1); hlist_for_each_entry_safe(key, n, &md5sig->head, node) { - hlist_del_rcu(&key->node); + hlist_del(&key->node); atomic_sub(sizeof(*key), &sk->sk_omem_alloc); - kfree_rcu(key, rcu); + kfree(key); } } @@ -1248,6 +1512,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, const union tcp_md5_addr *addr; u8 prefixlen = 32; int l3index = 0; + bool l3flag; u8 flags; if (optlen < sizeof(cmd)) @@ -1260,6 +1525,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, return -EINVAL; flags = cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX; + l3flag = cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX; if (optname == TCP_MD5SIG_EXT && cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) { @@ -1294,74 +1560,54 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN) return -EINVAL; + /* Don't allow keys for peers that have a matching TCP-AO key. + * See the comment in tcp_ao_add_cmd() + */ + if (tcp_ao_required(sk, addr, AF_INET, l3flag ? l3index : -1, false)) + return -EKEYREJECTED; + return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags, - cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL); + cmd.tcpm_key, cmd.tcpm_keylen); } -static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp, - __be32 daddr, __be32 saddr, - const struct tcphdr *th, int nbytes) +static void tcp_v4_md5_hash_headers(struct md5_ctx *ctx, + __be32 daddr, __be32 saddr, + const struct tcphdr *th, int nbytes) { - struct tcp4_pseudohdr *bp; - struct scatterlist sg; - struct tcphdr *_th; - - bp = hp->scratch; - bp->saddr = saddr; - bp->daddr = daddr; - bp->pad = 0; - bp->protocol = IPPROTO_TCP; - bp->len = cpu_to_be16(nbytes); - - _th = (struct tcphdr *)(bp + 1); - memcpy(_th, th, sizeof(*th)); - _th->check = 0; + struct { + struct tcp4_pseudohdr ip; + struct tcphdr tcp; + } h; - sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th)); - ahash_request_set_crypt(hp->md5_req, &sg, NULL, - sizeof(*bp) + sizeof(*th)); - return crypto_ahash_update(hp->md5_req); + h.ip.saddr = saddr; + h.ip.daddr = daddr; + h.ip.pad = 0; + h.ip.protocol = IPPROTO_TCP; + h.ip.len = cpu_to_be16(nbytes); + h.tcp = *th; + h.tcp.check = 0; + md5_update(ctx, (const u8 *)&h, sizeof(h.ip) + sizeof(h.tcp)); } -static int tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key, - __be32 daddr, __be32 saddr, const struct tcphdr *th) +static noinline_for_stack void +tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key, + __be32 daddr, __be32 saddr, const struct tcphdr *th) { - struct tcp_md5sig_pool *hp; - struct ahash_request *req; - - hp = tcp_get_md5sig_pool(); - if (!hp) - goto clear_hash_noput; - req = hp->md5_req; - - if (crypto_ahash_init(req)) - goto clear_hash; - if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2)) - goto clear_hash; - if (tcp_md5_hash_key(hp, key)) - goto clear_hash; - ahash_request_set_crypt(req, NULL, md5_hash, 0); - if (crypto_ahash_final(req)) - goto clear_hash; - - tcp_put_md5sig_pool(); - return 0; + struct md5_ctx ctx; -clear_hash: - tcp_put_md5sig_pool(); -clear_hash_noput: - memset(md5_hash, 0, 16); - return 1; + md5_init(&ctx); + tcp_v4_md5_hash_headers(&ctx, daddr, saddr, th, th->doff << 2); + tcp_md5_hash_key(&ctx, key); + md5_final(&ctx, md5_hash); } -int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key, - const struct sock *sk, - const struct sk_buff *skb) +noinline_for_stack void +tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key, + const struct sock *sk, const struct sk_buff *skb) { - struct tcp_md5sig_pool *hp; - struct ahash_request *req; const struct tcphdr *th = tcp_hdr(skb); __be32 saddr, daddr; + struct md5_ctx ctx; if (sk) { /* valid for establish/request sockets */ saddr = sk->sk_rcv_saddr; @@ -1372,102 +1618,15 @@ int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key, daddr = iph->daddr; } - hp = tcp_get_md5sig_pool(); - if (!hp) - goto clear_hash_noput; - req = hp->md5_req; - - if (crypto_ahash_init(req)) - goto clear_hash; - - if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, skb->len)) - goto clear_hash; - if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2)) - goto clear_hash; - if (tcp_md5_hash_key(hp, key)) - goto clear_hash; - ahash_request_set_crypt(req, NULL, md5_hash, 0); - if (crypto_ahash_final(req)) - goto clear_hash; - - tcp_put_md5sig_pool(); - return 0; - -clear_hash: - tcp_put_md5sig_pool(); -clear_hash_noput: - memset(md5_hash, 0, 16); - return 1; + md5_init(&ctx); + tcp_v4_md5_hash_headers(&ctx, daddr, saddr, th, skb->len); + tcp_md5_hash_skb_data(&ctx, skb, th->doff << 2); + tcp_md5_hash_key(&ctx, key); + md5_final(&ctx, md5_hash); } -EXPORT_SYMBOL(tcp_v4_md5_hash_skb); - -#endif - -/* Called with rcu_read_lock() */ -static bool tcp_v4_inbound_md5_hash(const struct sock *sk, - const struct sk_buff *skb, - int dif, int sdif) -{ -#ifdef CONFIG_TCP_MD5SIG - /* - * This gets called for each TCP segment that arrives - * so we want to be efficient. - * We have 3 drop cases: - * o No MD5 hash and one expected. - * o MD5 hash and we're not expecting one. - * o MD5 hash and its wrong. - */ - const __u8 *hash_location = NULL; - struct tcp_md5sig_key *hash_expected; - const struct iphdr *iph = ip_hdr(skb); - const struct tcphdr *th = tcp_hdr(skb); - const union tcp_md5_addr *addr; - unsigned char newhash[16]; - int genhash, l3index; - - /* sdif set, means packet ingressed via a device - * in an L3 domain and dif is set to the l3mdev - */ - l3index = sdif ? dif : 0; - - addr = (union tcp_md5_addr *)&iph->saddr; - hash_expected = tcp_md5_do_lookup(sk, l3index, addr, AF_INET); - hash_location = tcp_parse_md5sig_option(th); +EXPORT_IPV6_MOD(tcp_v4_md5_hash_skb); - /* We've parsed the options - do we have a hash? */ - if (!hash_expected && !hash_location) - return false; - - if (hash_expected && !hash_location) { - NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND); - return true; - } - - if (!hash_expected && hash_location) { - NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED); - return true; - } - - /* Okay, so this is hash_expected and hash_location - - * so we need to calculate the checksum. - */ - genhash = tcp_v4_md5_hash_skb(newhash, - hash_expected, - NULL, skb); - - if (genhash || memcmp(hash_location, newhash, 16) != 0) { - NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE); - net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n", - &iph->saddr, ntohs(th->source), - &iph->daddr, ntohs(th->dest), - genhash ? " tcp_v4_calc_md5_hash failed" - : "", l3index); - return true; - } - return false; #endif - return false; -} static void tcp_v4_init_req(struct request_sock *req, const struct sock *sk_listener, @@ -1484,7 +1643,8 @@ static void tcp_v4_init_req(struct request_sock *req, static struct dst_entry *tcp_v4_route_req(const struct sock *sk, struct sk_buff *skb, struct flowi *fl, - struct request_sock *req) + struct request_sock *req, + u32 tw_isn) { tcp_v4_init_req(req, sk, skb); @@ -1497,11 +1657,9 @@ static struct dst_entry *tcp_v4_route_req(const struct sock *sk, struct request_sock_ops tcp_request_sock_ops __read_mostly = { .family = PF_INET, .obj_size = sizeof(struct tcp_request_sock), - .rtx_syn_ack = tcp_rtx_synack, .send_ack = tcp_v4_reqsk_send_ack, .destructor = tcp_v4_reqsk_destructor, .send_reset = tcp_v4_send_reset, - .syn_ack_timeout = tcp_syn_ack_timeout, }; const struct tcp_request_sock_ops tcp_request_sock_ipv4_ops = { @@ -1510,6 +1668,11 @@ const struct tcp_request_sock_ops tcp_request_sock_ipv4_ops = { .req_md5_lookup = tcp_v4_md5_lookup, .calc_md5_hash = tcp_v4_md5_hash_skb, #endif +#ifdef CONFIG_TCP_AO + .ao_lookup = tcp_v4_ao_lookup_rsk, + .ao_calc_key = tcp_v4_ao_calc_key_rsk, + .ao_synack_hash = tcp_v4_ao_synack_hash, +#endif #ifdef CONFIG_SYN_COOKIES .cookie_init_seq = cookie_v4_init_sequence, #endif @@ -1532,7 +1695,7 @@ drop: tcp_listendrop(sk); return 0; } -EXPORT_SYMBOL(tcp_v4_conn_request); +EXPORT_IPV6_MOD(tcp_v4_conn_request); /* @@ -1570,10 +1733,6 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb, newtp = tcp_sk(newsk); newinet = inet_sk(newsk); ireq = inet_rsk(req); - sk_daddr_set(newsk, ireq->ir_rmt_addr); - sk_rcv_saddr_set(newsk, ireq->ir_loc_addr); - newsk->sk_bound_dev_if = ireq->ir_iif; - newinet->inet_saddr = ireq->ir_loc_addr; inet_opt = rcu_dereference(ireq->ireq_opt); RCU_INIT_POINTER(newinet->inet_opt, inet_opt); newinet->mc_index = inet_iif(skb); @@ -1582,12 +1741,12 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb, inet_csk(newsk)->icsk_ext_hdr_len = 0; if (inet_opt) inet_csk(newsk)->icsk_ext_hdr_len = inet_opt->opt.optlen; - newinet->inet_id = prandom_u32(); + atomic_set(&newinet->inet_id, get_random_u16()); /* Set ToS of the new socket based upon the value of incoming SYN. * ECT bits are set later in tcp_init_transfer(). */ - if (sock_net(sk)->ipv4.sysctl_tcp_reflect_tos) + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reflect_tos)) newinet->tos = tcp_rsk(req)->syn_tos & ~INET_ECN_MASK; if (!dst) { @@ -1611,18 +1770,16 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb, /* Copy over the MD5 key from the original socket */ addr = (union tcp_md5_addr *)&newinet->inet_daddr; key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET); - if (key) { - /* - * We're using one, so create a matching key - * on the newsk structure. If we fail to get - * memory, then we end up not copying the key - * across. Shucks. - */ - tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags, - key->key, key->keylen, GFP_ATOMIC); + if (key && !tcp_rsk_used_ao(req)) { + if (tcp_md5_key_copy(newsk, addr, AF_INET, 32, l3index, key)) + goto put_and_exit; sk_gso_disable(newsk); } #endif +#ifdef CONFIG_TCP_AO + if (tcp_ao_copy_all_matching(sk, newsk, req, skb, AF_INET)) + goto put_and_exit; /* OOM, release back memory */ +#endif if (__inet_inherit_port(sk, newsk) < 0) goto put_and_exit; @@ -1658,7 +1815,7 @@ put_and_exit: tcp_done(newsk); goto exit; } -EXPORT_SYMBOL(tcp_v4_syn_recv_sock); +EXPORT_IPV6_MOD(tcp_v4_syn_recv_sock); static struct sock *tcp_v4_cookie_check(struct sock *sk, struct sk_buff *skb) { @@ -1698,8 +1855,13 @@ INDIRECT_CALLABLE_DECLARE(struct dst_entry *ipv4_dst_check(struct dst_entry *, */ int tcp_v4_do_rcv(struct sock *sk, struct sk_buff *skb) { + enum skb_drop_reason reason; struct sock *rsk; + reason = psp_sk_rx_policy_check(sk, skb); + if (reason) + goto err_discard; + if (sk->sk_state == TCP_ESTABLISHED) { /* Fast path */ struct dst_entry *dst; @@ -1727,9 +1889,10 @@ int tcp_v4_do_rcv(struct sock *sk, struct sk_buff *skb) struct sock *nsk = tcp_v4_cookie_check(sk, skb); if (!nsk) - goto discard; + return 0; if (nsk != sk) { - if (tcp_child_process(sk, nsk, skb)) { + reason = tcp_child_process(sk, nsk, skb); + if (reason) { rsk = nsk; goto reset; } @@ -1738,16 +1901,17 @@ int tcp_v4_do_rcv(struct sock *sk, struct sk_buff *skb) } else sock_rps_save_rxhash(sk, skb); - if (tcp_rcv_state_process(sk, skb)) { + reason = tcp_rcv_state_process(sk, skb); + if (reason) { rsk = sk; goto reset; } return 0; reset: - tcp_v4_send_reset(rsk, skb); + tcp_v4_send_reset(rsk, skb, sk_rst_convert_drop_reason(reason)); discard: - kfree_skb(skb); + sk_skb_reason_drop(sk, skb, reason); /* Be careful here. If this function gets more complicated and * gcc suffers from register pressure on the x86, sk (in %ebx) * might be destroyed here. This current version compiles correctly, @@ -1756,8 +1920,10 @@ discard: return 0; csum_err: + reason = SKB_DROP_REASON_TCP_CSUM; trace_tcp_bad_csum(skb); TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS); +err_discard: TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); goto discard; } @@ -1765,6 +1931,7 @@ EXPORT_SYMBOL(tcp_v4_do_rcv); int tcp_v4_early_demux(struct sk_buff *skb) { + struct net *net = dev_net_rcu(skb->dev); const struct iphdr *iph; const struct tcphdr *th; struct sock *sk; @@ -1781,8 +1948,7 @@ int tcp_v4_early_demux(struct sk_buff *skb) if (th->doff < sizeof(struct tcphdr) / 4) return 0; - sk = __inet_lookup_established(dev_net(skb->dev), &tcp_hashinfo, - iph->saddr, th->source, + sk = __inet_lookup_established(net, iph->saddr, th->source, iph->daddr, ntohs(th->dest), skb->skb_iif, inet_sdif(skb)); if (sk) { @@ -1801,9 +1967,10 @@ int tcp_v4_early_demux(struct sk_buff *skb) return 0; } -bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb) +bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb, + enum skb_drop_reason *reason) { - u32 limit, tail_gso_size, tail_gso_segs; + u32 tail_gso_size, tail_gso_segs; struct skb_shared_info *shinfo; const struct tcphdr *th; struct tcphdr *thtail; @@ -1812,7 +1979,9 @@ bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb) bool fragstolen; u32 gso_segs; u32 gso_size; + u64 limit; int delta; + int err; /* In case all data was pulled from skb frags (in __pskb_pull_tail()), * we can fix skb->truesize to its real value to avoid future drops. @@ -1822,11 +1991,12 @@ bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb) */ skb_condense(skb); - skb_dst_drop(skb); + tcp_cleanup_skb(skb); if (unlikely(tcp_checksum_complete(skb))) { bh_unlock_sock(sk); trace_tcp_bad_csum(skb); + *reason = SKB_DROP_REASON_TCP_CSUM; __TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS); __TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); return true; @@ -1851,12 +2021,13 @@ bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb) !((TCP_SKB_CB(tail)->tcp_flags & TCP_SKB_CB(skb)->tcp_flags) & TCPHDR_ACK) || ((TCP_SKB_CB(tail)->tcp_flags ^ - TCP_SKB_CB(skb)->tcp_flags) & (TCPHDR_ECE | TCPHDR_CWR)) || -#ifdef CONFIG_TLS_DEVICE - tail->decrypted != skb->decrypted || -#endif + TCP_SKB_CB(skb)->tcp_flags) & + (TCPHDR_ECE | TCPHDR_CWR | TCPHDR_AE)) || + !tcp_skb_can_collapse_rx(tail, skb) || thtail->doff != th->doff || - memcmp(thtail + 1, th + 1, hdrlen - sizeof(*th))) + memcmp(thtail + 1, th + 1, hdrlen - sizeof(*th)) || + /* prior to PSP Rx policy check, retain exact PSP metadata */ + psp_skb_coalesce_diff(tail, skb)) goto no_coalesce; __skb_pull(skb, hdrlen); @@ -1907,28 +2078,45 @@ bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb) __skb_push(skb, hdrlen); no_coalesce: + /* sk->sk_backlog.len is reset only at the end of __release_sock(). + * Both sk->sk_backlog.len and sk->sk_rmem_alloc could reach + * sk_rcvbuf in normal conditions. + */ + limit = ((u64)READ_ONCE(sk->sk_rcvbuf)) << 1; + + limit += ((u32)READ_ONCE(sk->sk_sndbuf)) >> 1; + /* Only socket owner can try to collapse/prune rx queues * to reduce memory overhead, so add a little headroom here. * Few sockets backlog are possibly concurrently non empty. */ - limit = READ_ONCE(sk->sk_rcvbuf) + READ_ONCE(sk->sk_sndbuf) + 64*1024; + limit += 64 * 1024; - if (unlikely(sk_add_backlog(sk, skb, limit))) { + limit = min_t(u64, limit, UINT_MAX); + + err = sk_add_backlog(sk, skb, limit); + if (unlikely(err)) { bh_unlock_sock(sk); - __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPBACKLOGDROP); + if (err == -ENOMEM) { + *reason = SKB_DROP_REASON_PFMEMALLOC; + __NET_INC_STATS(sock_net(sk), LINUX_MIB_PFMEMALLOCDROP); + } else { + *reason = SKB_DROP_REASON_SOCKET_BACKLOG; + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPBACKLOGDROP); + } return true; } return false; } -EXPORT_SYMBOL(tcp_add_backlog); +EXPORT_IPV6_MOD(tcp_add_backlog); -int tcp_filter(struct sock *sk, struct sk_buff *skb) +int tcp_filter(struct sock *sk, struct sk_buff *skb, enum skb_drop_reason *reason) { struct tcphdr *th = (struct tcphdr *)skb->data; - return sk_filter_trim_cap(sk, skb, th->doff * 4); + return sk_filter_trim_cap(sk, skb, th->doff * 4, reason); } -EXPORT_SYMBOL(tcp_filter); +EXPORT_IPV6_MOD(tcp_filter); static void tcp_v4_restore_cb(struct sk_buff *skb) { @@ -1950,8 +2138,7 @@ static void tcp_v4_fill_cb(struct sk_buff *skb, const struct iphdr *iph, TCP_SKB_CB(skb)->end_seq = (TCP_SKB_CB(skb)->seq + th->syn + th->fin + skb->len - th->doff * 4); TCP_SKB_CB(skb)->ack_seq = ntohl(th->ack_seq); - TCP_SKB_CB(skb)->tcp_flags = tcp_flag_byte(th); - TCP_SKB_CB(skb)->tcp_tw_isn = 0; + TCP_SKB_CB(skb)->tcp_flags = tcp_flags_ntohs(th); TCP_SKB_CB(skb)->ip_dsfield = ipv4_get_dsfield(iph); TCP_SKB_CB(skb)->sacked = 0; TCP_SKB_CB(skb)->has_rxtstamp = @@ -1964,15 +2151,17 @@ static void tcp_v4_fill_cb(struct sk_buff *skb, const struct iphdr *iph, int tcp_v4_rcv(struct sk_buff *skb) { - struct net *net = dev_net(skb->dev); + struct net *net = dev_net_rcu(skb->dev); + enum skb_drop_reason drop_reason; + enum tcp_tw_status tw_status; int sdif = inet_sdif(skb); int dif = inet_iif(skb); const struct iphdr *iph; const struct tcphdr *th; + struct sock *sk = NULL; bool refcounted; - struct sock *sk; - int drop_reason; int ret; + u32 isn; drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; if (skb->pkt_type != PACKET_HOST) @@ -2004,12 +2193,11 @@ int tcp_v4_rcv(struct sk_buff *skb) th = (const struct tcphdr *)skb->data; iph = ip_hdr(skb); lookup: - sk = __inet_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th), th->source, + sk = __inet_lookup_skb(skb, __tcp_hdrlen(th), th->source, th->dest, sdif, &refcounted); if (!sk) goto no_tcp_socket; -process: if (sk->sk_state == TCP_TIME_WAIT) goto do_time_wait; @@ -2019,8 +2207,14 @@ process: struct sock *nsk; sk = req->rsk_listener; - if (unlikely(tcp_v4_inbound_md5_hash(sk, skb, dif, sdif))) { - sk_drops_add(sk, skb); + if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) + drop_reason = SKB_DROP_REASON_XFRM_POLICY; + else + drop_reason = tcp_inbound_hash(sk, req, skb, + &iph->saddr, &iph->daddr, + AF_INET, dif, sdif); + if (unlikely(drop_reason)) { + sk_drops_skbadd(sk, skb); reqsk_put(req); goto discard_it; } @@ -2046,11 +2240,12 @@ process: } refcounted = true; nsk = NULL; - if (!tcp_filter(sk, skb)) { + if (!tcp_filter(sk, skb, &drop_reason)) { th = (const struct tcphdr *)skb->data; iph = ip_hdr(skb); tcp_v4_fill_cb(skb, iph, th); - nsk = tcp_check_req(sk, skb, req, false, &req_stolen); + nsk = tcp_check_req(sk, skb, req, false, &req_stolen, + &drop_reason); } if (!nsk) { reqsk_put(req); @@ -2066,38 +2261,49 @@ process: } goto discard_and_relse; } + nf_reset_ct(skb); if (nsk == sk) { reqsk_put(req); tcp_v4_restore_cb(skb); - } else if (tcp_child_process(sk, nsk, skb)) { - tcp_v4_send_reset(nsk, skb); - goto discard_and_relse; } else { + drop_reason = tcp_child_process(sk, nsk, skb); + if (drop_reason) { + enum sk_rst_reason rst_reason; + + rst_reason = sk_rst_convert_drop_reason(drop_reason); + tcp_v4_send_reset(nsk, skb, rst_reason); + goto discard_and_relse; + } sock_put(sk); return 0; } } +process: if (static_branch_unlikely(&ip4_min_ttl)) { /* min_ttl can be changed concurrently from do_ip_setsockopt() */ if (unlikely(iph->ttl < READ_ONCE(inet_sk(sk)->min_ttl))) { __NET_INC_STATS(net, LINUX_MIB_TCPMINTTLDROP); + drop_reason = SKB_DROP_REASON_TCP_MINTTL; goto discard_and_relse; } } - if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) + if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) { + drop_reason = SKB_DROP_REASON_XFRM_POLICY; goto discard_and_relse; + } - if (tcp_v4_inbound_md5_hash(sk, skb, dif, sdif)) + drop_reason = tcp_inbound_hash(sk, NULL, skb, &iph->saddr, &iph->daddr, + AF_INET, dif, sdif); + if (drop_reason) goto discard_and_relse; nf_reset_ct(skb); - if (tcp_filter(sk, skb)) { - drop_reason = SKB_DROP_REASON_TCP_FILTER; + if (tcp_filter(sk, skb, &drop_reason)) goto discard_and_relse; - } + th = (const struct tcphdr *)skb->data; iph = ip_hdr(skb); tcp_v4_fill_cb(skb, iph, th); @@ -2111,14 +2317,13 @@ process: sk_incoming_cpu_update(sk); - sk_defer_free_flush(sk); bh_lock_sock_nested(sk); tcp_segs_in(tcp_sk(sk), skb); ret = 0; if (!sock_owned_by_user(sk)) { ret = tcp_v4_do_rcv(sk, skb); } else { - if (tcp_add_backlog(sk, skb)) + if (tcp_add_backlog(sk, skb, &drop_reason)) goto discard_and_relse; } bh_unlock_sock(sk); @@ -2144,22 +2349,24 @@ csum_error: bad_packet: __TCP_INC_STATS(net, TCP_MIB_INERRS); } else { - tcp_v4_send_reset(NULL, skb); + tcp_v4_send_reset(NULL, skb, sk_rst_convert_drop_reason(drop_reason)); } discard_it: + SKB_DR_OR(drop_reason, NOT_SPECIFIED); /* Discard frame. */ - kfree_skb_reason(skb, drop_reason); + sk_skb_reason_drop(sk, skb, drop_reason); return 0; discard_and_relse: - sk_drops_add(sk, skb); + sk_drops_skbadd(sk, skb); if (refcounted) sock_put(sk); goto discard_it; do_time_wait: if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) { + drop_reason = SKB_DROP_REASON_XFRM_POLICY; inet_twsk_put(inet_twsk(sk)); goto discard_it; } @@ -2170,11 +2377,12 @@ do_time_wait: inet_twsk_put(inet_twsk(sk)); goto csum_error; } - switch (tcp_timewait_state_process(inet_twsk(sk), skb, th)) { + + tw_status = tcp_timewait_state_process(inet_twsk(sk), skb, th, &isn, + &drop_reason); + switch (tw_status) { case TCP_TW_SYN: { - struct sock *sk2 = inet_lookup_listener(dev_net(skb->dev), - &tcp_hashinfo, skb, - __tcp_hdrlen(th), + struct sock *sk2 = inet_lookup_listener(net, skb, __tcp_hdrlen(th), iph->saddr, th->source, iph->daddr, th->dest, inet_iif(skb), @@ -2184,16 +2392,22 @@ do_time_wait: sk = sk2; tcp_v4_restore_cb(skb); refcounted = false; + __this_cpu_write(tcp_tw_isn, isn); goto process; } + + drop_reason = psp_twsk_rx_policy_check(inet_twsk(sk), skb); + if (drop_reason) + break; } /* to ACK */ fallthrough; case TCP_TW_ACK: - tcp_v4_timewait_ack(sk, skb); + case TCP_TW_ACK_OOW: + tcp_v4_timewait_ack(sk, skb, tw_status); break; case TCP_TW_RST: - tcp_v4_send_reset(sk, skb); + tcp_v4_send_reset(sk, skb, SK_RST_REASON_TCP_TIMEWAIT_SOCKET); inet_twsk_deschedule_put(inet_twsk(sk)); goto discard_it; case TCP_TW_SUCCESS:; @@ -2203,8 +2417,6 @@ do_time_wait: static struct timewait_sock_ops tcp_timewait_sock_ops = { .twsk_obj_size = sizeof(struct tcp_timewait_sock), - .twsk_unique = tcp_twsk_unique, - .twsk_destructor= tcp_twsk_destructor, }; void inet_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) @@ -2216,7 +2428,7 @@ void inet_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) sk->sk_rx_dst_ifindex = skb->skb_iif; } } -EXPORT_SYMBOL(inet_sk_rx_dst_set); +EXPORT_IPV6_MOD(inet_sk_rx_dst_set); const struct inet_connection_sock_af_ops ipv4_specific = { .queue_xmit = ip_queue_xmit, @@ -2228,18 +2440,31 @@ const struct inet_connection_sock_af_ops ipv4_specific = { .net_header_len = sizeof(struct iphdr), .setsockopt = ip_setsockopt, .getsockopt = ip_getsockopt, - .addr2sockaddr = inet_csk_addr2sockaddr, - .sockaddr_len = sizeof(struct sockaddr_in), .mtu_reduced = tcp_v4_mtu_reduced, }; -EXPORT_SYMBOL(ipv4_specific); +EXPORT_IPV6_MOD(ipv4_specific); -#ifdef CONFIG_TCP_MD5SIG +#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO) static const struct tcp_sock_af_ops tcp_sock_ipv4_specific = { +#ifdef CONFIG_TCP_MD5SIG .md5_lookup = tcp_v4_md5_lookup, .calc_md5_hash = tcp_v4_md5_hash_skb, .md5_parse = tcp_v4_parse_md5_keys, +#endif +#ifdef CONFIG_TCP_AO + .ao_lookup = tcp_v4_ao_lookup, + .calc_ao_hash = tcp_v4_ao_hash_skb, + .ao_parse = tcp_v4_parse_ao, + .ao_calc_key_sk = tcp_v4_ao_calc_key_sk, +#endif }; + +static void tcp4_destruct_sock(struct sock *sk) +{ + tcp_md5_destruct_sock(sk); + tcp_ao_destroy_sock(sk, false); + inet_sock_destruct(sk); +} #endif /* NOTE: A lot of things set to zero explicitly by call to @@ -2253,17 +2478,33 @@ static int tcp_v4_init_sock(struct sock *sk) icsk->icsk_af_ops = &ipv4_specific; -#ifdef CONFIG_TCP_MD5SIG +#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO) tcp_sk(sk)->af_specific = &tcp_sock_ipv4_specific; + sk->sk_destruct = tcp4_destruct_sock; #endif return 0; } +static void tcp_release_user_frags(struct sock *sk) +{ +#ifdef CONFIG_PAGE_POOL + unsigned long index; + void *netmem; + + xa_for_each(&sk->sk_user_frags, index, netmem) + WARN_ON_ONCE(!napi_pp_put_page((__force netmem_ref)netmem)); +#endif +} + void tcp_v4_destroy_sock(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); + tcp_release_user_frags(sk); + + xa_destroy(&sk->sk_user_frags); + trace_tcp_destroy_sock(sk); tcp_clear_xmit_timers(sk); @@ -2281,15 +2522,6 @@ void tcp_v4_destroy_sock(struct sock *sk) /* Cleans up our, hopefully empty, out_of_order_queue. */ skb_rbtree_purge(&tp->out_of_order_queue); -#ifdef CONFIG_TCP_MD5SIG - /* Clean up the MD5 key list, if any */ - if (tp->md5sig_info) { - tcp_clear_md5_list(sk); - kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu); - tp->md5sig_info = NULL; - } -#endif - /* Clean up a referenced TCP bind bucket. */ if (inet_csk(sk)->icsk_bind_hash) inet_put_port(sk); @@ -2303,7 +2535,7 @@ void tcp_v4_destroy_sock(struct sock *sk) sk_sockets_allocated_dec(sk); } -EXPORT_SYMBOL(tcp_v4_destroy_sock); +EXPORT_IPV6_MOD(tcp_v4_destroy_sock); #ifdef CONFIG_PROC_FS /* Proc filesystem TCP sock list dumping. */ @@ -2324,21 +2556,21 @@ static bool seq_sk_match(struct seq_file *seq, const struct sock *sk) */ static void *listening_get_first(struct seq_file *seq) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; st->offset = 0; - for (; st->bucket <= tcp_hashinfo.lhash2_mask; st->bucket++) { + for (; st->bucket <= hinfo->lhash2_mask; st->bucket++) { struct inet_listen_hashbucket *ilb2; - struct inet_connection_sock *icsk; + struct hlist_nulls_node *node; struct sock *sk; - ilb2 = &tcp_hashinfo.lhash2[st->bucket]; - if (hlist_empty(&ilb2->head)) + ilb2 = &hinfo->lhash2[st->bucket]; + if (hlist_nulls_empty(&ilb2->nulls_head)) continue; spin_lock(&ilb2->lock); - inet_lhash2_for_each_icsk(icsk, &ilb2->head) { - sk = (struct sock *)icsk; + sk_nulls_for_each(sk, node, &ilb2->nulls_head) { if (seq_sk_match(seq, sk)) return sk; } @@ -2357,20 +2589,21 @@ static void *listening_get_next(struct seq_file *seq, void *cur) { struct tcp_iter_state *st = seq->private; struct inet_listen_hashbucket *ilb2; - struct inet_connection_sock *icsk; + struct hlist_nulls_node *node; + struct inet_hashinfo *hinfo; struct sock *sk = cur; ++st->num; ++st->offset; - icsk = inet_csk(sk); - inet_lhash2_for_each_icsk_continue(icsk) { - sk = (struct sock *)icsk; + sk = sk_nulls_next(sk); + sk_nulls_for_each_from(sk, node) { if (seq_sk_match(seq, sk)) return sk; } - ilb2 = &tcp_hashinfo.lhash2[st->bucket]; + hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + ilb2 = &hinfo->lhash2[st->bucket]; spin_unlock(&ilb2->lock); ++st->bucket; return listening_get_first(seq); @@ -2392,9 +2625,10 @@ static void *listening_get_idx(struct seq_file *seq, loff_t *pos) return rc; } -static inline bool empty_bucket(const struct tcp_iter_state *st) +static inline bool empty_bucket(struct inet_hashinfo *hinfo, + const struct tcp_iter_state *st) { - return hlist_nulls_empty(&tcp_hashinfo.ehash[st->bucket].chain); + return hlist_nulls_empty(&hinfo->ehash[st->bucket].chain); } /* @@ -2403,20 +2637,23 @@ static inline bool empty_bucket(const struct tcp_iter_state *st) */ static void *established_get_first(struct seq_file *seq) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; st->offset = 0; - for (; st->bucket <= tcp_hashinfo.ehash_mask; ++st->bucket) { + for (; st->bucket <= hinfo->ehash_mask; ++st->bucket) { struct sock *sk; struct hlist_nulls_node *node; - spinlock_t *lock = inet_ehash_lockp(&tcp_hashinfo, st->bucket); + spinlock_t *lock = inet_ehash_lockp(hinfo, st->bucket); + + cond_resched(); /* Lockless fast path for the common case of empty buckets */ - if (empty_bucket(st)) + if (empty_bucket(hinfo, st)) continue; spin_lock_bh(lock); - sk_nulls_for_each(sk, node, &tcp_hashinfo.ehash[st->bucket].chain) { + sk_nulls_for_each(sk, node, &hinfo->ehash[st->bucket].chain) { if (seq_sk_match(seq, sk)) return sk; } @@ -2428,9 +2665,10 @@ static void *established_get_first(struct seq_file *seq) static void *established_get_next(struct seq_file *seq, void *cur) { - struct sock *sk = cur; - struct hlist_nulls_node *node; + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; + struct hlist_nulls_node *node; + struct sock *sk = cur; ++st->num; ++st->offset; @@ -2442,7 +2680,7 @@ static void *established_get_next(struct seq_file *seq, void *cur) return sk; } - spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket)); + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); ++st->bucket; return established_get_first(seq); } @@ -2480,6 +2718,7 @@ static void *tcp_get_idx(struct seq_file *seq, loff_t pos) static void *tcp_seek_last_pos(struct seq_file *seq) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; int bucket = st->bucket; int offset = st->offset; @@ -2488,9 +2727,8 @@ static void *tcp_seek_last_pos(struct seq_file *seq) switch (st->state) { case TCP_SEQ_STATE_LISTENING: - if (st->bucket > tcp_hashinfo.lhash2_mask) + if (st->bucket > hinfo->lhash2_mask) break; - st->state = TCP_SEQ_STATE_LISTENING; rc = listening_get_first(seq); while (offset-- && rc && bucket == st->bucket) rc = listening_get_next(seq, rc); @@ -2500,7 +2738,7 @@ static void *tcp_seek_last_pos(struct seq_file *seq) st->state = TCP_SEQ_STATE_ESTABLISHED; fallthrough; case TCP_SEQ_STATE_ESTABLISHED: - if (st->bucket > tcp_hashinfo.ehash_mask) + if (st->bucket > hinfo->ehash_mask) break; rc = established_get_first(seq); while (offset-- && rc && bucket == st->bucket) @@ -2533,7 +2771,7 @@ out: st->last_pos = *pos; return rc; } -EXPORT_SYMBOL(tcp_seq_start); +EXPORT_IPV6_MOD(tcp_seq_start); void *tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos) { @@ -2564,24 +2802,25 @@ out: st->last_pos = *pos; return rc; } -EXPORT_SYMBOL(tcp_seq_next); +EXPORT_IPV6_MOD(tcp_seq_next); void tcp_seq_stop(struct seq_file *seq, void *v) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct tcp_iter_state *st = seq->private; switch (st->state) { case TCP_SEQ_STATE_LISTENING: if (v != SEQ_START_TOKEN) - spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock); + spin_unlock(&hinfo->lhash2[st->bucket].lock); break; case TCP_SEQ_STATE_ESTABLISHED: if (v) - spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket)); + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); break; } } -EXPORT_SYMBOL(tcp_seq_stop); +EXPORT_IPV6_MOD(tcp_seq_stop); static void get_openreq4(const struct request_sock *req, struct seq_file *f, int i) @@ -2602,7 +2841,7 @@ static void get_openreq4(const struct request_sock *req, jiffies_delta_to_clock_t(delta), req->num_timeout, from_kuid_munged(seq_user_ns(f), - sock_i_uid(req->rsk_listener)), + sk_uid(req->rsk_listener)), 0, /* non standard timer */ 0, /* open_requests have no inode */ 0, @@ -2621,20 +2860,22 @@ static void get_tcp4_sock(struct sock *sk, struct seq_file *f, int i) __be32 src = inet->inet_rcv_saddr; __u16 destp = ntohs(inet->inet_dport); __u16 srcp = ntohs(inet->inet_sport); + u8 icsk_pending; int rx_queue; int state; - if (icsk->icsk_pending == ICSK_TIME_RETRANS || - icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT || - icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) { + icsk_pending = smp_load_acquire(&icsk->icsk_pending); + if (icsk_pending == ICSK_TIME_RETRANS || + icsk_pending == ICSK_TIME_REO_TIMEOUT || + icsk_pending == ICSK_TIME_LOSS_PROBE) { timer_active = 1; - timer_expires = icsk->icsk_timeout; - } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) { + timer_expires = tcp_timeout_expires(sk); + } else if (icsk_pending == ICSK_TIME_PROBE0) { timer_active = 4; - timer_expires = icsk->icsk_timeout; - } else if (timer_pending(&sk->sk_timer)) { + timer_expires = tcp_timeout_expires(sk); + } else if (timer_pending(&icsk->icsk_keepalive_timer)) { timer_active = 2; - timer_expires = sk->sk_timer.expires; + timer_expires = icsk->icsk_keepalive_timer.expires; } else { timer_active = 0; timer_expires = jiffies; @@ -2657,15 +2898,15 @@ static void get_tcp4_sock(struct sock *sk, struct seq_file *f, int i) rx_queue, timer_active, jiffies_delta_to_clock_t(timer_expires - jiffies), - icsk->icsk_retransmits, - from_kuid_munged(seq_user_ns(f), sock_i_uid(sk)), - icsk->icsk_probes_out, + READ_ONCE(icsk->icsk_retransmits), + from_kuid_munged(seq_user_ns(f), sk_uid(sk)), + READ_ONCE(icsk->icsk_probes_out), sock_i_ino(sk), refcount_read(&sk->sk_refcnt), sk, jiffies_to_clock_t(icsk->icsk_rto), jiffies_to_clock_t(icsk->icsk_ack.ato), (icsk->icsk_ack.quick << 1) | inet_csk_in_pingpong_mode(sk), - tp->snd_cwnd, + tcp_snd_cwnd(tp), state == TCP_LISTEN ? fastopenq->max_qlen : (tcp_in_initial_slowstart(tp) ? -1 : tp->snd_ssthresh)); @@ -2685,7 +2926,7 @@ static void get_timewait4_sock(const struct inet_timewait_sock *tw, seq_printf(f, "%4d: %08X:%04X %08X:%04X" " %02X %08X:%08X %02X:%08lX %08X %5d %8d %d %d %pK", - i, src, srcp, dest, destp, tw->tw_substate, 0, 0, + i, src, srcp, dest, destp, READ_ONCE(tw->tw_substate), 0, 0, 3, jiffies_delta_to_clock_t(delta), 0, 0, 0, 0, refcount_read(&tw->tw_refcnt), tw); } @@ -2718,13 +2959,17 @@ out: } #ifdef CONFIG_BPF_SYSCALL +union bpf_tcp_iter_batch_item { + struct sock *sk; + __u64 cookie; +}; + struct bpf_tcp_iter_state { struct tcp_iter_state state; unsigned int cur_sk; unsigned int end_sk; unsigned int max_sk; - struct sock **batch; - bool st_bucket_done; + union bpf_tcp_iter_batch_item *batch; }; struct bpf_iter__tcp { @@ -2747,21 +2992,32 @@ static int tcp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta, static void bpf_iter_tcp_put_batch(struct bpf_tcp_iter_state *iter) { - while (iter->cur_sk < iter->end_sk) - sock_put(iter->batch[iter->cur_sk++]); + union bpf_tcp_iter_batch_item *item; + unsigned int cur_sk = iter->cur_sk; + __u64 cookie; + + /* Remember the cookies of the sockets we haven't seen yet, so we can + * pick up where we left off next time around. + */ + while (cur_sk < iter->end_sk) { + item = &iter->batch[cur_sk++]; + cookie = sock_gen_cookie(item->sk); + sock_gen_put(item->sk); + item->cookie = cookie; + } } static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, - unsigned int new_batch_sz) + unsigned int new_batch_sz, gfp_t flags) { - struct sock **new_batch; + union bpf_tcp_iter_batch_item *new_batch; new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz, - GFP_USER | __GFP_NOWARN); + flags | __GFP_NOWARN); if (!new_batch) return -ENOMEM; - bpf_iter_tcp_put_batch(iter); + memcpy(new_batch, iter->batch, sizeof(*iter->batch) * iter->end_sk); kvfree(iter->batch); iter->batch = new_batch; iter->max_sk = new_batch_sz; @@ -2769,110 +3025,234 @@ static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, return 0; } -static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, - struct sock *start_sk) +static struct sock *bpf_iter_tcp_resume_bucket(struct sock *first_sk, + union bpf_tcp_iter_batch_item *cookies, + int n_cookies) +{ + struct hlist_nulls_node *node; + struct sock *sk; + int i; + + for (i = 0; i < n_cookies; i++) { + sk = first_sk; + sk_nulls_for_each_from(sk, node) + if (cookies[i].cookie == atomic64_read(&sk->sk_cookie)) + return sk; + } + + return NULL; +} + +static struct sock *bpf_iter_tcp_resume_listening(struct seq_file *seq) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + unsigned int find_cookie = iter->cur_sk; + unsigned int end_cookie = iter->end_sk; + int resume_bucket = st->bucket; + struct sock *sk; + + if (end_cookie && find_cookie == end_cookie) + ++st->bucket; + + sk = listening_get_first(seq); + iter->cur_sk = 0; + iter->end_sk = 0; + + if (sk && st->bucket == resume_bucket && end_cookie) { + sk = bpf_iter_tcp_resume_bucket(sk, &iter->batch[find_cookie], + end_cookie - find_cookie); + if (!sk) { + spin_unlock(&hinfo->lhash2[st->bucket].lock); + ++st->bucket; + sk = listening_get_first(seq); + } + } + + return sk; +} + +static struct sock *bpf_iter_tcp_resume_established(struct seq_file *seq) { + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; struct tcp_iter_state *st = &iter->state; - struct inet_connection_sock *icsk; + unsigned int find_cookie = iter->cur_sk; + unsigned int end_cookie = iter->end_sk; + int resume_bucket = st->bucket; + struct sock *sk; + + if (end_cookie && find_cookie == end_cookie) + ++st->bucket; + + sk = established_get_first(seq); + iter->cur_sk = 0; + iter->end_sk = 0; + + if (sk && st->bucket == resume_bucket && end_cookie) { + sk = bpf_iter_tcp_resume_bucket(sk, &iter->batch[find_cookie], + end_cookie - find_cookie); + if (!sk) { + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); + ++st->bucket; + sk = established_get_first(seq); + } + } + + return sk; +} + +static struct sock *bpf_iter_tcp_resume(struct seq_file *seq) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct sock *sk = NULL; + + switch (st->state) { + case TCP_SEQ_STATE_LISTENING: + sk = bpf_iter_tcp_resume_listening(seq); + if (sk) + break; + st->bucket = 0; + st->state = TCP_SEQ_STATE_ESTABLISHED; + fallthrough; + case TCP_SEQ_STATE_ESTABLISHED: + sk = bpf_iter_tcp_resume_established(seq); + break; + } + + return sk; +} + +static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, + struct sock **start_sk) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct hlist_nulls_node *node; unsigned int expected = 1; struct sock *sk; - sock_hold(start_sk); - iter->batch[iter->end_sk++] = start_sk; + sock_hold(*start_sk); + iter->batch[iter->end_sk++].sk = *start_sk; - icsk = inet_csk(start_sk); - inet_lhash2_for_each_icsk_continue(icsk) { - sk = (struct sock *)icsk; + sk = sk_nulls_next(*start_sk); + *start_sk = NULL; + sk_nulls_for_each_from(sk, node) { if (seq_sk_match(seq, sk)) { if (iter->end_sk < iter->max_sk) { sock_hold(sk); - iter->batch[iter->end_sk++] = sk; + iter->batch[iter->end_sk++].sk = sk; + } else if (!*start_sk) { + /* Remember where we left off. */ + *start_sk = sk; } expected++; } } - spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock); return expected; } static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq, - struct sock *start_sk) + struct sock **start_sk) { struct bpf_tcp_iter_state *iter = seq->private; - struct tcp_iter_state *st = &iter->state; struct hlist_nulls_node *node; unsigned int expected = 1; struct sock *sk; - sock_hold(start_sk); - iter->batch[iter->end_sk++] = start_sk; + sock_hold(*start_sk); + iter->batch[iter->end_sk++].sk = *start_sk; - sk = sk_nulls_next(start_sk); + sk = sk_nulls_next(*start_sk); + *start_sk = NULL; sk_nulls_for_each_from(sk, node) { if (seq_sk_match(seq, sk)) { if (iter->end_sk < iter->max_sk) { sock_hold(sk); - iter->batch[iter->end_sk++] = sk; + iter->batch[iter->end_sk++].sk = sk; + } else if (!*start_sk) { + /* Remember where we left off. */ + *start_sk = sk; } expected++; } } - spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket)); return expected; } -static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) +static unsigned int bpf_iter_fill_batch(struct seq_file *seq, + struct sock **start_sk) { struct bpf_tcp_iter_state *iter = seq->private; struct tcp_iter_state *st = &iter->state; + + if (st->state == TCP_SEQ_STATE_LISTENING) + return bpf_iter_tcp_listening_batch(seq, start_sk); + else + return bpf_iter_tcp_established_batch(seq, start_sk); +} + +static void bpf_iter_tcp_unlock_bucket(struct seq_file *seq) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + + if (st->state == TCP_SEQ_STATE_LISTENING) + spin_unlock(&hinfo->lhash2[st->bucket].lock); + else + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); +} + +static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) +{ + struct bpf_tcp_iter_state *iter = seq->private; unsigned int expected; - bool resized = false; struct sock *sk; + int err; - /* The st->bucket is done. Directly advance to the next - * bucket instead of having the tcp_seek_last_pos() to skip - * one by one in the current bucket and eventually find out - * it has to advance to the next bucket. - */ - if (iter->st_bucket_done) { - st->offset = 0; - st->bucket++; - if (st->state == TCP_SEQ_STATE_LISTENING && - st->bucket > tcp_hashinfo.lhash2_mask) { - st->state = TCP_SEQ_STATE_ESTABLISHED; - st->bucket = 0; - } - } + sk = bpf_iter_tcp_resume(seq); + if (!sk) + return NULL; /* Done */ -again: - /* Get a new batch */ - iter->cur_sk = 0; - iter->end_sk = 0; - iter->st_bucket_done = false; + expected = bpf_iter_fill_batch(seq, &sk); + if (likely(iter->end_sk == expected)) + goto done; + + /* Batch size was too small. */ + bpf_iter_tcp_unlock_bucket(seq); + bpf_iter_tcp_put_batch(iter); + err = bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2, + GFP_USER); + if (err) + return ERR_PTR(err); - sk = tcp_seek_last_pos(seq); + sk = bpf_iter_tcp_resume(seq); if (!sk) return NULL; /* Done */ - if (st->state == TCP_SEQ_STATE_LISTENING) - expected = bpf_iter_tcp_listening_batch(seq, sk); - else - expected = bpf_iter_tcp_established_batch(seq, sk); - - if (iter->end_sk == expected) { - iter->st_bucket_done = true; - return sk; - } + expected = bpf_iter_fill_batch(seq, &sk); + if (likely(iter->end_sk == expected)) + goto done; - if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2)) { - resized = true; - goto again; + /* Batch size was still too small. Hold onto the lock while we try + * again with a larger batch to make sure the current bucket's size + * does not change in the meantime. + */ + err = bpf_iter_tcp_realloc_batch(iter, expected, GFP_NOWAIT); + if (err) { + bpf_iter_tcp_unlock_bucket(seq); + return ERR_PTR(err); } - return sk; + expected = bpf_iter_fill_batch(seq, &sk); + WARN_ON_ONCE(iter->end_sk != expected); +done: + bpf_iter_tcp_unlock_bucket(seq); + return iter->batch[0].sk; } static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos) @@ -2902,16 +3282,11 @@ static void *bpf_iter_tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos) * meta.seq_num is used instead. */ st->num++; - /* Move st->offset to the next sk in the bucket such that - * the future start() will resume at st->offset in - * st->bucket. See tcp_seek_last_pos(). - */ - st->offset++; - sock_put(iter->batch[iter->cur_sk++]); + sock_gen_put(iter->batch[iter->cur_sk++].sk); } if (iter->cur_sk < iter->end_sk) - sk = iter->batch[iter->cur_sk]; + sk = iter->batch[iter->cur_sk].sk; else sk = bpf_iter_tcp_batch(seq); @@ -2928,7 +3303,6 @@ static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v) struct bpf_iter_meta meta; struct bpf_prog *prog; struct sock *sk = v; - bool slow; uid_t uid; int ret; @@ -2936,7 +3310,7 @@ static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v) return 0; if (sk_fullsock(sk)) - slow = lock_sock_fast(sk); + lock_sock(sk); if (unlikely(sk_unhashed(sk))) { ret = SEQ_SKIP; @@ -2949,9 +3323,9 @@ static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v) const struct request_sock *req = v; uid = from_kuid_munged(seq_user_ns(seq), - sock_i_uid(req->rsk_listener)); + sk_uid(req->rsk_listener)); } else { - uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)); + uid = from_kuid_munged(seq_user_ns(seq), sk_uid(sk)); } meta.seq = seq; @@ -2960,7 +3334,7 @@ static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v) unlock: if (sk_fullsock(sk)) - unlock_sock_fast(sk, slow); + release_sock(sk); return ret; } @@ -2978,10 +3352,8 @@ static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v) (void)tcp_prog_seq_show(prog, &meta, v, 0); } - if (iter->cur_sk < iter->end_sk) { + if (iter->cur_sk < iter->end_sk) bpf_iter_tcp_put_batch(iter); - iter->st_bucket_done = false; - } } static const struct seq_operations bpf_iter_tcp_seq_ops = { @@ -3078,7 +3450,7 @@ struct proto tcp_prot = { .keepalive = tcp_set_keepalive, .recvmsg = tcp_recvmsg, .sendmsg = tcp_sendmsg, - .sendpage = tcp_sendpage, + .splice_eof = tcp_splice_eof, .backlog_rcv = tcp_v4_do_rcv, .release_cb = tcp_release_cb, .hash = inet_hash, @@ -3092,8 +3464,10 @@ struct proto tcp_prot = { .leave_memory_pressure = tcp_leave_memory_pressure, .stream_memory_free = tcp_stream_memory_free, .sockets_allocated = &tcp_sockets_allocated, - .orphan_count = &tcp_orphan_count, - .memory_allocated = &tcp_memory_allocated, + + .memory_allocated = &net_aligned_data.tcp_memory_allocated, + .per_cpu_fw_alloc = &tcp_memory_per_cpu_fw_alloc, + .memory_pressure = &tcp_memory_pressure, .sysctl_mem = sysctl_tcp_mem, .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_tcp_wmem), @@ -3103,7 +3477,7 @@ struct proto tcp_prot = { .slab_flags = SLAB_TYPESAFE_BY_RCU, .twsk_prot = &tcp_timewait_sock_ops, .rsk_prot = &tcp_request_sock_ops, - .h.hashinfo = &tcp_hashinfo, + .h.hashinfo = NULL, .no_autobind = true, .diag_destroy = tcp_abort, }; @@ -3111,43 +3485,46 @@ EXPORT_SYMBOL(tcp_prot); static void __net_exit tcp_sk_exit(struct net *net) { - int cpu; - if (net->ipv4.tcp_congestion_control) bpf_module_put(net->ipv4.tcp_congestion_control, net->ipv4.tcp_congestion_control->owner); - - for_each_possible_cpu(cpu) - inet_ctl_sock_destroy(*per_cpu_ptr(net->ipv4.tcp_sk, cpu)); - free_percpu(net->ipv4.tcp_sk); } -static int __net_init tcp_sk_init(struct net *net) +static void __net_init tcp_set_hashinfo(struct net *net) { - int res, cpu, cnt; + struct inet_hashinfo *hinfo; + unsigned int ehash_entries; + struct net *old_net; - net->ipv4.tcp_sk = alloc_percpu(struct sock *); - if (!net->ipv4.tcp_sk) - return -ENOMEM; + if (net_eq(net, &init_net)) + goto fallback; - for_each_possible_cpu(cpu) { - struct sock *sk; + old_net = current->nsproxy->net_ns; + ehash_entries = READ_ONCE(old_net->ipv4.sysctl_tcp_child_ehash_entries); + if (!ehash_entries) + goto fallback; - res = inet_ctl_sock_create(&sk, PF_INET, SOCK_RAW, - IPPROTO_TCP, net); - if (res) - goto fail; - sock_set_flag(sk, SOCK_USE_WRITE_QUEUE); - - /* Please enforce IP_DF and IPID==0 for RST and - * ACK sent in SYN-RECV and TIME-WAIT state. - */ - inet_sk(sk)->pmtudisc = IP_PMTUDISC_DO; - - *per_cpu_ptr(net->ipv4.tcp_sk, cpu) = sk; + ehash_entries = roundup_pow_of_two(ehash_entries); + hinfo = inet_pernet_hashinfo_alloc(&tcp_hashinfo, ehash_entries); + if (!hinfo) { + pr_warn("Failed to allocate TCP ehash (entries: %u) " + "for a netns, fallback to the global one\n", + ehash_entries); +fallback: + hinfo = &tcp_hashinfo; + ehash_entries = tcp_hashinfo.ehash_mask + 1; } - net->ipv4.sysctl_tcp_ecn = 2; + net->ipv4.tcp_death_row.hashinfo = hinfo; + net->ipv4.tcp_death_row.sysctl_max_tw_buckets = ehash_entries / 2; + net->ipv4.sysctl_max_syn_backlog = max(128U, ehash_entries / 128); +} + +static int __net_init tcp_sk_init(struct net *net) +{ + net->ipv4.sysctl_tcp_ecn = TCP_ECN_IN_ECN_OUT_NOECN; + net->ipv4.sysctl_tcp_ecn_option = TCP_ACCECN_OPTION_FULL; + net->ipv4.sysctl_tcp_ecn_option_beacon = TCP_ACCECN_OPTION_BEACON; net->ipv4.sysctl_tcp_ecn_fallback = 1; net->ipv4.sysctl_tcp_base_mss = TCP_BASE_MSS; @@ -3170,13 +3547,12 @@ static int __net_init tcp_sk_init(struct net *net) net->ipv4.sysctl_tcp_fin_timeout = TCP_FIN_TIMEOUT; net->ipv4.sysctl_tcp_notsent_lowat = UINT_MAX; net->ipv4.sysctl_tcp_tw_reuse = 2; + net->ipv4.sysctl_tcp_tw_reuse_delay = 1 * MSEC_PER_SEC; net->ipv4.sysctl_tcp_no_ssthresh_metrics_save = 1; - cnt = tcp_hashinfo.ehash_mask + 1; - net->ipv4.tcp_death_row.sysctl_max_tw_buckets = cnt / 2; - net->ipv4.tcp_death_row.hashinfo = &tcp_hashinfo; + refcount_set(&net->ipv4.tcp_death_row.tw_refcount, 1); + tcp_set_hashinfo(net); - net->ipv4.sysctl_max_syn_backlog = max(128, cnt / 128); net->ipv4.sysctl_tcp_sack = 1; net->ipv4.sysctl_tcp_window_scaling = 1; net->ipv4.sysctl_tcp_timestamps = 1; @@ -3190,16 +3566,20 @@ static int __net_init tcp_sk_init(struct net *net) net->ipv4.sysctl_tcp_adv_win_scale = 1; net->ipv4.sysctl_tcp_frto = 2; net->ipv4.sysctl_tcp_moderate_rcvbuf = 1; + net->ipv4.sysctl_tcp_rcvbuf_low_rtt = USEC_PER_MSEC; /* This limits the percentage of the congestion window which we * will allow a single TSO frame to consume. Building TSO frames * which are too large can cause TCP streams to be bursty. */ net->ipv4.sysctl_tcp_tso_win_divisor = 3; - /* Default TSQ limit of 16 TSO segments */ - net->ipv4.sysctl_tcp_limit_output_bytes = 16 * 65536; - /* rfc5961 challenge ack rate limiting */ - net->ipv4.sysctl_tcp_challenge_ack_limit = 1000; + /* Default TSQ limit of 4 MB */ + net->ipv4.sysctl_tcp_limit_output_bytes = 4 << 20; + + /* rfc5961 challenge ack rate limiting, per net-ns, disabled by default. */ + net->ipv4.sysctl_tcp_challenge_ack_limit = INT_MAX; + net->ipv4.sysctl_tcp_min_tso_segs = 2; + net->ipv4.sysctl_tcp_tso_rtt_log = 9; /* 2^9 = 512 usec */ net->ipv4.sysctl_tcp_min_rtt_wlen = 300; net->ipv4.sysctl_tcp_autocorking = 1; net->ipv4.sysctl_tcp_invalid_ratelimit = HZ/2; @@ -3214,12 +3594,22 @@ static int __net_init tcp_sk_init(struct net *net) sizeof(init_net.ipv4.sysctl_tcp_wmem)); } net->ipv4.sysctl_tcp_comp_sack_delay_ns = NSEC_PER_MSEC; - net->ipv4.sysctl_tcp_comp_sack_slack_ns = 100 * NSEC_PER_USEC; + net->ipv4.sysctl_tcp_comp_sack_slack_ns = 10 * NSEC_PER_USEC; net->ipv4.sysctl_tcp_comp_sack_nr = 44; + net->ipv4.sysctl_tcp_comp_sack_rtt_percent = 33; + net->ipv4.sysctl_tcp_backlog_ack_defer = 1; net->ipv4.sysctl_tcp_fastopen = TFO_CLIENT_ENABLE; net->ipv4.sysctl_tcp_fastopen_blackhole_timeout = 0; atomic_set(&net->ipv4.tfo_active_disable_times, 0); + /* Set default values for PLB */ + net->ipv4.sysctl_tcp_plb_enabled = 0; /* Disabled by default */ + net->ipv4.sysctl_tcp_plb_idle_rehash_rounds = 3; + net->ipv4.sysctl_tcp_plb_rehash_rounds = 12; + net->ipv4.sysctl_tcp_plb_suspend_rto_sec = 60; + /* Default congestion threshold for PLB to mark a round is 50% */ + net->ipv4.sysctl_tcp_plb_cong_thresh = (1 << TCP_PLB_SCALE) / 2; + /* Reno is always built in */ if (!net_eq(net, &init_net) && bpf_try_module_get(init_net.ipv4.tcp_congestion_control, @@ -3228,21 +3618,39 @@ static int __net_init tcp_sk_init(struct net *net) else net->ipv4.tcp_congestion_control = &tcp_reno; - return 0; -fail: - tcp_sk_exit(net); + net->ipv4.sysctl_tcp_syn_linear_timeouts = 4; + net->ipv4.sysctl_tcp_shrink_window = 0; + + net->ipv4.sysctl_tcp_pingpong_thresh = 1; + net->ipv4.sysctl_tcp_rto_min_us = jiffies_to_usecs(TCP_RTO_MIN); + net->ipv4.sysctl_tcp_rto_max_ms = TCP_RTO_MAX_SEC * MSEC_PER_SEC; - return res; + return 0; } static void __net_exit tcp_sk_exit_batch(struct list_head *net_exit_list) { struct net *net; - inet_twsk_purge(&tcp_hashinfo, AF_INET); + /* make sure concurrent calls to tcp_sk_exit_batch from net_cleanup_work + * and failed setup_net error unwinding path are serialized. + * + * tcp_twsk_purge() handles twsk in any dead netns, not just those in + * net_exit_list, the thread that dismantles a particular twsk must + * do so without other thread progressing to refcount_dec_and_test() of + * tcp_death_row.tw_refcount. + */ + mutex_lock(&tcp_exit_batch_mutex); - list_for_each_entry(net, net_exit_list, exit_list) + tcp_twsk_purge(net_exit_list); + + list_for_each_entry(net, net_exit_list, exit_list) { + inet_pernet_hashinfo_free(net->ipv4.tcp_death_row.hashinfo); + WARN_ON_ONCE(!refcount_dec_and_test(&net->ipv4.tcp_death_row.tw_refcount)); tcp_fastopen_ctx_destroy(net); + } + + mutex_unlock(&tcp_exit_batch_mutex); } static struct pernet_operations __net_initdata tcp_sk_ops = { @@ -3266,7 +3674,7 @@ static int bpf_iter_init_tcp(void *priv_data, struct bpf_iter_aux_info *aux) if (err) return err; - err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ); + err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ, GFP_USER); if (err) { bpf_iter_fini_seq_net(priv_data); return err; @@ -3309,7 +3717,7 @@ static struct bpf_iter_reg tcp_reg_info = { .ctx_arg_info_size = 1, .ctx_arg_info = { { offsetof(struct bpf_iter__tcp, sk_common), - PTR_TO_BTF_ID_OR_NULL }, + PTR_TO_BTF_ID_OR_NULL | PTR_TRUSTED }, }, .get_func_proto = bpf_iter_tcp_get_func_proto, .seq_info = &tcp_seq_info, @@ -3326,6 +3734,26 @@ static void __init bpf_iter_register(void) void __init tcp_v4_init(void) { + int cpu, res; + + for_each_possible_cpu(cpu) { + struct sock *sk; + + res = inet_ctl_sock_create(&sk, PF_INET, SOCK_RAW, + IPPROTO_TCP, &init_net); + if (res) + panic("Failed to create the TCP control socket.\n"); + sock_set_flag(sk, SOCK_USE_WRITE_QUEUE); + + /* Please enforce IP_DF and IPID==0 for RST and + * ACK sent in SYN-RECV and TIME-WAIT state. + */ + inet_sk(sk)->pmtudisc = IP_PMTUDISC_DO; + + sk->sk_clockid = CLOCK_MONOTONIC; + + per_cpu(ipv4_tcp_sk.sock, cpu) = sk; + } if (register_pernet_subsys(&tcp_sk_ops)) panic("Failed to create the TCP control socket.\n"); diff --git a/net/ipv4/tcp_lp.c b/net/ipv4/tcp_lp.c index 82b36ec3f2f8..976b56644a8a 100644 --- a/net/ipv4/tcp_lp.c +++ b/net/ipv4/tcp_lp.c @@ -23,9 +23,9 @@ * Original Author: * Aleksandar Kuzmanovic <akuzma@northwestern.edu> * Available from: - * http://www.ece.rice.edu/~akuzma/Doc/akuzma/TCP-LP.pdf + * https://users.cs.northwestern.edu/~akuzma/doc/TCP-LP-ToN.pdf * Original implementation for 2.4.19: - * http://www-ece.rice.edu/networks/TCP-LP/ + * https://users.cs.northwestern.edu/~akuzma/rice/TCP-LP/linux/tcp-lp-linux.htm * * 2.6.x module Authors: * Wong Hoi Sing, Edison <hswong3i@gmail.com> @@ -113,6 +113,8 @@ static void tcp_lp_init(struct sock *sk) /** * tcp_lp_cong_avoid * @sk: socket to avoid congesting + * @ack: current ack sequence number + * @acked: number of ACKed packets * * Implementation of cong_avoid. * Will only call newReno CA when away from inference. @@ -261,6 +263,7 @@ static void tcp_lp_rtt_sample(struct sock *sk, u32 rtt) /** * tcp_lp_pkts_acked * @sk: socket requiring congestion avoidance calculations + * @sample: ACK sample containing timing and rate information * * Implementation of pkts_acked. * Deal with active drop under Early Congestion Indication. @@ -272,7 +275,7 @@ static void tcp_lp_pkts_acked(struct sock *sk, const struct ack_sample *sample) { struct tcp_sock *tp = tcp_sk(sk); struct lp *lp = inet_csk_ca(sk); - u32 now = tcp_time_stamp(tp); + u32 now = tcp_time_stamp_ts(tp); u32 delta; if (sample->rtt_us > 0) @@ -297,7 +300,7 @@ static void tcp_lp_pkts_acked(struct sock *sk, const struct ack_sample *sample) lp->flag &= ~LP_WITHIN_THR; pr_debug("TCP-LP: %05o|%5u|%5u|%15u|%15u|%15u\n", lp->flag, - tp->snd_cwnd, lp->remote_hz, lp->owd_min, lp->owd_max, + tcp_snd_cwnd(tp), lp->remote_hz, lp->owd_min, lp->owd_max, lp->sowd >> 3); if (lp->flag & LP_WITHIN_THR) @@ -313,12 +316,12 @@ static void tcp_lp_pkts_acked(struct sock *sk, const struct ack_sample *sample) /* happened within inference * drop snd_cwnd into 1 */ if (lp->flag & LP_WITHIN_INF) - tp->snd_cwnd = 1U; + tcp_snd_cwnd_set(tp, 1U); /* happened after inference * cut snd_cwnd into half */ else - tp->snd_cwnd = max(tp->snd_cwnd >> 1U, 1U); + tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp) >> 1U, 1U)); /* record this drop time */ lp->last_drop = now; diff --git a/net/ipv4/tcp_metrics.c b/net/ipv4/tcp_metrics.c index 0588b004ddac..45b6ecd16412 100644 --- a/net/ipv4/tcp_metrics.c +++ b/net/ipv4/tcp_metrics.c @@ -40,7 +40,7 @@ struct tcp_fastopen_metrics { struct tcp_metrics_block { struct tcp_metrics_block __rcu *tcpm_next; - possible_net_t tcpm_net; + struct net *tcpm_net; struct inetpeer_addr tcpm_saddr; struct inetpeer_addr tcpm_daddr; unsigned long tcpm_stamp; @@ -51,34 +51,38 @@ struct tcp_metrics_block { struct rcu_head rcu_head; }; -static inline struct net *tm_net(struct tcp_metrics_block *tm) +static inline struct net *tm_net(const struct tcp_metrics_block *tm) { - return read_pnet(&tm->tcpm_net); + /* Paired with the WRITE_ONCE() in tcpm_new() */ + return READ_ONCE(tm->tcpm_net); } static bool tcp_metric_locked(struct tcp_metrics_block *tm, enum tcp_metric_index idx) { - return tm->tcpm_lock & (1 << idx); + /* Paired with WRITE_ONCE() in tcpm_suck_dst() */ + return READ_ONCE(tm->tcpm_lock) & (1 << idx); } -static u32 tcp_metric_get(struct tcp_metrics_block *tm, +static u32 tcp_metric_get(const struct tcp_metrics_block *tm, enum tcp_metric_index idx) { - return tm->tcpm_vals[idx]; + /* Paired with WRITE_ONCE() in tcp_metric_set() */ + return READ_ONCE(tm->tcpm_vals[idx]); } static void tcp_metric_set(struct tcp_metrics_block *tm, enum tcp_metric_index idx, u32 val) { - tm->tcpm_vals[idx] = val; + /* Paired with READ_ONCE() in tcp_metric_get() */ + WRITE_ONCE(tm->tcpm_vals[idx], val); } static bool addr_same(const struct inetpeer_addr *a, const struct inetpeer_addr *b) { - return inetpeer_addr_cmp(a, b) == 0; + return (a->family == b->family) && !inetpeer_addr_cmp(a, b); } struct tcpm_hash_bucket { @@ -89,6 +93,7 @@ static struct tcpm_hash_bucket *tcp_metrics_hash __read_mostly; static unsigned int tcp_metrics_hash_log __read_mostly; static DEFINE_SPINLOCK(tcp_metrics_lock); +static DEFINE_SEQLOCK(fastopen_seqlock); static void tcpm_suck_dst(struct tcp_metrics_block *tm, const struct dst_entry *dst, @@ -97,7 +102,7 @@ static void tcpm_suck_dst(struct tcp_metrics_block *tm, u32 msval; u32 val; - tm->tcpm_stamp = jiffies; + WRITE_ONCE(tm->tcpm_stamp, jiffies); val = 0; if (dst_metric_locked(dst, RTAX_RTT)) @@ -110,30 +115,42 @@ static void tcpm_suck_dst(struct tcp_metrics_block *tm, val |= 1 << TCP_METRIC_CWND; if (dst_metric_locked(dst, RTAX_REORDERING)) val |= 1 << TCP_METRIC_REORDERING; - tm->tcpm_lock = val; + /* Paired with READ_ONCE() in tcp_metric_locked() */ + WRITE_ONCE(tm->tcpm_lock, val); msval = dst_metric_raw(dst, RTAX_RTT); - tm->tcpm_vals[TCP_METRIC_RTT] = msval * USEC_PER_MSEC; + tcp_metric_set(tm, TCP_METRIC_RTT, msval * USEC_PER_MSEC); msval = dst_metric_raw(dst, RTAX_RTTVAR); - tm->tcpm_vals[TCP_METRIC_RTTVAR] = msval * USEC_PER_MSEC; - tm->tcpm_vals[TCP_METRIC_SSTHRESH] = dst_metric_raw(dst, RTAX_SSTHRESH); - tm->tcpm_vals[TCP_METRIC_CWND] = dst_metric_raw(dst, RTAX_CWND); - tm->tcpm_vals[TCP_METRIC_REORDERING] = dst_metric_raw(dst, RTAX_REORDERING); + tcp_metric_set(tm, TCP_METRIC_RTTVAR, msval * USEC_PER_MSEC); + tcp_metric_set(tm, TCP_METRIC_SSTHRESH, + dst_metric_raw(dst, RTAX_SSTHRESH)); + tcp_metric_set(tm, TCP_METRIC_CWND, + dst_metric_raw(dst, RTAX_CWND)); + tcp_metric_set(tm, TCP_METRIC_REORDERING, + dst_metric_raw(dst, RTAX_REORDERING)); if (fastopen_clear) { + write_seqlock(&fastopen_seqlock); tm->tcpm_fastopen.mss = 0; tm->tcpm_fastopen.syn_loss = 0; tm->tcpm_fastopen.try_exp = 0; tm->tcpm_fastopen.cookie.exp = false; tm->tcpm_fastopen.cookie.len = 0; + write_sequnlock(&fastopen_seqlock); } } #define TCP_METRICS_TIMEOUT (60 * 60 * HZ) -static void tcpm_check_stamp(struct tcp_metrics_block *tm, struct dst_entry *dst) +static void tcpm_check_stamp(struct tcp_metrics_block *tm, + const struct dst_entry *dst) { - if (tm && unlikely(time_after(jiffies, tm->tcpm_stamp + TCP_METRICS_TIMEOUT))) + unsigned long limit; + + if (!tm) + return; + limit = READ_ONCE(tm->tcpm_stamp) + TCP_METRICS_TIMEOUT; + if (unlikely(time_after(jiffies, limit))) tcpm_suck_dst(tm, dst, false); } @@ -149,11 +166,11 @@ static struct tcp_metrics_block *tcpm_new(struct dst_entry *dst, unsigned int hash) { struct tcp_metrics_block *tm; - struct net *net; bool reclaim = false; + struct net *net; spin_lock_bh(&tcp_metrics_lock); - net = dev_net(dst->dev); + net = dst_dev_net_rcu(dst); /* While waiting for the spin-lock the cache might have been populated * with this entry and so we have to check again. @@ -174,20 +191,23 @@ static struct tcp_metrics_block *tcpm_new(struct dst_entry *dst, oldest = deref_locked(tcp_metrics_hash[hash].chain); for (tm = deref_locked(oldest->tcpm_next); tm; tm = deref_locked(tm->tcpm_next)) { - if (time_before(tm->tcpm_stamp, oldest->tcpm_stamp)) + if (time_before(READ_ONCE(tm->tcpm_stamp), + READ_ONCE(oldest->tcpm_stamp))) oldest = tm; } tm = oldest; } else { - tm = kmalloc(sizeof(*tm), GFP_ATOMIC); + tm = kzalloc(sizeof(*tm), GFP_ATOMIC); if (!tm) goto out_unlock; } - write_pnet(&tm->tcpm_net, net); + /* Paired with the READ_ONCE() in tm_net() */ + WRITE_ONCE(tm->tcpm_net, net); + tm->tcpm_saddr = *saddr; tm->tcpm_daddr = *daddr; - tcpm_suck_dst(tm, dst, true); + tcpm_suck_dst(tm, dst, reclaim); if (likely(!reclaim)) { tm->tcpm_next = tcp_metrics_hash[hash].chain; @@ -253,7 +273,7 @@ static struct tcp_metrics_block *__tcp_get_metrics_req(struct request_sock *req, return NULL; } - net = dev_net(dst->dev); + net = dst_dev_net_rcu(dst); hash ^= net_hash_mix(net); hash = hash_32(hash, tcp_metrics_hash_log); @@ -298,7 +318,7 @@ static struct tcp_metrics_block *tcp_get_metrics(struct sock *sk, else return NULL; - net = dev_net(dst->dev); + net = dst_dev_net_rcu(dst); hash ^= net_hash_mix(net); hash = hash_32(hash, tcp_metrics_hash_log); @@ -329,7 +349,7 @@ void tcp_update_metrics(struct sock *sk) int m; sk_dst_confirm(sk); - if (net->ipv4.sysctl_tcp_nometrics_save || !dst) + if (READ_ONCE(net->ipv4.sysctl_tcp_nometrics_save) || !dst) return; rcu_read_lock(); @@ -385,29 +405,29 @@ void tcp_update_metrics(struct sock *sk) if (tcp_in_initial_slowstart(tp)) { /* Slow start still did not finish. */ - if (!net->ipv4.sysctl_tcp_no_ssthresh_metrics_save && + if (!READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) && !tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) { val = tcp_metric_get(tm, TCP_METRIC_SSTHRESH); - if (val && (tp->snd_cwnd >> 1) > val) + if (val && (tcp_snd_cwnd(tp) >> 1) > val) tcp_metric_set(tm, TCP_METRIC_SSTHRESH, - tp->snd_cwnd >> 1); + tcp_snd_cwnd(tp) >> 1); } if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) { val = tcp_metric_get(tm, TCP_METRIC_CWND); - if (tp->snd_cwnd > val) + if (tcp_snd_cwnd(tp) > val) tcp_metric_set(tm, TCP_METRIC_CWND, - tp->snd_cwnd); + tcp_snd_cwnd(tp)); } } else if (!tcp_in_slow_start(tp) && icsk->icsk_ca_state == TCP_CA_Open) { /* Cong. avoidance phase, cwnd is reliable. */ - if (!net->ipv4.sysctl_tcp_no_ssthresh_metrics_save && + if (!READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) && !tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) tcp_metric_set(tm, TCP_METRIC_SSTHRESH, - max(tp->snd_cwnd >> 1, tp->snd_ssthresh)); + max(tcp_snd_cwnd(tp) >> 1, tp->snd_ssthresh)); if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) { val = tcp_metric_get(tm, TCP_METRIC_CWND); - tcp_metric_set(tm, TCP_METRIC_CWND, (val + tp->snd_cwnd) >> 1); + tcp_metric_set(tm, TCP_METRIC_CWND, (val + tcp_snd_cwnd(tp)) >> 1); } } else { /* Else slow start did not finish, cwnd is non-sense, @@ -418,7 +438,7 @@ void tcp_update_metrics(struct sock *sk) tcp_metric_set(tm, TCP_METRIC_CWND, (val + tp->snd_ssthresh) >> 1); } - if (!net->ipv4.sysctl_tcp_no_ssthresh_metrics_save && + if (!READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) && !tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) { val = tcp_metric_get(tm, TCP_METRIC_SSTHRESH); if (val && tp->snd_ssthresh > val) @@ -428,12 +448,13 @@ void tcp_update_metrics(struct sock *sk) if (!tcp_metric_locked(tm, TCP_METRIC_REORDERING)) { val = tcp_metric_get(tm, TCP_METRIC_REORDERING); if (val < tp->reordering && - tp->reordering != net->ipv4.sysctl_tcp_reordering) + tp->reordering != + READ_ONCE(net->ipv4.sysctl_tcp_reordering)) tcp_metric_set(tm, TCP_METRIC_REORDERING, tp->reordering); } } - tm->tcpm_stamp = jiffies; + WRITE_ONCE(tm->tcpm_stamp, jiffies); out_unlock: rcu_read_unlock(); } @@ -449,11 +470,15 @@ void tcp_init_metrics(struct sock *sk) u32 val, crtt = 0; /* cached RTT scaled by 8 */ sk_dst_confirm(sk); + /* ssthresh may have been reduced unnecessarily during. + * 3WHS. Restore it back to its initial default. + */ + tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; if (!dst) goto reset; rcu_read_lock(); - tm = tcp_get_metrics(sk, dst, true); + tm = tcp_get_metrics(sk, dst, false); if (!tm) { rcu_read_unlock(); goto reset; @@ -462,17 +487,12 @@ void tcp_init_metrics(struct sock *sk) if (tcp_metric_locked(tm, TCP_METRIC_CWND)) tp->snd_cwnd_clamp = tcp_metric_get(tm, TCP_METRIC_CWND); - val = net->ipv4.sysctl_tcp_no_ssthresh_metrics_save ? + val = READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) ? 0 : tcp_metric_get(tm, TCP_METRIC_SSTHRESH); if (val) { tp->snd_ssthresh = val; if (tp->snd_ssthresh > tp->snd_cwnd_clamp) tp->snd_ssthresh = tp->snd_cwnd_clamp; - } else { - /* ssthresh may have been reduced unnecessarily during. - * 3WHS. Restore it back to its initial default. - */ - tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; } val = tcp_metric_get(tm, TCP_METRIC_REORDERING); if (val && tp->reordering != val) @@ -538,8 +558,6 @@ bool tcp_peer_is_proven(struct request_sock *req, struct dst_entry *dst) return ret; } -static DEFINE_SEQLOCK(fastopen_seqlock); - void tcp_fastopen_cache_get(struct sock *sk, u16 *mss, struct tcp_fastopen_cookie *cookie) { @@ -599,8 +617,13 @@ static struct genl_family tcp_metrics_nl_family; static const struct nla_policy tcp_metrics_nl_policy[TCP_METRICS_ATTR_MAX + 1] = { [TCP_METRICS_ATTR_ADDR_IPV4] = { .type = NLA_U32, }, - [TCP_METRICS_ATTR_ADDR_IPV6] = { .type = NLA_BINARY, - .len = sizeof(struct in6_addr), }, + [TCP_METRICS_ATTR_ADDR_IPV6] = + NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), + + [TCP_METRICS_ATTR_SADDR_IPV4] = { .type = NLA_U32, }, + [TCP_METRICS_ATTR_SADDR_IPV6] = + NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), + /* Following attributes are not received for GET/DEL, * we keep them for reference */ @@ -646,7 +669,7 @@ static int tcp_metrics_fill_info(struct sk_buff *msg, } if (nla_put_msecs(msg, TCP_METRICS_ATTR_AGE, - jiffies - tm->tcpm_stamp, + jiffies - READ_ONCE(tm->tcpm_stamp), TCP_METRICS_ATTR_PAD) < 0) goto nla_put_failure; @@ -657,7 +680,7 @@ static int tcp_metrics_fill_info(struct sk_buff *msg, if (!nest) goto nla_put_failure; for (i = 0; i < TCP_METRIC_MAX_KERNEL + 1; i++) { - u32 val = tm->tcpm_vals[i]; + u32 val = tcp_metric_get(tm, i); if (!val) continue; @@ -748,6 +771,7 @@ static int tcp_metrics_nl_dump(struct sk_buff *skb, unsigned int max_rows = 1U << tcp_metrics_hash_log; unsigned int row, s_row = cb->args[0]; int s_col = cb->args[1], col = s_col; + int res = 0; for (row = s_row; row < max_rows; row++, s_col = 0) { struct tcp_metrics_block *tm; @@ -760,7 +784,8 @@ static int tcp_metrics_nl_dump(struct sk_buff *skb, continue; if (col < s_col) continue; - if (tcp_metrics_dump_info(skb, cb, tm) < 0) { + res = tcp_metrics_dump_info(skb, cb, tm); + if (res < 0) { rcu_read_unlock(); goto done; } @@ -771,7 +796,7 @@ static int tcp_metrics_nl_dump(struct sk_buff *skb, done: cb->args[0] = row; cb->args[1] = col; - return skb->len; + return res; } static int __parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr, @@ -790,8 +815,6 @@ static int __parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr, if (a) { struct in6_addr in6; - if (nla_len(a) != sizeof(struct in6_addr)) - return -EINVAL; in6 = nla_get_in6_addr(a); inetpeer_set_addr_v6(addr, &in6); if (hash) @@ -880,22 +903,25 @@ static void tcp_metrics_flush_all(struct net *net) unsigned int row; for (row = 0; row < max_rows; row++, hb++) { - struct tcp_metrics_block __rcu **pp; + struct tcp_metrics_block __rcu **pp = &hb->chain; bool match; + if (!rcu_access_pointer(*pp)) + continue; + spin_lock_bh(&tcp_metrics_lock); - pp = &hb->chain; for (tm = deref_locked(*pp); tm; tm = deref_locked(*pp)) { match = net ? net_eq(tm_net(tm), net) : - !refcount_read(&tm_net(tm)->ns.count); + !check_net(tm_net(tm)); if (match) { - *pp = tm->tcpm_next; + rcu_assign_pointer(*pp, tm->tcpm_next); kfree_rcu(tm, rcu_head); } else { pp = &tm->tcpm_next; } } spin_unlock_bh(&tcp_metrics_lock); + cond_resched(); } } @@ -930,7 +956,7 @@ static int tcp_metrics_nl_cmd_del(struct sk_buff *skb, struct genl_info *info) if (addr_same(&tm->tcpm_daddr, &daddr) && (!src || addr_same(&tm->tcpm_saddr, &saddr)) && net_eq(tm_net(tm), net)) { - *pp = tm->tcpm_next; + rcu_assign_pointer(*pp, tm->tcpm_next); kfree_rcu(tm, rcu_head); found = true; } else { @@ -965,12 +991,14 @@ static struct genl_family tcp_metrics_nl_family __ro_after_init = { .maxattr = TCP_METRICS_ATTR_MAX, .policy = tcp_metrics_nl_policy, .netnsok = true, + .parallel_ops = true, .module = THIS_MODULE, .small_ops = tcp_metrics_nl_ops, .n_small_ops = ARRAY_SIZE(tcp_metrics_nl_ops), + .resv_start_op = TCP_METRICS_CMD_DEL + 1, }; -static unsigned int tcpmhash_entries; +static unsigned int tcpmhash_entries __initdata; static int __init set_tcpmhash_entries(char *str) { ssize_t ret; @@ -986,15 +1014,11 @@ static int __init set_tcpmhash_entries(char *str) } __setup("tcpmhash_entries=", set_tcpmhash_entries); -static int __net_init tcp_net_metrics_init(struct net *net) +static void __init tcp_metrics_hash_alloc(void) { + unsigned int slots = tcpmhash_entries; size_t size; - unsigned int slots; - - if (!net_eq(net, &init_net)) - return 0; - slots = tcpmhash_entries; if (!slots) { if (totalram_pages() >= 128 * 1024) slots = 16 * 1024; @@ -1007,9 +1031,7 @@ static int __net_init tcp_net_metrics_init(struct net *net) tcp_metrics_hash = kvzalloc(size, GFP_KERNEL); if (!tcp_metrics_hash) - return -ENOMEM; - - return 0; + panic("Could not allocate the tcp_metrics hash table\n"); } static void __net_exit tcp_net_metrics_exit_batch(struct list_head *net_exit_list) @@ -1018,7 +1040,6 @@ static void __net_exit tcp_net_metrics_exit_batch(struct list_head *net_exit_lis } static __net_initdata struct pernet_operations tcp_net_metrics_ops = { - .init = tcp_net_metrics_init, .exit_batch = tcp_net_metrics_exit_batch, }; @@ -1026,9 +1047,11 @@ void __init tcp_metrics_init(void) { int ret; + tcp_metrics_hash_alloc(); + ret = register_pernet_subsys(&tcp_net_metrics_ops); if (ret < 0) - panic("Could not allocate the tcp_metrics hash table\n"); + panic("Could not register tcp_net_metrics_ops\n"); ret = genl_register_family(&tcp_metrics_nl_family); if (ret < 0) diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c index 7c2d3ac2363a..bd5462154f97 100644 --- a/net/ipv4/tcp_minisocks.c +++ b/net/ipv4/tcp_minisocks.c @@ -20,8 +20,11 @@ */ #include <net/tcp.h> +#include <net/tcp_ecn.h> #include <net/xfrm.h> #include <net/busy_poll.h> +#include <net/rstreason.h> +#include <net/psp.h> static bool tcp_in_window(u32 seq, u32 end_seq, u32 s_win, u32 e_win) { @@ -43,7 +46,7 @@ tcp_timewait_check_oow_rate_limit(struct inet_timewait_sock *tw, /* Send ACK. Note, we do not put the bucket, * it will be released by caller. */ - return TCP_TW_ACK; + return TCP_TW_ACK_OOW; } /* We are rate-limiting, so just release the tw sock and drop skb. */ @@ -51,6 +54,19 @@ tcp_timewait_check_oow_rate_limit(struct inet_timewait_sock *tw, return TCP_TW_SUCCESS; } +static void twsk_rcv_nxt_update(struct tcp_timewait_sock *tcptw, u32 seq, + u32 rcv_nxt) +{ +#ifdef CONFIG_TCP_AO + struct tcp_ao_info *ao; + + ao = rcu_dereference(tcptw->ao_info); + if (unlikely(ao && seq < rcv_nxt)) + WRITE_ONCE(ao->rcv_sne, ao->rcv_sne + 1); +#endif + WRITE_ONCE(tcptw->tw_rcv_nxt, seq); +} + /* * * Main purpose of TIME-WAIT state is to close connection gracefully, * when one of ends sits in LAST-ACK or CLOSING retransmitting FIN @@ -83,45 +99,59 @@ tcp_timewait_check_oow_rate_limit(struct inet_timewait_sock *tw, */ enum tcp_tw_status tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb, - const struct tcphdr *th) + const struct tcphdr *th, u32 *tw_isn, + enum skb_drop_reason *drop_reason) { - struct tcp_options_received tmp_opt; struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw); + u32 rcv_nxt = READ_ONCE(tcptw->tw_rcv_nxt); + struct tcp_options_received tmp_opt; + enum skb_drop_reason psp_drop; bool paws_reject = false; + int ts_recent_stamp; + + /* Instead of dropping immediately, wait to see what value is + * returned. We will accept a non psp-encapsulated syn in the + * case where TCP_TW_SYN is returned. + */ + psp_drop = psp_twsk_rx_policy_check(tw, skb); tmp_opt.saw_tstamp = 0; - if (th->doff > (sizeof(*th) >> 2) && tcptw->tw_ts_recent_stamp) { + ts_recent_stamp = READ_ONCE(tcptw->tw_ts_recent_stamp); + if (th->doff > (sizeof(*th) >> 2) && ts_recent_stamp) { tcp_parse_options(twsk_net(tw), skb, &tmp_opt, 0, NULL); if (tmp_opt.saw_tstamp) { if (tmp_opt.rcv_tsecr) tmp_opt.rcv_tsecr -= tcptw->tw_ts_offset; - tmp_opt.ts_recent = tcptw->tw_ts_recent; - tmp_opt.ts_recent_stamp = tcptw->tw_ts_recent_stamp; + tmp_opt.ts_recent = READ_ONCE(tcptw->tw_ts_recent); + tmp_opt.ts_recent_stamp = ts_recent_stamp; paws_reject = tcp_paws_reject(&tmp_opt, th->rst); } } - if (tw->tw_substate == TCP_FIN_WAIT2) { + if (READ_ONCE(tw->tw_substate) == TCP_FIN_WAIT2) { /* Just repeat all the checks of tcp_rcv_state_process() */ + if (psp_drop) + goto out_put; + /* Out of window, send ACK */ if (paws_reject || !tcp_in_window(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq, - tcptw->tw_rcv_nxt, - tcptw->tw_rcv_nxt + tcptw->tw_rcv_wnd)) + rcv_nxt, + rcv_nxt + tcptw->tw_rcv_wnd)) return tcp_timewait_check_oow_rate_limit( tw, skb, LINUX_MIB_TCPACKSKIPPEDFINWAIT2); if (th->rst) goto kill; - if (th->syn && !before(TCP_SKB_CB(skb)->seq, tcptw->tw_rcv_nxt)) + if (th->syn && !before(TCP_SKB_CB(skb)->seq, rcv_nxt)) return TCP_TW_RST; /* Dup ACK? */ if (!th->ack || - !after(TCP_SKB_CB(skb)->end_seq, tcptw->tw_rcv_nxt) || + !after(TCP_SKB_CB(skb)->end_seq, rcv_nxt) || TCP_SKB_CB(skb)->end_seq == TCP_SKB_CB(skb)->seq) { inet_twsk_put(tw); return TCP_TW_SUCCESS; @@ -131,15 +161,22 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb, * reset. */ if (!th->fin || - TCP_SKB_CB(skb)->end_seq != tcptw->tw_rcv_nxt + 1) + TCP_SKB_CB(skb)->end_seq != rcv_nxt + 1) return TCP_TW_RST; /* FIN arrived, enter true time-wait state. */ - tw->tw_substate = TCP_TIME_WAIT; - tcptw->tw_rcv_nxt = TCP_SKB_CB(skb)->end_seq; + WRITE_ONCE(tw->tw_substate, TCP_TIME_WAIT); + twsk_rcv_nxt_update(tcptw, TCP_SKB_CB(skb)->end_seq, + rcv_nxt); + if (tmp_opt.saw_tstamp) { - tcptw->tw_ts_recent_stamp = ktime_get_seconds(); - tcptw->tw_ts_recent = tmp_opt.rcv_tsval; + u64 ts = tcp_clock_ms(); + + WRITE_ONCE(tw->tw_entry_stamp, ts); + WRITE_ONCE(tcptw->tw_ts_recent_stamp, + div_u64(ts, MSEC_PER_SEC)); + WRITE_ONCE(tcptw->tw_ts_recent, + tmp_opt.rcv_tsval); } inet_twsk_reschedule(tw, TCP_TIMEWAIT_LEN); @@ -164,16 +201,19 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb, */ if (!paws_reject && - (TCP_SKB_CB(skb)->seq == tcptw->tw_rcv_nxt && + (TCP_SKB_CB(skb)->seq == rcv_nxt && (TCP_SKB_CB(skb)->seq == TCP_SKB_CB(skb)->end_seq || th->rst))) { /* In window segment, it may be only reset or bare ack. */ + if (psp_drop) + goto out_put; + if (th->rst) { /* This is TIME_WAIT assassination, in two flavors. * Oh well... nobody has a sufficient solution to this * protocol bug yet. */ - if (twsk_net(tw)->ipv4.sysctl_tcp_rfc1337 == 0) { + if (!READ_ONCE(twsk_net(tw)->ipv4.sysctl_tcp_rfc1337)) { kill: inet_twsk_deschedule_put(tw); return TCP_TW_SUCCESS; @@ -183,8 +223,10 @@ kill: } if (tmp_opt.saw_tstamp) { - tcptw->tw_ts_recent = tmp_opt.rcv_tsval; - tcptw->tw_ts_recent_stamp = ktime_get_seconds(); + WRITE_ONCE(tcptw->tw_ts_recent, + tmp_opt.rcv_tsval); + WRITE_ONCE(tcptw->tw_ts_recent_stamp, + ktime_get_seconds()); } inet_twsk_put(tw); @@ -209,18 +251,23 @@ kill: */ if (th->syn && !th->rst && !th->ack && !paws_reject && - (after(TCP_SKB_CB(skb)->seq, tcptw->tw_rcv_nxt) || + (after(TCP_SKB_CB(skb)->seq, rcv_nxt) || (tmp_opt.saw_tstamp && - (s32)(tcptw->tw_ts_recent - tmp_opt.rcv_tsval) < 0))) { + (s32)(READ_ONCE(tcptw->tw_ts_recent) - tmp_opt.rcv_tsval) < 0))) { u32 isn = tcptw->tw_snd_nxt + 65535 + 2; if (isn == 0) isn++; - TCP_SKB_CB(skb)->tcp_tw_isn = isn; + *tw_isn = isn; return TCP_TW_SYN; } - if (paws_reject) - __NET_INC_STATS(twsk_net(tw), LINUX_MIB_PAWSESTABREJECTED); + if (psp_drop) + goto out_put; + + if (paws_reject) { + *drop_reason = SKB_DROP_REASON_TCP_RFC7323_TW_PAWS; + __NET_INC_STATS(twsk_net(tw), LINUX_MIB_PAWS_TW_REJECTED); + } if (!th->rst) { /* In this case we must reset the TIMEWAIT timer. @@ -235,10 +282,44 @@ kill: return tcp_timewait_check_oow_rate_limit( tw, skb, LINUX_MIB_TCPACKSKIPPEDTIMEWAIT); } + +out_put: inet_twsk_put(tw); return TCP_TW_SUCCESS; } -EXPORT_SYMBOL(tcp_timewait_state_process); +EXPORT_IPV6_MOD(tcp_timewait_state_process); + +static void tcp_time_wait_init(struct sock *sk, struct tcp_timewait_sock *tcptw) +{ +#ifdef CONFIG_TCP_MD5SIG + const struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_key *key; + + /* + * The timewait bucket does not have the key DB from the + * sock structure. We just make a quick copy of the + * md5 key being used (if indeed we are using one) + * so the timewait ack generating code has the key. + */ + tcptw->tw_md5_key = NULL; + if (!static_branch_unlikely(&tcp_md5_needed.key)) + return; + + key = tp->af_specific->md5_lookup(sk, sk); + if (key) { + tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); + if (!tcptw->tw_md5_key) + return; + if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) + goto out_free; + } + return; +out_free: + WARN_ON_ONCE(1); + kfree(tcptw->tw_md5_key); + tcptw->tw_md5_key = NULL; +#endif +} /* * Move a socket to time-wait or dead fin-wait-2 state. @@ -246,29 +327,35 @@ EXPORT_SYMBOL(tcp_timewait_state_process); void tcp_time_wait(struct sock *sk, int state, int timeo) { const struct inet_connection_sock *icsk = inet_csk(sk); - const struct tcp_sock *tp = tcp_sk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); struct inet_timewait_sock *tw; - struct inet_timewait_death_row *tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row; - tw = inet_twsk_alloc(sk, tcp_death_row, state); + tw = inet_twsk_alloc(sk, &net->ipv4.tcp_death_row, state); if (tw) { struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw); const int rto = (icsk->icsk_rto << 2) - (icsk->icsk_rto >> 1); - struct inet_sock *inet = inet_sk(sk); - tw->tw_transparent = inet->transparent; tw->tw_mark = sk->sk_mark; - tw->tw_priority = sk->sk_priority; + tw->tw_priority = READ_ONCE(sk->sk_priority); tw->tw_rcv_wscale = tp->rx_opt.rcv_wscale; + /* refreshed when we enter true TIME-WAIT state */ + tw->tw_entry_stamp = tcp_time_stamp_ms(tp); tcptw->tw_rcv_nxt = tp->rcv_nxt; tcptw->tw_snd_nxt = tp->snd_nxt; tcptw->tw_rcv_wnd = tcp_receive_window(tp); tcptw->tw_ts_recent = tp->rx_opt.ts_recent; tcptw->tw_ts_recent_stamp = tp->rx_opt.ts_recent_stamp; tcptw->tw_ts_offset = tp->tsoffset; + tw->tw_usec_ts = tp->tcp_usec_ts; tcptw->tw_last_oow_ack_time = 0; tcptw->tw_tx_delay = tp->tcp_tx_delay; + tw->tw_txhash = sk->sk_txhash; + tw->tw_tx_queue_mapping = sk->sk_tx_queue_mapping; +#ifdef CONFIG_SOCK_RX_QUEUE_MAPPING + tw->tw_rx_queue_mapping = sk->sk_rx_queue_mapping; +#endif #if IS_ENABLED(CONFIG_IPV6) if (tw->tw_family == PF_INET6) { struct ipv6_pinfo *np = inet6_sk(sk); @@ -277,31 +364,12 @@ void tcp_time_wait(struct sock *sk, int state, int timeo) tw->tw_v6_rcv_saddr = sk->sk_v6_rcv_saddr; tw->tw_tclass = np->tclass; tw->tw_flowlabel = be32_to_cpu(np->flow_label & IPV6_FLOWLABEL_MASK); - tw->tw_txhash = sk->sk_txhash; tw->tw_ipv6only = sk->sk_ipv6only; } #endif -#ifdef CONFIG_TCP_MD5SIG - /* - * The timewait bucket does not have the key DB from the - * sock structure. We just make a quick copy of the - * md5 key being used (if indeed we are using one) - * so the timewait ack generating code has the key. - */ - do { - tcptw->tw_md5_key = NULL; - if (static_branch_unlikely(&tcp_md5_needed)) { - struct tcp_md5sig_key *key; - - key = tp->af_specific->md5_lookup(sk, sk); - if (key) { - tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); - BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool()); - } - } - } while (0); -#endif + tcp_time_wait_init(sk, tcptw); + tcp_ao_time_wait(tcptw, tp); /* Get the TIME_WAIT timeout firing. */ if (timeo < rto) @@ -310,23 +378,16 @@ void tcp_time_wait(struct sock *sk, int state, int timeo) if (state == TCP_TIME_WAIT) timeo = TCP_TIMEWAIT_LEN; - /* tw_timer is pinned, so we need to make sure BH are disabled - * in following section, otherwise timer handler could run before - * we complete the initialization. - */ - local_bh_disable(); - inet_twsk_schedule(tw, timeo); /* Linkage updates. * Note that access to tw after this point is illegal. */ - inet_twsk_hashdance(tw, sk, &tcp_hashinfo); - local_bh_enable(); + inet_twsk_hashdance_schedule(tw, sk, net->ipv4.tcp_death_row.hashinfo, timeo); } else { /* Sorry, if we're out of memory, just CLOSE this * socket up. We've got bigger problems than * non-graceful socket closings. */ - NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPTIMEWAITOVERFLOW); + NET_INC_STATS(net, LINUX_MIB_TCPTIMEWAITOVERFLOW); } tcp_update_metrics(sk); @@ -337,15 +398,34 @@ EXPORT_SYMBOL(tcp_time_wait); void tcp_twsk_destructor(struct sock *sk) { #ifdef CONFIG_TCP_MD5SIG - if (static_branch_unlikely(&tcp_md5_needed)) { + if (static_branch_unlikely(&tcp_md5_needed.key)) { struct tcp_timewait_sock *twsk = tcp_twsk(sk); - if (twsk->tw_md5_key) - kfree_rcu(twsk->tw_md5_key, rcu); + if (twsk->tw_md5_key) { + kfree(twsk->tw_md5_key); + static_branch_slow_dec_deferred(&tcp_md5_needed); + } } #endif + tcp_ao_destroy_sock(sk, true); + psp_twsk_assoc_free(inet_twsk(sk)); +} + +void tcp_twsk_purge(struct list_head *net_exit_list) +{ + bool purged_once = false; + struct net *net; + + list_for_each_entry(net, net_exit_list, exit_list) { + if (net->ipv4.tcp_death_row.hashinfo->pernet) { + /* Even if tw_refcount == 1, we must clean up kernel reqsk */ + inet_twsk_purge(net->ipv4.tcp_death_row.hashinfo); + } else if (!purged_once) { + inet_twsk_purge(&tcp_hashinfo); + purged_once = true; + } + } } -EXPORT_SYMBOL_GPL(tcp_twsk_destructor); /* Warning : This function is called without sk_listener being locked. * Be sure to read socket fields once, as their value could change under us. @@ -388,12 +468,27 @@ void tcp_openreq_init_rwin(struct request_sock *req, rcv_wnd); ireq->rcv_wscale = rcv_wscale; } -EXPORT_SYMBOL(tcp_openreq_init_rwin); -static void tcp_ecn_openreq_child(struct tcp_sock *tp, - const struct request_sock *req) +static void tcp_ecn_openreq_child(struct sock *sk, + const struct request_sock *req, + const struct sk_buff *skb) { - tp->ecn_flags = inet_rsk(req)->ecn_ok ? TCP_ECN_OK : 0; + const struct tcp_request_sock *treq = tcp_rsk(req); + struct tcp_sock *tp = tcp_sk(sk); + + if (treq->accecn_ok) { + tcp_ecn_mode_set(tp, TCP_ECN_MODE_ACCECN); + tp->syn_ect_snt = treq->syn_ect_snt; + tcp_accecn_third_ack(sk, skb, treq->syn_ect_snt); + tp->saw_accecn_opt = treq->saw_accecn_opt; + tp->prev_ecnfield = treq->syn_ect_rcv; + tp->accecn_opt_demand = 1; + tcp_ecn_received_counters_payload(sk, skb); + } else { + tcp_ecn_mode_set(tp, inet_rsk(req)->ecn_ok ? + TCP_ECN_MODE_RFC3168 : + TCP_ECN_DISABLED); + } } void tcp_ca_openreq_child(struct sock *sk, const struct dst_entry *dst) @@ -423,9 +518,9 @@ void tcp_ca_openreq_child(struct sock *sk, const struct dst_entry *dst) tcp_set_ca_state(sk, TCP_CA_Open); } -EXPORT_SYMBOL_GPL(tcp_ca_openreq_child); +EXPORT_IPV6_MOD_GPL(tcp_ca_openreq_child); -static void smc_check_reset_syn_req(struct tcp_sock *oldtp, +static void smc_check_reset_syn_req(const struct tcp_sock *oldtp, struct request_sock *req, struct tcp_sock *newtp) { @@ -454,7 +549,8 @@ struct sock *tcp_create_openreq_child(const struct sock *sk, const struct inet_request_sock *ireq = inet_rsk(req); struct tcp_request_sock *treq = tcp_rsk(req); struct inet_connection_sock *newicsk; - struct tcp_sock *oldtp, *newtp; + const struct tcp_sock *oldtp; + struct tcp_sock *newtp; u32 seq; if (!newsk) @@ -489,15 +585,14 @@ struct sock *tcp_create_openreq_child(const struct sock *sk, newicsk->icsk_ack.lrcvtime = tcp_jiffies32; newtp->lsndtime = tcp_jiffies32; - newsk->sk_txhash = treq->txhash; + newsk->sk_txhash = READ_ONCE(treq->txhash); newtp->total_retrans = req->num_retrans; tcp_init_xmit_timers(newsk); WRITE_ONCE(newtp->write_seq, newtp->pushed_seq = treq->snt_isn + 1); if (sock_flag(newsk, SOCK_KEEPOPEN)) - inet_csk_reset_keepalive_timer(newsk, - keepalive_time_when(newtp)); + tcp_reset_keepalive_timer(newsk, keepalive_time_when(newtp)); newtp->rx_opt.tstamp_ok = ireq->tstamp_ok; newtp->rx_opt.sack_ok = ireq->sack_ok; @@ -516,35 +611,59 @@ struct sock *tcp_create_openreq_child(const struct sock *sk, newtp->max_window = newtp->snd_wnd; if (newtp->rx_opt.tstamp_ok) { + newtp->tcp_usec_ts = treq->req_usec_ts; newtp->rx_opt.ts_recent = req->ts_recent; newtp->rx_opt.ts_recent_stamp = ktime_get_seconds(); newtp->tcp_header_len = sizeof(struct tcphdr) + TCPOLEN_TSTAMP_ALIGNED; } else { + newtp->tcp_usec_ts = 0; newtp->rx_opt.ts_recent_stamp = 0; newtp->tcp_header_len = sizeof(struct tcphdr); } if (req->num_timeout) { + newtp->total_rto = req->num_timeout; newtp->undo_marker = treq->snt_isn; - newtp->retrans_stamp = div_u64(treq->snt_synack, - USEC_PER_SEC / TCP_TS_HZ); + if (newtp->tcp_usec_ts) { + newtp->retrans_stamp = treq->snt_synack; + newtp->total_rto_time = (u32)(tcp_clock_us() - + newtp->retrans_stamp) / USEC_PER_MSEC; + } else { + newtp->retrans_stamp = div_u64(treq->snt_synack, + USEC_PER_SEC / TCP_TS_HZ); + newtp->total_rto_time = tcp_clock_ms() - + newtp->retrans_stamp; + } + newtp->total_rto_recoveries = 1; } newtp->tsoffset = treq->ts_off; #ifdef CONFIG_TCP_MD5SIG newtp->md5sig_info = NULL; /*XXX*/ - if (newtp->af_specific->md5_lookup(sk, newsk)) - newtp->tcp_header_len += TCPOLEN_MD5SIG_ALIGNED; #endif +#ifdef CONFIG_TCP_AO + newtp->ao_info = NULL; + + if (tcp_rsk_used_ao(req)) { + struct tcp_ao_key *ao_key; + + ao_key = treq->af_specific->ao_lookup(sk, req, tcp_rsk(req)->ao_keyid, -1); + if (ao_key) + newtp->tcp_header_len += tcp_ao_len_aligned(ao_key); + } + #endif if (skb->len >= TCP_MSS_DEFAULT + newtp->tcp_header_len) newicsk->icsk_ack.last_seg_size = skb->len - newtp->tcp_header_len; newtp->rx_opt.mss_clamp = req->mss; - tcp_ecn_openreq_child(newtp, req); + tcp_ecn_openreq_child(newsk, req, skb); newtp->fastopen_req = NULL; RCU_INIT_POINTER(newtp->fastopen_rsk, NULL); + newtp->bpf_chg_cc_inprogress = 0; tcp_bpf_clone(sk, newsk); __TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS); + xa_init_flags(&newsk->sk_user_frags, XA_FLAGS_ALLOC1); + return newsk; } EXPORT_SYMBOL(tcp_create_openreq_child); @@ -558,32 +677,44 @@ EXPORT_SYMBOL(tcp_create_openreq_child); * validation and inside tcp_v4_reqsk_send_ack(). Can we do better? * * We don't need to initialize tmp_opt.sack_ok as we don't use the results + * + * Note: If @fastopen is true, this can be called from process context. + * Otherwise, this is from BH context. */ struct sock *tcp_check_req(struct sock *sk, struct sk_buff *skb, struct request_sock *req, - bool fastopen, bool *req_stolen) + bool fastopen, bool *req_stolen, + enum skb_drop_reason *drop_reason) { struct tcp_options_received tmp_opt; struct sock *child; const struct tcphdr *th = tcp_hdr(skb); __be32 flg = tcp_flag_word(th) & (TCP_FLAG_RST|TCP_FLAG_SYN|TCP_FLAG_ACK); + bool tsecr_reject = false; bool paws_reject = false; bool own_req; tmp_opt.saw_tstamp = 0; + tmp_opt.accecn = 0; if (th->doff > (sizeof(struct tcphdr)>>2)) { tcp_parse_options(sock_net(sk), skb, &tmp_opt, 0, NULL); if (tmp_opt.saw_tstamp) { tmp_opt.ts_recent = req->ts_recent; - if (tmp_opt.rcv_tsecr) + if (tmp_opt.rcv_tsecr) { + if (inet_rsk(req)->tstamp_ok && !fastopen) + tsecr_reject = !between(tmp_opt.rcv_tsecr, + tcp_rsk(req)->snt_tsval_first, + READ_ONCE(tcp_rsk(req)->snt_tsval_last)); tmp_opt.rcv_tsecr -= tcp_rsk(req)->ts_off; + } /* We do not store true stamp, but it is not required, * it can be estimated (approximately) * from another data. */ - tmp_opt.ts_recent_stamp = ktime_get_seconds() - ((TCP_TIMEOUT_INIT/HZ)<<req->num_timeout); + tmp_opt.ts_recent_stamp = ktime_get_seconds() - + tcp_reqsk_timeout(req) / HZ; paws_reject = tcp_paws_reject(&tmp_opt, th->rst); } } @@ -619,11 +750,10 @@ struct sock *tcp_check_req(struct sock *sk, struct sk_buff *skb, LINUX_MIB_TCPACKSKIPPEDSYNRECV, &tcp_rsk(req)->last_oow_ack_time) && - !inet_rtx_syn_ack(sk, req)) { + !tcp_rtx_synack(sk, req)) { unsigned long expires = jiffies; - expires += min(TCP_TIMEOUT_INIT << req->num_timeout, - TCP_RTO_MAX); + expires += tcp_reqsk_timeout(req); if (!fastopen) mod_timer_pending(&req->rsk_timer, expires); else @@ -694,31 +824,34 @@ struct sock *tcp_check_req(struct sock *sk, struct sk_buff *skb, tcp_rsk(req)->snt_isn + 1)) return sk; - /* Also, it would be not so bad idea to check rcv_tsecr, which - * is essentially ACK extension and too early or too late values - * should cause reset in unsynchronized states. - */ - /* RFC793: "first check sequence number". */ - if (paws_reject || !tcp_in_window(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq, - tcp_rsk(req)->rcv_nxt, tcp_rsk(req)->rcv_nxt + req->rsk_rcv_wnd)) { + if (paws_reject || tsecr_reject || + !tcp_in_window(TCP_SKB_CB(skb)->seq, + TCP_SKB_CB(skb)->end_seq, + tcp_rsk(req)->rcv_nxt, + tcp_rsk(req)->rcv_nxt + + tcp_synack_window(req))) { /* Out of window: send ACK and drop. */ if (!(flg & TCP_FLAG_RST) && !tcp_oow_rate_limited(sock_net(sk), skb, LINUX_MIB_TCPACKSKIPPEDSYNRECV, &tcp_rsk(req)->last_oow_ack_time)) req->rsk_ops->send_ack(sk, skb, req); - if (paws_reject) - __NET_INC_STATS(sock_net(sk), LINUX_MIB_PAWSESTABREJECTED); + if (paws_reject) { + SKB_DR_SET(*drop_reason, TCP_RFC7323_PAWS); + NET_INC_STATS(sock_net(sk), LINUX_MIB_PAWSESTABREJECTED); + } else if (tsecr_reject) { + SKB_DR_SET(*drop_reason, TCP_RFC7323_TSECR); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TSECRREJECTED); + } else { + SKB_DR_SET(*drop_reason, TCP_OVERWINDOW); + } return NULL; } /* In sequence, PAWS is OK. */ - if (tmp_opt.saw_tstamp && !after(TCP_SKB_CB(skb)->seq, tcp_rsk(req)->rcv_nxt)) - req->ts_recent = tmp_opt.rcv_tsval; - if (TCP_SKB_CB(skb)->seq == tcp_rsk(req)->rcv_isn) { /* Truncate SYN, it is out of window starting at tcp_rsk(req)->rcv_isn + 1. */ @@ -729,7 +862,7 @@ struct sock *tcp_check_req(struct sock *sk, struct sk_buff *skb, * "fourth, check the SYN bit" */ if (flg & (TCP_FLAG_RST|TCP_FLAG_SYN)) { - __TCP_INC_STATS(sock_net(sk), TCP_MIB_ATTEMPTFAILS); + TCP_INC_STATS(sock_net(sk), TCP_MIB_ATTEMPTFAILS); goto embryonic_reset; } @@ -742,6 +875,18 @@ struct sock *tcp_check_req(struct sock *sk, struct sk_buff *skb, if (!(flg & TCP_FLAG_ACK)) return NULL; + if (tcp_rsk(req)->accecn_ok && tmp_opt.accecn && + tcp_rsk(req)->saw_accecn_opt < TCP_ACCECN_OPT_COUNTER_SEEN) { + u8 saw_opt = tcp_accecn_option_init(skb, tmp_opt.accecn); + + tcp_rsk(req)->saw_accecn_opt = saw_opt; + if (tcp_rsk(req)->saw_accecn_opt == TCP_ACCECN_OPT_FAIL_SEEN) { + u8 fail_mode = TCP_ACCECN_OPT_FAIL_RECV; + + tcp_rsk(req)->accecn_fail_mode |= fail_mode; + } + } + /* For Fast Open no more processing is needed (sk is the * child socket). */ @@ -749,7 +894,7 @@ struct sock *tcp_check_req(struct sock *sk, struct sk_buff *skb, return sk; /* While TCP_DEFER_ACCEPT is active, drop bare ACK. */ - if (req->num_timeout < inet_csk(sk)->icsk_accept_queue.rskq_defer_accept && + if (req->num_timeout < READ_ONCE(inet_csk(sk)->icsk_accept_queue.rskq_defer_accept) && TCP_SKB_CB(skb)->end_seq == tcp_rsk(req)->rcv_isn + 1) { inet_rsk(req)->acked = 1; __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPDEFERACCEPTDROP); @@ -767,6 +912,10 @@ struct sock *tcp_check_req(struct sock *sk, struct sk_buff *skb, if (!child) goto listen_overflow; + if (own_req && tmp_opt.saw_tstamp && + !after(TCP_SKB_CB(skb)->seq, tcp_rsk(req)->rcv_nxt)) + tcp_sk(child)->rx_opt.ts_recent = tmp_opt.rcv_tsval; + if (own_req && rsk_drop_req(req)) { reqsk_queue_removed(&inet_csk(req->rsk_listener)->icsk_accept_queue, req); inet_csk_reqsk_queue_drop_and_put(req->rsk_listener, req); @@ -779,10 +928,11 @@ struct sock *tcp_check_req(struct sock *sk, struct sk_buff *skb, return inet_csk_complete_hashdance(sk, child, req, own_req); listen_overflow: + SKB_DR_SET(*drop_reason, TCP_LISTEN_OVERFLOW); if (sk != req->rsk_listener) __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMIGRATEREQFAILURE); - if (!sock_net(sk)->ipv4.sysctl_tcp_abort_on_overflow) { + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_abort_on_overflow)) { inet_rsk(req)->acked = 1; return NULL; } @@ -794,7 +944,7 @@ embryonic_reset: * avoid becoming vulnerable to outside attack aiming at * resetting legit local connections. */ - req->rsk_ops->send_reset(sk, skb); + req->rsk_ops->send_reset(sk, skb, SK_RST_REASON_INVALID_SYN); } else if (fastopen) { /* received a valid RST pkt */ reqsk_fastopen_remove(sk, req, true); tcp_reset(sk, skb); @@ -808,7 +958,7 @@ embryonic_reset: } return NULL; } -EXPORT_SYMBOL(tcp_check_req); +EXPORT_IPV6_MOD(tcp_check_req); /* * Queue segment on the new socket if the new socket is active, @@ -822,11 +972,11 @@ EXPORT_SYMBOL(tcp_check_req); * be created. */ -int tcp_child_process(struct sock *parent, struct sock *child, - struct sk_buff *skb) +enum skb_drop_reason tcp_child_process(struct sock *parent, struct sock *child, + struct sk_buff *skb) __releases(&((child)->sk_lock.slock)) { - int ret = 0; + enum skb_drop_reason reason = SKB_NOT_DROPPED_YET; int state = child->sk_state; /* record sk_napi_id and sk_rx_queue_mapping of child. */ @@ -834,7 +984,7 @@ int tcp_child_process(struct sock *parent, struct sock *child, tcp_segs_in(tcp_sk(child), skb); if (!sock_owned_by_user(child)) { - ret = tcp_rcv_state_process(child, skb); + reason = tcp_rcv_state_process(child, skb); /* Wakeup parent, send SIGIO */ if (state == TCP_SYN_RECV && child->sk_state != state) parent->sk_data_ready(parent); @@ -848,6 +998,6 @@ int tcp_child_process(struct sock *parent, struct sock *child, bh_unlock_sock(child); sock_put(child); - return ret; + return reason; } -EXPORT_SYMBOL(tcp_child_process); +EXPORT_IPV6_MOD(tcp_child_process); diff --git a/net/ipv4/tcp_nv.c b/net/ipv4/tcp_nv.c index ab552356bdba..a60662f4bdf9 100644 --- a/net/ipv4/tcp_nv.c +++ b/net/ipv4/tcp_nv.c @@ -197,10 +197,10 @@ static void tcpnv_cong_avoid(struct sock *sk, u32 ack, u32 acked) } if (ca->cwnd_growth_factor < 0) { - cnt = tp->snd_cwnd << -ca->cwnd_growth_factor; + cnt = tcp_snd_cwnd(tp) << -ca->cwnd_growth_factor; tcp_cong_avoid_ai(tp, cnt, acked); } else { - cnt = max(4U, tp->snd_cwnd >> ca->cwnd_growth_factor); + cnt = max(4U, tcp_snd_cwnd(tp) >> ca->cwnd_growth_factor); tcp_cong_avoid_ai(tp, cnt, acked); } } @@ -209,7 +209,7 @@ static u32 tcpnv_recalc_ssthresh(struct sock *sk) { const struct tcp_sock *tp = tcp_sk(sk); - return max((tp->snd_cwnd * nv_loss_dec_factor) >> 10, 2U); + return max((tcp_snd_cwnd(tp) * nv_loss_dec_factor) >> 10, 2U); } static void tcpnv_state(struct sock *sk, u8 new_state) @@ -257,7 +257,7 @@ static void tcpnv_acked(struct sock *sk, const struct ack_sample *sample) return; /* Stop cwnd growth if we were in catch up mode */ - if (ca->nv_catchup && tp->snd_cwnd >= nv_min_cwnd) { + if (ca->nv_catchup && tcp_snd_cwnd(tp) >= nv_min_cwnd) { ca->nv_catchup = 0; ca->nv_allow_cwnd_growth = 0; } @@ -371,7 +371,7 @@ static void tcpnv_acked(struct sock *sk, const struct ack_sample *sample) * if cwnd < max_win, grow cwnd * else leave the same */ - if (tp->snd_cwnd > max_win) { + if (tcp_snd_cwnd(tp) > max_win) { /* there is congestion, check that it is ok * to make a CA decision * 1. We should have at least nv_dec_eval_min_calls @@ -398,20 +398,20 @@ static void tcpnv_acked(struct sock *sk, const struct ack_sample *sample) ca->nv_allow_cwnd_growth = 0; tp->snd_ssthresh = (nv_ssthresh_factor * max_win) >> 3; - if (tp->snd_cwnd - max_win > 2) { + if (tcp_snd_cwnd(tp) - max_win > 2) { /* gap > 2, we do exponential cwnd decrease */ int dec; - dec = max(2U, ((tp->snd_cwnd - max_win) * + dec = max(2U, ((tcp_snd_cwnd(tp) - max_win) * nv_cong_dec_mult) >> 7); - tp->snd_cwnd -= dec; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) - dec); } else if (nv_cong_dec_mult > 0) { - tp->snd_cwnd = max_win; + tcp_snd_cwnd_set(tp, max_win); } if (ca->cwnd_growth_factor > 0) ca->cwnd_growth_factor = 0; ca->nv_no_cong_cnt = 0; - } else if (tp->snd_cwnd <= max_win - nv_pad_buffer) { + } else if (tcp_snd_cwnd(tp) <= max_win - nv_pad_buffer) { /* There is no congestion, grow cwnd if allowed*/ if (ca->nv_eval_call_cnt < nv_inc_eval_min_calls) return; @@ -444,8 +444,8 @@ static void tcpnv_acked(struct sock *sk, const struct ack_sample *sample) * (it wasn't before, if it is now is because nv * decreased it). */ - if (tp->snd_cwnd < nv_min_cwnd) - tp->snd_cwnd = nv_min_cwnd; + if (tcp_snd_cwnd(tp) < nv_min_cwnd) + tcp_snd_cwnd_set(tp, nv_min_cwnd); } } diff --git a/net/ipv4/tcp_offload.c b/net/ipv4/tcp_offload.c index 30abde86db45..fdda18b1abda 100644 --- a/net/ipv4/tcp_offload.c +++ b/net/ipv4/tcp_offload.c @@ -9,15 +9,19 @@ #include <linux/indirect_call_wrapper.h> #include <linux/skbuff.h> #include <net/gro.h> +#include <net/gso.h> #include <net/tcp.h> #include <net/protocol.h> -static void tcp_gso_tstamp(struct sk_buff *skb, unsigned int ts_seq, +static void tcp_gso_tstamp(struct sk_buff *skb, struct sk_buff *gso_skb, unsigned int seq, unsigned int mss) { + u32 flags = skb_shinfo(gso_skb)->tx_flags & SKBTX_ANY_TSTAMP; + u32 ts_seq = skb_shinfo(gso_skb)->tskey; + while (skb) { if (before(ts_seq, seq + mss)) { - skb_shinfo(skb)->tx_flags |= SKBTX_SW_TSTAMP; + skb_shinfo(skb)->tx_flags |= flags; skb_shinfo(skb)->tskey = ts_seq; return; } @@ -27,6 +31,70 @@ static void tcp_gso_tstamp(struct sk_buff *skb, unsigned int ts_seq, } } +static void __tcpv4_gso_segment_csum(struct sk_buff *seg, + __be32 *oldip, __be32 newip, + __be16 *oldport, __be16 newport) +{ + struct tcphdr *th; + struct iphdr *iph; + + if (*oldip == newip && *oldport == newport) + return; + + th = tcp_hdr(seg); + iph = ip_hdr(seg); + + inet_proto_csum_replace4(&th->check, seg, *oldip, newip, true); + inet_proto_csum_replace2(&th->check, seg, *oldport, newport, false); + *oldport = newport; + + csum_replace4(&iph->check, *oldip, newip); + *oldip = newip; +} + +static struct sk_buff *__tcpv4_gso_segment_list_csum(struct sk_buff *segs) +{ + const struct tcphdr *th; + const struct iphdr *iph; + struct sk_buff *seg; + struct tcphdr *th2; + struct iphdr *iph2; + + seg = segs; + th = tcp_hdr(seg); + iph = ip_hdr(seg); + th2 = tcp_hdr(seg->next); + iph2 = ip_hdr(seg->next); + + if (!(*(const u32 *)&th->source ^ *(const u32 *)&th2->source) && + iph->daddr == iph2->daddr && iph->saddr == iph2->saddr) + return segs; + + while ((seg = seg->next)) { + th2 = tcp_hdr(seg); + iph2 = ip_hdr(seg); + + __tcpv4_gso_segment_csum(seg, + &iph2->saddr, iph->saddr, + &th2->source, th->source); + __tcpv4_gso_segment_csum(seg, + &iph2->daddr, iph->daddr, + &th2->dest, th->dest); + } + + return segs; +} + +static struct sk_buff *__tcp4_gso_segment_list(struct sk_buff *skb, + netdev_features_t features) +{ + skb = skb_segment_list(skb, features, skb_mac_header_len(skb)); + if (IS_ERR(skb)) + return skb; + + return __tcpv4_gso_segment_list_csum(skb); +} + static struct sk_buff *tcp4_gso_segment(struct sk_buff *skb, netdev_features_t features) { @@ -36,6 +104,15 @@ static struct sk_buff *tcp4_gso_segment(struct sk_buff *skb, if (!pskb_may_pull(skb, sizeof(struct tcphdr))) return ERR_PTR(-EINVAL); + if (skb_shinfo(skb)->gso_type & SKB_GSO_FRAGLIST) { + struct tcphdr *th = tcp_hdr(skb); + + if (skb_pagelen(skb) - th->doff * 4 == skb_shinfo(skb)->gso_size) + return __tcp4_gso_segment_list(skb, features); + + skb->ip_summed = CHECKSUM_NONE; + } + if (unlikely(skb->ip_summed != CHECKSUM_PARTIAL)) { const struct iphdr *iph = ip_hdr(skb); struct tcphdr *th = tcp_hdr(skb); @@ -60,22 +137,26 @@ struct sk_buff *tcp_gso_segment(struct sk_buff *skb, struct tcphdr *th; unsigned int thlen; unsigned int seq; - __be32 delta; unsigned int oldlen; unsigned int mss; struct sk_buff *gso_skb = skb; __sum16 newcheck; bool ooo_okay, copy_destructor; + bool ecn_cwr_mask; + __wsum delta; th = tcp_hdr(skb); thlen = th->doff * 4; if (thlen < sizeof(*th)) goto out; + if (unlikely(skb_checksum_start(skb) != skb_transport_header(skb))) + goto out; + if (!pskb_may_pull(skb, thlen)) goto out; - oldlen = (u16)~skb->len; + oldlen = ~skb->len; __skb_pull(skb, thlen); mss = skb_shinfo(skb)->gso_size; @@ -110,17 +191,18 @@ struct sk_buff *tcp_gso_segment(struct sk_buff *skb, if (skb_is_gso(segs)) mss *= skb_shinfo(segs)->gso_segs; - delta = htonl(oldlen + (thlen + mss)); + delta = (__force __wsum)htonl(oldlen + thlen + mss); skb = segs; th = tcp_hdr(skb); seq = ntohl(th->seq); - if (unlikely(skb_shinfo(gso_skb)->tx_flags & SKBTX_SW_TSTAMP)) - tcp_gso_tstamp(segs, skb_shinfo(gso_skb)->tskey, seq, mss); + if (unlikely(skb_shinfo(gso_skb)->tx_flags & SKBTX_ANY_TSTAMP)) + tcp_gso_tstamp(segs, gso_skb, seq, mss); - newcheck = ~csum_fold((__force __wsum)((__force u32)th->check + - (__force u32)delta)); + newcheck = ~csum_fold(csum_add(csum_unfold(th->check), delta)); + + ecn_cwr_mask = !!(skb_shinfo(gso_skb)->gso_type & SKB_GSO_TCP_ACCECN); while (skb->next) { th->fin = th->psh = 0; @@ -141,7 +223,8 @@ struct sk_buff *tcp_gso_segment(struct sk_buff *skb, th = tcp_hdr(skb); th->seq = htonl(seq); - th->cwr = 0; + + th->cwr &= ecn_cwr_mask; } /* Following permits TCP Small Queues to work well with GSO : @@ -165,11 +248,11 @@ struct sk_buff *tcp_gso_segment(struct sk_buff *skb, WARN_ON_ONCE(refcount_sub_and_test(-delta, &skb->sk->sk_wmem_alloc)); } - delta = htonl(oldlen + (skb_tail_pointer(skb) - - skb_transport_header(skb)) + - skb->data_len); - th->check = ~csum_fold((__force __wsum)((__force u32)th->check + - (__force u32)delta)); + delta = (__force __wsum)htonl(oldlen + + (skb_tail_pointer(skb) - + skb_transport_header(skb)) + + skb->data_len); + th->check = ~csum_fold(csum_add(csum_unfold(th->check), delta)); if (skb->ip_summed == CHECKSUM_PARTIAL) gso_reset_checksum(skb, ~th->check); else @@ -178,91 +261,84 @@ out: return segs; } -struct sk_buff *tcp_gro_receive(struct list_head *head, struct sk_buff *skb) +struct sk_buff *tcp_gro_lookup(struct list_head *head, struct tcphdr *th) { - struct sk_buff *pp = NULL; - struct sk_buff *p; - struct tcphdr *th; struct tcphdr *th2; - unsigned int len; - unsigned int thlen; - __be32 flags; - unsigned int mss = 1; - unsigned int hlen; - unsigned int off; - int flush = 1; - int i; - - off = skb_gro_offset(skb); - hlen = off + sizeof(*th); - th = skb_gro_header_fast(skb, off); - if (skb_gro_header_hard(skb, hlen)) { - th = skb_gro_header_slow(skb, hlen, off); - if (unlikely(!th)) - goto out; - } - - thlen = th->doff * 4; - if (thlen < sizeof(*th)) - goto out; - - hlen = off + thlen; - if (skb_gro_header_hard(skb, hlen)) { - th = skb_gro_header_slow(skb, hlen, off); - if (unlikely(!th)) - goto out; - } - - skb_gro_pull(skb, thlen); - - len = skb_gro_len(skb); - flags = tcp_flag_word(th); + struct sk_buff *p; list_for_each_entry(p, head, list) { if (!NAPI_GRO_CB(p)->same_flow) continue; th2 = tcp_hdr(p); - if (*(u32 *)&th->source ^ *(u32 *)&th2->source) { NAPI_GRO_CB(p)->same_flow = 0; continue; } - goto found; + return p; } - p = NULL; - goto out_check_final; -found: - /* Include the IP ID check below from the inner most IP hdr */ - flush = NAPI_GRO_CB(p)->flush; - flush |= (__force int)(flags & TCP_FLAG_CWR); + return NULL; +} + +struct sk_buff *tcp_gro_receive(struct list_head *head, struct sk_buff *skb, + struct tcphdr *th) +{ + unsigned int thlen = th->doff * 4; + struct sk_buff *pp = NULL; + struct sk_buff *p; + struct tcphdr *th2; + unsigned int len; + __be32 flags; + unsigned int mss = 1; + int flush = 1; + int i; + + len = skb_gro_len(skb); + flags = tcp_flag_word(th); + + p = tcp_gro_lookup(head, th); + if (!p) + goto out_check_final; + + th2 = tcp_hdr(p); + flush = (__force int)(flags & TCP_FLAG_CWR); flush |= (__force int)((flags ^ tcp_flag_word(th2)) & - ~(TCP_FLAG_CWR | TCP_FLAG_FIN | TCP_FLAG_PSH)); + ~(TCP_FLAG_FIN | TCP_FLAG_PSH)); flush |= (__force int)(th->ack_seq ^ th2->ack_seq); for (i = sizeof(*th); i < thlen; i += 4) flush |= *(u32 *)((u8 *)th + i) ^ *(u32 *)((u8 *)th2 + i); - /* When we receive our second frame we can made a decision on if we - * continue this flow as an atomic flow with a fixed ID or if we use - * an incrementing ID. - */ - if (NAPI_GRO_CB(p)->flush_id != 1 || - NAPI_GRO_CB(p)->count != 1 || - !NAPI_GRO_CB(p)->is_atomic) - flush |= NAPI_GRO_CB(p)->flush_id; - else - NAPI_GRO_CB(p)->is_atomic = false; + flush |= gro_receive_network_flush(th, th2, p); mss = skb_shinfo(p)->gso_size; - flush |= (len - 1) >= mss; + /* If skb is a GRO packet, make sure its gso_size matches prior packet mss. + * If it is a single frame, do not aggregate it if its length + * is bigger than our mss. + */ + if (unlikely(skb_is_gso(skb))) + flush |= (mss != skb_shinfo(skb)->gso_size); + else + flush |= (len - 1) >= mss; + flush |= (ntohl(th2->seq) + skb_gro_len(p)) ^ ntohl(th->seq); -#ifdef CONFIG_TLS_DEVICE - flush |= p->decrypted ^ skb->decrypted; -#endif + flush |= skb_cmp_decrypted(p, skb); + + if (unlikely(NAPI_GRO_CB(p)->is_flist)) { + flush |= (__force int)(flags ^ tcp_flag_word(th2)); + flush |= skb->ip_summed != p->ip_summed; + flush |= skb->csum_level != p->csum_level; + flush |= NAPI_GRO_CB(p)->count >= 64; + skb_set_network_header(skb, skb_gro_receive_network_offset(skb)); + + if (flush || skb_gro_receive_list(p, skb)) + mss = 1; + + goto out_check_final; + } if (flush || skb_gro_receive(p, skb)) { mss = 1; @@ -272,7 +348,12 @@ found: tcp_flag_word(th2) |= flags & (TCP_FLAG_FIN | TCP_FLAG_PSH); out_check_final: - flush = len < mss; + /* Force a flush if last segment is smaller than mss. */ + if (unlikely(skb_is_gso(skb))) + flush = len != NAPI_GRO_CB(skb)->count * skb_shinfo(skb)->gso_size; + else + flush = len < mss; + flush |= (__force int)(flags & (TCP_FLAG_URG | TCP_FLAG_PSH | TCP_FLAG_RST | TCP_FLAG_SYN | TCP_FLAG_FIN)); @@ -280,70 +361,118 @@ out_check_final: if (p && (!NAPI_GRO_CB(skb)->same_flow || flush)) pp = p; -out: NAPI_GRO_CB(skb)->flush |= (flush != 0); return pp; } -int tcp_gro_complete(struct sk_buff *skb) +void tcp_gro_complete(struct sk_buff *skb) { struct tcphdr *th = tcp_hdr(skb); + struct skb_shared_info *shinfo; + + if (skb->encapsulation) + skb->inner_transport_header = skb->transport_header; skb->csum_start = (unsigned char *)th - skb->head; skb->csum_offset = offsetof(struct tcphdr, check); skb->ip_summed = CHECKSUM_PARTIAL; - skb_shinfo(skb)->gso_segs = NAPI_GRO_CB(skb)->count; + shinfo = skb_shinfo(skb); + shinfo->gso_segs = NAPI_GRO_CB(skb)->count; if (th->cwr) - skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_ECN; + shinfo->gso_type |= SKB_GSO_TCP_ACCECN; +} +EXPORT_SYMBOL(tcp_gro_complete); - if (skb->encapsulation) - skb->inner_transport_header = skb->transport_header; +static void tcp4_check_fraglist_gro(struct list_head *head, struct sk_buff *skb, + struct tcphdr *th) +{ + const struct iphdr *iph; + struct sk_buff *p; + struct sock *sk; + struct net *net; + int iif, sdif; - return 0; + if (likely(!(skb->dev->features & NETIF_F_GRO_FRAGLIST))) + return; + + p = tcp_gro_lookup(head, th); + if (p) { + NAPI_GRO_CB(skb)->is_flist = NAPI_GRO_CB(p)->is_flist; + return; + } + + inet_get_iif_sdif(skb, &iif, &sdif); + iph = skb_gro_network_header(skb); + net = dev_net_rcu(skb->dev); + sk = __inet_lookup_established(net, iph->saddr, th->source, + iph->daddr, ntohs(th->dest), + iif, sdif); + NAPI_GRO_CB(skb)->is_flist = !sk; + if (sk) + sock_gen_put(sk); } -EXPORT_SYMBOL(tcp_gro_complete); INDIRECT_CALLABLE_SCOPE struct sk_buff *tcp4_gro_receive(struct list_head *head, struct sk_buff *skb) { + struct tcphdr *th; + /* Don't bother verifying checksum if we're going to flush anyway. */ if (!NAPI_GRO_CB(skb)->flush && skb_gro_checksum_validate(skb, IPPROTO_TCP, - inet_gro_compute_pseudo)) { - NAPI_GRO_CB(skb)->flush = 1; - return NULL; - } + inet_gro_compute_pseudo)) + goto flush; + + th = tcp_gro_pull_header(skb); + if (!th) + goto flush; + + tcp4_check_fraglist_gro(head, skb, th); + + return tcp_gro_receive(head, skb, th); - return tcp_gro_receive(head, skb); +flush: + NAPI_GRO_CB(skb)->flush = 1; + return NULL; } INDIRECT_CALLABLE_SCOPE int tcp4_gro_complete(struct sk_buff *skb, int thoff) { - const struct iphdr *iph = ip_hdr(skb); + const u16 offset = NAPI_GRO_CB(skb)->network_offsets[skb->encapsulation]; + const struct iphdr *iph = (struct iphdr *)(skb->data + offset); struct tcphdr *th = tcp_hdr(skb); + if (unlikely(NAPI_GRO_CB(skb)->is_flist)) { + skb_shinfo(skb)->gso_type |= SKB_GSO_FRAGLIST | SKB_GSO_TCPV4; + skb_shinfo(skb)->gso_segs = NAPI_GRO_CB(skb)->count; + + __skb_incr_checksum_unnecessary(skb); + + return 0; + } + th->check = ~tcp_v4_check(skb->len - thoff, iph->saddr, iph->daddr, 0); - skb_shinfo(skb)->gso_type |= SKB_GSO_TCPV4; - if (NAPI_GRO_CB(skb)->is_atomic) - skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_FIXEDID; + BUILD_BUG_ON(SKB_GSO_TCP_FIXEDID << 1 != SKB_GSO_TCP_FIXEDID_INNER); + skb_shinfo(skb)->gso_type |= SKB_GSO_TCPV4 | + (NAPI_GRO_CB(skb)->ip_fixedid * SKB_GSO_TCP_FIXEDID); - return tcp_gro_complete(skb); + tcp_gro_complete(skb); + return 0; } -static const struct net_offload tcpv4_offload = { - .callbacks = { - .gso_segment = tcp4_gso_segment, - .gro_receive = tcp4_gro_receive, - .gro_complete = tcp4_gro_complete, - }, -}; - int __init tcpv4_offload_init(void) { - return inet_add_offload(&tcpv4_offload, IPPROTO_TCP); + net_hotdata.tcpv4_offload = (struct net_offload) { + .callbacks = { + .gso_segment = tcp4_gso_segment, + .gro_receive = tcp4_gro_receive, + .gro_complete = tcp4_gro_complete, + }, + }; + return inet_add_offload(&net_hotdata.tcpv4_offload, IPPROTO_TCP); } diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index 5079832af5c1..479afb714bdf 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -38,12 +38,17 @@ #define pr_fmt(fmt) "TCP: " fmt #include <net/tcp.h> +#include <net/tcp_ecn.h> #include <net/mptcp.h> +#include <net/smc.h> +#include <net/proto_memory.h> +#include <net/psp.h> #include <linux/compiler.h> #include <linux/gfp.h> #include <linux/module.h> #include <linux/static_key.h> +#include <linux/skbuff_ref.h> #include <trace/events/tcp.h> @@ -82,6 +87,7 @@ static void tcp_event_new_data_sent(struct sock *sk, struct sk_buff *skb) NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPORIGDATASENT, tcp_skb_pcount(skb)); + tcp_check_space(sk); } /* SND.NXT, if window was not shrunk or the amount of shrunk was less than one @@ -142,7 +148,7 @@ void tcp_cwnd_restart(struct sock *sk, s32 delta) { struct tcp_sock *tp = tcp_sk(sk); u32 restart_cwnd = tcp_init_cwnd(tp, __sk_dst_get(sk)); - u32 cwnd = tp->snd_cwnd; + u32 cwnd = tcp_snd_cwnd(tp); tcp_ca_event(sk, CA_EVENT_CWND_RESTART); @@ -151,7 +157,7 @@ void tcp_cwnd_restart(struct sock *sk, s32 delta) while ((delta -= inet_csk(sk)->icsk_rto) > 0 && cwnd > restart_cwnd) cwnd >>= 1; - tp->snd_cwnd = max(cwnd, restart_cwnd); + tcp_snd_cwnd_set(tp, max(cwnd, restart_cwnd)); tp->snd_cwnd_stamp = tcp_jiffies32; tp->snd_cwnd_used = 0; } @@ -166,21 +172,17 @@ static void tcp_event_data_sent(struct tcp_sock *tp, if (tcp_packets_in_flight(tp) == 0) tcp_ca_event(sk, CA_EVENT_TX_START); - /* If this is the first data packet sent in response to the - * previous received data, - * and it is a reply for ato after last received packet, - * increase pingpong count. + tp->lsndtime = now; + + /* If it is a reply for ato after last received + * packet, increase pingpong count. */ - if (before(tp->lsndtime, icsk->icsk_ack.lrcvtime) && - (u32)(now - icsk->icsk_ack.lrcvtime) < icsk->icsk_ack.ato) + if ((u32)(now - icsk->icsk_ack.lrcvtime) < icsk->icsk_ack.ato) inet_csk_inc_pingpong_cnt(sk); - - tp->lsndtime = now; } /* Account for an ACK we sent. */ -static inline void tcp_event_ack_sent(struct sock *sk, unsigned int pkts, - u32 rcv_nxt) +static inline void tcp_event_ack_sent(struct sock *sk, u32 rcv_nxt) { struct tcp_sock *tp = tcp_sk(sk); @@ -194,7 +196,7 @@ static inline void tcp_event_ack_sent(struct sock *sk, unsigned int pkts, if (unlikely(rcv_nxt != tp->rcv_nxt)) return; /* Special ACK sent by DCTCP to reflect ECN */ - tcp_dec_quickack_mode(sk, pkts); + tcp_dec_quickack_mode(sk); inet_csk_clear_xmit_timer(sk, ICSK_TIME_DACK); } @@ -206,16 +208,17 @@ static inline void tcp_event_ack_sent(struct sock *sk, unsigned int pkts, * This MUST be enforced by all callers. */ void tcp_select_initial_window(const struct sock *sk, int __space, __u32 mss, - __u32 *rcv_wnd, __u32 *window_clamp, + __u32 *rcv_wnd, __u32 *__window_clamp, int wscale_ok, __u8 *rcv_wscale, __u32 init_rcv_wnd) { unsigned int space = (__space < 0 ? 0 : __space); + u32 window_clamp = READ_ONCE(*__window_clamp); /* If no clamp set the clamp to the max possible scaled window */ - if (*window_clamp == 0) - (*window_clamp) = (U16_MAX << TCP_MAX_WSCALE); - space = min(*window_clamp, space); + if (window_clamp == 0) + window_clamp = (U16_MAX << TCP_MAX_WSCALE); + space = min(window_clamp, space); /* Quantize space offering to a multiple of mss if possible. */ if (space > mss) @@ -229,10 +232,10 @@ void tcp_select_initial_window(const struct sock *sk, int __space, __u32 mss, * which we interpret as a sign the remote TCP is not * misinterpreting the window field as a signed quantity. */ - if (sock_net(sk)->ipv4.sysctl_tcp_workaround_signed_windows) + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_workaround_signed_windows)) (*rcv_wnd) = min(space, MAX_TCP_WINDOW); else - (*rcv_wnd) = min_t(u32, space, U16_MAX); + (*rcv_wnd) = space; if (init_rcv_wnd) *rcv_wnd = min(*rcv_wnd, init_rcv_wnd * mss); @@ -240,16 +243,17 @@ void tcp_select_initial_window(const struct sock *sk, int __space, __u32 mss, *rcv_wscale = 0; if (wscale_ok) { /* Set window scaling on max possible window */ - space = max_t(u32, space, sock_net(sk)->ipv4.sysctl_tcp_rmem[2]); - space = max_t(u32, space, sysctl_rmem_max); - space = min_t(u32, space, *window_clamp); + space = max_t(u32, space, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[2])); + space = max_t(u32, space, READ_ONCE(sysctl_rmem_max)); + space = min_t(u32, space, window_clamp); *rcv_wscale = clamp_t(int, ilog2(space) - 15, 0, TCP_MAX_WSCALE); } /* Set the clamp no higher than max representable value */ - (*window_clamp) = min_t(__u32, U16_MAX << (*rcv_wscale), *window_clamp); + WRITE_ONCE(*__window_clamp, + min_t(__u32, U16_MAX << (*rcv_wscale), window_clamp)); } -EXPORT_SYMBOL(tcp_select_initial_window); +EXPORT_IPV6_MOD(tcp_select_initial_window); /* Chose a new window to advertise, update state in tcp_sock for the * socket, and return result with RFC1323 scaling applied. The return @@ -259,11 +263,22 @@ EXPORT_SYMBOL(tcp_select_initial_window); static u16 tcp_select_window(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); u32 old_win = tp->rcv_wnd; - u32 cur_win = tcp_receive_window(tp); - u32 new_win = __tcp_select_window(sk); + u32 cur_win, new_win; + + /* Make the window 0 if we failed to queue the data because we + * are out of memory. + */ + if (unlikely(inet_csk(sk)->icsk_ack.pending & ICSK_ACK_NOMEM)) { + tp->pred_flags = 0; + tp->rcv_wnd = 0; + tp->rcv_wup = tp->rcv_nxt; + return 0; + } - /* Never shrink the offered window */ + cur_win = tcp_receive_window(tp); + new_win = __tcp_select_window(sk); if (new_win < cur_win) { /* Danger Will Robinson! * Don't update rcv_wup/rcv_wnd here or else @@ -272,11 +287,14 @@ static u16 tcp_select_window(struct sock *sk) * * Relax Will Robinson. */ - if (new_win == 0) - NET_INC_STATS(sock_net(sk), - LINUX_MIB_TCPWANTZEROWINDOWADV); - new_win = ALIGN(cur_win, 1 << tp->rx_opt.rcv_wscale); + if (!READ_ONCE(net->ipv4.sysctl_tcp_shrink_window) || !tp->rx_opt.rcv_wscale) { + /* Never shrink the offered window */ + if (new_win == 0) + NET_INC_STATS(net, LINUX_MIB_TCPWANTZEROWINDOWADV); + new_win = ALIGN(cur_win, 1 << tp->rx_opt.rcv_wscale); + } } + tp->rcv_wnd = new_win; tp->rcv_wup = tp->rcv_nxt; @@ -284,7 +302,7 @@ static u16 tcp_select_window(struct sock *sk) * scaled window. */ if (!tp->rx_opt.rcv_wscale && - sock_net(sk)->ipv4.sysctl_tcp_workaround_signed_windows) + READ_ONCE(net->ipv4.sysctl_tcp_workaround_signed_windows)) new_win = min(new_win, MAX_TCP_WINDOW); else new_win = min(new_win, (65535U << tp->rx_opt.rcv_wscale)); @@ -296,69 +314,14 @@ static u16 tcp_select_window(struct sock *sk) if (new_win == 0) { tp->pred_flags = 0; if (old_win) - NET_INC_STATS(sock_net(sk), - LINUX_MIB_TCPTOZEROWINDOWADV); + NET_INC_STATS(net, LINUX_MIB_TCPTOZEROWINDOWADV); } else if (old_win == 0) { - NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPFROMZEROWINDOWADV); + NET_INC_STATS(net, LINUX_MIB_TCPFROMZEROWINDOWADV); } return new_win; } -/* Packet ECN state for a SYN-ACK */ -static void tcp_ecn_send_synack(struct sock *sk, struct sk_buff *skb) -{ - const struct tcp_sock *tp = tcp_sk(sk); - - TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_CWR; - if (!(tp->ecn_flags & TCP_ECN_OK)) - TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_ECE; - else if (tcp_ca_needs_ecn(sk) || - tcp_bpf_ca_needs_ecn(sk)) - INET_ECN_xmit(sk); -} - -/* Packet ECN state for a SYN. */ -static void tcp_ecn_send_syn(struct sock *sk, struct sk_buff *skb) -{ - struct tcp_sock *tp = tcp_sk(sk); - bool bpf_needs_ecn = tcp_bpf_ca_needs_ecn(sk); - bool use_ecn = sock_net(sk)->ipv4.sysctl_tcp_ecn == 1 || - tcp_ca_needs_ecn(sk) || bpf_needs_ecn; - - if (!use_ecn) { - const struct dst_entry *dst = __sk_dst_get(sk); - - if (dst && dst_feature(dst, RTAX_FEATURE_ECN)) - use_ecn = true; - } - - tp->ecn_flags = 0; - - if (use_ecn) { - TCP_SKB_CB(skb)->tcp_flags |= TCPHDR_ECE | TCPHDR_CWR; - tp->ecn_flags = TCP_ECN_OK; - if (tcp_ca_needs_ecn(sk) || bpf_needs_ecn) - INET_ECN_xmit(sk); - } -} - -static void tcp_ecn_clear_syn(struct sock *sk, struct sk_buff *skb) -{ - if (sock_net(sk)->ipv4.sysctl_tcp_ecn_fallback) - /* tp->ecn_flags are cleared at a later point in time when - * SYN ACK is ultimatively being received. - */ - TCP_SKB_CB(skb)->tcp_flags &= ~(TCPHDR_ECE | TCPHDR_CWR); -} - -static void -tcp_ecn_make_synack(const struct request_sock *req, struct tcphdr *th) -{ - if (inet_rsk(req)->ecn_ok) - th->ece = 1; -} - /* Set up ECN state for a packet on a ESTABLISHED socket that is about to * be sent. */ @@ -367,7 +330,15 @@ static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb, { struct tcp_sock *tp = tcp_sk(sk); - if (tp->ecn_flags & TCP_ECN_OK) { + if (!tcp_ecn_mode_any(tp)) + return; + + if (tcp_ecn_mode_accecn(tp)) { + if (!tcp_accecn_ace_fail_recv(tp)) + INET_ECN_xmit(sk); + tcp_accecn_set_ace(tp, skb, th); + skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_ACCECN; + } else { /* Not-retransmitted data segment: set ECT and inject CWR. */ if (skb->len != tcp_header_len && !before(TCP_SKB_CB(skb)->seq, tp->snd_nxt)) { @@ -389,13 +360,15 @@ static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb, /* Constructs common control bits of non-data skb. If SYN/FIN is present, * auto increment end seqno. */ -static void tcp_init_nondata_skb(struct sk_buff *skb, u32 seq, u8 flags) +static void tcp_init_nondata_skb(struct sk_buff *skb, struct sock *sk, + u32 seq, u16 flags) { skb->ip_summed = CHECKSUM_PARTIAL; TCP_SKB_CB(skb)->tcp_flags = flags; tcp_skb_pcount_set(skb, 1); + psp_enqueue_set_decrypted(sk, skb); TCP_SKB_CB(skb)->seq = seq; if (flags & (TCPHDR_SYN | TCPHDR_FIN)) @@ -415,6 +388,8 @@ static inline bool tcp_urg_mode(const struct tcp_sock *tp) #define OPTION_FAST_OPEN_COOKIE BIT(8) #define OPTION_SMC BIT(9) #define OPTION_MPTCP BIT(10) +#define OPTION_AO BIT(11) +#define OPTION_ACCECN BIT(12) static void smc_options_write(__be32 *ptr, u16 *options) { @@ -436,6 +411,8 @@ struct tcp_out_options { u16 mss; /* 0 to disable */ u8 ws; /* window scale, 0 to disable */ u8 num_sack_blocks; /* number of SACK blocks to include */ + u8 num_accecn_fields:7, /* number of AccECN fields needed */ + use_synack_ecn_bytes:1; /* Use synack_ecn_bytes or not */ u8 hash_size; /* bytes in hash_location */ u8 bpf_opt_len; /* length of BPF hdr option */ __u8 *hash_location; /* temporary pointer, overloaded */ @@ -444,12 +421,13 @@ struct tcp_out_options { struct mptcp_out_options mptcp; }; -static void mptcp_options_write(__be32 *ptr, const struct tcp_sock *tp, +static void mptcp_options_write(struct tcphdr *th, __be32 *ptr, + struct tcp_sock *tp, struct tcp_out_options *opts) { #if IS_ENABLED(CONFIG_MPTCP) if (unlikely(OPTION_MPTCP & opts->options)) - mptcp_write_options(ptr, tp, &opts->mptcp); + mptcp_write_options(th, ptr, tp, &opts->mptcp); #endif } @@ -509,6 +487,7 @@ static void bpf_skops_hdr_opt_len(struct sock *sk, struct sk_buff *skb, sock_owned_by_me(sk); sock_ops.is_fullsock = 1; + sock_ops.is_locked_tcp_sock = 1; sock_ops.sk = sk; } @@ -554,6 +533,7 @@ static void bpf_skops_write_hdr_opt(struct sock *sk, struct sk_buff *skb, sock_owned_by_me(sk); sock_ops.is_fullsock = 1; + sock_ops.is_locked_tcp_sock = 1; sock_ops.sk = sk; } @@ -592,6 +572,49 @@ static void bpf_skops_write_hdr_opt(struct sock *sk, struct sk_buff *skb, } #endif +static __be32 *process_tcp_ao_options(struct tcp_sock *tp, + const struct tcp_request_sock *tcprsk, + struct tcp_out_options *opts, + struct tcp_key *key, __be32 *ptr) +{ +#ifdef CONFIG_TCP_AO + u8 maclen = tcp_ao_maclen(key->ao_key); + + if (tcprsk) { + u8 aolen = maclen + sizeof(struct tcp_ao_hdr); + + *ptr++ = htonl((TCPOPT_AO << 24) | (aolen << 16) | + (tcprsk->ao_keyid << 8) | + (tcprsk->ao_rcv_next)); + } else { + struct tcp_ao_key *rnext_key; + struct tcp_ao_info *ao_info; + + ao_info = rcu_dereference_check(tp->ao_info, + lockdep_sock_is_held(&tp->inet_conn.icsk_inet.sk)); + rnext_key = READ_ONCE(ao_info->rnext_key); + if (WARN_ON_ONCE(!rnext_key)) + return ptr; + *ptr++ = htonl((TCPOPT_AO << 24) | + (tcp_ao_len(key->ao_key) << 16) | + (key->ao_key->sndid << 8) | + (rnext_key->rcvid)); + } + opts->hash_location = (__u8 *)ptr; + ptr += maclen / sizeof(*ptr); + if (unlikely(maclen % sizeof(*ptr))) { + memset(ptr, TCPOPT_NOP, sizeof(*ptr)); + ptr++; + } +#endif + return ptr; +} + +/* Initial values for AccECN option, ordered is based on ECN field bits + * similar to received_ecn_bytes. Used for SYN/ACK AccECN option. + */ +static const u32 synack_ecn_bytes[3] = { 0, 0, 0 }; + /* Write previously computed TCP options to the packet. * * Beware: Something in the Internet is very sensitive to the ordering of @@ -605,19 +628,25 @@ static void bpf_skops_write_hdr_opt(struct sock *sk, struct sk_buff *skb, * At least SACK_PERM as the first option is known to lead to a disaster * (but it may well be that other scenarios fail similarly). */ -static void tcp_options_write(__be32 *ptr, struct tcp_sock *tp, - struct tcp_out_options *opts) +static void tcp_options_write(struct tcphdr *th, struct tcp_sock *tp, + const struct tcp_request_sock *tcprsk, + struct tcp_out_options *opts, + struct tcp_key *key) { + u8 leftover_highbyte = TCPOPT_NOP; /* replace 1st NOP if avail */ + u8 leftover_lowbyte = TCPOPT_NOP; /* replace 2nd NOP in succession */ + __be32 *ptr = (__be32 *)(th + 1); u16 options = opts->options; /* mungable copy */ - if (unlikely(OPTION_MD5 & options)) { + if (tcp_key_is_md5(key)) { *ptr++ = htonl((TCPOPT_NOP << 24) | (TCPOPT_NOP << 16) | (TCPOPT_MD5SIG << 8) | TCPOLEN_MD5SIG); /* overload cookie hash location */ opts->hash_location = (__u8 *)ptr; ptr += 4; + } else if (tcp_key_is_ao(key)) { + ptr = process_tcp_ao_options(tp, tcprsk, opts, key, ptr); } - if (unlikely(opts->mss)) { *ptr++ = htonl((TCPOPT_MSS << 24) | (TCPOLEN_MSS << 16) | @@ -641,15 +670,75 @@ static void tcp_options_write(__be32 *ptr, struct tcp_sock *tp, *ptr++ = htonl(opts->tsecr); } + if (OPTION_ACCECN & options) { + const u32 *ecn_bytes = opts->use_synack_ecn_bytes ? + synack_ecn_bytes : + tp->received_ecn_bytes; + const u8 ect0_idx = INET_ECN_ECT_0 - 1; + const u8 ect1_idx = INET_ECN_ECT_1 - 1; + const u8 ce_idx = INET_ECN_CE - 1; + u32 e0b; + u32 e1b; + u32 ceb; + u8 len; + + e0b = ecn_bytes[ect0_idx] + TCP_ACCECN_E0B_INIT_OFFSET; + e1b = ecn_bytes[ect1_idx] + TCP_ACCECN_E1B_INIT_OFFSET; + ceb = ecn_bytes[ce_idx] + TCP_ACCECN_CEB_INIT_OFFSET; + len = TCPOLEN_ACCECN_BASE + + opts->num_accecn_fields * TCPOLEN_ACCECN_PERFIELD; + + if (opts->num_accecn_fields == 2) { + *ptr++ = htonl((TCPOPT_ACCECN1 << 24) | (len << 16) | + ((e1b >> 8) & 0xffff)); + *ptr++ = htonl(((e1b & 0xff) << 24) | + (ceb & 0xffffff)); + } else if (opts->num_accecn_fields == 1) { + *ptr++ = htonl((TCPOPT_ACCECN1 << 24) | (len << 16) | + ((e1b >> 8) & 0xffff)); + leftover_highbyte = e1b & 0xff; + leftover_lowbyte = TCPOPT_NOP; + } else if (opts->num_accecn_fields == 0) { + leftover_highbyte = TCPOPT_ACCECN1; + leftover_lowbyte = len; + } else if (opts->num_accecn_fields == 3) { + *ptr++ = htonl((TCPOPT_ACCECN1 << 24) | (len << 16) | + ((e1b >> 8) & 0xffff)); + *ptr++ = htonl(((e1b & 0xff) << 24) | + (ceb & 0xffffff)); + *ptr++ = htonl(((e0b & 0xffffff) << 8) | + TCPOPT_NOP); + } + if (tp) { + tp->accecn_minlen = 0; + tp->accecn_opt_tstamp = tp->tcp_mstamp; + if (tp->accecn_opt_demand) + tp->accecn_opt_demand--; + } + } + if (unlikely(OPTION_SACK_ADVERTISE & options)) { - *ptr++ = htonl((TCPOPT_NOP << 24) | - (TCPOPT_NOP << 16) | + *ptr++ = htonl((leftover_highbyte << 24) | + (leftover_lowbyte << 16) | (TCPOPT_SACK_PERM << 8) | TCPOLEN_SACK_PERM); + leftover_highbyte = TCPOPT_NOP; + leftover_lowbyte = TCPOPT_NOP; } if (unlikely(OPTION_WSCALE & options)) { - *ptr++ = htonl((TCPOPT_NOP << 24) | + u8 highbyte = TCPOPT_NOP; + + /* Do not split the leftover 2-byte to fit into a single + * NOP, i.e., replace this NOP only when 1 byte is leftover + * within leftover_highbyte. + */ + if (unlikely(leftover_highbyte != TCPOPT_NOP && + leftover_lowbyte == TCPOPT_NOP)) { + highbyte = leftover_highbyte; + leftover_highbyte = TCPOPT_NOP; + } + *ptr++ = htonl((highbyte << 24) | (TCPOPT_WINDOW << 16) | (TCPOLEN_WINDOW << 8) | opts->ws); @@ -660,11 +749,13 @@ static void tcp_options_write(__be32 *ptr, struct tcp_sock *tp, tp->duplicate_sack : tp->selective_acks; int this_sack; - *ptr++ = htonl((TCPOPT_NOP << 24) | - (TCPOPT_NOP << 16) | + *ptr++ = htonl((leftover_highbyte << 24) | + (leftover_lowbyte << 16) | (TCPOPT_SACK << 8) | (TCPOLEN_SACK_BASE + (opts->num_sack_blocks * TCPOLEN_SACK_PERBLOCK))); + leftover_highbyte = TCPOPT_NOP; + leftover_lowbyte = TCPOPT_NOP; for (this_sack = 0; this_sack < opts->num_sack_blocks; ++this_sack) { @@ -673,6 +764,14 @@ static void tcp_options_write(__be32 *ptr, struct tcp_sock *tp, } tp->rx_opt.dsack = 0; + } else if (unlikely(leftover_highbyte != TCPOPT_NOP || + leftover_lowbyte != TCPOPT_NOP)) { + *ptr++ = htonl((leftover_highbyte << 24) | + (leftover_lowbyte << 16) | + (TCPOPT_NOP << 8) | + TCPOPT_NOP); + leftover_highbyte = TCPOPT_NOP; + leftover_lowbyte = TCPOPT_NOP; } if (unlikely(OPTION_FAST_OPEN_COOKIE & options)) { @@ -701,37 +800,39 @@ static void tcp_options_write(__be32 *ptr, struct tcp_sock *tp, smc_options_write(ptr, &options); - mptcp_options_write(ptr, tp, opts); + mptcp_options_write(th, ptr, tp, opts); } -static void smc_set_option(const struct tcp_sock *tp, +static void smc_set_option(struct tcp_sock *tp, struct tcp_out_options *opts, unsigned int *remaining) { #if IS_ENABLED(CONFIG_SMC) - if (static_branch_unlikely(&tcp_have_smc)) { - if (tp->syn_smc) { - if (*remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { - opts->options |= OPTION_SMC; - *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; - } + if (static_branch_unlikely(&tcp_have_smc) && tp->syn_smc) { + tp->syn_smc = !!smc_call_hsbpf(1, tp, syn_option); + /* re-check syn_smc */ + if (tp->syn_smc && + *remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { + opts->options |= OPTION_SMC; + *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; } } #endif } static void smc_set_option_cond(const struct tcp_sock *tp, - const struct inet_request_sock *ireq, + struct inet_request_sock *ireq, struct tcp_out_options *opts, unsigned int *remaining) { #if IS_ENABLED(CONFIG_SMC) - if (static_branch_unlikely(&tcp_have_smc)) { - if (tp->syn_smc && ireq->smc_ok) { - if (*remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { - opts->options |= OPTION_SMC; - *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; - } + if (static_branch_unlikely(&tcp_have_smc) && tp->syn_smc && ireq->smc_ok) { + ireq->smc_ok = !!smc_call_hsbpf(1, tp, synack_option, ireq); + /* re-check smc_ok */ + if (ireq->smc_ok && + *remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { + opts->options |= OPTION_SMC; + *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; } } #endif @@ -753,28 +854,104 @@ static void mptcp_set_option_cond(const struct request_sock *req, } } +static u32 tcp_synack_options_combine_saving(struct tcp_out_options *opts) +{ + /* How much there's room for combining with the alignment padding? */ + if ((opts->options & (OPTION_SACK_ADVERTISE | OPTION_TS)) == + OPTION_SACK_ADVERTISE) + return 2; + else if (opts->options & OPTION_WSCALE) + return 1; + return 0; +} + +/* Calculates how long AccECN option will fit to @remaining option space. + * + * AccECN option can sometimes replace NOPs used for alignment of other + * TCP options (up to @max_combine_saving available). + * + * Only solutions with at least @required AccECN fields are accepted. + * + * Returns: The size of the AccECN option excluding space repurposed from + * the alignment of the other options. + */ +static int tcp_options_fit_accecn(struct tcp_out_options *opts, int required, + int remaining) +{ + int size = TCP_ACCECN_MAXSIZE; + int sack_blocks_reduce = 0; + int max_combine_saving; + int rem = remaining; + int align_size; + + if (opts->use_synack_ecn_bytes) + max_combine_saving = tcp_synack_options_combine_saving(opts); + else + max_combine_saving = opts->num_sack_blocks > 0 ? 2 : 0; + opts->num_accecn_fields = TCP_ACCECN_NUMFIELDS; + while (opts->num_accecn_fields >= required) { + /* Pad to dword if cannot combine */ + if ((size & 0x3) > max_combine_saving) + align_size = ALIGN(size, 4); + else + align_size = ALIGN_DOWN(size, 4); + + if (rem >= align_size) { + size = align_size; + break; + } else if (opts->num_accecn_fields == required && + opts->num_sack_blocks > 2 && + required > 0) { + /* Try to fit the option by removing one SACK block */ + opts->num_sack_blocks--; + sack_blocks_reduce++; + rem = rem + TCPOLEN_SACK_PERBLOCK; + + opts->num_accecn_fields = TCP_ACCECN_NUMFIELDS; + size = TCP_ACCECN_MAXSIZE; + continue; + } + + opts->num_accecn_fields--; + size -= TCPOLEN_ACCECN_PERFIELD; + } + if (sack_blocks_reduce > 0) { + if (opts->num_accecn_fields >= required) + size -= sack_blocks_reduce * TCPOLEN_SACK_PERBLOCK; + else + opts->num_sack_blocks += sack_blocks_reduce; + } + if (opts->num_accecn_fields < required) + return 0; + + opts->options |= OPTION_ACCECN; + return size; +} + /* Compute TCP options for SYN packets. This is not the final * network wire format yet. */ static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb, struct tcp_out_options *opts, - struct tcp_md5sig_key **md5) + struct tcp_key *key) { struct tcp_sock *tp = tcp_sk(sk); unsigned int remaining = MAX_TCP_OPTION_SPACE; struct tcp_fastopen_request *fastopen = tp->fastopen_req; + bool timestamps; - *md5 = NULL; -#ifdef CONFIG_TCP_MD5SIG - if (static_branch_unlikely(&tcp_md5_needed) && - rcu_access_pointer(tp->md5sig_info)) { - *md5 = tp->af_specific->md5_lookup(sk, sk); - if (*md5) { - opts->options |= OPTION_MD5; - remaining -= TCPOLEN_MD5SIG_ALIGNED; + /* Better than switch (key.type) as it has static branches */ + if (tcp_key_is_md5(key)) { + timestamps = false; + opts->options |= OPTION_MD5; + remaining -= TCPOLEN_MD5SIG_ALIGNED; + } else { + timestamps = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_timestamps); + if (tcp_key_is_ao(key)) { + opts->options |= OPTION_AO; + remaining -= tcp_ao_len_aligned(key->ao_key); } } -#endif /* We always get an MSS option. The option bytes which will be seen in * normal data packets should timestamps be used, must be in the MSS @@ -788,18 +965,18 @@ static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb, opts->mss = tcp_advertise_mss(sk); remaining -= TCPOLEN_MSS_ALIGNED; - if (likely(sock_net(sk)->ipv4.sysctl_tcp_timestamps && !*md5)) { + if (likely(timestamps)) { opts->options |= OPTION_TS; - opts->tsval = tcp_skb_timestamp(skb) + tp->tsoffset; + opts->tsval = tcp_skb_timestamp_ts(tp->tcp_usec_ts, skb) + tp->tsoffset; opts->tsecr = tp->rx_opt.ts_recent; remaining -= TCPOLEN_TSTAMP_ALIGNED; } - if (likely(sock_net(sk)->ipv4.sysctl_tcp_window_scaling)) { + if (likely(READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_window_scaling))) { opts->ws = tp->rx_opt.rcv_wscale; opts->options |= OPTION_WSCALE; remaining -= TCPOLEN_WSCALE_ALIGNED; } - if (likely(sock_net(sk)->ipv4.sysctl_tcp_sack)) { + if (likely(READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_sack))) { opts->options |= OPTION_SACK_ADVERTISE; if (unlikely(!(OPTION_TS & opts->options))) remaining -= TCPOLEN_SACKPERM_ALIGNED; @@ -826,11 +1003,27 @@ static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb, unsigned int size; if (mptcp_syn_options(sk, skb, &size, &opts->mptcp)) { - opts->options |= OPTION_MPTCP; - remaining -= size; + if (remaining >= size) { + opts->options |= OPTION_MPTCP; + remaining -= size; + } } } + /* Simultaneous open SYN/ACK needs AccECN option but not SYN. + * It is attempted to negotiate the use of AccECN also on the first + * retransmitted SYN, as mentioned in "3.1.4.1. Retransmitted SYNs" + * of AccECN draft. + */ + if (unlikely((TCP_SKB_CB(skb)->tcp_flags & TCPHDR_ACK) && + tcp_ecn_mode_accecn(tp) && + inet_csk(sk)->icsk_retransmits < 2 && + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn_option) && + remaining >= TCPOLEN_ACCECN_BASE)) { + opts->use_synack_ecn_bytes = 1; + remaining -= tcp_options_fit_accecn(opts, 0, remaining); + } + bpf_skops_hdr_opt_len(sk, skb, NULL, NULL, 0, opts, &remaining); return MAX_TCP_OPTION_SPACE - remaining; @@ -841,16 +1034,16 @@ static unsigned int tcp_synack_options(const struct sock *sk, struct request_sock *req, unsigned int mss, struct sk_buff *skb, struct tcp_out_options *opts, - const struct tcp_md5sig_key *md5, + const struct tcp_key *key, struct tcp_fastopen_cookie *foc, enum tcp_synack_type synack_type, struct sk_buff *syn_skb) { struct inet_request_sock *ireq = inet_rsk(req); unsigned int remaining = MAX_TCP_OPTION_SPACE; + struct tcp_request_sock *treq = tcp_rsk(req); -#ifdef CONFIG_TCP_MD5SIG - if (md5) { + if (tcp_key_is_md5(key)) { opts->options |= OPTION_MD5; remaining -= TCPOLEN_MD5SIG_ALIGNED; @@ -861,8 +1054,11 @@ static unsigned int tcp_synack_options(const struct sock *sk, */ if (synack_type != TCP_SYNACK_COOKIE) ireq->tstamp_ok &= !ireq->sack_ok; + } else if (tcp_key_is_ao(key)) { + opts->options |= OPTION_AO; + remaining -= tcp_ao_len_aligned(key->ao_key); + ireq->tstamp_ok &= !ireq->sack_ok; } -#endif /* We always send an MSS option. */ opts->mss = mss; @@ -875,7 +1071,14 @@ static unsigned int tcp_synack_options(const struct sock *sk, } if (likely(ireq->tstamp_ok)) { opts->options |= OPTION_TS; - opts->tsval = tcp_skb_timestamp(skb) + tcp_rsk(req)->ts_off; + opts->tsval = tcp_skb_timestamp_ts(tcp_rsk(req)->req_usec_ts, skb) + + tcp_rsk(req)->ts_off; + if (!tcp_rsk(req)->snt_tsval_first) { + if (!opts->tsval) + opts->tsval = ~0U; + tcp_rsk(req)->snt_tsval_first = opts->tsval; + } + WRITE_ONCE(tcp_rsk(req)->snt_tsval_last, opts->tsval); opts->tsecr = req->ts_recent; remaining -= TCPOLEN_TSTAMP_ALIGNED; } @@ -901,6 +1104,13 @@ static unsigned int tcp_synack_options(const struct sock *sk, smc_set_option_cond(tcp_sk(sk), ireq, opts, &remaining); + if (treq->accecn_ok && + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn_option) && + req->num_timeout < 1 && remaining >= TCPOLEN_ACCECN_BASE) { + opts->use_synack_ecn_bytes = 1; + remaining -= tcp_options_fit_accecn(opts, 0, remaining); + } + bpf_skops_hdr_opt_len((struct sock *)sk, skb, req, syn_skb, synack_type, opts, &remaining); @@ -912,7 +1122,7 @@ static unsigned int tcp_synack_options(const struct sock *sk, */ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb, struct tcp_out_options *opts, - struct tcp_md5sig_key **md5) + struct tcp_key *key) { struct tcp_sock *tp = tcp_sk(sk); unsigned int size = 0; @@ -920,21 +1130,19 @@ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb opts->options = 0; - *md5 = NULL; -#ifdef CONFIG_TCP_MD5SIG - if (static_branch_unlikely(&tcp_md5_needed) && - rcu_access_pointer(tp->md5sig_info)) { - *md5 = tp->af_specific->md5_lookup(sk, sk); - if (*md5) { - opts->options |= OPTION_MD5; - size += TCPOLEN_MD5SIG_ALIGNED; - } + /* Better than switch (key.type) as it has static branches */ + if (tcp_key_is_md5(key)) { + opts->options |= OPTION_MD5; + size += TCPOLEN_MD5SIG_ALIGNED; + } else if (tcp_key_is_ao(key)) { + opts->options |= OPTION_AO; + size += tcp_ao_len_aligned(key->ao_key); } -#endif if (likely(tp->rx_opt.tstamp_ok)) { opts->options |= OPTION_TS; - opts->tsval = skb ? tcp_skb_timestamp(skb) + tp->tsoffset : 0; + opts->tsval = skb ? tcp_skb_timestamp_ts(tp->tcp_usec_ts, skb) + + tp->tsoffset : 0; opts->tsecr = tp->rx_opt.ts_recent; size += TCPOLEN_TSTAMP_ALIGNED; } @@ -959,17 +1167,32 @@ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb eff_sacks = tp->rx_opt.num_sacks + tp->rx_opt.dsack; if (unlikely(eff_sacks)) { const unsigned int remaining = MAX_TCP_OPTION_SPACE - size; - if (unlikely(remaining < TCPOLEN_SACK_BASE_ALIGNED + - TCPOLEN_SACK_PERBLOCK)) - return size; + if (likely(remaining >= TCPOLEN_SACK_BASE_ALIGNED + + TCPOLEN_SACK_PERBLOCK)) { + opts->num_sack_blocks = + min_t(unsigned int, eff_sacks, + (remaining - TCPOLEN_SACK_BASE_ALIGNED) / + TCPOLEN_SACK_PERBLOCK); + + size += TCPOLEN_SACK_BASE_ALIGNED + + opts->num_sack_blocks * TCPOLEN_SACK_PERBLOCK; + } else { + opts->num_sack_blocks = 0; + } + } else { + opts->num_sack_blocks = 0; + } - opts->num_sack_blocks = - min_t(unsigned int, eff_sacks, - (remaining - TCPOLEN_SACK_BASE_ALIGNED) / - TCPOLEN_SACK_PERBLOCK); + if (tcp_ecn_mode_accecn(tp)) { + int ecn_opt = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn_option); - size += TCPOLEN_SACK_BASE_ALIGNED + - opts->num_sack_blocks * TCPOLEN_SACK_PERBLOCK; + if (ecn_opt && tp->saw_accecn_opt && !tcp_accecn_opt_fail_send(tp) && + (ecn_opt >= TCP_ACCECN_OPTION_FULL || tp->accecn_opt_demand || + tcp_accecn_option_beacon_check(sk))) { + opts->use_synack_ecn_bytes = 0; + size += tcp_options_fit_accecn(opts, tp->accecn_minlen, + MAX_TCP_OPTION_SPACE - size); + } } if (unlikely(BPF_SOCK_OPS_TEST_FLAG(tp, @@ -995,15 +1218,15 @@ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb * needs to be reallocated in a driver. * The invariant being skb->truesize subtracted from sk->sk_wmem_alloc * - * Since transmit from skb destructor is forbidden, we use a tasklet + * Since transmit from skb destructor is forbidden, we use a BH work item * to process all sockets that eventually need to send more skbs. - * We use one tasklet per cpu, with its own queue of sockets. + * We use one work item per cpu, with its own queue of sockets. */ -struct tsq_tasklet { - struct tasklet_struct tasklet; +struct tsq_work { + struct work_struct work; struct list_head head; /* queue of tcp sockets */ }; -static DEFINE_PER_CPU(struct tsq_tasklet, tsq_tasklet); +static DEFINE_PER_CPU(struct tsq_work, tsq_work); static void tcp_tsq_write(struct sock *sk) { @@ -1013,7 +1236,7 @@ static void tcp_tsq_write(struct sock *sk) struct tcp_sock *tp = tcp_sk(sk); if (tp->lost_out > tp->retrans_out && - tp->snd_cwnd > tcp_packets_in_flight(tp)) { + tcp_snd_cwnd(tp) > tcp_packets_in_flight(tp)) { tcp_mstamp_refresh(tp); tcp_xmit_retransmit_queue(sk); } @@ -1033,14 +1256,14 @@ static void tcp_tsq_handler(struct sock *sk) bh_unlock_sock(sk); } /* - * One tasklet per cpu tries to send more skbs. - * We run in tasklet context but need to disable irqs when + * One work item per cpu tries to send more skbs. + * We run in BH context but need to disable irqs when * transferring tsq->head because tcp_wfree() might * interrupt us (non NAPI drivers) */ -static void tcp_tasklet_func(struct tasklet_struct *t) +static void tcp_tsq_workfn(struct work_struct *work) { - struct tsq_tasklet *tsq = from_tasklet(tsq, t, tasklet); + struct tsq_work *tsq = container_of(work, struct tsq_work, work); LIST_HEAD(list); unsigned long flags; struct list_head *q, *n; @@ -1067,7 +1290,8 @@ static void tcp_tasklet_func(struct tasklet_struct *t) #define TCP_DEFERRED_ALL (TCPF_TSQ_DEFERRED | \ TCPF_WRITE_TIMER_DEFERRED | \ TCPF_DELACK_TIMER_DEFERRED | \ - TCPF_MTU_REDUCED_DEFERRED) + TCPF_MTU_REDUCED_DEFERRED | \ + TCPF_ACK_DEFERRED) /** * tcp_release_cb - tcp release_sock() callback * @sk: socket @@ -1077,30 +1301,20 @@ static void tcp_tasklet_func(struct tasklet_struct *t) */ void tcp_release_cb(struct sock *sk) { - unsigned long flags, nflags; + unsigned long flags = smp_load_acquire(&sk->sk_tsq_flags); + unsigned long nflags; /* perform an atomic operation only if at least one flag is set */ do { - flags = sk->sk_tsq_flags; if (!(flags & TCP_DEFERRED_ALL)) return; nflags = flags & ~TCP_DEFERRED_ALL; - } while (cmpxchg(&sk->sk_tsq_flags, flags, nflags) != flags); + } while (!try_cmpxchg(&sk->sk_tsq_flags, &flags, nflags)); if (flags & TCPF_TSQ_DEFERRED) { tcp_tsq_write(sk); __sock_put(sk); } - /* Here begins the tricky part : - * We are called from release_sock() with : - * 1) BH disabled - * 2) sk_lock.slock spinlock held - * 3) socket owned by us (sk->sk_lock.owned == 1) - * - * But following code is meant to be called from BH handlers, - * so we should keep BH disabled, but early release socket ownership - */ - sock_release_ownership(sk); if (flags & TCPF_WRITE_TIMER_DEFERRED) { tcp_write_timer_handler(sk); @@ -1114,18 +1328,20 @@ void tcp_release_cb(struct sock *sk) inet_csk(sk)->icsk_af_ops->mtu_reduced(sk); __sock_put(sk); } + if ((flags & TCPF_ACK_DEFERRED) && inet_csk_ack_scheduled(sk)) + tcp_send_ack(sk); } -EXPORT_SYMBOL(tcp_release_cb); +EXPORT_IPV6_MOD(tcp_release_cb); -void __init tcp_tasklet_init(void) +void __init tcp_tsq_work_init(void) { int i; for_each_possible_cpu(i) { - struct tsq_tasklet *tsq = &per_cpu(tsq_tasklet, i); + struct tsq_work *tsq = &per_cpu(tsq_work, i); INIT_LIST_HEAD(&tsq->head); - tasklet_setup(&tsq->tasklet, tcp_tasklet_func); + INIT_WORK(&tsq->work, tcp_tsq_workfn); } } @@ -1139,9 +1355,11 @@ void tcp_wfree(struct sk_buff *skb) struct sock *sk = skb->sk; struct tcp_sock *tp = tcp_sk(sk); unsigned long flags, nval, oval; + struct tsq_work *tsq; + bool empty; /* Keep one reference on sk_wmem_alloc. - * Will be released by sk_free() from here or tcp_tasklet_func() + * Will be released by sk_free() from here or tcp_tsq_workfn() */ WARN_ON(refcount_sub_and_test(skb->truesize - 1, &sk->sk_wmem_alloc)); @@ -1155,28 +1373,23 @@ void tcp_wfree(struct sk_buff *skb) if (refcount_read(&sk->sk_wmem_alloc) >= SKB_TRUESIZE(1) && this_cpu_ksoftirqd() == current) goto out; - for (oval = READ_ONCE(sk->sk_tsq_flags);; oval = nval) { - struct tsq_tasklet *tsq; - bool empty; - + oval = smp_load_acquire(&sk->sk_tsq_flags); + do { if (!(oval & TSQF_THROTTLED) || (oval & TSQF_QUEUED)) goto out; nval = (oval & ~TSQF_THROTTLED) | TSQF_QUEUED; - nval = cmpxchg(&sk->sk_tsq_flags, oval, nval); - if (nval != oval) - continue; + } while (!try_cmpxchg(&sk->sk_tsq_flags, &oval, nval)); - /* queue this socket to tasklet queue */ - local_irq_save(flags); - tsq = this_cpu_ptr(&tsq_tasklet); - empty = list_empty(&tsq->head); - list_add(&tp->tsq_node, &tsq->head); - if (empty) - tasklet_schedule(&tsq->tasklet); - local_irq_restore(flags); - return; - } + /* queue this socket to BH workqueue */ + local_irq_save(flags); + tsq = this_cpu_ptr(&tsq_work); + empty = list_empty(&tsq->head); + list_add(&tp->tsq_node, &tsq->head); + if (empty) + queue_work(system_bh_wq, &tsq->work); + local_irq_restore(flags); + return; out: sk_free(sk); } @@ -1201,7 +1414,7 @@ static void tcp_update_skb_after_send(struct sock *sk, struct sk_buff *skb, struct tcp_sock *tp = tcp_sk(sk); if (sk->sk_pacing_status != SK_PACING_NONE) { - unsigned long rate = sk->sk_pacing_rate; + unsigned long rate = READ_ONCE(sk->sk_pacing_rate); /* Original sch_fq does not pace first 10 MSS * Note that tp->data_segs_out overflows after 2^32 packets, @@ -1244,7 +1457,7 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, struct tcp_out_options opts; unsigned int tcp_options_size, tcp_header_size; struct sk_buff *oskb = NULL; - struct tcp_md5sig_key *md5; + struct tcp_key key; struct tcphdr *th; u64 prior_wstamp; int err; @@ -1253,7 +1466,7 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, tp = tcp_sk(sk); prior_wstamp = tp->tcp_wstamp_ns; tp->tcp_wstamp_ns = max(tp->tcp_wstamp_ns, tp->tcp_clock_cache); - skb->skb_mstamp_ns = tp->tcp_wstamp_ns; + skb_set_delivery_time(skb, tp->tcp_wstamp_ns, SKB_CLOCK_MONOTONIC); if (clone_it) { oskb = skb; @@ -1276,11 +1489,11 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, tcb = TCP_SKB_CB(skb); memset(&opts, 0, sizeof(opts)); + tcp_get_current_key(sk, &key); if (unlikely(tcb->tcp_flags & TCPHDR_SYN)) { - tcp_options_size = tcp_syn_options(sk, skb, &opts, &md5); + tcp_options_size = tcp_syn_options(sk, skb, &opts, &key); } else { - tcp_options_size = tcp_established_options(sk, skb, &opts, - &md5); + tcp_options_size = tcp_established_options(sk, skb, &opts, &key); /* Force a PSH flag on all (GSO) packets to expedite GRO flush * at receiver : This slightly improve GRO performance. * Note that we do not force the PSH flag for non GSO packets, @@ -1294,14 +1507,21 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, } tcp_header_size = tcp_options_size + sizeof(struct tcphdr); - /* if no packet is in qdisc/device queue, then allow XPS to select - * another queue. We can be called from tcp_tsq_handler() - * which holds one reference to sk. - * - * TODO: Ideally, in-flight pure ACK packets should not matter here. - * One way to get this would be to set skb->truesize = 2 on them. + /* We set skb->ooo_okay to one if this packet can select + * a different TX queue than prior packets of this flow, + * to avoid self inflicted reorders. + * The 'other' queue decision is based on current cpu number + * if XPS is enabled, or sk->sk_txhash otherwise. + * We can switch to another (and better) queue if: + * 1) No packet with payload is in qdisc/device queues. + * Delays in TX completion can defeat the test + * even if packets were already sent. + * 2) Or rtx queue is empty. + * This mitigates above case if ACK packets for + * all prior packets were already processed. */ - skb->ooo_okay = sk_wmem_alloc_get(sk) < SKB_TRUESIZE(1); + skb->ooo_okay = sk_wmem_alloc_get(sk) < SKB_TRUESIZE(1) || + tcp_rtx_queue_empty(sk); /* If we had to use memory reserve to allocate this skb, * this might cause drops if packet is looped back : @@ -1318,7 +1538,7 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, skb->destructor = skb_is_tcp_pure_ack(skb) ? __sock_wfree : tcp_wfree; refcount_add(skb->truesize, &sk->sk_wmem_alloc); - skb_set_dst_pending_confirm(skb, sk->sk_dst_pending_confirm); + skb_set_dst_pending_confirm(skb, READ_ONCE(sk->sk_dst_pending_confirm)); /* Build TCP header and checksum it. */ th = (struct tcphdr *)skb->data; @@ -1327,7 +1547,7 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, th->seq = htonl(tcb->seq); th->ack_seq = htonl(rcv_nxt); *(((__be16 *)th) + 6) = htons(((tcp_header_size >> 2) << 12) | - tcb->tcp_flags); + (tcb->tcp_flags & TCPHDR_FLAGS_MASK)); th->check = 0; th->urg_ptr = 0; @@ -1354,16 +1574,25 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, th->window = htons(min(tp->rcv_wnd, 65535U)); } - tcp_options_write((__be32 *)(th + 1), tp, &opts); + tcp_options_write(th, tp, NULL, &opts, &key); + if (tcp_key_is_md5(&key)) { #ifdef CONFIG_TCP_MD5SIG - /* Calculate the MD5 hash, as we have all we need now */ - if (md5) { + /* Calculate the MD5 hash, as we have all we need now */ sk_gso_disable(sk); tp->af_specific->calc_md5_hash(opts.hash_location, - md5, sk, skb); - } + key.md5_key, sk, skb); #endif + } else if (tcp_key_is_ao(&key)) { + int err; + + err = tcp_ao_transmit_skb(sk, skb, key.ao_key, th, + opts.hash_location); + if (err) { + sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_NOT_SPECIFIED); + return -ENOMEM; + } + } /* BPF prog is the last one writing header option */ bpf_skops_write_hdr_opt(sk, skb, NULL, NULL, 0, &opts); @@ -1373,7 +1602,7 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, sk, skb); if (likely(tcb->tcp_flags & TCPHDR_ACK)) - tcp_event_ack_sent(sk, tcp_skb_pcount(skb), rcv_nxt); + tcp_event_ack_sent(sk, rcv_nxt); if (skb->len != tcp_header_size) { tcp_event_data_sent(tp, sk); @@ -1433,24 +1662,29 @@ static void tcp_queue_skb(struct sock *sk, struct sk_buff *skb) /* Advance write_seq and place onto the write_queue. */ WRITE_ONCE(tp->write_seq, TCP_SKB_CB(skb)->end_seq); __skb_header_release(skb); + psp_enqueue_set_decrypted(sk, skb); tcp_add_write_queue_tail(sk, skb); sk_wmem_queued_add(sk, skb->truesize); sk_mem_charge(sk, skb->truesize); } /* Initialize TSO segments for a packet. */ -static void tcp_set_skb_tso_segs(struct sk_buff *skb, unsigned int mss_now) +static int tcp_set_skb_tso_segs(struct sk_buff *skb, unsigned int mss_now) { + int tso_segs; + if (skb->len <= mss_now) { /* Avoid the costly divide in the normal * non-TSO case. */ - tcp_skb_pcount_set(skb, 1); TCP_SKB_CB(skb)->tcp_gso_size = 0; - } else { - tcp_skb_pcount_set(skb, DIV_ROUND_UP(skb->len, mss_now)); - TCP_SKB_CB(skb)->tcp_gso_size = mss_now; + tcp_skb_pcount_set(skb, 1); + return 1; } + TCP_SKB_CB(skb)->tcp_gso_size = mss_now; + tso_segs = DIV_ROUND_UP(skb->len, mss_now); + tcp_skb_pcount_set(skb, tso_segs); + return tso_segs; } /* Pcount in the middle of the write queue got changed, we need to do various @@ -1473,11 +1707,6 @@ static void tcp_adjust_pcount(struct sock *sk, const struct sk_buff *skb, int de if (tcp_is_reno(tp) && decr > 0) tp->sacked_out -= min_t(u32, tp->sacked_out, decr); - if (tp->lost_skb_hint && - before(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(tp->lost_skb_hint)->seq) && - (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED)) - tp->lost_cnt_hint -= decr; - tcp_verify_left_out(tp); } @@ -1533,24 +1762,22 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, { struct tcp_sock *tp = tcp_sk(sk); struct sk_buff *buff; - int nsize, old_factor; + int old_factor; long limit; + u16 flags; int nlen; - u8 flags; if (WARN_ON(len > skb->len)) return -EINVAL; - nsize = skb_headlen(skb) - len; - if (nsize < 0) - nsize = 0; + DEBUG_NET_WARN_ON_ONCE(skb_headlen(skb)); /* tcp_sendmsg() can overshoot sk_wmem_queued by one full size skb. * We need some allowance to not penalize applications setting small * SO_SNDBUF values. * Also allow first and last skb in retransmit queue to be split. */ - limit = sk->sk_sndbuf + 2 * SKB_TRUESIZE(GSO_MAX_SIZE); + limit = sk->sk_sndbuf + 2 * SKB_TRUESIZE(GSO_LEGACY_MAX_SIZE); if (unlikely((sk->sk_wmem_queued >> 1) > limit && tcp_queue != TCP_FRAG_IN_WRITE_QUEUE && skb != tcp_rtx_queue_head(sk) && @@ -1563,7 +1790,7 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, return -ENOMEM; /* Get a new skb... force flag on. */ - buff = tcp_stream_alloc_skb(sk, nsize, gfp, true); + buff = tcp_stream_alloc_skb(sk, gfp, true); if (!buff) return -ENOMEM; /* We'll just try again later. */ skb_copy_decrypted(buff, skb); @@ -1571,7 +1798,7 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, sk_wmem_queued_add(sk, buff->truesize); sk_mem_charge(sk, buff->truesize); - nlen = skb->len - len - nsize; + nlen = skb->len - len; buff->truesize += nlen; skb->truesize -= nlen; @@ -1589,7 +1816,7 @@ int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, skb_split(skb, buff, len); - buff->tstamp = skb->tstamp; + skb_set_delivery_time(buff, skb->tstamp, SKB_CLOCK_MONOTONIC); tcp_fragment_tstamp(skb, buff); old_factor = tcp_skb_pcount(skb); @@ -1629,13 +1856,7 @@ static int __pskb_trim_head(struct sk_buff *skb, int len) struct skb_shared_info *shinfo; int i, k, eat; - eat = min_t(int, len, skb_headlen(skb)); - if (eat) { - __skb_pull(skb, eat); - len -= eat; - if (!len) - return 0; - } + DEBUG_NET_WARN_ON_ONCE(skb_headlen(skb)); eat = len; k = 0; shinfo = skb_shinfo(skb); @@ -1674,12 +1895,10 @@ int tcp_trim_head(struct sock *sk, struct sk_buff *skb, u32 len) TCP_SKB_CB(skb)->seq += len; - if (delta_truesize) { - skb->truesize -= delta_truesize; - sk_wmem_queued_add(sk, -delta_truesize); - if (!skb_zcopy_pure(skb)) - sk_mem_uncharge(sk, delta_truesize); - } + skb->truesize -= delta_truesize; + sk_wmem_queued_add(sk, -delta_truesize); + if (!skb_zcopy_pure(skb)) + sk_mem_uncharge(sk, delta_truesize); /* Any change of skb->len requires recalculation of tso factor. */ if (tcp_skb_pcount(skb) > 1) @@ -1700,14 +1919,6 @@ static inline int __tcp_mtu_to_mss(struct sock *sk, int pmtu) */ mss_now = pmtu - icsk->icsk_af_ops->net_header_len - sizeof(struct tcphdr); - /* IPv6 adds a frag_hdr in case RTAX_FEATURE_ALLFRAG is set */ - if (icsk->icsk_af_ops->net_frag_header_len) { - const struct dst_entry *dst = __sk_dst_get(sk); - - if (dst && dst_allfrag(dst)) - mss_now -= icsk->icsk_af_ops->net_frag_header_len; - } - /* Clamp it (mss_clamp does not include tcp options) */ if (mss_now > tp->rx_opt.mss_clamp) mss_now = tp->rx_opt.mss_clamp; @@ -1716,7 +1927,8 @@ static inline int __tcp_mtu_to_mss(struct sock *sk, int pmtu) mss_now -= icsk->icsk_ext_hdr_len; /* Then reserve room for full set of TCP options and 8 bytes of data */ - mss_now = max(mss_now, sock_net(sk)->ipv4.sysctl_tcp_min_snd_mss); + mss_now = max(mss_now, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_snd_mss)); return mss_now; } @@ -1727,28 +1939,18 @@ int tcp_mtu_to_mss(struct sock *sk, int pmtu) return __tcp_mtu_to_mss(sk, pmtu) - (tcp_sk(sk)->tcp_header_len - sizeof(struct tcphdr)); } -EXPORT_SYMBOL(tcp_mtu_to_mss); +EXPORT_IPV6_MOD(tcp_mtu_to_mss); /* Inverse of above */ int tcp_mss_to_mtu(struct sock *sk, int mss) { const struct tcp_sock *tp = tcp_sk(sk); const struct inet_connection_sock *icsk = inet_csk(sk); - int mtu; - mtu = mss + + return mss + tp->tcp_header_len + icsk->icsk_ext_hdr_len + icsk->icsk_af_ops->net_header_len; - - /* IPv6 adds a frag_hdr in case RTAX_FEATURE_ALLFRAG is set */ - if (icsk->icsk_af_ops->net_frag_header_len) { - const struct dst_entry *dst = __sk_dst_get(sk); - - if (dst && dst_allfrag(dst)) - mtu += icsk->icsk_af_ops->net_frag_header_len; - } - return mtu; } EXPORT_SYMBOL(tcp_mss_to_mtu); @@ -1759,15 +1961,14 @@ void tcp_mtup_init(struct sock *sk) struct inet_connection_sock *icsk = inet_csk(sk); struct net *net = sock_net(sk); - icsk->icsk_mtup.enabled = net->ipv4.sysctl_tcp_mtu_probing > 1; + icsk->icsk_mtup.enabled = READ_ONCE(net->ipv4.sysctl_tcp_mtu_probing) > 1; icsk->icsk_mtup.search_high = tp->rx_opt.mss_clamp + sizeof(struct tcphdr) + icsk->icsk_af_ops->net_header_len; - icsk->icsk_mtup.search_low = tcp_mss_to_mtu(sk, net->ipv4.sysctl_tcp_base_mss); + icsk->icsk_mtup.search_low = tcp_mss_to_mtu(sk, READ_ONCE(net->ipv4.sysctl_tcp_base_mss)); icsk->icsk_mtup.probe_size = 0; if (icsk->icsk_mtup.enabled) icsk->icsk_mtup.probe_timestamp = tcp_jiffies32; } -EXPORT_SYMBOL(tcp_mtup_init); /* This function synchronize snd mss to current pmtu/exthdr set. @@ -1811,7 +2012,7 @@ unsigned int tcp_sync_mss(struct sock *sk, u32 pmtu) return mss_now; } -EXPORT_SYMBOL(tcp_sync_mss); +EXPORT_IPV6_MOD(tcp_sync_mss); /* Compute the current effective MSS, taking SACKs and IP options, * and even PMTU discovery events into account. @@ -1823,7 +2024,7 @@ unsigned int tcp_current_mss(struct sock *sk) u32 mss_now; unsigned int header_len; struct tcp_out_options opts; - struct tcp_md5sig_key *md5; + struct tcp_key key; mss_now = tp->mss_cache; @@ -1832,8 +2033,8 @@ unsigned int tcp_current_mss(struct sock *sk) if (mtu != inet_csk(sk)->icsk_pmtu_cookie) mss_now = tcp_sync_mss(sk, mtu); } - - header_len = tcp_established_options(sk, NULL, &opts, &md5) + + tcp_get_current_key(sk, &key); + header_len = tcp_established_options(sk, NULL, &opts, &key) + sizeof(struct tcphdr); /* The mss_cache is sized based on tp->tcp_header_len, which assumes * some common options. If this is an odd packet (because we have SACK @@ -1860,9 +2061,9 @@ static void tcp_cwnd_application_limited(struct sock *sk) /* Limited by application or receiver window. */ u32 init_win = tcp_init_cwnd(tp, __sk_dst_get(sk)); u32 win_used = max(tp->snd_cwnd_used, init_win); - if (win_used < tp->snd_cwnd) { + if (win_used < tcp_snd_cwnd(tp)) { tp->snd_ssthresh = tcp_current_ssthresh(sk); - tp->snd_cwnd = (tp->snd_cwnd + win_used) >> 1; + tcp_snd_cwnd_set(tp, (tcp_snd_cwnd(tp) + win_used) >> 1); } tp->snd_cwnd_used = 0; } @@ -1874,15 +2075,20 @@ static void tcp_cwnd_validate(struct sock *sk, bool is_cwnd_limited) const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; struct tcp_sock *tp = tcp_sk(sk); - /* Track the maximum number of outstanding packets in each - * window, and remember whether we were cwnd-limited then. + /* Track the strongest available signal of the degree to which the cwnd + * is fully utilized. If cwnd-limited then remember that fact for the + * current window. If not cwnd-limited then track the maximum number of + * outstanding packets in the current window. (If cwnd-limited then we + * chose to not update tp->max_packets_out to avoid an extra else + * clause with no functional impact.) */ - if (!before(tp->snd_una, tp->max_packets_seq) || - tp->packets_out > tp->max_packets_out || - is_cwnd_limited) { - tp->max_packets_out = tp->packets_out; - tp->max_packets_seq = tp->snd_nxt; + if (!before(tp->snd_una, tp->cwnd_usage_seq) || + is_cwnd_limited || + (!tp->is_cwnd_limited && + tp->packets_out > tp->max_packets_out)) { tp->is_cwnd_limited = is_cwnd_limited; + tp->max_packets_out = tp->packets_out; + tp->cwnd_usage_seq = tp->snd_nxt; } if (tcp_is_cwnd_limited(sk)) { @@ -1894,7 +2100,7 @@ static void tcp_cwnd_validate(struct sock *sk, bool is_cwnd_limited) if (tp->packets_out > tp->snd_cwnd_used) tp->snd_cwnd_used = tp->packets_out; - if (sock_net(sk)->ipv4.sysctl_tcp_slow_start_after_idle && + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_slow_start_after_idle) && (s32)(tcp_jiffies32 - tp->snd_cwnd_stamp) >= inet_csk(sk)->icsk_rto && !ca_ops->cong_control) tcp_cwnd_application_limited(sk); @@ -1951,25 +2157,34 @@ static bool tcp_nagle_check(bool partial, const struct tcp_sock *tp, } /* Return how many segs we'd like on a TSO packet, - * to send one TSO packet per ms + * depending on current pacing rate, and how close the peer is. + * + * Rationale is: + * - For close peers, we rather send bigger packets to reduce + * cpu costs, because occasional losses will be repaired fast. + * - For long distance/rtt flows, we would like to get ACK clocking + * with 1 ACK per ms. + * + * Use min_rtt to help adapt TSO burst size, with smaller min_rtt resulting + * in bigger TSO bursts. We we cut the RTT-based allowance in half + * for every 2^9 usec (aka 512 us) of RTT, so that the RTT-based allowance + * is below 1500 bytes after 6 * ~500 usec = 3ms. */ static u32 tcp_tso_autosize(const struct sock *sk, unsigned int mss_now, int min_tso_segs) { - u32 bytes, segs; + unsigned long bytes; + u32 r; - bytes = min_t(unsigned long, - sk->sk_pacing_rate >> READ_ONCE(sk->sk_pacing_shift), - sk->sk_gso_max_size - 1 - MAX_TCP_HEADER); + bytes = READ_ONCE(sk->sk_pacing_rate) >> READ_ONCE(sk->sk_pacing_shift); - /* Goal is to send at least one packet per ms, - * not one big TSO packet every 100 ms. - * This preserves ACK clocking and is consistent - * with tcp_tso_should_defer() heuristic. - */ - segs = max_t(u32, bytes / mss_now, min_tso_segs); + r = tcp_min_rtt(tcp_sk(sk)) >> READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_tso_rtt_log); + if (r < BITS_PER_TYPE(sk->sk_gso_max_size)) + bytes += sk->sk_gso_max_size >> r; - return segs; + bytes = min_t(unsigned long, bytes, sk->sk_gso_max_size); + + return max_t(u32, bytes / mss_now, min_tso_segs); } /* Return the number of segments we want in the skb we are transmitting. @@ -1982,7 +2197,7 @@ static u32 tcp_tso_segs(struct sock *sk, unsigned int mss_now) min_tso = ca_ops->min_tso_segs ? ca_ops->min_tso_segs(sk) : - sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs; + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); tso_segs = tcp_tso_autosize(sk, mss_now, min_tso); return min_t(u32, tso_segs, sk->sk_gso_max_segs); @@ -2023,18 +2238,12 @@ static unsigned int tcp_mss_split_point(const struct sock *sk, /* Can at least one segment of SKB be sent right now, according to the * congestion window rules? If so, return how many segments are allowed. */ -static inline unsigned int tcp_cwnd_test(const struct tcp_sock *tp, - const struct sk_buff *skb) +static u32 tcp_cwnd_test(const struct tcp_sock *tp) { u32 in_flight, cwnd, halfcwnd; - /* Don't be strict about the congestion window for the final FIN. */ - if ((TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) && - tcp_skb_pcount(skb) == 1) - return 1; - in_flight = tcp_packets_in_flight(tp); - cwnd = tp->snd_cwnd; + cwnd = tcp_snd_cwnd(tp); if (in_flight >= cwnd) return 0; @@ -2053,10 +2262,9 @@ static int tcp_init_tso_segs(struct sk_buff *skb, unsigned int mss_now) { int tso_segs = tcp_skb_pcount(skb); - if (!tso_segs || (tso_segs > 1 && tcp_skb_mss(skb) != mss_now)) { - tcp_set_skb_tso_segs(skb, mss_now); - tso_segs = tcp_skb_pcount(skb); - } + if (!tso_segs || (tso_segs > 1 && tcp_skb_mss(skb) != mss_now)) + return tcp_set_skb_tso_segs(skb, mss_now); + return tso_segs; } @@ -2111,14 +2319,12 @@ static int tso_fragment(struct sock *sk, struct sk_buff *skb, unsigned int len, { int nlen = skb->len - len; struct sk_buff *buff; - u8 flags; + u16 flags; /* All of a TSO frame must be composed of paged data. */ - if (skb->len != skb->data_len) - return tcp_fragment(sk, TCP_FRAG_IN_WRITE_QUEUE, - skb, len, mss_now, gfp); + DEBUG_NET_WARN_ON_ONCE(skb->len != skb->data_len); - buff = tcp_stream_alloc_skb(sk, 0, gfp, true); + buff = tcp_stream_alloc_skb(sk, gfp, true); if (unlikely(!buff)) return -ENOMEM; skb_copy_decrypted(buff, skb); @@ -2166,7 +2372,8 @@ static bool tcp_tso_should_defer(struct sock *sk, struct sk_buff *skb, u32 max_segs) { const struct inet_connection_sock *icsk = inet_csk(sk); - u32 send_win, cong_win, limit, in_flight; + u32 send_win, cong_win, limit, in_flight, threshold; + u64 srtt_in_ns, expected_ack, how_far_is_the_ack; struct tcp_sock *tp = tcp_sk(sk); struct sk_buff *head; int win_divisor; @@ -2187,12 +2394,12 @@ static bool tcp_tso_should_defer(struct sock *sk, struct sk_buff *skb, in_flight = tcp_packets_in_flight(tp); BUG_ON(tcp_skb_pcount(skb) <= 1); - BUG_ON(tp->snd_cwnd <= in_flight); + BUG_ON(tcp_snd_cwnd(tp) <= in_flight); send_win = tcp_wnd_end(tp) - TCP_SKB_CB(skb)->seq; /* From in_flight test above, we know that cwnd > in_flight. */ - cong_win = (tp->snd_cwnd - in_flight) * tp->mss_cache; + cong_win = (tcp_snd_cwnd(tp) - in_flight) * tp->mss_cache; limit = min(send_win, cong_win); @@ -2206,7 +2413,7 @@ static bool tcp_tso_should_defer(struct sock *sk, struct sk_buff *skb, win_divisor = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_tso_win_divisor); if (win_divisor) { - u32 chunk = min(tp->snd_wnd, tp->snd_cwnd * tp->mss_cache); + u32 chunk = min(tp->snd_wnd, tcp_snd_cwnd(tp) * tp->mss_cache); /* If at least some fraction of a window is available, * just use it. @@ -2228,9 +2435,19 @@ static bool tcp_tso_should_defer(struct sock *sk, struct sk_buff *skb, head = tcp_rtx_queue_head(sk); if (!head) goto send_now; - delta = tp->tcp_clock_cache - head->tstamp; - /* If next ACK is likely to come too late (half srtt), do not defer */ - if ((s64)(delta - (u64)NSEC_PER_USEC * (tp->srtt_us >> 4)) < 0) + + srtt_in_ns = (u64)(NSEC_PER_USEC >> 3) * tp->srtt_us; + /* When is the ACK expected ? */ + expected_ack = head->tstamp + srtt_in_ns; + /* How far from now is the ACK expected ? */ + how_far_is_the_ack = expected_ack - tp->tcp_clock_cache; + + /* If next ACK is likely to come too late, + * ie in more than min(1ms, half srtt), do not defer. + */ + threshold = min(srtt_in_ns >> 1, NSEC_PER_MSEC); + + if ((s64)(how_far_is_the_ack - threshold) > 0) goto send_now; /* Ok, it looks like it is advisable to defer. @@ -2270,7 +2487,7 @@ static inline void tcp_mtu_check_reprobe(struct sock *sk) u32 interval; s32 delta; - interval = net->ipv4.sysctl_tcp_probe_interval; + interval = READ_ONCE(net->ipv4.sysctl_tcp_probe_interval); delta = tcp_jiffies32 - icsk->icsk_mtup.probe_timestamp; if (unlikely(delta >= interval * HZ)) { int mss = tcp_current_mss(sk); @@ -2296,9 +2513,7 @@ static bool tcp_can_coalesce_send_queue_head(struct sock *sk, int len) if (len <= skb->len) break; - if (unlikely(TCP_SKB_CB(skb)->eor) || - tcp_has_tx_tstamp(skb) || - !skb_pure_zcopy_same(skb, next)) + if (tcp_has_tx_tstamp(skb) || !tcp_skb_can_collapse(skb, next)) return false; len -= skb->len; @@ -2307,6 +2522,72 @@ static bool tcp_can_coalesce_send_queue_head(struct sock *sk, int len) return true; } +static int tcp_clone_payload(struct sock *sk, struct sk_buff *to, + int probe_size) +{ + skb_frag_t *lastfrag = NULL, *fragto = skb_shinfo(to)->frags; + int i, todo, len = 0, nr_frags = 0; + const struct sk_buff *skb; + + if (!sk_wmem_schedule(sk, to->truesize + probe_size)) + return -ENOMEM; + + skb_queue_walk(&sk->sk_write_queue, skb) { + const skb_frag_t *fragfrom = skb_shinfo(skb)->frags; + + if (skb_headlen(skb)) + return -EINVAL; + + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++, fragfrom++) { + if (len >= probe_size) + goto commit; + todo = min_t(int, skb_frag_size(fragfrom), + probe_size - len); + len += todo; + if (lastfrag && + skb_frag_page(fragfrom) == skb_frag_page(lastfrag) && + skb_frag_off(fragfrom) == skb_frag_off(lastfrag) + + skb_frag_size(lastfrag)) { + skb_frag_size_add(lastfrag, todo); + continue; + } + if (unlikely(nr_frags == MAX_SKB_FRAGS)) + return -E2BIG; + skb_frag_page_copy(fragto, fragfrom); + skb_frag_off_copy(fragto, fragfrom); + skb_frag_size_set(fragto, todo); + nr_frags++; + lastfrag = fragto++; + } + } +commit: + WARN_ON_ONCE(len != probe_size); + for (i = 0; i < nr_frags; i++) + skb_frag_ref(to, i); + + skb_shinfo(to)->nr_frags = nr_frags; + to->truesize += probe_size; + to->len += probe_size; + to->data_len += probe_size; + __skb_header_release(to); + return 0; +} + +/* tcp_mtu_probe() and tcp_grow_skb() can both eat an skb (src) if + * all its payload was moved to another one (dst). + * Make sure to transfer tcp_flags, eor, and tstamp. + */ +static void tcp_eat_one_skb(struct sock *sk, + struct sk_buff *dst, + struct sk_buff *src) +{ + TCP_SKB_CB(dst)->tcp_flags |= TCP_SKB_CB(src)->tcp_flags; + TCP_SKB_CB(dst)->eor = TCP_SKB_CB(src)->eor; + tcp_skb_collapse_tstamp(dst, src); + tcp_unlink_write_queue(src, sk); + tcp_wmem_free_skb(sk, src); +} + /* Create a new MTU probe if we are ready. * MTU probe is regularly attempting to increase the path MTU by * deliberately sending larger packets. This discovers routing @@ -2336,7 +2617,7 @@ static int tcp_mtu_probe(struct sock *sk) if (likely(!icsk->icsk_mtup.enabled || icsk->icsk_mtup.probe_size || inet_csk(sk)->icsk_ca_state != TCP_CA_Open || - tp->snd_cwnd < 11 || + tcp_snd_cwnd(tp) < 11 || tp->rx_opt.num_sacks || tp->rx_opt.dsack)) return -1; @@ -2354,7 +2635,7 @@ static int tcp_mtu_probe(struct sock *sk) * probing process by not resetting search range to its orignal. */ if (probe_size > tcp_mtu_to_mss(sk, icsk->icsk_mtup.search_high) || - interval < net->ipv4.sysctl_tcp_probe_threshold) { + interval < READ_ONCE(net->ipv4.sysctl_tcp_probe_threshold)) { /* Check whether enough time has elaplased for * another round of probing. */ @@ -2372,7 +2653,7 @@ static int tcp_mtu_probe(struct sock *sk) return 0; /* Do we need to wait to drain cwnd? With none in flight, don't stall */ - if (tcp_packets_in_flight(tp) + 2 > tp->snd_cwnd) { + if (tcp_packets_in_flight(tp) + 2 > tcp_snd_cwnd(tp)) { if (!tcp_packets_in_flight(tp)) return -1; else @@ -2383,9 +2664,16 @@ static int tcp_mtu_probe(struct sock *sk) return -1; /* We're allowed to probe. Build it now. */ - nskb = tcp_stream_alloc_skb(sk, probe_size, GFP_ATOMIC, false); + nskb = tcp_stream_alloc_skb(sk, GFP_ATOMIC, false); if (!nskb) return -1; + + /* build the payload, and be prepared to abort if this fails. */ + if (tcp_clone_payload(sk, nskb, probe_size)) { + tcp_skb_tsorted_anchor_cleanup(nskb); + consume_skb(nskb); + return -1; + } sk_wmem_queued_add(sk, nskb->truesize); sk_mem_charge(sk, nskb->truesize); @@ -2403,28 +2691,14 @@ static int tcp_mtu_probe(struct sock *sk) len = 0; tcp_for_write_queue_from_safe(skb, next, sk) { copy = min_t(int, skb->len, probe_size - len); - skb_copy_bits(skb, 0, skb_put(nskb, copy), copy); if (skb->len <= copy) { - /* We've eaten all the data from this skb. - * Throw it away. */ - TCP_SKB_CB(nskb)->tcp_flags |= TCP_SKB_CB(skb)->tcp_flags; - /* If this is the last SKB we copy and eor is set - * we need to propagate it to the new skb. - */ - TCP_SKB_CB(nskb)->eor = TCP_SKB_CB(skb)->eor; - tcp_skb_collapse_tstamp(nskb, skb); - tcp_unlink_write_queue(skb, sk); - tcp_wmem_free_skb(sk, skb); + tcp_eat_one_skb(sk, nskb, skb); } else { TCP_SKB_CB(nskb)->tcp_flags |= TCP_SKB_CB(skb)->tcp_flags & ~(TCPHDR_FIN|TCPHDR_PSH); - if (!skb_shinfo(skb)->nr_frags) { - skb_pull(skb, copy); - } else { - __pskb_trim_head(skb, copy); - tcp_set_skb_tso_segs(skb, mss_now); - } + __pskb_trim_head(skb, copy); + tcp_set_skb_tso_segs(skb, mss_now); TCP_SKB_CB(skb)->seq += copy; } @@ -2441,7 +2715,7 @@ static int tcp_mtu_probe(struct sock *sk) if (!tcp_transmit_skb(sk, nskb, 1, GFP_ATOMIC)) { /* Decrement cwnd here because we are sending * effectively two packets. */ - tp->snd_cwnd--; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) - 1); tcp_event_new_data_sent(sk, nskb); icsk->icsk_mtup.probe_size = tcp_mss_to_mtu(sk, nskb->len); @@ -2473,6 +2747,18 @@ static bool tcp_pacing_check(struct sock *sk) return true; } +static bool tcp_rtx_queue_empty_or_single_skb(const struct sock *sk) +{ + const struct rb_node *node = sk->tcp_rtx_queue.rb_node; + + /* No skb in the rtx queue. */ + if (!node) + return true; + + /* Only one skb in rtx queue. */ + return !node->rb_left && !node->rb_right; +} + /* TCP Small Queues : * Control number of packets in qdisc/devices to two packets / or ~1 ms. * (These limits are doubled for retransmits) @@ -2491,15 +2777,15 @@ static bool tcp_small_queue_check(struct sock *sk, const struct sk_buff *skb, limit = max_t(unsigned long, 2 * skb->truesize, - sk->sk_pacing_rate >> READ_ONCE(sk->sk_pacing_shift)); - if (sk->sk_pacing_status == SK_PACING_NONE) - limit = min_t(unsigned long, limit, - sock_net(sk)->ipv4.sysctl_tcp_limit_output_bytes); + READ_ONCE(sk->sk_pacing_rate) >> READ_ONCE(sk->sk_pacing_shift)); + limit = min_t(unsigned long, limit, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_limit_output_bytes)); limit <<= factor; if (static_branch_unlikely(&tcp_tx_delay_enabled) && tcp_sk(sk)->tcp_tx_delay) { - u64 extra_bytes = (u64)sk->sk_pacing_rate * tcp_sk(sk)->tcp_tx_delay; + u64 extra_bytes = (u64)READ_ONCE(sk->sk_pacing_rate) * + tcp_sk(sk)->tcp_tx_delay; /* TSQ is based on skb truesize sum (sk_wmem_alloc), so we * approximate our needs assuming an ~100% skb->truesize overhead. @@ -2510,12 +2796,12 @@ static bool tcp_small_queue_check(struct sock *sk, const struct sk_buff *skb, limit += extra_bytes; } if (refcount_read(&sk->sk_wmem_alloc) > limit) { - /* Always send skb if rtx queue is empty. + /* Always send skb if rtx queue is empty or has one skb. * No need to wait for TX completion to call us back, - * after softirq/tasklet schedule. + * after softirq schedule. * This helps when TX completions are delayed too much. */ - if (tcp_rtx_queue_empty(sk)) + if (tcp_rtx_queue_empty_or_single_skb(sk)) return false; set_bit(TSQ_THROTTLED, &sk->sk_tsq_flags); @@ -2572,6 +2858,35 @@ void tcp_chrono_stop(struct sock *sk, const enum tcp_chrono type) tcp_chrono_set(tp, TCP_CHRONO_BUSY); } +/* First skb in the write queue is smaller than ideal packet size. + * Check if we can move payload from the second skb in the queue. + */ +static void tcp_grow_skb(struct sock *sk, struct sk_buff *skb, int amount) +{ + struct sk_buff *next_skb = skb->next; + unsigned int nlen; + + if (tcp_skb_is_last(sk, skb)) + return; + + if (!tcp_skb_can_collapse(skb, next_skb)) + return; + + nlen = min_t(u32, amount, next_skb->len); + if (!nlen || !skb_shift(skb, next_skb, nlen)) + return; + + TCP_SKB_CB(skb)->end_seq += nlen; + TCP_SKB_CB(next_skb)->seq += nlen; + + if (!next_skb->len) { + /* In case FIN is set, we need to update end_seq */ + TCP_SKB_CB(skb)->end_seq = TCP_SKB_CB(next_skb)->end_seq; + + tcp_eat_one_skb(sk, skb, next_skb); + } +} + /* This routine writes packets to the network. It advances the * send_head. This happens as incoming acks open up the remote * window for us. @@ -2592,14 +2907,18 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, struct tcp_sock *tp = tcp_sk(sk); struct sk_buff *skb; unsigned int tso_segs, sent_pkts; - int cwnd_quota; + u32 cwnd_quota, max_segs; int result; bool is_cwnd_limited = false, is_rwnd_limited = false; - u32 max_segs; sent_pkts = 0; tcp_mstamp_refresh(tp); + + /* AccECN option beacon depends on mstamp, it may change mss */ + if (tcp_ecn_mode_accecn(tp) && tcp_accecn_option_beacon_check(sk)) + mss_now = tcp_current_mss(sk); + if (!push_one) { /* Do MTU probing. */ result = tcp_mtu_probe(sk); @@ -2613,10 +2932,12 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, max_segs = tcp_tso_segs(sk, mss_now); while ((skb = tcp_send_head(sk))) { unsigned int limit; + int missing_bytes; if (unlikely(tp->repair) && tp->repair_queue == TCP_SEND_QUEUE) { /* "skb_mstamp_ns" is used as a start point for the retransmit timer */ - skb->skb_mstamp_ns = tp->tcp_wstamp_ns = tp->tcp_clock_cache; + tp->tcp_wstamp_ns = tp->tcp_clock_cache; + skb_set_delivery_time(skb, tp->tcp_wstamp_ns, SKB_CLOCK_MONOTONIC); list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); tcp_init_tso_segs(skb, mss_now); goto repair; /* Skip network transmission */ @@ -2625,10 +2946,7 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, if (tcp_pacing_check(sk)) break; - tso_segs = tcp_init_tso_segs(skb, mss_now); - BUG_ON(!tso_segs); - - cwnd_quota = tcp_cwnd_test(tp, skb); + cwnd_quota = tcp_cwnd_test(tp); if (!cwnd_quota) { if (push_one == 2) /* Force out a loss probe pkt. */ @@ -2636,6 +2954,12 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, else break; } + cwnd_quota = min(cwnd_quota, max_segs); + missing_bytes = cwnd_quota * mss_now - skb->len; + if (missing_bytes > 0) + tcp_grow_skb(sk, skb, missing_bytes); + + tso_segs = tcp_set_skb_tso_segs(skb, mss_now); if (unlikely(!tcp_snd_wnd_test(tp, skb, mss_now))) { is_rwnd_limited = true; @@ -2657,9 +2981,7 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, limit = mss_now; if (tso_segs > 1 && !tcp_urg_mode(tp)) limit = tcp_mss_split_point(sk, skb, mss_now, - min_t(unsigned int, - cwnd_quota, - max_segs), + cwnd_quota, nonagle); if (skb->len > limit && @@ -2698,7 +3020,7 @@ repair: else tcp_chrono_stop(sk, TCP_CHRONO_RWND_LIMITED); - is_cwnd_limited |= (tcp_packets_in_flight(tp) >= tp->snd_cwnd); + is_cwnd_limited |= (tcp_packets_in_flight(tp) >= tcp_snd_cwnd(tp)); if (likely(sent_pkts || is_cwnd_limited)) tcp_cwnd_validate(sk, is_cwnd_limited); @@ -2718,7 +3040,7 @@ bool tcp_schedule_loss_probe(struct sock *sk, bool advancing_rto) { struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); - u32 timeout, rto_delta_us; + u32 timeout, timeout_us, rto_delta_us; int early_retrans; /* Don't do any loss probe on a Fast Open connection before 3WHS @@ -2727,7 +3049,7 @@ bool tcp_schedule_loss_probe(struct sock *sk, bool advancing_rto) if (rcu_access_pointer(tp->fastopen_rsk)) return false; - early_retrans = sock_net(sk)->ipv4.sysctl_tcp_early_retrans; + early_retrans = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_early_retrans); /* Schedule a loss probe in 2*RTT for SACK capable connections * not in loss recovery, that are either limited by cwnd or application. */ @@ -2742,11 +3064,12 @@ bool tcp_schedule_loss_probe(struct sock *sk, bool advancing_rto) * sample is available then probe after TCP_TIMEOUT_INIT. */ if (tp->srtt_us) { - timeout = usecs_to_jiffies(tp->srtt_us >> 2); + timeout_us = tp->srtt_us >> 2; if (tp->packets_out == 1) - timeout += TCP_RTO_MIN; + timeout_us += tcp_rto_min_us(sk); else - timeout += TCP_TIMEOUT_MIN; + timeout_us += TCP_TIMEOUT_MIN_US; + timeout = usecs_to_jiffies(timeout_us); } else { timeout = TCP_TIMEOUT_INIT; } @@ -2758,7 +3081,7 @@ bool tcp_schedule_loss_probe(struct sock *sk, bool advancing_rto) if (rto_delta_us > 0) timeout = min_t(u32, timeout, usecs_to_jiffies(rto_delta_us)); - tcp_reset_xmit_timer(sk, ICSK_TIME_LOSS_PROBE, timeout, TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_LOSS_PROBE, timeout, true); return true; } @@ -2806,10 +3129,8 @@ void tcp_send_loss_probe(struct sock *sk) } skb = skb_rb_last(&sk->tcp_rtx_queue); if (unlikely(!skb)) { - WARN_ONCE(tp->packets_out, - "invalid inflight: %u state %u cwnd %u mss %d\n", - tp->packets_out, sk->sk_state, tp->snd_cwnd, mss); - inet_csk(sk)->icsk_pending = 0; + tcp_warn_once(sk, tp->packets_out, "invalid inflight: "); + smp_store_release(&inet_csk(sk)->icsk_pending, 0); return; } @@ -2842,7 +3163,7 @@ probe_sent: NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPLOSSPROBES); /* Reset s.t. tcp_rearm_rto will restart timer from now */ - inet_csk(sk)->icsk_pending = 0; + smp_store_release(&inet_csk(sk)->icsk_pending, 0); rearm_timer: tcp_rearm_rto(sk); } @@ -2934,6 +3255,7 @@ u32 __tcp_select_window(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); /* MSS for the peer's data. Previous versions used mss_clamp * here. I don't know if the value based on our guesses * of peer's MSS is better for the performance. It's more correct @@ -2955,6 +3277,15 @@ u32 __tcp_select_window(struct sock *sk) if (mss <= 0) return 0; } + + /* Only allow window shrink if the sysctl is enabled and we have + * a non-zero scaling factor in effect. + */ + if (READ_ONCE(net->ipv4.sysctl_tcp_shrink_window) && tp->rx_opt.rcv_wscale) + goto shrink_window_allowed; + + /* do not allow window to shrink */ + if (free_space < (full_space >> 1)) { icsk->icsk_ack.quick = 0; @@ -3009,6 +3340,36 @@ u32 __tcp_select_window(struct sock *sk) } return window; + +shrink_window_allowed: + /* new window should always be an exact multiple of scaling factor */ + free_space = round_down(free_space, 1 << tp->rx_opt.rcv_wscale); + + if (free_space < (full_space >> 1)) { + icsk->icsk_ack.quick = 0; + + if (tcp_under_memory_pressure(sk)) + tcp_adjust_rcv_ssthresh(sk); + + /* if free space is too low, return a zero window */ + if (free_space < (allowed_space >> 4) || free_space < mss || + free_space < (1 << tp->rx_opt.rcv_wscale)) + return 0; + } + + if (free_space > tp->rcv_ssthresh) { + free_space = tp->rcv_ssthresh; + /* new window should always be an exact multiple of scaling factor + * + * For this case, we ALIGN "up" (increase free_space) because + * we know free_space is not zero here, it has been reduced from + * the memory-based limit, and rcv_ssthresh is not a hard limit + * (unlike sk_rcvbuf). + */ + free_space = ALIGN(free_space, (1 << tp->rx_opt.rcv_wscale)); + } + + return free_space; } void tcp_skb_collapse_tstamp(struct sk_buff *skb, @@ -3055,7 +3416,6 @@ static bool tcp_collapse_retrans(struct sock *sk, struct sk_buff *skb) TCP_SKB_CB(skb)->eor = TCP_SKB_CB(next_skb)->eor; /* changed transmit queue under us so clear hints */ - tcp_clear_retrans_hints_partial(tp); if (next_skb == tp->retransmit_skb_hint) tp->retransmit_skb_hint = skb; @@ -3074,6 +3434,8 @@ static bool tcp_can_collapse(const struct sock *sk, const struct sk_buff *skb) return false; if (skb_cloned(skb)) return false; + if (!skb_frags_readable(skb)) + return false; /* Some heuristics for collapsing over SACK'd could be invented */ if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED) return false; @@ -3091,7 +3453,7 @@ static void tcp_retrans_try_collapse(struct sock *sk, struct sk_buff *to, struct sk_buff *skb = to, *tmp; bool first = true; - if (!sock_net(sk)->ipv4.sysctl_tcp_retrans_collapse) + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_retrans_collapse)) return; if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN) return; @@ -3131,57 +3493,88 @@ int __tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs) struct tcp_sock *tp = tcp_sk(sk); unsigned int cur_mss; int diff, len, err; - + int avail_wnd; /* Inconclusive MTU probe */ if (icsk->icsk_mtup.probe_size) icsk->icsk_mtup.probe_size = 0; - if (skb_still_in_host_queue(sk, skb)) - return -EBUSY; + if (skb_still_in_host_queue(sk, skb)) { + err = -EBUSY; + goto out; + } +start: if (before(TCP_SKB_CB(skb)->seq, tp->snd_una)) { + if (unlikely(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN)) { + TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_SYN; + TCP_SKB_CB(skb)->seq++; + goto start; + } if (unlikely(before(TCP_SKB_CB(skb)->end_seq, tp->snd_una))) { WARN_ON_ONCE(1); - return -EINVAL; + err = -EINVAL; + goto out; + } + if (tcp_trim_head(sk, skb, tp->snd_una - TCP_SKB_CB(skb)->seq)) { + err = -ENOMEM; + goto out; } - if (tcp_trim_head(sk, skb, tp->snd_una - TCP_SKB_CB(skb)->seq)) - return -ENOMEM; } - if (inet_csk(sk)->icsk_af_ops->rebuild_header(sk)) - return -EHOSTUNREACH; /* Routing failure or similar. */ + if (inet_csk(sk)->icsk_af_ops->rebuild_header(sk)) { + err = -EHOSTUNREACH; /* Routing failure or similar. */ + goto out; + } cur_mss = tcp_current_mss(sk); + avail_wnd = tcp_wnd_end(tp) - TCP_SKB_CB(skb)->seq; /* If receiver has shrunk his window, and skb is out of * new window, do not retransmit it. The exception is the * case, when window is shrunk to zero. In this case - * our retransmit serves as a zero window probe. + * our retransmit of one segment serves as a zero window probe. */ - if (!before(TCP_SKB_CB(skb)->seq, tcp_wnd_end(tp)) && - TCP_SKB_CB(skb)->seq != tp->snd_una) - return -EAGAIN; + if (avail_wnd <= 0) { + if (TCP_SKB_CB(skb)->seq != tp->snd_una) { + err = -EAGAIN; + goto out; + } + avail_wnd = cur_mss; + } len = cur_mss * segs; + if (len > avail_wnd) { + len = rounddown(avail_wnd, cur_mss); + if (!len) + len = avail_wnd; + } if (skb->len > len) { if (tcp_fragment(sk, TCP_FRAG_IN_RTX_QUEUE, skb, len, - cur_mss, GFP_ATOMIC)) - return -ENOMEM; /* We'll try again later. */ + cur_mss, GFP_ATOMIC)) { + err = -ENOMEM; /* We'll try again later. */ + goto out; + } } else { - if (skb_unclone_keeptruesize(skb, GFP_ATOMIC)) - return -ENOMEM; + if (skb_unclone_keeptruesize(skb, GFP_ATOMIC)) { + err = -ENOMEM; + goto out; + } diff = tcp_skb_pcount(skb); tcp_set_skb_tso_segs(skb, cur_mss); diff -= tcp_skb_pcount(skb); if (diff) tcp_adjust_pcount(sk, skb, diff); - if (skb->len < cur_mss) - tcp_retrans_try_collapse(sk, skb, cur_mss); + avail_wnd = min_t(int, avail_wnd, cur_mss); + if (skb->len < avail_wnd) + tcp_retrans_try_collapse(sk, skb, avail_wnd); } - /* RFC3168, section 6.1.1.1. ECN fallback */ + /* RFC3168, section 6.1.1.1. ECN fallback + * As AccECN uses the same SYN flags (+ AE), this check covers both + * cases. + */ if ((TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN_ECN) == TCPHDR_SYN_ECN) tcp_ecn_clear_syn(sk, skb); @@ -3219,20 +3612,20 @@ int __tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs) err = tcp_transmit_skb(sk, skb, 1, GFP_ATOMIC); } - /* To avoid taking spuriously low RTT samples based on a timestamp - * for a transmit that never happened, always mark EVER_RETRANS - */ - TCP_SKB_CB(skb)->sacked |= TCPCB_EVER_RETRANS; - if (BPF_SOCK_OPS_TEST_FLAG(tp, BPF_SOCK_OPS_RETRANS_CB_FLAG)) tcp_call_bpf_3arg(sk, BPF_SOCK_OPS_RETRANS_CB, TCP_SKB_CB(skb)->seq, segs, err); - if (likely(!err)) { - trace_tcp_retransmit_skb(sk, skb); - } else if (err != -EBUSY) { + if (unlikely(err) && err != -EBUSY) NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPRETRANSFAIL, segs); - } + + /* To avoid taking spuriously low RTT samples based on a timestamp + * for a transmit that never happened, always mark EVER_RETRANS + */ + TCP_SKB_CB(skb)->sacked |= TCPCB_EVER_RETRANS; + +out: + trace_tcp_retransmit_skb(sk, skb, err); return err; } @@ -3253,7 +3646,7 @@ int tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs) /* Save stamp of the first (attempted) retransmit. */ if (!tp->retrans_stamp) - tp->retrans_stamp = tcp_skb_timestamp(skb); + tp->retrans_stamp = tcp_skb_timestamp_ts(tp->tcp_usec_ts, skb); if (tp->undo_retrans < 0) tp->undo_retrans = 0; @@ -3292,7 +3685,7 @@ void tcp_xmit_retransmit_queue(struct sock *sk) if (!hole) tp->retransmit_skb_hint = skb; - segs = tp->snd_cwnd - tcp_packets_in_flight(tp); + segs = tcp_snd_cwnd(tp) - tcp_packets_in_flight(tp); if (segs <= 0) break; sacked = TCP_SKB_CB(skb)->sacked; @@ -3336,8 +3729,7 @@ void tcp_xmit_retransmit_queue(struct sock *sk) } if (rearm_timer) tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - inet_csk(sk)->icsk_rto, - TCP_RTO_MAX); + inet_csk(sk)->icsk_rto, true); } /* We allow to exceed memory limits for FIN packets to expedite @@ -3349,17 +3741,22 @@ void tcp_xmit_retransmit_queue(struct sock *sk) */ void sk_forced_mem_schedule(struct sock *sk, int size) { - int amt; + int delta, amt; - if (size <= sk->sk_forward_alloc) + delta = size - sk->sk_forward_alloc; + if (delta <= 0) + return; + + amt = sk_mem_pages(delta); + sk_forward_alloc_add(sk, amt << PAGE_SHIFT); + + if (mem_cgroup_sk_enabled(sk)) + mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge() | __GFP_NOFAIL); + + if (sk->sk_bypass_prot_mem) return; - amt = sk_mem_pages(size); - sk->sk_forward_alloc += amt * SK_MEM_QUANTUM; - sk_memory_allocated_add(sk, amt); - if (mem_cgroup_sockets_enabled && sk->sk_memcg) - mem_cgroup_charge_skmem(sk->sk_memcg, amt, - gfp_memcg_charge() | __GFP_NOFAIL); + sk_memory_allocated_add(sk, amt); } /* Send a FIN. The caller locks the socket for us. @@ -3394,7 +3791,9 @@ void tcp_send_fin(struct sock *sk) return; } } else { - skb = alloc_skb_fclone(MAX_TCP_HEADER, sk->sk_allocation); + skb = alloc_skb_fclone(MAX_TCP_HEADER, + sk_gfp_mask(sk, GFP_ATOMIC | + __GFP_NOWARN)); if (unlikely(!skb)) return; @@ -3402,7 +3801,7 @@ void tcp_send_fin(struct sock *sk) skb_reserve(skb, MAX_TCP_HEADER); sk_forced_mem_schedule(sk, skb->truesize); /* FIN eats a sequence byte, write_seq advanced by tcp_queue_skb(). */ - tcp_init_nondata_skb(skb, tp->write_seq, + tcp_init_nondata_skb(skb, sk, tp->write_seq, TCPHDR_ACK | TCPHDR_FIN); tcp_queue_skb(sk, skb); } @@ -3414,7 +3813,8 @@ void tcp_send_fin(struct sock *sk) * was unread data in the receive queue. This behavior is recommended * by RFC 2525, section 2.17. -DaveM */ -void tcp_send_active_reset(struct sock *sk, gfp_t priority) +void tcp_send_active_reset(struct sock *sk, gfp_t priority, + enum sk_rst_reason reason) { struct sk_buff *skb; @@ -3429,7 +3829,7 @@ void tcp_send_active_reset(struct sock *sk, gfp_t priority) /* Reserve space for headers and prepare control bits. */ skb_reserve(skb, MAX_TCP_HEADER); - tcp_init_nondata_skb(skb, tcp_acceptable_seq(sk), + tcp_init_nondata_skb(skb, sk, tcp_acceptable_seq(sk), TCPHDR_ACK | TCPHDR_RST); tcp_mstamp_refresh(tcp_sk(sk)); /* Send it off. */ @@ -3439,7 +3839,7 @@ void tcp_send_active_reset(struct sock *sk, gfp_t priority) /* skb of trace_tcp_send_reset() keeps the skb that caused RST, * skb here is different to the troublesome skb, so use NULL */ - trace_tcp_send_reset(sk, NULL); + trace_tcp_send_reset(sk, NULL, reason); } /* Send a crossed SYN-ACK during socket establishment. @@ -3500,8 +3900,8 @@ struct sk_buff *tcp_make_synack(const struct sock *sk, struct dst_entry *dst, { struct inet_request_sock *ireq = inet_rsk(req); const struct tcp_sock *tp = tcp_sk(sk); - struct tcp_md5sig_key *md5 = NULL; struct tcp_out_options opts; + struct tcp_key key = {}; struct sk_buff *skb; int tcp_header_size; struct tcphdr *th; @@ -3518,7 +3918,7 @@ struct sk_buff *tcp_make_synack(const struct sock *sk, struct dst_entry *dst, switch (synack_type) { case TCP_SYNACK_NORMAL: - skb_set_owner_w(skb, req_to_sk(req)); + skb_set_owner_edemux(skb, req_to_sk(req)); break; case TCP_SYNACK_COOKIE: /* Under synflood, we do not attach skb to a socket, @@ -3541,25 +3941,57 @@ struct sk_buff *tcp_make_synack(const struct sock *sk, struct dst_entry *dst, now = tcp_clock_ns(); #ifdef CONFIG_SYN_COOKIES if (unlikely(synack_type == TCP_SYNACK_COOKIE && ireq->tstamp_ok)) - skb->skb_mstamp_ns = cookie_init_timestamp(req, now); + skb_set_delivery_time(skb, cookie_init_timestamp(req, now), + SKB_CLOCK_MONOTONIC); else #endif { - skb->skb_mstamp_ns = now; + skb_set_delivery_time(skb, now, SKB_CLOCK_MONOTONIC); if (!tcp_rsk(req)->snt_synack) /* Timestamp first SYNACK */ tcp_rsk(req)->snt_synack = tcp_skb_timestamp_us(skb); } -#ifdef CONFIG_TCP_MD5SIG +#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO) rcu_read_lock(); - md5 = tcp_rsk(req)->af_specific->req_md5_lookup(sk, req_to_sk(req)); #endif - skb_set_hash(skb, tcp_rsk(req)->txhash, PKT_HASH_TYPE_L4); + if (tcp_rsk_used_ao(req)) { +#ifdef CONFIG_TCP_AO + struct tcp_ao_key *ao_key = NULL; + u8 keyid = tcp_rsk(req)->ao_keyid; + u8 rnext = tcp_rsk(req)->ao_rcv_next; + + ao_key = tcp_sk(sk)->af_specific->ao_lookup(sk, req_to_sk(req), + keyid, -1); + /* If there is no matching key - avoid sending anything, + * especially usigned segments. It could try harder and lookup + * for another peer-matching key, but the peer has requested + * ao_keyid (RFC5925 RNextKeyID), so let's keep it simple here. + */ + if (unlikely(!ao_key)) { + trace_tcp_ao_synack_no_key(sk, keyid, rnext); + rcu_read_unlock(); + kfree_skb(skb); + net_warn_ratelimited("TCP-AO: the keyid %u from SYN packet is not present - not sending SYNACK\n", + keyid); + return NULL; + } + key.ao_key = ao_key; + key.type = TCP_KEY_AO; +#endif + } else { +#ifdef CONFIG_TCP_MD5SIG + key.md5_key = tcp_rsk(req)->af_specific->req_md5_lookup(sk, + req_to_sk(req)); + if (key.md5_key) + key.type = TCP_KEY_MD5; +#endif + } + skb_set_hash(skb, READ_ONCE(tcp_rsk(req)->txhash), PKT_HASH_TYPE_L4); /* bpf program will be interested in the tcp_flags */ TCP_SKB_CB(skb)->tcp_flags = TCPHDR_SYN | TCPHDR_ACK; - tcp_header_size = tcp_synack_options(sk, req, mss, skb, &opts, md5, - foc, synack_type, - syn_skb) + sizeof(*th); + tcp_header_size = tcp_synack_options(sk, req, mss, skb, &opts, + &key, foc, synack_type, syn_skb) + + sizeof(*th); skb_push(skb, tcp_header_size); skb_reset_transport_header(skb); @@ -3579,27 +4011,36 @@ struct sk_buff *tcp_make_synack(const struct sock *sk, struct dst_entry *dst, /* RFC1323: The window in SYN & SYN/ACK segments is never scaled. */ th->window = htons(min(req->rsk_rcv_wnd, 65535U)); - tcp_options_write((__be32 *)(th + 1), NULL, &opts); + tcp_options_write(th, NULL, tcp_rsk(req), &opts, &key); th->doff = (tcp_header_size >> 2); - __TCP_INC_STATS(sock_net(sk), TCP_MIB_OUTSEGS); + TCP_INC_STATS(sock_net(sk), TCP_MIB_OUTSEGS); -#ifdef CONFIG_TCP_MD5SIG /* Okay, we have all we need - do the md5 hash if needed */ - if (md5) + if (tcp_key_is_md5(&key)) { +#ifdef CONFIG_TCP_MD5SIG tcp_rsk(req)->af_specific->calc_md5_hash(opts.hash_location, - md5, req_to_sk(req), skb); + key.md5_key, req_to_sk(req), skb); +#endif + } else if (tcp_key_is_ao(&key)) { +#ifdef CONFIG_TCP_AO + tcp_rsk(req)->af_specific->ao_synack_hash(opts.hash_location, + key.ao_key, req, skb, + opts.hash_location - (u8 *)th, 0); +#endif + } +#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO) rcu_read_unlock(); #endif bpf_skops_write_hdr_opt((struct sock *)sk, skb, req, syn_skb, synack_type, &opts); - skb->skb_mstamp_ns = now; + skb_set_delivery_time(skb, now, SKB_CLOCK_MONOTONIC); tcp_add_tx_delay(skb, tp); return skb; } -EXPORT_SYMBOL(tcp_make_synack); +EXPORT_IPV6_MOD(tcp_make_synack); static void tcp_ca_dst_init(struct sock *sk, const struct dst_entry *dst) { @@ -3626,23 +4067,22 @@ static void tcp_connect_init(struct sock *sk) const struct dst_entry *dst = __sk_dst_get(sk); struct tcp_sock *tp = tcp_sk(sk); __u8 rcv_wscale; + u16 user_mss; u32 rcv_wnd; /* We'll fix this up when we get a response from the other end. * See tcp_input.c:tcp_rcv_state_process case TCP_SYN_SENT. */ tp->tcp_header_len = sizeof(struct tcphdr); - if (sock_net(sk)->ipv4.sysctl_tcp_timestamps) + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_timestamps)) tp->tcp_header_len += TCPOLEN_TSTAMP_ALIGNED; -#ifdef CONFIG_TCP_MD5SIG - if (tp->af_specific->md5_lookup(sk, sk)) - tp->tcp_header_len += TCPOLEN_MD5SIG_ALIGNED; -#endif + tcp_ao_connect_init(sk); /* If user gave his TCP_MAXSEG, record it to clamp */ - if (tp->rx_opt.user_mss) - tp->rx_opt.mss_clamp = tp->rx_opt.user_mss; + user_mss = READ_ONCE(tp->rx_opt.user_mss); + if (user_mss) + tp->rx_opt.mss_clamp = user_mss; tp->max_window = 0; tcp_mtup_init(sk); tcp_sync_mss(sk, dst_mtu(dst)); @@ -3650,7 +4090,7 @@ static void tcp_connect_init(struct sock *sk) tcp_ca_dst_init(sk, dst); if (!tp->window_clamp) - tp->window_clamp = dst_metric(dst, RTAX_WINDOW); + WRITE_ONCE(tp->window_clamp, dst_metric(dst, RTAX_WINDOW)); tp->advmss = tcp_mss_clamp(tp, dst_metric_advmss(dst)); tcp_initialize_rcv_mss(sk); @@ -3658,7 +4098,7 @@ static void tcp_connect_init(struct sock *sk) /* limit the window selection if the user enforce a smaller rx buffer */ if (sk->sk_userlocks & SOCK_RCVBUF_LOCK && (tp->window_clamp > tcp_full_space(sk) || tp->window_clamp == 0)) - tp->window_clamp = tcp_full_space(sk); + WRITE_ONCE(tp->window_clamp, tcp_full_space(sk)); rcv_wnd = tcp_rwnd_init_bpf(sk); if (rcv_wnd == 0) @@ -3668,14 +4108,14 @@ static void tcp_connect_init(struct sock *sk) tp->advmss - (tp->rx_opt.ts_recent_stamp ? tp->tcp_header_len - sizeof(struct tcphdr) : 0), &tp->rcv_wnd, &tp->window_clamp, - sock_net(sk)->ipv4.sysctl_tcp_window_scaling, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_window_scaling), &rcv_wscale, rcv_wnd); tp->rx_opt.rcv_wscale = rcv_wscale; tp->rcv_ssthresh = tp->rcv_wnd; - sk->sk_err = 0; + WRITE_ONCE(sk->sk_err, 0); sock_reset_flag(sk, SOCK_DONE); tp->snd_wnd = 0; tcp_init_wl(tp, 0); @@ -3693,7 +4133,7 @@ static void tcp_connect_init(struct sock *sk) WRITE_ONCE(tp->copied_seq, tp->rcv_nxt); inet_csk(sk)->icsk_rto = tcp_timeout_init(sk); - inet_csk(sk)->icsk_retransmits = 0; + WRITE_ONCE(inet_csk(sk)->icsk_retransmits, 0); tcp_clear_retrans(tp); } @@ -3719,10 +4159,12 @@ static void tcp_connect_queue_skb(struct sock *sk, struct sk_buff *skb) */ static int tcp_send_syn_data(struct sock *sk, struct sk_buff *syn) { + struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); struct tcp_fastopen_request *fo = tp->fastopen_req; - int space, err = 0; + struct page_frag *pfrag = sk_page_frag(sk); struct sk_buff *syn_data; + int space, err = 0; tp->rx_opt.mss_clamp = tp->advmss; /* If MSS is not cached */ if (!tcp_fastopen_cookie_check(sk, &tp->rx_opt.mss_clamp, &fo->cookie)) @@ -3733,31 +4175,39 @@ static int tcp_send_syn_data(struct sock *sk, struct sk_buff *syn) * private TCP options. The cost is reduced data space in SYN :( */ tp->rx_opt.mss_clamp = tcp_mss_clamp(tp, tp->rx_opt.mss_clamp); + /* Sync mss_cache after updating the mss_clamp */ + tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); - space = __tcp_mtu_to_mss(sk, inet_csk(sk)->icsk_pmtu_cookie) - + space = __tcp_mtu_to_mss(sk, icsk->icsk_pmtu_cookie) - MAX_TCP_OPTION_SPACE; space = min_t(size_t, space, fo->size); - /* limit to order-0 allocations */ - space = min_t(size_t, space, SKB_MAX_HEAD(MAX_TCP_HEADER)); - - syn_data = tcp_stream_alloc_skb(sk, space, sk->sk_allocation, false); + if (space && + !skb_page_frag_refill(min_t(size_t, space, PAGE_SIZE), + pfrag, sk->sk_allocation)) + goto fallback; + syn_data = tcp_stream_alloc_skb(sk, sk->sk_allocation, false); if (!syn_data) goto fallback; memcpy(syn_data->cb, syn->cb, sizeof(syn->cb)); if (space) { - int copied = copy_from_iter(skb_put(syn_data, space), space, - &fo->data->msg_iter); - if (unlikely(!copied)) { + space = min_t(size_t, space, pfrag->size - pfrag->offset); + space = tcp_wmem_schedule(sk, space); + } + if (space) { + space = copy_page_from_iter(pfrag->page, pfrag->offset, + space, &fo->data->msg_iter); + if (unlikely(!space)) { tcp_skb_tsorted_anchor_cleanup(syn_data); kfree_skb(syn_data); goto fallback; } - if (copied != space) { - skb_trim(syn_data, copied); - space = copied; - } + skb_fill_page_desc(syn_data, 0, pfrag->page, + pfrag->offset, space); + page_ref_inc(pfrag->page); + pfrag->offset += space; + skb_len_add(syn_data, space); skb_zcopy_set(syn_data, fo->uarg, NULL); } /* No more data pending in inet_wait_for_connect() */ @@ -3771,7 +4221,7 @@ static int tcp_send_syn_data(struct sock *sk, struct sk_buff *syn) err = tcp_transmit_skb(sk, syn_data, 1, sk->sk_allocation); - syn->skb_mstamp_ns = syn_data->skb_mstamp_ns; + skb_set_delivery_time(syn, syn_data->skb_mstamp_ns, SKB_CLOCK_MONOTONIC); /* Now full SYN+DATA was cloned and sent (or not), * remove the SYN from the original skb (syn_data) @@ -3812,6 +4262,53 @@ int tcp_connect(struct sock *sk) tcp_call_bpf(sk, BPF_SOCK_OPS_TCP_CONNECT_CB, 0, NULL); +#if defined(CONFIG_TCP_MD5SIG) && defined(CONFIG_TCP_AO) + /* Has to be checked late, after setting daddr/saddr/ops. + * Return error if the peer has both a md5 and a tcp-ao key + * configured as this is ambiguous. + */ + if (unlikely(rcu_dereference_protected(tp->md5sig_info, + lockdep_sock_is_held(sk)))) { + bool needs_ao = !!tp->af_specific->ao_lookup(sk, sk, -1, -1); + bool needs_md5 = !!tp->af_specific->md5_lookup(sk, sk); + struct tcp_ao_info *ao_info; + + ao_info = rcu_dereference_check(tp->ao_info, + lockdep_sock_is_held(sk)); + if (ao_info) { + /* This is an extra check: tcp_ao_required() in + * tcp_v{4,6}_parse_md5_keys() should prevent adding + * md5 keys on ao_required socket. + */ + needs_ao |= ao_info->ao_required; + WARN_ON_ONCE(ao_info->ao_required && needs_md5); + } + if (needs_md5 && needs_ao) + return -EKEYREJECTED; + + /* If we have a matching md5 key and no matching tcp-ao key + * then free up ao_info if allocated. + */ + if (needs_md5) { + tcp_ao_destroy_sock(sk, false); + } else if (needs_ao) { + tcp_clear_md5_list(sk); + kfree(rcu_replace_pointer(tp->md5sig_info, NULL, + lockdep_sock_is_held(sk))); + } + } +#endif +#ifdef CONFIG_TCP_AO + if (unlikely(rcu_dereference_protected(tp->ao_info, + lockdep_sock_is_held(sk)))) { + /* Don't allow connecting if ao is configured but no + * matching key is found. + */ + if (!tp->af_specific->ao_lookup(sk, sk, -1, -1)) + return -EKEYREJECTED; + } +#endif + if (inet_csk(sk)->icsk_af_ops->rebuild_header(sk)) return -EHOSTUNREACH; /* Routing failure or similar. */ @@ -3822,13 +4319,16 @@ int tcp_connect(struct sock *sk) return 0; } - buff = tcp_stream_alloc_skb(sk, 0, sk->sk_allocation, true); + buff = tcp_stream_alloc_skb(sk, sk->sk_allocation, true); if (unlikely(!buff)) return -ENOBUFS; - tcp_init_nondata_skb(buff, tp->write_seq++, TCPHDR_SYN); + /* SYN eats a sequence byte, write_seq updated by + * tcp_connect_queue_skb(). + */ + tcp_init_nondata_skb(buff, sk, tp->write_seq, TCPHDR_SYN); tcp_mstamp_refresh(tp); - tp->retrans_stamp = tcp_time_stamp(tp); + tp->retrans_stamp = tcp_time_stamp_ts(tp); tcp_connect_queue_skb(sk, buff); tcp_ecn_send_syn(sk, buff); tcp_rbtree_insert(&sk->tcp_rtx_queue, buff); @@ -3852,12 +4352,19 @@ int tcp_connect(struct sock *sk) TCP_INC_STATS(sock_net(sk), TCP_MIB_ACTIVEOPENS); /* Timer for repeating the SYN until an answer. */ - inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - inet_csk(sk)->icsk_rto, TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + inet_csk(sk)->icsk_rto, false); return 0; } EXPORT_SYMBOL(tcp_connect); +u32 tcp_delack_max(const struct sock *sk) +{ + u32 delack_from_rto_min = max(tcp_rto_min(sk), 2) - 1; + + return min(READ_ONCE(inet_csk(sk)->icsk_delack_max), delack_from_rto_min); +} + /* Send out a delayed ack, the caller does the policy checking * to see if we should even be here. See tcp_input.c:tcp_ack_snd_check() * for details. @@ -3893,7 +4400,7 @@ void tcp_send_delayed_ack(struct sock *sk) ato = min(ato, max_ato); } - ato = min_t(u32, ato, inet_csk(sk)->icsk_delack_max); + ato = min_t(u32, ato, tcp_delack_max(sk)); /* Stay within the limit we were given */ timeout = jiffies + ato; @@ -3901,21 +4408,21 @@ void tcp_send_delayed_ack(struct sock *sk) /* Use new timeout only if there wasn't a older one earlier. */ if (icsk->icsk_ack.pending & ICSK_ACK_TIMER) { /* If delack timer is about to expire, send ACK now. */ - if (time_before_eq(icsk->icsk_ack.timeout, jiffies + (ato >> 2))) { + if (time_before_eq(icsk_delack_timeout(icsk), jiffies + (ato >> 2))) { tcp_send_ack(sk); return; } - if (!time_before(timeout, icsk->icsk_ack.timeout)) - timeout = icsk->icsk_ack.timeout; + if (!time_before(timeout, icsk_delack_timeout(icsk))) + timeout = icsk_delack_timeout(icsk); } - icsk->icsk_ack.pending |= ICSK_ACK_SCHED | ICSK_ACK_TIMER; - icsk->icsk_ack.timeout = timeout; + smp_store_release(&icsk->icsk_ack.pending, + icsk->icsk_ack.pending | ICSK_ACK_SCHED | ICSK_ACK_TIMER); sk_reset_timer(sk, &icsk->icsk_delack_timer, timeout); } /* This routine sends an ack and also updates the window. */ -void __tcp_send_ack(struct sock *sk, u32 rcv_nxt) +void __tcp_send_ack(struct sock *sk, u32 rcv_nxt, u16 flags) { struct sk_buff *buff; @@ -3934,17 +4441,18 @@ void __tcp_send_ack(struct sock *sk, u32 rcv_nxt) unsigned long delay; delay = TCP_DELACK_MAX << icsk->icsk_ack.retry; - if (delay < TCP_RTO_MAX) + if (delay < tcp_rto_max(sk)) icsk->icsk_ack.retry++; inet_csk_schedule_ack(sk); icsk->icsk_ack.ato = TCP_ATO_MIN; - inet_csk_reset_xmit_timer(sk, ICSK_TIME_DACK, delay, TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_DACK, delay, false); return; } /* Reserve space for headers and prepare control bits. */ skb_reserve(buff, MAX_TCP_HEADER); - tcp_init_nondata_skb(buff, tcp_acceptable_seq(sk), TCPHDR_ACK); + tcp_init_nondata_skb(buff, sk, + tcp_acceptable_seq(sk), TCPHDR_ACK | flags); /* We do not want pure acks influencing TCP Small Queues or fq/pacing * too much. @@ -3959,7 +4467,7 @@ EXPORT_SYMBOL_GPL(__tcp_send_ack); void tcp_send_ack(struct sock *sk) { - __tcp_send_ack(sk, tcp_sk(sk)->rcv_nxt); + __tcp_send_ack(sk, tcp_sk(sk)->rcv_nxt, 0); } /* This routine sends a packet with an out of date sequence @@ -3990,7 +4498,7 @@ static int tcp_xmit_probe_skb(struct sock *sk, int urgent, int mib) * end to send an ack. Don't queue or clone SKB, just * send it. */ - tcp_init_nondata_skb(skb, tp->snd_una - !urgent, TCPHDR_ACK); + tcp_init_nondata_skb(skb, sk, tp->snd_una - !urgent, TCPHDR_ACK); NET_INC_STATS(sock_net(sk), mib); return tcp_transmit_skb(sk, skb, 0, (__force gfp_t)0); } @@ -4064,17 +4572,17 @@ void tcp_send_probe0(struct sock *sk) if (tp->packets_out || tcp_write_queue_empty(sk)) { /* Cancel probe timer, if it is not required. */ - icsk->icsk_probes_out = 0; + WRITE_ONCE(icsk->icsk_probes_out, 0); icsk->icsk_backoff = 0; icsk->icsk_probes_tstamp = 0; return; } - icsk->icsk_probes_out++; + WRITE_ONCE(icsk->icsk_probes_out, icsk->icsk_probes_out + 1); if (err <= 0) { - if (icsk->icsk_backoff < net->ipv4.sysctl_tcp_retries2) + if (icsk->icsk_backoff < READ_ONCE(net->ipv4.sysctl_tcp_retries2)) icsk->icsk_backoff++; - timeout = tcp_probe0_when(sk, TCP_RTO_MAX); + timeout = tcp_probe0_when(sk, tcp_rto_max(sk)); } else { /* If packet was not sent due to local congestion, * Let senders fight for local resources conservatively. @@ -4083,7 +4591,7 @@ void tcp_send_probe0(struct sock *sk) } timeout = tcp_clamp_probe0_to_user_timeout(sk, timeout); - tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, timeout, TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, timeout, true); } int tcp_rtx_synack(const struct sock *sk, struct request_sock *req) @@ -4092,16 +4600,24 @@ int tcp_rtx_synack(const struct sock *sk, struct request_sock *req) struct flowi fl; int res; - tcp_rsk(req)->txhash = net_tx_rndhash(); + /* Paired with WRITE_ONCE() in sock_setsockopt() */ + if (READ_ONCE(sk->sk_txrehash) == SOCK_TXREHASH_ENABLED) + WRITE_ONCE(tcp_rsk(req)->txhash, net_tx_rndhash()); res = af_ops->send_synack(sk, NULL, &fl, req, NULL, TCP_SYNACK_NORMAL, NULL); if (!res) { - __TCP_INC_STATS(sock_net(sk), TCP_MIB_RETRANSSEGS); - __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPSYNRETRANS); - if (unlikely(tcp_passive_fastopen(sk))) - tcp_sk(sk)->total_retrans++; + TCP_INC_STATS(sock_net(sk), TCP_MIB_RETRANSSEGS); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPSYNRETRANS); + if (unlikely(tcp_passive_fastopen(sk))) { + /* sk has const attribute because listeners are lockless. + * However in this case, we are dealing with a passive fastopen + * socket thus we can change total_retrans value. + */ + tcp_sk_rw(sk)->total_retrans++; + } trace_tcp_retransmit_synack(sk, req); + WRITE_ONCE(req->num_retrans, req->num_retrans + 1); } return res; } -EXPORT_SYMBOL(tcp_rtx_synack); +EXPORT_IPV6_MOD(tcp_rtx_synack); diff --git a/net/ipv4/tcp_plb.c b/net/ipv4/tcp_plb.c new file mode 100644 index 000000000000..4bcf7eff95e3 --- /dev/null +++ b/net/ipv4/tcp_plb.c @@ -0,0 +1,109 @@ +/* Protective Load Balancing (PLB) + * + * PLB was designed to reduce link load imbalance across datacenter + * switches. PLB is a host-based optimization; it leverages congestion + * signals from the transport layer to randomly change the path of the + * connection experiencing sustained congestion. PLB prefers to repath + * after idle periods to minimize packet reordering. It repaths by + * changing the IPv6 Flow Label on the packets of a connection, which + * datacenter switches include as part of ECMP/WCMP hashing. + * + * PLB is described in detail in: + * + * Mubashir Adnan Qureshi, Yuchung Cheng, Qianwen Yin, Qiaobin Fu, + * Gautam Kumar, Masoud Moshref, Junhua Yan, Van Jacobson, + * David Wetherall,Abdul Kabbani: + * "PLB: Congestion Signals are Simple and Effective for + * Network Load Balancing" + * In ACM SIGCOMM 2022, Amsterdam Netherlands. + * + */ + +#include <net/tcp.h> + +/* Called once per round-trip to update PLB state for a connection. */ +void tcp_plb_update_state(const struct sock *sk, struct tcp_plb_state *plb, + const int cong_ratio) +{ + struct net *net = sock_net(sk); + + if (!READ_ONCE(net->ipv4.sysctl_tcp_plb_enabled)) + return; + + if (cong_ratio >= 0) { + if (cong_ratio < READ_ONCE(net->ipv4.sysctl_tcp_plb_cong_thresh)) + plb->consec_cong_rounds = 0; + else if (plb->consec_cong_rounds < + READ_ONCE(net->ipv4.sysctl_tcp_plb_rehash_rounds)) + plb->consec_cong_rounds++; + } +} +EXPORT_SYMBOL_GPL(tcp_plb_update_state); + +/* Check whether recent congestion has been persistent enough to warrant + * a load balancing decision that switches the connection to another path. + */ +void tcp_plb_check_rehash(struct sock *sk, struct tcp_plb_state *plb) +{ + struct net *net = sock_net(sk); + u32 max_suspend; + bool forced_rehash = false, idle_rehash = false; + + if (!READ_ONCE(net->ipv4.sysctl_tcp_plb_enabled)) + return; + + forced_rehash = plb->consec_cong_rounds >= + READ_ONCE(net->ipv4.sysctl_tcp_plb_rehash_rounds); + /* If sender goes idle then we check whether to rehash. */ + idle_rehash = READ_ONCE(net->ipv4.sysctl_tcp_plb_idle_rehash_rounds) && + !tcp_sk(sk)->packets_out && + plb->consec_cong_rounds >= + READ_ONCE(net->ipv4.sysctl_tcp_plb_idle_rehash_rounds); + + if (!forced_rehash && !idle_rehash) + return; + + /* Note that tcp_jiffies32 can wrap; we detect wraps by checking for + * cases where the max suspension end is before the actual suspension + * end. We clear pause_until to 0 to indicate there is no recent + * RTO event that constrains PLB rehashing. + */ + max_suspend = 2 * READ_ONCE(net->ipv4.sysctl_tcp_plb_suspend_rto_sec) * HZ; + if (plb->pause_until && + (!before(tcp_jiffies32, plb->pause_until) || + before(tcp_jiffies32 + max_suspend, plb->pause_until))) + plb->pause_until = 0; + + if (plb->pause_until) + return; + + sk_rethink_txhash(sk); + plb->consec_cong_rounds = 0; + tcp_sk(sk)->plb_rehash++; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPPLBREHASH); +} +EXPORT_SYMBOL_GPL(tcp_plb_check_rehash); + +/* Upon RTO, disallow load balancing for a while, to avoid having load + * balancing decisions switch traffic to a black-holed path that was + * previously avoided with a sk_rethink_txhash() call at RTO time. + */ +void tcp_plb_update_state_upon_rto(struct sock *sk, struct tcp_plb_state *plb) +{ + struct net *net = sock_net(sk); + u32 pause; + + if (!READ_ONCE(net->ipv4.sysctl_tcp_plb_enabled)) + return; + + pause = READ_ONCE(net->ipv4.sysctl_tcp_plb_suspend_rto_sec) * HZ; + pause += get_random_u32_below(pause); + plb->pause_until = tcp_jiffies32 + pause; + + /* Reset PLB state upon RTO, since an RTO causes a sk_rethink_txhash() call + * that may switch this connection to a path with completely different + * congestion characteristics. + */ + plb->consec_cong_rounds = 0; +} +EXPORT_SYMBOL_GPL(tcp_plb_update_state_upon_rto); diff --git a/net/ipv4/tcp_rate.c b/net/ipv4/tcp_rate.c index fbab921670cc..a8f6d9d06f2e 100644 --- a/net/ipv4/tcp_rate.c +++ b/net/ipv4/tcp_rate.c @@ -74,27 +74,32 @@ void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb) * * If an ACK (s)acks multiple skbs (e.g., stretched-acks), this function is * called multiple times. We favor the information from the most recently - * sent skb, i.e., the skb with the highest prior_delivered count. + * sent skb, i.e., the skb with the most recently sent time and the highest + * sequence. */ void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, struct rate_sample *rs) { struct tcp_sock *tp = tcp_sk(sk); struct tcp_skb_cb *scb = TCP_SKB_CB(skb); + u64 tx_tstamp; if (!scb->tx.delivered_mstamp) return; + tx_tstamp = tcp_skb_timestamp_us(skb); if (!rs->prior_delivered || - after(scb->tx.delivered, rs->prior_delivered)) { + tcp_skb_sent_after(tx_tstamp, tp->first_tx_mstamp, + scb->end_seq, rs->last_end_seq)) { rs->prior_delivered_ce = scb->tx.delivered_ce; rs->prior_delivered = scb->tx.delivered; rs->prior_mstamp = scb->tx.delivered_mstamp; rs->is_app_limited = scb->tx.is_app_limited; rs->is_retrans = scb->sacked & TCPCB_RETRANS; + rs->last_end_seq = scb->end_seq; /* Record send time of most recently ACKed packet: */ - tp->first_tx_mstamp = tcp_skb_timestamp_us(skb); + tp->first_tx_mstamp = tx_tstamp; /* Find the duration of the "send phase" of this window: */ rs->interval_us = tcp_stamp_us_delta(tp->first_tx_mstamp, scb->tx.first_tx_mstamp); @@ -195,7 +200,7 @@ void tcp_rate_check_app_limited(struct sock *sk) /* Nothing in sending host's qdisc queues or NIC tx queue. */ sk_wmem_alloc_get(sk) < SKB_TRUESIZE(1) && /* We are not limited by CWND. */ - tcp_packets_in_flight(tp) < tp->snd_cwnd && + tcp_packets_in_flight(tp) < tcp_snd_cwnd(tp) && /* All lost packets have been retransmitted. */ tp->lost_out <= tp->retrans_out) tp->app_limited = diff --git a/net/ipv4/tcp_recovery.c b/net/ipv4/tcp_recovery.c index fd113f6226ef..c52fd3254b6e 100644 --- a/net/ipv4/tcp_recovery.c +++ b/net/ipv4/tcp_recovery.c @@ -2,14 +2,9 @@ #include <linux/tcp.h> #include <net/tcp.h> -static bool tcp_rack_sent_after(u64 t1, u64 t2, u32 seq1, u32 seq2) -{ - return t1 > t2 || (t1 == t2 && after(seq1, seq2)); -} - static u32 tcp_rack_reo_wnd(const struct sock *sk) { - struct tcp_sock *tp = tcp_sk(sk); + const struct tcp_sock *tp = tcp_sk(sk); if (!tp->reord_seen) { /* If reordering has not been observed, be aggressive during @@ -19,7 +14,8 @@ static u32 tcp_rack_reo_wnd(const struct sock *sk) return 0; if (tp->sacked_out >= tp->reordering && - !(sock_net(sk)->ipv4.sysctl_tcp_recovery & TCP_RACK_NO_DUPTHRESH)) + !(READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_recovery) & + TCP_RACK_NO_DUPTHRESH)) return 0; } @@ -39,7 +35,7 @@ s32 tcp_rack_skb_timeout(struct tcp_sock *tp, struct sk_buff *skb, u32 reo_wnd) tcp_stamp_us_delta(tp->tcp_mstamp, tcp_skb_timestamp_us(skb)); } -/* RACK loss detection (IETF draft draft-ietf-tcpm-rack-01): +/* RACK loss detection (IETF RFC8985): * * Marks a packet lost, if some packet sent later has been (s)acked. * The underlying idea is similar to the traditional dupthresh and FACK @@ -77,9 +73,9 @@ static void tcp_rack_detect_loss(struct sock *sk, u32 *reo_timeout) !(scb->sacked & TCPCB_SACKED_RETRANS)) continue; - if (!tcp_rack_sent_after(tp->rack.mstamp, - tcp_skb_timestamp_us(skb), - tp->rack.end_seq, scb->end_seq)) + if (!tcp_skb_sent_after(tp->rack.mstamp, + tcp_skb_timestamp_us(skb), + tp->rack.end_seq, scb->end_seq)) break; /* A packet is lost if it has not been s/acked beyond @@ -108,7 +104,7 @@ bool tcp_rack_mark_lost(struct sock *sk) tp->rack.advanced = 0; tcp_rack_detect_loss(sk, &timeout); if (timeout) { - timeout = usecs_to_jiffies(timeout) + TCP_TIMEOUT_MIN; + timeout = usecs_to_jiffies(timeout + TCP_TIMEOUT_MIN_US); inet_csk_reset_xmit_timer(sk, ICSK_TIME_REO_TIMEOUT, timeout, inet_csk(sk)->icsk_rto); } @@ -140,8 +136,8 @@ void tcp_rack_advance(struct tcp_sock *tp, u8 sacked, u32 end_seq, } tp->rack.advanced = 1; tp->rack.rtt_us = rtt_us; - if (tcp_rack_sent_after(xmit_time, tp->rack.mstamp, - end_seq, tp->rack.end_seq)) { + if (tcp_skb_sent_after(xmit_time, tp->rack.mstamp, + end_seq, tp->rack.end_seq)) { tp->rack.mstamp = xmit_time; tp->rack.end_seq = end_seq; } @@ -192,7 +188,8 @@ void tcp_rack_update_reo_wnd(struct sock *sk, struct rate_sample *rs) { struct tcp_sock *tp = tcp_sk(sk); - if (sock_net(sk)->ipv4.sysctl_tcp_recovery & TCP_RACK_STATIC_REO_WND || + if ((READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_recovery) & + TCP_RACK_STATIC_REO_WND) || !rs->prior_delivered) return; diff --git a/net/ipv4/tcp_scalable.c b/net/ipv4/tcp_scalable.c index 5842081bc8a2..862b96248a92 100644 --- a/net/ipv4/tcp_scalable.c +++ b/net/ipv4/tcp_scalable.c @@ -27,7 +27,7 @@ static void tcp_scalable_cong_avoid(struct sock *sk, u32 ack, u32 acked) if (!acked) return; } - tcp_cong_avoid_ai(tp, min(tp->snd_cwnd, TCP_SCALABLE_AI_CNT), + tcp_cong_avoid_ai(tp, min(tcp_snd_cwnd(tp), TCP_SCALABLE_AI_CNT), acked); } @@ -35,7 +35,7 @@ static u32 tcp_scalable_ssthresh(struct sock *sk) { const struct tcp_sock *tp = tcp_sk(sk); - return max(tp->snd_cwnd - (tp->snd_cwnd>>TCP_SCALABLE_MD_SCALE), 2U); + return max(tcp_snd_cwnd(tp) - (tcp_snd_cwnd(tp)>>TCP_SCALABLE_MD_SCALE), 2U); } static struct tcp_congestion_ops tcp_scalable __read_mostly = { diff --git a/net/ipv4/tcp_sigpool.c b/net/ipv4/tcp_sigpool.c new file mode 100644 index 000000000000..d8a4f192873a --- /dev/null +++ b/net/ipv4/tcp_sigpool.c @@ -0,0 +1,366 @@ +// SPDX-License-Identifier: GPL-2.0-or-later + +#include <crypto/hash.h> +#include <linux/cpu.h> +#include <linux/kref.h> +#include <linux/module.h> +#include <linux/mutex.h> +#include <linux/percpu.h> +#include <linux/workqueue.h> +#include <net/tcp.h> + +static size_t __scratch_size; +struct sigpool_scratch { + local_lock_t bh_lock; + void __rcu *pad; +}; + +static DEFINE_PER_CPU(struct sigpool_scratch, sigpool_scratch) = { + .bh_lock = INIT_LOCAL_LOCK(bh_lock), +}; + +struct sigpool_entry { + struct crypto_ahash *hash; + const char *alg; + struct kref kref; + uint16_t needs_key:1, + reserved:15; +}; + +#define CPOOL_SIZE (PAGE_SIZE / sizeof(struct sigpool_entry)) +static struct sigpool_entry cpool[CPOOL_SIZE]; +static unsigned int cpool_populated; +static DEFINE_MUTEX(cpool_mutex); + +/* Slow-path */ +struct scratches_to_free { + struct rcu_head rcu; + unsigned int cnt; + void *scratches[]; +}; + +static void free_old_scratches(struct rcu_head *head) +{ + struct scratches_to_free *stf; + + stf = container_of(head, struct scratches_to_free, rcu); + while (stf->cnt--) + kfree(stf->scratches[stf->cnt]); + kfree(stf); +} + +/** + * sigpool_reserve_scratch - re-allocates scratch buffer, slow-path + * @size: request size for the scratch/temp buffer + */ +static int sigpool_reserve_scratch(size_t size) +{ + struct scratches_to_free *stf; + size_t stf_sz = struct_size(stf, scratches, num_possible_cpus()); + int cpu, err = 0; + + lockdep_assert_held(&cpool_mutex); + if (__scratch_size >= size) + return 0; + + stf = kmalloc(stf_sz, GFP_KERNEL); + if (!stf) + return -ENOMEM; + stf->cnt = 0; + + size = max(size, __scratch_size); + cpus_read_lock(); + for_each_possible_cpu(cpu) { + void *scratch, *old_scratch; + + scratch = kmalloc_node(size, GFP_KERNEL, cpu_to_node(cpu)); + if (!scratch) { + err = -ENOMEM; + break; + } + + old_scratch = rcu_replace_pointer(per_cpu(sigpool_scratch.pad, cpu), + scratch, lockdep_is_held(&cpool_mutex)); + if (!cpu_online(cpu) || !old_scratch) { + kfree(old_scratch); + continue; + } + stf->scratches[stf->cnt++] = old_scratch; + } + cpus_read_unlock(); + if (!err) + __scratch_size = size; + + call_rcu(&stf->rcu, free_old_scratches); + return err; +} + +static void sigpool_scratch_free(void) +{ + int cpu; + + for_each_possible_cpu(cpu) + kfree(rcu_replace_pointer(per_cpu(sigpool_scratch.pad, cpu), + NULL, lockdep_is_held(&cpool_mutex))); + __scratch_size = 0; +} + +static int __cpool_try_clone(struct crypto_ahash *hash) +{ + struct crypto_ahash *tmp; + + tmp = crypto_clone_ahash(hash); + if (IS_ERR(tmp)) + return PTR_ERR(tmp); + + crypto_free_ahash(tmp); + return 0; +} + +static int __cpool_alloc_ahash(struct sigpool_entry *e, const char *alg) +{ + struct crypto_ahash *cpu0_hash; + int ret; + + e->alg = kstrdup(alg, GFP_KERNEL); + if (!e->alg) + return -ENOMEM; + + cpu0_hash = crypto_alloc_ahash(alg, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(cpu0_hash)) { + ret = PTR_ERR(cpu0_hash); + goto out_free_alg; + } + + e->needs_key = crypto_ahash_get_flags(cpu0_hash) & CRYPTO_TFM_NEED_KEY; + + ret = __cpool_try_clone(cpu0_hash); + if (ret) + goto out_free_cpu0_hash; + e->hash = cpu0_hash; + kref_init(&e->kref); + return 0; + +out_free_cpu0_hash: + crypto_free_ahash(cpu0_hash); +out_free_alg: + kfree(e->alg); + e->alg = NULL; + return ret; +} + +/** + * tcp_sigpool_alloc_ahash - allocates pool for ahash requests + * @alg: name of async hash algorithm + * @scratch_size: reserve a tcp_sigpool::scratch buffer of this size + */ +int tcp_sigpool_alloc_ahash(const char *alg, size_t scratch_size) +{ + int i, ret; + + /* slow-path */ + mutex_lock(&cpool_mutex); + ret = sigpool_reserve_scratch(scratch_size); + if (ret) + goto out; + for (i = 0; i < cpool_populated; i++) { + if (!cpool[i].alg) + continue; + if (strcmp(cpool[i].alg, alg)) + continue; + + /* pairs with tcp_sigpool_release() */ + if (!kref_get_unless_zero(&cpool[i].kref)) + kref_init(&cpool[i].kref); + ret = i; + goto out; + } + + for (i = 0; i < cpool_populated; i++) { + if (!cpool[i].alg) + break; + } + if (i >= CPOOL_SIZE) { + ret = -ENOSPC; + goto out; + } + + ret = __cpool_alloc_ahash(&cpool[i], alg); + if (!ret) { + ret = i; + if (i == cpool_populated) + cpool_populated++; + } +out: + mutex_unlock(&cpool_mutex); + return ret; +} +EXPORT_SYMBOL_GPL(tcp_sigpool_alloc_ahash); + +static void __cpool_free_entry(struct sigpool_entry *e) +{ + crypto_free_ahash(e->hash); + kfree(e->alg); + memset(e, 0, sizeof(*e)); +} + +static void cpool_cleanup_work_cb(struct work_struct *work) +{ + bool free_scratch = true; + unsigned int i; + + mutex_lock(&cpool_mutex); + for (i = 0; i < cpool_populated; i++) { + if (kref_read(&cpool[i].kref) > 0) { + free_scratch = false; + continue; + } + if (!cpool[i].alg) + continue; + __cpool_free_entry(&cpool[i]); + } + if (free_scratch) + sigpool_scratch_free(); + mutex_unlock(&cpool_mutex); +} + +static DECLARE_WORK(cpool_cleanup_work, cpool_cleanup_work_cb); +static void cpool_schedule_cleanup(struct kref *kref) +{ + schedule_work(&cpool_cleanup_work); +} + +/** + * tcp_sigpool_release - decreases number of users for a pool. If it was + * the last user of the pool, releases any memory that was consumed. + * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() + */ +void tcp_sigpool_release(unsigned int id) +{ + if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) + return; + + /* slow-path */ + kref_put(&cpool[id].kref, cpool_schedule_cleanup); +} +EXPORT_SYMBOL_GPL(tcp_sigpool_release); + +/** + * tcp_sigpool_get - increases number of users (refcounter) for a pool + * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() + */ +void tcp_sigpool_get(unsigned int id) +{ + if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) + return; + kref_get(&cpool[id].kref); +} +EXPORT_SYMBOL_GPL(tcp_sigpool_get); + +int tcp_sigpool_start(unsigned int id, struct tcp_sigpool *c) __cond_acquires(RCU_BH) +{ + struct crypto_ahash *hash; + + rcu_read_lock_bh(); + if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) { + rcu_read_unlock_bh(); + return -EINVAL; + } + + hash = crypto_clone_ahash(cpool[id].hash); + if (IS_ERR(hash)) { + rcu_read_unlock_bh(); + return PTR_ERR(hash); + } + + c->req = ahash_request_alloc(hash, GFP_ATOMIC); + if (!c->req) { + crypto_free_ahash(hash); + rcu_read_unlock_bh(); + return -ENOMEM; + } + ahash_request_set_callback(c->req, 0, NULL, NULL); + + /* Pairs with tcp_sigpool_reserve_scratch(), scratch area is + * valid (allocated) until tcp_sigpool_end(). + */ + local_lock_nested_bh(&sigpool_scratch.bh_lock); + c->scratch = rcu_dereference_bh(*this_cpu_ptr(&sigpool_scratch.pad)); + return 0; +} +EXPORT_SYMBOL_GPL(tcp_sigpool_start); + +void tcp_sigpool_end(struct tcp_sigpool *c) __releases(RCU_BH) +{ + struct crypto_ahash *hash = crypto_ahash_reqtfm(c->req); + + local_unlock_nested_bh(&sigpool_scratch.bh_lock); + rcu_read_unlock_bh(); + ahash_request_free(c->req); + crypto_free_ahash(hash); +} +EXPORT_SYMBOL_GPL(tcp_sigpool_end); + +/** + * tcp_sigpool_algo - return algorithm of tcp_sigpool + * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() + * @buf: buffer to return name of algorithm + * @buf_len: size of @buf + */ +size_t tcp_sigpool_algo(unsigned int id, char *buf, size_t buf_len) +{ + if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) + return -EINVAL; + + return strscpy(buf, cpool[id].alg, buf_len); +} +EXPORT_SYMBOL_GPL(tcp_sigpool_algo); + +/** + * tcp_sigpool_hash_skb_data - hash data in skb with initialized tcp_sigpool + * @hp: tcp_sigpool pointer + * @skb: buffer to add sign for + * @header_len: TCP header length for this segment + */ +int tcp_sigpool_hash_skb_data(struct tcp_sigpool *hp, + const struct sk_buff *skb, + unsigned int header_len) +{ + const unsigned int head_data_len = skb_headlen(skb) > header_len ? + skb_headlen(skb) - header_len : 0; + const struct skb_shared_info *shi = skb_shinfo(skb); + const struct tcphdr *tp = tcp_hdr(skb); + struct ahash_request *req = hp->req; + struct sk_buff *frag_iter; + struct scatterlist sg; + unsigned int i; + + sg_init_table(&sg, 1); + + sg_set_buf(&sg, ((u8 *)tp) + header_len, head_data_len); + ahash_request_set_crypt(req, &sg, NULL, head_data_len); + if (crypto_ahash_update(req)) + return 1; + + for (i = 0; i < shi->nr_frags; ++i) { + const skb_frag_t *f = &shi->frags[i]; + unsigned int offset = skb_frag_off(f); + struct page *page; + + page = skb_frag_page(f) + (offset >> PAGE_SHIFT); + sg_set_page(&sg, page, skb_frag_size(f), offset_in_page(offset)); + ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f)); + if (crypto_ahash_update(req)) + return 1; + } + + skb_walk_frags(skb, frag_iter) + if (tcp_sigpool_hash_skb_data(hp, frag_iter, 0)) + return 1; + + return 0; +} +EXPORT_SYMBOL(tcp_sigpool_hash_skb_data); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Per-CPU pool of crypto requests"); diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c index 20cf4a98c69d..160080c9021d 100644 --- a/net/ipv4/tcp_timer.c +++ b/net/ipv4/tcp_timer.c @@ -22,18 +22,24 @@ #include <linux/module.h> #include <linux/gfp.h> #include <net/tcp.h> +#include <net/rstreason.h> static u32 tcp_clamp_rto_to_user_timeout(const struct sock *sk) { - struct inet_connection_sock *icsk = inet_csk(sk); - u32 elapsed, start_ts; + const struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_sock *tp = tcp_sk(sk); + u32 elapsed, user_timeout; s32 remaining; - start_ts = tcp_sk(sk)->retrans_stamp; - if (!icsk->icsk_user_timeout) + user_timeout = READ_ONCE(icsk->icsk_user_timeout); + if (!user_timeout) return icsk->icsk_rto; - elapsed = tcp_time_stamp(tcp_sk(sk)) - start_ts; - remaining = icsk->icsk_user_timeout - elapsed; + + elapsed = tcp_time_stamp_ts(tp) - tp->retrans_stamp; + if (tp->tcp_usec_ts) + elapsed /= USEC_PER_MSEC; + + remaining = user_timeout - elapsed; if (remaining <= 0) return 1; /* user timeout has passed; fire ASAP */ @@ -42,17 +48,18 @@ static u32 tcp_clamp_rto_to_user_timeout(const struct sock *sk) u32 tcp_clamp_probe0_to_user_timeout(const struct sock *sk, u32 when) { - struct inet_connection_sock *icsk = inet_csk(sk); - u32 remaining; + const struct inet_connection_sock *icsk = inet_csk(sk); + u32 remaining, user_timeout; s32 elapsed; - if (!icsk->icsk_user_timeout || !icsk->icsk_probes_tstamp) + user_timeout = READ_ONCE(icsk->icsk_user_timeout); + if (!user_timeout || !icsk->icsk_probes_tstamp) return when; elapsed = tcp_jiffies32 - icsk->icsk_probes_tstamp; if (unlikely(elapsed < 0)) elapsed = 0; - remaining = msecs_to_jiffies(icsk->icsk_user_timeout) - elapsed; + remaining = msecs_to_jiffies(user_timeout) - elapsed; remaining = max_t(u32, remaining, TCP_TIMEOUT_MIN); return min_t(u32, remaining, when); @@ -67,11 +74,7 @@ u32 tcp_clamp_probe0_to_user_timeout(const struct sock *sk, u32 when) static void tcp_write_err(struct sock *sk) { - sk->sk_err = sk->sk_err_soft ? : ETIMEDOUT; - sk_error_report(sk); - - tcp_write_queue_purge(sk); - tcp_done(sk); + tcp_done_with_error(sk, READ_ONCE(sk->sk_err_soft) ? : ETIMEDOUT); __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONTIMEOUT); } @@ -106,11 +109,11 @@ static int tcp_out_of_resources(struct sock *sk, bool do_reset) /* If peer does not open window for long time, or did not transmit * anything for long time, penalize it. */ - if ((s32)(tcp_jiffies32 - tp->lsndtime) > 2*TCP_RTO_MAX || !do_reset) + if ((s32)(tcp_jiffies32 - tp->lsndtime) > 2*tcp_rto_max(sk) || !do_reset) shift++; /* If some dubious ICMP arrived, penalize even more. */ - if (sk->sk_err_soft) + if (READ_ONCE(sk->sk_err_soft)) shift++; if (tcp_check_oom(sk, shift)) { @@ -121,7 +124,8 @@ static int tcp_out_of_resources(struct sock *sk, bool do_reset) (!tp->snd_wnd && !tp->packets_out)) do_reset = true; if (do_reset) - tcp_send_active_reset(sk, GFP_ATOMIC); + tcp_send_active_reset(sk, GFP_ATOMIC, + SK_RST_REASON_TCP_ABORT_ON_MEMORY); tcp_done(sk); __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONMEMORY); return 1; @@ -143,10 +147,10 @@ static int tcp_out_of_resources(struct sock *sk, bool do_reset) */ static int tcp_orphan_retries(struct sock *sk, bool alive) { - int retries = sock_net(sk)->ipv4.sysctl_tcp_orphan_retries; /* May be zero. */ + int retries = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_orphan_retries); /* May be zero. */ /* We know from an ICMP that something is wrong. */ - if (sk->sk_err_soft && !alive) + if (READ_ONCE(sk->sk_err_soft) && !alive) retries = 0; /* However, if socket sent something recently, select some safe @@ -163,7 +167,7 @@ static void tcp_mtu_probing(struct inet_connection_sock *icsk, struct sock *sk) int mss; /* Black hole detection */ - if (!net->ipv4.sysctl_tcp_mtu_probing) + if (!READ_ONCE(net->ipv4.sysctl_tcp_mtu_probing)) return; if (!icsk->icsk_mtup.enabled) { @@ -171,9 +175,9 @@ static void tcp_mtu_probing(struct inet_connection_sock *icsk, struct sock *sk) icsk->icsk_mtup.probe_timestamp = tcp_jiffies32; } else { mss = tcp_mtu_to_mss(sk, icsk->icsk_mtup.search_low) >> 1; - mss = min(net->ipv4.sysctl_tcp_base_mss, mss); - mss = max(mss, net->ipv4.sysctl_tcp_mtu_probe_floor); - mss = max(mss, net->ipv4.sysctl_tcp_min_snd_mss); + mss = min(READ_ONCE(net->ipv4.sysctl_tcp_base_mss), mss); + mss = max(mss, READ_ONCE(net->ipv4.sysctl_tcp_mtu_probe_floor)); + mss = max(mss, READ_ONCE(net->ipv4.sysctl_tcp_min_snd_mss)); icsk->icsk_mtup.search_low = tcp_mss_to_mtu(sk, mss); } tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); @@ -185,12 +189,12 @@ static unsigned int tcp_model_timeout(struct sock *sk, { unsigned int linear_backoff_thresh, timeout; - linear_backoff_thresh = ilog2(TCP_RTO_MAX / rto_base); + linear_backoff_thresh = ilog2(tcp_rto_max(sk) / rto_base); if (boundary <= linear_backoff_thresh) timeout = ((2 << boundary) - 1) * rto_base; else timeout = ((2 << linear_backoff_thresh) - 1) * rto_base + - (boundary - linear_backoff_thresh) * TCP_RTO_MAX; + (boundary - linear_backoff_thresh) * tcp_rto_max(sk); return jiffies_to_msecs(timeout); } /** @@ -210,12 +214,13 @@ static bool retransmits_timed_out(struct sock *sk, unsigned int boundary, unsigned int timeout) { - unsigned int start_ts; + struct tcp_sock *tp = tcp_sk(sk); + unsigned int start_ts, delta; if (!inet_csk(sk)->icsk_retransmits) return false; - start_ts = tcp_sk(sk)->retrans_stamp; + start_ts = tp->retrans_stamp; if (likely(timeout == 0)) { unsigned int rto_base = TCP_RTO_MIN; @@ -224,7 +229,12 @@ static bool retransmits_timed_out(struct sock *sk, timeout = tcp_model_timeout(sk, boundary, rto_base); } - return (s32)(tcp_time_stamp(tcp_sk(sk)) - start_ts - timeout) >= 0; + if (tp->tcp_usec_ts) { + /* delta maybe off up to a jiffy due to timer granularity. */ + delta = tp->tcp_mstamp - start_ts + jiffies_to_usecs(1); + return (s32)(delta - timeout * USEC_PER_MSEC) >= 0; + } + return (s32)(tcp_time_stamp_ts(tp) - start_ts - timeout) >= 0; } /* A write timeout has occurred. Process the after effects. */ @@ -234,24 +244,31 @@ static int tcp_write_timeout(struct sock *sk) struct tcp_sock *tp = tcp_sk(sk); struct net *net = sock_net(sk); bool expired = false, do_reset; - int retry_until; + int retry_until, max_retransmits; if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) { if (icsk->icsk_retransmits) __dst_negative_advice(sk); - retry_until = icsk->icsk_syn_retries ? : net->ipv4.sysctl_tcp_syn_retries; - expired = icsk->icsk_retransmits >= retry_until; + /* Paired with WRITE_ONCE() in tcp_sock_set_syncnt() */ + retry_until = READ_ONCE(icsk->icsk_syn_retries) ? : + READ_ONCE(net->ipv4.sysctl_tcp_syn_retries); + + max_retransmits = retry_until; + if (sk->sk_state == TCP_SYN_SENT) + max_retransmits += READ_ONCE(net->ipv4.sysctl_tcp_syn_linear_timeouts); + + expired = icsk->icsk_retransmits >= max_retransmits; } else { - if (retransmits_timed_out(sk, net->ipv4.sysctl_tcp_retries1, 0)) { + if (retransmits_timed_out(sk, READ_ONCE(net->ipv4.sysctl_tcp_retries1), 0)) { /* Black hole detection */ tcp_mtu_probing(icsk, sk); __dst_negative_advice(sk); } - retry_until = net->ipv4.sysctl_tcp_retries2; + retry_until = READ_ONCE(net->ipv4.sysctl_tcp_retries2); if (sock_flag(sk, SOCK_DEAD)) { - const bool alive = icsk->icsk_rto < TCP_RTO_MAX; + const bool alive = icsk->icsk_rto < tcp_rto_max(sk); retry_until = tcp_orphan_retries(sk, alive); do_reset = alive || @@ -263,8 +280,9 @@ static int tcp_write_timeout(struct sock *sk) } if (!expired) expired = retransmits_timed_out(sk, retry_until, - icsk->icsk_user_timeout); + READ_ONCE(icsk->icsk_user_timeout)); tcp_fastopen_active_detect_blackhole(sk, expired); + mptcp_active_detect_blackhole(sk, expired); if (BPF_SOCK_OPS_TEST_FLAG(tp, BPF_SOCK_OPS_RTO_CB_FLAG)) tcp_call_bpf_3arg(sk, BPF_SOCK_OPS_RTO_CB, @@ -289,23 +307,32 @@ static int tcp_write_timeout(struct sock *sk) void tcp_delack_timer_handler(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); - sk_mem_reclaim_partial(sk); + if ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) + return; - if (((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) || - !(icsk->icsk_ack.pending & ICSK_ACK_TIMER)) - goto out; + /* Handling the sack compression case */ + if (tp->compressed_ack) { + tcp_mstamp_refresh(tp); + tcp_sack_compress_send_ack(sk); + return; + } - if (time_after(icsk->icsk_ack.timeout, jiffies)) { - sk_reset_timer(sk, &icsk->icsk_delack_timer, icsk->icsk_ack.timeout); - goto out; + if (!(icsk->icsk_ack.pending & ICSK_ACK_TIMER)) + return; + + if (time_after(icsk_delack_timeout(icsk), jiffies)) { + sk_reset_timer(sk, &icsk->icsk_delack_timer, + icsk_delack_timeout(icsk)); + return; } icsk->icsk_ack.pending &= ~ICSK_ACK_TIMER; if (inet_csk_ack_scheduled(sk)) { if (!inet_csk_in_pingpong_mode(sk)) { /* Delayed ACK missed: inflate ATO. */ - icsk->icsk_ack.ato = min(icsk->icsk_ack.ato << 1, icsk->icsk_rto); + icsk->icsk_ack.ato = min_t(u32, icsk->icsk_ack.ato << 1, icsk->icsk_rto); } else { /* Delayed ACK missed: leave pingpong mode and * deflate ATO. @@ -313,14 +340,10 @@ void tcp_delack_timer_handler(struct sock *sk) inet_csk_exit_pingpong_mode(sk); icsk->icsk_ack.ato = TCP_ATO_MIN; } - tcp_mstamp_refresh(tcp_sk(sk)); + tcp_mstamp_refresh(tp); tcp_send_ack(sk); __NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKS); } - -out: - if (tcp_under_memory_pressure(sk)) - sk_mem_reclaim(sk); } @@ -336,9 +359,17 @@ out: static void tcp_delack_timer(struct timer_list *t) { struct inet_connection_sock *icsk = - from_timer(icsk, t, icsk_delack_timer); + timer_container_of(icsk, t, icsk_delack_timer); struct sock *sk = &icsk->icsk_inet.sk; + /* Avoid taking socket spinlock if there is no ACK to send. + * The compressed_ack check is racy, but a separate hrtimer + * will take care of it eventually. + */ + if (!(smp_load_acquire(&icsk->icsk_ack.pending) & ICSK_ACK_TIMER) && + !READ_ONCE(tcp_sk(sk)->compressed_ack)) + goto out; + bh_lock_sock(sk); if (!sock_owned_by_user(sk)) { tcp_delack_timer_handler(sk); @@ -349,6 +380,7 @@ static void tcp_delack_timer(struct timer_list *t) sock_hold(sk); } bh_unlock_sock(sk); +out: sock_put(sk); } @@ -360,7 +392,7 @@ static void tcp_probe_timer(struct sock *sk) int max_probes; if (tp->packets_out || !skb) { - icsk->icsk_probes_out = 0; + WRITE_ONCE(icsk->icsk_probes_out, 0); icsk->icsk_probes_tstamp = 0; return; } @@ -373,16 +405,20 @@ static void tcp_probe_timer(struct sock *sk) * corresponding system limit. We also implement similar policy when * we use RTO to probe window in tcp_retransmit_timer(). */ - if (!icsk->icsk_probes_tstamp) + if (!icsk->icsk_probes_tstamp) { icsk->icsk_probes_tstamp = tcp_jiffies32; - else if (icsk->icsk_user_timeout && - (s32)(tcp_jiffies32 - icsk->icsk_probes_tstamp) >= - msecs_to_jiffies(icsk->icsk_user_timeout)) - goto abort; + } else { + u32 user_timeout = READ_ONCE(icsk->icsk_user_timeout); - max_probes = sock_net(sk)->ipv4.sysctl_tcp_retries2; + if (user_timeout && + (s32)(tcp_jiffies32 - icsk->icsk_probes_tstamp) >= + msecs_to_jiffies(user_timeout)) + goto abort; + } + max_probes = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_retries2); if (sock_flag(sk, SOCK_DEAD)) { - const bool alive = inet_csk_rto_backoff(icsk, TCP_RTO_MAX) < TCP_RTO_MAX; + unsigned int rto_max = tcp_rto_max(sk); + const bool alive = inet_csk_rto_backoff(icsk, rto_max) < rto_max; max_probes = tcp_orphan_retries(sk, alive); if (!alive && icsk->icsk_backoff >= max_probes) @@ -399,6 +435,19 @@ abort: tcp_write_err(sk); } } +static void tcp_update_rto_stats(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + + if (!icsk->icsk_retransmits) { + tp->total_rto_recoveries++; + tp->rto_stamp = tcp_time_stamp_ms(tp); + } + WRITE_ONCE(icsk->icsk_retransmits, icsk->icsk_retransmits + 1); + tp->total_rto++; +} + /* * Timer for Fast Open socket to retransmit SYNACK. Note that the * sk here is the child socket, not the parent (listener) socket. @@ -406,11 +455,16 @@ abort: tcp_write_err(sk); static void tcp_fastopen_synack_timer(struct sock *sk, struct request_sock *req) { struct inet_connection_sock *icsk = inet_csk(sk); - int max_retries = icsk->icsk_syn_retries ? : - sock_net(sk)->ipv4.sysctl_tcp_synack_retries + 1; /* add one more retry for fastopen */ struct tcp_sock *tp = tcp_sk(sk); + int max_retries; + + tcp_syn_ack_timeout(req); - req->rsk_ops->syn_ack_timeout(req); + /* Add one more retry for fastopen. + * Paired with WRITE_ONCE() in tcp_sock_set_syncnt() + */ + max_retries = READ_ONCE(icsk->icsk_syn_retries) ? : + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_synack_retries) + 1; if (req->num_timeout >= max_retries) { tcp_write_err(sk); @@ -424,15 +478,44 @@ static void tcp_fastopen_synack_timer(struct sock *sk, struct request_sock *req) * regular retransmit because if the child socket has been accepted * it's not good to give up too easily. */ - inet_rtx_syn_ack(sk, req); + tcp_rtx_synack(sk, req); req->num_timeout++; - icsk->icsk_retransmits++; + tcp_update_rto_stats(sk); if (!tp->retrans_stamp) - tp->retrans_stamp = tcp_time_stamp(tp); - inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - TCP_TIMEOUT_INIT << req->num_timeout, TCP_RTO_MAX); + tp->retrans_stamp = tcp_time_stamp_ts(tp); + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + req->timeout << req->num_timeout, false); } +static bool tcp_rtx_probe0_timed_out(const struct sock *sk, + const struct sk_buff *skb, + u32 rtx_delta) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + u32 user_timeout = READ_ONCE(icsk->icsk_user_timeout); + const struct tcp_sock *tp = tcp_sk(sk); + int timeout = tcp_rto_max(sk) * 2; + s32 rcv_delta; + + if (user_timeout) { + /* If user application specified a TCP_USER_TIMEOUT, + * it does not want win 0 packets to 'reset the timer' + * while retransmits are not making progress. + */ + if (rtx_delta > user_timeout) + return true; + timeout = min_t(u32, timeout, msecs_to_jiffies(user_timeout)); + } + /* Note: timer interrupt might have been delayed by at least one jiffy, + * and tp->rcv_tstamp might very well have been written recently. + * rcv_delta can thus be negative. + */ + rcv_delta = tcp_timeout_expires(sk) - tp->rcv_tstamp; + if (rcv_delta <= timeout) + return false; + + return msecs_to_jiffies(rtx_delta) > timeout; +} /** * tcp_retransmit_timer() - The TCP retransmit timeout handler @@ -472,8 +555,6 @@ void tcp_retransmit_timer(struct sock *sk) if (WARN_ON_ONCE(!skb)) return; - tp->tlp_high_seq = 0; - if (!tp->snd_wnd && !sock_flag(sk, SOCK_DEAD) && !((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV))) { /* Receiver dastardly shrinks window. Our retransmits @@ -482,23 +563,30 @@ void tcp_retransmit_timer(struct sock *sk) * we cannot allow such beasts to hang infinitely. */ struct inet_sock *inet = inet_sk(sk); + u32 rtx_delta; + + rtx_delta = tcp_time_stamp_ts(tp) - (tp->retrans_stamp ?: + tcp_skb_timestamp_ts(tp->tcp_usec_ts, skb)); + if (tp->tcp_usec_ts) + rtx_delta /= USEC_PER_MSEC; + if (sk->sk_family == AF_INET) { - net_dbg_ratelimited("Peer %pI4:%u/%u unexpectedly shrunk window %u:%u (repaired)\n", - &inet->inet_daddr, - ntohs(inet->inet_dport), - inet->inet_num, - tp->snd_una, tp->snd_nxt); + net_dbg_ratelimited("Probing zero-window on %pI4:%u/%u, seq=%u:%u, recv %ums ago, lasting %ums\n", + &inet->inet_daddr, ntohs(inet->inet_dport), + inet->inet_num, tp->snd_una, tp->snd_nxt, + jiffies_to_msecs(jiffies - tp->rcv_tstamp), + rtx_delta); } #if IS_ENABLED(CONFIG_IPV6) else if (sk->sk_family == AF_INET6) { - net_dbg_ratelimited("Peer %pI6:%u/%u unexpectedly shrunk window %u:%u (repaired)\n", - &sk->sk_v6_daddr, - ntohs(inet->inet_dport), - inet->inet_num, - tp->snd_una, tp->snd_nxt); + net_dbg_ratelimited("Probing zero-window on %pI6:%u/%u, seq=%u:%u, recv %ums ago, lasting %ums\n", + &sk->sk_v6_daddr, ntohs(inet->inet_dport), + inet->inet_num, tp->snd_una, tp->snd_nxt, + jiffies_to_msecs(jiffies - tp->rcv_tstamp), + rtx_delta); } #endif - if (tcp_jiffies32 - tp->rcv_tstamp > TCP_RTO_MAX) { + if (tcp_rtx_probe0_timed_out(sk, skb, rtx_delta)) { tcp_write_err(sk); goto out; } @@ -535,14 +623,14 @@ void tcp_retransmit_timer(struct sock *sk) tcp_enter_loss(sk); - icsk->icsk_retransmits++; + tcp_update_rto_stats(sk); if (tcp_retransmit_skb(sk, tcp_rtx_queue_head(sk), 1) > 0) { /* Retransmission failed because of local congestion, * Let senders fight for local resources conservatively. */ - inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - TCP_RESOURCE_PROBE_INTERVAL, - TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + TCP_RESOURCE_PROBE_INTERVAL, + false); goto out; } @@ -561,7 +649,6 @@ void tcp_retransmit_timer(struct sock *sk) * implemented ftp to mars will work nicely. We will have to fix * the 120 second clamps though! */ - icsk->icsk_backoff++; out_reset_timer: /* If stream is thin, use linear timeouts. Since 'icsk_backoff' is @@ -574,25 +661,33 @@ out_reset_timer: * linear-timeout retransmissions into a black hole */ if (sk->sk_state == TCP_ESTABLISHED && - (tp->thin_lto || net->ipv4.sysctl_tcp_thin_linear_timeouts) && + (tp->thin_lto || READ_ONCE(net->ipv4.sysctl_tcp_thin_linear_timeouts)) && tcp_stream_is_thin(tp) && icsk->icsk_retransmits <= TCP_THIN_LINEAR_RETRIES) { icsk->icsk_backoff = 0; - icsk->icsk_rto = min(__tcp_set_rto(tp), TCP_RTO_MAX); - } else { - /* Use normal (exponential) backoff */ - icsk->icsk_rto = min(icsk->icsk_rto << 1, TCP_RTO_MAX); + icsk->icsk_rto = clamp(__tcp_set_rto(tp), + tcp_rto_min(sk), + tcp_rto_max(sk)); + } else if (sk->sk_state != TCP_SYN_SENT || + tp->total_rto > + READ_ONCE(net->ipv4.sysctl_tcp_syn_linear_timeouts)) { + /* Use normal (exponential) backoff unless linear timeouts are + * activated. + */ + icsk->icsk_backoff++; + icsk->icsk_rto = min(icsk->icsk_rto << 1, tcp_rto_max(sk)); } - inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - tcp_clamp_rto_to_user_timeout(sk), TCP_RTO_MAX); - if (retransmits_timed_out(sk, net->ipv4.sysctl_tcp_retries1 + 1, 0)) + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + tcp_clamp_rto_to_user_timeout(sk), false); + if (retransmits_timed_out(sk, READ_ONCE(net->ipv4.sysctl_tcp_retries1) + 1, 0)) __sk_dst_reset(sk); out:; } /* Called with bottom-half processing disabled. - Called by tcp_write_timer() */ + * Called by tcp_write_timer() and tcp_release_cb(). + */ void tcp_write_timer_handler(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); @@ -600,13 +695,13 @@ void tcp_write_timer_handler(struct sock *sk) if (((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) || !icsk->icsk_pending) - goto out; + return; - if (time_after(icsk->icsk_timeout, jiffies)) { - sk_reset_timer(sk, &icsk->icsk_retransmit_timer, icsk->icsk_timeout); - goto out; + if (time_after(tcp_timeout_expires(sk), jiffies)) { + sk_reset_timer(sk, &sk->tcp_retransmit_timer, + tcp_timeout_expires(sk)); + return; } - tcp_mstamp_refresh(tcp_sk(sk)); event = icsk->icsk_pending; @@ -618,24 +713,23 @@ void tcp_write_timer_handler(struct sock *sk) tcp_send_loss_probe(sk); break; case ICSK_TIME_RETRANS: - icsk->icsk_pending = 0; + smp_store_release(&icsk->icsk_pending, 0); tcp_retransmit_timer(sk); break; case ICSK_TIME_PROBE0: - icsk->icsk_pending = 0; + smp_store_release(&icsk->icsk_pending, 0); tcp_probe_timer(sk); break; } - -out: - sk_mem_reclaim(sk); } static void tcp_write_timer(struct timer_list *t) { - struct inet_connection_sock *icsk = - from_timer(icsk, t, icsk_retransmit_timer); - struct sock *sk = &icsk->icsk_inet.sk; + struct sock *sk = timer_container_of(sk, t, tcp_retransmit_timer); + + /* Avoid locking the socket when there is no pending event. */ + if (!smp_load_acquire(&inet_csk(sk)->icsk_pending)) + goto out; bh_lock_sock(sk); if (!sock_owned_by_user(sk)) { @@ -646,6 +740,7 @@ static void tcp_write_timer(struct timer_list *t) sock_hold(sk); } bh_unlock_sock(sk); +out: sock_put(sk); } @@ -655,7 +750,16 @@ void tcp_syn_ack_timeout(const struct request_sock *req) __NET_INC_STATS(net, LINUX_MIB_TCPTIMEOUTS); } -EXPORT_SYMBOL(tcp_syn_ack_timeout); + +void tcp_reset_keepalive_timer(struct sock *sk, unsigned long len) +{ + sk_reset_timer(sk, &inet_csk(sk)->icsk_keepalive_timer, jiffies + len); +} + +static void tcp_delete_keepalive_timer(struct sock *sk) +{ + sk_stop_timer(sk, &inet_csk(sk)->icsk_keepalive_timer); +} void tcp_set_keepalive(struct sock *sk, int val) { @@ -663,17 +767,17 @@ void tcp_set_keepalive(struct sock *sk, int val) return; if (val && !sock_flag(sk, SOCK_KEEPOPEN)) - inet_csk_reset_keepalive_timer(sk, keepalive_time_when(tcp_sk(sk))); + tcp_reset_keepalive_timer(sk, keepalive_time_when(tcp_sk(sk))); else if (!val) - inet_csk_delete_keepalive_timer(sk); + tcp_delete_keepalive_timer(sk); } -EXPORT_SYMBOL_GPL(tcp_set_keepalive); +EXPORT_IPV6_MOD_GPL(tcp_set_keepalive); - -static void tcp_keepalive_timer (struct timer_list *t) +static void tcp_keepalive_timer(struct timer_list *t) { - struct sock *sk = from_timer(sk, t, sk_timer); - struct inet_connection_sock *icsk = inet_csk(sk); + struct inet_connection_sock *icsk = + timer_container_of(icsk, t, icsk_keepalive_timer); + struct sock *sk = &icsk->icsk_inet.sk; struct tcp_sock *tp = tcp_sk(sk); u32 elapsed; @@ -681,7 +785,7 @@ static void tcp_keepalive_timer (struct timer_list *t) bh_lock_sock(sk); if (sock_owned_by_user(sk)) { /* Try again later. */ - inet_csk_reset_keepalive_timer (sk, HZ/20); + tcp_reset_keepalive_timer(sk, HZ/20); goto out; } @@ -692,7 +796,7 @@ static void tcp_keepalive_timer (struct timer_list *t) tcp_mstamp_refresh(tp); if (sk->sk_state == TCP_FIN_WAIT2 && sock_flag(sk, SOCK_DEAD)) { - if (tp->linger2 >= 0) { + if (READ_ONCE(tp->linger2) >= 0) { const int tmo = tcp_fin_time(sk) - TCP_TIMEWAIT_LEN; if (tmo > 0) { @@ -700,7 +804,7 @@ static void tcp_keepalive_timer (struct timer_list *t) goto out; } } - tcp_send_active_reset(sk, GFP_ATOMIC); + tcp_send_active_reset(sk, GFP_ATOMIC, SK_RST_REASON_TCP_STATE); goto death; } @@ -717,20 +821,23 @@ static void tcp_keepalive_timer (struct timer_list *t) elapsed = keepalive_time_elapsed(tp); if (elapsed >= keepalive_time_when(tp)) { + u32 user_timeout = READ_ONCE(icsk->icsk_user_timeout); + /* If the TCP_USER_TIMEOUT option is enabled, use that * to determine when to timeout instead. */ - if ((icsk->icsk_user_timeout != 0 && - elapsed >= msecs_to_jiffies(icsk->icsk_user_timeout) && + if ((user_timeout != 0 && + elapsed >= msecs_to_jiffies(user_timeout) && icsk->icsk_probes_out > 0) || - (icsk->icsk_user_timeout == 0 && + (user_timeout == 0 && icsk->icsk_probes_out >= keepalive_probes(tp))) { - tcp_send_active_reset(sk, GFP_ATOMIC); + tcp_send_active_reset(sk, GFP_ATOMIC, + SK_RST_REASON_TCP_KEEPALIVE_TIMEOUT); tcp_write_err(sk); goto out; } if (tcp_write_wakeup(sk, LINUX_MIB_TCPKEEPALIVE) <= 0) { - icsk->icsk_probes_out++; + WRITE_ONCE(icsk->icsk_probes_out, icsk->icsk_probes_out + 1); elapsed = keepalive_intvl_when(tp); } else { /* If keepalive was lost due to local congestion, @@ -743,10 +850,8 @@ static void tcp_keepalive_timer (struct timer_list *t) elapsed = keepalive_time_when(tp) - elapsed; } - sk_mem_reclaim(sk); - resched: - inet_csk_reset_keepalive_timer (sk, elapsed); + tcp_reset_keepalive_timer(sk, elapsed); goto out; death: @@ -770,6 +875,7 @@ static enum hrtimer_restart tcp_compressed_ack_kick(struct hrtimer *timer) * LINUX_MIB_TCPACKCOMPRESSED accurate. */ tp->compressed_ack--; + tcp_mstamp_refresh(tp); tcp_send_ack(sk); } } else { @@ -788,11 +894,9 @@ void tcp_init_xmit_timers(struct sock *sk) { inet_csk_init_xmit_timers(sk, &tcp_write_timer, &tcp_delack_timer, &tcp_keepalive_timer); - hrtimer_init(&tcp_sk(sk)->pacing_timer, CLOCK_MONOTONIC, - HRTIMER_MODE_ABS_PINNED_SOFT); - tcp_sk(sk)->pacing_timer.function = tcp_pace_kick; + hrtimer_setup(&tcp_sk(sk)->pacing_timer, tcp_pace_kick, CLOCK_MONOTONIC, + HRTIMER_MODE_ABS_PINNED_SOFT); - hrtimer_init(&tcp_sk(sk)->compressed_ack_timer, CLOCK_MONOTONIC, - HRTIMER_MODE_REL_PINNED_SOFT); - tcp_sk(sk)->compressed_ack_timer.function = tcp_compressed_ack_kick; + hrtimer_setup(&tcp_sk(sk)->compressed_ack_timer, tcp_compressed_ack_kick, CLOCK_MONOTONIC, + HRTIMER_MODE_REL_PINNED_SOFT); } diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c index 7c27aa629af1..2aa442128630 100644 --- a/net/ipv4/tcp_ulp.c +++ b/net/ipv4/tcp_ulp.c @@ -136,6 +136,13 @@ static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops) if (icsk->icsk_ulp_ops) goto out_err; + if (sk->sk_socket) + clear_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); + + err = -ENOTCONN; + if (!ulp_ops->clone && sk->sk_state == TCP_LISTEN) + goto out_err; + err = ulp_ops->init(sk); if (err) goto out_err; diff --git a/net/ipv4/tcp_vegas.c b/net/ipv4/tcp_vegas.c index c8003c8aad2c..786848ad37ea 100644 --- a/net/ipv4/tcp_vegas.c +++ b/net/ipv4/tcp_vegas.c @@ -159,7 +159,7 @@ EXPORT_SYMBOL_GPL(tcp_vegas_cwnd_event); static inline u32 tcp_vegas_ssthresh(struct tcp_sock *tp) { - return min(tp->snd_ssthresh, tp->snd_cwnd); + return min(tp->snd_ssthresh, tcp_snd_cwnd(tp)); } static void tcp_vegas_cong_avoid(struct sock *sk, u32 ack, u32 acked) @@ -217,14 +217,14 @@ static void tcp_vegas_cong_avoid(struct sock *sk, u32 ack, u32 acked) * This is: * (actual rate in segments) * baseRTT */ - target_cwnd = (u64)tp->snd_cwnd * vegas->baseRTT; + target_cwnd = (u64)tcp_snd_cwnd(tp) * vegas->baseRTT; do_div(target_cwnd, rtt); /* Calculate the difference between the window we had, * and the window we would like to have. This quantity * is the "Diff" from the Arizona Vegas papers. */ - diff = tp->snd_cwnd * (rtt-vegas->baseRTT) / vegas->baseRTT; + diff = tcp_snd_cwnd(tp) * (rtt-vegas->baseRTT) / vegas->baseRTT; if (diff > gamma && tcp_in_slow_start(tp)) { /* Going too fast. Time to slow down @@ -238,7 +238,8 @@ static void tcp_vegas_cong_avoid(struct sock *sk, u32 ack, u32 acked) * truncation robs us of full link * utilization. */ - tp->snd_cwnd = min(tp->snd_cwnd, (u32)target_cwnd+1); + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), + (u32)target_cwnd + 1)); tp->snd_ssthresh = tcp_vegas_ssthresh(tp); } else if (tcp_in_slow_start(tp)) { @@ -254,14 +255,14 @@ static void tcp_vegas_cong_avoid(struct sock *sk, u32 ack, u32 acked) /* The old window was too fast, so * we slow down. */ - tp->snd_cwnd--; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) - 1); tp->snd_ssthresh = tcp_vegas_ssthresh(tp); } else if (diff < alpha) { /* We don't have enough extra packets * in the network, so speed up. */ - tp->snd_cwnd++; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); } else { /* Sending just as fast as we * should be. @@ -269,10 +270,10 @@ static void tcp_vegas_cong_avoid(struct sock *sk, u32 ack, u32 acked) } } - if (tp->snd_cwnd < 2) - tp->snd_cwnd = 2; - else if (tp->snd_cwnd > tp->snd_cwnd_clamp) - tp->snd_cwnd = tp->snd_cwnd_clamp; + if (tcp_snd_cwnd(tp) < 2) + tcp_snd_cwnd_set(tp, 2); + else if (tcp_snd_cwnd(tp) > tp->snd_cwnd_clamp) + tcp_snd_cwnd_set(tp, tp->snd_cwnd_clamp); tp->snd_ssthresh = tcp_current_ssthresh(sk); } diff --git a/net/ipv4/tcp_veno.c b/net/ipv4/tcp_veno.c index cd50a61c9976..366ff6f214b2 100644 --- a/net/ipv4/tcp_veno.c +++ b/net/ipv4/tcp_veno.c @@ -146,11 +146,11 @@ static void tcp_veno_cong_avoid(struct sock *sk, u32 ack, u32 acked) rtt = veno->minrtt; - target_cwnd = (u64)tp->snd_cwnd * veno->basertt; + target_cwnd = (u64)tcp_snd_cwnd(tp) * veno->basertt; target_cwnd <<= V_PARAM_SHIFT; do_div(target_cwnd, rtt); - veno->diff = (tp->snd_cwnd << V_PARAM_SHIFT) - target_cwnd; + veno->diff = (tcp_snd_cwnd(tp) << V_PARAM_SHIFT) - target_cwnd; if (tcp_in_slow_start(tp)) { /* Slow start. */ @@ -164,15 +164,15 @@ static void tcp_veno_cong_avoid(struct sock *sk, u32 ack, u32 acked) /* In the "non-congestive state", increase cwnd * every rtt. */ - tcp_cong_avoid_ai(tp, tp->snd_cwnd, acked); + tcp_cong_avoid_ai(tp, tcp_snd_cwnd(tp), acked); } else { /* In the "congestive state", increase cwnd * every other rtt. */ - if (tp->snd_cwnd_cnt >= tp->snd_cwnd) { + if (tp->snd_cwnd_cnt >= tcp_snd_cwnd(tp)) { if (veno->inc && - tp->snd_cwnd < tp->snd_cwnd_clamp) { - tp->snd_cwnd++; + tcp_snd_cwnd(tp) < tp->snd_cwnd_clamp) { + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); veno->inc = 0; } else veno->inc = 1; @@ -181,10 +181,10 @@ static void tcp_veno_cong_avoid(struct sock *sk, u32 ack, u32 acked) tp->snd_cwnd_cnt += acked; } done: - if (tp->snd_cwnd < 2) - tp->snd_cwnd = 2; - else if (tp->snd_cwnd > tp->snd_cwnd_clamp) - tp->snd_cwnd = tp->snd_cwnd_clamp; + if (tcp_snd_cwnd(tp) < 2) + tcp_snd_cwnd_set(tp, 2); + else if (tcp_snd_cwnd(tp) > tp->snd_cwnd_clamp) + tcp_snd_cwnd_set(tp, tp->snd_cwnd_clamp); } /* Wipe the slate clean for the next rtt. */ /* veno->cntrtt = 0; */ @@ -199,10 +199,10 @@ static u32 tcp_veno_ssthresh(struct sock *sk) if (veno->diff < beta) /* in "non-congestive state", cut cwnd by 1/5 */ - return max(tp->snd_cwnd * 4 / 5, 2U); + return max(tcp_snd_cwnd(tp) * 4 / 5, 2U); else /* in "congestive state", cut cwnd by 1/2 */ - return max(tp->snd_cwnd >> 1U, 2U); + return max(tcp_snd_cwnd(tp) >> 1U, 2U); } static struct tcp_congestion_ops tcp_veno __read_mostly = { diff --git a/net/ipv4/tcp_westwood.c b/net/ipv4/tcp_westwood.c index b2e05c4cea00..c6e97141eef2 100644 --- a/net/ipv4/tcp_westwood.c +++ b/net/ipv4/tcp_westwood.c @@ -244,7 +244,8 @@ static void tcp_westwood_event(struct sock *sk, enum tcp_ca_event event) switch (event) { case CA_EVENT_COMPLETE_CWR: - tp->snd_cwnd = tp->snd_ssthresh = tcp_westwood_bw_rttmin(sk); + tp->snd_ssthresh = tcp_westwood_bw_rttmin(sk); + tcp_snd_cwnd_set(tp, tp->snd_ssthresh); break; case CA_EVENT_LOSS: tp->snd_ssthresh = tcp_westwood_bw_rttmin(sk); diff --git a/net/ipv4/tcp_yeah.c b/net/ipv4/tcp_yeah.c index 07c4c93b9fdb..18b07ff5d20e 100644 --- a/net/ipv4/tcp_yeah.c +++ b/net/ipv4/tcp_yeah.c @@ -71,11 +71,11 @@ static void tcp_yeah_cong_avoid(struct sock *sk, u32 ack, u32 acked) if (!yeah->doing_reno_now) { /* Scalable */ - tcp_cong_avoid_ai(tp, min(tp->snd_cwnd, TCP_SCALABLE_AI_CNT), + tcp_cong_avoid_ai(tp, min(tcp_snd_cwnd(tp), TCP_SCALABLE_AI_CNT), acked); } else { /* Reno */ - tcp_cong_avoid_ai(tp, tp->snd_cwnd, acked); + tcp_cong_avoid_ai(tp, tcp_snd_cwnd(tp), acked); } /* The key players are v_vegas.beg_snd_una and v_beg_snd_nxt. @@ -130,7 +130,7 @@ do_vegas: /* Compute excess number of packets above bandwidth * Avoid doing full 64 bit divide. */ - bw = tp->snd_cwnd; + bw = tcp_snd_cwnd(tp); bw *= rtt - yeah->vegas.baseRTT; do_div(bw, rtt); queue = bw; @@ -138,20 +138,20 @@ do_vegas: if (queue > TCP_YEAH_ALPHA || rtt - yeah->vegas.baseRTT > (yeah->vegas.baseRTT / TCP_YEAH_PHY)) { if (queue > TCP_YEAH_ALPHA && - tp->snd_cwnd > yeah->reno_count) { + tcp_snd_cwnd(tp) > yeah->reno_count) { u32 reduction = min(queue / TCP_YEAH_GAMMA , - tp->snd_cwnd >> TCP_YEAH_EPSILON); + tcp_snd_cwnd(tp) >> TCP_YEAH_EPSILON); - tp->snd_cwnd -= reduction; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) - reduction); - tp->snd_cwnd = max(tp->snd_cwnd, - yeah->reno_count); + tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), + yeah->reno_count)); - tp->snd_ssthresh = tp->snd_cwnd; + tp->snd_ssthresh = tcp_snd_cwnd(tp); } if (yeah->reno_count <= 2) - yeah->reno_count = max(tp->snd_cwnd>>1, 2U); + yeah->reno_count = max(tcp_snd_cwnd(tp)>>1, 2U); else yeah->reno_count++; @@ -176,7 +176,7 @@ do_vegas: */ yeah->vegas.beg_snd_una = yeah->vegas.beg_snd_nxt; yeah->vegas.beg_snd_nxt = tp->snd_nxt; - yeah->vegas.beg_snd_cwnd = tp->snd_cwnd; + yeah->vegas.beg_snd_cwnd = tcp_snd_cwnd(tp); /* Wipe the slate clean for the next RTT. */ yeah->vegas.cntRTT = 0; @@ -193,16 +193,16 @@ static u32 tcp_yeah_ssthresh(struct sock *sk) if (yeah->doing_reno_now < TCP_YEAH_RHO) { reduction = yeah->lastQ; - reduction = min(reduction, max(tp->snd_cwnd>>1, 2U)); + reduction = min(reduction, max(tcp_snd_cwnd(tp)>>1, 2U)); - reduction = max(reduction, tp->snd_cwnd >> TCP_YEAH_DELTA); + reduction = max(reduction, tcp_snd_cwnd(tp) >> TCP_YEAH_DELTA); } else - reduction = max(tp->snd_cwnd>>1, 2U); + reduction = max(tcp_snd_cwnd(tp)>>1, 2U); yeah->fast_count = 0; yeah->reno_count = max(yeah->reno_count>>1, 2U); - return max_t(int, tp->snd_cwnd - reduction, 2); + return max_t(int, tcp_snd_cwnd(tp) - reduction, 2); } static struct tcp_congestion_ops tcp_yeah __read_mostly = { diff --git a/net/ipv4/tunnel4.c b/net/ipv4/tunnel4.c index 5048c47c79b2..4c1f836aae38 100644 --- a/net/ipv4/tunnel4.c +++ b/net/ipv4/tunnel4.c @@ -294,4 +294,5 @@ static void __exit tunnel4_fini(void) module_init(tunnel4_init); module_exit(tunnel4_fini); +MODULE_DESCRIPTION("IPv4 XFRM tunnel library"); MODULE_LICENSE("GPL"); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 090360939401..ffe074cb5865 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -68,7 +68,7 @@ * YOSHIFUJI Hideaki @USAGI and: Support IPV6_V6ONLY socket option, which * Alexey Kuznetsov: allow both IPv4 and IPv6 sockets to bind * a single port at the same time. - * Derek Atkins <derek@ihtfp.com>: Add Encapulation Support + * Derek Atkins <derek@ihtfp.com>: Add Encapsulation Support * James Chapman : Add L2TP encapsulation type. */ @@ -93,6 +93,7 @@ #include <linux/inet.h> #include <linux/netdevice.h> #include <linux/slab.h> +#include <linux/sock_diag.h> #include <net/tcp_states.h> #include <linux/skbuff.h> #include <linux/proc_fs.h> @@ -100,9 +101,11 @@ #include <net/net_namespace.h> #include <net/icmp.h> #include <net/inet_hashtables.h> +#include <net/ip.h> #include <net/ip_tunnels.h> #include <net/route.h> #include <net/checksum.h> +#include <net/gso.h> #include <net/xfrm.h> #include <trace/events/udp.h> #include <linux/static_key.h> @@ -113,29 +116,35 @@ #include <net/sock_reuseport.h> #include <net/addrconf.h> #include <net/udp_tunnel.h> +#include <net/gro.h> #if IS_ENABLED(CONFIG_IPV6) #include <net/ipv6_stubs.h> #endif +#include <net/rps.h> struct udp_table udp_table __read_mostly; -EXPORT_SYMBOL(udp_table); long sysctl_udp_mem[3] __read_mostly; -EXPORT_SYMBOL(sysctl_udp_mem); +EXPORT_IPV6_MOD(sysctl_udp_mem); -atomic_long_t udp_memory_allocated ____cacheline_aligned_in_smp; -EXPORT_SYMBOL(udp_memory_allocated); +DEFINE_PER_CPU(int, udp_memory_per_cpu_fw_alloc); +EXPORT_PER_CPU_SYMBOL_GPL(udp_memory_per_cpu_fw_alloc); #define MAX_UDP_PORTS 65536 -#define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN) +#define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN_PERNET) + +static struct udp_table *udp_get_table_prot(struct sock *sk) +{ + return sk->sk_prot->h.udp_table ? : sock_net(sk)->ipv4.udp_table; +} static int udp_lib_lport_inuse(struct net *net, __u16 num, const struct udp_hslot *hslot, unsigned long *bitmap, struct sock *sk, unsigned int log) { + kuid_t uid = sk_uid(sk); struct sock *sk2; - kuid_t uid = sock_i_uid(sk); sk_for_each(sk2, &hslot->head) { if (net_eq(sock_net(sk2), net) && @@ -147,7 +156,7 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num, inet_rcv_saddr_equal(sk, sk2, true)) { if (sk2->sk_reuseport && sk->sk_reuseport && !rcu_access_pointer(sk->sk_reuseport_cb) && - uid_eq(uid, sock_i_uid(sk2))) { + uid_eq(uid, sk_uid(sk2))) { if (!bitmap) return 0; } else { @@ -169,8 +178,8 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num, struct udp_hslot *hslot2, struct sock *sk) { + kuid_t uid = sk_uid(sk); struct sock *sk2; - kuid_t uid = sock_i_uid(sk); int res = 0; spin_lock(&hslot2->lock); @@ -184,7 +193,7 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num, inet_rcv_saddr_equal(sk, sk2, true)) { if (sk2->sk_reuseport && sk->sk_reuseport && !rcu_access_pointer(sk->sk_reuseport_cb) && - uid_eq(uid, sock_i_uid(sk2))) { + uid_eq(uid, sk_uid(sk2))) { res = 0; } else { res = 1; @@ -199,7 +208,7 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num, static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot) { struct net *net = sock_net(sk); - kuid_t uid = sock_i_uid(sk); + kuid_t uid = sk_uid(sk); struct sock *sk2; sk_for_each(sk2, &hslot->head) { @@ -209,7 +218,7 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot) ipv6_only_sock(sk2) == ipv6_only_sock(sk) && (udp_sk(sk2)->udp_port_hash == udp_sk(sk)->udp_port_hash) && (sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && - sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && + sk2->sk_reuseport && uid_eq(uid, sk_uid(sk2)) && inet_rcv_saddr_equal(sk, sk2, false)) { return reuseport_add_sock(sk, sk2, inet_rcv_saddr_any(sk)); @@ -230,21 +239,21 @@ static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot) int udp_lib_get_port(struct sock *sk, unsigned short snum, unsigned int hash2_nulladdr) { + struct udp_table *udptable = udp_get_table_prot(sk); struct udp_hslot *hslot, *hslot2; - struct udp_table *udptable = sk->sk_prot->h.udp_table; - int error = 1; struct net *net = sock_net(sk); + int error = -EADDRINUSE; if (!snum) { + DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN); + unsigned short first, last; int low, high, remaining; unsigned int rand; - unsigned short first, last; - DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN); - inet_get_local_port_range(net, &low, &high); + inet_sk_get_local_port_range(sk, &low, &high); remaining = (high - low) + 1; - rand = prandom_u32(); + rand = get_random_u32(); first = reciprocal_scale(rand, remaining) + low; /* * force rand to be an odd multiple of UDP_HTABLE_SIZE @@ -317,6 +326,8 @@ found: goto fail_unlock; } + sock_set_flag(sk, SOCK_RCU_FREE); + sk_add_node_rcu(sk, &hslot->head); hslot->count++; sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); @@ -333,14 +344,14 @@ found: hslot2->count++; spin_unlock(&hslot2->lock); } - sock_set_flag(sk, SOCK_RCU_FREE); + error = 0; fail_unlock: spin_unlock_bh(&hslot->lock); fail: return error; } -EXPORT_SYMBOL(udp_lib_get_port); +EXPORT_IPV6_MOD(udp_lib_get_port); int udp_v4_get_port(struct sock *sk, unsigned short snum) { @@ -354,7 +365,7 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum) return udp_lib_get_port(sk, snum, hash2_nulladdr); } -static int compute_score(struct sock *sk, struct net *net, +static int compute_score(struct sock *sk, const struct net *net, __be32 saddr, __be16 sport, __be32 daddr, unsigned short hnum, int dif, int sdif) @@ -398,36 +409,61 @@ static int compute_score(struct sock *sk, struct net *net, return score; } -static u32 udp_ehashfn(const struct net *net, const __be32 laddr, - const __u16 lport, const __be32 faddr, - const __be16 fport) +u32 udp_ehashfn(const struct net *net, const __be32 laddr, const __u16 lport, + const __be32 faddr, const __be16 fport) { - static u32 udp_ehash_secret __read_mostly; - net_get_random_once(&udp_ehash_secret, sizeof(udp_ehash_secret)); return __inet_ehashfn(laddr, lport, faddr, fport, udp_ehash_secret + net_hash_mix(net)); } +EXPORT_IPV6_MOD(udp_ehashfn); -static struct sock *lookup_reuseport(struct net *net, struct sock *sk, - struct sk_buff *skb, +/** + * udp4_lib_lookup1() - Simplified lookup using primary hash (destination port) + * @net: Network namespace + * @saddr: Source address, network order + * @sport: Source port, network order + * @daddr: Destination address, network order + * @hnum: Destination port, host order + * @dif: Destination interface index + * @sdif: Destination bridge port index, if relevant + * @udptable: Set of UDP hash tables + * + * Simplified lookup to be used as fallback if no sockets are found due to a + * potential race between (receive) address change, and lookup happening before + * the rehash operation. This function ignores SO_REUSEPORT groups while scoring + * result sockets, because if we have one, we don't need the fallback at all. + * + * Called under rcu_read_lock(). + * + * Return: socket with highest matching score if any, NULL if none + */ +static struct sock *udp4_lib_lookup1(const struct net *net, __be32 saddr, __be16 sport, - __be32 daddr, unsigned short hnum) + __be32 daddr, unsigned int hnum, + int dif, int sdif, + const struct udp_table *udptable) { - struct sock *reuse_sk = NULL; - u32 hash; + unsigned int slot = udp_hashfn(net, hnum, udptable->mask); + struct udp_hslot *hslot = &udptable->hash[slot]; + struct sock *sk, *result = NULL; + int score, badness = 0; - if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) { - hash = udp_ehashfn(net, daddr, hnum, saddr, sport); - reuse_sk = reuseport_select_sock(sk, hash, skb, - sizeof(struct udphdr)); + sk_for_each_rcu(sk, &hslot->head) { + score = compute_score(sk, net, + saddr, sport, daddr, hnum, dif, sdif); + if (score > badness) { + result = sk; + badness = score; + } } - return reuse_sk; + + return result; } /* called with rcu_read_lock() */ -static struct sock *udp4_lib_lookup2(struct net *net, +static struct sock *udp4_lib_lookup2(const struct net *net, __be32 saddr, __be16 sport, __be32 daddr, unsigned int hnum, int dif, int sdif, @@ -436,64 +472,226 @@ static struct sock *udp4_lib_lookup2(struct net *net, { struct sock *sk, *result; int score, badness; + bool need_rescore; result = NULL; badness = 0; udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { - score = compute_score(sk, net, saddr, sport, - daddr, hnum, dif, sdif); + need_rescore = false; +rescore: + score = compute_score(need_rescore ? result : sk, net, saddr, + sport, daddr, hnum, dif, sdif); if (score > badness) { - result = lookup_reuseport(net, sk, skb, - saddr, sport, daddr, hnum); + badness = score; + + if (need_rescore) + continue; + + if (sk->sk_state == TCP_ESTABLISHED) { + result = sk; + continue; + } + + result = inet_lookup_reuseport(net, sk, skb, sizeof(struct udphdr), + saddr, sport, daddr, hnum, udp_ehashfn); + if (!result) { + result = sk; + continue; + } + /* Fall back to scoring if group has connections */ - if (result && !reuseport_has_conns(sk, false)) + if (!reuseport_has_conns(sk)) return result; - result = result ? : sk; - badness = score; + /* Reuseport logic returned an error, keep original score. */ + if (IS_ERR(result)) + continue; + + /* compute_score is too long of a function to be + * inlined, and calling it again here yields + * measurable overhead for some + * workloads. Work around it by jumping + * backwards to rescore 'result'. + */ + need_rescore = true; + goto rescore; } } return result; } -static struct sock *udp4_lookup_run_bpf(struct net *net, - struct udp_table *udptable, - struct sk_buff *skb, - __be32 saddr, __be16 sport, - __be32 daddr, u16 hnum, const int dif) +#if IS_ENABLED(CONFIG_BASE_SMALL) +static struct sock *udp4_lib_lookup4(const struct net *net, + __be32 saddr, __be16 sport, + __be32 daddr, unsigned int hnum, + int dif, int sdif, + struct udp_table *udptable) { - struct sock *sk, *reuse_sk; - bool no_reuseport; + return NULL; +} - if (udptable != &udp_table) - return NULL; /* only UDP is supported */ +static void udp_rehash4(struct udp_table *udptable, struct sock *sk, + u16 newhash4) +{ +} - no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_UDP, saddr, sport, - daddr, hnum, dif, &sk); - if (no_reuseport || IS_ERR_OR_NULL(sk)) - return sk; +static void udp_unhash4(struct udp_table *udptable, struct sock *sk) +{ +} +#else /* !CONFIG_BASE_SMALL */ +static struct sock *udp4_lib_lookup4(const struct net *net, + __be32 saddr, __be16 sport, + __be32 daddr, unsigned int hnum, + int dif, int sdif, + struct udp_table *udptable) +{ + const __portpair ports = INET_COMBINED_PORTS(sport, hnum); + const struct hlist_nulls_node *node; + struct udp_hslot *hslot4; + unsigned int hash4, slot; + struct udp_sock *up; + struct sock *sk; - reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum); - if (reuse_sk) - sk = reuse_sk; - return sk; + hash4 = udp_ehashfn(net, daddr, hnum, saddr, sport); + slot = hash4 & udptable->mask; + hslot4 = &udptable->hash4[slot]; + INET_ADDR_COOKIE(acookie, saddr, daddr); + +begin: + /* SLAB_TYPESAFE_BY_RCU not used, so we don't need to touch sk_refcnt */ + udp_lrpa_for_each_entry_rcu(up, node, &hslot4->nulls_head) { + sk = (struct sock *)up; + if (inet_match(net, sk, acookie, ports, dif, sdif)) + return sk; + } + + /* if the nulls value we got at the end of this lookup is not the + * expected one, we must restart lookup. We probably met an item that + * was moved to another chain due to rehash. + */ + if (get_nulls_value(node) != slot) + goto begin; + + return NULL; +} + +/* udp_rehash4() only checks hslot4, and hash4_cnt is not processed. */ +static void udp_rehash4(struct udp_table *udptable, struct sock *sk, + u16 newhash4) +{ + struct udp_hslot *hslot4, *nhslot4; + + hslot4 = udp_hashslot4(udptable, udp_sk(sk)->udp_lrpa_hash); + nhslot4 = udp_hashslot4(udptable, newhash4); + udp_sk(sk)->udp_lrpa_hash = newhash4; + + if (hslot4 != nhslot4) { + spin_lock_bh(&hslot4->lock); + hlist_nulls_del_init_rcu(&udp_sk(sk)->udp_lrpa_node); + hslot4->count--; + spin_unlock_bh(&hslot4->lock); + + spin_lock_bh(&nhslot4->lock); + hlist_nulls_add_head_rcu(&udp_sk(sk)->udp_lrpa_node, + &nhslot4->nulls_head); + nhslot4->count++; + spin_unlock_bh(&nhslot4->lock); + } +} + +static void udp_unhash4(struct udp_table *udptable, struct sock *sk) +{ + struct udp_hslot *hslot2, *hslot4; + + if (udp_hashed4(sk)) { + hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash); + hslot4 = udp_hashslot4(udptable, udp_sk(sk)->udp_lrpa_hash); + + spin_lock(&hslot4->lock); + hlist_nulls_del_init_rcu(&udp_sk(sk)->udp_lrpa_node); + hslot4->count--; + spin_unlock(&hslot4->lock); + + spin_lock(&hslot2->lock); + udp_hash4_dec(hslot2); + spin_unlock(&hslot2->lock); + } +} + +void udp_lib_hash4(struct sock *sk, u16 hash) +{ + struct udp_hslot *hslot, *hslot2, *hslot4; + struct net *net = sock_net(sk); + struct udp_table *udptable; + + /* Connected udp socket can re-connect to another remote address, which + * will be handled by rehash. Thus no need to redo hash4 here. + */ + if (udp_hashed4(sk)) + return; + + udptable = net->ipv4.udp_table; + hslot = udp_hashslot(udptable, net, udp_sk(sk)->udp_port_hash); + hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash); + hslot4 = udp_hashslot4(udptable, hash); + udp_sk(sk)->udp_lrpa_hash = hash; + + spin_lock_bh(&hslot->lock); + if (rcu_access_pointer(sk->sk_reuseport_cb)) + reuseport_detach_sock(sk); + + spin_lock(&hslot4->lock); + hlist_nulls_add_head_rcu(&udp_sk(sk)->udp_lrpa_node, + &hslot4->nulls_head); + hslot4->count++; + spin_unlock(&hslot4->lock); + + spin_lock(&hslot2->lock); + udp_hash4_inc(hslot2); + spin_unlock(&hslot2->lock); + + spin_unlock_bh(&hslot->lock); } +EXPORT_IPV6_MOD(udp_lib_hash4); + +/* call with sock lock */ +void udp4_hash4(struct sock *sk) +{ + struct net *net = sock_net(sk); + unsigned int hash; + + if (sk_unhashed(sk) || sk->sk_rcv_saddr == htonl(INADDR_ANY)) + return; + + hash = udp_ehashfn(net, sk->sk_rcv_saddr, sk->sk_num, + sk->sk_daddr, sk->sk_dport); + + udp_lib_hash4(sk, hash); +} +EXPORT_IPV6_MOD(udp4_hash4); +#endif /* CONFIG_BASE_SMALL */ /* UDP is nearly always wildcards out the wazoo, it makes no sense to try * harder than this. -DaveM */ -struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, +struct sock *__udp4_lib_lookup(const struct net *net, __be32 saddr, __be16 sport, __be32 daddr, __be16 dport, int dif, int sdif, struct udp_table *udptable, struct sk_buff *skb) { unsigned short hnum = ntohs(dport); - unsigned int hash2, slot2; struct udp_hslot *hslot2; struct sock *result, *sk; + unsigned int hash2; hash2 = ipv4_portaddr_hash(net, daddr, hnum); - slot2 = hash2 & udptable->mask; - hslot2 = &udptable->hash2[slot2]; + hslot2 = udp_hashslot2(udptable, hash2); + + if (udp_has_hash4(hslot2)) { + result = udp4_lib_lookup4(net, saddr, sport, daddr, hnum, + dif, sdif, udptable); + if (result) /* udp4_lib_lookup4 return sk or NULL */ + return result; + } /* Lookup connected or non-wildcard socket */ result = udp4_lib_lookup2(net, saddr, sport, @@ -503,9 +701,11 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, goto done; /* Lookup redirect from BPF */ - if (static_branch_unlikely(&bpf_sk_lookup_enabled)) { - sk = udp4_lookup_run_bpf(net, udptable, skb, - saddr, sport, daddr, hnum, dif); + if (static_branch_unlikely(&bpf_sk_lookup_enabled) && + udptable == net->ipv4.udp_table) { + sk = inet_lookup_run_sk_lookup(net, IPPROTO_UDP, skb, sizeof(struct udphdr), + saddr, sport, daddr, hnum, dif, + udp_ehashfn); if (sk) { result = sk; goto done; @@ -518,12 +718,24 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, /* Lookup wildcard sockets */ hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum); - slot2 = hash2 & udptable->mask; - hslot2 = &udptable->hash2[slot2]; + hslot2 = udp_hashslot2(udptable, hash2); result = udp4_lib_lookup2(net, saddr, sport, htonl(INADDR_ANY), hnum, dif, sdif, hslot2, skb); + if (!IS_ERR_OR_NULL(result)) + goto done; + + /* Primary hash (destination port) lookup as fallback for this race: + * 1. __ip4_datagram_connect() sets sk_rcv_saddr + * 2. lookup (this function): new sk_rcv_saddr, hashes not updated yet + * 3. rehash operation updating _secondary and four-tuple_ hashes + * The primary hash doesn't need an update after 1., so, thanks to this + * further step, 1. and 3. don't need to be atomic against the lookup. + */ + result = udp4_lib_lookup1(net, saddr, sport, daddr, hnum, dif, sdif, + udptable); + done: if (IS_ERR(result)) return NULL; @@ -545,24 +757,29 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, struct sock *udp4_lib_lookup_skb(const struct sk_buff *skb, __be16 sport, __be16 dport) { - const struct iphdr *iph = ip_hdr(skb); + const u16 offset = NAPI_GRO_CB(skb)->network_offsets[skb->encapsulation]; + const struct iphdr *iph = (struct iphdr *)(skb->data + offset); + struct net *net = dev_net(skb->dev); + int iif, sdif; - return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport, - iph->daddr, dport, inet_iif(skb), - inet_sdif(skb), &udp_table, NULL); + inet_get_iif_sdif(skb, &iif, &sdif); + + return __udp4_lib_lookup(net, iph->saddr, sport, + iph->daddr, dport, iif, + sdif, net->ipv4.udp_table, NULL); } /* Must be called under rcu_read_lock(). * Does increment socket refcount. */ #if IS_ENABLED(CONFIG_NF_TPROXY_IPV4) || IS_ENABLED(CONFIG_NF_SOCKET_IPV4) -struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, +struct sock *udp4_lib_lookup(const struct net *net, __be32 saddr, __be16 sport, __be32 daddr, __be16 dport, int dif) { struct sock *sk; sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport, - dif, 0, &udp_table, NULL); + dif, 0, net->ipv4.udp_table, NULL); if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) sk = NULL; return sk; @@ -570,12 +787,12 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, EXPORT_SYMBOL_GPL(udp4_lib_lookup); #endif -static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk, +static inline bool __udp_is_mcast_sock(struct net *net, const struct sock *sk, __be16 loc_port, __be32 loc_addr, __be16 rmt_port, __be32 rmt_addr, int dif, int sdif, unsigned short hnum) { - struct inet_sock *inet = inet_sk(sk); + const struct inet_sock *inet = inet_sk(sk); if (!net_eq(sock_net(sk), net) || udp_sk(sk)->udp_port_hash != hnum || @@ -591,6 +808,13 @@ static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk, } DEFINE_STATIC_KEY_FALSE(udp_encap_needed_key); +EXPORT_IPV6_MOD(udp_encap_needed_key); + +#if IS_ENABLED(CONFIG_IPV6) +DEFINE_STATIC_KEY_FALSE(udpv6_encap_needed_key); +EXPORT_IPV6_MOD(udpv6_encap_needed_key); +#endif + void udp_encap_enable(void) { static_branch_inc(&udp_encap_needed_key); @@ -721,7 +945,7 @@ int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable) iph->saddr, uh->source, skb->dev->ifindex, inet_sdif(skb), udptable, NULL); - if (!sk || udp_sk(sk)->encap_type) { + if (!sk || READ_ONCE(udp_sk(sk)->encap_type)) { /* No socket for error: try tunnels before discarding */ if (static_branch_unlikely(&udp_encap_needed_key)) { sk = __udp4_lib_err_encap(net, iph, uh, udptable, sk, skb, @@ -757,7 +981,7 @@ int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable) case ICMP_DEST_UNREACH: if (code == ICMP_FRAG_NEEDED) { /* Path MTU discovery */ ipv4_sk_update_pmtu(skb, sk, info); - if (inet->pmtudisc != IP_PMTUDISC_DONT) { + if (READ_ONCE(inet->pmtudisc) != IP_PMTUDISC_DONT) { err = EMSGSIZE; harderr = 1; break; @@ -781,9 +1005,12 @@ int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable) */ if (tunnel) { /* ...not for tunnels though: we don't have a sending socket */ + if (udp_sk(sk)->encap_err_rcv) + udp_sk(sk)->encap_err_rcv(sk, skb, err, uh->dest, info, + (u8 *)(uh+1)); goto out; } - if (!inet->recverr) { + if (!inet_test_bit(RECVERR, sk)) { if (!harderr || sk->sk_state != TCP_ESTABLISHED) goto out; } else @@ -797,7 +1024,7 @@ out: int udp_err(struct sk_buff *skb, u32 info) { - return __udp4_lib_err(skb, info, &udp_table); + return __udp4_lib_err(skb, info, dev_net(skb->dev)->ipv4.udp_table); } /* @@ -809,11 +1036,11 @@ void udp_flush_pending_frames(struct sock *sk) if (up->pending) { up->len = 0; - up->pending = 0; + WRITE_ONCE(up->pending, 0); ip_flush_pending_frames(sk); } } -EXPORT_SYMBOL(udp_flush_pending_frames); +EXPORT_IPV6_MOD(udp_flush_pending_frames); /** * udp4_hwcsum - handle outgoing HW checksumming @@ -913,9 +1140,9 @@ static int udp_send_skb(struct sk_buff *skb, struct flowi4 *fl4, const int hlen = skb_network_header_len(skb) + sizeof(struct udphdr); - if (hlen + cork->gso_size > cork->fragsize) { + if (hlen + min(datalen, cork->gso_size) > cork->fragsize) { kfree_skb(skb); - return -EINVAL; + return -EMSGSIZE; } if (datalen > cork->gso_size * UDP_MAX_SEGMENTS) { kfree_skb(skb); @@ -925,8 +1152,7 @@ static int udp_send_skb(struct sk_buff *skb, struct flowi4 *fl4, kfree_skb(skb); return -EINVAL; } - if (skb->ip_summed != CHECKSUM_PARTIAL || is_udplite || - dst_xfrm(skb_dst(skb))) { + if (is_udplite || dst_xfrm(skb_dst(skb))) { kfree_skb(skb); return -EIO; } @@ -936,8 +1162,10 @@ static int udp_send_skb(struct sk_buff *skb, struct flowi4 *fl4, skb_shinfo(skb)->gso_type = SKB_GSO_UDP_L4; skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(datalen, cork->gso_size); + + /* Don't checksum the payload, skb will get segmented */ + goto csum_partial; } - goto csum_partial; } if (is_udplite) /* UDP-Lite */ @@ -966,7 +1194,8 @@ csum_partial: send: err = ip_send_skb(sock_net(sk), skb); if (err) { - if (err == -ENOBUFS && !inet->recverr) { + if (err == -ENOBUFS && + !inet_test_bit(RECVERR, sk)) { UDP_INC_STATS(sock_net(sk), UDP_MIB_SNDBUFERRORS, is_udplite); err = 0; @@ -996,10 +1225,10 @@ int udp_push_pending_frames(struct sock *sk) out: up->len = 0; - up->pending = 0; + WRITE_ONCE(up->pending, 0); return err; } -EXPORT_SYMBOL(udp_push_pending_frames); +EXPORT_IPV6_MOD(udp_push_pending_frames); static int __udp_cmsg_send(struct cmsghdr *cmsg, u16 *gso_size) { @@ -1036,7 +1265,7 @@ int udp_cmsg_send(struct sock *sk, struct msghdr *msg, u16 *gso_size) return need_ip; } -EXPORT_SYMBOL_GPL(udp_cmsg_send); +EXPORT_IPV6_MOD_GPL(udp_cmsg_send); int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) { @@ -1051,13 +1280,14 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) int free = 0; int connected = 0; __be32 daddr, faddr, saddr; + u8 scope; __be16 dport; - u8 tos; int err, is_udplite = IS_UDPLITE(sk); - int corkreq = READ_ONCE(up->corkflag) || msg->msg_flags&MSG_MORE; + int corkreq = udp_test_bit(CORK, sk) || msg->msg_flags & MSG_MORE; int (*getfrag)(void *, char *, int, int, int, struct sk_buff *); struct sk_buff *skb; struct ip_options_data opt_copy; + int uc_index; if (len > 0xFFFF) return -EMSGSIZE; @@ -1072,7 +1302,7 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) getfrag = is_udplite ? udplite_getfrag : ip_generic_getfrag; fl4 = &inet->cork.fl.u.ip4; - if (up->pending) { + if (READ_ONCE(up->pending)) { /* * There are pending frames. * The socket lock must be held while it's corked. @@ -1120,16 +1350,17 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) if (msg->msg_controllen) { err = udp_cmsg_send(sk, msg, &ipc.gso_size); - if (err > 0) + if (err > 0) { err = ip_cmsg_send(sk, msg, &ipc, sk->sk_family == AF_INET6); + connected = 0; + } if (unlikely(err < 0)) { kfree(ipc.opt); return err; } if (ipc.opt) free = 1; - connected = 0; } if (!ipc.opt) { struct ip_options_rcu *inet_opt; @@ -1146,7 +1377,9 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) if (cgroup_bpf_enabled(CGROUP_UDP4_SENDMSG) && !connected) { err = BPF_CGROUP_RUN_PROG_UDP4_SENDMSG_LOCK(sk, - (struct sockaddr *)usin, &ipc.addr); + (struct sockaddr *)usin, + &msg->msg_namelen, + &ipc.addr); if (err) goto out_free; if (usin) { @@ -1171,38 +1404,35 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) faddr = ipc.opt->opt.faddr; connected = 0; } - tos = get_rttos(&ipc, inet); - if (sock_flag(sk, SOCK_LOCALROUTE) || - (msg->msg_flags & MSG_DONTROUTE) || - (ipc.opt && ipc.opt->opt.is_strictroute)) { - tos |= RTO_ONLINK; + scope = ip_sendmsg_scope(inet, &ipc, msg); + if (scope == RT_SCOPE_LINK) connected = 0; - } + uc_index = READ_ONCE(inet->uc_index); if (ipv4_is_multicast(daddr)) { if (!ipc.oif || netif_index_is_l3_master(sock_net(sk), ipc.oif)) - ipc.oif = inet->mc_index; + ipc.oif = READ_ONCE(inet->mc_index); if (!saddr) - saddr = inet->mc_addr; + saddr = READ_ONCE(inet->mc_addr); connected = 0; } else if (!ipc.oif) { - ipc.oif = inet->uc_index; - } else if (ipv4_is_lbcast(daddr) && inet->uc_index) { + ipc.oif = uc_index; + } else if (ipv4_is_lbcast(daddr) && uc_index) { /* oif is set, packet is to local broadcast and * uc_index is set. oif is most likely set * by sk_bound_dev_if. If uc_index != oif check if the * oif is an L3 master and uc_index is an L3 slave. * If so, we want to allow the send using the uc_index. */ - if (ipc.oif != inet->uc_index && + if (ipc.oif != uc_index && ipc.oif == l3mdev_master_ifindex_by_index(sock_net(sk), - inet->uc_index)) { - ipc.oif = inet->uc_index; + uc_index)) { + ipc.oif = uc_index; } } if (connected) - rt = (struct rtable *)sk_dst_check(sk, 0); + rt = dst_rtable(sk_dst_check(sk, 0)); if (!rt) { struct net *net = sock_net(sk); @@ -1210,11 +1440,11 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) fl4 = &fl4_stack; - flowi4_init_output(fl4, ipc.oif, ipc.sockc.mark, tos, - RT_SCOPE_UNIVERSE, sk->sk_protocol, - flow_flags, - faddr, saddr, dport, inet->inet_sport, - sk->sk_uid); + flowi4_init_output(fl4, ipc.oif, ipc.sockc.mark, + ipc.tos & INET_DSCP_MASK, scope, + sk->sk_protocol, flow_flags, faddr, saddr, + dport, inet->inet_sport, + sk_uid(sk)); security_sk_classify_flow(sk, flowi4_to_flowi_common(fl4)); rt = ip_route_output_flow(net, fl4, sk); @@ -1273,7 +1503,7 @@ back_from_confirm: fl4->saddr = saddr; fl4->fl4_dport = dport; fl4->fl4_sport = inet->inet_sport; - up->pending = AF_INET; + WRITE_ONCE(up->pending, AF_INET); do_append_data: up->len += ulen; @@ -1285,7 +1515,7 @@ do_append_data: else if (!corkreq) err = udp_push_pending_frames(sk); else if (unlikely(skb_queue_empty(&sk->sk_write_queue))) - up->pending = 0; + WRITE_ONCE(up->pending, 0); release_sock(sk); out: @@ -1318,58 +1548,20 @@ do_confirm: } EXPORT_SYMBOL(udp_sendmsg); -int udp_sendpage(struct sock *sk, struct page *page, int offset, - size_t size, int flags) +void udp_splice_eof(struct socket *sock) { - struct inet_sock *inet = inet_sk(sk); + struct sock *sk = sock->sk; struct udp_sock *up = udp_sk(sk); - int ret; - - if (flags & MSG_SENDPAGE_NOTLAST) - flags |= MSG_MORE; - if (!up->pending) { - struct msghdr msg = { .msg_flags = flags|MSG_MORE }; - - /* Call udp_sendmsg to specify destination address which - * sendpage interface can't pass. - * This will succeed only when the socket is connected. - */ - ret = udp_sendmsg(sk, &msg, 0); - if (ret < 0) - return ret; - } + if (!READ_ONCE(up->pending) || udp_test_bit(CORK, sk)) + return; lock_sock(sk); - - if (unlikely(!up->pending)) { - release_sock(sk); - - net_dbg_ratelimited("cork failed\n"); - return -EINVAL; - } - - ret = ip_append_page(sk, &inet->cork.fl.u.ip4, - page, offset, size, flags); - if (ret == -EOPNOTSUPP) { - release_sock(sk); - return sock_no_sendpage(sk->sk_socket, page, offset, - size, flags); - } - if (ret < 0) { - udp_flush_pending_frames(sk); - goto out; - } - - up->len += size; - if (!(READ_ONCE(up->corkflag) || (flags&MSG_MORE))) - ret = udp_push_pending_frames(sk); - if (!ret) - ret = size; -out: + if (up->pending && !udp_test_bit(CORK, sk)) + udp_push_pending_frames(sk); release_sock(sk); - return ret; } +EXPORT_IPV6_MOD_GPL(udp_splice_eof); #define UDP_SKB_IS_STATELESS 0x80000000 @@ -1434,17 +1626,17 @@ static bool udp_skb_has_head_state(struct sk_buff *skb) } /* fully reclaim rmem/fwd memory allocated for skb */ -static void udp_rmem_release(struct sock *sk, int size, int partial, - bool rx_queue_lock_held) +static void udp_rmem_release(struct sock *sk, unsigned int size, + int partial, bool rx_queue_lock_held) { struct udp_sock *up = udp_sk(sk); struct sk_buff_head *sk_queue; - int amt; + unsigned int amt; if (likely(partial)) { up->forward_deficit += size; size = up->forward_deficit; - if (size < (sk->sk_rcvbuf >> 2) && + if (size < READ_ONCE(up->forward_threshold) && !skb_queue_empty(&up->reader_queue)) return; } else { @@ -1459,13 +1651,11 @@ static void udp_rmem_release(struct sock *sk, int size, int partial, if (!rx_queue_lock_held) spin_lock(&sk_queue->lock); - - sk->sk_forward_alloc += size; - amt = (sk->sk_forward_alloc - partial) & ~(SK_MEM_QUANTUM - 1); - sk->sk_forward_alloc -= amt; + amt = (size + sk->sk_forward_alloc - partial) & ~(PAGE_SIZE - 1); + sk_forward_alloc_add(sk, size - amt); if (amt) - __sk_mem_reduce_allocated(sk, amt >> SK_MEM_QUANTUM_SHIFT); + __sk_mem_reduce_allocated(sk, amt >> PAGE_SHIFT); atomic_sub(size, &sk->sk_rmem_alloc); @@ -1486,7 +1676,7 @@ void udp_skb_destructor(struct sock *sk, struct sk_buff *skb) prefetch(&skb->data); udp_rmem_release(sk, udp_skb_truesize(skb), 1, false); } -EXPORT_SYMBOL(udp_skb_destructor); +EXPORT_IPV6_MOD(udp_skb_destructor); /* as above, but the caller held the rx queue lock, too */ static void udp_skb_dtor_locked(struct sock *sk, struct sk_buff *skb) @@ -1495,44 +1685,49 @@ static void udp_skb_dtor_locked(struct sock *sk, struct sk_buff *skb) udp_rmem_release(sk, udp_skb_truesize(skb), 1, true); } -/* Idea of busylocks is to let producers grab an extra spinlock - * to relieve pressure on the receive_queue spinlock shared by consumer. - * Under flood, this means that only one producer can be in line - * trying to acquire the receive_queue spinlock. - * These busylock can be allocated on a per cpu manner, instead of a - * per socket one (that would consume a cache line per socket) - */ -static int udp_busylocks_log __read_mostly; -static spinlock_t *udp_busylocks __read_mostly; - -static spinlock_t *busylock_acquire(void *ptr) +static int udp_rmem_schedule(struct sock *sk, int size) { - spinlock_t *busy; + int delta; - busy = udp_busylocks + hash_ptr(ptr, udp_busylocks_log); - spin_lock(busy); - return busy; -} + delta = size - sk->sk_forward_alloc; + if (delta > 0 && !__sk_mem_schedule(sk, delta, SK_MEM_RECV)) + return -ENOBUFS; -static void busylock_release(spinlock_t *busy) -{ - if (busy) - spin_unlock(busy); + return 0; } int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb) { struct sk_buff_head *list = &sk->sk_receive_queue; - int rmem, delta, amt, err = -ENOMEM; - spinlock_t *busy = NULL; - int size; + struct udp_prod_queue *udp_prod_queue; + struct sk_buff *next, *to_drop = NULL; + struct llist_node *ll_list; + unsigned int rmem, rcvbuf; + int size, err = -ENOMEM; + int total_size = 0; + int q_size = 0; + int dropcount; + int nb = 0; - /* try to avoid the costly atomic add/sub pair when the receive - * queue is full; always allow at least a packet - */ rmem = atomic_read(&sk->sk_rmem_alloc); - if (rmem > sk->sk_rcvbuf) - goto drop; + rcvbuf = READ_ONCE(sk->sk_rcvbuf); + size = skb->truesize; + + udp_prod_queue = &udp_sk(sk)->udp_prod_queue[numa_node_id()]; + + rmem += atomic_read(&udp_prod_queue->rmem_alloc); + + /* Immediately drop when the receive queue is full. + * Cast to unsigned int performs the boundary check for INT_MAX. + */ + if (rmem + size > rcvbuf) { + if (rcvbuf > INT_MAX >> 1) + goto drop; + + /* Accept the packet if queue is empty. */ + if (rmem) + goto drop; + } /* Under mem pressure, it might be helpful to help udp_recvmsg() * having linear skbs : @@ -1540,61 +1735,85 @@ int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb) * - Less cache line misses at copyout() time * - Less work at consume_skb() (less alien page frag freeing) */ - if (rmem > (sk->sk_rcvbuf >> 1)) { + if (rmem > (rcvbuf >> 1)) { skb_condense(skb); - - busy = busylock_acquire(sk); + size = skb->truesize; } - size = skb->truesize; + udp_set_dev_scratch(skb); - /* we drop only if the receive buf is full and the receive - * queue contains some other skb - */ - rmem = atomic_add_return(size, &sk->sk_rmem_alloc); - if (rmem > (size + (unsigned int)sk->sk_rcvbuf)) - goto uncharge_drop; + atomic_add(size, &udp_prod_queue->rmem_alloc); + + if (!llist_add(&skb->ll_node, &udp_prod_queue->ll_root)) + return 0; + + dropcount = sock_flag(sk, SOCK_RXQ_OVFL) ? sk_drops_read(sk) : 0; spin_lock(&list->lock); - if (size >= sk->sk_forward_alloc) { - amt = sk_mem_pages(size); - delta = amt << SK_MEM_QUANTUM_SHIFT; - if (!__sk_mem_raise_allocated(sk, delta, amt, SK_MEM_RECV)) { - err = -ENOBUFS; - spin_unlock(&list->lock); - goto uncharge_drop; + + ll_list = llist_del_all(&udp_prod_queue->ll_root); + + ll_list = llist_reverse_order(ll_list); + + llist_for_each_entry_safe(skb, next, ll_list, ll_node) { + size = udp_skb_truesize(skb); + total_size += size; + err = udp_rmem_schedule(sk, size); + if (unlikely(err)) { + /* Free the skbs outside of locked section. */ + skb->next = to_drop; + to_drop = skb; + continue; } - sk->sk_forward_alloc += delta; + q_size += size; + sk_forward_alloc_add(sk, -size); + + /* no need to setup a destructor, we will explicitly release the + * forward allocated memory on dequeue + */ + SOCK_SKB_CB(skb)->dropcount = dropcount; + nb++; + __skb_queue_tail(list, skb); } - sk->sk_forward_alloc -= size; + atomic_add(q_size, &sk->sk_rmem_alloc); - /* no need to setup a destructor, we will explicitly release the - * forward allocated memory on dequeue - */ - sock_skb_set_dropcount(sk, skb); - - __skb_queue_tail(list, skb); spin_unlock(&list->lock); - if (!sock_flag(sk, SOCK_DEAD)) - sk->sk_data_ready(sk); + if (!sock_flag(sk, SOCK_DEAD)) { + /* Multiple threads might be blocked in recvmsg(), + * using prepare_to_wait_exclusive(). + */ + while (nb) { + INDIRECT_CALL_1(sk->sk_data_ready, + sock_def_readable, sk); + nb--; + } + } - busylock_release(busy); - return 0; + if (unlikely(to_drop)) { + for (nb = 0; to_drop != NULL; nb++) { + skb = to_drop; + to_drop = skb->next; + skb_mark_not_on_list(skb); + /* TODO: update SNMP values. */ + sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_PROTO_MEM); + } + numa_drop_add(&udp_sk(sk)->drop_counters, nb); + } -uncharge_drop: - atomic_sub(skb->truesize, &sk->sk_rmem_alloc); + atomic_sub(total_size, &udp_prod_queue->rmem_alloc); + + return 0; drop: - atomic_inc(&sk->sk_drops); - busylock_release(busy); + udp_drops_inc(sk); return err; } -EXPORT_SYMBOL_GPL(__udp_enqueue_schedule_skb); +EXPORT_IPV6_MOD_GPL(__udp_enqueue_schedule_skb); -void udp_destruct_sock(struct sock *sk) +void udp_destruct_common(struct sock *sk) { /* reclaim completely the forward allocated memory */ struct udp_sock *up = udp_sk(sk); @@ -1607,26 +1826,33 @@ void udp_destruct_sock(struct sock *sk) kfree_skb(skb); } udp_rmem_release(sk, total, 0, true); + kfree(up->udp_prod_queue); +} +EXPORT_IPV6_MOD_GPL(udp_destruct_common); +static void udp_destruct_sock(struct sock *sk) +{ + udp_destruct_common(sk); inet_sock_destruct(sk); } -EXPORT_SYMBOL_GPL(udp_destruct_sock); int udp_init_sock(struct sock *sk) { - skb_queue_head_init(&udp_sk(sk)->reader_queue); + int res = udp_lib_init_sock(sk); + sk->sk_destruct = udp_destruct_sock; - return 0; + set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); + return res; } -EXPORT_SYMBOL_GPL(udp_init_sock); void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len) { - if (unlikely(READ_ONCE(sk->sk_peek_off) >= 0)) { - bool slow = lock_sock_fast(sk); - + if (unlikely(READ_ONCE(udp_sk(sk)->peeking_with_offset))) sk_peek_offset_bwd(sk, len); - unlock_sock_fast(sk, slow); + + if (!skb_shared(skb)) { + skb_attempt_defer_free(skb); + return; } if (!skb_unref(skb)) @@ -1639,11 +1865,11 @@ void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len) skb_release_head_state(skb); __consume_stateless_skb(skb); } -EXPORT_SYMBOL_GPL(skb_consume_udp); +EXPORT_IPV6_MOD_GPL(skb_consume_udp); static struct sk_buff *__first_packet_length(struct sock *sk, struct sk_buff_head *rcvq, - int *total) + unsigned int *total) { struct sk_buff *skb; @@ -1653,10 +1879,10 @@ static struct sk_buff *__first_packet_length(struct sock *sk, IS_UDPLITE(sk)); __UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, IS_UDPLITE(sk)); - atomic_inc(&sk->sk_drops); + udp_drops_inc(sk); __skb_unlink(skb, rcvq); *total += skb->truesize; - kfree_skb(skb); + kfree_skb_reason(skb, SKB_DROP_REASON_UDP_CSUM); } else { udp_skb_csum_unnecessary_set(skb); break; @@ -1676,8 +1902,8 @@ static int first_packet_length(struct sock *sk) { struct sk_buff_head *rcvq = &udp_sk(sk)->reader_queue; struct sk_buff_head *sk_queue = &sk->sk_receive_queue; + unsigned int total = 0; struct sk_buff *skb; - int total = 0; int res; spin_lock_bh(&rcvq->lock); @@ -1700,21 +1926,19 @@ static int first_packet_length(struct sock *sk) * IOCTL requests applicable to the UDP protocol */ -int udp_ioctl(struct sock *sk, int cmd, unsigned long arg) +int udp_ioctl(struct sock *sk, int cmd, int *karg) { switch (cmd) { case SIOCOUTQ: { - int amount = sk_wmem_alloc_get(sk); - - return put_user(amount, (int __user *)arg); + *karg = sk_wmem_alloc_get(sk); + return 0; } case SIOCINQ: { - int amount = max_t(int, 0, first_packet_length(sk)); - - return put_user(amount, (int __user *)arg); + *karg = max_t(int, 0, first_packet_length(sk)); + return 0; } default: @@ -1723,10 +1947,10 @@ int udp_ioctl(struct sock *sk, int cmd, unsigned long arg) return 0; } -EXPORT_SYMBOL(udp_ioctl); +EXPORT_IPV6_MOD(udp_ioctl); struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, - int noblock, int *off, int *err) + int *off, int *err) { struct sk_buff_head *sk_queue = &sk->sk_receive_queue; struct sk_buff_head *queue; @@ -1735,7 +1959,6 @@ struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, int error; queue = &udp_sk(sk)->reader_queue; - flags |= noblock ? MSG_DONTWAIT : 0; timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); do { struct sk_buff *skb; @@ -1747,8 +1970,8 @@ struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, error = -EAGAIN; do { spin_lock_bh(&queue->lock); - skb = __skb_try_recv_from_queue(sk, queue, flags, off, - err, &last); + skb = __skb_try_recv_from_queue(queue, flags, off, err, + &last); if (skb) { if (!(flags & MSG_PEEK)) udp_skb_destructor(sk, skb); @@ -1769,8 +1992,8 @@ struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, spin_lock(&sk_queue->lock); skb_queue_splice_tail_init(sk_queue, queue); - skb = __skb_try_recv_from_queue(sk, queue, flags, off, - err, &last); + skb = __skb_try_recv_from_queue(queue, flags, off, err, + &last); if (skb && !(flags & MSG_PEEK)) udp_skb_dtor_locked(sk, skb); spin_unlock(&sk_queue->lock); @@ -1796,55 +2019,39 @@ busy_check: } EXPORT_SYMBOL(__skb_recv_udp); -int udp_read_sock(struct sock *sk, read_descriptor_t *desc, - sk_read_actor_t recv_actor) +int udp_read_skb(struct sock *sk, skb_read_actor_t recv_actor) { - int copied = 0; - - while (1) { - struct sk_buff *skb; - int err, used; - - skb = skb_recv_udp(sk, 0, 1, &err); - if (!skb) - return err; + struct sk_buff *skb; + int err; - if (udp_lib_checksum_complete(skb)) { - __UDP_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, - IS_UDPLITE(sk)); - __UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, - IS_UDPLITE(sk)); - atomic_inc(&sk->sk_drops); - kfree_skb(skb); - continue; - } +try_again: + skb = skb_recv_udp(sk, MSG_DONTWAIT, &err); + if (!skb) + return err; - used = recv_actor(desc, skb, 0, skb->len); - if (used <= 0) { - if (!copied) - copied = used; - kfree_skb(skb); - break; - } else if (used <= skb->len) { - copied += used; - } + if (udp_lib_checksum_complete(skb)) { + int is_udplite = IS_UDPLITE(sk); + struct net *net = sock_net(sk); - kfree_skb(skb); - if (!desc->count) - break; + __UDP_INC_STATS(net, UDP_MIB_CSUMERRORS, is_udplite); + __UDP_INC_STATS(net, UDP_MIB_INERRORS, is_udplite); + udp_drops_inc(sk); + kfree_skb_reason(skb, SKB_DROP_REASON_UDP_CSUM); + goto try_again; } - return copied; + WARN_ON_ONCE(!skb_set_owner_sk_safe(skb, sk)); + return recv_actor(sk, skb); } -EXPORT_SYMBOL(udp_read_sock); +EXPORT_IPV6_MOD(udp_read_skb); /* * This should be easy, if there is something there we * return it, otherwise we block. */ -int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, - int flags, int *addr_len) +int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len) { struct inet_sock *inet = inet_sk(sk); DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name); @@ -1859,7 +2066,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, try_again: off = sk_peek_offset(sk, flags); - skb = __skb_recv_udp(sk, flags, noblock, &off, &err); + skb = __skb_recv_udp(sk, flags, &off, &err); if (!skb) return err; @@ -1898,7 +2105,7 @@ try_again: if (unlikely(err)) { if (!peeking) { - atomic_inc(&sk->sk_drops); + udp_drops_inc(sk); UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite); } @@ -1910,7 +2117,7 @@ try_again: UDP_INC_STATS(sock_net(sk), UDP_MIB_INDATAGRAMS, is_udplite); - sock_recv_ts_and_drops(msg, sk, skb); + sock_recv_cmsgs(msg, sk, skb); /* Copy the address. */ if (sin) { @@ -1921,13 +2128,14 @@ try_again: *addr_len = sizeof(*sin); BPF_CGROUP_RUN_PROG_UDP4_RECVMSG_LOCK(sk, - (struct sockaddr *)sin); + (struct sockaddr *)sin, + addr_len); } - if (udp_sk(sk)->gro_enabled) + if (udp_test_bit(GRO_ENABLED, sk)) udp_cmsg_recv(msg, sk, skb); - if (inet->cmsg_flags) + if (inet_cmsg_flags(inet)) ip_cmsg_recv_offset(msg, sk, skb, sizeof(struct udphdr), off); err = copied; @@ -1943,7 +2151,7 @@ csum_copy_err: UDP_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, is_udplite); UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite); } - kfree_skb(skb); + kfree_skb_reason(skb, SKB_DROP_REASON_UDP_CSUM); /* starting over for a new packet, but check if we need to yield */ cond_resched(); @@ -1951,7 +2159,8 @@ csum_copy_err: goto try_again; } -int udp_pre_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +int udp_pre_connect(struct sock *sk, struct sockaddr_unsized *uaddr, + int addr_len) { /* This check is replicated from __ip4_datagram_connect() and * intended to prevent BPF program called below from accessing bytes @@ -1960,9 +2169,22 @@ int udp_pre_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) if (addr_len < sizeof(struct sockaddr_in)) return -EINVAL; - return BPF_CGROUP_RUN_PROG_INET4_CONNECT_LOCK(sk, uaddr); + return BPF_CGROUP_RUN_PROG_INET4_CONNECT_LOCK(sk, uaddr, &addr_len); +} +EXPORT_IPV6_MOD(udp_pre_connect); + +static int udp_connect(struct sock *sk, struct sockaddr_unsized *uaddr, + int addr_len) +{ + int res; + + lock_sock(sk); + res = __ip4_datagram_connect(sk, uaddr, addr_len); + if (!res) + udp4_hash4(sk); + release_sock(sk); + return res; } -EXPORT_SYMBOL(udp_pre_connect); int __udp_disconnect(struct sock *sk, int flags) { @@ -1999,14 +2221,15 @@ int udp_disconnect(struct sock *sk, int flags) release_sock(sk); return 0; } -EXPORT_SYMBOL(udp_disconnect); +EXPORT_IPV6_MOD(udp_disconnect); void udp_lib_unhash(struct sock *sk) { if (sk_hashed(sk)) { - struct udp_table *udptable = sk->sk_prot->h.udp_table; + struct udp_table *udptable = udp_get_table_prot(sk); struct udp_hslot *hslot, *hslot2; + sock_rps_delete_flow(sk); hslot = udp_hashslot(udptable, sock_net(sk), udp_sk(sk)->udp_port_hash); hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash); @@ -2023,29 +2246,31 @@ void udp_lib_unhash(struct sock *sk) hlist_del_init_rcu(&udp_sk(sk)->udp_portaddr_node); hslot2->count--; spin_unlock(&hslot2->lock); + + udp_unhash4(udptable, sk); } spin_unlock_bh(&hslot->lock); } } -EXPORT_SYMBOL(udp_lib_unhash); +EXPORT_IPV6_MOD(udp_lib_unhash); /* * inet_rcv_saddr was changed, we must rehash secondary hash */ -void udp_lib_rehash(struct sock *sk, u16 newhash) +void udp_lib_rehash(struct sock *sk, u16 newhash, u16 newhash4) { if (sk_hashed(sk)) { - struct udp_table *udptable = sk->sk_prot->h.udp_table; + struct udp_table *udptable = udp_get_table_prot(sk); struct udp_hslot *hslot, *hslot2, *nhslot2; + hslot = udp_hashslot(udptable, sock_net(sk), + udp_sk(sk)->udp_port_hash); hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash); nhslot2 = udp_hashslot2(udptable, newhash); udp_sk(sk)->udp_portaddr_hash = newhash; if (hslot2 != nhslot2 || rcu_access_pointer(sk->sk_reuseport_cb)) { - hslot = udp_hashslot(udptable, sock_net(sk), - udp_sk(sk)->udp_port_hash); /* we must lock primary chain too */ spin_lock_bh(&hslot->lock); if (rcu_access_pointer(sk->sk_reuseport_cb)) @@ -2066,16 +2291,43 @@ void udp_lib_rehash(struct sock *sk, u16 newhash) spin_unlock_bh(&hslot->lock); } + + /* Now process hash4 if necessary: + * (1) update hslot4; + * (2) update hslot2->hash4_cnt. + * Note that hslot2/hslot4 should be checked separately, as + * either of them may change with the other unchanged. + */ + if (udp_hashed4(sk)) { + spin_lock_bh(&hslot->lock); + + udp_rehash4(udptable, sk, newhash4); + if (hslot2 != nhslot2) { + spin_lock(&hslot2->lock); + udp_hash4_dec(hslot2); + spin_unlock(&hslot2->lock); + + spin_lock(&nhslot2->lock); + udp_hash4_inc(nhslot2); + spin_unlock(&nhslot2->lock); + } + + spin_unlock_bh(&hslot->lock); + } } } -EXPORT_SYMBOL(udp_lib_rehash); +EXPORT_IPV6_MOD(udp_lib_rehash); void udp_v4_rehash(struct sock *sk) { u16 new_hash = ipv4_portaddr_hash(sock_net(sk), inet_sk(sk)->inet_rcv_saddr, inet_sk(sk)->inet_num); - udp_lib_rehash(sk, new_hash); + u16 new_hash4 = udp_ehashfn(sock_net(sk), + sk->sk_rcv_saddr, sk->sk_num, + sk->sk_daddr, sk->sk_dport); + + udp_lib_rehash(sk, new_hash, new_hash4); } static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) @@ -2093,17 +2345,21 @@ static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) rc = __udp_enqueue_schedule_skb(sk, skb); if (rc < 0) { int is_udplite = IS_UDPLITE(sk); + int drop_reason; /* Note that an ENOMEM error is charged twice */ - if (rc == -ENOMEM) + if (rc == -ENOMEM) { UDP_INC_STATS(sock_net(sk), UDP_MIB_RCVBUFERRORS, is_udplite); - else + drop_reason = SKB_DROP_REASON_SOCKET_RCVBUFF; + } else { UDP_INC_STATS(sock_net(sk), UDP_MIB_MEMERRORS, is_udplite); + drop_reason = SKB_DROP_REASON_PROTO_MEM; + } UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite); - kfree_skb(skb); - trace_udp_fail_queue_rcv_skb(rc, sk); + trace_udp_fail_queue_rcv_skb(rc, sk, skb); + sk_skb_reason_drop(sk, skb, drop_reason); return -1; } @@ -2120,17 +2376,21 @@ static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) */ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb) { + enum skb_drop_reason drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; struct udp_sock *up = udp_sk(sk); int is_udplite = IS_UDPLITE(sk); /* * Charge it to the socket, dropping if the queue is full. */ - if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) + if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) { + drop_reason = SKB_DROP_REASON_XFRM_POLICY; goto drop; + } nf_reset_ct(skb); - if (static_branch_unlikely(&udp_encap_needed_key) && up->encap_type) { + if (static_branch_unlikely(&udp_encap_needed_key) && + READ_ONCE(up->encap_type)) { int (*encap_rcv)(struct sock *sk, struct sk_buff *skb); /* @@ -2168,7 +2428,8 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb) /* * UDP-Lite specific tests, ignored on UDP sockets */ - if ((up->pcflag & UDPLITE_RECV_CC) && UDP_SKB_CB(skb)->partial_cov) { + if (udp_test_bit(UDPLITE_RECV_CC, sk) && UDP_SKB_CB(skb)->partial_cov) { + u16 pcrlen = READ_ONCE(up->pcrlen); /* * MIB statistics other than incrementing the error count are @@ -2181,7 +2442,7 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb) * delivery of packets with coverage values less than a value * provided by the application." */ - if (up->pcrlen == 0) { /* full coverage was set */ + if (pcrlen == 0) { /* full coverage was set */ net_dbg_ratelimited("UDPLite: partial coverage %d while full coverage %d requested\n", UDP_SKB_CB(skb)->cscov, skb->len); goto drop; @@ -2192,9 +2453,9 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb) * that it wants x while sender emits packets of smaller size y. * Therefore the above ...()->partial_cov statement is essential. */ - if (UDP_SKB_CB(skb)->cscov < up->pcrlen) { + if (UDP_SKB_CB(skb)->cscov < pcrlen) { net_dbg_ratelimited("UDPLite: coverage %d too small, need min %d\n", - UDP_SKB_CB(skb)->cscov, up->pcrlen); + UDP_SKB_CB(skb)->cscov, pcrlen); goto drop; } } @@ -2204,20 +2465,21 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb) udp_lib_checksum_complete(skb)) goto csum_error; - if (sk_filter_trim_cap(sk, skb, sizeof(struct udphdr))) + if (sk_filter_trim_cap(sk, skb, sizeof(struct udphdr), &drop_reason)) goto drop; udp_csum_pull_header(skb); - ipv4_pktinfo_prepare(sk, skb); + ipv4_pktinfo_prepare(sk, skb, true); return __udp_queue_rcv_skb(sk, skb); csum_error: + drop_reason = SKB_DROP_REASON_UDP_CSUM; __UDP_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, is_udplite); drop: __UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite); - atomic_inc(&sk->sk_drops); - kfree_skb(skb); + udp_drops_inc(sk); + sk_skb_reason_drop(sk, skb, drop_reason); return -1; } @@ -2251,13 +2513,13 @@ bool udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) struct dst_entry *old; if (dst_hold_safe(dst)) { - old = xchg((__force struct dst_entry **)&sk->sk_rx_dst, dst); + old = unrcu_pointer(xchg(&sk->sk_rx_dst, RCU_INITIALIZER(dst))); dst_release(old); return old != dst; } return false; } -EXPORT_SYMBOL(udp_sk_rx_dst_set); +EXPORT_IPV6_MOD(udp_sk_rx_dst_set); /* * Multicasts and broadcasts go to each listener. @@ -2285,7 +2547,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb, udptable->mask; hash2 = ipv4_portaddr_hash(net, daddr, hnum) & udptable->mask; start_lookup: - hslot = &udptable->hash2[hash2]; + hslot = &udptable->hash2[hash2].hslot; offset = offsetof(typeof(*sk), __sk_common.skc_portaddr_node); } @@ -2301,7 +2563,7 @@ start_lookup: nskb = skb_clone(skb, GFP_ATOMIC); if (unlikely(!nskb)) { - atomic_inc(&sk->sk_drops); + udp_drops_inc(sk); __UDP_INC_STATS(net, UDP_MIB_RCVBUFERRORS, IS_UDPLITE(sk)); __UDP_INC_STATS(net, UDP_MIB_INERRORS, @@ -2376,7 +2638,7 @@ static inline int udp4_csum_init(struct sk_buff *skb, struct udphdr *uh, return 0; } -/* wrapper for udp_queue_rcv_skb tacking care of csum conversion and +/* wrapper for udp_queue_rcv_skb taking care of csum conversion and * return code conversion for ip layer consumption */ static int udp_unicast_rcv_skb(struct sock *sk, struct sk_buff *skb, @@ -2404,7 +2666,7 @@ static int udp_unicast_rcv_skb(struct sock *sk, struct sk_buff *skb, int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, int proto) { - struct sock *sk; + struct sock *sk = NULL; struct udphdr *uh; unsigned short ulen; struct rtable *rt = skb_rtable(skb); @@ -2439,7 +2701,11 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, if (udp4_csum_init(skb, uh, proto)) goto csum_error; - sk = skb_steal_sock(skb, &refcounted); + sk = inet_steal_sock(net, skb, sizeof(struct udphdr), saddr, uh->source, daddr, uh->dest, + &refcounted, udp_ehashfn); + if (IS_ERR(sk)) + goto no_sk; + if (sk) { struct dst_entry *dst = skb_dst(skb); int ret; @@ -2460,7 +2726,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable); if (sk) return udp_unicast_rcv_skb(sk, skb, uh); - +no_sk: if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) goto drop; nf_reset_ct(skb); @@ -2477,7 +2743,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, * Hmm. We got an UDP packet to a port to which we * don't wanna listen. Ignore it. */ - kfree_skb_reason(skb, drop_reason); + sk_skb_reason_drop(sk, skb, drop_reason); return 0; short_packet: @@ -2502,7 +2768,7 @@ csum_error: __UDP_INC_STATS(net, UDP_MIB_CSUMERRORS, proto == IPPROTO_UDPLITE); drop: __UDP_INC_STATS(net, UDP_MIB_INERRORS, proto == IPPROTO_UDPLITE); - kfree_skb_reason(skb, drop_reason); + sk_skb_reason_drop(sk, skb, drop_reason); return 0; } @@ -2514,10 +2780,14 @@ static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net, __be16 rmt_port, __be32 rmt_addr, int dif, int sdif) { - struct sock *sk, *result; + struct udp_table *udptable = net->ipv4.udp_table; unsigned short hnum = ntohs(loc_port); - unsigned int slot = udp_hashfn(net, hnum, udp_table.mask); - struct udp_hslot *hslot = &udp_table.hash[slot]; + struct sock *sk, *result; + struct udp_hslot *hslot; + unsigned int slot; + + slot = udp_hashfn(net, hnum, udptable->mask); + hslot = &udptable->hash[slot]; /* Do not bother scanning a too big list */ if (hslot->count > 10) @@ -2545,17 +2815,20 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net, __be16 rmt_port, __be32 rmt_addr, int dif, int sdif) { - unsigned short hnum = ntohs(loc_port); - unsigned int hash2 = ipv4_portaddr_hash(net, loc_addr, hnum); - unsigned int slot2 = hash2 & udp_table.mask; - struct udp_hslot *hslot2 = &udp_table.hash2[slot2]; + struct udp_table *udptable = net->ipv4.udp_table; INET_ADDR_COOKIE(acookie, rmt_addr, loc_addr); - const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum); + unsigned short hnum = ntohs(loc_port); + struct udp_hslot *hslot2; + unsigned int hash2; + __portpair ports; struct sock *sk; + hash2 = ipv4_portaddr_hash(net, loc_addr, hnum); + hslot2 = udp_hashslot2(udptable, hash2); + ports = INET_COMBINED_PORTS(rmt_port, hnum); + udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { - if (INET_MATCH(sk, net, acookie, rmt_addr, - loc_addr, ports, dif, sdif)) + if (inet_match(net, sk, acookie, ports, dif, sdif)) return sk; /* Only check first socket in chain */ break; @@ -2563,7 +2836,7 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net, return NULL; } -int udp_v4_early_demux(struct sk_buff *skb) +enum skb_drop_reason udp_v4_early_demux(struct sk_buff *skb) { struct net *net = dev_net(skb->dev); struct in_device *in_dev = NULL; @@ -2577,7 +2850,7 @@ int udp_v4_early_demux(struct sk_buff *skb) /* validate the packet */ if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) - return 0; + return SKB_NOT_DROPPED_YET; iph = ip_hdr(skb); uh = udp_hdr(skb); @@ -2586,12 +2859,12 @@ int udp_v4_early_demux(struct sk_buff *skb) in_dev = __in_dev_get_rcu(skb->dev); if (!in_dev) - return 0; + return SKB_NOT_DROPPED_YET; ours = ip_check_mc_rcu(in_dev, iph->daddr, iph->saddr, iph->protocol); if (!ours) - return 0; + return SKB_NOT_DROPPED_YET; sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, uh->source, iph->saddr, @@ -2601,11 +2874,12 @@ int udp_v4_early_demux(struct sk_buff *skb) uh->source, iph->saddr, dif, sdif); } - if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt)) - return 0; + if (!sk) + return SKB_NOT_DROPPED_YET; skb->sk = sk; - skb->destructor = sock_efree; + DEBUG_NET_WARN_ON_ONCE(sk_is_refcounted(sk)); + skb->destructor = sock_pfree; dst = rcu_dereference(sk->sk_rx_dst); if (dst) @@ -2625,15 +2899,15 @@ int udp_v4_early_demux(struct sk_buff *skb) if (!inet_sk(sk)->inet_daddr && in_dev) return ip_mc_validate_source(skb, iph->daddr, iph->saddr, - iph->tos & IPTOS_RT_MASK, + ip4h_dscp(iph), skb->dev, in_dev, &itag); } - return 0; + return SKB_NOT_DROPPED_YET; } int udp_rcv(struct sk_buff *skb) { - return __udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP); + return __udp4_lib_rcv(skb, dev_net(skb->dev)->ipv4.udp_table, IPPROTO_UDP); } void udp_destroy_sock(struct sock *sk) @@ -2652,11 +2926,44 @@ void udp_destroy_sock(struct sock *sk) if (encap_destroy) encap_destroy(sk); } - if (up->encap_enabled) + if (udp_test_bit(ENCAP_ENABLED, sk)) { static_branch_dec(&udp_encap_needed_key); + udp_tunnel_cleanup_gro(sk); + } } } +typedef struct sk_buff *(*udp_gro_receive_t)(struct sock *sk, + struct list_head *head, + struct sk_buff *skb); + +static void set_xfrm_gro_udp_encap_rcv(__u16 encap_type, unsigned short family, + struct sock *sk) +{ +#ifdef CONFIG_XFRM + udp_gro_receive_t new_gro_receive; + + if (udp_test_bit(GRO_ENABLED, sk) && encap_type == UDP_ENCAP_ESPINUDP) { + if (IS_ENABLED(CONFIG_IPV6) && family == AF_INET6) + new_gro_receive = ipv6_stub->xfrm6_gro_udp_encap_rcv; + else + new_gro_receive = xfrm4_gro_udp_encap_rcv; + + if (udp_sk(sk)->gro_receive != new_gro_receive) { + /* + * With IPV6_ADDRFORM the gro callback could change + * after being set, unregister the old one, if valid. + */ + if (udp_sk(sk)->gro_receive) + udp_tunnel_update_gro_rcv(sk, false); + + WRITE_ONCE(udp_sk(sk)->gro_receive, new_gro_receive); + udp_tunnel_update_gro_rcv(sk, true); + } + } +#endif +} + /* * Socket option code for UDP */ @@ -2669,6 +2976,18 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, int err = 0; int is_udplite = IS_UDPLITE(sk); + if (level == SOL_SOCKET) { + err = sk_setsockopt(sk, level, optname, optval, optlen); + + if (optname == SO_RCVBUF || optname == SO_RCVBUFFORCE) { + sockopt_lock_sock(sk); + /* paired with READ_ONCE in udp_rmem_release() */ + WRITE_ONCE(up->forward_threshold, sk->sk_rcvbuf >> 2); + sockopt_release_sock(sk); + } + return err; + } + if (optlen < sizeof(int)) return -EINVAL; @@ -2680,9 +2999,9 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, switch (optname) { case UDP_CORK: if (val != 0) { - WRITE_ONCE(up->corkflag, 1); + udp_set_bit(CORK, sk); } else { - WRITE_ONCE(up->corkflag, 0); + udp_clear_bit(CORK, sk); lock_sock(sk); push_pending_frames(sk); release_sock(sk); @@ -2690,37 +3009,39 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, break; case UDP_ENCAP: + sockopt_lock_sock(sk); switch (val) { case 0: #ifdef CONFIG_XFRM case UDP_ENCAP_ESPINUDP: - case UDP_ENCAP_ESPINUDP_NON_IKE: + set_xfrm_gro_udp_encap_rcv(val, sk->sk_family, sk); #if IS_ENABLED(CONFIG_IPV6) if (sk->sk_family == AF_INET6) - up->encap_rcv = ipv6_stub->xfrm6_udp_encap_rcv; + WRITE_ONCE(up->encap_rcv, + ipv6_stub->xfrm6_udp_encap_rcv); else #endif - up->encap_rcv = xfrm4_udp_encap_rcv; + WRITE_ONCE(up->encap_rcv, + xfrm4_udp_encap_rcv); #endif fallthrough; case UDP_ENCAP_L2TPINUDP: - up->encap_type = val; - lock_sock(sk); - udp_tunnel_encap_enable(sk->sk_socket); - release_sock(sk); + WRITE_ONCE(up->encap_type, val); + udp_tunnel_encap_enable(sk); break; default: err = -ENOPROTOOPT; break; } + sockopt_release_sock(sk); break; case UDP_NO_CHECK6_TX: - up->no_check6_tx = valbool; + udp_set_no_check6_tx(sk, valbool); break; case UDP_NO_CHECK6_RX: - up->no_check6_rx = valbool; + udp_set_no_check6_rx(sk, valbool); break; case UDP_SEGMENT: @@ -2730,14 +3051,14 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, break; case UDP_GRO: - lock_sock(sk); - + sockopt_lock_sock(sk); /* when enabling GRO, accept the related GSO packet type */ if (valbool) - udp_tunnel_encap_enable(sk->sk_socket); - up->gro_enabled = valbool; - up->accept_udp_l4 = valbool; - release_sock(sk); + udp_tunnel_encap_enable(sk); + udp_assign_bit(GRO_ENABLED, sk, valbool); + udp_assign_bit(ACCEPT_L4, sk, valbool); + set_xfrm_gro_udp_encap_rcv(up->encap_type, sk->sk_family, sk); + sockopt_release_sock(sk); break; /* @@ -2752,8 +3073,8 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, val = 8; else if (val > USHRT_MAX) val = USHRT_MAX; - up->pcslen = val; - up->pcflag |= UDPLITE_SEND_CC; + WRITE_ONCE(up->pcslen, val); + udp_set_bit(UDPLITE_SEND_CC, sk); break; /* The receiver specifies a minimum checksum coverage value. To make @@ -2766,8 +3087,8 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, val = 8; else if (val > USHRT_MAX) val = USHRT_MAX; - up->pcrlen = val; - up->pcflag |= UDPLITE_RECV_CC; + WRITE_ONCE(up->pcrlen, val); + udp_set_bit(UDPLITE_RECV_CC, sk); break; default: @@ -2777,12 +3098,12 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname, return err; } -EXPORT_SYMBOL(udp_lib_setsockopt); +EXPORT_IPV6_MOD(udp_lib_setsockopt); int udp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, unsigned int optlen) { - if (level == SOL_UDP || level == SOL_UDPLITE) + if (level == SOL_UDP || level == SOL_UDPLITE || level == SOL_SOCKET) return udp_lib_setsockopt(sk, level, optname, optval, optlen, udp_push_pending_frames); @@ -2798,26 +3119,26 @@ int udp_lib_getsockopt(struct sock *sk, int level, int optname, if (get_user(len, optlen)) return -EFAULT; - len = min_t(unsigned int, len, sizeof(int)); - if (len < 0) return -EINVAL; + len = min_t(unsigned int, len, sizeof(int)); + switch (optname) { case UDP_CORK: - val = READ_ONCE(up->corkflag); + val = udp_test_bit(CORK, sk); break; case UDP_ENCAP: - val = up->encap_type; + val = READ_ONCE(up->encap_type); break; case UDP_NO_CHECK6_TX: - val = up->no_check6_tx; + val = udp_get_no_check6_tx(sk); break; case UDP_NO_CHECK6_RX: - val = up->no_check6_rx; + val = udp_get_no_check6_rx(sk); break; case UDP_SEGMENT: @@ -2825,17 +3146,17 @@ int udp_lib_getsockopt(struct sock *sk, int level, int optname, break; case UDP_GRO: - val = up->gro_enabled; + val = udp_test_bit(GRO_ENABLED, sk); break; /* The following two cannot be changed on UDP sockets, the return is * always 0 (which corresponds to the full checksum coverage of UDP). */ case UDPLITE_SEND_CSCOV: - val = up->pcslen; + val = READ_ONCE(up->pcslen); break; case UDPLITE_RECV_CSCOV: - val = up->pcrlen; + val = READ_ONCE(up->pcrlen); break; default: @@ -2848,7 +3169,7 @@ int udp_lib_getsockopt(struct sock *sk, int level, int optname, return -EFAULT; return 0; } -EXPORT_SYMBOL(udp_lib_getsockopt); +EXPORT_IPV6_MOD(udp_lib_getsockopt); int udp_getsockopt(struct sock *sk, int level, int optname, char __user *optval, int __user *optlen) @@ -2890,11 +3211,12 @@ __poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait) return mask; } -EXPORT_SYMBOL(udp_poll); +EXPORT_IPV6_MOD(udp_poll); int udp_abort(struct sock *sk, int err) { - lock_sock(sk); + if (!has_current_bpf_ctx()) + lock_sock(sk); /* udp{v6}_destroy_sock() sets it under the sk lock, avoid racing * with close() @@ -2907,18 +3229,19 @@ int udp_abort(struct sock *sk, int err) __udp_disconnect(sk, 0); out: - release_sock(sk); + if (!has_current_bpf_ctx()) + release_sock(sk); return 0; } -EXPORT_SYMBOL_GPL(udp_abort); +EXPORT_IPV6_MOD_GPL(udp_abort); struct proto udp_prot = { .name = "UDP", .owner = THIS_MODULE, .close = udp_lib_close, .pre_connect = udp_pre_connect, - .connect = ip4_datagram_connect, + .connect = udp_connect, .disconnect = udp_disconnect, .ioctl = udp_ioctl, .init = udp_init_sock, @@ -2927,7 +3250,7 @@ struct proto udp_prot = { .getsockopt = udp_getsockopt, .sendmsg = udp_sendmsg, .recvmsg = udp_recvmsg, - .sendpage = udp_sendpage, + .splice_eof = udp_splice_eof, .release_cb = ip4_datagram_release_cb, .hash = udp_lib_hash, .unhash = udp_lib_unhash, @@ -2937,12 +3260,14 @@ struct proto udp_prot = { #ifdef CONFIG_BPF_SYSCALL .psock_update_sk_prot = udp_bpf_update_proto, #endif - .memory_allocated = &udp_memory_allocated, + .memory_allocated = &net_aligned_data.udp_memory_allocated, + .per_cpu_fw_alloc = &udp_memory_per_cpu_fw_alloc, + .sysctl_mem = sysctl_udp_mem, .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_udp_rmem_min), .obj_size = sizeof(struct udp_sock), - .h.udp_table = &udp_table, + .h.udp_table = NULL, .diag_destroy = udp_abort, }; EXPORT_SYMBOL(udp_prot); @@ -2950,31 +3275,52 @@ EXPORT_SYMBOL(udp_prot); /* ------------------------------------------------------------------------ */ #ifdef CONFIG_PROC_FS +static unsigned short seq_file_family(const struct seq_file *seq); +static bool seq_sk_match(struct seq_file *seq, const struct sock *sk) +{ + unsigned short family = seq_file_family(seq); + + /* AF_UNSPEC is used as a match all */ + return ((family == AF_UNSPEC || family == sk->sk_family) && + net_eq(sock_net(sk), seq_file_net(seq))); +} + +#ifdef CONFIG_BPF_SYSCALL +static const struct seq_operations bpf_iter_udp_seq_ops; +#endif +static struct udp_table *udp_get_table_seq(struct seq_file *seq, + struct net *net) +{ + const struct udp_seq_afinfo *afinfo; + +#ifdef CONFIG_BPF_SYSCALL + if (seq->op == &bpf_iter_udp_seq_ops) + return net->ipv4.udp_table; +#endif + + afinfo = pde_data(file_inode(seq->file)); + return afinfo->udp_table ? : net->ipv4.udp_table; +} + static struct sock *udp_get_first(struct seq_file *seq, int start) { - struct sock *sk; - struct udp_seq_afinfo *afinfo; struct udp_iter_state *state = seq->private; struct net *net = seq_file_net(seq); + struct udp_table *udptable; + struct sock *sk; - if (state->bpf_seq_afinfo) - afinfo = state->bpf_seq_afinfo; - else - afinfo = pde_data(file_inode(seq->file)); + udptable = udp_get_table_seq(seq, net); - for (state->bucket = start; state->bucket <= afinfo->udp_table->mask; + for (state->bucket = start; state->bucket <= udptable->mask; ++state->bucket) { - struct udp_hslot *hslot = &afinfo->udp_table->hash[state->bucket]; + struct udp_hslot *hslot = &udptable->hash[state->bucket]; if (hlist_empty(&hslot->head)) continue; spin_lock_bh(&hslot->lock); sk_for_each(sk, &hslot->head) { - if (!net_eq(sock_net(sk), net)) - continue; - if (afinfo->family == AF_UNSPEC || - sk->sk_family == afinfo->family) + if (seq_sk_match(seq, sk)) goto found; } spin_unlock_bh(&hslot->lock); @@ -2986,24 +3332,20 @@ found: static struct sock *udp_get_next(struct seq_file *seq, struct sock *sk) { - struct udp_seq_afinfo *afinfo; struct udp_iter_state *state = seq->private; struct net *net = seq_file_net(seq); - - if (state->bpf_seq_afinfo) - afinfo = state->bpf_seq_afinfo; - else - afinfo = pde_data(file_inode(seq->file)); + struct udp_table *udptable; do { sk = sk_next(sk); - } while (sk && (!net_eq(sock_net(sk), net) || - (afinfo->family != AF_UNSPEC && - sk->sk_family != afinfo->family))); + } while (sk && !seq_sk_match(seq, sk)); if (!sk) { - if (state->bucket <= afinfo->udp_table->mask) - spin_unlock_bh(&afinfo->udp_table->hash[state->bucket].lock); + udptable = udp_get_table_seq(seq, net); + + if (state->bucket <= udptable->mask) + spin_unlock_bh(&udptable->hash[state->bucket].lock); + return udp_get_first(seq, state->bucket + 1); } return sk; @@ -3026,7 +3368,7 @@ void *udp_seq_start(struct seq_file *seq, loff_t *pos) return *pos ? udp_get_idx(seq, *pos-1) : SEQ_START_TOKEN; } -EXPORT_SYMBOL(udp_seq_start); +EXPORT_IPV6_MOD(udp_seq_start); void *udp_seq_next(struct seq_file *seq, void *v, loff_t *pos) { @@ -3040,22 +3382,19 @@ void *udp_seq_next(struct seq_file *seq, void *v, loff_t *pos) ++*pos; return sk; } -EXPORT_SYMBOL(udp_seq_next); +EXPORT_IPV6_MOD(udp_seq_next); void udp_seq_stop(struct seq_file *seq, void *v) { - struct udp_seq_afinfo *afinfo; struct udp_iter_state *state = seq->private; + struct udp_table *udptable; - if (state->bpf_seq_afinfo) - afinfo = state->bpf_seq_afinfo; - else - afinfo = pde_data(file_inode(seq->file)); + udptable = udp_get_table_seq(seq, seq_file_net(seq)); - if (state->bucket <= afinfo->udp_table->mask) - spin_unlock_bh(&afinfo->udp_table->hash[state->bucket].lock); + if (state->bucket <= udptable->mask) + spin_unlock_bh(&udptable->hash[state->bucket].lock); } -EXPORT_SYMBOL(udp_seq_stop); +EXPORT_IPV6_MOD(udp_seq_stop); /* ------------------------------------------------------------------------ */ static void udp4_format_sock(struct sock *sp, struct seq_file *f, @@ -3073,10 +3412,10 @@ static void udp4_format_sock(struct sock *sp, struct seq_file *f, sk_wmem_alloc_get(sp), udp_rqueue_get(sp), 0, 0L, 0, - from_kuid_munged(seq_user_ns(f), sock_i_uid(sp)), + from_kuid_munged(seq_user_ns(f), sk_uid(sp)), 0, sock_i_ino(sp), refcount_read(&sp->sk_refcnt), sp, - atomic_read(&sp->sk_drops)); + sk_drops_read(sp)); } int udp4_seq_show(struct seq_file *seq, void *v) @@ -3103,6 +3442,187 @@ struct bpf_iter__udp { int bucket __aligned(8); }; +union bpf_udp_iter_batch_item { + struct sock *sk; + __u64 cookie; +}; + +struct bpf_udp_iter_state { + struct udp_iter_state state; + unsigned int cur_sk; + unsigned int end_sk; + unsigned int max_sk; + union bpf_udp_iter_batch_item *batch; +}; + +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter, + unsigned int new_batch_sz, gfp_t flags); +static struct sock *bpf_iter_udp_resume(struct sock *first_sk, + union bpf_udp_iter_batch_item *cookies, + int n_cookies) +{ + struct sock *sk = NULL; + int i; + + for (i = 0; i < n_cookies; i++) { + sk = first_sk; + udp_portaddr_for_each_entry_from(sk) + if (cookies[i].cookie == atomic64_read(&sk->sk_cookie)) + goto done; + } +done: + return sk; +} + +static struct sock *bpf_iter_udp_batch(struct seq_file *seq) +{ + struct bpf_udp_iter_state *iter = seq->private; + struct udp_iter_state *state = &iter->state; + unsigned int find_cookie, end_cookie; + struct net *net = seq_file_net(seq); + struct udp_table *udptable; + unsigned int batch_sks = 0; + int resume_bucket; + int resizes = 0; + struct sock *sk; + int err = 0; + + resume_bucket = state->bucket; + + /* The current batch is done, so advance the bucket. */ + if (iter->cur_sk == iter->end_sk) + state->bucket++; + + udptable = udp_get_table_seq(seq, net); + +again: + /* New batch for the next bucket. + * Iterate over the hash table to find a bucket with sockets matching + * the iterator attributes, and return the first matching socket from + * the bucket. The remaining matched sockets from the bucket are batched + * before releasing the bucket lock. This allows BPF programs that are + * called in seq_show to acquire the bucket lock if needed. + */ + find_cookie = iter->cur_sk; + end_cookie = iter->end_sk; + iter->cur_sk = 0; + iter->end_sk = 0; + batch_sks = 0; + + for (; state->bucket <= udptable->mask; state->bucket++) { + struct udp_hslot *hslot2 = &udptable->hash2[state->bucket].hslot; + + if (hlist_empty(&hslot2->head)) + goto next_bucket; + + spin_lock_bh(&hslot2->lock); + sk = hlist_entry_safe(hslot2->head.first, struct sock, + __sk_common.skc_portaddr_node); + /* Resume from the first (in iteration order) unseen socket from + * the last batch that still exists in resume_bucket. Most of + * the time this will just be where the last iteration left off + * in resume_bucket unless that socket disappeared between + * reads. + */ + if (state->bucket == resume_bucket) + sk = bpf_iter_udp_resume(sk, &iter->batch[find_cookie], + end_cookie - find_cookie); +fill_batch: + udp_portaddr_for_each_entry_from(sk) { + if (seq_sk_match(seq, sk)) { + if (iter->end_sk < iter->max_sk) { + sock_hold(sk); + iter->batch[iter->end_sk++].sk = sk; + } + batch_sks++; + } + } + + /* Allocate a larger batch and try again. */ + if (unlikely(resizes <= 1 && iter->end_sk && + iter->end_sk != batch_sks)) { + resizes++; + + /* First, try with GFP_USER to maximize the chances of + * grabbing more memory. + */ + if (resizes == 1) { + spin_unlock_bh(&hslot2->lock); + err = bpf_iter_udp_realloc_batch(iter, + batch_sks * 3 / 2, + GFP_USER); + if (err) + return ERR_PTR(err); + /* Start over. */ + goto again; + } + + /* Next, hold onto the lock, so the bucket doesn't + * change while we get the rest of the sockets. + */ + err = bpf_iter_udp_realloc_batch(iter, batch_sks, + GFP_NOWAIT); + if (err) { + spin_unlock_bh(&hslot2->lock); + return ERR_PTR(err); + } + + /* Pick up where we left off. */ + sk = iter->batch[iter->end_sk - 1].sk; + sk = hlist_entry_safe(sk->__sk_common.skc_portaddr_node.next, + struct sock, + __sk_common.skc_portaddr_node); + batch_sks = iter->end_sk; + goto fill_batch; + } + + spin_unlock_bh(&hslot2->lock); + + if (iter->end_sk) + break; +next_bucket: + resizes = 0; + } + + WARN_ON_ONCE(iter->end_sk != batch_sks); + return iter->end_sk ? iter->batch[0].sk : NULL; +} + +static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct bpf_udp_iter_state *iter = seq->private; + struct sock *sk; + + /* Whenever seq_next() is called, the iter->cur_sk is + * done with seq_show(), so unref the iter->cur_sk. + */ + if (iter->cur_sk < iter->end_sk) + sock_put(iter->batch[iter->cur_sk++].sk); + + /* After updating iter->cur_sk, check if there are more sockets + * available in the current bucket batch. + */ + if (iter->cur_sk < iter->end_sk) + sk = iter->batch[iter->cur_sk].sk; + else + /* Prepare a new batch. */ + sk = bpf_iter_udp_batch(seq); + + ++*pos; + return sk; +} + +static void *bpf_iter_udp_seq_start(struct seq_file *seq, loff_t *pos) +{ + /* bpf iter does not support lseek, so it always + * continue from where it was stop()-ped. + */ + if (*pos) + return bpf_iter_udp_batch(seq); + + return SEQ_START_TOKEN; +} + static int udp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta, struct udp_sock *udp_sk, uid_t uid, int bucket) { @@ -3123,18 +3643,48 @@ static int bpf_iter_udp_seq_show(struct seq_file *seq, void *v) struct bpf_prog *prog; struct sock *sk = v; uid_t uid; + int ret; if (v == SEQ_START_TOKEN) return 0; - uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)); + lock_sock(sk); + + if (unlikely(sk_unhashed(sk))) { + ret = SEQ_SKIP; + goto unlock; + } + + uid = from_kuid_munged(seq_user_ns(seq), sk_uid(sk)); meta.seq = seq; prog = bpf_iter_get_info(&meta, false); - return udp_prog_seq_show(prog, &meta, v, uid, state->bucket); + ret = udp_prog_seq_show(prog, &meta, v, uid, state->bucket); + +unlock: + release_sock(sk); + return ret; +} + +static void bpf_iter_udp_put_batch(struct bpf_udp_iter_state *iter) +{ + union bpf_udp_iter_batch_item *item; + unsigned int cur_sk = iter->cur_sk; + __u64 cookie; + + /* Remember the cookies of the sockets we haven't seen yet, so we can + * pick up where we left off next time around. + */ + while (cur_sk < iter->end_sk) { + item = &iter->batch[cur_sk++]; + cookie = sock_gen_cookie(item->sk); + sock_put(item->sk); + item->cookie = cookie; + } } static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v) { + struct bpf_udp_iter_state *iter = seq->private; struct bpf_iter_meta meta; struct bpf_prog *prog; @@ -3145,28 +3695,44 @@ static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v) (void)udp_prog_seq_show(prog, &meta, v, 0, 0); } - udp_seq_stop(seq, v); + if (iter->cur_sk < iter->end_sk) + bpf_iter_udp_put_batch(iter); } static const struct seq_operations bpf_iter_udp_seq_ops = { - .start = udp_seq_start, - .next = udp_seq_next, + .start = bpf_iter_udp_seq_start, + .next = bpf_iter_udp_seq_next, .stop = bpf_iter_udp_seq_stop, .show = bpf_iter_udp_seq_show, }; #endif +static unsigned short seq_file_family(const struct seq_file *seq) +{ + const struct udp_seq_afinfo *afinfo; + +#ifdef CONFIG_BPF_SYSCALL + /* BPF iterator: bpf programs to filter sockets. */ + if (seq->op == &bpf_iter_udp_seq_ops) + return AF_UNSPEC; +#endif + + /* Proc fs iterator */ + afinfo = pde_data(file_inode(seq->file)); + return afinfo->family; +} + const struct seq_operations udp_seq_ops = { .start = udp_seq_start, .next = udp_seq_next, .stop = udp_seq_stop, .show = udp4_seq_show, }; -EXPORT_SYMBOL(udp_seq_ops); +EXPORT_IPV6_MOD(udp_seq_ops); static struct udp_seq_afinfo udp4_seq_afinfo = { .family = AF_INET, - .udp_table = &udp_table, + .udp_table = NULL, }; static int __net_init udp4_proc_init_net(struct net *net) @@ -3218,29 +3784,32 @@ __setup("uhash_entries=", set_uhash_entries); void __init udp_table_init(struct udp_table *table, const char *name) { - unsigned int i; + unsigned int i, slot_size; + slot_size = sizeof(struct udp_hslot) + sizeof(struct udp_hslot_main) + + udp_hash4_slot_size(); table->hash = alloc_large_system_hash(name, - 2 * sizeof(struct udp_hslot), + slot_size, uhash_entries, 21, /* one slot per 2 MB */ 0, &table->log, &table->mask, UDP_HTABLE_SIZE_MIN, - 64 * 1024); + UDP_HTABLE_SIZE_MAX); - table->hash2 = table->hash + (table->mask + 1); + table->hash2 = (void *)(table->hash + (table->mask + 1)); for (i = 0; i <= table->mask; i++) { INIT_HLIST_HEAD(&table->hash[i].head); table->hash[i].count = 0; spin_lock_init(&table->hash[i].lock); } for (i = 0; i <= table->mask; i++) { - INIT_HLIST_HEAD(&table->hash2[i].head); - table->hash2[i].count = 0; - spin_lock_init(&table->hash2[i].lock); + INIT_HLIST_HEAD(&table->hash2[i].hslot.head); + table->hash2[i].hslot.count = 0; + spin_lock_init(&table->hash2[i].hslot.lock); } + udp_table_hash4_init(table); } u32 udp_flow_hashrnd(void) @@ -3253,62 +3822,184 @@ u32 udp_flow_hashrnd(void) } EXPORT_SYMBOL(udp_flow_hashrnd); -static void __udp_sysctl_init(struct net *net) +static void __net_init udp_sysctl_init(struct net *net) { - net->ipv4.sysctl_udp_rmem_min = SK_MEM_QUANTUM; - net->ipv4.sysctl_udp_wmem_min = SK_MEM_QUANTUM; + net->ipv4.sysctl_udp_rmem_min = PAGE_SIZE; + net->ipv4.sysctl_udp_wmem_min = PAGE_SIZE; #ifdef CONFIG_NET_L3_MASTER_DEV net->ipv4.sysctl_udp_l3mdev_accept = 0; #endif } -static int __net_init udp_sysctl_init(struct net *net) +static struct udp_table __net_init *udp_pernet_table_alloc(unsigned int hash_entries) { - __udp_sysctl_init(net); + struct udp_table *udptable; + unsigned int slot_size; + int i; + + udptable = kmalloc(sizeof(*udptable), GFP_KERNEL); + if (!udptable) + goto out; + + slot_size = sizeof(struct udp_hslot) + sizeof(struct udp_hslot_main) + + udp_hash4_slot_size(); + udptable->hash = vmalloc_huge(hash_entries * slot_size, + GFP_KERNEL_ACCOUNT); + if (!udptable->hash) + goto free_table; + + udptable->hash2 = (void *)(udptable->hash + hash_entries); + udptable->mask = hash_entries - 1; + udptable->log = ilog2(hash_entries); + + for (i = 0; i < hash_entries; i++) { + INIT_HLIST_HEAD(&udptable->hash[i].head); + udptable->hash[i].count = 0; + spin_lock_init(&udptable->hash[i].lock); + + INIT_HLIST_HEAD(&udptable->hash2[i].hslot.head); + udptable->hash2[i].hslot.count = 0; + spin_lock_init(&udptable->hash2[i].hslot.lock); + } + udp_table_hash4_init(udptable); + + return udptable; + +free_table: + kfree(udptable); +out: + return NULL; +} + +static void __net_exit udp_pernet_table_free(struct net *net) +{ + struct udp_table *udptable = net->ipv4.udp_table; + + if (udptable == &udp_table) + return; + + kvfree(udptable->hash); + kfree(udptable); +} + +static void __net_init udp_set_table(struct net *net) +{ + struct udp_table *udptable; + unsigned int hash_entries; + struct net *old_net; + + if (net_eq(net, &init_net)) + goto fallback; + + old_net = current->nsproxy->net_ns; + hash_entries = READ_ONCE(old_net->ipv4.sysctl_udp_child_hash_entries); + if (!hash_entries) + goto fallback; + + /* Set min to keep the bitmap on stack in udp_lib_get_port() */ + if (hash_entries < UDP_HTABLE_SIZE_MIN_PERNET) + hash_entries = UDP_HTABLE_SIZE_MIN_PERNET; + else + hash_entries = roundup_pow_of_two(hash_entries); + + udptable = udp_pernet_table_alloc(hash_entries); + if (udptable) { + net->ipv4.udp_table = udptable; + } else { + pr_warn("Failed to allocate UDP hash table (entries: %u) " + "for a netns, fallback to the global one\n", + hash_entries); +fallback: + net->ipv4.udp_table = &udp_table; + } +} + +static int __net_init udp_pernet_init(struct net *net) +{ +#if IS_ENABLED(CONFIG_NET_UDP_TUNNEL) + int i; + + /* No tunnel is configured */ + for (i = 0; i < ARRAY_SIZE(net->ipv4.udp_tunnel_gro); ++i) { + INIT_HLIST_HEAD(&net->ipv4.udp_tunnel_gro[i].list); + RCU_INIT_POINTER(net->ipv4.udp_tunnel_gro[i].sk, NULL); + } +#endif + udp_sysctl_init(net); + udp_set_table(net); + return 0; } +static void __net_exit udp_pernet_exit(struct net *net) +{ + udp_pernet_table_free(net); +} + static struct pernet_operations __net_initdata udp_sysctl_ops = { - .init = udp_sysctl_init, + .init = udp_pernet_init, + .exit = udp_pernet_exit, }; #if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta, struct udp_sock *udp_sk, uid_t uid, int bucket) -static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux) +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter, + unsigned int new_batch_sz, gfp_t flags) { - struct udp_iter_state *st = priv_data; - struct udp_seq_afinfo *afinfo; - int ret; + union bpf_udp_iter_batch_item *new_batch; - afinfo = kmalloc(sizeof(*afinfo), GFP_USER | __GFP_NOWARN); - if (!afinfo) + new_batch = kvmalloc_array(new_batch_sz, sizeof(*new_batch), + flags | __GFP_NOWARN); + if (!new_batch) return -ENOMEM; - afinfo->family = AF_UNSPEC; - afinfo->udp_table = &udp_table; - st->bpf_seq_afinfo = afinfo; + if (flags != GFP_NOWAIT) + bpf_iter_udp_put_batch(iter); + + memcpy(new_batch, iter->batch, sizeof(*iter->batch) * iter->end_sk); + kvfree(iter->batch); + iter->batch = new_batch; + iter->max_sk = new_batch_sz; + + return 0; +} + +#define INIT_BATCH_SZ 16 + +static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux) +{ + struct bpf_udp_iter_state *iter = priv_data; + int ret; + ret = bpf_iter_init_seq_net(priv_data, aux); if (ret) - kfree(afinfo); + return ret; + + ret = bpf_iter_udp_realloc_batch(iter, INIT_BATCH_SZ, GFP_USER); + if (ret) + bpf_iter_fini_seq_net(priv_data); + + iter->state.bucket = -1; + return ret; } static void bpf_iter_fini_udp(void *priv_data) { - struct udp_iter_state *st = priv_data; + struct bpf_udp_iter_state *iter = priv_data; - kfree(st->bpf_seq_afinfo); bpf_iter_fini_seq_net(priv_data); + kvfree(iter->batch); } static const struct bpf_iter_seq_info udp_seq_info = { .seq_ops = &bpf_iter_udp_seq_ops, .init_seq_private = bpf_iter_init_udp, .fini_seq_private = bpf_iter_fini_udp, - .seq_priv_size = sizeof(struct udp_iter_state), + .seq_priv_size = sizeof(struct bpf_udp_iter_state), }; static struct bpf_iter_reg udp_reg_info = { @@ -3316,7 +4007,7 @@ static struct bpf_iter_reg udp_reg_info = { .ctx_arg_info_size = 1, .ctx_arg_info = { { offsetof(struct bpf_iter__udp, udp_sk), - PTR_TO_BTF_ID_OR_NULL }, + PTR_TO_BTF_ID_OR_NULL | PTR_TRUSTED }, }, .seq_info = &udp_seq_info, }; @@ -3332,7 +4023,6 @@ static void __init bpf_iter_register(void) void __init udp_init(void) { unsigned long limit; - unsigned int i; udp_table_init(&udp_table, "UDP"); limit = nr_free_buffer_pages() / 8; @@ -3341,17 +4031,6 @@ void __init udp_init(void) sysctl_udp_mem[1] = limit; sysctl_udp_mem[2] = sysctl_udp_mem[0] * 2; - __udp_sysctl_init(&init_net); - - /* 16 spinlocks per cpu */ - udp_busylocks_log = ilog2(nr_cpu_ids) + 4; - udp_busylocks = kmalloc(sizeof(spinlock_t) << udp_busylocks_log, - GFP_KERNEL); - if (!udp_busylocks) - panic("UDP: failed to alloc udp_busylocks\n"); - for (i = 0; i < (1U << udp_busylocks_log); i++) - spin_lock_init(udp_busylocks + i); - if (register_pernet_subsys(&udp_sysctl_ops)) panic("UDP: failed to init sysctl parameters.\n"); diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c index bbe6569c9ad3..0735d820e413 100644 --- a/net/ipv4/udp_bpf.c +++ b/net/ipv4/udp_bpf.c @@ -11,14 +11,13 @@ static struct proto *udpv6_prot_saved __read_mostly; static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, - int noblock, int flags, int *addr_len) + int flags, int *addr_len) { #if IS_ENABLED(CONFIG_IPV6) if (sk->sk_family == AF_INET6) - return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags, - addr_len); + return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len); #endif - return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len); + return udp_prot.recvmsg(sk, msg, len, flags, addr_len); } static bool udp_sk_has_data(struct sock *sk) @@ -61,7 +60,7 @@ static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, } static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, - int nonblock, int flags, int *addr_len) + int flags, int *addr_len) { struct sk_psock *psock; int copied, ret; @@ -69,12 +68,15 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, if (unlikely(flags & MSG_ERRQUEUE)) return inet_recv_error(sk, msg, len, addr_len); + if (!len) + return 0; + psock = sk_psock_get(sk); if (unlikely(!psock)) - return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + return sk_udp_recvmsg(sk, msg, len, flags, addr_len); if (!psock_has_data(psock)) { - ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); goto out; } @@ -84,12 +86,12 @@ msg_bytes_ready: long timeo; int data; - timeo = sock_rcvtimeo(sk, nonblock); + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); data = udp_msg_wait_data(sk, psock, timeo); if (data) { if (psock_has_data(psock)) goto msg_bytes_ready; - ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); goto out; } copied = -EAGAIN; @@ -142,14 +144,14 @@ int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) if (restore) { sk->sk_write_space = psock->saved_write_space; - WRITE_ONCE(sk->sk_prot, psock->sk_proto); + sock_replace_proto(sk, psock->sk_proto); return 0; } if (sk->sk_family == AF_INET6) udp_bpf_check_v6_needs_rebuild(psock->sk_proto); - WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]); + sock_replace_proto(sk, &udp_bpf_prots[family]); return 0; } EXPORT_SYMBOL_GPL(udp_bpf_update_proto); diff --git a/net/ipv4/udp_diag.c b/net/ipv4/udp_diag.c index 1ed8c4d78e5c..6e491c720c90 100644 --- a/net/ipv4/udp_diag.c +++ b/net/ipv4/udp_diag.c @@ -16,9 +16,9 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *req, - struct nlattr *bc, bool net_admin) + bool net_admin) { - if (!inet_diag_bc_sk(bc, sk)) + if (!inet_diag_bc_sk(cb->data, sk)) return 0; return inet_sk_diag_fill(sk, NULL, skb, cb, req, NLM_F_MULTI, @@ -92,12 +92,8 @@ static void udp_dump(struct udp_table *table, struct sk_buff *skb, { bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); struct net *net = sock_net(skb->sk); - struct inet_diag_dump_data *cb_data; int num, s_num, slot, s_slot; - struct nlattr *bc; - cb_data = cb->data; - bc = cb_data->inet_diag_nla_bc; s_slot = cb->args[0]; num = s_num = cb->args[1]; @@ -130,7 +126,7 @@ static void udp_dump(struct udp_table *table, struct sk_buff *skb, r->id.idiag_dport) goto next; - if (sk_diag_dump(sk, skb, cb, r, bc, net_admin) < 0) { + if (sk_diag_dump(sk, skb, cb, r, net_admin) < 0) { spin_unlock_bh(&hslot->lock); goto done; } @@ -147,13 +143,13 @@ done: static void udp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *r) { - udp_dump(&udp_table, skb, cb, r); + udp_dump(sock_net(cb->skb->sk)->ipv4.udp_table, skb, cb, r); } static int udp_diag_dump_one(struct netlink_callback *cb, const struct inet_diag_req_v2 *req) { - return udp_dump_one(&udp_table, cb, req); + return udp_dump_one(sock_net(cb->skb->sk)->ipv4.udp_table, cb, req); } static void udp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, @@ -225,7 +221,7 @@ static int __udp_diag_destroy(struct sk_buff *in_skb, static int udp_diag_destroy(struct sk_buff *in_skb, const struct inet_diag_req_v2 *req) { - return __udp_diag_destroy(in_skb, req, &udp_table); + return __udp_diag_destroy(in_skb, req, sock_net(in_skb->sk)->ipv4.udp_table); } static int udplite_diag_destroy(struct sk_buff *in_skb, @@ -237,6 +233,7 @@ static int udplite_diag_destroy(struct sk_buff *in_skb, #endif static const struct inet_diag_handler udp_diag_handler = { + .owner = THIS_MODULE, .dump = udp_diag_dump, .dump_one = udp_diag_dump_one, .idiag_get_info = udp_diag_get_info, @@ -260,6 +257,7 @@ static int udplite_diag_dump_one(struct netlink_callback *cb, } static const struct inet_diag_handler udplite_diag_handler = { + .owner = THIS_MODULE, .dump = udplite_diag_dump, .dump_one = udplite_diag_dump_one, .idiag_get_info = udp_diag_get_info, @@ -296,5 +294,6 @@ static void __exit udp_diag_exit(void) module_init(udp_diag_init); module_exit(udp_diag_exit); MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("UDP socket monitoring via SOCK_DIAG"); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-17 /* AF_INET - IPPROTO_UDP */); MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-136 /* AF_INET - IPPROTO_UDPLITE */); diff --git a/net/ipv4/udp_impl.h b/net/ipv4/udp_impl.h index 2878d8285caf..c7142213fc21 100644 --- a/net/ipv4/udp_impl.h +++ b/net/ipv4/udp_impl.h @@ -1,6 +1,7 @@ /* SPDX-License-Identifier: GPL-2.0 */ #ifndef _UDP4_IMPL_H #define _UDP4_IMPL_H +#include <net/aligned_data.h> #include <net/udp.h> #include <net/udplite.h> #include <net/protocol.h> @@ -17,10 +18,8 @@ int udp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, int udp_getsockopt(struct sock *sk, int level, int optname, char __user *optval, int __user *optlen); -int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, - int flags, int *addr_len); -int udp_sendpage(struct sock *sk, struct page *page, int offset, size_t size, - int flags); +int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len); void udp_destroy_sock(struct sock *sk); #ifdef CONFIG_PROC_FS diff --git a/net/ipv4/udp_offload.c b/net/ipv4/udp_offload.c index 6d1a4bec2614..19d0b5b09ffa 100644 --- a/net/ipv4/udp_offload.c +++ b/net/ipv4/udp_offload.c @@ -8,9 +8,166 @@ #include <linux/skbuff.h> #include <net/gro.h> +#include <net/gso.h> #include <net/udp.h> #include <net/protocol.h> #include <net/inet_common.h> +#include <net/udp_tunnel.h> + +#if IS_ENABLED(CONFIG_NET_UDP_TUNNEL) + +/* + * Dummy GRO tunnel callback, exists mainly to avoid dangling/NULL + * values for the udp tunnel static call. + */ +static struct sk_buff *dummy_gro_rcv(struct sock *sk, + struct list_head *head, + struct sk_buff *skb) +{ + NAPI_GRO_CB(skb)->flush = 1; + return NULL; +} + +typedef struct sk_buff *(*udp_tunnel_gro_rcv_t)(struct sock *sk, + struct list_head *head, + struct sk_buff *skb); + +struct udp_tunnel_type_entry { + udp_tunnel_gro_rcv_t gro_receive; + refcount_t count; +}; + +#define UDP_MAX_TUNNEL_TYPES (IS_ENABLED(CONFIG_GENEVE) + \ + IS_ENABLED(CONFIG_VXLAN) * 2 + \ + IS_ENABLED(CONFIG_NET_FOU) * 2 + \ + IS_ENABLED(CONFIG_XFRM) * 2) + +DEFINE_STATIC_CALL(udp_tunnel_gro_rcv, dummy_gro_rcv); +static DEFINE_STATIC_KEY_FALSE(udp_tunnel_static_call); +static DEFINE_MUTEX(udp_tunnel_gro_type_lock); +static struct udp_tunnel_type_entry udp_tunnel_gro_types[UDP_MAX_TUNNEL_TYPES]; +static unsigned int udp_tunnel_gro_type_nr; +static DEFINE_SPINLOCK(udp_tunnel_gro_lock); + +void udp_tunnel_update_gro_lookup(struct net *net, struct sock *sk, bool add) +{ + bool is_ipv6 = sk->sk_family == AF_INET6; + struct udp_sock *tup, *up = udp_sk(sk); + struct udp_tunnel_gro *udp_tunnel_gro; + + spin_lock(&udp_tunnel_gro_lock); + udp_tunnel_gro = &net->ipv4.udp_tunnel_gro[is_ipv6]; + if (add) + hlist_add_head(&up->tunnel_list, &udp_tunnel_gro->list); + else if (up->tunnel_list.pprev) + hlist_del_init(&up->tunnel_list); + + if (udp_tunnel_gro->list.first && + !udp_tunnel_gro->list.first->next) { + tup = hlist_entry(udp_tunnel_gro->list.first, struct udp_sock, + tunnel_list); + + rcu_assign_pointer(udp_tunnel_gro->sk, (struct sock *)tup); + } else { + RCU_INIT_POINTER(udp_tunnel_gro->sk, NULL); + } + + spin_unlock(&udp_tunnel_gro_lock); +} +EXPORT_SYMBOL_GPL(udp_tunnel_update_gro_lookup); + +void udp_tunnel_update_gro_rcv(struct sock *sk, bool add) +{ + struct udp_tunnel_type_entry *cur = NULL; + struct udp_sock *up = udp_sk(sk); + int i, old_gro_type_nr; + + if (!UDP_MAX_TUNNEL_TYPES || !up->gro_receive) + return; + + mutex_lock(&udp_tunnel_gro_type_lock); + + /* Check if the static call is permanently disabled. */ + if (udp_tunnel_gro_type_nr > UDP_MAX_TUNNEL_TYPES) + goto out; + + for (i = 0; i < udp_tunnel_gro_type_nr; i++) + if (udp_tunnel_gro_types[i].gro_receive == up->gro_receive) + cur = &udp_tunnel_gro_types[i]; + + old_gro_type_nr = udp_tunnel_gro_type_nr; + if (add) { + /* + * Update the matching entry, if found, or add a new one + * if needed + */ + if (cur) { + refcount_inc(&cur->count); + goto out; + } + + if (unlikely(udp_tunnel_gro_type_nr == UDP_MAX_TUNNEL_TYPES)) { + pr_err_once("Too many UDP tunnel types, please increase UDP_MAX_TUNNEL_TYPES\n"); + /* Ensure static call will never be enabled */ + udp_tunnel_gro_type_nr = UDP_MAX_TUNNEL_TYPES + 1; + } else { + cur = &udp_tunnel_gro_types[udp_tunnel_gro_type_nr++]; + refcount_set(&cur->count, 1); + cur->gro_receive = up->gro_receive; + } + } else { + /* + * The stack cleanups only successfully added tunnel, the + * lookup on removal should never fail. + */ + if (WARN_ON_ONCE(!cur)) + goto out; + + if (!refcount_dec_and_test(&cur->count)) + goto out; + + /* Avoid gaps, so that the enable tunnel has always id 0 */ + *cur = udp_tunnel_gro_types[--udp_tunnel_gro_type_nr]; + } + + if (udp_tunnel_gro_type_nr == 1) { + static_call_update(udp_tunnel_gro_rcv, + udp_tunnel_gro_types[0].gro_receive); + static_branch_enable(&udp_tunnel_static_call); + } else if (old_gro_type_nr == 1) { + static_branch_disable(&udp_tunnel_static_call); + static_call_update(udp_tunnel_gro_rcv, dummy_gro_rcv); + } + +out: + mutex_unlock(&udp_tunnel_gro_type_lock); +} +EXPORT_SYMBOL_GPL(udp_tunnel_update_gro_rcv); + +static struct sk_buff *udp_tunnel_gro_rcv(struct sock *sk, + struct list_head *head, + struct sk_buff *skb) +{ + if (static_branch_likely(&udp_tunnel_static_call)) { + if (unlikely(gro_recursion_inc_test(skb))) { + NAPI_GRO_CB(skb)->flush |= 1; + return NULL; + } + return static_call(udp_tunnel_gro_rcv)(sk, head, skb); + } + return call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb); +} + +#else + +static struct sk_buff *udp_tunnel_gro_rcv(struct sock *sk, + struct list_head *head, + struct sk_buff *skb) +{ + return call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb); +} + +#endif static struct sk_buff *__skb_udp_tunnel_segment(struct sk_buff *skb, netdev_features_t features, @@ -60,7 +217,7 @@ static struct sk_buff *__skb_udp_tunnel_segment(struct sk_buff *skb, remcsum = !!(skb_shinfo(skb)->gso_type & SKB_GSO_TUNNEL_REMCSUM); skb->remcsum_offload = remcsum; - need_ipsec = skb_dst(skb) && dst_xfrm(skb_dst(skb)); + need_ipsec = (skb_dst(skb) && dst_xfrm(skb_dst(skb))) || skb_sec_path(skb); /* Try to offload checksum if possible */ offload_csum = !!(need_csum && !need_ipsec && @@ -246,6 +403,62 @@ static struct sk_buff *__udpv4_gso_segment_list_csum(struct sk_buff *segs) return segs; } +static void __udpv6_gso_segment_csum(struct sk_buff *seg, + struct in6_addr *oldip, + const struct in6_addr *newip, + __be16 *oldport, __be16 newport) +{ + struct udphdr *uh = udp_hdr(seg); + + if (ipv6_addr_equal(oldip, newip) && *oldport == newport) + return; + + if (uh->check) { + inet_proto_csum_replace16(&uh->check, seg, oldip->s6_addr32, + newip->s6_addr32, true); + + inet_proto_csum_replace2(&uh->check, seg, *oldport, newport, + false); + if (!uh->check) + uh->check = CSUM_MANGLED_0; + } + + *oldip = *newip; + *oldport = newport; +} + +static struct sk_buff *__udpv6_gso_segment_list_csum(struct sk_buff *segs) +{ + const struct ipv6hdr *iph; + const struct udphdr *uh; + struct ipv6hdr *iph2; + struct sk_buff *seg; + struct udphdr *uh2; + + seg = segs; + uh = udp_hdr(seg); + iph = ipv6_hdr(seg); + uh2 = udp_hdr(seg->next); + iph2 = ipv6_hdr(seg->next); + + if (!(*(const u32 *)&uh->source ^ *(const u32 *)&uh2->source) && + ipv6_addr_equal(&iph->saddr, &iph2->saddr) && + ipv6_addr_equal(&iph->daddr, &iph2->daddr)) + return segs; + + while ((seg = seg->next)) { + uh2 = udp_hdr(seg); + iph2 = ipv6_hdr(seg); + + __udpv6_gso_segment_csum(seg, &iph2->saddr, &iph->saddr, + &uh2->source, uh->source); + __udpv6_gso_segment_csum(seg, &iph2->daddr, &iph->daddr, + &uh2->dest, uh->dest); + } + + return segs; +} + static struct sk_buff *__udp_gso_segment_list(struct sk_buff *skb, netdev_features_t features, bool is_ipv6) @@ -258,7 +471,10 @@ static struct sk_buff *__udp_gso_segment_list(struct sk_buff *skb, udp_hdr(skb)->len = htons(sizeof(struct udphdr) + mss); - return is_ipv6 ? skb : __udpv4_gso_segment_list_csum(skb); + if (is_ipv6) + return __udpv6_gso_segment_list_csum(skb); + else + return __udpv4_gso_segment_list_csum(skb); } struct sk_buff *__udp_gso_segment(struct sk_buff *gso_skb, @@ -272,25 +488,70 @@ struct sk_buff *__udp_gso_segment(struct sk_buff *gso_skb, bool copy_dtor; __sum16 check; __be16 newlen; - - if (skb_shinfo(gso_skb)->gso_type & SKB_GSO_FRAGLIST) - return __udp_gso_segment_list(gso_skb, features, is_ipv6); + int ret = 0; mss = skb_shinfo(gso_skb)->gso_size; if (gso_skb->len <= sizeof(*uh) + mss) return ERR_PTR(-EINVAL); + if (unlikely(skb_checksum_start(gso_skb) != + skb_transport_header(gso_skb) && + !(skb_shinfo(gso_skb)->gso_type & SKB_GSO_FRAGLIST))) + return ERR_PTR(-EINVAL); + + /* We don't know if egress device can segment and checksum the packet + * when IPv6 extension headers are present. Fall back to software GSO. + */ + if (gso_skb->ip_summed != CHECKSUM_PARTIAL) + features &= ~(NETIF_F_GSO_UDP_L4 | NETIF_F_CSUM_MASK); + + if (skb_gso_ok(gso_skb, features | NETIF_F_GSO_ROBUST)) { + /* Packet is from an untrusted source, reset gso_segs. */ + skb_shinfo(gso_skb)->gso_segs = DIV_ROUND_UP(gso_skb->len - sizeof(*uh), + mss); + return NULL; + } + + if (skb_shinfo(gso_skb)->gso_type & SKB_GSO_FRAGLIST) { + /* Detect modified geometry and pass those to skb_segment. */ + if (skb_pagelen(gso_skb) - sizeof(*uh) == skb_shinfo(gso_skb)->gso_size) + return __udp_gso_segment_list(gso_skb, features, is_ipv6); + + ret = __skb_linearize(gso_skb); + if (ret) + return ERR_PTR(ret); + + /* Setup csum, as fraglist skips this in udp4_gro_receive. */ + gso_skb->csum_start = skb_transport_header(gso_skb) - gso_skb->head; + gso_skb->csum_offset = offsetof(struct udphdr, check); + gso_skb->ip_summed = CHECKSUM_PARTIAL; + + uh = udp_hdr(gso_skb); + if (is_ipv6) + uh->check = ~udp_v6_check(gso_skb->len, + &ipv6_hdr(gso_skb)->saddr, + &ipv6_hdr(gso_skb)->daddr, 0); + else + uh->check = ~udp_v4_check(gso_skb->len, + ip_hdr(gso_skb)->saddr, + ip_hdr(gso_skb)->daddr, 0); + } + skb_pull(gso_skb, sizeof(*uh)); /* clear destructor to avoid skb_segment assigning it to tail */ copy_dtor = gso_skb->destructor == sock_wfree; - if (copy_dtor) + if (copy_dtor) { gso_skb->destructor = NULL; + gso_skb->sk = NULL; + } segs = skb_segment(gso_skb, features); if (IS_ERR_OR_NULL(segs)) { - if (copy_dtor) + if (copy_dtor) { gso_skb->destructor = sock_wfree; + gso_skb->sk = sk; + } return segs; } @@ -349,6 +610,14 @@ struct sk_buff *__udp_gso_segment(struct sk_buff *gso_skb, else uh->check = gso_make_checksum(seg, ~check) ? : CSUM_MANGLED_0; + /* On the TX path, CHECKSUM_NONE and CHECKSUM_UNNECESSARY have the same + * meaning. However, check for bad offloads in the GSO stack expects the + * latter, if the checksum was calculated in software. To vouch for the + * segment skbs we actually need to set it on the gso_skb. + */ + if (gso_skb->ip_summed == CHECKSUM_NONE) + gso_skb->ip_summed = CHECKSUM_UNNECESSARY; + /* update refcount for the packet */ if (copy_dtor) { int delta = sum_truesize - gso_skb->truesize; @@ -425,32 +694,6 @@ out: return segs; } -static int skb_gro_receive_list(struct sk_buff *p, struct sk_buff *skb) -{ - if (unlikely(p->len + skb->len >= 65536)) - return -E2BIG; - - if (NAPI_GRO_CB(p)->last == p) - skb_shinfo(p)->frag_list = skb; - else - NAPI_GRO_CB(p)->last->next = skb; - - skb_pull(skb, skb_gro_offset(skb)); - - NAPI_GRO_CB(p)->last = skb; - NAPI_GRO_CB(p)->count++; - p->data_len += skb->len; - - /* sk owenrship - if any - completely transferred to the aggregated packet */ - skb->destructor = NULL; - p->truesize += skb->truesize; - p->len += skb->len; - - NAPI_GRO_CB(skb)->same_flow = 1; - - return 0; -} - #define UDP_GRO_CNT_MAX 64 static struct sk_buff *udp_gro_receive_segment(struct list_head *head, @@ -462,6 +705,7 @@ static struct sk_buff *udp_gro_receive_segment(struct list_head *head, struct sk_buff *p; unsigned int ulen; int ret = 0; + int flush; /* requires non zero csum, for symmetry with GSO */ if (!uh->check) { @@ -495,13 +739,15 @@ static struct sk_buff *udp_gro_receive_segment(struct list_head *head, return p; } + flush = gro_receive_network_flush(uh, uh2, p); + /* Terminate the flow on len mismatch or if it grow "too much". * Under small packet flood GRO count could elsewhere grow a lot * leading to excessive truesize values. * On len mismatch merge the first packet shorter than gso_size, * otherwise complete the GRO packet. */ - if (ulen > ntohs(uh2->len)) { + if (ulen > ntohs(uh2->len) || flush) { pp = p; } else { if (NAPI_GRO_CB(skb)->is_flist) { @@ -514,6 +760,7 @@ static struct sk_buff *udp_gro_receive_segment(struct list_head *head, NAPI_GRO_CB(skb)->flush = 1; return NULL; } + skb_set_network_header(skb, skb_gro_receive_network_offset(skb)); ret = skb_gro_receive_list(p, skb); } else { skb_gro_postpull_rcsum(skb, uh, @@ -543,16 +790,24 @@ struct sk_buff *udp_gro_receive(struct list_head *head, struct sk_buff *skb, unsigned int off = skb_gro_offset(skb); int flush = 1; - /* we can do L4 aggregation only if the packet can't land in a tunnel - * otherwise we could corrupt the inner stream + /* We can do L4 aggregation only if the packet can't land in a tunnel + * otherwise we could corrupt the inner stream. Detecting such packets + * cannot be foolproof and the aggregation might still happen in some + * cases. Such packets should be caught in udp_unexpected_gso later. */ NAPI_GRO_CB(skb)->is_flist = 0; if (!sk || !udp_sk(sk)->gro_receive) { + /* If the packet was locally encapsulated in a UDP tunnel that + * wasn't detected above, do not GRO. + */ + if (skb->encapsulation) + goto out; + if (skb->dev->features & NETIF_F_GRO_FRAGLIST) - NAPI_GRO_CB(skb)->is_flist = sk ? !udp_sk(sk)->gro_enabled : 1; + NAPI_GRO_CB(skb)->is_flist = sk ? !udp_test_bit(GRO_ENABLED, sk) : 1; if ((!sk && (skb->dev->features & NETIF_F_GRO_UDP_FWD)) || - (sk && udp_sk(sk)->gro_enabled) || NAPI_GRO_CB(skb)->is_flist) + (sk && udp_test_bit(GRO_ENABLED, sk)) || NAPI_GRO_CB(skb)->is_flist) return call_gro_receive(udp_gro_receive_segment, head, skb); /* no GRO, be sure flush the current packet */ @@ -588,7 +843,7 @@ struct sk_buff *udp_gro_receive(struct list_head *head, struct sk_buff *skb, skb_gro_pull(skb, sizeof(struct udphdr)); /* pull encapsulating udp header */ skb_gro_postpull_rcsum(skb, uh, sizeof(struct udphdr)); - pp = call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb); + pp = udp_tunnel_gro_rcv(sk, head, skb); out: skb_gro_flush_final(skb, pp, flush); @@ -600,10 +855,19 @@ static struct sock *udp4_gro_lookup_skb(struct sk_buff *skb, __be16 sport, __be16 dport) { const struct iphdr *iph = skb_gro_network_header(skb); + struct net *net = dev_net_rcu(skb->dev); + struct sock *sk; + int iif, sdif; - return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport, - iph->daddr, dport, inet_iif(skb), - inet_sdif(skb), &udp_table, NULL); + sk = udp_tunnel_sk(net, false); + if (sk && dport == htons(sk->sk_num)) + return sk; + + inet_get_iif_sdif(skb, &iif, &sdif); + + return __udp4_lib_lookup(net, iph->saddr, sport, + iph->daddr, dport, iif, + sdif, net->ipv4.udp_table, NULL); } INDIRECT_CALLABLE_SCOPE @@ -627,8 +891,6 @@ struct sk_buff *udp4_gro_receive(struct list_head *head, struct sk_buff *skb) skb_gro_checksum_try_convert(skb, IPPROTO_UDP, inet_gro_compute_pseudo); skip: - NAPI_GRO_CB(skb)->is_ipv6 = 0; - if (static_branch_unlikely(&udp_encap_needed_key)) sk = udp4_gro_lookup_skb(skb, uh->source, uh->dest); @@ -697,7 +959,8 @@ EXPORT_SYMBOL(udp_gro_complete); INDIRECT_CALLABLE_SCOPE int udp4_gro_complete(struct sk_buff *skb, int nhoff) { - const struct iphdr *iph = ip_hdr(skb); + const u16 offset = NAPI_GRO_CB(skb)->network_offsets[skb->encapsulation]; + const struct iphdr *iph = (struct iphdr *)(skb->data + offset); struct udphdr *uh = (struct udphdr *)(skb->data + nhoff); /* do fraglist only if there is no outer UDP encap (or we already processed it) */ @@ -707,13 +970,7 @@ INDIRECT_CALLABLE_SCOPE int udp4_gro_complete(struct sk_buff *skb, int nhoff) skb_shinfo(skb)->gso_type |= (SKB_GSO_FRAGLIST|SKB_GSO_UDP_L4); skb_shinfo(skb)->gso_segs = NAPI_GRO_CB(skb)->count; - if (skb->ip_summed == CHECKSUM_UNNECESSARY) { - if (skb->csum_level < SKB_MAX_CSUM_LEVEL) - skb->csum_level++; - } else { - skb->ip_summed = CHECKSUM_UNNECESSARY; - skb->csum_level = 0; - } + __skb_incr_checksum_unnecessary(skb); return 0; } @@ -725,15 +982,15 @@ INDIRECT_CALLABLE_SCOPE int udp4_gro_complete(struct sk_buff *skb, int nhoff) return udp_gro_complete(skb, nhoff, udp4_lib_lookup_skb); } -static const struct net_offload udpv4_offload = { - .callbacks = { - .gso_segment = udp4_ufo_fragment, - .gro_receive = udp4_gro_receive, - .gro_complete = udp4_gro_complete, - }, -}; - int __init udpv4_offload_init(void) { - return inet_add_offload(&udpv4_offload, IPPROTO_UDP); + net_hotdata.udpv4_offload = (struct net_offload) { + .callbacks = { + .gso_segment = udp4_ufo_fragment, + .gro_receive = udp4_gro_receive, + .gro_complete = udp4_gro_complete, + }, + }; + + return inet_add_offload(&net_hotdata.udpv4_offload, IPPROTO_UDP); } diff --git a/net/ipv4/udp_tunnel_core.c b/net/ipv4/udp_tunnel_core.c index 8efaf8c3fe2a..b1f667c52cb2 100644 --- a/net/ipv4/udp_tunnel_core.c +++ b/net/ipv4/udp_tunnel_core.c @@ -4,8 +4,10 @@ #include <linux/socket.h> #include <linux/kernel.h> #include <net/dst_metadata.h> +#include <net/flow.h> #include <net/udp.h> #include <net/udp_tunnel.h> +#include <net/inet_dscp.h> int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg, struct socket **sockp) @@ -27,7 +29,7 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg, udp_addr.sin_family = AF_INET; udp_addr.sin_addr = cfg->local_ip; udp_addr.sin_port = cfg->local_udp_port; - err = kernel_bind(sock, (struct sockaddr *)&udp_addr, + err = kernel_bind(sock, (struct sockaddr_unsized *)&udp_addr, sizeof(udp_addr)); if (err < 0) goto error; @@ -36,7 +38,7 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg, udp_addr.sin_family = AF_INET; udp_addr.sin_addr = cfg->peer_ip; udp_addr.sin_port = cfg->peer_udp_port; - err = kernel_connect(sock, (struct sockaddr *)&udp_addr, + err = kernel_connect(sock, (struct sockaddr_unsized *)&udp_addr, sizeof(udp_addr), 0); if (err < 0) goto error; @@ -57,13 +59,22 @@ error: } EXPORT_SYMBOL(udp_sock_create4); +static bool sk_saddr_any(struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + return ipv6_addr_any(&sk->sk_v6_rcv_saddr); +#else + return !sk->sk_rcv_saddr; +#endif +} + void setup_udp_tunnel_sock(struct net *net, struct socket *sock, struct udp_tunnel_sock_cfg *cfg) { struct sock *sk = sock->sk; /* Disable multicast loopback */ - inet_sk(sk)->mc_loop = 0; + inet_clear_bit(MC_LOOP, sk); /* Enable CHECKSUM_UNNECESSARY to CHECKSUM_COMPLETE conversion */ inet_inc_convert_csum(sk); @@ -72,12 +83,19 @@ void setup_udp_tunnel_sock(struct net *net, struct socket *sock, udp_sk(sk)->encap_type = cfg->encap_type; udp_sk(sk)->encap_rcv = cfg->encap_rcv; + udp_sk(sk)->encap_err_rcv = cfg->encap_err_rcv; udp_sk(sk)->encap_err_lookup = cfg->encap_err_lookup; udp_sk(sk)->encap_destroy = cfg->encap_destroy; udp_sk(sk)->gro_receive = cfg->gro_receive; udp_sk(sk)->gro_complete = cfg->gro_complete; - udp_tunnel_encap_enable(sock); + udp_tunnel_encap_enable(sk); + + udp_tunnel_update_gro_rcv(sk, true); + + if (!sk->sk_dport && !sk->sk_bound_dev_if && sk_saddr_any(sk) && + sk->sk_kern_sock) + udp_tunnel_update_gro_lookup(net, sk, true); } EXPORT_SYMBOL_GPL(setup_udp_tunnel_sock); @@ -117,15 +135,17 @@ void udp_tunnel_notify_add_rx_port(struct socket *sock, unsigned short type) struct udp_tunnel_info ti; struct net_device *dev; + ASSERT_RTNL(); + ti.type = type; ti.sa_family = sk->sk_family; ti.port = inet_sk(sk)->inet_sport; - rcu_read_lock(); - for_each_netdev_rcu(net, dev) { + for_each_netdev(net, dev) { + udp_tunnel_nic_lock(dev); udp_tunnel_nic_add_port(dev, &ti); + udp_tunnel_nic_unlock(dev); } - rcu_read_unlock(); } EXPORT_SYMBOL_GPL(udp_tunnel_notify_add_rx_port); @@ -137,22 +157,24 @@ void udp_tunnel_notify_del_rx_port(struct socket *sock, unsigned short type) struct udp_tunnel_info ti; struct net_device *dev; + ASSERT_RTNL(); + ti.type = type; ti.sa_family = sk->sk_family; ti.port = inet_sk(sk)->inet_sport; - rcu_read_lock(); - for_each_netdev_rcu(net, dev) { + for_each_netdev(net, dev) { + udp_tunnel_nic_lock(dev); udp_tunnel_nic_del_port(dev, &ti); + udp_tunnel_nic_unlock(dev); } - rcu_read_unlock(); } EXPORT_SYMBOL_GPL(udp_tunnel_notify_del_rx_port); void udp_tunnel_xmit_skb(struct rtable *rt, struct sock *sk, struct sk_buff *skb, __be32 src, __be32 dst, __u8 tos, __u8 ttl, __be16 df, __be16 src_port, __be16 dst_port, - bool xnet, bool nocheck) + bool xnet, bool nocheck, u16 ipcb_flags) { struct udphdr *uh; @@ -168,20 +190,23 @@ void udp_tunnel_xmit_skb(struct rtable *rt, struct sock *sk, struct sk_buff *skb udp_set_csum(nocheck, skb, src, dst, skb->len); - iptunnel_xmit(sk, rt, skb, src, dst, IPPROTO_UDP, tos, ttl, df, xnet); + iptunnel_xmit(sk, rt, skb, src, dst, IPPROTO_UDP, tos, ttl, df, xnet, + ipcb_flags); } EXPORT_SYMBOL_GPL(udp_tunnel_xmit_skb); void udp_tunnel_sock_release(struct socket *sock) { rcu_assign_sk_user_data(sock->sk, NULL); + synchronize_rcu(); kernel_sock_shutdown(sock, SHUT_RDWR); sock_release(sock); } EXPORT_SYMBOL_GPL(udp_tunnel_sock_release); struct metadata_dst *udp_tun_rx_dst(struct sk_buff *skb, unsigned short family, - __be16 flags, __be64 tunnel_id, int md_size) + const unsigned long *flags, + __be64 tunnel_id, int md_size) { struct metadata_dst *tun_dst; struct ip_tunnel_info *info; @@ -197,9 +222,59 @@ struct metadata_dst *udp_tun_rx_dst(struct sk_buff *skb, unsigned short family, info->key.tp_src = udp_hdr(skb)->source; info->key.tp_dst = udp_hdr(skb)->dest; if (udp_hdr(skb)->check) - info->key.tun_flags |= TUNNEL_CSUM; + __set_bit(IP_TUNNEL_CSUM_BIT, info->key.tun_flags); return tun_dst; } EXPORT_SYMBOL_GPL(udp_tun_rx_dst); +struct rtable *udp_tunnel_dst_lookup(struct sk_buff *skb, + struct net_device *dev, + struct net *net, int oif, + __be32 *saddr, + const struct ip_tunnel_key *key, + __be16 sport, __be16 dport, u8 tos, + struct dst_cache *dst_cache) +{ + struct rtable *rt = NULL; + struct flowi4 fl4; + +#ifdef CONFIG_DST_CACHE + if (dst_cache) { + rt = dst_cache_get_ip4(dst_cache, saddr); + if (rt) + return rt; + } +#endif + + memset(&fl4, 0, sizeof(fl4)); + fl4.flowi4_mark = skb->mark; + fl4.flowi4_proto = IPPROTO_UDP; + fl4.flowi4_oif = oif; + fl4.daddr = key->u.ipv4.dst; + fl4.saddr = key->u.ipv4.src; + fl4.fl4_dport = dport; + fl4.fl4_sport = sport; + fl4.flowi4_dscp = inet_dsfield_to_dscp(tos); + fl4.flowi4_flags = key->flow_flags; + + rt = ip_route_output_key(net, &fl4); + if (IS_ERR(rt)) { + netdev_dbg(dev, "no route to %pI4\n", &fl4.daddr); + return ERR_PTR(-ENETUNREACH); + } + if (rt->dst.dev == dev) { /* is this necessary? */ + netdev_dbg(dev, "circular route to %pI4\n", &fl4.daddr); + ip_rt_put(rt); + return ERR_PTR(-ELOOP); + } +#ifdef CONFIG_DST_CACHE + if (dst_cache) + dst_cache_set_ip4(dst_cache, &rt->dst, fl4.saddr); +#endif + *saddr = fl4.saddr; + return rt; +} +EXPORT_SYMBOL_GPL(udp_tunnel_dst_lookup); + +MODULE_DESCRIPTION("IPv4 Foo over UDP tunnel driver"); MODULE_LICENSE("GPL"); diff --git a/net/ipv4/udp_tunnel_nic.c b/net/ipv4/udp_tunnel_nic.c index b91003538d87..944b3cf25468 100644 --- a/net/ipv4/udp_tunnel_nic.c +++ b/net/ipv4/udp_tunnel_nic.c @@ -29,6 +29,7 @@ struct udp_tunnel_nic_table_entry { * struct udp_tunnel_nic - UDP tunnel port offload state * @work: async work for talking to hardware from process context * @dev: netdev pointer + * @lock: protects all fields * @need_sync: at least one port start changed * @need_replay: space was freed, we need a replay of all ports * @work_pending: @work is currently scheduled @@ -41,13 +42,15 @@ struct udp_tunnel_nic { struct net_device *dev; + struct mutex lock; + u8 need_sync:1; u8 need_replay:1; u8 work_pending:1; unsigned int n_tables; unsigned long missed; - struct udp_tunnel_nic_table_entry **entries; + struct udp_tunnel_nic_table_entry *entries[] __counted_by(n_tables); }; /* We ensure all work structs are done using driver state, but not the code. @@ -298,22 +301,11 @@ __udp_tunnel_nic_device_sync(struct net_device *dev, struct udp_tunnel_nic *utn) static void udp_tunnel_nic_device_sync(struct net_device *dev, struct udp_tunnel_nic *utn) { - const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; - bool may_sleep; - if (!utn->need_sync) return; - /* Drivers which sleep in the callback need to update from - * the workqueue, if we come from the tunnel driver's notification. - */ - may_sleep = info->flags & UDP_TUNNEL_NIC_INFO_MAY_SLEEP; - if (!may_sleep) - __udp_tunnel_nic_device_sync(dev, utn); - if (may_sleep || utn->need_replay) { - queue_work(udp_tunnel_nic_workqueue, &utn->work); - utn->work_pending = 1; - } + queue_work(udp_tunnel_nic_workqueue, &utn->work); + utn->work_pending = 1; } static bool @@ -554,12 +546,12 @@ static void __udp_tunnel_nic_reset_ntf(struct net_device *dev) struct udp_tunnel_nic *utn; unsigned int i, j; - ASSERT_RTNL(); - utn = dev->udp_tunnel_nic; if (!utn) return; + mutex_lock(&utn->lock); + utn->need_sync = false; for (i = 0; i < utn->n_tables; i++) for (j = 0; j < info->tables[i].n_entries; j++) { @@ -569,7 +561,7 @@ static void __udp_tunnel_nic_reset_ntf(struct net_device *dev) entry->flags &= ~(UDP_TUNNEL_NIC_ENTRY_DEL | UDP_TUNNEL_NIC_ENTRY_OP_FAIL); - /* We don't release rtnl across ops */ + /* We don't release utn lock across ops */ WARN_ON(entry->flags & UDP_TUNNEL_NIC_ENTRY_FROZEN); if (!entry->use_cnt) continue; @@ -579,6 +571,8 @@ static void __udp_tunnel_nic_reset_ntf(struct net_device *dev) } __udp_tunnel_nic_device_sync(dev, utn); + + mutex_unlock(&utn->lock); } static size_t @@ -624,6 +618,8 @@ __udp_tunnel_nic_dump_write(struct net_device *dev, unsigned int table, continue; nest = nla_nest_start(skb, ETHTOOL_A_TUNNEL_UDP_TABLE_ENTRY); + if (!nest) + return -EMSGSIZE; if (nla_put_be16(skb, ETHTOOL_A_TUNNEL_UDP_ENTRY_PORT, utn->entries[table][j].port) || @@ -641,6 +637,33 @@ err_cancel: return -EMSGSIZE; } +static void __udp_tunnel_nic_assert_locked(struct net_device *dev) +{ + struct udp_tunnel_nic *utn; + + utn = dev->udp_tunnel_nic; + if (utn) + lockdep_assert_held(&utn->lock); +} + +static void __udp_tunnel_nic_lock(struct net_device *dev) +{ + struct udp_tunnel_nic *utn; + + utn = dev->udp_tunnel_nic; + if (utn) + mutex_lock(&utn->lock); +} + +static void __udp_tunnel_nic_unlock(struct net_device *dev) +{ + struct udp_tunnel_nic *utn; + + utn = dev->udp_tunnel_nic; + if (utn) + mutex_unlock(&utn->lock); +} + static const struct udp_tunnel_nic_ops __udp_tunnel_nic_ops = { .get_port = __udp_tunnel_nic_get_port, .set_port_priv = __udp_tunnel_nic_set_port_priv, @@ -649,6 +672,9 @@ static const struct udp_tunnel_nic_ops __udp_tunnel_nic_ops = { .reset_ntf = __udp_tunnel_nic_reset_ntf, .dump_size = __udp_tunnel_nic_dump_size, .dump_write = __udp_tunnel_nic_dump_write, + .assert_locked = __udp_tunnel_nic_assert_locked, + .lock = __udp_tunnel_nic_lock, + .unlock = __udp_tunnel_nic_unlock, }; static void @@ -708,11 +734,15 @@ static void udp_tunnel_nic_device_sync_work(struct work_struct *work) container_of(work, struct udp_tunnel_nic, work); rtnl_lock(); + mutex_lock(&utn->lock); + utn->work_pending = 0; __udp_tunnel_nic_device_sync(utn->dev, utn); if (utn->need_replay) udp_tunnel_nic_replay(utn->dev, utn); + + mutex_unlock(&utn->lock); rtnl_unlock(); } @@ -723,15 +753,12 @@ udp_tunnel_nic_alloc(const struct udp_tunnel_nic_info *info, struct udp_tunnel_nic *utn; unsigned int i; - utn = kzalloc(sizeof(*utn), GFP_KERNEL); + utn = kzalloc(struct_size(utn, entries, n_tables), GFP_KERNEL); if (!utn) return NULL; utn->n_tables = n_tables; INIT_WORK(&utn->work, udp_tunnel_nic_device_sync_work); - - utn->entries = kmalloc_array(n_tables, sizeof(void *), GFP_KERNEL); - if (!utn->entries) - goto err_free_utn; + mutex_init(&utn->lock); for (i = 0; i < n_tables; i++) { utn->entries[i] = kcalloc(info->tables[i].n_entries, @@ -745,8 +772,6 @@ udp_tunnel_nic_alloc(const struct udp_tunnel_nic_info *info, err_free_prev_entries: while (i--) kfree(utn->entries[i]); - kfree(utn->entries); -err_free_utn: kfree(utn); return NULL; } @@ -757,7 +782,6 @@ static void udp_tunnel_nic_free(struct udp_tunnel_nic *utn) for (i = 0; i < utn->n_tables; i++) kfree(utn->entries[i]); - kfree(utn->entries); kfree(utn); } @@ -826,8 +850,11 @@ static int udp_tunnel_nic_register(struct net_device *dev) dev_hold(dev); dev->udp_tunnel_nic = utn; - if (!(info->flags & UDP_TUNNEL_NIC_INFO_OPEN_ONLY)) + if (!(info->flags & UDP_TUNNEL_NIC_INFO_OPEN_ONLY)) { + udp_tunnel_nic_lock(dev); udp_tunnel_get_rx_info(dev); + udp_tunnel_nic_unlock(dev); + } return 0; } @@ -837,6 +864,8 @@ udp_tunnel_nic_unregister(struct net_device *dev, struct udp_tunnel_nic *utn) { const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + udp_tunnel_nic_lock(dev); + /* For a shared table remove this dev from the list of sharing devices * and if there are other devices just detach. */ @@ -846,8 +875,10 @@ udp_tunnel_nic_unregister(struct net_device *dev, struct udp_tunnel_nic *utn) list_for_each_entry(node, &info->shared->devices, list) if (node->dev == dev) break; - if (node->dev != dev) + if (list_entry_is_head(node, &info->shared->devices, list)) { + udp_tunnel_nic_unlock(dev); return; + } list_del(&node->list); kfree(node); @@ -857,6 +888,7 @@ udp_tunnel_nic_unregister(struct net_device *dev, struct udp_tunnel_nic *utn) if (first) { udp_tunnel_drop_rx_info(dev); utn->dev = first->dev; + udp_tunnel_nic_unlock(dev); goto release_dev; } @@ -867,6 +899,7 @@ udp_tunnel_nic_unregister(struct net_device *dev, struct udp_tunnel_nic *utn) * from the work which we will boot immediately. */ udp_tunnel_nic_flush(dev, utn); + udp_tunnel_nic_unlock(dev); /* Wait for the work to be done using the state, netdev core will * retry unregister until we give up our reference on this device. @@ -897,7 +930,7 @@ udp_tunnel_nic_netdevice_event(struct notifier_block *unused, err = udp_tunnel_nic_register(dev); if (err) - netdev_WARN(dev, "failed to register for UDP tunnel offloads: %d", err); + netdev_warn(dev, "failed to register for UDP tunnel offloads: %d", err); return notifier_from_errno(err); } /* All other events will need the udp_tunnel_nic state */ @@ -915,12 +948,16 @@ udp_tunnel_nic_netdevice_event(struct notifier_block *unused, return NOTIFY_DONE; if (event == NETDEV_UP) { + udp_tunnel_nic_lock(dev); WARN_ON(!udp_tunnel_nic_is_empty(dev, utn)); udp_tunnel_get_rx_info(dev); + udp_tunnel_nic_unlock(dev); return NOTIFY_OK; } if (event == NETDEV_GOING_DOWN) { + udp_tunnel_nic_lock(dev); udp_tunnel_nic_flush(dev, utn); + udp_tunnel_nic_unlock(dev); return NOTIFY_OK; } diff --git a/net/ipv4/udplite.c b/net/ipv4/udplite.c index cd1cd68adeec..d3e621a11a1a 100644 --- a/net/ipv4/udplite.c +++ b/net/ipv4/udplite.c @@ -17,6 +17,15 @@ struct udp_table udplite_table __read_mostly; EXPORT_SYMBOL(udplite_table); +/* Designate sk as UDP-Lite socket */ +static int udplite_sk_init(struct sock *sk) +{ + udp_init_sock(sk); + pr_warn_once("UDP-Lite is deprecated and scheduled to be removed in 2025, " + "please contact the netdev mailing list\n"); + return 0; +} + static int udplite_rcv(struct sk_buff *skb) { return __udp4_lib_rcv(skb, &udplite_table, IPPROTO_UDPLITE); @@ -46,13 +55,17 @@ struct proto udplite_prot = { .getsockopt = udp_getsockopt, .sendmsg = udp_sendmsg, .recvmsg = udp_recvmsg, - .sendpage = udp_sendpage, .hash = udp_lib_hash, .unhash = udp_lib_unhash, .rehash = udp_v4_rehash, .get_port = udp_v4_get_port, - .memory_allocated = &udp_memory_allocated, + + .memory_allocated = &net_aligned_data.udp_memory_allocated, + .per_cpu_fw_alloc = &udp_memory_per_cpu_fw_alloc, + .sysctl_mem = sysctl_udp_mem, + .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), + .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_udp_rmem_min), .obj_size = sizeof(struct udp_sock), .h.udp_table = &udplite_table, }; diff --git a/net/ipv4/xfrm4_input.c b/net/ipv4/xfrm4_input.c index ad2afeef4f10..f28cfd88eaf5 100644 --- a/net/ipv4/xfrm4_input.c +++ b/net/ipv4/xfrm4_input.c @@ -17,6 +17,8 @@ #include <linux/netfilter_ipv4.h> #include <net/ip.h> #include <net/xfrm.h> +#include <net/protocol.h> +#include <net/gro.h> static int xfrm4_rcv_encap_finish2(struct net *net, struct sock *sk, struct sk_buff *skb) @@ -31,7 +33,7 @@ static inline int xfrm4_rcv_encap_finish(struct net *net, struct sock *sk, const struct iphdr *iph = ip_hdr(skb); if (ip_route_input_noref(skb, iph->daddr, iph->saddr, - iph->tos, skb->dev)) + ip4h_dscp(iph), skb->dev)) goto drop; } @@ -56,12 +58,16 @@ int xfrm4_transport_finish(struct sk_buff *skb, int async) return -iph->protocol; #endif - __skb_push(skb, skb->data - skb_network_header(skb)); + __skb_push(skb, -skb_network_offset(skb)); iph->tot_len = htons(skb->len); ip_send_check(iph); if (xo && (xo->flags & XFRM_GRO)) { - skb_mac_header_rebuild(skb); + /* The full l2 header needs to be preserved so that re-injecting the packet at l2 + * works correctly in the presence of vlan tags. + */ + skb_mac_header_rebuild_full(skb, xo->orig_mac_len); + skb_reset_network_header(skb); skb_reset_transport_header(skb); return 0; } @@ -72,24 +78,17 @@ int xfrm4_transport_finish(struct sk_buff *skb, int async) return 0; } -/* If it's a keepalive packet, then just eat it. - * If it's an encapsulated packet, then pass it to the - * IPsec xfrm input. - * Returns 0 if skb passed to xfrm or was dropped. - * Returns >0 if skb should be passed to UDP. - * Returns <0 if skb should be resubmitted (-ret is protocol) - */ -int xfrm4_udp_encap_rcv(struct sock *sk, struct sk_buff *skb) +static int __xfrm4_udp_encap_rcv(struct sock *sk, struct sk_buff *skb, bool pull) { struct udp_sock *up = udp_sk(sk); struct udphdr *uh; struct iphdr *iph; int iphlen, len; - __u8 *udpdata; __be32 *udpdata32; - __u16 encap_type = up->encap_type; + u16 encap_type; + encap_type = READ_ONCE(up->encap_type); /* if this is not encapsulated socket, then just return now */ if (!encap_type) return 1; @@ -110,7 +109,7 @@ int xfrm4_udp_encap_rcv(struct sock *sk, struct sk_buff *skb) case UDP_ENCAP_ESPINUDP: /* Check if this is a keepalive packet. If so, eat it. */ if (len == 1 && udpdata[0] == 0xff) { - goto drop; + return -EINVAL; } else if (len > sizeof(struct ip_esp_hdr) && udpdata32[0] != 0) { /* ESP Packet without Non-ESP header */ len = sizeof(struct udphdr); @@ -118,19 +117,6 @@ int xfrm4_udp_encap_rcv(struct sock *sk, struct sk_buff *skb) /* Must be an IKE packet.. pass it through */ return 1; break; - case UDP_ENCAP_ESPINUDP_NON_IKE: - /* Check if this is a keepalive packet. If so, eat it. */ - if (len == 1 && udpdata[0] == 0xff) { - goto drop; - } else if (len > 2 * sizeof(u32) + sizeof(struct ip_esp_hdr) && - udpdata32[0] == 0 && udpdata32[1] == 0) { - - /* ESP Packet with Non-IKE marker */ - len = sizeof(struct udphdr) + 2 * sizeof(u32); - } else - /* Must be an IKE packet.. pass it through */ - return 1; - break; } /* At this point we are sure that this is an ESPinUDP packet, @@ -139,7 +125,7 @@ int xfrm4_udp_encap_rcv(struct sock *sk, struct sk_buff *skb) * protocol to ESP, and then call into the transform receiver. */ if (skb_unclone(skb, GFP_ATOMIC)) - goto drop; + return -EINVAL; /* Now we can update and verify the packet length... */ iph = ip_hdr(skb); @@ -147,24 +133,94 @@ int xfrm4_udp_encap_rcv(struct sock *sk, struct sk_buff *skb) iph->tot_len = htons(ntohs(iph->tot_len) - len); if (skb->len < iphlen + len) { /* packet is too small!?! */ - goto drop; + return -EINVAL; } /* pull the data buffer up to the ESP header and set the * transport header to point to ESP. Keep UDP on the stack * for later. */ - __skb_pull(skb, len); - skb_reset_transport_header(skb); + if (pull) { + __skb_pull(skb, len); + skb_reset_transport_header(skb); + } else { + skb_set_transport_header(skb, len); + } /* process ESP */ - return xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, encap_type); - -drop: - kfree_skb(skb); return 0; } +/* If it's a keepalive packet, then just eat it. + * If it's an encapsulated packet, then pass it to the + * IPsec xfrm input. + * Returns 0 if skb passed to xfrm or was dropped. + * Returns >0 if skb should be passed to UDP. + * Returns <0 if skb should be resubmitted (-ret is protocol) + */ +int xfrm4_udp_encap_rcv(struct sock *sk, struct sk_buff *skb) +{ + int ret; + + ret = __xfrm4_udp_encap_rcv(sk, skb, true); + if (!ret) + return xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, + udp_sk(sk)->encap_type); + + if (ret < 0) { + kfree_skb(skb); + return 0; + } + + return ret; +} +EXPORT_SYMBOL(xfrm4_udp_encap_rcv); + +struct sk_buff *xfrm4_gro_udp_encap_rcv(struct sock *sk, struct list_head *head, + struct sk_buff *skb) +{ + int offset = skb_gro_offset(skb); + const struct net_offload *ops; + struct sk_buff *pp = NULL; + int len, dlen; + __u8 *udpdata; + __be32 *udpdata32; + + len = skb->len - offset; + dlen = offset + min(len, 8); + udpdata = skb_gro_header(skb, dlen, offset); + udpdata32 = (__be32 *)udpdata; + if (unlikely(!udpdata)) + return NULL; + + rcu_read_lock(); + ops = rcu_dereference(inet_offloads[IPPROTO_ESP]); + if (!ops || !ops->callbacks.gro_receive) + goto out; + + /* check if it is a keepalive or IKE packet */ + if (len <= sizeof(struct ip_esp_hdr) || udpdata32[0] == 0) + goto out; + + /* set the transport header to ESP */ + skb_set_transport_header(skb, offset); + + NAPI_GRO_CB(skb)->proto = IPPROTO_UDP; + + pp = call_gro_receive(ops->callbacks.gro_receive, head, skb); + rcu_read_unlock(); + + return pp; + +out: + rcu_read_unlock(); + NAPI_GRO_CB(skb)->same_flow = 0; + NAPI_GRO_CB(skb)->flush = 1; + + return NULL; +} +EXPORT_SYMBOL(xfrm4_gro_udp_encap_rcv); + int xfrm4_rcv(struct sk_buff *skb) { return xfrm4_rcv_spi(skb, ip_hdr(skb)->protocol, 0); diff --git a/net/ipv4/xfrm4_output.c b/net/ipv4/xfrm4_output.c index 3cff51ba72bb..0ae67d537499 100644 --- a/net/ipv4/xfrm4_output.c +++ b/net/ipv4/xfrm4_output.c @@ -31,7 +31,7 @@ static int __xfrm4_output(struct net *net, struct sock *sk, struct sk_buff *skb) int xfrm4_output(struct net *net, struct sock *sk, struct sk_buff *skb) { return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, - net, sk, skb, skb->dev, skb_dst(skb)->dev, + net, sk, skb, skb->dev, skb_dst_dev(skb), __xfrm4_output, !(IPCB(skb)->flags & IPSKB_REROUTED)); } diff --git a/net/ipv4/xfrm4_policy.c b/net/ipv4/xfrm4_policy.c index 9e83bcb6bc99..58faf1ddd2b1 100644 --- a/net/ipv4/xfrm4_policy.c +++ b/net/ipv4/xfrm4_policy.c @@ -14,52 +14,47 @@ #include <linux/inetdevice.h> #include <net/dst.h> #include <net/xfrm.h> +#include <net/flow.h> #include <net/ip.h> #include <net/l3mdev.h> -static struct dst_entry *__xfrm4_dst_lookup(struct net *net, struct flowi4 *fl4, - int tos, int oif, - const xfrm_address_t *saddr, - const xfrm_address_t *daddr, - u32 mark) +static struct dst_entry *__xfrm4_dst_lookup(struct flowi4 *fl4, + const struct xfrm_dst_lookup_params *params) { struct rtable *rt; memset(fl4, 0, sizeof(*fl4)); - fl4->daddr = daddr->a4; - fl4->flowi4_tos = tos; - fl4->flowi4_oif = l3mdev_master_ifindex_by_index(net, oif); - fl4->flowi4_mark = mark; - if (saddr) - fl4->saddr = saddr->a4; - - fl4->flowi4_flags = FLOWI_FLAG_SKIP_NH_OIF; - - rt = __ip_route_output_key(net, fl4); + fl4->daddr = params->daddr->a4; + fl4->flowi4_dscp = params->dscp; + fl4->flowi4_l3mdev = l3mdev_master_ifindex_by_index(params->net, + params->oif); + fl4->flowi4_mark = params->mark; + if (params->saddr) + fl4->saddr = params->saddr->a4; + fl4->flowi4_proto = params->ipproto; + fl4->uli = params->uli; + + rt = __ip_route_output_key(params->net, fl4); if (!IS_ERR(rt)) return &rt->dst; return ERR_CAST(rt); } -static struct dst_entry *xfrm4_dst_lookup(struct net *net, int tos, int oif, - const xfrm_address_t *saddr, - const xfrm_address_t *daddr, - u32 mark) +static struct dst_entry *xfrm4_dst_lookup(const struct xfrm_dst_lookup_params *params) { struct flowi4 fl4; - return __xfrm4_dst_lookup(net, &fl4, tos, oif, saddr, daddr, mark); + return __xfrm4_dst_lookup(&fl4, params); } -static int xfrm4_get_saddr(struct net *net, int oif, - xfrm_address_t *saddr, xfrm_address_t *daddr, - u32 mark) +static int xfrm4_get_saddr(xfrm_address_t *saddr, + const struct xfrm_dst_lookup_params *params) { struct dst_entry *dst; struct flowi4 fl4; - dst = __xfrm4_dst_lookup(net, &fl4, 0, oif, NULL, daddr, mark); + dst = __xfrm4_dst_lookup(&fl4, params); if (IS_ERR(dst)) return -EHOSTUNREACH; @@ -71,13 +66,13 @@ static int xfrm4_get_saddr(struct net *net, int oif, static int xfrm4_fill_dst(struct xfrm_dst *xdst, struct net_device *dev, const struct flowi *fl) { - struct rtable *rt = (struct rtable *)xdst->route; + struct rtable *rt = dst_rtable(xdst->route); const struct flowi4 *fl4 = &fl->u.ip4; xdst->u.rt.rt_iif = fl4->flowi4_iif; xdst->u.dst.dev = dev; - dev_hold_track(dev, &xdst->u.dst.dev_tracker, GFP_ATOMIC); + netdev_hold(dev, &xdst->u.dst.dev_tracker, GFP_ATOMIC); /* Sheit... I remember I did this right. Apparently, * it was magically lost, so this code needs audit */ @@ -93,7 +88,6 @@ static int xfrm4_fill_dst(struct xfrm_dst *xdst, struct net_device *dev, xdst->u.rt.rt_gw6 = rt->rt_gw6; xdst->u.rt.rt_pmtu = rt->rt_pmtu; xdst->u.rt.rt_mtu_locked = rt->rt_mtu_locked; - INIT_LIST_HEAD(&xdst->u.rt.rt_uncached); rt_add_uncached_list(&xdst->u.rt); return 0; @@ -123,27 +117,17 @@ static void xfrm4_dst_destroy(struct dst_entry *dst) struct xfrm_dst *xdst = (struct xfrm_dst *)dst; dst_destroy_metrics_generic(dst); - if (xdst->u.rt.rt_uncached_list) - rt_del_uncached_list(&xdst->u.rt); + rt_del_uncached_list(&xdst->u.rt); xfrm_dst_destroy(xdst); } -static void xfrm4_dst_ifdown(struct dst_entry *dst, struct net_device *dev, - int unregister) -{ - if (!unregister) - return; - - xfrm_dst_ifdown(dst, dev); -} - static struct dst_ops xfrm4_dst_ops_template = { .family = AF_INET, .update_pmtu = xfrm4_update_pmtu, .redirect = xfrm4_redirect, .cow_metrics = dst_cow_metrics_generic, .destroy = xfrm4_dst_destroy, - .ifdown = xfrm4_dst_ifdown, + .ifdown = xfrm_dst_ifdown, .local_out = __ip_local_out, .gc_thresh = 32768, }; @@ -165,7 +149,6 @@ static struct ctl_table xfrm4_policy_table[] = { .mode = 0644, .proc_handler = proc_dointvec, }, - { } }; static __net_init int xfrm4_net_sysctl_init(struct net *net) @@ -182,7 +165,8 @@ static __net_init int xfrm4_net_sysctl_init(struct net *net) table[0].data = &net->xfrm.xfrm4_dst_ops.gc_thresh; } - hdr = register_net_sysctl(net, "net/ipv4", table); + hdr = register_net_sysctl_sz(net, "net/ipv4", table, + ARRAY_SIZE(xfrm4_policy_table)); if (!hdr) goto err_reg; @@ -198,7 +182,7 @@ err_alloc: static __net_exit void xfrm4_net_sysctl_exit(struct net *net) { - struct ctl_table *table; + const struct ctl_table *table; if (!net->ipv4.xfrm4_hdr) return; diff --git a/net/ipv4/xfrm4_protocol.c b/net/ipv4/xfrm4_protocol.c index 2fe5860c21d6..4ee624d8e66f 100644 --- a/net/ipv4/xfrm4_protocol.c +++ b/net/ipv4/xfrm4_protocol.c @@ -76,7 +76,7 @@ int xfrm4_rcv_encap(struct sk_buff *skb, int nexthdr, __be32 spi, const struct iphdr *iph = ip_hdr(skb); if (ip_route_input_noref(skb, iph->daddr, iph->saddr, - iph->tos, skb->dev)) + ip4h_dscp(iph), skb->dev)) goto drop; } @@ -304,4 +304,3 @@ void __init xfrm4_protocol_init(void) { xfrm_input_register_afinfo(&xfrm4_input_afinfo); } -EXPORT_SYMBOL(xfrm4_protocol_init); diff --git a/net/ipv4/xfrm4_tunnel.c b/net/ipv4/xfrm4_tunnel.c index 9d4f418f1bf8..8cb266af1393 100644 --- a/net/ipv4/xfrm4_tunnel.c +++ b/net/ipv4/xfrm4_tunnel.c @@ -22,13 +22,17 @@ static int ipip_xfrm_rcv(struct xfrm_state *x, struct sk_buff *skb) return ip_hdr(skb)->protocol; } -static int ipip_init_state(struct xfrm_state *x) +static int ipip_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) { - if (x->props.mode != XFRM_MODE_TUNNEL) + if (x->props.mode != XFRM_MODE_TUNNEL) { + NL_SET_ERR_MSG(extack, "IPv4 tunnel can only be used with tunnel mode"); return -EINVAL; + } - if (x->encap) + if (x->encap) { + NL_SET_ERR_MSG(extack, "IPv4 tunnel is not compatible with encapsulation"); return -EINVAL; + } x->props.header_len = sizeof(struct iphdr); @@ -110,5 +114,6 @@ static void __exit ipip_fini(void) module_init(ipip_init); module_exit(ipip_fini); +MODULE_DESCRIPTION("IPv4 XFRM tunnel driver"); MODULE_LICENSE("GPL"); MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_IPIP); |
