diff options
Diffstat (limited to 'net/sunrpc/xprtsock.c')
| -rw-r--r-- | net/sunrpc/xprtsock.c | 251 |
1 files changed, 166 insertions, 85 deletions
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index 9f010369100a..2e1fe6013361 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -47,7 +47,7 @@ #include <net/checksum.h> #include <net/udp.h> #include <net/tcp.h> -#include <net/tls.h> +#include <net/tls_prot.h> #include <net/handshake.h> #include <linux/bvec.h> @@ -62,6 +62,7 @@ #include "sunrpc.h" static void xs_close(struct rpc_xprt *xprt); +static void xs_reset_srcport(struct sock_xprt *transport); static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock); static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, struct socket *sock); @@ -159,7 +160,6 @@ static struct ctl_table xs_tunables_table[] = { .mode = 0644, .proc_handler = proc_dointvec_jiffies, }, - { }, }; /* @@ -358,44 +358,61 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp) static int xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, - struct cmsghdr *cmsg, int ret) -{ - if (cmsg->cmsg_level == SOL_TLS && - cmsg->cmsg_type == TLS_GET_RECORD_TYPE) { - u8 content_type = *((u8 *)CMSG_DATA(cmsg)); - - switch (content_type) { - case TLS_RECORD_TYPE_DATA: - /* TLS sets EOR at the end of each application data - * record, even though there might be more frames - * waiting to be decrypted. - */ - msg->msg_flags &= ~MSG_EOR; - break; - case TLS_RECORD_TYPE_ALERT: - ret = -ENOTCONN; - break; - default: - ret = -EAGAIN; - } + unsigned int *msg_flags, struct cmsghdr *cmsg, int ret) +{ + u8 content_type = tls_get_record_type(sock->sk, cmsg); + u8 level, description; + + switch (content_type) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + *msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(sock->sk, msg, &level, &description); + ret = (level == TLS_ALERT_LEVEL_FATAL) ? + -EACCES : -EAGAIN; + break; + default: + /* discard this record type */ + ret = -EAGAIN; } return ret; } static int -xs_sock_recv_cmsg(struct socket *sock, struct msghdr *msg, int flags) +xs_sock_recv_cmsg(struct socket *sock, unsigned int *msg_flags, int flags) { union { struct cmsghdr cmsg; u8 buf[CMSG_SPACE(sizeof(u8))]; } u; + u8 alert[2]; + struct kvec alert_kvec = { + .iov_base = alert, + .iov_len = sizeof(alert), + }; + struct msghdr msg = { + .msg_flags = *msg_flags, + .msg_control = &u, + .msg_controllen = sizeof(u), + }; int ret; - msg->msg_control = &u; - msg->msg_controllen = sizeof(u); - ret = sock_recvmsg(sock, msg, flags); - if (msg->msg_controllen != sizeof(u)) - ret = xs_sock_process_cmsg(sock, msg, &u.cmsg, ret); + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &alert_kvec, 1, + alert_kvec.iov_len); + ret = sock_recvmsg(sock, &msg, flags); + if (ret > 0) { + if (tls_get_record_type(sock->sk, &u.cmsg) == TLS_RECORD_TYPE_ALERT) + iov_iter_revert(&msg.msg_iter, ret); + ret = xs_sock_process_cmsg(sock, &msg, msg_flags, &u.cmsg, + -EAGAIN); + } return ret; } @@ -405,7 +422,13 @@ xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek) ssize_t ret; if (seek != 0) iov_iter_advance(&msg->msg_iter, seek); - ret = xs_sock_recv_cmsg(sock, msg, flags); + ret = sock_recvmsg(sock, msg, flags); + /* Handle TLS inband control message lazily */ + if (msg->msg_flags & MSG_CTRUNC) { + msg->msg_flags &= ~(MSG_CTRUNC | MSG_EOR); + if (ret == 0 || ret == -EIO) + ret = xs_sock_recv_cmsg(sock, &msg->msg_flags, flags); + } return ret > 0 ? ret + seek : ret; } @@ -431,7 +454,7 @@ xs_read_discard(struct socket *sock, struct msghdr *msg, int flags, size_t count) { iov_iter_discard(&msg->msg_iter, ITER_DEST, count); - return xs_sock_recv_cmsg(sock, msg, flags); + return xs_sock_recvmsg(sock, msg, flags, 0); } #if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE @@ -777,6 +800,8 @@ static void xs_stream_data_receive(struct sock_xprt *transport) } if (ret == -ESHUTDOWN) kernel_sock_shutdown(transport->sock, SHUT_RDWR); + else if (ret == -EACCES) + xprt_wake_pending_tasks(&transport->xprt, -EACCES); else xs_poll_check_readable(transport); out: @@ -878,6 +903,17 @@ static int xs_stream_prepare_request(struct rpc_rqst *req, struct xdr_buf *buf) return xdr_alloc_bvec(buf, rpc_task_gfp_mask()); } +static void xs_stream_abort_send_request(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + + if (transport->xmit.offset != 0 && + !test_bit(XPRT_CLOSE_WAIT, &xprt->state)) + xprt_force_disconnect(xprt); +} + /* * Determine if the previous message in the stream was aborted before it * could complete transmission. @@ -1176,11 +1212,13 @@ static void xs_sock_reset_state_flags(struct rpc_xprt *xprt) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + transport->xprt_err = 0; clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state); clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state); clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state); clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state); + clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); } static void xs_run_error_worker(struct sock_xprt *transport, unsigned int nr) @@ -1261,6 +1299,7 @@ static void xs_reset_transport(struct sock_xprt *transport) transport->file = NULL; sk->sk_user_data = NULL; + sk->sk_sndtimeo = 0; xs_restore_old_callbacks(transport, sk); xprt_clear_connected(xprt); @@ -1292,6 +1331,8 @@ static void xs_close(struct rpc_xprt *xprt) dprintk("RPC: xs_close xprt %p\n", xprt); + if (transport->sock) + tls_handshake_close(transport->sock); xs_reset_transport(transport); xprt->reestablish_timeout = 0; } @@ -1557,8 +1598,10 @@ static void xs_tcp_state_change(struct sock *sk) break; case TCP_CLOSE: if (test_and_clear_bit(XPRT_SOCK_CONNECTING, - &transport->sock_state)) + &transport->sock_state)) { + xs_reset_srcport(transport); xprt_clear_connecting(xprt); + } clear_bit(XPRT_CLOSING, &xprt->state); /* Trigger the socket release */ xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT); @@ -1714,6 +1757,11 @@ static void xs_set_port(struct rpc_xprt *xprt, unsigned short port) xs_update_peer_port(xprt); } +static void xs_reset_srcport(struct sock_xprt *transport) +{ + transport->srcport = 0; +} + static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock) { if (transport->srcport == 0 && transport->xprt.reuseport) @@ -1797,8 +1845,8 @@ static int xs_bind(struct sock_xprt *transport, struct socket *sock) memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen); do { rpc_set_port((struct sockaddr *)&myaddr, port); - err = kernel_bind(sock, (struct sockaddr *)&myaddr, - transport->xprt.addrlen); + err = kernel_bind(sock, (struct sockaddr_unsized *)&myaddr, + transport->xprt.addrlen); if (err == 0) { if (transport->xprt.reuseport) transport->srcport = port; @@ -1913,6 +1961,9 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt, goto out; } + if (protocol == IPPROTO_TCP) + sk_net_refcnt_upgrade(sock->sk); + filp = sock_alloc_file(sock, O_NONBLOCK, NULL); if (IS_ERR(filp)) return ERR_CAST(filp); @@ -1954,7 +2005,7 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt, xs_stream_start_connect(transport); - return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0); + return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt), xprt->addrlen, 0); } /** @@ -2230,9 +2281,13 @@ static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, struct socket *sock) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct net *net = sock_net(sock->sk); + unsigned long connect_timeout; + unsigned long syn_retries; unsigned int keepidle; unsigned int keepcnt; unsigned int timeo; + unsigned long t; spin_lock(&xprt->transport_lock); keepidle = DIV_ROUND_UP(xprt->timeout->to_initval, HZ); @@ -2250,6 +2305,35 @@ static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, /* TCP user timeout (see RFC5482) */ tcp_sock_set_user_timeout(sock->sk, timeo); + + /* Connect timeout */ + connect_timeout = max_t(unsigned long, + DIV_ROUND_UP(xprt->connect_timeout, HZ), 1); + syn_retries = max_t(unsigned long, + READ_ONCE(net->ipv4.sysctl_tcp_syn_retries), 1); + for (t = 0; t <= syn_retries && (1UL << t) < connect_timeout; t++) + ; + if (t <= syn_retries) + tcp_sock_set_syncnt(sock->sk, t - 1); +} + +static void xs_tcp_do_set_connect_timeout(struct rpc_xprt *xprt, + unsigned long connect_timeout) +{ + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct rpc_timeout to; + unsigned long initval; + + memcpy(&to, xprt->timeout, sizeof(to)); + /* Arbitrary lower limit */ + initval = max_t(unsigned long, connect_timeout, XS_TCP_INIT_REEST_TO); + to.to_initval = initval; + to.to_maxval = initval; + to.to_retries = 0; + memcpy(&transport->tcp_timeout, &to, sizeof(transport->tcp_timeout)); + xprt->timeout = &transport->tcp_timeout; + xprt->connect_timeout = connect_timeout; } static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt, @@ -2257,25 +2341,12 @@ static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt, unsigned long reconnect_timeout) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - struct rpc_timeout to; - unsigned long initval; spin_lock(&xprt->transport_lock); if (reconnect_timeout < xprt->max_reconnect_timeout) xprt->max_reconnect_timeout = reconnect_timeout; - if (connect_timeout < xprt->connect_timeout) { - memcpy(&to, xprt->timeout, sizeof(to)); - initval = DIV_ROUND_UP(connect_timeout, to.to_retries + 1); - /* Arbitrary lower limit */ - if (initval < XS_TCP_INIT_REEST_TO << 1) - initval = XS_TCP_INIT_REEST_TO << 1; - to.to_initval = initval; - to.to_maxval = initval; - memcpy(&transport->tcp_timeout, &to, - sizeof(transport->tcp_timeout)); - xprt->timeout = &transport->tcp_timeout; - xprt->connect_timeout = connect_timeout; - } + if (connect_timeout < xprt->connect_timeout) + xs_tcp_do_set_connect_timeout(xprt, connect_timeout); set_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); spin_unlock(&xprt->transport_lock); } @@ -2334,7 +2405,8 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) /* Tell the socket layer to start connecting... */ set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); - return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK); + return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt), + xprt->addrlen, O_NONBLOCK); } /** @@ -2395,6 +2467,13 @@ static void xs_tcp_setup_socket(struct work_struct *work) transport->srcport = 0; status = -EAGAIN; break; + case -EPERM: + /* Happens, for instance, if a BPF program is preventing + * the connect. Remap the error so upper layers can better + * deal with it. + */ + status = -ECONNREFUSED; + fallthrough; case -EINVAL: /* Happens, for instance, if the user specified a link * local IPv6 address without a scope-id. @@ -2406,6 +2485,7 @@ static void xs_tcp_setup_socket(struct work_struct *work) case -EHOSTUNREACH: case -EADDRINUSE: case -ENOBUFS: + case -ENOTCONN: break; default: printk("%s: connect returned unhandled error %d\n", @@ -2518,7 +2598,15 @@ static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid) struct sock_xprt *lower_transport = container_of(lower_xprt, struct sock_xprt, xprt); - lower_transport->xprt_err = status ? -EACCES : 0; + switch (status) { + case 0: + case -EACCES: + case -ETIMEDOUT: + lower_transport->xprt_err = status; + break; + default: + lower_transport->xprt_err = -EACCES; + } complete(&lower_transport->handshake_done); xprt_put(lower_xprt); } @@ -2560,11 +2648,10 @@ static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_par rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done, XS_TLS_HANDSHAKE_TO); if (rc <= 0) { - if (!tls_handshake_cancel(sk)) { - if (rc == 0) - rc = -ETIMEDOUT; - goto out_put_xprt; - } + tls_handshake_cancel(sk); + if (rc == 0) + rc = -ETIMEDOUT; + goto out_put_xprt; } rc = lower_transport->xprt_err; @@ -2617,6 +2704,7 @@ static void xs_tcp_tls_setup_socket(struct work_struct *work) .xprtsec = { .policy = RPC_XPRTSEC_NONE, }, + .stats = upper_clnt->cl_stats, }; unsigned int pflags = current->flags; struct rpc_clnt *lower_clnt; @@ -2645,6 +2733,10 @@ static void xs_tcp_tls_setup_socket(struct work_struct *work) rcu_read_lock(); lower_xprt = rcu_dereference(lower_clnt->cl_xprt); rcu_read_unlock(); + + if (wait_on_bit_lock(&lower_xprt->state, XPRT_LOCKED, TASK_KILLABLE)) + goto out_unlock; + status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec); if (status) { trace_rpc_tls_not_started(upper_clnt, upper_xprt); @@ -2654,20 +2746,15 @@ static void xs_tcp_tls_setup_socket(struct work_struct *work) status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport); if (status) goto out_close; - + xprt_release_write(lower_xprt, NULL); trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0); - if (!xprt_test_and_set_connected(upper_xprt)) { - upper_xprt->connect_cookie++; - clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); - xprt_clear_connecting(upper_xprt); - - upper_xprt->stat.connect_count++; - upper_xprt->stat.connect_time += (long)jiffies - - upper_xprt->stat.connect_start; - xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); - } rpc_shutdown_client(lower_clnt); + /* Check for ingress data that arrived before the socket's + * ->data_ready callback was set up. + */ + xs_poll_check_readable(upper_transport); + out_unlock: current_restore_flags(pflags, PF_MEMALLOC); upper_transport->clnt = NULL; @@ -2675,6 +2762,7 @@ out_unlock: return; out_close: + xprt_release_write(lower_xprt, NULL); rpc_shutdown_client(lower_clnt); /* xprt_force_disconnect() wakes tasks with a fixed tk_status code. @@ -2739,18 +2827,13 @@ static void xs_wake_error(struct sock_xprt *transport) { int sockerr; - if (!test_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state)) - return; - mutex_lock(&transport->recv_mutex); - if (transport->sock == NULL) - goto out; if (!test_and_clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state)) - goto out; + return; sockerr = xchg(&transport->xprt_err, 0); - if (sockerr < 0) + if (sockerr < 0) { xprt_wake_pending_tasks(&transport->xprt, sockerr); -out: - mutex_unlock(&transport->recv_mutex); + xs_tcp_force_close(&transport->xprt); + } } static void xs_wake_pending(struct sock_xprt *transport) @@ -2958,20 +3041,11 @@ static int bc_send_request(struct rpc_rqst *req) return len; } -/* - * The close routine. Since this is client initiated, we do nothing - */ - static void bc_close(struct rpc_xprt *xprt) { xprt_disconnect_done(xprt); } -/* - * The xprt destroy routine. Again, because this connection is client - * initiated, we do nothing - */ - static void bc_destroy(struct rpc_xprt *xprt) { dprintk("RPC: bc_destroy xprt %p\n", xprt); @@ -2992,6 +3066,7 @@ static const struct rpc_xprt_ops xs_local_ops = { .buf_free = rpc_free, .prepare_request = xs_stream_prepare_request, .send_request = xs_local_send_request, + .abort_send_request = xs_stream_abort_send_request, .wait_for_reply_request = xprt_wait_for_reply_request_def, .close = xs_close, .destroy = xs_destroy, @@ -3039,6 +3114,7 @@ static const struct rpc_xprt_ops xs_tcp_ops = { .buf_free = rpc_free, .prepare_request = xs_stream_prepare_request, .send_request = xs_tcp_send_request, + .abort_send_request = xs_stream_abort_send_request, .wait_for_reply_request = xprt_wait_for_reply_request_def, .close = xs_tcp_shutdown, .destroy = xs_destroy, @@ -3328,8 +3404,13 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) xprt->timeout = &xs_tcp_default_timeout; xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + if (args->reconnect_timeout) + xprt->max_reconnect_timeout = args->reconnect_timeout; + xprt->connect_timeout = xprt->timeout->to_initval * (xprt->timeout->to_retries + 1); + if (args->connect_timeout) + xs_tcp_do_set_connect_timeout(xprt, args->connect_timeout); INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); INIT_WORK(&transport->error_worker, xs_error_handle); |
