Blob Blame History Raw
From: Vincent Whitchurch <vincent.whitchurch@axis.com>
Date: Wed, 1 Jun 2022 00:03:18 -0500
Subject: [PATCH] cifs: fix potential deadlock in direct reclaim
Git-commit: cc391b694ff085f62f133e6b8f864d43a8e69dfd
References: bsc#1193629
Patch-mainline: v5.19-rc1

The srv_mutex is used during writeback so cifs should ensure that
allocations done when that mutex is held are done with GFP_NOFS, to
avoid having direct reclaim ending up waiting for the same mutex and
causing a deadlock.  This is detected by lockdep with the splat below:

 ======================================================
 WARNING: possible circular locking dependency detected
 5.18.0 #70 Not tainted
 ------------------------------------------------------
 kswapd0/49 is trying to acquire lock:
 ffff8880195782e0 (&tcp_ses->srv_mutex){+.+.}-{3:3}, at: compound_send_recv

 but task is already holding lock:
 ffffffffa98e66c0 (fs_reclaim){+.+.}-{0:0}, at: balance_pgdat

 which lock already depends on the new lock.

 the existing dependency chain (in reverse order) is:

 -> #1 (fs_reclaim){+.+.}-{0:0}:
        fs_reclaim_acquire
        kmem_cache_alloc_trace
        __request_module
        crypto_alg_mod_lookup
        crypto_alloc_tfm_node
        crypto_alloc_shash
        cifs_alloc_hash
        smb311_crypto_shash_allocate
        smb311_update_preauth_hash
        compound_send_recv
        cifs_send_recv
        SMB2_negotiate
        smb2_negotiate
        cifs_negotiate_protocol
        cifs_get_smb_ses
        cifs_mount
        cifs_smb3_do_mount
        smb3_get_tree
        vfs_get_tree
        path_mount
        __x64_sys_mount
        do_syscall_64
        entry_SYSCALL_64_after_hwframe

 -> #0 (&tcp_ses->srv_mutex){+.+.}-{3:3}:
        __lock_acquire
        lock_acquire
        __mutex_lock
        mutex_lock_nested
        compound_send_recv
        cifs_send_recv
        SMB2_write
        smb2_sync_write
        cifs_write
        cifs_writepage_locked
        cifs_writepage
        shrink_page_list
        shrink_lruvec
        shrink_node
        balance_pgdat
        kswapd
        kthread
        ret_from_fork

 other info that might help us debug this:

  Possible unsafe locking scenario:

        CPU0                    CPU1
        ----                    ----
   lock(fs_reclaim);
                                lock(&tcp_ses->srv_mutex);
                                lock(fs_reclaim);
   lock(&tcp_ses->srv_mutex);

  *** DEADLOCK ***

 1 lock held by kswapd0/49:
  #0: ffffffffa98e66c0 (fs_reclaim){+.+.}-{0:0}, at: balance_pgdat

 stack backtrace:
 CPU: 2 PID: 49 Comm: kswapd0 Not tainted 5.18.0 #70
 Call Trace:
  <TASK>
  dump_stack_lvl
  dump_stack
  print_circular_bug.cold
  check_noncircular
  __lock_acquire
  lock_acquire
  __mutex_lock
  mutex_lock_nested
  compound_send_recv
  cifs_send_recv
  SMB2_write
  smb2_sync_write
  cifs_write
  cifs_writepage_locked
  cifs_writepage
  shrink_page_list
  shrink_lruvec
  shrink_node
  balance_pgdat
  kswapd
  kthread
  ret_from_fork
  </TASK>

Fix this by using the memalloc_nofs_save/restore APIs around the places
where the srv_mutex is held.  Do this in a wrapper function for the
lock/unlock of the srv_mutex, and rename the srv_mutex to avoid missing
call sites in the conversion.

Note that there is another lockdep warning involving internal crypto
locks, which was masked by this problem and is visible after this fix,
see the discussion in this thread:

 https://lore.kernel.org/all/20220523123755.GA13668@axis.com/

Link: https://lore.kernel.org/r/CANT5p=rqcYfYMVHirqvdnnca4Mo+JQSw5Qu12v=kPfpk5yhhmg@mail.gmail.com/
Reported-by: Shyam Prasad N <nspmangalore@gmail.com>
Suggested-by: Lars Persson <larper@axis.com>
Reviewed-by: Ronnie Sahlberg <lsahlber@redhat.com>
Reviewed-by: Enzo Matsumiya <ematsumiya@suse.de>
Signed-off-by: Vincent Whitchurch <vincent.whitchurch@axis.com>
Signed-off-by: Steve French <stfrench@microsoft.com>
Acked-by: Enzo Matsumiya <ematsumiya@suse.de>
---
 fs/cifs/cifs_swn.c    |    4 ++--
 fs/cifs/cifsencrypt.c |    8 ++++----
 fs/cifs/cifsglob.h    |   20 +++++++++++++++++++-
 fs/cifs/connect.c     |   26 +++++++++++++-------------
 fs/cifs/dfs_cache.c   |    4 ++--
 fs/cifs/sess.c        |    6 +++---
 fs/cifs/smb1ops.c     |    6 +++---
 fs/cifs/smb2pdu.c     |    6 +++---
 fs/cifs/smbdirect.c   |    4 ++--
 fs/cifs/transport.c   |   40 ++++++++++++++++++++--------------------
 10 files changed, 71 insertions(+), 53 deletions(-)

--- a/fs/cifs/cifs_swn.c
+++ b/fs/cifs/cifs_swn.c
@@ -467,7 +467,7 @@ static int cifs_swn_reconnect(struct cif
 	int ret = 0;
 
 	/* Store the reconnect address */
-	mutex_lock(&tcon->ses->server->srv_mutex);
+	cifs_server_lock(tcon->ses->server);
 	if (cifs_sockaddr_equal(&tcon->ses->server->dstaddr, addr))
 		goto unlock;
 
@@ -503,7 +503,7 @@ static int cifs_swn_reconnect(struct cif
 	cifs_signal_cifsd_for_reconnect(tcon->ses->server, false);
 
 unlock:
-	mutex_unlock(&tcon->ses->server->srv_mutex);
+	cifs_server_unlock(tcon->ses->server);
 
 	return ret;
 }
--- a/fs/cifs/cifsencrypt.c
+++ b/fs/cifs/cifsencrypt.c
@@ -236,9 +236,9 @@ int cifs_verify_signature(struct smb_rqs
 					cpu_to_le32(expected_sequence_number);
 	cifs_pdu->Signature.Sequence.Reserved = 0;
 
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 	rc = cifs_calc_signature(rqst, server, what_we_think_sig_should_be);
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 
 	if (rc)
 		return rc;
@@ -707,7 +707,7 @@ setup_ntlmv2_rsp(struct cifs_ses *ses, c
 
 	memcpy(ses->auth_key.response + baselen, tiblob, tilen);
 
-	mutex_lock(&ses->server->srv_mutex);
+	cifs_server_lock(ses->server);
 
 	rc = cifs_alloc_hash("hmac(md5)",
 			     &ses->server->secmech.hmacmd5,
@@ -759,7 +759,7 @@ setup_ntlmv2_rsp(struct cifs_ses *ses, c
 		cifs_dbg(VFS, "%s: Could not generate md5 hash\n", __func__);
 
 unlock:
-	mutex_unlock(&ses->server->srv_mutex);
+	cifs_server_unlock(ses->server);
 setup_ntlmv2_rsp_ret:
 	kfree(tiblob);
 
--- a/fs/cifs/cifsglob.h
+++ b/fs/cifs/cifsglob.h
@@ -16,6 +16,7 @@
 #include <linux/mempool.h>
 #include <linux/workqueue.h>
 #include <linux/utsname.h>
+#include <linux/sched/mm.h>
 #include <linux/netfs.h>
 #include "cifs_fs_sb.h"
 #include "cifsacl.h"
@@ -630,7 +631,8 @@ struct TCP_Server_Info {
 	unsigned int in_flight;  /* number of requests on the wire to server */
 	unsigned int max_in_flight; /* max number of requests that were on wire */
 	spinlock_t req_lock;  /* protect the two values above */
-	struct mutex srv_mutex;
+	struct mutex _srv_mutex;
+	unsigned int nofs_flag;
 	struct task_struct *tsk;
 	char server_GUID[16];
 	__u16 sec_mode;
@@ -746,6 +748,22 @@ struct TCP_Server_Info {
 #endif
 };
 
+static inline void cifs_server_lock(struct TCP_Server_Info *server)
+{
+	unsigned int nofs_flag = memalloc_nofs_save();
+
+	mutex_lock(&server->_srv_mutex);
+	server->nofs_flag = nofs_flag;
+}
+
+static inline void cifs_server_unlock(struct TCP_Server_Info *server)
+{
+	unsigned int nofs_flag = server->nofs_flag;
+
+	mutex_unlock(&server->_srv_mutex);
+	memalloc_nofs_restore(nofs_flag);
+}
+
 struct cifs_credits {
 	unsigned int value;
 	unsigned int instance;
--- a/fs/cifs/connect.c
+++ b/fs/cifs/connect.c
@@ -148,7 +148,7 @@ static void cifs_resolve_server(struct w
 	struct TCP_Server_Info *server = container_of(work,
 					struct TCP_Server_Info, resolve.work);
 
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 
 	/*
 	 * Resolve the hostname again to make sure that IP address is up-to-date.
@@ -159,7 +159,7 @@ static void cifs_resolve_server(struct w
 				__func__, rc);
 	}
 
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 }
 
 /*
@@ -267,7 +267,7 @@ cifs_abort_connection(struct TCP_Server_
 
 	/* do not want to be sending data on a socket we are freeing */
 	cifs_dbg(FYI, "%s: tearing down socket\n", __func__);
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 	if (server->ssocket) {
 		cifs_dbg(FYI, "State: 0x%x Flags: 0x%lx\n", server->ssocket->state,
 			 server->ssocket->flags);
@@ -296,7 +296,7 @@ cifs_abort_connection(struct TCP_Server_
 		mid->mid_flags |= MID_DELETED;
 	}
 	spin_unlock(&GlobalMid_Lock);
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 
 	cifs_dbg(FYI, "%s: issuing mid callbacks\n", __func__);
 	list_for_each_entry_safe(mid, nmid, &retry_list, qhead) {
@@ -306,9 +306,9 @@ cifs_abort_connection(struct TCP_Server_
 	}
 
 	if (cifs_rdma_enabled(server)) {
-		mutex_lock(&server->srv_mutex);
+		cifs_server_lock(server);
 		smbd_destroy(server);
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 	}
 }
 
@@ -359,7 +359,7 @@ static int __cifs_reconnect(struct TCP_S
 
 	do {
 		try_to_freeze();
-		mutex_lock(&server->srv_mutex);
+		cifs_server_lock(server);
 
 		if (!cifs_swn_set_server_dstaddr(server)) {
 			/* resolve the hostname again to make sure that IP address is up-to-date */
@@ -372,7 +372,7 @@ static int __cifs_reconnect(struct TCP_S
 		else
 			rc = generic_ip_connect(server);
 		if (rc) {
-			mutex_unlock(&server->srv_mutex);
+			cifs_server_unlock(server);
 			cifs_dbg(FYI, "%s: reconnect error %d\n", __func__, rc);
 			msleep(3000);
 		} else {
@@ -383,7 +383,7 @@ static int __cifs_reconnect(struct TCP_S
 				server->tcpStatus = CifsNeedNegotiate;
 			spin_unlock(&cifs_tcp_ses_lock);
 			cifs_swn_reset_server_dstaddr(server);
-			mutex_unlock(&server->srv_mutex);
+			cifs_server_unlock(server);
 			mod_delayed_work(cifsiod_wq, &server->reconnect, 0);
 		}
 	} while (server->tcpStatus == CifsNeedReconnect);
@@ -488,12 +488,12 @@ static int reconnect_dfs_server(struct T
 
 	do {
 		try_to_freeze();
-		mutex_lock(&server->srv_mutex);
+		cifs_server_lock(server);
 
 		rc = reconnect_target_unlocked(server, &tl, &target_hint);
 		if (rc) {
 			/* Failed to reconnect socket */
-			mutex_unlock(&server->srv_mutex);
+			cifs_server_unlock(server);
 			cifs_dbg(FYI, "%s: reconnect error %d\n", __func__, rc);
 			msleep(3000);
 			continue;
@@ -510,7 +510,7 @@ static int reconnect_dfs_server(struct T
 			server->tcpStatus = CifsNeedNegotiate;
 		spin_unlock(&cifs_tcp_ses_lock);
 		cifs_swn_reset_server_dstaddr(server);
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		mod_delayed_work(cifsiod_wq, &server->reconnect, 0);
 	} while (server->tcpStatus == CifsNeedReconnect);
 
@@ -1565,7 +1565,7 @@ cifs_get_tcp_session(struct smb3_fs_cont
 	init_waitqueue_head(&tcp_ses->response_q);
 	init_waitqueue_head(&tcp_ses->request_q);
 	INIT_LIST_HEAD(&tcp_ses->pending_mid_q);
-	mutex_init(&tcp_ses->srv_mutex);
+	mutex_init(&tcp_ses->_srv_mutex);
 	memcpy(tcp_ses->workstation_RFC1001_name,
 		ctx->source_rfc1001_name, RFC1001_NAME_LEN_WITH_NULL);
 	memcpy(tcp_ses->server_RFC1001_name,
--- a/fs/cifs/dfs_cache.c
+++ b/fs/cifs/dfs_cache.c
@@ -1327,9 +1327,9 @@ static bool target_share_equal(struct TC
 		cifs_dbg(VFS, "%s: failed to convert address \'%s\'. skip address matching.\n",
 			 __func__, ip);
 	} else {
-		mutex_lock(&server->srv_mutex);
+		cifs_server_lock(server);
 		match = cifs_match_ipaddr((struct sockaddr *)&server->dstaddr, &sa);
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 	}
 
 	kfree(ip);
--- a/fs/cifs/sess.c
+++ b/fs/cifs/sess.c
@@ -1134,14 +1134,14 @@ sess_establish_session(struct sess_data
 	struct cifs_ses *ses = sess_data->ses;
 	struct TCP_Server_Info *server = sess_data->server;
 
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 	if (!server->session_estab) {
 		if (server->sign) {
 			server->session_key.response =
 				kmemdup(ses->auth_key.response,
 				ses->auth_key.len, GFP_KERNEL);
 			if (!server->session_key.response) {
-				mutex_unlock(&server->srv_mutex);
+				cifs_server_unlock(server);
 				return -ENOMEM;
 			}
 			server->session_key.len =
@@ -1150,7 +1150,7 @@ sess_establish_session(struct sess_data
 		server->sequence_number = 0x2;
 		server->session_estab = true;
 	}
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 
 	cifs_dbg(FYI, "CIFS session established successfully\n");
 	return 0;
--- a/fs/cifs/smb1ops.c
+++ b/fs/cifs/smb1ops.c
@@ -38,10 +38,10 @@ send_nt_cancel(struct TCP_Server_Info *s
 	in_buf->WordCount = 0;
 	put_bcc(0, in_buf);
 
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 	rc = cifs_sign_smb(in_buf, server, &mid->sequence_number);
 	if (rc) {
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		return rc;
 	}
 
@@ -55,7 +55,7 @@ send_nt_cancel(struct TCP_Server_Info *s
 	if (rc < 0)
 		server->sequence_number--;
 
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 
 	cifs_dbg(FYI, "issued NT_CANCEL for mid %u, rc = %d\n",
 		 get_mid(in_buf), rc);
--- a/fs/cifs/smb2pdu.c
+++ b/fs/cifs/smb2pdu.c
@@ -1369,13 +1369,13 @@ SMB2_sess_establish_session(struct SMB2_
 	struct cifs_ses *ses = sess_data->ses;
 	struct TCP_Server_Info *server = sess_data->server;
 
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 	if (server->ops->generate_signingkey) {
 		rc = server->ops->generate_signingkey(ses, server);
 		if (rc) {
 			cifs_dbg(FYI,
 				"SMB3 session key generation failed\n");
-			mutex_unlock(&server->srv_mutex);
+			cifs_server_unlock(server);
 			return rc;
 		}
 	}
@@ -1383,7 +1383,7 @@ SMB2_sess_establish_session(struct SMB2_
 		server->sequence_number = 0x2;
 		server->session_estab = true;
 	}
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 
 	cifs_dbg(FYI, "SMB2/3 session established successfully\n");
 	return rc;
--- a/fs/cifs/smbdirect.c
+++ b/fs/cifs/smbdirect.c
@@ -1382,9 +1382,9 @@ void smbd_destroy(struct TCP_Server_Info
 	log_rdma_event(INFO, "freeing mr list\n");
 	wake_up_interruptible_all(&info->wait_mr);
 	while (atomic_read(&info->mr_used_count)) {
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		msleep(1000);
-		mutex_lock(&server->srv_mutex);
+		cifs_server_lock(server);
 	}
 	destroy_mr_list(info);
 
--- a/fs/cifs/transport.c
+++ b/fs/cifs/transport.c
@@ -822,7 +822,7 @@ cifs_call_async(struct TCP_Server_Info *
 	} else
 		instance = exist_credits->instance;
 
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 
 	/*
 	 * We can't use credits obtained from the previous session to send this
@@ -830,14 +830,14 @@ cifs_call_async(struct TCP_Server_Info *
 	 * return -EAGAIN in such cases to let callers handle it.
 	 */
 	if (instance != server->reconnect_instance) {
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		add_credits_and_wake_if(server, &credits, optype);
 		return -EAGAIN;
 	}
 
 	mid = server->ops->setup_async_request(server, rqst);
 	if (IS_ERR(mid)) {
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		add_credits_and_wake_if(server, &credits, optype);
 		return PTR_ERR(mid);
 	}
@@ -868,7 +868,7 @@ cifs_call_async(struct TCP_Server_Info *
 		cifs_delete_mid(mid);
 	}
 
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 
 	if (rc == 0)
 		return 0;
@@ -1109,7 +1109,7 @@ compound_send_recv(const unsigned int xi
 	 * of smb data.
 	 */
 
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 
 	/*
 	 * All the parts of the compound chain belong obtained credits from the
@@ -1119,7 +1119,7 @@ compound_send_recv(const unsigned int xi
 	 * handle it.
 	 */
 	if (instance != server->reconnect_instance) {
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		for (j = 0; j < num_rqst; j++)
 			add_credits(server, &credits[j], optype);
 		return -EAGAIN;
@@ -1131,7 +1131,7 @@ compound_send_recv(const unsigned int xi
 			revert_current_mid(server, i);
 			for (j = 0; j < i; j++)
 				cifs_delete_mid(midQ[j]);
-			mutex_unlock(&server->srv_mutex);
+			cifs_server_unlock(server);
 
 			/* Update # of requests on wire to server */
 			for (j = 0; j < num_rqst; j++)
@@ -1163,7 +1163,7 @@ compound_send_recv(const unsigned int xi
 		server->sequence_number -= 2;
 	}
 
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 
 	/*
 	 * If sending failed for some reason or it is an oplock break that we
@@ -1190,9 +1190,9 @@ compound_send_recv(const unsigned int xi
 	if ((ses->ses_status == SES_NEW) || (optype & CIFS_NEG_OP) || (optype & CIFS_SESS_OP)) {
 		spin_unlock(&cifs_tcp_ses_lock);
 
-		mutex_lock(&server->srv_mutex);
+		cifs_server_lock(server);
 		smb311_update_preauth_hash(ses, server, rqst[0].rq_iov, rqst[0].rq_nvec);
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 
 		spin_lock(&cifs_tcp_ses_lock);
 	}
@@ -1266,9 +1266,9 @@ compound_send_recv(const unsigned int xi
 			.iov_len = resp_iov[0].iov_len
 		};
 		spin_unlock(&cifs_tcp_ses_lock);
-		mutex_lock(&server->srv_mutex);
+		cifs_server_lock(server);
 		smb311_update_preauth_hash(ses, server, &iov, 1);
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		spin_lock(&cifs_tcp_ses_lock);
 	}
 	spin_unlock(&cifs_tcp_ses_lock);
@@ -1385,11 +1385,11 @@ SendReceive(const unsigned int xid, stru
 	   and avoid races inside tcp sendmsg code that could cause corruption
 	   of smb data */
 
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 
 	rc = allocate_mid(ses, in_buf, &midQ);
 	if (rc) {
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		/* Update # of requests on wire to server */
 		add_credits(server, &credits, 0);
 		return rc;
@@ -1397,7 +1397,7 @@ SendReceive(const unsigned int xid, stru
 
 	rc = cifs_sign_smb(in_buf, server, &midQ->sequence_number);
 	if (rc) {
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		goto out;
 	}
 
@@ -1411,7 +1411,7 @@ SendReceive(const unsigned int xid, stru
 	if (rc < 0)
 		server->sequence_number -= 2;
 
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 
 	if (rc < 0)
 		goto out;
@@ -1530,18 +1530,18 @@ SendReceiveBlockingLock(const unsigned i
 	   and avoid races inside tcp sendmsg code that could cause corruption
 	   of smb data */
 
-	mutex_lock(&server->srv_mutex);
+	cifs_server_lock(server);
 
 	rc = allocate_mid(ses, in_buf, &midQ);
 	if (rc) {
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		return rc;
 	}
 
 	rc = cifs_sign_smb(in_buf, server, &midQ->sequence_number);
 	if (rc) {
 		cifs_delete_mid(midQ);
-		mutex_unlock(&server->srv_mutex);
+		cifs_server_unlock(server);
 		return rc;
 	}
 
@@ -1554,7 +1554,7 @@ SendReceiveBlockingLock(const unsigned i
 	if (rc < 0)
 		server->sequence_number -= 2;
 
-	mutex_unlock(&server->srv_mutex);
+	cifs_server_unlock(server);
 
 	if (rc < 0) {
 		cifs_delete_mid(midQ);