vfio/type1: handle DMA map/unmap up to the addressable limit

Before this commit, it was possible to create end of address space
mappings, but unmapping them via VFIO_IOMMU_UNMAP_DMA, replaying them
for newly added iommu domains, and querying their dirty pages via
VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP was broken due to bugs caused by
comparisons against (iova + size) expressions, which overflow to zero.
Additionally, there appears to be a page pinning leak in the
vfio_iommu_type1_release() path, since vfio_unmap_unpin()'s loop body
where unmap_unpin_*() are called will never be entered due to overflow
of (iova + size) to zero.

This commit handles DMA map/unmap operations up to the addressable
limit by comparing against inclusive end-of-range limits, and changing
iteration to perform relative traversals across range sizes, rather than
absolute traversals across addresses.

vfio_link_dma() inserts a zero-sized vfio_dma into the rb-tree, and is
only used for that purpose, so discard the size from consideration for
the insertion point.

Tested-by: Alejandro Jimenez <alejandro.j.jimenez@oracle.com>
Fixes: 73fa0d10d0 ("vfio: Type1 IOMMU implementation")
Reviewed-by: Jason Gunthorpe <jgg@nvidia.com>
Reviewed-by: Alejandro Jimenez <alejandro.j.jimenez@oracle.com>
Signed-off-by: Alex Mastro <amastro@fb.com>
Link: https://lore.kernel.org/r/20251028-fix-unmap-v6-3-2542b96bcc8e@fb.com
Signed-off-by: Alex Williamson <alex@shazbot.org>
This commit is contained in:
Alex Mastro 2025-10-28 09:15:02 -07:00 committed by Alex Williamson
parent 1196f1f897
commit ef270ec446
1 changed files with 42 additions and 35 deletions

View File

@ -168,12 +168,14 @@ static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
{ {
struct rb_node *node = iommu->dma_list.rb_node; struct rb_node *node = iommu->dma_list.rb_node;
WARN_ON(!size);
while (node) { while (node) {
struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node); struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
if (start + size <= dma->iova) if (start + size - 1 < dma->iova)
node = node->rb_left; node = node->rb_left;
else if (start >= dma->iova + dma->size) else if (start > dma->iova + dma->size - 1)
node = node->rb_right; node = node->rb_right;
else else
return dma; return dma;
@ -183,16 +185,19 @@ static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
} }
static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu, static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
dma_addr_t start, size_t size) dma_addr_t start,
dma_addr_t end)
{ {
struct rb_node *res = NULL; struct rb_node *res = NULL;
struct rb_node *node = iommu->dma_list.rb_node; struct rb_node *node = iommu->dma_list.rb_node;
struct vfio_dma *dma_res = NULL; struct vfio_dma *dma_res = NULL;
WARN_ON(end < start);
while (node) { while (node) {
struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node); struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
if (start < dma->iova + dma->size) { if (start <= dma->iova + dma->size - 1) {
res = node; res = node;
dma_res = dma; dma_res = dma;
if (start >= dma->iova) if (start >= dma->iova)
@ -202,7 +207,7 @@ static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
node = node->rb_right; node = node->rb_right;
} }
} }
if (res && size && dma_res->iova >= start + size) if (res && dma_res->iova > end)
res = NULL; res = NULL;
return res; return res;
} }
@ -212,11 +217,13 @@ static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL; struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
struct vfio_dma *dma; struct vfio_dma *dma;
WARN_ON(new->size != 0);
while (*link) { while (*link) {
parent = *link; parent = *link;
dma = rb_entry(parent, struct vfio_dma, node); dma = rb_entry(parent, struct vfio_dma, node);
if (new->iova + new->size <= dma->iova) if (new->iova <= dma->iova)
link = &(*link)->rb_left; link = &(*link)->rb_left;
else else
link = &(*link)->rb_right; link = &(*link)->rb_right;
@ -1141,12 +1148,12 @@ static size_t unmap_unpin_slow(struct vfio_domain *domain,
static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma, static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
bool do_accounting) bool do_accounting)
{ {
dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
struct vfio_domain *domain, *d; struct vfio_domain *domain, *d;
LIST_HEAD(unmapped_region_list); LIST_HEAD(unmapped_region_list);
struct iommu_iotlb_gather iotlb_gather; struct iommu_iotlb_gather iotlb_gather;
int unmapped_region_cnt = 0; int unmapped_region_cnt = 0;
long unlocked = 0; long unlocked = 0;
size_t pos = 0;
if (!dma->size) if (!dma->size)
return 0; return 0;
@ -1170,13 +1177,14 @@ static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
} }
iommu_iotlb_gather_init(&iotlb_gather); iommu_iotlb_gather_init(&iotlb_gather);
while (iova < end) { while (pos < dma->size) {
size_t unmapped, len; size_t unmapped, len;
phys_addr_t phys, next; phys_addr_t phys, next;
dma_addr_t iova = dma->iova + pos;
phys = iommu_iova_to_phys(domain->domain, iova); phys = iommu_iova_to_phys(domain->domain, iova);
if (WARN_ON(!phys)) { if (WARN_ON(!phys)) {
iova += PAGE_SIZE; pos += PAGE_SIZE;
continue; continue;
} }
@ -1185,7 +1193,7 @@ static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
* may require hardware cache flushing, try to find the * may require hardware cache flushing, try to find the
* largest contiguous physical memory chunk to unmap. * largest contiguous physical memory chunk to unmap.
*/ */
for (len = PAGE_SIZE; iova + len < end; len += PAGE_SIZE) { for (len = PAGE_SIZE; pos + len < dma->size; len += PAGE_SIZE) {
next = iommu_iova_to_phys(domain->domain, iova + len); next = iommu_iova_to_phys(domain->domain, iova + len);
if (next != phys + len) if (next != phys + len)
break; break;
@ -1206,7 +1214,7 @@ static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
break; break;
} }
iova += unmapped; pos += unmapped;
} }
dma->iommu_mapped = false; dma->iommu_mapped = false;
@ -1298,7 +1306,7 @@ static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
} }
static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu, static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
dma_addr_t iova, size_t size, size_t pgsize) dma_addr_t iova, dma_addr_t iova_end, size_t pgsize)
{ {
struct vfio_dma *dma; struct vfio_dma *dma;
struct rb_node *n; struct rb_node *n;
@ -1315,8 +1323,8 @@ static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
if (dma && dma->iova != iova) if (dma && dma->iova != iova)
return -EINVAL; return -EINVAL;
dma = vfio_find_dma(iommu, iova + size - 1, 0); dma = vfio_find_dma(iommu, iova_end, 1);
if (dma && dma->iova + dma->size != iova + size) if (dma && dma->iova + dma->size - 1 != iova_end)
return -EINVAL; return -EINVAL;
for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) { for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
@ -1325,7 +1333,7 @@ static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
if (dma->iova < iova) if (dma->iova < iova)
continue; continue;
if (dma->iova > iova + size - 1) if (dma->iova > iova_end)
break; break;
ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize); ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
@ -1418,7 +1426,7 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
if (unmap_all) { if (unmap_all) {
if (iova || size) if (iova || size)
goto unlock; goto unlock;
size = SIZE_MAX; iova_end = ~(dma_addr_t)0;
} else { } else {
if (!size || size & (pgsize - 1)) if (!size || size & (pgsize - 1))
goto unlock; goto unlock;
@ -1473,17 +1481,17 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
if (dma && dma->iova != iova) if (dma && dma->iova != iova)
goto unlock; goto unlock;
dma = vfio_find_dma(iommu, iova_end, 0); dma = vfio_find_dma(iommu, iova_end, 1);
if (dma && dma->iova + dma->size != iova + size) if (dma && dma->iova + dma->size - 1 != iova_end)
goto unlock; goto unlock;
} }
ret = 0; ret = 0;
n = first_n = vfio_find_dma_first_node(iommu, iova, size); n = first_n = vfio_find_dma_first_node(iommu, iova, iova_end);
while (n) { while (n) {
dma = rb_entry(n, struct vfio_dma, node); dma = rb_entry(n, struct vfio_dma, node);
if (dma->iova >= iova + size) if (dma->iova > iova_end)
break; break;
if (!iommu->v2 && iova > dma->iova) if (!iommu->v2 && iova > dma->iova)
@ -1813,12 +1821,12 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
for (; n; n = rb_next(n)) { for (; n; n = rb_next(n)) {
struct vfio_dma *dma; struct vfio_dma *dma;
dma_addr_t iova; size_t pos = 0;
dma = rb_entry(n, struct vfio_dma, node); dma = rb_entry(n, struct vfio_dma, node);
iova = dma->iova;
while (iova < dma->iova + dma->size) { while (pos < dma->size) {
dma_addr_t iova = dma->iova + pos;
phys_addr_t phys; phys_addr_t phys;
size_t size; size_t size;
@ -1834,14 +1842,14 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
phys = iommu_iova_to_phys(d->domain, iova); phys = iommu_iova_to_phys(d->domain, iova);
if (WARN_ON(!phys)) { if (WARN_ON(!phys)) {
iova += PAGE_SIZE; pos += PAGE_SIZE;
continue; continue;
} }
size = PAGE_SIZE; size = PAGE_SIZE;
p = phys + size; p = phys + size;
i = iova + size; i = iova + size;
while (i < dma->iova + dma->size && while (pos + size < dma->size &&
p == iommu_iova_to_phys(d->domain, i)) { p == iommu_iova_to_phys(d->domain, i)) {
size += PAGE_SIZE; size += PAGE_SIZE;
p += PAGE_SIZE; p += PAGE_SIZE;
@ -1849,9 +1857,8 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
} }
} else { } else {
unsigned long pfn; unsigned long pfn;
unsigned long vaddr = dma->vaddr + unsigned long vaddr = dma->vaddr + pos;
(iova - dma->iova); size_t n = dma->size - pos;
size_t n = dma->iova + dma->size - iova;
long npage; long npage;
npage = vfio_pin_pages_remote(dma, vaddr, npage = vfio_pin_pages_remote(dma, vaddr,
@ -1882,7 +1889,7 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
goto unwind; goto unwind;
} }
iova += size; pos += size;
} }
} }
@ -1899,29 +1906,29 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
unwind: unwind:
for (; n; n = rb_prev(n)) { for (; n; n = rb_prev(n)) {
struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node); struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
dma_addr_t iova; size_t pos = 0;
if (dma->iommu_mapped) { if (dma->iommu_mapped) {
iommu_unmap(domain->domain, dma->iova, dma->size); iommu_unmap(domain->domain, dma->iova, dma->size);
continue; continue;
} }
iova = dma->iova; while (pos < dma->size) {
while (iova < dma->iova + dma->size) { dma_addr_t iova = dma->iova + pos;
phys_addr_t phys, p; phys_addr_t phys, p;
size_t size; size_t size;
dma_addr_t i; dma_addr_t i;
phys = iommu_iova_to_phys(domain->domain, iova); phys = iommu_iova_to_phys(domain->domain, iova);
if (!phys) { if (!phys) {
iova += PAGE_SIZE; pos += PAGE_SIZE;
continue; continue;
} }
size = PAGE_SIZE; size = PAGE_SIZE;
p = phys + size; p = phys + size;
i = iova + size; i = iova + size;
while (i < dma->iova + dma->size && while (pos + size < dma->size &&
p == iommu_iova_to_phys(domain->domain, i)) { p == iommu_iova_to_phys(domain->domain, i)) {
size += PAGE_SIZE; size += PAGE_SIZE;
p += PAGE_SIZE; p += PAGE_SIZE;
@ -3059,7 +3066,7 @@ static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
if (iommu->dirty_page_tracking) if (iommu->dirty_page_tracking)
ret = vfio_iova_dirty_bitmap(range.bitmap.data, ret = vfio_iova_dirty_bitmap(range.bitmap.data,
iommu, iova, size, iommu, iova, iova_end,
range.bitmap.pgsize); range.bitmap.pgsize);
else else
ret = -EINVAL; ret = -EINVAL;