summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/linux/memcontrol.h45
-rw-r--r--include/net/proto_memory.h4
-rw-r--r--include/net/sock.h46
-rw-r--r--include/net/tcp.h4
-rw-r--r--mm/memcontrol.c40
-rw-r--r--net/core/sock.c38
-rw-r--r--net/ipv4/inet_connection_sock.c19
-rw-r--r--net/ipv4/tcp_output.c5
-rw-r--r--net/mptcp/protocol.h4
-rw-r--r--net/mptcp/subflow.c11
10 files changed, 142 insertions, 74 deletions
diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 785173aa0739..fb27e3d2fdac 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -1596,14 +1596,16 @@ static inline void mem_cgroup_flush_foreign(struct bdi_writeback *wb)
#endif /* CONFIG_CGROUP_WRITEBACK */
struct sock;
-bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages,
- gfp_t gfp_mask);
-void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages);
#ifdef CONFIG_MEMCG
extern struct static_key_false memcg_sockets_enabled_key;
#define mem_cgroup_sockets_enabled static_branch_unlikely(&memcg_sockets_enabled_key)
+
void mem_cgroup_sk_alloc(struct sock *sk);
void mem_cgroup_sk_free(struct sock *sk);
+void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk);
+bool mem_cgroup_sk_charge(const struct sock *sk, unsigned int nr_pages,
+ gfp_t gfp_mask);
+void mem_cgroup_sk_uncharge(const struct sock *sk, unsigned int nr_pages);
#if BITS_PER_LONG < 64
static inline void mem_cgroup_set_socket_pressure(struct mem_cgroup *memcg)
@@ -1640,32 +1642,37 @@ static inline u64 mem_cgroup_get_socket_pressure(struct mem_cgroup *memcg)
}
#endif
-static inline bool mem_cgroup_under_socket_pressure(struct mem_cgroup *memcg)
-{
-#ifdef CONFIG_MEMCG_V1
- if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
- return !!memcg->tcpmem_pressure;
-#endif /* CONFIG_MEMCG_V1 */
- do {
- if (time_before64(get_jiffies_64(), mem_cgroup_get_socket_pressure(memcg)))
- return true;
- } while ((memcg = parent_mem_cgroup(memcg)));
- return false;
-}
-
int alloc_shrinker_info(struct mem_cgroup *memcg);
void free_shrinker_info(struct mem_cgroup *memcg);
void set_shrinker_bit(struct mem_cgroup *memcg, int nid, int shrinker_id);
void reparent_shrinker_deferred(struct mem_cgroup *memcg);
#else
#define mem_cgroup_sockets_enabled 0
-static inline void mem_cgroup_sk_alloc(struct sock *sk) { };
-static inline void mem_cgroup_sk_free(struct sock *sk) { };
-static inline bool mem_cgroup_under_socket_pressure(struct mem_cgroup *memcg)
+
+static inline void mem_cgroup_sk_alloc(struct sock *sk)
+{
+}
+
+static inline void mem_cgroup_sk_free(struct sock *sk)
+{
+}
+
+static inline void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk)
+{
+}
+
+static inline bool mem_cgroup_sk_charge(const struct sock *sk,
+ unsigned int nr_pages,
+ gfp_t gfp_mask)
{
return false;
}
+static inline void mem_cgroup_sk_uncharge(const struct sock *sk,
+ unsigned int nr_pages)
+{
+}
+
static inline void set_shrinker_bit(struct mem_cgroup *memcg,
int nid, int shrinker_id)
{
diff --git a/include/net/proto_memory.h b/include/net/proto_memory.h
index a6ab2f4f5e28..8e91a8fa31b5 100644
--- a/include/net/proto_memory.h
+++ b/include/net/proto_memory.h
@@ -31,8 +31,8 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
if (!sk->sk_prot->memory_pressure)
return false;
- if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
- mem_cgroup_under_socket_pressure(sk->sk_memcg))
+ if (mem_cgroup_sk_enabled(sk) &&
+ mem_cgroup_sk_under_memory_pressure(sk))
return true;
return !!READ_ONCE(*sk->sk_prot->memory_pressure);
diff --git a/include/net/sock.h b/include/net/sock.h
index c8a4b283df6f..1c49ea13af4a 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -443,7 +443,9 @@ struct sock {
__cacheline_group_begin(sock_read_rxtx);
int sk_err;
struct socket *sk_socket;
+#ifdef CONFIG_MEMCG
struct mem_cgroup *sk_memcg;
+#endif
#ifdef CONFIG_XFRM
struct xfrm_policy __rcu *sk_policy[2];
#endif
@@ -2594,6 +2596,50 @@ static inline gfp_t gfp_memcg_charge(void)
return in_softirq() ? GFP_ATOMIC : GFP_KERNEL;
}
+#ifdef CONFIG_MEMCG
+static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk)
+{
+ return sk->sk_memcg;
+}
+
+static inline bool mem_cgroup_sk_enabled(const struct sock *sk)
+{
+ return mem_cgroup_sockets_enabled && mem_cgroup_from_sk(sk);
+}
+
+static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk)
+{
+ struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
+
+#ifdef CONFIG_MEMCG_V1
+ if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+ return !!memcg->tcpmem_pressure;
+#endif /* CONFIG_MEMCG_V1 */
+
+ do {
+ if (time_before64(get_jiffies_64(), mem_cgroup_get_socket_pressure(memcg)))
+ return true;
+ } while ((memcg = parent_mem_cgroup(memcg)));
+
+ return false;
+}
+#else
+static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk)
+{
+ return NULL;
+}
+
+static inline bool mem_cgroup_sk_enabled(const struct sock *sk)
+{
+ return false;
+}
+
+static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk)
+{
+ return false;
+}
+#endif
+
static inline long sock_rcvtimeo(const struct sock *sk, bool noblock)
{
return noblock ? 0 : READ_ONCE(sk->sk_rcvtimeo);
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 526a26e7a150..2936b8175950 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -275,8 +275,8 @@ extern unsigned long tcp_memory_pressure;
/* optimized version of sk_under_memory_pressure() for TCP sockets */
static inline bool tcp_under_memory_pressure(const struct sock *sk)
{
- if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
- mem_cgroup_under_socket_pressure(sk->sk_memcg))
+ if (mem_cgroup_sk_enabled(sk) &&
+ mem_cgroup_sk_under_memory_pressure(sk))
return true;
return READ_ONCE(tcp_memory_pressure);
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 8dd7fbed5a94..df3e9205c9e6 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -5020,22 +5020,42 @@ out:
void mem_cgroup_sk_free(struct sock *sk)
{
- if (sk->sk_memcg)
- css_put(&sk->sk_memcg->css);
+ struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
+
+ if (memcg)
+ css_put(&memcg->css);
+}
+
+void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk)
+{
+ struct mem_cgroup *memcg;
+
+ if (sk->sk_memcg == newsk->sk_memcg)
+ return;
+
+ mem_cgroup_sk_free(newsk);
+
+ memcg = mem_cgroup_from_sk(sk);
+ if (memcg)
+ css_get(&memcg->css);
+
+ newsk->sk_memcg = sk->sk_memcg;
}
/**
- * mem_cgroup_charge_skmem - charge socket memory
- * @memcg: memcg to charge
+ * mem_cgroup_sk_charge - charge socket memory
+ * @sk: socket in memcg to charge
* @nr_pages: number of pages to charge
* @gfp_mask: reclaim mode
*
* Charges @nr_pages to @memcg. Returns %true if the charge fit within
* @memcg's configured limit, %false if it doesn't.
*/
-bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages,
- gfp_t gfp_mask)
+bool mem_cgroup_sk_charge(const struct sock *sk, unsigned int nr_pages,
+ gfp_t gfp_mask)
{
+ struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
+
if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
return memcg1_charge_skmem(memcg, nr_pages, gfp_mask);
@@ -5048,12 +5068,14 @@ bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages,
}
/**
- * mem_cgroup_uncharge_skmem - uncharge socket memory
- * @memcg: memcg to uncharge
+ * mem_cgroup_sk_uncharge - uncharge socket memory
+ * @sk: socket in memcg to uncharge
* @nr_pages: number of pages to uncharge
*/
-void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
+void mem_cgroup_sk_uncharge(const struct sock *sk, unsigned int nr_pages)
{
+ struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
+
if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
memcg1_uncharge_skmem(memcg, nr_pages);
return;
diff --git a/net/core/sock.c b/net/core/sock.c
index 7c26ec8dce63..ab6953d295df 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -1032,7 +1032,7 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
bool charged;
int pages;
- if (!mem_cgroup_sockets_enabled || !sk->sk_memcg || !sk_has_account(sk))
+ if (!mem_cgroup_sk_enabled(sk) || !sk_has_account(sk))
return -EOPNOTSUPP;
if (!bytes)
@@ -1041,8 +1041,8 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
pages = sk_mem_pages(bytes);
/* pre-charge to memcg */
- charged = mem_cgroup_charge_skmem(sk->sk_memcg, pages,
- GFP_KERNEL | __GFP_RETRY_MAYFAIL);
+ charged = mem_cgroup_sk_charge(sk, pages,
+ GFP_KERNEL | __GFP_RETRY_MAYFAIL);
if (!charged)
return -ENOMEM;
@@ -1054,7 +1054,7 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
*/
if (allocated > sk_prot_mem_limits(sk, 1)) {
sk_memory_allocated_sub(sk, pages);
- mem_cgroup_uncharge_skmem(sk->sk_memcg, pages);
+ mem_cgroup_sk_uncharge(sk, pages);
return -ENOMEM;
}
sk_forward_alloc_add(sk, pages << PAGE_SHIFT);
@@ -2512,8 +2512,10 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
sock_reset_flag(newsk, SOCK_DONE);
+#ifdef CONFIG_MEMCG
/* sk->sk_memcg will be populated at accept() time */
newsk->sk_memcg = NULL;
+#endif
cgroup_sk_clone(&newsk->sk_cgrp_data);
@@ -3263,16 +3265,16 @@ EXPORT_SYMBOL(sk_wait_data);
*/
int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
{
- struct mem_cgroup *memcg = mem_cgroup_sockets_enabled ? sk->sk_memcg : NULL;
+ bool memcg_enabled = false, charged = false;
struct proto *prot = sk->sk_prot;
- bool charged = true;
long allocated;
sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);
- if (memcg) {
- charged = mem_cgroup_charge_skmem(memcg, amt, gfp_memcg_charge());
+ if (mem_cgroup_sk_enabled(sk)) {
+ memcg_enabled = true;
+ charged = mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge());
if (!charged)
goto suppress_allocation;
}
@@ -3346,21 +3348,19 @@ suppress_allocation:
*/
if (sk->sk_wmem_queued + size >= sk->sk_sndbuf) {
/* Force charge with __GFP_NOFAIL */
- if (memcg && !charged) {
- mem_cgroup_charge_skmem(memcg, amt,
- gfp_memcg_charge() | __GFP_NOFAIL);
- }
+ if (memcg_enabled && !charged)
+ mem_cgroup_sk_charge(sk, amt,
+ gfp_memcg_charge() | __GFP_NOFAIL);
return 1;
}
}
- if (kind == SK_MEM_SEND || (kind == SK_MEM_RECV && charged))
- trace_sock_exceed_buf_limit(sk, prot, allocated, kind);
+ trace_sock_exceed_buf_limit(sk, prot, allocated, kind);
sk_memory_allocated_sub(sk, amt);
- if (memcg && charged)
- mem_cgroup_uncharge_skmem(memcg, amt);
+ if (charged)
+ mem_cgroup_sk_uncharge(sk, amt);
return 0;
}
@@ -3398,8 +3398,8 @@ void __sk_mem_reduce_allocated(struct sock *sk, int amount)
{
sk_memory_allocated_sub(sk, amount);
- if (mem_cgroup_sockets_enabled && sk->sk_memcg)
- mem_cgroup_uncharge_skmem(sk->sk_memcg, amount);
+ if (mem_cgroup_sk_enabled(sk))
+ mem_cgroup_sk_uncharge(sk, amount);
if (sk_under_global_memory_pressure(sk) &&
(sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
@@ -4454,7 +4454,9 @@ static int __init sock_struct_check(void)
CACHELINE_ASSERT_GROUP_MEMBER(struct sock, sock_read_rxtx, sk_err);
CACHELINE_ASSERT_GROUP_MEMBER(struct sock, sock_read_rxtx, sk_socket);
+#ifdef CONFIG_MEMCG
CACHELINE_ASSERT_GROUP_MEMBER(struct sock, sock_read_rxtx, sk_memcg);
+#endif
CACHELINE_ASSERT_GROUP_MEMBER(struct sock, sock_write_rxtx, sk_lock);
CACHELINE_ASSERT_GROUP_MEMBER(struct sock, sock_write_rxtx, sk_reserved_mem);
diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c
index 1e2df51427fe..0ef1eacd539d 100644
--- a/net/ipv4/inet_connection_sock.c
+++ b/net/ipv4/inet_connection_sock.c
@@ -706,9 +706,9 @@ struct sock *inet_csk_accept(struct sock *sk, struct proto_accept_arg *arg)
spin_unlock_bh(&queue->fastopenq.lock);
}
-out:
release_sock(sk);
- if (newsk && mem_cgroup_sockets_enabled) {
+
+ if (mem_cgroup_sockets_enabled) {
gfp_t gfp = GFP_KERNEL | __GFP_NOFAIL;
int amt = 0;
@@ -718,7 +718,7 @@ out:
lock_sock(newsk);
mem_cgroup_sk_alloc(newsk);
- if (newsk->sk_memcg) {
+ if (mem_cgroup_from_sk(newsk)) {
/* The socket has not been accepted yet, no need
* to look at newsk->sk_wmem_queued.
*/
@@ -727,23 +727,22 @@ out:
}
if (amt)
- mem_cgroup_charge_skmem(newsk->sk_memcg, amt, gfp);
+ mem_cgroup_sk_charge(newsk, amt, gfp);
kmem_cache_charge(newsk, gfp);
release_sock(newsk);
}
+
if (req)
reqsk_put(req);
- if (newsk)
- inet_init_csk_locks(newsk);
-
+ inet_init_csk_locks(newsk);
return newsk;
+
out_err:
- newsk = NULL;
- req = NULL;
+ release_sock(sk);
arg->err = error;
- goto out;
+ return NULL;
}
EXPORT_SYMBOL(inet_csk_accept);
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index caf11920a878..dfbac0876d96 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -3578,9 +3578,8 @@ void sk_forced_mem_schedule(struct sock *sk, int size)
sk_forward_alloc_add(sk, amt << PAGE_SHIFT);
sk_memory_allocated_add(sk, amt);
- if (mem_cgroup_sockets_enabled && sk->sk_memcg)
- mem_cgroup_charge_skmem(sk->sk_memcg, amt,
- gfp_memcg_charge() | __GFP_NOFAIL);
+ if (mem_cgroup_sk_enabled(sk))
+ mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge() | __GFP_NOFAIL);
}
/* Send a FIN. The caller locks the socket for us.
diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h
index b15d7fab5c4b..a1787a1344ac 100644
--- a/net/mptcp/protocol.h
+++ b/net/mptcp/protocol.h
@@ -788,9 +788,7 @@ static inline bool mptcp_epollin_ready(const struct sock *sk)
* as it can always coalesce them
*/
return (data_avail >= sk->sk_rcvlowat) ||
- (mem_cgroup_sockets_enabled && sk->sk_memcg &&
- mem_cgroup_under_socket_pressure(sk->sk_memcg)) ||
- READ_ONCE(tcp_memory_pressure);
+ tcp_under_memory_pressure(sk);
}
int mptcp_set_rcvlowat(struct sock *sk, int val);
diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c
index 3f1b62a9fe88..c8a7e4b59db1 100644
--- a/net/mptcp/subflow.c
+++ b/net/mptcp/subflow.c
@@ -1717,19 +1717,14 @@ static void mptcp_attach_cgroup(struct sock *parent, struct sock *child)
/* only the additional subflows created by kworkers have to be modified */
if (cgroup_id(sock_cgroup_ptr(parent_skcd)) !=
cgroup_id(sock_cgroup_ptr(child_skcd))) {
-#ifdef CONFIG_MEMCG
- struct mem_cgroup *memcg = parent->sk_memcg;
-
- mem_cgroup_sk_free(child);
- if (memcg && css_tryget(&memcg->css))
- child->sk_memcg = memcg;
-#endif /* CONFIG_MEMCG */
-
cgroup_sk_free(child_skcd);
*child_skcd = *parent_skcd;
cgroup_sk_clone(child_skcd);
}
#endif /* CONFIG_SOCK_CGROUP_DATA */
+
+ if (mem_cgroup_sockets_enabled)
+ mem_cgroup_sk_inherit(parent, child);
}
static void mptcp_subflow_ops_override(struct sock *ssk)