summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--drivers/iommu/iommufd/device.c3
-rw-r--r--drivers/iommu/iommufd/iommufd_private.h3
-rw-r--r--drivers/iommu/iommufd/main.c4
3 files changed, 7 insertions, 3 deletions
diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c
index 65fbd098f9e9..4c842368289f 100644
--- a/drivers/iommu/iommufd/device.c
+++ b/drivers/iommu/iommufd/device.c
@@ -711,6 +711,8 @@ iommufd_hw_pagetable_detach(struct iommufd_device *idev, ioasid_t pasid)
iopt_remove_reserved_iova(&hwpt_paging->ioas->iopt, idev->dev);
mutex_unlock(&igroup->lock);
+ iommufd_hw_pagetable_put(idev->ictx, hwpt);
+
/* Caller must destroy hwpt */
return hwpt;
}
@@ -1057,7 +1059,6 @@ void iommufd_device_detach(struct iommufd_device *idev, ioasid_t pasid)
hwpt = iommufd_hw_pagetable_detach(idev, pasid);
if (!hwpt)
return;
- iommufd_hw_pagetable_put(idev->ictx, hwpt);
refcount_dec(&idev->obj.users);
}
EXPORT_SYMBOL_NS_GPL(iommufd_device_detach, "IOMMUFD");
diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h
index 0da2a81eedfa..627f9b78483a 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -454,9 +454,8 @@ static inline void iommufd_hw_pagetable_put(struct iommufd_ctx *ictx,
if (hwpt->obj.type == IOMMUFD_OBJ_HWPT_PAGING) {
struct iommufd_hwpt_paging *hwpt_paging = to_hwpt_paging(hwpt);
- lockdep_assert_not_held(&hwpt_paging->ioas->mutex);
-
if (hwpt_paging->auto_domain) {
+ lockdep_assert_not_held(&hwpt_paging->ioas->mutex);
iommufd_object_put_and_try_destroy(ictx, &hwpt->obj);
return;
}
diff --git a/drivers/iommu/iommufd/main.c b/drivers/iommu/iommufd/main.c
index 88be2e157245..ce775fbbae94 100644
--- a/drivers/iommu/iommufd/main.c
+++ b/drivers/iommu/iommufd/main.c
@@ -122,6 +122,10 @@ void iommufd_object_abort(struct iommufd_ctx *ictx, struct iommufd_object *obj)
old = xas_store(&xas, NULL);
xa_unlock(&ictx->objects);
WARN_ON(old != XA_ZERO_ENTRY);
+
+ if (WARN_ON(!refcount_dec_and_test(&obj->users)))
+ return;
+
kfree(obj);
}