summaryrefslogtreecommitdiff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls_device.c55
-rw-r--r--net/tls/tls_device_fallback.c16
-rw-r--r--net/tls/tls_main.c29
-rw-r--r--net/tls/tls_sw.c18
4 files changed, 82 insertions, 36 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index 135a7ee9db03..14dedb24fa7b 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -52,8 +52,11 @@ static DEFINE_SPINLOCK(tls_device_lock);
static void tls_device_free_ctx(struct tls_context *ctx)
{
- if (ctx->tx_conf == TLS_HW)
+ if (ctx->tx_conf == TLS_HW) {
kfree(tls_offload_ctx_tx(ctx));
+ kfree(ctx->tx.rec_seq);
+ kfree(ctx->tx.iv);
+ }
if (ctx->rx_conf == TLS_HW)
kfree(tls_offload_ctx_rx(ctx));
@@ -216,6 +219,13 @@ void tls_device_sk_destruct(struct sock *sk)
}
EXPORT_SYMBOL(tls_device_sk_destruct);
+void tls_device_free_resources_tx(struct sock *sk)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+
+ tls_free_partial_record(sk, tls_ctx);
+}
+
static void tls_append_frag(struct tls_record_info *record,
struct page_frag *pfrag,
int size)
@@ -587,7 +597,7 @@ void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
{
struct strp_msg *rxm = strp_msg(skb);
- int err = 0, offset = rxm->offset, copy, nsg;
+ int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
struct sk_buff *skb_iter, *unused;
struct scatterlist sg[1];
char *orig_buf, *buf;
@@ -618,25 +628,42 @@ static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
else
err = 0;
- copy = min_t(int, skb_pagelen(skb) - offset,
- rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+ data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
- if (skb->decrypted)
- skb_store_bits(skb, offset, buf, copy);
+ if (skb_pagelen(skb) > offset) {
+ copy = min_t(int, skb_pagelen(skb) - offset, data_len);
- offset += copy;
- buf += copy;
+ if (skb->decrypted)
+ skb_store_bits(skb, offset, buf, copy);
+
+ offset += copy;
+ buf += copy;
+ }
+ pos = skb_pagelen(skb);
skb_walk_frags(skb, skb_iter) {
- copy = min_t(int, skb_iter->len,
- rxm->full_len - offset + rxm->offset -
- TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+ int frag_pos;
+
+ /* Practically all frags must belong to msg if reencrypt
+ * is needed with current strparser and coalescing logic,
+ * but strparser may "get optimized", so let's be safe.
+ */
+ if (pos + skb_iter->len <= offset)
+ goto done_with_frag;
+ if (pos >= data_len + rxm->offset)
+ break;
+
+ frag_pos = offset - pos;
+ copy = min_t(int, skb_iter->len - frag_pos,
+ data_len + rxm->offset - offset);
if (skb_iter->decrypted)
- skb_store_bits(skb_iter, offset, buf, copy);
+ skb_store_bits(skb_iter, frag_pos, buf, copy);
offset += copy;
buf += copy;
+done_with_frag:
+ pos += skb_iter->len;
}
free_buf:
@@ -894,7 +921,9 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
goto release_netdev;
free_sw_resources:
+ up_read(&device_offload_lock);
tls_sw_free_resources_rx(sk);
+ down_read(&device_offload_lock);
release_ctx:
ctx->priv_ctx_rx = NULL;
release_netdev:
@@ -929,8 +958,6 @@ void tls_device_offload_cleanup_rx(struct sock *sk)
}
out:
up_read(&device_offload_lock);
- kfree(tls_ctx->rx.rec_seq);
- kfree(tls_ctx->rx.iv);
tls_sw_release_resources_rx(sk);
}
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
index 54c3a758f2a7..c3a5fe624b4e 100644
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -194,18 +194,26 @@ static void update_chksum(struct sk_buff *skb, int headln)
static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
{
+ struct sock *sk = skb->sk;
+ int delta;
+
skb_copy_header(nskb, skb);
skb_put(nskb, skb->len);
memcpy(nskb->data, skb->data, headln);
- update_chksum(nskb, headln);
nskb->destructor = skb->destructor;
- nskb->sk = skb->sk;
+ nskb->sk = sk;
skb->destructor = NULL;
skb->sk = NULL;
- refcount_add(nskb->truesize - skb->truesize,
- &nskb->sk->sk_wmem_alloc);
+
+ update_chksum(nskb, headln);
+
+ delta = nskb->truesize - skb->truesize;
+ if (likely(delta < 0))
+ WARN_ON_ONCE(refcount_sub_and_test(-delta, &sk->sk_wmem_alloc));
+ else if (delta)
+ refcount_add(delta, &sk->sk_wmem_alloc);
}
/* This function may be called after the user socket is already
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index df921a2904b9..478603f43964 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -208,6 +208,26 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
return tls_push_sg(sk, ctx, sg, offset, flags);
}
+bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
+{
+ struct scatterlist *sg;
+
+ sg = ctx->partially_sent_record;
+ if (!sg)
+ return false;
+
+ while (1) {
+ put_page(sg_page(sg));
+ sk_mem_uncharge(sk, sg->length);
+
+ if (sg_is_last(sg))
+ break;
+ sg++;
+ }
+ ctx->partially_sent_record = NULL;
+ return true;
+}
+
static void tls_write_space(struct sock *sk)
{
struct tls_context *ctx = tls_get_ctx(sk);
@@ -267,13 +287,14 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
kfree(ctx->tx.rec_seq);
kfree(ctx->tx.iv);
tls_sw_free_resources_tx(sk);
+#ifdef CONFIG_TLS_DEVICE
+ } else if (ctx->tx_conf == TLS_HW) {
+ tls_device_free_resources_tx(sk);
+#endif
}
- if (ctx->rx_conf == TLS_SW) {
- kfree(ctx->rx.rec_seq);
- kfree(ctx->rx.iv);
+ if (ctx->rx_conf == TLS_SW)
tls_sw_free_resources_rx(sk);
- }
#ifdef CONFIG_TLS_DEVICE
if (ctx->rx_conf == TLS_HW)
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 20b191227969..29d6af43dd24 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -2052,20 +2052,7 @@ void tls_sw_free_resources_tx(struct sock *sk)
/* Free up un-sent records in tx_list. First, free
* the partially sent record if any at head of tx_list.
*/
- if (tls_ctx->partially_sent_record) {
- struct scatterlist *sg = tls_ctx->partially_sent_record;
-
- while (1) {
- put_page(sg_page(sg));
- sk_mem_uncharge(sk, sg->length);
-
- if (sg_is_last(sg))
- break;
- sg++;
- }
-
- tls_ctx->partially_sent_record = NULL;
-
+ if (tls_free_partial_record(sk, tls_ctx)) {
rec = list_first_entry(&ctx->tx_list,
struct tls_rec, list);
list_del(&rec->list);
@@ -2091,6 +2078,9 @@ void tls_sw_release_resources_rx(struct sock *sk)
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ kfree(tls_ctx->rx.rec_seq);
+ kfree(tls_ctx->rx.iv);
+
if (ctx->aead_recv) {
kfree_skb(ctx->recv_pkt);
ctx->recv_pkt = NULL;