diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_device.c | 105 | ||||
-rw-r--r-- | net/tls/tls_device_fallback.c | 16 | ||||
-rw-r--r-- | net/tls/tls_main.c | 60 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 123 |
4 files changed, 199 insertions, 105 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 135a7ee9db03..ca54a7c7ec81 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -52,8 +52,11 @@ static DEFINE_SPINLOCK(tls_device_lock); static void tls_device_free_ctx(struct tls_context *ctx) { - if (ctx->tx_conf == TLS_HW) + if (ctx->tx_conf == TLS_HW) { kfree(tls_offload_ctx_tx(ctx)); + kfree(ctx->tx.rec_seq); + kfree(ctx->tx.iv); + } if (ctx->rx_conf == TLS_HW) kfree(tls_offload_ctx_rx(ctx)); @@ -86,22 +89,6 @@ static void tls_device_gc_task(struct work_struct *work) } } -static void tls_device_attach(struct tls_context *ctx, struct sock *sk, - struct net_device *netdev) -{ - if (sk->sk_destruct != tls_device_sk_destruct) { - refcount_set(&ctx->refcount, 1); - dev_hold(netdev); - ctx->netdev = netdev; - spin_lock_irq(&tls_device_lock); - list_add_tail(&ctx->list, &tls_device_list); - spin_unlock_irq(&tls_device_lock); - - ctx->sk_destruct = sk->sk_destruct; - sk->sk_destruct = tls_device_sk_destruct; - } -} - static void tls_device_queue_ctx_destruction(struct tls_context *ctx) { unsigned long flags; @@ -196,7 +183,7 @@ static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) * socket and no in-flight SKBs associated with this * socket, so it is safe to free all the resources. */ -void tls_device_sk_destruct(struct sock *sk) +static void tls_device_sk_destruct(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); @@ -214,7 +201,13 @@ void tls_device_sk_destruct(struct sock *sk) if (refcount_dec_and_test(&tls_ctx->refcount)) tls_device_queue_ctx_destruction(tls_ctx); } -EXPORT_SYMBOL(tls_device_sk_destruct); + +void tls_device_free_resources_tx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + + tls_free_partial_record(sk, tls_ctx); +} static void tls_append_frag(struct tls_record_info *record, struct page_frag *pfrag, @@ -548,14 +541,11 @@ static int tls_device_push_pending_record(struct sock *sk, int flags) void tls_device_write_space(struct sock *sk, struct tls_context *ctx) { - int rc = 0; - if (!sk->sk_write_pending && tls_is_partially_sent_record(ctx)) { gfp_t sk_allocation = sk->sk_allocation; sk->sk_allocation = GFP_ATOMIC; - rc = tls_push_partial_record(sk, ctx, - MSG_DONTWAIT | MSG_NOSIGNAL); + tls_push_partial_record(sk, ctx, MSG_DONTWAIT | MSG_NOSIGNAL); sk->sk_allocation = sk_allocation; } } @@ -574,7 +564,7 @@ void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn) rx_ctx = tls_offload_ctx_rx(tls_ctx); resync_req = atomic64_read(&rx_ctx->resync_req); - req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1); + req_seq = (resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1); is_req_pending = resync_req; if (unlikely(is_req_pending) && req_seq == seq && @@ -587,7 +577,7 @@ void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn) static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb) { struct strp_msg *rxm = strp_msg(skb); - int err = 0, offset = rxm->offset, copy, nsg; + int err = 0, offset = rxm->offset, copy, nsg, data_len, pos; struct sk_buff *skb_iter, *unused; struct scatterlist sg[1]; char *orig_buf, *buf; @@ -618,25 +608,42 @@ static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb) else err = 0; - copy = min_t(int, skb_pagelen(skb) - offset, - rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE); + data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE; - if (skb->decrypted) - skb_store_bits(skb, offset, buf, copy); + if (skb_pagelen(skb) > offset) { + copy = min_t(int, skb_pagelen(skb) - offset, data_len); - offset += copy; - buf += copy; + if (skb->decrypted) + skb_store_bits(skb, offset, buf, copy); + offset += copy; + buf += copy; + } + + pos = skb_pagelen(skb); skb_walk_frags(skb, skb_iter) { - copy = min_t(int, skb_iter->len, - rxm->full_len - offset + rxm->offset - - TLS_CIPHER_AES_GCM_128_TAG_SIZE); + int frag_pos; + + /* Practically all frags must belong to msg if reencrypt + * is needed with current strparser and coalescing logic, + * but strparser may "get optimized", so let's be safe. + */ + if (pos + skb_iter->len <= offset) + goto done_with_frag; + if (pos >= data_len + rxm->offset) + break; + + frag_pos = offset - pos; + copy = min_t(int, skb_iter->len - frag_pos, + data_len + rxm->offset - offset); if (skb_iter->decrypted) - skb_store_bits(skb_iter, offset, buf, copy); + skb_store_bits(skb_iter, frag_pos, buf, copy); offset += copy; buf += copy; +done_with_frag: + pos += skb_iter->len; } free_buf: @@ -672,6 +679,22 @@ int tls_device_decrypted(struct sock *sk, struct sk_buff *skb) tls_device_reencrypt(sk, skb); } +static void tls_device_attach(struct tls_context *ctx, struct sock *sk, + struct net_device *netdev) +{ + if (sk->sk_destruct != tls_device_sk_destruct) { + refcount_set(&ctx->refcount, 1); + dev_hold(netdev); + ctx->netdev = netdev; + spin_lock_irq(&tls_device_lock); + list_add_tail(&ctx->list, &tls_device_list); + spin_unlock_irq(&tls_device_lock); + + ctx->sk_destruct = sk->sk_destruct; + sk->sk_destruct = tls_device_sk_destruct; + } +} + int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) { u16 nonce_size, tag_size, iv_size, rec_seq_size; @@ -855,8 +878,6 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) } if (!(netdev->features & NETIF_F_HW_TLS_RX)) { - pr_err_ratelimited("%s: netdev %s with no TLS offload\n", - __func__, netdev->name); rc = -ENOTSUPP; goto release_netdev; } @@ -884,17 +905,16 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX, &ctx->crypto_recv.info, tcp_sk(sk)->copied_seq); - if (rc) { - pr_err_ratelimited("%s: The netdev has refused to offload this socket\n", - __func__); + if (rc) goto free_sw_resources; - } tls_device_attach(ctx, sk, netdev); goto release_netdev; free_sw_resources: + up_read(&device_offload_lock); tls_sw_free_resources_rx(sk); + down_read(&device_offload_lock); release_ctx: ctx->priv_ctx_rx = NULL; release_netdev: @@ -929,8 +949,6 @@ void tls_device_offload_cleanup_rx(struct sock *sk) } out: up_read(&device_offload_lock); - kfree(tls_ctx->rx.rec_seq); - kfree(tls_ctx->rx.iv); tls_sw_release_resources_rx(sk); } @@ -1015,4 +1033,5 @@ void __exit tls_device_cleanup(void) { unregister_netdevice_notifier(&tls_dev_notifier); flush_work(&tls_device_gc_work); + clean_acked_data_flush(); } diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index 54c3a758f2a7..c3a5fe624b4e 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -194,18 +194,26 @@ static void update_chksum(struct sk_buff *skb, int headln) static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln) { + struct sock *sk = skb->sk; + int delta; + skb_copy_header(nskb, skb); skb_put(nskb, skb->len); memcpy(nskb->data, skb->data, headln); - update_chksum(nskb, headln); nskb->destructor = skb->destructor; - nskb->sk = skb->sk; + nskb->sk = sk; skb->destructor = NULL; skb->sk = NULL; - refcount_add(nskb->truesize - skb->truesize, - &nskb->sk->sk_wmem_alloc); + + update_chksum(nskb, headln); + + delta = nskb->truesize - skb->truesize; + if (likely(delta < 0)) + WARN_ON_ONCE(refcount_sub_and_test(-delta, &sk->sk_wmem_alloc)); + else if (delta) + refcount_add(delta, &sk->sk_wmem_alloc); } /* This function may be called after the user socket is already diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index df921a2904b9..fc81ae18cc44 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -208,6 +208,26 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, return tls_push_sg(sk, ctx, sg, offset, flags); } +bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx) +{ + struct scatterlist *sg; + + sg = ctx->partially_sent_record; + if (!sg) + return false; + + while (1) { + put_page(sg_page(sg)); + sk_mem_uncharge(sk, sg->length); + + if (sg_is_last(sg)) + break; + sg++; + } + ctx->partially_sent_record = NULL; + return true; +} + static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); @@ -267,13 +287,14 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) kfree(ctx->tx.rec_seq); kfree(ctx->tx.iv); tls_sw_free_resources_tx(sk); +#ifdef CONFIG_TLS_DEVICE + } else if (ctx->tx_conf == TLS_HW) { + tls_device_free_resources_tx(sk); +#endif } - if (ctx->rx_conf == TLS_SW) { - kfree(ctx->rx.rec_seq); - kfree(ctx->rx.iv); + if (ctx->rx_conf == TLS_SW) tls_sw_free_resources_rx(sk); - } #ifdef CONFIG_TLS_DEVICE if (ctx->rx_conf == TLS_HW) @@ -469,27 +490,32 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, switch (crypto_info->cipher_type) { case TLS_CIPHER_AES_GCM_128: + optsize = sizeof(struct tls12_crypto_info_aes_gcm_128); + break; case TLS_CIPHER_AES_GCM_256: { - optsize = crypto_info->cipher_type == TLS_CIPHER_AES_GCM_128 ? - sizeof(struct tls12_crypto_info_aes_gcm_128) : - sizeof(struct tls12_crypto_info_aes_gcm_256); - if (optlen != optsize) { - rc = -EINVAL; - goto err_crypto_info; - } - rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info), - optlen - sizeof(*crypto_info)); - if (rc) { - rc = -EFAULT; - goto err_crypto_info; - } + optsize = sizeof(struct tls12_crypto_info_aes_gcm_256); break; } + case TLS_CIPHER_AES_CCM_128: + optsize = sizeof(struct tls12_crypto_info_aes_ccm_128); + break; default: rc = -EINVAL; goto err_crypto_info; } + if (optlen != optsize) { + rc = -EINVAL; + goto err_crypto_info; + } + + rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info), + optlen - sizeof(*crypto_info)); + if (rc) { + rc = -EFAULT; + goto err_crypto_info; + } + if (tx) { #ifdef CONFIG_TLS_DEVICE rc = tls_set_device_offload(sk, ctx); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 425351ac2a9b..d93f83f77864 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -42,8 +42,6 @@ #include <net/strparser.h> #include <net/tls.h> -#define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE - static int __skb_nsg(struct sk_buff *skb, int offset, int len, unsigned int recursion_level) { @@ -121,23 +119,25 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len) } static int padding_length(struct tls_sw_context_rx *ctx, - struct tls_context *tls_ctx, struct sk_buff *skb) + struct tls_prot_info *prot, struct sk_buff *skb) { struct strp_msg *rxm = strp_msg(skb); int sub = 0; /* Determine zero-padding length */ - if (tls_ctx->prot_info.version == TLS_1_3_VERSION) { + if (prot->version == TLS_1_3_VERSION) { char content_type = 0; int err; int back = 17; while (content_type == 0) { - if (back > rxm->full_len) + if (back > rxm->full_len - prot->prepend_size) return -EBADMSG; err = skb_copy_bits(skb, rxm->offset + rxm->full_len - back, &content_type, 1); + if (err) + return err; if (content_type) break; sub++; @@ -172,9 +172,17 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) tls_err_abort(skb->sk, err); } else { struct strp_msg *rxm = strp_msg(skb); - rxm->full_len -= padding_length(ctx, tls_ctx, skb); - rxm->offset += prot->prepend_size; - rxm->full_len -= prot->overhead_size; + int pad; + + pad = padding_length(ctx, prot, skb); + if (pad < 0) { + ctx->async_wait.err = pad; + tls_err_abort(skb->sk, pad); + } else { + rxm->full_len -= pad; + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; + } } /* After using skb->sk to propagate sk through crypto async callback @@ -225,7 +233,7 @@ static int tls_do_decryption(struct sock *sk, /* Using skb->sk to push sk through to crypto async callback * handler. This allows propagating errors up to the socket * if needed. It _must_ be cleared in the async handler - * before kfree_skb is called. We _know_ skb->sk is NULL + * before consume_skb is called. We _know_ skb->sk is NULL * because it is a clone from strparser. */ skb->sk = sk; @@ -479,11 +487,18 @@ static int tls_do_encryption(struct sock *sk, struct tls_rec *rec = ctx->open_rec; struct sk_msg *msg_en = &rec->msg_encrypted; struct scatterlist *sge = sk_msg_elem(msg_en, start); - int rc; + int rc, iv_offset = 0; + + /* For CCM based ciphers, first byte of IV is a constant */ + if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) { + rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE; + iv_offset = 1; + } - memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data)); - xor_iv_with_seq(prot->version, rec->iv_data, - tls_ctx->tx.rec_seq); + memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, + prot->iv_size + prot->salt_size); + + xor_iv_with_seq(prot->version, rec->iv_data, tls_ctx->tx.rec_seq); sge->offset += prot->prepend_size; sge->length -= prot->prepend_size; @@ -1344,6 +1359,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, struct scatterlist *sgout = NULL; const int data_len = rxm->full_len - prot->overhead_size + prot->tail_size; + int iv_offset = 0; if (*zc && (out_iov || out_sg)) { if (out_iov) @@ -1386,18 +1402,25 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, aad = (u8 *)(sgout + n_sgout); iv = aad + prot->aad_size; + /* For CCM based ciphers, first byte of nonce+iv is always '2' */ + if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) { + iv[0] = 2; + iv_offset = 1; + } + /* Prepare IV */ err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, - iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + iv + iv_offset + prot->salt_size, prot->iv_size); if (err < 0) { kfree(mem); return err; } if (prot->version == TLS_1_3_VERSION) - memcpy(iv, tls_ctx->rx.iv, crypto_aead_ivsize(ctx->aead_recv)); + memcpy(iv + iv_offset, tls_ctx->rx.iv, + crypto_aead_ivsize(ctx->aead_recv)); else - memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size); xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq); @@ -1465,7 +1488,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, struct tls_prot_info *prot = &tls_ctx->prot_info; int version = prot->version; struct strp_msg *rxm = strp_msg(skb); - int err = 0; + int pad, err = 0; if (!ctx->decrypted) { #ifdef CONFIG_TLS_DEVICE @@ -1484,9 +1507,15 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, return err; } + } else { + *zc = false; } - rxm->full_len -= padding_length(ctx, tls_ctx, skb); + pad = padding_length(ctx, prot, skb); + if (pad < 0) + return pad; + + rxm->full_len -= pad; rxm->offset += prot->prepend_size; rxm->full_len -= prot->overhead_size; tls_advance_record_sn(sk, &tls_ctx->rx, version); @@ -1522,7 +1551,7 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, rxm->full_len -= len; return false; } - kfree_skb(skb); + consume_skb(skb); } /* Finished with message */ @@ -1631,7 +1660,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, if (!is_peek) { skb_unlink(skb, &ctx->rx_list); - kfree_skb(skb); + consume_skb(skb); } skb = next_skb; @@ -2050,20 +2079,7 @@ void tls_sw_free_resources_tx(struct sock *sk) /* Free up un-sent records in tx_list. First, free * the partially sent record if any at head of tx_list. */ - if (tls_ctx->partially_sent_record) { - struct scatterlist *sg = tls_ctx->partially_sent_record; - - while (1) { - put_page(sg_page(sg)); - sk_mem_uncharge(sk, sg->length); - - if (sg_is_last(sg)) - break; - sg++; - } - - tls_ctx->partially_sent_record = NULL; - + if (tls_free_partial_record(sk, tls_ctx)) { rec = list_first_entry(&ctx->tx_list, struct tls_rec, list); list_del(&rec->list); @@ -2089,6 +2105,9 @@ void tls_sw_release_resources_rx(struct sock *sk) struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + kfree(tls_ctx->rx.rec_seq); + kfree(tls_ctx->rx.iv); + if (ctx->aead_recv) { kfree_skb(ctx->recv_pkt); ctx->recv_pkt = NULL; @@ -2152,14 +2171,15 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) struct tls_crypto_info *crypto_info; struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; + struct tls12_crypto_info_aes_ccm_128 *ccm_128_info; struct tls_sw_context_tx *sw_ctx_tx = NULL; struct tls_sw_context_rx *sw_ctx_rx = NULL; struct cipher_context *cctx; struct crypto_aead **aead; struct strp_callbacks cb; - u16 nonce_size, tag_size, iv_size, rec_seq_size; + u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size; struct crypto_tfm *tfm; - char *iv, *rec_seq, *key, *salt; + char *iv, *rec_seq, *key, *salt, *cipher_name; size_t keysize; int rc = 0; @@ -2224,6 +2244,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE; key = gcm_128_info->key; salt = gcm_128_info->salt; + salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE; + cipher_name = "gcm(aes)"; break; } case TLS_CIPHER_AES_GCM_256: { @@ -2239,6 +2261,25 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE; key = gcm_256_info->key; salt = gcm_256_info->salt; + salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE; + cipher_name = "gcm(aes)"; + break; + } + case TLS_CIPHER_AES_CCM_128: { + nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; + tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE; + iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; + iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv; + rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE; + rec_seq = + ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq; + ccm_128_info = + (struct tls12_crypto_info_aes_ccm_128 *)crypto_info; + keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE; + key = ccm_128_info->key; + salt = ccm_128_info->salt; + salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE; + cipher_name = "ccm(aes)"; break; } default: @@ -2268,16 +2309,16 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) prot->overhead_size = prot->prepend_size + prot->tag_size + prot->tail_size; prot->iv_size = iv_size; - cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - GFP_KERNEL); + prot->salt_size = salt_size; + cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL); if (!cctx->iv) { rc = -ENOMEM; goto free_priv; } /* Note: 128 & 256 bit salt are the same size */ - memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); - memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); prot->rec_seq_size = rec_seq_size; + memcpy(cctx->iv, salt, salt_size); + memcpy(cctx->iv + salt_size, iv, iv_size); cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); if (!cctx->rec_seq) { rc = -ENOMEM; @@ -2285,7 +2326,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) } if (!*aead) { - *aead = crypto_alloc_aead("gcm(aes)", 0, 0); + *aead = crypto_alloc_aead(cipher_name, 0, 0); if (IS_ERR(*aead)) { rc = PTR_ERR(*aead); *aead = NULL; |