diff options
Diffstat (limited to 'net/mctp')
| -rw-r--r-- | net/mctp/Kconfig | 24 | ||||
| -rw-r--r-- | net/mctp/Makefile | 6 | ||||
| -rw-r--r-- | net/mctp/af_mctp.c | 911 | ||||
| -rw-r--r-- | net/mctp/device.c | 561 | ||||
| -rw-r--r-- | net/mctp/neigh.c | 353 | ||||
| -rw-r--r-- | net/mctp/route.c | 1790 | ||||
| -rw-r--r-- | net/mctp/test/route-test.c | 1598 | ||||
| -rw-r--r-- | net/mctp/test/sock-test.c | 396 | ||||
| -rw-r--r-- | net/mctp/test/utils.c | 284 | ||||
| -rw-r--r-- | net/mctp/test/utils.h | 74 |
10 files changed, 5997 insertions, 0 deletions
diff --git a/net/mctp/Kconfig b/net/mctp/Kconfig new file mode 100644 index 000000000000..d8d3413a37f7 --- /dev/null +++ b/net/mctp/Kconfig @@ -0,0 +1,24 @@ + +menuconfig MCTP + depends on NET + bool "MCTP core protocol support" + help + Management Component Transport Protocol (MCTP) is an in-system + protocol for communicating between management controllers and + their managed devices (peripherals, host processors, etc.). The + protocol is defined by DMTF specification DSP0236. + + This option enables core MCTP support. For communicating with other + devices, you'll want to enable a driver for a specific hardware + channel. + +config MCTP_TEST + bool "MCTP core tests" if !KUNIT_ALL_TESTS + select MCTP_FLOWS + depends on MCTP=y && KUNIT=y + default KUNIT_ALL_TESTS + +config MCTP_FLOWS + bool + depends on MCTP + select SKB_EXTENSIONS diff --git a/net/mctp/Makefile b/net/mctp/Makefile new file mode 100644 index 000000000000..6cd55233e685 --- /dev/null +++ b/net/mctp/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_MCTP) += mctp.o +mctp-objs := af_mctp.o device.o route.o neigh.o + +# tests +obj-$(CONFIG_MCTP_TEST) += test/utils.o diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c new file mode 100644 index 000000000000..209a963112e3 --- /dev/null +++ b/net/mctp/af_mctp.c @@ -0,0 +1,911 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Management Component Transport Protocol (MCTP) + * + * Copyright (c) 2021 Code Construct + * Copyright (c) 2021 Google + */ + +#include <linux/compat.h> +#include <linux/if_arp.h> +#include <linux/net.h> +#include <linux/mctp.h> +#include <linux/module.h> +#include <linux/socket.h> + +#include <net/mctp.h> +#include <net/mctpdevice.h> +#include <net/sock.h> + +#define CREATE_TRACE_POINTS +#include <trace/events/mctp.h> + +/* socket implementation */ + +static void mctp_sk_expire_keys(struct timer_list *timer); + +static int mctp_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (sk) { + sock->sk = NULL; + sk->sk_prot->close(sk, 0); + } + + return 0; +} + +/* Generic sockaddr checks, padding checks only so far */ +static bool mctp_sockaddr_is_ok(const struct sockaddr_mctp *addr) +{ + return !addr->__smctp_pad0 && !addr->__smctp_pad1; +} + +static bool mctp_sockaddr_ext_is_ok(const struct sockaddr_mctp_ext *addr) +{ + return !addr->__smctp_pad0[0] && + !addr->__smctp_pad0[1] && + !addr->__smctp_pad0[2]; +} + +static int mctp_bind(struct socket *sock, struct sockaddr_unsized *addr, int addrlen) +{ + struct sock *sk = sock->sk; + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct net *net = sock_net(&msk->sk); + struct sockaddr_mctp *smctp; + int rc; + + if (addrlen < sizeof(*smctp)) + return -EINVAL; + + if (addr->sa_family != AF_MCTP) + return -EAFNOSUPPORT; + + if (!capable(CAP_NET_BIND_SERVICE)) + return -EACCES; + + /* it's a valid sockaddr for MCTP, cast and do protocol checks */ + smctp = (struct sockaddr_mctp *)addr; + + if (!mctp_sockaddr_is_ok(smctp)) + return -EINVAL; + + lock_sock(sk); + + if (sk_hashed(sk)) { + rc = -EADDRINUSE; + goto out_release; + } + + msk->bind_local_addr = smctp->smctp_addr.s_addr; + + /* MCTP_NET_ANY with a specific EID is resolved to the default net + * at bind() time. + * For bind_addr=MCTP_ADDR_ANY it is handled specially at route + * lookup time. + */ + if (smctp->smctp_network == MCTP_NET_ANY && + msk->bind_local_addr != MCTP_ADDR_ANY) { + msk->bind_net = mctp_default_net(net); + } else { + msk->bind_net = smctp->smctp_network; + } + + /* ignore the IC bit */ + smctp->smctp_type &= 0x7f; + + if (msk->bind_peer_set) { + if (msk->bind_type != smctp->smctp_type) { + /* Prior connect() had a different type */ + rc = -EINVAL; + goto out_release; + } + + if (msk->bind_net == MCTP_NET_ANY) { + /* Restrict to the network passed to connect() */ + msk->bind_net = msk->bind_peer_net; + } + + if (msk->bind_net != msk->bind_peer_net) { + /* connect() had a different net to bind() */ + rc = -EINVAL; + goto out_release; + } + } else { + msk->bind_type = smctp->smctp_type; + } + + rc = sk->sk_prot->hash(sk); + +out_release: + release_sock(sk); + + return rc; +} + +/* Used to set a specific peer prior to bind. Not used for outbound + * connections (Tag Owner set) since MCTP is a datagram protocol. + */ +static int mctp_connect(struct socket *sock, struct sockaddr_unsized *addr, + int addrlen, int flags) +{ + struct sock *sk = sock->sk; + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct net *net = sock_net(&msk->sk); + struct sockaddr_mctp *smctp; + int rc; + + if (addrlen != sizeof(*smctp)) + return -EINVAL; + + if (addr->sa_family != AF_MCTP) + return -EAFNOSUPPORT; + + /* It's a valid sockaddr for MCTP, cast and do protocol checks */ + smctp = (struct sockaddr_mctp *)addr; + + if (!mctp_sockaddr_is_ok(smctp)) + return -EINVAL; + + /* Can't bind by tag */ + if (smctp->smctp_tag) + return -EINVAL; + + /* IC bit must be unset */ + if (smctp->smctp_type & 0x80) + return -EINVAL; + + lock_sock(sk); + + if (sk_hashed(sk)) { + /* bind() already */ + rc = -EADDRINUSE; + goto out_release; + } + + if (msk->bind_peer_set) { + /* connect() already */ + rc = -EADDRINUSE; + goto out_release; + } + + msk->bind_peer_set = true; + msk->bind_peer_addr = smctp->smctp_addr.s_addr; + msk->bind_type = smctp->smctp_type; + if (smctp->smctp_network == MCTP_NET_ANY) + msk->bind_peer_net = mctp_default_net(net); + else + msk->bind_peer_net = smctp->smctp_network; + + rc = 0; + +out_release: + release_sock(sk); + return rc; +} + +static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); + int rc, addrlen = msg->msg_namelen; + struct sock *sk = sock->sk; + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct mctp_skb_cb *cb; + struct sk_buff *skb = NULL; + struct mctp_dst dst; + int hlen; + + if (addr) { + const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER | + MCTP_TAG_PREALLOC; + + if (addrlen < sizeof(struct sockaddr_mctp)) + return -EINVAL; + if (addr->smctp_family != AF_MCTP) + return -EINVAL; + if (!mctp_sockaddr_is_ok(addr)) + return -EINVAL; + if (addr->smctp_tag & ~tagbits) + return -EINVAL; + /* can't preallocate a non-owned tag */ + if (addr->smctp_tag & MCTP_TAG_PREALLOC && + !(addr->smctp_tag & MCTP_TAG_OWNER)) + return -EINVAL; + + } else { + /* TODO: connect()ed sockets */ + return -EDESTADDRREQ; + } + + if (!capable(CAP_NET_RAW)) + return -EACCES; + + if (addr->smctp_network == MCTP_NET_ANY) + addr->smctp_network = mctp_default_net(sock_net(sk)); + + /* direct addressing */ + if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) { + DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, + extaddr, msg->msg_name); + + if (!mctp_sockaddr_ext_is_ok(extaddr)) + return -EINVAL; + + rc = mctp_dst_from_extaddr(&dst, sock_net(sk), + extaddr->smctp_ifindex, + extaddr->smctp_halen, + extaddr->smctp_haddr); + if (rc) + return rc; + + } else { + rc = mctp_route_lookup(sock_net(sk), addr->smctp_network, + addr->smctp_addr.s_addr, &dst); + if (rc) + return rc; + } + + hlen = LL_RESERVED_SPACE(dst.dev->dev) + sizeof(struct mctp_hdr); + + skb = sock_alloc_send_skb(sk, hlen + 1 + len, + msg->msg_flags & MSG_DONTWAIT, &rc); + if (!skb) + goto err_release_dst; + + skb_reserve(skb, hlen); + + /* set type as first byte in payload */ + *(u8 *)skb_put(skb, 1) = addr->smctp_type; + + rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len); + if (rc < 0) + goto err_free; + + /* set up cb */ + cb = __mctp_cb(skb); + cb->net = addr->smctp_network; + + rc = mctp_local_output(sk, &dst, skb, addr->smctp_addr.s_addr, + addr->smctp_tag); + + mctp_dst_release(&dst); + return rc ? : len; + +err_free: + kfree_skb(skb); +err_release_dst: + mctp_dst_release(&dst); + return rc; +} + +static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); + struct sock *sk = sock->sk; + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct sk_buff *skb; + size_t msglen; + u8 type; + int rc; + + if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK)) + return -EOPNOTSUPP; + + skb = skb_recv_datagram(sk, flags, &rc); + if (!skb) + return rc; + + if (!skb->len) { + rc = 0; + goto out_free; + } + + /* extract message type, remove from data */ + type = *((u8 *)skb->data); + msglen = skb->len - 1; + + if (len < msglen) + msg->msg_flags |= MSG_TRUNC; + else + len = msglen; + + rc = skb_copy_datagram_msg(skb, 1, msg, len); + if (rc < 0) + goto out_free; + + sock_recv_cmsgs(msg, sk, skb); + + if (addr) { + struct mctp_skb_cb *cb = mctp_cb(skb); + /* TODO: expand mctp_skb_cb for header fields? */ + struct mctp_hdr *hdr = mctp_hdr(skb); + + addr = msg->msg_name; + addr->smctp_family = AF_MCTP; + addr->__smctp_pad0 = 0; + addr->smctp_network = cb->net; + addr->smctp_addr.s_addr = hdr->src; + addr->smctp_type = type; + addr->smctp_tag = hdr->flags_seq_tag & + (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); + addr->__smctp_pad1 = 0; + msg->msg_namelen = sizeof(*addr); + + if (msk->addr_ext) { + DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae, + msg->msg_name); + msg->msg_namelen = sizeof(*ae); + ae->smctp_ifindex = cb->ifindex; + ae->smctp_halen = cb->halen; + memset(ae->__smctp_pad0, 0x0, sizeof(ae->__smctp_pad0)); + memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr)); + memcpy(ae->smctp_haddr, cb->haddr, cb->halen); + } + } + + rc = len; + + if (flags & MSG_TRUNC) + rc = msglen; + +out_free: + skb_free_datagram(sk, skb); + return rc; +} + +/* We're done with the key; invalidate, stop reassembly, and remove from lists. + */ +static void __mctp_key_remove(struct mctp_sk_key *key, struct net *net, + unsigned long flags, unsigned long reason) +__releases(&key->lock) +__must_hold(&net->mctp.keys_lock) +{ + struct sk_buff *skb; + + trace_mctp_key_release(key, reason); + skb = key->reasm_head; + key->reasm_head = NULL; + key->reasm_dead = true; + key->valid = false; + mctp_dev_release_key(key->dev, key); + spin_unlock_irqrestore(&key->lock, flags); + + if (!hlist_unhashed(&key->hlist)) { + hlist_del_init(&key->hlist); + hlist_del_init(&key->sklist); + /* unref for the lists */ + mctp_key_unref(key); + } + + kfree_skb(skb); +} + +static int mctp_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk); + int val; + + if (level != SOL_MCTP) + return -EINVAL; + + if (optname == MCTP_OPT_ADDR_EXT) { + if (optlen != sizeof(int)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + msk->addr_ext = val; + return 0; + } + + return -ENOPROTOOPT; +} + +static int mctp_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk); + int len, val; + + if (level != SOL_MCTP) + return -EINVAL; + + if (get_user(len, optlen)) + return -EFAULT; + + if (optname == MCTP_OPT_ADDR_EXT) { + if (len != sizeof(int)) + return -EINVAL; + val = !!msk->addr_ext; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + return 0; + } + + return -ENOPROTOOPT; +} + +/* helpers for reading/writing the tag ioc, handling compatibility across the + * two versions, and some basic API error checking + */ +static int mctp_ioctl_tag_copy_from_user(unsigned long arg, + struct mctp_ioc_tag_ctl2 *ctl, + bool tagv2) +{ + struct mctp_ioc_tag_ctl ctl_compat; + unsigned long size; + void *ptr; + int rc; + + if (tagv2) { + size = sizeof(*ctl); + ptr = ctl; + } else { + size = sizeof(ctl_compat); + ptr = &ctl_compat; + } + + rc = copy_from_user(ptr, (void __user *)arg, size); + if (rc) + return -EFAULT; + + if (!tagv2) { + /* compat, using defaults for new fields */ + ctl->net = MCTP_INITIAL_DEFAULT_NET; + ctl->peer_addr = ctl_compat.peer_addr; + ctl->local_addr = MCTP_ADDR_ANY; + ctl->flags = ctl_compat.flags; + ctl->tag = ctl_compat.tag; + } + + if (ctl->flags) + return -EINVAL; + + if (ctl->local_addr != MCTP_ADDR_ANY && + ctl->local_addr != MCTP_ADDR_NULL) + return -EINVAL; + + return 0; +} + +static int mctp_ioctl_tag_copy_to_user(unsigned long arg, + struct mctp_ioc_tag_ctl2 *ctl, + bool tagv2) +{ + struct mctp_ioc_tag_ctl ctl_compat; + unsigned long size; + void *ptr; + int rc; + + if (tagv2) { + ptr = ctl; + size = sizeof(*ctl); + } else { + ctl_compat.peer_addr = ctl->peer_addr; + ctl_compat.tag = ctl->tag; + ctl_compat.flags = ctl->flags; + + ptr = &ctl_compat; + size = sizeof(ctl_compat); + } + + rc = copy_to_user((void __user *)arg, ptr, size); + if (rc) + return -EFAULT; + + return 0; +} + +static int mctp_ioctl_alloctag(struct mctp_sock *msk, bool tagv2, + unsigned long arg) +{ + struct net *net = sock_net(&msk->sk); + struct mctp_sk_key *key = NULL; + struct mctp_ioc_tag_ctl2 ctl; + unsigned long flags; + u8 tag; + int rc; + + rc = mctp_ioctl_tag_copy_from_user(arg, &ctl, tagv2); + if (rc) + return rc; + + if (ctl.tag) + return -EINVAL; + + key = mctp_alloc_local_tag(msk, ctl.net, MCTP_ADDR_ANY, + ctl.peer_addr, true, &tag); + if (IS_ERR(key)) + return PTR_ERR(key); + + ctl.tag = tag | MCTP_TAG_OWNER | MCTP_TAG_PREALLOC; + rc = mctp_ioctl_tag_copy_to_user(arg, &ctl, tagv2); + if (rc) { + unsigned long fl2; + /* Unwind our key allocation: the keys list lock needs to be + * taken before the individual key locks, and we need a valid + * flags value (fl2) to pass to __mctp_key_remove, hence the + * second spin_lock_irqsave() rather than a plain spin_lock(). + */ + spin_lock_irqsave(&net->mctp.keys_lock, flags); + spin_lock_irqsave(&key->lock, fl2); + __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_DROPPED); + mctp_key_unref(key); + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + return rc; + } + + mctp_key_unref(key); + return 0; +} + +static int mctp_ioctl_droptag(struct mctp_sock *msk, bool tagv2, + unsigned long arg) +{ + struct net *net = sock_net(&msk->sk); + struct mctp_ioc_tag_ctl2 ctl; + unsigned long flags, fl2; + struct mctp_sk_key *key; + struct hlist_node *tmp; + int rc; + u8 tag; + + rc = mctp_ioctl_tag_copy_from_user(arg, &ctl, tagv2); + if (rc) + return rc; + + /* Must be a local tag, TO set, preallocated */ + if ((ctl.tag & ~MCTP_TAG_MASK) != (MCTP_TAG_OWNER | MCTP_TAG_PREALLOC)) + return -EINVAL; + + tag = ctl.tag & MCTP_TAG_MASK; + rc = -EINVAL; + + if (ctl.peer_addr == MCTP_ADDR_NULL) + ctl.peer_addr = MCTP_ADDR_ANY; + + spin_lock_irqsave(&net->mctp.keys_lock, flags); + hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { + /* we do an irqsave here, even though we know the irq state, + * so we have the flags to pass to __mctp_key_remove + */ + spin_lock_irqsave(&key->lock, fl2); + if (key->manual_alloc && + ctl.net == key->net && + ctl.peer_addr == key->peer_addr && + tag == key->tag) { + __mctp_key_remove(key, net, fl2, + MCTP_TRACE_KEY_DROPPED); + rc = 0; + } else { + spin_unlock_irqrestore(&key->lock, fl2); + } + } + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + return rc; +} + +static int mctp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk); + bool tagv2 = false; + + switch (cmd) { + case SIOCMCTPALLOCTAG2: + case SIOCMCTPALLOCTAG: + tagv2 = cmd == SIOCMCTPALLOCTAG2; + return mctp_ioctl_alloctag(msk, tagv2, arg); + case SIOCMCTPDROPTAG: + case SIOCMCTPDROPTAG2: + tagv2 = cmd == SIOCMCTPDROPTAG2; + return mctp_ioctl_droptag(msk, tagv2, arg); + } + + return -EINVAL; +} + +#ifdef CONFIG_COMPAT +static int mctp_compat_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + + switch (cmd) { + /* These have compatible ptr layouts */ + case SIOCMCTPALLOCTAG: + case SIOCMCTPDROPTAG: + return mctp_ioctl(sock, cmd, (unsigned long)argp); + } + + return -ENOIOCTLCMD; +} +#endif + +static const struct proto_ops mctp_dgram_ops = { + .family = PF_MCTP, + .release = mctp_release, + .bind = mctp_bind, + .connect = mctp_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = datagram_poll, + .ioctl = mctp_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = mctp_setsockopt, + .getsockopt = mctp_getsockopt, + .sendmsg = mctp_sendmsg, + .recvmsg = mctp_recvmsg, + .mmap = sock_no_mmap, +#ifdef CONFIG_COMPAT + .compat_ioctl = mctp_compat_ioctl, +#endif +}; + +static void mctp_sk_expire_keys(struct timer_list *timer) +{ + struct mctp_sock *msk = container_of(timer, struct mctp_sock, + key_expiry); + struct net *net = sock_net(&msk->sk); + unsigned long next_expiry, flags, fl2; + struct mctp_sk_key *key; + struct hlist_node *tmp; + bool next_expiry_valid = false; + + spin_lock_irqsave(&net->mctp.keys_lock, flags); + + hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { + /* don't expire. manual_alloc is immutable, no locking + * required. + */ + if (key->manual_alloc) + continue; + + spin_lock_irqsave(&key->lock, fl2); + if (!time_after_eq(key->expiry, jiffies)) { + __mctp_key_remove(key, net, fl2, + MCTP_TRACE_KEY_TIMEOUT); + continue; + } + + if (next_expiry_valid) { + if (time_before(key->expiry, next_expiry)) + next_expiry = key->expiry; + } else { + next_expiry = key->expiry; + next_expiry_valid = true; + } + spin_unlock_irqrestore(&key->lock, fl2); + } + + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + if (next_expiry_valid) + mod_timer(timer, next_expiry); +} + +static int mctp_sk_init(struct sock *sk) +{ + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + + INIT_HLIST_HEAD(&msk->keys); + timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0); + msk->bind_peer_set = false; + return 0; +} + +static void mctp_sk_close(struct sock *sk, long timeout) +{ + sk_common_release(sk); +} + +static int mctp_sk_hash(struct sock *sk) +{ + struct net *net = sock_net(sk); + struct sock *existing; + struct mctp_sock *msk; + mctp_eid_t remote; + u32 hash; + int rc; + + msk = container_of(sk, struct mctp_sock, sk); + + if (msk->bind_peer_set) + remote = msk->bind_peer_addr; + else + remote = MCTP_ADDR_ANY; + hash = mctp_bind_hash(msk->bind_type, msk->bind_local_addr, remote); + + mutex_lock(&net->mctp.bind_lock); + + /* Prevent duplicate binds. */ + sk_for_each(existing, &net->mctp.binds[hash]) { + struct mctp_sock *mex = + container_of(existing, struct mctp_sock, sk); + + bool same_peer = (mex->bind_peer_set && msk->bind_peer_set && + mex->bind_peer_addr == msk->bind_peer_addr) || + (!mex->bind_peer_set && !msk->bind_peer_set); + + if (mex->bind_type == msk->bind_type && + mex->bind_local_addr == msk->bind_local_addr && same_peer && + mex->bind_net == msk->bind_net) { + rc = -EADDRINUSE; + goto out; + } + } + + /* Bind lookup runs under RCU, remain live during that. */ + sock_set_flag(sk, SOCK_RCU_FREE); + + sk_add_node_rcu(sk, &net->mctp.binds[hash]); + rc = 0; + +out: + mutex_unlock(&net->mctp.bind_lock); + return rc; +} + +static void mctp_sk_unhash(struct sock *sk) +{ + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct net *net = sock_net(sk); + unsigned long flags, fl2; + struct mctp_sk_key *key; + struct hlist_node *tmp; + + /* remove from any type-based binds */ + mutex_lock(&net->mctp.bind_lock); + sk_del_node_init_rcu(sk); + mutex_unlock(&net->mctp.bind_lock); + + /* remove tag allocations */ + spin_lock_irqsave(&net->mctp.keys_lock, flags); + hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { + spin_lock_irqsave(&key->lock, fl2); + __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED); + } + sock_set_flag(sk, SOCK_DEAD); + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + /* Since there are no more tag allocations (we have removed all of the + * keys), stop any pending expiry events. the timer cannot be re-queued + * as the sk is no longer observable + */ + timer_delete_sync(&msk->key_expiry); +} + +static void mctp_sk_destruct(struct sock *sk) +{ + skb_queue_purge(&sk->sk_receive_queue); +} + +static struct proto mctp_proto = { + .name = "MCTP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct mctp_sock), + .init = mctp_sk_init, + .close = mctp_sk_close, + .hash = mctp_sk_hash, + .unhash = mctp_sk_unhash, +}; + +static int mctp_pf_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + const struct proto_ops *ops; + struct proto *proto; + struct sock *sk; + int rc; + + if (protocol) + return -EPROTONOSUPPORT; + + /* only datagram sockets are supported */ + if (sock->type != SOCK_DGRAM) + return -ESOCKTNOSUPPORT; + + proto = &mctp_proto; + ops = &mctp_dgram_ops; + + sock->state = SS_UNCONNECTED; + sock->ops = ops; + + sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + sk->sk_destruct = mctp_sk_destruct; + + rc = 0; + if (sk->sk_prot->init) + rc = sk->sk_prot->init(sk); + + if (rc) + goto err_sk_put; + + return 0; + +err_sk_put: + sock_orphan(sk); + sock_put(sk); + return rc; +} + +static struct net_proto_family mctp_pf = { + .family = PF_MCTP, + .create = mctp_pf_create, + .owner = THIS_MODULE, +}; + +static __init int mctp_init(void) +{ + int rc; + + /* ensure our uapi tag definitions match the header format */ + BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO); + BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK); + + pr_info("mctp: management component transport protocol core\n"); + + rc = sock_register(&mctp_pf); + if (rc) + return rc; + + rc = proto_register(&mctp_proto, 0); + if (rc) + goto err_unreg_sock; + + rc = mctp_routes_init(); + if (rc) + goto err_unreg_proto; + + rc = mctp_neigh_init(); + if (rc) + goto err_unreg_routes; + + rc = mctp_device_init(); + if (rc) + goto err_unreg_neigh; + + return 0; + +err_unreg_neigh: + mctp_neigh_exit(); +err_unreg_routes: + mctp_routes_exit(); +err_unreg_proto: + proto_unregister(&mctp_proto); +err_unreg_sock: + sock_unregister(PF_MCTP); + + return rc; +} + +static __exit void mctp_exit(void) +{ + mctp_device_exit(); + mctp_neigh_exit(); + mctp_routes_exit(); + proto_unregister(&mctp_proto); + sock_unregister(PF_MCTP); +} + +subsys_initcall(mctp_init); +module_exit(mctp_exit); + +MODULE_DESCRIPTION("MCTP core"); +MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>"); + +MODULE_ALIAS_NETPROTO(PF_MCTP); + +#if IS_ENABLED(CONFIG_MCTP_TEST) +#include "test/sock-test.c" +#endif diff --git a/net/mctp/device.c b/net/mctp/device.c new file mode 100644 index 000000000000..4d404edd7446 --- /dev/null +++ b/net/mctp/device.c @@ -0,0 +1,561 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Management Component Transport Protocol (MCTP) - device implementation. + * + * Copyright (c) 2021 Code Construct + * Copyright (c) 2021 Google + */ + +#include <linux/if_arp.h> +#include <linux/if_link.h> +#include <linux/mctp.h> +#include <linux/netdevice.h> +#include <linux/rcupdate.h> +#include <linux/rtnetlink.h> + +#include <net/addrconf.h> +#include <net/netlink.h> +#include <net/mctp.h> +#include <net/mctpdevice.h> +#include <net/sock.h> + +struct mctp_dump_cb { + unsigned long ifindex; + size_t a_idx; +}; + +/* unlocked: caller must hold rcu_read_lock. + * Returned mctp_dev has its refcount incremented, or NULL if unset. + */ +struct mctp_dev *__mctp_dev_get(const struct net_device *dev) +{ + struct mctp_dev *mdev = rcu_dereference(dev->mctp_ptr); + + /* RCU guarantees that any mdev is still live. + * Zero refcount implies a pending free, return NULL. + */ + if (mdev) + if (!refcount_inc_not_zero(&mdev->refs)) + return NULL; + return mdev; +} + +/* Returned mctp_dev does not have refcount incremented. The returned pointer + * remains live while rtnl_lock is held, as that prevents mctp_unregister() + */ +struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev) +{ + return rtnl_dereference(dev->mctp_ptr); +} + +static int mctp_addrinfo_size(void) +{ + return NLMSG_ALIGN(sizeof(struct ifaddrmsg)) + + nla_total_size(1) // IFA_LOCAL + + nla_total_size(1) // IFA_ADDRESS + ; +} + +/* flag should be NLM_F_MULTI for dump calls */ +static int mctp_fill_addrinfo(struct sk_buff *skb, + struct mctp_dev *mdev, mctp_eid_t eid, + int msg_type, u32 portid, u32 seq, int flag) +{ + struct ifaddrmsg *hdr; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, portid, seq, + msg_type, sizeof(*hdr), flag); + if (!nlh) + return -EMSGSIZE; + + hdr = nlmsg_data(nlh); + hdr->ifa_family = AF_MCTP; + hdr->ifa_prefixlen = 0; + hdr->ifa_flags = 0; + hdr->ifa_scope = 0; + hdr->ifa_index = mdev->dev->ifindex; + + if (nla_put_u8(skb, IFA_LOCAL, eid)) + goto cancel; + + if (nla_put_u8(skb, IFA_ADDRESS, eid)) + goto cancel; + + nlmsg_end(skb, nlh); + + return 0; + +cancel: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int mctp_dump_dev_addrinfo(struct mctp_dev *mdev, struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct mctp_dump_cb *mcb = (void *)cb->ctx; + u32 portid, seq; + int rc = 0; + + portid = NETLINK_CB(cb->skb).portid; + seq = cb->nlh->nlmsg_seq; + for (; mcb->a_idx < mdev->num_addrs; mcb->a_idx++) { + rc = mctp_fill_addrinfo(skb, mdev, mdev->addrs[mcb->a_idx], + RTM_NEWADDR, portid, seq, NLM_F_MULTI); + if (rc < 0) + break; + } + + return rc; +} + +static int mctp_dump_addrinfo(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct mctp_dump_cb *mcb = (void *)cb->ctx; + struct net *net = sock_net(skb->sk); + struct net_device *dev; + struct ifaddrmsg *hdr; + struct mctp_dev *mdev; + int ifindex = 0, rc; + + /* Filter by ifindex if a header is provided */ + hdr = nlmsg_payload(cb->nlh, sizeof(*hdr)); + if (hdr) { + ifindex = hdr->ifa_index; + } else { + if (cb->strict_check) { + NL_SET_ERR_MSG(cb->extack, "mctp: Invalid header for addr dump request"); + return -EINVAL; + } + } + + rcu_read_lock(); + for_each_netdev_dump(net, dev, mcb->ifindex) { + if (ifindex && ifindex != dev->ifindex) + continue; + mdev = __mctp_dev_get(dev); + if (!mdev) + continue; + rc = mctp_dump_dev_addrinfo(mdev, skb, cb); + mctp_dev_put(mdev); + if (rc < 0) + break; + mcb->a_idx = 0; + } + rcu_read_unlock(); + + return skb->len; +} + +static void mctp_addr_notify(struct mctp_dev *mdev, mctp_eid_t eid, int msg_type, + struct sk_buff *req_skb, struct nlmsghdr *req_nlh) +{ + u32 portid = NETLINK_CB(req_skb).portid; + struct net *net = dev_net(mdev->dev); + struct sk_buff *skb; + int rc = -ENOBUFS; + + skb = nlmsg_new(mctp_addrinfo_size(), GFP_KERNEL); + if (!skb) + goto out; + + rc = mctp_fill_addrinfo(skb, mdev, eid, msg_type, + portid, req_nlh->nlmsg_seq, 0); + if (rc < 0) { + WARN_ON_ONCE(rc == -EMSGSIZE); + goto out; + } + + rtnl_notify(skb, net, portid, RTNLGRP_MCTP_IFADDR, req_nlh, GFP_KERNEL); + return; +out: + kfree_skb(skb); + rtnl_set_sk_err(net, RTNLGRP_MCTP_IFADDR, rc); +} + +static const struct nla_policy ifa_mctp_policy[IFA_MAX + 1] = { + [IFA_ADDRESS] = { .type = NLA_U8 }, + [IFA_LOCAL] = { .type = NLA_U8 }, +}; + +static int mctp_rtm_newaddr(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[IFA_MAX + 1]; + struct net_device *dev; + struct mctp_addr *addr; + struct mctp_dev *mdev; + struct ifaddrmsg *ifm; + unsigned long flags; + u8 *tmp_addrs; + int rc; + + rc = nlmsg_parse(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_mctp_policy, + extack); + if (rc < 0) + return rc; + + ifm = nlmsg_data(nlh); + + if (tb[IFA_LOCAL]) + addr = nla_data(tb[IFA_LOCAL]); + else if (tb[IFA_ADDRESS]) + addr = nla_data(tb[IFA_ADDRESS]); + else + return -EINVAL; + + /* find device */ + dev = __dev_get_by_index(net, ifm->ifa_index); + if (!dev) + return -ENODEV; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODEV; + + if (!mctp_address_unicast(addr->s_addr)) + return -EINVAL; + + /* Prevent duplicates. Under RTNL so don't need to lock for reading */ + if (memchr(mdev->addrs, addr->s_addr, mdev->num_addrs)) + return -EEXIST; + + tmp_addrs = kmalloc(mdev->num_addrs + 1, GFP_KERNEL); + if (!tmp_addrs) + return -ENOMEM; + memcpy(tmp_addrs, mdev->addrs, mdev->num_addrs); + tmp_addrs[mdev->num_addrs] = addr->s_addr; + + /* Lock to write */ + spin_lock_irqsave(&mdev->addrs_lock, flags); + mdev->num_addrs++; + swap(mdev->addrs, tmp_addrs); + spin_unlock_irqrestore(&mdev->addrs_lock, flags); + + kfree(tmp_addrs); + + mctp_addr_notify(mdev, addr->s_addr, RTM_NEWADDR, skb, nlh); + mctp_route_add_local(mdev, addr->s_addr); + + return 0; +} + +static int mctp_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[IFA_MAX + 1]; + struct net_device *dev; + struct mctp_addr *addr; + struct mctp_dev *mdev; + struct ifaddrmsg *ifm; + unsigned long flags; + u8 *pos; + int rc; + + rc = nlmsg_parse(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_mctp_policy, + extack); + if (rc < 0) + return rc; + + ifm = nlmsg_data(nlh); + + if (tb[IFA_LOCAL]) + addr = nla_data(tb[IFA_LOCAL]); + else if (tb[IFA_ADDRESS]) + addr = nla_data(tb[IFA_ADDRESS]); + else + return -EINVAL; + + /* find device */ + dev = __dev_get_by_index(net, ifm->ifa_index); + if (!dev) + return -ENODEV; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODEV; + + pos = memchr(mdev->addrs, addr->s_addr, mdev->num_addrs); + if (!pos) + return -ENOENT; + + rc = mctp_route_remove_local(mdev, addr->s_addr); + // we can ignore -ENOENT in the case a route was already removed + if (rc < 0 && rc != -ENOENT) + return rc; + + spin_lock_irqsave(&mdev->addrs_lock, flags); + memmove(pos, pos + 1, mdev->num_addrs - 1 - (pos - mdev->addrs)); + mdev->num_addrs--; + spin_unlock_irqrestore(&mdev->addrs_lock, flags); + + mctp_addr_notify(mdev, addr->s_addr, RTM_DELADDR, skb, nlh); + + return 0; +} + +void mctp_dev_hold(struct mctp_dev *mdev) +{ + refcount_inc(&mdev->refs); +} + +void mctp_dev_put(struct mctp_dev *mdev) +{ + if (mdev && refcount_dec_and_test(&mdev->refs)) { + kfree(mdev->addrs); + dev_put(mdev->dev); + kfree_rcu(mdev, rcu); + } +} + +void mctp_dev_release_key(struct mctp_dev *dev, struct mctp_sk_key *key) + __must_hold(&key->lock) +{ + if (!dev) + return; + if (dev->ops && dev->ops->release_flow) + dev->ops->release_flow(dev, key); + key->dev = NULL; + mctp_dev_put(dev); +} + +void mctp_dev_set_key(struct mctp_dev *dev, struct mctp_sk_key *key) + __must_hold(&key->lock) +{ + mctp_dev_hold(dev); + key->dev = dev; +} + +static struct mctp_dev *mctp_add_dev(struct net_device *dev) +{ + struct mctp_dev *mdev; + + ASSERT_RTNL(); + + mdev = kzalloc(sizeof(*mdev), GFP_KERNEL); + if (!mdev) + return ERR_PTR(-ENOMEM); + + spin_lock_init(&mdev->addrs_lock); + + mdev->net = mctp_default_net(dev_net(dev)); + + /* associate to net_device */ + refcount_set(&mdev->refs, 1); + rcu_assign_pointer(dev->mctp_ptr, mdev); + + dev_hold(dev); + mdev->dev = dev; + + return mdev; +} + +static int mctp_fill_link_af(struct sk_buff *skb, + const struct net_device *dev, u32 ext_filter_mask) +{ + struct mctp_dev *mdev; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODATA; + if (nla_put_u32(skb, IFLA_MCTP_NET, mdev->net)) + return -EMSGSIZE; + if (nla_put_u8(skb, IFLA_MCTP_PHYS_BINDING, mdev->binding)) + return -EMSGSIZE; + return 0; +} + +static size_t mctp_get_link_af_size(const struct net_device *dev, + u32 ext_filter_mask) +{ + struct mctp_dev *mdev; + unsigned int ret; + + /* caller holds RCU */ + mdev = __mctp_dev_get(dev); + if (!mdev) + return 0; + ret = nla_total_size(4); /* IFLA_MCTP_NET */ + ret += nla_total_size(1); /* IFLA_MCTP_PHYS_BINDING */ + mctp_dev_put(mdev); + return ret; +} + +static const struct nla_policy ifla_af_mctp_policy[IFLA_MCTP_MAX + 1] = { + [IFLA_MCTP_NET] = { .type = NLA_U32 }, +}; + +static int mctp_set_link_af(struct net_device *dev, const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_MCTP_MAX + 1]; + struct mctp_dev *mdev; + int rc; + + rc = nla_parse_nested(tb, IFLA_MCTP_MAX, attr, ifla_af_mctp_policy, + NULL); + if (rc) + return rc; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return 0; + + if (tb[IFLA_MCTP_NET]) + WRITE_ONCE(mdev->net, nla_get_u32(tb[IFLA_MCTP_NET])); + + return 0; +} + +/* Matches netdev types that should have MCTP handling */ +static bool mctp_known(struct net_device *dev) +{ + /* only register specific types (inc. NONE for TUN devices) */ + return dev->type == ARPHRD_MCTP || + dev->type == ARPHRD_LOOPBACK || + dev->type == ARPHRD_NONE; +} + +static void mctp_unregister(struct net_device *dev) +{ + struct mctp_dev *mdev; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return; + + RCU_INIT_POINTER(mdev->dev->mctp_ptr, NULL); + + mctp_route_remove_dev(mdev); + mctp_neigh_remove_dev(mdev); + + mctp_dev_put(mdev); +} + +static int mctp_register(struct net_device *dev) +{ + struct mctp_dev *mdev; + + /* Already registered? */ + if (rtnl_dereference(dev->mctp_ptr)) + return 0; + + /* only register specific types */ + if (!mctp_known(dev)) + return 0; + + mdev = mctp_add_dev(dev); + if (IS_ERR(mdev)) + return PTR_ERR(mdev); + + return 0; +} + +static int mctp_dev_notify(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + int rc; + + switch (event) { + case NETDEV_REGISTER: + rc = mctp_register(dev); + if (rc) + return notifier_from_errno(rc); + break; + case NETDEV_UNREGISTER: + mctp_unregister(dev); + break; + } + + return NOTIFY_OK; +} + +static int mctp_register_netdevice(struct net_device *dev, + const struct mctp_netdev_ops *ops, + enum mctp_phys_binding binding) +{ + struct mctp_dev *mdev; + + mdev = mctp_add_dev(dev); + if (IS_ERR(mdev)) + return PTR_ERR(mdev); + + mdev->ops = ops; + mdev->binding = binding; + + return register_netdevice(dev); +} + +int mctp_register_netdev(struct net_device *dev, + const struct mctp_netdev_ops *ops, + enum mctp_phys_binding binding) +{ + int rc; + + rtnl_lock(); + rc = mctp_register_netdevice(dev, ops, binding); + rtnl_unlock(); + + return rc; +} +EXPORT_SYMBOL_GPL(mctp_register_netdev); + +void mctp_unregister_netdev(struct net_device *dev) +{ + unregister_netdev(dev); +} +EXPORT_SYMBOL_GPL(mctp_unregister_netdev); + +static struct rtnl_af_ops mctp_af_ops = { + .family = AF_MCTP, + .fill_link_af = mctp_fill_link_af, + .get_link_af_size = mctp_get_link_af_size, + .set_link_af = mctp_set_link_af, +}; + +static struct notifier_block mctp_dev_nb = { + .notifier_call = mctp_dev_notify, + .priority = ADDRCONF_NOTIFY_PRIORITY, +}; + +static const struct rtnl_msg_handler mctp_device_rtnl_msg_handlers[] = { + {.owner = THIS_MODULE, .protocol = PF_MCTP, .msgtype = RTM_NEWADDR, + .doit = mctp_rtm_newaddr}, + {.owner = THIS_MODULE, .protocol = PF_MCTP, .msgtype = RTM_DELADDR, + .doit = mctp_rtm_deladdr}, + {.owner = THIS_MODULE, .protocol = PF_MCTP, .msgtype = RTM_GETADDR, + .dumpit = mctp_dump_addrinfo}, +}; + +int __init mctp_device_init(void) +{ + int err; + + register_netdevice_notifier(&mctp_dev_nb); + + err = rtnl_af_register(&mctp_af_ops); + if (err) + goto err_notifier; + + err = rtnl_register_many(mctp_device_rtnl_msg_handlers); + if (err) + goto err_af; + + return 0; +err_af: + rtnl_af_unregister(&mctp_af_ops); +err_notifier: + unregister_netdevice_notifier(&mctp_dev_nb); + return err; +} + +void __exit mctp_device_exit(void) +{ + rtnl_unregister_many(mctp_device_rtnl_msg_handlers); + rtnl_af_unregister(&mctp_af_ops); + unregister_netdevice_notifier(&mctp_dev_nb); +} diff --git a/net/mctp/neigh.c b/net/mctp/neigh.c new file mode 100644 index 000000000000..05b899f22d90 --- /dev/null +++ b/net/mctp/neigh.c @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Management Component Transport Protocol (MCTP) - routing + * implementation. + * + * This is currently based on a simple routing table, with no dst cache. The + * number of routes should stay fairly small, so the lookup cost is small. + * + * Copyright (c) 2021 Code Construct + * Copyright (c) 2021 Google + */ + +#include <linux/idr.h> +#include <linux/mctp.h> +#include <linux/netdevice.h> +#include <linux/rtnetlink.h> +#include <linux/skbuff.h> + +#include <net/mctp.h> +#include <net/mctpdevice.h> +#include <net/netlink.h> +#include <net/sock.h> + +static int mctp_neigh_add(struct mctp_dev *mdev, mctp_eid_t eid, + enum mctp_neigh_source source, + size_t lladdr_len, const void *lladdr) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_neigh *neigh; + int rc; + + mutex_lock(&net->mctp.neigh_lock); + if (mctp_neigh_lookup(mdev, eid, NULL) == 0) { + rc = -EEXIST; + goto out; + } + + if (lladdr_len > sizeof(neigh->ha)) { + rc = -EINVAL; + goto out; + } + + neigh = kzalloc(sizeof(*neigh), GFP_KERNEL); + if (!neigh) { + rc = -ENOMEM; + goto out; + } + INIT_LIST_HEAD(&neigh->list); + neigh->dev = mdev; + mctp_dev_hold(neigh->dev); + neigh->eid = eid; + neigh->source = source; + memcpy(neigh->ha, lladdr, lladdr_len); + + list_add_rcu(&neigh->list, &net->mctp.neighbours); + rc = 0; +out: + mutex_unlock(&net->mctp.neigh_lock); + return rc; +} + +static void __mctp_neigh_free(struct rcu_head *rcu) +{ + struct mctp_neigh *neigh = container_of(rcu, struct mctp_neigh, rcu); + + mctp_dev_put(neigh->dev); + kfree(neigh); +} + +/* Removes all neighbour entries referring to a device */ +void mctp_neigh_remove_dev(struct mctp_dev *mdev) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_neigh *neigh, *tmp; + + mutex_lock(&net->mctp.neigh_lock); + list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) { + if (neigh->dev == mdev) { + list_del_rcu(&neigh->list); + /* TODO: immediate RTM_DELNEIGH */ + call_rcu(&neigh->rcu, __mctp_neigh_free); + } + } + + mutex_unlock(&net->mctp.neigh_lock); +} + +static int mctp_neigh_remove(struct mctp_dev *mdev, mctp_eid_t eid, + enum mctp_neigh_source source) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_neigh *neigh, *tmp; + bool dropped = false; + + mutex_lock(&net->mctp.neigh_lock); + list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) { + if (neigh->dev == mdev && neigh->eid == eid && + neigh->source == source) { + list_del_rcu(&neigh->list); + /* TODO: immediate RTM_DELNEIGH */ + call_rcu(&neigh->rcu, __mctp_neigh_free); + dropped = true; + } + } + + mutex_unlock(&net->mctp.neigh_lock); + return dropped ? 0 : -ENOENT; +} + +static const struct nla_policy nd_mctp_policy[NDA_MAX + 1] = { + [NDA_DST] = { .type = NLA_U8 }, + [NDA_LLADDR] = { .type = NLA_BINARY, .len = MAX_ADDR_LEN }, +}; + +static int mctp_rtm_newneigh(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct net_device *dev; + struct mctp_dev *mdev; + struct ndmsg *ndm; + struct nlattr *tb[NDA_MAX + 1]; + int rc; + mctp_eid_t eid; + void *lladdr; + int lladdr_len; + + rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy, + extack); + if (rc < 0) { + NL_SET_ERR_MSG(extack, "lladdr too large?"); + return rc; + } + + if (!tb[NDA_DST]) { + NL_SET_ERR_MSG(extack, "Neighbour EID must be specified"); + return -EINVAL; + } + + if (!tb[NDA_LLADDR]) { + NL_SET_ERR_MSG(extack, "Neighbour lladdr must be specified"); + return -EINVAL; + } + + eid = nla_get_u8(tb[NDA_DST]); + if (!mctp_address_unicast(eid)) { + NL_SET_ERR_MSG(extack, "Invalid neighbour EID"); + return -EINVAL; + } + + lladdr = nla_data(tb[NDA_LLADDR]); + lladdr_len = nla_len(tb[NDA_LLADDR]); + + ndm = nlmsg_data(nlh); + + dev = __dev_get_by_index(net, ndm->ndm_ifindex); + if (!dev) + return -ENODEV; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODEV; + + if (lladdr_len != dev->addr_len) { + NL_SET_ERR_MSG(extack, "Wrong lladdr length"); + return -EINVAL; + } + + return mctp_neigh_add(mdev, eid, MCTP_NEIGH_STATIC, + lladdr_len, lladdr); +} + +static int mctp_rtm_delneigh(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[NDA_MAX + 1]; + struct net_device *dev; + struct mctp_dev *mdev; + struct ndmsg *ndm; + int rc; + mctp_eid_t eid; + + rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy, + extack); + if (rc < 0) { + NL_SET_ERR_MSG(extack, "incorrect format"); + return rc; + } + + if (!tb[NDA_DST]) { + NL_SET_ERR_MSG(extack, "Neighbour EID must be specified"); + return -EINVAL; + } + eid = nla_get_u8(tb[NDA_DST]); + + ndm = nlmsg_data(nlh); + dev = __dev_get_by_index(net, ndm->ndm_ifindex); + if (!dev) + return -ENODEV; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODEV; + + return mctp_neigh_remove(mdev, eid, MCTP_NEIGH_STATIC); +} + +static int mctp_fill_neigh(struct sk_buff *skb, u32 portid, u32 seq, int event, + unsigned int flags, struct mctp_neigh *neigh) +{ + struct net_device *dev = neigh->dev->dev; + struct nlmsghdr *nlh; + struct ndmsg *hdr; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags); + if (!nlh) + return -EMSGSIZE; + + hdr = nlmsg_data(nlh); + hdr->ndm_family = AF_MCTP; + hdr->ndm_ifindex = dev->ifindex; + hdr->ndm_state = 0; // TODO other state bits? + if (neigh->source == MCTP_NEIGH_STATIC) + hdr->ndm_state |= NUD_PERMANENT; + hdr->ndm_flags = 0; + hdr->ndm_type = RTN_UNICAST; // TODO: is loopback RTN_LOCAL? + + if (nla_put_u8(skb, NDA_DST, neigh->eid)) + goto cancel; + + if (nla_put(skb, NDA_LLADDR, dev->addr_len, neigh->ha)) + goto cancel; + + nlmsg_end(skb, nlh); + + return 0; +cancel: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int mctp_rtm_getneigh(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + int rc, idx, req_ifindex; + struct mctp_neigh *neigh; + struct ndmsg *ndmsg; + struct { + int idx; + } *cbctx = (void *)cb->ctx; + + ndmsg = nlmsg_payload(cb->nlh, sizeof(*ndmsg)); + if (!ndmsg) + return -EINVAL; + + req_ifindex = ndmsg->ndm_ifindex; + + idx = 0; + rcu_read_lock(); + list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) { + if (idx < cbctx->idx) + goto cont; + + rc = 0; + if (req_ifindex == 0 || req_ifindex == neigh->dev->dev->ifindex) + rc = mctp_fill_neigh(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + RTM_NEWNEIGH, NLM_F_MULTI, neigh); + + if (rc) + break; +cont: + idx++; + } + rcu_read_unlock(); + + cbctx->idx = idx; + return skb->len; +} + +int mctp_neigh_lookup(struct mctp_dev *mdev, mctp_eid_t eid, void *ret_hwaddr) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_neigh *neigh; + int rc = -EHOSTUNREACH; // TODO: or ENOENT? + + rcu_read_lock(); + list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) { + if (mdev == neigh->dev && eid == neigh->eid) { + if (ret_hwaddr) + memcpy(ret_hwaddr, neigh->ha, + sizeof(neigh->ha)); + rc = 0; + break; + } + } + rcu_read_unlock(); + return rc; +} + +/* namespace registration */ +static int __net_init mctp_neigh_net_init(struct net *net) +{ + struct netns_mctp *ns = &net->mctp; + + INIT_LIST_HEAD(&ns->neighbours); + mutex_init(&ns->neigh_lock); + return 0; +} + +static void __net_exit mctp_neigh_net_exit(struct net *net) +{ + struct netns_mctp *ns = &net->mctp; + struct mctp_neigh *neigh; + + list_for_each_entry(neigh, &ns->neighbours, list) + call_rcu(&neigh->rcu, __mctp_neigh_free); +} + +/* net namespace implementation */ + +static struct pernet_operations mctp_net_ops = { + .init = mctp_neigh_net_init, + .exit = mctp_neigh_net_exit, +}; + +static const struct rtnl_msg_handler mctp_neigh_rtnl_msg_handlers[] = { + {THIS_MODULE, PF_MCTP, RTM_NEWNEIGH, mctp_rtm_newneigh, NULL, 0}, + {THIS_MODULE, PF_MCTP, RTM_DELNEIGH, mctp_rtm_delneigh, NULL, 0}, + {THIS_MODULE, PF_MCTP, RTM_GETNEIGH, NULL, mctp_rtm_getneigh, 0}, +}; + +int __init mctp_neigh_init(void) +{ + int err; + + err = register_pernet_subsys(&mctp_net_ops); + if (err) + return err; + + err = rtnl_register_many(mctp_neigh_rtnl_msg_handlers); + if (err) + unregister_pernet_subsys(&mctp_net_ops); + + return err; +} + +void mctp_neigh_exit(void) +{ + rtnl_unregister_many(mctp_neigh_rtnl_msg_handlers); + unregister_pernet_subsys(&mctp_net_ops); +} diff --git a/net/mctp/route.c b/net/mctp/route.c new file mode 100644 index 000000000000..2ac4011a953f --- /dev/null +++ b/net/mctp/route.c @@ -0,0 +1,1790 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Management Component Transport Protocol (MCTP) - routing + * implementation. + * + * This is currently based on a simple routing table, with no dst cache. The + * number of routes should stay fairly small, so the lookup cost is small. + * + * Copyright (c) 2021 Code Construct + * Copyright (c) 2021 Google + */ + +#include <linux/idr.h> +#include <linux/kconfig.h> +#include <linux/mctp.h> +#include <linux/netdevice.h> +#include <linux/rtnetlink.h> +#include <linux/skbuff.h> + +#include <kunit/static_stub.h> + +#include <uapi/linux/if_arp.h> + +#include <net/mctp.h> +#include <net/mctpdevice.h> +#include <net/netlink.h> +#include <net/sock.h> + +#include <trace/events/mctp.h> + +static const unsigned int mctp_message_maxlen = 64 * 1024; +static const unsigned long mctp_key_lifetime = 6 * CONFIG_HZ; + +static void mctp_flow_prepare_output(struct sk_buff *skb, struct mctp_dev *dev); + +/* route output callbacks */ +static int mctp_dst_discard(struct mctp_dst *dst, struct sk_buff *skb) +{ + kfree_skb(skb); + return 0; +} + +static struct mctp_sock *mctp_lookup_bind_details(struct net *net, + struct sk_buff *skb, + u8 type, u8 dest, + u8 src, bool allow_net_any) +{ + struct mctp_skb_cb *cb = mctp_cb(skb); + struct sock *sk; + u8 hash; + + WARN_ON_ONCE(!rcu_read_lock_held()); + + hash = mctp_bind_hash(type, dest, src); + + sk_for_each_rcu(sk, &net->mctp.binds[hash]) { + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + + if (!allow_net_any && msk->bind_net == MCTP_NET_ANY) + continue; + + if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net) + continue; + + if (msk->bind_type != type) + continue; + + if (msk->bind_peer_set && + !mctp_address_matches(msk->bind_peer_addr, src)) + continue; + + if (!mctp_address_matches(msk->bind_local_addr, dest)) + continue; + + return msk; + } + + return NULL; +} + +static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +{ + struct mctp_sock *msk; + struct mctp_hdr *mh; + u8 type; + + /* TODO: look up in skb->cb? */ + mh = mctp_hdr(skb); + + if (!skb_headlen(skb)) + return NULL; + + type = (*(u8 *)skb->data) & 0x7f; + + /* Look for binds in order of widening scope. A given destination or + * source address also implies matching on a particular network. + * + * - Matching destination and source + * - Matching destination + * - Matching source + * - Matching network, any address + * - Any network or address + */ + + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, mh->src, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, mh->src, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, MCTP_ADDR_ANY, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, + MCTP_ADDR_ANY, false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, + MCTP_ADDR_ANY, true); + if (msk) + return msk; + + return NULL; +} + +/* A note on the key allocations. + * + * struct net->mctp.keys contains our set of currently-allocated keys for + * MCTP tag management. The lookup tuple for these is the peer EID, + * local EID and MCTP tag. + * + * In some cases, the peer EID may be MCTP_EID_ANY: for example, when a + * broadcast message is sent, we may receive responses from any peer EID. + * Because the broadcast dest address is equivalent to ANY, we create + * a key with (local = local-eid, peer = ANY). This allows a match on the + * incoming broadcast responses from any peer. + * + * We perform lookups when packets are received, and when tags are allocated + * in two scenarios: + * + * - when a packet is sent, with a locally-owned tag: we need to find an + * unused tag value for the (local, peer) EID pair. + * + * - when a tag is manually allocated: we need to find an unused tag value + * for the peer EID, but don't have a specific local EID at that stage. + * + * in the latter case, on successful allocation, we end up with a tag with + * (local = ANY, peer = peer-eid). + * + * So, the key set allows both a local EID of ANY, as well as a peer EID of + * ANY in the lookup tuple. Both may be ANY if we prealloc for a broadcast. + * The matching (in mctp_key_match()) during lookup allows the match value to + * be ANY in either the dest or source addresses. + * + * When allocating (+ inserting) a tag, we need to check for conflicts amongst + * the existing tag set. This requires macthing either exactly on the local + * and peer addresses, or either being ANY. + */ + +static bool mctp_key_match(struct mctp_sk_key *key, unsigned int net, + mctp_eid_t local, mctp_eid_t peer, u8 tag) +{ + if (key->net != net) + return false; + + if (!mctp_address_matches(key->local_addr, local)) + return false; + + if (!mctp_address_matches(key->peer_addr, peer)) + return false; + + if (key->tag != tag) + return false; + + return true; +} + +/* returns a key (with key->lock held, and refcounted), or NULL if no such + * key exists. + */ +static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb, + unsigned int netid, mctp_eid_t peer, + unsigned long *irqflags) + __acquires(&key->lock) +{ + struct mctp_sk_key *key, *ret; + unsigned long flags; + struct mctp_hdr *mh; + u8 tag; + + mh = mctp_hdr(skb); + tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); + + ret = NULL; + spin_lock_irqsave(&net->mctp.keys_lock, flags); + + hlist_for_each_entry(key, &net->mctp.keys, hlist) { + if (!mctp_key_match(key, netid, mh->dest, peer, tag)) + continue; + + spin_lock(&key->lock); + if (key->valid) { + refcount_inc(&key->refs); + ret = key; + break; + } + spin_unlock(&key->lock); + } + + if (ret) { + spin_unlock(&net->mctp.keys_lock); + *irqflags = flags; + } else { + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + } + + return ret; +} + +static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk, + unsigned int net, + mctp_eid_t local, mctp_eid_t peer, + u8 tag, gfp_t gfp) +{ + struct mctp_sk_key *key; + + key = kzalloc(sizeof(*key), gfp); + if (!key) + return NULL; + + key->net = net; + key->peer_addr = peer; + key->local_addr = local; + key->tag = tag; + key->sk = &msk->sk; + key->valid = true; + spin_lock_init(&key->lock); + refcount_set(&key->refs, 1); + sock_hold(key->sk); + + return key; +} + +void mctp_key_unref(struct mctp_sk_key *key) +{ + unsigned long flags; + + if (!refcount_dec_and_test(&key->refs)) + return; + + /* even though no refs exist here, the lock allows us to stay + * consistent with the locking requirement of mctp_dev_release_key + */ + spin_lock_irqsave(&key->lock, flags); + mctp_dev_release_key(key->dev, key); + spin_unlock_irqrestore(&key->lock, flags); + + sock_put(key->sk); + kfree(key); +} + +static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) +{ + struct net *net = sock_net(&msk->sk); + struct mctp_sk_key *tmp; + unsigned long flags; + int rc = 0; + + spin_lock_irqsave(&net->mctp.keys_lock, flags); + + if (sock_flag(&msk->sk, SOCK_DEAD)) { + rc = -EINVAL; + goto out_unlock; + } + + hlist_for_each_entry(tmp, &net->mctp.keys, hlist) { + if (mctp_key_match(tmp, key->net, key->local_addr, + key->peer_addr, key->tag)) { + spin_lock(&tmp->lock); + if (tmp->valid) + rc = -EEXIST; + spin_unlock(&tmp->lock); + if (rc) + break; + } + } + + if (!rc) { + refcount_inc(&key->refs); + key->expiry = jiffies + mctp_key_lifetime; + timer_reduce(&msk->key_expiry, key->expiry); + + hlist_add_head(&key->hlist, &net->mctp.keys); + hlist_add_head(&key->sklist, &msk->keys); + } + +out_unlock: + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + return rc; +} + +/* Helper for mctp_route_input(). + * We're done with the key; unlock and unref the key. + * For the usual case of automatic expiry we remove the key from lists. + * In the case that manual allocation is set on a key we release the lock + * and local ref, reset reassembly, but don't remove from lists. + */ +static void __mctp_key_done_in(struct mctp_sk_key *key, struct net *net, + unsigned long flags, unsigned long reason) +__releases(&key->lock) +{ + struct sk_buff *skb; + + trace_mctp_key_release(key, reason); + skb = key->reasm_head; + key->reasm_head = NULL; + + if (!key->manual_alloc) { + key->reasm_dead = true; + key->valid = false; + mctp_dev_release_key(key->dev, key); + } + spin_unlock_irqrestore(&key->lock, flags); + + if (!key->manual_alloc) { + spin_lock_irqsave(&net->mctp.keys_lock, flags); + if (!hlist_unhashed(&key->hlist)) { + hlist_del_init(&key->hlist); + hlist_del_init(&key->sklist); + mctp_key_unref(key); + } + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + } + + /* and one for the local reference */ + mctp_key_unref(key); + + kfree_skb(skb); +} + +#ifdef CONFIG_MCTP_FLOWS +static void mctp_skb_set_flow(struct sk_buff *skb, struct mctp_sk_key *key) +{ + struct mctp_flow *flow; + + flow = skb_ext_add(skb, SKB_EXT_MCTP); + if (!flow) + return; + + refcount_inc(&key->refs); + flow->key = key; +} + +static void mctp_flow_prepare_output(struct sk_buff *skb, struct mctp_dev *dev) +{ + struct mctp_sk_key *key; + struct mctp_flow *flow; + + flow = skb_ext_find(skb, SKB_EXT_MCTP); + if (!flow) + return; + + key = flow->key; + + if (key->dev) { + WARN_ON(key->dev != dev); + return; + } + + mctp_dev_set_key(dev, key); +} +#else +static void mctp_skb_set_flow(struct sk_buff *skb, struct mctp_sk_key *key) {} +static void mctp_flow_prepare_output(struct sk_buff *skb, struct mctp_dev *dev) {} +#endif + +/* takes ownership of skb, both in success and failure cases */ +static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb) +{ + struct mctp_hdr *hdr = mctp_hdr(skb); + u8 exp_seq, this_seq; + + this_seq = (hdr->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT) + & MCTP_HDR_SEQ_MASK; + + if (!key->reasm_head) { + /* Since we're manipulating the shared frag_list, ensure it + * isn't shared with any other SKBs. In the cloned case, + * this will free the skb; callers can no longer access it + * safely. + */ + key->reasm_head = skb_unshare(skb, GFP_ATOMIC); + if (!key->reasm_head) + return -ENOMEM; + + key->reasm_tailp = &(skb_shinfo(key->reasm_head)->frag_list); + key->last_seq = this_seq; + return 0; + } + + exp_seq = (key->last_seq + 1) & MCTP_HDR_SEQ_MASK; + + if (this_seq != exp_seq) + goto err_free; + + if (key->reasm_head->len + skb->len > mctp_message_maxlen) + goto err_free; + + skb->next = NULL; + skb->sk = NULL; + *key->reasm_tailp = skb; + key->reasm_tailp = &skb->next; + + key->last_seq = this_seq; + + key->reasm_head->data_len += skb->len; + key->reasm_head->len += skb->len; + key->reasm_head->truesize += skb->truesize; + + return 0; + +err_free: + kfree_skb(skb); + return -EINVAL; +} + +static int mctp_dst_input(struct mctp_dst *dst, struct sk_buff *skb) +{ + struct mctp_sk_key *key, *any_key = NULL; + struct net *net = dev_net(skb->dev); + struct mctp_sock *msk; + struct mctp_hdr *mh; + unsigned int netid; + unsigned long f; + u8 tag, flags; + int rc; + + msk = NULL; + rc = -EINVAL; + + /* We may be receiving a locally-routed packet; drop source sk + * accounting. + * + * From here, we will either queue the skb - either to a frag_queue, or + * to a receiving socket. When that succeeds, we clear the skb pointer; + * a non-NULL skb on exit will be otherwise unowned, and hence + * kfree_skb()-ed. + */ + skb_orphan(skb); + + if (skb->pkt_type == PACKET_OUTGOING) + skb->pkt_type = PACKET_LOOPBACK; + + /* ensure we have enough data for a header and a type */ + if (skb->len < sizeof(struct mctp_hdr) + 1) + goto out; + + /* grab header, advance data ptr */ + mh = mctp_hdr(skb); + netid = mctp_cb(skb)->net; + skb_pull(skb, sizeof(struct mctp_hdr)); + + if (mh->ver != 1) + goto out; + + flags = mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM); + tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); + + rcu_read_lock(); + + /* lookup socket / reasm context, exactly matching (src,dest,tag). + * we hold a ref on the key, and key->lock held. + */ + key = mctp_lookup_key(net, skb, netid, mh->src, &f); + + if (flags & MCTP_HDR_FLAG_SOM) { + if (key) { + msk = container_of(key->sk, struct mctp_sock, sk); + } else { + /* first response to a broadcast? do a more general + * key lookup to find the socket, but don't use this + * key for reassembly - we'll create a more specific + * one for future packets if required (ie, !EOM). + * + * this lookup requires key->peer to be MCTP_ADDR_ANY, + * it doesn't match just any key->peer. + */ + any_key = mctp_lookup_key(net, skb, netid, + MCTP_ADDR_ANY, &f); + if (any_key) { + msk = container_of(any_key->sk, + struct mctp_sock, sk); + spin_unlock_irqrestore(&any_key->lock, f); + } + } + + if (!key && !msk && (tag & MCTP_HDR_FLAG_TO)) + msk = mctp_lookup_bind(net, skb); + + if (!msk) { + rc = -ENOENT; + goto out_unlock; + } + + /* single-packet message? deliver to socket, clean up any + * pending key. + */ + if (flags & MCTP_HDR_FLAG_EOM) { + rc = sock_queue_rcv_skb(&msk->sk, skb); + if (!rc) + skb = NULL; + if (key) { + /* we've hit a pending reassembly; not much we + * can do but drop it + */ + __mctp_key_done_in(key, net, f, + MCTP_TRACE_KEY_REPLIED); + key = NULL; + } + goto out_unlock; + } + + /* broadcast response or a bind() - create a key for further + * packets for this message + */ + if (!key) { + key = mctp_key_alloc(msk, netid, mh->dest, mh->src, + tag, GFP_ATOMIC); + if (!key) { + rc = -ENOMEM; + goto out_unlock; + } + + /* we can queue without the key lock here, as the + * key isn't observable yet + */ + mctp_frag_queue(key, skb); + skb = NULL; + + /* if the key_add fails, we've raced with another + * SOM packet with the same src, dest and tag. There's + * no way to distinguish future packets, so all we + * can do is drop. + */ + rc = mctp_key_add(key, msk); + if (!rc) + trace_mctp_key_acquire(key); + + /* we don't need to release key->lock on exit, so + * clean up here and suppress the unlock via + * setting to NULL + */ + mctp_key_unref(key); + key = NULL; + + } else { + if (key->reasm_head || key->reasm_dead) { + /* duplicate start? drop everything */ + __mctp_key_done_in(key, net, f, + MCTP_TRACE_KEY_INVALIDATED); + rc = -EEXIST; + key = NULL; + } else { + rc = mctp_frag_queue(key, skb); + skb = NULL; + } + } + + } else if (key) { + /* this packet continues a previous message; reassemble + * using the message-specific key + */ + + /* we need to be continuing an existing reassembly... */ + if (!key->reasm_head) { + rc = -EINVAL; + } else { + rc = mctp_frag_queue(key, skb); + skb = NULL; + } + + if (rc) + goto out_unlock; + + /* end of message? deliver to socket, and we're done with + * the reassembly/response key + */ + if (flags & MCTP_HDR_FLAG_EOM) { + rc = sock_queue_rcv_skb(key->sk, key->reasm_head); + if (!rc) + key->reasm_head = NULL; + __mctp_key_done_in(key, net, f, MCTP_TRACE_KEY_REPLIED); + key = NULL; + } + + } else { + /* not a start, no matching key */ + rc = -ENOENT; + } + +out_unlock: + rcu_read_unlock(); + if (key) { + spin_unlock_irqrestore(&key->lock, f); + mctp_key_unref(key); + } + if (any_key) + mctp_key_unref(any_key); +out: + kfree_skb(skb); + return rc; +} + +static int mctp_dst_output(struct mctp_dst *dst, struct sk_buff *skb) +{ + char daddr_buf[MAX_ADDR_LEN]; + char *daddr = NULL; + int rc; + + skb->protocol = htons(ETH_P_MCTP); + skb->pkt_type = PACKET_OUTGOING; + skb->dev = dst->dev->dev; + + if (skb->len > dst->mtu) { + kfree_skb(skb); + return -EMSGSIZE; + } + + /* direct route; use the hwaddr we stashed in sendmsg */ + if (dst->halen) { + if (dst->halen != skb->dev->addr_len) { + /* sanity check, sendmsg should have already caught this */ + kfree_skb(skb); + return -EMSGSIZE; + } + daddr = dst->haddr; + } else { + /* If lookup fails let the device handle daddr==NULL */ + if (mctp_neigh_lookup(dst->dev, dst->nexthop, daddr_buf) == 0) + daddr = daddr_buf; + } + + rc = dev_hard_header(skb, skb->dev, ntohs(skb->protocol), + daddr, skb->dev->dev_addr, skb->len); + if (rc < 0) { + kfree_skb(skb); + return -EHOSTUNREACH; + } + + mctp_flow_prepare_output(skb, dst->dev); + + rc = dev_queue_xmit(skb); + if (rc) + rc = net_xmit_errno(rc); + + return rc; +} + +/* route alloc/release */ +static void mctp_route_release(struct mctp_route *rt) +{ + if (refcount_dec_and_test(&rt->refs)) { + if (rt->dst_type == MCTP_ROUTE_DIRECT) + mctp_dev_put(rt->dev); + kfree_rcu(rt, rcu); + } +} + +/* returns a route with the refcount at 1 */ +static struct mctp_route *mctp_route_alloc(void) +{ + struct mctp_route *rt; + + rt = kzalloc(sizeof(*rt), GFP_KERNEL); + if (!rt) + return NULL; + + INIT_LIST_HEAD(&rt->list); + refcount_set(&rt->refs, 1); + rt->output = mctp_dst_discard; + + return rt; +} + +unsigned int mctp_default_net(struct net *net) +{ + return READ_ONCE(net->mctp.default_net); +} + +int mctp_default_net_set(struct net *net, unsigned int index) +{ + if (index == 0) + return -EINVAL; + WRITE_ONCE(net->mctp.default_net, index); + return 0; +} + +/* tag management */ +static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key, + struct mctp_sock *msk) +{ + struct netns_mctp *mns = &net->mctp; + + lockdep_assert_held(&mns->keys_lock); + + key->expiry = jiffies + mctp_key_lifetime; + timer_reduce(&msk->key_expiry, key->expiry); + + /* we hold the net->key_lock here, allowing updates to both + * then net and sk + */ + hlist_add_head_rcu(&key->hlist, &mns->keys); + hlist_add_head_rcu(&key->sklist, &msk->keys); + refcount_inc(&key->refs); +} + +/* Allocate a locally-owned tag value for (local, peer), and reserve + * it for the socket msk + */ +struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk, + unsigned int netid, + mctp_eid_t local, mctp_eid_t peer, + bool manual, u8 *tagp) +{ + struct net *net = sock_net(&msk->sk); + struct netns_mctp *mns = &net->mctp; + struct mctp_sk_key *key, *tmp; + unsigned long flags; + u8 tagbits; + + /* for NULL destination EIDs, we may get a response from any peer */ + if (peer == MCTP_ADDR_NULL) + peer = MCTP_ADDR_ANY; + + /* be optimistic, alloc now */ + key = mctp_key_alloc(msk, netid, local, peer, 0, GFP_KERNEL); + if (!key) + return ERR_PTR(-ENOMEM); + + /* 8 possible tag values */ + tagbits = 0xff; + + spin_lock_irqsave(&mns->keys_lock, flags); + + /* Walk through the existing keys, looking for potential conflicting + * tags. If we find a conflict, clear that bit from tagbits + */ + hlist_for_each_entry(tmp, &mns->keys, hlist) { + /* We can check the lookup fields (*_addr, tag) without the + * lock held, they don't change over the lifetime of the key. + */ + + /* tags are net-specific */ + if (tmp->net != netid) + continue; + + /* if we don't own the tag, it can't conflict */ + if (tmp->tag & MCTP_HDR_FLAG_TO) + continue; + + /* Since we're avoiding conflicting entries, match peer and + * local addresses, including with a wildcard on ANY. See + * 'A note on key allocations' for background. + */ + if (peer != MCTP_ADDR_ANY && + !mctp_address_matches(tmp->peer_addr, peer)) + continue; + + if (local != MCTP_ADDR_ANY && + !mctp_address_matches(tmp->local_addr, local)) + continue; + + spin_lock(&tmp->lock); + /* key must still be valid. If we find a match, clear the + * potential tag value + */ + if (tmp->valid) + tagbits &= ~(1 << tmp->tag); + spin_unlock(&tmp->lock); + + if (!tagbits) + break; + } + + if (tagbits) { + key->tag = __ffs(tagbits); + mctp_reserve_tag(net, key, msk); + trace_mctp_key_acquire(key); + + key->manual_alloc = manual; + *tagp = key->tag; + } + + spin_unlock_irqrestore(&mns->keys_lock, flags); + + if (!tagbits) { + mctp_key_unref(key); + return ERR_PTR(-EBUSY); + } + + return key; +} + +static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk, + unsigned int netid, + mctp_eid_t daddr, + u8 req_tag, u8 *tagp) +{ + struct net *net = sock_net(&msk->sk); + struct netns_mctp *mns = &net->mctp; + struct mctp_sk_key *key, *tmp; + unsigned long flags; + + req_tag &= ~(MCTP_TAG_PREALLOC | MCTP_TAG_OWNER); + key = NULL; + + spin_lock_irqsave(&mns->keys_lock, flags); + + hlist_for_each_entry(tmp, &mns->keys, hlist) { + if (tmp->net != netid) + continue; + + if (tmp->tag != req_tag) + continue; + + if (!mctp_address_matches(tmp->peer_addr, daddr)) + continue; + + if (!tmp->manual_alloc) + continue; + + spin_lock(&tmp->lock); + if (tmp->valid) { + key = tmp; + refcount_inc(&key->refs); + spin_unlock(&tmp->lock); + break; + } + spin_unlock(&tmp->lock); + } + spin_unlock_irqrestore(&mns->keys_lock, flags); + + if (!key) + return ERR_PTR(-ENOENT); + + if (tagp) + *tagp = key->tag; + + return key; +} + +/* routing lookups */ +static unsigned int mctp_route_netid(struct mctp_route *rt) +{ + return rt->dst_type == MCTP_ROUTE_DIRECT ? + READ_ONCE(rt->dev->net) : rt->gateway.net; +} + +static bool mctp_rt_match_eid(struct mctp_route *rt, + unsigned int net, mctp_eid_t eid) +{ + return mctp_route_netid(rt) == net && + rt->min <= eid && rt->max >= eid; +} + +/* compares match, used for duplicate prevention */ +static bool mctp_rt_compare_exact(struct mctp_route *rt1, + struct mctp_route *rt2) +{ + ASSERT_RTNL(); + return mctp_route_netid(rt1) == mctp_route_netid(rt2) && + rt1->min == rt2->min && + rt1->max == rt2->max; +} + +/* must only be called on a direct route, as the final output hop */ +static void mctp_dst_from_route(struct mctp_dst *dst, mctp_eid_t eid, + unsigned int mtu, struct mctp_route *route) +{ + mctp_dev_hold(route->dev); + dst->nexthop = eid; + dst->dev = route->dev; + dst->mtu = READ_ONCE(dst->dev->dev->mtu); + if (mtu) + dst->mtu = min(dst->mtu, mtu); + dst->halen = 0; + dst->output = route->output; +} + +int mctp_dst_from_extaddr(struct mctp_dst *dst, struct net *net, int ifindex, + unsigned char halen, const unsigned char *haddr) +{ + struct net_device *netdev; + struct mctp_dev *dev; + int rc = -ENOENT; + + if (halen > sizeof(dst->haddr)) + return -EINVAL; + + rcu_read_lock(); + + netdev = dev_get_by_index_rcu(net, ifindex); + if (!netdev) + goto out_unlock; + + if (netdev->addr_len != halen) { + rc = -EINVAL; + goto out_unlock; + } + + dev = __mctp_dev_get(netdev); + if (!dev) + goto out_unlock; + + dst->dev = dev; + dst->mtu = READ_ONCE(netdev->mtu); + dst->halen = halen; + dst->output = mctp_dst_output; + dst->nexthop = 0; + memcpy(dst->haddr, haddr, halen); + + rc = 0; + +out_unlock: + rcu_read_unlock(); + return rc; +} + +void mctp_dst_release(struct mctp_dst *dst) +{ + mctp_dev_put(dst->dev); +} + +static struct mctp_route *mctp_route_lookup_single(struct net *net, + unsigned int dnet, + mctp_eid_t daddr) +{ + struct mctp_route *rt; + + list_for_each_entry_rcu(rt, &net->mctp.routes, list) { + if (mctp_rt_match_eid(rt, dnet, daddr)) + return rt; + } + + return NULL; +} + +/* populates *dst on successful lookup, if set */ +int mctp_route_lookup(struct net *net, unsigned int dnet, + mctp_eid_t daddr, struct mctp_dst *dst) +{ + const unsigned int max_depth = 32; + unsigned int depth, mtu = 0; + int rc = -EHOSTUNREACH; + + rcu_read_lock(); + + for (depth = 0; depth < max_depth; depth++) { + struct mctp_route *rt; + + rt = mctp_route_lookup_single(net, dnet, daddr); + if (!rt) + break; + + /* clamp mtu to the smallest in the path, allowing 0 + * to specify no restrictions + */ + if (mtu && rt->mtu) + mtu = min(mtu, rt->mtu); + else + mtu = mtu ?: rt->mtu; + + if (rt->dst_type == MCTP_ROUTE_DIRECT) { + if (dst) + mctp_dst_from_route(dst, daddr, mtu, rt); + rc = 0; + break; + + } else if (rt->dst_type == MCTP_ROUTE_GATEWAY) { + daddr = rt->gateway.eid; + } + } + + rcu_read_unlock(); + + return rc; +} + +static int mctp_route_lookup_null(struct net *net, struct net_device *dev, + struct mctp_dst *dst) +{ + int rc = -EHOSTUNREACH; + struct mctp_route *rt; + + rcu_read_lock(); + + list_for_each_entry_rcu(rt, &net->mctp.routes, list) { + if (rt->dst_type != MCTP_ROUTE_DIRECT || rt->type != RTN_LOCAL) + continue; + + if (rt->dev->dev != dev) + continue; + + mctp_dst_from_route(dst, 0, 0, rt); + rc = 0; + break; + } + + rcu_read_unlock(); + + return rc; +} + +static int mctp_do_fragment_route(struct mctp_dst *dst, struct sk_buff *skb, + unsigned int mtu, u8 tag) +{ + const unsigned int hlen = sizeof(struct mctp_hdr); + struct mctp_hdr *hdr, *hdr2; + unsigned int pos, size, headroom; + struct sk_buff *skb2; + int rc; + u8 seq; + + hdr = mctp_hdr(skb); + seq = 0; + rc = 0; + + if (mtu < hlen + 1) { + kfree_skb(skb); + return -EMSGSIZE; + } + + /* keep same headroom as the original skb */ + headroom = skb_headroom(skb); + + /* we've got the header */ + skb_pull(skb, hlen); + + for (pos = 0; pos < skb->len;) { + /* size of message payload */ + size = min(mtu - hlen, skb->len - pos); + + skb2 = alloc_skb(headroom + hlen + size, GFP_KERNEL); + if (!skb2) { + rc = -ENOMEM; + break; + } + + /* generic skb copy */ + skb2->protocol = skb->protocol; + skb2->priority = skb->priority; + skb2->dev = skb->dev; + memcpy(skb2->cb, skb->cb, sizeof(skb2->cb)); + + if (skb->sk) + skb_set_owner_w(skb2, skb->sk); + + /* establish packet */ + skb_reserve(skb2, headroom); + skb_reset_network_header(skb2); + skb_put(skb2, hlen + size); + skb2->transport_header = skb2->network_header + hlen; + + /* copy header fields, calculate SOM/EOM flags & seq */ + hdr2 = mctp_hdr(skb2); + hdr2->ver = hdr->ver; + hdr2->dest = hdr->dest; + hdr2->src = hdr->src; + hdr2->flags_seq_tag = tag & + (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); + + if (pos == 0) + hdr2->flags_seq_tag |= MCTP_HDR_FLAG_SOM; + + if (pos + size == skb->len) + hdr2->flags_seq_tag |= MCTP_HDR_FLAG_EOM; + + hdr2->flags_seq_tag |= seq << MCTP_HDR_SEQ_SHIFT; + + /* copy message payload */ + skb_copy_bits(skb, pos, skb_transport_header(skb2), size); + + /* we need to copy the extensions, for MCTP flow data */ + skb_ext_copy(skb2, skb); + + /* do route */ + rc = dst->output(dst, skb2); + if (rc) + break; + + seq = (seq + 1) & MCTP_HDR_SEQ_MASK; + pos += size; + } + + consume_skb(skb); + return rc; +} + +int mctp_local_output(struct sock *sk, struct mctp_dst *dst, + struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag) +{ + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct mctp_sk_key *key; + struct mctp_hdr *hdr; + unsigned long flags; + unsigned int netid; + unsigned int mtu; + mctp_eid_t saddr; + int rc; + u8 tag; + + KUNIT_STATIC_STUB_REDIRECT(mctp_local_output, sk, dst, skb, daddr, + req_tag); + + rc = -ENODEV; + + spin_lock_irqsave(&dst->dev->addrs_lock, flags); + if (dst->dev->num_addrs == 0) { + rc = -EHOSTUNREACH; + } else { + /* use the outbound interface's first address as our source */ + saddr = dst->dev->addrs[0]; + rc = 0; + } + spin_unlock_irqrestore(&dst->dev->addrs_lock, flags); + netid = READ_ONCE(dst->dev->net); + + if (rc) + goto out_release; + + if (req_tag & MCTP_TAG_OWNER) { + if (req_tag & MCTP_TAG_PREALLOC) + key = mctp_lookup_prealloc_tag(msk, netid, daddr, + req_tag, &tag); + else + key = mctp_alloc_local_tag(msk, netid, saddr, daddr, + false, &tag); + + if (IS_ERR(key)) { + rc = PTR_ERR(key); + goto out_release; + } + mctp_skb_set_flow(skb, key); + /* done with the key in this scope */ + mctp_key_unref(key); + tag |= MCTP_HDR_FLAG_TO; + } else { + key = NULL; + tag = req_tag & MCTP_TAG_MASK; + } + + skb->pkt_type = PACKET_OUTGOING; + skb->protocol = htons(ETH_P_MCTP); + skb->priority = 0; + skb_reset_transport_header(skb); + skb_push(skb, sizeof(struct mctp_hdr)); + skb_reset_network_header(skb); + skb->dev = dst->dev->dev; + + /* set up common header fields */ + hdr = mctp_hdr(skb); + hdr->ver = 1; + hdr->dest = daddr; + hdr->src = saddr; + + mtu = dst->mtu; + + if (skb->len + sizeof(struct mctp_hdr) <= mtu) { + hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | + MCTP_HDR_FLAG_EOM | tag; + rc = dst->output(dst, skb); + } else { + rc = mctp_do_fragment_route(dst, skb, mtu, tag); + } + + /* route output functions consume the skb, even on error */ + skb = NULL; + +out_release: + kfree_skb(skb); + return rc; +} + +/* route management */ + +/* mctp_route_add(): Add the provided route, previously allocated via + * mctp_route_alloc(). On success, takes ownership of @rt, which includes a + * hold on rt->dev for usage in the route table. On failure a caller will want + * to mctp_route_release(). + * + * We expect that the caller has set rt->type, rt->dst_type, rt->min, rt->max, + * rt->mtu and either rt->dev (with a reference held appropriately) or + * rt->gateway. Other fields will be populated. + */ +static int mctp_route_add(struct net *net, struct mctp_route *rt) +{ + struct mctp_route *ert; + + if (!mctp_address_unicast(rt->min) || !mctp_address_unicast(rt->max)) + return -EINVAL; + + if (rt->dst_type == MCTP_ROUTE_DIRECT && !rt->dev) + return -EINVAL; + + if (rt->dst_type == MCTP_ROUTE_GATEWAY && !rt->gateway.eid) + return -EINVAL; + + switch (rt->type) { + case RTN_LOCAL: + rt->output = mctp_dst_input; + break; + case RTN_UNICAST: + rt->output = mctp_dst_output; + break; + default: + return -EINVAL; + } + + ASSERT_RTNL(); + + /* Prevent duplicate identical routes. */ + list_for_each_entry(ert, &net->mctp.routes, list) { + if (mctp_rt_compare_exact(rt, ert)) { + return -EEXIST; + } + } + + list_add_rcu(&rt->list, &net->mctp.routes); + + return 0; +} + +static int mctp_route_remove(struct net *net, unsigned int netid, + mctp_eid_t daddr_start, unsigned int daddr_extent, + unsigned char type) +{ + struct mctp_route *rt, *tmp; + mctp_eid_t daddr_end; + bool dropped; + + if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255) + return -EINVAL; + + daddr_end = daddr_start + daddr_extent; + dropped = false; + + ASSERT_RTNL(); + + list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) { + if (mctp_route_netid(rt) == netid && + rt->min == daddr_start && rt->max == daddr_end && + rt->type == type) { + list_del_rcu(&rt->list); + /* TODO: immediate RTM_DELROUTE */ + mctp_route_release(rt); + dropped = true; + } + } + + return dropped ? 0 : -ENOENT; +} + +int mctp_route_add_local(struct mctp_dev *mdev, mctp_eid_t addr) +{ + struct mctp_route *rt; + int rc; + + rt = mctp_route_alloc(); + if (!rt) + return -ENOMEM; + + rt->min = addr; + rt->max = addr; + rt->dst_type = MCTP_ROUTE_DIRECT; + rt->dev = mdev; + rt->type = RTN_LOCAL; + + mctp_dev_hold(rt->dev); + + rc = mctp_route_add(dev_net(mdev->dev), rt); + if (rc) + mctp_route_release(rt); + + return rc; +} + +int mctp_route_remove_local(struct mctp_dev *mdev, mctp_eid_t addr) +{ + return mctp_route_remove(dev_net(mdev->dev), mdev->net, + addr, 0, RTN_LOCAL); +} + +/* removes all entries for a given device */ +void mctp_route_remove_dev(struct mctp_dev *mdev) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_route *rt, *tmp; + + ASSERT_RTNL(); + list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) { + if (rt->dst_type == MCTP_ROUTE_DIRECT && rt->dev == mdev) { + list_del_rcu(&rt->list); + /* TODO: immediate RTM_DELROUTE */ + mctp_route_release(rt); + } + } +} + +/* Incoming packet-handling */ + +static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, + struct net_device *orig_dev) +{ + struct net *net = dev_net(dev); + struct mctp_dev *mdev; + struct mctp_skb_cb *cb; + struct mctp_dst dst; + struct mctp_hdr *mh; + int rc; + + rcu_read_lock(); + mdev = __mctp_dev_get(dev); + rcu_read_unlock(); + if (!mdev) { + /* basic non-data sanity checks */ + goto err_drop; + } + + if (!pskb_may_pull(skb, sizeof(struct mctp_hdr))) + goto err_drop; + + skb_reset_transport_header(skb); + skb_reset_network_header(skb); + + /* We have enough for a header; decode and route */ + mh = mctp_hdr(skb); + if (mh->ver < MCTP_VER_MIN || mh->ver > MCTP_VER_MAX) + goto err_drop; + + /* source must be valid unicast or null; drop reserved ranges and + * broadcast + */ + if (!(mctp_address_unicast(mh->src) || mctp_address_null(mh->src))) + goto err_drop; + + /* dest address: as above, but allow broadcast */ + if (!(mctp_address_unicast(mh->dest) || mctp_address_null(mh->dest) || + mctp_address_broadcast(mh->dest))) + goto err_drop; + + /* MCTP drivers must populate halen/haddr */ + if (dev->type == ARPHRD_MCTP) { + cb = mctp_cb(skb); + } else { + cb = __mctp_cb(skb); + cb->halen = 0; + } + cb->net = READ_ONCE(mdev->net); + cb->ifindex = dev->ifindex; + + rc = mctp_route_lookup(net, cb->net, mh->dest, &dst); + + /* NULL EID, but addressed to our physical address */ + if (rc && mh->dest == MCTP_ADDR_NULL && skb->pkt_type == PACKET_HOST) + rc = mctp_route_lookup_null(net, dev, &dst); + + if (rc) + goto err_drop; + + dst.output(&dst, skb); + mctp_dst_release(&dst); + mctp_dev_put(mdev); + + return NET_RX_SUCCESS; + +err_drop: + kfree_skb(skb); + mctp_dev_put(mdev); + return NET_RX_DROP; +} + +static struct packet_type mctp_packet_type = { + .type = cpu_to_be16(ETH_P_MCTP), + .func = mctp_pkttype_receive, +}; + +/* netlink interface */ + +static const struct nla_policy rta_mctp_policy[RTA_MAX + 1] = { + [RTA_DST] = { .type = NLA_U8 }, + [RTA_METRICS] = { .type = NLA_NESTED }, + [RTA_OIF] = { .type = NLA_U32 }, + [RTA_GATEWAY] = NLA_POLICY_EXACT_LEN(sizeof(struct mctp_fq_addr)), +}; + +static const struct nla_policy rta_metrics_policy[RTAX_MAX + 1] = { + [RTAX_MTU] = { .type = NLA_U32 }, +}; + +/* base parsing; common to both _lookup and _populate variants. + * + * For gateway routes (which have a RTA_GATEWAY, and no RTA_OIF), we populate + * *gatweayp. for direct routes (RTA_OIF, no RTA_GATEWAY), we populate *mdev. + */ +static int mctp_route_nlparse_common(struct net *net, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack, + struct nlattr **tb, struct rtmsg **rtm, + struct mctp_dev **mdev, + struct mctp_fq_addr *gatewayp, + mctp_eid_t *daddr_start) +{ + struct mctp_fq_addr *gateway = NULL; + unsigned int ifindex = 0; + struct net_device *dev; + int rc; + + rc = nlmsg_parse(nlh, sizeof(struct rtmsg), tb, RTA_MAX, + rta_mctp_policy, extack); + if (rc < 0) { + NL_SET_ERR_MSG(extack, "incorrect format"); + return rc; + } + + if (!tb[RTA_DST]) { + NL_SET_ERR_MSG(extack, "dst EID missing"); + return -EINVAL; + } + *daddr_start = nla_get_u8(tb[RTA_DST]); + + if (tb[RTA_OIF]) + ifindex = nla_get_u32(tb[RTA_OIF]); + + if (tb[RTA_GATEWAY]) + gateway = nla_data(tb[RTA_GATEWAY]); + + if (ifindex && gateway) { + NL_SET_ERR_MSG(extack, + "cannot specify both ifindex and gateway"); + return -EINVAL; + + } else if (ifindex) { + dev = __dev_get_by_index(net, ifindex); + if (!dev) { + NL_SET_ERR_MSG(extack, "bad ifindex"); + return -ENODEV; + } + *mdev = mctp_dev_get_rtnl(dev); + if (!*mdev) + return -ENODEV; + gatewayp->eid = 0; + + } else if (gateway) { + if (!mctp_address_unicast(gateway->eid)) { + NL_SET_ERR_MSG(extack, "bad gateway"); + return -EINVAL; + } + + gatewayp->eid = gateway->eid; + gatewayp->net = gateway->net != MCTP_NET_ANY ? + gateway->net : + READ_ONCE(net->mctp.default_net); + *mdev = NULL; + + } else { + NL_SET_ERR_MSG(extack, "no route output provided"); + return -EINVAL; + } + + *rtm = nlmsg_data(nlh); + if ((*rtm)->rtm_family != AF_MCTP) { + NL_SET_ERR_MSG(extack, "route family must be AF_MCTP"); + return -EINVAL; + } + + if ((*rtm)->rtm_type != RTN_UNICAST) { + NL_SET_ERR_MSG(extack, "rtm_type must be RTN_UNICAST"); + return -EINVAL; + } + + return 0; +} + +/* Route parsing for lookup operations; we only need the "route target" + * components (ie., network and dest-EID range). + */ +static int mctp_route_nlparse_lookup(struct net *net, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack, + unsigned char *type, unsigned int *netid, + mctp_eid_t *daddr_start, + unsigned int *daddr_extent) +{ + struct nlattr *tb[RTA_MAX + 1]; + struct mctp_fq_addr gw; + struct mctp_dev *mdev; + struct rtmsg *rtm; + int rc; + + rc = mctp_route_nlparse_common(net, nlh, extack, tb, &rtm, + &mdev, &gw, daddr_start); + if (rc) + return rc; + + if (mdev) { + *netid = mdev->net; + } else if (gw.eid) { + *netid = gw.net; + } else { + /* bug: _nlparse_common should not allow this */ + return -1; + } + + *type = rtm->rtm_type; + *daddr_extent = rtm->rtm_dst_len; + + return 0; +} + +/* Full route parse for RTM_NEWROUTE: populate @rt. On success, + * MCTP_ROUTE_DIRECT routes (ie, those with a direct dev) will hold a reference + * to that dev. + */ +static int mctp_route_nlparse_populate(struct net *net, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack, + struct mctp_route *rt) +{ + struct nlattr *tbx[RTAX_MAX + 1]; + struct nlattr *tb[RTA_MAX + 1]; + unsigned int daddr_extent; + struct mctp_fq_addr gw; + mctp_eid_t daddr_start; + struct mctp_dev *dev; + struct rtmsg *rtm; + u32 mtu = 0; + int rc; + + rc = mctp_route_nlparse_common(net, nlh, extack, tb, &rtm, + &dev, &gw, &daddr_start); + if (rc) + return rc; + + daddr_extent = rtm->rtm_dst_len; + + if (daddr_extent > 0xff || daddr_extent + daddr_start >= 255) { + NL_SET_ERR_MSG(extack, "invalid eid range"); + return -EINVAL; + } + + if (tb[RTA_METRICS]) { + rc = nla_parse_nested(tbx, RTAX_MAX, tb[RTA_METRICS], + rta_metrics_policy, NULL); + if (rc < 0) { + NL_SET_ERR_MSG(extack, "incorrect RTA_METRICS format"); + return rc; + } + if (tbx[RTAX_MTU]) + mtu = nla_get_u32(tbx[RTAX_MTU]); + } + + rt->type = rtm->rtm_type; + rt->min = daddr_start; + rt->max = daddr_start + daddr_extent; + rt->mtu = mtu; + if (gw.eid) { + rt->dst_type = MCTP_ROUTE_GATEWAY; + rt->gateway.eid = gw.eid; + rt->gateway.net = gw.net; + } else { + rt->dst_type = MCTP_ROUTE_DIRECT; + rt->dev = dev; + mctp_dev_hold(rt->dev); + } + + return 0; +} + +static int mctp_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct mctp_route *rt; + int rc; + + rt = mctp_route_alloc(); + if (!rt) + return -ENOMEM; + + rc = mctp_route_nlparse_populate(net, nlh, extack, rt); + if (rc < 0) + goto err_free; + + if (rt->dst_type == MCTP_ROUTE_DIRECT && + rt->dev->dev->flags & IFF_LOOPBACK) { + NL_SET_ERR_MSG(extack, "no routes to loopback"); + rc = -EINVAL; + goto err_free; + } + + rc = mctp_route_add(net, rt); + if (!rc) + return 0; + +err_free: + mctp_route_release(rt); + return rc; +} + +static int mctp_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + unsigned int netid, daddr_extent; + unsigned char type = RTN_UNSPEC; + mctp_eid_t daddr_start; + int rc; + + rc = mctp_route_nlparse_lookup(net, nlh, extack, &type, &netid, + &daddr_start, &daddr_extent); + if (rc < 0) + return rc; + + /* we only have unicast routes */ + if (type != RTN_UNICAST) + return -EINVAL; + + rc = mctp_route_remove(net, netid, daddr_start, daddr_extent, type); + return rc; +} + +static int mctp_fill_rtinfo(struct sk_buff *skb, struct mctp_route *rt, + u32 portid, u32 seq, int event, unsigned int flags) +{ + struct nlmsghdr *nlh; + struct rtmsg *hdr; + void *metrics; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags); + if (!nlh) + return -EMSGSIZE; + + hdr = nlmsg_data(nlh); + hdr->rtm_family = AF_MCTP; + + /* we use the _len fields as a number of EIDs, rather than + * a number of bits in the address + */ + hdr->rtm_dst_len = rt->max - rt->min; + hdr->rtm_src_len = 0; + hdr->rtm_tos = 0; + hdr->rtm_table = RT_TABLE_DEFAULT; + hdr->rtm_protocol = RTPROT_STATIC; /* everything is user-defined */ + hdr->rtm_type = rt->type; + + if (nla_put_u8(skb, RTA_DST, rt->min)) + goto cancel; + + metrics = nla_nest_start_noflag(skb, RTA_METRICS); + if (!metrics) + goto cancel; + + if (rt->mtu) { + if (nla_put_u32(skb, RTAX_MTU, rt->mtu)) + goto cancel; + } + + nla_nest_end(skb, metrics); + + if (rt->dst_type == MCTP_ROUTE_DIRECT) { + hdr->rtm_scope = RT_SCOPE_LINK; + if (nla_put_u32(skb, RTA_OIF, rt->dev->dev->ifindex)) + goto cancel; + } else if (rt->dst_type == MCTP_ROUTE_GATEWAY) { + hdr->rtm_scope = RT_SCOPE_UNIVERSE; + if (nla_put(skb, RTA_GATEWAY, + sizeof(rt->gateway), &rt->gateway)) + goto cancel; + } + + nlmsg_end(skb, nlh); + + return 0; + +cancel: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int mctp_dump_rtinfo(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct mctp_route *rt; + int s_idx, idx; + + /* TODO: allow filtering on route data, possibly under + * cb->strict_check + */ + + /* TODO: change to struct overlay */ + s_idx = cb->args[0]; + idx = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(rt, &net->mctp.routes, list) { + if (idx++ < s_idx) + continue; + if (mctp_fill_rtinfo(skb, rt, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + RTM_NEWROUTE, NLM_F_MULTI) < 0) + break; + } + + rcu_read_unlock(); + cb->args[0] = idx; + + return skb->len; +} + +/* net namespace implementation */ +static int __net_init mctp_routes_net_init(struct net *net) +{ + struct netns_mctp *ns = &net->mctp; + + INIT_LIST_HEAD(&ns->routes); + hash_init(ns->binds); + mutex_init(&ns->bind_lock); + INIT_HLIST_HEAD(&ns->keys); + spin_lock_init(&ns->keys_lock); + WARN_ON(mctp_default_net_set(net, MCTP_INITIAL_DEFAULT_NET)); + return 0; +} + +static void __net_exit mctp_routes_net_exit(struct net *net) +{ + struct mctp_route *rt; + + rcu_read_lock(); + list_for_each_entry_rcu(rt, &net->mctp.routes, list) + mctp_route_release(rt); + rcu_read_unlock(); +} + +static struct pernet_operations mctp_net_ops = { + .init = mctp_routes_net_init, + .exit = mctp_routes_net_exit, +}; + +static const struct rtnl_msg_handler mctp_route_rtnl_msg_handlers[] = { + {THIS_MODULE, PF_MCTP, RTM_NEWROUTE, mctp_newroute, NULL, 0}, + {THIS_MODULE, PF_MCTP, RTM_DELROUTE, mctp_delroute, NULL, 0}, + {THIS_MODULE, PF_MCTP, RTM_GETROUTE, NULL, mctp_dump_rtinfo, 0}, +}; + +int __init mctp_routes_init(void) +{ + int err; + + dev_add_pack(&mctp_packet_type); + + err = register_pernet_subsys(&mctp_net_ops); + if (err) + goto err_pernet; + + err = rtnl_register_many(mctp_route_rtnl_msg_handlers); + if (err) + goto err_rtnl; + + return 0; + +err_rtnl: + unregister_pernet_subsys(&mctp_net_ops); +err_pernet: + dev_remove_pack(&mctp_packet_type); + return err; +} + +void mctp_routes_exit(void) +{ + rtnl_unregister_many(mctp_route_rtnl_msg_handlers); + unregister_pernet_subsys(&mctp_net_ops); + dev_remove_pack(&mctp_packet_type); +} + +#if IS_ENABLED(CONFIG_MCTP_TEST) +#include "test/route-test.c" +#endif diff --git a/net/mctp/test/route-test.c b/net/mctp/test/route-test.c new file mode 100644 index 000000000000..75ea96c10e49 --- /dev/null +++ b/net/mctp/test/route-test.c @@ -0,0 +1,1598 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <kunit/test.h> + +/* keep clangd happy when compiled outside of the route.c include */ +#include <net/mctp.h> +#include <net/mctpdevice.h> + +#include "utils.h" + +#define mctp_test_create_skb_data(h, d) \ + __mctp_test_create_skb_data(h, d, sizeof(*d)) + +struct mctp_frag_test { + unsigned int mtu; + unsigned int msgsize; + unsigned int n_frags; +}; + +static void mctp_test_fragment(struct kunit *test) +{ + const struct mctp_frag_test *params; + int rc, i, n, mtu, msgsize; + struct mctp_test_dev *dev; + struct mctp_dst dst; + struct sk_buff *skb; + struct mctp_hdr hdr; + u8 seq; + + params = test->param_value; + mtu = params->mtu; + msgsize = params->msgsize; + + hdr.ver = 1; + hdr.src = 8; + hdr.dest = 10; + hdr.flags_seq_tag = MCTP_HDR_FLAG_TO; + + skb = mctp_test_create_skb(&hdr, msgsize); + KUNIT_ASSERT_TRUE(test, skb); + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + + mctp_test_dst_setup(test, &dst, dev, mtu); + + rc = mctp_do_fragment_route(&dst, skb, mtu, MCTP_TAG_OWNER); + KUNIT_EXPECT_FALSE(test, rc); + + n = dev->pkts.qlen; + KUNIT_EXPECT_EQ(test, n, params->n_frags); + + for (i = 0;; i++) { + struct mctp_hdr *hdr2; + struct sk_buff *skb2; + u8 tag_mask, seq2; + bool first, last; + + first = i == 0; + last = i == (n - 1); + + skb2 = skb_dequeue(&dev->pkts); + if (!skb2) + break; + + hdr2 = mctp_hdr(skb2); + + tag_mask = MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO; + + KUNIT_EXPECT_EQ(test, hdr2->ver, hdr.ver); + KUNIT_EXPECT_EQ(test, hdr2->src, hdr.src); + KUNIT_EXPECT_EQ(test, hdr2->dest, hdr.dest); + KUNIT_EXPECT_EQ(test, hdr2->flags_seq_tag & tag_mask, + hdr.flags_seq_tag & tag_mask); + + KUNIT_EXPECT_EQ(test, + !!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_SOM), first); + KUNIT_EXPECT_EQ(test, + !!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_EOM), last); + + seq2 = (hdr2->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT) & + MCTP_HDR_SEQ_MASK; + + if (first) { + seq = seq2; + } else { + seq++; + KUNIT_EXPECT_EQ(test, seq2, seq & MCTP_HDR_SEQ_MASK); + } + + if (!last) + KUNIT_EXPECT_EQ(test, skb2->len, mtu); + else + KUNIT_EXPECT_LE(test, skb2->len, mtu); + + kfree_skb(skb2); + } + + mctp_dst_release(&dst); + mctp_test_destroy_dev(dev); +} + +static const struct mctp_frag_test mctp_frag_tests[] = { + {.mtu = 68, .msgsize = 63, .n_frags = 1}, + {.mtu = 68, .msgsize = 64, .n_frags = 1}, + {.mtu = 68, .msgsize = 65, .n_frags = 2}, + {.mtu = 68, .msgsize = 66, .n_frags = 2}, + {.mtu = 68, .msgsize = 127, .n_frags = 2}, + {.mtu = 68, .msgsize = 128, .n_frags = 2}, + {.mtu = 68, .msgsize = 129, .n_frags = 3}, + {.mtu = 68, .msgsize = 130, .n_frags = 3}, +}; + +static void mctp_frag_test_to_desc(const struct mctp_frag_test *t, char *desc) +{ + sprintf(desc, "mtu %d len %d -> %d frags", + t->msgsize, t->mtu, t->n_frags); +} + +KUNIT_ARRAY_PARAM(mctp_frag, mctp_frag_tests, mctp_frag_test_to_desc); + +struct mctp_rx_input_test { + struct mctp_hdr hdr; + bool input; +}; + +static void mctp_test_rx_input(struct kunit *test) +{ + const struct mctp_rx_input_test *params; + struct mctp_test_route *rt; + struct mctp_test_dev *dev; + struct sk_buff *skb; + + params = test->param_value; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + + rt = mctp_test_create_route_direct(&init_net, dev->mdev, 8, 68); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt); + + skb = mctp_test_create_skb(¶ms->hdr, 1); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + mctp_pkttype_receive(skb, dev->ndev, &mctp_packet_type, NULL); + + KUNIT_EXPECT_EQ(test, !!dev->pkts.qlen, params->input); + + mctp_test_route_destroy(test, rt); + mctp_test_destroy_dev(dev); +} + +#define RX_HDR(_ver, _src, _dest, _fst) \ + { .ver = _ver, .src = _src, .dest = _dest, .flags_seq_tag = _fst } + +/* we have a route for EID 8 only */ +static const struct mctp_rx_input_test mctp_rx_input_tests[] = { + { .hdr = RX_HDR(1, 10, 8, 0), .input = true }, + { .hdr = RX_HDR(1, 10, 9, 0), .input = false }, /* no input route */ + { .hdr = RX_HDR(2, 10, 8, 0), .input = false }, /* invalid version */ +}; + +static void mctp_rx_input_test_to_desc(const struct mctp_rx_input_test *t, + char *desc) +{ + sprintf(desc, "{%x,%x,%x,%x}", t->hdr.ver, t->hdr.src, t->hdr.dest, + t->hdr.flags_seq_tag); +} + +KUNIT_ARRAY_PARAM(mctp_rx_input, mctp_rx_input_tests, + mctp_rx_input_test_to_desc); + +/* set up a local dev, route on EID 8, and a socket listening on type 0 */ +static void __mctp_route_test_init(struct kunit *test, + struct mctp_test_dev **devp, + struct mctp_dst *dst, + struct socket **sockp, + unsigned int netid) +{ + struct sockaddr_mctp addr = {0}; + struct mctp_test_dev *dev; + struct socket *sock; + int rc; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + if (netid != MCTP_NET_ANY) + WRITE_ONCE(dev->mdev->net, netid); + + mctp_test_dst_setup(test, dst, dev, 68); + + rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock); + KUNIT_ASSERT_EQ(test, rc, 0); + + addr.smctp_family = AF_MCTP; + addr.smctp_network = netid; + addr.smctp_addr.s_addr = 8; + addr.smctp_type = 0; + rc = kernel_bind(sock, (struct sockaddr_unsized *)&addr, sizeof(addr)); + KUNIT_ASSERT_EQ(test, rc, 0); + + *devp = dev; + *sockp = sock; +} + +static void __mctp_route_test_fini(struct kunit *test, + struct mctp_test_dev *dev, + struct mctp_dst *dst, + struct socket *sock) +{ + sock_release(sock); + mctp_dst_release(dst); + mctp_test_destroy_dev(dev); +} + +struct mctp_route_input_sk_test { + struct mctp_hdr hdr; + u8 type; + bool deliver; +}; + +static void mctp_test_route_input_sk(struct kunit *test) +{ + const struct mctp_route_input_sk_test *params; + struct sk_buff *skb, *skb2; + struct mctp_test_dev *dev; + struct mctp_dst dst; + struct socket *sock; + int rc; + + params = test->param_value; + + __mctp_route_test_init(test, &dev, &dst, &sock, MCTP_NET_ANY); + + skb = mctp_test_create_skb_data(¶ms->hdr, ¶ms->type); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + mctp_test_skb_set_dev(skb, dev); + + rc = mctp_dst_input(&dst, skb); + + if (params->deliver) { + KUNIT_EXPECT_EQ(test, rc, 0); + + skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2); + KUNIT_EXPECT_EQ(test, skb2->len, 1); + + skb_free_datagram(sock->sk, skb2); + + } else { + KUNIT_EXPECT_NE(test, rc, 0); + skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_NULL(test, skb2); + } + + __mctp_route_test_fini(test, dev, &dst, sock); +} + +#define FL_S (MCTP_HDR_FLAG_SOM) +#define FL_E (MCTP_HDR_FLAG_EOM) +#define FL_TO (MCTP_HDR_FLAG_TO) +#define FL_T(t) ((t) & MCTP_HDR_TAG_MASK) + +static const struct mctp_route_input_sk_test mctp_route_input_sk_tests[] = { + { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 0, .deliver = true }, + { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 1, .deliver = false }, + { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E), .type = 0, .deliver = false }, + { .hdr = RX_HDR(1, 10, 8, FL_E | FL_TO), .type = 0, .deliver = false }, + { .hdr = RX_HDR(1, 10, 8, FL_TO), .type = 0, .deliver = false }, + { .hdr = RX_HDR(1, 10, 8, 0), .type = 0, .deliver = false }, +}; + +static void mctp_route_input_sk_to_desc(const struct mctp_route_input_sk_test *t, + char *desc) +{ + sprintf(desc, "{%x,%x,%x,%x} type %d", t->hdr.ver, t->hdr.src, + t->hdr.dest, t->hdr.flags_seq_tag, t->type); +} + +KUNIT_ARRAY_PARAM(mctp_route_input_sk, mctp_route_input_sk_tests, + mctp_route_input_sk_to_desc); + +struct mctp_route_input_sk_reasm_test { + const char *name; + struct mctp_hdr hdrs[4]; + int n_hdrs; + int rx_len; +}; + +static void mctp_test_route_input_sk_reasm(struct kunit *test) +{ + const struct mctp_route_input_sk_reasm_test *params; + struct sk_buff *skb, *skb2; + struct mctp_test_dev *dev; + struct mctp_dst dst; + struct socket *sock; + int i, rc; + u8 c; + + params = test->param_value; + + __mctp_route_test_init(test, &dev, &dst, &sock, MCTP_NET_ANY); + + for (i = 0; i < params->n_hdrs; i++) { + c = i; + skb = mctp_test_create_skb_data(¶ms->hdrs[i], &c); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + mctp_test_skb_set_dev(skb, dev); + + rc = mctp_dst_input(&dst, skb); + } + + skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + + if (params->rx_len) { + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2); + KUNIT_EXPECT_EQ(test, skb2->len, params->rx_len); + skb_free_datagram(sock->sk, skb2); + + } else { + KUNIT_EXPECT_NULL(test, skb2); + } + + __mctp_route_test_fini(test, dev, &dst, sock); +} + +#define RX_FRAG(f, s) RX_HDR(1, 10, 8, FL_TO | (f) | ((s) << MCTP_HDR_SEQ_SHIFT)) + +static const struct mctp_route_input_sk_reasm_test mctp_route_input_sk_reasm_tests[] = { + { + .name = "single packet", + .hdrs = { + RX_FRAG(FL_S | FL_E, 0), + }, + .n_hdrs = 1, + .rx_len = 1, + }, + { + .name = "single packet, offset seq", + .hdrs = { + RX_FRAG(FL_S | FL_E, 1), + }, + .n_hdrs = 1, + .rx_len = 1, + }, + { + .name = "start & end packets", + .hdrs = { + RX_FRAG(FL_S, 0), + RX_FRAG(FL_E, 1), + }, + .n_hdrs = 2, + .rx_len = 2, + }, + { + .name = "start & end packets, offset seq", + .hdrs = { + RX_FRAG(FL_S, 1), + RX_FRAG(FL_E, 2), + }, + .n_hdrs = 2, + .rx_len = 2, + }, + { + .name = "start & end packets, out of order", + .hdrs = { + RX_FRAG(FL_E, 1), + RX_FRAG(FL_S, 0), + }, + .n_hdrs = 2, + .rx_len = 0, + }, + { + .name = "start, middle & end packets", + .hdrs = { + RX_FRAG(FL_S, 0), + RX_FRAG(0, 1), + RX_FRAG(FL_E, 2), + }, + .n_hdrs = 3, + .rx_len = 3, + }, + { + .name = "missing seq", + .hdrs = { + RX_FRAG(FL_S, 0), + RX_FRAG(FL_E, 2), + }, + .n_hdrs = 2, + .rx_len = 0, + }, + { + .name = "seq wrap", + .hdrs = { + RX_FRAG(FL_S, 3), + RX_FRAG(FL_E, 0), + }, + .n_hdrs = 2, + .rx_len = 2, + }, +}; + +static void mctp_route_input_sk_reasm_to_desc( + const struct mctp_route_input_sk_reasm_test *t, + char *desc) +{ + sprintf(desc, "%s", t->name); +} + +KUNIT_ARRAY_PARAM(mctp_route_input_sk_reasm, mctp_route_input_sk_reasm_tests, + mctp_route_input_sk_reasm_to_desc); + +struct mctp_route_input_sk_keys_test { + const char *name; + mctp_eid_t key_peer_addr; + mctp_eid_t key_local_addr; + u8 key_tag; + struct mctp_hdr hdr; + bool deliver; +}; + +/* test packet rx in the presence of various key configurations */ +static void mctp_test_route_input_sk_keys(struct kunit *test) +{ + const struct mctp_route_input_sk_keys_test *params; + struct sk_buff *skb, *skb2; + struct mctp_test_dev *dev; + struct mctp_sk_key *key; + struct netns_mctp *mns; + struct mctp_sock *msk; + struct socket *sock; + unsigned long flags; + struct mctp_dst dst; + unsigned int net; + int rc; + u8 c; + + params = test->param_value; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + net = READ_ONCE(dev->mdev->net); + + mctp_test_dst_setup(test, &dst, dev, 68); + + rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock); + KUNIT_ASSERT_EQ(test, rc, 0); + + msk = container_of(sock->sk, struct mctp_sock, sk); + mns = &sock_net(sock->sk)->mctp; + + /* set the incoming tag according to test params */ + key = mctp_key_alloc(msk, net, params->key_local_addr, + params->key_peer_addr, params->key_tag, + GFP_KERNEL); + + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, key); + + spin_lock_irqsave(&mns->keys_lock, flags); + mctp_reserve_tag(&init_net, key, msk); + spin_unlock_irqrestore(&mns->keys_lock, flags); + + /* create packet and route */ + c = 0; + skb = mctp_test_create_skb_data(¶ms->hdr, &c); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + mctp_test_skb_set_dev(skb, dev); + + rc = mctp_dst_input(&dst, skb); + + /* (potentially) receive message */ + skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + + if (params->deliver) + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2); + else + KUNIT_EXPECT_PTR_EQ(test, skb2, NULL); + + if (skb2) + skb_free_datagram(sock->sk, skb2); + + mctp_key_unref(key); + __mctp_route_test_fini(test, dev, &dst, sock); +} + +static const struct mctp_route_input_sk_keys_test mctp_route_input_sk_keys_tests[] = { + { + .name = "direct match", + .key_peer_addr = 9, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)), + .deliver = true, + }, + { + .name = "flipped src/dest", + .key_peer_addr = 8, + .key_local_addr = 9, + .key_tag = 1, + .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)), + .deliver = false, + }, + { + .name = "peer addr mismatch", + .key_peer_addr = 9, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T(1)), + .deliver = false, + }, + { + .name = "tag value mismatch", + .key_peer_addr = 9, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(2)), + .deliver = false, + }, + { + .name = "TO mismatch", + .key_peer_addr = 9, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1) | FL_TO), + .deliver = false, + }, + { + .name = "broadcast response", + .key_peer_addr = MCTP_ADDR_ANY, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 11, 8, FL_S | FL_E | FL_T(1)), + .deliver = true, + }, + { + .name = "any local match", + .key_peer_addr = 12, + .key_local_addr = MCTP_ADDR_ANY, + .key_tag = 1, + .hdr = RX_HDR(1, 12, 8, FL_S | FL_E | FL_T(1)), + .deliver = true, + }, +}; + +static void mctp_route_input_sk_keys_to_desc( + const struct mctp_route_input_sk_keys_test *t, + char *desc) +{ + sprintf(desc, "%s", t->name); +} + +KUNIT_ARRAY_PARAM(mctp_route_input_sk_keys, mctp_route_input_sk_keys_tests, + mctp_route_input_sk_keys_to_desc); + +struct test_net { + unsigned int netid; + struct mctp_test_dev *dev; + struct mctp_dst dst; + struct socket *sock; + struct sk_buff *skb; + struct mctp_sk_key *key; + struct { + u8 type; + unsigned int data; + } msg; +}; + +static void +mctp_test_route_input_multiple_nets_bind_init(struct kunit *test, + struct test_net *t) +{ + struct mctp_hdr hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1) | FL_TO); + + t->msg.data = t->netid; + + __mctp_route_test_init(test, &t->dev, &t->dst, &t->sock, t->netid); + + t->skb = mctp_test_create_skb_data(&hdr, &t->msg); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->skb); + mctp_test_skb_set_dev(t->skb, t->dev); +} + +static void +mctp_test_route_input_multiple_nets_bind_fini(struct kunit *test, + struct test_net *t) +{ + __mctp_route_test_fini(test, t->dev, &t->dst, t->sock); +} + +/* Test that skbs from different nets (otherwise identical) get routed to their + * corresponding socket via the sockets' bind() + */ +static void mctp_test_route_input_multiple_nets_bind(struct kunit *test) +{ + struct sk_buff *rx_skb1, *rx_skb2; + struct test_net t1, t2; + int rc; + + t1.netid = 1; + t2.netid = 2; + + t1.msg.type = 0; + t2.msg.type = 0; + + mctp_test_route_input_multiple_nets_bind_init(test, &t1); + mctp_test_route_input_multiple_nets_bind_init(test, &t2); + + rc = mctp_dst_input(&t1.dst, t1.skb); + KUNIT_ASSERT_EQ(test, rc, 0); + rc = mctp_dst_input(&t2.dst, t2.skb); + KUNIT_ASSERT_EQ(test, rc, 0); + + rx_skb1 = skb_recv_datagram(t1.sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb1); + KUNIT_EXPECT_EQ(test, rx_skb1->len, sizeof(t1.msg)); + KUNIT_EXPECT_EQ(test, + *(unsigned int *)skb_pull(rx_skb1, sizeof(t1.msg.data)), + t1.netid); + kfree_skb(rx_skb1); + + rx_skb2 = skb_recv_datagram(t2.sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb2); + KUNIT_EXPECT_EQ(test, rx_skb2->len, sizeof(t2.msg)); + KUNIT_EXPECT_EQ(test, + *(unsigned int *)skb_pull(rx_skb2, sizeof(t2.msg.data)), + t2.netid); + kfree_skb(rx_skb2); + + mctp_test_route_input_multiple_nets_bind_fini(test, &t1); + mctp_test_route_input_multiple_nets_bind_fini(test, &t2); +} + +static void +mctp_test_route_input_multiple_nets_key_init(struct kunit *test, + struct test_net *t) +{ + struct mctp_hdr hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)); + struct mctp_sock *msk; + struct netns_mctp *mns; + unsigned long flags; + + t->msg.data = t->netid; + + __mctp_route_test_init(test, &t->dev, &t->dst, &t->sock, t->netid); + + msk = container_of(t->sock->sk, struct mctp_sock, sk); + + t->key = mctp_key_alloc(msk, t->netid, hdr.dest, hdr.src, 1, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->key); + + mns = &sock_net(t->sock->sk)->mctp; + spin_lock_irqsave(&mns->keys_lock, flags); + mctp_reserve_tag(&init_net, t->key, msk); + spin_unlock_irqrestore(&mns->keys_lock, flags); + + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->key); + t->skb = mctp_test_create_skb_data(&hdr, &t->msg); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->skb); + mctp_test_skb_set_dev(t->skb, t->dev); +} + +static void +mctp_test_route_input_multiple_nets_key_fini(struct kunit *test, + struct test_net *t) +{ + mctp_key_unref(t->key); + __mctp_route_test_fini(test, t->dev, &t->dst, t->sock); +} + +/* test that skbs from different nets (otherwise identical) get routed to their + * corresponding socket via the sk_key + */ +static void mctp_test_route_input_multiple_nets_key(struct kunit *test) +{ + struct sk_buff *rx_skb1, *rx_skb2; + struct test_net t1, t2; + int rc; + + t1.netid = 1; + t2.netid = 2; + + /* use type 1 which is not bound */ + t1.msg.type = 1; + t2.msg.type = 1; + + mctp_test_route_input_multiple_nets_key_init(test, &t1); + mctp_test_route_input_multiple_nets_key_init(test, &t2); + + rc = mctp_dst_input(&t1.dst, t1.skb); + KUNIT_ASSERT_EQ(test, rc, 0); + rc = mctp_dst_input(&t2.dst, t2.skb); + KUNIT_ASSERT_EQ(test, rc, 0); + + rx_skb1 = skb_recv_datagram(t1.sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb1); + KUNIT_EXPECT_EQ(test, rx_skb1->len, sizeof(t1.msg)); + KUNIT_EXPECT_EQ(test, + *(unsigned int *)skb_pull(rx_skb1, sizeof(t1.msg.data)), + t1.netid); + kfree_skb(rx_skb1); + + rx_skb2 = skb_recv_datagram(t2.sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb2); + KUNIT_EXPECT_EQ(test, rx_skb2->len, sizeof(t2.msg)); + KUNIT_EXPECT_EQ(test, + *(unsigned int *)skb_pull(rx_skb2, sizeof(t2.msg.data)), + t2.netid); + kfree_skb(rx_skb2); + + mctp_test_route_input_multiple_nets_key_fini(test, &t1); + mctp_test_route_input_multiple_nets_key_fini(test, &t2); +} + +/* Input route to socket, using a single-packet message, where sock delivery + * fails. Ensure we're handling the failure appropriately. + */ +static void mctp_test_route_input_sk_fail_single(struct kunit *test) +{ + const struct mctp_hdr hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO); + struct mctp_test_dev *dev; + struct mctp_dst dst; + struct socket *sock; + struct sk_buff *skb; + int rc; + + __mctp_route_test_init(test, &dev, &dst, &sock, MCTP_NET_ANY); + + /* No rcvbuf space, so delivery should fail. __sock_set_rcvbuf will + * clamp the minimum to SOCK_MIN_RCVBUF, so we open-code this. + */ + lock_sock(sock->sk); + WRITE_ONCE(sock->sk->sk_rcvbuf, 0); + release_sock(sock->sk); + + skb = mctp_test_create_skb(&hdr, 10); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + skb_get(skb); + + mctp_test_skb_set_dev(skb, dev); + + /* do route input, which should fail */ + rc = mctp_dst_input(&dst, skb); + KUNIT_EXPECT_NE(test, rc, 0); + + /* we should hold the only reference to skb */ + KUNIT_EXPECT_EQ(test, refcount_read(&skb->users), 1); + kfree_skb(skb); + + __mctp_route_test_fini(test, dev, &dst, sock); +} + +/* Input route to socket, using a fragmented message, where sock delivery fails. + */ +static void mctp_test_route_input_sk_fail_frag(struct kunit *test) +{ + const struct mctp_hdr hdrs[2] = { RX_FRAG(FL_S, 0), RX_FRAG(FL_E, 1) }; + struct mctp_test_dev *dev; + struct sk_buff *skbs[2]; + struct mctp_dst dst; + struct socket *sock; + unsigned int i; + int rc; + + __mctp_route_test_init(test, &dev, &dst, &sock, MCTP_NET_ANY); + + lock_sock(sock->sk); + WRITE_ONCE(sock->sk->sk_rcvbuf, 0); + release_sock(sock->sk); + + for (i = 0; i < ARRAY_SIZE(skbs); i++) { + skbs[i] = mctp_test_create_skb(&hdrs[i], 10); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skbs[i]); + skb_get(skbs[i]); + + mctp_test_skb_set_dev(skbs[i], dev); + } + + /* first route input should succeed, we're only queueing to the + * frag list + */ + rc = mctp_dst_input(&dst, skbs[0]); + KUNIT_EXPECT_EQ(test, rc, 0); + + /* final route input should fail to deliver to the socket */ + rc = mctp_dst_input(&dst, skbs[1]); + KUNIT_EXPECT_NE(test, rc, 0); + + /* we should hold the only reference to both skbs */ + KUNIT_EXPECT_EQ(test, refcount_read(&skbs[0]->users), 1); + kfree_skb(skbs[0]); + + KUNIT_EXPECT_EQ(test, refcount_read(&skbs[1]->users), 1); + kfree_skb(skbs[1]); + + __mctp_route_test_fini(test, dev, &dst, sock); +} + +/* Input route to socket, using a fragmented message created from clones. + */ +static void mctp_test_route_input_cloned_frag(struct kunit *test) +{ + /* 5 packet fragments, forming 2 complete messages */ + const struct mctp_hdr hdrs[5] = { + RX_FRAG(FL_S, 0), + RX_FRAG(0, 1), + RX_FRAG(FL_E, 2), + RX_FRAG(FL_S, 0), + RX_FRAG(FL_E, 1), + }; + const size_t data_len = 3; /* arbitrary */ + u8 compare[3 * ARRAY_SIZE(hdrs)]; + u8 flat[3 * ARRAY_SIZE(hdrs)]; + struct mctp_test_dev *dev; + struct sk_buff *skb[5]; + struct sk_buff *rx_skb; + struct mctp_dst dst; + struct socket *sock; + size_t total; + void *p; + int rc; + + total = data_len + sizeof(struct mctp_hdr); + + __mctp_route_test_init(test, &dev, &dst, &sock, MCTP_NET_ANY); + + /* Create a single skb initially with concatenated packets */ + skb[0] = mctp_test_create_skb(&hdrs[0], 5 * total); + mctp_test_skb_set_dev(skb[0], dev); + memset(skb[0]->data, 0 * 0x11, skb[0]->len); + memcpy(skb[0]->data, &hdrs[0], sizeof(struct mctp_hdr)); + + /* Extract and populate packets */ + for (int i = 1; i < 5; i++) { + skb[i] = skb_clone(skb[i - 1], GFP_ATOMIC); + KUNIT_ASSERT_TRUE(test, skb[i]); + p = skb_pull(skb[i], total); + KUNIT_ASSERT_TRUE(test, p); + skb_reset_network_header(skb[i]); + memcpy(skb[i]->data, &hdrs[i], sizeof(struct mctp_hdr)); + memset(&skb[i]->data[sizeof(struct mctp_hdr)], i * 0x11, data_len); + } + for (int i = 0; i < 5; i++) + skb_trim(skb[i], total); + + /* SOM packets have a type byte to match the socket */ + skb[0]->data[4] = 0; + skb[3]->data[4] = 0; + + skb_dump("pkt1 ", skb[0], false); + skb_dump("pkt2 ", skb[1], false); + skb_dump("pkt3 ", skb[2], false); + skb_dump("pkt4 ", skb[3], false); + skb_dump("pkt5 ", skb[4], false); + + for (int i = 0; i < 5; i++) { + KUNIT_EXPECT_EQ(test, refcount_read(&skb[i]->users), 1); + /* Take a reference so we can check refcounts at the end */ + skb_get(skb[i]); + } + + /* Feed the fragments into MCTP core */ + for (int i = 0; i < 5; i++) { + rc = mctp_dst_input(&dst, skb[i]); + KUNIT_EXPECT_EQ(test, rc, 0); + } + + /* Receive first reassembled message */ + rx_skb = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_EQ(test, rc, 0); + KUNIT_EXPECT_EQ(test, rx_skb->len, 3 * data_len); + rc = skb_copy_bits(rx_skb, 0, flat, rx_skb->len); + for (int i = 0; i < rx_skb->len; i++) + compare[i] = (i / data_len) * 0x11; + /* Set type byte */ + compare[0] = 0; + + KUNIT_EXPECT_MEMEQ(test, flat, compare, rx_skb->len); + KUNIT_EXPECT_EQ(test, refcount_read(&rx_skb->users), 1); + kfree_skb(rx_skb); + + /* Receive second reassembled message */ + rx_skb = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_EQ(test, rc, 0); + KUNIT_EXPECT_EQ(test, rx_skb->len, 2 * data_len); + rc = skb_copy_bits(rx_skb, 0, flat, rx_skb->len); + for (int i = 0; i < rx_skb->len; i++) + compare[i] = (i / data_len + 3) * 0x11; + /* Set type byte */ + compare[0] = 0; + + KUNIT_EXPECT_MEMEQ(test, flat, compare, rx_skb->len); + KUNIT_EXPECT_EQ(test, refcount_read(&rx_skb->users), 1); + kfree_skb(rx_skb); + + /* Check input skb refcounts */ + for (int i = 0; i < 5; i++) { + KUNIT_EXPECT_EQ(test, refcount_read(&skb[i]->users), 1); + kfree_skb(skb[i]); + } + + __mctp_route_test_fini(test, dev, &dst, sock); +} + +#if IS_ENABLED(CONFIG_MCTP_FLOWS) + +static void mctp_test_flow_init(struct kunit *test, + struct mctp_test_dev **devp, + struct mctp_dst *dst, + struct socket **sock, + struct sk_buff **skbp, + unsigned int len) +{ + struct mctp_test_dev *dev; + struct sk_buff *skb; + + /* we have a slightly odd routing setup here; the test route + * is for EID 8, which is our local EID. We don't do a routing + * lookup, so that's fine - all we require is a path through + * mctp_local_output, which will call dst->output on whatever + * route we provide + */ + __mctp_route_test_init(test, &dev, dst, sock, MCTP_NET_ANY); + + /* Assign a single EID. ->addrs is freed on mctp netdev release */ + dev->mdev->addrs = kmalloc(sizeof(u8), GFP_KERNEL); + dev->mdev->num_addrs = 1; + dev->mdev->addrs[0] = 8; + + skb = alloc_skb(len + sizeof(struct mctp_hdr) + 1, GFP_KERNEL); + KUNIT_ASSERT_TRUE(test, skb); + __mctp_cb(skb); + skb_reserve(skb, sizeof(struct mctp_hdr) + 1); + memset(skb_put(skb, len), 0, len); + + + *devp = dev; + *skbp = skb; +} + +static void mctp_test_flow_fini(struct kunit *test, + struct mctp_test_dev *dev, + struct mctp_dst *dst, + struct socket *sock) +{ + __mctp_route_test_fini(test, dev, dst, sock); +} + +/* test that an outgoing skb has the correct MCTP extension data set */ +static void mctp_test_packet_flow(struct kunit *test) +{ + struct sk_buff *skb, *skb2; + struct mctp_test_dev *dev; + struct mctp_dst dst; + struct mctp_flow *flow; + struct socket *sock; + u8 dst_eid = 8; + int n, rc; + + mctp_test_flow_init(test, &dev, &dst, &sock, &skb, 30); + + rc = mctp_local_output(sock->sk, &dst, skb, dst_eid, MCTP_TAG_OWNER); + KUNIT_ASSERT_EQ(test, rc, 0); + + n = dev->pkts.qlen; + KUNIT_ASSERT_EQ(test, n, 1); + + skb2 = skb_dequeue(&dev->pkts); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2); + + flow = skb_ext_find(skb2, SKB_EXT_MCTP); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flow); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flow->key); + KUNIT_ASSERT_PTR_EQ(test, flow->key->sk, sock->sk); + + kfree_skb(skb2); + mctp_test_flow_fini(test, dev, &dst, sock); +} + +/* test that outgoing skbs, after fragmentation, all have the correct MCTP + * extension data set. + */ +static void mctp_test_fragment_flow(struct kunit *test) +{ + struct mctp_flow *flows[2]; + struct sk_buff *tx_skbs[2]; + struct mctp_test_dev *dev; + struct mctp_dst dst; + struct sk_buff *skb; + struct socket *sock; + u8 dst_eid = 8; + int n, rc; + + mctp_test_flow_init(test, &dev, &dst, &sock, &skb, 100); + + rc = mctp_local_output(sock->sk, &dst, skb, dst_eid, MCTP_TAG_OWNER); + KUNIT_ASSERT_EQ(test, rc, 0); + + n = dev->pkts.qlen; + KUNIT_ASSERT_EQ(test, n, 2); + + /* both resulting packets should have the same flow data */ + tx_skbs[0] = skb_dequeue(&dev->pkts); + tx_skbs[1] = skb_dequeue(&dev->pkts); + + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tx_skbs[0]); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tx_skbs[1]); + + flows[0] = skb_ext_find(tx_skbs[0], SKB_EXT_MCTP); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[0]); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[0]->key); + KUNIT_ASSERT_PTR_EQ(test, flows[0]->key->sk, sock->sk); + + flows[1] = skb_ext_find(tx_skbs[1], SKB_EXT_MCTP); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[1]); + KUNIT_ASSERT_PTR_EQ(test, flows[1]->key, flows[0]->key); + + kfree_skb(tx_skbs[0]); + kfree_skb(tx_skbs[1]); + mctp_test_flow_fini(test, dev, &dst, sock); +} + +#else +static void mctp_test_packet_flow(struct kunit *test) +{ + kunit_skip(test, "Requires CONFIG_MCTP_FLOWS=y"); +} + +static void mctp_test_fragment_flow(struct kunit *test) +{ + kunit_skip(test, "Requires CONFIG_MCTP_FLOWS=y"); +} +#endif + +/* Test that outgoing skbs cause a suitable tag to be created */ +static void mctp_test_route_output_key_create(struct kunit *test) +{ + const u8 dst_eid = 26, src_eid = 15; + const unsigned int netid = 50; + struct mctp_test_dev *dev; + struct mctp_sk_key *key; + struct netns_mctp *mns; + unsigned long flags; + struct socket *sock; + struct sk_buff *skb; + struct mctp_dst dst; + bool empty, single; + const int len = 2; + int rc; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + WRITE_ONCE(dev->mdev->net, netid); + + mctp_test_dst_setup(test, &dst, dev, 68); + + rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock); + KUNIT_ASSERT_EQ(test, rc, 0); + + dev->mdev->addrs = kmalloc(sizeof(u8), GFP_KERNEL); + dev->mdev->num_addrs = 1; + dev->mdev->addrs[0] = src_eid; + + skb = alloc_skb(sizeof(struct mctp_hdr) + 1 + len, GFP_KERNEL); + KUNIT_ASSERT_TRUE(test, skb); + __mctp_cb(skb); + skb_reserve(skb, sizeof(struct mctp_hdr) + 1 + len); + memset(skb_put(skb, len), 0, len); + + mns = &sock_net(sock->sk)->mctp; + + /* We assume we're starting from an empty keys list, which requires + * preceding tests to clean up correctly! + */ + spin_lock_irqsave(&mns->keys_lock, flags); + empty = hlist_empty(&mns->keys); + spin_unlock_irqrestore(&mns->keys_lock, flags); + KUNIT_ASSERT_TRUE(test, empty); + + rc = mctp_local_output(sock->sk, &dst, skb, dst_eid, MCTP_TAG_OWNER); + KUNIT_ASSERT_EQ(test, rc, 0); + + key = NULL; + single = false; + spin_lock_irqsave(&mns->keys_lock, flags); + if (!hlist_empty(&mns->keys)) { + key = hlist_entry(mns->keys.first, struct mctp_sk_key, hlist); + single = hlist_is_singular_node(&key->hlist, &mns->keys); + } + spin_unlock_irqrestore(&mns->keys_lock, flags); + + KUNIT_ASSERT_NOT_NULL(test, key); + KUNIT_ASSERT_TRUE(test, single); + + KUNIT_EXPECT_EQ(test, key->net, netid); + KUNIT_EXPECT_EQ(test, key->local_addr, src_eid); + KUNIT_EXPECT_EQ(test, key->peer_addr, dst_eid); + /* key has incoming tag, so inverse of what we sent */ + KUNIT_EXPECT_FALSE(test, key->tag & MCTP_TAG_OWNER); + + sock_release(sock); + mctp_dst_release(&dst); + mctp_test_destroy_dev(dev); +} + +static void mctp_test_route_extaddr_input(struct kunit *test) +{ + static const unsigned char haddr[] = { 0xaa, 0x55 }; + struct mctp_skb_cb *cb, *cb2; + const unsigned int len = 40; + struct mctp_test_dev *dev; + struct sk_buff *skb, *skb2; + struct mctp_dst dst; + struct mctp_hdr hdr; + struct socket *sock; + int rc; + + hdr.ver = 1; + hdr.src = 10; + hdr.dest = 8; + hdr.flags_seq_tag = FL_S | FL_E | FL_TO; + + __mctp_route_test_init(test, &dev, &dst, &sock, MCTP_NET_ANY); + + skb = mctp_test_create_skb(&hdr, len); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + /* set our hardware addressing data */ + cb = mctp_cb(skb); + memcpy(cb->haddr, haddr, sizeof(haddr)); + cb->halen = sizeof(haddr); + + mctp_test_skb_set_dev(skb, dev); + + rc = mctp_dst_input(&dst, skb); + KUNIT_ASSERT_EQ(test, rc, 0); + + skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2); + KUNIT_ASSERT_EQ(test, skb2->len, len); + + cb2 = mctp_cb(skb2); + + /* Received SKB should have the hardware addressing as set above. + * We're likely to have the same actual cb here (ie., cb == cb2), + * but it's the comparison that we care about + */ + KUNIT_EXPECT_EQ(test, cb2->halen, sizeof(haddr)); + KUNIT_EXPECT_MEMEQ(test, cb2->haddr, haddr, sizeof(haddr)); + + kfree_skb(skb2); + __mctp_route_test_fini(test, dev, &dst, sock); +} + +static void mctp_test_route_gw_lookup(struct kunit *test) +{ + struct mctp_test_route *rt1, *rt2; + struct mctp_dst dst = { 0 }; + struct mctp_test_dev *dev; + int rc; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + + /* 8 (local) -> 10 (gateway) via 9 (direct) */ + rt1 = mctp_test_create_route_direct(&init_net, dev->mdev, 9, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt1); + rt2 = mctp_test_create_route_gw(&init_net, dev->mdev->net, 10, 9, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt2); + + rc = mctp_route_lookup(&init_net, dev->mdev->net, 10, &dst); + KUNIT_EXPECT_EQ(test, rc, 0); + KUNIT_EXPECT_PTR_EQ(test, dst.dev, dev->mdev); + KUNIT_EXPECT_EQ(test, dst.mtu, dev->ndev->mtu); + KUNIT_EXPECT_EQ(test, dst.nexthop, 9); + KUNIT_EXPECT_EQ(test, dst.halen, 0); + + mctp_dst_release(&dst); + + mctp_test_route_destroy(test, rt2); + mctp_test_route_destroy(test, rt1); + mctp_test_destroy_dev(dev); +} + +static void mctp_test_route_gw_loop(struct kunit *test) +{ + struct mctp_test_route *rt1, *rt2; + struct mctp_dst dst = { 0 }; + struct mctp_test_dev *dev; + int rc; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + + /* two routes using each other as the gw */ + rt1 = mctp_test_create_route_gw(&init_net, dev->mdev->net, 9, 10, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt1); + rt2 = mctp_test_create_route_gw(&init_net, dev->mdev->net, 10, 9, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt2); + + /* this should fail, rather than infinite-loop */ + rc = mctp_route_lookup(&init_net, dev->mdev->net, 10, &dst); + KUNIT_EXPECT_NE(test, rc, 0); + + mctp_test_route_destroy(test, rt2); + mctp_test_route_destroy(test, rt1); + mctp_test_destroy_dev(dev); +} + +struct mctp_route_gw_mtu_test { + /* working away from the local stack */ + unsigned int dev, neigh, gw, dst; + unsigned int exp; +}; + +static void mctp_route_gw_mtu_to_desc(const struct mctp_route_gw_mtu_test *t, + char *desc) +{ + sprintf(desc, "dev %d, neigh %d, gw %d, dst %d -> %d", + t->dev, t->neigh, t->gw, t->dst, t->exp); +} + +static const struct mctp_route_gw_mtu_test mctp_route_gw_mtu_tests[] = { + /* no route-specific MTUs */ + { 68, 0, 0, 0, 68 }, + { 100, 0, 0, 0, 100 }, + /* one route MTU (smaller than dev mtu), others unrestricted */ + { 100, 68, 0, 0, 68 }, + { 100, 0, 68, 0, 68 }, + { 100, 0, 0, 68, 68 }, + /* smallest applied, regardless of order */ + { 100, 99, 98, 68, 68 }, + { 99, 100, 98, 68, 68 }, + { 98, 99, 100, 68, 68 }, + { 68, 98, 99, 100, 68 }, +}; + +KUNIT_ARRAY_PARAM(mctp_route_gw_mtu, mctp_route_gw_mtu_tests, + mctp_route_gw_mtu_to_desc); + +static void mctp_test_route_gw_mtu(struct kunit *test) +{ + const struct mctp_route_gw_mtu_test *mtus = test->param_value; + struct mctp_test_route *rt1, *rt2, *rt3; + struct mctp_dst dst = { 0 }; + struct mctp_test_dev *dev; + struct mctp_dev *mdev; + unsigned int netid; + int rc; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + dev->ndev->mtu = mtus->dev; + mdev = dev->mdev; + netid = mdev->net; + + /* 8 (local) -> 11 (dst) via 10 (gw) via 9 (neigh) */ + rt1 = mctp_test_create_route_direct(&init_net, mdev, 9, mtus->neigh); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt1); + + rt2 = mctp_test_create_route_gw(&init_net, netid, 10, 9, mtus->gw); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt2); + + rt3 = mctp_test_create_route_gw(&init_net, netid, 11, 10, mtus->dst); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt3); + + rc = mctp_route_lookup(&init_net, dev->mdev->net, 11, &dst); + KUNIT_EXPECT_EQ(test, rc, 0); + KUNIT_EXPECT_EQ(test, dst.mtu, mtus->exp); + + mctp_dst_release(&dst); + + mctp_test_route_destroy(test, rt3); + mctp_test_route_destroy(test, rt2); + mctp_test_route_destroy(test, rt1); + mctp_test_destroy_dev(dev); +} + +#define MCTP_TEST_LLADDR_LEN 2 +struct mctp_test_llhdr { + unsigned int magic; + unsigned char src[MCTP_TEST_LLADDR_LEN]; + unsigned char dst[MCTP_TEST_LLADDR_LEN]; +}; + +static const unsigned int mctp_test_llhdr_magic = 0x5c78339c; + +static int test_dev_header_create(struct sk_buff *skb, struct net_device *dev, + unsigned short type, const void *daddr, + const void *saddr, unsigned int len) +{ + struct kunit *test = current->kunit_test; + struct mctp_test_llhdr *hdr; + + hdr = skb_push(skb, sizeof(*hdr)); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, hdr); + skb_reset_mac_header(skb); + + hdr->magic = mctp_test_llhdr_magic; + memcpy(&hdr->src, saddr, sizeof(hdr->src)); + memcpy(&hdr->dst, daddr, sizeof(hdr->dst)); + + return 0; +} + +/* Test the dst_output path for a gateway-routed skb: we should have it + * lookup the nexthop EID in the neighbour table, and call into + * header_ops->create to resolve that to a lladdr. Our mock header_ops->create + * will just set a synthetic link-layer header, which we check after transmit. + */ +static void mctp_test_route_gw_output(struct kunit *test) +{ + const unsigned char haddr_self[MCTP_TEST_LLADDR_LEN] = { 0xaa, 0x03 }; + const unsigned char haddr_peer[MCTP_TEST_LLADDR_LEN] = { 0xaa, 0x02 }; + const struct header_ops ops = { + .create = test_dev_header_create, + }; + struct mctp_neigh neigh = { 0 }; + struct mctp_test_llhdr *ll_hdr; + struct mctp_dst dst = { 0 }; + struct mctp_hdr hdr = { 0 }; + struct mctp_test_dev *dev; + struct sk_buff *skb; + unsigned char *buf; + int i, rc; + + dev = mctp_test_create_dev_lladdr(sizeof(haddr_self), haddr_self); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + dev->ndev->header_ops = &ops; + + dst.dev = dev->mdev; + __mctp_dev_get(dst.dev->dev); + dst.mtu = 68; + dst.nexthop = 9; + + /* simple mctp_neigh_add for the gateway (not dest!) endpoint */ + INIT_LIST_HEAD(&neigh.list); + neigh.dev = dev->mdev; + mctp_dev_hold(dev->mdev); + neigh.eid = 9; + neigh.source = MCTP_NEIGH_STATIC; + memcpy(neigh.ha, haddr_peer, sizeof(haddr_peer)); + list_add_rcu(&neigh.list, &init_net.mctp.neighbours); + + hdr.ver = 1; + hdr.src = 8; + hdr.dest = 10; + hdr.flags_seq_tag = FL_S | FL_E | FL_TO; + + /* construct enough for a future link-layer header, the provided + * mctp header, and 4 bytes of data + */ + skb = alloc_skb(sizeof(*ll_hdr) + sizeof(hdr) + 4, GFP_KERNEL); + skb->dev = dev->ndev; + __mctp_cb(skb); + + skb_reserve(skb, sizeof(*ll_hdr)); + + memcpy(skb_put(skb, sizeof(hdr)), &hdr, sizeof(hdr)); + buf = skb_put(skb, 4); + for (i = 0; i < 4; i++) + buf[i] = i; + + /* extra ref over the dev_xmit */ + skb_get(skb); + + rc = mctp_dst_output(&dst, skb); + KUNIT_EXPECT_EQ(test, rc, 0); + + mctp_dst_release(&dst); + list_del_rcu(&neigh.list); + mctp_dev_put(dev->mdev); + + /* check that we have our header created with the correct neighbour */ + ll_hdr = (void *)skb_mac_header(skb); + KUNIT_EXPECT_EQ(test, ll_hdr->magic, mctp_test_llhdr_magic); + KUNIT_EXPECT_MEMEQ(test, ll_hdr->src, haddr_self, sizeof(haddr_self)); + KUNIT_EXPECT_MEMEQ(test, ll_hdr->dst, haddr_peer, sizeof(haddr_peer)); + kfree_skb(skb); +} + +struct mctp_bind_lookup_test { + /* header of incoming message */ + struct mctp_hdr hdr; + u8 ty; + /* mctp network of incoming interface (smctp_network) */ + unsigned int net; + + /* expected socket, matches .name in lookup_binds, NULL for dropped */ + const char *expect; +}; + +/* Single-packet TO-set message */ +#define LK(src, dst) RX_HDR(1, (src), (dst), FL_S | FL_E | FL_TO) + +/* Input message test cases for bind lookup tests. + * + * 10 and 11 are local EIDs. + * 20 and 21 are remote EIDs. + */ +static const struct mctp_bind_lookup_test mctp_bind_lookup_tests[] = { + /* both local-eid and remote-eid binds, remote eid is preferenced */ + { .hdr = LK(20, 10), .ty = 1, .net = 1, .expect = "remote20" }, + + { .hdr = LK(20, 255), .ty = 1, .net = 1, .expect = "remote20" }, + { .hdr = LK(20, 0), .ty = 1, .net = 1, .expect = "remote20" }, + { .hdr = LK(0, 255), .ty = 1, .net = 1, .expect = "any" }, + { .hdr = LK(0, 11), .ty = 1, .net = 1, .expect = "any" }, + { .hdr = LK(0, 0), .ty = 1, .net = 1, .expect = "any" }, + { .hdr = LK(0, 10), .ty = 1, .net = 1, .expect = "local10" }, + { .hdr = LK(21, 10), .ty = 1, .net = 1, .expect = "local10" }, + { .hdr = LK(21, 11), .ty = 1, .net = 1, .expect = "remote21local11" }, + + /* both src and dest set to eid=99. unusual, but accepted + * by MCTP stack currently. + */ + { .hdr = LK(99, 99), .ty = 1, .net = 1, .expect = "any" }, + + /* unbound smctp_type */ + { .hdr = LK(20, 10), .ty = 3, .net = 1, .expect = NULL }, + + /* smctp_network tests */ + + { .hdr = LK(0, 0), .ty = 1, .net = 7, .expect = "any" }, + { .hdr = LK(21, 10), .ty = 1, .net = 2, .expect = "any" }, + + /* remote EID 20 matches, but MCTP_NET_ANY in "remote20" resolved + * to net=1, so lookup doesn't match "remote20" + */ + { .hdr = LK(20, 10), .ty = 1, .net = 3, .expect = "any" }, + + { .hdr = LK(21, 10), .ty = 1, .net = 3, .expect = "remote21net3" }, + { .hdr = LK(21, 10), .ty = 1, .net = 4, .expect = "remote21net4" }, + { .hdr = LK(21, 10), .ty = 1, .net = 5, .expect = "remote21net5" }, + + { .hdr = LK(21, 10), .ty = 1, .net = 5, .expect = "remote21net5" }, + + { .hdr = LK(99, 10), .ty = 1, .net = 8, .expect = "local10net8" }, + + { .hdr = LK(99, 10), .ty = 1, .net = 9, .expect = "anynet9" }, + { .hdr = LK(0, 0), .ty = 1, .net = 9, .expect = "anynet9" }, + { .hdr = LK(99, 99), .ty = 1, .net = 9, .expect = "anynet9" }, + { .hdr = LK(20, 10), .ty = 1, .net = 9, .expect = "anynet9" }, +}; + +/* Binds to create during the lookup tests */ +static const struct mctp_test_bind_setup lookup_binds[] = { + /* any address and net, type 1 */ + { .name = "any", .bind_addr = MCTP_ADDR_ANY, + .bind_net = MCTP_NET_ANY, .bind_type = 1, }, + /* local eid 10, net 1 (resolved from MCTP_NET_ANY) */ + { .name = "local10", .bind_addr = 10, + .bind_net = MCTP_NET_ANY, .bind_type = 1, }, + /* local eid 10, net 8 */ + { .name = "local10net8", .bind_addr = 10, + .bind_net = 8, .bind_type = 1, }, + /* any EID, net 9 */ + { .name = "anynet9", .bind_addr = MCTP_ADDR_ANY, + .bind_net = 9, .bind_type = 1, }, + + /* remote eid 20, net 1, any local eid */ + { .name = "remote20", .bind_addr = MCTP_ADDR_ANY, + .bind_net = MCTP_NET_ANY, .bind_type = 1, + .have_peer = true, .peer_addr = 20, .peer_net = MCTP_NET_ANY, }, + + /* remote eid 20, net 1, local eid 11 */ + { .name = "remote21local11", .bind_addr = 11, + .bind_net = MCTP_NET_ANY, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = MCTP_NET_ANY, }, + + /* remote eid 21, specific net=3 for connect() */ + { .name = "remote21net3", .bind_addr = MCTP_ADDR_ANY, + .bind_net = MCTP_NET_ANY, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = 3, }, + + /* remote eid 21, net 4 for bind, specific net=4 for connect() */ + { .name = "remote21net4", .bind_addr = MCTP_ADDR_ANY, + .bind_net = 4, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = 4, }, + + /* remote eid 21, net 5 for bind, specific net=5 for connect() */ + { .name = "remote21net5", .bind_addr = MCTP_ADDR_ANY, + .bind_net = 5, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = 5, }, +}; + +static void mctp_bind_lookup_desc(const struct mctp_bind_lookup_test *t, + char *desc) +{ + snprintf(desc, KUNIT_PARAM_DESC_SIZE, + "{src %d dst %d ty %d net %d expect %s}", + t->hdr.src, t->hdr.dest, t->ty, t->net, t->expect); +} + +KUNIT_ARRAY_PARAM(mctp_bind_lookup, mctp_bind_lookup_tests, + mctp_bind_lookup_desc); + +static void mctp_test_bind_lookup(struct kunit *test) +{ + const struct mctp_bind_lookup_test *rx; + struct socket *socks[ARRAY_SIZE(lookup_binds)]; + struct sk_buff *skb_pkt = NULL, *skb_sock = NULL; + struct socket *sock_ty0, *sock_expect = NULL; + struct mctp_test_dev *dev; + struct mctp_dst dst; + int rc; + + rx = test->param_value; + + __mctp_route_test_init(test, &dev, &dst, &sock_ty0, rx->net); + /* Create all binds */ + for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) { + mctp_test_bind_run(test, &lookup_binds[i], + &rc, &socks[i]); + KUNIT_ASSERT_EQ(test, rc, 0); + + /* Record the expected receive socket */ + if (rx->expect && + strcmp(rx->expect, lookup_binds[i].name) == 0) { + KUNIT_ASSERT_NULL(test, sock_expect); + sock_expect = socks[i]; + } + } + KUNIT_ASSERT_EQ(test, !!sock_expect, !!rx->expect); + + /* Create test message */ + skb_pkt = mctp_test_create_skb_data(&rx->hdr, &rx->ty); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb_pkt); + mctp_test_skb_set_dev(skb_pkt, dev); + + rc = mctp_dst_input(&dst, skb_pkt); + if (rx->expect) { + /* Test the message is received on the expected socket */ + KUNIT_EXPECT_EQ(test, rc, 0); + skb_sock = skb_recv_datagram(sock_expect->sk, + MSG_DONTWAIT, &rc); + if (!skb_sock) { + /* Find which socket received it instead */ + for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) { + skb_sock = skb_recv_datagram(socks[i]->sk, + MSG_DONTWAIT, &rc); + if (skb_sock) { + KUNIT_FAIL(test, + "received on incorrect socket '%s', expect '%s'", + lookup_binds[i].name, + rx->expect); + goto cleanup; + } + } + KUNIT_FAIL(test, "no message received"); + } + } else { + KUNIT_EXPECT_NE(test, rc, 0); + } + +cleanup: + kfree_skb(skb_sock); + + /* Drop all binds */ + for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) + sock_release(socks[i]); + + __mctp_route_test_fini(test, dev, &dst, sock_ty0); +} + +static struct kunit_case mctp_test_cases[] = { + KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params), + KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params), + KUNIT_CASE_PARAM(mctp_test_route_input_sk, mctp_route_input_sk_gen_params), + KUNIT_CASE_PARAM(mctp_test_route_input_sk_reasm, + mctp_route_input_sk_reasm_gen_params), + KUNIT_CASE_PARAM(mctp_test_route_input_sk_keys, + mctp_route_input_sk_keys_gen_params), + KUNIT_CASE(mctp_test_route_input_sk_fail_single), + KUNIT_CASE(mctp_test_route_input_sk_fail_frag), + KUNIT_CASE(mctp_test_route_input_multiple_nets_bind), + KUNIT_CASE(mctp_test_route_input_multiple_nets_key), + KUNIT_CASE(mctp_test_packet_flow), + KUNIT_CASE(mctp_test_fragment_flow), + KUNIT_CASE(mctp_test_route_output_key_create), + KUNIT_CASE(mctp_test_route_input_cloned_frag), + KUNIT_CASE(mctp_test_route_extaddr_input), + KUNIT_CASE(mctp_test_route_gw_lookup), + KUNIT_CASE(mctp_test_route_gw_loop), + KUNIT_CASE_PARAM(mctp_test_route_gw_mtu, mctp_route_gw_mtu_gen_params), + KUNIT_CASE(mctp_test_route_gw_output), + KUNIT_CASE_PARAM(mctp_test_bind_lookup, mctp_bind_lookup_gen_params), + {} +}; + +static struct kunit_suite mctp_test_suite = { + .name = "mctp-route", + .test_cases = mctp_test_cases, +}; + +kunit_test_suite(mctp_test_suite); diff --git a/net/mctp/test/sock-test.c b/net/mctp/test/sock-test.c new file mode 100644 index 000000000000..b0942deb5019 --- /dev/null +++ b/net/mctp/test/sock-test.c @@ -0,0 +1,396 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <kunit/static_stub.h> +#include <kunit/test.h> + +#include <linux/socket.h> +#include <linux/spinlock.h> + +#include "utils.h" + +static const u8 dev_default_lladdr[] = { 0x01, 0x02 }; + +/* helper for simple sock setup: single device, with dev_default_lladdr as its + * hardware address, assigned with a local EID 8, and a route to EID 9 + */ +static void __mctp_sock_test_init(struct kunit *test, + struct mctp_test_dev **devp, + struct mctp_test_route **rtp, + struct socket **sockp) +{ + struct mctp_test_route *rt; + struct mctp_test_dev *dev; + struct socket *sock; + unsigned long flags; + u8 *addrs; + int rc; + + dev = mctp_test_create_dev_lladdr(sizeof(dev_default_lladdr), + dev_default_lladdr); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + + addrs = kmalloc(1, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, addrs); + addrs[0] = 8; + + spin_lock_irqsave(&dev->mdev->addrs_lock, flags); + dev->mdev->num_addrs = 1; + swap(addrs, dev->mdev->addrs); + spin_unlock_irqrestore(&dev->mdev->addrs_lock, flags); + + kfree(addrs); + + rt = mctp_test_create_route_direct(dev_net(dev->ndev), dev->mdev, 9, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt); + + rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock); + KUNIT_ASSERT_EQ(test, rc, 0); + + *devp = dev; + *rtp = rt; + *sockp = sock; +} + +static void __mctp_sock_test_fini(struct kunit *test, + struct mctp_test_dev *dev, + struct mctp_test_route *rt, + struct socket *sock) +{ + sock_release(sock); + mctp_test_route_destroy(test, rt); + mctp_test_destroy_dev(dev); +} + +struct mctp_test_sock_local_output_config { + struct mctp_test_dev *dev; + size_t halen; + u8 haddr[MAX_ADDR_LEN]; + bool invoked; + int rc; +}; + +static int mctp_test_sock_local_output(struct sock *sk, + struct mctp_dst *dst, + struct sk_buff *skb, + mctp_eid_t daddr, u8 req_tag) +{ + struct kunit *test = kunit_get_current_test(); + struct mctp_test_sock_local_output_config *cfg = test->priv; + + KUNIT_EXPECT_PTR_EQ(test, dst->dev, cfg->dev->mdev); + KUNIT_EXPECT_EQ(test, dst->halen, cfg->halen); + KUNIT_EXPECT_MEMEQ(test, dst->haddr, cfg->haddr, dst->halen); + + cfg->invoked = true; + + kfree_skb(skb); + + return cfg->rc; +} + +static void mctp_test_sock_sendmsg_extaddr(struct kunit *test) +{ + struct sockaddr_mctp_ext addr = { + .smctp_base = { + .smctp_family = AF_MCTP, + .smctp_tag = MCTP_TAG_OWNER, + .smctp_network = MCTP_NET_ANY, + }, + }; + struct mctp_test_sock_local_output_config cfg = { 0 }; + u8 haddr[] = { 0xaa, 0x01 }; + u8 buf[4] = { 0, 1, 2, 3 }; + struct mctp_test_route *rt; + struct msghdr msg = { 0 }; + struct mctp_test_dev *dev; + struct mctp_sock *msk; + struct socket *sock; + ssize_t send_len; + struct kvec vec = { + .iov_base = buf, + .iov_len = sizeof(buf), + }; + + __mctp_sock_test_init(test, &dev, &rt, &sock); + + /* Expect to see the dst configured up with the addressing data we + * provide in the struct sockaddr_mctp_ext + */ + cfg.dev = dev; + cfg.halen = sizeof(haddr); + memcpy(cfg.haddr, haddr, sizeof(haddr)); + + test->priv = &cfg; + + kunit_activate_static_stub(test, mctp_local_output, + mctp_test_sock_local_output); + + /* enable and configure direct addressing */ + msk = container_of(sock->sk, struct mctp_sock, sk); + msk->addr_ext = true; + + addr.smctp_ifindex = dev->ndev->ifindex; + addr.smctp_halen = sizeof(haddr); + memcpy(addr.smctp_haddr, haddr, sizeof(haddr)); + + msg.msg_name = &addr; + msg.msg_namelen = sizeof(addr); + + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &vec, 1, sizeof(buf)); + send_len = mctp_sendmsg(sock, &msg, sizeof(buf)); + KUNIT_EXPECT_EQ(test, send_len, sizeof(buf)); + KUNIT_EXPECT_TRUE(test, cfg.invoked); + + __mctp_sock_test_fini(test, dev, rt, sock); +} + +static void mctp_test_sock_recvmsg_extaddr(struct kunit *test) +{ + struct sockaddr_mctp_ext recv_addr = { 0 }; + u8 rcv_buf[1], rcv_data[] = { 0, 1 }; + u8 haddr[] = { 0xaa, 0x02 }; + struct mctp_test_route *rt; + struct mctp_test_dev *dev; + struct mctp_skb_cb *cb; + struct mctp_sock *msk; + struct sk_buff *skb; + struct mctp_hdr hdr; + struct socket *sock; + struct msghdr msg; + ssize_t recv_len; + int rc; + struct kvec vec = { + .iov_base = rcv_buf, + .iov_len = sizeof(rcv_buf), + }; + + __mctp_sock_test_init(test, &dev, &rt, &sock); + + /* enable extended addressing on recv */ + msk = container_of(sock->sk, struct mctp_sock, sk); + msk->addr_ext = true; + + /* base incoming header, using a nul-EID dest */ + hdr.ver = 1; + hdr.dest = 0; + hdr.src = 9; + hdr.flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM | + MCTP_HDR_FLAG_TO; + + skb = mctp_test_create_skb_data(&hdr, &rcv_data); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + mctp_test_skb_set_dev(skb, dev); + + /* set incoming extended address data */ + cb = mctp_cb(skb); + cb->halen = sizeof(haddr); + cb->ifindex = dev->ndev->ifindex; + memcpy(cb->haddr, haddr, sizeof(haddr)); + + /* Deliver to socket. The route input path pulls the network header, + * leaving skb data at type byte onwards. recvmsg will consume the + * type for addr.smctp_type + */ + skb_pull(skb, sizeof(hdr)); + rc = sock_queue_rcv_skb(sock->sk, skb); + KUNIT_ASSERT_EQ(test, rc, 0); + + msg.msg_name = &recv_addr; + msg.msg_namelen = sizeof(recv_addr); + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &vec, 1, sizeof(rcv_buf)); + + recv_len = mctp_recvmsg(sock, &msg, sizeof(rcv_buf), + MSG_DONTWAIT | MSG_TRUNC); + + KUNIT_EXPECT_EQ(test, recv_len, sizeof(rcv_buf)); + + /* expect our extended address to be populated from hdr and cb */ + KUNIT_EXPECT_EQ(test, msg.msg_namelen, sizeof(recv_addr)); + KUNIT_EXPECT_EQ(test, recv_addr.smctp_base.smctp_family, AF_MCTP); + KUNIT_EXPECT_EQ(test, recv_addr.smctp_ifindex, dev->ndev->ifindex); + KUNIT_EXPECT_EQ(test, recv_addr.smctp_halen, sizeof(haddr)); + KUNIT_EXPECT_MEMEQ(test, recv_addr.smctp_haddr, haddr, sizeof(haddr)); + + __mctp_sock_test_fini(test, dev, rt, sock); +} + +static const struct mctp_test_bind_setup bind_addrany_netdefault_type1 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = MCTP_NET_ANY, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addrany_net2_type1 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 1, +}; + +/* 1 is default net */ +static const struct mctp_test_bind_setup bind_addr8_net1_type1 = { + .bind_addr = 8, .bind_net = 1, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addrany_net1_type1 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 1, .bind_type = 1, +}; + +/* 2 is an arbitrary net */ +static const struct mctp_test_bind_setup bind_addr8_net2_type1 = { + .bind_addr = 8, .bind_net = 2, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addr8_netdefault_type1 = { + .bind_addr = 8, .bind_net = MCTP_NET_ANY, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addrany_net2_type2 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 2, +}; + +static const struct mctp_test_bind_setup bind_addrany_net2_type1_peer9 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 1, + .have_peer = true, .peer_addr = 9, .peer_net = 2, +}; + +struct mctp_bind_pair_test { + const struct mctp_test_bind_setup *bind1; + const struct mctp_test_bind_setup *bind2; + int error; +}; + +/* Pairs of binds and whether they will conflict */ +static const struct mctp_bind_pair_test mctp_bind_pair_tests[] = { + /* Both ADDR_ANY, conflict */ + { &bind_addrany_netdefault_type1, &bind_addrany_netdefault_type1, + EADDRINUSE }, + /* Same specific EID, conflict */ + { &bind_addr8_netdefault_type1, &bind_addr8_netdefault_type1, + EADDRINUSE }, + /* ADDR_ANY vs specific EID, OK */ + { &bind_addrany_netdefault_type1, &bind_addr8_netdefault_type1, 0 }, + /* ADDR_ANY different types, OK */ + { &bind_addrany_net2_type2, &bind_addrany_net2_type1, 0 }, + /* ADDR_ANY different nets, OK */ + { &bind_addrany_net2_type1, &bind_addrany_netdefault_type1, 0 }, + + /* specific EID, NET_ANY (resolves to default) + * vs specific EID, explicit default net 1, conflict + */ + { &bind_addr8_netdefault_type1, &bind_addr8_net1_type1, EADDRINUSE }, + + /* specific EID, net 1 vs specific EID, net 2, ok */ + { &bind_addr8_net1_type1, &bind_addr8_net2_type1, 0 }, + + /* ANY_ADDR, NET_ANY (doesn't resolve to default) + * vs ADDR_ANY, explicit default net 1, OK + */ + { &bind_addrany_netdefault_type1, &bind_addrany_net1_type1, 0 }, + + /* specific remote peer doesn't conflict with any-peer bind */ + { &bind_addrany_net2_type1_peer9, &bind_addrany_net2_type1, 0 }, + + /* bind() NET_ANY is allowed with a connect() net */ + { &bind_addrany_net2_type1_peer9, &bind_addrany_netdefault_type1, 0 }, +}; + +static void mctp_bind_pair_desc(const struct mctp_bind_pair_test *t, char *desc) +{ + char peer1[25] = {0}, peer2[25] = {0}; + + if (t->bind1->have_peer) + snprintf(peer1, sizeof(peer1), ", peer %d net %d", + t->bind1->peer_addr, t->bind1->peer_net); + if (t->bind2->have_peer) + snprintf(peer2, sizeof(peer2), ", peer %d net %d", + t->bind2->peer_addr, t->bind2->peer_net); + + snprintf(desc, KUNIT_PARAM_DESC_SIZE, + "{bind(addr %d, type %d, net %d%s)} {bind(addr %d, type %d, net %d%s)} -> error %d", + t->bind1->bind_addr, t->bind1->bind_type, + t->bind1->bind_net, peer1, + t->bind2->bind_addr, t->bind2->bind_type, + t->bind2->bind_net, peer2, t->error); +} + +KUNIT_ARRAY_PARAM(mctp_bind_pair, mctp_bind_pair_tests, mctp_bind_pair_desc); + +static void mctp_test_bind_invalid(struct kunit *test) +{ + struct socket *sock; + int rc; + + /* bind() fails if the bind() vs connect() networks mismatch. */ + const struct mctp_test_bind_setup bind_connect_net_mismatch = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 1, .bind_type = 1, + .have_peer = true, .peer_addr = 9, .peer_net = 2, + }; + mctp_test_bind_run(test, &bind_connect_net_mismatch, &rc, &sock); + KUNIT_EXPECT_EQ(test, -rc, EINVAL); + sock_release(sock); +} + +static int +mctp_test_bind_conflicts_inner(struct kunit *test, + const struct mctp_test_bind_setup *bind1, + const struct mctp_test_bind_setup *bind2) +{ + struct socket *sock1 = NULL, *sock2 = NULL, *sock3 = NULL; + int bind_errno; + + /* Bind to first address, always succeeds */ + mctp_test_bind_run(test, bind1, &bind_errno, &sock1); + KUNIT_EXPECT_EQ(test, bind_errno, 0); + + /* A second identical bind always fails */ + mctp_test_bind_run(test, bind1, &bind_errno, &sock2); + KUNIT_EXPECT_EQ(test, -bind_errno, EADDRINUSE); + + /* A different bind, result is returned */ + mctp_test_bind_run(test, bind2, &bind_errno, &sock3); + + if (sock1) + sock_release(sock1); + if (sock2) + sock_release(sock2); + if (sock3) + sock_release(sock3); + + return bind_errno; +} + +static void mctp_test_bind_conflicts(struct kunit *test) +{ + const struct mctp_bind_pair_test *pair; + int bind_errno; + + pair = test->param_value; + + bind_errno = + mctp_test_bind_conflicts_inner(test, pair->bind1, pair->bind2); + KUNIT_EXPECT_EQ(test, -bind_errno, pair->error); + + /* swapping the calls, the second bind should still fail */ + bind_errno = + mctp_test_bind_conflicts_inner(test, pair->bind2, pair->bind1); + KUNIT_EXPECT_EQ(test, -bind_errno, pair->error); +} + +static void mctp_test_assumptions(struct kunit *test) +{ + /* check assumption of default net from bind_addr8_net1_type1 */ + KUNIT_ASSERT_EQ(test, mctp_default_net(&init_net), 1); +} + +static struct kunit_case mctp_test_cases[] = { + KUNIT_CASE(mctp_test_assumptions), + KUNIT_CASE(mctp_test_sock_sendmsg_extaddr), + KUNIT_CASE(mctp_test_sock_recvmsg_extaddr), + KUNIT_CASE_PARAM(mctp_test_bind_conflicts, mctp_bind_pair_gen_params), + KUNIT_CASE(mctp_test_bind_invalid), + {} +}; + +static struct kunit_suite mctp_test_suite = { + .name = "mctp-sock", + .test_cases = mctp_test_cases, +}; + +kunit_test_suite(mctp_test_suite); diff --git a/net/mctp/test/utils.c b/net/mctp/test/utils.c new file mode 100644 index 000000000000..37f1ba62a2ab --- /dev/null +++ b/net/mctp/test/utils.c @@ -0,0 +1,284 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/netdevice.h> +#include <linux/mctp.h> +#include <linux/if_arp.h> + +#include <net/mctp.h> +#include <net/mctpdevice.h> +#include <net/pkt_sched.h> + +#include "utils.h" + +static netdev_tx_t mctp_test_dev_tx(struct sk_buff *skb, + struct net_device *ndev) +{ + struct mctp_test_dev *dev = netdev_priv(ndev); + + skb_queue_tail(&dev->pkts, skb); + + return NETDEV_TX_OK; +} + +static const struct net_device_ops mctp_test_netdev_ops = { + .ndo_start_xmit = mctp_test_dev_tx, +}; + +static void mctp_test_dev_setup(struct net_device *ndev) +{ + ndev->type = ARPHRD_MCTP; + ndev->mtu = MCTP_DEV_TEST_MTU; + ndev->hard_header_len = 0; + ndev->tx_queue_len = 0; + ndev->flags = IFF_NOARP; + ndev->netdev_ops = &mctp_test_netdev_ops; + ndev->needs_free_netdev = true; +} + +static struct mctp_test_dev *__mctp_test_create_dev(unsigned short lladdr_len, + const unsigned char *lladdr) +{ + struct mctp_test_dev *dev; + struct net_device *ndev; + int rc; + + if (WARN_ON(lladdr_len > MAX_ADDR_LEN)) + return NULL; + + ndev = alloc_netdev(sizeof(*dev), "mctptest%d", NET_NAME_ENUM, + mctp_test_dev_setup); + if (!ndev) + return NULL; + + dev = netdev_priv(ndev); + dev->ndev = ndev; + ndev->addr_len = lladdr_len; + dev_addr_set(ndev, lladdr); + skb_queue_head_init(&dev->pkts); + + rc = register_netdev(ndev); + if (rc) { + free_netdev(ndev); + return NULL; + } + + rcu_read_lock(); + dev->mdev = __mctp_dev_get(ndev); + dev->mdev->net = mctp_default_net(dev_net(ndev)); + rcu_read_unlock(); + + /* bring the device up; we want to be able to TX immediately */ + rtnl_lock(); + dev_open(ndev, NULL); + rtnl_unlock(); + + return dev; +} + +struct mctp_test_dev *mctp_test_create_dev(void) +{ + return __mctp_test_create_dev(0, NULL); +} + +struct mctp_test_dev *mctp_test_create_dev_lladdr(unsigned short lladdr_len, + const unsigned char *lladdr) +{ + return __mctp_test_create_dev(lladdr_len, lladdr); +} + +void mctp_test_destroy_dev(struct mctp_test_dev *dev) +{ + skb_queue_purge(&dev->pkts); + mctp_dev_put(dev->mdev); + unregister_netdev(dev->ndev); +} + +static int mctp_test_dst_output(struct mctp_dst *dst, struct sk_buff *skb) +{ + skb->dev = dst->dev->dev; + dev_queue_xmit(skb); + + return 0; +} + +/* local version of mctp_route_alloc() */ +static struct mctp_test_route *mctp_route_test_alloc(void) +{ + struct mctp_test_route *rt; + + rt = kzalloc(sizeof(*rt), GFP_KERNEL); + if (!rt) + return NULL; + + INIT_LIST_HEAD(&rt->rt.list); + refcount_set(&rt->rt.refs, 1); + rt->rt.output = mctp_test_dst_output; + + return rt; +} + +struct mctp_test_route *mctp_test_create_route_direct(struct net *net, + struct mctp_dev *dev, + mctp_eid_t eid, + unsigned int mtu) +{ + struct mctp_test_route *rt; + + rt = mctp_route_test_alloc(); + if (!rt) + return NULL; + + rt->rt.min = eid; + rt->rt.max = eid; + rt->rt.mtu = mtu; + rt->rt.type = RTN_UNSPEC; + rt->rt.dst_type = MCTP_ROUTE_DIRECT; + if (dev) + mctp_dev_hold(dev); + rt->rt.dev = dev; + + list_add_rcu(&rt->rt.list, &net->mctp.routes); + + return rt; +} + +struct mctp_test_route *mctp_test_create_route_gw(struct net *net, + unsigned int netid, + mctp_eid_t eid, + mctp_eid_t gw, + unsigned int mtu) +{ + struct mctp_test_route *rt; + + rt = mctp_route_test_alloc(); + if (!rt) + return NULL; + + rt->rt.min = eid; + rt->rt.max = eid; + rt->rt.mtu = mtu; + rt->rt.type = RTN_UNSPEC; + rt->rt.dst_type = MCTP_ROUTE_GATEWAY; + rt->rt.gateway.eid = gw; + rt->rt.gateway.net = netid; + + list_add_rcu(&rt->rt.list, &net->mctp.routes); + + return rt; +} + +/* Convenience function for our test dst; release with mctp_dst_release() */ +void mctp_test_dst_setup(struct kunit *test, struct mctp_dst *dst, + struct mctp_test_dev *dev, unsigned int mtu) +{ + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, dev); + + memset(dst, 0, sizeof(*dst)); + + dst->dev = dev->mdev; + __mctp_dev_get(dst->dev->dev); + dst->mtu = mtu; + dst->output = mctp_test_dst_output; +} + +void mctp_test_route_destroy(struct kunit *test, struct mctp_test_route *rt) +{ + unsigned int refs; + + rtnl_lock(); + list_del_rcu(&rt->rt.list); + rtnl_unlock(); + + if (rt->rt.dst_type == MCTP_ROUTE_DIRECT && rt->rt.dev) + mctp_dev_put(rt->rt.dev); + + refs = refcount_read(&rt->rt.refs); + KUNIT_ASSERT_EQ_MSG(test, refs, 1, "route ref imbalance"); + + kfree_rcu(&rt->rt, rcu); +} + +void mctp_test_skb_set_dev(struct sk_buff *skb, struct mctp_test_dev *dev) +{ + struct mctp_skb_cb *cb; + + cb = mctp_cb(skb); + cb->net = READ_ONCE(dev->mdev->net); + skb->dev = dev->ndev; +} + +struct sk_buff *mctp_test_create_skb(const struct mctp_hdr *hdr, + unsigned int data_len) +{ + size_t hdr_len = sizeof(*hdr); + struct sk_buff *skb; + unsigned int i; + u8 *buf; + + skb = alloc_skb(hdr_len + data_len, GFP_KERNEL); + if (!skb) + return NULL; + + __mctp_cb(skb); + memcpy(skb_put(skb, hdr_len), hdr, hdr_len); + + buf = skb_put(skb, data_len); + for (i = 0; i < data_len; i++) + buf[i] = i & 0xff; + + return skb; +} + +struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr, + const void *data, size_t data_len) +{ + size_t hdr_len = sizeof(*hdr); + struct sk_buff *skb; + + skb = alloc_skb(hdr_len + data_len, GFP_KERNEL); + if (!skb) + return NULL; + + __mctp_cb(skb); + memcpy(skb_put(skb, hdr_len), hdr, hdr_len); + memcpy(skb_put(skb, data_len), data, data_len); + + return skb; +} + +void mctp_test_bind_run(struct kunit *test, + const struct mctp_test_bind_setup *setup, + int *ret_bind_errno, struct socket **sock) +{ + struct sockaddr_mctp addr; + int rc; + + *ret_bind_errno = -EIO; + + rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, sock); + KUNIT_ASSERT_EQ(test, rc, 0); + + /* connect() if requested */ + if (setup->have_peer) { + memset(&addr, 0x0, sizeof(addr)); + addr.smctp_family = AF_MCTP; + addr.smctp_network = setup->peer_net; + addr.smctp_addr.s_addr = setup->peer_addr; + /* connect() type must match bind() type */ + addr.smctp_type = setup->bind_type; + rc = kernel_connect(*sock, (struct sockaddr_unsized *)&addr, + sizeof(addr), 0); + KUNIT_EXPECT_EQ(test, rc, 0); + } + + /* bind() */ + memset(&addr, 0x0, sizeof(addr)); + addr.smctp_family = AF_MCTP; + addr.smctp_network = setup->bind_net; + addr.smctp_addr.s_addr = setup->bind_addr; + addr.smctp_type = setup->bind_type; + + *ret_bind_errno = + kernel_bind(*sock, (struct sockaddr_unsized *)&addr, + sizeof(addr)); +} diff --git a/net/mctp/test/utils.h b/net/mctp/test/utils.h new file mode 100644 index 000000000000..4cc90c9da4d1 --- /dev/null +++ b/net/mctp/test/utils.h @@ -0,0 +1,74 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#ifndef __NET_MCTP_TEST_UTILS_H +#define __NET_MCTP_TEST_UTILS_H + +#include <uapi/linux/netdevice.h> + +#include <net/mctp.h> +#include <net/mctpdevice.h> + +#include <kunit/test.h> + +#define MCTP_DEV_TEST_MTU 68 + +struct mctp_test_dev { + struct net_device *ndev; + struct mctp_dev *mdev; + + unsigned short lladdr_len; + unsigned char lladdr[MAX_ADDR_LEN]; + + struct sk_buff_head pkts; +}; + +struct mctp_test_dev; + +struct mctp_test_route { + struct mctp_route rt; +}; + +struct mctp_test_bind_setup { + mctp_eid_t bind_addr; + int bind_net; + u8 bind_type; + + bool have_peer; + mctp_eid_t peer_addr; + int peer_net; + + /* optional name. Used for comparison in "lookup" tests */ + const char *name; +}; + +struct mctp_test_dev *mctp_test_create_dev(void); +struct mctp_test_dev *mctp_test_create_dev_lladdr(unsigned short lladdr_len, + const unsigned char *lladdr); +void mctp_test_destroy_dev(struct mctp_test_dev *dev); + +struct mctp_test_route *mctp_test_create_route_direct(struct net *net, + struct mctp_dev *dev, + mctp_eid_t eid, + unsigned int mtu); +struct mctp_test_route *mctp_test_create_route_gw(struct net *net, + unsigned int netid, + mctp_eid_t eid, + mctp_eid_t gw, + unsigned int mtu); +void mctp_test_dst_setup(struct kunit *test, struct mctp_dst *dst, + struct mctp_test_dev *dev, unsigned int mtu); +void mctp_test_route_destroy(struct kunit *test, struct mctp_test_route *rt); +void mctp_test_skb_set_dev(struct sk_buff *skb, struct mctp_test_dev *dev); +struct sk_buff *mctp_test_create_skb(const struct mctp_hdr *hdr, + unsigned int data_len); +struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr, + const void *data, size_t data_len); + +#define mctp_test_create_skb_data(h, d) \ + __mctp_test_create_skb_data(h, d, sizeof(*d)) + +void mctp_test_bind_run(struct kunit *test, + const struct mctp_test_bind_setup *setup, + int *ret_bind_errno, struct socket **sock); + +#endif /* __NET_MCTP_TEST_UTILS_H */ |
