diff options
Diffstat (limited to 'net/vmw_vsock')
-rw-r--r-- | net/vmw_vsock/af_vsock.c | 224 | ||||
-rw-r--r-- | net/vmw_vsock/diag.c | 1 | ||||
-rw-r--r-- | net/vmw_vsock/hyperv_transport.c | 7 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport.c | 174 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 95 | ||||
-rw-r--r-- | net/vmw_vsock/vsock_bpf.c | 21 | ||||
-rw-r--r-- | net/vmw_vsock/vsock_loopback.c | 6 |
7 files changed, 384 insertions, 144 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 54ba7316f808..7e3db87ae433 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -112,16 +112,19 @@ #include <net/sock.h> #include <net/af_vsock.h> #include <uapi/linux/vm_sockets.h> +#include <uapi/asm-generic/ioctls.h> static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr); static void vsock_sk_destruct(struct sock *sk); static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); +static void vsock_close(struct sock *sk, long timeout); /* Protocol family. */ struct proto vsock_proto = { .name = "AF_VSOCK", .owner = THIS_MODULE, .obj_size = sizeof(struct vsock_sock), + .close = vsock_close, #ifdef CONFIG_BPF_SYSCALL .psock_update_sk_prot = vsock_bpf_update_proto, #endif @@ -334,7 +337,10 @@ EXPORT_SYMBOL_GPL(vsock_find_connected_socket); void vsock_remove_sock(struct vsock_sock *vsk) { - vsock_remove_bound(vsk); + /* Transport reassignment must not remove the binding. */ + if (sock_flag(sk_vsock(vsk), SOCK_DEAD)) + vsock_remove_bound(vsk); + vsock_remove_connected(vsk); } EXPORT_SYMBOL_GPL(vsock_remove_sock); @@ -488,6 +494,15 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) */ vsk->transport->release(vsk); vsock_deassign_transport(vsk); + + /* transport's release() and destruct() can touch some socket + * state, since we are reassigning the socket to a new transport + * during vsock_connect(), let's reset these fields to have a + * clean state. + */ + sock_reset_flag(sk, SOCK_DONE); + sk->sk_state = TCP_CLOSE; + vsk->peer_shutdown = 0; } /* We increase the module refcnt to prevent the transport unloading @@ -796,45 +811,53 @@ static bool sock_type_connectible(u16 type) static void __vsock_release(struct sock *sk, int level) { - if (sk) { - struct sock *pending; - struct vsock_sock *vsk; + struct vsock_sock *vsk; + struct sock *pending; - vsk = vsock_sk(sk); - pending = NULL; /* Compiler warning. */ + vsk = vsock_sk(sk); + pending = NULL; /* Compiler warning. */ - /* When "level" is SINGLE_DEPTH_NESTING, use the nested - * version to avoid the warning "possible recursive locking - * detected". When "level" is 0, lock_sock_nested(sk, level) - * is the same as lock_sock(sk). - */ - lock_sock_nested(sk, level); + /* When "level" is SINGLE_DEPTH_NESTING, use the nested + * version to avoid the warning "possible recursive locking + * detected". When "level" is 0, lock_sock_nested(sk, level) + * is the same as lock_sock(sk). + */ + lock_sock_nested(sk, level); - if (vsk->transport) - vsk->transport->release(vsk); - else if (sock_type_connectible(sk->sk_type)) - vsock_remove_sock(vsk); + /* Indicate to vsock_remove_sock() that the socket is being released and + * can be removed from the bound_table. Unlike transport reassignment + * case, where the socket must remain bound despite vsock_remove_sock() + * being called from the transport release() callback. + */ + sock_set_flag(sk, SOCK_DEAD); - sock_orphan(sk); - sk->sk_shutdown = SHUTDOWN_MASK; + if (vsk->transport) + vsk->transport->release(vsk); + else if (sock_type_connectible(sk->sk_type)) + vsock_remove_sock(vsk); - skb_queue_purge(&sk->sk_receive_queue); + sock_orphan(sk); + sk->sk_shutdown = SHUTDOWN_MASK; - /* Clean up any sockets that never were accepted. */ - while ((pending = vsock_dequeue_accept(sk)) != NULL) { - __vsock_release(pending, SINGLE_DEPTH_NESTING); - sock_put(pending); - } + skb_queue_purge(&sk->sk_receive_queue); - release_sock(sk); - sock_put(sk); + /* Clean up any sockets that never were accepted. */ + while ((pending = vsock_dequeue_accept(sk)) != NULL) { + __vsock_release(pending, SINGLE_DEPTH_NESTING); + sock_put(pending); } + + release_sock(sk); + sock_put(sk); } static void vsock_sk_destruct(struct sock *sk) { struct vsock_sock *vsk = vsock_sk(sk); + /* Flush MSG_ZEROCOPY leftovers. */ + __skb_queue_purge(&sk->sk_error_queue); + vsock_deassign_transport(vsk); /* When clearing these addresses, there's no need to set the family and @@ -866,6 +889,9 @@ EXPORT_SYMBOL_GPL(vsock_create_connected); s64 vsock_stream_has_data(struct vsock_sock *vsk) { + if (WARN_ON(!vsk->transport)) + return 0; + return vsk->transport->stream_has_data(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_data); @@ -874,6 +900,9 @@ s64 vsock_connectible_has_data(struct vsock_sock *vsk) { struct sock *sk = sk_vsock(vsk); + if (WARN_ON(!vsk->transport)) + return 0; + if (sk->sk_type == SOCK_SEQPACKET) return vsk->transport->seqpacket_has_data(vsk); else @@ -883,6 +912,9 @@ EXPORT_SYMBOL_GPL(vsock_connectible_has_data); s64 vsock_stream_has_space(struct vsock_sock *vsk) { + if (WARN_ON(!vsk->transport)) + return 0; + return vsk->transport->stream_has_space(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_space); @@ -897,9 +929,22 @@ void vsock_data_ready(struct sock *sk) } EXPORT_SYMBOL_GPL(vsock_data_ready); +/* Dummy callback required by sockmap. + * See unconditional call of saved_close() in sock_map_close(). + */ +static void vsock_close(struct sock *sk, long timeout) +{ +} + static int vsock_release(struct socket *sock) { - __vsock_release(sock->sk, 0); + struct sock *sk = sock->sk; + + if (!sk) + return 0; + + sk->sk_prot->close(sk, 0); + __vsock_release(sk, 0); sock->sk = NULL; sock->state = SS_FREE; @@ -1050,6 +1095,9 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock, mask |= EPOLLRDHUP; } + if (sk_is_readable(sk)) + mask |= EPOLLIN | EPOLLRDNORM; + if (sock->type == SOCK_DGRAM) { /* For datagram sockets we can read if there is something in * the queue and write as long as the socket isn't shutdown for @@ -1141,6 +1189,9 @@ static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor) { struct vsock_sock *vsk = vsock_sk(sk); + if (WARN_ON_ONCE(!vsk->transport)) + return -ENODEV; + return vsk->transport->read_skb(vsk, read_actor); } @@ -1270,28 +1321,82 @@ out: return err; } +int __vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct sock *sk = sock->sk; + struct vsock_sock *vsk = vsock_sk(sk); + + return vsk->transport->dgram_dequeue(vsk, msg, len, flags); +} + int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, int flags) { #ifdef CONFIG_BPF_SYSCALL + struct sock *sk = sock->sk; const struct proto *prot; -#endif - struct vsock_sock *vsk; - struct sock *sk; - - sk = sock->sk; - vsk = vsock_sk(sk); -#ifdef CONFIG_BPF_SYSCALL prot = READ_ONCE(sk->sk_prot); if (prot != &vsock_proto) return prot->recvmsg(sk, msg, len, flags, NULL); #endif - return vsk->transport->dgram_dequeue(vsk, msg, len, flags); + return __vsock_dgram_recvmsg(sock, msg, len, flags); } EXPORT_SYMBOL_GPL(vsock_dgram_recvmsg); +static int vsock_do_ioctl(struct socket *sock, unsigned int cmd, + int __user *arg) +{ + struct sock *sk = sock->sk; + struct vsock_sock *vsk; + int ret; + + vsk = vsock_sk(sk); + + switch (cmd) { + case SIOCOUTQ: { + ssize_t n_bytes; + + if (!vsk->transport || !vsk->transport->unsent_bytes) { + ret = -EOPNOTSUPP; + break; + } + + if (sock_type_connectible(sk->sk_type) && sk->sk_state == TCP_LISTEN) { + ret = -EINVAL; + break; + } + + n_bytes = vsk->transport->unsent_bytes(vsk); + if (n_bytes < 0) { + ret = n_bytes; + break; + } + + ret = put_user(n_bytes, arg); + break; + } + default: + ret = -ENOIOCTLCMD; + } + + return ret; +} + +static int vsock_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + int ret; + + lock_sock(sock->sk); + ret = vsock_do_ioctl(sock, cmd, (int __user *)arg); + release_sock(sock->sk); + + return ret; +} + static const struct proto_ops vsock_dgram_ops = { .family = PF_VSOCK, .owner = THIS_MODULE, @@ -1302,7 +1407,7 @@ static const struct proto_ops vsock_dgram_ops = { .accept = sock_no_accept, .getname = vsock_getname, .poll = vsock_poll, - .ioctl = sock_no_ioctl, + .ioctl = vsock_ioctl, .listen = sock_no_listen, .shutdown = vsock_shutdown, .sendmsg = vsock_dgram_sendmsg, @@ -1427,6 +1532,11 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr, if (err < 0) goto out; + /* sk_err might have been set as a result of an earlier + * (failed) connect attempt. + */ + sk->sk_err = 0; + /* Mark sock as connecting and set the error code to in * progress in case this is a non-blocking connect. */ @@ -1500,8 +1610,8 @@ out: return err; } -static int vsock_accept(struct socket *sock, struct socket *newsock, int flags, - bool kern) +static int vsock_accept(struct socket *sock, struct socket *newsock, + struct proto_accept_arg *arg) { struct sock *listener; int err; @@ -1528,7 +1638,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags, /* Wait for children sockets to appear; these are the new sockets * created upon connection establishment. */ - timeout = sock_rcvtimeo(listener, flags & O_NONBLOCK); + timeout = sock_rcvtimeo(listener, arg->flags & O_NONBLOCK); prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); while ((connected = vsock_dequeue_accept(listener)) == NULL && @@ -2174,15 +2284,12 @@ out: } int -vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, - int flags) +__vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) { struct sock *sk; struct vsock_sock *vsk; const struct vsock_transport *transport; -#ifdef CONFIG_BPF_SYSCALL - const struct proto *prot; -#endif int err; sk = sock->sk; @@ -2233,14 +2340,6 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, goto out; } -#ifdef CONFIG_BPF_SYSCALL - prot = READ_ONCE(sk->sk_prot); - if (prot != &vsock_proto) { - release_sock(sk); - return prot->recvmsg(sk, msg, len, flags, NULL); - } -#endif - if (sk->sk_type == SOCK_STREAM) err = __vsock_stream_recvmsg(sk, msg, len, flags); else @@ -2250,6 +2349,22 @@ out: release_sock(sk); return err; } + +int +vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ +#ifdef CONFIG_BPF_SYSCALL + struct sock *sk = sock->sk; + const struct proto *prot; + + prot = READ_ONCE(sk->sk_prot); + if (prot != &vsock_proto) + return prot->recvmsg(sk, msg, len, flags, NULL); +#endif + + return __vsock_connectible_recvmsg(sock, msg, len, flags); +} EXPORT_SYMBOL_GPL(vsock_connectible_recvmsg); static int vsock_set_rcvlowat(struct sock *sk, int val) @@ -2286,7 +2401,7 @@ static const struct proto_ops vsock_stream_ops = { .accept = vsock_accept, .getname = vsock_getname, .poll = vsock_poll, - .ioctl = sock_no_ioctl, + .ioctl = vsock_ioctl, .listen = vsock_listen, .shutdown = vsock_shutdown, .setsockopt = vsock_connectible_setsockopt, @@ -2308,7 +2423,7 @@ static const struct proto_ops vsock_seqpacket_ops = { .accept = vsock_accept, .getname = vsock_getname, .poll = vsock_poll, - .ioctl = sock_no_ioctl, + .ioctl = vsock_ioctl, .listen = vsock_listen, .shutdown = vsock_shutdown, .setsockopt = vsock_connectible_setsockopt, @@ -2357,6 +2472,7 @@ static int vsock_create(struct net *net, struct socket *sock, if (sock->type == SOCK_DGRAM) { ret = vsock_assign_transport(vsk, NULL); if (ret < 0) { + sock->sk = NULL; sock_put(sk); return ret; } diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c index 2e29994f92ff..ab87ef66c1e8 100644 --- a/net/vmw_vsock/diag.c +++ b/net/vmw_vsock/diag.c @@ -157,6 +157,7 @@ static int vsock_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) } static const struct sock_diag_handler vsock_diag_handler = { + .owner = THIS_MODULE, .family = AF_VSOCK, .dump = vsock_diag_handler_dump, }; diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index e2157e387217..31342ab502b4 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -13,12 +13,12 @@ #include <linux/hyperv.h> #include <net/sock.h> #include <net/af_vsock.h> -#include <asm/hyperv-tlfs.h> +#include <hyperv/hvhdk.h> /* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some * stricter requirements on the hv_sock ring buffer size of six 4K pages. - * hyperv-tlfs defines HV_HYP_PAGE_SIZE as 4K. Newer hosts don't have this - * limitation; but, keep the defaults the same for compat. + * HV_HYP_PAGE_SIZE is defined as 4K. Newer hosts don't have this limitation; + * but, keep the defaults the same for compat. */ #define RINGBUFFER_HVS_RCV_SIZE (HV_HYP_PAGE_SIZE * 6) #define RINGBUFFER_HVS_SND_SIZE (HV_HYP_PAGE_SIZE * 6) @@ -549,6 +549,7 @@ static void hvs_destruct(struct vsock_sock *vsk) vmbus_hvsock_device_unregister(chan); kfree(hvs); + vsk->trans = NULL; } static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr) diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index 1748268e0694..f0e48e6911fc 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -94,6 +94,63 @@ out_rcu: return ret; } +/* Caller need to hold vsock->tx_lock on vq */ +static int virtio_transport_send_skb(struct sk_buff *skb, struct virtqueue *vq, + struct virtio_vsock *vsock, gfp_t gfp) +{ + int ret, in_sg = 0, out_sg = 0; + struct scatterlist **sgs; + + sgs = vsock->out_sgs; + sg_init_one(sgs[out_sg], virtio_vsock_hdr(skb), + sizeof(*virtio_vsock_hdr(skb))); + out_sg++; + + if (!skb_is_nonlinear(skb)) { + if (skb->len > 0) { + sg_init_one(sgs[out_sg], skb->data, skb->len); + out_sg++; + } + } else { + struct skb_shared_info *si; + int i; + + /* If skb is nonlinear, then its buffer must contain + * only header and nothing more. Data is stored in + * the fragged part. + */ + WARN_ON_ONCE(skb_headroom(skb) != sizeof(*virtio_vsock_hdr(skb))); + + si = skb_shinfo(skb); + + for (i = 0; i < si->nr_frags; i++) { + skb_frag_t *skb_frag = &si->frags[i]; + void *va; + + /* We will use 'page_to_virt()' for the userspace page + * here, because virtio or dma-mapping layers will call + * 'virt_to_phys()' later to fill the buffer descriptor. + * We don't touch memory at "virtual" address of this page. + */ + va = page_to_virt(skb_frag_page(skb_frag)); + sg_init_one(sgs[out_sg], + va + skb_frag_off(skb_frag), + skb_frag_size(skb_frag)); + out_sg++; + } + } + + ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, skb, gfp); + /* Usually this means that there is no more space available in + * the vq + */ + if (ret < 0) + return ret; + + virtio_transport_deliver_tap_pkt(skb); + return 0; +} + static void virtio_transport_send_pkt_work(struct work_struct *work) { @@ -111,60 +168,17 @@ virtio_transport_send_pkt_work(struct work_struct *work) vq = vsock->vqs[VSOCK_VQ_TX]; for (;;) { - int ret, in_sg = 0, out_sg = 0; - struct scatterlist **sgs; struct sk_buff *skb; bool reply; + int ret; skb = virtio_vsock_skb_dequeue(&vsock->send_pkt_queue); if (!skb) break; - virtio_transport_deliver_tap_pkt(skb); reply = virtio_vsock_skb_reply(skb); - sgs = vsock->out_sgs; - sg_init_one(sgs[out_sg], virtio_vsock_hdr(skb), - sizeof(*virtio_vsock_hdr(skb))); - out_sg++; - - if (!skb_is_nonlinear(skb)) { - if (skb->len > 0) { - sg_init_one(sgs[out_sg], skb->data, skb->len); - out_sg++; - } - } else { - struct skb_shared_info *si; - int i; - - /* If skb is nonlinear, then its buffer must contain - * only header and nothing more. Data is stored in - * the fragged part. - */ - WARN_ON_ONCE(skb_headroom(skb) != sizeof(*virtio_vsock_hdr(skb))); - si = skb_shinfo(skb); - - for (i = 0; i < si->nr_frags; i++) { - skb_frag_t *skb_frag = &si->frags[i]; - void *va; - - /* We will use 'page_to_virt()' for the userspace page - * here, because virtio or dma-mapping layers will call - * 'virt_to_phys()' later to fill the buffer descriptor. - * We don't touch memory at "virtual" address of this page. - */ - va = page_to_virt(skb_frag_page(skb_frag)); - sg_init_one(sgs[out_sg], - va + skb_frag_off(skb_frag), - skb_frag_size(skb_frag)); - out_sg++; - } - } - - ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, skb, GFP_KERNEL); - /* Usually this means that there is no more space available in - * the vq - */ + ret = virtio_transport_send_skb(skb, vq, vsock, GFP_KERNEL); if (ret < 0) { virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb); break; @@ -194,6 +208,28 @@ out: queue_work(virtio_vsock_workqueue, &vsock->rx_work); } +/* Caller need to hold RCU for vsock. + * Returns 0 if the packet is successfully put on the vq. + */ +static int virtio_transport_send_skb_fast_path(struct virtio_vsock *vsock, struct sk_buff *skb) +{ + struct virtqueue *vq = vsock->vqs[VSOCK_VQ_TX]; + int ret; + + /* Inside RCU, can't sleep! */ + ret = mutex_trylock(&vsock->tx_lock); + if (unlikely(ret == 0)) + return -EBUSY; + + ret = virtio_transport_send_skb(skb, vq, vsock, GFP_ATOMIC); + if (ret == 0) + virtqueue_kick(vq); + + mutex_unlock(&vsock->tx_lock); + + return ret; +} + static int virtio_transport_send_pkt(struct sk_buff *skb) { @@ -217,11 +253,20 @@ virtio_transport_send_pkt(struct sk_buff *skb) goto out_rcu; } - if (virtio_vsock_skb_reply(skb)) - atomic_inc(&vsock->queued_replies); + /* If send_pkt_queue is empty, we can safely bypass this queue + * because packet order is maintained and (try) to put the packet + * on the virtqueue using virtio_transport_send_skb_fast_path. + * If this fails we simply put the packet on the intermediate + * queue and schedule the worker. + */ + if (!skb_queue_empty_lockless(&vsock->send_pkt_queue) || + virtio_transport_send_skb_fast_path(vsock, skb)) { + if (virtio_vsock_skb_reply(skb)) + atomic_inc(&vsock->queued_replies); - virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb); - queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); + virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb); + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); + } out_rcu: rcu_read_unlock(); @@ -310,7 +355,7 @@ static void virtio_transport_tx_work(struct work_struct *work) virtqueue_disable_cb(vq); while ((skb = virtqueue_get_buf(vq, &len)) != NULL) { - consume_skb(skb); + virtio_transport_consume_skb_sent(skb, true); added = true; } } while (!virtqueue_enable_cb(vq)); @@ -539,6 +584,8 @@ static struct virtio_transport virtio_transport = { .notify_buffer_size = virtio_transport_notify_buffer_size, .notify_set_rcvlowat = virtio_transport_notify_set_rcvlowat, + .unsent_bytes = virtio_transport_unsent_bytes, + .read_skb = virtio_transport_read_skb, }, @@ -616,20 +663,21 @@ out: static int virtio_vsock_vqs_init(struct virtio_vsock *vsock) { struct virtio_device *vdev = vsock->vdev; - static const char * const names[] = { - "rx", - "tx", - "event", - }; - vq_callback_t *callbacks[] = { - virtio_vsock_rx_done, - virtio_vsock_tx_done, - virtio_vsock_event_done, + struct virtqueue_info vqs_info[] = { + { "rx", virtio_vsock_rx_done }, + { "tx", virtio_vsock_tx_done }, + { "event", virtio_vsock_event_done }, }; int ret; - ret = virtio_find_vqs(vdev, VSOCK_VQ_MAX, vsock->vqs, callbacks, names, - NULL); + mutex_lock(&vsock->rx_lock); + vsock->rx_buf_nr = 0; + vsock->rx_buf_max_nr = 0; + mutex_unlock(&vsock->rx_lock); + + atomic_set(&vsock->queued_replies, 0); + + ret = virtio_find_vqs(vdev, VSOCK_VQ_MAX, vsock->vqs, vqs_info, NULL); if (ret < 0) return ret; @@ -738,9 +786,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev) vsock->vdev = vdev; - vsock->rx_buf_nr = 0; - vsock->rx_buf_max_nr = 0; - atomic_set(&vsock->queued_replies, 0); mutex_init(&vsock->tx_lock); mutex_init(&vsock->rx_lock); @@ -858,7 +903,6 @@ static struct virtio_driver virtio_vsock_driver = { .feature_table = features, .feature_table_size = ARRAY_SIZE(features), .driver.name = KBUILD_MODNAME, - .driver.owner = THIS_MODULE, .id_table = id_table, .probe = virtio_vsock_probe, .remove = virtio_vsock_remove, diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index 16ff976a86e3..7f7de6d88096 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -26,6 +26,9 @@ /* Threshold for detecting small packets to copy */ #define GOOD_COPY_LEN 128 +static void virtio_transport_cancel_close_work(struct vsock_sock *vsk, + bool cancel_timeout); + static const struct virtio_transport * virtio_transport_get_ops(struct vsock_sock *vsk) { @@ -400,6 +403,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, if (virtio_transport_init_zcopy_skb(vsk, skb, info->msg, can_zcopy)) { + kfree_skb(skb); ret = -ENOMEM; break; } @@ -463,6 +467,26 @@ void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff * } EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); +void virtio_transport_consume_skb_sent(struct sk_buff *skb, bool consume) +{ + struct sock *s = skb->sk; + + if (s && skb->len) { + struct vsock_sock *vs = vsock_sk(s); + struct virtio_vsock_sock *vvs; + + vvs = vs->trans; + + spin_lock_bh(&vvs->tx_lock); + vvs->bytes_unsent -= skb->len; + spin_unlock_bh(&vvs->tx_lock); + } + + if (consume) + consume_skb(skb); +} +EXPORT_SYMBOL_GPL(virtio_transport_consume_skb_sent); + u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) { u32 ret; @@ -475,6 +499,7 @@ u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) if (ret > credit) ret = credit; vvs->tx_cnt += ret; + vvs->bytes_unsent += ret; spin_unlock_bh(&vvs->tx_lock); return ret; @@ -488,6 +513,7 @@ void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) spin_lock_bh(&vvs->tx_lock); vvs->tx_cnt -= credit; + vvs->bytes_unsent -= credit; spin_unlock_bh(&vvs->tx_lock); } EXPORT_SYMBOL_GPL(virtio_transport_put_credit); @@ -1086,10 +1112,26 @@ void virtio_transport_destruct(struct vsock_sock *vsk) { struct virtio_vsock_sock *vvs = vsk->trans; + virtio_transport_cancel_close_work(vsk, true); + kfree(vvs); + vsk->trans = NULL; } EXPORT_SYMBOL_GPL(virtio_transport_destruct); +ssize_t virtio_transport_unsent_bytes(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + size_t ret; + + spin_lock_bh(&vvs->tx_lock); + ret = vvs->bytes_unsent; + spin_unlock_bh(&vvs->tx_lock); + + return ret; +} +EXPORT_SYMBOL_GPL(virtio_transport_unsent_bytes); + static int virtio_transport_reset(struct vsock_sock *vsk, struct sk_buff *skb) { @@ -1167,17 +1209,11 @@ static void virtio_transport_wait_close(struct sock *sk, long timeout) } } -static void virtio_transport_do_close(struct vsock_sock *vsk, - bool cancel_timeout) +static void virtio_transport_cancel_close_work(struct vsock_sock *vsk, + bool cancel_timeout) { struct sock *sk = sk_vsock(vsk); - sock_set_flag(sk, SOCK_DONE); - vsk->peer_shutdown = SHUTDOWN_MASK; - if (vsock_stream_has_data(vsk) <= 0) - sk->sk_state = TCP_CLOSING; - sk->sk_state_change(sk); - if (vsk->close_work_scheduled && (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { vsk->close_work_scheduled = false; @@ -1189,6 +1225,20 @@ static void virtio_transport_do_close(struct vsock_sock *vsk, } } +static void virtio_transport_do_close(struct vsock_sock *vsk, + bool cancel_timeout) +{ + struct sock *sk = sk_vsock(vsk); + + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = TCP_CLOSING; + sk->sk_state_change(sk); + + virtio_transport_cancel_close_work(vsk, cancel_timeout); +} + static void virtio_transport_close_timeout(struct work_struct *work) { struct vsock_sock *vsk = @@ -1477,6 +1527,14 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, return -ENOMEM; } + /* __vsock_release() might have already flushed accept_queue. + * Subsequent enqueues would lead to a memory leak. + */ + if (sk->sk_shutdown == SHUTDOWN_MASK) { + virtio_transport_reset_no_sock(t, skb); + return -ESHUTDOWN; + } + child = vsock_create_connected(sk); if (!child) { virtio_transport_reset_no_sock(t, skb); @@ -1583,8 +1641,11 @@ void virtio_transport_recv_pkt(struct virtio_transport *t, lock_sock(sk); - /* Check if sk has been closed before lock_sock */ - if (sock_flag(sk, SOCK_DONE)) { + /* Check if sk has been closed or assigned to another transport before + * lock_sock (note: listener sockets are not assigned to any transport) + */ + if (sock_flag(sk, SOCK_DONE) || + (sk->sk_state != TCP_LISTEN && vsk->transport != &t->transport)) { (void)virtio_transport_reset_no_sock(t, skb); release_sock(sk); sock_put(sk); @@ -1672,6 +1733,7 @@ int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_acto { struct virtio_vsock_sock *vvs = vsk->trans; struct sock *sk = sk_vsock(vsk); + struct virtio_vsock_hdr *hdr; struct sk_buff *skb; int off = 0; int err; @@ -1681,10 +1743,19 @@ int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_acto * works for types other than dgrams. */ skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err); + if (!skb) { + spin_unlock_bh(&vvs->rx_lock); + return err; + } + + hdr = virtio_vsock_hdr(skb); + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) + vvs->msg_count--; + + virtio_transport_dec_rx_pkt(vvs, le32_to_cpu(hdr->len)); spin_unlock_bh(&vvs->rx_lock); - if (!skb) - return err; + virtio_transport_send_credit_update(vsk); return recv_actor(sk, skb); } diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c index a3c97546ab84..07b96d56f3a5 100644 --- a/net/vmw_vsock/vsock_bpf.c +++ b/net/vmw_vsock/vsock_bpf.c @@ -64,9 +64,9 @@ static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int int err; if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) - err = vsock_connectible_recvmsg(sock, msg, len, flags); + err = __vsock_connectible_recvmsg(sock, msg, len, flags); else if (sk->sk_type == SOCK_DGRAM) - err = vsock_dgram_recvmsg(sock, msg, len, flags); + err = __vsock_dgram_recvmsg(sock, msg, len, flags); else err = -EPROTOTYPE; @@ -77,6 +77,7 @@ static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, int *addr_len) { struct sk_psock *psock; + struct vsock_sock *vsk; int copied; psock = sk_psock_get(sk); @@ -84,6 +85,13 @@ static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg, return __vsock_recvmsg(sk, msg, len, flags); lock_sock(sk); + vsk = vsock_sk(sk); + + if (WARN_ON_ONCE(!vsk->transport)) { + copied = -ENODEV; + goto out; + } + if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) { release_sock(sk); sk_psock_put(sk, psock); @@ -108,20 +116,13 @@ static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg, copied = sk_msg_recvmsg(sk, psock, msg, len, flags); } +out: release_sock(sk); sk_psock_put(sk, psock); return copied; } -/* Copy of original proto with updated sock_map methods */ -static struct proto vsock_bpf_prot = { - .close = sock_map_close, - .recvmsg = vsock_bpf_recvmsg, - .sock_is_readable = sk_msg_is_readable, - .unhash = sock_map_unhash, -}; - static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *base) { *prot = *base; diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c index 6dea6119f5b2..6e78927a598e 100644 --- a/net/vmw_vsock/vsock_loopback.c +++ b/net/vmw_vsock/vsock_loopback.c @@ -98,6 +98,8 @@ static struct virtio_transport loopback_transport = { .notify_buffer_size = virtio_transport_notify_buffer_size, .notify_set_rcvlowat = virtio_transport_notify_set_rcvlowat, + .unsent_bytes = virtio_transport_unsent_bytes, + .read_skb = virtio_transport_read_skb, }, @@ -123,6 +125,10 @@ static void vsock_loopback_work(struct work_struct *work) spin_unlock_bh(&vsock->pkt_queue.lock); while ((skb = __skb_dequeue(&pkts))) { + /* Decrement the bytes_unsent counter without deallocating skb + * It is freed by the receiver. + */ + virtio_transport_consume_skb_sent(skb, false); virtio_transport_deliver_tap_pkt(skb); virtio_transport_recv_pkt(&loopback_transport, skb); } |