diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
| -rw-r--r-- | drivers/vhost/vhost.c | 705 |
1 files changed, 555 insertions, 150 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index c71d573f1c94..bccdc9eab267 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -22,6 +22,7 @@ #include <linux/slab.h> #include <linux/vmalloc.h> #include <linux/kthread.h> +#include <linux/cgroup.h> #include <linux/module.h> #include <linux/sort.h> #include <linux/sched/mm.h> @@ -41,6 +42,13 @@ static int max_iotlb_entries = 2048; module_param(max_iotlb_entries, int, 0444); MODULE_PARM_DESC(max_iotlb_entries, "Maximum number of iotlb entries. (default: 2048)"); +static bool fork_from_owner_default = VHOST_FORK_OWNER_TASK; + +#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL +module_param(fork_from_owner_default, bool, 0444); +MODULE_PARM_DESC(fork_from_owner_default, + "Set task mode as the default(default: Y)"); +#endif enum { VHOST_MEMORY_F_LOG = 0x1, @@ -242,7 +250,7 @@ static void vhost_worker_queue(struct vhost_worker *worker, * test_and_set_bit() implies a memory barrier. */ llist_add(&work->node, &worker->work_list); - vhost_task_wake(worker->vtsk); + worker->ops->wakeup(worker); } } @@ -263,34 +271,37 @@ bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work) } EXPORT_SYMBOL_GPL(vhost_vq_work_queue); -void vhost_vq_flush(struct vhost_virtqueue *vq) -{ - struct vhost_flush_struct flush; - - init_completion(&flush.wait_event); - vhost_work_init(&flush.work, vhost_flush_work); - - if (vhost_vq_work_queue(vq, &flush.work)) - wait_for_completion(&flush.wait_event); -} -EXPORT_SYMBOL_GPL(vhost_vq_flush); - /** - * vhost_worker_flush - flush a worker + * __vhost_worker_flush - flush a worker * @worker: worker to flush * - * This does not use RCU to protect the worker, so the device or worker - * mutex must be held. + * The worker's flush_mutex must be held. */ -static void vhost_worker_flush(struct vhost_worker *worker) +static void __vhost_worker_flush(struct vhost_worker *worker) { struct vhost_flush_struct flush; + if (!worker->attachment_cnt || worker->killed) + return; + init_completion(&flush.wait_event); vhost_work_init(&flush.work, vhost_flush_work); vhost_worker_queue(worker, &flush.work); + /* + * Drop mutex in case our worker is killed and it needs to take the + * mutex to force cleanup. + */ + mutex_unlock(&worker->mutex); wait_for_completion(&flush.wait_event); + mutex_lock(&worker->mutex); +} + +static void vhost_worker_flush(struct vhost_worker *worker) +{ + mutex_lock(&worker->mutex); + __vhost_worker_flush(worker); + mutex_unlock(&worker->mutex); } void vhost_dev_flush(struct vhost_dev *dev) @@ -298,15 +309,8 @@ void vhost_dev_flush(struct vhost_dev *dev) struct vhost_worker *worker; unsigned long i; - xa_for_each(&dev->worker_xa, i, worker) { - mutex_lock(&worker->mutex); - if (!worker->attachment_cnt) { - mutex_unlock(&worker->mutex); - continue; - } + xa_for_each(&dev->worker_xa, i, worker) vhost_worker_flush(worker); - mutex_unlock(&worker->mutex); - } } EXPORT_SYMBOL_GPL(vhost_dev_flush); @@ -368,6 +372,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->avail = NULL; vq->used = NULL; vq->last_avail_idx = 0; + vq->next_avail_head = 0; vq->avail_idx = 0; vq->last_used_idx = 0; vq->signalled_used = 0; @@ -376,7 +381,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_used = false; vq->log_addr = -1ull; vq->private_data = NULL; - vq->acked_features = 0; + virtio_features_zero(vq->acked_features_array); vq->acked_backend_features = 0; vq->log_base = NULL; vq->error_ctx = NULL; @@ -392,7 +397,45 @@ static void vhost_vq_reset(struct vhost_dev *dev, __vhost_vq_meta_reset(vq); } -static bool vhost_worker(void *data) +static int vhost_run_work_kthread_list(void *data) +{ + struct vhost_worker *worker = data; + struct vhost_work *work, *work_next; + struct vhost_dev *dev = worker->dev; + struct llist_node *node; + + kthread_use_mm(dev->mm); + + for (;;) { + /* mb paired w/ kthread_stop */ + set_current_state(TASK_INTERRUPTIBLE); + + if (kthread_should_stop()) { + __set_current_state(TASK_RUNNING); + break; + } + node = llist_del_all(&worker->work_list); + if (!node) + schedule(); + + node = llist_reverse_order(node); + /* make sure flag is seen after deletion */ + smp_wmb(); + llist_for_each_entry_safe(work, work_next, node, node) { + clear_bit(VHOST_WORK_QUEUED, &work->flags); + __set_current_state(TASK_RUNNING); + kcov_remote_start_common(worker->kcov_handle); + work->fn(work); + kcov_remote_stop(); + cond_resched(); + } + } + kthread_unuse_mm(dev->mm); + + return 0; +} + +static bool vhost_run_work_list(void *data) { struct vhost_worker *worker = data; struct vhost_work *work, *work_next; @@ -417,6 +460,40 @@ static bool vhost_worker(void *data) return !!node; } +static void vhost_worker_killed(void *data) +{ + struct vhost_worker *worker = data; + struct vhost_dev *dev = worker->dev; + struct vhost_virtqueue *vq; + int i, attach_cnt = 0; + + mutex_lock(&worker->mutex); + worker->killed = true; + + for (i = 0; i < dev->nvqs; i++) { + vq = dev->vqs[i]; + + mutex_lock(&vq->mutex); + if (worker == + rcu_dereference_check(vq->worker, + lockdep_is_held(&vq->mutex))) { + rcu_assign_pointer(vq->worker, NULL); + attach_cnt++; + } + mutex_unlock(&vq->mutex); + } + + worker->attachment_cnt -= attach_cnt; + if (attach_cnt) + synchronize_rcu(); + /* + * Finish vhost_worker_flush calls and any other works that snuck in + * before the synchronize_rcu. + */ + vhost_run_work_list(worker); + mutex_unlock(&worker->mutex); +} + static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) { kfree(vq->indirect); @@ -425,6 +502,8 @@ static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) vq->log = NULL; kfree(vq->heads); vq->heads = NULL; + kfree(vq->nheads); + vq->nheads = NULL; } /* Helper to allocate iovec buffers for all vqs. */ @@ -442,7 +521,9 @@ static long vhost_dev_alloc_iovecs(struct vhost_dev *dev) GFP_KERNEL); vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads), GFP_KERNEL); - if (!vq->indirect || !vq->log || !vq->heads) + vq->nheads = kmalloc_array(dev->iov_limit, sizeof(*vq->nheads), + GFP_KERNEL); + if (!vq->indirect || !vq->log || !vq->heads || !vq->nheads) goto err_nomem; } return 0; @@ -522,6 +603,7 @@ void vhost_dev_init(struct vhost_dev *dev, dev->byte_weight = byte_weight; dev->use_worker = use_worker; dev->msg_handler = msg_handler; + dev->fork_owner = fork_from_owner_default; init_waitqueue_head(&dev->wait); INIT_LIST_HEAD(&dev->read_list); INIT_LIST_HEAD(&dev->pending_list); @@ -533,6 +615,7 @@ void vhost_dev_init(struct vhost_dev *dev, vq->log = NULL; vq->indirect = NULL; vq->heads = NULL; + vq->nheads = NULL; vq->dev = dev; mutex_init(&vq->mutex); vhost_vq_reset(dev, vq); @@ -551,6 +634,46 @@ long vhost_dev_check_owner(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_dev_check_owner); +struct vhost_attach_cgroups_struct { + struct vhost_work work; + struct task_struct *owner; + int ret; +}; + +static void vhost_attach_cgroups_work(struct vhost_work *work) +{ + struct vhost_attach_cgroups_struct *s; + + s = container_of(work, struct vhost_attach_cgroups_struct, work); + s->ret = cgroup_attach_task_all(s->owner, current); +} + +static int vhost_attach_task_to_cgroups(struct vhost_worker *worker) +{ + struct vhost_attach_cgroups_struct attach; + int saved_cnt; + + attach.owner = current; + + vhost_work_init(&attach.work, vhost_attach_cgroups_work); + vhost_worker_queue(worker, &attach.work); + + mutex_lock(&worker->mutex); + + /* + * Bypass attachment_cnt check in __vhost_worker_flush: + * Temporarily change it to INT_MAX to bypass the check + */ + saved_cnt = worker->attachment_cnt; + worker->attachment_cnt = INT_MAX; + __vhost_worker_flush(worker); + worker->attachment_cnt = saved_cnt; + + mutex_unlock(&worker->mutex); + + return attach.ret; +} + /* Caller should have device mutex */ bool vhost_dev_has_owner(struct vhost_dev *dev) { @@ -564,10 +687,10 @@ static void vhost_attach_mm(struct vhost_dev *dev) if (dev->use_worker) { dev->mm = get_task_mm(current); } else { - /* vDPA device does not use worker thead, so there's - * no need to hold the address space for mm. This help + /* vDPA device does not use worker thread, so there's + * no need to hold the address space for mm. This helps * to avoid deadlock in the case of mmap() which may - * held the refcnt of the file and depends on release + * hold the refcnt of the file and depends on release * method to remove vma. */ dev->mm = current->mm; @@ -596,7 +719,7 @@ static void vhost_worker_destroy(struct vhost_dev *dev, WARN_ON(!llist_empty(&worker->work_list)); xa_erase(&dev->worker_xa, worker->id); - vhost_task_stop(worker->vtsk); + worker->ops->stop(worker); kfree(worker); } @@ -619,40 +742,117 @@ static void vhost_workers_free(struct vhost_dev *dev) xa_destroy(&dev->worker_xa); } +static void vhost_task_wakeup(struct vhost_worker *worker) +{ + return vhost_task_wake(worker->vtsk); +} + +static void vhost_kthread_wakeup(struct vhost_worker *worker) +{ + wake_up_process(worker->kthread_task); +} + +static void vhost_task_do_stop(struct vhost_worker *worker) +{ + return vhost_task_stop(worker->vtsk); +} + +static void vhost_kthread_do_stop(struct vhost_worker *worker) +{ + kthread_stop(worker->kthread_task); +} + +static int vhost_task_worker_create(struct vhost_worker *worker, + struct vhost_dev *dev, const char *name) +{ + struct vhost_task *vtsk; + u32 id; + int ret; + + vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed, + worker, name); + if (IS_ERR(vtsk)) + return PTR_ERR(vtsk); + + worker->vtsk = vtsk; + vhost_task_start(vtsk); + ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + if (ret < 0) { + vhost_task_do_stop(worker); + return ret; + } + worker->id = id; + return 0; +} + +static int vhost_kthread_worker_create(struct vhost_worker *worker, + struct vhost_dev *dev, const char *name) +{ + struct task_struct *task; + u32 id; + int ret; + + task = kthread_create(vhost_run_work_kthread_list, worker, "%s", name); + if (IS_ERR(task)) + return PTR_ERR(task); + + worker->kthread_task = task; + wake_up_process(task); + ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + if (ret < 0) + goto stop_worker; + + ret = vhost_attach_task_to_cgroups(worker); + if (ret) + goto free_id; + + worker->id = id; + return 0; + +free_id: + xa_erase(&dev->worker_xa, id); +stop_worker: + vhost_kthread_do_stop(worker); + return ret; +} + +static const struct vhost_worker_ops kthread_ops = { + .create = vhost_kthread_worker_create, + .stop = vhost_kthread_do_stop, + .wakeup = vhost_kthread_wakeup, +}; + +static const struct vhost_worker_ops vhost_task_ops = { + .create = vhost_task_worker_create, + .stop = vhost_task_do_stop, + .wakeup = vhost_task_wakeup, +}; + static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev) { struct vhost_worker *worker; - struct vhost_task *vtsk; char name[TASK_COMM_LEN]; int ret; - u32 id; + const struct vhost_worker_ops *ops = dev->fork_owner ? &vhost_task_ops : + &kthread_ops; worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT); if (!worker) return NULL; + worker->dev = dev; + worker->ops = ops; snprintf(name, sizeof(name), "vhost-%d", current->pid); - vtsk = vhost_task_create(vhost_worker, worker, name); - if (!vtsk) - goto free_worker; - mutex_init(&worker->mutex); init_llist_head(&worker->work_list); worker->kcov_handle = kcov_common_handle(); - worker->vtsk = vtsk; - - vhost_task_start(vtsk); - - ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + ret = ops->create(worker, dev, name); if (ret < 0) - goto stop_worker; - worker->id = id; + goto free_worker; return worker; -stop_worker: - vhost_task_stop(vtsk); free_worker: kfree(worker); return NULL; @@ -664,32 +864,49 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, { struct vhost_worker *old_worker; - old_worker = rcu_dereference_check(vq->worker, - lockdep_is_held(&vq->dev->mutex)); - mutex_lock(&worker->mutex); - worker->attachment_cnt++; - mutex_unlock(&worker->mutex); + if (worker->killed) { + mutex_unlock(&worker->mutex); + return; + } + + mutex_lock(&vq->mutex); + + old_worker = rcu_dereference_check(vq->worker, + lockdep_is_held(&vq->mutex)); rcu_assign_pointer(vq->worker, worker); + worker->attachment_cnt++; - if (!old_worker) + if (!old_worker) { + mutex_unlock(&vq->mutex); + mutex_unlock(&worker->mutex); return; + } + mutex_unlock(&vq->mutex); + mutex_unlock(&worker->mutex); + /* * Take the worker mutex to make sure we see the work queued from * device wide flushes which doesn't use RCU for execution. */ mutex_lock(&old_worker->mutex); - old_worker->attachment_cnt--; + if (old_worker->killed) { + mutex_unlock(&old_worker->mutex); + return; + } + /* * We don't want to call synchronize_rcu for every vq during setup * because it will slow down VM startup. If we haven't done * VHOST_SET_VRING_KICK and not done the driver specific - * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will + * SET_ENDPOINT/RUNNING then we can skip the sync since there will * not be any works queued for scsi and net. */ mutex_lock(&vq->mutex); if (!vhost_vq_get_backend(vq) && !vq->kick) { mutex_unlock(&vq->mutex); + + old_worker->attachment_cnt--; mutex_unlock(&old_worker->mutex); /* * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID. @@ -705,7 +922,8 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, /* Make sure new vq queue/flush/poll calls see the new worker */ synchronize_rcu(); /* Make sure whatever was queued gets run */ - vhost_worker_flush(old_worker); + __vhost_worker_flush(old_worker); + old_worker->attachment_cnt--; mutex_unlock(&old_worker->mutex); } @@ -754,10 +972,16 @@ static int vhost_free_worker(struct vhost_dev *dev, return -ENODEV; mutex_lock(&worker->mutex); - if (worker->attachment_cnt) { + if (worker->attachment_cnt || worker->killed) { mutex_unlock(&worker->mutex); return -EBUSY; } + /* + * A flush might have raced and snuck in before attachment_cnt was set + * to zero. Make sure flushes are flushed from the queue before + * freeing. + */ + __vhost_worker_flush(worker); mutex_unlock(&worker->mutex); vhost_worker_destroy(dev, worker); @@ -809,6 +1033,14 @@ long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl, switch (ioctl) { /* dev worker ioctls */ case VHOST_NEW_WORKER: + /* + * vhost_tasks will account for worker threads under the parent's + * NPROC value but kthreads do not. To avoid userspace overflowing + * the system with worker threads fork_owner must be true. + */ + if (!dev->fork_owner) + return -EFAULT; + ret = vhost_new_worker(dev, &state); if (!ret && copy_to_user(argp, &state, sizeof(state))) ret = -EFAULT; @@ -926,6 +1158,7 @@ void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem) vhost_dev_cleanup(dev); + dev->fork_owner = fork_from_owner_default; dev->umem = umem; /* We don't need VQ locks below since vhost_dev_cleanup makes sure * VQs aren't running. @@ -1290,10 +1523,36 @@ static void vhost_dev_unlock_vqs(struct vhost_dev *d) mutex_unlock(&d->vqs[i]->mutex); } -static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq, - __virtio16 *idx) +static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq) { - return vhost_get_avail(vq, *idx, &vq->avail->idx); + __virtio16 idx; + int r; + + r = vhost_get_avail(vq, idx, &vq->avail->idx); + if (unlikely(r < 0)) { + vq_err(vq, "Failed to access available index at %p (%d)\n", + &vq->avail->idx, r); + return r; + } + + /* Check it isn't doing very strange thing with available indexes */ + vq->avail_idx = vhost16_to_cpu(vq, idx); + if (unlikely((u16)(vq->avail_idx - vq->last_avail_idx) > vq->num)) { + vq_err(vq, "Invalid available index change from %u to %u", + vq->last_avail_idx, vq->avail_idx); + return -EINVAL; + } + + /* We're done if there is nothing new */ + if (vq->avail_idx == vq->last_avail_idx) + return 0; + + /* + * We updated vq->avail_idx so we need a memory barrier between + * the index read above and the caller reading avail ring entries. + */ + smp_rmb(); + return 1; } static inline int vhost_get_avail_head(struct vhost_virtqueue *vq, @@ -1458,9 +1717,7 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, goto done; } - if ((msg.type == VHOST_IOTLB_UPDATE || - msg.type == VHOST_IOTLB_INVALIDATE) && - msg.size == 0) { + if (msg.type == VHOST_IOTLB_UPDATE && msg.size == 0) { ret = -EINVAL; goto done; } @@ -1910,14 +2167,15 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg break; } if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) { - vq->last_avail_idx = s.num & 0xffff; + vq->next_avail_head = vq->last_avail_idx = + s.num & 0xffff; vq->last_used_idx = (s.num >> 16) & 0xffff; } else { if (s.num > 0xffff) { r = -EINVAL; break; } - vq->last_avail_idx = s.num; + vq->next_avail_head = vq->last_avail_idx = s.num; } /* Forget the cached index value. */ vq->avail_idx = vq->last_avail_idx; @@ -2055,6 +2313,45 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) goto done; } +#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL + if (ioctl == VHOST_SET_FORK_FROM_OWNER) { + /* Only allow modification before owner is set */ + if (vhost_dev_has_owner(d)) { + r = -EBUSY; + goto done; + } + u8 fork_owner_val; + + if (get_user(fork_owner_val, (u8 __user *)argp)) { + r = -EFAULT; + goto done; + } + if (fork_owner_val != VHOST_FORK_OWNER_TASK && + fork_owner_val != VHOST_FORK_OWNER_KTHREAD) { + r = -EINVAL; + goto done; + } + d->fork_owner = !!fork_owner_val; + r = 0; + goto done; + } + if (ioctl == VHOST_GET_FORK_FROM_OWNER) { + u8 fork_owner_val = d->fork_owner; + + if (fork_owner_val != VHOST_FORK_OWNER_TASK && + fork_owner_val != VHOST_FORK_OWNER_KTHREAD) { + r = -EINVAL; + goto done; + } + if (put_user(fork_owner_val, (u8 __user *)argp)) { + r = -EFAULT; + goto done; + } + r = 0; + goto done; + } +#endif + /* You must be the owner to do anything else */ r = vhost_dev_check_owner(d); if (r) @@ -2224,6 +2521,19 @@ static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len) return 0; } +/* + * vhost_log_write() - Log in dirty page bitmap + * @vq: vhost virtqueue. + * @log: Array of dirty memory in GPA. + * @log_num: Size of vhost_log arrary. + * @len: The total length of memory buffer to log in the dirty bitmap. + * Some drivers may only partially use pages shared via the last + * vring descriptor (i.e. vhost-net RX buffer). + * Use (len == U64_MAX) to indicate the driver would log all + * pages of vring descriptors. + * @iov: Array of dirty memory in HVA. + * @count: Size of iovec array. + */ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, unsigned int log_num, u64 len, struct iovec *iov, int count) { @@ -2247,15 +2557,14 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, r = log_write(vq->log_base, log[i].addr, l); if (r < 0) return r; - len -= l; - if (!len) { - if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); - return 0; - } + + if (len != U64_MAX) + len -= l; } - /* Length written exceeds what we have stored. This is a bug. */ - BUG(); + + if (vq->log_ctx) + eventfd_signal(vq->log_ctx); + return 0; } EXPORT_SYMBOL_GPL(vhost_log_write); @@ -2273,7 +2582,7 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq) log_used(vq, (used - (void __user *)vq->used), sizeof vq->used->flags); if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); + eventfd_signal(vq->log_ctx); } return 0; } @@ -2291,7 +2600,7 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq) log_used(vq, (used - (void __user *)vq->used), sizeof *vhost_avail_event(vq)); if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); + eventfd_signal(vq->log_ctx); } return 0; } @@ -2485,66 +2794,66 @@ static int get_indirect(struct vhost_virtqueue *vq, return 0; } -/* This looks in the virtqueue and for the first available buffer, and converts - * it to an iovec for convenient access. Since descriptors consist of some - * number of output then some number of input descriptors, it's actually two - * iovecs, but we pack them into one and note how many of each there were. +/** + * vhost_get_vq_desc_n - Fetch the next available descriptor chain and build iovecs + * @vq: target virtqueue + * @iov: array that receives the scatter/gather segments + * @iov_size: capacity of @iov in elements + * @out_num: the number of output segments + * @in_num: the number of input segments + * @log: optional array to record addr/len for each writable segment; NULL if unused + * @log_num: optional output; number of entries written to @log when provided + * @ndesc: optional output; number of descriptors consumed from the available ring + * (useful for rollback via vhost_discard_vq_desc) * - * This function returns the descriptor number found, or vq->num (which is - * never a valid descriptor number) if none was found. A negative code is - * returned on error. */ -int vhost_get_vq_desc(struct vhost_virtqueue *vq, - struct iovec iov[], unsigned int iov_size, - unsigned int *out_num, unsigned int *in_num, - struct vhost_log *log, unsigned int *log_num) + * Extracts one available descriptor chain from @vq and translates guest addresses + * into host iovecs. + * + * On success, advances @vq->last_avail_idx by 1 and @vq->next_avail_head by the + * number of descriptors consumed (also stored via @ndesc when non-NULL). + * + * Return: + * - head index in [0, @vq->num) on success; + * - @vq->num if no descriptor is currently available; + * - negative errno on failure + */ +int vhost_get_vq_desc_n(struct vhost_virtqueue *vq, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num, + unsigned int *ndesc) { + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); struct vring_desc desc; unsigned int i, head, found = 0; - u16 last_avail_idx; - __virtio16 avail_idx; + u16 last_avail_idx = vq->last_avail_idx; __virtio16 ring_head; - int ret, access; - - /* Check it isn't doing very strange things with descriptor numbers. */ - last_avail_idx = vq->last_avail_idx; + int ret, access, c = 0; if (vq->avail_idx == vq->last_avail_idx) { - if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) { - vq_err(vq, "Failed to access avail idx at %p\n", - &vq->avail->idx); - return -EFAULT; - } - vq->avail_idx = vhost16_to_cpu(vq, avail_idx); - - if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) { - vq_err(vq, "Guest moved used index from %u to %u", - last_avail_idx, vq->avail_idx); - return -EFAULT; - } + ret = vhost_get_avail_idx(vq); + if (unlikely(ret < 0)) + return ret; - /* If there's nothing new since last we looked, return - * invalid. - */ - if (vq->avail_idx == last_avail_idx) + if (!ret) return vq->num; - - /* Only get avail ring entries after they have been - * exposed by guest. - */ - smp_rmb(); } - /* Grab the next descriptor number they're advertising, and increment - * the index we've seen. */ - if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) { - vq_err(vq, "Failed to read head: idx %d address %p\n", - last_avail_idx, - &vq->avail->ring[last_avail_idx % vq->num]); - return -EFAULT; + if (in_order) + head = vq->next_avail_head & (vq->num - 1); + else { + /* Grab the next descriptor number they're + * advertising, and increment the index we've seen. */ + if (unlikely(vhost_get_avail_head(vq, &ring_head, + last_avail_idx))) { + vq_err(vq, "Failed to read head: idx %d address %p\n", + last_avail_idx, + &vq->avail->ring[last_avail_idx % vq->num]); + return -EFAULT; + } + head = vhost16_to_cpu(vq, ring_head); } - head = vhost16_to_cpu(vq, ring_head); - /* If their number is silly, that's an error. */ if (unlikely(head >= vq->num)) { vq_err(vq, "Guest says index %u > %u is available", @@ -2587,6 +2896,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, "in indirect descriptor at idx %d\n", i); return ret; } + ++c; continue; } @@ -2622,22 +2932,56 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, } *out_num += ret; } + ++c; } while ((i = next_desc(vq, &desc)) != -1); /* On success, increment avail index. */ vq->last_avail_idx++; + vq->next_avail_head += c; + + if (ndesc) + *ndesc = c; /* Assume notifications from guest are disabled at this point, * if they aren't we would need to update avail_event index. */ BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY)); return head; } +EXPORT_SYMBOL_GPL(vhost_get_vq_desc_n); + +/* This looks in the virtqueue and for the first available buffer, and converts + * it to an iovec for convenient access. Since descriptors consist of some + * number of output then some number of input descriptors, it's actually two + * iovecs, but we pack them into one and note how many of each there were. + * + * This function returns the descriptor number found, or vq->num (which is + * never a valid descriptor number) if none was found. A negative code is + * returned on error. + */ +int vhost_get_vq_desc(struct vhost_virtqueue *vq, + struct iovec iov[], unsigned int iov_size, + unsigned int *out_num, unsigned int *in_num, + struct vhost_log *log, unsigned int *log_num) +{ + return vhost_get_vq_desc_n(vq, iov, iov_size, out_num, in_num, + log, log_num, NULL); +} EXPORT_SYMBOL_GPL(vhost_get_vq_desc); -/* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ -void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) +/** + * vhost_discard_vq_desc - Reverse the effect of vhost_get_vq_desc_n() + * @vq: target virtqueue + * @nbufs: number of buffers to roll back + * @ndesc: number of descriptors to roll back + * + * Rewinds the internal consumer cursors after a failed attempt to use buffers + * returned by vhost_get_vq_desc_n(). + */ +void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int nbufs, + unsigned int ndesc) { - vq->last_avail_idx -= n; + vq->next_avail_head -= ndesc; + vq->last_avail_idx -= nbufs; } EXPORT_SYMBOL_GPL(vhost_discard_vq_desc); @@ -2649,8 +2993,9 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) cpu_to_vhost32(vq, head), cpu_to_vhost32(vq, len) }; + u16 nheads = 1; - return vhost_add_used_n(vq, &heads, 1); + return vhost_add_used_n(vq, &heads, &nheads, 1); } EXPORT_SYMBOL_GPL(vhost_add_used); @@ -2686,10 +3031,9 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, return 0; } -/* After we've used one of their buffers, we tell them about it. We'll then - * want to notify the guest, using eventfd. */ -int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, - unsigned count) +static int vhost_add_used_n_ooo(struct vhost_virtqueue *vq, + struct vring_used_elem *heads, + unsigned count) { int start, n, r; @@ -2702,7 +3046,72 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, heads += n; count -= n; } - r = __vhost_add_used_n(vq, heads, count); + return __vhost_add_used_n(vq, heads, count); +} + +static int vhost_add_used_n_in_order(struct vhost_virtqueue *vq, + struct vring_used_elem *heads, + const u16 *nheads, + unsigned count) +{ + vring_used_elem_t __user *used; + u16 old, new = vq->last_used_idx; + int start, i; + + if (!nheads) + return -EINVAL; + + start = vq->last_used_idx & (vq->num - 1); + used = vq->used->ring + start; + + for (i = 0; i < count; i++) { + if (vhost_put_used(vq, &heads[i], start, 1)) { + vq_err(vq, "Failed to write used"); + return -EFAULT; + } + start += nheads[i]; + new += nheads[i]; + if (start >= vq->num) + start -= vq->num; + } + + if (unlikely(vq->log_used)) { + /* Make sure data is seen before log. */ + smp_wmb(); + /* Log used ring entry write. */ + log_used(vq, ((void __user *)used - (void __user *)vq->used), + (vq->num - start) * sizeof *used); + if (start + count > vq->num) + log_used(vq, 0, + (start + count - vq->num) * sizeof *used); + } + + old = vq->last_used_idx; + vq->last_used_idx = new; + /* If the driver never bothers to signal in a very long while, + * used index might wrap around. If that happens, invalidate + * signalled_used index we stored. TODO: make sure driver + * signals at least once in 2^16 and remove this. */ + if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old))) + vq->signalled_used_valid = false; + return 0; +} + +/* After we've used one of their buffers, we tell them about it. We'll then + * want to notify the guest, using eventfd. */ +int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, + u16 *nheads, unsigned count) +{ + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); + int r; + + if (!in_order || !nheads) + r = vhost_add_used_n_ooo(vq, heads, count); + else + r = vhost_add_used_n_in_order(vq, heads, nheads, count); + + if (r < 0) + return r; /* Make sure buffer is written before we update index. */ smp_wmb(); @@ -2717,7 +3126,7 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, log_used(vq, offsetof(struct vring_used, idx), sizeof vq->used->idx); if (vq->log_ctx) - eventfd_signal(vq->log_ctx, 1); + eventfd_signal(vq->log_ctx); } return r; } @@ -2765,7 +3174,7 @@ void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) { /* Signal the Guest tell them we used something up. */ if (vq->call_ctx.ctx && vhost_notify(dev, vq)) - eventfd_signal(vq->call_ctx.ctx, 1); + eventfd_signal(vq->call_ctx.ctx); } EXPORT_SYMBOL_GPL(vhost_signal); @@ -2782,35 +3191,33 @@ EXPORT_SYMBOL_GPL(vhost_add_used_and_signal); /* multi-buffer version of vhost_add_used_and_signal */ void vhost_add_used_and_signal_n(struct vhost_dev *dev, struct vhost_virtqueue *vq, - struct vring_used_elem *heads, unsigned count) + struct vring_used_elem *heads, + u16 *nheads, + unsigned count) { - vhost_add_used_n(vq, heads, count); + vhost_add_used_n(vq, heads, nheads, count); vhost_signal(dev, vq); } EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n); -/* return true if we're sure that avaiable ring is empty */ +/* return true if we're sure that available ring is empty */ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) { - __virtio16 avail_idx; int r; if (vq->avail_idx != vq->last_avail_idx) return false; - r = vhost_get_avail_idx(vq, &avail_idx); - if (unlikely(r)) - return false; - vq->avail_idx = vhost16_to_cpu(vq, avail_idx); + r = vhost_get_avail_idx(vq); - return vq->avail_idx == vq->last_avail_idx; + /* Note: we treat error as non-empty here */ + return r == 0; } EXPORT_SYMBOL_GPL(vhost_vq_avail_empty); /* OK, now we need to know about added descriptors. */ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) { - __virtio16 avail_idx; int r; if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY)) @@ -2834,15 +3241,13 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) /* They could have slipped one in as we were doing that: make * sure it's written, then check again. */ smp_mb(); - r = vhost_get_avail_idx(vq, &avail_idx); - if (r) { - vq_err(vq, "Failed to check avail idx at %p: %d\n", - &vq->avail->idx, r); + + r = vhost_get_avail_idx(vq); + /* Note: we treat error as empty here */ + if (unlikely(r < 0)) return false; - } - vq->avail_idx = vhost16_to_cpu(vq, avail_idx); - return vq->avail_idx != vq->last_avail_idx; + return r; } EXPORT_SYMBOL_GPL(vhost_enable_notify); |
