Blob Blame History Raw
From: Jakub Sitnicki <jakub@cloudflare.com>
Date: Fri, 17 Jul 2020 12:35:23 +0200
Subject: bpf: Introduce SK_LOOKUP program type with a dedicated attach point
Patch-mainline: v5.9-rc1
Git-commit: e9ddbb7707ff5891616240026062b8c1e29864ca
References: bsc#1177028

Add a new program type BPF_PROG_TYPE_SK_LOOKUP with a dedicated attach type
BPF_SK_LOOKUP. The new program kind is to be invoked by the transport layer
when looking up a listening socket for a new connection request for
connection oriented protocols, or when looking up an unconnected socket for
a packet for connection-less protocols.

When called, SK_LOOKUP BPF program can select a socket that will receive
the packet. This serves as a mechanism to overcome the limits of what
bind() API allows to express. Two use-cases driving this work are:

 (1) steer packets destined to an IP range, on fixed port to a socket

     192.0.2.0/24, port 80 -> NGINX socket

 (2) steer packets destined to an IP address, on any port to a socket

     198.51.100.1, any port -> L7 proxy socket

In its run-time context program receives information about the packet that
triggered the socket lookup. Namely IP version, L4 protocol identifier, and
address 4-tuple. Context can be further extended to include ingress
interface identifier.

To select a socket BPF program fetches it from a map holding socket
references, like SOCKMAP or SOCKHASH, and calls bpf_sk_assign(ctx, sk, ...)
helper to record the selection. Transport layer then uses the selected
socket as a result of socket lookup.

In its basic form, SK_LOOKUP acts as a filter and hence must return either
SK_PASS or SK_DROP. If the program returns with SK_PASS, transport should
look for a socket to receive the packet, or use the one selected by the
program if available, while SK_DROP informs the transport layer that the
lookup should fail.

This patch only enables the user to attach an SK_LOOKUP program to a
network namespace. Subsequent patches hook it up to run on local delivery
path in ipv4 and ipv6 stacks.

Suggested-by: Marek Majkowski <marek@cloudflare.com>
Signed-off-by: Jakub Sitnicki <jakub@cloudflare.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20200717103536.397595-3-jakub@cloudflare.com
Acked-by: Gary Lin <glin@suse.com>
---
 include/linux/bpf-netns.h  |    3 
 include/linux/bpf.h        |    1 
 include/linux/bpf_types.h  |    2 
 include/linux/filter.h     |   17 ++++
 include/uapi/linux/bpf.h   |   77 +++++++++++++++++++
 kernel/bpf/net_namespace.c |    5 +
 kernel/bpf/syscall.c       |    9 ++
 kernel/bpf/verifier.c      |   13 ++-
 net/core/filter.c          |  180 +++++++++++++++++++++++++++++++++++++++++++++
 scripts/bpf_helpers_doc.py |    9 ++
 10 files changed, 312 insertions(+), 4 deletions(-)

--- a/include/linux/bpf-netns.h
+++ b/include/linux/bpf-netns.h
@@ -8,6 +8,7 @@
 enum netns_bpf_attach_type {
 	NETNS_BPF_INVALID = -1,
 	NETNS_BPF_FLOW_DISSECTOR = 0,
+	NETNS_BPF_SK_LOOKUP,
 	MAX_NETNS_BPF_ATTACH_TYPE
 };
 
@@ -17,6 +18,8 @@ to_netns_bpf_attach_type(enum bpf_attach
 	switch (attach_type) {
 	case BPF_FLOW_DISSECTOR:
 		return NETNS_BPF_FLOW_DISSECTOR;
+	case BPF_SK_LOOKUP:
+		return NETNS_BPF_SK_LOOKUP;
 	default:
 		return NETNS_BPF_INVALID;
 	}
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -249,6 +249,7 @@ enum bpf_arg_type {
 	ARG_PTR_TO_INT,		/* pointer to int */
 	ARG_PTR_TO_LONG,	/* pointer to long */
 	ARG_PTR_TO_SOCKET,	/* pointer to bpf_sock (fullsock) */
+	ARG_PTR_TO_SOCKET_OR_NULL,	/* pointer to bpf_sock (fullsock) or NULL */
 	ARG_PTR_TO_BTF_ID,	/* pointer to in-kernel struct */
 	ARG_PTR_TO_ALLOC_MEM,	/* pointer to dynamically allocated memory */
 	ARG_PTR_TO_ALLOC_MEM_OR_NULL,	/* pointer to dynamically allocated memory or NULL */
--- a/include/linux/bpf_types.h
+++ b/include/linux/bpf_types.h
@@ -64,6 +64,8 @@ BPF_PROG_TYPE(BPF_PROG_TYPE_LIRC_MODE2,
 #ifdef CONFIG_INET
 BPF_PROG_TYPE(BPF_PROG_TYPE_SK_REUSEPORT, sk_reuseport,
 	      struct sk_reuseport_md, struct sk_reuseport_kern)
+BPF_PROG_TYPE(BPF_PROG_TYPE_SK_LOOKUP, sk_lookup,
+	      struct bpf_sk_lookup, struct bpf_sk_lookup_kern)
 #endif
 #if defined(CONFIG_BPF_JIT)
 BPF_PROG_TYPE(BPF_PROG_TYPE_STRUCT_OPS, bpf_struct_ops,
--- a/include/linux/filter.h
+++ b/include/linux/filter.h
@@ -1276,4 +1276,21 @@ struct bpf_sockopt_kern {
 
 int copy_bpf_fprog_from_user(struct sock_fprog *dst, void __user *src, int len);
 
+struct bpf_sk_lookup_kern {
+	u16		family;
+	u16		protocol;
+	struct {
+		__be32 saddr;
+		__be32 daddr;
+	} v4;
+	struct {
+		const struct in6_addr *saddr;
+		const struct in6_addr *daddr;
+	} v6;
+	__be16		sport;
+	u16		dport;
+	struct sock	*selected_sk;
+	bool		no_reuseport;
+};
+
 #endif /* __LINUX_FILTER_H__ */
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -189,6 +189,7 @@ enum bpf_prog_type {
 	BPF_PROG_TYPE_STRUCT_OPS,
 	BPF_PROG_TYPE_EXT,
 	BPF_PROG_TYPE_LSM,
+	BPF_PROG_TYPE_SK_LOOKUP,
 };
 
 enum bpf_attach_type {
@@ -228,6 +229,7 @@ enum bpf_attach_type {
 	BPF_XDP_DEVMAP,
 	BPF_CGROUP_INET_SOCK_RELEASE,
 	BPF_XDP_CPUMAP,
+	BPF_SK_LOOKUP,
 	__MAX_BPF_ATTACH_TYPE
 };
 
@@ -3069,6 +3071,10 @@ union bpf_attr {
  *
  * long bpf_sk_assign(struct sk_buff *skb, struct bpf_sock *sk, u64 flags)
  *	Description
+ *		Helper is overloaded depending on BPF program type. This
+ *		description applies to **BPF_PROG_TYPE_SCHED_CLS** and
+ *		**BPF_PROG_TYPE_SCHED_ACT** programs.
+ *
  *		Assign the *sk* to the *skb*. When combined with appropriate
  *		routing configuration to receive the packet towards the socket,
  *		will cause *skb* to be delivered to the specified socket.
@@ -3094,6 +3100,56 @@ union bpf_attr {
  *		**-ESOCKTNOSUPPORT** if the socket type is not supported
  *		(reuseport).
  *
+ * long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
+ *	Description
+ *		Helper is overloaded depending on BPF program type. This
+ *		description applies to **BPF_PROG_TYPE_SK_LOOKUP** programs.
+ *
+ *		Select the *sk* as a result of a socket lookup.
+ *
+ *		For the operation to succeed passed socket must be compatible
+ *		with the packet description provided by the *ctx* object.
+ *
+ *		L4 protocol (**IPPROTO_TCP** or **IPPROTO_UDP**) must
+ *		be an exact match. While IP family (**AF_INET** or
+ *		**AF_INET6**) must be compatible, that is IPv6 sockets
+ *		that are not v6-only can be selected for IPv4 packets.
+ *
+ *		Only TCP listeners and UDP unconnected sockets can be
+ *		selected. *sk* can also be NULL to reset any previous
+ *		selection.
+ *
+ *		*flags* argument can combination of following values:
+ *
+ *		* **BPF_SK_LOOKUP_F_REPLACE** to override the previous
+ *		  socket selection, potentially done by a BPF program
+ *		  that ran before us.
+ *
+ *		* **BPF_SK_LOOKUP_F_NO_REUSEPORT** to skip
+ *		  load-balancing within reuseport group for the socket
+ *		  being selected.
+ *
+ *		On success *ctx->sk* will point to the selected socket.
+ *
+ *	Return
+ *		0 on success, or a negative errno in case of failure.
+ *
+ *		* **-EAFNOSUPPORT** if socket family (*sk->family*) is
+ *		  not compatible with packet family (*ctx->family*).
+ *
+ *		* **-EEXIST** if socket has been already selected,
+ *		  potentially by another program, and
+ *		  **BPF_SK_LOOKUP_F_REPLACE** flag was not specified.
+ *
+ *		* **-EINVAL** if unsupported flags were specified.
+ *
+ *		* **-EPROTOTYPE** if socket L4 protocol
+ *		  (*sk->protocol*) doesn't match packet protocol
+ *		  (*ctx->protocol*).
+ *
+ *		* **-ESOCKTNOSUPPORT** if socket is not in allowed
+ *		  state (TCP listening or UDP unconnected).
+ *
  * u64 bpf_ktime_get_boot_ns(void)
  * 	Description
  * 		Return the time elapsed since system boot, in nanoseconds.
@@ -3607,6 +3663,12 @@ enum {
 	BPF_RINGBUF_HDR_SZ		= 8,
 };
 
+/* BPF_FUNC_sk_assign flags in bpf_sk_lookup context. */
+enum {
+	BPF_SK_LOOKUP_F_REPLACE		= (1ULL << 0),
+	BPF_SK_LOOKUP_F_NO_REUSEPORT	= (1ULL << 1),
+};
+
 /* Mode for BPF_FUNC_skb_adjust_room helper. */
 enum bpf_adj_room_mode {
 	BPF_ADJ_ROOM_NET,
@@ -4349,4 +4411,19 @@ struct bpf_pidns_info {
 	__u32 pid;
 	__u32 tgid;
 };
+
+/* User accessible data for SK_LOOKUP programs. Add new fields at the end. */
+struct bpf_sk_lookup {
+	__bpf_md_ptr(struct bpf_sock *, sk); /* Selected socket */
+
+	__u32 family;		/* Protocol family (AF_INET, AF_INET6) */
+	__u32 protocol;		/* IP protocol (IPPROTO_TCP, IPPROTO_UDP) */
+	__u32 remote_ip4;	/* Network byte order */
+	__u32 remote_ip6[4];	/* Network byte order */
+	__u32 remote_port;	/* Network byte order */
+	__u32 local_ip4;	/* Network byte order */
+	__u32 local_ip6[4];	/* Network byte order */
+	__u32 local_port;	/* Host byte order */
+};
+
 #endif /* _UAPI__LINUX_BPF_H__ */
--- a/kernel/bpf/net_namespace.c
+++ b/kernel/bpf/net_namespace.c
@@ -373,6 +373,8 @@ static int netns_bpf_max_progs(enum netn
 	switch (type) {
 	case NETNS_BPF_FLOW_DISSECTOR:
 		return 1;
+	case NETNS_BPF_SK_LOOKUP:
+		return 64;
 	default:
 		return 0;
 	}
@@ -403,6 +405,9 @@ static int netns_bpf_link_attach(struct
 	case NETNS_BPF_FLOW_DISSECTOR:
 		err = flow_dissector_bpf_prog_attach_check(net, link->prog);
 		break;
+	case NETNS_BPF_SK_LOOKUP:
+		err = 0; /* nothing to check */
+		break;
 	default:
 		err = -EINVAL;
 		break;
--- a/kernel/bpf/syscall.c
+++ b/kernel/bpf/syscall.c
@@ -2019,6 +2019,10 @@ bpf_prog_load_check_attach(enum bpf_prog
 		default:
 			return -EINVAL;
 		}
+	case BPF_PROG_TYPE_SK_LOOKUP:
+		if (expected_attach_type == BPF_SK_LOOKUP)
+			return 0;
+		return -EINVAL;
 	case BPF_PROG_TYPE_EXT:
 		if (expected_attach_type)
 			return -EINVAL;
@@ -2753,6 +2757,7 @@ static int bpf_prog_attach_check_attach_
 	case BPF_PROG_TYPE_CGROUP_SOCK:
 	case BPF_PROG_TYPE_CGROUP_SOCK_ADDR:
 	case BPF_PROG_TYPE_CGROUP_SOCKOPT:
+	case BPF_PROG_TYPE_SK_LOOKUP:
 		return attach_type == prog->expected_attach_type ? 0 : -EINVAL;
 	case BPF_PROG_TYPE_CGROUP_SKB:
 		if (!capable(CAP_NET_ADMIN))
@@ -2814,6 +2819,8 @@ attach_type_to_prog_type(enum bpf_attach
 		return BPF_PROG_TYPE_CGROUP_SOCKOPT;
 	case BPF_TRACE_ITER:
 		return BPF_PROG_TYPE_TRACING;
+	case BPF_SK_LOOKUP:
+		return BPF_PROG_TYPE_SK_LOOKUP;
 	default:
 		return BPF_PROG_TYPE_UNSPEC;
 	}
@@ -2950,6 +2957,7 @@ static int bpf_prog_query(const union bp
 	case BPF_LIRC_MODE2:
 		return lirc_prog_query(attr, uattr);
 	case BPF_FLOW_DISSECTOR:
+	case BPF_SK_LOOKUP:
 		return netns_bpf_prog_query(attr, uattr);
 	default:
 		return -EINVAL;
@@ -3888,6 +3896,7 @@ static int link_create(union bpf_attr *a
 		ret = tracing_bpf_link_attach(attr, prog);
 		break;
 	case BPF_PROG_TYPE_FLOW_DISSECTOR:
+	case BPF_PROG_TYPE_SK_LOOKUP:
 		ret = netns_bpf_link_create(attr, prog);
 		break;
 	default:
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -3878,10 +3878,14 @@ static int check_func_arg(struct bpf_ver
 			}
 			meta->ref_obj_id = reg->ref_obj_id;
 		}
-	} else if (arg_type == ARG_PTR_TO_SOCKET) {
+	} else if (arg_type == ARG_PTR_TO_SOCKET ||
+		   arg_type == ARG_PTR_TO_SOCKET_OR_NULL) {
 		expected_type = PTR_TO_SOCKET;
-		if (type != expected_type)
-			goto err_type;
+		if (!(register_is_null(reg) &&
+		      arg_type == ARG_PTR_TO_SOCKET_OR_NULL)) {
+			if (type != expected_type)
+				goto err_type;
+		}
 	} else if (arg_type == ARG_PTR_TO_BTF_ID) {
 		expected_type = PTR_TO_BTF_ID;
 		if (type != expected_type)
@@ -7354,6 +7358,9 @@ static int check_return_code(struct bpf_
 			return -ENOTSUPP;
 		}
 		break;
+	case BPF_PROG_TYPE_SK_LOOKUP:
+		range = tnum_range(SK_DROP, SK_PASS);
+		break;
 	case BPF_PROG_TYPE_EXT:
 		/* freplace program can return anything as its return value
 		 * depends on the to-be-replaced kernel func or bpf program.
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -9266,6 +9266,186 @@ const struct bpf_verifier_ops sk_reusepo
 
 const struct bpf_prog_ops sk_reuseport_prog_ops = {
 };
+
+BPF_CALL_3(bpf_sk_lookup_assign, struct bpf_sk_lookup_kern *, ctx,
+	   struct sock *, sk, u64, flags)
+{
+	if (unlikely(flags & ~(BPF_SK_LOOKUP_F_REPLACE |
+			       BPF_SK_LOOKUP_F_NO_REUSEPORT)))
+		return -EINVAL;
+	if (unlikely(sk && sk_is_refcounted(sk)))
+		return -ESOCKTNOSUPPORT; /* reject non-RCU freed sockets */
+	if (unlikely(sk && sk->sk_state == TCP_ESTABLISHED))
+		return -ESOCKTNOSUPPORT; /* reject connected sockets */
+
+	/* Check if socket is suitable for packet L3/L4 protocol */
+	if (sk && sk->sk_protocol != ctx->protocol)
+		return -EPROTOTYPE;
+	if (sk && sk->sk_family != ctx->family &&
+	    (sk->sk_family == AF_INET || ipv6_only_sock(sk)))
+		return -EAFNOSUPPORT;
+
+	if (ctx->selected_sk && !(flags & BPF_SK_LOOKUP_F_REPLACE))
+		return -EEXIST;
+
+	/* Select socket as lookup result */
+	ctx->selected_sk = sk;
+	ctx->no_reuseport = flags & BPF_SK_LOOKUP_F_NO_REUSEPORT;
+	return 0;
+}
+
+static const struct bpf_func_proto bpf_sk_lookup_assign_proto = {
+	.func		= bpf_sk_lookup_assign,
+	.gpl_only	= false,
+	.ret_type	= RET_INTEGER,
+	.arg1_type	= ARG_PTR_TO_CTX,
+	.arg2_type	= ARG_PTR_TO_SOCKET_OR_NULL,
+	.arg3_type	= ARG_ANYTHING,
+};
+
+static const struct bpf_func_proto *
+sk_lookup_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
+{
+	switch (func_id) {
+	case BPF_FUNC_perf_event_output:
+		return &bpf_event_output_data_proto;
+	case BPF_FUNC_sk_assign:
+		return &bpf_sk_lookup_assign_proto;
+	case BPF_FUNC_sk_release:
+		return &bpf_sk_release_proto;
+	default:
+		return bpf_base_func_proto(func_id);
+	}
+}
+
+static bool sk_lookup_is_valid_access(int off, int size,
+				      enum bpf_access_type type,
+				      const struct bpf_prog *prog,
+				      struct bpf_insn_access_aux *info)
+{
+	if (off < 0 || off >= sizeof(struct bpf_sk_lookup))
+		return false;
+	if (off % size != 0)
+		return false;
+	if (type != BPF_READ)
+		return false;
+
+	switch (off) {
+	case offsetof(struct bpf_sk_lookup, sk):
+		info->reg_type = PTR_TO_SOCKET_OR_NULL;
+		return size == sizeof(__u64);
+
+	case bpf_ctx_range(struct bpf_sk_lookup, family):
+	case bpf_ctx_range(struct bpf_sk_lookup, protocol):
+	case bpf_ctx_range(struct bpf_sk_lookup, remote_ip4):
+	case bpf_ctx_range(struct bpf_sk_lookup, local_ip4):
+	case bpf_ctx_range_till(struct bpf_sk_lookup, remote_ip6[0], remote_ip6[3]):
+	case bpf_ctx_range_till(struct bpf_sk_lookup, local_ip6[0], local_ip6[3]):
+	case bpf_ctx_range(struct bpf_sk_lookup, remote_port):
+	case bpf_ctx_range(struct bpf_sk_lookup, local_port):
+		bpf_ctx_record_field_size(info, sizeof(__u32));
+		return bpf_ctx_narrow_access_ok(off, size, sizeof(__u32));
+
+	default:
+		return false;
+	}
+}
+
+static u32 sk_lookup_convert_ctx_access(enum bpf_access_type type,
+					const struct bpf_insn *si,
+					struct bpf_insn *insn_buf,
+					struct bpf_prog *prog,
+					u32 *target_size)
+{
+	struct bpf_insn *insn = insn_buf;
+
+	switch (si->off) {
+	case offsetof(struct bpf_sk_lookup, sk):
+		*insn++ = BPF_LDX_MEM(BPF_SIZEOF(void *), si->dst_reg, si->src_reg,
+				      offsetof(struct bpf_sk_lookup_kern, selected_sk));
+		break;
+
+	case offsetof(struct bpf_sk_lookup, family):
+		*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg,
+				      bpf_target_off(struct bpf_sk_lookup_kern,
+						     family, 2, target_size));
+		break;
+
+	case offsetof(struct bpf_sk_lookup, protocol):
+		*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg,
+				      bpf_target_off(struct bpf_sk_lookup_kern,
+						     protocol, 2, target_size));
+		break;
+
+	case offsetof(struct bpf_sk_lookup, remote_ip4):
+		*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg,
+				      bpf_target_off(struct bpf_sk_lookup_kern,
+						     v4.saddr, 4, target_size));
+		break;
+
+	case offsetof(struct bpf_sk_lookup, local_ip4):
+		*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg,
+				      bpf_target_off(struct bpf_sk_lookup_kern,
+						     v4.daddr, 4, target_size));
+		break;
+
+	case bpf_ctx_range_till(struct bpf_sk_lookup,
+				remote_ip6[0], remote_ip6[3]): {
+#if IS_ENABLED(CONFIG_IPV6)
+		int off = si->off;
+
+		off -= offsetof(struct bpf_sk_lookup, remote_ip6[0]);
+		off += bpf_target_off(struct in6_addr, s6_addr32[0], 4, target_size);
+		*insn++ = BPF_LDX_MEM(BPF_SIZEOF(void *), si->dst_reg, si->src_reg,
+				      offsetof(struct bpf_sk_lookup_kern, v6.saddr));
+		*insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1);
+		*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, off);
+#else
+		*insn++ = BPF_MOV32_IMM(si->dst_reg, 0);
+#endif
+		break;
+	}
+	case bpf_ctx_range_till(struct bpf_sk_lookup,
+				local_ip6[0], local_ip6[3]): {
+#if IS_ENABLED(CONFIG_IPV6)
+		int off = si->off;
+
+		off -= offsetof(struct bpf_sk_lookup, local_ip6[0]);
+		off += bpf_target_off(struct in6_addr, s6_addr32[0], 4, target_size);
+		*insn++ = BPF_LDX_MEM(BPF_SIZEOF(void *), si->dst_reg, si->src_reg,
+				      offsetof(struct bpf_sk_lookup_kern, v6.daddr));
+		*insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1);
+		*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, off);
+#else
+		*insn++ = BPF_MOV32_IMM(si->dst_reg, 0);
+#endif
+		break;
+	}
+	case offsetof(struct bpf_sk_lookup, remote_port):
+		*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg,
+				      bpf_target_off(struct bpf_sk_lookup_kern,
+						     sport, 2, target_size));
+		break;
+
+	case offsetof(struct bpf_sk_lookup, local_port):
+		*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg,
+				      bpf_target_off(struct bpf_sk_lookup_kern,
+						     dport, 2, target_size));
+		break;
+	}
+
+	return insn - insn_buf;
+}
+
+const struct bpf_prog_ops sk_lookup_prog_ops = {
+};
+
+const struct bpf_verifier_ops sk_lookup_verifier_ops = {
+	.get_func_proto		= sk_lookup_func_proto,
+	.is_valid_access	= sk_lookup_is_valid_access,
+	.convert_ctx_access	= sk_lookup_convert_ctx_access,
+};
+
 #endif /* CONFIG_INET */
 
 DEFINE_BPF_DISPATCHER(xdp)
--- a/scripts/bpf_helpers_doc.py
+++ b/scripts/bpf_helpers_doc.py
@@ -404,6 +404,7 @@ class PrinterHelpers(Printer):
 
     type_fwds = [
             'struct bpf_fib_lookup',
+            'struct bpf_sk_lookup',
             'struct bpf_perf_event_data',
             'struct bpf_perf_event_value',
             'struct bpf_pidns_info',
@@ -450,6 +451,7 @@ class PrinterHelpers(Printer):
             'struct bpf_perf_event_data',
             'struct bpf_perf_event_value',
             'struct bpf_pidns_info',
+            'struct bpf_sk_lookup',
             'struct bpf_sock',
             'struct bpf_sock_addr',
             'struct bpf_sock_ops',
@@ -487,6 +489,11 @@ class PrinterHelpers(Printer):
             'struct sk_msg_buff': 'struct sk_msg_md',
             'struct xdp_buff': 'struct xdp_md',
     }
+    # Helpers overloaded for different context types.
+    overloaded_helpers = [
+        'bpf_get_socket_cookie',
+        'bpf_sk_assign',
+    ]
 
     def print_header(self):
         header = '''\
@@ -543,7 +550,7 @@ class PrinterHelpers(Printer):
         for i, a in enumerate(proto['args']):
             t = a['type']
             n = a['name']
-            if proto['name'] == 'bpf_get_socket_cookie' and i == 0:
+            if proto['name'] in self.overloaded_helpers and i == 0:
                     t = 'void'
                     n = 'ctx'
             one_arg = '{}{}'.format(comma, self.map_type(t))