summaryrefslogtreecommitdiff
path: root/net/ipv6/ip6mr.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv6/ip6mr.c')
-rw-r--r--net/ipv6/ip6mr.c673
1 files changed, 466 insertions, 207 deletions
diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c
index 7cf73e60e619..e047a4680ab0 100644
--- a/net/ipv6/ip6mr.c
+++ b/net/ipv6/ip6mr.c
@@ -62,7 +62,12 @@ struct ip6mr_result {
Note that the changes are semaphored via rtnl_lock.
*/
-static DEFINE_RWLOCK(mrt_lock);
+static DEFINE_SPINLOCK(mrt_lock);
+
+static struct net_device *vif_dev_read(const struct vif_device *vif)
+{
+ return rcu_dereference(vif->dev);
+}
/* Multicast router control variables */
@@ -85,11 +90,13 @@ static void ip6mr_free_table(struct mr_table *mrt);
static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
struct net_device *dev, struct sk_buff *skb,
struct mfc6_cache *cache);
-static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
+static int ip6mr_cache_report(const struct mr_table *mrt, struct sk_buff *pkt,
mifi_t mifi, int assert);
static void mr6_netlink_event(struct mr_table *mrt, struct mfc6_cache *mfc,
int cmd);
-static void mrt6msg_netlink_event(struct mr_table *mrt, struct sk_buff *pkt);
+static void mrt6msg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt);
+static int ip6mr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh,
+ struct netlink_ext_ack *extack);
static int ip6mr_rtm_dumproute(struct sk_buff *skb,
struct netlink_callback *cb);
static void mroute_clean_tables(struct mr_table *mrt, int flags);
@@ -118,7 +125,7 @@ static struct mr_table *ip6mr_mr_table_iter(struct net *net,
return ret;
}
-static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
+static struct mr_table *__ip6mr_get_table(struct net *net, u32 id)
{
struct mr_table *mrt;
@@ -129,6 +136,16 @@ static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
return NULL;
}
+static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
+{
+ struct mr_table *mrt;
+
+ rcu_read_lock();
+ mrt = __ip6mr_get_table(net, id);
+ rcu_read_unlock();
+ return mrt;
+}
+
static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6,
struct mr_table **mrt)
{
@@ -170,7 +187,7 @@ static int ip6mr_rule_action(struct fib_rule *rule, struct flowi *flp,
arg->table = fib_rule_get_table(rule, arg);
- mrt = ip6mr_get_table(rule->fr_net, arg->table);
+ mrt = __ip6mr_get_table(rule->fr_net, arg->table);
if (!mrt)
return -EAGAIN;
res->mrt = mrt;
@@ -235,7 +252,7 @@ static int __net_init ip6mr_rules_init(struct net *net)
goto err1;
}
- err = fib_default_rule_add(ops, 0x7fff, RT6_TABLE_DFLT, 0);
+ err = fib_default_rule_add(ops, 0x7fff, RT6_TABLE_DFLT);
if (err < 0)
goto err2;
@@ -243,7 +260,9 @@ static int __net_init ip6mr_rules_init(struct net *net)
return 0;
err2:
+ rtnl_lock();
ip6mr_free_table(mrt);
+ rtnl_unlock();
err1:
fib_rules_unregister(ops);
return err;
@@ -253,13 +272,12 @@ static void __net_exit ip6mr_rules_exit(struct net *net)
{
struct mr_table *mrt, *next;
- rtnl_lock();
+ ASSERT_RTNL();
list_for_each_entry_safe(mrt, next, &net->ipv6.mr6_tables, list) {
list_del(&mrt->list);
ip6mr_free_table(mrt);
}
fib_rules_unregister(net->ipv6.mr6_rules_ops);
- rtnl_unlock();
}
static int ip6mr_rules_dump(struct net *net, struct notifier_block *nb,
@@ -268,7 +286,7 @@ static int ip6mr_rules_dump(struct net *net, struct notifier_block *nb,
return fib_rules_dump(net, nb, RTNL_FAMILY_IP6MR, extack);
}
-static unsigned int ip6mr_rules_seq_read(struct net *net)
+static unsigned int ip6mr_rules_seq_read(const struct net *net)
{
return fib_rules_seq_read(net, RTNL_FAMILY_IP6MR);
}
@@ -296,6 +314,8 @@ static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
return net->ipv6.mrt6;
}
+#define __ip6mr_get_table ip6mr_get_table
+
static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6,
struct mr_table **mrt)
{
@@ -316,10 +336,9 @@ static int __net_init ip6mr_rules_init(struct net *net)
static void __net_exit ip6mr_rules_exit(struct net *net)
{
- rtnl_lock();
+ ASSERT_RTNL();
ip6mr_free_table(net->ipv6.mrt6);
net->ipv6.mrt6 = NULL;
- rtnl_unlock();
}
static int ip6mr_rules_dump(struct net *net, struct notifier_block *nb,
@@ -328,7 +347,7 @@ static int ip6mr_rules_dump(struct net *net, struct notifier_block *nb,
return 0;
}
-static unsigned int ip6mr_rules_seq_read(struct net *net)
+static unsigned int ip6mr_rules_seq_read(const struct net *net)
{
return 0;
}
@@ -375,7 +394,7 @@ static struct mr_table *ip6mr_new_table(struct net *net, u32 id)
{
struct mr_table *mrt;
- mrt = ip6mr_get_table(net, id);
+ mrt = __ip6mr_get_table(net, id);
if (mrt)
return mrt;
@@ -385,7 +404,11 @@ static struct mr_table *ip6mr_new_table(struct net *net, u32 id)
static void ip6mr_free_table(struct mr_table *mrt)
{
- del_timer_sync(&mrt->ipmr_expire_timer);
+ struct net *net = read_pnet(&mrt->net);
+
+ WARN_ON_ONCE(!mr_can_free_table(net));
+
+ timer_shutdown_sync(&mrt->ipmr_expire_timer);
mroute_clean_tables(mrt, MRT6_FLUSH_MIFS | MRT6_FLUSH_MIFS_STATIC |
MRT6_FLUSH_MFC | MRT6_FLUSH_MFC_STATIC);
rhltable_destroy(&mrt->mfc_hash);
@@ -398,26 +421,28 @@ static void ip6mr_free_table(struct mr_table *mrt)
*/
static void *ip6mr_vif_seq_start(struct seq_file *seq, loff_t *pos)
- __acquires(mrt_lock)
+ __acquires(RCU)
{
struct mr_vif_iter *iter = seq->private;
struct net *net = seq_file_net(seq);
struct mr_table *mrt;
- mrt = ip6mr_get_table(net, RT6_TABLE_DFLT);
- if (!mrt)
+ rcu_read_lock();
+ mrt = __ip6mr_get_table(net, RT6_TABLE_DFLT);
+ if (!mrt) {
+ rcu_read_unlock();
return ERR_PTR(-ENOENT);
+ }
iter->mrt = mrt;
- read_lock(&mrt_lock);
return mr_vif_seq_start(seq, pos);
}
static void ip6mr_vif_seq_stop(struct seq_file *seq, void *v)
- __releases(mrt_lock)
+ __releases(RCU)
{
- read_unlock(&mrt_lock);
+ rcu_read_unlock();
}
static int ip6mr_vif_seq_show(struct seq_file *seq, void *v)
@@ -430,7 +455,11 @@ static int ip6mr_vif_seq_show(struct seq_file *seq, void *v)
"Interface BytesIn PktsIn BytesOut PktsOut Flags\n");
} else {
const struct vif_device *vif = v;
- const char *name = vif->dev ? vif->dev->name : "none";
+ const struct net_device *vif_dev;
+ const char *name;
+
+ vif_dev = vif_dev_read(vif);
+ name = vif_dev ? vif_dev->name : "none";
seq_printf(seq,
"%2td %-10s %8ld %7ld %8ld %7ld %05X\n",
@@ -481,9 +510,9 @@ static int ipmr_mfc_seq_show(struct seq_file *seq, void *v)
if (it->cache != &mrt->mfc_unres_queue) {
seq_printf(seq, " %8lu %8lu %8lu",
- mfc->_c.mfc_un.res.pkt,
- mfc->_c.mfc_un.res.bytes,
- mfc->_c.mfc_un.res.wrong_if);
+ atomic_long_read(&mfc->_c.mfc_un.res.pkt),
+ atomic_long_read(&mfc->_c.mfc_un.res.bytes),
+ atomic_long_read(&mfc->_c.mfc_un.res.wrong_if));
for (n = mfc->_c.mfc_un.res.minvif;
n < mfc->_c.mfc_un.res.maxvif; n++) {
if (VIF_EXISTS(mrt, n) &&
@@ -549,13 +578,11 @@ static int pim6_rcv(struct sk_buff *skb)
if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0)
goto drop;
- reg_vif_num = mrt->mroute_reg_vif_num;
- read_lock(&mrt_lock);
+ /* Pairs with WRITE_ONCE() in mif6_add()/mif6_delete() */
+ reg_vif_num = READ_ONCE(mrt->mroute_reg_vif_num);
if (reg_vif_num >= 0)
- reg_dev = mrt->vif_table[reg_vif_num].dev;
- dev_hold(reg_dev);
- read_unlock(&mrt_lock);
+ reg_dev = vif_dev_read(&mrt->vif_table[reg_vif_num]);
if (!reg_dev)
goto drop;
@@ -570,7 +597,6 @@ static int pim6_rcv(struct sk_buff *skb)
netif_rx(skb);
- dev_put(reg_dev);
return 0;
drop:
kfree_skb(skb);
@@ -600,16 +626,17 @@ static netdev_tx_t reg_vif_xmit(struct sk_buff *skb,
if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0)
goto tx_err;
- read_lock(&mrt_lock);
- dev->stats.tx_bytes += skb->len;
- dev->stats.tx_packets++;
- ip6mr_cache_report(mrt, skb, mrt->mroute_reg_vif_num, MRT6MSG_WHOLEPKT);
- read_unlock(&mrt_lock);
+ DEV_STATS_ADD(dev, tx_bytes, skb->len);
+ DEV_STATS_INC(dev, tx_packets);
+ rcu_read_lock();
+ ip6mr_cache_report(mrt, skb, READ_ONCE(mrt->mroute_reg_vif_num),
+ MRT6MSG_WHOLEPKT);
+ rcu_read_unlock();
kfree_skb(skb);
return NETDEV_TX_OK;
tx_err:
- dev->stats.tx_errors++;
+ DEV_STATS_INC(dev, tx_errors);
kfree_skb(skb);
return NETDEV_TX_OK;
}
@@ -631,7 +658,7 @@ static void reg_vif_setup(struct net_device *dev)
dev->flags = IFF_NOARP;
dev->netdev_ops = &reg_vif_netdev_ops;
dev->needs_free_netdev = true;
- dev->features |= NETIF_F_NETNS_LOCAL;
+ dev->netns_immutable = true;
}
static struct net_device *ip6mr_reg_vif(struct net *net, struct mr_table *mrt)
@@ -670,10 +697,11 @@ failure:
static int call_ip6mr_vif_entry_notifiers(struct net *net,
enum fib_event_type event_type,
struct vif_device *vif,
+ struct net_device *vif_dev,
mifi_t vif_index, u32 tb_id)
{
return mr_call_vif_notifiers(net, RTNL_FAMILY_IP6MR, event_type,
- vif, vif_index, tb_id,
+ vif, vif_dev, vif_index, tb_id,
&net->ipv6.ipmr_seq);
}
@@ -698,23 +726,21 @@ static int mif6_delete(struct mr_table *mrt, int vifi, int notify,
v = &mrt->vif_table[vifi];
- if (VIF_EXISTS(mrt, vifi))
- call_ip6mr_vif_entry_notifiers(read_pnet(&mrt->net),
- FIB_EVENT_VIF_DEL, v, vifi,
- mrt->id);
-
- write_lock_bh(&mrt_lock);
- dev = v->dev;
- v->dev = NULL;
-
- if (!dev) {
- write_unlock_bh(&mrt_lock);
+ dev = rtnl_dereference(v->dev);
+ if (!dev)
return -EADDRNOTAVAIL;
- }
+
+ call_ip6mr_vif_entry_notifiers(read_pnet(&mrt->net),
+ FIB_EVENT_VIF_DEL, v, dev,
+ vifi, mrt->id);
+ spin_lock(&mrt_lock);
+ RCU_INIT_POINTER(v->dev, NULL);
#ifdef CONFIG_IPV6_PIMSM_V2
- if (vifi == mrt->mroute_reg_vif_num)
- mrt->mroute_reg_vif_num = -1;
+ if (vifi == mrt->mroute_reg_vif_num) {
+ /* Pairs with READ_ONCE() in ip6mr_cache_report() and reg_vif_xmit() */
+ WRITE_ONCE(mrt->mroute_reg_vif_num, -1);
+ }
#endif
if (vifi + 1 == mrt->maxvif) {
@@ -723,16 +749,16 @@ static int mif6_delete(struct mr_table *mrt, int vifi, int notify,
if (VIF_EXISTS(mrt, tmp))
break;
}
- mrt->maxvif = tmp + 1;
+ WRITE_ONCE(mrt->maxvif, tmp + 1);
}
- write_unlock_bh(&mrt_lock);
+ spin_unlock(&mrt_lock);
dev_set_allmulti(dev, -1);
in6_dev = __in6_dev_get(dev);
if (in6_dev) {
- in6_dev->cnf.mc_forwarding--;
+ atomic_dec(&in6_dev->cnf.mc_forwarding);
inet6_netconf_notify_devconf(dev_net(dev), RTM_NEWNETCONF,
NETCONFA_MC_FORWARDING,
dev->ifindex, &in6_dev->cnf);
@@ -741,7 +767,7 @@ static int mif6_delete(struct mr_table *mrt, int vifi, int notify,
if ((v->flags & MIFF_REGISTER) && !notify)
unregister_netdevice_queue(dev, head);
- dev_put_track(dev, &v->dev_tracker);
+ netdev_put(dev, &v->dev_tracker);
return 0;
}
@@ -813,7 +839,7 @@ static void ipmr_do_expire_process(struct mr_table *mrt)
static void ipmr_expire_process(struct timer_list *t)
{
- struct mr_table *mrt = from_timer(mrt, t, ipmr_expire_timer);
+ struct mr_table *mrt = timer_container_of(mrt, t, ipmr_expire_timer);
if (!spin_trylock(&mfc_unres_lock)) {
mod_timer(&mrt->ipmr_expire_timer, jiffies + 1);
@@ -826,7 +852,7 @@ static void ipmr_expire_process(struct timer_list *t)
spin_unlock(&mfc_unres_lock);
}
-/* Fill oifs list. It is called under write locked mrt_lock. */
+/* Fill oifs list. It is called under locked mrt_lock. */
static void ip6mr_update_thresholds(struct mr_table *mrt,
struct mr_mfc *cache,
@@ -848,7 +874,7 @@ static void ip6mr_update_thresholds(struct mr_table *mrt,
cache->mfc_un.res.maxvif = vifi + 1;
}
}
- cache->mfc_un.res.lastuse = jiffies;
+ WRITE_ONCE(cache->mfc_un.res.lastuse, jiffies);
}
static int mif6_add(struct net *net, struct mr_table *mrt,
@@ -900,7 +926,7 @@ static int mif6_add(struct net *net, struct mr_table *mrt,
in6_dev = __in6_dev_get(dev);
if (in6_dev) {
- in6_dev->cnf.mc_forwarding++;
+ atomic_inc(&in6_dev->cnf.mc_forwarding);
inet6_netconf_notify_devconf(dev_net(dev), RTM_NEWNETCONF,
NETCONFA_MC_FORWARDING,
dev->ifindex, &in6_dev->cnf);
@@ -912,18 +938,18 @@ static int mif6_add(struct net *net, struct mr_table *mrt,
MIFF_REGISTER);
/* And finish update writing critical data */
- write_lock_bh(&mrt_lock);
- v->dev = dev;
+ spin_lock(&mrt_lock);
+ rcu_assign_pointer(v->dev, dev);
netdev_tracker_alloc(dev, &v->dev_tracker, GFP_ATOMIC);
#ifdef CONFIG_IPV6_PIMSM_V2
if (v->flags & MIFF_REGISTER)
- mrt->mroute_reg_vif_num = vifi;
+ WRITE_ONCE(mrt->mroute_reg_vif_num, vifi);
#endif
if (vifi + 1 > mrt->maxvif)
- mrt->maxvif = vifi + 1;
- write_unlock_bh(&mrt_lock);
+ WRITE_ONCE(mrt->maxvif, vifi + 1);
+ spin_unlock(&mrt_lock);
call_ip6mr_vif_entry_notifiers(net, FIB_EVENT_VIF_ADD,
- v, vifi, mrt->id);
+ v, dev, vifi, mrt->id);
return 0;
}
@@ -1020,18 +1046,21 @@ static void ip6mr_cache_resolve(struct net *net, struct mr_table *mrt,
((struct nlmsgerr *)nlmsg_data(nlh))->error = -EMSGSIZE;
}
rtnl_unicast(skb, net, NETLINK_CB(skb).portid);
- } else
+ } else {
+ rcu_read_lock();
ip6_mr_forward(net, mrt, skb->dev, skb, c);
+ rcu_read_unlock();
+ }
}
}
/*
* Bounce a cache query up to pim6sd and netlink.
*
- * Called under mrt_lock.
+ * Called under rcu_read_lock()
*/
-static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
+static int ip6mr_cache_report(const struct mr_table *mrt, struct sk_buff *pkt,
mifi_t mifi, int assert)
{
struct sock *mroute6_sk;
@@ -1040,7 +1069,7 @@ static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
int ret;
#ifdef CONFIG_IPV6_PIMSM_V2
- if (assert == MRT6MSG_WHOLEPKT)
+ if (assert == MRT6MSG_WHOLEPKT || assert == MRT6MSG_WRMIFWHOLE)
skb = skb_realloc_headroom(pkt, -skb_network_offset(pkt)
+sizeof(*msg));
else
@@ -1056,20 +1085,23 @@ static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
skb->ip_summed = CHECKSUM_UNNECESSARY;
#ifdef CONFIG_IPV6_PIMSM_V2
- if (assert == MRT6MSG_WHOLEPKT) {
+ if (assert == MRT6MSG_WHOLEPKT || assert == MRT6MSG_WRMIFWHOLE) {
/* Ugly, but we have no choice with this interface.
Duplicate old header, fix length etc.
And all this only to mangle msg->im6_msgtype and
to set msg->im6_mbz to "mbz" :-)
*/
- skb_push(skb, -skb_network_offset(pkt));
+ __skb_pull(skb, skb_network_offset(pkt));
skb_push(skb, sizeof(*msg));
skb_reset_transport_header(skb);
msg = (struct mrt6msg *)skb_transport_header(skb);
msg->im6_mbz = 0;
- msg->im6_msgtype = MRT6MSG_WHOLEPKT;
- msg->im6_mif = mrt->mroute_reg_vif_num;
+ msg->im6_msgtype = assert;
+ if (assert == MRT6MSG_WRMIFWHOLE)
+ msg->im6_mif = mifi;
+ else
+ msg->im6_mif = READ_ONCE(mrt->mroute_reg_vif_num);
msg->im6_pad = 0;
msg->im6_src = ipv6_hdr(pkt)->saddr;
msg->im6_dst = ipv6_hdr(pkt)->daddr;
@@ -1104,10 +1136,8 @@ static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
skb->ip_summed = CHECKSUM_UNNECESSARY;
}
- rcu_read_lock();
mroute6_sk = rcu_dereference(mrt->mroute_sk);
if (!mroute6_sk) {
- rcu_read_unlock();
kfree_skb(skb);
return -EINVAL;
}
@@ -1116,7 +1146,7 @@ static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
/* Deliver to user space multicast routing algorithms */
ret = sock_queue_rcv_skb(mroute6_sk, skb);
- rcu_read_unlock();
+
if (ret < 0) {
net_warn_ratelimited("mroute6: pending queue full, dropping entries\n");
kfree_skb(skb);
@@ -1240,7 +1270,7 @@ static int ip6mr_device_event(struct notifier_block *this,
ip6mr_for_each_table(mrt, net) {
v = &mrt->vif_table[0];
for (ct = 0; ct < mrt->maxvif; ct++, v++) {
- if (v->dev == dev)
+ if (rcu_access_pointer(v->dev) == dev)
mif6_delete(mrt, ct, 1, NULL);
}
}
@@ -1248,18 +1278,16 @@ static int ip6mr_device_event(struct notifier_block *this,
return NOTIFY_DONE;
}
-static unsigned int ip6mr_seq_read(struct net *net)
+static unsigned int ip6mr_seq_read(const struct net *net)
{
- ASSERT_RTNL();
-
- return net->ipv6.ipmr_seq + ip6mr_rules_seq_read(net);
+ return READ_ONCE(net->ipv6.ipmr_seq) + ip6mr_rules_seq_read(net);
}
static int ip6mr_dump(struct net *net, struct notifier_block *nb,
struct netlink_ext_ack *extack)
{
return mr_dump(net, nb, RTNL_FAMILY_IP6MR, ip6mr_rules_dump,
- ip6mr_mr_table_iter, &mrt_lock, extack);
+ ip6mr_mr_table_iter, extack);
}
static struct notifier_block ip6_mr_notifier = {
@@ -1323,7 +1351,9 @@ static int __net_init ip6mr_net_init(struct net *net)
proc_cache_fail:
remove_proc_entry("ip6_mr_vif", net->proc_net);
proc_vif_fail:
+ rtnl_lock();
ip6mr_rules_exit(net);
+ rtnl_unlock();
#endif
ip6mr_rules_fail:
ip6mr_notifier_exit(net);
@@ -1336,23 +1366,36 @@ static void __net_exit ip6mr_net_exit(struct net *net)
remove_proc_entry("ip6_mr_cache", net->proc_net);
remove_proc_entry("ip6_mr_vif", net->proc_net);
#endif
- ip6mr_rules_exit(net);
ip6mr_notifier_exit(net);
}
+static void __net_exit ip6mr_net_exit_batch(struct list_head *net_list)
+{
+ struct net *net;
+
+ rtnl_lock();
+ list_for_each_entry(net, net_list, exit_list)
+ ip6mr_rules_exit(net);
+ rtnl_unlock();
+}
+
static struct pernet_operations ip6mr_net_ops = {
.init = ip6mr_net_init,
.exit = ip6mr_net_exit,
+ .exit_batch = ip6mr_net_exit_batch,
+};
+
+static const struct rtnl_msg_handler ip6mr_rtnl_msg_handlers[] __initconst_or_module = {
+ {.owner = THIS_MODULE, .protocol = RTNL_FAMILY_IP6MR,
+ .msgtype = RTM_GETROUTE,
+ .doit = ip6mr_rtm_getroute, .dumpit = ip6mr_rtm_dumproute},
};
int __init ip6_mr_init(void)
{
int err;
- mrt_cachep = kmem_cache_create("ip6_mrt_cache",
- sizeof(struct mfc6_cache),
- 0, SLAB_HWCACHE_ALIGN,
- NULL);
+ mrt_cachep = KMEM_CACHE(mfc6_cache, SLAB_HWCACHE_ALIGN);
if (!mrt_cachep)
return -ENOMEM;
@@ -1370,9 +1413,8 @@ int __init ip6_mr_init(void)
goto add_proto_fail;
}
#endif
- err = rtnl_register_module(THIS_MODULE, RTNL_FAMILY_IP6MR, RTM_GETROUTE,
- NULL, ip6mr_rtm_dumproute, 0);
- if (err == 0)
+ err = rtnl_register_many(ip6mr_rtnl_msg_handlers);
+ if (!err)
return 0;
#ifdef CONFIG_IPV6_PIMSM_V2
@@ -1387,9 +1429,9 @@ reg_pernet_fail:
return err;
}
-void ip6_mr_cleanup(void)
+void __init ip6_mr_cleanup(void)
{
- rtnl_unregister(RTNL_FAMILY_IP6MR, RTM_GETROUTE);
+ rtnl_unregister_many(ip6mr_rtnl_msg_handlers);
#ifdef CONFIG_IPV6_PIMSM_V2
inet6_del_protocol(&pim6_protocol, IPPROTO_PIM);
#endif
@@ -1422,12 +1464,12 @@ static int ip6mr_mfc_add(struct net *net, struct mr_table *mrt,
&mfc->mf6cc_mcastgrp.sin6_addr, parent);
rcu_read_unlock();
if (c) {
- write_lock_bh(&mrt_lock);
+ spin_lock(&mrt_lock);
c->_c.mfc_parent = mfc->mf6cc_parent;
ip6mr_update_thresholds(mrt, &c->_c, ttls);
if (!mrtsock)
c->_c.mfc_flags |= MFC_STATIC;
- write_unlock_bh(&mrt_lock);
+ spin_unlock(&mrt_lock);
call_ip6mr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_REPLACE,
c, mrt->id);
mr6_netlink_event(mrt, c, RTM_NEWROUTE);
@@ -1474,7 +1516,7 @@ static int ip6mr_mfc_add(struct net *net, struct mr_table *mrt,
}
}
if (list_empty(&mrt->mfc_unres_queue))
- del_timer(&mrt->ipmr_expire_timer);
+ timer_delete(&mrt->ipmr_expire_timer);
spin_unlock_bh(&mfc_unres_lock);
if (found) {
@@ -1545,15 +1587,15 @@ static int ip6mr_sk_init(struct mr_table *mrt, struct sock *sk)
struct net *net = sock_net(sk);
rtnl_lock();
- write_lock_bh(&mrt_lock);
+ spin_lock(&mrt_lock);
if (rtnl_dereference(mrt->mroute_sk)) {
err = -EADDRINUSE;
} else {
rcu_assign_pointer(mrt->mroute_sk, sk);
sock_set_flag(sk, SOCK_RCU_FREE);
- net->ipv6.devconf_all->mc_forwarding++;
+ atomic_inc(&net->ipv6.devconf_all->mc_forwarding);
}
- write_unlock_bh(&mrt_lock);
+ spin_unlock(&mrt_lock);
if (!err)
inet6_netconf_notify_devconf(net, RTM_NEWNETCONF,
@@ -1567,25 +1609,30 @@ static int ip6mr_sk_init(struct mr_table *mrt, struct sock *sk)
int ip6mr_sk_done(struct sock *sk)
{
- int err = -EACCES;
struct net *net = sock_net(sk);
+ struct ipv6_devconf *devconf;
struct mr_table *mrt;
+ int err = -EACCES;
if (sk->sk_type != SOCK_RAW ||
inet_sk(sk)->inet_num != IPPROTO_ICMPV6)
return err;
+ devconf = net->ipv6.devconf_all;
+ if (!devconf || !atomic_read(&devconf->mc_forwarding))
+ return err;
+
rtnl_lock();
ip6mr_for_each_table(mrt, net) {
if (sk == rtnl_dereference(mrt->mroute_sk)) {
- write_lock_bh(&mrt_lock);
+ spin_lock(&mrt_lock);
RCU_INIT_POINTER(mrt->mroute_sk, NULL);
/* Note that mroute_sk had SOCK_RCU_FREE set,
* so the RCU grace period before sk freeing
* is guaranteed by sk_destruct()
*/
- net->ipv6.devconf_all->mc_forwarding--;
- write_unlock_bh(&mrt_lock);
+ atomic_dec(&devconf->mc_forwarding);
+ spin_unlock(&mrt_lock);
inet6_netconf_notify_devconf(net, RTM_NEWNETCONF,
NETCONFA_MC_FORWARDING,
NETCONFA_IFINDEX_ALL,
@@ -1740,18 +1787,22 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval,
#ifdef CONFIG_IPV6_PIMSM_V2
case MRT6_PIM:
{
+ bool do_wrmifwhole;
int v;
if (optlen != sizeof(v))
return -EINVAL;
if (copy_from_sockptr(&v, optval, sizeof(v)))
return -EFAULT;
+
+ do_wrmifwhole = (v == MRT6MSG_WRMIFWHOLE);
v = !!v;
rtnl_lock();
ret = 0;
if (v != mrt->mroute_do_pim) {
mrt->mroute_do_pim = v;
mrt->mroute_do_assert = v;
+ mrt->mroute_do_wrvifwhole = do_wrmifwhole;
}
rtnl_unlock();
return ret;
@@ -1797,8 +1848,8 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval,
* Getsock opt support for the multicast routing system.
*/
-int ip6_mroute_getsockopt(struct sock *sk, int optname, char __user *optval,
- int __user *optlen)
+int ip6_mroute_getsockopt(struct sock *sk, int optname, sockptr_t optval,
+ sockptr_t optlen)
{
int olr;
int val;
@@ -1829,16 +1880,16 @@ int ip6_mroute_getsockopt(struct sock *sk, int optname, char __user *optval,
return -ENOPROTOOPT;
}
- if (get_user(olr, optlen))
+ if (copy_from_sockptr(&olr, optlen, sizeof(int)))
return -EFAULT;
olr = min_t(int, olr, sizeof(int));
if (olr < 0)
return -EINVAL;
- if (put_user(olr, optlen))
+ if (copy_to_sockptr(optlen, &olr, sizeof(int)))
return -EFAULT;
- if (copy_to_user(optval, &val, olr))
+ if (copy_to_sockptr(optval, &val, olr))
return -EFAULT;
return 0;
}
@@ -1846,11 +1897,10 @@ int ip6_mroute_getsockopt(struct sock *sk, int optname, char __user *optval,
/*
* The IP multicast ioctl support routines.
*/
-
-int ip6mr_ioctl(struct sock *sk, int cmd, void __user *arg)
+int ip6mr_ioctl(struct sock *sk, int cmd, void *arg)
{
- struct sioc_sg_req6 sr;
- struct sioc_mif_req6 vr;
+ struct sioc_sg_req6 *sr;
+ struct sioc_mif_req6 *vr;
struct vif_device *vif;
struct mfc6_cache *c;
struct net *net = sock_net(sk);
@@ -1862,40 +1912,33 @@ int ip6mr_ioctl(struct sock *sk, int cmd, void __user *arg)
switch (cmd) {
case SIOCGETMIFCNT_IN6:
- if (copy_from_user(&vr, arg, sizeof(vr)))
- return -EFAULT;
- if (vr.mifi >= mrt->maxvif)
+ vr = (struct sioc_mif_req6 *)arg;
+ if (vr->mifi >= mrt->maxvif)
return -EINVAL;
- vr.mifi = array_index_nospec(vr.mifi, mrt->maxvif);
- read_lock(&mrt_lock);
- vif = &mrt->vif_table[vr.mifi];
- if (VIF_EXISTS(mrt, vr.mifi)) {
- vr.icount = vif->pkt_in;
- vr.ocount = vif->pkt_out;
- vr.ibytes = vif->bytes_in;
- vr.obytes = vif->bytes_out;
- read_unlock(&mrt_lock);
-
- if (copy_to_user(arg, &vr, sizeof(vr)))
- return -EFAULT;
+ vr->mifi = array_index_nospec(vr->mifi, mrt->maxvif);
+ rcu_read_lock();
+ vif = &mrt->vif_table[vr->mifi];
+ if (VIF_EXISTS(mrt, vr->mifi)) {
+ vr->icount = READ_ONCE(vif->pkt_in);
+ vr->ocount = READ_ONCE(vif->pkt_out);
+ vr->ibytes = READ_ONCE(vif->bytes_in);
+ vr->obytes = READ_ONCE(vif->bytes_out);
+ rcu_read_unlock();
return 0;
}
- read_unlock(&mrt_lock);
+ rcu_read_unlock();
return -EADDRNOTAVAIL;
case SIOCGETSGCNT_IN6:
- if (copy_from_user(&sr, arg, sizeof(sr)))
- return -EFAULT;
+ sr = (struct sioc_sg_req6 *)arg;
rcu_read_lock();
- c = ip6mr_cache_find(mrt, &sr.src.sin6_addr, &sr.grp.sin6_addr);
+ c = ip6mr_cache_find(mrt, &sr->src.sin6_addr,
+ &sr->grp.sin6_addr);
if (c) {
- sr.pktcnt = c->_c.mfc_un.res.pkt;
- sr.bytecnt = c->_c.mfc_un.res.bytes;
- sr.wrong_if = c->_c.mfc_un.res.wrong_if;
+ sr->pktcnt = atomic_long_read(&c->_c.mfc_un.res.pkt);
+ sr->bytecnt = atomic_long_read(&c->_c.mfc_un.res.bytes);
+ sr->wrong_if = atomic_long_read(&c->_c.mfc_un.res.wrong_if);
rcu_read_unlock();
-
- if (copy_to_user(arg, &sr, sizeof(sr)))
- return -EFAULT;
return 0;
}
rcu_read_unlock();
@@ -1942,20 +1985,20 @@ int ip6mr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg)
if (vr.mifi >= mrt->maxvif)
return -EINVAL;
vr.mifi = array_index_nospec(vr.mifi, mrt->maxvif);
- read_lock(&mrt_lock);
+ rcu_read_lock();
vif = &mrt->vif_table[vr.mifi];
if (VIF_EXISTS(mrt, vr.mifi)) {
- vr.icount = vif->pkt_in;
- vr.ocount = vif->pkt_out;
- vr.ibytes = vif->bytes_in;
- vr.obytes = vif->bytes_out;
- read_unlock(&mrt_lock);
+ vr.icount = READ_ONCE(vif->pkt_in);
+ vr.ocount = READ_ONCE(vif->pkt_out);
+ vr.ibytes = READ_ONCE(vif->bytes_in);
+ vr.obytes = READ_ONCE(vif->bytes_out);
+ rcu_read_unlock();
if (copy_to_user(arg, &vr, sizeof(vr)))
return -EFAULT;
return 0;
}
- read_unlock(&mrt_lock);
+ rcu_read_unlock();
return -EADDRNOTAVAIL;
case SIOCGETSGCNT_IN6:
if (copy_from_user(&sr, arg, sizeof(sr)))
@@ -1964,9 +2007,9 @@ int ip6mr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg)
rcu_read_lock();
c = ip6mr_cache_find(mrt, &sr.src.sin6_addr, &sr.grp.sin6_addr);
if (c) {
- sr.pktcnt = c->_c.mfc_un.res.pkt;
- sr.bytecnt = c->_c.mfc_un.res.bytes;
- sr.wrong_if = c->_c.mfc_un.res.wrong_if;
+ sr.pktcnt = atomic_long_read(&c->_c.mfc_un.res.pkt);
+ sr.bytecnt = atomic_long_read(&c->_c.mfc_un.res.bytes);
+ sr.wrong_if = atomic_long_read(&c->_c.mfc_un.res.wrong_if);
rcu_read_unlock();
if (copy_to_user(arg, &sr, sizeof(sr)))
@@ -1985,8 +2028,6 @@ static inline int ip6mr_forward2_finish(struct net *net, struct sock *sk, struct
{
IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)),
IPSTATS_MIB_OUTFORWDATAGRAMS);
- IP6_ADD_STATS(net, ip6_dst_idev(skb_dst(skb)),
- IPSTATS_MIB_OUTOCTETS, skb->len);
return dst_output(net, sk, skb);
}
@@ -1994,26 +2035,27 @@ static inline int ip6mr_forward2_finish(struct net *net, struct sock *sk, struct
* Processing handlers for ip6mr_forward
*/
-static int ip6mr_forward2(struct net *net, struct mr_table *mrt,
- struct sk_buff *skb, int vifi)
+static int ip6mr_prepare_xmit(struct net *net, struct mr_table *mrt,
+ struct sk_buff *skb, int vifi)
{
- struct ipv6hdr *ipv6h;
struct vif_device *vif = &mrt->vif_table[vifi];
- struct net_device *dev;
+ struct net_device *vif_dev;
+ struct ipv6hdr *ipv6h;
struct dst_entry *dst;
struct flowi6 fl6;
- if (!vif->dev)
- goto out_free;
+ vif_dev = vif_dev_read(vif);
+ if (!vif_dev)
+ return -1;
#ifdef CONFIG_IPV6_PIMSM_V2
if (vif->flags & MIFF_REGISTER) {
- vif->pkt_out++;
- vif->bytes_out += skb->len;
- vif->dev->stats.tx_bytes += skb->len;
- vif->dev->stats.tx_packets++;
+ WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1);
+ WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len);
+ DEV_STATS_ADD(vif_dev, tx_bytes, skb->len);
+ DEV_STATS_INC(vif_dev, tx_packets);
ip6mr_cache_report(mrt, skb, vifi, MRT6MSG_WHOLEPKT);
- goto out_free;
+ return -1;
}
#endif
@@ -2027,7 +2069,7 @@ static int ip6mr_forward2(struct net *net, struct mr_table *mrt,
dst = ip6_route_output(net, NULL, &fl6);
if (dst->error) {
dst_release(dst);
- goto out_free;
+ return -1;
}
skb_dst_drop(skb);
@@ -2044,41 +2086,66 @@ static int ip6mr_forward2(struct net *net, struct mr_table *mrt,
* not mrouter) cannot join to more than one interface - it will
* result in receiving multiple packets.
*/
- dev = vif->dev;
- skb->dev = dev;
- vif->pkt_out++;
- vif->bytes_out += skb->len;
+ skb->dev = vif_dev;
+ WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1);
+ WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len);
/* We are about to write */
/* XXX: extension headers? */
- if (skb_cow(skb, sizeof(*ipv6h) + LL_RESERVED_SPACE(dev)))
- goto out_free;
+ if (skb_cow(skb, sizeof(*ipv6h) + LL_RESERVED_SPACE(vif_dev)))
+ return -1;
ipv6h = ipv6_hdr(skb);
ipv6h->hop_limit--;
+ return 0;
+}
+
+static void ip6mr_forward2(struct net *net, struct mr_table *mrt,
+ struct sk_buff *skb, int vifi)
+{
+ struct net_device *indev = skb->dev;
+
+ if (ip6mr_prepare_xmit(net, mrt, skb, vifi))
+ goto out_free;
IP6CB(skb)->flags |= IP6SKB_FORWARDED;
- return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD,
- net, NULL, skb, skb->dev, dev,
- ip6mr_forward2_finish);
+ NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD,
+ net, NULL, skb, indev, skb->dev,
+ ip6mr_forward2_finish);
+ return;
out_free:
kfree_skb(skb);
- return 0;
}
+static void ip6mr_output2(struct net *net, struct mr_table *mrt,
+ struct sk_buff *skb, int vifi)
+{
+ if (ip6mr_prepare_xmit(net, mrt, skb, vifi))
+ goto out_free;
+
+ ip6_output(net, NULL, skb);
+ return;
+
+out_free:
+ kfree_skb(skb);
+}
+
+/* Called with rcu_read_lock() */
static int ip6mr_find_vif(struct mr_table *mrt, struct net_device *dev)
{
int ct;
- for (ct = mrt->maxvif - 1; ct >= 0; ct--) {
- if (mrt->vif_table[ct].dev == dev)
+ /* Pairs with WRITE_ONCE() in mif6_delete()/mif6_add() */
+ for (ct = READ_ONCE(mrt->maxvif) - 1; ct >= 0; ct--) {
+ if (rcu_access_pointer(mrt->vif_table[ct].dev) == dev)
break;
}
return ct;
}
+/* Called under rcu_read_lock() */
static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
struct net_device *dev, struct sk_buff *skb,
struct mfc6_cache *c)
@@ -2088,9 +2155,9 @@ static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
int true_vifi = ip6mr_find_vif(mrt, dev);
vif = c->_c.mfc_parent;
- c->_c.mfc_un.res.pkt++;
- c->_c.mfc_un.res.bytes += skb->len;
- c->_c.mfc_un.res.lastuse = jiffies;
+ atomic_long_inc(&c->_c.mfc_un.res.pkt);
+ atomic_long_add(skb->len, &c->_c.mfc_un.res.bytes);
+ WRITE_ONCE(c->_c.mfc_un.res.lastuse, jiffies);
if (ipv6_addr_any(&c->mf6c_origin) && true_vifi >= 0) {
struct mfc6_cache *cache_proxy;
@@ -2098,21 +2165,17 @@ static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
/* For an (*,G) entry, we only check that the incoming
* interface is part of the static tree.
*/
- rcu_read_lock();
cache_proxy = mr_mfc_find_any_parent(mrt, vif);
if (cache_proxy &&
- cache_proxy->_c.mfc_un.res.ttls[true_vifi] < 255) {
- rcu_read_unlock();
+ cache_proxy->_c.mfc_un.res.ttls[true_vifi] < 255)
goto forward;
- }
- rcu_read_unlock();
}
/*
* Wrong interface: drop packet and (maybe) send PIM assert.
*/
- if (mrt->vif_table[vif].dev != dev) {
- c->_c.mfc_un.res.wrong_if++;
+ if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev) {
+ atomic_long_inc(&c->_c.mfc_un.res.wrong_if);
if (true_vifi >= 0 && mrt->mroute_do_assert &&
/* pimsm uses asserts, when switching from RPT to SPT,
@@ -2127,13 +2190,18 @@ static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
MFC_ASSERT_THRESH)) {
c->_c.mfc_un.res.last_assert = jiffies;
ip6mr_cache_report(mrt, skb, true_vifi, MRT6MSG_WRONGMIF);
+ if (mrt->mroute_do_wrvifwhole)
+ ip6mr_cache_report(mrt, skb, true_vifi,
+ MRT6MSG_WRMIFWHOLE);
}
goto dont_forward;
}
forward:
- mrt->vif_table[vif].pkt_in++;
- mrt->vif_table[vif].bytes_in += skb->len;
+ WRITE_ONCE(mrt->vif_table[vif].pkt_in,
+ mrt->vif_table[vif].pkt_in + 1);
+ WRITE_ONCE(mrt->vif_table[vif].bytes_in,
+ mrt->vif_table[vif].bytes_in + skb->len);
/*
* Forward the frame
@@ -2176,6 +2244,56 @@ dont_forward:
kfree_skb(skb);
}
+/* Called under rcu_read_lock() */
+static void ip6_mr_output_finish(struct net *net, struct mr_table *mrt,
+ struct net_device *dev, struct sk_buff *skb,
+ struct mfc6_cache *c)
+{
+ int psend = -1;
+ int ct;
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+
+ atomic_long_inc(&c->_c.mfc_un.res.pkt);
+ atomic_long_add(skb->len, &c->_c.mfc_un.res.bytes);
+ WRITE_ONCE(c->_c.mfc_un.res.lastuse, jiffies);
+
+ /* Forward the frame */
+ if (ipv6_addr_any(&c->mf6c_origin) &&
+ ipv6_addr_any(&c->mf6c_mcastgrp)) {
+ if (ipv6_hdr(skb)->hop_limit >
+ c->_c.mfc_un.res.ttls[c->_c.mfc_parent]) {
+ /* It's an (*,*) entry and the packet is not coming from
+ * the upstream: forward the packet to the upstream
+ * only.
+ */
+ psend = c->_c.mfc_parent;
+ goto last_forward;
+ }
+ goto dont_forward;
+ }
+ for (ct = c->_c.mfc_un.res.maxvif - 1;
+ ct >= c->_c.mfc_un.res.minvif; ct--) {
+ if (ipv6_hdr(skb)->hop_limit > c->_c.mfc_un.res.ttls[ct]) {
+ if (psend != -1) {
+ struct sk_buff *skb2;
+
+ skb2 = skb_clone(skb, GFP_ATOMIC);
+ if (skb2)
+ ip6mr_output2(net, mrt, skb2, psend);
+ }
+ psend = ct;
+ }
+ }
+last_forward:
+ if (psend != -1) {
+ ip6mr_output2(net, mrt, skb, psend);
+ return;
+ }
+
+dont_forward:
+ kfree_skb(skb);
+}
/*
* Multicast packets for forwarding arrive here
@@ -2183,21 +2301,20 @@ dont_forward:
int ip6_mr_input(struct sk_buff *skb)
{
+ struct net_device *dev = skb->dev;
+ struct net *net = dev_net_rcu(dev);
struct mfc6_cache *cache;
- struct net *net = dev_net(skb->dev);
struct mr_table *mrt;
struct flowi6 fl6 = {
- .flowi6_iif = skb->dev->ifindex,
+ .flowi6_iif = dev->ifindex,
.flowi6_mark = skb->mark,
};
int err;
- struct net_device *dev;
/* skb->dev passed in is the master dev for vrfs.
* Get the proper interface that does have a vif associated with it.
*/
- dev = skb->dev;
- if (netif_is_l3_master(skb->dev)) {
+ if (netif_is_l3_master(dev)) {
dev = dev_get_by_index_rcu(net, IPCB(skb)->iif);
if (!dev) {
kfree_skb(skb);
@@ -2211,7 +2328,6 @@ int ip6_mr_input(struct sk_buff *skb)
return err;
}
- read_lock(&mrt_lock);
cache = ip6mr_cache_find(mrt,
&ipv6_hdr(skb)->saddr, &ipv6_hdr(skb)->daddr);
if (!cache) {
@@ -2232,20 +2348,71 @@ int ip6_mr_input(struct sk_buff *skb)
vif = ip6mr_find_vif(mrt, dev);
if (vif >= 0) {
int err = ip6mr_cache_unresolved(mrt, vif, skb, dev);
- read_unlock(&mrt_lock);
return err;
}
- read_unlock(&mrt_lock);
kfree_skb(skb);
return -ENODEV;
}
ip6_mr_forward(net, mrt, dev, skb, cache);
- read_unlock(&mrt_lock);
+ return 0;
+}
+
+int ip6_mr_output(struct net *net, struct sock *sk, struct sk_buff *skb)
+{
+ struct net_device *dev = skb_dst(skb)->dev;
+ struct flowi6 fl6 = (struct flowi6) {
+ .flowi6_iif = LOOPBACK_IFINDEX,
+ .flowi6_mark = skb->mark,
+ };
+ struct mfc6_cache *cache;
+ struct mr_table *mrt;
+ int err;
+ int vif;
+
+ guard(rcu)();
+
+ if (IP6CB(skb)->flags & IP6SKB_FORWARDED)
+ goto ip6_output;
+ if (!(IP6CB(skb)->flags & IP6SKB_MCROUTE))
+ goto ip6_output;
+
+ err = ip6mr_fib_lookup(net, &fl6, &mrt);
+ if (err < 0) {
+ kfree_skb(skb);
+ return err;
+ }
+
+ cache = ip6mr_cache_find(mrt,
+ &ipv6_hdr(skb)->saddr, &ipv6_hdr(skb)->daddr);
+ if (!cache) {
+ vif = ip6mr_find_vif(mrt, dev);
+ if (vif >= 0)
+ cache = ip6mr_cache_find_any(mrt,
+ &ipv6_hdr(skb)->daddr,
+ vif);
+ }
+
+ /* No usable cache entry */
+ if (!cache) {
+ vif = ip6mr_find_vif(mrt, dev);
+ if (vif >= 0)
+ return ip6mr_cache_unresolved(mrt, vif, skb, dev);
+ goto ip6_output;
+ }
+
+ /* Wrong interface */
+ vif = cache->_c.mfc_parent;
+ if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev)
+ goto ip6_output;
+ ip6_mr_output_finish(net, mrt, dev, skb, cache);
return 0;
+
+ip6_output:
+ return ip6_output(net, sk, skb);
}
int ip6mr_get_route(struct net *net, struct sk_buff *skb, struct rtmsg *rtm,
@@ -2254,13 +2421,15 @@ int ip6mr_get_route(struct net *net, struct sk_buff *skb, struct rtmsg *rtm,
int err;
struct mr_table *mrt;
struct mfc6_cache *cache;
- struct rt6_info *rt = (struct rt6_info *)skb_dst(skb);
+ struct rt6_info *rt = dst_rt6_info(skb_dst(skb));
- mrt = ip6mr_get_table(net, RT6_TABLE_DFLT);
- if (!mrt)
+ rcu_read_lock();
+ mrt = __ip6mr_get_table(net, RT6_TABLE_DFLT);
+ if (!mrt) {
+ rcu_read_unlock();
return -ENOENT;
+ }
- read_lock(&mrt_lock);
cache = ip6mr_cache_find(mrt, &rt->rt6i_src.addr, &rt->rt6i_dst.addr);
if (!cache && skb->dev) {
int vif = ip6mr_find_vif(mrt, skb->dev);
@@ -2278,14 +2447,14 @@ int ip6mr_get_route(struct net *net, struct sk_buff *skb, struct rtmsg *rtm,
dev = skb->dev;
if (!dev || (vif = ip6mr_find_vif(mrt, dev)) < 0) {
- read_unlock(&mrt_lock);
+ rcu_read_unlock();
return -ENODEV;
}
/* really correct? */
skb2 = alloc_skb(sizeof(struct ipv6hdr), GFP_ATOMIC);
if (!skb2) {
- read_unlock(&mrt_lock);
+ rcu_read_unlock();
return -ENOMEM;
}
@@ -2308,13 +2477,13 @@ int ip6mr_get_route(struct net *net, struct sk_buff *skb, struct rtmsg *rtm,
iph->daddr = rt->rt6i_dst.addr;
err = ip6mr_cache_unresolved(mrt, vif, skb2, dev);
- read_unlock(&mrt_lock);
+ rcu_read_unlock();
return err;
}
err = mr_fill_mroute(mrt, skb, &cache->_c, rtm);
- read_unlock(&mrt_lock);
+ rcu_read_unlock();
return err;
}
@@ -2412,8 +2581,7 @@ static void mr6_netlink_event(struct mr_table *mrt, struct mfc6_cache *mfc,
errout:
kfree_skb(skb);
- if (err < 0)
- rtnl_set_sk_err(net, RTNLGRP_IPV6_MROUTE, err);
+ rtnl_set_sk_err(net, RTNLGRP_IPV6_MROUTE, err);
}
static size_t mrt6msg_netlink_msgsize(size_t payloadlen)
@@ -2433,7 +2601,7 @@ static size_t mrt6msg_netlink_msgsize(size_t payloadlen)
return len;
}
-static void mrt6msg_netlink_event(struct mr_table *mrt, struct sk_buff *pkt)
+static void mrt6msg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt)
{
struct net *net = read_pnet(&mrt->net);
struct nlmsghdr *nlh;
@@ -2481,10 +2649,101 @@ errout:
rtnl_set_sk_err(net, RTNLGRP_IPV6_MROUTE_R, -ENOBUFS);
}
+static const struct nla_policy ip6mr_getroute_policy[RTA_MAX + 1] = {
+ [RTA_SRC] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
+ [RTA_DST] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
+ [RTA_TABLE] = { .type = NLA_U32 },
+};
+
+static int ip6mr_rtm_valid_getroute_req(struct sk_buff *skb,
+ const struct nlmsghdr *nlh,
+ struct nlattr **tb,
+ struct netlink_ext_ack *extack)
+{
+ struct rtmsg *rtm;
+ int err;
+
+ err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, ip6mr_getroute_policy,
+ extack);
+ if (err)
+ return err;
+
+ rtm = nlmsg_data(nlh);
+ if ((rtm->rtm_src_len && rtm->rtm_src_len != 128) ||
+ (rtm->rtm_dst_len && rtm->rtm_dst_len != 128) ||
+ rtm->rtm_tos || rtm->rtm_table || rtm->rtm_protocol ||
+ rtm->rtm_scope || rtm->rtm_type || rtm->rtm_flags) {
+ NL_SET_ERR_MSG_MOD(extack,
+ "Invalid values in header for multicast route get request");
+ return -EINVAL;
+ }
+
+ if ((tb[RTA_SRC] && !rtm->rtm_src_len) ||
+ (tb[RTA_DST] && !rtm->rtm_dst_len)) {
+ NL_SET_ERR_MSG_MOD(extack, "rtm_src_len and rtm_dst_len must be 128 for IPv6");
+ return -EINVAL;
+ }
+
+ return 0;
+}
+
+static int ip6mr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh,
+ struct netlink_ext_ack *extack)
+{
+ struct net *net = sock_net(in_skb->sk);
+ struct in6_addr src = {}, grp = {};
+ struct nlattr *tb[RTA_MAX + 1];
+ struct mfc6_cache *cache;
+ struct mr_table *mrt;
+ struct sk_buff *skb;
+ u32 tableid;
+ int err;
+
+ err = ip6mr_rtm_valid_getroute_req(in_skb, nlh, tb, extack);
+ if (err < 0)
+ return err;
+
+ if (tb[RTA_SRC])
+ src = nla_get_in6_addr(tb[RTA_SRC]);
+ if (tb[RTA_DST])
+ grp = nla_get_in6_addr(tb[RTA_DST]);
+ tableid = nla_get_u32_default(tb[RTA_TABLE], 0);
+
+ mrt = __ip6mr_get_table(net, tableid ?: RT_TABLE_DEFAULT);
+ if (!mrt) {
+ NL_SET_ERR_MSG_MOD(extack, "MR table does not exist");
+ return -ENOENT;
+ }
+
+ /* entries are added/deleted only under RTNL */
+ rcu_read_lock();
+ cache = ip6mr_cache_find(mrt, &src, &grp);
+ rcu_read_unlock();
+ if (!cache) {
+ NL_SET_ERR_MSG_MOD(extack, "MR cache entry not found");
+ return -ENOENT;
+ }
+
+ skb = nlmsg_new(mr6_msgsize(false, mrt->maxvif), GFP_KERNEL);
+ if (!skb)
+ return -ENOBUFS;
+
+ err = ip6mr_fill_mroute(mrt, skb, NETLINK_CB(in_skb).portid,
+ nlh->nlmsg_seq, cache, RTM_NEWROUTE, 0);
+ if (err < 0) {
+ kfree_skb(skb);
+ return err;
+ }
+
+ return rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid);
+}
+
static int ip6mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb)
{
const struct nlmsghdr *nlh = cb->nlh;
- struct fib_dump_filter filter = {};
+ struct fib_dump_filter filter = {
+ .rtnl_held = true,
+ };
int err;
if (cb->strict_check) {
@@ -2497,7 +2756,7 @@ static int ip6mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb)
if (filter.table_id) {
struct mr_table *mrt;
- mrt = ip6mr_get_table(sock_net(skb->sk), filter.table_id);
+ mrt = __ip6mr_get_table(sock_net(skb->sk), filter.table_id);
if (!mrt) {
if (rtnl_msg_family(cb->nlh) != RTNL_FAMILY_IP6MR)
return skb->len;