LoongArch: BPF: Fix the tailcall hierarchy

In specific use cases combining tailcalls and BPF-to-BPF calls,
MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt
back-propagation from callee to caller. This patch fixes this
tailcall issue caused by abusing the tailcall in bpf2bpf feature
on LoongArch like the way of "bpf, x64: Fix tailcall hierarchy".

Push tail_call_cnt_ptr and tail_call_cnt into the stack,
tail_call_cnt_ptr is passed between tailcall and bpf2bpf,
uses tail_call_cnt_ptr to increment tail_call_cnt.

Fixes: bb035ef0cc ("LoongArch: BPF: Support mixing bpf2bpf and tailcalls")
Reviewed-by: Geliang Tang <geliang@kernel.org>
Reviewed-by: Hengqi Chen <hengqi.chen@gmail.com>
Signed-off-by: Haoran Jiang <jianghaoran@kylinos.cn>
Signed-off-by: Huacai Chen <chenhuacai@loongson.cn>
This commit is contained in:
Haoran Jiang 2025-08-05 19:00:22 +08:00 committed by Huacai Chen
parent cd39d9e6b7
commit c0fcc955ff
1 changed files with 107 additions and 48 deletions

View File

@ -17,10 +17,7 @@
#define LOONGARCH_BPF_FENTRY_NBYTES (LOONGARCH_LONG_JUMP_NINSNS * 4) #define LOONGARCH_BPF_FENTRY_NBYTES (LOONGARCH_LONG_JUMP_NINSNS * 4)
#define REG_TCC LOONGARCH_GPR_A6 #define REG_TCC LOONGARCH_GPR_A6
#define TCC_SAVED LOONGARCH_GPR_S5 #define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (round_up(stack, 16) - 80)
#define SAVE_RA BIT(0)
#define SAVE_TCC BIT(1)
static const int regmap[] = { static const int regmap[] = {
/* return value from in-kernel function, and exit value for eBPF program */ /* return value from in-kernel function, and exit value for eBPF program */
@ -42,32 +39,57 @@ static const int regmap[] = {
[BPF_REG_AX] = LOONGARCH_GPR_T0, [BPF_REG_AX] = LOONGARCH_GPR_T0,
}; };
static void mark_call(struct jit_ctx *ctx) static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int *store_offset)
{ {
ctx->flags |= SAVE_RA; const struct bpf_prog *prog = ctx->prog;
} const bool is_main_prog = !bpf_is_subprog(prog);
static void mark_tail_call(struct jit_ctx *ctx) if (is_main_prog) {
{ /*
ctx->flags |= SAVE_TCC; * LOONGARCH_GPR_T3 = MAX_TAIL_CALL_CNT
} * if (REG_TCC > T3 )
* std REG_TCC -> LOONGARCH_GPR_SP + store_offset
* else
* std REG_TCC -> LOONGARCH_GPR_SP + store_offset
* REG_TCC = LOONGARCH_GPR_SP + store_offset
*
* std REG_TCC -> LOONGARCH_GPR_SP + store_offset
*
* The purpose of this code is to first push the TCC into stack,
* and then push the address of TCC into stack.
* In cases where bpf2bpf and tailcall are used in combination,
* the value in REG_TCC may be a count or an address,
* these two cases need to be judged and handled separately.
*/
emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
*store_offset -= sizeof(long);
static bool seen_call(struct jit_ctx *ctx) emit_cond_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4);
{
return (ctx->flags & SAVE_RA);
}
static bool seen_tail_call(struct jit_ctx *ctx) /*
{ * If REG_TCC < MAX_TAIL_CALL_CNT, the value in REG_TCC is a count,
return (ctx->flags & SAVE_TCC); * push tcc into stack
} */
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
static u8 tail_call_reg(struct jit_ctx *ctx) /* Push the address of TCC into the REG_TCC */
{ emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
if (seen_call(ctx))
return TCC_SAVED;
return REG_TCC; emit_uncond_jmp(ctx, 2);
/*
* If REG_TCC > MAX_TAIL_CALL_CNT, the value in REG_TCC is an address,
* push tcc_ptr into stack
*/
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
} else {
*store_offset -= sizeof(long);
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
}
/* Push tcc_ptr into stack */
*store_offset -= sizeof(long);
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
} }
/* /*
@ -90,6 +112,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
* | $s4 | * | $s4 |
* +-------------------------+ * +-------------------------+
* | $s5 | * | $s5 |
* +-------------------------+
* | tcc |
* +-------------------------+
* | tcc_ptr |
* +-------------------------+ <--BPF_REG_FP * +-------------------------+ <--BPF_REG_FP
* | prog->aux->stack_depth | * | prog->aux->stack_depth |
* | (optional) | * | (optional) |
@ -99,12 +125,17 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
static void build_prologue(struct jit_ctx *ctx) static void build_prologue(struct jit_ctx *ctx)
{ {
int i, stack_adjust = 0, store_offset, bpf_stack_adjust; int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
const struct bpf_prog *prog = ctx->prog;
const bool is_main_prog = !bpf_is_subprog(prog);
bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16); bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
/* To store ra, fp, s0, s1, s2, s3, s4 and s5. */ /* To store ra, fp, s0, s1, s2, s3, s4, s5 */
stack_adjust += sizeof(long) * 8; stack_adjust += sizeof(long) * 8;
/* To store tcc and tcc_ptr */
stack_adjust += sizeof(long) * 2;
stack_adjust = round_up(stack_adjust, 16); stack_adjust = round_up(stack_adjust, 16);
stack_adjust += bpf_stack_adjust; stack_adjust += bpf_stack_adjust;
@ -113,11 +144,12 @@ static void build_prologue(struct jit_ctx *ctx)
emit_insn(ctx, nop); emit_insn(ctx, nop);
/* /*
* First instruction initializes the tail call count (TCC). * First instruction initializes the tail call count (TCC)
* On tail call we skip this instruction, and the TCC is * register to zero. On tail call we skip this instruction,
* passed in REG_TCC from the caller. * and the TCC is passed in REG_TCC from the caller.
*/ */
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT); if (is_main_prog)
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0);
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust); emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
@ -145,20 +177,13 @@ static void build_prologue(struct jit_ctx *ctx)
store_offset -= sizeof(long); store_offset -= sizeof(long);
emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset); emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
prepare_bpf_tail_call_cnt(ctx, &store_offset);
emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust); emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
if (bpf_stack_adjust) if (bpf_stack_adjust)
emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust); emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
/*
* Program contains calls and tail calls, so REG_TCC need
* to be saved across calls.
*/
if (seen_tail_call(ctx) && seen_call(ctx))
move_reg(ctx, TCC_SAVED, REG_TCC);
else
emit_insn(ctx, nop);
ctx->stack_size = stack_adjust; ctx->stack_size = stack_adjust;
} }
@ -191,6 +216,16 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
load_offset -= sizeof(long); load_offset -= sizeof(long);
emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset); emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
/*
* When push into the stack, follow the order of tcc then tcc_ptr.
* When pop from the stack, first pop tcc_ptr then followed by tcc.
*/
load_offset -= 2 * sizeof(long);
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
load_offset += sizeof(long);
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust); emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
if (!is_tail_call) { if (!is_tail_call) {
@ -203,7 +238,7 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
* Call the next bpf prog and skip the first instruction * Call the next bpf prog and skip the first instruction
* of TCC initialization. * of TCC initialization.
*/ */
emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 1); emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 6);
} }
} }
@ -225,7 +260,7 @@ bool bpf_jit_supports_far_kfunc_call(void)
static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn) static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
{ {
int off, tc_ninsn = 0; int off, tc_ninsn = 0;
u8 tcc = tail_call_reg(ctx); int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
u8 a1 = LOONGARCH_GPR_A1; u8 a1 = LOONGARCH_GPR_A1;
u8 a2 = LOONGARCH_GPR_A2; u8 a2 = LOONGARCH_GPR_A2;
u8 t1 = LOONGARCH_GPR_T1; u8 t1 = LOONGARCH_GPR_T1;
@ -252,11 +287,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
goto toofar; goto toofar;
/* /*
* if (--TCC < 0) * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
* goto out; * goto out;
*/ */
emit_insn(ctx, addid, REG_TCC, tcc, -1); emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0) emit_insn(ctx, ldd, t3, REG_TCC, 0);
emit_insn(ctx, addid, t3, t3, 1);
emit_insn(ctx, std, t3, REG_TCC, 0);
emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0)
goto toofar; goto toofar;
/* /*
@ -467,7 +506,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
u64 func_addr; u64 func_addr;
bool func_addr_fixed, sign_extend; bool func_addr_fixed, sign_extend;
int i = insn - ctx->prog->insnsi; int i = insn - ctx->prog->insnsi;
int ret, jmp_offset; int ret, jmp_offset, tcc_ptr_off;
const u8 code = insn->code; const u8 code = insn->code;
const u8 cond = BPF_OP(code); const u8 cond = BPF_OP(code);
const u8 t1 = LOONGARCH_GPR_T1; const u8 t1 = LOONGARCH_GPR_T1;
@ -903,12 +942,16 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
/* function call */ /* function call */
case BPF_JMP | BPF_CALL: case BPF_JMP | BPF_CALL:
mark_call(ctx);
ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
&func_addr, &func_addr_fixed); &func_addr, &func_addr_fixed);
if (ret < 0) if (ret < 0)
return ret; return ret;
if (insn->src_reg == BPF_PSEUDO_CALL) {
tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
}
move_addr(ctx, t1, func_addr); move_addr(ctx, t1, func_addr);
emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0); emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0);
@ -919,7 +962,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
/* tail call */ /* tail call */
case BPF_JMP | BPF_TAIL_CALL: case BPF_JMP | BPF_TAIL_CALL:
mark_tail_call(ctx);
if (emit_bpf_tail_call(ctx, i) < 0) if (emit_bpf_tail_call(ctx, i) < 0)
return -EINVAL; return -EINVAL;
break; break;
@ -1412,7 +1454,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
{ {
int i, ret, save_ret; int i, ret, save_ret;
int stack_size = 0, nargs = 0; int stack_size = 0, nargs = 0;
int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off; int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off, tcc_ptr_off;
bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT; bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT;
void *orig_call = func_addr; void *orig_call = func_addr;
struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY]; struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
@ -1447,6 +1489,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
* *
* FP - sreg_off [ callee saved reg ] * FP - sreg_off [ callee saved reg ]
* *
* FP - tcc_ptr_off [ tail_call_cnt_ptr ]
*/ */
if (m->nr_args > LOONGARCH_MAX_REG_ARGS) if (m->nr_args > LOONGARCH_MAX_REG_ARGS)
@ -1489,6 +1532,12 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
stack_size += 8; stack_size += 8;
sreg_off = stack_size; sreg_off = stack_size;
/* Room of trampoline frame to store tail_call_cnt_ptr */
if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
stack_size += 8;
tcc_ptr_off = stack_size;
}
stack_size = round_up(stack_size, 16); stack_size = round_up(stack_size, 16);
if (is_struct_ops) { if (is_struct_ops) {
@ -1519,6 +1568,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_size); emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_size);
} }
if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
/* callee saved register S1 to pass start time */ /* callee saved register S1 to pass start time */
emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off); emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
@ -1565,6 +1617,10 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
if (flags & BPF_TRAMP_F_CALL_ORIG) { if (flags & BPF_TRAMP_F_CALL_ORIG) {
restore_args(ctx, m->nr_args, args_off); restore_args(ctx, m->nr_args, args_off);
if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
ret = emit_call(ctx, (const u64)orig_call); ret = emit_call(ctx, (const u64)orig_call);
if (ret) if (ret)
goto out; goto out;
@ -1605,6 +1661,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off); emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
if (is_struct_ops) { if (is_struct_ops) {
/* trampoline called directly */ /* trampoline called directly */
emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, stack_size - 8); emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, stack_size - 8);