summaryrefslogtreecommitdiff
path: root/net/sunrpc/xprtsock.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/sunrpc/xprtsock.c')
-rw-r--r--net/sunrpc/xprtsock.c251
1 files changed, 166 insertions, 85 deletions
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
index 9f010369100a..2e1fe6013361 100644
--- a/net/sunrpc/xprtsock.c
+++ b/net/sunrpc/xprtsock.c
@@ -47,7 +47,7 @@
#include <net/checksum.h>
#include <net/udp.h>
#include <net/tcp.h>
-#include <net/tls.h>
+#include <net/tls_prot.h>
#include <net/handshake.h>
#include <linux/bvec.h>
@@ -62,6 +62,7 @@
#include "sunrpc.h"
static void xs_close(struct rpc_xprt *xprt);
+static void xs_reset_srcport(struct sock_xprt *transport);
static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock);
static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
struct socket *sock);
@@ -159,7 +160,6 @@ static struct ctl_table xs_tunables_table[] = {
.mode = 0644,
.proc_handler = proc_dointvec_jiffies,
},
- { },
};
/*
@@ -358,44 +358,61 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp)
static int
xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
- struct cmsghdr *cmsg, int ret)
-{
- if (cmsg->cmsg_level == SOL_TLS &&
- cmsg->cmsg_type == TLS_GET_RECORD_TYPE) {
- u8 content_type = *((u8 *)CMSG_DATA(cmsg));
-
- switch (content_type) {
- case TLS_RECORD_TYPE_DATA:
- /* TLS sets EOR at the end of each application data
- * record, even though there might be more frames
- * waiting to be decrypted.
- */
- msg->msg_flags &= ~MSG_EOR;
- break;
- case TLS_RECORD_TYPE_ALERT:
- ret = -ENOTCONN;
- break;
- default:
- ret = -EAGAIN;
- }
+ unsigned int *msg_flags, struct cmsghdr *cmsg, int ret)
+{
+ u8 content_type = tls_get_record_type(sock->sk, cmsg);
+ u8 level, description;
+
+ switch (content_type) {
+ case 0:
+ break;
+ case TLS_RECORD_TYPE_DATA:
+ /* TLS sets EOR at the end of each application data
+ * record, even though there might be more frames
+ * waiting to be decrypted.
+ */
+ *msg_flags &= ~MSG_EOR;
+ break;
+ case TLS_RECORD_TYPE_ALERT:
+ tls_alert_recv(sock->sk, msg, &level, &description);
+ ret = (level == TLS_ALERT_LEVEL_FATAL) ?
+ -EACCES : -EAGAIN;
+ break;
+ default:
+ /* discard this record type */
+ ret = -EAGAIN;
}
return ret;
}
static int
-xs_sock_recv_cmsg(struct socket *sock, struct msghdr *msg, int flags)
+xs_sock_recv_cmsg(struct socket *sock, unsigned int *msg_flags, int flags)
{
union {
struct cmsghdr cmsg;
u8 buf[CMSG_SPACE(sizeof(u8))];
} u;
+ u8 alert[2];
+ struct kvec alert_kvec = {
+ .iov_base = alert,
+ .iov_len = sizeof(alert),
+ };
+ struct msghdr msg = {
+ .msg_flags = *msg_flags,
+ .msg_control = &u,
+ .msg_controllen = sizeof(u),
+ };
int ret;
- msg->msg_control = &u;
- msg->msg_controllen = sizeof(u);
- ret = sock_recvmsg(sock, msg, flags);
- if (msg->msg_controllen != sizeof(u))
- ret = xs_sock_process_cmsg(sock, msg, &u.cmsg, ret);
+ iov_iter_kvec(&msg.msg_iter, ITER_DEST, &alert_kvec, 1,
+ alert_kvec.iov_len);
+ ret = sock_recvmsg(sock, &msg, flags);
+ if (ret > 0) {
+ if (tls_get_record_type(sock->sk, &u.cmsg) == TLS_RECORD_TYPE_ALERT)
+ iov_iter_revert(&msg.msg_iter, ret);
+ ret = xs_sock_process_cmsg(sock, &msg, msg_flags, &u.cmsg,
+ -EAGAIN);
+ }
return ret;
}
@@ -405,7 +422,13 @@ xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek)
ssize_t ret;
if (seek != 0)
iov_iter_advance(&msg->msg_iter, seek);
- ret = xs_sock_recv_cmsg(sock, msg, flags);
+ ret = sock_recvmsg(sock, msg, flags);
+ /* Handle TLS inband control message lazily */
+ if (msg->msg_flags & MSG_CTRUNC) {
+ msg->msg_flags &= ~(MSG_CTRUNC | MSG_EOR);
+ if (ret == 0 || ret == -EIO)
+ ret = xs_sock_recv_cmsg(sock, &msg->msg_flags, flags);
+ }
return ret > 0 ? ret + seek : ret;
}
@@ -431,7 +454,7 @@ xs_read_discard(struct socket *sock, struct msghdr *msg, int flags,
size_t count)
{
iov_iter_discard(&msg->msg_iter, ITER_DEST, count);
- return xs_sock_recv_cmsg(sock, msg, flags);
+ return xs_sock_recvmsg(sock, msg, flags, 0);
}
#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
@@ -777,6 +800,8 @@ static void xs_stream_data_receive(struct sock_xprt *transport)
}
if (ret == -ESHUTDOWN)
kernel_sock_shutdown(transport->sock, SHUT_RDWR);
+ else if (ret == -EACCES)
+ xprt_wake_pending_tasks(&transport->xprt, -EACCES);
else
xs_poll_check_readable(transport);
out:
@@ -878,6 +903,17 @@ static int xs_stream_prepare_request(struct rpc_rqst *req, struct xdr_buf *buf)
return xdr_alloc_bvec(buf, rpc_task_gfp_mask());
}
+static void xs_stream_abort_send_request(struct rpc_rqst *req)
+{
+ struct rpc_xprt *xprt = req->rq_xprt;
+ struct sock_xprt *transport =
+ container_of(xprt, struct sock_xprt, xprt);
+
+ if (transport->xmit.offset != 0 &&
+ !test_bit(XPRT_CLOSE_WAIT, &xprt->state))
+ xprt_force_disconnect(xprt);
+}
+
/*
* Determine if the previous message in the stream was aborted before it
* could complete transmission.
@@ -1176,11 +1212,13 @@ static void xs_sock_reset_state_flags(struct rpc_xprt *xprt)
{
struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ transport->xprt_err = 0;
clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state);
clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state);
clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state);
clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state);
clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state);
+ clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state);
}
static void xs_run_error_worker(struct sock_xprt *transport, unsigned int nr)
@@ -1261,6 +1299,7 @@ static void xs_reset_transport(struct sock_xprt *transport)
transport->file = NULL;
sk->sk_user_data = NULL;
+ sk->sk_sndtimeo = 0;
xs_restore_old_callbacks(transport, sk);
xprt_clear_connected(xprt);
@@ -1292,6 +1331,8 @@ static void xs_close(struct rpc_xprt *xprt)
dprintk("RPC: xs_close xprt %p\n", xprt);
+ if (transport->sock)
+ tls_handshake_close(transport->sock);
xs_reset_transport(transport);
xprt->reestablish_timeout = 0;
}
@@ -1557,8 +1598,10 @@ static void xs_tcp_state_change(struct sock *sk)
break;
case TCP_CLOSE:
if (test_and_clear_bit(XPRT_SOCK_CONNECTING,
- &transport->sock_state))
+ &transport->sock_state)) {
+ xs_reset_srcport(transport);
xprt_clear_connecting(xprt);
+ }
clear_bit(XPRT_CLOSING, &xprt->state);
/* Trigger the socket release */
xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT);
@@ -1714,6 +1757,11 @@ static void xs_set_port(struct rpc_xprt *xprt, unsigned short port)
xs_update_peer_port(xprt);
}
+static void xs_reset_srcport(struct sock_xprt *transport)
+{
+ transport->srcport = 0;
+}
+
static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock)
{
if (transport->srcport == 0 && transport->xprt.reuseport)
@@ -1797,8 +1845,8 @@ static int xs_bind(struct sock_xprt *transport, struct socket *sock)
memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen);
do {
rpc_set_port((struct sockaddr *)&myaddr, port);
- err = kernel_bind(sock, (struct sockaddr *)&myaddr,
- transport->xprt.addrlen);
+ err = kernel_bind(sock, (struct sockaddr_unsized *)&myaddr,
+ transport->xprt.addrlen);
if (err == 0) {
if (transport->xprt.reuseport)
transport->srcport = port;
@@ -1913,6 +1961,9 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
goto out;
}
+ if (protocol == IPPROTO_TCP)
+ sk_net_refcnt_upgrade(sock->sk);
+
filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
if (IS_ERR(filp))
return ERR_CAST(filp);
@@ -1954,7 +2005,7 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt,
xs_stream_start_connect(transport);
- return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0);
+ return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt), xprt->addrlen, 0);
}
/**
@@ -2230,9 +2281,13 @@ static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
struct socket *sock)
{
struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ struct net *net = sock_net(sock->sk);
+ unsigned long connect_timeout;
+ unsigned long syn_retries;
unsigned int keepidle;
unsigned int keepcnt;
unsigned int timeo;
+ unsigned long t;
spin_lock(&xprt->transport_lock);
keepidle = DIV_ROUND_UP(xprt->timeout->to_initval, HZ);
@@ -2250,6 +2305,35 @@ static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt,
/* TCP user timeout (see RFC5482) */
tcp_sock_set_user_timeout(sock->sk, timeo);
+
+ /* Connect timeout */
+ connect_timeout = max_t(unsigned long,
+ DIV_ROUND_UP(xprt->connect_timeout, HZ), 1);
+ syn_retries = max_t(unsigned long,
+ READ_ONCE(net->ipv4.sysctl_tcp_syn_retries), 1);
+ for (t = 0; t <= syn_retries && (1UL << t) < connect_timeout; t++)
+ ;
+ if (t <= syn_retries)
+ tcp_sock_set_syncnt(sock->sk, t - 1);
+}
+
+static void xs_tcp_do_set_connect_timeout(struct rpc_xprt *xprt,
+ unsigned long connect_timeout)
+{
+ struct sock_xprt *transport =
+ container_of(xprt, struct sock_xprt, xprt);
+ struct rpc_timeout to;
+ unsigned long initval;
+
+ memcpy(&to, xprt->timeout, sizeof(to));
+ /* Arbitrary lower limit */
+ initval = max_t(unsigned long, connect_timeout, XS_TCP_INIT_REEST_TO);
+ to.to_initval = initval;
+ to.to_maxval = initval;
+ to.to_retries = 0;
+ memcpy(&transport->tcp_timeout, &to, sizeof(transport->tcp_timeout));
+ xprt->timeout = &transport->tcp_timeout;
+ xprt->connect_timeout = connect_timeout;
}
static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt,
@@ -2257,25 +2341,12 @@ static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt,
unsigned long reconnect_timeout)
{
struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
- struct rpc_timeout to;
- unsigned long initval;
spin_lock(&xprt->transport_lock);
if (reconnect_timeout < xprt->max_reconnect_timeout)
xprt->max_reconnect_timeout = reconnect_timeout;
- if (connect_timeout < xprt->connect_timeout) {
- memcpy(&to, xprt->timeout, sizeof(to));
- initval = DIV_ROUND_UP(connect_timeout, to.to_retries + 1);
- /* Arbitrary lower limit */
- if (initval < XS_TCP_INIT_REEST_TO << 1)
- initval = XS_TCP_INIT_REEST_TO << 1;
- to.to_initval = initval;
- to.to_maxval = initval;
- memcpy(&transport->tcp_timeout, &to,
- sizeof(transport->tcp_timeout));
- xprt->timeout = &transport->tcp_timeout;
- xprt->connect_timeout = connect_timeout;
- }
+ if (connect_timeout < xprt->connect_timeout)
+ xs_tcp_do_set_connect_timeout(xprt, connect_timeout);
set_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state);
spin_unlock(&xprt->transport_lock);
}
@@ -2334,7 +2405,8 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
/* Tell the socket layer to start connecting... */
set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state);
- return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK);
+ return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt),
+ xprt->addrlen, O_NONBLOCK);
}
/**
@@ -2395,6 +2467,13 @@ static void xs_tcp_setup_socket(struct work_struct *work)
transport->srcport = 0;
status = -EAGAIN;
break;
+ case -EPERM:
+ /* Happens, for instance, if a BPF program is preventing
+ * the connect. Remap the error so upper layers can better
+ * deal with it.
+ */
+ status = -ECONNREFUSED;
+ fallthrough;
case -EINVAL:
/* Happens, for instance, if the user specified a link
* local IPv6 address without a scope-id.
@@ -2406,6 +2485,7 @@ static void xs_tcp_setup_socket(struct work_struct *work)
case -EHOSTUNREACH:
case -EADDRINUSE:
case -ENOBUFS:
+ case -ENOTCONN:
break;
default:
printk("%s: connect returned unhandled error %d\n",
@@ -2518,7 +2598,15 @@ static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid)
struct sock_xprt *lower_transport =
container_of(lower_xprt, struct sock_xprt, xprt);
- lower_transport->xprt_err = status ? -EACCES : 0;
+ switch (status) {
+ case 0:
+ case -EACCES:
+ case -ETIMEDOUT:
+ lower_transport->xprt_err = status;
+ break;
+ default:
+ lower_transport->xprt_err = -EACCES;
+ }
complete(&lower_transport->handshake_done);
xprt_put(lower_xprt);
}
@@ -2560,11 +2648,10 @@ static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_par
rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done,
XS_TLS_HANDSHAKE_TO);
if (rc <= 0) {
- if (!tls_handshake_cancel(sk)) {
- if (rc == 0)
- rc = -ETIMEDOUT;
- goto out_put_xprt;
- }
+ tls_handshake_cancel(sk);
+ if (rc == 0)
+ rc = -ETIMEDOUT;
+ goto out_put_xprt;
}
rc = lower_transport->xprt_err;
@@ -2617,6 +2704,7 @@ static void xs_tcp_tls_setup_socket(struct work_struct *work)
.xprtsec = {
.policy = RPC_XPRTSEC_NONE,
},
+ .stats = upper_clnt->cl_stats,
};
unsigned int pflags = current->flags;
struct rpc_clnt *lower_clnt;
@@ -2645,6 +2733,10 @@ static void xs_tcp_tls_setup_socket(struct work_struct *work)
rcu_read_lock();
lower_xprt = rcu_dereference(lower_clnt->cl_xprt);
rcu_read_unlock();
+
+ if (wait_on_bit_lock(&lower_xprt->state, XPRT_LOCKED, TASK_KILLABLE))
+ goto out_unlock;
+
status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec);
if (status) {
trace_rpc_tls_not_started(upper_clnt, upper_xprt);
@@ -2654,20 +2746,15 @@ static void xs_tcp_tls_setup_socket(struct work_struct *work)
status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport);
if (status)
goto out_close;
-
+ xprt_release_write(lower_xprt, NULL);
trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0);
- if (!xprt_test_and_set_connected(upper_xprt)) {
- upper_xprt->connect_cookie++;
- clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
- xprt_clear_connecting(upper_xprt);
-
- upper_xprt->stat.connect_count++;
- upper_xprt->stat.connect_time += (long)jiffies -
- upper_xprt->stat.connect_start;
- xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
- }
rpc_shutdown_client(lower_clnt);
+ /* Check for ingress data that arrived before the socket's
+ * ->data_ready callback was set up.
+ */
+ xs_poll_check_readable(upper_transport);
+
out_unlock:
current_restore_flags(pflags, PF_MEMALLOC);
upper_transport->clnt = NULL;
@@ -2675,6 +2762,7 @@ out_unlock:
return;
out_close:
+ xprt_release_write(lower_xprt, NULL);
rpc_shutdown_client(lower_clnt);
/* xprt_force_disconnect() wakes tasks with a fixed tk_status code.
@@ -2739,18 +2827,13 @@ static void xs_wake_error(struct sock_xprt *transport)
{
int sockerr;
- if (!test_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state))
- return;
- mutex_lock(&transport->recv_mutex);
- if (transport->sock == NULL)
- goto out;
if (!test_and_clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state))
- goto out;
+ return;
sockerr = xchg(&transport->xprt_err, 0);
- if (sockerr < 0)
+ if (sockerr < 0) {
xprt_wake_pending_tasks(&transport->xprt, sockerr);
-out:
- mutex_unlock(&transport->recv_mutex);
+ xs_tcp_force_close(&transport->xprt);
+ }
}
static void xs_wake_pending(struct sock_xprt *transport)
@@ -2958,20 +3041,11 @@ static int bc_send_request(struct rpc_rqst *req)
return len;
}
-/*
- * The close routine. Since this is client initiated, we do nothing
- */
-
static void bc_close(struct rpc_xprt *xprt)
{
xprt_disconnect_done(xprt);
}
-/*
- * The xprt destroy routine. Again, because this connection is client
- * initiated, we do nothing
- */
-
static void bc_destroy(struct rpc_xprt *xprt)
{
dprintk("RPC: bc_destroy xprt %p\n", xprt);
@@ -2992,6 +3066,7 @@ static const struct rpc_xprt_ops xs_local_ops = {
.buf_free = rpc_free,
.prepare_request = xs_stream_prepare_request,
.send_request = xs_local_send_request,
+ .abort_send_request = xs_stream_abort_send_request,
.wait_for_reply_request = xprt_wait_for_reply_request_def,
.close = xs_close,
.destroy = xs_destroy,
@@ -3039,6 +3114,7 @@ static const struct rpc_xprt_ops xs_tcp_ops = {
.buf_free = rpc_free,
.prepare_request = xs_stream_prepare_request,
.send_request = xs_tcp_send_request,
+ .abort_send_request = xs_stream_abort_send_request,
.wait_for_reply_request = xprt_wait_for_reply_request_def,
.close = xs_tcp_shutdown,
.destroy = xs_destroy,
@@ -3328,8 +3404,13 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args)
xprt->timeout = &xs_tcp_default_timeout;
xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
+ if (args->reconnect_timeout)
+ xprt->max_reconnect_timeout = args->reconnect_timeout;
+
xprt->connect_timeout = xprt->timeout->to_initval *
(xprt->timeout->to_retries + 1);
+ if (args->connect_timeout)
+ xs_tcp_do_set_connect_timeout(xprt, args->connect_timeout);
INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
INIT_WORK(&transport->error_worker, xs_error_handle);