summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHaoran Jiang <jianghaoran@kylinos.cn>2025-08-05 19:00:22 +0800
committerHuacai Chen <chenhuacai@loongson.cn>2025-08-05 19:00:22 +0800
commitc0fcc955ff827431b541b1aa6bcb82bdce4531f7 (patch)
treee4d83f7d74d1b2922371cb6507b4e536ca64efc8
parentcd39d9e6b7e4c58fa77783e7aedf7ada51d02ea3 (diff)
LoongArch: BPF: Fix the tailcall hierarchy
In specific use cases combining tailcalls and BPF-to-BPF calls, MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt back-propagation from callee to caller. This patch fixes this tailcall issue caused by abusing the tailcall in bpf2bpf feature on LoongArch like the way of "bpf, x64: Fix tailcall hierarchy". Push tail_call_cnt_ptr and tail_call_cnt into the stack, tail_call_cnt_ptr is passed between tailcall and bpf2bpf, uses tail_call_cnt_ptr to increment tail_call_cnt. Fixes: bb035ef0cc91 ("LoongArch: BPF: Support mixing bpf2bpf and tailcalls") Reviewed-by: Geliang Tang <geliang@kernel.org> Reviewed-by: Hengqi Chen <hengqi.chen@gmail.com> Signed-off-by: Haoran Jiang <jianghaoran@kylinos.cn> Signed-off-by: Huacai Chen <chenhuacai@loongson.cn>
-rw-r--r--arch/loongarch/net/bpf_jit.c155
1 files changed, 107 insertions, 48 deletions
diff --git a/arch/loongarch/net/bpf_jit.c b/arch/loongarch/net/bpf_jit.c
index f4f12ed16d2f..4ea8ae4cf0ca 100644
--- a/arch/loongarch/net/bpf_jit.c
+++ b/arch/loongarch/net/bpf_jit.c
@@ -17,10 +17,7 @@
#define LOONGARCH_BPF_FENTRY_NBYTES (LOONGARCH_LONG_JUMP_NINSNS * 4)
#define REG_TCC LOONGARCH_GPR_A6
-#define TCC_SAVED LOONGARCH_GPR_S5
-
-#define SAVE_RA BIT(0)
-#define SAVE_TCC BIT(1)
+#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (round_up(stack, 16) - 80)
static const int regmap[] = {
/* return value from in-kernel function, and exit value for eBPF program */
@@ -42,32 +39,57 @@ static const int regmap[] = {
[BPF_REG_AX] = LOONGARCH_GPR_T0,
};
-static void mark_call(struct jit_ctx *ctx)
+static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int *store_offset)
{
- ctx->flags |= SAVE_RA;
-}
+ const struct bpf_prog *prog = ctx->prog;
+ const bool is_main_prog = !bpf_is_subprog(prog);
-static void mark_tail_call(struct jit_ctx *ctx)
-{
- ctx->flags |= SAVE_TCC;
-}
+ if (is_main_prog) {
+ /*
+ * LOONGARCH_GPR_T3 = MAX_TAIL_CALL_CNT
+ * if (REG_TCC > T3 )
+ * std REG_TCC -> LOONGARCH_GPR_SP + store_offset
+ * else
+ * std REG_TCC -> LOONGARCH_GPR_SP + store_offset
+ * REG_TCC = LOONGARCH_GPR_SP + store_offset
+ *
+ * std REG_TCC -> LOONGARCH_GPR_SP + store_offset
+ *
+ * The purpose of this code is to first push the TCC into stack,
+ * and then push the address of TCC into stack.
+ * In cases where bpf2bpf and tailcall are used in combination,
+ * the value in REG_TCC may be a count or an address,
+ * these two cases need to be judged and handled separately.
+ */
+ emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
+ *store_offset -= sizeof(long);
-static bool seen_call(struct jit_ctx *ctx)
-{
- return (ctx->flags & SAVE_RA);
-}
+ emit_cond_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4);
-static bool seen_tail_call(struct jit_ctx *ctx)
-{
- return (ctx->flags & SAVE_TCC);
-}
+ /*
+ * If REG_TCC < MAX_TAIL_CALL_CNT, the value in REG_TCC is a count,
+ * push tcc into stack
+ */
+ emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
-static u8 tail_call_reg(struct jit_ctx *ctx)
-{
- if (seen_call(ctx))
- return TCC_SAVED;
+ /* Push the address of TCC into the REG_TCC */
+ emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
- return REG_TCC;
+ emit_uncond_jmp(ctx, 2);
+
+ /*
+ * If REG_TCC > MAX_TAIL_CALL_CNT, the value in REG_TCC is an address,
+ * push tcc_ptr into stack
+ */
+ emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
+ } else {
+ *store_offset -= sizeof(long);
+ emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
+ }
+
+ /* Push tcc_ptr into stack */
+ *store_offset -= sizeof(long);
+ emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
}
/*
@@ -90,6 +112,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
* | $s4 |
* +-------------------------+
* | $s5 |
+ * +-------------------------+
+ * | tcc |
+ * +-------------------------+
+ * | tcc_ptr |
* +-------------------------+ <--BPF_REG_FP
* | prog->aux->stack_depth |
* | (optional) |
@@ -99,12 +125,17 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
static void build_prologue(struct jit_ctx *ctx)
{
int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
+ const struct bpf_prog *prog = ctx->prog;
+ const bool is_main_prog = !bpf_is_subprog(prog);
bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
- /* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
+ /* To store ra, fp, s0, s1, s2, s3, s4, s5 */
stack_adjust += sizeof(long) * 8;
+ /* To store tcc and tcc_ptr */
+ stack_adjust += sizeof(long) * 2;
+
stack_adjust = round_up(stack_adjust, 16);
stack_adjust += bpf_stack_adjust;
@@ -113,11 +144,12 @@ static void build_prologue(struct jit_ctx *ctx)
emit_insn(ctx, nop);
/*
- * First instruction initializes the tail call count (TCC).
- * On tail call we skip this instruction, and the TCC is
- * passed in REG_TCC from the caller.
+ * First instruction initializes the tail call count (TCC)
+ * register to zero. On tail call we skip this instruction,
+ * and the TCC is passed in REG_TCC from the caller.
*/
- emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
+ if (is_main_prog)
+ emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0);
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
@@ -145,20 +177,13 @@ static void build_prologue(struct jit_ctx *ctx)
store_offset -= sizeof(long);
emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
+ prepare_bpf_tail_call_cnt(ctx, &store_offset);
+
emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
if (bpf_stack_adjust)
emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
- /*
- * Program contains calls and tail calls, so REG_TCC need
- * to be saved across calls.
- */
- if (seen_tail_call(ctx) && seen_call(ctx))
- move_reg(ctx, TCC_SAVED, REG_TCC);
- else
- emit_insn(ctx, nop);
-
ctx->stack_size = stack_adjust;
}
@@ -191,6 +216,16 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
load_offset -= sizeof(long);
emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
+ /*
+ * When push into the stack, follow the order of tcc then tcc_ptr.
+ * When pop from the stack, first pop tcc_ptr then followed by tcc.
+ */
+ load_offset -= 2 * sizeof(long);
+ emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
+
+ load_offset += sizeof(long);
+ emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
+
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
if (!is_tail_call) {
@@ -203,7 +238,7 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
* Call the next bpf prog and skip the first instruction
* of TCC initialization.
*/
- emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 1);
+ emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 6);
}
}
@@ -225,7 +260,7 @@ bool bpf_jit_supports_far_kfunc_call(void)
static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
{
int off, tc_ninsn = 0;
- u8 tcc = tail_call_reg(ctx);
+ int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
u8 a1 = LOONGARCH_GPR_A1;
u8 a2 = LOONGARCH_GPR_A2;
u8 t1 = LOONGARCH_GPR_T1;
@@ -252,11 +287,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
goto toofar;
/*
- * if (--TCC < 0)
- * goto out;
+ * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
+ * goto out;
*/
- emit_insn(ctx, addid, REG_TCC, tcc, -1);
- if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
+ emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
+ emit_insn(ctx, ldd, t3, REG_TCC, 0);
+ emit_insn(ctx, addid, t3, t3, 1);
+ emit_insn(ctx, std, t3, REG_TCC, 0);
+ emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
+ if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0)
goto toofar;
/*
@@ -467,7 +506,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
u64 func_addr;
bool func_addr_fixed, sign_extend;
int i = insn - ctx->prog->insnsi;
- int ret, jmp_offset;
+ int ret, jmp_offset, tcc_ptr_off;
const u8 code = insn->code;
const u8 cond = BPF_OP(code);
const u8 t1 = LOONGARCH_GPR_T1;
@@ -903,12 +942,16 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
/* function call */
case BPF_JMP | BPF_CALL:
- mark_call(ctx);
ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
&func_addr, &func_addr_fixed);
if (ret < 0)
return ret;
+ if (insn->src_reg == BPF_PSEUDO_CALL) {
+ tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
+ emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
+ }
+
move_addr(ctx, t1, func_addr);
emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0);
@@ -919,7 +962,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
/* tail call */
case BPF_JMP | BPF_TAIL_CALL:
- mark_tail_call(ctx);
if (emit_bpf_tail_call(ctx, i) < 0)
return -EINVAL;
break;
@@ -1412,7 +1454,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
{
int i, ret, save_ret;
int stack_size = 0, nargs = 0;
- int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off;
+ int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off, tcc_ptr_off;
bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT;
void *orig_call = func_addr;
struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
@@ -1447,6 +1489,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
*
* FP - sreg_off [ callee saved reg ]
*
+ * FP - tcc_ptr_off [ tail_call_cnt_ptr ]
*/
if (m->nr_args > LOONGARCH_MAX_REG_ARGS)
@@ -1489,6 +1532,12 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
stack_size += 8;
sreg_off = stack_size;
+ /* Room of trampoline frame to store tail_call_cnt_ptr */
+ if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
+ stack_size += 8;
+ tcc_ptr_off = stack_size;
+ }
+
stack_size = round_up(stack_size, 16);
if (is_struct_ops) {
@@ -1519,6 +1568,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_size);
}
+ if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+ emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
+
/* callee saved register S1 to pass start time */
emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
@@ -1565,6 +1617,10 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
if (flags & BPF_TRAMP_F_CALL_ORIG) {
restore_args(ctx, m->nr_args, args_off);
+
+ if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+ emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
+
ret = emit_call(ctx, (const u64)orig_call);
if (ret)
goto out;
@@ -1605,6 +1661,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
+ if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+ emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
+
if (is_struct_ops) {
/* trampoline called directly */
emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, stack_size - 8);