]> Git Repo - J-linux.git/blob - drivers/iommu/iommufd/selftest.c
Merge tag 'vfs-6.13-rc7.fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/vfs/vfs
[J-linux.git] / drivers / iommu / iommufd / selftest.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * Kernel side components to support tools/testing/selftests/iommu
5  */
6 #include <linux/anon_inodes.h>
7 #include <linux/debugfs.h>
8 #include <linux/fault-inject.h>
9 #include <linux/file.h>
10 #include <linux/iommu.h>
11 #include <linux/platform_device.h>
12 #include <linux/slab.h>
13 #include <linux/xarray.h>
14 #include <uapi/linux/iommufd.h>
15
16 #include "../iommu-priv.h"
17 #include "io_pagetable.h"
18 #include "iommufd_private.h"
19 #include "iommufd_test.h"
20
21 static DECLARE_FAULT_ATTR(fail_iommufd);
22 static struct dentry *dbgfs_root;
23 static struct platform_device *selftest_iommu_dev;
24 static const struct iommu_ops mock_ops;
25 static struct iommu_domain_ops domain_nested_ops;
26
27 size_t iommufd_test_memory_limit = 65536;
28
29 struct mock_bus_type {
30         struct bus_type bus;
31         struct notifier_block nb;
32 };
33
34 static struct mock_bus_type iommufd_mock_bus_type = {
35         .bus = {
36                 .name = "iommufd_mock",
37         },
38 };
39
40 static DEFINE_IDA(mock_dev_ida);
41
42 enum {
43         MOCK_DIRTY_TRACK = 1,
44         MOCK_IO_PAGE_SIZE = PAGE_SIZE / 2,
45         MOCK_HUGE_PAGE_SIZE = 512 * MOCK_IO_PAGE_SIZE,
46
47         /*
48          * Like a real page table alignment requires the low bits of the address
49          * to be zero. xarray also requires the high bit to be zero, so we store
50          * the pfns shifted. The upper bits are used for metadata.
51          */
52         MOCK_PFN_MASK = ULONG_MAX / MOCK_IO_PAGE_SIZE,
53
54         _MOCK_PFN_START = MOCK_PFN_MASK + 1,
55         MOCK_PFN_START_IOVA = _MOCK_PFN_START,
56         MOCK_PFN_LAST_IOVA = _MOCK_PFN_START,
57         MOCK_PFN_DIRTY_IOVA = _MOCK_PFN_START << 1,
58         MOCK_PFN_HUGE_IOVA = _MOCK_PFN_START << 2,
59 };
60
61 /*
62  * Syzkaller has trouble randomizing the correct iova to use since it is linked
63  * to the map ioctl's output, and it has no ide about that. So, simplify things.
64  * In syzkaller mode the 64 bit IOVA is converted into an nth area and offset
65  * value. This has a much smaller randomization space and syzkaller can hit it.
66  */
67 static unsigned long __iommufd_test_syz_conv_iova(struct io_pagetable *iopt,
68                                                   u64 *iova)
69 {
70         struct syz_layout {
71                 __u32 nth_area;
72                 __u32 offset;
73         };
74         struct syz_layout *syz = (void *)iova;
75         unsigned int nth = syz->nth_area;
76         struct iopt_area *area;
77
78         down_read(&iopt->iova_rwsem);
79         for (area = iopt_area_iter_first(iopt, 0, ULONG_MAX); area;
80              area = iopt_area_iter_next(area, 0, ULONG_MAX)) {
81                 if (nth == 0) {
82                         up_read(&iopt->iova_rwsem);
83                         return iopt_area_iova(area) + syz->offset;
84                 }
85                 nth--;
86         }
87         up_read(&iopt->iova_rwsem);
88
89         return 0;
90 }
91
92 static unsigned long iommufd_test_syz_conv_iova(struct iommufd_access *access,
93                                                 u64 *iova)
94 {
95         unsigned long ret;
96
97         mutex_lock(&access->ioas_lock);
98         if (!access->ioas) {
99                 mutex_unlock(&access->ioas_lock);
100                 return 0;
101         }
102         ret = __iommufd_test_syz_conv_iova(&access->ioas->iopt, iova);
103         mutex_unlock(&access->ioas_lock);
104         return ret;
105 }
106
107 void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd,
108                                    unsigned int ioas_id, u64 *iova, u32 *flags)
109 {
110         struct iommufd_ioas *ioas;
111
112         if (!(*flags & MOCK_FLAGS_ACCESS_SYZ))
113                 return;
114         *flags &= ~(u32)MOCK_FLAGS_ACCESS_SYZ;
115
116         ioas = iommufd_get_ioas(ucmd->ictx, ioas_id);
117         if (IS_ERR(ioas))
118                 return;
119         *iova = __iommufd_test_syz_conv_iova(&ioas->iopt, iova);
120         iommufd_put_object(ucmd->ictx, &ioas->obj);
121 }
122
123 struct mock_iommu_domain {
124         unsigned long flags;
125         struct iommu_domain domain;
126         struct xarray pfns;
127 };
128
129 static inline struct mock_iommu_domain *
130 to_mock_domain(struct iommu_domain *domain)
131 {
132         return container_of(domain, struct mock_iommu_domain, domain);
133 }
134
135 struct mock_iommu_domain_nested {
136         struct iommu_domain domain;
137         struct mock_viommu *mock_viommu;
138         struct mock_iommu_domain *parent;
139         u32 iotlb[MOCK_NESTED_DOMAIN_IOTLB_NUM];
140 };
141
142 static inline struct mock_iommu_domain_nested *
143 to_mock_nested(struct iommu_domain *domain)
144 {
145         return container_of(domain, struct mock_iommu_domain_nested, domain);
146 }
147
148 struct mock_viommu {
149         struct iommufd_viommu core;
150         struct mock_iommu_domain *s2_parent;
151 };
152
153 static inline struct mock_viommu *to_mock_viommu(struct iommufd_viommu *viommu)
154 {
155         return container_of(viommu, struct mock_viommu, core);
156 }
157
158 enum selftest_obj_type {
159         TYPE_IDEV,
160 };
161
162 struct mock_dev {
163         struct device dev;
164         unsigned long flags;
165         int id;
166         u32 cache[MOCK_DEV_CACHE_NUM];
167 };
168
169 static inline struct mock_dev *to_mock_dev(struct device *dev)
170 {
171         return container_of(dev, struct mock_dev, dev);
172 }
173
174 struct selftest_obj {
175         struct iommufd_object obj;
176         enum selftest_obj_type type;
177
178         union {
179                 struct {
180                         struct iommufd_device *idev;
181                         struct iommufd_ctx *ictx;
182                         struct mock_dev *mock_dev;
183                 } idev;
184         };
185 };
186
187 static inline struct selftest_obj *to_selftest_obj(struct iommufd_object *obj)
188 {
189         return container_of(obj, struct selftest_obj, obj);
190 }
191
192 static int mock_domain_nop_attach(struct iommu_domain *domain,
193                                   struct device *dev)
194 {
195         struct mock_dev *mdev = to_mock_dev(dev);
196
197         if (domain->dirty_ops && (mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY))
198                 return -EINVAL;
199
200         return 0;
201 }
202
203 static const struct iommu_domain_ops mock_blocking_ops = {
204         .attach_dev = mock_domain_nop_attach,
205 };
206
207 static struct iommu_domain mock_blocking_domain = {
208         .type = IOMMU_DOMAIN_BLOCKED,
209         .ops = &mock_blocking_ops,
210 };
211
212 static void *mock_domain_hw_info(struct device *dev, u32 *length, u32 *type)
213 {
214         struct iommu_test_hw_info *info;
215
216         info = kzalloc(sizeof(*info), GFP_KERNEL);
217         if (!info)
218                 return ERR_PTR(-ENOMEM);
219
220         info->test_reg = IOMMU_HW_INFO_SELFTEST_REGVAL;
221         *length = sizeof(*info);
222         *type = IOMMU_HW_INFO_TYPE_SELFTEST;
223
224         return info;
225 }
226
227 static int mock_domain_set_dirty_tracking(struct iommu_domain *domain,
228                                           bool enable)
229 {
230         struct mock_iommu_domain *mock = to_mock_domain(domain);
231         unsigned long flags = mock->flags;
232
233         if (enable && !domain->dirty_ops)
234                 return -EINVAL;
235
236         /* No change? */
237         if (!(enable ^ !!(flags & MOCK_DIRTY_TRACK)))
238                 return 0;
239
240         flags = (enable ? flags | MOCK_DIRTY_TRACK : flags & ~MOCK_DIRTY_TRACK);
241
242         mock->flags = flags;
243         return 0;
244 }
245
246 static bool mock_test_and_clear_dirty(struct mock_iommu_domain *mock,
247                                       unsigned long iova, size_t page_size,
248                                       unsigned long flags)
249 {
250         unsigned long cur, end = iova + page_size - 1;
251         bool dirty = false;
252         void *ent, *old;
253
254         for (cur = iova; cur < end; cur += MOCK_IO_PAGE_SIZE) {
255                 ent = xa_load(&mock->pfns, cur / MOCK_IO_PAGE_SIZE);
256                 if (!ent || !(xa_to_value(ent) & MOCK_PFN_DIRTY_IOVA))
257                         continue;
258
259                 dirty = true;
260                 /* Clear dirty */
261                 if (!(flags & IOMMU_DIRTY_NO_CLEAR)) {
262                         unsigned long val;
263
264                         val = xa_to_value(ent) & ~MOCK_PFN_DIRTY_IOVA;
265                         old = xa_store(&mock->pfns, cur / MOCK_IO_PAGE_SIZE,
266                                        xa_mk_value(val), GFP_KERNEL);
267                         WARN_ON_ONCE(ent != old);
268                 }
269         }
270
271         return dirty;
272 }
273
274 static int mock_domain_read_and_clear_dirty(struct iommu_domain *domain,
275                                             unsigned long iova, size_t size,
276                                             unsigned long flags,
277                                             struct iommu_dirty_bitmap *dirty)
278 {
279         struct mock_iommu_domain *mock = to_mock_domain(domain);
280         unsigned long end = iova + size;
281         void *ent;
282
283         if (!(mock->flags & MOCK_DIRTY_TRACK) && dirty->bitmap)
284                 return -EINVAL;
285
286         do {
287                 unsigned long pgsize = MOCK_IO_PAGE_SIZE;
288                 unsigned long head;
289
290                 ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
291                 if (!ent) {
292                         iova += pgsize;
293                         continue;
294                 }
295
296                 if (xa_to_value(ent) & MOCK_PFN_HUGE_IOVA)
297                         pgsize = MOCK_HUGE_PAGE_SIZE;
298                 head = iova & ~(pgsize - 1);
299
300                 /* Clear dirty */
301                 if (mock_test_and_clear_dirty(mock, head, pgsize, flags))
302                         iommu_dirty_bitmap_record(dirty, iova, pgsize);
303                 iova += pgsize;
304         } while (iova < end);
305
306         return 0;
307 }
308
309 static const struct iommu_dirty_ops dirty_ops = {
310         .set_dirty_tracking = mock_domain_set_dirty_tracking,
311         .read_and_clear_dirty = mock_domain_read_and_clear_dirty,
312 };
313
314 static struct iommu_domain *mock_domain_alloc_paging(struct device *dev)
315 {
316         struct mock_dev *mdev = to_mock_dev(dev);
317         struct mock_iommu_domain *mock;
318
319         mock = kzalloc(sizeof(*mock), GFP_KERNEL);
320         if (!mock)
321                 return NULL;
322         mock->domain.geometry.aperture_start = MOCK_APERTURE_START;
323         mock->domain.geometry.aperture_end = MOCK_APERTURE_LAST;
324         mock->domain.pgsize_bitmap = MOCK_IO_PAGE_SIZE;
325         if (dev && mdev->flags & MOCK_FLAGS_DEVICE_HUGE_IOVA)
326                 mock->domain.pgsize_bitmap |= MOCK_HUGE_PAGE_SIZE;
327         mock->domain.ops = mock_ops.default_domain_ops;
328         mock->domain.type = IOMMU_DOMAIN_UNMANAGED;
329         xa_init(&mock->pfns);
330         return &mock->domain;
331 }
332
333 static struct mock_iommu_domain_nested *
334 __mock_domain_alloc_nested(const struct iommu_user_data *user_data)
335 {
336         struct mock_iommu_domain_nested *mock_nested;
337         struct iommu_hwpt_selftest user_cfg;
338         int rc, i;
339
340         if (user_data->type != IOMMU_HWPT_DATA_SELFTEST)
341                 return ERR_PTR(-EOPNOTSUPP);
342
343         rc = iommu_copy_struct_from_user(&user_cfg, user_data,
344                                          IOMMU_HWPT_DATA_SELFTEST, iotlb);
345         if (rc)
346                 return ERR_PTR(rc);
347
348         mock_nested = kzalloc(sizeof(*mock_nested), GFP_KERNEL);
349         if (!mock_nested)
350                 return ERR_PTR(-ENOMEM);
351         mock_nested->domain.ops = &domain_nested_ops;
352         mock_nested->domain.type = IOMMU_DOMAIN_NESTED;
353         for (i = 0; i < MOCK_NESTED_DOMAIN_IOTLB_NUM; i++)
354                 mock_nested->iotlb[i] = user_cfg.iotlb;
355         return mock_nested;
356 }
357
358 static struct iommu_domain *
359 mock_domain_alloc_nested(struct device *dev, struct iommu_domain *parent,
360                          u32 flags, const struct iommu_user_data *user_data)
361 {
362         struct mock_iommu_domain_nested *mock_nested;
363         struct mock_iommu_domain *mock_parent;
364
365         if (flags)
366                 return ERR_PTR(-EOPNOTSUPP);
367         if (!parent || parent->ops != mock_ops.default_domain_ops)
368                 return ERR_PTR(-EINVAL);
369
370         mock_parent = to_mock_domain(parent);
371         if (!mock_parent)
372                 return ERR_PTR(-EINVAL);
373
374         mock_nested = __mock_domain_alloc_nested(user_data);
375         if (IS_ERR(mock_nested))
376                 return ERR_CAST(mock_nested);
377         mock_nested->parent = mock_parent;
378         return &mock_nested->domain;
379 }
380
381 static struct iommu_domain *
382 mock_domain_alloc_paging_flags(struct device *dev, u32 flags,
383                                const struct iommu_user_data *user_data)
384 {
385         bool has_dirty_flag = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
386         const u32 PAGING_FLAGS = IOMMU_HWPT_ALLOC_DIRTY_TRACKING |
387                                  IOMMU_HWPT_ALLOC_NEST_PARENT;
388         bool no_dirty_ops = to_mock_dev(dev)->flags &
389                             MOCK_FLAGS_DEVICE_NO_DIRTY;
390         struct iommu_domain *domain;
391
392         if (user_data)
393                 return ERR_PTR(-EOPNOTSUPP);
394         if ((flags & ~PAGING_FLAGS) || (has_dirty_flag && no_dirty_ops))
395                 return ERR_PTR(-EOPNOTSUPP);
396
397         domain = mock_domain_alloc_paging(dev);
398         if (!domain)
399                 return ERR_PTR(-ENOMEM);
400         if (has_dirty_flag)
401                 domain->dirty_ops = &dirty_ops;
402         return domain;
403 }
404
405 static void mock_domain_free(struct iommu_domain *domain)
406 {
407         struct mock_iommu_domain *mock = to_mock_domain(domain);
408
409         WARN_ON(!xa_empty(&mock->pfns));
410         kfree(mock);
411 }
412
413 static int mock_domain_map_pages(struct iommu_domain *domain,
414                                  unsigned long iova, phys_addr_t paddr,
415                                  size_t pgsize, size_t pgcount, int prot,
416                                  gfp_t gfp, size_t *mapped)
417 {
418         struct mock_iommu_domain *mock = to_mock_domain(domain);
419         unsigned long flags = MOCK_PFN_START_IOVA;
420         unsigned long start_iova = iova;
421
422         /*
423          * xarray does not reliably work with fault injection because it does a
424          * retry allocation, so put our own failure point.
425          */
426         if (iommufd_should_fail())
427                 return -ENOENT;
428
429         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
430         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
431         for (; pgcount; pgcount--) {
432                 size_t cur;
433
434                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
435                         void *old;
436
437                         if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
438                                 flags = MOCK_PFN_LAST_IOVA;
439                         if (pgsize != MOCK_IO_PAGE_SIZE) {
440                                 flags |= MOCK_PFN_HUGE_IOVA;
441                         }
442                         old = xa_store(&mock->pfns, iova / MOCK_IO_PAGE_SIZE,
443                                        xa_mk_value((paddr / MOCK_IO_PAGE_SIZE) |
444                                                    flags),
445                                        gfp);
446                         if (xa_is_err(old)) {
447                                 for (; start_iova != iova;
448                                      start_iova += MOCK_IO_PAGE_SIZE)
449                                         xa_erase(&mock->pfns,
450                                                  start_iova /
451                                                          MOCK_IO_PAGE_SIZE);
452                                 return xa_err(old);
453                         }
454                         WARN_ON(old);
455                         iova += MOCK_IO_PAGE_SIZE;
456                         paddr += MOCK_IO_PAGE_SIZE;
457                         *mapped += MOCK_IO_PAGE_SIZE;
458                         flags = 0;
459                 }
460         }
461         return 0;
462 }
463
464 static size_t mock_domain_unmap_pages(struct iommu_domain *domain,
465                                       unsigned long iova, size_t pgsize,
466                                       size_t pgcount,
467                                       struct iommu_iotlb_gather *iotlb_gather)
468 {
469         struct mock_iommu_domain *mock = to_mock_domain(domain);
470         bool first = true;
471         size_t ret = 0;
472         void *ent;
473
474         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
475         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
476
477         for (; pgcount; pgcount--) {
478                 size_t cur;
479
480                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
481                         ent = xa_erase(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
482
483                         /*
484                          * iommufd generates unmaps that must be a strict
485                          * superset of the map's performend So every
486                          * starting/ending IOVA should have been an iova passed
487                          * to map.
488                          *
489                          * This simple logic doesn't work when the HUGE_PAGE is
490                          * turned on since the core code will automatically
491                          * switch between the two page sizes creating a break in
492                          * the unmap calls. The break can land in the middle of
493                          * contiguous IOVA.
494                          */
495                         if (!(domain->pgsize_bitmap & MOCK_HUGE_PAGE_SIZE)) {
496                                 if (first) {
497                                         WARN_ON(ent && !(xa_to_value(ent) &
498                                                          MOCK_PFN_START_IOVA));
499                                         first = false;
500                                 }
501                                 if (pgcount == 1 &&
502                                     cur + MOCK_IO_PAGE_SIZE == pgsize)
503                                         WARN_ON(ent && !(xa_to_value(ent) &
504                                                          MOCK_PFN_LAST_IOVA));
505                         }
506
507                         iova += MOCK_IO_PAGE_SIZE;
508                         ret += MOCK_IO_PAGE_SIZE;
509                 }
510         }
511         return ret;
512 }
513
514 static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain,
515                                             dma_addr_t iova)
516 {
517         struct mock_iommu_domain *mock = to_mock_domain(domain);
518         void *ent;
519
520         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
521         ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
522         WARN_ON(!ent);
523         return (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE;
524 }
525
526 static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
527 {
528         struct mock_dev *mdev = to_mock_dev(dev);
529
530         switch (cap) {
531         case IOMMU_CAP_CACHE_COHERENCY:
532                 return true;
533         case IOMMU_CAP_DIRTY_TRACKING:
534                 return !(mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY);
535         default:
536                 break;
537         }
538
539         return false;
540 }
541
542 static struct iopf_queue *mock_iommu_iopf_queue;
543
544 static struct mock_iommu_device {
545         struct iommu_device iommu_dev;
546         struct completion complete;
547         refcount_t users;
548 } mock_iommu;
549
550 static struct iommu_device *mock_probe_device(struct device *dev)
551 {
552         if (dev->bus != &iommufd_mock_bus_type.bus)
553                 return ERR_PTR(-ENODEV);
554         return &mock_iommu.iommu_dev;
555 }
556
557 static void mock_domain_page_response(struct device *dev, struct iopf_fault *evt,
558                                       struct iommu_page_response *msg)
559 {
560 }
561
562 static int mock_dev_enable_feat(struct device *dev, enum iommu_dev_features feat)
563 {
564         if (feat != IOMMU_DEV_FEAT_IOPF || !mock_iommu_iopf_queue)
565                 return -ENODEV;
566
567         return iopf_queue_add_device(mock_iommu_iopf_queue, dev);
568 }
569
570 static int mock_dev_disable_feat(struct device *dev, enum iommu_dev_features feat)
571 {
572         if (feat != IOMMU_DEV_FEAT_IOPF || !mock_iommu_iopf_queue)
573                 return -ENODEV;
574
575         iopf_queue_remove_device(mock_iommu_iopf_queue, dev);
576
577         return 0;
578 }
579
580 static void mock_viommu_destroy(struct iommufd_viommu *viommu)
581 {
582         struct mock_iommu_device *mock_iommu = container_of(
583                 viommu->iommu_dev, struct mock_iommu_device, iommu_dev);
584
585         if (refcount_dec_and_test(&mock_iommu->users))
586                 complete(&mock_iommu->complete);
587
588         /* iommufd core frees mock_viommu and viommu */
589 }
590
591 static struct iommu_domain *
592 mock_viommu_alloc_domain_nested(struct iommufd_viommu *viommu, u32 flags,
593                                 const struct iommu_user_data *user_data)
594 {
595         struct mock_viommu *mock_viommu = to_mock_viommu(viommu);
596         struct mock_iommu_domain_nested *mock_nested;
597
598         if (flags & ~IOMMU_HWPT_FAULT_ID_VALID)
599                 return ERR_PTR(-EOPNOTSUPP);
600
601         mock_nested = __mock_domain_alloc_nested(user_data);
602         if (IS_ERR(mock_nested))
603                 return ERR_CAST(mock_nested);
604         mock_nested->mock_viommu = mock_viommu;
605         mock_nested->parent = mock_viommu->s2_parent;
606         return &mock_nested->domain;
607 }
608
609 static int mock_viommu_cache_invalidate(struct iommufd_viommu *viommu,
610                                         struct iommu_user_data_array *array)
611 {
612         struct iommu_viommu_invalidate_selftest *cmds;
613         struct iommu_viommu_invalidate_selftest *cur;
614         struct iommu_viommu_invalidate_selftest *end;
615         int rc;
616
617         /* A zero-length array is allowed to validate the array type */
618         if (array->entry_num == 0 &&
619             array->type == IOMMU_VIOMMU_INVALIDATE_DATA_SELFTEST) {
620                 array->entry_num = 0;
621                 return 0;
622         }
623
624         cmds = kcalloc(array->entry_num, sizeof(*cmds), GFP_KERNEL);
625         if (!cmds)
626                 return -ENOMEM;
627         cur = cmds;
628         end = cmds + array->entry_num;
629
630         static_assert(sizeof(*cmds) == 3 * sizeof(u32));
631         rc = iommu_copy_struct_from_full_user_array(
632                 cmds, sizeof(*cmds), array,
633                 IOMMU_VIOMMU_INVALIDATE_DATA_SELFTEST);
634         if (rc)
635                 goto out;
636
637         while (cur != end) {
638                 struct mock_dev *mdev;
639                 struct device *dev;
640                 int i;
641
642                 if (cur->flags & ~IOMMU_TEST_INVALIDATE_FLAG_ALL) {
643                         rc = -EOPNOTSUPP;
644                         goto out;
645                 }
646
647                 if (cur->cache_id > MOCK_DEV_CACHE_ID_MAX) {
648                         rc = -EINVAL;
649                         goto out;
650                 }
651
652                 xa_lock(&viommu->vdevs);
653                 dev = iommufd_viommu_find_dev(viommu,
654                                               (unsigned long)cur->vdev_id);
655                 if (!dev) {
656                         xa_unlock(&viommu->vdevs);
657                         rc = -EINVAL;
658                         goto out;
659                 }
660                 mdev = container_of(dev, struct mock_dev, dev);
661
662                 if (cur->flags & IOMMU_TEST_INVALIDATE_FLAG_ALL) {
663                         /* Invalidate all cache entries and ignore cache_id */
664                         for (i = 0; i < MOCK_DEV_CACHE_NUM; i++)
665                                 mdev->cache[i] = 0;
666                 } else {
667                         mdev->cache[cur->cache_id] = 0;
668                 }
669                 xa_unlock(&viommu->vdevs);
670
671                 cur++;
672         }
673 out:
674         array->entry_num = cur - cmds;
675         kfree(cmds);
676         return rc;
677 }
678
679 static struct iommufd_viommu_ops mock_viommu_ops = {
680         .destroy = mock_viommu_destroy,
681         .alloc_domain_nested = mock_viommu_alloc_domain_nested,
682         .cache_invalidate = mock_viommu_cache_invalidate,
683 };
684
685 static struct iommufd_viommu *mock_viommu_alloc(struct device *dev,
686                                                 struct iommu_domain *domain,
687                                                 struct iommufd_ctx *ictx,
688                                                 unsigned int viommu_type)
689 {
690         struct mock_iommu_device *mock_iommu =
691                 iommu_get_iommu_dev(dev, struct mock_iommu_device, iommu_dev);
692         struct mock_viommu *mock_viommu;
693
694         if (viommu_type != IOMMU_VIOMMU_TYPE_SELFTEST)
695                 return ERR_PTR(-EOPNOTSUPP);
696
697         mock_viommu = iommufd_viommu_alloc(ictx, struct mock_viommu, core,
698                                            &mock_viommu_ops);
699         if (IS_ERR(mock_viommu))
700                 return ERR_CAST(mock_viommu);
701
702         refcount_inc(&mock_iommu->users);
703         return &mock_viommu->core;
704 }
705
706 static const struct iommu_ops mock_ops = {
707         /*
708          * IOMMU_DOMAIN_BLOCKED cannot be returned from def_domain_type()
709          * because it is zero.
710          */
711         .default_domain = &mock_blocking_domain,
712         .blocked_domain = &mock_blocking_domain,
713         .owner = THIS_MODULE,
714         .pgsize_bitmap = MOCK_IO_PAGE_SIZE,
715         .hw_info = mock_domain_hw_info,
716         .domain_alloc_paging = mock_domain_alloc_paging,
717         .domain_alloc_paging_flags = mock_domain_alloc_paging_flags,
718         .domain_alloc_nested = mock_domain_alloc_nested,
719         .capable = mock_domain_capable,
720         .device_group = generic_device_group,
721         .probe_device = mock_probe_device,
722         .page_response = mock_domain_page_response,
723         .dev_enable_feat = mock_dev_enable_feat,
724         .dev_disable_feat = mock_dev_disable_feat,
725         .user_pasid_table = true,
726         .viommu_alloc = mock_viommu_alloc,
727         .default_domain_ops =
728                 &(struct iommu_domain_ops){
729                         .free = mock_domain_free,
730                         .attach_dev = mock_domain_nop_attach,
731                         .map_pages = mock_domain_map_pages,
732                         .unmap_pages = mock_domain_unmap_pages,
733                         .iova_to_phys = mock_domain_iova_to_phys,
734                 },
735 };
736
737 static void mock_domain_free_nested(struct iommu_domain *domain)
738 {
739         kfree(to_mock_nested(domain));
740 }
741
742 static int
743 mock_domain_cache_invalidate_user(struct iommu_domain *domain,
744                                   struct iommu_user_data_array *array)
745 {
746         struct mock_iommu_domain_nested *mock_nested = to_mock_nested(domain);
747         struct iommu_hwpt_invalidate_selftest inv;
748         u32 processed = 0;
749         int i = 0, j;
750         int rc = 0;
751
752         if (array->type != IOMMU_HWPT_INVALIDATE_DATA_SELFTEST) {
753                 rc = -EINVAL;
754                 goto out;
755         }
756
757         for ( ; i < array->entry_num; i++) {
758                 rc = iommu_copy_struct_from_user_array(&inv, array,
759                                                        IOMMU_HWPT_INVALIDATE_DATA_SELFTEST,
760                                                        i, iotlb_id);
761                 if (rc)
762                         break;
763
764                 if (inv.flags & ~IOMMU_TEST_INVALIDATE_FLAG_ALL) {
765                         rc = -EOPNOTSUPP;
766                         break;
767                 }
768
769                 if (inv.iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX) {
770                         rc = -EINVAL;
771                         break;
772                 }
773
774                 if (inv.flags & IOMMU_TEST_INVALIDATE_FLAG_ALL) {
775                         /* Invalidate all mock iotlb entries and ignore iotlb_id */
776                         for (j = 0; j < MOCK_NESTED_DOMAIN_IOTLB_NUM; j++)
777                                 mock_nested->iotlb[j] = 0;
778                 } else {
779                         mock_nested->iotlb[inv.iotlb_id] = 0;
780                 }
781
782                 processed++;
783         }
784
785 out:
786         array->entry_num = processed;
787         return rc;
788 }
789
790 static struct iommu_domain_ops domain_nested_ops = {
791         .free = mock_domain_free_nested,
792         .attach_dev = mock_domain_nop_attach,
793         .cache_invalidate_user = mock_domain_cache_invalidate_user,
794 };
795
796 static inline struct iommufd_hw_pagetable *
797 __get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id, u32 hwpt_type)
798 {
799         struct iommufd_object *obj;
800
801         obj = iommufd_get_object(ucmd->ictx, mockpt_id, hwpt_type);
802         if (IS_ERR(obj))
803                 return ERR_CAST(obj);
804         return container_of(obj, struct iommufd_hw_pagetable, obj);
805 }
806
807 static inline struct iommufd_hw_pagetable *
808 get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
809                  struct mock_iommu_domain **mock)
810 {
811         struct iommufd_hw_pagetable *hwpt;
812
813         hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_PAGING);
814         if (IS_ERR(hwpt))
815                 return hwpt;
816         if (hwpt->domain->type != IOMMU_DOMAIN_UNMANAGED ||
817             hwpt->domain->ops != mock_ops.default_domain_ops) {
818                 iommufd_put_object(ucmd->ictx, &hwpt->obj);
819                 return ERR_PTR(-EINVAL);
820         }
821         *mock = to_mock_domain(hwpt->domain);
822         return hwpt;
823 }
824
825 static inline struct iommufd_hw_pagetable *
826 get_md_pagetable_nested(struct iommufd_ucmd *ucmd, u32 mockpt_id,
827                         struct mock_iommu_domain_nested **mock_nested)
828 {
829         struct iommufd_hw_pagetable *hwpt;
830
831         hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_NESTED);
832         if (IS_ERR(hwpt))
833                 return hwpt;
834         if (hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
835             hwpt->domain->ops != &domain_nested_ops) {
836                 iommufd_put_object(ucmd->ictx, &hwpt->obj);
837                 return ERR_PTR(-EINVAL);
838         }
839         *mock_nested = to_mock_nested(hwpt->domain);
840         return hwpt;
841 }
842
843 static void mock_dev_release(struct device *dev)
844 {
845         struct mock_dev *mdev = to_mock_dev(dev);
846
847         ida_free(&mock_dev_ida, mdev->id);
848         kfree(mdev);
849 }
850
851 static struct mock_dev *mock_dev_create(unsigned long dev_flags)
852 {
853         struct mock_dev *mdev;
854         int rc, i;
855
856         if (dev_flags &
857             ~(MOCK_FLAGS_DEVICE_NO_DIRTY | MOCK_FLAGS_DEVICE_HUGE_IOVA))
858                 return ERR_PTR(-EINVAL);
859
860         mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
861         if (!mdev)
862                 return ERR_PTR(-ENOMEM);
863
864         device_initialize(&mdev->dev);
865         mdev->flags = dev_flags;
866         mdev->dev.release = mock_dev_release;
867         mdev->dev.bus = &iommufd_mock_bus_type.bus;
868         for (i = 0; i < MOCK_DEV_CACHE_NUM; i++)
869                 mdev->cache[i] = IOMMU_TEST_DEV_CACHE_DEFAULT;
870
871         rc = ida_alloc(&mock_dev_ida, GFP_KERNEL);
872         if (rc < 0)
873                 goto err_put;
874         mdev->id = rc;
875
876         rc = dev_set_name(&mdev->dev, "iommufd_mock%u", mdev->id);
877         if (rc)
878                 goto err_put;
879
880         rc = device_add(&mdev->dev);
881         if (rc)
882                 goto err_put;
883         return mdev;
884
885 err_put:
886         put_device(&mdev->dev);
887         return ERR_PTR(rc);
888 }
889
890 static void mock_dev_destroy(struct mock_dev *mdev)
891 {
892         device_unregister(&mdev->dev);
893 }
894
895 bool iommufd_selftest_is_mock_dev(struct device *dev)
896 {
897         return dev->release == mock_dev_release;
898 }
899
900 /* Create an hw_pagetable with the mock domain so we can test the domain ops */
901 static int iommufd_test_mock_domain(struct iommufd_ucmd *ucmd,
902                                     struct iommu_test_cmd *cmd)
903 {
904         struct iommufd_device *idev;
905         struct selftest_obj *sobj;
906         u32 pt_id = cmd->id;
907         u32 dev_flags = 0;
908         u32 idev_id;
909         int rc;
910
911         sobj = iommufd_object_alloc(ucmd->ictx, sobj, IOMMUFD_OBJ_SELFTEST);
912         if (IS_ERR(sobj))
913                 return PTR_ERR(sobj);
914
915         sobj->idev.ictx = ucmd->ictx;
916         sobj->type = TYPE_IDEV;
917
918         if (cmd->op == IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS)
919                 dev_flags = cmd->mock_domain_flags.dev_flags;
920
921         sobj->idev.mock_dev = mock_dev_create(dev_flags);
922         if (IS_ERR(sobj->idev.mock_dev)) {
923                 rc = PTR_ERR(sobj->idev.mock_dev);
924                 goto out_sobj;
925         }
926
927         idev = iommufd_device_bind(ucmd->ictx, &sobj->idev.mock_dev->dev,
928                                    &idev_id);
929         if (IS_ERR(idev)) {
930                 rc = PTR_ERR(idev);
931                 goto out_mdev;
932         }
933         sobj->idev.idev = idev;
934
935         rc = iommufd_device_attach(idev, &pt_id);
936         if (rc)
937                 goto out_unbind;
938
939         /* Userspace must destroy the device_id to destroy the object */
940         cmd->mock_domain.out_hwpt_id = pt_id;
941         cmd->mock_domain.out_stdev_id = sobj->obj.id;
942         cmd->mock_domain.out_idev_id = idev_id;
943         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
944         if (rc)
945                 goto out_detach;
946         iommufd_object_finalize(ucmd->ictx, &sobj->obj);
947         return 0;
948
949 out_detach:
950         iommufd_device_detach(idev);
951 out_unbind:
952         iommufd_device_unbind(idev);
953 out_mdev:
954         mock_dev_destroy(sobj->idev.mock_dev);
955 out_sobj:
956         iommufd_object_abort(ucmd->ictx, &sobj->obj);
957         return rc;
958 }
959
960 /* Replace the mock domain with a manually allocated hw_pagetable */
961 static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
962                                             unsigned int device_id, u32 pt_id,
963                                             struct iommu_test_cmd *cmd)
964 {
965         struct iommufd_object *dev_obj;
966         struct selftest_obj *sobj;
967         int rc;
968
969         /*
970          * Prefer to use the OBJ_SELFTEST because the destroy_rwsem will ensure
971          * it doesn't race with detach, which is not allowed.
972          */
973         dev_obj =
974                 iommufd_get_object(ucmd->ictx, device_id, IOMMUFD_OBJ_SELFTEST);
975         if (IS_ERR(dev_obj))
976                 return PTR_ERR(dev_obj);
977
978         sobj = to_selftest_obj(dev_obj);
979         if (sobj->type != TYPE_IDEV) {
980                 rc = -EINVAL;
981                 goto out_dev_obj;
982         }
983
984         rc = iommufd_device_replace(sobj->idev.idev, &pt_id);
985         if (rc)
986                 goto out_dev_obj;
987
988         cmd->mock_domain_replace.pt_id = pt_id;
989         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
990
991 out_dev_obj:
992         iommufd_put_object(ucmd->ictx, dev_obj);
993         return rc;
994 }
995
996 /* Add an additional reserved IOVA to the IOAS */
997 static int iommufd_test_add_reserved(struct iommufd_ucmd *ucmd,
998                                      unsigned int mockpt_id,
999                                      unsigned long start, size_t length)
1000 {
1001         struct iommufd_ioas *ioas;
1002         int rc;
1003
1004         ioas = iommufd_get_ioas(ucmd->ictx, mockpt_id);
1005         if (IS_ERR(ioas))
1006                 return PTR_ERR(ioas);
1007         down_write(&ioas->iopt.iova_rwsem);
1008         rc = iopt_reserve_iova(&ioas->iopt, start, start + length - 1, NULL);
1009         up_write(&ioas->iopt.iova_rwsem);
1010         iommufd_put_object(ucmd->ictx, &ioas->obj);
1011         return rc;
1012 }
1013
1014 /* Check that every pfn under each iova matches the pfn under a user VA */
1015 static int iommufd_test_md_check_pa(struct iommufd_ucmd *ucmd,
1016                                     unsigned int mockpt_id, unsigned long iova,
1017                                     size_t length, void __user *uptr)
1018 {
1019         struct iommufd_hw_pagetable *hwpt;
1020         struct mock_iommu_domain *mock;
1021         uintptr_t end;
1022         int rc;
1023
1024         if (iova % MOCK_IO_PAGE_SIZE || length % MOCK_IO_PAGE_SIZE ||
1025             (uintptr_t)uptr % MOCK_IO_PAGE_SIZE ||
1026             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
1027                 return -EINVAL;
1028
1029         hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
1030         if (IS_ERR(hwpt))
1031                 return PTR_ERR(hwpt);
1032
1033         for (; length; length -= MOCK_IO_PAGE_SIZE) {
1034                 struct page *pages[1];
1035                 unsigned long pfn;
1036                 long npages;
1037                 void *ent;
1038
1039                 npages = get_user_pages_fast((uintptr_t)uptr & PAGE_MASK, 1, 0,
1040                                              pages);
1041                 if (npages < 0) {
1042                         rc = npages;
1043                         goto out_put;
1044                 }
1045                 if (WARN_ON(npages != 1)) {
1046                         rc = -EFAULT;
1047                         goto out_put;
1048                 }
1049                 pfn = page_to_pfn(pages[0]);
1050                 put_page(pages[0]);
1051
1052                 ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
1053                 if (!ent ||
1054                     (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE !=
1055                             pfn * PAGE_SIZE + ((uintptr_t)uptr % PAGE_SIZE)) {
1056                         rc = -EINVAL;
1057                         goto out_put;
1058                 }
1059                 iova += MOCK_IO_PAGE_SIZE;
1060                 uptr += MOCK_IO_PAGE_SIZE;
1061         }
1062         rc = 0;
1063
1064 out_put:
1065         iommufd_put_object(ucmd->ictx, &hwpt->obj);
1066         return rc;
1067 }
1068
1069 /* Check that the page ref count matches, to look for missing pin/unpins */
1070 static int iommufd_test_md_check_refs(struct iommufd_ucmd *ucmd,
1071                                       void __user *uptr, size_t length,
1072                                       unsigned int refs)
1073 {
1074         uintptr_t end;
1075
1076         if (length % PAGE_SIZE || (uintptr_t)uptr % PAGE_SIZE ||
1077             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
1078                 return -EINVAL;
1079
1080         for (; length; length -= PAGE_SIZE) {
1081                 struct page *pages[1];
1082                 long npages;
1083
1084                 npages = get_user_pages_fast((uintptr_t)uptr, 1, 0, pages);
1085                 if (npages < 0)
1086                         return npages;
1087                 if (WARN_ON(npages != 1))
1088                         return -EFAULT;
1089                 if (!PageCompound(pages[0])) {
1090                         unsigned int count;
1091
1092                         count = page_ref_count(pages[0]);
1093                         if (count / GUP_PIN_COUNTING_BIAS != refs) {
1094                                 put_page(pages[0]);
1095                                 return -EIO;
1096                         }
1097                 }
1098                 put_page(pages[0]);
1099                 uptr += PAGE_SIZE;
1100         }
1101         return 0;
1102 }
1103
1104 static int iommufd_test_md_check_iotlb(struct iommufd_ucmd *ucmd,
1105                                        u32 mockpt_id, unsigned int iotlb_id,
1106                                        u32 iotlb)
1107 {
1108         struct mock_iommu_domain_nested *mock_nested;
1109         struct iommufd_hw_pagetable *hwpt;
1110         int rc = 0;
1111
1112         hwpt = get_md_pagetable_nested(ucmd, mockpt_id, &mock_nested);
1113         if (IS_ERR(hwpt))
1114                 return PTR_ERR(hwpt);
1115
1116         mock_nested = to_mock_nested(hwpt->domain);
1117
1118         if (iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX ||
1119             mock_nested->iotlb[iotlb_id] != iotlb)
1120                 rc = -EINVAL;
1121         iommufd_put_object(ucmd->ictx, &hwpt->obj);
1122         return rc;
1123 }
1124
1125 static int iommufd_test_dev_check_cache(struct iommufd_ucmd *ucmd, u32 idev_id,
1126                                         unsigned int cache_id, u32 cache)
1127 {
1128         struct iommufd_device *idev;
1129         struct mock_dev *mdev;
1130         int rc = 0;
1131
1132         idev = iommufd_get_device(ucmd, idev_id);
1133         if (IS_ERR(idev))
1134                 return PTR_ERR(idev);
1135         mdev = container_of(idev->dev, struct mock_dev, dev);
1136
1137         if (cache_id > MOCK_DEV_CACHE_ID_MAX || mdev->cache[cache_id] != cache)
1138                 rc = -EINVAL;
1139         iommufd_put_object(ucmd->ictx, &idev->obj);
1140         return rc;
1141 }
1142
1143 struct selftest_access {
1144         struct iommufd_access *access;
1145         struct file *file;
1146         struct mutex lock;
1147         struct list_head items;
1148         unsigned int next_id;
1149         bool destroying;
1150 };
1151
1152 struct selftest_access_item {
1153         struct list_head items_elm;
1154         unsigned long iova;
1155         size_t length;
1156         unsigned int id;
1157 };
1158
1159 static const struct file_operations iommfd_test_staccess_fops;
1160
1161 static struct selftest_access *iommufd_access_get(int fd)
1162 {
1163         struct file *file;
1164
1165         file = fget(fd);
1166         if (!file)
1167                 return ERR_PTR(-EBADFD);
1168
1169         if (file->f_op != &iommfd_test_staccess_fops) {
1170                 fput(file);
1171                 return ERR_PTR(-EBADFD);
1172         }
1173         return file->private_data;
1174 }
1175
1176 static void iommufd_test_access_unmap(void *data, unsigned long iova,
1177                                       unsigned long length)
1178 {
1179         unsigned long iova_last = iova + length - 1;
1180         struct selftest_access *staccess = data;
1181         struct selftest_access_item *item;
1182         struct selftest_access_item *tmp;
1183
1184         mutex_lock(&staccess->lock);
1185         list_for_each_entry_safe(item, tmp, &staccess->items, items_elm) {
1186                 if (iova > item->iova + item->length - 1 ||
1187                     iova_last < item->iova)
1188                         continue;
1189                 list_del(&item->items_elm);
1190                 iommufd_access_unpin_pages(staccess->access, item->iova,
1191                                            item->length);
1192                 kfree(item);
1193         }
1194         mutex_unlock(&staccess->lock);
1195 }
1196
1197 static int iommufd_test_access_item_destroy(struct iommufd_ucmd *ucmd,
1198                                             unsigned int access_id,
1199                                             unsigned int item_id)
1200 {
1201         struct selftest_access_item *item;
1202         struct selftest_access *staccess;
1203
1204         staccess = iommufd_access_get(access_id);
1205         if (IS_ERR(staccess))
1206                 return PTR_ERR(staccess);
1207
1208         mutex_lock(&staccess->lock);
1209         list_for_each_entry(item, &staccess->items, items_elm) {
1210                 if (item->id == item_id) {
1211                         list_del(&item->items_elm);
1212                         iommufd_access_unpin_pages(staccess->access, item->iova,
1213                                                    item->length);
1214                         mutex_unlock(&staccess->lock);
1215                         kfree(item);
1216                         fput(staccess->file);
1217                         return 0;
1218                 }
1219         }
1220         mutex_unlock(&staccess->lock);
1221         fput(staccess->file);
1222         return -ENOENT;
1223 }
1224
1225 static int iommufd_test_staccess_release(struct inode *inode,
1226                                          struct file *filep)
1227 {
1228         struct selftest_access *staccess = filep->private_data;
1229
1230         if (staccess->access) {
1231                 iommufd_test_access_unmap(staccess, 0, ULONG_MAX);
1232                 iommufd_access_destroy(staccess->access);
1233         }
1234         mutex_destroy(&staccess->lock);
1235         kfree(staccess);
1236         return 0;
1237 }
1238
1239 static const struct iommufd_access_ops selftest_access_ops_pin = {
1240         .needs_pin_pages = 1,
1241         .unmap = iommufd_test_access_unmap,
1242 };
1243
1244 static const struct iommufd_access_ops selftest_access_ops = {
1245         .unmap = iommufd_test_access_unmap,
1246 };
1247
1248 static const struct file_operations iommfd_test_staccess_fops = {
1249         .release = iommufd_test_staccess_release,
1250 };
1251
1252 static struct selftest_access *iommufd_test_alloc_access(void)
1253 {
1254         struct selftest_access *staccess;
1255         struct file *filep;
1256
1257         staccess = kzalloc(sizeof(*staccess), GFP_KERNEL_ACCOUNT);
1258         if (!staccess)
1259                 return ERR_PTR(-ENOMEM);
1260         INIT_LIST_HEAD(&staccess->items);
1261         mutex_init(&staccess->lock);
1262
1263         filep = anon_inode_getfile("[iommufd_test_staccess]",
1264                                    &iommfd_test_staccess_fops, staccess,
1265                                    O_RDWR);
1266         if (IS_ERR(filep)) {
1267                 kfree(staccess);
1268                 return ERR_CAST(filep);
1269         }
1270         staccess->file = filep;
1271         return staccess;
1272 }
1273
1274 static int iommufd_test_create_access(struct iommufd_ucmd *ucmd,
1275                                       unsigned int ioas_id, unsigned int flags)
1276 {
1277         struct iommu_test_cmd *cmd = ucmd->cmd;
1278         struct selftest_access *staccess;
1279         struct iommufd_access *access;
1280         u32 id;
1281         int fdno;
1282         int rc;
1283
1284         if (flags & ~MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES)
1285                 return -EOPNOTSUPP;
1286
1287         staccess = iommufd_test_alloc_access();
1288         if (IS_ERR(staccess))
1289                 return PTR_ERR(staccess);
1290
1291         fdno = get_unused_fd_flags(O_CLOEXEC);
1292         if (fdno < 0) {
1293                 rc = -ENOMEM;
1294                 goto out_free_staccess;
1295         }
1296
1297         access = iommufd_access_create(
1298                 ucmd->ictx,
1299                 (flags & MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES) ?
1300                         &selftest_access_ops_pin :
1301                         &selftest_access_ops,
1302                 staccess, &id);
1303         if (IS_ERR(access)) {
1304                 rc = PTR_ERR(access);
1305                 goto out_put_fdno;
1306         }
1307         rc = iommufd_access_attach(access, ioas_id);
1308         if (rc)
1309                 goto out_destroy;
1310         cmd->create_access.out_access_fd = fdno;
1311         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1312         if (rc)
1313                 goto out_destroy;
1314
1315         staccess->access = access;
1316         fd_install(fdno, staccess->file);
1317         return 0;
1318
1319 out_destroy:
1320         iommufd_access_destroy(access);
1321 out_put_fdno:
1322         put_unused_fd(fdno);
1323 out_free_staccess:
1324         fput(staccess->file);
1325         return rc;
1326 }
1327
1328 static int iommufd_test_access_replace_ioas(struct iommufd_ucmd *ucmd,
1329                                             unsigned int access_id,
1330                                             unsigned int ioas_id)
1331 {
1332         struct selftest_access *staccess;
1333         int rc;
1334
1335         staccess = iommufd_access_get(access_id);
1336         if (IS_ERR(staccess))
1337                 return PTR_ERR(staccess);
1338
1339         rc = iommufd_access_replace(staccess->access, ioas_id);
1340         fput(staccess->file);
1341         return rc;
1342 }
1343
1344 /* Check that the pages in a page array match the pages in the user VA */
1345 static int iommufd_test_check_pages(void __user *uptr, struct page **pages,
1346                                     size_t npages)
1347 {
1348         for (; npages; npages--) {
1349                 struct page *tmp_pages[1];
1350                 long rc;
1351
1352                 rc = get_user_pages_fast((uintptr_t)uptr, 1, 0, tmp_pages);
1353                 if (rc < 0)
1354                         return rc;
1355                 if (WARN_ON(rc != 1))
1356                         return -EFAULT;
1357                 put_page(tmp_pages[0]);
1358                 if (tmp_pages[0] != *pages)
1359                         return -EBADE;
1360                 pages++;
1361                 uptr += PAGE_SIZE;
1362         }
1363         return 0;
1364 }
1365
1366 static int iommufd_test_access_pages(struct iommufd_ucmd *ucmd,
1367                                      unsigned int access_id, unsigned long iova,
1368                                      size_t length, void __user *uptr,
1369                                      u32 flags)
1370 {
1371         struct iommu_test_cmd *cmd = ucmd->cmd;
1372         struct selftest_access_item *item;
1373         struct selftest_access *staccess;
1374         struct page **pages;
1375         size_t npages;
1376         int rc;
1377
1378         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1379         if (length > 16*1024*1024)
1380                 return -ENOMEM;
1381
1382         if (flags & ~(MOCK_FLAGS_ACCESS_WRITE | MOCK_FLAGS_ACCESS_SYZ))
1383                 return -EOPNOTSUPP;
1384
1385         staccess = iommufd_access_get(access_id);
1386         if (IS_ERR(staccess))
1387                 return PTR_ERR(staccess);
1388
1389         if (staccess->access->ops != &selftest_access_ops_pin) {
1390                 rc = -EOPNOTSUPP;
1391                 goto out_put;
1392         }
1393
1394         if (flags & MOCK_FLAGS_ACCESS_SYZ)
1395                 iova = iommufd_test_syz_conv_iova(staccess->access,
1396                                         &cmd->access_pages.iova);
1397
1398         npages = (ALIGN(iova + length, PAGE_SIZE) -
1399                   ALIGN_DOWN(iova, PAGE_SIZE)) /
1400                  PAGE_SIZE;
1401         pages = kvcalloc(npages, sizeof(*pages), GFP_KERNEL_ACCOUNT);
1402         if (!pages) {
1403                 rc = -ENOMEM;
1404                 goto out_put;
1405         }
1406
1407         /*
1408          * Drivers will need to think very carefully about this locking. The
1409          * core code can do multiple unmaps instantaneously after
1410          * iommufd_access_pin_pages() and *all* the unmaps must not return until
1411          * the range is unpinned. This simple implementation puts a global lock
1412          * around the pin, which may not suit drivers that want this to be a
1413          * performance path. drivers that get this wrong will trigger WARN_ON
1414          * races and cause EDEADLOCK failures to userspace.
1415          */
1416         mutex_lock(&staccess->lock);
1417         rc = iommufd_access_pin_pages(staccess->access, iova, length, pages,
1418                                       flags & MOCK_FLAGS_ACCESS_WRITE);
1419         if (rc)
1420                 goto out_unlock;
1421
1422         /* For syzkaller allow uptr to be NULL to skip this check */
1423         if (uptr) {
1424                 rc = iommufd_test_check_pages(
1425                         uptr - (iova - ALIGN_DOWN(iova, PAGE_SIZE)), pages,
1426                         npages);
1427                 if (rc)
1428                         goto out_unaccess;
1429         }
1430
1431         item = kzalloc(sizeof(*item), GFP_KERNEL_ACCOUNT);
1432         if (!item) {
1433                 rc = -ENOMEM;
1434                 goto out_unaccess;
1435         }
1436
1437         item->iova = iova;
1438         item->length = length;
1439         item->id = staccess->next_id++;
1440         list_add_tail(&item->items_elm, &staccess->items);
1441
1442         cmd->access_pages.out_access_pages_id = item->id;
1443         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1444         if (rc)
1445                 goto out_free_item;
1446         goto out_unlock;
1447
1448 out_free_item:
1449         list_del(&item->items_elm);
1450         kfree(item);
1451 out_unaccess:
1452         iommufd_access_unpin_pages(staccess->access, iova, length);
1453 out_unlock:
1454         mutex_unlock(&staccess->lock);
1455         kvfree(pages);
1456 out_put:
1457         fput(staccess->file);
1458         return rc;
1459 }
1460
1461 static int iommufd_test_access_rw(struct iommufd_ucmd *ucmd,
1462                                   unsigned int access_id, unsigned long iova,
1463                                   size_t length, void __user *ubuf,
1464                                   unsigned int flags)
1465 {
1466         struct iommu_test_cmd *cmd = ucmd->cmd;
1467         struct selftest_access *staccess;
1468         void *tmp;
1469         int rc;
1470
1471         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
1472         if (length > 16*1024*1024)
1473                 return -ENOMEM;
1474
1475         if (flags & ~(MOCK_ACCESS_RW_WRITE | MOCK_ACCESS_RW_SLOW_PATH |
1476                       MOCK_FLAGS_ACCESS_SYZ))
1477                 return -EOPNOTSUPP;
1478
1479         staccess = iommufd_access_get(access_id);
1480         if (IS_ERR(staccess))
1481                 return PTR_ERR(staccess);
1482
1483         tmp = kvzalloc(length, GFP_KERNEL_ACCOUNT);
1484         if (!tmp) {
1485                 rc = -ENOMEM;
1486                 goto out_put;
1487         }
1488
1489         if (flags & MOCK_ACCESS_RW_WRITE) {
1490                 if (copy_from_user(tmp, ubuf, length)) {
1491                         rc = -EFAULT;
1492                         goto out_free;
1493                 }
1494         }
1495
1496         if (flags & MOCK_FLAGS_ACCESS_SYZ)
1497                 iova = iommufd_test_syz_conv_iova(staccess->access,
1498                                 &cmd->access_rw.iova);
1499
1500         rc = iommufd_access_rw(staccess->access, iova, tmp, length, flags);
1501         if (rc)
1502                 goto out_free;
1503         if (!(flags & MOCK_ACCESS_RW_WRITE)) {
1504                 if (copy_to_user(ubuf, tmp, length)) {
1505                         rc = -EFAULT;
1506                         goto out_free;
1507                 }
1508         }
1509
1510 out_free:
1511         kvfree(tmp);
1512 out_put:
1513         fput(staccess->file);
1514         return rc;
1515 }
1516 static_assert((unsigned int)MOCK_ACCESS_RW_WRITE == IOMMUFD_ACCESS_RW_WRITE);
1517 static_assert((unsigned int)MOCK_ACCESS_RW_SLOW_PATH ==
1518               __IOMMUFD_ACCESS_RW_SLOW_PATH);
1519
1520 static int iommufd_test_dirty(struct iommufd_ucmd *ucmd, unsigned int mockpt_id,
1521                               unsigned long iova, size_t length,
1522                               unsigned long page_size, void __user *uptr,
1523                               u32 flags)
1524 {
1525         unsigned long i, max;
1526         struct iommu_test_cmd *cmd = ucmd->cmd;
1527         struct iommufd_hw_pagetable *hwpt;
1528         struct mock_iommu_domain *mock;
1529         int rc, count = 0;
1530         void *tmp;
1531
1532         if (!page_size || !length || iova % page_size || length % page_size ||
1533             !uptr)
1534                 return -EINVAL;
1535
1536         hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
1537         if (IS_ERR(hwpt))
1538                 return PTR_ERR(hwpt);
1539
1540         if (!(mock->flags & MOCK_DIRTY_TRACK)) {
1541                 rc = -EINVAL;
1542                 goto out_put;
1543         }
1544
1545         max = length / page_size;
1546         tmp = kvzalloc(DIV_ROUND_UP(max, BITS_PER_LONG) * sizeof(unsigned long),
1547                        GFP_KERNEL_ACCOUNT);
1548         if (!tmp) {
1549                 rc = -ENOMEM;
1550                 goto out_put;
1551         }
1552
1553         if (copy_from_user(tmp, uptr,DIV_ROUND_UP(max, BITS_PER_BYTE))) {
1554                 rc = -EFAULT;
1555                 goto out_free;
1556         }
1557
1558         for (i = 0; i < max; i++) {
1559                 unsigned long cur = iova + i * page_size;
1560                 void *ent, *old;
1561
1562                 if (!test_bit(i, (unsigned long *)tmp))
1563                         continue;
1564
1565                 ent = xa_load(&mock->pfns, cur / page_size);
1566                 if (ent) {
1567                         unsigned long val;
1568
1569                         val = xa_to_value(ent) | MOCK_PFN_DIRTY_IOVA;
1570                         old = xa_store(&mock->pfns, cur / page_size,
1571                                        xa_mk_value(val), GFP_KERNEL);
1572                         WARN_ON_ONCE(ent != old);
1573                         count++;
1574                 }
1575         }
1576
1577         cmd->dirty.out_nr_dirty = count;
1578         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
1579 out_free:
1580         kvfree(tmp);
1581 out_put:
1582         iommufd_put_object(ucmd->ictx, &hwpt->obj);
1583         return rc;
1584 }
1585
1586 static int iommufd_test_trigger_iopf(struct iommufd_ucmd *ucmd,
1587                                      struct iommu_test_cmd *cmd)
1588 {
1589         struct iopf_fault event = { };
1590         struct iommufd_device *idev;
1591
1592         idev = iommufd_get_device(ucmd, cmd->trigger_iopf.dev_id);
1593         if (IS_ERR(idev))
1594                 return PTR_ERR(idev);
1595
1596         event.fault.prm.flags = IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE;
1597         if (cmd->trigger_iopf.pasid != IOMMU_NO_PASID)
1598                 event.fault.prm.flags |= IOMMU_FAULT_PAGE_REQUEST_PASID_VALID;
1599         event.fault.type = IOMMU_FAULT_PAGE_REQ;
1600         event.fault.prm.addr = cmd->trigger_iopf.addr;
1601         event.fault.prm.pasid = cmd->trigger_iopf.pasid;
1602         event.fault.prm.grpid = cmd->trigger_iopf.grpid;
1603         event.fault.prm.perm = cmd->trigger_iopf.perm;
1604
1605         iommu_report_device_fault(idev->dev, &event);
1606         iommufd_put_object(ucmd->ictx, &idev->obj);
1607
1608         return 0;
1609 }
1610
1611 void iommufd_selftest_destroy(struct iommufd_object *obj)
1612 {
1613         struct selftest_obj *sobj = to_selftest_obj(obj);
1614
1615         switch (sobj->type) {
1616         case TYPE_IDEV:
1617                 iommufd_device_detach(sobj->idev.idev);
1618                 iommufd_device_unbind(sobj->idev.idev);
1619                 mock_dev_destroy(sobj->idev.mock_dev);
1620                 break;
1621         }
1622 }
1623
1624 int iommufd_test(struct iommufd_ucmd *ucmd)
1625 {
1626         struct iommu_test_cmd *cmd = ucmd->cmd;
1627
1628         switch (cmd->op) {
1629         case IOMMU_TEST_OP_ADD_RESERVED:
1630                 return iommufd_test_add_reserved(ucmd, cmd->id,
1631                                                  cmd->add_reserved.start,
1632                                                  cmd->add_reserved.length);
1633         case IOMMU_TEST_OP_MOCK_DOMAIN:
1634         case IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS:
1635                 return iommufd_test_mock_domain(ucmd, cmd);
1636         case IOMMU_TEST_OP_MOCK_DOMAIN_REPLACE:
1637                 return iommufd_test_mock_domain_replace(
1638                         ucmd, cmd->id, cmd->mock_domain_replace.pt_id, cmd);
1639         case IOMMU_TEST_OP_MD_CHECK_MAP:
1640                 return iommufd_test_md_check_pa(
1641                         ucmd, cmd->id, cmd->check_map.iova,
1642                         cmd->check_map.length,
1643                         u64_to_user_ptr(cmd->check_map.uptr));
1644         case IOMMU_TEST_OP_MD_CHECK_REFS:
1645                 return iommufd_test_md_check_refs(
1646                         ucmd, u64_to_user_ptr(cmd->check_refs.uptr),
1647                         cmd->check_refs.length, cmd->check_refs.refs);
1648         case IOMMU_TEST_OP_MD_CHECK_IOTLB:
1649                 return iommufd_test_md_check_iotlb(ucmd, cmd->id,
1650                                                    cmd->check_iotlb.id,
1651                                                    cmd->check_iotlb.iotlb);
1652         case IOMMU_TEST_OP_DEV_CHECK_CACHE:
1653                 return iommufd_test_dev_check_cache(ucmd, cmd->id,
1654                                                     cmd->check_dev_cache.id,
1655                                                     cmd->check_dev_cache.cache);
1656         case IOMMU_TEST_OP_CREATE_ACCESS:
1657                 return iommufd_test_create_access(ucmd, cmd->id,
1658                                                   cmd->create_access.flags);
1659         case IOMMU_TEST_OP_ACCESS_REPLACE_IOAS:
1660                 return iommufd_test_access_replace_ioas(
1661                         ucmd, cmd->id, cmd->access_replace_ioas.ioas_id);
1662         case IOMMU_TEST_OP_ACCESS_PAGES:
1663                 return iommufd_test_access_pages(
1664                         ucmd, cmd->id, cmd->access_pages.iova,
1665                         cmd->access_pages.length,
1666                         u64_to_user_ptr(cmd->access_pages.uptr),
1667                         cmd->access_pages.flags);
1668         case IOMMU_TEST_OP_ACCESS_RW:
1669                 return iommufd_test_access_rw(
1670                         ucmd, cmd->id, cmd->access_rw.iova,
1671                         cmd->access_rw.length,
1672                         u64_to_user_ptr(cmd->access_rw.uptr),
1673                         cmd->access_rw.flags);
1674         case IOMMU_TEST_OP_DESTROY_ACCESS_PAGES:
1675                 return iommufd_test_access_item_destroy(
1676                         ucmd, cmd->id, cmd->destroy_access_pages.access_pages_id);
1677         case IOMMU_TEST_OP_SET_TEMP_MEMORY_LIMIT:
1678                 /* Protect _batch_init(), can not be less than elmsz */
1679                 if (cmd->memory_limit.limit <
1680                     sizeof(unsigned long) + sizeof(u32))
1681                         return -EINVAL;
1682                 iommufd_test_memory_limit = cmd->memory_limit.limit;
1683                 return 0;
1684         case IOMMU_TEST_OP_DIRTY:
1685                 return iommufd_test_dirty(ucmd, cmd->id, cmd->dirty.iova,
1686                                           cmd->dirty.length,
1687                                           cmd->dirty.page_size,
1688                                           u64_to_user_ptr(cmd->dirty.uptr),
1689                                           cmd->dirty.flags);
1690         case IOMMU_TEST_OP_TRIGGER_IOPF:
1691                 return iommufd_test_trigger_iopf(ucmd, cmd);
1692         default:
1693                 return -EOPNOTSUPP;
1694         }
1695 }
1696
1697 bool iommufd_should_fail(void)
1698 {
1699         return should_fail(&fail_iommufd, 1);
1700 }
1701
1702 int __init iommufd_test_init(void)
1703 {
1704         struct platform_device_info pdevinfo = {
1705                 .name = "iommufd_selftest_iommu",
1706         };
1707         int rc;
1708
1709         dbgfs_root =
1710                 fault_create_debugfs_attr("fail_iommufd", NULL, &fail_iommufd);
1711
1712         selftest_iommu_dev = platform_device_register_full(&pdevinfo);
1713         if (IS_ERR(selftest_iommu_dev)) {
1714                 rc = PTR_ERR(selftest_iommu_dev);
1715                 goto err_dbgfs;
1716         }
1717
1718         rc = bus_register(&iommufd_mock_bus_type.bus);
1719         if (rc)
1720                 goto err_platform;
1721
1722         rc = iommu_device_sysfs_add(&mock_iommu.iommu_dev,
1723                                     &selftest_iommu_dev->dev, NULL, "%s",
1724                                     dev_name(&selftest_iommu_dev->dev));
1725         if (rc)
1726                 goto err_bus;
1727
1728         rc = iommu_device_register_bus(&mock_iommu.iommu_dev, &mock_ops,
1729                                   &iommufd_mock_bus_type.bus,
1730                                   &iommufd_mock_bus_type.nb);
1731         if (rc)
1732                 goto err_sysfs;
1733
1734         refcount_set(&mock_iommu.users, 1);
1735         init_completion(&mock_iommu.complete);
1736
1737         mock_iommu_iopf_queue = iopf_queue_alloc("mock-iopfq");
1738
1739         return 0;
1740
1741 err_sysfs:
1742         iommu_device_sysfs_remove(&mock_iommu.iommu_dev);
1743 err_bus:
1744         bus_unregister(&iommufd_mock_bus_type.bus);
1745 err_platform:
1746         platform_device_unregister(selftest_iommu_dev);
1747 err_dbgfs:
1748         debugfs_remove_recursive(dbgfs_root);
1749         return rc;
1750 }
1751
1752 static void iommufd_test_wait_for_users(void)
1753 {
1754         if (refcount_dec_and_test(&mock_iommu.users))
1755                 return;
1756         /*
1757          * Time out waiting for iommu device user count to become 0.
1758          *
1759          * Note that this is just making an example here, since the selftest is
1760          * built into the iommufd module, i.e. it only unplugs the iommu device
1761          * when unloading the module. So, it is expected that this WARN_ON will
1762          * not trigger, as long as any iommufd FDs are open.
1763          */
1764         WARN_ON(!wait_for_completion_timeout(&mock_iommu.complete,
1765                                              msecs_to_jiffies(10000)));
1766 }
1767
1768 void iommufd_test_exit(void)
1769 {
1770         if (mock_iommu_iopf_queue) {
1771                 iopf_queue_free(mock_iommu_iopf_queue);
1772                 mock_iommu_iopf_queue = NULL;
1773         }
1774
1775         iommufd_test_wait_for_users();
1776         iommu_device_sysfs_remove(&mock_iommu.iommu_dev);
1777         iommu_device_unregister_bus(&mock_iommu.iommu_dev,
1778                                     &iommufd_mock_bus_type.bus,
1779                                     &iommufd_mock_bus_type.nb);
1780         bus_unregister(&iommufd_mock_bus_type.bus);
1781         platform_device_unregister(selftest_iommu_dev);
1782         debugfs_remove_recursive(dbgfs_root);
1783 }
This page took 0.126685 seconds and 4 git commands to generate.