Blob Blame History Raw
From: Yuval Mintz <yuvalm@mellanox.com>
Date: Wed, 28 Feb 2018 23:29:30 +0200
Subject: ip6mr: Make mroute_sk rcu-based
Patch-mainline: v4.17-rc1
Git-commit: 8571ab479a6e1ef46ead5ebee567e128a422767c
References: bsc#1112374

In ipmr the mr_table socket is handled under RCU. Introduce the same
for ip6mr.

Signed-off-by: Yuval Mintz <yuvalm@mellanox.com>
Acked-by: Nikolay Aleksandrov <nikolay@cumulusnetworks.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Acked-by: Thomas Bogendoerfer <tbogendoerfer@suse.de>
---
 include/linux/mroute6.h |    6 +++---
 net/ipv6/ip6_output.c   |    2 +-
 net/ipv6/ip6mr.c        |   45 +++++++++++++++++++++++++++------------------
 3 files changed, 31 insertions(+), 22 deletions(-)

--- a/include/linux/mroute6.h
+++ b/include/linux/mroute6.h
@@ -110,12 +110,12 @@ extern int ip6mr_get_route(struct net *n
 			   struct rtmsg *rtm, u32 portid);
 
 #ifdef CONFIG_IPV6_MROUTE
-extern struct sock *mroute6_socket(struct net *net, struct sk_buff *skb);
+bool mroute6_is_socket(struct net *net, struct sk_buff *skb);
 extern int ip6mr_sk_done(struct sock *sk);
 #else
-static inline struct sock *mroute6_socket(struct net *net, struct sk_buff *skb)
+static inline bool mroute6_is_socket(struct net *net, struct sk_buff *skb)
 {
-	return NULL;
+	return false;
 }
 static inline int ip6mr_sk_done(struct sock *sk)
 {
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -74,7 +74,7 @@ static int ip6_finish_output2(struct net
 		struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb));
 
 		if (!(dev->flags & IFF_LOOPBACK) && sk_mc_loop(sk) &&
-		    ((mroute6_socket(net, skb) &&
+		    ((mroute6_is_socket(net, skb) &&
 		     !(IP6CB(skb)->flags & IP6SKB_FORWARDED)) ||
 		     ipv6_chk_mcast_addr(dev, &ipv6_hdr(skb)->daddr,
 					 &ipv6_hdr(skb)->saddr))) {
--- a/net/ipv6/ip6mr.c
+++ b/net/ipv6/ip6mr.c
@@ -58,7 +58,7 @@ struct mr6_table {
 	struct list_head	list;
 	possible_net_t		net;
 	u32			id;
-	struct sock		*mroute6_sk;
+	struct sock __rcu	*mroute6_sk;
 	struct timer_list	ipmr_expire_timer;
 	struct list_head	mfc6_unres_queue;
 	struct list_head	mfc6_cache_array[MFC6_LINES];
@@ -1124,6 +1124,7 @@ static void ip6mr_cache_resolve(struct n
 static int ip6mr_cache_report(struct mr6_table *mrt, struct sk_buff *pkt,
 			      mifi_t mifi, int assert)
 {
+	struct sock *mroute6_sk;
 	struct sk_buff *skb;
 	struct mrt6msg *msg;
 	int ret;
@@ -1193,15 +1194,17 @@ static int ip6mr_cache_report(struct mr6
 	skb->ip_summed = CHECKSUM_UNNECESSARY;
 	}
 
-	if (!mrt->mroute6_sk) {
+	rcu_read_lock();
+	mroute6_sk = rcu_dereference(mrt->mroute6_sk);
+	if (!mroute6_sk) {
+		rcu_read_unlock();
 		kfree_skb(skb);
 		return -EINVAL;
 	}
 
-	/*
-	 *	Deliver to user space multicast routing algorithms
-	 */
-	ret = sock_queue_rcv_skb(mrt->mroute6_sk, skb);
+	/* Deliver to user space multicast routing algorithms */
+	ret = sock_queue_rcv_skb(mroute6_sk, skb);
+	rcu_read_unlock();
 	if (ret < 0) {
 		net_warn_ratelimited("mroute6: pending queue full, dropping entries\n");
 		kfree_skb(skb);
@@ -1581,11 +1584,11 @@ static int ip6mr_sk_init(struct mr6_tabl
 
 	rtnl_lock();
 	write_lock_bh(&mrt_lock);
-	if (likely(mrt->mroute6_sk == NULL)) {
-		mrt->mroute6_sk = sk;
-		net->ipv6.devconf_all->mc_forwarding++;
-	} else {
+	if (rtnl_dereference(mrt->mroute6_sk)) {
 		err = -EADDRINUSE;
+	} else {
+		rcu_assign_pointer(mrt->mroute6_sk, sk);
+		net->ipv6.devconf_all->mc_forwarding++;
 	}
 	write_unlock_bh(&mrt_lock);
 
@@ -1607,9 +1610,9 @@ int ip6mr_sk_done(struct sock *sk)
 
 	rtnl_lock();
 	ip6mr_for_each_table(mrt, net) {
-		if (sk == mrt->mroute6_sk) {
+		if (sk == rtnl_dereference(mrt->mroute6_sk)) {
 			write_lock_bh(&mrt_lock);
-			mrt->mroute6_sk = NULL;
+			RCU_INIT_POINTER(mrt->mroute6_sk, NULL);
 			net->ipv6.devconf_all->mc_forwarding--;
 			write_unlock_bh(&mrt_lock);
 			inet6_netconf_notify_devconf(net, RTM_NEWNETCONF,
@@ -1623,11 +1626,12 @@ int ip6mr_sk_done(struct sock *sk)
 		}
 	}
 	rtnl_unlock();
+	synchronize_rcu();
 
 	return err;
 }
 
-struct sock *mroute6_socket(struct net *net, struct sk_buff *skb)
+bool mroute6_is_socket(struct net *net, struct sk_buff *skb)
 {
 	struct mr6_table *mrt;
 	struct flowi6 fl6 = {
@@ -1639,8 +1643,9 @@ struct sock *mroute6_socket(struct net *
 	if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0)
 		return NULL;
 
-	return mrt->mroute6_sk;
+	return rcu_access_pointer(mrt->mroute6_sk);
 }
+EXPORT_SYMBOL(mroute6_is_socket);
 
 /*
  *	Socket options and virtual interface manipulation. The whole
@@ -1667,7 +1672,8 @@ int ip6_mroute_setsockopt(struct sock *s
 		return -ENOENT;
 
 	if (optname != MRT6_INIT) {
-		if (sk != mrt->mroute6_sk && !ns_capable(net->user_ns, CAP_NET_ADMIN))
+		if (sk != rcu_access_pointer(mrt->mroute6_sk) &&
+		    !ns_capable(net->user_ns, CAP_NET_ADMIN))
 			return -EACCES;
 	}
 
@@ -1689,7 +1695,8 @@ int ip6_mroute_setsockopt(struct sock *s
 		if (vif.mif6c_mifi >= MAXMIFS)
 			return -ENFILE;
 		rtnl_lock();
-		ret = mif6_add(net, mrt, &vif, sk == mrt->mroute6_sk);
+		ret = mif6_add(net, mrt, &vif,
+			       sk == rtnl_dereference(mrt->mroute6_sk));
 		rtnl_unlock();
 		return ret;
 
@@ -1723,7 +1730,9 @@ int ip6_mroute_setsockopt(struct sock *s
 			ret = ip6mr_mfc_delete(mrt, &mfc, parent);
 		else
 			ret = ip6mr_mfc_add(net, mrt, &mfc,
-					    sk == mrt->mroute6_sk, parent);
+					    sk ==
+					    rtnl_dereference(mrt->mroute6_sk),
+					    parent);
 		rtnl_unlock();
 		return ret;
 
@@ -1775,7 +1784,7 @@ int ip6_mroute_setsockopt(struct sock *s
 		/* "pim6reg%u" should not exceed 16 bytes (IFNAMSIZ) */
 		if (v != RT_TABLE_DEFAULT && v >= 100000000)
 			return -EINVAL;
-		if (sk == mrt->mroute6_sk)
+		if (sk == rcu_access_pointer(mrt->mroute6_sk))
 			return -EBUSY;
 
 		rtnl_lock();