summaryrefslogtreecommitdiff
path: root/arch/riscv/net/bpf_jit_comp64.c
diff options
context:
space:
mode:
Diffstat (limited to 'arch/riscv/net/bpf_jit_comp64.c')
-rw-r--r--arch/riscv/net/bpf_jit_comp64.c255
1 files changed, 162 insertions, 93 deletions
diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c
index c648864c8cd1..8423f4ddf8f5 100644
--- a/arch/riscv/net/bpf_jit_comp64.c
+++ b/arch/riscv/net/bpf_jit_comp64.c
@@ -13,6 +13,8 @@
#include <asm/patch.h>
#include "bpf_jit.h"
+#define RV_FENTRY_NINSNS 2
+
#define RV_REG_TCC RV_REG_A6
#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
@@ -241,7 +243,7 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
if (!is_tail_call)
emit_mv(RV_REG_A0, RV_REG_A5, ctx);
emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
- is_tail_call ? 20 : 0, /* skip reserved nops and TCC init */
+ is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
ctx);
}
@@ -578,7 +580,8 @@ static int add_exception_handler(const struct bpf_insn *insn,
unsigned long pc;
off_t offset;
- if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
+ if (!ctx->insns || !ctx->prog->aux->extable ||
+ (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX))
return 0;
if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
@@ -618,32 +621,7 @@ static int add_exception_handler(const struct bpf_insn *insn,
return 0;
}
-static int gen_call_or_nops(void *target, void *ip, u32 *insns)
-{
- s64 rvoff;
- int i, ret;
- struct rv_jit_context ctx;
-
- ctx.ninsns = 0;
- ctx.insns = (u16 *)insns;
-
- if (!target) {
- for (i = 0; i < 4; i++)
- emit(rv_nop(), &ctx);
- return 0;
- }
-
- rvoff = (s64)(target - (ip + 4));
- emit(rv_sd(RV_REG_SP, -8, RV_REG_RA), &ctx);
- ret = emit_jump_and_link(RV_REG_RA, rvoff, false, &ctx);
- if (ret)
- return ret;
- emit(rv_ld(RV_REG_RA, -8, RV_REG_SP), &ctx);
-
- return 0;
-}
-
-static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
+static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
{
s64 rvoff;
struct rv_jit_context ctx;
@@ -658,38 +636,35 @@ static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
}
rvoff = (s64)(target - ip);
- return emit_jump_and_link(RV_REG_ZERO, rvoff, false, &ctx);
+ return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
}
int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
void *old_addr, void *new_addr)
{
- u32 old_insns[4], new_insns[4];
+ u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
bool is_call = poke_type == BPF_MOD_CALL;
- int (*gen_insns)(void *target, void *ip, u32 *insns);
- int ninsns = is_call ? 4 : 2;
int ret;
- if (!is_bpf_text_address((unsigned long)ip))
+ if (!is_kernel_text((unsigned long)ip) &&
+ !is_bpf_text_address((unsigned long)ip))
return -ENOTSUPP;
- gen_insns = is_call ? gen_call_or_nops : gen_jump_or_nops;
-
- ret = gen_insns(old_addr, ip, old_insns);
+ ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
if (ret)
return ret;
- if (memcmp(ip, old_insns, ninsns * 4))
+ if (memcmp(ip, old_insns, RV_FENTRY_NINSNS * 4))
return -EFAULT;
- ret = gen_insns(new_addr, ip, new_insns);
+ ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
if (ret)
return ret;
cpus_read_lock();
mutex_lock(&text_mutex);
- if (memcmp(ip, new_insns, ninsns * 4))
- ret = patch_text(ip, new_insns, ninsns);
+ if (memcmp(ip, new_insns, RV_FENTRY_NINSNS * 4))
+ ret = patch_text(ip, new_insns, RV_FENTRY_NINSNS);
mutex_unlock(&text_mutex);
cpus_read_unlock();
@@ -787,8 +762,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
int i, ret, offset;
int *branches_off = NULL;
int stack_size = 0, nregs = m->nr_args;
- int retaddr_off, fp_off, retval_off, args_off;
- int nregs_off, ip_off, run_ctx_off, sreg_off;
+ int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
@@ -796,13 +770,27 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
bool save_ret;
u32 insn;
- /* Generated trampoline stack layout:
+ /* Two types of generated trampoline stack layout:
*
- * FP - 8 [ RA of parent func ] return address of parent
+ * 1. trampoline called from function entry
+ * --------------------------------------
+ * FP + 8 [ RA to parent func ] return address to parent
* function
- * FP - retaddr_off [ RA of traced func ] return address of traced
+ * FP + 0 [ FP of parent func ] frame pointer of parent
* function
- * FP - fp_off [ FP of parent func ]
+ * FP - 8 [ T0 to traced func ] return address of traced
+ * function
+ * FP - 16 [ FP of traced func ] frame pointer of traced
+ * function
+ * --------------------------------------
+ *
+ * 2. trampoline called directly
+ * --------------------------------------
+ * FP - 8 [ RA to caller func ] return address to caller
+ * function
+ * FP - 16 [ FP of caller func ] frame pointer of caller
+ * function
+ * --------------------------------------
*
* FP - retval_off [ return value ] BPF_TRAMP_F_CALL_ORIG or
* BPF_TRAMP_F_RET_FENTRY_RET
@@ -833,14 +821,8 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
if (nregs > 8)
return -ENOTSUPP;
- /* room for parent function return address */
- stack_size += 8;
-
- stack_size += 8;
- retaddr_off = stack_size;
-
- stack_size += 8;
- fp_off = stack_size;
+ /* room of trampoline frame to store return address and frame pointer */
+ stack_size += 16;
save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
if (save_ret) {
@@ -867,12 +849,29 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
stack_size = round_up(stack_size, 16);
- emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
-
- emit_sd(RV_REG_SP, stack_size - retaddr_off, RV_REG_RA, ctx);
- emit_sd(RV_REG_SP, stack_size - fp_off, RV_REG_FP, ctx);
-
- emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
+ if (func_addr) {
+ /* For the trampoline called from function entry,
+ * the frame of traced function and the frame of
+ * trampoline need to be considered.
+ */
+ emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
+ emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
+ emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
+ emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);
+
+ emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
+ emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
+ emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
+ emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
+ } else {
+ /* For the trampoline called directly, just handle
+ * the frame of trampoline.
+ */
+ emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
+ emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
+ emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
+ emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
+ }
/* callee saved register S1 to pass start time */
emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
@@ -890,7 +889,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
/* skip to actual body of traced function */
if (flags & BPF_TRAMP_F_SKIP_FRAME)
- orig_call += 16;
+ orig_call += RV_FENTRY_NINSNS * 4;
if (flags & BPF_TRAMP_F_CALL_ORIG) {
emit_imm(RV_REG_A0, (const s64)im, ctx);
@@ -967,17 +966,30 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
- if (flags & BPF_TRAMP_F_SKIP_FRAME)
- /* return address of parent function */
- emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
- else
- /* return address of traced function */
- emit_ld(RV_REG_RA, stack_size - retaddr_off, RV_REG_SP, ctx);
+ if (func_addr) {
+ /* trampoline called from function entry */
+ emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
+ emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
+ emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
- emit_ld(RV_REG_FP, stack_size - fp_off, RV_REG_SP, ctx);
- emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
+ emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
+ emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
+ emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);
- emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
+ if (flags & BPF_TRAMP_F_SKIP_FRAME)
+ /* return to parent function */
+ emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
+ else
+ /* return to traced function */
+ emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
+ } else {
+ /* trampoline called directly */
+ emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
+ emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
+ emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
+
+ emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
+ }
ret = ctx->ninsns;
out:
@@ -1035,7 +1047,19 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
emit_zext_32(rd, ctx);
break;
}
- emit_mv(rd, rs, ctx);
+ switch (insn->off) {
+ case 0:
+ emit_mv(rd, rs, ctx);
+ break;
+ case 8:
+ case 16:
+ emit_slli(RV_REG_T1, rs, 64 - insn->off, ctx);
+ emit_srai(rd, RV_REG_T1, 64 - insn->off, ctx);
+ break;
+ case 32:
+ emit_addiw(rd, rs, 0, ctx);
+ break;
+ }
if (!is64 && !aux->verifier_zext)
emit_zext_32(rd, ctx);
break;
@@ -1083,13 +1107,19 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
break;
case BPF_ALU | BPF_DIV | BPF_X:
case BPF_ALU64 | BPF_DIV | BPF_X:
- emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
+ if (off)
+ emit(is64 ? rv_div(rd, rd, rs) : rv_divw(rd, rd, rs), ctx);
+ else
+ emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
if (!is64 && !aux->verifier_zext)
emit_zext_32(rd, ctx);
break;
case BPF_ALU | BPF_MOD | BPF_X:
case BPF_ALU64 | BPF_MOD | BPF_X:
- emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
+ if (off)
+ emit(is64 ? rv_rem(rd, rd, rs) : rv_remw(rd, rd, rs), ctx);
+ else
+ emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
if (!is64 && !aux->verifier_zext)
emit_zext_32(rd, ctx);
break;
@@ -1138,6 +1168,7 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
break;
case BPF_ALU | BPF_END | BPF_FROM_BE:
+ case BPF_ALU64 | BPF_END | BPF_FROM_LE:
emit_li(RV_REG_T2, 0, ctx);
emit_andi(RV_REG_T1, rd, 0xff, ctx);
@@ -1260,16 +1291,24 @@ out_be:
case BPF_ALU | BPF_DIV | BPF_K:
case BPF_ALU64 | BPF_DIV | BPF_K:
emit_imm(RV_REG_T1, imm, ctx);
- emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
- rv_divuw(rd, rd, RV_REG_T1), ctx);
+ if (off)
+ emit(is64 ? rv_div(rd, rd, RV_REG_T1) :
+ rv_divw(rd, rd, RV_REG_T1), ctx);
+ else
+ emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
+ rv_divuw(rd, rd, RV_REG_T1), ctx);
if (!is64 && !aux->verifier_zext)
emit_zext_32(rd, ctx);
break;
case BPF_ALU | BPF_MOD | BPF_K:
case BPF_ALU64 | BPF_MOD | BPF_K:
emit_imm(RV_REG_T1, imm, ctx);
- emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
- rv_remuw(rd, rd, RV_REG_T1), ctx);
+ if (off)
+ emit(is64 ? rv_rem(rd, rd, RV_REG_T1) :
+ rv_remw(rd, rd, RV_REG_T1), ctx);
+ else
+ emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
+ rv_remuw(rd, rd, RV_REG_T1), ctx);
if (!is64 && !aux->verifier_zext)
emit_zext_32(rd, ctx);
break;
@@ -1303,7 +1342,11 @@ out_be:
/* JUMP off */
case BPF_JMP | BPF_JA:
- rvoff = rv_offset(i, off, ctx);
+ case BPF_JMP32 | BPF_JA:
+ if (BPF_CLASS(code) == BPF_JMP)
+ rvoff = rv_offset(i, off, ctx);
+ else
+ rvoff = rv_offset(i, imm, ctx);
ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
if (ret)
return ret;
@@ -1475,7 +1518,7 @@ out_be:
return 1;
}
- /* LDX: dst = *(size *)(src + off) */
+ /* LDX: dst = *(unsigned size *)(src + off) */
case BPF_LDX | BPF_MEM | BPF_B:
case BPF_LDX | BPF_MEM | BPF_H:
case BPF_LDX | BPF_MEM | BPF_W:
@@ -1484,14 +1527,28 @@ out_be:
case BPF_LDX | BPF_PROBE_MEM | BPF_H:
case BPF_LDX | BPF_PROBE_MEM | BPF_W:
case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
+ /* LDSX: dst = *(signed size *)(src + off) */
+ case BPF_LDX | BPF_MEMSX | BPF_B:
+ case BPF_LDX | BPF_MEMSX | BPF_H:
+ case BPF_LDX | BPF_MEMSX | BPF_W:
+ case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
+ case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
+ case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
{
int insn_len, insns_start;
+ bool sign_ext;
+
+ sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
+ BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
switch (BPF_SIZE(code)) {
case BPF_B:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
- emit(rv_lbu(rd, off, rs), ctx);
+ if (sign_ext)
+ emit(rv_lb(rd, off, rs), ctx);
+ else
+ emit(rv_lbu(rd, off, rs), ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
@@ -1499,15 +1556,19 @@ out_be:
emit_imm(RV_REG_T1, off, ctx);
emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
insns_start = ctx->ninsns;
- emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
+ if (sign_ext)
+ emit(rv_lb(rd, 0, RV_REG_T1), ctx);
+ else
+ emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
insn_len = ctx->ninsns - insns_start;
- if (insn_is_zext(&insn[1]))
- return 1;
break;
case BPF_H:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
- emit(rv_lhu(rd, off, rs), ctx);
+ if (sign_ext)
+ emit(rv_lh(rd, off, rs), ctx);
+ else
+ emit(rv_lhu(rd, off, rs), ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
@@ -1515,15 +1576,19 @@ out_be:
emit_imm(RV_REG_T1, off, ctx);
emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
insns_start = ctx->ninsns;
- emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
+ if (sign_ext)
+ emit(rv_lh(rd, 0, RV_REG_T1), ctx);
+ else
+ emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
insn_len = ctx->ninsns - insns_start;
- if (insn_is_zext(&insn[1]))
- return 1;
break;
case BPF_W:
if (is_12b_int(off)) {
insns_start = ctx->ninsns;
- emit(rv_lwu(rd, off, rs), ctx);
+ if (sign_ext)
+ emit(rv_lw(rd, off, rs), ctx);
+ else
+ emit(rv_lwu(rd, off, rs), ctx);
insn_len = ctx->ninsns - insns_start;
break;
}
@@ -1531,10 +1596,11 @@ out_be:
emit_imm(RV_REG_T1, off, ctx);
emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
insns_start = ctx->ninsns;
- emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
+ if (sign_ext)
+ emit(rv_lw(rd, 0, RV_REG_T1), ctx);
+ else
+ emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
insn_len = ctx->ninsns - insns_start;
- if (insn_is_zext(&insn[1]))
- return 1;
break;
case BPF_DW:
if (is_12b_int(off)) {
@@ -1555,6 +1621,9 @@ out_be:
ret = add_exception_handler(insn, ctx, rd, insn_len);
if (ret)
return ret;
+
+ if (BPF_SIZE(code) != BPF_DW && insn_is_zext(&insn[1]))
+ return 1;
break;
}
/* speculation barrier */
@@ -1691,8 +1760,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
store_offset = stack_adjust - 8;
- /* reserve 4 nop insns */
- for (i = 0; i < 4; i++)
+ /* nops reserved for auipc+jalr pair */
+ for (i = 0; i < RV_FENTRY_NINSNS; i++)
emit(rv_nop(), ctx);
/* First instruction is always setting the tail-call-counter