summaryrefslogtreecommitdiff
path: root/net/vmw_vsock/af_vsock.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/af_vsock.c')
-rw-r--r--net/vmw_vsock/af_vsock.c1790
1 files changed, 1289 insertions, 501 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 593071dabd1c..adcba1b7bf74 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -1,16 +1,8 @@
+// SPDX-License-Identifier: GPL-2.0-only
/*
* VMware vSockets Driver
*
* Copyright (C) 2007-2013 VMware, Inc. All rights reserved.
- *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the Free
- * Software Foundation version 2 and no later version.
- *
- * This program is distributed in the hope that it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
- * more details.
*/
/* Implementation notes:
@@ -36,19 +28,20 @@
* not support simultaneous connects (two "client" sockets connecting).
*
* - "Server" sockets are referred to as listener sockets throughout this
- * implementation because they are in the SS_LISTEN state. When a connection
- * request is received (the second kind of socket mentioned above), we create a
- * new socket and refer to it as a pending socket. These pending sockets are
- * placed on the pending connection list of the listener socket. When future
- * packets are received for the address the listener socket is bound to, we
- * check if the source of the packet is from one that has an existing pending
- * connection. If it does, we process the packet for the pending socket. When
- * that socket reaches the connected state, it is removed from the listener
- * socket's pending list and enqueued in the listener socket's accept queue.
- * Callers of accept(2) will accept connected sockets from the listener socket's
- * accept queue. If the socket cannot be accepted for some reason then it is
- * marked rejected. Once the connection is accepted, it is owned by the user
- * process and the responsibility for cleanup falls with that user process.
+ * implementation because they are in the TCP_LISTEN state. When a
+ * connection request is received (the second kind of socket mentioned above),
+ * we create a new socket and refer to it as a pending socket. These pending
+ * sockets are placed on the pending connection list of the listener socket.
+ * When future packets are received for the address the listener socket is
+ * bound to, we check if the source of the packet is from one that has an
+ * existing pending connection. If it does, we process the packet for the
+ * pending socket. When that socket reaches the connected state, it is removed
+ * from the listener socket's pending list and enqueued in the listener
+ * socket's accept queue. Callers of accept(2) will accept connected sockets
+ * from the listener socket's accept queue. If the socket cannot be accepted
+ * for some reason then it is marked rejected. Once the connection is
+ * accepted, it is owned by the user process and the responsibility for cleanup
+ * falls with that user process.
*
* - It is possible that these pending sockets will never reach the connected
* state; in fact, we may never receive another packet after the connection
@@ -60,6 +53,14 @@
* function will also cleanup rejected sockets, those that reach the connected
* state but leave it before they have been accepted.
*
+ * - Lock ordering for pending or accept queue sockets is:
+ *
+ * lock_sock(listener);
+ * lock_sock_nested(pending, SINGLE_DEPTH_NESTING);
+ *
+ * Using explicit nested locking keeps lockdep happy since normally only one
+ * lock of a given class may be taken at a time.
+ *
* - Sockets created by user action will be cleaned up when the user process
* calls close(2), causing our release implementation to be called. Our release
* implementation will perform some cleanup then drop the last reference so our
@@ -73,14 +74,26 @@
* argument, we must ensure the reference count is increased to ensure the
* socket isn't freed before the function is run; the deferred function will
* then drop the reference.
+ *
+ * - sk->sk_state uses the TCP state constants because they are widely used by
+ * other address families and exposed to userspace tools like ss(8):
+ *
+ * TCP_CLOSE - unconnected
+ * TCP_SYN_SENT - connecting
+ * TCP_ESTABLISHED - connected
+ * TCP_CLOSING - disconnecting
+ * TCP_LISTEN - listening
*/
+#include <linux/compat.h>
#include <linux/types.h>
#include <linux/bitops.h>
#include <linux/cred.h>
+#include <linux/errqueue.h>
#include <linux/init.h>
#include <linux/io.h>
#include <linux/kernel.h>
+#include <linux/sched/signal.h>
#include <linux/kmod.h>
#include <linux/list.h>
#include <linux/miscdevice.h>
@@ -88,6 +101,7 @@
#include <linux/mutex.h>
#include <linux/net.h>
#include <linux/poll.h>
+#include <linux/random.h>
#include <linux/skbuff.h>
#include <linux/smp.h>
#include <linux/socket.h>
@@ -96,18 +110,24 @@
#include <linux/wait.h>
#include <linux/workqueue.h>
#include <net/sock.h>
-
-#include "af_vsock.h"
+#include <net/af_vsock.h>
+#include <uapi/linux/vm_sockets.h>
+#include <uapi/asm-generic/ioctls.h>
static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr);
static void vsock_sk_destruct(struct sock *sk);
static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
+static void vsock_close(struct sock *sk, long timeout);
/* Protocol family. */
-static struct proto vsock_proto = {
+struct proto vsock_proto = {
.name = "AF_VSOCK",
.owner = THIS_MODULE,
.obj_size = sizeof(struct vsock_sock),
+ .close = vsock_close,
+#ifdef CONFIG_BPF_SYSCALL
+ .psock_update_sk_prot = vsock_bpf_update_proto,
+#endif
};
/* The default peer timeout indicates how long we will wait for a peer response
@@ -115,21 +135,20 @@ static struct proto vsock_proto = {
*/
#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
-#define SS_LISTEN 255
-
-static const struct vsock_transport *transport;
+#define VSOCK_DEFAULT_BUFFER_SIZE (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128
+
+/* Transport used for host->guest communication */
+static const struct vsock_transport *transport_h2g;
+/* Transport used for guest->host communication */
+static const struct vsock_transport *transport_g2h;
+/* Transport used for DGRAM communication */
+static const struct vsock_transport *transport_dgram;
+/* Transport used for local communication */
+static const struct vsock_transport *transport_local;
static DEFINE_MUTEX(vsock_register_mutex);
-/**** EXPORTS ****/
-
-/* Get the ID of the local context. This is transport dependent. */
-
-int vm_sockets_get_local_cid(void)
-{
- return transport->get_local_cid();
-}
-EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid);
-
/**** UTILS ****/
/* Each bound VSocket is stored in the bind hash table and each connected
@@ -146,7 +165,6 @@ EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid);
* vsock_bind_table[VSOCK_HASH_SIZE] is for unbound sockets. The hash function
* mods with VSOCK_HASH_SIZE to ensure this.
*/
-#define VSOCK_HASH_SIZE 251
#define MAX_PORT_RETRIES 24
#define VSOCK_HASH(addr) ((addr)->svm_port % VSOCK_HASH_SIZE)
@@ -161,9 +179,12 @@ EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid);
#define vsock_connected_sockets_vsk(vsk) \
vsock_connected_sockets(&(vsk)->remote_addr, &(vsk)->local_addr)
-static struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1];
-static struct list_head vsock_connected_table[VSOCK_HASH_SIZE];
-static DEFINE_SPINLOCK(vsock_table_lock);
+struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1];
+EXPORT_SYMBOL_GPL(vsock_bind_table);
+struct list_head vsock_connected_table[VSOCK_HASH_SIZE];
+EXPORT_SYMBOL_GPL(vsock_connected_table);
+DEFINE_SPINLOCK(vsock_table_lock);
+EXPORT_SYMBOL_GPL(vsock_table_lock);
/* Autobind this socket to the local address if necessary. */
static int vsock_auto_bind(struct vsock_sock *vsk)
@@ -218,9 +239,15 @@ static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
{
struct vsock_sock *vsk;
- list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table)
- if (addr->svm_port == vsk->local_addr.svm_port)
+ list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
+ if (vsock_addr_equals_addr(addr, &vsk->local_addr))
+ return sk_vsock(vsk);
+
+ if (addr->svm_port == vsk->local_addr.svm_port &&
+ (vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
+ addr->svm_cid == VMADDR_CID_ANY))
return sk_vsock(vsk);
+ }
return NULL;
}
@@ -241,16 +268,6 @@ static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
return NULL;
}
-static bool __vsock_in_bound_table(struct vsock_sock *vsk)
-{
- return !list_empty(&vsk->bound_table);
-}
-
-static bool __vsock_in_connected_table(struct vsock_sock *vsk)
-{
- return !list_empty(&vsk->connected_table);
-}
-
static void vsock_insert_unbound(struct vsock_sock *vsk)
{
spin_lock_bh(&vsock_table_lock);
@@ -272,7 +289,8 @@ EXPORT_SYMBOL_GPL(vsock_insert_connected);
void vsock_remove_bound(struct vsock_sock *vsk)
{
spin_lock_bh(&vsock_table_lock);
- __vsock_remove_bound(vsk);
+ if (__vsock_in_bound_table(vsk))
+ __vsock_remove_bound(vsk);
spin_unlock_bh(&vsock_table_lock);
}
EXPORT_SYMBOL_GPL(vsock_remove_bound);
@@ -280,7 +298,8 @@ EXPORT_SYMBOL_GPL(vsock_remove_bound);
void vsock_remove_connected(struct vsock_sock *vsk)
{
spin_lock_bh(&vsock_table_lock);
- __vsock_remove_connected(vsk);
+ if (__vsock_in_connected_table(vsk))
+ __vsock_remove_connected(vsk);
spin_unlock_bh(&vsock_table_lock);
}
EXPORT_SYMBOL_GPL(vsock_remove_connected);
@@ -316,29 +335,18 @@ struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
}
EXPORT_SYMBOL_GPL(vsock_find_connected_socket);
-static bool vsock_in_bound_table(struct vsock_sock *vsk)
-{
- bool ret;
-
- spin_lock_bh(&vsock_table_lock);
- ret = __vsock_in_bound_table(vsk);
- spin_unlock_bh(&vsock_table_lock);
-
- return ret;
-}
-
-static bool vsock_in_connected_table(struct vsock_sock *vsk)
+void vsock_remove_sock(struct vsock_sock *vsk)
{
- bool ret;
-
- spin_lock_bh(&vsock_table_lock);
- ret = __vsock_in_connected_table(vsk);
- spin_unlock_bh(&vsock_table_lock);
+ /* Transport reassignment must not remove the binding. */
+ if (sock_flag(sk_vsock(vsk), SOCK_DEAD))
+ vsock_remove_bound(vsk);
- return ret;
+ vsock_remove_connected(vsk);
}
+EXPORT_SYMBOL_GPL(vsock_remove_sock);
-void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
+void vsock_for_each_connected_socket(struct vsock_transport *transport,
+ void (*fn)(struct sock *sk))
{
int i;
@@ -347,8 +355,12 @@ void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) {
struct vsock_sock *vsk;
list_for_each_entry(vsk, &vsock_connected_table[i],
- connected_table);
+ connected_table) {
+ if (vsk->transport != transport)
+ continue;
+
fn(sk_vsock(vsk));
+ }
}
spin_unlock_bh(&vsock_table_lock);
@@ -393,6 +405,181 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
}
EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
+static bool vsock_use_local_transport(unsigned int remote_cid)
+{
+ lockdep_assert_held(&vsock_register_mutex);
+
+ if (!transport_local)
+ return false;
+
+ if (remote_cid == VMADDR_CID_LOCAL)
+ return true;
+
+ if (transport_g2h) {
+ return remote_cid == transport_g2h->get_local_cid();
+ } else {
+ return remote_cid == VMADDR_CID_HOST;
+ }
+}
+
+static void vsock_deassign_transport(struct vsock_sock *vsk)
+{
+ if (!vsk->transport)
+ return;
+
+ vsk->transport->destruct(vsk);
+ module_put(vsk->transport->module);
+ vsk->transport = NULL;
+}
+
+/* Assign a transport to a socket and call the .init transport callback.
+ *
+ * Note: for connection oriented socket this must be called when vsk->remote_addr
+ * is set (e.g. during the connect() or when a connection request on a listener
+ * socket is received).
+ * The vsk->remote_addr is used to decide which transport to use:
+ * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
+ * g2h is not loaded, will use local transport;
+ * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field
+ * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport;
+ * - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ */
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+{
+ const struct vsock_transport *new_transport;
+ struct sock *sk = sk_vsock(vsk);
+ unsigned int remote_cid = vsk->remote_addr.svm_cid;
+ __u8 remote_flags;
+ int ret;
+
+ /* If the packet is coming with the source and destination CIDs higher
+ * than VMADDR_CID_HOST, then a vsock channel where all the packets are
+ * forwarded to the host should be established. Then the host will
+ * need to forward the packets to the guest.
+ *
+ * The flag is set on the (listen) receive path (psk is not NULL). On
+ * the connect path the flag can be set by the user space application.
+ */
+ if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST &&
+ vsk->remote_addr.svm_cid > VMADDR_CID_HOST)
+ vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST;
+
+ remote_flags = vsk->remote_addr.svm_flags;
+
+ mutex_lock(&vsock_register_mutex);
+
+ switch (sk->sk_type) {
+ case SOCK_DGRAM:
+ new_transport = transport_dgram;
+ break;
+ case SOCK_STREAM:
+ case SOCK_SEQPACKET:
+ if (vsock_use_local_transport(remote_cid))
+ new_transport = transport_local;
+ else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g ||
+ (remote_flags & VMADDR_FLAG_TO_HOST))
+ new_transport = transport_g2h;
+ else
+ new_transport = transport_h2g;
+ break;
+ default:
+ ret = -ESOCKTNOSUPPORT;
+ goto err;
+ }
+
+ if (vsk->transport && vsk->transport == new_transport) {
+ ret = 0;
+ goto err;
+ }
+
+ /* We increase the module refcnt to prevent the transport unloading
+ * while there are open sockets assigned to it.
+ */
+ if (!new_transport || !try_module_get(new_transport->module)) {
+ ret = -ENODEV;
+ goto err;
+ }
+
+ /* It's safe to release the mutex after a successful try_module_get().
+ * Whichever transport `new_transport` points at, it won't go away until
+ * the last module_put() below or in vsock_deassign_transport().
+ */
+ mutex_unlock(&vsock_register_mutex);
+
+ if (vsk->transport) {
+ /* transport->release() must be called with sock lock acquired.
+ * This path can only be taken during vsock_connect(), where we
+ * have already held the sock lock. In the other cases, this
+ * function is called on a new socket which is not assigned to
+ * any transport.
+ */
+ vsk->transport->release(vsk);
+ vsock_deassign_transport(vsk);
+
+ /* transport's release() and destruct() can touch some socket
+ * state, since we are reassigning the socket to a new transport
+ * during vsock_connect(), let's reset these fields to have a
+ * clean state.
+ */
+ sock_reset_flag(sk, SOCK_DONE);
+ sk->sk_state = TCP_CLOSE;
+ vsk->peer_shutdown = 0;
+ }
+
+ if (sk->sk_type == SOCK_SEQPACKET) {
+ if (!new_transport->seqpacket_allow ||
+ !new_transport->seqpacket_allow(remote_cid)) {
+ module_put(new_transport->module);
+ return -ESOCKTNOSUPPORT;
+ }
+ }
+
+ ret = new_transport->init(vsk, psk);
+ if (ret) {
+ module_put(new_transport->module);
+ return ret;
+ }
+
+ vsk->transport = new_transport;
+
+ return 0;
+err:
+ mutex_unlock(&vsock_register_mutex);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(vsock_assign_transport);
+
+/*
+ * Provide safe access to static transport_{h2g,g2h,dgram,local} callbacks.
+ * Otherwise we may race with module removal. Do not use on `vsk->transport`.
+ */
+static u32 vsock_registered_transport_cid(const struct vsock_transport **transport)
+{
+ u32 cid = VMADDR_CID_ANY;
+
+ mutex_lock(&vsock_register_mutex);
+ if (*transport)
+ cid = (*transport)->get_local_cid();
+ mutex_unlock(&vsock_register_mutex);
+
+ return cid;
+}
+
+bool vsock_find_cid(unsigned int cid)
+{
+ if (cid == vsock_registered_transport_cid(&transport_g2h))
+ return true;
+
+ if (transport_h2g && cid == VMADDR_CID_HOST)
+ return true;
+
+ if (transport_local && cid == VMADDR_CID_LOCAL)
+ return true;
+
+ return false;
+}
+EXPORT_SYMBOL_GPL(vsock_find_cid);
+
static struct sock *vsock_dequeue_accept(struct sock *listener)
{
struct vsock_sock *vlistener;
@@ -429,26 +616,33 @@ static bool vsock_is_pending(struct sock *sk)
static int vsock_send_shutdown(struct sock *sk, int mode)
{
- return transport->shutdown(vsock_sk(sk), mode);
+ struct vsock_sock *vsk = vsock_sk(sk);
+
+ if (!vsk->transport)
+ return -ENODEV;
+
+ return vsk->transport->shutdown(vsk, mode);
}
-void vsock_pending_work(struct work_struct *work)
+static void vsock_pending_work(struct work_struct *work)
{
struct sock *sk;
struct sock *listener;
struct vsock_sock *vsk;
bool cleanup;
- vsk = container_of(work, struct vsock_sock, dwork.work);
+ vsk = container_of(work, struct vsock_sock, pending_work.work);
sk = sk_vsock(vsk);
listener = vsk->listener;
cleanup = true;
lock_sock(listener);
- lock_sock(sk);
+ lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
if (vsock_is_pending(sk)) {
vsock_remove_pending(listener, sk);
+
+ sk_acceptq_removed(listener);
} else if (!vsk->rejected) {
/* We are not on the pending list and accept() did not reject
* us, so we must have been accepted by our user process. We
@@ -459,16 +653,13 @@ void vsock_pending_work(struct work_struct *work)
goto out;
}
- listener->sk_ack_backlog--;
-
/* We need to remove ourself from the global connected sockets list so
* incoming packets can't find this socket, and to reduce the reference
* count.
*/
- if (vsock_in_connected_table(vsk))
- vsock_remove_connected(vsk);
+ vsock_remove_connected(vsk);
- sk->sk_state = SS_FREE;
+ sk->sk_state = TCP_CLOSE;
out:
release_sock(sk);
@@ -479,16 +670,18 @@ out:
sock_put(sk);
sock_put(listener);
}
-EXPORT_SYMBOL_GPL(vsock_pending_work);
/**** SOCKET OPERATIONS ****/
-static int __vsock_bind_stream(struct vsock_sock *vsk,
- struct sockaddr_vm *addr)
+static int __vsock_bind_connectible(struct vsock_sock *vsk,
+ struct sockaddr_vm *addr)
{
- static u32 port = LAST_RESERVED_PORT + 1;
+ static u32 port;
struct sockaddr_vm new_addr;
+ if (!port)
+ port = get_random_u32_above(LAST_RESERVED_PORT);
+
vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port);
if (addr->svm_port == VMADDR_PORT_ANY) {
@@ -496,7 +689,8 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
unsigned int i;
for (i = 0; i < MAX_PORT_RETRIES; i++) {
- if (port <= LAST_RESERVED_PORT)
+ if (port == VMADDR_PORT_ANY ||
+ port <= LAST_RESERVED_PORT)
port = LAST_RESERVED_PORT + 1;
new_addr.svm_port = port++;
@@ -524,9 +718,10 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port);
- /* Remove stream sockets from the unbound list and add them to the hash
- * table for easy lookup by its address. The unbound list is simply an
- * extra entry at the end of the hash table, a trick used by AF_UNIX.
+ /* Remove connection oriented sockets from the unbound list and add them
+ * to the hash table for easy lookup by its address. The unbound list
+ * is simply an extra entry at the end of the hash table, a trick used
+ * by AF_UNIX.
*/
__vsock_remove_bound(vsk);
__vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk);
@@ -537,13 +732,12 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
static int __vsock_bind_dgram(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
- return transport->dgram_bind(vsk, addr);
+ return vsk->transport->dgram_bind(vsk, addr);
}
static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
{
struct vsock_sock *vsk = vsock_sk(sk);
- u32 cid;
int retval;
/* First ensure this socket isn't already bound. */
@@ -553,16 +747,16 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
/* Now bind to the provided address or select appropriate values if
* none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY). Note that
* like AF_INET prevents binding to a non-local IP address (in most
- * cases), we only allow binding to the local CID.
+ * cases), we only allow binding to a local CID.
*/
- cid = transport->get_local_cid();
- if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY)
+ if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid))
return -EADDRNOTAVAIL;
switch (sk->sk_socket->type) {
case SOCK_STREAM:
+ case SOCK_SEQPACKET:
spin_lock_bh(&vsock_table_lock);
- retval = __vsock_bind_stream(vsk, addr);
+ retval = __vsock_bind_connectible(vsk, addr);
spin_unlock_bh(&vsock_table_lock);
break;
@@ -578,17 +772,20 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
return retval;
}
-struct sock *__vsock_create(struct net *net,
- struct socket *sock,
- struct sock *parent,
- gfp_t priority,
- unsigned short type)
+static void vsock_connect_timeout(struct work_struct *work);
+
+static struct sock *__vsock_create(struct net *net,
+ struct socket *sock,
+ struct sock *parent,
+ gfp_t priority,
+ unsigned short type,
+ int kern)
{
struct sock *sk;
struct vsock_sock *psk;
struct vsock_sock *vsk;
- sk = sk_alloc(net, AF_VSOCK, priority, &vsock_proto);
+ sk = sk_alloc(net, AF_VSOCK, priority, &vsock_proto, kern);
if (!sk)
return NULL;
@@ -607,7 +804,6 @@ struct sock *__vsock_create(struct net *net,
sk->sk_destruct = vsock_sk_destruct;
sk->sk_backlog_rcv = vsock_queue_rcv_skb;
- sk->sk_state = 0;
sock_reset_flag(sk, SOCK_DONE);
INIT_LIST_HEAD(&vsk->bound_table);
@@ -619,71 +815,85 @@ struct sock *__vsock_create(struct net *net,
vsk->sent_request = false;
vsk->ignore_connecting_rst = false;
vsk->peer_shutdown = 0;
+ INIT_DELAYED_WORK(&vsk->connect_work, vsock_connect_timeout);
+ INIT_DELAYED_WORK(&vsk->pending_work, vsock_pending_work);
psk = parent ? vsock_sk(parent) : NULL;
if (parent) {
vsk->trusted = psk->trusted;
vsk->owner = get_cred(psk->owner);
vsk->connect_timeout = psk->connect_timeout;
+ vsk->buffer_size = psk->buffer_size;
+ vsk->buffer_min_size = psk->buffer_min_size;
+ vsk->buffer_max_size = psk->buffer_max_size;
+ security_sk_clone(parent, sk);
} else {
- vsk->trusted = capable(CAP_NET_ADMIN);
+ vsk->trusted = ns_capable_noaudit(&init_user_ns, CAP_NET_ADMIN);
vsk->owner = get_current_cred();
vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT;
+ vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE;
+ vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE;
+ vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE;
}
- if (transport->init(vsk, psk) < 0) {
- sk_free(sk);
- return NULL;
- }
-
- if (sock)
- vsock_insert_unbound(vsk);
-
return sk;
}
-EXPORT_SYMBOL_GPL(__vsock_create);
-static void __vsock_release(struct sock *sk)
+static bool sock_type_connectible(u16 type)
{
- if (sk) {
- struct sk_buff *skb;
- struct sock *pending;
- struct vsock_sock *vsk;
+ return (type == SOCK_STREAM) || (type == SOCK_SEQPACKET);
+}
- vsk = vsock_sk(sk);
- pending = NULL; /* Compiler warning. */
+static void __vsock_release(struct sock *sk, int level)
+{
+ struct vsock_sock *vsk;
+ struct sock *pending;
- if (vsock_in_bound_table(vsk))
- vsock_remove_bound(vsk);
+ vsk = vsock_sk(sk);
+ pending = NULL; /* Compiler warning. */
- if (vsock_in_connected_table(vsk))
- vsock_remove_connected(vsk);
+ /* When "level" is SINGLE_DEPTH_NESTING, use the nested
+ * version to avoid the warning "possible recursive locking
+ * detected". When "level" is 0, lock_sock_nested(sk, level)
+ * is the same as lock_sock(sk).
+ */
+ lock_sock_nested(sk, level);
- transport->release(vsk);
+ /* Indicate to vsock_remove_sock() that the socket is being released and
+ * can be removed from the bound_table. Unlike transport reassignment
+ * case, where the socket must remain bound despite vsock_remove_sock()
+ * being called from the transport release() callback.
+ */
+ sock_set_flag(sk, SOCK_DEAD);
- lock_sock(sk);
- sock_orphan(sk);
- sk->sk_shutdown = SHUTDOWN_MASK;
+ if (vsk->transport)
+ vsk->transport->release(vsk);
+ else if (sock_type_connectible(sk->sk_type))
+ vsock_remove_sock(vsk);
- while ((skb = skb_dequeue(&sk->sk_receive_queue)))
- kfree_skb(skb);
+ sock_orphan(sk);
+ sk->sk_shutdown = SHUTDOWN_MASK;
- /* Clean up any sockets that never were accepted. */
- while ((pending = vsock_dequeue_accept(sk)) != NULL) {
- __vsock_release(pending);
- sock_put(pending);
- }
+ skb_queue_purge(&sk->sk_receive_queue);
- release_sock(sk);
- sock_put(sk);
+ /* Clean up any sockets that never were accepted. */
+ while ((pending = vsock_dequeue_accept(sk)) != NULL) {
+ __vsock_release(pending, SINGLE_DEPTH_NESTING);
+ sock_put(pending);
}
+
+ release_sock(sk);
+ sock_put(sk);
}
static void vsock_sk_destruct(struct sock *sk)
{
struct vsock_sock *vsk = vsock_sk(sk);
- transport->destruct(vsk);
+ /* Flush MSG_ZEROCOPY leftovers. */
+ __skb_queue_purge(&sk->sk_error_queue);
+
+ vsock_deassign_transport(vsk);
/* When clearing these addresses, there's no need to set the family and
* possibly register the address family with the kernel.
@@ -705,21 +915,71 @@ static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
return err;
}
+struct sock *vsock_create_connected(struct sock *parent)
+{
+ return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL,
+ parent->sk_type, 0);
+}
+EXPORT_SYMBOL_GPL(vsock_create_connected);
+
s64 vsock_stream_has_data(struct vsock_sock *vsk)
{
- return transport->stream_has_data(vsk);
+ if (WARN_ON(!vsk->transport))
+ return 0;
+
+ return vsk->transport->stream_has_data(vsk);
}
EXPORT_SYMBOL_GPL(vsock_stream_has_data);
+s64 vsock_connectible_has_data(struct vsock_sock *vsk)
+{
+ struct sock *sk = sk_vsock(vsk);
+
+ if (WARN_ON(!vsk->transport))
+ return 0;
+
+ if (sk->sk_type == SOCK_SEQPACKET)
+ return vsk->transport->seqpacket_has_data(vsk);
+ else
+ return vsock_stream_has_data(vsk);
+}
+EXPORT_SYMBOL_GPL(vsock_connectible_has_data);
+
s64 vsock_stream_has_space(struct vsock_sock *vsk)
{
- return transport->stream_has_space(vsk);
+ if (WARN_ON(!vsk->transport))
+ return 0;
+
+ return vsk->transport->stream_has_space(vsk);
}
EXPORT_SYMBOL_GPL(vsock_stream_has_space);
+void vsock_data_ready(struct sock *sk)
+{
+ struct vsock_sock *vsk = vsock_sk(sk);
+
+ if (vsock_stream_has_data(vsk) >= sk->sk_rcvlowat ||
+ sock_flag(sk, SOCK_DONE))
+ sk->sk_data_ready(sk);
+}
+EXPORT_SYMBOL_GPL(vsock_data_ready);
+
+/* Dummy callback required by sockmap.
+ * See unconditional call of saved_close() in sock_map_close().
+ */
+static void vsock_close(struct sock *sk, long timeout)
+{
+}
+
static int vsock_release(struct socket *sock)
{
- __vsock_release(sock->sk);
+ struct sock *sk = sock->sk;
+
+ if (!sk)
+ return 0;
+
+ sk->sk_prot->close(sk, 0);
+ __vsock_release(sk, 0);
sock->sk = NULL;
sock->state = SS_FREE;
@@ -727,7 +987,7 @@ static int vsock_release(struct socket *sock)
}
static int
-vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
+vsock_bind(struct socket *sock, struct sockaddr_unsized *addr, int addr_len)
{
int err;
struct sock *sk;
@@ -746,7 +1006,7 @@ vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
}
static int vsock_getname(struct socket *sock,
- struct sockaddr *addr, int *addr_len, int peer)
+ struct sockaddr *addr, int peer)
{
int err;
struct sock *sk;
@@ -769,25 +1029,48 @@ static int vsock_getname(struct socket *sock,
vm_addr = &vsk->local_addr;
}
- if (!vm_addr) {
- err = -EINVAL;
- goto out;
- }
-
- /* sys_getsockname() and sys_getpeername() pass us a
- * MAX_SOCK_ADDR-sized buffer and don't set addr_len. Unfortunately
- * that macro is defined in socket.c instead of .h, so we hardcode its
- * value here.
- */
- BUILD_BUG_ON(sizeof(*vm_addr) > 128);
+ BUILD_BUG_ON(sizeof(*vm_addr) > sizeof(struct sockaddr_storage));
memcpy(addr, vm_addr, sizeof(*vm_addr));
- *addr_len = sizeof(*vm_addr);
+ err = sizeof(*vm_addr);
out:
release_sock(sk);
return err;
}
+void vsock_linger(struct sock *sk)
+{
+ DEFINE_WAIT_FUNC(wait, woken_wake_function);
+ ssize_t (*unsent)(struct vsock_sock *vsk);
+ struct vsock_sock *vsk = vsock_sk(sk);
+ long timeout;
+
+ if (!sock_flag(sk, SOCK_LINGER))
+ return;
+
+ timeout = sk->sk_lingertime;
+ if (!timeout)
+ return;
+
+ /* Transports must implement `unsent_bytes` if they want to support
+ * SOCK_LINGER through `vsock_linger()` since we use it to check when
+ * the socket can be closed.
+ */
+ unsent = vsk->transport->unsent_bytes;
+ if (!unsent)
+ return;
+
+ add_wait_queue(sk_sleep(sk), &wait);
+
+ do {
+ if (sk_wait_event(sk, &timeout, unsent(vsk) == 0, &wait))
+ break;
+ } while (!signal_pending(current) && timeout);
+
+ remove_wait_queue(sk_sleep(sk), &wait);
+}
+EXPORT_SYMBOL_GPL(vsock_linger);
+
static int vsock_shutdown(struct socket *sock, int mode)
{
int err;
@@ -804,17 +1087,19 @@ static int vsock_shutdown(struct socket *sock, int mode)
if ((mode & ~SHUTDOWN_MASK) || !mode)
return -EINVAL;
- /* If this is a STREAM socket and it is not connected then bail out
- * immediately. If it is a DGRAM socket then we must first kick the
- * socket so that it wakes up from any sleeping calls, for example
- * recv(), and then afterwards return the error.
+ /* If this is a connection oriented socket and it is not connected then
+ * bail out immediately. If it is a DGRAM socket then we must first
+ * kick the socket so that it wakes up from any sleeping calls, for
+ * example recv(), and then afterwards return the error.
*/
sk = sock->sk;
+
+ lock_sock(sk);
if (sock->state == SS_UNCONNECTED) {
err = -ENOTCONN;
- if (sk->sk_type == SOCK_STREAM)
- return err;
+ if (sock_type_connectible(sk->sk_type))
+ goto out;
} else {
sock->state = SS_DISCONNECTING;
err = 0;
@@ -823,25 +1108,25 @@ static int vsock_shutdown(struct socket *sock, int mode)
/* Receive and send shutdowns are treated alike. */
mode = mode & (RCV_SHUTDOWN | SEND_SHUTDOWN);
if (mode) {
- lock_sock(sk);
sk->sk_shutdown |= mode;
sk->sk_state_change(sk);
- release_sock(sk);
- if (sk->sk_type == SOCK_STREAM) {
+ if (sock_type_connectible(sk->sk_type)) {
sock_reset_flag(sk, SOCK_DONE);
vsock_send_shutdown(sk, mode);
}
}
+out:
+ release_sock(sk);
return err;
}
-static unsigned int vsock_poll(struct file *file, struct socket *sock,
+static __poll_t vsock_poll(struct file *file, struct socket *sock,
poll_table *wait)
{
struct sock *sk;
- unsigned int mask;
+ __poll_t mask;
struct vsock_sock *vsk;
sk = sock->sk;
@@ -850,58 +1135,66 @@ static unsigned int vsock_poll(struct file *file, struct socket *sock,
poll_wait(file, sk_sleep(sk), wait);
mask = 0;
- if (sk->sk_err)
+ if (sk->sk_err || !skb_queue_empty_lockless(&sk->sk_error_queue))
/* Signify that there has been an error on this socket. */
- mask |= POLLERR;
+ mask |= EPOLLERR;
/* INET sockets treat local write shutdown and peer write shutdown as a
- * case of POLLHUP set.
+ * case of EPOLLHUP set.
*/
if ((sk->sk_shutdown == SHUTDOWN_MASK) ||
((sk->sk_shutdown & SEND_SHUTDOWN) &&
(vsk->peer_shutdown & SEND_SHUTDOWN))) {
- mask |= POLLHUP;
+ mask |= EPOLLHUP;
}
if (sk->sk_shutdown & RCV_SHUTDOWN ||
vsk->peer_shutdown & SEND_SHUTDOWN) {
- mask |= POLLRDHUP;
+ mask |= EPOLLRDHUP;
}
+ if (sk_is_readable(sk))
+ mask |= EPOLLIN | EPOLLRDNORM;
+
if (sock->type == SOCK_DGRAM) {
/* For datagram sockets we can read if there is something in
* the queue and write as long as the socket isn't shutdown for
* sending.
*/
- if (!skb_queue_empty(&sk->sk_receive_queue) ||
+ if (!skb_queue_empty_lockless(&sk->sk_receive_queue) ||
(sk->sk_shutdown & RCV_SHUTDOWN)) {
- mask |= POLLIN | POLLRDNORM;
+ mask |= EPOLLIN | EPOLLRDNORM;
}
if (!(sk->sk_shutdown & SEND_SHUTDOWN))
- mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
+ mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
+
+ } else if (sock_type_connectible(sk->sk_type)) {
+ const struct vsock_transport *transport;
- } else if (sock->type == SOCK_STREAM) {
lock_sock(sk);
+ transport = vsk->transport;
+
/* Listening sockets that have connections in their accept
* queue can be read.
*/
- if (sk->sk_state == SS_LISTEN
+ if (sk->sk_state == TCP_LISTEN
&& !vsock_is_accept_queue_empty(sk))
- mask |= POLLIN | POLLRDNORM;
+ mask |= EPOLLIN | EPOLLRDNORM;
/* If there is something in the queue then we can read. */
- if (transport->stream_is_active(vsk) &&
+ if (transport && transport->stream_is_active(vsk) &&
!(sk->sk_shutdown & RCV_SHUTDOWN)) {
bool data_ready_now = false;
+ int target = sock_rcvlowat(sk, 0, INT_MAX);
int ret = transport->notify_poll_in(
- vsk, 1, &data_ready_now);
+ vsk, target, &data_ready_now);
if (ret < 0) {
- mask |= POLLERR;
+ mask |= EPOLLERR;
} else {
if (data_ready_now)
- mask |= POLLIN | POLLRDNORM;
+ mask |= EPOLLIN | EPOLLRDNORM;
}
}
@@ -912,35 +1205,35 @@ static unsigned int vsock_poll(struct file *file, struct socket *sock,
*/
if (sk->sk_shutdown & RCV_SHUTDOWN ||
vsk->peer_shutdown & SEND_SHUTDOWN) {
- mask |= POLLIN | POLLRDNORM;
+ mask |= EPOLLIN | EPOLLRDNORM;
}
/* Connected sockets that can produce data can be written. */
- if (sk->sk_state == SS_CONNECTED) {
+ if (transport && sk->sk_state == TCP_ESTABLISHED) {
if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {
bool space_avail_now = false;
int ret = transport->notify_poll_out(
vsk, 1, &space_avail_now);
if (ret < 0) {
- mask |= POLLERR;
+ mask |= EPOLLERR;
} else {
if (space_avail_now)
- /* Remove POLLWRBAND since INET
+ /* Remove EPOLLWRBAND since INET
* sockets are not setting it.
*/
- mask |= POLLOUT | POLLWRNORM;
+ mask |= EPOLLOUT | EPOLLWRNORM;
}
}
}
/* Simulate INET socket poll behaviors, which sets
- * POLLOUT|POLLWRNORM when peer is closed and nothing to read,
+ * EPOLLOUT|EPOLLWRNORM when peer is closed and nothing to read,
* but local send is not shutdown.
*/
- if (sk->sk_state == SS_UNCONNECTED) {
+ if (sk->sk_state == TCP_CLOSE || sk->sk_state == TCP_CLOSING) {
if (!(sk->sk_shutdown & SEND_SHUTDOWN))
- mask |= POLLOUT | POLLWRNORM;
+ mask |= EPOLLOUT | EPOLLWRNORM;
}
@@ -950,13 +1243,24 @@ static unsigned int vsock_poll(struct file *file, struct socket *sock,
return mask;
}
-static int vsock_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock,
- struct msghdr *msg, size_t len)
+static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor)
+{
+ struct vsock_sock *vsk = vsock_sk(sk);
+
+ if (WARN_ON_ONCE(!vsk->transport))
+ return -ENODEV;
+
+ return vsk->transport->read_skb(vsk, read_actor);
+}
+
+static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t len)
{
int err;
struct sock *sk;
struct vsock_sock *vsk;
struct sockaddr_vm *remote_addr;
+ const struct vsock_transport *transport;
if (msg->msg_flags & MSG_OOB)
return -EOPNOTSUPP;
@@ -968,6 +1272,8 @@ static int vsock_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock,
lock_sock(sk);
+ transport = vsk->transport;
+
err = vsock_auto_bind(vsk);
if (err)
goto out;
@@ -1014,7 +1320,7 @@ static int vsock_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock,
goto out;
}
- err = transport->dgram_enqueue(vsk, remote_addr, msg->msg_iov, len);
+ err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
out:
release_sock(sk);
@@ -1022,7 +1328,7 @@ out:
}
static int vsock_dgram_connect(struct socket *sock,
- struct sockaddr *addr, int addr_len, int flags)
+ struct sockaddr_unsized *addr, int addr_len, int flags)
{
int err;
struct sock *sk;
@@ -1049,8 +1355,8 @@ static int vsock_dgram_connect(struct socket *sock,
if (err)
goto out;
- if (!transport->dgram_allow(remote_addr->svm_cid,
- remote_addr->svm_port)) {
+ if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
+ remote_addr->svm_port)) {
err = -EINVAL;
goto out;
}
@@ -1058,16 +1364,117 @@ static int vsock_dgram_connect(struct socket *sock,
memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
sock->state = SS_CONNECTED;
+ /* sock map disallows redirection of non-TCP sockets with sk_state !=
+ * TCP_ESTABLISHED (see sock_map_redirect_allowed()), so we set
+ * TCP_ESTABLISHED here to allow redirection of connected vsock dgrams.
+ *
+ * This doesn't seem to be abnormal state for datagram sockets, as the
+ * same approach can be see in other datagram socket types as well
+ * (such as unix sockets).
+ */
+ sk->sk_state = TCP_ESTABLISHED;
+
out:
release_sock(sk);
return err;
}
-static int vsock_dgram_recvmsg(struct kiocb *kiocb, struct socket *sock,
- struct msghdr *msg, size_t len, int flags)
+int __vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
+ size_t len, int flags)
+{
+ struct sock *sk = sock->sk;
+ struct vsock_sock *vsk = vsock_sk(sk);
+
+ return vsk->transport->dgram_dequeue(vsk, msg, len, flags);
+}
+
+int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
+ size_t len, int flags)
+{
+#ifdef CONFIG_BPF_SYSCALL
+ struct sock *sk = sock->sk;
+ const struct proto *prot;
+
+ prot = READ_ONCE(sk->sk_prot);
+ if (prot != &vsock_proto)
+ return prot->recvmsg(sk, msg, len, flags, NULL);
+#endif
+
+ return __vsock_dgram_recvmsg(sock, msg, len, flags);
+}
+EXPORT_SYMBOL_GPL(vsock_dgram_recvmsg);
+
+static int vsock_do_ioctl(struct socket *sock, unsigned int cmd,
+ int __user *arg)
+{
+ struct sock *sk = sock->sk;
+ struct vsock_sock *vsk;
+ int ret;
+
+ vsk = vsock_sk(sk);
+
+ switch (cmd) {
+ case SIOCINQ: {
+ ssize_t n_bytes;
+
+ if (!vsk->transport) {
+ ret = -EOPNOTSUPP;
+ break;
+ }
+
+ if (sock_type_connectible(sk->sk_type) &&
+ sk->sk_state == TCP_LISTEN) {
+ ret = -EINVAL;
+ break;
+ }
+
+ n_bytes = vsock_stream_has_data(vsk);
+ if (n_bytes < 0) {
+ ret = n_bytes;
+ break;
+ }
+ ret = put_user(n_bytes, arg);
+ break;
+ }
+ case SIOCOUTQ: {
+ ssize_t n_bytes;
+
+ if (!vsk->transport || !vsk->transport->unsent_bytes) {
+ ret = -EOPNOTSUPP;
+ break;
+ }
+
+ if (sock_type_connectible(sk->sk_type) && sk->sk_state == TCP_LISTEN) {
+ ret = -EINVAL;
+ break;
+ }
+
+ n_bytes = vsk->transport->unsent_bytes(vsk);
+ if (n_bytes < 0) {
+ ret = n_bytes;
+ break;
+ }
+
+ ret = put_user(n_bytes, arg);
+ break;
+ }
+ default:
+ ret = -ENOIOCTLCMD;
+ }
+
+ return ret;
+}
+
+static int vsock_ioctl(struct socket *sock, unsigned int cmd,
+ unsigned long arg)
{
- return transport->dgram_dequeue(kiocb, vsock_sk(sock->sk), msg, len,
- flags);
+ int ret;
+
+ lock_sock(sock->sk);
+ ret = vsock_do_ioctl(sock, cmd, (int __user *)arg);
+ release_sock(sock->sk);
+
+ return ret;
}
static const struct proto_ops vsock_dgram_ops = {
@@ -1080,43 +1487,54 @@ static const struct proto_ops vsock_dgram_ops = {
.accept = sock_no_accept,
.getname = vsock_getname,
.poll = vsock_poll,
- .ioctl = sock_no_ioctl,
+ .ioctl = vsock_ioctl,
.listen = sock_no_listen,
.shutdown = vsock_shutdown,
- .setsockopt = sock_no_setsockopt,
- .getsockopt = sock_no_getsockopt,
.sendmsg = vsock_dgram_sendmsg,
.recvmsg = vsock_dgram_recvmsg,
.mmap = sock_no_mmap,
- .sendpage = sock_no_sendpage,
+ .read_skb = vsock_read_skb,
};
+static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
+{
+ const struct vsock_transport *transport = vsk->transport;
+
+ if (!transport || !transport->cancel_pkt)
+ return -EOPNOTSUPP;
+
+ return transport->cancel_pkt(vsk);
+}
+
static void vsock_connect_timeout(struct work_struct *work)
{
struct sock *sk;
struct vsock_sock *vsk;
- vsk = container_of(work, struct vsock_sock, dwork.work);
+ vsk = container_of(work, struct vsock_sock, connect_work.work);
sk = sk_vsock(vsk);
lock_sock(sk);
- if (sk->sk_state == SS_CONNECTING &&
+ if (sk->sk_state == TCP_SYN_SENT &&
(sk->sk_shutdown != SHUTDOWN_MASK)) {
- sk->sk_state = SS_UNCONNECTED;
+ sk->sk_state = TCP_CLOSE;
+ sk->sk_socket->state = SS_UNCONNECTED;
sk->sk_err = ETIMEDOUT;
- sk->sk_error_report(sk);
+ sk_error_report(sk);
+ vsock_transport_cancel_pkt(vsk);
}
release_sock(sk);
sock_put(sk);
}
-static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
- int addr_len, int flags)
+static int vsock_connect(struct socket *sock, struct sockaddr_unsized *addr,
+ int addr_len, int flags)
{
int err;
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
struct sockaddr_vm *remote_addr;
long timeout;
DEFINE_WAIT(wait);
@@ -1143,37 +1561,62 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
* non-blocking call.
*/
err = -EALREADY;
+ if (flags & O_NONBLOCK)
+ goto out;
break;
default:
- if ((sk->sk_state == SS_LISTEN) ||
+ if ((sk->sk_state == TCP_LISTEN) ||
vsock_addr_cast(addr, addr_len, &remote_addr) != 0) {
err = -EINVAL;
goto out;
}
+ /* Set the remote address that we are connecting to. */
+ memcpy(&vsk->remote_addr, remote_addr,
+ sizeof(vsk->remote_addr));
+
+ err = vsock_assign_transport(vsk, NULL);
+ if (err)
+ goto out;
+
+ transport = vsk->transport;
+
/* The hypervisor and well-known contexts do not have socket
* endpoints.
*/
- if (!transport->stream_allow(remote_addr->svm_cid,
+ if (!transport ||
+ !transport->stream_allow(remote_addr->svm_cid,
remote_addr->svm_port)) {
err = -ENETUNREACH;
goto out;
}
- /* Set the remote address that we are connecting to. */
- memcpy(&vsk->remote_addr, remote_addr,
- sizeof(vsk->remote_addr));
+ if (vsock_msgzerocopy_allow(transport)) {
+ set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags);
+ } else if (sock_flag(sk, SOCK_ZEROCOPY)) {
+ /* If this option was set before 'connect()',
+ * when transport was unknown, check that this
+ * feature is supported here.
+ */
+ err = -EOPNOTSUPP;
+ goto out;
+ }
err = vsock_auto_bind(vsk);
if (err)
goto out;
- sk->sk_state = SS_CONNECTING;
+ sk->sk_state = TCP_SYN_SENT;
err = transport->connect(vsk);
if (err < 0)
goto out;
+ /* sk_err might have been set as a result of an earlier
+ * (failed) connect attempt.
+ */
+ sk->sk_err = 0;
+
/* Mark sock as connecting and set the error code to in
* progress in case this is a non-blocking connect.
*/
@@ -1188,7 +1631,11 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
timeout = vsk->connect_timeout;
prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
- while (sk->sk_state != SS_CONNECTED && sk->sk_err == 0) {
+ /* If the socket is already closing or it is in an error state, there
+ * is no point in waiting.
+ */
+ while (sk->sk_state != TCP_ESTABLISHED &&
+ sk->sk_state != TCP_CLOSING && sk->sk_err == 0) {
if (flags & O_NONBLOCK) {
/* If we're not going to block, we schedule a timeout
* function to generate a timeout on the connection
@@ -1197,9 +1644,14 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
* timeout fires.
*/
sock_hold(sk);
- INIT_DELAYED_WORK(&vsk->dwork,
- vsock_connect_timeout);
- schedule_delayed_work(&vsk->dwork, timeout);
+
+ /* If the timeout function is already scheduled,
+ * reschedule it, then ungrab the socket refcount to
+ * keep it balanced.
+ */
+ if (mod_delayed_work(system_percpu_wq, &vsk->connect_work,
+ timeout))
+ sock_put(sk);
/* Skip ahead to preserve error code set above. */
goto out_wait;
@@ -1209,12 +1661,41 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
timeout = schedule_timeout(timeout);
lock_sock(sk);
- if (signal_pending(current)) {
- err = sock_intr_errno(timeout);
- goto out_wait_error;
- } else if (timeout == 0) {
- err = -ETIMEDOUT;
- goto out_wait_error;
+ /* Connection established. Whatever happens to socket once we
+ * release it, that's not connect()'s concern. No need to go
+ * into signal and timeout handling. Call it a day.
+ *
+ * Note that allowing to "reset" an already established socket
+ * here is racy and insecure.
+ */
+ if (sk->sk_state == TCP_ESTABLISHED)
+ break;
+
+ /* If connection was _not_ established and a signal/timeout came
+ * to be, we want the socket's state reset. User space may want
+ * to retry.
+ *
+ * sk_state != TCP_ESTABLISHED implies that socket is not on
+ * vsock_connected_table. We keep the binding and the transport
+ * assigned.
+ */
+ if (signal_pending(current) || timeout == 0) {
+ err = timeout == 0 ? -ETIMEDOUT : sock_intr_errno(timeout);
+
+ /* Listener might have already responded with
+ * VIRTIO_VSOCK_OP_RESPONSE. Its handling expects our
+ * sk_state == TCP_SYN_SENT, which hereby we break.
+ * In such case VIRTIO_VSOCK_OP_RST will follow.
+ */
+ sk->sk_state = TCP_CLOSE;
+ sock->state = SS_UNCONNECTED;
+
+ /* Try to cancel VIRTIO_VSOCK_OP_REQUEST skb sent out by
+ * transport->connect().
+ */
+ vsock_transport_cancel_pkt(vsk);
+
+ goto out_wait;
}
prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
@@ -1222,23 +1703,21 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
if (sk->sk_err) {
err = -sk->sk_err;
- goto out_wait_error;
- } else
+ sk->sk_state = TCP_CLOSE;
+ sock->state = SS_UNCONNECTED;
+ } else {
err = 0;
+ }
out_wait:
finish_wait(sk_sleep(sk), &wait);
out:
release_sock(sk);
return err;
-
-out_wait_error:
- sk->sk_state = SS_UNCONNECTED;
- sock->state = SS_UNCONNECTED;
- goto out_wait;
}
-static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
+static int vsock_accept(struct socket *sock, struct socket *newsock,
+ struct proto_accept_arg *arg)
{
struct sock *listener;
int err;
@@ -1252,12 +1731,12 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
lock_sock(listener);
- if (sock->type != SOCK_STREAM) {
+ if (!sock_type_connectible(sock->type)) {
err = -EOPNOTSUPP;
goto out;
}
- if (listener->sk_state != SS_LISTEN) {
+ if (listener->sk_state != TCP_LISTEN) {
err = -EINVAL;
goto out;
}
@@ -1265,33 +1744,35 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
/* Wait for children sockets to appear; these are the new sockets
* created upon connection establishment.
*/
- timeout = sock_sndtimeo(listener, flags & O_NONBLOCK);
+ timeout = sock_rcvtimeo(listener, arg->flags & O_NONBLOCK);
prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
while ((connected = vsock_dequeue_accept(listener)) == NULL &&
listener->sk_err == 0) {
release_sock(listener);
timeout = schedule_timeout(timeout);
+ finish_wait(sk_sleep(listener), &wait);
lock_sock(listener);
if (signal_pending(current)) {
err = sock_intr_errno(timeout);
- goto out_wait;
+ goto out;
} else if (timeout == 0) {
err = -EAGAIN;
- goto out_wait;
+ goto out;
}
prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
}
+ finish_wait(sk_sleep(listener), &wait);
if (listener->sk_err)
err = -listener->sk_err;
if (connected) {
- listener->sk_ack_backlog--;
+ sk_acceptq_removed(listener);
- lock_sock(connected);
+ lock_sock_nested(connected, SINGLE_DEPTH_NESTING);
vconnected = vsock_sk(connected);
/* If the listener socket has received an error, then we should
@@ -1303,19 +1784,18 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
*/
if (err) {
vconnected->rejected = true;
- release_sock(connected);
- sock_put(connected);
- goto out_wait;
+ } else {
+ newsock->state = SS_CONNECTED;
+ sock_graft(connected, newsock);
+ if (vsock_msgzerocopy_allow(vconnected->transport))
+ set_bit(SOCK_SUPPORT_ZC,
+ &connected->sk_socket->flags);
}
- newsock->state = SS_CONNECTED;
- sock_graft(connected, newsock);
release_sock(connected);
sock_put(connected);
}
-out_wait:
- finish_wait(sk_sleep(listener), &wait);
out:
release_sock(listener);
return err;
@@ -1331,7 +1811,7 @@ static int vsock_listen(struct socket *sock, int backlog)
lock_sock(sk);
- if (sock->type != SOCK_STREAM) {
+ if (!sock_type_connectible(sk->sk_type)) {
err = -EOPNOTSUPP;
goto out;
}
@@ -1349,7 +1829,7 @@ static int vsock_listen(struct socket *sock, int backlog)
}
sk->sk_max_ack_backlog = backlog;
- sk->sk_state = SS_LISTEN;
+ sk->sk_state = TCP_LISTEN;
err = 0;
@@ -1358,18 +1838,36 @@ out:
return err;
}
-static int vsock_stream_setsockopt(struct socket *sock,
- int level,
- int optname,
- char __user *optval,
- unsigned int optlen)
+static void vsock_update_buffer_size(struct vsock_sock *vsk,
+ const struct vsock_transport *transport,
+ u64 val)
+{
+ if (val > vsk->buffer_max_size)
+ val = vsk->buffer_max_size;
+
+ if (val < vsk->buffer_min_size)
+ val = vsk->buffer_min_size;
+
+ if (val != vsk->buffer_size &&
+ transport && transport->notify_buffer_size)
+ transport->notify_buffer_size(vsk, &val);
+
+ vsk->buffer_size = val;
+}
+
+static int vsock_connectible_setsockopt(struct socket *sock,
+ int level,
+ int optname,
+ sockptr_t optval,
+ unsigned int optlen)
{
int err;
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
u64 val;
- if (level != AF_VSOCK)
+ if (level != AF_VSOCK && level != SOL_SOCKET)
return -ENOPROTOOPT;
#define COPY_IN(_v) \
@@ -1378,7 +1876,7 @@ static int vsock_stream_setsockopt(struct socket *sock,
err = -EINVAL; \
goto exit; \
} \
- if (copy_from_user(&_v, optval, sizeof(_v)) != 0) { \
+ if (copy_from_sockptr(&_v, optval, sizeof(_v)) != 0) { \
err = -EFAULT; \
goto exit; \
} \
@@ -1390,29 +1888,65 @@ static int vsock_stream_setsockopt(struct socket *sock,
lock_sock(sk);
+ transport = vsk->transport;
+
+ if (level == SOL_SOCKET) {
+ int zerocopy;
+
+ if (optname != SO_ZEROCOPY) {
+ release_sock(sk);
+ return sock_setsockopt(sock, level, optname, optval, optlen);
+ }
+
+ /* Use 'int' type here, because variable to
+ * set this option usually has this type.
+ */
+ COPY_IN(zerocopy);
+
+ if (zerocopy < 0 || zerocopy > 1) {
+ err = -EINVAL;
+ goto exit;
+ }
+
+ if (transport && !vsock_msgzerocopy_allow(transport)) {
+ err = -EOPNOTSUPP;
+ goto exit;
+ }
+
+ sock_valbool_flag(sk, SOCK_ZEROCOPY, zerocopy);
+ goto exit;
+ }
+
switch (optname) {
case SO_VM_SOCKETS_BUFFER_SIZE:
COPY_IN(val);
- transport->set_buffer_size(vsk, val);
+ vsock_update_buffer_size(vsk, transport, val);
break;
case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
COPY_IN(val);
- transport->set_max_buffer_size(vsk, val);
+ vsk->buffer_max_size = val;
+ vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
break;
case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
COPY_IN(val);
- transport->set_min_buffer_size(vsk, val);
+ vsk->buffer_min_size = val;
+ vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
break;
- case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
- struct timeval tv;
- COPY_IN(tv);
+ case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW:
+ case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD: {
+ struct __kernel_sock_timeval tv;
+
+ err = sock_copy_user_timeval(&tv, optval, optlen,
+ optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
+ if (err)
+ break;
if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC &&
tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) {
vsk->connect_timeout = tv.tv_sec * HZ +
- DIV_ROUND_UP(tv.tv_usec, (1000000 / HZ));
+ DIV_ROUND_UP((unsigned long)tv.tv_usec, (USEC_PER_SEC / HZ));
if (vsk->connect_timeout == 0)
vsk->connect_timeout =
VSOCK_DEFAULT_CONNECT_TIMEOUT;
@@ -1435,88 +1969,79 @@ exit:
return err;
}
-static int vsock_stream_getsockopt(struct socket *sock,
- int level, int optname,
- char __user *optval,
- int __user *optlen)
+static int vsock_connectible_getsockopt(struct socket *sock,
+ int level, int optname,
+ char __user *optval,
+ int __user *optlen)
{
- int err;
+ struct sock *sk = sock->sk;
+ struct vsock_sock *vsk = vsock_sk(sk);
+
+ union {
+ u64 val64;
+ struct old_timeval32 tm32;
+ struct __kernel_old_timeval tm;
+ struct __kernel_sock_timeval stm;
+ } v;
+
+ int lv = sizeof(v.val64);
int len;
- struct sock *sk;
- struct vsock_sock *vsk;
- u64 val;
if (level != AF_VSOCK)
return -ENOPROTOOPT;
- err = get_user(len, optlen);
- if (err != 0)
- return err;
+ if (get_user(len, optlen))
+ return -EFAULT;
-#define COPY_OUT(_v) \
- do { \
- if (len < sizeof(_v)) \
- return -EINVAL; \
- \
- len = sizeof(_v); \
- if (copy_to_user(optval, &_v, len) != 0) \
- return -EFAULT; \
- \
- } while (0)
-
- err = 0;
- sk = sock->sk;
- vsk = vsock_sk(sk);
+ memset(&v, 0, sizeof(v));
switch (optname) {
case SO_VM_SOCKETS_BUFFER_SIZE:
- val = transport->get_buffer_size(vsk);
- COPY_OUT(val);
+ v.val64 = vsk->buffer_size;
break;
case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
- val = transport->get_max_buffer_size(vsk);
- COPY_OUT(val);
+ v.val64 = vsk->buffer_max_size;
break;
case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
- val = transport->get_min_buffer_size(vsk);
- COPY_OUT(val);
+ v.val64 = vsk->buffer_min_size;
break;
- case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
- struct timeval tv;
- tv.tv_sec = vsk->connect_timeout / HZ;
- tv.tv_usec =
- (vsk->connect_timeout -
- tv.tv_sec * HZ) * (1000000 / HZ);
- COPY_OUT(tv);
+ case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW:
+ case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD:
+ lv = sock_get_timeout(vsk->connect_timeout, &v,
+ optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
break;
- }
+
default:
return -ENOPROTOOPT;
}
- err = put_user(len, optlen);
- if (err != 0)
+ if (len < lv)
+ return -EINVAL;
+ if (len > lv)
+ len = lv;
+ if (copy_to_user(optval, &v, len))
return -EFAULT;
-#undef COPY_OUT
+ if (put_user(len, optlen))
+ return -EFAULT;
return 0;
}
-static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
- struct msghdr *msg, size_t len)
+static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t len)
{
struct sock *sk;
struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
ssize_t total_written;
long timeout;
int err;
struct vsock_transport_send_notify_data send_data;
-
- DEFINE_WAIT(wait);
+ DEFINE_WAIT_FUNC(wait, woken_wake_function);
sk = sock->sk;
vsk = vsock_sk(sk);
@@ -1528,9 +2053,13 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
lock_sock(sk);
- /* Callers should not provide a destination with stream sockets. */
+ transport = vsk->transport;
+
+ /* Callers should not provide a destination with connection oriented
+ * sockets.
+ */
if (msg->msg_namelen) {
- err = sk->sk_state == SS_CONNECTED ? -EISCONN : -EOPNOTSUPP;
+ err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
goto out;
}
@@ -1541,7 +2070,7 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
goto out;
}
- if (sk->sk_state != SS_CONNECTED ||
+ if (!transport || sk->sk_state != TCP_ESTABLISHED ||
!vsock_addr_bound(&vsk->local_addr)) {
err = -ENOTCONN;
goto out;
@@ -1552,6 +2081,12 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
goto out;
}
+ if (msg->msg_flags & MSG_ZEROCOPY &&
+ !vsock_msgzerocopy_allow(transport)) {
+ err = -EOPNOTSUPP;
+ goto out;
+ }
+
/* Wait for room in the produce queue to enqueue our user's data. */
timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
@@ -1559,11 +2094,10 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
if (err < 0)
goto out;
- prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
-
while (total_written < len) {
ssize_t written;
+ add_wait_queue(sk_sleep(sk), &wait);
while (vsock_stream_has_space(vsk) == 0 &&
sk->sk_err == 0 &&
!(sk->sk_shutdown & SEND_SHUTDOWN) &&
@@ -1572,27 +2106,30 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
/* Don't wait for non-blocking sockets. */
if (timeout == 0) {
err = -EAGAIN;
- goto out_wait;
+ remove_wait_queue(sk_sleep(sk), &wait);
+ goto out_err;
}
err = transport->notify_send_pre_block(vsk, &send_data);
- if (err < 0)
- goto out_wait;
+ if (err < 0) {
+ remove_wait_queue(sk_sleep(sk), &wait);
+ goto out_err;
+ }
release_sock(sk);
- timeout = schedule_timeout(timeout);
+ timeout = wait_woken(&wait, TASK_INTERRUPTIBLE, timeout);
lock_sock(sk);
if (signal_pending(current)) {
err = sock_intr_errno(timeout);
- goto out_wait;
+ remove_wait_queue(sk_sleep(sk), &wait);
+ goto out_err;
} else if (timeout == 0) {
err = -EAGAIN;
- goto out_wait;
+ remove_wait_queue(sk_sleep(sk), &wait);
+ goto out_err;
}
-
- prepare_to_wait(sk_sleep(sk), &wait,
- TASK_INTERRUPTIBLE);
}
+ remove_wait_queue(sk_sleep(sk), &wait);
/* These checks occur both as part of and after the loop
* conditional since we need to check before and after
@@ -1600,16 +2137,16 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
*/
if (sk->sk_err) {
err = -sk->sk_err;
- goto out_wait;
+ goto out_err;
} else if ((sk->sk_shutdown & SEND_SHUTDOWN) ||
(vsk->peer_shutdown & RCV_SHUTDOWN)) {
err = -EPIPE;
- goto out_wait;
+ goto out_err;
}
err = transport->notify_send_pre_enqueue(vsk, &send_data);
if (err < 0)
- goto out_wait;
+ goto out_err;
/* Note that enqueue will only write as many bytes as are free
* in the produce queue, so we don't need to ensure len is
@@ -1617,12 +2154,17 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
* responsibility to check how many bytes we were able to send.
*/
- written = transport->stream_enqueue(
- vsk, msg->msg_iov,
- len - total_written);
+ if (sk->sk_type == SOCK_SEQPACKET) {
+ written = transport->seqpacket_enqueue(vsk,
+ msg, len - total_written);
+ } else {
+ written = transport->stream_enqueue(vsk,
+ msg, len - total_written);
+ }
+
if (written < 0) {
- err = -ENOMEM;
- goto out_wait;
+ err = written;
+ goto out_err;
}
total_written += written;
@@ -1630,78 +2172,109 @@ static int vsock_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
err = transport->notify_send_post_enqueue(
vsk, written, &send_data);
if (err < 0)
- goto out_wait;
+ goto out_err;
}
-out_wait:
- if (total_written > 0)
- err = total_written;
- finish_wait(sk_sleep(sk), &wait);
+out_err:
+ if (total_written > 0) {
+ /* Return number of written bytes only if:
+ * 1) SOCK_STREAM socket.
+ * 2) SOCK_SEQPACKET socket when whole buffer is sent.
+ */
+ if (sk->sk_type == SOCK_STREAM || total_written == len)
+ err = total_written;
+ }
out:
+ if (sk->sk_type == SOCK_STREAM)
+ err = sk_stream_error(sk, msg->msg_flags, err);
+
release_sock(sk);
return err;
}
-
-static int
-vsock_stream_recvmsg(struct kiocb *kiocb,
- struct socket *sock,
- struct msghdr *msg, size_t len, int flags)
+static int vsock_connectible_wait_data(struct sock *sk,
+ struct wait_queue_entry *wait,
+ long timeout,
+ struct vsock_transport_recv_notify_data *recv_data,
+ size_t target)
{
- struct sock *sk;
+ const struct vsock_transport *transport;
struct vsock_sock *vsk;
+ s64 data;
int err;
- size_t target;
- ssize_t copied;
- long timeout;
- struct vsock_transport_recv_notify_data recv_data;
-
- DEFINE_WAIT(wait);
- sk = sock->sk;
vsk = vsock_sk(sk);
err = 0;
+ transport = vsk->transport;
- msg->msg_namelen = 0;
+ while (1) {
+ prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE);
+ data = vsock_connectible_has_data(vsk);
+ if (data != 0)
+ break;
+
+ if (sk->sk_err != 0 ||
+ (sk->sk_shutdown & RCV_SHUTDOWN) ||
+ (vsk->peer_shutdown & SEND_SHUTDOWN)) {
+ break;
+ }
- lock_sock(sk);
+ /* Don't wait for non-blocking sockets. */
+ if (timeout == 0) {
+ err = -EAGAIN;
+ break;
+ }
- if (sk->sk_state != SS_CONNECTED) {
- /* Recvmsg is supposed to return 0 if a peer performs an
- * orderly shutdown. Differentiate between that case and when a
- * peer has not connected or a local shutdown occured with the
- * SOCK_DONE flag.
- */
- if (sock_flag(sk, SOCK_DONE))
- err = 0;
- else
- err = -ENOTCONN;
+ if (recv_data) {
+ err = transport->notify_recv_pre_block(vsk, target, recv_data);
+ if (err < 0)
+ break;
+ }
- goto out;
- }
+ release_sock(sk);
+ timeout = schedule_timeout(timeout);
+ lock_sock(sk);
- if (flags & MSG_OOB) {
- err = -EOPNOTSUPP;
- goto out;
+ if (signal_pending(current)) {
+ err = sock_intr_errno(timeout);
+ break;
+ } else if (timeout == 0) {
+ err = -EAGAIN;
+ break;
+ }
}
- /* We don't check peer_shutdown flag here since peer may actually shut
- * down, but there can be data in the queue that a local socket can
- * receive.
- */
- if (sk->sk_shutdown & RCV_SHUTDOWN) {
- err = 0;
- goto out;
- }
+ finish_wait(sk_sleep(sk), wait);
- /* It is valid on Linux to pass in a zero-length receive buffer. This
- * is not an error. We may as well bail out now.
+ if (err)
+ return err;
+
+ /* Internal transport error when checking for available
+ * data. XXX This should be changed to a connection
+ * reset in a later change.
*/
- if (!len) {
- err = 0;
- goto out;
- }
+ if (data < 0)
+ return -ENOMEM;
+
+ return data;
+}
+
+static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg,
+ size_t len, int flags)
+{
+ struct vsock_transport_recv_notify_data recv_data;
+ const struct vsock_transport *transport;
+ struct vsock_sock *vsk;
+ ssize_t copied;
+ size_t target;
+ long timeout;
+ int err;
+
+ DEFINE_WAIT(wait);
+
+ vsk = vsock_sk(sk);
+ transport = vsk->transport;
/* We must not copy less than target bytes into the user's buffer
* before returning successfully, so we wait for the consume queue to
@@ -1721,137 +2294,259 @@ vsock_stream_recvmsg(struct kiocb *kiocb,
if (err < 0)
goto out;
- prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
while (1) {
- s64 ready = vsock_stream_has_data(vsk);
+ ssize_t read;
- if (ready < 0) {
- /* Invalid queue pair content. XXX This should be
- * changed to a connection reset in a later change.
- */
+ err = vsock_connectible_wait_data(sk, &wait, timeout,
+ &recv_data, target);
+ if (err <= 0)
+ break;
- err = -ENOMEM;
- goto out_wait;
- } else if (ready > 0) {
- ssize_t read;
+ err = transport->notify_recv_pre_dequeue(vsk, target,
+ &recv_data);
+ if (err < 0)
+ break;
- err = transport->notify_recv_pre_dequeue(
- vsk, target, &recv_data);
- if (err < 0)
- break;
+ read = transport->stream_dequeue(vsk, msg, len - copied, flags);
+ if (read < 0) {
+ err = read;
+ break;
+ }
- read = transport->stream_dequeue(
- vsk, msg->msg_iov,
- len - copied, flags);
- if (read < 0) {
- err = -ENOMEM;
- break;
- }
+ copied += read;
- copied += read;
+ err = transport->notify_recv_post_dequeue(vsk, target, read,
+ !(flags & MSG_PEEK), &recv_data);
+ if (err < 0)
+ goto out;
- err = transport->notify_recv_post_dequeue(
- vsk, target, read,
- !(flags & MSG_PEEK), &recv_data);
- if (err < 0)
- goto out_wait;
+ if (read >= target || flags & MSG_PEEK)
+ break;
- if (read >= target || flags & MSG_PEEK)
- break;
+ target -= read;
+ }
- target -= read;
- } else {
- if (sk->sk_err != 0 || (sk->sk_shutdown & RCV_SHUTDOWN)
- || (vsk->peer_shutdown & SEND_SHUTDOWN)) {
- break;
- }
- /* Don't wait for non-blocking sockets. */
- if (timeout == 0) {
- err = -EAGAIN;
- break;
- }
+ if (sk->sk_err)
+ err = -sk->sk_err;
+ else if (sk->sk_shutdown & RCV_SHUTDOWN)
+ err = 0;
- err = transport->notify_recv_pre_block(
- vsk, target, &recv_data);
- if (err < 0)
- break;
+ if (copied > 0)
+ err = copied;
- release_sock(sk);
- timeout = schedule_timeout(timeout);
- lock_sock(sk);
+out:
+ return err;
+}
- if (signal_pending(current)) {
- err = sock_intr_errno(timeout);
- break;
- } else if (timeout == 0) {
- err = -EAGAIN;
- break;
- }
+static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
+ size_t len, int flags)
+{
+ const struct vsock_transport *transport;
+ struct vsock_sock *vsk;
+ ssize_t msg_len;
+ long timeout;
+ int err = 0;
+ DEFINE_WAIT(wait);
- prepare_to_wait(sk_sleep(sk), &wait,
- TASK_INTERRUPTIBLE);
- }
+ vsk = vsock_sk(sk);
+ transport = vsk->transport;
+
+ timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+
+ err = vsock_connectible_wait_data(sk, &wait, timeout, NULL, 0);
+ if (err <= 0)
+ goto out;
+
+ msg_len = transport->seqpacket_dequeue(vsk, msg, flags);
+
+ if (msg_len < 0) {
+ err = msg_len;
+ goto out;
}
- if (sk->sk_err)
+ if (sk->sk_err) {
err = -sk->sk_err;
- else if (sk->sk_shutdown & RCV_SHUTDOWN)
+ } else if (sk->sk_shutdown & RCV_SHUTDOWN) {
err = 0;
+ } else {
+ /* User sets MSG_TRUNC, so return real length of
+ * packet.
+ */
+ if (flags & MSG_TRUNC)
+ err = msg_len;
+ else
+ err = len - msg_data_left(msg);
- if (copied > 0) {
- /* We only do these additional bookkeeping/notification steps
- * if we actually copied something out of the queue pair
- * instead of just peeking ahead.
+ /* Always set MSG_TRUNC if real length of packet is
+ * bigger than user's buffer.
*/
+ if (msg_len > len)
+ msg->msg_flags |= MSG_TRUNC;
+ }
- if (!(flags & MSG_PEEK)) {
- /* If the other side has shutdown for sending and there
- * is nothing more to read, then modify the socket
- * state.
- */
- if (vsk->peer_shutdown & SEND_SHUTDOWN) {
- if (vsock_stream_has_data(vsk) <= 0) {
- sk->sk_state = SS_UNCONNECTED;
- sock_set_flag(sk, SOCK_DONE);
- sk->sk_state_change(sk);
- }
- }
- }
- err = copied;
+out:
+ return err;
+}
+
+int
+__vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
+ int flags)
+{
+ struct sock *sk;
+ struct vsock_sock *vsk;
+ const struct vsock_transport *transport;
+ int err;
+
+ sk = sock->sk;
+
+ if (unlikely(flags & MSG_ERRQUEUE))
+ return sock_recv_errqueue(sk, msg, len, SOL_VSOCK, VSOCK_RECVERR);
+
+ vsk = vsock_sk(sk);
+ err = 0;
+
+ lock_sock(sk);
+
+ transport = vsk->transport;
+
+ if (!transport || sk->sk_state != TCP_ESTABLISHED) {
+ /* Recvmsg is supposed to return 0 if a peer performs an
+ * orderly shutdown. Differentiate between that case and when a
+ * peer has not connected or a local shutdown occurred with the
+ * SOCK_DONE flag.
+ */
+ if (sock_flag(sk, SOCK_DONE))
+ err = 0;
+ else
+ err = -ENOTCONN;
+
+ goto out;
}
-out_wait:
- finish_wait(sk_sleep(sk), &wait);
+ if (flags & MSG_OOB) {
+ err = -EOPNOTSUPP;
+ goto out;
+ }
+
+ /* We don't check peer_shutdown flag here since peer may actually shut
+ * down, but there can be data in the queue that a local socket can
+ * receive.
+ */
+ if (sk->sk_shutdown & RCV_SHUTDOWN) {
+ err = 0;
+ goto out;
+ }
+
+ /* It is valid on Linux to pass in a zero-length receive buffer. This
+ * is not an error. We may as well bail out now.
+ */
+ if (!len) {
+ err = 0;
+ goto out;
+ }
+
+ if (sk->sk_type == SOCK_STREAM)
+ err = __vsock_stream_recvmsg(sk, msg, len, flags);
+ else
+ err = __vsock_seqpacket_recvmsg(sk, msg, len, flags);
+
out:
release_sock(sk);
return err;
}
+int
+vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
+ int flags)
+{
+#ifdef CONFIG_BPF_SYSCALL
+ struct sock *sk = sock->sk;
+ const struct proto *prot;
+
+ prot = READ_ONCE(sk->sk_prot);
+ if (prot != &vsock_proto)
+ return prot->recvmsg(sk, msg, len, flags, NULL);
+#endif
+
+ return __vsock_connectible_recvmsg(sock, msg, len, flags);
+}
+EXPORT_SYMBOL_GPL(vsock_connectible_recvmsg);
+
+static int vsock_set_rcvlowat(struct sock *sk, int val)
+{
+ const struct vsock_transport *transport;
+ struct vsock_sock *vsk;
+
+ vsk = vsock_sk(sk);
+
+ if (val > vsk->buffer_size)
+ return -EINVAL;
+
+ transport = vsk->transport;
+
+ if (transport && transport->notify_set_rcvlowat) {
+ int err;
+
+ err = transport->notify_set_rcvlowat(vsk, val);
+ if (err)
+ return err;
+ }
+
+ WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
+ return 0;
+}
+
static const struct proto_ops vsock_stream_ops = {
.family = PF_VSOCK,
.owner = THIS_MODULE,
.release = vsock_release,
.bind = vsock_bind,
- .connect = vsock_stream_connect,
+ .connect = vsock_connect,
+ .socketpair = sock_no_socketpair,
+ .accept = vsock_accept,
+ .getname = vsock_getname,
+ .poll = vsock_poll,
+ .ioctl = vsock_ioctl,
+ .listen = vsock_listen,
+ .shutdown = vsock_shutdown,
+ .setsockopt = vsock_connectible_setsockopt,
+ .getsockopt = vsock_connectible_getsockopt,
+ .sendmsg = vsock_connectible_sendmsg,
+ .recvmsg = vsock_connectible_recvmsg,
+ .mmap = sock_no_mmap,
+ .set_rcvlowat = vsock_set_rcvlowat,
+ .read_skb = vsock_read_skb,
+};
+
+static const struct proto_ops vsock_seqpacket_ops = {
+ .family = PF_VSOCK,
+ .owner = THIS_MODULE,
+ .release = vsock_release,
+ .bind = vsock_bind,
+ .connect = vsock_connect,
.socketpair = sock_no_socketpair,
.accept = vsock_accept,
.getname = vsock_getname,
.poll = vsock_poll,
- .ioctl = sock_no_ioctl,
+ .ioctl = vsock_ioctl,
.listen = vsock_listen,
.shutdown = vsock_shutdown,
- .setsockopt = vsock_stream_setsockopt,
- .getsockopt = vsock_stream_getsockopt,
- .sendmsg = vsock_stream_sendmsg,
- .recvmsg = vsock_stream_recvmsg,
+ .setsockopt = vsock_connectible_setsockopt,
+ .getsockopt = vsock_connectible_getsockopt,
+ .sendmsg = vsock_connectible_sendmsg,
+ .recvmsg = vsock_connectible_recvmsg,
.mmap = sock_no_mmap,
- .sendpage = sock_no_sendpage,
+ .read_skb = vsock_read_skb,
};
static int vsock_create(struct net *net, struct socket *sock,
int protocol, int kern)
{
+ struct vsock_sock *vsk;
+ struct sock *sk;
+ int ret;
+
if (!sock)
return -EINVAL;
@@ -1865,13 +2560,39 @@ static int vsock_create(struct net *net, struct socket *sock,
case SOCK_STREAM:
sock->ops = &vsock_stream_ops;
break;
+ case SOCK_SEQPACKET:
+ sock->ops = &vsock_seqpacket_ops;
+ break;
default:
return -ESOCKTNOSUPPORT;
}
sock->state = SS_UNCONNECTED;
- return __vsock_create(net, sock, NULL, GFP_KERNEL, 0) ? 0 : -ENOMEM;
+ sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern);
+ if (!sk)
+ return -ENOMEM;
+
+ vsk = vsock_sk(sk);
+
+ if (sock->type == SOCK_DGRAM) {
+ ret = vsock_assign_transport(vsk, NULL);
+ if (ret < 0) {
+ sock->sk = NULL;
+ sock_put(sk);
+ return ret;
+ }
+ }
+
+ /* SOCK_DGRAM doesn't have 'setsockopt' callback set in its
+ * proto_ops, so there is no handler for custom logic.
+ */
+ if (sock_type_connectible(sock->type))
+ set_bit(SOCK_CUSTOM_SOCKOPT, &sk->sk_socket->flags);
+
+ vsock_insert_unbound(vsk);
+
+ return 0;
}
static const struct net_proto_family vsock_family_ops = {
@@ -1885,16 +2606,25 @@ static long vsock_dev_do_ioctl(struct file *filp,
{
u32 __user *p = ptr;
int retval = 0;
+ u32 cid;
switch (cmd) {
case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
- if (put_user(transport->get_local_cid(), p) != 0)
+ /* To be compatible with the VMCI behavior, we prioritize the
+ * guest CID instead of well-know host CID (VMADDR_CID_HOST).
+ */
+ cid = vsock_registered_transport_cid(&transport_g2h);
+ if (cid == VMADDR_CID_ANY)
+ cid = vsock_registered_transport_cid(&transport_h2g);
+ if (cid == VMADDR_CID_ANY)
+ cid = vsock_registered_transport_cid(&transport_local);
+
+ if (put_user(cid, p) != 0)
retval = -EFAULT;
break;
default:
- pr_err("Unknown ioctl %d\n", cmd);
- retval = -EINVAL;
+ retval = -ENOIOCTLCMD;
}
return retval;
@@ -1928,23 +2658,24 @@ static struct miscdevice vsock_device = {
.fops = &vsock_device_ops,
};
-static int __vsock_core_init(void)
+static int __init vsock_init(void)
{
- int err;
+ int err = 0;
vsock_init_tables();
+ vsock_proto.owner = THIS_MODULE;
vsock_device.minor = MISC_DYNAMIC_MINOR;
err = misc_register(&vsock_device);
if (err) {
pr_err("Failed to register misc device\n");
- return -ENOENT;
+ goto err_reset_transport;
}
err = proto_register(&vsock_proto, 1); /* we want our slab */
if (err) {
pr_err("Cannot register vsock protocol\n");
- goto err_misc_deregister;
+ goto err_deregister_misc;
}
err = sock_register(&vsock_family_ops);
@@ -1954,54 +2685,111 @@ static int __vsock_core_init(void)
goto err_unregister_proto;
}
+ vsock_bpf_build_proto();
+
return 0;
err_unregister_proto:
proto_unregister(&vsock_proto);
-err_misc_deregister:
+err_deregister_misc:
misc_deregister(&vsock_device);
+err_reset_transport:
return err;
}
-int vsock_core_init(const struct vsock_transport *t)
+static void __exit vsock_exit(void)
+{
+ misc_deregister(&vsock_device);
+ sock_unregister(AF_VSOCK);
+ proto_unregister(&vsock_proto);
+}
+
+const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
{
- int retval = mutex_lock_interruptible(&vsock_register_mutex);
- if (retval)
- return retval;
+ return vsk->transport;
+}
+EXPORT_SYMBOL_GPL(vsock_core_get_transport);
- if (transport) {
- retval = -EBUSY;
- goto out;
+int vsock_core_register(const struct vsock_transport *t, int features)
+{
+ const struct vsock_transport *t_h2g, *t_g2h, *t_dgram, *t_local;
+ int err = mutex_lock_interruptible(&vsock_register_mutex);
+
+ if (err)
+ return err;
+
+ t_h2g = transport_h2g;
+ t_g2h = transport_g2h;
+ t_dgram = transport_dgram;
+ t_local = transport_local;
+
+ if (features & VSOCK_TRANSPORT_F_H2G) {
+ if (t_h2g) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_h2g = t;
}
- transport = t;
- retval = __vsock_core_init();
- if (retval)
- transport = NULL;
+ if (features & VSOCK_TRANSPORT_F_G2H) {
+ if (t_g2h) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_g2h = t;
+ }
-out:
+ if (features & VSOCK_TRANSPORT_F_DGRAM) {
+ if (t_dgram) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_dgram = t;
+ }
+
+ if (features & VSOCK_TRANSPORT_F_LOCAL) {
+ if (t_local) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_local = t;
+ }
+
+ transport_h2g = t_h2g;
+ transport_g2h = t_g2h;
+ transport_dgram = t_dgram;
+ transport_local = t_local;
+
+err_busy:
mutex_unlock(&vsock_register_mutex);
- return retval;
+ return err;
}
-EXPORT_SYMBOL_GPL(vsock_core_init);
+EXPORT_SYMBOL_GPL(vsock_core_register);
-void vsock_core_exit(void)
+void vsock_core_unregister(const struct vsock_transport *t)
{
mutex_lock(&vsock_register_mutex);
- misc_deregister(&vsock_device);
- sock_unregister(AF_VSOCK);
- proto_unregister(&vsock_proto);
+ if (transport_h2g == t)
+ transport_h2g = NULL;
+
+ if (transport_g2h == t)
+ transport_g2h = NULL;
- /* We do not want the assignment below re-ordered. */
- mb();
- transport = NULL;
+ if (transport_dgram == t)
+ transport_dgram = NULL;
+
+ if (transport_local == t)
+ transport_local = NULL;
mutex_unlock(&vsock_register_mutex);
}
-EXPORT_SYMBOL_GPL(vsock_core_exit);
+EXPORT_SYMBOL_GPL(vsock_core_unregister);
+
+module_init(vsock_init);
+module_exit(vsock_exit);
MODULE_AUTHOR("VMware, Inc.");
MODULE_DESCRIPTION("VMware Virtual Socket Family");
-MODULE_VERSION("1.0.0.0-k");
+MODULE_VERSION("1.0.2.0-k");
MODULE_LICENSE("GPL v2");