summaryrefslogtreecommitdiff
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c233
1 files changed, 79 insertions, 154 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index f44340b41494..d450e16c5c25 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -50,10 +50,6 @@ enum {
#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
-INTERVAL_TREE_DEFINE(struct vhost_umem_node,
- rb, __u64, __subtree_last,
- START, LAST, static inline, vhost_umem_interval_tree);
-
#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
{
@@ -457,7 +453,9 @@ static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
void vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue **vqs, int nvqs,
- int iov_limit, int weight, int byte_weight)
+ int iov_limit, int weight, int byte_weight,
+ int (*msg_handler)(struct vhost_dev *dev,
+ struct vhost_iotlb_msg *msg))
{
struct vhost_virtqueue *vq;
int i;
@@ -473,6 +471,7 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->iov_limit = iov_limit;
dev->weight = weight;
dev->byte_weight = byte_weight;
+ dev->msg_handler = msg_handler;
init_llist_head(&dev->work_list);
init_waitqueue_head(&dev->wait);
INIT_LIST_HEAD(&dev->read_list);
@@ -581,21 +580,25 @@ err_mm:
}
EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
-struct vhost_umem *vhost_dev_reset_owner_prepare(void)
+static struct vhost_iotlb *iotlb_alloc(void)
+{
+ return vhost_iotlb_alloc(max_iotlb_entries,
+ VHOST_IOTLB_FLAG_RETIRE);
+}
+
+struct vhost_iotlb *vhost_dev_reset_owner_prepare(void)
{
- return kvzalloc(sizeof(struct vhost_umem), GFP_KERNEL);
+ return iotlb_alloc();
}
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
/* Caller should have device mutex */
-void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
+void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem)
{
int i;
vhost_dev_cleanup(dev);
- /* Restore memory to default empty mapping. */
- INIT_LIST_HEAD(&umem->umem_list);
dev->umem = umem;
/* We don't need VQ locks below since vhost_dev_cleanup makes sure
* VQs aren't running.
@@ -618,28 +621,6 @@ void vhost_dev_stop(struct vhost_dev *dev)
}
EXPORT_SYMBOL_GPL(vhost_dev_stop);
-static void vhost_umem_free(struct vhost_umem *umem,
- struct vhost_umem_node *node)
-{
- vhost_umem_interval_tree_remove(node, &umem->umem_tree);
- list_del(&node->link);
- kfree(node);
- umem->numem--;
-}
-
-static void vhost_umem_clean(struct vhost_umem *umem)
-{
- struct vhost_umem_node *node, *tmp;
-
- if (!umem)
- return;
-
- list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
- vhost_umem_free(umem, node);
-
- kvfree(umem);
-}
-
static void vhost_clear_msg(struct vhost_dev *dev)
{
struct vhost_msg_node *node, *n;
@@ -677,9 +658,9 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
eventfd_ctx_put(dev->log_ctx);
dev->log_ctx = NULL;
/* No one will access memory at this point */
- vhost_umem_clean(dev->umem);
+ vhost_iotlb_free(dev->umem);
dev->umem = NULL;
- vhost_umem_clean(dev->iotlb);
+ vhost_iotlb_free(dev->iotlb);
dev->iotlb = NULL;
vhost_clear_msg(dev);
wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
@@ -715,27 +696,26 @@ static bool vhost_overflow(u64 uaddr, u64 size)
}
/* Caller should have vq mutex and device mutex. */
-static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
+static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem,
int log_all)
{
- struct vhost_umem_node *node;
+ struct vhost_iotlb_map *map;
if (!umem)
return false;
- list_for_each_entry(node, &umem->umem_list, link) {
- unsigned long a = node->userspace_addr;
+ list_for_each_entry(map, &umem->list, link) {
+ unsigned long a = map->addr;
- if (vhost_overflow(node->userspace_addr, node->size))
+ if (vhost_overflow(map->addr, map->size))
return false;
- if (!access_ok((void __user *)a,
- node->size))
+ if (!access_ok((void __user *)a, map->size))
return false;
else if (log_all && !log_access_ok(log_base,
- node->start,
- node->size))
+ map->start,
+ map->size))
return false;
}
return true;
@@ -745,17 +725,17 @@ static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
u64 addr, unsigned int size,
int type)
{
- const struct vhost_umem_node *node = vq->meta_iotlb[type];
+ const struct vhost_iotlb_map *map = vq->meta_iotlb[type];
- if (!node)
+ if (!map)
return NULL;
- return (void *)(uintptr_t)(node->userspace_addr + addr - node->start);
+ return (void *)(uintptr_t)(map->addr + addr - map->start);
}
/* Can we switch to this memory table? */
/* Caller should have device mutex but not vq mutex */
-static bool memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
+static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem,
int log_all)
{
int i;
@@ -1020,47 +1000,6 @@ static inline int vhost_get_desc(struct vhost_virtqueue *vq,
return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
}
-static int vhost_new_umem_range(struct vhost_umem *umem,
- u64 start, u64 size, u64 end,
- u64 userspace_addr, int perm)
-{
- struct vhost_umem_node *tmp, *node;
-
- if (!size)
- return -EFAULT;
-
- node = kmalloc(sizeof(*node), GFP_ATOMIC);
- if (!node)
- return -ENOMEM;
-
- if (umem->numem == max_iotlb_entries) {
- tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
- vhost_umem_free(umem, tmp);
- }
-
- node->start = start;
- node->size = size;
- node->last = end;
- node->userspace_addr = userspace_addr;
- node->perm = perm;
- INIT_LIST_HEAD(&node->link);
- list_add_tail(&node->link, &umem->umem_list);
- vhost_umem_interval_tree_insert(node, &umem->umem_tree);
- umem->numem++;
-
- return 0;
-}
-
-static void vhost_del_umem_range(struct vhost_umem *umem,
- u64 start, u64 end)
-{
- struct vhost_umem_node *node;
-
- while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
- start, end)))
- vhost_umem_free(umem, node);
-}
-
static void vhost_iotlb_notify_vq(struct vhost_dev *d,
struct vhost_iotlb_msg *msg)
{
@@ -1117,9 +1056,9 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
break;
}
vhost_vq_meta_reset(dev);
- if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
- msg->iova + msg->size - 1,
- msg->uaddr, msg->perm)) {
+ if (vhost_iotlb_add_range(dev->iotlb, msg->iova,
+ msg->iova + msg->size - 1,
+ msg->uaddr, msg->perm)) {
ret = -ENOMEM;
break;
}
@@ -1131,8 +1070,8 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
break;
}
vhost_vq_meta_reset(dev);
- vhost_del_umem_range(dev->iotlb, msg->iova,
- msg->iova + msg->size - 1);
+ vhost_iotlb_del_range(dev->iotlb, msg->iova,
+ msg->iova + msg->size - 1);
break;
default:
ret = -EINVAL;
@@ -1178,7 +1117,12 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
ret = -EINVAL;
goto done;
}
- if (vhost_process_iotlb_msg(dev, &msg)) {
+
+ if (dev->msg_handler)
+ ret = dev->msg_handler(dev, &msg);
+ else
+ ret = vhost_process_iotlb_msg(dev, &msg);
+ if (ret) {
ret = -EFAULT;
goto done;
}
@@ -1311,44 +1255,42 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
}
static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
- const struct vhost_umem_node *node,
+ const struct vhost_iotlb_map *map,
int type)
{
int access = (type == VHOST_ADDR_USED) ?
VHOST_ACCESS_WO : VHOST_ACCESS_RO;
- if (likely(node->perm & access))
- vq->meta_iotlb[type] = node;
+ if (likely(map->perm & access))
+ vq->meta_iotlb[type] = map;
}
static bool iotlb_access_ok(struct vhost_virtqueue *vq,
int access, u64 addr, u64 len, int type)
{
- const struct vhost_umem_node *node;
- struct vhost_umem *umem = vq->iotlb;
+ const struct vhost_iotlb_map *map;
+ struct vhost_iotlb *umem = vq->iotlb;
u64 s = 0, size, orig_addr = addr, last = addr + len - 1;
if (vhost_vq_meta_fetch(vq, addr, len, type))
return true;
while (len > s) {
- node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
- addr,
- last);
- if (node == NULL || node->start > addr) {
+ map = vhost_iotlb_itree_first(umem, addr, last);
+ if (map == NULL || map->start > addr) {
vhost_iotlb_miss(vq, addr, access);
return false;
- } else if (!(node->perm & access)) {
+ } else if (!(map->perm & access)) {
/* Report the possible access violation by
* request another translation from userspace.
*/
return false;
}
- size = node->size - addr + node->start;
+ size = map->size - addr + map->start;
if (orig_addr == addr && size >= len)
- vhost_vq_meta_update(vq, node, type);
+ vhost_vq_meta_update(vq, map, type);
s += size;
addr += size;
@@ -1364,12 +1306,12 @@ int vq_meta_prefetch(struct vhost_virtqueue *vq)
if (!vq->iotlb)
return 1;
- return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
+ return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc,
vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
- iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
+ iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail,
vhost_get_avail_size(vq, num),
VHOST_ADDR_AVAIL) &&
- iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
+ iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used,
vhost_get_used_size(vq, num), VHOST_ADDR_USED);
}
EXPORT_SYMBOL_GPL(vq_meta_prefetch);
@@ -1408,25 +1350,11 @@ bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
}
EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
-static struct vhost_umem *vhost_umem_alloc(void)
-{
- struct vhost_umem *umem = kvzalloc(sizeof(*umem), GFP_KERNEL);
-
- if (!umem)
- return NULL;
-
- umem->umem_tree = RB_ROOT_CACHED;
- umem->numem = 0;
- INIT_LIST_HEAD(&umem->umem_list);
-
- return umem;
-}
-
static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
{
struct vhost_memory mem, *newmem;
struct vhost_memory_region *region;
- struct vhost_umem *newumem, *oldumem;
+ struct vhost_iotlb *newumem, *oldumem;
unsigned long size = offsetof(struct vhost_memory, regions);
int i;
@@ -1448,7 +1376,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
return -EFAULT;
}
- newumem = vhost_umem_alloc();
+ newumem = iotlb_alloc();
if (!newumem) {
kvfree(newmem);
return -ENOMEM;
@@ -1457,13 +1385,12 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
for (region = newmem->regions;
region < newmem->regions + mem.nregions;
region++) {
- if (vhost_new_umem_range(newumem,
- region->guest_phys_addr,
- region->memory_size,
- region->guest_phys_addr +
- region->memory_size - 1,
- region->userspace_addr,
- VHOST_ACCESS_RW))
+ if (vhost_iotlb_add_range(newumem,
+ region->guest_phys_addr,
+ region->guest_phys_addr +
+ region->memory_size - 1,
+ region->userspace_addr,
+ VHOST_MAP_RW))
goto err;
}
@@ -1481,11 +1408,11 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
}
kvfree(newmem);
- vhost_umem_clean(oldumem);
+ vhost_iotlb_free(oldumem);
return 0;
err:
- vhost_umem_clean(newumem);
+ vhost_iotlb_free(newumem);
kvfree(newmem);
return -EFAULT;
}
@@ -1726,10 +1653,10 @@ EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
{
- struct vhost_umem *niotlb, *oiotlb;
+ struct vhost_iotlb *niotlb, *oiotlb;
int i;
- niotlb = vhost_umem_alloc();
+ niotlb = iotlb_alloc();
if (!niotlb)
return -ENOMEM;
@@ -1745,7 +1672,7 @@ int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
mutex_unlock(&vq->mutex);
}
- vhost_umem_clean(oiotlb);
+ vhost_iotlb_free(oiotlb);
return 0;
}
@@ -1875,8 +1802,8 @@ static int log_write(void __user *log_base,
static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
{
- struct vhost_umem *umem = vq->umem;
- struct vhost_umem_node *u;
+ struct vhost_iotlb *umem = vq->umem;
+ struct vhost_iotlb_map *u;
u64 start, end, l, min;
int r;
bool hit = false;
@@ -1886,16 +1813,15 @@ static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
/* More than one GPAs can be mapped into a single HVA. So
* iterate all possible umems here to be safe.
*/
- list_for_each_entry(u, &umem->umem_list, link) {
- if (u->userspace_addr > hva - 1 + len ||
- u->userspace_addr - 1 + u->size < hva)
+ list_for_each_entry(u, &umem->list, link) {
+ if (u->addr > hva - 1 + len ||
+ u->addr - 1 + u->size < hva)
continue;
- start = max(u->userspace_addr, hva);
- end = min(u->userspace_addr - 1 + u->size,
- hva - 1 + len);
+ start = max(u->addr, hva);
+ end = min(u->addr - 1 + u->size, hva - 1 + len);
l = end - start + 1;
r = log_write(vq->log_base,
- u->start + start - u->userspace_addr,
+ u->start + start - u->addr,
l);
if (r < 0)
return r;
@@ -2046,9 +1972,9 @@ EXPORT_SYMBOL_GPL(vhost_vq_init_access);
static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
struct iovec iov[], int iov_size, int access)
{
- const struct vhost_umem_node *node;
+ const struct vhost_iotlb_map *map;
struct vhost_dev *dev = vq->dev;
- struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
+ struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem;
struct iovec *_iov;
u64 s = 0;
int ret = 0;
@@ -2060,25 +1986,24 @@ static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
break;
}
- node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
- addr, addr + len - 1);
- if (node == NULL || node->start > addr) {
+ map = vhost_iotlb_itree_first(umem, addr, addr + len - 1);
+ if (map == NULL || map->start > addr) {
if (umem != dev->iotlb) {
ret = -EFAULT;
break;
}
ret = -EAGAIN;
break;
- } else if (!(node->perm & access)) {
+ } else if (!(map->perm & access)) {
ret = -EPERM;
break;
}
_iov = iov + ret;
- size = node->size - addr + node->start;
+ size = map->size - addr + map->start;
_iov->iov_len = min((u64)len - s, size);
_iov->iov_base = (void __user *)(unsigned long)
- (node->userspace_addr + addr - node->start);
+ (map->addr + addr - map->start);
s += size;
addr += size;
++ret;