diff options
Diffstat (limited to 'net/vmw_vsock')
-rw-r--r-- | net/vmw_vsock/af_vsock.c | 84 | ||||
-rw-r--r-- | net/vmw_vsock/hyperv_transport.c | 17 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport.c | 20 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 3 | ||||
-rw-r--r-- | net/vmw_vsock/vmci_transport.c | 4 |
5 files changed, 101 insertions, 27 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 2e7a3034e965..ead6a3c14b87 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -407,6 +407,8 @@ EXPORT_SYMBOL_GPL(vsock_enqueue_accept); static bool vsock_use_local_transport(unsigned int remote_cid) { + lockdep_assert_held(&vsock_register_mutex); + if (!transport_local) return false; @@ -464,6 +466,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) remote_flags = vsk->remote_addr.svm_flags; + mutex_lock(&vsock_register_mutex); + switch (sk->sk_type) { case SOCK_DGRAM: new_transport = transport_dgram; @@ -479,12 +483,15 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) new_transport = transport_h2g; break; default: - return -ESOCKTNOSUPPORT; + ret = -ESOCKTNOSUPPORT; + goto err; } if (vsk->transport) { - if (vsk->transport == new_transport) - return 0; + if (vsk->transport == new_transport) { + ret = 0; + goto err; + } /* transport->release() must be called with sock lock acquired. * This path can only be taken during vsock_connect(), where we @@ -508,8 +515,16 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) /* We increase the module refcnt to prevent the transport unloading * while there are open sockets assigned to it. */ - if (!new_transport || !try_module_get(new_transport->module)) - return -ENODEV; + if (!new_transport || !try_module_get(new_transport->module)) { + ret = -ENODEV; + goto err; + } + + /* It's safe to release the mutex after a successful try_module_get(). + * Whichever transport `new_transport` points at, it won't go away until + * the last module_put() below or in vsock_deassign_transport(). + */ + mutex_unlock(&vsock_register_mutex); if (sk->sk_type == SOCK_SEQPACKET) { if (!new_transport->seqpacket_allow || @@ -528,12 +543,31 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) vsk->transport = new_transport; return 0; +err: + mutex_unlock(&vsock_register_mutex); + return ret; } EXPORT_SYMBOL_GPL(vsock_assign_transport); +/* + * Provide safe access to static transport_{h2g,g2h,dgram,local} callbacks. + * Otherwise we may race with module removal. Do not use on `vsk->transport`. + */ +static u32 vsock_registered_transport_cid(const struct vsock_transport **transport) +{ + u32 cid = VMADDR_CID_ANY; + + mutex_lock(&vsock_register_mutex); + if (*transport) + cid = (*transport)->get_local_cid(); + mutex_unlock(&vsock_register_mutex); + + return cid; +} + bool vsock_find_cid(unsigned int cid) { - if (transport_g2h && cid == transport_g2h->get_local_cid()) + if (cid == vsock_registered_transport_cid(&transport_g2h)) return true; if (transport_h2g && cid == VMADDR_CID_HOST) @@ -994,11 +1028,6 @@ static int vsock_getname(struct socket *sock, vm_addr = &vsk->local_addr; } - if (!vm_addr) { - err = -EINVAL; - goto out; - } - /* sys_getsockname() and sys_getpeername() pass us a * MAX_SOCK_ADDR-sized buffer and don't set addr_len. Unfortunately * that macro is defined in socket.c instead of .h, so we hardcode its @@ -1389,6 +1418,28 @@ static int vsock_do_ioctl(struct socket *sock, unsigned int cmd, vsk = vsock_sk(sk); switch (cmd) { + case SIOCINQ: { + ssize_t n_bytes; + + if (!vsk->transport) { + ret = -EOPNOTSUPP; + break; + } + + if (sock_type_connectible(sk->sk_type) && + sk->sk_state == TCP_LISTEN) { + ret = -EINVAL; + break; + } + + n_bytes = vsock_stream_has_data(vsk); + if (n_bytes < 0) { + ret = n_bytes; + break; + } + ret = put_user(n_bytes, arg); + break; + } case SIOCOUTQ: { ssize_t n_bytes; @@ -2536,18 +2587,19 @@ static long vsock_dev_do_ioctl(struct file *filp, unsigned int cmd, void __user *ptr) { u32 __user *p = ptr; - u32 cid = VMADDR_CID_ANY; int retval = 0; + u32 cid; switch (cmd) { case IOCTL_VM_SOCKETS_GET_LOCAL_CID: /* To be compatible with the VMCI behavior, we prioritize the * guest CID instead of well-know host CID (VMADDR_CID_HOST). */ - if (transport_g2h) - cid = transport_g2h->get_local_cid(); - else if (transport_h2g) - cid = transport_h2g->get_local_cid(); + cid = vsock_registered_transport_cid(&transport_g2h); + if (cid == VMADDR_CID_ANY) + cid = vsock_registered_transport_cid(&transport_h2g); + if (cid == VMADDR_CID_ANY) + cid = vsock_registered_transport_cid(&transport_local); if (put_user(cid, p) != 0) retval = -EFAULT; diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index 31342ab502b4..432fcbbd14d4 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -694,15 +694,26 @@ out: static s64 hvs_stream_has_data(struct vsock_sock *vsk) { struct hvsock *hvs = vsk->trans; + bool need_refill; s64 ret; if (hvs->recv_data_len > 0) - return 1; + return hvs->recv_data_len; switch (hvs_channel_readable_payload(hvs->chan)) { case 1: - ret = 1; - break; + need_refill = !hvs->recv_desc; + if (!need_refill) + return -EIO; + + hvs->recv_desc = hv_pkt_iter_first(hvs->chan); + if (!hvs->recv_desc) + return -ENOBUFS; + + ret = hvs_update_recv_data(hvs); + if (ret) + return ret; + return hvs->recv_data_len; case 0: vsk->peer_shutdown |= SEND_SHUTDOWN; ret = 0; diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index f0e48e6911fc..b6569b0ca2bb 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -307,7 +307,7 @@ out_rcu: static void virtio_vsock_rx_fill(struct virtio_vsock *vsock) { - int total_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE + VIRTIO_VSOCK_SKB_HEADROOM; + int total_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; struct scatterlist pkt, *p; struct virtqueue *vq; struct sk_buff *skb; @@ -316,7 +316,7 @@ static void virtio_vsock_rx_fill(struct virtio_vsock *vsock) vq = vsock->vqs[VSOCK_VQ_RX]; do { - skb = virtio_vsock_alloc_skb(total_len, GFP_KERNEL); + skb = virtio_vsock_alloc_linear_skb(total_len, GFP_KERNEL); if (!skb) break; @@ -624,8 +624,9 @@ static void virtio_transport_rx_work(struct work_struct *work) do { virtqueue_disable_cb(vq); for (;;) { + unsigned int len, payload_len; + struct virtio_vsock_hdr *hdr; struct sk_buff *skb; - unsigned int len; if (!virtio_transport_more_replies(vsock)) { /* Stop rx until the device processes already @@ -642,13 +643,22 @@ static void virtio_transport_rx_work(struct work_struct *work) vsock->rx_buf_nr--; /* Drop short/long packets */ - if (unlikely(len < sizeof(struct virtio_vsock_hdr) || + if (unlikely(len < sizeof(*hdr) || len > virtio_vsock_skb_len(skb))) { kfree_skb(skb); continue; } - virtio_vsock_skb_rx_put(skb); + hdr = virtio_vsock_hdr(skb); + payload_len = le32_to_cpu(hdr->len); + if (unlikely(payload_len > len - sizeof(*hdr))) { + kfree_skb(skb); + continue; + } + + if (payload_len) + virtio_vsock_skb_put(skb, payload_len); + virtio_transport_deliver_tap_pkt(skb); virtio_transport_recv_pkt(&virtio_transport, skb); } diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index 1b5d9896edae..fe92e5fa95b4 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -109,7 +109,8 @@ static int virtio_transport_fill_skb(struct sk_buff *skb, return __zerocopy_sg_from_iter(info->msg, NULL, skb, &info->msg->msg_iter, len, NULL); - return memcpy_from_msg(skb_put(skb, len), info->msg, len); + virtio_vsock_skb_put(skb, len); + return skb_copy_datagram_from_iter(skb, 0, &info->msg->msg_iter, len); } static void virtio_transport_init_hdr(struct sk_buff *skb, diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index b370070194fa..7eccd6708d66 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -119,6 +119,8 @@ vmci_transport_packet_init(struct vmci_transport_packet *pkt, u16 proto, struct vmci_handle handle) { + memset(pkt, 0, sizeof(*pkt)); + /* We register the stream control handler as an any cid handle so we * must always send from a source address of VMADDR_CID_ANY */ @@ -131,8 +133,6 @@ vmci_transport_packet_init(struct vmci_transport_packet *pkt, pkt->type = type; pkt->src_port = src->svm_port; pkt->dst_port = dst->svm_port; - memset(&pkt->proto, 0, sizeof(pkt->proto)); - memset(&pkt->_reserved2, 0, sizeof(pkt->_reserved2)); switch (pkt->type) { case VMCI_TRANSPORT_PACKET_TYPE_INVALID: |