summaryrefslogtreecommitdiff
path: root/arch/riscv/net/bpf_jit_comp64.c
diff options
context:
space:
mode:
Diffstat (limited to 'arch/riscv/net/bpf_jit_comp64.c')
-rw-r--r--arch/riscv/net/bpf_jit_comp64.c18
1 files changed, 12 insertions, 6 deletions
diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c
index ecd3ae6f4116..8581693e62d3 100644
--- a/arch/riscv/net/bpf_jit_comp64.c
+++ b/arch/riscv/net/bpf_jit_comp64.c
@@ -245,7 +245,7 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
/* Set return value. */
if (!is_tail_call)
- emit_mv(RV_REG_A0, RV_REG_A5, ctx);
+ emit_addiw(RV_REG_A0, RV_REG_A5, 0, ctx);
emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
ctx);
@@ -759,8 +759,10 @@ static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_of
if (ret)
return ret;
- if (save_ret)
- emit_sd(RV_REG_FP, -retval_off, regmap[BPF_REG_0], ctx);
+ if (save_ret) {
+ emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
+ emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
+ }
/* update branch with beqz */
if (ctx->insns) {
@@ -853,7 +855,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
if (save_ret) {
- stack_size += 8;
+ stack_size += 16; /* Save both A5 (BPF R0) and A0 */
retval_off = stack_size;
}
@@ -957,6 +959,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
if (ret)
goto out;
emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
+ emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
im->ip_after_call = ctx->insns + ctx->ninsns;
/* 2 nops reserved for auipc+jalr pair */
emit(rv_nop(), ctx);
@@ -988,8 +991,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
if (flags & BPF_TRAMP_F_RESTORE_REGS)
restore_args(nregs, args_off, ctx);
- if (save_ret)
+ if (save_ret) {
emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
+ emit_ld(regmap[BPF_REG_0], -(retval_off - 8), RV_REG_FP, ctx);
+ }
emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
@@ -1515,7 +1520,8 @@ out_be:
if (ret)
return ret;
- emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
+ if (insn->src_reg != BPF_PSEUDO_CALL)
+ emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
break;
}
/* tail call */