Blob Blame History Raw
From: Jason Gunthorpe <jgg@mellanox.com>
Date: Sun, 16 Sep 2018 20:44:45 +0300
Subject: RDMA/umem: Do not use current->tgid to track the mm_struct
Patch-mainline: v4.20-rc1
Git-commit: d4b4dd1b9706e48c370f88d3adfe713e43423cc9
References: bsc#1103992 FATE#326009

This is just wrong, the process that calls into the reg_mr is the process
associated with the umem, and that does not have to be the same process
that created the context.

When this code was first written mmgrab() didn't exist, however these days
we can just directly hold the mm_struct pointer in the umem and have no
ambiguity when it comes to releasing the umem as to which mm it was
associated with.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
Acked-by: Thomas Bogendoerfer <tbogendoerfer@suse.de>
---
 drivers/infiniband/core/umem.c |   77 +++++++++++++++++++----------------------
 include/rdma/ib_umem.h         |    3 -
 2 files changed, 37 insertions(+), 43 deletions(-)

--- a/drivers/infiniband/core/umem.c
+++ b/drivers/infiniband/core/umem.c
@@ -86,6 +86,7 @@ struct ib_umem *ib_umem_get(struct ib_uc
 	struct vm_area_struct **vma_list;
 	unsigned long lock_limit;
 	unsigned long cur_base;
+	struct mm_struct *mm;
 	unsigned long npages;
 	int ret;
 	int i;
@@ -124,6 +125,8 @@ struct ib_umem *ib_umem_get(struct ib_uc
 		return umem;
 	}
 
+	umem->owning_mm = mm = current->mm;
+	mmgrab(mm);
 	umem->odp_data = NULL;
 
 	/* We assume the memory is from hugetlb until proved otherwise */
@@ -132,7 +135,7 @@ struct ib_umem *ib_umem_get(struct ib_uc
 	page_list = (struct page **) __get_free_page(GFP_KERNEL);
 	if (!page_list) {
 		ret = -ENOMEM;
-		goto umem_kfree;
+		goto umem_kfree_drop;
 	}
 
 	/*
@@ -147,14 +150,14 @@ struct ib_umem *ib_umem_get(struct ib_uc
 
 	lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
 
-	down_write(&current->mm->mmap_sem);
-	current->mm->pinned_vm += npages;
-	if ((current->mm->pinned_vm > lock_limit) && !capable(CAP_IPC_LOCK)) {
-		up_write(&current->mm->mmap_sem);
+	down_write(&mm->mmap_sem);
+	mm->pinned_vm += npages;
+	if ((mm->pinned_vm > lock_limit) && !capable(CAP_IPC_LOCK)) {
+		up_write(&mm->mmap_sem);
 		ret = -ENOMEM;
 		goto vma;
 	}
-	up_write(&current->mm->mmap_sem);
+	up_write(&mm->mmap_sem);
 
 	cur_base = addr & PAGE_MASK;
 
@@ -172,14 +175,14 @@ struct ib_umem *ib_umem_get(struct ib_uc
 
 	sg_list_start = umem->sg_head.sgl;
 
-	down_read(&current->mm->mmap_sem);
+	down_read(&mm->mmap_sem);
 	while (npages) {
 		ret = get_user_pages_longterm(cur_base,
 				     min_t(unsigned long, npages,
 					   PAGE_SIZE / sizeof (struct page *)),
 				     gup_flags, page_list, vma_list);
 		if (ret < 0) {
-			up_read(&current->mm->mmap_sem);
+			up_read(&mm->mmap_sem);
 			goto umem_release;
 		}
 
@@ -197,7 +200,7 @@ struct ib_umem *ib_umem_get(struct ib_uc
 		/* preparing for next loop */
 		sg_list_start = sg;
 	}
-	up_read(&current->mm->mmap_sem);
+	up_read(&mm->mmap_sem);
 
 	umem->nmap = ib_dma_map_sg_attrs(context->device,
 				  umem->sg_head.sgl,
@@ -223,6 +226,9 @@ out:
 	if (vma_list)
 		free_page((unsigned long) vma_list);
 	free_page((unsigned long) page_list);
+umem_kfree_drop:
+	if (ret)
+		mmdrop(umem->owning_mm);
 umem_kfree:
 	if (ret)
 		kfree(umem);
@@ -230,15 +236,21 @@ umem_kfree:
 }
 EXPORT_SYMBOL(ib_umem_get);
 
-static void ib_umem_account(struct work_struct *work)
+static void __ib_umem_release_tail(struct ib_umem *umem)
+{
+	mmdrop(umem->owning_mm);
+	kfree(umem);
+}
+
+static void ib_umem_release_defer(struct work_struct *work)
 {
 	struct ib_umem *umem = container_of(work, struct ib_umem, work);
 
-	down_write(&umem->mm->mmap_sem);
-	umem->mm->pinned_vm -= umem->diff;
-	up_write(&umem->mm->mmap_sem);
-	mmput(umem->mm);
-	kfree(umem);
+	down_write(&umem->owning_mm->mmap_sem);
+	umem->owning_mm->pinned_vm -= ib_umem_num_pages(umem);
+	up_write(&umem->owning_mm->mmap_sem);
+
+	__ib_umem_release_tail(umem);
 }
 
 /**
@@ -248,9 +260,6 @@ static void ib_umem_account(struct work_
 void ib_umem_release(struct ib_umem *umem)
 {
 	struct ib_ucontext *context = umem->context;
-	struct mm_struct *mm;
-	struct task_struct *task;
-	unsigned long diff;
 
 	if (umem->odp_data) {
 		ib_umem_odp_release(umem);
@@ -259,41 +268,27 @@ void ib_umem_release(struct ib_umem *ume
 
 	__ib_umem_release(umem->context->device, umem, 1);
 
-	task = get_pid_task(umem->context->tgid, PIDTYPE_PID);
-	if (!task)
-		goto out;
-	mm = get_task_mm(task);
-	put_task_struct(task);
-	if (!mm)
-		goto out;
-
-	diff = ib_umem_num_pages(umem);
-
 	/*
 	 * We may be called with the mm's mmap_sem already held.  This
 	 * can happen when a userspace munmap() is the call that drops
 	 * the last reference to our file and calls our release
 	 * method.  If there are memory regions to destroy, we'll end
 	 * up here and not be able to take the mmap_sem.  In that case
-	 * we defer the vm_locked accounting to the system workqueue.
+	 * we defer the vm_locked accounting a workqueue.
 	 */
 	if (context->closing) {
-		if (!down_write_trylock(&mm->mmap_sem)) {
-			INIT_WORK(&umem->work, ib_umem_account);
-			umem->mm   = mm;
-			umem->diff = diff;
-
+		if (!down_write_trylock(&umem->owning_mm->mmap_sem)) {
+			INIT_WORK(&umem->work, ib_umem_release_defer);
 			queue_work(ib_wq, &umem->work);
 			return;
 		}
-	} else
-		down_write(&mm->mmap_sem);
+	} else {
+		down_write(&umem->owning_mm->mmap_sem);
+	}
+	umem->owning_mm->pinned_vm -= ib_umem_num_pages(umem);
+	up_write(&umem->owning_mm->mmap_sem);
 
-	mm->pinned_vm -= diff;
-	up_write(&mm->mmap_sem);
-	mmput(mm);
-out:
-	kfree(umem);
+	__ib_umem_release_tail(umem);
 }
 EXPORT_SYMBOL(ib_umem_release);
 
--- a/include/rdma/ib_umem.h
+++ b/include/rdma/ib_umem.h
@@ -42,14 +42,13 @@ struct ib_umem_odp;
 
 struct ib_umem {
 	struct ib_ucontext     *context;
+	struct mm_struct       *owning_mm;
 	size_t			length;
 	unsigned long		address;
 	int			page_shift;
 	int                     writable;
 	int                     hugetlb;
 	struct work_struct	work;
-	struct mm_struct       *mm;
-	unsigned long		diff;
 	struct ib_umem_odp     *odp_data;
 	struct sg_table sg_head;
 	int             nmap;