Blob Blame History Raw
From: Jason Gunthorpe <jgg@mellanox.com>
Date: Sun, 16 Sep 2018 20:48:05 +0300
Subject: RDMA/umem: Make ib_umem_odp into a sub structure of ib_umem
Patch-mainline: v4.20-rc1
Git-commit: 41b4deeaa123e62e1037af7a0be547af2e0e05f1
References: bsc#1103992 FATE#326009

These two structures are linked together, use the container_of pattern
instead of a double allocation to make the code simpler and easier to
follow.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
Acked-by: Thomas Bogendoerfer <tbogendoerfer@suse.de>
---
 drivers/infiniband/core/umem.c        |   36 +++++++++------
 drivers/infiniband/core/umem_odp.c    |   77 +++++++++++++---------------------
 drivers/infiniband/core/umem_rbtree.c |    4 -
 drivers/infiniband/hw/mlx5/odp.c      |   26 +++++------
 include/rdma/ib_umem_odp.h            |   11 +---
 5 files changed, 71 insertions(+), 83 deletions(-)

--- a/drivers/infiniband/core/umem.c
+++ b/drivers/infiniband/core/umem.c
@@ -108,34 +108,39 @@ struct ib_umem *ib_umem_get(struct ib_uc
 	if (!can_do_mlock())
 		return ERR_PTR(-EPERM);
 
-	umem = kzalloc(sizeof *umem, GFP_KERNEL);
-	if (!umem)
-		return ERR_PTR(-ENOMEM);
+	if (access & IB_ACCESS_ON_DEMAND) {
+		umem = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
+		if (!umem)
+			return ERR_PTR(-ENOMEM);
+		umem->odp_data = to_ib_umem_odp(umem);
+	} else {
+		umem = kzalloc(sizeof(*umem), GFP_KERNEL);
+		if (!umem)
+			return ERR_PTR(-ENOMEM);
+	}
 
 	umem->context    = context;
 	umem->length     = size;
 	umem->address    = addr;
 	umem->page_shift = PAGE_SHIFT;
 	umem->writable   = ib_access_writable(access);
+	umem->owning_mm = mm = current->mm;
+	mmgrab(mm);
 
 	if (access & IB_ACCESS_ON_DEMAND) {
-		ret = ib_umem_odp_get(context, umem, access);
+		ret = ib_umem_odp_get(to_ib_umem_odp(umem), access);
 		if (ret)
 			goto umem_kfree;
 		return umem;
 	}
 
-	umem->owning_mm = mm = current->mm;
-	mmgrab(mm);
-	umem->odp_data = NULL;
-
 	/* We assume the memory is from hugetlb until proved otherwise */
 	umem->hugetlb   = 1;
 
 	page_list = (struct page **) __get_free_page(GFP_KERNEL);
 	if (!page_list) {
 		ret = -ENOMEM;
-		goto umem_kfree_drop;
+		goto umem_kfree;
 	}
 
 	/*
@@ -226,12 +231,11 @@ out:
 	if (vma_list)
 		free_page((unsigned long) vma_list);
 	free_page((unsigned long) page_list);
-umem_kfree_drop:
-	if (ret)
-		mmdrop(umem->owning_mm);
 umem_kfree:
-	if (ret)
+	if (ret) {
+		mmdrop(umem->owning_mm);
 		kfree(umem);
+	}
 	return ret ? ERR_PTR(ret) : umem;
 }
 EXPORT_SYMBOL(ib_umem_get);
@@ -239,7 +243,10 @@ EXPORT_SYMBOL(ib_umem_get);
 static void __ib_umem_release_tail(struct ib_umem *umem)
 {
 	mmdrop(umem->owning_mm);
-	kfree(umem);
+	if (umem->odp_data)
+		kfree(to_ib_umem_odp(umem));
+	else
+		kfree(umem);
 }
 
 static void ib_umem_release_defer(struct work_struct *work)
@@ -263,6 +270,7 @@ void ib_umem_release(struct ib_umem *ume
 
 	if (umem->odp_data) {
 		ib_umem_odp_release(to_ib_umem_odp(umem));
+		__ib_umem_release_tail(umem);
 		return;
 	}
 
--- a/drivers/infiniband/core/umem_odp.c
+++ b/drivers/infiniband/core/umem_odp.c
@@ -126,7 +126,7 @@ static void ib_ucontext_notifier_end_acc
 static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
 					       u64 start, u64 end, void *cookie)
 {
-	struct ib_umem *umem = umem_odp->umem;
+	struct ib_umem *umem = &umem_odp->umem;
 
 	/*
 	 * Increase the number of notifiers running, to
@@ -164,7 +164,7 @@ static int invalidate_page_trampoline(st
 				      u64 end, void *cookie)
 {
 	ib_umem_notifier_start_account(item);
-	item->umem->context->invalidate_range(item, start, start + PAGE_SIZE);
+	item->umem.context->invalidate_range(item, start, start + PAGE_SIZE);
 	ib_umem_notifier_end_account(item);
 	return 0;
 }
@@ -191,7 +191,7 @@ static int invalidate_range_start_trampo
 					     u64 start, u64 end, void *cookie)
 {
 	ib_umem_notifier_start_account(item);
-	item->umem->context->invalidate_range(item, start, end);
+	item->umem.context->invalidate_range(item, start, end);
 	return 0;
 }
 
@@ -248,28 +248,21 @@ static const struct mmu_notifier_ops ib_
 struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
 				      unsigned long addr, size_t size)
 {
-	struct ib_umem *umem;
 	struct ib_umem_odp *odp_data;
+	struct ib_umem *umem;
 	int pages = size >> PAGE_SHIFT;
 	int ret;
 
-	umem = kzalloc(sizeof(*umem), GFP_KERNEL);
-	if (!umem)
+	odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
+	if (!odp_data)
 		return ERR_PTR(-ENOMEM);
-
+	umem = &odp_data->umem;
 	umem->context    = context;
 	umem->length     = size;
 	umem->address    = addr;
 	umem->page_shift = PAGE_SHIFT;
 	umem->writable   = 1;
 
-	odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
-	if (!odp_data) {
-		ret = -ENOMEM;
-		goto out_umem;
-	}
-	odp_data->umem = umem;
-
 	mutex_init(&odp_data->umem_mutex);
 	init_completion(&odp_data->notifier_completion);
 
@@ -303,15 +296,14 @@ out_page_list:
 	vfree(odp_data->page_list);
 out_odp_data:
 	kfree(odp_data);
-out_umem:
-	kfree(umem);
 	return ERR_PTR(ret);
 }
 EXPORT_SYMBOL(ib_alloc_odp_umem);
 
-int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
-		    int access)
+int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 {
+	struct ib_ucontext *context = umem_odp->umem.context;
+	struct ib_umem *umem = &umem_odp->umem;
 	int ret_val;
 	struct pid *our_pid;
 	struct mm_struct *mm = get_task_mm(current);
@@ -347,28 +339,23 @@ int ib_umem_odp_get(struct ib_ucontext *
 		goto out_mm;
 	}
 
-	umem->odp_data = kzalloc(sizeof(*umem->odp_data), GFP_KERNEL);
-	if (!umem->odp_data) {
-		ret_val = -ENOMEM;
-		goto out_mm;
-	}
-	umem->odp_data->umem = umem;
-
-	mutex_init(&umem->odp_data->umem_mutex);
+	mutex_init(&umem_odp->umem_mutex);
 
-	init_completion(&umem->odp_data->notifier_completion);
+	init_completion(&umem_odp->notifier_completion);
 
 	if (ib_umem_num_pages(umem)) {
-		umem->odp_data->page_list = vzalloc(ib_umem_num_pages(umem) *
-					    sizeof(*umem->odp_data->page_list));
-		if (!umem->odp_data->page_list) {
+		umem_odp->page_list =
+			vzalloc(array_size(sizeof(*umem_odp->page_list),
+					   ib_umem_num_pages(umem)));
+		if (!umem_odp->page_list) {
 			ret_val = -ENOMEM;
-			goto out_odp_data;
+			goto out_mm;
 		}
 
-		umem->odp_data->dma_list = vzalloc(ib_umem_num_pages(umem) *
-					  sizeof(*umem->odp_data->dma_list));
-		if (!umem->odp_data->dma_list) {
+		umem_odp->dma_list =
+			vzalloc(array_size(sizeof(*umem_odp->dma_list),
+					   ib_umem_num_pages(umem)));
+		if (!umem_odp->dma_list) {
 			ret_val = -ENOMEM;
 			goto out_page_list;
 		}
@@ -382,13 +369,13 @@ int ib_umem_odp_get(struct ib_ucontext *
 	down_write(&context->umem_rwsem);
 	context->odp_mrs_count++;
 	if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
-		rbt_ib_umem_insert(&umem->odp_data->interval_tree,
+		rbt_ib_umem_insert(&umem_odp->interval_tree,
 				   &context->umem_tree);
 	if (likely(!atomic_read(&context->notifier_count)) ||
 	    context->odp_mrs_count == 1)
-		umem->odp_data->mn_counters_active = true;
+		umem_odp->mn_counters_active = true;
 	else
-		list_add(&umem->odp_data->no_private_counters,
+		list_add(&umem_odp->no_private_counters,
 			 &context->no_private_counters);
 	downgrade_write(&context->umem_rwsem);
 
@@ -421,11 +408,9 @@ int ib_umem_odp_get(struct ib_ucontext *
 
 out_mutex:
 	up_read(&context->umem_rwsem);
-	vfree(umem->odp_data->dma_list);
+	vfree(umem_odp->dma_list);
 out_page_list:
-	vfree(umem->odp_data->page_list);
-out_odp_data:
-	kfree(umem->odp_data);
+	vfree(umem_odp->page_list);
 out_mm:
 	mmput(mm);
 	return ret_val;
@@ -433,7 +418,7 @@ out_mm:
 
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
 {
-	struct ib_umem *umem = umem_odp->umem;
+	struct ib_umem *umem = &umem_odp->umem;
 	struct ib_ucontext *context = umem->context;
 
 	/*
@@ -495,8 +480,6 @@ out:
 
 	vfree(umem_odp->dma_list);
 	vfree(umem_odp->page_list);
-	kfree(umem_odp);
-	kfree(umem);
 }
 
 /*
@@ -524,7 +507,7 @@ static int ib_umem_odp_map_dma_single_pa
 		u64 access_mask,
 		unsigned long current_seq)
 {
-	struct ib_umem *umem = umem_odp->umem;
+	struct ib_umem *umem = &umem_odp->umem;
 	struct ib_device *dev = umem->context->device;
 	dma_addr_t dma_addr;
 	int stored_page = 0;
@@ -610,7 +593,7 @@ int ib_umem_odp_map_dma_pages(struct ib_
 			      u64 bcnt, u64 access_mask,
 			      unsigned long current_seq)
 {
-	struct ib_umem *umem = umem_odp->umem;
+	struct ib_umem *umem = &umem_odp->umem;
 	struct task_struct *owning_process  = NULL;
 	struct mm_struct   *owning_mm       = NULL;
 	struct page       **local_page_list = NULL;
@@ -726,7 +709,7 @@ EXPORT_SYMBOL(ib_umem_odp_map_dma_pages)
 void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
 				 u64 bound)
 {
-	struct ib_umem *umem = umem_odp->umem;
+	struct ib_umem *umem = &umem_odp->umem;
 	int idx;
 	u64 addr;
 	struct ib_device *dev = umem->context->device;
--- a/drivers/infiniband/core/umem_rbtree.c
+++ b/drivers/infiniband/core/umem_rbtree.c
@@ -50,7 +50,7 @@ static inline u64 node_start(struct umem
 	struct ib_umem_odp *umem_odp =
 			container_of(n, struct ib_umem_odp, interval_tree);
 
-	return ib_umem_start(umem_odp->umem);
+	return ib_umem_start(&umem_odp->umem);
 }
 
 /* Note that the representation of the intervals in the interval tree
@@ -63,7 +63,7 @@ static inline u64 node_last(struct umem_
 	struct ib_umem_odp *umem_odp =
 			container_of(n, struct ib_umem_odp, interval_tree);
 
-	return ib_umem_end(umem_odp->umem) - 1;
+	return ib_umem_end(&umem_odp->umem) - 1;
 }
 
 INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
--- a/drivers/infiniband/hw/mlx5/odp.c
+++ b/drivers/infiniband/hw/mlx5/odp.c
@@ -64,7 +64,7 @@ static int check_parent(struct ib_umem_o
 static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp)
 {
 	struct mlx5_ib_mr *mr = odp->private, *parent = mr->parent;
-	struct ib_ucontext *ctx = odp->umem->context;
+	struct ib_ucontext *ctx = odp->umem.context;
 	struct rb_node *rb;
 
 	down_read(&ctx->umem_rwsem);
@@ -102,7 +102,7 @@ static struct ib_umem_odp *odp_lookup(st
 		if (!rb)
 			goto not_found;
 		odp = rb_entry(rb, struct ib_umem_odp, interval_tree.rb);
-		if (ib_umem_start(odp->umem) > start + length)
+		if (ib_umem_start(&odp->umem) > start + length)
 			goto not_found;
 	}
 not_found:
@@ -137,7 +137,7 @@ void mlx5_odp_populate_klm(struct mlx5_k
 	for (i = 0; i < nentries; i++, pklm++) {
 		pklm->bcount = cpu_to_be32(MLX5_IMR_MTT_SIZE);
 		va = (offset + i) * MLX5_IMR_MTT_SIZE;
-		if (odp && odp->umem->address == va) {
+		if (odp && odp->umem.address == va) {
 			struct mlx5_ib_mr *mtt = odp->private;
 
 			pklm->key = cpu_to_be32(mtt->ibmr.lkey);
@@ -153,13 +153,13 @@ void mlx5_odp_populate_klm(struct mlx5_k
 static void mr_leaf_free_action(struct work_struct *work)
 {
 	struct ib_umem_odp *odp = container_of(work, struct ib_umem_odp, work);
-	int idx = ib_umem_start(odp->umem) >> MLX5_IMR_MTT_SHIFT;
+	int idx = ib_umem_start(&odp->umem) >> MLX5_IMR_MTT_SHIFT;
 	struct mlx5_ib_mr *mr = odp->private, *imr = mr->parent;
 
 	mr->parent = NULL;
 	synchronize_srcu(&mr->dev->mr_srcu);
 
-	ib_umem_release(odp->umem);
+	ib_umem_release(&odp->umem);
 	if (imr->live)
 		mlx5_ib_update_xlt(imr, idx, 1, 0,
 				   MLX5_IB_UPD_XLT_INDIRECT |
@@ -185,7 +185,7 @@ void mlx5_ib_invalidate_range(struct ib_
 		pr_err("invalidation called on NULL umem or non-ODP umem\n");
 		return;
 	}
-	umem = umem_odp->umem;
+	umem = &umem_odp->umem;
 
 	mr = umem_odp->private;
 
@@ -392,16 +392,16 @@ next_mr:
 			return ERR_CAST(odp);
 		}
 
-		mtt = implicit_mr_alloc(mr->ibmr.pd, odp->umem, 0,
+		mtt = implicit_mr_alloc(mr->ibmr.pd, &odp->umem, 0,
 					mr->access_flags);
 		if (IS_ERR(mtt)) {
 			mutex_unlock(&mr->umem->odp_data->umem_mutex);
-			ib_umem_release(odp->umem);
+			ib_umem_release(&odp->umem);
 			return ERR_CAST(mtt);
 		}
 
 		odp->private = mtt;
-		mtt->umem = odp->umem;
+		mtt->umem = &odp->umem;
 		mtt->mmkey.iova = addr;
 		mtt->parent = mr;
 		INIT_WORK(&odp->work, mr_leaf_free_action);
@@ -418,7 +418,7 @@ next_mr:
 	addr += MLX5_IMR_MTT_SIZE;
 	if (unlikely(addr < io_virt + bcnt)) {
 		odp = odp_next(odp);
-		if (odp && odp->umem->address != addr)
+		if (odp && odp->umem.address != addr)
 			odp = NULL;
 		goto next_mr;
 	}
@@ -465,7 +465,7 @@ static int mr_leaf_free(struct ib_umem_o
 			void *cookie)
 {
 	struct mlx5_ib_mr *mr = umem_odp->private, *imr = cookie;
-	struct ib_umem *umem = umem_odp->umem;
+	struct ib_umem *umem = &umem_odp->umem;
 
 	if (mr->parent != imr)
 		return 0;
@@ -518,7 +518,7 @@ static int pagefault_mr(struct mlx5_ib_d
 	}
 
 next_mr:
-	size = min_t(size_t, bcnt, ib_umem_end(odp->umem) - io_virt);
+	size = min_t(size_t, bcnt, ib_umem_end(&odp->umem) - io_virt);
 
 	page_shift = mr->umem->page_shift;
 	page_mask = ~(BIT(page_shift) - 1);
@@ -577,7 +577,7 @@ next_mr:
 
 		io_virt += size;
 		next = odp_next(odp);
-		if (unlikely(!next || next->umem->address != io_virt)) {
+		if (unlikely(!next || next->umem.address != io_virt)) {
 			mlx5_ib_dbg(dev, "next implicit leaf removed at 0x%llx. got %p\n",
 				    io_virt, next);
 			return -EAGAIN;
--- a/include/rdma/ib_umem_odp.h
+++ b/include/rdma/ib_umem_odp.h
@@ -43,6 +43,7 @@ struct umem_odp_node {
 };
 
 struct ib_umem_odp {
+	struct ib_umem umem;
 	/*
 	 * An array of the pages included in the on-demand paging umem.
 	 * Indices of pages that are currently not mapped into the device will
@@ -72,7 +73,6 @@ struct ib_umem_odp {
 	/* A linked list of umems that don't have private mmu notifier
 	 * counters yet. */
 	struct list_head no_private_counters;
-	struct ib_umem		*umem;
 
 	/* Tree tracking */
 	struct umem_odp_node	interval_tree;
@@ -84,13 +84,12 @@ struct ib_umem_odp {
 
 static inline struct ib_umem_odp *to_ib_umem_odp(struct ib_umem *umem)
 {
-	return umem->odp_data;
+	return container_of(umem, struct ib_umem_odp, umem);
 }
 
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
 
-int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
-		    int access);
+int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
 struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
 				      unsigned long addr, size_t size);
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp);
@@ -158,9 +157,7 @@ static inline int ib_umem_mmu_notifier_r
 
 #else /* CONFIG_INFINIBAND_ON_DEMAND_PAGING */
 
-static inline int ib_umem_odp_get(struct ib_ucontext *context,
-				  struct ib_umem *umem,
-				  int access)
+static inline int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 {
 	return -EINVAL;
 }