diff options
-rw-r--r-- | drivers/net/netdevsim/Makefile | 4 | ||||
-rw-r--r-- | drivers/net/netdevsim/netdev.c | 43 | ||||
-rw-r--r-- | drivers/net/netdevsim/netdevsim.h | 27 | ||||
-rw-r--r-- | drivers/net/netdevsim/psp.c | 225 | ||||
-rw-r--r-- | net/core/skbuff.c | 1 | ||||
-rw-r--r-- | tools/testing/selftests/drivers/net/.gitignore | 1 | ||||
-rw-r--r-- | tools/testing/selftests/drivers/net/Makefile | 10 | ||||
-rw-r--r-- | tools/testing/selftests/drivers/net/config | 1 | ||||
-rw-r--r-- | tools/testing/selftests/drivers/net/hw/lib/py/__init__.py | 4 | ||||
-rw-r--r-- | tools/testing/selftests/drivers/net/lib/py/__init__.py | 4 | ||||
-rw-r--r-- | tools/testing/selftests/drivers/net/lib/py/env.py | 4 | ||||
-rwxr-xr-x | tools/testing/selftests/drivers/net/psp.py | 627 | ||||
-rw-r--r-- | tools/testing/selftests/drivers/net/psp_responder.c | 483 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/__init__.py | 2 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/ksft.py | 10 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/ynl.py | 5 |
16 files changed, 1440 insertions, 11 deletions
diff --git a/drivers/net/netdevsim/Makefile b/drivers/net/netdevsim/Makefile index f8de93bc5f5b..14a553e000ec 100644 --- a/drivers/net/netdevsim/Makefile +++ b/drivers/net/netdevsim/Makefile @@ -18,6 +18,10 @@ ifneq ($(CONFIG_PSAMPLE),) netdevsim-objs += psample.o endif +ifneq ($(CONFIG_INET_PSP),) +netdevsim-objs += psp.o +endif + ifneq ($(CONFIG_MACSEC),) netdevsim-objs += macsec.o endif diff --git a/drivers/net/netdevsim/netdev.c b/drivers/net/netdevsim/netdev.c index 0178219f0db5..ebc3833e95b4 100644 --- a/drivers/net/netdevsim/netdev.c +++ b/drivers/net/netdevsim/netdev.c @@ -103,28 +103,42 @@ static int nsim_napi_rx(struct net_device *tx_dev, struct net_device *rx_dev, static int nsim_forward_skb(struct net_device *tx_dev, struct net_device *rx_dev, struct sk_buff *skb, - struct nsim_rq *rq) + struct nsim_rq *rq, + struct skb_ext *psp_ext) { - return __dev_forward_skb(rx_dev, skb) ?: - nsim_napi_rx(tx_dev, rx_dev, rq, skb); + int ret; + + ret = __dev_forward_skb(rx_dev, skb); + if (ret) + return ret; + + nsim_psp_handle_ext(skb, psp_ext); + + return nsim_napi_rx(tx_dev, rx_dev, rq, skb); } static netdev_tx_t nsim_start_xmit(struct sk_buff *skb, struct net_device *dev) { struct netdevsim *ns = netdev_priv(dev); + struct skb_ext *psp_ext = NULL; struct net_device *peer_dev; unsigned int len = skb->len; struct netdevsim *peer_ns; struct netdev_config *cfg; struct nsim_rq *rq; int rxq; + int dr; rcu_read_lock(); if (!nsim_ipsec_tx(ns, skb)) - goto out_drop_free; + goto out_drop_any; peer_ns = rcu_dereference(ns->peer); if (!peer_ns) + goto out_drop_any; + + dr = nsim_do_psp(skb, ns, peer_ns, &psp_ext); + if (dr) goto out_drop_free; peer_dev = peer_ns->netdev; @@ -141,7 +155,8 @@ static netdev_tx_t nsim_start_xmit(struct sk_buff *skb, struct net_device *dev) skb_linearize(skb); skb_tx_timestamp(skb); - if (unlikely(nsim_forward_skb(dev, peer_dev, skb, rq) == NET_RX_DROP)) + if (unlikely(nsim_forward_skb(dev, peer_dev, + skb, rq, psp_ext) == NET_RX_DROP)) goto out_drop_cnt; if (!hrtimer_active(&rq->napi_timer)) @@ -151,8 +166,10 @@ static netdev_tx_t nsim_start_xmit(struct sk_buff *skb, struct net_device *dev) dev_dstats_tx_add(dev, len); return NETDEV_TX_OK; +out_drop_any: + dr = SKB_DROP_REASON_NOT_SPECIFIED; out_drop_free: - dev_kfree_skb(skb); + kfree_skb_reason(skb, dr); out_drop_cnt: rcu_read_unlock(); dev_dstats_tx_dropped(dev); @@ -1002,6 +1019,7 @@ static void nsim_queue_uninit(struct netdevsim *ns) static int nsim_init_netdevsim(struct netdevsim *ns) { + struct netdevsim *peer; struct mock_phc *phc; int err; @@ -1036,6 +1054,10 @@ static int nsim_init_netdevsim(struct netdevsim *ns) goto err_ipsec_teardown; rtnl_unlock(); + err = nsim_psp_init(ns); + if (err) + goto err_unregister_netdev; + if (IS_ENABLED(CONFIG_DEBUG_NET)) { ns->nb.notifier_call = netdev_debug_event; if (register_netdevice_notifier_dev_net(ns->netdev, &ns->nb, @@ -1045,6 +1067,13 @@ static int nsim_init_netdevsim(struct netdevsim *ns) return 0; +err_unregister_netdev: + rtnl_lock(); + peer = rtnl_dereference(ns->peer); + if (peer) + RCU_INIT_POINTER(peer->peer, NULL); + RCU_INIT_POINTER(ns->peer, NULL); + unregister_netdevice(ns->netdev); err_ipsec_teardown: nsim_ipsec_teardown(ns); nsim_macsec_teardown(ns); @@ -1132,6 +1161,8 @@ void nsim_destroy(struct netdevsim *ns) unregister_netdevice_notifier_dev_net(ns->netdev, &ns->nb, &ns->nn); + nsim_psp_uninit(ns); + rtnl_lock(); peer = rtnl_dereference(ns->peer); if (peer) diff --git a/drivers/net/netdevsim/netdevsim.h b/drivers/net/netdevsim/netdevsim.h index bddd24c1389d..02c1c97b7008 100644 --- a/drivers/net/netdevsim/netdevsim.h +++ b/drivers/net/netdevsim/netdevsim.h @@ -108,6 +108,12 @@ struct netdevsim { int rq_reset_mode; + struct { + struct psp_dev *dev; + u32 spi; + u32 assoc_cnt; + } psp; + struct nsim_bus_dev *nsim_bus_dev; struct bpf_prog *bpf_offloaded; @@ -421,6 +427,27 @@ static inline void nsim_macsec_teardown(struct netdevsim *ns) } #endif +#if IS_ENABLED(CONFIG_INET_PSP) +int nsim_psp_init(struct netdevsim *ns); +void nsim_psp_uninit(struct netdevsim *ns); +void nsim_psp_handle_ext(struct sk_buff *skb, struct skb_ext *psp_ext); +enum skb_drop_reason +nsim_do_psp(struct sk_buff *skb, struct netdevsim *ns, + struct netdevsim *peer_ns, struct skb_ext **psp_ext); +#else +static inline int nsim_psp_init(struct netdevsim *ns) { return 0; } +static inline void nsim_psp_uninit(struct netdevsim *ns) {} +static inline enum skb_drop_reason +nsim_do_psp(struct sk_buff *skb, struct netdevsim *ns, + struct netdevsim *peer_ns, struct skb_ext **psp_ext) +{ + return 0; +} + +static inline void +nsim_psp_handle_ext(struct sk_buff *skb, struct skb_ext *psp_ext) {} +#endif + struct nsim_bus_dev { struct device dev; struct list_head list; diff --git a/drivers/net/netdevsim/psp.c b/drivers/net/netdevsim/psp.c new file mode 100644 index 000000000000..332b5b744f01 --- /dev/null +++ b/drivers/net/netdevsim/psp.c @@ -0,0 +1,225 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/ip.h> +#include <linux/skbuff.h> +#include <net/ip6_checksum.h> +#include <net/psp.h> +#include <net/sock.h> + +#include "netdevsim.h" + +void nsim_psp_handle_ext(struct sk_buff *skb, struct skb_ext *psp_ext) +{ + if (psp_ext) + __skb_ext_set(skb, SKB_EXT_PSP, psp_ext); +} + +enum skb_drop_reason +nsim_do_psp(struct sk_buff *skb, struct netdevsim *ns, + struct netdevsim *peer_ns, struct skb_ext **psp_ext) +{ + enum skb_drop_reason rc = 0; + struct psp_assoc *pas; + struct net *net; + void **ptr; + + rcu_read_lock(); + pas = psp_skb_get_assoc_rcu(skb); + if (!pas) { + rc = SKB_NOT_DROPPED_YET; + goto out_unlock; + } + + if (!skb_transport_header_was_set(skb)) { + rc = SKB_DROP_REASON_PSP_OUTPUT; + goto out_unlock; + } + + ptr = psp_assoc_drv_data(pas); + if (*ptr != ns) { + rc = SKB_DROP_REASON_PSP_OUTPUT; + goto out_unlock; + } + + net = sock_net(skb->sk); + if (!psp_dev_encapsulate(net, skb, pas->tx.spi, pas->version, 0)) { + rc = SKB_DROP_REASON_PSP_OUTPUT; + goto out_unlock; + } + + /* Now pretend we just received this frame */ + if (peer_ns->psp.dev->config.versions & (1 << pas->version)) { + bool strip_icv = false; + u8 generation; + + /* We cheat a bit and put the generation in the key. + * In real life if generation was too old, then decryption would + * fail. Here, we just make it so a bad key causes a bad + * generation too, and psp_sk_rx_policy_check() will fail. + */ + generation = pas->tx.key[0]; + + skb_ext_reset(skb); + skb->mac_len = ETH_HLEN; + if (psp_dev_rcv(skb, peer_ns->psp.dev->id, generation, + strip_icv)) { + rc = SKB_DROP_REASON_PSP_OUTPUT; + goto out_unlock; + } + + *psp_ext = skb->extensions; + refcount_inc(&(*psp_ext)->refcnt); + skb->decrypted = 1; + } else { + struct ipv6hdr *ip6h __maybe_unused; + struct iphdr *iph; + struct udphdr *uh; + __wsum csum; + + /* Do not decapsulate. Receive the skb with the udp and psp + * headers still there as if this is a normal udp packet. + * psp_dev_encapsulate() sets udp checksum to 0, so we need to + * provide a valid checksum here, so the skb isn't dropped. + */ + uh = udp_hdr(skb); + csum = skb_checksum(skb, skb_transport_offset(skb), + ntohs(uh->len), 0); + + switch (skb->protocol) { + case htons(ETH_P_IP): + iph = ip_hdr(skb); + uh->check = udp_v4_check(ntohs(uh->len), iph->saddr, + iph->daddr, csum); + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + ip6h = ipv6_hdr(skb); + uh->check = udp_v6_check(ntohs(uh->len), &ip6h->saddr, + &ip6h->daddr, csum); + break; +#endif + } + + uh->check = uh->check ?: CSUM_MANGLED_0; + skb->ip_summed = CHECKSUM_NONE; + } + +out_unlock: + rcu_read_unlock(); + return rc; +} + +static int +nsim_psp_set_config(struct psp_dev *psd, struct psp_dev_config *conf, + struct netlink_ext_ack *extack) +{ + return 0; +} + +static int +nsim_rx_spi_alloc(struct psp_dev *psd, u32 version, + struct psp_key_parsed *assoc, + struct netlink_ext_ack *extack) +{ + struct netdevsim *ns = psd->drv_priv; + unsigned int new; + int i; + + new = ++ns->psp.spi & PSP_SPI_KEY_ID; + if (psd->generation & 1) + new |= PSP_SPI_KEY_PHASE; + + assoc->spi = cpu_to_be32(new); + assoc->key[0] = psd->generation; + for (i = 1; i < PSP_MAX_KEY; i++) + assoc->key[i] = ns->psp.spi + i; + + return 0; +} + +static int nsim_assoc_add(struct psp_dev *psd, struct psp_assoc *pas, + struct netlink_ext_ack *extack) +{ + struct netdevsim *ns = psd->drv_priv; + void **ptr = psp_assoc_drv_data(pas); + + /* Copy drv_priv from psd to assoc */ + *ptr = psd->drv_priv; + ns->psp.assoc_cnt++; + + return 0; +} + +static int nsim_key_rotate(struct psp_dev *psd, struct netlink_ext_ack *extack) +{ + return 0; +} + +static void nsim_assoc_del(struct psp_dev *psd, struct psp_assoc *pas) +{ + struct netdevsim *ns = psd->drv_priv; + void **ptr = psp_assoc_drv_data(pas); + + *ptr = NULL; + ns->psp.assoc_cnt--; +} + +static struct psp_dev_ops nsim_psp_ops = { + .set_config = nsim_psp_set_config, + .rx_spi_alloc = nsim_rx_spi_alloc, + .tx_key_add = nsim_assoc_add, + .tx_key_del = nsim_assoc_del, + .key_rotate = nsim_key_rotate, +}; + +static struct psp_dev_caps nsim_psp_caps = { + .versions = 1 << PSP_VERSION_HDR0_AES_GCM_128 | + 1 << PSP_VERSION_HDR0_AES_GMAC_128 | + 1 << PSP_VERSION_HDR0_AES_GCM_256 | + 1 << PSP_VERSION_HDR0_AES_GMAC_256, + .assoc_drv_spc = sizeof(void *), +}; + +void nsim_psp_uninit(struct netdevsim *ns) +{ + if (!IS_ERR(ns->psp.dev)) + psp_dev_unregister(ns->psp.dev); + WARN_ON(ns->psp.assoc_cnt); +} + +static ssize_t +nsim_psp_rereg_write(struct file *file, const char __user *data, size_t count, + loff_t *ppos) +{ + struct netdevsim *ns = file->private_data; + int err; + + nsim_psp_uninit(ns); + + ns->psp.dev = psp_dev_create(ns->netdev, &nsim_psp_ops, + &nsim_psp_caps, ns); + err = PTR_ERR_OR_ZERO(ns->psp.dev); + return err ?: count; +} + +static const struct file_operations nsim_psp_rereg_fops = { + .open = simple_open, + .write = nsim_psp_rereg_write, + .llseek = generic_file_llseek, + .owner = THIS_MODULE, +}; + +int nsim_psp_init(struct netdevsim *ns) +{ + struct dentry *ddir = ns->nsim_dev_port->ddir; + int err; + + ns->psp.dev = psp_dev_create(ns->netdev, &nsim_psp_ops, + &nsim_psp_caps, ns); + err = PTR_ERR_OR_ZERO(ns->psp.dev); + if (err) + return err; + + debugfs_create_file("psp_rereg", 0200, ddir, ns, &nsim_psp_rereg_fops); + return 0; +} diff --git a/net/core/skbuff.c b/net/core/skbuff.c index daaf6da43cc9..618afd59afff 100644 --- a/net/core/skbuff.c +++ b/net/core/skbuff.c @@ -7048,6 +7048,7 @@ void *__skb_ext_set(struct sk_buff *skb, enum skb_ext_id id, skb->active_extensions = 1 << id; return skb_ext_get_ptr(ext, id); } +EXPORT_SYMBOL_NS_GPL(__skb_ext_set, "NETDEV_INTERNAL"); /** * skb_ext_add - allocate space for given extension, COW if needed diff --git a/tools/testing/selftests/drivers/net/.gitignore b/tools/testing/selftests/drivers/net/.gitignore index d634d8395d90..585ecb4d5dc4 100644 --- a/tools/testing/selftests/drivers/net/.gitignore +++ b/tools/testing/selftests/drivers/net/.gitignore @@ -1,2 +1,3 @@ # SPDX-License-Identifier: GPL-2.0-only napi_id_helper +psp_responder diff --git a/tools/testing/selftests/drivers/net/Makefile b/tools/testing/selftests/drivers/net/Makefile index 984ece05f7f9..bd3af9a34e2f 100644 --- a/tools/testing/selftests/drivers/net/Makefile +++ b/tools/testing/selftests/drivers/net/Makefile @@ -19,6 +19,7 @@ TEST_PROGS := \ netcons_sysdata.sh \ netpoll_basic.py \ ping.py \ + psp.py \ queues.py \ stats.py \ shaper.py \ @@ -26,4 +27,13 @@ TEST_PROGS := \ xdp.py \ # end of TEST_PROGS +# YNL files, must be before "include ..lib.mk" +YNL_GEN_FILES := psp_responder +TEST_GEN_FILES += $(YNL_GEN_FILES) + include ../../lib.mk + +# YNL build +YNL_GENS := psp + +include ../../net/ynl.mk diff --git a/tools/testing/selftests/drivers/net/config b/tools/testing/selftests/drivers/net/config index f51b77cd0219..601431248d5b 100644 --- a/tools/testing/selftests/drivers/net/config +++ b/tools/testing/selftests/drivers/net/config @@ -1,6 +1,7 @@ CONFIG_CONFIGFS_FS=y CONFIG_DEBUG_INFO_BTF=y CONFIG_DEBUG_INFO_BTF_MODULES=n +CONFIG_INET_PSP=y CONFIG_IPV6=y CONFIG_NETDEVSIM=m CONFIG_NETCONSOLE=m diff --git a/tools/testing/selftests/drivers/net/hw/lib/py/__init__.py b/tools/testing/selftests/drivers/net/hw/lib/py/__init__.py index 1462a339a74b..0ceb297e7757 100644 --- a/tools/testing/selftests/drivers/net/hw/lib/py/__init__.py +++ b/tools/testing/selftests/drivers/net/hw/lib/py/__init__.py @@ -13,7 +13,7 @@ try: # Import one by one to avoid pylint false positives from net.lib.py import EthtoolFamily, NetdevFamily, NetshaperFamily, \ - NlError, RtnlFamily, DevlinkFamily + NlError, RtnlFamily, DevlinkFamily, PSPFamily from net.lib.py import CmdExitFailure from net.lib.py import bkg, cmd, defer, ethtool, fd_read_timeout, ip, \ rand_port, tool, wait_port_listen @@ -22,7 +22,7 @@ try: from net.lib.py import ksft_disruptive, ksft_exit, ksft_pr, ksft_run, \ ksft_setup from net.lib.py import ksft_eq, ksft_ge, ksft_in, ksft_is, ksft_lt, \ - ksft_ne, ksft_not_in, ksft_raises, ksft_true + ksft_ne, ksft_not_in, ksft_raises, ksft_true, ksft_gt, ksft_not_none from net.lib.py import NetNSEnter from drivers.net.lib.py import GenerateTraffic from drivers.net.lib.py import NetDrvEnv, NetDrvEpEnv diff --git a/tools/testing/selftests/drivers/net/lib/py/__init__.py b/tools/testing/selftests/drivers/net/lib/py/__init__.py index a07b56a75c8a..2a645415c4ca 100644 --- a/tools/testing/selftests/drivers/net/lib/py/__init__.py +++ b/tools/testing/selftests/drivers/net/lib/py/__init__.py @@ -12,7 +12,7 @@ try: # Import one by one to avoid pylint false positives from net.lib.py import EthtoolFamily, NetdevFamily, NetshaperFamily, \ - NlError, RtnlFamily, DevlinkFamily + NlError, RtnlFamily, DevlinkFamily, PSPFamily from net.lib.py import CmdExitFailure from net.lib.py import bkg, cmd, bpftool, bpftrace, defer, ethtool, \ fd_read_timeout, ip, rand_port, tool, wait_port_listen, wait_file @@ -21,7 +21,7 @@ try: from net.lib.py import ksft_disruptive, ksft_exit, ksft_pr, ksft_run, \ ksft_setup from net.lib.py import ksft_eq, ksft_ge, ksft_in, ksft_is, ksft_lt, \ - ksft_ne, ksft_not_in, ksft_raises, ksft_true + ksft_ne, ksft_not_in, ksft_raises, ksft_true, ksft_gt, ksft_not_none except ModuleNotFoundError as e: ksft_pr("Failed importing `net` library from kernel sources") ksft_pr(str(e)) diff --git a/tools/testing/selftests/drivers/net/lib/py/env.py b/tools/testing/selftests/drivers/net/lib/py/env.py index c1f3b608c6d8..01be3d9b9720 100644 --- a/tools/testing/selftests/drivers/net/lib/py/env.py +++ b/tools/testing/selftests/drivers/net/lib/py/env.py @@ -245,6 +245,10 @@ class NetDrvEpEnv(NetDrvEnvBase): if not self.addr_v[ipver] or not self.remote_addr_v[ipver]: raise KsftSkipEx(f"Test requires IPv{ipver} connectivity") + def require_nsim(self): + if self._ns is None: + raise KsftXfailEx("Test only works on netdevsim") + def _require_cmd(self, comm, key, host=None): cached = self._required_cmd.get(comm, {}) if cached.get(key) is None: diff --git a/tools/testing/selftests/drivers/net/psp.py b/tools/testing/selftests/drivers/net/psp.py new file mode 100755 index 000000000000..4ae7a785ff10 --- /dev/null +++ b/tools/testing/selftests/drivers/net/psp.py @@ -0,0 +1,627 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: GPL-2.0 + +"""Test suite for PSP capable drivers.""" + +import errno +import fcntl +import socket +import struct +import termios +import time + +from lib.py import defer +from lib.py import ksft_run, ksft_exit, ksft_pr +from lib.py import ksft_true, ksft_eq, ksft_ne, ksft_gt, ksft_raises +from lib.py import ksft_not_none +from lib.py import KsftSkipEx +from lib.py import NetDrvEpEnv, PSPFamily, NlError +from lib.py import bkg, rand_port, wait_port_listen + + +def _get_outq(s): + one = b'\0' * 4 + outq = fcntl.ioctl(s.fileno(), termios.TIOCOUTQ, one) + return struct.unpack("I", outq)[0] + + +def _send_with_ack(cfg, msg): + cfg.comm_sock.send(msg) + response = cfg.comm_sock.recv(4) + if response != b'ack\0': + raise RuntimeError("Unexpected server response", response) + + +def _remote_read_len(cfg): + cfg.comm_sock.send(b'read len\0') + return int(cfg.comm_sock.recv(1024)[:-1].decode('utf-8')) + + +def _make_clr_conn(cfg, ipver=None): + _send_with_ack(cfg, b'conn clr\0') + remote_addr = cfg.remote_addr_v[ipver] if ipver else cfg.remote_addr + s = socket.create_connection((remote_addr, cfg.comm_port), ) + return s + + +def _make_psp_conn(cfg, version=0, ipver=None): + _send_with_ack(cfg, b'conn psp\0' + struct.pack('BB', version, version)) + remote_addr = cfg.remote_addr_v[ipver] if ipver else cfg.remote_addr + s = socket.create_connection((remote_addr, cfg.comm_port), ) + return s + + +def _close_conn(cfg, s): + _send_with_ack(cfg, b'data close\0') + s.close() + + +def _close_psp_conn(cfg, s): + _close_conn(cfg, s) + + +def _spi_xchg(s, rx): + s.send(struct.pack('I', rx['spi']) + rx['key']) + tx = s.recv(4 + len(rx['key'])) + return { + 'spi': struct.unpack('I', tx[:4])[0], + 'key': tx[4:] + } + + +def _send_careful(cfg, s, rounds): + data = b'0123456789' * 200 + for i in range(rounds): + n = 0 + for _ in range(10): # allow 10 retries + try: + n += s.send(data[n:], socket.MSG_DONTWAIT) + if n == len(data): + break + except BlockingIOError: + time.sleep(0.05) + else: + rlen = _remote_read_len(cfg) + outq = _get_outq(s) + report = f'sent: {i * len(data) + n} remote len: {rlen} outq: {outq}' + raise RuntimeError(report) + + return len(data) * rounds + + +def _check_data_rx(cfg, exp_len): + read_len = -1 + for _ in range(30): + cfg.comm_sock.send(b'read len\0') + read_len = int(cfg.comm_sock.recv(1024)[:-1].decode('utf-8')) + if read_len == exp_len: + break + time.sleep(0.01) + ksft_eq(read_len, exp_len) + + +def _check_data_outq(s, exp_len, force_wait=False): + outq = 0 + for _ in range(10): + outq = _get_outq(s) + if not force_wait and outq == exp_len: + break + time.sleep(0.01) + ksft_eq(outq, exp_len) + +# +# Test case boiler plate +# + +def _init_psp_dev(cfg): + if not hasattr(cfg, 'psp_dev_id'): + # Figure out which local device we are testing against + for dev in cfg.pspnl.dev_get({}, dump=True): + if dev['ifindex'] == cfg.ifindex: + cfg.psp_info = dev + cfg.psp_dev_id = cfg.psp_info['id'] + break + else: + raise KsftSkipEx("No PSP devices found") + + # Enable PSP if necessary + cap = cfg.psp_info['psp-versions-cap'] + ena = cfg.psp_info['psp-versions-ena'] + if cap != ena: + cfg.pspnl.dev_set({'id': cfg.psp_dev_id, 'psp-versions-ena': cap}) + defer(cfg.pspnl.dev_set, {'id': cfg.psp_dev_id, + 'psp-versions-ena': ena }) + +# +# Test cases +# + +def dev_list_devices(cfg): + """ Dump all devices """ + _init_psp_dev(cfg) + + devices = cfg.pspnl.dev_get({}, dump=True) + + found = False + for dev in devices: + found |= dev['id'] == cfg.psp_dev_id + ksft_true(found) + + +def dev_get_device(cfg): + """ Get the device we intend to use """ + _init_psp_dev(cfg) + + dev = cfg.pspnl.dev_get({'id': cfg.psp_dev_id}) + ksft_eq(dev['id'], cfg.psp_dev_id) + + +def dev_get_device_bad(cfg): + """ Test getting device which doesn't exist """ + raised = False + try: + cfg.pspnl.dev_get({'id': 1234567}) + except NlError as e: + ksft_eq(e.nl_msg.error, -errno.ENODEV) + raised = True + ksft_true(raised) + + +def dev_rotate(cfg): + """ Test key rotation """ + _init_psp_dev(cfg) + + rot = cfg.pspnl.key_rotate({"id": cfg.psp_dev_id}) + ksft_eq(rot['id'], cfg.psp_dev_id) + rot = cfg.pspnl.key_rotate({"id": cfg.psp_dev_id}) + ksft_eq(rot['id'], cfg.psp_dev_id) + + +def dev_rotate_spi(cfg): + """ Test key rotation and SPI check """ + _init_psp_dev(cfg) + + top_a = top_b = 0 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + assoc_a = cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + top_a = assoc_a['rx-key']['spi'] >> 31 + s.close() + rot = cfg.pspnl.key_rotate({"id": cfg.psp_dev_id}) + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + ksft_eq(rot['id'], cfg.psp_dev_id) + assoc_b = cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + top_b = assoc_b['rx-key']['spi'] >> 31 + s.close() + ksft_ne(top_a, top_b) + + +def assoc_basic(cfg): + """ Test creating associations """ + _init_psp_dev(cfg) + + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + assoc = cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + ksft_eq(assoc['dev-id'], cfg.psp_dev_id) + ksft_gt(assoc['rx-key']['spi'], 0) + ksft_eq(len(assoc['rx-key']['key']), 16) + + assoc = cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": 0, + "tx-key": assoc['rx-key'], + "sock-fd": s.fileno()}) + ksft_eq(len(assoc), 0) + s.close() + + +def assoc_bad_dev(cfg): + """ Test creating associations with bad device ID """ + _init_psp_dev(cfg) + + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + with ksft_raises(NlError) as cm: + cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id + 1234567, + "sock-fd": s.fileno()}) + ksft_eq(cm.exception.nl_msg.error, -errno.ENODEV) + + +def assoc_sk_only_conn(cfg): + """ Test creating associations based on socket """ + _init_psp_dev(cfg) + + with _make_clr_conn(cfg) as s: + assoc = cfg.pspnl.rx_assoc({"version": 0, + "sock-fd": s.fileno()}) + ksft_eq(assoc['dev-id'], cfg.psp_dev_id) + cfg.pspnl.tx_assoc({"version": 0, + "tx-key": assoc['rx-key'], + "sock-fd": s.fileno()}) + _close_conn(cfg, s) + + +def assoc_sk_only_mismatch(cfg): + """ Test creating associations based on socket (dev mismatch) """ + _init_psp_dev(cfg) + + with _make_clr_conn(cfg) as s: + with ksft_raises(NlError) as cm: + cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id + 1234567, + "sock-fd": s.fileno()}) + the_exception = cm.exception + ksft_eq(the_exception.nl_msg.extack['bad-attr'], ".dev-id") + ksft_eq(the_exception.nl_msg.error, -errno.EINVAL) + + +def assoc_sk_only_mismatch_tx(cfg): + """ Test creating associations based on socket (dev mismatch) """ + _init_psp_dev(cfg) + + with _make_clr_conn(cfg) as s: + with ksft_raises(NlError) as cm: + assoc = cfg.pspnl.rx_assoc({"version": 0, + "sock-fd": s.fileno()}) + cfg.pspnl.tx_assoc({"version": 0, + "tx-key": assoc['rx-key'], + "dev-id": cfg.psp_dev_id + 1234567, + "sock-fd": s.fileno()}) + the_exception = cm.exception + ksft_eq(the_exception.nl_msg.extack['bad-attr'], ".dev-id") + ksft_eq(the_exception.nl_msg.error, -errno.EINVAL) + + +def assoc_sk_only_unconn(cfg): + """ Test creating associations based on socket (unconnected, should fail) """ + _init_psp_dev(cfg) + + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + with ksft_raises(NlError) as cm: + cfg.pspnl.rx_assoc({"version": 0, + "sock-fd": s.fileno()}) + the_exception = cm.exception + ksft_eq(the_exception.nl_msg.extack['miss-type'], "dev-id") + ksft_eq(the_exception.nl_msg.error, -errno.EINVAL) + + +def assoc_version_mismatch(cfg): + """ Test creating associations where Rx and Tx PSP versions do not match """ + _init_psp_dev(cfg) + + versions = list(cfg.psp_info['psp-versions-cap']) + if len(versions) < 2: + raise KsftSkipEx("Not enough PSP versions supported by the device for the test") + + # Translate versions to integers + versions = [cfg.pspnl.consts["version"].entries[v].value for v in versions] + + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + rx = cfg.pspnl.rx_assoc({"version": versions[0], + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + + for version in versions[1:]: + with ksft_raises(NlError) as cm: + cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": version, + "tx-key": rx['rx-key'], + "sock-fd": s.fileno()}) + the_exception = cm.exception + ksft_eq(the_exception.nl_msg.error, -errno.EINVAL) + + +def assoc_twice(cfg): + """ Test reusing Tx assoc for two sockets """ + _init_psp_dev(cfg) + + def rx_assoc_check(s): + assoc = cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + ksft_eq(assoc['dev-id'], cfg.psp_dev_id) + ksft_gt(assoc['rx-key']['spi'], 0) + ksft_eq(len(assoc['rx-key']['key']), 16) + + return assoc + + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + assoc = rx_assoc_check(s) + tx = cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": 0, + "tx-key": assoc['rx-key'], + "sock-fd": s.fileno()}) + ksft_eq(len(tx), 0) + + # Use the same Tx assoc second time + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s2: + rx_assoc_check(s2) + tx = cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": 0, + "tx-key": assoc['rx-key'], + "sock-fd": s2.fileno()}) + ksft_eq(len(tx), 0) + + s.close() + + +def _data_basic_send(cfg, version, ipver): + """ Test basic data send """ + _init_psp_dev(cfg) + + # Version 0 is required by spec, don't let it skip + if version: + name = cfg.pspnl.consts["version"].entries_by_val[version].name + if name not in cfg.psp_info['psp-versions-cap']: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + with ksft_raises(NlError) as cm: + cfg.pspnl.rx_assoc({"version": version, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + ksft_eq(cm.exception.nl_msg.error, -errno.EOPNOTSUPP) + raise KsftSkipEx("PSP version not supported", name) + + s = _make_psp_conn(cfg, version, ipver) + + rx_assoc = cfg.pspnl.rx_assoc({"version": version, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + rx = rx_assoc['rx-key'] + tx = _spi_xchg(s, rx) + + cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": version, + "tx-key": tx, + "sock-fd": s.fileno()}) + + data_len = _send_careful(cfg, s, 100) + _check_data_rx(cfg, data_len) + _close_psp_conn(cfg, s) + + +def __bad_xfer_do(cfg, s, tx, version='hdr0-aes-gcm-128'): + # Make sure we accept the ACK for the SPI before we seal with the bad assoc + _check_data_outq(s, 0) + + cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": version, + "tx-key": tx, + "sock-fd": s.fileno()}) + + data_len = _send_careful(cfg, s, 20) + _check_data_outq(s, data_len, force_wait=True) + _check_data_rx(cfg, 0) + _close_psp_conn(cfg, s) + + +def data_send_bad_key(cfg): + """ Test send data with bad key """ + _init_psp_dev(cfg) + + s = _make_psp_conn(cfg) + + rx_assoc = cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + rx = rx_assoc['rx-key'] + tx = _spi_xchg(s, rx) + tx['key'] = (tx['key'][0] ^ 0xff).to_bytes(1, 'little') + tx['key'][1:] + __bad_xfer_do(cfg, s, tx) + + +def data_send_disconnect(cfg): + """ Test socket close after sending data """ + _init_psp_dev(cfg) + + with _make_psp_conn(cfg) as s: + assoc = cfg.pspnl.rx_assoc({"version": 0, + "sock-fd": s.fileno()}) + tx = _spi_xchg(s, assoc['rx-key']) + cfg.pspnl.tx_assoc({"version": 0, + "tx-key": tx, + "sock-fd": s.fileno()}) + + data_len = _send_careful(cfg, s, 100) + _check_data_rx(cfg, data_len) + + s.shutdown(socket.SHUT_RDWR) + s.close() + + +def _data_mss_adjust(cfg, ipver): + _init_psp_dev(cfg) + + # First figure out what the MSS would be without any adjustments + s = _make_clr_conn(cfg, ipver) + s.send(b"0123456789abcdef" * 1024) + _check_data_rx(cfg, 16 * 1024) + mss = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG) + _close_conn(cfg, s) + + s = _make_psp_conn(cfg, 0, ipver) + try: + rx_assoc = cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + rx = rx_assoc['rx-key'] + tx = _spi_xchg(s, rx) + + rxmss = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG) + ksft_eq(mss, rxmss) + + cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": 0, + "tx-key": tx, + "sock-fd": s.fileno()}) + + txmss = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG) + ksft_eq(mss, txmss + 40) + + data_len = _send_careful(cfg, s, 100) + _check_data_rx(cfg, data_len) + _check_data_outq(s, 0) + + txmss = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG) + ksft_eq(mss, txmss + 40) + finally: + _close_psp_conn(cfg, s) + + +def data_stale_key(cfg): + """ Test send on a double-rotated key """ + _init_psp_dev(cfg) + + s = _make_psp_conn(cfg) + try: + rx_assoc = cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + rx = rx_assoc['rx-key'] + tx = _spi_xchg(s, rx) + + cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": 0, + "tx-key": tx, + "sock-fd": s.fileno()}) + + data_len = _send_careful(cfg, s, 100) + _check_data_rx(cfg, data_len) + _check_data_outq(s, 0) + + cfg.pspnl.key_rotate({"id": cfg.psp_dev_id}) + cfg.pspnl.key_rotate({"id": cfg.psp_dev_id}) + + s.send(b'0123456789' * 200) + _check_data_outq(s, 2000, force_wait=True) + finally: + _close_psp_conn(cfg, s) + + +def __nsim_psp_rereg(cfg): + # The PSP dev ID will change, remember what was there before + before = set([x['id'] for x in cfg.pspnl.dev_get({}, dump=True)]) + + cfg._ns.nsims[0].dfs_write('psp_rereg', '1') + + after = set([x['id'] for x in cfg.pspnl.dev_get({}, dump=True)]) + + new_devs = list(after - before) + ksft_eq(len(new_devs), 1) + cfg.psp_dev_id = list(after - before)[0] + + +def removal_device_rx(cfg): + """ Test removing a netdev / PSD with active Rx assoc """ + + # We could technically devlink reload real devices, too + # but that kills the control socket. So test this on + # netdevsim only for now + cfg.require_nsim() + + s = _make_clr_conn(cfg) + try: + rx_assoc = cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + ksft_not_none(rx_assoc) + + __nsim_psp_rereg(cfg) + finally: + _close_conn(cfg, s) + + +def removal_device_bi(cfg): + """ Test removing a netdev / PSD with active Rx/Tx assoc """ + + # We could technically devlink reload real devices, too + # but that kills the control socket. So test this on + # netdevsim only for now + cfg.require_nsim() + + s = _make_clr_conn(cfg) + try: + rx_assoc = cfg.pspnl.rx_assoc({"version": 0, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": 0, + "tx-key": rx_assoc['rx-key'], + "sock-fd": s.fileno()}) + __nsim_psp_rereg(cfg) + finally: + _close_conn(cfg, s) + + +def psp_ip_ver_test_builder(name, test_func, psp_ver, ipver): + """Build test cases for each combo of PSP version and IP version""" + def test_case(cfg): + cfg.require_ipver(ipver) + test_case.__name__ = f"{name}_v{psp_ver}_ip{ipver}" + test_func(cfg, psp_ver, ipver) + return test_case + + +def ipver_test_builder(name, test_func, ipver): + """Build test cases for each IP version""" + def test_case(cfg): + cfg.require_ipver(ipver) + test_case.__name__ = f"{name}_ip{ipver}" + test_func(cfg, ipver) + return test_case + + +def main() -> None: + """ Ksft boiler plate main """ + + with NetDrvEpEnv(__file__) as cfg: + cfg.pspnl = PSPFamily() + + # Set up responder and communication sock + responder = cfg.remote.deploy("psp_responder") + + cfg.comm_port = rand_port() + srv = None + try: + with bkg(responder + f" -p {cfg.comm_port}", host=cfg.remote, + exit_wait=True) as srv: + wait_port_listen(cfg.comm_port, host=cfg.remote) + + cfg.comm_sock = socket.create_connection((cfg.remote_addr, + cfg.comm_port), + timeout=1) + + cases = [ + psp_ip_ver_test_builder( + "data_basic_send", _data_basic_send, version, ipver + ) + for version in range(0, 4) + for ipver in ("4", "6") + ] + cases += [ + ipver_test_builder("data_mss_adjust", _data_mss_adjust, ipver) + for ipver in ("4", "6") + ] + + ksft_run(cases=cases, globs=globals(), + case_pfx={"dev_", "data_", "assoc_", "removal_"}, + args=(cfg, )) + + cfg.comm_sock.send(b"exit\0") + cfg.comm_sock.close() + finally: + if srv and (srv.stdout or srv.stderr): + ksft_pr("") + ksft_pr(f"Responder logs ({srv.ret}):") + if srv and srv.stdout: + ksft_pr("STDOUT:\n# " + srv.stdout.strip().replace("\n", "\n# ")) + if srv and srv.stderr: + ksft_pr("STDERR:\n# " + srv.stderr.strip().replace("\n", "\n# ")) + ksft_exit() + + +if __name__ == "__main__": + main() diff --git a/tools/testing/selftests/drivers/net/psp_responder.c b/tools/testing/selftests/drivers/net/psp_responder.c new file mode 100644 index 000000000000..f309e0d73cbf --- /dev/null +++ b/tools/testing/selftests/drivers/net/psp_responder.c @@ -0,0 +1,483 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <stdio.h> +#include <string.h> +#include <sys/poll.h> +#include <sys/socket.h> +#include <sys/time.h> +#include <netinet/in.h> +#include <unistd.h> + +#include <ynl.h> + +#include "psp-user.h" + +#define dbg(msg...) \ +do { \ + if (opts->verbose) \ + fprintf(stderr, "DEBUG: " msg); \ +} while (0) + +static bool should_quit; + +struct opts { + int port; + int devid; + bool verbose; +}; + +enum accept_cfg { + ACCEPT_CFG_NONE = 0, + ACCEPT_CFG_CLEAR, + ACCEPT_CFG_PSP, +}; + +static struct { + unsigned char tx; + unsigned char rx; +} psp_vers; + +static int conn_setup_psp(struct ynl_sock *ys, struct opts *opts, int data_sock) +{ + struct psp_rx_assoc_rsp *rsp; + struct psp_rx_assoc_req *req; + struct psp_tx_assoc_rsp *tsp; + struct psp_tx_assoc_req *teq; + char info[300]; + int key_len; + ssize_t sz; + __u32 spi; + + dbg("create PSP connection\n"); + + // Rx assoc alloc + req = psp_rx_assoc_req_alloc(); + + psp_rx_assoc_req_set_sock_fd(req, data_sock); + psp_rx_assoc_req_set_version(req, psp_vers.rx); + + rsp = psp_rx_assoc(ys, req); + psp_rx_assoc_req_free(req); + + if (!rsp) { + perror("ERROR: failed to Rx assoc"); + return -1; + } + + // SPI exchange + key_len = rsp->rx_key._len.key; + memcpy(info, &rsp->rx_key.spi, sizeof(spi)); + memcpy(&info[sizeof(spi)], rsp->rx_key.key, key_len); + sz = sizeof(spi) + key_len; + + send(data_sock, info, sz, MSG_WAITALL); + psp_rx_assoc_rsp_free(rsp); + + sz = recv(data_sock, info, sz, MSG_WAITALL); + if (sz < 0) { + perror("ERROR: failed to read PSP key from sock"); + return -1; + } + memcpy(&spi, info, sizeof(spi)); + + // Setup Tx assoc + teq = psp_tx_assoc_req_alloc(); + + psp_tx_assoc_req_set_sock_fd(teq, data_sock); + psp_tx_assoc_req_set_version(teq, psp_vers.tx); + psp_tx_assoc_req_set_tx_key_spi(teq, spi); + psp_tx_assoc_req_set_tx_key_key(teq, &info[sizeof(spi)], key_len); + + tsp = psp_tx_assoc(ys, teq); + psp_tx_assoc_req_free(teq); + if (!tsp) { + perror("ERROR: failed to Tx assoc"); + return -1; + } + psp_tx_assoc_rsp_free(tsp); + + return 0; +} + +static void send_ack(int sock) +{ + send(sock, "ack", 4, MSG_WAITALL); +} + +static void send_err(int sock) +{ + send(sock, "err", 4, MSG_WAITALL); +} + +static void send_str(int sock, int value) +{ + char buf[128]; + int ret; + + ret = snprintf(buf, sizeof(buf), "%d", value); + send(sock, buf, ret + 1, MSG_WAITALL); +} + +static void +run_session(struct ynl_sock *ys, struct opts *opts, + int server_sock, int comm_sock) +{ + enum accept_cfg accept_cfg = ACCEPT_CFG_NONE; + struct pollfd pfds[3]; + size_t data_read = 0; + int data_sock = -1; + + while (true) { + bool race_close = false; + int nfds; + + memset(pfds, 0, sizeof(pfds)); + + pfds[0].fd = server_sock; + pfds[0].events = POLLIN; + + pfds[1].fd = comm_sock; + pfds[1].events = POLLIN; + + nfds = 2; + if (data_sock >= 0) { + pfds[2].fd = data_sock; + pfds[2].events = POLLIN; + nfds++; + } + + dbg(" ...\n"); + if (poll(pfds, nfds, -1) < 0) { + perror("poll"); + break; + } + + /* data sock */ + if (pfds[2].revents & POLLIN) { + char buf[8192]; + ssize_t n; + + n = recv(data_sock, buf, sizeof(buf), 0); + if (n <= 0) { + if (n < 0) + perror("data read"); + close(data_sock); + data_sock = -1; + dbg("data sock closed\n"); + } else { + data_read += n; + dbg("data read %zd\n", data_read); + } + } + + /* comm sock */ + if (pfds[1].revents & POLLIN) { + static char buf[4096]; + static ssize_t off; + bool consumed; + ssize_t n; + + n = recv(comm_sock, &buf[off], sizeof(buf) - off, 0); + if (n <= 0) { + if (n < 0) + perror("comm read"); + return; + } + + off += n; + n = off; + +#define __consume(sz) \ + ({ \ + if (n == (sz)) { \ + off = 0; \ + } else { \ + off -= (sz); \ + memmove(buf, &buf[(sz)], off); \ + } \ + }) + +#define cmd(_name) \ + ({ \ + ssize_t sz = sizeof(_name); \ + bool match = n >= sz && !memcmp(buf, _name, sz); \ + \ + if (match) { \ + dbg("command: " _name "\n"); \ + __consume(sz); \ + } \ + consumed |= match; \ + match; \ + }) + + do { + consumed = false; + + if (cmd("read len")) + send_str(comm_sock, data_read); + + if (cmd("data echo")) { + if (data_sock >= 0) + send(data_sock, "echo", 5, + MSG_WAITALL); + else + fprintf(stderr, "WARN: echo but no data sock\n"); + send_ack(comm_sock); + } + if (cmd("data close")) { + if (data_sock >= 0) { + close(data_sock); + data_sock = -1; + send_ack(comm_sock); + } else { + race_close = true; + } + } + if (cmd("conn psp")) { + if (accept_cfg != ACCEPT_CFG_NONE) + fprintf(stderr, "WARN: old conn config still set!\n"); + accept_cfg = ACCEPT_CFG_PSP; + send_ack(comm_sock); + /* next two bytes are versions */ + if (off >= 2) { + memcpy(&psp_vers, buf, 2); + __consume(2); + } else { + fprintf(stderr, "WARN: short conn psp command!\n"); + } + } + if (cmd("conn clr")) { + if (accept_cfg != ACCEPT_CFG_NONE) + fprintf(stderr, "WARN: old conn config still set!\n"); + accept_cfg = ACCEPT_CFG_CLEAR; + send_ack(comm_sock); + } + if (cmd("exit")) + should_quit = true; +#undef cmd + + if (!consumed) { + fprintf(stderr, "WARN: unknown cmd: [%zd] %s\n", + off, buf); + } + } while (consumed && off); + } + + /* server sock */ + if (pfds[0].revents & POLLIN) { + if (data_sock >= 0) { + fprintf(stderr, "WARN: new data sock but old one still here\n"); + close(data_sock); + data_sock = -1; + } + data_sock = accept(server_sock, NULL, NULL); + if (data_sock < 0) { + perror("accept"); + continue; + } + data_read = 0; + + if (accept_cfg == ACCEPT_CFG_CLEAR) { + dbg("new data sock: clear\n"); + /* nothing to do */ + } else if (accept_cfg == ACCEPT_CFG_PSP) { + dbg("new data sock: psp\n"); + conn_setup_psp(ys, opts, data_sock); + } else { + fprintf(stderr, "WARN: new data sock but no config\n"); + } + accept_cfg = ACCEPT_CFG_NONE; + } + + if (race_close) { + if (data_sock >= 0) { + /* indeed, ordering problem, handle the close */ + close(data_sock); + data_sock = -1; + send_ack(comm_sock); + } else { + fprintf(stderr, "WARN: close but no data sock\n"); + send_err(comm_sock); + } + } + } + dbg("session ending\n"); +} + +static int spawn_server(struct opts *opts) +{ + struct sockaddr_in6 addr; + int fd; + + fd = socket(AF_INET6, SOCK_STREAM, 0); + if (fd < 0) { + perror("can't open socket"); + return -1; + } + + memset(&addr, 0, sizeof(addr)); + + addr.sin6_family = AF_INET6; + addr.sin6_addr = in6addr_any; + addr.sin6_port = htons(opts->port); + + if (bind(fd, (struct sockaddr *)&addr, sizeof(addr))) { + perror("can't bind socket"); + return -1; + } + + if (listen(fd, 5)) { + perror("can't listen"); + return -1; + } + + return fd; +} + +static int run_responder(struct ynl_sock *ys, struct opts *opts) +{ + int server_sock, comm; + + server_sock = spawn_server(opts); + if (server_sock < 0) + return 4; + + while (!should_quit) { + comm = accept(server_sock, NULL, NULL); + if (comm < 0) { + perror("accept failed"); + } else { + run_session(ys, opts, server_sock, comm); + close(comm); + } + } + + return 0; +} + +static void usage(const char *name, const char *miss) +{ + if (miss) + fprintf(stderr, "Missing argument: %s\n", miss); + + fprintf(stderr, "Usage: %s -p port [-v] [-d psp-dev-id]\n", name); + exit(EXIT_FAILURE); +} + +static void parse_cmd_opts(int argc, char **argv, struct opts *opts) +{ + int opt; + + while ((opt = getopt(argc, argv, "vp:d:")) != -1) { + switch (opt) { + case 'v': + opts->verbose = 1; + break; + case 'p': + opts->port = atoi(optarg); + break; + case 'd': + opts->devid = atoi(optarg); + break; + default: + usage(argv[0], NULL); + } + } +} + +static int psp_dev_set_ena(struct ynl_sock *ys, __u32 dev_id, __u32 versions) +{ + struct psp_dev_set_req *sreq; + struct psp_dev_set_rsp *srsp; + + fprintf(stderr, "Set PSP enable on device %d to 0x%x\n", + dev_id, versions); + + sreq = psp_dev_set_req_alloc(); + + psp_dev_set_req_set_id(sreq, dev_id); + psp_dev_set_req_set_psp_versions_ena(sreq, versions); + + srsp = psp_dev_set(ys, sreq); + psp_dev_set_req_free(sreq); + if (!srsp) + return 10; + + psp_dev_set_rsp_free(srsp); + return 0; +} + +int main(int argc, char **argv) +{ + struct psp_dev_get_list *dev_list; + bool devid_found = false; + __u32 ver_ena, ver_cap; + struct opts opts = {}; + struct ynl_error yerr; + struct ynl_sock *ys; + int first_id = 0; + int ret; + + parse_cmd_opts(argc, argv, &opts); + if (!opts.port) + usage(argv[0], "port"); // exits + + ys = ynl_sock_create(&ynl_psp_family, &yerr); + if (!ys) { + fprintf(stderr, "YNL: %s\n", yerr.msg); + return 1; + } + + dev_list = psp_dev_get_dump(ys); + if (ynl_dump_empty(dev_list)) { + if (ys->err.code) + goto err_close; + fprintf(stderr, "No PSP devices\n"); + goto err_close_silent; + } + + ynl_dump_foreach(dev_list, d) { + if (opts.devid) { + devid_found = true; + ver_ena = d->psp_versions_ena; + ver_cap = d->psp_versions_cap; + } else if (!first_id) { + first_id = d->id; + ver_ena = d->psp_versions_ena; + ver_cap = d->psp_versions_cap; + } else { + fprintf(stderr, "Multiple PSP devices found\n"); + goto err_close_silent; + } + } + psp_dev_get_list_free(dev_list); + + if (opts.devid && !devid_found) { + fprintf(stderr, "PSP device %d requested on cmdline, not found\n", + opts.devid); + goto err_close_silent; + } else if (!opts.devid) { + opts.devid = first_id; + } + + if (ver_ena != ver_cap) { + ret = psp_dev_set_ena(ys, opts.devid, ver_cap); + if (ret) + goto err_close; + } + + ret = run_responder(ys, &opts); + + if (ver_ena != ver_cap && psp_dev_set_ena(ys, opts.devid, ver_ena)) + fprintf(stderr, "WARN: failed to set the PSP versions back\n"); + + ynl_sock_destroy(ys); + + return ret; + +err_close: + fprintf(stderr, "YNL: %s\n", ys->err.msg); +err_close_silent: + ynl_sock_destroy(ys); + return 2; +} diff --git a/tools/testing/selftests/net/lib/py/__init__.py b/tools/testing/selftests/net/lib/py/__init__.py index 02be28dcc089..997b85cc216a 100644 --- a/tools/testing/selftests/net/lib/py/__init__.py +++ b/tools/testing/selftests/net/lib/py/__init__.py @@ -6,4 +6,4 @@ from .netns import NetNS, NetNSEnter from .nsim import * from .utils import * from .ynl import NlError, YnlFamily, EthtoolFamily, NetdevFamily, RtnlFamily, RtnlAddrFamily -from .ynl import NetshaperFamily, DevlinkFamily +from .ynl import NetshaperFamily, DevlinkFamily, PSPFamily diff --git a/tools/testing/selftests/net/lib/py/ksft.py b/tools/testing/selftests/net/lib/py/ksft.py index 8e35ed12ed9e..83b1574f7719 100644 --- a/tools/testing/selftests/net/lib/py/ksft.py +++ b/tools/testing/selftests/net/lib/py/ksft.py @@ -72,6 +72,11 @@ def ksft_true(a, comment=""): _fail("Check failed", a, "does not eval to True", comment) +def ksft_not_none(a, comment=""): + if a is None: + _fail("Check failed", a, "is None", comment) + + def ksft_in(a, b, comment=""): if a not in b: _fail("Check failed", a, "not in", b, comment) @@ -92,6 +97,11 @@ def ksft_ge(a, b, comment=""): _fail("Check failed", a, "<", b, comment) +def ksft_gt(a, b, comment=""): + if a <= b: + _fail("Check failed", a, "<=", b, comment) + + def ksft_lt(a, b, comment=""): if a >= b: _fail("Check failed", a, ">=", b, comment) diff --git a/tools/testing/selftests/net/lib/py/ynl.py b/tools/testing/selftests/net/lib/py/ynl.py index 2b3a61ea3bfa..32c223e93b2c 100644 --- a/tools/testing/selftests/net/lib/py/ynl.py +++ b/tools/testing/selftests/net/lib/py/ynl.py @@ -61,3 +61,8 @@ class DevlinkFamily(YnlFamily): def __init__(self, recv_size=0): super().__init__((SPEC_PATH / Path('devlink.yaml')).as_posix(), schema='', recv_size=recv_size) + +class PSPFamily(YnlFamily): + def __init__(self, recv_size=0): + super().__init__((SPEC_PATH / Path('psp.yaml')).as_posix(), + schema='', recv_size=recv_size) |