summaryrefslogtreecommitdiff
path: root/drivers/net/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r--drivers/net/wireguard/device.c11
-rw-r--r--drivers/net/wireguard/receive.c7
-rw-r--r--drivers/net/wireguard/send.c16
-rw-r--r--drivers/net/wireguard/socket.c1
4 files changed, 22 insertions, 13 deletions
diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index 16b19824b9ad..cdc96968b0f4 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -203,9 +203,9 @@ err_peer:
err:
++dev->stats.tx_errors;
if (skb->protocol == htons(ETH_P_IP))
- icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
+ icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
else if (skb->protocol == htons(ETH_P_IPV6))
- icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0);
+ icmpv6_ndo_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0);
kfree_skb(skb);
return ret;
}
@@ -258,6 +258,8 @@ static void wg_setup(struct net_device *dev)
enum { WG_NETDEV_FEATURES = NETIF_F_HW_CSUM | NETIF_F_RXCSUM |
NETIF_F_SG | NETIF_F_GSO |
NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA };
+ const int overhead = MESSAGE_MINIMUM_LENGTH + sizeof(struct udphdr) +
+ max(sizeof(struct ipv6hdr), sizeof(struct iphdr));
dev->netdev_ops = &netdev_ops;
dev->hard_header_len = 0;
@@ -271,9 +273,8 @@ static void wg_setup(struct net_device *dev)
dev->features |= WG_NETDEV_FEATURES;
dev->hw_features |= WG_NETDEV_FEATURES;
dev->hw_enc_features |= WG_NETDEV_FEATURES;
- dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH -
- sizeof(struct udphdr) -
- max(sizeof(struct ipv6hdr), sizeof(struct iphdr));
+ dev->mtu = ETH_DATA_LEN - overhead;
+ dev->max_mtu = round_down(INT_MAX, MESSAGE_PADDING_MULTIPLE) - overhead;
SET_NETDEV_DEVTYPE(dev, &device_type);
diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c
index 9c6bab9c981f..4a153894cee2 100644
--- a/drivers/net/wireguard/receive.c
+++ b/drivers/net/wireguard/receive.c
@@ -118,10 +118,13 @@ static void wg_receive_handshake_packet(struct wg_device *wg,
under_load = skb_queue_len(&wg->incoming_handshakes) >=
MAX_QUEUED_INCOMING_HANDSHAKES / 8;
- if (under_load)
+ if (under_load) {
last_under_load = ktime_get_coarse_boottime_ns();
- else if (last_under_load)
+ } else if (last_under_load) {
under_load = !wg_birthdate_has_expired(last_under_load, 1);
+ if (!under_load)
+ last_under_load = 0;
+ }
mac_state = wg_cookie_validate_packet(&wg->cookie_checker, skb,
under_load);
if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) ||
diff --git a/drivers/net/wireguard/send.c b/drivers/net/wireguard/send.c
index c13260563446..7348c10cbae3 100644
--- a/drivers/net/wireguard/send.c
+++ b/drivers/net/wireguard/send.c
@@ -143,16 +143,22 @@ static void keep_key_fresh(struct wg_peer *peer)
static unsigned int calculate_skb_padding(struct sk_buff *skb)
{
+ unsigned int padded_size, last_unit = skb->len;
+
+ if (unlikely(!PACKET_CB(skb)->mtu))
+ return ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE) - last_unit;
+
/* We do this modulo business with the MTU, just in case the networking
* layer gives us a packet that's bigger than the MTU. In that case, we
* wouldn't want the final subtraction to overflow in the case of the
- * padded_size being clamped.
+ * padded_size being clamped. Fortunately, that's very rarely the case,
+ * so we optimize for that not happening.
*/
- unsigned int last_unit = skb->len % PACKET_CB(skb)->mtu;
- unsigned int padded_size = ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE);
+ if (unlikely(last_unit > PACKET_CB(skb)->mtu))
+ last_unit %= PACKET_CB(skb)->mtu;
- if (padded_size > PACKET_CB(skb)->mtu)
- padded_size = PACKET_CB(skb)->mtu;
+ padded_size = min(PACKET_CB(skb)->mtu,
+ ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE));
return padded_size - last_unit;
}
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index 262f3b5c819d..b0d6541582d3 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -432,7 +432,6 @@ void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
wg->incoming_port = ntohs(inet_sk(new4)->inet_sport);
mutex_unlock(&wg->socket_update_lock);
synchronize_rcu();
- synchronize_net();
sock_free(old4);
sock_free(old6);
}