diff options
Diffstat (limited to 'net/tls/tls_sw.c')
| -rw-r--r-- | net/tls/tls_sw.c | 2279 |
1 files changed, 1515 insertions, 764 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 86b9527c4826..9937d4c810f2 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -35,14 +35,48 @@ * SOFTWARE. */ +#include <linux/bug.h> #include <linux/sched/signal.h> #include <linux/module.h> +#include <linux/kernel.h> +#include <linux/splice.h> #include <crypto/aead.h> #include <net/strparser.h> #include <net/tls.h> +#include <trace/events/sock.h> -#define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE +#include "tls.h" + +struct tls_decrypt_arg { + struct_group(inargs, + bool zc; + bool async; + bool async_done; + u8 tail; + ); + + struct sk_buff *skb; +}; + +struct tls_decrypt_ctx { + struct sock *sk; + u8 iv[TLS_MAX_IV_SIZE]; + u8 aad[TLS_MAX_AAD_SIZE]; + u8 tail; + bool free_sgout; + struct scatterlist sg[]; +}; + +noinline void tls_err_abort(struct sock *sk, int err) +{ + WARN_ON_ONCE(err >= 0); + /* sk->sk_err should contain a positive error code. */ + WRITE_ONCE(sk->sk_err, -err); + /* Paired with smp_rmb() in tcp_poll() */ + smp_wmb(); + sk_error_report(sk); +} static int __skb_nsg(struct sk_buff *skb, int offset, int len, unsigned int recursion_level) @@ -120,41 +154,78 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len) return __skb_nsg(skb, offset, len, 0); } -static void tls_decrypt_done(struct crypto_async_request *req, int err) +static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb, + struct tls_decrypt_arg *darg) { - struct aead_request *aead_req = (struct aead_request *)req; + struct strp_msg *rxm = strp_msg(skb); + struct tls_msg *tlm = tls_msg(skb); + int sub = 0; + + /* Determine zero-padding length */ + if (prot->version == TLS_1_3_VERSION) { + int offset = rxm->full_len - TLS_TAG_SIZE - 1; + char content_type = darg->zc ? darg->tail : 0; + int err; + + while (content_type == 0) { + if (offset < prot->prepend_size) + return -EBADMSG; + err = skb_copy_bits(skb, rxm->offset + offset, + &content_type, 1); + if (err) + return err; + if (content_type) + break; + sub++; + offset--; + } + tlm->control = content_type; + } + return sub; +} + +static void tls_decrypt_done(void *data, int err) +{ + struct aead_request *aead_req = data; + struct crypto_aead *aead = crypto_aead_reqtfm(aead_req); struct scatterlist *sgout = aead_req->dst; - struct scatterlist *sgin = aead_req->src; struct tls_sw_context_rx *ctx; + struct tls_decrypt_ctx *dctx; struct tls_context *tls_ctx; struct scatterlist *sg; - struct sk_buff *skb; unsigned int pages; - int pending; + struct sock *sk; + int aead_size; + + /* If requests get too backlogged crypto API returns -EBUSY and calls + * ->complete(-EINPROGRESS) immediately followed by ->complete(0) + * to make waiting for backlog to flush with crypto_wait_req() easier. + * First wait converts -EBUSY -> -EINPROGRESS, and the second one + * -EINPROGRESS -> 0. + * We have a single struct crypto_async_request per direction, this + * scheme doesn't help us, so just ignore the first ->complete(). + */ + if (err == -EINPROGRESS) + return; + + aead_size = sizeof(*aead_req) + crypto_aead_reqsize(aead); + aead_size = ALIGN(aead_size, __alignof__(*dctx)); + dctx = (void *)((u8 *)aead_req + aead_size); - skb = (struct sk_buff *)req->data; - tls_ctx = tls_get_ctx(skb->sk); + sk = dctx->sk; + tls_ctx = tls_get_ctx(sk); ctx = tls_sw_ctx_rx(tls_ctx); /* Propagate if there was an err */ if (err) { + if (err == -EBADMSG) + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); ctx->async_wait.err = err; - tls_err_abort(skb->sk, err); - } else { - struct strp_msg *rxm = strp_msg(skb); - - rxm->offset += tls_ctx->rx.prepend_size; - rxm->full_len -= tls_ctx->rx.overhead_size; + tls_err_abort(sk, err); } - /* After using skb->sk to propagate sk through crypto async callback - * we need to NULL it again. - */ - skb->sk = NULL; - - /* Free the destination pages if skb was not decrypted inplace */ - if (sgout != sgin) { + if (dctx->free_sgout) { /* Skip the first S/G entry as it points to AAD */ for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { if (!sg) @@ -165,59 +236,70 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) kfree(aead_req); - pending = atomic_dec_return(&ctx->decrypt_pending); - - if (!pending && READ_ONCE(ctx->async_notify)) + if (atomic_dec_and_test(&ctx->decrypt_pending)) complete(&ctx->async_wait.completion); } +static int tls_decrypt_async_wait(struct tls_sw_context_rx *ctx) +{ + if (!atomic_dec_and_test(&ctx->decrypt_pending)) + crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + atomic_inc(&ctx->decrypt_pending); + + return ctx->async_wait.err; +} + static int tls_do_decryption(struct sock *sk, - struct sk_buff *skb, struct scatterlist *sgin, struct scatterlist *sgout, char *iv_recv, size_t data_len, struct aead_request *aead_req, - bool async) + struct tls_decrypt_arg *darg) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); int ret; aead_request_set_tfm(aead_req, ctx->aead_recv); - aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); + aead_request_set_ad(aead_req, prot->aad_size); aead_request_set_crypt(aead_req, sgin, sgout, - data_len + tls_ctx->rx.tag_size, + data_len + prot->tag_size, (u8 *)iv_recv); - if (async) { - /* Using skb->sk to push sk through to crypto async callback - * handler. This allows propagating errors up to the socket - * if needed. It _must_ be cleared in the async handler - * before kfree_skb is called. We _know_ skb->sk is NULL - * because it is a clone from strparser. - */ - skb->sk = sk; + if (darg->async) { aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, - tls_decrypt_done, skb); + tls_decrypt_done, aead_req); + DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->decrypt_pending) < 1); atomic_inc(&ctx->decrypt_pending); } else { + DECLARE_CRYPTO_WAIT(wait); + aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, - crypto_req_done, &ctx->async_wait); + crypto_req_done, &wait); + ret = crypto_aead_decrypt(aead_req); + if (ret == -EINPROGRESS || ret == -EBUSY) + ret = crypto_wait_req(ret, &wait); + return ret; } ret = crypto_aead_decrypt(aead_req); - if (ret == -EINPROGRESS) { - if (async) - return ret; + if (ret == -EINPROGRESS) + return 0; - ret = crypto_wait_req(ret, &ctx->async_wait); + if (ret == -EBUSY) { + ret = tls_decrypt_async_wait(ctx); + darg->async_done = true; + /* all completions have run, we're not doing async anymore */ + darg->async = false; + return ret; } - if (async) - atomic_dec(&ctx->decrypt_pending); + atomic_dec(&ctx->decrypt_pending); + darg->async = false; return ret; } @@ -225,12 +307,13 @@ static int tls_do_decryption(struct sock *sk, static void tls_trim_both_msgs(struct sock *sk, int target_size) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; sk_msg_trim(sk, &rec->msg_plaintext, target_size); if (target_size > 0) - target_size += tls_ctx->tx.overhead_size; + target_size += prot->overhead_size; sk_msg_trim(sk, &rec->msg_encrypted, target_size); } @@ -247,6 +330,7 @@ static int tls_alloc_encrypted_msg(struct sock *sk, int len) static int tls_clone_plaintext_msg(struct sock *sk, int required) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; struct sk_msg *msg_pl = &rec->msg_plaintext; @@ -262,7 +346,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) /* Skip initial bytes in msg_en's data to be able to use * same offset of both plain and encrypted data. */ - skip = tls_ctx->tx.prepend_size + msg_pl->sg.size; + skip = prot->prepend_size + msg_pl->sg.size; return sk_msg_clone(sk, msg_pl, msg_en, skip, len); } @@ -270,6 +354,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) static struct tls_rec *tls_get_rec(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct sk_msg *msg_pl, *msg_en; struct tls_rec *rec; @@ -288,15 +373,15 @@ static struct tls_rec *tls_get_rec(struct sock *sk) sk_msg_init(msg_en); sg_init_table(rec->sg_aead_in, 2); - sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, - sizeof(rec->aad_space)); + sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size); sg_unmark_end(&rec->sg_aead_in[1]); sg_init_table(rec->sg_aead_out, 2); - sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, - sizeof(rec->aad_space)); + sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size); sg_unmark_end(&rec->sg_aead_out[1]); + rec->sk = sk; + return rec; } @@ -373,29 +458,34 @@ int tls_tx_records(struct sock *sk, int flags) tx_err: if (rc < 0 && rc != -EAGAIN) - tls_err_abort(sk, EBADMSG); + tls_err_abort(sk, rc); return rc; } -static void tls_encrypt_done(struct crypto_async_request *req, int err) +static void tls_encrypt_done(void *data, int err) { - struct aead_request *aead_req = (struct aead_request *)req; - struct sock *sk = req->data; - struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_sw_context_tx *ctx; + struct tls_context *tls_ctx; + struct tls_prot_info *prot; + struct tls_rec *rec = data; struct scatterlist *sge; struct sk_msg *msg_en; - struct tls_rec *rec; - bool ready = false; - int pending; + struct sock *sk; + + if (err == -EINPROGRESS) /* see the comment in tls_decrypt_done() */ + return; - rec = container_of(aead_req, struct tls_rec, aead_req); msg_en = &rec->msg_encrypted; + sk = rec->sk; + tls_ctx = tls_get_ctx(sk); + prot = &tls_ctx->prot_info; + ctx = tls_sw_ctx_tx(tls_ctx); + sge = sk_msg_elem(msg_en, msg_en->sg.curr); - sge->offset -= tls_ctx->tx.prepend_size; - sge->length += tls_ctx->tx.prepend_size; + sge->offset -= prot->prepend_size; + sge->length += prot->prepend_size; /* Check if error is previously set on socket */ if (err || sk->sk_err) { @@ -403,7 +493,7 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) /* If err is already set on socket, return the same code */ if (sk->sk_err) { - ctx->async_wait.err = sk->sk_err; + ctx->async_wait.err = -sk->sk_err; } else { ctx->async_wait.err = err; tls_err_abort(sk, err); @@ -419,21 +509,25 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) /* If received record is at head of tx_list, schedule tx */ first_rec = list_first_entry(&ctx->tx_list, struct tls_rec, list); - if (rec == first_rec) - ready = true; + if (rec == first_rec) { + /* Schedule the transmission */ + if (!test_and_set_bit(BIT_TX_SCHEDULED, + &ctx->tx_bitmask)) + schedule_delayed_work(&ctx->tx_work.work, 1); + } } - pending = atomic_dec_return(&ctx->encrypt_pending); - - if (!pending && READ_ONCE(ctx->async_notify)) + if (atomic_dec_and_test(&ctx->encrypt_pending)) complete(&ctx->async_wait.completion); +} - if (!ready) - return; +static int tls_encrypt_async_wait(struct tls_sw_context_tx *ctx) +{ + if (!atomic_dec_and_test(&ctx->encrypt_pending)) + crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + atomic_inc(&ctx->encrypt_pending); - /* Schedule the transmission */ - if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) - schedule_delayed_work(&ctx->tx_work.work, 1); + return ctx->async_wait.err; } static int tls_do_encryption(struct sock *sk, @@ -442,34 +536,58 @@ static int tls_do_encryption(struct sock *sk, struct aead_request *aead_req, size_t data_len, u32 start) { + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_rec *rec = ctx->open_rec; struct sk_msg *msg_en = &rec->msg_encrypted; struct scatterlist *sge = sk_msg_elem(msg_en, start); - int rc; + int rc, iv_offset = 0; + + /* For CCM based ciphers, first byte of IV is a constant */ + switch (prot->cipher_type) { + case TLS_CIPHER_AES_CCM_128: + rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE; + iv_offset = 1; + break; + case TLS_CIPHER_SM4_CCM: + rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE; + iv_offset = 1; + break; + } - sge->offset += tls_ctx->tx.prepend_size; - sge->length -= tls_ctx->tx.prepend_size; + memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, + prot->iv_size + prot->salt_size); + + tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset, + tls_ctx->tx.rec_seq); + + sge->offset += prot->prepend_size; + sge->length -= prot->prepend_size; msg_en->sg.curr = start; aead_request_set_tfm(aead_req, ctx->aead_send); - aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); + aead_request_set_ad(aead_req, prot->aad_size); aead_request_set_crypt(aead_req, rec->sg_aead_in, rec->sg_aead_out, - data_len, tls_ctx->tx.iv); + data_len, rec->iv_data); aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, - tls_encrypt_done, sk); + tls_encrypt_done, rec); /* Add the record in tx_list */ list_add_tail((struct list_head *)&rec->list, &ctx->tx_list); + DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->encrypt_pending) < 1); atomic_inc(&ctx->encrypt_pending); rc = crypto_aead_encrypt(aead_req); + if (rc == -EBUSY) { + rc = tls_encrypt_async_wait(ctx); + rc = rc ?: -EINPROGRESS; + } if (!rc || rc != -EINPROGRESS) { atomic_dec(&ctx->encrypt_pending); - sge->offset -= tls_ctx->tx.prepend_size; - sge->length += tls_ctx->tx.prepend_size; + sge->offset -= prot->prepend_size; + sge->length += prot->prepend_size; } if (!rc) { @@ -481,7 +599,7 @@ static int tls_do_encryption(struct sock *sk, /* Unhook the record from context if encryption is not failure */ ctx->open_rec = NULL; - tls_advance_record_sn(sk, &tls_ctx->tx); + tls_advance_record_sn(sk, prot, &tls_ctx->tx); return rc; } @@ -607,9 +725,10 @@ static int tls_push_record(struct sock *sk, int flags, unsigned char record_type) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec, *tmp = NULL; - u32 i, split_point, uninitialized_var(orig_end); + u32 i, split_point, orig_end; struct sk_msg *msg_pl, *msg_en; struct aead_request *req; bool split; @@ -623,14 +742,34 @@ static int tls_push_record(struct sock *sk, int flags, split_point = msg_pl->apply_bytes; split = split_point && split_point < msg_pl->sg.size; + if (unlikely((!split && + msg_pl->sg.size + + prot->overhead_size > msg_en->sg.size) || + (split && + split_point + + prot->overhead_size > msg_en->sg.size))) { + split = true; + split_point = msg_en->sg.size; + } if (split) { rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, - split_point, tls_ctx->tx.overhead_size, + split_point, prot->overhead_size, &orig_end); if (rc < 0) return rc; + /* This can happen if above tls_split_open_record allocates + * a single large encryption buffer instead of two smaller + * ones. In this case adjust pointers and continue without + * split. + */ + if (!msg_pl->sg.size) { + tls_merge_open_record(sk, rec, tmp, orig_end); + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; + split = false; + } sk_msg_trim(sk, msg_en, msg_pl->sg.size + - tls_ctx->tx.overhead_size); + prot->overhead_size); } rec->tx_flags = flags; @@ -638,11 +777,26 @@ static int tls_push_record(struct sock *sk, int flags, i = msg_pl->sg.end; sk_msg_iter_var_prev(i); - sg_mark_end(sk_msg_elem(msg_pl, i)); + + rec->content_type = record_type; + if (prot->version == TLS_1_3_VERSION) { + /* Add content type to end of message. No padding added */ + sg_set_buf(&rec->sg_content_type, &rec->content_type, 1); + sg_mark_end(&rec->sg_content_type); + sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1, + &rec->sg_content_type); + } else { + sg_mark_end(sk_msg_elem(msg_pl, i)); + } + + if (msg_pl->sg.end < msg_pl->sg.start) { + sg_chain(&msg_pl->sg.data[msg_pl->sg.start], + MAX_SKB_FRAGS - msg_pl->sg.start + 1, + msg_pl->sg.data); + } i = msg_pl->sg.start; - sg_chain(rec->sg_aead_in, 2, rec->inplace_crypto ? - &msg_en->sg.data[i] : &msg_pl->sg.data[i]); + sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]); i = msg_en->sg.end; sk_msg_iter_var_prev(i); @@ -651,32 +805,33 @@ static int tls_push_record(struct sock *sk, int flags, i = msg_en->sg.start; sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); - tls_make_aad(rec->aad_space, msg_pl->sg.size, - tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, - record_type); + tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size, + tls_ctx->tx.rec_seq, record_type, prot); tls_fill_prepend(tls_ctx, page_address(sg_page(&msg_en->sg.data[i])) + - msg_en->sg.data[i].offset, msg_pl->sg.size, + msg_en->sg.data[i].offset, + msg_pl->sg.size + prot->tail_size, record_type); tls_ctx->pending_open_record_frags = false; - rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i); + rc = tls_do_encryption(sk, tls_ctx, ctx, req, + msg_pl->sg.size + prot->tail_size, i); if (rc < 0) { if (rc != -EINPROGRESS) { - tls_err_abort(sk, EBADMSG); + tls_err_abort(sk, -EBADMSG); if (split) { tls_ctx->pending_open_record_frags = true; tls_merge_open_record(sk, rec, tmp, orig_end); } } + ctx->async_capable = 1; return rc; } else if (split) { msg_pl = &tmp->msg_plaintext; msg_en = &tmp->msg_encrypted; - sk_msg_trim(sk, msg_en, msg_pl->sg.size + - tls_ctx->tx.overhead_size); + sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size); tls_ctx->pending_open_record_frags = true; ctx->open_rec = tmp; } @@ -686,7 +841,7 @@ static int tls_push_record(struct sock *sk, int flags, static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, bool full_record, u8 record_type, - size_t *copied, int flags) + ssize_t *copied, int flags) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); @@ -694,23 +849,42 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, struct sk_psock *psock; struct sock *sk_redir; struct tls_rec *rec; - bool enospc, policy; + bool enospc, policy, redir_ingress; int err = 0, send; u32 delta = 0; policy = !(flags & MSG_SENDPAGE_NOPOLICY); psock = sk_psock_get(sk); - if (!psock || !policy) - return tls_push_record(sk, flags, record_type); + if (!psock || !policy) { + err = tls_push_record(sk, flags, record_type); + if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) { + *copied -= sk_msg_free(sk, msg); + tls_free_open_rec(sk); + err = -sk->sk_err; + } + if (psock) + sk_psock_put(sk, psock); + return err; + } more_data: enospc = sk_msg_full(msg); if (psock->eval == __SK_NONE) { delta = msg->sg.size; psock->eval = sk_psock_msg_verdict(sk, psock, msg); - if (delta < msg->sg.size) - delta -= msg->sg.size; - else - delta = 0; + delta -= msg->sg.size; + + if ((s32)delta > 0) { + /* It indicates that we executed bpf_msg_pop_data(), + * causing the plaintext data size to decrease. + * Therefore the encrypted data size also needs to + * correspondingly decrease. We only need to subtract + * delta to calculate the new ciphertext length since + * ktls does not support block encryption. + */ + struct sk_msg *enc = &ctx->open_rec->msg_encrypted; + + sk_msg_trim(sk, enc, enc->sg.size - delta); + } } if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && !enospc && !full_record) { @@ -725,13 +899,15 @@ more_data: switch (psock->eval) { case __SK_PASS: err = tls_push_record(sk, flags, record_type); - if (err < 0) { + if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) { *copied -= sk_msg_free(sk, msg); tls_free_open_rec(sk); + err = -sk->sk_err; goto out_err; } break; case __SK_REDIRECT: + redir_ingress = psock->redir_ingress; sk_redir = psock->sk_redir; memcpy(&msg_redir, msg, sizeof(*msg)); if (msg->apply_bytes < send) @@ -741,9 +917,17 @@ more_data: sk_msg_return_zero(sk, msg, send); msg->sg.size -= send; release_sock(sk); - err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags); + err = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, + &msg_redir, send, flags); lock_sock(sk); if (err < 0) { + /* Regardless of whether the data represented by + * msg_redir is sent successfully, we have already + * uncharged it via sk_msg_return_zero(). The + * msg->sg.size represents the remaining unprocessed + * data, which needs to be uncharged here. + */ + sk_mem_uncharge(sk, msg->sg.size); *copied -= sk_msg_free_nocharge(sk, &msg_redir); msg->sg.size = 0; } @@ -807,17 +991,50 @@ static int tls_sw_push_pending_record(struct sock *sk, int flags) &copied, flags); } -int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +static int tls_sw_sendmsg_splice(struct sock *sk, struct msghdr *msg, + struct sk_msg *msg_pl, size_t try_to_copy, + ssize_t *copied) +{ + struct page *page = NULL, **pages = &page; + + do { + ssize_t part; + size_t off; + + part = iov_iter_extract_pages(&msg->msg_iter, &pages, + try_to_copy, 1, 0, &off); + if (part <= 0) + return part ?: -EIO; + + if (WARN_ON_ONCE(!sendpage_ok(page))) { + iov_iter_revert(&msg->msg_iter, part); + return -EIO; + } + + sk_msg_page_add(msg_pl, page, part, off); + msg_pl->sg.copybreak = 0; + msg_pl->sg.curr = msg_pl->sg.end; + sk_mem_charge(sk, part); + *copied += part; + try_to_copy -= part; + } while (try_to_copy && !sk_msg_full(msg_pl)); + + return 0; +} + +static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg, + size_t size) { long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - struct crypto_tfm *tfm = crypto_aead_tfm(ctx->aead_send); - bool async_capable = tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC; + bool async_capable = ctx->async_capable; unsigned char record_type = TLS_RECORD_TYPE_DATA; bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool eor = !(msg->msg_flags & MSG_MORE); - size_t try_to_copy, copied = 0; + size_t try_to_copy; + ssize_t copied = 0; struct sk_msg *msg_pl, *msg_en; struct tls_rec *rec; int required_size; @@ -828,25 +1045,16 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) int orig_size; int ret = 0; - if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) - return -ENOTSUPP; - - lock_sock(sk); - - /* Wait till there is any pending write on socket */ - if (unlikely(sk->sk_write_pending)) { - ret = wait_on_pending_writer(sk, &timeo); - if (unlikely(ret)) - goto send_end; - } + if (!eor && (msg->msg_flags & MSG_EOR)) + return -EINVAL; if (unlikely(msg->msg_controllen)) { - ret = tls_proccess_cmsg(sk, msg, &record_type); + ret = tls_process_cmsg(sk, msg, &record_type); if (ret) { if (ret == -EINPROGRESS) num_async++; else if (ret != -EAGAIN) - goto send_end; + goto end; } } @@ -871,14 +1079,14 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) orig_size = msg_pl->sg.size; full_record = false; try_to_copy = msg_data_left(msg); - record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; + record_room = tls_ctx->tx_max_payload_len - msg_pl->sg.size; if (try_to_copy >= record_room) { try_to_copy = record_room; full_record = true; } required_size = msg_pl->sg.size + try_to_copy + - tls_ctx->tx.overhead_size; + prot->overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; @@ -897,6 +1105,24 @@ alloc_encrypted: full_record = true; } + if (try_to_copy && (msg->msg_flags & MSG_SPLICE_PAGES)) { + ret = tls_sw_sendmsg_splice(sk, msg, msg_pl, + try_to_copy, &copied); + if (ret < 0) + goto send_end; + tls_ctx->pending_open_record_frags = true; + + if (sk_msg_full(msg_pl)) { + full_record = true; + sk_msg_trim(sk, msg_en, + msg_pl->sg.size + prot->overhead_size); + } + + if (full_record || eor) + goto copied; + continue; + } + if (!is_kvec && (full_record || eor) && !async_capable) { u32 first = msg_pl->sg.end; @@ -905,8 +1131,6 @@ alloc_encrypted: if (ret) goto fallback_to_reg_send; - rec->inplace_crypto = 0; - num_zc++; copied += try_to_copy; @@ -919,11 +1143,22 @@ alloc_encrypted: num_async++; else if (ret == -ENOMEM) goto wait_for_memory; - else if (ret == -ENOSPC) + else if (ctx->open_rec && ret == -ENOSPC) { + if (msg_pl->cork_bytes) { + ret = 0; + goto send_end; + } goto rollback_iter; - else if (ret != -EAGAIN) + } else if (ret != -EAGAIN) goto send_end; } + + /* Transmit if any encryptions have completed */ + if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { + cancel_delayed_work(&ctx->tx_work.work); + tls_tx_records(sk, msg->msg_flags); + } + continue; rollback_iter: copied -= try_to_copy; @@ -947,8 +1182,8 @@ fallback_to_reg_send: */ try_to_copy -= required_size - msg_pl->sg.size; full_record = true; - sk_msg_trim(sk, msg_en, msg_pl->sg.size + - tls_ctx->tx.overhead_size); + sk_msg_trim(sk, msg_en, + msg_pl->sg.size + prot->overhead_size); } if (try_to_copy) { @@ -963,6 +1198,7 @@ fallback_to_reg_send: */ tls_ctx->pending_open_record_frags = true; copied += try_to_copy; +copied: if (full_record || eor) { ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, record_type, &copied, @@ -978,6 +1214,12 @@ fallback_to_reg_send: goto send_end; } } + + /* Transmit if any encryptions have completed */ + if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { + cancel_delayed_work(&ctx->tx_work.work); + tls_tx_records(sk, msg->msg_flags); + } } continue; @@ -988,29 +1230,25 @@ wait_for_memory: ret = sk_stream_wait_memory(sk, &timeo); if (ret) { trim_sgl: - tls_trim_both_msgs(sk, orig_size); + if (ctx->open_rec) + tls_trim_both_msgs(sk, orig_size); goto send_end; } - if (msg_en->sg.size < required_size) + if (ctx->open_rec && msg_en->sg.size < required_size) goto alloc_encrypted; } +send_end: if (!num_async) { - goto send_end; - } else if (num_zc) { - /* Wait for pending encryptions to get completed */ - smp_store_mb(ctx->async_notify, true); - - if (atomic_read(&ctx->encrypt_pending)) - crypto_wait_req(-EINPROGRESS, &ctx->async_wait); - else - reinit_completion(&ctx->async_wait.completion); - - WRITE_ONCE(ctx->async_notify, false); + goto end; + } else if (num_zc || eor) { + int err; - if (ctx->async_wait.err) { - ret = ctx->async_wait.err; + /* Wait for pending encryptions to get completed */ + err = tls_encrypt_async_wait(ctx); + if (err) { + ret = err; copied = 0; } } @@ -1021,204 +1259,162 @@ trim_sgl: tls_tx_records(sk, msg->msg_flags); } -send_end: +end: ret = sk_stream_error(sk, msg->msg_flags, ret); + return copied > 0 ? copied : ret; +} +int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + int ret; + + if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_CMSG_COMPAT | MSG_SPLICE_PAGES | MSG_EOR | + MSG_SENDPAGE_NOPOLICY)) + return -EOPNOTSUPP; + + ret = mutex_lock_interruptible(&tls_ctx->tx_lock); + if (ret) + return ret; + lock_sock(sk); + ret = tls_sw_sendmsg_locked(sk, msg, size); release_sock(sk); - return copied ? copied : ret; + mutex_unlock(&tls_ctx->tx_lock); + return ret; } -static int tls_sw_do_sendpage(struct sock *sk, struct page *page, - int offset, size_t size, int flags) +/* + * Handle unexpected EOF during splice without SPLICE_F_MORE set. + */ +void tls_sw_splice_eof(struct socket *sock) { - long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + struct sock *sk = sock->sk; struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - unsigned char record_type = TLS_RECORD_TYPE_DATA; - struct sk_msg *msg_pl; struct tls_rec *rec; - int num_async = 0; - size_t copied = 0; - bool full_record; - int record_room; + struct sk_msg *msg_pl; + ssize_t copied = 0; + bool retrying = false; int ret = 0; - bool eor; - eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST)); - sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); - - /* Wait till there is any pending write on socket */ - if (unlikely(sk->sk_write_pending)) { - ret = wait_on_pending_writer(sk, &timeo); - if (unlikely(ret)) - goto sendpage_end; - } - - /* Call the sk_stream functions to manage the sndbuf mem. */ - while (size > 0) { - size_t copy, required_size; - - if (sk->sk_err) { - ret = -sk->sk_err; - goto sendpage_end; - } - - if (ctx->open_rec) - rec = ctx->open_rec; - else - rec = ctx->open_rec = tls_get_rec(sk); - if (!rec) { - ret = -ENOMEM; - goto sendpage_end; - } - - msg_pl = &rec->msg_plaintext; - - full_record = false; - record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; - copied = 0; - copy = size; - if (copy >= record_room) { - copy = record_room; - full_record = true; - } - - required_size = msg_pl->sg.size + copy + - tls_ctx->tx.overhead_size; - - if (!sk_stream_memory_free(sk)) - goto wait_for_sndbuf; -alloc_payload: - ret = tls_alloc_encrypted_msg(sk, required_size); - if (ret) { - if (ret != -ENOSPC) - goto wait_for_memory; - - /* Adjust copy according to the amount that was - * actually allocated. The difference is due - * to max sg elements limit - */ - copy -= required_size - msg_pl->sg.size; - full_record = true; - } - - sk_msg_page_add(msg_pl, page, copy, offset); - sk_mem_charge(sk, copy); - - offset += copy; - size -= copy; - copied += copy; + if (!ctx->open_rec) + return; - tls_ctx->pending_open_record_frags = true; - if (full_record || eor || sk_msg_full(msg_pl)) { - rec->inplace_crypto = 0; - ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, - record_type, &copied, flags); - if (ret) { - if (ret == -EINPROGRESS) - num_async++; - else if (ret == -ENOMEM) - goto wait_for_memory; - else if (ret != -EAGAIN) { - if (ret == -ENOSPC) - ret = 0; - goto sendpage_end; - } - } - } - continue; -wait_for_sndbuf: - set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); -wait_for_memory: - ret = sk_stream_wait_memory(sk, &timeo); - if (ret) { - tls_trim_both_msgs(sk, msg_pl->sg.size); - goto sendpage_end; - } + mutex_lock(&tls_ctx->tx_lock); + lock_sock(sk); - goto alloc_payload; - } +retry: + /* same checks as in tls_sw_push_pending_record() */ + rec = ctx->open_rec; + if (!rec) + goto unlock; - if (num_async) { - /* Transmit if any encryptions have completed */ - if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { - cancel_delayed_work(&ctx->tx_work.work); - tls_tx_records(sk, flags); - } + msg_pl = &rec->msg_plaintext; + if (msg_pl->sg.size == 0) + goto unlock; + + /* Check the BPF advisor and perform transmission. */ + ret = bpf_exec_tx_verdict(msg_pl, sk, false, TLS_RECORD_TYPE_DATA, + &copied, 0); + switch (ret) { + case 0: + case -EAGAIN: + if (retrying) + goto unlock; + retrying = true; + goto retry; + case -EINPROGRESS: + break; + default: + goto unlock; } -sendpage_end: - ret = sk_stream_error(sk, flags, ret); - return copied ? copied : ret; -} -int tls_sw_sendpage(struct sock *sk, struct page *page, - int offset, size_t size, int flags) -{ - int ret; + /* Wait for pending encryptions to get completed */ + if (tls_encrypt_async_wait(ctx)) + goto unlock; - if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | - MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) - return -ENOTSUPP; + /* Transmit if any encryptions have completed */ + if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { + cancel_delayed_work(&ctx->tx_work.work); + tls_tx_records(sk, 0); + } - lock_sock(sk); - ret = tls_sw_do_sendpage(sk, page, offset, size, flags); +unlock: release_sock(sk); - return ret; + mutex_unlock(&tls_ctx->tx_lock); } -static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock, - int flags, long timeo, int *err) +static int +tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, + bool released) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - struct sk_buff *skb; DEFINE_WAIT_FUNC(wait, woken_wake_function); + int ret = 0; + long timeo; - while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) { - if (sk->sk_err) { - *err = sock_error(sk); - return NULL; + /* a rekey is pending, let userspace deal with it */ + if (unlikely(ctx->key_update_pending)) + return -EKEYEXPIRED; + + timeo = sock_rcvtimeo(sk, nonblock); + + while (!tls_strp_msg_ready(ctx)) { + if (!sk_psock_queue_empty(psock)) + return 0; + + if (sk->sk_err) + return sock_error(sk); + + if (ret < 0) + return ret; + + if (!skb_queue_empty(&sk->sk_receive_queue)) { + tls_strp_check_rcv(&ctx->strp); + if (tls_strp_msg_ready(ctx)) + break; } if (sk->sk_shutdown & RCV_SHUTDOWN) - return NULL; + return 0; if (sock_flag(sk, SOCK_DONE)) - return NULL; + return 0; - if ((flags & MSG_DONTWAIT) || !timeo) { - *err = -EAGAIN; - return NULL; - } + if (!timeo) + return -EAGAIN; + released = true; add_wait_queue(sk_sleep(sk), &wait); sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); - sk_wait_event(sk, &timeo, - ctx->recv_pkt != skb || - !sk_psock_queue_empty(psock), - &wait); + ret = sk_wait_event(sk, &timeo, + tls_strp_msg_ready(ctx) || + !sk_psock_queue_empty(psock), + &wait); sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); remove_wait_queue(sk_sleep(sk), &wait); /* Handle signals */ - if (signal_pending(current)) { - *err = sock_intr_errno(timeo); - return NULL; - } + if (signal_pending(current)) + return sock_intr_errno(timeo); } - return skb; + if (unlikely(!tls_strp_msg_load(&ctx->strp, released))) + return tls_rx_rec_wait(sk, psock, nonblock, false); + + return 1; } -static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from, +static int tls_setup_from_iter(struct iov_iter *from, int length, int *pages_used, - unsigned int *size_used, struct scatterlist *to, int to_max_pages) { int rc = 0, i = 0, num_elem = *pages_used, maxpages; struct page *pages[MAX_SKB_FRAGS]; - unsigned int size = *size_used; + unsigned int size = 0; ssize_t copied, use; size_t offset; @@ -1229,7 +1425,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from, rc = -EFAULT; goto out; } - copied = iov_iter_get_pages(from, pages, + copied = iov_iter_get_pages2(from, pages, length, maxpages, &offset); if (copied <= 0) { @@ -1237,8 +1433,6 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from, goto out; } - iov_iter_advance(from, copied); - length -= copied; size += copied; while (copied) { @@ -1261,227 +1455,434 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from, sg_mark_end(&to[num_elem - 1]); out: if (rc) - iov_iter_revert(from, size - *size_used); - *size_used = size; + iov_iter_revert(from, size); *pages_used = num_elem; return rc; } +static struct sk_buff * +tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb, + unsigned int full_len) +{ + struct strp_msg *clr_rxm; + struct sk_buff *clr_skb; + int err; + + clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER, + &err, sk->sk_allocation); + if (!clr_skb) + return NULL; + + skb_copy_header(clr_skb, skb); + clr_skb->len = full_len; + clr_skb->data_len = full_len; + + clr_rxm = strp_msg(clr_skb); + clr_rxm->offset = 0; + + return clr_skb; +} + +/* Decrypt handlers + * + * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers. + * They must transform the darg in/out argument are as follows: + * | Input | Output + * ------------------------------------------------------------------- + * zc | Zero-copy decrypt allowed | Zero-copy performed + * async | Async decrypt allowed | Async crypto used / in progress + * skb | * | Output skb + * + * If ZC decryption was performed darg.skb will point to the input skb. + */ + /* This function decrypts the input skb into either out_iov or in out_sg - * or in skb buffers itself. The input parameter 'zc' indicates if + * or in skb buffers itself. The input parameter 'darg->zc' indicates if * zero-copy mode needs to be tried or not. With zero-copy mode, either * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are * NULL, then the decryption happens inside skb buffers itself, i.e. - * zero-copy gets disabled and 'zc' is updated. + * zero-copy gets disabled and 'darg->zc' is updated. */ - -static int decrypt_internal(struct sock *sk, struct sk_buff *skb, - struct iov_iter *out_iov, - struct scatterlist *out_sg, - int *chunk, bool *zc, bool async) +static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, + struct scatterlist *out_sg, + struct tls_decrypt_arg *darg) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - struct strp_msg *rxm = strp_msg(skb); - int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; + struct tls_prot_info *prot = &tls_ctx->prot_info; + int n_sgin, n_sgout, aead_size, err, pages = 0; + struct sk_buff *skb = tls_strp_msg(ctx); + const struct strp_msg *rxm = strp_msg(skb); + const struct tls_msg *tlm = tls_msg(skb); struct aead_request *aead_req; - struct sk_buff *unused; - u8 *aad, *iv, *mem = NULL; struct scatterlist *sgin = NULL; struct scatterlist *sgout = NULL; - const int data_len = rxm->full_len - tls_ctx->rx.overhead_size; + const int data_len = rxm->full_len - prot->overhead_size; + int tail_pages = !!prot->tail_size; + struct tls_decrypt_ctx *dctx; + struct sk_buff *clear_skb; + int iv_offset = 0; + u8 *mem; + + n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size, + rxm->full_len - prot->prepend_size); + if (n_sgin < 1) + return n_sgin ?: -EBADMSG; + + if (darg->zc && (out_iov || out_sg)) { + clear_skb = NULL; - if (*zc && (out_iov || out_sg)) { if (out_iov) - n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1; + n_sgout = 1 + tail_pages + + iov_iter_npages_cap(out_iov, INT_MAX, data_len); else n_sgout = sg_nents(out_sg); - n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size, - rxm->full_len - tls_ctx->rx.prepend_size); } else { - n_sgout = 0; - *zc = false; - n_sgin = skb_cow_data(skb, 0, &unused); - } + darg->zc = false; - if (n_sgin < 1) - return -EBADMSG; + clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len); + if (!clear_skb) + return -ENOMEM; + + n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags; + } /* Increment to accommodate AAD */ n_sgin = n_sgin + 1; - nsg = n_sgin + n_sgout; - - aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); - mem_size = aead_size + (nsg * sizeof(struct scatterlist)); - mem_size = mem_size + TLS_AAD_SPACE_SIZE; - mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); - /* Allocate a single block of memory which contains - * aead_req || sgin[] || sgout[] || aad || iv. - * This order achieves correct alignment for aead_req, sgin, sgout. + * aead_req || tls_decrypt_ctx. + * Both structs are variable length. */ - mem = kmalloc(mem_size, sk->sk_allocation); - if (!mem) - return -ENOMEM; + aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); + aead_size = ALIGN(aead_size, __alignof__(*dctx)); + mem = kmalloc(aead_size + struct_size(dctx, sg, size_add(n_sgin, n_sgout)), + sk->sk_allocation); + if (!mem) { + err = -ENOMEM; + goto exit_free_skb; + } /* Segment the allocated memory */ aead_req = (struct aead_request *)mem; - sgin = (struct scatterlist *)(mem + aead_size); - sgout = sgin + n_sgin; - aad = (u8 *)(sgout + n_sgout); - iv = aad + TLS_AAD_SPACE_SIZE; + dctx = (struct tls_decrypt_ctx *)(mem + aead_size); + dctx->sk = sk; + sgin = &dctx->sg[0]; + sgout = &dctx->sg[n_sgin]; + + /* For CCM based ciphers, first byte of nonce+iv is a constant */ + switch (prot->cipher_type) { + case TLS_CIPHER_AES_CCM_128: + dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE; + iv_offset = 1; + break; + case TLS_CIPHER_SM4_CCM: + dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE; + iv_offset = 1; + break; + } /* Prepare IV */ - err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, - iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - tls_ctx->rx.iv_size); - if (err < 0) { - kfree(mem); - return err; + if (prot->version == TLS_1_3_VERSION || + prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) { + memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, + prot->iv_size + prot->salt_size); + } else { + err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, + &dctx->iv[iv_offset] + prot->salt_size, + prot->iv_size); + if (err < 0) + goto exit_free; + memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size); } - memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq); /* Prepare AAD */ - tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size, - tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size, - ctx->control); + tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size + + prot->tail_size, + tls_ctx->rx.rec_seq, tlm->control, prot); /* Prepare sgin */ sg_init_table(sgin, n_sgin); - sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(&sgin[0], dctx->aad, prot->aad_size); err = skb_to_sgvec(skb, &sgin[1], - rxm->offset + tls_ctx->rx.prepend_size, - rxm->full_len - tls_ctx->rx.prepend_size); - if (err < 0) { - kfree(mem); - return err; + rxm->offset + prot->prepend_size, + rxm->full_len - prot->prepend_size); + if (err < 0) + goto exit_free; + + if (clear_skb) { + sg_init_table(sgout, n_sgout); + sg_set_buf(&sgout[0], dctx->aad, prot->aad_size); + + err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size, + data_len + prot->tail_size); + if (err < 0) + goto exit_free; + } else if (out_iov) { + sg_init_table(sgout, n_sgout); + sg_set_buf(&sgout[0], dctx->aad, prot->aad_size); + + err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1], + (n_sgout - 1 - tail_pages)); + if (err < 0) + goto exit_free_pages; + + if (prot->tail_size) { + sg_unmark_end(&sgout[pages]); + sg_set_buf(&sgout[pages + 1], &dctx->tail, + prot->tail_size); + sg_mark_end(&sgout[pages + 1]); + } + } else if (out_sg) { + memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); } + dctx->free_sgout = !!pages; - if (n_sgout) { - if (out_iov) { - sg_init_table(sgout, n_sgout); - sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE); + /* Prepare and submit AEAD request */ + err = tls_do_decryption(sk, sgin, sgout, dctx->iv, + data_len + prot->tail_size, aead_req, darg); + if (err) { + if (darg->async_done) + goto exit_free_skb; + goto exit_free_pages; + } - *chunk = 0; - err = tls_setup_from_iter(sk, out_iov, data_len, - &pages, chunk, &sgout[1], - (n_sgout - 1)); - if (err < 0) - goto fallback_to_reg_recv; - } else if (out_sg) { - memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); - } else { - goto fallback_to_reg_recv; + darg->skb = clear_skb ?: tls_strp_msg(ctx); + clear_skb = NULL; + + if (unlikely(darg->async)) { + err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold); + if (err) { + err = tls_decrypt_async_wait(ctx); + darg->async = false; } - } else { -fallback_to_reg_recv: - sgout = sgin; - pages = 0; - *chunk = data_len; - *zc = false; + return err; } - /* Prepare and submit AEAD request */ - err = tls_do_decryption(sk, skb, sgin, sgout, iv, - data_len, aead_req, async); - if (err == -EINPROGRESS) - return err; + if (unlikely(darg->async_done)) + return 0; + if (prot->tail_size) + darg->tail = dctx->tail; + +exit_free_pages: /* Release the pages in case iov was mapped to pages */ for (; pages > 0; pages--) put_page(sg_page(&sgout[pages])); - +exit_free: kfree(mem); +exit_free_skb: + consume_skb(clear_skb); return err; } -static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, - struct iov_iter *dest, int *chunk, bool *zc, - bool async) +static int +tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx, + struct msghdr *msg, struct tls_decrypt_arg *darg) { - struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - struct strp_msg *rxm = strp_msg(skb); - int err = 0; + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct strp_msg *rxm; + int pad, err; -#ifdef CONFIG_TLS_DEVICE - err = tls_device_decrypted(sk, skb); - if (err < 0) + err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg); + if (err < 0) { + if (err == -EBADMSG) + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); return err; -#endif - if (!ctx->decrypted) { - err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async); - if (err < 0) { - if (err == -EINPROGRESS) - tls_advance_record_sn(sk, &tls_ctx->rx); + } + /* keep going even for ->async, the code below is TLS 1.3 */ - return err; - } - } else { - *zc = false; + /* If opportunistic TLS 1.3 ZC failed retry without ZC */ + if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION && + darg->tail != TLS_RECORD_TYPE_DATA)) { + darg->zc = false; + if (!darg->tail) + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY); + return tls_decrypt_sw(sk, tls_ctx, msg, darg); } - rxm->offset += tls_ctx->rx.prepend_size; - rxm->full_len -= tls_ctx->rx.overhead_size; - tls_advance_record_sn(sk, &tls_ctx->rx); - ctx->decrypted = true; - ctx->saved_data_ready(sk); + pad = tls_padding_length(prot, darg->skb, darg); + if (pad < 0) { + if (darg->skb != tls_strp_msg(ctx)) + consume_skb(darg->skb); + return pad; + } - return err; + rxm = strp_msg(darg->skb); + rxm->full_len -= pad; + + return 0; } -int decrypt_skb(struct sock *sk, struct sk_buff *skb, - struct scatterlist *sgout) +static int +tls_decrypt_device(struct sock *sk, struct msghdr *msg, + struct tls_context *tls_ctx, struct tls_decrypt_arg *darg) { - bool zc = true; - int chunk; + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct strp_msg *rxm; + int pad, err; + + if (tls_ctx->rx_conf != TLS_HW) + return 0; + + err = tls_device_decrypted(sk, tls_ctx); + if (err <= 0) + return err; + + pad = tls_padding_length(prot, tls_strp_msg(ctx), darg); + if (pad < 0) + return pad; + + darg->async = false; + darg->skb = tls_strp_msg(ctx); + /* ->zc downgrade check, in case TLS 1.3 gets here */ + darg->zc &= !(prot->version == TLS_1_3_VERSION && + tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA); - return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false); + rxm = strp_msg(darg->skb); + rxm->full_len -= pad; + + if (!darg->zc) { + /* Non-ZC case needs a real skb */ + darg->skb = tls_strp_msg_detach(ctx); + if (!darg->skb) + return -ENOMEM; + } else { + unsigned int off, len; + + /* In ZC case nobody cares about the output skb. + * Just copy the data here. Note the skb is not fully trimmed. + */ + off = rxm->offset + prot->prepend_size; + len = rxm->full_len - prot->overhead_size; + + err = skb_copy_datagram_msg(darg->skb, off, msg, len); + if (err) + return err; + } + return 1; } -static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, - unsigned int len) +static int tls_check_pending_rekey(struct sock *sk, struct tls_context *ctx, + struct sk_buff *skb) +{ + const struct strp_msg *rxm = strp_msg(skb); + const struct tls_msg *tlm = tls_msg(skb); + char hs_type; + int err; + + if (likely(tlm->control != TLS_RECORD_TYPE_HANDSHAKE)) + return 0; + + if (rxm->full_len < 1) + return 0; + + err = skb_copy_bits(skb, rxm->offset, &hs_type, 1); + if (err < 0) { + DEBUG_NET_WARN_ON_ONCE(1); + return err; + } + + if (hs_type == TLS_HANDSHAKE_KEYUPDATE) { + struct tls_sw_context_rx *rx_ctx = ctx->priv_ctx_rx; + + WRITE_ONCE(rx_ctx->key_update_pending, true); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYRECEIVED); + } + + return 0; +} + +static int tls_rx_one_record(struct sock *sk, struct msghdr *msg, + struct tls_decrypt_arg *darg) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct strp_msg *rxm; + int err; - if (skb) { - struct strp_msg *rxm = strp_msg(skb); + err = tls_decrypt_device(sk, msg, tls_ctx, darg); + if (!err) + err = tls_decrypt_sw(sk, tls_ctx, msg, darg); + if (err < 0) + return err; + + rxm = strp_msg(darg->skb); + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; + tls_advance_record_sn(sk, prot, &tls_ctx->rx); + + return tls_check_pending_rekey(sk, tls_ctx, darg->skb); +} - if (len < rxm->full_len) { - rxm->offset += len; - rxm->full_len -= len; - return false; +int decrypt_skb(struct sock *sk, struct scatterlist *sgout) +{ + struct tls_decrypt_arg darg = { .zc = true, }; + + return tls_decrypt_sg(sk, NULL, sgout, &darg); +} + +/* All records returned from a recvmsg() call must have the same type. + * 0 is not a valid content type. Use it as "no type reported, yet". + */ +static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, + u8 *control) +{ + int err; + + if (!*control) { + *control = tlm->control; + if (!*control) + return -EBADMSG; + + err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, + sizeof(*control), control); + if (*control != TLS_RECORD_TYPE_DATA) { + if (err || msg->msg_flags & MSG_CTRUNC) + return -EIO; } - kfree_skb(skb); + } else if (*control != tlm->control) { + return 0; } - /* Finished with message */ - ctx->recv_pkt = NULL; - __strp_unpause(&ctx->strp); + return 1; +} - return true; +static void tls_rx_rec_done(struct tls_sw_context_rx *ctx) +{ + tls_strp_msg_done(&ctx->strp); } /* This function traverses the rx_list in tls receive context to copies the - * decrypted data records into the buffer provided by caller zero copy is not + * decrypted records into the buffer provided by caller zero copy is not * true. Further, the records are removed from the rx_list if it is not a peek * case and the record has been consumed completely. */ static int process_rx_list(struct tls_sw_context_rx *ctx, struct msghdr *msg, + u8 *control, size_t skip, size_t len, - bool zc, - bool is_peek) + bool is_peek, + bool *more) { struct sk_buff *skb = skb_peek(&ctx->rx_list); + struct tls_msg *tlm; ssize_t copied = 0; + int err; while (skip && skb) { struct strp_msg *rxm = strp_msg(skb); + tlm = tls_msg(skb); + + err = tls_record_content_type(msg, tlm, control); + if (err <= 0) + goto more; if (skip < rxm->full_len) break; @@ -1495,12 +1896,16 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, struct strp_msg *rxm = strp_msg(skb); int chunk = min_t(unsigned int, rxm->full_len - skip, len); - if (!zc || (rxm->full_len - skip) > len) { - int err = skb_copy_datagram_msg(skb, rxm->offset + skip, - msg, chunk); - if (err < 0) - return err; - } + tlm = tls_msg(skb); + + err = tls_record_content_type(msg, tlm, control); + if (err <= 0) + goto more; + + err = skb_copy_datagram_msg(skb, rxm->offset + skip, + msg, chunk); + if (err < 0) + goto more; len = len - chunk; copied = copied + chunk; @@ -1526,209 +1931,324 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, next_skb = skb_peek_next(skb, &ctx->rx_list); if (!is_peek) { - skb_unlink(skb, &ctx->rx_list); - kfree_skb(skb); + __skb_unlink(skb, &ctx->rx_list); + consume_skb(skb); } skb = next_skb; } + err = 0; + +out: + return copied ? : err; +more: + if (more) + *more = true; + goto out; +} + +static bool +tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, + size_t len_left, size_t decrypted, ssize_t done, + size_t *flushed_at) +{ + size_t max_rec; + + if (len_left <= decrypted) + return false; + + max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE; + if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec) + return false; + + *flushed_at = done; + return sk_flush_backlog(sk); +} + +static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx, + bool nonblock) +{ + long timeo; + int ret; + + timeo = sock_rcvtimeo(sk, nonblock); + + while (unlikely(ctx->reader_present)) { + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + ctx->reader_contended = 1; + + add_wait_queue(&ctx->wq, &wait); + ret = sk_wait_event(sk, &timeo, + !READ_ONCE(ctx->reader_present), &wait); + remove_wait_queue(&ctx->wq, &wait); + + if (timeo <= 0) + return -EAGAIN; + if (signal_pending(current)) + return sock_intr_errno(timeo); + if (ret < 0) + return ret; + } + + WRITE_ONCE(ctx->reader_present, 1); + + return 0; +} + +static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, + bool nonblock) +{ + int err; + + lock_sock(sk); + err = tls_rx_reader_acquire(sk, ctx, nonblock); + if (err) + release_sock(sk); + return err; +} + +static void tls_rx_reader_release(struct sock *sk, struct tls_sw_context_rx *ctx) +{ + if (unlikely(ctx->reader_contended)) { + if (wq_has_sleeper(&ctx->wq)) + wake_up(&ctx->wq); + else + ctx->reader_contended = 0; + + WARN_ON_ONCE(!ctx->reader_present); + } + + WRITE_ONCE(ctx->reader_present, 0); +} - return copied; +static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx) +{ + tls_rx_reader_release(sk, ctx); + release_sock(sk); } int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, - int nonblock, int flags, int *addr_len) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + ssize_t decrypted = 0, async_copy_bytes = 0; struct sk_psock *psock; unsigned char control = 0; - ssize_t decrypted = 0; + size_t flushed_at = 0; struct strp_msg *rxm; - struct sk_buff *skb; + struct tls_msg *tlm; ssize_t copied = 0; - bool cmsg = false; - int target, err = 0; - long timeo; + ssize_t peeked = 0; + bool async = false; + int target, err; bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_peek = flags & MSG_PEEK; - int num_async = 0; - - flags |= nonblock; + bool rx_more = false; + bool released = true; + bool bpf_strp_enabled; + bool zc_capable; if (unlikely(flags & MSG_ERRQUEUE)) return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); + err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT); + if (err < 0) + return err; psock = sk_psock_get(sk); - lock_sock(sk); + bpf_strp_enabled = sk_psock_strp_enabled(psock); + + /* If crypto failed the connection is broken */ + err = ctx->async_wait.err; + if (err) + goto end; /* Process pending decrypted records. It must be non-zero-copy */ - err = process_rx_list(ctx, msg, 0, len, false, is_peek); - if (err < 0) { - tls_err_abort(sk, err); + err = process_rx_list(ctx, msg, &control, 0, len, is_peek, &rx_more); + if (err < 0) goto end; - } else { - copied = err; - } + /* process_rx_list() will set @control if it processed any records */ + copied = err; + if (len <= copied || rx_more || + (control && control != TLS_RECORD_TYPE_DATA)) + goto end; + + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); len = len - copied; - if (len) { - target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); - timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); - } else { - goto recv_end; - } - do { - bool retain_skb = false; - bool async = false; - bool zc = false; - int to_decrypt; - int chunk = 0; - - skb = tls_wait_data(sk, psock, flags, timeo, &err); - if (!skb) { - if (psock) { - int ret = __tcp_bpf_recvmsg(sk, psock, - msg, len, flags); + zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek && + ctx->zc_capable; + decrypted = 0; + while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) { + struct tls_decrypt_arg darg; + int to_decrypt, chunk; - if (ret > 0) { - decrypted += ret; - len -= ret; + err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, + released); + if (err <= 0) { + if (psock) { + chunk = sk_msg_recvmsg(sk, psock, msg, len, + flags); + if (chunk > 0) { + decrypted += chunk; + len -= chunk; continue; } } goto recv_end; } - rxm = strp_msg(skb); + memset(&darg.inargs, 0, sizeof(darg.inargs)); - if (!cmsg) { - int cerr; + rxm = strp_msg(tls_strp_msg(ctx)); + tlm = tls_msg(tls_strp_msg(ctx)); - cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, - sizeof(ctx->control), &ctx->control); - cmsg = true; - control = ctx->control; - if (ctx->control != TLS_RECORD_TYPE_DATA) { - if (cerr || msg->msg_flags & MSG_CTRUNC) { - err = -EIO; - goto recv_end; - } - } - } else if (control != ctx->control) { - goto recv_end; - } + to_decrypt = rxm->full_len - prot->overhead_size; - to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size; + if (zc_capable && to_decrypt <= len && + tlm->control == TLS_RECORD_TYPE_DATA) + darg.zc = true; - if (to_decrypt <= len && !is_kvec && !is_peek) - zc = true; + /* Do not use async mode if record is non-data */ + if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled) + darg.async = ctx->async_capable; + else + darg.async = false; - err = decrypt_skb_update(sk, skb, &msg->msg_iter, - &chunk, &zc, ctx->async_capable); - if (err < 0 && err != -EINPROGRESS) { - tls_err_abort(sk, EBADMSG); + err = tls_rx_one_record(sk, msg, &darg); + if (err < 0) { + tls_err_abort(sk, -EBADMSG); goto recv_end; } - if (err == -EINPROGRESS) { - async = true; - num_async++; - goto pick_next_record; - } else { - if (!zc) { - if (rxm->full_len > len) { - retain_skb = true; - chunk = len; - } else { - chunk = rxm->full_len; - } + async |= darg.async; + + /* If the type of records being processed is not known yet, + * set it to record type just dequeued. If it is already known, + * but does not match the record type just dequeued, go to end. + * We always get record type here since for tls1.2, record type + * is known just after record is dequeued from stream parser. + * For tls1.3, we disable async. + */ + err = tls_record_content_type(msg, tls_msg(darg.skb), &control); + if (err <= 0) { + DEBUG_NET_WARN_ON_ONCE(darg.zc); + tls_rx_rec_done(ctx); +put_on_rx_list_err: + __skb_queue_tail(&ctx->rx_list, darg.skb); + goto recv_end; + } - err = skb_copy_datagram_msg(skb, rxm->offset, - msg, chunk); - if (err < 0) - goto recv_end; + /* periodically flush backlog, and feed strparser */ + released = tls_read_flush_backlog(sk, prot, len, to_decrypt, + decrypted + copied, + &flushed_at); + + /* TLS 1.3 may have updated the length by more than overhead */ + rxm = strp_msg(darg.skb); + chunk = rxm->full_len; + tls_rx_rec_done(ctx); + + if (!darg.zc) { + bool partially_consumed = chunk > len; + struct sk_buff *skb = darg.skb; + + DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor); + + if (async) { + /* TLS 1.2-only, to_decrypt must be text len */ + chunk = min_t(int, to_decrypt, len); + async_copy_bytes += chunk; +put_on_rx_list: + decrypted += chunk; + len -= chunk; + __skb_queue_tail(&ctx->rx_list, skb); + if (unlikely(control != TLS_RECORD_TYPE_DATA)) + break; + continue; + } - if (!is_peek) { - rxm->offset = rxm->offset + chunk; - rxm->full_len = rxm->full_len - chunk; + if (bpf_strp_enabled) { + released = true; + err = sk_psock_tls_strp_read(psock, skb); + if (err != __SK_PASS) { + rxm->offset = rxm->offset + rxm->full_len; + rxm->full_len = 0; + if (err == __SK_DROP) + consume_skb(skb); + continue; } } - } -pick_next_record: - if (chunk > len) - chunk = len; + if (partially_consumed) + chunk = len; - decrypted += chunk; - len -= chunk; + err = skb_copy_datagram_msg(skb, rxm->offset, + msg, chunk); + if (err < 0) + goto put_on_rx_list_err; - /* For async or peek case, queue the current skb */ - if (async || is_peek || retain_skb) { - skb_queue_tail(&ctx->rx_list, skb); - skb = NULL; - } + if (is_peek) { + peeked += chunk; + goto put_on_rx_list; + } - if (tls_sw_advance_skb(sk, skb, chunk)) { - /* Return full control message to - * userspace before trying to parse - * another message type - */ - msg->msg_flags |= MSG_EOR; - if (ctx->control != TLS_RECORD_TYPE_DATA) - goto recv_end; - } else { - break; + if (partially_consumed) { + rxm->offset += chunk; + rxm->full_len -= chunk; + goto put_on_rx_list; + } + + consume_skb(skb); } - /* If we have a new message from strparser, continue now. */ - if (decrypted >= target && !ctx->recv_pkt) + decrypted += chunk; + len -= chunk; + + /* Return full control message to userspace before trying + * to parse another message type + */ + msg->msg_flags |= MSG_EOR; + if (control != TLS_RECORD_TYPE_DATA) break; - } while (len); + } recv_end: - if (num_async) { + if (async) { + int ret; + /* Wait for all previously submitted records to be decrypted */ - smp_store_mb(ctx->async_notify, true); - if (atomic_read(&ctx->decrypt_pending)) { - err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); - if (err) { - /* one of async decrypt failed */ - tls_err_abort(sk, err); - copied = 0; - decrypted = 0; - goto end; - } - } else { - reinit_completion(&ctx->async_wait.completion); + ret = tls_decrypt_async_wait(ctx); + __skb_queue_purge(&ctx->async_hold); + + if (ret) { + if (err >= 0 || err == -EINPROGRESS) + err = ret; + goto end; } - WRITE_ONCE(ctx->async_notify, false); /* Drain records from the rx_list & copy if required */ - if (is_peek || is_kvec) - err = process_rx_list(ctx, msg, copied, - decrypted, false, is_peek); + if (is_peek) + err = process_rx_list(ctx, msg, &control, copied + peeked, + decrypted - peeked, is_peek, NULL); else - err = process_rx_list(ctx, msg, 0, - decrypted, true, is_peek); - if (err < 0) { - tls_err_abort(sk, err); - copied = 0; - goto end; - } + err = process_rx_list(ctx, msg, &control, 0, + async_copy_bytes, is_peek, NULL); - WARN_ON(decrypted != err); + /* we could have copied less than we wanted, and possibly nothing */ + decrypted += max(err, 0) - async_copy_bytes; } copied += decrypted; end: - release_sock(sk); + tls_rx_reader_unlock(sk, ctx); if (psock) sk_psock_put(sk, psock); return copied ? : err; @@ -1742,52 +2262,166 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct strp_msg *rxm = NULL; struct sock *sk = sock->sk; + struct tls_msg *tlm; struct sk_buff *skb; ssize_t copied = 0; - int err = 0; - long timeo; int chunk; - bool zc = false; + int err; - lock_sock(sk); - - timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK); + if (err < 0) + return err; - skb = tls_wait_data(sk, NULL, flags, timeo, &err); - if (!skb) - goto splice_read_end; + if (!skb_queue_empty(&ctx->rx_list)) { + skb = __skb_dequeue(&ctx->rx_list); + } else { + struct tls_decrypt_arg darg; - /* splice does not support reading control messages */ - if (ctx->control != TLS_RECORD_TYPE_DATA) { - err = -ENOTSUPP; - goto splice_read_end; - } + err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK, + true); + if (err <= 0) + goto splice_read_end; - if (!ctx->decrypted) { - err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false); + memset(&darg.inargs, 0, sizeof(darg.inargs)); + err = tls_rx_one_record(sk, NULL, &darg); if (err < 0) { - tls_err_abort(sk, EBADMSG); + tls_err_abort(sk, -EBADMSG); goto splice_read_end; } - ctx->decrypted = true; + + tls_rx_rec_done(ctx); + skb = darg.skb; } + rxm = strp_msg(skb); + tlm = tls_msg(skb); + + /* splice does not support reading control messages */ + if (tlm->control != TLS_RECORD_TYPE_DATA) { + err = -EINVAL; + goto splice_requeue; + } chunk = min_t(unsigned int, rxm->full_len, len); copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); if (copied < 0) - goto splice_read_end; + goto splice_requeue; + + if (chunk < rxm->full_len) { + rxm->offset += len; + rxm->full_len -= len; + goto splice_requeue; + } - if (likely(!(flags & MSG_PEEK))) - tls_sw_advance_skb(sk, skb, copied); + consume_skb(skb); splice_read_end: - release_sock(sk); + tls_rx_reader_unlock(sk, ctx); + return copied ? : err; + +splice_requeue: + __skb_queue_head(&ctx->rx_list, skb); + goto splice_read_end; +} + +int tls_sw_read_sock(struct sock *sk, read_descriptor_t *desc, + sk_read_actor_t read_actor) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct strp_msg *rxm = NULL; + struct sk_buff *skb = NULL; + struct sk_psock *psock; + size_t flushed_at = 0; + bool released = true; + struct tls_msg *tlm; + ssize_t copied = 0; + ssize_t decrypted; + int err, used; + + psock = sk_psock_get(sk); + if (psock) { + sk_psock_put(sk, psock); + return -EINVAL; + } + err = tls_rx_reader_acquire(sk, ctx, true); + if (err < 0) + return err; + + /* If crypto failed the connection is broken */ + err = ctx->async_wait.err; + if (err) + goto read_sock_end; + + decrypted = 0; + do { + if (!skb_queue_empty(&ctx->rx_list)) { + skb = __skb_dequeue(&ctx->rx_list); + rxm = strp_msg(skb); + tlm = tls_msg(skb); + } else { + struct tls_decrypt_arg darg; + + err = tls_rx_rec_wait(sk, NULL, true, released); + if (err <= 0) + goto read_sock_end; + + memset(&darg.inargs, 0, sizeof(darg.inargs)); + + err = tls_rx_one_record(sk, NULL, &darg); + if (err < 0) { + tls_err_abort(sk, -EBADMSG); + goto read_sock_end; + } + + released = tls_read_flush_backlog(sk, prot, INT_MAX, + 0, decrypted, + &flushed_at); + skb = darg.skb; + rxm = strp_msg(skb); + tlm = tls_msg(skb); + decrypted += rxm->full_len; + + tls_rx_rec_done(ctx); + } + + /* read_sock does not support reading control messages */ + if (tlm->control != TLS_RECORD_TYPE_DATA) { + err = -EINVAL; + goto read_sock_requeue; + } + + used = read_actor(desc, skb, rxm->offset, rxm->full_len); + if (used <= 0) { + if (!copied) + err = used; + goto read_sock_requeue; + } + copied += used; + if (used < rxm->full_len) { + rxm->offset += used; + rxm->full_len -= used; + if (!desc->count) + goto read_sock_requeue; + } else { + consume_skb(skb); + if (!desc->count) + skb = NULL; + } + } while (skb); + +read_sock_end: + tls_rx_reader_release(sk, ctx); return copied ? : err; + +read_sock_requeue: + __skb_queue_head(&ctx->rx_list, skb); + goto read_sock_end; } -bool tls_sw_stream_read(const struct sock *sk) +bool tls_sw_sock_is_readable(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); @@ -1800,42 +2434,45 @@ bool tls_sw_stream_read(const struct sock *sk) ingress_empty = list_empty(&psock->ingress_msg); rcu_read_unlock(); - return !ingress_empty || ctx->recv_pkt; + return !ingress_empty || tls_strp_msg_ready(ctx) || + !skb_queue_empty(&ctx->rx_list); } -static int tls_read_size(struct strparser *strp, struct sk_buff *skb) +int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; - struct strp_msg *rxm = strp_msg(skb); + struct tls_prot_info *prot = &tls_ctx->prot_info; + char header[TLS_HEADER_SIZE + TLS_MAX_IV_SIZE]; size_t cipher_overhead; size_t data_len = 0; int ret; /* Verify that we have a full TLS header, or wait for more data */ - if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) + if (strp->stm.offset + prot->prepend_size > skb->len) return 0; /* Sanity-check size of on-stack buffer. */ - if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) { + if (WARN_ON(prot->prepend_size > sizeof(header))) { ret = -EINVAL; goto read_failure; } /* Linearize header to local buffer */ - ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); - + ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size); if (ret < 0) goto read_failure; - ctx->control = header[0]; + strp->mark = header[0]; data_len = ((header[4] & 0xFF) | (header[3] << 8)); - cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size; + cipher_overhead = prot->tag_size; + if (prot->version != TLS_1_3_VERSION && + prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) + cipher_overhead += prot->iv_size; - if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) { + if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead + + prot->tail_size) { ret = -EMSGSIZE; goto read_failure; } @@ -1844,34 +2481,27 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) goto read_failure; } - if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) || - header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) { + /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */ + if (header[1] != TLS_1_2_VERSION_MINOR || + header[2] != TLS_1_2_VERSION_MAJOR) { ret = -EINVAL; goto read_failure; } -#ifdef CONFIG_TLS_DEVICE - handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset, - *(u64*)tls_ctx->rx.rec_seq); -#endif + tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE, + TCP_SKB_CB(skb)->seq + strp->stm.offset); return data_len + TLS_HEADER_SIZE; read_failure: - tls_err_abort(strp->sk, ret); - + tls_strp_abort_strp(strp, ret); return ret; } -static void tls_queue(struct strparser *strp, struct sk_buff *skb) +void tls_rx_msg_ready(struct tls_strparser *strp) { - struct tls_context *tls_ctx = tls_get_ctx(strp->sk); - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - - ctx->decrypted = false; - - ctx->recv_pkt = skb; - strp_pause(strp); + struct tls_sw_context_rx *ctx; + ctx = container_of(strp, struct tls_sw_context_rx, strp); ctx->saved_data_ready(strp->sk); } @@ -1880,49 +2510,48 @@ static void tls_data_ready(struct sock *sk) struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct sk_psock *psock; + gfp_t alloc_save; - strp_data_ready(&ctx->strp); + trace_sk_data_ready(sk); + + alloc_save = sk->sk_allocation; + sk->sk_allocation = GFP_ATOMIC; + tls_strp_data_ready(&ctx->strp); + sk->sk_allocation = alloc_save; psock = sk_psock_get(sk); - if (psock && !list_empty(&psock->ingress_msg)) { - ctx->saved_data_ready(sk); + if (psock) { + if (!list_empty(&psock->ingress_msg)) + ctx->saved_data_ready(sk); sk_psock_put(sk, psock); } } -void tls_sw_free_resources_tx(struct sock *sk) +void tls_sw_cancel_work_tx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + + set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask); + set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask); + cancel_delayed_work_sync(&ctx->tx_work.work); +} + +void tls_sw_release_resources_tx(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec, *tmp; /* Wait for any pending async encryptions to complete */ - smp_store_mb(ctx->async_notify, true); - if (atomic_read(&ctx->encrypt_pending)) - crypto_wait_req(-EINPROGRESS, &ctx->async_wait); - - cancel_delayed_work_sync(&ctx->tx_work.work); + tls_encrypt_async_wait(ctx); - /* Tx whatever records we can transmit and abandon the rest */ tls_tx_records(sk, -1); /* Free up un-sent records in tx_list. First, free * the partially sent record if any at head of tx_list. */ if (tls_ctx->partially_sent_record) { - struct scatterlist *sg = tls_ctx->partially_sent_record; - - while (1) { - put_page(sg_page(sg)); - sk_mem_uncharge(sk, sg->length); - - if (sg_is_last(sg)) - break; - sg++; - } - - tls_ctx->partially_sent_record = NULL; - + tls_free_partial_record(sk, tls_ctx); rec = list_first_entry(&ctx->tx_list, struct tls_rec, list); list_del(&rec->list); @@ -1939,6 +2568,11 @@ void tls_sw_free_resources_tx(struct sock *sk) crypto_free_aead(ctx->aead_send); tls_free_open_rec(sk); +} + +void tls_sw_free_ctx_tx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); kfree(ctx); } @@ -1949,30 +2583,43 @@ void tls_sw_release_resources_rx(struct sock *sk) struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); if (ctx->aead_recv) { - kfree_skb(ctx->recv_pkt); - ctx->recv_pkt = NULL; - skb_queue_purge(&ctx->rx_list); + __skb_queue_purge(&ctx->rx_list); crypto_free_aead(ctx->aead_recv); - strp_stop(&ctx->strp); - write_lock_bh(&sk->sk_callback_lock); - sk->sk_data_ready = ctx->saved_data_ready; - write_unlock_bh(&sk->sk_callback_lock); - release_sock(sk); - strp_done(&ctx->strp); - lock_sock(sk); + tls_strp_stop(&ctx->strp); + /* If tls_sw_strparser_arm() was not called (cleanup paths) + * we still want to tls_strp_stop(), but sk->sk_data_ready was + * never swapped. + */ + if (ctx->saved_data_ready) { + write_lock_bh(&sk->sk_callback_lock); + sk->sk_data_ready = ctx->saved_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + } } } -void tls_sw_free_resources_rx(struct sock *sk) +void tls_sw_strparser_done(struct tls_context *tls_ctx) { - struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - tls_sw_release_resources_rx(sk); + tls_strp_done(&ctx->strp); +} + +void tls_sw_free_ctx_rx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); kfree(ctx); } +void tls_sw_free_resources_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + + tls_sw_release_resources_rx(sk); + tls_sw_free_ctx_rx(tls_ctx); +} + /* The work handler to transmitt the encrypted records in tx_list */ static void tx_work_handler(struct work_struct *work) { @@ -1981,158 +2628,266 @@ static void tx_work_handler(struct work_struct *work) struct tx_work, work); struct sock *sk = tx_work->sk; struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_sw_context_tx *ctx; + + if (unlikely(!tls_ctx)) + return; + + ctx = tls_sw_ctx_tx(tls_ctx); + if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask)) + return; if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) return; - lock_sock(sk); - tls_tx_records(sk, -1); - release_sock(sk); + if (mutex_trylock(&tls_ctx->tx_lock)) { + lock_sock(sk); + tls_tx_records(sk, -1); + release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); + } else if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { + /* Someone is holding the tx_lock, they will likely run Tx + * and cancel the work on their way out of the lock section. + * Schedule a long delay just in case. + */ + schedule_delayed_work(&ctx->tx_work.work, msecs_to_jiffies(10)); + } } -int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) +static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx) { - struct tls_crypto_info *crypto_info; - struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; + struct tls_rec *rec; + + rec = list_first_entry_or_null(&ctx->tx_list, struct tls_rec, list); + if (!rec) + return false; + + return READ_ONCE(rec->tx_ready); +} + +void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) +{ + struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); + + /* Schedule the transmission if tx list is ready */ + if (tls_is_tx_ready(tx_ctx) && + !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) + schedule_delayed_work(&tx_ctx->tx_work.work, 0); +} + +void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); + + write_lock_bh(&sk->sk_callback_lock); + rx_ctx->saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = tls_data_ready; + write_unlock_bh(&sk->sk_callback_lock); +} + +void tls_update_rx_zc_capable(struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); + + rx_ctx->zc_capable = tls_ctx->rx_no_pad || + tls_ctx->prot_info.version != TLS_1_3_VERSION; +} + +static struct tls_sw_context_tx *init_ctx_tx(struct tls_context *ctx, struct sock *sk) +{ + struct tls_sw_context_tx *sw_ctx_tx; + + if (!ctx->priv_ctx_tx) { + sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); + if (!sw_ctx_tx) + return NULL; + } else { + sw_ctx_tx = ctx->priv_ctx_tx; + } + + crypto_init_wait(&sw_ctx_tx->async_wait); + atomic_set(&sw_ctx_tx->encrypt_pending, 1); + INIT_LIST_HEAD(&sw_ctx_tx->tx_list); + INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); + sw_ctx_tx->tx_work.sk = sk; + + return sw_ctx_tx; +} + +static struct tls_sw_context_rx *init_ctx_rx(struct tls_context *ctx) +{ + struct tls_sw_context_rx *sw_ctx_rx; + + if (!ctx->priv_ctx_rx) { + sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); + if (!sw_ctx_rx) + return NULL; + } else { + sw_ctx_rx = ctx->priv_ctx_rx; + } + + crypto_init_wait(&sw_ctx_rx->async_wait); + atomic_set(&sw_ctx_rx->decrypt_pending, 1); + init_waitqueue_head(&sw_ctx_rx->wq); + skb_queue_head_init(&sw_ctx_rx->rx_list); + skb_queue_head_init(&sw_ctx_rx->async_hold); + + return sw_ctx_rx; +} + +int init_prot_info(struct tls_prot_info *prot, + const struct tls_crypto_info *crypto_info, + const struct tls_cipher_desc *cipher_desc) +{ + u16 nonce_size = cipher_desc->nonce; + + if (crypto_info->version == TLS_1_3_VERSION) { + nonce_size = 0; + prot->aad_size = TLS_HEADER_SIZE; + prot->tail_size = 1; + } else { + prot->aad_size = TLS_AAD_SPACE_SIZE; + prot->tail_size = 0; + } + + /* Sanity-check the sizes for stack allocations. */ + if (nonce_size > TLS_MAX_IV_SIZE || prot->aad_size > TLS_MAX_AAD_SIZE) + return -EINVAL; + + prot->version = crypto_info->version; + prot->cipher_type = crypto_info->cipher_type; + prot->prepend_size = TLS_HEADER_SIZE + nonce_size; + prot->tag_size = cipher_desc->tag; + prot->overhead_size = prot->prepend_size + prot->tag_size + prot->tail_size; + prot->iv_size = cipher_desc->iv; + prot->salt_size = cipher_desc->salt; + prot->rec_seq_size = cipher_desc->rec_seq; + + return 0; +} + +static void tls_finish_key_update(struct sock *sk, struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *ctx = tls_ctx->priv_ctx_rx; + + WRITE_ONCE(ctx->key_update_pending, false); + /* wake-up pre-existing poll() */ + ctx->saved_data_ready(sk); +} + +int tls_set_sw_offload(struct sock *sk, int tx, + struct tls_crypto_info *new_crypto_info) +{ + struct tls_crypto_info *crypto_info, *src_crypto_info; struct tls_sw_context_tx *sw_ctx_tx = NULL; struct tls_sw_context_rx *sw_ctx_rx = NULL; + const struct tls_cipher_desc *cipher_desc; + char *iv, *rec_seq, *key, *salt; struct cipher_context *cctx; + struct tls_prot_info *prot; struct crypto_aead **aead; - struct strp_callbacks cb; - u16 nonce_size, tag_size, iv_size, rec_seq_size; + struct tls_context *ctx; struct crypto_tfm *tfm; - char *iv, *rec_seq; int rc = 0; - if (!ctx) { - rc = -EINVAL; - goto out; - } + ctx = tls_get_ctx(sk); + prot = &ctx->prot_info; - if (tx) { - if (!ctx->priv_ctx_tx) { - sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); - if (!sw_ctx_tx) { - rc = -ENOMEM; - goto out; - } - ctx->priv_ctx_tx = sw_ctx_tx; + /* new_crypto_info != NULL means rekey */ + if (!new_crypto_info) { + if (tx) { + ctx->priv_ctx_tx = init_ctx_tx(ctx, sk); + if (!ctx->priv_ctx_tx) + return -ENOMEM; } else { - sw_ctx_tx = - (struct tls_sw_context_tx *)ctx->priv_ctx_tx; - } - } else { - if (!ctx->priv_ctx_rx) { - sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); - if (!sw_ctx_rx) { - rc = -ENOMEM; - goto out; - } - ctx->priv_ctx_rx = sw_ctx_rx; - } else { - sw_ctx_rx = - (struct tls_sw_context_rx *)ctx->priv_ctx_rx; + ctx->priv_ctx_rx = init_ctx_rx(ctx); + if (!ctx->priv_ctx_rx) + return -ENOMEM; } } if (tx) { - crypto_init_wait(&sw_ctx_tx->async_wait); + sw_ctx_tx = ctx->priv_ctx_tx; crypto_info = &ctx->crypto_send.info; cctx = &ctx->tx; aead = &sw_ctx_tx->aead_send; - INIT_LIST_HEAD(&sw_ctx_tx->tx_list); - INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); - sw_ctx_tx->tx_work.sk = sk; } else { - crypto_init_wait(&sw_ctx_rx->async_wait); + sw_ctx_rx = ctx->priv_ctx_rx; crypto_info = &ctx->crypto_recv.info; cctx = &ctx->rx; - skb_queue_head_init(&sw_ctx_rx->rx_list); aead = &sw_ctx_rx->aead_recv; } - switch (crypto_info->cipher_type) { - case TLS_CIPHER_AES_GCM_128: { - nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; - tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; - iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; - iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; - rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; - rec_seq = - ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; - gcm_128_info = - (struct tls12_crypto_info_aes_gcm_128 *)crypto_info; - break; - } - default: - rc = -EINVAL; - goto free_priv; - } + src_crypto_info = new_crypto_info ?: crypto_info; - /* Sanity-check the IV size for stack allocations. */ - if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) { + cipher_desc = get_cipher_desc(src_crypto_info->cipher_type); + if (!cipher_desc) { rc = -EINVAL; goto free_priv; } - cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; - cctx->tag_size = tag_size; - cctx->overhead_size = cctx->prepend_size + cctx->tag_size; - cctx->iv_size = iv_size; - cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - GFP_KERNEL); - if (!cctx->iv) { - rc = -ENOMEM; + rc = init_prot_info(prot, src_crypto_info, cipher_desc); + if (rc) goto free_priv; - } - memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); - memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); - cctx->rec_seq_size = rec_seq_size; - cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); - if (!cctx->rec_seq) { - rc = -ENOMEM; - goto free_iv; - } + + iv = crypto_info_iv(src_crypto_info, cipher_desc); + key = crypto_info_key(src_crypto_info, cipher_desc); + salt = crypto_info_salt(src_crypto_info, cipher_desc); + rec_seq = crypto_info_rec_seq(src_crypto_info, cipher_desc); if (!*aead) { - *aead = crypto_alloc_aead("gcm(aes)", 0, 0); + *aead = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0); if (IS_ERR(*aead)) { rc = PTR_ERR(*aead); *aead = NULL; - goto free_rec_seq; + goto free_priv; } } ctx->push_pending_record = tls_sw_push_pending_record; - rc = crypto_aead_setkey(*aead, gcm_128_info->key, - TLS_CIPHER_AES_GCM_128_KEY_SIZE); - if (rc) - goto free_aead; + /* setkey is the last operation that could fail during a + * rekey. if it succeeds, we can start modifying the + * context. + */ + rc = crypto_aead_setkey(*aead, key, cipher_desc->key); + if (rc) { + if (new_crypto_info) + goto out; + else + goto free_aead; + } - rc = crypto_aead_setauthsize(*aead, cctx->tag_size); - if (rc) - goto free_aead; + if (!new_crypto_info) { + rc = crypto_aead_setauthsize(*aead, prot->tag_size); + if (rc) + goto free_aead; + } - if (sw_ctx_rx) { + if (!tx && !new_crypto_info) { tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv); - sw_ctx_rx->async_capable = - tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC; - /* Set up strparser */ - memset(&cb, 0, sizeof(cb)); - cb.rcv_msg = tls_queue; - cb.parse_msg = tls_read_size; + tls_update_rx_zc_capable(ctx); + sw_ctx_rx->async_capable = + src_crypto_info->version != TLS_1_3_VERSION && + !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC); - strp_init(&sw_ctx_rx->strp, sk, &cb); + rc = tls_strp_init(&sw_ctx_rx->strp, sk); + if (rc) + goto free_aead; + } - write_lock_bh(&sk->sk_callback_lock); - sw_ctx_rx->saved_data_ready = sk->sk_data_ready; - sk->sk_data_ready = tls_data_ready; - write_unlock_bh(&sk->sk_callback_lock); + memcpy(cctx->iv, salt, cipher_desc->salt); + memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv); + memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq); - strp_check_rcv(&sw_ctx_rx->strp); + if (new_crypto_info) { + unsafe_memcpy(crypto_info, new_crypto_info, + cipher_desc->crypto_info, + /* size was checked in do_tls_setsockopt_conf */); + memzero_explicit(new_crypto_info, cipher_desc->crypto_info); + if (!tx) + tls_finish_key_update(sk, ctx); } goto out; @@ -2140,19 +2895,15 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) free_aead: crypto_free_aead(*aead); *aead = NULL; -free_rec_seq: - kfree(cctx->rec_seq); - cctx->rec_seq = NULL; -free_iv: - kfree(cctx->iv); - cctx->iv = NULL; free_priv: - if (tx) { - kfree(ctx->priv_ctx_tx); - ctx->priv_ctx_tx = NULL; - } else { - kfree(ctx->priv_ctx_rx); - ctx->priv_ctx_rx = NULL; + if (!new_crypto_info) { + if (tx) { + kfree(ctx->priv_ctx_tx); + ctx->priv_ctx_tx = NULL; + } else { + kfree(ctx->priv_ctx_rx); + ctx->priv_ctx_rx = NULL; + } } out: return rc; |
