]> Git Repo - linux.git/blob - drivers/iommu/iommufd/pages.c
Linux 6.14-rc3
[linux.git] / drivers / iommu / iommufd / pages.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * The iopt_pages is the center of the storage and motion of PFNs. Each
5  * iopt_pages represents a logical linear array of full PFNs. The array is 0
6  * based and has npages in it. Accessors use 'index' to refer to the entry in
7  * this logical array, regardless of its storage location.
8  *
9  * PFNs are stored in a tiered scheme:
10  *  1) iopt_pages::pinned_pfns xarray
11  *  2) An iommu_domain
12  *  3) The origin of the PFNs, i.e. the userspace pointer
13  *
14  * PFN have to be copied between all combinations of tiers, depending on the
15  * configuration.
16  *
17  * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18  * The storage locations of the PFN's index are tracked in the two interval
19  * trees. If no interval includes the index then it is not pinned.
20  *
21  * If access_itree includes the PFN's index then an in-kernel access has
22  * requested the page. The PFN is stored in the xarray so other requestors can
23  * continue to find it.
24  *
25  * If the domains_itree includes the PFN's index then an iommu_domain is storing
26  * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27  * duplicating storage the xarray is not used if only iommu_domains are using
28  * the PFN's index.
29  *
30  * As a general principle this is designed so that destroy never fails. This
31  * means removing an iommu_domain or releasing a in-kernel access will not fail
32  * due to insufficient memory. In practice this means some cases have to hold
33  * PFNs in the xarray even though they are also being stored in an iommu_domain.
34  *
35  * While the iopt_pages can use an iommu_domain as storage, it does not have an
36  * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37  * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38  * and reference their own slice of the PFN array, with sub page granularity.
39  *
40  * In this file the term 'last' indicates an inclusive and closed interval, eg
41  * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42  * no PFNs.
43  *
44  * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45  * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46  * ULONG_MAX so last_index + 1 cannot overflow.
47  */
48 #include <linux/file.h>
49 #include <linux/highmem.h>
50 #include <linux/iommu.h>
51 #include <linux/iommufd.h>
52 #include <linux/kthread.h>
53 #include <linux/overflow.h>
54 #include <linux/slab.h>
55 #include <linux/sched/mm.h>
56
57 #include "double_span.h"
58 #include "io_pagetable.h"
59
60 #ifndef CONFIG_IOMMUFD_TEST
61 #define TEMP_MEMORY_LIMIT 65536
62 #else
63 #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
64 #endif
65 #define BATCH_BACKUP_SIZE 32
66
67 /*
68  * More memory makes pin_user_pages() and the batching more efficient, but as
69  * this is only a performance optimization don't try too hard to get it. A 64k
70  * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
71  * pfn_batch. Various destroy paths cannot fail and provide a small amount of
72  * stack memory as a backup contingency. If backup_len is given this cannot
73  * fail.
74  */
75 static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
76 {
77         void *res;
78
79         if (WARN_ON(*size == 0))
80                 return NULL;
81
82         if (*size < backup_len)
83                 return backup;
84
85         if (!backup && iommufd_should_fail())
86                 return NULL;
87
88         *size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
89         res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
90         if (res)
91                 return res;
92         *size = PAGE_SIZE;
93         if (backup_len) {
94                 res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
95                 if (res)
96                         return res;
97                 *size = backup_len;
98                 return backup;
99         }
100         return kmalloc(*size, GFP_KERNEL);
101 }
102
103 void interval_tree_double_span_iter_update(
104         struct interval_tree_double_span_iter *iter)
105 {
106         unsigned long last_hole = ULONG_MAX;
107         unsigned int i;
108
109         for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
110                 if (interval_tree_span_iter_done(&iter->spans[i])) {
111                         iter->is_used = -1;
112                         return;
113                 }
114
115                 if (iter->spans[i].is_hole) {
116                         last_hole = min(last_hole, iter->spans[i].last_hole);
117                         continue;
118                 }
119
120                 iter->is_used = i + 1;
121                 iter->start_used = iter->spans[i].start_used;
122                 iter->last_used = min(iter->spans[i].last_used, last_hole);
123                 return;
124         }
125
126         iter->is_used = 0;
127         iter->start_hole = iter->spans[0].start_hole;
128         iter->last_hole =
129                 min(iter->spans[0].last_hole, iter->spans[1].last_hole);
130 }
131
132 void interval_tree_double_span_iter_first(
133         struct interval_tree_double_span_iter *iter,
134         struct rb_root_cached *itree1, struct rb_root_cached *itree2,
135         unsigned long first_index, unsigned long last_index)
136 {
137         unsigned int i;
138
139         iter->itrees[0] = itree1;
140         iter->itrees[1] = itree2;
141         for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
142                 interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
143                                               first_index, last_index);
144         interval_tree_double_span_iter_update(iter);
145 }
146
147 void interval_tree_double_span_iter_next(
148         struct interval_tree_double_span_iter *iter)
149 {
150         unsigned int i;
151
152         if (iter->is_used == -1 ||
153             iter->last_hole == iter->spans[0].last_index) {
154                 iter->is_used = -1;
155                 return;
156         }
157
158         for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
159                 interval_tree_span_iter_advance(
160                         &iter->spans[i], iter->itrees[i], iter->last_hole + 1);
161         interval_tree_double_span_iter_update(iter);
162 }
163
164 static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
165 {
166         int rc;
167
168         rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
169         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
170                 WARN_ON(rc || pages->npinned > pages->npages);
171 }
172
173 static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
174 {
175         int rc;
176
177         rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
178         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
179                 WARN_ON(rc || pages->npinned > pages->npages);
180 }
181
182 static void iopt_pages_err_unpin(struct iopt_pages *pages,
183                                  unsigned long start_index,
184                                  unsigned long last_index,
185                                  struct page **page_list)
186 {
187         unsigned long npages = last_index - start_index + 1;
188
189         unpin_user_pages(page_list, npages);
190         iopt_pages_sub_npinned(pages, npages);
191 }
192
193 /*
194  * index is the number of PAGE_SIZE units from the start of the area's
195  * iopt_pages. If the iova is sub page-size then the area has an iova that
196  * covers a portion of the first and last pages in the range.
197  */
198 static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
199                                              unsigned long index)
200 {
201         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
202                 WARN_ON(index < iopt_area_index(area) ||
203                         index > iopt_area_last_index(area));
204         index -= iopt_area_index(area);
205         if (index == 0)
206                 return iopt_area_iova(area);
207         return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
208 }
209
210 static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
211                                                   unsigned long index)
212 {
213         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
214                 WARN_ON(index < iopt_area_index(area) ||
215                         index > iopt_area_last_index(area));
216         if (index == iopt_area_last_index(area))
217                 return iopt_area_last_iova(area);
218         return iopt_area_iova(area) - area->page_offset +
219                (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
220 }
221
222 static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
223                                size_t size)
224 {
225         size_t ret;
226
227         ret = iommu_unmap(domain, iova, size);
228         /*
229          * It is a logic error in this code or a driver bug if the IOMMU unmaps
230          * something other than exactly as requested. This implies that the
231          * iommu driver may not fail unmap for reasons beyond bad agruments.
232          * Particularly, the iommu driver may not do a memory allocation on the
233          * unmap path.
234          */
235         WARN_ON(ret != size);
236 }
237
238 static void iopt_area_unmap_domain_range(struct iopt_area *area,
239                                          struct iommu_domain *domain,
240                                          unsigned long start_index,
241                                          unsigned long last_index)
242 {
243         unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
244
245         iommu_unmap_nofail(domain, start_iova,
246                            iopt_area_index_to_iova_last(area, last_index) -
247                                    start_iova + 1);
248 }
249
250 static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
251                                                      unsigned long index)
252 {
253         struct interval_tree_node *node;
254
255         node = interval_tree_iter_first(&pages->domains_itree, index, index);
256         if (!node)
257                 return NULL;
258         return container_of(node, struct iopt_area, pages_node);
259 }
260
261 /*
262  * A simple datastructure to hold a vector of PFNs, optimized for contiguous
263  * PFNs. This is used as a temporary holding memory for shuttling pfns from one
264  * place to another. Generally everything is made more efficient if operations
265  * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
266  * better cache locality, etc
267  */
268 struct pfn_batch {
269         unsigned long *pfns;
270         u32 *npfns;
271         unsigned int array_size;
272         unsigned int end;
273         unsigned int total_pfns;
274 };
275
276 static void batch_clear(struct pfn_batch *batch)
277 {
278         batch->total_pfns = 0;
279         batch->end = 0;
280         batch->pfns[0] = 0;
281         batch->npfns[0] = 0;
282 }
283
284 /*
285  * Carry means we carry a portion of the final hugepage over to the front of the
286  * batch
287  */
288 static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
289 {
290         if (!keep_pfns)
291                 return batch_clear(batch);
292
293         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
294                 WARN_ON(!batch->end ||
295                         batch->npfns[batch->end - 1] < keep_pfns);
296
297         batch->total_pfns = keep_pfns;
298         batch->pfns[0] = batch->pfns[batch->end - 1] +
299                          (batch->npfns[batch->end - 1] - keep_pfns);
300         batch->npfns[0] = keep_pfns;
301         batch->end = 1;
302 }
303
304 static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
305 {
306         if (!batch->total_pfns)
307                 return;
308         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
309                 WARN_ON(batch->total_pfns != batch->npfns[0]);
310         skip_pfns = min(batch->total_pfns, skip_pfns);
311         batch->pfns[0] += skip_pfns;
312         batch->npfns[0] -= skip_pfns;
313         batch->total_pfns -= skip_pfns;
314 }
315
316 static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
317                         size_t backup_len)
318 {
319         const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
320         size_t size = max_pages * elmsz;
321
322         batch->pfns = temp_kmalloc(&size, backup, backup_len);
323         if (!batch->pfns)
324                 return -ENOMEM;
325         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
326                 return -EINVAL;
327         batch->array_size = size / elmsz;
328         batch->npfns = (u32 *)(batch->pfns + batch->array_size);
329         batch_clear(batch);
330         return 0;
331 }
332
333 static int batch_init(struct pfn_batch *batch, size_t max_pages)
334 {
335         return __batch_init(batch, max_pages, NULL, 0);
336 }
337
338 static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
339                               void *backup, size_t backup_len)
340 {
341         __batch_init(batch, max_pages, backup, backup_len);
342 }
343
344 static void batch_destroy(struct pfn_batch *batch, void *backup)
345 {
346         if (batch->pfns != backup)
347                 kfree(batch->pfns);
348 }
349
350 static bool batch_add_pfn_num(struct pfn_batch *batch, unsigned long pfn,
351                               u32 nr)
352 {
353         const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
354         unsigned int end = batch->end;
355
356         if (end && pfn == batch->pfns[end - 1] + batch->npfns[end - 1] &&
357             nr <= MAX_NPFNS - batch->npfns[end - 1]) {
358                 batch->npfns[end - 1] += nr;
359         } else if (end < batch->array_size) {
360                 batch->pfns[end] = pfn;
361                 batch->npfns[end] = nr;
362                 batch->end++;
363         } else {
364                 return false;
365         }
366
367         batch->total_pfns += nr;
368         return true;
369 }
370
371 static void batch_remove_pfn_num(struct pfn_batch *batch, unsigned long nr)
372 {
373         batch->npfns[batch->end - 1] -= nr;
374         if (batch->npfns[batch->end - 1] == 0)
375                 batch->end--;
376         batch->total_pfns -= nr;
377 }
378
379 /* true if the pfn was added, false otherwise */
380 static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
381 {
382         return batch_add_pfn_num(batch, pfn, 1);
383 }
384
385 /*
386  * Fill the batch with pfns from the domain. When the batch is full, or it
387  * reaches last_index, the function will return. The caller should use
388  * batch->total_pfns to determine the starting point for the next iteration.
389  */
390 static void batch_from_domain(struct pfn_batch *batch,
391                               struct iommu_domain *domain,
392                               struct iopt_area *area, unsigned long start_index,
393                               unsigned long last_index)
394 {
395         unsigned int page_offset = 0;
396         unsigned long iova;
397         phys_addr_t phys;
398
399         iova = iopt_area_index_to_iova(area, start_index);
400         if (start_index == iopt_area_index(area))
401                 page_offset = area->page_offset;
402         while (start_index <= last_index) {
403                 /*
404                  * This is pretty slow, it would be nice to get the page size
405                  * back from the driver, or have the driver directly fill the
406                  * batch.
407                  */
408                 phys = iommu_iova_to_phys(domain, iova) - page_offset;
409                 if (!batch_add_pfn(batch, PHYS_PFN(phys)))
410                         return;
411                 iova += PAGE_SIZE - page_offset;
412                 page_offset = 0;
413                 start_index++;
414         }
415 }
416
417 static struct page **raw_pages_from_domain(struct iommu_domain *domain,
418                                            struct iopt_area *area,
419                                            unsigned long start_index,
420                                            unsigned long last_index,
421                                            struct page **out_pages)
422 {
423         unsigned int page_offset = 0;
424         unsigned long iova;
425         phys_addr_t phys;
426
427         iova = iopt_area_index_to_iova(area, start_index);
428         if (start_index == iopt_area_index(area))
429                 page_offset = area->page_offset;
430         while (start_index <= last_index) {
431                 phys = iommu_iova_to_phys(domain, iova) - page_offset;
432                 *(out_pages++) = pfn_to_page(PHYS_PFN(phys));
433                 iova += PAGE_SIZE - page_offset;
434                 page_offset = 0;
435                 start_index++;
436         }
437         return out_pages;
438 }
439
440 /* Continues reading a domain until we reach a discontinuity in the pfns. */
441 static void batch_from_domain_continue(struct pfn_batch *batch,
442                                        struct iommu_domain *domain,
443                                        struct iopt_area *area,
444                                        unsigned long start_index,
445                                        unsigned long last_index)
446 {
447         unsigned int array_size = batch->array_size;
448
449         batch->array_size = batch->end;
450         batch_from_domain(batch, domain, area, start_index, last_index);
451         batch->array_size = array_size;
452 }
453
454 /*
455  * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
456  * mode permits splitting a mapped area up, and then one of the splits is
457  * unmapped. Doing this normally would cause us to violate our invariant of
458  * pairing map/unmap. Thus, to support old VFIO compatibility disable support
459  * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
460  * PAGE_SIZE units, not larger or smaller.
461  */
462 static int batch_iommu_map_small(struct iommu_domain *domain,
463                                  unsigned long iova, phys_addr_t paddr,
464                                  size_t size, int prot)
465 {
466         unsigned long start_iova = iova;
467         int rc;
468
469         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
470                 WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
471                         size % PAGE_SIZE);
472
473         while (size) {
474                 rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
475                                GFP_KERNEL_ACCOUNT);
476                 if (rc)
477                         goto err_unmap;
478                 iova += PAGE_SIZE;
479                 paddr += PAGE_SIZE;
480                 size -= PAGE_SIZE;
481         }
482         return 0;
483
484 err_unmap:
485         if (start_iova != iova)
486                 iommu_unmap_nofail(domain, start_iova, iova - start_iova);
487         return rc;
488 }
489
490 static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
491                            struct iopt_area *area, unsigned long start_index)
492 {
493         bool disable_large_pages = area->iopt->disable_large_pages;
494         unsigned long last_iova = iopt_area_last_iova(area);
495         unsigned int page_offset = 0;
496         unsigned long start_iova;
497         unsigned long next_iova;
498         unsigned int cur = 0;
499         unsigned long iova;
500         int rc;
501
502         /* The first index might be a partial page */
503         if (start_index == iopt_area_index(area))
504                 page_offset = area->page_offset;
505         next_iova = iova = start_iova =
506                 iopt_area_index_to_iova(area, start_index);
507         while (cur < batch->end) {
508                 next_iova = min(last_iova + 1,
509                                 next_iova + batch->npfns[cur] * PAGE_SIZE -
510                                         page_offset);
511                 if (disable_large_pages)
512                         rc = batch_iommu_map_small(
513                                 domain, iova,
514                                 PFN_PHYS(batch->pfns[cur]) + page_offset,
515                                 next_iova - iova, area->iommu_prot);
516                 else
517                         rc = iommu_map(domain, iova,
518                                        PFN_PHYS(batch->pfns[cur]) + page_offset,
519                                        next_iova - iova, area->iommu_prot,
520                                        GFP_KERNEL_ACCOUNT);
521                 if (rc)
522                         goto err_unmap;
523                 iova = next_iova;
524                 page_offset = 0;
525                 cur++;
526         }
527         return 0;
528 err_unmap:
529         if (start_iova != iova)
530                 iommu_unmap_nofail(domain, start_iova, iova - start_iova);
531         return rc;
532 }
533
534 static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
535                               unsigned long start_index,
536                               unsigned long last_index)
537 {
538         XA_STATE(xas, xa, start_index);
539         void *entry;
540
541         rcu_read_lock();
542         while (true) {
543                 entry = xas_next(&xas);
544                 if (xas_retry(&xas, entry))
545                         continue;
546                 WARN_ON(!xa_is_value(entry));
547                 if (!batch_add_pfn(batch, xa_to_value(entry)) ||
548                     start_index == last_index)
549                         break;
550                 start_index++;
551         }
552         rcu_read_unlock();
553 }
554
555 static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
556                                     unsigned long start_index,
557                                     unsigned long last_index)
558 {
559         XA_STATE(xas, xa, start_index);
560         void *entry;
561
562         xas_lock(&xas);
563         while (true) {
564                 entry = xas_next(&xas);
565                 if (xas_retry(&xas, entry))
566                         continue;
567                 WARN_ON(!xa_is_value(entry));
568                 if (!batch_add_pfn(batch, xa_to_value(entry)))
569                         break;
570                 xas_store(&xas, NULL);
571                 if (start_index == last_index)
572                         break;
573                 start_index++;
574         }
575         xas_unlock(&xas);
576 }
577
578 static void clear_xarray(struct xarray *xa, unsigned long start_index,
579                          unsigned long last_index)
580 {
581         XA_STATE(xas, xa, start_index);
582         void *entry;
583
584         xas_lock(&xas);
585         xas_for_each(&xas, entry, last_index)
586                 xas_store(&xas, NULL);
587         xas_unlock(&xas);
588 }
589
590 static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
591                            unsigned long last_index, struct page **pages)
592 {
593         struct page **end_pages = pages + (last_index - start_index) + 1;
594         struct page **half_pages = pages + (end_pages - pages) / 2;
595         XA_STATE(xas, xa, start_index);
596
597         do {
598                 void *old;
599
600                 xas_lock(&xas);
601                 while (pages != end_pages) {
602                         /* xarray does not participate in fault injection */
603                         if (pages == half_pages && iommufd_should_fail()) {
604                                 xas_set_err(&xas, -EINVAL);
605                                 xas_unlock(&xas);
606                                 /* aka xas_destroy() */
607                                 xas_nomem(&xas, GFP_KERNEL);
608                                 goto err_clear;
609                         }
610
611                         old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
612                         if (xas_error(&xas))
613                                 break;
614                         WARN_ON(old);
615                         pages++;
616                         xas_next(&xas);
617                 }
618                 xas_unlock(&xas);
619         } while (xas_nomem(&xas, GFP_KERNEL));
620
621 err_clear:
622         if (xas_error(&xas)) {
623                 if (xas.xa_index != start_index)
624                         clear_xarray(xa, start_index, xas.xa_index - 1);
625                 return xas_error(&xas);
626         }
627         return 0;
628 }
629
630 static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
631                              size_t npages)
632 {
633         struct page **end = pages + npages;
634
635         for (; pages != end; pages++)
636                 if (!batch_add_pfn(batch, page_to_pfn(*pages)))
637                         break;
638 }
639
640 static int batch_from_folios(struct pfn_batch *batch, struct folio ***folios_p,
641                              unsigned long *offset_p, unsigned long npages)
642 {
643         int rc = 0;
644         struct folio **folios = *folios_p;
645         unsigned long offset = *offset_p;
646
647         while (npages) {
648                 struct folio *folio = *folios;
649                 unsigned long nr = folio_nr_pages(folio) - offset;
650                 unsigned long pfn = page_to_pfn(folio_page(folio, offset));
651
652                 nr = min(nr, npages);
653                 npages -= nr;
654
655                 if (!batch_add_pfn_num(batch, pfn, nr))
656                         break;
657                 if (nr > 1) {
658                         rc = folio_add_pins(folio, nr - 1);
659                         if (rc) {
660                                 batch_remove_pfn_num(batch, nr);
661                                 goto out;
662                         }
663                 }
664
665                 folios++;
666                 offset = 0;
667         }
668
669 out:
670         *folios_p = folios;
671         *offset_p = offset;
672         return rc;
673 }
674
675 static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
676                         unsigned int first_page_off, size_t npages)
677 {
678         unsigned int cur = 0;
679
680         while (first_page_off) {
681                 if (batch->npfns[cur] > first_page_off)
682                         break;
683                 first_page_off -= batch->npfns[cur];
684                 cur++;
685         }
686
687         while (npages) {
688                 size_t to_unpin = min_t(size_t, npages,
689                                         batch->npfns[cur] - first_page_off);
690
691                 unpin_user_page_range_dirty_lock(
692                         pfn_to_page(batch->pfns[cur] + first_page_off),
693                         to_unpin, pages->writable);
694                 iopt_pages_sub_npinned(pages, to_unpin);
695                 cur++;
696                 first_page_off = 0;
697                 npages -= to_unpin;
698         }
699 }
700
701 static void copy_data_page(struct page *page, void *data, unsigned long offset,
702                            size_t length, unsigned int flags)
703 {
704         void *mem;
705
706         mem = kmap_local_page(page);
707         if (flags & IOMMUFD_ACCESS_RW_WRITE) {
708                 memcpy(mem + offset, data, length);
709                 set_page_dirty_lock(page);
710         } else {
711                 memcpy(data, mem + offset, length);
712         }
713         kunmap_local(mem);
714 }
715
716 static unsigned long batch_rw(struct pfn_batch *batch, void *data,
717                               unsigned long offset, unsigned long length,
718                               unsigned int flags)
719 {
720         unsigned long copied = 0;
721         unsigned int npage = 0;
722         unsigned int cur = 0;
723
724         while (cur < batch->end) {
725                 unsigned long bytes = min(length, PAGE_SIZE - offset);
726
727                 copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
728                                offset, bytes, flags);
729                 offset = 0;
730                 length -= bytes;
731                 data += bytes;
732                 copied += bytes;
733                 npage++;
734                 if (npage == batch->npfns[cur]) {
735                         npage = 0;
736                         cur++;
737                 }
738                 if (!length)
739                         break;
740         }
741         return copied;
742 }
743
744 /* pfn_reader_user is just the pin_user_pages() path */
745 struct pfn_reader_user {
746         struct page **upages;
747         size_t upages_len;
748         unsigned long upages_start;
749         unsigned long upages_end;
750         unsigned int gup_flags;
751         /*
752          * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
753          * neither
754          */
755         int locked;
756
757         /* The following are only valid if file != NULL. */
758         struct file *file;
759         struct folio **ufolios;
760         size_t ufolios_len;
761         unsigned long ufolios_offset;
762         struct folio **ufolios_next;
763 };
764
765 static void pfn_reader_user_init(struct pfn_reader_user *user,
766                                  struct iopt_pages *pages)
767 {
768         user->upages = NULL;
769         user->upages_len = 0;
770         user->upages_start = 0;
771         user->upages_end = 0;
772         user->locked = -1;
773         user->gup_flags = FOLL_LONGTERM;
774         if (pages->writable)
775                 user->gup_flags |= FOLL_WRITE;
776
777         user->file = (pages->type == IOPT_ADDRESS_FILE) ? pages->file : NULL;
778         user->ufolios = NULL;
779         user->ufolios_len = 0;
780         user->ufolios_next = NULL;
781         user->ufolios_offset = 0;
782 }
783
784 static void pfn_reader_user_destroy(struct pfn_reader_user *user,
785                                     struct iopt_pages *pages)
786 {
787         if (user->locked != -1) {
788                 if (user->locked)
789                         mmap_read_unlock(pages->source_mm);
790                 if (!user->file && pages->source_mm != current->mm)
791                         mmput(pages->source_mm);
792                 user->locked = -1;
793         }
794
795         kfree(user->upages);
796         user->upages = NULL;
797         kfree(user->ufolios);
798         user->ufolios = NULL;
799 }
800
801 static long pin_memfd_pages(struct pfn_reader_user *user, unsigned long start,
802                             unsigned long npages)
803 {
804         unsigned long i;
805         unsigned long offset;
806         unsigned long npages_out = 0;
807         struct page **upages = user->upages;
808         unsigned long end = start + (npages << PAGE_SHIFT) - 1;
809         long nfolios = user->ufolios_len / sizeof(*user->ufolios);
810
811         /*
812          * todo: memfd_pin_folios should return the last pinned offset so
813          * we can compute npages pinned, and avoid looping over folios here
814          * if upages == NULL.
815          */
816         nfolios = memfd_pin_folios(user->file, start, end, user->ufolios,
817                                    nfolios, &offset);
818         if (nfolios <= 0)
819                 return nfolios;
820
821         offset >>= PAGE_SHIFT;
822         user->ufolios_next = user->ufolios;
823         user->ufolios_offset = offset;
824
825         for (i = 0; i < nfolios; i++) {
826                 struct folio *folio = user->ufolios[i];
827                 unsigned long nr = folio_nr_pages(folio);
828                 unsigned long npin = min(nr - offset, npages);
829
830                 npages -= npin;
831                 npages_out += npin;
832
833                 if (upages) {
834                         if (npin == 1) {
835                                 *upages++ = folio_page(folio, offset);
836                         } else {
837                                 int rc = folio_add_pins(folio, npin - 1);
838
839                                 if (rc)
840                                         return rc;
841
842                                 while (npin--)
843                                         *upages++ = folio_page(folio, offset++);
844                         }
845                 }
846
847                 offset = 0;
848         }
849
850         return npages_out;
851 }
852
853 static int pfn_reader_user_pin(struct pfn_reader_user *user,
854                                struct iopt_pages *pages,
855                                unsigned long start_index,
856                                unsigned long last_index)
857 {
858         bool remote_mm = pages->source_mm != current->mm;
859         unsigned long npages = last_index - start_index + 1;
860         unsigned long start;
861         unsigned long unum;
862         uintptr_t uptr;
863         long rc;
864
865         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
866             WARN_ON(last_index < start_index))
867                 return -EINVAL;
868
869         if (!user->file && !user->upages) {
870                 /* All undone in pfn_reader_destroy() */
871                 user->upages_len = npages * sizeof(*user->upages);
872                 user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
873                 if (!user->upages)
874                         return -ENOMEM;
875         }
876
877         if (user->file && !user->ufolios) {
878                 user->ufolios_len = npages * sizeof(*user->ufolios);
879                 user->ufolios = temp_kmalloc(&user->ufolios_len, NULL, 0);
880                 if (!user->ufolios)
881                         return -ENOMEM;
882         }
883
884         if (user->locked == -1) {
885                 /*
886                  * The majority of usages will run the map task within the mm
887                  * providing the pages, so we can optimize into
888                  * get_user_pages_fast()
889                  */
890                 if (!user->file && remote_mm) {
891                         if (!mmget_not_zero(pages->source_mm))
892                                 return -EFAULT;
893                 }
894                 user->locked = 0;
895         }
896
897         unum = user->file ? user->ufolios_len / sizeof(*user->ufolios) :
898                             user->upages_len / sizeof(*user->upages);
899         npages = min_t(unsigned long, npages, unum);
900
901         if (iommufd_should_fail())
902                 return -EFAULT;
903
904         if (user->file) {
905                 start = pages->start + (start_index * PAGE_SIZE);
906                 rc = pin_memfd_pages(user, start, npages);
907         } else if (!remote_mm) {
908                 uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
909                 rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
910                                          user->upages);
911         } else {
912                 uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
913                 if (!user->locked) {
914                         mmap_read_lock(pages->source_mm);
915                         user->locked = 1;
916                 }
917                 rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
918                                            user->gup_flags, user->upages,
919                                            &user->locked);
920         }
921         if (rc <= 0) {
922                 if (WARN_ON(!rc))
923                         return -EFAULT;
924                 return rc;
925         }
926         iopt_pages_add_npinned(pages, rc);
927         user->upages_start = start_index;
928         user->upages_end = start_index + rc;
929         return 0;
930 }
931
932 /* This is the "modern" and faster accounting method used by io_uring */
933 static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
934 {
935         unsigned long lock_limit;
936         unsigned long cur_pages;
937         unsigned long new_pages;
938
939         lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
940                      PAGE_SHIFT;
941
942         cur_pages = atomic_long_read(&pages->source_user->locked_vm);
943         do {
944                 new_pages = cur_pages + npages;
945                 if (new_pages > lock_limit)
946                         return -ENOMEM;
947         } while (!atomic_long_try_cmpxchg(&pages->source_user->locked_vm,
948                                           &cur_pages, new_pages));
949         return 0;
950 }
951
952 static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
953 {
954         if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
955                 return;
956         atomic_long_sub(npages, &pages->source_user->locked_vm);
957 }
958
959 /* This is the accounting method used for compatibility with VFIO */
960 static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
961                                bool inc, struct pfn_reader_user *user)
962 {
963         bool do_put = false;
964         int rc;
965
966         if (user && user->locked) {
967                 mmap_read_unlock(pages->source_mm);
968                 user->locked = 0;
969                 /* If we had the lock then we also have a get */
970
971         } else if ((!user || (!user->upages && !user->ufolios)) &&
972                    pages->source_mm != current->mm) {
973                 if (!mmget_not_zero(pages->source_mm))
974                         return -EINVAL;
975                 do_put = true;
976         }
977
978         mmap_write_lock(pages->source_mm);
979         rc = __account_locked_vm(pages->source_mm, npages, inc,
980                                  pages->source_task, false);
981         mmap_write_unlock(pages->source_mm);
982
983         if (do_put)
984                 mmput(pages->source_mm);
985         return rc;
986 }
987
988 int iopt_pages_update_pinned(struct iopt_pages *pages, unsigned long npages,
989                              bool inc, struct pfn_reader_user *user)
990 {
991         int rc = 0;
992
993         switch (pages->account_mode) {
994         case IOPT_PAGES_ACCOUNT_NONE:
995                 break;
996         case IOPT_PAGES_ACCOUNT_USER:
997                 if (inc)
998                         rc = incr_user_locked_vm(pages, npages);
999                 else
1000                         decr_user_locked_vm(pages, npages);
1001                 break;
1002         case IOPT_PAGES_ACCOUNT_MM:
1003                 rc = update_mm_locked_vm(pages, npages, inc, user);
1004                 break;
1005         }
1006         if (rc)
1007                 return rc;
1008
1009         pages->last_npinned = pages->npinned;
1010         if (inc)
1011                 atomic64_add(npages, &pages->source_mm->pinned_vm);
1012         else
1013                 atomic64_sub(npages, &pages->source_mm->pinned_vm);
1014         return 0;
1015 }
1016
1017 static void update_unpinned(struct iopt_pages *pages)
1018 {
1019         if (WARN_ON(pages->npinned > pages->last_npinned))
1020                 return;
1021         if (pages->npinned == pages->last_npinned)
1022                 return;
1023         iopt_pages_update_pinned(pages, pages->last_npinned - pages->npinned,
1024                                  false, NULL);
1025 }
1026
1027 /*
1028  * Changes in the number of pages pinned is done after the pages have been read
1029  * and processed. If the user lacked the limit then the error unwind will unpin
1030  * everything that was just pinned. This is because it is expensive to calculate
1031  * how many pages we have already pinned within a range to generate an accurate
1032  * prediction in advance of doing the work to actually pin them.
1033  */
1034 static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
1035                                          struct iopt_pages *pages)
1036 {
1037         unsigned long npages;
1038         bool inc;
1039
1040         lockdep_assert_held(&pages->mutex);
1041
1042         if (pages->npinned == pages->last_npinned)
1043                 return 0;
1044
1045         if (pages->npinned < pages->last_npinned) {
1046                 npages = pages->last_npinned - pages->npinned;
1047                 inc = false;
1048         } else {
1049                 if (iommufd_should_fail())
1050                         return -ENOMEM;
1051                 npages = pages->npinned - pages->last_npinned;
1052                 inc = true;
1053         }
1054         return iopt_pages_update_pinned(pages, npages, inc, user);
1055 }
1056
1057 /*
1058  * PFNs are stored in three places, in order of preference:
1059  * - The iopt_pages xarray. This is only populated if there is a
1060  *   iopt_pages_access
1061  * - The iommu_domain under an area
1062  * - The original PFN source, ie pages->source_mm
1063  *
1064  * This iterator reads the pfns optimizing to load according to the
1065  * above order.
1066  */
1067 struct pfn_reader {
1068         struct iopt_pages *pages;
1069         struct interval_tree_double_span_iter span;
1070         struct pfn_batch batch;
1071         unsigned long batch_start_index;
1072         unsigned long batch_end_index;
1073         unsigned long last_index;
1074
1075         struct pfn_reader_user user;
1076 };
1077
1078 static int pfn_reader_update_pinned(struct pfn_reader *pfns)
1079 {
1080         return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
1081 }
1082
1083 /*
1084  * The batch can contain a mixture of pages that are still in use and pages that
1085  * need to be unpinned. Unpin only pages that are not held anywhere else.
1086  */
1087 static void pfn_reader_unpin(struct pfn_reader *pfns)
1088 {
1089         unsigned long last = pfns->batch_end_index - 1;
1090         unsigned long start = pfns->batch_start_index;
1091         struct interval_tree_double_span_iter span;
1092         struct iopt_pages *pages = pfns->pages;
1093
1094         lockdep_assert_held(&pages->mutex);
1095
1096         interval_tree_for_each_double_span(&span, &pages->access_itree,
1097                                            &pages->domains_itree, start, last) {
1098                 if (span.is_used)
1099                         continue;
1100
1101                 batch_unpin(&pfns->batch, pages, span.start_hole - start,
1102                             span.last_hole - span.start_hole + 1);
1103         }
1104 }
1105
1106 /* Process a single span to load it from the proper storage */
1107 static int pfn_reader_fill_span(struct pfn_reader *pfns)
1108 {
1109         struct interval_tree_double_span_iter *span = &pfns->span;
1110         unsigned long start_index = pfns->batch_end_index;
1111         struct pfn_reader_user *user = &pfns->user;
1112         unsigned long npages;
1113         struct iopt_area *area;
1114         int rc;
1115
1116         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1117             WARN_ON(span->last_used < start_index))
1118                 return -EINVAL;
1119
1120         if (span->is_used == 1) {
1121                 batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
1122                                   start_index, span->last_used);
1123                 return 0;
1124         }
1125
1126         if (span->is_used == 2) {
1127                 /*
1128                  * Pull as many pages from the first domain we find in the
1129                  * target span. If it is too small then we will be called again
1130                  * and we'll find another area.
1131                  */
1132                 area = iopt_pages_find_domain_area(pfns->pages, start_index);
1133                 if (WARN_ON(!area))
1134                         return -EINVAL;
1135
1136                 /* The storage_domain cannot change without the pages mutex */
1137                 batch_from_domain(
1138                         &pfns->batch, area->storage_domain, area, start_index,
1139                         min(iopt_area_last_index(area), span->last_used));
1140                 return 0;
1141         }
1142
1143         if (start_index >= pfns->user.upages_end) {
1144                 rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
1145                                          span->last_hole);
1146                 if (rc)
1147                         return rc;
1148         }
1149
1150         npages = user->upages_end - start_index;
1151         start_index -= user->upages_start;
1152         rc = 0;
1153
1154         if (!user->file)
1155                 batch_from_pages(&pfns->batch, user->upages + start_index,
1156                                  npages);
1157         else
1158                 rc = batch_from_folios(&pfns->batch, &user->ufolios_next,
1159                                        &user->ufolios_offset, npages);
1160         return rc;
1161 }
1162
1163 static bool pfn_reader_done(struct pfn_reader *pfns)
1164 {
1165         return pfns->batch_start_index == pfns->last_index + 1;
1166 }
1167
1168 static int pfn_reader_next(struct pfn_reader *pfns)
1169 {
1170         int rc;
1171
1172         batch_clear(&pfns->batch);
1173         pfns->batch_start_index = pfns->batch_end_index;
1174
1175         while (pfns->batch_end_index != pfns->last_index + 1) {
1176                 unsigned int npfns = pfns->batch.total_pfns;
1177
1178                 if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1179                     WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1180                         return -EINVAL;
1181
1182                 rc = pfn_reader_fill_span(pfns);
1183                 if (rc)
1184                         return rc;
1185
1186                 if (WARN_ON(!pfns->batch.total_pfns))
1187                         return -EINVAL;
1188
1189                 pfns->batch_end_index =
1190                         pfns->batch_start_index + pfns->batch.total_pfns;
1191                 if (pfns->batch_end_index == pfns->span.last_used + 1)
1192                         interval_tree_double_span_iter_next(&pfns->span);
1193
1194                 /* Batch is full */
1195                 if (npfns == pfns->batch.total_pfns)
1196                         return 0;
1197         }
1198         return 0;
1199 }
1200
1201 static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1202                            unsigned long start_index, unsigned long last_index)
1203 {
1204         int rc;
1205
1206         lockdep_assert_held(&pages->mutex);
1207
1208         pfns->pages = pages;
1209         pfns->batch_start_index = start_index;
1210         pfns->batch_end_index = start_index;
1211         pfns->last_index = last_index;
1212         pfn_reader_user_init(&pfns->user, pages);
1213         rc = batch_init(&pfns->batch, last_index - start_index + 1);
1214         if (rc)
1215                 return rc;
1216         interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1217                                              &pages->domains_itree, start_index,
1218                                              last_index);
1219         return 0;
1220 }
1221
1222 /*
1223  * There are many assertions regarding the state of pages->npinned vs
1224  * pages->last_pinned, for instance something like unmapping a domain must only
1225  * decrement the npinned, and pfn_reader_destroy() must be called only after all
1226  * the pins are updated. This is fine for success flows, but error flows
1227  * sometimes need to release the pins held inside the pfn_reader before going on
1228  * to complete unmapping and releasing pins held in domains.
1229  */
1230 static void pfn_reader_release_pins(struct pfn_reader *pfns)
1231 {
1232         struct iopt_pages *pages = pfns->pages;
1233         struct pfn_reader_user *user = &pfns->user;
1234
1235         if (user->upages_end > pfns->batch_end_index) {
1236                 /* Any pages not transferred to the batch are just unpinned */
1237
1238                 unsigned long npages = user->upages_end - pfns->batch_end_index;
1239                 unsigned long start_index = pfns->batch_end_index -
1240                                             user->upages_start;
1241
1242                 if (!user->file) {
1243                         unpin_user_pages(user->upages + start_index, npages);
1244                 } else {
1245                         long n = user->ufolios_len / sizeof(*user->ufolios);
1246
1247                         unpin_folios(user->ufolios_next,
1248                                      user->ufolios + n - user->ufolios_next);
1249                 }
1250                 iopt_pages_sub_npinned(pages, npages);
1251                 user->upages_end = pfns->batch_end_index;
1252         }
1253         if (pfns->batch_start_index != pfns->batch_end_index) {
1254                 pfn_reader_unpin(pfns);
1255                 pfns->batch_start_index = pfns->batch_end_index;
1256         }
1257 }
1258
1259 static void pfn_reader_destroy(struct pfn_reader *pfns)
1260 {
1261         struct iopt_pages *pages = pfns->pages;
1262
1263         pfn_reader_release_pins(pfns);
1264         pfn_reader_user_destroy(&pfns->user, pfns->pages);
1265         batch_destroy(&pfns->batch, NULL);
1266         WARN_ON(pages->last_npinned != pages->npinned);
1267 }
1268
1269 static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1270                             unsigned long start_index, unsigned long last_index)
1271 {
1272         int rc;
1273
1274         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1275             WARN_ON(last_index < start_index))
1276                 return -EINVAL;
1277
1278         rc = pfn_reader_init(pfns, pages, start_index, last_index);
1279         if (rc)
1280                 return rc;
1281         rc = pfn_reader_next(pfns);
1282         if (rc) {
1283                 pfn_reader_destroy(pfns);
1284                 return rc;
1285         }
1286         return 0;
1287 }
1288
1289 static struct iopt_pages *iopt_alloc_pages(unsigned long start_byte,
1290                                            unsigned long length,
1291                                            bool writable)
1292 {
1293         struct iopt_pages *pages;
1294
1295         /*
1296          * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1297          * below from overflow
1298          */
1299         if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1300                 return ERR_PTR(-EINVAL);
1301
1302         pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1303         if (!pages)
1304                 return ERR_PTR(-ENOMEM);
1305
1306         kref_init(&pages->kref);
1307         xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1308         mutex_init(&pages->mutex);
1309         pages->source_mm = current->mm;
1310         mmgrab(pages->source_mm);
1311         pages->npages = DIV_ROUND_UP(length + start_byte, PAGE_SIZE);
1312         pages->access_itree = RB_ROOT_CACHED;
1313         pages->domains_itree = RB_ROOT_CACHED;
1314         pages->writable = writable;
1315         if (capable(CAP_IPC_LOCK))
1316                 pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1317         else
1318                 pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1319         pages->source_task = current->group_leader;
1320         get_task_struct(current->group_leader);
1321         pages->source_user = get_uid(current_user());
1322         return pages;
1323 }
1324
1325 struct iopt_pages *iopt_alloc_user_pages(void __user *uptr,
1326                                          unsigned long length, bool writable)
1327 {
1328         struct iopt_pages *pages;
1329         unsigned long end;
1330         void __user *uptr_down =
1331                 (void __user *) ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1332
1333         if (check_add_overflow((unsigned long)uptr, length, &end))
1334                 return ERR_PTR(-EOVERFLOW);
1335
1336         pages = iopt_alloc_pages(uptr - uptr_down, length, writable);
1337         if (IS_ERR(pages))
1338                 return pages;
1339         pages->uptr = uptr_down;
1340         pages->type = IOPT_ADDRESS_USER;
1341         return pages;
1342 }
1343
1344 struct iopt_pages *iopt_alloc_file_pages(struct file *file, unsigned long start,
1345                                          unsigned long length, bool writable)
1346
1347 {
1348         struct iopt_pages *pages;
1349         unsigned long start_down = ALIGN_DOWN(start, PAGE_SIZE);
1350         unsigned long end;
1351
1352         if (length && check_add_overflow(start, length - 1, &end))
1353                 return ERR_PTR(-EOVERFLOW);
1354
1355         pages = iopt_alloc_pages(start - start_down, length, writable);
1356         if (IS_ERR(pages))
1357                 return pages;
1358         pages->file = get_file(file);
1359         pages->start = start_down;
1360         pages->type = IOPT_ADDRESS_FILE;
1361         return pages;
1362 }
1363
1364 void iopt_release_pages(struct kref *kref)
1365 {
1366         struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1367
1368         WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1369         WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1370         WARN_ON(pages->npinned);
1371         WARN_ON(!xa_empty(&pages->pinned_pfns));
1372         mmdrop(pages->source_mm);
1373         mutex_destroy(&pages->mutex);
1374         put_task_struct(pages->source_task);
1375         free_uid(pages->source_user);
1376         if (pages->type == IOPT_ADDRESS_FILE)
1377                 fput(pages->file);
1378         kfree(pages);
1379 }
1380
1381 static void
1382 iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1383                        struct iopt_pages *pages, struct iommu_domain *domain,
1384                        unsigned long start_index, unsigned long last_index,
1385                        unsigned long *unmapped_end_index,
1386                        unsigned long real_last_index)
1387 {
1388         while (start_index <= last_index) {
1389                 unsigned long batch_last_index;
1390
1391                 if (*unmapped_end_index <= last_index) {
1392                         unsigned long start =
1393                                 max(start_index, *unmapped_end_index);
1394
1395                         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1396                             batch->total_pfns)
1397                                 WARN_ON(*unmapped_end_index -
1398                                                 batch->total_pfns !=
1399                                         start_index);
1400                         batch_from_domain(batch, domain, area, start,
1401                                           last_index);
1402                         batch_last_index = start_index + batch->total_pfns - 1;
1403                 } else {
1404                         batch_last_index = last_index;
1405                 }
1406
1407                 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1408                         WARN_ON(batch_last_index > real_last_index);
1409
1410                 /*
1411                  * unmaps must always 'cut' at a place where the pfns are not
1412                  * contiguous to pair with the maps that always install
1413                  * contiguous pages. Thus, if we have to stop unpinning in the
1414                  * middle of the domains we need to keep reading pfns until we
1415                  * find a cut point to do the unmap. The pfns we read are
1416                  * carried over and either skipped or integrated into the next
1417                  * batch.
1418                  */
1419                 if (batch_last_index == last_index &&
1420                     last_index != real_last_index)
1421                         batch_from_domain_continue(batch, domain, area,
1422                                                    last_index + 1,
1423                                                    real_last_index);
1424
1425                 if (*unmapped_end_index <= batch_last_index) {
1426                         iopt_area_unmap_domain_range(
1427                                 area, domain, *unmapped_end_index,
1428                                 start_index + batch->total_pfns - 1);
1429                         *unmapped_end_index = start_index + batch->total_pfns;
1430                 }
1431
1432                 /* unpin must follow unmap */
1433                 batch_unpin(batch, pages, 0,
1434                             batch_last_index - start_index + 1);
1435                 start_index = batch_last_index + 1;
1436
1437                 batch_clear_carry(batch,
1438                                   *unmapped_end_index - batch_last_index - 1);
1439         }
1440 }
1441
1442 static void __iopt_area_unfill_domain(struct iopt_area *area,
1443                                       struct iopt_pages *pages,
1444                                       struct iommu_domain *domain,
1445                                       unsigned long last_index)
1446 {
1447         struct interval_tree_double_span_iter span;
1448         unsigned long start_index = iopt_area_index(area);
1449         unsigned long unmapped_end_index = start_index;
1450         u64 backup[BATCH_BACKUP_SIZE];
1451         struct pfn_batch batch;
1452
1453         lockdep_assert_held(&pages->mutex);
1454
1455         /*
1456          * For security we must not unpin something that is still DMA mapped,
1457          * so this must unmap any IOVA before we go ahead and unpin the pages.
1458          * This creates a complexity where we need to skip over unpinning pages
1459          * held in the xarray, but continue to unmap from the domain.
1460          *
1461          * The domain unmap cannot stop in the middle of a contiguous range of
1462          * PFNs. To solve this problem the unpinning step will read ahead to the
1463          * end of any contiguous span, unmap that whole span, and then only
1464          * unpin the leading part that does not have any accesses. The residual
1465          * PFNs that were unmapped but not unpinned are called a "carry" in the
1466          * batch as they are moved to the front of the PFN list and continue on
1467          * to the next iteration(s).
1468          */
1469         batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1470         interval_tree_for_each_double_span(&span, &pages->domains_itree,
1471                                            &pages->access_itree, start_index,
1472                                            last_index) {
1473                 if (span.is_used) {
1474                         batch_skip_carry(&batch,
1475                                          span.last_used - span.start_used + 1);
1476                         continue;
1477                 }
1478                 iopt_area_unpin_domain(&batch, area, pages, domain,
1479                                        span.start_hole, span.last_hole,
1480                                        &unmapped_end_index, last_index);
1481         }
1482         /*
1483          * If the range ends in a access then we do the residual unmap without
1484          * any unpins.
1485          */
1486         if (unmapped_end_index != last_index + 1)
1487                 iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1488                                              last_index);
1489         WARN_ON(batch.total_pfns);
1490         batch_destroy(&batch, backup);
1491         update_unpinned(pages);
1492 }
1493
1494 static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1495                                             struct iopt_pages *pages,
1496                                             struct iommu_domain *domain,
1497                                             unsigned long end_index)
1498 {
1499         if (end_index != iopt_area_index(area))
1500                 __iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1501 }
1502
1503 /**
1504  * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1505  * @area: The IOVA range to unmap
1506  * @domain: The domain to unmap
1507  *
1508  * The caller must know that unpinning is not required, usually because there
1509  * are other domains in the iopt.
1510  */
1511 void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1512 {
1513         iommu_unmap_nofail(domain, iopt_area_iova(area),
1514                            iopt_area_length(area));
1515 }
1516
1517 /**
1518  * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1519  * @area: IOVA area to use
1520  * @pages: page supplier for the area (area->pages is NULL)
1521  * @domain: Domain to unmap from
1522  *
1523  * The domain should be removed from the domains_itree before calling. The
1524  * domain will always be unmapped, but the PFNs may not be unpinned if there are
1525  * still accesses.
1526  */
1527 void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1528                              struct iommu_domain *domain)
1529 {
1530         __iopt_area_unfill_domain(area, pages, domain,
1531                                   iopt_area_last_index(area));
1532 }
1533
1534 /**
1535  * iopt_area_fill_domain() - Map PFNs from the area into a domain
1536  * @area: IOVA area to use
1537  * @domain: Domain to load PFNs into
1538  *
1539  * Read the pfns from the area's underlying iopt_pages and map them into the
1540  * given domain. Called when attaching a new domain to an io_pagetable.
1541  */
1542 int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1543 {
1544         unsigned long done_end_index;
1545         struct pfn_reader pfns;
1546         int rc;
1547
1548         lockdep_assert_held(&area->pages->mutex);
1549
1550         rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1551                               iopt_area_last_index(area));
1552         if (rc)
1553                 return rc;
1554
1555         while (!pfn_reader_done(&pfns)) {
1556                 done_end_index = pfns.batch_start_index;
1557                 rc = batch_to_domain(&pfns.batch, domain, area,
1558                                      pfns.batch_start_index);
1559                 if (rc)
1560                         goto out_unmap;
1561                 done_end_index = pfns.batch_end_index;
1562
1563                 rc = pfn_reader_next(&pfns);
1564                 if (rc)
1565                         goto out_unmap;
1566         }
1567
1568         rc = pfn_reader_update_pinned(&pfns);
1569         if (rc)
1570                 goto out_unmap;
1571         goto out_destroy;
1572
1573 out_unmap:
1574         pfn_reader_release_pins(&pfns);
1575         iopt_area_unfill_partial_domain(area, area->pages, domain,
1576                                         done_end_index);
1577 out_destroy:
1578         pfn_reader_destroy(&pfns);
1579         return rc;
1580 }
1581
1582 /**
1583  * iopt_area_fill_domains() - Install PFNs into the area's domains
1584  * @area: The area to act on
1585  * @pages: The pages associated with the area (area->pages is NULL)
1586  *
1587  * Called during area creation. The area is freshly created and not inserted in
1588  * the domains_itree yet. PFNs are read and loaded into every domain held in the
1589  * area's io_pagetable and the area is installed in the domains_itree.
1590  *
1591  * On failure all domains are left unchanged.
1592  */
1593 int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1594 {
1595         unsigned long done_first_end_index;
1596         unsigned long done_all_end_index;
1597         struct iommu_domain *domain;
1598         unsigned long unmap_index;
1599         struct pfn_reader pfns;
1600         unsigned long index;
1601         int rc;
1602
1603         lockdep_assert_held(&area->iopt->domains_rwsem);
1604
1605         if (xa_empty(&area->iopt->domains))
1606                 return 0;
1607
1608         mutex_lock(&pages->mutex);
1609         rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1610                               iopt_area_last_index(area));
1611         if (rc)
1612                 goto out_unlock;
1613
1614         while (!pfn_reader_done(&pfns)) {
1615                 done_first_end_index = pfns.batch_end_index;
1616                 done_all_end_index = pfns.batch_start_index;
1617                 xa_for_each(&area->iopt->domains, index, domain) {
1618                         rc = batch_to_domain(&pfns.batch, domain, area,
1619                                              pfns.batch_start_index);
1620                         if (rc)
1621                                 goto out_unmap;
1622                 }
1623                 done_all_end_index = done_first_end_index;
1624
1625                 rc = pfn_reader_next(&pfns);
1626                 if (rc)
1627                         goto out_unmap;
1628         }
1629         rc = pfn_reader_update_pinned(&pfns);
1630         if (rc)
1631                 goto out_unmap;
1632
1633         area->storage_domain = xa_load(&area->iopt->domains, 0);
1634         interval_tree_insert(&area->pages_node, &pages->domains_itree);
1635         goto out_destroy;
1636
1637 out_unmap:
1638         pfn_reader_release_pins(&pfns);
1639         xa_for_each(&area->iopt->domains, unmap_index, domain) {
1640                 unsigned long end_index;
1641
1642                 if (unmap_index < index)
1643                         end_index = done_first_end_index;
1644                 else
1645                         end_index = done_all_end_index;
1646
1647                 /*
1648                  * The area is not yet part of the domains_itree so we have to
1649                  * manage the unpinning specially. The last domain does the
1650                  * unpin, every other domain is just unmapped.
1651                  */
1652                 if (unmap_index != area->iopt->next_domain_id - 1) {
1653                         if (end_index != iopt_area_index(area))
1654                                 iopt_area_unmap_domain_range(
1655                                         area, domain, iopt_area_index(area),
1656                                         end_index - 1);
1657                 } else {
1658                         iopt_area_unfill_partial_domain(area, pages, domain,
1659                                                         end_index);
1660                 }
1661         }
1662 out_destroy:
1663         pfn_reader_destroy(&pfns);
1664 out_unlock:
1665         mutex_unlock(&pages->mutex);
1666         return rc;
1667 }
1668
1669 /**
1670  * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1671  * @area: The area to act on
1672  * @pages: The pages associated with the area (area->pages is NULL)
1673  *
1674  * Called during area destruction. This unmaps the iova's covered by all the
1675  * area's domains and releases the PFNs.
1676  */
1677 void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1678 {
1679         struct io_pagetable *iopt = area->iopt;
1680         struct iommu_domain *domain;
1681         unsigned long index;
1682
1683         lockdep_assert_held(&iopt->domains_rwsem);
1684
1685         mutex_lock(&pages->mutex);
1686         if (!area->storage_domain)
1687                 goto out_unlock;
1688
1689         xa_for_each(&iopt->domains, index, domain)
1690                 if (domain != area->storage_domain)
1691                         iopt_area_unmap_domain_range(
1692                                 area, domain, iopt_area_index(area),
1693                                 iopt_area_last_index(area));
1694
1695         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1696                 WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb));
1697         interval_tree_remove(&area->pages_node, &pages->domains_itree);
1698         iopt_area_unfill_domain(area, pages, area->storage_domain);
1699         area->storage_domain = NULL;
1700 out_unlock:
1701         mutex_unlock(&pages->mutex);
1702 }
1703
1704 static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1705                                     struct iopt_pages *pages,
1706                                     unsigned long start_index,
1707                                     unsigned long end_index)
1708 {
1709         while (start_index <= end_index) {
1710                 batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1711                                         end_index);
1712                 batch_unpin(batch, pages, 0, batch->total_pfns);
1713                 start_index += batch->total_pfns;
1714                 batch_clear(batch);
1715         }
1716 }
1717
1718 /**
1719  * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1720  * @pages: The pages to act on
1721  * @start_index: Starting PFN index
1722  * @last_index: Last PFN index
1723  *
1724  * Called when an iopt_pages_access is removed, removes pages from the itree.
1725  * The access should already be removed from the access_itree.
1726  */
1727 void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1728                               unsigned long start_index,
1729                               unsigned long last_index)
1730 {
1731         struct interval_tree_double_span_iter span;
1732         u64 backup[BATCH_BACKUP_SIZE];
1733         struct pfn_batch batch;
1734         bool batch_inited = false;
1735
1736         lockdep_assert_held(&pages->mutex);
1737
1738         interval_tree_for_each_double_span(&span, &pages->access_itree,
1739                                            &pages->domains_itree, start_index,
1740                                            last_index) {
1741                 if (!span.is_used) {
1742                         if (!batch_inited) {
1743                                 batch_init_backup(&batch,
1744                                                   last_index - start_index + 1,
1745                                                   backup, sizeof(backup));
1746                                 batch_inited = true;
1747                         }
1748                         iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1749                                                 span.last_hole);
1750                 } else if (span.is_used == 2) {
1751                         /* Covered by a domain */
1752                         clear_xarray(&pages->pinned_pfns, span.start_used,
1753                                      span.last_used);
1754                 }
1755                 /* Otherwise covered by an existing access */
1756         }
1757         if (batch_inited)
1758                 batch_destroy(&batch, backup);
1759         update_unpinned(pages);
1760 }
1761
1762 /**
1763  * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1764  * @pages: The pages to act on
1765  * @start_index: The first page index in the range
1766  * @last_index: The last page index in the range
1767  * @out_pages: The output array to return the pages
1768  *
1769  * This can be called if the caller is holding a refcount on an
1770  * iopt_pages_access that is known to have already been filled. It quickly reads
1771  * the pages directly from the xarray.
1772  *
1773  * This is part of the SW iommu interface to read pages for in-kernel use.
1774  */
1775 void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1776                                  unsigned long start_index,
1777                                  unsigned long last_index,
1778                                  struct page **out_pages)
1779 {
1780         XA_STATE(xas, &pages->pinned_pfns, start_index);
1781         void *entry;
1782
1783         rcu_read_lock();
1784         while (start_index <= last_index) {
1785                 entry = xas_next(&xas);
1786                 if (xas_retry(&xas, entry))
1787                         continue;
1788                 WARN_ON(!xa_is_value(entry));
1789                 *(out_pages++) = pfn_to_page(xa_to_value(entry));
1790                 start_index++;
1791         }
1792         rcu_read_unlock();
1793 }
1794
1795 static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1796                                        unsigned long start_index,
1797                                        unsigned long last_index,
1798                                        struct page **out_pages)
1799 {
1800         while (start_index != last_index + 1) {
1801                 unsigned long domain_last;
1802                 struct iopt_area *area;
1803
1804                 area = iopt_pages_find_domain_area(pages, start_index);
1805                 if (WARN_ON(!area))
1806                         return -EINVAL;
1807
1808                 domain_last = min(iopt_area_last_index(area), last_index);
1809                 out_pages = raw_pages_from_domain(area->storage_domain, area,
1810                                                   start_index, domain_last,
1811                                                   out_pages);
1812                 start_index = domain_last + 1;
1813         }
1814         return 0;
1815 }
1816
1817 static int iopt_pages_fill(struct iopt_pages *pages,
1818                            struct pfn_reader_user *user,
1819                            unsigned long start_index,
1820                            unsigned long last_index,
1821                            struct page **out_pages)
1822 {
1823         unsigned long cur_index = start_index;
1824         int rc;
1825
1826         while (cur_index != last_index + 1) {
1827                 user->upages = out_pages + (cur_index - start_index);
1828                 rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1829                 if (rc)
1830                         goto out_unpin;
1831                 cur_index = user->upages_end;
1832         }
1833         return 0;
1834
1835 out_unpin:
1836         if (start_index != cur_index)
1837                 iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1838                                      out_pages);
1839         return rc;
1840 }
1841
1842 /**
1843  * iopt_pages_fill_xarray() - Read PFNs
1844  * @pages: The pages to act on
1845  * @start_index: The first page index in the range
1846  * @last_index: The last page index in the range
1847  * @out_pages: The output array to return the pages, may be NULL
1848  *
1849  * This populates the xarray and returns the pages in out_pages. As the slow
1850  * path this is able to copy pages from other storage tiers into the xarray.
1851  *
1852  * On failure the xarray is left unchanged.
1853  *
1854  * This is part of the SW iommu interface to read pages for in-kernel use.
1855  */
1856 int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1857                            unsigned long last_index, struct page **out_pages)
1858 {
1859         struct interval_tree_double_span_iter span;
1860         unsigned long xa_end = start_index;
1861         struct pfn_reader_user user;
1862         int rc;
1863
1864         lockdep_assert_held(&pages->mutex);
1865
1866         pfn_reader_user_init(&user, pages);
1867         user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1868         interval_tree_for_each_double_span(&span, &pages->access_itree,
1869                                            &pages->domains_itree, start_index,
1870                                            last_index) {
1871                 struct page **cur_pages;
1872
1873                 if (span.is_used == 1) {
1874                         cur_pages = out_pages + (span.start_used - start_index);
1875                         iopt_pages_fill_from_xarray(pages, span.start_used,
1876                                                     span.last_used, cur_pages);
1877                         continue;
1878                 }
1879
1880                 if (span.is_used == 2) {
1881                         cur_pages = out_pages + (span.start_used - start_index);
1882                         iopt_pages_fill_from_domain(pages, span.start_used,
1883                                                     span.last_used, cur_pages);
1884                         rc = pages_to_xarray(&pages->pinned_pfns,
1885                                              span.start_used, span.last_used,
1886                                              cur_pages);
1887                         if (rc)
1888                                 goto out_clean_xa;
1889                         xa_end = span.last_used + 1;
1890                         continue;
1891                 }
1892
1893                 /* hole */
1894                 cur_pages = out_pages + (span.start_hole - start_index);
1895                 rc = iopt_pages_fill(pages, &user, span.start_hole,
1896                                      span.last_hole, cur_pages);
1897                 if (rc)
1898                         goto out_clean_xa;
1899                 rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1900                                      span.last_hole, cur_pages);
1901                 if (rc) {
1902                         iopt_pages_err_unpin(pages, span.start_hole,
1903                                              span.last_hole, cur_pages);
1904                         goto out_clean_xa;
1905                 }
1906                 xa_end = span.last_hole + 1;
1907         }
1908         rc = pfn_reader_user_update_pinned(&user, pages);
1909         if (rc)
1910                 goto out_clean_xa;
1911         user.upages = NULL;
1912         pfn_reader_user_destroy(&user, pages);
1913         return 0;
1914
1915 out_clean_xa:
1916         if (start_index != xa_end)
1917                 iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1918         user.upages = NULL;
1919         pfn_reader_user_destroy(&user, pages);
1920         return rc;
1921 }
1922
1923 /*
1924  * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1925  * do every scenario and is fully consistent with what an iommu_domain would
1926  * see.
1927  */
1928 static int iopt_pages_rw_slow(struct iopt_pages *pages,
1929                               unsigned long start_index,
1930                               unsigned long last_index, unsigned long offset,
1931                               void *data, unsigned long length,
1932                               unsigned int flags)
1933 {
1934         struct pfn_reader pfns;
1935         int rc;
1936
1937         mutex_lock(&pages->mutex);
1938
1939         rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1940         if (rc)
1941                 goto out_unlock;
1942
1943         while (!pfn_reader_done(&pfns)) {
1944                 unsigned long done;
1945
1946                 done = batch_rw(&pfns.batch, data, offset, length, flags);
1947                 data += done;
1948                 length -= done;
1949                 offset = 0;
1950                 pfn_reader_unpin(&pfns);
1951
1952                 rc = pfn_reader_next(&pfns);
1953                 if (rc)
1954                         goto out_destroy;
1955         }
1956         if (WARN_ON(length != 0))
1957                 rc = -EINVAL;
1958 out_destroy:
1959         pfn_reader_destroy(&pfns);
1960 out_unlock:
1961         mutex_unlock(&pages->mutex);
1962         return rc;
1963 }
1964
1965 /*
1966  * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1967  * memory allocations or interval tree searches.
1968  */
1969 static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1970                               unsigned long offset, void *data,
1971                               unsigned long length, unsigned int flags)
1972 {
1973         struct page *page = NULL;
1974         int rc;
1975
1976         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1977             WARN_ON(pages->type != IOPT_ADDRESS_USER))
1978                 return -EINVAL;
1979
1980         if (!mmget_not_zero(pages->source_mm))
1981                 return iopt_pages_rw_slow(pages, index, index, offset, data,
1982                                           length, flags);
1983
1984         if (iommufd_should_fail()) {
1985                 rc = -EINVAL;
1986                 goto out_mmput;
1987         }
1988
1989         mmap_read_lock(pages->source_mm);
1990         rc = pin_user_pages_remote(
1991                 pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1992                 1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1993                 NULL);
1994         mmap_read_unlock(pages->source_mm);
1995         if (rc != 1) {
1996                 if (WARN_ON(rc >= 0))
1997                         rc = -EINVAL;
1998                 goto out_mmput;
1999         }
2000         copy_data_page(page, data, offset, length, flags);
2001         unpin_user_page(page);
2002         rc = 0;
2003
2004 out_mmput:
2005         mmput(pages->source_mm);
2006         return rc;
2007 }
2008
2009 /**
2010  * iopt_pages_rw_access - Copy to/from a linear slice of the pages
2011  * @pages: pages to act on
2012  * @start_byte: First byte of pages to copy to/from
2013  * @data: Kernel buffer to get/put the data
2014  * @length: Number of bytes to copy
2015  * @flags: IOMMUFD_ACCESS_RW_* flags
2016  *
2017  * This will find each page in the range, kmap it and then memcpy to/from
2018  * the given kernel buffer.
2019  */
2020 int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
2021                          void *data, unsigned long length, unsigned int flags)
2022 {
2023         unsigned long start_index = start_byte / PAGE_SIZE;
2024         unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
2025         bool change_mm = current->mm != pages->source_mm;
2026         int rc = 0;
2027
2028         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2029             (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
2030                 change_mm = true;
2031
2032         if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2033                 return -EPERM;
2034
2035         if (pages->type == IOPT_ADDRESS_FILE)
2036                 return iopt_pages_rw_slow(pages, start_index, last_index,
2037                                           start_byte % PAGE_SIZE, data, length,
2038                                           flags);
2039
2040         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
2041             WARN_ON(pages->type != IOPT_ADDRESS_USER))
2042                 return -EINVAL;
2043
2044         if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
2045                 if (start_index == last_index)
2046                         return iopt_pages_rw_page(pages, start_index,
2047                                                   start_byte % PAGE_SIZE, data,
2048                                                   length, flags);
2049                 return iopt_pages_rw_slow(pages, start_index, last_index,
2050                                           start_byte % PAGE_SIZE, data, length,
2051                                           flags);
2052         }
2053
2054         /*
2055          * Try to copy using copy_to_user(). We do this as a fast path and
2056          * ignore any pinning inconsistencies, unlike a real DMA path.
2057          */
2058         if (change_mm) {
2059                 if (!mmget_not_zero(pages->source_mm))
2060                         return iopt_pages_rw_slow(pages, start_index,
2061                                                   last_index,
2062                                                   start_byte % PAGE_SIZE, data,
2063                                                   length, flags);
2064                 kthread_use_mm(pages->source_mm);
2065         }
2066
2067         if (flags & IOMMUFD_ACCESS_RW_WRITE) {
2068                 if (copy_to_user(pages->uptr + start_byte, data, length))
2069                         rc = -EFAULT;
2070         } else {
2071                 if (copy_from_user(data, pages->uptr + start_byte, length))
2072                         rc = -EFAULT;
2073         }
2074
2075         if (change_mm) {
2076                 kthread_unuse_mm(pages->source_mm);
2077                 mmput(pages->source_mm);
2078         }
2079
2080         return rc;
2081 }
2082
2083 static struct iopt_pages_access *
2084 iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
2085                             unsigned long last)
2086 {
2087         struct interval_tree_node *node;
2088
2089         lockdep_assert_held(&pages->mutex);
2090
2091         /* There can be overlapping ranges in this interval tree */
2092         for (node = interval_tree_iter_first(&pages->access_itree, index, last);
2093              node; node = interval_tree_iter_next(node, index, last))
2094                 if (node->start == index && node->last == last)
2095                         return container_of(node, struct iopt_pages_access,
2096                                             node);
2097         return NULL;
2098 }
2099
2100 /**
2101  * iopt_area_add_access() - Record an in-knerel access for PFNs
2102  * @area: The source of PFNs
2103  * @start_index: First page index
2104  * @last_index: Inclusive last page index
2105  * @out_pages: Output list of struct page's representing the PFNs
2106  * @flags: IOMMUFD_ACCESS_RW_* flags
2107  *
2108  * Record that an in-kernel access will be accessing the pages, ensure they are
2109  * pinned, and return the PFNs as a simple list of 'struct page *'.
2110  *
2111  * This should be undone through a matching call to iopt_area_remove_access()
2112  */
2113 int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
2114                           unsigned long last_index, struct page **out_pages,
2115                           unsigned int flags)
2116 {
2117         struct iopt_pages *pages = area->pages;
2118         struct iopt_pages_access *access;
2119         int rc;
2120
2121         if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
2122                 return -EPERM;
2123
2124         mutex_lock(&pages->mutex);
2125         access = iopt_pages_get_exact_access(pages, start_index, last_index);
2126         if (access) {
2127                 area->num_accesses++;
2128                 access->users++;
2129                 iopt_pages_fill_from_xarray(pages, start_index, last_index,
2130                                             out_pages);
2131                 mutex_unlock(&pages->mutex);
2132                 return 0;
2133         }
2134
2135         access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
2136         if (!access) {
2137                 rc = -ENOMEM;
2138                 goto err_unlock;
2139         }
2140
2141         rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
2142         if (rc)
2143                 goto err_free;
2144
2145         access->node.start = start_index;
2146         access->node.last = last_index;
2147         access->users = 1;
2148         area->num_accesses++;
2149         interval_tree_insert(&access->node, &pages->access_itree);
2150         mutex_unlock(&pages->mutex);
2151         return 0;
2152
2153 err_free:
2154         kfree(access);
2155 err_unlock:
2156         mutex_unlock(&pages->mutex);
2157         return rc;
2158 }
2159
2160 /**
2161  * iopt_area_remove_access() - Release an in-kernel access for PFNs
2162  * @area: The source of PFNs
2163  * @start_index: First page index
2164  * @last_index: Inclusive last page index
2165  *
2166  * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
2167  * must stop using the PFNs before calling this.
2168  */
2169 void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
2170                              unsigned long last_index)
2171 {
2172         struct iopt_pages *pages = area->pages;
2173         struct iopt_pages_access *access;
2174
2175         mutex_lock(&pages->mutex);
2176         access = iopt_pages_get_exact_access(pages, start_index, last_index);
2177         if (WARN_ON(!access))
2178                 goto out_unlock;
2179
2180         WARN_ON(area->num_accesses == 0 || access->users == 0);
2181         area->num_accesses--;
2182         access->users--;
2183         if (access->users)
2184                 goto out_unlock;
2185
2186         interval_tree_remove(&access->node, &pages->access_itree);
2187         iopt_pages_unfill_xarray(pages, start_index, last_index);
2188         kfree(access);
2189 out_unlock:
2190         mutex_unlock(&pages->mutex);
2191 }
This page took 0.155283 seconds and 4 git commands to generate.