diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
| -rw-r--r-- | drivers/vhost/vhost.c | 1202 |
1 files changed, 973 insertions, 229 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index d7b8df3edffc..bccdc9eab267 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -27,6 +27,7 @@ #include <linux/sort.h> #include <linux/sched/mm.h> #include <linux/sched/signal.h> +#include <linux/sched/vhost_task.h> #include <linux/interval_tree_generic.h> #include <linux/nospec.h> #include <linux/kcov.h> @@ -41,6 +42,13 @@ static int max_iotlb_entries = 2048; module_param(max_iotlb_entries, int, 0444); MODULE_PARM_DESC(max_iotlb_entries, "Maximum number of iotlb entries. (default: 2048)"); +static bool fork_from_owner_default = VHOST_FORK_OWNER_TASK; + +#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL +module_param(fork_from_owner_default, bool, 0444); +MODULE_PARM_DESC(fork_from_owner_default, + "Set task mode as the default(default: Y)"); +#endif enum { VHOST_MEMORY_F_LOG = 0x1, @@ -187,13 +195,15 @@ EXPORT_SYMBOL_GPL(vhost_work_init); /* Init poll structure */ void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, - __poll_t mask, struct vhost_dev *dev) + __poll_t mask, struct vhost_dev *dev, + struct vhost_virtqueue *vq) { init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); init_poll_funcptr(&poll->table, vhost_poll_func); poll->mask = mask; poll->dev = dev; poll->wqh = NULL; + poll->vq = vq; vhost_work_init(&poll->work, fn); } @@ -231,54 +241,98 @@ void vhost_poll_stop(struct vhost_poll *poll) } EXPORT_SYMBOL_GPL(vhost_poll_stop); -void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) +static void vhost_worker_queue(struct vhost_worker *worker, + struct vhost_work *work) { - struct vhost_flush_struct flush; + if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { + /* We can only add the work to the list after we're + * sure it was not in the list. + * test_and_set_bit() implies a memory barrier. + */ + llist_add(&work->node, &worker->work_list); + worker->ops->wakeup(worker); + } +} - if (dev->worker) { - init_completion(&flush.wait_event); - vhost_work_init(&flush.work, vhost_flush_work); +bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work) +{ + struct vhost_worker *worker; + bool queued = false; - vhost_work_queue(dev, &flush.work); - wait_for_completion(&flush.wait_event); + rcu_read_lock(); + worker = rcu_dereference(vq->worker); + if (worker) { + queued = true; + vhost_worker_queue(worker, work); } + rcu_read_unlock(); + + return queued; } -EXPORT_SYMBOL_GPL(vhost_work_flush); +EXPORT_SYMBOL_GPL(vhost_vq_work_queue); -/* Flush any work that has been scheduled. When calling this, don't hold any - * locks that are also used by the callback. */ -void vhost_poll_flush(struct vhost_poll *poll) +/** + * __vhost_worker_flush - flush a worker + * @worker: worker to flush + * + * The worker's flush_mutex must be held. + */ +static void __vhost_worker_flush(struct vhost_worker *worker) { - vhost_work_flush(poll->dev, &poll->work); + struct vhost_flush_struct flush; + + if (!worker->attachment_cnt || worker->killed) + return; + + init_completion(&flush.wait_event); + vhost_work_init(&flush.work, vhost_flush_work); + + vhost_worker_queue(worker, &flush.work); + /* + * Drop mutex in case our worker is killed and it needs to take the + * mutex to force cleanup. + */ + mutex_unlock(&worker->mutex); + wait_for_completion(&flush.wait_event); + mutex_lock(&worker->mutex); } -EXPORT_SYMBOL_GPL(vhost_poll_flush); -void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) +static void vhost_worker_flush(struct vhost_worker *worker) { - if (!dev->worker) - return; + mutex_lock(&worker->mutex); + __vhost_worker_flush(worker); + mutex_unlock(&worker->mutex); +} - if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { - /* We can only add the work to the list after we're - * sure it was not in the list. - * test_and_set_bit() implies a memory barrier. - */ - llist_add(&work->node, &dev->work_list); - wake_up_process(dev->worker); - } +void vhost_dev_flush(struct vhost_dev *dev) +{ + struct vhost_worker *worker; + unsigned long i; + + xa_for_each(&dev->worker_xa, i, worker) + vhost_worker_flush(worker); } -EXPORT_SYMBOL_GPL(vhost_work_queue); +EXPORT_SYMBOL_GPL(vhost_dev_flush); /* A lockless hint for busy polling code to exit the loop */ -bool vhost_has_work(struct vhost_dev *dev) +bool vhost_vq_has_work(struct vhost_virtqueue *vq) { - return !llist_empty(&dev->work_list); + struct vhost_worker *worker; + bool has_work = false; + + rcu_read_lock(); + worker = rcu_dereference(vq->worker); + if (worker && !llist_empty(&worker->work_list)) + has_work = true; + rcu_read_unlock(); + + return has_work; } -EXPORT_SYMBOL_GPL(vhost_has_work); +EXPORT_SYMBOL_GPL(vhost_vq_has_work); void vhost_poll_queue(struct vhost_poll *poll) { - vhost_work_queue(poll->dev, &poll->work); + vhost_vq_work_queue(poll->vq, &poll->work); } EXPORT_SYMBOL_GPL(vhost_poll_queue); @@ -298,6 +352,18 @@ static void vhost_vq_meta_reset(struct vhost_dev *d) __vhost_vq_meta_reset(d->vqs[i]); } +static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx) +{ + call_ctx->ctx = NULL; + memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer)); +} + +bool vhost_vq_is_setup(struct vhost_virtqueue *vq) +{ + return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq); +} +EXPORT_SYMBOL_GPL(vhost_vq_is_setup); + static void vhost_vq_reset(struct vhost_dev *dev, struct vhost_virtqueue *vq) { @@ -306,6 +372,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->avail = NULL; vq->used = NULL; vq->last_avail_idx = 0; + vq->next_avail_head = 0; vq->avail_idx = 0; vq->last_used_idx = 0; vq->signalled_used = 0; @@ -314,25 +381,27 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_used = false; vq->log_addr = -1ull; vq->private_data = NULL; - vq->acked_features = 0; + virtio_features_zero(vq->acked_features_array); vq->acked_backend_features = 0; vq->log_base = NULL; vq->error_ctx = NULL; vq->kick = NULL; - vq->call_ctx = NULL; vq->log_ctx = NULL; - vhost_reset_is_le(vq); vhost_disable_cross_endian(vq); + vhost_reset_is_le(vq); vq->busyloop_timeout = 0; vq->umem = NULL; vq->iotlb = NULL; + rcu_assign_pointer(vq->worker, NULL); + vhost_vring_call_reset(&vq->call_ctx); __vhost_vq_meta_reset(vq); } -static int vhost_worker(void *data) +static int vhost_run_work_kthread_list(void *data) { - struct vhost_dev *dev = data; + struct vhost_worker *worker = data; struct vhost_work *work, *work_next; + struct vhost_dev *dev = worker->dev; struct llist_node *node; kthread_use_mm(dev->mm); @@ -345,8 +414,7 @@ static int vhost_worker(void *data) __set_current_state(TASK_RUNNING); break; } - - node = llist_del_all(&dev->work_list); + node = llist_del_all(&worker->work_list); if (!node) schedule(); @@ -356,17 +424,76 @@ static int vhost_worker(void *data) llist_for_each_entry_safe(work, work_next, node, node) { clear_bit(VHOST_WORK_QUEUED, &work->flags); __set_current_state(TASK_RUNNING); - kcov_remote_start_common(dev->kcov_handle); + kcov_remote_start_common(worker->kcov_handle); work->fn(work); kcov_remote_stop(); - if (need_resched()) - schedule(); + cond_resched(); } } kthread_unuse_mm(dev->mm); + return 0; } +static bool vhost_run_work_list(void *data) +{ + struct vhost_worker *worker = data; + struct vhost_work *work, *work_next; + struct llist_node *node; + + node = llist_del_all(&worker->work_list); + if (node) { + __set_current_state(TASK_RUNNING); + + node = llist_reverse_order(node); + /* make sure flag is seen after deletion */ + smp_wmb(); + llist_for_each_entry_safe(work, work_next, node, node) { + clear_bit(VHOST_WORK_QUEUED, &work->flags); + kcov_remote_start_common(worker->kcov_handle); + work->fn(work); + kcov_remote_stop(); + cond_resched(); + } + } + + return !!node; +} + +static void vhost_worker_killed(void *data) +{ + struct vhost_worker *worker = data; + struct vhost_dev *dev = worker->dev; + struct vhost_virtqueue *vq; + int i, attach_cnt = 0; + + mutex_lock(&worker->mutex); + worker->killed = true; + + for (i = 0; i < dev->nvqs; i++) { + vq = dev->vqs[i]; + + mutex_lock(&vq->mutex); + if (worker == + rcu_dereference_check(vq->worker, + lockdep_is_held(&vq->mutex))) { + rcu_assign_pointer(vq->worker, NULL); + attach_cnt++; + } + mutex_unlock(&vq->mutex); + } + + worker->attachment_cnt -= attach_cnt; + if (attach_cnt) + synchronize_rcu(); + /* + * Finish vhost_worker_flush calls and any other works that snuck in + * before the synchronize_rcu. + */ + vhost_run_work_list(worker); + mutex_unlock(&worker->mutex); +} + static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) { kfree(vq->indirect); @@ -375,6 +502,8 @@ static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) vq->log = NULL; kfree(vq->heads); vq->heads = NULL; + kfree(vq->nheads); + vq->nheads = NULL; } /* Helper to allocate iovec buffers for all vqs. */ @@ -392,7 +521,9 @@ static long vhost_dev_alloc_iovecs(struct vhost_dev *dev) GFP_KERNEL); vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads), GFP_KERNEL); - if (!vq->indirect || !vq->log || !vq->heads) + vq->nheads = kmalloc_array(dev->iov_limit, sizeof(*vq->nheads), + GFP_KERNEL); + if (!vq->indirect || !vq->log || !vq->heads || !vq->nheads) goto err_nomem; } return 0; @@ -432,8 +563,7 @@ static size_t vhost_get_avail_size(struct vhost_virtqueue *vq, size_t event __maybe_unused = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - return sizeof(*vq->avail) + - sizeof(*vq->avail->ring) * num + event; + return size_add(struct_size(vq->avail, ring, num), event); } static size_t vhost_get_used_size(struct vhost_virtqueue *vq, @@ -442,8 +572,7 @@ static size_t vhost_get_used_size(struct vhost_virtqueue *vq, size_t event __maybe_unused = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - return sizeof(*vq->used) + - sizeof(*vq->used->ring) * num + event; + return size_add(struct_size(vq->used, ring, num), event); } static size_t vhost_get_desc_size(struct vhost_virtqueue *vq, @@ -456,7 +585,7 @@ void vhost_dev_init(struct vhost_dev *dev, struct vhost_virtqueue **vqs, int nvqs, int iov_limit, int weight, int byte_weight, bool use_worker, - int (*msg_handler)(struct vhost_dev *dev, + int (*msg_handler)(struct vhost_dev *dev, u32 asid, struct vhost_iotlb_msg *msg)) { struct vhost_virtqueue *vq; @@ -469,30 +598,30 @@ void vhost_dev_init(struct vhost_dev *dev, dev->umem = NULL; dev->iotlb = NULL; dev->mm = NULL; - dev->worker = NULL; dev->iov_limit = iov_limit; dev->weight = weight; dev->byte_weight = byte_weight; dev->use_worker = use_worker; dev->msg_handler = msg_handler; - init_llist_head(&dev->work_list); + dev->fork_owner = fork_from_owner_default; init_waitqueue_head(&dev->wait); INIT_LIST_HEAD(&dev->read_list); INIT_LIST_HEAD(&dev->pending_list); spin_lock_init(&dev->iotlb_lock); - + xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC); for (i = 0; i < dev->nvqs; ++i) { vq = dev->vqs[i]; vq->log = NULL; vq->indirect = NULL; vq->heads = NULL; + vq->nheads = NULL; vq->dev = dev; mutex_init(&vq->mutex); vhost_vq_reset(dev, vq); if (vq->handle_kick) vhost_poll_init(&vq->poll, vq->handle_kick, - EPOLLIN, dev); + EPOLLIN, dev, vq); } } EXPORT_SYMBOL_GPL(vhost_dev_init); @@ -519,14 +648,29 @@ static void vhost_attach_cgroups_work(struct vhost_work *work) s->ret = cgroup_attach_task_all(s->owner, current); } -static int vhost_attach_cgroups(struct vhost_dev *dev) +static int vhost_attach_task_to_cgroups(struct vhost_worker *worker) { struct vhost_attach_cgroups_struct attach; + int saved_cnt; attach.owner = current; + vhost_work_init(&attach.work, vhost_attach_cgroups_work); - vhost_work_queue(dev, &attach.work); - vhost_work_flush(dev, &attach.work); + vhost_worker_queue(worker, &attach.work); + + mutex_lock(&worker->mutex); + + /* + * Bypass attachment_cnt check in __vhost_worker_flush: + * Temporarily change it to INT_MAX to bypass the check + */ + saved_cnt = worker->attachment_cnt; + worker->attachment_cnt = INT_MAX; + __vhost_worker_flush(worker); + worker->attachment_cnt = saved_cnt; + + mutex_unlock(&worker->mutex); + return attach.ret; } @@ -543,10 +687,10 @@ static void vhost_attach_mm(struct vhost_dev *dev) if (dev->use_worker) { dev->mm = get_task_mm(current); } else { - /* vDPA device does not use worker thead, so there's - * no need to hold the address space for mm. This help + /* vDPA device does not use worker thread, so there's + * no need to hold the address space for mm. This helps * to avoid deadlock in the case of mmap() which may - * held the refcnt of the file and depends on release + * hold the refcnt of the file and depends on release * method to remove vma. */ dev->mm = current->mm; @@ -567,11 +711,393 @@ static void vhost_detach_mm(struct vhost_dev *dev) dev->mm = NULL; } +static void vhost_worker_destroy(struct vhost_dev *dev, + struct vhost_worker *worker) +{ + if (!worker) + return; + + WARN_ON(!llist_empty(&worker->work_list)); + xa_erase(&dev->worker_xa, worker->id); + worker->ops->stop(worker); + kfree(worker); +} + +static void vhost_workers_free(struct vhost_dev *dev) +{ + struct vhost_worker *worker; + unsigned long i; + + if (!dev->use_worker) + return; + + for (i = 0; i < dev->nvqs; i++) + rcu_assign_pointer(dev->vqs[i]->worker, NULL); + /* + * Free the default worker we created and cleanup workers userspace + * created but couldn't clean up (it forgot or crashed). + */ + xa_for_each(&dev->worker_xa, i, worker) + vhost_worker_destroy(dev, worker); + xa_destroy(&dev->worker_xa); +} + +static void vhost_task_wakeup(struct vhost_worker *worker) +{ + return vhost_task_wake(worker->vtsk); +} + +static void vhost_kthread_wakeup(struct vhost_worker *worker) +{ + wake_up_process(worker->kthread_task); +} + +static void vhost_task_do_stop(struct vhost_worker *worker) +{ + return vhost_task_stop(worker->vtsk); +} + +static void vhost_kthread_do_stop(struct vhost_worker *worker) +{ + kthread_stop(worker->kthread_task); +} + +static int vhost_task_worker_create(struct vhost_worker *worker, + struct vhost_dev *dev, const char *name) +{ + struct vhost_task *vtsk; + u32 id; + int ret; + + vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed, + worker, name); + if (IS_ERR(vtsk)) + return PTR_ERR(vtsk); + + worker->vtsk = vtsk; + vhost_task_start(vtsk); + ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + if (ret < 0) { + vhost_task_do_stop(worker); + return ret; + } + worker->id = id; + return 0; +} + +static int vhost_kthread_worker_create(struct vhost_worker *worker, + struct vhost_dev *dev, const char *name) +{ + struct task_struct *task; + u32 id; + int ret; + + task = kthread_create(vhost_run_work_kthread_list, worker, "%s", name); + if (IS_ERR(task)) + return PTR_ERR(task); + + worker->kthread_task = task; + wake_up_process(task); + ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + if (ret < 0) + goto stop_worker; + + ret = vhost_attach_task_to_cgroups(worker); + if (ret) + goto free_id; + + worker->id = id; + return 0; + +free_id: + xa_erase(&dev->worker_xa, id); +stop_worker: + vhost_kthread_do_stop(worker); + return ret; +} + +static const struct vhost_worker_ops kthread_ops = { + .create = vhost_kthread_worker_create, + .stop = vhost_kthread_do_stop, + .wakeup = vhost_kthread_wakeup, +}; + +static const struct vhost_worker_ops vhost_task_ops = { + .create = vhost_task_worker_create, + .stop = vhost_task_do_stop, + .wakeup = vhost_task_wakeup, +}; + +static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev) +{ + struct vhost_worker *worker; + char name[TASK_COMM_LEN]; + int ret; + const struct vhost_worker_ops *ops = dev->fork_owner ? &vhost_task_ops : + &kthread_ops; + + worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT); + if (!worker) + return NULL; + + worker->dev = dev; + worker->ops = ops; + snprintf(name, sizeof(name), "vhost-%d", current->pid); + + mutex_init(&worker->mutex); + init_llist_head(&worker->work_list); + worker->kcov_handle = kcov_common_handle(); + ret = ops->create(worker, dev, name); + if (ret < 0) + goto free_worker; + + return worker; + +free_worker: + kfree(worker); + return NULL; +} + +/* Caller must have device mutex */ +static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, + struct vhost_worker *worker) +{ + struct vhost_worker *old_worker; + + mutex_lock(&worker->mutex); + if (worker->killed) { + mutex_unlock(&worker->mutex); + return; + } + + mutex_lock(&vq->mutex); + + old_worker = rcu_dereference_check(vq->worker, + lockdep_is_held(&vq->mutex)); + rcu_assign_pointer(vq->worker, worker); + worker->attachment_cnt++; + + if (!old_worker) { + mutex_unlock(&vq->mutex); + mutex_unlock(&worker->mutex); + return; + } + mutex_unlock(&vq->mutex); + mutex_unlock(&worker->mutex); + + /* + * Take the worker mutex to make sure we see the work queued from + * device wide flushes which doesn't use RCU for execution. + */ + mutex_lock(&old_worker->mutex); + if (old_worker->killed) { + mutex_unlock(&old_worker->mutex); + return; + } + + /* + * We don't want to call synchronize_rcu for every vq during setup + * because it will slow down VM startup. If we haven't done + * VHOST_SET_VRING_KICK and not done the driver specific + * SET_ENDPOINT/RUNNING then we can skip the sync since there will + * not be any works queued for scsi and net. + */ + mutex_lock(&vq->mutex); + if (!vhost_vq_get_backend(vq) && !vq->kick) { + mutex_unlock(&vq->mutex); + + old_worker->attachment_cnt--; + mutex_unlock(&old_worker->mutex); + /* + * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID. + * Warn if it adds support for multiple workers but forgets to + * handle the early queueing case. + */ + WARN_ON(!old_worker->attachment_cnt && + !llist_empty(&old_worker->work_list)); + return; + } + mutex_unlock(&vq->mutex); + + /* Make sure new vq queue/flush/poll calls see the new worker */ + synchronize_rcu(); + /* Make sure whatever was queued gets run */ + __vhost_worker_flush(old_worker); + old_worker->attachment_cnt--; + mutex_unlock(&old_worker->mutex); +} + + /* Caller must have device mutex */ +static int vhost_vq_attach_worker(struct vhost_virtqueue *vq, + struct vhost_vring_worker *info) +{ + unsigned long index = info->worker_id; + struct vhost_dev *dev = vq->dev; + struct vhost_worker *worker; + + if (!dev->use_worker) + return -EINVAL; + + worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT); + if (!worker || worker->id != info->worker_id) + return -ENODEV; + + __vhost_vq_attach_worker(vq, worker); + return 0; +} + +/* Caller must have device mutex */ +static int vhost_new_worker(struct vhost_dev *dev, + struct vhost_worker_state *info) +{ + struct vhost_worker *worker; + + worker = vhost_worker_create(dev); + if (!worker) + return -ENOMEM; + + info->worker_id = worker->id; + return 0; +} + +/* Caller must have device mutex */ +static int vhost_free_worker(struct vhost_dev *dev, + struct vhost_worker_state *info) +{ + unsigned long index = info->worker_id; + struct vhost_worker *worker; + + worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT); + if (!worker || worker->id != info->worker_id) + return -ENODEV; + + mutex_lock(&worker->mutex); + if (worker->attachment_cnt || worker->killed) { + mutex_unlock(&worker->mutex); + return -EBUSY; + } + /* + * A flush might have raced and snuck in before attachment_cnt was set + * to zero. Make sure flushes are flushed from the queue before + * freeing. + */ + __vhost_worker_flush(worker); + mutex_unlock(&worker->mutex); + + vhost_worker_destroy(dev, worker); + return 0; +} + +static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp, + struct vhost_virtqueue **vq, u32 *id) +{ + u32 __user *idxp = argp; + u32 idx; + long r; + + r = get_user(idx, idxp); + if (r < 0) + return r; + + if (idx >= dev->nvqs) + return -ENOBUFS; + + idx = array_index_nospec(idx, dev->nvqs); + + *vq = dev->vqs[idx]; + *id = idx; + return 0; +} + +/* Caller must have device mutex */ +long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl, + void __user *argp) +{ + struct vhost_vring_worker ring_worker; + struct vhost_worker_state state; + struct vhost_worker *worker; + struct vhost_virtqueue *vq; + long ret; + u32 idx; + + if (!dev->use_worker) + return -EINVAL; + + if (!vhost_dev_has_owner(dev)) + return -EINVAL; + + ret = vhost_dev_check_owner(dev); + if (ret) + return ret; + + switch (ioctl) { + /* dev worker ioctls */ + case VHOST_NEW_WORKER: + /* + * vhost_tasks will account for worker threads under the parent's + * NPROC value but kthreads do not. To avoid userspace overflowing + * the system with worker threads fork_owner must be true. + */ + if (!dev->fork_owner) + return -EFAULT; + + ret = vhost_new_worker(dev, &state); + if (!ret && copy_to_user(argp, &state, sizeof(state))) + ret = -EFAULT; + return ret; + case VHOST_FREE_WORKER: + if (copy_from_user(&state, argp, sizeof(state))) + return -EFAULT; + return vhost_free_worker(dev, &state); + /* vring worker ioctls */ + case VHOST_ATTACH_VRING_WORKER: + case VHOST_GET_VRING_WORKER: + break; + default: + return -ENOIOCTLCMD; + } + + ret = vhost_get_vq_from_user(dev, argp, &vq, &idx); + if (ret) + return ret; + + switch (ioctl) { + case VHOST_ATTACH_VRING_WORKER: + if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) { + ret = -EFAULT; + break; + } + + ret = vhost_vq_attach_worker(vq, &ring_worker); + break; + case VHOST_GET_VRING_WORKER: + worker = rcu_dereference_check(vq->worker, + lockdep_is_held(&dev->mutex)); + if (!worker) { + ret = -EINVAL; + break; + } + + ring_worker.index = idx; + ring_worker.worker_id = worker->id; + + if (copy_to_user(argp, &ring_worker, sizeof(ring_worker))) + ret = -EFAULT; + break; + default: + ret = -ENOIOCTLCMD; + break; + } + + return ret; +} +EXPORT_SYMBOL_GPL(vhost_worker_ioctl); + /* Caller should have device mutex */ long vhost_dev_set_owner(struct vhost_dev *dev) { - struct task_struct *worker; - int err; + struct vhost_worker *worker; + int err, i; /* Is there an owner already? */ if (vhost_dev_has_owner(dev)) { @@ -581,36 +1107,33 @@ long vhost_dev_set_owner(struct vhost_dev *dev) vhost_attach_mm(dev); - dev->kcov_handle = kcov_common_handle(); + err = vhost_dev_alloc_iovecs(dev); + if (err) + goto err_iovecs; + if (dev->use_worker) { - worker = kthread_create(vhost_worker, dev, - "vhost-%d", current->pid); - if (IS_ERR(worker)) { - err = PTR_ERR(worker); + /* + * This should be done last, because vsock can queue work + * before VHOST_SET_OWNER so it simplifies the failure path + * below since we don't have to worry about vsock queueing + * while we free the worker. + */ + worker = vhost_worker_create(dev); + if (!worker) { + err = -ENOMEM; goto err_worker; } - dev->worker = worker; - wake_up_process(worker); /* avoid contributing to loadavg */ - - err = vhost_attach_cgroups(dev); - if (err) - goto err_cgroup; + for (i = 0; i < dev->nvqs; i++) + __vhost_vq_attach_worker(dev->vqs[i], worker); } - err = vhost_dev_alloc_iovecs(dev); - if (err) - goto err_cgroup; - return 0; -err_cgroup: - if (dev->worker) { - kthread_stop(dev->worker); - dev->worker = NULL; - } + err_worker: + vhost_dev_free_iovecs(dev); +err_iovecs: vhost_detach_mm(dev); - dev->kcov_handle = 0; err_mm: return err; } @@ -635,6 +1158,7 @@ void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem) vhost_dev_cleanup(dev); + dev->fork_owner = fork_from_owner_default; dev->umem = umem; /* We don't need VQ locks below since vhost_dev_cleanup makes sure * VQs aren't running. @@ -649,15 +1173,15 @@ void vhost_dev_stop(struct vhost_dev *dev) int i; for (i = 0; i < dev->nvqs; ++i) { - if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) { + if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) vhost_poll_stop(&dev->vqs[i]->poll); - vhost_poll_flush(&dev->vqs[i]->poll); - } } + + vhost_dev_flush(dev); } EXPORT_SYMBOL_GPL(vhost_dev_stop); -static void vhost_clear_msg(struct vhost_dev *dev) +void vhost_clear_msg(struct vhost_dev *dev) { struct vhost_msg_node *node, *n; @@ -675,6 +1199,7 @@ static void vhost_clear_msg(struct vhost_dev *dev) spin_unlock(&dev->iotlb_lock); } +EXPORT_SYMBOL_GPL(vhost_clear_msg); void vhost_dev_cleanup(struct vhost_dev *dev) { @@ -685,8 +1210,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev) eventfd_ctx_put(dev->vqs[i]->error_ctx); if (dev->vqs[i]->kick) fput(dev->vqs[i]->kick); - if (dev->vqs[i]->call_ctx) - eventfd_ctx_put(dev->vqs[i]->call_ctx); + if (dev->vqs[i]->call_ctx.ctx) + eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx); vhost_vq_reset(dev, dev->vqs[i]); } vhost_dev_free_iovecs(dev); @@ -700,12 +1225,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev) dev->iotlb = NULL; vhost_clear_msg(dev); wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); - WARN_ON(!llist_empty(&dev->work_list)); - if (dev->worker) { - kthread_stop(dev->worker); - dev->worker = NULL; - dev->kcov_handle = 0; - } + vhost_workers_free(dev); vhost_detach_mm(dev); } EXPORT_SYMBOL_GPL(vhost_dev_cleanup); @@ -723,10 +1243,16 @@ static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz) (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); } +/* Make sure 64 bit math will not overflow. */ static bool vhost_overflow(u64 uaddr, u64 size) { - /* Make sure 64 bit math will not overflow. */ - return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size; + if (uaddr > ULONG_MAX || size > ULONG_MAX) + return true; + + if (!size) + return false; + + return uaddr > ULONG_MAX - size + 1; } /* Caller should have vq mutex and device mutex. */ @@ -822,7 +1348,7 @@ static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to, VHOST_ACCESS_WO); if (ret < 0) goto out; - iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size); + iov_iter_init(&t, ITER_DEST, vq->iotlb_iov, ret, size); ret = copy_to_iter(from, size, &t); if (ret == size) ret = 0; @@ -861,7 +1387,7 @@ static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, (unsigned long long) size); goto out; } - iov_iter_init(&f, READ, vq->iotlb_iov, ret, size); + iov_iter_init(&f, ITER_SOURCE, vq->iotlb_iov, ret, size); ret = copy_from_iter(to, size, &f); if (ret == size) ret = 0; @@ -997,10 +1523,36 @@ static void vhost_dev_unlock_vqs(struct vhost_dev *d) mutex_unlock(&d->vqs[i]->mutex); } -static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq, - __virtio16 *idx) +static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq) { - return vhost_get_avail(vq, *idx, &vq->avail->idx); + __virtio16 idx; + int r; + + r = vhost_get_avail(vq, idx, &vq->avail->idx); + if (unlikely(r < 0)) { + vq_err(vq, "Failed to access available index at %p (%d)\n", + &vq->avail->idx, r); + return r; + } + + /* Check it isn't doing very strange thing with available indexes */ + vq->avail_idx = vhost16_to_cpu(vq, idx); + if (unlikely((u16)(vq->avail_idx - vq->last_avail_idx) > vq->num)) { + vq_err(vq, "Invalid available index change from %u to %u", + vq->last_avail_idx, vq->avail_idx); + return -EINVAL; + } + + /* We're done if there is nothing new */ + if (vq->avail_idx == vq->last_avail_idx) + return 0; + + /* + * We updated vq->avail_idx so we need a memory barrier between + * the index read above and the caller reading avail ring entries. + */ + smp_rmb(); + return 1; } static inline int vhost_get_avail_head(struct vhost_virtqueue *vq, @@ -1072,11 +1624,14 @@ static bool umem_access_ok(u64 uaddr, u64 size, int access) return true; } -static int vhost_process_iotlb_msg(struct vhost_dev *dev, +static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid, struct vhost_iotlb_msg *msg) { int ret = 0; + if (asid != 0) + return -EINVAL; + mutex_lock(&dev->mutex); vhost_dev_lock_vqs(dev); switch (msg->type) { @@ -1123,6 +1678,7 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, struct vhost_iotlb_msg msg; size_t offset; int type, ret; + u32 asid = 0; ret = copy_from_iter(&type, sizeof(type), from); if (ret != sizeof(type)) { @@ -1138,7 +1694,16 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, offset = offsetof(struct vhost_msg, iotlb) - sizeof(int); break; case VHOST_IOTLB_MSG_V2: - offset = sizeof(__u32); + if (vhost_backend_has_feature(dev->vqs[0], + VHOST_BACKEND_F_IOTLB_ASID)) { + ret = copy_from_iter(&asid, sizeof(asid), from); + if (ret != sizeof(asid)) { + ret = -EINVAL; + goto done; + } + offset = 0; + } else + offset = sizeof(__u32); break; default: ret = -EINVAL; @@ -1152,10 +1717,15 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, goto done; } + if (msg.type == VHOST_IOTLB_UPDATE && msg.size == 0) { + ret = -EINVAL; + goto done; + } + if (dev->msg_handler) - ret = dev->msg_handler(dev, &msg); + ret = dev->msg_handler(dev, asid, &msg); else - ret = vhost_process_iotlb_msg(dev, &msg); + ret = vhost_process_iotlb_msg(dev, asid, &msg); if (ret) { ret = -EFAULT; goto done; @@ -1283,6 +1853,11 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, vring_used_t __user *used) { + /* If an IOTLB device is present, the vring addresses are + * GIOVAs. Access validation occurs at prefetch time. */ + if (vq->iotlb) + return true; + return access_ok(desc, vhost_get_desc_size(vq, num)) && access_ok(avail, vhost_get_avail_size(vq, num)) && access_ok(used, vhost_get_used_size(vq, num)); @@ -1358,6 +1933,20 @@ bool vhost_log_access_ok(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_log_access_ok); +static bool vq_log_used_access_ok(struct vhost_virtqueue *vq, + void __user *log_base, + bool log_used, + u64 log_addr) +{ + /* If an IOTLB device is present, log_addr is a GIOVA that + * will never be logged by log_used(). */ + if (vq->iotlb) + return true; + + return !log_used || log_access_ok(log_base, log_addr, + vhost_get_used_size(vq, vq->num)); +} + /* Verify access for write logging. */ /* Caller should have vq mutex and device mutex */ static bool vq_log_access_ok(struct vhost_virtqueue *vq, @@ -1365,8 +1954,7 @@ static bool vq_log_access_ok(struct vhost_virtqueue *vq, { return vq_memory_access_ok(log_base, vq->umem, vhost_has_feature(vq, VHOST_F_LOG_ALL)) && - (!vq->log_used || log_access_ok(log_base, vq->log_addr, - vhost_get_used_size(vq, vq->num))); + vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr); } /* Can we start vq? */ @@ -1376,10 +1964,6 @@ bool vhost_vq_access_ok(struct vhost_virtqueue *vq) if (!vq_log_access_ok(vq, vq->log_base)) return false; - /* Access validation occurs at prefetch time with IOTLB */ - if (vq->iotlb) - return true; - return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used); } EXPORT_SYMBOL_GPL(vhost_vq_access_ok); @@ -1405,7 +1989,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) memcpy(newmem, &mem, size); if (copy_from_user(newmem->regions, m->regions, - mem.nregions * sizeof *m->regions)) { + flex_array_size(newmem, regions, mem.nregions))) { kvfree(newmem); return -EFAULT; } @@ -1509,10 +2093,9 @@ static long vhost_vring_set_addr(struct vhost_dev *d, return -EINVAL; /* Also validate log access for used ring if enabled. */ - if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) && - !log_access_ok(vq->log_base, a.log_guest_addr, - sizeof *vq->used + - vq->num * sizeof *vq->used->ring)) + if (!vq_log_used_access_ok(vq, vq->log_base, + a.flags & (0x1 << VHOST_VRING_F_LOG), + a.log_guest_addr)) return -EINVAL; } @@ -1554,21 +2137,15 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg struct file *eventfp, *filep = NULL; bool pollstart = false, pollstop = false; struct eventfd_ctx *ctx = NULL; - u32 __user *idxp = argp; struct vhost_virtqueue *vq; struct vhost_vring_state s; struct vhost_vring_file f; u32 idx; long r; - r = get_user(idx, idxp); + r = vhost_get_vq_from_user(d, argp, &vq, &idx); if (r < 0) return r; - if (idx >= d->nvqs) - return -ENOBUFS; - - idx = array_index_nospec(idx, d->nvqs); - vq = d->vqs[idx]; if (ioctl == VHOST_SET_VRING_NUM || ioctl == VHOST_SET_VRING_ADDR) { @@ -1589,17 +2166,26 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg r = -EFAULT; break; } - if (s.num > 0xffff) { - r = -EINVAL; - break; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) { + vq->next_avail_head = vq->last_avail_idx = + s.num & 0xffff; + vq->last_used_idx = (s.num >> 16) & 0xffff; + } else { + if (s.num > 0xffff) { + r = -EINVAL; + break; + } + vq->next_avail_head = vq->last_avail_idx = s.num; } - vq->last_avail_idx = s.num; /* Forget the cached index value. */ vq->avail_idx = vq->last_avail_idx; break; case VHOST_GET_VRING_BASE: s.index = idx; - s.num = vq->last_avail_idx; + if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) + s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16); + else + s.num = vq->last_avail_idx; if (copy_to_user(argp, &s, sizeof s)) r = -EFAULT; break; @@ -1629,7 +2215,8 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg r = PTR_ERR(ctx); break; } - swap(ctx, vq->call_ctx); + + swap(ctx, vq->call_ctx.ctx); break; case VHOST_SET_VRING_ERR: if (copy_from_user(&f, argp, sizeof f)) { @@ -1680,12 +2267,12 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg mutex_unlock(&vq->mutex); if (pollstop && vq->handle_kick) - vhost_poll_flush(&vq->poll); + vhost_dev_flush(vq->poll.dev); return r; } EXPORT_SYMBOL_GPL(vhost_vring_ioctl); -int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled) +int vhost_init_device_iotlb(struct vhost_dev *d) { struct vhost_iotlb *niotlb, *oiotlb; int i; @@ -1726,6 +2313,45 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) goto done; } +#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL + if (ioctl == VHOST_SET_FORK_FROM_OWNER) { + /* Only allow modification before owner is set */ + if (vhost_dev_has_owner(d)) { + r = -EBUSY; + goto done; + } + u8 fork_owner_val; + + if (get_user(fork_owner_val, (u8 __user *)argp)) { + r = -EFAULT; + goto done; + } + if (fork_owner_val != VHOST_FORK_OWNER_TASK && + fork_owner_val != VHOST_FORK_OWNER_KTHREAD) { + r = -EINVAL; + goto done; + } + d->fork_owner = !!fork_owner_val; + r = 0; + goto done; + } + if (ioctl == VHOST_GET_FORK_FROM_OWNER) { + u8 fork_owner_val = d->fork_owner; + + if (fork_owner_val != VHOST_FORK_OWNER_TASK && + fork_owner_val != VHOST_FORK_OWNER_KTHREAD) { + r = -EINVAL; + goto done; + } + if (put_user(fork_owner_val, (u8 __user *)argp)) { + r = -EFAULT; + goto done; + } + r = 0; + goto done; + } +#endif + /* You must be the owner to do anything else */ r = vhost_dev_check_owner(d); if (r) @@ -1786,7 +2412,7 @@ EXPORT_SYMBOL_GPL(vhost_dev_ioctl); /* TODO: This is really inefficient. We need something like get_user() * (instruction directly accesses the data, with an exception table entry - * returning -EFAULT). See Documentation/x86/exception-tables.rst. + * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst. */ static int set_bit_to_user(int nr, void __user *addr) { @@ -1874,7 +2500,7 @@ static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len) static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len) { - struct iovec iov[64]; + struct iovec *iov = vq->log_iov; int i, ret; if (!vq->iotlb) @@ -1895,6 +2521,19 @@ static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len) return 0; } +/* + * vhost_log_write() - Log in dirty page bitmap + * @vq: vhost virtqueue. + * @log: Array of dirty memory in GPA. + * @log_num: Size of vhost_log arrary. + * @len: The total length of memory buffer to log in the dirty bitmap. + * Some drivers may only partially use pages shared via the last + * vring descriptor (i.e. vhost-net RX buffer). + * Use (len == U64_MAX) to indicate the driver would log all + * pages of vring descriptors. + * @iov: Array of dirty memory in HVA. + * @count: Size of iovec array. + */ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, unsigned int log_num, u64 len, struct iovec *iov, int count) { @@ -1918,15 +2557,14 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, r = log_write(vq->log_base, log[i].addr, l); if (r < 0) return r; - len -= l; - if (!len) { - if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); - return 0; - } + + if (len != U64_MAX) + len -= l; } - /* Length written exceeds what we have stored. This is a bug. */ - BUG(); + + if (vq->log_ctx) + eventfd_signal(vq->log_ctx); + return 0; } EXPORT_SYMBOL_GPL(vhost_log_write); @@ -1944,12 +2582,12 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq) log_used(vq, (used - (void __user *)vq->used), sizeof vq->used->flags); if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); + eventfd_signal(vq->log_ctx); } return 0; } -static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) +static int vhost_update_avail_event(struct vhost_virtqueue *vq) { if (vhost_put_avail_event(vq)) return -EFAULT; @@ -1962,7 +2600,7 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) log_used(vq, (used - (void __user *)vq->used), sizeof *vhost_avail_event(vq)); if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); + eventfd_signal(vq->log_ctx); } return 0; } @@ -2009,7 +2647,7 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, struct vhost_dev *dev = vq->dev; struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem; struct iovec *_iov; - u64 s = 0; + u64 s = 0, last = addr + len - 1; int ret = 0; while ((u64)len > s) { @@ -2019,7 +2657,7 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, break; } - map = vhost_iotlb_itree_first(umem, addr, addr + len - 1); + map = vhost_iotlb_itree_first(umem, addr, last); if (map == NULL || map->start > addr) { if (umem != dev->iotlb) { ret = -EFAULT; @@ -2091,12 +2729,7 @@ static int get_indirect(struct vhost_virtqueue *vq, vq_err(vq, "Translation failure %d in indirect.\n", ret); return ret; } - iov_iter_init(&from, READ, vq->indirect, ret, len); - - /* We will use the result as an address to read from, so most - * architectures only need a compiler barrier here. */ - read_barrier_depends(); - + iov_iter_init(&from, ITER_SOURCE, vq->indirect, ret, len); count = len / sizeof desc; /* Buffers are chained via a 16 bit next field, so * we can have at most 2^16 of these. */ @@ -2161,66 +2794,66 @@ static int get_indirect(struct vhost_virtqueue *vq, return 0; } -/* This looks in the virtqueue and for the first available buffer, and converts - * it to an iovec for convenient access. Since descriptors consist of some - * number of output then some number of input descriptors, it's actually two - * iovecs, but we pack them into one and note how many of each there were. +/** + * vhost_get_vq_desc_n - Fetch the next available descriptor chain and build iovecs + * @vq: target virtqueue + * @iov: array that receives the scatter/gather segments + * @iov_size: capacity of @iov in elements + * @out_num: the number of output segments + * @in_num: the number of input segments + * @log: optional array to record addr/len for each writable segment; NULL if unused + * @log_num: optional output; number of entries written to @log when provided + * @ndesc: optional output; number of descriptors consumed from the available ring + * (useful for rollback via vhost_discard_vq_desc) * - * This function returns the descriptor number found, or vq->num (which is - * never a valid descriptor number) if none was found. A negative code is - * returned on error. */ -int vhost_get_vq_desc(struct vhost_virtqueue *vq, - struct iovec iov[], unsigned int iov_size, - unsigned int *out_num, unsigned int *in_num, - struct vhost_log *log, unsigned int *log_num) + * Extracts one available descriptor chain from @vq and translates guest addresses + * into host iovecs. + * + * On success, advances @vq->last_avail_idx by 1 and @vq->next_avail_head by the + * number of descriptors consumed (also stored via @ndesc when non-NULL). + * + * Return: + * - head index in [0, @vq->num) on success; + * - @vq->num if no descriptor is currently available; + * - negative errno on failure + */ +int vhost_get_vq_desc_n(struct vhost_virtqueue *vq, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num, + unsigned int *ndesc) { + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); struct vring_desc desc; unsigned int i, head, found = 0; - u16 last_avail_idx; - __virtio16 avail_idx; + u16 last_avail_idx = vq->last_avail_idx; __virtio16 ring_head; - int ret, access; - - /* Check it isn't doing very strange things with descriptor numbers. */ - last_avail_idx = vq->last_avail_idx; + int ret, access, c = 0; if (vq->avail_idx == vq->last_avail_idx) { - if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) { - vq_err(vq, "Failed to access avail idx at %p\n", - &vq->avail->idx); - return -EFAULT; - } - vq->avail_idx = vhost16_to_cpu(vq, avail_idx); - - if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) { - vq_err(vq, "Guest moved used index from %u to %u", - last_avail_idx, vq->avail_idx); - return -EFAULT; - } + ret = vhost_get_avail_idx(vq); + if (unlikely(ret < 0)) + return ret; - /* If there's nothing new since last we looked, return - * invalid. - */ - if (vq->avail_idx == last_avail_idx) + if (!ret) return vq->num; - - /* Only get avail ring entries after they have been - * exposed by guest. - */ - smp_rmb(); } - /* Grab the next descriptor number they're advertising, and increment - * the index we've seen. */ - if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) { - vq_err(vq, "Failed to read head: idx %d address %p\n", - last_avail_idx, - &vq->avail->ring[last_avail_idx % vq->num]); - return -EFAULT; + if (in_order) + head = vq->next_avail_head & (vq->num - 1); + else { + /* Grab the next descriptor number they're + * advertising, and increment the index we've seen. */ + if (unlikely(vhost_get_avail_head(vq, &ring_head, + last_avail_idx))) { + vq_err(vq, "Failed to read head: idx %d address %p\n", + last_avail_idx, + &vq->avail->ring[last_avail_idx % vq->num]); + return -EFAULT; + } + head = vhost16_to_cpu(vq, ring_head); } - head = vhost16_to_cpu(vq, ring_head); - /* If their number is silly, that's an error. */ if (unlikely(head >= vq->num)) { vq_err(vq, "Guest says index %u > %u is available", @@ -2263,6 +2896,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, "in indirect descriptor at idx %d\n", i); return ret; } + ++c; continue; } @@ -2298,22 +2932,56 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, } *out_num += ret; } + ++c; } while ((i = next_desc(vq, &desc)) != -1); /* On success, increment avail index. */ vq->last_avail_idx++; + vq->next_avail_head += c; + + if (ndesc) + *ndesc = c; /* Assume notifications from guest are disabled at this point, * if they aren't we would need to update avail_event index. */ BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY)); return head; } +EXPORT_SYMBOL_GPL(vhost_get_vq_desc_n); + +/* This looks in the virtqueue and for the first available buffer, and converts + * it to an iovec for convenient access. Since descriptors consist of some + * number of output then some number of input descriptors, it's actually two + * iovecs, but we pack them into one and note how many of each there were. + * + * This function returns the descriptor number found, or vq->num (which is + * never a valid descriptor number) if none was found. A negative code is + * returned on error. + */ +int vhost_get_vq_desc(struct vhost_virtqueue *vq, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num) +{ + return vhost_get_vq_desc_n(vq, iov, iov_size, out_num, in_num, + log, log_num, NULL); +} EXPORT_SYMBOL_GPL(vhost_get_vq_desc); -/* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ -void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) +/** + * vhost_discard_vq_desc - Reverse the effect of vhost_get_vq_desc_n() + * @vq: target virtqueue + * @nbufs: number of buffers to roll back + * @ndesc: number of descriptors to roll back + * + * Rewinds the internal consumer cursors after a failed attempt to use buffers + * returned by vhost_get_vq_desc_n(). + */ +void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int nbufs, + unsigned int ndesc) { - vq->last_avail_idx -= n; + vq->next_avail_head -= ndesc; + vq->last_avail_idx -= nbufs; } EXPORT_SYMBOL_GPL(vhost_discard_vq_desc); @@ -2325,8 +2993,9 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) cpu_to_vhost32(vq, head), cpu_to_vhost32(vq, len) }; + u16 nheads = 1; - return vhost_add_used_n(vq, &heads, 1); + return vhost_add_used_n(vq, &heads, &nheads, 1); } EXPORT_SYMBOL_GPL(vhost_add_used); @@ -2362,10 +3031,9 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, return 0; } -/* After we've used one of their buffers, we tell them about it. We'll then - * want to notify the guest, using eventfd. */ -int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, - unsigned count) +static int vhost_add_used_n_ooo(struct vhost_virtqueue *vq, + struct vring_used_elem *heads, + unsigned count) { int start, n, r; @@ -2378,7 +3046,72 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, heads += n; count -= n; } - r = __vhost_add_used_n(vq, heads, count); + return __vhost_add_used_n(vq, heads, count); +} + +static int vhost_add_used_n_in_order(struct vhost_virtqueue *vq, + struct vring_used_elem *heads, + const u16 *nheads, + unsigned count) +{ + vring_used_elem_t __user *used; + u16 old, new = vq->last_used_idx; + int start, i; + + if (!nheads) + return -EINVAL; + + start = vq->last_used_idx & (vq->num - 1); + used = vq->used->ring + start; + + for (i = 0; i < count; i++) { + if (vhost_put_used(vq, &heads[i], start, 1)) { + vq_err(vq, "Failed to write used"); + return -EFAULT; + } + start += nheads[i]; + new += nheads[i]; + if (start >= vq->num) + start -= vq->num; + } + + if (unlikely(vq->log_used)) { + /* Make sure data is seen before log. */ + smp_wmb(); + /* Log used ring entry write. */ + log_used(vq, ((void __user *)used - (void __user *)vq->used), + (vq->num - start) * sizeof *used); + if (start + count > vq->num) + log_used(vq, 0, + (start + count - vq->num) * sizeof *used); + } + + old = vq->last_used_idx; + vq->last_used_idx = new; + /* If the driver never bothers to signal in a very long while, + * used index might wrap around. If that happens, invalidate + * signalled_used index we stored. TODO: make sure driver + * signals at least once in 2^16 and remove this. */ + if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old))) + vq->signalled_used_valid = false; + return 0; +} + +/* After we've used one of their buffers, we tell them about it. We'll then + * want to notify the guest, using eventfd. */ +int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, + u16 *nheads, unsigned count) +{ + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); + int r; + + if (!in_order || !nheads) + r = vhost_add_used_n_ooo(vq, heads, count); + else + r = vhost_add_used_n_in_order(vq, heads, nheads, count); + + if (r < 0) + return r; /* Make sure buffer is written before we update index. */ smp_wmb(); @@ -2393,7 +3126,7 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, log_used(vq, offsetof(struct vring_used, idx), sizeof vq->used->idx); if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); + eventfd_signal(vq->log_ctx); } return r; } @@ -2440,8 +3173,8 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) { /* Signal the Guest tell them we used something up. */ - if (vq->call_ctx && vhost_notify(dev, vq)) - eventfd_signal(vq->call_ctx, 1); + if (vq->call_ctx.ctx && vhost_notify(dev, vq)) + eventfd_signal(vq->call_ctx.ctx); } EXPORT_SYMBOL_GPL(vhost_signal); @@ -2458,35 +3191,33 @@ EXPORT_SYMBOL_GPL(vhost_add_used_and_signal); /* multi-buffer version of vhost_add_used_and_signal */ void vhost_add_used_and_signal_n(struct vhost_dev *dev, struct vhost_virtqueue *vq, - struct vring_used_elem *heads, unsigned count) + struct vring_used_elem *heads, + u16 *nheads, + unsigned count) { - vhost_add_used_n(vq, heads, count); + vhost_add_used_n(vq, heads, nheads, count); vhost_signal(dev, vq); } EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n); -/* return true if we're sure that avaiable ring is empty */ +/* return true if we're sure that available ring is empty */ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) { - __virtio16 avail_idx; int r; if (vq->avail_idx != vq->last_avail_idx) return false; - r = vhost_get_avail_idx(vq, &avail_idx); - if (unlikely(r)) - return false; - vq->avail_idx = vhost16_to_cpu(vq, avail_idx); + r = vhost_get_avail_idx(vq); - return vq->avail_idx == vq->last_avail_idx; + /* Note: we treat error as non-empty here */ + return r == 0; } EXPORT_SYMBOL_GPL(vhost_vq_avail_empty); /* OK, now we need to know about added descriptors. */ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) { - __virtio16 avail_idx; int r; if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY)) @@ -2500,7 +3231,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) return false; } } else { - r = vhost_update_avail_event(vq, vq->avail_idx); + r = vhost_update_avail_event(vq); if (r) { vq_err(vq, "Failed to update avail event index at %p: %d\n", vhost_avail_event(vq), r); @@ -2510,14 +3241,13 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) /* They could have slipped one in as we were doing that: make * sure it's written, then check again. */ smp_mb(); - r = vhost_get_avail_idx(vq, &avail_idx); - if (r) { - vq_err(vq, "Failed to check avail idx at %p: %d\n", - &vq->avail->idx, r); + + r = vhost_get_avail_idx(vq); + /* Note: we treat error as empty here */ + if (unlikely(r < 0)) return false; - } - return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx; + return r; } EXPORT_SYMBOL_GPL(vhost_enable_notify); @@ -2532,7 +3262,7 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { r = vhost_update_used_flags(vq); if (r) - vq_err(vq, "Failed to enable notification at %p: %d\n", + vq_err(vq, "Failed to disable notification at %p: %d\n", &vq->used->flags, r); } } @@ -2541,12 +3271,11 @@ EXPORT_SYMBOL_GPL(vhost_disable_notify); /* Create a new message. */ struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type) { - struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL); + /* Make sure all padding within the structure is initialized. */ + struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL); if (!node) return NULL; - /* Make sure all padding within the structure is initialized. */ - memset(&node->msg, 0, sizeof node->msg); node->vq = vq; node->msg.type = type; return node; @@ -2581,6 +3310,21 @@ struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev, } EXPORT_SYMBOL_GPL(vhost_dequeue_msg); +void vhost_set_backend_features(struct vhost_dev *dev, u64 features) +{ + struct vhost_virtqueue *vq; + int i; + + mutex_lock(&dev->mutex); + for (i = 0; i < dev->nvqs; ++i) { + vq = dev->vqs[i]; + mutex_lock(&vq->mutex); + vq->acked_backend_features = features; + mutex_unlock(&vq->mutex); + } + mutex_unlock(&dev->mutex); +} +EXPORT_SYMBOL_GPL(vhost_set_backend_features); static int __init vhost_init(void) { |
