diff options
Diffstat (limited to 'net/tls/tls_strp.c')
| -rw-r--r-- | net/tls/tls_strp.c | 228 |
1 files changed, 176 insertions, 52 deletions
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c index 955ac3e0bf4d..98e12f0ff57e 100644 --- a/net/tls/tls_strp.c +++ b/net/tls/tls_strp.c @@ -2,6 +2,7 @@ /* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */ #include <linux/skbuff.h> +#include <linux/skbuff_ref.h> #include <linux/workqueue.h> #include <net/strparser.h> #include <net/tcp.h> @@ -12,7 +13,7 @@ static struct workqueue_struct *tls_strp_wq; -static void tls_strp_abort_strp(struct tls_strparser *strp, int err) +void tls_strp_abort_strp(struct tls_strparser *strp, int err) { if (strp->stopped) return; @@ -20,7 +21,9 @@ static void tls_strp_abort_strp(struct tls_strparser *strp, int err) strp->stopped = 1; /* Report an error on the lower socket */ - strp->sk->sk_err = -err; + WRITE_ONCE(strp->sk->sk_err, -err); + /* Paired with smp_rmb() in tcp_poll() */ + smp_wmb(); sk_error_report(strp->sk); } @@ -29,34 +32,50 @@ static void tls_strp_anchor_free(struct tls_strparser *strp) struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); - shinfo->frag_list = NULL; + if (!strp->copy_mode) + shinfo->frag_list = NULL; consume_skb(strp->anchor); strp->anchor = NULL; } -/* Create a new skb with the contents of input copied to its page frags */ -static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp) +static struct sk_buff * +tls_strp_skb_copy(struct tls_strparser *strp, struct sk_buff *in_skb, + int offset, int len) { - struct strp_msg *rxm; struct sk_buff *skb; - int i, err, offset; + int i, err; - skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER, + skb = alloc_skb_with_frags(0, len, TLS_PAGE_ORDER, &err, strp->sk->sk_allocation); if (!skb) return NULL; - offset = strp->stm.offset; for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; - WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset, + WARN_ON_ONCE(skb_copy_bits(in_skb, offset, skb_frag_address(frag), skb_frag_size(frag))); offset += skb_frag_size(frag); } - skb_copy_header(skb, strp->anchor); + skb->len = len; + skb->data_len = len; + skb_copy_header(skb, in_skb); + return skb; +} + +/* Create a new skb with the contents of input copied to its page frags */ +static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp) +{ + struct strp_msg *rxm; + struct sk_buff *skb; + + skb = tls_strp_skb_copy(strp, strp->anchor, strp->stm.offset, + strp->stm.full_len); + if (!skb) + return NULL; + rxm = strp_msg(skb); rxm->offset = 0; return skb; @@ -180,23 +199,29 @@ static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) for (i = 0; i < shinfo->nr_frags; i++) __skb_frag_unref(&shinfo->frags[i], false); shinfo->nr_frags = 0; + if (strp->copy_mode) { + kfree_skb_list(shinfo->frag_list); + shinfo->frag_list = NULL; + } strp->copy_mode = 0; + strp->mixed_decrypted = 0; } -static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, - unsigned int offset, size_t in_len) +static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb, + struct sk_buff *in_skb, unsigned int offset, + size_t in_len) { - struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; - struct sk_buff *skb; - skb_frag_t *frag; + unsigned int nfrag = skb->len / PAGE_SIZE; size_t len, chunk; + skb_frag_t *frag; int sz; - if (strp->msg_ready) - return 0; + if (unlikely(nfrag >= skb_shinfo(skb)->nr_frags)) { + DEBUG_NET_WARN_ON_ONCE(1); + return -EMSGSIZE; + } - skb = strp->anchor; - frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; + frag = &skb_shinfo(skb)->frags[nfrag]; len = in_len; /* First make sure we got the header */ @@ -208,19 +233,26 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, skb_frag_size(frag), chunk)); - sz = tls_rx_msg_size(strp, strp->anchor); - if (sz < 0) { - desc->error = sz; - return 0; - } - - /* We may have over-read, sz == 0 is guaranteed under-read */ - if (sz > 0) - chunk = min_t(size_t, chunk, sz - skb->len); - skb->len += chunk; skb->data_len += chunk; skb_frag_size_add(frag, chunk); + + sz = tls_rx_msg_size(strp, skb); + if (sz < 0) + return sz; + + /* We may have over-read, sz == 0 is guaranteed under-read */ + if (unlikely(sz && sz < skb->len)) { + int over = skb->len - sz; + + WARN_ON_ONCE(over > chunk); + skb->len -= over; + skb->data_len -= over; + skb_frag_size_add(frag, -over); + + chunk -= over; + } + frag++; len -= chunk; offset += chunk; @@ -247,20 +279,103 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, offset += chunk; } - if (strp->stm.full_len == skb->len) { +read_done: + return in_len - len; +} + +static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb, + struct sk_buff *in_skb, unsigned int offset, + size_t in_len) +{ + struct sk_buff *nskb, *first, *last; + struct skb_shared_info *shinfo; + size_t chunk; + int sz; + + if (strp->stm.full_len) + chunk = strp->stm.full_len - skb->len; + else + chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; + chunk = min(chunk, in_len); + + nskb = tls_strp_skb_copy(strp, in_skb, offset, chunk); + if (!nskb) + return -ENOMEM; + + shinfo = skb_shinfo(skb); + if (!shinfo->frag_list) { + shinfo->frag_list = nskb; + nskb->prev = nskb; + } else { + first = shinfo->frag_list; + last = first->prev; + last->next = nskb; + first->prev = nskb; + } + + skb->len += chunk; + skb->data_len += chunk; + + if (!strp->stm.full_len) { + sz = tls_rx_msg_size(strp, skb); + if (sz < 0) + return sz; + + /* We may have over-read, sz == 0 is guaranteed under-read */ + if (unlikely(sz && sz < skb->len)) { + int over = skb->len - sz; + + WARN_ON_ONCE(over > chunk); + skb->len -= over; + skb->data_len -= over; + __pskb_trim(nskb, nskb->len - over); + + chunk -= over; + } + + strp->stm.full_len = sz; + } + + return chunk; +} + +static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, + unsigned int offset, size_t in_len) +{ + struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; + struct sk_buff *skb; + int ret; + + if (strp->msg_ready) + return 0; + + skb = strp->anchor; + if (!skb->len) + skb_copy_decrypted(skb, in_skb); + else + strp->mixed_decrypted |= !!skb_cmp_decrypted(skb, in_skb); + + if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted) + ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len); + else + ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len); + if (ret < 0) { + desc->error = ret; + ret = 0; + } + + if (strp->stm.full_len && strp->stm.full_len == skb->len) { desc->count = 0; - strp->msg_ready = 1; + WRITE_ONCE(strp->msg_ready, 1); tls_rx_msg_ready(strp); } -read_done: - return in_len - len; + return ret; } static int tls_strp_read_copyin(struct tls_strparser *strp) { - struct socket *sock = strp->sk->sk_socket; read_descriptor_t desc; desc.arg.data = strp; @@ -268,7 +383,7 @@ static int tls_strp_read_copyin(struct tls_strparser *strp) desc.count = 1; /* give more than one skb per call */ /* sk should be locked here, so okay to do read_sock */ - sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin); + tcp_read_sock(strp->sk, &desc, tls_strp_copyin); return desc.error; } @@ -287,7 +402,6 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) return 0; shinfo = skb_shinfo(strp->anchor); - shinfo->frag_list = NULL; /* If we don't know the length go max plus page for cipher overhead */ need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; @@ -303,6 +417,8 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) page, 0, 0); } + shinfo->frag_list = NULL; + strp->copy_mode = 1; strp->stm.offset = 0; @@ -315,15 +431,19 @@ static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) return 0; } -static bool tls_strp_check_no_dup(struct tls_strparser *strp) +static bool tls_strp_check_queue_ok(struct tls_strparser *strp) { unsigned int len = strp->stm.offset + strp->stm.full_len; - struct sk_buff *skb; + struct sk_buff *first, *skb; u32 seq; - skb = skb_shinfo(strp->anchor)->frag_list; - seq = TCP_SKB_CB(skb)->seq; + first = skb_shinfo(strp->anchor)->frag_list; + skb = first; + seq = TCP_SKB_CB(first)->seq; + /* Make sure there's no duplicate data in the queue, + * and the decrypted status matches. + */ while (skb->len < len) { seq += skb->len; len -= skb->len; @@ -331,6 +451,8 @@ static bool tls_strp_check_no_dup(struct tls_strparser *strp) if (TCP_SKB_CB(skb)->seq != seq) return false; + if (skb_cmp_decrypted(first, skb)) + return false; } return true; @@ -359,7 +481,7 @@ static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) strp->stm.offset = offset; } -void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) +bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) { struct strp_msg *rxm; struct tls_msg *tlm; @@ -368,8 +490,11 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); if (!strp->copy_mode && force_refresh) { - if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) - return; + if (unlikely(tcp_inq(strp->sk) < strp->stm.full_len)) { + WRITE_ONCE(strp->msg_ready, 0); + memset(&strp->stm, 0, sizeof(strp->stm)); + return false; + } tls_strp_load_anchor_with_queue(strp, strp->stm.full_len); } @@ -379,6 +504,8 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) rxm->offset = strp->stm.offset; tlm = tls_msg(strp->anchor); tlm->control = strp->mark; + + return true; } /* Called with lock held on lower socket */ @@ -396,14 +523,11 @@ static int tls_strp_read_sock(struct tls_strparser *strp) if (inq < strp->stm.full_len) return tls_strp_read_copy(strp, true); + tls_strp_load_anchor_with_queue(strp, inq); if (!strp->stm.full_len) { - tls_strp_load_anchor_with_queue(strp, inq); - sz = tls_rx_msg_size(strp, strp->anchor); - if (sz < 0) { - tls_strp_abort_strp(strp, sz); + if (sz < 0) return sz; - } strp->stm.full_len = sz; @@ -411,10 +535,10 @@ static int tls_strp_read_sock(struct tls_strparser *strp) return tls_strp_read_copy(strp, true); } - if (!tls_strp_check_no_dup(strp)) + if (!tls_strp_check_queue_ok(strp)) return tls_strp_read_copy(strp, false); - strp->msg_ready = 1; + WRITE_ONCE(strp->msg_ready, 1); tls_rx_msg_ready(strp); return 0; @@ -466,7 +590,7 @@ void tls_strp_msg_done(struct tls_strparser *strp) else tls_strp_flush_anchor_copy(strp); - strp->msg_ready = 0; + WRITE_ONCE(strp->msg_ready, 0); memset(&strp->stm, 0, sizeof(strp->stm)); tls_strp_check_rcv(strp); |
