diff options
Diffstat (limited to 'net/mptcp/pm_userspace.c')
| -rw-r--r-- | net/mptcp/pm_userspace.c | 698 |
1 files changed, 698 insertions, 0 deletions
diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c new file mode 100644 index 000000000000..8cbc1920afb4 --- /dev/null +++ b/net/mptcp/pm_userspace.c @@ -0,0 +1,698 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2022, Intel Corporation. + */ + +#include "protocol.h" +#include "mib.h" +#include "mptcp_pm_gen.h" + +#define mptcp_for_each_userspace_pm_addr(__msk, __entry) \ + list_for_each_entry(__entry, \ + &((__msk)->pm.userspace_pm_local_addr_list), list) + +void mptcp_userspace_pm_free_local_addr_list(struct mptcp_sock *msk) +{ + struct mptcp_pm_addr_entry *entry, *tmp; + struct sock *sk = (struct sock *)msk; + LIST_HEAD(free_list); + + spin_lock_bh(&msk->pm.lock); + list_splice_init(&msk->pm.userspace_pm_local_addr_list, &free_list); + spin_unlock_bh(&msk->pm.lock); + + list_for_each_entry_safe(entry, tmp, &free_list, list) { + sock_kfree_s(sk, entry, sizeof(*entry)); + } +} + +static struct mptcp_pm_addr_entry * +mptcp_userspace_pm_lookup_addr(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) +{ + struct mptcp_pm_addr_entry *entry; + + mptcp_for_each_userspace_pm_addr(msk, entry) { + if (mptcp_addresses_equal(&entry->addr, addr, false)) + return entry; + } + return NULL; +} + +static int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, + struct mptcp_pm_addr_entry *entry, + bool needs_id) +{ + DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); + struct sock *sk = (struct sock *)msk; + struct mptcp_pm_addr_entry *e; + bool addr_match = false; + bool id_match = false; + int ret = -EINVAL; + + bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); + + spin_lock_bh(&msk->pm.lock); + mptcp_for_each_userspace_pm_addr(msk, e) { + addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true); + if (addr_match && entry->addr.id == 0 && needs_id) + entry->addr.id = e->addr.id; + id_match = (e->addr.id == entry->addr.id); + if (addr_match || id_match) + break; + __set_bit(e->addr.id, id_bitmap); + } + + if (!addr_match && !id_match) { + /* Memory for the entry is allocated from the + * sock option buffer. + */ + e = sock_kmemdup(sk, entry, sizeof(*entry), GFP_ATOMIC); + if (!e) { + ret = -ENOMEM; + goto append_err; + } + + if (!e->addr.id && needs_id) + e->addr.id = find_next_zero_bit(id_bitmap, + MPTCP_PM_MAX_ADDR_ID + 1, + 1); + list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list); + msk->pm.local_addr_used++; + ret = e->addr.id; + } else if (addr_match && id_match) { + ret = entry->addr.id; + } + +append_err: + spin_unlock_bh(&msk->pm.lock); + return ret; +} + +/* If the subflow is closed from the other peer (not via a + * subflow destroy command then), we want to keep the entry + * not to assign the same ID to another address and to be + * able to send RM_ADDR after the removal of the subflow. + */ +static int mptcp_userspace_pm_delete_local_addr(struct mptcp_sock *msk, + struct mptcp_pm_addr_entry *addr) +{ + struct sock *sk = (struct sock *)msk; + struct mptcp_pm_addr_entry *entry; + + entry = mptcp_userspace_pm_lookup_addr(msk, &addr->addr); + if (!entry) + return -EINVAL; + + /* TODO: a refcount is needed because the entry can + * be used multiple times (e.g. fullmesh mode). + */ + list_del_rcu(&entry->list); + sock_kfree_s(sk, entry, sizeof(*entry)); + msk->pm.local_addr_used--; + return 0; +} + +static struct mptcp_pm_addr_entry * +mptcp_userspace_pm_lookup_addr_by_id(struct mptcp_sock *msk, unsigned int id) +{ + struct mptcp_pm_addr_entry *entry; + + mptcp_for_each_userspace_pm_addr(msk, entry) { + if (entry->addr.id == id) + return entry; + } + return NULL; +} + +int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk, + struct mptcp_pm_addr_entry *skc) +{ + __be16 msk_sport = ((struct inet_sock *) + inet_sk((struct sock *)msk))->inet_sport; + struct mptcp_pm_addr_entry *entry; + + spin_lock_bh(&msk->pm.lock); + entry = mptcp_userspace_pm_lookup_addr(msk, &skc->addr); + spin_unlock_bh(&msk->pm.lock); + if (entry) + return entry->addr.id; + + if (skc->addr.port == msk_sport) + skc->addr.port = 0; + + return mptcp_userspace_pm_append_new_local_addr(msk, skc, true); +} + +bool mptcp_userspace_pm_is_backup(struct mptcp_sock *msk, + struct mptcp_addr_info *skc) +{ + struct mptcp_pm_addr_entry *entry; + bool backup; + + spin_lock_bh(&msk->pm.lock); + entry = mptcp_userspace_pm_lookup_addr(msk, skc); + backup = entry && !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP); + spin_unlock_bh(&msk->pm.lock); + + return backup; +} + +static struct mptcp_sock *mptcp_userspace_pm_get_sock(const struct genl_info *info) +{ + struct mptcp_sock *msk; + struct nlattr *token; + + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_TOKEN)) + return NULL; + + token = info->attrs[MPTCP_PM_ATTR_TOKEN]; + msk = mptcp_token_get_sock(genl_info_net(info), nla_get_u32(token)); + if (!msk) { + NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + return NULL; + } + + if (!mptcp_pm_is_userspace(msk)) { + NL_SET_ERR_MSG_ATTR(info->extack, token, + "userspace PM not selected"); + sock_put((struct sock *)msk); + return NULL; + } + + return msk; +} + +int mptcp_pm_nl_announce_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct mptcp_pm_addr_entry addr_val; + struct mptcp_sock *msk; + struct nlattr *addr; + int err = -EINVAL; + struct sock *sk; + + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_ADDR)) + return err; + + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) + return err; + + sk = (struct sock *)msk; + + addr = info->attrs[MPTCP_PM_ATTR_ADDR]; + err = mptcp_pm_parse_entry(addr, info, true, &addr_val); + if (err < 0) + goto announce_err; + + if (addr_val.addr.id == 0) { + NL_SET_ERR_MSG_ATTR(info->extack, addr, "invalid addr id"); + err = -EINVAL; + goto announce_err; + } + + if (!(addr_val.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { + NL_SET_ERR_MSG_ATTR(info->extack, addr, "invalid addr flags"); + err = -EINVAL; + goto announce_err; + } + + err = mptcp_userspace_pm_append_new_local_addr(msk, &addr_val, false); + if (err < 0) { + NL_SET_ERR_MSG_ATTR(info->extack, addr, + "did not match address and id"); + goto announce_err; + } + + lock_sock(sk); + spin_lock_bh(&msk->pm.lock); + + if (mptcp_pm_alloc_anno_list(msk, &addr_val.addr)) { + msk->pm.add_addr_signaled++; + mptcp_pm_announce_addr(msk, &addr_val.addr, false); + mptcp_pm_addr_send_ack(msk); + } + + spin_unlock_bh(&msk->pm.lock); + release_sock(sk); + + err = 0; + announce_err: + sock_put(sk); + return err; +} + +static int mptcp_userspace_pm_remove_id_zero_address(struct mptcp_sock *msk) +{ + struct mptcp_rm_list list = { .nr = 0 }; + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + bool has_id_0 = false; + int err = -EINVAL; + + lock_sock(sk); + mptcp_for_each_subflow(msk, subflow) { + if (READ_ONCE(subflow->local_id) == 0) { + has_id_0 = true; + break; + } + } + if (!has_id_0) + goto remove_err; + + list.ids[list.nr++] = 0; + + spin_lock_bh(&msk->pm.lock); + mptcp_pm_remove_addr(msk, &list); + spin_unlock_bh(&msk->pm.lock); + + err = 0; + +remove_err: + release_sock(sk); + return err; +} + +void mptcp_pm_remove_addr_entry(struct mptcp_sock *msk, + struct mptcp_pm_addr_entry *entry) +{ + struct mptcp_rm_list alist = { .nr = 0 }; + int anno_nr = 0; + + /* only delete if either announced or matching a subflow */ + if (mptcp_remove_anno_list_by_saddr(msk, &entry->addr)) + anno_nr++; + else if (!mptcp_lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) + return; + + alist.ids[alist.nr++] = entry->addr.id; + + spin_lock_bh(&msk->pm.lock); + msk->pm.add_addr_signaled -= anno_nr; + mptcp_pm_remove_addr(msk, &alist); + spin_unlock_bh(&msk->pm.lock); +} + +int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct mptcp_pm_addr_entry *match; + struct mptcp_sock *msk; + struct nlattr *id; + int err = -EINVAL; + struct sock *sk; + u8 id_val; + + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_LOC_ID)) + return err; + + id = info->attrs[MPTCP_PM_ATTR_LOC_ID]; + id_val = nla_get_u8(id); + + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) + return err; + + sk = (struct sock *)msk; + + if (id_val == 0) { + err = mptcp_userspace_pm_remove_id_zero_address(msk); + goto out; + } + + lock_sock(sk); + + spin_lock_bh(&msk->pm.lock); + match = mptcp_userspace_pm_lookup_addr_by_id(msk, id_val); + if (!match) { + spin_unlock_bh(&msk->pm.lock); + release_sock(sk); + goto out; + } + + list_del_rcu(&match->list); + spin_unlock_bh(&msk->pm.lock); + + mptcp_pm_remove_addr_entry(msk, match); + + release_sock(sk); + + kfree_rcu_mightsleep(match); + /* Adjust sk_omem_alloc like sock_kfree_s() does, to match + * with allocation of this memory by sock_kmemdup() + */ + atomic_sub(sizeof(*match), &sk->sk_omem_alloc); + + err = 0; +out: + if (err) + NL_SET_ERR_MSG_ATTR_FMT(info->extack, id, + "address with id %u not found", + id_val); + + sock_put(sk); + return err; +} + +int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct mptcp_pm_addr_entry entry = { 0 }; + struct mptcp_addr_info addr_r; + struct nlattr *raddr, *laddr; + struct mptcp_pm_local local; + struct mptcp_sock *msk; + int err = -EINVAL; + struct sock *sk; + + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_ADDR) || + GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_ADDR_REMOTE)) + return err; + + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) + return err; + + sk = (struct sock *)msk; + + laddr = info->attrs[MPTCP_PM_ATTR_ADDR]; + err = mptcp_pm_parse_entry(laddr, info, true, &entry); + if (err < 0) + goto create_err; + + if (entry.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { + NL_SET_ERR_MSG_ATTR(info->extack, laddr, "invalid addr flags"); + err = -EINVAL; + goto create_err; + } + entry.flags |= MPTCP_PM_ADDR_FLAG_SUBFLOW; + + raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; + err = mptcp_pm_parse_addr(raddr, info, &addr_r); + if (err < 0) + goto create_err; + + if (!mptcp_pm_addr_families_match(sk, &entry.addr, &addr_r)) { + GENL_SET_ERR_MSG(info, "families mismatch"); + err = -EINVAL; + goto create_err; + } + + err = mptcp_userspace_pm_append_new_local_addr(msk, &entry, false); + if (err < 0) { + NL_SET_ERR_MSG_ATTR(info->extack, laddr, + "did not match address and id"); + goto create_err; + } + + local.addr = entry.addr; + local.flags = entry.flags; + local.ifindex = entry.ifindex; + + lock_sock(sk); + err = __mptcp_subflow_connect(sk, &local, &addr_r); + release_sock(sk); + + if (err) + GENL_SET_ERR_MSG_FMT(info, "connect error: %d", err); + + spin_lock_bh(&msk->pm.lock); + if (err) + mptcp_userspace_pm_delete_local_addr(msk, &entry); + else + msk->pm.extra_subflows++; + spin_unlock_bh(&msk->pm.lock); + + create_err: + sock_put(sk); + return err; +} + +static struct sock *mptcp_nl_find_ssk(struct mptcp_sock *msk, + const struct mptcp_addr_info *local, + const struct mptcp_addr_info *remote) +{ + struct mptcp_subflow_context *subflow; + + if (local->family != remote->family) + return NULL; + + mptcp_for_each_subflow(msk, subflow) { + const struct inet_sock *issk; + struct sock *ssk; + + ssk = mptcp_subflow_tcp_sock(subflow); + + if (local->family != ssk->sk_family) + continue; + + issk = inet_sk(ssk); + + switch (ssk->sk_family) { + case AF_INET: + if (issk->inet_saddr != local->addr.s_addr || + issk->inet_daddr != remote->addr.s_addr) + continue; + break; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + case AF_INET6: { + if (!ipv6_addr_equal(&local->addr6, &issk->pinet6->saddr) || + !ipv6_addr_equal(&remote->addr6, &ssk->sk_v6_daddr)) + continue; + break; + } +#endif + default: + continue; + } + + if (issk->inet_sport == local->port && + issk->inet_dport == remote->port) + return ssk; + } + + return NULL; +} + +int mptcp_pm_nl_subflow_destroy_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct mptcp_pm_addr_entry addr_l; + struct mptcp_addr_info addr_r; + struct nlattr *raddr, *laddr; + struct mptcp_sock *msk; + struct sock *sk, *ssk; + int err = -EINVAL; + + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_ADDR) || + GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_ADDR_REMOTE)) + return err; + + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) + return err; + + sk = (struct sock *)msk; + + laddr = info->attrs[MPTCP_PM_ATTR_ADDR]; + err = mptcp_pm_parse_entry(laddr, info, true, &addr_l); + if (err < 0) + goto destroy_err; + + raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; + err = mptcp_pm_parse_addr(raddr, info, &addr_r); + if (err < 0) + goto destroy_err; + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (addr_l.addr.family == AF_INET && ipv6_addr_v4mapped(&addr_r.addr6)) { + ipv6_addr_set_v4mapped(addr_l.addr.addr.s_addr, &addr_l.addr.addr6); + addr_l.addr.family = AF_INET6; + } + if (addr_r.family == AF_INET && ipv6_addr_v4mapped(&addr_l.addr.addr6)) { + ipv6_addr_set_v4mapped(addr_r.addr.s_addr, &addr_r.addr6); + addr_r.family = AF_INET6; + } +#endif + if (addr_l.addr.family != addr_r.family) { + GENL_SET_ERR_MSG(info, "address families do not match"); + err = -EINVAL; + goto destroy_err; + } + + if (!addr_l.addr.port) { + NL_SET_ERR_MSG_ATTR(info->extack, laddr, "missing local port"); + err = -EINVAL; + goto destroy_err; + } + + if (!addr_r.port) { + NL_SET_ERR_MSG_ATTR(info->extack, raddr, "missing remote port"); + err = -EINVAL; + goto destroy_err; + } + + lock_sock(sk); + ssk = mptcp_nl_find_ssk(msk, &addr_l.addr, &addr_r); + if (!ssk) { + GENL_SET_ERR_MSG(info, "subflow not found"); + err = -ESRCH; + goto release_sock; + } + + spin_lock_bh(&msk->pm.lock); + mptcp_userspace_pm_delete_local_addr(msk, &addr_l); + spin_unlock_bh(&msk->pm.lock); + mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN); + mptcp_close_ssk(sk, ssk, mptcp_subflow_ctx(ssk)); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); +release_sock: + release_sock(sk); + +destroy_err: + sock_put(sk); + return err; +} + +int mptcp_userspace_pm_set_flags(struct mptcp_pm_addr_entry *local, + struct genl_info *info) +{ + struct mptcp_addr_info rem = { .family = AF_UNSPEC, }; + struct mptcp_pm_addr_entry *entry; + struct nlattr *attr, *attr_rem; + struct mptcp_sock *msk; + int ret = -EINVAL; + struct sock *sk; + u8 bkup = 0; + + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_ADDR_REMOTE)) + return ret; + + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) + return ret; + + sk = (struct sock *)msk; + + attr = info->attrs[MPTCP_PM_ATTR_ADDR]; + if (local->addr.family == AF_UNSPEC) { + NL_SET_ERR_MSG_ATTR(info->extack, attr, + "invalid local address family"); + ret = -EINVAL; + goto set_flags_err; + } + + attr_rem = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; + ret = mptcp_pm_parse_addr(attr_rem, info, &rem); + if (ret < 0) + goto set_flags_err; + + if (rem.family == AF_UNSPEC) { + NL_SET_ERR_MSG_ATTR(info->extack, attr_rem, + "invalid remote address family"); + ret = -EINVAL; + goto set_flags_err; + } + + if (local->flags & MPTCP_PM_ADDR_FLAG_BACKUP) + bkup = 1; + + spin_lock_bh(&msk->pm.lock); + entry = mptcp_userspace_pm_lookup_addr(msk, &local->addr); + if (entry) { + if (bkup) + entry->flags |= MPTCP_PM_ADDR_FLAG_BACKUP; + else + entry->flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP; + } + spin_unlock_bh(&msk->pm.lock); + + lock_sock(sk); + ret = mptcp_pm_mp_prio_send_ack(msk, &local->addr, &rem, bkup); + release_sock(sk); + + /* mptcp_pm_mp_prio_send_ack() only fails in one case */ + if (ret < 0) + GENL_SET_ERR_MSG(info, "subflow not found"); + +set_flags_err: + sock_put(sk); + return ret; +} + +int mptcp_userspace_pm_dump_addr(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct id_bitmap { + DECLARE_BITMAP(map, MPTCP_PM_MAX_ADDR_ID + 1); + } *bitmap; + const struct genl_info *info = genl_info_dump(cb); + struct mptcp_pm_addr_entry *entry; + struct mptcp_sock *msk; + int ret = -EINVAL; + struct sock *sk; + + BUILD_BUG_ON(sizeof(struct id_bitmap) > sizeof(cb->ctx)); + + bitmap = (struct id_bitmap *)cb->ctx; + + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) + return ret; + + sk = (struct sock *)msk; + + lock_sock(sk); + spin_lock_bh(&msk->pm.lock); + mptcp_for_each_userspace_pm_addr(msk, entry) { + if (test_bit(entry->addr.id, bitmap->map)) + continue; + + if (mptcp_pm_genl_fill_addr(msg, cb, entry) < 0) + break; + + __set_bit(entry->addr.id, bitmap->map); + } + spin_unlock_bh(&msk->pm.lock); + release_sock(sk); + ret = msg->len; + + sock_put(sk); + return ret; +} + +int mptcp_userspace_pm_get_addr(u8 id, struct mptcp_pm_addr_entry *addr, + struct genl_info *info) +{ + struct mptcp_pm_addr_entry *entry; + struct mptcp_sock *msk; + int ret = -EINVAL; + struct sock *sk; + + msk = mptcp_userspace_pm_get_sock(info); + if (!msk) + return ret; + + sk = (struct sock *)msk; + + lock_sock(sk); + spin_lock_bh(&msk->pm.lock); + entry = mptcp_userspace_pm_lookup_addr_by_id(msk, id); + if (entry) { + *addr = *entry; + ret = 0; + } + spin_unlock_bh(&msk->pm.lock); + release_sock(sk); + + sock_put(sk); + return ret; +} + +static struct mptcp_pm_ops mptcp_pm_userspace = { + .name = "userspace", + .owner = THIS_MODULE, +}; + +void __init mptcp_pm_userspace_register(void) +{ + mptcp_pm_register(&mptcp_pm_userspace); +} |
