summaryrefslogtreecommitdiff
path: root/net/sunrpc/xprtsock.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/sunrpc/xprtsock.c')
-rw-r--r--net/sunrpc/xprtsock.c608
1 files changed, 555 insertions, 53 deletions
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
index aaa5b2741b79..2e1fe6013361 100644
--- a/net/sunrpc/xprtsock.c
+++ b/net/sunrpc/xprtsock.c
@@ -47,17 +47,22 @@
#include <net/checksum.h>
#include <net/udp.h>
#include <net/tcp.h>
+#include <net/tls_prot.h>
+#include <net/handshake.h>
+
#include <linux/bvec.h>
#include <linux/highmem.h>
#include <linux/uio.h>
#include <linux/sched/mm.h>
+#include <trace/events/sock.h>
#include <trace/events/sunrpc.h>
#include "socklib.h"
#include "sunrpc.h"
static void xs_close(struct rpc_xprt *xprt);
+static void xs_reset_srcport(struct sock_xprt *transport);
static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock);
static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
struct socket *sock);
@@ -77,7 +82,7 @@ static unsigned int xs_tcp_fin_timeout __read_mostly = XS_TCP_LINGER_TO;
/*
* We can register our own files under /proc/sys/sunrpc by
- * calling register_sysctl_table() again. The files in that
+ * calling register_sysctl() again. The files in that
* directory become the union of all files registered there.
*
* We simply need to make sure that we don't collide with
@@ -95,6 +100,7 @@ static struct ctl_table_header *sunrpc_table_header;
static struct xprt_class xs_local_transport;
static struct xprt_class xs_udp_transport;
static struct xprt_class xs_tcp_transport;
+static struct xprt_class xs_tcp_tls_transport;
static struct xprt_class xs_bc_tcp_transport;
/*
@@ -154,16 +160,6 @@ static struct ctl_table xs_tunables_table[] = {
.mode = 0644,
.proc_handler = proc_dointvec_jiffies,
},
- { },
-};
-
-static struct ctl_table sunrpc_table[] = {
- {
- .procname = "sunrpc",
- .mode = 0555,
- .child = xs_tunables_table
- },
- { },
};
/*
@@ -195,6 +191,11 @@ static struct ctl_table sunrpc_table[] = {
*/
#define XS_IDLE_DISC_TO (5U * 60 * HZ)
+/*
+ * TLS handshake timeout.
+ */
+#define XS_TLS_HANDSHAKE_TO (10U * HZ)
+
#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
# undef RPC_DEBUG_DATA
# define RPCDBG_FACILITY RPCDBG_TRANS
@@ -261,7 +262,12 @@ static void xs_format_common_peer_addresses(struct rpc_xprt *xprt)
switch (sap->sa_family) {
case AF_LOCAL:
sun = xs_addr_un(xprt);
- strscpy(buf, sun->sun_path, sizeof(buf));
+ if (sun->sun_path[0]) {
+ strscpy(buf, sun->sun_path, sizeof(buf));
+ } else {
+ buf[0] = '@';
+ strscpy(buf+1, sun->sun_path+1, sizeof(buf)-1);
+ }
xprt->address_strings[RPC_DISPLAY_ADDR] =
kstrdup(buf, GFP_KERNEL);
break;
@@ -350,6 +356,66 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp)
return want;
}
+static int
+xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
+ unsigned int *msg_flags, struct cmsghdr *cmsg, int ret)
+{
+ u8 content_type = tls_get_record_type(sock->sk, cmsg);
+ u8 level, description;
+
+ switch (content_type) {
+ case 0:
+ break;
+ case TLS_RECORD_TYPE_DATA:
+ /* TLS sets EOR at the end of each application data
+ * record, even though there might be more frames
+ * waiting to be decrypted.
+ */
+ *msg_flags &= ~MSG_EOR;
+ break;
+ case TLS_RECORD_TYPE_ALERT:
+ tls_alert_recv(sock->sk, msg, &level, &description);
+ ret = (level == TLS_ALERT_LEVEL_FATAL) ?
+ -EACCES : -EAGAIN;
+ break;
+ default:
+ /* discard this record type */
+ ret = -EAGAIN;
+ }
+ return ret;
+}
+
+static int
+xs_sock_recv_cmsg(struct socket *sock, unsigned int *msg_flags, int flags)
+{
+ union {
+ struct cmsghdr cmsg;
+ u8 buf[CMSG_SPACE(sizeof(u8))];
+ } u;
+ u8 alert[2];
+ struct kvec alert_kvec = {
+ .iov_base = alert,
+ .iov_len = sizeof(alert),
+ };
+ struct msghdr msg = {
+ .msg_flags = *msg_flags,
+ .msg_control = &u,
+ .msg_controllen = sizeof(u),
+ };
+ int ret;
+
+ iov_iter_kvec(&msg.msg_iter, ITER_DEST, &alert_kvec, 1,
+ alert_kvec.iov_len);
+ ret = sock_recvmsg(sock, &msg, flags);
+ if (ret > 0) {
+ if (tls_get_record_type(sock->sk, &u.cmsg) == TLS_RECORD_TYPE_ALERT)
+ iov_iter_revert(&msg.msg_iter, ret);
+ ret = xs_sock_process_cmsg(sock, &msg, msg_flags, &u.cmsg,
+ -EAGAIN);
+ }
+ return ret;
+}
+
static ssize_t
xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek)
{
@@ -357,6 +423,12 @@ xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek)
if (seek != 0)
iov_iter_advance(&msg->msg_iter, seek);
ret = sock_recvmsg(sock, msg, flags);
+ /* Handle TLS inband control message lazily */
+ if (msg->msg_flags & MSG_CTRUNC) {
+ msg->msg_flags &= ~(MSG_CTRUNC | MSG_EOR);
+ if (ret == 0 || ret == -EIO)
+ ret = xs_sock_recv_cmsg(sock, &msg->msg_flags, flags);
+ }
return ret > 0 ? ret + seek : ret;
}
@@ -382,7 +454,7 @@ xs_read_discard(struct socket *sock, struct msghdr *msg, int flags,
size_t count)
{
iov_iter_discard(&msg->msg_iter, ITER_DEST, count);
- return sock_recvmsg(sock, msg, flags);
+ return xs_sock_recvmsg(sock, msg, flags, 0);
}
#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
@@ -703,6 +775,8 @@ static void xs_poll_check_readable(struct sock_xprt *transport)
{
clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state);
+ if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state))
+ return;
if (!xs_poll_socket_readable(transport))
return;
if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state))
@@ -726,6 +800,8 @@ static void xs_stream_data_receive(struct sock_xprt *transport)
}
if (ret == -ESHUTDOWN)
kernel_sock_shutdown(transport->sock, SHUT_RDWR);
+ else if (ret == -EACCES)
+ xprt_wake_pending_tasks(&transport->xprt, -EACCES);
else
xs_poll_check_readable(transport);
out:
@@ -827,6 +903,17 @@ static int xs_stream_prepare_request(struct rpc_rqst *req, struct xdr_buf *buf)
return xdr_alloc_bvec(buf, rpc_task_gfp_mask());
}
+static void xs_stream_abort_send_request(struct rpc_rqst *req)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+ struct sock_xprt *transport =
+ container_of(xprt, struct sock_xprt, xprt);
+
+ if (transport->xmit.offset != 0 &&
+ !test_bit(XPRT_CLOSE_WAIT, &xprt->state))
+ xprt_force_disconnect(xprt);
+}
+
/*
* Determine if the previous message in the stream was aborted before it
* could complete transmission.
@@ -1125,11 +1212,13 @@ static void xs_sock_reset_state_flags(struct rpc_xprt *xprt)
{
struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ transport->xprt_err = 0;
clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state);
clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state);
clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state);
clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state);
clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state);
+ clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state);
}
static void xs_run_error_worker(struct sock_xprt *transport, unsigned int nr)
@@ -1199,6 +1288,8 @@ static void xs_reset_transport(struct sock_xprt *transport)
if (atomic_read(&transport->xprt.swapper))
sk_clear_memalloc(sk);
+ tls_handshake_cancel(sk);
+
kernel_sock_shutdown(sock, SHUT_RDWR);
mutex_lock(&transport->recv_mutex);
@@ -1208,6 +1299,7 @@ static void xs_reset_transport(struct sock_xprt *transport)
transport->file = NULL;
sk->sk_user_data = NULL;
+ sk->sk_sndtimeo = 0;
xs_restore_old_callbacks(transport, sk);
xprt_clear_connected(xprt);
@@ -1239,6 +1331,8 @@ static void xs_close(struct rpc_xprt *xprt)
dprintk("RPC: xs_close xprt %p\n", xprt);
+ if (transport->sock)
+ tls_handshake_close(transport->sock);
xs_reset_transport(transport);
xprt->reestablish_timeout = 0;
}
@@ -1378,6 +1472,8 @@ static void xs_data_ready(struct sock *sk)
{
struct rpc_xprt *xprt;
+ trace_sk_data_ready(sk);
+
xprt = xprt_from_sock(sk);
if (xprt != NULL) {
struct sock_xprt *transport = container_of(xprt,
@@ -1386,6 +1482,10 @@ static void xs_data_ready(struct sock *sk)
trace_xs_data_ready(xprt);
transport->old_data_ready(sk);
+
+ if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state))
+ return;
+
/* Any data means we had a useful conversation, so
* then we don't need to delay the next reconnect
*/
@@ -1498,8 +1598,10 @@ static void xs_tcp_state_change(struct sock *sk)
break;
case TCP_CLOSE:
if (test_and_clear_bit(XPRT_SOCK_CONNECTING,
- &transport->sock_state))
+ &transport->sock_state)) {
+ xs_reset_srcport(transport);
xprt_clear_connecting(xprt);
+ }
clear_bit(XPRT_CLOSING, &xprt->state);
/* Trigger the socket release */
xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT);
@@ -1655,6 +1757,11 @@ static void xs_set_port(struct rpc_xprt *xprt, unsigned short port)
xs_update_peer_port(xprt);
}
+static void xs_reset_srcport(struct sock_xprt *transport)
+{
+ transport->srcport = 0;
+}
+
static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock)
{
if (transport->srcport == 0 && transport->xprt.reuseport)
@@ -1738,8 +1845,8 @@ static int xs_bind(struct sock_xprt *transport, struct socket *sock)
memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen);
do {
rpc_set_port((struct sockaddr *)&myaddr, port);
- err = kernel_bind(sock, (struct sockaddr *)&myaddr,
- transport->xprt.addrlen);
+ err = kernel_bind(sock, (struct sockaddr_unsized *)&myaddr,
+ transport->xprt.addrlen);
if (err == 0) {
if (transport->xprt.reuseport)
transport->srcport = port;
@@ -1854,6 +1961,9 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
goto out;
}
+ if (protocol == IPPROTO_TCP)
+ sk_net_refcnt_upgrade(sock->sk);
+
filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
if (IS_ERR(filp))
return ERR_CAST(filp);
@@ -1895,7 +2005,7 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt,
xs_stream_start_connect(transport);
- return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0);
+ return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt), xprt->addrlen, 0);
}
/**
@@ -2155,6 +2265,7 @@ static void xs_tcp_shutdown(struct rpc_xprt *xprt)
switch (skst) {
case TCP_FIN_WAIT1:
case TCP_FIN_WAIT2:
+ case TCP_LAST_ACK:
break;
case TCP_ESTABLISHED:
case TCP_CLOSE_WAIT:
@@ -2170,9 +2281,13 @@ static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
struct socket *sock)
{
struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ struct net *net = sock_net(sock->sk);
+ unsigned long connect_timeout;
+ unsigned long syn_retries;
unsigned int keepidle;
unsigned int keepcnt;
unsigned int timeo;
+ unsigned long t;
spin_lock(&xprt->transport_lock);
keepidle = DIV_ROUND_UP(xprt->timeout->to_initval, HZ);
@@ -2190,6 +2305,35 @@ static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
/* TCP user timeout (see RFC5482) */
tcp_sock_set_user_timeout(sock->sk, timeo);
+
+ /* Connect timeout */
+ connect_timeout = max_t(unsigned long,
+ DIV_ROUND_UP(xprt->connect_timeout, HZ), 1);
+ syn_retries = max_t(unsigned long,
+ READ_ONCE(net->ipv4.sysctl_tcp_syn_retries), 1);
+ for (t = 0; t <= syn_retries && (1UL << t) < connect_timeout; t++)
+ ;
+ if (t <= syn_retries)
+ tcp_sock_set_syncnt(sock->sk, t - 1);
+}
+
+static void xs_tcp_do_set_connect_timeout(struct rpc_xprt *xprt,
+ unsigned long connect_timeout)
+{
+ struct sock_xprt *transport =
+ container_of(xprt, struct sock_xprt, xprt);
+ struct rpc_timeout to;
+ unsigned long initval;
+
+ memcpy(&to, xprt->timeout, sizeof(to));
+ /* Arbitrary lower limit */
+ initval = max_t(unsigned long, connect_timeout, XS_TCP_INIT_REEST_TO);
+ to.to_initval = initval;
+ to.to_maxval = initval;
+ to.to_retries = 0;
+ memcpy(&transport->tcp_timeout, &to, sizeof(transport->tcp_timeout));
+ xprt->timeout = &transport->tcp_timeout;
+ xprt->connect_timeout = connect_timeout;
}
static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt,
@@ -2197,25 +2341,12 @@ static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt,
unsigned long reconnect_timeout)
{
struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
- struct rpc_timeout to;
- unsigned long initval;
spin_lock(&xprt->transport_lock);
if (reconnect_timeout < xprt->max_reconnect_timeout)
xprt->max_reconnect_timeout = reconnect_timeout;
- if (connect_timeout < xprt->connect_timeout) {
- memcpy(&to, xprt->timeout, sizeof(to));
- initval = DIV_ROUND_UP(connect_timeout, to.to_retries + 1);
- /* Arbitrary lower limit */
- if (initval < XS_TCP_INIT_REEST_TO << 1)
- initval = XS_TCP_INIT_REEST_TO << 1;
- to.to_initval = initval;
- to.to_maxval = initval;
- memcpy(&transport->tcp_timeout, &to,
- sizeof(transport->tcp_timeout));
- xprt->timeout = &transport->tcp_timeout;
- xprt->connect_timeout = connect_timeout;
- }
+ if (connect_timeout < xprt->connect_timeout)
+ xs_tcp_do_set_connect_timeout(xprt, connect_timeout);
set_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state);
spin_unlock(&xprt->transport_lock);
}
@@ -2274,7 +2405,8 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
/* Tell the socket layer to start connecting... */
set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state);
- return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK);
+ return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt),
+ xprt->addrlen, O_NONBLOCK);
}
/**
@@ -2335,6 +2467,13 @@ static void xs_tcp_setup_socket(struct work_struct *work)
transport->srcport = 0;
status = -EAGAIN;
break;
+ case -EPERM:
+ /* Happens, for instance, if a BPF program is preventing
+ * the connect. Remap the error so upper layers can better
+ * deal with it.
+ */
+ status = -ECONNREFUSED;
+ fallthrough;
case -EINVAL:
/* Happens, for instance, if the user specified a link
* local IPv6 address without a scope-id.
@@ -2346,6 +2485,7 @@ static void xs_tcp_setup_socket(struct work_struct *work)
case -EHOSTUNREACH:
case -EADDRINUSE:
case -ENOBUFS:
+ case -ENOTCONN:
break;
default:
printk("%s: connect returned unhandled error %d\n",
@@ -2365,6 +2505,275 @@ out_unlock:
current_restore_flags(pflags, PF_MEMALLOC);
}
+/*
+ * Transfer the connected socket to @upper_transport, then mark that
+ * xprt CONNECTED.
+ */
+static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt,
+ struct sock_xprt *upper_transport)
+{
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+ struct rpc_xprt *upper_xprt = &upper_transport->xprt;
+
+ if (!upper_transport->inet) {
+ struct socket *sock = lower_transport->sock;
+ struct sock *sk = sock->sk;
+
+ /* Avoid temporary address, they are bad for long-lived
+ * connections such as NFS mounts.
+ * RFC4941, section 3.6 suggests that:
+ * Individual applications, which have specific
+ * knowledge about the normal duration of connections,
+ * MAY override this as appropriate.
+ */
+ if (xs_addr(upper_xprt)->sa_family == PF_INET6)
+ ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC);
+
+ xs_tcp_set_socket_timeouts(upper_xprt, sock);
+ tcp_sock_set_nodelay(sk);
+
+ lock_sock(sk);
+
+ /* @sk is already connected, so it now has the RPC callbacks.
+ * Reach into @lower_transport to save the original ones.
+ */
+ upper_transport->old_data_ready = lower_transport->old_data_ready;
+ upper_transport->old_state_change = lower_transport->old_state_change;
+ upper_transport->old_write_space = lower_transport->old_write_space;
+ upper_transport->old_error_report = lower_transport->old_error_report;
+ sk->sk_user_data = upper_xprt;
+
+ /* socket options */
+ sock_reset_flag(sk, SOCK_LINGER);
+
+ xprt_clear_connected(upper_xprt);
+
+ upper_transport->sock = sock;
+ upper_transport->inet = sk;
+ upper_transport->file = lower_transport->file;
+
+ release_sock(sk);
+
+ /* Reset lower_transport before shutting down its clnt */
+ mutex_lock(&lower_transport->recv_mutex);
+ lower_transport->inet = NULL;
+ lower_transport->sock = NULL;
+ lower_transport->file = NULL;
+
+ xprt_clear_connected(lower_xprt);
+ xs_sock_reset_connection_flags(lower_xprt);
+ xs_stream_reset_connect(lower_transport);
+ mutex_unlock(&lower_transport->recv_mutex);
+ }
+
+ if (!xprt_bound(upper_xprt))
+ return -ENOTCONN;
+
+ xs_set_memalloc(upper_xprt);
+
+ if (!xprt_test_and_set_connected(upper_xprt)) {
+ upper_xprt->connect_cookie++;
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+
+ upper_xprt->stat.connect_count++;
+ upper_xprt->stat.connect_time += (long)jiffies -
+ upper_xprt->stat.connect_start;
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ }
+ return 0;
+}
+
+/**
+ * xs_tls_handshake_done - TLS handshake completion handler
+ * @data: address of xprt to wake
+ * @status: status of handshake
+ * @peerid: serial number of key containing the remote's identity
+ *
+ */
+static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid)
+{
+ struct rpc_xprt *lower_xprt = data;
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+
+ switch (status) {
+ case 0:
+ case -EACCES:
+ case -ETIMEDOUT:
+ lower_transport->xprt_err = status;
+ break;
+ default:
+ lower_transport->xprt_err = -EACCES;
+ }
+ complete(&lower_transport->handshake_done);
+ xprt_put(lower_xprt);
+}
+
+static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec)
+{
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+ struct tls_handshake_args args = {
+ .ta_sock = lower_transport->sock,
+ .ta_done = xs_tls_handshake_done,
+ .ta_data = xprt_get(lower_xprt),
+ .ta_peername = lower_xprt->servername,
+ };
+ struct sock *sk = lower_transport->inet;
+ int rc;
+
+ init_completion(&lower_transport->handshake_done);
+ set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
+ lower_transport->xprt_err = -ETIMEDOUT;
+ switch (xprtsec->policy) {
+ case RPC_XPRTSEC_TLS_ANON:
+ rc = tls_client_hello_anon(&args, GFP_KERNEL);
+ if (rc)
+ goto out_put_xprt;
+ break;
+ case RPC_XPRTSEC_TLS_X509:
+ args.ta_my_cert = xprtsec->cert_serial;
+ args.ta_my_privkey = xprtsec->privkey_serial;
+ rc = tls_client_hello_x509(&args, GFP_KERNEL);
+ if (rc)
+ goto out_put_xprt;
+ break;
+ default:
+ rc = -EACCES;
+ goto out_put_xprt;
+ }
+
+ rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done,
+ XS_TLS_HANDSHAKE_TO);
+ if (rc <= 0) {
+ tls_handshake_cancel(sk);
+ if (rc == 0)
+ rc = -ETIMEDOUT;
+ goto out_put_xprt;
+ }
+
+ rc = lower_transport->xprt_err;
+
+out:
+ xs_stream_reset_connect(lower_transport);
+ clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
+ return rc;
+
+out_put_xprt:
+ xprt_put(lower_xprt);
+ goto out;
+}
+
+/**
+ * xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket
+ * @work: queued work item
+ *
+ * Invoked by a work queue tasklet.
+ *
+ * For RPC-with-TLS, there is a two-stage connection process.
+ *
+ * The "upper-layer xprt" is visible to the RPC consumer. Once it has
+ * been marked connected, the consumer knows that a TCP connection and
+ * a TLS session have been established.
+ *
+ * A "lower-layer xprt", created in this function, handles the mechanics
+ * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and
+ * then driving the TLS handshake. Once all that is complete, the upper
+ * layer xprt is marked connected.
+ */
+static void xs_tcp_tls_setup_socket(struct work_struct *work)
+{
+ struct sock_xprt *upper_transport =
+ container_of(work, struct sock_xprt, connect_worker.work);
+ struct rpc_clnt *upper_clnt = upper_transport->clnt;
+ struct rpc_xprt *upper_xprt = &upper_transport->xprt;
+ struct rpc_create_args args = {
+ .net = upper_xprt->xprt_net,
+ .protocol = upper_xprt->prot,
+ .address = (struct sockaddr *)&upper_xprt->addr,
+ .addrsize = upper_xprt->addrlen,
+ .timeout = upper_clnt->cl_timeout,
+ .servername = upper_xprt->servername,
+ .program = upper_clnt->cl_program,
+ .prognumber = upper_clnt->cl_prog,
+ .version = upper_clnt->cl_vers,
+ .authflavor = RPC_AUTH_TLS,
+ .cred = upper_clnt->cl_cred,
+ .xprtsec = {
+ .policy = RPC_XPRTSEC_NONE,
+ },
+ .stats = upper_clnt->cl_stats,
+ };
+ unsigned int pflags = current->flags;
+ struct rpc_clnt *lower_clnt;
+ struct rpc_xprt *lower_xprt;
+ int status;
+
+ if (atomic_read(&upper_xprt->swapper))
+ current->flags |= PF_MEMALLOC;
+
+ xs_stream_start_connect(upper_transport);
+
+ /* This implicitly sends an RPC_AUTH_TLS probe */
+ lower_clnt = rpc_create(&args);
+ if (IS_ERR(lower_clnt)) {
+ trace_rpc_tls_unavailable(upper_clnt, upper_xprt);
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+ xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt));
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ goto out_unlock;
+ }
+
+ /* RPC_AUTH_TLS probe was successful. Try a TLS handshake on
+ * the lower xprt.
+ */
+ rcu_read_lock();
+ lower_xprt = rcu_dereference(lower_clnt->cl_xprt);
+ rcu_read_unlock();
+
+ if (wait_on_bit_lock(&lower_xprt->state, XPRT_LOCKED, TASK_KILLABLE))
+ goto out_unlock;
+
+ status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec);
+ if (status) {
+ trace_rpc_tls_not_started(upper_clnt, upper_xprt);
+ goto out_close;
+ }
+
+ status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport);
+ if (status)
+ goto out_close;
+ xprt_release_write(lower_xprt, NULL);
+ trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0);
+ rpc_shutdown_client(lower_clnt);
+
+ /* Check for ingress data that arrived before the socket's
+ * ->data_ready callback was set up.
+ */
+ xs_poll_check_readable(upper_transport);
+
+out_unlock:
+ current_restore_flags(pflags, PF_MEMALLOC);
+ upper_transport->clnt = NULL;
+ xprt_unlock_connect(upper_xprt, upper_transport);
+ return;
+
+out_close:
+ xprt_release_write(lower_xprt, NULL);
+ rpc_shutdown_client(lower_clnt);
+
+ /* xprt_force_disconnect() wakes tasks with a fixed tk_status code.
+ * Wake them first here to ensure they get our tk_status code.
+ */
+ xprt_wake_pending_tasks(upper_xprt, status);
+ xs_tcp_force_close(upper_xprt);
+ xprt_clear_connecting(upper_xprt);
+ goto out_unlock;
+}
+
/**
* xs_connect - connect a socket to a remote endpoint
* @xprt: pointer to transport structure
@@ -2396,6 +2805,7 @@ static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task)
} else
dprintk("RPC: xs_connect scheduled xprt %p\n", xprt);
+ transport->clnt = task->tk_client;
queue_delayed_work(xprtiod_workqueue,
&transport->connect_worker,
delay);
@@ -2417,18 +2827,13 @@ static void xs_wake_error(struct sock_xprt *transport)
{
int sockerr;
- if (!test_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state))
- return;
- mutex_lock(&transport->recv_mutex);
- if (transport->sock == NULL)
- goto out;
if (!test_and_clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state))
- goto out;
+ return;
sockerr = xchg(&transport->xprt_err, 0);
- if (sockerr < 0)
+ if (sockerr < 0) {
xprt_wake_pending_tasks(&transport->xprt, sockerr);
-out:
- mutex_unlock(&transport->recv_mutex);
+ xs_tcp_force_close(&transport->xprt);
+ }
}
static void xs_wake_pending(struct sock_xprt *transport)
@@ -2636,20 +3041,11 @@ static int bc_send_request(struct rpc_rqst *req)
return len;
}
-/*
- * The close routine. Since this is client initiated, we do nothing
- */
-
static void bc_close(struct rpc_xprt *xprt)
{
xprt_disconnect_done(xprt);
}
-/*
- * The xprt destroy routine. Again, because this connection is client
- * initiated, we do nothing
- */
-
static void bc_destroy(struct rpc_xprt *xprt)
{
dprintk("RPC: bc_destroy xprt %p\n", xprt);
@@ -2670,6 +3066,7 @@ static const struct rpc_xprt_ops xs_local_ops = {
.buf_free = rpc_free,
.prepare_request = xs_stream_prepare_request,
.send_request = xs_local_send_request,
+ .abort_send_request = xs_stream_abort_send_request,
.wait_for_reply_request = xprt_wait_for_reply_request_def,
.close = xs_close,
.destroy = xs_destroy,
@@ -2717,6 +3114,7 @@ static const struct rpc_xprt_ops xs_tcp_ops = {
.buf_free = rpc_free,
.prepare_request = xs_stream_prepare_request,
.send_request = xs_tcp_send_request,
+ .abort_send_request = xs_stream_abort_send_request,
.wait_for_reply_request = xprt_wait_for_reply_request_def,
.close = xs_tcp_shutdown,
.destroy = xs_destroy,
@@ -2863,7 +3261,7 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args)
switch (sun->sun_family) {
case AF_LOCAL:
- if (sun->sun_path[0] != '/') {
+ if (sun->sun_path[0] != '/' && sun->sun_path[0] != '\0') {
dprintk("RPC: bad AF_LOCAL address: %s\n",
sun->sun_path);
ret = ERR_PTR(-EINVAL);
@@ -3006,8 +3404,13 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args)
xprt->timeout = &xs_tcp_default_timeout;
xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
+ if (args->reconnect_timeout)
+ xprt->max_reconnect_timeout = args->reconnect_timeout;
+
xprt->connect_timeout = xprt->timeout->to_initval *
(xprt->timeout->to_retries + 1);
+ if (args->connect_timeout)
+ xs_tcp_do_set_connect_timeout(xprt, args->connect_timeout);
INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
INIT_WORK(&transport->error_worker, xs_error_handle);
@@ -3050,6 +3453,94 @@ out_err:
}
/**
+ * xs_setup_tcp_tls - Set up transport to use a TCP with TLS
+ * @args: rpc transport creation arguments
+ *
+ */
+static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args)
+{
+ struct sockaddr *addr = args->dstaddr;
+ struct rpc_xprt *xprt;
+ struct sock_xprt *transport;
+ struct rpc_xprt *ret;
+ unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries;
+
+ if (args->flags & XPRT_CREATE_INFINITE_SLOTS)
+ max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT;
+
+ xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries,
+ max_slot_table_size);
+ if (IS_ERR(xprt))
+ return xprt;
+ transport = container_of(xprt, struct sock_xprt, xprt);
+
+ xprt->prot = IPPROTO_TCP;
+ xprt->xprt_class = &xs_tcp_transport;
+ xprt->max_payload = RPC_MAX_FRAGMENT_SIZE;
+
+ xprt->bind_timeout = XS_BIND_TO;
+ xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
+ xprt->idle_timeout = XS_IDLE_DISC_TO;
+
+ xprt->ops = &xs_tcp_ops;
+ xprt->timeout = &xs_tcp_default_timeout;
+
+ xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
+ xprt->connect_timeout = xprt->timeout->to_initval *
+ (xprt->timeout->to_retries + 1);
+
+ INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
+ INIT_WORK(&transport->error_worker, xs_error_handle);
+
+ switch (args->xprtsec.policy) {
+ case RPC_XPRTSEC_TLS_ANON:
+ case RPC_XPRTSEC_TLS_X509:
+ xprt->xprtsec = args->xprtsec;
+ INIT_DELAYED_WORK(&transport->connect_worker,
+ xs_tcp_tls_setup_socket);
+ break;
+ default:
+ ret = ERR_PTR(-EACCES);
+ goto out_err;
+ }
+
+ switch (addr->sa_family) {
+ case AF_INET:
+ if (((struct sockaddr_in *)addr)->sin_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP);
+ break;
+ case AF_INET6:
+ if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0))
+ xprt_set_bound(xprt);
+
+ xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6);
+ break;
+ default:
+ ret = ERR_PTR(-EAFNOSUPPORT);
+ goto out_err;
+ }
+
+ if (xprt_bound(xprt))
+ dprintk("RPC: set up xprt to %s (port %s) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PORT],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+ else
+ dprintk("RPC: set up xprt to %s (autobind) via %s\n",
+ xprt->address_strings[RPC_DISPLAY_ADDR],
+ xprt->address_strings[RPC_DISPLAY_PROTO]);
+
+ if (try_module_get(THIS_MODULE))
+ return xprt;
+ ret = ERR_PTR(-EINVAL);
+out_err:
+ xs_xprt_free(xprt);
+ return ret;
+}
+
+/**
* xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket
* @args: rpc transport creation arguments
*
@@ -3158,6 +3649,15 @@ static struct xprt_class xs_tcp_transport = {
.netid = { "tcp", "tcp6", "" },
};
+static struct xprt_class xs_tcp_tls_transport = {
+ .list = LIST_HEAD_INIT(xs_tcp_tls_transport.list),
+ .name = "tcp-with-tls",
+ .owner = THIS_MODULE,
+ .ident = XPRT_TRANSPORT_TCP_TLS,
+ .setup = xs_setup_tcp_tls,
+ .netid = { "tcp", "tcp6", "" },
+};
+
static struct xprt_class xs_bc_tcp_transport = {
.list = LIST_HEAD_INIT(xs_bc_tcp_transport.list),
.name = "tcp NFSv4.1 backchannel",
@@ -3174,11 +3674,12 @@ static struct xprt_class xs_bc_tcp_transport = {
int init_socket_xprt(void)
{
if (!sunrpc_table_header)
- sunrpc_table_header = register_sysctl_table(sunrpc_table);
+ sunrpc_table_header = register_sysctl("sunrpc", xs_tunables_table);
xprt_register_transport(&xs_local_transport);
xprt_register_transport(&xs_udp_transport);
xprt_register_transport(&xs_tcp_transport);
+ xprt_register_transport(&xs_tcp_tls_transport);
xprt_register_transport(&xs_bc_tcp_transport);
return 0;
@@ -3198,6 +3699,7 @@ void cleanup_socket_xprt(void)
xprt_unregister_transport(&xs_local_transport);
xprt_unregister_transport(&xs_udp_transport);
xprt_unregister_transport(&xs_tcp_transport);
+ xprt_unregister_transport(&xs_tcp_tls_transport);
xprt_unregister_transport(&xs_bc_tcp_transport);
}