summaryrefslogtreecommitdiff
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c1202
1 files changed, 973 insertions, 229 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index d7b8df3edffc..bccdc9eab267 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -27,6 +27,7 @@
#include <linux/sort.h>
#include <linux/sched/mm.h>
#include <linux/sched/signal.h>
+#include <linux/sched/vhost_task.h>
#include <linux/interval_tree_generic.h>
#include <linux/nospec.h>
#include <linux/kcov.h>
@@ -41,6 +42,13 @@ static int max_iotlb_entries = 2048;
module_param(max_iotlb_entries, int, 0444);
MODULE_PARM_DESC(max_iotlb_entries,
"Maximum number of iotlb entries. (default: 2048)");
+static bool fork_from_owner_default = VHOST_FORK_OWNER_TASK;
+
+#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL
+module_param(fork_from_owner_default, bool, 0444);
+MODULE_PARM_DESC(fork_from_owner_default,
+ "Set task mode as the default(default: Y)");
+#endif
enum {
VHOST_MEMORY_F_LOG = 0x1,
@@ -187,13 +195,15 @@ EXPORT_SYMBOL_GPL(vhost_work_init);
/* Init poll structure */
void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
- __poll_t mask, struct vhost_dev *dev)
+ __poll_t mask, struct vhost_dev *dev,
+ struct vhost_virtqueue *vq)
{
init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
init_poll_funcptr(&poll->table, vhost_poll_func);
poll->mask = mask;
poll->dev = dev;
poll->wqh = NULL;
+ poll->vq = vq;
vhost_work_init(&poll->work, fn);
}
@@ -231,54 +241,98 @@ void vhost_poll_stop(struct vhost_poll *poll)
}
EXPORT_SYMBOL_GPL(vhost_poll_stop);
-void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
+static void vhost_worker_queue(struct vhost_worker *worker,
+ struct vhost_work *work)
{
- struct vhost_flush_struct flush;
+ if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
+ /* We can only add the work to the list after we're
+ * sure it was not in the list.
+ * test_and_set_bit() implies a memory barrier.
+ */
+ llist_add(&work->node, &worker->work_list);
+ worker->ops->wakeup(worker);
+ }
+}
- if (dev->worker) {
- init_completion(&flush.wait_event);
- vhost_work_init(&flush.work, vhost_flush_work);
+bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
+{
+ struct vhost_worker *worker;
+ bool queued = false;
- vhost_work_queue(dev, &flush.work);
- wait_for_completion(&flush.wait_event);
+ rcu_read_lock();
+ worker = rcu_dereference(vq->worker);
+ if (worker) {
+ queued = true;
+ vhost_worker_queue(worker, work);
}
+ rcu_read_unlock();
+
+ return queued;
}
-EXPORT_SYMBOL_GPL(vhost_work_flush);
+EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
-/* Flush any work that has been scheduled. When calling this, don't hold any
- * locks that are also used by the callback. */
-void vhost_poll_flush(struct vhost_poll *poll)
+/**
+ * __vhost_worker_flush - flush a worker
+ * @worker: worker to flush
+ *
+ * The worker's flush_mutex must be held.
+ */
+static void __vhost_worker_flush(struct vhost_worker *worker)
{
- vhost_work_flush(poll->dev, &poll->work);
+ struct vhost_flush_struct flush;
+
+ if (!worker->attachment_cnt || worker->killed)
+ return;
+
+ init_completion(&flush.wait_event);
+ vhost_work_init(&flush.work, vhost_flush_work);
+
+ vhost_worker_queue(worker, &flush.work);
+ /*
+ * Drop mutex in case our worker is killed and it needs to take the
+ * mutex to force cleanup.
+ */
+ mutex_unlock(&worker->mutex);
+ wait_for_completion(&flush.wait_event);
+ mutex_lock(&worker->mutex);
}
-EXPORT_SYMBOL_GPL(vhost_poll_flush);
-void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
+static void vhost_worker_flush(struct vhost_worker *worker)
{
- if (!dev->worker)
- return;
+ mutex_lock(&worker->mutex);
+ __vhost_worker_flush(worker);
+ mutex_unlock(&worker->mutex);
+}
- if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
- /* We can only add the work to the list after we're
- * sure it was not in the list.
- * test_and_set_bit() implies a memory barrier.
- */
- llist_add(&work->node, &dev->work_list);
- wake_up_process(dev->worker);
- }
+void vhost_dev_flush(struct vhost_dev *dev)
+{
+ struct vhost_worker *worker;
+ unsigned long i;
+
+ xa_for_each(&dev->worker_xa, i, worker)
+ vhost_worker_flush(worker);
}
-EXPORT_SYMBOL_GPL(vhost_work_queue);
+EXPORT_SYMBOL_GPL(vhost_dev_flush);
/* A lockless hint for busy polling code to exit the loop */
-bool vhost_has_work(struct vhost_dev *dev)
+bool vhost_vq_has_work(struct vhost_virtqueue *vq)
{
- return !llist_empty(&dev->work_list);
+ struct vhost_worker *worker;
+ bool has_work = false;
+
+ rcu_read_lock();
+ worker = rcu_dereference(vq->worker);
+ if (worker && !llist_empty(&worker->work_list))
+ has_work = true;
+ rcu_read_unlock();
+
+ return has_work;
}
-EXPORT_SYMBOL_GPL(vhost_has_work);
+EXPORT_SYMBOL_GPL(vhost_vq_has_work);
void vhost_poll_queue(struct vhost_poll *poll)
{
- vhost_work_queue(poll->dev, &poll->work);
+ vhost_vq_work_queue(poll->vq, &poll->work);
}
EXPORT_SYMBOL_GPL(vhost_poll_queue);
@@ -298,6 +352,18 @@ static void vhost_vq_meta_reset(struct vhost_dev *d)
__vhost_vq_meta_reset(d->vqs[i]);
}
+static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx)
+{
+ call_ctx->ctx = NULL;
+ memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer));
+}
+
+bool vhost_vq_is_setup(struct vhost_virtqueue *vq)
+{
+ return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq);
+}
+EXPORT_SYMBOL_GPL(vhost_vq_is_setup);
+
static void vhost_vq_reset(struct vhost_dev *dev,
struct vhost_virtqueue *vq)
{
@@ -306,6 +372,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->avail = NULL;
vq->used = NULL;
vq->last_avail_idx = 0;
+ vq->next_avail_head = 0;
vq->avail_idx = 0;
vq->last_used_idx = 0;
vq->signalled_used = 0;
@@ -314,25 +381,27 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->log_used = false;
vq->log_addr = -1ull;
vq->private_data = NULL;
- vq->acked_features = 0;
+ virtio_features_zero(vq->acked_features_array);
vq->acked_backend_features = 0;
vq->log_base = NULL;
vq->error_ctx = NULL;
vq->kick = NULL;
- vq->call_ctx = NULL;
vq->log_ctx = NULL;
- vhost_reset_is_le(vq);
vhost_disable_cross_endian(vq);
+ vhost_reset_is_le(vq);
vq->busyloop_timeout = 0;
vq->umem = NULL;
vq->iotlb = NULL;
+ rcu_assign_pointer(vq->worker, NULL);
+ vhost_vring_call_reset(&vq->call_ctx);
__vhost_vq_meta_reset(vq);
}
-static int vhost_worker(void *data)
+static int vhost_run_work_kthread_list(void *data)
{
- struct vhost_dev *dev = data;
+ struct vhost_worker *worker = data;
struct vhost_work *work, *work_next;
+ struct vhost_dev *dev = worker->dev;
struct llist_node *node;
kthread_use_mm(dev->mm);
@@ -345,8 +414,7 @@ static int vhost_worker(void *data)
__set_current_state(TASK_RUNNING);
break;
}
-
- node = llist_del_all(&dev->work_list);
+ node = llist_del_all(&worker->work_list);
if (!node)
schedule();
@@ -356,17 +424,76 @@ static int vhost_worker(void *data)
llist_for_each_entry_safe(work, work_next, node, node) {
clear_bit(VHOST_WORK_QUEUED, &work->flags);
__set_current_state(TASK_RUNNING);
- kcov_remote_start_common(dev->kcov_handle);
+ kcov_remote_start_common(worker->kcov_handle);
work->fn(work);
kcov_remote_stop();
- if (need_resched())
- schedule();
+ cond_resched();
}
}
kthread_unuse_mm(dev->mm);
+
return 0;
}
+static bool vhost_run_work_list(void *data)
+{
+ struct vhost_worker *worker = data;
+ struct vhost_work *work, *work_next;
+ struct llist_node *node;
+
+ node = llist_del_all(&worker->work_list);
+ if (node) {
+ __set_current_state(TASK_RUNNING);
+
+ node = llist_reverse_order(node);
+ /* make sure flag is seen after deletion */
+ smp_wmb();
+ llist_for_each_entry_safe(work, work_next, node, node) {
+ clear_bit(VHOST_WORK_QUEUED, &work->flags);
+ kcov_remote_start_common(worker->kcov_handle);
+ work->fn(work);
+ kcov_remote_stop();
+ cond_resched();
+ }
+ }
+
+ return !!node;
+}
+
+static void vhost_worker_killed(void *data)
+{
+ struct vhost_worker *worker = data;
+ struct vhost_dev *dev = worker->dev;
+ struct vhost_virtqueue *vq;
+ int i, attach_cnt = 0;
+
+ mutex_lock(&worker->mutex);
+ worker->killed = true;
+
+ for (i = 0; i < dev->nvqs; i++) {
+ vq = dev->vqs[i];
+
+ mutex_lock(&vq->mutex);
+ if (worker ==
+ rcu_dereference_check(vq->worker,
+ lockdep_is_held(&vq->mutex))) {
+ rcu_assign_pointer(vq->worker, NULL);
+ attach_cnt++;
+ }
+ mutex_unlock(&vq->mutex);
+ }
+
+ worker->attachment_cnt -= attach_cnt;
+ if (attach_cnt)
+ synchronize_rcu();
+ /*
+ * Finish vhost_worker_flush calls and any other works that snuck in
+ * before the synchronize_rcu.
+ */
+ vhost_run_work_list(worker);
+ mutex_unlock(&worker->mutex);
+}
+
static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
{
kfree(vq->indirect);
@@ -375,6 +502,8 @@ static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
vq->log = NULL;
kfree(vq->heads);
vq->heads = NULL;
+ kfree(vq->nheads);
+ vq->nheads = NULL;
}
/* Helper to allocate iovec buffers for all vqs. */
@@ -392,7 +521,9 @@ static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
GFP_KERNEL);
vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
GFP_KERNEL);
- if (!vq->indirect || !vq->log || !vq->heads)
+ vq->nheads = kmalloc_array(dev->iov_limit, sizeof(*vq->nheads),
+ GFP_KERNEL);
+ if (!vq->indirect || !vq->log || !vq->heads || !vq->nheads)
goto err_nomem;
}
return 0;
@@ -432,8 +563,7 @@ static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
size_t event __maybe_unused =
vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
- return sizeof(*vq->avail) +
- sizeof(*vq->avail->ring) * num + event;
+ return size_add(struct_size(vq->avail, ring, num), event);
}
static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
@@ -442,8 +572,7 @@ static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
size_t event __maybe_unused =
vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
- return sizeof(*vq->used) +
- sizeof(*vq->used->ring) * num + event;
+ return size_add(struct_size(vq->used, ring, num), event);
}
static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
@@ -456,7 +585,7 @@ void vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue **vqs, int nvqs,
int iov_limit, int weight, int byte_weight,
bool use_worker,
- int (*msg_handler)(struct vhost_dev *dev,
+ int (*msg_handler)(struct vhost_dev *dev, u32 asid,
struct vhost_iotlb_msg *msg))
{
struct vhost_virtqueue *vq;
@@ -469,30 +598,30 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->umem = NULL;
dev->iotlb = NULL;
dev->mm = NULL;
- dev->worker = NULL;
dev->iov_limit = iov_limit;
dev->weight = weight;
dev->byte_weight = byte_weight;
dev->use_worker = use_worker;
dev->msg_handler = msg_handler;
- init_llist_head(&dev->work_list);
+ dev->fork_owner = fork_from_owner_default;
init_waitqueue_head(&dev->wait);
INIT_LIST_HEAD(&dev->read_list);
INIT_LIST_HEAD(&dev->pending_list);
spin_lock_init(&dev->iotlb_lock);
-
+ xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC);
for (i = 0; i < dev->nvqs; ++i) {
vq = dev->vqs[i];
vq->log = NULL;
vq->indirect = NULL;
vq->heads = NULL;
+ vq->nheads = NULL;
vq->dev = dev;
mutex_init(&vq->mutex);
vhost_vq_reset(dev, vq);
if (vq->handle_kick)
vhost_poll_init(&vq->poll, vq->handle_kick,
- EPOLLIN, dev);
+ EPOLLIN, dev, vq);
}
}
EXPORT_SYMBOL_GPL(vhost_dev_init);
@@ -519,14 +648,29 @@ static void vhost_attach_cgroups_work(struct vhost_work *work)
s->ret = cgroup_attach_task_all(s->owner, current);
}
-static int vhost_attach_cgroups(struct vhost_dev *dev)
+static int vhost_attach_task_to_cgroups(struct vhost_worker *worker)
{
struct vhost_attach_cgroups_struct attach;
+ int saved_cnt;
attach.owner = current;
+
vhost_work_init(&attach.work, vhost_attach_cgroups_work);
- vhost_work_queue(dev, &attach.work);
- vhost_work_flush(dev, &attach.work);
+ vhost_worker_queue(worker, &attach.work);
+
+ mutex_lock(&worker->mutex);
+
+ /*
+ * Bypass attachment_cnt check in __vhost_worker_flush:
+ * Temporarily change it to INT_MAX to bypass the check
+ */
+ saved_cnt = worker->attachment_cnt;
+ worker->attachment_cnt = INT_MAX;
+ __vhost_worker_flush(worker);
+ worker->attachment_cnt = saved_cnt;
+
+ mutex_unlock(&worker->mutex);
+
return attach.ret;
}
@@ -543,10 +687,10 @@ static void vhost_attach_mm(struct vhost_dev *dev)
if (dev->use_worker) {
dev->mm = get_task_mm(current);
} else {
- /* vDPA device does not use worker thead, so there's
- * no need to hold the address space for mm. This help
+ /* vDPA device does not use worker thread, so there's
+ * no need to hold the address space for mm. This helps
* to avoid deadlock in the case of mmap() which may
- * held the refcnt of the file and depends on release
+ * hold the refcnt of the file and depends on release
* method to remove vma.
*/
dev->mm = current->mm;
@@ -567,11 +711,393 @@ static void vhost_detach_mm(struct vhost_dev *dev)
dev->mm = NULL;
}
+static void vhost_worker_destroy(struct vhost_dev *dev,
+ struct vhost_worker *worker)
+{
+ if (!worker)
+ return;
+
+ WARN_ON(!llist_empty(&worker->work_list));
+ xa_erase(&dev->worker_xa, worker->id);
+ worker->ops->stop(worker);
+ kfree(worker);
+}
+
+static void vhost_workers_free(struct vhost_dev *dev)
+{
+ struct vhost_worker *worker;
+ unsigned long i;
+
+ if (!dev->use_worker)
+ return;
+
+ for (i = 0; i < dev->nvqs; i++)
+ rcu_assign_pointer(dev->vqs[i]->worker, NULL);
+ /*
+ * Free the default worker we created and cleanup workers userspace
+ * created but couldn't clean up (it forgot or crashed).
+ */
+ xa_for_each(&dev->worker_xa, i, worker)
+ vhost_worker_destroy(dev, worker);
+ xa_destroy(&dev->worker_xa);
+}
+
+static void vhost_task_wakeup(struct vhost_worker *worker)
+{
+ return vhost_task_wake(worker->vtsk);
+}
+
+static void vhost_kthread_wakeup(struct vhost_worker *worker)
+{
+ wake_up_process(worker->kthread_task);
+}
+
+static void vhost_task_do_stop(struct vhost_worker *worker)
+{
+ return vhost_task_stop(worker->vtsk);
+}
+
+static void vhost_kthread_do_stop(struct vhost_worker *worker)
+{
+ kthread_stop(worker->kthread_task);
+}
+
+static int vhost_task_worker_create(struct vhost_worker *worker,
+ struct vhost_dev *dev, const char *name)
+{
+ struct vhost_task *vtsk;
+ u32 id;
+ int ret;
+
+ vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
+ worker, name);
+ if (IS_ERR(vtsk))
+ return PTR_ERR(vtsk);
+
+ worker->vtsk = vtsk;
+ vhost_task_start(vtsk);
+ ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
+ if (ret < 0) {
+ vhost_task_do_stop(worker);
+ return ret;
+ }
+ worker->id = id;
+ return 0;
+}
+
+static int vhost_kthread_worker_create(struct vhost_worker *worker,
+ struct vhost_dev *dev, const char *name)
+{
+ struct task_struct *task;
+ u32 id;
+ int ret;
+
+ task = kthread_create(vhost_run_work_kthread_list, worker, "%s", name);
+ if (IS_ERR(task))
+ return PTR_ERR(task);
+
+ worker->kthread_task = task;
+ wake_up_process(task);
+ ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
+ if (ret < 0)
+ goto stop_worker;
+
+ ret = vhost_attach_task_to_cgroups(worker);
+ if (ret)
+ goto free_id;
+
+ worker->id = id;
+ return 0;
+
+free_id:
+ xa_erase(&dev->worker_xa, id);
+stop_worker:
+ vhost_kthread_do_stop(worker);
+ return ret;
+}
+
+static const struct vhost_worker_ops kthread_ops = {
+ .create = vhost_kthread_worker_create,
+ .stop = vhost_kthread_do_stop,
+ .wakeup = vhost_kthread_wakeup,
+};
+
+static const struct vhost_worker_ops vhost_task_ops = {
+ .create = vhost_task_worker_create,
+ .stop = vhost_task_do_stop,
+ .wakeup = vhost_task_wakeup,
+};
+
+static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
+{
+ struct vhost_worker *worker;
+ char name[TASK_COMM_LEN];
+ int ret;
+ const struct vhost_worker_ops *ops = dev->fork_owner ? &vhost_task_ops :
+ &kthread_ops;
+
+ worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
+ if (!worker)
+ return NULL;
+
+ worker->dev = dev;
+ worker->ops = ops;
+ snprintf(name, sizeof(name), "vhost-%d", current->pid);
+
+ mutex_init(&worker->mutex);
+ init_llist_head(&worker->work_list);
+ worker->kcov_handle = kcov_common_handle();
+ ret = ops->create(worker, dev, name);
+ if (ret < 0)
+ goto free_worker;
+
+ return worker;
+
+free_worker:
+ kfree(worker);
+ return NULL;
+}
+
+/* Caller must have device mutex */
+static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
+ struct vhost_worker *worker)
+{
+ struct vhost_worker *old_worker;
+
+ mutex_lock(&worker->mutex);
+ if (worker->killed) {
+ mutex_unlock(&worker->mutex);
+ return;
+ }
+
+ mutex_lock(&vq->mutex);
+
+ old_worker = rcu_dereference_check(vq->worker,
+ lockdep_is_held(&vq->mutex));
+ rcu_assign_pointer(vq->worker, worker);
+ worker->attachment_cnt++;
+
+ if (!old_worker) {
+ mutex_unlock(&vq->mutex);
+ mutex_unlock(&worker->mutex);
+ return;
+ }
+ mutex_unlock(&vq->mutex);
+ mutex_unlock(&worker->mutex);
+
+ /*
+ * Take the worker mutex to make sure we see the work queued from
+ * device wide flushes which doesn't use RCU for execution.
+ */
+ mutex_lock(&old_worker->mutex);
+ if (old_worker->killed) {
+ mutex_unlock(&old_worker->mutex);
+ return;
+ }
+
+ /*
+ * We don't want to call synchronize_rcu for every vq during setup
+ * because it will slow down VM startup. If we haven't done
+ * VHOST_SET_VRING_KICK and not done the driver specific
+ * SET_ENDPOINT/RUNNING then we can skip the sync since there will
+ * not be any works queued for scsi and net.
+ */
+ mutex_lock(&vq->mutex);
+ if (!vhost_vq_get_backend(vq) && !vq->kick) {
+ mutex_unlock(&vq->mutex);
+
+ old_worker->attachment_cnt--;
+ mutex_unlock(&old_worker->mutex);
+ /*
+ * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
+ * Warn if it adds support for multiple workers but forgets to
+ * handle the early queueing case.
+ */
+ WARN_ON(!old_worker->attachment_cnt &&
+ !llist_empty(&old_worker->work_list));
+ return;
+ }
+ mutex_unlock(&vq->mutex);
+
+ /* Make sure new vq queue/flush/poll calls see the new worker */
+ synchronize_rcu();
+ /* Make sure whatever was queued gets run */
+ __vhost_worker_flush(old_worker);
+ old_worker->attachment_cnt--;
+ mutex_unlock(&old_worker->mutex);
+}
+
+ /* Caller must have device mutex */
+static int vhost_vq_attach_worker(struct vhost_virtqueue *vq,
+ struct vhost_vring_worker *info)
+{
+ unsigned long index = info->worker_id;
+ struct vhost_dev *dev = vq->dev;
+ struct vhost_worker *worker;
+
+ if (!dev->use_worker)
+ return -EINVAL;
+
+ worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
+ if (!worker || worker->id != info->worker_id)
+ return -ENODEV;
+
+ __vhost_vq_attach_worker(vq, worker);
+ return 0;
+}
+
+/* Caller must have device mutex */
+static int vhost_new_worker(struct vhost_dev *dev,
+ struct vhost_worker_state *info)
+{
+ struct vhost_worker *worker;
+
+ worker = vhost_worker_create(dev);
+ if (!worker)
+ return -ENOMEM;
+
+ info->worker_id = worker->id;
+ return 0;
+}
+
+/* Caller must have device mutex */
+static int vhost_free_worker(struct vhost_dev *dev,
+ struct vhost_worker_state *info)
+{
+ unsigned long index = info->worker_id;
+ struct vhost_worker *worker;
+
+ worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
+ if (!worker || worker->id != info->worker_id)
+ return -ENODEV;
+
+ mutex_lock(&worker->mutex);
+ if (worker->attachment_cnt || worker->killed) {
+ mutex_unlock(&worker->mutex);
+ return -EBUSY;
+ }
+ /*
+ * A flush might have raced and snuck in before attachment_cnt was set
+ * to zero. Make sure flushes are flushed from the queue before
+ * freeing.
+ */
+ __vhost_worker_flush(worker);
+ mutex_unlock(&worker->mutex);
+
+ vhost_worker_destroy(dev, worker);
+ return 0;
+}
+
+static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp,
+ struct vhost_virtqueue **vq, u32 *id)
+{
+ u32 __user *idxp = argp;
+ u32 idx;
+ long r;
+
+ r = get_user(idx, idxp);
+ if (r < 0)
+ return r;
+
+ if (idx >= dev->nvqs)
+ return -ENOBUFS;
+
+ idx = array_index_nospec(idx, dev->nvqs);
+
+ *vq = dev->vqs[idx];
+ *id = idx;
+ return 0;
+}
+
+/* Caller must have device mutex */
+long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
+ void __user *argp)
+{
+ struct vhost_vring_worker ring_worker;
+ struct vhost_worker_state state;
+ struct vhost_worker *worker;
+ struct vhost_virtqueue *vq;
+ long ret;
+ u32 idx;
+
+ if (!dev->use_worker)
+ return -EINVAL;
+
+ if (!vhost_dev_has_owner(dev))
+ return -EINVAL;
+
+ ret = vhost_dev_check_owner(dev);
+ if (ret)
+ return ret;
+
+ switch (ioctl) {
+ /* dev worker ioctls */
+ case VHOST_NEW_WORKER:
+ /*
+ * vhost_tasks will account for worker threads under the parent's
+ * NPROC value but kthreads do not. To avoid userspace overflowing
+ * the system with worker threads fork_owner must be true.
+ */
+ if (!dev->fork_owner)
+ return -EFAULT;
+
+ ret = vhost_new_worker(dev, &state);
+ if (!ret && copy_to_user(argp, &state, sizeof(state)))
+ ret = -EFAULT;
+ return ret;
+ case VHOST_FREE_WORKER:
+ if (copy_from_user(&state, argp, sizeof(state)))
+ return -EFAULT;
+ return vhost_free_worker(dev, &state);
+ /* vring worker ioctls */
+ case VHOST_ATTACH_VRING_WORKER:
+ case VHOST_GET_VRING_WORKER:
+ break;
+ default:
+ return -ENOIOCTLCMD;
+ }
+
+ ret = vhost_get_vq_from_user(dev, argp, &vq, &idx);
+ if (ret)
+ return ret;
+
+ switch (ioctl) {
+ case VHOST_ATTACH_VRING_WORKER:
+ if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) {
+ ret = -EFAULT;
+ break;
+ }
+
+ ret = vhost_vq_attach_worker(vq, &ring_worker);
+ break;
+ case VHOST_GET_VRING_WORKER:
+ worker = rcu_dereference_check(vq->worker,
+ lockdep_is_held(&dev->mutex));
+ if (!worker) {
+ ret = -EINVAL;
+ break;
+ }
+
+ ring_worker.index = idx;
+ ring_worker.worker_id = worker->id;
+
+ if (copy_to_user(argp, &ring_worker, sizeof(ring_worker)))
+ ret = -EFAULT;
+ break;
+ default:
+ ret = -ENOIOCTLCMD;
+ break;
+ }
+
+ return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_worker_ioctl);
+
/* Caller should have device mutex */
long vhost_dev_set_owner(struct vhost_dev *dev)
{
- struct task_struct *worker;
- int err;
+ struct vhost_worker *worker;
+ int err, i;
/* Is there an owner already? */
if (vhost_dev_has_owner(dev)) {
@@ -581,36 +1107,33 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
vhost_attach_mm(dev);
- dev->kcov_handle = kcov_common_handle();
+ err = vhost_dev_alloc_iovecs(dev);
+ if (err)
+ goto err_iovecs;
+
if (dev->use_worker) {
- worker = kthread_create(vhost_worker, dev,
- "vhost-%d", current->pid);
- if (IS_ERR(worker)) {
- err = PTR_ERR(worker);
+ /*
+ * This should be done last, because vsock can queue work
+ * before VHOST_SET_OWNER so it simplifies the failure path
+ * below since we don't have to worry about vsock queueing
+ * while we free the worker.
+ */
+ worker = vhost_worker_create(dev);
+ if (!worker) {
+ err = -ENOMEM;
goto err_worker;
}
- dev->worker = worker;
- wake_up_process(worker); /* avoid contributing to loadavg */
-
- err = vhost_attach_cgroups(dev);
- if (err)
- goto err_cgroup;
+ for (i = 0; i < dev->nvqs; i++)
+ __vhost_vq_attach_worker(dev->vqs[i], worker);
}
- err = vhost_dev_alloc_iovecs(dev);
- if (err)
- goto err_cgroup;
-
return 0;
-err_cgroup:
- if (dev->worker) {
- kthread_stop(dev->worker);
- dev->worker = NULL;
- }
+
err_worker:
+ vhost_dev_free_iovecs(dev);
+err_iovecs:
vhost_detach_mm(dev);
- dev->kcov_handle = 0;
err_mm:
return err;
}
@@ -635,6 +1158,7 @@ void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem)
vhost_dev_cleanup(dev);
+ dev->fork_owner = fork_from_owner_default;
dev->umem = umem;
/* We don't need VQ locks below since vhost_dev_cleanup makes sure
* VQs aren't running.
@@ -649,15 +1173,15 @@ void vhost_dev_stop(struct vhost_dev *dev)
int i;
for (i = 0; i < dev->nvqs; ++i) {
- if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) {
+ if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick)
vhost_poll_stop(&dev->vqs[i]->poll);
- vhost_poll_flush(&dev->vqs[i]->poll);
- }
}
+
+ vhost_dev_flush(dev);
}
EXPORT_SYMBOL_GPL(vhost_dev_stop);
-static void vhost_clear_msg(struct vhost_dev *dev)
+void vhost_clear_msg(struct vhost_dev *dev)
{
struct vhost_msg_node *node, *n;
@@ -675,6 +1199,7 @@ static void vhost_clear_msg(struct vhost_dev *dev)
spin_unlock(&dev->iotlb_lock);
}
+EXPORT_SYMBOL_GPL(vhost_clear_msg);
void vhost_dev_cleanup(struct vhost_dev *dev)
{
@@ -685,8 +1210,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
eventfd_ctx_put(dev->vqs[i]->error_ctx);
if (dev->vqs[i]->kick)
fput(dev->vqs[i]->kick);
- if (dev->vqs[i]->call_ctx)
- eventfd_ctx_put(dev->vqs[i]->call_ctx);
+ if (dev->vqs[i]->call_ctx.ctx)
+ eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx);
vhost_vq_reset(dev, dev->vqs[i]);
}
vhost_dev_free_iovecs(dev);
@@ -700,12 +1225,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
dev->iotlb = NULL;
vhost_clear_msg(dev);
wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
- WARN_ON(!llist_empty(&dev->work_list));
- if (dev->worker) {
- kthread_stop(dev->worker);
- dev->worker = NULL;
- dev->kcov_handle = 0;
- }
+ vhost_workers_free(dev);
vhost_detach_mm(dev);
}
EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
@@ -723,10 +1243,16 @@ static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
(sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
}
+/* Make sure 64 bit math will not overflow. */
static bool vhost_overflow(u64 uaddr, u64 size)
{
- /* Make sure 64 bit math will not overflow. */
- return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size;
+ if (uaddr > ULONG_MAX || size > ULONG_MAX)
+ return true;
+
+ if (!size)
+ return false;
+
+ return uaddr > ULONG_MAX - size + 1;
}
/* Caller should have vq mutex and device mutex. */
@@ -822,7 +1348,7 @@ static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
VHOST_ACCESS_WO);
if (ret < 0)
goto out;
- iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
+ iov_iter_init(&t, ITER_DEST, vq->iotlb_iov, ret, size);
ret = copy_to_iter(from, size, &t);
if (ret == size)
ret = 0;
@@ -861,7 +1387,7 @@ static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
(unsigned long long) size);
goto out;
}
- iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
+ iov_iter_init(&f, ITER_SOURCE, vq->iotlb_iov, ret, size);
ret = copy_from_iter(to, size, &f);
if (ret == size)
ret = 0;
@@ -997,10 +1523,36 @@ static void vhost_dev_unlock_vqs(struct vhost_dev *d)
mutex_unlock(&d->vqs[i]->mutex);
}
-static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
- __virtio16 *idx)
+static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq)
{
- return vhost_get_avail(vq, *idx, &vq->avail->idx);
+ __virtio16 idx;
+ int r;
+
+ r = vhost_get_avail(vq, idx, &vq->avail->idx);
+ if (unlikely(r < 0)) {
+ vq_err(vq, "Failed to access available index at %p (%d)\n",
+ &vq->avail->idx, r);
+ return r;
+ }
+
+ /* Check it isn't doing very strange thing with available indexes */
+ vq->avail_idx = vhost16_to_cpu(vq, idx);
+ if (unlikely((u16)(vq->avail_idx - vq->last_avail_idx) > vq->num)) {
+ vq_err(vq, "Invalid available index change from %u to %u",
+ vq->last_avail_idx, vq->avail_idx);
+ return -EINVAL;
+ }
+
+ /* We're done if there is nothing new */
+ if (vq->avail_idx == vq->last_avail_idx)
+ return 0;
+
+ /*
+ * We updated vq->avail_idx so we need a memory barrier between
+ * the index read above and the caller reading avail ring entries.
+ */
+ smp_rmb();
+ return 1;
}
static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
@@ -1072,11 +1624,14 @@ static bool umem_access_ok(u64 uaddr, u64 size, int access)
return true;
}
-static int vhost_process_iotlb_msg(struct vhost_dev *dev,
+static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
struct vhost_iotlb_msg *msg)
{
int ret = 0;
+ if (asid != 0)
+ return -EINVAL;
+
mutex_lock(&dev->mutex);
vhost_dev_lock_vqs(dev);
switch (msg->type) {
@@ -1123,6 +1678,7 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
struct vhost_iotlb_msg msg;
size_t offset;
int type, ret;
+ u32 asid = 0;
ret = copy_from_iter(&type, sizeof(type), from);
if (ret != sizeof(type)) {
@@ -1138,7 +1694,16 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
break;
case VHOST_IOTLB_MSG_V2:
- offset = sizeof(__u32);
+ if (vhost_backend_has_feature(dev->vqs[0],
+ VHOST_BACKEND_F_IOTLB_ASID)) {
+ ret = copy_from_iter(&asid, sizeof(asid), from);
+ if (ret != sizeof(asid)) {
+ ret = -EINVAL;
+ goto done;
+ }
+ offset = 0;
+ } else
+ offset = sizeof(__u32);
break;
default:
ret = -EINVAL;
@@ -1152,10 +1717,15 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
goto done;
}
+ if (msg.type == VHOST_IOTLB_UPDATE && msg.size == 0) {
+ ret = -EINVAL;
+ goto done;
+ }
+
if (dev->msg_handler)
- ret = dev->msg_handler(dev, &msg);
+ ret = dev->msg_handler(dev, asid, &msg);
else
- ret = vhost_process_iotlb_msg(dev, &msg);
+ ret = vhost_process_iotlb_msg(dev, asid, &msg);
if (ret) {
ret = -EFAULT;
goto done;
@@ -1283,6 +1853,11 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
vring_used_t __user *used)
{
+ /* If an IOTLB device is present, the vring addresses are
+ * GIOVAs. Access validation occurs at prefetch time. */
+ if (vq->iotlb)
+ return true;
+
return access_ok(desc, vhost_get_desc_size(vq, num)) &&
access_ok(avail, vhost_get_avail_size(vq, num)) &&
access_ok(used, vhost_get_used_size(vq, num));
@@ -1358,6 +1933,20 @@ bool vhost_log_access_ok(struct vhost_dev *dev)
}
EXPORT_SYMBOL_GPL(vhost_log_access_ok);
+static bool vq_log_used_access_ok(struct vhost_virtqueue *vq,
+ void __user *log_base,
+ bool log_used,
+ u64 log_addr)
+{
+ /* If an IOTLB device is present, log_addr is a GIOVA that
+ * will never be logged by log_used(). */
+ if (vq->iotlb)
+ return true;
+
+ return !log_used || log_access_ok(log_base, log_addr,
+ vhost_get_used_size(vq, vq->num));
+}
+
/* Verify access for write logging. */
/* Caller should have vq mutex and device mutex */
static bool vq_log_access_ok(struct vhost_virtqueue *vq,
@@ -1365,8 +1954,7 @@ static bool vq_log_access_ok(struct vhost_virtqueue *vq,
{
return vq_memory_access_ok(log_base, vq->umem,
vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
- (!vq->log_used || log_access_ok(log_base, vq->log_addr,
- vhost_get_used_size(vq, vq->num)));
+ vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr);
}
/* Can we start vq? */
@@ -1376,10 +1964,6 @@ bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
if (!vq_log_access_ok(vq, vq->log_base))
return false;
- /* Access validation occurs at prefetch time with IOTLB */
- if (vq->iotlb)
- return true;
-
return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
}
EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
@@ -1405,7 +1989,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
memcpy(newmem, &mem, size);
if (copy_from_user(newmem->regions, m->regions,
- mem.nregions * sizeof *m->regions)) {
+ flex_array_size(newmem, regions, mem.nregions))) {
kvfree(newmem);
return -EFAULT;
}
@@ -1509,10 +2093,9 @@ static long vhost_vring_set_addr(struct vhost_dev *d,
return -EINVAL;
/* Also validate log access for used ring if enabled. */
- if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
- !log_access_ok(vq->log_base, a.log_guest_addr,
- sizeof *vq->used +
- vq->num * sizeof *vq->used->ring))
+ if (!vq_log_used_access_ok(vq, vq->log_base,
+ a.flags & (0x1 << VHOST_VRING_F_LOG),
+ a.log_guest_addr))
return -EINVAL;
}
@@ -1554,21 +2137,15 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
struct file *eventfp, *filep = NULL;
bool pollstart = false, pollstop = false;
struct eventfd_ctx *ctx = NULL;
- u32 __user *idxp = argp;
struct vhost_virtqueue *vq;
struct vhost_vring_state s;
struct vhost_vring_file f;
u32 idx;
long r;
- r = get_user(idx, idxp);
+ r = vhost_get_vq_from_user(d, argp, &vq, &idx);
if (r < 0)
return r;
- if (idx >= d->nvqs)
- return -ENOBUFS;
-
- idx = array_index_nospec(idx, d->nvqs);
- vq = d->vqs[idx];
if (ioctl == VHOST_SET_VRING_NUM ||
ioctl == VHOST_SET_VRING_ADDR) {
@@ -1589,17 +2166,26 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
r = -EFAULT;
break;
}
- if (s.num > 0xffff) {
- r = -EINVAL;
- break;
+ if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
+ vq->next_avail_head = vq->last_avail_idx =
+ s.num & 0xffff;
+ vq->last_used_idx = (s.num >> 16) & 0xffff;
+ } else {
+ if (s.num > 0xffff) {
+ r = -EINVAL;
+ break;
+ }
+ vq->next_avail_head = vq->last_avail_idx = s.num;
}
- vq->last_avail_idx = s.num;
/* Forget the cached index value. */
vq->avail_idx = vq->last_avail_idx;
break;
case VHOST_GET_VRING_BASE:
s.index = idx;
- s.num = vq->last_avail_idx;
+ if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED))
+ s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16);
+ else
+ s.num = vq->last_avail_idx;
if (copy_to_user(argp, &s, sizeof s))
r = -EFAULT;
break;
@@ -1629,7 +2215,8 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
r = PTR_ERR(ctx);
break;
}
- swap(ctx, vq->call_ctx);
+
+ swap(ctx, vq->call_ctx.ctx);
break;
case VHOST_SET_VRING_ERR:
if (copy_from_user(&f, argp, sizeof f)) {
@@ -1680,12 +2267,12 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
mutex_unlock(&vq->mutex);
if (pollstop && vq->handle_kick)
- vhost_poll_flush(&vq->poll);
+ vhost_dev_flush(vq->poll.dev);
return r;
}
EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
-int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
+int vhost_init_device_iotlb(struct vhost_dev *d)
{
struct vhost_iotlb *niotlb, *oiotlb;
int i;
@@ -1726,6 +2313,45 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
goto done;
}
+#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL
+ if (ioctl == VHOST_SET_FORK_FROM_OWNER) {
+ /* Only allow modification before owner is set */
+ if (vhost_dev_has_owner(d)) {
+ r = -EBUSY;
+ goto done;
+ }
+ u8 fork_owner_val;
+
+ if (get_user(fork_owner_val, (u8 __user *)argp)) {
+ r = -EFAULT;
+ goto done;
+ }
+ if (fork_owner_val != VHOST_FORK_OWNER_TASK &&
+ fork_owner_val != VHOST_FORK_OWNER_KTHREAD) {
+ r = -EINVAL;
+ goto done;
+ }
+ d->fork_owner = !!fork_owner_val;
+ r = 0;
+ goto done;
+ }
+ if (ioctl == VHOST_GET_FORK_FROM_OWNER) {
+ u8 fork_owner_val = d->fork_owner;
+
+ if (fork_owner_val != VHOST_FORK_OWNER_TASK &&
+ fork_owner_val != VHOST_FORK_OWNER_KTHREAD) {
+ r = -EINVAL;
+ goto done;
+ }
+ if (put_user(fork_owner_val, (u8 __user *)argp)) {
+ r = -EFAULT;
+ goto done;
+ }
+ r = 0;
+ goto done;
+ }
+#endif
+
/* You must be the owner to do anything else */
r = vhost_dev_check_owner(d);
if (r)
@@ -1786,7 +2412,7 @@ EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
/* TODO: This is really inefficient. We need something like get_user()
* (instruction directly accesses the data, with an exception table entry
- * returning -EFAULT). See Documentation/x86/exception-tables.rst.
+ * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst.
*/
static int set_bit_to_user(int nr, void __user *addr)
{
@@ -1874,7 +2500,7 @@ static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
{
- struct iovec iov[64];
+ struct iovec *iov = vq->log_iov;
int i, ret;
if (!vq->iotlb)
@@ -1895,6 +2521,19 @@ static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
return 0;
}
+/*
+ * vhost_log_write() - Log in dirty page bitmap
+ * @vq: vhost virtqueue.
+ * @log: Array of dirty memory in GPA.
+ * @log_num: Size of vhost_log arrary.
+ * @len: The total length of memory buffer to log in the dirty bitmap.
+ * Some drivers may only partially use pages shared via the last
+ * vring descriptor (i.e. vhost-net RX buffer).
+ * Use (len == U64_MAX) to indicate the driver would log all
+ * pages of vring descriptors.
+ * @iov: Array of dirty memory in HVA.
+ * @count: Size of iovec array.
+ */
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
unsigned int log_num, u64 len, struct iovec *iov, int count)
{
@@ -1918,15 +2557,14 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
r = log_write(vq->log_base, log[i].addr, l);
if (r < 0)
return r;
- len -= l;
- if (!len) {
- if (vq->log_ctx)
- eventfd_signal(vq->log_ctx, 1);
- return 0;
- }
+
+ if (len != U64_MAX)
+ len -= l;
}
- /* Length written exceeds what we have stored. This is a bug. */
- BUG();
+
+ if (vq->log_ctx)
+ eventfd_signal(vq->log_ctx);
+
return 0;
}
EXPORT_SYMBOL_GPL(vhost_log_write);
@@ -1944,12 +2582,12 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
log_used(vq, (used - (void __user *)vq->used),
sizeof vq->used->flags);
if (vq->log_ctx)
- eventfd_signal(vq->log_ctx, 1);
+ eventfd_signal(vq->log_ctx);
}
return 0;
}
-static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
+static int vhost_update_avail_event(struct vhost_virtqueue *vq)
{
if (vhost_put_avail_event(vq))
return -EFAULT;
@@ -1962,7 +2600,7 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
log_used(vq, (used - (void __user *)vq->used),
sizeof *vhost_avail_event(vq));
if (vq->log_ctx)
- eventfd_signal(vq->log_ctx, 1);
+ eventfd_signal(vq->log_ctx);
}
return 0;
}
@@ -2009,7 +2647,7 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
struct vhost_dev *dev = vq->dev;
struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem;
struct iovec *_iov;
- u64 s = 0;
+ u64 s = 0, last = addr + len - 1;
int ret = 0;
while ((u64)len > s) {
@@ -2019,7 +2657,7 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
break;
}
- map = vhost_iotlb_itree_first(umem, addr, addr + len - 1);
+ map = vhost_iotlb_itree_first(umem, addr, last);
if (map == NULL || map->start > addr) {
if (umem != dev->iotlb) {
ret = -EFAULT;
@@ -2091,12 +2729,7 @@ static int get_indirect(struct vhost_virtqueue *vq,
vq_err(vq, "Translation failure %d in indirect.\n", ret);
return ret;
}
- iov_iter_init(&from, READ, vq->indirect, ret, len);
-
- /* We will use the result as an address to read from, so most
- * architectures only need a compiler barrier here. */
- read_barrier_depends();
-
+ iov_iter_init(&from, ITER_SOURCE, vq->indirect, ret, len);
count = len / sizeof desc;
/* Buffers are chained via a 16 bit next field, so
* we can have at most 2^16 of these. */
@@ -2161,66 +2794,66 @@ static int get_indirect(struct vhost_virtqueue *vq,
return 0;
}
-/* This looks in the virtqueue and for the first available buffer, and converts
- * it to an iovec for convenient access. Since descriptors consist of some
- * number of output then some number of input descriptors, it's actually two
- * iovecs, but we pack them into one and note how many of each there were.
+/**
+ * vhost_get_vq_desc_n - Fetch the next available descriptor chain and build iovecs
+ * @vq: target virtqueue
+ * @iov: array that receives the scatter/gather segments
+ * @iov_size: capacity of @iov in elements
+ * @out_num: the number of output segments
+ * @in_num: the number of input segments
+ * @log: optional array to record addr/len for each writable segment; NULL if unused
+ * @log_num: optional output; number of entries written to @log when provided
+ * @ndesc: optional output; number of descriptors consumed from the available ring
+ * (useful for rollback via vhost_discard_vq_desc)
*
- * This function returns the descriptor number found, or vq->num (which is
- * never a valid descriptor number) if none was found. A negative code is
- * returned on error. */
-int vhost_get_vq_desc(struct vhost_virtqueue *vq,
- struct iovec iov[], unsigned int iov_size,
- unsigned int *out_num, unsigned int *in_num,
- struct vhost_log *log, unsigned int *log_num)
+ * Extracts one available descriptor chain from @vq and translates guest addresses
+ * into host iovecs.
+ *
+ * On success, advances @vq->last_avail_idx by 1 and @vq->next_avail_head by the
+ * number of descriptors consumed (also stored via @ndesc when non-NULL).
+ *
+ * Return:
+ * - head index in [0, @vq->num) on success;
+ * - @vq->num if no descriptor is currently available;
+ * - negative errno on failure
+ */
+int vhost_get_vq_desc_n(struct vhost_virtqueue *vq,
+ struct iovec iov[], unsigned int iov_size,
+ unsigned int *out_num, unsigned int *in_num,
+ struct vhost_log *log, unsigned int *log_num,
+ unsigned int *ndesc)
{
+ bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER);
struct vring_desc desc;
unsigned int i, head, found = 0;
- u16 last_avail_idx;
- __virtio16 avail_idx;
+ u16 last_avail_idx = vq->last_avail_idx;
__virtio16 ring_head;
- int ret, access;
-
- /* Check it isn't doing very strange things with descriptor numbers. */
- last_avail_idx = vq->last_avail_idx;
+ int ret, access, c = 0;
if (vq->avail_idx == vq->last_avail_idx) {
- if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) {
- vq_err(vq, "Failed to access avail idx at %p\n",
- &vq->avail->idx);
- return -EFAULT;
- }
- vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
-
- if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) {
- vq_err(vq, "Guest moved used index from %u to %u",
- last_avail_idx, vq->avail_idx);
- return -EFAULT;
- }
+ ret = vhost_get_avail_idx(vq);
+ if (unlikely(ret < 0))
+ return ret;
- /* If there's nothing new since last we looked, return
- * invalid.
- */
- if (vq->avail_idx == last_avail_idx)
+ if (!ret)
return vq->num;
-
- /* Only get avail ring entries after they have been
- * exposed by guest.
- */
- smp_rmb();
}
- /* Grab the next descriptor number they're advertising, and increment
- * the index we've seen. */
- if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
- vq_err(vq, "Failed to read head: idx %d address %p\n",
- last_avail_idx,
- &vq->avail->ring[last_avail_idx % vq->num]);
- return -EFAULT;
+ if (in_order)
+ head = vq->next_avail_head & (vq->num - 1);
+ else {
+ /* Grab the next descriptor number they're
+ * advertising, and increment the index we've seen. */
+ if (unlikely(vhost_get_avail_head(vq, &ring_head,
+ last_avail_idx))) {
+ vq_err(vq, "Failed to read head: idx %d address %p\n",
+ last_avail_idx,
+ &vq->avail->ring[last_avail_idx % vq->num]);
+ return -EFAULT;
+ }
+ head = vhost16_to_cpu(vq, ring_head);
}
- head = vhost16_to_cpu(vq, ring_head);
-
/* If their number is silly, that's an error. */
if (unlikely(head >= vq->num)) {
vq_err(vq, "Guest says index %u > %u is available",
@@ -2263,6 +2896,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
"in indirect descriptor at idx %d\n", i);
return ret;
}
+ ++c;
continue;
}
@@ -2298,22 +2932,56 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
}
*out_num += ret;
}
+ ++c;
} while ((i = next_desc(vq, &desc)) != -1);
/* On success, increment avail index. */
vq->last_avail_idx++;
+ vq->next_avail_head += c;
+
+ if (ndesc)
+ *ndesc = c;
/* Assume notifications from guest are disabled at this point,
* if they aren't we would need to update avail_event index. */
BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
return head;
}
+EXPORT_SYMBOL_GPL(vhost_get_vq_desc_n);
+
+/* This looks in the virtqueue and for the first available buffer, and converts
+ * it to an iovec for convenient access. Since descriptors consist of some
+ * number of output then some number of input descriptors, it's actually two
+ * iovecs, but we pack them into one and note how many of each there were.
+ *
+ * This function returns the descriptor number found, or vq->num (which is
+ * never a valid descriptor number) if none was found. A negative code is
+ * returned on error.
+ */
+int vhost_get_vq_desc(struct vhost_virtqueue *vq,
+ struct iovec iov[], unsigned int iov_size,
+ unsigned int *out_num, unsigned int *in_num,
+ struct vhost_log *log, unsigned int *log_num)
+{
+ return vhost_get_vq_desc_n(vq, iov, iov_size, out_num, in_num,
+ log, log_num, NULL);
+}
EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
-/* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
-void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
+/**
+ * vhost_discard_vq_desc - Reverse the effect of vhost_get_vq_desc_n()
+ * @vq: target virtqueue
+ * @nbufs: number of buffers to roll back
+ * @ndesc: number of descriptors to roll back
+ *
+ * Rewinds the internal consumer cursors after a failed attempt to use buffers
+ * returned by vhost_get_vq_desc_n().
+ */
+void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int nbufs,
+ unsigned int ndesc)
{
- vq->last_avail_idx -= n;
+ vq->next_avail_head -= ndesc;
+ vq->last_avail_idx -= nbufs;
}
EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
@@ -2325,8 +2993,9 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
cpu_to_vhost32(vq, head),
cpu_to_vhost32(vq, len)
};
+ u16 nheads = 1;
- return vhost_add_used_n(vq, &heads, 1);
+ return vhost_add_used_n(vq, &heads, &nheads, 1);
}
EXPORT_SYMBOL_GPL(vhost_add_used);
@@ -2362,10 +3031,9 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
return 0;
}
-/* After we've used one of their buffers, we tell them about it. We'll then
- * want to notify the guest, using eventfd. */
-int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
- unsigned count)
+static int vhost_add_used_n_ooo(struct vhost_virtqueue *vq,
+ struct vring_used_elem *heads,
+ unsigned count)
{
int start, n, r;
@@ -2378,7 +3046,72 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
heads += n;
count -= n;
}
- r = __vhost_add_used_n(vq, heads, count);
+ return __vhost_add_used_n(vq, heads, count);
+}
+
+static int vhost_add_used_n_in_order(struct vhost_virtqueue *vq,
+ struct vring_used_elem *heads,
+ const u16 *nheads,
+ unsigned count)
+{
+ vring_used_elem_t __user *used;
+ u16 old, new = vq->last_used_idx;
+ int start, i;
+
+ if (!nheads)
+ return -EINVAL;
+
+ start = vq->last_used_idx & (vq->num - 1);
+ used = vq->used->ring + start;
+
+ for (i = 0; i < count; i++) {
+ if (vhost_put_used(vq, &heads[i], start, 1)) {
+ vq_err(vq, "Failed to write used");
+ return -EFAULT;
+ }
+ start += nheads[i];
+ new += nheads[i];
+ if (start >= vq->num)
+ start -= vq->num;
+ }
+
+ if (unlikely(vq->log_used)) {
+ /* Make sure data is seen before log. */
+ smp_wmb();
+ /* Log used ring entry write. */
+ log_used(vq, ((void __user *)used - (void __user *)vq->used),
+ (vq->num - start) * sizeof *used);
+ if (start + count > vq->num)
+ log_used(vq, 0,
+ (start + count - vq->num) * sizeof *used);
+ }
+
+ old = vq->last_used_idx;
+ vq->last_used_idx = new;
+ /* If the driver never bothers to signal in a very long while,
+ * used index might wrap around. If that happens, invalidate
+ * signalled_used index we stored. TODO: make sure driver
+ * signals at least once in 2^16 and remove this. */
+ if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
+ vq->signalled_used_valid = false;
+ return 0;
+}
+
+/* After we've used one of their buffers, we tell them about it. We'll then
+ * want to notify the guest, using eventfd. */
+int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
+ u16 *nheads, unsigned count)
+{
+ bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER);
+ int r;
+
+ if (!in_order || !nheads)
+ r = vhost_add_used_n_ooo(vq, heads, count);
+ else
+ r = vhost_add_used_n_in_order(vq, heads, nheads, count);
+
+ if (r < 0)
+ return r;
/* Make sure buffer is written before we update index. */
smp_wmb();
@@ -2393,7 +3126,7 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
log_used(vq, offsetof(struct vring_used, idx),
sizeof vq->used->idx);
if (vq->log_ctx)
- eventfd_signal(vq->log_ctx, 1);
+ eventfd_signal(vq->log_ctx);
}
return r;
}
@@ -2440,8 +3173,8 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
{
/* Signal the Guest tell them we used something up. */
- if (vq->call_ctx && vhost_notify(dev, vq))
- eventfd_signal(vq->call_ctx, 1);
+ if (vq->call_ctx.ctx && vhost_notify(dev, vq))
+ eventfd_signal(vq->call_ctx.ctx);
}
EXPORT_SYMBOL_GPL(vhost_signal);
@@ -2458,35 +3191,33 @@ EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
/* multi-buffer version of vhost_add_used_and_signal */
void vhost_add_used_and_signal_n(struct vhost_dev *dev,
struct vhost_virtqueue *vq,
- struct vring_used_elem *heads, unsigned count)
+ struct vring_used_elem *heads,
+ u16 *nheads,
+ unsigned count)
{
- vhost_add_used_n(vq, heads, count);
+ vhost_add_used_n(vq, heads, nheads, count);
vhost_signal(dev, vq);
}
EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
-/* return true if we're sure that avaiable ring is empty */
+/* return true if we're sure that available ring is empty */
bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
{
- __virtio16 avail_idx;
int r;
if (vq->avail_idx != vq->last_avail_idx)
return false;
- r = vhost_get_avail_idx(vq, &avail_idx);
- if (unlikely(r))
- return false;
- vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
+ r = vhost_get_avail_idx(vq);
- return vq->avail_idx == vq->last_avail_idx;
+ /* Note: we treat error as non-empty here */
+ return r == 0;
}
EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
/* OK, now we need to know about added descriptors. */
bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
{
- __virtio16 avail_idx;
int r;
if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
@@ -2500,7 +3231,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
return false;
}
} else {
- r = vhost_update_avail_event(vq, vq->avail_idx);
+ r = vhost_update_avail_event(vq);
if (r) {
vq_err(vq, "Failed to update avail event index at %p: %d\n",
vhost_avail_event(vq), r);
@@ -2510,14 +3241,13 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
/* They could have slipped one in as we were doing that: make
* sure it's written, then check again. */
smp_mb();
- r = vhost_get_avail_idx(vq, &avail_idx);
- if (r) {
- vq_err(vq, "Failed to check avail idx at %p: %d\n",
- &vq->avail->idx, r);
+
+ r = vhost_get_avail_idx(vq);
+ /* Note: we treat error as empty here */
+ if (unlikely(r < 0))
return false;
- }
- return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx;
+ return r;
}
EXPORT_SYMBOL_GPL(vhost_enable_notify);
@@ -2532,7 +3262,7 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
r = vhost_update_used_flags(vq);
if (r)
- vq_err(vq, "Failed to enable notification at %p: %d\n",
+ vq_err(vq, "Failed to disable notification at %p: %d\n",
&vq->used->flags, r);
}
}
@@ -2541,12 +3271,11 @@ EXPORT_SYMBOL_GPL(vhost_disable_notify);
/* Create a new message. */
struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
{
- struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
+ /* Make sure all padding within the structure is initialized. */
+ struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
if (!node)
return NULL;
- /* Make sure all padding within the structure is initialized. */
- memset(&node->msg, 0, sizeof node->msg);
node->vq = vq;
node->msg.type = type;
return node;
@@ -2581,6 +3310,21 @@ struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
}
EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
+void vhost_set_backend_features(struct vhost_dev *dev, u64 features)
+{
+ struct vhost_virtqueue *vq;
+ int i;
+
+ mutex_lock(&dev->mutex);
+ for (i = 0; i < dev->nvqs; ++i) {
+ vq = dev->vqs[i];
+ mutex_lock(&vq->mutex);
+ vq->acked_backend_features = features;
+ mutex_unlock(&vq->mutex);
+ }
+ mutex_unlock(&dev->mutex);
+}
+EXPORT_SYMBOL_GPL(vhost_set_backend_features);
static int __init vhost_init(void)
{