summaryrefslogtreecommitdiff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls.h6
-rw-r--r--net/tls/tls_device.c29
-rw-r--r--net/tls/tls_main.c71
-rw-r--r--net/tls/tls_proc.c10
-rw-r--r--net/tls/tls_strp.c25
-rw-r--r--net/tls/tls_sw.c46
6 files changed, 145 insertions, 42 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h
index 774859b63f0d..2f86baeb71fc 100644
--- a/net/tls/tls.h
+++ b/net/tls/tls.h
@@ -128,8 +128,9 @@ struct tls_rec {
char aad_space[TLS_AAD_SPACE_SIZE];
u8 iv_data[TLS_MAX_IV_SIZE];
+
+ /* Must be last --ends in a flexible-array member. */
struct aead_request aead_req;
- u8 aead_req_ctx[];
};
int __net_init tls_proc_init(struct net *net);
@@ -141,6 +142,7 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx);
int wait_on_pending_writer(struct sock *sk, long *timeo);
void tls_err_abort(struct sock *sk, int err);
+void tls_strp_abort_strp(struct tls_strparser *strp, int err);
int init_prot_info(struct tls_prot_info *prot,
const struct tls_crypto_info *crypto_info,
@@ -196,7 +198,7 @@ void tls_strp_msg_done(struct tls_strparser *strp);
int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb);
void tls_rx_msg_ready(struct tls_strparser *strp);
-void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh);
+bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh);
int tls_strp_msg_cow(struct tls_sw_context_rx *ctx);
struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx);
int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst);
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index f672a62a9a52..82ea407e520a 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -123,17 +123,19 @@ static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
/* We assume that the socket is already connected */
static struct net_device *get_netdev_for_sock(struct sock *sk)
{
- struct dst_entry *dst = sk_dst_get(sk);
- struct net_device *netdev = NULL;
+ struct net_device *dev, *lowest_dev = NULL;
+ struct dst_entry *dst;
- if (likely(dst)) {
- netdev = netdev_sk_get_lowest_dev(dst->dev, sk);
- dev_hold(netdev);
+ rcu_read_lock();
+ dst = __sk_dst_get(sk);
+ dev = dst ? dst_dev_rcu(dst) : NULL;
+ if (likely(dev)) {
+ lowest_dev = netdev_sk_get_lowest_dev(dev, sk);
+ dev_hold(lowest_dev);
}
+ rcu_read_unlock();
- dst_release(dst);
-
- return netdev;
+ return lowest_dev;
}
static void destroy_record(struct tls_record_info *record)
@@ -371,7 +373,8 @@ static int tls_do_allocation(struct sock *sk,
if (!offload_ctx->open_record) {
if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
sk->sk_allocation))) {
- READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
+ if (!sk->sk_bypass_prot_mem)
+ READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
sk_stream_moderate_sndbuf(sk);
return -ENOMEM;
}
@@ -459,7 +462,7 @@ static int tls_push_data(struct sock *sk,
/* TLS_HEADER_SIZE is not counted as part of the TLS record, and
* we need to leave room for an authentication tag.
*/
- max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
+ max_open_record_len = tls_ctx->tx_max_payload_len +
prot->prepend_size;
do {
rc = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size);
@@ -721,8 +724,10 @@ tls_device_rx_resync_async(struct tls_offload_resync_async *resync_async,
/* shouldn't get to wraparound:
* too long in async stage, something bad happened
*/
- if (WARN_ON_ONCE(resync_async->rcd_delta == USHRT_MAX))
+ if (WARN_ON_ONCE(resync_async->rcd_delta == USHRT_MAX)) {
+ tls_offload_rx_resync_async_request_cancel(resync_async);
return false;
+ }
/* asynchronous stage: log all headers seq such that
* req_seq <= seq <= end_seq, and wait for real resync request
@@ -1410,7 +1415,7 @@ int __init tls_device_init(void)
if (!dummy_page)
return -ENOMEM;
- destruct_wq = alloc_workqueue("ktls_device_destruct", 0, 0);
+ destruct_wq = alloc_workqueue("ktls_device_destruct", WQ_PERCPU, 0);
if (!destruct_wq) {
err = -ENOMEM;
goto err_free_dummy;
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index a3ccb3135e51..56ce0bc8317b 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -255,12 +255,9 @@ int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
if (msg->msg_flags & MSG_MORE)
return -EINVAL;
- rc = tls_handle_open_record(sk, msg->msg_flags);
- if (rc)
- return rc;
-
*record_type = *(unsigned char *)CMSG_DATA(cmsg);
- rc = 0;
+
+ rc = tls_handle_open_record(sk, msg->msg_flags);
break;
default:
return -EINVAL;
@@ -544,6 +541,28 @@ static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
return 0;
}
+static int do_tls_getsockopt_tx_payload_len(struct sock *sk, char __user *optval,
+ int __user *optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ u16 payload_len = ctx->tx_max_payload_len;
+ int len;
+
+ if (get_user(len, optlen))
+ return -EFAULT;
+
+ if (len < sizeof(payload_len))
+ return -EINVAL;
+
+ if (put_user(sizeof(payload_len), optlen))
+ return -EFAULT;
+
+ if (copy_to_user(optval, &payload_len, sizeof(payload_len)))
+ return -EFAULT;
+
+ return 0;
+}
+
static int do_tls_getsockopt(struct sock *sk, int optname,
char __user *optval, int __user *optlen)
{
@@ -563,6 +582,9 @@ static int do_tls_getsockopt(struct sock *sk, int optname,
case TLS_RX_EXPECT_NO_PAD:
rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
break;
+ case TLS_TX_MAX_PAYLOAD_LEN:
+ rc = do_tls_getsockopt_tx_payload_len(sk, optval, optlen);
+ break;
default:
rc = -ENOPROTOOPT;
break;
@@ -812,6 +834,32 @@ static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
return rc;
}
+static int do_tls_setsockopt_tx_payload_len(struct sock *sk, sockptr_t optval,
+ unsigned int optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ struct tls_sw_context_tx *sw_ctx = tls_sw_ctx_tx(ctx);
+ u16 value;
+ bool tls_13 = ctx->prot_info.version == TLS_1_3_VERSION;
+
+ if (sw_ctx && sw_ctx->open_rec)
+ return -EBUSY;
+
+ if (sockptr_is_null(optval) || optlen != sizeof(value))
+ return -EINVAL;
+
+ if (copy_from_sockptr(&value, optval, sizeof(value)))
+ return -EFAULT;
+
+ if (value < TLS_MIN_RECORD_SIZE_LIM - (tls_13 ? 1 : 0) ||
+ value > TLS_MAX_PAYLOAD_SIZE)
+ return -EINVAL;
+
+ ctx->tx_max_payload_len = value;
+
+ return 0;
+}
+
static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
unsigned int optlen)
{
@@ -833,6 +881,11 @@ static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
case TLS_RX_EXPECT_NO_PAD:
rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
break;
+ case TLS_TX_MAX_PAYLOAD_LEN:
+ lock_sock(sk);
+ rc = do_tls_setsockopt_tx_payload_len(sk, optval, optlen);
+ release_sock(sk);
+ break;
default:
rc = -ENOPROTOOPT;
break;
@@ -1022,6 +1075,7 @@ static int tls_init(struct sock *sk)
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
+ ctx->tx_max_payload_len = TLS_MAX_PAYLOAD_SIZE;
update_sk_prot(sk, ctx);
out:
write_unlock_bh(&sk->sk_callback_lock);
@@ -1111,6 +1165,12 @@ static int tls_get_info(struct sock *sk, struct sk_buff *skb, bool net_admin)
goto nla_failure;
}
+ err = nla_put_u16(skb, TLS_INFO_TX_MAX_PAYLOAD_LEN,
+ ctx->tx_max_payload_len);
+
+ if (err)
+ goto nla_failure;
+
rcu_read_unlock();
nla_nest_end(skb, start);
return 0;
@@ -1132,6 +1192,7 @@ static size_t tls_get_info_size(const struct sock *sk, bool net_admin)
nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */
nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */
nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */
+ nla_total_size(sizeof(u16)) + /* TLS_INFO_TX_MAX_PAYLOAD_LEN */
0;
return size;
diff --git a/net/tls/tls_proc.c b/net/tls/tls_proc.c
index 367666aa07b8..4012c4372d4c 100644
--- a/net/tls/tls_proc.c
+++ b/net/tls/tls_proc.c
@@ -27,17 +27,19 @@ static const struct snmp_mib tls_mib_list[] = {
SNMP_MIB_ITEM("TlsTxRekeyOk", LINUX_MIB_TLSTXREKEYOK),
SNMP_MIB_ITEM("TlsTxRekeyError", LINUX_MIB_TLSTXREKEYERROR),
SNMP_MIB_ITEM("TlsRxRekeyReceived", LINUX_MIB_TLSRXREKEYRECEIVED),
- SNMP_MIB_SENTINEL
};
static int tls_statistics_seq_show(struct seq_file *seq, void *v)
{
- unsigned long buf[LINUX_MIB_TLSMAX] = {};
+ unsigned long buf[ARRAY_SIZE(tls_mib_list)];
+ const int cnt = ARRAY_SIZE(tls_mib_list);
struct net *net = seq->private;
int i;
- snmp_get_cpu_field_batch(buf, tls_mib_list, net->mib.tls_statistics);
- for (i = 0; tls_mib_list[i].name; i++)
+ memset(buf, 0, sizeof(buf));
+ snmp_get_cpu_field_batch_cnt(buf, tls_mib_list, cnt,
+ net->mib.tls_statistics);
+ for (i = 0; i < cnt; i++)
seq_printf(seq, "%-32s\t%lu\n", tls_mib_list[i].name, buf[i]);
return 0;
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c
index 095cf31bae0b..98e12f0ff57e 100644
--- a/net/tls/tls_strp.c
+++ b/net/tls/tls_strp.c
@@ -13,7 +13,7 @@
static struct workqueue_struct *tls_strp_wq;
-static void tls_strp_abort_strp(struct tls_strparser *strp, int err)
+void tls_strp_abort_strp(struct tls_strparser *strp, int err)
{
if (strp->stopped)
return;
@@ -211,11 +211,17 @@ static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb,
struct sk_buff *in_skb, unsigned int offset,
size_t in_len)
{
+ unsigned int nfrag = skb->len / PAGE_SIZE;
size_t len, chunk;
skb_frag_t *frag;
int sz;
- frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
+ if (unlikely(nfrag >= skb_shinfo(skb)->nr_frags)) {
+ DEBUG_NET_WARN_ON_ONCE(1);
+ return -EMSGSIZE;
+ }
+
+ frag = &skb_shinfo(skb)->frags[nfrag];
len = in_len;
/* First make sure we got the header */
@@ -475,7 +481,7 @@ static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
strp->stm.offset = offset;
}
-void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
+bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
{
struct strp_msg *rxm;
struct tls_msg *tlm;
@@ -484,8 +490,11 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
if (!strp->copy_mode && force_refresh) {
- if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len))
- return;
+ if (unlikely(tcp_inq(strp->sk) < strp->stm.full_len)) {
+ WRITE_ONCE(strp->msg_ready, 0);
+ memset(&strp->stm, 0, sizeof(strp->stm));
+ return false;
+ }
tls_strp_load_anchor_with_queue(strp, strp->stm.full_len);
}
@@ -495,6 +504,8 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
rxm->offset = strp->stm.offset;
tlm = tls_msg(strp->anchor);
tlm->control = strp->mark;
+
+ return true;
}
/* Called with lock held on lower socket */
@@ -515,10 +526,8 @@ static int tls_strp_read_sock(struct tls_strparser *strp)
tls_strp_load_anchor_with_queue(strp, inq);
if (!strp->stm.full_len) {
sz = tls_rx_msg_size(strp, strp->anchor);
- if (sz < 0) {
- tls_strp_abort_strp(strp, sz);
+ if (sz < 0)
return sz;
- }
strp->stm.full_len = sz;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 549d1ea01a72..9937d4c810f2 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -1054,7 +1054,7 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
if (ret == -EINPROGRESS)
num_async++;
else if (ret != -EAGAIN)
- goto send_end;
+ goto end;
}
}
@@ -1079,7 +1079,7 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
orig_size = msg_pl->sg.size;
full_record = false;
try_to_copy = msg_data_left(msg);
- record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
+ record_room = tls_ctx->tx_max_payload_len - msg_pl->sg.size;
if (try_to_copy >= record_room) {
try_to_copy = record_room;
full_record = true;
@@ -1112,8 +1112,11 @@ alloc_encrypted:
goto send_end;
tls_ctx->pending_open_record_frags = true;
- if (sk_msg_full(msg_pl))
+ if (sk_msg_full(msg_pl)) {
full_record = true;
+ sk_msg_trim(sk, msg_en,
+ msg_pl->sg.size + prot->overhead_size);
+ }
if (full_record || eor)
goto copied;
@@ -1149,6 +1152,13 @@ alloc_encrypted:
} else if (ret != -EAGAIN)
goto send_end;
}
+
+ /* Transmit if any encryptions have completed */
+ if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
+ cancel_delayed_work(&ctx->tx_work.work);
+ tls_tx_records(sk, msg->msg_flags);
+ }
+
continue;
rollback_iter:
copied -= try_to_copy;
@@ -1204,6 +1214,12 @@ copied:
goto send_end;
}
}
+
+ /* Transmit if any encryptions have completed */
+ if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
+ cancel_delayed_work(&ctx->tx_work.work);
+ tls_tx_records(sk, msg->msg_flags);
+ }
}
continue;
@@ -1223,8 +1239,9 @@ trim_sgl:
goto alloc_encrypted;
}
+send_end:
if (!num_async) {
- goto send_end;
+ goto end;
} else if (num_zc || eor) {
int err;
@@ -1242,7 +1259,7 @@ trim_sgl:
tls_tx_records(sk, msg->msg_flags);
}
-send_end:
+end:
ret = sk_stream_error(sk, msg->msg_flags, ret);
return copied > 0 ? copied : ret;
}
@@ -1384,7 +1401,8 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
return sock_intr_errno(timeo);
}
- tls_strp_msg_load(&ctx->strp, released);
+ if (unlikely(!tls_strp_msg_load(&ctx->strp, released)))
+ return tls_rx_rec_wait(sk, psock, nonblock, false);
return 1;
}
@@ -1636,8 +1654,10 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
if (unlikely(darg->async)) {
err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold);
- if (err)
- __skb_queue_tail(&ctx->async_hold, darg->skb);
+ if (err) {
+ err = tls_decrypt_async_wait(ctx);
+ darg->async = false;
+ }
return err;
}
@@ -1807,6 +1827,9 @@ int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
return tls_decrypt_sg(sk, NULL, sgout, &darg);
}
+/* All records returned from a recvmsg() call must have the same type.
+ * 0 is not a valid content type. Use it as "no type reported, yet".
+ */
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
u8 *control)
{
@@ -2050,8 +2073,10 @@ int tls_sw_recvmsg(struct sock *sk,
if (err < 0)
goto end;
+ /* process_rx_list() will set @control if it processed any records */
copied = err;
- if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA) || rx_more)
+ if (len <= copied || rx_more ||
+ (control && control != TLS_RECORD_TYPE_DATA))
goto end;
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
@@ -2468,8 +2493,7 @@ int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb)
return data_len + TLS_HEADER_SIZE;
read_failure:
- tls_err_abort(strp->sk, ret);
-
+ tls_strp_abort_strp(strp, ret);
return ret;
}