summaryrefslogtreecommitdiff
path: root/include/net/psp
diff options
context:
space:
mode:
Diffstat (limited to 'include/net/psp')
-rw-r--r--include/net/psp/functions.h209
-rw-r--r--include/net/psp/types.h184
2 files changed, 393 insertions, 0 deletions
diff --git a/include/net/psp/functions.h b/include/net/psp/functions.h
new file mode 100644
index 000000000000..ef7743664da3
--- /dev/null
+++ b/include/net/psp/functions.h
@@ -0,0 +1,209 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+
+#ifndef __NET_PSP_HELPERS_H
+#define __NET_PSP_HELPERS_H
+
+#include <linux/skbuff.h>
+#include <linux/rcupdate.h>
+#include <linux/udp.h>
+#include <net/sock.h>
+#include <net/tcp.h>
+#include <net/psp/types.h>
+
+struct inet_timewait_sock;
+
+/* Driver-facing API */
+struct psp_dev *
+psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
+ struct psp_dev_caps *psd_caps, void *priv_ptr);
+void psp_dev_unregister(struct psp_dev *psd);
+bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
+ u8 ver, __be16 sport);
+int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv);
+
+/* Kernel-facing API */
+void psp_assoc_put(struct psp_assoc *pas);
+
+static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
+{
+ return pas->drv_data;
+}
+
+#if IS_ENABLED(CONFIG_INET_PSP)
+unsigned int psp_key_size(u32 version);
+void psp_sk_assoc_free(struct sock *sk);
+void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
+void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
+void psp_reply_set_decrypted(struct sk_buff *skb);
+
+static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
+{
+ return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
+}
+
+static inline void
+psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
+{
+ struct psp_assoc *pas;
+
+ pas = psp_sk_assoc(sk);
+ if (pas && pas->tx.spi)
+ skb->decrypted = 1;
+}
+
+static inline unsigned long
+__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
+ unsigned long diffs)
+{
+ struct psp_skb_ext *a, *b;
+
+ a = skb_ext_find(one, SKB_EXT_PSP);
+ b = skb_ext_find(two, SKB_EXT_PSP);
+
+ diffs |= (!!a) ^ (!!b);
+ if (!diffs && unlikely(a))
+ diffs |= memcmp(a, b, sizeof(*a));
+ return diffs;
+}
+
+static inline bool
+psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
+{
+ bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
+ u32 end_seq = TCP_SKB_CB(skb)->end_seq;
+ u32 seq = TCP_SKB_CB(skb)->seq;
+ bool pure_fin;
+
+ pure_fin = fin && end_seq - seq == 1;
+
+ return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
+}
+
+static inline bool
+psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
+{
+ return pse && pas->rx.spi == pse->spi &&
+ pas->generation == pse->generation &&
+ pas->version == pse->version &&
+ pas->dev_id == pse->dev_id;
+}
+
+static inline enum skb_drop_reason
+__psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
+{
+ struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
+
+ if (!pas)
+ return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
+
+ if (likely(psp_pse_matches_pas(pse, pas))) {
+ if (unlikely(!pas->peer_tx))
+ pas->peer_tx = 1;
+
+ return 0;
+ }
+
+ if (!pse) {
+ if (!pas->tx.spi ||
+ (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
+ return 0;
+ }
+
+ return SKB_DROP_REASON_PSP_INPUT;
+}
+
+static inline enum skb_drop_reason
+psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
+{
+ return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
+}
+
+static inline enum skb_drop_reason
+psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
+{
+ return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
+}
+
+static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk)
+{
+ struct psp_assoc *pas;
+ int state;
+
+ state = READ_ONCE(sk->sk_state);
+ if (!sk_is_inet(sk) || state == TCP_NEW_SYN_RECV)
+ return NULL;
+
+ pas = state == TCP_TIME_WAIT ?
+ rcu_dereference(inet_twsk(sk)->psp_assoc) :
+ rcu_dereference(sk->psp_assoc);
+ return pas;
+}
+
+static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
+{
+ if (!skb->decrypted || !skb->sk)
+ return NULL;
+
+ return psp_sk_get_assoc_rcu(skb->sk);
+}
+
+static inline unsigned int psp_sk_overhead(const struct sock *sk)
+{
+ int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE;
+ bool has_psp = rcu_access_pointer(sk->psp_assoc);
+
+ return has_psp ? psp_encap : 0;
+}
+#else
+static inline void psp_sk_assoc_free(struct sock *sk) { }
+static inline void
+psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
+static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
+static inline void
+psp_reply_set_decrypted(struct sk_buff *skb) { }
+
+static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
+{
+ return NULL;
+}
+
+static inline void
+psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
+
+static inline unsigned long
+__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
+ unsigned long diffs)
+{
+ return diffs;
+}
+
+static inline enum skb_drop_reason
+psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
+{
+ return 0;
+}
+
+static inline enum skb_drop_reason
+psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
+{
+ return 0;
+}
+
+static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
+{
+ return NULL;
+}
+
+static inline unsigned int psp_sk_overhead(const struct sock *sk)
+{
+ return 0;
+}
+#endif
+
+static inline unsigned long
+psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
+{
+ return __psp_skb_coalesce_diff(one, two, 0);
+}
+
+#endif /* __NET_PSP_HELPERS_H */
diff --git a/include/net/psp/types.h b/include/net/psp/types.h
new file mode 100644
index 000000000000..31cee64b7c86
--- /dev/null
+++ b/include/net/psp/types.h
@@ -0,0 +1,184 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+
+#ifndef __NET_PSP_H
+#define __NET_PSP_H
+
+#include <linux/mutex.h>
+#include <linux/refcount.h>
+
+struct netlink_ext_ack;
+
+#define PSP_DEFAULT_UDP_PORT 1000
+
+struct psphdr {
+ u8 nexthdr;
+ u8 hdrlen;
+ u8 crypt_offset;
+ u8 verfl;
+ __be32 spi;
+ __be64 iv;
+ __be64 vc[]; /* optional */
+};
+
+#define PSP_ENCAP_HLEN (sizeof(struct udphdr) + sizeof(struct psphdr))
+
+#define PSP_SPI_KEY_ID GENMASK(30, 0)
+#define PSP_SPI_KEY_PHASE BIT(31)
+
+#define PSPHDR_CRYPT_OFFSET GENMASK(5, 0)
+
+#define PSPHDR_VERFL_SAMPLE BIT(7)
+#define PSPHDR_VERFL_DROP BIT(6)
+#define PSPHDR_VERFL_VERSION GENMASK(5, 2)
+#define PSPHDR_VERFL_VIRT BIT(1)
+#define PSPHDR_VERFL_ONE BIT(0)
+
+#define PSP_HDRLEN_NOOPT ((sizeof(struct psphdr) - 8) / 8)
+
+/**
+ * struct psp_dev_config - PSP device configuration
+ * @versions: PSP versions enabled on the device
+ */
+struct psp_dev_config {
+ u32 versions;
+};
+
+/**
+ * struct psp_dev - PSP device struct
+ * @main_netdev: original netdevice of this PSP device
+ * @ops: driver callbacks
+ * @caps: device capabilities
+ * @drv_priv: driver priv pointer
+ * @lock: instance lock, protects all fields
+ * @refcnt: reference count for the instance
+ * @id: instance id
+ * @generation: current generation of the device key
+ * @config: current device configuration
+ * @active_assocs: list of registered associations
+ * @prev_assocs: associations which use old (but still usable)
+ * device key
+ * @stale_assocs: associations which use a rotated out key
+ *
+ * @rcu: RCU head for freeing the structure
+ */
+struct psp_dev {
+ struct net_device *main_netdev;
+
+ struct psp_dev_ops *ops;
+ struct psp_dev_caps *caps;
+ void *drv_priv;
+
+ struct mutex lock;
+ refcount_t refcnt;
+
+ u32 id;
+
+ u8 generation;
+
+ struct psp_dev_config config;
+
+ struct list_head active_assocs;
+ struct list_head prev_assocs;
+ struct list_head stale_assocs;
+
+ struct rcu_head rcu;
+};
+
+#define PSP_GEN_VALID_MASK 0x7f
+
+/**
+ * struct psp_dev_caps - PSP device capabilities
+ */
+struct psp_dev_caps {
+ /**
+ * @versions: mask of supported PSP versions
+ * Set this field to 0 to indicate PSP is not supported at all.
+ */
+ u32 versions;
+
+ /**
+ * @assoc_drv_spc: size of driver-specific state in Tx assoc
+ * Determines the size of struct psp_assoc::drv_data
+ */
+ u32 assoc_drv_spc;
+};
+
+#define PSP_MAX_KEY 32
+
+#define PSP_HDR_SIZE 16 /* We don't support optional fields, yet */
+#define PSP_TRL_SIZE 16 /* AES-GCM/GMAC trailer size */
+
+struct psp_skb_ext {
+ __be32 spi;
+ u16 dev_id;
+ u8 generation;
+ u8 version;
+};
+
+struct psp_key_parsed {
+ __be32 spi;
+ u8 key[PSP_MAX_KEY];
+};
+
+struct psp_assoc {
+ struct psp_dev *psd;
+
+ u16 dev_id;
+ u8 generation;
+ u8 version;
+ u8 peer_tx;
+
+ u32 upgrade_seq;
+
+ struct psp_key_parsed tx;
+ struct psp_key_parsed rx;
+
+ refcount_t refcnt;
+ struct rcu_head rcu;
+ struct work_struct work;
+ struct list_head assocs_list;
+
+ u8 drv_data[] __aligned(8);
+};
+
+/**
+ * struct psp_dev_ops - netdev driver facing PSP callbacks
+ */
+struct psp_dev_ops {
+ /**
+ * @set_config: set configuration of a PSP device
+ * Driver can inspect @psd->config for the previous configuration.
+ * Core will update @psd->config with @config on success.
+ */
+ int (*set_config)(struct psp_dev *psd, struct psp_dev_config *conf,
+ struct netlink_ext_ack *extack);
+
+ /**
+ * @key_rotate: rotate the device key
+ */
+ int (*key_rotate)(struct psp_dev *psd, struct netlink_ext_ack *extack);
+
+ /**
+ * @rx_spi_alloc: allocate an Rx SPI+key pair
+ * Allocate an Rx SPI and resulting derived key.
+ * This key should remain valid until key rotation.
+ */
+ int (*rx_spi_alloc)(struct psp_dev *psd, u32 version,
+ struct psp_key_parsed *assoc,
+ struct netlink_ext_ack *extack);
+
+ /**
+ * @tx_key_add: add a Tx key to the device
+ * Install an association in the device. Core will allocate space
+ * for the driver to use at drv_data.
+ */
+ int (*tx_key_add)(struct psp_dev *psd, struct psp_assoc *pas,
+ struct netlink_ext_ack *extack);
+ /**
+ * @tx_key_del: remove a Tx key from the device
+ * Remove an association from the device.
+ */
+ void (*tx_key_del)(struct psp_dev *psd, struct psp_assoc *pas);
+};
+
+#endif /* __NET_PSP_H */