]> Git Repo - qemu.git/blob - contrib/libvhost-user/libvhost-user.c
e53b1953df1ca59bb9dd066ba12485a2f5cda9b9
[qemu.git] / contrib / libvhost-user / libvhost-user.c
1 /*
2  * Vhost User library
3  *
4  * Copyright IBM, Corp. 2007
5  * Copyright (c) 2016 Red Hat, Inc.
6  *
7  * Authors:
8  *  Anthony Liguori <[email protected]>
9  *  Marc-AndrĂ© Lureau <[email protected]>
10  *  Victor Kaplansky <[email protected]>
11  *
12  * This work is licensed under the terms of the GNU GPL, version 2 or
13  * later.  See the COPYING file in the top-level directory.
14  */
15
16 /* this code avoids GLib dependency */
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <unistd.h>
20 #include <stdarg.h>
21 #include <errno.h>
22 #include <string.h>
23 #include <assert.h>
24 #include <inttypes.h>
25 #include <sys/types.h>
26 #include <sys/socket.h>
27 #include <sys/eventfd.h>
28 #include <sys/mman.h>
29 #include "qemu/compiler.h"
30
31 #if defined(__linux__)
32 #include <sys/syscall.h>
33 #include <fcntl.h>
34 #include <sys/ioctl.h>
35 #include <linux/vhost.h>
36
37 #ifdef __NR_userfaultfd
38 #include <linux/userfaultfd.h>
39 #endif
40
41 #endif
42
43 #include "qemu/atomic.h"
44
45 #include "libvhost-user.h"
46
47 /* usually provided by GLib */
48 #ifndef MIN
49 #define MIN(x, y) ({                            \
50             typeof(x) _min1 = (x);              \
51             typeof(y) _min2 = (y);              \
52             (void) (&_min1 == &_min2);          \
53             _min1 < _min2 ? _min1 : _min2; })
54 #endif
55
56 #define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
57
58 /* The version of the protocol we support */
59 #define VHOST_USER_VERSION 1
60 #define LIBVHOST_USER_DEBUG 0
61
62 #define DPRINT(...)                             \
63     do {                                        \
64         if (LIBVHOST_USER_DEBUG) {              \
65             fprintf(stderr, __VA_ARGS__);        \
66         }                                       \
67     } while (0)
68
69 static const char *
70 vu_request_to_string(unsigned int req)
71 {
72 #define REQ(req) [req] = #req
73     static const char *vu_request_str[] = {
74         REQ(VHOST_USER_NONE),
75         REQ(VHOST_USER_GET_FEATURES),
76         REQ(VHOST_USER_SET_FEATURES),
77         REQ(VHOST_USER_SET_OWNER),
78         REQ(VHOST_USER_RESET_OWNER),
79         REQ(VHOST_USER_SET_MEM_TABLE),
80         REQ(VHOST_USER_SET_LOG_BASE),
81         REQ(VHOST_USER_SET_LOG_FD),
82         REQ(VHOST_USER_SET_VRING_NUM),
83         REQ(VHOST_USER_SET_VRING_ADDR),
84         REQ(VHOST_USER_SET_VRING_BASE),
85         REQ(VHOST_USER_GET_VRING_BASE),
86         REQ(VHOST_USER_SET_VRING_KICK),
87         REQ(VHOST_USER_SET_VRING_CALL),
88         REQ(VHOST_USER_SET_VRING_ERR),
89         REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
90         REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
91         REQ(VHOST_USER_GET_QUEUE_NUM),
92         REQ(VHOST_USER_SET_VRING_ENABLE),
93         REQ(VHOST_USER_SEND_RARP),
94         REQ(VHOST_USER_NET_SET_MTU),
95         REQ(VHOST_USER_SET_SLAVE_REQ_FD),
96         REQ(VHOST_USER_IOTLB_MSG),
97         REQ(VHOST_USER_SET_VRING_ENDIAN),
98         REQ(VHOST_USER_GET_CONFIG),
99         REQ(VHOST_USER_SET_CONFIG),
100         REQ(VHOST_USER_POSTCOPY_ADVISE),
101         REQ(VHOST_USER_POSTCOPY_LISTEN),
102         REQ(VHOST_USER_MAX),
103     };
104 #undef REQ
105
106     if (req < VHOST_USER_MAX) {
107         return vu_request_str[req];
108     } else {
109         return "unknown";
110     }
111 }
112
113 static void
114 vu_panic(VuDev *dev, const char *msg, ...)
115 {
116     char *buf = NULL;
117     va_list ap;
118
119     va_start(ap, msg);
120     if (vasprintf(&buf, msg, ap) < 0) {
121         buf = NULL;
122     }
123     va_end(ap);
124
125     dev->broken = true;
126     dev->panic(dev, buf);
127     free(buf);
128
129     /* FIXME: find a way to call virtio_error? */
130 }
131
132 /* Translate guest physical address to our virtual address.  */
133 void *
134 vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
135 {
136     int i;
137
138     if (*plen == 0) {
139         return NULL;
140     }
141
142     /* Find matching memory region.  */
143     for (i = 0; i < dev->nregions; i++) {
144         VuDevRegion *r = &dev->regions[i];
145
146         if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
147             if ((guest_addr + *plen) > (r->gpa + r->size)) {
148                 *plen = r->gpa + r->size - guest_addr;
149             }
150             return (void *)(uintptr_t)
151                 guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
152         }
153     }
154
155     return NULL;
156 }
157
158 /* Translate qemu virtual address to our virtual address.  */
159 static void *
160 qva_to_va(VuDev *dev, uint64_t qemu_addr)
161 {
162     int i;
163
164     /* Find matching memory region.  */
165     for (i = 0; i < dev->nregions; i++) {
166         VuDevRegion *r = &dev->regions[i];
167
168         if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
169             return (void *)(uintptr_t)
170                 qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
171         }
172     }
173
174     return NULL;
175 }
176
177 static void
178 vmsg_close_fds(VhostUserMsg *vmsg)
179 {
180     int i;
181
182     for (i = 0; i < vmsg->fd_num; i++) {
183         close(vmsg->fds[i]);
184     }
185 }
186
187 static bool
188 vu_message_read(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
189 {
190     char control[CMSG_SPACE(VHOST_MEMORY_MAX_NREGIONS * sizeof(int))] = { };
191     struct iovec iov = {
192         .iov_base = (char *)vmsg,
193         .iov_len = VHOST_USER_HDR_SIZE,
194     };
195     struct msghdr msg = {
196         .msg_iov = &iov,
197         .msg_iovlen = 1,
198         .msg_control = control,
199         .msg_controllen = sizeof(control),
200     };
201     size_t fd_size;
202     struct cmsghdr *cmsg;
203     int rc;
204
205     do {
206         rc = recvmsg(conn_fd, &msg, 0);
207     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
208
209     if (rc < 0) {
210         vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
211         return false;
212     }
213
214     vmsg->fd_num = 0;
215     for (cmsg = CMSG_FIRSTHDR(&msg);
216          cmsg != NULL;
217          cmsg = CMSG_NXTHDR(&msg, cmsg))
218     {
219         if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
220             fd_size = cmsg->cmsg_len - CMSG_LEN(0);
221             vmsg->fd_num = fd_size / sizeof(int);
222             memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
223             break;
224         }
225     }
226
227     if (vmsg->size > sizeof(vmsg->payload)) {
228         vu_panic(dev,
229                  "Error: too big message request: %d, size: vmsg->size: %u, "
230                  "while sizeof(vmsg->payload) = %zu\n",
231                  vmsg->request, vmsg->size, sizeof(vmsg->payload));
232         goto fail;
233     }
234
235     if (vmsg->size) {
236         do {
237             rc = read(conn_fd, &vmsg->payload, vmsg->size);
238         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
239
240         if (rc <= 0) {
241             vu_panic(dev, "Error while reading: %s", strerror(errno));
242             goto fail;
243         }
244
245         assert(rc == vmsg->size);
246     }
247
248     return true;
249
250 fail:
251     vmsg_close_fds(vmsg);
252
253     return false;
254 }
255
256 static bool
257 vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
258 {
259     int rc;
260     uint8_t *p = (uint8_t *)vmsg;
261     char control[CMSG_SPACE(VHOST_MEMORY_MAX_NREGIONS * sizeof(int))] = { };
262     struct iovec iov = {
263         .iov_base = (char *)vmsg,
264         .iov_len = VHOST_USER_HDR_SIZE,
265     };
266     struct msghdr msg = {
267         .msg_iov = &iov,
268         .msg_iovlen = 1,
269         .msg_control = control,
270     };
271     struct cmsghdr *cmsg;
272
273     memset(control, 0, sizeof(control));
274     assert(vmsg->fd_num <= VHOST_MEMORY_MAX_NREGIONS);
275     if (vmsg->fd_num > 0) {
276         size_t fdsize = vmsg->fd_num * sizeof(int);
277         msg.msg_controllen = CMSG_SPACE(fdsize);
278         cmsg = CMSG_FIRSTHDR(&msg);
279         cmsg->cmsg_len = CMSG_LEN(fdsize);
280         cmsg->cmsg_level = SOL_SOCKET;
281         cmsg->cmsg_type = SCM_RIGHTS;
282         memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
283     } else {
284         msg.msg_controllen = 0;
285     }
286
287     /* Set the version in the flags when sending the reply */
288     vmsg->flags &= ~VHOST_USER_VERSION_MASK;
289     vmsg->flags |= VHOST_USER_VERSION;
290     vmsg->flags |= VHOST_USER_REPLY_MASK;
291
292     do {
293         rc = sendmsg(conn_fd, &msg, 0);
294     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
295
296     do {
297         if (vmsg->data) {
298             rc = write(conn_fd, vmsg->data, vmsg->size);
299         } else {
300             rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
301         }
302     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
303
304     if (rc <= 0) {
305         vu_panic(dev, "Error while writing: %s", strerror(errno));
306         return false;
307     }
308
309     return true;
310 }
311
312 /* Kick the log_call_fd if required. */
313 static void
314 vu_log_kick(VuDev *dev)
315 {
316     if (dev->log_call_fd != -1) {
317         DPRINT("Kicking the QEMU's log...\n");
318         if (eventfd_write(dev->log_call_fd, 1) < 0) {
319             vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
320         }
321     }
322 }
323
324 static void
325 vu_log_page(uint8_t *log_table, uint64_t page)
326 {
327     DPRINT("Logged dirty guest page: %"PRId64"\n", page);
328     atomic_or(&log_table[page / 8], 1 << (page % 8));
329 }
330
331 static void
332 vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
333 {
334     uint64_t page;
335
336     if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
337         !dev->log_table || !length) {
338         return;
339     }
340
341     assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
342
343     page = address / VHOST_LOG_PAGE;
344     while (page * VHOST_LOG_PAGE < address + length) {
345         vu_log_page(dev->log_table, page);
346         page += VHOST_LOG_PAGE;
347     }
348
349     vu_log_kick(dev);
350 }
351
352 static void
353 vu_kick_cb(VuDev *dev, int condition, void *data)
354 {
355     int index = (intptr_t)data;
356     VuVirtq *vq = &dev->vq[index];
357     int sock = vq->kick_fd;
358     eventfd_t kick_data;
359     ssize_t rc;
360
361     rc = eventfd_read(sock, &kick_data);
362     if (rc == -1) {
363         vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
364         dev->remove_watch(dev, dev->vq[index].kick_fd);
365     } else {
366         DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
367                kick_data, vq->handler, index);
368         if (vq->handler) {
369             vq->handler(dev, index);
370         }
371     }
372 }
373
374 static bool
375 vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
376 {
377     vmsg->payload.u64 =
378         1ULL << VHOST_F_LOG_ALL |
379         1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
380
381     if (dev->iface->get_features) {
382         vmsg->payload.u64 |= dev->iface->get_features(dev);
383     }
384
385     vmsg->size = sizeof(vmsg->payload.u64);
386     vmsg->fd_num = 0;
387
388     DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
389
390     return true;
391 }
392
393 static void
394 vu_set_enable_all_rings(VuDev *dev, bool enabled)
395 {
396     int i;
397
398     for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
399         dev->vq[i].enable = enabled;
400     }
401 }
402
403 static bool
404 vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
405 {
406     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
407
408     dev->features = vmsg->payload.u64;
409
410     if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
411         vu_set_enable_all_rings(dev, true);
412     }
413
414     if (dev->iface->set_features) {
415         dev->iface->set_features(dev, dev->features);
416     }
417
418     return false;
419 }
420
421 static bool
422 vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
423 {
424     return false;
425 }
426
427 static void
428 vu_close_log(VuDev *dev)
429 {
430     if (dev->log_table) {
431         if (munmap(dev->log_table, dev->log_size) != 0) {
432             perror("close log munmap() error");
433         }
434
435         dev->log_table = NULL;
436     }
437     if (dev->log_call_fd != -1) {
438         close(dev->log_call_fd);
439         dev->log_call_fd = -1;
440     }
441 }
442
443 static bool
444 vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
445 {
446     vu_set_enable_all_rings(dev, false);
447
448     return false;
449 }
450
451 static bool
452 vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
453 {
454     int i;
455     VhostUserMemory *memory = &vmsg->payload.memory;
456
457     for (i = 0; i < dev->nregions; i++) {
458         VuDevRegion *r = &dev->regions[i];
459         void *m = (void *) (uintptr_t) r->mmap_addr;
460
461         if (m) {
462             munmap(m, r->size + r->mmap_offset);
463         }
464     }
465     dev->nregions = memory->nregions;
466
467     DPRINT("Nregions: %d\n", memory->nregions);
468     for (i = 0; i < dev->nregions; i++) {
469         void *mmap_addr;
470         VhostUserMemoryRegion *msg_region = &memory->regions[i];
471         VuDevRegion *dev_region = &dev->regions[i];
472
473         DPRINT("Region %d\n", i);
474         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
475                msg_region->guest_phys_addr);
476         DPRINT("    memory_size:     0x%016"PRIx64"\n",
477                msg_region->memory_size);
478         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
479                msg_region->userspace_addr);
480         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
481                msg_region->mmap_offset);
482
483         dev_region->gpa = msg_region->guest_phys_addr;
484         dev_region->size = msg_region->memory_size;
485         dev_region->qva = msg_region->userspace_addr;
486         dev_region->mmap_offset = msg_region->mmap_offset;
487
488         /* We don't use offset argument of mmap() since the
489          * mapped address has to be page aligned, and we use huge
490          * pages.  */
491         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
492                          PROT_READ | PROT_WRITE, MAP_SHARED,
493                          vmsg->fds[i], 0);
494
495         if (mmap_addr == MAP_FAILED) {
496             vu_panic(dev, "region mmap error: %s", strerror(errno));
497         } else {
498             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
499             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
500                    dev_region->mmap_addr);
501         }
502
503         close(vmsg->fds[i]);
504     }
505
506     return false;
507 }
508
509 static bool
510 vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
511 {
512     int fd;
513     uint64_t log_mmap_size, log_mmap_offset;
514     void *rc;
515
516     if (vmsg->fd_num != 1 ||
517         vmsg->size != sizeof(vmsg->payload.log)) {
518         vu_panic(dev, "Invalid log_base message");
519         return true;
520     }
521
522     fd = vmsg->fds[0];
523     log_mmap_offset = vmsg->payload.log.mmap_offset;
524     log_mmap_size = vmsg->payload.log.mmap_size;
525     DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
526     DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
527
528     rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
529               log_mmap_offset);
530     close(fd);
531     if (rc == MAP_FAILED) {
532         perror("log mmap error");
533     }
534
535     if (dev->log_table) {
536         munmap(dev->log_table, dev->log_size);
537     }
538     dev->log_table = rc;
539     dev->log_size = log_mmap_size;
540
541     vmsg->size = sizeof(vmsg->payload.u64);
542     vmsg->fd_num = 0;
543
544     return true;
545 }
546
547 static bool
548 vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
549 {
550     if (vmsg->fd_num != 1) {
551         vu_panic(dev, "Invalid log_fd message");
552         return false;
553     }
554
555     if (dev->log_call_fd != -1) {
556         close(dev->log_call_fd);
557     }
558     dev->log_call_fd = vmsg->fds[0];
559     DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
560
561     return false;
562 }
563
564 static bool
565 vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
566 {
567     unsigned int index = vmsg->payload.state.index;
568     unsigned int num = vmsg->payload.state.num;
569
570     DPRINT("State.index: %d\n", index);
571     DPRINT("State.num:   %d\n", num);
572     dev->vq[index].vring.num = num;
573
574     return false;
575 }
576
577 static bool
578 vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
579 {
580     struct vhost_vring_addr *vra = &vmsg->payload.addr;
581     unsigned int index = vra->index;
582     VuVirtq *vq = &dev->vq[index];
583
584     DPRINT("vhost_vring_addr:\n");
585     DPRINT("    index:  %d\n", vra->index);
586     DPRINT("    flags:  %d\n", vra->flags);
587     DPRINT("    desc_user_addr:   0x%016llx\n", vra->desc_user_addr);
588     DPRINT("    used_user_addr:   0x%016llx\n", vra->used_user_addr);
589     DPRINT("    avail_user_addr:  0x%016llx\n", vra->avail_user_addr);
590     DPRINT("    log_guest_addr:   0x%016llx\n", vra->log_guest_addr);
591
592     vq->vring.flags = vra->flags;
593     vq->vring.desc = qva_to_va(dev, vra->desc_user_addr);
594     vq->vring.used = qva_to_va(dev, vra->used_user_addr);
595     vq->vring.avail = qva_to_va(dev, vra->avail_user_addr);
596     vq->vring.log_guest_addr = vra->log_guest_addr;
597
598     DPRINT("Setting virtq addresses:\n");
599     DPRINT("    vring_desc  at %p\n", vq->vring.desc);
600     DPRINT("    vring_used  at %p\n", vq->vring.used);
601     DPRINT("    vring_avail at %p\n", vq->vring.avail);
602
603     if (!(vq->vring.desc && vq->vring.used && vq->vring.avail)) {
604         vu_panic(dev, "Invalid vring_addr message");
605         return false;
606     }
607
608     vq->used_idx = vq->vring.used->idx;
609
610     if (vq->last_avail_idx != vq->used_idx) {
611         bool resume = dev->iface->queue_is_processed_in_order &&
612             dev->iface->queue_is_processed_in_order(dev, index);
613
614         DPRINT("Last avail index != used index: %u != %u%s\n",
615                vq->last_avail_idx, vq->used_idx,
616                resume ? ", resuming" : "");
617
618         if (resume) {
619             vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
620         }
621     }
622
623     return false;
624 }
625
626 static bool
627 vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
628 {
629     unsigned int index = vmsg->payload.state.index;
630     unsigned int num = vmsg->payload.state.num;
631
632     DPRINT("State.index: %d\n", index);
633     DPRINT("State.num:   %d\n", num);
634     dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
635
636     return false;
637 }
638
639 static bool
640 vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
641 {
642     unsigned int index = vmsg->payload.state.index;
643
644     DPRINT("State.index: %d\n", index);
645     vmsg->payload.state.num = dev->vq[index].last_avail_idx;
646     vmsg->size = sizeof(vmsg->payload.state);
647
648     dev->vq[index].started = false;
649     if (dev->iface->queue_set_started) {
650         dev->iface->queue_set_started(dev, index, false);
651     }
652
653     if (dev->vq[index].call_fd != -1) {
654         close(dev->vq[index].call_fd);
655         dev->vq[index].call_fd = -1;
656     }
657     if (dev->vq[index].kick_fd != -1) {
658         dev->remove_watch(dev, dev->vq[index].kick_fd);
659         close(dev->vq[index].kick_fd);
660         dev->vq[index].kick_fd = -1;
661     }
662
663     return true;
664 }
665
666 static bool
667 vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
668 {
669     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
670
671     if (index >= VHOST_MAX_NR_VIRTQUEUE) {
672         vmsg_close_fds(vmsg);
673         vu_panic(dev, "Invalid queue index: %u", index);
674         return false;
675     }
676
677     if (vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK ||
678         vmsg->fd_num != 1) {
679         vmsg_close_fds(vmsg);
680         vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
681         return false;
682     }
683
684     return true;
685 }
686
687 static bool
688 vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
689 {
690     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
691
692     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
693
694     if (!vu_check_queue_msg_file(dev, vmsg)) {
695         return false;
696     }
697
698     if (dev->vq[index].kick_fd != -1) {
699         dev->remove_watch(dev, dev->vq[index].kick_fd);
700         close(dev->vq[index].kick_fd);
701         dev->vq[index].kick_fd = -1;
702     }
703
704     if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
705         dev->vq[index].kick_fd = vmsg->fds[0];
706         DPRINT("Got kick_fd: %d for vq: %d\n", vmsg->fds[0], index);
707     }
708
709     dev->vq[index].started = true;
710     if (dev->iface->queue_set_started) {
711         dev->iface->queue_set_started(dev, index, true);
712     }
713
714     if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
715         dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
716                        vu_kick_cb, (void *)(long)index);
717
718         DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
719                dev->vq[index].kick_fd, index);
720     }
721
722     return false;
723 }
724
725 void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
726                           vu_queue_handler_cb handler)
727 {
728     int qidx = vq - dev->vq;
729
730     vq->handler = handler;
731     if (vq->kick_fd >= 0) {
732         if (handler) {
733             dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
734                            vu_kick_cb, (void *)(long)qidx);
735         } else {
736             dev->remove_watch(dev, vq->kick_fd);
737         }
738     }
739 }
740
741 static bool
742 vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
743 {
744     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
745
746     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
747
748     if (!vu_check_queue_msg_file(dev, vmsg)) {
749         return false;
750     }
751
752     if (dev->vq[index].call_fd != -1) {
753         close(dev->vq[index].call_fd);
754         dev->vq[index].call_fd = -1;
755     }
756
757     if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
758         dev->vq[index].call_fd = vmsg->fds[0];
759     }
760
761     DPRINT("Got call_fd: %d for vq: %d\n", vmsg->fds[0], index);
762
763     return false;
764 }
765
766 static bool
767 vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
768 {
769     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
770
771     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
772
773     if (!vu_check_queue_msg_file(dev, vmsg)) {
774         return false;
775     }
776
777     if (dev->vq[index].err_fd != -1) {
778         close(dev->vq[index].err_fd);
779         dev->vq[index].err_fd = -1;
780     }
781
782     if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
783         dev->vq[index].err_fd = vmsg->fds[0];
784     }
785
786     return false;
787 }
788
789 static bool
790 vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
791 {
792     uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
793                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ;
794
795     if (dev->iface->get_protocol_features) {
796         features |= dev->iface->get_protocol_features(dev);
797     }
798
799     vmsg->payload.u64 = features;
800     vmsg->size = sizeof(vmsg->payload.u64);
801     vmsg->fd_num = 0;
802
803     return true;
804 }
805
806 static bool
807 vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
808 {
809     uint64_t features = vmsg->payload.u64;
810
811     DPRINT("u64: 0x%016"PRIx64"\n", features);
812
813     dev->protocol_features = vmsg->payload.u64;
814
815     if (dev->iface->set_protocol_features) {
816         dev->iface->set_protocol_features(dev, features);
817     }
818
819     return false;
820 }
821
822 static bool
823 vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
824 {
825     DPRINT("Function %s() not implemented yet.\n", __func__);
826     return false;
827 }
828
829 static bool
830 vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
831 {
832     unsigned int index = vmsg->payload.state.index;
833     unsigned int enable = vmsg->payload.state.num;
834
835     DPRINT("State.index: %d\n", index);
836     DPRINT("State.enable:   %d\n", enable);
837
838     if (index >= VHOST_MAX_NR_VIRTQUEUE) {
839         vu_panic(dev, "Invalid vring_enable index: %u", index);
840         return false;
841     }
842
843     dev->vq[index].enable = enable;
844     return false;
845 }
846
847 static bool
848 vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
849 {
850     if (vmsg->fd_num != 1) {
851         vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
852         return false;
853     }
854
855     if (dev->slave_fd != -1) {
856         close(dev->slave_fd);
857     }
858     dev->slave_fd = vmsg->fds[0];
859     DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
860
861     return false;
862 }
863
864 static bool
865 vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
866 {
867     int ret = -1;
868
869     if (dev->iface->get_config) {
870         ret = dev->iface->get_config(dev, vmsg->payload.config.region,
871                                      vmsg->payload.config.size);
872     }
873
874     if (ret) {
875         /* resize to zero to indicate an error to master */
876         vmsg->size = 0;
877     }
878
879     return true;
880 }
881
882 static bool
883 vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
884 {
885     int ret = -1;
886
887     if (dev->iface->set_config) {
888         ret = dev->iface->set_config(dev, vmsg->payload.config.region,
889                                      vmsg->payload.config.offset,
890                                      vmsg->payload.config.size,
891                                      vmsg->payload.config.flags);
892         if (ret) {
893             vu_panic(dev, "Set virtio configuration space failed");
894         }
895     }
896
897     return false;
898 }
899
900 static bool
901 vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
902 {
903     dev->postcopy_ufd = -1;
904 #ifdef UFFDIO_API
905     struct uffdio_api api_struct;
906
907     dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
908     vmsg->size = 0;
909 #endif
910
911     if (dev->postcopy_ufd == -1) {
912         vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
913         goto out;
914     }
915
916 #ifdef UFFDIO_API
917     api_struct.api = UFFD_API;
918     api_struct.features = 0;
919     if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
920         vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
921         close(dev->postcopy_ufd);
922         dev->postcopy_ufd = -1;
923         goto out;
924     }
925     /* TODO: Stash feature flags somewhere */
926 #endif
927
928 out:
929     /* Return a ufd to the QEMU */
930     vmsg->fd_num = 1;
931     vmsg->fds[0] = dev->postcopy_ufd;
932     return true; /* = send a reply */
933 }
934
935 static bool
936 vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
937 {
938     vmsg->payload.u64 = -1;
939     vmsg->size = sizeof(vmsg->payload.u64);
940
941     if (dev->nregions) {
942         vu_panic(dev, "Regions already registered at postcopy-listen");
943         return true;
944     }
945     dev->postcopy_listening = true;
946
947     vmsg->flags = VHOST_USER_VERSION |  VHOST_USER_REPLY_MASK;
948     vmsg->payload.u64 = 0; /* Success */
949     return true;
950 }
951 static bool
952 vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
953 {
954     int do_reply = 0;
955
956     /* Print out generic part of the request. */
957     DPRINT("================ Vhost user message ================\n");
958     DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
959            vmsg->request);
960     DPRINT("Flags:   0x%x\n", vmsg->flags);
961     DPRINT("Size:    %d\n", vmsg->size);
962
963     if (vmsg->fd_num) {
964         int i;
965         DPRINT("Fds:");
966         for (i = 0; i < vmsg->fd_num; i++) {
967             DPRINT(" %d", vmsg->fds[i]);
968         }
969         DPRINT("\n");
970     }
971
972     if (dev->iface->process_msg &&
973         dev->iface->process_msg(dev, vmsg, &do_reply)) {
974         return do_reply;
975     }
976
977     switch (vmsg->request) {
978     case VHOST_USER_GET_FEATURES:
979         return vu_get_features_exec(dev, vmsg);
980     case VHOST_USER_SET_FEATURES:
981         return vu_set_features_exec(dev, vmsg);
982     case VHOST_USER_GET_PROTOCOL_FEATURES:
983         return vu_get_protocol_features_exec(dev, vmsg);
984     case VHOST_USER_SET_PROTOCOL_FEATURES:
985         return vu_set_protocol_features_exec(dev, vmsg);
986     case VHOST_USER_SET_OWNER:
987         return vu_set_owner_exec(dev, vmsg);
988     case VHOST_USER_RESET_OWNER:
989         return vu_reset_device_exec(dev, vmsg);
990     case VHOST_USER_SET_MEM_TABLE:
991         return vu_set_mem_table_exec(dev, vmsg);
992     case VHOST_USER_SET_LOG_BASE:
993         return vu_set_log_base_exec(dev, vmsg);
994     case VHOST_USER_SET_LOG_FD:
995         return vu_set_log_fd_exec(dev, vmsg);
996     case VHOST_USER_SET_VRING_NUM:
997         return vu_set_vring_num_exec(dev, vmsg);
998     case VHOST_USER_SET_VRING_ADDR:
999         return vu_set_vring_addr_exec(dev, vmsg);
1000     case VHOST_USER_SET_VRING_BASE:
1001         return vu_set_vring_base_exec(dev, vmsg);
1002     case VHOST_USER_GET_VRING_BASE:
1003         return vu_get_vring_base_exec(dev, vmsg);
1004     case VHOST_USER_SET_VRING_KICK:
1005         return vu_set_vring_kick_exec(dev, vmsg);
1006     case VHOST_USER_SET_VRING_CALL:
1007         return vu_set_vring_call_exec(dev, vmsg);
1008     case VHOST_USER_SET_VRING_ERR:
1009         return vu_set_vring_err_exec(dev, vmsg);
1010     case VHOST_USER_GET_QUEUE_NUM:
1011         return vu_get_queue_num_exec(dev, vmsg);
1012     case VHOST_USER_SET_VRING_ENABLE:
1013         return vu_set_vring_enable_exec(dev, vmsg);
1014     case VHOST_USER_SET_SLAVE_REQ_FD:
1015         return vu_set_slave_req_fd(dev, vmsg);
1016     case VHOST_USER_GET_CONFIG:
1017         return vu_get_config(dev, vmsg);
1018     case VHOST_USER_SET_CONFIG:
1019         return vu_set_config(dev, vmsg);
1020     case VHOST_USER_NONE:
1021         break;
1022     case VHOST_USER_POSTCOPY_ADVISE:
1023         return vu_set_postcopy_advise(dev, vmsg);
1024     case VHOST_USER_POSTCOPY_LISTEN:
1025         return vu_set_postcopy_listen(dev, vmsg);
1026     default:
1027         vmsg_close_fds(vmsg);
1028         vu_panic(dev, "Unhandled request: %d", vmsg->request);
1029     }
1030
1031     return false;
1032 }
1033
1034 bool
1035 vu_dispatch(VuDev *dev)
1036 {
1037     VhostUserMsg vmsg = { 0, };
1038     int reply_requested;
1039     bool success = false;
1040
1041     if (!vu_message_read(dev, dev->sock, &vmsg)) {
1042         goto end;
1043     }
1044
1045     reply_requested = vu_process_message(dev, &vmsg);
1046     if (!reply_requested) {
1047         success = true;
1048         goto end;
1049     }
1050
1051     if (!vu_message_write(dev, dev->sock, &vmsg)) {
1052         goto end;
1053     }
1054
1055     success = true;
1056
1057 end:
1058     free(vmsg.data);
1059     return success;
1060 }
1061
1062 void
1063 vu_deinit(VuDev *dev)
1064 {
1065     int i;
1066
1067     for (i = 0; i < dev->nregions; i++) {
1068         VuDevRegion *r = &dev->regions[i];
1069         void *m = (void *) (uintptr_t) r->mmap_addr;
1070         if (m != MAP_FAILED) {
1071             munmap(m, r->size + r->mmap_offset);
1072         }
1073     }
1074     dev->nregions = 0;
1075
1076     for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
1077         VuVirtq *vq = &dev->vq[i];
1078
1079         if (vq->call_fd != -1) {
1080             close(vq->call_fd);
1081             vq->call_fd = -1;
1082         }
1083
1084         if (vq->kick_fd != -1) {
1085             close(vq->kick_fd);
1086             vq->kick_fd = -1;
1087         }
1088
1089         if (vq->err_fd != -1) {
1090             close(vq->err_fd);
1091             vq->err_fd = -1;
1092         }
1093     }
1094
1095
1096     vu_close_log(dev);
1097     if (dev->slave_fd != -1) {
1098         close(dev->slave_fd);
1099         dev->slave_fd = -1;
1100     }
1101
1102     if (dev->sock != -1) {
1103         close(dev->sock);
1104     }
1105 }
1106
1107 void
1108 vu_init(VuDev *dev,
1109         int socket,
1110         vu_panic_cb panic,
1111         vu_set_watch_cb set_watch,
1112         vu_remove_watch_cb remove_watch,
1113         const VuDevIface *iface)
1114 {
1115     int i;
1116
1117     assert(socket >= 0);
1118     assert(set_watch);
1119     assert(remove_watch);
1120     assert(iface);
1121     assert(panic);
1122
1123     memset(dev, 0, sizeof(*dev));
1124
1125     dev->sock = socket;
1126     dev->panic = panic;
1127     dev->set_watch = set_watch;
1128     dev->remove_watch = remove_watch;
1129     dev->iface = iface;
1130     dev->log_call_fd = -1;
1131     dev->slave_fd = -1;
1132     for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
1133         dev->vq[i] = (VuVirtq) {
1134             .call_fd = -1, .kick_fd = -1, .err_fd = -1,
1135             .notification = true,
1136         };
1137     }
1138 }
1139
1140 VuVirtq *
1141 vu_get_queue(VuDev *dev, int qidx)
1142 {
1143     assert(qidx < VHOST_MAX_NR_VIRTQUEUE);
1144     return &dev->vq[qidx];
1145 }
1146
1147 bool
1148 vu_queue_enabled(VuDev *dev, VuVirtq *vq)
1149 {
1150     return vq->enable;
1151 }
1152
1153 bool
1154 vu_queue_started(const VuDev *dev, const VuVirtq *vq)
1155 {
1156     return vq->started;
1157 }
1158
1159 static inline uint16_t
1160 vring_avail_flags(VuVirtq *vq)
1161 {
1162     return vq->vring.avail->flags;
1163 }
1164
1165 static inline uint16_t
1166 vring_avail_idx(VuVirtq *vq)
1167 {
1168     vq->shadow_avail_idx = vq->vring.avail->idx;
1169
1170     return vq->shadow_avail_idx;
1171 }
1172
1173 static inline uint16_t
1174 vring_avail_ring(VuVirtq *vq, int i)
1175 {
1176     return vq->vring.avail->ring[i];
1177 }
1178
1179 static inline uint16_t
1180 vring_get_used_event(VuVirtq *vq)
1181 {
1182     return vring_avail_ring(vq, vq->vring.num);
1183 }
1184
1185 static int
1186 virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
1187 {
1188     uint16_t num_heads = vring_avail_idx(vq) - idx;
1189
1190     /* Check it isn't doing very strange things with descriptor numbers. */
1191     if (num_heads > vq->vring.num) {
1192         vu_panic(dev, "Guest moved used index from %u to %u",
1193                  idx, vq->shadow_avail_idx);
1194         return -1;
1195     }
1196     if (num_heads) {
1197         /* On success, callers read a descriptor at vq->last_avail_idx.
1198          * Make sure descriptor read does not bypass avail index read. */
1199         smp_rmb();
1200     }
1201
1202     return num_heads;
1203 }
1204
1205 static bool
1206 virtqueue_get_head(VuDev *dev, VuVirtq *vq,
1207                    unsigned int idx, unsigned int *head)
1208 {
1209     /* Grab the next descriptor number they're advertising, and increment
1210      * the index we've seen. */
1211     *head = vring_avail_ring(vq, idx % vq->vring.num);
1212
1213     /* If their number is silly, that's a fatal mistake. */
1214     if (*head >= vq->vring.num) {
1215         vu_panic(dev, "Guest says index %u is available", head);
1216         return false;
1217     }
1218
1219     return true;
1220 }
1221
1222 static int
1223 virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
1224                              uint64_t addr, size_t len)
1225 {
1226     struct vring_desc *ori_desc;
1227     uint64_t read_len;
1228
1229     if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
1230         return -1;
1231     }
1232
1233     if (len == 0) {
1234         return -1;
1235     }
1236
1237     while (len) {
1238         read_len = len;
1239         ori_desc = vu_gpa_to_va(dev, &read_len, addr);
1240         if (!ori_desc) {
1241             return -1;
1242         }
1243
1244         memcpy(desc, ori_desc, read_len);
1245         len -= read_len;
1246         addr += read_len;
1247         desc += read_len;
1248     }
1249
1250     return 0;
1251 }
1252
1253 enum {
1254     VIRTQUEUE_READ_DESC_ERROR = -1,
1255     VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
1256     VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
1257 };
1258
1259 static int
1260 virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
1261                          int i, unsigned int max, unsigned int *next)
1262 {
1263     /* If this descriptor says it doesn't chain, we're done. */
1264     if (!(desc[i].flags & VRING_DESC_F_NEXT)) {
1265         return VIRTQUEUE_READ_DESC_DONE;
1266     }
1267
1268     /* Check they're not leading us off end of descriptors. */
1269     *next = desc[i].next;
1270     /* Make sure compiler knows to grab that: we don't want it changing! */
1271     smp_wmb();
1272
1273     if (*next >= max) {
1274         vu_panic(dev, "Desc next is %u", next);
1275         return VIRTQUEUE_READ_DESC_ERROR;
1276     }
1277
1278     return VIRTQUEUE_READ_DESC_MORE;
1279 }
1280
1281 void
1282 vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
1283                          unsigned int *out_bytes,
1284                          unsigned max_in_bytes, unsigned max_out_bytes)
1285 {
1286     unsigned int idx;
1287     unsigned int total_bufs, in_total, out_total;
1288     int rc;
1289
1290     idx = vq->last_avail_idx;
1291
1292     total_bufs = in_total = out_total = 0;
1293     if (unlikely(dev->broken) ||
1294         unlikely(!vq->vring.avail)) {
1295         goto done;
1296     }
1297
1298     while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
1299         unsigned int max, desc_len, num_bufs, indirect = 0;
1300         uint64_t desc_addr, read_len;
1301         struct vring_desc *desc;
1302         struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
1303         unsigned int i;
1304
1305         max = vq->vring.num;
1306         num_bufs = total_bufs;
1307         if (!virtqueue_get_head(dev, vq, idx++, &i)) {
1308             goto err;
1309         }
1310         desc = vq->vring.desc;
1311
1312         if (desc[i].flags & VRING_DESC_F_INDIRECT) {
1313             if (desc[i].len % sizeof(struct vring_desc)) {
1314                 vu_panic(dev, "Invalid size for indirect buffer table");
1315                 goto err;
1316             }
1317
1318             /* If we've got too many, that implies a descriptor loop. */
1319             if (num_bufs >= max) {
1320                 vu_panic(dev, "Looped descriptor");
1321                 goto err;
1322             }
1323
1324             /* loop over the indirect descriptor table */
1325             indirect = 1;
1326             desc_addr = desc[i].addr;
1327             desc_len = desc[i].len;
1328             max = desc_len / sizeof(struct vring_desc);
1329             read_len = desc_len;
1330             desc = vu_gpa_to_va(dev, &read_len, desc_addr);
1331             if (unlikely(desc && read_len != desc_len)) {
1332                 /* Failed to use zero copy */
1333                 desc = NULL;
1334                 if (!virtqueue_read_indirect_desc(dev, desc_buf,
1335                                                   desc_addr,
1336                                                   desc_len)) {
1337                     desc = desc_buf;
1338                 }
1339             }
1340             if (!desc) {
1341                 vu_panic(dev, "Invalid indirect buffer table");
1342                 goto err;
1343             }
1344             num_bufs = i = 0;
1345         }
1346
1347         do {
1348             /* If we've got too many, that implies a descriptor loop. */
1349             if (++num_bufs > max) {
1350                 vu_panic(dev, "Looped descriptor");
1351                 goto err;
1352             }
1353
1354             if (desc[i].flags & VRING_DESC_F_WRITE) {
1355                 in_total += desc[i].len;
1356             } else {
1357                 out_total += desc[i].len;
1358             }
1359             if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
1360                 goto done;
1361             }
1362             rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
1363         } while (rc == VIRTQUEUE_READ_DESC_MORE);
1364
1365         if (rc == VIRTQUEUE_READ_DESC_ERROR) {
1366             goto err;
1367         }
1368
1369         if (!indirect) {
1370             total_bufs = num_bufs;
1371         } else {
1372             total_bufs++;
1373         }
1374     }
1375     if (rc < 0) {
1376         goto err;
1377     }
1378 done:
1379     if (in_bytes) {
1380         *in_bytes = in_total;
1381     }
1382     if (out_bytes) {
1383         *out_bytes = out_total;
1384     }
1385     return;
1386
1387 err:
1388     in_total = out_total = 0;
1389     goto done;
1390 }
1391
1392 bool
1393 vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
1394                      unsigned int out_bytes)
1395 {
1396     unsigned int in_total, out_total;
1397
1398     vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
1399                              in_bytes, out_bytes);
1400
1401     return in_bytes <= in_total && out_bytes <= out_total;
1402 }
1403
1404 /* Fetch avail_idx from VQ memory only when we really need to know if
1405  * guest has added some buffers. */
1406 bool
1407 vu_queue_empty(VuDev *dev, VuVirtq *vq)
1408 {
1409     if (unlikely(dev->broken) ||
1410         unlikely(!vq->vring.avail)) {
1411         return true;
1412     }
1413
1414     if (vq->shadow_avail_idx != vq->last_avail_idx) {
1415         return false;
1416     }
1417
1418     return vring_avail_idx(vq) == vq->last_avail_idx;
1419 }
1420
1421 static inline
1422 bool has_feature(uint64_t features, unsigned int fbit)
1423 {
1424     assert(fbit < 64);
1425     return !!(features & (1ULL << fbit));
1426 }
1427
1428 static inline
1429 bool vu_has_feature(VuDev *dev,
1430                     unsigned int fbit)
1431 {
1432     return has_feature(dev->features, fbit);
1433 }
1434
1435 static bool
1436 vring_notify(VuDev *dev, VuVirtq *vq)
1437 {
1438     uint16_t old, new;
1439     bool v;
1440
1441     /* We need to expose used array entries before checking used event. */
1442     smp_mb();
1443
1444     /* Always notify when queue is empty (when feature acknowledge) */
1445     if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
1446         !vq->inuse && vu_queue_empty(dev, vq)) {
1447         return true;
1448     }
1449
1450     if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1451         return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
1452     }
1453
1454     v = vq->signalled_used_valid;
1455     vq->signalled_used_valid = true;
1456     old = vq->signalled_used;
1457     new = vq->signalled_used = vq->used_idx;
1458     return !v || vring_need_event(vring_get_used_event(vq), new, old);
1459 }
1460
1461 void
1462 vu_queue_notify(VuDev *dev, VuVirtq *vq)
1463 {
1464     if (unlikely(dev->broken) ||
1465         unlikely(!vq->vring.avail)) {
1466         return;
1467     }
1468
1469     if (!vring_notify(dev, vq)) {
1470         DPRINT("skipped notify...\n");
1471         return;
1472     }
1473
1474     if (eventfd_write(vq->call_fd, 1) < 0) {
1475         vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
1476     }
1477 }
1478
1479 static inline void
1480 vring_used_flags_set_bit(VuVirtq *vq, int mask)
1481 {
1482     uint16_t *flags;
1483
1484     flags = (uint16_t *)((char*)vq->vring.used +
1485                          offsetof(struct vring_used, flags));
1486     *flags |= mask;
1487 }
1488
1489 static inline void
1490 vring_used_flags_unset_bit(VuVirtq *vq, int mask)
1491 {
1492     uint16_t *flags;
1493
1494     flags = (uint16_t *)((char*)vq->vring.used +
1495                          offsetof(struct vring_used, flags));
1496     *flags &= ~mask;
1497 }
1498
1499 static inline void
1500 vring_set_avail_event(VuVirtq *vq, uint16_t val)
1501 {
1502     if (!vq->notification) {
1503         return;
1504     }
1505
1506     *((uint16_t *) &vq->vring.used->ring[vq->vring.num]) = val;
1507 }
1508
1509 void
1510 vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
1511 {
1512     vq->notification = enable;
1513     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1514         vring_set_avail_event(vq, vring_avail_idx(vq));
1515     } else if (enable) {
1516         vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
1517     } else {
1518         vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
1519     }
1520     if (enable) {
1521         /* Expose avail event/used flags before caller checks the avail idx. */
1522         smp_mb();
1523     }
1524 }
1525
1526 static void
1527 virtqueue_map_desc(VuDev *dev,
1528                    unsigned int *p_num_sg, struct iovec *iov,
1529                    unsigned int max_num_sg, bool is_write,
1530                    uint64_t pa, size_t sz)
1531 {
1532     unsigned num_sg = *p_num_sg;
1533
1534     assert(num_sg <= max_num_sg);
1535
1536     if (!sz) {
1537         vu_panic(dev, "virtio: zero sized buffers are not allowed");
1538         return;
1539     }
1540
1541     while (sz) {
1542         uint64_t len = sz;
1543
1544         if (num_sg == max_num_sg) {
1545             vu_panic(dev, "virtio: too many descriptors in indirect table");
1546             return;
1547         }
1548
1549         iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
1550         if (iov[num_sg].iov_base == NULL) {
1551             vu_panic(dev, "virtio: invalid address for buffers");
1552             return;
1553         }
1554         iov[num_sg].iov_len = len;
1555         num_sg++;
1556         sz -= len;
1557         pa += len;
1558     }
1559
1560     *p_num_sg = num_sg;
1561 }
1562
1563 /* Round number down to multiple */
1564 #define ALIGN_DOWN(n, m) ((n) / (m) * (m))
1565
1566 /* Round number up to multiple */
1567 #define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
1568
1569 static void *
1570 virtqueue_alloc_element(size_t sz,
1571                                      unsigned out_num, unsigned in_num)
1572 {
1573     VuVirtqElement *elem;
1574     size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
1575     size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
1576     size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
1577
1578     assert(sz >= sizeof(VuVirtqElement));
1579     elem = malloc(out_sg_end);
1580     elem->out_num = out_num;
1581     elem->in_num = in_num;
1582     elem->in_sg = (void *)elem + in_sg_ofs;
1583     elem->out_sg = (void *)elem + out_sg_ofs;
1584     return elem;
1585 }
1586
1587 void *
1588 vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
1589 {
1590     unsigned int i, head, max, desc_len;
1591     uint64_t desc_addr, read_len;
1592     VuVirtqElement *elem;
1593     unsigned out_num, in_num;
1594     struct iovec iov[VIRTQUEUE_MAX_SIZE];
1595     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
1596     struct vring_desc *desc;
1597     int rc;
1598
1599     if (unlikely(dev->broken) ||
1600         unlikely(!vq->vring.avail)) {
1601         return NULL;
1602     }
1603
1604     if (vu_queue_empty(dev, vq)) {
1605         return NULL;
1606     }
1607     /* Needed after virtio_queue_empty(), see comment in
1608      * virtqueue_num_heads(). */
1609     smp_rmb();
1610
1611     /* When we start there are none of either input nor output. */
1612     out_num = in_num = 0;
1613
1614     max = vq->vring.num;
1615     if (vq->inuse >= vq->vring.num) {
1616         vu_panic(dev, "Virtqueue size exceeded");
1617         return NULL;
1618     }
1619
1620     if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
1621         return NULL;
1622     }
1623
1624     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1625         vring_set_avail_event(vq, vq->last_avail_idx);
1626     }
1627
1628     i = head;
1629     desc = vq->vring.desc;
1630     if (desc[i].flags & VRING_DESC_F_INDIRECT) {
1631         if (desc[i].len % sizeof(struct vring_desc)) {
1632             vu_panic(dev, "Invalid size for indirect buffer table");
1633         }
1634
1635         /* loop over the indirect descriptor table */
1636         desc_addr = desc[i].addr;
1637         desc_len = desc[i].len;
1638         max = desc_len / sizeof(struct vring_desc);
1639         read_len = desc_len;
1640         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
1641         if (unlikely(desc && read_len != desc_len)) {
1642             /* Failed to use zero copy */
1643             desc = NULL;
1644             if (!virtqueue_read_indirect_desc(dev, desc_buf,
1645                                               desc_addr,
1646                                               desc_len)) {
1647                 desc = desc_buf;
1648             }
1649         }
1650         if (!desc) {
1651             vu_panic(dev, "Invalid indirect buffer table");
1652             return NULL;
1653         }
1654         i = 0;
1655     }
1656
1657     /* Collect all the descriptors */
1658     do {
1659         if (desc[i].flags & VRING_DESC_F_WRITE) {
1660             virtqueue_map_desc(dev, &in_num, iov + out_num,
1661                                VIRTQUEUE_MAX_SIZE - out_num, true,
1662                                desc[i].addr, desc[i].len);
1663         } else {
1664             if (in_num) {
1665                 vu_panic(dev, "Incorrect order for descriptors");
1666                 return NULL;
1667             }
1668             virtqueue_map_desc(dev, &out_num, iov,
1669                                VIRTQUEUE_MAX_SIZE, false,
1670                                desc[i].addr, desc[i].len);
1671         }
1672
1673         /* If we've got too many, that implies a descriptor loop. */
1674         if ((in_num + out_num) > max) {
1675             vu_panic(dev, "Looped descriptor");
1676         }
1677         rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
1678     } while (rc == VIRTQUEUE_READ_DESC_MORE);
1679
1680     if (rc == VIRTQUEUE_READ_DESC_ERROR) {
1681         return NULL;
1682     }
1683
1684     /* Now copy what we have collected and mapped */
1685     elem = virtqueue_alloc_element(sz, out_num, in_num);
1686     elem->index = head;
1687     for (i = 0; i < out_num; i++) {
1688         elem->out_sg[i] = iov[i];
1689     }
1690     for (i = 0; i < in_num; i++) {
1691         elem->in_sg[i] = iov[out_num + i];
1692     }
1693
1694     vq->inuse++;
1695
1696     return elem;
1697 }
1698
1699 bool
1700 vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
1701 {
1702     if (num > vq->inuse) {
1703         return false;
1704     }
1705     vq->last_avail_idx -= num;
1706     vq->inuse -= num;
1707     return true;
1708 }
1709
1710 static inline
1711 void vring_used_write(VuDev *dev, VuVirtq *vq,
1712                       struct vring_used_elem *uelem, int i)
1713 {
1714     struct vring_used *used = vq->vring.used;
1715
1716     used->ring[i] = *uelem;
1717     vu_log_write(dev, vq->vring.log_guest_addr +
1718                  offsetof(struct vring_used, ring[i]),
1719                  sizeof(used->ring[i]));
1720 }
1721
1722
1723 static void
1724 vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
1725                   const VuVirtqElement *elem,
1726                   unsigned int len)
1727 {
1728     struct vring_desc *desc = vq->vring.desc;
1729     unsigned int i, max, min, desc_len;
1730     uint64_t desc_addr, read_len;
1731     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
1732     unsigned num_bufs = 0;
1733
1734     max = vq->vring.num;
1735     i = elem->index;
1736
1737     if (desc[i].flags & VRING_DESC_F_INDIRECT) {
1738         if (desc[i].len % sizeof(struct vring_desc)) {
1739             vu_panic(dev, "Invalid size for indirect buffer table");
1740         }
1741
1742         /* loop over the indirect descriptor table */
1743         desc_addr = desc[i].addr;
1744         desc_len = desc[i].len;
1745         max = desc_len / sizeof(struct vring_desc);
1746         read_len = desc_len;
1747         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
1748         if (unlikely(desc && read_len != desc_len)) {
1749             /* Failed to use zero copy */
1750             desc = NULL;
1751             if (!virtqueue_read_indirect_desc(dev, desc_buf,
1752                                               desc_addr,
1753                                               desc_len)) {
1754                 desc = desc_buf;
1755             }
1756         }
1757         if (!desc) {
1758             vu_panic(dev, "Invalid indirect buffer table");
1759             return;
1760         }
1761         i = 0;
1762     }
1763
1764     do {
1765         if (++num_bufs > max) {
1766             vu_panic(dev, "Looped descriptor");
1767             return;
1768         }
1769
1770         if (desc[i].flags & VRING_DESC_F_WRITE) {
1771             min = MIN(desc[i].len, len);
1772             vu_log_write(dev, desc[i].addr, min);
1773             len -= min;
1774         }
1775
1776     } while (len > 0 &&
1777              (virtqueue_read_next_desc(dev, desc, i, max, &i)
1778               == VIRTQUEUE_READ_DESC_MORE));
1779 }
1780
1781 void
1782 vu_queue_fill(VuDev *dev, VuVirtq *vq,
1783               const VuVirtqElement *elem,
1784               unsigned int len, unsigned int idx)
1785 {
1786     struct vring_used_elem uelem;
1787
1788     if (unlikely(dev->broken) ||
1789         unlikely(!vq->vring.avail)) {
1790         return;
1791     }
1792
1793     vu_log_queue_fill(dev, vq, elem, len);
1794
1795     idx = (idx + vq->used_idx) % vq->vring.num;
1796
1797     uelem.id = elem->index;
1798     uelem.len = len;
1799     vring_used_write(dev, vq, &uelem, idx);
1800 }
1801
1802 static inline
1803 void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
1804 {
1805     vq->vring.used->idx = val;
1806     vu_log_write(dev,
1807                  vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
1808                  sizeof(vq->vring.used->idx));
1809
1810     vq->used_idx = val;
1811 }
1812
1813 void
1814 vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
1815 {
1816     uint16_t old, new;
1817
1818     if (unlikely(dev->broken) ||
1819         unlikely(!vq->vring.avail)) {
1820         return;
1821     }
1822
1823     /* Make sure buffer is written before we update index. */
1824     smp_wmb();
1825
1826     old = vq->used_idx;
1827     new = old + count;
1828     vring_used_idx_set(dev, vq, new);
1829     vq->inuse -= count;
1830     if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
1831         vq->signalled_used_valid = false;
1832     }
1833 }
1834
1835 void
1836 vu_queue_push(VuDev *dev, VuVirtq *vq,
1837               const VuVirtqElement *elem, unsigned int len)
1838 {
1839     vu_queue_fill(dev, vq, elem, len, 0);
1840     vu_queue_flush(dev, vq, 1);
1841 }
This page took 0.118429 seconds and 2 git commands to generate.