Blob Blame History Raw
From: Jason Gunthorpe <jgg@mellanox.com>
Date: Sun, 16 Sep 2018 20:48:08 +0300
Subject: RDMA/umem: Use umem->owning_mm inside ODP
Patch-mainline: v4.20-rc1
Git-commit: f27a0d50a4bc2861b472c2e3740d63a29d1ac460
References: bsc#1103992 FATE#326009

Since ODP had a single struct mmu_notifier located in the ucontext it
could only handle a single MM at a time, and this prevented it from using
the new owning_mm system.

With the prior rework it is now simple to let ODP track multiple MMs per
ucontext, finish the job so that the per_mm is allocated on a mm by mm
basis, and freed when the last umem is dropped from the ucontext.

As a side effect the new saner locking removes the lockdep splat about
nesting the umem_rwsem between mmu_notifier_unregister and
ib_umem_odp_release.

It also makes ODP work with multiple processes, across, fork, etc.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
Acked-by: Thomas Bogendoerfer <tbogendoerfer@suse.de>
---
 drivers/infiniband/core/umem_odp.c   |  301 ++++++++++++++++++-----------------
 drivers/infiniband/core/uverbs_cmd.c |    8 
 drivers/infiniband/hw/mlx5/main.c    |    7 
 drivers/infiniband/hw/mlx5/odp.c     |    2 
 include/rdma/ib_umem_odp.h           |   20 ++
 include/rdma/ib_verbs.h              |   22 --
 6 files changed, 191 insertions(+), 169 deletions(-)

--- a/drivers/infiniband/core/umem_odp.c
+++ b/drivers/infiniband/core/umem_odp.c
@@ -250,10 +250,135 @@ static const struct mmu_notifier_ops ib_
 	.invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
 };
 
-struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
-				      unsigned long addr, size_t size)
+static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
+{
+	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
+	struct ib_umem *umem = &umem_odp->umem;
+
+	down_write(&per_mm->umem_rwsem);
+	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
+		rbt_ib_umem_insert(&umem_odp->interval_tree,
+				   &per_mm->umem_tree);
+
+	if (likely(!atomic_read(&per_mm->notifier_count)))
+		umem_odp->mn_counters_active = true;
+	else
+		list_add(&umem_odp->no_private_counters,
+			 &per_mm->no_private_counters);
+	up_write(&per_mm->umem_rwsem);
+}
+
+static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
+{
+	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
+	struct ib_umem *umem = &umem_odp->umem;
+
+	down_write(&per_mm->umem_rwsem);
+	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
+		rbt_ib_umem_remove(&umem_odp->interval_tree,
+				   &per_mm->umem_tree);
+	if (!umem_odp->mn_counters_active) {
+		list_del(&umem_odp->no_private_counters);
+		complete_all(&umem_odp->notifier_completion);
+	}
+
+	up_write(&per_mm->umem_rwsem);
+}
+
+static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
+					       struct mm_struct *mm)
 {
 	struct ib_ucontext_per_mm *per_mm;
+	int ret;
+
+	per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
+	if (!per_mm)
+		return ERR_PTR(-ENOMEM);
+
+	per_mm->context = ctx;
+	per_mm->mm = mm;
+	per_mm->umem_tree = RB_ROOT;
+	init_rwsem(&per_mm->umem_rwsem);
+	INIT_LIST_HEAD(&per_mm->no_private_counters);
+
+	rcu_read_lock();
+	per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
+	rcu_read_unlock();
+
+	WARN_ON(mm != current->mm);
+
+	per_mm->mn.ops = &ib_umem_notifiers;
+	ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
+	if (ret) {
+		dev_err(&ctx->device->dev,
+			"Failed to register mmu_notifier %d\n", ret);
+		goto out_pid;
+	}
+
+	list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
+	return per_mm;
+
+out_pid:
+	put_pid(per_mm->tgid);
+	kfree(per_mm);
+	return ERR_PTR(ret);
+}
+
+static int get_per_mm(struct ib_umem_odp *umem_odp)
+{
+	struct ib_ucontext *ctx = umem_odp->umem.context;
+	struct ib_ucontext_per_mm *per_mm;
+
+	/*
+	 * Generally speaking we expect only one or two per_mm in this list,
+	 * so no reason to optimize this search today.
+	 */
+	mutex_lock(&ctx->per_mm_list_lock);
+	list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
+		if (per_mm->mm == umem_odp->umem.owning_mm)
+			goto found;
+	}
+
+	per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
+	if (IS_ERR(per_mm)) {
+		mutex_unlock(&ctx->per_mm_list_lock);
+		return PTR_ERR(per_mm);
+	}
+
+found:
+	umem_odp->per_mm = per_mm;
+	per_mm->odp_mrs_count++;
+	mutex_unlock(&ctx->per_mm_list_lock);
+
+	return 0;
+}
+
+void put_per_mm(struct ib_umem_odp *umem_odp)
+{
+	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
+	struct ib_ucontext *ctx = umem_odp->umem.context;
+	bool need_free;
+
+	mutex_lock(&ctx->per_mm_list_lock);
+	umem_odp->per_mm = NULL;
+	per_mm->odp_mrs_count--;
+	need_free = per_mm->odp_mrs_count == 0;
+	if (need_free)
+		list_del(&per_mm->ucontext_list);
+	mutex_unlock(&ctx->per_mm_list_lock);
+
+	if (!need_free)
+		return;
+
+	mmu_notifier_unregister(&per_mm->mn, per_mm->mm);
+	put_pid(per_mm->tgid);
+	kfree(per_mm);
+}
+
+struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
+				      unsigned long addr, size_t size)
+{
+	struct ib_ucontext *ctx = per_mm->context;
 	struct ib_umem_odp *odp_data;
 	struct ib_umem *umem;
 	int pages = size >> PAGE_SHIFT;
@@ -263,13 +388,13 @@ struct ib_umem_odp *ib_alloc_odp_umem(st
 	if (!odp_data)
 		return ERR_PTR(-ENOMEM);
 	umem = &odp_data->umem;
-	umem->context    = context;
+	umem->context    = ctx;
 	umem->length     = size;
 	umem->address    = addr;
 	umem->page_shift = PAGE_SHIFT;
 	umem->writable   = 1;
 	umem->is_odp = 1;
-	odp_data->per_mm = per_mm = &context->per_mm;
+	odp_data->per_mm = per_mm;
 
 	mutex_init(&odp_data->umem_mutex);
 	init_completion(&odp_data->notifier_completion);
@@ -286,15 +411,14 @@ struct ib_umem_odp *ib_alloc_odp_umem(st
 		goto out_page_list;
 	}
 
-	down_write(&per_mm->umem_rwsem);
+	/*
+	 * Caller must ensure that the umem_odp that the per_mm came from
+	 * cannot be freed during the call to ib_alloc_odp_umem.
+	 */
+	mutex_lock(&ctx->per_mm_list_lock);
 	per_mm->odp_mrs_count++;
-	rbt_ib_umem_insert(&odp_data->interval_tree, &per_mm->umem_tree);
-	if (likely(!atomic_read(&per_mm->notifier_count)))
-		odp_data->mn_counters_active = true;
-	else
-		list_add(&odp_data->no_private_counters,
-			 &per_mm->no_private_counters);
-	up_write(&per_mm->umem_rwsem);
+	mutex_unlock(&ctx->per_mm_list_lock);
+	add_umem_to_per_mm(odp_data);
 
 	return odp_data;
 
@@ -308,15 +432,13 @@ EXPORT_SYMBOL(ib_alloc_odp_umem);
 
 int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 {
-	struct ib_ucontext *context = umem_odp->umem.context;
 	struct ib_umem *umem = &umem_odp->umem;
-	struct ib_ucontext_per_mm *per_mm;
+	/*
+	 * NOTE: This must called in a process context where umem->owning_mm
+	 * == current->mm
+	 */
+	struct mm_struct *mm = umem->owning_mm;
 	int ret_val;
-	struct pid *our_pid;
-	struct mm_struct *mm = get_task_mm(current);
-
-	if (!mm)
-		return -EINVAL;
 
 	if (access & IB_ACCESS_HUGETLB) {
 		struct vm_area_struct *vma;
@@ -336,16 +458,6 @@ int ib_umem_odp_get(struct ib_umem_odp *
 		umem->hugetlb = 0;
 	}
 
-	/* Prevent creating ODP MRs in child processes */
-	rcu_read_lock();
-	our_pid = get_task_pid(current->group_leader, PIDTYPE_PID);
-	rcu_read_unlock();
-	put_pid(our_pid);
-	if (context->tgid != our_pid) {
-		ret_val = -EINVAL;
-		goto out_mm;
-	}
-
 	mutex_init(&umem_odp->umem_mutex);
 
 	init_completion(&umem_odp->notifier_completion);
@@ -354,10 +466,8 @@ int ib_umem_odp_get(struct ib_umem_odp *
 		umem_odp->page_list =
 			vzalloc(array_size(sizeof(*umem_odp->page_list),
 					   ib_umem_num_pages(umem)));
-		if (!umem_odp->page_list) {
-			ret_val = -ENOMEM;
-			goto out_mm;
-		}
+		if (!umem_odp->page_list)
+			return -ENOMEM;
 
 		umem_odp->dma_list =
 			vzalloc(array_size(sizeof(*umem_odp->dma_list),
@@ -368,67 +478,23 @@ int ib_umem_odp_get(struct ib_umem_odp *
 		}
 	}
 
-	/*
-	 * When using MMU notifiers, we will get a
-	 * notification before the "current" task (and MM) is
-	 * destroyed. We use the umem_rwsem semaphore to synchronize.
-	 */
-	umem_odp->per_mm = per_mm = &context->per_mm;
-
-	down_write(&per_mm->umem_rwsem);
-	per_mm->odp_mrs_count++;
-	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
-		rbt_ib_umem_insert(&umem_odp->interval_tree,
-				   &per_mm->umem_tree);
-	if (likely(!atomic_read(&per_mm->notifier_count)) ||
-	    per_mm->odp_mrs_count == 1)
-		umem_odp->mn_counters_active = true;
-	else
-		list_add(&umem_odp->no_private_counters,
-			 &per_mm->no_private_counters);
-	downgrade_write(&per_mm->umem_rwsem);
-
-	if (per_mm->odp_mrs_count == 1) {
-		/*
-		 * Note that at this point, no MMU notifier is running
-		 * for this per_mm!
-		 */
-		atomic_set(&per_mm->notifier_count, 0);
-		INIT_HLIST_NODE(&per_mm->mn.hlist);
-		per_mm->mn.ops = &ib_umem_notifiers;
-		ret_val = mmu_notifier_register(&per_mm->mn, mm);
-		if (ret_val) {
-			pr_err("Failed to register mmu_notifier %d\n", ret_val);
-			ret_val = -EBUSY;
-			goto out_mutex;
-		}
-	}
-
-	up_read(&per_mm->umem_rwsem);
+	ret_val = get_per_mm(umem_odp);
+	if (ret_val)
+		goto out_dma_list;
+	add_umem_to_per_mm(umem_odp);
 
-	/*
-	 * Note that doing an mmput can cause a notifier for the relevant mm.
-	 * If the notifier is called while we hold the umem_rwsem, this will
-	 * cause a deadlock. Therefore, we release the reference only after we
-	 * released the semaphore.
-	 */
-	mmput(mm);
 	return 0;
 
-out_mutex:
-	up_read(&per_mm->umem_rwsem);
+out_dma_list:
 	vfree(umem_odp->dma_list);
 out_page_list:
 	vfree(umem_odp->page_list);
-out_mm:
-	mmput(mm);
 	return ret_val;
 }
 
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
 {
 	struct ib_umem *umem = &umem_odp->umem;
-	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
 
 	/*
 	 * Ensure that no more pages are mapped in the umem.
@@ -439,54 +505,8 @@ void ib_umem_odp_release(struct ib_umem_
 	ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem),
 				    ib_umem_end(umem));
 
-	down_write(&per_mm->umem_rwsem);
-	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
-		rbt_ib_umem_remove(&umem_odp->interval_tree,
-				   &per_mm->umem_tree);
-	per_mm->odp_mrs_count--;
-	if (!umem_odp->mn_counters_active) {
-		list_del(&umem_odp->no_private_counters);
-		complete_all(&umem_odp->notifier_completion);
-	}
-
-	/*
-	 * Downgrade the lock to a read lock. This ensures that the notifiers
-	 * (who lock the mutex for reading) will be able to finish, and we
-	 * will be able to enventually obtain the mmu notifiers SRCU. Note
-	 * that since we are doing it atomically, no other user could register
-	 * and unregister while we do the check.
-	 */
-	downgrade_write(&per_mm->umem_rwsem);
-	if (!per_mm->odp_mrs_count) {
-		struct task_struct *owning_process = NULL;
-		struct mm_struct *owning_mm        = NULL;
-
-		owning_process =
-			get_pid_task(umem_odp->umem.context->tgid, PIDTYPE_PID);
-		if (owning_process == NULL)
-			/*
-			 * The process is already dead, notifier were removed
-			 * already.
-			 */
-			goto out;
-
-		owning_mm = get_task_mm(owning_process);
-		if (owning_mm == NULL)
-			/*
-			 * The process' mm is already dead, notifier were
-			 * removed already.
-			 */
-			goto out_put_task;
-		mmu_notifier_unregister(&per_mm->mn, owning_mm);
-
-		mmput(owning_mm);
-
-out_put_task:
-		put_task_struct(owning_process);
-	}
-out:
-	up_read(&per_mm->umem_rwsem);
-
+	remove_umem_from_per_mm(umem_odp);
+	put_per_mm(umem_odp);
 	vfree(umem_odp->dma_list);
 	vfree(umem_odp->page_list);
 }
@@ -604,7 +624,7 @@ int ib_umem_odp_map_dma_pages(struct ib_
 {
 	struct ib_umem *umem = &umem_odp->umem;
 	struct task_struct *owning_process  = NULL;
-	struct mm_struct   *owning_mm       = NULL;
+	struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
 	struct page       **local_page_list = NULL;
 	u64 page_mask, off;
 	int j, k, ret = 0, start_idx, npages = 0, page_shift;
@@ -628,15 +648,14 @@ int ib_umem_odp_map_dma_pages(struct ib_
 	user_virt = user_virt & page_mask;
 	bcnt += off; /* Charge for the first page offset as well. */
 
-	owning_process = get_pid_task(umem->context->tgid, PIDTYPE_PID);
-	if (owning_process == NULL) {
+	/*
+	 * owning_process is allowed to be NULL, this means somehow the mm is
+	 * existing beyond the lifetime of the originating process.. Presumably
+	 * mmget_not_zero will fail in this case.
+	 */
+	owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
+	if (WARN_ON(!mmget_not_zero(umem_odp->umem.owning_mm))) {
 		ret = -EINVAL;
-		goto out_no_task;
-	}
-
-	owning_mm = get_task_mm(owning_process);
-	if (owning_mm == NULL) {
-		ret = -ENOENT;
 		goto out_put_task;
 	}
 
@@ -708,8 +727,8 @@ int ib_umem_odp_map_dma_pages(struct ib_
 
 	mmput(owning_mm);
 out_put_task:
-	put_task_struct(owning_process);
-out_no_task:
+	if (owning_process)
+		put_task_struct(owning_process);
 	free_page((unsigned long)local_page_list);
 	return ret;
 }
--- a/drivers/infiniband/core/uverbs_cmd.c
+++ b/drivers/infiniband/core/uverbs_cmd.c
@@ -124,12 +124,8 @@ ssize_t ib_uverbs_get_context(struct ib_
 	ucontext->cleanup_retryable = false;
 
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	ucontext->per_mm.umem_tree = RB_ROOT;
-	init_rwsem(&ucontext->per_mm.umem_rwsem);
-	ucontext->per_mm.odp_mrs_count = 0;
-	INIT_LIST_HEAD(&ucontext->per_mm.no_private_counters);
-	ucontext->per_mm.context = ucontext;
-
+	mutex_init(&ucontext->per_mm_list_lock);
+	INIT_LIST_HEAD(&ucontext->per_mm_list);
 	if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING))
 		ucontext->invalidate_range = NULL;
 
--- a/drivers/infiniband/hw/mlx5/main.c
+++ b/drivers/infiniband/hw/mlx5/main.c
@@ -1861,6 +1861,13 @@ static int mlx5_ib_dealloc_ucontext(stru
 	struct mlx5_ib_dev *dev = to_mdev(ibcontext->device);
 	struct mlx5_bfreg_info *bfregi;
 
+#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
+	/* All umem's must be destroyed before destroying the ucontext. */
+	mutex_lock(&ibcontext->per_mm_list_lock);
+	WARN_ON(!list_empty(&ibcontext->per_mm_list));
+	mutex_unlock(&ibcontext->per_mm_list_lock);
+#endif
+
 	if (context->devx_uid)
 		mlx5_ib_devx_destroy(dev, context);
 
--- a/drivers/infiniband/hw/mlx5/odp.c
+++ b/drivers/infiniband/hw/mlx5/odp.c
@@ -393,7 +393,7 @@ next_mr:
 		if (nentries)
 			nentries++;
 	} else {
-		odp = ib_alloc_odp_umem(odp_mr->umem.context, addr,
+		odp = ib_alloc_odp_umem(odp_mr->per_mm, addr,
 					MLX5_IMR_MTT_SIZE);
 		if (IS_ERR(odp)) {
 			mutex_unlock(&odp_mr->umem_mutex);
--- a/include/rdma/ib_umem_odp.h
+++ b/include/rdma/ib_umem_odp.h
@@ -91,8 +91,26 @@ static inline struct ib_umem_odp *to_ib_
 
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
 
+struct ib_ucontext_per_mm {
+	struct ib_ucontext *context;
+	struct mm_struct *mm;
+	struct pid *tgid;
+
+	struct rb_root umem_tree;
+	/* Protects umem_tree */
+	struct rw_semaphore umem_rwsem;
+	atomic_t notifier_count;
+
+	struct mmu_notifier mn;
+	/* A list of umems that don't have private mmu notifier counters yet. */
+	struct list_head no_private_counters;
+	unsigned int odp_mrs_count;
+
+	struct list_head ucontext_list;
+};
+
 int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
-struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
+struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
 				      unsigned long addr, size_t size);
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp);
 
--- a/include/rdma/ib_verbs.h
+++ b/include/rdma/ib_verbs.h
@@ -1488,25 +1488,6 @@ struct ib_rdmacg_object {
 #endif
 };
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-struct ib_ucontext_per_mm {
-	struct ib_ucontext *context;
-
-	struct rb_root      umem_tree;
-	/*
-	 * Protects .umem_rbroot and tree, as well as odp_mrs_count and
-	 * mmu notifiers registration.
-	 */
-	struct rw_semaphore umem_rwsem;
-
-	struct mmu_notifier mn;
-	atomic_t notifier_count;
-	/* A list of umems that don't have private mmu notifier counters yet. */
-	struct list_head no_private_counters;
-	unsigned int odp_mrs_count;
-};
-#endif
-
 struct ib_ucontext {
 	struct ib_device       *device;
 	struct ib_uverbs_file  *ufile;
@@ -1523,7 +1504,8 @@ struct ib_ucontext {
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
 	void (*invalidate_range)(struct ib_umem_odp *umem_odp,
 				 unsigned long start, unsigned long end);
-	struct ib_ucontext_per_mm per_mm;
+	struct mutex per_mm_list_lock;
+	struct list_head per_mm_list;
 #endif
 
 	struct ib_rdmacg_object	cg_obj;