diff options
Diffstat (limited to 'net/sunrpc/xprtsock.c')
| -rw-r--r-- | net/sunrpc/xprtsock.c | 608 |
1 files changed, 555 insertions, 53 deletions
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index aaa5b2741b79..2e1fe6013361 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -47,17 +47,22 @@ #include <net/checksum.h> #include <net/udp.h> #include <net/tcp.h> +#include <net/tls_prot.h> +#include <net/handshake.h> + #include <linux/bvec.h> #include <linux/highmem.h> #include <linux/uio.h> #include <linux/sched/mm.h> +#include <trace/events/sock.h> #include <trace/events/sunrpc.h> #include "socklib.h" #include "sunrpc.h" static void xs_close(struct rpc_xprt *xprt); +static void xs_reset_srcport(struct sock_xprt *transport); static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock); static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, struct socket *sock); @@ -77,7 +82,7 @@ static unsigned int xs_tcp_fin_timeout __read_mostly = XS_TCP_LINGER_TO; /* * We can register our own files under /proc/sys/sunrpc by - * calling register_sysctl_table() again. The files in that + * calling register_sysctl() again. The files in that * directory become the union of all files registered there. * * We simply need to make sure that we don't collide with @@ -95,6 +100,7 @@ static struct ctl_table_header *sunrpc_table_header; static struct xprt_class xs_local_transport; static struct xprt_class xs_udp_transport; static struct xprt_class xs_tcp_transport; +static struct xprt_class xs_tcp_tls_transport; static struct xprt_class xs_bc_tcp_transport; /* @@ -154,16 +160,6 @@ static struct ctl_table xs_tunables_table[] = { .mode = 0644, .proc_handler = proc_dointvec_jiffies, }, - { }, -}; - -static struct ctl_table sunrpc_table[] = { - { - .procname = "sunrpc", - .mode = 0555, - .child = xs_tunables_table - }, - { }, }; /* @@ -195,6 +191,11 @@ static struct ctl_table sunrpc_table[] = { */ #define XS_IDLE_DISC_TO (5U * 60 * HZ) +/* + * TLS handshake timeout. + */ +#define XS_TLS_HANDSHAKE_TO (10U * HZ) + #if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # undef RPC_DEBUG_DATA # define RPCDBG_FACILITY RPCDBG_TRANS @@ -261,7 +262,12 @@ static void xs_format_common_peer_addresses(struct rpc_xprt *xprt) switch (sap->sa_family) { case AF_LOCAL: sun = xs_addr_un(xprt); - strscpy(buf, sun->sun_path, sizeof(buf)); + if (sun->sun_path[0]) { + strscpy(buf, sun->sun_path, sizeof(buf)); + } else { + buf[0] = '@'; + strscpy(buf+1, sun->sun_path+1, sizeof(buf)-1); + } xprt->address_strings[RPC_DISPLAY_ADDR] = kstrdup(buf, GFP_KERNEL); break; @@ -350,6 +356,66 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp) return want; } +static int +xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, + unsigned int *msg_flags, struct cmsghdr *cmsg, int ret) +{ + u8 content_type = tls_get_record_type(sock->sk, cmsg); + u8 level, description; + + switch (content_type) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + *msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(sock->sk, msg, &level, &description); + ret = (level == TLS_ALERT_LEVEL_FATAL) ? + -EACCES : -EAGAIN; + break; + default: + /* discard this record type */ + ret = -EAGAIN; + } + return ret; +} + +static int +xs_sock_recv_cmsg(struct socket *sock, unsigned int *msg_flags, int flags) +{ + union { + struct cmsghdr cmsg; + u8 buf[CMSG_SPACE(sizeof(u8))]; + } u; + u8 alert[2]; + struct kvec alert_kvec = { + .iov_base = alert, + .iov_len = sizeof(alert), + }; + struct msghdr msg = { + .msg_flags = *msg_flags, + .msg_control = &u, + .msg_controllen = sizeof(u), + }; + int ret; + + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &alert_kvec, 1, + alert_kvec.iov_len); + ret = sock_recvmsg(sock, &msg, flags); + if (ret > 0) { + if (tls_get_record_type(sock->sk, &u.cmsg) == TLS_RECORD_TYPE_ALERT) + iov_iter_revert(&msg.msg_iter, ret); + ret = xs_sock_process_cmsg(sock, &msg, msg_flags, &u.cmsg, + -EAGAIN); + } + return ret; +} + static ssize_t xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek) { @@ -357,6 +423,12 @@ xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek) if (seek != 0) iov_iter_advance(&msg->msg_iter, seek); ret = sock_recvmsg(sock, msg, flags); + /* Handle TLS inband control message lazily */ + if (msg->msg_flags & MSG_CTRUNC) { + msg->msg_flags &= ~(MSG_CTRUNC | MSG_EOR); + if (ret == 0 || ret == -EIO) + ret = xs_sock_recv_cmsg(sock, &msg->msg_flags, flags); + } return ret > 0 ? ret + seek : ret; } @@ -382,7 +454,7 @@ xs_read_discard(struct socket *sock, struct msghdr *msg, int flags, size_t count) { iov_iter_discard(&msg->msg_iter, ITER_DEST, count); - return sock_recvmsg(sock, msg, flags); + return xs_sock_recvmsg(sock, msg, flags, 0); } #if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE @@ -703,6 +775,8 @@ static void xs_poll_check_readable(struct sock_xprt *transport) { clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state)) + return; if (!xs_poll_socket_readable(transport)) return; if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) @@ -726,6 +800,8 @@ static void xs_stream_data_receive(struct sock_xprt *transport) } if (ret == -ESHUTDOWN) kernel_sock_shutdown(transport->sock, SHUT_RDWR); + else if (ret == -EACCES) + xprt_wake_pending_tasks(&transport->xprt, -EACCES); else xs_poll_check_readable(transport); out: @@ -827,6 +903,17 @@ static int xs_stream_prepare_request(struct rpc_rqst *req, struct xdr_buf *buf) return xdr_alloc_bvec(buf, rpc_task_gfp_mask()); } +static void xs_stream_abort_send_request(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + + if (transport->xmit.offset != 0 && + !test_bit(XPRT_CLOSE_WAIT, &xprt->state)) + xprt_force_disconnect(xprt); +} + /* * Determine if the previous message in the stream was aborted before it * could complete transmission. @@ -1125,11 +1212,13 @@ static void xs_sock_reset_state_flags(struct rpc_xprt *xprt) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + transport->xprt_err = 0; clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state); clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state); clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state); clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state); + clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); } static void xs_run_error_worker(struct sock_xprt *transport, unsigned int nr) @@ -1199,6 +1288,8 @@ static void xs_reset_transport(struct sock_xprt *transport) if (atomic_read(&transport->xprt.swapper)) sk_clear_memalloc(sk); + tls_handshake_cancel(sk); + kernel_sock_shutdown(sock, SHUT_RDWR); mutex_lock(&transport->recv_mutex); @@ -1208,6 +1299,7 @@ static void xs_reset_transport(struct sock_xprt *transport) transport->file = NULL; sk->sk_user_data = NULL; + sk->sk_sndtimeo = 0; xs_restore_old_callbacks(transport, sk); xprt_clear_connected(xprt); @@ -1239,6 +1331,8 @@ static void xs_close(struct rpc_xprt *xprt) dprintk("RPC: xs_close xprt %p\n", xprt); + if (transport->sock) + tls_handshake_close(transport->sock); xs_reset_transport(transport); xprt->reestablish_timeout = 0; } @@ -1378,6 +1472,8 @@ static void xs_data_ready(struct sock *sk) { struct rpc_xprt *xprt; + trace_sk_data_ready(sk); + xprt = xprt_from_sock(sk); if (xprt != NULL) { struct sock_xprt *transport = container_of(xprt, @@ -1386,6 +1482,10 @@ static void xs_data_ready(struct sock *sk) trace_xs_data_ready(xprt); transport->old_data_ready(sk); + + if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state)) + return; + /* Any data means we had a useful conversation, so * then we don't need to delay the next reconnect */ @@ -1498,8 +1598,10 @@ static void xs_tcp_state_change(struct sock *sk) break; case TCP_CLOSE: if (test_and_clear_bit(XPRT_SOCK_CONNECTING, - &transport->sock_state)) + &transport->sock_state)) { + xs_reset_srcport(transport); xprt_clear_connecting(xprt); + } clear_bit(XPRT_CLOSING, &xprt->state); /* Trigger the socket release */ xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT); @@ -1655,6 +1757,11 @@ static void xs_set_port(struct rpc_xprt *xprt, unsigned short port) xs_update_peer_port(xprt); } +static void xs_reset_srcport(struct sock_xprt *transport) +{ + transport->srcport = 0; +} + static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock) { if (transport->srcport == 0 && transport->xprt.reuseport) @@ -1738,8 +1845,8 @@ static int xs_bind(struct sock_xprt *transport, struct socket *sock) memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen); do { rpc_set_port((struct sockaddr *)&myaddr, port); - err = kernel_bind(sock, (struct sockaddr *)&myaddr, - transport->xprt.addrlen); + err = kernel_bind(sock, (struct sockaddr_unsized *)&myaddr, + transport->xprt.addrlen); if (err == 0) { if (transport->xprt.reuseport) transport->srcport = port; @@ -1854,6 +1961,9 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt, goto out; } + if (protocol == IPPROTO_TCP) + sk_net_refcnt_upgrade(sock->sk); + filp = sock_alloc_file(sock, O_NONBLOCK, NULL); if (IS_ERR(filp)) return ERR_CAST(filp); @@ -1895,7 +2005,7 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt, xs_stream_start_connect(transport); - return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0); + return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt), xprt->addrlen, 0); } /** @@ -2155,6 +2265,7 @@ static void xs_tcp_shutdown(struct rpc_xprt *xprt) switch (skst) { case TCP_FIN_WAIT1: case TCP_FIN_WAIT2: + case TCP_LAST_ACK: break; case TCP_ESTABLISHED: case TCP_CLOSE_WAIT: @@ -2170,9 +2281,13 @@ static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, struct socket *sock) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct net *net = sock_net(sock->sk); + unsigned long connect_timeout; + unsigned long syn_retries; unsigned int keepidle; unsigned int keepcnt; unsigned int timeo; + unsigned long t; spin_lock(&xprt->transport_lock); keepidle = DIV_ROUND_UP(xprt->timeout->to_initval, HZ); @@ -2190,6 +2305,35 @@ static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, /* TCP user timeout (see RFC5482) */ tcp_sock_set_user_timeout(sock->sk, timeo); + + /* Connect timeout */ + connect_timeout = max_t(unsigned long, + DIV_ROUND_UP(xprt->connect_timeout, HZ), 1); + syn_retries = max_t(unsigned long, + READ_ONCE(net->ipv4.sysctl_tcp_syn_retries), 1); + for (t = 0; t <= syn_retries && (1UL << t) < connect_timeout; t++) + ; + if (t <= syn_retries) + tcp_sock_set_syncnt(sock->sk, t - 1); +} + +static void xs_tcp_do_set_connect_timeout(struct rpc_xprt *xprt, + unsigned long connect_timeout) +{ + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct rpc_timeout to; + unsigned long initval; + + memcpy(&to, xprt->timeout, sizeof(to)); + /* Arbitrary lower limit */ + initval = max_t(unsigned long, connect_timeout, XS_TCP_INIT_REEST_TO); + to.to_initval = initval; + to.to_maxval = initval; + to.to_retries = 0; + memcpy(&transport->tcp_timeout, &to, sizeof(transport->tcp_timeout)); + xprt->timeout = &transport->tcp_timeout; + xprt->connect_timeout = connect_timeout; } static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt, @@ -2197,25 +2341,12 @@ static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt, unsigned long reconnect_timeout) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - struct rpc_timeout to; - unsigned long initval; spin_lock(&xprt->transport_lock); if (reconnect_timeout < xprt->max_reconnect_timeout) xprt->max_reconnect_timeout = reconnect_timeout; - if (connect_timeout < xprt->connect_timeout) { - memcpy(&to, xprt->timeout, sizeof(to)); - initval = DIV_ROUND_UP(connect_timeout, to.to_retries + 1); - /* Arbitrary lower limit */ - if (initval < XS_TCP_INIT_REEST_TO << 1) - initval = XS_TCP_INIT_REEST_TO << 1; - to.to_initval = initval; - to.to_maxval = initval; - memcpy(&transport->tcp_timeout, &to, - sizeof(transport->tcp_timeout)); - xprt->timeout = &transport->tcp_timeout; - xprt->connect_timeout = connect_timeout; - } + if (connect_timeout < xprt->connect_timeout) + xs_tcp_do_set_connect_timeout(xprt, connect_timeout); set_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); spin_unlock(&xprt->transport_lock); } @@ -2274,7 +2405,8 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) /* Tell the socket layer to start connecting... */ set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); - return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK); + return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt), + xprt->addrlen, O_NONBLOCK); } /** @@ -2335,6 +2467,13 @@ static void xs_tcp_setup_socket(struct work_struct *work) transport->srcport = 0; status = -EAGAIN; break; + case -EPERM: + /* Happens, for instance, if a BPF program is preventing + * the connect. Remap the error so upper layers can better + * deal with it. + */ + status = -ECONNREFUSED; + fallthrough; case -EINVAL: /* Happens, for instance, if the user specified a link * local IPv6 address without a scope-id. @@ -2346,6 +2485,7 @@ static void xs_tcp_setup_socket(struct work_struct *work) case -EHOSTUNREACH: case -EADDRINUSE: case -ENOBUFS: + case -ENOTCONN: break; default: printk("%s: connect returned unhandled error %d\n", @@ -2365,6 +2505,275 @@ out_unlock: current_restore_flags(pflags, PF_MEMALLOC); } +/* + * Transfer the connected socket to @upper_transport, then mark that + * xprt CONNECTED. + */ +static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt, + struct sock_xprt *upper_transport) +{ + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + struct rpc_xprt *upper_xprt = &upper_transport->xprt; + + if (!upper_transport->inet) { + struct socket *sock = lower_transport->sock; + struct sock *sk = sock->sk; + + /* Avoid temporary address, they are bad for long-lived + * connections such as NFS mounts. + * RFC4941, section 3.6 suggests that: + * Individual applications, which have specific + * knowledge about the normal duration of connections, + * MAY override this as appropriate. + */ + if (xs_addr(upper_xprt)->sa_family == PF_INET6) + ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC); + + xs_tcp_set_socket_timeouts(upper_xprt, sock); + tcp_sock_set_nodelay(sk); + + lock_sock(sk); + + /* @sk is already connected, so it now has the RPC callbacks. + * Reach into @lower_transport to save the original ones. + */ + upper_transport->old_data_ready = lower_transport->old_data_ready; + upper_transport->old_state_change = lower_transport->old_state_change; + upper_transport->old_write_space = lower_transport->old_write_space; + upper_transport->old_error_report = lower_transport->old_error_report; + sk->sk_user_data = upper_xprt; + + /* socket options */ + sock_reset_flag(sk, SOCK_LINGER); + + xprt_clear_connected(upper_xprt); + + upper_transport->sock = sock; + upper_transport->inet = sk; + upper_transport->file = lower_transport->file; + + release_sock(sk); + + /* Reset lower_transport before shutting down its clnt */ + mutex_lock(&lower_transport->recv_mutex); + lower_transport->inet = NULL; + lower_transport->sock = NULL; + lower_transport->file = NULL; + + xprt_clear_connected(lower_xprt); + xs_sock_reset_connection_flags(lower_xprt); + xs_stream_reset_connect(lower_transport); + mutex_unlock(&lower_transport->recv_mutex); + } + + if (!xprt_bound(upper_xprt)) + return -ENOTCONN; + + xs_set_memalloc(upper_xprt); + + if (!xprt_test_and_set_connected(upper_xprt)) { + upper_xprt->connect_cookie++; + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + + upper_xprt->stat.connect_count++; + upper_xprt->stat.connect_time += (long)jiffies - + upper_xprt->stat.connect_start; + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + } + return 0; +} + +/** + * xs_tls_handshake_done - TLS handshake completion handler + * @data: address of xprt to wake + * @status: status of handshake + * @peerid: serial number of key containing the remote's identity + * + */ +static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid) +{ + struct rpc_xprt *lower_xprt = data; + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + + switch (status) { + case 0: + case -EACCES: + case -ETIMEDOUT: + lower_transport->xprt_err = status; + break; + default: + lower_transport->xprt_err = -EACCES; + } + complete(&lower_transport->handshake_done); + xprt_put(lower_xprt); +} + +static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec) +{ + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + struct tls_handshake_args args = { + .ta_sock = lower_transport->sock, + .ta_done = xs_tls_handshake_done, + .ta_data = xprt_get(lower_xprt), + .ta_peername = lower_xprt->servername, + }; + struct sock *sk = lower_transport->inet; + int rc; + + init_completion(&lower_transport->handshake_done); + set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state); + lower_transport->xprt_err = -ETIMEDOUT; + switch (xprtsec->policy) { + case RPC_XPRTSEC_TLS_ANON: + rc = tls_client_hello_anon(&args, GFP_KERNEL); + if (rc) + goto out_put_xprt; + break; + case RPC_XPRTSEC_TLS_X509: + args.ta_my_cert = xprtsec->cert_serial; + args.ta_my_privkey = xprtsec->privkey_serial; + rc = tls_client_hello_x509(&args, GFP_KERNEL); + if (rc) + goto out_put_xprt; + break; + default: + rc = -EACCES; + goto out_put_xprt; + } + + rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done, + XS_TLS_HANDSHAKE_TO); + if (rc <= 0) { + tls_handshake_cancel(sk); + if (rc == 0) + rc = -ETIMEDOUT; + goto out_put_xprt; + } + + rc = lower_transport->xprt_err; + +out: + xs_stream_reset_connect(lower_transport); + clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state); + return rc; + +out_put_xprt: + xprt_put(lower_xprt); + goto out; +} + +/** + * xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket + * @work: queued work item + * + * Invoked by a work queue tasklet. + * + * For RPC-with-TLS, there is a two-stage connection process. + * + * The "upper-layer xprt" is visible to the RPC consumer. Once it has + * been marked connected, the consumer knows that a TCP connection and + * a TLS session have been established. + * + * A "lower-layer xprt", created in this function, handles the mechanics + * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and + * then driving the TLS handshake. Once all that is complete, the upper + * layer xprt is marked connected. + */ +static void xs_tcp_tls_setup_socket(struct work_struct *work) +{ + struct sock_xprt *upper_transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_clnt *upper_clnt = upper_transport->clnt; + struct rpc_xprt *upper_xprt = &upper_transport->xprt; + struct rpc_create_args args = { + .net = upper_xprt->xprt_net, + .protocol = upper_xprt->prot, + .address = (struct sockaddr *)&upper_xprt->addr, + .addrsize = upper_xprt->addrlen, + .timeout = upper_clnt->cl_timeout, + .servername = upper_xprt->servername, + .program = upper_clnt->cl_program, + .prognumber = upper_clnt->cl_prog, + .version = upper_clnt->cl_vers, + .authflavor = RPC_AUTH_TLS, + .cred = upper_clnt->cl_cred, + .xprtsec = { + .policy = RPC_XPRTSEC_NONE, + }, + .stats = upper_clnt->cl_stats, + }; + unsigned int pflags = current->flags; + struct rpc_clnt *lower_clnt; + struct rpc_xprt *lower_xprt; + int status; + + if (atomic_read(&upper_xprt->swapper)) + current->flags |= PF_MEMALLOC; + + xs_stream_start_connect(upper_transport); + + /* This implicitly sends an RPC_AUTH_TLS probe */ + lower_clnt = rpc_create(&args); + if (IS_ERR(lower_clnt)) { + trace_rpc_tls_unavailable(upper_clnt, upper_xprt); + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt)); + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + goto out_unlock; + } + + /* RPC_AUTH_TLS probe was successful. Try a TLS handshake on + * the lower xprt. + */ + rcu_read_lock(); + lower_xprt = rcu_dereference(lower_clnt->cl_xprt); + rcu_read_unlock(); + + if (wait_on_bit_lock(&lower_xprt->state, XPRT_LOCKED, TASK_KILLABLE)) + goto out_unlock; + + status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec); + if (status) { + trace_rpc_tls_not_started(upper_clnt, upper_xprt); + goto out_close; + } + + status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport); + if (status) + goto out_close; + xprt_release_write(lower_xprt, NULL); + trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0); + rpc_shutdown_client(lower_clnt); + + /* Check for ingress data that arrived before the socket's + * ->data_ready callback was set up. + */ + xs_poll_check_readable(upper_transport); + +out_unlock: + current_restore_flags(pflags, PF_MEMALLOC); + upper_transport->clnt = NULL; + xprt_unlock_connect(upper_xprt, upper_transport); + return; + +out_close: + xprt_release_write(lower_xprt, NULL); + rpc_shutdown_client(lower_clnt); + + /* xprt_force_disconnect() wakes tasks with a fixed tk_status code. + * Wake them first here to ensure they get our tk_status code. + */ + xprt_wake_pending_tasks(upper_xprt, status); + xs_tcp_force_close(upper_xprt); + xprt_clear_connecting(upper_xprt); + goto out_unlock; +} + /** * xs_connect - connect a socket to a remote endpoint * @xprt: pointer to transport structure @@ -2396,6 +2805,7 @@ static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task) } else dprintk("RPC: xs_connect scheduled xprt %p\n", xprt); + transport->clnt = task->tk_client; queue_delayed_work(xprtiod_workqueue, &transport->connect_worker, delay); @@ -2417,18 +2827,13 @@ static void xs_wake_error(struct sock_xprt *transport) { int sockerr; - if (!test_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state)) - return; - mutex_lock(&transport->recv_mutex); - if (transport->sock == NULL) - goto out; if (!test_and_clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state)) - goto out; + return; sockerr = xchg(&transport->xprt_err, 0); - if (sockerr < 0) + if (sockerr < 0) { xprt_wake_pending_tasks(&transport->xprt, sockerr); -out: - mutex_unlock(&transport->recv_mutex); + xs_tcp_force_close(&transport->xprt); + } } static void xs_wake_pending(struct sock_xprt *transport) @@ -2636,20 +3041,11 @@ static int bc_send_request(struct rpc_rqst *req) return len; } -/* - * The close routine. Since this is client initiated, we do nothing - */ - static void bc_close(struct rpc_xprt *xprt) { xprt_disconnect_done(xprt); } -/* - * The xprt destroy routine. Again, because this connection is client - * initiated, we do nothing - */ - static void bc_destroy(struct rpc_xprt *xprt) { dprintk("RPC: bc_destroy xprt %p\n", xprt); @@ -2670,6 +3066,7 @@ static const struct rpc_xprt_ops xs_local_ops = { .buf_free = rpc_free, .prepare_request = xs_stream_prepare_request, .send_request = xs_local_send_request, + .abort_send_request = xs_stream_abort_send_request, .wait_for_reply_request = xprt_wait_for_reply_request_def, .close = xs_close, .destroy = xs_destroy, @@ -2717,6 +3114,7 @@ static const struct rpc_xprt_ops xs_tcp_ops = { .buf_free = rpc_free, .prepare_request = xs_stream_prepare_request, .send_request = xs_tcp_send_request, + .abort_send_request = xs_stream_abort_send_request, .wait_for_reply_request = xprt_wait_for_reply_request_def, .close = xs_tcp_shutdown, .destroy = xs_destroy, @@ -2863,7 +3261,7 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args) switch (sun->sun_family) { case AF_LOCAL: - if (sun->sun_path[0] != '/') { + if (sun->sun_path[0] != '/' && sun->sun_path[0] != '\0') { dprintk("RPC: bad AF_LOCAL address: %s\n", sun->sun_path); ret = ERR_PTR(-EINVAL); @@ -3006,8 +3404,13 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) xprt->timeout = &xs_tcp_default_timeout; xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + if (args->reconnect_timeout) + xprt->max_reconnect_timeout = args->reconnect_timeout; + xprt->connect_timeout = xprt->timeout->to_initval * (xprt->timeout->to_retries + 1); + if (args->connect_timeout) + xs_tcp_do_set_connect_timeout(xprt, args->connect_timeout); INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); INIT_WORK(&transport->error_worker, xs_error_handle); @@ -3050,6 +3453,94 @@ out_err: } /** + * xs_setup_tcp_tls - Set up transport to use a TCP with TLS + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct rpc_xprt *ret; + unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries; + + if (args->flags & XPRT_CREATE_INFINITE_SLOTS) + max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + max_slot_table_size); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_TCP; + xprt->xprt_class = &xs_tcp_transport; + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_tcp_ops; + xprt->timeout = &xs_tcp_default_timeout; + + xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + xprt->connect_timeout = xprt->timeout->to_initval * + (xprt->timeout->to_retries + 1); + + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + + switch (args->xprtsec.policy) { + case RPC_XPRTSEC_TLS_ANON: + case RPC_XPRTSEC_TLS_X509: + xprt->xprtsec = args->xprtsec; + INIT_DELAYED_WORK(&transport->connect_worker, + xs_tcp_tls_setup_socket); + break; + default: + ret = ERR_PTR(-EACCES); + goto out_err; + } + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + if (xprt_bound(xprt)) + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + else + dprintk("RPC: set up xprt to %s (autobind) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +/** * xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket * @args: rpc transport creation arguments * @@ -3158,6 +3649,15 @@ static struct xprt_class xs_tcp_transport = { .netid = { "tcp", "tcp6", "" }, }; +static struct xprt_class xs_tcp_tls_transport = { + .list = LIST_HEAD_INIT(xs_tcp_tls_transport.list), + .name = "tcp-with-tls", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_TCP_TLS, + .setup = xs_setup_tcp_tls, + .netid = { "tcp", "tcp6", "" }, +}; + static struct xprt_class xs_bc_tcp_transport = { .list = LIST_HEAD_INIT(xs_bc_tcp_transport.list), .name = "tcp NFSv4.1 backchannel", @@ -3174,11 +3674,12 @@ static struct xprt_class xs_bc_tcp_transport = { int init_socket_xprt(void) { if (!sunrpc_table_header) - sunrpc_table_header = register_sysctl_table(sunrpc_table); + sunrpc_table_header = register_sysctl("sunrpc", xs_tunables_table); xprt_register_transport(&xs_local_transport); xprt_register_transport(&xs_udp_transport); xprt_register_transport(&xs_tcp_transport); + xprt_register_transport(&xs_tcp_tls_transport); xprt_register_transport(&xs_bc_tcp_transport); return 0; @@ -3198,6 +3699,7 @@ void cleanup_socket_xprt(void) xprt_unregister_transport(&xs_local_transport); xprt_unregister_transport(&xs_udp_transport); xprt_unregister_transport(&xs_tcp_transport); + xprt_unregister_transport(&xs_tcp_tls_transport); xprt_unregister_transport(&xs_bc_tcp_transport); } |
