Blob Blame History Raw
From: Jiri Slaby <jslaby@suse.cz>
Subject: kABI: do not check external trampolines for signature
Patch-mainline: never, kabi
References: kabi bsc#1207894 bsc#1211243

Commit 2105a92748e8 (static_call,x86: Robustify trampoline patching)
added a signature to trampolines. This silently broke external modules.

So do not check the signature for external modules.

Signed-off-by: Jiri Slaby <jslaby@suse.cz>
---
 arch/arm64/kernel/patching.c  |    8 +++++-
 arch/x86/kernel/static_call.c |   16 ++++++++----
 include/linux/static_call.h   |   28 ++++++++++++++++++----
 include/linux/tracepoint.h    |   20 +++++++++++----
 kernel/static_call.c          |   26 +++++++++++++-------
 kernel/trace/trace_events.c   |   22 ++++++++---------
 kernel/tracepoint.c           |   53 ++++++++++++++++++++++++++++++------------
 7 files changed, 122 insertions(+), 51 deletions(-)

--- a/arch/arm64/kernel/patching.c
+++ b/arch/arm64/kernel/patching.c
@@ -115,7 +115,8 @@ static void *strip_cfi_jt(void *addr)
 	return addr;
 }
 
-void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
+void __arch_static_call_transform(void *site, void *tramp, void *func,
+				  bool tail, bool checked)
 {
 	/*
 	 * -0x8	<literal>
@@ -159,6 +160,11 @@ void arch_static_call_transform(void *si
 		caches_clean_inval_pou((u64)tramp - 8, sizeof(insns));
 }
 
+void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
+{
+       __arch_static_call_transform(site, tramp, func, tail, false);
+}
+
 int __kprobes aarch64_insn_patch_text_nosync(void *addr, u32 insn)
 {
 	u32 *tp = addr;
--- a/arch/x86/kernel/static_call.c
+++ b/arch/x86/kernel/static_call.c
@@ -68,11 +68,11 @@ static void __ref __static_call_transfor
 	text_poke_bp(insn, code, size, emulate);
 }
 
-static void __static_call_validate(void *insn, bool tail, bool tramp)
+static void __static_call_validate(void *insn, bool tail, bool tramp, bool checked)
 {
 	u8 opcode = *(u8 *)insn;
 
-	if (tramp && memcmp(insn+5, tramp_ud, 3)) {
+	if (tramp && checked && memcmp(insn+5, tramp_ud, 3)) {
 		pr_err("trampoline signature fail");
 		BUG();
 	}
@@ -110,22 +110,28 @@ static inline enum insn_type __sc_insn(b
 	return 2*tail + null;
 }
 
-void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
+void __arch_static_call_transform(void *site, void *tramp, void *func, bool tail, bool checked)
 {
 	mutex_lock(&text_mutex);
 
 	if (tramp) {
-		__static_call_validate(tramp, true, true);
+		__static_call_validate(tramp, true, true, checked);
 		__static_call_transform(tramp, __sc_insn(!func, true), func, false);
 	}
 
 	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site) {
-		__static_call_validate(site, tail, false);
+		__static_call_validate(site, tail, false, checked);
 		__static_call_transform(site, __sc_insn(!func, tail), func, false);
 	}
 
 	mutex_unlock(&text_mutex);
 }
+EXPORT_SYMBOL_GPL(__arch_static_call_transform);
+
+void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
+{
+	__arch_static_call_transform(site, tramp, func, tail, false);
+}
 EXPORT_SYMBOL_GPL(arch_static_call_transform);
 
 #ifdef CONFIG_RETPOLINE
--- a/include/linux/static_call.h
+++ b/include/linux/static_call.h
@@ -141,6 +141,8 @@
 /*
  * Either @site or @tramp can be NULL.
  */
+extern void __arch_static_call_transform(void *site, void *tramp, void *func, bool tail,
+					 bool checked);
 extern void arch_static_call_transform(void *site, void *tramp, void *func, bool tail);
 
 #define STATIC_CALL_TRAMP_ADDR(name) &STATIC_CALL_TRAMP(name)
@@ -152,8 +154,8 @@ extern void arch_static_call_transform(v
 #define static_call_update(name, func)					\
 ({									\
 	typeof(&STATIC_CALL_TYPE(name)) __F = (func);			\
-	__static_call_update(&STATIC_CALL_KEY(name),			\
-			     STATIC_CALL_TRAMP_ADDR(name), __F);	\
+	____static_call_update(&STATIC_CALL_KEY(name),			\
+			     STATIC_CALL_TRAMP_ADDR(name), __F, true);	\
 })
 
 #define static_call_query(name) (READ_ONCE(STATIC_CALL_KEY(name).func))
@@ -175,6 +177,8 @@ struct static_call_tramp_key {
 };
 
 extern void __static_call_update(struct static_call_key *key, void *tramp, void *func);
+extern void ____static_call_update(struct static_call_key *key, void *tramp, void *func,
+				   bool checked);
 extern int static_call_mod_init(struct module *mod);
 extern int static_call_text_reserved(void *start, void *end);
 
@@ -235,14 +239,21 @@ static inline int static_call_init(void)
 #define static_call_cond(name)	(void)__static_call(name)
 
 static inline
-void __static_call_update(struct static_call_key *key, void *tramp, void *func)
+void ____static_call_update(struct static_call_key *key, void *tramp,
+			    void *func, bool checked)
 {
 	cpus_read_lock();
 	WRITE_ONCE(key->func, func);
-	arch_static_call_transform(NULL, tramp, func, false);
+	__arch_static_call_transform(NULL, tramp, func, false, checked);
 	cpus_read_unlock();
 }
 
+static inline
+void __static_call_update(struct static_call_key *key, void *tramp, void *func)
+{
+	____static_call_update(key, tramp, func, false);
+}
+
 static inline int static_call_text_reserved(void *start, void *end)
 {
 	return 0;
@@ -312,11 +323,18 @@ static inline void __static_call_nop(voi
 #define static_call_cond(name)	(void)__static_call_cond(name)
 
 static inline
-void __static_call_update(struct static_call_key *key, void *tramp, void *func)
+void ____static_call_update(struct static_call_key *key, void *tramp,
+			    void *func, bool checked)
 {
 	WRITE_ONCE(key->func, func);
 }
 
+static inline
+void __static_call_update(struct static_call_key *key, void *tramp, void *func)
+{
+	____static_call_update(key, tramp, func, false);
+}
+
 static inline int static_call_text_reserved(void *start, void *end)
 {
 	return 0;
--- a/include/linux/tracepoint.h
+++ b/include/linux/tracepoint.h
@@ -36,6 +36,11 @@ struct trace_eval_map {
 extern struct srcu_struct tracepoint_srcu;
 
 extern int
+__tracepoint_probe_register(struct tracepoint *tp, void *probe, void *data, bool checked);
+extern int
+__tracepoint_probe_register_prio(struct tracepoint *tp, void *probe, void *data, int prio,
+				 bool checked);
+extern int
 tracepoint_probe_register(struct tracepoint *tp, void *probe, void *data);
 extern int
 tracepoint_probe_register_prio(struct tracepoint *tp, void *probe, void *data,
@@ -44,6 +49,9 @@ extern int
 tracepoint_probe_register_prio_may_exist(struct tracepoint *tp, void *probe, void *data,
 					 int prio);
 extern int
+__tracepoint_probe_unregister(struct tracepoint *tp, void *probe, void *data,
+			      bool checked);
+extern int
 tracepoint_probe_unregister(struct tracepoint *tp, void *probe, void *data);
 static inline int
 tracepoint_probe_register_may_exist(struct tracepoint *tp, void *probe,
@@ -256,21 +264,21 @@ static inline struct tracepoint *tracepo
 	static inline int						\
 	register_trace_##name(void (*probe)(data_proto), void *data)	\
 	{								\
-		return tracepoint_probe_register(&__tracepoint_##name,	\
-						(void *)probe, data);	\
+		return __tracepoint_probe_register(&__tracepoint_##name,	\
+						(void *)probe, data, true);	\
 	}								\
 	static inline int						\
 	register_trace_prio_##name(void (*probe)(data_proto), void *data,\
 				   int prio)				\
 	{								\
-		return tracepoint_probe_register_prio(&__tracepoint_##name, \
-					      (void *)probe, data, prio); \
+		return __tracepoint_probe_register_prio(&__tracepoint_##name, \
+					      (void *)probe, data, prio, true); \
 	}								\
 	static inline int						\
 	unregister_trace_##name(void (*probe)(data_proto), void *data)	\
 	{								\
-		return tracepoint_probe_unregister(&__tracepoint_##name,\
-						(void *)probe, data);	\
+		return __tracepoint_probe_unregister(&__tracepoint_##name,\
+						(void *)probe, data, true);	\
 	}								\
 	static inline void						\
 	check_trace_callback_type_##name(void (*cb)(data_proto))	\
--- a/kernel/static_call.c
+++ b/kernel/static_call.c
@@ -120,7 +120,8 @@ static inline struct static_call_site *s
 	return (struct static_call_site *)(key->type & ~1);
 }
 
-void __static_call_update(struct static_call_key *key, void *tramp, void *func)
+void ____static_call_update(struct static_call_key *key, void *tramp, void *func,
+			    bool checked)
 {
 	struct static_call_site *site, *stop;
 	struct static_call_mod *site_mod, first;
@@ -133,7 +134,7 @@ void __static_call_update(struct static_
 
 	key->func = func;
 
-	arch_static_call_transform(NULL, tramp, func, false);
+	__arch_static_call_transform(NULL, tramp, func, false, checked);
 
 	/*
 	 * If uninitialized, we'll not update the callsites, but they still
@@ -195,8 +196,8 @@ void __static_call_update(struct static_
 				continue;
 			}
 
-			arch_static_call_transform(site_addr, NULL, func,
-						   static_call_is_tail(site));
+			__arch_static_call_transform(site_addr, NULL, func,
+						     static_call_is_tail(site), checked);
 		}
 	}
 
@@ -204,11 +205,17 @@ done:
 	static_call_unlock();
 	cpus_read_unlock();
 }
+EXPORT_SYMBOL_GPL(____static_call_update);
+
+void __static_call_update(struct static_call_key *key, void *tramp, void *func)
+{
+	____static_call_update(key, tramp, func, false);
+}
 EXPORT_SYMBOL_GPL(__static_call_update);
 
 static int __static_call_init(struct module *mod,
 			      struct static_call_site *start,
-			      struct static_call_site *stop)
+			      struct static_call_site *stop, bool checked)
 {
 	struct static_call_site *site;
 	struct static_call_key *key, *prev_key = NULL;
@@ -272,8 +279,8 @@ static int __static_call_init(struct mod
 		}
 
 do_transform:
-		arch_static_call_transform(site_addr, NULL, key->func,
-				static_call_is_tail(site));
+		__arch_static_call_transform(site_addr, NULL, key->func,
+				static_call_is_tail(site), checked);
 	}
 
 	return 0;
@@ -386,7 +393,8 @@ static int static_call_add_module(struct
 		site->key = key - (long)&site->key;
 	}
 
-	return __static_call_init(mod, start, stop);
+	return __static_call_init(mod, start, stop,
+				  !test_bit(TAINT_OOT_MODULE, &mod->taints));
 }
 
 static void static_call_del_module(struct module *mod)
@@ -481,7 +489,7 @@ int __init static_call_init(void)
 	cpus_read_lock();
 	static_call_lock();
 	ret = __static_call_init(NULL, __start_static_call_sites,
-				 __stop_static_call_sites);
+				 __stop_static_call_sites, true);
 	static_call_unlock();
 	cpus_read_unlock();
 
--- a/kernel/trace/trace_events.c
+++ b/kernel/trace/trace_events.c
@@ -515,24 +515,24 @@ int trace_event_reg(struct trace_event_c
 	WARN_ON(!(call->flags & TRACE_EVENT_FL_TRACEPOINT));
 	switch (type) {
 	case TRACE_REG_REGISTER:
-		return tracepoint_probe_register(call->tp,
-						 call->class->probe,
-						 file);
+		return __tracepoint_probe_register(call->tp,
+						   call->class->probe,
+						   file, true);
 	case TRACE_REG_UNREGISTER:
-		tracepoint_probe_unregister(call->tp,
+		__tracepoint_probe_unregister(call->tp,
 					    call->class->probe,
-					    file);
+					    file, true);
 		return 0;
 
 #ifdef CONFIG_PERF_EVENTS
 	case TRACE_REG_PERF_REGISTER:
-		return tracepoint_probe_register(call->tp,
-						 call->class->perf_probe,
-						 call);
+		return __tracepoint_probe_register(call->tp,
+						   call->class->perf_probe,
+						   call, true);
 	case TRACE_REG_PERF_UNREGISTER:
-		tracepoint_probe_unregister(call->tp,
-					    call->class->perf_probe,
-					    call);
+		__tracepoint_probe_unregister(call->tp,
+					      call->class->perf_probe,
+					      call, true);
 		return 0;
 	case TRACE_REG_PERF_OPEN:
 	case TRACE_REG_PERF_CLOSE:
--- a/kernel/tracepoint.c
+++ b/kernel/tracepoint.c
@@ -305,7 +305,9 @@ static enum tp_func_state nr_func_state(
 	return TP_FUNC_N;	/* 3 or more */
 }
 
-static void tracepoint_update_call(struct tracepoint *tp, struct tracepoint_func *tp_funcs)
+static void tracepoint_update_call(struct tracepoint *tp,
+				   struct tracepoint_func *tp_funcs,
+				   bool checked)
 {
 	void *func = tp->iterator;
 
@@ -314,7 +316,8 @@ static void tracepoint_update_call(struc
 		return;
 	if (nr_func_state(tp_funcs) == TP_FUNC_1)
 		func = tp_funcs[0].func;
-	__static_call_update(tp->static_call_key, tp->static_call_tramp, func);
+	____static_call_update(tp->static_call_key, tp->static_call_tramp, func,
+			     checked);
 }
 
 /*
@@ -322,7 +325,7 @@ static void tracepoint_update_call(struc
  */
 static int tracepoint_add_func(struct tracepoint *tp,
 			       struct tracepoint_func *func, int prio,
-			       bool warn)
+			       bool warn, bool checked)
 {
 	struct tracepoint_func *old, *tp_funcs;
 	int ret;
@@ -355,14 +358,14 @@ static int tracepoint_add_func(struct tr
 		 */
 		tp_rcu_cond_sync(TP_TRANSITION_SYNC_1_0_1);
 		/* Set static call to first function */
-		tracepoint_update_call(tp, tp_funcs);
+		tracepoint_update_call(tp, tp_funcs, checked);
 		/* Both iterator and static call handle NULL tp->funcs */
 		rcu_assign_pointer(tp->funcs, tp_funcs);
 		static_key_enable(&tp->key);
 		break;
 	case TP_FUNC_2:		/* 1->2 */
 		/* Set iterator static call */
-		tracepoint_update_call(tp, tp_funcs);
+		tracepoint_update_call(tp, tp_funcs, checked);
 		/*
 		 * Iterator callback installed before updating tp->funcs.
 		 * Requires ordering between RCU assign/dereference and
@@ -394,7 +397,7 @@ static int tracepoint_add_func(struct tr
  * by preempt_disable around the call site.
  */
 static int tracepoint_remove_func(struct tracepoint *tp,
-		struct tracepoint_func *func)
+		struct tracepoint_func *func, bool checked)
 {
 	struct tracepoint_func *old, *tp_funcs;
 
@@ -416,7 +419,7 @@ static int tracepoint_remove_func(struct
 
 		static_key_disable(&tp->key);
 		/* Set iterator static call */
-		tracepoint_update_call(tp, tp_funcs);
+		tracepoint_update_call(tp, tp_funcs, checked);
 		/* Both iterator and static call handle NULL tp->funcs */
 		rcu_assign_pointer(tp->funcs, NULL);
 		/*
@@ -438,7 +441,7 @@ static int tracepoint_remove_func(struct
 			tp_rcu_get_state(TP_TRANSITION_SYNC_N_2_1);
 		tp_rcu_cond_sync(TP_TRANSITION_SYNC_N_2_1);
 		/* Set static call to first function */
-		tracepoint_update_call(tp, tp_funcs);
+		tracepoint_update_call(tp, tp_funcs, checked);
 		break;
 	case TP_FUNC_2:		/* N->N-1 (N>2) */
 		fallthrough;
@@ -479,7 +482,7 @@ int tracepoint_probe_register_prio_may_e
 	tp_func.func = probe;
 	tp_func.data = data;
 	tp_func.prio = prio;
-	ret = tracepoint_add_func(tp, &tp_func, prio, false);
+	ret = tracepoint_add_func(tp, &tp_func, prio, false, false);
 	mutex_unlock(&tracepoints_mutex);
 	return ret;
 }
@@ -498,8 +501,9 @@ EXPORT_SYMBOL_GPL(tracepoint_probe_regis
  * performed either with a tracepoint module going notifier, or from
  * within module exit functions.
  */
-int tracepoint_probe_register_prio(struct tracepoint *tp, void *probe,
-				   void *data, int prio)
+int __tracepoint_probe_register_prio(struct tracepoint *tp, void *probe,
+					    void *data, int prio,
+					    bool checked)
 {
 	struct tracepoint_func tp_func;
 	int ret;
@@ -508,12 +512,25 @@ int tracepoint_probe_register_prio(struc
 	tp_func.func = probe;
 	tp_func.data = data;
 	tp_func.prio = prio;
-	ret = tracepoint_add_func(tp, &tp_func, prio, true);
+	ret = tracepoint_add_func(tp, &tp_func, prio, true, checked);
 	mutex_unlock(&tracepoints_mutex);
 	return ret;
 }
+EXPORT_SYMBOL_GPL(__tracepoint_probe_register_prio);
+
+int tracepoint_probe_register_prio(struct tracepoint *tp, void *probe,
+				   void *data, int prio)
+{
+	return __tracepoint_probe_register_prio(tp, probe, data, prio, false);
+}
 EXPORT_SYMBOL_GPL(tracepoint_probe_register_prio);
 
+int __tracepoint_probe_register(struct tracepoint *tp, void *probe, void *data, bool checked)
+{
+	return __tracepoint_probe_register_prio(tp, probe, data, TRACEPOINT_DEFAULT_PRIO, checked);
+}
+EXPORT_SYMBOL_GPL(__tracepoint_probe_register);
+
 /**
  * tracepoint_probe_register -  Connect a probe to a tracepoint
  * @tp: tracepoint
@@ -540,7 +557,8 @@ EXPORT_SYMBOL_GPL(tracepoint_probe_regis
  *
  * Returns 0 if ok, error value on error.
  */
-int tracepoint_probe_unregister(struct tracepoint *tp, void *probe, void *data)
+int __tracepoint_probe_unregister(struct tracepoint *tp, void *probe,
+					 void *data, bool checked)
 {
 	struct tracepoint_func tp_func;
 	int ret;
@@ -548,10 +566,17 @@ int tracepoint_probe_unregister(struct t
 	mutex_lock(&tracepoints_mutex);
 	tp_func.func = probe;
 	tp_func.data = data;
-	ret = tracepoint_remove_func(tp, &tp_func);
+	ret = tracepoint_remove_func(tp, &tp_func, checked);
 	mutex_unlock(&tracepoints_mutex);
 	return ret;
 }
+EXPORT_SYMBOL_GPL(__tracepoint_probe_unregister);
+
+int tracepoint_probe_unregister(struct tracepoint *tp, void *probe,
+				void *data)
+{
+	return __tracepoint_probe_unregister(tp, probe, data, false);
+}
 EXPORT_SYMBOL_GPL(tracepoint_probe_unregister);
 
 static void for_each_tracepoint_range(