diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls.h | 4 | ||||
-rw-r--r-- | net/tls/tls_device.c | 20 | ||||
-rw-r--r-- | net/tls/tls_proc.c | 10 | ||||
-rw-r--r-- | net/tls/tls_strp.c | 14 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 3 |
5 files changed, 30 insertions, 21 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h index 4e077068e6d9..2f86baeb71fc 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -128,8 +128,9 @@ struct tls_rec { char aad_space[TLS_AAD_SPACE_SIZE]; u8 iv_data[TLS_MAX_IV_SIZE]; + + /* Must be last --ends in a flexible-array member. */ struct aead_request aead_req; - u8 aead_req_ctx[]; }; int __net_init tls_proc_init(struct net *net); @@ -141,6 +142,7 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx); int wait_on_pending_writer(struct sock *sk, long *timeo); void tls_err_abort(struct sock *sk, int err); +void tls_strp_abort_strp(struct tls_strparser *strp, int err); int init_prot_info(struct tls_prot_info *prot, const struct tls_crypto_info *crypto_info, diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index f672a62a9a52..a64ae15b1a60 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -123,17 +123,19 @@ static void tls_device_queue_ctx_destruction(struct tls_context *ctx) /* We assume that the socket is already connected */ static struct net_device *get_netdev_for_sock(struct sock *sk) { - struct dst_entry *dst = sk_dst_get(sk); - struct net_device *netdev = NULL; + struct net_device *dev, *lowest_dev = NULL; + struct dst_entry *dst; - if (likely(dst)) { - netdev = netdev_sk_get_lowest_dev(dst->dev, sk); - dev_hold(netdev); + rcu_read_lock(); + dst = __sk_dst_get(sk); + dev = dst ? dst_dev_rcu(dst) : NULL; + if (likely(dev)) { + lowest_dev = netdev_sk_get_lowest_dev(dev, sk); + dev_hold(lowest_dev); } + rcu_read_unlock(); - dst_release(dst); - - return netdev; + return lowest_dev; } static void destroy_record(struct tls_record_info *record) @@ -1410,7 +1412,7 @@ int __init tls_device_init(void) if (!dummy_page) return -ENOMEM; - destruct_wq = alloc_workqueue("ktls_device_destruct", 0, 0); + destruct_wq = alloc_workqueue("ktls_device_destruct", WQ_PERCPU, 0); if (!destruct_wq) { err = -ENOMEM; goto err_free_dummy; diff --git a/net/tls/tls_proc.c b/net/tls/tls_proc.c index 367666aa07b8..4012c4372d4c 100644 --- a/net/tls/tls_proc.c +++ b/net/tls/tls_proc.c @@ -27,17 +27,19 @@ static const struct snmp_mib tls_mib_list[] = { SNMP_MIB_ITEM("TlsTxRekeyOk", LINUX_MIB_TLSTXREKEYOK), SNMP_MIB_ITEM("TlsTxRekeyError", LINUX_MIB_TLSTXREKEYERROR), SNMP_MIB_ITEM("TlsRxRekeyReceived", LINUX_MIB_TLSRXREKEYRECEIVED), - SNMP_MIB_SENTINEL }; static int tls_statistics_seq_show(struct seq_file *seq, void *v) { - unsigned long buf[LINUX_MIB_TLSMAX] = {}; + unsigned long buf[ARRAY_SIZE(tls_mib_list)]; + const int cnt = ARRAY_SIZE(tls_mib_list); struct net *net = seq->private; int i; - snmp_get_cpu_field_batch(buf, tls_mib_list, net->mib.tls_statistics); - for (i = 0; tls_mib_list[i].name; i++) + memset(buf, 0, sizeof(buf)); + snmp_get_cpu_field_batch_cnt(buf, tls_mib_list, cnt, + net->mib.tls_statistics); + for (i = 0; i < cnt; i++) seq_printf(seq, "%-32s\t%lu\n", tls_mib_list[i].name, buf[i]); return 0; diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c index d71643b494a1..98e12f0ff57e 100644 --- a/net/tls/tls_strp.c +++ b/net/tls/tls_strp.c @@ -13,7 +13,7 @@ static struct workqueue_struct *tls_strp_wq; -static void tls_strp_abort_strp(struct tls_strparser *strp, int err) +void tls_strp_abort_strp(struct tls_strparser *strp, int err) { if (strp->stopped) return; @@ -211,11 +211,17 @@ static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb, struct sk_buff *in_skb, unsigned int offset, size_t in_len) { + unsigned int nfrag = skb->len / PAGE_SIZE; size_t len, chunk; skb_frag_t *frag; int sz; - frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; + if (unlikely(nfrag >= skb_shinfo(skb)->nr_frags)) { + DEBUG_NET_WARN_ON_ONCE(1); + return -EMSGSIZE; + } + + frag = &skb_shinfo(skb)->frags[nfrag]; len = in_len; /* First make sure we got the header */ @@ -520,10 +526,8 @@ static int tls_strp_read_sock(struct tls_strparser *strp) tls_strp_load_anchor_with_queue(strp, inq); if (!strp->stm.full_len) { sz = tls_rx_msg_size(strp, strp->anchor); - if (sz < 0) { - tls_strp_abort_strp(strp, sz); + if (sz < 0) return sz; - } strp->stm.full_len = sz; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index bac65d0d4e3e..daac9fd4be7e 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -2474,8 +2474,7 @@ int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb) return data_len + TLS_HEADER_SIZE; read_failure: - tls_err_abort(strp->sk, ret); - + tls_strp_abort_strp(strp, ret); return ret; } |