diff options
Diffstat (limited to 'net/rxrpc/af_rxrpc.c')
| -rw-r--r-- | net/rxrpc/af_rxrpc.c | 876 |
1 files changed, 567 insertions, 309 deletions
diff --git a/net/rxrpc/af_rxrpc.c b/net/rxrpc/af_rxrpc.c index e61aa6001c65..0c2c68c4b07e 100644 --- a/net/rxrpc/af_rxrpc.c +++ b/net/rxrpc/af_rxrpc.c @@ -1,25 +1,25 @@ +// SPDX-License-Identifier: GPL-2.0-or-later /* AF_RXRPC implementation * * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. * Written by David Howells (dhowells@redhat.com) - * - * This program is free software; you can redistribute it and/or - * modify it under the terms of the GNU General Public License - * as published by the Free Software Foundation; either version - * 2 of the License, or (at your option) any later version. */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + #include <linux/module.h> #include <linux/kernel.h> #include <linux/net.h> #include <linux/slab.h> #include <linux/skbuff.h> +#include <linux/random.h> #include <linux/poll.h> #include <linux/proc_fs.h> #include <linux/key-type.h> #include <net/net_namespace.h> #include <net/sock.h> #include <net/af_rxrpc.h> +#define CREATE_TRACE_POINTS #include "ar-internal.h" MODULE_DESCRIPTION("RxRPC network protocol"); @@ -28,22 +28,18 @@ MODULE_LICENSE("GPL"); MODULE_ALIAS_NETPROTO(PF_RXRPC); unsigned int rxrpc_debug; // = RXRPC_DEBUG_KPROTO; -module_param_named(debug, rxrpc_debug, uint, S_IWUSR | S_IRUGO); +module_param_named(debug, rxrpc_debug, uint, 0644); MODULE_PARM_DESC(debug, "RxRPC debugging mask"); -static int sysctl_rxrpc_max_qlen __read_mostly = 10; - static struct proto rxrpc_proto; static const struct proto_ops rxrpc_rpc_ops; -/* local epoch for detecting local-end reset */ -__be32 rxrpc_epoch; - /* current debugging ID */ atomic_t rxrpc_debug_id; +EXPORT_SYMBOL(rxrpc_debug_id); /* count of skbs currently in use */ -atomic_t rxrpc_n_skbs; +atomic_t rxrpc_n_rx_skbs; struct workqueue_struct *rxrpc_workqueue; @@ -54,7 +50,7 @@ static void rxrpc_sock_destructor(struct sock *); */ static inline int rxrpc_writable(struct sock *sk) { - return atomic_read(&sk->sk_wmem_alloc) < (size_t) sk->sk_sndbuf; + return refcount_read(&sk->sk_wmem_alloc) < (size_t) sk->sk_sndbuf; } /* @@ -67,9 +63,9 @@ static void rxrpc_write_space(struct sock *sk) if (rxrpc_writable(sk)) { struct socket_wq *wq = rcu_dereference(sk->sk_wq); - if (wq_has_sleeper(wq)) + if (skwq_has_sleeper(wq)) wake_up_interruptible(&wq->wait); - sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT); + sk_wake_async_rcu(sk, SOCK_WAKE_SPACE, POLL_OUT); } rcu_read_unlock(); } @@ -81,6 +77,8 @@ static int rxrpc_validate_address(struct rxrpc_sock *rx, struct sockaddr_rxrpc *srx, int len) { + unsigned int tail; + if (len < sizeof(struct sockaddr_rxrpc)) return -EINVAL; @@ -95,37 +93,46 @@ static int rxrpc_validate_address(struct rxrpc_sock *rx, srx->transport_len > len) return -EINVAL; - if (srx->transport.family != rx->proto) - return -EAFNOSUPPORT; - switch (srx->transport.family) { case AF_INET: - _debug("INET: %x @ %pI4", - ntohs(srx->transport.sin.sin_port), - &srx->transport.sin.sin_addr); - if (srx->transport_len > 8) - memset((void *)&srx->transport + 8, 0, - srx->transport_len - 8); + if (rx->family != AF_INET && + rx->family != AF_INET6) + return -EAFNOSUPPORT; + if (srx->transport_len < sizeof(struct sockaddr_in)) + return -EINVAL; + tail = offsetof(struct sockaddr_rxrpc, transport.sin.__pad); break; +#ifdef CONFIG_AF_RXRPC_IPV6 case AF_INET6: + if (rx->family != AF_INET6) + return -EAFNOSUPPORT; + if (srx->transport_len < sizeof(struct sockaddr_in6)) + return -EINVAL; + tail = offsetof(struct sockaddr_rxrpc, transport) + + sizeof(struct sockaddr_in6); + break; +#endif + default: return -EAFNOSUPPORT; } + if (tail < len) + memset((void *)srx + tail, 0, len - tail); + _debug("INET: %pISp", &srx->transport); return 0; } /* * bind a local address to an RxRPC socket */ -static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len) +static int rxrpc_bind(struct socket *sock, struct sockaddr_unsized *saddr, int len) { - struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *) saddr; - struct sock *sk = sock->sk; + struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *)saddr; struct rxrpc_local *local; - struct rxrpc_sock *rx = rxrpc_sk(sk), *prx; - __be16 service_id; + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); + u16 service_id; int ret; _enter("%p,%p,%d", rx, saddr, len); @@ -133,39 +140,52 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len) ret = rxrpc_validate_address(rx, srx, len); if (ret < 0) goto error; + service_id = srx->srx_service; lock_sock(&rx->sk); - if (rx->sk.sk_state != RXRPC_UNCONNECTED) { - ret = -EINVAL; - goto error_unlock; - } - - memcpy(&rx->srx, srx, sizeof(rx->srx)); - - /* find a local transport endpoint if we don't have one already */ - local = rxrpc_lookup_local(&rx->srx); - if (IS_ERR(local)) { - ret = PTR_ERR(local); - goto error_unlock; - } + switch (rx->sk.sk_state) { + case RXRPC_UNBOUND: + rx->srx = *srx; + local = rxrpc_lookup_local(sock_net(&rx->sk), &rx->srx); + if (IS_ERR(local)) { + ret = PTR_ERR(local); + goto error_unlock; + } - rx->local = local; - if (srx->srx_service) { - service_id = htons(srx->srx_service); - write_lock_bh(&local->services_lock); - list_for_each_entry(prx, &local->services, listen_link) { - if (prx->service_id == service_id) + if (service_id) { + write_lock(&local->services_lock); + if (local->service) goto service_in_use; + rx->local = local; + local->service = rx; + write_unlock(&local->services_lock); + + rx->sk.sk_state = RXRPC_SERVER_BOUND; + } else { + rx->local = local; + rx->sk.sk_state = RXRPC_CLIENT_BOUND; } + break; - rx->service_id = service_id; - list_add_tail(&rx->listen_link, &local->services); - write_unlock_bh(&local->services_lock); + case RXRPC_SERVER_BOUND: + ret = -EINVAL; + if (service_id == 0) + goto error_unlock; + ret = -EADDRINUSE; + if (service_id == rx->srx.srx_service) + goto error_unlock; + ret = -EINVAL; + srx->srx_service = rx->srx.srx_service; + if (memcmp(srx, &rx->srx, sizeof(*srx)) != 0) + goto error_unlock; + rx->second_service = service_id; + rx->sk.sk_state = RXRPC_SERVER_BOUND2; + break; - rx->sk.sk_state = RXRPC_SERVER_BOUND; - } else { - rx->sk.sk_state = RXRPC_CLIENT_BOUND; + default: + ret = -EINVAL; + goto error_unlock; } release_sock(&rx->sk); @@ -173,8 +193,10 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len) return 0; service_in_use: + write_unlock(&local->services_lock); + rxrpc_unuse_local(local, rxrpc_local_unuse_bind); + rxrpc_put_local(local, rxrpc_local_put_bind); ret = -EADDRINUSE; - write_unlock_bh(&local->services_lock); error_unlock: release_sock(&rx->sk); error: @@ -189,6 +211,7 @@ static int rxrpc_listen(struct socket *sock, int backlog) { struct sock *sk = sock->sk; struct rxrpc_sock *rx = rxrpc_sk(sk); + unsigned int max, old; int ret; _enter("%p,%d", rx, backlog); @@ -196,19 +219,37 @@ static int rxrpc_listen(struct socket *sock, int backlog) lock_sock(&rx->sk); switch (rx->sk.sk_state) { - case RXRPC_UNCONNECTED: + case RXRPC_UNBOUND: ret = -EADDRNOTAVAIL; break; - case RXRPC_CLIENT_BOUND: - case RXRPC_CLIENT_CONNECTED: - default: - ret = -EBUSY; - break; case RXRPC_SERVER_BOUND: + case RXRPC_SERVER_BOUND2: ASSERT(rx->local != NULL); + max = READ_ONCE(rxrpc_max_backlog); + ret = -EINVAL; + if (backlog == INT_MAX) + backlog = max; + else if (backlog < 0 || backlog > max) + break; + old = sk->sk_max_ack_backlog; sk->sk_max_ack_backlog = backlog; - rx->sk.sk_state = RXRPC_SERVER_LISTENING; - ret = 0; + ret = rxrpc_service_prealloc(rx, GFP_KERNEL); + if (ret == 0) + rx->sk.sk_state = RXRPC_SERVER_LISTENING; + else + sk->sk_max_ack_backlog = old; + break; + case RXRPC_SERVER_LISTENING: + if (backlog == 0) { + rx->sk.sk_state = RXRPC_SERVER_LISTEN_DISABLED; + sk->sk_max_ack_backlog = 0; + rxrpc_discard_prealloc(rx); + ret = 0; + break; + } + fallthrough; + default: + ret = -EBUSY; break; } @@ -217,167 +258,234 @@ static int rxrpc_listen(struct socket *sock, int backlog) return ret; } -/* - * find a transport by address +/** + * rxrpc_kernel_lookup_peer - Obtain remote transport endpoint for an address + * @sock: The socket through which it will be accessed + * @srx: The network address + * @gfp: Allocation flags + * + * Lookup or create a remote transport endpoint record for the specified + * address. + * + * Return: The peer record found with a reference, %NULL if no record is found + * or a negative error code if the address is invalid or unsupported. */ -static struct rxrpc_transport *rxrpc_name_to_transport(struct socket *sock, - struct sockaddr *addr, - int addr_len, int flags, - gfp_t gfp) +struct rxrpc_peer *rxrpc_kernel_lookup_peer(struct socket *sock, + struct sockaddr_rxrpc *srx, gfp_t gfp) { - struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *) addr; - struct rxrpc_transport *trans; struct rxrpc_sock *rx = rxrpc_sk(sock->sk); - struct rxrpc_peer *peer; - - _enter("%p,%p,%d,%d", rx, addr, addr_len, flags); + int ret; - ASSERT(rx->local != NULL); - ASSERT(rx->sk.sk_state > RXRPC_UNCONNECTED); + ret = rxrpc_validate_address(rx, srx, sizeof(*srx)); + if (ret < 0) + return ERR_PTR(ret); - if (rx->srx.transport_type != srx->transport_type) - return ERR_PTR(-ESOCKTNOSUPPORT); - if (rx->srx.transport.family != srx->transport.family) - return ERR_PTR(-EAFNOSUPPORT); + return rxrpc_lookup_peer(rx->local, srx, gfp); +} +EXPORT_SYMBOL(rxrpc_kernel_lookup_peer); - /* find a remote transport endpoint from the local one */ - peer = rxrpc_get_peer(srx, gfp); - if (IS_ERR(peer)) - return ERR_CAST(peer); +/** + * rxrpc_kernel_get_peer - Get a reference on a peer + * @peer: The peer to get a reference on (may be NULL). + * + * Get a reference for a remote peer record (if not NULL). + * + * Return: The @peer argument. + */ +struct rxrpc_peer *rxrpc_kernel_get_peer(struct rxrpc_peer *peer) +{ + return peer ? rxrpc_get_peer(peer, rxrpc_peer_get_application) : NULL; +} +EXPORT_SYMBOL(rxrpc_kernel_get_peer); - /* find a transport */ - trans = rxrpc_get_transport(rx->local, peer, gfp); - rxrpc_put_peer(peer); - _leave(" = %p", trans); - return trans; +/** + * rxrpc_kernel_put_peer - Allow a kernel app to drop a peer reference + * @peer: The peer to drop a ref on + * + * Drop a reference on a peer record. + */ +void rxrpc_kernel_put_peer(struct rxrpc_peer *peer) +{ + rxrpc_put_peer(peer, rxrpc_peer_put_application); } +EXPORT_SYMBOL(rxrpc_kernel_put_peer); /** * rxrpc_kernel_begin_call - Allow a kernel service to begin a call * @sock: The socket on which to make the call - * @srx: The address of the peer to contact (defaults to socket setting) + * @peer: The peer to contact * @key: The security context to use (defaults to socket setting) * @user_call_ID: The ID to use + * @tx_total_len: Total length of data to transmit during the call (or -1) + * @hard_timeout: The maximum lifespan of the call in sec + * @gfp: The allocation constraints + * @notify_rx: Where to send notifications instead of socket queue + * @service_id: The ID of the service to contact + * @upgrade: Request service upgrade for call + * @interruptibility: The call is interruptible, or can be canceled. + * @debug_id: The debug ID for tracing to be assigned to the call * * Allow a kernel service to begin a call on the nominated socket. This just * sets up all the internal tracking structures and allocates connection and - * call IDs as appropriate. The call to be used is returned. + * call IDs as appropriate. * * The default socket destination address and security may be overridden by * supplying @srx and @key. + * + * Return: The new call or an error code. */ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock, - struct sockaddr_rxrpc *srx, + struct rxrpc_peer *peer, struct key *key, unsigned long user_call_ID, - gfp_t gfp) + s64 tx_total_len, + u32 hard_timeout, + gfp_t gfp, + rxrpc_notify_rx_t notify_rx, + u16 service_id, + bool upgrade, + enum rxrpc_interruptibility interruptibility, + unsigned int debug_id) { - struct rxrpc_conn_bundle *bundle; - struct rxrpc_transport *trans; + struct rxrpc_conn_parameters cp; + struct rxrpc_call_params p; struct rxrpc_call *call; struct rxrpc_sock *rx = rxrpc_sk(sock->sk); - __be16 service_id; _enter(",,%x,%lx", key_serial(key), user_call_ID); - lock_sock(&rx->sk); + if (WARN_ON_ONCE(peer->local != rx->local)) + return ERR_PTR(-EIO); - if (srx) { - trans = rxrpc_name_to_transport(sock, (struct sockaddr *) srx, - sizeof(*srx), 0, gfp); - if (IS_ERR(trans)) { - call = ERR_CAST(trans); - trans = NULL; - goto out_notrans; - } - } else { - trans = rx->trans; - if (!trans) { - call = ERR_PTR(-ENOTCONN); - goto out_notrans; - } - atomic_inc(&trans->usage); - } - - service_id = rx->service_id; - if (srx) - service_id = htons(srx->srx_service); + lock_sock(&rx->sk); if (!key) key = rx->key; - if (key && !key->payload.data) + if (key && !key->payload.data[0]) key = NULL; /* a no-security key */ - bundle = rxrpc_get_bundle(rx, trans, key, service_id, gfp); - if (IS_ERR(bundle)) { - call = ERR_CAST(bundle); - goto out; + memset(&p, 0, sizeof(p)); + p.user_call_ID = user_call_ID; + p.tx_total_len = tx_total_len; + p.interruptibility = interruptibility; + p.kernel = true; + p.timeouts.hard = hard_timeout; + + memset(&cp, 0, sizeof(cp)); + cp.local = rx->local; + cp.peer = peer; + cp.key = key; + cp.security_level = rx->min_sec_level; + cp.exclusive = false; + cp.upgrade = upgrade; + cp.service_id = service_id; + call = rxrpc_new_client_call(rx, &cp, &p, gfp, debug_id); + /* The socket has been unlocked. */ + if (!IS_ERR(call)) { + call->notify_rx = notify_rx; + mutex_unlock(&call->user_mutex); } - call = rxrpc_get_client_call(rx, trans, bundle, user_call_ID, true, - gfp); - rxrpc_put_bundle(trans, bundle); -out: - rxrpc_put_transport(trans); -out_notrans: - release_sock(&rx->sk); _leave(" = %p", call); return call; } - EXPORT_SYMBOL(rxrpc_kernel_begin_call); +/* + * Dummy function used to stop the notifier talking to recvmsg(). + */ +static void rxrpc_dummy_notify_rx(struct sock *sk, struct rxrpc_call *rxcall, + unsigned long call_user_ID) +{ +} + /** - * rxrpc_kernel_end_call - Allow a kernel service to end a call it was using + * rxrpc_kernel_shutdown_call - Allow a kernel service to shut down a call it was using + * @sock: The socket the call is on * @call: The call to end * - * Allow a kernel service to end a call it was using. The call must be + * Allow a kernel service to shut down a call it was using. The call must be * complete before this is called (the call should be aborted if necessary). */ -void rxrpc_kernel_end_call(struct rxrpc_call *call) +void rxrpc_kernel_shutdown_call(struct socket *sock, struct rxrpc_call *call) +{ + _enter("%d{%d}", call->debug_id, refcount_read(&call->ref)); + + mutex_lock(&call->user_mutex); + if (!test_bit(RXRPC_CALL_RELEASED, &call->flags)) { + rxrpc_release_call(rxrpc_sk(sock->sk), call); + + /* Make sure we're not going to call back into a kernel service */ + if (call->notify_rx) { + spin_lock_irq(&call->notify_lock); + call->notify_rx = rxrpc_dummy_notify_rx; + spin_unlock_irq(&call->notify_lock); + } + } + mutex_unlock(&call->user_mutex); +} +EXPORT_SYMBOL(rxrpc_kernel_shutdown_call); + +/** + * rxrpc_kernel_put_call - Release a reference to a call + * @sock: The socket the call is on + * @call: The call to put + * + * Drop the application's ref on an rxrpc call. + */ +void rxrpc_kernel_put_call(struct socket *sock, struct rxrpc_call *call) { - _enter("%d{%d}", call->debug_id, atomic_read(&call->usage)); - rxrpc_remove_user_ID(call->socket, call); - rxrpc_put_call(call); + rxrpc_put_call(call, rxrpc_call_put_kernel); } +EXPORT_SYMBOL(rxrpc_kernel_put_call); -EXPORT_SYMBOL(rxrpc_kernel_end_call); +/** + * rxrpc_kernel_check_life - Check to see whether a call is still alive + * @sock: The socket the call is on + * @call: The call to check + * + * Allow a kernel service to find out whether a call is still alive - whether + * it has completed successfully and all received data has been consumed. + * + * Return: %true if the call is still ongoing and %false if it has completed. + */ +bool rxrpc_kernel_check_life(const struct socket *sock, + const struct rxrpc_call *call) +{ + if (!rxrpc_call_is_complete(call)) + return true; + if (call->completion != RXRPC_CALL_SUCCEEDED) + return false; + return !skb_queue_empty(&call->recvmsg_queue); +} +EXPORT_SYMBOL(rxrpc_kernel_check_life); /** - * rxrpc_kernel_intercept_rx_messages - Intercept received RxRPC messages - * @sock: The socket to intercept received messages on - * @interceptor: The function to pass the messages to + * rxrpc_kernel_set_notifications - Set table of callback operations + * @sock: The socket to install table upon + * @app_ops: Callback operation table to set * - * Allow a kernel service to intercept messages heading for the Rx queue on an - * RxRPC socket. They get passed to the specified function instead. - * @interceptor should free the socket buffers it is given. @interceptor is - * called with the socket receive queue spinlock held and softirqs disabled - - * this ensures that the messages will be delivered in the right order. + * Allow a kernel service to set a table of event notifications on a socket. */ -void rxrpc_kernel_intercept_rx_messages(struct socket *sock, - rxrpc_interceptor_t interceptor) +void rxrpc_kernel_set_notifications(struct socket *sock, + const struct rxrpc_kernel_ops *app_ops) { struct rxrpc_sock *rx = rxrpc_sk(sock->sk); - _enter(""); - rx->interceptor = interceptor; + rx->app_ops = app_ops; } - -EXPORT_SYMBOL(rxrpc_kernel_intercept_rx_messages); +EXPORT_SYMBOL(rxrpc_kernel_set_notifications); /* * connect an RxRPC socket * - this just targets it at a specific destination; no actual connection * negotiation takes place */ -static int rxrpc_connect(struct socket *sock, struct sockaddr *addr, +static int rxrpc_connect(struct socket *sock, struct sockaddr_unsized *addr, int addr_len, int flags) { - struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *) addr; - struct sock *sk = sock->sk; - struct rxrpc_transport *trans; - struct rxrpc_local *local; - struct rxrpc_sock *rx = rxrpc_sk(sk); + struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *)addr; + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); int ret; _enter("%p,%p,%d,%d", rx, addr, addr_len, flags); @@ -390,46 +498,29 @@ static int rxrpc_connect(struct socket *sock, struct sockaddr *addr, lock_sock(&rx->sk); + ret = -EISCONN; + if (test_bit(RXRPC_SOCK_CONNECTED, &rx->flags)) + goto error; + switch (rx->sk.sk_state) { - case RXRPC_UNCONNECTED: - /* find a local transport endpoint if we don't have one already */ - ASSERTCMP(rx->local, ==, NULL); - rx->srx.srx_family = AF_RXRPC; - rx->srx.srx_service = 0; - rx->srx.transport_type = srx->transport_type; - rx->srx.transport_len = sizeof(sa_family_t); - rx->srx.transport.family = srx->transport.family; - local = rxrpc_lookup_local(&rx->srx); - if (IS_ERR(local)) { - release_sock(&rx->sk); - return PTR_ERR(local); - } - rx->local = local; - rx->sk.sk_state = RXRPC_CLIENT_BOUND; + case RXRPC_UNBOUND: + rx->sk.sk_state = RXRPC_CLIENT_UNBOUND; + break; + case RXRPC_CLIENT_UNBOUND: case RXRPC_CLIENT_BOUND: break; - case RXRPC_CLIENT_CONNECTED: - release_sock(&rx->sk); - return -EISCONN; default: - release_sock(&rx->sk); - return -EBUSY; /* server sockets can't connect as well */ - } - - trans = rxrpc_name_to_transport(sock, addr, addr_len, flags, - GFP_KERNEL); - if (IS_ERR(trans)) { - release_sock(&rx->sk); - _leave(" = %ld", PTR_ERR(trans)); - return PTR_ERR(trans); + ret = -EBUSY; + goto error; } - rx->trans = trans; - rx->service_id = htons(srx->srx_service); - rx->sk.sk_state = RXRPC_CLIENT_CONNECTED; + rx->connect_srx = *srx; + set_bit(RXRPC_SOCK_CONNECTED, &rx->flags); + ret = 0; +error: release_sock(&rx->sk); - return 0; + return ret; } /* @@ -441,10 +532,9 @@ static int rxrpc_connect(struct socket *sock, struct sockaddr *addr, * - sends a call data packet * - may send an abort (abort code in control data) */ -static int rxrpc_sendmsg(struct kiocb *iocb, struct socket *sock, - struct msghdr *m, size_t len) +static int rxrpc_sendmsg(struct socket *sock, struct msghdr *m, size_t len) { - struct rxrpc_transport *trans; + struct rxrpc_local *local; struct rxrpc_sock *rx = rxrpc_sk(sock->sk); int ret; @@ -461,60 +551,87 @@ static int rxrpc_sendmsg(struct kiocb *iocb, struct socket *sock, } } - trans = NULL; lock_sock(&rx->sk); - if (m->msg_name) { - ret = -EISCONN; - trans = rxrpc_name_to_transport(sock, m->msg_name, - m->msg_namelen, 0, GFP_KERNEL); - if (IS_ERR(trans)) { - ret = PTR_ERR(trans); - trans = NULL; - goto out; - } - } else { - trans = rx->trans; - if (trans) - atomic_inc(&trans->usage); - } - switch (rx->sk.sk_state) { - case RXRPC_SERVER_LISTENING: - if (!m->msg_name) { - ret = rxrpc_server_sendmsg(iocb, rx, m, len); + case RXRPC_UNBOUND: + case RXRPC_CLIENT_UNBOUND: + rx->srx.srx_family = AF_RXRPC; + rx->srx.srx_service = 0; + rx->srx.transport_type = SOCK_DGRAM; + rx->srx.transport.family = rx->family; + switch (rx->family) { + case AF_INET: + rx->srx.transport_len = sizeof(struct sockaddr_in); + break; +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + rx->srx.transport_len = sizeof(struct sockaddr_in6); break; +#endif + default: + ret = -EAFNOSUPPORT; + goto error_unlock; } - case RXRPC_SERVER_BOUND: + local = rxrpc_lookup_local(sock_net(sock->sk), &rx->srx); + if (IS_ERR(local)) { + ret = PTR_ERR(local); + goto error_unlock; + } + + rx->local = local; + rx->sk.sk_state = RXRPC_CLIENT_BOUND; + fallthrough; + case RXRPC_CLIENT_BOUND: - if (!m->msg_name) { - ret = -ENOTCONN; - break; + if (!m->msg_name && + test_bit(RXRPC_SOCK_CONNECTED, &rx->flags)) { + m->msg_name = &rx->connect_srx; + m->msg_namelen = sizeof(rx->connect_srx); } - case RXRPC_CLIENT_CONNECTED: - ret = rxrpc_client_sendmsg(iocb, rx, trans, m, len); - break; + fallthrough; + case RXRPC_SERVER_BOUND: + case RXRPC_SERVER_LISTENING: + if (m->msg_flags & MSG_OOB) + ret = rxrpc_sendmsg_oob(rx, m, len); + else + ret = rxrpc_do_sendmsg(rx, m, len); + /* The socket has been unlocked */ + goto out; default: - ret = -ENOTCONN; - break; + ret = -EINVAL; + goto error_unlock; } -out: +error_unlock: release_sock(&rx->sk); - if (trans) - rxrpc_put_transport(trans); +out: _leave(" = %d", ret); return ret; } +int rxrpc_sock_set_min_security_level(struct sock *sk, unsigned int val) +{ + if (sk->sk_state != RXRPC_UNBOUND) + return -EISCONN; + if (val > RXRPC_SECURITY_MAX) + return -EINVAL; + lock_sock(sk); + rxrpc_sk(sk)->min_sec_level = val; + release_sock(sk); + return 0; +} +EXPORT_SYMBOL(rxrpc_sock_set_min_security_level); + /* * set RxRPC socket options */ static int rxrpc_setsockopt(struct socket *sock, int level, int optname, - char __user *optval, unsigned int optlen) + sockptr_t optval, unsigned int optlen) { struct rxrpc_sock *rx = rxrpc_sk(sock->sk); - unsigned int min_sec_level; + unsigned int min_sec_level, val; + u16 service_upgrade[2]; int ret; _enter(",%d,%d,,%d", level, optname, optlen); @@ -529,9 +646,9 @@ static int rxrpc_setsockopt(struct socket *sock, int level, int optname, if (optlen != 0) goto error; ret = -EISCONN; - if (rx->sk.sk_state != RXRPC_UNCONNECTED) + if (rx->sk.sk_state != RXRPC_UNBOUND) goto error; - set_bit(RXRPC_SOCK_EXCLUSIVE_CONN, &rx->flags); + rx->exclusive = true; goto success; case RXRPC_SECURITY_KEY: @@ -539,7 +656,7 @@ static int rxrpc_setsockopt(struct socket *sock, int level, int optname, if (rx->key) goto error; ret = -EISCONN; - if (rx->sk.sk_state != RXRPC_UNCONNECTED) + if (rx->sk.sk_state != RXRPC_UNBOUND) goto error; ret = rxrpc_request_key(rx, optval, optlen); goto error; @@ -549,7 +666,7 @@ static int rxrpc_setsockopt(struct socket *sock, int level, int optname, if (rx->key) goto error; ret = -EISCONN; - if (rx->sk.sk_state != RXRPC_UNCONNECTED) + if (rx->sk.sk_state != RXRPC_UNBOUND) goto error; ret = rxrpc_server_keyring(rx, optval, optlen); goto error; @@ -559,11 +676,12 @@ static int rxrpc_setsockopt(struct socket *sock, int level, int optname, if (optlen != sizeof(unsigned int)) goto error; ret = -EISCONN; - if (rx->sk.sk_state != RXRPC_UNCONNECTED) + if (rx->sk.sk_state != RXRPC_UNBOUND) goto error; - ret = get_user(min_sec_level, - (unsigned int __user *) optval); - if (ret < 0) + ret = copy_safe_from_sockptr(&min_sec_level, + sizeof(min_sec_level), + optval, optlen); + if (ret) goto error; ret = -EINVAL; if (min_sec_level > RXRPC_SECURITY_MAX) @@ -571,6 +689,48 @@ static int rxrpc_setsockopt(struct socket *sock, int level, int optname, rx->min_sec_level = min_sec_level; goto success; + case RXRPC_UPGRADEABLE_SERVICE: + ret = -EINVAL; + if (optlen != sizeof(service_upgrade) || + rx->service_upgrade.from != 0) + goto error; + ret = -EISCONN; + if (rx->sk.sk_state != RXRPC_SERVER_BOUND2) + goto error; + ret = -EFAULT; + if (copy_from_sockptr(service_upgrade, optval, + sizeof(service_upgrade)) != 0) + goto error; + ret = -EINVAL; + if ((service_upgrade[0] != rx->srx.srx_service || + service_upgrade[1] != rx->second_service) && + (service_upgrade[0] != rx->second_service || + service_upgrade[1] != rx->srx.srx_service)) + goto error; + rx->service_upgrade.from = service_upgrade[0]; + rx->service_upgrade.to = service_upgrade[1]; + goto success; + + case RXRPC_MANAGE_RESPONSE: + ret = -EINVAL; + if (optlen != sizeof(unsigned int)) + goto error; + ret = -EISCONN; + if (rx->sk.sk_state != RXRPC_UNBOUND) + goto error; + ret = copy_safe_from_sockptr(&val, sizeof(val), + optval, optlen); + if (ret) + goto error; + ret = -EINVAL; + if (val > 1) + goto error; + if (val) + set_bit(RXRPC_SOCK_MANAGE_RESPONSE, &rx->flags); + else + clear_bit(RXRPC_SOCK_MANAGE_RESPONSE, &rx->flags); + goto success; + default: break; } @@ -584,27 +744,56 @@ error: } /* + * Get socket options. + */ +static int rxrpc_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *_optlen) +{ + int optlen; + + if (level != SOL_RXRPC) + return -EOPNOTSUPP; + + if (get_user(optlen, _optlen)) + return -EFAULT; + + switch (optname) { + case RXRPC_SUPPORTED_CMSG: + if (optlen < sizeof(int)) + return -ETOOSMALL; + if (put_user(RXRPC__SUPPORTED - 1, (int __user *)optval) || + put_user(sizeof(int), _optlen)) + return -EFAULT; + return 0; + + default: + return -EOPNOTSUPP; + } +} + +/* * permit an RxRPC socket to be polled */ -static unsigned int rxrpc_poll(struct file *file, struct socket *sock, +static __poll_t rxrpc_poll(struct file *file, struct socket *sock, poll_table *wait) { - unsigned int mask; struct sock *sk = sock->sk; + struct rxrpc_sock *rx = rxrpc_sk(sk); + __poll_t mask; - sock_poll_wait(file, sk_sleep(sk), wait); + sock_poll_wait(file, sock, wait); mask = 0; /* the socket is readable if there are any messages waiting on the Rx * queue */ - if (!skb_queue_empty(&sk->sk_receive_queue)) - mask |= POLLIN | POLLRDNORM; + if (!list_empty(&rx->recvmsg_q)) + mask |= EPOLLIN | EPOLLRDNORM; /* the socket is writable if there is space to add new data to the * socket; there is no guarantee that any particular call in progress * on the socket may have space in the Tx ACK window */ if (rxrpc_writable(sk)) - mask |= POLLOUT | POLLWRNORM; + mask |= EPOLLOUT | EPOLLWRNORM; return mask; } @@ -615,16 +804,15 @@ static unsigned int rxrpc_poll(struct file *file, struct socket *sock, static int rxrpc_create(struct net *net, struct socket *sock, int protocol, int kern) { + struct rxrpc_net *rxnet; struct rxrpc_sock *rx; struct sock *sk; _enter("%p,%d", sock, protocol); - if (!net_eq(net, &init_net)) - return -EAFNOSUPPORT; - - /* we support transport protocol UDP only */ - if (protocol != PF_INET) + /* we support transport protocol UDP/UDP6 only */ + if (protocol != PF_INET && + IS_ENABLED(CONFIG_AF_RXRPC_IPV6) && protocol != PF_INET6) return -EPROTONOSUPPORT; if (sock->type != SOCK_DGRAM) @@ -633,40 +821,99 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol, sock->ops = &rxrpc_rpc_ops; sock->state = SS_UNCONNECTED; - sk = sk_alloc(net, PF_RXRPC, GFP_KERNEL, &rxrpc_proto); + sk = sk_alloc(net, PF_RXRPC, GFP_KERNEL, &rxrpc_proto, kern); if (!sk) return -ENOMEM; sock_init_data(sock, sk); - sk->sk_state = RXRPC_UNCONNECTED; + sock_set_flag(sk, SOCK_RCU_FREE); + sk->sk_state = RXRPC_UNBOUND; sk->sk_write_space = rxrpc_write_space; - sk->sk_max_ack_backlog = sysctl_rxrpc_max_qlen; + sk->sk_max_ack_backlog = 0; sk->sk_destruct = rxrpc_sock_destructor; rx = rxrpc_sk(sk); - rx->proto = protocol; + rx->family = protocol; rx->calls = RB_ROOT; - INIT_LIST_HEAD(&rx->listen_link); - INIT_LIST_HEAD(&rx->secureq); - INIT_LIST_HEAD(&rx->acceptq); + spin_lock_init(&rx->incoming_lock); + skb_queue_head_init(&rx->recvmsg_oobq); + rx->pending_oobq = RB_ROOT; + INIT_LIST_HEAD(&rx->sock_calls); + INIT_LIST_HEAD(&rx->to_be_accepted); + INIT_LIST_HEAD(&rx->recvmsg_q); + spin_lock_init(&rx->recvmsg_lock); rwlock_init(&rx->call_lock); memset(&rx->srx, 0, sizeof(rx->srx)); + rxnet = rxrpc_net(sock_net(&rx->sk)); + timer_reduce(&rxnet->peer_keepalive_timer, jiffies + 1); + _leave(" = 0 [%p]", rx); return 0; } /* + * Kill all the calls on a socket and shut it down. + */ +static int rxrpc_shutdown(struct socket *sock, int flags) +{ + struct sock *sk = sock->sk; + struct rxrpc_sock *rx = rxrpc_sk(sk); + int ret = 0; + + _enter("%p,%d", sk, flags); + + if (flags != SHUT_RDWR) + return -EOPNOTSUPP; + if (sk->sk_state == RXRPC_CLOSE) + return -ESHUTDOWN; + + lock_sock(sk); + + if (sk->sk_state < RXRPC_CLOSE) { + spin_lock_irq(&rx->recvmsg_lock); + sk->sk_state = RXRPC_CLOSE; + sk->sk_shutdown = SHUTDOWN_MASK; + spin_unlock_irq(&rx->recvmsg_lock); + } else { + ret = -ESHUTDOWN; + } + + rxrpc_discard_prealloc(rx); + + release_sock(sk); + return ret; +} + +/* + * Purge the out-of-band queue. + */ +static void rxrpc_purge_oob_queue(struct sock *sk) +{ + struct rxrpc_sock *rx = rxrpc_sk(sk); + struct sk_buff *skb; + + while ((skb = skb_dequeue(&rx->recvmsg_oobq))) + rxrpc_kernel_free_oob(skb); + while (!RB_EMPTY_ROOT(&rx->pending_oobq)) { + skb = rb_entry(rx->pending_oobq.rb_node, struct sk_buff, rbnode); + rb_erase(&skb->rbnode, &rx->pending_oobq); + rxrpc_kernel_free_oob(skb); + } +} + +/* * RxRPC socket destructor */ static void rxrpc_sock_destructor(struct sock *sk) { _enter("%p", sk); + rxrpc_purge_oob_queue(sk); rxrpc_purge_queue(&sk->sk_receive_queue); - WARN_ON(atomic_read(&sk->sk_wmem_alloc)); + WARN_ON(refcount_read(&sk->sk_wmem_alloc)); WARN_ON(!sk_unhashed(sk)); WARN_ON(sk->sk_socket); @@ -683,47 +930,45 @@ static int rxrpc_release_sock(struct sock *sk) { struct rxrpc_sock *rx = rxrpc_sk(sk); - _enter("%p{%d,%d}", sk, sk->sk_state, atomic_read(&sk->sk_refcnt)); + _enter("%p{%d,%d}", sk, sk->sk_state, refcount_read(&sk->sk_refcnt)); /* declare the socket closed for business */ sock_orphan(sk); sk->sk_shutdown = SHUTDOWN_MASK; - spin_lock_bh(&sk->sk_receive_queue.lock); - sk->sk_state = RXRPC_CLOSE; - spin_unlock_bh(&sk->sk_receive_queue.lock); + /* We want to kill off all connections from a service socket + * as fast as possible because we can't share these; client + * sockets, on the other hand, can share an endpoint. + */ + switch (sk->sk_state) { + case RXRPC_SERVER_BOUND: + case RXRPC_SERVER_BOUND2: + case RXRPC_SERVER_LISTENING: + case RXRPC_SERVER_LISTEN_DISABLED: + rx->local->service_closed = true; + break; + } - ASSERTCMP(rx->listen_link.next, !=, LIST_POISON1); + spin_lock_irq(&rx->recvmsg_lock); + sk->sk_state = RXRPC_CLOSE; + spin_unlock_irq(&rx->recvmsg_lock); - if (!list_empty(&rx->listen_link)) { - write_lock_bh(&rx->local->services_lock); - list_del(&rx->listen_link); - write_unlock_bh(&rx->local->services_lock); + if (rx->local && rx->local->service == rx) { + write_lock(&rx->local->services_lock); + rx->local->service = NULL; + write_unlock(&rx->local->services_lock); } /* try to flush out this socket */ + rxrpc_discard_prealloc(rx); rxrpc_release_calls_on_socket(rx); flush_workqueue(rxrpc_workqueue); + rxrpc_purge_oob_queue(sk); rxrpc_purge_queue(&sk->sk_receive_queue); - if (rx->conn) { - rxrpc_put_connection(rx->conn); - rx->conn = NULL; - } - - if (rx->bundle) { - rxrpc_put_bundle(rx->trans, rx->bundle); - rx->bundle = NULL; - } - if (rx->trans) { - rxrpc_put_transport(rx->trans); - rx->trans = NULL; - } - if (rx->local) { - rxrpc_put_local(rx->local); - rx->local = NULL; - } - + rxrpc_unuse_local(rx->local, rxrpc_local_unuse_release_sock); + rxrpc_put_local(rx->local, rxrpc_local_put_release_sock); + rx->local = NULL; key_put(rx->key); rx->key = NULL; key_put(rx->securities); @@ -755,7 +1000,7 @@ static int rxrpc_release(struct socket *sock) * RxRPC network protocol */ static const struct proto_ops rxrpc_rpc_ops = { - .family = PF_UNIX, + .family = PF_RXRPC, .owner = THIS_MODULE, .release = rxrpc_release, .bind = rxrpc_bind, @@ -766,20 +1011,19 @@ static const struct proto_ops rxrpc_rpc_ops = { .poll = rxrpc_poll, .ioctl = sock_no_ioctl, .listen = rxrpc_listen, - .shutdown = sock_no_shutdown, + .shutdown = rxrpc_shutdown, .setsockopt = rxrpc_setsockopt, - .getsockopt = sock_no_getsockopt, + .getsockopt = rxrpc_getsockopt, .sendmsg = rxrpc_sendmsg, .recvmsg = rxrpc_recvmsg, .mmap = sock_no_mmap, - .sendpage = sock_no_sendpage, }; static struct proto rxrpc_proto = { .name = "RXRPC", .owner = THIS_MODULE, .obj_size = sizeof(struct rxrpc_sock), - .max_header = sizeof(struct rxrpc_header), + .max_header = sizeof(struct rxrpc_wire_header), }; static const struct net_proto_family rxrpc_family_ops = { @@ -795,56 +1039,68 @@ static int __init af_rxrpc_init(void) { int ret = -1; - BUILD_BUG_ON(sizeof(struct rxrpc_skb_priv) > FIELD_SIZEOF(struct sk_buff, cb)); - - rxrpc_epoch = htonl(get_seconds()); + BUILD_BUG_ON(sizeof(struct rxrpc_skb_priv) > sizeof_field(struct sk_buff, cb)); ret = -ENOMEM; + rxrpc_gen_version_string(); rxrpc_call_jar = kmem_cache_create( "rxrpc_call_jar", sizeof(struct rxrpc_call), 0, SLAB_HWCACHE_ALIGN, NULL); if (!rxrpc_call_jar) { - printk(KERN_NOTICE "RxRPC: Failed to allocate call jar\n"); + pr_notice("Failed to allocate call jar\n"); goto error_call_jar; } - rxrpc_workqueue = alloc_workqueue("krxrpcd", 0, 1); + rxrpc_workqueue = alloc_ordered_workqueue("krxrpcd", WQ_HIGHPRI | WQ_MEM_RECLAIM); if (!rxrpc_workqueue) { - printk(KERN_NOTICE "RxRPC: Failed to allocate work queue\n"); + pr_notice("Failed to allocate work queue\n"); goto error_work_queue; } + ret = rxrpc_init_security(); + if (ret < 0) { + pr_crit("Cannot initialise security\n"); + goto error_security; + } + + ret = register_pernet_device(&rxrpc_net_ops); + if (ret) + goto error_pernet; + ret = proto_register(&rxrpc_proto, 1); if (ret < 0) { - printk(KERN_CRIT "RxRPC: Cannot register protocol\n"); + pr_crit("Cannot register protocol\n"); goto error_proto; } ret = sock_register(&rxrpc_family_ops); if (ret < 0) { - printk(KERN_CRIT "RxRPC: Cannot register socket family\n"); + pr_crit("Cannot register socket family\n"); goto error_sock; } ret = register_key_type(&key_type_rxrpc); if (ret < 0) { - printk(KERN_CRIT "RxRPC: Cannot register client key type\n"); + pr_crit("Cannot register client key type\n"); goto error_key_type; } ret = register_key_type(&key_type_rxrpc_s); if (ret < 0) { - printk(KERN_CRIT "RxRPC: Cannot register server key type\n"); + pr_crit("Cannot register server key type\n"); goto error_key_type_s; } -#ifdef CONFIG_PROC_FS - proc_create("rxrpc_calls", 0, init_net.proc_net, &rxrpc_call_seq_fops); - proc_create("rxrpc_conns", 0, init_net.proc_net, - &rxrpc_connection_seq_fops); -#endif + ret = rxrpc_sysctl_init(); + if (ret < 0) { + pr_crit("Cannot register sysctls\n"); + goto error_sysctls; + } + return 0; +error_sysctls: + unregister_key_type(&key_type_rxrpc_s); error_key_type_s: unregister_key_type(&key_type_rxrpc); error_key_type: @@ -852,6 +1108,10 @@ error_key_type: error_sock: proto_unregister(&rxrpc_proto); error_proto: + unregister_pernet_device(&rxrpc_net_ops); +error_pernet: + rxrpc_exit_security(); +error_security: destroy_workqueue(rxrpc_workqueue); error_work_queue: kmem_cache_destroy(rxrpc_call_jar); @@ -865,23 +1125,21 @@ error_call_jar: static void __exit af_rxrpc_exit(void) { _enter(""); + rxrpc_sysctl_exit(); unregister_key_type(&key_type_rxrpc_s); unregister_key_type(&key_type_rxrpc); sock_unregister(PF_RXRPC); proto_unregister(&rxrpc_proto); - rxrpc_destroy_all_calls(); - rxrpc_destroy_all_connections(); - rxrpc_destroy_all_transports(); - rxrpc_destroy_all_peers(); - rxrpc_destroy_all_locals(); + unregister_pernet_device(&rxrpc_net_ops); + ASSERTCMP(atomic_read(&rxrpc_n_rx_skbs), ==, 0); - ASSERTCMP(atomic_read(&rxrpc_n_skbs), ==, 0); + /* Make sure the local and peer records pinned by any dying connections + * are released. + */ + rcu_barrier(); - _debug("flush scheduled work"); - flush_workqueue(rxrpc_workqueue); - remove_proc_entry("rxrpc_conns", init_net.proc_net); - remove_proc_entry("rxrpc_calls", init_net.proc_net); destroy_workqueue(rxrpc_workqueue); + rxrpc_exit_security(); kmem_cache_destroy(rxrpc_call_jar); _leave(""); } |
