From: Kumar Kartikeya Dwivedi <memxor@gmail.com>
Date: Sun, 4 Sep 2022 22:41:28 +0200
Subject: bpf: Add helper macro bpf_for_each_reg_in_vstate
Patch-mainline: v6.1-rc1
Git-commit: b239da34203f49c40b5d656220c39647c3ff0b3c
References: bsc#1215863 CVE-2023-39191
X-Info: adds bpf_for_each_reg_in_vstate(), dependency of commit f8064ab90d664 "bpf: Invalidate slices on destruction of dynptrs on stack"
For a lot of use cases in future patches, we will want to modify the
state of registers part of some same 'group' (e.g. same ref_obj_id). It
won't just be limited to releasing reference state, but setting a type
flag dynamically based on certain actions, etc.
Hence, we need a way to easily pass a callback to the function that
iterates over all registers in current bpf_verifier_state in all frames
upto (and including) the curframe.
While in C++ we would be able to easily use a lambda to pass state and
the callback together, sadly we aren't using C++ in the kernel. The next
best thing to avoid defining a function for each case seems like
statement expressions in GNU C. The kernel already uses them heavily,
hence they can passed to the macro in the style of a lambda. The
statement expression will then be substituted in the for loop bodies.
Variables __state and __reg are set to current bpf_func_state and reg
for each invocation of the expression inside the passed in verifier
state.
Then, convert mark_ptr_or_null_regs, clear_all_pkt_pointers,
release_reference, find_good_pkt_pointers, find_equal_scalars to
use bpf_for_each_reg_in_vstate.
Signed-off-by: Kumar Kartikeya Dwivedi <memxor@gmail.com>
Link: https://lore.kernel.org/r/20220904204145.3089-16-memxor@gmail.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Shung-Hsi Yu <shung-hsi.yu@suse.com>
---
include/linux/bpf_verifier.h | 21 ++++++
kernel/bpf/verifier.c | 135 ++++++++-----------------------------------
2 files changed, 49 insertions(+), 107 deletions(-)
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -337,6 +337,27 @@ struct bpf_verifier_state {
iter < frame->allocated_stack / BPF_REG_SIZE; \
iter++, reg = bpf_get_spilled_reg(iter, frame))
+/* Invoke __expr over regsiters in __vst, setting __state and __reg */
+#define bpf_for_each_reg_in_vstate(__vst, __state, __reg, __expr) \
+ ({ \
+ struct bpf_verifier_state *___vstate = __vst; \
+ int ___i, ___j; \
+ for (___i = 0; ___i <= ___vstate->curframe; ___i++) { \
+ struct bpf_reg_state *___regs; \
+ __state = ___vstate->frame[___i]; \
+ ___regs = __state->regs; \
+ for (___j = 0; ___j < MAX_BPF_REG; ___j++) { \
+ __reg = &___regs[___j]; \
+ (void)(__expr); \
+ } \
+ bpf_for_each_spilled_reg(___j, __state, __reg) { \
+ if (!__reg) \
+ continue; \
+ (void)(__expr); \
+ } \
+ } \
+ })
+
/* linked list of verifier states used to prune search */
struct bpf_verifier_state_list {
struct bpf_verifier_state state;
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -6466,31 +6466,15 @@ static int check_func_proto(const struct
/* Packet data might have moved, any old PTR_TO_PACKET[_META,_END]
* are now invalid, so turn them into unknown SCALAR_VALUE.
*/
-static void __clear_all_pkt_pointers(struct bpf_verifier_env *env,
- struct bpf_func_state *state)
+static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
{
- struct bpf_reg_state *regs = state->regs, *reg;
- int i;
+ struct bpf_func_state *state;
+ struct bpf_reg_state *reg;
- for (i = 0; i < MAX_BPF_REG; i++)
- if (reg_is_pkt_pointer_any(®s[i]))
- mark_reg_unknown(env, regs, i);
-
- bpf_for_each_spilled_reg(i, state, reg) {
- if (!reg)
- continue;
+ bpf_for_each_reg_in_vstate(env->cur_state, state, reg, ({
if (reg_is_pkt_pointer_any(reg))
__mark_reg_unknown(env, reg);
- }
-}
-
-static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
-{
- struct bpf_verifier_state *vstate = env->cur_state;
- int i;
-
- for (i = 0; i <= vstate->curframe; i++)
- __clear_all_pkt_pointers(env, vstate->frame[i]);
+ }));
}
enum {
@@ -6519,41 +6503,24 @@ static void mark_pkt_end(struct bpf_veri
reg->range = AT_PKT_END;
}
-static void release_reg_references(struct bpf_verifier_env *env,
- struct bpf_func_state *state,
- int ref_obj_id)
-{
- struct bpf_reg_state *regs = state->regs, *reg;
- int i;
-
- for (i = 0; i < MAX_BPF_REG; i++)
- if (regs[i].ref_obj_id == ref_obj_id)
- mark_reg_unknown(env, regs, i);
-
- bpf_for_each_spilled_reg(i, state, reg) {
- if (!reg)
- continue;
- if (reg->ref_obj_id == ref_obj_id)
- __mark_reg_unknown(env, reg);
- }
-}
-
/* The pointer with the specified id has released its reference to kernel
* resources. Identify all copies of the same pointer and clear the reference.
*/
static int release_reference(struct bpf_verifier_env *env,
int ref_obj_id)
{
- struct bpf_verifier_state *vstate = env->cur_state;
+ struct bpf_func_state *state;
+ struct bpf_reg_state *reg;
int err;
- int i;
err = release_reference_state(cur_func(env), ref_obj_id);
if (err)
return err;
- for (i = 0; i <= vstate->curframe; i++)
- release_reg_references(env, vstate->frame[i], ref_obj_id);
+ bpf_for_each_reg_in_vstate(env->cur_state, state, reg, ({
+ if (reg->ref_obj_id == ref_obj_id)
+ __mark_reg_unknown(env, reg);
+ }));
return 0;
}
@@ -9193,34 +9160,14 @@ static int check_alu_op(struct bpf_verif
return 0;
}
-static void __find_good_pkt_pointers(struct bpf_func_state *state,
- struct bpf_reg_state *dst_reg,
- enum bpf_reg_type type, int new_range)
-{
- struct bpf_reg_state *reg;
- int i;
-
- for (i = 0; i < MAX_BPF_REG; i++) {
- reg = &state->regs[i];
- if (reg->type == type && reg->id == dst_reg->id)
- /* keep the maximum range already checked */
- reg->range = max(reg->range, new_range);
- }
-
- bpf_for_each_spilled_reg(i, state, reg) {
- if (!reg)
- continue;
- if (reg->type == type && reg->id == dst_reg->id)
- reg->range = max(reg->range, new_range);
- }
-}
-
static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
struct bpf_reg_state *dst_reg,
enum bpf_reg_type type,
bool range_right_open)
{
- int new_range, i;
+ struct bpf_func_state *state;
+ struct bpf_reg_state *reg;
+ int new_range;
if (dst_reg->off < 0 ||
(dst_reg->off == 0 && range_right_open))
@@ -9285,9 +9232,11 @@ static void find_good_pkt_pointers(struc
* the range won't allow anything.
* dst_reg->off is known < MAX_PACKET_OFF, therefore it fits in a u16.
*/
- for (i = 0; i <= vstate->curframe; i++)
- __find_good_pkt_pointers(vstate->frame[i], dst_reg, type,
- new_range);
+ bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+ if (reg->type == type && reg->id == dst_reg->id)
+ /* keep the maximum range already checked */
+ reg->range = max(reg->range, new_range);
+ }));
}
static int is_branch32_taken(struct bpf_reg_state *reg, u32 val, u8 opcode)
@@ -9776,7 +9725,7 @@ static void mark_ptr_or_null_reg(struct
if (!reg_may_point_to_spin_lock(reg)) {
/* For not-NULL ptr, reg->ref_obj_id will be reset
- * in release_reg_references().
+ * in release_reference().
*
* reg->id is still used by spin_lock ptr. Other
* than spin_lock ptr type, reg->id can be reset.
@@ -9786,22 +9735,6 @@ static void mark_ptr_or_null_reg(struct
}
}
-static void __mark_ptr_or_null_regs(struct bpf_func_state *state, u32 id,
- bool is_null)
-{
- struct bpf_reg_state *reg;
- int i;
-
- for (i = 0; i < MAX_BPF_REG; i++)
- mark_ptr_or_null_reg(state, &state->regs[i], id, is_null);
-
- bpf_for_each_spilled_reg(i, state, reg) {
- if (!reg)
- continue;
- mark_ptr_or_null_reg(state, reg, id, is_null);
- }
-}
-
/* The logic is similar to find_good_pkt_pointers(), both could eventually
* be folded together at some point.
*/
@@ -9809,10 +9742,9 @@ static void mark_ptr_or_null_regs(struct
bool is_null)
{
struct bpf_func_state *state = vstate->frame[vstate->curframe];
- struct bpf_reg_state *regs = state->regs;
+ struct bpf_reg_state *regs = state->regs, *reg;
u32 ref_obj_id = regs[regno].ref_obj_id;
u32 id = regs[regno].id;
- int i;
if (ref_obj_id && ref_obj_id == id && is_null)
/* regs[regno] is in the " == NULL" branch.
@@ -9821,8 +9753,9 @@ static void mark_ptr_or_null_regs(struct
*/
WARN_ON_ONCE(release_reference_state(state, id));
- for (i = 0; i <= vstate->curframe; i++)
- __mark_ptr_or_null_regs(vstate->frame[i], id, is_null);
+ bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+ mark_ptr_or_null_reg(state, reg, id, is_null);
+ }));
}
static bool try_match_pkt_pointers(const struct bpf_insn *insn,
@@ -9935,23 +9868,11 @@ static void find_equal_scalars(struct bp
{
struct bpf_func_state *state;
struct bpf_reg_state *reg;
- int i, j;
- for (i = 0; i <= vstate->curframe; i++) {
- state = vstate->frame[i];
- for (j = 0; j < MAX_BPF_REG; j++) {
- reg = &state->regs[j];
- if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
- *reg = *known_reg;
- }
-
- bpf_for_each_spilled_reg(j, state, reg) {
- if (!reg)
- continue;
- if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
- *reg = *known_reg;
- }
- }
+ bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+ if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
+ *reg = *known_reg;
+ }));
}
static int check_cond_jmp_op(struct bpf_verifier_env *env,