diff options
Diffstat (limited to 'net/vmw_vsock/af_vsock.c')
| -rw-r--r-- | net/vmw_vsock/af_vsock.c | 1790 |
1 files changed, 1289 insertions, 501 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 593071dabd1c..adcba1b7bf74 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -1,16 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * VMware vSockets Driver * * Copyright (C) 2007-2013 VMware, Inc. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the Free - * Software Foundation version 2 and no later version. - * - * This program is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for - * more details. */ /* Implementation notes: @@ -36,19 +28,20 @@ * not support simultaneous connects (two "client" sockets connecting). * * - "Server" sockets are referred to as listener sockets throughout this - * implementation because they are in the SS_LISTEN state. When a connection - * request is received (the second kind of socket mentioned above), we create a - * new socket and refer to it as a pending socket. These pending sockets are - * placed on the pending connection list of the listener socket. When future - * packets are received for the address the listener socket is bound to, we - * check if the source of the packet is from one that has an existing pending - * connection. If it does, we process the packet for the pending socket. When - * that socket reaches the connected state, it is removed from the listener - * socket's pending list and enqueued in the listener socket's accept queue. - * Callers of accept(2) will accept connected sockets from the listener socket's - * accept queue. If the socket cannot be accepted for some reason then it is - * marked rejected. Once the connection is accepted, it is owned by the user - * process and the responsibility for cleanup falls with that user process. + * implementation because they are in the TCP_LISTEN state. When a + * connection request is received (the second kind of socket mentioned above), + * we create a new socket and refer to it as a pending socket. These pending + * sockets are placed on the pending connection list of the listener socket. + * When future packets are received for the address the listener socket is + * bound to, we check if the source of the packet is from one that has an + * existing pending connection. If it does, we process the packet for the + * pending socket. When that socket reaches the connected state, it is removed + * from the listener socket's pending list and enqueued in the listener + * socket's accept queue. Callers of accept(2) will accept connected sockets + * from the listener socket's accept queue. If the socket cannot be accepted + * for some reason then it is marked rejected. Once the connection is + * accepted, it is owned by the user process and the responsibility for cleanup + * falls with that user process. * * - It is possible that these pending sockets will never reach the connected * state; in fact, we may never receive another packet after the connection @@ -60,6 +53,14 @@ * function will also cleanup rejected sockets, those that reach the connected * state but leave it before they have been accepted. * + * - Lock ordering for pending or accept queue sockets is: + * + * lock_sock(listener); + * lock_sock_nested(pending, SINGLE_DEPTH_NESTING); + * + * Using explicit nested locking keeps lockdep happy since normally only one + * lock of a given class may be taken at a time. + * * - Sockets created by user action will be cleaned up when the user process * calls close(2), causing our release implementation to be called. Our release * implementation will perform some cleanup then drop the last reference so our @@ -73,14 +74,26 @@ * argument, we must ensure the reference count is increased to ensure the * socket isn't freed before the function is run; the deferred function will * then drop the reference. + * + * - sk->sk_state uses the TCP state constants because they are widely used by + * other address families and exposed to userspace tools like ss(8): + * + * TCP_CLOSE - unconnected + * TCP_SYN_SENT - connecting + * TCP_ESTABLISHED - connected + * TCP_CLOSING - disconnecting + * TCP_LISTEN - listening */ +#include <linux/compat.h> #include <linux/types.h> #include <linux/bitops.h> #include <linux/cred.h> +#include <linux/errqueue.h> #include <linux/init.h> #include <linux/io.h> #include <linux/kernel.h> +#include <linux/sched/signal.h> #include <linux/kmod.h> #include <linux/list.h> #include <linux/miscdevice.h> @@ -88,6 +101,7 @@ #include <linux/mutex.h> #include <linux/net.h> #include <linux/poll.h> +#include <linux/random.h> #include <linux/skbuff.h> #include <linux/smp.h> #include <linux/socket.h> @@ -96,18 +110,24 @@ #include <linux/wait.h> #include <linux/workqueue.h> #include <net/sock.h> - -#include "af_vsock.h" +#include <net/af_vsock.h> +#include <uapi/linux/vm_sockets.h> +#include <uapi/asm-generic/ioctls.h> static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr); static void vsock_sk_destruct(struct sock *sk); static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); +static void vsock_close(struct sock *sk, long timeout); /* Protocol family. */ -static struct proto vsock_proto = { +struct proto vsock_proto = { .name = "AF_VSOCK", .owner = THIS_MODULE, .obj_size = sizeof(struct vsock_sock), + .close = vsock_close, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = vsock_bpf_update_proto, +#endif }; /* The default peer timeout indicates how long we will wait for a peer response @@ -115,21 +135,20 @@ static struct proto vsock_proto = { */ #define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ) -#define SS_LISTEN 255 - -static const struct vsock_transport *transport; +#define VSOCK_DEFAULT_BUFFER_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128 + +/* Transport used for host->guest communication */ +static const struct vsock_transport *transport_h2g; +/* Transport used for guest->host communication */ +static const struct vsock_transport *transport_g2h; +/* Transport used for DGRAM communication */ +static const struct vsock_transport *transport_dgram; +/* Transport used for local communication */ +static const struct vsock_transport *transport_local; static DEFINE_MUTEX(vsock_register_mutex); -/**** EXPORTS ****/ - -/* Get the ID of the local context. This is transport dependent. */ - -int vm_sockets_get_local_cid(void) -{ - return transport->get_local_cid(); -} -EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid); - /**** UTILS ****/ /* Each bound VSocket is stored in the bind hash table and each connected @@ -146,7 +165,6 @@ EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid); * vsock_bind_table[VSOCK_HASH_SIZE] is for unbound sockets. The hash function * mods with VSOCK_HASH_SIZE to ensure this. */ -#define VSOCK_HASH_SIZE 251 #define MAX_PORT_RETRIES 24 #define VSOCK_HASH(addr) ((addr)->svm_port % VSOCK_HASH_SIZE) @@ -161,9 +179,12 @@ EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid); #define vsock_connected_sockets_vsk(vsk) \ vsock_connected_sockets(&(vsk)->remote_addr, &(vsk)->local_addr) -static struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1]; -static struct list_head vsock_connected_table[VSOCK_HASH_SIZE]; -static DEFINE_SPINLOCK(vsock_table_lock); +struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1]; +EXPORT_SYMBOL_GPL(vsock_bind_table); +struct list_head vsock_connected_table[VSOCK_HASH_SIZE]; +EXPORT_SYMBOL_GPL(vsock_connected_table); +DEFINE_SPINLOCK(vsock_table_lock); +EXPORT_SYMBOL_GPL(vsock_table_lock); /* Autobind this socket to the local address if necessary. */ static int vsock_auto_bind(struct vsock_sock *vsk) @@ -218,9 +239,15 @@ static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr) { struct vsock_sock *vsk; - list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) - if (addr->svm_port == vsk->local_addr.svm_port) + list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) { + if (vsock_addr_equals_addr(addr, &vsk->local_addr)) + return sk_vsock(vsk); + + if (addr->svm_port == vsk->local_addr.svm_port && + (vsk->local_addr.svm_cid == VMADDR_CID_ANY || + addr->svm_cid == VMADDR_CID_ANY)) return sk_vsock(vsk); + } return NULL; } @@ -241,16 +268,6 @@ static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src, return NULL; } -static bool __vsock_in_bound_table(struct vsock_sock *vsk) -{ - return !list_empty(&vsk->bound_table); -} - -static bool __vsock_in_connected_table(struct vsock_sock *vsk) -{ - return !list_empty(&vsk->connected_table); -} - static void vsock_insert_unbound(struct vsock_sock *vsk) { spin_lock_bh(&vsock_table_lock); @@ -272,7 +289,8 @@ EXPORT_SYMBOL_GPL(vsock_insert_connected); void vsock_remove_bound(struct vsock_sock *vsk) { spin_lock_bh(&vsock_table_lock); - __vsock_remove_bound(vsk); + if (__vsock_in_bound_table(vsk)) + __vsock_remove_bound(vsk); spin_unlock_bh(&vsock_table_lock); } EXPORT_SYMBOL_GPL(vsock_remove_bound); @@ -280,7 +298,8 @@ EXPORT_SYMBOL_GPL(vsock_remove_bound); void vsock_remove_connected(struct vsock_sock *vsk) { spin_lock_bh(&vsock_table_lock); - __vsock_remove_connected(vsk); + if (__vsock_in_connected_table(vsk)) + __vsock_remove_connected(vsk); spin_unlock_bh(&vsock_table_lock); } EXPORT_SYMBOL_GPL(vsock_remove_connected); @@ -316,29 +335,18 @@ struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, } EXPORT_SYMBOL_GPL(vsock_find_connected_socket); -static bool vsock_in_bound_table(struct vsock_sock *vsk) -{ - bool ret; - - spin_lock_bh(&vsock_table_lock); - ret = __vsock_in_bound_table(vsk); - spin_unlock_bh(&vsock_table_lock); - - return ret; -} - -static bool vsock_in_connected_table(struct vsock_sock *vsk) +void vsock_remove_sock(struct vsock_sock *vsk) { - bool ret; - - spin_lock_bh(&vsock_table_lock); - ret = __vsock_in_connected_table(vsk); - spin_unlock_bh(&vsock_table_lock); + /* Transport reassignment must not remove the binding. */ + if (sock_flag(sk_vsock(vsk), SOCK_DEAD)) + vsock_remove_bound(vsk); - return ret; + vsock_remove_connected(vsk); } +EXPORT_SYMBOL_GPL(vsock_remove_sock); -void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)) +void vsock_for_each_connected_socket(struct vsock_transport *transport, + void (*fn)(struct sock *sk)) { int i; @@ -347,8 +355,12 @@ void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)) for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) { struct vsock_sock *vsk; list_for_each_entry(vsk, &vsock_connected_table[i], - connected_table); + connected_table) { + if (vsk->transport != transport) + continue; + fn(sk_vsock(vsk)); + } } spin_unlock_bh(&vsock_table_lock); @@ -393,6 +405,181 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected) } EXPORT_SYMBOL_GPL(vsock_enqueue_accept); +static bool vsock_use_local_transport(unsigned int remote_cid) +{ + lockdep_assert_held(&vsock_register_mutex); + + if (!transport_local) + return false; + + if (remote_cid == VMADDR_CID_LOCAL) + return true; + + if (transport_g2h) { + return remote_cid == transport_g2h->get_local_cid(); + } else { + return remote_cid == VMADDR_CID_HOST; + } +} + +static void vsock_deassign_transport(struct vsock_sock *vsk) +{ + if (!vsk->transport) + return; + + vsk->transport->destruct(vsk); + module_put(vsk->transport->module); + vsk->transport = NULL; +} + +/* Assign a transport to a socket and call the .init transport callback. + * + * Note: for connection oriented socket this must be called when vsk->remote_addr + * is set (e.g. during the connect() or when a connection request on a listener + * socket is received). + * The vsk->remote_addr is used to decide which transport to use: + * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if + * g2h is not loaded, will use local transport; + * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field + * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport; + * - remote CID > VMADDR_CID_HOST will use host->guest transport; + */ +int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + const struct vsock_transport *new_transport; + struct sock *sk = sk_vsock(vsk); + unsigned int remote_cid = vsk->remote_addr.svm_cid; + __u8 remote_flags; + int ret; + + /* If the packet is coming with the source and destination CIDs higher + * than VMADDR_CID_HOST, then a vsock channel where all the packets are + * forwarded to the host should be established. Then the host will + * need to forward the packets to the guest. + * + * The flag is set on the (listen) receive path (psk is not NULL). On + * the connect path the flag can be set by the user space application. + */ + if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST && + vsk->remote_addr.svm_cid > VMADDR_CID_HOST) + vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST; + + remote_flags = vsk->remote_addr.svm_flags; + + mutex_lock(&vsock_register_mutex); + + switch (sk->sk_type) { + case SOCK_DGRAM: + new_transport = transport_dgram; + break; + case SOCK_STREAM: + case SOCK_SEQPACKET: + if (vsock_use_local_transport(remote_cid)) + new_transport = transport_local; + else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g || + (remote_flags & VMADDR_FLAG_TO_HOST)) + new_transport = transport_g2h; + else + new_transport = transport_h2g; + break; + default: + ret = -ESOCKTNOSUPPORT; + goto err; + } + + if (vsk->transport && vsk->transport == new_transport) { + ret = 0; + goto err; + } + + /* We increase the module refcnt to prevent the transport unloading + * while there are open sockets assigned to it. + */ + if (!new_transport || !try_module_get(new_transport->module)) { + ret = -ENODEV; + goto err; + } + + /* It's safe to release the mutex after a successful try_module_get(). + * Whichever transport `new_transport` points at, it won't go away until + * the last module_put() below or in vsock_deassign_transport(). + */ + mutex_unlock(&vsock_register_mutex); + + if (vsk->transport) { + /* transport->release() must be called with sock lock acquired. + * This path can only be taken during vsock_connect(), where we + * have already held the sock lock. In the other cases, this + * function is called on a new socket which is not assigned to + * any transport. + */ + vsk->transport->release(vsk); + vsock_deassign_transport(vsk); + + /* transport's release() and destruct() can touch some socket + * state, since we are reassigning the socket to a new transport + * during vsock_connect(), let's reset these fields to have a + * clean state. + */ + sock_reset_flag(sk, SOCK_DONE); + sk->sk_state = TCP_CLOSE; + vsk->peer_shutdown = 0; + } + + if (sk->sk_type == SOCK_SEQPACKET) { + if (!new_transport->seqpacket_allow || + !new_transport->seqpacket_allow(remote_cid)) { + module_put(new_transport->module); + return -ESOCKTNOSUPPORT; + } + } + + ret = new_transport->init(vsk, psk); + if (ret) { + module_put(new_transport->module); + return ret; + } + + vsk->transport = new_transport; + + return 0; +err: + mutex_unlock(&vsock_register_mutex); + return ret; +} +EXPORT_SYMBOL_GPL(vsock_assign_transport); + +/* + * Provide safe access to static transport_{h2g,g2h,dgram,local} callbacks. + * Otherwise we may race with module removal. Do not use on `vsk->transport`. + */ +static u32 vsock_registered_transport_cid(const struct vsock_transport **transport) +{ + u32 cid = VMADDR_CID_ANY; + + mutex_lock(&vsock_register_mutex); + if (*transport) + cid = (*transport)->get_local_cid(); + mutex_unlock(&vsock_register_mutex); + + return cid; +} + +bool vsock_find_cid(unsigned int cid) +{ + if (cid == vsock_registered_transport_cid(&transport_g2h)) + return true; + + if (transport_h2g && cid == VMADDR_CID_HOST) + return true; + + if (transport_local && cid == VMADDR_CID_LOCAL) + return true; + + return false; +} +EXPORT_SYMBOL_GPL(vsock_find_cid); + static struct sock *vsock_dequeue_accept(struct sock *listener) { struct vsock_sock *vlistener; @@ -429,26 +616,33 @@ static bool vsock_is_pending(struct sock *sk) static int vsock_send_shutdown(struct sock *sk, int mode) { - return transport->shutdown(vsock_sk(sk), mode); + struct vsock_sock *vsk = vsock_sk(sk); + + if (!vsk->transport) + return -ENODEV; + + return vsk->transport->shutdown(vsk, mode); } -void vsock_pending_work(struct work_struct *work) +static void vsock_pending_work(struct work_struct *work) { struct sock *sk; struct sock *listener; struct vsock_sock *vsk; bool cleanup; - vsk = container_of(work, struct vsock_sock, dwork.work); + vsk = container_of(work, struct vsock_sock, pending_work.work); sk = sk_vsock(vsk); listener = vsk->listener; cleanup = true; lock_sock(listener); - lock_sock(sk); + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); if (vsock_is_pending(sk)) { vsock_remove_pending(listener, sk); + + sk_acceptq_removed(listener); } else if (!vsk->rejected) { /* We are not on the pending list and accept() did not reject * us, so we must have been accepted by our user process. We @@ -459,16 +653,13 @@ void vsock_pending_work(struct work_struct *work) goto out; } - listener->sk_ack_backlog--; - /* We need to remove ourself from the global connected sockets list so * incoming packets can't find this socket, and to reduce the reference * count. */ - if (vsock_in_connected_table(vsk)) - vsock_remove_connected(vsk); + vsock_remove_connected(vsk); - sk->sk_state = SS_FREE; + sk->sk_state = TCP_CLOSE; out: release_sock(sk); @@ -479,16 +670,18 @@ out: sock_put(sk); sock_put(listener); } -EXPORT_SYMBOL_GPL(vsock_pending_work); /**** SOCKET OPERATIONS ****/ -static int __vsock_bind_stream(struct vsock_sock *vsk, - struct sockaddr_vm *addr) +static int __vsock_bind_connectible(struct vsock_sock *vsk, + struct sockaddr_vm *addr) { - static u32 port = LAST_RESERVED_PORT + 1; + static u32 port; struct sockaddr_vm new_addr; + if (!port) + port = get_random_u32_above(LAST_RESERVED_PORT); + vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port); if (addr->svm_port == VMADDR_PORT_ANY) { @@ -496,7 +689,8 @@ static int __vsock_bind_stream(struct vsock_sock *vsk, unsigned int i; for (i = 0; i < MAX_PORT_RETRIES; i++) { - if (port <= LAST_RESERVED_PORT) + if (port == VMADDR_PORT_ANY || + port <= LAST_RESERVED_PORT) port = LAST_RESERVED_PORT + 1; new_addr.svm_port = port++; @@ -524,9 +718,10 @@ static int __vsock_bind_stream(struct vsock_sock *vsk, vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port); - /* Remove stream sockets from the unbound list and add them to the hash - * table for easy lookup by its address. The unbound list is simply an - * extra entry at the end of the hash table, a trick used by AF_UNIX. + /* Remove connection oriented sockets from the unbound list and add them + * to the hash table for easy lookup by its address. The unbound list + * is simply an extra entry at the end of the hash table, a trick used + * by AF_UNIX. */ __vsock_remove_bound(vsk); __vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk); @@ -537,13 +732,12 @@ static int __vsock_bind_stream(struct vsock_sock *vsk, static int __vsock_bind_dgram(struct vsock_sock *vsk, struct sockaddr_vm *addr) { - return transport->dgram_bind(vsk, addr); + return vsk->transport->dgram_bind(vsk, addr); } static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) { struct vsock_sock *vsk = vsock_sk(sk); - u32 cid; int retval; /* First ensure this socket isn't already bound. */ @@ -553,16 +747,16 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) /* Now bind to the provided address or select appropriate values if * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY). Note that * like AF_INET prevents binding to a non-local IP address (in most - * cases), we only allow binding to the local CID. + * cases), we only allow binding to a local CID. */ - cid = transport->get_local_cid(); - if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY) + if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid)) return -EADDRNOTAVAIL; switch (sk->sk_socket->type) { case SOCK_STREAM: + case SOCK_SEQPACKET: spin_lock_bh(&vsock_table_lock); - retval = __vsock_bind_stream(vsk, addr); + retval = __vsock_bind_connectible(vsk, addr); spin_unlock_bh(&vsock_table_lock); break; @@ -578,17 +772,20 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) return retval; } -struct sock *__vsock_create(struct net *net, - struct socket *sock, - struct sock *parent, - gfp_t priority, - unsigned short type) +static void vsock_connect_timeout(struct work_struct *work); + +static struct sock *__vsock_create(struct net *net, + struct socket *sock, + struct sock *parent, + gfp_t priority, + unsigned short type, + int kern) { struct sock *sk; struct vsock_sock *psk; struct vsock_sock *vsk; - sk = sk_alloc(net, AF_VSOCK, priority, &vsock_proto); + sk = sk_alloc(net, AF_VSOCK, priority, &vsock_proto, kern); if (!sk) return NULL; @@ -607,7 +804,6 @@ struct sock *__vsock_create(struct net *net, sk->sk_destruct = vsock_sk_destruct; sk->sk_backlog_rcv = vsock_queue_rcv_skb; - sk->sk_state = 0; sock_reset_flag(sk, SOCK_DONE); INIT_LIST_HEAD(&vsk->bound_table); @@ -619,71 +815,85 @@ struct sock *__vsock_create(struct net *net, vsk->sent_request = false; vsk->ignore_connecting_rst = false; vsk->peer_shutdown = 0; + INIT_DELAYED_WORK(&vsk->connect_work, vsock_connect_timeout); + INIT_DELAYED_WORK(&vsk->pending_work, vsock_pending_work); psk = parent ? vsock_sk(parent) : NULL; if (parent) { vsk->trusted = psk->trusted; vsk->owner = get_cred(psk->owner); vsk->connect_timeout = psk->connect_timeout; + vsk->buffer_size = psk->buffer_size; + vsk->buffer_min_size = psk->buffer_min_size; + vsk->buffer_max_size = psk->buffer_max_size; + security_sk_clone(parent, sk); } else { - vsk->trusted = capable(CAP_NET_ADMIN); + vsk->trusted = ns_capable_noaudit(&init_user_ns, CAP_NET_ADMIN); vsk->owner = get_current_cred(); vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT; + vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE; + vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE; + vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE; } - if (transport->init(vsk, psk) < 0) { - sk_free(sk); - return NULL; - } - - if (sock) - vsock_insert_unbound(vsk); - return sk; } -EXPORT_SYMBOL_GPL(__vsock_create); -static void __vsock_release(struct sock *sk) +static bool sock_type_connectible(u16 type) { - if (sk) { - struct sk_buff *skb; - struct sock *pending; - struct vsock_sock *vsk; + return (type == SOCK_STREAM) || (type == SOCK_SEQPACKET); +} - vsk = vsock_sk(sk); - pending = NULL; /* Compiler warning. */ +static void __vsock_release(struct sock *sk, int level) +{ + struct vsock_sock *vsk; + struct sock *pending; - if (vsock_in_bound_table(vsk)) - vsock_remove_bound(vsk); + vsk = vsock_sk(sk); + pending = NULL; /* Compiler warning. */ - if (vsock_in_connected_table(vsk)) - vsock_remove_connected(vsk); + /* When "level" is SINGLE_DEPTH_NESTING, use the nested + * version to avoid the warning "possible recursive locking + * detected". When "level" is 0, lock_sock_nested(sk, level) + * is the same as lock_sock(sk). + */ + lock_sock_nested(sk, level); - transport->release(vsk); + /* Indicate to vsock_remove_sock() that the socket is being released and + * can be removed from the bound_table. Unlike transport reassignment + * case, where the socket must remain bound despite vsock_remove_sock() + * being called from the transport release() callback. + */ + sock_set_flag(sk, SOCK_DEAD); - lock_sock(sk); - sock_orphan(sk); - sk->sk_shutdown = SHUTDOWN_MASK; + if (vsk->transport) + vsk->transport->release(vsk); + else if (sock_type_connectible(sk->sk_type)) + vsock_remove_sock(vsk); - while ((skb = skb_dequeue(&sk->sk_receive_queue))) - kfree_skb(skb); + sock_orphan(sk); + sk->sk_shutdown = SHUTDOWN_MASK; - /* Clean up any sockets that never were accepted. */ - while ((pending = vsock_dequeue_accept(sk)) != NULL) { - __vsock_release(pending); - sock_put(pending); - } + skb_queue_purge(&sk->sk_receive_queue); - release_sock(sk); - sock_put(sk); + /* Clean up any sockets that never were accepted. */ + while ((pending = vsock_dequeue_accept(sk)) != NULL) { + __vsock_release(pending, SINGLE_DEPTH_NESTING); + sock_put(pending); } + + release_sock(sk); + sock_put(sk); } static void vsock_sk_destruct(struct sock *sk) { struct vsock_sock *vsk = vsock_sk(sk); - transport->destruct(vsk); + /* Flush MSG_ZEROCOPY leftovers. */ + __skb_queue_purge(&sk->sk_error_queue); + + vsock_deassign_transport(vsk); /* When clearing these addresses, there's no need to set the family and * possibly register the address family with the kernel. @@ -705,21 +915,71 @@ static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) return err; } +struct sock *vsock_create_connected(struct sock *parent) +{ + return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL, + parent->sk_type, 0); +} +EXPORT_SYMBOL_GPL(vsock_create_connected); + s64 vsock_stream_has_data(struct vsock_sock *vsk) { - return transport->stream_has_data(vsk); + if (WARN_ON(!vsk->transport)) + return 0; + + return vsk->transport->stream_has_data(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_data); +s64 vsock_connectible_has_data(struct vsock_sock *vsk) +{ + struct sock *sk = sk_vsock(vsk); + + if (WARN_ON(!vsk->transport)) + return 0; + + if (sk->sk_type == SOCK_SEQPACKET) + return vsk->transport->seqpacket_has_data(vsk); + else + return vsock_stream_has_data(vsk); +} +EXPORT_SYMBOL_GPL(vsock_connectible_has_data); + s64 vsock_stream_has_space(struct vsock_sock *vsk) { - return transport->stream_has_space(vsk); + if (WARN_ON(!vsk->transport)) + return 0; + + return vsk->transport->stream_has_space(vsk); } EXPORT_SYMBOL_GPL(vsock_stream_has_space); +void vsock_data_ready(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsock_stream_has_data(vsk) >= sk->sk_rcvlowat || + sock_flag(sk, SOCK_DONE)) + sk->sk_data_ready(sk); +} +EXPORT_SYMBOL_GPL(vsock_data_ready); + +/* Dummy callback required by sockmap. + * See unconditional call of saved_close() in sock_map_close(). + */ +static void vsock_close(struct sock *sk, long timeout) +{ +} + static int vsock_release(struct socket *sock) { - __vsock_release(sock->sk); + struct sock *sk = sock->sk; + + if (!sk) + return 0; + + sk->sk_prot->close(sk, 0); + __vsock_release(sk, 0); sock->sk = NULL; sock->state = SS_FREE; @@ -727,7 +987,7 @@ static int vsock_release(struct socket *sock) } static int -vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len) +vsock_bind(struct socket *sock, struct sockaddr_unsized *addr, int addr_len) { int err; struct sock *sk; @@ -746,7 +1006,7 @@ vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len) } static int vsock_getname(struct socket *sock, - struct sockaddr *addr, int *addr_len, int peer) + struct sockaddr *addr, int peer) { int err; struct sock *sk; @@ -769,25 +1029,48 @@ static int vsock_getname(struct socket *sock, vm_addr = &vsk->local_addr; } - if (!vm_addr) { - err = -EINVAL; - goto out; - } - - /* sys_getsockname() and sys_getpeername() pass us a - * MAX_SOCK_ADDR-sized buffer and don't set addr_len. Unfortunately - * that macro is defined in socket.c instead of .h, so we hardcode its - * value here. - */ - BUILD_BUG_ON(sizeof(*vm_addr) > 128); + BUILD_BUG_ON(sizeof(*vm_addr) > sizeof(struct sockaddr_storage)); memcpy(addr, vm_addr, sizeof(*vm_addr)); - *addr_len = sizeof(*vm_addr); + err = sizeof(*vm_addr); out: release_sock(sk); return err; } +void vsock_linger(struct sock *sk) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + ssize_t (*unsent)(struct vsock_sock *vsk); + struct vsock_sock *vsk = vsock_sk(sk); + long timeout; + + if (!sock_flag(sk, SOCK_LINGER)) + return; + + timeout = sk->sk_lingertime; + if (!timeout) + return; + + /* Transports must implement `unsent_bytes` if they want to support + * SOCK_LINGER through `vsock_linger()` since we use it to check when + * the socket can be closed. + */ + unsent = vsk->transport->unsent_bytes; + if (!unsent) + return; + + add_wait_queue(sk_sleep(sk), &wait); + + do { + if (sk_wait_event(sk, &timeout, unsent(vsk) == 0, &wait)) + break; + } while (!signal_pending(current) && timeout); + + remove_wait_queue(sk_sleep(sk), &wait); +} +EXPORT_SYMBOL_GPL(vsock_linger); + static int vsock_shutdown(struct socket *sock, int mode) { int err; @@ -804,17 +1087,19 @@ static int vsock_shutdown(struct socket *sock, int mode) if ((mode & ~SHUTDOWN_MASK) || !mode) return -EINVAL; - /* If this is a STREAM socket and it is not connected then bail out - * immediately. If it is a DGRAM socket then we must first kick the - * socket so that it wakes up from any sleeping calls, for example - * recv(), and then afterwards return the error. + /* If this is a connection oriented socket and it is not connected then + * bail out immediately. If it is a DGRAM socket then we must first + * kick the socket so that it wakes up from any sleeping calls, for + * example recv(), and then afterwards return the error. */ sk = sock->sk; + + lock_sock(sk); if (sock->state == SS_UNCONNECTED) { err = -ENOTCONN; - if (sk->sk_type == SOCK_STREAM) - return err; + if (sock_type_connectible(sk->sk_type)) + goto out; } else { sock->state = SS_DISCONNECTING; err = 0; @@ -823,25 +1108,25 @@ static int vsock_shutdown(struct socket *sock, int mode) /* Receive and send shutdowns are treated alike. */ mode = mode & (RCV_SHUTDOWN | SEND_SHUTDOWN); if (mode) { - lock_sock(sk); sk->sk_shutdown |= mode; sk->sk_state_change(sk); - release_sock(sk); - if (sk->sk_type == SOCK_STREAM) { + if (sock_type_connectible(sk->sk_type)) { sock_reset_flag(sk, SOCK_DONE); vsock_send_shutdown(sk, mode); } } +out: + release_sock(sk); return err; } -static unsigned int vsock_poll(struct file *file, struct socket *sock, +static __poll_t vsock_poll(struct file *file, struct socket *sock, poll_table *wait) { struct sock *sk; - unsigned int mask; + __poll_t mask; struct vsock_sock *vsk; sk = sock->sk; @@ -850,58 +1135,66 @@ static unsigned int vsock_poll(struct file *file, struct socket *sock, poll_wait(file, sk_sleep(sk), wait); mask = 0; - if (sk->sk_err) + if (sk->sk_err || !skb_queue_empty_lockless(&sk->sk_error_queue)) /* Signify that there has been an error on this socket. */ - mask |= POLLERR; + mask |= EPOLLERR; /* INET sockets treat local write shutdown and peer write shutdown as a - * case of POLLHUP set. + * case of EPOLLHUP set. */ if ((sk->sk_shutdown == SHUTDOWN_MASK) || ((sk->sk_shutdown & SEND_SHUTDOWN) && (vsk->peer_shutdown & SEND_SHUTDOWN))) { - mask |= POLLHUP; + mask |= EPOLLHUP; } if (sk->sk_shutdown & RCV_SHUTDOWN || vsk->peer_shutdown & SEND_SHUTDOWN) { - mask |= POLLRDHUP; + mask |= EPOLLRDHUP; } + if (sk_is_readable(sk)) + mask |= EPOLLIN | EPOLLRDNORM; + if (sock->type == SOCK_DGRAM) { /* For datagram sockets we can read if there is something in * the queue and write as long as the socket isn't shutdown for * sending. */ - if (!skb_queue_empty(&sk->sk_receive_queue) || + if (!skb_queue_empty_lockless(&sk->sk_receive_queue) || (sk->sk_shutdown & RCV_SHUTDOWN)) { - mask |= POLLIN | POLLRDNORM; + mask |= EPOLLIN | EPOLLRDNORM; } if (!(sk->sk_shutdown & SEND_SHUTDOWN)) - mask |= POLLOUT | POLLWRNORM | POLLWRBAND; + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + + } else if (sock_type_connectible(sk->sk_type)) { + const struct vsock_transport *transport; - } else if (sock->type == SOCK_STREAM) { lock_sock(sk); + transport = vsk->transport; + /* Listening sockets that have connections in their accept * queue can be read. */ - if (sk->sk_state == SS_LISTEN + if (sk->sk_state == TCP_LISTEN && !vsock_is_accept_queue_empty(sk)) - mask |= POLLIN | POLLRDNORM; + mask |= EPOLLIN | EPOLLRDNORM; /* If there is something in the queue then we can read. */ - if (transport->stream_is_active(vsk) && + if (transport && transport->stream_is_active(vsk) && !(sk->sk_shutdown & RCV_SHUTDOWN)) { bool data_ready_now = false; + int target = sock_rcvlowat(sk, 0, INT_MAX); int ret = transport->notify_poll_in( - vsk, 1, &data_ready_now); + vsk, target, &data_ready_now); if (ret < 0) { - mask |= POLLERR; + mask |= EPOLLERR; } else { if (data_ready_now) - mask |= POLLIN | POLLRDNORM; + mask |= EPOLLIN | EPOLLRDNORM; } } @@ -912,35 +1205,35 @@ static unsigned int vsock_poll(struct file *file, struct socket *sock, */ if (sk->sk_shutdown & RCV_SHUTDOWN || vsk->peer_shutdown & SEND_SHUTDOWN) { - mask |= POLLIN | POLLRDNORM; + mask |= EPOLLIN | EPOLLRDNORM; } /* Connected sockets that can produce data can be written. */ - if (sk->sk_state == SS_CONNECTED) { + if (transport && sk->sk_state == TCP_ESTABLISHED) { if (!(sk->sk_shutdown & SEND_SHUTDOWN)) { bool space_avail_now = false; int ret = transport->notify_poll_out( vsk, 1, &space_avail_now); if (ret < 0) { - mask |= POLLERR; + mask |= EPOLLERR; } else { if (space_avail_now) - /* Remove POLLWRBAND since INET + /* Remove EPOLLWRBAND since INET * sockets are not setting it. */ - mask |= POLLOUT | POLLWRNORM; + mask |= EPOLLOUT | EPOLLWRNORM; } } } /* Simulate INET socket poll behaviors, which sets - * POLLOUT|POLLWRNORM when peer is closed and nothing to read, + * EPOLLOUT|EPOLLWRNORM when peer is closed and nothing to read, * but local send is not shutdown. */ - if (sk->sk_state == SS_UNCONNECTED) { + if (sk->sk_state == TCP_CLOSE || sk->sk_state == TCP_CLOSING) { if (!(sk->sk_shutdown & SEND_SHUTDOWN)) - mask |= POLLOUT | POLLWRNORM; + mask |= EPOLLOUT | EPOLLWRNORM; } @@ -950,13 +1243,24 @@ static unsigned int vsock_poll(struct file *file, struct socket *sock, return mask; } -static int vsock_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock, - struct msghdr *msg, size_t len) +static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (WARN_ON_ONCE(!vsk->transport)) + return -ENODEV; + + return vsk->transport->read_skb(vsk, read_actor); +} + +static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) { int err; struct sock *sk; struct vsock_sock *vsk; struct sockaddr_vm *remote_addr; + const struct vsock_transport *transport; if (msg->msg_flags & MSG_OOB) return -EOPNOTSUPP; @@ -968,6 +1272,8 @@ static int vsock_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock, lock_sock(sk); + transport = vsk->transport; + err = vsock_auto_bind(vsk); if (err) goto out; @@ -1014,7 +1320,7 @@ static int vsock_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock, goto out; } - err = transport->dgram_enqueue(vsk, remote_addr, msg->msg_iov, len); + err = transport->dgram_enqueue(vsk, remote_addr, msg, len); out: release_sock(sk); @@ -1022,7 +1328,7 @@ out: } static int vsock_dgram_connect(struct socket *sock, - struct sockaddr *addr, int addr_len, int flags) + struct sockaddr_unsized *addr, int addr_len, int flags) { int err; struct sock *sk; @@ -1049,8 +1355,8 @@ static int vsock_dgram_connect(struct socket *sock, if (err) goto out; - if (!transport->dgram_allow(remote_addr->svm_cid, - remote_addr->svm_port)) { + if (!vsk->transport->dgram_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { err = -EINVAL; goto out; } @@ -1058,16 +1364,117 @@ static int vsock_dgram_connect(struct socket *sock, memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr)); sock->state = SS_CONNECTED; + /* sock map disallows redirection of non-TCP sockets with sk_state != + * TCP_ESTABLISHED (see sock_map_redirect_allowed()), so we set + * TCP_ESTABLISHED here to allow redirection of connected vsock dgrams. + * + * This doesn't seem to be abnormal state for datagram sockets, as the + * same approach can be see in other datagram socket types as well + * (such as unix sockets). + */ + sk->sk_state = TCP_ESTABLISHED; + out: release_sock(sk); return err; } -static int vsock_dgram_recvmsg(struct kiocb *kiocb, struct socket *sock, - struct msghdr *msg, size_t len, int flags) +int __vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct sock *sk = sock->sk; + struct vsock_sock *vsk = vsock_sk(sk); + + return vsk->transport->dgram_dequeue(vsk, msg, len, flags); +} + +int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ +#ifdef CONFIG_BPF_SYSCALL + struct sock *sk = sock->sk; + const struct proto *prot; + + prot = READ_ONCE(sk->sk_prot); + if (prot != &vsock_proto) + return prot->recvmsg(sk, msg, len, flags, NULL); +#endif + + return __vsock_dgram_recvmsg(sock, msg, len, flags); +} +EXPORT_SYMBOL_GPL(vsock_dgram_recvmsg); + +static int vsock_do_ioctl(struct socket *sock, unsigned int cmd, + int __user *arg) +{ + struct sock *sk = sock->sk; + struct vsock_sock *vsk; + int ret; + + vsk = vsock_sk(sk); + + switch (cmd) { + case SIOCINQ: { + ssize_t n_bytes; + + if (!vsk->transport) { + ret = -EOPNOTSUPP; + break; + } + + if (sock_type_connectible(sk->sk_type) && + sk->sk_state == TCP_LISTEN) { + ret = -EINVAL; + break; + } + + n_bytes = vsock_stream_has_data(vsk); + if (n_bytes < 0) { + ret = n_bytes; + break; + } + ret = put_user(n_bytes, arg); + break; + } + case SIOCOUTQ: { + ssize_t n_bytes; + + if (!vsk->transport || !vsk->transport->unsent_bytes) { + ret = -EOPNOTSUPP; + break; + } + + if (sock_type_connectible(sk->sk_type) && sk->sk_state == TCP_LISTEN) { + ret = -EINVAL; + break; + } + + n_bytes = vsk->transport->unsent_bytes(vsk); + if (n_bytes < 0) { + ret = n_bytes; + break; + } + + ret = put_user(n_bytes, arg); + break; + } + default: + ret = -ENOIOCTLCMD; + } + + return ret; +} + +static int vsock_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) { - return transport->dgram_dequeue(kiocb, vsock_sk(sock->sk), msg, len, - flags); + int ret; + + lock_sock(sock->sk); + ret = vsock_do_ioctl(sock, cmd, (int __user *)arg); + release_sock(sock->sk); + + return ret; } static const struct proto_ops vsock_dgram_ops = { @@ -1080,43 +1487,54 @@ static const struct proto_ops vsock_dgram_ops = { .accept = sock_no_accept, .getname = vsock_getname, .poll = vsock_poll, - .ioctl = sock_no_ioctl, + .ioctl = vsock_ioctl, .listen = sock_no_listen, .shutdown = vsock_shutdown, - .setsockopt = sock_no_setsockopt, - .getsockopt = sock_no_getsockopt, .sendmsg = vsock_dgram_sendmsg, .recvmsg = vsock_dgram_recvmsg, .mmap = sock_no_mmap, - .sendpage = sock_no_sendpage, + .read_skb = vsock_read_skb, }; +static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) +{ + const struct vsock_transport *transport = vsk->transport; + + if (!transport || !transport->cancel_pkt) + return -EOPNOTSUPP; + + return transport->cancel_pkt(vsk); +} + static void vsock_connect_timeout(struct work_struct *work) { struct sock *sk; struct vsock_sock *vsk; - vsk = container_of(work, struct vsock_sock, dwork.work); + vsk = container_of(work, struct vsock_sock, connect_work.work); sk = sk_vsock(vsk); lock_sock(sk); - if (sk->sk_state == SS_CONNECTING && + if (sk->sk_state == TCP_SYN_SENT && (sk->sk_shutdown != SHUTDOWN_MASK)) { - sk->sk_state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; + sk->sk_socket->state = SS_UNCONNECTED; sk->sk_err = ETIMEDOUT; - sk->sk_error_report(sk); + sk_error_report(sk); + vsock_transport_cancel_pkt(vsk); } release_sock(sk); sock_put(sk); } -static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, - int addr_len, int flags) +static int vsock_connect(struct socket *sock, struct sockaddr_unsized *addr, + int addr_len, int flags) { int err; struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; struct sockaddr_vm *remote_addr; long timeout; DEFINE_WAIT(wait); @@ -1143,37 +1561,62 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, * non-blocking call. */ err = -EALREADY; + if (flags & O_NONBLOCK) + goto out; break; default: - if ((sk->sk_state == SS_LISTEN) || + if ((sk->sk_state == TCP_LISTEN) || vsock_addr_cast(addr, addr_len, &remote_addr) != 0) { err = -EINVAL; goto out; } + /* Set the remote address that we are connecting to. */ + memcpy(&vsk->remote_addr, remote_addr, + sizeof(vsk->remote_addr)); + + err = vsock_assign_transport(vsk, NULL); + if (err) + goto out; + + transport = vsk->transport; + /* The hypervisor and well-known contexts do not have socket * endpoints. */ - if (!transport->stream_allow(remote_addr->svm_cid, + if (!transport || + !transport->stream_allow(remote_addr->svm_cid, remote_addr->svm_port)) { err = -ENETUNREACH; goto out; } - /* Set the remote address that we are connecting to. */ - memcpy(&vsk->remote_addr, remote_addr, - sizeof(vsk->remote_addr)); + if (vsock_msgzerocopy_allow(transport)) { + set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); + } else if (sock_flag(sk, SOCK_ZEROCOPY)) { + /* If this option was set before 'connect()', + * when transport was unknown, check that this + * feature is supported here. + */ + err = -EOPNOTSUPP; + goto out; + } err = vsock_auto_bind(vsk); if (err) goto out; - sk->sk_state = SS_CONNECTING; + sk->sk_state = TCP_SYN_SENT; err = transport->connect(vsk); if (err < 0) goto out; + /* sk_err might have been set as a result of an earlier + * (failed) connect attempt. + */ + sk->sk_err = 0; + /* Mark sock as connecting and set the error code to in * progress in case this is a non-blocking connect. */ @@ -1188,7 +1631,11 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, timeout = vsk->connect_timeout; prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); - while (sk->sk_state != SS_CONNECTED && sk->sk_err == 0) { + /* If the socket is already closing or it is in an error state, there + * is no point in waiting. + */ + while (sk->sk_state != TCP_ESTABLISHED && + sk->sk_state != TCP_CLOSING && sk->sk_err == 0) { if (flags & O_NONBLOCK) { /* If we're not going to block, we schedule a timeout * function to generate a timeout on the connection @@ -1197,9 +1644,14 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, * timeout fires. */ sock_hold(sk); - INIT_DELAYED_WORK(&vsk->dwork, - vsock_connect_timeout); - schedule_delayed_work(&vsk->dwork, timeout); + + /* If the timeout function is already scheduled, + * reschedule it, then ungrab the socket refcount to + * keep it balanced. + */ + if (mod_delayed_work(system_percpu_wq, &vsk->connect_work, + timeout)) + sock_put(sk); /* Skip ahead to preserve error code set above. */ goto out_wait; @@ -1209,12 +1661,41 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, timeout = schedule_timeout(timeout); lock_sock(sk); - if (signal_pending(current)) { - err = sock_intr_errno(timeout); - goto out_wait_error; - } else if (timeout == 0) { - err = -ETIMEDOUT; - goto out_wait_error; + /* Connection established. Whatever happens to socket once we + * release it, that's not connect()'s concern. No need to go + * into signal and timeout handling. Call it a day. + * + * Note that allowing to "reset" an already established socket + * here is racy and insecure. + */ + if (sk->sk_state == TCP_ESTABLISHED) + break; + + /* If connection was _not_ established and a signal/timeout came + * to be, we want the socket's state reset. User space may want + * to retry. + * + * sk_state != TCP_ESTABLISHED implies that socket is not on + * vsock_connected_table. We keep the binding and the transport + * assigned. + */ + if (signal_pending(current) || timeout == 0) { + err = timeout == 0 ? -ETIMEDOUT : sock_intr_errno(timeout); + + /* Listener might have already responded with + * VIRTIO_VSOCK_OP_RESPONSE. Its handling expects our + * sk_state == TCP_SYN_SENT, which hereby we break. + * In such case VIRTIO_VSOCK_OP_RST will follow. + */ + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + + /* Try to cancel VIRTIO_VSOCK_OP_REQUEST skb sent out by + * transport->connect(). + */ + vsock_transport_cancel_pkt(vsk); + + goto out_wait; } prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); @@ -1222,23 +1703,21 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, if (sk->sk_err) { err = -sk->sk_err; - goto out_wait_error; - } else + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + } else { err = 0; + } out_wait: finish_wait(sk_sleep(sk), &wait); out: release_sock(sk); return err; - -out_wait_error: - sk->sk_state = SS_UNCONNECTED; - sock->state = SS_UNCONNECTED; - goto out_wait; } -static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) +static int vsock_accept(struct socket *sock, struct socket *newsock, + struct proto_accept_arg *arg) { struct sock *listener; int err; @@ -1252,12 +1731,12 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) lock_sock(listener); - if (sock->type != SOCK_STREAM) { + if (!sock_type_connectible(sock->type)) { err = -EOPNOTSUPP; goto out; } - if (listener->sk_state != SS_LISTEN) { + if (listener->sk_state != TCP_LISTEN) { err = -EINVAL; goto out; } @@ -1265,33 +1744,35 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) /* Wait for children sockets to appear; these are the new sockets * created upon connection establishment. */ - timeout = sock_sndtimeo(listener, flags & O_NONBLOCK); + timeout = sock_rcvtimeo(listener, arg->flags & O_NONBLOCK); prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); while ((connected = vsock_dequeue_accept(listener)) == NULL && listener->sk_err == 0) { release_sock(listener); timeout = schedule_timeout(timeout); + finish_wait(sk_sleep(listener), &wait); lock_sock(listener); if (signal_pending(current)) { err = sock_intr_errno(timeout); - goto out_wait; + goto out; } else if (timeout == 0) { err = -EAGAIN; - goto out_wait; + goto out; } prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); } + finish_wait(sk_sleep(listener), &wait); if (listener->sk_err) err = -listener->sk_err; if (connected) { - listener->sk_ack_backlog--; + sk_acceptq_removed(listener); - lock_sock(connected); + lock_sock_nested(connected, SINGLE_DEPTH_NESTING); vconnected = vsock_sk(connected); /* If the listener socket has received an error, then we should @@ -1303,19 +1784,18 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) */ if (err) { vconnected->rejected = true; - release_sock(connected); - sock_put(connected); - goto out_wait; + } else { + newsock->state = SS_CONNECTED; + sock_graft(connected, newsock); + if (vsock_msgzerocopy_allow(vconnected->transport)) + set_bit(SOCK_SUPPORT_ZC, + &connected->sk_socket->flags); } - newsock->state = SS_CONNECTED; - sock_graft(connected, newsock); release_sock(connected); sock_put(connected); } -out_wait: - finish_wait(sk_sleep(listener), &wait); out: release_sock(listener); return err; @@ -1331,7 +1811,7 @@ static int vsock_listen(struct socket *sock, int backlog) lock_sock(sk); - if (sock->type != SOCK_STREAM) { + if (!sock_type_connectible(sk->sk_type)) { err = -EOPNOTSUPP; goto out; } @@ -1349,7 +1829,7 @@ static int vsock_listen(struct socket *sock, int backlog) } sk->sk_max_ack_backlog = backlog; - sk->sk_state = SS_LISTEN; + sk->sk_state = TCP_LISTEN; err = 0; @@ -1358,18 +1838,36 @@ out: return err; } -static int vsock_stream_setsockopt(struct socket *sock, - int level, - int optname, - char __user *optval, - unsigned int optlen) +static void vsock_update_buffer_size(struct vsock_sock *vsk, + const struct vsock_transport *transport, + u64 val) +{ + if (val > vsk->buffer_max_size) + val = vsk->buffer_max_size; + + if (val < vsk->buffer_min_size) + val = vsk->buffer_min_size; + + if (val != vsk->buffer_size && + transport && transport->notify_buffer_size) + transport->notify_buffer_size(vsk, &val); + + vsk->buffer_size = val; +} + +static int vsock_connectible_setsockopt(struct socket *sock, + int level, + int optname, + sockptr_t optval, + unsigned int optlen) { int err; struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; u64 val; - if (level != AF_VSOCK) + if (level != AF_VSOCK && level != SOL_SOCKET) return -ENOPROTOOPT; #define COPY_IN(_v) \ @@ -1378,7 +1876,7 @@ static int vsock_stream_setsockopt(struct socket *sock, err = -EINVAL; \ goto exit; \ } \ - if (copy_from_user(&_v, optval, sizeof(_v)) != 0) { \ + if (copy_from_sockptr(&_v, optval, sizeof(_v)) != 0) { \ err = -EFAULT; \ goto exit; \ } \ @@ -1390,29 +1888,65 @@ static int vsock_stream_setsockopt(struct socket *sock, lock_sock(sk); + transport = vsk->transport; + + if (level == SOL_SOCKET) { + int zerocopy; + + if (optname != SO_ZEROCOPY) { + release_sock(sk); + return sock_setsockopt(sock, level, optname, optval, optlen); + } + + /* Use 'int' type here, because variable to + * set this option usually has this type. + */ + COPY_IN(zerocopy); + + if (zerocopy < 0 || zerocopy > 1) { + err = -EINVAL; + goto exit; + } + + if (transport && !vsock_msgzerocopy_allow(transport)) { + err = -EOPNOTSUPP; + goto exit; + } + + sock_valbool_flag(sk, SOCK_ZEROCOPY, zerocopy); + goto exit; + } + switch (optname) { case SO_VM_SOCKETS_BUFFER_SIZE: COPY_IN(val); - transport->set_buffer_size(vsk, val); + vsock_update_buffer_size(vsk, transport, val); break; case SO_VM_SOCKETS_BUFFER_MAX_SIZE: COPY_IN(val); - transport->set_max_buffer_size(vsk, val); + vsk->buffer_max_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); break; case SO_VM_SOCKETS_BUFFER_MIN_SIZE: COPY_IN(val); - transport->set_min_buffer_size(vsk, val); + vsk->buffer_min_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); break; - case SO_VM_SOCKETS_CONNECT_TIMEOUT: { - struct timeval tv; - COPY_IN(tv); + case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW: + case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD: { + struct __kernel_sock_timeval tv; + + err = sock_copy_user_timeval(&tv, optval, optlen, + optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD); + if (err) + break; if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC && tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) { vsk->connect_timeout = tv.tv_sec * HZ + - DIV_ROUND_UP(tv.tv_usec, (1000000 / HZ)); + DIV_ROUND_UP((unsigned long)tv.tv_usec, (USEC_PER_SEC / HZ)); if (vsk->connect_timeout == 0) vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT; @@ -1435,88 +1969,79 @@ exit: return err; } -static int vsock_stream_getsockopt(struct socket *sock, - int level, int optname, - char __user *optval, - int __user *optlen) +static int vsock_connectible_getsockopt(struct socket *sock, + int level, int optname, + char __user *optval, + int __user *optlen) { - int err; + struct sock *sk = sock->sk; + struct vsock_sock *vsk = vsock_sk(sk); + + union { + u64 val64; + struct old_timeval32 tm32; + struct __kernel_old_timeval tm; + struct __kernel_sock_timeval stm; + } v; + + int lv = sizeof(v.val64); int len; - struct sock *sk; - struct vsock_sock *vsk; - u64 val; if (level != AF_VSOCK) return -ENOPROTOOPT; - err = get_user(len, optlen); - if (err != 0) - return err; + if (get_user(len, optlen)) + return -EFAULT; -#define COPY_OUT(_v) \ - do { \ - if (len < sizeof(_v)) \ - return -EINVAL; \ - \ - len = sizeof(_v); \ - if (copy_to_user(optval, &_v, len) != 0) \ - return -EFAULT; \ - \ - } while (0) - - err = 0; - sk = sock->sk; - vsk = vsock_sk(sk); + memset(&v, 0, sizeof(v)); switch (optname) { case SO_VM_SOCKETS_BUFFER_SIZE: - val = transport->get_buffer_size(vsk); - COPY_OUT(val); + v.val64 = vsk->buffer_size; break; case SO_VM_SOCKETS_BUFFER_MAX_SIZE: - val = transport->get_max_buffer_size(vsk); - COPY_OUT(val); + v.val64 = vsk->buffer_max_size; break; case SO_VM_SOCKETS_BUFFER_MIN_SIZE: - val = transport->get_min_buffer_size(vsk); - COPY_OUT(val); + v.val64 = vsk->buffer_min_size; break; - case SO_VM_SOCKETS_CONNECT_TIMEOUT: { - struct timeval tv; - tv.tv_sec = vsk->connect_timeout / HZ; - tv.tv_usec = - (vsk->connect_timeout - - tv.tv_sec * HZ) * (1000000 / HZ); - COPY_OUT(tv); + case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW: + case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD: + lv = sock_get_timeout(vsk->connect_timeout, &v, + optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD); break; - } + default: return -ENOPROTOOPT; } - err = put_user(len, optlen); - if (err != 0) + if (len < lv) + return -EINVAL; + if (len > lv) + len = lv; + if (copy_to_user(optval, &v, len)) return -EFAULT; -#undef COPY_OUT + if (put_user(len, optlen)) + return -EFAULT; return 0; } -static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, - struct msghdr *msg, size_t len) +static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) { struct sock *sk; struct vsock_sock *vsk; + const struct vsock_transport *transport; ssize_t total_written; long timeout; int err; struct vsock_transport_send_notify_data send_data; - - DEFINE_WAIT(wait); + DEFINE_WAIT_FUNC(wait, woken_wake_function); sk = sock->sk; vsk = vsock_sk(sk); @@ -1528,9 +2053,13 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, lock_sock(sk); - /* Callers should not provide a destination with stream sockets. */ + transport = vsk->transport; + + /* Callers should not provide a destination with connection oriented + * sockets. + */ if (msg->msg_namelen) { - err = sk->sk_state == SS_CONNECTED ? -EISCONN : -EOPNOTSUPP; + err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP; goto out; } @@ -1541,7 +2070,7 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, goto out; } - if (sk->sk_state != SS_CONNECTED || + if (!transport || sk->sk_state != TCP_ESTABLISHED || !vsock_addr_bound(&vsk->local_addr)) { err = -ENOTCONN; goto out; @@ -1552,6 +2081,12 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, goto out; } + if (msg->msg_flags & MSG_ZEROCOPY && + !vsock_msgzerocopy_allow(transport)) { + err = -EOPNOTSUPP; + goto out; + } + /* Wait for room in the produce queue to enqueue our user's data. */ timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); @@ -1559,11 +2094,10 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, if (err < 0) goto out; - prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); - while (total_written < len) { ssize_t written; + add_wait_queue(sk_sleep(sk), &wait); while (vsock_stream_has_space(vsk) == 0 && sk->sk_err == 0 && !(sk->sk_shutdown & SEND_SHUTDOWN) && @@ -1572,27 +2106,30 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, /* Don't wait for non-blocking sockets. */ if (timeout == 0) { err = -EAGAIN; - goto out_wait; + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; } err = transport->notify_send_pre_block(vsk, &send_data); - if (err < 0) - goto out_wait; + if (err < 0) { + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; + } release_sock(sk); - timeout = schedule_timeout(timeout); + timeout = wait_woken(&wait, TASK_INTERRUPTIBLE, timeout); lock_sock(sk); if (signal_pending(current)) { err = sock_intr_errno(timeout); - goto out_wait; + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; } else if (timeout == 0) { err = -EAGAIN; - goto out_wait; + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; } - - prepare_to_wait(sk_sleep(sk), &wait, - TASK_INTERRUPTIBLE); } + remove_wait_queue(sk_sleep(sk), &wait); /* These checks occur both as part of and after the loop * conditional since we need to check before and after @@ -1600,16 +2137,16 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, */ if (sk->sk_err) { err = -sk->sk_err; - goto out_wait; + goto out_err; } else if ((sk->sk_shutdown & SEND_SHUTDOWN) || (vsk->peer_shutdown & RCV_SHUTDOWN)) { err = -EPIPE; - goto out_wait; + goto out_err; } err = transport->notify_send_pre_enqueue(vsk, &send_data); if (err < 0) - goto out_wait; + goto out_err; /* Note that enqueue will only write as many bytes as are free * in the produce queue, so we don't need to ensure len is @@ -1617,12 +2154,17 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, * responsibility to check how many bytes we were able to send. */ - written = transport->stream_enqueue( - vsk, msg->msg_iov, - len - total_written); + if (sk->sk_type == SOCK_SEQPACKET) { + written = transport->seqpacket_enqueue(vsk, + msg, len - total_written); + } else { + written = transport->stream_enqueue(vsk, + msg, len - total_written); + } + if (written < 0) { - err = -ENOMEM; - goto out_wait; + err = written; + goto out_err; } total_written += written; @@ -1630,78 +2172,109 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, err = transport->notify_send_post_enqueue( vsk, written, &send_data); if (err < 0) - goto out_wait; + goto out_err; } -out_wait: - if (total_written > 0) - err = total_written; - finish_wait(sk_sleep(sk), &wait); +out_err: + if (total_written > 0) { + /* Return number of written bytes only if: + * 1) SOCK_STREAM socket. + * 2) SOCK_SEQPACKET socket when whole buffer is sent. + */ + if (sk->sk_type == SOCK_STREAM || total_written == len) + err = total_written; + } out: + if (sk->sk_type == SOCK_STREAM) + err = sk_stream_error(sk, msg->msg_flags, err); + release_sock(sk); return err; } - -static int -vsock_stream_recvmsg(struct kiocb *kiocb, - struct socket *sock, - struct msghdr *msg, size_t len, int flags) +static int vsock_connectible_wait_data(struct sock *sk, + struct wait_queue_entry *wait, + long timeout, + struct vsock_transport_recv_notify_data *recv_data, + size_t target) { - struct sock *sk; + const struct vsock_transport *transport; struct vsock_sock *vsk; + s64 data; int err; - size_t target; - ssize_t copied; - long timeout; - struct vsock_transport_recv_notify_data recv_data; - - DEFINE_WAIT(wait); - sk = sock->sk; vsk = vsock_sk(sk); err = 0; + transport = vsk->transport; - msg->msg_namelen = 0; + while (1) { + prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE); + data = vsock_connectible_has_data(vsk); + if (data != 0) + break; + + if (sk->sk_err != 0 || + (sk->sk_shutdown & RCV_SHUTDOWN) || + (vsk->peer_shutdown & SEND_SHUTDOWN)) { + break; + } - lock_sock(sk); + /* Don't wait for non-blocking sockets. */ + if (timeout == 0) { + err = -EAGAIN; + break; + } - if (sk->sk_state != SS_CONNECTED) { - /* Recvmsg is supposed to return 0 if a peer performs an - * orderly shutdown. Differentiate between that case and when a - * peer has not connected or a local shutdown occured with the - * SOCK_DONE flag. - */ - if (sock_flag(sk, SOCK_DONE)) - err = 0; - else - err = -ENOTCONN; + if (recv_data) { + err = transport->notify_recv_pre_block(vsk, target, recv_data); + if (err < 0) + break; + } - goto out; - } + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); - if (flags & MSG_OOB) { - err = -EOPNOTSUPP; - goto out; + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + break; + } else if (timeout == 0) { + err = -EAGAIN; + break; + } } - /* We don't check peer_shutdown flag here since peer may actually shut - * down, but there can be data in the queue that a local socket can - * receive. - */ - if (sk->sk_shutdown & RCV_SHUTDOWN) { - err = 0; - goto out; - } + finish_wait(sk_sleep(sk), wait); - /* It is valid on Linux to pass in a zero-length receive buffer. This - * is not an error. We may as well bail out now. + if (err) + return err; + + /* Internal transport error when checking for available + * data. XXX This should be changed to a connection + * reset in a later change. */ - if (!len) { - err = 0; - goto out; - } + if (data < 0) + return -ENOMEM; + + return data; +} + +static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags) +{ + struct vsock_transport_recv_notify_data recv_data; + const struct vsock_transport *transport; + struct vsock_sock *vsk; + ssize_t copied; + size_t target; + long timeout; + int err; + + DEFINE_WAIT(wait); + + vsk = vsock_sk(sk); + transport = vsk->transport; /* We must not copy less than target bytes into the user's buffer * before returning successfully, so we wait for the consume queue to @@ -1721,137 +2294,259 @@ vsock_stream_recvmsg(struct kiocb *kiocb, if (err < 0) goto out; - prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); while (1) { - s64 ready = vsock_stream_has_data(vsk); + ssize_t read; - if (ready < 0) { - /* Invalid queue pair content. XXX This should be - * changed to a connection reset in a later change. - */ + err = vsock_connectible_wait_data(sk, &wait, timeout, + &recv_data, target); + if (err <= 0) + break; - err = -ENOMEM; - goto out_wait; - } else if (ready > 0) { - ssize_t read; + err = transport->notify_recv_pre_dequeue(vsk, target, + &recv_data); + if (err < 0) + break; - err = transport->notify_recv_pre_dequeue( - vsk, target, &recv_data); - if (err < 0) - break; + read = transport->stream_dequeue(vsk, msg, len - copied, flags); + if (read < 0) { + err = read; + break; + } - read = transport->stream_dequeue( - vsk, msg->msg_iov, - len - copied, flags); - if (read < 0) { - err = -ENOMEM; - break; - } + copied += read; - copied += read; + err = transport->notify_recv_post_dequeue(vsk, target, read, + !(flags & MSG_PEEK), &recv_data); + if (err < 0) + goto out; - err = transport->notify_recv_post_dequeue( - vsk, target, read, - !(flags & MSG_PEEK), &recv_data); - if (err < 0) - goto out_wait; + if (read >= target || flags & MSG_PEEK) + break; - if (read >= target || flags & MSG_PEEK) - break; + target -= read; + } - target -= read; - } else { - if (sk->sk_err != 0 || (sk->sk_shutdown & RCV_SHUTDOWN) - || (vsk->peer_shutdown & SEND_SHUTDOWN)) { - break; - } - /* Don't wait for non-blocking sockets. */ - if (timeout == 0) { - err = -EAGAIN; - break; - } + if (sk->sk_err) + err = -sk->sk_err; + else if (sk->sk_shutdown & RCV_SHUTDOWN) + err = 0; - err = transport->notify_recv_pre_block( - vsk, target, &recv_data); - if (err < 0) - break; + if (copied > 0) + err = copied; - release_sock(sk); - timeout = schedule_timeout(timeout); - lock_sock(sk); +out: + return err; +} - if (signal_pending(current)) { - err = sock_intr_errno(timeout); - break; - } else if (timeout == 0) { - err = -EAGAIN; - break; - } +static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags) +{ + const struct vsock_transport *transport; + struct vsock_sock *vsk; + ssize_t msg_len; + long timeout; + int err = 0; + DEFINE_WAIT(wait); - prepare_to_wait(sk_sleep(sk), &wait, - TASK_INTERRUPTIBLE); - } + vsk = vsock_sk(sk); + transport = vsk->transport; + + timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + err = vsock_connectible_wait_data(sk, &wait, timeout, NULL, 0); + if (err <= 0) + goto out; + + msg_len = transport->seqpacket_dequeue(vsk, msg, flags); + + if (msg_len < 0) { + err = msg_len; + goto out; } - if (sk->sk_err) + if (sk->sk_err) { err = -sk->sk_err; - else if (sk->sk_shutdown & RCV_SHUTDOWN) + } else if (sk->sk_shutdown & RCV_SHUTDOWN) { err = 0; + } else { + /* User sets MSG_TRUNC, so return real length of + * packet. + */ + if (flags & MSG_TRUNC) + err = msg_len; + else + err = len - msg_data_left(msg); - if (copied > 0) { - /* We only do these additional bookkeeping/notification steps - * if we actually copied something out of the queue pair - * instead of just peeking ahead. + /* Always set MSG_TRUNC if real length of packet is + * bigger than user's buffer. */ + if (msg_len > len) + msg->msg_flags |= MSG_TRUNC; + } - if (!(flags & MSG_PEEK)) { - /* If the other side has shutdown for sending and there - * is nothing more to read, then modify the socket - * state. - */ - if (vsk->peer_shutdown & SEND_SHUTDOWN) { - if (vsock_stream_has_data(vsk) <= 0) { - sk->sk_state = SS_UNCONNECTED; - sock_set_flag(sk, SOCK_DONE); - sk->sk_state_change(sk); - } - } - } - err = copied; +out: + return err; +} + +int +__vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct sock *sk; + struct vsock_sock *vsk; + const struct vsock_transport *transport; + int err; + + sk = sock->sk; + + if (unlikely(flags & MSG_ERRQUEUE)) + return sock_recv_errqueue(sk, msg, len, SOL_VSOCK, VSOCK_RECVERR); + + vsk = vsock_sk(sk); + err = 0; + + lock_sock(sk); + + transport = vsk->transport; + + if (!transport || sk->sk_state != TCP_ESTABLISHED) { + /* Recvmsg is supposed to return 0 if a peer performs an + * orderly shutdown. Differentiate between that case and when a + * peer has not connected or a local shutdown occurred with the + * SOCK_DONE flag. + */ + if (sock_flag(sk, SOCK_DONE)) + err = 0; + else + err = -ENOTCONN; + + goto out; } -out_wait: - finish_wait(sk_sleep(sk), &wait); + if (flags & MSG_OOB) { + err = -EOPNOTSUPP; + goto out; + } + + /* We don't check peer_shutdown flag here since peer may actually shut + * down, but there can be data in the queue that a local socket can + * receive. + */ + if (sk->sk_shutdown & RCV_SHUTDOWN) { + err = 0; + goto out; + } + + /* It is valid on Linux to pass in a zero-length receive buffer. This + * is not an error. We may as well bail out now. + */ + if (!len) { + err = 0; + goto out; + } + + if (sk->sk_type == SOCK_STREAM) + err = __vsock_stream_recvmsg(sk, msg, len, flags); + else + err = __vsock_seqpacket_recvmsg(sk, msg, len, flags); + out: release_sock(sk); return err; } +int +vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ +#ifdef CONFIG_BPF_SYSCALL + struct sock *sk = sock->sk; + const struct proto *prot; + + prot = READ_ONCE(sk->sk_prot); + if (prot != &vsock_proto) + return prot->recvmsg(sk, msg, len, flags, NULL); +#endif + + return __vsock_connectible_recvmsg(sock, msg, len, flags); +} +EXPORT_SYMBOL_GPL(vsock_connectible_recvmsg); + +static int vsock_set_rcvlowat(struct sock *sk, int val) +{ + const struct vsock_transport *transport; + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + if (val > vsk->buffer_size) + return -EINVAL; + + transport = vsk->transport; + + if (transport && transport->notify_set_rcvlowat) { + int err; + + err = transport->notify_set_rcvlowat(vsk, val); + if (err) + return err; + } + + WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); + return 0; +} + static const struct proto_ops vsock_stream_ops = { .family = PF_VSOCK, .owner = THIS_MODULE, .release = vsock_release, .bind = vsock_bind, - .connect = vsock_stream_connect, + .connect = vsock_connect, + .socketpair = sock_no_socketpair, + .accept = vsock_accept, + .getname = vsock_getname, + .poll = vsock_poll, + .ioctl = vsock_ioctl, + .listen = vsock_listen, + .shutdown = vsock_shutdown, + .setsockopt = vsock_connectible_setsockopt, + .getsockopt = vsock_connectible_getsockopt, + .sendmsg = vsock_connectible_sendmsg, + .recvmsg = vsock_connectible_recvmsg, + .mmap = sock_no_mmap, + .set_rcvlowat = vsock_set_rcvlowat, + .read_skb = vsock_read_skb, +}; + +static const struct proto_ops vsock_seqpacket_ops = { + .family = PF_VSOCK, + .owner = THIS_MODULE, + .release = vsock_release, + .bind = vsock_bind, + .connect = vsock_connect, .socketpair = sock_no_socketpair, .accept = vsock_accept, .getname = vsock_getname, .poll = vsock_poll, - .ioctl = sock_no_ioctl, + .ioctl = vsock_ioctl, .listen = vsock_listen, .shutdown = vsock_shutdown, - .setsockopt = vsock_stream_setsockopt, - .getsockopt = vsock_stream_getsockopt, - .sendmsg = vsock_stream_sendmsg, - .recvmsg = vsock_stream_recvmsg, + .setsockopt = vsock_connectible_setsockopt, + .getsockopt = vsock_connectible_getsockopt, + .sendmsg = vsock_connectible_sendmsg, + .recvmsg = vsock_connectible_recvmsg, .mmap = sock_no_mmap, - .sendpage = sock_no_sendpage, + .read_skb = vsock_read_skb, }; static int vsock_create(struct net *net, struct socket *sock, int protocol, int kern) { + struct vsock_sock *vsk; + struct sock *sk; + int ret; + if (!sock) return -EINVAL; @@ -1865,13 +2560,39 @@ static int vsock_create(struct net *net, struct socket *sock, case SOCK_STREAM: sock->ops = &vsock_stream_ops; break; + case SOCK_SEQPACKET: + sock->ops = &vsock_seqpacket_ops; + break; default: return -ESOCKTNOSUPPORT; } sock->state = SS_UNCONNECTED; - return __vsock_create(net, sock, NULL, GFP_KERNEL, 0) ? 0 : -ENOMEM; + sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern); + if (!sk) + return -ENOMEM; + + vsk = vsock_sk(sk); + + if (sock->type == SOCK_DGRAM) { + ret = vsock_assign_transport(vsk, NULL); + if (ret < 0) { + sock->sk = NULL; + sock_put(sk); + return ret; + } + } + + /* SOCK_DGRAM doesn't have 'setsockopt' callback set in its + * proto_ops, so there is no handler for custom logic. + */ + if (sock_type_connectible(sock->type)) + set_bit(SOCK_CUSTOM_SOCKOPT, &sk->sk_socket->flags); + + vsock_insert_unbound(vsk); + + return 0; } static const struct net_proto_family vsock_family_ops = { @@ -1885,16 +2606,25 @@ static long vsock_dev_do_ioctl(struct file *filp, { u32 __user *p = ptr; int retval = 0; + u32 cid; switch (cmd) { case IOCTL_VM_SOCKETS_GET_LOCAL_CID: - if (put_user(transport->get_local_cid(), p) != 0) + /* To be compatible with the VMCI behavior, we prioritize the + * guest CID instead of well-know host CID (VMADDR_CID_HOST). + */ + cid = vsock_registered_transport_cid(&transport_g2h); + if (cid == VMADDR_CID_ANY) + cid = vsock_registered_transport_cid(&transport_h2g); + if (cid == VMADDR_CID_ANY) + cid = vsock_registered_transport_cid(&transport_local); + + if (put_user(cid, p) != 0) retval = -EFAULT; break; default: - pr_err("Unknown ioctl %d\n", cmd); - retval = -EINVAL; + retval = -ENOIOCTLCMD; } return retval; @@ -1928,23 +2658,24 @@ static struct miscdevice vsock_device = { .fops = &vsock_device_ops, }; -static int __vsock_core_init(void) +static int __init vsock_init(void) { - int err; + int err = 0; vsock_init_tables(); + vsock_proto.owner = THIS_MODULE; vsock_device.minor = MISC_DYNAMIC_MINOR; err = misc_register(&vsock_device); if (err) { pr_err("Failed to register misc device\n"); - return -ENOENT; + goto err_reset_transport; } err = proto_register(&vsock_proto, 1); /* we want our slab */ if (err) { pr_err("Cannot register vsock protocol\n"); - goto err_misc_deregister; + goto err_deregister_misc; } err = sock_register(&vsock_family_ops); @@ -1954,54 +2685,111 @@ static int __vsock_core_init(void) goto err_unregister_proto; } + vsock_bpf_build_proto(); + return 0; err_unregister_proto: proto_unregister(&vsock_proto); -err_misc_deregister: +err_deregister_misc: misc_deregister(&vsock_device); +err_reset_transport: return err; } -int vsock_core_init(const struct vsock_transport *t) +static void __exit vsock_exit(void) +{ + misc_deregister(&vsock_device); + sock_unregister(AF_VSOCK); + proto_unregister(&vsock_proto); +} + +const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk) { - int retval = mutex_lock_interruptible(&vsock_register_mutex); - if (retval) - return retval; + return vsk->transport; +} +EXPORT_SYMBOL_GPL(vsock_core_get_transport); - if (transport) { - retval = -EBUSY; - goto out; +int vsock_core_register(const struct vsock_transport *t, int features) +{ + const struct vsock_transport *t_h2g, *t_g2h, *t_dgram, *t_local; + int err = mutex_lock_interruptible(&vsock_register_mutex); + + if (err) + return err; + + t_h2g = transport_h2g; + t_g2h = transport_g2h; + t_dgram = transport_dgram; + t_local = transport_local; + + if (features & VSOCK_TRANSPORT_F_H2G) { + if (t_h2g) { + err = -EBUSY; + goto err_busy; + } + t_h2g = t; } - transport = t; - retval = __vsock_core_init(); - if (retval) - transport = NULL; + if (features & VSOCK_TRANSPORT_F_G2H) { + if (t_g2h) { + err = -EBUSY; + goto err_busy; + } + t_g2h = t; + } -out: + if (features & VSOCK_TRANSPORT_F_DGRAM) { + if (t_dgram) { + err = -EBUSY; + goto err_busy; + } + t_dgram = t; + } + + if (features & VSOCK_TRANSPORT_F_LOCAL) { + if (t_local) { + err = -EBUSY; + goto err_busy; + } + t_local = t; + } + + transport_h2g = t_h2g; + transport_g2h = t_g2h; + transport_dgram = t_dgram; + transport_local = t_local; + +err_busy: mutex_unlock(&vsock_register_mutex); - return retval; + return err; } -EXPORT_SYMBOL_GPL(vsock_core_init); +EXPORT_SYMBOL_GPL(vsock_core_register); -void vsock_core_exit(void) +void vsock_core_unregister(const struct vsock_transport *t) { mutex_lock(&vsock_register_mutex); - misc_deregister(&vsock_device); - sock_unregister(AF_VSOCK); - proto_unregister(&vsock_proto); + if (transport_h2g == t) + transport_h2g = NULL; + + if (transport_g2h == t) + transport_g2h = NULL; - /* We do not want the assignment below re-ordered. */ - mb(); - transport = NULL; + if (transport_dgram == t) + transport_dgram = NULL; + + if (transport_local == t) + transport_local = NULL; mutex_unlock(&vsock_register_mutex); } -EXPORT_SYMBOL_GPL(vsock_core_exit); +EXPORT_SYMBOL_GPL(vsock_core_unregister); + +module_init(vsock_init); +module_exit(vsock_exit); MODULE_AUTHOR("VMware, Inc."); MODULE_DESCRIPTION("VMware Virtual Socket Family"); -MODULE_VERSION("1.0.0.0-k"); +MODULE_VERSION("1.0.2.0-k"); MODULE_LICENSE("GPL v2"); |
