summaryrefslogtreecommitdiff
path: root/net/xfrm/xfrm_state.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/xfrm/xfrm_state.c')
-rw-r--r--net/xfrm/xfrm_state.c447
1 files changed, 412 insertions, 35 deletions
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index bda5327bf34d..ad2202fa82f3 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -34,6 +34,8 @@
#define xfrm_state_deref_prot(table, net) \
rcu_dereference_protected((table), lockdep_is_held(&(net)->xfrm.xfrm_state_lock))
+#define xfrm_state_deref_check(table, net) \
+ rcu_dereference_check((table), lockdep_is_held(&(net)->xfrm.xfrm_state_lock))
static void xfrm_state_gc_task(struct work_struct *work);
@@ -49,6 +51,7 @@ static struct kmem_cache *xfrm_state_cache __ro_after_init;
static DECLARE_WORK(xfrm_state_gc_work, xfrm_state_gc_task);
static HLIST_HEAD(xfrm_state_gc_list);
+static HLIST_HEAD(xfrm_state_dev_gc_list);
static inline bool xfrm_state_hold_rcu(struct xfrm_state __rcu *x)
{
@@ -61,6 +64,8 @@ static inline unsigned int xfrm_dst_hash(struct net *net,
u32 reqid,
unsigned short family)
{
+ lockdep_assert_held(&net->xfrm.xfrm_state_lock);
+
return __xfrm_dst_hash(daddr, saddr, reqid, family, net->xfrm.state_hmask);
}
@@ -69,6 +74,8 @@ static inline unsigned int xfrm_src_hash(struct net *net,
const xfrm_address_t *saddr,
unsigned short family)
{
+ lockdep_assert_held(&net->xfrm.xfrm_state_lock);
+
return __xfrm_src_hash(daddr, saddr, family, net->xfrm.state_hmask);
}
@@ -76,11 +83,15 @@ static inline unsigned int
xfrm_spi_hash(struct net *net, const xfrm_address_t *daddr,
__be32 spi, u8 proto, unsigned short family)
{
+ lockdep_assert_held(&net->xfrm.xfrm_state_lock);
+
return __xfrm_spi_hash(daddr, spi, proto, family, net->xfrm.state_hmask);
}
static unsigned int xfrm_seq_hash(struct net *net, u32 seq)
{
+ lockdep_assert_held(&net->xfrm.xfrm_state_lock);
+
return __xfrm_seq_hash(seq, net->xfrm.state_hmask);
}
@@ -214,6 +225,7 @@ static DEFINE_SPINLOCK(xfrm_state_afinfo_lock);
static struct xfrm_state_afinfo __rcu *xfrm_state_afinfo[NPROTO];
static DEFINE_SPINLOCK(xfrm_state_gc_lock);
+static DEFINE_SPINLOCK(xfrm_state_dev_gc_lock);
int __xfrm_state_delete(struct xfrm_state *x);
@@ -465,6 +477,11 @@ static const struct xfrm_mode xfrm4_mode_map[XFRM_MODE_MAX] = {
.flags = XFRM_MODE_FLAG_TUNNEL,
.family = AF_INET,
},
+ [XFRM_MODE_IPTFS] = {
+ .encap = XFRM_MODE_IPTFS,
+ .flags = XFRM_MODE_FLAG_TUNNEL,
+ .family = AF_INET,
+ },
};
static const struct xfrm_mode xfrm6_mode_map[XFRM_MODE_MAX] = {
@@ -486,6 +503,11 @@ static const struct xfrm_mode xfrm6_mode_map[XFRM_MODE_MAX] = {
.flags = XFRM_MODE_FLAG_TUNNEL,
.family = AF_INET6,
},
+ [XFRM_MODE_IPTFS] = {
+ .encap = XFRM_MODE_IPTFS,
+ .flags = XFRM_MODE_FLAG_TUNNEL,
+ .family = AF_INET6,
+ },
};
static const struct xfrm_mode *xfrm_get_mode(unsigned int encap, int family)
@@ -513,6 +535,60 @@ static const struct xfrm_mode *xfrm_get_mode(unsigned int encap, int family)
return NULL;
}
+static const struct xfrm_mode_cbs __rcu *xfrm_mode_cbs_map[XFRM_MODE_MAX];
+static DEFINE_SPINLOCK(xfrm_mode_cbs_map_lock);
+
+int xfrm_register_mode_cbs(u8 mode, const struct xfrm_mode_cbs *mode_cbs)
+{
+ if (mode >= XFRM_MODE_MAX)
+ return -EINVAL;
+
+ spin_lock_bh(&xfrm_mode_cbs_map_lock);
+ rcu_assign_pointer(xfrm_mode_cbs_map[mode], mode_cbs);
+ spin_unlock_bh(&xfrm_mode_cbs_map_lock);
+
+ return 0;
+}
+EXPORT_SYMBOL(xfrm_register_mode_cbs);
+
+void xfrm_unregister_mode_cbs(u8 mode)
+{
+ if (mode >= XFRM_MODE_MAX)
+ return;
+
+ spin_lock_bh(&xfrm_mode_cbs_map_lock);
+ RCU_INIT_POINTER(xfrm_mode_cbs_map[mode], NULL);
+ spin_unlock_bh(&xfrm_mode_cbs_map_lock);
+ synchronize_rcu();
+}
+EXPORT_SYMBOL(xfrm_unregister_mode_cbs);
+
+static const struct xfrm_mode_cbs *xfrm_get_mode_cbs(u8 mode)
+{
+ const struct xfrm_mode_cbs *cbs;
+ bool try_load = true;
+
+ if (mode >= XFRM_MODE_MAX)
+ return NULL;
+
+retry:
+ rcu_read_lock();
+
+ cbs = rcu_dereference(xfrm_mode_cbs_map[mode]);
+ if (cbs && !try_module_get(cbs->owner))
+ cbs = NULL;
+
+ rcu_read_unlock();
+
+ if (mode == XFRM_MODE_IPTFS && !cbs && try_load) {
+ request_module("xfrm-iptfs");
+ try_load = false;
+ goto retry;
+ }
+
+ return cbs;
+}
+
void xfrm_state_free(struct xfrm_state *x)
{
kmem_cache_free(xfrm_state_cache, x);
@@ -521,6 +597,8 @@ EXPORT_SYMBOL(xfrm_state_free);
static void ___xfrm_state_destroy(struct xfrm_state *x)
{
+ if (x->mode_cbs && x->mode_cbs->destroy_state)
+ x->mode_cbs->destroy_state(x);
hrtimer_cancel(&x->mtimer);
del_timer_sync(&x->rtimer);
kfree(x->aead);
@@ -570,7 +648,7 @@ static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me)
int err = 0;
spin_lock(&x->lock);
- xfrm_dev_state_update_curlft(x);
+ xfrm_dev_state_update_stats(x);
if (x->km.state == XFRM_STATE_DEAD)
goto out;
@@ -663,6 +741,7 @@ struct xfrm_state *xfrm_state_alloc(struct net *net)
refcount_set(&x->refcnt, 1);
atomic_set(&x->tunnel_users, 0);
INIT_LIST_HEAD(&x->km.all);
+ INIT_HLIST_NODE(&x->state_cache);
INIT_HLIST_NODE(&x->bydst);
INIT_HLIST_NODE(&x->bysrc);
INIT_HLIST_NODE(&x->byspi);
@@ -677,12 +756,49 @@ struct xfrm_state *xfrm_state_alloc(struct net *net)
x->lft.hard_packet_limit = XFRM_INF;
x->replay_maxage = 0;
x->replay_maxdiff = 0;
+ x->pcpu_num = UINT_MAX;
spin_lock_init(&x->lock);
+ x->mode_data = NULL;
}
return x;
}
EXPORT_SYMBOL(xfrm_state_alloc);
+#ifdef CONFIG_XFRM_OFFLOAD
+void xfrm_dev_state_delete(struct xfrm_state *x)
+{
+ struct xfrm_dev_offload *xso = &x->xso;
+ struct net_device *dev = READ_ONCE(xso->dev);
+
+ if (dev) {
+ dev->xfrmdev_ops->xdo_dev_state_delete(x);
+ spin_lock_bh(&xfrm_state_dev_gc_lock);
+ hlist_add_head(&x->dev_gclist, &xfrm_state_dev_gc_list);
+ spin_unlock_bh(&xfrm_state_dev_gc_lock);
+ }
+}
+EXPORT_SYMBOL_GPL(xfrm_dev_state_delete);
+
+void xfrm_dev_state_free(struct xfrm_state *x)
+{
+ struct xfrm_dev_offload *xso = &x->xso;
+ struct net_device *dev = READ_ONCE(xso->dev);
+
+ if (dev && dev->xfrmdev_ops) {
+ spin_lock_bh(&xfrm_state_dev_gc_lock);
+ if (!hlist_unhashed(&x->dev_gclist))
+ hlist_del(&x->dev_gclist);
+ spin_unlock_bh(&xfrm_state_dev_gc_lock);
+
+ if (dev->xfrmdev_ops->xdo_dev_state_free)
+ dev->xfrmdev_ops->xdo_dev_state_free(x);
+ WRITE_ONCE(xso->dev, NULL);
+ xso->type = XFRM_DEV_OFFLOAD_UNSPECIFIED;
+ netdev_put(dev, &xso->dev_tracker);
+ }
+}
+#endif
+
void __xfrm_state_destroy(struct xfrm_state *x, bool sync)
{
WARN_ON(x->km.state != XFRM_STATE_DEAD);
@@ -706,15 +822,22 @@ int __xfrm_state_delete(struct xfrm_state *x)
if (x->km.state != XFRM_STATE_DEAD) {
x->km.state = XFRM_STATE_DEAD;
+
spin_lock(&net->xfrm.xfrm_state_lock);
list_del(&x->km.all);
hlist_del_rcu(&x->bydst);
hlist_del_rcu(&x->bysrc);
if (x->km.seq)
hlist_del_rcu(&x->byseq);
+ if (!hlist_unhashed(&x->state_cache))
+ hlist_del_rcu(&x->state_cache);
+ if (!hlist_unhashed(&x->state_cache_input))
+ hlist_del_rcu(&x->state_cache_input);
+
if (x->id.spi)
hlist_del_rcu(&x->byspi);
net->xfrm.state_num--;
+ xfrm_nat_keepalive_state_updated(x);
spin_unlock(&net->xfrm.xfrm_state_lock);
if (x->encap_sk)
@@ -848,6 +971,9 @@ EXPORT_SYMBOL(xfrm_state_flush);
int xfrm_dev_state_flush(struct net *net, struct net_device *dev, bool task_valid)
{
+ struct xfrm_state *x;
+ struct hlist_node *tmp;
+ struct xfrm_dev_offload *xso;
int i, err = 0, cnt = 0;
spin_lock_bh(&net->xfrm.xfrm_state_lock);
@@ -857,8 +983,6 @@ int xfrm_dev_state_flush(struct net *net, struct net_device *dev, bool task_vali
err = -ESRCH;
for (i = 0; i <= net->xfrm.state_hmask; i++) {
- struct xfrm_state *x;
- struct xfrm_dev_offload *xso;
restart:
hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) {
xso = &x->xso;
@@ -868,6 +992,8 @@ restart:
spin_unlock_bh(&net->xfrm.xfrm_state_lock);
err = xfrm_state_delete(x);
+ xfrm_dev_state_free(x);
+
xfrm_audit_state_delete(x, err ? 0 : 1,
task_valid);
xfrm_state_put(x);
@@ -884,6 +1010,24 @@ restart:
out:
spin_unlock_bh(&net->xfrm.xfrm_state_lock);
+
+ spin_lock_bh(&xfrm_state_dev_gc_lock);
+restart_gc:
+ hlist_for_each_entry_safe(x, tmp, &xfrm_state_dev_gc_list, dev_gclist) {
+ xso = &x->xso;
+
+ if (xso->dev == dev) {
+ spin_unlock_bh(&xfrm_state_dev_gc_lock);
+ xfrm_dev_state_free(x);
+ spin_lock_bh(&xfrm_state_dev_gc_lock);
+ goto restart_gc;
+ }
+
+ }
+ spin_unlock_bh(&xfrm_state_dev_gc_lock);
+
+ xfrm_flush_gc();
+
return err;
}
EXPORT_SYMBOL(xfrm_dev_state_flush);
@@ -974,16 +1118,38 @@ xfrm_init_tempstate(struct xfrm_state *x, const struct flowi *fl,
x->props.family = tmpl->encap_family;
}
-static struct xfrm_state *__xfrm_state_lookup_all(struct net *net, u32 mark,
+struct xfrm_hash_state_ptrs {
+ const struct hlist_head *bydst;
+ const struct hlist_head *bysrc;
+ const struct hlist_head *byspi;
+ unsigned int hmask;
+};
+
+static void xfrm_hash_ptrs_get(const struct net *net, struct xfrm_hash_state_ptrs *ptrs)
+{
+ unsigned int sequence;
+
+ do {
+ sequence = read_seqcount_begin(&net->xfrm.xfrm_state_hash_generation);
+
+ ptrs->bydst = xfrm_state_deref_check(net->xfrm.state_bydst, net);
+ ptrs->bysrc = xfrm_state_deref_check(net->xfrm.state_bysrc, net);
+ ptrs->byspi = xfrm_state_deref_check(net->xfrm.state_byspi, net);
+ ptrs->hmask = net->xfrm.state_hmask;
+ } while (read_seqcount_retry(&net->xfrm.xfrm_state_hash_generation, sequence));
+}
+
+static struct xfrm_state *__xfrm_state_lookup_all(const struct xfrm_hash_state_ptrs *state_ptrs,
+ u32 mark,
const xfrm_address_t *daddr,
__be32 spi, u8 proto,
unsigned short family,
struct xfrm_dev_offload *xdo)
{
- unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family);
+ unsigned int h = __xfrm_spi_hash(daddr, spi, proto, family, state_ptrs->hmask);
struct xfrm_state *x;
- hlist_for_each_entry_rcu(x, net->xfrm.state_byspi + h, byspi) {
+ hlist_for_each_entry_rcu(x, state_ptrs->byspi + h, byspi) {
#ifdef CONFIG_XFRM_OFFLOAD
if (xdo->type == XFRM_DEV_OFFLOAD_PACKET) {
if (x->xso.type != XFRM_DEV_OFFLOAD_PACKET)
@@ -1017,15 +1183,16 @@ static struct xfrm_state *__xfrm_state_lookup_all(struct net *net, u32 mark,
return NULL;
}
-static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark,
+static struct xfrm_state *__xfrm_state_lookup(const struct xfrm_hash_state_ptrs *state_ptrs,
+ u32 mark,
const xfrm_address_t *daddr,
__be32 spi, u8 proto,
unsigned short family)
{
- unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family);
+ unsigned int h = __xfrm_spi_hash(daddr, spi, proto, family, state_ptrs->hmask);
struct xfrm_state *x;
- hlist_for_each_entry_rcu(x, net->xfrm.state_byspi + h, byspi) {
+ hlist_for_each_entry_rcu(x, state_ptrs->byspi + h, byspi) {
if (x->props.family != family ||
x->id.spi != spi ||
x->id.proto != proto ||
@@ -1042,15 +1209,63 @@ static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark,
return NULL;
}
-static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark,
+struct xfrm_state *xfrm_input_state_lookup(struct net *net, u32 mark,
+ const xfrm_address_t *daddr,
+ __be32 spi, u8 proto,
+ unsigned short family)
+{
+ struct xfrm_hash_state_ptrs state_ptrs;
+ struct hlist_head *state_cache_input;
+ struct xfrm_state *x = NULL;
+
+ state_cache_input = raw_cpu_ptr(net->xfrm.state_cache_input);
+
+ rcu_read_lock();
+ hlist_for_each_entry_rcu(x, state_cache_input, state_cache_input) {
+ if (x->props.family != family ||
+ x->id.spi != spi ||
+ x->id.proto != proto ||
+ !xfrm_addr_equal(&x->id.daddr, daddr, family))
+ continue;
+
+ if ((mark & x->mark.m) != x->mark.v)
+ continue;
+ if (!xfrm_state_hold_rcu(x))
+ continue;
+ goto out;
+ }
+
+ xfrm_hash_ptrs_get(net, &state_ptrs);
+
+ x = __xfrm_state_lookup(&state_ptrs, mark, daddr, spi, proto, family);
+
+ if (x && x->km.state == XFRM_STATE_VALID) {
+ spin_lock_bh(&net->xfrm.xfrm_state_lock);
+ if (hlist_unhashed(&x->state_cache_input)) {
+ hlist_add_head_rcu(&x->state_cache_input, state_cache_input);
+ } else {
+ hlist_del_rcu(&x->state_cache_input);
+ hlist_add_head_rcu(&x->state_cache_input, state_cache_input);
+ }
+ spin_unlock_bh(&net->xfrm.xfrm_state_lock);
+ }
+
+out:
+ rcu_read_unlock();
+ return x;
+}
+EXPORT_SYMBOL(xfrm_input_state_lookup);
+
+static struct xfrm_state *__xfrm_state_lookup_byaddr(const struct xfrm_hash_state_ptrs *state_ptrs,
+ u32 mark,
const xfrm_address_t *daddr,
const xfrm_address_t *saddr,
u8 proto, unsigned short family)
{
- unsigned int h = xfrm_src_hash(net, daddr, saddr, family);
+ unsigned int h = __xfrm_src_hash(daddr, saddr, family, state_ptrs->hmask);
struct xfrm_state *x;
- hlist_for_each_entry_rcu(x, net->xfrm.state_bysrc + h, bysrc) {
+ hlist_for_each_entry_rcu(x, state_ptrs->bysrc + h, bysrc) {
if (x->props.family != family ||
x->id.proto != proto ||
!xfrm_addr_equal(&x->id.daddr, daddr, family) ||
@@ -1070,14 +1285,17 @@ static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark,
static inline struct xfrm_state *
__xfrm_state_locate(struct xfrm_state *x, int use_spi, int family)
{
+ struct xfrm_hash_state_ptrs state_ptrs;
struct net *net = xs_net(x);
u32 mark = x->mark.v & x->mark.m;
+ xfrm_hash_ptrs_get(net, &state_ptrs);
+
if (use_spi)
- return __xfrm_state_lookup(net, mark, &x->id.daddr,
+ return __xfrm_state_lookup(&state_ptrs, mark, &x->id.daddr,
x->id.spi, x->id.proto, family);
else
- return __xfrm_state_lookup_byaddr(net, mark,
+ return __xfrm_state_lookup_byaddr(&state_ptrs, mark,
&x->id.daddr,
&x->props.saddr,
x->id.proto, family);
@@ -1096,6 +1314,12 @@ static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x,
struct xfrm_state **best, int *acq_in_progress,
int *error)
{
+ /* We need the cpu id just as a lookup key,
+ * we don't require it to be stable.
+ */
+ unsigned int pcpu_id = get_cpu();
+ put_cpu();
+
/* Resolution logic:
* 1. There is a valid state with matching selector. Done.
* 2. Valid state with inappropriate selector. Skip.
@@ -1115,13 +1339,18 @@ static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x,
&fl->u.__fl_common))
return;
+ if (x->pcpu_num != UINT_MAX && x->pcpu_num != pcpu_id)
+ return;
+
if (!*best ||
+ ((*best)->pcpu_num == UINT_MAX && x->pcpu_num == pcpu_id) ||
(*best)->km.dying > x->km.dying ||
((*best)->km.dying == x->km.dying &&
(*best)->curlft.add_time < x->curlft.add_time))
*best = x;
} else if (x->km.state == XFRM_STATE_ACQ) {
- *acq_in_progress = 1;
+ if (!*best || x->pcpu_num == pcpu_id)
+ *acq_in_progress = 1;
} else if (x->km.state == XFRM_STATE_ERROR ||
x->km.state == XFRM_STATE_EXPIRED) {
if ((!x->sel.family ||
@@ -1140,6 +1369,7 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
unsigned short family, u32 if_id)
{
static xfrm_address_t saddr_wildcard = { };
+ struct xfrm_hash_state_ptrs state_ptrs;
struct net *net = xp_net(pol);
unsigned int h, h_wildcard;
struct xfrm_state *x, *x0, *to_put;
@@ -1150,14 +1380,64 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
unsigned short encap_family = tmpl->encap_family;
unsigned int sequence;
struct km_event c;
+ unsigned int pcpu_id;
+ bool cached = false;
+
+ /* We need the cpu id just as a lookup key,
+ * we don't require it to be stable.
+ */
+ pcpu_id = get_cpu();
+ put_cpu();
to_put = NULL;
sequence = read_seqcount_begin(&net->xfrm.xfrm_state_hash_generation);
rcu_read_lock();
- h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family);
- hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h, bydst) {
+ hlist_for_each_entry_rcu(x, &pol->state_cache_list, state_cache) {
+ if (x->props.family == encap_family &&
+ x->props.reqid == tmpl->reqid &&
+ (mark & x->mark.m) == x->mark.v &&
+ x->if_id == if_id &&
+ !(x->props.flags & XFRM_STATE_WILDRECV) &&
+ xfrm_state_addr_check(x, daddr, saddr, encap_family) &&
+ tmpl->mode == x->props.mode &&
+ tmpl->id.proto == x->id.proto &&
+ (tmpl->id.spi == x->id.spi || !tmpl->id.spi))
+ xfrm_state_look_at(pol, x, fl, encap_family,
+ &best, &acquire_in_progress, &error);
+ }
+
+ if (best)
+ goto cached;
+
+ hlist_for_each_entry_rcu(x, &pol->state_cache_list, state_cache) {
+ if (x->props.family == encap_family &&
+ x->props.reqid == tmpl->reqid &&
+ (mark & x->mark.m) == x->mark.v &&
+ x->if_id == if_id &&
+ !(x->props.flags & XFRM_STATE_WILDRECV) &&
+ xfrm_addr_equal(&x->id.daddr, daddr, encap_family) &&
+ tmpl->mode == x->props.mode &&
+ tmpl->id.proto == x->id.proto &&
+ (tmpl->id.spi == x->id.spi || !tmpl->id.spi))
+ xfrm_state_look_at(pol, x, fl, family,
+ &best, &acquire_in_progress, &error);
+ }
+
+cached:
+ cached = true;
+ if (best)
+ goto found;
+ else if (error)
+ best = NULL;
+ else if (acquire_in_progress) /* XXX: acquire_in_progress should not happen */
+ WARN_ON(1);
+
+ xfrm_hash_ptrs_get(net, &state_ptrs);
+
+ h = __xfrm_dst_hash(daddr, saddr, tmpl->reqid, encap_family, state_ptrs.hmask);
+ hlist_for_each_entry_rcu(x, state_ptrs.bydst + h, bydst) {
#ifdef CONFIG_XFRM_OFFLOAD
if (pol->xdo.type == XFRM_DEV_OFFLOAD_PACKET) {
if (x->xso.type != XFRM_DEV_OFFLOAD_PACKET)
@@ -1190,8 +1470,9 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
if (best || acquire_in_progress)
goto found;
- h_wildcard = xfrm_dst_hash(net, daddr, &saddr_wildcard, tmpl->reqid, encap_family);
- hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h_wildcard, bydst) {
+ h_wildcard = __xfrm_dst_hash(daddr, &saddr_wildcard, tmpl->reqid,
+ encap_family, state_ptrs.hmask);
+ hlist_for_each_entry_rcu(x, state_ptrs.bydst + h_wildcard, bydst) {
#ifdef CONFIG_XFRM_OFFLOAD
if (pol->xdo.type == XFRM_DEV_OFFLOAD_PACKET) {
if (x->xso.type != XFRM_DEV_OFFLOAD_PACKET)
@@ -1223,10 +1504,13 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
}
found:
- x = best;
+ if (!(pol->flags & XFRM_POLICY_CPU_ACQUIRE) ||
+ (best && (best->pcpu_num == pcpu_id)))
+ x = best;
+
if (!x && !error && !acquire_in_progress) {
if (tmpl->id.spi &&
- (x0 = __xfrm_state_lookup_all(net, mark, daddr,
+ (x0 = __xfrm_state_lookup_all(&state_ptrs, mark, daddr,
tmpl->id.spi, tmpl->id.proto,
encap_family,
&pol->xdo)) != NULL) {
@@ -1255,6 +1539,8 @@ found:
xfrm_init_tempstate(x, fl, tmpl, daddr, saddr, family);
memcpy(&x->mark, &pol->mark, sizeof(x->mark));
x->if_id = if_id;
+ if ((pol->flags & XFRM_POLICY_CPU_ACQUIRE) && best)
+ x->pcpu_num = pcpu_id;
error = security_xfrm_state_alloc_acquire(x, pol->security, fl->flowi_secid);
if (error) {
@@ -1273,8 +1559,7 @@ found:
xso->dev = xdo->dev;
xso->real_dev = xdo->real_dev;
xso->flags = XFRM_DEV_OFFLOAD_FLAG_ACQ;
- netdev_tracker_alloc(xso->dev, &xso->dev_tracker,
- GFP_ATOMIC);
+ netdev_hold(xso->dev, &xso->dev_tracker, GFP_ATOMIC);
error = xso->dev->xfrmdev_ops->xdo_dev_state_add(x, NULL);
if (error) {
xso->dir = 0;
@@ -1292,7 +1577,9 @@ found:
if (km_query(x, tmpl, pol) == 0) {
spin_lock_bh(&net->xfrm.xfrm_state_lock);
x->km.state = XFRM_STATE_ACQ;
+ x->dir = XFRM_SA_DIR_OUT;
list_add(&x->km.all, &net->xfrm.state_all);
+ h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family);
XFRM_STATE_INSERT(bydst, &x->bydst,
net->xfrm.state_bydst + h,
x->xso.type);
@@ -1300,6 +1587,7 @@ found:
XFRM_STATE_INSERT(bysrc, &x->bysrc,
net->xfrm.state_bysrc + h,
x->xso.type);
+ INIT_HLIST_NODE(&x->state_cache);
if (x->id.spi) {
h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, encap_family);
XFRM_STATE_INSERT(byspi, &x->byspi,
@@ -1333,6 +1621,11 @@ found:
x = NULL;
error = -ESRCH;
}
+
+ /* Use the already installed 'fallback' while the CPU-specific
+ * SA acquire is handled*/
+ if (best)
+ x = best;
}
out:
if (x) {
@@ -1343,6 +1636,15 @@ out:
} else {
*err = acquire_in_progress ? -EAGAIN : error;
}
+
+ if (x && x->km.state == XFRM_STATE_VALID && !cached &&
+ (!(pol->flags & XFRM_POLICY_CPU_ACQUIRE) || x->pcpu_num == pcpu_id)) {
+ spin_lock_bh(&net->xfrm.xfrm_state_lock);
+ if (hlist_unhashed(&x->state_cache))
+ hlist_add_head_rcu(&x->state_cache, &pol->state_cache_list);
+ spin_unlock_bh(&net->xfrm.xfrm_state_lock);
+ }
+
rcu_read_unlock();
if (to_put)
xfrm_state_put(to_put);
@@ -1452,6 +1754,7 @@ static void __xfrm_state_insert(struct xfrm_state *x)
net->xfrm.state_num++;
xfrm_hash_grow_check(net, x->bydst.next != NULL);
+ xfrm_nat_keepalive_state_updated(x);
}
/* net->xfrm.xfrm_state_lock is held */
@@ -1464,12 +1767,14 @@ static void __xfrm_state_bump_genids(struct xfrm_state *xnew)
unsigned int h;
u32 mark = xnew->mark.v & xnew->mark.m;
u32 if_id = xnew->if_id;
+ u32 cpu_id = xnew->pcpu_num;
h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family);
hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
if (x->props.family == family &&
x->props.reqid == reqid &&
x->if_id == if_id &&
+ x->pcpu_num == cpu_id &&
(mark & x->mark.m) == x->mark.v &&
xfrm_addr_equal(&x->id.daddr, &xnew->id.daddr, family) &&
xfrm_addr_equal(&x->props.saddr, &xnew->props.saddr, family))
@@ -1492,7 +1797,7 @@ EXPORT_SYMBOL(xfrm_state_insert);
static struct xfrm_state *__find_acq_core(struct net *net,
const struct xfrm_mark *m,
unsigned short family, u8 mode,
- u32 reqid, u32 if_id, u8 proto,
+ u32 reqid, u32 if_id, u32 pcpu_num, u8 proto,
const xfrm_address_t *daddr,
const xfrm_address_t *saddr,
int create)
@@ -1509,6 +1814,7 @@ static struct xfrm_state *__find_acq_core(struct net *net,
x->id.spi != 0 ||
x->id.proto != proto ||
(mark & x->mark.m) != x->mark.v ||
+ x->pcpu_num != pcpu_num ||
!xfrm_addr_equal(&x->id.daddr, daddr, family) ||
!xfrm_addr_equal(&x->props.saddr, saddr, family))
continue;
@@ -1542,6 +1848,7 @@ static struct xfrm_state *__find_acq_core(struct net *net,
break;
}
+ x->pcpu_num = pcpu_num;
x->km.state = XFRM_STATE_ACQ;
x->id.proto = proto;
x->props.family = family;
@@ -1570,7 +1877,7 @@ static struct xfrm_state *__find_acq_core(struct net *net,
return x;
}
-static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq);
+static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num);
int xfrm_state_add(struct xfrm_state *x)
{
@@ -1596,7 +1903,7 @@ int xfrm_state_add(struct xfrm_state *x)
}
if (use_spi && x->km.seq) {
- x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq);
+ x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq, x->pcpu_num);
if (x1 && ((x1->id.proto != x->id.proto) ||
!xfrm_addr_equal(&x1->id.daddr, &x->id.daddr, family))) {
to_put = x1;
@@ -1606,7 +1913,7 @@ int xfrm_state_add(struct xfrm_state *x)
if (use_spi && !x1)
x1 = __find_acq_core(net, &x->mark, family, x->props.mode,
- x->props.reqid, x->if_id, x->id.proto,
+ x->props.reqid, x->if_id, x->pcpu_num, x->id.proto,
&x->id.daddr, &x->props.saddr, 0);
__xfrm_state_bump_genids(x);
@@ -1731,6 +2038,7 @@ static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig,
x->props.flags = orig->props.flags;
x->props.extra_flags = orig->props.extra_flags;
+ x->pcpu_num = orig->pcpu_num;
x->if_id = orig->if_id;
x->tfcpad = orig->tfcpad;
x->replay_maxdiff = orig->replay_maxdiff;
@@ -1744,6 +2052,13 @@ static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig,
x->lastused = orig->lastused;
x->new_mapping = 0;
x->new_mapping_sport = 0;
+ x->dir = orig->dir;
+
+ x->mode_cbs = orig->mode_cbs;
+ if (x->mode_cbs && x->mode_cbs->clone_state) {
+ if (x->mode_cbs->clone_state(x, orig))
+ goto error;
+ }
return x;
@@ -1864,8 +2179,14 @@ int xfrm_state_update(struct xfrm_state *x)
}
if (x1->km.state == XFRM_STATE_ACQ) {
+ if (x->dir && x1->dir != x->dir)
+ goto out;
+
__xfrm_state_insert(x);
x = NULL;
+ } else {
+ if (x1->dir != x->dir)
+ goto out;
}
err = 0;
@@ -1935,7 +2256,7 @@ EXPORT_SYMBOL(xfrm_state_update);
int xfrm_state_check_expire(struct xfrm_state *x)
{
- xfrm_dev_state_update_curlft(x);
+ xfrm_dev_state_update_stats(x);
if (!READ_ONCE(x->curlft.use_time))
WRITE_ONCE(x->curlft.use_time, ktime_get_real_seconds());
@@ -1957,14 +2278,30 @@ int xfrm_state_check_expire(struct xfrm_state *x)
}
EXPORT_SYMBOL(xfrm_state_check_expire);
+void xfrm_state_update_stats(struct net *net)
+{
+ struct xfrm_state *x;
+ int i;
+
+ spin_lock_bh(&net->xfrm.xfrm_state_lock);
+ for (i = 0; i <= net->xfrm.state_hmask; i++) {
+ hlist_for_each_entry(x, net->xfrm.state_bydst + i, bydst)
+ xfrm_dev_state_update_stats(x);
+ }
+ spin_unlock_bh(&net->xfrm.xfrm_state_lock);
+}
+
struct xfrm_state *
xfrm_state_lookup(struct net *net, u32 mark, const xfrm_address_t *daddr, __be32 spi,
u8 proto, unsigned short family)
{
+ struct xfrm_hash_state_ptrs state_ptrs;
struct xfrm_state *x;
rcu_read_lock();
- x = __xfrm_state_lookup(net, mark, daddr, spi, proto, family);
+ xfrm_hash_ptrs_get(net, &state_ptrs);
+
+ x = __xfrm_state_lookup(&state_ptrs, mark, daddr, spi, proto, family);
rcu_read_unlock();
return x;
}
@@ -1975,10 +2312,14 @@ xfrm_state_lookup_byaddr(struct net *net, u32 mark,
const xfrm_address_t *daddr, const xfrm_address_t *saddr,
u8 proto, unsigned short family)
{
+ struct xfrm_hash_state_ptrs state_ptrs;
struct xfrm_state *x;
spin_lock_bh(&net->xfrm.xfrm_state_lock);
- x = __xfrm_state_lookup_byaddr(net, mark, daddr, saddr, proto, family);
+
+ xfrm_hash_ptrs_get(net, &state_ptrs);
+
+ x = __xfrm_state_lookup_byaddr(&state_ptrs, mark, daddr, saddr, proto, family);
spin_unlock_bh(&net->xfrm.xfrm_state_lock);
return x;
}
@@ -1986,13 +2327,14 @@ EXPORT_SYMBOL(xfrm_state_lookup_byaddr);
struct xfrm_state *
xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32 reqid,
- u32 if_id, u8 proto, const xfrm_address_t *daddr,
+ u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr,
const xfrm_address_t *saddr, int create, unsigned short family)
{
struct xfrm_state *x;
spin_lock_bh(&net->xfrm.xfrm_state_lock);
- x = __find_acq_core(net, mark, family, mode, reqid, if_id, proto, daddr, saddr, create);
+ x = __find_acq_core(net, mark, family, mode, reqid, if_id, pcpu_num,
+ proto, daddr, saddr, create);
spin_unlock_bh(&net->xfrm.xfrm_state_lock);
return x;
@@ -2051,6 +2393,7 @@ static int __xfrm6_state_sort_cmp(const void *p)
#endif
case XFRM_MODE_TUNNEL:
case XFRM_MODE_BEET:
+ case XFRM_MODE_IPTFS:
return 4;
}
return 5;
@@ -2077,6 +2420,7 @@ static int __xfrm6_tmpl_sort_cmp(const void *p)
#endif
case XFRM_MODE_TUNNEL:
case XFRM_MODE_BEET:
+ case XFRM_MODE_IPTFS:
return 3;
}
return 4;
@@ -2127,7 +2471,7 @@ xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n,
/* Silly enough, but I'm lazy to build resolution list */
-static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
+static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num)
{
unsigned int h = xfrm_seq_hash(net, seq);
struct xfrm_state *x;
@@ -2135,6 +2479,7 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s
hlist_for_each_entry_rcu(x, net->xfrm.state_byseq + h, byseq) {
if (x->km.seq == seq &&
(mark & x->mark.m) == x->mark.v &&
+ x->pcpu_num == pcpu_num &&
x->km.state == XFRM_STATE_ACQ) {
xfrm_state_hold(x);
return x;
@@ -2144,12 +2489,12 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s
return NULL;
}
-struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
+struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num)
{
struct xfrm_state *x;
spin_lock_bh(&net->xfrm.xfrm_state_lock);
- x = __xfrm_find_acq_byseq(net, mark, seq);
+ x = __xfrm_find_acq_byseq(net, mark, seq, pcpu_num);
spin_unlock_bh(&net->xfrm.xfrm_state_lock);
return x;
}
@@ -2765,6 +3110,9 @@ u32 xfrm_state_mtu(struct xfrm_state *x, int mtu)
case XFRM_MODE_TUNNEL:
break;
default:
+ if (x->mode_cbs && x->mode_cbs->get_inner_mtu)
+ return x->mode_cbs->get_inner_mtu(x, mtu);
+
WARN_ON_ONCE(1);
break;
}
@@ -2850,6 +3198,27 @@ int __xfrm_init_state(struct xfrm_state *x, bool init_replay, bool offload,
goto error;
}
+ if (x->nat_keepalive_interval) {
+ if (x->dir != XFRM_SA_DIR_OUT) {
+ NL_SET_ERR_MSG(extack, "NAT keepalive is only supported for outbound SAs");
+ err = -EINVAL;
+ goto error;
+ }
+
+ if (!x->encap || x->encap->encap_type != UDP_ENCAP_ESPINUDP) {
+ NL_SET_ERR_MSG(extack,
+ "NAT keepalive is only supported for UDP encapsulation");
+ err = -EINVAL;
+ goto error;
+ }
+ }
+
+ x->mode_cbs = xfrm_get_mode_cbs(x->props.mode);
+ if (x->mode_cbs) {
+ if (x->mode_cbs->init_state)
+ err = x->mode_cbs->init_state(x);
+ module_put(x->mode_cbs->owner);
+ }
error:
return err;
}
@@ -2893,6 +3262,11 @@ int __net_init xfrm_state_init(struct net *net)
net->xfrm.state_byseq = xfrm_hash_alloc(sz);
if (!net->xfrm.state_byseq)
goto out_byseq;
+
+ net->xfrm.state_cache_input = alloc_percpu(struct hlist_head);
+ if (!net->xfrm.state_cache_input)
+ goto out_state_cache_input;
+
net->xfrm.state_hmask = ((sz / sizeof(struct hlist_head)) - 1);
net->xfrm.state_num = 0;
@@ -2902,6 +3276,8 @@ int __net_init xfrm_state_init(struct net *net)
&net->xfrm.xfrm_state_lock);
return 0;
+out_state_cache_input:
+ xfrm_hash_free(net->xfrm.state_byseq, sz);
out_byseq:
xfrm_hash_free(net->xfrm.state_byspi, sz);
out_byspi:
@@ -2931,6 +3307,7 @@ void xfrm_state_fini(struct net *net)
xfrm_hash_free(net->xfrm.state_bysrc, sz);
WARN_ON(!hlist_empty(net->xfrm.state_bydst));
xfrm_hash_free(net->xfrm.state_bydst, sz);
+ free_percpu(net->xfrm.state_cache_input);
}
#ifdef CONFIG_AUDITSYSCALL