diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
| -rw-r--r-- | drivers/vhost/vhost.c | 2151 |
1 files changed, 1531 insertions, 620 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 9cb3f722dce1..bccdc9eab267 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* Copyright (C) 2009 Red Hat, Inc. * Copyright (C) 2006 Rusty Russell IBM Corporation * @@ -6,8 +7,6 @@ * Inspiration, some code, and most witty comments come from * Documentation/virtual/lguest/lguest.c, by Rusty Russell * - * This work is licensed under the terms of the GNU GPL, version 2. - * * Generic code for virtio server in host kernel. */ @@ -15,7 +14,6 @@ #include <linux/vhost.h> #include <linux/uio.h> #include <linux/mm.h> -#include <linux/mmu_context.h> #include <linux/miscdevice.h> #include <linux/mutex.h> #include <linux/poll.h> @@ -29,7 +27,10 @@ #include <linux/sort.h> #include <linux/sched/mm.h> #include <linux/sched/signal.h> +#include <linux/sched/vhost_task.h> #include <linux/interval_tree_generic.h> +#include <linux/nospec.h> +#include <linux/kcov.h> #include "vhost.h" @@ -41,6 +42,13 @@ static int max_iotlb_entries = 2048; module_param(max_iotlb_entries, int, 0444); MODULE_PARM_DESC(max_iotlb_entries, "Maximum number of iotlb entries. (default: 2048)"); +static bool fork_from_owner_default = VHOST_FORK_OWNER_TASK; + +#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL +module_param(fork_from_owner_default, bool, 0444); +MODULE_PARM_DESC(fork_from_owner_default, + "Set task mode as the default(default: Y)"); +#endif enum { VHOST_MEMORY_F_LOG = 0x1, @@ -49,10 +57,6 @@ enum { #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) -INTERVAL_TREE_DEFINE(struct vhost_umem_node, - rb, __u64, __subtree_last, - START, LAST, static inline, vhost_umem_interval_tree); - #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) { @@ -169,11 +173,16 @@ static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync, void *key) { struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait); + struct vhost_work *work = &poll->work; - if (!((unsigned long)key & poll->mask)) + if (!(key_to_poll(key) & poll->mask)) return 0; - vhost_poll_queue(poll); + if (!poll->dev->use_worker) + work->fn(work); + else + vhost_poll_queue(poll); + return 0; } @@ -181,19 +190,20 @@ void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) { clear_bit(VHOST_WORK_QUEUED, &work->flags); work->fn = fn; - init_waitqueue_head(&work->done); } EXPORT_SYMBOL_GPL(vhost_work_init); /* Init poll structure */ void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, - unsigned long mask, struct vhost_dev *dev) + __poll_t mask, struct vhost_dev *dev, + struct vhost_virtqueue *vq) { init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); init_poll_funcptr(&poll->table, vhost_poll_func); poll->mask = mask; poll->dev = dev; poll->wqh = NULL; + poll->vq = vq; vhost_work_init(&poll->work, fn); } @@ -203,22 +213,20 @@ EXPORT_SYMBOL_GPL(vhost_poll_init); * keep a reference to a file until after vhost_poll_stop is called. */ int vhost_poll_start(struct vhost_poll *poll, struct file *file) { - unsigned long mask; - int ret = 0; + __poll_t mask; if (poll->wqh) return 0; - mask = file->f_op->poll(file, &poll->table); + mask = vfs_poll(file, &poll->table); if (mask) - vhost_poll_wakeup(&poll->wait, 0, 0, (void *)mask); - if (mask & POLLERR) { - if (poll->wqh) - remove_wait_queue(poll->wqh, &poll->wait); - ret = -EINVAL; + vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask)); + if (mask & EPOLLERR) { + vhost_poll_stop(poll); + return -EINVAL; } - return ret; + return 0; } EXPORT_SYMBOL_GPL(vhost_poll_start); @@ -233,54 +241,98 @@ void vhost_poll_stop(struct vhost_poll *poll) } EXPORT_SYMBOL_GPL(vhost_poll_stop); -void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) +static void vhost_worker_queue(struct vhost_worker *worker, + struct vhost_work *work) { - struct vhost_flush_struct flush; + if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { + /* We can only add the work to the list after we're + * sure it was not in the list. + * test_and_set_bit() implies a memory barrier. + */ + llist_add(&work->node, &worker->work_list); + worker->ops->wakeup(worker); + } +} - if (dev->worker) { - init_completion(&flush.wait_event); - vhost_work_init(&flush.work, vhost_flush_work); +bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work) +{ + struct vhost_worker *worker; + bool queued = false; - vhost_work_queue(dev, &flush.work); - wait_for_completion(&flush.wait_event); + rcu_read_lock(); + worker = rcu_dereference(vq->worker); + if (worker) { + queued = true; + vhost_worker_queue(worker, work); } + rcu_read_unlock(); + + return queued; +} +EXPORT_SYMBOL_GPL(vhost_vq_work_queue); + +/** + * __vhost_worker_flush - flush a worker + * @worker: worker to flush + * + * The worker's flush_mutex must be held. + */ +static void __vhost_worker_flush(struct vhost_worker *worker) +{ + struct vhost_flush_struct flush; + + if (!worker->attachment_cnt || worker->killed) + return; + + init_completion(&flush.wait_event); + vhost_work_init(&flush.work, vhost_flush_work); + + vhost_worker_queue(worker, &flush.work); + /* + * Drop mutex in case our worker is killed and it needs to take the + * mutex to force cleanup. + */ + mutex_unlock(&worker->mutex); + wait_for_completion(&flush.wait_event); + mutex_lock(&worker->mutex); } -EXPORT_SYMBOL_GPL(vhost_work_flush); -/* Flush any work that has been scheduled. When calling this, don't hold any - * locks that are also used by the callback. */ -void vhost_poll_flush(struct vhost_poll *poll) +static void vhost_worker_flush(struct vhost_worker *worker) { - vhost_work_flush(poll->dev, &poll->work); + mutex_lock(&worker->mutex); + __vhost_worker_flush(worker); + mutex_unlock(&worker->mutex); } -EXPORT_SYMBOL_GPL(vhost_poll_flush); -void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) +void vhost_dev_flush(struct vhost_dev *dev) { - if (!dev->worker) - return; + struct vhost_worker *worker; + unsigned long i; - if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { - /* We can only add the work to the list after we're - * sure it was not in the list. - * test_and_set_bit() implies a memory barrier. - */ - llist_add(&work->node, &dev->work_list); - wake_up_process(dev->worker); - } + xa_for_each(&dev->worker_xa, i, worker) + vhost_worker_flush(worker); } -EXPORT_SYMBOL_GPL(vhost_work_queue); +EXPORT_SYMBOL_GPL(vhost_dev_flush); /* A lockless hint for busy polling code to exit the loop */ -bool vhost_has_work(struct vhost_dev *dev) +bool vhost_vq_has_work(struct vhost_virtqueue *vq) { - return !llist_empty(&dev->work_list); + struct vhost_worker *worker; + bool has_work = false; + + rcu_read_lock(); + worker = rcu_dereference(vq->worker); + if (worker && !llist_empty(&worker->work_list)) + has_work = true; + rcu_read_unlock(); + + return has_work; } -EXPORT_SYMBOL_GPL(vhost_has_work); +EXPORT_SYMBOL_GPL(vhost_vq_has_work); void vhost_poll_queue(struct vhost_poll *poll) { - vhost_work_queue(poll->dev, &poll->work); + vhost_vq_work_queue(poll->vq, &poll->work); } EXPORT_SYMBOL_GPL(vhost_poll_queue); @@ -300,6 +352,18 @@ static void vhost_vq_meta_reset(struct vhost_dev *d) __vhost_vq_meta_reset(d->vqs[i]); } +static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx) +{ + call_ctx->ctx = NULL; + memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer)); +} + +bool vhost_vq_is_setup(struct vhost_virtqueue *vq) +{ + return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq); +} +EXPORT_SYMBOL_GPL(vhost_vq_is_setup); + static void vhost_vq_reset(struct vhost_dev *dev, struct vhost_virtqueue *vq) { @@ -308,6 +372,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->avail = NULL; vq->used = NULL; vq->last_avail_idx = 0; + vq->next_avail_head = 0; vq->avail_idx = 0; vq->last_used_idx = 0; vq->signalled_used = 0; @@ -316,31 +381,30 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_used = false; vq->log_addr = -1ull; vq->private_data = NULL; - vq->acked_features = 0; + virtio_features_zero(vq->acked_features_array); + vq->acked_backend_features = 0; vq->log_base = NULL; vq->error_ctx = NULL; - vq->error = NULL; vq->kick = NULL; - vq->call_ctx = NULL; - vq->call = NULL; vq->log_ctx = NULL; - vhost_reset_is_le(vq); vhost_disable_cross_endian(vq); + vhost_reset_is_le(vq); vq->busyloop_timeout = 0; vq->umem = NULL; vq->iotlb = NULL; + rcu_assign_pointer(vq->worker, NULL); + vhost_vring_call_reset(&vq->call_ctx); __vhost_vq_meta_reset(vq); } -static int vhost_worker(void *data) +static int vhost_run_work_kthread_list(void *data) { - struct vhost_dev *dev = data; + struct vhost_worker *worker = data; struct vhost_work *work, *work_next; + struct vhost_dev *dev = worker->dev; struct llist_node *node; - mm_segment_t oldfs = get_fs(); - set_fs(USER_DS); - use_mm(dev->mm); + kthread_use_mm(dev->mm); for (;;) { /* mb paired w/ kthread_stop */ @@ -350,8 +414,7 @@ static int vhost_worker(void *data) __set_current_state(TASK_RUNNING); break; } - - node = llist_del_all(&dev->work_list); + node = llist_del_all(&worker->work_list); if (!node) schedule(); @@ -361,16 +424,76 @@ static int vhost_worker(void *data) llist_for_each_entry_safe(work, work_next, node, node) { clear_bit(VHOST_WORK_QUEUED, &work->flags); __set_current_state(TASK_RUNNING); + kcov_remote_start_common(worker->kcov_handle); work->fn(work); - if (need_resched()) - schedule(); + kcov_remote_stop(); + cond_resched(); } } - unuse_mm(dev->mm); - set_fs(oldfs); + kthread_unuse_mm(dev->mm); + return 0; } +static bool vhost_run_work_list(void *data) +{ + struct vhost_worker *worker = data; + struct vhost_work *work, *work_next; + struct llist_node *node; + + node = llist_del_all(&worker->work_list); + if (node) { + __set_current_state(TASK_RUNNING); + + node = llist_reverse_order(node); + /* make sure flag is seen after deletion */ + smp_wmb(); + llist_for_each_entry_safe(work, work_next, node, node) { + clear_bit(VHOST_WORK_QUEUED, &work->flags); + kcov_remote_start_common(worker->kcov_handle); + work->fn(work); + kcov_remote_stop(); + cond_resched(); + } + } + + return !!node; +} + +static void vhost_worker_killed(void *data) +{ + struct vhost_worker *worker = data; + struct vhost_dev *dev = worker->dev; + struct vhost_virtqueue *vq; + int i, attach_cnt = 0; + + mutex_lock(&worker->mutex); + worker->killed = true; + + for (i = 0; i < dev->nvqs; i++) { + vq = dev->vqs[i]; + + mutex_lock(&vq->mutex); + if (worker == + rcu_dereference_check(vq->worker, + lockdep_is_held(&vq->mutex))) { + rcu_assign_pointer(vq->worker, NULL); + attach_cnt++; + } + mutex_unlock(&vq->mutex); + } + + worker->attachment_cnt -= attach_cnt; + if (attach_cnt) + synchronize_rcu(); + /* + * Finish vhost_worker_flush calls and any other works that snuck in + * before the synchronize_rcu. + */ + vhost_run_work_list(worker); + mutex_unlock(&worker->mutex); +} + static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) { kfree(vq->indirect); @@ -379,6 +502,8 @@ static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) vq->log = NULL; kfree(vq->heads); vq->heads = NULL; + kfree(vq->nheads); + vq->nheads = NULL; } /* Helper to allocate iovec buffers for all vqs. */ @@ -389,11 +514,16 @@ static long vhost_dev_alloc_iovecs(struct vhost_dev *dev) for (i = 0; i < dev->nvqs; ++i) { vq = dev->vqs[i]; - vq->indirect = kmalloc(sizeof *vq->indirect * UIO_MAXIOV, - GFP_KERNEL); - vq->log = kmalloc(sizeof *vq->log * UIO_MAXIOV, GFP_KERNEL); - vq->heads = kmalloc(sizeof *vq->heads * UIO_MAXIOV, GFP_KERNEL); - if (!vq->indirect || !vq->log || !vq->heads) + vq->indirect = kmalloc_array(UIO_MAXIOV, + sizeof(*vq->indirect), + GFP_KERNEL); + vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log), + GFP_KERNEL); + vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads), + GFP_KERNEL); + vq->nheads = kmalloc_array(dev->iov_limit, sizeof(*vq->nheads), + GFP_KERNEL); + if (!vq->indirect || !vq->log || !vq->heads || !vq->nheads) goto err_nomem; } return 0; @@ -412,8 +542,51 @@ static void vhost_dev_free_iovecs(struct vhost_dev *dev) vhost_vq_free_iovecs(dev->vqs[i]); } +bool vhost_exceeds_weight(struct vhost_virtqueue *vq, + int pkts, int total_len) +{ + struct vhost_dev *dev = vq->dev; + + if ((dev->byte_weight && total_len >= dev->byte_weight) || + pkts >= dev->weight) { + vhost_poll_queue(&vq->poll); + return true; + } + + return false; +} +EXPORT_SYMBOL_GPL(vhost_exceeds_weight); + +static size_t vhost_get_avail_size(struct vhost_virtqueue *vq, + unsigned int num) +{ + size_t event __maybe_unused = + vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; + + return size_add(struct_size(vq->avail, ring, num), event); +} + +static size_t vhost_get_used_size(struct vhost_virtqueue *vq, + unsigned int num) +{ + size_t event __maybe_unused = + vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; + + return size_add(struct_size(vq->used, ring, num), event); +} + +static size_t vhost_get_desc_size(struct vhost_virtqueue *vq, + unsigned int num) +{ + return sizeof(*vq->desc) * num; +} + void vhost_dev_init(struct vhost_dev *dev, - struct vhost_virtqueue **vqs, int nvqs) + struct vhost_virtqueue **vqs, int nvqs, + int iov_limit, int weight, int byte_weight, + bool use_worker, + int (*msg_handler)(struct vhost_dev *dev, u32 asid, + struct vhost_iotlb_msg *msg)) { struct vhost_virtqueue *vq; int i; @@ -422,29 +595,33 @@ void vhost_dev_init(struct vhost_dev *dev, dev->nvqs = nvqs; mutex_init(&dev->mutex); dev->log_ctx = NULL; - dev->log_file = NULL; dev->umem = NULL; dev->iotlb = NULL; dev->mm = NULL; - dev->worker = NULL; - init_llist_head(&dev->work_list); + dev->iov_limit = iov_limit; + dev->weight = weight; + dev->byte_weight = byte_weight; + dev->use_worker = use_worker; + dev->msg_handler = msg_handler; + dev->fork_owner = fork_from_owner_default; init_waitqueue_head(&dev->wait); INIT_LIST_HEAD(&dev->read_list); INIT_LIST_HEAD(&dev->pending_list); spin_lock_init(&dev->iotlb_lock); - + xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC); for (i = 0; i < dev->nvqs; ++i) { vq = dev->vqs[i]; vq->log = NULL; vq->indirect = NULL; vq->heads = NULL; + vq->nheads = NULL; vq->dev = dev; mutex_init(&vq->mutex); vhost_vq_reset(dev, vq); if (vq->handle_kick) vhost_poll_init(&vq->poll, vq->handle_kick, - POLLIN, dev); + EPOLLIN, dev, vq); } } EXPORT_SYMBOL_GPL(vhost_dev_init); @@ -471,14 +648,29 @@ static void vhost_attach_cgroups_work(struct vhost_work *work) s->ret = cgroup_attach_task_all(s->owner, current); } -static int vhost_attach_cgroups(struct vhost_dev *dev) +static int vhost_attach_task_to_cgroups(struct vhost_worker *worker) { struct vhost_attach_cgroups_struct attach; + int saved_cnt; attach.owner = current; + vhost_work_init(&attach.work, vhost_attach_cgroups_work); - vhost_work_queue(dev, &attach.work); - vhost_work_flush(dev, &attach.work); + vhost_worker_queue(worker, &attach.work); + + mutex_lock(&worker->mutex); + + /* + * Bypass attachment_cnt check in __vhost_worker_flush: + * Temporarily change it to INT_MAX to bypass the check + */ + saved_cnt = worker->attachment_cnt; + worker->attachment_cnt = INT_MAX; + __vhost_worker_flush(worker); + worker->attachment_cnt = saved_cnt; + + mutex_unlock(&worker->mutex); + return attach.ret; } @@ -489,11 +681,423 @@ bool vhost_dev_has_owner(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_dev_has_owner); +static void vhost_attach_mm(struct vhost_dev *dev) +{ + /* No owner, become one */ + if (dev->use_worker) { + dev->mm = get_task_mm(current); + } else { + /* vDPA device does not use worker thread, so there's + * no need to hold the address space for mm. This helps + * to avoid deadlock in the case of mmap() which may + * hold the refcnt of the file and depends on release + * method to remove vma. + */ + dev->mm = current->mm; + mmgrab(dev->mm); + } +} + +static void vhost_detach_mm(struct vhost_dev *dev) +{ + if (!dev->mm) + return; + + if (dev->use_worker) + mmput(dev->mm); + else + mmdrop(dev->mm); + + dev->mm = NULL; +} + +static void vhost_worker_destroy(struct vhost_dev *dev, + struct vhost_worker *worker) +{ + if (!worker) + return; + + WARN_ON(!llist_empty(&worker->work_list)); + xa_erase(&dev->worker_xa, worker->id); + worker->ops->stop(worker); + kfree(worker); +} + +static void vhost_workers_free(struct vhost_dev *dev) +{ + struct vhost_worker *worker; + unsigned long i; + + if (!dev->use_worker) + return; + + for (i = 0; i < dev->nvqs; i++) + rcu_assign_pointer(dev->vqs[i]->worker, NULL); + /* + * Free the default worker we created and cleanup workers userspace + * created but couldn't clean up (it forgot or crashed). + */ + xa_for_each(&dev->worker_xa, i, worker) + vhost_worker_destroy(dev, worker); + xa_destroy(&dev->worker_xa); +} + +static void vhost_task_wakeup(struct vhost_worker *worker) +{ + return vhost_task_wake(worker->vtsk); +} + +static void vhost_kthread_wakeup(struct vhost_worker *worker) +{ + wake_up_process(worker->kthread_task); +} + +static void vhost_task_do_stop(struct vhost_worker *worker) +{ + return vhost_task_stop(worker->vtsk); +} + +static void vhost_kthread_do_stop(struct vhost_worker *worker) +{ + kthread_stop(worker->kthread_task); +} + +static int vhost_task_worker_create(struct vhost_worker *worker, + struct vhost_dev *dev, const char *name) +{ + struct vhost_task *vtsk; + u32 id; + int ret; + + vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed, + worker, name); + if (IS_ERR(vtsk)) + return PTR_ERR(vtsk); + + worker->vtsk = vtsk; + vhost_task_start(vtsk); + ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + if (ret < 0) { + vhost_task_do_stop(worker); + return ret; + } + worker->id = id; + return 0; +} + +static int vhost_kthread_worker_create(struct vhost_worker *worker, + struct vhost_dev *dev, const char *name) +{ + struct task_struct *task; + u32 id; + int ret; + + task = kthread_create(vhost_run_work_kthread_list, worker, "%s", name); + if (IS_ERR(task)) + return PTR_ERR(task); + + worker->kthread_task = task; + wake_up_process(task); + ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + if (ret < 0) + goto stop_worker; + + ret = vhost_attach_task_to_cgroups(worker); + if (ret) + goto free_id; + + worker->id = id; + return 0; + +free_id: + xa_erase(&dev->worker_xa, id); +stop_worker: + vhost_kthread_do_stop(worker); + return ret; +} + +static const struct vhost_worker_ops kthread_ops = { + .create = vhost_kthread_worker_create, + .stop = vhost_kthread_do_stop, + .wakeup = vhost_kthread_wakeup, +}; + +static const struct vhost_worker_ops vhost_task_ops = { + .create = vhost_task_worker_create, + .stop = vhost_task_do_stop, + .wakeup = vhost_task_wakeup, +}; + +static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev) +{ + struct vhost_worker *worker; + char name[TASK_COMM_LEN]; + int ret; + const struct vhost_worker_ops *ops = dev->fork_owner ? &vhost_task_ops : + &kthread_ops; + + worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT); + if (!worker) + return NULL; + + worker->dev = dev; + worker->ops = ops; + snprintf(name, sizeof(name), "vhost-%d", current->pid); + + mutex_init(&worker->mutex); + init_llist_head(&worker->work_list); + worker->kcov_handle = kcov_common_handle(); + ret = ops->create(worker, dev, name); + if (ret < 0) + goto free_worker; + + return worker; + +free_worker: + kfree(worker); + return NULL; +} + +/* Caller must have device mutex */ +static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, + struct vhost_worker *worker) +{ + struct vhost_worker *old_worker; + + mutex_lock(&worker->mutex); + if (worker->killed) { + mutex_unlock(&worker->mutex); + return; + } + + mutex_lock(&vq->mutex); + + old_worker = rcu_dereference_check(vq->worker, + lockdep_is_held(&vq->mutex)); + rcu_assign_pointer(vq->worker, worker); + worker->attachment_cnt++; + + if (!old_worker) { + mutex_unlock(&vq->mutex); + mutex_unlock(&worker->mutex); + return; + } + mutex_unlock(&vq->mutex); + mutex_unlock(&worker->mutex); + + /* + * Take the worker mutex to make sure we see the work queued from + * device wide flushes which doesn't use RCU for execution. + */ + mutex_lock(&old_worker->mutex); + if (old_worker->killed) { + mutex_unlock(&old_worker->mutex); + return; + } + + /* + * We don't want to call synchronize_rcu for every vq during setup + * because it will slow down VM startup. If we haven't done + * VHOST_SET_VRING_KICK and not done the driver specific + * SET_ENDPOINT/RUNNING then we can skip the sync since there will + * not be any works queued for scsi and net. + */ + mutex_lock(&vq->mutex); + if (!vhost_vq_get_backend(vq) && !vq->kick) { + mutex_unlock(&vq->mutex); + + old_worker->attachment_cnt--; + mutex_unlock(&old_worker->mutex); + /* + * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID. + * Warn if it adds support for multiple workers but forgets to + * handle the early queueing case. + */ + WARN_ON(!old_worker->attachment_cnt && + !llist_empty(&old_worker->work_list)); + return; + } + mutex_unlock(&vq->mutex); + + /* Make sure new vq queue/flush/poll calls see the new worker */ + synchronize_rcu(); + /* Make sure whatever was queued gets run */ + __vhost_worker_flush(old_worker); + old_worker->attachment_cnt--; + mutex_unlock(&old_worker->mutex); +} + + /* Caller must have device mutex */ +static int vhost_vq_attach_worker(struct vhost_virtqueue *vq, + struct vhost_vring_worker *info) +{ + unsigned long index = info->worker_id; + struct vhost_dev *dev = vq->dev; + struct vhost_worker *worker; + + if (!dev->use_worker) + return -EINVAL; + + worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT); + if (!worker || worker->id != info->worker_id) + return -ENODEV; + + __vhost_vq_attach_worker(vq, worker); + return 0; +} + +/* Caller must have device mutex */ +static int vhost_new_worker(struct vhost_dev *dev, + struct vhost_worker_state *info) +{ + struct vhost_worker *worker; + + worker = vhost_worker_create(dev); + if (!worker) + return -ENOMEM; + + info->worker_id = worker->id; + return 0; +} + +/* Caller must have device mutex */ +static int vhost_free_worker(struct vhost_dev *dev, + struct vhost_worker_state *info) +{ + unsigned long index = info->worker_id; + struct vhost_worker *worker; + + worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT); + if (!worker || worker->id != info->worker_id) + return -ENODEV; + + mutex_lock(&worker->mutex); + if (worker->attachment_cnt || worker->killed) { + mutex_unlock(&worker->mutex); + return -EBUSY; + } + /* + * A flush might have raced and snuck in before attachment_cnt was set + * to zero. Make sure flushes are flushed from the queue before + * freeing. + */ + __vhost_worker_flush(worker); + mutex_unlock(&worker->mutex); + + vhost_worker_destroy(dev, worker); + return 0; +} + +static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp, + struct vhost_virtqueue **vq, u32 *id) +{ + u32 __user *idxp = argp; + u32 idx; + long r; + + r = get_user(idx, idxp); + if (r < 0) + return r; + + if (idx >= dev->nvqs) + return -ENOBUFS; + + idx = array_index_nospec(idx, dev->nvqs); + + *vq = dev->vqs[idx]; + *id = idx; + return 0; +} + +/* Caller must have device mutex */ +long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl, + void __user *argp) +{ + struct vhost_vring_worker ring_worker; + struct vhost_worker_state state; + struct vhost_worker *worker; + struct vhost_virtqueue *vq; + long ret; + u32 idx; + + if (!dev->use_worker) + return -EINVAL; + + if (!vhost_dev_has_owner(dev)) + return -EINVAL; + + ret = vhost_dev_check_owner(dev); + if (ret) + return ret; + + switch (ioctl) { + /* dev worker ioctls */ + case VHOST_NEW_WORKER: + /* + * vhost_tasks will account for worker threads under the parent's + * NPROC value but kthreads do not. To avoid userspace overflowing + * the system with worker threads fork_owner must be true. + */ + if (!dev->fork_owner) + return -EFAULT; + + ret = vhost_new_worker(dev, &state); + if (!ret && copy_to_user(argp, &state, sizeof(state))) + ret = -EFAULT; + return ret; + case VHOST_FREE_WORKER: + if (copy_from_user(&state, argp, sizeof(state))) + return -EFAULT; + return vhost_free_worker(dev, &state); + /* vring worker ioctls */ + case VHOST_ATTACH_VRING_WORKER: + case VHOST_GET_VRING_WORKER: + break; + default: + return -ENOIOCTLCMD; + } + + ret = vhost_get_vq_from_user(dev, argp, &vq, &idx); + if (ret) + return ret; + + switch (ioctl) { + case VHOST_ATTACH_VRING_WORKER: + if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) { + ret = -EFAULT; + break; + } + + ret = vhost_vq_attach_worker(vq, &ring_worker); + break; + case VHOST_GET_VRING_WORKER: + worker = rcu_dereference_check(vq->worker, + lockdep_is_held(&dev->mutex)); + if (!worker) { + ret = -EINVAL; + break; + } + + ring_worker.index = idx; + ring_worker.worker_id = worker->id; + + if (copy_to_user(argp, &ring_worker, sizeof(ring_worker))) + ret = -EFAULT; + break; + default: + ret = -ENOIOCTLCMD; + break; + } + + return ret; +} +EXPORT_SYMBOL_GPL(vhost_worker_ioctl); + /* Caller should have device mutex */ long vhost_dev_set_owner(struct vhost_dev *dev) { - struct task_struct *worker; - int err; + struct vhost_worker *worker; + int err, i; /* Is there an owner already? */ if (vhost_dev_has_owner(dev)) { @@ -501,53 +1105,60 @@ long vhost_dev_set_owner(struct vhost_dev *dev) goto err_mm; } - /* No owner, become one */ - dev->mm = get_task_mm(current); - worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid); - if (IS_ERR(worker)) { - err = PTR_ERR(worker); - goto err_worker; - } - - dev->worker = worker; - wake_up_process(worker); /* avoid contributing to loadavg */ - - err = vhost_attach_cgroups(dev); - if (err) - goto err_cgroup; + vhost_attach_mm(dev); err = vhost_dev_alloc_iovecs(dev); if (err) - goto err_cgroup; + goto err_iovecs; + + if (dev->use_worker) { + /* + * This should be done last, because vsock can queue work + * before VHOST_SET_OWNER so it simplifies the failure path + * below since we don't have to worry about vsock queueing + * while we free the worker. + */ + worker = vhost_worker_create(dev); + if (!worker) { + err = -ENOMEM; + goto err_worker; + } + + for (i = 0; i < dev->nvqs; i++) + __vhost_vq_attach_worker(dev->vqs[i], worker); + } return 0; -err_cgroup: - kthread_stop(worker); - dev->worker = NULL; + err_worker: - if (dev->mm) - mmput(dev->mm); - dev->mm = NULL; + vhost_dev_free_iovecs(dev); +err_iovecs: + vhost_detach_mm(dev); err_mm: return err; } EXPORT_SYMBOL_GPL(vhost_dev_set_owner); -struct vhost_umem *vhost_dev_reset_owner_prepare(void) +static struct vhost_iotlb *iotlb_alloc(void) { - return kvzalloc(sizeof(struct vhost_umem), GFP_KERNEL); + return vhost_iotlb_alloc(max_iotlb_entries, + VHOST_IOTLB_FLAG_RETIRE); +} + +struct vhost_iotlb *vhost_dev_reset_owner_prepare(void) +{ + return iotlb_alloc(); } EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); /* Caller should have device mutex */ -void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem) +void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem) { int i; - vhost_dev_cleanup(dev, true); + vhost_dev_cleanup(dev); - /* Restore memory to default empty mapping. */ - INIT_LIST_HEAD(&umem->umem_list); + dev->fork_owner = fork_from_owner_default; dev->umem = umem; /* We don't need VQ locks below since vhost_dev_cleanup makes sure * VQs aren't running. @@ -562,37 +1173,15 @@ void vhost_dev_stop(struct vhost_dev *dev) int i; for (i = 0; i < dev->nvqs; ++i) { - if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) { + if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) vhost_poll_stop(&dev->vqs[i]->poll); - vhost_poll_flush(&dev->vqs[i]->poll); - } } -} -EXPORT_SYMBOL_GPL(vhost_dev_stop); -static void vhost_umem_free(struct vhost_umem *umem, - struct vhost_umem_node *node) -{ - vhost_umem_interval_tree_remove(node, &umem->umem_tree); - list_del(&node->link); - kfree(node); - umem->numem--; -} - -static void vhost_umem_clean(struct vhost_umem *umem) -{ - struct vhost_umem_node *node, *tmp; - - if (!umem) - return; - - list_for_each_entry_safe(node, tmp, &umem->umem_list, link) - vhost_umem_free(umem, node); - - kvfree(umem); + vhost_dev_flush(dev); } +EXPORT_SYMBOL_GPL(vhost_dev_stop); -static void vhost_clear_msg(struct vhost_dev *dev) +void vhost_clear_msg(struct vhost_dev *dev) { struct vhost_msg_node *node, *n; @@ -610,117 +1199,109 @@ static void vhost_clear_msg(struct vhost_dev *dev) spin_unlock(&dev->iotlb_lock); } +EXPORT_SYMBOL_GPL(vhost_clear_msg); -/* Caller should have device mutex if and only if locked is set */ -void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) +void vhost_dev_cleanup(struct vhost_dev *dev) { int i; for (i = 0; i < dev->nvqs; ++i) { if (dev->vqs[i]->error_ctx) eventfd_ctx_put(dev->vqs[i]->error_ctx); - if (dev->vqs[i]->error) - fput(dev->vqs[i]->error); if (dev->vqs[i]->kick) fput(dev->vqs[i]->kick); - if (dev->vqs[i]->call_ctx) - eventfd_ctx_put(dev->vqs[i]->call_ctx); - if (dev->vqs[i]->call) - fput(dev->vqs[i]->call); + if (dev->vqs[i]->call_ctx.ctx) + eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx); vhost_vq_reset(dev, dev->vqs[i]); } vhost_dev_free_iovecs(dev); if (dev->log_ctx) eventfd_ctx_put(dev->log_ctx); dev->log_ctx = NULL; - if (dev->log_file) - fput(dev->log_file); - dev->log_file = NULL; /* No one will access memory at this point */ - vhost_umem_clean(dev->umem); + vhost_iotlb_free(dev->umem); dev->umem = NULL; - vhost_umem_clean(dev->iotlb); + vhost_iotlb_free(dev->iotlb); dev->iotlb = NULL; vhost_clear_msg(dev); - wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM); - WARN_ON(!llist_empty(&dev->work_list)); - if (dev->worker) { - kthread_stop(dev->worker); - dev->worker = NULL; - } - if (dev->mm) - mmput(dev->mm); - dev->mm = NULL; + wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); + vhost_workers_free(dev); + vhost_detach_mm(dev); } EXPORT_SYMBOL_GPL(vhost_dev_cleanup); -static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) +static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz) { u64 a = addr / VHOST_PAGE_SIZE / 8; /* Make sure 64 bit math will not overflow. */ if (a > ULONG_MAX - (unsigned long)log_base || a + (unsigned long)log_base > ULONG_MAX) - return 0; + return false; - return access_ok(VERIFY_WRITE, log_base + a, + return access_ok(log_base + a, (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); } +/* Make sure 64 bit math will not overflow. */ static bool vhost_overflow(u64 uaddr, u64 size) { - /* Make sure 64 bit math will not overflow. */ - return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size; + if (uaddr > ULONG_MAX || size > ULONG_MAX) + return true; + + if (!size) + return false; + + return uaddr > ULONG_MAX - size + 1; } /* Caller should have vq mutex and device mutex. */ -static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem, - int log_all) +static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem, + int log_all) { - struct vhost_umem_node *node; + struct vhost_iotlb_map *map; if (!umem) - return 0; + return false; - list_for_each_entry(node, &umem->umem_list, link) { - unsigned long a = node->userspace_addr; + list_for_each_entry(map, &umem->list, link) { + unsigned long a = map->addr; - if (vhost_overflow(node->userspace_addr, node->size)) - return 0; + if (vhost_overflow(map->addr, map->size)) + return false; - if (!access_ok(VERIFY_WRITE, (void __user *)a, - node->size)) - return 0; + if (!access_ok((void __user *)a, map->size)) + return false; else if (log_all && !log_access_ok(log_base, - node->start, - node->size)) - return 0; + map->start, + map->size)) + return false; } - return 1; + return true; } static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq, u64 addr, unsigned int size, int type) { - const struct vhost_umem_node *node = vq->meta_iotlb[type]; + const struct vhost_iotlb_map *map = vq->meta_iotlb[type]; - if (!node) + if (!map) return NULL; - return (void *)(uintptr_t)(node->userspace_addr + addr - node->start); + return (void __user *)(uintptr_t)(map->addr + addr - map->start); } /* Can we switch to this memory table? */ /* Caller should have device mutex but not vq mutex */ -static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, - int log_all) +static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem, + int log_all) { int i; for (i = 0; i < d->nvqs; ++i) { - int ok; + bool ok; bool log; mutex_lock(&d->vqs[i]->mutex); @@ -730,12 +1311,12 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, ok = vq_memory_access_ok(d->vqs[i]->log_base, umem, log); else - ok = 1; + ok = true; mutex_unlock(&d->vqs[i]->mutex); if (!ok) - return 0; + return false; } - return 1; + return true; } static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, @@ -757,7 +1338,7 @@ static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to, struct iov_iter t; void __user *uaddr = vhost_vq_meta_fetch(vq, (u64)(uintptr_t)to, size, - VHOST_ADDR_DESC); + VHOST_ADDR_USED); if (uaddr) return __copy_to_user(uaddr, from, size); @@ -767,7 +1348,7 @@ static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to, VHOST_ACCESS_WO); if (ret < 0) goto out; - iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size); + iov_iter_init(&t, ITER_DEST, vq->iotlb_iov, ret, size); ret = copy_to_iter(from, size, &t); if (ret == size) ret = 0; @@ -806,7 +1387,7 @@ static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, (unsigned long long) size); goto out; } - iov_iter_init(&f, READ, vq->iotlb_iov, ret, size); + iov_iter_init(&f, ITER_SOURCE, vq->iotlb_iov, ret, size); ret = copy_from_iter(to, size, &f); if (ret == size) ret = 0; @@ -848,7 +1429,7 @@ static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq, * not happen in this case. */ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, - void *addr, unsigned int size, + void __user *addr, unsigned int size, int type) { void __user *uaddr = vhost_vq_meta_fetch(vq, @@ -861,7 +1442,7 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, #define vhost_put_user(vq, x, ptr) \ ({ \ - int ret = -EFAULT; \ + int ret; \ if (!vq->iotlb) { \ ret = __put_user(x, ptr); \ } else { \ @@ -876,6 +1457,34 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, ret; \ }) +static inline int vhost_put_avail_event(struct vhost_virtqueue *vq) +{ + return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx), + vhost_avail_event(vq)); +} + +static inline int vhost_put_used(struct vhost_virtqueue *vq, + struct vring_used_elem *head, int idx, + int count) +{ + return vhost_copy_to_user(vq, vq->used->ring + idx, head, + count * sizeof(*head)); +} + +static inline int vhost_put_used_flags(struct vhost_virtqueue *vq) + +{ + return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags), + &vq->used->flags); +} + +static inline int vhost_put_used_idx(struct vhost_virtqueue *vq) + +{ + return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx), + &vq->used->idx); +} + #define vhost_get_user(vq, x, ptr, type) \ ({ \ int ret; \ @@ -904,7 +1513,7 @@ static void vhost_dev_lock_vqs(struct vhost_dev *d) { int i = 0; for (i = 0; i < d->nvqs; ++i) - mutex_lock(&d->vqs[i]->mutex); + mutex_lock_nested(&d->vqs[i]->mutex, i); } static void vhost_dev_unlock_vqs(struct vhost_dev *d) @@ -914,41 +1523,67 @@ static void vhost_dev_unlock_vqs(struct vhost_dev *d) mutex_unlock(&d->vqs[i]->mutex); } -static int vhost_new_umem_range(struct vhost_umem *umem, - u64 start, u64 size, u64 end, - u64 userspace_addr, int perm) +static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq) { - struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC); + __virtio16 idx; + int r; - if (!node) - return -ENOMEM; + r = vhost_get_avail(vq, idx, &vq->avail->idx); + if (unlikely(r < 0)) { + vq_err(vq, "Failed to access available index at %p (%d)\n", + &vq->avail->idx, r); + return r; + } - if (umem->numem == max_iotlb_entries) { - tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link); - vhost_umem_free(umem, tmp); + /* Check it isn't doing very strange thing with available indexes */ + vq->avail_idx = vhost16_to_cpu(vq, idx); + if (unlikely((u16)(vq->avail_idx - vq->last_avail_idx) > vq->num)) { + vq_err(vq, "Invalid available index change from %u to %u", + vq->last_avail_idx, vq->avail_idx); + return -EINVAL; } - node->start = start; - node->size = size; - node->last = end; - node->userspace_addr = userspace_addr; - node->perm = perm; - INIT_LIST_HEAD(&node->link); - list_add_tail(&node->link, &umem->umem_list); - vhost_umem_interval_tree_insert(node, &umem->umem_tree); - umem->numem++; + /* We're done if there is nothing new */ + if (vq->avail_idx == vq->last_avail_idx) + return 0; - return 0; + /* + * We updated vq->avail_idx so we need a memory barrier between + * the index read above and the caller reading avail ring entries. + */ + smp_rmb(); + return 1; +} + +static inline int vhost_get_avail_head(struct vhost_virtqueue *vq, + __virtio16 *head, int idx) +{ + return vhost_get_avail(vq, *head, + &vq->avail->ring[idx & (vq->num - 1)]); +} + +static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq, + __virtio16 *flags) +{ + return vhost_get_avail(vq, *flags, &vq->avail->flags); } -static void vhost_del_umem_range(struct vhost_umem *umem, - u64 start, u64 end) +static inline int vhost_get_used_event(struct vhost_virtqueue *vq, + __virtio16 *event) { - struct vhost_umem_node *node; + return vhost_get_avail(vq, *event, vhost_used_event(vq)); +} - while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, - start, end))) - vhost_umem_free(umem, node); +static inline int vhost_get_used_idx(struct vhost_virtqueue *vq, + __virtio16 *idx) +{ + return vhost_get_used(vq, *idx, &vq->used->idx); +} + +static inline int vhost_get_desc(struct vhost_virtqueue *vq, + struct vring_desc *desc, int idx) +{ + return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc)); } static void vhost_iotlb_notify_vq(struct vhost_dev *d, @@ -961,7 +1596,7 @@ static void vhost_iotlb_notify_vq(struct vhost_dev *d, list_for_each_entry_safe(node, n, &d->pending_list, node) { struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb; if (msg->iova <= vq_msg->iova && - msg->iova + msg->size - 1 > vq_msg->iova && + msg->iova + msg->size - 1 >= vq_msg->iova && vq_msg->type == VHOST_IOTLB_MISS) { vhost_poll_queue(&node->vq->poll); list_del(&node->node); @@ -972,28 +1607,32 @@ static void vhost_iotlb_notify_vq(struct vhost_dev *d, spin_unlock(&d->iotlb_lock); } -static int umem_access_ok(u64 uaddr, u64 size, int access) +static bool umem_access_ok(u64 uaddr, u64 size, int access) { unsigned long a = uaddr; /* Make sure 64 bit math will not overflow. */ if (vhost_overflow(uaddr, size)) - return -EFAULT; + return false; if ((access & VHOST_ACCESS_RO) && - !access_ok(VERIFY_READ, (void __user *)a, size)) - return -EFAULT; + !access_ok((void __user *)a, size)) + return false; if ((access & VHOST_ACCESS_WO) && - !access_ok(VERIFY_WRITE, (void __user *)a, size)) - return -EFAULT; - return 0; + !access_ok((void __user *)a, size)) + return false; + return true; } -static int vhost_process_iotlb_msg(struct vhost_dev *dev, +static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid, struct vhost_iotlb_msg *msg) { int ret = 0; + if (asid != 0) + return -EINVAL; + + mutex_lock(&dev->mutex); vhost_dev_lock_vqs(dev); switch (msg->type) { case VHOST_IOTLB_UPDATE: @@ -1001,23 +1640,27 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, ret = -EFAULT; break; } - if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) { + if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) { ret = -EFAULT; break; } vhost_vq_meta_reset(dev); - if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size, - msg->iova + msg->size - 1, - msg->uaddr, msg->perm)) { + if (vhost_iotlb_add_range(dev->iotlb, msg->iova, + msg->iova + msg->size - 1, + msg->uaddr, msg->perm)) { ret = -ENOMEM; break; } vhost_iotlb_notify_vq(dev, msg); break; case VHOST_IOTLB_INVALIDATE: + if (!dev->iotlb) { + ret = -EFAULT; + break; + } vhost_vq_meta_reset(dev); - vhost_del_umem_range(dev->iotlb, msg->iova, - msg->iova + msg->size - 1); + vhost_iotlb_del_range(dev->iotlb, msg->iova, + msg->iova + msg->size - 1); break; default: ret = -EINVAL; @@ -1025,47 +1668,85 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, } vhost_dev_unlock_vqs(dev); + mutex_unlock(&dev->mutex); + return ret; } ssize_t vhost_chr_write_iter(struct vhost_dev *dev, struct iov_iter *from) { - struct vhost_msg_node node; - unsigned size = sizeof(struct vhost_msg); - size_t ret; - int err; + struct vhost_iotlb_msg msg; + size_t offset; + int type, ret; + u32 asid = 0; - if (iov_iter_count(from) < size) - return 0; - ret = copy_from_iter(&node.msg, size, from); - if (ret != size) + ret = copy_from_iter(&type, sizeof(type), from); + if (ret != sizeof(type)) { + ret = -EINVAL; goto done; + } - switch (node.msg.type) { + switch (type) { case VHOST_IOTLB_MSG: - err = vhost_process_iotlb_msg(dev, &node.msg.iotlb); - if (err) - ret = err; + /* There maybe a hole after type for V1 message type, + * so skip it here. + */ + offset = offsetof(struct vhost_msg, iotlb) - sizeof(int); + break; + case VHOST_IOTLB_MSG_V2: + if (vhost_backend_has_feature(dev->vqs[0], + VHOST_BACKEND_F_IOTLB_ASID)) { + ret = copy_from_iter(&asid, sizeof(asid), from); + if (ret != sizeof(asid)) { + ret = -EINVAL; + goto done; + } + offset = 0; + } else + offset = sizeof(__u32); break; default: ret = -EINVAL; - break; + goto done; + } + + iov_iter_advance(from, offset); + ret = copy_from_iter(&msg, sizeof(msg), from); + if (ret != sizeof(msg)) { + ret = -EINVAL; + goto done; + } + + if (msg.type == VHOST_IOTLB_UPDATE && msg.size == 0) { + ret = -EINVAL; + goto done; } + if (dev->msg_handler) + ret = dev->msg_handler(dev, asid, &msg); + else + ret = vhost_process_iotlb_msg(dev, asid, &msg); + if (ret) { + ret = -EFAULT; + goto done; + } + + ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) : + sizeof(struct vhost_msg_v2); done: return ret; } EXPORT_SYMBOL(vhost_chr_write_iter); -unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev, +__poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev, poll_table *wait) { - unsigned int mask = 0; + __poll_t mask = 0; poll_wait(file, &dev->wait, wait); if (!list_empty(&dev->read_list)) - mask |= POLLIN | POLLRDNORM; + mask |= EPOLLIN | EPOLLRDNORM; return mask; } @@ -1110,13 +1791,28 @@ ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to, finish_wait(&dev->wait, &wait); if (node) { - ret = copy_to_iter(&node->msg, size, to); + struct vhost_iotlb_msg *msg; + void *start = &node->msg; + + switch (node->msg.type) { + case VHOST_IOTLB_MSG: + size = sizeof(node->msg); + msg = &node->msg.iotlb; + break; + case VHOST_IOTLB_MSG_V2: + size = sizeof(node->msg_v2); + msg = &node->msg_v2.iotlb; + break; + default: + BUG(); + break; + } - if (ret != size || node->msg.type != VHOST_IOTLB_MISS) { + ret = copy_to_iter(start, size, to); + if (ret != size || msg->type != VHOST_IOTLB_MISS) { kfree(node); return ret; } - vhost_enqueue_msg(dev, &dev->pending_list, node); } @@ -1129,12 +1825,19 @@ static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access) struct vhost_dev *dev = vq->dev; struct vhost_msg_node *node; struct vhost_iotlb_msg *msg; + bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2); - node = vhost_new_msg(vq, VHOST_IOTLB_MISS); + node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG); if (!node) return -ENOMEM; - msg = &node->msg.iotlb; + if (v2) { + node->msg_v2.type = VHOST_IOTLB_MSG_V2; + msg = &node->msg_v2.iotlb; + } else { + msg = &node->msg.iotlb; + } + msg->type = VHOST_IOTLB_MISS; msg->iova = iova; msg->perm = access; @@ -1144,60 +1847,59 @@ static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access) return 0; } -static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, - struct vring_desc __user *desc, - struct vring_avail __user *avail, - struct vring_used __user *used) +static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, + vring_desc_t __user *desc, + vring_avail_t __user *avail, + vring_used_t __user *used) { - size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; + /* If an IOTLB device is present, the vring addresses are + * GIOVAs. Access validation occurs at prefetch time. */ + if (vq->iotlb) + return true; - return access_ok(VERIFY_READ, desc, num * sizeof *desc) && - access_ok(VERIFY_READ, avail, - sizeof *avail + num * sizeof *avail->ring + s) && - access_ok(VERIFY_WRITE, used, - sizeof *used + num * sizeof *used->ring + s); + return access_ok(desc, vhost_get_desc_size(vq, num)) && + access_ok(avail, vhost_get_avail_size(vq, num)) && + access_ok(used, vhost_get_used_size(vq, num)); } static void vhost_vq_meta_update(struct vhost_virtqueue *vq, - const struct vhost_umem_node *node, + const struct vhost_iotlb_map *map, int type) { int access = (type == VHOST_ADDR_USED) ? VHOST_ACCESS_WO : VHOST_ACCESS_RO; - if (likely(node->perm & access)) - vq->meta_iotlb[type] = node; + if (likely(map->perm & access)) + vq->meta_iotlb[type] = map; } -static int iotlb_access_ok(struct vhost_virtqueue *vq, - int access, u64 addr, u64 len, int type) +static bool iotlb_access_ok(struct vhost_virtqueue *vq, + int access, u64 addr, u64 len, int type) { - const struct vhost_umem_node *node; - struct vhost_umem *umem = vq->iotlb; - u64 s = 0, size, orig_addr = addr; + const struct vhost_iotlb_map *map; + struct vhost_iotlb *umem = vq->iotlb; + u64 s = 0, size, orig_addr = addr, last = addr + len - 1; if (vhost_vq_meta_fetch(vq, addr, len, type)) return true; while (len > s) { - node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, - addr, - addr + len - 1); - if (node == NULL || node->start > addr) { + map = vhost_iotlb_itree_first(umem, addr, last); + if (map == NULL || map->start > addr) { vhost_iotlb_miss(vq, addr, access); return false; - } else if (!(node->perm & access)) { + } else if (!(map->perm & access)) { /* Report the possible access violation by * request another translation from userspace. */ return false; } - size = node->size - addr + node->start; + size = map->size - addr + map->start; if (orig_addr == addr && size >= len) - vhost_vq_meta_update(vq, node, type); + vhost_vq_meta_update(vq, map, type); s += size; addr += size; @@ -1206,83 +1908,71 @@ static int iotlb_access_ok(struct vhost_virtqueue *vq, return true; } -int vq_iotlb_prefetch(struct vhost_virtqueue *vq) +int vq_meta_prefetch(struct vhost_virtqueue *vq) { - size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; unsigned int num = vq->num; if (!vq->iotlb) return 1; - return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc, - num * sizeof(*vq->desc), VHOST_ADDR_DESC) && - iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail, - sizeof *vq->avail + - num * sizeof(*vq->avail->ring) + s, + return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc, + vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) && + iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail, + vhost_get_avail_size(vq, num), VHOST_ADDR_AVAIL) && - iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used, - sizeof *vq->used + - num * sizeof(*vq->used->ring) + s, - VHOST_ADDR_USED); + iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used, + vhost_get_used_size(vq, num), VHOST_ADDR_USED); } -EXPORT_SYMBOL_GPL(vq_iotlb_prefetch); +EXPORT_SYMBOL_GPL(vq_meta_prefetch); /* Can we log writes? */ /* Caller should have device mutex but not vq mutex */ -int vhost_log_access_ok(struct vhost_dev *dev) +bool vhost_log_access_ok(struct vhost_dev *dev) { return memory_access_ok(dev, dev->umem, 1); } EXPORT_SYMBOL_GPL(vhost_log_access_ok); +static bool vq_log_used_access_ok(struct vhost_virtqueue *vq, + void __user *log_base, + bool log_used, + u64 log_addr) +{ + /* If an IOTLB device is present, log_addr is a GIOVA that + * will never be logged by log_used(). */ + if (vq->iotlb) + return true; + + return !log_used || log_access_ok(log_base, log_addr, + vhost_get_used_size(vq, vq->num)); +} + /* Verify access for write logging. */ /* Caller should have vq mutex and device mutex */ -static int vq_log_access_ok(struct vhost_virtqueue *vq, - void __user *log_base) +static bool vq_log_access_ok(struct vhost_virtqueue *vq, + void __user *log_base) { - size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - return vq_memory_access_ok(log_base, vq->umem, vhost_has_feature(vq, VHOST_F_LOG_ALL)) && - (!vq->log_used || log_access_ok(log_base, vq->log_addr, - sizeof *vq->used + - vq->num * sizeof *vq->used->ring + s)); + vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr); } /* Can we start vq? */ /* Caller should have vq mutex and device mutex */ -int vhost_vq_access_ok(struct vhost_virtqueue *vq) -{ - if (vq->iotlb) { - /* When device IOTLB was used, the access validation - * will be validated during prefetching. - */ - return 1; - } - return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && - vq_log_access_ok(vq, vq->log_base); -} -EXPORT_SYMBOL_GPL(vhost_vq_access_ok); - -static struct vhost_umem *vhost_umem_alloc(void) +bool vhost_vq_access_ok(struct vhost_virtqueue *vq) { - struct vhost_umem *umem = kvzalloc(sizeof(*umem), GFP_KERNEL); - - if (!umem) - return NULL; - - umem->umem_tree = RB_ROOT; - umem->numem = 0; - INIT_LIST_HEAD(&umem->umem_list); + if (!vq_log_access_ok(vq, vq->log_base)) + return false; - return umem; + return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used); } +EXPORT_SYMBOL_GPL(vhost_vq_access_ok); static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) { struct vhost_memory mem, *newmem; struct vhost_memory_region *region; - struct vhost_umem *newumem, *oldumem; + struct vhost_iotlb *newumem, *oldumem; unsigned long size = offsetof(struct vhost_memory, regions); int i; @@ -1292,18 +1982,19 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) return -EOPNOTSUPP; if (mem.nregions > max_mem_regions) return -E2BIG; - newmem = kvzalloc(size + mem.nregions * sizeof(*m->regions), GFP_KERNEL); + newmem = kvzalloc(struct_size(newmem, regions, mem.nregions), + GFP_KERNEL); if (!newmem) return -ENOMEM; memcpy(newmem, &mem, size); if (copy_from_user(newmem->regions, m->regions, - mem.nregions * sizeof *m->regions)) { + flex_array_size(newmem, regions, mem.nregions))) { kvfree(newmem); return -EFAULT; } - newumem = vhost_umem_alloc(); + newumem = iotlb_alloc(); if (!newumem) { kvfree(newmem); return -ENOMEM; @@ -1312,13 +2003,12 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) for (region = newmem->regions; region < newmem->regions + mem.nregions; region++) { - if (vhost_new_umem_range(newumem, - region->guest_phys_addr, - region->memory_size, - region->guest_phys_addr + - region->memory_size - 1, - region->userspace_addr, - VHOST_ACCESS_RW)) + if (vhost_iotlb_add_range(newumem, + region->guest_phys_addr, + region->guest_phys_addr + + region->memory_size - 1, + region->userspace_addr, + VHOST_MAP_RW)) goto err; } @@ -1336,56 +2026,135 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) } kvfree(newmem); - vhost_umem_clean(oldumem); + vhost_iotlb_free(oldumem); return 0; err: - vhost_umem_clean(newumem); + vhost_iotlb_free(newumem); kvfree(newmem); return -EFAULT; } -long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) +static long vhost_vring_set_num(struct vhost_dev *d, + struct vhost_virtqueue *vq, + void __user *argp) +{ + struct vhost_vring_state s; + + /* Resizing ring with an active backend? + * You don't want to do that. */ + if (vq->private_data) + return -EBUSY; + + if (copy_from_user(&s, argp, sizeof s)) + return -EFAULT; + + if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) + return -EINVAL; + vq->num = s.num; + + return 0; +} + +static long vhost_vring_set_addr(struct vhost_dev *d, + struct vhost_virtqueue *vq, + void __user *argp) +{ + struct vhost_vring_addr a; + + if (copy_from_user(&a, argp, sizeof a)) + return -EFAULT; + if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) + return -EOPNOTSUPP; + + /* For 32bit, verify that the top 32bits of the user + data are set to zero. */ + if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr || + (u64)(unsigned long)a.used_user_addr != a.used_user_addr || + (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) + return -EFAULT; + + /* Make sure it's safe to cast pointers to vring types. */ + BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE); + BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE); + if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) || + (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) || + (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) + return -EINVAL; + + /* We only verify access here if backend is configured. + * If it is not, we don't as size might not have been setup. + * We will verify when backend is configured. */ + if (vq->private_data) { + if (!vq_access_ok(vq, vq->num, + (void __user *)(unsigned long)a.desc_user_addr, + (void __user *)(unsigned long)a.avail_user_addr, + (void __user *)(unsigned long)a.used_user_addr)) + return -EINVAL; + + /* Also validate log access for used ring if enabled. */ + if (!vq_log_used_access_ok(vq, vq->log_base, + a.flags & (0x1 << VHOST_VRING_F_LOG), + a.log_guest_addr)) + return -EINVAL; + } + + vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG)); + vq->desc = (void __user *)(unsigned long)a.desc_user_addr; + vq->avail = (void __user *)(unsigned long)a.avail_user_addr; + vq->log_addr = a.log_guest_addr; + vq->used = (void __user *)(unsigned long)a.used_user_addr; + + return 0; +} + +static long vhost_vring_set_num_addr(struct vhost_dev *d, + struct vhost_virtqueue *vq, + unsigned int ioctl, + void __user *argp) +{ + long r; + + mutex_lock(&vq->mutex); + + switch (ioctl) { + case VHOST_SET_VRING_NUM: + r = vhost_vring_set_num(d, vq, argp); + break; + case VHOST_SET_VRING_ADDR: + r = vhost_vring_set_addr(d, vq, argp); + break; + default: + BUG(); + } + + mutex_unlock(&vq->mutex); + + return r; +} +long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) { struct file *eventfp, *filep = NULL; bool pollstart = false, pollstop = false; struct eventfd_ctx *ctx = NULL; - u32 __user *idxp = argp; struct vhost_virtqueue *vq; struct vhost_vring_state s; struct vhost_vring_file f; - struct vhost_vring_addr a; u32 idx; long r; - r = get_user(idx, idxp); + r = vhost_get_vq_from_user(d, argp, &vq, &idx); if (r < 0) return r; - if (idx >= d->nvqs) - return -ENOBUFS; - vq = d->vqs[idx]; + if (ioctl == VHOST_SET_VRING_NUM || + ioctl == VHOST_SET_VRING_ADDR) { + return vhost_vring_set_num_addr(d, vq, ioctl, argp); + } mutex_lock(&vq->mutex); switch (ioctl) { - case VHOST_SET_VRING_NUM: - /* Resizing ring with an active backend? - * You don't want to do that. */ - if (vq->private_data) { - r = -EBUSY; - break; - } - if (copy_from_user(&s, argp, sizeof s)) { - r = -EFAULT; - break; - } - if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) { - r = -EINVAL; - break; - } - vq->num = s.num; - break; case VHOST_SET_VRING_BASE: /* Moving base with an active backend? * You don't want to do that. */ @@ -1397,82 +2166,35 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) r = -EFAULT; break; } - if (s.num > 0xffff) { - r = -EINVAL; - break; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) { + vq->next_avail_head = vq->last_avail_idx = + s.num & 0xffff; + vq->last_used_idx = (s.num >> 16) & 0xffff; + } else { + if (s.num > 0xffff) { + r = -EINVAL; + break; + } + vq->next_avail_head = vq->last_avail_idx = s.num; } - vq->last_avail_idx = s.num; /* Forget the cached index value. */ vq->avail_idx = vq->last_avail_idx; break; case VHOST_GET_VRING_BASE: s.index = idx; - s.num = vq->last_avail_idx; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16); + else + s.num = vq->last_avail_idx; if (copy_to_user(argp, &s, sizeof s)) r = -EFAULT; break; - case VHOST_SET_VRING_ADDR: - if (copy_from_user(&a, argp, sizeof a)) { - r = -EFAULT; - break; - } - if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) { - r = -EOPNOTSUPP; - break; - } - /* For 32bit, verify that the top 32bits of the user - data are set to zero. */ - if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr || - (u64)(unsigned long)a.used_user_addr != a.used_user_addr || - (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) { - r = -EFAULT; - break; - } - - /* Make sure it's safe to cast pointers to vring types. */ - BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE); - BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE); - if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) || - (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) || - (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) { - r = -EINVAL; - break; - } - - /* We only verify access here if backend is configured. - * If it is not, we don't as size might not have been setup. - * We will verify when backend is configured. */ - if (vq->private_data) { - if (!vq_access_ok(vq, vq->num, - (void __user *)(unsigned long)a.desc_user_addr, - (void __user *)(unsigned long)a.avail_user_addr, - (void __user *)(unsigned long)a.used_user_addr)) { - r = -EINVAL; - break; - } - - /* Also validate log access for used ring if enabled. */ - if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) && - !log_access_ok(vq->log_base, a.log_guest_addr, - sizeof *vq->used + - vq->num * sizeof *vq->used->ring)) { - r = -EINVAL; - break; - } - } - - vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG)); - vq->desc = (void __user *)(unsigned long)a.desc_user_addr; - vq->avail = (void __user *)(unsigned long)a.avail_user_addr; - vq->log_addr = a.log_guest_addr; - vq->used = (void __user *)(unsigned long)a.used_user_addr; - break; case VHOST_SET_VRING_KICK: if (copy_from_user(&f, argp, sizeof f)) { r = -EFAULT; break; } - eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); + eventfp = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_fget(f.fd); if (IS_ERR(eventfp)) { r = PTR_ERR(eventfp); break; @@ -1488,38 +2210,25 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) r = -EFAULT; break; } - eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); - if (IS_ERR(eventfp)) { - r = PTR_ERR(eventfp); + ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd); + if (IS_ERR(ctx)) { + r = PTR_ERR(ctx); break; } - if (eventfp != vq->call) { - filep = vq->call; - ctx = vq->call_ctx; - vq->call = eventfp; - vq->call_ctx = eventfp ? - eventfd_ctx_fileget(eventfp) : NULL; - } else - filep = eventfp; + + swap(ctx, vq->call_ctx.ctx); break; case VHOST_SET_VRING_ERR: if (copy_from_user(&f, argp, sizeof f)) { r = -EFAULT; break; } - eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); - if (IS_ERR(eventfp)) { - r = PTR_ERR(eventfp); + ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd); + if (IS_ERR(ctx)) { + r = PTR_ERR(ctx); break; } - if (eventfp != vq->error) { - filep = vq->error; - vq->error = eventfp; - ctx = vq->error_ctx; - vq->error_ctx = eventfp ? - eventfd_ctx_fileget(eventfp) : NULL; - } else - filep = eventfp; + swap(ctx, vq->error_ctx); break; case VHOST_SET_VRING_ENDIAN: r = vhost_set_vring_endian(vq, argp); @@ -1547,7 +2256,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) if (pollstop && vq->handle_kick) vhost_poll_stop(&vq->poll); - if (ctx) + if (!IS_ERR_OR_NULL(ctx)) eventfd_ctx_put(ctx); if (filep) fput(filep); @@ -1558,17 +2267,17 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) mutex_unlock(&vq->mutex); if (pollstop && vq->handle_kick) - vhost_poll_flush(&vq->poll); + vhost_dev_flush(vq->poll.dev); return r; } EXPORT_SYMBOL_GPL(vhost_vring_ioctl); -int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled) +int vhost_init_device_iotlb(struct vhost_dev *d) { - struct vhost_umem *niotlb, *oiotlb; + struct vhost_iotlb *niotlb, *oiotlb; int i; - niotlb = vhost_umem_alloc(); + niotlb = iotlb_alloc(); if (!niotlb) return -ENOMEM; @@ -1576,12 +2285,15 @@ int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled) d->iotlb = niotlb; for (i = 0; i < d->nvqs; ++i) { - mutex_lock(&d->vqs[i]->mutex); - d->vqs[i]->iotlb = niotlb; - mutex_unlock(&d->vqs[i]->mutex); + struct vhost_virtqueue *vq = d->vqs[i]; + + mutex_lock(&vq->mutex); + vq->iotlb = niotlb; + __vhost_vq_meta_reset(vq); + mutex_unlock(&vq->mutex); } - vhost_umem_clean(oiotlb); + vhost_iotlb_free(oiotlb); return 0; } @@ -1590,8 +2302,7 @@ EXPORT_SYMBOL_GPL(vhost_init_device_iotlb); /* Caller must have device mutex */ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) { - struct file *eventfp, *filep = NULL; - struct eventfd_ctx *ctx = NULL; + struct eventfd_ctx *ctx; u64 p; long r; int i, fd; @@ -1602,6 +2313,45 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) goto done; } +#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL + if (ioctl == VHOST_SET_FORK_FROM_OWNER) { + /* Only allow modification before owner is set */ + if (vhost_dev_has_owner(d)) { + r = -EBUSY; + goto done; + } + u8 fork_owner_val; + + if (get_user(fork_owner_val, (u8 __user *)argp)) { + r = -EFAULT; + goto done; + } + if (fork_owner_val != VHOST_FORK_OWNER_TASK && + fork_owner_val != VHOST_FORK_OWNER_KTHREAD) { + r = -EINVAL; + goto done; + } + d->fork_owner = !!fork_owner_val; + r = 0; + goto done; + } + if (ioctl == VHOST_GET_FORK_FROM_OWNER) { + u8 fork_owner_val = d->fork_owner; + + if (fork_owner_val != VHOST_FORK_OWNER_TASK && + fork_owner_val != VHOST_FORK_OWNER_KTHREAD) { + r = -EINVAL; + goto done; + } + if (put_user(fork_owner_val, (u8 __user *)argp)) { + r = -EFAULT; + goto done; + } + r = 0; + goto done; + } +#endif + /* You must be the owner to do anything else */ r = vhost_dev_check_owner(d); if (r) @@ -1637,19 +2387,12 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) r = get_user(fd, (int __user *)argp); if (r < 0) break; - eventfp = fd == -1 ? NULL : eventfd_fget(fd); - if (IS_ERR(eventfp)) { - r = PTR_ERR(eventfp); + ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd); + if (IS_ERR(ctx)) { + r = PTR_ERR(ctx); break; } - if (eventfp != d->log_file) { - filep = d->log_file; - d->log_file = eventfp; - ctx = d->log_ctx; - d->log_ctx = eventfp ? - eventfd_ctx_fileget(eventfp) : NULL; - } else - filep = eventfp; + swap(ctx, d->log_ctx); for (i = 0; i < d->nvqs; ++i) { mutex_lock(&d->vqs[i]->mutex); d->vqs[i]->log_ctx = d->log_ctx; @@ -1657,8 +2400,6 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) } if (ctx) eventfd_ctx_put(ctx); - if (filep) - fput(filep); break; default: r = -ENOIOCTLCMD; @@ -1671,7 +2412,7 @@ EXPORT_SYMBOL_GPL(vhost_dev_ioctl); /* TODO: This is really inefficient. We need something like get_user() * (instruction directly accesses the data, with an exception table entry - * returning -EFAULT). See Documentation/x86/exception-tables.txt. + * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst. */ static int set_bit_to_user(int nr, void __user *addr) { @@ -1681,15 +2422,14 @@ static int set_bit_to_user(int nr, void __user *addr) int bit = nr + (log % PAGE_SIZE) * 8; int r; - r = get_user_pages_fast(log, 1, 1, &page); + r = pin_user_pages_fast(log, 1, FOLL_WRITE, &page); if (r < 0) return r; BUG_ON(r != 1); base = kmap_atomic(page); set_bit(bit, base); kunmap_atomic(base); - set_page_dirty_lock(page); - put_page(page); + unpin_user_pages_dirty_lock(&page, 1, true); return 0; } @@ -1719,27 +2459,112 @@ static int log_write(void __user *log_base, return r; } +static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len) +{ + struct vhost_iotlb *umem = vq->umem; + struct vhost_iotlb_map *u; + u64 start, end, l, min; + int r; + bool hit = false; + + while (len) { + min = len; + /* More than one GPAs can be mapped into a single HVA. So + * iterate all possible umems here to be safe. + */ + list_for_each_entry(u, &umem->list, link) { + if (u->addr > hva - 1 + len || + u->addr - 1 + u->size < hva) + continue; + start = max(u->addr, hva); + end = min(u->addr - 1 + u->size, hva - 1 + len); + l = end - start + 1; + r = log_write(vq->log_base, + u->start + start - u->addr, + l); + if (r < 0) + return r; + hit = true; + min = min(l, min); + } + + if (!hit) + return -EFAULT; + + len -= min; + hva += min; + } + + return 0; +} + +static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len) +{ + struct iovec *iov = vq->log_iov; + int i, ret; + + if (!vq->iotlb) + return log_write(vq->log_base, vq->log_addr + used_offset, len); + + ret = translate_desc(vq, (uintptr_t)vq->used + used_offset, + len, iov, 64, VHOST_ACCESS_WO); + if (ret < 0) + return ret; + + for (i = 0; i < ret; i++) { + ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base, + iov[i].iov_len); + if (ret) + return ret; + } + + return 0; +} + +/* + * vhost_log_write() - Log in dirty page bitmap + * @vq: vhost virtqueue. + * @log: Array of dirty memory in GPA. + * @log_num: Size of vhost_log arrary. + * @len: The total length of memory buffer to log in the dirty bitmap. + * Some drivers may only partially use pages shared via the last + * vring descriptor (i.e. vhost-net RX buffer). + * Use (len == U64_MAX) to indicate the driver would log all + * pages of vring descriptors. + * @iov: Array of dirty memory in HVA. + * @count: Size of iovec array. + */ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, - unsigned int log_num, u64 len) + unsigned int log_num, u64 len, struct iovec *iov, int count) { int i, r; /* Make sure data written is seen before log. */ smp_wmb(); + + if (vq->iotlb) { + for (i = 0; i < count; i++) { + r = log_write_hva(vq, (uintptr_t)iov[i].iov_base, + iov[i].iov_len); + if (r < 0) + return r; + } + return 0; + } + for (i = 0; i < log_num; ++i) { u64 l = min(log[i].len, len); r = log_write(vq->log_base, log[i].addr, l); if (r < 0) return r; - len -= l; - if (!len) { - if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); - return 0; - } + + if (len != U64_MAX) + len -= l; } - /* Length written exceeds what we have stored. This is a bug. */ - BUG(); + + if (vq->log_ctx) + eventfd_signal(vq->log_ctx); + return 0; } EXPORT_SYMBOL_GPL(vhost_log_write); @@ -1747,27 +2572,24 @@ EXPORT_SYMBOL_GPL(vhost_log_write); static int vhost_update_used_flags(struct vhost_virtqueue *vq) { void __user *used; - if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags), - &vq->used->flags) < 0) + if (vhost_put_used_flags(vq)) return -EFAULT; if (unlikely(vq->log_used)) { /* Make sure the flag is seen before log. */ smp_wmb(); /* Log used flag write. */ used = &vq->used->flags; - log_write(vq->log_base, vq->log_addr + - (used - (void __user *)vq->used), - sizeof vq->used->flags); + log_used(vq, (used - (void __user *)vq->used), + sizeof vq->used->flags); if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); + eventfd_signal(vq->log_ctx); } return 0; } -static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) +static int vhost_update_avail_event(struct vhost_virtqueue *vq) { - if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx), - vhost_avail_event(vq))) + if (vhost_put_avail_event(vq)) return -EFAULT; if (unlikely(vq->log_used)) { void __user *used; @@ -1775,11 +2597,10 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) smp_wmb(); /* Log avail event write */ used = vhost_avail_event(vq); - log_write(vq->log_base, vq->log_addr + - (used - (void __user *)vq->used), - sizeof *vhost_avail_event(vq)); + log_used(vq, (used - (void __user *)vq->used), + sizeof *vhost_avail_event(vq)); if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); + eventfd_signal(vq->log_ctx); } return 0; } @@ -1800,11 +2621,11 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) goto err; vq->signalled_used_valid = false; if (!vq->iotlb && - !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { + !access_ok(&vq->used->idx, sizeof vq->used->idx)) { r = -EFAULT; goto err; } - r = vhost_get_used(vq, last_used_idx, &vq->used->idx); + r = vhost_get_used_idx(vq, &last_used_idx); if (r) { vq_err(vq, "Can't access used idx at %p\n", &vq->used->idx); @@ -1822,11 +2643,11 @@ EXPORT_SYMBOL_GPL(vhost_vq_init_access); static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, struct iovec iov[], int iov_size, int access) { - const struct vhost_umem_node *node; + const struct vhost_iotlb_map *map; struct vhost_dev *dev = vq->dev; - struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem; + struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem; struct iovec *_iov; - u64 s = 0; + u64 s = 0, last = addr + len - 1; int ret = 0; while ((u64)len > s) { @@ -1836,25 +2657,24 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, break; } - node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, - addr, addr + len - 1); - if (node == NULL || node->start > addr) { + map = vhost_iotlb_itree_first(umem, addr, last); + if (map == NULL || map->start > addr) { if (umem != dev->iotlb) { ret = -EFAULT; break; } ret = -EAGAIN; break; - } else if (!(node->perm & access)) { + } else if (!(map->perm & access)) { ret = -EPERM; break; } _iov = iov + ret; - size = node->size - addr + node->start; + size = map->size - addr + map->start; _iov->iov_len = min((u64)len - s, size); _iov->iov_base = (void __user *)(unsigned long) - (node->userspace_addr + addr - node->start); + (map->addr + addr - map->start); s += size; addr += size; ++ret; @@ -1877,12 +2697,7 @@ static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc) return -1U; /* Check they're not leading us off end of descriptors. */ - next = vhost16_to_cpu(vq, desc->next); - /* Make sure compiler knows to grab that: we don't want it changing! */ - /* We will use the result as an index in an array, so most - * architectures only need a compiler barrier here. */ - read_barrier_depends(); - + next = vhost16_to_cpu(vq, READ_ONCE(desc->next)); return next; } @@ -1914,12 +2729,7 @@ static int get_indirect(struct vhost_virtqueue *vq, vq_err(vq, "Translation failure %d in indirect.\n", ret); return ret; } - iov_iter_init(&from, READ, vq->indirect, ret, len); - - /* We will use the result as an address to read from, so most - * architectures only need a compiler barrier here. */ - read_barrier_depends(); - + iov_iter_init(&from, ITER_SOURCE, vq->indirect, ret, len); count = len / sizeof desc; /* Buffers are chained via a 16 bit next field, so * we can have at most 2^16 of these. */ @@ -1965,7 +2775,7 @@ static int get_indirect(struct vhost_virtqueue *vq, /* If this is an input descriptor, increment that count. */ if (access == VHOST_ACCESS_WO) { *in_num += ret; - if (unlikely(log)) { + if (unlikely(log && ret)) { log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); log[*log_num].len = vhost32_to_cpu(vq, desc.len); ++*log_num; @@ -1984,67 +2794,66 @@ static int get_indirect(struct vhost_virtqueue *vq, return 0; } -/* This looks in the virtqueue and for the first available buffer, and converts - * it to an iovec for convenient access. Since descriptors consist of some - * number of output then some number of input descriptors, it's actually two - * iovecs, but we pack them into one and note how many of each there were. +/** + * vhost_get_vq_desc_n - Fetch the next available descriptor chain and build iovecs + * @vq: target virtqueue + * @iov: array that receives the scatter/gather segments + * @iov_size: capacity of @iov in elements + * @out_num: the number of output segments + * @in_num: the number of input segments + * @log: optional array to record addr/len for each writable segment; NULL if unused + * @log_num: optional output; number of entries written to @log when provided + * @ndesc: optional output; number of descriptors consumed from the available ring + * (useful for rollback via vhost_discard_vq_desc) * - * This function returns the descriptor number found, or vq->num (which is - * never a valid descriptor number) if none was found. A negative code is - * returned on error. */ -int vhost_get_vq_desc(struct vhost_virtqueue *vq, - struct iovec iov[], unsigned int iov_size, - unsigned int *out_num, unsigned int *in_num, - struct vhost_log *log, unsigned int *log_num) + * Extracts one available descriptor chain from @vq and translates guest addresses + * into host iovecs. + * + * On success, advances @vq->last_avail_idx by 1 and @vq->next_avail_head by the + * number of descriptors consumed (also stored via @ndesc when non-NULL). + * + * Return: + * - head index in [0, @vq->num) on success; + * - @vq->num if no descriptor is currently available; + * - negative errno on failure + */ +int vhost_get_vq_desc_n(struct vhost_virtqueue *vq, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num, + unsigned int *ndesc) { + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); struct vring_desc desc; unsigned int i, head, found = 0; - u16 last_avail_idx; - __virtio16 avail_idx; + u16 last_avail_idx = vq->last_avail_idx; __virtio16 ring_head; - int ret, access; - - /* Check it isn't doing very strange things with descriptor numbers. */ - last_avail_idx = vq->last_avail_idx; + int ret, access, c = 0; if (vq->avail_idx == vq->last_avail_idx) { - if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) { - vq_err(vq, "Failed to access avail idx at %p\n", - &vq->avail->idx); - return -EFAULT; - } - vq->avail_idx = vhost16_to_cpu(vq, avail_idx); - - if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) { - vq_err(vq, "Guest moved used index from %u to %u", - last_avail_idx, vq->avail_idx); - return -EFAULT; - } + ret = vhost_get_avail_idx(vq); + if (unlikely(ret < 0)) + return ret; - /* If there's nothing new since last we looked, return - * invalid. - */ - if (vq->avail_idx == last_avail_idx) + if (!ret) return vq->num; - - /* Only get avail ring entries after they have been - * exposed by guest. - */ - smp_rmb(); } - /* Grab the next descriptor number they're advertising, and increment - * the index we've seen. */ - if (unlikely(vhost_get_avail(vq, ring_head, - &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { - vq_err(vq, "Failed to read head: idx %d address %p\n", - last_avail_idx, - &vq->avail->ring[last_avail_idx % vq->num]); - return -EFAULT; + if (in_order) + head = vq->next_avail_head & (vq->num - 1); + else { + /* Grab the next descriptor number they're + * advertising, and increment the index we've seen. */ + if (unlikely(vhost_get_avail_head(vq, &ring_head, + last_avail_idx))) { + vq_err(vq, "Failed to read head: idx %d address %p\n", + last_avail_idx, + &vq->avail->ring[last_avail_idx % vq->num]); + return -EFAULT; + } + head = vhost16_to_cpu(vq, ring_head); } - head = vhost16_to_cpu(vq, ring_head); - /* If their number is silly, that's an error. */ if (unlikely(head >= vq->num)) { vq_err(vq, "Guest says index %u > %u is available", @@ -2071,8 +2880,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, i, vq->num, head); return -EINVAL; } - ret = vhost_copy_from_user(vq, &desc, vq->desc + i, - sizeof desc); + ret = vhost_get_desc(vq, &desc, i); if (unlikely(ret)) { vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", i, vq->desc + i); @@ -2088,6 +2896,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, "in indirect descriptor at idx %d\n", i); return ret; } + ++c; continue; } @@ -2108,7 +2917,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, /* If this is an input descriptor, * increment that count. */ *in_num += ret; - if (unlikely(log)) { + if (unlikely(log && ret)) { log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); log[*log_num].len = vhost32_to_cpu(vq, desc.len); ++*log_num; @@ -2123,22 +2932,56 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, } *out_num += ret; } + ++c; } while ((i = next_desc(vq, &desc)) != -1); /* On success, increment avail index. */ vq->last_avail_idx++; + vq->next_avail_head += c; + + if (ndesc) + *ndesc = c; /* Assume notifications from guest are disabled at this point, * if they aren't we would need to update avail_event index. */ BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY)); return head; } +EXPORT_SYMBOL_GPL(vhost_get_vq_desc_n); + +/* This looks in the virtqueue and for the first available buffer, and converts + * it to an iovec for convenient access. Since descriptors consist of some + * number of output then some number of input descriptors, it's actually two + * iovecs, but we pack them into one and note how many of each there were. + * + * This function returns the descriptor number found, or vq->num (which is + * never a valid descriptor number) if none was found. A negative code is + * returned on error. + */ +int vhost_get_vq_desc(struct vhost_virtqueue *vq, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num) +{ + return vhost_get_vq_desc_n(vq, iov, iov_size, out_num, in_num, + log, log_num, NULL); +} EXPORT_SYMBOL_GPL(vhost_get_vq_desc); -/* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ -void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) +/** + * vhost_discard_vq_desc - Reverse the effect of vhost_get_vq_desc_n() + * @vq: target virtqueue + * @nbufs: number of buffers to roll back + * @ndesc: number of descriptors to roll back + * + * Rewinds the internal consumer cursors after a failed attempt to use buffers + * returned by vhost_get_vq_desc_n(). + */ +void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int nbufs, + unsigned int ndesc) { - vq->last_avail_idx -= n; + vq->next_avail_head -= ndesc; + vq->last_avail_idx -= nbufs; } EXPORT_SYMBOL_GPL(vhost_discard_vq_desc); @@ -2150,8 +2993,9 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) cpu_to_vhost32(vq, head), cpu_to_vhost32(vq, len) }; + u16 nheads = 1; - return vhost_add_used_n(vq, &heads, 1); + return vhost_add_used_n(vq, &heads, &nheads, 1); } EXPORT_SYMBOL_GPL(vhost_add_used); @@ -2159,22 +3003,13 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, unsigned count) { - struct vring_used_elem __user *used; + vring_used_elem_t __user *used; u16 old, new; int start; start = vq->last_used_idx & (vq->num - 1); used = vq->used->ring + start; - if (count == 1) { - if (vhost_put_user(vq, heads[0].id, &used->id)) { - vq_err(vq, "Failed to write used id"); - return -EFAULT; - } - if (vhost_put_user(vq, heads[0].len, &used->len)) { - vq_err(vq, "Failed to write used len"); - return -EFAULT; - } - } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) { + if (vhost_put_used(vq, heads, start, count)) { vq_err(vq, "Failed to write used"); return -EFAULT; } @@ -2182,10 +3017,8 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, /* Make sure data is seen before log. */ smp_wmb(); /* Log used ring entry write. */ - log_write(vq->log_base, - vq->log_addr + - ((void __user *)used - (void __user *)vq->used), - count * sizeof *used); + log_used(vq, ((void __user *)used - (void __user *)vq->used), + count * sizeof *used); } old = vq->last_used_idx; new = (vq->last_used_idx += count); @@ -2198,10 +3031,9 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, return 0; } -/* After we've used one of their buffers, we tell them about it. We'll then - * want to notify the guest, using eventfd. */ -int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, - unsigned count) +static int vhost_add_used_n_ooo(struct vhost_virtqueue *vq, + struct vring_used_elem *heads, + unsigned count) { int start, n, r; @@ -2214,22 +3046,87 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, heads += n; count -= n; } - r = __vhost_add_used_n(vq, heads, count); + return __vhost_add_used_n(vq, heads, count); +} + +static int vhost_add_used_n_in_order(struct vhost_virtqueue *vq, + struct vring_used_elem *heads, + const u16 *nheads, + unsigned count) +{ + vring_used_elem_t __user *used; + u16 old, new = vq->last_used_idx; + int start, i; + + if (!nheads) + return -EINVAL; + + start = vq->last_used_idx & (vq->num - 1); + used = vq->used->ring + start; + + for (i = 0; i < count; i++) { + if (vhost_put_used(vq, &heads[i], start, 1)) { + vq_err(vq, "Failed to write used"); + return -EFAULT; + } + start += nheads[i]; + new += nheads[i]; + if (start >= vq->num) + start -= vq->num; + } + + if (unlikely(vq->log_used)) { + /* Make sure data is seen before log. */ + smp_wmb(); + /* Log used ring entry write. */ + log_used(vq, ((void __user *)used - (void __user *)vq->used), + (vq->num - start) * sizeof *used); + if (start + count > vq->num) + log_used(vq, 0, + (start + count - vq->num) * sizeof *used); + } + + old = vq->last_used_idx; + vq->last_used_idx = new; + /* If the driver never bothers to signal in a very long while, + * used index might wrap around. If that happens, invalidate + * signalled_used index we stored. TODO: make sure driver + * signals at least once in 2^16 and remove this. */ + if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old))) + vq->signalled_used_valid = false; + return 0; +} + +/* After we've used one of their buffers, we tell them about it. We'll then + * want to notify the guest, using eventfd. */ +int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, + u16 *nheads, unsigned count) +{ + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); + int r; + + if (!in_order || !nheads) + r = vhost_add_used_n_ooo(vq, heads, count); + else + r = vhost_add_used_n_in_order(vq, heads, nheads, count); + + if (r < 0) + return r; /* Make sure buffer is written before we update index. */ smp_wmb(); - if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx), - &vq->used->idx)) { + if (vhost_put_used_idx(vq)) { vq_err(vq, "Failed to increment used idx"); return -EFAULT; } if (unlikely(vq->log_used)) { + /* Make sure used idx is seen before log. */ + smp_wmb(); /* Log used index update. */ - log_write(vq->log_base, - vq->log_addr + offsetof(struct vring_used, idx), - sizeof vq->used->idx); + log_used(vq, offsetof(struct vring_used, idx), + sizeof vq->used->idx); if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); + eventfd_signal(vq->log_ctx); } return r; } @@ -2251,7 +3148,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { __virtio16 flags; - if (vhost_get_avail(vq, flags, &vq->avail->flags)) { + if (vhost_get_avail_flags(vq, &flags)) { vq_err(vq, "Failed to get flags"); return true; } @@ -2265,7 +3162,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) if (unlikely(!v)) return true; - if (vhost_get_avail(vq, event, vhost_used_event(vq))) { + if (vhost_get_used_event(vq, &event)) { vq_err(vq, "Failed to get used event idx"); return true; } @@ -2276,8 +3173,8 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) { /* Signal the Guest tell them we used something up. */ - if (vq->call_ctx && vhost_notify(dev, vq)) - eventfd_signal(vq->call_ctx, 1); + if (vq->call_ctx.ctx && vhost_notify(dev, vq)) + eventfd_signal(vq->call_ctx.ctx); } EXPORT_SYMBOL_GPL(vhost_signal); @@ -2294,35 +3191,33 @@ EXPORT_SYMBOL_GPL(vhost_add_used_and_signal); /* multi-buffer version of vhost_add_used_and_signal */ void vhost_add_used_and_signal_n(struct vhost_dev *dev, struct vhost_virtqueue *vq, - struct vring_used_elem *heads, unsigned count) + struct vring_used_elem *heads, + u16 *nheads, + unsigned count) { - vhost_add_used_n(vq, heads, count); + vhost_add_used_n(vq, heads, nheads, count); vhost_signal(dev, vq); } EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n); -/* return true if we're sure that avaiable ring is empty */ +/* return true if we're sure that available ring is empty */ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) { - __virtio16 avail_idx; int r; if (vq->avail_idx != vq->last_avail_idx) return false; - r = vhost_get_avail(vq, avail_idx, &vq->avail->idx); - if (unlikely(r)) - return false; - vq->avail_idx = vhost16_to_cpu(vq, avail_idx); + r = vhost_get_avail_idx(vq); - return vq->avail_idx == vq->last_avail_idx; + /* Note: we treat error as non-empty here */ + return r == 0; } EXPORT_SYMBOL_GPL(vhost_vq_avail_empty); /* OK, now we need to know about added descriptors. */ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) { - __virtio16 avail_idx; int r; if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY)) @@ -2336,7 +3231,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) return false; } } else { - r = vhost_update_avail_event(vq, vq->avail_idx); + r = vhost_update_avail_event(vq); if (r) { vq_err(vq, "Failed to update avail event index at %p: %d\n", vhost_avail_event(vq), r); @@ -2346,14 +3241,13 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) /* They could have slipped one in as we were doing that: make * sure it's written, then check again. */ smp_mb(); - r = vhost_get_avail(vq, avail_idx, &vq->avail->idx); - if (r) { - vq_err(vq, "Failed to check avail idx at %p: %d\n", - &vq->avail->idx, r); + + r = vhost_get_avail_idx(vq); + /* Note: we treat error as empty here */ + if (unlikely(r < 0)) return false; - } - return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx; + return r; } EXPORT_SYMBOL_GPL(vhost_enable_notify); @@ -2368,7 +3262,7 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { r = vhost_update_used_flags(vq); if (r) - vq_err(vq, "Failed to enable notification at %p: %d\n", + vq_err(vq, "Failed to disable notification at %p: %d\n", &vq->used->flags, r); } } @@ -2377,9 +3271,11 @@ EXPORT_SYMBOL_GPL(vhost_disable_notify); /* Create a new message. */ struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type) { - struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL); + /* Make sure all padding within the structure is initialized. */ + struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL); if (!node) return NULL; + node->vq = vq; node->msg.type = type; return node; @@ -2393,7 +3289,7 @@ void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head, list_add_tail(&node->node, head); spin_unlock(&dev->iotlb_lock); - wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM); + wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); } EXPORT_SYMBOL_GPL(vhost_enqueue_msg); @@ -2414,6 +3310,21 @@ struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev, } EXPORT_SYMBOL_GPL(vhost_dequeue_msg); +void vhost_set_backend_features(struct vhost_dev *dev, u64 features) +{ + struct vhost_virtqueue *vq; + int i; + + mutex_lock(&dev->mutex); + for (i = 0; i < dev->nvqs; ++i) { + vq = dev->vqs[i]; + mutex_lock(&vq->mutex); + vq->acked_backend_features = features; + mutex_unlock(&vq->mutex); + } + mutex_unlock(&dev->mutex); +} +EXPORT_SYMBOL_GPL(vhost_set_backend_features); static int __init vhost_init(void) { |
