diff options
Diffstat (limited to 'net/sunrpc/xprtsock.c')
-rw-r--r-- | net/sunrpc/xprtsock.c | 434 |
1 files changed, 430 insertions, 4 deletions
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index 5f9030b81c9e..9f010369100a 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -47,6 +47,9 @@ #include <net/checksum.h> #include <net/udp.h> #include <net/tcp.h> +#include <net/tls.h> +#include <net/handshake.h> + #include <linux/bvec.h> #include <linux/highmem.h> #include <linux/uio.h> @@ -96,6 +99,7 @@ static struct ctl_table_header *sunrpc_table_header; static struct xprt_class xs_local_transport; static struct xprt_class xs_udp_transport; static struct xprt_class xs_tcp_transport; +static struct xprt_class xs_tcp_tls_transport; static struct xprt_class xs_bc_tcp_transport; /* @@ -187,6 +191,11 @@ static struct ctl_table xs_tunables_table[] = { */ #define XS_IDLE_DISC_TO (5U * 60 * HZ) +/* + * TLS handshake timeout. + */ +#define XS_TLS_HANDSHAKE_TO (10U * HZ) + #if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # undef RPC_DEBUG_DATA # define RPCDBG_FACILITY RPCDBG_TRANS @@ -253,7 +262,12 @@ static void xs_format_common_peer_addresses(struct rpc_xprt *xprt) switch (sap->sa_family) { case AF_LOCAL: sun = xs_addr_un(xprt); - strscpy(buf, sun->sun_path, sizeof(buf)); + if (sun->sun_path[0]) { + strscpy(buf, sun->sun_path, sizeof(buf)); + } else { + buf[0] = '@'; + strscpy(buf+1, sun->sun_path+1, sizeof(buf)-1); + } xprt->address_strings[RPC_DISPLAY_ADDR] = kstrdup(buf, GFP_KERNEL); break; @@ -342,13 +356,56 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp) return want; } +static int +xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, + struct cmsghdr *cmsg, int ret) +{ + if (cmsg->cmsg_level == SOL_TLS && + cmsg->cmsg_type == TLS_GET_RECORD_TYPE) { + u8 content_type = *((u8 *)CMSG_DATA(cmsg)); + + switch (content_type) { + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + msg->msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + ret = -ENOTCONN; + break; + default: + ret = -EAGAIN; + } + } + return ret; +} + +static int +xs_sock_recv_cmsg(struct socket *sock, struct msghdr *msg, int flags) +{ + union { + struct cmsghdr cmsg; + u8 buf[CMSG_SPACE(sizeof(u8))]; + } u; + int ret; + + msg->msg_control = &u; + msg->msg_controllen = sizeof(u); + ret = sock_recvmsg(sock, msg, flags); + if (msg->msg_controllen != sizeof(u)) + ret = xs_sock_process_cmsg(sock, msg, &u.cmsg, ret); + return ret; +} + static ssize_t xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek) { ssize_t ret; if (seek != 0) iov_iter_advance(&msg->msg_iter, seek); - ret = sock_recvmsg(sock, msg, flags); + ret = xs_sock_recv_cmsg(sock, msg, flags); return ret > 0 ? ret + seek : ret; } @@ -374,7 +431,7 @@ xs_read_discard(struct socket *sock, struct msghdr *msg, int flags, size_t count) { iov_iter_discard(&msg->msg_iter, ITER_DEST, count); - return sock_recvmsg(sock, msg, flags); + return xs_sock_recv_cmsg(sock, msg, flags); } #if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE @@ -695,6 +752,8 @@ static void xs_poll_check_readable(struct sock_xprt *transport) { clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state)) + return; if (!xs_poll_socket_readable(transport)) return; if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) @@ -1191,6 +1250,8 @@ static void xs_reset_transport(struct sock_xprt *transport) if (atomic_read(&transport->xprt.swapper)) sk_clear_memalloc(sk); + tls_handshake_cancel(sk); + kernel_sock_shutdown(sock, SHUT_RDWR); mutex_lock(&transport->recv_mutex); @@ -1380,6 +1441,10 @@ static void xs_data_ready(struct sock *sk) trace_xs_data_ready(xprt); transport->old_data_ready(sk); + + if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state)) + return; + /* Any data means we had a useful conversation, so * then we don't need to delay the next reconnect */ @@ -2360,6 +2425,267 @@ out_unlock: current_restore_flags(pflags, PF_MEMALLOC); } +/* + * Transfer the connected socket to @upper_transport, then mark that + * xprt CONNECTED. + */ +static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt, + struct sock_xprt *upper_transport) +{ + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + struct rpc_xprt *upper_xprt = &upper_transport->xprt; + + if (!upper_transport->inet) { + struct socket *sock = lower_transport->sock; + struct sock *sk = sock->sk; + + /* Avoid temporary address, they are bad for long-lived + * connections such as NFS mounts. + * RFC4941, section 3.6 suggests that: + * Individual applications, which have specific + * knowledge about the normal duration of connections, + * MAY override this as appropriate. + */ + if (xs_addr(upper_xprt)->sa_family == PF_INET6) + ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC); + + xs_tcp_set_socket_timeouts(upper_xprt, sock); + tcp_sock_set_nodelay(sk); + + lock_sock(sk); + + /* @sk is already connected, so it now has the RPC callbacks. + * Reach into @lower_transport to save the original ones. + */ + upper_transport->old_data_ready = lower_transport->old_data_ready; + upper_transport->old_state_change = lower_transport->old_state_change; + upper_transport->old_write_space = lower_transport->old_write_space; + upper_transport->old_error_report = lower_transport->old_error_report; + sk->sk_user_data = upper_xprt; + + /* socket options */ + sock_reset_flag(sk, SOCK_LINGER); + + xprt_clear_connected(upper_xprt); + + upper_transport->sock = sock; + upper_transport->inet = sk; + upper_transport->file = lower_transport->file; + + release_sock(sk); + + /* Reset lower_transport before shutting down its clnt */ + mutex_lock(&lower_transport->recv_mutex); + lower_transport->inet = NULL; + lower_transport->sock = NULL; + lower_transport->file = NULL; + + xprt_clear_connected(lower_xprt); + xs_sock_reset_connection_flags(lower_xprt); + xs_stream_reset_connect(lower_transport); + mutex_unlock(&lower_transport->recv_mutex); + } + + if (!xprt_bound(upper_xprt)) + return -ENOTCONN; + + xs_set_memalloc(upper_xprt); + + if (!xprt_test_and_set_connected(upper_xprt)) { + upper_xprt->connect_cookie++; + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + + upper_xprt->stat.connect_count++; + upper_xprt->stat.connect_time += (long)jiffies - + upper_xprt->stat.connect_start; + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + } + return 0; +} + +/** + * xs_tls_handshake_done - TLS handshake completion handler + * @data: address of xprt to wake + * @status: status of handshake + * @peerid: serial number of key containing the remote's identity + * + */ +static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid) +{ + struct rpc_xprt *lower_xprt = data; + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + + lower_transport->xprt_err = status ? -EACCES : 0; + complete(&lower_transport->handshake_done); + xprt_put(lower_xprt); +} + +static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec) +{ + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + struct tls_handshake_args args = { + .ta_sock = lower_transport->sock, + .ta_done = xs_tls_handshake_done, + .ta_data = xprt_get(lower_xprt), + .ta_peername = lower_xprt->servername, + }; + struct sock *sk = lower_transport->inet; + int rc; + + init_completion(&lower_transport->handshake_done); + set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state); + lower_transport->xprt_err = -ETIMEDOUT; + switch (xprtsec->policy) { + case RPC_XPRTSEC_TLS_ANON: + rc = tls_client_hello_anon(&args, GFP_KERNEL); + if (rc) + goto out_put_xprt; + break; + case RPC_XPRTSEC_TLS_X509: + args.ta_my_cert = xprtsec->cert_serial; + args.ta_my_privkey = xprtsec->privkey_serial; + rc = tls_client_hello_x509(&args, GFP_KERNEL); + if (rc) + goto out_put_xprt; + break; + default: + rc = -EACCES; + goto out_put_xprt; + } + + rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done, + XS_TLS_HANDSHAKE_TO); + if (rc <= 0) { + if (!tls_handshake_cancel(sk)) { + if (rc == 0) + rc = -ETIMEDOUT; + goto out_put_xprt; + } + } + + rc = lower_transport->xprt_err; + +out: + xs_stream_reset_connect(lower_transport); + clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state); + return rc; + +out_put_xprt: + xprt_put(lower_xprt); + goto out; +} + +/** + * xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket + * @work: queued work item + * + * Invoked by a work queue tasklet. + * + * For RPC-with-TLS, there is a two-stage connection process. + * + * The "upper-layer xprt" is visible to the RPC consumer. Once it has + * been marked connected, the consumer knows that a TCP connection and + * a TLS session have been established. + * + * A "lower-layer xprt", created in this function, handles the mechanics + * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and + * then driving the TLS handshake. Once all that is complete, the upper + * layer xprt is marked connected. + */ +static void xs_tcp_tls_setup_socket(struct work_struct *work) +{ + struct sock_xprt *upper_transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_clnt *upper_clnt = upper_transport->clnt; + struct rpc_xprt *upper_xprt = &upper_transport->xprt; + struct rpc_create_args args = { + .net = upper_xprt->xprt_net, + .protocol = upper_xprt->prot, + .address = (struct sockaddr *)&upper_xprt->addr, + .addrsize = upper_xprt->addrlen, + .timeout = upper_clnt->cl_timeout, + .servername = upper_xprt->servername, + .program = upper_clnt->cl_program, + .prognumber = upper_clnt->cl_prog, + .version = upper_clnt->cl_vers, + .authflavor = RPC_AUTH_TLS, + .cred = upper_clnt->cl_cred, + .xprtsec = { + .policy = RPC_XPRTSEC_NONE, + }, + }; + unsigned int pflags = current->flags; + struct rpc_clnt *lower_clnt; + struct rpc_xprt *lower_xprt; + int status; + + if (atomic_read(&upper_xprt->swapper)) + current->flags |= PF_MEMALLOC; + + xs_stream_start_connect(upper_transport); + + /* This implicitly sends an RPC_AUTH_TLS probe */ + lower_clnt = rpc_create(&args); + if (IS_ERR(lower_clnt)) { + trace_rpc_tls_unavailable(upper_clnt, upper_xprt); + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt)); + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + goto out_unlock; + } + + /* RPC_AUTH_TLS probe was successful. Try a TLS handshake on + * the lower xprt. + */ + rcu_read_lock(); + lower_xprt = rcu_dereference(lower_clnt->cl_xprt); + rcu_read_unlock(); + status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec); + if (status) { + trace_rpc_tls_not_started(upper_clnt, upper_xprt); + goto out_close; + } + + status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport); + if (status) + goto out_close; + + trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0); + if (!xprt_test_and_set_connected(upper_xprt)) { + upper_xprt->connect_cookie++; + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + + upper_xprt->stat.connect_count++; + upper_xprt->stat.connect_time += (long)jiffies - + upper_xprt->stat.connect_start; + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + } + rpc_shutdown_client(lower_clnt); + +out_unlock: + current_restore_flags(pflags, PF_MEMALLOC); + upper_transport->clnt = NULL; + xprt_unlock_connect(upper_xprt, upper_transport); + return; + +out_close: + rpc_shutdown_client(lower_clnt); + + /* xprt_force_disconnect() wakes tasks with a fixed tk_status code. + * Wake them first here to ensure they get our tk_status code. + */ + xprt_wake_pending_tasks(upper_xprt, status); + xs_tcp_force_close(upper_xprt); + xprt_clear_connecting(upper_xprt); + goto out_unlock; +} + /** * xs_connect - connect a socket to a remote endpoint * @xprt: pointer to transport structure @@ -2391,6 +2717,7 @@ static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task) } else dprintk("RPC: xs_connect scheduled xprt %p\n", xprt); + transport->clnt = task->tk_client; queue_delayed_work(xprtiod_workqueue, &transport->connect_worker, delay); @@ -2858,7 +3185,7 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args) switch (sun->sun_family) { case AF_LOCAL: - if (sun->sun_path[0] != '/') { + if (sun->sun_path[0] != '/' && sun->sun_path[0] != '\0') { dprintk("RPC: bad AF_LOCAL address: %s\n", sun->sun_path); ret = ERR_PTR(-EINVAL); @@ -3045,6 +3372,94 @@ out_err: } /** + * xs_setup_tcp_tls - Set up transport to use a TCP with TLS + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct rpc_xprt *ret; + unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries; + + if (args->flags & XPRT_CREATE_INFINITE_SLOTS) + max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + max_slot_table_size); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_TCP; + xprt->xprt_class = &xs_tcp_transport; + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_tcp_ops; + xprt->timeout = &xs_tcp_default_timeout; + + xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + xprt->connect_timeout = xprt->timeout->to_initval * + (xprt->timeout->to_retries + 1); + + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + + switch (args->xprtsec.policy) { + case RPC_XPRTSEC_TLS_ANON: + case RPC_XPRTSEC_TLS_X509: + xprt->xprtsec = args->xprtsec; + INIT_DELAYED_WORK(&transport->connect_worker, + xs_tcp_tls_setup_socket); + break; + default: + ret = ERR_PTR(-EACCES); + goto out_err; + } + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + if (xprt_bound(xprt)) + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + else + dprintk("RPC: set up xprt to %s (autobind) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +/** * xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket * @args: rpc transport creation arguments * @@ -3153,6 +3568,15 @@ static struct xprt_class xs_tcp_transport = { .netid = { "tcp", "tcp6", "" }, }; +static struct xprt_class xs_tcp_tls_transport = { + .list = LIST_HEAD_INIT(xs_tcp_tls_transport.list), + .name = "tcp-with-tls", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_TCP_TLS, + .setup = xs_setup_tcp_tls, + .netid = { "tcp", "tcp6", "" }, +}; + static struct xprt_class xs_bc_tcp_transport = { .list = LIST_HEAD_INIT(xs_bc_tcp_transport.list), .name = "tcp NFSv4.1 backchannel", @@ -3174,6 +3598,7 @@ int init_socket_xprt(void) xprt_register_transport(&xs_local_transport); xprt_register_transport(&xs_udp_transport); xprt_register_transport(&xs_tcp_transport); + xprt_register_transport(&xs_tcp_tls_transport); xprt_register_transport(&xs_bc_tcp_transport); return 0; @@ -3193,6 +3618,7 @@ void cleanup_socket_xprt(void) xprt_unregister_transport(&xs_local_transport); xprt_unregister_transport(&xs_udp_transport); xprt_unregister_transport(&xs_tcp_transport); + xprt_unregister_transport(&xs_tcp_tls_transport); xprt_unregister_transport(&xs_bc_tcp_transport); } |