Blob Blame History Raw
From: Kumar Kartikeya Dwivedi <memxor@gmail.com>
Date: Thu, 28 Oct 2021 12:04:57 +0530
Subject: libbpf: Ensure that BPF syscall fds are never 0, 1, or 2
Patch-mainline: v5.16-rc1
Git-commit: 549a63238603103fa33cecd49487cf6c0f52e503
References: jsc#PED-1377

Add a simple wrapper for passing an fd and getting a new one >= 3 if it
is one of 0, 1, or 2. There are two primary reasons to make this change:
First, libbpf relies on the assumption a certain BPF fd is never 0 (e.g.
most recently noticed in [0]). Second, Alexei pointed out in [1] that
some environments reset stdin, stdout, and stderr if they notice an
invalid fd at these numbers. To protect against both these cases, switch
all internal BPF syscall wrappers in libbpf to always return an fd >= 3.
We only need to modify the syscall wrappers and not other code that
assumes a valid fd by doing >= 0, to avoid pointless churn, and because
it is still a valid assumption. The cost paid is two additional syscalls
if fd is in range [0, 2].

  [0]: e31eec77e4ab ("bpf: selftests: Fix fd cleanup in get_branch_snapshot")
  [1]: https://lore.kernel.org/bpf/CAADnVQKVKY8o_3aU8Gzke443+uHa-eGoM0h7W4srChMXU1S4Bg@mail.gmail.com

Signed-off-by: Kumar Kartikeya Dwivedi <memxor@gmail.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Song Liu <songliubraving@fb.com>
Acked-by: Andrii Nakryiko <andrii@kernel.org>
Link: https://lore.kernel.org/bpf/20211028063501.2239335-5-memxor@gmail.com
Acked-by: Shung-Hsi Yu <shung-hsi.yu@suse.com>
---
 tools/lib/bpf/bpf.c             |   35 ++++++++++++++++++++++-------------
 tools/lib/bpf/libbpf_internal.h |   24 ++++++++++++++++++++++++
 2 files changed, 46 insertions(+), 13 deletions(-)

--- a/tools/lib/bpf/bpf.c
+++ b/tools/lib/bpf/bpf.c
@@ -65,13 +65,22 @@ static inline int sys_bpf(enum bpf_cmd c
 	return syscall(__NR_bpf, cmd, attr, size);
 }
 
+static inline int sys_bpf_fd(enum bpf_cmd cmd, union bpf_attr *attr,
+			     unsigned int size)
+{
+	int fd;
+
+	fd = sys_bpf(cmd, attr, size);
+	return ensure_good_fd(fd);
+}
+
 static inline int sys_bpf_prog_load(union bpf_attr *attr, unsigned int size)
 {
 	int retries = 5;
 	int fd;
 
 	do {
-		fd = sys_bpf(BPF_PROG_LOAD, attr, size);
+		fd = sys_bpf_fd(BPF_PROG_LOAD, attr, size);
 	} while (fd < 0 && errno == EAGAIN && retries-- > 0);
 
 	return fd;
@@ -104,7 +113,7 @@ int libbpf__bpf_create_map_xattr(const s
 		attr.inner_map_fd = create_attr->inner_map_fd;
 	attr.map_extra = create_attr->map_extra;
 
-	fd = sys_bpf(BPF_MAP_CREATE, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_MAP_CREATE, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -206,7 +215,7 @@ int bpf_create_map_in_map_node(enum bpf_
 		attr.numa_node = node;
 	}
 
-	fd = sys_bpf(BPF_MAP_CREATE, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_MAP_CREATE, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -634,7 +643,7 @@ int bpf_obj_get(const char *pathname)
 	memset(&attr, 0, sizeof(attr));
 	attr.pathname = ptr_to_u64((void *)pathname);
 
-	fd = sys_bpf(BPF_OBJ_GET, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_OBJ_GET, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -745,7 +754,7 @@ int bpf_link_create(int prog_fd, int tar
 		break;
 	}
 proceed:
-	fd = sys_bpf(BPF_LINK_CREATE, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_LINK_CREATE, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -788,7 +797,7 @@ int bpf_iter_create(int link_fd)
 	memset(&attr, 0, sizeof(attr));
 	attr.iter_create.link_fd = link_fd;
 
-	fd = sys_bpf(BPF_ITER_CREATE, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_ITER_CREATE, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -946,7 +955,7 @@ int bpf_prog_get_fd_by_id(__u32 id)
 	memset(&attr, 0, sizeof(attr));
 	attr.prog_id = id;
 
-	fd = sys_bpf(BPF_PROG_GET_FD_BY_ID, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_PROG_GET_FD_BY_ID, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -958,7 +967,7 @@ int bpf_map_get_fd_by_id(__u32 id)
 	memset(&attr, 0, sizeof(attr));
 	attr.map_id = id;
 
-	fd = sys_bpf(BPF_MAP_GET_FD_BY_ID, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_MAP_GET_FD_BY_ID, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -970,7 +979,7 @@ int bpf_btf_get_fd_by_id(__u32 id)
 	memset(&attr, 0, sizeof(attr));
 	attr.btf_id = id;
 
-	fd = sys_bpf(BPF_BTF_GET_FD_BY_ID, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_BTF_GET_FD_BY_ID, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -982,7 +991,7 @@ int bpf_link_get_fd_by_id(__u32 id)
 	memset(&attr, 0, sizeof(attr));
 	attr.link_id = id;
 
-	fd = sys_bpf(BPF_LINK_GET_FD_BY_ID, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_LINK_GET_FD_BY_ID, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -1013,7 +1022,7 @@ int bpf_raw_tracepoint_open(const char *
 	attr.raw_tracepoint.name = ptr_to_u64(name);
 	attr.raw_tracepoint.prog_fd = prog_fd;
 
-	fd = sys_bpf(BPF_RAW_TRACEPOINT_OPEN, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_RAW_TRACEPOINT_OPEN, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -1033,7 +1042,7 @@ retry:
 		attr.btf_log_buf = ptr_to_u64(log_buf);
 	}
 
-	fd = sys_bpf(BPF_BTF_LOAD, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_BTF_LOAD, &attr, sizeof(attr));
 
 	if (fd < 0 && !do_log && log_buf && log_buf_size) {
 		do_log = true;
@@ -1075,7 +1084,7 @@ int bpf_enable_stats(enum bpf_stats_type
 	memset(&attr, 0, sizeof(attr));
 	attr.enable_stats.type = type;
 
-	fd = sys_bpf(BPF_ENABLE_STATS, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_ENABLE_STATS, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
--- a/tools/lib/bpf/libbpf_internal.h
+++ b/tools/lib/bpf/libbpf_internal.h
@@ -13,6 +13,8 @@
 #include <limits.h>
 #include <errno.h>
 #include <linux/err.h>
+#include <fcntl.h>
+#include <unistd.h>
 #include "libbpf_legacy.h"
 #include "relo_core.h"
 
@@ -491,4 +493,26 @@ static inline bool is_ldimm64_insn(struc
 	return insn->code == (BPF_LD | BPF_IMM | BPF_DW);
 }
 
+/* if fd is stdin, stdout, or stderr, dup to a fd greater than 2
+ * Takes ownership of the fd passed in, and closes it if calling
+ * fcntl(fd, F_DUPFD_CLOEXEC, 3).
+ */
+static inline int ensure_good_fd(int fd)
+{
+	int old_fd = fd, saved_errno;
+
+	if (fd < 0)
+		return fd;
+	if (fd < 3) {
+		fd = fcntl(fd, F_DUPFD_CLOEXEC, 3);
+		saved_errno = errno;
+		close(old_fd);
+		if (fd < 0) {
+			pr_warn("failed to dup FD %d to FD > 2: %d\n", old_fd, -saved_errno);
+			errno = saved_errno;
+		}
+	}
+	return fd;
+}
+
 #endif /* __LIBBPF_LIBBPF_INTERNAL_H */