diff options
Diffstat (limited to 'net/tls/tls_strp.c')
| -rw-r--r-- | net/tls/tls_strp.c | 41 |
1 files changed, 25 insertions, 16 deletions
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c index f37f4a0fcd3c..98e12f0ff57e 100644 --- a/net/tls/tls_strp.c +++ b/net/tls/tls_strp.c @@ -2,6 +2,7 @@ /* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */ #include <linux/skbuff.h> +#include <linux/skbuff_ref.h> #include <linux/workqueue.h> #include <net/strparser.h> #include <net/tcp.h> @@ -12,7 +13,7 @@ static struct workqueue_struct *tls_strp_wq; -static void tls_strp_abort_strp(struct tls_strparser *strp, int err) +void tls_strp_abort_strp(struct tls_strparser *strp, int err) { if (strp->stopped) return; @@ -210,11 +211,17 @@ static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb, struct sk_buff *in_skb, unsigned int offset, size_t in_len) { + unsigned int nfrag = skb->len / PAGE_SIZE; size_t len, chunk; skb_frag_t *frag; int sz; - frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; + if (unlikely(nfrag >= skb_shinfo(skb)->nr_frags)) { + DEBUG_NET_WARN_ON_ONCE(1); + return -EMSGSIZE; + } + + frag = &skb_shinfo(skb)->frags[nfrag]; len = in_len; /* First make sure we got the header */ @@ -360,7 +367,7 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, if (strp->stm.full_len && strp->stm.full_len == skb->len) { desc->count = 0; - strp->msg_ready = 1; + WRITE_ONCE(strp->msg_ready, 1); tls_rx_msg_ready(strp); } @@ -369,7 +376,6 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, static int tls_strp_read_copyin(struct tls_strparser *strp) { - struct socket *sock = strp->sk->sk_socket; read_descriptor_t desc; desc.arg.data = strp; @@ -377,7 +383,7 @@ static int tls_strp_read_copyin(struct tls_strparser *strp) desc.count = 1; /* give more than one skb per call */ /* sk should be locked here, so okay to do read_sock */ - sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin); + tcp_read_sock(strp->sk, &desc, tls_strp_copyin); return desc.error; } @@ -396,7 +402,6 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) return 0; shinfo = skb_shinfo(strp->anchor); - shinfo->frag_list = NULL; /* If we don't know the length go max plus page for cipher overhead */ need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; @@ -412,6 +417,8 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) page, 0, 0); } + shinfo->frag_list = NULL; + strp->copy_mode = 1; strp->stm.offset = 0; @@ -474,7 +481,7 @@ static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) strp->stm.offset = offset; } -void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) +bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) { struct strp_msg *rxm; struct tls_msg *tlm; @@ -483,8 +490,11 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); if (!strp->copy_mode && force_refresh) { - if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) - return; + if (unlikely(tcp_inq(strp->sk) < strp->stm.full_len)) { + WRITE_ONCE(strp->msg_ready, 0); + memset(&strp->stm, 0, sizeof(strp->stm)); + return false; + } tls_strp_load_anchor_with_queue(strp, strp->stm.full_len); } @@ -494,6 +504,8 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) rxm->offset = strp->stm.offset; tlm = tls_msg(strp->anchor); tlm->control = strp->mark; + + return true; } /* Called with lock held on lower socket */ @@ -511,14 +523,11 @@ static int tls_strp_read_sock(struct tls_strparser *strp) if (inq < strp->stm.full_len) return tls_strp_read_copy(strp, true); + tls_strp_load_anchor_with_queue(strp, inq); if (!strp->stm.full_len) { - tls_strp_load_anchor_with_queue(strp, inq); - sz = tls_rx_msg_size(strp, strp->anchor); - if (sz < 0) { - tls_strp_abort_strp(strp, sz); + if (sz < 0) return sz; - } strp->stm.full_len = sz; @@ -529,7 +538,7 @@ static int tls_strp_read_sock(struct tls_strparser *strp) if (!tls_strp_check_queue_ok(strp)) return tls_strp_read_copy(strp, false); - strp->msg_ready = 1; + WRITE_ONCE(strp->msg_ready, 1); tls_rx_msg_ready(strp); return 0; @@ -581,7 +590,7 @@ void tls_strp_msg_done(struct tls_strparser *strp) else tls_strp_flush_anchor_copy(strp); - strp->msg_ready = 0; + WRITE_ONCE(strp->msg_ready, 0); memset(&strp->stm, 0, sizeof(strp->stm)); tls_strp_check_rcv(strp); |
