summaryrefslogtreecommitdiff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls_main.c47
-rw-r--r--net/tls/tls_sw.c44
2 files changed, 69 insertions, 22 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index acfba9f1ba72..6bc2879ba637 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -61,7 +61,7 @@ static DEFINE_MUTEX(tcpv6_prot_mutex);
static const struct proto *saved_tcpv4_prot;
static DEFINE_MUTEX(tcpv4_prot_mutex);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
-static struct proto_ops tls_sw_proto_ops;
+static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
const struct proto *base);
@@ -71,6 +71,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx)
WRITE_ONCE(sk->sk_prot,
&tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
+ WRITE_ONCE(sk->sk_socket->ops,
+ &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
}
int wait_on_pending_writer(struct sock *sk, long *timeo)
@@ -669,8 +671,6 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
if (tx) {
ctx->sk_write_space = sk->sk_write_space;
sk->sk_write_space = tls_write_space;
- } else {
- sk->sk_socket->ops = &tls_sw_proto_ops;
}
goto out;
@@ -728,6 +728,39 @@ struct tls_context *tls_ctx_create(struct sock *sk)
return ctx;
}
+static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
+ const struct proto_ops *base)
+{
+ ops[TLS_BASE][TLS_BASE] = *base;
+
+ ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
+ ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked;
+
+ ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE];
+ ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read;
+
+ ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE];
+ ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read;
+
+#ifdef CONFIG_TLS_DEVICE
+ ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
+ ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL;
+
+ ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ];
+ ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL;
+
+ ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ];
+
+ ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ];
+
+ ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ];
+ ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL;
+#endif
+#ifdef CONFIG_TLS_TOE
+ ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
+#endif
+}
+
static void tls_build_proto(struct sock *sk)
{
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
@@ -739,6 +772,8 @@ static void tls_build_proto(struct sock *sk)
mutex_lock(&tcpv6_prot_mutex);
if (likely(prot != saved_tcpv6_prot)) {
build_protos(tls_prots[TLSV6], prot);
+ build_proto_ops(tls_proto_ops[TLSV6],
+ sk->sk_socket->ops);
smp_store_release(&saved_tcpv6_prot, prot);
}
mutex_unlock(&tcpv6_prot_mutex);
@@ -749,6 +784,8 @@ static void tls_build_proto(struct sock *sk)
mutex_lock(&tcpv4_prot_mutex);
if (likely(prot != saved_tcpv4_prot)) {
build_protos(tls_prots[TLSV4], prot);
+ build_proto_ops(tls_proto_ops[TLSV4],
+ sk->sk_socket->ops);
smp_store_release(&saved_tcpv4_prot, prot);
}
mutex_unlock(&tcpv4_prot_mutex);
@@ -959,10 +996,6 @@ static int __init tls_register(void)
if (err)
return err;
- tls_sw_proto_ops = inet_stream_ops;
- tls_sw_proto_ops.splice_read = tls_sw_splice_read;
- tls_sw_proto_ops.sendpage_locked = tls_sw_sendpage_locked;
-
tls_device_init();
tcp_register_ulp(&tcp_tls_ulp_ops);
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index d81564078557..dfe623a4e72f 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -521,7 +521,7 @@ static int tls_do_encryption(struct sock *sk,
memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
prot->iv_size + prot->salt_size);
- xor_iv_with_seq(prot, rec->iv_data, tls_ctx->tx.rec_seq);
+ xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
sge->offset += prot->prepend_size;
sge->length -= prot->prepend_size;
@@ -1499,7 +1499,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
else
memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
- xor_iv_with_seq(prot, iv, tls_ctx->rx.rec_seq);
+ xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
/* Prepare AAD */
tls_make_aad(aad, rxm->full_len - prot->overhead_size +
@@ -2005,6 +2005,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct sock *sk = sock->sk;
struct sk_buff *skb;
ssize_t copied = 0;
+ bool from_queue;
int err = 0;
long timeo;
int chunk;
@@ -2014,25 +2015,28 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
- skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err);
- if (!skb)
- goto splice_read_end;
-
- if (!ctx->decrypted) {
- err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
-
- /* splice does not support reading control messages */
- if (ctx->control != TLS_RECORD_TYPE_DATA) {
- err = -EINVAL;
+ from_queue = !skb_queue_empty(&ctx->rx_list);
+ if (from_queue) {
+ skb = __skb_dequeue(&ctx->rx_list);
+ } else {
+ skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
+ &err);
+ if (!skb)
goto splice_read_end;
- }
+ err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
if (err < 0) {
tls_err_abort(sk, -EBADMSG);
goto splice_read_end;
}
- ctx->decrypted = 1;
}
+
+ /* splice does not support reading control messages */
+ if (ctx->control != TLS_RECORD_TYPE_DATA) {
+ err = -EINVAL;
+ goto splice_read_end;
+ }
+
rxm = strp_msg(skb);
chunk = min_t(unsigned int, rxm->full_len, len);
@@ -2040,7 +2044,17 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
if (copied < 0)
goto splice_read_end;
- tls_sw_advance_skb(sk, skb, copied);
+ if (!from_queue) {
+ ctx->recv_pkt = NULL;
+ __strp_unpause(&ctx->strp);
+ }
+ if (chunk < rxm->full_len) {
+ __skb_queue_head(&ctx->rx_list, skb);
+ rxm->offset += len;
+ rxm->full_len -= len;
+ } else {
+ consume_skb(skb);
+ }
splice_read_end:
release_sock(sk);