Blob Blame History Raw
From: Jiri Pirko <jiri@nvidia.com>
Date: Mon, 25 Jul 2022 10:29:15 +0200
Subject: net: devlink: move net check into
 devlinks_xa_for_each_registered_get()
Patch-mainline: v6.0-rc1
Git-commit: 294c4f57cfe3303ee2f050d1728c76a401e573a7
References: jsc#PED-1549

Benefit from having devlinks iterator helper
devlinks_xa_for_each_registered_get() and move the net pointer
check inside.

Suggested-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: Jiri Pirko <jiri@nvidia.com>
Reviewed-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Acked-by: Thomas Bogendoerfer <tbogendoerfer@suse.de>
---
 net/core/devlink.c |  135 +++++++++++++++--------------------------------------
 1 file changed, 39 insertions(+), 96 deletions(-)

--- a/net/core/devlink.c
+++ b/net/core/devlink.c
@@ -290,7 +290,7 @@ void devl_unlock(struct devlink *devlink
 EXPORT_SYMBOL_GPL(devl_unlock);
 
 static struct devlink *
-devlinks_xa_find_get(unsigned long *indexp, xa_mark_t filter,
+devlinks_xa_find_get(struct net *net, unsigned long *indexp, xa_mark_t filter,
 		     void * (*xa_find_fn)(struct xarray *, unsigned long *,
 					  unsigned long, xa_mark_t))
 {
@@ -305,33 +305,40 @@ retry:
 	xa_find_fn = xa_find_after;
 	if (!devlink_try_get(devlink))
 		goto retry;
+	if (!net_eq(devlink_net(devlink), net)) {
+		devlink_put(devlink);
+		goto retry;
+	}
 unlock:
 	rcu_read_unlock();
 	return devlink;
 }
 
-static struct devlink *devlinks_xa_find_get_first(unsigned long *indexp,
+static struct devlink *devlinks_xa_find_get_first(struct net *net,
+						  unsigned long *indexp,
 						  xa_mark_t filter)
 {
-	return devlinks_xa_find_get(indexp, filter, xa_find);
+	return devlinks_xa_find_get(net, indexp, filter, xa_find);
 }
 
-static struct devlink *devlinks_xa_find_get_next(unsigned long *indexp,
+static struct devlink *devlinks_xa_find_get_next(struct net *net,
+						 unsigned long *indexp,
 						 xa_mark_t filter)
 {
-	return devlinks_xa_find_get(indexp, filter, xa_find_after);
+	return devlinks_xa_find_get(net, indexp, filter, xa_find_after);
 }
 
 /* Iterate over devlink pointers which were possible to get reference to.
  * devlink_put() needs to be called for each iterated devlink pointer
  * in loop body in order to release the reference.
  */
-#define devlinks_xa_for_each_get(index, devlink, filter)			\
-	for (index = 0, devlink = devlinks_xa_find_get_first(&index, filter);	\
-	     devlink; devlink = devlinks_xa_find_get_next(&index, filter))
+#define devlinks_xa_for_each_get(net, index, devlink, filter)			\
+	for (index = 0,								\
+	     devlink = devlinks_xa_find_get_first(net, &index, filter);		\
+	     devlink; devlink = devlinks_xa_find_get_next(net, &index, filter))
 
-#define devlinks_xa_for_each_registered_get(index, devlink)			\
-	devlinks_xa_for_each_get(index, devlink, DEVLINK_REGISTERED)
+#define devlinks_xa_for_each_registered_get(net, index, devlink)		\
+	devlinks_xa_for_each_get(net, index, devlink, DEVLINK_REGISTERED)
 
 static struct devlink *devlink_get_from_attrs(struct net *net,
 					      struct nlattr **attrs)
@@ -347,10 +354,9 @@ static struct devlink *devlink_get_from_
 	busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]);
 	devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);
 
-	devlinks_xa_for_each_registered_get(index, devlink) {
+	devlinks_xa_for_each_registered_get(net, index, devlink) {
 		if (strcmp(devlink->dev->bus->name, busname) == 0 &&
-		    strcmp(dev_name(devlink->dev), devname) == 0 &&
-		    net_eq(devlink_net(devlink), net))
+		    strcmp(dev_name(devlink->dev), devname) == 0)
 			return devlink;
 		devlink_put(devlink);
 	}
@@ -1377,10 +1383,7 @@ static int devlink_nl_cmd_rate_get_dumpi
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		devl_lock(devlink);
 		list_for_each_entry(devlink_rate, &devlink->rate_list, list) {
 			enum devlink_command cmd = DEVLINK_CMD_RATE_NEW;
@@ -1401,7 +1404,6 @@ static int devlink_nl_cmd_rate_get_dumpi
 			idx++;
 		}
 		devl_unlock(devlink);
-retry:
 		devlink_put(devlink);
 	}
 out:
@@ -1477,12 +1479,7 @@ static int devlink_nl_cmd_get_dumpit(str
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
-			devlink_put(devlink);
-			continue;
-		}
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		if (idx < start) {
 			idx++;
 			devlink_put(devlink);
@@ -1537,10 +1534,7 @@ static int devlink_nl_cmd_port_get_dumpi
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		devl_lock(devlink);
 		list_for_each_entry(devlink_port, &devlink->port_list, list) {
 			if (idx < start) {
@@ -1560,7 +1554,6 @@ static int devlink_nl_cmd_port_get_dumpi
 			idx++;
 		}
 		devl_unlock(devlink);
-retry:
 		devlink_put(devlink);
 	}
 out:
@@ -2270,10 +2263,7 @@ static int devlink_nl_cmd_linecard_get_d
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		mutex_lock(&devlink->linecards_lock);
 		list_for_each_entry(linecard, &devlink->linecard_list, list) {
 			if (idx < start) {
@@ -2296,7 +2286,6 @@ static int devlink_nl_cmd_linecard_get_d
 			idx++;
 		}
 		mutex_unlock(&devlink->linecards_lock);
-retry:
 		devlink_put(devlink);
 	}
 out:
@@ -2539,10 +2528,7 @@ static int devlink_nl_cmd_sb_get_dumpit(
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		devl_lock(devlink);
 		list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
 			if (idx < start) {
@@ -2562,7 +2548,6 @@ static int devlink_nl_cmd_sb_get_dumpit(
 			idx++;
 		}
 		devl_unlock(devlink);
-retry:
 		devlink_put(devlink);
 	}
 out:
@@ -2688,7 +2673,7 @@ static int devlink_nl_cmd_sb_pool_get_du
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
 		    !devlink->ops->sb_pool_get)
 			goto retry;
@@ -2906,9 +2891,8 @@ static int devlink_nl_cmd_sb_port_pool_g
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
-		    !devlink->ops->sb_port_pool_get)
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
+		if (!devlink->ops->sb_port_pool_get)
 			goto retry;
 
 		devl_lock(devlink);
@@ -3152,9 +3136,8 @@ devlink_nl_cmd_sb_tc_pool_bind_get_dumpi
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
-		    !devlink->ops->sb_tc_pool_bind_get)
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
+		if (!devlink->ops->sb_tc_pool_bind_get)
 			goto retry;
 
 		devl_lock(devlink);
@@ -5236,10 +5219,7 @@ static int devlink_nl_cmd_param_get_dump
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		devl_lock(devlink);
 		list_for_each_entry(param_item, &devlink->param_list, list) {
 			if (idx < start) {
@@ -5261,7 +5241,6 @@ static int devlink_nl_cmd_param_get_dump
 			idx++;
 		}
 		devl_unlock(devlink);
-retry:
 		devlink_put(devlink);
 	}
 out:
@@ -5468,10 +5447,7 @@ static int devlink_nl_cmd_port_param_get
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		devl_lock(devlink);
 		list_for_each_entry(devlink_port, &devlink->port_list, list) {
 			list_for_each_entry(param_item,
@@ -5498,7 +5474,6 @@ static int devlink_nl_cmd_port_param_get
 			}
 		}
 		devl_unlock(devlink);
-retry:
 		devlink_put(devlink);
 	}
 out:
@@ -6049,13 +6024,9 @@ static int devlink_nl_cmd_region_get_dum
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		err = devlink_nl_cmd_region_get_devlink_dumpit(msg, cb, devlink,
 							       &idx, start);
-retry:
 		devlink_put(devlink);
 		if (err)
 			goto out;
@@ -6580,10 +6551,7 @@ static int devlink_nl_cmd_info_get_dumpi
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		if (idx < start || !devlink->ops->info_get)
 			goto inc;
 
@@ -6601,7 +6569,6 @@ static int devlink_nl_cmd_info_get_dumpi
 		}
 inc:
 		idx++;
-retry:
 		devlink_put(devlink);
 	}
 	mutex_unlock(&devlink_mutex);
@@ -7757,10 +7724,7 @@ devlink_nl_cmd_health_reporter_get_dumpi
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry_rep;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		mutex_lock(&devlink->reporters_lock);
 		list_for_each_entry(reporter, &devlink->reporter_list,
 				    list) {
@@ -7780,14 +7744,10 @@ devlink_nl_cmd_health_reporter_get_dumpi
 			idx++;
 		}
 		mutex_unlock(&devlink->reporters_lock);
-retry_rep:
 		devlink_put(devlink);
 	}
 
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry_port;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		devl_lock(devlink);
 		list_for_each_entry(port, &devlink->port_list, list) {
 			mutex_lock(&port->reporters_lock);
@@ -7812,7 +7772,6 @@ retry_rep:
 			mutex_unlock(&port->reporters_lock);
 		}
 		devl_unlock(devlink);
-retry_port:
 		devlink_put(devlink);
 	}
 out:
@@ -8351,10 +8310,7 @@ static int devlink_nl_cmd_trap_get_dumpi
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		devl_lock(devlink);
 		list_for_each_entry(trap_item, &devlink->trap_list, list) {
 			if (idx < start) {
@@ -8374,7 +8330,6 @@ static int devlink_nl_cmd_trap_get_dumpi
 			idx++;
 		}
 		devl_unlock(devlink);
-retry:
 		devlink_put(devlink);
 	}
 out:
@@ -8575,10 +8530,7 @@ static int devlink_nl_cmd_trap_group_get
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		devl_lock(devlink);
 		list_for_each_entry(group_item, &devlink->trap_group_list,
 				    list) {
@@ -8599,7 +8551,6 @@ static int devlink_nl_cmd_trap_group_get
 			idx++;
 		}
 		devl_unlock(devlink);
-retry:
 		devlink_put(devlink);
 	}
 out:
@@ -8886,10 +8837,7 @@ static int devlink_nl_cmd_trap_policer_g
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
 		devl_lock(devlink);
 		list_for_each_entry(policer_item, &devlink->trap_policer_list,
 				    list) {
@@ -8910,7 +8858,6 @@ static int devlink_nl_cmd_trap_policer_g
 			idx++;
 		}
 		devl_unlock(devlink);
-retry:
 		devlink_put(devlink);
 	}
 out:
@@ -12375,10 +12322,7 @@ static void __net_exit devlink_pernet_pr
 	 * all devlink instances from this namespace into init_net.
 	 */
 	mutex_lock(&devlink_mutex);
-	devlinks_xa_for_each_registered_get(index, devlink) {
-		if (!net_eq(devlink_net(devlink), net))
-			goto retry;
-
+	devlinks_xa_for_each_registered_get(net, index, devlink) {
 		WARN_ON(!(devlink->features & DEVLINK_F_RELOAD));
 		err = devlink_reload(devlink, &init_net,
 				     DEVLINK_RELOAD_ACTION_DRIVER_REINIT,
@@ -12386,7 +12330,6 @@ static void __net_exit devlink_pernet_pr
 				     &actions_performed, NULL);
 		if (err && err != -EOPNOTSUPP)
 			pr_warn("Failed to reload devlink instance into init_net\n");
-retry:
 		devlink_put(devlink);
 	}
 	mutex_unlock(&devlink_mutex);