diff options
Diffstat (limited to 'drivers/vfio/device_cdev.c')
-rw-r--r-- | drivers/vfio/device_cdev.c | 38 |
1 files changed, 35 insertions, 3 deletions
diff --git a/drivers/vfio/device_cdev.c b/drivers/vfio/device_cdev.c index 281a8dc3ed49..480cac3a0c27 100644 --- a/drivers/vfio/device_cdev.c +++ b/drivers/vfio/device_cdev.c @@ -60,22 +60,50 @@ static void vfio_df_get_kvm_safe(struct vfio_device_file *df) spin_unlock(&df->kvm_ref_lock); } +static int vfio_df_check_token(struct vfio_device *device, + const struct vfio_device_bind_iommufd *bind) +{ + uuid_t uuid; + + if (!device->ops->match_token_uuid) { + if (bind->flags & VFIO_DEVICE_BIND_FLAG_TOKEN) + return -EINVAL; + return 0; + } + + if (!(bind->flags & VFIO_DEVICE_BIND_FLAG_TOKEN)) + return device->ops->match_token_uuid(device, NULL); + + if (copy_from_user(&uuid, u64_to_user_ptr(bind->token_uuid_ptr), + sizeof(uuid))) + return -EFAULT; + return device->ops->match_token_uuid(device, &uuid); +} + long vfio_df_ioctl_bind_iommufd(struct vfio_device_file *df, struct vfio_device_bind_iommufd __user *arg) { + const u32 VALID_FLAGS = VFIO_DEVICE_BIND_FLAG_TOKEN; struct vfio_device *device = df->device; struct vfio_device_bind_iommufd bind; unsigned long minsz; + u32 user_size; int ret; static_assert(__same_type(arg->out_devid, df->devid)); minsz = offsetofend(struct vfio_device_bind_iommufd, out_devid); - if (copy_from_user(&bind, arg, minsz)) - return -EFAULT; + ret = get_user(user_size, &arg->argsz); + if (ret) + return ret; + if (user_size < minsz) + return -EINVAL; + ret = copy_struct_from_user(&bind, minsz, arg, user_size); + if (ret) + return ret; - if (bind.argsz < minsz || bind.flags || bind.iommufd < 0) + if (bind.iommufd < 0 || bind.flags & ~VALID_FLAGS) return -EINVAL; /* BIND_IOMMUFD only allowed for cdev fds */ @@ -93,6 +121,10 @@ long vfio_df_ioctl_bind_iommufd(struct vfio_device_file *df, goto out_unlock; } + ret = vfio_df_check_token(device, &bind); + if (ret) + goto out_unlock; + df->iommufd = iommufd_ctx_from_fd(bind.iommufd); if (IS_ERR(df->iommufd)) { ret = PTR_ERR(df->iommufd); |