diff options
Diffstat (limited to 'net/vmw_vsock/vmci_transport.c')
| -rw-r--r-- | net/vmw_vsock/vmci_transport.c | 208 |
1 files changed, 91 insertions, 117 deletions
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index c361ce782412..7eccd6708d66 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -1,16 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * VMware vSockets Driver * * Copyright (C) 2007-2013 VMware, Inc. All rights reserved. - * - * This program is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the Free - * Software Foundation version 2 and no later version. - * - * This program is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for - * more details. */ #include <linux/types.h> @@ -65,6 +57,7 @@ static bool vmci_transport_old_proto_override(bool *old_pkt_proto); static u16 vmci_transport_new_proto_supported_versions(void); static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto, bool old_pkt_proto); +static bool vmci_check_transport(struct vsock_sock *vsk); struct vmci_transport_recv_pkt_info { struct work_struct work; @@ -82,14 +75,7 @@ static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; static int PROTOCOL_OVERRIDE = -1; -#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN 128 -#define VMCI_TRANSPORT_DEFAULT_QP_SIZE 262144 -#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX 262144 - -/* The default peer timeout indicates how long we will wait for a peer response - * to a control message. - */ -#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ) +static struct vsock_transport vmci_transport; /* forward declaration */ /* Helper function to convert from a VMCI error code to a VSock error code. */ @@ -133,6 +119,8 @@ vmci_transport_packet_init(struct vmci_transport_packet *pkt, u16 proto, struct vmci_handle handle) { + memset(pkt, 0, sizeof(*pkt)); + /* We register the stream control handler as an any cid handle so we * must always send from a source address of VMADDR_CID_ANY */ @@ -145,8 +133,6 @@ vmci_transport_packet_init(struct vmci_transport_packet *pkt, pkt->type = type; pkt->src_port = src->svm_port; pkt->dst_port = dst->svm_port; - memset(&pkt->proto, 0, sizeof(pkt->proto)); - memset(&pkt->_reserved2, 0, sizeof(pkt->_reserved2)); switch (pkt->type) { case VMCI_TRANSPORT_PACKET_TYPE_INVALID: @@ -584,8 +570,7 @@ vmci_transport_queue_pair_alloc(struct vmci_qp **qpair, peer, flags, VMCI_NO_PRIVILEGE_FLAGS); out: if (err < 0) { - pr_err("Could not attach to queue pair with %d\n", - err); + pr_err_once("Could not attach to queue pair with %d\n", err); err = vmci_transport_error_to_vsock_error(err); } @@ -664,7 +649,7 @@ static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg) static bool vmci_transport_stream_allow(u32 cid, u32 port) { static const u32 non_socket_contexts[] = { - VMADDR_CID_RESERVED, + VMADDR_CID_LOCAL, }; int i; @@ -848,7 +833,7 @@ static void vmci_transport_handle_detach(struct sock *sk) sk->sk_state = TCP_CLOSE; sk->sk_err = ECONNRESET; - sk->sk_error_report(sk); + sk_error_report(sk); return; } sk->sk_state = TCP_CLOSE; @@ -899,7 +884,8 @@ static void vmci_transport_qp_resumed_cb(u32 sub_id, const struct vmci_event_data *e_data, void *client_data) { - vsock_for_each_connected_socket(vmci_transport_handle_detach); + vsock_for_each_connected_socket(&vmci_transport, + vmci_transport_handle_detach); } static void vmci_transport_recv_pkt_work(struct work_struct *work) @@ -961,13 +947,11 @@ static int vmci_transport_recv_listen(struct sock *sk, bool old_request = false; bool old_pkt_proto = false; - err = 0; - /* Because we are in the listen state, we could be receiving a packet * for ourself or any previous connection requests that we received. * If it's the latter, we try to find a socket in our list of pending * connections and, if we do, call the appropriate handler for the - * state that that socket is in. Otherwise we try to service the + * state that socket is in. Otherwise we try to service the * connection request. */ pending = vmci_transport_get_pending(sk, pkt); @@ -1021,8 +1005,7 @@ static int vmci_transport_recv_listen(struct sock *sk, return -ECONNREFUSED; } - pending = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, - sk->sk_type, 0); + pending = vsock_create_connected(sk); if (!pending) { vmci_transport_send_reset(sk, pkt); return -ENOMEM; @@ -1035,14 +1018,24 @@ static int vmci_transport_recv_listen(struct sock *sk, vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context, pkt->src_port); + err = vsock_assign_transport(vpending, vsock_sk(sk)); + /* Transport assigned (looking at remote_addr) must be the same + * where we received the request. + */ + if (err || !vmci_check_transport(vpending)) { + vmci_transport_send_reset(sk, pkt); + sock_put(pending); + return err; + } + /* If the proposed size fits within our min/max, accept it. Otherwise * propose our own size. */ - if (pkt->u.size >= vmci_trans(vpending)->queue_pair_min_size && - pkt->u.size <= vmci_trans(vpending)->queue_pair_max_size) { + if (pkt->u.size >= vpending->buffer_min_size && + pkt->u.size <= vpending->buffer_max_size) { qp_size = pkt->u.size; } else { - qp_size = vmci_trans(vpending)->queue_pair_size; + qp_size = vpending->buffer_size; } /* Figure out if we are using old or new requests based on the @@ -1106,12 +1099,12 @@ static int vmci_transport_recv_listen(struct sock *sk, } vsock_add_pending(sk, pending); - sk->sk_ack_backlog++; + sk_acceptq_added(sk); pending->sk_state = TCP_SYN_SENT; vmci_trans(vpending)->produce_size = vmci_trans(vpending)->consume_size = qp_size; - vmci_trans(vpending)->queue_pair_size = qp_size; + vpending->buffer_size = qp_size; vmci_trans(vpending)->notify_ops->process_request(pending); @@ -1258,7 +1251,7 @@ vmci_transport_recv_connecting_server(struct sock *listener, vsock_remove_pending(listener, pending); vsock_enqueue_accept(listener, pending); - /* Callers of accept() will be be waiting on the listening socket, not + /* Callers of accept() will be waiting on the listening socket, not * the pending socket. */ listener->sk_data_ready(listener); @@ -1375,7 +1368,7 @@ destroy: sk->sk_state = TCP_CLOSE; sk->sk_err = skerr; - sk->sk_error_report(sk); + sk_error_report(sk); return err; } @@ -1405,8 +1398,8 @@ static int vmci_transport_recv_connecting_client_negotiate( vsk->ignore_connecting_rst = false; /* Verify that we're OK with the proposed queue pair size */ - if (pkt->u.size < vmci_trans(vsk)->queue_pair_min_size || - pkt->u.size > vmci_trans(vsk)->queue_pair_max_size) { + if (pkt->u.size < vsk->buffer_min_size || + pkt->u.size > vsk->buffer_max_size) { err = -EINVAL; goto destroy; } @@ -1511,8 +1504,7 @@ vmci_transport_recv_connecting_client_invalid(struct sock *sk, vsk->sent_request = false; vsk->ignore_connecting_rst = true; - err = vmci_transport_send_conn_request( - sk, vmci_trans(vsk)->queue_pair_size); + err = vmci_transport_send_conn_request(sk, vsk->buffer_size); if (err < 0) err = vmci_transport_error_to_vsock_error(err); else @@ -1596,21 +1588,6 @@ static int vmci_transport_socket_init(struct vsock_sock *vsk, INIT_LIST_HEAD(&vmci_trans(vsk)->elem); vmci_trans(vsk)->sk = &vsk->sk; spin_lock_init(&vmci_trans(vsk)->lock); - if (psk) { - vmci_trans(vsk)->queue_pair_size = - vmci_trans(psk)->queue_pair_size; - vmci_trans(vsk)->queue_pair_min_size = - vmci_trans(psk)->queue_pair_min_size; - vmci_trans(vsk)->queue_pair_max_size = - vmci_trans(psk)->queue_pair_max_size; - } else { - vmci_trans(vsk)->queue_pair_size = - VMCI_TRANSPORT_DEFAULT_QP_SIZE; - vmci_trans(vsk)->queue_pair_min_size = - VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN; - vmci_trans(vsk)->queue_pair_max_size = - VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX; - } return 0; } @@ -1651,6 +1628,10 @@ static void vmci_transport_cleanup(struct work_struct *work) static void vmci_transport_destruct(struct vsock_sock *vsk) { + /* transport can be NULL if we hit a failure at init() time */ + if (!vmci_trans(vsk)) + return; + /* Ensure that the detach callback doesn't use the sk/vsk * we are about to destruct. */ @@ -1730,7 +1711,11 @@ static int vmci_transport_dgram_enqueue( if (!dg) return -ENOMEM; - memcpy_from_msg(VMCI_DG_PAYLOAD(dg), msg, len); + err = memcpy_from_msg(VMCI_DG_PAYLOAD(dg), msg, len); + if (err) { + kfree(dg); + return err; + } dg->dst = vmci_make_handle(remote_addr->svm_cid, remote_addr->svm_port); @@ -1751,19 +1736,16 @@ static int vmci_transport_dgram_dequeue(struct vsock_sock *vsk, int flags) { int err; - int noblock; struct vmci_datagram *dg; size_t payload_len; struct sk_buff *skb; - noblock = flags & MSG_DONTWAIT; - if (flags & MSG_OOB || flags & MSG_ERRQUEUE) return -EOPNOTSUPP; /* Retrieve the head sk_buff from the socket's receive queue. */ err = 0; - skb = skb_recv_datagram(&vsk->sk, flags, noblock, &err); + skb = skb_recv_datagram(&vsk->sk, flags, &err); if (!skb) return err; @@ -1822,8 +1804,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk) if (vmci_transport_old_proto_override(&old_pkt_proto) && old_pkt_proto) { - err = vmci_transport_send_conn_request( - sk, vmci_trans(vsk)->queue_pair_size); + err = vmci_transport_send_conn_request(sk, vsk->buffer_size); if (err < 0) { sk->sk_state = TCP_CLOSE; return err; @@ -1831,8 +1812,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk) } else { int supported_proto_versions = vmci_transport_new_proto_supported_versions(); - err = vmci_transport_send_conn_request2( - sk, vmci_trans(vsk)->queue_pair_size, + err = vmci_transport_send_conn_request2(sk, vsk->buffer_size, supported_proto_versions); if (err < 0) { sk->sk_state = TCP_CLOSE; @@ -1851,10 +1831,17 @@ static ssize_t vmci_transport_stream_dequeue( size_t len, int flags) { + ssize_t err; + if (flags & MSG_PEEK) - return vmci_qpair_peekv(vmci_trans(vsk)->qpair, msg, len, 0); + err = vmci_qpair_peekv(vmci_trans(vsk)->qpair, msg, len, 0); else - return vmci_qpair_dequev(vmci_trans(vsk)->qpair, msg, len, 0); + err = vmci_qpair_dequev(vmci_trans(vsk)->qpair, msg, len, 0); + + if (err < 0) + err = -ENOMEM; + + return err; } static ssize_t vmci_transport_stream_enqueue( @@ -1862,7 +1849,13 @@ static ssize_t vmci_transport_stream_enqueue( struct msghdr *msg, size_t len) { - return vmci_qpair_enquev(vmci_trans(vsk)->qpair, msg, len, 0); + ssize_t err; + + err = vmci_qpair_enquev(vmci_trans(vsk)->qpair, msg, len, 0); + if (err < 0) + err = -ENOMEM; + + return err; } static s64 vmci_transport_stream_has_data(struct vsock_sock *vsk) @@ -1885,46 +1878,6 @@ static bool vmci_transport_stream_is_active(struct vsock_sock *vsk) return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle); } -static u64 vmci_transport_get_buffer_size(struct vsock_sock *vsk) -{ - return vmci_trans(vsk)->queue_pair_size; -} - -static u64 vmci_transport_get_min_buffer_size(struct vsock_sock *vsk) -{ - return vmci_trans(vsk)->queue_pair_min_size; -} - -static u64 vmci_transport_get_max_buffer_size(struct vsock_sock *vsk) -{ - return vmci_trans(vsk)->queue_pair_max_size; -} - -static void vmci_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) -{ - if (val < vmci_trans(vsk)->queue_pair_min_size) - vmci_trans(vsk)->queue_pair_min_size = val; - if (val > vmci_trans(vsk)->queue_pair_max_size) - vmci_trans(vsk)->queue_pair_max_size = val; - vmci_trans(vsk)->queue_pair_size = val; -} - -static void vmci_transport_set_min_buffer_size(struct vsock_sock *vsk, - u64 val) -{ - if (val > vmci_trans(vsk)->queue_pair_size) - vmci_trans(vsk)->queue_pair_size = val; - vmci_trans(vsk)->queue_pair_min_size = val; -} - -static void vmci_transport_set_max_buffer_size(struct vsock_sock *vsk, - u64 val) -{ - if (val < vmci_trans(vsk)->queue_pair_size) - vmci_trans(vsk)->queue_pair_size = val; - vmci_trans(vsk)->queue_pair_max_size = val; -} - static int vmci_transport_notify_poll_in( struct vsock_sock *vsk, size_t target, @@ -2080,7 +2033,8 @@ static u32 vmci_transport_get_local_cid(void) return vmci_get_context_id(); } -static const struct vsock_transport vmci_transport = { +static struct vsock_transport vmci_transport = { + .module = THIS_MODULE, .init = vmci_transport_socket_init, .destruct = vmci_transport_destruct, .release = vmci_transport_release, @@ -2107,15 +2061,26 @@ static const struct vsock_transport vmci_transport = { .notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue, .notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue, .shutdown = vmci_transport_shutdown, - .set_buffer_size = vmci_transport_set_buffer_size, - .set_min_buffer_size = vmci_transport_set_min_buffer_size, - .set_max_buffer_size = vmci_transport_set_max_buffer_size, - .get_buffer_size = vmci_transport_get_buffer_size, - .get_min_buffer_size = vmci_transport_get_min_buffer_size, - .get_max_buffer_size = vmci_transport_get_max_buffer_size, .get_local_cid = vmci_transport_get_local_cid, }; +static bool vmci_check_transport(struct vsock_sock *vsk) +{ + return vsk->transport == &vmci_transport; +} + +static void vmci_vsock_transport_cb(bool is_host) +{ + int features; + + if (is_host) + features = VSOCK_TRANSPORT_F_H2G; + else + features = VSOCK_TRANSPORT_F_G2H; + + vsock_core_register(&vmci_transport, features); +} + static int __init vmci_transport_init(void) { int err; @@ -2132,7 +2097,6 @@ static int __init vmci_transport_init(void) pr_err("Unable to create datagram handle. (%d)\n", err); return vmci_transport_error_to_vsock_error(err); } - err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED, vmci_transport_qp_resumed_cb, NULL, &vmci_transport_qp_resumed_sub_id); @@ -2143,12 +2107,21 @@ static int __init vmci_transport_init(void) goto err_destroy_stream_handle; } - err = vsock_core_init(&vmci_transport); + /* Register only with dgram feature, other features (H2G, G2H) will be + * registered when the first host or guest becomes active. + */ + err = vsock_core_register(&vmci_transport, VSOCK_TRANSPORT_F_DGRAM); if (err < 0) goto err_unsubscribe; + err = vmci_register_vsock_callback(vmci_vsock_transport_cb); + if (err < 0) + goto err_unregister; + return 0; +err_unregister: + vsock_core_unregister(&vmci_transport); err_unsubscribe: vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id); err_destroy_stream_handle: @@ -2174,7 +2147,8 @@ static void __exit vmci_transport_exit(void) vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; } - vsock_core_exit(); + vmci_register_vsock_callback(NULL); + vsock_core_unregister(&vmci_transport); } module_exit(vmci_transport_exit); |
