diff --git a/drivers/iommu/iommufd/io_pagetable.c b/drivers/iommu/iommufd/io_pagetable.c index 38c5fdc6c821..a120c2ae942a 100644 --- a/drivers/iommu/iommufd/io_pagetable.c +++ b/drivers/iommu/iommufd/io_pagetable.c @@ -973,6 +973,7 @@ static void iopt_unfill_domain(struct io_pagetable *iopt, if (iopt_is_dmabuf(pages)) { if (!iopt_dmabuf_revoked(pages)) iopt_area_unmap_domain(area, domain); + iopt_dmabuf_untrack_domain(pages, area, domain); } mutex_unlock(&pages->mutex); @@ -994,6 +995,8 @@ static void iopt_unfill_domain(struct io_pagetable *iopt, WARN_ON(area->storage_domain != domain); area->storage_domain = NULL; iopt_area_unfill_domain(area, pages, domain); + if (iopt_is_dmabuf(pages)) + iopt_dmabuf_untrack_domain(pages, area, domain); mutex_unlock(&pages->mutex); } } @@ -1023,10 +1026,16 @@ static int iopt_fill_domain(struct io_pagetable *iopt, if (!pages) continue; - mutex_lock(&pages->mutex); + guard(mutex)(&pages->mutex); + if (iopt_is_dmabuf(pages)) { + rc = iopt_dmabuf_track_domain(pages, area, domain); + if (rc) + goto out_unfill; + } rc = iopt_area_fill_domain(area, domain); if (rc) { - mutex_unlock(&pages->mutex); + if (iopt_is_dmabuf(pages)) + iopt_dmabuf_untrack_domain(pages, area, domain); goto out_unfill; } if (!area->storage_domain) { @@ -1035,7 +1044,6 @@ static int iopt_fill_domain(struct io_pagetable *iopt, interval_tree_insert(&area->pages_node, &pages->domains_itree); } - mutex_unlock(&pages->mutex); } return 0; @@ -1056,6 +1064,8 @@ static int iopt_fill_domain(struct io_pagetable *iopt, area->storage_domain = NULL; } iopt_area_unfill_domain(area, pages, domain); + if (iopt_is_dmabuf(pages)) + iopt_dmabuf_untrack_domain(pages, area, domain); mutex_unlock(&pages->mutex); } return rc; diff --git a/drivers/iommu/iommufd/io_pagetable.h b/drivers/iommu/iommufd/io_pagetable.h index 892daf4b1f1e..8f8d583e0243 100644 --- a/drivers/iommu/iommufd/io_pagetable.h +++ b/drivers/iommu/iommufd/io_pagetable.h @@ -70,6 +70,16 @@ void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages, void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain); +int iopt_dmabuf_track_domain(struct iopt_pages *pages, struct iopt_area *area, + struct iommu_domain *domain); +void iopt_dmabuf_untrack_domain(struct iopt_pages *pages, + struct iopt_area *area, + struct iommu_domain *domain); +int iopt_dmabuf_track_all_domains(struct iopt_area *area, + struct iopt_pages *pages); +void iopt_dmabuf_untrack_all_domains(struct iopt_area *area, + struct iopt_pages *pages); + static inline unsigned long iopt_area_index(struct iopt_area *area) { return area->pages_node.start; @@ -184,11 +194,18 @@ enum iopt_address_type { IOPT_ADDRESS_DMABUF, }; +struct iopt_pages_dmabuf_track { + struct iommu_domain *domain; + struct iopt_area *area; + struct list_head elm; +}; + struct iopt_pages_dmabuf { struct dma_buf_attachment *attach; struct dma_buf_phys_vec phys; /* Always PAGE_SIZE aligned */ unsigned long start; + struct list_head tracker; }; /* diff --git a/drivers/iommu/iommufd/pages.c b/drivers/iommu/iommufd/pages.c index e8bfd734fb69..a6c874d5eda8 100644 --- a/drivers/iommu/iommufd/pages.c +++ b/drivers/iommu/iommufd/pages.c @@ -1366,8 +1366,19 @@ struct iopt_pages *iopt_alloc_file_pages(struct file *file, unsigned long start, static void iopt_revoke_notify(struct dma_buf_attachment *attach) { struct iopt_pages *pages = attach->importer_priv; + struct iopt_pages_dmabuf_track *track; guard(mutex)(&pages->mutex); + if (iopt_dmabuf_revoked(pages)) + return; + + list_for_each_entry(track, &pages->dmabuf.tracker, elm) { + struct iopt_area *area = track->area; + + iopt_area_unmap_domain_range(area, track->domain, + iopt_area_index(area), + iopt_area_last_index(area)); + } pages->dmabuf.phys.len = 0; } @@ -1468,6 +1479,7 @@ struct iopt_pages *iopt_alloc_dmabuf_pages(struct iommufd_ctx *ictx, pages->account_mode = IOPT_PAGES_ACCOUNT_NONE; pages->type = IOPT_ADDRESS_DMABUF; pages->dmabuf.start = start - start_byte; + INIT_LIST_HEAD(&pages->dmabuf.tracker); rc = iopt_map_dmabuf(ictx, pages, dmabuf); if (rc) { @@ -1478,6 +1490,86 @@ struct iopt_pages *iopt_alloc_dmabuf_pages(struct iommufd_ctx *ictx, return pages; } +int iopt_dmabuf_track_domain(struct iopt_pages *pages, struct iopt_area *area, + struct iommu_domain *domain) +{ + struct iopt_pages_dmabuf_track *track; + + lockdep_assert_held(&pages->mutex); + if (WARN_ON(!iopt_is_dmabuf(pages))) + return -EINVAL; + + list_for_each_entry(track, &pages->dmabuf.tracker, elm) + if (WARN_ON(track->domain == domain && track->area == area)) + return -EINVAL; + + track = kzalloc(sizeof(*track), GFP_KERNEL); + if (!track) + return -ENOMEM; + track->domain = domain; + track->area = area; + list_add_tail(&track->elm, &pages->dmabuf.tracker); + + return 0; +} + +void iopt_dmabuf_untrack_domain(struct iopt_pages *pages, + struct iopt_area *area, + struct iommu_domain *domain) +{ + struct iopt_pages_dmabuf_track *track; + + lockdep_assert_held(&pages->mutex); + WARN_ON(!iopt_is_dmabuf(pages)); + + list_for_each_entry(track, &pages->dmabuf.tracker, elm) { + if (track->domain == domain && track->area == area) { + list_del(&track->elm); + kfree(track); + return; + } + } + WARN_ON(true); +} + +int iopt_dmabuf_track_all_domains(struct iopt_area *area, + struct iopt_pages *pages) +{ + struct iopt_pages_dmabuf_track *track; + struct iommu_domain *domain; + unsigned long index; + int rc; + + list_for_each_entry(track, &pages->dmabuf.tracker, elm) + if (WARN_ON(track->area == area)) + return -EINVAL; + + xa_for_each(&area->iopt->domains, index, domain) { + rc = iopt_dmabuf_track_domain(pages, area, domain); + if (rc) + goto err_untrack; + } + return 0; +err_untrack: + iopt_dmabuf_untrack_all_domains(area, pages); + return rc; +} + +void iopt_dmabuf_untrack_all_domains(struct iopt_area *area, + struct iopt_pages *pages) +{ + struct iopt_pages_dmabuf_track *track; + struct iopt_pages_dmabuf_track *tmp; + + list_for_each_entry_safe(track, tmp, &pages->dmabuf.tracker, + elm) { + if (track->area == area) { + list_del(&track->elm); + kfree(track); + } + } +} + void iopt_release_pages(struct kref *kref) { struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref); @@ -1495,6 +1587,7 @@ void iopt_release_pages(struct kref *kref) dma_buf_detach(dmabuf, pages->dmabuf.attach); dma_buf_put(dmabuf); + WARN_ON(!list_empty(&pages->dmabuf.tracker)); } else if (pages->type == IOPT_ADDRESS_FILE) { fput(pages->file); } @@ -1735,11 +1828,17 @@ int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages) return 0; mutex_lock(&pages->mutex); + if (iopt_is_dmabuf(pages)) { + rc = iopt_dmabuf_track_all_domains(area, pages); + if (rc) + goto out_unlock; + } + if (!iopt_dmabuf_revoked(pages)) { rc = pfn_reader_first(&pfns, pages, iopt_area_index(area), iopt_area_last_index(area)); if (rc) - goto out_unlock; + goto out_untrack; while (!pfn_reader_done(&pfns)) { done_first_end_index = pfns.batch_end_index; @@ -1794,6 +1893,9 @@ int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages) } } pfn_reader_destroy(&pfns); +out_untrack: + if (iopt_is_dmabuf(pages)) + iopt_dmabuf_untrack_all_domains(area, pages); out_unlock: mutex_unlock(&pages->mutex); return rc; @@ -1833,6 +1935,8 @@ void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages) WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb)); interval_tree_remove(&area->pages_node, &pages->domains_itree); iopt_area_unfill_domain(area, pages, area->storage_domain); + if (iopt_is_dmabuf(pages)) + iopt_dmabuf_untrack_all_domains(area, pages); area->storage_domain = NULL; out_unlock: mutex_unlock(&pages->mutex);