diff options
Diffstat (limited to 'drivers/vhost')
-rw-r--r-- | drivers/vhost/Kconfig | 18 | ||||
-rw-r--r-- | drivers/vhost/net.c | 231 | ||||
-rw-r--r-- | drivers/vhost/scsi.c | 214 | ||||
-rw-r--r-- | drivers/vhost/vdpa.c | 10 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 407 | ||||
-rw-r--r-- | drivers/vhost/vhost.h | 34 | ||||
-rw-r--r-- | drivers/vhost/vringh.c | 137 | ||||
-rw-r--r-- | drivers/vhost/vsock.c | 15 |
8 files changed, 776 insertions, 290 deletions
diff --git a/drivers/vhost/Kconfig b/drivers/vhost/Kconfig index 020d4fbb947c..bc0f38574497 100644 --- a/drivers/vhost/Kconfig +++ b/drivers/vhost/Kconfig @@ -95,4 +95,22 @@ config VHOST_CROSS_ENDIAN_LEGACY If unsure, say "N". +config VHOST_ENABLE_FORK_OWNER_CONTROL + bool "Enable VHOST_ENABLE_FORK_OWNER_CONTROL" + default y + help + This option enables two IOCTLs: VHOST_SET_FORK_FROM_OWNER and + VHOST_GET_FORK_FROM_OWNER. These allow userspace applications + to modify the vhost worker mode for vhost devices. + + Also expose module parameter 'fork_from_owner_default' to allow users + to configure the default mode for vhost workers. + + By default, `VHOST_ENABLE_FORK_OWNER_CONTROL` is set to `y`, + users can change the worker thread mode as needed. + If this config is disabled (n),the related IOCTLs and parameters will + be unavailable. + + If unsure, say "Y". + endif diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index b9b9e9d40951..6edac0c1ba9b 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -69,12 +69,15 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;" #define VHOST_DMA_IS_DONE(len) ((__force u32)(len) >= (__force u32)VHOST_DMA_DONE_LEN) -enum { - VHOST_NET_FEATURES = VHOST_FEATURES | - (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | - (1ULL << VIRTIO_NET_F_MRG_RXBUF) | - (1ULL << VIRTIO_F_ACCESS_PLATFORM) | - (1ULL << VIRTIO_F_RING_RESET) +static const u64 vhost_net_features[VIRTIO_FEATURES_DWORDS] = { + VHOST_FEATURES | + (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | + (1ULL << VIRTIO_NET_F_MRG_RXBUF) | + (1ULL << VIRTIO_F_ACCESS_PLATFORM) | + (1ULL << VIRTIO_F_RING_RESET) | + (1ULL << VIRTIO_F_IN_ORDER), + VIRTIO_BIT(VIRTIO_NET_F_GUEST_UDP_TUNNEL_GSO) | + VIRTIO_BIT(VIRTIO_NET_F_HOST_UDP_TUNNEL_GSO), }; enum { @@ -374,7 +377,8 @@ static void vhost_zerocopy_signal_used(struct vhost_net *net, while (j) { add = min(UIO_MAXIOV - nvq->done_idx, j); vhost_add_used_and_signal_n(vq->dev, vq, - &vq->heads[nvq->done_idx], add); + &vq->heads[nvq->done_idx], + NULL, add); nvq->done_idx = (nvq->done_idx + add) % UIO_MAXIOV; j -= add; } @@ -449,7 +453,8 @@ static int vhost_net_enable_vq(struct vhost_net *n, return vhost_poll_start(poll, sock->file); } -static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq) +static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq, + unsigned int count) { struct vhost_virtqueue *vq = &nvq->vq; struct vhost_dev *dev = vq->dev; @@ -457,7 +462,8 @@ static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq) if (!nvq->done_idx) return; - vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx); + vhost_add_used_and_signal_n(dev, vq, vq->heads, + vq->nheads, count); nvq->done_idx = 0; } @@ -466,6 +472,8 @@ static void vhost_tx_batch(struct vhost_net *net, struct socket *sock, struct msghdr *msghdr) { + struct vhost_virtqueue *vq = &nvq->vq; + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); struct tun_msg_ctl ctl = { .type = TUN_MSG_PTR, .num = nvq->batched_xdp, @@ -473,6 +481,11 @@ static void vhost_tx_batch(struct vhost_net *net, }; int i, err; + if (in_order) { + vq->heads[0].len = 0; + vq->nheads[0] = nvq->done_idx; + } + if (nvq->batched_xdp == 0) goto signal_used; @@ -494,7 +507,7 @@ static void vhost_tx_batch(struct vhost_net *net, } signal_used: - vhost_net_signal_used(nvq); + vhost_net_signal_used(nvq, in_order ? 1 : nvq->done_idx); nvq->batched_xdp = 0; } @@ -668,7 +681,6 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, struct socket *sock = vhost_vq_get_backend(vq); struct virtio_net_hdr *gso; struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp]; - struct tun_xdp_hdr *hdr; size_t len = iov_iter_count(from); int headroom = vhost_sock_xdp(sock) ? XDP_PACKET_HEADROOM : 0; int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); @@ -691,15 +703,13 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, if (unlikely(!buf)) return -ENOMEM; - copied = copy_from_iter(buf + offsetof(struct tun_xdp_hdr, gso), - sock_hlen, from); - if (copied != sock_hlen) { + copied = copy_from_iter(buf + pad - sock_hlen, len, from); + if (copied != len) { ret = -EFAULT; goto err; } - hdr = buf; - gso = &hdr->gso; + gso = buf + pad - sock_hlen; if (!sock_hlen) memset(buf, 0, pad); @@ -718,16 +728,11 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, } } - len -= sock_hlen; - copied = copy_from_iter(buf + pad, len, from); - if (copied != len) { - ret = -EFAULT; - goto err; - } + /* pad contains sock_hlen */ + memcpy(buf, buf + pad - sock_hlen, sock_hlen); xdp_init_buff(xdp, buflen, NULL); - xdp_prepare_buff(xdp, buf, pad, len, true); - hdr->buflen = buflen; + xdp_prepare_buff(xdp, buf, pad, len - sock_hlen, true); ++nvq->batched_xdp; @@ -755,10 +760,11 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock) int err; int sent_pkts = 0; bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX); + bool busyloop_intr; + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); do { - bool busyloop_intr = false; - + busyloop_intr = false; if (nvq->done_idx == VHOST_NET_BATCH) vhost_tx_batch(net, nvq, sock, &msg); @@ -769,13 +775,10 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock) break; /* Nothing new? Wait for eventfd to tell us they refilled. */ if (head == vq->num) { - if (unlikely(busyloop_intr)) { - vhost_poll_queue(&vq->poll); - } else if (unlikely(vhost_enable_notify(&net->dev, - vq))) { - vhost_disable_notify(&net->dev, vq); - continue; - } + /* Kicks are disabled at this point, break loop and + * process any remaining batched packets. Queue will + * be re-enabled afterwards. + */ break; } @@ -795,11 +798,13 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock) break; } - /* We can't build XDP buff, go for single - * packet path but let's flush batched - * packets. - */ - vhost_tx_batch(net, nvq, sock, &msg); + if (nvq->batched_xdp) { + /* We can't build XDP buff, go for single + * packet path but let's flush batched + * packets. + */ + vhost_tx_batch(net, nvq, sock, &msg); + } msg.msg_control = NULL; } else { if (tx_can_batch(vq, total_len)) @@ -820,12 +825,31 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock) pr_debug("Truncated TX packet: len %d != %zd\n", err, len); done: - vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head); - vq->heads[nvq->done_idx].len = 0; + if (in_order) { + vq->heads[0].id = cpu_to_vhost32(vq, head); + } else { + vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head); + vq->heads[nvq->done_idx].len = 0; + } ++nvq->done_idx; } while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len))); + /* Kicks are still disabled, dispatch any remaining batched msgs. */ vhost_tx_batch(net, nvq, sock, &msg); + + if (unlikely(busyloop_intr)) + /* If interrupted while doing busy polling, requeue the + * handler to be fair handle_rx as well as other tasks + * waiting on cpu. + */ + vhost_poll_queue(&vq->poll); + else + /* All of our work has been completed; however, before + * leaving the TX handler, do one last check for work, + * and requeue handler if necessary. If there is no work, + * queue will be reenabled. + */ + vhost_net_busy_poll_try_queue(net, vq); } static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock) @@ -985,7 +1009,7 @@ static int peek_head_len(struct vhost_net_virtqueue *rvq, struct sock *sk) } static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, - bool *busyloop_intr) + bool *busyloop_intr, unsigned int count) { struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX]; struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX]; @@ -995,7 +1019,7 @@ static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, if (!len && rvq->busyloop_timeout) { /* Flush batched heads first */ - vhost_net_signal_used(rnvq); + vhost_net_signal_used(rnvq, count); /* Both tx vq and rx socket were polled here */ vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, true); @@ -1007,7 +1031,7 @@ static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, /* This is a multi-buffer version of vhost_get_desc, that works if * vq has read descriptors only. - * @vq - the relevant virtqueue + * @nvq - the relevant vhost_net virtqueue * @datalen - data length we'll be reading * @iovcount - returned count of io vectors we fill * @log - vhost log @@ -1015,14 +1039,17 @@ static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk, * @quota - headcount quota, 1 for big buffer * returns number of buffer heads allocated, negative on error */ -static int get_rx_bufs(struct vhost_virtqueue *vq, +static int get_rx_bufs(struct vhost_net_virtqueue *nvq, struct vring_used_elem *heads, + u16 *nheads, int datalen, unsigned *iovcount, struct vhost_log *log, unsigned *log_num, unsigned int quota) { + struct vhost_virtqueue *vq = &nvq->vq; + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); unsigned int out, in; int seg = 0; int headcount = 0; @@ -1059,14 +1086,16 @@ static int get_rx_bufs(struct vhost_virtqueue *vq, nlogs += *log_num; log += *log_num; } - heads[headcount].id = cpu_to_vhost32(vq, d); len = iov_length(vq->iov + seg, in); - heads[headcount].len = cpu_to_vhost32(vq, len); - datalen -= len; + if (!in_order) { + heads[headcount].id = cpu_to_vhost32(vq, d); + heads[headcount].len = cpu_to_vhost32(vq, len); + } ++headcount; + datalen -= len; seg += in; } - heads[headcount - 1].len = cpu_to_vhost32(vq, len + datalen); + *iovcount = seg; if (unlikely(log)) *log_num = nlogs; @@ -1076,6 +1105,15 @@ static int get_rx_bufs(struct vhost_virtqueue *vq, r = UIO_MAXIOV + 1; goto err; } + + if (!in_order) + heads[headcount - 1].len = cpu_to_vhost32(vq, len + datalen); + else { + heads[0].len = cpu_to_vhost32(vq, len + datalen); + heads[0].id = cpu_to_vhost32(vq, d); + nheads[0] = headcount; + } + return headcount; err: vhost_discard_vq_desc(vq, headcount); @@ -1088,6 +1126,8 @@ static void handle_rx(struct vhost_net *net) { struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX]; struct vhost_virtqueue *vq = &nvq->vq; + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); + unsigned int count = 0; unsigned in, log; struct vhost_log *vq_log; struct msghdr msg = { @@ -1135,12 +1175,13 @@ static void handle_rx(struct vhost_net *net) do { sock_len = vhost_net_rx_peek_head_len(net, sock->sk, - &busyloop_intr); + &busyloop_intr, count); if (!sock_len) break; sock_len += sock_hlen; vhost_len = sock_len + vhost_hlen; - headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx, + headcount = get_rx_bufs(nvq, vq->heads + count, + vq->nheads + count, vhost_len, &in, vq_log, &log, likely(mergeable) ? UIO_MAXIOV : 1); /* On error, stop handling until the next kick. */ @@ -1216,8 +1257,11 @@ static void handle_rx(struct vhost_net *net) goto out; } nvq->done_idx += headcount; - if (nvq->done_idx > VHOST_NET_BATCH) - vhost_net_signal_used(nvq); + count += in_order ? 1 : headcount; + if (nvq->done_idx > VHOST_NET_BATCH) { + vhost_net_signal_used(nvq, count); + count = 0; + } if (unlikely(vq_log)) vhost_log_write(vq, vq_log, log, vhost_len, vq->iov, in); @@ -1229,7 +1273,7 @@ static void handle_rx(struct vhost_net *net) else if (!sock_len) vhost_net_enable_vq(net, vq); out: - vhost_net_signal_used(nvq); + vhost_net_signal_used(nvq, count); mutex_unlock(&vq->mutex); } @@ -1602,16 +1646,23 @@ done: return err; } -static int vhost_net_set_features(struct vhost_net *n, u64 features) +static int vhost_net_set_features(struct vhost_net *n, const u64 *features) { size_t vhost_hlen, sock_hlen, hdr_len; int i; - hdr_len = (features & ((1ULL << VIRTIO_NET_F_MRG_RXBUF) | - (1ULL << VIRTIO_F_VERSION_1))) ? - sizeof(struct virtio_net_hdr_mrg_rxbuf) : - sizeof(struct virtio_net_hdr); - if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) { + hdr_len = virtio_features_test_bit(features, VIRTIO_NET_F_MRG_RXBUF) || + virtio_features_test_bit(features, VIRTIO_F_VERSION_1) ? + sizeof(struct virtio_net_hdr_mrg_rxbuf) : + sizeof(struct virtio_net_hdr); + + if (virtio_features_test_bit(features, + VIRTIO_NET_F_HOST_UDP_TUNNEL_GSO) || + virtio_features_test_bit(features, + VIRTIO_NET_F_GUEST_UDP_TUNNEL_GSO)) + hdr_len = sizeof(struct virtio_net_hdr_v1_hash_tunnel); + + if (virtio_features_test_bit(features, VHOST_NET_F_VIRTIO_NET_HDR)) { /* vhost provides vnet_hdr */ vhost_hlen = hdr_len; sock_hlen = 0; @@ -1621,18 +1672,19 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) sock_hlen = hdr_len; } mutex_lock(&n->dev.mutex); - if ((features & (1 << VHOST_F_LOG_ALL)) && + if (virtio_features_test_bit(features, VHOST_F_LOG_ALL) && !vhost_log_access_ok(&n->dev)) goto out_unlock; - if ((features & (1ULL << VIRTIO_F_ACCESS_PLATFORM))) { + if (virtio_features_test_bit(features, VIRTIO_F_ACCESS_PLATFORM)) { if (vhost_init_device_iotlb(&n->dev)) goto out_unlock; } for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { mutex_lock(&n->vqs[i].vq.mutex); - n->vqs[i].vq.acked_features = features; + virtio_features_copy(n->vqs[i].vq.acked_features_array, + features); n->vqs[i].vhost_hlen = vhost_hlen; n->vqs[i].sock_hlen = sock_hlen; mutex_unlock(&n->vqs[i].vq.mutex); @@ -1669,12 +1721,13 @@ out: static long vhost_net_ioctl(struct file *f, unsigned int ioctl, unsigned long arg) { + u64 all_features[VIRTIO_FEATURES_DWORDS]; struct vhost_net *n = f->private_data; void __user *argp = (void __user *)arg; u64 __user *featurep = argp; struct vhost_vring_file backend; - u64 features; - int r; + u64 features, count, copied; + int r, i; switch (ioctl) { case VHOST_NET_SET_BACKEND: @@ -1682,16 +1735,60 @@ static long vhost_net_ioctl(struct file *f, unsigned int ioctl, return -EFAULT; return vhost_net_set_backend(n, backend.index, backend.fd); case VHOST_GET_FEATURES: - features = VHOST_NET_FEATURES; + features = vhost_net_features[0]; if (copy_to_user(featurep, &features, sizeof features)) return -EFAULT; return 0; case VHOST_SET_FEATURES: if (copy_from_user(&features, featurep, sizeof features)) return -EFAULT; - if (features & ~VHOST_NET_FEATURES) + if (features & ~vhost_net_features[0]) return -EOPNOTSUPP; - return vhost_net_set_features(n, features); + + virtio_features_from_u64(all_features, features); + return vhost_net_set_features(n, all_features); + case VHOST_GET_FEATURES_ARRAY: + if (copy_from_user(&count, featurep, sizeof(count))) + return -EFAULT; + + /* Copy the net features, up to the user-provided buffer size */ + argp += sizeof(u64); + copied = min(count, VIRTIO_FEATURES_DWORDS); + if (copy_to_user(argp, vhost_net_features, + copied * sizeof(u64))) + return -EFAULT; + + /* Zero the trailing space provided by user-space, if any */ + if (clear_user(argp, size_mul(count - copied, sizeof(u64)))) + return -EFAULT; + return 0; + case VHOST_SET_FEATURES_ARRAY: + if (copy_from_user(&count, featurep, sizeof(count))) + return -EFAULT; + + virtio_features_zero(all_features); + argp += sizeof(u64); + copied = min(count, VIRTIO_FEATURES_DWORDS); + if (copy_from_user(all_features, argp, copied * sizeof(u64))) + return -EFAULT; + + /* + * Any feature specified by user-space above + * VIRTIO_FEATURES_MAX is not supported by definition. + */ + for (i = copied; i < count; ++i) { + if (copy_from_user(&features, featurep + 1 + i, + sizeof(features))) + return -EFAULT; + if (features) + return -EOPNOTSUPP; + } + + for (i = 0; i < VIRTIO_FEATURES_DWORDS; i++) + if (all_features[i] & ~vhost_net_features[i]) + return -EOPNOTSUPP; + + return vhost_net_set_features(n, all_features); case VHOST_GET_BACKEND_FEATURES: features = VHOST_NET_BACKEND_FEATURES; if (copy_to_user(featurep, &features, sizeof(features))) diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c index 26bcf3a7f70c..abf51332a5c5 100644 --- a/drivers/vhost/scsi.c +++ b/drivers/vhost/scsi.c @@ -71,7 +71,7 @@ static int vhost_scsi_set_inline_sg_cnt(const char *buf, if (ret) return ret; - if (ret > VHOST_SCSI_PREALLOC_SGLS) { + if (cnt > VHOST_SCSI_PREALLOC_SGLS) { pr_err("Max inline_sg_cnt is %u\n", VHOST_SCSI_PREALLOC_SGLS); return -EINVAL; } @@ -133,6 +133,11 @@ struct vhost_scsi_cmd { struct se_cmd tvc_se_cmd; /* Sense buffer that will be mapped into outgoing status */ unsigned char tvc_sense_buf[TRANSPORT_SENSE_BUFFER]; + /* + * Dirty write descriptors of this command. + */ + struct vhost_log *tvc_log; + unsigned int tvc_log_num; /* Completed commands list, serviced from vhost worker thread */ struct llist_node tvc_completion_list; /* Used to track inflight cmd */ @@ -147,7 +152,7 @@ struct vhost_scsi_nexus { struct vhost_scsi_tpg { /* Vhost port target portal group tag for TCM */ u16 tport_tpgt; - /* Used to track number of TPG Port/Lun Links wrt to explict I_T Nexus shutdown */ + /* Used to track number of TPG Port/Lun Links wrt to explicit I_T Nexus shutdown */ int tv_tpg_port_count; /* Used for vhost_scsi device reference to tpg_nexus, protected by tv_tpg_mutex */ int tv_tpg_vhost_count; @@ -258,6 +263,12 @@ struct vhost_scsi_tmf { struct iovec resp_iov; int in_iovs; int vq_desc; + + /* + * Dirty write descriptors of this command. + */ + struct vhost_log *tmf_log; + unsigned int tmf_log_num; }; /* @@ -300,12 +311,12 @@ static void vhost_scsi_init_inflight(struct vhost_scsi *vs, mutex_lock(&vq->mutex); - /* store old infight */ + /* store old inflight */ idx = vs->vqs[i].inflight_idx; if (old_inflight) old_inflight[i] = &vs->vqs[i].inflights[idx]; - /* setup new infight */ + /* setup new inflight */ vs->vqs[i].inflight_idx = idx ^ 1; new_inflight = &vs->vqs[i].inflights[idx ^ 1]; kref_init(&new_inflight->kref); @@ -362,6 +373,45 @@ static int vhost_scsi_check_prot_fabric_only(struct se_portal_group *se_tpg) return tpg->tv_fabric_prot_type; } +static int vhost_scsi_copy_cmd_log(struct vhost_virtqueue *vq, + struct vhost_scsi_cmd *cmd, + struct vhost_log *log, + unsigned int log_num) +{ + if (!cmd->tvc_log) + cmd->tvc_log = kmalloc_array(vq->dev->iov_limit, + sizeof(*cmd->tvc_log), + GFP_KERNEL); + + if (unlikely(!cmd->tvc_log)) { + vq_err(vq, "Failed to alloc tvc_log\n"); + return -ENOMEM; + } + + memcpy(cmd->tvc_log, log, sizeof(*cmd->tvc_log) * log_num); + cmd->tvc_log_num = log_num; + + return 0; +} + +static void vhost_scsi_log_write(struct vhost_virtqueue *vq, + struct vhost_log *log, + unsigned int log_num) +{ + if (likely(!vhost_has_feature(vq, VHOST_F_LOG_ALL))) + return; + + if (likely(!log_num || !log)) + return; + + /* + * vhost-scsi doesn't support VIRTIO_F_ACCESS_PLATFORM. + * No requirement for vq->iotlb case. + */ + WARN_ON_ONCE(unlikely(vq->iotlb)); + vhost_log_write(vq, log, log_num, U64_MAX, NULL, 0); +} + static void vhost_scsi_release_cmd_res(struct se_cmd *se_cmd) { struct vhost_scsi_cmd *tv_cmd = container_of(se_cmd, @@ -408,6 +458,10 @@ static void vhost_scsi_release_tmf_res(struct vhost_scsi_tmf *tmf) { struct vhost_scsi_inflight *inflight = tmf->inflight; + /* + * tmf->tmf_log is default NULL unless VHOST_F_LOG_ALL is set. + */ + kfree(tmf->tmf_log); kfree(tmf); vhost_scsi_put_inflight(inflight); } @@ -517,6 +571,8 @@ vhost_scsi_do_evt_work(struct vhost_scsi *vs, struct vhost_scsi_evt *evt) struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq; struct virtio_scsi_event *event = &evt->event; struct virtio_scsi_event __user *eventp; + struct vhost_log *vq_log; + unsigned int log_num; unsigned out, in; int head, ret; @@ -527,9 +583,19 @@ vhost_scsi_do_evt_work(struct vhost_scsi *vs, struct vhost_scsi_evt *evt) again: vhost_disable_notify(&vs->dev, vq); + + vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ? + vq->log : NULL; + + /* + * Reset 'log_num' since vhost_get_vq_desc() may reset it only + * after certain condition checks. + */ + log_num = 0; + head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), &out, &in, - NULL, NULL); + vq_log, &log_num); if (head < 0) { vs->vs_events_missed = true; return; @@ -559,6 +625,8 @@ again: vhost_add_used_and_signal(&vs->dev, vq, head, 0); else vq_err(vq, "Faulted on vhost_scsi_send_event\n"); + + vhost_scsi_log_write(vq, vq_log, log_num); } static void vhost_scsi_complete_events(struct vhost_scsi *vs, bool drop) @@ -660,6 +728,9 @@ static void vhost_scsi_complete_cmd_work(struct vhost_work *work) } else pr_err("Faulted on virtio_scsi_cmd_resp\n"); + vhost_scsi_log_write(cmd->tvc_vq, cmd->tvc_log, + cmd->tvc_log_num); + vhost_scsi_release_cmd_res(se_cmd); } @@ -676,6 +747,7 @@ vhost_scsi_get_cmd(struct vhost_virtqueue *vq, u64 scsi_tag) struct vhost_scsi_virtqueue, vq); struct vhost_scsi_cmd *cmd; struct scatterlist *sgl, *prot_sgl; + struct vhost_log *log; int tag; tag = sbitmap_get(&svq->scsi_tags); @@ -687,9 +759,11 @@ vhost_scsi_get_cmd(struct vhost_virtqueue *vq, u64 scsi_tag) cmd = &svq->scsi_cmds[tag]; sgl = cmd->sgl; prot_sgl = cmd->prot_sgl; + log = cmd->tvc_log; memset(cmd, 0, sizeof(*cmd)); cmd->sgl = sgl; cmd->prot_sgl = prot_sgl; + cmd->tvc_log = log; cmd->tvc_se_cmd.map_tag = tag; cmd->inflight = vhost_scsi_get_inflight(vq); @@ -1063,13 +1137,17 @@ vhost_scsi_send_bad_target(struct vhost_scsi *vs, static int vhost_scsi_get_desc(struct vhost_scsi *vs, struct vhost_virtqueue *vq, - struct vhost_scsi_ctx *vc) + struct vhost_scsi_ctx *vc, + struct vhost_log *log, unsigned int *log_num) { int ret = -ENXIO; + if (likely(log_num)) + *log_num = 0; + vc->head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), &vc->out, &vc->in, - NULL, NULL); + log, log_num); pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n", vc->head, vc->out, vc->in); @@ -1148,10 +1226,8 @@ vhost_scsi_get_req(struct vhost_virtqueue *vq, struct vhost_scsi_ctx *vc, /* validated at handler entry */ vs_tpg = vhost_vq_get_backend(vq); tpg = READ_ONCE(vs_tpg[*vc->target]); - if (unlikely(!tpg)) { - vq_err(vq, "Target 0x%x does not exist\n", *vc->target); + if (unlikely(!tpg)) goto out; - } } if (tpgp) @@ -1171,7 +1247,7 @@ vhost_scsi_setup_resp_iovs(struct vhost_scsi_cmd *cmd, struct iovec *in_iovs, if (!in_iovs_cnt) return 0; /* - * Initiator's normally just put the virtio_scsi_cmd_resp in the first + * Initiators normally just put the virtio_scsi_cmd_resp in the first * iov, but just in case they wedged in some data with it we check for * greater than or equal to the response struct. */ @@ -1221,6 +1297,8 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) u8 task_attr; bool t10_pi = vhost_has_feature(vq, VIRTIO_SCSI_F_T10_PI); u8 *cdb; + struct vhost_log *vq_log; + unsigned int log_num; mutex_lock(&vq->mutex); /* @@ -1236,8 +1314,11 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) vhost_disable_notify(&vs->dev, vq); + vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ? + vq->log : NULL; + do { - ret = vhost_scsi_get_desc(vs, vq, &vc); + ret = vhost_scsi_get_desc(vs, vq, &vc, vq_log, &log_num); if (ret) goto err; @@ -1374,7 +1455,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) cmd = vhost_scsi_get_cmd(vq, tag); if (IS_ERR(cmd)) { ret = PTR_ERR(cmd); - vq_err(vq, "vhost_scsi_get_tag failed %dd\n", ret); + vq_err(vq, "vhost_scsi_get_tag failed %d\n", ret); goto err; } cmd->tvc_vq = vq; @@ -1386,6 +1467,14 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) goto err; } + if (unlikely(vq_log && log_num)) { + ret = vhost_scsi_copy_cmd_log(vq, cmd, vq_log, log_num); + if (unlikely(ret)) { + vhost_scsi_release_cmd_res(&cmd->tvc_se_cmd); + goto err; + } + } + pr_debug("vhost_scsi got command opcode: %#02x, lun: %d\n", cdb[0], lun); pr_debug("cmd: %p exp_data_len: %d, prot_bytes: %d data_direction:" @@ -1421,11 +1510,14 @@ err: */ if (ret == -ENXIO) break; - else if (ret == -EIO) + else if (ret == -EIO) { vhost_scsi_send_bad_target(vs, vq, &vc, TYPE_IO_CMD); - else if (ret == -ENOMEM) + vhost_scsi_log_write(vq, vq_log, log_num); + } else if (ret == -ENOMEM) { vhost_scsi_send_status(vs, vq, &vc, SAM_STAT_TASK_SET_FULL); + vhost_scsi_log_write(vq, vq_log, log_num); + } } while (likely(!vhost_exceeds_weight(vq, ++c, 0))); out: mutex_unlock(&vq->mutex); @@ -1467,6 +1559,8 @@ static void vhost_scsi_tmf_resp_work(struct vhost_work *work) mutex_lock(&tmf->svq->vq.mutex); vhost_scsi_send_tmf_resp(tmf->vhost, &tmf->svq->vq, tmf->in_iovs, tmf->vq_desc, &tmf->resp_iov, resp_code); + vhost_scsi_log_write(&tmf->svq->vq, tmf->tmf_log, + tmf->tmf_log_num); mutex_unlock(&tmf->svq->vq.mutex); vhost_scsi_release_tmf_res(tmf); @@ -1490,7 +1584,8 @@ static void vhost_scsi_handle_tmf(struct vhost_scsi *vs, struct vhost_scsi_tpg *tpg, struct vhost_virtqueue *vq, struct virtio_scsi_ctrl_tmf_req *vtmf, - struct vhost_scsi_ctx *vc) + struct vhost_scsi_ctx *vc, + struct vhost_log *log, unsigned int log_num) { struct vhost_scsi_virtqueue *svq = container_of(vq, struct vhost_scsi_virtqueue, vq); @@ -1518,6 +1613,19 @@ vhost_scsi_handle_tmf(struct vhost_scsi *vs, struct vhost_scsi_tpg *tpg, tmf->in_iovs = vc->in; tmf->inflight = vhost_scsi_get_inflight(vq); + if (unlikely(log && log_num)) { + tmf->tmf_log = kmalloc_array(log_num, sizeof(*tmf->tmf_log), + GFP_KERNEL); + if (tmf->tmf_log) { + memcpy(tmf->tmf_log, log, sizeof(*tmf->tmf_log) * log_num); + tmf->tmf_log_num = log_num; + } else { + pr_err("vhost_scsi tmf log allocation error\n"); + vhost_scsi_release_tmf_res(tmf); + goto send_reject; + } + } + if (target_submit_tmr(&tmf->se_cmd, tpg->tpg_nexus->tvn_se_sess, NULL, vhost_buf_to_lun(vtmf->lun), NULL, TMR_LUN_RESET, GFP_KERNEL, 0, @@ -1531,6 +1639,7 @@ vhost_scsi_handle_tmf(struct vhost_scsi *vs, struct vhost_scsi_tpg *tpg, send_reject: vhost_scsi_send_tmf_resp(vs, vq, vc->in, vc->head, &vq->iov[vc->out], VIRTIO_SCSI_S_FUNCTION_REJECTED); + vhost_scsi_log_write(vq, log, log_num); } static void @@ -1567,6 +1676,8 @@ vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) struct vhost_scsi_ctx vc; size_t typ_size; int ret, c = 0; + struct vhost_log *vq_log; + unsigned int log_num; mutex_lock(&vq->mutex); /* @@ -1580,8 +1691,11 @@ vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) vhost_disable_notify(&vs->dev, vq); + vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ? + vq->log : NULL; + do { - ret = vhost_scsi_get_desc(vs, vq, &vc); + ret = vhost_scsi_get_desc(vs, vq, &vc, vq_log, &log_num); if (ret) goto err; @@ -1645,9 +1759,12 @@ vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) goto err; if (v_req.type == VIRTIO_SCSI_T_TMF) - vhost_scsi_handle_tmf(vs, tpg, vq, &v_req.tmf, &vc); - else + vhost_scsi_handle_tmf(vs, tpg, vq, &v_req.tmf, &vc, + vq_log, log_num); + else { vhost_scsi_send_an_resp(vs, vq, &vc); + vhost_scsi_log_write(vq, vq_log, log_num); + } err: /* * ENXIO: No more requests, or read error, wait for next kick @@ -1657,11 +1774,13 @@ err: */ if (ret == -ENXIO) break; - else if (ret == -EIO) + else if (ret == -EIO) { vhost_scsi_send_bad_target(vs, vq, &vc, v_req.type == VIRTIO_SCSI_T_TMF ? TYPE_CTRL_TMF : TYPE_CTRL_AN); + vhost_scsi_log_write(vq, vq_log, log_num); + } } while (likely(!vhost_exceeds_weight(vq, ++c, 0))); out: mutex_unlock(&vq->mutex); @@ -1756,6 +1875,24 @@ static void vhost_scsi_flush(struct vhost_scsi *vs) wait_for_completion(&vs->old_inflight[i]->comp); } +static void vhost_scsi_destroy_vq_log(struct vhost_virtqueue *vq) +{ + struct vhost_scsi_virtqueue *svq = container_of(vq, + struct vhost_scsi_virtqueue, vq); + struct vhost_scsi_cmd *tv_cmd; + unsigned int i; + + if (!svq->scsi_cmds) + return; + + for (i = 0; i < svq->max_cmds; i++) { + tv_cmd = &svq->scsi_cmds[i]; + kfree(tv_cmd->tvc_log); + tv_cmd->tvc_log = NULL; + tv_cmd->tvc_log_num = 0; + } +} + static void vhost_scsi_destroy_vq_cmds(struct vhost_virtqueue *vq) { struct vhost_scsi_virtqueue *svq = container_of(vq, @@ -1775,6 +1912,7 @@ static void vhost_scsi_destroy_vq_cmds(struct vhost_virtqueue *vq) sbitmap_free(&svq->scsi_tags); kfree(svq->upages); + vhost_scsi_destroy_vq_log(vq); kfree(svq->scsi_cmds); svq->scsi_cmds = NULL; } @@ -2084,6 +2222,7 @@ err_dev: static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features) { struct vhost_virtqueue *vq; + bool is_log, was_log; int i; if (features & ~VHOST_SCSI_FEATURES) @@ -2096,12 +2235,39 @@ static int vhost_scsi_set_features(struct vhost_scsi *vs, u64 features) return -EFAULT; } + if (!vs->dev.nvqs) + goto out; + + is_log = features & (1 << VHOST_F_LOG_ALL); + /* + * All VQs should have same feature. + */ + was_log = vhost_has_feature(&vs->vqs[0].vq, VHOST_F_LOG_ALL); + for (i = 0; i < vs->dev.nvqs; i++) { vq = &vs->vqs[i].vq; mutex_lock(&vq->mutex); vq->acked_features = features; mutex_unlock(&vq->mutex); } + + /* + * If VHOST_F_LOG_ALL is removed, free tvc_log after + * vq->acked_features is committed. + */ + if (!is_log && was_log) { + for (i = VHOST_SCSI_VQ_IO; i < vs->dev.nvqs; i++) { + if (!vs->vqs[i].scsi_cmds) + continue; + + vq = &vs->vqs[i].vq; + mutex_lock(&vq->mutex); + vhost_scsi_destroy_vq_log(vq); + mutex_unlock(&vq->mutex); + } + } + +out: mutex_unlock(&vs->dev.mutex); return 0; } @@ -2441,7 +2607,7 @@ static int vhost_scsi_make_nexus(struct vhost_scsi_tpg *tpg, return -ENOMEM; } /* - * Since we are running in 'demo mode' this call with generate a + * Since we are running in 'demo mode' this call will generate a * struct se_node_acl for the vhost_scsi struct se_portal_group with * the SCSI Initiator port name of the passed configfs group 'name'. */ @@ -2747,7 +2913,7 @@ static ssize_t vhost_scsi_wwn_version_show(struct config_item *item, char *page) { return sysfs_emit(page, "TCM_VHOST fabric module %s on %s/%s" - "on "UTS_RELEASE"\n", VHOST_SCSI_VERSION, utsname()->sysname, + " on "UTS_RELEASE"\n", VHOST_SCSI_VERSION, utsname()->sysname, utsname()->machine); } @@ -2815,13 +2981,13 @@ out_vhost_scsi_deregister: vhost_scsi_deregister(); out: return ret; -}; +} static void vhost_scsi_exit(void) { target_unregister_template(&vhost_scsi_ops); vhost_scsi_deregister(); -}; +} MODULE_DESCRIPTION("VHOST_SCSI series fabric driver"); MODULE_ALIAS("tcm_vhost"); diff --git a/drivers/vhost/vdpa.c b/drivers/vhost/vdpa.c index 5a49b5a6d496..af1e1fdfd9ed 100644 --- a/drivers/vhost/vdpa.c +++ b/drivers/vhost/vdpa.c @@ -212,11 +212,11 @@ static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid) if (!vq->call_ctx.ctx) return; - vq->call_ctx.producer.irq = irq; - ret = irq_bypass_register_producer(&vq->call_ctx.producer); + ret = irq_bypass_register_producer(&vq->call_ctx.producer, + vq->call_ctx.ctx, irq); if (unlikely(ret)) - dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret = %d\n", - qid, vq->call_ctx.producer.token, ret); + dev_info(&v->dev, "vq %u, irq bypass producer (eventfd %p) registration fails, ret = %d\n", + qid, vq->call_ctx.ctx, ret); } static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid) @@ -712,7 +712,6 @@ static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd, if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_DRIVER_OK) vhost_vdpa_unsetup_vq_irq(v, idx); - vq->call_ctx.producer.token = NULL; } break; } @@ -753,7 +752,6 @@ static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd, cb.callback = vhost_vdpa_virtqueue_cb; cb.private = vq; cb.trigger = vq->call_ctx.ctx; - vq->call_ctx.producer.token = vq->call_ctx.ctx; if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_DRIVER_OK) vhost_vdpa_setup_vq_irq(v, idx); diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 63612faeab72..23286e4d7b49 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -22,6 +22,7 @@ #include <linux/slab.h> #include <linux/vmalloc.h> #include <linux/kthread.h> +#include <linux/cgroup.h> #include <linux/module.h> #include <linux/sort.h> #include <linux/sched/mm.h> @@ -41,6 +42,13 @@ static int max_iotlb_entries = 2048; module_param(max_iotlb_entries, int, 0444); MODULE_PARM_DESC(max_iotlb_entries, "Maximum number of iotlb entries. (default: 2048)"); +static bool fork_from_owner_default = VHOST_FORK_OWNER_TASK; + +#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL +module_param(fork_from_owner_default, bool, 0444); +MODULE_PARM_DESC(fork_from_owner_default, + "Set task mode as the default(default: Y)"); +#endif enum { VHOST_MEMORY_F_LOG = 0x1, @@ -242,7 +250,7 @@ static void vhost_worker_queue(struct vhost_worker *worker, * test_and_set_bit() implies a memory barrier. */ llist_add(&work->node, &worker->work_list); - vhost_task_wake(worker->vtsk); + worker->ops->wakeup(worker); } } @@ -364,6 +372,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->avail = NULL; vq->used = NULL; vq->last_avail_idx = 0; + vq->next_avail_head = 0; vq->avail_idx = 0; vq->last_used_idx = 0; vq->signalled_used = 0; @@ -372,7 +381,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_used = false; vq->log_addr = -1ull; vq->private_data = NULL; - vq->acked_features = 0; + virtio_features_zero(vq->acked_features_array); vq->acked_backend_features = 0; vq->log_base = NULL; vq->error_ctx = NULL; @@ -388,6 +397,44 @@ static void vhost_vq_reset(struct vhost_dev *dev, __vhost_vq_meta_reset(vq); } +static int vhost_run_work_kthread_list(void *data) +{ + struct vhost_worker *worker = data; + struct vhost_work *work, *work_next; + struct vhost_dev *dev = worker->dev; + struct llist_node *node; + + kthread_use_mm(dev->mm); + + for (;;) { + /* mb paired w/ kthread_stop */ + set_current_state(TASK_INTERRUPTIBLE); + + if (kthread_should_stop()) { + __set_current_state(TASK_RUNNING); + break; + } + node = llist_del_all(&worker->work_list); + if (!node) + schedule(); + + node = llist_reverse_order(node); + /* make sure flag is seen after deletion */ + smp_wmb(); + llist_for_each_entry_safe(work, work_next, node, node) { + clear_bit(VHOST_WORK_QUEUED, &work->flags); + __set_current_state(TASK_RUNNING); + kcov_remote_start_common(worker->kcov_handle); + work->fn(work); + kcov_remote_stop(); + cond_resched(); + } + } + kthread_unuse_mm(dev->mm); + + return 0; +} + static bool vhost_run_work_list(void *data) { struct vhost_worker *worker = data; @@ -455,6 +502,8 @@ static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) vq->log = NULL; kfree(vq->heads); vq->heads = NULL; + kfree(vq->nheads); + vq->nheads = NULL; } /* Helper to allocate iovec buffers for all vqs. */ @@ -472,7 +521,9 @@ static long vhost_dev_alloc_iovecs(struct vhost_dev *dev) GFP_KERNEL); vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads), GFP_KERNEL); - if (!vq->indirect || !vq->log || !vq->heads) + vq->nheads = kmalloc_array(dev->iov_limit, sizeof(*vq->nheads), + GFP_KERNEL); + if (!vq->indirect || !vq->log || !vq->heads || !vq->nheads) goto err_nomem; } return 0; @@ -552,6 +603,7 @@ void vhost_dev_init(struct vhost_dev *dev, dev->byte_weight = byte_weight; dev->use_worker = use_worker; dev->msg_handler = msg_handler; + dev->fork_owner = fork_from_owner_default; init_waitqueue_head(&dev->wait); INIT_LIST_HEAD(&dev->read_list); INIT_LIST_HEAD(&dev->pending_list); @@ -581,6 +633,46 @@ long vhost_dev_check_owner(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_dev_check_owner); +struct vhost_attach_cgroups_struct { + struct vhost_work work; + struct task_struct *owner; + int ret; +}; + +static void vhost_attach_cgroups_work(struct vhost_work *work) +{ + struct vhost_attach_cgroups_struct *s; + + s = container_of(work, struct vhost_attach_cgroups_struct, work); + s->ret = cgroup_attach_task_all(s->owner, current); +} + +static int vhost_attach_task_to_cgroups(struct vhost_worker *worker) +{ + struct vhost_attach_cgroups_struct attach; + int saved_cnt; + + attach.owner = current; + + vhost_work_init(&attach.work, vhost_attach_cgroups_work); + vhost_worker_queue(worker, &attach.work); + + mutex_lock(&worker->mutex); + + /* + * Bypass attachment_cnt check in __vhost_worker_flush: + * Temporarily change it to INT_MAX to bypass the check + */ + saved_cnt = worker->attachment_cnt; + worker->attachment_cnt = INT_MAX; + __vhost_worker_flush(worker); + worker->attachment_cnt = saved_cnt; + + mutex_unlock(&worker->mutex); + + return attach.ret; +} + /* Caller should have device mutex */ bool vhost_dev_has_owner(struct vhost_dev *dev) { @@ -594,10 +686,10 @@ static void vhost_attach_mm(struct vhost_dev *dev) if (dev->use_worker) { dev->mm = get_task_mm(current); } else { - /* vDPA device does not use worker thead, so there's - * no need to hold the address space for mm. This help + /* vDPA device does not use worker thread, so there's + * no need to hold the address space for mm. This helps * to avoid deadlock in the case of mmap() which may - * held the refcnt of the file and depends on release + * hold the refcnt of the file and depends on release * method to remove vma. */ dev->mm = current->mm; @@ -626,7 +718,7 @@ static void vhost_worker_destroy(struct vhost_dev *dev, WARN_ON(!llist_empty(&worker->work_list)); xa_erase(&dev->worker_xa, worker->id); - vhost_task_stop(worker->vtsk); + worker->ops->stop(worker); kfree(worker); } @@ -649,42 +741,115 @@ static void vhost_workers_free(struct vhost_dev *dev) xa_destroy(&dev->worker_xa); } +static void vhost_task_wakeup(struct vhost_worker *worker) +{ + return vhost_task_wake(worker->vtsk); +} + +static void vhost_kthread_wakeup(struct vhost_worker *worker) +{ + wake_up_process(worker->kthread_task); +} + +static void vhost_task_do_stop(struct vhost_worker *worker) +{ + return vhost_task_stop(worker->vtsk); +} + +static void vhost_kthread_do_stop(struct vhost_worker *worker) +{ + kthread_stop(worker->kthread_task); +} + +static int vhost_task_worker_create(struct vhost_worker *worker, + struct vhost_dev *dev, const char *name) +{ + struct vhost_task *vtsk; + u32 id; + int ret; + + vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed, + worker, name); + if (IS_ERR(vtsk)) + return PTR_ERR(vtsk); + + worker->vtsk = vtsk; + vhost_task_start(vtsk); + ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + if (ret < 0) { + vhost_task_do_stop(worker); + return ret; + } + worker->id = id; + return 0; +} + +static int vhost_kthread_worker_create(struct vhost_worker *worker, + struct vhost_dev *dev, const char *name) +{ + struct task_struct *task; + u32 id; + int ret; + + task = kthread_create(vhost_run_work_kthread_list, worker, "%s", name); + if (IS_ERR(task)) + return PTR_ERR(task); + + worker->kthread_task = task; + wake_up_process(task); + ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + if (ret < 0) + goto stop_worker; + + ret = vhost_attach_task_to_cgroups(worker); + if (ret) + goto stop_worker; + + worker->id = id; + return 0; + +stop_worker: + vhost_kthread_do_stop(worker); + return ret; +} + +static const struct vhost_worker_ops kthread_ops = { + .create = vhost_kthread_worker_create, + .stop = vhost_kthread_do_stop, + .wakeup = vhost_kthread_wakeup, +}; + +static const struct vhost_worker_ops vhost_task_ops = { + .create = vhost_task_worker_create, + .stop = vhost_task_do_stop, + .wakeup = vhost_task_wakeup, +}; + static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev) { struct vhost_worker *worker; - struct vhost_task *vtsk; char name[TASK_COMM_LEN]; int ret; - u32 id; + const struct vhost_worker_ops *ops = dev->fork_owner ? &vhost_task_ops : + &kthread_ops; worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT); if (!worker) return NULL; worker->dev = dev; + worker->ops = ops; snprintf(name, sizeof(name), "vhost-%d", current->pid); - vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed, - worker, name); - if (IS_ERR(vtsk)) - goto free_worker; - mutex_init(&worker->mutex); init_llist_head(&worker->work_list); worker->kcov_handle = kcov_common_handle(); - worker->vtsk = vtsk; - - vhost_task_start(vtsk); - - ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); + ret = ops->create(worker, dev, name); if (ret < 0) - goto stop_worker; - worker->id = id; + goto free_worker; return worker; -stop_worker: - vhost_task_stop(vtsk); free_worker: kfree(worker); return NULL; @@ -731,7 +896,7 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, * We don't want to call synchronize_rcu for every vq during setup * because it will slow down VM startup. If we haven't done * VHOST_SET_VRING_KICK and not done the driver specific - * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will + * SET_ENDPOINT/RUNNING then we can skip the sync since there will * not be any works queued for scsi and net. */ mutex_lock(&vq->mutex); @@ -865,6 +1030,14 @@ long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl, switch (ioctl) { /* dev worker ioctls */ case VHOST_NEW_WORKER: + /* + * vhost_tasks will account for worker threads under the parent's + * NPROC value but kthreads do not. To avoid userspace overflowing + * the system with worker threads fork_owner must be true. + */ + if (!dev->fork_owner) + return -EFAULT; + ret = vhost_new_worker(dev, &state); if (!ret && copy_to_user(argp, &state, sizeof(state))) ret = -EFAULT; @@ -982,6 +1155,7 @@ void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem) vhost_dev_cleanup(dev); + dev->fork_owner = fork_from_owner_default; dev->umem = umem; /* We don't need VQ locks below since vhost_dev_cleanup makes sure * VQs aren't running. @@ -1990,14 +2164,15 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg break; } if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) { - vq->last_avail_idx = s.num & 0xffff; + vq->next_avail_head = vq->last_avail_idx = + s.num & 0xffff; vq->last_used_idx = (s.num >> 16) & 0xffff; } else { if (s.num > 0xffff) { r = -EINVAL; break; } - vq->last_avail_idx = s.num; + vq->next_avail_head = vq->last_avail_idx = s.num; } /* Forget the cached index value. */ vq->avail_idx = vq->last_avail_idx; @@ -2135,6 +2310,45 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) goto done; } +#ifdef CONFIG_VHOST_ENABLE_FORK_OWNER_CONTROL + if (ioctl == VHOST_SET_FORK_FROM_OWNER) { + /* Only allow modification before owner is set */ + if (vhost_dev_has_owner(d)) { + r = -EBUSY; + goto done; + } + u8 fork_owner_val; + + if (get_user(fork_owner_val, (u8 __user *)argp)) { + r = -EFAULT; + goto done; + } + if (fork_owner_val != VHOST_FORK_OWNER_TASK && + fork_owner_val != VHOST_FORK_OWNER_KTHREAD) { + r = -EINVAL; + goto done; + } + d->fork_owner = !!fork_owner_val; + r = 0; + goto done; + } + if (ioctl == VHOST_GET_FORK_FROM_OWNER) { + u8 fork_owner_val = d->fork_owner; + + if (fork_owner_val != VHOST_FORK_OWNER_TASK && + fork_owner_val != VHOST_FORK_OWNER_KTHREAD) { + r = -EINVAL; + goto done; + } + if (put_user(fork_owner_val, (u8 __user *)argp)) { + r = -EFAULT; + goto done; + } + r = 0; + goto done; + } +#endif + /* You must be the owner to do anything else */ r = vhost_dev_check_owner(d); if (r) @@ -2304,6 +2518,19 @@ static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len) return 0; } +/* + * vhost_log_write() - Log in dirty page bitmap + * @vq: vhost virtqueue. + * @log: Array of dirty memory in GPA. + * @log_num: Size of vhost_log arrary. + * @len: The total length of memory buffer to log in the dirty bitmap. + * Some drivers may only partially use pages shared via the last + * vring descriptor (i.e. vhost-net RX buffer). + * Use (len == U64_MAX) to indicate the driver would log all + * pages of vring descriptors. + * @iov: Array of dirty memory in HVA. + * @count: Size of iovec array. + */ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, unsigned int log_num, u64 len, struct iovec *iov, int count) { @@ -2327,15 +2554,14 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, r = log_write(vq->log_base, log[i].addr, l); if (r < 0) return r; - len -= l; - if (!len) { - if (vq->log_ctx) - eventfd_signal(vq->log_ctx); - return 0; - } + + if (len != U64_MAX) + len -= l; } - /* Length written exceeds what we have stored. This is a bug. */ - BUG(); + + if (vq->log_ctx) + eventfd_signal(vq->log_ctx); + return 0; } EXPORT_SYMBOL_GPL(vhost_log_write); @@ -2578,11 +2804,12 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, unsigned int *out_num, unsigned int *in_num, struct vhost_log *log, unsigned int *log_num) { + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); struct vring_desc desc; unsigned int i, head, found = 0; u16 last_avail_idx = vq->last_avail_idx; __virtio16 ring_head; - int ret, access; + int ret, access, c = 0; if (vq->avail_idx == vq->last_avail_idx) { ret = vhost_get_avail_idx(vq); @@ -2593,17 +2820,21 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, return vq->num; } - /* Grab the next descriptor number they're advertising, and increment - * the index we've seen. */ - if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) { - vq_err(vq, "Failed to read head: idx %d address %p\n", - last_avail_idx, - &vq->avail->ring[last_avail_idx % vq->num]); - return -EFAULT; + if (in_order) + head = vq->next_avail_head & (vq->num - 1); + else { + /* Grab the next descriptor number they're + * advertising, and increment the index we've seen. */ + if (unlikely(vhost_get_avail_head(vq, &ring_head, + last_avail_idx))) { + vq_err(vq, "Failed to read head: idx %d address %p\n", + last_avail_idx, + &vq->avail->ring[last_avail_idx % vq->num]); + return -EFAULT; + } + head = vhost16_to_cpu(vq, ring_head); } - head = vhost16_to_cpu(vq, ring_head); - /* If their number is silly, that's an error. */ if (unlikely(head >= vq->num)) { vq_err(vq, "Guest says index %u > %u is available", @@ -2646,6 +2877,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, "in indirect descriptor at idx %d\n", i); return ret; } + ++c; continue; } @@ -2681,10 +2913,12 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, } *out_num += ret; } + ++c; } while ((i = next_desc(vq, &desc)) != -1); /* On success, increment avail index. */ vq->last_avail_idx++; + vq->next_avail_head += c; /* Assume notifications from guest are disabled at this point, * if they aren't we would need to update avail_event index. */ @@ -2708,8 +2942,9 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) cpu_to_vhost32(vq, head), cpu_to_vhost32(vq, len) }; + u16 nheads = 1; - return vhost_add_used_n(vq, &heads, 1); + return vhost_add_used_n(vq, &heads, &nheads, 1); } EXPORT_SYMBOL_GPL(vhost_add_used); @@ -2745,10 +2980,9 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, return 0; } -/* After we've used one of their buffers, we tell them about it. We'll then - * want to notify the guest, using eventfd. */ -int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, - unsigned count) +static int vhost_add_used_n_ooo(struct vhost_virtqueue *vq, + struct vring_used_elem *heads, + unsigned count) { int start, n, r; @@ -2761,7 +2995,72 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, heads += n; count -= n; } - r = __vhost_add_used_n(vq, heads, count); + return __vhost_add_used_n(vq, heads, count); +} + +static int vhost_add_used_n_in_order(struct vhost_virtqueue *vq, + struct vring_used_elem *heads, + const u16 *nheads, + unsigned count) +{ + vring_used_elem_t __user *used; + u16 old, new = vq->last_used_idx; + int start, i; + + if (!nheads) + return -EINVAL; + + start = vq->last_used_idx & (vq->num - 1); + used = vq->used->ring + start; + + for (i = 0; i < count; i++) { + if (vhost_put_used(vq, &heads[i], start, 1)) { + vq_err(vq, "Failed to write used"); + return -EFAULT; + } + start += nheads[i]; + new += nheads[i]; + if (start >= vq->num) + start -= vq->num; + } + + if (unlikely(vq->log_used)) { + /* Make sure data is seen before log. */ + smp_wmb(); + /* Log used ring entry write. */ + log_used(vq, ((void __user *)used - (void __user *)vq->used), + (vq->num - start) * sizeof *used); + if (start + count > vq->num) + log_used(vq, 0, + (start + count - vq->num) * sizeof *used); + } + + old = vq->last_used_idx; + vq->last_used_idx = new; + /* If the driver never bothers to signal in a very long while, + * used index might wrap around. If that happens, invalidate + * signalled_used index we stored. TODO: make sure driver + * signals at least once in 2^16 and remove this. */ + if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old))) + vq->signalled_used_valid = false; + return 0; +} + +/* After we've used one of their buffers, we tell them about it. We'll then + * want to notify the guest, using eventfd. */ +int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, + u16 *nheads, unsigned count) +{ + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); + int r; + + if (!in_order || !nheads) + r = vhost_add_used_n_ooo(vq, heads, count); + else + r = vhost_add_used_n_in_order(vq, heads, nheads, count); + + if (r < 0) + return r; /* Make sure buffer is written before we update index. */ smp_wmb(); @@ -2841,14 +3140,16 @@ EXPORT_SYMBOL_GPL(vhost_add_used_and_signal); /* multi-buffer version of vhost_add_used_and_signal */ void vhost_add_used_and_signal_n(struct vhost_dev *dev, struct vhost_virtqueue *vq, - struct vring_used_elem *heads, unsigned count) + struct vring_used_elem *heads, + u16 *nheads, + unsigned count) { - vhost_add_used_n(vq, heads, count); + vhost_add_used_n(vq, heads, nheads, count); vhost_signal(dev, vq); } EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n); -/* return true if we're sure that avaiable ring is empty */ +/* return true if we're sure that available ring is empty */ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) { int r; diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index bb75a292d50c..621a6d9a8791 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -26,7 +26,18 @@ struct vhost_work { unsigned long flags; }; +struct vhost_worker; +struct vhost_dev; + +struct vhost_worker_ops { + int (*create)(struct vhost_worker *worker, struct vhost_dev *dev, + const char *name); + void (*stop)(struct vhost_worker *worker); + void (*wakeup)(struct vhost_worker *worker); +}; + struct vhost_worker { + struct task_struct *kthread_task; struct vhost_task *vtsk; struct vhost_dev *dev; /* Used to serialize device wide flushing with worker swapping. */ @@ -36,6 +47,7 @@ struct vhost_worker { u32 id; int attachment_cnt; bool killed; + const struct vhost_worker_ops *ops; }; /* Poll a file (eventfd or socket) */ @@ -103,6 +115,8 @@ struct vhost_virtqueue { * Values are limited to 0x7fff, and the high bit is used as * a wrap counter when using VIRTIO_F_RING_PACKED. */ u16 last_avail_idx; + /* Next avail ring head when VIRTIO_F_IN_ORDER is negoitated */ + u16 next_avail_head; /* Caches available index value from user. */ u16 avail_idx; @@ -129,11 +143,12 @@ struct vhost_virtqueue { struct iovec iotlb_iov[64]; struct iovec *indirect; struct vring_used_elem *heads; + u16 *nheads; /* Protected by virtqueue mutex. */ struct vhost_iotlb *umem; struct vhost_iotlb *iotlb; void *private_data; - u64 acked_features; + VIRTIO_DECLARE_FEATURES(acked_features); u64 acked_backend_features; /* Log write descriptors */ void __user *log_base; @@ -176,6 +191,16 @@ struct vhost_dev { int byte_weight; struct xarray worker_xa; bool use_worker; + /* + * If fork_owner is true we use vhost_tasks to create + * the worker so all settings/limits like cgroups, NPROC, + * scheduler, etc are inherited from the owner. If false, + * we use kthreads and only attach to the same cgroups + * as the owner for compat with older kernels. + * here we use true as default value. + * The default value is set by fork_from_owner_default + */ + bool fork_owner; int (*msg_handler)(struct vhost_dev *dev, u32 asid, struct vhost_iotlb_msg *msg); }; @@ -213,11 +238,12 @@ bool vhost_vq_is_setup(struct vhost_virtqueue *vq); int vhost_vq_init_access(struct vhost_virtqueue *); int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads, - unsigned count); + u16 *nheads, unsigned count); void vhost_add_used_and_signal(struct vhost_dev *, struct vhost_virtqueue *, unsigned int id, int len); void vhost_add_used_and_signal_n(struct vhost_dev *, struct vhost_virtqueue *, - struct vring_used_elem *heads, unsigned count); + struct vring_used_elem *heads, u16 *nheads, + unsigned count); void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *); void vhost_disable_notify(struct vhost_dev *, struct vhost_virtqueue *); bool vhost_vq_avail_empty(struct vhost_dev *, struct vhost_virtqueue *); @@ -291,7 +317,7 @@ static inline void *vhost_vq_get_backend(struct vhost_virtqueue *vq) static inline bool vhost_has_feature(struct vhost_virtqueue *vq, int bit) { - return vq->acked_features & (1ULL << bit); + return virtio_features_test_bit(vq->acked_features_array, bit); } static inline bool vhost_backend_has_feature(struct vhost_virtqueue *vq, int bit) diff --git a/drivers/vhost/vringh.c b/drivers/vhost/vringh.c index 73e153f9b449..9f27c3f6091b 100644 --- a/drivers/vhost/vringh.c +++ b/drivers/vhost/vringh.c @@ -225,10 +225,9 @@ static int resize_iovec(struct vringh_kiov *iov, gfp_t gfp) flag = (iov->max_num & VRINGH_IOV_ALLOCATED); if (flag) - new = krealloc_array(iov->iov, new_num, - sizeof(struct iovec), gfp); + new = krealloc_array(iov->iov, new_num, sizeof(*new), gfp); else { - new = kmalloc_array(new_num, sizeof(struct iovec), gfp); + new = kmalloc_array(new_num, sizeof(*new), gfp); if (new) { memcpy(new, iov->iov, iov->max_num * sizeof(struct iovec)); @@ -781,22 +780,6 @@ ssize_t vringh_iov_push_user(struct vringh_iov *wiov, EXPORT_SYMBOL(vringh_iov_push_user); /** - * vringh_abandon_user - we've decided not to handle the descriptor(s). - * @vrh: the vring. - * @num: the number of descriptors to put back (ie. num - * vringh_get_user() to undo). - * - * The next vringh_get_user() will return the old descriptor(s) again. - */ -void vringh_abandon_user(struct vringh *vrh, unsigned int num) -{ - /* We only update vring_avail_event(vr) when we want to be notified, - * so we haven't changed that yet. */ - vrh->last_avail_idx -= num; -} -EXPORT_SYMBOL(vringh_abandon_user); - -/** * vringh_complete_user - we've finished with descriptor, publish it. * @vrh: the vring. * @head: the head as filled in by vringh_getdesc_user. @@ -901,20 +884,6 @@ static inline int putused_kern(const struct vringh *vrh, return 0; } -static inline int xfer_kern(const struct vringh *vrh, void *src, - void *dst, size_t len) -{ - memcpy(dst, src, len); - return 0; -} - -static inline int kern_xfer(const struct vringh *vrh, void *dst, - void *src, size_t len) -{ - memcpy(dst, src, len); - return 0; -} - /** * vringh_init_kern - initialize a vringh for a kernelspace vring. * @vrh: the vringh to initialize. @@ -1000,51 +969,6 @@ int vringh_getdesc_kern(struct vringh *vrh, EXPORT_SYMBOL(vringh_getdesc_kern); /** - * vringh_iov_pull_kern - copy bytes from vring_iov. - * @riov: the riov as passed to vringh_getdesc_kern() (updated as we consume) - * @dst: the place to copy. - * @len: the maximum length to copy. - * - * Returns the bytes copied <= len or a negative errno. - */ -ssize_t vringh_iov_pull_kern(struct vringh_kiov *riov, void *dst, size_t len) -{ - return vringh_iov_xfer(NULL, riov, dst, len, xfer_kern); -} -EXPORT_SYMBOL(vringh_iov_pull_kern); - -/** - * vringh_iov_push_kern - copy bytes into vring_iov. - * @wiov: the wiov as passed to vringh_getdesc_kern() (updated as we consume) - * @src: the place to copy from. - * @len: the maximum length to copy. - * - * Returns the bytes copied <= len or a negative errno. - */ -ssize_t vringh_iov_push_kern(struct vringh_kiov *wiov, - const void *src, size_t len) -{ - return vringh_iov_xfer(NULL, wiov, (void *)src, len, kern_xfer); -} -EXPORT_SYMBOL(vringh_iov_push_kern); - -/** - * vringh_abandon_kern - we've decided not to handle the descriptor(s). - * @vrh: the vring. - * @num: the number of descriptors to put back (ie. num - * vringh_get_kern() to undo). - * - * The next vringh_get_kern() will return the old descriptor(s) again. - */ -void vringh_abandon_kern(struct vringh *vrh, unsigned int num) -{ - /* We only update vring_avail_event(vr) when we want to be notified, - * so we haven't changed that yet. */ - vrh->last_avail_idx -= num; -} -EXPORT_SYMBOL(vringh_abandon_kern); - -/** * vringh_complete_kern - we've finished with descriptor, publish it. * @vrh: the vring. * @head: the head as filled in by vringh_getdesc_kern. @@ -1291,11 +1215,10 @@ static inline int getu16_iotlb(const struct vringh *vrh, if (ret) return ret; } else { - void *kaddr = kmap_local_page(ivec.iov.bvec[0].bv_page); - void *from = kaddr + ivec.iov.bvec[0].bv_offset; + __virtio16 *from = bvec_kmap_local(&ivec.iov.bvec[0]); - tmp = READ_ONCE(*(__virtio16 *)from); - kunmap_local(kaddr); + tmp = READ_ONCE(*from); + kunmap_local(from); } *val = vringh16_to_cpu(vrh, tmp); @@ -1330,11 +1253,10 @@ static inline int putu16_iotlb(const struct vringh *vrh, if (ret) return ret; } else { - void *kaddr = kmap_local_page(ivec.iov.bvec[0].bv_page); - void *to = kaddr + ivec.iov.bvec[0].bv_offset; + __virtio16 *to = bvec_kmap_local(&ivec.iov.bvec[0]); - WRITE_ONCE(*(__virtio16 *)to, tmp); - kunmap_local(kaddr); + WRITE_ONCE(*to, tmp); + kunmap_local(to); } return 0; @@ -1538,23 +1460,6 @@ ssize_t vringh_iov_push_iotlb(struct vringh *vrh, EXPORT_SYMBOL(vringh_iov_push_iotlb); /** - * vringh_abandon_iotlb - we've decided not to handle the descriptor(s). - * @vrh: the vring. - * @num: the number of descriptors to put back (ie. num - * vringh_get_iotlb() to undo). - * - * The next vringh_get_iotlb() will return the old descriptor(s) again. - */ -void vringh_abandon_iotlb(struct vringh *vrh, unsigned int num) -{ - /* We only update vring_avail_event(vr) when we want to be notified, - * so we haven't changed that yet. - */ - vrh->last_avail_idx -= num; -} -EXPORT_SYMBOL(vringh_abandon_iotlb); - -/** * vringh_complete_iotlb - we've finished with descriptor, publish it. * @vrh: the vring. * @head: the head as filled in by vringh_getdesc_iotlb. @@ -1575,32 +1480,6 @@ int vringh_complete_iotlb(struct vringh *vrh, u16 head, u32 len) EXPORT_SYMBOL(vringh_complete_iotlb); /** - * vringh_notify_enable_iotlb - we want to know if something changes. - * @vrh: the vring. - * - * This always enables notifications, but returns false if there are - * now more buffers available in the vring. - */ -bool vringh_notify_enable_iotlb(struct vringh *vrh) -{ - return __vringh_notify_enable(vrh, getu16_iotlb, putu16_iotlb); -} -EXPORT_SYMBOL(vringh_notify_enable_iotlb); - -/** - * vringh_notify_disable_iotlb - don't tell us if something changes. - * @vrh: the vring. - * - * This is our normal running state: we disable and then only enable when - * we're going to sleep. - */ -void vringh_notify_disable_iotlb(struct vringh *vrh) -{ - __vringh_notify_disable(vrh, putu16_iotlb); -} -EXPORT_SYMBOL(vringh_notify_disable_iotlb); - -/** * vringh_need_notify_iotlb - must we tell the other side about used buffers? * @vrh: the vring we've called vringh_complete_iotlb() on. * diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c index 802153e23073..ae01457ea2cd 100644 --- a/drivers/vhost/vsock.c +++ b/drivers/vhost/vsock.c @@ -344,6 +344,10 @@ vhost_vsock_alloc_skb(struct vhost_virtqueue *vq, len = iov_length(vq->iov, out); + if (len < VIRTIO_VSOCK_SKB_HEADROOM || + len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE + VIRTIO_VSOCK_SKB_HEADROOM) + return NULL; + /* len contains both payload and hdr */ skb = virtio_vsock_alloc_skb(len, GFP_KERNEL); if (!skb) @@ -367,18 +371,15 @@ vhost_vsock_alloc_skb(struct vhost_virtqueue *vq, return skb; /* The pkt is too big or the length in the header is invalid */ - if (payload_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE || - payload_len + sizeof(*hdr) > len) { + if (payload_len + sizeof(*hdr) > len) { kfree_skb(skb); return NULL; } - virtio_vsock_skb_rx_put(skb); + virtio_vsock_skb_put(skb, payload_len); - nbytes = copy_from_iter(skb->data, payload_len, &iov_iter); - if (nbytes != payload_len) { - vq_err(vq, "Expected %zu byte payload, got %zu bytes\n", - payload_len, nbytes); + if (skb_copy_datagram_from_iter(skb, 0, &iov_iter, payload_len)) { + vq_err(vq, "Failed to copy %zu byte payload\n", payload_len); kfree_skb(skb); return NULL; } |