Blob Blame History Raw
From 2e66fcaab858ff6755685697060c2946af2bf672 Mon Sep 17 00:00:00 2001
From: Alex Deucher <alexander.deucher@amd.com>
Date: Fri, 22 Apr 2022 17:19:29 -0400
Subject: drm/amdgpu/psp: move shared buffer frees into single function
Git-commit: da40bf8f9376370b5bc2fda07aadaaddc308b1eb
Patch-mainline: v5.19-rc1
References: jsc#PED-1166 jsc#PED-1168 jsc#PED-1170 jsc#PED-1218 jsc#PED-1220 jsc#PED-1222 jsc#PED-1223 jsc#PED-1225

So we can properly clean up if any of the TAs or TMR fails
to properly initialize or terminate.  This avoids any
memory leaks in the error case.

Reviewed-by: Hawking Zhang <Hawking.Zhang@amd.com>
Signed-off-by: Alex Deucher <alexander.deucher@amd.com>
Acked-by: Patrik Jakobsson <pjakobsson@suse.de>
---
 drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c | 114 ++++++++++++------------
 1 file changed, 55 insertions(+), 59 deletions(-)

diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c
index 1ef2aba2ac3f..b1b6f5dd35dd 100644
--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c
+++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c
@@ -153,6 +153,36 @@ static int psp_early_init(void *handle)
 	return 0;
 }
 
+static void psp_free_shared_bufs(struct psp_context *psp)
+{
+	void *tmr_buf;
+	void **pptr;
+
+	/* free TMR memory buffer */
+	pptr = amdgpu_sriov_vf(psp->adev) ? &tmr_buf : NULL;
+	amdgpu_bo_free_kernel(&psp->tmr_bo, &psp->tmr_mc_addr, pptr);
+
+	/* free xgmi shared memory */
+	psp_ta_free_shared_buf(&psp->xgmi_context.context.mem_context);
+
+	/* free ras shared memory */
+	psp_ta_free_shared_buf(&psp->ras_context.context.mem_context);
+
+	/* free hdcp shared memory */
+	psp_ta_free_shared_buf(&psp->hdcp_context.context.mem_context);
+
+	/* free dtm shared memory */
+	psp_ta_free_shared_buf(&psp->dtm_context.context.mem_context);
+
+	/* free rap shared memory */
+	psp_ta_free_shared_buf(&psp->rap_context.context.mem_context);
+
+	/* free securedisplay shared memory */
+	psp_ta_free_shared_buf(&psp->securedisplay_context.context.mem_context);
+
+
+}
+
 static void psp_memory_training_fini(struct psp_context *psp)
 {
 	struct psp_memory_training_context *ctx = &psp->mem_train_ctx;
@@ -747,17 +777,7 @@ static int psp_tmr_unload(struct psp_context *psp)
 
 static int psp_tmr_terminate(struct psp_context *psp)
 {
-	int ret;
-	void *tmr_buf;
-	void **pptr;
-
-	ret = psp_tmr_unload(psp);
-
-	/* free TMR memory buffer */
-	pptr = amdgpu_sriov_vf(psp->adev) ? &tmr_buf : NULL;
-	amdgpu_bo_free_kernel(&psp->tmr_bo, &psp->tmr_mc_addr, pptr);
-
-	return ret;
+	return psp_tmr_unload(psp);
 }
 
 int psp_get_fw_attestation_records_addr(struct psp_context *psp,
@@ -1102,9 +1122,6 @@ int psp_xgmi_terminate(struct psp_context *psp)
 
 	psp->xgmi_context.context.initialized = false;
 
-	/* free xgmi shared memory */
-	psp_ta_free_shared_buf(&psp->xgmi_context.context.mem_context);
-
 	return ret;
 }
 
@@ -1465,9 +1482,6 @@ int psp_ras_terminate(struct psp_context *psp)
 
 	psp->ras_context.context.initialized = false;
 
-	/* free ras shared memory */
-	psp_ta_free_shared_buf(&psp->ras_context.context.mem_context);
-
 	return ret;
 }
 
@@ -1650,23 +1664,13 @@ static int psp_hdcp_terminate(struct psp_context *psp)
 	if (amdgpu_sriov_vf(psp->adev))
 		return 0;
 
-	if (!psp->hdcp_context.context.initialized) {
-		if (psp->hdcp_context.context.mem_context.shared_buf) {
-			ret = 0;
-			goto out;
-		} else {
-			return 0;
-		}
-	}
+	if (!psp->hdcp_context.context.initialized)
+		return 0;
 
 	ret = psp_ta_unload(psp, &psp->hdcp_context.context);
 
 	psp->hdcp_context.context.initialized = false;
 
-out:
-	/* free hdcp shared memory */
-	psp_ta_free_shared_buf(&psp->hdcp_context.context.mem_context);
-
 	return ret;
 }
 // HDCP end
@@ -1727,23 +1731,13 @@ static int psp_dtm_terminate(struct psp_context *psp)
 	if (amdgpu_sriov_vf(psp->adev))
 		return 0;
 
-	if (!psp->dtm_context.context.initialized) {
-		if (psp->dtm_context.context.mem_context.shared_buf) {
-			ret = 0;
-			goto out;
-		} else {
-			return 0;
-		}
-	}
+	if (!psp->dtm_context.context.initialized)
+		return 0;
 
 	ret = psp_ta_unload(psp, &psp->dtm_context.context);
 
 	psp->dtm_context.context.initialized = false;
 
-out:
-	/* free dtm shared memory */
-	psp_ta_free_shared_buf(&psp->dtm_context.context.mem_context);
-
 	return ret;
 }
 // DTM end
@@ -1785,6 +1779,8 @@ static int psp_rap_initialize(struct psp_context *psp)
 	ret = psp_rap_invoke(psp, TA_CMD_RAP__INITIALIZE, &status);
 	if (ret || status != TA_RAP_STATUS__SUCCESS) {
 		psp_rap_terminate(psp);
+		/* free rap shared memory */
+		psp_ta_free_shared_buf(&psp->rap_context.context.mem_context);
 
 		dev_warn(psp->adev->dev, "RAP TA initialize fail (%d) status %d.\n",
 			 ret, status);
@@ -1806,9 +1802,6 @@ static int psp_rap_terminate(struct psp_context *psp)
 
 	psp->rap_context.context.initialized = false;
 
-	/* free rap shared memory */
-	psp_ta_free_shared_buf(&psp->rap_context.context.mem_context);
-
 	return ret;
 }
 
@@ -1889,6 +1882,8 @@ static int psp_securedisplay_initialize(struct psp_context *psp)
 	ret = psp_securedisplay_invoke(psp, TA_SECUREDISPLAY_COMMAND__QUERY_TA);
 	if (ret) {
 		psp_securedisplay_terminate(psp);
+		/* free securedisplay shared memory */
+		psp_ta_free_shared_buf(&psp->securedisplay_context.context.mem_context);
 		dev_err(psp->adev->dev, "SECUREDISPLAY TA initialize fail.\n");
 		return -EINVAL;
 	}
@@ -1919,9 +1914,6 @@ static int psp_securedisplay_terminate(struct psp_context *psp)
 
 	psp->securedisplay_context.context.initialized = false;
 
-	/* free securedisplay shared memory */
-	psp_ta_free_shared_buf(&psp->securedisplay_context.context.mem_context);
-
 	return ret;
 }
 
@@ -2524,16 +2516,18 @@ static int psp_hw_fini(void *handle)
 	}
 
 	psp_asd_terminate(psp);
-
 	psp_tmr_terminate(psp);
+
 	psp_ring_destroy(psp, PSP_RING_TYPE__KM);
 
+	psp_free_shared_bufs(psp);
+
 	return 0;
 }
 
 static int psp_suspend(void *handle)
 {
-	int ret;
+	int ret = 0;
 	struct amdgpu_device *adev = (struct amdgpu_device *)handle;
 	struct psp_context *psp = &adev->psp;
 
@@ -2542,7 +2536,7 @@ static int psp_suspend(void *handle)
 		ret = psp_xgmi_terminate(psp);
 		if (ret) {
 			DRM_ERROR("Failed to terminate xgmi ta\n");
-			return ret;
+			goto out;
 		}
 	}
 
@@ -2550,49 +2544,51 @@ static int psp_suspend(void *handle)
 		ret = psp_ras_terminate(psp);
 		if (ret) {
 			DRM_ERROR("Failed to terminate ras ta\n");
-			return ret;
+			goto out;
 		}
 		ret = psp_hdcp_terminate(psp);
 		if (ret) {
 			DRM_ERROR("Failed to terminate hdcp ta\n");
-			return ret;
+			goto out;
 		}
 		ret = psp_dtm_terminate(psp);
 		if (ret) {
 			DRM_ERROR("Failed to terminate dtm ta\n");
-			return ret;
+			goto out;
 		}
 		ret = psp_rap_terminate(psp);
 		if (ret) {
 			DRM_ERROR("Failed to terminate rap ta\n");
-			return ret;
+			goto out;
 		}
 		ret = psp_securedisplay_terminate(psp);
 		if (ret) {
 			DRM_ERROR("Failed to terminate securedisplay ta\n");
-			return ret;
+			goto out;
 		}
 	}
 
 	ret = psp_asd_terminate(psp);
 	if (ret) {
 		DRM_ERROR("Failed to terminate asd\n");
-		return ret;
+		goto out;
 	}
 
 	ret = psp_tmr_terminate(psp);
 	if (ret) {
 		DRM_ERROR("Failed to terminate tmr\n");
-		return ret;
+		goto out;
 	}
 
 	ret = psp_ring_stop(psp, PSP_RING_TYPE__KM);
 	if (ret) {
 		DRM_ERROR("PSP ring stop failed\n");
-		return ret;
 	}
 
-	return 0;
+out:
+	psp_free_shared_bufs(psp);
+
+	return ret;
 }
 
 static int psp_resume(void *handle)
-- 
2.38.1