Blob Blame History Raw
From: Leon Romanovsky <leonro@mellanox.com>
Date: Mon, 27 Apr 2020 18:46:22 +0300
Subject: RDMA/mlx5: Combine copy of create QP command in RSS RAW QP
Patch-mainline: v5.8-rc1
Git-commit: 5ce0592b0ee56e41f1a4a164ac2f54dbfbbf5e49
References: jsc#SLE-15175

Change the create QP flow to handle all copy_from_user() operations in
one place.

Link: https://lore.kernel.org/r/20200427154636.381474-23-leon@kernel.org
Reviewed-by: Maor Gottlieb <maorg@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Acked-by: Thomas Bogendoerfer <tbogendoerfer@suse.de>
---
 drivers/infiniband/hw/mlx5/qp.c |  156 +++++++++++++++++++++-------------------
 1 file changed, 82 insertions(+), 74 deletions(-)

--- a/drivers/infiniband/hw/mlx5/qp.c
+++ b/drivers/infiniband/hw/mlx5/qp.c
@@ -1624,6 +1624,7 @@ static void destroy_rss_raw_qp_tir(struc
 
 static int create_rss_raw_qp_tir(struct ib_pd *pd, struct mlx5_ib_qp *qp,
 				 struct ib_qp_init_attr *init_attr,
+				 struct mlx5_ib_create_qp_rss *ucmd,
 				 struct ib_udata *udata)
 {
 	struct mlx5_ib_ucontext *mucontext = rdma_udata_to_drv_context(
@@ -1641,46 +1642,26 @@ static int create_rss_raw_qp_tir(struct
 	u32 outer_l4;
 	size_t min_resp_len;
 	u32 tdn = mucontext->tdn;
-	struct mlx5_ib_create_qp_rss ucmd = {};
-	size_t required_cmd_sz;
 	u8 lb_flag = 0;
 
 	min_resp_len = offsetof(typeof(resp), bfreg_index) + sizeof(resp.bfreg_index);
 	if (udata->outlen < min_resp_len)
 		return -EINVAL;
 
-	required_cmd_sz = offsetof(typeof(ucmd), flags) + sizeof(ucmd.flags);
-	if (udata->inlen < required_cmd_sz) {
-		mlx5_ib_dbg(dev, "invalid inlen\n");
-		return -EINVAL;
-	}
-
-	if (udata->inlen > sizeof(ucmd) &&
-	    !ib_is_udata_cleared(udata, sizeof(ucmd),
-				 udata->inlen - sizeof(ucmd))) {
-		mlx5_ib_dbg(dev, "inlen is not supported\n");
-		return -EOPNOTSUPP;
-	}
-
-	if (ib_copy_from_udata(&ucmd, udata, min(sizeof(ucmd), udata->inlen))) {
-		mlx5_ib_dbg(dev, "copy failed\n");
-		return -EFAULT;
-	}
-
-	if (ucmd.comp_mask) {
+	if (ucmd->comp_mask) {
 		mlx5_ib_dbg(dev, "invalid comp mask\n");
 		return -EOPNOTSUPP;
 	}
 
-	if (ucmd.flags & ~(MLX5_QP_FLAG_TUNNEL_OFFLOADS |
+	if (ucmd->flags & ~(MLX5_QP_FLAG_TUNNEL_OFFLOADS |
 			   MLX5_QP_FLAG_TIR_ALLOW_SELF_LB_UC |
 			   MLX5_QP_FLAG_TIR_ALLOW_SELF_LB_MC)) {
 		mlx5_ib_dbg(dev, "invalid flags\n");
 		return -EOPNOTSUPP;
 	}
 
-	if (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_INNER &&
-	    !(ucmd.flags & MLX5_QP_FLAG_TUNNEL_OFFLOADS)) {
+	if (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_INNER &&
+	    !(ucmd->flags & MLX5_QP_FLAG_TUNNEL_OFFLOADS)) {
 		mlx5_ib_dbg(dev, "Tunnel offloads must be set for inner RSS\n");
 		return -EOPNOTSUPP;
 	}
@@ -1717,29 +1698,29 @@ static int create_rss_raw_qp_tir(struct
 
 	hfso = MLX5_ADDR_OF(tirc, tirc, rx_hash_field_selector_outer);
 
-	if (ucmd.flags & MLX5_QP_FLAG_TUNNEL_OFFLOADS)
+	if (ucmd->flags & MLX5_QP_FLAG_TUNNEL_OFFLOADS)
 		MLX5_SET(tirc, tirc, tunneled_offload_en, 1);
 
 	MLX5_SET(tirc, tirc, self_lb_block, lb_flag);
 
-	if (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_INNER)
+	if (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_INNER)
 		hfso = MLX5_ADDR_OF(tirc, tirc, rx_hash_field_selector_inner);
 	else
 		hfso = MLX5_ADDR_OF(tirc, tirc, rx_hash_field_selector_outer);
 
-	switch (ucmd.rx_hash_function) {
+	switch (ucmd->rx_hash_function) {
 	case MLX5_RX_HASH_FUNC_TOEPLITZ:
 	{
 		void *rss_key = MLX5_ADDR_OF(tirc, tirc, rx_hash_toeplitz_key);
 		size_t len = MLX5_FLD_SZ_BYTES(tirc, rx_hash_toeplitz_key);
 
-		if (len != ucmd.rx_key_len) {
+		if (len != ucmd->rx_key_len) {
 			err = -EINVAL;
 			goto err;
 		}
 
 		MLX5_SET(tirc, tirc, rx_hash_fn, MLX5_RX_HASH_FN_TOEPLITZ);
-		memcpy(rss_key, ucmd.rx_hash_key, len);
+		memcpy(rss_key, ucmd->rx_hash_key, len);
 		break;
 	}
 	default:
@@ -1747,7 +1728,7 @@ static int create_rss_raw_qp_tir(struct
 		goto err;
 	}
 
-	if (!ucmd.rx_hash_fields_mask) {
+	if (!ucmd->rx_hash_fields_mask) {
 		/* special case when this TIR serves as steering entry without hashing */
 		if (!init_attr->rwq_ind_tbl->log_ind_tbl_size)
 			goto create_tir;
@@ -1755,29 +1736,31 @@ static int create_rss_raw_qp_tir(struct
 		goto err;
 	}
 
-	if (((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV4) ||
-	     (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV4)) &&
-	     ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV6) ||
-	     (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV6))) {
+	if (((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV4) ||
+	     (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV4)) &&
+	     ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV6) ||
+	     (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV6))) {
 		err = -EINVAL;
 		goto err;
 	}
 
 	/* If none of IPV4 & IPV6 SRC/DST was set - this bit field is ignored */
-	if ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV4) ||
-	    (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV4))
+	if ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV4) ||
+	    (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV4))
 		MLX5_SET(rx_hash_field_select, hfso, l3_prot_type,
 			 MLX5_L3_PROT_TYPE_IPV4);
-	else if ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV6) ||
-		 (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV6))
+	else if ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV6) ||
+		 (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV6))
 		MLX5_SET(rx_hash_field_select, hfso, l3_prot_type,
 			 MLX5_L3_PROT_TYPE_IPV6);
 
-	outer_l4 = ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_TCP) ||
-		    (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_TCP)) << 0 |
-		   ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_UDP) ||
-		    (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_UDP)) << 1 |
-		   (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_IPSEC_SPI) << 2;
+	outer_l4 = ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_TCP) ||
+		    (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_TCP))
+			   << 0 |
+		   ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_UDP) ||
+		    (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_UDP))
+			   << 1 |
+		   (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_IPSEC_SPI) << 2;
 
 	/* Check that only one l4 protocol is set */
 	if (outer_l4 & (outer_l4 - 1)) {
@@ -1786,32 +1769,32 @@ static int create_rss_raw_qp_tir(struct
 	}
 
 	/* If none of TCP & UDP SRC/DST was set - this bit field is ignored */
-	if ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_TCP) ||
-	    (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_TCP))
+	if ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_TCP) ||
+	    (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_TCP))
 		MLX5_SET(rx_hash_field_select, hfso, l4_prot_type,
 			 MLX5_L4_PROT_TYPE_TCP);
-	else if ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_UDP) ||
-		 (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_UDP))
+	else if ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_UDP) ||
+		 (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_UDP))
 		MLX5_SET(rx_hash_field_select, hfso, l4_prot_type,
 			 MLX5_L4_PROT_TYPE_UDP);
 
-	if ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV4) ||
-	    (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV6))
+	if ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV4) ||
+	    (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_IPV6))
 		selected_fields |= MLX5_HASH_FIELD_SEL_SRC_IP;
 
-	if ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV4) ||
-	    (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV6))
+	if ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV4) ||
+	    (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_IPV6))
 		selected_fields |= MLX5_HASH_FIELD_SEL_DST_IP;
 
-	if ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_TCP) ||
-	    (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_UDP))
+	if ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_TCP) ||
+	    (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_SRC_PORT_UDP))
 		selected_fields |= MLX5_HASH_FIELD_SEL_L4_SPORT;
 
-	if ((ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_TCP) ||
-	    (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_UDP))
+	if ((ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_TCP) ||
+	    (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_DST_PORT_UDP))
 		selected_fields |= MLX5_HASH_FIELD_SEL_L4_DPORT;
 
-	if (ucmd.rx_hash_fields_mask & MLX5_RX_HASH_IPSEC_SPI)
+	if (ucmd->rx_hash_fields_mask & MLX5_RX_HASH_IPSEC_SPI)
 		selected_fields |= MLX5_HASH_FIELD_SEL_IPSEC_SPI;
 
 	MLX5_SET(rx_hash_field_select, hfso, selected_fields, selected_fields);
@@ -2513,11 +2496,16 @@ static void process_vendor_flag(struct m
 }
 
 static int process_vendor_flags(struct mlx5_ib_dev *dev, struct mlx5_ib_qp *qp,
-				struct mlx5_ib_create_qp *ucmd)
+				void *ucmd, struct ib_qp_init_attr *attr)
 {
 	struct mlx5_core_dev *mdev = dev->mdev;
-	int flags = ucmd->flags;
 	bool cond;
+	int flags;
+
+	if (attr->rwq_ind_tbl)
+		flags = ((struct mlx5_ib_create_qp_rss *)ucmd)->flags;
+	else
+		flags = ((struct mlx5_ib_create_qp *)ucmd)->flags;
 
 	switch (flags & (MLX5_QP_FLAG_TYPE_DCT | MLX5_QP_FLAG_TYPE_DCI)) {
 	case MLX5_QP_FLAG_TYPE_DCI:
@@ -2657,21 +2645,32 @@ static size_t process_udata_size(struct
 				 struct ib_udata *udata)
 {
 	size_t ucmd = sizeof(struct mlx5_ib_create_qp);
+	size_t inlen = udata->inlen;
 
 	if (attr->qp_type == IB_QPT_DRIVER)
-		return (udata->inlen < ucmd) ? 0 : ucmd;
+		return (inlen < ucmd) ? 0 : ucmd;
+
+	if (!attr->rwq_ind_tbl)
+		return ucmd;
+
+	if (inlen < offsetofend(struct mlx5_ib_create_qp_rss, flags))
+		return 0;
+
+	ucmd = sizeof(struct mlx5_ib_create_qp_rss);
+	if (inlen > ucmd && !ib_is_udata_cleared(udata, ucmd, inlen - ucmd))
+		return 0;
 
-	return ucmd;
+	return min(ucmd, inlen);
 }
 
 static int create_raw_qp(struct ib_pd *pd, struct mlx5_ib_qp *qp,
-			 struct ib_qp_init_attr *attr,
-			 struct mlx5_ib_create_qp *ucmd, struct ib_udata *udata)
+			 struct ib_qp_init_attr *attr, void *ucmd,
+			 struct ib_udata *udata)
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 
 	if (attr->rwq_ind_tbl)
-		return create_rss_raw_qp_tir(pd, qp, attr, udata);
+		return create_rss_raw_qp_tir(pd, qp, attr, ucmd, udata);
 
 	return create_qp_common(dev, pd, attr, ucmd, udata, qp);
 }
@@ -2707,10 +2706,10 @@ struct ib_qp *mlx5_ib_create_qp(struct i
 				struct ib_qp_init_attr *init_attr,
 				struct ib_udata *udata)
 {
-	struct mlx5_ib_create_qp ucmd = {};
 	struct mlx5_ib_dev *dev;
 	struct mlx5_ib_qp *qp;
 	enum ib_qp_type type;
+	void *ucmd = NULL;
 	u16 xrcdn = 0;
 	int err;
 
@@ -2731,25 +2730,31 @@ struct ib_qp *mlx5_ib_create_qp(struct i
 	if (init_attr->qp_type == IB_QPT_GSI)
 		return mlx5_ib_gsi_create_qp(pd, init_attr);
 
-	if (udata && !init_attr->rwq_ind_tbl) {
+	if (udata) {
 		size_t inlen =
 			process_udata_size(init_attr, udata);
 
 		if (!inlen)
 			return ERR_PTR(-EINVAL);
 
-		err = ib_copy_from_udata(&ucmd, udata, inlen);
+		ucmd = kzalloc(inlen, GFP_KERNEL);
+		if (!ucmd)
+			return ERR_PTR(-ENOMEM);
+
+		err = ib_copy_from_udata(ucmd, udata, inlen);
 		if (err)
-			return ERR_PTR(err);
+			goto free_ucmd;
 	}
 
 	qp = kzalloc(sizeof(*qp), GFP_KERNEL);
-	if (!qp)
-		return ERR_PTR(-ENOMEM);
+	if (!qp) {
+		err = -ENOMEM;
+		goto free_ucmd;
+	}
 
 	qp->type = type;
 	if (udata) {
-		err = process_vendor_flags(dev, qp, &ucmd);
+		err = process_vendor_flags(dev, qp, ucmd, init_attr);
 		if (err)
 			goto free_qp;
 	}
@@ -2766,20 +2771,21 @@ struct ib_qp *mlx5_ib_create_qp(struct i
 
 	switch (qp->type) {
 	case IB_QPT_RAW_PACKET:
-		err = create_raw_qp(pd, qp, init_attr, &ucmd, udata);
+		err = create_raw_qp(pd, qp, init_attr, ucmd, udata);
 		break;
 	case MLX5_IB_QPT_DCT:
-		err = create_dct(pd, qp, init_attr, &ucmd, udata);
+		err = create_dct(pd, qp, init_attr, ucmd, udata);
 		break;
 	default:
-		err = create_qp_common(dev, pd, init_attr,
-				       (udata) ? &ucmd : NULL, udata, qp);
+		err = create_qp_common(dev, pd, init_attr, ucmd, udata, qp);
 	}
 	if (err) {
 		mlx5_ib_dbg(dev, "create_qp_common failed\n");
 		goto free_qp;
 	}
 
+	kfree(ucmd);
+
 	if (is_qp0(init_attr->qp_type))
 		qp->ibqp.qp_num = 0;
 	else if (is_qp1(init_attr->qp_type))
@@ -2793,6 +2799,8 @@ struct ib_qp *mlx5_ib_create_qp(struct i
 
 free_qp:
 	kfree(qp);
+free_ucmd:
+	kfree(ucmd);
 	return ERR_PTR(err);
 }