Blob Blame History Raw
From: Leon Romanovsky <leonro@nvidia.com>
Date: Sat, 14 Aug 2021 12:57:28 +0300
Subject: devlink: Count struct devlink consumers
Patch-mainline: v5.15-rc1
Git-commit: 437ebfd90a2567aab19dce47bafc81ebd8a63324
References: jsc#SLE-19253

The struct devlink itself is protected by internal lock and doesn't
need global lock during operation. That global lock is used to protect
addition/removal new devlink instances from the global list in use by
all devlink consumers in the system.

The future conversion of linked list to be xarray will allow us to
actually delete that lock, but first we need to count all struct devlink
users.

The reference counting provides us a way to ensure that no new user
space commands success to grab devlink instance which is going to be
destroyed makes it is safe to access it without lock.

Signed-off-by: Leon Romanovsky <leonro@nvidia.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Acked-by: Thomas Bogendoerfer <tbogendoerfer@suse.de>
---
 include/net/devlink.h |    2 
 net/core/devlink.c    |  205 +++++++++++++++++++++++++++++++++++++++++---------
 2 files changed, 172 insertions(+), 35 deletions(-)

--- a/include/net/devlink.h
+++ b/include/net/devlink.h
@@ -56,6 +56,8 @@ struct devlink {
 			    */
 	u8 reload_failed:1,
 	   reload_enabled:1;
+	refcount_t refcount;
+	struct completion comp;
 	char priv[0] __aligned(NETDEV_ALIGN);
 };
 
--- a/net/core/devlink.c
+++ b/net/core/devlink.c
@@ -108,10 +108,22 @@ struct net *devlink_net(const struct dev
 }
 EXPORT_SYMBOL_GPL(devlink_net);
 
+static void devlink_put(struct devlink *devlink)
+{
+	if (refcount_dec_and_test(&devlink->refcount))
+		complete(&devlink->comp);
+}
+
+static bool __must_check devlink_try_get(struct devlink *devlink)
+{
+	return refcount_inc_not_zero(&devlink->refcount);
+}
+
 static struct devlink *devlink_get_from_attrs(struct net *net,
 					      struct nlattr **attrs)
 {
 	struct devlink *devlink;
+	bool found = false;
 	char *busname;
 	char *devname;
 
@@ -126,16 +138,16 @@ static struct devlink *devlink_get_from_
 	list_for_each_entry(devlink, &devlink_list, list) {
 		if (strcmp(devlink->dev->bus->name, busname) == 0 &&
 		    strcmp(dev_name(devlink->dev), devname) == 0 &&
-		    net_eq(devlink_net(devlink), net))
-			return devlink;
+		    net_eq(devlink_net(devlink), net)) {
+			found = true;
+			break;
+		}
 	}
 
-	return ERR_PTR(-ENODEV);
-}
+	if (!found || !devlink_try_get(devlink))
+		devlink = ERR_PTR(-ENODEV);
 
-static struct devlink *devlink_get_from_info(struct genl_info *info)
-{
-	return devlink_get_from_attrs(genl_info_net(info), info->attrs);
+	return devlink;
 }
 
 static struct devlink_port *devlink_port_get_by_index(struct devlink *devlink,
@@ -486,7 +498,7 @@ static int devlink_nl_pre_doit(const str
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlink = devlink_get_from_info(info);
+	devlink = devlink_get_from_attrs(genl_info_net(info), info->attrs);
 	if (IS_ERR(devlink)) {
 		mutex_unlock(&devlink_mutex);
 		return PTR_ERR(devlink);
@@ -529,6 +541,7 @@ static int devlink_nl_pre_doit(const str
 unlock:
 	if (~ops->internal_flags & DEVLINK_NL_FLAG_NO_LOCK)
 		mutex_unlock(&devlink->lock);
+	devlink_put(devlink);
 	mutex_unlock(&devlink_mutex);
 	return err;
 }
@@ -541,6 +554,7 @@ static void devlink_nl_post_doit(const s
 	devlink = info->user_ptr[0];
 	if (~ops->internal_flags & DEVLINK_NL_FLAG_NO_LOCK)
 		mutex_unlock(&devlink->lock);
+	devlink_put(devlink);
 	mutex_unlock(&devlink_mutex);
 }
 
@@ -1078,8 +1092,12 @@ static int devlink_nl_cmd_rate_get_dumpi
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_rate, &devlink->rate_list, list) {
 			enum devlink_command cmd = DEVLINK_CMD_RATE_NEW;
@@ -1094,11 +1112,14 @@ static int devlink_nl_cmd_rate_get_dumpi
 						   NLM_F_MULTI, NULL);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -1173,15 +1194,24 @@ static int devlink_nl_cmd_get_dumpit(str
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
+			devlink_put(devlink);
+			continue;
+		}
+
 		if (idx < start) {
 			idx++;
+			devlink_put(devlink);
 			continue;
 		}
+
 		err = devlink_nl_fill(msg, devlink, DEVLINK_CMD_NEW,
 				      NETLINK_CB(cb->skb).portid,
 				      cb->nlh->nlmsg_seq, NLM_F_MULTI);
+		devlink_put(devlink);
 		if (err)
 			goto out;
 		idx++;
@@ -1226,8 +1256,12 @@ static int devlink_nl_cmd_port_get_dumpi
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_port, &devlink->port_list, list) {
 			if (idx < start) {
@@ -1241,11 +1275,14 @@ static int devlink_nl_cmd_port_get_dumpi
 						   NLM_F_MULTI, cb->extack);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -1884,8 +1921,12 @@ static int devlink_nl_cmd_sb_get_dumpit(
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
 			if (idx < start) {
@@ -1899,11 +1940,14 @@ static int devlink_nl_cmd_sb_get_dumpit(
 						 NLM_F_MULTI);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -2028,9 +2072,13 @@ static int devlink_nl_cmd_sb_pool_get_du
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
+		if (!devlink_try_get(devlink))
+			continue;
+
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
 		    !devlink->ops->sb_pool_get)
-			continue;
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
 			err = __sb_pool_get_dumpit(msg, start, &idx, devlink,
@@ -2041,10 +2089,13 @@ static int devlink_nl_cmd_sb_pool_get_du
 				err = 0;
 			} else if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -2241,9 +2292,13 @@ static int devlink_nl_cmd_sb_port_pool_g
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
+		if (!devlink_try_get(devlink))
+			continue;
+
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
 		    !devlink->ops->sb_port_pool_get)
-			continue;
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
 			err = __sb_port_pool_get_dumpit(msg, start, &idx,
@@ -2254,10 +2309,13 @@ static int devlink_nl_cmd_sb_port_pool_g
 				err = 0;
 			} else if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -2482,9 +2540,12 @@ devlink_nl_cmd_sb_tc_pool_bind_get_dumpi
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
+		if (!devlink_try_get(devlink))
+			continue;
+
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
 		    !devlink->ops->sb_tc_pool_bind_get)
-			continue;
+			goto retry;
 
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
@@ -2497,10 +2558,13 @@ devlink_nl_cmd_sb_tc_pool_bind_get_dumpi
 				err = 0;
 			} else if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -4552,8 +4616,12 @@ static int devlink_nl_cmd_param_get_dump
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(param_item, &devlink->param_list, list) {
 			if (idx < start) {
@@ -4569,11 +4637,14 @@ static int devlink_nl_cmd_param_get_dump
 				err = 0;
 			} else if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -4820,8 +4891,12 @@ static int devlink_nl_cmd_port_param_get
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_port, &devlink->port_list, list) {
 			list_for_each_entry(param_item,
@@ -4841,12 +4916,15 @@ static int devlink_nl_cmd_port_param_get
 					err = 0;
 				} else if (err) {
 					mutex_unlock(&devlink->lock);
+					devlink_put(devlink);
 					goto out;
 				}
 				idx++;
 			}
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -5385,14 +5463,20 @@ static int devlink_nl_cmd_region_get_dum
 	struct devlink *devlink;
 	int start = cb->args[0];
 	int idx = 0;
-	int err;
+	int err = 0;
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
 		err = devlink_nl_cmd_region_get_devlink_dumpit(msg, cb, devlink,
 							       &idx, start);
+retry:
+		devlink_put(devlink);
 		if (err)
 			goto out;
 	}
@@ -5755,6 +5839,7 @@ static int devlink_nl_cmd_region_read_du
 	nla_nest_end(skb, chunks_attr);
 	genlmsg_end(skb, hdr);
 	mutex_unlock(&devlink->lock);
+	devlink_put(devlink);
 	mutex_unlock(&devlink_mutex);
 
 	return skb->len;
@@ -5763,6 +5848,7 @@ nla_put_failure:
 	genlmsg_cancel(skb, hdr);
 out_unlock:
 	mutex_unlock(&devlink->lock);
+	devlink_put(devlink);
 out_dev:
 	mutex_unlock(&devlink_mutex);
 	return err;
@@ -5914,17 +6000,14 @@ static int devlink_nl_cmd_info_get_dumpi
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
-		if (idx < start) {
-			idx++;
-			continue;
-		}
 
-		if (!devlink->ops->info_get) {
-			idx++;
-			continue;
-		}
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
+		if (idx < start || !devlink->ops->info_get)
+			goto inc;
 
 		mutex_lock(&devlink->lock);
 		err = devlink_nl_info_fill(msg, devlink, DEVLINK_CMD_INFO_GET,
@@ -5934,9 +6017,14 @@ static int devlink_nl_cmd_info_get_dumpi
 		mutex_unlock(&devlink->lock);
 		if (err == -EOPNOTSUPP)
 			err = 0;
-		else if (err)
+		else if (err) {
+			devlink_put(devlink);
 			break;
+		}
+inc:
 		idx++;
+retry:
+		devlink_put(devlink);
 	}
 	mutex_unlock(&devlink_mutex);
 
@@ -7021,6 +7109,7 @@ devlink_health_reporter_get_from_cb(stru
 		goto unlock;
 
 	reporter = devlink_health_reporter_get_from_attrs(devlink, attrs);
+	devlink_put(devlink);
 	mutex_unlock(&devlink_mutex);
 	return reporter;
 unlock:
@@ -7092,8 +7181,12 @@ devlink_nl_cmd_health_reporter_get_dumpi
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry_rep;
+
 		mutex_lock(&devlink->reporters_lock);
 		list_for_each_entry(reporter, &devlink->reporter_list,
 				    list) {
@@ -7107,16 +7200,23 @@ devlink_nl_cmd_health_reporter_get_dumpi
 				NLM_F_MULTI);
 			if (err) {
 				mutex_unlock(&devlink->reporters_lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->reporters_lock);
+retry_rep:
+		devlink_put(devlink);
 	}
 
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry_port;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(port, &devlink->port_list, list) {
 			mutex_lock(&port->reporters_lock);
@@ -7133,6 +7233,7 @@ devlink_nl_cmd_health_reporter_get_dumpi
 				if (err) {
 					mutex_unlock(&port->reporters_lock);
 					mutex_unlock(&devlink->lock);
+					devlink_put(devlink);
 					goto out;
 				}
 				idx++;
@@ -7140,6 +7241,8 @@ devlink_nl_cmd_health_reporter_get_dumpi
 			mutex_unlock(&port->reporters_lock);
 		}
 		mutex_unlock(&devlink->lock);
+retry_port:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -7673,8 +7776,12 @@ static int devlink_nl_cmd_trap_get_dumpi
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(trap_item, &devlink->trap_list, list) {
 			if (idx < start) {
@@ -7688,11 +7795,14 @@ static int devlink_nl_cmd_trap_get_dumpi
 						   NLM_F_MULTI);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -7892,8 +8002,12 @@ static int devlink_nl_cmd_trap_group_get
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(group_item, &devlink->trap_group_list,
 				    list) {
@@ -7908,11 +8022,14 @@ static int devlink_nl_cmd_trap_group_get
 							 NLM_F_MULTI);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -8198,8 +8315,12 @@ static int devlink_nl_cmd_trap_policer_g
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+			goto retry;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(policer_item, &devlink->trap_policer_list,
 				    list) {
@@ -8214,11 +8335,14 @@ static int devlink_nl_cmd_trap_policer_g
 							   NLM_F_MULTI);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+retry:
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -8801,6 +8925,9 @@ struct devlink *devlink_alloc_ns(const s
 	INIT_LIST_HEAD(&devlink->trap_policer_list);
 	mutex_init(&devlink->lock);
 	mutex_init(&devlink->reporters_lock);
+	refcount_set(&devlink->refcount, 1);
+	init_completion(&devlink->comp);
+
 	return devlink;
 }
 EXPORT_SYMBOL_GPL(devlink_alloc_ns);
@@ -8827,6 +8954,9 @@ EXPORT_SYMBOL_GPL(devlink_register);
  */
 void devlink_unregister(struct devlink *devlink)
 {
+	devlink_put(devlink);
+	wait_for_completion(&devlink->comp);
+
 	mutex_lock(&devlink_mutex);
 	WARN_ON(devlink_reload_supported(devlink->ops) &&
 		devlink->reload_enabled);
@@ -11374,9 +11504,12 @@ static void __net_exit devlink_pernet_pr
 	 */
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), net))
+		if (!devlink_try_get(devlink))
 			continue;
 
+		if (!net_eq(devlink_net(devlink), net))
+			goto retry;
+
 		WARN_ON(!devlink_reload_supported(devlink->ops));
 		err = devlink_reload(devlink, &init_net,
 				     DEVLINK_RELOAD_ACTION_DRIVER_REINIT,
@@ -11384,6 +11517,8 @@ static void __net_exit devlink_pernet_pr
 				     &actions_performed, NULL);
 		if (err && err != -EOPNOTSUPP)
 			pr_warn("Failed to reload devlink instance into init_net\n");
+retry:
+		devlink_put(devlink);
 	}
 	mutex_unlock(&devlink_mutex);
 }