diff options
Diffstat (limited to 'net/tls/tls_main.c')
| -rw-r--r-- | net/tls/tls_main.c | 590 |
1 files changed, 305 insertions, 285 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 3735cb00905d..56ce0bc8317b 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -58,23 +58,67 @@ enum { TLS_NUM_PROTS, }; -#define CIPHER_SIZE_DESC(cipher) [cipher] = { \ +#define CHECK_CIPHER_DESC(cipher,ci) \ + static_assert(cipher ## _IV_SIZE <= TLS_MAX_IV_SIZE); \ + static_assert(cipher ## _SALT_SIZE <= TLS_MAX_SALT_SIZE); \ + static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE); \ + static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE); \ + static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE); \ + static_assert(sizeof_field(struct ci, key) == cipher ## _KEY_SIZE); \ + static_assert(sizeof_field(struct ci, salt) == cipher ## _SALT_SIZE); \ + static_assert(sizeof_field(struct ci, rec_seq) == cipher ## _REC_SEQ_SIZE); + +#define __CIPHER_DESC(ci) \ + .iv_offset = offsetof(struct ci, iv), \ + .key_offset = offsetof(struct ci, key), \ + .salt_offset = offsetof(struct ci, salt), \ + .rec_seq_offset = offsetof(struct ci, rec_seq), \ + .crypto_info = sizeof(struct ci) + +#define CIPHER_DESC(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ + .nonce = cipher ## _IV_SIZE, \ .iv = cipher ## _IV_SIZE, \ .key = cipher ## _KEY_SIZE, \ .salt = cipher ## _SALT_SIZE, \ .tag = cipher ## _TAG_SIZE, \ .rec_seq = cipher ## _REC_SEQ_SIZE, \ + .cipher_name = algname, \ + .offloadable = _offloadable, \ + __CIPHER_DESC(ci), \ } -const struct tls_cipher_size_desc tls_cipher_size_desc[] = { - CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_128), - CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_256), - CIPHER_SIZE_DESC(TLS_CIPHER_AES_CCM_128), - CIPHER_SIZE_DESC(TLS_CIPHER_CHACHA20_POLY1305), - CIPHER_SIZE_DESC(TLS_CIPHER_SM4_GCM), - CIPHER_SIZE_DESC(TLS_CIPHER_SM4_CCM), +#define CIPHER_DESC_NONCE0(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ + .nonce = 0, \ + .iv = cipher ## _IV_SIZE, \ + .key = cipher ## _KEY_SIZE, \ + .salt = cipher ## _SALT_SIZE, \ + .tag = cipher ## _TAG_SIZE, \ + .rec_seq = cipher ## _REC_SEQ_SIZE, \ + .cipher_name = algname, \ + .offloadable = _offloadable, \ + __CIPHER_DESC(ci), \ +} + +const struct tls_cipher_desc tls_cipher_desc[TLS_CIPHER_MAX + 1 - TLS_CIPHER_MIN] = { + CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128, "gcm(aes)", true), + CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256, "gcm(aes)", true), + CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128, "ccm(aes)", false), + CIPHER_DESC_NONCE0(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305, "rfc7539(chacha20,poly1305)", false), + CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm, "gcm(sm4)", false), + CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm, "ccm(sm4)", false), + CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128, "gcm(aria)", false), + CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256, "gcm(aria)", false), }; +CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128); +CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256); +CHECK_CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128); +CHECK_CIPHER_DESC(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305); +CHECK_CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm); +CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm); +CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128); +CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256); + static const struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); static const struct proto *saved_tcpv4_prot; @@ -96,8 +140,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx) int wait_on_pending_writer(struct sock *sk, long *timeo) { - int rc = 0; DEFINE_WAIT_FUNC(wait, woken_wake_function); + int ret, rc = 0; add_wait_queue(sk_sleep(sk), &wait); while (1) { @@ -111,8 +155,13 @@ int wait_on_pending_writer(struct sock *sk, long *timeo) break; } - if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait)) + ret = sk_wait_event(sk, timeo, + !READ_ONCE(sk->sk_write_pending), &wait); + if (ret) { + if (ret < 0) + rc = ret; break; + } } remove_wait_queue(sk_sleep(sk), &wait); return rc; @@ -124,7 +173,10 @@ int tls_push_sg(struct sock *sk, u16 first_offset, int flags) { - int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST; + struct bio_vec bvec; + struct msghdr msg = { + .msg_flags = MSG_SPLICE_PAGES | flags, + }; int ret = 0; struct page *p; size_t size; @@ -133,16 +185,16 @@ int tls_push_sg(struct sock *sk, size = sg->length - offset; offset += sg->offset; - ctx->in_tcp_sendpages = true; + ctx->splicing_pages = true; while (1) { - if (sg_is_last(sg)) - sendpage_flags = flags; - /* is sending application-limited? */ tcp_rate_check_app_limited(sk); p = sg_page(sg); retry: - ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags); + bvec_set_page(&bvec, p, size, offset); + iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size); + + ret = tcp_sendmsg_locked(sk, &msg, size); if (ret != size) { if (ret > 0) { @@ -154,7 +206,7 @@ retry: offset -= sg->offset; ctx->partially_sent_offset = offset; ctx->partially_sent_record = (void *)sg; - ctx->in_tcp_sendpages = false; + ctx->splicing_pages = false; return ret; } @@ -168,7 +220,7 @@ retry: size = sg->length; } - ctx->in_tcp_sendpages = false; + ctx->splicing_pages = false; return 0; } @@ -203,12 +255,9 @@ int tls_process_cmsg(struct sock *sk, struct msghdr *msg, if (msg->msg_flags & MSG_MORE) return -EINVAL; - rc = tls_handle_open_record(sk, msg->msg_flags); - if (rc) - return rc; - *record_type = *(unsigned char *)CMSG_DATA(cmsg); - rc = 0; + + rc = tls_handle_open_record(sk, msg->msg_flags); break; default: return -EINVAL; @@ -246,11 +295,11 @@ static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); - /* If in_tcp_sendpages call lower protocol write space handler + /* If splicing_pages call lower protocol write space handler * to ensure we wake up any waiting operations there. For example - * if do_tcp_sendpages where to call sk_wait_event. + * if splicing pages where to call sk_wait_event. */ - if (ctx->in_tcp_sendpages) { + if (ctx->splicing_pages) { ctx->sk_write_space(sk); return; } @@ -297,8 +346,6 @@ static void tls_sk_proto_cleanup(struct sock *sk, /* We need these for tls_sw_fallback handling of other packets */ if (ctx->tx_conf == TLS_SW) { - kfree(ctx->tx.rec_seq); - kfree(ctx->tx.iv); tls_sw_release_resources_tx(sk); TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); } else if (ctx->tx_conf == TLS_HW) { @@ -351,10 +398,45 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) tls_ctx_free(sk, ctx); } +static __poll_t tls_sk_poll(struct file *file, struct socket *sock, + struct poll_table_struct *wait) +{ + struct tls_sw_context_rx *ctx; + struct tls_context *tls_ctx; + struct sock *sk = sock->sk; + struct sk_psock *psock; + __poll_t mask = 0; + u8 shutdown; + int state; + + mask = tcp_poll(file, sock, wait); + + state = inet_sk_state_load(sk); + shutdown = READ_ONCE(sk->sk_shutdown); + if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN)) + return mask; + + tls_ctx = tls_get_ctx(sk); + ctx = tls_sw_ctx_rx(tls_ctx); + psock = sk_psock_get(sk); + + if ((skb_queue_empty_lockless(&ctx->rx_list) && + !tls_strp_msg_ready(ctx) && + sk_psock_queue_empty(psock)) || + READ_ONCE(ctx->key_update_pending)) + mask &= ~(EPOLLIN | EPOLLRDNORM); + + if (psock) + sk_psock_put(sk, psock); + + return mask; +} + static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, int __user *optlen, int tx) { int rc = 0; + const struct tls_cipher_desc *cipher_desc; struct tls_context *ctx = tls_get_ctx(sk); struct tls_crypto_info *crypto_info; struct cipher_context *cctx; @@ -393,188 +475,19 @@ static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, goto out; } - switch (crypto_info->cipher_type) { - case TLS_CIPHER_AES_GCM_128: { - struct tls12_crypto_info_aes_gcm_128 * - crypto_info_aes_gcm_128 = - container_of(crypto_info, - struct tls12_crypto_info_aes_gcm_128, - info); - - if (len != sizeof(*crypto_info_aes_gcm_128)) { - rc = -EINVAL; - goto out; - } - lock_sock(sk); - memcpy(crypto_info_aes_gcm_128->iv, - cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - TLS_CIPHER_AES_GCM_128_IV_SIZE); - memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq, - TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE); - release_sock(sk); - if (copy_to_user(optval, - crypto_info_aes_gcm_128, - sizeof(*crypto_info_aes_gcm_128))) - rc = -EFAULT; - break; - } - case TLS_CIPHER_AES_GCM_256: { - struct tls12_crypto_info_aes_gcm_256 * - crypto_info_aes_gcm_256 = - container_of(crypto_info, - struct tls12_crypto_info_aes_gcm_256, - info); - - if (len != sizeof(*crypto_info_aes_gcm_256)) { - rc = -EINVAL; - goto out; - } - lock_sock(sk); - memcpy(crypto_info_aes_gcm_256->iv, - cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE, - TLS_CIPHER_AES_GCM_256_IV_SIZE); - memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq, - TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE); - release_sock(sk); - if (copy_to_user(optval, - crypto_info_aes_gcm_256, - sizeof(*crypto_info_aes_gcm_256))) - rc = -EFAULT; - break; - } - case TLS_CIPHER_AES_CCM_128: { - struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 = - container_of(crypto_info, - struct tls12_crypto_info_aes_ccm_128, info); - - if (len != sizeof(*aes_ccm_128)) { - rc = -EINVAL; - goto out; - } - lock_sock(sk); - memcpy(aes_ccm_128->iv, - cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE, - TLS_CIPHER_AES_CCM_128_IV_SIZE); - memcpy(aes_ccm_128->rec_seq, cctx->rec_seq, - TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE); - release_sock(sk); - if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128))) - rc = -EFAULT; - break; - } - case TLS_CIPHER_CHACHA20_POLY1305: { - struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 = - container_of(crypto_info, - struct tls12_crypto_info_chacha20_poly1305, - info); - - if (len != sizeof(*chacha20_poly1305)) { - rc = -EINVAL; - goto out; - } - lock_sock(sk); - memcpy(chacha20_poly1305->iv, - cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE, - TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE); - memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq, - TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE); - release_sock(sk); - if (copy_to_user(optval, chacha20_poly1305, - sizeof(*chacha20_poly1305))) - rc = -EFAULT; - break; + cipher_desc = get_cipher_desc(crypto_info->cipher_type); + if (!cipher_desc || len != cipher_desc->crypto_info) { + rc = -EINVAL; + goto out; } - case TLS_CIPHER_SM4_GCM: { - struct tls12_crypto_info_sm4_gcm *sm4_gcm_info = - container_of(crypto_info, - struct tls12_crypto_info_sm4_gcm, info); - if (len != sizeof(*sm4_gcm_info)) { - rc = -EINVAL; - goto out; - } - lock_sock(sk); - memcpy(sm4_gcm_info->iv, - cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE, - TLS_CIPHER_SM4_GCM_IV_SIZE); - memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq, - TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE); - release_sock(sk); - if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info))) - rc = -EFAULT; - break; - } - case TLS_CIPHER_SM4_CCM: { - struct tls12_crypto_info_sm4_ccm *sm4_ccm_info = - container_of(crypto_info, - struct tls12_crypto_info_sm4_ccm, info); + memcpy(crypto_info_iv(crypto_info, cipher_desc), + cctx->iv + cipher_desc->salt, cipher_desc->iv); + memcpy(crypto_info_rec_seq(crypto_info, cipher_desc), + cctx->rec_seq, cipher_desc->rec_seq); - if (len != sizeof(*sm4_ccm_info)) { - rc = -EINVAL; - goto out; - } - lock_sock(sk); - memcpy(sm4_ccm_info->iv, - cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE, - TLS_CIPHER_SM4_CCM_IV_SIZE); - memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq, - TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE); - release_sock(sk); - if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info))) - rc = -EFAULT; - break; - } - case TLS_CIPHER_ARIA_GCM_128: { - struct tls12_crypto_info_aria_gcm_128 * - crypto_info_aria_gcm_128 = - container_of(crypto_info, - struct tls12_crypto_info_aria_gcm_128, - info); - - if (len != sizeof(*crypto_info_aria_gcm_128)) { - rc = -EINVAL; - goto out; - } - lock_sock(sk); - memcpy(crypto_info_aria_gcm_128->iv, - cctx->iv + TLS_CIPHER_ARIA_GCM_128_SALT_SIZE, - TLS_CIPHER_ARIA_GCM_128_IV_SIZE); - memcpy(crypto_info_aria_gcm_128->rec_seq, cctx->rec_seq, - TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE); - release_sock(sk); - if (copy_to_user(optval, - crypto_info_aria_gcm_128, - sizeof(*crypto_info_aria_gcm_128))) - rc = -EFAULT; - break; - } - case TLS_CIPHER_ARIA_GCM_256: { - struct tls12_crypto_info_aria_gcm_256 * - crypto_info_aria_gcm_256 = - container_of(crypto_info, - struct tls12_crypto_info_aria_gcm_256, - info); - - if (len != sizeof(*crypto_info_aria_gcm_256)) { - rc = -EINVAL; - goto out; - } - lock_sock(sk); - memcpy(crypto_info_aria_gcm_256->iv, - cctx->iv + TLS_CIPHER_ARIA_GCM_256_SALT_SIZE, - TLS_CIPHER_ARIA_GCM_256_IV_SIZE); - memcpy(crypto_info_aria_gcm_256->rec_seq, cctx->rec_seq, - TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE); - release_sock(sk); - if (copy_to_user(optval, - crypto_info_aria_gcm_256, - sizeof(*crypto_info_aria_gcm_256))) - rc = -EFAULT; - break; - } - default: - rc = -EINVAL; - } + if (copy_to_user(optval, crypto_info, cipher_desc->crypto_info)) + rc = -EFAULT; out: return rc; @@ -614,11 +527,9 @@ static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, if (len < sizeof(value)) return -EINVAL; - lock_sock(sk); value = -EINVAL; if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) value = ctx->rx_no_pad; - release_sock(sk); if (value < 0) return value; @@ -630,11 +541,35 @@ static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, return 0; } +static int do_tls_getsockopt_tx_payload_len(struct sock *sk, char __user *optval, + int __user *optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + u16 payload_len = ctx->tx_max_payload_len; + int len; + + if (get_user(len, optlen)) + return -EFAULT; + + if (len < sizeof(payload_len)) + return -EINVAL; + + if (put_user(sizeof(payload_len), optlen)) + return -EFAULT; + + if (copy_to_user(optval, &payload_len, sizeof(payload_len))) + return -EFAULT; + + return 0; +} + static int do_tls_getsockopt(struct sock *sk, int optname, char __user *optval, int __user *optlen) { int rc = 0; + lock_sock(sk); + switch (optname) { case TLS_TX: case TLS_RX: @@ -647,10 +582,16 @@ static int do_tls_getsockopt(struct sock *sk, int optname, case TLS_RX_EXPECT_NO_PAD: rc = do_tls_getsockopt_no_pad(sk, optval, optlen); break; + case TLS_TX_MAX_PAYLOAD_LEN: + rc = do_tls_getsockopt_tx_payload_len(sk, optval, optlen); + break; default: rc = -ENOPROTOOPT; break; } + + release_sock(sk); + return rc; } @@ -666,13 +607,41 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, return do_tls_getsockopt(sk, optname, optval, optlen); } +static int validate_crypto_info(const struct tls_crypto_info *crypto_info, + const struct tls_crypto_info *alt_crypto_info) +{ + if (crypto_info->version != TLS_1_2_VERSION && + crypto_info->version != TLS_1_3_VERSION) + return -EINVAL; + + switch (crypto_info->cipher_type) { + case TLS_CIPHER_ARIA_GCM_128: + case TLS_CIPHER_ARIA_GCM_256: + if (crypto_info->version != TLS_1_2_VERSION) + return -EINVAL; + break; + } + + /* Ensure that TLS version and ciphers are same in both directions */ + if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { + if (alt_crypto_info->version != crypto_info->version || + alt_crypto_info->cipher_type != crypto_info->cipher_type) + return -EINVAL; + } + + return 0; +} + static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, unsigned int optlen, int tx) { - struct tls_crypto_info *crypto_info; - struct tls_crypto_info *alt_crypto_info; + struct tls_crypto_info *crypto_info, *alt_crypto_info; + struct tls_crypto_info *old_crypto_info = NULL; struct tls_context *ctx = tls_get_ctx(sk); - size_t optsize; + const struct tls_cipher_desc *cipher_desc; + union tls_crypto_context *crypto_ctx; + union tls_crypto_context tmp = {}; + bool update = false; int rc = 0; int conf; @@ -680,16 +649,30 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, return -EINVAL; if (tx) { - crypto_info = &ctx->crypto_send.info; + crypto_ctx = &ctx->crypto_send; alt_crypto_info = &ctx->crypto_recv.info; } else { - crypto_info = &ctx->crypto_recv.info; + crypto_ctx = &ctx->crypto_recv; alt_crypto_info = &ctx->crypto_send.info; } - /* Currently we don't support set crypto info more than one time */ - if (TLS_CRYPTO_INFO_READY(crypto_info)) - return -EBUSY; + crypto_info = &crypto_ctx->info; + + if (TLS_CRYPTO_INFO_READY(crypto_info)) { + /* Currently we only support setting crypto info more + * than one time for TLS 1.3 + */ + if (crypto_info->version != TLS_1_3_VERSION) { + TLS_INC_STATS(sock_net(sk), tx ? LINUX_MIB_TLSTXREKEYERROR + : LINUX_MIB_TLSRXREKEYERROR); + return -EBUSY; + } + + update = true; + old_crypto_info = crypto_info; + crypto_info = &tmp.info; + crypto_ctx = &tmp; + } rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); if (rc) { @@ -697,62 +680,24 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, goto err_crypto_info; } - /* check version */ - if (crypto_info->version != TLS_1_2_VERSION && - crypto_info->version != TLS_1_3_VERSION) { - rc = -EINVAL; - goto err_crypto_info; - } - - /* Ensure that TLS version and ciphers are same in both directions */ - if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { - if (alt_crypto_info->version != crypto_info->version || - alt_crypto_info->cipher_type != crypto_info->cipher_type) { + if (update) { + /* Ensure that TLS version and ciphers are not modified */ + if (crypto_info->version != old_crypto_info->version || + crypto_info->cipher_type != old_crypto_info->cipher_type) rc = -EINVAL; - goto err_crypto_info; - } + } else { + rc = validate_crypto_info(crypto_info, alt_crypto_info); } + if (rc) + goto err_crypto_info; - switch (crypto_info->cipher_type) { - case TLS_CIPHER_AES_GCM_128: - optsize = sizeof(struct tls12_crypto_info_aes_gcm_128); - break; - case TLS_CIPHER_AES_GCM_256: { - optsize = sizeof(struct tls12_crypto_info_aes_gcm_256); - break; - } - case TLS_CIPHER_AES_CCM_128: - optsize = sizeof(struct tls12_crypto_info_aes_ccm_128); - break; - case TLS_CIPHER_CHACHA20_POLY1305: - optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305); - break; - case TLS_CIPHER_SM4_GCM: - optsize = sizeof(struct tls12_crypto_info_sm4_gcm); - break; - case TLS_CIPHER_SM4_CCM: - optsize = sizeof(struct tls12_crypto_info_sm4_ccm); - break; - case TLS_CIPHER_ARIA_GCM_128: - if (crypto_info->version != TLS_1_2_VERSION) { - rc = -EINVAL; - goto err_crypto_info; - } - optsize = sizeof(struct tls12_crypto_info_aria_gcm_128); - break; - case TLS_CIPHER_ARIA_GCM_256: - if (crypto_info->version != TLS_1_2_VERSION) { - rc = -EINVAL; - goto err_crypto_info; - } - optsize = sizeof(struct tls12_crypto_info_aria_gcm_256); - break; - default: + cipher_desc = get_cipher_desc(crypto_info->cipher_type); + if (!cipher_desc) { rc = -EINVAL; goto err_crypto_info; } - if (optlen != optsize) { + if (optlen != cipher_desc->crypto_info) { rc = -EINVAL; goto err_crypto_info; } @@ -766,17 +711,23 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, } if (tx) { - rc = tls_set_device_offload(sk, ctx); + rc = tls_set_device_offload(sk); conf = TLS_HW; if (!rc) { TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); } else { - rc = tls_set_sw_offload(sk, ctx, 1); + rc = tls_set_sw_offload(sk, 1, + update ? crypto_info : NULL); if (rc) goto err_crypto_info; - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK); + } else { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + } conf = TLS_SW; } } else { @@ -786,14 +737,21 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); } else { - rc = tls_set_sw_offload(sk, ctx, 0); + rc = tls_set_sw_offload(sk, 0, + update ? crypto_info : NULL); if (rc) goto err_crypto_info; - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); + + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK); + } else { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); + } conf = TLS_SW; } - tls_sw_strparser_arm(sk, ctx); + if (!update) + tls_sw_strparser_arm(sk, ctx); } if (tx) @@ -801,6 +759,10 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, else ctx->rx_conf = conf; update_sk_prot(sk, ctx); + + if (update) + return 0; + if (tx) { ctx->sk_write_space = sk->sk_write_space; sk->sk_write_space = tls_write_space; @@ -812,7 +774,11 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, return 0; err_crypto_info: - memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); + if (update) { + TLS_INC_STATS(sock_net(sk), tx ? LINUX_MIB_TLSTXREKEYERROR + : LINUX_MIB_TLSRXREKEYERROR); + } + memzero_explicit(crypto_ctx, sizeof(*crypto_ctx)); return rc; } @@ -868,6 +834,32 @@ static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, return rc; } +static int do_tls_setsockopt_tx_payload_len(struct sock *sk, sockptr_t optval, + unsigned int optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *sw_ctx = tls_sw_ctx_tx(ctx); + u16 value; + bool tls_13 = ctx->prot_info.version == TLS_1_3_VERSION; + + if (sw_ctx && sw_ctx->open_rec) + return -EBUSY; + + if (sockptr_is_null(optval) || optlen != sizeof(value)) + return -EINVAL; + + if (copy_from_sockptr(&value, optval, sizeof(value))) + return -EFAULT; + + if (value < TLS_MIN_RECORD_SIZE_LIM - (tls_13 ? 1 : 0) || + value > TLS_MAX_PAYLOAD_SIZE) + return -EINVAL; + + ctx->tx_max_payload_len = value; + + return 0; +} + static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, unsigned int optlen) { @@ -889,6 +881,11 @@ static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, case TLS_RX_EXPECT_NO_PAD: rc = do_tls_setsockopt_no_pad(sk, optval, optlen); break; + case TLS_TX_MAX_PAYLOAD_LEN: + lock_sock(sk); + rc = do_tls_setsockopt_tx_payload_len(sk, optval, optlen); + release_sock(sk); + break; default: rc = -ENOPROTOOPT; break; @@ -908,6 +905,11 @@ static int tls_setsockopt(struct sock *sk, int level, int optname, return do_tls_setsockopt(sk, optname, optval, optlen); } +static int tls_disconnect(struct sock *sk, int flags) +{ + return -EOPNOTSUPP; +} + struct tls_context *tls_ctx_create(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); @@ -918,9 +920,17 @@ struct tls_context *tls_ctx_create(struct sock *sk) return NULL; mutex_init(&ctx->tx_lock); - rcu_assign_pointer(icsk->icsk_ulp_data, ctx); ctx->sk_proto = READ_ONCE(sk->sk_prot); ctx->sk = sk; + /* Release semantic of rcu_assign_pointer() ensures that + * ctx->sk_proto is visible before changing sk->sk_prot in + * update_sk_prot(), and prevents reading uninitialized value in + * tls_{getsockopt, setsockopt}. Note that we do not need a + * read barrier in tls_{getsockopt,setsockopt} as there is an + * address dependency between sk->sk_proto->{getsockopt,setsockopt} + * and ctx->sk_proto. + */ + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); return ctx; } @@ -930,27 +940,28 @@ static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG] ops[TLS_BASE][TLS_BASE] = *base; ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; - ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked; + ops[TLS_SW ][TLS_BASE].splice_eof = tls_sw_splice_eof; ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; + ops[TLS_BASE][TLS_SW ].poll = tls_sk_poll; + ops[TLS_BASE][TLS_SW ].read_sock = tls_sw_read_sock; ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; + ops[TLS_SW ][TLS_SW ].poll = tls_sk_poll; + ops[TLS_SW ][TLS_SW ].read_sock = tls_sw_read_sock; #ifdef CONFIG_TLS_DEVICE ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; - ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL; ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; - ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL; ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; - ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL; #endif #ifdef CONFIG_TLS_TOE ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; @@ -994,11 +1005,12 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], prot[TLS_BASE][TLS_BASE] = *base; prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; + prot[TLS_BASE][TLS_BASE].disconnect = tls_disconnect; prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; - prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage; + prot[TLS_SW][TLS_BASE].splice_eof = tls_sw_splice_eof; prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; @@ -1013,11 +1025,11 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], #ifdef CONFIG_TLS_DEVICE prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; - prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage; + prot[TLS_HW][TLS_BASE].splice_eof = tls_device_splice_eof; prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; - prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; + prot[TLS_HW][TLS_SW].splice_eof = tls_device_splice_eof; prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; @@ -1063,6 +1075,7 @@ static int tls_init(struct sock *sk) ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; + ctx->tx_max_payload_len = TLS_MAX_PAYLOAD_SIZE; update_sk_prot(sk, ctx); out: write_unlock_bh(&sk->sk_callback_lock); @@ -1104,7 +1117,7 @@ static u16 tls_user_config(struct tls_context *ctx, bool tx) return 0; } -static int tls_get_info(const struct sock *sk, struct sk_buff *skb) +static int tls_get_info(struct sock *sk, struct sk_buff *skb, bool net_admin) { u16 version, cipher_type; struct tls_context *ctx; @@ -1152,6 +1165,12 @@ static int tls_get_info(const struct sock *sk, struct sk_buff *skb) goto nla_failure; } + err = nla_put_u16(skb, TLS_INFO_TX_MAX_PAYLOAD_LEN, + ctx->tx_max_payload_len); + + if (err) + goto nla_failure; + rcu_read_unlock(); nla_nest_end(skb, start); return 0; @@ -1162,7 +1181,7 @@ nla_failure: return err; } -static size_t tls_get_info_size(const struct sock *sk) +static size_t tls_get_info_size(const struct sock *sk, bool net_admin) { size_t size = 0; @@ -1173,6 +1192,7 @@ static size_t tls_get_info_size(const struct sock *sk) nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_TX_MAX_PAYLOAD_LEN */ 0; return size; |
