diff options
Diffstat (limited to 'net/sunrpc')
74 files changed, 28281 insertions, 16338 deletions
diff --git a/net/sunrpc/.kunitconfig b/net/sunrpc/.kunitconfig new file mode 100644 index 000000000000..eb02b906c295 --- /dev/null +++ b/net/sunrpc/.kunitconfig @@ -0,0 +1,29 @@ +CONFIG_KUNIT=y +CONFIG_UBSAN=y +CONFIG_STACKTRACE=y +CONFIG_NET=y +CONFIG_NETWORK_FILESYSTEMS=y +CONFIG_INET=y +CONFIG_FILE_LOCKING=y +CONFIG_MULTIUSER=y +CONFIG_CRYPTO=y +CONFIG_CRYPTO_CBC=y +CONFIG_CRYPTO_CTS=y +CONFIG_CRYPTO_ECB=y +CONFIG_CRYPTO_HMAC=y +CONFIG_CRYPTO_CMAC=y +CONFIG_CRYPTO_MD5=y +CONFIG_CRYPTO_SHA1=y +CONFIG_CRYPTO_SHA256=y +CONFIG_CRYPTO_SHA512=y +CONFIG_CRYPTO_DES=y +CONFIG_CRYPTO_AES=y +CONFIG_CRYPTO_CAMELLIA=y +CONFIG_NFS_FS=y +CONFIG_SUNRPC=y +CONFIG_SUNRPC_GSS=y +CONFIG_RPCSEC_GSS_KRB5=y +CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA1=y +CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_CAMELLIA=y +CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA2=y +CONFIG_RPCSEC_GSS_KRB5_KUNIT_TEST=y diff --git a/net/sunrpc/Kconfig b/net/sunrpc/Kconfig index 241b54f30204..a570e7adf270 100644 --- a/net/sunrpc/Kconfig +++ b/net/sunrpc/Kconfig @@ -1,27 +1,17 @@ +# SPDX-License-Identifier: GPL-2.0-only config SUNRPC tristate + depends on MULTIUSER config SUNRPC_GSS tristate select OID_REGISTRY + depends on MULTIUSER config SUNRPC_BACKCHANNEL bool depends on SUNRPC -config SUNRPC_XPRT_RDMA - tristate - depends on SUNRPC && INFINIBAND && INFINIBAND_ADDR_TRANS - default SUNRPC && INFINIBAND - help - This option allows the NFS client and server to support - an RDMA-enabled transport. - - To compile RPC client RDMA transport support as a module, - choose M here: the module will be called xprtrdma. - - If unsure, say N. - config SUNRPC_SWAP bool depends on SUNRPC @@ -29,11 +19,10 @@ config SUNRPC_SWAP config RPCSEC_GSS_KRB5 tristate "Secure RPC: Kerberos V mechanism" depends on SUNRPC && CRYPTO - depends on CRYPTO_MD5 && CRYPTO_DES && CRYPTO_CBC && CRYPTO_CTS - depends on CRYPTO_ECB && CRYPTO_HMAC && CRYPTO_SHA1 && CRYPTO_AES - depends on CRYPTO_ARC4 default y select SUNRPC_GSS + select CRYPTO_SKCIPHER + select CRYPTO_HASH help Choose Y here to enable Secure RPC using the Kerberos version 5 GSS-API mechanism (RFC 1964). @@ -45,9 +34,63 @@ config RPCSEC_GSS_KRB5 If unsure, say Y. +config RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA1 + bool "Enable Kerberos enctypes based on AES and SHA-1" + depends on RPCSEC_GSS_KRB5 + depends on CRYPTO_CBC && CRYPTO_CTS + depends on CRYPTO_HMAC && CRYPTO_SHA1 + depends on CRYPTO_AES + default y + help + Choose Y to enable the use of Kerberos 5 encryption types + that utilize Advanced Encryption Standard (AES) ciphers and + SHA-1 digests. These include aes128-cts-hmac-sha1-96 and + aes256-cts-hmac-sha1-96. + +config RPCSEC_GSS_KRB5_ENCTYPES_CAMELLIA + bool "Enable Kerberos encryption types based on Camellia and CMAC" + depends on RPCSEC_GSS_KRB5 + depends on CRYPTO_CBC && CRYPTO_CTS && CRYPTO_CAMELLIA + depends on CRYPTO_CMAC + default n + help + Choose Y to enable the use of Kerberos 5 encryption types + that utilize Camellia ciphers (RFC 3713) and CMAC digests + (NIST Special Publication 800-38B). These include + camellia128-cts-cmac and camellia256-cts-cmac. + +config RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA2 + bool "Enable Kerberos enctypes based on AES and SHA-2" + depends on RPCSEC_GSS_KRB5 + depends on CRYPTO_CBC && CRYPTO_CTS + depends on CRYPTO_HMAC && CRYPTO_SHA256 && CRYPTO_SHA512 + depends on CRYPTO_AES + default n + help + Choose Y to enable the use of Kerberos 5 encryption types + that utilize Advanced Encryption Standard (AES) ciphers and + SHA-2 digests. These include aes128-cts-hmac-sha256-128 and + aes256-cts-hmac-sha384-192. + +config RPCSEC_GSS_KRB5_KUNIT_TEST + tristate "KUnit tests for RPCSEC GSS Kerberos" if !KUNIT_ALL_TESTS + depends on RPCSEC_GSS_KRB5 && KUNIT + default KUNIT_ALL_TESTS + help + This builds the KUnit tests for RPCSEC GSS Kerberos 5. + + KUnit tests run during boot and output the results to the debug + log in TAP format (https://testanything.org/). Only useful for + kernel devs running KUnit test harness and are not for inclusion + into a production build. + + For more information on KUnit and unit tests in general, refer + to the KUnit documentation in Documentation/dev-tools/kunit/. + config SUNRPC_DEBUG bool "RPC: Enable dprintk debugging" depends on SUNRPC && SYSCTL + select DEBUG_FS help This option enables a sysctl-based debugging interface that is be used by the 'rpcdebug' utility to turn on or off @@ -57,3 +100,32 @@ config SUNRPC_DEBUG but makes troubleshooting NFS issues significantly harder. If unsure, say Y. + +config SUNRPC_DEBUG_TRACE + bool "RPC: Send dfprintk() output to the trace buffer" + depends on SUNRPC_DEBUG && TRACING + default n + help + dprintk() output can be voluminous, which can overwhelm the + kernel's logging facility as it must be sent to the console. + This option causes dprintk() output to go to the trace buffer + instead of the kernel log. + + This will cause warnings about trace_printk() being used to be + logged at boot time, so say N unless you are debugging a problem + with sunrpc-based clients or services. + +config SUNRPC_XPRT_RDMA + tristate "RPC-over-RDMA transport" + depends on SUNRPC && INFINIBAND && INFINIBAND_ADDR_TRANS + default SUNRPC && INFINIBAND + select SG_POOL + help + This option allows the NFS client and server to use RDMA + transports (InfiniBand, iWARP, or RoCE). + + To compile this support as a module, choose M. The module + will be called rpcrdma.ko. + + If unsure, or you know there is no RDMA capability on your + hardware platform, say N. diff --git a/net/sunrpc/Makefile b/net/sunrpc/Makefile index 8209a0411bca..f89c10fe7e6a 100644 --- a/net/sunrpc/Makefile +++ b/net/sunrpc/Makefile @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: GPL-2.0 # # Makefile for Linux kernel SUN RPC # @@ -8,11 +9,13 @@ obj-$(CONFIG_SUNRPC_GSS) += auth_gss/ obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma/ sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \ - auth.o auth_null.o auth_unix.o auth_generic.o \ + auth.o auth_null.o auth_tls.o auth_unix.o \ svc.o svcsock.o svcauth.o svcauth_unix.o \ addr.o rpcb_clnt.o timer.o xdr.o \ - sunrpc_syms.o cache.o rpc_pipe.o \ - svc_xprt.o -sunrpc-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel_rqst.o bc_svc.o + sunrpc_syms.o cache.o rpc_pipe.o sysfs.o \ + svc_xprt.o \ + xprtmultipath.o +sunrpc-$(CONFIG_SUNRPC_DEBUG) += debugfs.o +sunrpc-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel_rqst.o sunrpc-$(CONFIG_PROC_FS) += stats.o sunrpc-$(CONFIG_SYSCTL) += sysctl.o diff --git a/net/sunrpc/addr.c b/net/sunrpc/addr.c index a622ad64acd8..97ff11973c49 100644 --- a/net/sunrpc/addr.c +++ b/net/sunrpc/addr.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * Copyright 2009, Oracle. All rights reserved. * @@ -81,11 +82,11 @@ static size_t rpc_ntop6(const struct sockaddr *sap, rc = snprintf(scopebuf, sizeof(scopebuf), "%c%u", IPV6_SCOPE_DELIMITER, sin6->sin6_scope_id); - if (unlikely((size_t)rc > sizeof(scopebuf))) + if (unlikely((size_t)rc >= sizeof(scopebuf))) return 0; len += rc; - if (unlikely(len > buflen)) + if (unlikely(len >= buflen)) return 0; strcat(buf, scopebuf); @@ -161,8 +162,10 @@ static int rpc_parse_scope_id(struct net *net, const char *buf, const size_t buflen, const char *delim, struct sockaddr_in6 *sin6) { - char *p; + char p[IPV6_SCOPE_ID_LEN + 1]; size_t len; + u32 scope_id = 0; + struct net_device *dev; if ((buf + buflen) == delim) return 1; @@ -174,29 +177,23 @@ static int rpc_parse_scope_id(struct net *net, const char *buf, return 0; len = (buf + buflen) - delim - 1; - p = kstrndup(delim + 1, len, GFP_KERNEL); - if (p) { - unsigned long scope_id = 0; - struct net_device *dev; - - dev = dev_get_by_name(net, p); - if (dev != NULL) { - scope_id = dev->ifindex; - dev_put(dev); - } else { - if (strict_strtoul(p, 10, &scope_id) == 0) { - kfree(p); - return 0; - } - } - - kfree(p); - - sin6->sin6_scope_id = scope_id; - return 1; + if (len > IPV6_SCOPE_ID_LEN) + return 0; + + memcpy(p, delim + 1, len); + p[len] = 0; + + dev = dev_get_by_name(net, p); + if (dev != NULL) { + scope_id = dev->ifindex; + dev_put(dev); + } else { + if (kstrtou32(p, 10, &scope_id) != 0) + return 0; } - return 0; + sin6->sin6_scope_id = scope_id; + return 1; } static size_t rpc_pton6(struct net *net, const char *buf, const size_t buflen, @@ -287,10 +284,10 @@ char *rpc_sockaddr2uaddr(const struct sockaddr *sap, gfp_t gfp_flags) } if (snprintf(portbuf, sizeof(portbuf), - ".%u.%u", port >> 8, port & 0xff) > (int)sizeof(portbuf)) + ".%u.%u", port >> 8, port & 0xff) >= (int)sizeof(portbuf)) return NULL; - if (strlcat(addrbuf, portbuf, sizeof(addrbuf)) > sizeof(addrbuf)) + if (strlcat(addrbuf, portbuf, sizeof(addrbuf)) >= sizeof(addrbuf)) return NULL; return kstrdup(addrbuf, gfp_flags); @@ -304,7 +301,7 @@ char *rpc_sockaddr2uaddr(const struct sockaddr *sap, gfp_t gfp_flags) * @sap: buffer into which to plant socket address * @salen: size of buffer * - * @uaddr does not have to be '\0'-terminated, but strict_strtoul() and + * @uaddr does not have to be '\0'-terminated, but kstrtou8() and * rpc_pton() require proper string termination to be successful. * * Returns the size of the socket address if successful; otherwise @@ -315,7 +312,7 @@ size_t rpc_uaddr2sockaddr(struct net *net, const char *uaddr, const size_t salen) { char *c, buf[RPCBIND_MAXUADDRLEN + sizeof('\0')]; - unsigned long portlo, porthi; + u8 portlo, porthi; unsigned short port; if (uaddr_len > RPCBIND_MAXUADDRLEN) @@ -327,18 +324,14 @@ size_t rpc_uaddr2sockaddr(struct net *net, const char *uaddr, c = strrchr(buf, '.'); if (unlikely(c == NULL)) return 0; - if (unlikely(strict_strtoul(c + 1, 10, &portlo) != 0)) - return 0; - if (unlikely(portlo > 255)) + if (unlikely(kstrtou8(c + 1, 10, &portlo) != 0)) return 0; *c = '\0'; c = strrchr(buf, '.'); if (unlikely(c == NULL)) return 0; - if (unlikely(strict_strtoul(c + 1, 10, &porthi) != 0)) - return 0; - if (unlikely(porthi > 255)) + if (unlikely(kstrtou8(c + 1, 10, &porthi) != 0)) return 0; port = (unsigned short)((porthi << 8) | portlo); diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c index ed2fdd210c0b..5a827afd8e3b 100644 --- a/net/sunrpc/auth.c +++ b/net/sunrpc/auth.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/auth.c * @@ -8,6 +9,7 @@ #include <linux/types.h> #include <linux/sched.h> +#include <linux/cred.h> #include <linux/module.h> #include <linux/slab.h> #include <linux/errno.h> @@ -16,9 +18,7 @@ #include <linux/sunrpc/gss_api.h> #include <linux/spinlock.h> -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif +#include <trace/events/sunrpc.h> #define RPC_CREDCACHE_DEFAULT_HASHBITS (4) struct rpc_cred_cache { @@ -29,16 +29,29 @@ struct rpc_cred_cache { static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS; -static DEFINE_SPINLOCK(rpc_authflavor_lock); -static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = { - &authnull_ops, /* AUTH_NULL */ - &authunix_ops, /* AUTH_UNIX */ - NULL, /* others can be loadable modules */ +static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = { + [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops, + [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops, + [RPC_AUTH_TLS] = (const struct rpc_authops __force __rcu *)&authtls_ops, }; static LIST_HEAD(cred_unused); static unsigned long number_cred_unused; +static struct cred machine_cred = { + .usage = ATOMIC_INIT(1), +}; + +/* + * Return the machine_cred pointer to be used whenever + * the a generic machine credential is needed. + */ +const struct cred *rpc_machine_cred(void) +{ + return &machine_cred; +} +EXPORT_SYMBOL_GPL(rpc_machine_cred); + #define MAX_HASHTABLE_BITS (14) static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp) { @@ -48,12 +61,10 @@ static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp) if (!val) goto out_inval; - ret = strict_strtoul(val, 0, &num); - if (ret == -EINVAL) + ret = kstrtoul(val, 0, &num); + if (ret) goto out_inval; - nbits = fls(num); - if (num > (1U << nbits)) - nbits++; + nbits = fls(num - 1); if (nbits > MAX_HASHTABLE_BITS || nbits < 2) goto out_inval; *(unsigned int *)kp->arg = nbits; @@ -67,12 +78,12 @@ static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp) unsigned int nbits; nbits = *(unsigned int *)kp->arg; - return sprintf(buffer, "%u", 1U << nbits); + return sprintf(buffer, "%u\n", 1U << nbits); } #define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int); -static struct kernel_param_ops param_ops_hashtbl_sz = { +static const struct kernel_param_ops param_ops_hashtbl_sz = { .set = param_set_hashtbl_sz, .get = param_get_hashtbl_sz, }; @@ -80,6 +91,10 @@ static struct kernel_param_ops param_ops_hashtbl_sz = { module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644); MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size"); +static unsigned long auth_max_cred_cachesize = ULONG_MAX; +module_param(auth_max_cred_cachesize, ulong, 0644); +MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size"); + static u32 pseudoflavor_to_flavor(u32 flavor) { if (flavor > RPC_AUTH_MAXFLAVOR) @@ -90,39 +105,65 @@ pseudoflavor_to_flavor(u32 flavor) { int rpcauth_register(const struct rpc_authops *ops) { + const struct rpc_authops *old; rpc_authflavor_t flavor; - int ret = -EPERM; if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) return -EINVAL; - spin_lock(&rpc_authflavor_lock); - if (auth_flavors[flavor] == NULL) { - auth_flavors[flavor] = ops; - ret = 0; - } - spin_unlock(&rpc_authflavor_lock); - return ret; + old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops); + if (old == NULL || old == ops) + return 0; + return -EPERM; } EXPORT_SYMBOL_GPL(rpcauth_register); int rpcauth_unregister(const struct rpc_authops *ops) { + const struct rpc_authops *old; rpc_authflavor_t flavor; - int ret = -EPERM; if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) return -EINVAL; - spin_lock(&rpc_authflavor_lock); - if (auth_flavors[flavor] == ops) { - auth_flavors[flavor] = NULL; - ret = 0; - } - spin_unlock(&rpc_authflavor_lock); - return ret; + + old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL); + if (old == ops || old == NULL) + return 0; + return -EPERM; } EXPORT_SYMBOL_GPL(rpcauth_unregister); +static const struct rpc_authops * +rpcauth_get_authops(rpc_authflavor_t flavor) +{ + const struct rpc_authops *ops; + + if (flavor >= RPC_AUTH_MAXFLAVOR) + return NULL; + + rcu_read_lock(); + ops = rcu_dereference(auth_flavors[flavor]); + if (ops == NULL) { + rcu_read_unlock(); + request_module("rpc-auth-%u", flavor); + rcu_read_lock(); + ops = rcu_dereference(auth_flavors[flavor]); + if (ops == NULL) + goto out; + } + if (!try_module_get(ops->owner)) + ops = NULL; +out: + rcu_read_unlock(); + return ops; +} + +static void +rpcauth_put_authops(const struct rpc_authops *ops) +{ + module_put(ops->owner); +} + /** * rpcauth_get_pseudoflavor - check if security flavor is supported * @flavor: a security flavor @@ -135,25 +176,16 @@ EXPORT_SYMBOL_GPL(rpcauth_unregister); rpc_authflavor_t rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info) { - const struct rpc_authops *ops; + const struct rpc_authops *ops = rpcauth_get_authops(flavor); rpc_authflavor_t pseudoflavor; - ops = auth_flavors[flavor]; - if (ops == NULL) - request_module("rpc-auth-%u", flavor); - spin_lock(&rpc_authflavor_lock); - ops = auth_flavors[flavor]; - if (ops == NULL || !try_module_get(ops->owner)) { - spin_unlock(&rpc_authflavor_lock); + if (!ops) return RPC_AUTH_MAXFLAVOR; - } - spin_unlock(&rpc_authflavor_lock); - pseudoflavor = flavor; if (ops->info2flavor != NULL) pseudoflavor = ops->info2flavor(info); - module_put(ops->owner); + rpcauth_put_authops(ops); return pseudoflavor; } EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor); @@ -173,104 +205,33 @@ rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info) const struct rpc_authops *ops; int result; - if (flavor >= RPC_AUTH_MAXFLAVOR) - return -EINVAL; - - ops = auth_flavors[flavor]; + ops = rpcauth_get_authops(flavor); if (ops == NULL) - request_module("rpc-auth-%u", flavor); - spin_lock(&rpc_authflavor_lock); - ops = auth_flavors[flavor]; - if (ops == NULL || !try_module_get(ops->owner)) { - spin_unlock(&rpc_authflavor_lock); return -ENOENT; - } - spin_unlock(&rpc_authflavor_lock); result = -ENOENT; if (ops->flavor2info != NULL) result = ops->flavor2info(pseudoflavor, info); - module_put(ops->owner); + rpcauth_put_authops(ops); return result; } EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo); -/** - * rpcauth_list_flavors - discover registered flavors and pseudoflavors - * @array: array to fill in - * @size: size of "array" - * - * Returns the number of array items filled in, or a negative errno. - * - * The returned array is not sorted by any policy. Callers should not - * rely on the order of the items in the returned array. - */ -int -rpcauth_list_flavors(rpc_authflavor_t *array, int size) -{ - rpc_authflavor_t flavor; - int result = 0; - - spin_lock(&rpc_authflavor_lock); - for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) { - const struct rpc_authops *ops = auth_flavors[flavor]; - rpc_authflavor_t pseudos[4]; - int i, len; - - if (result >= size) { - result = -ENOMEM; - break; - } - - if (ops == NULL) - continue; - if (ops->list_pseudoflavors == NULL) { - array[result++] = ops->au_flavor; - continue; - } - len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos)); - if (len < 0) { - result = len; - break; - } - for (i = 0; i < len; i++) { - if (result >= size) { - result = -ENOMEM; - break; - } - array[result++] = pseudos[i]; - } - } - spin_unlock(&rpc_authflavor_lock); - - dprintk("RPC: %s returns %d\n", __func__, result); - return result; -} -EXPORT_SYMBOL_GPL(rpcauth_list_flavors); - struct rpc_auth * -rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt) +rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { - struct rpc_auth *auth; + struct rpc_auth *auth = ERR_PTR(-EINVAL); const struct rpc_authops *ops; - u32 flavor = pseudoflavor_to_flavor(pseudoflavor); + u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor); - auth = ERR_PTR(-EINVAL); - if (flavor >= RPC_AUTH_MAXFLAVOR) + ops = rpcauth_get_authops(flavor); + if (ops == NULL) goto out; - if ((ops = auth_flavors[flavor]) == NULL) - request_module("rpc-auth-%u", flavor); - spin_lock(&rpc_authflavor_lock); - ops = auth_flavors[flavor]; - if (ops == NULL || !try_module_get(ops->owner)) { - spin_unlock(&rpc_authflavor_lock); - goto out; - } - spin_unlock(&rpc_authflavor_lock); - auth = ops->create(clnt, pseudoflavor); - module_put(ops->owner); + auth = ops->create(args, clnt); + + rpcauth_put_authops(ops); if (IS_ERR(auth)) return auth; if (clnt->cl_auth) @@ -285,32 +246,37 @@ EXPORT_SYMBOL_GPL(rpcauth_create); void rpcauth_release(struct rpc_auth *auth) { - if (!atomic_dec_and_test(&auth->au_count)) + if (!refcount_dec_and_test(&auth->au_count)) return; auth->au_ops->destroy(auth); } static DEFINE_SPINLOCK(rpc_credcache_lock); -static void +/* + * On success, the caller is responsible for freeing the reference + * held by the hashtable + */ +static bool rpcauth_unhash_cred_locked(struct rpc_cred *cred) { + if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)) + return false; hlist_del_rcu(&cred->cr_hash); - smp_mb__before_clear_bit(); - clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); + return true; } -static int +static bool rpcauth_unhash_cred(struct rpc_cred *cred) { spinlock_t *cache_lock; - int ret; + bool ret; + if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)) + return false; cache_lock = &cred->cr_auth->au_credcache->lock; spin_lock(cache_lock); - ret = atomic_read(&cred->cr_count) == 0; - if (ret) - rpcauth_unhash_cred_locked(cred); + ret = rpcauth_unhash_cred_locked(cred); spin_unlock(cache_lock); return ret; } @@ -342,6 +308,15 @@ out_nocache: } EXPORT_SYMBOL_GPL(rpcauth_init_credcache); +char * +rpcauth_stringify_acceptor(struct rpc_cred *cred) +{ + if (!cred->cr_ops->crstringify_acceptor) + return NULL; + return cred->cr_ops->crstringify_acceptor(cred); +} +EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor); + /* * Destroy a list of credentials */ @@ -357,6 +332,44 @@ void rpcauth_destroy_credlist(struct list_head *head) } } +static void +rpcauth_lru_add_locked(struct rpc_cred *cred) +{ + if (!list_empty(&cred->cr_lru)) + return; + number_cred_unused++; + list_add_tail(&cred->cr_lru, &cred_unused); +} + +static void +rpcauth_lru_add(struct rpc_cred *cred) +{ + if (!list_empty(&cred->cr_lru)) + return; + spin_lock(&rpc_credcache_lock); + rpcauth_lru_add_locked(cred); + spin_unlock(&rpc_credcache_lock); +} + +static void +rpcauth_lru_remove_locked(struct rpc_cred *cred) +{ + if (list_empty(&cred->cr_lru)) + return; + number_cred_unused--; + list_del_init(&cred->cr_lru); +} + +static void +rpcauth_lru_remove(struct rpc_cred *cred) +{ + if (list_empty(&cred->cr_lru)) + return; + spin_lock(&rpc_credcache_lock); + rpcauth_lru_remove_locked(cred); + spin_unlock(&rpc_credcache_lock); +} + /* * Clear the RPC credential cache, and delete those credentials * that are not referenced. @@ -376,13 +389,10 @@ rpcauth_clear_credcache(struct rpc_cred_cache *cache) head = &cache->hashtable[i]; while (!hlist_empty(head)) { cred = hlist_entry(head->first, struct rpc_cred, cr_hash); - get_rpccred(cred); - if (!list_empty(&cred->cr_lru)) { - list_del(&cred->cr_lru); - number_cred_unused--; - } - list_add_tail(&cred->cr_lru, &free); rpcauth_unhash_cred_locked(cred); + /* Note: We now hold a reference to cred */ + rpcauth_lru_remove_locked(cred); + list_add_tail(&cred->cr_lru, &free); } } spin_unlock(&cache->lock); @@ -413,62 +423,88 @@ EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache); /* * Remove stale credentials. Avoid sleeping inside the loop. */ -static int +static long rpcauth_prune_expired(struct list_head *free, int nr_to_scan) { - spinlock_t *cache_lock; struct rpc_cred *cred, *next; unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM; + long freed = 0; list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) { if (nr_to_scan-- == 0) break; + if (refcount_read(&cred->cr_count) > 1) { + rpcauth_lru_remove_locked(cred); + continue; + } /* * Enforce a 60 second garbage collection moratorium * Note that the cred_unused list must be time-ordered. */ - if (time_in_range(cred->cr_expire, expired, jiffies) && - test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) - return 0; - - list_del_init(&cred->cr_lru); - number_cred_unused--; - if (atomic_read(&cred->cr_count) != 0) + if (time_in_range(cred->cr_expire, expired, jiffies)) + continue; + if (!rpcauth_unhash_cred(cred)) continue; - cache_lock = &cred->cr_auth->au_credcache->lock; - spin_lock(cache_lock); - if (atomic_read(&cred->cr_count) == 0) { - get_rpccred(cred); - list_add_tail(&cred->cr_lru, free); - rpcauth_unhash_cred_locked(cred); - } - spin_unlock(cache_lock); + rpcauth_lru_remove_locked(cred); + freed++; + list_add_tail(&cred->cr_lru, free); } - return (number_cred_unused / 100) * sysctl_vfs_cache_pressure; + return freed ? freed : SHRINK_STOP; } -/* - * Run memory cache shrinker. - */ -static int -rpcauth_cache_shrinker(struct shrinker *shrink, struct shrink_control *sc) +static unsigned long +rpcauth_cache_do_shrink(int nr_to_scan) { LIST_HEAD(free); - int res; - int nr_to_scan = sc->nr_to_scan; - gfp_t gfp_mask = sc->gfp_mask; + unsigned long freed; - if ((gfp_mask & GFP_KERNEL) != GFP_KERNEL) - return (nr_to_scan == 0) ? 0 : -1; - if (list_empty(&cred_unused)) - return 0; spin_lock(&rpc_credcache_lock); - res = rpcauth_prune_expired(&free, nr_to_scan); + freed = rpcauth_prune_expired(&free, nr_to_scan); spin_unlock(&rpc_credcache_lock); rpcauth_destroy_credlist(&free); - return res; + + return freed; +} + +/* + * Run memory cache shrinker. + */ +static unsigned long +rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc) + +{ + if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL) + return SHRINK_STOP; + + /* nothing left, don't come back */ + if (list_empty(&cred_unused)) + return SHRINK_STOP; + + return rpcauth_cache_do_shrink(sc->nr_to_scan); +} + +static unsigned long +rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc) + +{ + return number_cred_unused; +} + +static void +rpcauth_cache_enforce_limit(void) +{ + unsigned long diff; + unsigned int nr_to_scan; + + if (number_cred_unused <= auth_max_cred_cachesize) + return; + diff = number_cred_unused - auth_max_cred_cachesize; + nr_to_scan = 100; + if (diff < nr_to_scan) + nr_to_scan = diff; + rpcauth_cache_do_shrink(nr_to_scan); } /* @@ -476,7 +512,7 @@ rpcauth_cache_shrinker(struct shrinker *shrink, struct shrink_control *sc) */ struct rpc_cred * rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, - int flags) + int flags, gfp_t gfp) { LIST_HEAD(free); struct rpc_cred_cache *cache = auth->au_credcache; @@ -484,27 +520,22 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, *entry, *new; unsigned int nr; - nr = hash_long(from_kuid(&init_user_ns, acred->uid), cache->hashbits); + nr = auth->au_ops->hash_cred(acred, cache->hashbits); rcu_read_lock(); hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) { if (!entry->cr_ops->crmatch(acred, entry, flags)) continue; - spin_lock(&cache->lock); - if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) { - spin_unlock(&cache->lock); - continue; - } cred = get_rpccred(entry); - spin_unlock(&cache->lock); - break; + if (cred) + break; } rcu_read_unlock(); if (cred != NULL) goto found; - new = auth->au_ops->crcreate(auth, acred, flags); + new = auth->au_ops->crcreate(auth, acred, flags, gfp); if (IS_ERR(new)) { cred = new; goto out; @@ -515,15 +546,18 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, if (!entry->cr_ops->crmatch(acred, entry, flags)) continue; cred = get_rpccred(entry); - break; + if (cred) + break; } if (cred == NULL) { cred = new; set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); + refcount_inc(&cred->cr_count); hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]); } else list_add_tail(&new->cr_lru, &free); spin_unlock(&cache->lock); + rpcauth_cache_enforce_limit(); found: if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) && cred->cr_ops->cr_init != NULL && @@ -547,18 +581,12 @@ rpcauth_lookupcred(struct rpc_auth *auth, int flags) struct rpc_cred *ret; const struct cred *cred = current_cred(); - dprintk("RPC: looking up %s cred\n", - auth->au_ops->au_name); - memset(&acred, 0, sizeof(acred)); - acred.uid = cred->fsuid; - acred.gid = cred->fsgid; - acred.group_info = get_group_info(((struct cred *)cred)->group_info); - + acred.cred = cred; ret = auth->au_ops->lookup_cred(auth, &acred, flags); - put_group_info(acred.group_info); return ret; } +EXPORT_SYMBOL_GPL(rpcauth_lookupcred); void rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, @@ -566,37 +594,44 @@ rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, { INIT_HLIST_NODE(&cred->cr_hash); INIT_LIST_HEAD(&cred->cr_lru); - atomic_set(&cred->cr_count, 1); + refcount_set(&cred->cr_count, 1); cred->cr_auth = auth; + cred->cr_flags = 0; cred->cr_ops = ops; cred->cr_expire = jiffies; -#ifdef RPC_DEBUG - cred->cr_magic = RPCAUTH_CRED_MAGIC; -#endif - cred->cr_uid = acred->uid; + cred->cr_cred = get_cred(acred->cred); } EXPORT_SYMBOL_GPL(rpcauth_init_cred); -struct rpc_cred * -rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags) +static struct rpc_cred * +rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags) { - dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid, - cred->cr_auth->au_ops->au_name, cred); - return get_rpccred(cred); + struct rpc_auth *auth = task->tk_client->cl_auth; + struct auth_cred acred = { + .cred = get_task_cred(&init_task), + }; + struct rpc_cred *ret; + + if (RPC_IS_ASYNC(task)) + lookupflags |= RPCAUTH_LOOKUP_ASYNC; + ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags); + put_cred(acred.cred); + return ret; } -EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred); static struct rpc_cred * -rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags) +rpcauth_bind_machine_cred(struct rpc_task *task, int lookupflags) { struct rpc_auth *auth = task->tk_client->cl_auth; struct auth_cred acred = { - .uid = GLOBAL_ROOT_UID, - .gid = GLOBAL_ROOT_GID, + .principal = task->tk_client->cl_principal, + .cred = init_task.cred, }; - dprintk("RPC: %5u looking up %s cred\n", - task->tk_pid, task->tk_client->cl_auth->au_ops->au_name); + if (!acred.principal) + return NULL; + if (RPC_IS_ASYNC(task)) + lookupflags |= RPCAUTH_LOOKUP_ASYNC; return auth->au_ops->lookup_cred(auth, &acred, lookupflags); } @@ -605,30 +640,42 @@ rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags) { struct rpc_auth *auth = task->tk_client->cl_auth; - dprintk("RPC: %5u looking up %s cred\n", - task->tk_pid, auth->au_ops->au_name); return rpcauth_lookupcred(auth, lookupflags); } static int -rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags) +rpcauth_bindcred(struct rpc_task *task, const struct cred *cred, int flags) { struct rpc_rqst *req = task->tk_rqstp; - struct rpc_cred *new; + struct rpc_cred *new = NULL; int lookupflags = 0; + struct rpc_auth *auth = task->tk_client->cl_auth; + struct auth_cred acred = { + .cred = cred, + }; if (flags & RPC_TASK_ASYNC) - lookupflags |= RPCAUTH_LOOKUP_NEW; - if (cred != NULL) - new = cred->cr_ops->crbind(task, cred, lookupflags); - else if (flags & RPC_TASK_ROOTCREDS) + lookupflags |= RPCAUTH_LOOKUP_NEW | RPCAUTH_LOOKUP_ASYNC; + if (task->tk_op_cred) + /* Task must use exactly this rpc_cred */ + new = get_rpccred(task->tk_op_cred); + else if (cred != NULL && cred != &machine_cred) + new = auth->au_ops->lookup_cred(auth, &acred, lookupflags); + else if (cred == &machine_cred) + new = rpcauth_bind_machine_cred(task, lookupflags); + + /* If machine cred couldn't be bound, try a root cred */ + if (new) + ; + else if (cred == &machine_cred) new = rpcauth_bind_root_cred(task, lookupflags); + else if (flags & RPC_TASK_NULLCREDS) + new = authnull_ops.lookup_cred(NULL, NULL, 0); else new = rpcauth_bind_new_cred(task, lookupflags); if (IS_ERR(new)) return PTR_ERR(new); - if (req->rq_cred != NULL) - put_rpccred(req->rq_cred); + put_rpccred(req->rq_cred); req->rq_cred = new; return 0; } @@ -636,108 +683,145 @@ rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags) void put_rpccred(struct rpc_cred *cred) { - /* Fast path for unhashed credentials */ - if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) { - if (atomic_dec_and_test(&cred->cr_count)) - cred->cr_ops->crdestroy(cred); + if (cred == NULL) return; + rcu_read_lock(); + if (refcount_dec_and_test(&cred->cr_count)) + goto destroy; + if (refcount_read(&cred->cr_count) != 1 || + !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)) + goto out; + if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) { + cred->cr_expire = jiffies; + rpcauth_lru_add(cred); + /* Race breaker */ + if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))) + rpcauth_lru_remove(cred); + } else if (rpcauth_unhash_cred(cred)) { + rpcauth_lru_remove(cred); + if (refcount_dec_and_test(&cred->cr_count)) + goto destroy; } - - if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock)) - return; - if (!list_empty(&cred->cr_lru)) { - number_cred_unused--; - list_del_init(&cred->cr_lru); - } - if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) { - if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) { - cred->cr_expire = jiffies; - list_add_tail(&cred->cr_lru, &cred_unused); - number_cred_unused++; - goto out_nodestroy; - } - if (!rpcauth_unhash_cred(cred)) { - /* We were hashed and someone looked us up... */ - goto out_nodestroy; - } - } - spin_unlock(&rpc_credcache_lock); - cred->cr_ops->crdestroy(cred); +out: + rcu_read_unlock(); return; -out_nodestroy: - spin_unlock(&rpc_credcache_lock); +destroy: + rcu_read_unlock(); + cred->cr_ops->crdestroy(cred); } EXPORT_SYMBOL_GPL(put_rpccred); -__be32 * -rpcauth_marshcred(struct rpc_task *task, __be32 *p) +/** + * rpcauth_marshcred - Append RPC credential to end of @xdr + * @task: controlling RPC task + * @xdr: xdr_stream containing initial portion of RPC Call header + * + * On success, an appropriate verifier is added to @xdr, @xdr is + * updated to point past the verifier, and zero is returned. + * Otherwise, @xdr is in an undefined state and a negative errno + * is returned. + */ +int rpcauth_marshcred(struct rpc_task *task, struct xdr_stream *xdr) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; - - dprintk("RPC: %5u marshaling %s cred %p\n", - task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; - return cred->cr_ops->crmarshal(task, p); + return ops->crmarshal(task, xdr); } -__be32 * -rpcauth_checkverf(struct rpc_task *task, __be32 *p) +/** + * rpcauth_wrap_req_encode - XDR encode the RPC procedure + * @task: controlling RPC task + * @xdr: stream where on-the-wire bytes are to be marshalled + * + * On success, @xdr contains the encoded and wrapped message. + * Otherwise, @xdr is in an undefined state. + */ +int rpcauth_wrap_req_encode(struct rpc_task *task, struct xdr_stream *xdr) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; - - dprintk("RPC: %5u validating %s cred %p\n", - task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + kxdreproc_t encode = task->tk_msg.rpc_proc->p_encode; - return cred->cr_ops->crvalidate(task, p); + encode(task->tk_rqstp, xdr, task->tk_msg.rpc_argp); + return 0; } +EXPORT_SYMBOL_GPL(rpcauth_wrap_req_encode); -static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, - __be32 *data, void *obj) +/** + * rpcauth_wrap_req - XDR encode and wrap the RPC procedure + * @task: controlling RPC task + * @xdr: stream where on-the-wire bytes are to be marshalled + * + * On success, @xdr contains the encoded and wrapped message, + * and zero is returned. Otherwise, @xdr is in an undefined + * state and a negative errno is returned. + */ +int rpcauth_wrap_req(struct rpc_task *task, struct xdr_stream *xdr) { - struct xdr_stream xdr; + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; - xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data); - encode(rqstp, &xdr, obj); + return ops->crwrap_req(task, xdr); } +/** + * rpcauth_checkverf - Validate verifier in RPC Reply header + * @task: controlling RPC task + * @xdr: xdr_stream containing RPC Reply header + * + * Return values: + * %0: Verifier is valid. @xdr now points past the verifier. + * %-EIO: Verifier is corrupted or message ended early. + * %-EACCES: Verifier is intact but not valid. + * %-EPROTONOSUPPORT: Server does not support the requested auth type. + * + * When a negative errno is returned, @xdr is left in an undefined + * state. + */ int -rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp, - __be32 *data, void *obj) +rpcauth_checkverf(struct rpc_task *task, struct xdr_stream *xdr) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; - dprintk("RPC: %5u using %s cred %p to wrap rpc data\n", - task->tk_pid, cred->cr_ops->cr_name, cred); - if (cred->cr_ops->crwrap_req) - return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj); - /* By default, we encode the arguments normally. */ - rpcauth_wrap_req_encode(encode, rqstp, data, obj); - return 0; + return ops->crvalidate(task, xdr); } -static int -rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, - __be32 *data, void *obj) +/** + * rpcauth_unwrap_resp_decode - Invoke XDR decode function + * @task: controlling RPC task + * @xdr: stream where the Reply message resides + * + * Returns zero on success; otherwise a negative errno is returned. + */ +int +rpcauth_unwrap_resp_decode(struct rpc_task *task, struct xdr_stream *xdr) { - struct xdr_stream xdr; + kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode; - xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data); - return decode(rqstp, &xdr, obj); + return decode(task->tk_rqstp, xdr, task->tk_msg.rpc_resp); } +EXPORT_SYMBOL_GPL(rpcauth_unwrap_resp_decode); +/** + * rpcauth_unwrap_resp - Invoke unwrap and decode function for the cred + * @task: controlling RPC task + * @xdr: stream where the Reply message resides + * + * Returns zero on success; otherwise a negative errno is returned. + */ int -rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp, - __be32 *data, void *obj) +rpcauth_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr) +{ + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; + + return ops->crunwrap_resp(task, xdr); +} + +bool +rpcauth_xmit_need_reencode(struct rpc_task *task) { struct rpc_cred *cred = task->tk_rqstp->rq_cred; - dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n", - task->tk_pid, cred->cr_ops->cr_name, cred); - if (cred->cr_ops->crunwrap_resp) - return cred->cr_ops->crunwrap_resp(task, decode, rqstp, - data, obj); - /* By default, we decode the arguments normally. */ - return rpcauth_unwrap_req_decode(decode, rqstp, data, obj); + if (!cred || !cred->cr_ops->crneed_reencode) + return false; + return cred->cr_ops->crneed_reencode(task); } int @@ -753,8 +837,6 @@ rpcauth_refreshcred(struct rpc_task *task) goto out; cred = task->tk_rqstp->rq_cred; } - dprintk("RPC: %5u refreshing %s cred %p\n", - task->tk_pid, cred->cr_auth->au_ops->au_name, cred); err = cred->cr_ops->crrefresh(task); out: @@ -768,8 +850,6 @@ rpcauth_invalcred(struct rpc_task *task) { struct rpc_cred *cred = task->tk_rqstp->rq_cred; - dprintk("RPC: %5u invalidating %s cred %p\n", - task->tk_pid, cred->cr_auth->au_ops->au_name, cred); if (cred) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); } @@ -783,10 +863,7 @@ rpcauth_uptodatecred(struct rpc_task *task) test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0; } -static struct shrinker rpc_cred_shrinker = { - .shrink = rpcauth_cache_shrinker, - .seeks = DEFAULT_SEEKS, -}; +static struct shrinker *rpc_cred_shrinker; int __init rpcauth_init_module(void) { @@ -795,10 +872,17 @@ int __init rpcauth_init_module(void) err = rpc_init_authunix(); if (err < 0) goto out1; - err = rpc_init_generic_auth(); - if (err < 0) + rpc_cred_shrinker = shrinker_alloc(0, "sunrpc_cred"); + if (!rpc_cred_shrinker) { + err = -ENOMEM; goto out2; - register_shrinker(&rpc_cred_shrinker); + } + + rpc_cred_shrinker->count_objects = rpcauth_cache_shrink_count; + rpc_cred_shrinker->scan_objects = rpcauth_cache_shrink_scan; + + shrinker_register(rpc_cred_shrinker); + return 0; out2: rpc_destroy_authunix(); @@ -809,6 +893,5 @@ out1: void rpcauth_remove_module(void) { rpc_destroy_authunix(); - rpc_destroy_generic_auth(); - unregister_shrinker(&rpc_cred_shrinker); + shrinker_free(rpc_cred_shrinker); } diff --git a/net/sunrpc/auth_generic.c b/net/sunrpc/auth_generic.c deleted file mode 100644 index b6badafc6494..000000000000 --- a/net/sunrpc/auth_generic.c +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Generic RPC credential - * - * Copyright (C) 2008, Trond Myklebust <Trond.Myklebust@netapp.com> - */ - -#include <linux/err.h> -#include <linux/slab.h> -#include <linux/types.h> -#include <linux/module.h> -#include <linux/sched.h> -#include <linux/sunrpc/auth.h> -#include <linux/sunrpc/clnt.h> -#include <linux/sunrpc/debug.h> -#include <linux/sunrpc/sched.h> - -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif - -#define RPC_MACHINE_CRED_USERID GLOBAL_ROOT_UID -#define RPC_MACHINE_CRED_GROUPID GLOBAL_ROOT_GID - -struct generic_cred { - struct rpc_cred gc_base; - struct auth_cred acred; -}; - -static struct rpc_auth generic_auth; -static const struct rpc_credops generic_credops; - -/* - * Public call interface - */ -struct rpc_cred *rpc_lookup_cred(void) -{ - return rpcauth_lookupcred(&generic_auth, 0); -} -EXPORT_SYMBOL_GPL(rpc_lookup_cred); - -/* - * Public call interface for looking up machine creds. - */ -struct rpc_cred *rpc_lookup_machine_cred(const char *service_name) -{ - struct auth_cred acred = { - .uid = RPC_MACHINE_CRED_USERID, - .gid = RPC_MACHINE_CRED_GROUPID, - .principal = service_name, - .machine_cred = 1, - }; - - dprintk("RPC: looking up machine cred for service %s\n", - service_name); - return generic_auth.au_ops->lookup_cred(&generic_auth, &acred, 0); -} -EXPORT_SYMBOL_GPL(rpc_lookup_machine_cred); - -static struct rpc_cred *generic_bind_cred(struct rpc_task *task, - struct rpc_cred *cred, int lookupflags) -{ - struct rpc_auth *auth = task->tk_client->cl_auth; - struct auth_cred *acred = &container_of(cred, struct generic_cred, gc_base)->acred; - - return auth->au_ops->lookup_cred(auth, acred, lookupflags); -} - -/* - * Lookup generic creds for current process - */ -static struct rpc_cred * -generic_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) -{ - return rpcauth_lookup_credcache(&generic_auth, acred, flags); -} - -static struct rpc_cred * -generic_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) -{ - struct generic_cred *gcred; - - gcred = kmalloc(sizeof(*gcred), GFP_KERNEL); - if (gcred == NULL) - return ERR_PTR(-ENOMEM); - - rpcauth_init_cred(&gcred->gc_base, acred, &generic_auth, &generic_credops); - gcred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; - - gcred->acred.uid = acred->uid; - gcred->acred.gid = acred->gid; - gcred->acred.group_info = acred->group_info; - if (gcred->acred.group_info != NULL) - get_group_info(gcred->acred.group_info); - gcred->acred.machine_cred = acred->machine_cred; - gcred->acred.principal = acred->principal; - - dprintk("RPC: allocated %s cred %p for uid %d gid %d\n", - gcred->acred.machine_cred ? "machine" : "generic", - gcred, - from_kuid(&init_user_ns, acred->uid), - from_kgid(&init_user_ns, acred->gid)); - return &gcred->gc_base; -} - -static void -generic_free_cred(struct rpc_cred *cred) -{ - struct generic_cred *gcred = container_of(cred, struct generic_cred, gc_base); - - dprintk("RPC: generic_free_cred %p\n", gcred); - if (gcred->acred.group_info != NULL) - put_group_info(gcred->acred.group_info); - kfree(gcred); -} - -static void -generic_free_cred_callback(struct rcu_head *head) -{ - struct rpc_cred *cred = container_of(head, struct rpc_cred, cr_rcu); - generic_free_cred(cred); -} - -static void -generic_destroy_cred(struct rpc_cred *cred) -{ - call_rcu(&cred->cr_rcu, generic_free_cred_callback); -} - -static int -machine_cred_match(struct auth_cred *acred, struct generic_cred *gcred, int flags) -{ - if (!gcred->acred.machine_cred || - gcred->acred.principal != acred->principal || - !uid_eq(gcred->acred.uid, acred->uid) || - !gid_eq(gcred->acred.gid, acred->gid)) - return 0; - return 1; -} - -/* - * Match credentials against current process creds. - */ -static int -generic_match(struct auth_cred *acred, struct rpc_cred *cred, int flags) -{ - struct generic_cred *gcred = container_of(cred, struct generic_cred, gc_base); - int i; - - if (acred->machine_cred) - return machine_cred_match(acred, gcred, flags); - - if (!uid_eq(gcred->acred.uid, acred->uid) || - !gid_eq(gcred->acred.gid, acred->gid) || - gcred->acred.machine_cred != 0) - goto out_nomatch; - - /* Optimisation in the case where pointers are identical... */ - if (gcred->acred.group_info == acred->group_info) - goto out_match; - - /* Slow path... */ - if (gcred->acred.group_info->ngroups != acred->group_info->ngroups) - goto out_nomatch; - for (i = 0; i < gcred->acred.group_info->ngroups; i++) { - if (!gid_eq(GROUP_AT(gcred->acred.group_info, i), - GROUP_AT(acred->group_info, i))) - goto out_nomatch; - } -out_match: - return 1; -out_nomatch: - return 0; -} - -int __init rpc_init_generic_auth(void) -{ - return rpcauth_init_credcache(&generic_auth); -} - -void rpc_destroy_generic_auth(void) -{ - rpcauth_destroy_credcache(&generic_auth); -} - -static const struct rpc_authops generic_auth_ops = { - .owner = THIS_MODULE, - .au_name = "Generic", - .lookup_cred = generic_lookup_cred, - .crcreate = generic_create_cred, -}; - -static struct rpc_auth generic_auth = { - .au_ops = &generic_auth_ops, - .au_count = ATOMIC_INIT(0), -}; - -static const struct rpc_credops generic_credops = { - .cr_name = "Generic cred", - .crdestroy = generic_destroy_cred, - .crbind = generic_bind_cred, - .crmatch = generic_match, -}; diff --git a/net/sunrpc/auth_gss/Makefile b/net/sunrpc/auth_gss/Makefile index 14e9e53e63d5..452f67deebc6 100644 --- a/net/sunrpc/auth_gss/Makefile +++ b/net/sunrpc/auth_gss/Makefile @@ -1,14 +1,17 @@ +# SPDX-License-Identifier: GPL-2.0 # # Makefile for Linux kernel rpcsec_gss implementation # obj-$(CONFIG_SUNRPC_GSS) += auth_rpcgss.o -auth_rpcgss-y := auth_gss.o gss_generic_token.o \ +auth_rpcgss-y := auth_gss.o \ gss_mech_switch.o svcauth_gss.o \ - gss_rpc_upcall.o gss_rpc_xdr.o + gss_rpc_upcall.o gss_rpc_xdr.o trace.o obj-$(CONFIG_RPCSEC_GSS_KRB5) += rpcsec_gss_krb5.o rpcsec_gss_krb5-y := gss_krb5_mech.o gss_krb5_seal.o gss_krb5_unseal.o \ - gss_krb5_seqnum.o gss_krb5_wrap.o gss_krb5_crypto.o gss_krb5_keys.o + gss_krb5_wrap.o gss_krb5_crypto.o gss_krb5_keys.o + +obj-$(CONFIG_RPCSEC_GSS_KRB5_KUNIT_TEST) += gss_krb5_test.o diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c index fc2f78d6a9b4..5c095cb8cb20 100644 --- a/net/sunrpc/auth_gss/auth_gss.c +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: BSD-3-Clause /* * linux/net/sunrpc/auth_gss/auth_gss.c * @@ -8,34 +9,8 @@ * * Dug Song <dugsong@monkey.org> * Andy Adamson <andros@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ - #include <linux/module.h> #include <linux/init.h> #include <linux/types.h> @@ -45,15 +20,20 @@ #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/auth.h> #include <linux/sunrpc/auth_gss.h> +#include <linux/sunrpc/gss_krb5.h> #include <linux/sunrpc/svcauth_gss.h> #include <linux/sunrpc/gss_err.h> #include <linux/workqueue.h> #include <linux/sunrpc/rpc_pipe_fs.h> #include <linux/sunrpc/gss_api.h> -#include <asm/uaccess.h> +#include <linux/uaccess.h> +#include <linux/hashtable.h> +#include "auth_gss_internal.h" #include "../netns.h" +#include <trace/events/rpcgss.h> + static const struct rpc_authops authgss_ops; static const struct rpc_credops gss_credops; @@ -62,34 +42,69 @@ static const struct rpc_credops gss_nullops; #define GSS_RETRY_EXPIRED 5 static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED; -#ifdef RPC_DEBUG +#define GSS_KEY_EXPIRE_TIMEO 240 +static unsigned int gss_key_expire_timeo = GSS_KEY_EXPIRE_TIMEO; + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif +/* + * This compile-time check verifies that we will not exceed the + * slack space allotted by the client and server auth_gss code + * before they call gss_wrap(). + */ +#define GSS_KRB5_MAX_SLACK_NEEDED \ + (GSS_KRB5_TOK_HDR_LEN /* gss token header */ \ + + GSS_KRB5_MAX_CKSUM_LEN /* gss token checksum */ \ + + GSS_KRB5_MAX_BLOCKSIZE /* confounder */ \ + + GSS_KRB5_MAX_BLOCKSIZE /* possible padding */ \ + + GSS_KRB5_TOK_HDR_LEN /* encrypted hdr in v2 token */ \ + + GSS_KRB5_MAX_CKSUM_LEN /* encryption hmac */ \ + + XDR_UNIT * 2 /* RPC verifier */ \ + + GSS_KRB5_TOK_HDR_LEN \ + + GSS_KRB5_MAX_CKSUM_LEN) + #define GSS_CRED_SLACK (RPC_MAX_AUTH_SIZE * 2) /* length of a krb5 verifier (48), plus data added before arguments when * using integrity (two 4-byte integers): */ #define GSS_VERF_SLACK 100 +static DEFINE_HASHTABLE(gss_auth_hash_table, 4); +static DEFINE_SPINLOCK(gss_auth_hash_lock); + +struct gss_pipe { + struct rpc_pipe_dir_object pdo; + struct rpc_pipe *pipe; + struct rpc_clnt *clnt; + const char *name; + struct kref kref; +}; + struct gss_auth { struct kref kref; + struct hlist_node hash; struct rpc_auth rpc_auth; struct gss_api_mech *mech; enum rpc_gss_svc service; struct rpc_clnt *client; + struct net *net; + netns_tracker ns_tracker; /* * There are two upcall pipes; dentry[1], named "gssd", is used * for the new text-based upcall; dentry[0] is named after the * mechanism (for example, "krb5") and exists for * backwards-compatibility with older gssd's. */ - struct rpc_pipe *pipe[2]; + struct gss_pipe *gss_pipe[2]; + const char *target_name; }; /* pipe_version >= 0 if and only if someone has a pipe open. */ static DEFINE_SPINLOCK(pipe_version_lock); static struct rpc_wait_queue pipe_version_rpc_waitqueue; static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue); +static void gss_put_auth(struct gss_auth *gss_auth); static void gss_free_ctx(struct gss_cl_ctx *); static const struct rpc_pipe_ops gss_upcall_ops_v0; @@ -98,14 +113,14 @@ static const struct rpc_pipe_ops gss_upcall_ops_v1; static inline struct gss_cl_ctx * gss_get_ctx(struct gss_cl_ctx *ctx) { - atomic_inc(&ctx->count); + refcount_inc(&ctx->count); return ctx; } static inline void gss_put_ctx(struct gss_cl_ctx *ctx) { - if (atomic_dec_and_test(&ctx->count)) + if (refcount_dec_and_test(&ctx->count)) gss_free_ctx(ctx); } @@ -124,39 +139,10 @@ gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx) gss_get_ctx(ctx); rcu_assign_pointer(gss_cred->gc_ctx, ctx); set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); - smp_mb__before_clear_bit(); + smp_mb__before_atomic(); clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags); } -static const void * -simple_get_bytes(const void *p, const void *end, void *res, size_t len) -{ - const void *q = (const void *)((const char *)p + len); - if (unlikely(q > end || q < p)) - return ERR_PTR(-EFAULT); - memcpy(res, p, len); - return q; -} - -static inline const void * -simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest) -{ - const void *q; - unsigned int len; - - p = simple_get_bytes(p, end, &len, sizeof(len)); - if (IS_ERR(p)) - return p; - q = (const void *)((const char *)p + len); - if (unlikely(q > end || q < p)) - return ERR_PTR(-EFAULT); - dest->data = kmemdup(p, len, GFP_NOFS); - if (unlikely(dest->data == NULL)) - return ERR_PTR(-ENOMEM); - dest->len = len; - return q; -} - static struct gss_cl_ctx * gss_cred_get_ctx(struct rpc_cred *cred) { @@ -164,8 +150,9 @@ gss_cred_get_ctx(struct rpc_cred *cred) struct gss_cl_ctx *ctx = NULL; rcu_read_lock(); - if (gss_cred->gc_ctx) - ctx = gss_get_ctx(gss_cred->gc_ctx); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (ctx) + gss_get_ctx(ctx); rcu_read_unlock(); return ctx; } @@ -175,12 +162,12 @@ gss_alloc_context(void) { struct gss_cl_ctx *ctx; - ctx = kzalloc(sizeof(*ctx), GFP_NOFS); + ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); if (ctx != NULL) { ctx->gc_proc = RPC_GSS_PROC_DATA; ctx->gc_seq = 1; /* NetApp 6.4R1 doesn't accept seq. no. 0 */ spin_lock_init(&ctx->gc_seq_lock); - atomic_set(&ctx->count,1); + refcount_set(&ctx->count,1); } return ctx; } @@ -238,24 +225,41 @@ gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct p = ERR_PTR(-EFAULT); goto err; } - ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, NULL, GFP_NOFS); + ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, NULL, GFP_KERNEL); if (ret < 0) { + trace_rpcgss_import_ctx(ret); p = ERR_PTR(ret); goto err; } - dprintk("RPC: %s Success. gc_expiry %lu now %lu timeout %u\n", - __func__, ctx->gc_expiry, now, timeout); - return q; + + /* is there any trailing data? */ + if (q == end) { + p = q; + goto done; + } + + /* pull in acceptor name (if there is one) */ + p = simple_get_netobj(q, end, &ctx->gc_acceptor); + if (IS_ERR(p)) + goto err; +done: + trace_rpcgss_context(window_size, ctx->gc_expiry, now, timeout, + ctx->gc_acceptor.len, ctx->gc_acceptor.data); err: - dprintk("RPC: %s returns error %ld\n", __func__, -PTR_ERR(p)); return p; } -#define UPCALL_BUF_LEN 128 +/* XXX: Need some documentation about why UPCALL_BUF_LEN is so small. + * Is user space expecting no more than UPCALL_BUF_LEN bytes? + * Note that there are now _two_ NI_MAXHOST sized data items + * being passed in this string. + */ +#define UPCALL_BUF_LEN 256 struct gss_upcall_msg { - atomic_t count; + refcount_t count; kuid_t uid; + const char *service_name; struct rpc_pipe_msg msg; struct list_head list; struct gss_auth *auth; @@ -294,29 +298,31 @@ static void put_pipe_version(struct net *net) static void gss_release_msg(struct gss_upcall_msg *gss_msg) { - struct net *net = rpc_net_ns(gss_msg->auth->client); - if (!atomic_dec_and_test(&gss_msg->count)) + struct net *net = gss_msg->auth->net; + if (!refcount_dec_and_test(&gss_msg->count)) return; put_pipe_version(net); BUG_ON(!list_empty(&gss_msg->list)); if (gss_msg->ctx != NULL) gss_put_ctx(gss_msg->ctx); rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue); + gss_put_auth(gss_msg->auth); + kfree_const(gss_msg->service_name); kfree(gss_msg); } static struct gss_upcall_msg * -__gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid) +__gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid, const struct gss_auth *auth) { struct gss_upcall_msg *pos; list_for_each_entry(pos, &pipe->in_downcall, list) { if (!uid_eq(pos->uid, uid)) continue; - atomic_inc(&pos->count); - dprintk("RPC: %s found msg %p\n", __func__, pos); + if (pos->auth->service != auth->service) + continue; + refcount_inc(&pos->count); return pos; } - dprintk("RPC: %s found nothing\n", __func__); return NULL; } @@ -331,9 +337,9 @@ gss_add_msg(struct gss_upcall_msg *gss_msg) struct gss_upcall_msg *old; spin_lock(&pipe->lock); - old = __gss_find_upcall(pipe, gss_msg->uid); + old = __gss_find_upcall(pipe, gss_msg->uid, gss_msg->auth); if (old == NULL) { - atomic_inc(&gss_msg->count); + refcount_inc(&gss_msg->count); list_add(&gss_msg->list, &pipe->in_downcall); } else gss_msg = old; @@ -347,7 +353,7 @@ __gss_unhash_msg(struct gss_upcall_msg *gss_msg) list_del_init(&gss_msg->list); rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); wake_up_all(&gss_msg->waitqueue); - atomic_dec(&gss_msg->count); + refcount_dec(&gss_msg->count); } static void @@ -396,104 +402,182 @@ gss_upcall_callback(struct rpc_task *task) gss_release_msg(gss_msg); } -static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg) +static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg, + const struct cred *cred) { - uid_t uid = from_kuid(&init_user_ns, gss_msg->uid); + struct user_namespace *userns = cred->user_ns; + + uid_t uid = from_kuid_munged(userns, gss_msg->uid); memcpy(gss_msg->databuf, &uid, sizeof(uid)); gss_msg->msg.data = gss_msg->databuf; gss_msg->msg.len = sizeof(uid); - BUG_ON(sizeof(uid) > UPCALL_BUF_LEN); + + BUILD_BUG_ON(sizeof(uid) > sizeof(gss_msg->databuf)); } -static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, - struct rpc_clnt *clnt, - const char *service_name) +static ssize_t +gss_v0_upcall(struct file *file, struct rpc_pipe_msg *msg, + char __user *buf, size_t buflen) +{ + struct gss_upcall_msg *gss_msg = container_of(msg, + struct gss_upcall_msg, + msg); + if (msg->copied == 0) + gss_encode_v0_msg(gss_msg, file->f_cred); + return rpc_pipe_generic_upcall(file, msg, buf, buflen); +} + +static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, + const char *service_name, + const char *target_name, + const struct cred *cred) { + struct user_namespace *userns = cred->user_ns; struct gss_api_mech *mech = gss_msg->auth->mech; char *p = gss_msg->databuf; - int len = 0; - - gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ", - mech->gm_name, - from_kuid(&init_user_ns, gss_msg->uid)); - p += gss_msg->msg.len; - if (clnt->cl_principal) { - len = sprintf(p, "target=%s ", clnt->cl_principal); + size_t buflen = sizeof(gss_msg->databuf); + int len; + + len = scnprintf(p, buflen, "mech=%s uid=%d", mech->gm_name, + from_kuid_munged(userns, gss_msg->uid)); + buflen -= len; + p += len; + gss_msg->msg.len = len; + + /* + * target= is a full service principal that names the remote + * identity that we are authenticating to. + */ + if (target_name) { + len = scnprintf(p, buflen, " target=%s", target_name); + buflen -= len; p += len; gss_msg->msg.len += len; } - if (service_name != NULL) { - len = sprintf(p, "service=%s ", service_name); + + /* + * gssd uses service= and srchost= to select a matching key from + * the system's keytab to use as the source principal. + * + * service= is the service name part of the source principal, + * or "*" (meaning choose any). + * + * srchost= is the hostname part of the source principal. When + * not provided, gssd uses the local hostname. + */ + if (service_name) { + char *c = strchr(service_name, '@'); + + if (!c) + len = scnprintf(p, buflen, " service=%s", + service_name); + else + len = scnprintf(p, buflen, + " service=%.*s srchost=%s", + (int)(c - service_name), + service_name, c + 1); + buflen -= len; p += len; gss_msg->msg.len += len; } + if (mech->gm_upcall_enctypes) { - len = sprintf(p, "enctypes=%s ", mech->gm_upcall_enctypes); + len = scnprintf(p, buflen, " enctypes=%s", + mech->gm_upcall_enctypes); + buflen -= len; p += len; gss_msg->msg.len += len; } - len = sprintf(p, "\n"); + trace_rpcgss_upcall_msg(gss_msg->databuf); + len = scnprintf(p, buflen, "\n"); + if (len == 0) + goto out_overflow; gss_msg->msg.len += len; - gss_msg->msg.data = gss_msg->databuf; - BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN); + return 0; +out_overflow: + WARN_ON_ONCE(1); + return -ENOMEM; } -static void gss_encode_msg(struct gss_upcall_msg *gss_msg, - struct rpc_clnt *clnt, - const char *service_name) +static ssize_t +gss_v1_upcall(struct file *file, struct rpc_pipe_msg *msg, + char __user *buf, size_t buflen) { - struct net *net = rpc_net_ns(clnt); - struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); - - if (sn->pipe_version == 0) - gss_encode_v0_msg(gss_msg); - else /* pipe_version == 1 */ - gss_encode_v1_msg(gss_msg, clnt, service_name); + struct gss_upcall_msg *gss_msg = container_of(msg, + struct gss_upcall_msg, + msg); + int err; + if (msg->copied == 0) { + err = gss_encode_v1_msg(gss_msg, + gss_msg->service_name, + gss_msg->auth->target_name, + file->f_cred); + if (err) + return err; + } + return rpc_pipe_generic_upcall(file, msg, buf, buflen); } static struct gss_upcall_msg * -gss_alloc_msg(struct gss_auth *gss_auth, struct rpc_clnt *clnt, +gss_alloc_msg(struct gss_auth *gss_auth, kuid_t uid, const char *service_name) { struct gss_upcall_msg *gss_msg; int vers; + int err = -ENOMEM; - gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS); + gss_msg = kzalloc(sizeof(*gss_msg), GFP_KERNEL); if (gss_msg == NULL) - return ERR_PTR(-ENOMEM); - vers = get_pipe_version(rpc_net_ns(clnt)); - if (vers < 0) { - kfree(gss_msg); - return ERR_PTR(vers); - } - gss_msg->pipe = gss_auth->pipe[vers]; + goto err; + vers = get_pipe_version(gss_auth->net); + err = vers; + if (err < 0) + goto err_free_msg; + gss_msg->pipe = gss_auth->gss_pipe[vers]->pipe; INIT_LIST_HEAD(&gss_msg->list); rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq"); init_waitqueue_head(&gss_msg->waitqueue); - atomic_set(&gss_msg->count, 1); + refcount_set(&gss_msg->count, 1); gss_msg->uid = uid; gss_msg->auth = gss_auth; - gss_encode_msg(gss_msg, clnt, service_name); + kref_get(&gss_auth->kref); + if (service_name) { + gss_msg->service_name = kstrdup_const(service_name, GFP_KERNEL); + if (!gss_msg->service_name) { + err = -ENOMEM; + goto err_put_pipe_version; + } + } return gss_msg; +err_put_pipe_version: + put_pipe_version(gss_auth->net); +err_free_msg: + kfree(gss_msg); +err: + return ERR_PTR(err); } static struct gss_upcall_msg * -gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred) +gss_setup_upcall(struct gss_auth *gss_auth, struct rpc_cred *cred) { struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_upcall_msg *gss_new, *gss_msg; - kuid_t uid = cred->cr_uid; + kuid_t uid = cred->cr_cred->fsuid; - gss_new = gss_alloc_msg(gss_auth, clnt, uid, gss_cred->gc_principal); + gss_new = gss_alloc_msg(gss_auth, uid, gss_cred->gc_principal); if (IS_ERR(gss_new)) return gss_new; gss_msg = gss_add_msg(gss_new); if (gss_msg == gss_new) { - int res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg); + int res; + refcount_inc(&gss_msg->count); + res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg); if (res) { gss_unhash_msg(gss_new); + refcount_dec(&gss_msg->count); + gss_release_msg(gss_new); gss_msg = ERR_PTR(res); } } else @@ -503,14 +587,7 @@ gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cr static void warn_gssd(void) { - static unsigned long ratelimit; - unsigned long now = jiffies; - - if (time_after(now, ratelimit)) { - printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n" - "Please check user daemon is running.\n"); - ratelimit = now + 15*HZ; - } + dprintk("AUTH_GSS upcall failed. Please check user daemon is running.\n"); } static inline int @@ -525,16 +602,15 @@ gss_refresh_upcall(struct rpc_task *task) struct rpc_pipe *pipe; int err = 0; - dprintk("RPC: %5u %s for uid %u\n", - task->tk_pid, __func__, from_kuid(&init_user_ns, cred->cr_uid)); - gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred); + gss_msg = gss_setup_upcall(gss_auth, cred); if (PTR_ERR(gss_msg) == -EAGAIN) { /* XXX: warning on the first, under the assumption we * shouldn't normally hit this case on a refresh. */ warn_gssd(); - task->tk_timeout = 15*HZ; - rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL); - return -EAGAIN; + rpc_sleep_on_timeout(&pipe_version_rpc_waitqueue, + task, NULL, jiffies + (15 * HZ)); + err = -EAGAIN; + goto out; } if (IS_ERR(gss_msg)) { err = PTR_ERR(gss_msg); @@ -545,10 +621,9 @@ gss_refresh_upcall(struct rpc_task *task) if (gss_cred->gc_upcall != NULL) rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL); else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) { - task->tk_timeout = 0; gss_cred->gc_upcall = gss_msg; /* gss_upcall_callback will release the reference to gss_upcall_msg */ - atomic_inc(&gss_msg->count); + refcount_inc(&gss_msg->count); rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback); } else { gss_handle_downcall_result(gss_cred, gss_msg); @@ -557,39 +632,35 @@ gss_refresh_upcall(struct rpc_task *task) spin_unlock(&pipe->lock); gss_release_msg(gss_msg); out: - dprintk("RPC: %5u %s for uid %u result %d\n", - task->tk_pid, __func__, - from_kuid(&init_user_ns, cred->cr_uid), err); + trace_rpcgss_upcall_result(from_kuid(&init_user_ns, + cred->cr_cred->fsuid), err); return err; } static inline int gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred) { - struct net *net = rpc_net_ns(gss_auth->client); + struct net *net = gss_auth->net; struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); struct rpc_pipe *pipe; struct rpc_cred *cred = &gss_cred->gc_base; struct gss_upcall_msg *gss_msg; - unsigned long timeout; DEFINE_WAIT(wait); int err; - dprintk("RPC: %s for uid %u\n", - __func__, from_kuid(&init_user_ns, cred->cr_uid)); retry: err = 0; - /* Default timeout is 15s unless we know that gssd is not running */ - timeout = 15 * HZ; - if (!sn->gssd_running) - timeout = HZ >> 2; - gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred); + /* if gssd is down, just skip upcalling altogether */ + if (!gssd_running(net)) { + warn_gssd(); + err = -EACCES; + goto out; + } + gss_msg = gss_setup_upcall(gss_auth, cred); if (PTR_ERR(gss_msg) == -EAGAIN) { err = wait_event_interruptible_timeout(pipe_version_waitqueue, - sn->pipe_version >= 0, timeout); + sn->pipe_version >= 0, 15 * HZ); if (sn->pipe_version < 0) { - if (err == 0) - sn->gssd_running = 0; warn_gssd(); err = -EACCES; } @@ -615,20 +686,37 @@ retry: } schedule(); } - if (gss_msg->ctx) + if (gss_msg->ctx) { + trace_rpcgss_ctx_init(gss_cred); gss_cred_set_ctx(cred, gss_msg->ctx); - else + } else { err = gss_msg->msg.errno; + } spin_unlock(&pipe->lock); out_intr: finish_wait(&gss_msg->waitqueue, &wait); gss_release_msg(gss_msg); out: - dprintk("RPC: %s for uid %u result %d\n", - __func__, from_kuid(&init_user_ns, cred->cr_uid), err); + trace_rpcgss_upcall_result(from_kuid(&init_user_ns, + cred->cr_cred->fsuid), err); return err; } +static struct gss_upcall_msg * +gss_find_downcall(struct rpc_pipe *pipe, kuid_t uid) +{ + struct gss_upcall_msg *pos; + list_for_each_entry(pos, &pipe->in_downcall, list) { + if (!uid_eq(pos->uid, uid)) + continue; + if (!rpc_msg_is_inflight(&pos->msg)) + continue; + refcount_inc(&pos->count); + return pos; + } + return NULL; +} + #define MSG_BUF_MAXSIZE 1024 static ssize_t @@ -646,7 +734,7 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) if (mlen > MSG_BUF_MAXSIZE) goto out; err = -ENOMEM; - buf = kmalloc(mlen, GFP_NOFS); + buf = kmalloc(mlen, GFP_KERNEL); if (!buf) goto out; @@ -661,7 +749,7 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) goto err; } - uid = make_kuid(&init_user_ns, id); + uid = make_kuid(current_user_ns(), id); if (!uid_valid(uid)) { err = -EINVAL; goto err; @@ -675,7 +763,7 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) err = -ENOENT; /* Find a matching upcall */ spin_lock(&pipe->lock); - gss_msg = __gss_find_upcall(pipe, uid); + gss_msg = gss_find_downcall(pipe, uid); if (gss_msg == NULL) { spin_unlock(&pipe->lock); goto err_put_ctx; @@ -701,7 +789,7 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) default: printk(KERN_CRIT "%s: bad return from " "gss_fill_context: %zd\n", __func__, err); - BUG(); + gss_msg->msg.errno = -EIO; } goto err_release_msg; } @@ -718,7 +806,6 @@ err_put_ctx: err: kfree(buf); out: - dprintk("RPC: %s returning %Zd\n", __func__, err); return err; } @@ -770,7 +857,7 @@ restart: if (!list_empty(&gss_msg->msg.list)) continue; gss_msg->msg.errno = -EPIPE; - atomic_inc(&gss_msg->count); + refcount_inc(&gss_msg->count); __gss_unhash_msg(gss_msg); spin_unlock(&pipe->lock); gss_release_msg(gss_msg); @@ -787,169 +874,243 @@ gss_pipe_destroy_msg(struct rpc_pipe_msg *msg) struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg); if (msg->errno < 0) { - dprintk("RPC: %s releasing msg %p\n", - __func__, gss_msg); - atomic_inc(&gss_msg->count); + refcount_inc(&gss_msg->count); gss_unhash_msg(gss_msg); if (msg->errno == -ETIMEDOUT) warn_gssd(); gss_release_msg(gss_msg); } + gss_release_msg(gss_msg); } -static void gss_pipes_dentries_destroy(struct rpc_auth *auth) +static void gss_pipe_dentry_destroy(struct dentry *dir, + struct rpc_pipe_dir_object *pdo) { - struct gss_auth *gss_auth; + struct gss_pipe *gss_pipe = pdo->pdo_data; - gss_auth = container_of(auth, struct gss_auth, rpc_auth); - if (gss_auth->pipe[0]->dentry) - rpc_unlink(gss_auth->pipe[0]->dentry); - if (gss_auth->pipe[1]->dentry) - rpc_unlink(gss_auth->pipe[1]->dentry); + rpc_unlink(gss_pipe->pipe); } -static int gss_pipes_dentries_create(struct rpc_auth *auth) +static int gss_pipe_dentry_create(struct dentry *dir, + struct rpc_pipe_dir_object *pdo) { - int err; - struct gss_auth *gss_auth; - struct rpc_clnt *clnt; + struct gss_pipe *p = pdo->pdo_data; - gss_auth = container_of(auth, struct gss_auth, rpc_auth); - clnt = gss_auth->client; - - gss_auth->pipe[1]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry, - "gssd", - clnt, gss_auth->pipe[1]); - if (IS_ERR(gss_auth->pipe[1]->dentry)) - return PTR_ERR(gss_auth->pipe[1]->dentry); - gss_auth->pipe[0]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry, - gss_auth->mech->gm_name, - clnt, gss_auth->pipe[0]); - if (IS_ERR(gss_auth->pipe[0]->dentry)) { - err = PTR_ERR(gss_auth->pipe[0]->dentry); - goto err_unlink_pipe_1; + return rpc_mkpipe_dentry(dir, p->name, p->clnt, p->pipe); +} + +static const struct rpc_pipe_dir_object_ops gss_pipe_dir_object_ops = { + .create = gss_pipe_dentry_create, + .destroy = gss_pipe_dentry_destroy, +}; + +static struct gss_pipe *gss_pipe_alloc(struct rpc_clnt *clnt, + const char *name, + const struct rpc_pipe_ops *upcall_ops) +{ + struct gss_pipe *p; + int err = -ENOMEM; + + p = kmalloc(sizeof(*p), GFP_KERNEL); + if (p == NULL) + goto err; + p->pipe = rpc_mkpipe_data(upcall_ops, RPC_PIPE_WAIT_FOR_OPEN); + if (IS_ERR(p->pipe)) { + err = PTR_ERR(p->pipe); + goto err_free_gss_pipe; } - return 0; + p->name = name; + p->clnt = clnt; + kref_init(&p->kref); + rpc_init_pipe_dir_object(&p->pdo, + &gss_pipe_dir_object_ops, + p); + return p; +err_free_gss_pipe: + kfree(p); +err: + return ERR_PTR(err); +} -err_unlink_pipe_1: - rpc_unlink(gss_auth->pipe[1]->dentry); - return err; +struct gss_alloc_pdo { + struct rpc_clnt *clnt; + const char *name; + const struct rpc_pipe_ops *upcall_ops; +}; + +static int gss_pipe_match_pdo(struct rpc_pipe_dir_object *pdo, void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + if (pdo->pdo_ops != &gss_pipe_dir_object_ops) + return 0; + gss_pipe = container_of(pdo, struct gss_pipe, pdo); + if (strcmp(gss_pipe->name, args->name) != 0) + return 0; + if (!kref_get_unless_zero(&gss_pipe->kref)) + return 0; + return 1; } -static void gss_pipes_dentries_destroy_net(struct rpc_clnt *clnt, - struct rpc_auth *auth) +static struct rpc_pipe_dir_object *gss_pipe_alloc_pdo(void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + gss_pipe = gss_pipe_alloc(args->clnt, args->name, args->upcall_ops); + if (!IS_ERR(gss_pipe)) + return &gss_pipe->pdo; + return NULL; +} + +static struct gss_pipe *gss_pipe_get(struct rpc_clnt *clnt, + const char *name, + const struct rpc_pipe_ops *upcall_ops) { struct net *net = rpc_net_ns(clnt); - struct super_block *sb; + struct rpc_pipe_dir_object *pdo; + struct gss_alloc_pdo args = { + .clnt = clnt, + .name = name, + .upcall_ops = upcall_ops, + }; - sb = rpc_get_sb_net(net); - if (sb) { - if (clnt->cl_dentry) - gss_pipes_dentries_destroy(auth); - rpc_put_sb_net(net); - } + pdo = rpc_find_or_alloc_pipe_dir_object(net, + &clnt->cl_pipedir_objects, + gss_pipe_match_pdo, + gss_pipe_alloc_pdo, + &args); + if (pdo != NULL) + return container_of(pdo, struct gss_pipe, pdo); + return ERR_PTR(-ENOMEM); } -static int gss_pipes_dentries_create_net(struct rpc_clnt *clnt, - struct rpc_auth *auth) +static void __gss_pipe_free(struct gss_pipe *p) { + struct rpc_clnt *clnt = p->clnt; struct net *net = rpc_net_ns(clnt); - struct super_block *sb; - int err = 0; - sb = rpc_get_sb_net(net); - if (sb) { - if (clnt->cl_dentry) - err = gss_pipes_dentries_create(auth); - rpc_put_sb_net(net); - } - return err; + rpc_remove_pipe_dir_object(net, + &clnt->cl_pipedir_objects, + &p->pdo); + rpc_destroy_pipe_data(p->pipe); + kfree(p); +} + +static void __gss_pipe_release(struct kref *kref) +{ + struct gss_pipe *p = container_of(kref, struct gss_pipe, kref); + + __gss_pipe_free(p); +} + +static void gss_pipe_free(struct gss_pipe *p) +{ + if (p != NULL) + kref_put(&p->kref, __gss_pipe_release); } /* * NOTE: we have the opportunity to use different * parameters based on the input flavor (which must be a pseudoflavor) */ -static struct rpc_auth * -gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +static struct gss_auth * +gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { + rpc_authflavor_t flavor = args->pseudoflavor; struct gss_auth *gss_auth; + struct gss_pipe *gss_pipe; struct rpc_auth * auth; int err = -ENOMEM; /* XXX? */ - dprintk("RPC: creating GSS authenticator for client %p\n", clnt); - if (!try_module_get(THIS_MODULE)) return ERR_PTR(err); if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL))) goto out_dec; + INIT_HLIST_NODE(&gss_auth->hash); + gss_auth->target_name = NULL; + if (args->target_name) { + gss_auth->target_name = kstrdup(args->target_name, GFP_KERNEL); + if (gss_auth->target_name == NULL) + goto err_free; + } gss_auth->client = clnt; + gss_auth->net = get_net_track(rpc_net_ns(clnt), &gss_auth->ns_tracker, + GFP_KERNEL); err = -EINVAL; gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor); - if (!gss_auth->mech) { - dprintk("RPC: Pseudoflavor %d not found!\n", flavor); - goto err_free; - } + if (!gss_auth->mech) + goto err_put_net; gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor); if (gss_auth->service == 0) goto err_put_mech; + if (!gssd_running(gss_auth->net)) + goto err_put_mech; auth = &gss_auth->rpc_auth; auth->au_cslack = GSS_CRED_SLACK >> 2; - auth->au_rslack = GSS_VERF_SLACK >> 2; + BUILD_BUG_ON(GSS_KRB5_MAX_SLACK_NEEDED > RPC_MAX_AUTH_SIZE); + auth->au_rslack = GSS_KRB5_MAX_SLACK_NEEDED >> 2; + auth->au_verfsize = GSS_VERF_SLACK >> 2; + auth->au_ralign = GSS_VERF_SLACK >> 2; + __set_bit(RPCAUTH_AUTH_UPDATE_SLACK, &auth->au_flags); auth->au_ops = &authgss_ops; auth->au_flavor = flavor; - atomic_set(&auth->au_count, 1); + if (gss_pseudoflavor_to_datatouch(gss_auth->mech, flavor)) + __set_bit(RPCAUTH_AUTH_DATATOUCH, &auth->au_flags); + refcount_set(&auth->au_count, 1); kref_init(&gss_auth->kref); + err = rpcauth_init_credcache(auth); + if (err) + goto err_put_mech; /* * Note: if we created the old pipe first, then someone who * examined the directory at the right moment might conclude * that we supported only the old pipe. So we instead create * the new pipe first. */ - gss_auth->pipe[1] = rpc_mkpipe_data(&gss_upcall_ops_v1, - RPC_PIPE_WAIT_FOR_OPEN); - if (IS_ERR(gss_auth->pipe[1])) { - err = PTR_ERR(gss_auth->pipe[1]); - goto err_put_mech; + gss_pipe = gss_pipe_get(clnt, "gssd", &gss_upcall_ops_v1); + if (IS_ERR(gss_pipe)) { + err = PTR_ERR(gss_pipe); + goto err_destroy_credcache; } + gss_auth->gss_pipe[1] = gss_pipe; - gss_auth->pipe[0] = rpc_mkpipe_data(&gss_upcall_ops_v0, - RPC_PIPE_WAIT_FOR_OPEN); - if (IS_ERR(gss_auth->pipe[0])) { - err = PTR_ERR(gss_auth->pipe[0]); + gss_pipe = gss_pipe_get(clnt, gss_auth->mech->gm_name, + &gss_upcall_ops_v0); + if (IS_ERR(gss_pipe)) { + err = PTR_ERR(gss_pipe); goto err_destroy_pipe_1; } - err = gss_pipes_dentries_create_net(clnt, auth); - if (err) - goto err_destroy_pipe_0; - err = rpcauth_init_credcache(auth); - if (err) - goto err_unlink_pipes; + gss_auth->gss_pipe[0] = gss_pipe; - return auth; -err_unlink_pipes: - gss_pipes_dentries_destroy_net(clnt, auth); -err_destroy_pipe_0: - rpc_destroy_pipe_data(gss_auth->pipe[0]); + return gss_auth; err_destroy_pipe_1: - rpc_destroy_pipe_data(gss_auth->pipe[1]); + gss_pipe_free(gss_auth->gss_pipe[1]); +err_destroy_credcache: + rpcauth_destroy_credcache(auth); err_put_mech: gss_mech_put(gss_auth->mech); +err_put_net: + put_net_track(gss_auth->net, &gss_auth->ns_tracker); err_free: + kfree(gss_auth->target_name); kfree(gss_auth); out_dec: module_put(THIS_MODULE); + trace_rpcgss_createauth(flavor, err); return ERR_PTR(err); } static void gss_free(struct gss_auth *gss_auth) { - gss_pipes_dentries_destroy_net(gss_auth->client, &gss_auth->rpc_auth); - rpc_destroy_pipe_data(gss_auth->pipe[0]); - rpc_destroy_pipe_data(gss_auth->pipe[1]); + gss_pipe_free(gss_auth->gss_pipe[0]); + gss_pipe_free(gss_auth->gss_pipe[1]); gss_mech_put(gss_auth->mech); + put_net_track(gss_auth->net, &gss_auth->ns_tracker); + kfree(gss_auth->target_name); kfree(gss_auth); module_put(THIS_MODULE); @@ -964,49 +1125,172 @@ gss_free_callback(struct kref *kref) } static void +gss_put_auth(struct gss_auth *gss_auth) +{ + kref_put(&gss_auth->kref, gss_free_callback); +} + +static void gss_destroy(struct rpc_auth *auth) { - struct gss_auth *gss_auth; + struct gss_auth *gss_auth = container_of(auth, + struct gss_auth, rpc_auth); - dprintk("RPC: destroying GSS authenticator %p flavor %d\n", - auth, auth->au_flavor); + if (hash_hashed(&gss_auth->hash)) { + spin_lock(&gss_auth_hash_lock); + hash_del(&gss_auth->hash); + spin_unlock(&gss_auth_hash_lock); + } + gss_pipe_free(gss_auth->gss_pipe[0]); + gss_auth->gss_pipe[0] = NULL; + gss_pipe_free(gss_auth->gss_pipe[1]); + gss_auth->gss_pipe[1] = NULL; rpcauth_destroy_credcache(auth); - gss_auth = container_of(auth, struct gss_auth, rpc_auth); - kref_put(&gss_auth->kref, gss_free_callback); + gss_put_auth(gss_auth); +} + +/* + * Auths may be shared between rpc clients that were cloned from a + * common client with the same xprt, if they also share the flavor and + * target_name. + * + * The auth is looked up from the oldest parent sharing the same + * cl_xprt, and the auth itself references only that common parent + * (which is guaranteed to last as long as any of its descendants). + */ +static struct gss_auth * +gss_auth_find_or_add_hashed(const struct rpc_auth_create_args *args, + struct rpc_clnt *clnt, + struct gss_auth *new) +{ + struct gss_auth *gss_auth; + unsigned long hashval = (unsigned long)clnt; + + spin_lock(&gss_auth_hash_lock); + hash_for_each_possible(gss_auth_hash_table, + gss_auth, + hash, + hashval) { + if (gss_auth->client != clnt) + continue; + if (gss_auth->rpc_auth.au_flavor != args->pseudoflavor) + continue; + if (gss_auth->target_name != args->target_name) { + if (gss_auth->target_name == NULL) + continue; + if (args->target_name == NULL) + continue; + if (strcmp(gss_auth->target_name, args->target_name)) + continue; + } + if (!refcount_inc_not_zero(&gss_auth->rpc_auth.au_count)) + continue; + goto out; + } + if (new) + hash_add(gss_auth_hash_table, &new->hash, hashval); + gss_auth = new; +out: + spin_unlock(&gss_auth_hash_lock); + return gss_auth; +} + +static struct gss_auth * +gss_create_hashed(const struct rpc_auth_create_args *args, + struct rpc_clnt *clnt) +{ + struct gss_auth *gss_auth; + struct gss_auth *new; + + gss_auth = gss_auth_find_or_add_hashed(args, clnt, NULL); + if (gss_auth != NULL) + goto out; + new = gss_create_new(args, clnt); + if (IS_ERR(new)) + return new; + gss_auth = gss_auth_find_or_add_hashed(args, clnt, new); + if (gss_auth != new) + gss_destroy(&new->rpc_auth); +out: + return gss_auth; +} + +static struct rpc_auth * +gss_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + struct gss_auth *gss_auth; + struct rpc_xprt_switch *xps = rcu_access_pointer(clnt->cl_xpi.xpi_xpswitch); + + while (clnt != clnt->cl_parent) { + struct rpc_clnt *parent = clnt->cl_parent; + /* Find the original parent for this transport */ + if (rcu_access_pointer(parent->cl_xpi.xpi_xpswitch) != xps) + break; + clnt = parent; + } + + gss_auth = gss_create_hashed(args, clnt); + if (IS_ERR(gss_auth)) + return ERR_CAST(gss_auth); + return &gss_auth->rpc_auth; +} + +static struct gss_cred * +gss_dup_cred(struct gss_auth *gss_auth, struct gss_cred *gss_cred) +{ + struct gss_cred *new; + + /* Make a copy of the cred so that we can reference count it */ + new = kzalloc(sizeof(*gss_cred), GFP_KERNEL); + if (new) { + struct auth_cred acred = { + .cred = gss_cred->gc_base.cr_cred, + }; + struct gss_cl_ctx *ctx = + rcu_dereference_protected(gss_cred->gc_ctx, 1); + + rpcauth_init_cred(&new->gc_base, &acred, + &gss_auth->rpc_auth, + &gss_nullops); + new->gc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; + new->gc_service = gss_cred->gc_service; + new->gc_principal = gss_cred->gc_principal; + kref_get(&gss_auth->kref); + rcu_assign_pointer(new->gc_ctx, ctx); + gss_get_ctx(ctx); + } + return new; } /* - * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call + * gss_send_destroy_context will cause the RPCSEC_GSS to send a NULL RPC call * to the server with the GSS control procedure field set to * RPC_GSS_PROC_DESTROY. This should normally cause the server to release * all RPCSEC_GSS state associated with that context. */ -static int -gss_destroying_context(struct rpc_cred *cred) +static void +gss_send_destroy_context(struct rpc_cred *cred) { struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); + struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1); + struct gss_cred *new; struct rpc_task *task; - if (gss_cred->gc_ctx == NULL || - test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0) - return 0; - - gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY; - cred->cr_ops = &gss_nullops; - - /* Take a reference to ensure the cred will be destroyed either - * by the RPC call or by the put_rpccred() below */ - get_rpccred(cred); + new = gss_dup_cred(gss_auth, gss_cred); + if (new) { + ctx->gc_proc = RPC_GSS_PROC_DESTROY; - task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT); - if (!IS_ERR(task)) - rpc_put_task(task); + trace_rpcgss_ctx_destroy(gss_cred); + task = rpc_call_null(gss_auth->client, &new->gc_base, + RPC_TASK_ASYNC); + if (!IS_ERR(task)) + rpc_put_task(task); - put_rpccred(cred); - return 1; + put_rpccred(&new->gc_base); + } } /* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure @@ -1015,10 +1299,9 @@ gss_destroying_context(struct rpc_cred *cred) static void gss_do_free_ctx(struct gss_cl_ctx *ctx) { - dprintk("RPC: %s\n", __func__); - gss_delete_sec_context(&ctx->gc_gss_ctx); kfree(ctx->gc_wire_ctx.data); + kfree(ctx->gc_acceptor.data); kfree(ctx); } @@ -1038,7 +1321,6 @@ gss_free_ctx(struct gss_cl_ctx *ctx) static void gss_free_cred(struct gss_cred *gss_cred) { - dprintk("RPC: %s cred=%p\n", __func__, gss_cred); kfree(gss_cred); } @@ -1054,45 +1336,48 @@ gss_destroy_nullcred(struct rpc_cred *cred) { struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); - struct gss_cl_ctx *ctx = gss_cred->gc_ctx; + struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1); RCU_INIT_POINTER(gss_cred->gc_ctx, NULL); + put_cred(cred->cr_cred); call_rcu(&cred->cr_rcu, gss_free_cred_callback); if (ctx) gss_put_ctx(ctx); - kref_put(&gss_auth->kref, gss_free_callback); + gss_put_auth(gss_auth); } static void gss_destroy_cred(struct rpc_cred *cred) { - - if (gss_destroying_context(cred)) - return; + if (test_and_clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) + gss_send_destroy_context(cred); gss_destroy_nullcred(cred); } +static int +gss_hash_cred(struct auth_cred *acred, unsigned int hashbits) +{ + return hash_64(from_kuid(&init_user_ns, acred->cred->fsuid), hashbits); +} + /* * Lookup RPCSEC_GSS cred for the current process */ -static struct rpc_cred * -gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +static struct rpc_cred *gss_lookup_cred(struct rpc_auth *auth, + struct auth_cred *acred, int flags) { - return rpcauth_lookup_credcache(auth, acred, flags); + return rpcauth_lookup_credcache(auth, acred, flags, + rpc_task_gfp_mask()); } static struct rpc_cred * -gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t gfp) { struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); struct gss_cred *cred = NULL; int err = -ENOMEM; - dprintk("RPC: %s for uid %d, flavor %d\n", - __func__, from_kuid(&init_user_ns, acred->uid), - auth->au_flavor); - - if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS))) + if (!(cred = kzalloc(sizeof(*cred), gfp))) goto out_err; rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops); @@ -1102,14 +1387,11 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) */ cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW; cred->gc_service = gss_auth->service; - cred->gc_principal = NULL; - if (acred->machine_cred) - cred->gc_principal = acred->principal; + cred->gc_principal = acred->principal; kref_get(&gss_auth->kref); return &cred->gc_base; out_err: - dprintk("RPC: %s failed with error %d\n", __func__, err); return ERR_PTR(err); } @@ -1126,87 +1408,195 @@ gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred) return err; } +static char * +gss_stringify_acceptor(struct rpc_cred *cred) +{ + char *string = NULL; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx; + unsigned int len; + struct xdr_netobj *acceptor; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (!ctx) + goto out; + + len = ctx->gc_acceptor.len; + rcu_read_unlock(); + + /* no point if there's no string */ + if (!len) + return NULL; +realloc: + string = kmalloc(len + 1, GFP_KERNEL); + if (!string) + return NULL; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + + /* did the ctx disappear or was it replaced by one with no acceptor? */ + if (!ctx || !ctx->gc_acceptor.len) { + kfree(string); + string = NULL; + goto out; + } + + acceptor = &ctx->gc_acceptor; + + /* + * Did we find a new acceptor that's longer than the original? Allocate + * a longer buffer and try again. + */ + if (len < acceptor->len) { + len = acceptor->len; + rcu_read_unlock(); + kfree(string); + goto realloc; + } + + memcpy(string, acceptor->data, acceptor->len); + string[acceptor->len] = '\0'; +out: + rcu_read_unlock(); + return string; +} + +/* + * Returns -EACCES if GSS context is NULL or will expire within the + * timeout (miliseconds) + */ +static int +gss_key_timeout(struct rpc_cred *rc) +{ + struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx; + unsigned long timeout = jiffies + (gss_key_expire_timeo * HZ); + int ret = 0; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (!ctx || time_after(timeout, ctx->gc_expiry)) + ret = -EACCES; + rcu_read_unlock(); + + return ret; +} + static int gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags) { struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx; + int ret; if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags)) goto out; /* Don't match with creds that have expired. */ - if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry)) + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (!ctx || time_after(jiffies, ctx->gc_expiry)) { + rcu_read_unlock(); return 0; + } + rcu_read_unlock(); if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags)) return 0; out: if (acred->principal != NULL) { if (gss_cred->gc_principal == NULL) return 0; - return strcmp(acred->principal, gss_cred->gc_principal) == 0; + ret = strcmp(acred->principal, gss_cred->gc_principal) == 0; + } else { + if (gss_cred->gc_principal != NULL) + return 0; + ret = uid_eq(rc->cr_cred->fsuid, acred->cred->fsuid); } - if (gss_cred->gc_principal != NULL) - return 0; - return uid_eq(rc->cr_uid, acred->uid); + return ret; } /* -* Marshal credentials. -* Maybe we should keep a cached credential for performance reasons. -*/ -static __be32 * -gss_marshal(struct rpc_task *task, __be32 *p) + * Marshal credentials. + * + * The expensive part is computing the verifier. We can't cache a + * pre-computed version of the verifier because the seqno, which + * is different every time, is included in the MIC. + */ +static int gss_marshal(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_cred *cred = req->rq_cred; struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); - __be32 *cred_len; + __be32 *p, *cred_len; u32 maj_stat = 0; struct xdr_netobj mic; struct kvec iov; struct xdr_buf verf_buf; + int status; + u32 seqno; - dprintk("RPC: %5u %s\n", task->tk_pid, __func__); + /* Credential */ - *p++ = htonl(RPC_AUTH_GSS); + p = xdr_reserve_space(xdr, 7 * sizeof(*p) + + ctx->gc_wire_ctx.len); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_gss; cred_len = p++; spin_lock(&ctx->gc_seq_lock); - req->rq_seqno = ctx->gc_seq++; + seqno = (ctx->gc_seq < MAXSEQ) ? ctx->gc_seq++ : MAXSEQ; + xprt_rqst_add_seqno(req, seqno); spin_unlock(&ctx->gc_seq_lock); - - *p++ = htonl((u32) RPC_GSS_VERSION); - *p++ = htonl((u32) ctx->gc_proc); - *p++ = htonl((u32) req->rq_seqno); - *p++ = htonl((u32) gss_cred->gc_service); + if (*req->rq_seqnos == MAXSEQ) + goto expired; + trace_rpcgss_seqno(task); + + *p++ = cpu_to_be32(RPC_GSS_VERSION); + *p++ = cpu_to_be32(ctx->gc_proc); + *p++ = cpu_to_be32(*req->rq_seqnos); + *p++ = cpu_to_be32(gss_cred->gc_service); p = xdr_encode_netobj(p, &ctx->gc_wire_ctx); - *cred_len = htonl((p - (cred_len + 1)) << 2); + *cred_len = cpu_to_be32((p - (cred_len + 1)) << 2); + + /* Verifier */ /* We compute the checksum for the verifier over the xdr-encoded bytes * starting with the xid and ending at the end of the credential: */ - iov.iov_base = xprt_skip_transport_header(req->rq_xprt, - req->rq_snd_buf.head[0].iov_base); + iov.iov_base = req->rq_snd_buf.head[0].iov_base; iov.iov_len = (u8 *)p - (u8 *)iov.iov_base; xdr_buf_from_iov(&iov, &verf_buf); - /* set verifier flavor*/ - *p++ = htonl(RPC_AUTH_GSS); - + p = xdr_reserve_space(xdr, sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_gss; mic.data = (u8 *)(p + 1); maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic); - if (maj_stat == GSS_S_CONTEXT_EXPIRED) { - clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); - } else if (maj_stat != 0) { - printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat); - goto out_put_ctx; - } - p = xdr_encode_opaque(p, NULL, mic.len); - gss_put_ctx(ctx); - return p; -out_put_ctx: + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + goto expired; + else if (maj_stat != 0) + goto bad_mic; + if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0) + goto marshal_failed; + status = 0; +out: gss_put_ctx(ctx); - return NULL; + return status; +expired: + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + status = -EKEYEXPIRED; + goto out; +marshal_failed: + status = -EMSGSIZE; + goto out; +bad_mic: + trace_rpcgss_get_mic(task, maj_stat); + status = -EIO; + goto out; } static int gss_renew_cred(struct rpc_task *task) @@ -1217,15 +1607,15 @@ static int gss_renew_cred(struct rpc_task *task) gc_base); struct rpc_auth *auth = oldcred->cr_auth; struct auth_cred acred = { - .uid = oldcred->cr_uid, + .cred = oldcred->cr_cred, .principal = gss_cred->gc_principal, - .machine_cred = (gss_cred->gc_principal != NULL ? 1 : 0), }; struct rpc_cred *new; new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW); if (IS_ERR(new)) return PTR_ERR(new); + task->tk_rqstp->rq_cred = new; put_rpccred(oldcred); return 0; @@ -1236,7 +1626,7 @@ static int gss_cred_is_negative_entry(struct rpc_cred *cred) if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) { unsigned long now = jiffies; unsigned long begin, expire; - struct gss_cred *gss_cred; + struct gss_cred *gss_cred; gss_cred = container_of(cred, struct gss_cred, gc_base); begin = gss_cred->gc_upcall_timestamp; @@ -1278,111 +1668,121 @@ out: static int gss_refresh_null(struct rpc_task *task) { - return -EACCES; + return 0; } -static __be32 * -gss_validate(struct rpc_task *task, __be32 *p) +static u32 +gss_validate_seqno_mic(struct gss_cl_ctx *ctx, u32 seqno, __be32 *seq, __be32 *p, u32 len) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; - struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); - __be32 seq; - struct kvec iov; - struct xdr_buf verf_buf; + struct kvec iov; + struct xdr_buf verf_buf; struct xdr_netobj mic; - u32 flav,len; - u32 maj_stat; - - dprintk("RPC: %5u %s\n", task->tk_pid, __func__); - - flav = ntohl(*p++); - if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE) - goto out_bad; - if (flav != RPC_AUTH_GSS) - goto out_bad; - seq = htonl(task->tk_rqstp->rq_seqno); - iov.iov_base = &seq; - iov.iov_len = sizeof(seq); + + *seq = cpu_to_be32(seqno); + iov.iov_base = seq; + iov.iov_len = 4; xdr_buf_from_iov(&iov, &verf_buf); mic.data = (u8 *)p; mic.len = len; + return gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic); +} - maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic); +static int +gss_validate(struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + __be32 *p, *seq = NULL; + u32 len, maj_stat; + int status; + int i = 1; /* don't recheck the first item */ + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + goto validate_failed; + if (*p++ != rpc_auth_gss) + goto validate_failed; + len = be32_to_cpup(p); + if (len > RPC_MAX_AUTH_SIZE) + goto validate_failed; + p = xdr_inline_decode(xdr, len); + if (!p) + goto validate_failed; + + seq = kmalloc(4, GFP_KERNEL); + if (!seq) + goto validate_failed; + maj_stat = gss_validate_seqno_mic(ctx, task->tk_rqstp->rq_seqnos[0], seq, p, len); + /* RFC 2203 5.3.3.1 - compute the checksum of each sequence number in the cache */ + while (unlikely(maj_stat == GSS_S_BAD_SIG && i < task->tk_rqstp->rq_seqno_count)) + maj_stat = gss_validate_seqno_mic(ctx, task->tk_rqstp->rq_seqnos[i++], seq, p, len); if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); - if (maj_stat) { - dprintk("RPC: %5u %s: gss_verify_mic returned error 0x%08x\n", - task->tk_pid, __func__, maj_stat); - goto out_bad; - } + if (maj_stat) + goto bad_mic; + /* We leave it to unwrap to calculate au_rslack. For now we just * calculate the length of the verifier: */ - cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2; - gss_put_ctx(ctx); - dprintk("RPC: %5u %s: gss_verify_mic succeeded.\n", - task->tk_pid, __func__); - return p + XDR_QUADLEN(len); -out_bad: + if (test_bit(RPCAUTH_AUTH_UPDATE_SLACK, &cred->cr_auth->au_flags)) + cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2; + status = 0; +out: gss_put_ctx(ctx); - dprintk("RPC: %5u %s failed.\n", task->tk_pid, __func__); - return NULL; -} - -static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, - __be32 *p, void *obj) -{ - struct xdr_stream xdr; + kfree(seq); + return status; - xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p); - encode(rqstp, &xdr, obj); +validate_failed: + status = -EIO; + goto out; +bad_mic: + trace_rpcgss_verify_mic(task, maj_stat); + status = -EACCES; + goto out; } -static inline int +static noinline_for_stack int gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - kxdreproc_t encode, struct rpc_rqst *rqstp, - __be32 *p, void *obj) + struct rpc_task *task, struct xdr_stream *xdr) { - struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; - struct xdr_buf integ_buf; - __be32 *integ_len = NULL; + struct rpc_rqst *rqstp = task->tk_rqstp; + struct xdr_buf integ_buf, *snd_buf = &rqstp->rq_snd_buf; struct xdr_netobj mic; - u32 offset; - __be32 *q; - struct kvec *iov; - u32 maj_stat = 0; - int status = -EIO; + __be32 *p, *integ_len; + u32 offset, maj_stat; + p = xdr_reserve_space(xdr, 2 * sizeof(*p)); + if (!p) + goto wrap_failed; integ_len = p++; - offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; - *p++ = htonl(rqstp->rq_seqno); + *p = cpu_to_be32(*rqstp->rq_seqnos); - gss_wrap_req_encode(encode, rqstp, p, obj); + if (rpcauth_wrap_req_encode(task, xdr)) + goto wrap_failed; + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; if (xdr_buf_subsegment(snd_buf, &integ_buf, offset, snd_buf->len - offset)) - return status; - *integ_len = htonl(integ_buf.len); + goto wrap_failed; + *integ_len = cpu_to_be32(integ_buf.len); - /* guess whether we're in the head or the tail: */ - if (snd_buf->page_len || snd_buf->tail[0].iov_len) - iov = snd_buf->tail; - else - iov = snd_buf->head; - p = iov->iov_base + iov->iov_len; + p = xdr_reserve_space(xdr, 0); + if (!p) + goto wrap_failed; mic.data = (u8 *)(p + 1); - maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic); - status = -EIO; /* XXX? */ if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); else if (maj_stat) - return status; - q = xdr_encode_opaque(p, NULL, mic.len); - - offset = (u8 *)q - (u8 *)p; - iov->iov_len += offset; - snd_buf->len += offset; + goto bad_mic; + /* Check that the trailing MIC fit in the buffer, after the fact */ + if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0) + goto wrap_failed; return 0; +wrap_failed: + return -EMSGSIZE; +bad_mic: + trace_rpcgss_get_mic(task, maj_stat); + return -EIO; } static void @@ -1393,6 +1793,7 @@ priv_release_snd_buf(struct rpc_rqst *rqstp) for (i=0; i < rqstp->rq_enc_pages_num; i++) __free_page(rqstp->rq_enc_pages[i]); kfree(rqstp->rq_enc_pages); + rqstp->rq_release_snd_buf = NULL; } static int @@ -1401,21 +1802,25 @@ alloc_enc_pages(struct rpc_rqst *rqstp) struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; int first, last, i; + if (rqstp->rq_release_snd_buf) + rqstp->rq_release_snd_buf(rqstp); + if (snd_buf->page_len == 0) { rqstp->rq_enc_pages_num = 0; return 0; } - first = snd_buf->page_base >> PAGE_CACHE_SHIFT; - last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_CACHE_SHIFT; + first = snd_buf->page_base >> PAGE_SHIFT; + last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_SHIFT; rqstp->rq_enc_pages_num = last - first + 1 + 1; rqstp->rq_enc_pages - = kmalloc(rqstp->rq_enc_pages_num * sizeof(struct page *), - GFP_NOFS); + = kmalloc_array(rqstp->rq_enc_pages_num, + sizeof(struct page *), + GFP_KERNEL); if (!rqstp->rq_enc_pages) goto out; for (i=0; i < rqstp->rq_enc_pages_num; i++) { - rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS); + rqstp->rq_enc_pages[i] = alloc_page(GFP_KERNEL); if (rqstp->rq_enc_pages[i] == NULL) goto out_free; } @@ -1428,224 +1833,350 @@ out: return -EAGAIN; } -static inline int +static noinline_for_stack int gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - kxdreproc_t encode, struct rpc_rqst *rqstp, - __be32 *p, void *obj) + struct rpc_task *task, struct xdr_stream *xdr) { + struct rpc_rqst *rqstp = task->tk_rqstp; struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; - u32 offset; - u32 maj_stat; + u32 pad, offset, maj_stat; int status; - __be32 *opaque_len; + __be32 *p, *opaque_len; struct page **inpages; int first; - int pad; struct kvec *iov; - char *tmp; + status = -EIO; + p = xdr_reserve_space(xdr, 2 * sizeof(*p)); + if (!p) + goto wrap_failed; opaque_len = p++; - offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; - *p++ = htonl(rqstp->rq_seqno); + *p = cpu_to_be32(*rqstp->rq_seqnos); - gss_wrap_req_encode(encode, rqstp, p, obj); + if (rpcauth_wrap_req_encode(task, xdr)) + goto wrap_failed; status = alloc_enc_pages(rqstp); - if (status) - return status; - first = snd_buf->page_base >> PAGE_CACHE_SHIFT; + if (unlikely(status)) + goto wrap_failed; + first = snd_buf->page_base >> PAGE_SHIFT; inpages = snd_buf->pages + first; snd_buf->pages = rqstp->rq_enc_pages; - snd_buf->page_base -= first << PAGE_CACHE_SHIFT; + snd_buf->page_base -= first << PAGE_SHIFT; /* - * Give the tail its own page, in case we need extra space in the - * head when wrapping: + * Move the tail into its own page, in case gss_wrap needs + * more space in the head when wrapping. * - * call_allocate() allocates twice the slack space required - * by the authentication flavor to rq_callsize. - * For GSS, slack is GSS_CRED_SLACK. + * Still... Why can't gss_wrap just slide the tail down? */ if (snd_buf->page_len || snd_buf->tail[0].iov_len) { + char *tmp; + tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]); memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len); snd_buf->tail[0].iov_base = tmp; } + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages); /* slack space should prevent this ever happening: */ - BUG_ON(snd_buf->len > snd_buf->buflen); - status = -EIO; + if (unlikely(snd_buf->len > snd_buf->buflen)) { + status = -EIO; + goto wrap_failed; + } /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was * done anyway, so it's safe to put the request on the wire: */ if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); else if (maj_stat) - return status; + goto bad_wrap; - *opaque_len = htonl(snd_buf->len - offset); - /* guess whether we're in the head or the tail: */ + *opaque_len = cpu_to_be32(snd_buf->len - offset); + /* guess whether the pad goes into the head or the tail: */ if (snd_buf->page_len || snd_buf->tail[0].iov_len) iov = snd_buf->tail; else iov = snd_buf->head; p = iov->iov_base + iov->iov_len; - pad = 3 - ((snd_buf->len - offset - 1) & 3); + pad = xdr_pad_size(snd_buf->len - offset); memset(p, 0, pad); iov->iov_len += pad; snd_buf->len += pad; return 0; +wrap_failed: + return status; +bad_wrap: + trace_rpcgss_wrap(task, maj_stat); + return -EIO; } -static int -gss_wrap_req(struct rpc_task *task, - kxdreproc_t encode, void *rqstp, __be32 *p, void *obj) +static int gss_wrap_req(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_cred *cred = task->tk_rqstp->rq_cred; struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); - int status = -EIO; + int status; - dprintk("RPC: %5u %s\n", task->tk_pid, __func__); + status = -EIO; if (ctx->gc_proc != RPC_GSS_PROC_DATA) { /* The spec seems a little ambiguous here, but I think that not * wrapping context destruction requests makes the most sense. */ - gss_wrap_req_encode(encode, rqstp, p, obj); - status = 0; + status = rpcauth_wrap_req_encode(task, xdr); goto out; } switch (gss_cred->gc_service) { case RPC_GSS_SVC_NONE: - gss_wrap_req_encode(encode, rqstp, p, obj); - status = 0; + status = rpcauth_wrap_req_encode(task, xdr); break; case RPC_GSS_SVC_INTEGRITY: - status = gss_wrap_req_integ(cred, ctx, encode, rqstp, p, obj); + status = gss_wrap_req_integ(cred, ctx, task, xdr); break; case RPC_GSS_SVC_PRIVACY: - status = gss_wrap_req_priv(cred, ctx, encode, rqstp, p, obj); + status = gss_wrap_req_priv(cred, ctx, task, xdr); break; + default: + status = -EIO; } out: gss_put_ctx(ctx); - dprintk("RPC: %5u %s returning %d\n", task->tk_pid, __func__, status); return status; } -static inline int -gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - struct rpc_rqst *rqstp, __be32 **p) +/** + * gss_update_rslack - Possibly update RPC receive buffer size estimates + * @task: rpc_task for incoming RPC Reply being unwrapped + * @cred: controlling rpc_cred for @task + * @before: XDR words needed before each RPC Reply message + * @after: XDR words needed following each RPC Reply message + * + */ +static void gss_update_rslack(struct rpc_task *task, struct rpc_cred *cred, + unsigned int before, unsigned int after) +{ + struct rpc_auth *auth = cred->cr_auth; + + if (test_and_clear_bit(RPCAUTH_AUTH_UPDATE_SLACK, &auth->au_flags)) { + auth->au_ralign = auth->au_verfsize + before; + auth->au_rslack = auth->au_verfsize + after; + trace_rpcgss_update_slack(task, auth); + } +} + +static int +gss_unwrap_resp_auth(struct rpc_task *task, struct rpc_cred *cred) +{ + gss_update_rslack(task, cred, 0, 0); + return 0; +} + +/* + * RFC 2203, Section 5.3.2.2 + * + * struct rpc_gss_integ_data { + * opaque databody_integ<>; + * opaque checksum<>; + * }; + * + * struct rpc_gss_data_t { + * unsigned int seq_num; + * proc_req_arg_t arg; + * }; + */ +static noinline_for_stack int +gss_unwrap_resp_integ(struct rpc_task *task, struct rpc_cred *cred, + struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp, + struct xdr_stream *xdr) { - struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; - struct xdr_buf integ_buf; + struct xdr_buf gss_data, *rcv_buf = &rqstp->rq_rcv_buf; + u32 len, offset, seqno, maj_stat; struct xdr_netobj mic; - u32 data_offset, mic_offset; - u32 integ_len; - u32 maj_stat; - int status = -EIO; + int ret; - integ_len = ntohl(*(*p)++); - if (integ_len & 3) - return status; - data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base; - mic_offset = integ_len + data_offset; - if (mic_offset > rcv_buf->len) - return status; - if (ntohl(*(*p)++) != rqstp->rq_seqno) - return status; + ret = -EIO; + mic.data = NULL; + + /* opaque databody_integ<>; */ + if (xdr_stream_decode_u32(xdr, &len)) + goto unwrap_failed; + if (len & 3) + goto unwrap_failed; + offset = rcv_buf->len - xdr_stream_remaining(xdr); + if (xdr_stream_decode_u32(xdr, &seqno)) + goto unwrap_failed; + if (seqno != *rqstp->rq_seqnos) + goto bad_seqno; + if (xdr_buf_subsegment(rcv_buf, &gss_data, offset, len)) + goto unwrap_failed; - if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset, - mic_offset - data_offset)) - return status; + /* + * The xdr_stream now points to the beginning of the + * upper layer payload, to be passed below to + * rpcauth_unwrap_resp_decode(). The checksum, which + * follows the upper layer payload in @rcv_buf, is + * located and parsed without updating the xdr_stream. + */ - if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset)) - return status; + /* opaque checksum<>; */ + offset += len; + if (xdr_decode_word(rcv_buf, offset, &len)) + goto unwrap_failed; + offset += sizeof(__be32); + if (offset + len > rcv_buf->len) + goto unwrap_failed; + mic.len = len; + mic.data = kmalloc(len, GFP_KERNEL); + if (ZERO_OR_NULL_PTR(mic.data)) + goto unwrap_failed; + if (read_bytes_from_xdr_buf(rcv_buf, offset, mic.data, mic.len)) + goto unwrap_failed; - maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic); + maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &gss_data, &mic); if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); if (maj_stat != GSS_S_COMPLETE) - return status; - return 0; -} + goto bad_mic; -static inline int -gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - struct rpc_rqst *rqstp, __be32 **p) -{ - struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; - u32 offset; - u32 opaque_len; - u32 maj_stat; - int status = -EIO; + gss_update_rslack(task, cred, 2, 2 + 1 + XDR_QUADLEN(mic.len)); + ret = 0; + +out: + kfree(mic.data); + return ret; - opaque_len = ntohl(*(*p)++); - offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base; +unwrap_failed: + trace_rpcgss_unwrap_failed(task); + goto out; +bad_seqno: + trace_rpcgss_bad_seqno(task, *rqstp->rq_seqnos, seqno); + goto out; +bad_mic: + trace_rpcgss_verify_mic(task, maj_stat); + goto out; +} + +static noinline_for_stack int +gss_unwrap_resp_priv(struct rpc_task *task, struct rpc_cred *cred, + struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp, + struct xdr_stream *xdr) +{ + struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; + struct kvec *head = rqstp->rq_rcv_buf.head; + u32 offset, opaque_len, maj_stat; + __be32 *p; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (unlikely(!p)) + goto unwrap_failed; + opaque_len = be32_to_cpup(p++); + offset = (u8 *)(p) - (u8 *)head->iov_base; if (offset + opaque_len > rcv_buf->len) - return status; - /* remove padding: */ - rcv_buf->len = offset + opaque_len; + goto unwrap_failed; - maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf); + maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, + offset + opaque_len, rcv_buf); if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); if (maj_stat != GSS_S_COMPLETE) - return status; - if (ntohl(*(*p)++) != rqstp->rq_seqno) - return status; + goto bad_unwrap; + /* gss_unwrap decrypted the sequence number */ + if (be32_to_cpup(p++) != *rqstp->rq_seqnos) + goto bad_seqno; + + /* gss_unwrap redacts the opaque blob from the head iovec. + * rcv_buf has changed, thus the stream needs to be reset. + */ + xdr_init_decode(xdr, rcv_buf, p, rqstp); + + gss_update_rslack(task, cred, 2 + ctx->gc_gss_ctx->align, + 2 + ctx->gc_gss_ctx->slack); return 0; +unwrap_failed: + trace_rpcgss_unwrap_failed(task); + return -EIO; +bad_seqno: + trace_rpcgss_bad_seqno(task, *rqstp->rq_seqnos, be32_to_cpup(--p)); + return -EIO; +bad_unwrap: + trace_rpcgss_unwrap(task, maj_stat); + return -EIO; } -static int -gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, - __be32 *p, void *obj) +static bool +gss_seq_is_newer(u32 new, u32 old) { - struct xdr_stream xdr; + return (s32)(new - old) > 0; +} - xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p); - return decode(rqstp, &xdr, obj); +static bool +gss_xmit_need_reencode(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_cred *cred = req->rq_cred; + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + u32 win, seq_xmit = 0; + bool ret = true; + + if (!ctx) + goto out; + + if (gss_seq_is_newer(*req->rq_seqnos, READ_ONCE(ctx->gc_seq))) + goto out_ctx; + + seq_xmit = READ_ONCE(ctx->gc_seq_xmit); + while (gss_seq_is_newer(*req->rq_seqnos, seq_xmit)) { + u32 tmp = seq_xmit; + + seq_xmit = cmpxchg(&ctx->gc_seq_xmit, tmp, *req->rq_seqnos); + if (seq_xmit == tmp) { + ret = false; + goto out_ctx; + } + } + + win = ctx->gc_win; + if (win > 0) + ret = !gss_seq_is_newer(*req->rq_seqnos, seq_xmit - win); + +out_ctx: + gss_put_ctx(ctx); +out: + trace_rpcgss_need_reencode(task, seq_xmit, ret); + return ret; } static int -gss_unwrap_resp(struct rpc_task *task, - kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj) +gss_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct rpc_rqst *rqstp = task->tk_rqstp; + struct rpc_cred *cred = rqstp->rq_cred; struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); - __be32 *savedp = p; - struct kvec *head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head; - int savedlen = head->iov_len; - int status = -EIO; + int status = -EIO; if (ctx->gc_proc != RPC_GSS_PROC_DATA) goto out_decode; switch (gss_cred->gc_service) { case RPC_GSS_SVC_NONE: + status = gss_unwrap_resp_auth(task, cred); break; case RPC_GSS_SVC_INTEGRITY: - status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p); - if (status) - goto out; + status = gss_unwrap_resp_integ(task, cred, ctx, rqstp, xdr); break; case RPC_GSS_SVC_PRIVACY: - status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p); - if (status) - goto out; + status = gss_unwrap_resp_priv(task, cred, ctx, rqstp, xdr); break; } - /* take into account extra slack for integrity and privacy cases: */ - cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp) - + (savedlen - head->iov_len); + if (status) + goto out; + out_decode: - status = gss_unwrap_req_decode(decode, rqstp, p, obj); + status = rpcauth_unwrap_resp_decode(task, xdr); out: gss_put_ctx(ctx); - dprintk("RPC: %5u %s returning %d\n", - task->tk_pid, __func__, status); return status; } @@ -1655,42 +2186,42 @@ static const struct rpc_authops authgss_ops = { .au_name = "RPCSEC_GSS", .create = gss_create, .destroy = gss_destroy, + .hash_cred = gss_hash_cred, .lookup_cred = gss_lookup_cred, .crcreate = gss_create_cred, - .pipes_create = gss_pipes_dentries_create, - .pipes_destroy = gss_pipes_dentries_destroy, - .list_pseudoflavors = gss_mech_list_pseudoflavors, .info2flavor = gss_mech_info2flavor, .flavor2info = gss_mech_flavor2info, }; static const struct rpc_credops gss_credops = { - .cr_name = "AUTH_GSS", - .crdestroy = gss_destroy_cred, - .cr_init = gss_cred_init, - .crbind = rpcauth_generic_bind_cred, - .crmatch = gss_match, - .crmarshal = gss_marshal, - .crrefresh = gss_refresh, - .crvalidate = gss_validate, - .crwrap_req = gss_wrap_req, - .crunwrap_resp = gss_unwrap_resp, + .cr_name = "AUTH_GSS", + .crdestroy = gss_destroy_cred, + .cr_init = gss_cred_init, + .crmatch = gss_match, + .crmarshal = gss_marshal, + .crrefresh = gss_refresh, + .crvalidate = gss_validate, + .crwrap_req = gss_wrap_req, + .crunwrap_resp = gss_unwrap_resp, + .crkey_timeout = gss_key_timeout, + .crstringify_acceptor = gss_stringify_acceptor, + .crneed_reencode = gss_xmit_need_reencode, }; static const struct rpc_credops gss_nullops = { - .cr_name = "AUTH_GSS", - .crdestroy = gss_destroy_nullcred, - .crbind = rpcauth_generic_bind_cred, - .crmatch = gss_match, - .crmarshal = gss_marshal, - .crrefresh = gss_refresh_null, - .crvalidate = gss_validate, - .crwrap_req = gss_wrap_req, - .crunwrap_resp = gss_unwrap_resp, + .cr_name = "AUTH_GSS", + .crdestroy = gss_destroy_nullcred, + .crmatch = gss_match, + .crmarshal = gss_marshal, + .crrefresh = gss_refresh_null, + .crvalidate = gss_validate, + .crwrap_req = gss_wrap_req, + .crunwrap_resp = gss_unwrap_resp, + .crstringify_acceptor = gss_stringify_acceptor, }; static const struct rpc_pipe_ops gss_upcall_ops_v0 = { - .upcall = rpc_pipe_generic_upcall, + .upcall = gss_v0_upcall, .downcall = gss_pipe_downcall, .destroy_msg = gss_pipe_destroy_msg, .open_pipe = gss_pipe_open_v0, @@ -1698,7 +2229,7 @@ static const struct rpc_pipe_ops gss_upcall_ops_v0 = { }; static const struct rpc_pipe_ops gss_upcall_ops_v1 = { - .upcall = rpc_pipe_generic_upcall, + .upcall = gss_v1_upcall, .downcall = gss_pipe_downcall, .destroy_msg = gss_pipe_destroy_msg, .open_pipe = gss_pipe_open_v1, @@ -1755,6 +2286,7 @@ static void __exit exit_rpcsec_gss(void) } MODULE_ALIAS("rpc-auth-6"); +MODULE_DESCRIPTION("Sun RPC Kerberos RPCSEC_GSS client authentication"); MODULE_LICENSE("GPL"); module_param_named(expired_cred_retry_delay, gss_expired_cred_retry_delay, @@ -1762,5 +2294,12 @@ module_param_named(expired_cred_retry_delay, MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until " "the RPC engine retries an expired credential"); +module_param_named(key_expire_timeo, + gss_key_expire_timeo, + uint, 0644); +MODULE_PARM_DESC(key_expire_timeo, "Time (in seconds) at the end of a " + "credential keys lifetime where the NFS layer cleans up " + "prior to key expiration"); + module_init(init_rpcsec_gss) module_exit(exit_rpcsec_gss) diff --git a/net/sunrpc/auth_gss/auth_gss_internal.h b/net/sunrpc/auth_gss/auth_gss_internal.h new file mode 100644 index 000000000000..4ebc1b7043d9 --- /dev/null +++ b/net/sunrpc/auth_gss/auth_gss_internal.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: BSD-3-Clause +/* + * linux/net/sunrpc/auth_gss/auth_gss_internal.h + * + * Internal definitions for RPCSEC_GSS client authentication + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + */ +#include <linux/err.h> +#include <linux/string.h> +#include <linux/sunrpc/xdr.h> + +static inline const void * +simple_get_bytes(const void *p, const void *end, void *res, size_t len) +{ + const void *q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + memcpy(res, p, len); + return q; +} + +static inline const void * +simple_get_netobj_noprof(const void *p, const void *end, struct xdr_netobj *dest) +{ + const void *q; + unsigned int len; + + p = simple_get_bytes(p, end, &len, sizeof(len)); + if (IS_ERR(p)) + return p; + q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + if (len) { + dest->data = kmemdup_noprof(p, len, GFP_KERNEL); + if (unlikely(dest->data == NULL)) + return ERR_PTR(-ENOMEM); + } else + dest->data = NULL; + dest->len = len; + return q; +} + +#define simple_get_netobj(...) alloc_hooks(simple_get_netobj_noprof(__VA_ARGS__)) diff --git a/net/sunrpc/auth_gss/gss_generic_token.c b/net/sunrpc/auth_gss/gss_generic_token.c deleted file mode 100644 index c586e92bcf76..000000000000 --- a/net/sunrpc/auth_gss/gss_generic_token.c +++ /dev/null @@ -1,234 +0,0 @@ -/* - * linux/net/sunrpc/gss_generic_token.c - * - * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/generic/util_token.c - * - * Copyright (c) 2000 The Regents of the University of Michigan. - * All rights reserved. - * - * Andy Adamson <andros@umich.edu> - */ - -/* - * Copyright 1993 by OpenVision Technologies, Inc. - * - * Permission to use, copy, modify, distribute, and sell this software - * and its documentation for any purpose is hereby granted without fee, - * provided that the above copyright notice appears in all copies and - * that both that copyright notice and this permission notice appear in - * supporting documentation, and that the name of OpenVision not be used - * in advertising or publicity pertaining to distribution of the software - * without specific, written prior permission. OpenVision makes no - * representations about the suitability of this software for any - * purpose. It is provided "as is" without express or implied warranty. - * - * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, - * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO - * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR - * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF - * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR - * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR - * PERFORMANCE OF THIS SOFTWARE. - */ - -#include <linux/types.h> -#include <linux/module.h> -#include <linux/string.h> -#include <linux/sunrpc/sched.h> -#include <linux/sunrpc/gss_asn1.h> - - -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif - - -/* TWRITE_STR from gssapiP_generic.h */ -#define TWRITE_STR(ptr, str, len) \ - memcpy((ptr), (char *) (str), (len)); \ - (ptr) += (len); - -/* XXXX this code currently makes the assumption that a mech oid will - never be longer than 127 bytes. This assumption is not inherent in - the interfaces, so the code can be fixed if the OSI namespace - balloons unexpectedly. */ - -/* Each token looks like this: - -0x60 tag for APPLICATION 0, SEQUENCE - (constructed, definite-length) - <length> possible multiple bytes, need to parse/generate - 0x06 tag for OBJECT IDENTIFIER - <moid_length> compile-time constant string (assume 1 byte) - <moid_bytes> compile-time constant string - <inner_bytes> the ANY containing the application token - bytes 0,1 are the token type - bytes 2,n are the token data - -For the purposes of this abstraction, the token "header" consists of -the sequence tag and length octets, the mech OID DER encoding, and the -first two inner bytes, which indicate the token type. The token -"body" consists of everything else. - -*/ - -static int -der_length_size( int length) -{ - if (length < (1<<7)) - return 1; - else if (length < (1<<8)) - return 2; -#if (SIZEOF_INT == 2) - else - return 3; -#else - else if (length < (1<<16)) - return 3; - else if (length < (1<<24)) - return 4; - else - return 5; -#endif -} - -static void -der_write_length(unsigned char **buf, int length) -{ - if (length < (1<<7)) { - *(*buf)++ = (unsigned char) length; - } else { - *(*buf)++ = (unsigned char) (der_length_size(length)+127); -#if (SIZEOF_INT > 2) - if (length >= (1<<24)) - *(*buf)++ = (unsigned char) (length>>24); - if (length >= (1<<16)) - *(*buf)++ = (unsigned char) ((length>>16)&0xff); -#endif - if (length >= (1<<8)) - *(*buf)++ = (unsigned char) ((length>>8)&0xff); - *(*buf)++ = (unsigned char) (length&0xff); - } -} - -/* returns decoded length, or < 0 on failure. Advances buf and - decrements bufsize */ - -static int -der_read_length(unsigned char **buf, int *bufsize) -{ - unsigned char sf; - int ret; - - if (*bufsize < 1) - return -1; - sf = *(*buf)++; - (*bufsize)--; - if (sf & 0x80) { - if ((sf &= 0x7f) > ((*bufsize)-1)) - return -1; - if (sf > SIZEOF_INT) - return -1; - ret = 0; - for (; sf; sf--) { - ret = (ret<<8) + (*(*buf)++); - (*bufsize)--; - } - } else { - ret = sf; - } - - return ret; -} - -/* returns the length of a token, given the mech oid and the body size */ - -int -g_token_size(struct xdr_netobj *mech, unsigned int body_size) -{ - /* set body_size to sequence contents size */ - body_size += 2 + (int) mech->len; /* NEED overflow check */ - return 1 + der_length_size(body_size) + body_size; -} - -EXPORT_SYMBOL_GPL(g_token_size); - -/* fills in a buffer with the token header. The buffer is assumed to - be the right size. buf is advanced past the token header */ - -void -g_make_token_header(struct xdr_netobj *mech, int body_size, unsigned char **buf) -{ - *(*buf)++ = 0x60; - der_write_length(buf, 2 + mech->len + body_size); - *(*buf)++ = 0x06; - *(*buf)++ = (unsigned char) mech->len; - TWRITE_STR(*buf, mech->data, ((int) mech->len)); -} - -EXPORT_SYMBOL_GPL(g_make_token_header); - -/* - * Given a buffer containing a token, reads and verifies the token, - * leaving buf advanced past the token header, and setting body_size - * to the number of remaining bytes. Returns 0 on success, - * G_BAD_TOK_HEADER for a variety of errors, and G_WRONG_MECH if the - * mechanism in the token does not match the mech argument. buf and - * *body_size are left unmodified on error. - */ -u32 -g_verify_token_header(struct xdr_netobj *mech, int *body_size, - unsigned char **buf_in, int toksize) -{ - unsigned char *buf = *buf_in; - int seqsize; - struct xdr_netobj toid; - int ret = 0; - - if ((toksize-=1) < 0) - return G_BAD_TOK_HEADER; - if (*buf++ != 0x60) - return G_BAD_TOK_HEADER; - - if ((seqsize = der_read_length(&buf, &toksize)) < 0) - return G_BAD_TOK_HEADER; - - if (seqsize != toksize) - return G_BAD_TOK_HEADER; - - if ((toksize-=1) < 0) - return G_BAD_TOK_HEADER; - if (*buf++ != 0x06) - return G_BAD_TOK_HEADER; - - if ((toksize-=1) < 0) - return G_BAD_TOK_HEADER; - toid.len = *buf++; - - if ((toksize-=toid.len) < 0) - return G_BAD_TOK_HEADER; - toid.data = buf; - buf+=toid.len; - - if (! g_OID_equal(&toid, mech)) - ret = G_WRONG_MECH; - - /* G_WRONG_MECH is not returned immediately because it's more important - to return G_BAD_TOK_HEADER if the token header is in fact bad */ - - if ((toksize-=2) < 0) - return G_BAD_TOK_HEADER; - - if (ret) - return ret; - - if (!ret) { - *buf_in = buf; - *body_size = toksize; - } - - return ret; -} - -EXPORT_SYMBOL_GPL(g_verify_token_header); - diff --git a/net/sunrpc/auth_gss/gss_krb5_crypto.c b/net/sunrpc/auth_gss/gss_krb5_crypto.c index 0f43e894bc0a..16dcf115de1e 100644 --- a/net/sunrpc/auth_gss/gss_krb5_crypto.c +++ b/net/sunrpc/auth_gss/gss_krb5_crypto.c @@ -34,24 +34,74 @@ * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. */ +#include <crypto/hash.h> +#include <crypto/skcipher.h> +#include <crypto/utils.h> #include <linux/err.h> #include <linux/types.h> #include <linux/mm.h> #include <linux/scatterlist.h> -#include <linux/crypto.h> #include <linux/highmem.h> #include <linux/pagemap.h> #include <linux/random.h> #include <linux/sunrpc/gss_krb5.h> #include <linux/sunrpc/xdr.h> +#include <kunit/visibility.h> -#ifdef RPC_DEBUG +#include "gss_krb5_internal.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif +/** + * krb5_make_confounder - Generate a confounder string + * @p: memory location into which to write the string + * @conflen: string length to write, in octets + * + * RFCs 1964 and 3961 mention only "a random confounder" without going + * into detail about its function or cryptographic requirements. The + * assumed purpose is to prevent repeated encryption of a plaintext with + * the same key from generating the same ciphertext. It is also used to + * pad minimum plaintext length to at least a single cipher block. + * + * However, in situations like the GSS Kerberos 5 mechanism, where the + * encryption IV is always all zeroes, the confounder also effectively + * functions like an IV. Thus, not only must it be unique from message + * to message, but it must also be difficult to predict. Otherwise an + * attacker can correlate the confounder to previous or future values, + * making the encryption easier to break. + * + * Given that the primary consumer of this encryption mechanism is a + * network storage protocol, a type of traffic that often carries + * predictable payloads (eg, all zeroes when reading unallocated blocks + * from a file), our confounder generation has to be cryptographically + * strong. + */ +void krb5_make_confounder(u8 *p, int conflen) +{ + get_random_bytes(p, conflen); +} + +/** + * krb5_encrypt - simple encryption of an RPCSEC GSS payload + * @tfm: initialized cipher transform + * @iv: pointer to an IV + * @in: plaintext to encrypt + * @out: OUT: ciphertext + * @length: length of input and output buffers, in bytes + * + * @iv may be NULL to force the use of an all-zero IV. + * The buffer containing the IV must be as large as the + * cipher's ivsize. + * + * Return values: + * %0: @in successfully encrypted into @out + * negative errno: @in not encrypted + */ u32 krb5_encrypt( - struct crypto_blkcipher *tfm, + struct crypto_sync_skcipher *tfm, void * iv, void * in, void * out, @@ -60,334 +110,120 @@ krb5_encrypt( u32 ret = -EINVAL; struct scatterlist sg[1]; u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; - struct blkcipher_desc desc = { .tfm = tfm, .info = local_iv }; + SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); - if (length % crypto_blkcipher_blocksize(tfm) != 0) + if (length % crypto_sync_skcipher_blocksize(tfm) != 0) goto out; - if (crypto_blkcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { + if (crypto_sync_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { dprintk("RPC: gss_k5encrypt: tfm iv size too large %d\n", - crypto_blkcipher_ivsize(tfm)); + crypto_sync_skcipher_ivsize(tfm)); goto out; } if (iv) - memcpy(local_iv, iv, crypto_blkcipher_ivsize(tfm)); + memcpy(local_iv, iv, crypto_sync_skcipher_ivsize(tfm)); memcpy(out, in, length); sg_init_one(sg, out, length); - ret = crypto_blkcipher_encrypt_iv(&desc, sg, sg, length); -out: - dprintk("RPC: krb5_encrypt returns %d\n", ret); - return ret; -} - -u32 -krb5_decrypt( - struct crypto_blkcipher *tfm, - void * iv, - void * in, - void * out, - int length) -{ - u32 ret = -EINVAL; - struct scatterlist sg[1]; - u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; - struct blkcipher_desc desc = { .tfm = tfm, .info = local_iv }; - - if (length % crypto_blkcipher_blocksize(tfm) != 0) - goto out; + skcipher_request_set_sync_tfm(req, tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, length, local_iv); - if (crypto_blkcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { - dprintk("RPC: gss_k5decrypt: tfm iv size too large %d\n", - crypto_blkcipher_ivsize(tfm)); - goto out; - } - if (iv) - memcpy(local_iv,iv, crypto_blkcipher_ivsize(tfm)); - - memcpy(out, in, length); - sg_init_one(sg, out, length); - - ret = crypto_blkcipher_decrypt_iv(&desc, sg, sg, length); + ret = crypto_skcipher_encrypt(req); + skcipher_request_zero(req); out: - dprintk("RPC: gss_k5decrypt returns %d\n",ret); + dprintk("RPC: krb5_encrypt returns %d\n", ret); return ret; } static int checksummer(struct scatterlist *sg, void *data) { - struct hash_desc *desc = data; - - return crypto_hash_update(desc, sg, sg->length); -} - -static int -arcfour_hmac_md5_usage_to_salt(unsigned int usage, u8 salt[4]) -{ - unsigned int ms_usage; - - switch (usage) { - case KG_USAGE_SIGN: - ms_usage = 15; - break; - case KG_USAGE_SEAL: - ms_usage = 13; - break; - default: - return -EINVAL; - } - salt[0] = (ms_usage >> 0) & 0xff; - salt[1] = (ms_usage >> 8) & 0xff; - salt[2] = (ms_usage >> 16) & 0xff; - salt[3] = (ms_usage >> 24) & 0xff; - - return 0; -} - -static u32 -make_checksum_hmac_md5(struct krb5_ctx *kctx, char *header, int hdrlen, - struct xdr_buf *body, int body_offset, u8 *cksumkey, - unsigned int usage, struct xdr_netobj *cksumout) -{ - struct hash_desc desc; - struct scatterlist sg[1]; - int err; - u8 checksumdata[GSS_KRB5_MAX_CKSUM_LEN]; - u8 rc4salt[4]; - struct crypto_hash *md5; - struct crypto_hash *hmac_md5; - - if (cksumkey == NULL) - return GSS_S_FAILURE; - - if (cksumout->len < kctx->gk5e->cksumlength) { - dprintk("%s: checksum buffer length, %u, too small for %s\n", - __func__, cksumout->len, kctx->gk5e->name); - return GSS_S_FAILURE; - } - - if (arcfour_hmac_md5_usage_to_salt(usage, rc4salt)) { - dprintk("%s: invalid usage value %u\n", __func__, usage); - return GSS_S_FAILURE; - } - - md5 = crypto_alloc_hash("md5", 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(md5)) - return GSS_S_FAILURE; - - hmac_md5 = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, - CRYPTO_ALG_ASYNC); - if (IS_ERR(hmac_md5)) { - crypto_free_hash(md5); - return GSS_S_FAILURE; - } - - desc.tfm = md5; - desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; - - err = crypto_hash_init(&desc); - if (err) - goto out; - sg_init_one(sg, rc4salt, 4); - err = crypto_hash_update(&desc, sg, 4); - if (err) - goto out; - - sg_init_one(sg, header, hdrlen); - err = crypto_hash_update(&desc, sg, hdrlen); - if (err) - goto out; - err = xdr_process_buf(body, body_offset, body->len - body_offset, - checksummer, &desc); - if (err) - goto out; - err = crypto_hash_final(&desc, checksumdata); - if (err) - goto out; + struct ahash_request *req = data; - desc.tfm = hmac_md5; - desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; + ahash_request_set_crypt(req, sg, NULL, sg->length); - err = crypto_hash_init(&desc); - if (err) - goto out; - err = crypto_hash_setkey(hmac_md5, cksumkey, kctx->gk5e->keylength); - if (err) - goto out; - - sg_init_one(sg, checksumdata, crypto_hash_digestsize(md5)); - err = crypto_hash_digest(&desc, sg, crypto_hash_digestsize(md5), - checksumdata); - if (err) - goto out; - - memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); - cksumout->len = kctx->gk5e->cksumlength; -out: - crypto_free_hash(md5); - crypto_free_hash(hmac_md5); - return err ? GSS_S_FAILURE : 0; + return crypto_ahash_update(req); } -/* - * checksum the plaintext data and hdrlen bytes of the token header - * The checksum is performed over the first 8 bytes of the - * gss token header and then over the data body +/** + * gss_krb5_checksum - Compute the MAC for a GSS Wrap or MIC token + * @tfm: an initialized hash transform + * @header: pointer to a buffer containing the token header, or NULL + * @hdrlen: number of octets in @header + * @body: xdr_buf containing an RPC message (body.len is the message length) + * @body_offset: byte offset into @body to start checksumming + * @cksumout: OUT: a buffer to be filled in with the computed HMAC + * + * Usually expressed as H = HMAC(K, message)[1..h] . + * + * Caller provides the truncation length of the output token (h) in + * cksumout.len. + * + * Return values: + * %GSS_S_COMPLETE: Digest computed, @cksumout filled in + * %GSS_S_FAILURE: Call failed */ u32 -make_checksum(struct krb5_ctx *kctx, char *header, int hdrlen, - struct xdr_buf *body, int body_offset, u8 *cksumkey, - unsigned int usage, struct xdr_netobj *cksumout) +gss_krb5_checksum(struct crypto_ahash *tfm, char *header, int hdrlen, + const struct xdr_buf *body, int body_offset, + struct xdr_netobj *cksumout) { - struct hash_desc desc; - struct scatterlist sg[1]; - int err; - u8 checksumdata[GSS_KRB5_MAX_CKSUM_LEN]; - unsigned int checksumlen; - - if (kctx->gk5e->ctype == CKSUMTYPE_HMAC_MD5_ARCFOUR) - return make_checksum_hmac_md5(kctx, header, hdrlen, - body, body_offset, - cksumkey, usage, cksumout); - - if (cksumout->len < kctx->gk5e->cksumlength) { - dprintk("%s: checksum buffer length, %u, too small for %s\n", - __func__, cksumout->len, kctx->gk5e->name); - return GSS_S_FAILURE; - } + struct ahash_request *req; + int err = -ENOMEM; + u8 *checksumdata; - desc.tfm = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(desc.tfm)) + checksumdata = kmalloc(crypto_ahash_digestsize(tfm), GFP_KERNEL); + if (!checksumdata) return GSS_S_FAILURE; - desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; - checksumlen = crypto_hash_digestsize(desc.tfm); - - if (cksumkey != NULL) { - err = crypto_hash_setkey(desc.tfm, cksumkey, - kctx->gk5e->keylength); - if (err) - goto out; - } - - err = crypto_hash_init(&desc); - if (err) - goto out; - sg_init_one(sg, header, hdrlen); - err = crypto_hash_update(&desc, sg, hdrlen); + req = ahash_request_alloc(tfm, GFP_KERNEL); + if (!req) + goto out_free_cksum; + ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); + err = crypto_ahash_init(req); if (err) - goto out; + goto out_free_ahash; + + /* + * Per RFC 4121 Section 4.2.4, the checksum is performed over the + * data body first, then over the octets in "header". + */ err = xdr_process_buf(body, body_offset, body->len - body_offset, - checksummer, &desc); - if (err) - goto out; - err = crypto_hash_final(&desc, checksumdata); + checksummer, req); if (err) - goto out; + goto out_free_ahash; + if (header) { + struct scatterlist sg[1]; - switch (kctx->gk5e->ctype) { - case CKSUMTYPE_RSA_MD5: - err = kctx->gk5e->encrypt(kctx->seq, NULL, checksumdata, - checksumdata, checksumlen); + sg_init_one(sg, header, hdrlen); + ahash_request_set_crypt(req, sg, NULL, hdrlen); + err = crypto_ahash_update(req); if (err) - goto out; - memcpy(cksumout->data, - checksumdata + checksumlen - kctx->gk5e->cksumlength, - kctx->gk5e->cksumlength); - break; - case CKSUMTYPE_HMAC_SHA1_DES3: - memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); - break; - default: - BUG(); - break; - } - cksumout->len = kctx->gk5e->cksumlength; -out: - crypto_free_hash(desc.tfm); - return err ? GSS_S_FAILURE : 0; -} - -/* - * checksum the plaintext data and hdrlen bytes of the token header - * Per rfc4121, sec. 4.2.4, the checksum is performed over the data - * body then over the first 16 octets of the MIC token - * Inclusion of the header data in the calculation of the - * checksum is optional. - */ -u32 -make_checksum_v2(struct krb5_ctx *kctx, char *header, int hdrlen, - struct xdr_buf *body, int body_offset, u8 *cksumkey, - unsigned int usage, struct xdr_netobj *cksumout) -{ - struct hash_desc desc; - struct scatterlist sg[1]; - int err; - u8 checksumdata[GSS_KRB5_MAX_CKSUM_LEN]; - unsigned int checksumlen; - - if (kctx->gk5e->keyed_cksum == 0) { - dprintk("%s: expected keyed hash for %s\n", - __func__, kctx->gk5e->name); - return GSS_S_FAILURE; - } - if (cksumkey == NULL) { - dprintk("%s: no key supplied for %s\n", - __func__, kctx->gk5e->name); - return GSS_S_FAILURE; + goto out_free_ahash; } - desc.tfm = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, - CRYPTO_ALG_ASYNC); - if (IS_ERR(desc.tfm)) - return GSS_S_FAILURE; - checksumlen = crypto_hash_digestsize(desc.tfm); - desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; - - err = crypto_hash_setkey(desc.tfm, cksumkey, kctx->gk5e->keylength); + ahash_request_set_crypt(req, NULL, checksumdata, 0); + err = crypto_ahash_final(req); if (err) - goto out; + goto out_free_ahash; - err = crypto_hash_init(&desc); - if (err) - goto out; - err = xdr_process_buf(body, body_offset, body->len - body_offset, - checksummer, &desc); - if (err) - goto out; - if (header != NULL) { - sg_init_one(sg, header, hdrlen); - err = crypto_hash_update(&desc, sg, hdrlen); - if (err) - goto out; - } - err = crypto_hash_final(&desc, checksumdata); - if (err) - goto out; + memcpy(cksumout->data, checksumdata, + min_t(int, cksumout->len, crypto_ahash_digestsize(tfm))); - cksumout->len = kctx->gk5e->cksumlength; - - switch (kctx->gk5e->ctype) { - case CKSUMTYPE_HMAC_SHA1_96_AES128: - case CKSUMTYPE_HMAC_SHA1_96_AES256: - /* note that this truncates the hash */ - memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); - break; - default: - BUG(); - break; - } -out: - crypto_free_hash(desc.tfm); - return err ? GSS_S_FAILURE : 0; +out_free_ahash: + ahash_request_free(req); +out_free_cksum: + kfree_sensitive(checksumdata); + return err ? GSS_S_FAILURE : GSS_S_COMPLETE; } +EXPORT_SYMBOL_IF_KUNIT(gss_krb5_checksum); struct encryptor_desc { u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; - struct blkcipher_desc desc; + struct skcipher_request *req; int pos; struct xdr_buf *outbuf; struct page **pages; @@ -402,6 +238,8 @@ encryptor(struct scatterlist *sg, void *data) { struct encryptor_desc *desc = data; struct xdr_buf *outbuf = desc->outbuf; + struct crypto_sync_skcipher *tfm = + crypto_sync_skcipher_reqtfm(desc->req); struct page *in_page; int thislen = desc->fraglen + sg->length; int fraglen, ret; @@ -414,7 +252,7 @@ encryptor(struct scatterlist *sg, void *data) page_pos = desc->pos - outbuf->head[0].iov_len; if (page_pos >= 0 && page_pos < outbuf->page_len) { /* pages are not in place: */ - int i = (page_pos + outbuf->page_base) >> PAGE_CACHE_SHIFT; + int i = (page_pos + outbuf->page_base) >> PAGE_SHIFT; in_page = desc->pages[i]; } else { in_page = sg_page(sg); @@ -427,7 +265,7 @@ encryptor(struct scatterlist *sg, void *data) desc->fraglen += sg->length; desc->pos += sg->length; - fraglen = thislen & (crypto_blkcipher_blocksize(desc->desc.tfm) - 1); + fraglen = thislen & (crypto_sync_skcipher_blocksize(tfm) - 1); thislen -= fraglen; if (thislen == 0) @@ -436,8 +274,10 @@ encryptor(struct scatterlist *sg, void *data) sg_mark_end(&desc->infrags[desc->fragno - 1]); sg_mark_end(&desc->outfrags[desc->fragno - 1]); - ret = crypto_blkcipher_encrypt_iv(&desc->desc, desc->outfrags, - desc->infrags, thislen); + skcipher_request_set_crypt(desc->req, desc->infrags, desc->outfrags, + thislen, desc->iv); + + ret = crypto_skcipher_encrypt(desc->req); if (ret) return ret; @@ -458,35 +298,9 @@ encryptor(struct scatterlist *sg, void *data) return 0; } -int -gss_encrypt_xdr_buf(struct crypto_blkcipher *tfm, struct xdr_buf *buf, - int offset, struct page **pages) -{ - int ret; - struct encryptor_desc desc; - - BUG_ON((buf->len - offset) % crypto_blkcipher_blocksize(tfm) != 0); - - memset(desc.iv, 0, sizeof(desc.iv)); - desc.desc.tfm = tfm; - desc.desc.info = desc.iv; - desc.desc.flags = 0; - desc.pos = offset; - desc.outbuf = buf; - desc.pages = pages; - desc.fragno = 0; - desc.fraglen = 0; - - sg_init_table(desc.infrags, 4); - sg_init_table(desc.outfrags, 4); - - ret = xdr_process_buf(buf, offset, buf->len - offset, encryptor, &desc); - return ret; -} - struct decryptor_desc { u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; - struct blkcipher_desc desc; + struct skcipher_request *req; struct scatterlist frags[4]; int fragno; int fraglen; @@ -497,6 +311,8 @@ decryptor(struct scatterlist *sg, void *data) { struct decryptor_desc *desc = data; int thislen = desc->fraglen + sg->length; + struct crypto_sync_skcipher *tfm = + crypto_sync_skcipher_reqtfm(desc->req); int fraglen, ret; /* Worst case is 4 fragments: head, end of page 1, start @@ -507,7 +323,7 @@ decryptor(struct scatterlist *sg, void *data) desc->fragno++; desc->fraglen += sg->length; - fraglen = thislen & (crypto_blkcipher_blocksize(desc->desc.tfm) - 1); + fraglen = thislen & (crypto_sync_skcipher_blocksize(tfm) - 1); thislen -= fraglen; if (thislen == 0) @@ -515,8 +331,10 @@ decryptor(struct scatterlist *sg, void *data) sg_mark_end(&desc->frags[desc->fragno - 1]); - ret = crypto_blkcipher_decrypt_iv(&desc->desc, desc->frags, - desc->frags, thislen); + skcipher_request_set_crypt(desc->req, desc->frags, desc->frags, + thislen, desc->iv); + + ret = crypto_skcipher_decrypt(desc->req); if (ret) return ret; @@ -534,27 +352,6 @@ decryptor(struct scatterlist *sg, void *data) return 0; } -int -gss_decrypt_xdr_buf(struct crypto_blkcipher *tfm, struct xdr_buf *buf, - int offset) -{ - struct decryptor_desc desc; - - /* XXXJBF: */ - BUG_ON((buf->len - offset) % crypto_blkcipher_blocksize(tfm) != 0); - - memset(desc.iv, 0, sizeof(desc.iv)); - desc.desc.tfm = tfm; - desc.desc.info = desc.iv; - desc.desc.flags = 0; - desc.fragno = 0; - desc.fraglen = 0; - - sg_init_table(desc.frags, 4); - - return xdr_process_buf(buf, offset, buf->len - offset, decryptor, &desc); -} - /* * This function makes the assumption that it was ultimately called * from gss_wrap(). @@ -580,7 +377,6 @@ xdr_extend_head(struct xdr_buf *buf, unsigned int base, unsigned int shiftlen) if (shiftlen == 0) return 0; - BUILD_BUG_ON(GSS_KRB5_MAX_SLACK_NEEDED > RPC_MAX_AUTH_SIZE); BUG_ON(shiftlen > RPC_MAX_AUTH_SIZE); p = buf->head[0].iov_base + base; @@ -594,20 +390,23 @@ xdr_extend_head(struct xdr_buf *buf, unsigned int base, unsigned int shiftlen) } static u32 -gss_krb5_cts_crypt(struct crypto_blkcipher *cipher, struct xdr_buf *buf, +gss_krb5_cts_crypt(struct crypto_sync_skcipher *cipher, struct xdr_buf *buf, u32 offset, u8 *iv, struct page **pages, int encrypt) { u32 ret; struct scatterlist sg[1]; - struct blkcipher_desc desc = { .tfm = cipher, .info = iv }; - u8 data[GSS_KRB5_MAX_BLOCKSIZE * 2]; + SYNC_SKCIPHER_REQUEST_ON_STACK(req, cipher); + u8 *data; struct page **save_pages; u32 len = buf->len - offset; - if (len > ARRAY_SIZE(data)) { + if (len > GSS_KRB5_MAX_BLOCKSIZE * 2) { WARN_ON(0); return -ENOMEM; } + data = kmalloc(GSS_KRB5_MAX_BLOCKSIZE * 2, GFP_KERNEL); + if (!data) + return -ENOMEM; /* * For encryption, we want to read from the cleartext @@ -625,54 +424,188 @@ gss_krb5_cts_crypt(struct crypto_blkcipher *cipher, struct xdr_buf *buf, sg_init_one(sg, data, len); + skcipher_request_set_sync_tfm(req, cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, len, iv); + if (encrypt) - ret = crypto_blkcipher_encrypt_iv(&desc, sg, sg, len); + ret = crypto_skcipher_encrypt(req); else - ret = crypto_blkcipher_decrypt_iv(&desc, sg, sg, len); + ret = crypto_skcipher_decrypt(req); + + skcipher_request_zero(req); if (ret) goto out; ret = write_bytes_to_xdr_buf(buf, offset, data, len); +#if IS_ENABLED(CONFIG_KUNIT) + /* + * CBC-CTS does not define an output IV but RFC 3962 defines it as the + * penultimate block of ciphertext, so copy that into the IV buffer + * before returning. + */ + if (encrypt) + memcpy(iv, data, crypto_sync_skcipher_ivsize(cipher)); +#endif + out: + kfree(data); return ret; } +/** + * krb5_cbc_cts_encrypt - encrypt in CBC mode with CTS + * @cts_tfm: CBC cipher with CTS + * @cbc_tfm: base CBC cipher + * @offset: starting byte offset for plaintext + * @buf: OUT: output buffer + * @pages: plaintext + * @iv: output CBC initialization vector, or NULL + * @ivsize: size of @iv, in octets + * + * To provide confidentiality, encrypt using cipher block chaining + * with ciphertext stealing. Message integrity is handled separately. + * + * Return values: + * %0: encryption successful + * negative errno: encryption could not be completed + */ +VISIBLE_IF_KUNIT +int krb5_cbc_cts_encrypt(struct crypto_sync_skcipher *cts_tfm, + struct crypto_sync_skcipher *cbc_tfm, + u32 offset, struct xdr_buf *buf, struct page **pages, + u8 *iv, unsigned int ivsize) +{ + u32 blocksize, nbytes, nblocks, cbcbytes; + struct encryptor_desc desc; + int err; + + blocksize = crypto_sync_skcipher_blocksize(cts_tfm); + nbytes = buf->len - offset; + nblocks = (nbytes + blocksize - 1) / blocksize; + cbcbytes = 0; + if (nblocks > 2) + cbcbytes = (nblocks - 2) * blocksize; + + memset(desc.iv, 0, sizeof(desc.iv)); + + /* Handle block-sized chunks of plaintext with CBC. */ + if (cbcbytes) { + SYNC_SKCIPHER_REQUEST_ON_STACK(req, cbc_tfm); + + desc.pos = offset; + desc.fragno = 0; + desc.fraglen = 0; + desc.pages = pages; + desc.outbuf = buf; + desc.req = req; + + skcipher_request_set_sync_tfm(req, cbc_tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + + sg_init_table(desc.infrags, 4); + sg_init_table(desc.outfrags, 4); + + err = xdr_process_buf(buf, offset, cbcbytes, encryptor, &desc); + skcipher_request_zero(req); + if (err) + return err; + } + + /* Remaining plaintext is handled with CBC-CTS. */ + err = gss_krb5_cts_crypt(cts_tfm, buf, offset + cbcbytes, + desc.iv, pages, 1); + if (err) + return err; + + if (unlikely(iv)) + memcpy(iv, desc.iv, ivsize); + return 0; +} +EXPORT_SYMBOL_IF_KUNIT(krb5_cbc_cts_encrypt); + +/** + * krb5_cbc_cts_decrypt - decrypt in CBC mode with CTS + * @cts_tfm: CBC cipher with CTS + * @cbc_tfm: base CBC cipher + * @offset: starting byte offset for plaintext + * @buf: OUT: output buffer + * + * Return values: + * %0: decryption successful + * negative errno: decryption could not be completed + */ +VISIBLE_IF_KUNIT +int krb5_cbc_cts_decrypt(struct crypto_sync_skcipher *cts_tfm, + struct crypto_sync_skcipher *cbc_tfm, + u32 offset, struct xdr_buf *buf) +{ + u32 blocksize, nblocks, cbcbytes; + struct decryptor_desc desc; + int err; + + blocksize = crypto_sync_skcipher_blocksize(cts_tfm); + nblocks = (buf->len + blocksize - 1) / blocksize; + cbcbytes = 0; + if (nblocks > 2) + cbcbytes = (nblocks - 2) * blocksize; + + memset(desc.iv, 0, sizeof(desc.iv)); + + /* Handle block-sized chunks of plaintext with CBC. */ + if (cbcbytes) { + SYNC_SKCIPHER_REQUEST_ON_STACK(req, cbc_tfm); + + desc.fragno = 0; + desc.fraglen = 0; + desc.req = req; + + skcipher_request_set_sync_tfm(req, cbc_tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + + sg_init_table(desc.frags, 4); + + err = xdr_process_buf(buf, 0, cbcbytes, decryptor, &desc); + skcipher_request_zero(req); + if (err) + return err; + } + + /* Remaining plaintext is handled with CBC-CTS. */ + return gss_krb5_cts_crypt(cts_tfm, buf, cbcbytes, desc.iv, NULL, 0); +} +EXPORT_SYMBOL_IF_KUNIT(krb5_cbc_cts_decrypt); + u32 gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, - struct xdr_buf *buf, int ec, struct page **pages) + struct xdr_buf *buf, struct page **pages) { u32 err; struct xdr_netobj hmac; - u8 *cksumkey; u8 *ecptr; - struct crypto_blkcipher *cipher, *aux_cipher; - int blocksize; + struct crypto_sync_skcipher *cipher, *aux_cipher; + struct crypto_ahash *ahash; struct page **save_pages; - int nblocks, nbytes; - struct encryptor_desc desc; - u32 cbcbytes; - unsigned int usage; + unsigned int conflen; if (kctx->initiate) { cipher = kctx->initiator_enc; aux_cipher = kctx->initiator_enc_aux; - cksumkey = kctx->initiator_integ; - usage = KG_USAGE_INITIATOR_SEAL; + ahash = kctx->initiator_integ; } else { cipher = kctx->acceptor_enc; aux_cipher = kctx->acceptor_enc_aux; - cksumkey = kctx->acceptor_integ; - usage = KG_USAGE_ACCEPTOR_SEAL; + ahash = kctx->acceptor_integ; } - blocksize = crypto_blkcipher_blocksize(cipher); + conflen = crypto_sync_skcipher_blocksize(cipher); /* hide the gss token header and insert the confounder */ offset += GSS_KRB5_TOK_HDR_LEN; - if (xdr_extend_head(buf, offset, kctx->gk5e->conflen)) + if (xdr_extend_head(buf, offset, conflen)) return GSS_S_FAILURE; - gss_krb5_make_confounder(buf->head[0].iov_base + offset, kctx->gk5e->conflen); + krb5_make_confounder(buf->head[0].iov_base + offset, conflen); offset -= GSS_KRB5_TOK_HDR_LEN; if (buf->tail[0].iov_base != NULL) { @@ -684,18 +617,12 @@ gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, ecptr = buf->tail[0].iov_base; } - memset(ecptr, 'X', ec); - buf->tail[0].iov_len += ec; - buf->len += ec; - /* copy plaintext gss token header after filler (if any) */ - memcpy(ecptr + ec, buf->head[0].iov_base + offset, - GSS_KRB5_TOK_HDR_LEN); + memcpy(ecptr, buf->head[0].iov_base + offset, GSS_KRB5_TOK_HDR_LEN); buf->tail[0].iov_len += GSS_KRB5_TOK_HDR_LEN; buf->len += GSS_KRB5_TOK_HDR_LEN; - /* Do the HMAC */ - hmac.len = GSS_KRB5_MAX_CKSUM_LEN; + hmac.len = kctx->gk5e->cksumlength; hmac.data = buf->tail[0].iov_base + buf->tail[0].iov_len; /* @@ -708,141 +635,73 @@ gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, save_pages = buf->pages; buf->pages = pages; - err = make_checksum_v2(kctx, NULL, 0, buf, - offset + GSS_KRB5_TOK_HDR_LEN, - cksumkey, usage, &hmac); + err = gss_krb5_checksum(ahash, NULL, 0, buf, + offset + GSS_KRB5_TOK_HDR_LEN, &hmac); buf->pages = save_pages; if (err) return GSS_S_FAILURE; - nbytes = buf->len - offset - GSS_KRB5_TOK_HDR_LEN; - nblocks = (nbytes + blocksize - 1) / blocksize; - cbcbytes = 0; - if (nblocks > 2) - cbcbytes = (nblocks - 2) * blocksize; - - memset(desc.iv, 0, sizeof(desc.iv)); - - if (cbcbytes) { - desc.pos = offset + GSS_KRB5_TOK_HDR_LEN; - desc.fragno = 0; - desc.fraglen = 0; - desc.pages = pages; - desc.outbuf = buf; - desc.desc.info = desc.iv; - desc.desc.flags = 0; - desc.desc.tfm = aux_cipher; - - sg_init_table(desc.infrags, 4); - sg_init_table(desc.outfrags, 4); - - err = xdr_process_buf(buf, offset + GSS_KRB5_TOK_HDR_LEN, - cbcbytes, encryptor, &desc); - if (err) - goto out_err; - } - - /* Make sure IV carries forward from any CBC results. */ - err = gss_krb5_cts_crypt(cipher, buf, - offset + GSS_KRB5_TOK_HDR_LEN + cbcbytes, - desc.iv, pages, 1); - if (err) { - err = GSS_S_FAILURE; - goto out_err; - } + err = krb5_cbc_cts_encrypt(cipher, aux_cipher, + offset + GSS_KRB5_TOK_HDR_LEN, + buf, pages, NULL, 0); + if (err) + return GSS_S_FAILURE; /* Now update buf to account for HMAC */ buf->tail[0].iov_len += kctx->gk5e->cksumlength; buf->len += kctx->gk5e->cksumlength; -out_err: - if (err) - err = GSS_S_FAILURE; - return err; + return GSS_S_COMPLETE; } u32 -gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_buf *buf, - u32 *headskip, u32 *tailskip) +gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, u32 len, + struct xdr_buf *buf, u32 *headskip, u32 *tailskip) { - struct xdr_buf subbuf; - u32 ret = 0; - u8 *cksum_key; - struct crypto_blkcipher *cipher, *aux_cipher; + struct crypto_sync_skcipher *cipher, *aux_cipher; + struct crypto_ahash *ahash; struct xdr_netobj our_hmac_obj; u8 our_hmac[GSS_KRB5_MAX_CKSUM_LEN]; u8 pkt_hmac[GSS_KRB5_MAX_CKSUM_LEN]; - int nblocks, blocksize, cbcbytes; - struct decryptor_desc desc; - unsigned int usage; + struct xdr_buf subbuf; + u32 ret = 0; if (kctx->initiate) { cipher = kctx->acceptor_enc; aux_cipher = kctx->acceptor_enc_aux; - cksum_key = kctx->acceptor_integ; - usage = KG_USAGE_ACCEPTOR_SEAL; + ahash = kctx->acceptor_integ; } else { cipher = kctx->initiator_enc; aux_cipher = kctx->initiator_enc_aux; - cksum_key = kctx->initiator_integ; - usage = KG_USAGE_INITIATOR_SEAL; + ahash = kctx->initiator_integ; } - blocksize = crypto_blkcipher_blocksize(cipher); - /* create a segment skipping the header and leaving out the checksum */ xdr_buf_subsegment(buf, &subbuf, offset + GSS_KRB5_TOK_HDR_LEN, - (buf->len - offset - GSS_KRB5_TOK_HDR_LEN - + (len - offset - GSS_KRB5_TOK_HDR_LEN - kctx->gk5e->cksumlength)); - nblocks = (subbuf.len + blocksize - 1) / blocksize; - - cbcbytes = 0; - if (nblocks > 2) - cbcbytes = (nblocks - 2) * blocksize; - - memset(desc.iv, 0, sizeof(desc.iv)); - - if (cbcbytes) { - desc.fragno = 0; - desc.fraglen = 0; - desc.desc.info = desc.iv; - desc.desc.flags = 0; - desc.desc.tfm = aux_cipher; - - sg_init_table(desc.frags, 4); - - ret = xdr_process_buf(&subbuf, 0, cbcbytes, decryptor, &desc); - if (ret) - goto out_err; - } - - /* Make sure IV carries forward from any CBC results. */ - ret = gss_krb5_cts_crypt(cipher, &subbuf, cbcbytes, desc.iv, NULL, 0); + ret = krb5_cbc_cts_decrypt(cipher, aux_cipher, 0, &subbuf); if (ret) goto out_err; - - /* Calculate our hmac over the plaintext data */ - our_hmac_obj.len = sizeof(our_hmac); + our_hmac_obj.len = kctx->gk5e->cksumlength; our_hmac_obj.data = our_hmac; - - ret = make_checksum_v2(kctx, NULL, 0, &subbuf, 0, - cksum_key, usage, &our_hmac_obj); + ret = gss_krb5_checksum(ahash, NULL, 0, &subbuf, 0, &our_hmac_obj); if (ret) goto out_err; /* Get the packet's hmac value */ - ret = read_bytes_from_xdr_buf(buf, buf->len - kctx->gk5e->cksumlength, + ret = read_bytes_from_xdr_buf(buf, len - kctx->gk5e->cksumlength, pkt_hmac, kctx->gk5e->cksumlength); if (ret) goto out_err; - if (memcmp(pkt_hmac, our_hmac, kctx->gk5e->cksumlength) != 0) { + if (crypto_memneq(pkt_hmac, our_hmac, kctx->gk5e->cksumlength) != 0) { ret = GSS_S_BAD_SIG; goto out_err; } - *headskip = kctx->gk5e->conflen; + *headskip = crypto_sync_skcipher_blocksize(cipher); *tailskip = kctx->gk5e->cksumlength; out_err: if (ret && ret != GSS_S_BAD_SIG) @@ -850,144 +709,247 @@ out_err: return ret; } -/* - * Compute Kseq given the initial session key and the checksum. - * Set the key of the given cipher. +/** + * krb5_etm_checksum - Compute a MAC for a GSS Wrap token + * @cipher: an initialized cipher transform + * @tfm: an initialized hash transform + * @body: xdr_buf containing an RPC message (body.len is the message length) + * @body_offset: byte offset into @body to start checksumming + * @cksumout: OUT: a buffer to be filled in with the computed HMAC + * + * Usually expressed as H = HMAC(K, IV | ciphertext)[1..h] . + * + * Caller provides the truncation length of the output token (h) in + * cksumout.len. + * + * Return values: + * %GSS_S_COMPLETE: Digest computed, @cksumout filled in + * %GSS_S_FAILURE: Call failed */ -int -krb5_rc4_setup_seq_key(struct krb5_ctx *kctx, struct crypto_blkcipher *cipher, - unsigned char *cksum) +VISIBLE_IF_KUNIT +u32 krb5_etm_checksum(struct crypto_sync_skcipher *cipher, + struct crypto_ahash *tfm, const struct xdr_buf *body, + int body_offset, struct xdr_netobj *cksumout) { - struct crypto_hash *hmac; - struct hash_desc desc; + unsigned int ivsize = crypto_sync_skcipher_ivsize(cipher); + struct ahash_request *req; struct scatterlist sg[1]; - u8 Kseq[GSS_KRB5_MAX_KEYLEN]; - u32 zeroconstant = 0; - int err; - - dprintk("%s: entered\n", __func__); + u8 *iv, *checksumdata; + int err = -ENOMEM; - hmac = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(hmac)) { - dprintk("%s: error %ld, allocating hash '%s'\n", - __func__, PTR_ERR(hmac), kctx->gk5e->cksum_name); - return PTR_ERR(hmac); - } - - desc.tfm = hmac; - desc.flags = 0; + checksumdata = kmalloc(crypto_ahash_digestsize(tfm), GFP_KERNEL); + if (!checksumdata) + return GSS_S_FAILURE; + /* For RPCSEC, the "initial cipher state" is always all zeroes. */ + iv = kzalloc(ivsize, GFP_KERNEL); + if (!iv) + goto out_free_mem; + + req = ahash_request_alloc(tfm, GFP_KERNEL); + if (!req) + goto out_free_mem; + ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); + err = crypto_ahash_init(req); + if (err) + goto out_free_ahash; - err = crypto_hash_init(&desc); + sg_init_one(sg, iv, ivsize); + ahash_request_set_crypt(req, sg, NULL, ivsize); + err = crypto_ahash_update(req); if (err) - goto out_err; + goto out_free_ahash; + err = xdr_process_buf(body, body_offset, body->len - body_offset, + checksummer, req); + if (err) + goto out_free_ahash; - /* Compute intermediate Kseq from session key */ - err = crypto_hash_setkey(hmac, kctx->Ksess, kctx->gk5e->keylength); + ahash_request_set_crypt(req, NULL, checksumdata, 0); + err = crypto_ahash_final(req); if (err) - goto out_err; + goto out_free_ahash; + memcpy(cksumout->data, checksumdata, cksumout->len); + +out_free_ahash: + ahash_request_free(req); +out_free_mem: + kfree(iv); + kfree_sensitive(checksumdata); + return err ? GSS_S_FAILURE : GSS_S_COMPLETE; +} +EXPORT_SYMBOL_IF_KUNIT(krb5_etm_checksum); + +/** + * krb5_etm_encrypt - Encrypt using the RFC 8009 rules + * @kctx: Kerberos context + * @offset: starting offset of the payload, in bytes + * @buf: OUT: send buffer to contain the encrypted payload + * @pages: plaintext payload + * + * The main difference with aes_encrypt is that "The HMAC is + * calculated over the cipher state concatenated with the AES + * output, instead of being calculated over the confounder and + * plaintext. This allows the message receiver to verify the + * integrity of the message before decrypting the message." + * + * RFC 8009 Section 5: + * + * encryption function: as follows, where E() is AES encryption in + * CBC-CS3 mode, and h is the size of truncated HMAC (128 bits or + * 192 bits as described above). + * + * N = random value of length 128 bits (the AES block size) + * IV = cipher state + * C = E(Ke, N | plaintext, IV) + * H = HMAC(Ki, IV | C) + * ciphertext = C | H[1..h] + * + * This encryption formula provides AEAD EtM with key separation. + * + * Return values: + * %GSS_S_COMPLETE: Encryption successful + * %GSS_S_FAILURE: Encryption failed + */ +u32 +krb5_etm_encrypt(struct krb5_ctx *kctx, u32 offset, + struct xdr_buf *buf, struct page **pages) +{ + struct crypto_sync_skcipher *cipher, *aux_cipher; + struct crypto_ahash *ahash; + struct xdr_netobj hmac; + unsigned int conflen; + u8 *ecptr; + u32 err; - sg_init_table(sg, 1); - sg_set_buf(sg, &zeroconstant, 4); + if (kctx->initiate) { + cipher = kctx->initiator_enc; + aux_cipher = kctx->initiator_enc_aux; + ahash = kctx->initiator_integ; + } else { + cipher = kctx->acceptor_enc; + aux_cipher = kctx->acceptor_enc_aux; + ahash = kctx->acceptor_integ; + } + conflen = crypto_sync_skcipher_blocksize(cipher); - err = crypto_hash_digest(&desc, sg, 4, Kseq); - if (err) - goto out_err; + offset += GSS_KRB5_TOK_HDR_LEN; + if (xdr_extend_head(buf, offset, conflen)) + return GSS_S_FAILURE; + krb5_make_confounder(buf->head[0].iov_base + offset, conflen); + offset -= GSS_KRB5_TOK_HDR_LEN; - /* Compute final Kseq from the checksum and intermediate Kseq */ - err = crypto_hash_setkey(hmac, Kseq, kctx->gk5e->keylength); - if (err) - goto out_err; + if (buf->tail[0].iov_base) { + ecptr = buf->tail[0].iov_base + buf->tail[0].iov_len; + } else { + buf->tail[0].iov_base = buf->head[0].iov_base + + buf->head[0].iov_len; + buf->tail[0].iov_len = 0; + ecptr = buf->tail[0].iov_base; + } - sg_set_buf(sg, cksum, 8); + memcpy(ecptr, buf->head[0].iov_base + offset, GSS_KRB5_TOK_HDR_LEN); + buf->tail[0].iov_len += GSS_KRB5_TOK_HDR_LEN; + buf->len += GSS_KRB5_TOK_HDR_LEN; - err = crypto_hash_digest(&desc, sg, 8, Kseq); + err = krb5_cbc_cts_encrypt(cipher, aux_cipher, + offset + GSS_KRB5_TOK_HDR_LEN, + buf, pages, NULL, 0); if (err) - goto out_err; + return GSS_S_FAILURE; - err = crypto_blkcipher_setkey(cipher, Kseq, kctx->gk5e->keylength); + hmac.data = buf->tail[0].iov_base + buf->tail[0].iov_len; + hmac.len = kctx->gk5e->cksumlength; + err = krb5_etm_checksum(cipher, ahash, + buf, offset + GSS_KRB5_TOK_HDR_LEN, &hmac); if (err) goto out_err; + buf->tail[0].iov_len += kctx->gk5e->cksumlength; + buf->len += kctx->gk5e->cksumlength; - err = 0; + return GSS_S_COMPLETE; out_err: - crypto_free_hash(hmac); - dprintk("%s: returning %d\n", __func__, err); - return err; + return GSS_S_FAILURE; } -/* - * Compute Kcrypt given the initial session key and the plaintext seqnum. - * Set the key of cipher kctx->enc. +/** + * krb5_etm_decrypt - Decrypt using the RFC 8009 rules + * @kctx: Kerberos context + * @offset: starting offset of the ciphertext, in bytes + * @len: size of ciphertext to unwrap + * @buf: ciphertext to unwrap + * @headskip: OUT: the enctype's confounder length, in octets + * @tailskip: OUT: the enctype's HMAC length, in octets + * + * RFC 8009 Section 5: + * + * decryption function: as follows, where D() is AES decryption in + * CBC-CS3 mode, and h is the size of truncated HMAC. + * + * (C, H) = ciphertext + * (Note: H is the last h bits of the ciphertext.) + * IV = cipher state + * if H != HMAC(Ki, IV | C)[1..h] + * stop, report error + * (N, P) = D(Ke, C, IV) + * + * Return values: + * %GSS_S_COMPLETE: Decryption successful + * %GSS_S_BAD_SIG: computed HMAC != received HMAC + * %GSS_S_FAILURE: Decryption failed */ -int -krb5_rc4_setup_enc_key(struct krb5_ctx *kctx, struct crypto_blkcipher *cipher, - s32 seqnum) +u32 +krb5_etm_decrypt(struct krb5_ctx *kctx, u32 offset, u32 len, + struct xdr_buf *buf, u32 *headskip, u32 *tailskip) { - struct crypto_hash *hmac; - struct hash_desc desc; - struct scatterlist sg[1]; - u8 Kcrypt[GSS_KRB5_MAX_KEYLEN]; - u8 zeroconstant[4] = {0}; - u8 seqnumarray[4]; - int err, i; - - dprintk("%s: entered, seqnum %u\n", __func__, seqnum); - - hmac = crypto_alloc_hash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(hmac)) { - dprintk("%s: error %ld, allocating hash '%s'\n", - __func__, PTR_ERR(hmac), kctx->gk5e->cksum_name); - return PTR_ERR(hmac); - } - - desc.tfm = hmac; - desc.flags = 0; - - err = crypto_hash_init(&desc); - if (err) - goto out_err; - - /* Compute intermediate Kcrypt from session key */ - for (i = 0; i < kctx->gk5e->keylength; i++) - Kcrypt[i] = kctx->Ksess[i] ^ 0xf0; + struct crypto_sync_skcipher *cipher, *aux_cipher; + u8 our_hmac[GSS_KRB5_MAX_CKSUM_LEN]; + u8 pkt_hmac[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj our_hmac_obj; + struct crypto_ahash *ahash; + struct xdr_buf subbuf; + u32 ret = 0; - err = crypto_hash_setkey(hmac, Kcrypt, kctx->gk5e->keylength); - if (err) - goto out_err; + if (kctx->initiate) { + cipher = kctx->acceptor_enc; + aux_cipher = kctx->acceptor_enc_aux; + ahash = kctx->acceptor_integ; + } else { + cipher = kctx->initiator_enc; + aux_cipher = kctx->initiator_enc_aux; + ahash = kctx->initiator_integ; + } - sg_init_table(sg, 1); - sg_set_buf(sg, zeroconstant, 4); + /* Extract the ciphertext into @subbuf. */ + xdr_buf_subsegment(buf, &subbuf, offset + GSS_KRB5_TOK_HDR_LEN, + (len - offset - GSS_KRB5_TOK_HDR_LEN - + kctx->gk5e->cksumlength)); - err = crypto_hash_digest(&desc, sg, 4, Kcrypt); - if (err) + our_hmac_obj.data = our_hmac; + our_hmac_obj.len = kctx->gk5e->cksumlength; + ret = krb5_etm_checksum(cipher, ahash, &subbuf, 0, &our_hmac_obj); + if (ret) goto out_err; - - /* Compute final Kcrypt from the seqnum and intermediate Kcrypt */ - err = crypto_hash_setkey(hmac, Kcrypt, kctx->gk5e->keylength); - if (err) + ret = read_bytes_from_xdr_buf(buf, len - kctx->gk5e->cksumlength, + pkt_hmac, kctx->gk5e->cksumlength); + if (ret) goto out_err; - - seqnumarray[0] = (unsigned char) ((seqnum >> 24) & 0xff); - seqnumarray[1] = (unsigned char) ((seqnum >> 16) & 0xff); - seqnumarray[2] = (unsigned char) ((seqnum >> 8) & 0xff); - seqnumarray[3] = (unsigned char) ((seqnum >> 0) & 0xff); - - sg_set_buf(sg, seqnumarray, 4); - - err = crypto_hash_digest(&desc, sg, 4, Kcrypt); - if (err) + if (crypto_memneq(pkt_hmac, our_hmac, kctx->gk5e->cksumlength) != 0) { + ret = GSS_S_BAD_SIG; goto out_err; + } - err = crypto_blkcipher_setkey(cipher, Kcrypt, kctx->gk5e->keylength); - if (err) + ret = krb5_cbc_cts_decrypt(cipher, aux_cipher, 0, &subbuf); + if (ret) { + ret = GSS_S_FAILURE; goto out_err; + } - err = 0; + *headskip = crypto_sync_skcipher_blocksize(cipher); + *tailskip = kctx->gk5e->cksumlength; + return GSS_S_COMPLETE; out_err: - crypto_free_hash(hmac); - dprintk("%s: returning %d\n", __func__, err); - return err; + if (ret != GSS_S_BAD_SIG) + ret = GSS_S_FAILURE; + return ret; } - diff --git a/net/sunrpc/auth_gss/gss_krb5_internal.h b/net/sunrpc/auth_gss/gss_krb5_internal.h new file mode 100644 index 000000000000..8769e9e705bf --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_internal.h @@ -0,0 +1,195 @@ +/* SPDX-License-Identifier: GPL-2.0 or BSD-3-Clause */ +/* + * SunRPC GSS Kerberos 5 mechanism internal definitions + * + * Copyright (c) 2022 Oracle and/or its affiliates. + */ + +#ifndef _NET_SUNRPC_AUTH_GSS_KRB5_INTERNAL_H +#define _NET_SUNRPC_AUTH_GSS_KRB5_INTERNAL_H + +/* + * The RFCs often specify payload lengths in bits. This helper + * converts a specified bit-length to the number of octets/bytes. + */ +#define BITS2OCTETS(x) ((x) / 8) + +struct krb5_ctx; + +struct gss_krb5_enctype { + const u32 etype; /* encryption (key) type */ + const u32 ctype; /* checksum type */ + const char *name; /* "friendly" name */ + const char *encrypt_name; /* crypto encrypt name */ + const char *aux_cipher; /* aux encrypt cipher name */ + const char *cksum_name; /* crypto checksum name */ + const u16 signalg; /* signing algorithm */ + const u16 sealalg; /* sealing algorithm */ + const u32 cksumlength; /* checksum length */ + const u32 keyed_cksum; /* is it a keyed cksum? */ + const u32 keybytes; /* raw key len, in bytes */ + const u32 keylength; /* protocol key length, in octets */ + const u32 Kc_length; /* checksum subkey length, in octets */ + const u32 Ke_length; /* encryption subkey length, in octets */ + const u32 Ki_length; /* integrity subkey length, in octets */ + + int (*derive_key)(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *in, + struct xdr_netobj *out, + const struct xdr_netobj *label, + gfp_t gfp_mask); + u32 (*encrypt)(struct krb5_ctx *kctx, u32 offset, + struct xdr_buf *buf, struct page **pages); + u32 (*decrypt)(struct krb5_ctx *kctx, u32 offset, u32 len, + struct xdr_buf *buf, u32 *headskip, u32 *tailskip); + u32 (*get_mic)(struct krb5_ctx *kctx, struct xdr_buf *text, + struct xdr_netobj *token); + u32 (*verify_mic)(struct krb5_ctx *kctx, struct xdr_buf *message_buffer, + struct xdr_netobj *read_token); + u32 (*wrap)(struct krb5_ctx *kctx, int offset, + struct xdr_buf *buf, struct page **pages); + u32 (*unwrap)(struct krb5_ctx *kctx, int offset, int len, + struct xdr_buf *buf, unsigned int *slack, + unsigned int *align); +}; + +/* krb5_ctx flags definitions */ +#define KRB5_CTX_FLAG_INITIATOR 0x00000001 +#define KRB5_CTX_FLAG_ACCEPTOR_SUBKEY 0x00000004 + +struct krb5_ctx { + int initiate; /* 1 = initiating, 0 = accepting */ + u32 enctype; + u32 flags; + const struct gss_krb5_enctype *gk5e; /* enctype-specific info */ + struct crypto_sync_skcipher *enc; + struct crypto_sync_skcipher *seq; + struct crypto_sync_skcipher *acceptor_enc; + struct crypto_sync_skcipher *initiator_enc; + struct crypto_sync_skcipher *acceptor_enc_aux; + struct crypto_sync_skcipher *initiator_enc_aux; + struct crypto_ahash *acceptor_sign; + struct crypto_ahash *initiator_sign; + struct crypto_ahash *initiator_integ; + struct crypto_ahash *acceptor_integ; + u8 Ksess[GSS_KRB5_MAX_KEYLEN]; /* session key */ + u8 cksum[GSS_KRB5_MAX_KEYLEN]; + atomic_t seq_send; + atomic64_t seq_send64; + time64_t endtime; + struct xdr_netobj mech_used; +}; + +/* + * GSS Kerberos 5 mechanism Per-Message calls. + */ + +u32 gss_krb5_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text, + struct xdr_netobj *token); + +u32 gss_krb5_verify_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *message_buffer, + struct xdr_netobj *read_token); + +u32 gss_krb5_wrap_v2(struct krb5_ctx *kctx, int offset, + struct xdr_buf *buf, struct page **pages); + +u32 gss_krb5_unwrap_v2(struct krb5_ctx *kctx, int offset, int len, + struct xdr_buf *buf, unsigned int *slack, + unsigned int *align); + +/* + * Implementation internal functions + */ + +/* Key Derivation Functions */ + +int krb5_derive_key_v2(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + const struct xdr_netobj *label, + gfp_t gfp_mask); + +int krb5_kdf_hmac_sha2(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + const struct xdr_netobj *in_constant, + gfp_t gfp_mask); + +int krb5_kdf_feedback_cmac(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + const struct xdr_netobj *in_constant, + gfp_t gfp_mask); + +/** + * krb5_derive_key - Derive a subkey from a protocol key + * @kctx: Kerberos 5 context + * @inkey: base protocol key + * @outkey: OUT: derived key + * @usage: key usage value + * @seed: key usage seed (one octet) + * @gfp_mask: memory allocation control flags + * + * Caller sets @outkey->len to the desired length of the derived key. + * + * On success, returns 0 and fills in @outkey. A negative errno value + * is returned on failure. + */ +static inline int krb5_derive_key(struct krb5_ctx *kctx, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + u32 usage, u8 seed, gfp_t gfp_mask) +{ + const struct gss_krb5_enctype *gk5e = kctx->gk5e; + u8 label_data[GSS_KRB5_K5CLENGTH]; + struct xdr_netobj label = { + .len = sizeof(label_data), + .data = label_data, + }; + __be32 *p = (__be32 *)label_data; + + *p = cpu_to_be32(usage); + label_data[4] = seed; + return gk5e->derive_key(gk5e, inkey, outkey, &label, gfp_mask); +} + +void krb5_make_confounder(u8 *p, int conflen); + +u32 gss_krb5_checksum(struct crypto_ahash *tfm, char *header, int hdrlen, + const struct xdr_buf *body, int body_offset, + struct xdr_netobj *cksumout); + +u32 krb5_encrypt(struct crypto_sync_skcipher *key, void *iv, void *in, + void *out, int length); + +int xdr_extend_head(struct xdr_buf *buf, unsigned int base, + unsigned int shiftlen); + +u32 gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, + struct xdr_buf *buf, struct page **pages); + +u32 gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, u32 len, + struct xdr_buf *buf, u32 *plainoffset, u32 *plainlen); + +u32 krb5_etm_encrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_buf *buf, + struct page **pages); + +u32 krb5_etm_decrypt(struct krb5_ctx *kctx, u32 offset, u32 len, + struct xdr_buf *buf, u32 *headskip, u32 *tailskip); + +#if IS_ENABLED(CONFIG_KUNIT) +void krb5_nfold(u32 inbits, const u8 *in, u32 outbits, u8 *out); +const struct gss_krb5_enctype *gss_krb5_lookup_enctype(u32 etype); +int krb5_cbc_cts_encrypt(struct crypto_sync_skcipher *cts_tfm, + struct crypto_sync_skcipher *cbc_tfm, u32 offset, + struct xdr_buf *buf, struct page **pages, + u8 *iv, unsigned int ivsize); +int krb5_cbc_cts_decrypt(struct crypto_sync_skcipher *cts_tfm, + struct crypto_sync_skcipher *cbc_tfm, + u32 offset, struct xdr_buf *buf); +u32 krb5_etm_checksum(struct crypto_sync_skcipher *cipher, + struct crypto_ahash *tfm, const struct xdr_buf *body, + int body_offset, struct xdr_netobj *cksumout); +#endif + +#endif /* _NET_SUNRPC_AUTH_GSS_KRB5_INTERNAL_H */ diff --git a/net/sunrpc/auth_gss/gss_krb5_keys.c b/net/sunrpc/auth_gss/gss_krb5_keys.c index 76e42e6be755..4eb19c3a54c7 100644 --- a/net/sunrpc/auth_gss/gss_krb5_keys.c +++ b/net/sunrpc/auth_gss/gss_krb5_keys.c @@ -54,25 +54,35 @@ * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. */ +#include <crypto/skcipher.h> #include <linux/err.h> #include <linux/types.h> -#include <linux/crypto.h> #include <linux/sunrpc/gss_krb5.h> #include <linux/sunrpc/xdr.h> +#include <linux/lcm.h> +#include <crypto/hash.h> +#include <kunit/visibility.h> -#ifdef RPC_DEBUG +#include "gss_krb5_internal.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif -/* +/** + * krb5_nfold - n-fold function + * @inbits: number of bits in @in + * @in: buffer containing input to fold + * @outbits: number of bits in the output buffer + * @out: buffer to hold the result + * * This is the n-fold function as described in rfc3961, sec 5.1 * Taken from MIT Kerberos and modified. */ - -static void krb5_nfold(u32 inbits, const u8 *in, - u32 outbits, u8 *out) +VISIBLE_IF_KUNIT +void krb5_nfold(u32 inbits, const u8 *in, u32 outbits, u8 *out) { - int a, b, c, lcm; + unsigned long ulcm; int byte, i, msbit; /* the code below is more readable if I make these bytes @@ -82,17 +92,7 @@ static void krb5_nfold(u32 inbits, const u8 *in, outbits >>= 3; /* first compute lcm(n,k) */ - - a = outbits; - b = inbits; - - while (b != 0) { - c = b; - b = a%b; - a = c; - } - - lcm = outbits*inbits/a; + ulcm = lcm(inbits, outbits); /* now do the real work */ @@ -101,7 +101,7 @@ static void krb5_nfold(u32 inbits, const u8 *in, /* this will end up cycling through k lcm(k,n)/k times, which is correct */ - for (i = lcm-1; i >= 0; i--) { + for (i = ulcm-1; i >= 0; i--) { /* compute the msbit in k which gets added into this byte */ msbit = ( /* first, start with the msbit in the first, @@ -141,41 +141,36 @@ static void krb5_nfold(u32 inbits, const u8 *in, } } } +EXPORT_SYMBOL_IF_KUNIT(krb5_nfold); /* * This is the DK (derive_key) function as described in rfc3961, sec 5.1 * Taken from MIT Kerberos and modified. */ - -u32 krb5_derive_key(const struct gss_krb5_enctype *gk5e, - const struct xdr_netobj *inkey, - struct xdr_netobj *outkey, - const struct xdr_netobj *in_constant, - gfp_t gfp_mask) +static int krb5_DK(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, u8 *rawkey, + const struct xdr_netobj *in_constant, gfp_t gfp_mask) { size_t blocksize, keybytes, keylength, n; - unsigned char *inblockdata, *outblockdata, *rawkey; + unsigned char *inblockdata, *outblockdata; struct xdr_netobj inblock, outblock; - struct crypto_blkcipher *cipher; - u32 ret = EINVAL; + struct crypto_sync_skcipher *cipher; + int ret = -EINVAL; - blocksize = gk5e->blocksize; keybytes = gk5e->keybytes; keylength = gk5e->keylength; - if ((inkey->len != keylength) || (outkey->len != keylength)) + if (inkey->len != keylength) goto err_return; - cipher = crypto_alloc_blkcipher(gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); + cipher = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0); if (IS_ERR(cipher)) goto err_return; - if (crypto_blkcipher_setkey(cipher, inkey->data, inkey->len)) - goto err_return; - - /* allocate and set up buffers */ + blocksize = crypto_sync_skcipher_blocksize(cipher); + if (crypto_sync_skcipher_setkey(cipher, inkey->data, inkey->len)) + goto err_free_cipher; - ret = ENOMEM; + ret = -ENOMEM; inblockdata = kmalloc(blocksize, gfp_mask); if (inblockdata == NULL) goto err_free_cipher; @@ -184,10 +179,6 @@ u32 krb5_derive_key(const struct gss_krb5_enctype *gk5e, if (outblockdata == NULL) goto err_free_in; - rawkey = kmalloc(keybytes, gfp_mask); - if (rawkey == NULL) - goto err_free_out; - inblock.data = (char *) inblockdata; inblock.len = blocksize; @@ -207,8 +198,8 @@ u32 krb5_derive_key(const struct gss_krb5_enctype *gk5e, n = 0; while (n < keybytes) { - (*(gk5e->encrypt))(cipher, NULL, inblock.data, - outblock.data, inblock.len); + krb5_encrypt(cipher, NULL, inblock.data, outblock.data, + inblock.len); if ((keybytes - n) <= outblock.len) { memcpy(rawkey + n, outblock.data, (keybytes - n)); @@ -220,117 +211,336 @@ u32 krb5_derive_key(const struct gss_krb5_enctype *gk5e, n += outblock.len; } - /* postprocess the key */ - - inblock.data = (char *) rawkey; - inblock.len = keybytes; - - BUG_ON(gk5e->mk_key == NULL); - ret = (*(gk5e->mk_key))(gk5e, &inblock, outkey); - if (ret) { - dprintk("%s: got %d from mk_key function for '%s'\n", - __func__, ret, gk5e->encrypt_name); - goto err_free_raw; - } - - /* clean memory, free resources and exit */ - ret = 0; -err_free_raw: - memset(rawkey, 0, keybytes); - kfree(rawkey); -err_free_out: - memset(outblockdata, 0, blocksize); - kfree(outblockdata); + kfree_sensitive(outblockdata); err_free_in: - memset(inblockdata, 0, blocksize); - kfree(inblockdata); + kfree_sensitive(inblockdata); err_free_cipher: - crypto_free_blkcipher(cipher); + crypto_free_sync_skcipher(cipher); err_return: return ret; } -#define smask(step) ((1<<step)-1) -#define pstep(x, step) (((x)&smask(step))^(((x)>>step)&smask(step))) -#define parity_char(x) pstep(pstep(pstep((x), 4), 2), 1) - -static void mit_des_fixup_key_parity(u8 key[8]) -{ - int i; - for (i = 0; i < 8; i++) { - key[i] &= 0xfe; - key[i] |= 1^parity_char(key[i]); - } -} - /* - * This is the des3 key derivation postprocess function + * This is the identity function, with some sanity checking. */ -u32 gss_krb5_des3_make_key(const struct gss_krb5_enctype *gk5e, - struct xdr_netobj *randombits, - struct xdr_netobj *key) +static int krb5_random_to_key_v2(const struct gss_krb5_enctype *gk5e, + struct xdr_netobj *randombits, + struct xdr_netobj *key) { - int i; - u32 ret = EINVAL; + int ret = -EINVAL; - if (key->len != 24) { + if (key->len != 16 && key->len != 32) { dprintk("%s: key->len is %d\n", __func__, key->len); goto err_out; } - if (randombits->len != 21) { + if (randombits->len != 16 && randombits->len != 32) { dprintk("%s: randombits->len is %d\n", __func__, randombits->len); goto err_out; } - - /* take the seven bytes, move them around into the top 7 bits of the - 8 key bytes, then compute the parity bits. Do this three times. */ - - for (i = 0; i < 3; i++) { - memcpy(key->data + i*8, randombits->data + i*7, 7); - key->data[i*8+7] = (((key->data[i*8]&1)<<1) | - ((key->data[i*8+1]&1)<<2) | - ((key->data[i*8+2]&1)<<3) | - ((key->data[i*8+3]&1)<<4) | - ((key->data[i*8+4]&1)<<5) | - ((key->data[i*8+5]&1)<<6) | - ((key->data[i*8+6]&1)<<7)); - - mit_des_fixup_key_parity(key->data + i*8); + if (randombits->len != key->len) { + dprintk("%s: randombits->len is %d, key->len is %d\n", + __func__, randombits->len, key->len); + goto err_out; } + memcpy(key->data, randombits->data, key->len); ret = 0; err_out: return ret; } +/** + * krb5_derive_key_v2 - Derive a subkey for an RFC 3962 enctype + * @gk5e: Kerberos 5 enctype profile + * @inkey: base protocol key + * @outkey: OUT: derived key + * @label: subkey usage label + * @gfp_mask: memory allocation control flags + * + * Caller sets @outkey->len to the desired length of the derived key. + * + * On success, returns 0 and fills in @outkey. A negative errno value + * is returned on failure. + */ +int krb5_derive_key_v2(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + const struct xdr_netobj *label, + gfp_t gfp_mask) +{ + struct xdr_netobj inblock; + int ret; + + inblock.len = gk5e->keybytes; + inblock.data = kmalloc(inblock.len, gfp_mask); + if (!inblock.data) + return -ENOMEM; + + ret = krb5_DK(gk5e, inkey, inblock.data, label, gfp_mask); + if (!ret) + ret = krb5_random_to_key_v2(gk5e, &inblock, outkey); + + kfree_sensitive(inblock.data); + return ret; +} + /* - * This is the aes key derivation postprocess function + * K(i) = CMAC(key, K(i-1) | i | constant | 0x00 | k) + * + * i: A block counter is used with a length of 4 bytes, represented + * in big-endian order. + * + * constant: The label input to the KDF is the usage constant supplied + * to the key derivation function + * + * k: The length of the output key in bits, represented as a 4-byte + * string in big-endian order. + * + * Caller fills in K(i-1) in @step, and receives the result K(i) + * in the same buffer. */ -u32 gss_krb5_aes_make_key(const struct gss_krb5_enctype *gk5e, - struct xdr_netobj *randombits, - struct xdr_netobj *key) +static int +krb5_cmac_Ki(struct crypto_shash *tfm, const struct xdr_netobj *constant, + u32 outlen, u32 count, struct xdr_netobj *step) { - u32 ret = EINVAL; + __be32 k = cpu_to_be32(outlen * 8); + SHASH_DESC_ON_STACK(desc, tfm); + __be32 i = cpu_to_be32(count); + u8 zero = 0; + int ret; + + desc->tfm = tfm; + ret = crypto_shash_init(desc); + if (ret) + goto out_err; + + ret = crypto_shash_update(desc, step->data, step->len); + if (ret) + goto out_err; + ret = crypto_shash_update(desc, (u8 *)&i, sizeof(i)); + if (ret) + goto out_err; + ret = crypto_shash_update(desc, constant->data, constant->len); + if (ret) + goto out_err; + ret = crypto_shash_update(desc, &zero, sizeof(zero)); + if (ret) + goto out_err; + ret = crypto_shash_update(desc, (u8 *)&k, sizeof(k)); + if (ret) + goto out_err; + ret = crypto_shash_final(desc, step->data); + if (ret) + goto out_err; + +out_err: + shash_desc_zero(desc); + return ret; +} - if (key->len != 16 && key->len != 32) { - dprintk("%s: key->len is %d\n", __func__, key->len); - goto err_out; - } - if (randombits->len != 16 && randombits->len != 32) { - dprintk("%s: randombits->len is %d\n", - __func__, randombits->len); - goto err_out; +/** + * krb5_kdf_feedback_cmac - Derive a subkey for a Camellia/CMAC-based enctype + * @gk5e: Kerberos 5 enctype parameters + * @inkey: base protocol key + * @outkey: OUT: derived key + * @constant: subkey usage label + * @gfp_mask: memory allocation control flags + * + * RFC 6803 Section 3: + * + * "We use a key derivation function from the family specified in + * [SP800-108], Section 5.2, 'KDF in Feedback Mode'." + * + * n = ceiling(k / 128) + * K(0) = zeros + * K(i) = CMAC(key, K(i-1) | i | constant | 0x00 | k) + * DR(key, constant) = k-truncate(K(1) | K(2) | ... | K(n)) + * KDF-FEEDBACK-CMAC(key, constant) = random-to-key(DR(key, constant)) + * + * Caller sets @outkey->len to the desired length of the derived key (k). + * + * On success, returns 0 and fills in @outkey. A negative errno value + * is returned on failure. + */ +int +krb5_kdf_feedback_cmac(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + const struct xdr_netobj *constant, + gfp_t gfp_mask) +{ + struct xdr_netobj step = { .data = NULL }; + struct xdr_netobj DR = { .data = NULL }; + unsigned int blocksize, offset; + struct crypto_shash *tfm; + int n, count, ret; + + /* + * This implementation assumes the CMAC used for an enctype's + * key derivation is the same as the CMAC used for its + * checksumming. This happens to be true for enctypes that + * are currently supported by this implementation. + */ + tfm = crypto_alloc_shash(gk5e->cksum_name, 0, 0); + if (IS_ERR(tfm)) { + ret = PTR_ERR(tfm); + goto out; } - if (randombits->len != key->len) { - dprintk("%s: randombits->len is %d, key->len is %d\n", - __func__, randombits->len, key->len); - goto err_out; + ret = crypto_shash_setkey(tfm, inkey->data, inkey->len); + if (ret) + goto out_free_tfm; + + blocksize = crypto_shash_digestsize(tfm); + n = (outkey->len + blocksize - 1) / blocksize; + + /* K(0) is all zeroes */ + ret = -ENOMEM; + step.len = blocksize; + step.data = kzalloc(step.len, gfp_mask); + if (!step.data) + goto out_free_tfm; + + DR.len = blocksize * n; + DR.data = kmalloc(DR.len, gfp_mask); + if (!DR.data) + goto out_free_tfm; + + /* XXX: Does not handle partial-block key sizes */ + for (offset = 0, count = 1; count <= n; count++) { + ret = krb5_cmac_Ki(tfm, constant, outkey->len, count, &step); + if (ret) + goto out_free_tfm; + + memcpy(DR.data + offset, step.data, blocksize); + offset += blocksize; } - memcpy(key->data, randombits->data, key->len); + + /* k-truncate and random-to-key */ + memcpy(outkey->data, DR.data, outkey->len); ret = 0; -err_out: + +out_free_tfm: + crypto_free_shash(tfm); +out: + kfree_sensitive(step.data); + kfree_sensitive(DR.data); return ret; } +/* + * K1 = HMAC-SHA(key, 0x00000001 | label | 0x00 | k) + * + * key: The source of entropy from which subsequent keys are derived. + * + * label: An octet string describing the intended usage of the + * derived key. + * + * k: Length in bits of the key to be outputted, expressed in + * big-endian binary representation in 4 bytes. + */ +static int +krb5_hmac_K1(struct crypto_shash *tfm, const struct xdr_netobj *label, + u32 outlen, struct xdr_netobj *K1) +{ + __be32 k = cpu_to_be32(outlen * 8); + SHASH_DESC_ON_STACK(desc, tfm); + __be32 one = cpu_to_be32(1); + u8 zero = 0; + int ret; + + desc->tfm = tfm; + ret = crypto_shash_init(desc); + if (ret) + goto out_err; + ret = crypto_shash_update(desc, (u8 *)&one, sizeof(one)); + if (ret) + goto out_err; + ret = crypto_shash_update(desc, label->data, label->len); + if (ret) + goto out_err; + ret = crypto_shash_update(desc, &zero, sizeof(zero)); + if (ret) + goto out_err; + ret = crypto_shash_update(desc, (u8 *)&k, sizeof(k)); + if (ret) + goto out_err; + ret = crypto_shash_final(desc, K1->data); + if (ret) + goto out_err; + +out_err: + shash_desc_zero(desc); + return ret; +} + +/** + * krb5_kdf_hmac_sha2 - Derive a subkey for an AES/SHA2-based enctype + * @gk5e: Kerberos 5 enctype policy parameters + * @inkey: base protocol key + * @outkey: OUT: derived key + * @label: subkey usage label + * @gfp_mask: memory allocation control flags + * + * RFC 8009 Section 3: + * + * "We use a key derivation function from Section 5.1 of [SP800-108], + * which uses the HMAC algorithm as the PRF." + * + * function KDF-HMAC-SHA2(key, label, [context,] k): + * k-truncate(K1) + * + * Caller sets @outkey->len to the desired length of the derived key. + * + * On success, returns 0 and fills in @outkey. A negative errno value + * is returned on failure. + */ +int +krb5_kdf_hmac_sha2(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + const struct xdr_netobj *label, + gfp_t gfp_mask) +{ + struct crypto_shash *tfm; + struct xdr_netobj K1 = { + .data = NULL, + }; + int ret; + + /* + * This implementation assumes the HMAC used for an enctype's + * key derivation is the same as the HMAC used for its + * checksumming. This happens to be true for enctypes that + * are currently supported by this implementation. + */ + tfm = crypto_alloc_shash(gk5e->cksum_name, 0, 0); + if (IS_ERR(tfm)) { + ret = PTR_ERR(tfm); + goto out; + } + ret = crypto_shash_setkey(tfm, inkey->data, inkey->len); + if (ret) + goto out_free_tfm; + + K1.len = crypto_shash_digestsize(tfm); + K1.data = kmalloc(K1.len, gfp_mask); + if (!K1.data) { + ret = -ENOMEM; + goto out_free_tfm; + } + + ret = krb5_hmac_K1(tfm, label, outkey->len, &K1); + if (ret) + goto out_free_tfm; + + /* k-truncate and random-to-key */ + memcpy(outkey->data, K1.data, outkey->len); + +out_free_tfm: + kfree_sensitive(K1.data); + crypto_free_shash(tfm); +out: + return ret; +} diff --git a/net/sunrpc/auth_gss/gss_krb5_mech.c b/net/sunrpc/auth_gss/gss_krb5_mech.c index 0d3c158ef8fa..3366505bc669 100644 --- a/net/sunrpc/auth_gss/gss_krb5_mech.c +++ b/net/sunrpc/auth_gss/gss_krb5_mech.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: BSD-3-Clause /* * linux/net/sunrpc/gss_krb5_mech.c * @@ -6,34 +7,10 @@ * * Andy Adamson <andros@umich.edu> * J. Bruce Fields <bfields@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * */ +#include <crypto/hash.h> +#include <crypto/skcipher.h> #include <linux/err.h> #include <linux/module.h> #include <linux/init.h> @@ -42,599 +19,413 @@ #include <linux/sunrpc/auth.h> #include <linux/sunrpc/gss_krb5.h> #include <linux/sunrpc/xdr.h> -#include <linux/crypto.h> -#include <linux/sunrpc/gss_krb5_enctypes.h> +#include <kunit/visibility.h> -#ifdef RPC_DEBUG +#include "auth_gss_internal.h" +#include "gss_krb5_internal.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif -static struct gss_api_mech gss_kerberos_mech; /* forward declaration */ +static struct gss_api_mech gss_kerberos_mech; static const struct gss_krb5_enctype supported_gss_krb5_enctypes[] = { +#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA1) /* - * DES (All DES enctypes are mapped to the same gss functionality) - */ - { - .etype = ENCTYPE_DES_CBC_RAW, - .ctype = CKSUMTYPE_RSA_MD5, - .name = "des-cbc-crc", - .encrypt_name = "cbc(des)", - .cksum_name = "md5", - .encrypt = krb5_encrypt, - .decrypt = krb5_decrypt, - .mk_key = NULL, - .signalg = SGN_ALG_DES_MAC_MD5, - .sealalg = SEAL_ALG_DES, - .keybytes = 7, - .keylength = 8, - .blocksize = 8, - .conflen = 8, - .cksumlength = 8, - .keyed_cksum = 0, - }, - /* - * RC4-HMAC - */ - { - .etype = ENCTYPE_ARCFOUR_HMAC, - .ctype = CKSUMTYPE_HMAC_MD5_ARCFOUR, - .name = "rc4-hmac", - .encrypt_name = "ecb(arc4)", - .cksum_name = "hmac(md5)", - .encrypt = krb5_encrypt, - .decrypt = krb5_decrypt, - .mk_key = NULL, - .signalg = SGN_ALG_HMAC_MD5, - .sealalg = SEAL_ALG_MICROSOFT_RC4, - .keybytes = 16, - .keylength = 16, - .blocksize = 1, - .conflen = 8, - .cksumlength = 8, - .keyed_cksum = 1, - }, - /* - * 3DES - */ - { - .etype = ENCTYPE_DES3_CBC_RAW, - .ctype = CKSUMTYPE_HMAC_SHA1_DES3, - .name = "des3-hmac-sha1", - .encrypt_name = "cbc(des3_ede)", - .cksum_name = "hmac(sha1)", - .encrypt = krb5_encrypt, - .decrypt = krb5_decrypt, - .mk_key = gss_krb5_des3_make_key, - .signalg = SGN_ALG_HMAC_SHA1_DES3_KD, - .sealalg = SEAL_ALG_DES3KD, - .keybytes = 21, - .keylength = 24, - .blocksize = 8, - .conflen = 8, - .cksumlength = 20, - .keyed_cksum = 1, - }, - /* - * AES128 + * AES-128 with SHA-1 (RFC 3962) */ { .etype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, .ctype = CKSUMTYPE_HMAC_SHA1_96_AES128, .name = "aes128-cts", .encrypt_name = "cts(cbc(aes))", + .aux_cipher = "cbc(aes)", .cksum_name = "hmac(sha1)", - .encrypt = krb5_encrypt, - .decrypt = krb5_decrypt, - .mk_key = gss_krb5_aes_make_key, - .encrypt_v2 = gss_krb5_aes_encrypt, - .decrypt_v2 = gss_krb5_aes_decrypt, + .derive_key = krb5_derive_key_v2, + .encrypt = gss_krb5_aes_encrypt, + .decrypt = gss_krb5_aes_decrypt, + + .get_mic = gss_krb5_get_mic_v2, + .verify_mic = gss_krb5_verify_mic_v2, + .wrap = gss_krb5_wrap_v2, + .unwrap = gss_krb5_unwrap_v2, + .signalg = -1, .sealalg = -1, .keybytes = 16, - .keylength = 16, - .blocksize = 16, - .conflen = 16, - .cksumlength = 12, + .keylength = BITS2OCTETS(128), + .Kc_length = BITS2OCTETS(128), + .Ke_length = BITS2OCTETS(128), + .Ki_length = BITS2OCTETS(128), + .cksumlength = BITS2OCTETS(96), .keyed_cksum = 1, }, /* - * AES256 + * AES-256 with SHA-1 (RFC 3962) */ { .etype = ENCTYPE_AES256_CTS_HMAC_SHA1_96, .ctype = CKSUMTYPE_HMAC_SHA1_96_AES256, .name = "aes256-cts", .encrypt_name = "cts(cbc(aes))", + .aux_cipher = "cbc(aes)", .cksum_name = "hmac(sha1)", - .encrypt = krb5_encrypt, - .decrypt = krb5_decrypt, - .mk_key = gss_krb5_aes_make_key, - .encrypt_v2 = gss_krb5_aes_encrypt, - .decrypt_v2 = gss_krb5_aes_decrypt, + .derive_key = krb5_derive_key_v2, + .encrypt = gss_krb5_aes_encrypt, + .decrypt = gss_krb5_aes_decrypt, + + .get_mic = gss_krb5_get_mic_v2, + .verify_mic = gss_krb5_verify_mic_v2, + .wrap = gss_krb5_wrap_v2, + .unwrap = gss_krb5_unwrap_v2, + .signalg = -1, .sealalg = -1, .keybytes = 32, - .keylength = 32, - .blocksize = 16, - .conflen = 16, - .cksumlength = 12, + .keylength = BITS2OCTETS(256), + .Kc_length = BITS2OCTETS(256), + .Ke_length = BITS2OCTETS(256), + .Ki_length = BITS2OCTETS(256), + .cksumlength = BITS2OCTETS(96), .keyed_cksum = 1, }, -}; - -static const int num_supported_enctypes = - ARRAY_SIZE(supported_gss_krb5_enctypes); - -static int -supported_gss_krb5_enctype(int etype) -{ - int i; - for (i = 0; i < num_supported_enctypes; i++) - if (supported_gss_krb5_enctypes[i].etype == etype) - return 1; - return 0; -} - -static const struct gss_krb5_enctype * -get_gss_krb5_enctype(int etype) -{ - int i; - for (i = 0; i < num_supported_enctypes; i++) - if (supported_gss_krb5_enctypes[i].etype == etype) - return &supported_gss_krb5_enctypes[i]; - return NULL; -} +#endif -static const void * -simple_get_bytes(const void *p, const void *end, void *res, int len) -{ - const void *q = (const void *)((const char *)p + len); - if (unlikely(q > end || q < p)) - return ERR_PTR(-EFAULT); - memcpy(res, p, len); - return q; -} +#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_CAMELLIA) + /* + * Camellia-128 with CMAC (RFC 6803) + */ + { + .etype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .ctype = CKSUMTYPE_CMAC_CAMELLIA128, + .name = "camellia128-cts-cmac", + .encrypt_name = "cts(cbc(camellia))", + .aux_cipher = "cbc(camellia)", + .cksum_name = "cmac(camellia)", + .cksumlength = BITS2OCTETS(128), + .keyed_cksum = 1, + .keylength = BITS2OCTETS(128), + .Kc_length = BITS2OCTETS(128), + .Ke_length = BITS2OCTETS(128), + .Ki_length = BITS2OCTETS(128), + + .derive_key = krb5_kdf_feedback_cmac, + .encrypt = gss_krb5_aes_encrypt, + .decrypt = gss_krb5_aes_decrypt, + + .get_mic = gss_krb5_get_mic_v2, + .verify_mic = gss_krb5_verify_mic_v2, + .wrap = gss_krb5_wrap_v2, + .unwrap = gss_krb5_unwrap_v2, + }, + /* + * Camellia-256 with CMAC (RFC 6803) + */ + { + .etype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .ctype = CKSUMTYPE_CMAC_CAMELLIA256, + .name = "camellia256-cts-cmac", + .encrypt_name = "cts(cbc(camellia))", + .aux_cipher = "cbc(camellia)", + .cksum_name = "cmac(camellia)", + .cksumlength = BITS2OCTETS(128), + .keyed_cksum = 1, + .keylength = BITS2OCTETS(256), + .Kc_length = BITS2OCTETS(256), + .Ke_length = BITS2OCTETS(256), + .Ki_length = BITS2OCTETS(256), + + .derive_key = krb5_kdf_feedback_cmac, + .encrypt = gss_krb5_aes_encrypt, + .decrypt = gss_krb5_aes_decrypt, + + .get_mic = gss_krb5_get_mic_v2, + .verify_mic = gss_krb5_verify_mic_v2, + .wrap = gss_krb5_wrap_v2, + .unwrap = gss_krb5_unwrap_v2, + }, +#endif -static const void * -simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res) -{ - const void *q; - unsigned int len; +#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA2) + /* + * AES-128 with SHA-256 (RFC 8009) + */ + { + .etype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .ctype = CKSUMTYPE_HMAC_SHA256_128_AES128, + .name = "aes128-cts-hmac-sha256-128", + .encrypt_name = "cts(cbc(aes))", + .aux_cipher = "cbc(aes)", + .cksum_name = "hmac(sha256)", + .cksumlength = BITS2OCTETS(128), + .keyed_cksum = 1, + .keylength = BITS2OCTETS(128), + .Kc_length = BITS2OCTETS(128), + .Ke_length = BITS2OCTETS(128), + .Ki_length = BITS2OCTETS(128), + + .derive_key = krb5_kdf_hmac_sha2, + .encrypt = krb5_etm_encrypt, + .decrypt = krb5_etm_decrypt, + + .get_mic = gss_krb5_get_mic_v2, + .verify_mic = gss_krb5_verify_mic_v2, + .wrap = gss_krb5_wrap_v2, + .unwrap = gss_krb5_unwrap_v2, + }, + /* + * AES-256 with SHA-384 (RFC 8009) + */ + { + .etype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .ctype = CKSUMTYPE_HMAC_SHA384_192_AES256, + .name = "aes256-cts-hmac-sha384-192", + .encrypt_name = "cts(cbc(aes))", + .aux_cipher = "cbc(aes)", + .cksum_name = "hmac(sha384)", + .cksumlength = BITS2OCTETS(192), + .keyed_cksum = 1, + .keylength = BITS2OCTETS(256), + .Kc_length = BITS2OCTETS(192), + .Ke_length = BITS2OCTETS(256), + .Ki_length = BITS2OCTETS(192), + + .derive_key = krb5_kdf_hmac_sha2, + .encrypt = krb5_etm_encrypt, + .decrypt = krb5_etm_decrypt, + + .get_mic = gss_krb5_get_mic_v2, + .verify_mic = gss_krb5_verify_mic_v2, + .wrap = gss_krb5_wrap_v2, + .unwrap = gss_krb5_unwrap_v2, + }, +#endif +}; - p = simple_get_bytes(p, end, &len, sizeof(len)); - if (IS_ERR(p)) - return p; - q = (const void *)((const char *)p + len); - if (unlikely(q > end || q < p)) - return ERR_PTR(-EFAULT); - res->data = kmemdup(p, len, GFP_NOFS); - if (unlikely(res->data == NULL)) - return ERR_PTR(-ENOMEM); - res->len = len; - return q; -} +/* + * The list of advertised enctypes is specified in order of most + * preferred to least. + */ +static char gss_krb5_enctype_priority_list[64]; -static inline const void * -get_key(const void *p, const void *end, - struct krb5_ctx *ctx, struct crypto_blkcipher **res) +static void gss_krb5_prepare_enctype_priority_list(void) { - struct xdr_netobj key; - int alg; - - p = simple_get_bytes(p, end, &alg, sizeof(alg)); - if (IS_ERR(p)) - goto out_err; - - switch (alg) { - case ENCTYPE_DES_CBC_CRC: - case ENCTYPE_DES_CBC_MD4: - case ENCTYPE_DES_CBC_MD5: - /* Map all these key types to ENCTYPE_DES_CBC_RAW */ - alg = ENCTYPE_DES_CBC_RAW; - break; - } - - if (!supported_gss_krb5_enctype(alg)) { - printk(KERN_WARNING "gss_kerberos_mech: unsupported " - "encryption key algorithm %d\n", alg); - p = ERR_PTR(-EINVAL); - goto out_err; - } - p = simple_get_netobj(p, end, &key); - if (IS_ERR(p)) - goto out_err; - - *res = crypto_alloc_blkcipher(ctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); - if (IS_ERR(*res)) { - printk(KERN_WARNING "gss_kerberos_mech: unable to initialize " - "crypto algorithm %s\n", ctx->gk5e->encrypt_name); - *res = NULL; - goto out_err_free_key; - } - if (crypto_blkcipher_setkey(*res, key.data, key.len)) { - printk(KERN_WARNING "gss_kerberos_mech: error setting key for " - "crypto algorithm %s\n", ctx->gk5e->encrypt_name); - goto out_err_free_tfm; + static const u32 gss_krb5_enctypes[] = { +#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA2) + ENCTYPE_AES256_CTS_HMAC_SHA384_192, + ENCTYPE_AES128_CTS_HMAC_SHA256_128, +#endif +#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_CAMELLIA) + ENCTYPE_CAMELLIA256_CTS_CMAC, + ENCTYPE_CAMELLIA128_CTS_CMAC, +#endif +#if defined(CONFIG_RPCSEC_GSS_KRB5_ENCTYPES_AES_SHA1) + ENCTYPE_AES256_CTS_HMAC_SHA1_96, + ENCTYPE_AES128_CTS_HMAC_SHA1_96, +#endif + }; + size_t total, i; + char buf[16]; + char *sep; + int n; + + sep = ""; + gss_krb5_enctype_priority_list[0] = '\0'; + for (total = 0, i = 0; i < ARRAY_SIZE(gss_krb5_enctypes); i++) { + n = sprintf(buf, "%s%u", sep, gss_krb5_enctypes[i]); + if (n < 0) + break; + if (total + n >= sizeof(gss_krb5_enctype_priority_list)) + break; + strcat(gss_krb5_enctype_priority_list, buf); + sep = ","; + total += n; } - - kfree(key.data); - return p; - -out_err_free_tfm: - crypto_free_blkcipher(*res); -out_err_free_key: - kfree(key.data); - p = ERR_PTR(-EINVAL); -out_err: - return p; } -static int -gss_import_v1_context(const void *p, const void *end, struct krb5_ctx *ctx) +/** + * gss_krb5_lookup_enctype - Retrieve profile information for a given enctype + * @etype: ENCTYPE value + * + * Returns a pointer to a gss_krb5_enctype structure, or NULL if no + * matching etype is found. + */ +VISIBLE_IF_KUNIT +const struct gss_krb5_enctype *gss_krb5_lookup_enctype(u32 etype) { - int tmp; - - p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate)); - if (IS_ERR(p)) - goto out_err; + size_t i; - /* Old format supports only DES! Any other enctype uses new format */ - ctx->enctype = ENCTYPE_DES_CBC_RAW; - - ctx->gk5e = get_gss_krb5_enctype(ctx->enctype); - if (ctx->gk5e == NULL) { - p = ERR_PTR(-EINVAL); - goto out_err; - } - - /* The downcall format was designed before we completely understood - * the uses of the context fields; so it includes some stuff we - * just give some minimal sanity-checking, and some we ignore - * completely (like the next twenty bytes): */ - if (unlikely(p + 20 > end || p + 20 < p)) { - p = ERR_PTR(-EFAULT); - goto out_err; - } - p += 20; - p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); - if (IS_ERR(p)) - goto out_err; - if (tmp != SGN_ALG_DES_MAC_MD5) { - p = ERR_PTR(-ENOSYS); - goto out_err; - } - p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); - if (IS_ERR(p)) - goto out_err; - if (tmp != SEAL_ALG_DES) { - p = ERR_PTR(-ENOSYS); - goto out_err; - } - p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); - if (IS_ERR(p)) - goto out_err; - p = simple_get_bytes(p, end, &ctx->seq_send, sizeof(ctx->seq_send)); - if (IS_ERR(p)) - goto out_err; - p = simple_get_netobj(p, end, &ctx->mech_used); - if (IS_ERR(p)) - goto out_err; - p = get_key(p, end, ctx, &ctx->enc); - if (IS_ERR(p)) - goto out_err_free_mech; - p = get_key(p, end, ctx, &ctx->seq); - if (IS_ERR(p)) - goto out_err_free_key1; - if (p != end) { - p = ERR_PTR(-EFAULT); - goto out_err_free_key2; - } - - return 0; - -out_err_free_key2: - crypto_free_blkcipher(ctx->seq); -out_err_free_key1: - crypto_free_blkcipher(ctx->enc); -out_err_free_mech: - kfree(ctx->mech_used.data); -out_err: - return PTR_ERR(p); + for (i = 0; i < ARRAY_SIZE(supported_gss_krb5_enctypes); i++) + if (supported_gss_krb5_enctypes[i].etype == etype) + return &supported_gss_krb5_enctypes[i]; + return NULL; } +EXPORT_SYMBOL_IF_KUNIT(gss_krb5_lookup_enctype); -static struct crypto_blkcipher * -context_v2_alloc_cipher(struct krb5_ctx *ctx, const char *cname, u8 *key) +static struct crypto_sync_skcipher * +gss_krb5_alloc_cipher_v2(const char *cname, const struct xdr_netobj *key) { - struct crypto_blkcipher *cp; + struct crypto_sync_skcipher *tfm; - cp = crypto_alloc_blkcipher(cname, 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(cp)) { - dprintk("gss_kerberos_mech: unable to initialize " - "crypto algorithm %s\n", cname); + tfm = crypto_alloc_sync_skcipher(cname, 0, 0); + if (IS_ERR(tfm)) return NULL; - } - if (crypto_blkcipher_setkey(cp, key, ctx->gk5e->keylength)) { - dprintk("gss_kerberos_mech: error setting key for " - "crypto algorithm %s\n", cname); - crypto_free_blkcipher(cp); + if (crypto_sync_skcipher_setkey(tfm, key->data, key->len)) { + crypto_free_sync_skcipher(tfm); return NULL; } - return cp; -} - -static inline void -set_cdata(u8 cdata[GSS_KRB5_K5CLENGTH], u32 usage, u8 seed) -{ - cdata[0] = (usage>>24)&0xff; - cdata[1] = (usage>>16)&0xff; - cdata[2] = (usage>>8)&0xff; - cdata[3] = usage&0xff; - cdata[4] = seed; -} - -static int -context_derive_keys_des3(struct krb5_ctx *ctx, gfp_t gfp_mask) -{ - struct xdr_netobj c, keyin, keyout; - u8 cdata[GSS_KRB5_K5CLENGTH]; - u32 err; - - c.len = GSS_KRB5_K5CLENGTH; - c.data = cdata; - - keyin.data = ctx->Ksess; - keyin.len = ctx->gk5e->keylength; - keyout.len = ctx->gk5e->keylength; - - /* seq uses the raw key */ - ctx->seq = context_v2_alloc_cipher(ctx, ctx->gk5e->encrypt_name, - ctx->Ksess); - if (ctx->seq == NULL) - goto out_err; - - ctx->enc = context_v2_alloc_cipher(ctx, ctx->gk5e->encrypt_name, - ctx->Ksess); - if (ctx->enc == NULL) - goto out_free_seq; - - /* derive cksum */ - set_cdata(cdata, KG_USAGE_SIGN, KEY_USAGE_SEED_CHECKSUM); - keyout.data = ctx->cksum; - err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); - if (err) { - dprintk("%s: Error %d deriving cksum key\n", - __func__, err); - goto out_free_enc; - } - - return 0; - -out_free_enc: - crypto_free_blkcipher(ctx->enc); -out_free_seq: - crypto_free_blkcipher(ctx->seq); -out_err: - return -EINVAL; + return tfm; } -/* - * Note that RC4 depends on deriving keys using the sequence - * number or the checksum of a token. Therefore, the final keys - * cannot be calculated until the token is being constructed! - */ -static int -context_derive_keys_rc4(struct krb5_ctx *ctx) +static struct crypto_ahash * +gss_krb5_alloc_hash_v2(struct krb5_ctx *kctx, const struct xdr_netobj *key) { - struct crypto_hash *hmac; - char sigkeyconstant[] = "signaturekey"; - int slen = strlen(sigkeyconstant) + 1; /* include null terminator */ - struct hash_desc desc; - struct scatterlist sg[1]; - int err; - - dprintk("RPC: %s: entered\n", __func__); - /* - * derive cksum (aka Ksign) key - */ - hmac = crypto_alloc_hash(ctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(hmac)) { - dprintk("%s: error %ld allocating hash '%s'\n", - __func__, PTR_ERR(hmac), ctx->gk5e->cksum_name); - err = PTR_ERR(hmac); - goto out_err; - } - - err = crypto_hash_setkey(hmac, ctx->Ksess, ctx->gk5e->keylength); - if (err) - goto out_err_free_hmac; - - sg_init_table(sg, 1); - sg_set_buf(sg, sigkeyconstant, slen); - - desc.tfm = hmac; - desc.flags = 0; + struct crypto_ahash *tfm; - err = crypto_hash_init(&desc); - if (err) - goto out_err_free_hmac; - - err = crypto_hash_digest(&desc, sg, slen, ctx->cksum); - if (err) - goto out_err_free_hmac; - /* - * allocate hash, and blkciphers for data and seqnum encryption - */ - ctx->enc = crypto_alloc_blkcipher(ctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); - if (IS_ERR(ctx->enc)) { - err = PTR_ERR(ctx->enc); - goto out_err_free_hmac; - } - - ctx->seq = crypto_alloc_blkcipher(ctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); - if (IS_ERR(ctx->seq)) { - crypto_free_blkcipher(ctx->enc); - err = PTR_ERR(ctx->seq); - goto out_err_free_hmac; + tfm = crypto_alloc_ahash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm)) + return NULL; + if (crypto_ahash_setkey(tfm, key->data, key->len)) { + crypto_free_ahash(tfm); + return NULL; } - - dprintk("RPC: %s: returning success\n", __func__); - - err = 0; - -out_err_free_hmac: - crypto_free_hash(hmac); -out_err: - dprintk("RPC: %s: returning %d\n", __func__, err); - return err; + return tfm; } static int -context_derive_keys_new(struct krb5_ctx *ctx, gfp_t gfp_mask) +gss_krb5_import_ctx_v2(struct krb5_ctx *ctx, gfp_t gfp_mask) { - struct xdr_netobj c, keyin, keyout; - u8 cdata[GSS_KRB5_K5CLENGTH]; - u32 err; - - c.len = GSS_KRB5_K5CLENGTH; - c.data = cdata; - - keyin.data = ctx->Ksess; - keyin.len = ctx->gk5e->keylength; - keyout.len = ctx->gk5e->keylength; + struct xdr_netobj keyin = { + .len = ctx->gk5e->keylength, + .data = ctx->Ksess, + }; + struct xdr_netobj keyout; + int ret = -EINVAL; + + keyout.data = kmalloc(GSS_KRB5_MAX_KEYLEN, gfp_mask); + if (!keyout.data) + return -ENOMEM; /* initiator seal encryption */ - set_cdata(cdata, KG_USAGE_INITIATOR_SEAL, KEY_USAGE_SEED_ENCRYPTION); - keyout.data = ctx->initiator_seal; - err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); - if (err) { - dprintk("%s: Error %d deriving initiator_seal key\n", - __func__, err); - goto out_err; - } - ctx->initiator_enc = context_v2_alloc_cipher(ctx, - ctx->gk5e->encrypt_name, - ctx->initiator_seal); + keyout.len = ctx->gk5e->Ke_length; + if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_INITIATOR_SEAL, + KEY_USAGE_SEED_ENCRYPTION, gfp_mask)) + goto out; + ctx->initiator_enc = gss_krb5_alloc_cipher_v2(ctx->gk5e->encrypt_name, + &keyout); if (ctx->initiator_enc == NULL) - goto out_err; + goto out; + if (ctx->gk5e->aux_cipher) { + ctx->initiator_enc_aux = + gss_krb5_alloc_cipher_v2(ctx->gk5e->aux_cipher, + &keyout); + if (ctx->initiator_enc_aux == NULL) + goto out_free; + } /* acceptor seal encryption */ - set_cdata(cdata, KG_USAGE_ACCEPTOR_SEAL, KEY_USAGE_SEED_ENCRYPTION); - keyout.data = ctx->acceptor_seal; - err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); - if (err) { - dprintk("%s: Error %d deriving acceptor_seal key\n", - __func__, err); - goto out_free_initiator_enc; - } - ctx->acceptor_enc = context_v2_alloc_cipher(ctx, - ctx->gk5e->encrypt_name, - ctx->acceptor_seal); + if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_ACCEPTOR_SEAL, + KEY_USAGE_SEED_ENCRYPTION, gfp_mask)) + goto out_free; + ctx->acceptor_enc = gss_krb5_alloc_cipher_v2(ctx->gk5e->encrypt_name, + &keyout); if (ctx->acceptor_enc == NULL) - goto out_free_initiator_enc; + goto out_free; + if (ctx->gk5e->aux_cipher) { + ctx->acceptor_enc_aux = + gss_krb5_alloc_cipher_v2(ctx->gk5e->aux_cipher, + &keyout); + if (ctx->acceptor_enc_aux == NULL) + goto out_free; + } /* initiator sign checksum */ - set_cdata(cdata, KG_USAGE_INITIATOR_SIGN, KEY_USAGE_SEED_CHECKSUM); - keyout.data = ctx->initiator_sign; - err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); - if (err) { - dprintk("%s: Error %d deriving initiator_sign key\n", - __func__, err); - goto out_free_acceptor_enc; - } + keyout.len = ctx->gk5e->Kc_length; + if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_INITIATOR_SIGN, + KEY_USAGE_SEED_CHECKSUM, gfp_mask)) + goto out_free; + ctx->initiator_sign = gss_krb5_alloc_hash_v2(ctx, &keyout); + if (ctx->initiator_sign == NULL) + goto out_free; /* acceptor sign checksum */ - set_cdata(cdata, KG_USAGE_ACCEPTOR_SIGN, KEY_USAGE_SEED_CHECKSUM); - keyout.data = ctx->acceptor_sign; - err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); - if (err) { - dprintk("%s: Error %d deriving acceptor_sign key\n", - __func__, err); - goto out_free_acceptor_enc; - } + if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_ACCEPTOR_SIGN, + KEY_USAGE_SEED_CHECKSUM, gfp_mask)) + goto out_free; + ctx->acceptor_sign = gss_krb5_alloc_hash_v2(ctx, &keyout); + if (ctx->acceptor_sign == NULL) + goto out_free; /* initiator seal integrity */ - set_cdata(cdata, KG_USAGE_INITIATOR_SEAL, KEY_USAGE_SEED_INTEGRITY); - keyout.data = ctx->initiator_integ; - err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); - if (err) { - dprintk("%s: Error %d deriving initiator_integ key\n", - __func__, err); - goto out_free_acceptor_enc; - } + keyout.len = ctx->gk5e->Ki_length; + if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_INITIATOR_SEAL, + KEY_USAGE_SEED_INTEGRITY, gfp_mask)) + goto out_free; + ctx->initiator_integ = gss_krb5_alloc_hash_v2(ctx, &keyout); + if (ctx->initiator_integ == NULL) + goto out_free; /* acceptor seal integrity */ - set_cdata(cdata, KG_USAGE_ACCEPTOR_SEAL, KEY_USAGE_SEED_INTEGRITY); - keyout.data = ctx->acceptor_integ; - err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); - if (err) { - dprintk("%s: Error %d deriving acceptor_integ key\n", - __func__, err); - goto out_free_acceptor_enc; - } - - switch (ctx->enctype) { - case ENCTYPE_AES128_CTS_HMAC_SHA1_96: - case ENCTYPE_AES256_CTS_HMAC_SHA1_96: - ctx->initiator_enc_aux = - context_v2_alloc_cipher(ctx, "cbc(aes)", - ctx->initiator_seal); - if (ctx->initiator_enc_aux == NULL) - goto out_free_acceptor_enc; - ctx->acceptor_enc_aux = - context_v2_alloc_cipher(ctx, "cbc(aes)", - ctx->acceptor_seal); - if (ctx->acceptor_enc_aux == NULL) { - crypto_free_blkcipher(ctx->initiator_enc_aux); - goto out_free_acceptor_enc; - } - } - - return 0; + if (krb5_derive_key(ctx, &keyin, &keyout, KG_USAGE_ACCEPTOR_SEAL, + KEY_USAGE_SEED_INTEGRITY, gfp_mask)) + goto out_free; + ctx->acceptor_integ = gss_krb5_alloc_hash_v2(ctx, &keyout); + if (ctx->acceptor_integ == NULL) + goto out_free; + + ret = 0; +out: + kfree_sensitive(keyout.data); + return ret; -out_free_acceptor_enc: - crypto_free_blkcipher(ctx->acceptor_enc); -out_free_initiator_enc: - crypto_free_blkcipher(ctx->initiator_enc); -out_err: - return -EINVAL; +out_free: + crypto_free_ahash(ctx->acceptor_integ); + crypto_free_ahash(ctx->initiator_integ); + crypto_free_ahash(ctx->acceptor_sign); + crypto_free_ahash(ctx->initiator_sign); + crypto_free_sync_skcipher(ctx->acceptor_enc_aux); + crypto_free_sync_skcipher(ctx->acceptor_enc); + crypto_free_sync_skcipher(ctx->initiator_enc_aux); + crypto_free_sync_skcipher(ctx->initiator_enc); + goto out; } static int gss_import_v2_context(const void *p, const void *end, struct krb5_ctx *ctx, gfp_t gfp_mask) { + u64 seq_send64; int keylen; + u32 time32; + int ret; p = simple_get_bytes(p, end, &ctx->flags, sizeof(ctx->flags)); if (IS_ERR(p)) goto out_err; ctx->initiate = ctx->flags & KRB5_CTX_FLAG_INITIATOR; - p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); + p = simple_get_bytes(p, end, &time32, sizeof(time32)); if (IS_ERR(p)) goto out_err; - p = simple_get_bytes(p, end, &ctx->seq_send64, sizeof(ctx->seq_send64)); + /* unsigned 32-bit time overflows in year 2106 */ + ctx->endtime = (time64_t)time32; + p = simple_get_bytes(p, end, &seq_send64, sizeof(seq_send64)); if (IS_ERR(p)) goto out_err; + atomic64_set(&ctx->seq_send64, seq_send64); /* set seq_send for use by "older" enctypes */ - ctx->seq_send = ctx->seq_send64; - if (ctx->seq_send64 != ctx->seq_send) { - dprintk("%s: seq_send64 %lx, seq_send %x overflow?\n", __func__, - (unsigned long)ctx->seq_send64, ctx->seq_send); + atomic_set(&ctx->seq_send, seq_send64); + if (seq_send64 != atomic_read(&ctx->seq_send)) { + dprintk("%s: seq_send64 %llx, seq_send %x overflow?\n", __func__, + seq_send64, atomic_read(&ctx->seq_send)); p = ERR_PTR(-EINVAL); goto out_err; } p = simple_get_bytes(p, end, &ctx->enctype, sizeof(ctx->enctype)); if (IS_ERR(p)) goto out_err; - /* Map ENCTYPE_DES3_CBC_SHA1 to ENCTYPE_DES3_CBC_RAW */ - if (ctx->enctype == ENCTYPE_DES3_CBC_SHA1) - ctx->enctype = ENCTYPE_DES3_CBC_RAW; - ctx->gk5e = get_gss_krb5_enctype(ctx->enctype); + ctx->gk5e = gss_krb5_lookup_enctype(ctx->enctype); if (ctx->gk5e == NULL) { dprintk("gss_kerberos_mech: unsupported krb5 enctype %u\n", ctx->enctype); @@ -660,27 +451,23 @@ gss_import_v2_context(const void *p, const void *end, struct krb5_ctx *ctx, } ctx->mech_used.len = gss_kerberos_mech.gm_oid.len; - switch (ctx->enctype) { - case ENCTYPE_DES3_CBC_RAW: - return context_derive_keys_des3(ctx, gfp_mask); - case ENCTYPE_ARCFOUR_HMAC: - return context_derive_keys_rc4(ctx); - case ENCTYPE_AES128_CTS_HMAC_SHA1_96: - case ENCTYPE_AES256_CTS_HMAC_SHA1_96: - return context_derive_keys_new(ctx, gfp_mask); - default: - return -EINVAL; + ret = gss_krb5_import_ctx_v2(ctx, gfp_mask); + if (ret) { + p = ERR_PTR(ret); + goto out_free; } + return 0; + +out_free: + kfree(ctx->mech_used.data); out_err: return PTR_ERR(p); } static int -gss_import_sec_context_kerberos(const void *p, size_t len, - struct gss_ctx *ctx_id, - time_t *endtime, - gfp_t gfp_mask) +gss_krb5_import_sec_context(const void *p, size_t len, struct gss_ctx *ctx_id, + time64_t *endtime, gfp_t gfp_mask) { const void *end = (const void *)((const char *)p + len); struct krb5_ctx *ctx; @@ -690,43 +477,129 @@ gss_import_sec_context_kerberos(const void *p, size_t len, if (ctx == NULL) return -ENOMEM; - if (len == 85) - ret = gss_import_v1_context(p, end, ctx); - else - ret = gss_import_v2_context(p, end, ctx, gfp_mask); - - if (ret == 0) { - ctx_id->internal_ctx_id = ctx; - if (endtime) - *endtime = ctx->endtime; - } else + ret = gss_import_v2_context(p, end, ctx, gfp_mask); + memzero_explicit(&ctx->Ksess, sizeof(ctx->Ksess)); + if (ret) { kfree(ctx); + return ret; + } - dprintk("RPC: %s: returning %d\n", __func__, ret); - return ret; + ctx_id->internal_ctx_id = ctx; + if (endtime) + *endtime = ctx->endtime; + return 0; } static void -gss_delete_sec_context_kerberos(void *internal_ctx) { +gss_krb5_delete_sec_context(void *internal_ctx) +{ struct krb5_ctx *kctx = internal_ctx; - crypto_free_blkcipher(kctx->seq); - crypto_free_blkcipher(kctx->enc); - crypto_free_blkcipher(kctx->acceptor_enc); - crypto_free_blkcipher(kctx->initiator_enc); - crypto_free_blkcipher(kctx->acceptor_enc_aux); - crypto_free_blkcipher(kctx->initiator_enc_aux); + crypto_free_sync_skcipher(kctx->seq); + crypto_free_sync_skcipher(kctx->enc); + crypto_free_sync_skcipher(kctx->acceptor_enc); + crypto_free_sync_skcipher(kctx->initiator_enc); + crypto_free_sync_skcipher(kctx->acceptor_enc_aux); + crypto_free_sync_skcipher(kctx->initiator_enc_aux); + crypto_free_ahash(kctx->acceptor_sign); + crypto_free_ahash(kctx->initiator_sign); + crypto_free_ahash(kctx->acceptor_integ); + crypto_free_ahash(kctx->initiator_integ); kfree(kctx->mech_used.data); kfree(kctx); } +/** + * gss_krb5_get_mic - get_mic for the Kerberos GSS mechanism + * @gctx: GSS context + * @text: plaintext to checksum + * @token: buffer into which to write the computed checksum + * + * Return values: + * %GSS_S_COMPLETE - success, and @token is filled in + * %GSS_S_FAILURE - checksum could not be generated + * %GSS_S_CONTEXT_EXPIRED - Kerberos context is no longer valid + */ +static u32 gss_krb5_get_mic(struct gss_ctx *gctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + return kctx->gk5e->get_mic(kctx, text, token); +} + +/** + * gss_krb5_verify_mic - verify_mic for the Kerberos GSS mechanism + * @gctx: GSS context + * @message_buffer: plaintext to check + * @read_token: received checksum to check + * + * Return values: + * %GSS_S_COMPLETE - computed and received checksums match + * %GSS_S_DEFECTIVE_TOKEN - received checksum is not valid + * %GSS_S_BAD_SIG - computed and received checksums do not match + * %GSS_S_FAILURE - received checksum could not be checked + * %GSS_S_CONTEXT_EXPIRED - Kerberos context is no longer valid + */ +static u32 gss_krb5_verify_mic(struct gss_ctx *gctx, + struct xdr_buf *message_buffer, + struct xdr_netobj *read_token) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + return kctx->gk5e->verify_mic(kctx, message_buffer, read_token); +} + +/** + * gss_krb5_wrap - gss_wrap for the Kerberos GSS mechanism + * @gctx: initialized GSS context + * @offset: byte offset in @buf to start writing the cipher text + * @buf: OUT: send buffer + * @pages: plaintext to wrap + * + * Return values: + * %GSS_S_COMPLETE - success, @buf has been updated + * %GSS_S_FAILURE - @buf could not be wrapped + * %GSS_S_CONTEXT_EXPIRED - Kerberos context is no longer valid + */ +static u32 gss_krb5_wrap(struct gss_ctx *gctx, int offset, + struct xdr_buf *buf, struct page **pages) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + return kctx->gk5e->wrap(kctx, offset, buf, pages); +} + +/** + * gss_krb5_unwrap - gss_unwrap for the Kerberos GSS mechanism + * @gctx: initialized GSS context + * @offset: starting byte offset into @buf + * @len: size of ciphertext to unwrap + * @buf: ciphertext to unwrap + * + * Return values: + * %GSS_S_COMPLETE - success, @buf has been updated + * %GSS_S_DEFECTIVE_TOKEN - received blob is not valid + * %GSS_S_BAD_SIG - computed and received checksums do not match + * %GSS_S_FAILURE - @buf could not be unwrapped + * %GSS_S_CONTEXT_EXPIRED - Kerberos context is no longer valid + */ +static u32 gss_krb5_unwrap(struct gss_ctx *gctx, int offset, + int len, struct xdr_buf *buf) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + return kctx->gk5e->unwrap(kctx, offset, len, buf, + &gctx->slack, &gctx->align); +} + static const struct gss_api_ops gss_kerberos_ops = { - .gss_import_sec_context = gss_import_sec_context_kerberos, - .gss_get_mic = gss_get_mic_kerberos, - .gss_verify_mic = gss_verify_mic_kerberos, - .gss_wrap = gss_wrap_kerberos, - .gss_unwrap = gss_unwrap_kerberos, - .gss_delete_sec_context = gss_delete_sec_context_kerberos, + .gss_import_sec_context = gss_krb5_import_sec_context, + .gss_get_mic = gss_krb5_get_mic, + .gss_verify_mic = gss_krb5_verify_mic, + .gss_wrap = gss_krb5_wrap, + .gss_unwrap = gss_krb5_unwrap, + .gss_delete_sec_context = gss_krb5_delete_sec_context, }; static struct pf_desc gss_kerberos_pfs[] = { @@ -741,12 +614,14 @@ static struct pf_desc gss_kerberos_pfs[] = { .qop = GSS_C_QOP_DEFAULT, .service = RPC_GSS_SVC_INTEGRITY, .name = "krb5i", + .datatouch = true, }, [2] = { .pseudoflavor = RPC_AUTH_GSS_KRB5P, .qop = GSS_C_QOP_DEFAULT, .service = RPC_GSS_SVC_PRIVACY, .name = "krb5p", + .datatouch = true, }, }; @@ -765,13 +640,14 @@ static struct gss_api_mech gss_kerberos_mech = { .gm_ops = &gss_kerberos_ops, .gm_pf_num = ARRAY_SIZE(gss_kerberos_pfs), .gm_pfs = gss_kerberos_pfs, - .gm_upcall_enctypes = KRB5_SUPPORTED_ENCTYPES, + .gm_upcall_enctypes = gss_krb5_enctype_priority_list, }; static int __init init_kerberos_module(void) { int status; + gss_krb5_prepare_enctype_priority_list(); status = gss_mech_register(&gss_kerberos_mech); if (status) printk("Failed to register kerberos gss mechanism!\n"); @@ -783,6 +659,7 @@ static void __exit cleanup_kerberos_module(void) gss_mech_unregister(&gss_kerberos_mech); } +MODULE_DESCRIPTION("Sun RPC Kerberos 5 module"); MODULE_LICENSE("GPL"); module_init(init_kerberos_module); module_exit(cleanup_kerberos_module); diff --git a/net/sunrpc/auth_gss/gss_krb5_seal.c b/net/sunrpc/auth_gss/gss_krb5_seal.c index 62ae3273186c..ce540df9bce4 100644 --- a/net/sunrpc/auth_gss/gss_krb5_seal.c +++ b/net/sunrpc/auth_gss/gss_krb5_seal.c @@ -63,38 +63,19 @@ #include <linux/sunrpc/gss_krb5.h> #include <linux/random.h> #include <linux/crypto.h> +#include <linux/atomic.h> -#ifdef RPC_DEBUG +#include "gss_krb5_internal.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif -DEFINE_SPINLOCK(krb5_seq_lock); - -static char * -setup_token(struct krb5_ctx *ctx, struct xdr_netobj *token) -{ - __be16 *ptr, *krb5_hdr; - int body_size = GSS_KRB5_TOK_HDR_LEN + ctx->gk5e->cksumlength; - - token->len = g_token_size(&ctx->mech_used, body_size); - - ptr = (__be16 *)token->data; - g_make_token_header(&ctx->mech_used, body_size, (unsigned char **)&ptr); - - /* ptr now at start of header described in rfc 1964, section 1.2.1: */ - krb5_hdr = ptr; - *ptr++ = KG_TOK_MIC_MSG; - *ptr++ = cpu_to_le16(ctx->gk5e->signalg); - *ptr++ = SEAL_ALG_NONE; - *ptr++ = 0xffff; - - return (char *)krb5_hdr; -} - static void * setup_token_v2(struct krb5_ctx *ctx, struct xdr_netobj *token) { - __be16 *ptr, *krb5_hdr; + u16 *ptr; + void *krb5_hdr; u8 *p, flags = 0x00; if ((ctx->flags & KRB5_CTX_FLAG_INITIATOR) == 0) @@ -103,74 +84,35 @@ setup_token_v2(struct krb5_ctx *ctx, struct xdr_netobj *token) flags |= 0x04; /* Per rfc 4121, sec 4.2.6.1, there is no header, - * just start the token */ - krb5_hdr = ptr = (__be16 *)token->data; + * just start the token. + */ + krb5_hdr = (u16 *)token->data; + ptr = krb5_hdr; *ptr++ = KG2_TOK_MIC; p = (u8 *)ptr; *p++ = flags; *p++ = 0xff; - ptr = (__be16 *)p; - *ptr++ = 0xffff; + ptr = (u16 *)p; *ptr++ = 0xffff; + *ptr = 0xffff; token->len = GSS_KRB5_TOK_HDR_LEN + ctx->gk5e->cksumlength; return krb5_hdr; } -static u32 -gss_get_mic_v1(struct krb5_ctx *ctx, struct xdr_buf *text, - struct xdr_netobj *token) -{ - char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; - struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), - .data = cksumdata}; - void *ptr; - s32 now; - u32 seq_send; - u8 *cksumkey; - - dprintk("RPC: %s\n", __func__); - BUG_ON(ctx == NULL); - - now = get_seconds(); - - ptr = setup_token(ctx, token); - - if (ctx->gk5e->keyed_cksum) - cksumkey = ctx->cksum; - else - cksumkey = NULL; - - if (make_checksum(ctx, ptr, 8, text, 0, cksumkey, - KG_USAGE_SIGN, &md5cksum)) - return GSS_S_FAILURE; - - memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); - - spin_lock(&krb5_seq_lock); - seq_send = ctx->seq_send++; - spin_unlock(&krb5_seq_lock); - - if (krb5_make_seq_num(ctx, ctx->seq, ctx->initiate ? 0 : 0xff, - seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8)) - return GSS_S_FAILURE; - - return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; -} - -static u32 -gss_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text, - struct xdr_netobj *token) +u32 +gss_krb5_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text, + struct xdr_netobj *token) { - char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; - struct xdr_netobj cksumobj = { .len = sizeof(cksumdata), - .data = cksumdata}; + struct crypto_ahash *tfm = ctx->initiate ? + ctx->initiator_sign : ctx->acceptor_sign; + struct xdr_netobj cksumobj = { + .len = ctx->gk5e->cksumlength, + }; + __be64 seq_send_be64; void *krb5_hdr; - s32 now; - u64 seq_send; - u8 *cksumkey; - unsigned int cksum_usage; + time64_t now; dprintk("RPC: %s\n", __func__); @@ -178,46 +120,14 @@ gss_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text, /* Set up the sequence number. Now 64-bits in clear * text and w/o direction indicator */ - spin_lock(&krb5_seq_lock); - seq_send = ctx->seq_send64++; - spin_unlock(&krb5_seq_lock); - *((u64 *)(krb5_hdr + 8)) = cpu_to_be64(seq_send); - - if (ctx->initiate) { - cksumkey = ctx->initiator_sign; - cksum_usage = KG_USAGE_INITIATOR_SIGN; - } else { - cksumkey = ctx->acceptor_sign; - cksum_usage = KG_USAGE_ACCEPTOR_SIGN; - } - - if (make_checksum_v2(ctx, krb5_hdr, GSS_KRB5_TOK_HDR_LEN, - text, 0, cksumkey, cksum_usage, &cksumobj)) - return GSS_S_FAILURE; - - memcpy(krb5_hdr + GSS_KRB5_TOK_HDR_LEN, cksumobj.data, cksumobj.len); + seq_send_be64 = cpu_to_be64(atomic64_fetch_inc(&ctx->seq_send64)); + memcpy(krb5_hdr + 8, (char *) &seq_send_be64, 8); - now = get_seconds(); + cksumobj.data = krb5_hdr + GSS_KRB5_TOK_HDR_LEN; + if (gss_krb5_checksum(tfm, krb5_hdr, GSS_KRB5_TOK_HDR_LEN, + text, 0, &cksumobj)) + return GSS_S_FAILURE; + now = ktime_get_real_seconds(); return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; } - -u32 -gss_get_mic_kerberos(struct gss_ctx *gss_ctx, struct xdr_buf *text, - struct xdr_netobj *token) -{ - struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; - - switch (ctx->enctype) { - default: - BUG(); - case ENCTYPE_DES_CBC_RAW: - case ENCTYPE_DES3_CBC_RAW: - case ENCTYPE_ARCFOUR_HMAC: - return gss_get_mic_v1(ctx, text, token); - case ENCTYPE_AES128_CTS_HMAC_SHA1_96: - case ENCTYPE_AES256_CTS_HMAC_SHA1_96: - return gss_get_mic_v2(ctx, text, token); - } -} - diff --git a/net/sunrpc/auth_gss/gss_krb5_seqnum.c b/net/sunrpc/auth_gss/gss_krb5_seqnum.c deleted file mode 100644 index 62ac90c62cb1..000000000000 --- a/net/sunrpc/auth_gss/gss_krb5_seqnum.c +++ /dev/null @@ -1,166 +0,0 @@ -/* - * linux/net/sunrpc/gss_krb5_seqnum.c - * - * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/util_seqnum.c - * - * Copyright (c) 2000 The Regents of the University of Michigan. - * All rights reserved. - * - * Andy Adamson <andros@umich.edu> - */ - -/* - * Copyright 1993 by OpenVision Technologies, Inc. - * - * Permission to use, copy, modify, distribute, and sell this software - * and its documentation for any purpose is hereby granted without fee, - * provided that the above copyright notice appears in all copies and - * that both that copyright notice and this permission notice appear in - * supporting documentation, and that the name of OpenVision not be used - * in advertising or publicity pertaining to distribution of the software - * without specific, written prior permission. OpenVision makes no - * representations about the suitability of this software for any - * purpose. It is provided "as is" without express or implied warranty. - * - * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, - * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO - * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR - * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF - * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR - * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR - * PERFORMANCE OF THIS SOFTWARE. - */ - -#include <linux/types.h> -#include <linux/sunrpc/gss_krb5.h> -#include <linux/crypto.h> - -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif - -static s32 -krb5_make_rc4_seq_num(struct krb5_ctx *kctx, int direction, s32 seqnum, - unsigned char *cksum, unsigned char *buf) -{ - struct crypto_blkcipher *cipher; - unsigned char plain[8]; - s32 code; - - dprintk("RPC: %s:\n", __func__); - cipher = crypto_alloc_blkcipher(kctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); - if (IS_ERR(cipher)) - return PTR_ERR(cipher); - - plain[0] = (unsigned char) ((seqnum >> 24) & 0xff); - plain[1] = (unsigned char) ((seqnum >> 16) & 0xff); - plain[2] = (unsigned char) ((seqnum >> 8) & 0xff); - plain[3] = (unsigned char) ((seqnum >> 0) & 0xff); - plain[4] = direction; - plain[5] = direction; - plain[6] = direction; - plain[7] = direction; - - code = krb5_rc4_setup_seq_key(kctx, cipher, cksum); - if (code) - goto out; - - code = krb5_encrypt(cipher, cksum, plain, buf, 8); -out: - crypto_free_blkcipher(cipher); - return code; -} -s32 -krb5_make_seq_num(struct krb5_ctx *kctx, - struct crypto_blkcipher *key, - int direction, - u32 seqnum, - unsigned char *cksum, unsigned char *buf) -{ - unsigned char plain[8]; - - if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) - return krb5_make_rc4_seq_num(kctx, direction, seqnum, - cksum, buf); - - plain[0] = (unsigned char) (seqnum & 0xff); - plain[1] = (unsigned char) ((seqnum >> 8) & 0xff); - plain[2] = (unsigned char) ((seqnum >> 16) & 0xff); - plain[3] = (unsigned char) ((seqnum >> 24) & 0xff); - - plain[4] = direction; - plain[5] = direction; - plain[6] = direction; - plain[7] = direction; - - return krb5_encrypt(key, cksum, plain, buf, 8); -} - -static s32 -krb5_get_rc4_seq_num(struct krb5_ctx *kctx, unsigned char *cksum, - unsigned char *buf, int *direction, s32 *seqnum) -{ - struct crypto_blkcipher *cipher; - unsigned char plain[8]; - s32 code; - - dprintk("RPC: %s:\n", __func__); - cipher = crypto_alloc_blkcipher(kctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); - if (IS_ERR(cipher)) - return PTR_ERR(cipher); - - code = krb5_rc4_setup_seq_key(kctx, cipher, cksum); - if (code) - goto out; - - code = krb5_decrypt(cipher, cksum, buf, plain, 8); - if (code) - goto out; - - if ((plain[4] != plain[5]) || (plain[4] != plain[6]) - || (plain[4] != plain[7])) { - code = (s32)KG_BAD_SEQ; - goto out; - } - - *direction = plain[4]; - - *seqnum = ((plain[0] << 24) | (plain[1] << 16) | - (plain[2] << 8) | (plain[3])); -out: - crypto_free_blkcipher(cipher); - return code; -} - -s32 -krb5_get_seq_num(struct krb5_ctx *kctx, - unsigned char *cksum, - unsigned char *buf, - int *direction, u32 *seqnum) -{ - s32 code; - unsigned char plain[8]; - struct crypto_blkcipher *key = kctx->seq; - - dprintk("RPC: krb5_get_seq_num:\n"); - - if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) - return krb5_get_rc4_seq_num(kctx, cksum, buf, - direction, seqnum); - - if ((code = krb5_decrypt(key, cksum, buf, plain, 8))) - return code; - - if ((plain[4] != plain[5]) || (plain[4] != plain[6]) || - (plain[4] != plain[7])) - return (s32)KG_BAD_SEQ; - - *direction = plain[4]; - - *seqnum = ((plain[0]) | - (plain[1] << 8) | (plain[2] << 16) | (plain[3] << 24)); - - return 0; -} diff --git a/net/sunrpc/auth_gss/gss_krb5_test.c b/net/sunrpc/auth_gss/gss_krb5_test.c new file mode 100644 index 000000000000..a5bff02cd7ba --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_test.c @@ -0,0 +1,1859 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2022 Oracle and/or its affiliates. + * + * KUnit test of SunRPC's GSS Kerberos mechanism. Subsystem + * name is "rpcsec_gss_krb5". + */ + +#include <kunit/test.h> +#include <kunit/visibility.h> + +#include <linux/kernel.h> +#include <crypto/hash.h> + +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/gss_krb5.h> + +#include "gss_krb5_internal.h" + +MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING"); + +struct gss_krb5_test_param { + const char *desc; + u32 enctype; + u32 nfold; + u32 constant; + const struct xdr_netobj *base_key; + const struct xdr_netobj *Ke; + const struct xdr_netobj *usage; + const struct xdr_netobj *plaintext; + const struct xdr_netobj *confounder; + const struct xdr_netobj *expected_result; + const struct xdr_netobj *expected_hmac; + const struct xdr_netobj *next_iv; +}; + +static inline void gss_krb5_get_desc(const struct gss_krb5_test_param *param, + char *desc) +{ + strscpy(desc, param->desc, KUNIT_PARAM_DESC_SIZE); +} + +static void kdf_case(struct kunit *test) +{ + const struct gss_krb5_test_param *param = test->param_value; + const struct gss_krb5_enctype *gk5e; + struct xdr_netobj derivedkey; + int err; + + /* Arrange */ + gk5e = gss_krb5_lookup_enctype(param->enctype); + if (!gk5e) + kunit_skip(test, "Encryption type is not available"); + + derivedkey.data = kunit_kzalloc(test, param->expected_result->len, + GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, derivedkey.data); + derivedkey.len = param->expected_result->len; + + /* Act */ + err = gk5e->derive_key(gk5e, param->base_key, &derivedkey, + param->usage, GFP_KERNEL); + KUNIT_ASSERT_EQ(test, err, 0); + + /* Assert */ + KUNIT_EXPECT_EQ_MSG(test, + memcmp(param->expected_result->data, + derivedkey.data, derivedkey.len), 0, + "key mismatch"); +} + +static void checksum_case(struct kunit *test) +{ + const struct gss_krb5_test_param *param = test->param_value; + struct xdr_buf buf = { + .head[0].iov_len = param->plaintext->len, + .len = param->plaintext->len, + }; + const struct gss_krb5_enctype *gk5e; + struct xdr_netobj Kc, checksum; + struct crypto_ahash *tfm; + int err; + + /* Arrange */ + gk5e = gss_krb5_lookup_enctype(param->enctype); + if (!gk5e) + kunit_skip(test, "Encryption type is not available"); + + Kc.len = gk5e->Kc_length; + Kc.data = kunit_kzalloc(test, Kc.len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Kc.data); + err = gk5e->derive_key(gk5e, param->base_key, &Kc, + param->usage, GFP_KERNEL); + KUNIT_ASSERT_EQ(test, err, 0); + + tfm = crypto_alloc_ahash(gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tfm); + err = crypto_ahash_setkey(tfm, Kc.data, Kc.len); + KUNIT_ASSERT_EQ(test, err, 0); + + buf.head[0].iov_base = kunit_kzalloc(test, buf.head[0].iov_len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, buf.head[0].iov_base); + memcpy(buf.head[0].iov_base, param->plaintext->data, buf.head[0].iov_len); + + checksum.len = gk5e->cksumlength; + checksum.data = kunit_kzalloc(test, checksum.len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, checksum.data); + + /* Act */ + err = gss_krb5_checksum(tfm, NULL, 0, &buf, 0, &checksum); + KUNIT_ASSERT_EQ(test, err, 0); + + /* Assert */ + KUNIT_EXPECT_EQ_MSG(test, + memcmp(param->expected_result->data, + checksum.data, checksum.len), 0, + "checksum mismatch"); + + crypto_free_ahash(tfm); +} + +#define DEFINE_HEX_XDR_NETOBJ(name, hex_array...) \ + static const u8 name ## _data[] = { hex_array }; \ + static const struct xdr_netobj name = { \ + .data = (u8 *)name##_data, \ + .len = sizeof(name##_data), \ + } + +#define DEFINE_STR_XDR_NETOBJ(name, string) \ + static const u8 name ## _str[] = string; \ + static const struct xdr_netobj name = { \ + .data = (u8 *)name##_str, \ + .len = sizeof(name##_str) - 1, \ + } + +/* + * RFC 3961 Appendix A.1. n-fold + * + * The n-fold function is defined in section 5.1 of RFC 3961. + * + * This test material is copyright (C) The Internet Society (2005). + */ + +DEFINE_HEX_XDR_NETOBJ(nfold_test1_plaintext, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test1_expected_result, + 0xbe, 0x07, 0x26, 0x31, 0x27, 0x6b, 0x19, 0x55 +); + +DEFINE_HEX_XDR_NETOBJ(nfold_test2_plaintext, + 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test2_expected_result, + 0x78, 0xa0, 0x7b, 0x6c, 0xaf, 0x85, 0xfa +); + +DEFINE_HEX_XDR_NETOBJ(nfold_test3_plaintext, + 0x52, 0x6f, 0x75, 0x67, 0x68, 0x20, 0x43, 0x6f, + 0x6e, 0x73, 0x65, 0x6e, 0x73, 0x75, 0x73, 0x2c, + 0x20, 0x61, 0x6e, 0x64, 0x20, 0x52, 0x75, 0x6e, + 0x6e, 0x69, 0x6e, 0x67, 0x20, 0x43, 0x6f, 0x64, + 0x65 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test3_expected_result, + 0xbb, 0x6e, 0xd3, 0x08, 0x70, 0xb7, 0xf0, 0xe0 +); + +DEFINE_HEX_XDR_NETOBJ(nfold_test4_plaintext, + 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test4_expected_result, + 0x59, 0xe4, 0xa8, 0xca, 0x7c, 0x03, 0x85, 0xc3, + 0xc3, 0x7b, 0x3f, 0x6d, 0x20, 0x00, 0x24, 0x7c, + 0xb6, 0xe6, 0xbd, 0x5b, 0x3e +); + +DEFINE_HEX_XDR_NETOBJ(nfold_test5_plaintext, + 0x4d, 0x41, 0x53, 0x53, 0x41, 0x43, 0x48, 0x56, + 0x53, 0x45, 0x54, 0x54, 0x53, 0x20, 0x49, 0x4e, + 0x53, 0x54, 0x49, 0x54, 0x56, 0x54, 0x45, 0x20, + 0x4f, 0x46, 0x20, 0x54, 0x45, 0x43, 0x48, 0x4e, + 0x4f, 0x4c, 0x4f, 0x47, 0x59 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test5_expected_result, + 0xdb, 0x3b, 0x0d, 0x8f, 0x0b, 0x06, 0x1e, 0x60, + 0x32, 0x82, 0xb3, 0x08, 0xa5, 0x08, 0x41, 0x22, + 0x9a, 0xd7, 0x98, 0xfa, 0xb9, 0x54, 0x0c, 0x1b +); + +DEFINE_HEX_XDR_NETOBJ(nfold_test6_plaintext, + 0x51 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test6_expected_result, + 0x51, 0x8a, 0x54, 0xa2, 0x15, 0xa8, 0x45, 0x2a, + 0x51, 0x8a, 0x54, 0xa2, 0x15, 0xa8, 0x45, 0x2a, + 0x51, 0x8a, 0x54, 0xa2, 0x15 +); + +DEFINE_HEX_XDR_NETOBJ(nfold_test7_plaintext, + 0x62, 0x61 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test7_expected_result, + 0xfb, 0x25, 0xd5, 0x31, 0xae, 0x89, 0x74, 0x49, + 0x9f, 0x52, 0xfd, 0x92, 0xea, 0x98, 0x57, 0xc4, + 0xba, 0x24, 0xcf, 0x29, 0x7e +); + +DEFINE_HEX_XDR_NETOBJ(nfold_test_kerberos, + 0x6b, 0x65, 0x72, 0x62, 0x65, 0x72, 0x6f, 0x73 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test8_expected_result, + 0x6b, 0x65, 0x72, 0x62, 0x65, 0x72, 0x6f, 0x73 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test9_expected_result, + 0x6b, 0x65, 0x72, 0x62, 0x65, 0x72, 0x6f, 0x73, + 0x7b, 0x9b, 0x5b, 0x2b, 0x93, 0x13, 0x2b, 0x93 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test10_expected_result, + 0x83, 0x72, 0xc2, 0x36, 0x34, 0x4e, 0x5f, 0x15, + 0x50, 0xcd, 0x07, 0x47, 0xe1, 0x5d, 0x62, 0xca, + 0x7a, 0x5a, 0x3b, 0xce, 0xa4 +); +DEFINE_HEX_XDR_NETOBJ(nfold_test11_expected_result, + 0x6b, 0x65, 0x72, 0x62, 0x65, 0x72, 0x6f, 0x73, + 0x7b, 0x9b, 0x5b, 0x2b, 0x93, 0x13, 0x2b, 0x93, + 0x5c, 0x9b, 0xdc, 0xda, 0xd9, 0x5c, 0x98, 0x99, + 0xc4, 0xca, 0xe4, 0xde, 0xe6, 0xd6, 0xca, 0xe4 +); + +static const struct gss_krb5_test_param rfc3961_nfold_test_params[] = { + { + .desc = "64-fold(\"012345\")", + .nfold = 64, + .plaintext = &nfold_test1_plaintext, + .expected_result = &nfold_test1_expected_result, + }, + { + .desc = "56-fold(\"password\")", + .nfold = 56, + .plaintext = &nfold_test2_plaintext, + .expected_result = &nfold_test2_expected_result, + }, + { + .desc = "64-fold(\"Rough Consensus, and Running Code\")", + .nfold = 64, + .plaintext = &nfold_test3_plaintext, + .expected_result = &nfold_test3_expected_result, + }, + { + .desc = "168-fold(\"password\")", + .nfold = 168, + .plaintext = &nfold_test4_plaintext, + .expected_result = &nfold_test4_expected_result, + }, + { + .desc = "192-fold(\"MASSACHVSETTS INSTITVTE OF TECHNOLOGY\")", + .nfold = 192, + .plaintext = &nfold_test5_plaintext, + .expected_result = &nfold_test5_expected_result, + }, + { + .desc = "168-fold(\"Q\")", + .nfold = 168, + .plaintext = &nfold_test6_plaintext, + .expected_result = &nfold_test6_expected_result, + }, + { + .desc = "168-fold(\"ba\")", + .nfold = 168, + .plaintext = &nfold_test7_plaintext, + .expected_result = &nfold_test7_expected_result, + }, + { + .desc = "64-fold(\"kerberos\")", + .nfold = 64, + .plaintext = &nfold_test_kerberos, + .expected_result = &nfold_test8_expected_result, + }, + { + .desc = "128-fold(\"kerberos\")", + .nfold = 128, + .plaintext = &nfold_test_kerberos, + .expected_result = &nfold_test9_expected_result, + }, + { + .desc = "168-fold(\"kerberos\")", + .nfold = 168, + .plaintext = &nfold_test_kerberos, + .expected_result = &nfold_test10_expected_result, + }, + { + .desc = "256-fold(\"kerberos\")", + .nfold = 256, + .plaintext = &nfold_test_kerberos, + .expected_result = &nfold_test11_expected_result, + }, +}; + +/* Creates the function rfc3961_nfold_gen_params */ +KUNIT_ARRAY_PARAM(rfc3961_nfold, rfc3961_nfold_test_params, gss_krb5_get_desc); + +static void rfc3961_nfold_case(struct kunit *test) +{ + const struct gss_krb5_test_param *param = test->param_value; + u8 *result; + + /* Arrange */ + result = kunit_kzalloc(test, 4096, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, result); + + /* Act */ + krb5_nfold(param->plaintext->len * 8, param->plaintext->data, + param->expected_result->len * 8, result); + + /* Assert */ + KUNIT_EXPECT_EQ_MSG(test, + memcmp(param->expected_result->data, + result, param->expected_result->len), 0, + "result mismatch"); +} + +static struct kunit_case rfc3961_test_cases[] = { + { + .name = "RFC 3961 n-fold", + .run_case = rfc3961_nfold_case, + .generate_params = rfc3961_nfold_gen_params, + }, + {} +}; + +static struct kunit_suite rfc3961_suite = { + .name = "RFC 3961 tests", + .test_cases = rfc3961_test_cases, +}; + +/* + * From RFC 3962 Appendix B: Sample Test Vectors + * + * Some test vectors for CBC with ciphertext stealing, using an + * initial vector of all-zero. + * + * This test material is copyright (C) The Internet Society (2005). + */ + +DEFINE_HEX_XDR_NETOBJ(rfc3962_encryption_key, + 0x63, 0x68, 0x69, 0x63, 0x6b, 0x65, 0x6e, 0x20, + 0x74, 0x65, 0x72, 0x69, 0x79, 0x61, 0x6b, 0x69 +); + +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test1_plaintext, + 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20, + 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65, + 0x20 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test1_expected_result, + 0xc6, 0x35, 0x35, 0x68, 0xf2, 0xbf, 0x8c, 0xb4, + 0xd8, 0xa5, 0x80, 0x36, 0x2d, 0xa7, 0xff, 0x7f, + 0x97 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test1_next_iv, + 0xc6, 0x35, 0x35, 0x68, 0xf2, 0xbf, 0x8c, 0xb4, + 0xd8, 0xa5, 0x80, 0x36, 0x2d, 0xa7, 0xff, 0x7f +); + +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test2_plaintext, + 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20, + 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65, + 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, + 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test2_expected_result, + 0xfc, 0x00, 0x78, 0x3e, 0x0e, 0xfd, 0xb2, 0xc1, + 0xd4, 0x45, 0xd4, 0xc8, 0xef, 0xf7, 0xed, 0x22, + 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0, + 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test2_next_iv, + 0xfc, 0x00, 0x78, 0x3e, 0x0e, 0xfd, 0xb2, 0xc1, + 0xd4, 0x45, 0xd4, 0xc8, 0xef, 0xf7, 0xed, 0x22 +); + +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test3_plaintext, + 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20, + 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65, + 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, + 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20, 0x43 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test3_expected_result, + 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5, + 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5, 0xa8, + 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0, + 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5, 0x84 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test3_next_iv, + 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5, + 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5, 0xa8 +); + +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test4_plaintext, + 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20, + 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65, + 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, + 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20, 0x43, + 0x68, 0x69, 0x63, 0x6b, 0x65, 0x6e, 0x2c, 0x20, + 0x70, 0x6c, 0x65, 0x61, 0x73, 0x65, 0x2c +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test4_expected_result, + 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0, + 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5, 0x84, + 0xb3, 0xff, 0xfd, 0x94, 0x0c, 0x16, 0xa1, 0x8c, + 0x1b, 0x55, 0x49, 0xd2, 0xf8, 0x38, 0x02, 0x9e, + 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5, + 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test4_next_iv, + 0xb3, 0xff, 0xfd, 0x94, 0x0c, 0x16, 0xa1, 0x8c, + 0x1b, 0x55, 0x49, 0xd2, 0xf8, 0x38, 0x02, 0x9e +); + +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test5_plaintext, + 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20, + 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65, + 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, + 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20, 0x43, + 0x68, 0x69, 0x63, 0x6b, 0x65, 0x6e, 0x2c, 0x20, + 0x70, 0x6c, 0x65, 0x61, 0x73, 0x65, 0x2c, 0x20 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test5_expected_result, + 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0, + 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5, 0x84, + 0x9d, 0xad, 0x8b, 0xbb, 0x96, 0xc4, 0xcd, 0xc0, + 0x3b, 0xc1, 0x03, 0xe1, 0xa1, 0x94, 0xbb, 0xd8, + 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5, + 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5, 0xa8 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test5_next_iv, + 0x9d, 0xad, 0x8b, 0xbb, 0x96, 0xc4, 0xcd, 0xc0, + 0x3b, 0xc1, 0x03, 0xe1, 0xa1, 0x94, 0xbb, 0xd8 +); + +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test6_plaintext, + 0x49, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20, + 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x74, 0x68, 0x65, + 0x20, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, + 0x20, 0x47, 0x61, 0x75, 0x27, 0x73, 0x20, 0x43, + 0x68, 0x69, 0x63, 0x6b, 0x65, 0x6e, 0x2c, 0x20, + 0x70, 0x6c, 0x65, 0x61, 0x73, 0x65, 0x2c, 0x20, + 0x61, 0x6e, 0x64, 0x20, 0x77, 0x6f, 0x6e, 0x74, + 0x6f, 0x6e, 0x20, 0x73, 0x6f, 0x75, 0x70, 0x2e +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test6_expected_result, + 0x97, 0x68, 0x72, 0x68, 0xd6, 0xec, 0xcc, 0xc0, + 0xc0, 0x7b, 0x25, 0xe2, 0x5e, 0xcf, 0xe5, 0x84, + 0x39, 0x31, 0x25, 0x23, 0xa7, 0x86, 0x62, 0xd5, + 0xbe, 0x7f, 0xcb, 0xcc, 0x98, 0xeb, 0xf5, 0xa8, + 0x48, 0x07, 0xef, 0xe8, 0x36, 0xee, 0x89, 0xa5, + 0x26, 0x73, 0x0d, 0xbc, 0x2f, 0x7b, 0xc8, 0x40, + 0x9d, 0xad, 0x8b, 0xbb, 0x96, 0xc4, 0xcd, 0xc0, + 0x3b, 0xc1, 0x03, 0xe1, 0xa1, 0x94, 0xbb, 0xd8 +); +DEFINE_HEX_XDR_NETOBJ(rfc3962_enc_test6_next_iv, + 0x48, 0x07, 0xef, 0xe8, 0x36, 0xee, 0x89, 0xa5, + 0x26, 0x73, 0x0d, 0xbc, 0x2f, 0x7b, 0xc8, 0x40 +); + +static const struct gss_krb5_test_param rfc3962_encrypt_test_params[] = { + { + .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 1", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .Ke = &rfc3962_encryption_key, + .plaintext = &rfc3962_enc_test1_plaintext, + .expected_result = &rfc3962_enc_test1_expected_result, + .next_iv = &rfc3962_enc_test1_next_iv, + }, + { + .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 2", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .Ke = &rfc3962_encryption_key, + .plaintext = &rfc3962_enc_test2_plaintext, + .expected_result = &rfc3962_enc_test2_expected_result, + .next_iv = &rfc3962_enc_test2_next_iv, + }, + { + .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 3", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .Ke = &rfc3962_encryption_key, + .plaintext = &rfc3962_enc_test3_plaintext, + .expected_result = &rfc3962_enc_test3_expected_result, + .next_iv = &rfc3962_enc_test3_next_iv, + }, + { + .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 4", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .Ke = &rfc3962_encryption_key, + .plaintext = &rfc3962_enc_test4_plaintext, + .expected_result = &rfc3962_enc_test4_expected_result, + .next_iv = &rfc3962_enc_test4_next_iv, + }, + { + .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 5", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .Ke = &rfc3962_encryption_key, + .plaintext = &rfc3962_enc_test5_plaintext, + .expected_result = &rfc3962_enc_test5_expected_result, + .next_iv = &rfc3962_enc_test5_next_iv, + }, + { + .desc = "Encrypt with aes128-cts-hmac-sha1-96 case 6", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .Ke = &rfc3962_encryption_key, + .plaintext = &rfc3962_enc_test6_plaintext, + .expected_result = &rfc3962_enc_test6_expected_result, + .next_iv = &rfc3962_enc_test6_next_iv, + }, +}; + +/* Creates the function rfc3962_encrypt_gen_params */ +KUNIT_ARRAY_PARAM(rfc3962_encrypt, rfc3962_encrypt_test_params, + gss_krb5_get_desc); + +/* + * This tests the implementation of the encryption part of the mechanism. + * It does not apply a confounder or test the result of HMAC over the + * plaintext. + */ +static void rfc3962_encrypt_case(struct kunit *test) +{ + const struct gss_krb5_test_param *param = test->param_value; + struct crypto_sync_skcipher *cts_tfm, *cbc_tfm; + const struct gss_krb5_enctype *gk5e; + struct xdr_buf buf; + void *iv, *text; + u32 err; + + /* Arrange */ + gk5e = gss_krb5_lookup_enctype(param->enctype); + if (!gk5e) + kunit_skip(test, "Encryption type is not available"); + + cbc_tfm = crypto_alloc_sync_skcipher(gk5e->aux_cipher, 0, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cbc_tfm); + err = crypto_sync_skcipher_setkey(cbc_tfm, param->Ke->data, param->Ke->len); + KUNIT_ASSERT_EQ(test, err, 0); + + cts_tfm = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cts_tfm); + err = crypto_sync_skcipher_setkey(cts_tfm, param->Ke->data, param->Ke->len); + KUNIT_ASSERT_EQ(test, err, 0); + + iv = kunit_kzalloc(test, crypto_sync_skcipher_ivsize(cts_tfm), GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, iv); + + text = kunit_kzalloc(test, param->plaintext->len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, text); + + memcpy(text, param->plaintext->data, param->plaintext->len); + memset(&buf, 0, sizeof(buf)); + buf.head[0].iov_base = text; + buf.head[0].iov_len = param->plaintext->len; + buf.len = buf.head[0].iov_len; + + /* Act */ + err = krb5_cbc_cts_encrypt(cts_tfm, cbc_tfm, 0, &buf, NULL, + iv, crypto_sync_skcipher_ivsize(cts_tfm)); + KUNIT_ASSERT_EQ(test, err, 0); + + /* Assert */ + KUNIT_EXPECT_EQ_MSG(test, + param->expected_result->len, buf.len, + "ciphertext length mismatch"); + KUNIT_EXPECT_EQ_MSG(test, + memcmp(param->expected_result->data, + text, param->expected_result->len), 0, + "ciphertext mismatch"); + KUNIT_EXPECT_EQ_MSG(test, + memcmp(param->next_iv->data, iv, + param->next_iv->len), 0, + "IV mismatch"); + + crypto_free_sync_skcipher(cts_tfm); + crypto_free_sync_skcipher(cbc_tfm); +} + +static struct kunit_case rfc3962_test_cases[] = { + { + .name = "RFC 3962 encryption", + .run_case = rfc3962_encrypt_case, + .generate_params = rfc3962_encrypt_gen_params, + }, + {} +}; + +static struct kunit_suite rfc3962_suite = { + .name = "RFC 3962 suite", + .test_cases = rfc3962_test_cases, +}; + +/* + * From RFC 6803 Section 10. Test vectors + * + * Sample results for key derivation + * + * Copyright (c) 2012 IETF Trust and the persons identified as the + * document authors. All rights reserved. + */ + +DEFINE_HEX_XDR_NETOBJ(camellia128_cts_cmac_basekey, + 0x57, 0xd0, 0x29, 0x72, 0x98, 0xff, 0xd9, 0xd3, + 0x5d, 0xe5, 0xa4, 0x7f, 0xb4, 0xbd, 0xe2, 0x4b +); +DEFINE_HEX_XDR_NETOBJ(camellia128_cts_cmac_Kc, + 0xd1, 0x55, 0x77, 0x5a, 0x20, 0x9d, 0x05, 0xf0, + 0x2b, 0x38, 0xd4, 0x2a, 0x38, 0x9e, 0x5a, 0x56 +); +DEFINE_HEX_XDR_NETOBJ(camellia128_cts_cmac_Ke, + 0x64, 0xdf, 0x83, 0xf8, 0x5a, 0x53, 0x2f, 0x17, + 0x57, 0x7d, 0x8c, 0x37, 0x03, 0x57, 0x96, 0xab +); +DEFINE_HEX_XDR_NETOBJ(camellia128_cts_cmac_Ki, + 0x3e, 0x4f, 0xbd, 0xf3, 0x0f, 0xb8, 0x25, 0x9c, + 0x42, 0x5c, 0xb6, 0xc9, 0x6f, 0x1f, 0x46, 0x35 +); + +DEFINE_HEX_XDR_NETOBJ(camellia256_cts_cmac_basekey, + 0xb9, 0xd6, 0x82, 0x8b, 0x20, 0x56, 0xb7, 0xbe, + 0x65, 0x6d, 0x88, 0xa1, 0x23, 0xb1, 0xfa, 0xc6, + 0x82, 0x14, 0xac, 0x2b, 0x72, 0x7e, 0xcf, 0x5f, + 0x69, 0xaf, 0xe0, 0xc4, 0xdf, 0x2a, 0x6d, 0x2c +); +DEFINE_HEX_XDR_NETOBJ(camellia256_cts_cmac_Kc, + 0xe4, 0x67, 0xf9, 0xa9, 0x55, 0x2b, 0xc7, 0xd3, + 0x15, 0x5a, 0x62, 0x20, 0xaf, 0x9c, 0x19, 0x22, + 0x0e, 0xee, 0xd4, 0xff, 0x78, 0xb0, 0xd1, 0xe6, + 0xa1, 0x54, 0x49, 0x91, 0x46, 0x1a, 0x9e, 0x50 +); +DEFINE_HEX_XDR_NETOBJ(camellia256_cts_cmac_Ke, + 0x41, 0x2a, 0xef, 0xc3, 0x62, 0xa7, 0x28, 0x5f, + 0xc3, 0x96, 0x6c, 0x6a, 0x51, 0x81, 0xe7, 0x60, + 0x5a, 0xe6, 0x75, 0x23, 0x5b, 0x6d, 0x54, 0x9f, + 0xbf, 0xc9, 0xab, 0x66, 0x30, 0xa4, 0xc6, 0x04 +); +DEFINE_HEX_XDR_NETOBJ(camellia256_cts_cmac_Ki, + 0xfa, 0x62, 0x4f, 0xa0, 0xe5, 0x23, 0x99, 0x3f, + 0xa3, 0x88, 0xae, 0xfd, 0xc6, 0x7e, 0x67, 0xeb, + 0xcd, 0x8c, 0x08, 0xe8, 0xa0, 0x24, 0x6b, 0x1d, + 0x73, 0xb0, 0xd1, 0xdd, 0x9f, 0xc5, 0x82, 0xb0 +); + +DEFINE_HEX_XDR_NETOBJ(usage_checksum, + 0x00, 0x00, 0x00, 0x02, KEY_USAGE_SEED_CHECKSUM +); +DEFINE_HEX_XDR_NETOBJ(usage_encryption, + 0x00, 0x00, 0x00, 0x02, KEY_USAGE_SEED_ENCRYPTION +); +DEFINE_HEX_XDR_NETOBJ(usage_integrity, + 0x00, 0x00, 0x00, 0x02, KEY_USAGE_SEED_INTEGRITY +); + +static const struct gss_krb5_test_param rfc6803_kdf_test_params[] = { + { + .desc = "Derive Kc subkey for camellia128-cts-cmac", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .base_key = &camellia128_cts_cmac_basekey, + .usage = &usage_checksum, + .expected_result = &camellia128_cts_cmac_Kc, + }, + { + .desc = "Derive Ke subkey for camellia128-cts-cmac", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .base_key = &camellia128_cts_cmac_basekey, + .usage = &usage_encryption, + .expected_result = &camellia128_cts_cmac_Ke, + }, + { + .desc = "Derive Ki subkey for camellia128-cts-cmac", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .base_key = &camellia128_cts_cmac_basekey, + .usage = &usage_integrity, + .expected_result = &camellia128_cts_cmac_Ki, + }, + { + .desc = "Derive Kc subkey for camellia256-cts-cmac", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .base_key = &camellia256_cts_cmac_basekey, + .usage = &usage_checksum, + .expected_result = &camellia256_cts_cmac_Kc, + }, + { + .desc = "Derive Ke subkey for camellia256-cts-cmac", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .base_key = &camellia256_cts_cmac_basekey, + .usage = &usage_encryption, + .expected_result = &camellia256_cts_cmac_Ke, + }, + { + .desc = "Derive Ki subkey for camellia256-cts-cmac", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .base_key = &camellia256_cts_cmac_basekey, + .usage = &usage_integrity, + .expected_result = &camellia256_cts_cmac_Ki, + }, +}; + +/* Creates the function rfc6803_kdf_gen_params */ +KUNIT_ARRAY_PARAM(rfc6803_kdf, rfc6803_kdf_test_params, gss_krb5_get_desc); + +/* + * From RFC 6803 Section 10. Test vectors + * + * Sample checksums. + * + * Copyright (c) 2012 IETF Trust and the persons identified as the + * document authors. All rights reserved. + * + * XXX: These tests are likely to fail on EBCDIC or Unicode platforms. + */ +DEFINE_STR_XDR_NETOBJ(rfc6803_checksum_test1_plaintext, + "abcdefghijk"); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test1_basekey, + 0x1d, 0xc4, 0x6a, 0x8d, 0x76, 0x3f, 0x4f, 0x93, + 0x74, 0x2b, 0xcb, 0xa3, 0x38, 0x75, 0x76, 0xc3 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test1_usage, + 0x00, 0x00, 0x00, 0x07, KEY_USAGE_SEED_CHECKSUM +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test1_expected_result, + 0x11, 0x78, 0xe6, 0xc5, 0xc4, 0x7a, 0x8c, 0x1a, + 0xe0, 0xc4, 0xb9, 0xc7, 0xd4, 0xeb, 0x7b, 0x6b +); + +DEFINE_STR_XDR_NETOBJ(rfc6803_checksum_test2_plaintext, + "ABCDEFGHIJKLMNOPQRSTUVWXYZ"); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test2_basekey, + 0x50, 0x27, 0xbc, 0x23, 0x1d, 0x0f, 0x3a, 0x9d, + 0x23, 0x33, 0x3f, 0x1c, 0xa6, 0xfd, 0xbe, 0x7c +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test2_usage, + 0x00, 0x00, 0x00, 0x08, KEY_USAGE_SEED_CHECKSUM +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test2_expected_result, + 0xd1, 0xb3, 0x4f, 0x70, 0x04, 0xa7, 0x31, 0xf2, + 0x3a, 0x0c, 0x00, 0xbf, 0x6c, 0x3f, 0x75, 0x3a +); + +DEFINE_STR_XDR_NETOBJ(rfc6803_checksum_test3_plaintext, + "123456789"); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test3_basekey, + 0xb6, 0x1c, 0x86, 0xcc, 0x4e, 0x5d, 0x27, 0x57, + 0x54, 0x5a, 0xd4, 0x23, 0x39, 0x9f, 0xb7, 0x03, + 0x1e, 0xca, 0xb9, 0x13, 0xcb, 0xb9, 0x00, 0xbd, + 0x7a, 0x3c, 0x6d, 0xd8, 0xbf, 0x92, 0x01, 0x5b +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test3_usage, + 0x00, 0x00, 0x00, 0x09, KEY_USAGE_SEED_CHECKSUM +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test3_expected_result, + 0x87, 0xa1, 0x2c, 0xfd, 0x2b, 0x96, 0x21, 0x48, + 0x10, 0xf0, 0x1c, 0x82, 0x6e, 0x77, 0x44, 0xb1 +); + +DEFINE_STR_XDR_NETOBJ(rfc6803_checksum_test4_plaintext, + "!@#$%^&*()!@#$%^&*()!@#$%^&*()"); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test4_basekey, + 0x32, 0x16, 0x4c, 0x5b, 0x43, 0x4d, 0x1d, 0x15, + 0x38, 0xe4, 0xcf, 0xd9, 0xbe, 0x80, 0x40, 0xfe, + 0x8c, 0x4a, 0xc7, 0xac, 0xc4, 0xb9, 0x3d, 0x33, + 0x14, 0xd2, 0x13, 0x36, 0x68, 0x14, 0x7a, 0x05 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test4_usage, + 0x00, 0x00, 0x00, 0x0a, KEY_USAGE_SEED_CHECKSUM +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_checksum_test4_expected_result, + 0x3f, 0xa0, 0xb4, 0x23, 0x55, 0xe5, 0x2b, 0x18, + 0x91, 0x87, 0x29, 0x4a, 0xa2, 0x52, 0xab, 0x64 +); + +static const struct gss_krb5_test_param rfc6803_checksum_test_params[] = { + { + .desc = "camellia128-cts-cmac checksum test 1", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .base_key = &rfc6803_checksum_test1_basekey, + .usage = &rfc6803_checksum_test1_usage, + .plaintext = &rfc6803_checksum_test1_plaintext, + .expected_result = &rfc6803_checksum_test1_expected_result, + }, + { + .desc = "camellia128-cts-cmac checksum test 2", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .base_key = &rfc6803_checksum_test2_basekey, + .usage = &rfc6803_checksum_test2_usage, + .plaintext = &rfc6803_checksum_test2_plaintext, + .expected_result = &rfc6803_checksum_test2_expected_result, + }, + { + .desc = "camellia256-cts-cmac checksum test 3", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .base_key = &rfc6803_checksum_test3_basekey, + .usage = &rfc6803_checksum_test3_usage, + .plaintext = &rfc6803_checksum_test3_plaintext, + .expected_result = &rfc6803_checksum_test3_expected_result, + }, + { + .desc = "camellia256-cts-cmac checksum test 4", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .base_key = &rfc6803_checksum_test4_basekey, + .usage = &rfc6803_checksum_test4_usage, + .plaintext = &rfc6803_checksum_test4_plaintext, + .expected_result = &rfc6803_checksum_test4_expected_result, + }, +}; + +/* Creates the function rfc6803_checksum_gen_params */ +KUNIT_ARRAY_PARAM(rfc6803_checksum, rfc6803_checksum_test_params, + gss_krb5_get_desc); + +/* + * From RFC 6803 Section 10. Test vectors + * + * Sample encryptions (all using the default cipher state) + * + * Copyright (c) 2012 IETF Trust and the persons identified as the + * document authors. All rights reserved. + * + * Key usage values are from errata 4326 against RFC 6803. + */ + +static const struct xdr_netobj rfc6803_enc_empty_plaintext = { + .len = 0, +}; + +DEFINE_STR_XDR_NETOBJ(rfc6803_enc_1byte_plaintext, "1"); +DEFINE_STR_XDR_NETOBJ(rfc6803_enc_9byte_plaintext, "9 bytesss"); +DEFINE_STR_XDR_NETOBJ(rfc6803_enc_13byte_plaintext, "13 bytes byte"); +DEFINE_STR_XDR_NETOBJ(rfc6803_enc_30byte_plaintext, + "30 bytes bytes bytes bytes byt" +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test1_confounder, + 0xb6, 0x98, 0x22, 0xa1, 0x9a, 0x6b, 0x09, 0xc0, + 0xeb, 0xc8, 0x55, 0x7d, 0x1f, 0x1b, 0x6c, 0x0a +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test1_basekey, + 0x1d, 0xc4, 0x6a, 0x8d, 0x76, 0x3f, 0x4f, 0x93, + 0x74, 0x2b, 0xcb, 0xa3, 0x38, 0x75, 0x76, 0xc3 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test1_expected_result, + 0xc4, 0x66, 0xf1, 0x87, 0x10, 0x69, 0x92, 0x1e, + 0xdb, 0x7c, 0x6f, 0xde, 0x24, 0x4a, 0x52, 0xdb, + 0x0b, 0xa1, 0x0e, 0xdc, 0x19, 0x7b, 0xdb, 0x80, + 0x06, 0x65, 0x8c, 0xa3, 0xcc, 0xce, 0x6e, 0xb8 +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test2_confounder, + 0x6f, 0x2f, 0xc3, 0xc2, 0xa1, 0x66, 0xfd, 0x88, + 0x98, 0x96, 0x7a, 0x83, 0xde, 0x95, 0x96, 0xd9 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test2_basekey, + 0x50, 0x27, 0xbc, 0x23, 0x1d, 0x0f, 0x3a, 0x9d, + 0x23, 0x33, 0x3f, 0x1c, 0xa6, 0xfd, 0xbe, 0x7c +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test2_expected_result, + 0x84, 0x2d, 0x21, 0xfd, 0x95, 0x03, 0x11, 0xc0, + 0xdd, 0x46, 0x4a, 0x3f, 0x4b, 0xe8, 0xd6, 0xda, + 0x88, 0xa5, 0x6d, 0x55, 0x9c, 0x9b, 0x47, 0xd3, + 0xf9, 0xa8, 0x50, 0x67, 0xaf, 0x66, 0x15, 0x59, + 0xb8 +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test3_confounder, + 0xa5, 0xb4, 0xa7, 0x1e, 0x07, 0x7a, 0xee, 0xf9, + 0x3c, 0x87, 0x63, 0xc1, 0x8f, 0xdb, 0x1f, 0x10 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test3_basekey, + 0xa1, 0xbb, 0x61, 0xe8, 0x05, 0xf9, 0xba, 0x6d, + 0xde, 0x8f, 0xdb, 0xdd, 0xc0, 0x5c, 0xde, 0xa0 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test3_expected_result, + 0x61, 0x9f, 0xf0, 0x72, 0xe3, 0x62, 0x86, 0xff, + 0x0a, 0x28, 0xde, 0xb3, 0xa3, 0x52, 0xec, 0x0d, + 0x0e, 0xdf, 0x5c, 0x51, 0x60, 0xd6, 0x63, 0xc9, + 0x01, 0x75, 0x8c, 0xcf, 0x9d, 0x1e, 0xd3, 0x3d, + 0x71, 0xdb, 0x8f, 0x23, 0xaa, 0xbf, 0x83, 0x48, + 0xa0 +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test4_confounder, + 0x19, 0xfe, 0xe4, 0x0d, 0x81, 0x0c, 0x52, 0x4b, + 0x5b, 0x22, 0xf0, 0x18, 0x74, 0xc6, 0x93, 0xda +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test4_basekey, + 0x2c, 0xa2, 0x7a, 0x5f, 0xaf, 0x55, 0x32, 0x24, + 0x45, 0x06, 0x43, 0x4e, 0x1c, 0xef, 0x66, 0x76 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test4_expected_result, + 0xb8, 0xec, 0xa3, 0x16, 0x7a, 0xe6, 0x31, 0x55, + 0x12, 0xe5, 0x9f, 0x98, 0xa7, 0xc5, 0x00, 0x20, + 0x5e, 0x5f, 0x63, 0xff, 0x3b, 0xb3, 0x89, 0xaf, + 0x1c, 0x41, 0xa2, 0x1d, 0x64, 0x0d, 0x86, 0x15, + 0xc9, 0xed, 0x3f, 0xbe, 0xb0, 0x5a, 0xb6, 0xac, + 0xb6, 0x76, 0x89, 0xb5, 0xea +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test5_confounder, + 0xca, 0x7a, 0x7a, 0xb4, 0xbe, 0x19, 0x2d, 0xab, + 0xd6, 0x03, 0x50, 0x6d, 0xb1, 0x9c, 0x39, 0xe2 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test5_basekey, + 0x78, 0x24, 0xf8, 0xc1, 0x6f, 0x83, 0xff, 0x35, + 0x4c, 0x6b, 0xf7, 0x51, 0x5b, 0x97, 0x3f, 0x43 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test5_expected_result, + 0xa2, 0x6a, 0x39, 0x05, 0xa4, 0xff, 0xd5, 0x81, + 0x6b, 0x7b, 0x1e, 0x27, 0x38, 0x0d, 0x08, 0x09, + 0x0c, 0x8e, 0xc1, 0xf3, 0x04, 0x49, 0x6e, 0x1a, + 0xbd, 0xcd, 0x2b, 0xdc, 0xd1, 0xdf, 0xfc, 0x66, + 0x09, 0x89, 0xe1, 0x17, 0xa7, 0x13, 0xdd, 0xbb, + 0x57, 0xa4, 0x14, 0x6c, 0x15, 0x87, 0xcb, 0xa4, + 0x35, 0x66, 0x65, 0x59, 0x1d, 0x22, 0x40, 0x28, + 0x2f, 0x58, 0x42, 0xb1, 0x05, 0xa5 +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test6_confounder, + 0x3c, 0xbb, 0xd2, 0xb4, 0x59, 0x17, 0x94, 0x10, + 0x67, 0xf9, 0x65, 0x99, 0xbb, 0x98, 0x92, 0x6c +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test6_basekey, + 0xb6, 0x1c, 0x86, 0xcc, 0x4e, 0x5d, 0x27, 0x57, + 0x54, 0x5a, 0xd4, 0x23, 0x39, 0x9f, 0xb7, 0x03, + 0x1e, 0xca, 0xb9, 0x13, 0xcb, 0xb9, 0x00, 0xbd, + 0x7a, 0x3c, 0x6d, 0xd8, 0xbf, 0x92, 0x01, 0x5b +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test6_expected_result, + 0x03, 0x88, 0x6d, 0x03, 0x31, 0x0b, 0x47, 0xa6, + 0xd8, 0xf0, 0x6d, 0x7b, 0x94, 0xd1, 0xdd, 0x83, + 0x7e, 0xcc, 0xe3, 0x15, 0xef, 0x65, 0x2a, 0xff, + 0x62, 0x08, 0x59, 0xd9, 0x4a, 0x25, 0x92, 0x66 +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test7_confounder, + 0xde, 0xf4, 0x87, 0xfc, 0xeb, 0xe6, 0xde, 0x63, + 0x46, 0xd4, 0xda, 0x45, 0x21, 0xbb, 0xa2, 0xd2 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test7_basekey, + 0x1b, 0x97, 0xfe, 0x0a, 0x19, 0x0e, 0x20, 0x21, + 0xeb, 0x30, 0x75, 0x3e, 0x1b, 0x6e, 0x1e, 0x77, + 0xb0, 0x75, 0x4b, 0x1d, 0x68, 0x46, 0x10, 0x35, + 0x58, 0x64, 0x10, 0x49, 0x63, 0x46, 0x38, 0x33 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test7_expected_result, + 0x2c, 0x9c, 0x15, 0x70, 0x13, 0x3c, 0x99, 0xbf, + 0x6a, 0x34, 0xbc, 0x1b, 0x02, 0x12, 0x00, 0x2f, + 0xd1, 0x94, 0x33, 0x87, 0x49, 0xdb, 0x41, 0x35, + 0x49, 0x7a, 0x34, 0x7c, 0xfc, 0xd9, 0xd1, 0x8a, + 0x12 +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test8_confounder, + 0xad, 0x4f, 0xf9, 0x04, 0xd3, 0x4e, 0x55, 0x53, + 0x84, 0xb1, 0x41, 0x00, 0xfc, 0x46, 0x5f, 0x88 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test8_basekey, + 0x32, 0x16, 0x4c, 0x5b, 0x43, 0x4d, 0x1d, 0x15, + 0x38, 0xe4, 0xcf, 0xd9, 0xbe, 0x80, 0x40, 0xfe, + 0x8c, 0x4a, 0xc7, 0xac, 0xc4, 0xb9, 0x3d, 0x33, + 0x14, 0xd2, 0x13, 0x36, 0x68, 0x14, 0x7a, 0x05 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test8_expected_result, + 0x9c, 0x6d, 0xe7, 0x5f, 0x81, 0x2d, 0xe7, 0xed, + 0x0d, 0x28, 0xb2, 0x96, 0x35, 0x57, 0xa1, 0x15, + 0x64, 0x09, 0x98, 0x27, 0x5b, 0x0a, 0xf5, 0x15, + 0x27, 0x09, 0x91, 0x3f, 0xf5, 0x2a, 0x2a, 0x9c, + 0x8e, 0x63, 0xb8, 0x72, 0xf9, 0x2e, 0x64, 0xc8, + 0x39 +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test9_confounder, + 0xcf, 0x9b, 0xca, 0x6d, 0xf1, 0x14, 0x4e, 0x0c, + 0x0a, 0xf9, 0xb8, 0xf3, 0x4c, 0x90, 0xd5, 0x14 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test9_basekey, + 0xb0, 0x38, 0xb1, 0x32, 0xcd, 0x8e, 0x06, 0x61, + 0x22, 0x67, 0xfa, 0xb7, 0x17, 0x00, 0x66, 0xd8, + 0x8a, 0xec, 0xcb, 0xa0, 0xb7, 0x44, 0xbf, 0xc6, + 0x0d, 0xc8, 0x9b, 0xca, 0x18, 0x2d, 0x07, 0x15 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test9_expected_result, + 0xee, 0xec, 0x85, 0xa9, 0x81, 0x3c, 0xdc, 0x53, + 0x67, 0x72, 0xab, 0x9b, 0x42, 0xde, 0xfc, 0x57, + 0x06, 0xf7, 0x26, 0xe9, 0x75, 0xdd, 0xe0, 0x5a, + 0x87, 0xeb, 0x54, 0x06, 0xea, 0x32, 0x4c, 0xa1, + 0x85, 0xc9, 0x98, 0x6b, 0x42, 0xaa, 0xbe, 0x79, + 0x4b, 0x84, 0x82, 0x1b, 0xee +); + +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test10_confounder, + 0x64, 0x4d, 0xef, 0x38, 0xda, 0x35, 0x00, 0x72, + 0x75, 0x87, 0x8d, 0x21, 0x68, 0x55, 0xe2, 0x28 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test10_basekey, + 0xcc, 0xfc, 0xd3, 0x49, 0xbf, 0x4c, 0x66, 0x77, + 0xe8, 0x6e, 0x4b, 0x02, 0xb8, 0xea, 0xb9, 0x24, + 0xa5, 0x46, 0xac, 0x73, 0x1c, 0xf9, 0xbf, 0x69, + 0x89, 0xb9, 0x96, 0xe7, 0xd6, 0xbf, 0xbb, 0xa7 +); +DEFINE_HEX_XDR_NETOBJ(rfc6803_enc_test10_expected_result, + 0x0e, 0x44, 0x68, 0x09, 0x85, 0x85, 0x5f, 0x2d, + 0x1f, 0x18, 0x12, 0x52, 0x9c, 0xa8, 0x3b, 0xfd, + 0x8e, 0x34, 0x9d, 0xe6, 0xfd, 0x9a, 0xda, 0x0b, + 0xaa, 0xa0, 0x48, 0xd6, 0x8e, 0x26, 0x5f, 0xeb, + 0xf3, 0x4a, 0xd1, 0x25, 0x5a, 0x34, 0x49, 0x99, + 0xad, 0x37, 0x14, 0x68, 0x87, 0xa6, 0xc6, 0x84, + 0x57, 0x31, 0xac, 0x7f, 0x46, 0x37, 0x6a, 0x05, + 0x04, 0xcd, 0x06, 0x57, 0x14, 0x74 +); + +static const struct gss_krb5_test_param rfc6803_encrypt_test_params[] = { + { + .desc = "Encrypt empty plaintext with camellia128-cts-cmac", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .constant = 0, + .base_key = &rfc6803_enc_test1_basekey, + .plaintext = &rfc6803_enc_empty_plaintext, + .confounder = &rfc6803_enc_test1_confounder, + .expected_result = &rfc6803_enc_test1_expected_result, + }, + { + .desc = "Encrypt 1 byte with camellia128-cts-cmac", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .constant = 1, + .base_key = &rfc6803_enc_test2_basekey, + .plaintext = &rfc6803_enc_1byte_plaintext, + .confounder = &rfc6803_enc_test2_confounder, + .expected_result = &rfc6803_enc_test2_expected_result, + }, + { + .desc = "Encrypt 9 bytes with camellia128-cts-cmac", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .constant = 2, + .base_key = &rfc6803_enc_test3_basekey, + .plaintext = &rfc6803_enc_9byte_plaintext, + .confounder = &rfc6803_enc_test3_confounder, + .expected_result = &rfc6803_enc_test3_expected_result, + }, + { + .desc = "Encrypt 13 bytes with camellia128-cts-cmac", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .constant = 3, + .base_key = &rfc6803_enc_test4_basekey, + .plaintext = &rfc6803_enc_13byte_plaintext, + .confounder = &rfc6803_enc_test4_confounder, + .expected_result = &rfc6803_enc_test4_expected_result, + }, + { + .desc = "Encrypt 30 bytes with camellia128-cts-cmac", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .constant = 4, + .base_key = &rfc6803_enc_test5_basekey, + .plaintext = &rfc6803_enc_30byte_plaintext, + .confounder = &rfc6803_enc_test5_confounder, + .expected_result = &rfc6803_enc_test5_expected_result, + }, + { + .desc = "Encrypt empty plaintext with camellia256-cts-cmac", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .constant = 0, + .base_key = &rfc6803_enc_test6_basekey, + .plaintext = &rfc6803_enc_empty_plaintext, + .confounder = &rfc6803_enc_test6_confounder, + .expected_result = &rfc6803_enc_test6_expected_result, + }, + { + .desc = "Encrypt 1 byte with camellia256-cts-cmac", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .constant = 1, + .base_key = &rfc6803_enc_test7_basekey, + .plaintext = &rfc6803_enc_1byte_plaintext, + .confounder = &rfc6803_enc_test7_confounder, + .expected_result = &rfc6803_enc_test7_expected_result, + }, + { + .desc = "Encrypt 9 bytes with camellia256-cts-cmac", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .constant = 2, + .base_key = &rfc6803_enc_test8_basekey, + .plaintext = &rfc6803_enc_9byte_plaintext, + .confounder = &rfc6803_enc_test8_confounder, + .expected_result = &rfc6803_enc_test8_expected_result, + }, + { + .desc = "Encrypt 13 bytes with camellia256-cts-cmac", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .constant = 3, + .base_key = &rfc6803_enc_test9_basekey, + .plaintext = &rfc6803_enc_13byte_plaintext, + .confounder = &rfc6803_enc_test9_confounder, + .expected_result = &rfc6803_enc_test9_expected_result, + }, + { + .desc = "Encrypt 30 bytes with camellia256-cts-cmac", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .constant = 4, + .base_key = &rfc6803_enc_test10_basekey, + .plaintext = &rfc6803_enc_30byte_plaintext, + .confounder = &rfc6803_enc_test10_confounder, + .expected_result = &rfc6803_enc_test10_expected_result, + }, +}; + +/* Creates the function rfc6803_encrypt_gen_params */ +KUNIT_ARRAY_PARAM(rfc6803_encrypt, rfc6803_encrypt_test_params, + gss_krb5_get_desc); + +static void rfc6803_encrypt_case(struct kunit *test) +{ + const struct gss_krb5_test_param *param = test->param_value; + struct crypto_sync_skcipher *cts_tfm, *cbc_tfm; + const struct gss_krb5_enctype *gk5e; + struct xdr_netobj Ke, Ki, checksum; + u8 usage_data[GSS_KRB5_K5CLENGTH]; + struct xdr_netobj usage = { + .data = usage_data, + .len = sizeof(usage_data), + }; + struct crypto_ahash *ahash_tfm; + unsigned int blocksize; + struct xdr_buf buf; + void *text; + size_t len; + u32 err; + + /* Arrange */ + gk5e = gss_krb5_lookup_enctype(param->enctype); + if (!gk5e) + kunit_skip(test, "Encryption type is not available"); + + memset(usage_data, 0, sizeof(usage_data)); + usage.data[3] = param->constant; + + Ke.len = gk5e->Ke_length; + Ke.data = kunit_kzalloc(test, Ke.len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Ke.data); + usage.data[4] = KEY_USAGE_SEED_ENCRYPTION; + err = gk5e->derive_key(gk5e, param->base_key, &Ke, &usage, GFP_KERNEL); + KUNIT_ASSERT_EQ(test, err, 0); + + cbc_tfm = crypto_alloc_sync_skcipher(gk5e->aux_cipher, 0, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cbc_tfm); + err = crypto_sync_skcipher_setkey(cbc_tfm, Ke.data, Ke.len); + KUNIT_ASSERT_EQ(test, err, 0); + + cts_tfm = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cts_tfm); + err = crypto_sync_skcipher_setkey(cts_tfm, Ke.data, Ke.len); + KUNIT_ASSERT_EQ(test, err, 0); + blocksize = crypto_sync_skcipher_blocksize(cts_tfm); + + len = param->confounder->len + param->plaintext->len + blocksize; + text = kunit_kzalloc(test, len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, text); + memcpy(text, param->confounder->data, param->confounder->len); + memcpy(text + param->confounder->len, param->plaintext->data, + param->plaintext->len); + + memset(&buf, 0, sizeof(buf)); + buf.head[0].iov_base = text; + buf.head[0].iov_len = param->confounder->len + param->plaintext->len; + buf.len = buf.head[0].iov_len; + + checksum.len = gk5e->cksumlength; + checksum.data = kunit_kzalloc(test, checksum.len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, checksum.data); + + Ki.len = gk5e->Ki_length; + Ki.data = kunit_kzalloc(test, Ki.len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Ki.data); + usage.data[4] = KEY_USAGE_SEED_INTEGRITY; + err = gk5e->derive_key(gk5e, param->base_key, &Ki, + &usage, GFP_KERNEL); + KUNIT_ASSERT_EQ(test, err, 0); + ahash_tfm = crypto_alloc_ahash(gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, ahash_tfm); + err = crypto_ahash_setkey(ahash_tfm, Ki.data, Ki.len); + KUNIT_ASSERT_EQ(test, err, 0); + + /* Act */ + err = gss_krb5_checksum(ahash_tfm, NULL, 0, &buf, 0, &checksum); + KUNIT_ASSERT_EQ(test, err, 0); + + err = krb5_cbc_cts_encrypt(cts_tfm, cbc_tfm, 0, &buf, NULL, NULL, 0); + KUNIT_ASSERT_EQ(test, err, 0); + + /* Assert */ + KUNIT_EXPECT_EQ_MSG(test, param->expected_result->len, + buf.len + checksum.len, + "ciphertext length mismatch"); + KUNIT_EXPECT_EQ_MSG(test, + memcmp(param->expected_result->data, + buf.head[0].iov_base, buf.len), 0, + "encrypted result mismatch"); + KUNIT_EXPECT_EQ_MSG(test, + memcmp(param->expected_result->data + + (param->expected_result->len - checksum.len), + checksum.data, checksum.len), 0, + "HMAC mismatch"); + + crypto_free_ahash(ahash_tfm); + crypto_free_sync_skcipher(cts_tfm); + crypto_free_sync_skcipher(cbc_tfm); +} + +static struct kunit_case rfc6803_test_cases[] = { + { + .name = "RFC 6803 key derivation", + .run_case = kdf_case, + .generate_params = rfc6803_kdf_gen_params, + }, + { + .name = "RFC 6803 checksum", + .run_case = checksum_case, + .generate_params = rfc6803_checksum_gen_params, + }, + { + .name = "RFC 6803 encryption", + .run_case = rfc6803_encrypt_case, + .generate_params = rfc6803_encrypt_gen_params, + }, + {} +}; + +static struct kunit_suite rfc6803_suite = { + .name = "RFC 6803 suite", + .test_cases = rfc6803_test_cases, +}; + +/* + * From RFC 8009 Appendix A. Test Vectors + * + * Sample results for SHA-2 enctype key derivation + * + * This test material is copyright (c) 2016 IETF Trust and the + * persons identified as the document authors. All rights reserved. + */ + +DEFINE_HEX_XDR_NETOBJ(aes128_cts_hmac_sha256_128_basekey, + 0x37, 0x05, 0xd9, 0x60, 0x80, 0xc1, 0x77, 0x28, + 0xa0, 0xe8, 0x00, 0xea, 0xb6, 0xe0, 0xd2, 0x3c +); +DEFINE_HEX_XDR_NETOBJ(aes128_cts_hmac_sha256_128_Kc, + 0xb3, 0x1a, 0x01, 0x8a, 0x48, 0xf5, 0x47, 0x76, + 0xf4, 0x03, 0xe9, 0xa3, 0x96, 0x32, 0x5d, 0xc3 +); +DEFINE_HEX_XDR_NETOBJ(aes128_cts_hmac_sha256_128_Ke, + 0x9b, 0x19, 0x7d, 0xd1, 0xe8, 0xc5, 0x60, 0x9d, + 0x6e, 0x67, 0xc3, 0xe3, 0x7c, 0x62, 0xc7, 0x2e +); +DEFINE_HEX_XDR_NETOBJ(aes128_cts_hmac_sha256_128_Ki, + 0x9f, 0xda, 0x0e, 0x56, 0xab, 0x2d, 0x85, 0xe1, + 0x56, 0x9a, 0x68, 0x86, 0x96, 0xc2, 0x6a, 0x6c +); + +DEFINE_HEX_XDR_NETOBJ(aes256_cts_hmac_sha384_192_basekey, + 0x6d, 0x40, 0x4d, 0x37, 0xfa, 0xf7, 0x9f, 0x9d, + 0xf0, 0xd3, 0x35, 0x68, 0xd3, 0x20, 0x66, 0x98, + 0x00, 0xeb, 0x48, 0x36, 0x47, 0x2e, 0xa8, 0xa0, + 0x26, 0xd1, 0x6b, 0x71, 0x82, 0x46, 0x0c, 0x52 +); +DEFINE_HEX_XDR_NETOBJ(aes256_cts_hmac_sha384_192_Kc, + 0xef, 0x57, 0x18, 0xbe, 0x86, 0xcc, 0x84, 0x96, + 0x3d, 0x8b, 0xbb, 0x50, 0x31, 0xe9, 0xf5, 0xc4, + 0xba, 0x41, 0xf2, 0x8f, 0xaf, 0x69, 0xe7, 0x3d +); +DEFINE_HEX_XDR_NETOBJ(aes256_cts_hmac_sha384_192_Ke, + 0x56, 0xab, 0x22, 0xbe, 0xe6, 0x3d, 0x82, 0xd7, + 0xbc, 0x52, 0x27, 0xf6, 0x77, 0x3f, 0x8e, 0xa7, + 0xa5, 0xeb, 0x1c, 0x82, 0x51, 0x60, 0xc3, 0x83, + 0x12, 0x98, 0x0c, 0x44, 0x2e, 0x5c, 0x7e, 0x49 +); +DEFINE_HEX_XDR_NETOBJ(aes256_cts_hmac_sha384_192_Ki, + 0x69, 0xb1, 0x65, 0x14, 0xe3, 0xcd, 0x8e, 0x56, + 0xb8, 0x20, 0x10, 0xd5, 0xc7, 0x30, 0x12, 0xb6, + 0x22, 0xc4, 0xd0, 0x0f, 0xfc, 0x23, 0xed, 0x1f +); + +static const struct gss_krb5_test_param rfc8009_kdf_test_params[] = { + { + .desc = "Derive Kc subkey for aes128-cts-hmac-sha256-128", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .base_key = &aes128_cts_hmac_sha256_128_basekey, + .usage = &usage_checksum, + .expected_result = &aes128_cts_hmac_sha256_128_Kc, + }, + { + .desc = "Derive Ke subkey for aes128-cts-hmac-sha256-128", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .base_key = &aes128_cts_hmac_sha256_128_basekey, + .usage = &usage_encryption, + .expected_result = &aes128_cts_hmac_sha256_128_Ke, + }, + { + .desc = "Derive Ki subkey for aes128-cts-hmac-sha256-128", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .base_key = &aes128_cts_hmac_sha256_128_basekey, + .usage = &usage_integrity, + .expected_result = &aes128_cts_hmac_sha256_128_Ki, + }, + { + .desc = "Derive Kc subkey for aes256-cts-hmac-sha384-192", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .base_key = &aes256_cts_hmac_sha384_192_basekey, + .usage = &usage_checksum, + .expected_result = &aes256_cts_hmac_sha384_192_Kc, + }, + { + .desc = "Derive Ke subkey for aes256-cts-hmac-sha384-192", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .base_key = &aes256_cts_hmac_sha384_192_basekey, + .usage = &usage_encryption, + .expected_result = &aes256_cts_hmac_sha384_192_Ke, + }, + { + .desc = "Derive Ki subkey for aes256-cts-hmac-sha384-192", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .base_key = &aes256_cts_hmac_sha384_192_basekey, + .usage = &usage_integrity, + .expected_result = &aes256_cts_hmac_sha384_192_Ki, + }, +}; + +/* Creates the function rfc8009_kdf_gen_params */ +KUNIT_ARRAY_PARAM(rfc8009_kdf, rfc8009_kdf_test_params, gss_krb5_get_desc); + +/* + * From RFC 8009 Appendix A. Test Vectors + * + * These sample checksums use the above sample key derivation results, + * including use of the same base-key and key usage values. + * + * This test material is copyright (c) 2016 IETF Trust and the + * persons identified as the document authors. All rights reserved. + */ + +DEFINE_HEX_XDR_NETOBJ(rfc8009_checksum_plaintext, + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_checksum_test1_expected_result, + 0xd7, 0x83, 0x67, 0x18, 0x66, 0x43, 0xd6, 0x7b, + 0x41, 0x1c, 0xba, 0x91, 0x39, 0xfc, 0x1d, 0xee +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_checksum_test2_expected_result, + 0x45, 0xee, 0x79, 0x15, 0x67, 0xee, 0xfc, 0xa3, + 0x7f, 0x4a, 0xc1, 0xe0, 0x22, 0x2d, 0xe8, 0x0d, + 0x43, 0xc3, 0xbf, 0xa0, 0x66, 0x99, 0x67, 0x2a +); + +static const struct gss_krb5_test_param rfc8009_checksum_test_params[] = { + { + .desc = "Checksum with aes128-cts-hmac-sha256-128", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .base_key = &aes128_cts_hmac_sha256_128_basekey, + .usage = &usage_checksum, + .plaintext = &rfc8009_checksum_plaintext, + .expected_result = &rfc8009_checksum_test1_expected_result, + }, + { + .desc = "Checksum with aes256-cts-hmac-sha384-192", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .base_key = &aes256_cts_hmac_sha384_192_basekey, + .usage = &usage_checksum, + .plaintext = &rfc8009_checksum_plaintext, + .expected_result = &rfc8009_checksum_test2_expected_result, + }, +}; + +/* Creates the function rfc8009_checksum_gen_params */ +KUNIT_ARRAY_PARAM(rfc8009_checksum, rfc8009_checksum_test_params, + gss_krb5_get_desc); + +/* + * From RFC 8009 Appendix A. Test Vectors + * + * Sample encryptions (all using the default cipher state): + * -------------------------------------------------------- + * + * These sample encryptions use the above sample key derivation results, + * including use of the same base-key and key usage values. + * + * This test material is copyright (c) 2016 IETF Trust and the + * persons identified as the document authors. All rights reserved. + */ + +static const struct xdr_netobj rfc8009_enc_empty_plaintext = { + .len = 0, +}; +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_short_plaintext, + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_block_plaintext, + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_long_plaintext, + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14 +); + +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test1_confounder, + 0x7e, 0x58, 0x95, 0xea, 0xf2, 0x67, 0x24, 0x35, + 0xba, 0xd8, 0x17, 0xf5, 0x45, 0xa3, 0x71, 0x48 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test1_expected_result, + 0xef, 0x85, 0xfb, 0x89, 0x0b, 0xb8, 0x47, 0x2f, + 0x4d, 0xab, 0x20, 0x39, 0x4d, 0xca, 0x78, 0x1d +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test1_expected_hmac, + 0xad, 0x87, 0x7e, 0xda, 0x39, 0xd5, 0x0c, 0x87, + 0x0c, 0x0d, 0x5a, 0x0a, 0x8e, 0x48, 0xc7, 0x18 +); + +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test2_confounder, + 0x7b, 0xca, 0x28, 0x5e, 0x2f, 0xd4, 0x13, 0x0f, + 0xb5, 0x5b, 0x1a, 0x5c, 0x83, 0xbc, 0x5b, 0x24 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test2_expected_result, + 0x84, 0xd7, 0xf3, 0x07, 0x54, 0xed, 0x98, 0x7b, + 0xab, 0x0b, 0xf3, 0x50, 0x6b, 0xeb, 0x09, 0xcf, + 0xb5, 0x54, 0x02, 0xce, 0xf7, 0xe6 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test2_expected_hmac, + 0x87, 0x7c, 0xe9, 0x9e, 0x24, 0x7e, 0x52, 0xd1, + 0x6e, 0xd4, 0x42, 0x1d, 0xfd, 0xf8, 0x97, 0x6c +); + +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test3_confounder, + 0x56, 0xab, 0x21, 0x71, 0x3f, 0xf6, 0x2c, 0x0a, + 0x14, 0x57, 0x20, 0x0f, 0x6f, 0xa9, 0x94, 0x8f +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test3_expected_result, + 0x35, 0x17, 0xd6, 0x40, 0xf5, 0x0d, 0xdc, 0x8a, + 0xd3, 0x62, 0x87, 0x22, 0xb3, 0x56, 0x9d, 0x2a, + 0xe0, 0x74, 0x93, 0xfa, 0x82, 0x63, 0x25, 0x40, + 0x80, 0xea, 0x65, 0xc1, 0x00, 0x8e, 0x8f, 0xc2 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test3_expected_hmac, + 0x95, 0xfb, 0x48, 0x52, 0xe7, 0xd8, 0x3e, 0x1e, + 0x7c, 0x48, 0xc3, 0x7e, 0xeb, 0xe6, 0xb0, 0xd3 +); + +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test4_confounder, + 0xa7, 0xa4, 0xe2, 0x9a, 0x47, 0x28, 0xce, 0x10, + 0x66, 0x4f, 0xb6, 0x4e, 0x49, 0xad, 0x3f, 0xac +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test4_expected_result, + 0x72, 0x0f, 0x73, 0xb1, 0x8d, 0x98, 0x59, 0xcd, + 0x6c, 0xcb, 0x43, 0x46, 0x11, 0x5c, 0xd3, 0x36, + 0xc7, 0x0f, 0x58, 0xed, 0xc0, 0xc4, 0x43, 0x7c, + 0x55, 0x73, 0x54, 0x4c, 0x31, 0xc8, 0x13, 0xbc, + 0xe1, 0xe6, 0xd0, 0x72, 0xc1 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test4_expected_hmac, + 0x86, 0xb3, 0x9a, 0x41, 0x3c, 0x2f, 0x92, 0xca, + 0x9b, 0x83, 0x34, 0xa2, 0x87, 0xff, 0xcb, 0xfc +); + +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test5_confounder, + 0xf7, 0x64, 0xe9, 0xfa, 0x15, 0xc2, 0x76, 0x47, + 0x8b, 0x2c, 0x7d, 0x0c, 0x4e, 0x5f, 0x58, 0xe4 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test5_expected_result, + 0x41, 0xf5, 0x3f, 0xa5, 0xbf, 0xe7, 0x02, 0x6d, + 0x91, 0xfa, 0xf9, 0xbe, 0x95, 0x91, 0x95, 0xa0 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test5_expected_hmac, + 0x58, 0x70, 0x72, 0x73, 0xa9, 0x6a, 0x40, 0xf0, + 0xa0, 0x19, 0x60, 0x62, 0x1a, 0xc6, 0x12, 0x74, + 0x8b, 0x9b, 0xbf, 0xbe, 0x7e, 0xb4, 0xce, 0x3c +); + +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test6_confounder, + 0xb8, 0x0d, 0x32, 0x51, 0xc1, 0xf6, 0x47, 0x14, + 0x94, 0x25, 0x6f, 0xfe, 0x71, 0x2d, 0x0b, 0x9a +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test6_expected_result, + 0x4e, 0xd7, 0xb3, 0x7c, 0x2b, 0xca, 0xc8, 0xf7, + 0x4f, 0x23, 0xc1, 0xcf, 0x07, 0xe6, 0x2b, 0xc7, + 0xb7, 0x5f, 0xb3, 0xf6, 0x37, 0xb9 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test6_expected_hmac, + 0xf5, 0x59, 0xc7, 0xf6, 0x64, 0xf6, 0x9e, 0xab, + 0x7b, 0x60, 0x92, 0x23, 0x75, 0x26, 0xea, 0x0d, + 0x1f, 0x61, 0xcb, 0x20, 0xd6, 0x9d, 0x10, 0xf2 +); + +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test7_confounder, + 0x53, 0xbf, 0x8a, 0x0d, 0x10, 0x52, 0x65, 0xd4, + 0xe2, 0x76, 0x42, 0x86, 0x24, 0xce, 0x5e, 0x63 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test7_expected_result, + 0xbc, 0x47, 0xff, 0xec, 0x79, 0x98, 0xeb, 0x91, + 0xe8, 0x11, 0x5c, 0xf8, 0xd1, 0x9d, 0xac, 0x4b, + 0xbb, 0xe2, 0xe1, 0x63, 0xe8, 0x7d, 0xd3, 0x7f, + 0x49, 0xbe, 0xca, 0x92, 0x02, 0x77, 0x64, 0xf6 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test7_expected_hmac, + 0x8c, 0xf5, 0x1f, 0x14, 0xd7, 0x98, 0xc2, 0x27, + 0x3f, 0x35, 0xdf, 0x57, 0x4d, 0x1f, 0x93, 0x2e, + 0x40, 0xc4, 0xff, 0x25, 0x5b, 0x36, 0xa2, 0x66 +); + +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test8_confounder, + 0x76, 0x3e, 0x65, 0x36, 0x7e, 0x86, 0x4f, 0x02, + 0xf5, 0x51, 0x53, 0xc7, 0xe3, 0xb5, 0x8a, 0xf1 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test8_expected_result, + 0x40, 0x01, 0x3e, 0x2d, 0xf5, 0x8e, 0x87, 0x51, + 0x95, 0x7d, 0x28, 0x78, 0xbc, 0xd2, 0xd6, 0xfe, + 0x10, 0x1c, 0xcf, 0xd5, 0x56, 0xcb, 0x1e, 0xae, + 0x79, 0xdb, 0x3c, 0x3e, 0xe8, 0x64, 0x29, 0xf2, + 0xb2, 0xa6, 0x02, 0xac, 0x86 +); +DEFINE_HEX_XDR_NETOBJ(rfc8009_enc_test8_expected_hmac, + 0xfe, 0xf6, 0xec, 0xb6, 0x47, 0xd6, 0x29, 0x5f, + 0xae, 0x07, 0x7a, 0x1f, 0xeb, 0x51, 0x75, 0x08, + 0xd2, 0xc1, 0x6b, 0x41, 0x92, 0xe0, 0x1f, 0x62 +); + +static const struct gss_krb5_test_param rfc8009_encrypt_test_params[] = { + { + .desc = "Encrypt empty plaintext with aes128-cts-hmac-sha256-128", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .plaintext = &rfc8009_enc_empty_plaintext, + .confounder = &rfc8009_enc_test1_confounder, + .base_key = &aes128_cts_hmac_sha256_128_basekey, + .expected_result = &rfc8009_enc_test1_expected_result, + .expected_hmac = &rfc8009_enc_test1_expected_hmac, + }, + { + .desc = "Encrypt short plaintext with aes128-cts-hmac-sha256-128", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .plaintext = &rfc8009_enc_short_plaintext, + .confounder = &rfc8009_enc_test2_confounder, + .base_key = &aes128_cts_hmac_sha256_128_basekey, + .expected_result = &rfc8009_enc_test2_expected_result, + .expected_hmac = &rfc8009_enc_test2_expected_hmac, + }, + { + .desc = "Encrypt block plaintext with aes128-cts-hmac-sha256-128", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .plaintext = &rfc8009_enc_block_plaintext, + .confounder = &rfc8009_enc_test3_confounder, + .base_key = &aes128_cts_hmac_sha256_128_basekey, + .expected_result = &rfc8009_enc_test3_expected_result, + .expected_hmac = &rfc8009_enc_test3_expected_hmac, + }, + { + .desc = "Encrypt long plaintext with aes128-cts-hmac-sha256-128", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .plaintext = &rfc8009_enc_long_plaintext, + .confounder = &rfc8009_enc_test4_confounder, + .base_key = &aes128_cts_hmac_sha256_128_basekey, + .expected_result = &rfc8009_enc_test4_expected_result, + .expected_hmac = &rfc8009_enc_test4_expected_hmac, + }, + { + .desc = "Encrypt empty plaintext with aes256-cts-hmac-sha384-192", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .plaintext = &rfc8009_enc_empty_plaintext, + .confounder = &rfc8009_enc_test5_confounder, + .base_key = &aes256_cts_hmac_sha384_192_basekey, + .expected_result = &rfc8009_enc_test5_expected_result, + .expected_hmac = &rfc8009_enc_test5_expected_hmac, + }, + { + .desc = "Encrypt short plaintext with aes256-cts-hmac-sha384-192", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .plaintext = &rfc8009_enc_short_plaintext, + .confounder = &rfc8009_enc_test6_confounder, + .base_key = &aes256_cts_hmac_sha384_192_basekey, + .expected_result = &rfc8009_enc_test6_expected_result, + .expected_hmac = &rfc8009_enc_test6_expected_hmac, + }, + { + .desc = "Encrypt block plaintext with aes256-cts-hmac-sha384-192", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .plaintext = &rfc8009_enc_block_plaintext, + .confounder = &rfc8009_enc_test7_confounder, + .base_key = &aes256_cts_hmac_sha384_192_basekey, + .expected_result = &rfc8009_enc_test7_expected_result, + .expected_hmac = &rfc8009_enc_test7_expected_hmac, + }, + { + .desc = "Encrypt long plaintext with aes256-cts-hmac-sha384-192", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .plaintext = &rfc8009_enc_long_plaintext, + .confounder = &rfc8009_enc_test8_confounder, + .base_key = &aes256_cts_hmac_sha384_192_basekey, + .expected_result = &rfc8009_enc_test8_expected_result, + .expected_hmac = &rfc8009_enc_test8_expected_hmac, + }, +}; + +/* Creates the function rfc8009_encrypt_gen_params */ +KUNIT_ARRAY_PARAM(rfc8009_encrypt, rfc8009_encrypt_test_params, + gss_krb5_get_desc); + +static void rfc8009_encrypt_case(struct kunit *test) +{ + const struct gss_krb5_test_param *param = test->param_value; + struct crypto_sync_skcipher *cts_tfm, *cbc_tfm; + const struct gss_krb5_enctype *gk5e; + struct xdr_netobj Ke, Ki, checksum; + u8 usage_data[GSS_KRB5_K5CLENGTH]; + struct xdr_netobj usage = { + .data = usage_data, + .len = sizeof(usage_data), + }; + struct crypto_ahash *ahash_tfm; + struct xdr_buf buf; + void *text; + size_t len; + u32 err; + + /* Arrange */ + gk5e = gss_krb5_lookup_enctype(param->enctype); + if (!gk5e) + kunit_skip(test, "Encryption type is not available"); + + *(__be32 *)usage.data = cpu_to_be32(2); + + Ke.len = gk5e->Ke_length; + Ke.data = kunit_kzalloc(test, Ke.len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Ke.data); + usage.data[4] = KEY_USAGE_SEED_ENCRYPTION; + err = gk5e->derive_key(gk5e, param->base_key, &Ke, + &usage, GFP_KERNEL); + KUNIT_ASSERT_EQ(test, err, 0); + + cbc_tfm = crypto_alloc_sync_skcipher(gk5e->aux_cipher, 0, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cbc_tfm); + err = crypto_sync_skcipher_setkey(cbc_tfm, Ke.data, Ke.len); + KUNIT_ASSERT_EQ(test, err, 0); + + cts_tfm = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cts_tfm); + err = crypto_sync_skcipher_setkey(cts_tfm, Ke.data, Ke.len); + KUNIT_ASSERT_EQ(test, err, 0); + + len = param->confounder->len + param->plaintext->len; + text = kunit_kzalloc(test, len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, text); + memcpy(text, param->confounder->data, param->confounder->len); + memcpy(text + param->confounder->len, param->plaintext->data, + param->plaintext->len); + + memset(&buf, 0, sizeof(buf)); + buf.head[0].iov_base = text; + buf.head[0].iov_len = param->confounder->len + param->plaintext->len; + buf.len = buf.head[0].iov_len; + + checksum.len = gk5e->cksumlength; + checksum.data = kunit_kzalloc(test, checksum.len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, checksum.data); + + Ki.len = gk5e->Ki_length; + Ki.data = kunit_kzalloc(test, Ki.len, GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, Ki.data); + usage.data[4] = KEY_USAGE_SEED_INTEGRITY; + err = gk5e->derive_key(gk5e, param->base_key, &Ki, + &usage, GFP_KERNEL); + KUNIT_ASSERT_EQ(test, err, 0); + + ahash_tfm = crypto_alloc_ahash(gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, ahash_tfm); + err = crypto_ahash_setkey(ahash_tfm, Ki.data, Ki.len); + KUNIT_ASSERT_EQ(test, err, 0); + + /* Act */ + err = krb5_cbc_cts_encrypt(cts_tfm, cbc_tfm, 0, &buf, NULL, NULL, 0); + KUNIT_ASSERT_EQ(test, err, 0); + err = krb5_etm_checksum(cts_tfm, ahash_tfm, &buf, 0, &checksum); + KUNIT_ASSERT_EQ(test, err, 0); + + /* Assert */ + KUNIT_EXPECT_EQ_MSG(test, + param->expected_result->len, buf.len, + "ciphertext length mismatch"); + KUNIT_EXPECT_EQ_MSG(test, + memcmp(param->expected_result->data, + buf.head[0].iov_base, + param->expected_result->len), 0, + "ciphertext mismatch"); + KUNIT_EXPECT_EQ_MSG(test, memcmp(param->expected_hmac->data, + checksum.data, + checksum.len), 0, + "HMAC mismatch"); + + crypto_free_ahash(ahash_tfm); + crypto_free_sync_skcipher(cts_tfm); + crypto_free_sync_skcipher(cbc_tfm); +} + +static struct kunit_case rfc8009_test_cases[] = { + { + .name = "RFC 8009 key derivation", + .run_case = kdf_case, + .generate_params = rfc8009_kdf_gen_params, + }, + { + .name = "RFC 8009 checksum", + .run_case = checksum_case, + .generate_params = rfc8009_checksum_gen_params, + }, + { + .name = "RFC 8009 encryption", + .run_case = rfc8009_encrypt_case, + .generate_params = rfc8009_encrypt_gen_params, + }, + {} +}; + +static struct kunit_suite rfc8009_suite = { + .name = "RFC 8009 suite", + .test_cases = rfc8009_test_cases, +}; + +/* + * Encryption self-tests + */ + +DEFINE_STR_XDR_NETOBJ(encrypt_selftest_plaintext, + "This is the plaintext for the encryption self-test."); + +static const struct gss_krb5_test_param encrypt_selftest_params[] = { + { + .desc = "aes128-cts-hmac-sha1-96 encryption self-test", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .Ke = &rfc3962_encryption_key, + .plaintext = &encrypt_selftest_plaintext, + }, + { + .desc = "aes256-cts-hmac-sha1-96 encryption self-test", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA1_96, + .Ke = &rfc3962_encryption_key, + .plaintext = &encrypt_selftest_plaintext, + }, + { + .desc = "camellia128-cts-cmac encryption self-test", + .enctype = ENCTYPE_CAMELLIA128_CTS_CMAC, + .Ke = &camellia128_cts_cmac_Ke, + .plaintext = &encrypt_selftest_plaintext, + }, + { + .desc = "camellia256-cts-cmac encryption self-test", + .enctype = ENCTYPE_CAMELLIA256_CTS_CMAC, + .Ke = &camellia256_cts_cmac_Ke, + .plaintext = &encrypt_selftest_plaintext, + }, + { + .desc = "aes128-cts-hmac-sha256-128 encryption self-test", + .enctype = ENCTYPE_AES128_CTS_HMAC_SHA256_128, + .Ke = &aes128_cts_hmac_sha256_128_Ke, + .plaintext = &encrypt_selftest_plaintext, + }, + { + .desc = "aes256-cts-hmac-sha384-192 encryption self-test", + .enctype = ENCTYPE_AES256_CTS_HMAC_SHA384_192, + .Ke = &aes256_cts_hmac_sha384_192_Ke, + .plaintext = &encrypt_selftest_plaintext, + }, +}; + +/* Creates the function encrypt_selftest_gen_params */ +KUNIT_ARRAY_PARAM(encrypt_selftest, encrypt_selftest_params, + gss_krb5_get_desc); + +/* + * Encrypt and decrypt plaintext, and ensure the input plaintext + * matches the output plaintext. A confounder is not added in this + * case. + */ +static void encrypt_selftest_case(struct kunit *test) +{ + const struct gss_krb5_test_param *param = test->param_value; + struct crypto_sync_skcipher *cts_tfm, *cbc_tfm; + const struct gss_krb5_enctype *gk5e; + struct xdr_buf buf; + void *text; + int err; + + /* Arrange */ + gk5e = gss_krb5_lookup_enctype(param->enctype); + if (!gk5e) + kunit_skip(test, "Encryption type is not available"); + + cbc_tfm = crypto_alloc_sync_skcipher(gk5e->aux_cipher, 0, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cbc_tfm); + err = crypto_sync_skcipher_setkey(cbc_tfm, param->Ke->data, param->Ke->len); + KUNIT_ASSERT_EQ(test, err, 0); + + cts_tfm = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, cts_tfm); + err = crypto_sync_skcipher_setkey(cts_tfm, param->Ke->data, param->Ke->len); + KUNIT_ASSERT_EQ(test, err, 0); + + text = kunit_kzalloc(test, roundup(param->plaintext->len, + crypto_sync_skcipher_blocksize(cbc_tfm)), + GFP_KERNEL); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, text); + + memcpy(text, param->plaintext->data, param->plaintext->len); + memset(&buf, 0, sizeof(buf)); + buf.head[0].iov_base = text; + buf.head[0].iov_len = param->plaintext->len; + buf.len = buf.head[0].iov_len; + + /* Act */ + err = krb5_cbc_cts_encrypt(cts_tfm, cbc_tfm, 0, &buf, NULL, NULL, 0); + KUNIT_ASSERT_EQ(test, err, 0); + err = krb5_cbc_cts_decrypt(cts_tfm, cbc_tfm, 0, &buf); + KUNIT_ASSERT_EQ(test, err, 0); + + /* Assert */ + KUNIT_EXPECT_EQ_MSG(test, + param->plaintext->len, buf.len, + "length mismatch"); + KUNIT_EXPECT_EQ_MSG(test, + memcmp(param->plaintext->data, + buf.head[0].iov_base, buf.len), 0, + "plaintext mismatch"); + + crypto_free_sync_skcipher(cts_tfm); + crypto_free_sync_skcipher(cbc_tfm); +} + +static struct kunit_case encryption_test_cases[] = { + { + .name = "Encryption self-tests", + .run_case = encrypt_selftest_case, + .generate_params = encrypt_selftest_gen_params, + }, + {} +}; + +static struct kunit_suite encryption_test_suite = { + .name = "Encryption test suite", + .test_cases = encryption_test_cases, +}; + +kunit_test_suites(&rfc3961_suite, + &rfc3962_suite, + &rfc6803_suite, + &rfc8009_suite, + &encryption_test_suite); + +MODULE_DESCRIPTION("Test RPCSEC GSS Kerberos 5 functions"); +MODULE_LICENSE("GPL"); diff --git a/net/sunrpc/auth_gss/gss_krb5_unseal.c b/net/sunrpc/auth_gss/gss_krb5_unseal.c index 6cd930f3678f..ef0e6af9fc95 100644 --- a/net/sunrpc/auth_gss/gss_krb5_unseal.c +++ b/net/sunrpc/auth_gss/gss_krb5_unseal.c @@ -60,106 +60,34 @@ #include <linux/types.h> #include <linux/jiffies.h> #include <linux/sunrpc/gss_krb5.h> -#include <linux/crypto.h> -#ifdef RPC_DEBUG +#include "gss_krb5_internal.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif - -/* read_token is a mic token, and message_buffer is the data that the mic was - * supposedly taken over. */ - -static u32 -gss_verify_mic_v1(struct krb5_ctx *ctx, - struct xdr_buf *message_buffer, struct xdr_netobj *read_token) -{ - int signalg; - int sealalg; - char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; - struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), - .data = cksumdata}; - s32 now; - int direction; - u32 seqnum; - unsigned char *ptr = (unsigned char *)read_token->data; - int bodysize; - u8 *cksumkey; - - dprintk("RPC: krb5_read_token\n"); - - if (g_verify_token_header(&ctx->mech_used, &bodysize, &ptr, - read_token->len)) - return GSS_S_DEFECTIVE_TOKEN; - - if ((ptr[0] != ((KG_TOK_MIC_MSG >> 8) & 0xff)) || - (ptr[1] != (KG_TOK_MIC_MSG & 0xff))) - return GSS_S_DEFECTIVE_TOKEN; - - /* XXX sanity-check bodysize?? */ - - signalg = ptr[2] + (ptr[3] << 8); - if (signalg != ctx->gk5e->signalg) - return GSS_S_DEFECTIVE_TOKEN; - - sealalg = ptr[4] + (ptr[5] << 8); - if (sealalg != SEAL_ALG_NONE) - return GSS_S_DEFECTIVE_TOKEN; - - if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) - return GSS_S_DEFECTIVE_TOKEN; - - if (ctx->gk5e->keyed_cksum) - cksumkey = ctx->cksum; - else - cksumkey = NULL; - - if (make_checksum(ctx, ptr, 8, message_buffer, 0, - cksumkey, KG_USAGE_SIGN, &md5cksum)) - return GSS_S_FAILURE; - - if (memcmp(md5cksum.data, ptr + GSS_KRB5_TOK_HDR_LEN, - ctx->gk5e->cksumlength)) - return GSS_S_BAD_SIG; - - /* it got through unscathed. Make sure the context is unexpired */ - - now = get_seconds(); - - if (now > ctx->endtime) - return GSS_S_CONTEXT_EXPIRED; - - /* do sequencing checks */ - - if (krb5_get_seq_num(ctx, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8, - &direction, &seqnum)) - return GSS_S_FAILURE; - - if ((ctx->initiate && direction != 0xff) || - (!ctx->initiate && direction != 0)) - return GSS_S_BAD_SIG; - - return GSS_S_COMPLETE; -} - -static u32 -gss_verify_mic_v2(struct krb5_ctx *ctx, - struct xdr_buf *message_buffer, struct xdr_netobj *read_token) +u32 +gss_krb5_verify_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *message_buffer, + struct xdr_netobj *read_token) { + struct crypto_ahash *tfm = ctx->initiate ? + ctx->acceptor_sign : ctx->initiator_sign; char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; - struct xdr_netobj cksumobj = {.len = sizeof(cksumdata), - .data = cksumdata}; - s32 now; - u64 seqnum; + struct xdr_netobj cksumobj = { + .len = ctx->gk5e->cksumlength, + .data = cksumdata, + }; u8 *ptr = read_token->data; - u8 *cksumkey; + __be16 be16_ptr; + time64_t now; u8 flags; int i; - unsigned int cksum_usage; dprintk("RPC: %s\n", __func__); - if (be16_to_cpu(*((__be16 *)ptr)) != KG2_TOK_MIC) + memcpy(&be16_ptr, (char *) ptr, 2); + if (be16_to_cpu(be16_ptr) != KG2_TOK_MIC) return GSS_S_DEFECTIVE_TOKEN; flags = ptr[2]; @@ -176,16 +104,8 @@ gss_verify_mic_v2(struct krb5_ctx *ctx, if (ptr[i] != 0xff) return GSS_S_DEFECTIVE_TOKEN; - if (ctx->initiate) { - cksumkey = ctx->acceptor_sign; - cksum_usage = KG_USAGE_ACCEPTOR_SIGN; - } else { - cksumkey = ctx->initiator_sign; - cksum_usage = KG_USAGE_INITIATOR_SIGN; - } - - if (make_checksum_v2(ctx, ptr, GSS_KRB5_TOK_HDR_LEN, message_buffer, 0, - cksumkey, cksum_usage, &cksumobj)) + if (gss_krb5_checksum(tfm, ptr, GSS_KRB5_TOK_HDR_LEN, + message_buffer, 0, &cksumobj)) return GSS_S_FAILURE; if (memcmp(cksumobj.data, ptr + GSS_KRB5_TOK_HDR_LEN, @@ -193,34 +113,14 @@ gss_verify_mic_v2(struct krb5_ctx *ctx, return GSS_S_BAD_SIG; /* it got through unscathed. Make sure the context is unexpired */ - now = get_seconds(); + now = ktime_get_real_seconds(); if (now > ctx->endtime) return GSS_S_CONTEXT_EXPIRED; - /* do sequencing checks */ - - seqnum = be64_to_cpup((__be64 *)ptr + 8); + /* + * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss + * doesn't want it checked; see page 6 of rfc 2203. + */ return GSS_S_COMPLETE; } - -u32 -gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, - struct xdr_buf *message_buffer, - struct xdr_netobj *read_token) -{ - struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; - - switch (ctx->enctype) { - default: - BUG(); - case ENCTYPE_DES_CBC_RAW: - case ENCTYPE_DES3_CBC_RAW: - case ENCTYPE_ARCFOUR_HMAC: - return gss_verify_mic_v1(ctx, message_buffer, read_token); - case ENCTYPE_AES128_CTS_HMAC_SHA1_96: - case ENCTYPE_AES256_CTS_HMAC_SHA1_96: - return gss_verify_mic_v2(ctx, message_buffer, read_token); - } -} - diff --git a/net/sunrpc/auth_gss/gss_krb5_wrap.c b/net/sunrpc/auth_gss/gss_krb5_wrap.c index 1da52d1406fc..b3e1738ff6bf 100644 --- a/net/sunrpc/auth_gss/gss_krb5_wrap.c +++ b/net/sunrpc/auth_gss/gss_krb5_wrap.c @@ -28,358 +28,18 @@ * SUCH DAMAGES. */ +#include <crypto/skcipher.h> #include <linux/types.h> #include <linux/jiffies.h> #include <linux/sunrpc/gss_krb5.h> -#include <linux/random.h> #include <linux/pagemap.h> -#include <linux/crypto.h> -#ifdef RPC_DEBUG +#include "gss_krb5_internal.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif -static inline int -gss_krb5_padding(int blocksize, int length) -{ - return blocksize - (length % blocksize); -} - -static inline void -gss_krb5_add_padding(struct xdr_buf *buf, int offset, int blocksize) -{ - int padding = gss_krb5_padding(blocksize, buf->len - offset); - char *p; - struct kvec *iov; - - if (buf->page_len || buf->tail[0].iov_len) - iov = &buf->tail[0]; - else - iov = &buf->head[0]; - p = iov->iov_base + iov->iov_len; - iov->iov_len += padding; - buf->len += padding; - memset(p, padding, padding); -} - -static inline int -gss_krb5_remove_padding(struct xdr_buf *buf, int blocksize) -{ - u8 *ptr; - u8 pad; - size_t len = buf->len; - - if (len <= buf->head[0].iov_len) { - pad = *(u8 *)(buf->head[0].iov_base + len - 1); - if (pad > buf->head[0].iov_len) - return -EINVAL; - buf->head[0].iov_len -= pad; - goto out; - } else - len -= buf->head[0].iov_len; - if (len <= buf->page_len) { - unsigned int last = (buf->page_base + len - 1) - >>PAGE_CACHE_SHIFT; - unsigned int offset = (buf->page_base + len - 1) - & (PAGE_CACHE_SIZE - 1); - ptr = kmap_atomic(buf->pages[last]); - pad = *(ptr + offset); - kunmap_atomic(ptr); - goto out; - } else - len -= buf->page_len; - BUG_ON(len > buf->tail[0].iov_len); - pad = *(u8 *)(buf->tail[0].iov_base + len - 1); -out: - /* XXX: NOTE: we do not adjust the page lengths--they represent - * a range of data in the real filesystem page cache, and we need - * to know that range so the xdr code can properly place read data. - * However adjusting the head length, as we do above, is harmless. - * In the case of a request that fits into a single page, the server - * also uses length and head length together to determine the original - * start of the request to copy the request for deferal; so it's - * easier on the server if we adjust head and tail length in tandem. - * It's not really a problem that we don't fool with the page and - * tail lengths, though--at worst badly formed xdr might lead the - * server to attempt to parse the padding. - * XXX: Document all these weird requirements for gss mechanism - * wrap/unwrap functions. */ - if (pad > blocksize) - return -EINVAL; - if (buf->len > pad) - buf->len -= pad; - else - return -EINVAL; - return 0; -} - -void -gss_krb5_make_confounder(char *p, u32 conflen) -{ - static u64 i = 0; - u64 *q = (u64 *)p; - - /* rfc1964 claims this should be "random". But all that's really - * necessary is that it be unique. And not even that is necessary in - * our case since our "gssapi" implementation exists only to support - * rpcsec_gss, so we know that the only buffers we will ever encrypt - * already begin with a unique sequence number. Just to hedge my bets - * I'll make a half-hearted attempt at something unique, but ensuring - * uniqueness would mean worrying about atomicity and rollover, and I - * don't care enough. */ - - /* initialize to random value */ - if (i == 0) { - i = prandom_u32(); - i = (i << 32) | prandom_u32(); - } - - switch (conflen) { - case 16: - *q++ = i++; - /* fall through */ - case 8: - *q++ = i++; - break; - default: - BUG(); - } -} - -/* Assumptions: the head and tail of inbuf are ours to play with. - * The pages, however, may be real pages in the page cache and we replace - * them with scratch pages from **pages before writing to them. */ -/* XXX: obviously the above should be documentation of wrap interface, - * and shouldn't be in this kerberos-specific file. */ - -/* XXX factor out common code with seal/unseal. */ - -static u32 -gss_wrap_kerberos_v1(struct krb5_ctx *kctx, int offset, - struct xdr_buf *buf, struct page **pages) -{ - char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; - struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), - .data = cksumdata}; - int blocksize = 0, plainlen; - unsigned char *ptr, *msg_start; - s32 now; - int headlen; - struct page **tmp_pages; - u32 seq_send; - u8 *cksumkey; - u32 conflen = kctx->gk5e->conflen; - - dprintk("RPC: %s\n", __func__); - - now = get_seconds(); - - blocksize = crypto_blkcipher_blocksize(kctx->enc); - gss_krb5_add_padding(buf, offset, blocksize); - BUG_ON((buf->len - offset) % blocksize); - plainlen = conflen + buf->len - offset; - - headlen = g_token_size(&kctx->mech_used, - GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength + plainlen) - - (buf->len - offset); - - ptr = buf->head[0].iov_base + offset; - /* shift data to make room for header. */ - xdr_extend_head(buf, offset, headlen); - - /* XXX Would be cleverer to encrypt while copying. */ - BUG_ON((buf->len - offset - headlen) % blocksize); - - g_make_token_header(&kctx->mech_used, - GSS_KRB5_TOK_HDR_LEN + - kctx->gk5e->cksumlength + plainlen, &ptr); - - - /* ptr now at header described in rfc 1964, section 1.2.1: */ - ptr[0] = (unsigned char) ((KG_TOK_WRAP_MSG >> 8) & 0xff); - ptr[1] = (unsigned char) (KG_TOK_WRAP_MSG & 0xff); - - msg_start = ptr + GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength; - - *(__be16 *)(ptr + 2) = cpu_to_le16(kctx->gk5e->signalg); - memset(ptr + 4, 0xff, 4); - *(__be16 *)(ptr + 4) = cpu_to_le16(kctx->gk5e->sealalg); - - gss_krb5_make_confounder(msg_start, conflen); - - if (kctx->gk5e->keyed_cksum) - cksumkey = kctx->cksum; - else - cksumkey = NULL; - - /* XXXJBF: UGH!: */ - tmp_pages = buf->pages; - buf->pages = pages; - if (make_checksum(kctx, ptr, 8, buf, offset + headlen - conflen, - cksumkey, KG_USAGE_SEAL, &md5cksum)) - return GSS_S_FAILURE; - buf->pages = tmp_pages; - - memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); - - spin_lock(&krb5_seq_lock); - seq_send = kctx->seq_send++; - spin_unlock(&krb5_seq_lock); - - /* XXX would probably be more efficient to compute checksum - * and encrypt at the same time: */ - if ((krb5_make_seq_num(kctx, kctx->seq, kctx->initiate ? 0 : 0xff, - seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8))) - return GSS_S_FAILURE; - - if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) { - struct crypto_blkcipher *cipher; - int err; - cipher = crypto_alloc_blkcipher(kctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); - if (IS_ERR(cipher)) - return GSS_S_FAILURE; - - krb5_rc4_setup_enc_key(kctx, cipher, seq_send); - - err = gss_encrypt_xdr_buf(cipher, buf, - offset + headlen - conflen, pages); - crypto_free_blkcipher(cipher); - if (err) - return GSS_S_FAILURE; - } else { - if (gss_encrypt_xdr_buf(kctx->enc, buf, - offset + headlen - conflen, pages)) - return GSS_S_FAILURE; - } - - return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; -} - -static u32 -gss_unwrap_kerberos_v1(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) -{ - int signalg; - int sealalg; - char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; - struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), - .data = cksumdata}; - s32 now; - int direction; - s32 seqnum; - unsigned char *ptr; - int bodysize; - void *data_start, *orig_start; - int data_len; - int blocksize; - u32 conflen = kctx->gk5e->conflen; - int crypt_offset; - u8 *cksumkey; - - dprintk("RPC: gss_unwrap_kerberos\n"); - - ptr = (u8 *)buf->head[0].iov_base + offset; - if (g_verify_token_header(&kctx->mech_used, &bodysize, &ptr, - buf->len - offset)) - return GSS_S_DEFECTIVE_TOKEN; - - if ((ptr[0] != ((KG_TOK_WRAP_MSG >> 8) & 0xff)) || - (ptr[1] != (KG_TOK_WRAP_MSG & 0xff))) - return GSS_S_DEFECTIVE_TOKEN; - - /* XXX sanity-check bodysize?? */ - - /* get the sign and seal algorithms */ - - signalg = ptr[2] + (ptr[3] << 8); - if (signalg != kctx->gk5e->signalg) - return GSS_S_DEFECTIVE_TOKEN; - - sealalg = ptr[4] + (ptr[5] << 8); - if (sealalg != kctx->gk5e->sealalg) - return GSS_S_DEFECTIVE_TOKEN; - - if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) - return GSS_S_DEFECTIVE_TOKEN; - - /* - * Data starts after token header and checksum. ptr points - * to the beginning of the token header - */ - crypt_offset = ptr + (GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength) - - (unsigned char *)buf->head[0].iov_base; - - /* - * Need plaintext seqnum to derive encryption key for arcfour-hmac - */ - if (krb5_get_seq_num(kctx, ptr + GSS_KRB5_TOK_HDR_LEN, - ptr + 8, &direction, &seqnum)) - return GSS_S_BAD_SIG; - - if ((kctx->initiate && direction != 0xff) || - (!kctx->initiate && direction != 0)) - return GSS_S_BAD_SIG; - - if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) { - struct crypto_blkcipher *cipher; - int err; - - cipher = crypto_alloc_blkcipher(kctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); - if (IS_ERR(cipher)) - return GSS_S_FAILURE; - - krb5_rc4_setup_enc_key(kctx, cipher, seqnum); - - err = gss_decrypt_xdr_buf(cipher, buf, crypt_offset); - crypto_free_blkcipher(cipher); - if (err) - return GSS_S_DEFECTIVE_TOKEN; - } else { - if (gss_decrypt_xdr_buf(kctx->enc, buf, crypt_offset)) - return GSS_S_DEFECTIVE_TOKEN; - } - - if (kctx->gk5e->keyed_cksum) - cksumkey = kctx->cksum; - else - cksumkey = NULL; - - if (make_checksum(kctx, ptr, 8, buf, crypt_offset, - cksumkey, KG_USAGE_SEAL, &md5cksum)) - return GSS_S_FAILURE; - - if (memcmp(md5cksum.data, ptr + GSS_KRB5_TOK_HDR_LEN, - kctx->gk5e->cksumlength)) - return GSS_S_BAD_SIG; - - /* it got through unscathed. Make sure the context is unexpired */ - - now = get_seconds(); - - if (now > kctx->endtime) - return GSS_S_CONTEXT_EXPIRED; - - /* do sequencing checks */ - - /* Copy the data back to the right position. XXX: Would probably be - * better to copy and encrypt at the same time. */ - - blocksize = crypto_blkcipher_blocksize(kctx->enc); - data_start = ptr + (GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength) + - conflen; - orig_start = buf->head[0].iov_base + offset; - data_len = (buf->head[0].iov_base + buf->head[0].iov_len) - data_start; - memmove(orig_start, data_start, data_len); - buf->head[0].iov_len -= (data_start - orig_start); - buf->len -= (data_start - orig_start); - - if (gss_krb5_remove_padding(buf, blocksize)) - return GSS_S_DEFECTIVE_TOKEN; - - return GSS_S_COMPLETE; -} - /* * We can shift data by up to LOCAL_BUF_LEN bytes in a pass. If we need * to do more than that, we shift repeatedly. Kevin Coffman reports @@ -430,29 +90,25 @@ static void rotate_left(u32 base, struct xdr_buf *buf, unsigned int shift) _rotate_left(&subbuf, shift); } -static u32 -gss_wrap_kerberos_v2(struct krb5_ctx *kctx, u32 offset, - struct xdr_buf *buf, struct page **pages) +u32 +gss_krb5_wrap_v2(struct krb5_ctx *kctx, int offset, + struct xdr_buf *buf, struct page **pages) { - int blocksize; - u8 *ptr, *plainhdr; - s32 now; + u8 *ptr; + time64_t now; u8 flags = 0x00; - __be16 *be16ptr, ec = 0; + __be16 *be16ptr; __be64 *be64ptr; u32 err; dprintk("RPC: %s\n", __func__); - if (kctx->gk5e->encrypt_v2 == NULL) - return GSS_S_FAILURE; - /* make room for gss token header */ if (xdr_extend_head(buf, offset, GSS_KRB5_TOK_HDR_LEN)) return GSS_S_FAILURE; /* construct gss token header */ - ptr = plainhdr = buf->head[0].iov_base + offset; + ptr = buf->head[0].iov_base + offset; *ptr++ = (unsigned char) ((KG2_TOK_WRAP>>8) & 0xff); *ptr++ = (unsigned char) (KG2_TOK_WRAP & 0xff); @@ -467,29 +123,27 @@ gss_wrap_kerberos_v2(struct krb5_ctx *kctx, u32 offset, *ptr++ = 0xff; be16ptr = (__be16 *)ptr; - blocksize = crypto_blkcipher_blocksize(kctx->acceptor_enc); - *be16ptr++ = cpu_to_be16(ec); + *be16ptr++ = 0; /* "inner" token header always uses 0 for RRC */ - *be16ptr++ = cpu_to_be16(0); + *be16ptr++ = 0; be64ptr = (__be64 *)be16ptr; - spin_lock(&krb5_seq_lock); - *be64ptr = cpu_to_be64(kctx->seq_send64++); - spin_unlock(&krb5_seq_lock); + *be64ptr = cpu_to_be64(atomic64_fetch_inc(&kctx->seq_send64)); - err = (*kctx->gk5e->encrypt_v2)(kctx, offset, buf, ec, pages); + err = (*kctx->gk5e->encrypt)(kctx, offset, buf, pages); if (err) return err; - now = get_seconds(); + now = ktime_get_real_seconds(); return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; } -static u32 -gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) +u32 +gss_krb5_unwrap_v2(struct krb5_ctx *kctx, int offset, int len, + struct xdr_buf *buf, unsigned int *slack, + unsigned int *align) { - s32 now; - u64 seqnum; + time64_t now; u8 *ptr; u8 flags = 0x00; u16 ec, rrc; @@ -501,9 +155,6 @@ gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) dprintk("RPC: %s\n", __func__); - if (kctx->gk5e->decrypt_v2 == NULL) - return GSS_S_FAILURE; - ptr = buf->head[0].iov_base + offset; if (be16_to_cpu(*((__be16 *)ptr)) != KG2_TOK_WRAP) @@ -525,13 +176,16 @@ gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) ec = be16_to_cpup((__be16 *)(ptr + 4)); rrc = be16_to_cpup((__be16 *)(ptr + 6)); - seqnum = be64_to_cpup((__be64 *)(ptr + 8)); + /* + * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss + * doesn't want it checked; see page 6 of rfc 2203. + */ if (rrc != 0) rotate_left(offset + 16, buf, rrc); - err = (*kctx->gk5e->decrypt_v2)(kctx, offset, buf, - &headskip, &tailskip); + err = (*kctx->gk5e->decrypt)(kctx, offset, len, buf, + &headskip, &tailskip); if (err) return GSS_S_FAILURE; @@ -540,7 +194,7 @@ gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) * it against the original */ err = read_bytes_from_xdr_buf(buf, - buf->len - GSS_KRB5_TOK_HDR_LEN - tailskip, + len - GSS_KRB5_TOK_HDR_LEN - tailskip, decrypted_hdr, GSS_KRB5_TOK_HDR_LEN); if (err) { dprintk("%s: error %u getting decrypted_hdr\n", __func__, err); @@ -555,7 +209,7 @@ gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) /* do sequencing checks */ /* it got through unscathed. Make sure the context is unexpired */ - now = get_seconds(); + now = ktime_get_real_seconds(); if (now > kctx->endtime) return GSS_S_CONTEXT_EXPIRED; @@ -566,53 +220,18 @@ gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) * Note that buf->head[0].iov_len may indicate the available * head buffer space rather than that actually occupied. */ - movelen = min_t(unsigned int, buf->head[0].iov_len, buf->len); + movelen = min_t(unsigned int, buf->head[0].iov_len, len); movelen -= offset + GSS_KRB5_TOK_HDR_LEN + headskip; BUG_ON(offset + GSS_KRB5_TOK_HDR_LEN + headskip + movelen > buf->head[0].iov_len); memmove(ptr, ptr + GSS_KRB5_TOK_HDR_LEN + headskip, movelen); buf->head[0].iov_len -= GSS_KRB5_TOK_HDR_LEN + headskip; - buf->len -= GSS_KRB5_TOK_HDR_LEN + headskip; + buf->len = len - (GSS_KRB5_TOK_HDR_LEN + headskip); - /* Trim off the checksum blob */ - xdr_buf_trim(buf, GSS_KRB5_TOK_HDR_LEN + tailskip); - return GSS_S_COMPLETE; -} - -u32 -gss_wrap_kerberos(struct gss_ctx *gctx, int offset, - struct xdr_buf *buf, struct page **pages) -{ - struct krb5_ctx *kctx = gctx->internal_ctx_id; - - switch (kctx->enctype) { - default: - BUG(); - case ENCTYPE_DES_CBC_RAW: - case ENCTYPE_DES3_CBC_RAW: - case ENCTYPE_ARCFOUR_HMAC: - return gss_wrap_kerberos_v1(kctx, offset, buf, pages); - case ENCTYPE_AES128_CTS_HMAC_SHA1_96: - case ENCTYPE_AES256_CTS_HMAC_SHA1_96: - return gss_wrap_kerberos_v2(kctx, offset, buf, pages); - } -} + /* Trim off the trailing "extra count" and checksum blob */ + xdr_buf_trim(buf, ec + GSS_KRB5_TOK_HDR_LEN + tailskip); -u32 -gss_unwrap_kerberos(struct gss_ctx *gctx, int offset, struct xdr_buf *buf) -{ - struct krb5_ctx *kctx = gctx->internal_ctx_id; - - switch (kctx->enctype) { - default: - BUG(); - case ENCTYPE_DES_CBC_RAW: - case ENCTYPE_DES3_CBC_RAW: - case ENCTYPE_ARCFOUR_HMAC: - return gss_unwrap_kerberos_v1(kctx, offset, buf); - case ENCTYPE_AES128_CTS_HMAC_SHA1_96: - case ENCTYPE_AES256_CTS_HMAC_SHA1_96: - return gss_unwrap_kerberos_v2(kctx, offset, buf); - } + *align = XDR_QUADLEN(GSS_KRB5_TOK_HDR_LEN + headskip); + *slack = *align + XDR_QUADLEN(ec + GSS_KRB5_TOK_HDR_LEN + tailskip); + return GSS_S_COMPLETE; } - diff --git a/net/sunrpc/auth_gss/gss_mech_switch.c b/net/sunrpc/auth_gss/gss_mech_switch.c index 27ce26240932..c84d0cf61980 100644 --- a/net/sunrpc/auth_gss/gss_mech_switch.c +++ b/net/sunrpc/auth_gss/gss_mech_switch.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: BSD-3-Clause /* * linux/net/sunrpc/gss_mech_switch.c * @@ -5,32 +6,6 @@ * All rights reserved. * * J. Bruce Fields <bfields@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * */ #include <linux/types.h> @@ -38,15 +13,15 @@ #include <linux/module.h> #include <linux/oid_registry.h> #include <linux/sunrpc/msg_prot.h> -#include <linux/sunrpc/gss_asn1.h> #include <linux/sunrpc/auth_gss.h> #include <linux/sunrpc/svcauth_gss.h> #include <linux/sunrpc/gss_err.h> #include <linux/sunrpc/sched.h> #include <linux/sunrpc/gss_api.h> #include <linux/sunrpc/clnt.h> +#include <trace/events/rpcgss.h> -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif @@ -61,6 +36,8 @@ gss_mech_free(struct gss_api_mech *gm) for (i = 0; i < gm->gm_pf_num; i++) { pf = &gm->gm_pfs[i]; + if (pf->domain) + auth_domain_put(pf->domain); kfree(pf->auth_domain_name); pf->auth_domain_name = NULL; } @@ -83,6 +60,7 @@ make_auth_domain_name(char *name) static int gss_mech_svc_setup(struct gss_api_mech *gm) { + struct auth_domain *dom; struct pf_desc *pf; int i, status; @@ -92,10 +70,13 @@ gss_mech_svc_setup(struct gss_api_mech *gm) status = -ENOMEM; if (pf->auth_domain_name == NULL) goto out; - status = svcauth_gss_register_pseudoflavor(pf->pseudoflavor, - pf->auth_domain_name); - if (status) + dom = svcauth_gss_register_pseudoflavor( + pf->pseudoflavor, pf->auth_domain_name); + if (IS_ERR(dom)) { + status = PTR_ERR(dom); goto out; + } + pf->domain = dom; } return 0; out: @@ -117,7 +98,7 @@ int gss_mech_register(struct gss_api_mech *gm) if (status) return status; spin_lock(®istered_mechs_lock); - list_add(&gm->gm_list, ®istered_mechs); + list_add_rcu(&gm->gm_list, ®istered_mechs); spin_unlock(®istered_mechs_lock); dprintk("RPC: registered gss mechanism %s\n", gm->gm_name); return 0; @@ -132,7 +113,7 @@ EXPORT_SYMBOL_GPL(gss_mech_register); void gss_mech_unregister(struct gss_api_mech *gm) { spin_lock(®istered_mechs_lock); - list_del(&gm->gm_list); + list_del_rcu(&gm->gm_list); spin_unlock(®istered_mechs_lock); dprintk("RPC: unregistered gss mechanism %s\n", gm->gm_name); gss_mech_free(gm); @@ -151,15 +132,15 @@ _gss_mech_get_by_name(const char *name) { struct gss_api_mech *pos, *gm = NULL; - spin_lock(®istered_mechs_lock); - list_for_each_entry(pos, ®istered_mechs, gm_list) { + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { if (0 == strcmp(name, pos->gm_name)) { if (try_module_get(pos->gm_owner)) gm = pos; break; } } - spin_unlock(®istered_mechs_lock); + rcu_read_unlock(); return gm; } @@ -183,11 +164,10 @@ struct gss_api_mech *gss_mech_get_by_OID(struct rpcsec_gss_oid *obj) if (sprint_oid(obj->data, obj->len, buf, sizeof(buf)) < 0) return NULL; - dprintk("RPC: %s(%s)\n", __func__, buf); request_module("rpc-auth-gss-%s", buf); - spin_lock(®istered_mechs_lock); - list_for_each_entry(pos, ®istered_mechs, gm_list) { + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { if (obj->len == pos->gm_oid.len) { if (0 == memcmp(obj->data, pos->gm_oid.data, obj->len)) { if (try_module_get(pos->gm_owner)) @@ -196,7 +176,9 @@ struct gss_api_mech *gss_mech_get_by_OID(struct rpcsec_gss_oid *obj) } } } - spin_unlock(®istered_mechs_lock); + rcu_read_unlock(); + if (!gm) + trace_rpcgss_oid_to_mech(buf); return gm; } @@ -216,17 +198,15 @@ static struct gss_api_mech *_gss_mech_get_by_pseudoflavor(u32 pseudoflavor) { struct gss_api_mech *gm = NULL, *pos; - spin_lock(®istered_mechs_lock); - list_for_each_entry(pos, ®istered_mechs, gm_list) { - if (!mech_supports_pseudoflavor(pos, pseudoflavor)) { - module_put(pos->gm_owner); + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { + if (!mech_supports_pseudoflavor(pos, pseudoflavor)) continue; - } if (try_module_get(pos->gm_owner)) gm = pos; break; } - spin_unlock(®istered_mechs_lock); + rcu_read_unlock(); return gm; } @@ -245,35 +225,6 @@ gss_mech_get_by_pseudoflavor(u32 pseudoflavor) } /** - * gss_mech_list_pseudoflavors - Discover registered GSS pseudoflavors - * @array: array to fill in - * @size: size of "array" - * - * Returns the number of array items filled in, or a negative errno. - * - * The returned array is not sorted by any policy. Callers should not - * rely on the order of the items in the returned array. - */ -int gss_mech_list_pseudoflavors(rpc_authflavor_t *array_ptr, int size) -{ - struct gss_api_mech *pos = NULL; - int j, i = 0; - - spin_lock(®istered_mechs_lock); - list_for_each_entry(pos, ®istered_mechs, gm_list) { - for (j = 0; j < pos->gm_pf_num; j++) { - if (i >= size) { - spin_unlock(®istered_mechs_lock); - return -ENOMEM; - } - array_ptr[i++] = pos->gm_pfs[j].pseudoflavor; - } - } - spin_unlock(®istered_mechs_lock); - return i; -} - -/** * gss_svc_to_pseudoflavor - map a GSS service number to a pseudoflavor * @gm: GSS mechanism handle * @qop: GSS quality-of-protection value @@ -363,6 +314,18 @@ gss_pseudoflavor_to_service(struct gss_api_mech *gm, u32 pseudoflavor) } EXPORT_SYMBOL(gss_pseudoflavor_to_service); +bool +gss_pseudoflavor_to_datatouch(struct gss_api_mech *gm, u32 pseudoflavor) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) + return gm->gm_pfs[i].datatouch; + } + return false; +} + char * gss_service_to_auth_domain_name(struct gss_api_mech *gm, u32 service) { @@ -389,7 +352,7 @@ int gss_import_sec_context(const void *input_token, size_t bufsize, struct gss_api_mech *mech, struct gss_ctx **ctx_id, - time_t *endtime, + time64_t *endtime, gfp_t gfp_mask) { if (!(*ctx_id = kzalloc(sizeof(**ctx_id), gfp_mask))) @@ -453,10 +416,11 @@ gss_wrap(struct gss_ctx *ctx_id, u32 gss_unwrap(struct gss_ctx *ctx_id, int offset, + int len, struct xdr_buf *buf) { return ctx_id->mech_type->gm_ops - ->gss_unwrap(ctx_id, offset, buf); + ->gss_unwrap(ctx_id, offset, len, buf); } diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.c b/net/sunrpc/auth_gss/gss_rpc_upcall.c index d304f41260f2..f549e4c05def 100644 --- a/net/sunrpc/auth_gss/gss_rpc_upcall.c +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.c @@ -1,21 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0+ /* * linux/net/sunrpc/gss_rpc_upcall.c * * Copyright (C) 2012 Simo Sorce <simo@redhat.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ #include <linux/types.h> @@ -55,15 +42,15 @@ enum { #define PROC(proc, name) \ [GSSX_##proc] = { \ .p_proc = GSSX_##proc, \ - .p_encode = (kxdreproc_t)gssx_enc_##name, \ - .p_decode = (kxdrdproc_t)gssx_dec_##name, \ + .p_encode = gssx_enc_##name, \ + .p_decode = gssx_dec_##name, \ .p_arglen = GSSX_ARG_##name##_sz, \ .p_replen = GSSX_RES_##name##_sz, \ .p_statidx = GSSX_##proc, \ .p_name = #proc, \ } -static struct rpc_procinfo gssp_procedures[] = { +static const struct rpc_procinfo gssp_procedures[] = { PROC(INDICATE_MECHS, indicate_mechs), PROC(GET_CALL_CONTEXT, get_call_context), PROC(IMPORT_AND_CANON_NAME, import_and_canon_name), @@ -111,6 +98,7 @@ static int gssp_rpc_create(struct net *net, struct rpc_clnt **_clnt) * done without the correct namespace: */ .flags = RPC_CLNT_CREATE_NOPING | + RPC_CLNT_CREATE_CONNECTED | RPC_CLNT_CREATE_NO_IDLE_TIMEOUT }; struct rpc_clnt *clnt; @@ -120,7 +108,7 @@ static int gssp_rpc_create(struct net *net, struct rpc_clnt **_clnt) if (IS_ERR(clnt)) { dprintk("RPC: failed to create AF_LOCAL gssproxy " "client (errno %ld).\n", PTR_ERR(clnt)); - result = -PTR_ERR(clnt); + result = PTR_ERR(clnt); *_clnt = NULL; goto out; } @@ -137,7 +125,6 @@ void init_gssp_clnt(struct sunrpc_net *sn) { mutex_init(&sn->gssp_lock); sn->gssp_clnt = NULL; - init_waitqueue_head(&sn->gssp_wq); } int set_gssp_clnt(struct net *net) @@ -154,7 +141,6 @@ int set_gssp_clnt(struct net *net) sn->gssp_clnt = clnt; } mutex_unlock(&sn->gssp_lock); - wake_up(&sn->gssp_wq); return ret; } @@ -175,7 +161,7 @@ static struct rpc_clnt *get_gssp_clnt(struct sunrpc_net *sn) mutex_lock(&sn->gssp_lock); clnt = sn->gssp_clnt; if (clnt) - atomic_inc(&clnt->cl_count); + refcount_inc(&clnt->cl_count); mutex_unlock(&sn->gssp_lock); return clnt; } @@ -213,6 +199,62 @@ static int gssp_call(struct net *net, struct rpc_message *msg) return status; } +static void gssp_free_receive_pages(struct gssx_arg_accept_sec_context *arg) +{ + unsigned int i; + + for (i = 0; i < arg->npages && arg->pages[i]; i++) + __free_page(arg->pages[i]); + + kfree(arg->pages); +} + +static int gssp_alloc_receive_pages(struct gssx_arg_accept_sec_context *arg) +{ + unsigned int i; + + arg->npages = DIV_ROUND_UP(NGROUPS_MAX * 4, PAGE_SIZE); + arg->pages = kcalloc(arg->npages, sizeof(struct page *), GFP_KERNEL); + if (!arg->pages) + return -ENOMEM; + for (i = 0; i < arg->npages; i++) { + arg->pages[i] = alloc_page(GFP_KERNEL); + if (!arg->pages[i]) { + gssp_free_receive_pages(arg); + return -ENOMEM; + } + } + return 0; +} + +static char *gssp_stringify(struct xdr_netobj *netobj) +{ + return kmemdup_nul(netobj->data, netobj->len, GFP_KERNEL); +} + +static void gssp_hostbased_service(char **principal) +{ + char *c; + + if (!*principal) + return; + + /* terminate and remove realm part */ + c = strchr(*principal, '@'); + if (c) { + *c = '\0'; + + /* change service-hostname delimiter */ + c = strchr(*principal, '/'); + if (c) + *c = '@'; + } + if (!c) { + /* not a service principal */ + kfree(*principal); + *principal = NULL; + } +} /* * Public functions @@ -242,6 +284,7 @@ int gssp_accept_sec_context_upcall(struct net *net, */ .exported_context_token.len = GSSX_max_output_handle_sz, .mech.len = GSS_OID_MAX_LEN, + .targ_name.display_name.len = GSSX_max_princ_sz, .src_name.display_name.len = GSSX_max_princ_sz }; struct gssx_res_accept_sec_context res = { @@ -255,16 +298,21 @@ int gssp_accept_sec_context_upcall(struct net *net, .rpc_cred = NULL, /* FIXME ? */ }; struct xdr_netobj client_name = { 0 , NULL }; + struct xdr_netobj target_name = { 0, NULL }; int ret; if (data->in_handle.len != 0) arg.context_handle = &ctxh; res.output_token->len = GSSX_max_output_token_sz; - /* use nfs/ for targ_name ? */ + ret = gssp_alloc_receive_pages(&arg); + if (ret) + return ret; ret = gssp_call(net, &msg); + gssp_free_receive_pages(&arg); + /* we need to fetch all data even in case of error so * that we can free special strctures is they have been allocated */ data->major_status = res.status.major_status; @@ -272,9 +320,13 @@ int gssp_accept_sec_context_upcall(struct net *net, if (res.context_handle) { data->out_handle = rctxh.exported_context_token; data->mech_oid.len = rctxh.mech.len; - memcpy(data->mech_oid.data, rctxh.mech.data, + if (rctxh.mech.data) { + memcpy(data->mech_oid.data, rctxh.mech.data, data->mech_oid.len); + kfree(rctxh.mech.data); + } client_name = rctxh.src_name.display_name; + target_name = rctxh.targ_name.display_name; } if (res.options.count == 1) { @@ -296,29 +348,22 @@ int gssp_accept_sec_context_upcall(struct net *net, } /* convert to GSS_NT_HOSTBASED_SERVICE form and set into creds */ - if (data->found_creds && client_name.data != NULL) { - char *c; - - data->creds.cr_principal = kstrndup(client_name.data, - client_name.len, GFP_KERNEL); - if (data->creds.cr_principal) { - /* terminate and remove realm part */ - c = strchr(data->creds.cr_principal, '@'); - if (c) { - *c = '\0'; - - /* change service-hostname delimiter */ - c = strchr(data->creds.cr_principal, '/'); - if (c) *c = '@'; - } - if (!c) { - /* not a service principal */ - kfree(data->creds.cr_principal); - data->creds.cr_principal = NULL; - } + if (data->found_creds) { + if (client_name.data) { + data->creds.cr_raw_principal = + gssp_stringify(&client_name); + data->creds.cr_principal = + gssp_stringify(&client_name); + gssp_hostbased_service(&data->creds.cr_principal); + } + if (target_name.data) { + data->creds.cr_targ_princ = + gssp_stringify(&target_name); + gssp_hostbased_service(&data->creds.cr_targ_princ); } } kfree(client_name.data); + kfree(target_name.data); return ret; } @@ -328,18 +373,18 @@ void gssp_free_upcall_data(struct gssp_upcall_data *data) kfree(data->in_handle.data); kfree(data->out_handle.data); kfree(data->out_token.data); - kfree(data->mech_oid.data); free_svc_cred(&data->creds); } /* * Initialization stuff */ - +static unsigned int gssp_version1_counts[ARRAY_SIZE(gssp_procedures)]; static const struct rpc_version gssp_version1 = { .number = GSSPROXY_VERS_1, .nrprocs = ARRAY_SIZE(gssp_procedures), .procs = gssp_procedures, + .counts = gssp_version1_counts, }; static const struct rpc_version *gssp_version[] = { diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.h b/net/sunrpc/auth_gss/gss_rpc_upcall.h index 1e542aded90a..31e96344167e 100644 --- a/net/sunrpc/auth_gss/gss_rpc_upcall.h +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.h @@ -1,21 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ */ /* * linux/net/sunrpc/gss_rpc_upcall.h * * Copyright (C) 2012 Simo Sorce <simo@redhat.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ #ifndef _GSS_RPC_UPCALL_H @@ -45,4 +32,5 @@ void gssp_free_upcall_data(struct gssp_upcall_data *data); void init_gssp_clnt(struct sunrpc_net *); int set_gssp_clnt(struct net *); void clear_gssp_clnt(struct sunrpc_net *); + #endif /* _GSS_RPC_UPCALL_H */ diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.c b/net/sunrpc/auth_gss/gss_rpc_xdr.c index 357f613df7ff..7d2cdc2bd374 100644 --- a/net/sunrpc/auth_gss/gss_rpc_xdr.c +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.c @@ -1,21 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0+ /* * GSS Proxy upcall module * * Copyright (C) 2012 Simo Sorce <simo@redhat.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ #include <linux/sunrpc/svcauth.h> @@ -44,7 +31,7 @@ static int gssx_dec_bool(struct xdr_stream *xdr, u32 *v) } static int gssx_enc_buffer(struct xdr_stream *xdr, - gssx_buffer *buf) + const gssx_buffer *buf) { __be32 *p; @@ -56,7 +43,7 @@ static int gssx_enc_buffer(struct xdr_stream *xdr, } static int gssx_enc_in_token(struct xdr_stream *xdr, - struct gssp_in_token *in) + const struct gssp_in_token *in) { __be32 *p; @@ -130,7 +117,7 @@ static int gssx_dec_option(struct xdr_stream *xdr, } static int dummy_enc_opt_array(struct xdr_stream *xdr, - struct gssx_option_array *oa) + const struct gssx_option_array *oa) { __be32 *p; @@ -166,14 +153,15 @@ static int dummy_dec_opt_array(struct xdr_stream *xdr, return 0; } -static int get_s32(void **p, void *max, s32 *res) +static int get_host_u32(struct xdr_stream *xdr, u32 *res) { - void *base = *p; - void *next = (void *)((char *)base + sizeof(s32)); - if (unlikely(next > max || next < base)) + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (!p) return -EINVAL; - memcpy(res, base, sizeof(s32)); - *p = next; + /* Contents of linux creds are all host-endian: */ + memcpy(res, p, sizeof(u32)); return 0; } @@ -182,9 +170,9 @@ static int gssx_dec_linux_creds(struct xdr_stream *xdr, { u32 length; __be32 *p; - void *q, *end; - s32 tmp; - int N, i, err; + u32 tmp; + u32 N; + int i, err; p = xdr_inline_decode(xdr, 4); if (unlikely(p == NULL)) @@ -192,33 +180,28 @@ static int gssx_dec_linux_creds(struct xdr_stream *xdr, length = be32_to_cpup(p); - /* FIXME: we do not want to use the scratch buffer for this one - * may need to use functions that allows us to access an io vector - * directly */ - p = xdr_inline_decode(xdr, length); - if (unlikely(p == NULL)) + if (length > (3 + NGROUPS_MAX) * sizeof(u32)) return -ENOSPC; - q = p; - end = q + length; - /* uid */ - err = get_s32(&q, end, &tmp); + err = get_host_u32(xdr, &tmp); if (err) return err; creds->cr_uid = make_kuid(&init_user_ns, tmp); /* gid */ - err = get_s32(&q, end, &tmp); + err = get_host_u32(xdr, &tmp); if (err) return err; creds->cr_gid = make_kgid(&init_user_ns, tmp); /* number of additional gid's */ - err = get_s32(&q, end, &tmp); + err = get_host_u32(xdr, &tmp); if (err) return err; N = tmp; + if ((3 + N) * sizeof(u32) != length) + return -EINVAL; creds->cr_group_info = groups_alloc(N); if (creds->cr_group_info == NULL) return -ENOMEM; @@ -226,15 +209,16 @@ static int gssx_dec_linux_creds(struct xdr_stream *xdr, /* gid's */ for (i = 0; i < N; i++) { kgid_t kgid; - err = get_s32(&q, end, &tmp); + err = get_host_u32(xdr, &tmp); if (err) goto out_free_groups; err = -EINVAL; kgid = make_kgid(&init_user_ns, tmp); if (!gid_valid(kgid)) goto out_free_groups; - GROUP_AT(creds->cr_group_info, i) = kgid; + creds->cr_group_info->gid[i] = kgid; } + groups_sort(creds->cr_group_info); return 0; out_free_groups: @@ -264,10 +248,10 @@ static int gssx_dec_option_array(struct xdr_stream *xdr, if (!oa->data) return -ENOMEM; - creds = kmalloc(sizeof(struct svc_cred), GFP_KERNEL); + creds = kzalloc(sizeof(struct svc_cred), GFP_KERNEL); if (!creds) { - kfree(oa->data); - return -ENOMEM; + err = -ENOMEM; + goto free_oa; } oa->data[0].option.data = CREDS_VALUE; @@ -281,29 +265,40 @@ static int gssx_dec_option_array(struct xdr_stream *xdr, /* option buffer */ p = xdr_inline_decode(xdr, 4); - if (unlikely(p == NULL)) - return -ENOSPC; + if (unlikely(p == NULL)) { + err = -ENOSPC; + goto free_creds; + } length = be32_to_cpup(p); p = xdr_inline_decode(xdr, length); - if (unlikely(p == NULL)) - return -ENOSPC; + if (unlikely(p == NULL)) { + err = -ENOSPC; + goto free_creds; + } if (length == sizeof(CREDS_VALUE) && memcmp(p, CREDS_VALUE, sizeof(CREDS_VALUE)) == 0) { /* We have creds here. parse them */ err = gssx_dec_linux_creds(xdr, creds); if (err) - return err; + goto free_creds; oa->data[0].value.len = 1; /* presence */ } else { /* consume uninteresting buffer */ err = gssx_dec_buffer(xdr, &dummy); if (err) - return err; + goto free_creds; } } return 0; + +free_creds: + kfree(creds); +free_oa: + kfree(oa->data); + oa->data = NULL; + return err; } static int gssx_dec_status(struct xdr_stream *xdr, @@ -352,7 +347,7 @@ static int gssx_dec_status(struct xdr_stream *xdr, } static int gssx_enc_call_ctx(struct xdr_stream *xdr, - struct gssx_call_ctx *ctx) + const struct gssx_call_ctx *ctx) { struct gssx_option opt; __be32 *p; @@ -430,7 +425,7 @@ static int dummy_enc_nameattr_array(struct xdr_stream *xdr, static int dummy_dec_nameattr_array(struct xdr_stream *xdr, struct gssx_name_attr_array *naa) { - struct gssx_name_attr dummy; + struct gssx_name_attr dummy = { .attr = {.len = 0} }; u32 count, i; __be32 *p; @@ -493,12 +488,13 @@ static int gssx_enc_name(struct xdr_stream *xdr, return err; } + static int gssx_dec_name(struct xdr_stream *xdr, struct gssx_name *name) { - struct xdr_netobj dummy_netobj; - struct gssx_name_attr_array dummy_name_attr_array; - struct gssx_option_array dummy_option_array; + struct xdr_netobj dummy_netobj = { .len = 0 }; + struct gssx_name_attr_array dummy_name_attr_array = { .count = 0 }; + struct gssx_option_array dummy_option_array = { .count = 0 }; int err; /* name->display_name */ @@ -562,6 +558,8 @@ static int gssx_enc_cred(struct xdr_stream *xdr, /* cred->elements */ err = dummy_enc_credel_array(xdr, &cred->elements); + if (err) + return err; /* cred->cred_handle_reference */ err = gssx_enc_buffer(xdr, &cred->cred_handle_reference); @@ -734,8 +732,9 @@ static int gssx_enc_cb(struct xdr_stream *xdr, struct gssx_cb *cb) void gssx_enc_accept_sec_context(struct rpc_rqst *req, struct xdr_stream *xdr, - struct gssx_arg_accept_sec_context *arg) + const void *data) { + const struct gssx_arg_accept_sec_context *arg = data; int err; err = gssx_enc_call_ctx(xdr, &arg->call_ctx); @@ -743,22 +742,20 @@ void gssx_enc_accept_sec_context(struct rpc_rqst *req, goto done; /* arg->context_handle */ - if (arg->context_handle) { + if (arg->context_handle) err = gssx_enc_ctx(xdr, arg->context_handle); - if (err) - goto done; - } else { + else err = gssx_enc_bool(xdr, 0); - } + if (err) + goto done; /* arg->cred_handle */ - if (arg->cred_handle) { + if (arg->cred_handle) err = gssx_enc_cred(xdr, arg->cred_handle); - if (err) - goto done; - } else { + else err = gssx_enc_bool(xdr, 0); - } + if (err) + goto done; /* arg->input_token */ err = gssx_enc_in_token(xdr, &arg->input_token); @@ -766,13 +763,12 @@ void gssx_enc_accept_sec_context(struct rpc_rqst *req, goto done; /* arg->input_cb */ - if (arg->input_cb) { + if (arg->input_cb) err = gssx_enc_cb(xdr, arg->input_cb); - if (err) - goto done; - } else { + else err = gssx_enc_bool(xdr, 0); - } + if (err) + goto done; err = gssx_enc_bool(xdr, arg->ret_deleg_cred); if (err) @@ -783,6 +779,9 @@ void gssx_enc_accept_sec_context(struct rpc_rqst *req, /* arg->options */ err = dummy_enc_opt_array(xdr, &arg->options); + xdr_inline_pages(&req->rq_rcv_buf, + PAGE_SIZE/2 /* pretty arbitrary */, + arg->pages, 0 /* page base */, arg->npages * PAGE_SIZE); done: if (err) dprintk("RPC: gssx_enc_accept_sec_context: %d\n", err); @@ -790,24 +789,31 @@ done: int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, struct xdr_stream *xdr, - struct gssx_res_accept_sec_context *res) + void *data) { + struct gssx_res_accept_sec_context *res = data; u32 value_follows; int err; + struct folio *scratch; + + scratch = folio_alloc(GFP_KERNEL, 0); + if (!scratch) + return -ENOMEM; + xdr_set_scratch_folio(xdr, scratch); /* res->status */ err = gssx_dec_status(xdr, &res->status); if (err) - return err; + goto out_free; /* res->context_handle */ err = gssx_dec_bool(xdr, &value_follows); if (err) - return err; + goto out_free; if (value_follows) { err = gssx_dec_ctx(xdr, res->context_handle); if (err) - return err; + goto out_free; } else { res->context_handle = NULL; } @@ -815,11 +821,11 @@ int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, /* res->output_token */ err = gssx_dec_bool(xdr, &value_follows); if (err) - return err; + goto out_free; if (value_follows) { err = gssx_dec_buffer(xdr, res->output_token); if (err) - return err; + goto out_free; } else { res->output_token = NULL; } @@ -827,14 +833,17 @@ int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, /* res->delegated_cred_handle */ err = gssx_dec_bool(xdr, &value_follows); if (err) - return err; + goto out_free; if (value_follows) { /* we do not support upcall servers sending this data. */ - return -EINVAL; + err = -EINVAL; + goto out_free; } /* res->options */ err = gssx_dec_option_array(xdr, &res->options); +out_free: + folio_put(scratch); return err; } diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.h b/net/sunrpc/auth_gss/gss_rpc_xdr.h index 1c98b27d870c..3f17411b7e65 100644 --- a/net/sunrpc/auth_gss/gss_rpc_xdr.h +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.h @@ -1,21 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ */ /* * GSS Proxy upcall module * * Copyright (C) 2012 Simo Sorce <simo@redhat.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ #ifndef _LINUX_GSS_RPC_XDR_H @@ -25,7 +12,7 @@ #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/xprtsock.h> -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif @@ -147,6 +134,8 @@ struct gssx_arg_accept_sec_context { struct gssx_cb *input_cb; u32 ret_deleg_cred; struct gssx_option_array options; + struct page **pages; + unsigned int npages; }; struct gssx_res_accept_sec_context { @@ -177,10 +166,10 @@ struct gssx_res_accept_sec_context { #define gssx_dec_init_sec_context NULL void gssx_enc_accept_sec_context(struct rpc_rqst *req, struct xdr_stream *xdr, - struct gssx_arg_accept_sec_context *args); + const void *data); int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, struct xdr_stream *xdr, - struct gssx_res_accept_sec_context *res); + void *data); #define gssx_enc_release_handle NULL #define gssx_dec_release_handle NULL #define gssx_enc_get_mic NULL @@ -240,7 +229,8 @@ int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, 2 * GSSX_max_princ_sz + \ 8 + 8 + 4 + 4 + 4) #define GSSX_max_output_token_sz 1024 -#define GSSX_max_creds_sz (4 + 4 + 4 + NGROUPS_MAX * 4) +/* grouplist not included; we allocate separate pages for that: */ +#define GSSX_max_creds_sz (4 + 4 + 4 /* + NGROUPS_MAX*4 */) #define GSSX_RES_accept_sec_context_sz (GSSX_default_status_sz + \ GSSX_default_ctx_sz + \ GSSX_max_output_token_sz + \ @@ -259,6 +249,4 @@ int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, #define GSSX_ARG_wrap_size_limit_sz 0 #define GSSX_RES_wrap_size_limit_sz 0 - - #endif /* _LINUX_GSS_RPC_XDR_H */ diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c index d0347d148b34..a8ec30759a18 100644 --- a/net/sunrpc/auth_gss/svcauth_gss.c +++ b/net/sunrpc/auth_gss/svcauth_gss.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0 /* * Neil Brown <neilb@cse.unsw.edu.au> * J. Bruce Fields <bfields@umich.edu> @@ -48,12 +49,35 @@ #include <linux/sunrpc/svcauth.h> #include <linux/sunrpc/svcauth_gss.h> #include <linux/sunrpc/cache.h> +#include <linux/sunrpc/gss_krb5.h> + +#include <trace/events/rpcgss.h> + #include "gss_rpc_upcall.h" +/* + * Unfortunately there isn't a maximum checksum size exported via the + * GSS API. Manufacture one based on GSS mechanisms supported by this + * implementation. + */ +#define GSS_MAX_CKSUMSIZE (GSS_KRB5_TOK_HDR_LEN + GSS_KRB5_MAX_CKSUM_LEN) + +/* + * This value may be increased in the future to accommodate other + * usage of the scratch buffer. + */ +#define GSS_SCRATCH_SIZE GSS_MAX_CKSUMSIZE + +struct gss_svc_data { + /* decoded gss client cred: */ + struct rpc_gss_wire_cred clcred; + u32 gsd_databody_offset; + struct rsc *rsci; -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif + /* for temporary results */ + __be32 gsd_seq_num; + u8 gsd_scratch[GSS_SCRATCH_SIZE]; +}; /* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests * into replies. @@ -76,6 +100,7 @@ struct rsi { struct xdr_netobj in_handle, in_token; struct xdr_netobj out_handle, out_token; int major_status, minor_status; + struct rcu_head rcu_head; }; static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct rsi *old); @@ -89,13 +114,21 @@ static void rsi_free(struct rsi *rsii) kfree(rsii->out_token.data); } -static void rsi_put(struct kref *ref) +static void rsi_free_rcu(struct rcu_head *head) { - struct rsi *rsii = container_of(ref, struct rsi, h.ref); + struct rsi *rsii = container_of(head, struct rsi, rcu_head); + rsi_free(rsii); kfree(rsii); } +static void rsi_put(struct kref *ref) +{ + struct rsi *rsii = container_of(ref, struct rsi, h.ref); + + call_rcu(&rsii->rcu_head, rsi_free_rcu); +} + static inline int rsi_hash(struct rsi *item) { return hash_mem(item->in_handle.data, item->in_handle.len, RSI_HASHBITS) @@ -171,6 +204,11 @@ static struct cache_head *rsi_alloc(void) return NULL; } +static int rsi_upcall(struct cache_detail *cd, struct cache_head *h) +{ + return sunrpc_cache_pipe_upcall_timeout(cd, h); +} + static void rsi_request(struct cache_detail *cd, struct cache_head *h, char **bpp, int *blen) @@ -180,6 +218,8 @@ static void rsi_request(struct cache_detail *cd, qword_addhex(bpp, blen, rsii->in_handle.data, rsii->in_handle.len); qword_addhex(bpp, blen, rsii->in_token.data, rsii->in_token.len); (*bpp)[-1] = '\n'; + WARN_ONCE(*blen < 0, + "RPCSEC/GSS credential too large - please use gssproxy\n"); } static int rsi_parse(struct cache_detail *cd, @@ -190,7 +230,7 @@ static int rsi_parse(struct cache_detail *cd, char *ep; int len; struct rsi rsii, *rsip = NULL; - time_t expiry; + time64_t expiry; int status = -EINVAL; memset(&rsii, 0, sizeof(rsii)); @@ -217,11 +257,11 @@ static int rsi_parse(struct cache_detail *cd, rsii.h.flags = 0; /* expiry */ - expiry = get_expiry(&mesg); - status = -EINVAL; - if (expiry == 0) + status = get_expiry(&mesg, &expiry); + if (status) goto out; + status = -EINVAL; /* major/minor */ len = qword_get(&mesg, buf, mlen); if (len <= 0) @@ -264,11 +304,12 @@ out: return status; } -static struct cache_detail rsi_cache_template = { +static const struct cache_detail rsi_cache_template = { .owner = THIS_MODULE, .hash_size = RSI_HASHMAX, .name = "auth.rpcsec.init", .cache_put = rsi_put, + .cache_upcall = rsi_upcall, .cache_request = rsi_request, .cache_parse = rsi_parse, .match = rsi_match, @@ -282,7 +323,7 @@ static struct rsi *rsi_lookup(struct cache_detail *cd, struct rsi *item) struct cache_head *ch; int hash = rsi_hash(item); - ch = sunrpc_cache_lookup(cd, &item->h, hash); + ch = sunrpc_cache_lookup_rcu(cd, &item->h, hash); if (ch) return container_of(ch, struct rsi, h); else @@ -317,7 +358,7 @@ static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct r struct gss_svc_seq_data { /* highest seq number seen so far: */ - int sd_max; + u32 sd_max; /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of * sd_win is nonzero iff sequence number i has been seen already: */ unsigned long sd_win[GSS_SEQ_WIN/BITS_PER_LONG]; @@ -330,6 +371,7 @@ struct rsc { struct svc_cred cred; struct gss_svc_seq_data seqdata; struct gss_ctx *mechctx; + struct rcu_head rcu_head; }; static struct rsc *rsc_update(struct cache_detail *cd, struct rsc *new, struct rsc *old); @@ -343,12 +385,22 @@ static void rsc_free(struct rsc *rsci) free_svc_cred(&rsci->cred); } +static void rsc_free_rcu(struct rcu_head *head) +{ + struct rsc *rsci = container_of(head, struct rsc, rcu_head); + + kfree(rsci->handle.data); + kfree(rsci); +} + static void rsc_put(struct kref *ref) { struct rsc *rsci = container_of(ref, struct rsc, h.ref); - rsc_free(rsci); - kfree(rsci); + if (rsci->mechctx) + gss_delete_sec_context(&rsci->mechctx); + free_svc_cred(&rsci->cred); + call_rcu(&rsci->rcu_head, rsc_free_rcu); } static inline int @@ -404,6 +456,11 @@ rsc_alloc(void) return NULL; } +static int rsc_upcall(struct cache_detail *cd, struct cache_head *h) +{ + return -EINVAL; +} + static int rsc_parse(struct cache_detail *cd, char *mesg, int mlen) { @@ -412,7 +469,7 @@ static int rsc_parse(struct cache_detail *cd, int id; int len, rv; struct rsc rsci, *rscp = NULL; - time_t expiry; + time64_t expiry; int status = -EINVAL; struct gss_api_mech *gm = NULL; @@ -426,11 +483,11 @@ static int rsc_parse(struct cache_detail *cd, rsci.h.flags = 0; /* expiry */ - expiry = get_expiry(&mesg); - status = -EINVAL; - if (expiry == 0) + status = get_expiry(&mesg, &expiry); + if (status) goto out; + status = -EINVAL; rscp = rsc_lookup(cd, &rsci); if (!rscp) goto out; @@ -453,16 +510,18 @@ static int rsc_parse(struct cache_detail *cd, * treatment so are checked for validity here.) */ /* uid */ - rsci.cred.cr_uid = make_kuid(&init_user_ns, id); + rsci.cred.cr_uid = make_kuid(current_user_ns(), id); /* gid */ if (get_int(&mesg, &id)) goto out; - rsci.cred.cr_gid = make_kgid(&init_user_ns, id); + rsci.cred.cr_gid = make_kgid(current_user_ns(), id); /* number of additional gid's */ if (get_int(&mesg, &N)) goto out; + if (N < 0 || N > NGROUPS_MAX) + goto out; status = -ENOMEM; rsci.cred.cr_group_info = groups_alloc(N); if (rsci.cred.cr_group_info == NULL) @@ -474,11 +533,12 @@ static int rsc_parse(struct cache_detail *cd, kgid_t kgid; if (get_int(&mesg, &id)) goto out; - kgid = make_kgid(&init_user_ns, id); + kgid = make_kgid(current_user_ns(), id); if (!gid_valid(kgid)) goto out; - GROUP_AT(rsci.cred.cr_group_info, i) = kgid; + rsci.cred.cr_group_info->gid[i] = kgid; } + groups_sort(rsci.cred.cr_group_info); /* mech name */ len = qword_get(&mesg, buf, mlen); @@ -522,11 +582,12 @@ out: return status; } -static struct cache_detail rsc_cache_template = { +static const struct cache_detail rsc_cache_template = { .owner = THIS_MODULE, .hash_size = RSC_HASHMAX, .name = "auth.rpcsec.context", .cache_put = rsc_put, + .cache_upcall = rsc_upcall, .cache_parse = rsc_parse, .match = rsc_match, .init = rsc_init, @@ -539,7 +600,7 @@ static struct rsc *rsc_lookup(struct cache_detail *cd, struct rsc *item) struct cache_head *ch; int hash = rsc_hash(item); - ch = sunrpc_cache_lookup(cd, &item->h, hash); + ch = sunrpc_cache_lookup_rcu(cd, &item->h, hash); if (ch) return container_of(ch, struct rsc, h); else @@ -578,16 +639,29 @@ gss_svc_searchbyctx(struct cache_detail *cd, struct xdr_netobj *handle) return found; } -/* Implements sequence number algorithm as specified in RFC 2203. */ -static int -gss_check_seq_num(struct rsc *rsci, int seq_num) +/** + * gss_check_seq_num - GSS sequence number window check + * @rqstp: RPC Call to use when reporting errors + * @rsci: cached GSS context state (updated on return) + * @seq_num: sequence number to check + * + * Implements sequence number algorithm as specified in + * RFC 2203, Section 5.3.3.1. "Context Management". + * + * Return values: + * %true: @rqstp's GSS sequence number is inside the window + * %false: @rqstp's GSS sequence number is outside the window + */ +static bool gss_check_seq_num(const struct svc_rqst *rqstp, struct rsc *rsci, + u32 seq_num) { struct gss_svc_seq_data *sd = &rsci->seqdata; + bool result = false; spin_lock(&sd->sd_lock); if (seq_num > sd->sd_max) { if (seq_num >= sd->sd_max + GSS_SEQ_WIN) { - memset(sd->sd_win,0,sizeof(sd->sd_win)); + memset(sd->sd_win, 0, sizeof(sd->sd_win)); sd->sd_max = seq_num; } else while (sd->sd_max < seq_num) { sd->sd_max++; @@ -595,151 +669,115 @@ gss_check_seq_num(struct rsc *rsci, int seq_num) } __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win); goto ok; - } else if (seq_num <= sd->sd_max - GSS_SEQ_WIN) { - goto drop; + } else if (seq_num + GSS_SEQ_WIN <= sd->sd_max) { + goto toolow; } - /* sd_max - GSS_SEQ_WIN < seq_num <= sd_max */ if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win)) - goto drop; + goto alreadyseen; + ok: + result = true; +out: spin_unlock(&sd->sd_lock); - return 1; -drop: - spin_unlock(&sd->sd_lock); - return 0; -} - -static inline u32 round_up_to_quad(u32 i) -{ - return (i + 3 ) & ~3; -} - -static inline int -svc_safe_getnetobj(struct kvec *argv, struct xdr_netobj *o) -{ - int l; - - if (argv->iov_len < 4) - return -1; - o->len = svc_getnl(argv); - l = round_up_to_quad(o->len); - if (argv->iov_len < l) - return -1; - o->data = argv->iov_base; - argv->iov_base += l; - argv->iov_len -= l; - return 0; -} + return result; -static inline int -svc_safe_putnetobj(struct kvec *resv, struct xdr_netobj *o) -{ - u8 *p; - - if (resv->iov_len + 4 > PAGE_SIZE) - return -1; - svc_putnl(resv, o->len); - p = resv->iov_base + resv->iov_len; - resv->iov_len += round_up_to_quad(o->len); - if (resv->iov_len > PAGE_SIZE) - return -1; - memcpy(p, o->data, o->len); - memset(p + o->len, 0, round_up_to_quad(o->len) - o->len); - return 0; +toolow: + trace_rpcgss_svc_seqno_low(rqstp, seq_num, + sd->sd_max - GSS_SEQ_WIN, + sd->sd_max); + goto out; +alreadyseen: + trace_rpcgss_svc_seqno_seen(rqstp, seq_num); + goto out; } /* - * Verify the checksum on the header and return SVC_OK on success. - * Otherwise, return SVC_DROP (in the case of a bad sequence number) - * or return SVC_DENIED and indicate error in authp. + * Decode and verify a Call's verifier field. For RPC_AUTH_GSS Calls, + * the body of this field contains a variable length checksum. + * + * GSS-specific auth_stat values are mandated by RFC 2203 Section + * 5.3.3.3. */ static int -gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci, - __be32 *rpcstart, struct rpc_gss_wire_cred *gc, __be32 *authp) +svcauth_gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci, + __be32 *rpcstart, struct rpc_gss_wire_cred *gc) { + struct xdr_stream *xdr = &rqstp->rq_arg_stream; struct gss_ctx *ctx_id = rsci->mechctx; + u32 flavor, maj_stat; struct xdr_buf rpchdr; struct xdr_netobj checksum; - u32 flavor = 0; - struct kvec *argv = &rqstp->rq_arg.head[0]; struct kvec iov; - /* data to compute the checksum over: */ + /* + * Compute the checksum of the incoming Call from the + * XID field to credential field: + */ iov.iov_base = rpcstart; - iov.iov_len = (u8 *)argv->iov_base - (u8 *)rpcstart; + iov.iov_len = (u8 *)xdr->p - (u8 *)rpcstart; xdr_buf_from_iov(&iov, &rpchdr); - *authp = rpc_autherr_badverf; - if (argv->iov_len < 4) + /* Call's verf field: */ + if (xdr_stream_decode_opaque_auth(xdr, &flavor, + (void **)&checksum.data, + &checksum.len) < 0) { + rqstp->rq_auth_stat = rpc_autherr_badverf; return SVC_DENIED; - flavor = svc_getnl(argv); - if (flavor != RPC_AUTH_GSS) - return SVC_DENIED; - if (svc_safe_getnetobj(argv, &checksum)) + } + if (flavor != RPC_AUTH_GSS || checksum.len < XDR_UNIT) { + rqstp->rq_auth_stat = rpc_autherr_badverf; return SVC_DENIED; + } - if (rqstp->rq_deferred) /* skip verification of revisited request */ + if (rqstp->rq_deferred) return SVC_OK; - if (gss_verify_mic(ctx_id, &rpchdr, &checksum) != GSS_S_COMPLETE) { - *authp = rpcsec_gsserr_credproblem; + maj_stat = gss_verify_mic(ctx_id, &rpchdr, &checksum); + if (maj_stat != GSS_S_COMPLETE) { + trace_rpcgss_svc_mic(rqstp, maj_stat); + rqstp->rq_auth_stat = rpcsec_gsserr_credproblem; return SVC_DENIED; } if (gc->gc_seq > MAXSEQ) { - dprintk("RPC: svcauth_gss: discarding request with " - "large sequence number %d\n", gc->gc_seq); - *authp = rpcsec_gsserr_ctxproblem; + trace_rpcgss_svc_seqno_large(rqstp, gc->gc_seq); + rqstp->rq_auth_stat = rpcsec_gsserr_ctxproblem; return SVC_DENIED; } - if (!gss_check_seq_num(rsci, gc->gc_seq)) { - dprintk("RPC: svcauth_gss: discarding request with " - "old sequence number %d\n", gc->gc_seq); + if (!gss_check_seq_num(rqstp, rsci, gc->gc_seq)) return SVC_DROP; - } return SVC_OK; } -static int -gss_write_null_verf(struct svc_rqst *rqstp) -{ - __be32 *p; - - svc_putnl(rqstp->rq_res.head, RPC_AUTH_NULL); - p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len; - /* don't really need to check if head->iov_len > PAGE_SIZE ... */ - *p++ = 0; - if (!xdr_ressize_check(rqstp, p)) - return -1; - return 0; -} - -static int -gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq) +/* + * Construct and encode a Reply's verifier field. The verifier's body + * field contains a variable-length checksum of the GSS sequence + * number. + */ +static bool +svcauth_gss_encode_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq) { - __be32 xdr_seq; + struct gss_svc_data *gsd = rqstp->rq_auth_data; u32 maj_stat; struct xdr_buf verf_data; - struct xdr_netobj mic; - __be32 *p; + struct xdr_netobj checksum; struct kvec iov; - svc_putnl(rqstp->rq_res.head, RPC_AUTH_GSS); - xdr_seq = htonl(seq); - - iov.iov_base = &xdr_seq; - iov.iov_len = sizeof(xdr_seq); + gsd->gsd_seq_num = cpu_to_be32(seq); + iov.iov_base = &gsd->gsd_seq_num; + iov.iov_len = XDR_UNIT; xdr_buf_from_iov(&iov, &verf_data); - p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len; - mic.data = (u8 *)(p + 1); - maj_stat = gss_get_mic(ctx_id, &verf_data, &mic); + + checksum.data = gsd->gsd_scratch; + maj_stat = gss_get_mic(ctx_id, &verf_data, &checksum); if (maj_stat != GSS_S_COMPLETE) - return -1; - *p++ = htonl(mic.len); - memset((u8 *)p + mic.len, 0, round_up_to_quad(mic.len) - mic.len); - p += XDR_QUADLEN(mic.len); - if (!xdr_ressize_check(rqstp, p)) - return -1; - return 0; + goto bad_mic; + + return xdr_stream_encode_opaque_auth(&rqstp->rq_res_stream, RPC_AUTH_GSS, + checksum.data, checksum.len) > 0; + +bad_mic: + trace_rpcgss_svc_get_mic(rqstp, maj_stat); + return false; } struct gss_domain { @@ -769,7 +807,7 @@ u32 svcauth_gss_flavor(struct auth_domain *dom) EXPORT_SYMBOL_GPL(svcauth_gss_flavor); -int +struct auth_domain * svcauth_gss_register_pseudoflavor(u32 pseudoflavor, char * name) { struct gss_domain *new; @@ -786,161 +824,159 @@ svcauth_gss_register_pseudoflavor(u32 pseudoflavor, char * name) new->h.flavour = &svcauthops_gss; new->pseudoflavor = pseudoflavor; - stat = 0; test = auth_domain_lookup(name, &new->h); - if (test != &new->h) { /* Duplicate registration */ + if (test != &new->h) { + pr_warn("svc: duplicate registration of gss pseudo flavour %s.\n", + name); + stat = -EADDRINUSE; auth_domain_put(test); - kfree(new->h.name); - goto out_free_dom; + goto out_free_name; } - return 0; + return test; +out_free_name: + kfree(new->h.name); out_free_dom: kfree(new); out: - return stat; + return ERR_PTR(stat); } - EXPORT_SYMBOL_GPL(svcauth_gss_register_pseudoflavor); -static inline int -read_u32_from_xdr_buf(struct xdr_buf *buf, int base, u32 *obj) -{ - __be32 raw; - int status; - - status = read_bytes_from_xdr_buf(buf, base, &raw, sizeof(*obj)); - if (status) - return status; - *obj = ntohl(raw); - return 0; -} - -/* It would be nice if this bit of code could be shared with the client. - * Obstacles: - * The client shouldn't malloc(), would have to pass in own memory. - * The server uses base of head iovec as read pointer, while the - * client uses separate pointer. */ -static int -unwrap_integ_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) +/* + * RFC 2203, Section 5.3.2.2 + * + * struct rpc_gss_integ_data { + * opaque databody_integ<>; + * opaque checksum<>; + * }; + * + * struct rpc_gss_data_t { + * unsigned int seq_num; + * proc_req_arg_t arg; + * }; + */ +static noinline_for_stack int +svcauth_gss_unwrap_integ(struct svc_rqst *rqstp, u32 seq, struct gss_ctx *ctx) { - int stat = -EINVAL; - u32 integ_len, maj_stat; - struct xdr_netobj mic; - struct xdr_buf integ_buf; + struct gss_svc_data *gsd = rqstp->rq_auth_data; + struct xdr_stream *xdr = &rqstp->rq_arg_stream; + u32 len, offset, seq_num, maj_stat; + struct xdr_buf *buf = xdr->buf; + struct xdr_buf databody_integ; + struct xdr_netobj checksum; /* Did we already verify the signature on the original pass through? */ if (rqstp->rq_deferred) return 0; - integ_len = svc_getnl(&buf->head[0]); - if (integ_len & 3) - return stat; - if (integ_len > buf->len) - return stat; - if (xdr_buf_subsegment(buf, &integ_buf, 0, integ_len)) - BUG(); - /* copy out mic... */ - if (read_u32_from_xdr_buf(buf, integ_len, &mic.len)) - BUG(); - if (mic.len > RPC_MAX_AUTH_SIZE) - return stat; - mic.data = kmalloc(mic.len, GFP_KERNEL); - if (!mic.data) - return stat; - if (read_bytes_from_xdr_buf(buf, integ_len + 4, mic.data, mic.len)) - goto out; - maj_stat = gss_verify_mic(ctx, &integ_buf, &mic); + if (xdr_stream_decode_u32(xdr, &len) < 0) + goto unwrap_failed; + if (len & 3) + goto unwrap_failed; + offset = xdr_stream_pos(xdr); + if (xdr_buf_subsegment(buf, &databody_integ, offset, len)) + goto unwrap_failed; + + /* + * The xdr_stream now points to the @seq_num field. The next + * XDR data item is the @arg field, which contains the clear + * text RPC program payload. The checksum, which follows the + * @arg field, is located and decoded without updating the + * xdr_stream. + */ + + offset += len; + if (xdr_decode_word(buf, offset, &checksum.len)) + goto unwrap_failed; + if (checksum.len > sizeof(gsd->gsd_scratch)) + goto unwrap_failed; + checksum.data = gsd->gsd_scratch; + if (read_bytes_from_xdr_buf(buf, offset + XDR_UNIT, checksum.data, + checksum.len)) + goto unwrap_failed; + + maj_stat = gss_verify_mic(ctx, &databody_integ, &checksum); if (maj_stat != GSS_S_COMPLETE) - goto out; - if (svc_getnl(&buf->head[0]) != seq) - goto out; - /* trim off the mic at the end before returning */ - xdr_buf_trim(buf, mic.len + 4); - stat = 0; -out: - kfree(mic.data); - return stat; -} + goto bad_mic; -static inline int -total_buf_len(struct xdr_buf *buf) -{ - return buf->head[0].iov_len + buf->page_len + buf->tail[0].iov_len; -} + /* The received seqno is protected by the checksum. */ + if (xdr_stream_decode_u32(xdr, &seq_num) < 0) + goto unwrap_failed; + if (seq_num != seq) + goto bad_seqno; -static void -fix_priv_head(struct xdr_buf *buf, int pad) -{ - if (buf->page_len == 0) { - /* We need to adjust head and buf->len in tandem in this - * case to make svc_defer() work--it finds the original - * buffer start using buf->len - buf->head[0].iov_len. */ - buf->head[0].iov_len -= pad; - } + xdr_truncate_decode(xdr, XDR_UNIT + checksum.len); + return 0; + +unwrap_failed: + trace_rpcgss_svc_unwrap_failed(rqstp); + return -EINVAL; +bad_seqno: + trace_rpcgss_svc_seqno_bad(rqstp, seq, seq_num); + return -EINVAL; +bad_mic: + trace_rpcgss_svc_mic(rqstp, maj_stat); + return -EINVAL; } -static int -unwrap_priv_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) +/* + * RFC 2203, Section 5.3.2.3 + * + * struct rpc_gss_priv_data { + * opaque databody_priv<> + * }; + * + * struct rpc_gss_data_t { + * unsigned int seq_num; + * proc_req_arg_t arg; + * }; + */ +static noinline_for_stack int +svcauth_gss_unwrap_priv(struct svc_rqst *rqstp, u32 seq, struct gss_ctx *ctx) { - u32 priv_len, maj_stat; - int pad, saved_len, remaining_len, offset; - - rqstp->rq_splice_ok = 0; + struct xdr_stream *xdr = &rqstp->rq_arg_stream; + u32 len, maj_stat, seq_num, offset; + struct xdr_buf *buf = xdr->buf; + unsigned int saved_len; - priv_len = svc_getnl(&buf->head[0]); + if (xdr_stream_decode_u32(xdr, &len) < 0) + goto unwrap_failed; if (rqstp->rq_deferred) { /* Already decrypted last time through! The sequence number * check at out_seq is unnecessary but harmless: */ goto out_seq; } - /* buf->len is the number of bytes from the original start of the - * request to the end, where head[0].iov_len is just the bytes - * not yet read from the head, so these two values are different: */ - remaining_len = total_buf_len(buf); - if (priv_len > remaining_len) - return -EINVAL; - pad = remaining_len - priv_len; - buf->len -= pad; - fix_priv_head(buf, pad); + if (len > xdr_stream_remaining(xdr)) + goto unwrap_failed; + offset = xdr_stream_pos(xdr); - /* Maybe it would be better to give gss_unwrap a length parameter: */ saved_len = buf->len; - buf->len = priv_len; - maj_stat = gss_unwrap(ctx, 0, buf); - pad = priv_len - buf->len; - buf->len = saved_len; - buf->len -= pad; - /* The upper layers assume the buffer is aligned on 4-byte boundaries. - * In the krb5p case, at least, the data ends up offset, so we need to - * move it around. */ - /* XXX: This is very inefficient. It would be better to either do - * this while we encrypt, or maybe in the receive code, if we can peak - * ahead and work out the service and mechanism there. */ - offset = buf->head[0].iov_len % 4; - if (offset) { - buf->buflen = RPCSVC_MAXPAYLOAD; - xdr_shift_buf(buf, offset); - fix_priv_head(buf, pad); - } + maj_stat = gss_unwrap(ctx, offset, offset + len, buf); if (maj_stat != GSS_S_COMPLETE) - return -EINVAL; + goto bad_unwrap; + xdr->nwords -= XDR_QUADLEN(saved_len - buf->len); + out_seq: - if (svc_getnl(&buf->head[0]) != seq) - return -EINVAL; + /* gss_unwrap() decrypted the sequence number. */ + if (xdr_stream_decode_u32(xdr, &seq_num) < 0) + goto unwrap_failed; + if (seq_num != seq) + goto bad_seqno; return 0; -} -struct gss_svc_data { - /* decoded gss client cred: */ - struct rpc_gss_wire_cred clcred; - /* save a pointer to the beginning of the encoded verifier, - * for use in encryption/checksumming in svcauth_gss_release: */ - __be32 *verf_start; - struct rsc *rsci; -}; +unwrap_failed: + trace_rpcgss_svc_unwrap_failed(rqstp); + return -EINVAL; +bad_seqno: + trace_rpcgss_svc_seqno_bad(rqstp, seq, seq_num); + return -EINVAL; +bad_unwrap: + trace_rpcgss_svc_unwrap(rqstp, maj_stat); + return -EINVAL; +} -static int +static enum svc_auth_status svcauth_gss_set_client(struct svc_rqst *rqstp) { struct gss_svc_data *svcdata = rqstp->rq_auth_data; @@ -948,6 +984,8 @@ svcauth_gss_set_client(struct svc_rqst *rqstp) struct rpc_gss_wire_cred *gc = &svcdata->clcred; int stat; + rqstp->rq_auth_stat = rpc_autherr_badcred; + /* * A gss export can be specified either by: * export *(sec=krb5,rw) @@ -963,129 +1001,146 @@ svcauth_gss_set_client(struct svc_rqst *rqstp) stat = svcauth_unix_set_client(rqstp); if (stat == SVC_DROP || stat == SVC_CLOSE) return stat; + + rqstp->rq_auth_stat = rpc_auth_ok; return SVC_OK; } -static inline int -gss_write_init_verf(struct cache_detail *cd, struct svc_rqst *rqstp, - struct xdr_netobj *out_handle, int *major_status) +static bool +svcauth_gss_proc_init_verf(struct cache_detail *cd, struct svc_rqst *rqstp, + struct xdr_netobj *out_handle, int *major_status, + u32 seq_num) { + struct xdr_stream *xdr = &rqstp->rq_res_stream; struct rsc *rsci; - int rc; + bool rc; if (*major_status != GSS_S_COMPLETE) - return gss_write_null_verf(rqstp); + goto null_verifier; rsci = gss_svc_searchbyctx(cd, out_handle); if (rsci == NULL) { *major_status = GSS_S_NO_CONTEXT; - return gss_write_null_verf(rqstp); + goto null_verifier; } - rc = gss_write_verf(rqstp, rsci->mechctx, GSS_SEQ_WIN); + + rc = svcauth_gss_encode_verf(rqstp, rsci->mechctx, seq_num); cache_put(&rsci->h, cd); return rc; -} - -static inline int -gss_read_common_verf(struct rpc_gss_wire_cred *gc, - struct kvec *argv, __be32 *authp, - struct xdr_netobj *in_handle) -{ - /* Read the verifier; should be NULL: */ - *authp = rpc_autherr_badverf; - if (argv->iov_len < 2 * 4) - return SVC_DENIED; - if (svc_getnl(argv) != RPC_AUTH_NULL) - return SVC_DENIED; - if (svc_getnl(argv) != 0) - return SVC_DENIED; - /* Martial context handle and token for upcall: */ - *authp = rpc_autherr_badcred; - if (gc->gc_proc == RPC_GSS_PROC_INIT && gc->gc_ctx.len != 0) - return SVC_DENIED; - if (dup_netobj(in_handle, &gc->gc_ctx)) - return SVC_CLOSE; - *authp = rpc_autherr_badverf; - return 0; +null_verifier: + return xdr_stream_encode_opaque_auth(xdr, RPC_AUTH_NULL, NULL, 0) > 0; } -static inline int -gss_read_verf(struct rpc_gss_wire_cred *gc, - struct kvec *argv, __be32 *authp, - struct xdr_netobj *in_handle, - struct xdr_netobj *in_token) +static void gss_free_in_token_pages(struct gssp_in_token *in_token) { - struct xdr_netobj tmpobj; - int res; + int i; - res = gss_read_common_verf(gc, argv, authp, in_handle); - if (res) - return res; - - if (svc_safe_getnetobj(argv, &tmpobj)) { - kfree(in_handle->data); - return SVC_DENIED; - } - if (dup_netobj(in_token, &tmpobj)) { - kfree(in_handle->data); - return SVC_CLOSE; - } - - return 0; + i = 0; + while (in_token->pages[i]) + put_page(in_token->pages[i++]); + kfree(in_token->pages); + in_token->pages = NULL; } -/* Ok this is really heavily depending on a set of semantics in - * how rqstp is set up by svc_recv and pages laid down by the - * server when reading a request. We are basically guaranteed that - * the token lays all down linearly across a set of pages, starting - * at iov_base in rq_arg.head[0] which happens to be the first of a - * set of pages stored in rq_pages[]. - * rq_arg.head[0].iov_base will provide us the page_base to pass - * to the upcall. - */ -static inline int -gss_read_proxy_verf(struct svc_rqst *rqstp, - struct rpc_gss_wire_cred *gc, __be32 *authp, - struct xdr_netobj *in_handle, - struct gssp_in_token *in_token) +static int gss_read_proxy_verf(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc, + struct xdr_netobj *in_handle, + struct gssp_in_token *in_token) { - struct kvec *argv = &rqstp->rq_arg.head[0]; + struct xdr_stream *xdr = &rqstp->rq_arg_stream; + unsigned int length, pgto_offs, pgfrom_offs; + int pages, i, pgto, pgfrom; + size_t to_offs, from_offs; u32 inlen; - int res; - res = gss_read_common_verf(gc, argv, authp, in_handle); - if (res) - return res; - - inlen = svc_getnl(argv); - if (inlen > (argv->iov_len + rqstp->rq_arg.page_len)) - return SVC_DENIED; + if (dup_netobj(in_handle, &gc->gc_ctx)) + return SVC_CLOSE; - in_token->pages = rqstp->rq_pages; - in_token->page_base = (ulong)argv->iov_base & ~PAGE_MASK; + /* + * RFC 2203 Section 5.2.2 + * + * struct rpc_gss_init_arg { + * opaque gss_token<>; + * }; + */ + if (xdr_stream_decode_u32(xdr, &inlen) < 0) + goto out_denied_free; + if (inlen > xdr_stream_remaining(xdr)) + goto out_denied_free; + + pages = DIV_ROUND_UP(inlen, PAGE_SIZE); + in_token->pages = kcalloc(pages + 1, sizeof(struct page *), GFP_KERNEL); + if (!in_token->pages) + goto out_denied_free; + in_token->page_base = 0; in_token->page_len = inlen; + for (i = 0; i < pages; i++) { + in_token->pages[i] = alloc_page(GFP_KERNEL); + if (!in_token->pages[i]) { + gss_free_in_token_pages(in_token); + goto out_denied_free; + } + } + length = min_t(unsigned int, inlen, (char *)xdr->end - (char *)xdr->p); + memcpy(page_address(in_token->pages[0]), xdr->p, length); + inlen -= length; + + to_offs = length; + from_offs = rqstp->rq_arg.page_base; + while (inlen) { + pgto = to_offs >> PAGE_SHIFT; + pgfrom = from_offs >> PAGE_SHIFT; + pgto_offs = to_offs & ~PAGE_MASK; + pgfrom_offs = from_offs & ~PAGE_MASK; + + length = min_t(unsigned int, inlen, + min_t(unsigned int, PAGE_SIZE - pgto_offs, + PAGE_SIZE - pgfrom_offs)); + memcpy(page_address(in_token->pages[pgto]) + pgto_offs, + page_address(rqstp->rq_arg.pages[pgfrom]) + pgfrom_offs, + length); + + to_offs += length; + from_offs += length; + inlen -= length; + } return 0; + +out_denied_free: + kfree(in_handle->data); + return SVC_DENIED; } -static inline int -gss_write_resv(struct kvec *resv, size_t size_limit, - struct xdr_netobj *out_handle, struct xdr_netobj *out_token, - int major_status, int minor_status) -{ - if (resv->iov_len + 4 > size_limit) - return -1; - svc_putnl(resv, RPC_SUCCESS); - if (svc_safe_putnetobj(resv, out_handle)) - return -1; - if (resv->iov_len + 3 * 4 > size_limit) - return -1; - svc_putnl(resv, major_status); - svc_putnl(resv, minor_status); - svc_putnl(resv, GSS_SEQ_WIN); - if (svc_safe_putnetobj(resv, out_token)) - return -1; - return 0; +/* + * RFC 2203, Section 5.2.3.1. + * + * struct rpc_gss_init_res { + * opaque handle<>; + * unsigned int gss_major; + * unsigned int gss_minor; + * unsigned int seq_window; + * opaque gss_token<>; + * }; + */ +static bool +svcxdr_encode_gss_init_res(struct xdr_stream *xdr, + struct xdr_netobj *handle, + struct xdr_netobj *gss_token, + unsigned int major_status, + unsigned int minor_status, u32 seq_num) +{ + if (xdr_stream_encode_opaque(xdr, handle->data, handle->len) < 0) + return false; + if (xdr_stream_encode_u32(xdr, major_status) < 0) + return false; + if (xdr_stream_encode_u32(xdr, minor_status) < 0) + return false; + if (xdr_stream_encode_u32(xdr, seq_num) < 0) + return false; + if (xdr_stream_encode_opaque(xdr, gss_token->data, gss_token->len) < 0) + return false; + return true; } /* @@ -1095,20 +1150,44 @@ gss_write_resv(struct kvec *resv, size_t size_limit, * the upcall results are available, write the verifier and result. * Otherwise, drop the request pending an answer to the upcall. */ -static int svcauth_gss_legacy_init(struct svc_rqst *rqstp, - struct rpc_gss_wire_cred *gc, __be32 *authp) +static int +svcauth_gss_legacy_init(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc) { - struct kvec *argv = &rqstp->rq_arg.head[0]; - struct kvec *resv = &rqstp->rq_res.head[0]; + struct xdr_stream *xdr = &rqstp->rq_arg_stream; struct rsi *rsip, rsikey; + __be32 *p; + u32 len; int ret; - struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net, sunrpc_net_id); + struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id); memset(&rsikey, 0, sizeof(rsikey)); - ret = gss_read_verf(gc, argv, authp, - &rsikey.in_handle, &rsikey.in_token); - if (ret) - return ret; + if (dup_netobj(&rsikey.in_handle, &gc->gc_ctx)) + return SVC_CLOSE; + + /* + * RFC 2203 Section 5.2.2 + * + * struct rpc_gss_init_arg { + * opaque gss_token<>; + * }; + */ + if (xdr_stream_decode_u32(xdr, &len) < 0) { + kfree(rsikey.in_handle.data); + return SVC_DENIED; + } + p = xdr_inline_decode(xdr, len); + if (!p) { + kfree(rsikey.in_handle.data); + return SVC_DENIED; + } + rsikey.in_token.data = kmalloc(len, GFP_KERNEL); + if (ZERO_OR_NULL_PTR(rsikey.in_token.data)) { + kfree(rsikey.in_handle.data); + return SVC_CLOSE; + } + memcpy(rsikey.in_token.data, p, len); + rsikey.in_token.len = len; /* Perform upcall, or find upcall result: */ rsip = rsi_lookup(sn->rsi_cache, &rsikey); @@ -1120,13 +1199,14 @@ static int svcauth_gss_legacy_init(struct svc_rqst *rqstp, return SVC_CLOSE; ret = SVC_CLOSE; - /* Got an answer to the upcall; use it: */ - if (gss_write_init_verf(sn->rsc_cache, rqstp, - &rsip->out_handle, &rsip->major_status)) + if (!svcauth_gss_proc_init_verf(sn->rsc_cache, rqstp, &rsip->out_handle, + &rsip->major_status, GSS_SEQ_WIN)) + goto out; + if (!svcxdr_set_accept_stat(rqstp)) goto out; - if (gss_write_resv(resv, PAGE_SIZE, - &rsip->out_handle, &rsip->out_token, - rsip->major_status, rsip->minor_status)) + if (!svcxdr_encode_gss_init_res(&rqstp->rq_res_stream, &rsip->out_handle, + &rsip->out_token, rsip->major_status, + rsip->minor_status, GSS_SEQ_WIN)) goto out; ret = SVC_COMPLETE; @@ -1143,8 +1223,8 @@ static int gss_proxy_save_rsc(struct cache_detail *cd, static atomic64_t ctxhctr; long long ctxh; struct gss_api_mech *gm = NULL; - time_t expiry; - int status = -EINVAL; + time64_t expiry; + int status; memset(&rsci, 0, sizeof(rsci)); /* context handle */ @@ -1167,9 +1247,9 @@ static int gss_proxy_save_rsc(struct cache_detail *cd, if (!ud->found_creds) { /* userspace seem buggy, we should always get at least a * mapping to nobody */ - dprintk("RPC: No creds found, marking Negative!\n"); - set_bit(CACHE_NEGATIVE, &rsci.h.flags); + goto out; } else { + struct timespec64 boot; /* steal creds */ rsci.cred = ud->creds; @@ -1180,6 +1260,7 @@ static int gss_proxy_save_rsc(struct cache_detail *cd, gm = gss_mech_get_by_OID(&ud->mech_oid); if (!gm) goto out; + rsci.cred.cr_gss_mech = gm; status = -EINVAL; /* mech-specific data: */ @@ -1189,13 +1270,15 @@ static int gss_proxy_save_rsc(struct cache_detail *cd, &expiry, GFP_KERNEL); if (status) goto out; + + getboottime64(&boot); + expiry -= boot.tv_sec; } rsci.h.expiry_time = expiry; rscp = rsc_update(cd, &rsci, rscp); status = 0; out: - gss_mech_put(gm); rsc_free(&rsci); if (rscp) cache_put(&rscp->h, cd); @@ -1205,20 +1288,18 @@ out: } static int svcauth_gss_proxy_init(struct svc_rqst *rqstp, - struct rpc_gss_wire_cred *gc, __be32 *authp) + struct rpc_gss_wire_cred *gc) { - struct kvec *resv = &rqstp->rq_res.head[0]; struct xdr_netobj cli_handle; struct gssp_upcall_data ud; uint64_t handle; int status; int ret; - struct net *net = rqstp->rq_xprt->xpt_net; + struct net *net = SVC_NET(rqstp); struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); memset(&ud, 0, sizeof(ud)); - ret = gss_read_proxy_verf(rqstp, gc, authp, - &ud.in_handle, &ud.in_token); + ret = gss_read_proxy_verf(rqstp, gc, &ud.in_handle, &ud.in_token); if (ret) return ret; @@ -1229,8 +1310,7 @@ static int svcauth_gss_proxy_init(struct svc_rqst *rqstp, if (status) goto out; - dprintk("RPC: svcauth_gss: gss major status = %d\n", - ud.major_status); + trace_rpcgss_svc_accept_upcall(rqstp, ud.major_status, ud.minor_status); switch (ud.major_status) { case GSS_S_CONTINUE_NEEDED: @@ -1244,89 +1324,84 @@ static int svcauth_gss_proxy_init(struct svc_rqst *rqstp, cli_handle.len = sizeof(handle); break; default: - ret = SVC_CLOSE; goto out; } - /* Got an answer to the upcall; use it: */ - if (gss_write_init_verf(sn->rsc_cache, rqstp, - &cli_handle, &ud.major_status)) + if (!svcauth_gss_proc_init_verf(sn->rsc_cache, rqstp, &cli_handle, + &ud.major_status, GSS_SEQ_WIN)) + goto out; + if (!svcxdr_set_accept_stat(rqstp)) goto out; - if (gss_write_resv(resv, PAGE_SIZE, - &cli_handle, &ud.out_token, - ud.major_status, ud.minor_status)) + if (!svcxdr_encode_gss_init_res(&rqstp->rq_res_stream, &cli_handle, + &ud.out_token, ud.major_status, + ud.minor_status, GSS_SEQ_WIN)) goto out; ret = SVC_COMPLETE; out: + gss_free_in_token_pages(&ud.in_token); gssp_free_upcall_data(&ud); return ret; } -DEFINE_SPINLOCK(use_gssp_lock); - -static bool use_gss_proxy(struct net *net) +/* + * Try to set the sn->use_gss_proxy variable to a new value. We only allow + * it to be changed if it's currently undefined (-1). If it's any other value + * then return -EBUSY unless the type wouldn't have changed anyway. + */ +static int set_gss_proxy(struct net *net, int type) { struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int ret; - if (sn->use_gss_proxy != -1) - return sn->use_gss_proxy; - spin_lock(&use_gssp_lock); - /* - * If you wanted gss-proxy, you should have said so before - * starting to accept requests: - */ - sn->use_gss_proxy = 0; - spin_unlock(&use_gssp_lock); + WARN_ON_ONCE(type != 0 && type != 1); + ret = cmpxchg(&sn->use_gss_proxy, -1, type); + if (ret != -1 && ret != type) + return -EBUSY; return 0; } -#ifdef CONFIG_PROC_FS - -static int set_gss_proxy(struct net *net, int type) +static bool use_gss_proxy(struct net *net) { struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); - int ret = 0; - WARN_ON_ONCE(type != 0 && type != 1); - spin_lock(&use_gssp_lock); - if (sn->use_gss_proxy == -1 || sn->use_gss_proxy == type) - sn->use_gss_proxy = type; - else - ret = -EBUSY; - spin_unlock(&use_gssp_lock); - wake_up(&sn->gssp_wq); - return ret; + /* If use_gss_proxy is still undefined, then try to disable it */ + if (sn->use_gss_proxy == -1) + set_gss_proxy(net, 0); + return sn->use_gss_proxy; } -static inline bool gssp_ready(struct sunrpc_net *sn) +static noinline_for_stack int +svcauth_gss_proc_init(struct svc_rqst *rqstp, struct rpc_gss_wire_cred *gc) { - switch (sn->use_gss_proxy) { - case -1: - return false; - case 0: - return true; - case 1: - return sn->gssp_clnt; + struct xdr_stream *xdr = &rqstp->rq_arg_stream; + u32 flavor, len; + void *body; + + /* Call's verf field: */ + if (xdr_stream_decode_opaque_auth(xdr, &flavor, &body, &len) < 0) + return SVC_GARBAGE; + if (flavor != RPC_AUTH_NULL || len != 0) { + rqstp->rq_auth_stat = rpc_autherr_badverf; + return SVC_DENIED; } - WARN_ON_ONCE(1); - return false; -} -static int wait_for_gss_proxy(struct net *net, struct file *file) -{ - struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + if (gc->gc_proc == RPC_GSS_PROC_INIT && gc->gc_ctx.len != 0) { + rqstp->rq_auth_stat = rpc_autherr_badcred; + return SVC_DENIED; + } - if (file->f_flags & O_NONBLOCK && !gssp_ready(sn)) - return -EAGAIN; - return wait_event_interruptible(sn->gssp_wq, gssp_ready(sn)); + if (!use_gss_proxy(SVC_NET(rqstp))) + return svcauth_gss_legacy_init(rqstp, gc); + return svcauth_gss_proxy_init(rqstp, gc); } +#ifdef CONFIG_PROC_FS static ssize_t write_gssp(struct file *file, const char __user *buf, size_t count, loff_t *ppos) { - struct net *net = PDE_DATA(file_inode(file)); + struct net *net = pde_data(file_inode(file)); char tbuf[20]; unsigned long i; int res; @@ -1342,10 +1417,10 @@ static ssize_t write_gssp(struct file *file, const char __user *buf, return res; if (i != 1) return -EINVAL; - res = set_gss_proxy(net, 1); + res = set_gssp_clnt(net); if (res) return res; - res = set_gssp_clnt(net); + res = set_gss_proxy(net, 1); if (res) return res; return count; @@ -1354,17 +1429,13 @@ static ssize_t write_gssp(struct file *file, const char __user *buf, static ssize_t read_gssp(struct file *file, char __user *buf, size_t count, loff_t *ppos) { - struct net *net = PDE_DATA(file_inode(file)); + struct net *net = pde_data(file_inode(file)); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); unsigned long p = *ppos; char tbuf[10]; size_t len; - int ret; - - ret = wait_for_gss_proxy(net, file); - if (ret) - return ret; - snprintf(tbuf, sizeof(tbuf), "%d\n", use_gss_proxy(net)); + snprintf(tbuf, sizeof(tbuf), "%d\n", sn->use_gss_proxy); len = strlen(tbuf); if (p >= len) return 0; @@ -1377,10 +1448,10 @@ static ssize_t read_gssp(struct file *file, char __user *buf, return len; } -static const struct file_operations use_gss_proxy_ops = { - .open = nonseekable_open, - .write = write_gssp, - .read = read_gssp, +static const struct proc_ops use_gss_proxy_proc_ops = { + .proc_open = nonseekable_open, + .proc_write = write_gssp, + .proc_read = read_gssp, }; static int create_use_gss_proxy_proc_entry(struct net *net) @@ -1389,9 +1460,9 @@ static int create_use_gss_proxy_proc_entry(struct net *net) struct proc_dir_entry **p = &sn->use_gssp_proc; sn->use_gss_proxy = -1; - *p = proc_create_data("use-gss-proxy", S_IFREG|S_IRUSR|S_IWUSR, + *p = proc_create_data("use-gss-proxy", S_IFREG | 0600, sn->proc_net_rpc, - &use_gss_proxy_ops, net); + &use_gss_proxy_proc_ops, net); if (!*p) return -ENOMEM; init_gssp_clnt(sn); @@ -1403,10 +1474,60 @@ static void destroy_use_gss_proxy_proc_entry(struct net *net) struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); if (sn->use_gssp_proc) { - remove_proc_entry("use-gss-proxy", sn->proc_net_rpc); + remove_proc_entry("use-gss-proxy", sn->proc_net_rpc); clear_gssp_clnt(sn); } } + +static ssize_t read_gss_krb5_enctypes(struct file *file, char __user *buf, + size_t count, loff_t *ppos) +{ + struct rpcsec_gss_oid oid = { + .len = 9, + .data = "\x2a\x86\x48\x86\xf7\x12\x01\x02\x02", + }; + struct gss_api_mech *mech; + ssize_t ret; + + mech = gss_mech_get_by_OID(&oid); + if (!mech) + return 0; + if (!mech->gm_upcall_enctypes) { + gss_mech_put(mech); + return 0; + } + + ret = simple_read_from_buffer(buf, count, ppos, + mech->gm_upcall_enctypes, + strlen(mech->gm_upcall_enctypes)); + gss_mech_put(mech); + return ret; +} + +static const struct proc_ops gss_krb5_enctypes_proc_ops = { + .proc_open = nonseekable_open, + .proc_read = read_gss_krb5_enctypes, +}; + +static int create_krb5_enctypes_proc_entry(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + sn->gss_krb5_enctypes = + proc_create_data("gss_krb5_enctypes", S_IFREG | 0444, + sn->proc_net_rpc, &gss_krb5_enctypes_proc_ops, + net); + return sn->gss_krb5_enctypes ? 0 : -ENOMEM; +} + +static void destroy_krb5_enctypes_proc_entry(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + if (sn->gss_krb5_enctypes) + remove_proc_entry("gss_krb5_enctypes", sn->proc_net_rpc); +} + #else /* CONFIG_PROC_FS */ static int create_use_gss_proxy_proc_entry(struct net *net) @@ -1416,86 +1537,129 @@ static int create_use_gss_proxy_proc_entry(struct net *net) static void destroy_use_gss_proxy_proc_entry(struct net *net) {} +static int create_krb5_enctypes_proc_entry(struct net *net) +{ + return 0; +} + +static void destroy_krb5_enctypes_proc_entry(struct net *net) {} + #endif /* CONFIG_PROC_FS */ /* - * Accept an rpcsec packet. - * If context establishment, punt to user space - * If data exchange, verify/decrypt - * If context destruction, handle here - * In the context establishment and destruction case we encode - * response here and return SVC_COMPLETE. + * The Call's credential body should contain a struct rpc_gss_cred_t. + * + * RFC 2203 Section 5 + * + * struct rpc_gss_cred_t { + * union switch (unsigned int version) { + * case RPCSEC_GSS_VERS_1: + * struct { + * rpc_gss_proc_t gss_proc; + * unsigned int seq_num; + * rpc_gss_service_t service; + * opaque handle<>; + * } rpc_gss_cred_vers_1_t; + * } + * }; */ -static int -svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) +static bool +svcauth_gss_decode_credbody(struct xdr_stream *xdr, + struct rpc_gss_wire_cred *gc, + __be32 **rpcstart) +{ + ssize_t handle_len; + u32 body_len; + __be32 *p; + + p = xdr_inline_decode(xdr, XDR_UNIT); + if (!p) + return false; + /* + * start of rpc packet is 7 u32's back from here: + * xid direction rpcversion prog vers proc flavour + */ + *rpcstart = p - 7; + body_len = be32_to_cpup(p); + if (body_len > RPC_MAX_AUTH_SIZE) + return false; + + /* struct rpc_gss_cred_t */ + if (xdr_stream_decode_u32(xdr, &gc->gc_v) < 0) + return false; + if (xdr_stream_decode_u32(xdr, &gc->gc_proc) < 0) + return false; + if (xdr_stream_decode_u32(xdr, &gc->gc_seq) < 0) + return false; + if (xdr_stream_decode_u32(xdr, &gc->gc_svc) < 0) + return false; + handle_len = xdr_stream_decode_opaque_inline(xdr, + (void **)&gc->gc_ctx.data, + body_len); + if (handle_len < 0) + return false; + if (body_len != XDR_UNIT * 5 + xdr_align_size(handle_len)) + return false; + + gc->gc_ctx.len = handle_len; + return true; +} + +/** + * svcauth_gss_accept - Decode and validate incoming RPC_AUTH_GSS credential + * @rqstp: RPC transaction + * + * Return values: + * %SVC_OK: Success + * %SVC_COMPLETE: GSS context lifetime event + * %SVC_DENIED: Credential or verifier is not valid + * %SVC_GARBAGE: Failed to decode credential or verifier + * %SVC_CLOSE: Temporary failure + * + * The rqstp->rq_auth_stat field is also set (see RFCs 2203 and 5531). + */ +static enum svc_auth_status +svcauth_gss_accept(struct svc_rqst *rqstp) { - struct kvec *argv = &rqstp->rq_arg.head[0]; - struct kvec *resv = &rqstp->rq_res.head[0]; - u32 crlen; struct gss_svc_data *svcdata = rqstp->rq_auth_data; + __be32 *rpcstart; struct rpc_gss_wire_cred *gc; struct rsc *rsci = NULL; - __be32 *rpcstart; - __be32 *reject_stat = resv->iov_base + resv->iov_len; int ret; - struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net, sunrpc_net_id); - - dprintk("RPC: svcauth_gss: argv->iov_len = %zd\n", - argv->iov_len); + struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id); - *authp = rpc_autherr_badcred; + rqstp->rq_auth_stat = rpc_autherr_failed; if (!svcdata) svcdata = kmalloc(sizeof(*svcdata), GFP_KERNEL); if (!svcdata) goto auth_err; rqstp->rq_auth_data = svcdata; - svcdata->verf_start = NULL; + svcdata->gsd_databody_offset = 0; svcdata->rsci = NULL; gc = &svcdata->clcred; - /* start of rpc packet is 7 u32's back from here: - * xid direction rpcversion prog vers proc flavour - */ - rpcstart = argv->iov_base; - rpcstart -= 7; - - /* credential is: - * version(==1), proc(0,1,2,3), seq, service (1,2,3), handle - * at least 5 u32s, and is preceded by length, so that makes 6. - */ - - if (argv->iov_len < 5 * 4) - goto auth_err; - crlen = svc_getnl(argv); - if (svc_getnl(argv) != RPC_GSS_VERSION) + rqstp->rq_auth_stat = rpc_autherr_badcred; + if (!svcauth_gss_decode_credbody(&rqstp->rq_arg_stream, gc, &rpcstart)) goto auth_err; - gc->gc_proc = svc_getnl(argv); - gc->gc_seq = svc_getnl(argv); - gc->gc_svc = svc_getnl(argv); - if (svc_safe_getnetobj(argv, &gc->gc_ctx)) - goto auth_err; - if (crlen != round_up_to_quad(gc->gc_ctx.len) + 5 * 4) - goto auth_err; - - if ((gc->gc_proc != RPC_GSS_PROC_DATA) && (rqstp->rq_proc != 0)) + if (gc->gc_v != RPC_GSS_VERSION) goto auth_err; - *authp = rpc_autherr_badverf; switch (gc->gc_proc) { case RPC_GSS_PROC_INIT: case RPC_GSS_PROC_CONTINUE_INIT: - if (use_gss_proxy(SVC_NET(rqstp))) - return svcauth_gss_proxy_init(rqstp, gc, authp); - else - return svcauth_gss_legacy_init(rqstp, gc, authp); - case RPC_GSS_PROC_DATA: + if (rqstp->rq_proc != 0) + goto auth_err; + return svcauth_gss_proc_init(rqstp, gc); case RPC_GSS_PROC_DESTROY: - /* Look up the context, and check the verifier: */ - *authp = rpcsec_gsserr_credproblem; + if (rqstp->rq_proc != 0) + goto auth_err; + fallthrough; + case RPC_GSS_PROC_DATA: + rqstp->rq_auth_stat = rpcsec_gsserr_credproblem; rsci = gss_svc_searchbyctx(sn->rsc_cache, &gc->gc_ctx); if (!rsci) goto auth_err; - switch (gss_verify_header(rqstp, rsci, rpcstart, gc, authp)) { + switch (svcauth_gss_verify_header(rqstp, rsci, rpcstart, gc)) { case SVC_OK: break; case SVC_DENIED: @@ -1505,47 +1669,50 @@ svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) } break; default: - *authp = rpc_autherr_rejectedcred; + if (rqstp->rq_proc != 0) + goto auth_err; + rqstp->rq_auth_stat = rpc_autherr_rejectedcred; goto auth_err; } /* now act upon the command: */ switch (gc->gc_proc) { case RPC_GSS_PROC_DESTROY: - if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq)) + if (!svcauth_gss_encode_verf(rqstp, rsci->mechctx, gc->gc_seq)) goto auth_err; - rsci->h.expiry_time = get_seconds(); - set_bit(CACHE_NEGATIVE, &rsci->h.flags); - if (resv->iov_len + 4 > PAGE_SIZE) - goto drop; - svc_putnl(resv, RPC_SUCCESS); + if (!svcxdr_set_accept_stat(rqstp)) + goto auth_err; + /* Delete the entry from the cache_list and call cache_put */ + sunrpc_cache_unhash(sn->rsc_cache, &rsci->h); goto complete; case RPC_GSS_PROC_DATA: - *authp = rpcsec_gsserr_ctxproblem; - svcdata->verf_start = resv->iov_base + resv->iov_len; - if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq)) + rqstp->rq_auth_stat = rpcsec_gsserr_ctxproblem; + if (!svcauth_gss_encode_verf(rqstp, rsci->mechctx, gc->gc_seq)) + goto auth_err; + if (!svcxdr_set_accept_stat(rqstp)) goto auth_err; + svcdata->gsd_databody_offset = xdr_stream_pos(&rqstp->rq_res_stream); rqstp->rq_cred = rsci->cred; get_group_info(rsci->cred.cr_group_info); - *authp = rpc_autherr_badcred; + rqstp->rq_auth_stat = rpc_autherr_badcred; switch (gc->gc_svc) { case RPC_GSS_SVC_NONE: break; case RPC_GSS_SVC_INTEGRITY: - /* placeholders for length and seq. number: */ - svc_putnl(resv, 0); - svc_putnl(resv, 0); - if (unwrap_integ_data(rqstp, &rqstp->rq_arg, - gc->gc_seq, rsci->mechctx)) + /* placeholders for body length and seq. number: */ + xdr_reserve_space(&rqstp->rq_res_stream, XDR_UNIT * 2); + if (svcauth_gss_unwrap_integ(rqstp, gc->gc_seq, + rsci->mechctx)) goto garbage_args; + svcxdr_set_auth_slack(rqstp, RPC_MAX_AUTH_SIZE); break; case RPC_GSS_SVC_PRIVACY: - /* placeholders for length and seq. number: */ - svc_putnl(resv, 0); - svc_putnl(resv, 0); - if (unwrap_priv_data(rqstp, &rqstp->rq_arg, - gc->gc_seq, rsci->mechctx)) + /* placeholders for body length and seq. number: */ + xdr_reserve_space(&rqstp->rq_res_stream, XDR_UNIT * 2); + if (svcauth_gss_unwrap_priv(rqstp, gc->gc_seq, + rsci->mechctx)) goto garbage_args; + svcxdr_set_auth_slack(rqstp, RPC_MAX_AUTH_SIZE * 2); break; default: goto auth_err; @@ -1557,124 +1724,146 @@ svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) GSS_C_QOP_DEFAULT, gc->gc_svc); ret = SVC_OK; + trace_rpcgss_svc_authenticate(rqstp, gc); goto out; } garbage_args: ret = SVC_GARBAGE; goto out; auth_err: - /* Restore write pointer to its original value: */ - xdr_ressize_check(rqstp, reject_stat); + xdr_truncate_encode(&rqstp->rq_res_stream, XDR_UNIT * 2); ret = SVC_DENIED; goto out; complete: ret = SVC_COMPLETE; goto out; drop: - ret = SVC_DROP; + ret = SVC_CLOSE; out: if (rsci) cache_put(&rsci->h, sn->rsc_cache); return ret; } -static __be32 * -svcauth_gss_prepare_to_wrap(struct xdr_buf *resbuf, struct gss_svc_data *gsd) +static u32 +svcauth_gss_prepare_to_wrap(struct svc_rqst *rqstp, struct gss_svc_data *gsd) { - __be32 *p; - u32 verf_len; + u32 offset; - p = gsd->verf_start; - gsd->verf_start = NULL; + /* Release can be called twice, but we only wrap once. */ + offset = gsd->gsd_databody_offset; + gsd->gsd_databody_offset = 0; - /* If the reply stat is nonzero, don't wrap: */ - if (*(p-1) != rpc_success) - return NULL; - /* Skip the verifier: */ - p += 1; - verf_len = ntohl(*p++); - p += XDR_QUADLEN(verf_len); - /* move accept_stat to right place: */ - memcpy(p, p + 2, 4); - /* Also don't wrap if the accept stat is nonzero: */ - if (*p != rpc_success) { - resbuf->head[0].iov_len -= 2 * 4; - return NULL; - } - p++; - return p; + /* AUTH_ERROR replies are not wrapped. */ + if (rqstp->rq_auth_stat != rpc_auth_ok) + return 0; + + /* Also don't wrap if the accept_stat is nonzero: */ + if (*rqstp->rq_accept_statp != rpc_success) + return 0; + + return offset; } -static inline int -svcauth_gss_wrap_resp_integ(struct svc_rqst *rqstp) +/* + * RFC 2203, Section 5.3.2.2 + * + * struct rpc_gss_integ_data { + * opaque databody_integ<>; + * opaque checksum<>; + * }; + * + * struct rpc_gss_data_t { + * unsigned int seq_num; + * proc_req_arg_t arg; + * }; + * + * The RPC Reply message has already been XDR-encoded. rq_res_stream + * is now positioned so that the checksum can be written just past + * the RPC Reply message. + */ +static int svcauth_gss_wrap_integ(struct svc_rqst *rqstp) { - struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct gss_svc_data *gsd = rqstp->rq_auth_data; + struct xdr_stream *xdr = &rqstp->rq_res_stream; struct rpc_gss_wire_cred *gc = &gsd->clcred; - struct xdr_buf *resbuf = &rqstp->rq_res; - struct xdr_buf integ_buf; - struct xdr_netobj mic; - struct kvec *resv; - __be32 *p; - int integ_offset, integ_len; - int stat = -EINVAL; + struct xdr_buf *buf = xdr->buf; + struct xdr_buf databody_integ; + struct xdr_netobj checksum; + u32 offset, maj_stat; - p = svcauth_gss_prepare_to_wrap(resbuf, gsd); - if (p == NULL) + offset = svcauth_gss_prepare_to_wrap(rqstp, gsd); + if (!offset) goto out; - integ_offset = (u8 *)(p + 1) - (u8 *)resbuf->head[0].iov_base; - integ_len = resbuf->len - integ_offset; - BUG_ON(integ_len % 4); - *p++ = htonl(integ_len); - *p++ = htonl(gc->gc_seq); - if (xdr_buf_subsegment(resbuf, &integ_buf, integ_offset, - integ_len)) - BUG(); - if (resbuf->tail[0].iov_base == NULL) { - if (resbuf->head[0].iov_len + RPC_MAX_AUTH_SIZE > PAGE_SIZE) - goto out_err; - resbuf->tail[0].iov_base = resbuf->head[0].iov_base - + resbuf->head[0].iov_len; - resbuf->tail[0].iov_len = 0; - resv = &resbuf->tail[0]; - } else { - resv = &resbuf->tail[0]; - } - mic.data = (u8 *)resv->iov_base + resv->iov_len + 4; - if (gss_get_mic(gsd->rsci->mechctx, &integ_buf, &mic)) - goto out_err; - svc_putnl(resv, mic.len); - memset(mic.data + mic.len, 0, - round_up_to_quad(mic.len) - mic.len); - resv->iov_len += XDR_QUADLEN(mic.len) << 2; - /* not strictly required: */ - resbuf->len += XDR_QUADLEN(mic.len) << 2; - BUG_ON(resv->iov_len > PAGE_SIZE); + + if (xdr_buf_subsegment(buf, &databody_integ, offset + XDR_UNIT, + buf->len - offset - XDR_UNIT)) + goto wrap_failed; + /* Buffer space for these has already been reserved in + * svcauth_gss_accept(). */ + if (xdr_encode_word(buf, offset, databody_integ.len)) + goto wrap_failed; + if (xdr_encode_word(buf, offset + XDR_UNIT, gc->gc_seq)) + goto wrap_failed; + + checksum.data = gsd->gsd_scratch; + maj_stat = gss_get_mic(gsd->rsci->mechctx, &databody_integ, &checksum); + if (maj_stat != GSS_S_COMPLETE) + goto bad_mic; + + if (xdr_stream_encode_opaque(xdr, checksum.data, checksum.len) < 0) + goto wrap_failed; + xdr_commit_encode(xdr); + out: - stat = 0; -out_err: - return stat; + return 0; + +bad_mic: + trace_rpcgss_svc_get_mic(rqstp, maj_stat); + return -EINVAL; +wrap_failed: + trace_rpcgss_svc_wrap_failed(rqstp); + return -EINVAL; } -static inline int -svcauth_gss_wrap_resp_priv(struct svc_rqst *rqstp) +/* + * RFC 2203, Section 5.3.2.3 + * + * struct rpc_gss_priv_data { + * opaque databody_priv<> + * }; + * + * struct rpc_gss_data_t { + * unsigned int seq_num; + * proc_req_arg_t arg; + * }; + * + * gss_wrap() expands the size of the RPC message payload in the + * response buffer. The main purpose of svcauth_gss_wrap_priv() + * is to ensure there is adequate space in the response buffer to + * avoid overflow during the wrap. + */ +static int svcauth_gss_wrap_priv(struct svc_rqst *rqstp) { - struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct gss_svc_data *gsd = rqstp->rq_auth_data; struct rpc_gss_wire_cred *gc = &gsd->clcred; - struct xdr_buf *resbuf = &rqstp->rq_res; - struct page **inpages = NULL; - __be32 *p, *len; - int offset; - int pad; - - p = svcauth_gss_prepare_to_wrap(resbuf, gsd); - if (p == NULL) + struct xdr_buf *buf = &rqstp->rq_res; + struct kvec *head = buf->head; + struct kvec *tail = buf->tail; + u32 offset, pad, maj_stat; + __be32 *p; + + offset = svcauth_gss_prepare_to_wrap(rqstp, gsd); + if (!offset) return 0; - len = p++; - offset = (u8 *)p - (u8 *)resbuf->head[0].iov_base; - *p++ = htonl(gc->gc_seq); - inpages = resbuf->pages; - /* XXX: Would be better to write some xdr helper functions for - * nfs{2,3,4}xdr.c that place the data right, instead of copying: */ + + /* + * Buffer space for this field has already been reserved + * in svcauth_gss_accept(). Note that the GSS sequence + * number is encrypted along with the RPC reply payload. + */ + if (xdr_encode_word(buf, offset + XDR_UNIT, gc->gc_seq)) + goto wrap_failed; /* * If there is currently tail data, make sure there is @@ -1683,17 +1872,17 @@ svcauth_gss_wrap_resp_priv(struct svc_rqst *rqstp) * there is RPC_MAX_AUTH_SIZE slack space available in * both the head and tail. */ - if (resbuf->tail[0].iov_base) { - BUG_ON(resbuf->tail[0].iov_base >= resbuf->head[0].iov_base - + PAGE_SIZE); - BUG_ON(resbuf->tail[0].iov_base < resbuf->head[0].iov_base); - if (resbuf->tail[0].iov_len + resbuf->head[0].iov_len + if (tail->iov_base) { + if (tail->iov_base >= head->iov_base + PAGE_SIZE) + goto wrap_failed; + if (tail->iov_base < head->iov_base) + goto wrap_failed; + if (tail->iov_len + head->iov_len + 2 * RPC_MAX_AUTH_SIZE > PAGE_SIZE) - return -ENOMEM; - memmove(resbuf->tail[0].iov_base + RPC_MAX_AUTH_SIZE, - resbuf->tail[0].iov_base, - resbuf->tail[0].iov_len); - resbuf->tail[0].iov_base += RPC_MAX_AUTH_SIZE; + goto wrap_failed; + memmove(tail->iov_base + RPC_MAX_AUTH_SIZE, tail->iov_base, + tail->iov_len); + tail->iov_base += RPC_MAX_AUTH_SIZE; } /* * If there is no current tail data, make sure there is @@ -1702,52 +1891,70 @@ svcauth_gss_wrap_resp_priv(struct svc_rqst *rqstp) * is RPC_MAX_AUTH_SIZE slack space available in both the * head and tail. */ - if (resbuf->tail[0].iov_base == NULL) { - if (resbuf->head[0].iov_len + 2*RPC_MAX_AUTH_SIZE > PAGE_SIZE) - return -ENOMEM; - resbuf->tail[0].iov_base = resbuf->head[0].iov_base - + resbuf->head[0].iov_len + RPC_MAX_AUTH_SIZE; - resbuf->tail[0].iov_len = 0; + if (!tail->iov_base) { + if (head->iov_len + 2 * RPC_MAX_AUTH_SIZE > PAGE_SIZE) + goto wrap_failed; + tail->iov_base = head->iov_base + + head->iov_len + RPC_MAX_AUTH_SIZE; + tail->iov_len = 0; } - if (gss_wrap(gsd->rsci->mechctx, offset, resbuf, inpages)) - return -ENOMEM; - *len = htonl(resbuf->len - offset); - pad = 3 - ((resbuf->len - offset - 1)&3); - p = (__be32 *)(resbuf->tail[0].iov_base + resbuf->tail[0].iov_len); + + maj_stat = gss_wrap(gsd->rsci->mechctx, offset + XDR_UNIT, buf, + buf->pages); + if (maj_stat != GSS_S_COMPLETE) + goto bad_wrap; + + /* Wrapping can change the size of databody_priv. */ + if (xdr_encode_word(buf, offset, buf->len - offset - XDR_UNIT)) + goto wrap_failed; + pad = xdr_pad_size(buf->len - offset - XDR_UNIT); + p = (__be32 *)(tail->iov_base + tail->iov_len); memset(p, 0, pad); - resbuf->tail[0].iov_len += pad; - resbuf->len += pad; + tail->iov_len += pad; + buf->len += pad; + return 0; +wrap_failed: + trace_rpcgss_svc_wrap_failed(rqstp); + return -EINVAL; +bad_wrap: + trace_rpcgss_svc_wrap(rqstp, maj_stat); + return -ENOMEM; } +/** + * svcauth_gss_release - Wrap payload and release resources + * @rqstp: RPC transaction context + * + * Return values: + * %0: the Reply is ready to be sent + * %-ENOMEM: failed to allocate memory + * %-EINVAL: encoding error + */ static int svcauth_gss_release(struct svc_rqst *rqstp) { - struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; - struct rpc_gss_wire_cred *gc = &gsd->clcred; - struct xdr_buf *resbuf = &rqstp->rq_res; - int stat = -EINVAL; - struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net, sunrpc_net_id); + struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id); + struct gss_svc_data *gsd = rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc; + int stat; - if (gc->gc_proc != RPC_GSS_PROC_DATA) + if (!gsd) goto out; - /* Release can be called twice, but we only wrap once. */ - if (gsd->verf_start == NULL) + gc = &gsd->clcred; + if (gc->gc_proc != RPC_GSS_PROC_DATA) goto out; - /* normally not set till svc_send, but we need it here: */ - /* XXX: what for? Do we mess it up the moment we call svc_putu32 - * or whatever? */ - resbuf->len = total_buf_len(resbuf); + switch (gc->gc_svc) { case RPC_GSS_SVC_NONE: break; case RPC_GSS_SVC_INTEGRITY: - stat = svcauth_gss_wrap_resp_integ(rqstp); + stat = svcauth_gss_wrap_integ(rqstp); if (stat) goto out_err; break; case RPC_GSS_SVC_PRIVACY: - stat = svcauth_gss_wrap_resp_priv(rqstp); + stat = svcauth_gss_wrap_priv(rqstp); if (stat) goto out_err; break; @@ -1769,22 +1976,34 @@ out_err: if (rqstp->rq_cred.cr_group_info) put_group_info(rqstp->rq_cred.cr_group_info); rqstp->rq_cred.cr_group_info = NULL; - if (gsd->rsci) + if (gsd && gsd->rsci) { cache_put(&gsd->rsci->h, sn->rsc_cache); - gsd->rsci = NULL; - + gsd->rsci = NULL; + } return stat; } static void -svcauth_gss_domain_release(struct auth_domain *dom) +svcauth_gss_domain_release_rcu(struct rcu_head *head) { + struct auth_domain *dom = container_of(head, struct auth_domain, rcu_head); struct gss_domain *gd = container_of(dom, struct gss_domain, h); kfree(dom->name); kfree(gd); } +static void +svcauth_gss_domain_release(struct auth_domain *dom) +{ + call_rcu(&dom->rcu_head, svcauth_gss_domain_release_rcu); +} + +static rpc_authflavor_t svcauth_gss_pseudoflavor(struct svc_rqst *rqstp) +{ + return svcauth_gss_flavor(rqstp->rq_gssclient); +} + static struct auth_ops svcauthops_gss = { .name = "rpcsec_gss", .owner = THIS_MODULE, @@ -1793,6 +2012,7 @@ static struct auth_ops svcauthops_gss = { .release = svcauth_gss_release, .domain_release = svcauth_gss_domain_release, .set_client = svcauth_gss_set_client, + .pseudoflavor = svcauth_gss_pseudoflavor, }; static int rsi_cache_create_net(struct net *net) @@ -1867,9 +2087,17 @@ gss_svc_init_net(struct net *net) rv = create_use_gss_proxy_proc_entry(net); if (rv) goto out2; + + rv = create_krb5_enctypes_proc_entry(net); + if (rv) + goto out3; + return 0; -out2: + +out3: destroy_use_gss_proxy_proc_entry(net); +out2: + rsi_cache_destroy_net(net); out1: rsc_cache_destroy_net(net); return rv; @@ -1878,6 +2106,7 @@ out1: void gss_svc_shutdown_net(struct net *net) { + destroy_krb5_enctypes_proc_entry(net); destroy_use_gss_proxy_proc_entry(net); rsi_cache_destroy_net(net); rsc_cache_destroy_net(net); diff --git a/net/sunrpc/auth_gss/trace.c b/net/sunrpc/auth_gss/trace.c new file mode 100644 index 000000000000..76685abba60f --- /dev/null +++ b/net/sunrpc/auth_gss/trace.c @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2018, 2019 Oracle. All rights reserved. + */ + +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/svc.h> +#include <linux/sunrpc/svc_xprt.h> +#include <linux/sunrpc/auth_gss.h> +#include <linux/sunrpc/gss_err.h> + +#define CREATE_TRACE_POINTS +#include <trace/events/rpcgss.h> diff --git a/net/sunrpc/auth_null.c b/net/sunrpc/auth_null.c index a5c36c01707b..41a633a4049e 100644 --- a/net/sunrpc/auth_null.c +++ b/net/sunrpc/auth_null.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0 /* * linux/net/sunrpc/auth_null.c * @@ -10,7 +11,7 @@ #include <linux/module.h> #include <linux/sunrpc/clnt.h> -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif @@ -18,9 +19,9 @@ static struct rpc_auth null_auth; static struct rpc_cred null_cred; static struct rpc_auth * -nul_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +nul_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { - atomic_inc(&null_auth.au_count); + refcount_inc(&null_auth.au_count); return &null_auth; } @@ -58,15 +59,21 @@ nul_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) /* * Marshal credential. */ -static __be32 * -nul_marshal(struct rpc_task *task, __be32 *p) +static int +nul_marshal(struct rpc_task *task, struct xdr_stream *xdr) { - *p++ = htonl(RPC_AUTH_NULL); - *p++ = 0; - *p++ = htonl(RPC_AUTH_NULL); - *p++ = 0; - - return p; + __be32 *p; + + p = xdr_reserve_space(xdr, 4 * sizeof(*p)); + if (!p) + return -EMSGSIZE; + /* Credential */ + *p++ = rpc_auth_null; + *p++ = xdr_zero; + /* Verifier */ + *p++ = rpc_auth_null; + *p = xdr_zero; + return 0; } /* @@ -79,25 +86,19 @@ nul_refresh(struct rpc_task *task) return 0; } -static __be32 * -nul_validate(struct rpc_task *task, __be32 *p) +static int +nul_validate(struct rpc_task *task, struct xdr_stream *xdr) { - rpc_authflavor_t flavor; - u32 size; - - flavor = ntohl(*p++); - if (flavor != RPC_AUTH_NULL) { - printk("RPC: bad verf flavor: %u\n", flavor); - return NULL; - } - - size = ntohl(*p++); - if (size != 0) { - printk("RPC: bad verf size: %u\n", size); - return NULL; - } - - return p; + __be32 *p; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + return -EIO; + if (*p++ != rpc_auth_null) + return -EIO; + if (*p != xdr_zero) + return -EIO; + return 0; } const struct rpc_authops authnull_ops = { @@ -111,22 +112,25 @@ const struct rpc_authops authnull_ops = { static struct rpc_auth null_auth = { - .au_cslack = 4, - .au_rslack = 2, + .au_cslack = NUL_CALLSLACK, + .au_rslack = NUL_REPLYSLACK, + .au_verfsize = NUL_REPLYSLACK, + .au_ralign = NUL_REPLYSLACK, .au_ops = &authnull_ops, .au_flavor = RPC_AUTH_NULL, - .au_count = ATOMIC_INIT(0), + .au_count = REFCOUNT_INIT(1), }; static const struct rpc_credops null_credops = { .cr_name = "AUTH_NULL", .crdestroy = nul_destroy_cred, - .crbind = rpcauth_generic_bind_cred, .crmatch = nul_match, .crmarshal = nul_marshal, + .crwrap_req = rpcauth_wrap_req_encode, .crrefresh = nul_refresh, .crvalidate = nul_validate, + .crunwrap_resp = rpcauth_unwrap_resp_decode, }; static @@ -134,9 +138,6 @@ struct rpc_cred null_cred = { .cr_lru = LIST_HEAD_INIT(null_cred.cr_lru), .cr_auth = &null_auth, .cr_ops = &null_credops, - .cr_count = ATOMIC_INIT(1), + .cr_count = REFCOUNT_INIT(2), .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE, -#ifdef RPC_DEBUG - .cr_magic = RPCAUTH_CRED_MAGIC, -#endif }; diff --git a/net/sunrpc/auth_tls.c b/net/sunrpc/auth_tls.c new file mode 100644 index 000000000000..87f570fd3b00 --- /dev/null +++ b/net/sunrpc/auth_tls.c @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2021, 2022 Oracle. All rights reserved. + * + * The AUTH_TLS credential is used only to probe a remote peer + * for RPC-over-TLS support. + */ + +#include <linux/types.h> +#include <linux/module.h> +#include <linux/sunrpc/clnt.h> + +static const char *starttls_token = "STARTTLS"; +static const size_t starttls_len = 8; + +static struct rpc_auth tls_auth; +static struct rpc_cred tls_cred; + +static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + const void *obj) +{ +} + +static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + void *obj) +{ + return 0; +} + +static const struct rpc_procinfo rpcproc_tls_probe = { + .p_encode = tls_encode_probe, + .p_decode = tls_decode_probe, +}; + +static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data) +{ + task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT; + rpc_call_start(task); +} + +static void rpc_tls_probe_call_done(struct rpc_task *task, void *data) +{ +} + +static const struct rpc_call_ops rpc_tls_probe_ops = { + .rpc_call_prepare = rpc_tls_probe_call_prepare, + .rpc_call_done = rpc_tls_probe_call_done, +}; + +static int tls_probe(struct rpc_clnt *clnt) +{ + struct rpc_message msg = { + .rpc_proc = &rpcproc_tls_probe, + }; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = &msg, + .rpc_op_cred = &tls_cred, + .callback_ops = &rpc_tls_probe_ops, + .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN, + }; + struct rpc_task *task; + int status; + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + status = task->tk_status; + rpc_put_task(task); + return status; +} + +static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args, + struct rpc_clnt *clnt) +{ + refcount_inc(&tls_auth.au_count); + return &tls_auth; +} + +static void tls_destroy(struct rpc_auth *auth) +{ +} + +static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth, + struct auth_cred *acred, int flags) +{ + return get_rpccred(&tls_cred); +} + +static void tls_destroy_cred(struct rpc_cred *cred) +{ +} + +static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) +{ + return 1; +} + +static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4 * XDR_UNIT); + if (!p) + return -EMSGSIZE; + /* Credential */ + *p++ = rpc_auth_tls; + *p++ = xdr_zero; + /* Verifier */ + *p++ = rpc_auth_null; + *p = xdr_zero; + return 0; +} + +static int tls_refresh(struct rpc_task *task) +{ + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); + return 0; +} + +static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr) +{ + __be32 *p; + void *str; + + p = xdr_inline_decode(xdr, XDR_UNIT); + if (!p) + return -EIO; + if (*p != rpc_auth_null) + return -EIO; + if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len) + return -EPROTONOSUPPORT; + if (memcmp(str, starttls_token, starttls_len)) + return -EPROTONOSUPPORT; + return 0; +} + +const struct rpc_authops authtls_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_TLS, + .au_name = "NULL", + .create = tls_create, + .destroy = tls_destroy, + .lookup_cred = tls_lookup_cred, + .ping = tls_probe, +}; + +static struct rpc_auth tls_auth = { + .au_cslack = NUL_CALLSLACK, + .au_rslack = NUL_REPLYSLACK, + .au_verfsize = NUL_REPLYSLACK, + .au_ralign = NUL_REPLYSLACK, + .au_ops = &authtls_ops, + .au_flavor = RPC_AUTH_TLS, + .au_count = REFCOUNT_INIT(1), +}; + +static const struct rpc_credops tls_credops = { + .cr_name = "AUTH_TLS", + .crdestroy = tls_destroy_cred, + .crmatch = tls_match, + .crmarshal = tls_marshal, + .crwrap_req = rpcauth_wrap_req_encode, + .crrefresh = tls_refresh, + .crvalidate = tls_validate, + .crunwrap_resp = rpcauth_unwrap_resp_decode, +}; + +static struct rpc_cred tls_cred = { + .cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru), + .cr_auth = &tls_auth, + .cr_ops = &tls_credops, + .cr_count = REFCOUNT_INIT(2), + .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE, +}; diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c index dc37021fc3e5..1e091d3fa607 100644 --- a/net/sunrpc/auth_unix.c +++ b/net/sunrpc/auth_unix.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0 /* * linux/net/sunrpc/auth_unix.c * @@ -10,96 +11,60 @@ #include <linux/types.h> #include <linux/sched.h> #include <linux/module.h> +#include <linux/mempool.h> #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/auth.h> #include <linux/user_namespace.h> -#define NFS_NGROUPS 16 -struct unx_cred { - struct rpc_cred uc_base; - kgid_t uc_gid; - kgid_t uc_gids[NFS_NGROUPS]; -}; -#define uc_uid uc_base.cr_uid - -#define UNX_WRITESLACK (21 + (UNX_MAXNODENAME >> 2)) - -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif static struct rpc_auth unix_auth; static const struct rpc_credops unix_credops; +static mempool_t *unix_pool; static struct rpc_auth * -unx_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +unx_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { - dprintk("RPC: creating UNIX authenticator for client %p\n", - clnt); - atomic_inc(&unix_auth.au_count); + refcount_inc(&unix_auth.au_count); return &unix_auth; } static void unx_destroy(struct rpc_auth *auth) { - dprintk("RPC: destroying UNIX authenticator %p\n", auth); - rpcauth_clear_credcache(auth->au_credcache); } /* * Lookup AUTH_UNIX creds for current process */ -static struct rpc_cred * -unx_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +static struct rpc_cred *unx_lookup_cred(struct rpc_auth *auth, + struct auth_cred *acred, int flags) { - return rpcauth_lookup_credcache(auth, acred, flags); -} - -static struct rpc_cred * -unx_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) -{ - struct unx_cred *cred; - unsigned int groups = 0; - unsigned int i; - - dprintk("RPC: allocating UNIX cred for uid %d gid %d\n", - from_kuid(&init_user_ns, acred->uid), - from_kgid(&init_user_ns, acred->gid)); - - if (!(cred = kmalloc(sizeof(*cred), GFP_NOFS))) - return ERR_PTR(-ENOMEM); - - rpcauth_init_cred(&cred->uc_base, acred, auth, &unix_credops); - cred->uc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; - - if (acred->group_info != NULL) - groups = acred->group_info->ngroups; - if (groups > NFS_NGROUPS) - groups = NFS_NGROUPS; - - cred->uc_gid = acred->gid; - for (i = 0; i < groups; i++) - cred->uc_gids[i] = GROUP_AT(acred->group_info, i); - if (i < NFS_NGROUPS) - cred->uc_gids[i] = INVALID_GID; - - return &cred->uc_base; -} - -static void -unx_free_cred(struct unx_cred *unx_cred) -{ - dprintk("RPC: unx_free_cred %p\n", unx_cred); - kfree(unx_cred); + struct rpc_cred *ret; + + ret = kmalloc(sizeof(*ret), rpc_task_gfp_mask()); + if (!ret) { + if (!(flags & RPCAUTH_LOOKUP_ASYNC)) + return ERR_PTR(-ENOMEM); + ret = mempool_alloc(unix_pool, GFP_NOWAIT); + if (!ret) + return ERR_PTR(-ENOMEM); + } + rpcauth_init_cred(ret, acred, auth, &unix_credops); + ret->cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; + return ret; } static void unx_free_cred_callback(struct rcu_head *head) { - struct unx_cred *unx_cred = container_of(head, struct unx_cred, uc_base.cr_rcu); - unx_free_cred(unx_cred); + struct rpc_cred *rpc_cred = container_of(head, struct rpc_cred, cr_rcu); + + put_cred(rpc_cred->cr_cred); + mempool_free(rpc_cred, unix_pool); } static void @@ -109,30 +74,32 @@ unx_destroy_cred(struct rpc_cred *cred) } /* - * Match credentials against current process creds. - * The root_override argument takes care of cases where the caller may - * request root creds (e.g. for NFS swapping). + * Match credentials against current the auth_cred. */ static int -unx_match(struct auth_cred *acred, struct rpc_cred *rcred, int flags) +unx_match(struct auth_cred *acred, struct rpc_cred *cred, int flags) { - struct unx_cred *cred = container_of(rcred, struct unx_cred, uc_base); unsigned int groups = 0; unsigned int i; + if (cred->cr_cred == acred->cred) + return 1; + + if (!uid_eq(cred->cr_cred->fsuid, acred->cred->fsuid) || !gid_eq(cred->cr_cred->fsgid, acred->cred->fsgid)) + return 0; - if (!uid_eq(cred->uc_uid, acred->uid) || !gid_eq(cred->uc_gid, acred->gid)) + if (acred->cred->group_info != NULL) + groups = acred->cred->group_info->ngroups; + if (groups > UNX_NGROUPS) + groups = UNX_NGROUPS; + if (cred->cr_cred->group_info == NULL) + return groups == 0; + if (groups != cred->cr_cred->group_info->ngroups) return 0; - if (acred->group_info != NULL) - groups = acred->group_info->ngroups; - if (groups > NFS_NGROUPS) - groups = NFS_NGROUPS; for (i = 0; i < groups ; i++) - if (!gid_eq(cred->uc_gids[i], GROUP_AT(acred->group_info, i))) + if (!gid_eq(cred->cr_cred->group_info->gid[i], acred->cred->group_info->gid[i])) return 0; - if (groups < NFS_NGROUPS && gid_valid(cred->uc_gids[groups])) - return 0; return 1; } @@ -140,35 +107,56 @@ unx_match(struct auth_cred *acred, struct rpc_cred *rcred, int flags) * Marshal credentials. * Maybe we should keep a cached credential for performance reasons. */ -static __be32 * -unx_marshal(struct rpc_task *task, __be32 *p) +static int +unx_marshal(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_clnt *clnt = task->tk_client; - struct unx_cred *cred = container_of(task->tk_rqstp->rq_cred, struct unx_cred, uc_base); - __be32 *base, *hold; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + __be32 *p, *cred_len, *gidarr_len; int i; + struct group_info *gi = cred->cr_cred->group_info; + struct user_namespace *userns = clnt->cl_cred ? + clnt->cl_cred->user_ns : &init_user_ns; + + /* Credential */ + + p = xdr_reserve_space(xdr, 3 * sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_unix; + cred_len = p++; + *p++ = xdr_zero; /* stamp */ + if (xdr_stream_encode_opaque(xdr, clnt->cl_nodename, + clnt->cl_nodelen) < 0) + goto marshal_failed; + p = xdr_reserve_space(xdr, 3 * sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = cpu_to_be32(from_kuid_munged(userns, cred->cr_cred->fsuid)); + *p++ = cpu_to_be32(from_kgid_munged(userns, cred->cr_cred->fsgid)); + + gidarr_len = p++; + if (gi) + for (i = 0; i < UNX_NGROUPS && i < gi->ngroups; i++) + *p++ = cpu_to_be32(from_kgid_munged(userns, gi->gid[i])); + *gidarr_len = cpu_to_be32(p - gidarr_len - 1); + *cred_len = cpu_to_be32((p - cred_len - 1) << 2); + p = xdr_reserve_space(xdr, (p - gidarr_len - 1) << 2); + if (!p) + goto marshal_failed; + + /* Verifier */ + + p = xdr_reserve_space(xdr, 2 * sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_null; + *p = xdr_zero; - *p++ = htonl(RPC_AUTH_UNIX); - base = p++; - *p++ = htonl(jiffies/HZ); - - /* - * Copy the UTS nodename captured when the client was created. - */ - p = xdr_encode_array(p, clnt->cl_nodename, clnt->cl_nodelen); - - *p++ = htonl((u32) from_kuid(&init_user_ns, cred->uc_uid)); - *p++ = htonl((u32) from_kgid(&init_user_ns, cred->uc_gid)); - hold = p++; - for (i = 0; i < 16 && gid_valid(cred->uc_gids[i]); i++) - *p++ = htonl((u32) from_kgid(&init_user_ns, cred->uc_gids[i])); - *hold = htonl(p - hold - 1); /* gid array length */ - *base = htonl((p - base - 1) << 2); /* cred length */ - - *p++ = htonl(RPC_AUTH_NULL); - *p++ = htonl(0); + return 0; - return p; +marshal_failed: + return -EMSGSIZE; } /* @@ -181,39 +169,46 @@ unx_refresh(struct rpc_task *task) return 0; } -static __be32 * -unx_validate(struct rpc_task *task, __be32 *p) +static int +unx_validate(struct rpc_task *task, struct xdr_stream *xdr) { - rpc_authflavor_t flavor; - u32 size; - - flavor = ntohl(*p++); - if (flavor != RPC_AUTH_NULL && - flavor != RPC_AUTH_UNIX && - flavor != RPC_AUTH_SHORT) { - printk("RPC: bad verf flavor: %u\n", flavor); - return NULL; + struct rpc_auth *auth = task->tk_rqstp->rq_cred->cr_auth; + __be32 *p; + u32 size; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + return -EIO; + switch (*p++) { + case rpc_auth_null: + case rpc_auth_unix: + case rpc_auth_short: + break; + default: + return -EIO; } - - size = ntohl(*p++); - if (size > RPC_MAX_AUTH_SIZE) { - printk("RPC: giant verf size: %u\n", size); - return NULL; - } - task->tk_rqstp->rq_cred->cr_auth->au_rslack = (size >> 2) + 2; - p += (size >> 2); - - return p; + size = be32_to_cpup(p); + if (size > RPC_MAX_AUTH_SIZE) + return -EIO; + p = xdr_inline_decode(xdr, size); + if (!p) + return -EIO; + + auth->au_verfsize = XDR_QUADLEN(size) + 2; + auth->au_rslack = XDR_QUADLEN(size) + 2; + auth->au_ralign = XDR_QUADLEN(size) + 2; + return 0; } int __init rpc_init_authunix(void) { - return rpcauth_init_credcache(&unix_auth); + unix_pool = mempool_create_kmalloc_pool(16, sizeof(struct rpc_cred)); + return unix_pool ? 0 : -ENOMEM; } void rpc_destroy_authunix(void) { - rpcauth_destroy_credcache(&unix_auth); + mempool_destroy(unix_pool); } const struct rpc_authops authunix_ops = { @@ -223,25 +218,26 @@ const struct rpc_authops authunix_ops = { .create = unx_create, .destroy = unx_destroy, .lookup_cred = unx_lookup_cred, - .crcreate = unx_create_cred, }; static struct rpc_auth unix_auth = { - .au_cslack = UNX_WRITESLACK, - .au_rslack = 2, /* assume AUTH_NULL verf */ + .au_cslack = UNX_CALLSLACK, + .au_rslack = NUL_REPLYSLACK, + .au_verfsize = NUL_REPLYSLACK, .au_ops = &authunix_ops, .au_flavor = RPC_AUTH_UNIX, - .au_count = ATOMIC_INIT(0), + .au_count = REFCOUNT_INIT(1), }; static const struct rpc_credops unix_credops = { .cr_name = "AUTH_UNIX", .crdestroy = unx_destroy_cred, - .crbind = rpcauth_generic_bind_cred, .crmatch = unx_match, .crmarshal = unx_marshal, + .crwrap_req = rpcauth_wrap_req_encode, .crrefresh = unx_refresh, .crvalidate = unx_validate, + .crunwrap_resp = rpcauth_unwrap_resp_decode, }; diff --git a/net/sunrpc/backchannel_rqst.c b/net/sunrpc/backchannel_rqst.c index 890a29912d5a..caa94cf57123 100644 --- a/net/sunrpc/backchannel_rqst.c +++ b/net/sunrpc/backchannel_rqst.c @@ -1,23 +1,9 @@ +// SPDX-License-Identifier: GPL-2.0-only /****************************************************************************** (c) 2007 Network Appliance, Inc. All Rights Reserved. (c) 2009 NetApp. All Rights Reserved. -NetApp provides this source code under the GPL v2 License. -The GPL v2 license is available at -http://opensource.org/licenses/gpl-license.php. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ******************************************************************************/ @@ -27,27 +13,24 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include <linux/export.h> #include <linux/sunrpc/bc_xprt.h> -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) #define RPCDBG_FACILITY RPCDBG_TRANS #endif +#define BC_MAX_SLOTS 64U + +unsigned int xprt_bc_max_slots(struct rpc_xprt *xprt) +{ + return BC_MAX_SLOTS; +} + /* * Helper routines that track the number of preallocation elements * on the transport. */ static inline int xprt_need_to_requeue(struct rpc_xprt *xprt) { - return xprt->bc_alloc_count > 0; -} - -static inline void xprt_inc_alloc_count(struct rpc_xprt *xprt, unsigned int n) -{ - xprt->bc_alloc_count += n; -} - -static inline int xprt_dec_alloc_count(struct rpc_xprt *xprt, unsigned int n) -{ - return xprt->bc_alloc_count -= n; + return xprt->bc_alloc_count < xprt->bc_alloc_max; } /* @@ -60,20 +43,71 @@ static void xprt_free_allocation(struct rpc_rqst *req) dprintk("RPC: free allocations for req= %p\n", req); WARN_ON_ONCE(test_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state)); - xbufp = &req->rq_private_buf; + xbufp = &req->rq_rcv_buf; free_page((unsigned long)xbufp->head[0].iov_base); xbufp = &req->rq_snd_buf; free_page((unsigned long)xbufp->head[0].iov_base); - list_del(&req->rq_bc_pa_list); kfree(req); } +static void xprt_bc_reinit_xdr_buf(struct xdr_buf *buf) +{ + buf->head[0].iov_len = PAGE_SIZE; + buf->tail[0].iov_len = 0; + buf->pages = NULL; + buf->page_len = 0; + buf->flags = 0; + buf->len = 0; + buf->buflen = PAGE_SIZE; +} + +static int xprt_alloc_xdr_buf(struct xdr_buf *buf, gfp_t gfp_flags) +{ + struct page *page; + /* Preallocate one XDR receive buffer */ + page = alloc_page(gfp_flags); + if (page == NULL) + return -ENOMEM; + xdr_buf_init(buf, page_address(page), PAGE_SIZE); + return 0; +} + +static struct rpc_rqst *xprt_alloc_bc_req(struct rpc_xprt *xprt) +{ + gfp_t gfp_flags = GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN; + struct rpc_rqst *req; + + /* Pre-allocate one backchannel rpc_rqst */ + req = kzalloc(sizeof(*req), gfp_flags); + if (req == NULL) + return NULL; + + req->rq_xprt = xprt; + + /* Preallocate one XDR receive buffer */ + if (xprt_alloc_xdr_buf(&req->rq_rcv_buf, gfp_flags) < 0) { + printk(KERN_ERR "Failed to create bc receive xbuf\n"); + goto out_free; + } + req->rq_rcv_buf.len = PAGE_SIZE; + + /* Preallocate one XDR send buffer */ + if (xprt_alloc_xdr_buf(&req->rq_snd_buf, gfp_flags) < 0) { + printk(KERN_ERR "Failed to create bc snd xbuf\n"); + goto out_free; + } + return req; +out_free: + xprt_free_allocation(req); + return NULL; +} + /* * Preallocate up to min_reqs structures and related buffers for use * by the backchannel. This function can be called multiple times * when creating new sessions that use the same rpc_xprt. The * preallocated buffers are added to the pool of resources used by - * the rpc_xprt. Anyone of these resources may be used used by an + * the rpc_xprt. Any one of these resources may be used by an * incoming callback request. It's up to the higher levels in the * stack to enforce that the maximum number of session slots is not * being exceeded. @@ -88,14 +122,23 @@ static void xprt_free_allocation(struct rpc_rqst *req) */ int xprt_setup_backchannel(struct rpc_xprt *xprt, unsigned int min_reqs) { - struct page *page_rcv = NULL, *page_snd = NULL; - struct xdr_buf *xbufp = NULL; - struct rpc_rqst *req, *tmp; + if (!xprt->ops->bc_setup) + return 0; + return xprt->ops->bc_setup(xprt, min_reqs); +} +EXPORT_SYMBOL_GPL(xprt_setup_backchannel); + +int xprt_setup_bc(struct rpc_xprt *xprt, unsigned int min_reqs) +{ + struct rpc_rqst *req; struct list_head tmp_list; int i; dprintk("RPC: setup backchannel transport\n"); + if (min_reqs > BC_MAX_SLOTS) + min_reqs = BC_MAX_SLOTS; + /* * We use a temporary list to keep track of the preallocated * buffers. Once we're done building the list we splice it @@ -107,7 +150,7 @@ int xprt_setup_backchannel(struct rpc_xprt *xprt, unsigned int min_reqs) INIT_LIST_HEAD(&tmp_list); for (i = 0; i < min_reqs; i++) { /* Pre-allocate one backchannel rpc_rqst */ - req = kzalloc(sizeof(struct rpc_rqst), GFP_KERNEL); + req = xprt_alloc_bc_req(xprt); if (req == NULL) { printk(KERN_ERR "Failed to create bc rpc_rqst\n"); goto out_free; @@ -116,50 +159,17 @@ int xprt_setup_backchannel(struct rpc_xprt *xprt, unsigned int min_reqs) /* Add the allocated buffer to the tmp list */ dprintk("RPC: adding req= %p\n", req); list_add(&req->rq_bc_pa_list, &tmp_list); - - req->rq_xprt = xprt; - INIT_LIST_HEAD(&req->rq_list); - INIT_LIST_HEAD(&req->rq_bc_list); - - /* Preallocate one XDR receive buffer */ - page_rcv = alloc_page(GFP_KERNEL); - if (page_rcv == NULL) { - printk(KERN_ERR "Failed to create bc receive xbuf\n"); - goto out_free; - } - xbufp = &req->rq_rcv_buf; - xbufp->head[0].iov_base = page_address(page_rcv); - xbufp->head[0].iov_len = PAGE_SIZE; - xbufp->tail[0].iov_base = NULL; - xbufp->tail[0].iov_len = 0; - xbufp->page_len = 0; - xbufp->len = PAGE_SIZE; - xbufp->buflen = PAGE_SIZE; - - /* Preallocate one XDR send buffer */ - page_snd = alloc_page(GFP_KERNEL); - if (page_snd == NULL) { - printk(KERN_ERR "Failed to create bc snd xbuf\n"); - goto out_free; - } - - xbufp = &req->rq_snd_buf; - xbufp->head[0].iov_base = page_address(page_snd); - xbufp->head[0].iov_len = 0; - xbufp->tail[0].iov_base = NULL; - xbufp->tail[0].iov_len = 0; - xbufp->page_len = 0; - xbufp->len = 0; - xbufp->buflen = PAGE_SIZE; } /* * Add the temporary list to the backchannel preallocation list */ - spin_lock_bh(&xprt->bc_pa_lock); + spin_lock(&xprt->bc_pa_lock); list_splice(&tmp_list, &xprt->bc_pa_list); - xprt_inc_alloc_count(xprt, min_reqs); - spin_unlock_bh(&xprt->bc_pa_lock); + xprt->bc_alloc_count += min_reqs; + xprt->bc_alloc_max += min_reqs; + atomic_add(min_reqs, &xprt->bc_slot_count); + spin_unlock(&xprt->bc_pa_lock); dprintk("RPC: setup backchannel transport done\n"); return 0; @@ -168,18 +178,22 @@ out_free: /* * Memory allocation failed, free the temporary list */ - list_for_each_entry_safe(req, tmp, &tmp_list, rq_bc_pa_list) + while (!list_empty(&tmp_list)) { + req = list_first_entry(&tmp_list, + struct rpc_rqst, + rq_bc_pa_list); + list_del(&req->rq_bc_pa_list); xprt_free_allocation(req); + } dprintk("RPC: setup backchannel transport failed\n"); return -ENOMEM; } -EXPORT_SYMBOL_GPL(xprt_setup_backchannel); /** * xprt_destroy_backchannel - Destroys the backchannel preallocated structures. * @xprt: the transport holding the preallocated strucures - * @max_reqs the maximum number of preallocated structures to destroy + * @max_reqs: the maximum number of preallocated structures to destroy * * Since these structures may have been allocated by multiple calls * to xprt_setup_backchannel, we only destroy up to the maximum number @@ -187,6 +201,13 @@ EXPORT_SYMBOL_GPL(xprt_setup_backchannel); */ void xprt_destroy_backchannel(struct rpc_xprt *xprt, unsigned int max_reqs) { + if (xprt->ops->bc_destroy) + xprt->ops->bc_destroy(xprt, max_reqs); +} +EXPORT_SYMBOL_GPL(xprt_destroy_backchannel); + +void xprt_destroy_bc(struct rpc_xprt *xprt, unsigned int max_reqs) +{ struct rpc_rqst *req = NULL, *tmp = NULL; dprintk("RPC: destroy backchannel transport\n"); @@ -195,10 +216,13 @@ void xprt_destroy_backchannel(struct rpc_xprt *xprt, unsigned int max_reqs) goto out; spin_lock_bh(&xprt->bc_pa_lock); - xprt_dec_alloc_count(xprt, max_reqs); + xprt->bc_alloc_max -= min(max_reqs, xprt->bc_alloc_max); list_for_each_entry_safe(req, tmp, &xprt->bc_pa_list, rq_bc_pa_list) { dprintk("RPC: req=%p\n", req); + list_del(&req->rq_bc_pa_list); xprt_free_allocation(req); + xprt->bc_alloc_count--; + atomic_dec(&xprt->bc_slot_count); if (--max_reqs == 0) break; } @@ -208,42 +232,31 @@ out: dprintk("RPC: backchannel list empty= %s\n", list_empty(&xprt->bc_pa_list) ? "true" : "false"); } -EXPORT_SYMBOL_GPL(xprt_destroy_backchannel); -/* - * One or more rpc_rqst structure have been preallocated during the - * backchannel setup. Buffer space for the send and private XDR buffers - * has been preallocated as well. Use xprt_alloc_bc_request to allocate - * to this request. Use xprt_free_bc_request to return it. - * - * We know that we're called in soft interrupt context, grab the spin_lock - * since there is no need to grab the bottom half spin_lock. - * - * Return an available rpc_rqst, otherwise NULL if non are available. - */ -struct rpc_rqst *xprt_alloc_bc_request(struct rpc_xprt *xprt) +static struct rpc_rqst *xprt_get_bc_request(struct rpc_xprt *xprt, __be32 xid, + struct rpc_rqst *new) { - struct rpc_rqst *req; + struct rpc_rqst *req = NULL; dprintk("RPC: allocate a backchannel request\n"); - spin_lock(&xprt->bc_pa_lock); - if (!list_empty(&xprt->bc_pa_list)) { - req = list_first_entry(&xprt->bc_pa_list, struct rpc_rqst, - rq_bc_pa_list); - list_del(&req->rq_bc_pa_list); - } else { - req = NULL; + if (list_empty(&xprt->bc_pa_list)) { + if (!new) + goto not_found; + if (atomic_read(&xprt->bc_slot_count) >= BC_MAX_SLOTS) + goto not_found; + list_add_tail(&new->rq_bc_pa_list, &xprt->bc_pa_list); + xprt->bc_alloc_count++; + atomic_inc(&xprt->bc_slot_count); } - spin_unlock(&xprt->bc_pa_lock); - - if (req != NULL) { - set_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); - req->rq_reply_bytes_recvd = 0; - req->rq_bytes_sent = 0; - memcpy(&req->rq_private_buf, &req->rq_rcv_buf, + req = list_first_entry(&xprt->bc_pa_list, struct rpc_rqst, + rq_bc_pa_list); + req->rq_reply_bytes_recvd = 0; + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(req->rq_private_buf)); - } + req->rq_xid = xid; + req->rq_connect_cookie = xprt->connect_cookie; dprintk("RPC: backchannel req=%p\n", req); +not_found: return req; } @@ -255,14 +268,36 @@ void xprt_free_bc_request(struct rpc_rqst *req) { struct rpc_xprt *xprt = req->rq_xprt; + xprt->ops->bc_free_rqst(req); +} + +void xprt_free_bc_rqst(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + dprintk("RPC: free backchannel req=%p\n", req); - smp_mb__before_clear_bit(); - WARN_ON_ONCE(!test_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state)); + req->rq_connect_cookie = xprt->connect_cookie - 1; + smp_mb__before_atomic(); clear_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); - smp_mb__after_clear_bit(); + smp_mb__after_atomic(); - if (!xprt_need_to_requeue(xprt)) { + /* + * Return it to the list of preallocations so that it + * may be reused by a new callback request. + */ + spin_lock_bh(&xprt->bc_pa_lock); + if (xprt_need_to_requeue(xprt)) { + xprt_bc_reinit_xdr_buf(&req->rq_snd_buf); + xprt_bc_reinit_xdr_buf(&req->rq_rcv_buf); + req->rq_rcv_buf.len = PAGE_SIZE; + list_add_tail(&req->rq_bc_pa_list, &xprt->bc_pa_list); + xprt->bc_alloc_count++; + atomic_inc(&xprt->bc_slot_count); + req = NULL; + } + spin_unlock_bh(&xprt->bc_pa_lock); + if (req != NULL) { /* * The last remaining session was destroyed while this * entry was in use. Free the entry and don't attempt @@ -271,15 +306,66 @@ void xprt_free_bc_request(struct rpc_rqst *req) */ dprintk("RPC: Last session removed req=%p\n", req); xprt_free_allocation(req); - return; } + xprt_put(xprt); +} - /* - * Return it to the list of preallocations so that it - * may be reused by a new callback request. - */ - spin_lock_bh(&xprt->bc_pa_lock); - list_add(&req->rq_bc_pa_list, &xprt->bc_pa_list); - spin_unlock_bh(&xprt->bc_pa_lock); +/* + * One or more rpc_rqst structure have been preallocated during the + * backchannel setup. Buffer space for the send and private XDR buffers + * has been preallocated as well. Use xprt_alloc_bc_request to allocate + * to this request. Use xprt_free_bc_request to return it. + * + * We know that we're called in soft interrupt context, grab the spin_lock + * since there is no need to grab the bottom half spin_lock. + * + * Return an available rpc_rqst, otherwise NULL if non are available. + */ +struct rpc_rqst *xprt_lookup_bc_request(struct rpc_xprt *xprt, __be32 xid) +{ + struct rpc_rqst *req, *new = NULL; + + do { + spin_lock(&xprt->bc_pa_lock); + list_for_each_entry(req, &xprt->bc_pa_list, rq_bc_pa_list) { + if (req->rq_connect_cookie != xprt->connect_cookie) + continue; + if (req->rq_xid == xid) + goto found; + } + req = xprt_get_bc_request(xprt, xid, new); +found: + spin_unlock(&xprt->bc_pa_lock); + if (new) { + if (req != new) + xprt_free_allocation(new); + break; + } else if (req) + break; + new = xprt_alloc_bc_req(xprt); + } while (new); + return req; } +/* + * Add callback request to callback list. Wake a thread + * on the first pool (usually the only pool) to handle it. + */ +void xprt_complete_bc_request(struct rpc_rqst *req, uint32_t copied) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct svc_serv *bc_serv = xprt->bc_serv; + + spin_lock(&xprt->bc_pa_lock); + list_del(&req->rq_bc_pa_list); + xprt->bc_alloc_count--; + spin_unlock(&xprt->bc_pa_lock); + + req->rq_private_buf.len = copied; + set_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); + + dprintk("RPC: add callback request to list\n"); + xprt_get(xprt); + lwq_enqueue(&req->rq_bc_list, &bc_serv->sv_cb_list); + svc_pool_wake_idle_thread(&bc_serv->sv_pools[0]); +} diff --git a/net/sunrpc/bc_svc.c b/net/sunrpc/bc_svc.c deleted file mode 100644 index 15c7a8a1c24f..000000000000 --- a/net/sunrpc/bc_svc.c +++ /dev/null @@ -1,63 +0,0 @@ -/****************************************************************************** - -(c) 2007 Network Appliance, Inc. All Rights Reserved. -(c) 2009 NetApp. All Rights Reserved. - -NetApp provides this source code under the GPL v2 License. -The GPL v2 license is available at -http://opensource.org/licenses/gpl-license.php. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -******************************************************************************/ - -/* - * The NFSv4.1 callback service helper routines. - * They implement the transport level processing required to send the - * reply over an existing open connection previously established by the client. - */ - -#include <linux/module.h> - -#include <linux/sunrpc/xprt.h> -#include <linux/sunrpc/sched.h> -#include <linux/sunrpc/bc_xprt.h> - -#define RPCDBG_FACILITY RPCDBG_SVCDSP - -/* Empty callback ops */ -static const struct rpc_call_ops nfs41_callback_ops = { -}; - - -/* - * Send the callback reply - */ -int bc_send(struct rpc_rqst *req) -{ - struct rpc_task *task; - int ret; - - dprintk("RPC: bc_send req= %p\n", req); - task = rpc_run_bc_task(req, &nfs41_callback_ops); - if (IS_ERR(task)) - ret = PTR_ERR(task); - else { - WARN_ON_ONCE(atomic_read(&task->tk_count) != 1); - ret = task->tk_status; - rpc_put_task(task); - } - dprintk("RPC: bc_send ret= %d\n", ret); - return ret; -} - diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c index a72de074172d..131090f31e6a 100644 --- a/net/sunrpc/cache.c +++ b/net/sunrpc/cache.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * net/sunrpc/cache.c * @@ -5,9 +6,6 @@ * used by sunrpc clients and servers. * * Copyright (C) 2002 Neil Brown <neilb@cse.unsw.edu.au> - * - * Released under terms in GPL version 2. See COPYING. - * */ #include <linux/types.h> @@ -20,7 +18,8 @@ #include <linux/list.h> #include <linux/module.h> #include <linux/ctype.h> -#include <asm/uaccess.h> +#include <linux/string_helpers.h> +#include <linux/uaccess.h> #include <linux/poll.h> #include <linux/seq_file.h> #include <linux/proc_fs.h> @@ -33,46 +32,76 @@ #include <linux/sunrpc/cache.h> #include <linux/sunrpc/stats.h> #include <linux/sunrpc/rpc_pipe_fs.h> +#include <trace/events/sunrpc.h> + #include "netns.h" +#include "fail.h" #define RPCDBG_FACILITY RPCDBG_CACHE static bool cache_defer_req(struct cache_req *req, struct cache_head *item); static void cache_revisit_request(struct cache_head *item); -static void cache_init(struct cache_head *h) +static void cache_init(struct cache_head *h, struct cache_detail *detail) { - time_t now = seconds_since_boot(); - h->next = NULL; + time64_t now = seconds_since_boot(); + INIT_HLIST_NODE(&h->cache_list); h->flags = 0; kref_init(&h->ref); h->expiry_time = now + CACHE_NEW_EXPIRY; + if (now <= detail->flush_time) + /* ensure it isn't already expired */ + now = detail->flush_time + 1; h->last_refresh = now; } -struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, - struct cache_head *key, int hash) +static void cache_fresh_unlocked(struct cache_head *head, + struct cache_detail *detail); + +static struct cache_head *sunrpc_cache_find_rcu(struct cache_detail *detail, + struct cache_head *key, + int hash) { - struct cache_head **head, **hp; - struct cache_head *new = NULL, *freeme = NULL; + struct hlist_head *head = &detail->hash_table[hash]; + struct cache_head *tmp; - head = &detail->hash_table[hash]; + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, head, cache_list) { + if (!detail->match(tmp, key)) + continue; + if (test_bit(CACHE_VALID, &tmp->flags) && + cache_is_expired(detail, tmp)) + continue; + tmp = cache_get_rcu(tmp); + rcu_read_unlock(); + return tmp; + } + rcu_read_unlock(); + return NULL; +} - read_lock(&detail->hash_lock); +static void sunrpc_begin_cache_remove_entry(struct cache_head *ch, + struct cache_detail *cd) +{ + /* Must be called under cd->hash_lock */ + hlist_del_init_rcu(&ch->cache_list); + set_bit(CACHE_CLEANED, &ch->flags); + cd->entries --; +} - for (hp=head; *hp != NULL ; hp = &(*hp)->next) { - struct cache_head *tmp = *hp; - if (detail->match(tmp, key)) { - if (cache_is_expired(detail, tmp)) - /* This entry is expired, we will discard it. */ - break; - cache_get(tmp); - read_unlock(&detail->hash_lock); - return tmp; - } - } - read_unlock(&detail->hash_lock); - /* Didn't find anything, insert an empty entry */ +static void sunrpc_end_cache_remove_entry(struct cache_head *ch, + struct cache_detail *cd) +{ + cache_fresh_unlocked(ch, cd); + cache_put(ch, cd); +} + +static struct cache_head *sunrpc_cache_add_entry(struct cache_detail *detail, + struct cache_head *key, + int hash) +{ + struct cache_head *new, *tmp, *freeme = NULL; + struct hlist_head *head = &detail->hash_table[hash]; new = detail->alloc(); if (!new) @@ -81,47 +110,65 @@ struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, * we might get lose if we need to * cache_put it soon. */ - cache_init(new); + cache_init(new, detail); detail->init(new, key); - write_lock(&detail->hash_lock); + spin_lock(&detail->hash_lock); /* check if entry appeared while we slept */ - for (hp=head; *hp != NULL ; hp = &(*hp)->next) { - struct cache_head *tmp = *hp; - if (detail->match(tmp, key)) { - if (cache_is_expired(detail, tmp)) { - *hp = tmp->next; - tmp->next = NULL; - detail->entries --; - freeme = tmp; - break; - } - cache_get(tmp); - write_unlock(&detail->hash_lock); - cache_put(new, detail); - return tmp; + hlist_for_each_entry_rcu(tmp, head, cache_list, + lockdep_is_held(&detail->hash_lock)) { + if (!detail->match(tmp, key)) + continue; + if (test_bit(CACHE_VALID, &tmp->flags) && + cache_is_expired(detail, tmp)) { + sunrpc_begin_cache_remove_entry(tmp, detail); + trace_cache_entry_expired(detail, tmp); + freeme = tmp; + break; } + cache_get(tmp); + spin_unlock(&detail->hash_lock); + cache_put(new, detail); + return tmp; } - new->next = *head; - *head = new; + + hlist_add_head_rcu(&new->cache_list, head); detail->entries++; + if (detail->nextcheck > new->expiry_time) + detail->nextcheck = new->expiry_time + 1; cache_get(new); - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); if (freeme) - cache_put(freeme, detail); + sunrpc_end_cache_remove_entry(freeme, detail); return new; } -EXPORT_SYMBOL_GPL(sunrpc_cache_lookup); +struct cache_head *sunrpc_cache_lookup_rcu(struct cache_detail *detail, + struct cache_head *key, int hash) +{ + struct cache_head *ret; + + ret = sunrpc_cache_find_rcu(detail, key, hash); + if (ret) + return ret; + /* Didn't find anything, insert an empty entry */ + return sunrpc_cache_add_entry(detail, key, hash); +} +EXPORT_SYMBOL_GPL(sunrpc_cache_lookup_rcu); static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch); -static void cache_fresh_locked(struct cache_head *head, time_t expiry) +static void cache_fresh_locked(struct cache_head *head, time64_t expiry, + struct cache_detail *detail) { + time64_t now = seconds_since_boot(); + if (now <= detail->flush_time) + /* ensure it isn't immediately treated as expired */ + now = detail->flush_time + 1; head->expiry_time = expiry; - head->last_refresh = seconds_since_boot(); + head->last_refresh = now; smp_wmb(); /* paired with smp_rmb() in cache_is_valid() */ set_bit(CACHE_VALID, &head->flags); } @@ -135,6 +182,25 @@ static void cache_fresh_unlocked(struct cache_head *head, } } +static void cache_make_negative(struct cache_detail *detail, + struct cache_head *h) +{ + set_bit(CACHE_NEGATIVE, &h->flags); + trace_cache_entry_make_negative(detail, h); +} + +static void cache_entry_update(struct cache_detail *detail, + struct cache_head *h, + struct cache_head *new) +{ + if (!test_bit(CACHE_NEGATIVE, &new->flags)) { + detail->update(h, new); + trace_cache_entry_update(detail, h); + } else { + cache_make_negative(detail, h); + } +} + struct cache_head *sunrpc_cache_update(struct cache_detail *detail, struct cache_head *new, struct cache_head *old, int hash) { @@ -142,22 +208,18 @@ struct cache_head *sunrpc_cache_update(struct cache_detail *detail, * If 'old' is not VALID, we update it directly, * otherwise we need to replace it */ - struct cache_head **head; struct cache_head *tmp; if (!test_bit(CACHE_VALID, &old->flags)) { - write_lock(&detail->hash_lock); + spin_lock(&detail->hash_lock); if (!test_bit(CACHE_VALID, &old->flags)) { - if (test_bit(CACHE_NEGATIVE, &new->flags)) - set_bit(CACHE_NEGATIVE, &old->flags); - else - detail->update(old, new); - cache_fresh_locked(old, new->expiry_time); - write_unlock(&detail->hash_lock); + cache_entry_update(detail, old, new); + cache_fresh_locked(old, new->expiry_time, detail); + spin_unlock(&detail->hash_lock); cache_fresh_unlocked(old, detail); return old; } - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); } /* We need to insert a new entry */ tmp = detail->alloc(); @@ -165,22 +227,17 @@ struct cache_head *sunrpc_cache_update(struct cache_detail *detail, cache_put(old, detail); return NULL; } - cache_init(tmp); + cache_init(tmp, detail); detail->init(tmp, old); - head = &detail->hash_table[hash]; - write_lock(&detail->hash_lock); - if (test_bit(CACHE_NEGATIVE, &new->flags)) - set_bit(CACHE_NEGATIVE, &tmp->flags); - else - detail->update(tmp, new); - tmp->next = *head; - *head = tmp; + spin_lock(&detail->hash_lock); + cache_entry_update(detail, tmp, new); + hlist_add_head(&tmp->cache_list, &detail->hash_table[hash]); detail->entries++; cache_get(tmp); - cache_fresh_locked(tmp, new->expiry_time); - cache_fresh_locked(old, 0); - write_unlock(&detail->hash_lock); + cache_fresh_locked(tmp, new->expiry_time, detail); + cache_fresh_locked(old, 0, detail); + spin_unlock(&detail->hash_lock); cache_fresh_unlocked(tmp, detail); cache_fresh_unlocked(old, detail); cache_put(old, detail); @@ -188,13 +245,6 @@ struct cache_head *sunrpc_cache_update(struct cache_detail *detail, } EXPORT_SYMBOL_GPL(sunrpc_cache_update); -static int cache_make_upcall(struct cache_detail *cd, struct cache_head *h) -{ - if (cd->cache_upcall) - return cd->cache_upcall(cd, h); - return sunrpc_cache_pipe_upcall(cd, h); -} - static inline int cache_is_valid(struct cache_head *h) { if (!test_bit(CACHE_VALID, &h->flags)) @@ -220,37 +270,24 @@ static int try_to_negate_entry(struct cache_detail *detail, struct cache_head *h { int rv; - write_lock(&detail->hash_lock); + spin_lock(&detail->hash_lock); rv = cache_is_valid(h); if (rv == -EAGAIN) { - set_bit(CACHE_NEGATIVE, &h->flags); - cache_fresh_locked(h, seconds_since_boot()+CACHE_NEW_EXPIRY); + cache_make_negative(detail, h); + cache_fresh_locked(h, seconds_since_boot()+CACHE_NEW_EXPIRY, + detail); rv = -ENOENT; } - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); cache_fresh_unlocked(h, detail); return rv; } -/* - * This is the generic cache management routine for all - * the authentication caches. - * It checks the currency of a cache item and will (later) - * initiate an upcall to fill it if needed. - * - * - * Returns 0 if the cache_head can be used, or cache_puts it and returns - * -EAGAIN if upcall is pending and request has been queued - * -ETIMEDOUT if upcall failed or request could not be queue or - * upcall completed but item is still invalid (implying that - * the cache item has been replaced with a newer one). - * -ENOENT if cache entry was negative - */ -int cache_check(struct cache_detail *detail, +int cache_check_rcu(struct cache_detail *detail, struct cache_head *h, struct cache_req *rqstp) { int rv; - long refresh_age, age; + time64_t refresh_age, age; /* First decide return status as best we can */ rv = cache_is_valid(h); @@ -264,17 +301,15 @@ int cache_check(struct cache_detail *detail, rv = -ENOENT; } else if (rv == -EAGAIN || (h->expiry_time != 0 && age > refresh_age/2)) { - dprintk("RPC: Want update, refage=%ld, age=%ld\n", + dprintk("RPC: Want update, refage=%lld, age=%lld\n", refresh_age, age); - if (!test_and_set_bit(CACHE_PENDING, &h->flags)) { - switch (cache_make_upcall(detail, h)) { - case -EINVAL: - rv = try_to_negate_entry(detail, h); - break; - case -EAGAIN: - cache_fresh_unlocked(h, detail); - break; - } + switch (detail->cache_upcall(detail, h)) { + case -EINVAL: + rv = try_to_negate_entry(detail, h); + break; + case -EAGAIN: + cache_fresh_unlocked(h, detail); + break; } } @@ -289,6 +324,31 @@ int cache_check(struct cache_detail *detail, rv = -ETIMEDOUT; } } + + return rv; +} +EXPORT_SYMBOL_GPL(cache_check_rcu); + +/* + * This is the generic cache management routine for all + * the authentication caches. + * It checks the currency of a cache item and will (later) + * initiate an upcall to fill it if needed. + * + * + * Returns 0 if the cache_head can be used, or cache_puts it and returns + * -EAGAIN if upcall is pending and request has been queued + * -ETIMEDOUT if upcall failed or request could not be queue or + * upcall completed but item is still invalid (implying that + * the cache item has been replaced with a newer one). + * -ENOENT if cache entry was negative + */ +int cache_check(struct cache_detail *detail, + struct cache_head *h, struct cache_req *rqstp) +{ + int rv; + + rv = cache_check_rcu(detail, h, rqstp); if (rv) cache_put(h, detail); return rv; @@ -337,19 +397,19 @@ static struct delayed_work cache_cleaner; void sunrpc_init_cache_detail(struct cache_detail *cd) { - rwlock_init(&cd->hash_lock); + spin_lock_init(&cd->hash_lock); INIT_LIST_HEAD(&cd->queue); spin_lock(&cache_list_lock); cd->nextcheck = 0; cd->entries = 0; - atomic_set(&cd->readers, 0); + atomic_set(&cd->writers, 0); cd->last_close = 0; cd->last_warn = -1; list_add(&cd->others, &cache_list); spin_unlock(&cache_list_lock); /* start the cleaning process */ - schedule_delayed_work(&cache_cleaner, 0); + queue_delayed_work(system_power_efficient_wq, &cache_cleaner, 0); } EXPORT_SYMBOL_GPL(sunrpc_init_cache_detail); @@ -357,24 +417,16 @@ void sunrpc_destroy_cache_detail(struct cache_detail *cd) { cache_purge(cd); spin_lock(&cache_list_lock); - write_lock(&cd->hash_lock); - if (cd->entries || atomic_read(&cd->inuse)) { - write_unlock(&cd->hash_lock); - spin_unlock(&cache_list_lock); - goto out; - } + spin_lock(&cd->hash_lock); if (current_detail == cd) current_detail = NULL; list_del_init(&cd->others); - write_unlock(&cd->hash_lock); + spin_unlock(&cd->hash_lock); spin_unlock(&cache_list_lock); if (list_empty(&cache_list)) { /* module must be being unloaded so its safe to kill the worker */ cancel_delayed_work_sync(&cache_cleaner); } - return; -out: - printk(KERN_ERR "nfsd: failed to unregister %s cache\n", cd->name); } EXPORT_SYMBOL_GPL(sunrpc_destroy_cache_detail); @@ -412,48 +464,45 @@ static int cache_clean(void) } } + spin_lock(¤t_detail->hash_lock); + /* find a non-empty bucket in the table */ - while (current_detail && - current_index < current_detail->hash_size && - current_detail->hash_table[current_index] == NULL) + while (current_index < current_detail->hash_size && + hlist_empty(¤t_detail->hash_table[current_index])) current_index++; /* find a cleanable entry in the bucket and clean it, or set to next bucket */ - - if (current_detail && current_index < current_detail->hash_size) { - struct cache_head *ch, **cp; + if (current_index < current_detail->hash_size) { + struct cache_head *ch = NULL; struct cache_detail *d; - - write_lock(¤t_detail->hash_lock); + struct hlist_head *head; + struct hlist_node *tmp; /* Ok, now to clean this strand */ - - cp = & current_detail->hash_table[current_index]; - for (ch = *cp ; ch ; cp = & ch->next, ch = *cp) { + head = ¤t_detail->hash_table[current_index]; + hlist_for_each_entry_safe(ch, tmp, head, cache_list) { if (current_detail->nextcheck > ch->expiry_time) current_detail->nextcheck = ch->expiry_time+1; if (!cache_is_expired(current_detail, ch)) continue; - *cp = ch->next; - ch->next = NULL; - current_detail->entries--; + sunrpc_begin_cache_remove_entry(ch, current_detail); + trace_cache_entry_expired(current_detail, ch); rv = 1; break; } - write_unlock(¤t_detail->hash_lock); + spin_unlock(¤t_detail->hash_lock); d = current_detail; if (!ch) current_index ++; spin_unlock(&cache_list_lock); - if (ch) { - set_bit(CACHE_CLEANED, &ch->flags); - cache_fresh_unlocked(ch, d); - cache_put(ch, d); - } - } else + if (ch) + sunrpc_end_cache_remove_entry(ch, d); + } else { + spin_unlock(¤t_detail->hash_lock); spin_unlock(&cache_list_lock); + } return rv; } @@ -463,15 +512,17 @@ static int cache_clean(void) */ static void do_cache_clean(struct work_struct *work) { - int delay = 5; - if (cache_clean() == -1) - delay = round_jiffies_relative(30*HZ); + int delay; if (list_empty(&cache_list)) - delay = 0; + return; - if (delay) - schedule_delayed_work(&cache_cleaner, delay); + if (cache_clean() == -1) + delay = round_jiffies_relative(30*HZ); + else + delay = 5; + + queue_delayed_work(system_power_efficient_wq, &cache_cleaner, delay); } @@ -491,10 +542,29 @@ EXPORT_SYMBOL_GPL(cache_flush); void cache_purge(struct cache_detail *detail) { - detail->flush_time = LONG_MAX; - detail->nextcheck = seconds_since_boot(); - cache_flush(); - detail->flush_time = 1; + struct cache_head *ch = NULL; + struct hlist_head *head = NULL; + int i = 0; + + spin_lock(&detail->hash_lock); + if (!detail->entries) { + spin_unlock(&detail->hash_lock); + return; + } + + dprintk("RPC: %d entries in %s cache\n", detail->entries, detail->name); + for (i = 0; i < detail->hash_size; i++) { + head = &detail->hash_table[i]; + while (!hlist_empty(head)) { + ch = hlist_entry(head->first, struct cache_head, + cache_list); + sunrpc_begin_cache_remove_entry(ch, detail); + spin_unlock(&detail->hash_lock); + sunrpc_end_cache_remove_entry(ch, detail); + spin_lock(&detail->hash_lock); + } + } + spin_unlock(&detail->hash_lock); } EXPORT_SYMBOL_GPL(cache_purge); @@ -619,7 +689,7 @@ static void cache_limit_defers(void) /* Consider removing either the first or the last */ if (cache_defer_cnt > DFR_MAX) { - if (net_random() & 1) + if (get_random_u32_below(2)) discard = list_entry(cache_defer_list.next, struct cache_deferred_req, recent); else @@ -632,16 +702,30 @@ static void cache_limit_defers(void) discard->revisit(discard, 1); } +#if IS_ENABLED(CONFIG_FAIL_SUNRPC) +static inline bool cache_defer_immediately(void) +{ + return !fail_sunrpc.ignore_cache_wait && + should_fail(&fail_sunrpc.attr, 1); +} +#else +static inline bool cache_defer_immediately(void) +{ + return false; +} +#endif + /* Return true if and only if a deferred request is queued. */ static bool cache_defer_req(struct cache_req *req, struct cache_head *item) { struct cache_deferred_req *dreq; - if (req->thread_wait) { + if (!cache_defer_immediately()) { cache_wait_req(req, item); if (!test_bit(CACHE_PENDING, &item->flags)) return false; } + dreq = req->defer(req); if (dreq == NULL) return false; @@ -659,11 +743,10 @@ static bool cache_defer_req(struct cache_req *req, struct cache_head *item) static void cache_revisit_request(struct cache_head *item) { struct cache_deferred_req *dreq; - struct list_head pending; struct hlist_node *tmp; int hash = DFR_HASH(item); + LIST_HEAD(pending); - INIT_LIST_HEAD(&pending); spin_lock(&cache_defer_lock); hlist_for_each_entry_safe(dreq, tmp, &cache_defer_hash[hash], hash) @@ -684,10 +767,8 @@ static void cache_revisit_request(struct cache_head *item) void cache_clean_deferred(void *owner) { struct cache_deferred_req *dreq, *tmp; - struct list_head pending; - + LIST_HEAD(pending); - INIT_LIST_HEAD(&pending); spin_lock(&cache_defer_lock); list_for_each_entry_safe(dreq, tmp, &cache_defer_list, recent) { @@ -708,7 +789,7 @@ void cache_clean_deferred(void *owner) /* * communicate with user-space * - * We have a magic /proc file - /proc/sunrpc/<cachename>/channel. + * We have a magic /proc file - /proc/net/rpc/<cachename>/channel. * On read, you get a full request, or block. * On write, an update request is processed. * Poll works if anything to read, and always allows write. @@ -722,7 +803,6 @@ void cache_clean_deferred(void *owner) */ static DEFINE_SPINLOCK(queue_lock); -static DEFINE_MUTEX(queue_io_mutex); struct cache_queue { struct list_head list; @@ -748,7 +828,7 @@ static int cache_request(struct cache_detail *detail, detail->cache_request(detail, crq->item, &bp, &len); if (len < 0) - return -EAGAIN; + return -E2BIG; return PAGE_SIZE - len; } @@ -763,7 +843,7 @@ static ssize_t cache_read(struct file *filp, char __user *buf, size_t count, if (count == 0) return 0; - mutex_lock(&inode->i_mutex); /* protect against multiple concurrent + inode_lock(inode); /* protect against multiple concurrent * readers on this file */ again: spin_lock(&queue_lock); @@ -776,7 +856,7 @@ static ssize_t cache_read(struct file *filp, char __user *buf, size_t count, } if (rp->q.list.next == &cd->queue) { spin_unlock(&queue_lock); - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); WARN_ON_ONCE(rp->offset); return 0; } @@ -830,7 +910,7 @@ static ssize_t cache_read(struct file *filp, char __user *buf, size_t count, } if (err == -EAGAIN) goto again; - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); return err ? err : count; } @@ -850,44 +930,26 @@ static ssize_t cache_do_downcall(char *kaddr, const char __user *buf, return ret; } -static ssize_t cache_slow_downcall(const char __user *buf, - size_t count, struct cache_detail *cd) -{ - static char write_buf[8192]; /* protected by queue_io_mutex */ - ssize_t ret = -EINVAL; - - if (count >= sizeof(write_buf)) - goto out; - mutex_lock(&queue_io_mutex); - ret = cache_do_downcall(write_buf, buf, count, cd); - mutex_unlock(&queue_io_mutex); -out: - return ret; -} - static ssize_t cache_downcall(struct address_space *mapping, const char __user *buf, size_t count, struct cache_detail *cd) { - struct page *page; - char *kaddr; + char *write_buf; ssize_t ret = -ENOMEM; - if (count >= PAGE_CACHE_SIZE) - goto out_slow; + if (count >= 32768) { /* 32k is max userland buffer, lets check anyway */ + ret = -EINVAL; + goto out; + } - page = find_or_create_page(mapping, 0, GFP_KERNEL); - if (!page) - goto out_slow; + write_buf = kvmalloc(count + 1, GFP_KERNEL); + if (!write_buf) + goto out; - kaddr = kmap(page); - ret = cache_do_downcall(kaddr, buf, count, cd); - kunmap(page); - unlock_page(page); - page_cache_release(page); + ret = cache_do_downcall(write_buf, buf, count, cd); + kvfree(write_buf); +out: return ret; -out_slow: - return cache_slow_downcall(buf, count, cd); } static ssize_t cache_write(struct file *filp, const char __user *buf, @@ -901,26 +963,26 @@ static ssize_t cache_write(struct file *filp, const char __user *buf, if (!cd->cache_parse) goto out; - mutex_lock(&inode->i_mutex); + inode_lock(inode); ret = cache_downcall(mapping, buf, count, cd); - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); out: return ret; } static DECLARE_WAIT_QUEUE_HEAD(queue_wait); -static unsigned int cache_poll(struct file *filp, poll_table *wait, +static __poll_t cache_poll(struct file *filp, poll_table *wait, struct cache_detail *cd) { - unsigned int mask; + __poll_t mask; struct cache_reader *rp = filp->private_data; struct cache_queue *cq; poll_wait(filp, &queue_wait, wait); /* alway allow write */ - mask = POLL_OUT | POLLWRNORM; + mask = EPOLLOUT | EPOLLWRNORM; if (!rp) return mask; @@ -930,7 +992,7 @@ static unsigned int cache_poll(struct file *filp, poll_table *wait, for (cq= &rp->q; &cq->list != &cd->queue; cq = list_entry(cq->list.next, struct cache_queue, list)) if (!cq->reader) { - mask |= POLLIN | POLLRDNORM; + mask |= EPOLLIN | EPOLLRDNORM; break; } spin_unlock(&queue_lock); @@ -982,11 +1044,13 @@ static int cache_open(struct inode *inode, struct file *filp, } rp->offset = 0; rp->q.reader = 1; - atomic_inc(&cd->readers); + spin_lock(&queue_lock); list_add(&rp->q.list, &cd->queue); spin_unlock(&queue_lock); } + if (filp->f_mode & FMODE_WRITE) + atomic_inc(&cd->writers); filp->private_data = rp; return 0; } @@ -1015,8 +1079,10 @@ static int cache_release(struct inode *inode, struct file *filp, filp->private_data = NULL; kfree(rp); + } + if (filp->f_mode & FMODE_WRITE) { + atomic_dec(&cd->writers); cd->last_close = seconds_since_boot(); - atomic_dec(&cd->readers); } module_put(cd->owner); return 0; @@ -1028,9 +1094,8 @@ static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch) { struct cache_queue *cq, *tmp; struct cache_request *cr; - struct list_head dequeued; + LIST_HEAD(dequeued); - INIT_LIST_HEAD(&dequeued); spin_lock(&queue_lock); list_for_each_entry_safe(cq, tmp, &detail->queue, list) if (!cq->reader) { @@ -1067,30 +1132,17 @@ void qword_add(char **bpp, int *lp, char *str) { char *bp = *bpp; int len = *lp; - char c; + int ret; if (len < 0) return; - while ((c=*str++) && len) - switch(c) { - case ' ': - case '\t': - case '\n': - case '\\': - if (len >= 4) { - *bp++ = '\\'; - *bp++ = '0' + ((c & 0300)>>6); - *bp++ = '0' + ((c & 0070)>>3); - *bp++ = '0' + ((c & 0007)>>0); - } - len -= 4; - break; - default: - *bp++ = c; - len--; - } - if (c || len <1) len = -1; - else { + ret = string_escape_str(str, bp, len, ESCAPE_OCTAL, "\\ \n\t"); + if (ret >= len) { + bp += len; + len = -1; + } else { + bp += ret; + len -= ret; *bp++ = ' '; len--; } @@ -1111,9 +1163,7 @@ void qword_addhex(char **bpp, int *lp, char *buf, int blen) *bp++ = 'x'; len -= 2; while (blen && len >= 2) { - unsigned char c = *buf++; - *bp++ = '0' + ((c&0xf0)>>4) + (c>=0xa0)*('a'-'9'-1); - *bp++ = '0' + (c&0x0f) + ((c&0x0f)>=0x0a)*('a'-'9'-1); + bp = hex_byte_pack(bp, *buf++); len -= 2; blen--; } @@ -1139,7 +1189,7 @@ static void warn_no_listener(struct cache_detail *detail) static bool cache_listeners_exist(struct cache_detail *detail) { - if (atomic_read(&detail->readers)) + if (atomic_read(&detail->writers)) return true; if (detail->last_close == 0) /* This cache was never opened */ @@ -1160,20 +1210,12 @@ static bool cache_listeners_exist(struct cache_detail *detail) * * Each request is at most one page long. */ -int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h) +static int cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h) { - char *buf; struct cache_request *crq; int ret = 0; - if (!detail->cache_request) - return -EINVAL; - - if (!cache_listeners_exist(detail)) { - warn_no_listener(detail); - return -EINVAL; - } if (test_bit(CACHE_CLEANED, &h->flags)) /* Too late to make an upcall */ return -EAGAIN; @@ -1189,14 +1231,15 @@ int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h) } crq->q.reader = 0; - crq->item = cache_get(h); crq->buf = buf; crq->len = 0; crq->readers = 0; spin_lock(&queue_lock); - if (test_bit(CACHE_PENDING, &h->flags)) + if (test_bit(CACHE_PENDING, &h->flags)) { + crq->item = cache_get(h); list_add_tail(&crq->q.list, &detail->queue); - else + trace_cache_entry_upcall(detail, h); + } else /* Lost a race, no longer PENDING, so don't enqueue */ ret = -EAGAIN; spin_unlock(&queue_lock); @@ -1207,8 +1250,27 @@ int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h) } return ret; } + +int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h) +{ + if (test_and_set_bit(CACHE_PENDING, &h->flags)) + return 0; + return cache_pipe_upcall(detail, h); +} EXPORT_SYMBOL_GPL(sunrpc_cache_pipe_upcall); +int sunrpc_cache_pipe_upcall_timeout(struct cache_detail *detail, + struct cache_head *h) +{ + if (!cache_listeners_exist(detail)) { + warn_no_listener(detail); + trace_cache_entry_no_listener(detail, h); + return -EINVAL; + } + return sunrpc_cache_pipe_upcall(detail, h); +} +EXPORT_SYMBOL_GPL(sunrpc_cache_pipe_upcall_timeout); + /* * parse a message from user-space and pass it * to an appropriate cache @@ -1232,7 +1294,7 @@ int qword_get(char **bpp, char *dest, int bufsize) if (bp[0] == '\\' && bp[1] == 'x') { /* HEX STRING */ bp += 2; - while (len < bufsize) { + while (len < bufsize - 1) { int h, l; h = hex_to_bin(bp[0]); @@ -1278,32 +1340,25 @@ EXPORT_SYMBOL_GPL(qword_get); /* - * support /proc/sunrpc/cache/$CACHENAME/content + * support /proc/net/rpc/$CACHENAME/content * as a seqfile. * We call ->cache_show passing NULL for the item to * get a header, then pass each real item in the cache */ -struct handle { - struct cache_detail *cd; -}; - -static void *c_start(struct seq_file *m, loff_t *pos) - __acquires(cd->hash_lock) +static void *__cache_seq_start(struct seq_file *m, loff_t *pos) { loff_t n = *pos; unsigned int hash, entry; struct cache_head *ch; - struct cache_detail *cd = ((struct handle*)m->private)->cd; + struct cache_detail *cd = m->private; - - read_lock(&cd->hash_lock); if (!n--) return SEQ_START_TOKEN; hash = n >> 32; entry = n & ((1LL<<32) - 1); - for (ch=cd->hash_table[hash]; ch; ch=ch->next) + hlist_for_each_entry_rcu(ch, &cd->hash_table[hash], cache_list) if (!entry--) return ch; n &= ~((1LL<<32) - 1); @@ -1311,100 +1366,119 @@ static void *c_start(struct seq_file *m, loff_t *pos) hash++; n += 1LL<<32; } while(hash < cd->hash_size && - cd->hash_table[hash]==NULL); + hlist_empty(&cd->hash_table[hash])); if (hash >= cd->hash_size) return NULL; *pos = n+1; - return cd->hash_table[hash]; + return hlist_entry_safe(rcu_dereference_raw( + hlist_first_rcu(&cd->hash_table[hash])), + struct cache_head, cache_list); } -static void *c_next(struct seq_file *m, void *p, loff_t *pos) +static void *cache_seq_next(struct seq_file *m, void *p, loff_t *pos) { struct cache_head *ch = p; int hash = (*pos >> 32); - struct cache_detail *cd = ((struct handle*)m->private)->cd; + struct cache_detail *cd = m->private; if (p == SEQ_START_TOKEN) hash = 0; - else if (ch->next == NULL) { + else if (ch->cache_list.next == NULL) { hash++; *pos += 1LL<<32; } else { ++*pos; - return ch->next; + return hlist_entry_safe(rcu_dereference_raw( + hlist_next_rcu(&ch->cache_list)), + struct cache_head, cache_list); } *pos &= ~((1LL<<32) - 1); while (hash < cd->hash_size && - cd->hash_table[hash] == NULL) { + hlist_empty(&cd->hash_table[hash])) { hash++; *pos += 1LL<<32; } if (hash >= cd->hash_size) return NULL; ++*pos; - return cd->hash_table[hash]; + return hlist_entry_safe(rcu_dereference_raw( + hlist_first_rcu(&cd->hash_table[hash])), + struct cache_head, cache_list); } -static void c_stop(struct seq_file *m, void *p) - __releases(cd->hash_lock) +void *cache_seq_start_rcu(struct seq_file *m, loff_t *pos) + __acquires(RCU) { - struct cache_detail *cd = ((struct handle*)m->private)->cd; - read_unlock(&cd->hash_lock); + rcu_read_lock(); + return __cache_seq_start(m, pos); } +EXPORT_SYMBOL_GPL(cache_seq_start_rcu); + +void *cache_seq_next_rcu(struct seq_file *file, void *p, loff_t *pos) +{ + return cache_seq_next(file, p, pos); +} +EXPORT_SYMBOL_GPL(cache_seq_next_rcu); + +void cache_seq_stop_rcu(struct seq_file *m, void *p) + __releases(RCU) +{ + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(cache_seq_stop_rcu); static int c_show(struct seq_file *m, void *p) { struct cache_head *cp = p; - struct cache_detail *cd = ((struct handle*)m->private)->cd; + struct cache_detail *cd = m->private; if (p == SEQ_START_TOKEN) return cd->cache_show(m, cd, NULL); ifdebug(CACHE) - seq_printf(m, "# expiry=%ld refcnt=%d flags=%lx\n", + seq_printf(m, "# expiry=%lld refcnt=%d flags=%lx\n", convert_to_wallclock(cp->expiry_time), - atomic_read(&cp->ref.refcount), cp->flags); - cache_get(cp); - if (cache_check(cd, cp, NULL)) - /* cache_check does a cache_put on failure */ - seq_printf(m, "# "); - else { - if (cache_is_expired(cd, cp)) - seq_printf(m, "# "); - cache_put(cp, cd); - } + kref_read(&cp->ref), cp->flags); + + if (cache_check_rcu(cd, cp, NULL)) + seq_puts(m, "# "); + else if (cache_is_expired(cd, cp)) + seq_puts(m, "# "); return cd->cache_show(m, cd, cp); } static const struct seq_operations cache_content_op = { - .start = c_start, - .next = c_next, - .stop = c_stop, + .start = cache_seq_start_rcu, + .next = cache_seq_next_rcu, + .stop = cache_seq_stop_rcu, .show = c_show, }; static int content_open(struct inode *inode, struct file *file, struct cache_detail *cd) { - struct handle *han; + struct seq_file *seq; + int err; if (!cd || !try_module_get(cd->owner)) return -EACCES; - han = __seq_open_private(file, &cache_content_op, sizeof(*han)); - if (han == NULL) { + + err = seq_open(file, &cache_content_op); + if (err) { module_put(cd->owner); - return -ENOMEM; + return err; } - han->cd = cd; + seq = file->private_data; + seq->private = cd; return 0; } static int content_release(struct inode *inode, struct file *file, struct cache_detail *cd) { - int ret = seq_release_private(inode, file); + int ret = seq_release(inode, file); module_put(cd->owner); return ret; } @@ -1429,20 +1503,11 @@ static ssize_t read_flush(struct file *file, char __user *buf, struct cache_detail *cd) { char tbuf[22]; - unsigned long p = *ppos; size_t len; - snprintf(tbuf, sizeof(tbuf), "%lu\n", convert_to_wallclock(cd->flush_time)); - len = strlen(tbuf); - if (p >= len) - return 0; - len -= p; - if (len > count) - len = count; - if (copy_to_user(buf, (void*)(tbuf+p), len)) - return -EFAULT; - *ppos += len; - return len; + len = snprintf(tbuf, sizeof(tbuf), "%llu\n", + convert_to_wallclock(cd->flush_time)); + return simple_read_from_buffer(buf, count, ppos, tbuf, len); } static ssize_t write_flush(struct file *file, const char __user *buf, @@ -1450,7 +1515,8 @@ static ssize_t write_flush(struct file *file, const char __user *buf, struct cache_detail *cd) { char tbuf[20]; - char *bp, *ep; + char *ep; + time64_t now; if (*ppos || count > sizeof(tbuf)-1) return -EINVAL; @@ -1460,12 +1526,29 @@ static ssize_t write_flush(struct file *file, const char __user *buf, simple_strtoul(tbuf, &ep, 0); if (*ep && *ep != '\n') return -EINVAL; + /* Note that while we check that 'buf' holds a valid number, + * we always ignore the value and just flush everything. + * Making use of the number leads to races. + */ + + now = seconds_since_boot(); + /* Always flush everything, so behave like cache_purge() + * Do this by advancing flush_time to the current time, + * or by one second if it has already reached the current time. + * Newly added cache entries will always have ->last_refresh greater + * that ->flush_time, so they don't get flushed prematurely. + */ + + if (cd->flush_time >= now) + now = cd->flush_time + 1; - bp = tbuf; - cd->flush_time = get_expiry(&bp); - cd->nextcheck = seconds_since_boot(); + cd->flush_time = now; + cd->nextcheck = now; cache_flush(); + if (cd->flush) + cd->flush(); + *ppos += count; return count; } @@ -1473,7 +1556,7 @@ static ssize_t write_flush(struct file *file, const char __user *buf, static ssize_t cache_read_procfs(struct file *filp, char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = PDE_DATA(file_inode(filp)); + struct cache_detail *cd = pde_data(file_inode(filp)); return cache_read(filp, buf, count, ppos, cd); } @@ -1481,14 +1564,14 @@ static ssize_t cache_read_procfs(struct file *filp, char __user *buf, static ssize_t cache_write_procfs(struct file *filp, const char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = PDE_DATA(file_inode(filp)); + struct cache_detail *cd = pde_data(file_inode(filp)); return cache_write(filp, buf, count, ppos, cd); } -static unsigned int cache_poll_procfs(struct file *filp, poll_table *wait) +static __poll_t cache_poll_procfs(struct file *filp, poll_table *wait) { - struct cache_detail *cd = PDE_DATA(file_inode(filp)); + struct cache_detail *cd = pde_data(file_inode(filp)); return cache_poll(filp, wait, cd); } @@ -1497,67 +1580,65 @@ static long cache_ioctl_procfs(struct file *filp, unsigned int cmd, unsigned long arg) { struct inode *inode = file_inode(filp); - struct cache_detail *cd = PDE_DATA(inode); + struct cache_detail *cd = pde_data(inode); return cache_ioctl(inode, filp, cmd, arg, cd); } static int cache_open_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE_DATA(inode); + struct cache_detail *cd = pde_data(inode); return cache_open(inode, filp, cd); } static int cache_release_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE_DATA(inode); + struct cache_detail *cd = pde_data(inode); return cache_release(inode, filp, cd); } -static const struct file_operations cache_file_operations_procfs = { - .owner = THIS_MODULE, - .llseek = no_llseek, - .read = cache_read_procfs, - .write = cache_write_procfs, - .poll = cache_poll_procfs, - .unlocked_ioctl = cache_ioctl_procfs, /* for FIONREAD */ - .open = cache_open_procfs, - .release = cache_release_procfs, +static const struct proc_ops cache_channel_proc_ops = { + .proc_read = cache_read_procfs, + .proc_write = cache_write_procfs, + .proc_poll = cache_poll_procfs, + .proc_ioctl = cache_ioctl_procfs, /* for FIONREAD */ + .proc_open = cache_open_procfs, + .proc_release = cache_release_procfs, }; static int content_open_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE_DATA(inode); + struct cache_detail *cd = pde_data(inode); return content_open(inode, filp, cd); } static int content_release_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE_DATA(inode); + struct cache_detail *cd = pde_data(inode); return content_release(inode, filp, cd); } -static const struct file_operations content_file_operations_procfs = { - .open = content_open_procfs, - .read = seq_read, - .llseek = seq_lseek, - .release = content_release_procfs, +static const struct proc_ops content_proc_ops = { + .proc_open = content_open_procfs, + .proc_read = seq_read, + .proc_lseek = seq_lseek, + .proc_release = content_release_procfs, }; static int open_flush_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE_DATA(inode); + struct cache_detail *cd = pde_data(inode); return open_flush(inode, filp, cd); } static int release_flush_procfs(struct inode *inode, struct file *filp) { - struct cache_detail *cd = PDE_DATA(inode); + struct cache_detail *cd = pde_data(inode); return release_flush(inode, filp, cd); } @@ -1565,7 +1646,7 @@ static int release_flush_procfs(struct inode *inode, struct file *filp) static ssize_t read_flush_procfs(struct file *filp, char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = PDE_DATA(file_inode(filp)); + struct cache_detail *cd = pde_data(file_inode(filp)); return read_flush(filp, buf, count, ppos, cd); } @@ -1574,83 +1655,61 @@ static ssize_t write_flush_procfs(struct file *filp, const char __user *buf, size_t count, loff_t *ppos) { - struct cache_detail *cd = PDE_DATA(file_inode(filp)); + struct cache_detail *cd = pde_data(file_inode(filp)); return write_flush(filp, buf, count, ppos, cd); } -static const struct file_operations cache_flush_operations_procfs = { - .open = open_flush_procfs, - .read = read_flush_procfs, - .write = write_flush_procfs, - .release = release_flush_procfs, - .llseek = no_llseek, +static const struct proc_ops cache_flush_proc_ops = { + .proc_open = open_flush_procfs, + .proc_read = read_flush_procfs, + .proc_write = write_flush_procfs, + .proc_release = release_flush_procfs, }; -static void remove_cache_proc_entries(struct cache_detail *cd, struct net *net) +static void remove_cache_proc_entries(struct cache_detail *cd) { - struct sunrpc_net *sn; - - if (cd->u.procfs.proc_ent == NULL) - return; - if (cd->u.procfs.flush_ent) - remove_proc_entry("flush", cd->u.procfs.proc_ent); - if (cd->u.procfs.channel_ent) - remove_proc_entry("channel", cd->u.procfs.proc_ent); - if (cd->u.procfs.content_ent) - remove_proc_entry("content", cd->u.procfs.proc_ent); - cd->u.procfs.proc_ent = NULL; - sn = net_generic(net, sunrpc_net_id); - remove_proc_entry(cd->name, sn->proc_net_rpc); + if (cd->procfs) { + proc_remove(cd->procfs); + cd->procfs = NULL; + } } -#ifdef CONFIG_PROC_FS static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) { struct proc_dir_entry *p; struct sunrpc_net *sn; + if (!IS_ENABLED(CONFIG_PROC_FS)) + return 0; + sn = net_generic(net, sunrpc_net_id); - cd->u.procfs.proc_ent = proc_mkdir(cd->name, sn->proc_net_rpc); - if (cd->u.procfs.proc_ent == NULL) + cd->procfs = proc_mkdir(cd->name, sn->proc_net_rpc); + if (cd->procfs == NULL) goto out_nomem; - cd->u.procfs.channel_ent = NULL; - cd->u.procfs.content_ent = NULL; - p = proc_create_data("flush", S_IFREG|S_IRUSR|S_IWUSR, - cd->u.procfs.proc_ent, - &cache_flush_operations_procfs, cd); - cd->u.procfs.flush_ent = p; + p = proc_create_data("flush", S_IFREG | 0600, + cd->procfs, &cache_flush_proc_ops, cd); if (p == NULL) goto out_nomem; if (cd->cache_request || cd->cache_parse) { - p = proc_create_data("channel", S_IFREG|S_IRUSR|S_IWUSR, - cd->u.procfs.proc_ent, - &cache_file_operations_procfs, cd); - cd->u.procfs.channel_ent = p; + p = proc_create_data("channel", S_IFREG | 0600, cd->procfs, + &cache_channel_proc_ops, cd); if (p == NULL) goto out_nomem; } if (cd->cache_show) { - p = proc_create_data("content", S_IFREG|S_IRUSR, - cd->u.procfs.proc_ent, - &content_file_operations_procfs, cd); - cd->u.procfs.content_ent = p; + p = proc_create_data("content", S_IFREG | 0400, cd->procfs, + &content_proc_ops, cd); if (p == NULL) goto out_nomem; } return 0; out_nomem: - remove_cache_proc_entries(cd, net); + remove_cache_proc_entries(cd); return -ENOMEM; } -#else /* CONFIG_PROC_FS */ -static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) -{ - return 0; -} -#endif void __init cache_initialize(void) { @@ -1671,25 +1730,29 @@ EXPORT_SYMBOL_GPL(cache_register_net); void cache_unregister_net(struct cache_detail *cd, struct net *net) { - remove_cache_proc_entries(cd, net); + remove_cache_proc_entries(cd); sunrpc_destroy_cache_detail(cd); } EXPORT_SYMBOL_GPL(cache_unregister_net); -struct cache_detail *cache_create_net(struct cache_detail *tmpl, struct net *net) +struct cache_detail *cache_create_net(const struct cache_detail *tmpl, struct net *net) { struct cache_detail *cd; + int i; cd = kmemdup(tmpl, sizeof(struct cache_detail), GFP_KERNEL); if (cd == NULL) return ERR_PTR(-ENOMEM); - cd->hash_table = kzalloc(cd->hash_size * sizeof(struct cache_head *), + cd->hash_table = kcalloc(cd->hash_size, sizeof(struct hlist_head), GFP_KERNEL); if (cd->hash_table == NULL) { kfree(cd); return ERR_PTR(-ENOMEM); } + + for (i = 0; i < cd->hash_size; i++) + INIT_HLIST_HEAD(&cd->hash_table[i]); cd->net = net; return cd; } @@ -1718,7 +1781,7 @@ static ssize_t cache_write_pipefs(struct file *filp, const char __user *buf, return cache_write(filp, buf, count, ppos, cd); } -static unsigned int cache_poll_pipefs(struct file *filp, poll_table *wait) +static __poll_t cache_poll_pipefs(struct file *filp, poll_table *wait) { struct cache_detail *cd = RPC_I(file_inode(filp))->private; @@ -1750,7 +1813,6 @@ static int cache_release_pipefs(struct inode *inode, struct file *filp) const struct file_operations cache_file_operations_pipefs = { .owner = THIS_MODULE, - .llseek = no_llseek, .read = cache_read_pipefs, .write = cache_write_pipefs, .poll = cache_poll_pipefs, @@ -1816,7 +1878,6 @@ const struct file_operations cache_flush_operations_pipefs = { .read = read_flush_pipefs, .write = write_flush_pipefs, .release = release_flush_pipefs, - .llseek = no_llseek, }; int sunrpc_cache_register_pipefs(struct dentry *parent, @@ -1826,15 +1887,28 @@ int sunrpc_cache_register_pipefs(struct dentry *parent, struct dentry *dir = rpc_create_cache_dir(parent, name, umode, cd); if (IS_ERR(dir)) return PTR_ERR(dir); - cd->u.pipefs.dir = dir; + cd->pipefs = dir; return 0; } EXPORT_SYMBOL_GPL(sunrpc_cache_register_pipefs); void sunrpc_cache_unregister_pipefs(struct cache_detail *cd) { - rpc_remove_cache_dir(cd->u.pipefs.dir); - cd->u.pipefs.dir = NULL; + if (cd->pipefs) { + rpc_remove_cache_dir(cd->pipefs); + cd->pipefs = NULL; + } } EXPORT_SYMBOL_GPL(sunrpc_cache_unregister_pipefs); +void sunrpc_cache_unhash(struct cache_detail *cd, struct cache_head *h) +{ + spin_lock(&cd->hash_lock); + if (!hlist_unhashed(&h->cache_list)){ + sunrpc_begin_cache_remove_entry(h, cd); + spin_unlock(&cd->hash_lock); + sunrpc_end_cache_remove_entry(h, cd); + } else + spin_unlock(&cd->hash_lock); +} +EXPORT_SYMBOL_GPL(sunrpc_cache_unhash); diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c index 9963584605c0..58442ae1c2da 100644 --- a/net/sunrpc/clnt.c +++ b/net/sunrpc/clnt.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/clnt.c * @@ -25,12 +26,12 @@ #include <linux/namei.h> #include <linux/mount.h> #include <linux/slab.h> +#include <linux/rcupdate.h> #include <linux/utsname.h> #include <linux/workqueue.h> #include <linux/in.h> #include <linux/in6.h> #include <linux/un.h> -#include <linux/rcupdate.h> #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/addr.h> @@ -40,45 +41,38 @@ #include <trace/events/sunrpc.h> #include "sunrpc.h" +#include "sysfs.h" #include "netns.h" -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_CALL #endif -#define dprint_status(t) \ - dprintk("RPC: %5u %s (status %d)\n", t->tk_pid, \ - __func__, t->tk_status) - -/* - * All RPC clients are linked into this list - */ - static DECLARE_WAIT_QUEUE_HEAD(destroy_wait); - static void call_start(struct rpc_task *task); static void call_reserve(struct rpc_task *task); static void call_reserveresult(struct rpc_task *task); static void call_allocate(struct rpc_task *task); +static void call_encode(struct rpc_task *task); static void call_decode(struct rpc_task *task); static void call_bind(struct rpc_task *task); static void call_bind_status(struct rpc_task *task); static void call_transmit(struct rpc_task *task); -#if defined(CONFIG_SUNRPC_BACKCHANNEL) -static void call_bc_transmit(struct rpc_task *task); -#endif /* CONFIG_SUNRPC_BACKCHANNEL */ static void call_status(struct rpc_task *task); static void call_transmit_status(struct rpc_task *task); static void call_refresh(struct rpc_task *task); static void call_refreshresult(struct rpc_task *task); -static void call_timeout(struct rpc_task *task); static void call_connect(struct rpc_task *task); static void call_connect_status(struct rpc_task *task); -static __be32 *rpc_encode_header(struct rpc_task *task); -static __be32 *rpc_verify_header(struct rpc_task *task); +static int rpc_encode_header(struct rpc_task *task, + struct xdr_stream *xdr); +static int rpc_decode_header(struct rpc_task *task, + struct xdr_stream *xdr); static int rpc_ping(struct rpc_clnt *clnt); +static int rpc_ping_noreply(struct rpc_clnt *clnt); +static void rpc_check_timeout(struct rpc_task *task); static void rpc_register_client(struct rpc_clnt *clnt) { @@ -102,12 +96,7 @@ static void rpc_unregister_client(struct rpc_clnt *clnt) static void __rpc_clnt_remove_pipedir(struct rpc_clnt *clnt) { - if (clnt->cl_dentry) { - if (clnt->cl_auth && clnt->cl_auth->au_ops->pipes_destroy) - clnt->cl_auth->au_ops->pipes_destroy(clnt->cl_auth); - rpc_remove_client_dir(clnt->cl_dentry); - } - clnt->cl_dentry = NULL; + rpc_remove_client_dir(clnt); } static void rpc_clnt_remove_pipedir(struct rpc_clnt *clnt) @@ -117,88 +106,82 @@ static void rpc_clnt_remove_pipedir(struct rpc_clnt *clnt) pipefs_sb = rpc_get_sb_net(net); if (pipefs_sb) { - __rpc_clnt_remove_pipedir(clnt); + if (pipefs_sb == clnt->pipefs_sb) + __rpc_clnt_remove_pipedir(clnt); rpc_put_sb_net(net); } } -static struct dentry *rpc_setup_pipedir_sb(struct super_block *sb, - struct rpc_clnt *clnt, - const char *dir_name) +static int rpc_setup_pipedir_sb(struct super_block *sb, + struct rpc_clnt *clnt) { static uint32_t clntid; + const char *dir_name = clnt->cl_program->pipe_dir_name; char name[15]; - struct dentry *dir, *dentry; + struct dentry *dir; + int err; dir = rpc_d_lookup_sb(sb, dir_name); if (dir == NULL) { pr_info("RPC: pipefs directory doesn't exist: %s\n", dir_name); - return dir; + return -ENOENT; } for (;;) { snprintf(name, sizeof(name), "clnt%x", (unsigned int)clntid++); name[sizeof(name) - 1] = '\0'; - dentry = rpc_create_client_dir(dir, name, clnt); - if (!IS_ERR(dentry)) + err = rpc_create_client_dir(dir, name, clnt); + if (!err) break; - if (dentry == ERR_PTR(-EEXIST)) + if (err == -EEXIST) continue; printk(KERN_INFO "RPC: Couldn't create pipefs entry" - " %s/%s, error %ld\n", - dir_name, name, PTR_ERR(dentry)); + " %s/%s, error %d\n", + dir_name, name, err); break; } dput(dir); - return dentry; + return err; } static int -rpc_setup_pipedir(struct rpc_clnt *clnt, const char *dir_name, - struct super_block *pipefs_sb) +rpc_setup_pipedir(struct super_block *pipefs_sb, struct rpc_clnt *clnt) { - struct dentry *dentry; + clnt->pipefs_sb = pipefs_sb; - clnt->cl_dentry = NULL; - if (dir_name == NULL) - return 0; - dentry = rpc_setup_pipedir_sb(pipefs_sb, clnt, dir_name); - if (IS_ERR(dentry)) - return PTR_ERR(dentry); - clnt->cl_dentry = dentry; + if (clnt->cl_program->pipe_dir_name != NULL) { + int err = rpc_setup_pipedir_sb(pipefs_sb, clnt); + if (err && err != -ENOENT) + return err; + } return 0; } -static inline int rpc_clnt_skip_event(struct rpc_clnt *clnt, unsigned long event) +static int rpc_clnt_skip_event(struct rpc_clnt *clnt, unsigned long event) { - if (((event == RPC_PIPEFS_MOUNT) && clnt->cl_dentry) || - ((event == RPC_PIPEFS_UMOUNT) && !clnt->cl_dentry)) - return 1; - if ((event == RPC_PIPEFS_MOUNT) && atomic_read(&clnt->cl_count) == 0) + if (clnt->cl_program->pipe_dir_name == NULL) return 1; + + switch (event) { + case RPC_PIPEFS_MOUNT: + if (clnt->cl_pipedir_objects.pdh_dentry != NULL) + return 1; + if (refcount_read(&clnt->cl_count) == 0) + return 1; + break; + case RPC_PIPEFS_UMOUNT: + if (clnt->cl_pipedir_objects.pdh_dentry == NULL) + return 1; + break; + } return 0; } static int __rpc_clnt_handle_event(struct rpc_clnt *clnt, unsigned long event, struct super_block *sb) { - struct dentry *dentry; - int err = 0; - switch (event) { case RPC_PIPEFS_MOUNT: - dentry = rpc_setup_pipedir_sb(sb, clnt, - clnt->cl_program->pipe_dir_name); - if (!dentry) - return -ENOENT; - if (IS_ERR(dentry)) - return PTR_ERR(dentry); - clnt->cl_dentry = dentry; - if (clnt->cl_auth->au_ops->pipes_create) { - err = clnt->cl_auth->au_ops->pipes_create(clnt->cl_auth); - if (err) - __rpc_clnt_remove_pipedir(clnt); - } - break; + return rpc_setup_pipedir_sb(sb, clnt); case RPC_PIPEFS_UMOUNT: __rpc_clnt_remove_pipedir(clnt); break; @@ -206,7 +189,7 @@ static int __rpc_clnt_handle_event(struct rpc_clnt *clnt, unsigned long event, printk(KERN_ERR "%s: unknown event: %ld\n", __func__, event); return -ENOTSUPP; } - return err; + return 0; } static int __rpc_pipefs_event(struct rpc_clnt *clnt, unsigned long event, @@ -230,8 +213,6 @@ static struct rpc_clnt *rpc_get_client_for_event(struct net *net, int event) spin_lock(&sn->rpc_client_lock); list_for_each_entry(clnt, &sn->all_clients, cl_clients) { - if (clnt->cl_program->pipe_dir_name == NULL) - continue; if (rpc_clnt_skip_event(clnt, event)) continue; spin_unlock(&sn->rpc_client_lock); @@ -271,26 +252,53 @@ void rpc_clients_notifier_unregister(void) return rpc_pipefs_notifier_unregister(&rpc_clients_block); } +static struct rpc_xprt *rpc_clnt_set_transport(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + const struct rpc_timeout *timeout) +{ + struct rpc_xprt *old; + + spin_lock(&clnt->cl_lock); + old = rcu_dereference_protected(clnt->cl_xprt, + lockdep_is_held(&clnt->cl_lock)); + + clnt->cl_timeout = timeout; + rcu_assign_pointer(clnt->cl_xprt, xprt); + spin_unlock(&clnt->cl_lock); + + return old; +} + static void rpc_clnt_set_nodename(struct rpc_clnt *clnt, const char *nodename) { - clnt->cl_nodelen = strlen(nodename); - if (clnt->cl_nodelen > UNX_MAXNODENAME) - clnt->cl_nodelen = UNX_MAXNODENAME; - memcpy(clnt->cl_nodename, nodename, clnt->cl_nodelen); + ssize_t copied; + + copied = strscpy(clnt->cl_nodename, + nodename, sizeof(clnt->cl_nodename)); + + clnt->cl_nodelen = copied < 0 + ? sizeof(clnt->cl_nodename) - 1 + : copied; } -static int rpc_client_register(const struct rpc_create_args *args, - struct rpc_clnt *clnt) +static int rpc_client_register(struct rpc_clnt *clnt, + rpc_authflavor_t pseudoflavor, + const char *client_name) { - const struct rpc_program *program = args->program; + struct rpc_auth_create_args auth_args = { + .pseudoflavor = pseudoflavor, + .target_name = client_name, + }; struct rpc_auth *auth; struct net *net = rpc_net_ns(clnt); struct super_block *pipefs_sb; int err; + rpc_clnt_debugfs_register(clnt); + pipefs_sb = rpc_get_sb_net(net); if (pipefs_sb) { - err = rpc_setup_pipedir(clnt, program->pipe_dir_name, pipefs_sb); + err = rpc_setup_pipedir(pipefs_sb, clnt); if (err) goto out; } @@ -299,34 +307,61 @@ static int rpc_client_register(const struct rpc_create_args *args, if (pipefs_sb) rpc_put_sb_net(net); - auth = rpcauth_create(args->authflavor, clnt); + auth = rpcauth_create(&auth_args, clnt); if (IS_ERR(auth)) { dprintk("RPC: Couldn't create auth handle (flavor %u)\n", - args->authflavor); + pseudoflavor); err = PTR_ERR(auth); goto err_auth; } return 0; err_auth: pipefs_sb = rpc_get_sb_net(net); + rpc_unregister_client(clnt); __rpc_clnt_remove_pipedir(clnt); out: if (pipefs_sb) rpc_put_sb_net(net); + rpc_sysfs_client_destroy(clnt); + rpc_clnt_debugfs_unregister(clnt); return err; } -static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, struct rpc_xprt *xprt) +static DEFINE_IDA(rpc_clids); + +void rpc_cleanup_clids(void) +{ + ida_destroy(&rpc_clids); +} + +static int rpc_alloc_clid(struct rpc_clnt *clnt) +{ + int clid; + + clid = ida_alloc(&rpc_clids, GFP_KERNEL); + if (clid < 0) + return clid; + clnt->cl_clid = clid; + return 0; +} + +static void rpc_free_clid(struct rpc_clnt *clnt) +{ + ida_free(&rpc_clids, clnt->cl_clid); +} + +static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, + struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, + struct rpc_clnt *parent) { const struct rpc_program *program = args->program; const struct rpc_version *version; - struct rpc_clnt *clnt = NULL; + struct rpc_clnt *clnt = NULL; + const struct rpc_timeout *timeout; + const char *nodename = args->nodename; int err; - /* sanity check the name before trying to print it */ - dprintk("RPC: creating %s client for %s (xprt %p)\n", - program->name, args->servername, xprt); - err = rpciod_up(); if (err) goto out_no_rpciod; @@ -342,16 +377,21 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, stru clnt = kzalloc(sizeof(*clnt), GFP_KERNEL); if (!clnt) goto out_err; - clnt->cl_parent = clnt; + clnt->cl_parent = parent ? : clnt; + clnt->cl_xprtsec = args->xprtsec; - rcu_assign_pointer(clnt->cl_xprt, xprt); + err = rpc_alloc_clid(clnt); + if (err) + goto out_no_clid; + + clnt->cl_cred = get_cred(args->cred); clnt->cl_procinfo = version->procs; clnt->cl_maxproc = version->nrprocs; - clnt->cl_protname = program->name; clnt->cl_prog = args->prognumber ? : program->number; clnt->cl_vers = version->number; - clnt->cl_stats = program->stats; + clnt->cl_stats = args->stats ? : program->stats; clnt->cl_metrics = rpc_alloc_iostats(clnt); + rpc_init_pipe_dir_head(&clnt->cl_pipedir_objects); err = -ENOMEM; if (clnt->cl_metrics == NULL) goto out_no_stats; @@ -359,48 +399,114 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, stru INIT_LIST_HEAD(&clnt->cl_tasks); spin_lock_init(&clnt->cl_lock); - if (!xprt_bound(xprt)) - clnt->cl_autobind = 1; - - clnt->cl_timeout = xprt->timeout; + timeout = xprt->timeout; if (args->timeout != NULL) { memcpy(&clnt->cl_timeout_default, args->timeout, sizeof(clnt->cl_timeout_default)); - clnt->cl_timeout = &clnt->cl_timeout_default; + timeout = &clnt->cl_timeout_default; } + rpc_clnt_set_transport(clnt, xprt, timeout); + xprt->main = true; + xprt_iter_init(&clnt->cl_xpi, xps); + xprt_switch_put(xps); + clnt->cl_rtt = &clnt->cl_rtt_default; rpc_init_rtt(&clnt->cl_rtt_default, clnt->cl_timeout->to_initval); - clnt->cl_principal = NULL; - if (args->client_name) { - clnt->cl_principal = kstrdup(args->client_name, GFP_KERNEL); - if (!clnt->cl_principal) - goto out_no_principal; - } - atomic_set(&clnt->cl_count, 1); + refcount_set(&clnt->cl_count, 1); + if (nodename == NULL) + nodename = utsname()->nodename; /* save the nodename */ - rpc_clnt_set_nodename(clnt, utsname()->nodename); + rpc_clnt_set_nodename(clnt, nodename); - err = rpc_client_register(args, clnt); + rpc_sysfs_client_setup(clnt, xps, rpc_net_ns(clnt)); + err = rpc_client_register(clnt, args->authflavor, args->client_name); if (err) goto out_no_path; + if (parent) + refcount_inc(&parent->cl_count); + + trace_rpc_clnt_new(clnt, xprt, args); return clnt; out_no_path: - kfree(clnt->cl_principal); -out_no_principal: rpc_free_iostats(clnt->cl_metrics); out_no_stats: + put_cred(clnt->cl_cred); + rpc_free_clid(clnt); +out_no_clid: kfree(clnt); out_err: rpciod_down(); out_no_rpciod: + xprt_switch_put(xps); xprt_put(xprt); + trace_rpc_clnt_new_err(program->name, args->servername, err); return ERR_PTR(err); } +static struct rpc_clnt *rpc_create_xprt(struct rpc_create_args *args, + struct rpc_xprt *xprt) +{ + struct rpc_clnt *clnt = NULL; + struct rpc_xprt_switch *xps; + + if (args->bc_xprt && args->bc_xprt->xpt_bc_xps) { + WARN_ON_ONCE(!(args->protocol & XPRT_TRANSPORT_BC)); + xps = args->bc_xprt->xpt_bc_xps; + xprt_switch_get(xps); + } else { + xps = xprt_switch_alloc(xprt, GFP_KERNEL); + if (xps == NULL) { + xprt_put(xprt); + return ERR_PTR(-ENOMEM); + } + if (xprt->bc_xprt) { + xprt_switch_get(xps); + xprt->bc_xprt->xpt_bc_xps = xps; + } + } + clnt = rpc_new_client(args, xps, xprt, NULL); + if (IS_ERR(clnt)) + return clnt; + + if (!(args->flags & RPC_CLNT_CREATE_NOPING)) { + int err = rpc_ping(clnt); + if (err != 0) { + rpc_shutdown_client(clnt); + return ERR_PTR(err); + } + } else if (args->flags & RPC_CLNT_CREATE_CONNECTED) { + int err = rpc_ping_noreply(clnt); + if (err != 0) { + rpc_shutdown_client(clnt); + return ERR_PTR(err); + } + } + + clnt->cl_softrtry = 1; + if (args->flags & (RPC_CLNT_CREATE_HARDRTRY|RPC_CLNT_CREATE_SOFTERR)) { + clnt->cl_softrtry = 0; + if (args->flags & RPC_CLNT_CREATE_SOFTERR) + clnt->cl_softerr = 1; + } + + if (args->flags & RPC_CLNT_CREATE_AUTOBIND) + clnt->cl_autobind = 1; + if (args->flags & RPC_CLNT_CREATE_NO_RETRANS_TIMEOUT) + clnt->cl_noretranstimeo = 1; + if (args->flags & RPC_CLNT_CREATE_DISCRTRY) + clnt->cl_discrtry = 1; + if (!(args->flags & RPC_CLNT_CREATE_QUIET)) + clnt->cl_chatty = 1; + if (args->flags & RPC_CLNT_CREATE_NETUNREACH_FATAL) + clnt->cl_netunreach_fatal = 1; + + return clnt; +} + /** * rpc_create - create an RPC client and transport with one call * @args: rpc_clnt create argument structure @@ -414,7 +520,6 @@ out_no_rpciod: struct rpc_clnt *rpc_create(struct rpc_create_args *args) { struct rpc_xprt *xprt; - struct rpc_clnt *clnt; struct xprt_create xprtargs = { .net = args->net, .ident = args->protocol, @@ -423,8 +528,22 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args) .addrlen = args->addrsize, .servername = args->servername, .bc_xprt = args->bc_xprt, + .xprtsec = args->xprtsec, + .connect_timeout = args->connect_timeout, + .reconnect_timeout = args->reconnect_timeout, }; - char servername[48]; + char servername[RPC_MAXNETNAMELEN]; + struct rpc_clnt *clnt; + int i; + + if (args->bc_xprt) { + WARN_ON_ONCE(!(args->protocol & XPRT_TRANSPORT_BC)); + xprt = args->bc_xprt->xpt_bc_xprt; + if (xprt) { + xprt_get(xprt); + return rpc_create_xprt(args, xprt); + } + } if (args->flags & RPC_CLNT_CREATE_INFINITE_SLOTS) xprtargs.flags |= XPRT_CREATE_INFINITE_SLOTS; @@ -445,8 +564,12 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args) servername[0] = '\0'; switch (args->address->sa_family) { case AF_LOCAL: - snprintf(servername, sizeof(servername), "%s", - sun->sun_path); + if (sun->sun_path[0]) + snprintf(servername, sizeof(servername), "%s", + sun->sun_path); + else + snprintf(servername, sizeof(servername), "@%s", + sun->sun_path+1); break; case AF_INET: snprintf(servername, sizeof(servername), "%pI4", @@ -477,30 +600,18 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args) xprt->resvport = 1; if (args->flags & RPC_CLNT_CREATE_NONPRIVPORT) xprt->resvport = 0; + xprt->reuseport = 0; + if (args->flags & RPC_CLNT_CREATE_REUSEPORT) + xprt->reuseport = 1; - clnt = rpc_new_client(args, xprt); - if (IS_ERR(clnt)) + clnt = rpc_create_xprt(args, xprt); + if (IS_ERR(clnt) || args->nconnect <= 1) return clnt; - if (!(args->flags & RPC_CLNT_CREATE_NOPING)) { - int err = rpc_ping(clnt); - if (err != 0) { - rpc_shutdown_client(clnt); - return ERR_PTR(err); - } + for (i = 0; i < args->nconnect - 1; i++) { + if (rpc_clnt_add_xprt(clnt, &xprtargs, NULL, NULL) < 0) + break; } - - clnt->cl_softrtry = 1; - if (args->flags & RPC_CLNT_CREATE_HARDRTRY) - clnt->cl_softrtry = 0; - - if (args->flags & RPC_CLNT_CREATE_AUTOBIND) - clnt->cl_autobind = 1; - if (args->flags & RPC_CLNT_CREATE_DISCRTRY) - clnt->cl_discrtry = 1; - if (!(args->flags & RPC_CLNT_CREATE_QUIET)) - clnt->cl_chatty = 1; - return clnt; } EXPORT_SYMBOL_GPL(rpc_create); @@ -513,6 +624,7 @@ EXPORT_SYMBOL_GPL(rpc_create); static struct rpc_clnt *__rpc_clone_client(struct rpc_create_args *args, struct rpc_clnt *clnt) { + struct rpc_xprt_switch *xps; struct rpc_xprt *xprt; struct rpc_clnt *new; int err; @@ -520,29 +632,34 @@ static struct rpc_clnt *__rpc_clone_client(struct rpc_create_args *args, err = -ENOMEM; rcu_read_lock(); xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); rcu_read_unlock(); - if (xprt == NULL) - goto out_err; - args->servername = xprt->servername; - - new = rpc_new_client(args, xprt); - if (IS_ERR(new)) { - err = PTR_ERR(new); + if (xprt == NULL || xps == NULL) { + xprt_put(xprt); + xprt_switch_put(xps); goto out_err; } + args->servername = xprt->servername; + args->nodename = clnt->cl_nodename; - atomic_inc(&clnt->cl_count); - new->cl_parent = clnt; + new = rpc_new_client(args, xps, xprt, clnt); + if (IS_ERR(new)) + return new; /* Turn off autobind on clones */ new->cl_autobind = 0; new->cl_softrtry = clnt->cl_softrtry; + new->cl_softerr = clnt->cl_softerr; + new->cl_noretranstimeo = clnt->cl_noretranstimeo; new->cl_discrtry = clnt->cl_discrtry; new->cl_chatty = clnt->cl_chatty; + new->cl_netunreach_fatal = clnt->cl_netunreach_fatal; + new->cl_principal = clnt->cl_principal; + new->cl_max_connect = clnt->cl_max_connect; return new; out_err: - dprintk("RPC: %s: returned error %d\n", __func__, err); + trace_rpc_clnt_clone_err(clnt, err); return ERR_PTR(err); } @@ -560,7 +677,8 @@ struct rpc_clnt *rpc_clone_client(struct rpc_clnt *clnt) .prognumber = clnt->cl_prog, .version = clnt->cl_vers, .authflavor = clnt->cl_auth->au_flavor, - .client_name = clnt->cl_principal, + .cred = clnt->cl_cred, + .stats = clnt->cl_stats, }; return __rpc_clone_client(&args, clnt); } @@ -582,12 +700,172 @@ rpc_clone_client_set_auth(struct rpc_clnt *clnt, rpc_authflavor_t flavor) .prognumber = clnt->cl_prog, .version = clnt->cl_vers, .authflavor = flavor, - .client_name = clnt->cl_principal, + .cred = clnt->cl_cred, + .stats = clnt->cl_stats, }; return __rpc_clone_client(&args, clnt); } EXPORT_SYMBOL_GPL(rpc_clone_client_set_auth); +/** + * rpc_switch_client_transport: switch the RPC transport on the fly + * @clnt: pointer to a struct rpc_clnt + * @args: pointer to the new transport arguments + * @timeout: pointer to the new timeout parameters + * + * This function allows the caller to switch the RPC transport for the + * rpc_clnt structure 'clnt' to allow it to connect to a mirrored NFS + * server, for instance. It assumes that the caller has ensured that + * there are no active RPC tasks by using some form of locking. + * + * Returns zero if "clnt" is now using the new xprt. Otherwise a + * negative errno is returned, and "clnt" continues to use the old + * xprt. + */ +int rpc_switch_client_transport(struct rpc_clnt *clnt, + struct xprt_create *args, + const struct rpc_timeout *timeout) +{ + const struct rpc_timeout *old_timeo; + rpc_authflavor_t pseudoflavor; + struct rpc_xprt_switch *xps, *oldxps; + struct rpc_xprt *xprt, *old; + struct rpc_clnt *parent; + int err; + + args->xprtsec = clnt->cl_xprtsec; + xprt = xprt_create_transport(args); + if (IS_ERR(xprt)) + return PTR_ERR(xprt); + + xps = xprt_switch_alloc(xprt, GFP_KERNEL); + if (xps == NULL) { + xprt_put(xprt); + return -ENOMEM; + } + + pseudoflavor = clnt->cl_auth->au_flavor; + + old_timeo = clnt->cl_timeout; + old = rpc_clnt_set_transport(clnt, xprt, timeout); + oldxps = xprt_iter_xchg_switch(&clnt->cl_xpi, xps); + + rpc_unregister_client(clnt); + __rpc_clnt_remove_pipedir(clnt); + rpc_sysfs_client_destroy(clnt); + rpc_clnt_debugfs_unregister(clnt); + + /* + * A new transport was created. "clnt" therefore + * becomes the root of a new cl_parent tree. clnt's + * children, if it has any, still point to the old xprt. + */ + parent = clnt->cl_parent; + clnt->cl_parent = clnt; + + /* + * The old rpc_auth cache cannot be re-used. GSS + * contexts in particular are between a single + * client and server. + */ + err = rpc_client_register(clnt, pseudoflavor, NULL); + if (err) + goto out_revert; + + synchronize_rcu(); + if (parent != clnt) + rpc_release_client(parent); + xprt_switch_put(oldxps); + xprt_put(old); + trace_rpc_clnt_replace_xprt(clnt); + return 0; + +out_revert: + xps = xprt_iter_xchg_switch(&clnt->cl_xpi, oldxps); + rpc_clnt_set_transport(clnt, old, old_timeo); + clnt->cl_parent = parent; + rpc_client_register(clnt, pseudoflavor, NULL); + xprt_switch_put(xps); + xprt_put(xprt); + trace_rpc_clnt_replace_xprt_err(clnt); + return err; +} +EXPORT_SYMBOL_GPL(rpc_switch_client_transport); + +static struct rpc_xprt_switch *rpc_clnt_xprt_switch_get(struct rpc_clnt *clnt) +{ + struct rpc_xprt_switch *xps; + + rcu_read_lock(); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + rcu_read_unlock(); + + return xps; +} + +static +int _rpc_clnt_xprt_iter_init(struct rpc_clnt *clnt, struct rpc_xprt_iter *xpi, + void func(struct rpc_xprt_iter *xpi, struct rpc_xprt_switch *xps)) +{ + struct rpc_xprt_switch *xps; + + xps = rpc_clnt_xprt_switch_get(clnt); + if (xps == NULL) + return -EAGAIN; + func(xpi, xps); + xprt_switch_put(xps); + return 0; +} + +static +int rpc_clnt_xprt_iter_init(struct rpc_clnt *clnt, struct rpc_xprt_iter *xpi) +{ + return _rpc_clnt_xprt_iter_init(clnt, xpi, xprt_iter_init_listall); +} + +static +int rpc_clnt_xprt_iter_offline_init(struct rpc_clnt *clnt, + struct rpc_xprt_iter *xpi) +{ + return _rpc_clnt_xprt_iter_init(clnt, xpi, xprt_iter_init_listoffline); +} + +/** + * rpc_clnt_iterate_for_each_xprt - Apply a function to all transports + * @clnt: pointer to client + * @fn: function to apply + * @data: void pointer to function data + * + * Iterates through the list of RPC transports currently attached to the + * client and applies the function fn(clnt, xprt, data). + * + * On error, the iteration stops, and the function returns the error value. + */ +int rpc_clnt_iterate_for_each_xprt(struct rpc_clnt *clnt, + int (*fn)(struct rpc_clnt *, struct rpc_xprt *, void *), + void *data) +{ + struct rpc_xprt_iter xpi; + int ret; + + ret = rpc_clnt_xprt_iter_init(clnt, &xpi); + if (ret) + return ret; + for (;;) { + struct rpc_xprt *xprt = xprt_iter_get_next(&xpi); + + if (!xprt) + break; + ret = fn(clnt, xprt, data); + xprt_put(xprt); + if (ret < 0) + break; + } + xprt_iter_destroy(&xpi); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_clnt_iterate_for_each_xprt); + /* * Kill all tasks for the given client. * XXX: kill their descendants as well? @@ -599,25 +877,68 @@ void rpc_killall_tasks(struct rpc_clnt *clnt) if (list_empty(&clnt->cl_tasks)) return; - dprintk("RPC: killing all tasks for client %p\n", clnt); + + /* + * Spin lock all_tasks to prevent changes... + */ + trace_rpc_clnt_killall(clnt); + spin_lock(&clnt->cl_lock); + list_for_each_entry(rovr, &clnt->cl_tasks, tk_task) + rpc_signal_task(rovr); + spin_unlock(&clnt->cl_lock); +} +EXPORT_SYMBOL_GPL(rpc_killall_tasks); + +/** + * rpc_cancel_tasks - try to cancel a set of RPC tasks + * @clnt: Pointer to RPC client + * @error: RPC task error value to set + * @fnmatch: Pointer to selector function + * @data: User data + * + * Uses @fnmatch to define a set of RPC tasks that are to be cancelled. + * The argument @error must be a negative error value. + */ +unsigned long rpc_cancel_tasks(struct rpc_clnt *clnt, int error, + bool (*fnmatch)(const struct rpc_task *, + const void *), + const void *data) +{ + struct rpc_task *task; + unsigned long count = 0; + + if (list_empty(&clnt->cl_tasks)) + return 0; /* * Spin lock all_tasks to prevent changes... */ spin_lock(&clnt->cl_lock); - list_for_each_entry(rovr, &clnt->cl_tasks, tk_task) { - if (!RPC_IS_ACTIVATED(rovr)) + list_for_each_entry(task, &clnt->cl_tasks, tk_task) { + if (!RPC_IS_ACTIVATED(task)) continue; - if (!(rovr->tk_flags & RPC_TASK_KILLED)) { - rovr->tk_flags |= RPC_TASK_KILLED; - rpc_exit(rovr, -EIO); - if (RPC_IS_QUEUED(rovr)) - rpc_wake_up_queued_task(rovr->tk_waitqueue, - rovr); - } + if (!fnmatch(task, data)) + continue; + rpc_task_try_cancel(task, error); + count++; } spin_unlock(&clnt->cl_lock); + return count; } -EXPORT_SYMBOL_GPL(rpc_killall_tasks); +EXPORT_SYMBOL_GPL(rpc_cancel_tasks); + +static int rpc_clnt_disconnect_xprt(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, void *dummy) +{ + if (xprt_connected(xprt)) + xprt_force_disconnect(xprt); + return 0; +} + +void rpc_clnt_disconnect(struct rpc_clnt *clnt) +{ + rpc_clnt_iterate_for_each_xprt(clnt, rpc_clnt_disconnect_xprt, NULL); +} +EXPORT_SYMBOL_GPL(rpc_clnt_disconnect); /* * Properly shut down an RPC client, terminating all outstanding @@ -627,16 +948,19 @@ void rpc_shutdown_client(struct rpc_clnt *clnt) { might_sleep(); - dprintk_rcu("RPC: shutting down %s client for %s\n", - clnt->cl_protname, - rcu_dereference(clnt->cl_xprt)->servername); + trace_rpc_clnt_shutdown(clnt); + clnt->cl_shutdown = 1; while (!list_empty(&clnt->cl_tasks)) { rpc_killall_tasks(clnt); wait_event_timeout(destroy_wait, list_empty(&clnt->cl_tasks), 1*HZ); } + /* wait for tasks still in workqueue or waitqueue */ + wait_event_timeout(destroy_wait, + atomic_read(&clnt->cl_task_count) == 0, 1 * HZ); + rpc_release_client(clnt); } EXPORT_SYMBOL_GPL(rpc_shutdown_client); @@ -644,45 +968,62 @@ EXPORT_SYMBOL_GPL(rpc_shutdown_client); /* * Free an RPC client */ -static void +static void rpc_free_client_work(struct work_struct *work) +{ + struct rpc_clnt *clnt = container_of(work, struct rpc_clnt, cl_work); + + trace_rpc_clnt_free(clnt); + + /* These might block on processes that might allocate memory, + * so they cannot be called in rpciod, so they are handled separately + * here. + */ + rpc_sysfs_client_destroy(clnt); + rpc_clnt_debugfs_unregister(clnt); + rpc_free_clid(clnt); + rpc_clnt_remove_pipedir(clnt); + xprt_put(rcu_dereference_raw(clnt->cl_xprt)); + + kfree(clnt); + rpciod_down(); +} +static struct rpc_clnt * rpc_free_client(struct rpc_clnt *clnt) { - dprintk_rcu("RPC: destroying %s client for %s\n", - clnt->cl_protname, - rcu_dereference(clnt->cl_xprt)->servername); + struct rpc_clnt *parent = NULL; + + trace_rpc_clnt_release(clnt); if (clnt->cl_parent != clnt) - rpc_release_client(clnt->cl_parent); - rpc_clnt_remove_pipedir(clnt); + parent = clnt->cl_parent; rpc_unregister_client(clnt); rpc_free_iostats(clnt->cl_metrics); - kfree(clnt->cl_principal); clnt->cl_metrics = NULL; - xprt_put(rcu_dereference_raw(clnt->cl_xprt)); - rpciod_down(); - kfree(clnt); + xprt_iter_destroy(&clnt->cl_xpi); + put_cred(clnt->cl_cred); + + INIT_WORK(&clnt->cl_work, rpc_free_client_work); + schedule_work(&clnt->cl_work); + return parent; } /* * Free an RPC client */ -static void +static struct rpc_clnt * rpc_free_auth(struct rpc_clnt *clnt) { - if (clnt->cl_auth == NULL) { - rpc_free_client(clnt); - return; - } - /* * Note: RPCSEC_GSS may need to send NULL RPC calls in order to * release remaining GSS contexts. This mechanism ensures * that it can do so safely. */ - atomic_inc(&clnt->cl_count); - rpcauth_release(clnt->cl_auth); - clnt->cl_auth = NULL; - if (atomic_dec_and_test(&clnt->cl_count)) - rpc_free_client(clnt); + if (clnt->cl_auth != NULL) { + rpcauth_release(clnt->cl_auth); + clnt->cl_auth = NULL; + } + if (refcount_dec_and_test(&clnt->cl_count)) + return rpc_free_client(clnt); + return NULL; } /* @@ -691,12 +1032,13 @@ rpc_free_auth(struct rpc_clnt *clnt) void rpc_release_client(struct rpc_clnt *clnt) { - dprintk("RPC: rpc_release_client(%p)\n", clnt); - - if (list_empty(&clnt->cl_tasks)) - wake_up(&destroy_wait); - if (atomic_dec_and_test(&clnt->cl_count)) - rpc_free_auth(clnt); + do { + if (list_empty(&clnt->cl_tasks)) + wake_up(&destroy_wait); + if (refcount_dec_not_one(&clnt->cl_count)) + break; + clnt = rpc_free_auth(clnt); + } while (clnt != NULL); } EXPORT_SYMBOL_GPL(rpc_release_client); @@ -719,7 +1061,9 @@ struct rpc_clnt *rpc_bind_new_program(struct rpc_clnt *old, .prognumber = program->number, .version = vers, .authflavor = old->cl_auth->au_flavor, - .client_name = old->cl_principal, + .cred = old->cl_cred, + .stats = old->cl_stats, + .timeout = old->cl_timeout, }; struct rpc_clnt *clnt; int err; @@ -737,53 +1081,116 @@ out: } EXPORT_SYMBOL_GPL(rpc_bind_new_program); +struct rpc_xprt * +rpc_task_get_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + struct rpc_xprt_switch *xps; + + if (!xprt) + return NULL; + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + atomic_long_inc(&xps->xps_queuelen); + rcu_read_unlock(); + atomic_long_inc(&xprt->queuelen); + + return xprt; +} + +static void +rpc_task_release_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + struct rpc_xprt_switch *xps; + + atomic_long_dec(&xprt->queuelen); + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + atomic_long_dec(&xps->xps_queuelen); + rcu_read_unlock(); + + xprt_put(xprt); +} + +void rpc_task_release_transport(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + if (xprt) { + task->tk_xprt = NULL; + if (task->tk_client) + rpc_task_release_xprt(task->tk_client, xprt); + else + xprt_put(xprt); + } +} +EXPORT_SYMBOL_GPL(rpc_task_release_transport); + void rpc_task_release_client(struct rpc_task *task) { struct rpc_clnt *clnt = task->tk_client; + rpc_task_release_transport(task); if (clnt != NULL) { /* Remove from client task list */ spin_lock(&clnt->cl_lock); list_del(&task->tk_task); spin_unlock(&clnt->cl_lock); task->tk_client = NULL; + atomic_dec(&clnt->cl_task_count); rpc_release_client(clnt); } } +static struct rpc_xprt * +rpc_task_get_first_xprt(struct rpc_clnt *clnt) +{ + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + rcu_read_unlock(); + return rpc_task_get_xprt(clnt, xprt); +} + +static struct rpc_xprt * +rpc_task_get_next_xprt(struct rpc_clnt *clnt) +{ + return rpc_task_get_xprt(clnt, xprt_iter_get_next(&clnt->cl_xpi)); +} + static -void rpc_task_set_client(struct rpc_task *task, struct rpc_clnt *clnt) +void rpc_task_set_transport(struct rpc_task *task, struct rpc_clnt *clnt) { - if (clnt != NULL) { - rpc_task_release_client(task); - task->tk_client = clnt; - atomic_inc(&clnt->cl_count); - if (clnt->cl_softrtry) - task->tk_flags |= RPC_TASK_SOFT; - if (sk_memalloc_socks()) { - struct rpc_xprt *xprt; - - rcu_read_lock(); - xprt = rcu_dereference(clnt->cl_xprt); - if (xprt->swapper) - task->tk_flags |= RPC_TASK_SWAPPER; - rcu_read_unlock(); - } - /* Add to the client's list of all tasks */ - spin_lock(&clnt->cl_lock); - list_add_tail(&task->tk_task, &clnt->cl_tasks); - spin_unlock(&clnt->cl_lock); + if (task->tk_xprt) { + if (!(test_bit(XPRT_OFFLINE, &task->tk_xprt->state) && + (task->tk_flags & RPC_TASK_MOVEABLE))) + return; + xprt_release(task); + xprt_put(task->tk_xprt); } + if (task->tk_flags & RPC_TASK_NO_ROUND_ROBIN) + task->tk_xprt = rpc_task_get_first_xprt(clnt); + else + task->tk_xprt = rpc_task_get_next_xprt(clnt); } -void rpc_task_reset_client(struct rpc_task *task, struct rpc_clnt *clnt) +static +void rpc_task_set_client(struct rpc_task *task, struct rpc_clnt *clnt) { - rpc_task_release_client(task); - rpc_task_set_client(task, clnt); + rpc_task_set_transport(task, clnt); + task->tk_client = clnt; + refcount_inc(&clnt->cl_count); + if (clnt->cl_softrtry) + task->tk_flags |= RPC_TASK_SOFT; + if (clnt->cl_softerr) + task->tk_flags |= RPC_TASK_TIMEOUT; + if (clnt->cl_noretranstimeo) + task->tk_flags |= RPC_TASK_NO_RETRANS_TIMEOUT; + if (clnt->cl_netunreach_fatal) + task->tk_flags |= RPC_TASK_NETUNREACH_FATAL; + atomic_inc(&clnt->cl_task_count); } -EXPORT_SYMBOL_GPL(rpc_task_reset_client); - static void rpc_task_set_rpc_message(struct rpc_task *task, const struct rpc_message *msg) @@ -792,8 +1199,9 @@ rpc_task_set_rpc_message(struct rpc_task *task, const struct rpc_message *msg) task->tk_msg.rpc_proc = msg->rpc_proc; task->tk_msg.rpc_argp = msg->rpc_argp; task->tk_msg.rpc_resp = msg->rpc_resp; - if (msg->rpc_cred != NULL) - task->tk_msg.rpc_cred = get_rpccred(msg->rpc_cred); + task->tk_msg.rpc_cred = msg->rpc_cred; + if (!(task->tk_flags & RPC_TASK_CRED_NOREF)) + get_cred(task->tk_msg.rpc_cred); } } @@ -819,7 +1227,10 @@ struct rpc_task *rpc_run_task(const struct rpc_task_setup *task_setup_data) task = rpc_new_task(task_setup_data); if (IS_ERR(task)) - goto out; + return task; + + if (!RPC_IS_ASYNC(task)) + task->tk_flags |= RPC_TASK_CRED_NOREF; rpc_task_set_client(task, task_setup_data->rpc_client); rpc_task_set_rpc_message(task, task_setup_data->rpc_message); @@ -829,7 +1240,6 @@ struct rpc_task *rpc_run_task(const struct rpc_task_setup *task_setup_data) atomic_inc(&task->tk_count); rpc_execute(task); -out: return task; } EXPORT_SYMBOL_GPL(rpc_run_task); @@ -897,19 +1307,22 @@ rpc_call_async(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags, EXPORT_SYMBOL_GPL(rpc_call_async); #if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void call_bc_encode(struct rpc_task *task); + /** * rpc_run_bc_task - Allocate a new RPC task for backchannel use, then run * rpc_execute against it * @req: RPC request - * @tk_ops: RPC call ops + * @timeout: timeout values to use for this task */ struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req, - const struct rpc_call_ops *tk_ops) + struct rpc_timeout *timeout) { struct rpc_task *task; - struct xdr_buf *xbufp = &req->rq_snd_buf; struct rpc_task_setup task_setup_data = { - .callback_ops = tk_ops, + .callback_ops = &rpc_default_ops, + .flags = RPC_TASK_SOFTCONN | + RPC_TASK_NO_RETRANS_TIMEOUT, }; dprintk("RPC: rpc_run_bc_task req= %p\n", req); @@ -919,28 +1332,41 @@ struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req, task = rpc_new_task(&task_setup_data); if (IS_ERR(task)) { xprt_free_bc_request(req); - goto out; + return task; } - task->tk_rqstp = req; - /* - * Set up the xdr_buf length. - * This also indicates that the buffer is XDR encoded already. - */ - xbufp->len = xbufp->head[0].iov_len + xbufp->page_len + - xbufp->tail[0].iov_len; + xprt_init_bc_request(req, task, timeout); - task->tk_action = call_bc_transmit; + task->tk_action = call_bc_encode; atomic_inc(&task->tk_count); WARN_ON_ONCE(atomic_read(&task->tk_count) != 2); rpc_execute(task); -out: dprintk("RPC: rpc_run_bc_task: task= %p\n", task); return task; } #endif /* CONFIG_SUNRPC_BACKCHANNEL */ +/** + * rpc_prepare_reply_pages - Prepare to receive a reply data payload into pages + * @req: RPC request to prepare + * @pages: vector of struct page pointers + * @base: offset in first page where receive should start, in bytes + * @len: expected size of the upper layer data payload, in bytes + * @hdrsize: expected size of upper layer reply header, in XDR words + * + */ +void rpc_prepare_reply_pages(struct rpc_rqst *req, struct page **pages, + unsigned int base, unsigned int len, + unsigned int hdrsize) +{ + hdrsize += RPC_REPHDRSIZE + req->rq_cred->cr_auth->au_ralign; + + xdr_inline_pages(&req->rq_rcv_buf, hdrsize << 2, pages, base, len); + trace_rpc_xdr_reply_pages(req->rq_task, &req->rq_rcv_buf); +} +EXPORT_SYMBOL_GPL(rpc_prepare_reply_pages); + void rpc_call_start(struct rpc_task *task) { @@ -1016,7 +1442,7 @@ static const struct sockaddr_in6 rpc_in6addr_loopback = { * negative errno is returned. */ static int rpc_sockname(struct net *net, struct sockaddr *sap, size_t salen, - struct sockaddr *buf, int buflen) + struct sockaddr *buf) { struct socket *sock; int err; @@ -1031,30 +1457,30 @@ static int rpc_sockname(struct net *net, struct sockaddr *sap, size_t salen, switch (sap->sa_family) { case AF_INET: err = kernel_bind(sock, - (struct sockaddr *)&rpc_inaddr_loopback, + (struct sockaddr_unsized *)&rpc_inaddr_loopback, sizeof(rpc_inaddr_loopback)); break; case AF_INET6: err = kernel_bind(sock, - (struct sockaddr *)&rpc_in6addr_loopback, + (struct sockaddr_unsized *)&rpc_in6addr_loopback, sizeof(rpc_in6addr_loopback)); break; default: err = -EAFNOSUPPORT; - goto out; + goto out_release; } if (err < 0) { dprintk("RPC: can't bind UDP socket (%d)\n", err); goto out_release; } - err = kernel_connect(sock, sap, salen, 0); + err = kernel_connect(sock, (struct sockaddr_unsized *)sap, salen, 0); if (err < 0) { dprintk("RPC: can't connect UDP socket (%d)\n", err); goto out_release; } - err = kernel_getsockname(sock, buf, &buflen); + err = kernel_getsockname(sock, buf); if (err < 0) { dprintk("RPC: getsockname failed (%d)\n", err); goto out_release; @@ -1095,6 +1521,7 @@ static int rpc_anyaddr(int family, struct sockaddr *buf, size_t buflen) return -EINVAL; memcpy(buf, &rpc_in6addr_loopback, sizeof(rpc_in6addr_loopback)); + break; default: dprintk("RPC: %s: address family not supported\n", __func__); @@ -1137,7 +1564,7 @@ int rpc_localaddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t buflen) rcu_read_unlock(); rpc_set_port(sap, 0); - err = rpc_sockname(net, sap, salen, buf, buflen); + err = rpc_sockname(net, sap, salen, buf); put_net(net); if (err != 0) /* Couldn't discover local address, return ANYADDR */ @@ -1160,22 +1587,6 @@ rpc_setbufsize(struct rpc_clnt *clnt, unsigned int sndsize, unsigned int rcvsize EXPORT_SYMBOL_GPL(rpc_setbufsize); /** - * rpc_protocol - Get transport protocol number for an RPC client - * @clnt: RPC client to query - * - */ -int rpc_protocol(struct rpc_clnt *clnt) -{ - int protocol; - - rcu_read_lock(); - protocol = rcu_dereference(clnt->cl_xprt)->prot; - rcu_read_unlock(); - return protocol; -} -EXPORT_SYMBOL_GPL(rpc_protocol); - -/** * rpc_net_ns - Get the network namespace for this RPC client * @clnt: RPC client to query * @@ -1212,19 +1623,34 @@ size_t rpc_max_payload(struct rpc_clnt *clnt) EXPORT_SYMBOL_GPL(rpc_max_payload); /** - * rpc_get_timeout - Get timeout for transport in units of HZ + * rpc_max_bc_payload - Get maximum backchannel payload size, in bytes * @clnt: RPC client to query */ -unsigned long rpc_get_timeout(struct rpc_clnt *clnt) +size_t rpc_max_bc_payload(struct rpc_clnt *clnt) { - unsigned long ret; + struct rpc_xprt *xprt; + size_t ret; rcu_read_lock(); - ret = rcu_dereference(clnt->cl_xprt)->timeout->to_initval; + xprt = rcu_dereference(clnt->cl_xprt); + ret = xprt->ops->bc_maxpayload(xprt); rcu_read_unlock(); return ret; } -EXPORT_SYMBOL_GPL(rpc_get_timeout); +EXPORT_SYMBOL_GPL(rpc_max_bc_payload); + +unsigned int rpc_num_bc_slots(struct rpc_clnt *clnt) +{ + struct rpc_xprt *xprt; + unsigned int ret; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + ret = xprt->ops->bc_num_slots(xprt); + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_num_bc_slots); /** * rpc_force_rebind - force transport to check that remote port is unchanged @@ -1241,21 +1667,14 @@ void rpc_force_rebind(struct rpc_clnt *clnt) } EXPORT_SYMBOL_GPL(rpc_force_rebind); -/* - * Restart an (async) RPC call from the call_prepare state. - * Usually called from within the exit handler. - */ -int -rpc_restart_call_prepare(struct rpc_task *task) +static int +__rpc_restart_call(struct rpc_task *task, void (*action)(struct rpc_task *)) { - if (RPC_ASSASSINATED(task)) - return 0; - task->tk_action = call_start; - if (task->tk_ops->rpc_call_prepare != NULL) - task->tk_action = rpc_prepare_task; + task->tk_status = 0; + task->tk_rpc_status = 0; + task->tk_action = action; return 1; } -EXPORT_SYMBOL_GPL(rpc_restart_call_prepare); /* * Restart an (async) RPC call. Usually called from within the @@ -1264,15 +1683,25 @@ EXPORT_SYMBOL_GPL(rpc_restart_call_prepare); int rpc_restart_call(struct rpc_task *task) { - if (RPC_ASSASSINATED(task)) - return 0; - task->tk_action = call_start; - return 1; + return __rpc_restart_call(task, call_start); } EXPORT_SYMBOL_GPL(rpc_restart_call); -#ifdef RPC_DEBUG -static const char *rpc_proc_name(const struct rpc_task *task) +/* + * Restart an (async) RPC call from the call_prepare state. + * Usually called from within the exit handler. + */ +int +rpc_restart_call_prepare(struct rpc_task *task) +{ + if (task->tk_ops->rpc_call_prepare != NULL) + return __rpc_restart_call(task, rpc_prepare_task); + return rpc_restart_call(task); +} +EXPORT_SYMBOL_GPL(rpc_restart_call_prepare); + +const char +*rpc_proc_name(const struct rpc_task *task) { const struct rpc_procinfo *proc = task->tk_msg.rpc_proc; @@ -1284,7 +1713,20 @@ static const char *rpc_proc_name(const struct rpc_task *task) } else return "no proc"; } -#endif + +static void +__rpc_call_rpcerror(struct rpc_task *task, int tk_status, int rpc_status) +{ + trace_rpc_call_rpcerror(task, tk_status, rpc_status); + rpc_task_set_rpc_status(task, rpc_status); + rpc_exit(task, tk_status); +} + +static void +rpc_call_rpcerror(struct rpc_task *task, int status) +{ + __rpc_call_rpcerror(task, status, status); +} /* * 0. Initial state @@ -1296,16 +1738,21 @@ static void call_start(struct rpc_task *task) { struct rpc_clnt *clnt = task->tk_client; + int idx = task->tk_msg.rpc_proc->p_statidx; + + trace_rpc_request(task); - dprintk("RPC: %5u call_start %s%d proc %s (%s)\n", task->tk_pid, - clnt->cl_protname, clnt->cl_vers, - rpc_proc_name(task), - (RPC_IS_ASYNC(task) ? "async" : "sync")); + if (task->tk_client->cl_shutdown) { + rpc_call_rpcerror(task, -EIO); + return; + } - /* Increment call count */ - task->tk_msg.rpc_proc->p_count++; + /* Increment call count (version might not be valid for ping) */ + if (clnt->cl_program->version[clnt->cl_vers]) + clnt->cl_program->version[clnt->cl_vers]->counts[idx]++; clnt->cl_stats->rpccnt++; task->tk_action = call_reserve; + rpc_task_set_transport(task, clnt); } /* @@ -1314,8 +1761,6 @@ call_start(struct rpc_task *task) static void call_reserve(struct rpc_task *task) { - dprint_status(task); - task->tk_status = 0; task->tk_action = call_reserveresult; xprt_reserve(task); @@ -1331,8 +1776,6 @@ call_reserveresult(struct rpc_task *task) { int status = task->tk_status; - dprint_status(task); - /* * After a call to xprt_reserve(), we must have either * a request slot or else an error status. @@ -1341,39 +1784,28 @@ call_reserveresult(struct rpc_task *task) if (status >= 0) { if (task->tk_rqstp) { task->tk_action = call_refresh; + + /* Add to the client's list of all tasks */ + spin_lock(&task->tk_client->cl_lock); + if (list_empty(&task->tk_task)) + list_add_tail(&task->tk_task, &task->tk_client->cl_tasks); + spin_unlock(&task->tk_client->cl_lock); return; } - - printk(KERN_ERR "%s: status=%d, but no request slot, exiting\n", - __func__, status); - rpc_exit(task, -EIO); + rpc_call_rpcerror(task, -EIO); return; } - /* - * Even though there was an error, we may have acquired - * a request slot somehow. Make sure not to leak it. - */ - if (task->tk_rqstp) { - printk(KERN_ERR "%s: status=%d, request allocated anyway\n", - __func__, status); - xprt_release(task); - } - switch (status) { case -ENOMEM: rpc_delay(task, HZ >> 2); + fallthrough; case -EAGAIN: /* woken up; retry */ task->tk_action = call_retry_reserve; return; - case -EIO: /* probably a shutdown */ - break; default: - printk(KERN_ERR "%s: unrecognized error %d, exiting\n", - __func__, status); - break; + rpc_call_rpcerror(task, status); } - rpc_exit(task, status); } /* @@ -1382,8 +1814,6 @@ call_reserveresult(struct rpc_task *task) static void call_retry_reserve(struct rpc_task *task) { - dprint_status(task); - task->tk_status = 0; task->tk_action = call_reserveresult; xprt_retry_reserve(task); @@ -1395,8 +1825,6 @@ call_retry_reserve(struct rpc_task *task) static void call_refresh(struct rpc_task *task) { - dprint_status(task); - task->tk_action = call_refreshresult; task->tk_status = 0; task->tk_client->cl_stats->rpcauthrefresh++; @@ -1411,30 +1839,36 @@ call_refreshresult(struct rpc_task *task) { int status = task->tk_status; - dprint_status(task); - task->tk_status = 0; task->tk_action = call_refresh; switch (status) { case 0: - if (rpcauth_uptodatecred(task)) + if (rpcauth_uptodatecred(task)) { task->tk_action = call_allocate; - return; + return; + } + /* Use rate-limiting and a max number of retries if refresh + * had status 0 but failed to update the cred. + */ + fallthrough; case -ETIMEDOUT: rpc_delay(task, 3*HZ); - case -EKEYEXPIRED: + fallthrough; case -EAGAIN: status = -EACCES; if (!task->tk_cred_retry) break; task->tk_cred_retry--; - dprintk("RPC: %5u %s: retry refresh creds\n", - task->tk_pid, __func__); + trace_rpc_retry_refresh_status(task); + return; + case -EKEYEXPIRED: + break; + case -ENOMEM: + rpc_delay(task, HZ >> 4); return; } - dprintk("RPC: %5u %s: refresh creds failed with error %d\n", - task->tk_pid, __func__, status); - rpc_exit(task, status); + trace_rpc_refresh_status(task); + rpc_call_rpcerror(task, status); } /* @@ -1444,41 +1878,42 @@ call_refreshresult(struct rpc_task *task) static void call_allocate(struct rpc_task *task) { - unsigned int slack = task->tk_rqstp->rq_cred->cr_auth->au_cslack; + const struct rpc_auth *auth = task->tk_rqstp->rq_cred->cr_auth; struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; - struct rpc_procinfo *proc = task->tk_msg.rpc_proc; - - dprint_status(task); + const struct rpc_procinfo *proc = task->tk_msg.rpc_proc; + int status; task->tk_status = 0; - task->tk_action = call_bind; + task->tk_action = call_encode; if (req->rq_buffer) return; - if (proc->p_proc != 0) { - BUG_ON(proc->p_arglen == 0); - if (proc->p_decode != NULL) - BUG_ON(proc->p_replen == 0); - } - /* * Calculate the size (in quads) of the RPC call * and reply headers, and convert both values * to byte sizes. */ - req->rq_callsize = RPC_CALLHDRSIZE + (slack << 1) + proc->p_arglen; + req->rq_callsize = RPC_CALLHDRSIZE + (auth->au_cslack << 1) + + proc->p_arglen; req->rq_callsize <<= 2; - req->rq_rcvsize = RPC_REPHDRSIZE + slack + proc->p_replen; + /* + * Note: the reply buffer must at minimum allocate enough space + * for the 'struct accepted_reply' from RFC5531. + */ + req->rq_rcvsize = RPC_REPHDRSIZE + auth->au_rslack + \ + max_t(size_t, proc->p_replen, 2); req->rq_rcvsize <<= 2; - req->rq_buffer = xprt->ops->buf_alloc(task, - req->rq_callsize + req->rq_rcvsize); - if (req->rq_buffer != NULL) + status = xprt->ops->buf_alloc(task); + trace_rpc_buf_alloc(task, status); + if (status == 0) return; - - dprintk("RPC: %5u rpc_buffer allocation failed\n", task->tk_pid); + if (status != -ENOMEM) { + rpc_call_rpcerror(task, status); + return; + } if (RPC_IS_ASYNC(task) || !fatal_signal_pending(current)) { task->tk_action = call_allocate; @@ -1486,66 +1921,105 @@ call_allocate(struct rpc_task *task) return; } - rpc_exit(task, -ERESTARTSYS); + rpc_call_rpcerror(task, -ERESTARTSYS); } -static inline int +static int rpc_task_need_encode(struct rpc_task *task) { - return task->tk_rqstp->rq_snd_buf.len == 0; + return test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate) == 0 && + (!(task->tk_flags & RPC_TASK_SENT) || + !(task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) || + xprt_request_need_retransmit(task)); } -static inline void -rpc_task_force_reencode(struct rpc_task *task) +static void +rpc_xdr_encode(struct rpc_task *task) { - task->tk_rqstp->rq_snd_buf.len = 0; - task->tk_rqstp->rq_bytes_sent = 0; -} + struct rpc_rqst *req = task->tk_rqstp; + struct xdr_stream xdr; + + xdr_buf_init(&req->rq_snd_buf, + req->rq_buffer, + req->rq_callsize); + xdr_buf_init(&req->rq_rcv_buf, + req->rq_rbuffer, + req->rq_rcvsize); + + req->rq_reply_bytes_recvd = 0; + req->rq_snd_buf.head[0].iov_len = 0; + xdr_init_encode(&xdr, &req->rq_snd_buf, + req->rq_snd_buf.head[0].iov_base, req); + if (rpc_encode_header(task, &xdr)) + return; -static inline void -rpc_xdr_buf_init(struct xdr_buf *buf, void *start, size_t len) -{ - buf->head[0].iov_base = start; - buf->head[0].iov_len = len; - buf->tail[0].iov_len = 0; - buf->page_len = 0; - buf->flags = 0; - buf->len = 0; - buf->buflen = len; + task->tk_status = rpcauth_wrap_req(task, &xdr); } /* * 3. Encode arguments of an RPC call */ static void -rpc_xdr_encode(struct rpc_task *task) +call_encode(struct rpc_task *task) { - struct rpc_rqst *req = task->tk_rqstp; - kxdreproc_t encode; - __be32 *p; - - dprint_status(task); - - rpc_xdr_buf_init(&req->rq_snd_buf, - req->rq_buffer, - req->rq_callsize); - rpc_xdr_buf_init(&req->rq_rcv_buf, - (char *)req->rq_buffer + req->rq_callsize, - req->rq_rcvsize); - - p = rpc_encode_header(task); - if (p == NULL) { - printk(KERN_INFO "RPC: couldn't encode RPC header, exit EIO\n"); - rpc_exit(task, -EIO); + if (!rpc_task_need_encode(task)) + goto out; + + /* Dequeue task from the receive queue while we're encoding */ + xprt_request_dequeue_xprt(task); + /* Encode here so that rpcsec_gss can use correct sequence number. */ + rpc_xdr_encode(task); + /* Add task to reply queue before transmission to avoid races */ + if (task->tk_status == 0 && rpc_reply_expected(task)) + task->tk_status = xprt_request_enqueue_receive(task); + /* Did the encode result in an error condition? */ + if (task->tk_status != 0) { + /* Was the error nonfatal? */ + switch (task->tk_status) { + case -EAGAIN: + case -ENOMEM: + rpc_delay(task, HZ >> 4); + break; + case -EKEYEXPIRED: + if (!task->tk_cred_retry) { + rpc_call_rpcerror(task, task->tk_status); + } else { + task->tk_action = call_refresh; + task->tk_cred_retry--; + trace_rpc_retry_refresh_status(task); + } + break; + default: + rpc_call_rpcerror(task, task->tk_status); + } return; } - encode = task->tk_msg.rpc_proc->p_encode; - if (encode == NULL) - return; + xprt_request_enqueue_transmit(task); +out: + task->tk_action = call_transmit; + /* Check that the connection is OK */ + if (!xprt_bound(task->tk_xprt)) + task->tk_action = call_bind; + else if (!xprt_connected(task->tk_xprt)) + task->tk_action = call_connect; +} + +/* + * Helpers to check if the task was already transmitted, and + * to take action when that is the case. + */ +static bool +rpc_task_transmitted(struct rpc_task *task) +{ + return !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); +} - task->tk_status = rpcauth_wrap_req(task, encode, req, p, - task->tk_msg.rpc_argp); +static void +rpc_task_handle_transmitted(struct rpc_task *task) +{ + xprt_end_transmit(task); + task->tk_action = call_transmit_status; } /* @@ -1556,14 +2030,21 @@ call_bind(struct rpc_task *task) { struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; - dprint_status(task); + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } - task->tk_action = call_connect; - if (!xprt_bound(xprt)) { - task->tk_action = call_bind_status; - task->tk_timeout = xprt->bind_timeout; - xprt->ops->rpcbind(task); + if (xprt_bound(xprt)) { + task->tk_action = call_connect; + return; } + + task->tk_action = call_bind_status; + if (!xprt_prepare_transmit(task)) + return; + + xprt->ops->rpcbind(task); } /* @@ -1572,58 +2053,62 @@ call_bind(struct rpc_task *task) static void call_bind_status(struct rpc_task *task) { + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; int status = -EIO; - if (task->tk_status >= 0) { - dprint_status(task); - task->tk_status = 0; - task->tk_action = call_connect; + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); return; } - trace_rpc_bind_status(task); + if (task->tk_status >= 0) + goto out_next; + if (xprt_bound(xprt)) { + task->tk_status = 0; + goto out_next; + } + switch (task->tk_status) { case -ENOMEM: - dprintk("RPC: %5u rpcbind out of memory\n", task->tk_pid); rpc_delay(task, HZ >> 2); goto retry_timeout; case -EACCES: - dprintk("RPC: %5u remote rpcbind: RPC program/version " - "unavailable\n", task->tk_pid); + trace_rpcb_prog_unavail_err(task); /* fail immediately if this is an RPC ping */ if (task->tk_msg.rpc_proc->p_proc == 0) { status = -EOPNOTSUPP; break; } - if (task->tk_rebind_retry == 0) - break; - task->tk_rebind_retry--; rpc_delay(task, 3*HZ); goto retry_timeout; + case -ENOBUFS: + rpc_delay(task, HZ >> 2); + goto retry_timeout; + case -EAGAIN: + goto retry_timeout; case -ETIMEDOUT: - dprintk("RPC: %5u rpcbind request timed out\n", - task->tk_pid); + trace_rpcb_timeout_err(task); goto retry_timeout; case -EPFNOSUPPORT: /* server doesn't support any rpcbind version we know of */ - dprintk("RPC: %5u unrecognized remote rpcbind service\n", - task->tk_pid); + trace_rpcb_bind_version_err(task); break; case -EPROTONOSUPPORT: - dprintk("RPC: %5u remote rpcbind version unavailable, retrying\n", - task->tk_pid); - task->tk_status = 0; - task->tk_action = call_bind; - return; + trace_rpcb_bind_version_err(task); + goto retry_timeout; + case -ENETDOWN: + case -ENETUNREACH: + if (task->tk_flags & RPC_TASK_NETUNREACH_FATAL) + break; + fallthrough; case -ECONNREFUSED: /* connection problems */ case -ECONNRESET: + case -ECONNABORTED: case -ENOTCONN: case -EHOSTDOWN: case -EHOSTUNREACH: - case -ENETUNREACH: case -EPIPE: - dprintk("RPC: %5u remote rpcbind unreachable: %d\n", - task->tk_pid, task->tk_status); + trace_rpcb_unreachable_err(task); if (!RPC_IS_SOFTCONN(task)) { rpc_delay(task, 5*HZ); goto retry_timeout; @@ -1631,15 +2116,18 @@ call_bind_status(struct rpc_task *task) status = task->tk_status; break; default: - dprintk("RPC: %5u unrecognized rpcbind error (%d)\n", - task->tk_pid, -task->tk_status); + trace_rpcb_unrecognized_err(task); } - rpc_exit(task, status); + rpc_call_rpcerror(task, status); + return; +out_next: + task->tk_action = call_connect; return; - retry_timeout: - task->tk_action = call_timeout; + task->tk_status = 0; + task->tk_action = call_bind; + rpc_check_timeout(task); } /* @@ -1650,17 +2138,26 @@ call_connect(struct rpc_task *task) { struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; - dprintk("RPC: %5u call_connect xprt %p %s connected\n", - task->tk_pid, xprt, - (xprt_connected(xprt) ? "is" : "is not")); + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } - task->tk_action = call_transmit; - if (!xprt_connected(xprt)) { - task->tk_action = call_connect_status; - if (task->tk_status < 0) - return; - xprt_connect(task); + if (xprt_connected(xprt)) { + task->tk_action = call_transmit; + return; } + + task->tk_action = call_connect_status; + if (task->tk_status < 0) + return; + if (task->tk_flags & RPC_TASK_NOCONNECT) { + rpc_call_rpcerror(task, -ENOTCONN); + return; + } + if (!xprt_prepare_transmit(task)) + return; + xprt_connect(task); } /* @@ -1669,31 +2166,96 @@ call_connect(struct rpc_task *task) static void call_connect_status(struct rpc_task *task) { + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; struct rpc_clnt *clnt = task->tk_client; int status = task->tk_status; - dprint_status(task); + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } + + trace_rpc_connect_status(task); - trace_rpc_connect_status(task, status); + if (task->tk_status == 0) { + clnt->cl_stats->netreconn++; + goto out_next; + } + if (xprt_connected(xprt)) { + task->tk_status = 0; + goto out_next; + } + + task->tk_status = 0; switch (status) { - /* if soft mounted, test if we've timed out */ - case -ETIMEDOUT: - task->tk_action = call_timeout; - return; + case -ENETDOWN: + case -ENETUNREACH: + if (task->tk_flags & RPC_TASK_NETUNREACH_FATAL) + break; + fallthrough; case -ECONNREFUSED: case -ECONNRESET: - case -ENETUNREACH: + /* A positive refusal suggests a rebind is needed. */ + if (clnt->cl_autobind) { + rpc_force_rebind(clnt); + if (RPC_IS_SOFTCONN(task)) + break; + goto out_retry; + } + fallthrough; + case -ECONNABORTED: + case -EHOSTUNREACH: + case -EPIPE: + case -EPROTO: + xprt_conditional_disconnect(task->tk_rqstp->rq_xprt, + task->tk_rqstp->rq_connect_cookie); if (RPC_IS_SOFTCONN(task)) break; /* retry with existing socket, after a delay */ - case 0: + rpc_delay(task, 3*HZ); + fallthrough; + case -EADDRINUSE: + case -ENOTCONN: case -EAGAIN: - task->tk_status = 0; - clnt->cl_stats->netreconn++; - task->tk_action = call_transmit; - return; + case -ETIMEDOUT: + if (!(task->tk_flags & RPC_TASK_NO_ROUND_ROBIN) && + (task->tk_flags & RPC_TASK_MOVEABLE) && + test_bit(XPRT_REMOVE, &xprt->state)) { + struct rpc_xprt *saved = task->tk_xprt; + struct rpc_xprt_switch *xps; + + xps = rpc_clnt_xprt_switch_get(clnt); + if (xps->xps_nxprts > 1) { + long value; + + xprt_release(task); + value = atomic_long_dec_return(&xprt->queuelen); + if (value == 0) + rpc_xprt_switch_remove_xprt(xps, saved, + true); + xprt_put(saved); + task->tk_xprt = NULL; + task->tk_action = call_start; + } + xprt_switch_put(xps); + if (!task->tk_xprt) + goto out; + } + goto out_retry; + case -ENOBUFS: + rpc_delay(task, HZ >> 2); + goto out_retry; } - rpc_exit(task, status); + rpc_call_rpcerror(task, status); + return; +out_next: + task->tk_action = call_transmit; + return; +out_retry: + /* Check for timeouts before looping back to call_bind */ + task->tk_action = call_bind; +out: + rpc_check_timeout(task); } /* @@ -1702,40 +2264,23 @@ call_connect_status(struct rpc_task *task) static void call_transmit(struct rpc_task *task) { - dprint_status(task); - - task->tk_action = call_status; - if (task->tk_status < 0) - return; - task->tk_status = xprt_prepare_transmit(task); - if (task->tk_status != 0) + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); return; + } + task->tk_action = call_transmit_status; - /* Encode here so that rpcsec_gss can use correct sequence number. */ - if (rpc_task_need_encode(task)) { - rpc_xdr_encode(task); - /* Did the encode result in an error condition? */ - if (task->tk_status != 0) { - /* Was the error nonfatal? */ - if (task->tk_status == -EAGAIN) - rpc_delay(task, HZ >> 4); - else - rpc_exit(task, task->tk_status); + if (!xprt_prepare_transmit(task)) + return; + task->tk_status = 0; + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) { + if (!xprt_connected(task->tk_xprt)) { + task->tk_status = -ENOTCONN; return; } + xprt_transmit(task); } - xprt_transmit(task); - if (task->tk_status < 0) - return; - /* - * On success, ensure that we call xprt_end_transmit() before sleeping - * in order to allow access to the socket to other RPC requests. - */ - call_transmit_status(task); - if (rpc_reply_expected(task)) - return; - task->tk_action = rpc_exit_task; - rpc_wake_up_queued_task(&task->tk_rqstp->rq_xprt->pending, task); + xprt_end_transmit(task); } /* @@ -1750,19 +2295,18 @@ call_transmit_status(struct rpc_task *task) * Common case: success. Force the compiler to put this * test first. */ - if (task->tk_status == 0) { - xprt_end_transmit(task); - rpc_task_force_reencode(task); + if (rpc_task_transmitted(task)) { + task->tk_status = 0; + xprt_request_wait_receive(task); return; } switch (task->tk_status) { - case -EAGAIN: - break; default: - dprint_status(task); - xprt_end_transmit(task); - rpc_task_force_reencode(task); + break; + case -EBADMSG: + task->tk_status = 0; + task->tk_action = call_encode; break; /* * Special cases: if we've been waiting on the @@ -1770,23 +2314,53 @@ call_transmit_status(struct rpc_task *task) * socket just returned a connection error, * then hold onto the transport lock. */ - case -ECONNREFUSED: + case -ENOMEM: + case -ENOBUFS: + rpc_delay(task, HZ>>2); + fallthrough; + case -EBADSLT: + case -EAGAIN: + task->tk_action = call_transmit; + task->tk_status = 0; + break; case -EHOSTDOWN: + case -ENETDOWN: case -EHOSTUNREACH: case -ENETUNREACH: + case -EPERM: + break; + case -ECONNREFUSED: if (RPC_IS_SOFTCONN(task)) { - xprt_end_transmit(task); - rpc_exit(task, task->tk_status); - break; + if (!task->tk_msg.rpc_proc->p_proc) + trace_xprt_ping(task->tk_xprt, + task->tk_status); + rpc_call_rpcerror(task, task->tk_status); + return; } + fallthrough; case -ECONNRESET: + case -ECONNABORTED: + case -EADDRINUSE: case -ENOTCONN: case -EPIPE: - rpc_task_force_reencode(task); + task->tk_action = call_bind; + task->tk_status = 0; + break; } + rpc_check_timeout(task); } #if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void call_bc_transmit(struct rpc_task *task); +static void call_bc_transmit_status(struct rpc_task *task); + +static void +call_bc_encode(struct rpc_task *task) +{ + xprt_request_enqueue_transmit(task); + task->tk_action = call_bc_transmit; +} + /* * 5b. Send the backchannel RPC reply. On error, drop the reply. In * addition, disconnect on connectivity errors. @@ -1794,36 +2368,46 @@ call_transmit_status(struct rpc_task *task) static void call_bc_transmit(struct rpc_task *task) { - struct rpc_rqst *req = task->tk_rqstp; - - task->tk_status = xprt_prepare_transmit(task); - if (task->tk_status == -EAGAIN) { - /* - * Could not reserve the transport. Try again after the - * transport is released. - */ + task->tk_action = call_bc_transmit_status; + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) { + if (!xprt_prepare_transmit(task)) + return; task->tk_status = 0; - task->tk_action = call_bc_transmit; - return; + xprt_transmit(task); } + xprt_end_transmit(task); +} - task->tk_action = rpc_exit_task; - if (task->tk_status < 0) { - printk(KERN_NOTICE "RPC: Could not send backchannel reply " - "error: %d\n", task->tk_status); - return; - } +static void +call_bc_transmit_status(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (rpc_task_transmitted(task)) + task->tk_status = 0; - xprt_transmit(task); - xprt_end_transmit(task); - dprint_status(task); switch (task->tk_status) { case 0: /* Success */ - break; + case -ENETDOWN: case -EHOSTDOWN: case -EHOSTUNREACH: case -ENETUNREACH: + case -ECONNRESET: + case -ECONNREFUSED: + case -EADDRINUSE: + case -ENOTCONN: + case -EPIPE: + break; + case -ENOMEM: + case -ENOBUFS: + rpc_delay(task, HZ>>2); + fallthrough; + case -EBADSLT: + case -EAGAIN: + task->tk_status = 0; + task->tk_action = call_bc_transmit; + return; case -ETIMEDOUT: /* * Problem reaching the server. Disconnect and let the @@ -1842,12 +2426,11 @@ call_bc_transmit(struct rpc_task *task) * We were unable to reply and will have to drop the * request. The server should reconnect and retransmit. */ - WARN_ON_ONCE(task->tk_status == -EAGAIN); printk(KERN_NOTICE "RPC: Could not send backchannel reply " "error: %d\n", task->tk_status); break; } - rpc_wake_up_queued_task(&req->rq_xprt->pending, task); + task->tk_action = rpc_exit_task; } #endif /* CONFIG_SUNRPC_BACKCHANNEL */ @@ -1858,13 +2441,10 @@ static void call_status(struct rpc_task *task) { struct rpc_clnt *clnt = task->tk_client; - struct rpc_rqst *req = task->tk_rqstp; int status; - if (req->rq_reply_bytes_recvd > 0 && !req->rq_bytes_sent) - task->tk_status = req->rq_reply_bytes_recvd; - - dprint_status(task); + if (!task->tk_msg.rpc_proc->p_proc) + trace_xprt_ping(task->tk_xprt, task->tk_status); status = task->tk_status; if (status >= 0) { @@ -1875,88 +2455,115 @@ call_status(struct rpc_task *task) trace_rpc_call_status(task); task->tk_status = 0; switch(status) { + case -ENETDOWN: + case -ENETUNREACH: + if (task->tk_flags & RPC_TASK_NETUNREACH_FATAL) + goto out_exit; + fallthrough; case -EHOSTDOWN: case -EHOSTUNREACH: - case -ENETUNREACH: + case -EPERM: + if (RPC_IS_SOFTCONN(task)) + goto out_exit; /* * Delay any retries for 3 seconds, then handle as if it * were a timeout. */ rpc_delay(task, 3*HZ); + fallthrough; case -ETIMEDOUT: - task->tk_action = call_timeout; - if (task->tk_client->cl_discrtry) - xprt_conditional_disconnect(req->rq_xprt, - req->rq_connect_cookie); break; - case -ECONNRESET: case -ECONNREFUSED: + case -ECONNRESET: + case -ECONNABORTED: + case -ENOTCONN: rpc_force_rebind(clnt); + break; + case -EADDRINUSE: rpc_delay(task, 3*HZ); + fallthrough; case -EPIPE: - case -ENOTCONN: - task->tk_action = call_bind; - break; case -EAGAIN: - task->tk_action = call_transmit; + break; + case -ENFILE: + case -ENOBUFS: + case -ENOMEM: + rpc_delay(task, HZ>>2); break; case -EIO: /* shutdown or soft timeout */ - rpc_exit(task, status); - break; + goto out_exit; default: if (clnt->cl_chatty) printk("%s: RPC call returned error %d\n", - clnt->cl_protname, -status); - rpc_exit(task, status); + clnt->cl_program->name, -status); + goto out_exit; } + task->tk_action = call_encode; + rpc_check_timeout(task); + return; +out_exit: + rpc_call_rpcerror(task, status); +} + +static bool +rpc_check_connected(const struct rpc_rqst *req) +{ + /* No allocated request or transport? return true */ + if (!req || !req->rq_xprt) + return true; + return xprt_connected(req->rq_xprt); } -/* - * 6a. Handle RPC timeout - * We do not release the request slot, so we keep using the - * same XID for all retransmits. - */ static void -call_timeout(struct rpc_task *task) +rpc_check_timeout(struct rpc_task *task) { struct rpc_clnt *clnt = task->tk_client; - if (xprt_adjust_timeout(task->tk_rqstp) == 0) { - dprintk("RPC: %5u call_timeout (minor)\n", task->tk_pid); - goto retry; - } + if (RPC_SIGNALLED(task)) + return; + + if (xprt_adjust_timeout(task->tk_rqstp) == 0) + return; - dprintk("RPC: %5u call_timeout (major)\n", task->tk_pid); + trace_rpc_timeout_status(task); task->tk_timeouts++; - if (RPC_IS_SOFTCONN(task)) { - rpc_exit(task, -ETIMEDOUT); + if (RPC_IS_SOFTCONN(task) && !rpc_check_connected(task->tk_rqstp)) { + rpc_call_rpcerror(task, -ETIMEDOUT); return; } + if (RPC_IS_SOFT(task)) { + /* + * Once a "no retrans timeout" soft tasks (a.k.a NFSv4) has + * been sent, it should time out only if the transport + * connection gets terminally broken. + */ + if ((task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) && + rpc_check_connected(task->tk_rqstp)) + return; + if (clnt->cl_chatty) { - rcu_read_lock(); - printk(KERN_NOTICE "%s: server %s not responding, timed out\n", - clnt->cl_protname, - rcu_dereference(clnt->cl_xprt)->servername); - rcu_read_unlock(); + pr_notice_ratelimited( + "%s: server %s not responding, timed out\n", + clnt->cl_program->name, + task->tk_xprt->servername); } if (task->tk_flags & RPC_TASK_TIMEOUT) - rpc_exit(task, -ETIMEDOUT); + rpc_call_rpcerror(task, -ETIMEDOUT); else - rpc_exit(task, -EIO); + __rpc_call_rpcerror(task, -EIO, -ETIMEDOUT); return; } if (!(task->tk_flags & RPC_CALL_MAJORSEEN)) { task->tk_flags |= RPC_CALL_MAJORSEEN; if (clnt->cl_chatty) { - rcu_read_lock(); - printk(KERN_NOTICE "%s: server %s not responding, still trying\n", - clnt->cl_protname, - rcu_dereference(clnt->cl_xprt)->servername); - rcu_read_unlock(); + pr_notice_ratelimited( + "%s: server %s not responding, still trying\n", + clnt->cl_program->name, + task->tk_xprt->servername); } } rpc_force_rebind(clnt); @@ -1965,11 +2572,6 @@ call_timeout(struct rpc_task *task) * event? RFC2203 requires the server to drop all such requests. */ rpcauth_invalcred(task); - -retry: - clnt->cl_stats->rpcretrans++; - task->tk_action = call_bind; - task->tk_status = 0; } /* @@ -1980,297 +2582,764 @@ call_decode(struct rpc_task *task) { struct rpc_clnt *clnt = task->tk_client; struct rpc_rqst *req = task->tk_rqstp; - kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode; - __be32 *p; + struct xdr_stream xdr; + int err; - dprint_status(task); + if (!task->tk_msg.rpc_proc->p_decode) { + task->tk_action = rpc_exit_task; + return; + } if (task->tk_flags & RPC_CALL_MAJORSEEN) { if (clnt->cl_chatty) { - rcu_read_lock(); - printk(KERN_NOTICE "%s: server %s OK\n", - clnt->cl_protname, - rcu_dereference(clnt->cl_xprt)->servername); - rcu_read_unlock(); + pr_notice_ratelimited("%s: server %s OK\n", + clnt->cl_program->name, + task->tk_xprt->servername); } task->tk_flags &= ~RPC_CALL_MAJORSEEN; } /* - * Ensure that we see all writes made by xprt_complete_rqst() + * Did we ever call xprt_complete_rqst()? If not, we should assume + * the message is incomplete. + */ + err = -EAGAIN; + if (!req->rq_reply_bytes_recvd) + goto out; + + /* Ensure that we see all writes made by xprt_complete_rqst() * before it changed req->rq_reply_bytes_recvd. */ smp_rmb(); + req->rq_rcv_buf.len = req->rq_private_buf.len; + trace_rpc_xdr_recvfrom(task, &req->rq_rcv_buf); /* Check that the softirq receive buffer is valid */ WARN_ON(memcmp(&req->rq_rcv_buf, &req->rq_private_buf, sizeof(req->rq_rcv_buf)) != 0); - if (req->rq_rcv_buf.len < 12) { - if (!RPC_IS_SOFT(task)) { - task->tk_action = call_bind; - clnt->cl_stats->rpcretrans++; - goto out_retry; - } - dprintk("RPC: %s: too small RPC reply size (%d bytes)\n", - clnt->cl_protname, task->tk_status); - task->tk_action = call_timeout; - goto out_retry; - } - - p = rpc_verify_header(task); - if (IS_ERR(p)) { - if (p == ERR_PTR(-EAGAIN)) - goto out_retry; + xdr_init_decode(&xdr, &req->rq_rcv_buf, + req->rq_rcv_buf.head[0].iov_base, req); + err = rpc_decode_header(task, &xdr); +out: + switch (err) { + case 0: + task->tk_action = rpc_exit_task; + task->tk_status = rpcauth_unwrap_resp(task, &xdr); + xdr_finish_decode(&xdr); return; - } - - task->tk_action = rpc_exit_task; - - if (decode) { - task->tk_status = rpcauth_unwrap_resp(task, decode, req, p, - task->tk_msg.rpc_resp); - } - dprintk("RPC: %5u call_decode result %d\n", task->tk_pid, - task->tk_status); - return; -out_retry: - task->tk_status = 0; - /* Note: rpc_verify_header() may have freed the RPC slot */ - if (task->tk_rqstp == req) { - req->rq_reply_bytes_recvd = req->rq_rcv_buf.len = 0; + case -EAGAIN: + task->tk_status = 0; if (task->tk_client->cl_discrtry) xprt_conditional_disconnect(req->rq_xprt, - req->rq_connect_cookie); + req->rq_connect_cookie); + task->tk_action = call_encode; + rpc_check_timeout(task); + break; + case -EKEYREJECTED: + task->tk_action = call_reserve; + rpc_check_timeout(task); + rpcauth_invalcred(task); + /* Ensure we obtain a new XID if we retry! */ + xprt_release(task); } } -static __be32 * -rpc_encode_header(struct rpc_task *task) +static int +rpc_encode_header(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_clnt *clnt = task->tk_client; struct rpc_rqst *req = task->tk_rqstp; - __be32 *p = req->rq_svec[0].iov_base; - - /* FIXME: check buffer size? */ - - p = xprt_skip_transport_header(req->rq_xprt, p); - *p++ = req->rq_xid; /* XID */ - *p++ = htonl(RPC_CALL); /* CALL */ - *p++ = htonl(RPC_VERSION); /* RPC version */ - *p++ = htonl(clnt->cl_prog); /* program number */ - *p++ = htonl(clnt->cl_vers); /* program version */ - *p++ = htonl(task->tk_msg.rpc_proc->p_proc); /* procedure */ - p = rpcauth_marshcred(task, p); - req->rq_slen = xdr_adjust_iovec(&req->rq_svec[0], p); - return p; + __be32 *p; + int error; + + error = -EMSGSIZE; + p = xdr_reserve_space(xdr, RPC_CALLHDRSIZE << 2); + if (!p) + goto out_fail; + *p++ = req->rq_xid; + *p++ = rpc_call; + *p++ = cpu_to_be32(RPC_VERSION); + *p++ = cpu_to_be32(clnt->cl_prog); + *p++ = cpu_to_be32(clnt->cl_vers); + *p = cpu_to_be32(task->tk_msg.rpc_proc->p_proc); + + error = rpcauth_marshcred(task, xdr); + if (error < 0) + goto out_fail; + return 0; +out_fail: + trace_rpc_bad_callhdr(task); + rpc_call_rpcerror(task, error); + return error; } -static __be32 * -rpc_verify_header(struct rpc_task *task) +static noinline int +rpc_decode_header(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_clnt *clnt = task->tk_client; - struct kvec *iov = &task->tk_rqstp->rq_rcv_buf.head[0]; - int len = task->tk_rqstp->rq_rcv_buf.len >> 2; - __be32 *p = iov->iov_base; - u32 n; - int error = -EACCES; - - if ((task->tk_rqstp->rq_rcv_buf.len & 3) != 0) { - /* RFC-1014 says that the representation of XDR data must be a - * multiple of four bytes - * - if it isn't pointer subtraction in the NFS client may give - * undefined results - */ - dprintk("RPC: %5u %s: XDR representation not a multiple of" - " 4 bytes: 0x%x\n", task->tk_pid, __func__, - task->tk_rqstp->rq_rcv_buf.len); - goto out_eio; - } - if ((len -= 3) < 0) - goto out_overflow; - - p += 1; /* skip XID */ - if ((n = ntohl(*p++)) != RPC_REPLY) { - dprintk("RPC: %5u %s: not an RPC reply: %x\n", - task->tk_pid, __func__, n); - goto out_garbage; - } + int error; + __be32 *p; - if ((n = ntohl(*p++)) != RPC_MSG_ACCEPTED) { - if (--len < 0) - goto out_overflow; - switch ((n = ntohl(*p++))) { - case RPC_AUTH_ERROR: - break; - case RPC_MISMATCH: - dprintk("RPC: %5u %s: RPC call version mismatch!\n", - task->tk_pid, __func__); - error = -EPROTONOSUPPORT; - goto out_err; - default: - dprintk("RPC: %5u %s: RPC call rejected, " - "unknown error: %x\n", - task->tk_pid, __func__, n); - goto out_eio; - } - if (--len < 0) - goto out_overflow; - switch ((n = ntohl(*p++))) { - case RPC_AUTH_REJECTEDCRED: - case RPC_AUTH_REJECTEDVERF: - case RPCSEC_GSS_CREDPROBLEM: - case RPCSEC_GSS_CTXPROBLEM: + /* RFC-1014 says that the representation of XDR data must be a + * multiple of four bytes + * - if it isn't pointer subtraction in the NFS client may give + * undefined results + */ + if (task->tk_rqstp->rq_rcv_buf.len & 3) + goto out_unparsable; + + p = xdr_inline_decode(xdr, 3 * sizeof(*p)); + if (!p) + goto out_unparsable; + p++; /* skip XID */ + if (*p++ != rpc_reply) + goto out_unparsable; + if (*p++ != rpc_msg_accepted) + goto out_msg_denied; + + error = rpcauth_checkverf(task, xdr); + if (error) { + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + if (!test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) { + rpcauth_invalcred(task); if (!task->tk_cred_retry) - break; + goto out_err; task->tk_cred_retry--; - dprintk("RPC: %5u %s: retry stale creds\n", - task->tk_pid, __func__); - rpcauth_invalcred(task); - /* Ensure we obtain a new XID! */ - xprt_release(task); - task->tk_action = call_reserve; - goto out_retry; - case RPC_AUTH_BADCRED: - case RPC_AUTH_BADVERF: - /* possibly garbled cred/verf? */ - if (!task->tk_garb_retry) - break; - task->tk_garb_retry--; - dprintk("RPC: %5u %s: retry garbled creds\n", - task->tk_pid, __func__); - task->tk_action = call_bind; - goto out_retry; - case RPC_AUTH_TOOWEAK: - rcu_read_lock(); - printk(KERN_NOTICE "RPC: server %s requires stronger " - "authentication.\n", - rcu_dereference(clnt->cl_xprt)->servername); - rcu_read_unlock(); - break; - default: - dprintk("RPC: %5u %s: unknown auth error: %x\n", - task->tk_pid, __func__, n); - error = -EIO; + trace_rpc__stale_creds(task); + return -EKEYREJECTED; } - dprintk("RPC: %5u %s: call rejected %d\n", - task->tk_pid, __func__, n); - goto out_err; + goto out_verifier; } - if (!(p = rpcauth_checkverf(task, p))) { - dprintk("RPC: %5u %s: auth check failed\n", - task->tk_pid, __func__); - goto out_garbage; /* bad verifier, retry */ - } - len = p - (__be32 *)iov->iov_base - 1; - if (len < 0) - goto out_overflow; - switch ((n = ntohl(*p++))) { - case RPC_SUCCESS: - return p; - case RPC_PROG_UNAVAIL: - dprintk_rcu("RPC: %5u %s: program %u is unsupported " - "by server %s\n", task->tk_pid, __func__, - (unsigned int)clnt->cl_prog, - rcu_dereference(clnt->cl_xprt)->servername); + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (!p) + goto out_unparsable; + switch (*p) { + case rpc_success: + return 0; + case rpc_prog_unavail: + trace_rpc__prog_unavail(task); error = -EPFNOSUPPORT; goto out_err; - case RPC_PROG_MISMATCH: - dprintk_rcu("RPC: %5u %s: program %u, version %u unsupported " - "by server %s\n", task->tk_pid, __func__, - (unsigned int)clnt->cl_prog, - (unsigned int)clnt->cl_vers, - rcu_dereference(clnt->cl_xprt)->servername); + case rpc_prog_mismatch: + trace_rpc__prog_mismatch(task); error = -EPROTONOSUPPORT; goto out_err; - case RPC_PROC_UNAVAIL: - dprintk_rcu("RPC: %5u %s: proc %s unsupported by program %u, " - "version %u on server %s\n", - task->tk_pid, __func__, - rpc_proc_name(task), - clnt->cl_prog, clnt->cl_vers, - rcu_dereference(clnt->cl_xprt)->servername); + case rpc_proc_unavail: + trace_rpc__proc_unavail(task); error = -EOPNOTSUPP; goto out_err; - case RPC_GARBAGE_ARGS: - dprintk("RPC: %5u %s: server saw garbage\n", - task->tk_pid, __func__); - break; /* retry */ + case rpc_garbage_args: + case rpc_system_err: + trace_rpc__garbage_args(task); + error = -EIO; + break; default: - dprintk("RPC: %5u %s: server accept status: %x\n", - task->tk_pid, __func__, n); - /* Also retry */ + goto out_unparsable; } out_garbage: clnt->cl_stats->rpcgarbage++; if (task->tk_garb_retry) { task->tk_garb_retry--; - dprintk("RPC: %5u %s: retrying\n", - task->tk_pid, __func__); - task->tk_action = call_bind; -out_retry: - return ERR_PTR(-EAGAIN); + task->tk_action = call_encode; + return -EAGAIN; } -out_eio: - error = -EIO; out_err: - rpc_exit(task, error); - dprintk("RPC: %5u %s: call failed with error %d\n", task->tk_pid, - __func__, error); - return ERR_PTR(error); -out_overflow: - dprintk("RPC: %5u %s: server reply was truncated.\n", task->tk_pid, - __func__); + rpc_call_rpcerror(task, error); + return error; + +out_unparsable: + trace_rpc__unparsable(task); + error = -EIO; goto out_garbage; + +out_verifier: + trace_rpc_bad_verifier(task); + switch (error) { + case -EPROTONOSUPPORT: + goto out_err; + case -EACCES: + /* possible RPCSEC_GSS out-of-sequence event (RFC2203), + * reset recv state and keep waiting, don't retransmit + */ + task->tk_rqstp->rq_reply_bytes_recvd = 0; + task->tk_status = xprt_request_enqueue_receive(task); + task->tk_action = call_transmit_status; + return -EBADMSG; + default: + goto out_garbage; + } + +out_msg_denied: + error = -EACCES; + p = xdr_inline_decode(xdr, sizeof(*p)); + if (!p) + goto out_unparsable; + switch (*p++) { + case rpc_auth_error: + break; + case rpc_mismatch: + trace_rpc__mismatch(task); + error = -EPROTONOSUPPORT; + goto out_err; + default: + goto out_unparsable; + } + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (!p) + goto out_unparsable; + switch (*p++) { + case rpc_autherr_rejectedcred: + case rpc_autherr_rejectedverf: + case rpcsec_gsserr_credproblem: + case rpcsec_gsserr_ctxproblem: + rpcauth_invalcred(task); + if (!task->tk_cred_retry) + break; + task->tk_cred_retry--; + trace_rpc__stale_creds(task); + return -EKEYREJECTED; + case rpc_autherr_badcred: + case rpc_autherr_badverf: + /* possibly garbled cred/verf? */ + if (!task->tk_garb_retry) + break; + task->tk_garb_retry--; + trace_rpc__bad_creds(task); + task->tk_action = call_encode; + return -EAGAIN; + case rpc_autherr_tooweak: + trace_rpc__auth_tooweak(task); + pr_warn("RPC: server %s requires stronger authentication.\n", + task->tk_xprt->servername); + break; + default: + goto out_unparsable; + } + goto out_err; } -static void rpcproc_encode_null(void *rqstp, struct xdr_stream *xdr, void *obj) +static void rpcproc_encode_null(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + const void *obj) { } -static int rpcproc_decode_null(void *rqstp, struct xdr_stream *xdr, void *obj) +static int rpcproc_decode_null(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + void *obj) { return 0; } -static struct rpc_procinfo rpcproc_null = { +static const struct rpc_procinfo rpcproc_null = { .p_encode = rpcproc_encode_null, .p_decode = rpcproc_decode_null, }; -static int rpc_ping(struct rpc_clnt *clnt) +static const struct rpc_procinfo rpcproc_null_noreply = { + .p_encode = rpcproc_encode_null, +}; + +static void +rpc_null_call_prepare(struct rpc_task *task, void *data) +{ + task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT; + rpc_call_start(task); +} + +static const struct rpc_call_ops rpc_null_ops = { + .rpc_call_prepare = rpc_null_call_prepare, + .rpc_call_done = rpc_default_callback, +}; + +static +struct rpc_task *rpc_call_null_helper(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, struct rpc_cred *cred, int flags, + const struct rpc_call_ops *ops, void *data) { struct rpc_message msg = { .rpc_proc = &rpcproc_null, }; - int err; - msg.rpc_cred = authnull_ops.lookup_cred(NULL, NULL, 0); - err = rpc_call_sync(clnt, &msg, RPC_TASK_SOFT | RPC_TASK_SOFTCONN); - put_rpccred(msg.rpc_cred); - return err; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_xprt = xprt, + .rpc_message = &msg, + .rpc_op_cred = cred, + .callback_ops = ops ?: &rpc_null_ops, + .callback_data = data, + .flags = flags | RPC_TASK_SOFT | RPC_TASK_SOFTCONN | + RPC_TASK_NULLCREDS, + }; + + return rpc_run_task(&task_setup_data); } struct rpc_task *rpc_call_null(struct rpc_clnt *clnt, struct rpc_cred *cred, int flags) { + return rpc_call_null_helper(clnt, NULL, cred, flags, NULL, NULL); +} +EXPORT_SYMBOL_GPL(rpc_call_null); + +static int rpc_ping(struct rpc_clnt *clnt) +{ + struct rpc_task *task; + int status; + + if (clnt->cl_auth->au_ops->ping) + return clnt->cl_auth->au_ops->ping(clnt); + + task = rpc_call_null_helper(clnt, NULL, NULL, 0, NULL, NULL); + if (IS_ERR(task)) + return PTR_ERR(task); + status = task->tk_status; + rpc_put_task(task); + return status; +} + +static int rpc_ping_noreply(struct rpc_clnt *clnt) +{ struct rpc_message msg = { - .rpc_proc = &rpcproc_null, - .rpc_cred = cred, + .rpc_proc = &rpcproc_null_noreply, }; struct rpc_task_setup task_setup_data = { .rpc_client = clnt, .rpc_message = &msg, - .callback_ops = &rpc_default_ops, - .flags = flags, + .callback_ops = &rpc_null_ops, + .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN | RPC_TASK_NULLCREDS, }; - return rpc_run_task(&task_setup_data); + struct rpc_task *task; + int status; + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + status = task->tk_status; + rpc_put_task(task); + return status; } -EXPORT_SYMBOL_GPL(rpc_call_null); -#ifdef RPC_DEBUG -static void rpc_show_header(void) +struct rpc_cb_add_xprt_calldata { + struct rpc_xprt_switch *xps; + struct rpc_xprt *xprt; +}; + +static void rpc_cb_add_xprt_done(struct rpc_task *task, void *calldata) +{ + struct rpc_cb_add_xprt_calldata *data = calldata; + + if (task->tk_status == 0) + rpc_xprt_switch_add_xprt(data->xps, data->xprt); +} + +static void rpc_cb_add_xprt_release(void *calldata) +{ + struct rpc_cb_add_xprt_calldata *data = calldata; + + xprt_put(data->xprt); + xprt_switch_put(data->xps); + kfree(data); +} + +static const struct rpc_call_ops rpc_cb_add_xprt_call_ops = { + .rpc_call_prepare = rpc_null_call_prepare, + .rpc_call_done = rpc_cb_add_xprt_done, + .rpc_release = rpc_cb_add_xprt_release, +}; + +/** + * rpc_clnt_test_and_add_xprt - Test and add a new transport to a rpc_clnt + * @clnt: pointer to struct rpc_clnt + * @xps: pointer to struct rpc_xprt_switch, + * @xprt: pointer struct rpc_xprt + * @in_max_connect: pointer to the max_connect value for the passed in xprt transport + */ +int rpc_clnt_test_and_add_xprt(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xps, struct rpc_xprt *xprt, + void *in_max_connect) +{ + struct rpc_cb_add_xprt_calldata *data; + struct rpc_task *task; + int max_connect = clnt->cl_max_connect; + + if (in_max_connect) + max_connect = *(int *)in_max_connect; + if (xps->xps_nunique_destaddr_xprts + 1 > max_connect) { + rcu_read_lock(); + pr_warn("SUNRPC: reached max allowed number (%d) did not add " + "transport to server: %s\n", max_connect, + rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR)); + rcu_read_unlock(); + return -EINVAL; + } + + data = kmalloc(sizeof(*data), GFP_KERNEL); + if (!data) + return -ENOMEM; + data->xps = xprt_switch_get(xps); + data->xprt = xprt_get(xprt); + if (rpc_xprt_switch_has_addr(data->xps, (struct sockaddr *)&xprt->addr)) { + rpc_cb_add_xprt_release(data); + goto success; + } + + task = rpc_call_null_helper(clnt, xprt, NULL, RPC_TASK_ASYNC, + &rpc_cb_add_xprt_call_ops, data); + if (IS_ERR(task)) + return PTR_ERR(task); + + data->xps->xps_nunique_destaddr_xprts++; + rpc_put_task(task); +success: + return 1; +} +EXPORT_SYMBOL_GPL(rpc_clnt_test_and_add_xprt); + +static int rpc_clnt_add_xprt_helper(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + struct rpc_add_xprt_test *data) +{ + struct rpc_task *task; + int status = -EADDRINUSE; + + /* Test the connection */ + task = rpc_call_null_helper(clnt, xprt, NULL, 0, NULL, NULL); + if (IS_ERR(task)) + return PTR_ERR(task); + + status = task->tk_status; + rpc_put_task(task); + + if (status < 0) + return status; + + /* rpc_xprt_switch and rpc_xprt are deferrenced by add_xprt_test() */ + data->add_xprt_test(clnt, xprt, data->data); + + return 0; +} + +/** + * rpc_clnt_setup_test_and_add_xprt() + * + * This is an rpc_clnt_add_xprt setup() function which returns 1 so: + * 1) caller of the test function must dereference the rpc_xprt_switch + * and the rpc_xprt. + * 2) test function must call rpc_xprt_switch_add_xprt, usually in + * the rpc_call_done routine. + * + * Upon success (return of 1), the test function adds the new + * transport to the rpc_clnt xprt switch + * + * @clnt: struct rpc_clnt to get the new transport + * @xps: the rpc_xprt_switch to hold the new transport + * @xprt: the rpc_xprt to test + * @data: a struct rpc_add_xprt_test pointer that holds the test function + * and test function call data + */ +int rpc_clnt_setup_test_and_add_xprt(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, + void *data) +{ + int status = -EADDRINUSE; + + xprt = xprt_get(xprt); + xprt_switch_get(xps); + + if (rpc_xprt_switch_has_addr(xps, (struct sockaddr *)&xprt->addr)) + goto out_err; + + status = rpc_clnt_add_xprt_helper(clnt, xprt, data); + if (status < 0) + goto out_err; + + status = 1; +out_err: + xprt_put(xprt); + xprt_switch_put(xps); + if (status < 0) + pr_info("RPC: rpc_clnt_test_xprt failed: %d addr %s not " + "added\n", status, + xprt->address_strings[RPC_DISPLAY_ADDR]); + /* so that rpc_clnt_add_xprt does not call rpc_xprt_switch_add_xprt */ + return status; +} +EXPORT_SYMBOL_GPL(rpc_clnt_setup_test_and_add_xprt); + +/** + * rpc_clnt_add_xprt - Add a new transport to a rpc_clnt + * @clnt: pointer to struct rpc_clnt + * @xprtargs: pointer to struct xprt_create + * @setup: callback to test and/or set up the connection + * @data: pointer to setup function data + * + * Creates a new transport using the parameters set in args and + * adds it to clnt. + * If ping is set, then test that connectivity succeeds before + * adding the new transport. + * + */ +int rpc_clnt_add_xprt(struct rpc_clnt *clnt, + struct xprt_create *xprtargs, + int (*setup)(struct rpc_clnt *, + struct rpc_xprt_switch *, + struct rpc_xprt *, + void *), + void *data) +{ + struct rpc_xprt_switch *xps; + struct rpc_xprt *xprt; + unsigned long connect_timeout; + unsigned long reconnect_timeout; + unsigned char resvport, reuseport; + int ret = 0, ident; + + rcu_read_lock(); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + xprt = xprt_iter_xprt(&clnt->cl_xpi); + if (xps == NULL || xprt == NULL) { + rcu_read_unlock(); + xprt_switch_put(xps); + return -EAGAIN; + } + resvport = xprt->resvport; + reuseport = xprt->reuseport; + connect_timeout = xprt->connect_timeout; + reconnect_timeout = xprt->max_reconnect_timeout; + ident = xprt->xprt_class->ident; + rcu_read_unlock(); + + if (!xprtargs->ident) + xprtargs->ident = ident; + xprtargs->xprtsec = clnt->cl_xprtsec; + xprt = xprt_create_transport(xprtargs); + if (IS_ERR(xprt)) { + ret = PTR_ERR(xprt); + goto out_put_switch; + } + xprt->resvport = resvport; + xprt->reuseport = reuseport; + + if (xprtargs->connect_timeout) + connect_timeout = xprtargs->connect_timeout; + if (xprtargs->reconnect_timeout) + reconnect_timeout = xprtargs->reconnect_timeout; + if (xprt->ops->set_connect_timeout != NULL) + xprt->ops->set_connect_timeout(xprt, + connect_timeout, + reconnect_timeout); + + rpc_xprt_switch_set_roundrobin(xps); + if (setup) { + ret = setup(clnt, xps, xprt, data); + if (ret != 0) + goto out_put_xprt; + } + rpc_xprt_switch_add_xprt(xps, xprt); +out_put_xprt: + xprt_put(xprt); +out_put_switch: + xprt_switch_put(xps); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_clnt_add_xprt); + +static int rpc_xprt_probe_trunked(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + struct rpc_add_xprt_test *data) { + struct rpc_xprt *main_xprt; + int status = 0; + + xprt_get(xprt); + + rcu_read_lock(); + main_xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + status = rpc_cmp_addr_port((struct sockaddr *)&xprt->addr, + (struct sockaddr *)&main_xprt->addr); + rcu_read_unlock(); + xprt_put(main_xprt); + if (status || !test_bit(XPRT_OFFLINE, &xprt->state)) + goto out; + + status = rpc_clnt_add_xprt_helper(clnt, xprt, data); +out: + xprt_put(xprt); + return status; +} + +/* rpc_clnt_probe_trunked_xprt -- probe offlined transport for session trunking + * @clnt rpc_clnt structure + * + * For each offlined transport found in the rpc_clnt structure call + * the function rpc_xprt_probe_trunked() which will determine if this + * transport still belongs to the trunking group. + */ +void rpc_clnt_probe_trunked_xprts(struct rpc_clnt *clnt, + struct rpc_add_xprt_test *data) +{ + struct rpc_xprt_iter xpi; + int ret; + + ret = rpc_clnt_xprt_iter_offline_init(clnt, &xpi); + if (ret) + return; + for (;;) { + struct rpc_xprt *xprt = xprt_iter_get_next(&xpi); + + if (!xprt) + break; + ret = rpc_xprt_probe_trunked(clnt, xprt, data); + xprt_put(xprt); + if (ret < 0) + break; + xprt_iter_rewind(&xpi); + } + xprt_iter_destroy(&xpi); +} +EXPORT_SYMBOL_GPL(rpc_clnt_probe_trunked_xprts); + +static int rpc_xprt_offline(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *data) +{ + struct rpc_xprt *main_xprt; + struct rpc_xprt_switch *xps; + int err = 0; + + xprt_get(xprt); + + rcu_read_lock(); + main_xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + err = rpc_cmp_addr_port((struct sockaddr *)&xprt->addr, + (struct sockaddr *)&main_xprt->addr); + rcu_read_unlock(); + xprt_put(main_xprt); + if (err) + goto out; + + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) { + err = -EINTR; + goto out; + } + xprt_set_offline_locked(xprt, xps); + + xprt_release_write(xprt, NULL); +out: + xprt_put(xprt); + xprt_switch_put(xps); + return err; +} + +/* rpc_clnt_manage_trunked_xprts -- offline trunked transports + * @clnt rpc_clnt structure + * + * For each active transport found in the rpc_clnt structure call + * the function rpc_xprt_offline() which will identify trunked transports + * and will mark them offline. + */ +void rpc_clnt_manage_trunked_xprts(struct rpc_clnt *clnt) +{ + rpc_clnt_iterate_for_each_xprt(clnt, rpc_xprt_offline, NULL); +} +EXPORT_SYMBOL_GPL(rpc_clnt_manage_trunked_xprts); + +struct connect_timeout_data { + unsigned long connect_timeout; + unsigned long reconnect_timeout; +}; + +static int +rpc_xprt_set_connect_timeout(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *data) +{ + struct connect_timeout_data *timeo = data; + + if (xprt->ops->set_connect_timeout) + xprt->ops->set_connect_timeout(xprt, + timeo->connect_timeout, + timeo->reconnect_timeout); + return 0; +} + +void +rpc_set_connect_timeout(struct rpc_clnt *clnt, + unsigned long connect_timeout, + unsigned long reconnect_timeout) +{ + struct connect_timeout_data timeout = { + .connect_timeout = connect_timeout, + .reconnect_timeout = reconnect_timeout, + }; + rpc_clnt_iterate_for_each_xprt(clnt, + rpc_xprt_set_connect_timeout, + &timeout); +} +EXPORT_SYMBOL_GPL(rpc_set_connect_timeout); + +void rpc_clnt_xprt_set_online(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + struct rpc_xprt_switch *xps; + + xps = rpc_clnt_xprt_switch_get(clnt); + xprt_set_online_locked(xprt, xps); + xprt_switch_put(xps); +} + +void rpc_clnt_xprt_switch_add_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + struct rpc_xprt_switch *xps; + + if (rpc_clnt_xprt_switch_has_addr(clnt, + (const struct sockaddr *)&xprt->addr)) { + return rpc_clnt_xprt_set_online(clnt, xprt); + } + + xps = rpc_clnt_xprt_switch_get(clnt); + rpc_xprt_switch_add_xprt(xps, xprt); + xprt_switch_put(xps); +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_add_xprt); + +void rpc_clnt_xprt_switch_remove_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + struct rpc_xprt_switch *xps; + + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + rpc_xprt_switch_remove_xprt(rcu_dereference(clnt->cl_xpi.xpi_xpswitch), + xprt, 0); + xps->xps_nunique_destaddr_xprts--; + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_remove_xprt); + +bool rpc_clnt_xprt_switch_has_addr(struct rpc_clnt *clnt, + const struct sockaddr *sap) +{ + struct rpc_xprt_switch *xps; + bool ret; + + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + ret = rpc_xprt_switch_has_addr(xps, sap); + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_has_addr); + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +static void rpc_show_header(struct rpc_clnt *clnt) +{ + printk(KERN_INFO "clnt[%pISpc] RPC tasks[%d]\n", + (struct sockaddr *)&clnt->cl_xprt->addr, + atomic_read(&clnt->cl_task_count)); printk(KERN_INFO "-pid- flgs status -client- --rqstp- " "-timeout ---ops--\n"); } @@ -2285,8 +3354,8 @@ static void rpc_show_task(const struct rpc_clnt *clnt, printk(KERN_INFO "%5u %04x %6d %8p %8p %8ld %8p %sv%u %s a:%ps q:%s\n", task->tk_pid, task->tk_flags, task->tk_status, - clnt, task->tk_rqstp, task->tk_timeout, task->tk_ops, - clnt->cl_protname, clnt->cl_vers, rpc_proc_name(task), + clnt, task->tk_rqstp, rpc_task_timeout(task), task->tk_ops, + clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task), task->tk_action, rpc_waitq); } @@ -2302,7 +3371,7 @@ void rpc_show_tasks(struct net *net) spin_lock(&clnt->cl_lock); list_for_each_entry(task, &clnt->cl_tasks, tk_task) { if (!header) { - rpc_show_header(); + rpc_show_header(clnt); header++; } rpc_show_task(clnt, task); @@ -2312,3 +3381,45 @@ void rpc_show_tasks(struct net *net) spin_unlock(&sn->rpc_client_lock); } #endif + +#if IS_ENABLED(CONFIG_SUNRPC_SWAP) +static int +rpc_clnt_swap_activate_callback(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *dummy) +{ + return xprt_enable_swap(xprt); +} + +int +rpc_clnt_swap_activate(struct rpc_clnt *clnt) +{ + while (clnt != clnt->cl_parent) + clnt = clnt->cl_parent; + if (atomic_inc_return(&clnt->cl_swapper) == 1) + return rpc_clnt_iterate_for_each_xprt(clnt, + rpc_clnt_swap_activate_callback, NULL); + return 0; +} +EXPORT_SYMBOL_GPL(rpc_clnt_swap_activate); + +static int +rpc_clnt_swap_deactivate_callback(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *dummy) +{ + xprt_disable_swap(xprt); + return 0; +} + +void +rpc_clnt_swap_deactivate(struct rpc_clnt *clnt) +{ + while (clnt != clnt->cl_parent) + clnt = clnt->cl_parent; + if (atomic_dec_if_positive(&clnt->cl_swapper) == 0) + rpc_clnt_iterate_for_each_xprt(clnt, + rpc_clnt_swap_deactivate_callback, NULL); +} +EXPORT_SYMBOL_GPL(rpc_clnt_swap_deactivate); +#endif /* CONFIG_SUNRPC_SWAP */ diff --git a/net/sunrpc/debugfs.c b/net/sunrpc/debugfs.c new file mode 100644 index 000000000000..32417db340de --- /dev/null +++ b/net/sunrpc/debugfs.c @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * debugfs interface for sunrpc + * + * (c) 2014 Jeff Layton <jlayton@primarydata.com> + */ + +#include <linux/debugfs.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/clnt.h> + +#include "netns.h" +#include "fail.h" + +static struct dentry *topdir; +static struct dentry *rpc_clnt_dir; +static struct dentry *rpc_xprt_dir; + +static int +tasks_show(struct seq_file *f, void *v) +{ + u32 xid = 0; + struct rpc_task *task = v; + struct rpc_clnt *clnt = task->tk_client; + const char *rpc_waitq = "none"; + + if (RPC_IS_QUEUED(task)) + rpc_waitq = rpc_qname(task->tk_waitqueue); + + if (task->tk_rqstp) + xid = be32_to_cpu(task->tk_rqstp->rq_xid); + + seq_printf(f, "%5u %04x %6d 0x%x 0x%x %8ld %ps %sv%u %s a:%ps q:%s\n", + task->tk_pid, task->tk_flags, task->tk_status, + clnt->cl_clid, xid, rpc_task_timeout(task), task->tk_ops, + clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task), + task->tk_action, rpc_waitq); + return 0; +} + +static void * +tasks_start(struct seq_file *f, loff_t *ppos) + __acquires(&clnt->cl_lock) +{ + struct rpc_clnt *clnt = f->private; + loff_t pos = *ppos; + struct rpc_task *task; + + spin_lock(&clnt->cl_lock); + list_for_each_entry(task, &clnt->cl_tasks, tk_task) + if (pos-- == 0) + return task; + return NULL; +} + +static void * +tasks_next(struct seq_file *f, void *v, loff_t *pos) +{ + struct rpc_clnt *clnt = f->private; + struct rpc_task *task = v; + struct list_head *next = task->tk_task.next; + + ++*pos; + + /* If there's another task on list, return it */ + if (next == &clnt->cl_tasks) + return NULL; + return list_entry(next, struct rpc_task, tk_task); +} + +static void +tasks_stop(struct seq_file *f, void *v) + __releases(&clnt->cl_lock) +{ + struct rpc_clnt *clnt = f->private; + spin_unlock(&clnt->cl_lock); + seq_printf(f, "clnt[%pISpc] RPC tasks[%d]\n", + (struct sockaddr *)&clnt->cl_xprt->addr, + atomic_read(&clnt->cl_task_count)); +} + +static const struct seq_operations tasks_seq_operations = { + .start = tasks_start, + .next = tasks_next, + .stop = tasks_stop, + .show = tasks_show, +}; + +static int tasks_open(struct inode *inode, struct file *filp) +{ + int ret = seq_open(filp, &tasks_seq_operations); + if (!ret) { + struct seq_file *seq = filp->private_data; + struct rpc_clnt *clnt = seq->private = inode->i_private; + + if (!refcount_inc_not_zero(&clnt->cl_count)) { + seq_release(inode, filp); + ret = -EINVAL; + } + } + + return ret; +} + +static int +tasks_release(struct inode *inode, struct file *filp) +{ + struct seq_file *seq = filp->private_data; + struct rpc_clnt *clnt = seq->private; + + rpc_release_client(clnt); + return seq_release(inode, filp); +} + +static const struct file_operations tasks_fops = { + .owner = THIS_MODULE, + .open = tasks_open, + .read = seq_read, + .llseek = seq_lseek, + .release = tasks_release, +}; + +static int do_xprt_debugfs(struct rpc_clnt *clnt, struct rpc_xprt *xprt, void *numv) +{ + int len; + char name[24]; /* enough for "../../rpc_xprt/ + 8 hex digits + NULL */ + char link[9]; /* enough for 8 hex digits + NULL */ + int *nump = numv; + + if (IS_ERR_OR_NULL(xprt->debugfs)) + return 0; + len = snprintf(name, sizeof(name), "../../rpc_xprt/%s", + xprt->debugfs->d_name.name); + if (len >= sizeof(name)) + return -1; + if (*nump == 0) + strcpy(link, "xprt"); + else { + len = snprintf(link, sizeof(link), "xprt%d", *nump); + if (len >= sizeof(link)) + return -1; + } + debugfs_create_symlink(link, clnt->cl_debugfs, name); + (*nump)++; + return 0; +} + +void +rpc_clnt_debugfs_register(struct rpc_clnt *clnt) +{ + int len; + char name[9]; /* enough for 8 hex digits + NULL */ + int xprtnum = 0; + + len = snprintf(name, sizeof(name), "%x", clnt->cl_clid); + if (len >= sizeof(name)) + return; + + /* make the per-client dir */ + clnt->cl_debugfs = debugfs_create_dir(name, rpc_clnt_dir); + + /* make tasks file */ + debugfs_create_file("tasks", S_IFREG | 0400, clnt->cl_debugfs, clnt, + &tasks_fops); + + rpc_clnt_iterate_for_each_xprt(clnt, do_xprt_debugfs, &xprtnum); +} + +void +rpc_clnt_debugfs_unregister(struct rpc_clnt *clnt) +{ + debugfs_remove_recursive(clnt->cl_debugfs); + clnt->cl_debugfs = NULL; +} + +static int +xprt_info_show(struct seq_file *f, void *v) +{ + struct rpc_xprt *xprt = f->private; + + seq_printf(f, "netid: %s\n", xprt->address_strings[RPC_DISPLAY_NETID]); + seq_printf(f, "addr: %s\n", xprt->address_strings[RPC_DISPLAY_ADDR]); + seq_printf(f, "port: %s\n", xprt->address_strings[RPC_DISPLAY_PORT]); + seq_printf(f, "state: 0x%lx\n", xprt->state); + seq_printf(f, "netns: %u\n", xprt->xprt_net->ns.inum); + + if (xprt->ops->get_srcaddr) { + int ret, buflen; + char buf[INET6_ADDRSTRLEN]; + + buflen = ARRAY_SIZE(buf); + ret = xprt->ops->get_srcaddr(xprt, buf, buflen); + if (ret < 0) + ret = sprintf(buf, "<closed>"); + seq_printf(f, "saddr: %.*s\n", ret, buf); + } + return 0; +} + +static int +xprt_info_open(struct inode *inode, struct file *filp) +{ + int ret; + struct rpc_xprt *xprt = inode->i_private; + + ret = single_open(filp, xprt_info_show, xprt); + + if (!ret) { + if (!xprt_get(xprt)) { + single_release(inode, filp); + ret = -EINVAL; + } + } + return ret; +} + +static int +xprt_info_release(struct inode *inode, struct file *filp) +{ + struct rpc_xprt *xprt = inode->i_private; + + xprt_put(xprt); + return single_release(inode, filp); +} + +static const struct file_operations xprt_info_fops = { + .owner = THIS_MODULE, + .open = xprt_info_open, + .read = seq_read, + .llseek = seq_lseek, + .release = xprt_info_release, +}; + +void +rpc_xprt_debugfs_register(struct rpc_xprt *xprt) +{ + int len, id; + static atomic_t cur_id; + char name[9]; /* 8 hex digits + NULL term */ + + id = (unsigned int)atomic_inc_return(&cur_id); + + len = snprintf(name, sizeof(name), "%x", id); + if (len >= sizeof(name)) + return; + + /* make the per-client dir */ + xprt->debugfs = debugfs_create_dir(name, rpc_xprt_dir); + + /* make tasks file */ + debugfs_create_file("info", S_IFREG | 0400, xprt->debugfs, xprt, + &xprt_info_fops); +} + +void +rpc_xprt_debugfs_unregister(struct rpc_xprt *xprt) +{ + debugfs_remove_recursive(xprt->debugfs); + xprt->debugfs = NULL; +} + +#if IS_ENABLED(CONFIG_FAIL_SUNRPC) +struct fail_sunrpc_attr fail_sunrpc = { + .attr = FAULT_ATTR_INITIALIZER, +}; +EXPORT_SYMBOL_GPL(fail_sunrpc); + +static void fail_sunrpc_init(void) +{ + struct dentry *dir; + + dir = fault_create_debugfs_attr("fail_sunrpc", NULL, + &fail_sunrpc.attr); + + debugfs_create_bool("ignore-client-disconnect", S_IFREG | 0600, dir, + &fail_sunrpc.ignore_client_disconnect); + + debugfs_create_bool("ignore-server-disconnect", S_IFREG | 0600, dir, + &fail_sunrpc.ignore_server_disconnect); + + debugfs_create_bool("ignore-cache-wait", S_IFREG | 0600, dir, + &fail_sunrpc.ignore_cache_wait); +} +#else +static void fail_sunrpc_init(void) +{ +} +#endif + +void __exit +sunrpc_debugfs_exit(void) +{ + debugfs_remove_recursive(topdir); + topdir = NULL; + rpc_clnt_dir = NULL; + rpc_xprt_dir = NULL; +} + +void __init +sunrpc_debugfs_init(void) +{ + topdir = debugfs_create_dir("sunrpc", NULL); + + rpc_clnt_dir = debugfs_create_dir("rpc_clnt", topdir); + + rpc_xprt_dir = debugfs_create_dir("rpc_xprt", topdir); + + fail_sunrpc_init(); +} diff --git a/net/sunrpc/fail.h b/net/sunrpc/fail.h new file mode 100644 index 000000000000..4b4b500df428 --- /dev/null +++ b/net/sunrpc/fail.h @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Copyright (C) 2021, Oracle. All rights reserved. + */ + +#ifndef _NET_SUNRPC_FAIL_H_ +#define _NET_SUNRPC_FAIL_H_ + +#include <linux/fault-inject.h> + +#if IS_ENABLED(CONFIG_FAULT_INJECTION) + +struct fail_sunrpc_attr { + struct fault_attr attr; + + bool ignore_client_disconnect; + bool ignore_server_disconnect; + bool ignore_cache_wait; +}; + +extern struct fail_sunrpc_attr fail_sunrpc; + +#endif /* CONFIG_FAULT_INJECTION */ + +#endif /* _NET_SUNRPC_FAIL_H_ */ diff --git a/net/sunrpc/netns.h b/net/sunrpc/netns.h index 74d948f5d5a1..4efb5f28d881 100644 --- a/net/sunrpc/netns.h +++ b/net/sunrpc/netns.h @@ -1,3 +1,4 @@ +/* SPDX-License-Identifier: GPL-2.0 */ #ifndef __SUNRPC_NETNS_H__ #define __SUNRPC_NETNS_H__ @@ -14,6 +15,7 @@ struct sunrpc_net { struct cache_detail *rsi_cache; struct super_block *pipefs_sb; + struct rpc_pipe *gssd_dummy; struct mutex pipefs_sb_lock; struct list_head all_clients; @@ -23,19 +25,18 @@ struct sunrpc_net { struct rpc_clnt *rpcb_local_clnt4; spinlock_t rpcb_clnt_lock; unsigned int rpcb_users; + unsigned int rpcb_is_af_local : 1; struct mutex gssp_lock; - wait_queue_head_t gssp_wq; struct rpc_clnt *gssp_clnt; int use_gss_proxy; int pipe_version; atomic_t pipe_users; struct proc_dir_entry *use_gssp_proc; - - unsigned int gssd_running; + struct proc_dir_entry *gss_krb5_enctypes; }; -extern int sunrpc_net_id; +extern unsigned int sunrpc_net_id; int ip_map_cache_create(struct net *); void ip_map_cache_destroy(struct net *); diff --git a/net/sunrpc/rpc_pipe.c b/net/sunrpc/rpc_pipe.c index 406859cc68aa..379daefc4847 100644 --- a/net/sunrpc/rpc_pipe.c +++ b/net/sunrpc/rpc_pipe.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * net/sunrpc/rpc_pipe.c * @@ -13,10 +14,12 @@ #include <linux/string.h> #include <linux/pagemap.h> #include <linux/mount.h> +#include <linux/fs_context.h> #include <linux/namei.h> #include <linux/fsnotify.h> #include <linux/kernel.h> #include <linux/rcupdate.h> +#include <linux/utsname.h> #include <asm/ioctls.h> #include <linux/poll.h> @@ -38,7 +41,7 @@ #define NET_NAME(net) ((net == &init_net) ? " (init_net)" : "") static struct file_system_type rpc_pipe_fs_type; - +static const struct rpc_pipe_ops gssd_dummy_pipe_ops; static struct kmem_cache *rpc_inode_cachep __read_mostly; @@ -48,7 +51,7 @@ static BLOCKING_NOTIFIER_HEAD(rpc_pipefs_notifier_list); int rpc_pipefs_notifier_register(struct notifier_block *nb) { - return blocking_notifier_chain_cond_register(&rpc_pipefs_notifier_list, nb); + return blocking_notifier_chain_register(&rpc_pipefs_notifier_list, nb); } EXPORT_SYMBOL_GPL(rpc_pipefs_notifier_register); @@ -93,7 +96,7 @@ rpc_timeout_upcall_queue(struct work_struct *work) } dentry = dget(pipe->dentry); spin_unlock(&pipe->lock); - rpc_purge_list(dentry ? &RPC_I(dentry->d_inode)->waitq : NULL, + rpc_purge_list(dentry ? &RPC_I(d_inode(dentry))->waitq : NULL, &free_list, destroy_msg, -ETIMEDOUT); dput(dentry); } @@ -151,7 +154,7 @@ rpc_queue_upcall(struct rpc_pipe *pipe, struct rpc_pipe_msg *msg) dentry = dget(pipe->dentry); spin_unlock(&pipe->lock); if (dentry) { - wake_up(&RPC_I(dentry->d_inode)->waitq); + wake_up(&RPC_I(d_inode(dentry))->waitq); dput(dentry); } return res; @@ -165,13 +168,14 @@ rpc_inode_setowner(struct inode *inode, void *private) } static void -rpc_close_pipes(struct inode *inode) +rpc_close_pipes(struct dentry *dentry) { + struct inode *inode = dentry->d_inode; struct rpc_pipe *pipe = RPC_I(inode)->pipe; int need_release; LIST_HEAD(free_list); - mutex_lock(&inode->i_mutex); + inode_lock(inode); spin_lock(&pipe->lock); need_release = pipe->nreaders != 0 || pipe->nwriters != 0; pipe->nreaders = 0; @@ -187,43 +191,33 @@ rpc_close_pipes(struct inode *inode) cancel_delayed_work_sync(&pipe->queue_timeout); rpc_inode_setowner(inode, NULL); RPC_I(inode)->pipe = NULL; - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); } static struct inode * rpc_alloc_inode(struct super_block *sb) { struct rpc_inode *rpci; - rpci = (struct rpc_inode *)kmem_cache_alloc(rpc_inode_cachep, GFP_KERNEL); + rpci = alloc_inode_sb(sb, rpc_inode_cachep, GFP_KERNEL); if (!rpci) return NULL; return &rpci->vfs_inode; } static void -rpc_i_callback(struct rcu_head *head) +rpc_free_inode(struct inode *inode) { - struct inode *inode = container_of(head, struct inode, i_rcu); kmem_cache_free(rpc_inode_cachep, RPC_I(inode)); } -static void -rpc_destroy_inode(struct inode *inode) -{ - call_rcu(&inode->i_rcu, rpc_i_callback); -} - static int rpc_pipe_open(struct inode *inode, struct file *filp) { - struct net *net = inode->i_sb->s_fs_info; - struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); struct rpc_pipe *pipe; int first_open; int res = -ENXIO; - mutex_lock(&inode->i_mutex); - sn->gssd_running = 1; + inode_lock(inode); pipe = RPC_I(inode)->pipe; if (pipe == NULL) goto out; @@ -239,7 +233,7 @@ rpc_pipe_open(struct inode *inode, struct file *filp) pipe->nwriters++; res = 0; out: - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); return res; } @@ -250,7 +244,7 @@ rpc_pipe_release(struct inode *inode, struct file *filp) struct rpc_pipe_msg *msg; int last_close; - mutex_lock(&inode->i_mutex); + inode_lock(inode); pipe = RPC_I(inode)->pipe; if (pipe == NULL) goto out; @@ -280,7 +274,7 @@ rpc_pipe_release(struct inode *inode, struct file *filp) if (last_close && pipe->ops->release_pipe) pipe->ops->release_pipe(inode); out: - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); return 0; } @@ -292,7 +286,7 @@ rpc_pipe_read(struct file *filp, char __user *buf, size_t len, loff_t *offset) struct rpc_pipe_msg *msg; int res = 0; - mutex_lock(&inode->i_mutex); + inode_lock(inode); pipe = RPC_I(inode)->pipe; if (pipe == NULL) { res = -EPIPE; @@ -324,7 +318,7 @@ rpc_pipe_read(struct file *filp, char __user *buf, size_t len, loff_t *offset) pipe->ops->destroy_msg(msg); } out_unlock: - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); return res; } @@ -334,29 +328,29 @@ rpc_pipe_write(struct file *filp, const char __user *buf, size_t len, loff_t *of struct inode *inode = file_inode(filp); int res; - mutex_lock(&inode->i_mutex); + inode_lock(inode); res = -EPIPE; if (RPC_I(inode)->pipe != NULL) res = RPC_I(inode)->pipe->ops->downcall(filp, buf, len); - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); return res; } -static unsigned int +static __poll_t rpc_pipe_poll(struct file *filp, struct poll_table_struct *wait) { struct inode *inode = file_inode(filp); struct rpc_inode *rpci = RPC_I(inode); - unsigned int mask = POLLOUT | POLLWRNORM; + __poll_t mask = EPOLLOUT | EPOLLWRNORM; poll_wait(filp, &rpci->waitq, wait); - mutex_lock(&inode->i_mutex); + inode_lock(inode); if (rpci->pipe == NULL) - mask |= POLLERR | POLLHUP; + mask |= EPOLLERR | EPOLLHUP; else if (filp->private_data || !list_empty(&rpci->pipe->pipe)) - mask |= POLLIN | POLLRDNORM; - mutex_unlock(&inode->i_mutex); + mask |= EPOLLIN | EPOLLRDNORM; + inode_unlock(inode); return mask; } @@ -369,10 +363,10 @@ rpc_pipe_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) switch (cmd) { case FIONREAD: - mutex_lock(&inode->i_mutex); + inode_lock(inode); pipe = RPC_I(inode)->pipe; if (pipe == NULL) { - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); return -EPIPE; } spin_lock(&pipe->lock); @@ -383,7 +377,7 @@ rpc_pipe_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) len += msg->len - msg->copied; } spin_unlock(&pipe->lock); - mutex_unlock(&inode->i_mutex); + inode_unlock(inode); return put_user(len, (int __user *)arg); default: return -EINVAL; @@ -392,7 +386,6 @@ rpc_pipe_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) static const struct file_operations rpc_pipe_fops = { .owner = THIS_MODULE, - .llseek = no_llseek, .read = rpc_pipe_read, .write = rpc_pipe_write, .poll = rpc_pipe_poll, @@ -409,7 +402,7 @@ rpc_show_info(struct seq_file *m, void *v) rcu_read_lock(); seq_printf(m, "RPC server: %s\n", rcu_dereference(clnt->cl_xprt)->servername); - seq_printf(m, "service: %s (%d) version %d\n", clnt->cl_protname, + seq_printf(m, "service: %s (%d) version %d\n", clnt->cl_program->name, clnt->cl_prog, clnt->cl_vers); seq_printf(m, "address: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR)); seq_printf(m, "protocol: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PROTO)); @@ -430,7 +423,7 @@ rpc_info_open(struct inode *inode, struct file *file) spin_lock(&file->f_path.dentry->d_lock); if (!d_unhashed(file->f_path.dentry)) clnt = RPC_I(inode)->private; - if (clnt != NULL && atomic_inc_not_zero(&clnt->cl_count)) { + if (clnt != NULL && refcount_inc_not_zero(&clnt->cl_count)) { spin_unlock(&file->f_path.dentry->d_lock); m->private = clnt; } else { @@ -471,32 +464,6 @@ struct rpc_filelist { umode_t mode; }; -static int rpc_delete_dentry(const struct dentry *dentry) -{ - return 1; -} - -static const struct dentry_operations rpc_dentry_operations = { - .d_delete = rpc_delete_dentry, -}; - -/* - * Lookup the data. This is trivial - if the dentry didn't already - * exist, we know it is negative. - */ -static struct dentry * -rpc_lookup(struct inode *dir, struct dentry *dentry, unsigned int flags) -{ - if (dentry->d_name.len > NAME_MAX) - return ERR_PTR(-ENAMETOOLONG); - d_add(dentry, NULL); - return NULL; -} - -static const struct inode_operations rpc_dir_inode_operations = { - .lookup = rpc_lookup, -}; - static struct inode * rpc_get_inode(struct super_block *sb, umode_t mode) { @@ -505,72 +472,19 @@ rpc_get_inode(struct super_block *sb, umode_t mode) return NULL; inode->i_ino = get_next_ino(); inode->i_mode = mode; - inode->i_atime = inode->i_mtime = inode->i_ctime = CURRENT_TIME; + simple_inode_init_ts(inode); switch (mode & S_IFMT) { case S_IFDIR: inode->i_fop = &simple_dir_operations; - inode->i_op = &rpc_dir_inode_operations; + inode->i_op = &simple_dir_inode_operations; inc_nlink(inode); + break; default: break; } return inode; } -static int __rpc_create_common(struct inode *dir, struct dentry *dentry, - umode_t mode, - const struct file_operations *i_fop, - void *private) -{ - struct inode *inode; - - d_drop(dentry); - inode = rpc_get_inode(dir->i_sb, mode); - if (!inode) - goto out_err; - inode->i_ino = iunique(dir->i_sb, 100); - if (i_fop) - inode->i_fop = i_fop; - if (private) - rpc_inode_setowner(inode, private); - d_add(dentry, inode); - return 0; -out_err: - printk(KERN_WARNING "%s: %s failed to allocate inode for dentry %s\n", - __FILE__, __func__, dentry->d_name.name); - dput(dentry); - return -ENOMEM; -} - -static int __rpc_create(struct inode *dir, struct dentry *dentry, - umode_t mode, - const struct file_operations *i_fop, - void *private) -{ - int err; - - err = __rpc_create_common(dir, dentry, S_IFREG | mode, i_fop, private); - if (err) - return err; - fsnotify_create(dir, dentry); - return 0; -} - -static int __rpc_mkdir(struct inode *dir, struct dentry *dentry, - umode_t mode, - const struct file_operations *i_fop, - void *private) -{ - int err; - - err = __rpc_create_common(dir, dentry, S_IFDIR | mode, i_fop, private); - if (err) - return err; - inc_nlink(dir); - fsnotify_mkdir(dir, dentry); - return 0; -} - static void init_pipe(struct rpc_pipe *pipe) { @@ -607,131 +521,58 @@ struct rpc_pipe *rpc_mkpipe_data(const struct rpc_pipe_ops *ops, int flags) } EXPORT_SYMBOL_GPL(rpc_mkpipe_data); -static int __rpc_mkpipe_dentry(struct inode *dir, struct dentry *dentry, - umode_t mode, - const struct file_operations *i_fop, - void *private, - struct rpc_pipe *pipe) +static int rpc_new_file(struct dentry *parent, + const char *name, + umode_t mode, + const struct file_operations *i_fop, + void *private) { - struct rpc_inode *rpci; - int err; - - err = __rpc_create_common(dir, dentry, S_IFIFO | mode, i_fop, private); - if (err) - return err; - rpci = RPC_I(dentry->d_inode); - rpci->private = private; - rpci->pipe = pipe; - fsnotify_create(dir, dentry); - return 0; -} - -static int __rpc_rmdir(struct inode *dir, struct dentry *dentry) -{ - int ret; - - dget(dentry); - ret = simple_rmdir(dir, dentry); - d_delete(dentry); - dput(dentry); - return ret; -} - -int rpc_rmdir(struct dentry *dentry) -{ - struct dentry *parent; - struct inode *dir; - int error; - - parent = dget_parent(dentry); - dir = parent->d_inode; - mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); - error = __rpc_rmdir(dir, dentry); - mutex_unlock(&dir->i_mutex); - dput(parent); - return error; -} -EXPORT_SYMBOL_GPL(rpc_rmdir); - -static int __rpc_unlink(struct inode *dir, struct dentry *dentry) -{ - int ret; - - dget(dentry); - ret = simple_unlink(dir, dentry); - d_delete(dentry); - dput(dentry); - return ret; -} - -static int __rpc_rmpipe(struct inode *dir, struct dentry *dentry) -{ - struct inode *inode = dentry->d_inode; + struct dentry *dentry = simple_start_creating(parent, name); + struct inode *dir = parent->d_inode; + struct inode *inode; - rpc_close_pipes(inode); - return __rpc_unlink(dir, dentry); -} + if (IS_ERR(dentry)) + return PTR_ERR(dentry); -static struct dentry *__rpc_lookup_create_exclusive(struct dentry *parent, - const char *name) -{ - struct qstr q = QSTR_INIT(name, strlen(name)); - struct dentry *dentry = d_hash_and_lookup(parent, &q); - if (!dentry) { - dentry = d_alloc(parent, &q); - if (!dentry) - return ERR_PTR(-ENOMEM); + inode = rpc_get_inode(dir->i_sb, S_IFREG | mode); + if (unlikely(!inode)) { + simple_done_creating(dentry); + return -ENOMEM; } - if (dentry->d_inode == NULL) - return dentry; - dput(dentry); - return ERR_PTR(-EEXIST); + inode->i_ino = iunique(dir->i_sb, 100); + if (i_fop) + inode->i_fop = i_fop; + rpc_inode_setowner(inode, private); + d_make_persistent(dentry, inode); + fsnotify_create(dir, dentry); + simple_done_creating(dentry); + return 0; } -/* - * FIXME: This probably has races. - */ -static void __rpc_depopulate(struct dentry *parent, - const struct rpc_filelist *files, - int start, int eof) +static struct dentry *rpc_new_dir(struct dentry *parent, + const char *name, + umode_t mode) { + struct dentry *dentry = simple_start_creating(parent, name); struct inode *dir = parent->d_inode; - struct dentry *dentry; - struct qstr name; - int i; + struct inode *inode; - for (i = start; i < eof; i++) { - name.name = files[i].name; - name.len = strlen(files[i].name); - dentry = d_hash_and_lookup(parent, &name); + if (IS_ERR(dentry)) + return dentry; - if (dentry == NULL) - continue; - if (dentry->d_inode == NULL) - goto next; - switch (dentry->d_inode->i_mode & S_IFMT) { - default: - BUG(); - case S_IFREG: - __rpc_unlink(dir, dentry); - break; - case S_IFDIR: - __rpc_rmdir(dir, dentry); - } -next: - dput(dentry); + inode = rpc_get_inode(dir->i_sb, S_IFDIR | mode); + if (unlikely(!inode)) { + simple_done_creating(dentry); + return ERR_PTR(-ENOMEM); } -} -static void rpc_depopulate(struct dentry *parent, - const struct rpc_filelist *files, - int start, int eof) -{ - struct inode *dir = parent->d_inode; + inode->i_ino = iunique(dir->i_sb, 100); + inc_nlink(dir); + d_make_persistent(dentry, inode); + fsnotify_mkdir(dir, dentry); + simple_done_creating(dentry); - mutex_lock_nested(&dir->i_mutex, I_MUTEX_CHILD); - __rpc_depopulate(parent, files, start, eof); - mutex_unlock(&dir->i_mutex); + return dentry; // borrowed } static int rpc_populate(struct dentry *parent, @@ -739,94 +580,42 @@ static int rpc_populate(struct dentry *parent, int start, int eof, void *private) { - struct inode *dir = parent->d_inode; struct dentry *dentry; int i, err; - mutex_lock(&dir->i_mutex); for (i = start; i < eof; i++) { - dentry = __rpc_lookup_create_exclusive(parent, files[i].name); - err = PTR_ERR(dentry); - if (IS_ERR(dentry)) - goto out_bad; switch (files[i].mode & S_IFMT) { default: BUG(); case S_IFREG: - err = __rpc_create(dir, dentry, + err = rpc_new_file(parent, + files[i].name, files[i].mode, files[i].i_fop, private); + if (err) + goto out_bad; break; case S_IFDIR: - err = __rpc_mkdir(dir, dentry, - files[i].mode, - NULL, - private); + dentry = rpc_new_dir(parent, + files[i].name, + files[i].mode); + if (IS_ERR(dentry)) { + err = PTR_ERR(dentry); + goto out_bad; + } } - if (err != 0) - goto out_bad; } - mutex_unlock(&dir->i_mutex); return 0; out_bad: - __rpc_depopulate(parent, files, start, eof); - mutex_unlock(&dir->i_mutex); - printk(KERN_WARNING "%s: %s failed to populate directory %s\n", - __FILE__, __func__, parent->d_name.name); + printk(KERN_WARNING "%s: %s failed to populate directory %pd\n", + __FILE__, __func__, parent); return err; } -static struct dentry *rpc_mkdir_populate(struct dentry *parent, - const char *name, umode_t mode, void *private, - int (*populate)(struct dentry *, void *), void *args_populate) -{ - struct dentry *dentry; - struct inode *dir = parent->d_inode; - int error; - - mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); - dentry = __rpc_lookup_create_exclusive(parent, name); - if (IS_ERR(dentry)) - goto out; - error = __rpc_mkdir(dir, dentry, mode, NULL, private); - if (error != 0) - goto out_err; - if (populate != NULL) { - error = populate(dentry, args_populate); - if (error) - goto err_rmdir; - } -out: - mutex_unlock(&dir->i_mutex); - return dentry; -err_rmdir: - __rpc_rmdir(dir, dentry); -out_err: - dentry = ERR_PTR(error); - goto out; -} - -static int rpc_rmdir_depopulate(struct dentry *dentry, - void (*depopulate)(struct dentry *)) -{ - struct dentry *parent; - struct inode *dir; - int error; - - parent = dget_parent(dentry); - dir = parent->d_inode; - mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); - if (depopulate != NULL) - depopulate(dentry); - error = __rpc_rmdir(dir, dentry); - mutex_unlock(&dir->i_mutex); - dput(parent); - return error; -} - /** - * rpc_mkpipe - make an rpc_pipefs file for kernel<->userspace communication + * rpc_mkpipe_dentry - make an rpc_pipefs file for kernel<->userspace + * communication * @parent: dentry of directory to create new "pipe" in * @name: name of pipe * @private: private data to associate with the pipe, for the caller's use @@ -843,87 +632,221 @@ static int rpc_rmdir_depopulate(struct dentry *dentry, * The @private argument passed here will be available to all these methods * from the file pointer, via RPC_I(file_inode(file))->private. */ -struct dentry *rpc_mkpipe_dentry(struct dentry *parent, const char *name, +int rpc_mkpipe_dentry(struct dentry *parent, const char *name, void *private, struct rpc_pipe *pipe) { + struct inode *dir = d_inode(parent); struct dentry *dentry; - struct inode *dir = parent->d_inode; - umode_t umode = S_IFIFO | S_IRUSR | S_IWUSR; + struct inode *inode; + struct rpc_inode *rpci; + umode_t umode = S_IFIFO | 0600; int err; if (pipe->ops->upcall == NULL) - umode &= ~S_IRUGO; + umode &= ~0444; if (pipe->ops->downcall == NULL) - umode &= ~S_IWUGO; + umode &= ~0222; - mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); - dentry = __rpc_lookup_create_exclusive(parent, name); - if (IS_ERR(dentry)) - goto out; - err = __rpc_mkpipe_dentry(dir, dentry, umode, &rpc_pipe_fops, - private, pipe); - if (err) - goto out_err; -out: - mutex_unlock(&dir->i_mutex); - return dentry; -out_err: - dentry = ERR_PTR(err); - printk(KERN_WARNING "%s: %s() failed to create pipe %s/%s (errno = %d)\n", - __FILE__, __func__, parent->d_name.name, name, - err); - goto out; + dentry = simple_start_creating(parent, name); + if (IS_ERR(dentry)) { + err = PTR_ERR(dentry); + goto failed; + } + + inode = rpc_get_inode(dir->i_sb, umode); + if (unlikely(!inode)) { + simple_done_creating(dentry); + err = -ENOMEM; + goto failed; + } + inode->i_ino = iunique(dir->i_sb, 100); + inode->i_fop = &rpc_pipe_fops; + rpci = RPC_I(inode); + rpci->private = private; + rpci->pipe = pipe; + rpc_inode_setowner(inode, private); + pipe->dentry = dentry; // borrowed + d_make_persistent(dentry, inode); + fsnotify_create(dir, dentry); + simple_done_creating(dentry); + return 0; + +failed: + pr_warn("%s() failed to create pipe %pd/%s (errno = %d)\n", + __func__, parent, name, err); + return err; } EXPORT_SYMBOL_GPL(rpc_mkpipe_dentry); /** * rpc_unlink - remove a pipe - * @dentry: dentry for the pipe, as returned from rpc_mkpipe + * @pipe: the pipe to be removed * * After this call, lookups will no longer find the pipe, and any * attempts to read or write using preexisting opens of the pipe will * return -EPIPE. */ -int -rpc_unlink(struct dentry *dentry) +void +rpc_unlink(struct rpc_pipe *pipe) { - struct dentry *parent; - struct inode *dir; - int error = 0; - - parent = dget_parent(dentry); - dir = parent->d_inode; - mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); - error = __rpc_rmpipe(dir, dentry); - mutex_unlock(&dir->i_mutex); - dput(parent); - return error; + if (pipe->dentry) { + simple_recursive_removal(pipe->dentry, rpc_close_pipes); + pipe->dentry = NULL; + } } EXPORT_SYMBOL_GPL(rpc_unlink); -enum { - RPCAUTH_info, - RPCAUTH_EOF -}; +/** + * rpc_init_pipe_dir_head - initialise a struct rpc_pipe_dir_head + * @pdh: pointer to struct rpc_pipe_dir_head + */ +void rpc_init_pipe_dir_head(struct rpc_pipe_dir_head *pdh) +{ + INIT_LIST_HEAD(&pdh->pdh_entries); + pdh->pdh_dentry = NULL; +} +EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_head); -static const struct rpc_filelist authfiles[] = { - [RPCAUTH_info] = { - .name = "info", - .i_fop = &rpc_info_operations, - .mode = S_IFREG | S_IRUSR, - }, -}; +/** + * rpc_init_pipe_dir_object - initialise a struct rpc_pipe_dir_object + * @pdo: pointer to struct rpc_pipe_dir_object + * @pdo_ops: pointer to const struct rpc_pipe_dir_object_ops + * @pdo_data: pointer to caller-defined data + */ +void rpc_init_pipe_dir_object(struct rpc_pipe_dir_object *pdo, + const struct rpc_pipe_dir_object_ops *pdo_ops, + void *pdo_data) +{ + INIT_LIST_HEAD(&pdo->pdo_head); + pdo->pdo_ops = pdo_ops; + pdo->pdo_data = pdo_data; +} +EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_object); + +static int +rpc_add_pipe_dir_object_locked(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + int ret = 0; + + if (pdh->pdh_dentry) + ret = pdo->pdo_ops->create(pdh->pdh_dentry, pdo); + if (ret == 0) + list_add_tail(&pdo->pdo_head, &pdh->pdh_entries); + return ret; +} + +static void +rpc_remove_pipe_dir_object_locked(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + if (pdh->pdh_dentry) + pdo->pdo_ops->destroy(pdh->pdh_dentry, pdo); + list_del_init(&pdo->pdo_head); +} + +/** + * rpc_add_pipe_dir_object - associate a rpc_pipe_dir_object to a directory + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @pdo: pointer to struct rpc_pipe_dir_object + * + */ +int +rpc_add_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + int ret = 0; + + if (list_empty(&pdo->pdo_head)) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + ret = rpc_add_pipe_dir_object_locked(net, pdh, pdo); + mutex_unlock(&sn->pipefs_sb_lock); + } + return ret; +} +EXPORT_SYMBOL_GPL(rpc_add_pipe_dir_object); -static int rpc_clntdir_populate(struct dentry *dentry, void *private) +/** + * rpc_remove_pipe_dir_object - remove a rpc_pipe_dir_object from a directory + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @pdo: pointer to struct rpc_pipe_dir_object + * + */ +void +rpc_remove_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) { - return rpc_populate(dentry, - authfiles, RPCAUTH_info, RPCAUTH_EOF, - private); + if (!list_empty(&pdo->pdo_head)) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + rpc_remove_pipe_dir_object_locked(net, pdh, pdo); + mutex_unlock(&sn->pipefs_sb_lock); + } } +EXPORT_SYMBOL_GPL(rpc_remove_pipe_dir_object); -static void rpc_clntdir_depopulate(struct dentry *dentry) +/** + * rpc_find_or_alloc_pipe_dir_object + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @match: match struct rpc_pipe_dir_object to data + * @alloc: allocate a new struct rpc_pipe_dir_object + * @data: user defined data for match() and alloc() + * + */ +struct rpc_pipe_dir_object * +rpc_find_or_alloc_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + int (*match)(struct rpc_pipe_dir_object *, void *), + struct rpc_pipe_dir_object *(*alloc)(void *), + void *data) { - rpc_depopulate(dentry, authfiles, RPCAUTH_info, RPCAUTH_EOF); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe_dir_object *pdo; + + mutex_lock(&sn->pipefs_sb_lock); + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) { + if (!match(pdo, data)) + continue; + goto out; + } + pdo = alloc(data); + if (!pdo) + goto out; + rpc_add_pipe_dir_object_locked(net, pdh, pdo); +out: + mutex_unlock(&sn->pipefs_sb_lock); + return pdo; +} +EXPORT_SYMBOL_GPL(rpc_find_or_alloc_pipe_dir_object); + +static void +rpc_create_pipe_dir_objects(struct rpc_pipe_dir_head *pdh) +{ + struct rpc_pipe_dir_object *pdo; + struct dentry *dir = pdh->pdh_dentry; + + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) + pdo->pdo_ops->create(dir, pdo); +} + +static void +rpc_destroy_pipe_dir_objects(struct rpc_pipe_dir_head *pdh) +{ + struct rpc_pipe_dir_object *pdo; + struct dentry *dir = pdh->pdh_dentry; + + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) + pdo->pdo_ops->destroy(dir, pdo); } /** @@ -937,63 +860,82 @@ static void rpc_clntdir_depopulate(struct dentry *dentry) * information about the client, together with any "pipes" that may * later be created using rpc_mkpipe(). */ -struct dentry *rpc_create_client_dir(struct dentry *dentry, - const char *name, - struct rpc_clnt *rpc_client) +int rpc_create_client_dir(struct dentry *dentry, + const char *name, + struct rpc_clnt *rpc_client) { - return rpc_mkdir_populate(dentry, name, S_IRUGO | S_IXUGO, NULL, - rpc_clntdir_populate, rpc_client); + struct dentry *ret; + int err; + + ret = rpc_new_dir(dentry, name, 0555); + if (IS_ERR(ret)) + return PTR_ERR(ret); + err = rpc_new_file(ret, "info", S_IFREG | 0400, + &rpc_info_operations, rpc_client); + if (err) { + pr_warn("%s failed to populate directory %pd\n", + __func__, ret); + simple_recursive_removal(ret, NULL); + return err; + } + rpc_client->cl_pipedir_objects.pdh_dentry = ret; + rpc_create_pipe_dir_objects(&rpc_client->cl_pipedir_objects); + return 0; } /** * rpc_remove_client_dir - Remove a directory created with rpc_create_client_dir() - * @dentry: dentry for the pipe + * @rpc_client: rpc_client for the pipe */ -int rpc_remove_client_dir(struct dentry *dentry) +int rpc_remove_client_dir(struct rpc_clnt *rpc_client) { - return rpc_rmdir_depopulate(dentry, rpc_clntdir_depopulate); + struct dentry *dentry = rpc_client->cl_pipedir_objects.pdh_dentry; + + if (dentry == NULL) + return 0; + rpc_destroy_pipe_dir_objects(&rpc_client->cl_pipedir_objects); + rpc_client->cl_pipedir_objects.pdh_dentry = NULL; + simple_recursive_removal(dentry, NULL); + return 0; } static const struct rpc_filelist cache_pipefs_files[3] = { [0] = { .name = "channel", .i_fop = &cache_file_operations_pipefs, - .mode = S_IFREG|S_IRUSR|S_IWUSR, + .mode = S_IFREG | 0600, }, [1] = { .name = "content", .i_fop = &content_file_operations_pipefs, - .mode = S_IFREG|S_IRUSR, + .mode = S_IFREG | 0400, }, [2] = { .name = "flush", .i_fop = &cache_flush_operations_pipefs, - .mode = S_IFREG|S_IRUSR|S_IWUSR, + .mode = S_IFREG | 0600, }, }; -static int rpc_cachedir_populate(struct dentry *dentry, void *private) -{ - return rpc_populate(dentry, - cache_pipefs_files, 0, 3, - private); -} - -static void rpc_cachedir_depopulate(struct dentry *dentry) -{ - rpc_depopulate(dentry, cache_pipefs_files, 0, 3); -} - struct dentry *rpc_create_cache_dir(struct dentry *parent, const char *name, umode_t umode, struct cache_detail *cd) { - return rpc_mkdir_populate(parent, name, umode, NULL, - rpc_cachedir_populate, cd); + struct dentry *dentry; + + dentry = rpc_new_dir(parent, name, umode); + if (!IS_ERR(dentry)) { + int error = rpc_populate(dentry, cache_pipefs_files, 0, 3, cd); + if (error) { + simple_recursive_removal(dentry, NULL); + return ERR_PTR(error); + } + } + return dentry; } void rpc_remove_cache_dir(struct dentry *dentry) { - rpc_rmdir_depopulate(dentry, rpc_cachedir_depopulate); + simple_recursive_removal(dentry, NULL); } /* @@ -1001,7 +943,7 @@ void rpc_remove_cache_dir(struct dentry *dentry) */ static const struct super_operations s_ops = { .alloc_inode = rpc_alloc_inode, - .destroy_inode = rpc_destroy_inode, + .free_inode = rpc_free_inode, .statfs = simple_statfs, }; @@ -1025,35 +967,35 @@ enum { static const struct rpc_filelist files[] = { [RPCAUTH_lockd] = { .name = "lockd", - .mode = S_IFDIR | S_IRUGO | S_IXUGO, + .mode = S_IFDIR | 0555, }, [RPCAUTH_mount] = { .name = "mount", - .mode = S_IFDIR | S_IRUGO | S_IXUGO, + .mode = S_IFDIR | 0555, }, [RPCAUTH_nfs] = { .name = "nfs", - .mode = S_IFDIR | S_IRUGO | S_IXUGO, + .mode = S_IFDIR | 0555, }, [RPCAUTH_portmap] = { .name = "portmap", - .mode = S_IFDIR | S_IRUGO | S_IXUGO, + .mode = S_IFDIR | 0555, }, [RPCAUTH_statd] = { .name = "statd", - .mode = S_IFDIR | S_IRUGO | S_IXUGO, + .mode = S_IFDIR | 0555, }, [RPCAUTH_nfsd4_cb] = { .name = "nfsd4_cb", - .mode = S_IFDIR | S_IRUGO | S_IXUGO, + .mode = S_IFDIR | 0555, }, [RPCAUTH_cache] = { .name = "cache", - .mode = S_IFDIR | S_IRUGO | S_IXUGO, + .mode = S_IFDIR | 0555, }, [RPCAUTH_nfsd] = { .name = "nfsd", - .mode = S_IFDIR | S_IRUGO | S_IXUGO, + .mode = S_IFDIR | 0555, }, }; @@ -1063,18 +1005,28 @@ static const struct rpc_filelist files[] = { struct dentry *rpc_d_lookup_sb(const struct super_block *sb, const unsigned char *dir_name) { - struct qstr dir = QSTR_INIT(dir_name, strlen(dir_name)); - return d_hash_and_lookup(sb->s_root, &dir); + return try_lookup_noperm(&QSTR(dir_name), sb->s_root); } EXPORT_SYMBOL_GPL(rpc_d_lookup_sb); -void rpc_pipefs_init_net(struct net *net) +int rpc_pipefs_init_net(struct net *net) { struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + sn->gssd_dummy = rpc_mkpipe_data(&gssd_dummy_pipe_ops, 0); + if (IS_ERR(sn->gssd_dummy)) + return PTR_ERR(sn->gssd_dummy); + mutex_init(&sn->pipefs_sb_lock); - sn->gssd_running = 1; sn->pipe_version = -1; + return 0; +} + +void rpc_pipefs_exit_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + rpc_destroy_pipe_data(sn->gssd_dummy); } /* @@ -1104,56 +1056,133 @@ void rpc_put_sb_net(const struct net *net) } EXPORT_SYMBOL_GPL(rpc_put_sb_net); +static ssize_t +dummy_downcall(struct file *filp, const char __user *src, size_t len) +{ + return -EINVAL; +} + +static const struct rpc_pipe_ops gssd_dummy_pipe_ops = { + .upcall = rpc_pipe_generic_upcall, + .downcall = dummy_downcall, +}; + +/* + * Here we present a bogus "info" file to keep rpc.gssd happy. We don't expect + * that it will ever use this info to handle an upcall, but rpc.gssd expects + * that this file will be there and have a certain format. + */ +static int +rpc_dummy_info_show(struct seq_file *m, void *v) +{ + seq_printf(m, "RPC server: %s\n", utsname()->nodename); + seq_printf(m, "service: foo (1) version 0\n"); + seq_printf(m, "address: 127.0.0.1\n"); + seq_printf(m, "protocol: tcp\n"); + seq_printf(m, "port: 0\n"); + return 0; +} +DEFINE_SHOW_ATTRIBUTE(rpc_dummy_info); + +/** + * rpc_gssd_dummy_populate - create a dummy gssd pipe + * @root: root of the rpc_pipefs filesystem + * @pipe_data: pipe data created when netns is initialized + * + * Create a dummy set of directories and a pipe that gssd can hold open to + * indicate that it is up and running. + */ +static int +rpc_gssd_dummy_populate(struct dentry *root, struct rpc_pipe *pipe_data) +{ + struct dentry *gssd_dentry, *clnt_dentry; + int err; + + gssd_dentry = rpc_new_dir(root, "gssd", 0555); + if (IS_ERR(gssd_dentry)) + return -ENOENT; + + clnt_dentry = rpc_new_dir(gssd_dentry, "clntXX", 0555); + if (IS_ERR(clnt_dentry)) + return -ENOENT; + + err = rpc_new_file(clnt_dentry, "info", 0400, + &rpc_dummy_info_fops, NULL); + if (!err) + err = rpc_mkpipe_dentry(clnt_dentry, "gssd", NULL, pipe_data); + return err; +} + static int -rpc_fill_super(struct super_block *sb, void *data, int silent) +rpc_fill_super(struct super_block *sb, struct fs_context *fc) { struct inode *inode; struct dentry *root; - struct net *net = data; + struct net *net = sb->s_fs_info; struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); int err; - sb->s_blocksize = PAGE_CACHE_SIZE; - sb->s_blocksize_bits = PAGE_CACHE_SHIFT; + sb->s_blocksize = PAGE_SIZE; + sb->s_blocksize_bits = PAGE_SHIFT; sb->s_magic = RPCAUTH_GSSMAGIC; sb->s_op = &s_ops; - sb->s_d_op = &rpc_dentry_operations; + sb->s_d_flags = DCACHE_DONTCACHE; sb->s_time_gran = 1; - inode = rpc_get_inode(sb, S_IFDIR | S_IRUGO | S_IXUGO); + inode = rpc_get_inode(sb, S_IFDIR | 0555); sb->s_root = root = d_make_root(inode); if (!root) return -ENOMEM; if (rpc_populate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF, NULL)) return -ENOMEM; - dprintk("RPC: sending pipefs MOUNT notification for net %p%s\n", - net, NET_NAME(net)); + + err = rpc_gssd_dummy_populate(root, sn->gssd_dummy); + if (err) + return err; + + dprintk("RPC: sending pipefs MOUNT notification for net %x%s\n", + net->ns.inum, NET_NAME(net)); mutex_lock(&sn->pipefs_sb_lock); sn->pipefs_sb = sb; err = blocking_notifier_call_chain(&rpc_pipefs_notifier_list, RPC_PIPEFS_MOUNT, sb); - if (err) - goto err_depopulate; - sb->s_fs_info = get_net(net); - mutex_unlock(&sn->pipefs_sb_lock); - return 0; - -err_depopulate: - blocking_notifier_call_chain(&rpc_pipefs_notifier_list, - RPC_PIPEFS_UMOUNT, - sb); - sn->pipefs_sb = NULL; - __rpc_depopulate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF); mutex_unlock(&sn->pipefs_sb_lock); return err; } -static struct dentry * -rpc_mount(struct file_system_type *fs_type, - int flags, const char *dev_name, void *data) +bool +gssd_running(struct net *net) { - return mount_ns(fs_type, flags, current->nsproxy->net_ns, rpc_fill_super); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe *pipe = sn->gssd_dummy; + + return pipe->nreaders || pipe->nwriters; +} +EXPORT_SYMBOL_GPL(gssd_running); + +static int rpc_fs_get_tree(struct fs_context *fc) +{ + return get_tree_keyed(fc, rpc_fill_super, get_net(fc->net_ns)); +} + +static void rpc_fs_free_fc(struct fs_context *fc) +{ + if (fc->s_fs_info) + put_net(fc->s_fs_info); +} + +static const struct fs_context_operations rpc_fs_context_ops = { + .free = rpc_fs_free_fc, + .get_tree = rpc_fs_get_tree, +}; + +static int rpc_init_fs_context(struct fs_context *fc) +{ + put_user_ns(fc->user_ns); + fc->user_ns = get_user_ns(fc->net_ns->user_ns); + fc->ops = &rpc_fs_context_ops; + return 0; } static void rpc_kill_sb(struct super_block *sb) @@ -1167,21 +1196,21 @@ static void rpc_kill_sb(struct super_block *sb) goto out; } sn->pipefs_sb = NULL; - dprintk("RPC: sending pipefs UMOUNT notification for net %p%s\n", - net, NET_NAME(net)); + dprintk("RPC: sending pipefs UMOUNT notification for net %x%s\n", + net->ns.inum, NET_NAME(net)); blocking_notifier_call_chain(&rpc_pipefs_notifier_list, RPC_PIPEFS_UMOUNT, sb); mutex_unlock(&sn->pipefs_sb_lock); - put_net(net); out: - kill_litter_super(sb); + kill_anon_super(sb); + put_net(net); } static struct file_system_type rpc_pipe_fs_type = { .owner = THIS_MODULE, .name = "rpc_pipefs", - .mount = rpc_mount, + .init_fs_context = rpc_init_fs_context, .kill_sb = rpc_kill_sb, }; MODULE_ALIAS_FS("rpc_pipefs"); @@ -1205,7 +1234,7 @@ int register_rpc_pipefs(void) rpc_inode_cachep = kmem_cache_create("rpc_inode_cache", sizeof(struct rpc_inode), 0, (SLAB_HWCACHE_ALIGN|SLAB_RECLAIM_ACCOUNT| - SLAB_MEM_SPREAD), + SLAB_ACCOUNT), init_once); if (!rpc_inode_cachep) return -ENOMEM; @@ -1227,6 +1256,6 @@ err_notifier: void unregister_rpc_pipefs(void) { rpc_clients_notifier_unregister(); - kmem_cache_destroy(rpc_inode_cachep); unregister_filesystem(&rpc_pipe_fs_type); + kmem_cache_destroy(rpc_inode_cachep); } diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c index 3df764dc330c..53bcca365fb1 100644 --- a/net/sunrpc/rpcb_clnt.c +++ b/net/sunrpc/rpcb_clnt.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * In-kernel rpcbind client supporting versions 2, 3, and 4 of the rpcbind * protocol @@ -30,13 +31,12 @@ #include <linux/sunrpc/sched.h> #include <linux/sunrpc/xprtsock.h> -#include "netns.h" +#include <trace/events/sunrpc.h> -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_BIND -#endif +#include "netns.h" #define RPCBIND_SOCK_PATHNAME "/var/run/rpcbind.sock" +#define RPCBIND_SOCK_ABSTRACT_NAME "\0/run/rpcbind.sock" #define RPCBIND_PROGRAM (100000u) #define RPCBIND_PORT (111u) @@ -128,13 +128,13 @@ struct rpcbind_args { int r_status; }; -static struct rpc_procinfo rpcb_procedures2[]; -static struct rpc_procinfo rpcb_procedures3[]; -static struct rpc_procinfo rpcb_procedures4[]; +static const struct rpc_procinfo rpcb_procedures2[]; +static const struct rpc_procinfo rpcb_procedures3[]; +static const struct rpc_procinfo rpcb_procedures4[]; struct rpcb_info { u32 rpc_vers; - struct rpc_procinfo * rpc_proc; + const struct rpc_procinfo *rpc_proc; }; static const struct rpcb_info rpcb_next_version[]; @@ -204,40 +204,48 @@ void rpcb_put_local(struct net *net) } static void rpcb_set_local(struct net *net, struct rpc_clnt *clnt, - struct rpc_clnt *clnt4) + struct rpc_clnt *clnt4, + bool is_af_local) { struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); /* Protected by rpcb_create_local_mutex */ sn->rpcb_local_clnt = clnt; sn->rpcb_local_clnt4 = clnt4; - smp_wmb(); + sn->rpcb_is_af_local = is_af_local ? 1 : 0; + smp_wmb(); sn->rpcb_users = 1; - dprintk("RPC: created new rpcb local clients (rpcb_local_clnt: " - "%p, rpcb_local_clnt4: %p) for net %p%s\n", - sn->rpcb_local_clnt, sn->rpcb_local_clnt4, - net, (net == &init_net) ? " (init_net)" : ""); } +/* Evaluate to actual length of the `sockaddr_un' structure. */ +# define SUN_LEN(ptr) (offsetof(struct sockaddr_un, sun_path) \ + + 1 + strlen((ptr)->sun_path + 1)) + /* * Returns zero on success, otherwise a negative errno value * is returned. */ -static int rpcb_create_local_unix(struct net *net) +static int rpcb_create_af_local(struct net *net, + const struct sockaddr_un *addr) { - static const struct sockaddr_un rpcb_localaddr_rpcbind = { - .sun_family = AF_LOCAL, - .sun_path = RPCBIND_SOCK_PATHNAME, - }; struct rpc_create_args args = { .net = net, .protocol = XPRT_TRANSPORT_LOCAL, - .address = (struct sockaddr *)&rpcb_localaddr_rpcbind, - .addrsize = sizeof(rpcb_localaddr_rpcbind), + .address = (struct sockaddr *)addr, + .addrsize = SUN_LEN(addr), .servername = "localhost", .program = &rpcb_program, .version = RPCBVERS_2, .authflavor = RPC_AUTH_NULL, + .cred = current_cred(), + /* + * We turn off the idle timeout to prevent the kernel + * from automatically disconnecting the socket. + * Otherwise, we'd have to cache the mount namespace + * of the caller and somehow pass that to the socket + * reconnect code. + */ + .flags = RPC_CLNT_CREATE_NO_IDLE_TIMEOUT, }; struct rpc_clnt *clnt, *clnt4; int result = 0; @@ -249,26 +257,40 @@ static int rpcb_create_local_unix(struct net *net) */ clnt = rpc_create(&args); if (IS_ERR(clnt)) { - dprintk("RPC: failed to create AF_LOCAL rpcbind " - "client (errno %ld).\n", PTR_ERR(clnt)); result = PTR_ERR(clnt); goto out; } clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4); - if (IS_ERR(clnt4)) { - dprintk("RPC: failed to bind second program to " - "rpcbind v4 client (errno %ld).\n", - PTR_ERR(clnt4)); + if (IS_ERR(clnt4)) clnt4 = NULL; - } - rpcb_set_local(net, clnt, clnt4); + rpcb_set_local(net, clnt, clnt4, true); out: return result; } +static int rpcb_create_local_abstract(struct net *net) +{ + static const struct sockaddr_un rpcb_localaddr_abstract = { + .sun_family = AF_LOCAL, + .sun_path = RPCBIND_SOCK_ABSTRACT_NAME, + }; + + return rpcb_create_af_local(net, &rpcb_localaddr_abstract); +} + +static int rpcb_create_local_unix(struct net *net) +{ + static const struct sockaddr_un rpcb_localaddr_unix = { + .sun_family = AF_LOCAL, + .sun_path = RPCBIND_SOCK_PATHNAME, + }; + + return rpcb_create_af_local(net, &rpcb_localaddr_unix); +} + /* * Returns zero on success, otherwise a negative errno value * is returned. @@ -289,6 +311,7 @@ static int rpcb_create_local_net(struct net *net) .program = &rpcb_program, .version = RPCBVERS_2, .authflavor = RPC_AUTH_UNIX, + .cred = current_cred(), .flags = RPC_CLNT_CREATE_NOPING, }; struct rpc_clnt *clnt, *clnt4; @@ -296,8 +319,6 @@ static int rpcb_create_local_net(struct net *net) clnt = rpc_create(&args); if (IS_ERR(clnt)) { - dprintk("RPC: failed to create local rpcbind " - "client (errno %ld).\n", PTR_ERR(clnt)); result = PTR_ERR(clnt); goto out; } @@ -308,14 +329,10 @@ static int rpcb_create_local_net(struct net *net) * v4 upcalls. */ clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4); - if (IS_ERR(clnt4)) { - dprintk("RPC: failed to bind second program to " - "rpcbind v4 client (errno %ld).\n", - PTR_ERR(clnt4)); + if (IS_ERR(clnt4)) clnt4 = NULL; - } - rpcb_set_local(net, clnt, clnt4); + rpcb_set_local(net, clnt, clnt4, false); out: return result; @@ -337,7 +354,8 @@ int rpcb_create_local(struct net *net) if (rpcb_get_local(net)) goto out; - if (rpcb_create_local_unix(net) != 0) + if (rpcb_create_local_abstract(net) != 0 && + rpcb_create_local_unix(net) != 0) result = rpcb_create_local_net(net); out: @@ -345,19 +363,25 @@ out: return result; } -static struct rpc_clnt *rpcb_create(struct net *net, const char *hostname, +static struct rpc_clnt *rpcb_create(struct net *net, const char *nodename, + const char *hostname, struct sockaddr *srvaddr, size_t salen, - int proto, u32 version) + int proto, u32 version, + const struct cred *cred, + const struct rpc_timeout *timeo) { struct rpc_create_args args = { .net = net, .protocol = proto, .address = srvaddr, .addrsize = salen, + .timeout = timeo, .servername = hostname, + .nodename = nodename, .program = &rpcb_program, .version = version, .authflavor = RPC_AUTH_UNIX, + .cred = cred, .flags = (RPC_CLNT_CREATE_NOPING | RPC_CLNT_CREATE_NONPRIVPORT), }; @@ -376,18 +400,18 @@ static struct rpc_clnt *rpcb_create(struct net *net, const char *hostname, return rpc_create(&args); } -static int rpcb_register_call(struct rpc_clnt *clnt, struct rpc_message *msg) +static int rpcb_register_call(struct sunrpc_net *sn, struct rpc_clnt *clnt, struct rpc_message *msg, bool is_set) { - int result, error = 0; + int flags = RPC_TASK_NOCONNECT; + int error, result = 0; + if (is_set || !sn->rpcb_is_af_local) + flags = RPC_TASK_SOFTCONN; msg->rpc_resp = &result; - error = rpc_call_sync(clnt, msg, RPC_TASK_SOFTCONN); - if (error < 0) { - dprintk("RPC: failed to contact local rpcbind " - "server (errno %d).\n", -error); + error = rpc_call_sync(clnt, msg, flags); + if (error < 0) return error; - } if (!result) return -EACCES; @@ -439,16 +463,17 @@ int rpcb_register(struct net *net, u32 prog, u32 vers, int prot, unsigned short .rpc_argp = &map, }; struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + bool is_set = false; - dprintk("RPC: %sregistering (%u, %u, %d, %u) with local " - "rpcbind\n", (port ? "" : "un"), - prog, vers, prot, port); + trace_pmap_register(prog, vers, prot, port); msg.rpc_proc = &rpcb_procedures2[RPCBPROC_UNSET]; - if (port) + if (port != 0) { msg.rpc_proc = &rpcb_procedures2[RPCBPROC_SET]; + is_set = true; + } - return rpcb_register_call(sn->rpcb_local_clnt, &msg); + return rpcb_register_call(sn, sn->rpcb_local_clnt, &msg, is_set); } /* @@ -461,20 +486,18 @@ static int rpcb_register_inet4(struct sunrpc_net *sn, const struct sockaddr_in *sin = (const struct sockaddr_in *)sap; struct rpcbind_args *map = msg->rpc_argp; unsigned short port = ntohs(sin->sin_port); + bool is_set = false; int result; map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL); - dprintk("RPC: %sregistering [%u, %u, %s, '%s'] with " - "local rpcbind\n", (port ? "" : "un"), - map->r_prog, map->r_vers, - map->r_addr, map->r_netid); - msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; - if (port) + if (port != 0) { msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + is_set = true; + } - result = rpcb_register_call(sn->rpcb_local_clnt4, msg); + result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set); kfree(map->r_addr); return result; } @@ -489,20 +512,18 @@ static int rpcb_register_inet6(struct sunrpc_net *sn, const struct sockaddr_in6 *sin6 = (const struct sockaddr_in6 *)sap; struct rpcbind_args *map = msg->rpc_argp; unsigned short port = ntohs(sin6->sin6_port); + bool is_set = false; int result; map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL); - dprintk("RPC: %sregistering [%u, %u, %s, '%s'] with " - "local rpcbind\n", (port ? "" : "un"), - map->r_prog, map->r_vers, - map->r_addr, map->r_netid); - msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; - if (port) + if (port != 0) { msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + is_set = true; + } - result = rpcb_register_call(sn->rpcb_local_clnt4, msg); + result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set); kfree(map->r_addr); return result; } @@ -512,14 +533,12 @@ static int rpcb_unregister_all_protofamilies(struct sunrpc_net *sn, { struct rpcbind_args *map = msg->rpc_argp; - dprintk("RPC: unregistering [%u, %u, '%s'] with " - "local rpcbind\n", - map->r_prog, map->r_vers, map->r_netid); + trace_rpcb_unregister(map->r_prog, map->r_vers, map->r_netid); map->r_addr = ""; msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; - return rpcb_register_call(sn->rpcb_local_clnt4, msg); + return rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, false); } /** @@ -586,6 +605,8 @@ int rpcb_v4_register(struct net *net, const u32 program, const u32 version, if (address == NULL) return rpcb_unregister_all_protofamilies(sn, &msg); + trace_rpcb_register(map.r_prog, map.r_vers, map.r_addr, map.r_netid); + switch (address->sa_family) { case AF_INET: return rpcb_register_inet4(sn, address, &msg); @@ -596,7 +617,8 @@ int rpcb_v4_register(struct net *net, const u32 program, const u32 version, return -EAFNOSUPPORT; } -static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, struct rpcbind_args *map, struct rpc_procinfo *proc) +static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, + struct rpcbind_args *map, const struct rpc_procinfo *proc) { struct rpc_message msg = { .rpc_proc = proc, @@ -624,10 +646,10 @@ static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, struct rpcbi static struct rpc_clnt *rpcb_find_transport_owner(struct rpc_clnt *clnt) { struct rpc_clnt *parent = clnt->cl_parent; - struct rpc_xprt *xprt = rcu_dereference(clnt->cl_xprt); + struct rpc_xprt_switch *xps = rcu_access_pointer(clnt->cl_xpi.xpi_xpswitch); while (parent != clnt) { - if (rcu_dereference(parent->cl_xprt) != xprt) + if (rcu_access_pointer(parent->cl_xpi.xpi_xpswitch) != xps) break; if (clnt->cl_autobind) break; @@ -647,7 +669,7 @@ static struct rpc_clnt *rpcb_find_transport_owner(struct rpc_clnt *clnt) void rpcb_getport_async(struct rpc_task *task) { struct rpc_clnt *clnt; - struct rpc_procinfo *proc; + const struct rpc_procinfo *proc; u32 bind_version; struct rpc_xprt *xprt; struct rpc_clnt *rpcb_clnt; @@ -659,23 +681,16 @@ void rpcb_getport_async(struct rpc_task *task) int status; rcu_read_lock(); - do { - clnt = rpcb_find_transport_owner(task->tk_client); - xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); - } while (xprt == NULL); + clnt = rpcb_find_transport_owner(task->tk_client); rcu_read_unlock(); - - dprintk("RPC: %5u %s(%s, %u, %u, %d)\n", - task->tk_pid, __func__, - xprt->servername, clnt->cl_prog, clnt->cl_vers, xprt->prot); + xprt = xprt_get(task->tk_xprt); /* Put self on the wait queue to ensure we get notified if * some other task is already attempting to bind the port */ - rpc_sleep_on(&xprt->binding, task, NULL); + rpc_sleep_on_timeout(&xprt->binding, task, + NULL, jiffies + xprt->bind_timeout); if (xprt_test_and_set_binding(xprt)) { - dprintk("RPC: %5u %s: waiting for another binder\n", - task->tk_pid, __func__); xprt_put(xprt); return; } @@ -683,8 +698,6 @@ void rpcb_getport_async(struct rpc_task *task) /* Someone else may have bound if we slept */ if (xprt_bound(xprt)) { status = 0; - dprintk("RPC: %5u %s: already bound\n", - task->tk_pid, __func__); goto bailout_nofree; } @@ -703,35 +716,30 @@ void rpcb_getport_async(struct rpc_task *task) break; default: status = -EAFNOSUPPORT; - dprintk("RPC: %5u %s: bad address family\n", - task->tk_pid, __func__); goto bailout_nofree; } if (proc == NULL) { xprt->bind_index = 0; status = -EPFNOSUPPORT; - dprintk("RPC: %5u %s: no more getport versions available\n", - task->tk_pid, __func__); goto bailout_nofree; } - dprintk("RPC: %5u %s: trying rpcbind version %u\n", - task->tk_pid, __func__, bind_version); + trace_rpcb_getport(clnt, task, bind_version); - rpcb_clnt = rpcb_create(xprt->xprt_net, xprt->servername, sap, salen, - xprt->prot, bind_version); + rpcb_clnt = rpcb_create(xprt->xprt_net, + clnt->cl_nodename, + xprt->servername, sap, salen, + xprt->prot, bind_version, + clnt->cl_cred, + task->tk_client->cl_timeout); if (IS_ERR(rpcb_clnt)) { status = PTR_ERR(rpcb_clnt); - dprintk("RPC: %5u %s: rpcb_create failed, error %ld\n", - task->tk_pid, __func__, PTR_ERR(rpcb_clnt)); goto bailout_nofree; } - map = kzalloc(sizeof(struct rpcbind_args), GFP_ATOMIC); + map = kzalloc(sizeof(struct rpcbind_args), rpc_task_gfp_mask()); if (!map) { status = -ENOMEM; - dprintk("RPC: %5u %s: no memory available\n", - task->tk_pid, __func__); goto bailout_release_client; } map->r_prog = clnt->cl_prog; @@ -745,7 +753,11 @@ void rpcb_getport_async(struct rpc_task *task) case RPCBVERS_4: case RPCBVERS_3: map->r_netid = xprt->address_strings[RPC_DISPLAY_NETID]; - map->r_addr = rpc_sockaddr2uaddr(sap, GFP_ATOMIC); + map->r_addr = rpc_sockaddr2uaddr(sap, rpc_task_gfp_mask()); + if (!map->r_addr) { + status = -ENOMEM; + goto bailout_free_args; + } map->r_owner = ""; break; case RPCBVERS_2: @@ -759,8 +771,6 @@ void rpcb_getport_async(struct rpc_task *task) rpc_release_client(rpcb_clnt); if (IS_ERR(child)) { /* rpcb_map_release() has freed the arguments */ - dprintk("RPC: %5u %s: rpc_run_task failed\n", - task->tk_pid, __func__); return; } @@ -768,6 +778,8 @@ void rpcb_getport_async(struct rpc_task *task) rpc_put_task(child); return; +bailout_free_args: + kfree(map); bailout_release_client: rpc_release_client(rpcb_clnt); bailout_nofree: @@ -784,34 +796,34 @@ static void rpcb_getport_done(struct rpc_task *child, void *data) { struct rpcbind_args *map = data; struct rpc_xprt *xprt = map->r_xprt; - int status = child->tk_status; + + map->r_status = child->tk_status; /* Garbage reply: retry with a lesser rpcbind version */ - if (status == -EIO) - status = -EPROTONOSUPPORT; + if (map->r_status == -EIO) + map->r_status = -EPROTONOSUPPORT; /* rpcbind server doesn't support this rpcbind protocol version */ - if (status == -EPROTONOSUPPORT) + if (map->r_status == -EPROTONOSUPPORT) xprt->bind_index++; - if (status < 0) { + if (map->r_status < 0) { /* rpcbind server not available on remote host? */ - xprt->ops->set_port(xprt, 0); + map->r_port = 0; + } else if (map->r_port == 0) { /* Requested RPC service wasn't registered on remote host */ - xprt->ops->set_port(xprt, 0); - status = -EACCES; + map->r_status = -EACCES; } else { /* Succeeded */ + map->r_status = 0; + } + + trace_rpcb_setport(child, map->r_status, map->r_port); + if (map->r_port) { xprt->ops->set_port(xprt, map->r_port); xprt_set_bound(xprt); - status = 0; } - - dprintk("RPC: %5u rpcb_getport_done(status %d, port %u)\n", - child->tk_pid, status, map->r_port); - - map->r_status = status; } /* @@ -819,15 +831,11 @@ static void rpcb_getport_done(struct rpc_task *child, void *data) */ static void rpcb_enc_mapping(struct rpc_rqst *req, struct xdr_stream *xdr, - const struct rpcbind_args *rpcb) + const void *data) { + const struct rpcbind_args *rpcb = data; __be32 *p; - dprintk("RPC: %5u encoding PMAP_%s call (%u, %u, %d, %u)\n", - req->rq_task->tk_pid, - req->rq_task->tk_msg.rpc_proc->p_name, - rpcb->r_prog, rpcb->r_vers, rpcb->r_prot, rpcb->r_port); - p = xdr_reserve_space(xdr, RPCB_mappingargs_sz << 2); *p++ = cpu_to_be32(rpcb->r_prog); *p++ = cpu_to_be32(rpcb->r_vers); @@ -836,8 +844,9 @@ static void rpcb_enc_mapping(struct rpc_rqst *req, struct xdr_stream *xdr, } static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr, - struct rpcbind_args *rpcb) + void *data) { + struct rpcbind_args *rpcb = data; unsigned long port; __be32 *p; @@ -848,8 +857,6 @@ static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr, return -EIO; port = be32_to_cpup(p); - dprintk("RPC: %5u PMAP_%s result: %lu\n", req->rq_task->tk_pid, - req->rq_task->tk_msg.rpc_proc->p_name, port); if (unlikely(port > USHRT_MAX)) return -EIO; @@ -858,8 +865,9 @@ static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr, } static int rpcb_dec_set(struct rpc_rqst *req, struct xdr_stream *xdr, - unsigned int *boolp) + void *data) { + unsigned int *boolp = data; __be32 *p; p = xdr_inline_decode(xdr, 4); @@ -869,11 +877,6 @@ static int rpcb_dec_set(struct rpc_rqst *req, struct xdr_stream *xdr, *boolp = 0; if (*p != xdr_zero) *boolp = 1; - - dprintk("RPC: %5u RPCB_%s call %s\n", - req->rq_task->tk_pid, - req->rq_task->tk_msg.rpc_proc->p_name, - (*boolp ? "succeeded" : "failed")); return 0; } @@ -893,16 +896,11 @@ static void encode_rpcb_string(struct xdr_stream *xdr, const char *string, } static void rpcb_enc_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, - const struct rpcbind_args *rpcb) + const void *data) { + const struct rpcbind_args *rpcb = data; __be32 *p; - dprintk("RPC: %5u encoding RPCB_%s call (%u, %u, '%s', '%s')\n", - req->rq_task->tk_pid, - req->rq_task->tk_msg.rpc_proc->p_name, - rpcb->r_prog, rpcb->r_vers, - rpcb->r_netid, rpcb->r_addr); - p = xdr_reserve_space(xdr, (RPCB_program_sz + RPCB_version_sz) << 2); *p++ = cpu_to_be32(rpcb->r_prog); *p = cpu_to_be32(rpcb->r_vers); @@ -913,8 +911,9 @@ static void rpcb_enc_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, } static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, - struct rpcbind_args *rpcb) + void *data) { + struct rpcbind_args *rpcb = data; struct sockaddr_storage address; struct sockaddr *sap = (struct sockaddr *)&address; __be32 *p; @@ -931,11 +930,8 @@ static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, * If the returned universal address is a null string, * the requested RPC service was not registered. */ - if (len == 0) { - dprintk("RPC: %5u RPCB reply: program not registered\n", - req->rq_task->tk_pid); + if (len == 0) return 0; - } if (unlikely(len > RPCBIND_MAXUADDRLEN)) goto out_fail; @@ -943,8 +939,6 @@ static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, p = xdr_inline_decode(xdr, len); if (unlikely(p == NULL)) goto out_fail; - dprintk("RPC: %5u RPCB_%s reply: %s\n", req->rq_task->tk_pid, - req->rq_task->tk_msg.rpc_proc->p_name, (char *)p); if (rpc_uaddr2sockaddr(req->rq_xprt->xprt_net, (char *)p, len, sap, sizeof(address)) == 0) @@ -954,9 +948,6 @@ static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, return 0; out_fail: - dprintk("RPC: %5u malformed RPCB_%s reply\n", - req->rq_task->tk_pid, - req->rq_task->tk_msg.rpc_proc->p_name); return -EIO; } @@ -965,11 +956,11 @@ out_fail: * since the Linux kernel RPC code requires only these. */ -static struct rpc_procinfo rpcb_procedures2[] = { +static const struct rpc_procinfo rpcb_procedures2[] = { [RPCBPROC_SET] = { .p_proc = RPCBPROC_SET, - .p_encode = (kxdreproc_t)rpcb_enc_mapping, - .p_decode = (kxdrdproc_t)rpcb_dec_set, + .p_encode = rpcb_enc_mapping, + .p_decode = rpcb_dec_set, .p_arglen = RPCB_mappingargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_SET, @@ -978,8 +969,8 @@ static struct rpc_procinfo rpcb_procedures2[] = { }, [RPCBPROC_UNSET] = { .p_proc = RPCBPROC_UNSET, - .p_encode = (kxdreproc_t)rpcb_enc_mapping, - .p_decode = (kxdrdproc_t)rpcb_dec_set, + .p_encode = rpcb_enc_mapping, + .p_decode = rpcb_dec_set, .p_arglen = RPCB_mappingargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_UNSET, @@ -988,8 +979,8 @@ static struct rpc_procinfo rpcb_procedures2[] = { }, [RPCBPROC_GETPORT] = { .p_proc = RPCBPROC_GETPORT, - .p_encode = (kxdreproc_t)rpcb_enc_mapping, - .p_decode = (kxdrdproc_t)rpcb_dec_getport, + .p_encode = rpcb_enc_mapping, + .p_decode = rpcb_dec_getport, .p_arglen = RPCB_mappingargs_sz, .p_replen = RPCB_getportres_sz, .p_statidx = RPCBPROC_GETPORT, @@ -998,11 +989,11 @@ static struct rpc_procinfo rpcb_procedures2[] = { }, }; -static struct rpc_procinfo rpcb_procedures3[] = { +static const struct rpc_procinfo rpcb_procedures3[] = { [RPCBPROC_SET] = { .p_proc = RPCBPROC_SET, - .p_encode = (kxdreproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrdproc_t)rpcb_dec_set, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_SET, @@ -1011,8 +1002,8 @@ static struct rpc_procinfo rpcb_procedures3[] = { }, [RPCBPROC_UNSET] = { .p_proc = RPCBPROC_UNSET, - .p_encode = (kxdreproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrdproc_t)rpcb_dec_set, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_UNSET, @@ -1021,8 +1012,8 @@ static struct rpc_procinfo rpcb_procedures3[] = { }, [RPCBPROC_GETADDR] = { .p_proc = RPCBPROC_GETADDR, - .p_encode = (kxdreproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrdproc_t)rpcb_dec_getaddr, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_getaddr, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_getaddrres_sz, .p_statidx = RPCBPROC_GETADDR, @@ -1031,11 +1022,11 @@ static struct rpc_procinfo rpcb_procedures3[] = { }, }; -static struct rpc_procinfo rpcb_procedures4[] = { +static const struct rpc_procinfo rpcb_procedures4[] = { [RPCBPROC_SET] = { .p_proc = RPCBPROC_SET, - .p_encode = (kxdreproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrdproc_t)rpcb_dec_set, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_SET, @@ -1044,8 +1035,8 @@ static struct rpc_procinfo rpcb_procedures4[] = { }, [RPCBPROC_UNSET] = { .p_proc = RPCBPROC_UNSET, - .p_encode = (kxdreproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrdproc_t)rpcb_dec_set, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_setres_sz, .p_statidx = RPCBPROC_UNSET, @@ -1054,8 +1045,8 @@ static struct rpc_procinfo rpcb_procedures4[] = { }, [RPCBPROC_GETADDR] = { .p_proc = RPCBPROC_GETADDR, - .p_encode = (kxdreproc_t)rpcb_enc_getaddr, - .p_decode = (kxdrdproc_t)rpcb_dec_getaddr, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_getaddr, .p_arglen = RPCB_getaddrargs_sz, .p_replen = RPCB_getaddrres_sz, .p_statidx = RPCBPROC_GETADDR, @@ -1088,22 +1079,28 @@ static const struct rpcb_info rpcb_next_version6[] = { }, }; +static unsigned int rpcb_version2_counts[ARRAY_SIZE(rpcb_procedures2)]; static const struct rpc_version rpcb_version2 = { .number = RPCBVERS_2, .nrprocs = ARRAY_SIZE(rpcb_procedures2), - .procs = rpcb_procedures2 + .procs = rpcb_procedures2, + .counts = rpcb_version2_counts, }; +static unsigned int rpcb_version3_counts[ARRAY_SIZE(rpcb_procedures3)]; static const struct rpc_version rpcb_version3 = { .number = RPCBVERS_3, .nrprocs = ARRAY_SIZE(rpcb_procedures3), - .procs = rpcb_procedures3 + .procs = rpcb_procedures3, + .counts = rpcb_version3_counts, }; +static unsigned int rpcb_version4_counts[ARRAY_SIZE(rpcb_procedures4)]; static const struct rpc_version rpcb_version4 = { .number = RPCBVERS_4, .nrprocs = ARRAY_SIZE(rpcb_procedures4), - .procs = rpcb_procedures4 + .procs = rpcb_procedures4, + .counts = rpcb_version4_counts, }; static const struct rpc_version *rpcb_version[] = { diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c index 93a7a4e94d80..016f16ca5779 100644 --- a/net/sunrpc/sched.c +++ b/net/sunrpc/sched.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/sched.c * @@ -19,15 +20,13 @@ #include <linux/spinlock.h> #include <linux/mutex.h> #include <linux/freezer.h> +#include <linux/sched/mm.h> #include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/metrics.h> #include "sunrpc.h" -#ifdef RPC_DEBUG -#define RPCDBG_FACILITY RPCDBG_SCHED -#endif - #define CREATE_TRACE_POINTS #include <trace/events/sunrpc.h> @@ -44,7 +43,7 @@ static mempool_t *rpc_buffer_mempool __read_mostly; static void rpc_async_schedule(struct work_struct *); static void rpc_release_task(struct rpc_task *task); -static void __rpc_queue_timer_fn(unsigned long ptr); +static void __rpc_queue_timer_fn(struct work_struct *); /* * RPC tasks sit here while waiting for conditions to improve. @@ -54,7 +53,38 @@ static struct rpc_wait_queue delay_queue; /* * rpciod-related stuff */ -struct workqueue_struct *rpciod_workqueue; +struct workqueue_struct *rpciod_workqueue __read_mostly; +struct workqueue_struct *xprtiod_workqueue __read_mostly; +EXPORT_SYMBOL_GPL(xprtiod_workqueue); + +gfp_t rpc_task_gfp_mask(void) +{ + if (current->flags & PF_WQ_WORKER) + return GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN; + return GFP_KERNEL; +} +EXPORT_SYMBOL_GPL(rpc_task_gfp_mask); + +bool rpc_task_set_rpc_status(struct rpc_task *task, int rpc_status) +{ + if (cmpxchg(&task->tk_rpc_status, 0, rpc_status) == 0) + return true; + return false; +} + +unsigned long +rpc_task_timeout(const struct rpc_task *task) +{ + unsigned long timeout = READ_ONCE(task->tk_timeout); + + if (timeout != 0) { + unsigned long now = jiffies; + if (time_before(now, timeout)) + return timeout - now; + } + return 0; +} +EXPORT_SYMBOL_GPL(rpc_task_timeout); /* * Disable the timer for a given RPC task. Should be called with @@ -64,118 +94,121 @@ struct workqueue_struct *rpciod_workqueue; static void __rpc_disable_timer(struct rpc_wait_queue *queue, struct rpc_task *task) { - if (task->tk_timeout == 0) + if (list_empty(&task->u.tk_wait.timer_list)) return; - dprintk("RPC: %5u disabling timer\n", task->tk_pid); task->tk_timeout = 0; list_del(&task->u.tk_wait.timer_list); if (list_empty(&queue->timer_list.list)) - del_timer(&queue->timer_list.timer); + cancel_delayed_work(&queue->timer_list.dwork); } static void rpc_set_queue_timer(struct rpc_wait_queue *queue, unsigned long expires) { + unsigned long now = jiffies; queue->timer_list.expires = expires; - mod_timer(&queue->timer_list.timer, expires); + if (time_before_eq(expires, now)) + expires = 0; + else + expires -= now; + mod_delayed_work(rpciod_workqueue, &queue->timer_list.dwork, expires); } /* * Set up a timer for the current task. */ static void -__rpc_add_timer(struct rpc_wait_queue *queue, struct rpc_task *task) +__rpc_add_timer(struct rpc_wait_queue *queue, struct rpc_task *task, + unsigned long timeout) { - if (!task->tk_timeout) - return; - - dprintk("RPC: %5u setting alarm for %lu ms\n", - task->tk_pid, task->tk_timeout * 1000 / HZ); - - task->u.tk_wait.expires = jiffies + task->tk_timeout; - if (list_empty(&queue->timer_list.list) || time_before(task->u.tk_wait.expires, queue->timer_list.expires)) - rpc_set_queue_timer(queue, task->u.tk_wait.expires); + task->tk_timeout = timeout; + if (list_empty(&queue->timer_list.list) || time_before(timeout, queue->timer_list.expires)) + rpc_set_queue_timer(queue, timeout); list_add(&task->u.tk_wait.timer_list, &queue->timer_list.list); } -static void rpc_rotate_queue_owner(struct rpc_wait_queue *queue) -{ - struct list_head *q = &queue->tasks[queue->priority]; - struct rpc_task *task; - - if (!list_empty(q)) { - task = list_first_entry(q, struct rpc_task, u.tk_wait.list); - if (task->tk_owner == queue->owner) - list_move_tail(&task->u.tk_wait.list, q); - } -} - static void rpc_set_waitqueue_priority(struct rpc_wait_queue *queue, int priority) { if (queue->priority != priority) { - /* Fairness: rotate the list when changing priority */ - rpc_rotate_queue_owner(queue); queue->priority = priority; + queue->nr = 1U << priority; } } -static void rpc_set_waitqueue_owner(struct rpc_wait_queue *queue, pid_t pid) -{ - queue->owner = pid; - queue->nr = RPC_BATCH_COUNT; -} - static void rpc_reset_waitqueue_priority(struct rpc_wait_queue *queue) { rpc_set_waitqueue_priority(queue, queue->maxpriority); - rpc_set_waitqueue_owner(queue, 0); } /* - * Add new request to a priority queue. + * Add a request to a queue list */ -static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, - struct rpc_task *task, - unsigned char queue_priority) +static void +__rpc_list_enqueue_task(struct list_head *q, struct rpc_task *task) { - struct list_head *q; struct rpc_task *t; - INIT_LIST_HEAD(&task->u.tk_wait.links); - if (unlikely(queue_priority > queue->maxpriority)) - queue_priority = queue->maxpriority; - if (queue_priority > queue->priority) - rpc_set_waitqueue_priority(queue, queue_priority); - q = &queue->tasks[queue_priority]; list_for_each_entry(t, q, u.tk_wait.list) { if (t->tk_owner == task->tk_owner) { - list_add_tail(&task->u.tk_wait.list, &t->u.tk_wait.links); + list_add_tail(&task->u.tk_wait.links, + &t->u.tk_wait.links); + /* Cache the queue head in task->u.tk_wait.list */ + task->u.tk_wait.list.next = q; + task->u.tk_wait.list.prev = NULL; return; } } + INIT_LIST_HEAD(&task->u.tk_wait.links); list_add_tail(&task->u.tk_wait.list, q); } /* + * Remove request from a queue list + */ +static void +__rpc_list_dequeue_task(struct rpc_task *task) +{ + struct list_head *q; + struct rpc_task *t; + + if (task->u.tk_wait.list.prev == NULL) { + list_del(&task->u.tk_wait.links); + return; + } + if (!list_empty(&task->u.tk_wait.links)) { + t = list_first_entry(&task->u.tk_wait.links, + struct rpc_task, + u.tk_wait.links); + /* Assume __rpc_list_enqueue_task() cached the queue head */ + q = t->u.tk_wait.list.next; + list_add_tail(&t->u.tk_wait.list, q); + list_del(&task->u.tk_wait.links); + } + list_del(&task->u.tk_wait.list); +} + +/* + * Add new request to a priority queue. + */ +static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, + struct rpc_task *task, + unsigned char queue_priority) +{ + if (unlikely(queue_priority > queue->maxpriority)) + queue_priority = queue->maxpriority; + __rpc_list_enqueue_task(&queue->tasks[queue_priority], task); +} + +/* * Add new request to wait queue. - * - * Swapper tasks always get inserted at the head of the queue. - * This should avoid many nasty memory deadlocks and hopefully - * improve overall performance. - * Everyone else gets appended to the queue to ensure proper FIFO behavior. */ static void __rpc_add_wait_queue(struct rpc_wait_queue *queue, struct rpc_task *task, unsigned char queue_priority) { - WARN_ON_ONCE(RPC_IS_QUEUED(task)); - if (RPC_IS_QUEUED(task)) - return; - + INIT_LIST_HEAD(&task->u.tk_wait.timer_list); if (RPC_IS_PRIORITY(queue)) __rpc_add_wait_queue_priority(queue, task, queue_priority); - else if (RPC_IS_SWAPPER(task)) - list_add(&task->u.tk_wait.list, &queue->tasks[0]); else list_add_tail(&task->u.tk_wait.list, &queue->tasks[0]); task->tk_waitqueue = queue; @@ -183,9 +216,6 @@ static void __rpc_add_wait_queue(struct rpc_wait_queue *queue, /* barrier matches the read in rpc_wake_up_task_queue_locked() */ smp_wmb(); rpc_set_queued(task); - - dprintk("RPC: %5u added to queue %p \"%s\"\n", - task->tk_pid, queue, rpc_qname(queue)); } /* @@ -193,13 +223,7 @@ static void __rpc_add_wait_queue(struct rpc_wait_queue *queue, */ static void __rpc_remove_wait_queue_priority(struct rpc_task *task) { - struct rpc_task *t; - - if (!list_empty(&task->u.tk_wait.links)) { - t = list_entry(task->u.tk_wait.links.next, struct rpc_task, u.tk_wait.list); - list_move(&t->u.tk_wait.list, &task->u.tk_wait.list); - list_splice_init(&task->u.tk_wait.links, &t->u.tk_wait.links); - } + __rpc_list_dequeue_task(task); } /* @@ -211,10 +235,9 @@ static void __rpc_remove_wait_queue(struct rpc_wait_queue *queue, struct rpc_tas __rpc_disable_timer(queue, task); if (RPC_IS_PRIORITY(queue)) __rpc_remove_wait_queue_priority(task); - list_del(&task->u.tk_wait.list); + else + list_del(&task->u.tk_wait.list); queue->qlen--; - dprintk("RPC: %5u removed from queue %p \"%s\"\n", - task->tk_pid, queue, rpc_qname(queue)); } static void __rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname, unsigned char nr_queues) @@ -227,7 +250,8 @@ static void __rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const c queue->maxpriority = nr_queues - 1; rpc_reset_waitqueue_priority(queue); queue->qlen = 0; - setup_timer(&queue->timer_list.timer, __rpc_queue_timer_fn, (unsigned long)queue); + queue->timer_list.expires = 0; + INIT_DELAYED_WORK(&queue->timer_list.dwork, __rpc_queue_timer_fn); INIT_LIST_HEAD(&queue->timer_list.list); rpc_assign_waitqueue_name(queue, qname); } @@ -246,24 +270,32 @@ EXPORT_SYMBOL_GPL(rpc_init_wait_queue); void rpc_destroy_wait_queue(struct rpc_wait_queue *queue) { - del_timer_sync(&queue->timer_list.timer); + cancel_delayed_work_sync(&queue->timer_list.dwork); } EXPORT_SYMBOL_GPL(rpc_destroy_wait_queue); -static int rpc_wait_bit_killable(void *word) +static int rpc_wait_bit_killable(struct wait_bit_key *key, int mode) { - if (fatal_signal_pending(current)) + schedule(); + if (signal_pending_state(mode, current)) return -ERESTARTSYS; - freezable_schedule_unsafe(); return 0; } -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) || IS_ENABLED(CONFIG_TRACEPOINTS) static void rpc_task_set_debuginfo(struct rpc_task *task) { - static atomic_t rpc_pid; + struct rpc_clnt *clnt = task->tk_client; - task->tk_pid = atomic_inc_return(&rpc_pid); + /* Might be a task carrying a reverse-direction operation */ + if (!clnt) { + static atomic_t rpc_pid; + + task->tk_pid = atomic_inc_return(&rpc_pid); + return; + } + + task->tk_pid = atomic_inc_return(&clnt->cl_pid); } #else static inline void rpc_task_set_debuginfo(struct rpc_task *task) @@ -273,10 +305,9 @@ static inline void rpc_task_set_debuginfo(struct rpc_task *task) static void rpc_set_active(struct rpc_task *task) { - trace_rpc_task_begin(task->tk_client, task, NULL); - rpc_task_set_debuginfo(task); set_bit(RPC_TASK_ACTIVE, &task->tk_runstate); + trace_rpc_task_begin(task, NULL); } /* @@ -291,7 +322,7 @@ static int rpc_complete_task(struct rpc_task *task) unsigned long flags; int ret; - trace_rpc_task_complete(task->tk_client, task, NULL); + trace_rpc_task_complete(task, NULL); spin_lock_irqsave(&wq->lock, flags); clear_bit(RPC_TASK_ACTIVE, &task->tk_runstate); @@ -309,14 +340,12 @@ static int rpc_complete_task(struct rpc_task *task) * to enforce taking of the wq->lock and hence avoid races with * rpc_complete_task(). */ -int __rpc_wait_for_completion_task(struct rpc_task *task, int (*action)(void *)) +int rpc_wait_for_completion_task(struct rpc_task *task) { - if (action == NULL) - action = rpc_wait_bit_killable; return out_of_line_wait_on_bit(&task->tk_runstate, RPC_TASK_ACTIVE, - action, TASK_KILLABLE); + rpc_wait_bit_killable, TASK_KILLABLE|TASK_FREEZABLE_UNSAFE); } -EXPORT_SYMBOL_GPL(__rpc_wait_for_completion_task); +EXPORT_SYMBOL_GPL(rpc_wait_for_completion_task); /* * Make an RPC task runnable. @@ -329,7 +358,8 @@ EXPORT_SYMBOL_GPL(__rpc_wait_for_completion_task); * lockless RPC_IS_QUEUED() test) before we've had a chance to test * the RPC_TASK_RUNNING flag. */ -static void rpc_make_runnable(struct rpc_task *task) +static void rpc_make_runnable(struct workqueue_struct *wq, + struct rpc_task *task) { bool need_wakeup = !rpc_test_and_set_running(task); @@ -338,9 +368,11 @@ static void rpc_make_runnable(struct rpc_task *task) return; if (RPC_IS_ASYNC(task)) { INIT_WORK(&task->u.tk_work, rpc_async_schedule); - queue_work(rpciod_workqueue, &task->u.tk_work); - } else + queue_work(wq, &task->u.tk_work); + } else { + smp_mb__after_atomic(); wake_up_bit(&task->tk_runstate, RPC_TASK_QUEUED); + } } /* @@ -349,100 +381,175 @@ static void rpc_make_runnable(struct rpc_task *task) * NB: An RPC task will only receive interrupt-driven events as long * as it's on a wait queue. */ -static void __rpc_sleep_on_priority(struct rpc_wait_queue *q, +static void __rpc_do_sleep_on_priority(struct rpc_wait_queue *q, struct rpc_task *task, - rpc_action action, unsigned char queue_priority) { - dprintk("RPC: %5u sleep_on(queue \"%s\" time %lu)\n", - task->tk_pid, rpc_qname(q), jiffies); - - trace_rpc_task_sleep(task->tk_client, task, q); + trace_rpc_task_sleep(task, q); __rpc_add_wait_queue(q, task, queue_priority); +} + +static void __rpc_sleep_on_priority(struct rpc_wait_queue *q, + struct rpc_task *task, + unsigned char queue_priority) +{ + if (WARN_ON_ONCE(RPC_IS_QUEUED(task))) + return; + __rpc_do_sleep_on_priority(q, task, queue_priority); +} - WARN_ON_ONCE(task->tk_callback != NULL); - task->tk_callback = action; - __rpc_add_timer(q, task); +static void __rpc_sleep_on_priority_timeout(struct rpc_wait_queue *q, + struct rpc_task *task, unsigned long timeout, + unsigned char queue_priority) +{ + if (WARN_ON_ONCE(RPC_IS_QUEUED(task))) + return; + if (time_is_after_jiffies(timeout)) { + __rpc_do_sleep_on_priority(q, task, queue_priority); + __rpc_add_timer(q, task, timeout); + } else + task->tk_status = -ETIMEDOUT; } -void rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task, - rpc_action action) +static void rpc_set_tk_callback(struct rpc_task *task, rpc_action action) +{ + if (action && !WARN_ON_ONCE(task->tk_callback != NULL)) + task->tk_callback = action; +} + +static bool rpc_sleep_check_activated(struct rpc_task *task) { /* We shouldn't ever put an inactive task to sleep */ - WARN_ON_ONCE(!RPC_IS_ACTIVATED(task)); - if (!RPC_IS_ACTIVATED(task)) { + if (WARN_ON_ONCE(!RPC_IS_ACTIVATED(task))) { task->tk_status = -EIO; rpc_put_task_async(task); - return; + return false; } + return true; +} +void rpc_sleep_on_timeout(struct rpc_wait_queue *q, struct rpc_task *task, + rpc_action action, unsigned long timeout) +{ + if (!rpc_sleep_check_activated(task)) + return; + + rpc_set_tk_callback(task, action); + + /* + * Protect the queue operations. + */ + spin_lock(&q->lock); + __rpc_sleep_on_priority_timeout(q, task, timeout, task->tk_priority); + spin_unlock(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on_timeout); + +void rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task, + rpc_action action) +{ + if (!rpc_sleep_check_activated(task)) + return; + + rpc_set_tk_callback(task, action); + + WARN_ON_ONCE(task->tk_timeout != 0); /* * Protect the queue operations. */ - spin_lock_bh(&q->lock); - __rpc_sleep_on_priority(q, task, action, task->tk_priority); - spin_unlock_bh(&q->lock); + spin_lock(&q->lock); + __rpc_sleep_on_priority(q, task, task->tk_priority); + spin_unlock(&q->lock); } EXPORT_SYMBOL_GPL(rpc_sleep_on); +void rpc_sleep_on_priority_timeout(struct rpc_wait_queue *q, + struct rpc_task *task, unsigned long timeout, int priority) +{ + if (!rpc_sleep_check_activated(task)) + return; + + priority -= RPC_PRIORITY_LOW; + /* + * Protect the queue operations. + */ + spin_lock(&q->lock); + __rpc_sleep_on_priority_timeout(q, task, timeout, priority); + spin_unlock(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on_priority_timeout); + void rpc_sleep_on_priority(struct rpc_wait_queue *q, struct rpc_task *task, - rpc_action action, int priority) + int priority) { - /* We shouldn't ever put an inactive task to sleep */ - WARN_ON_ONCE(!RPC_IS_ACTIVATED(task)); - if (!RPC_IS_ACTIVATED(task)) { - task->tk_status = -EIO; - rpc_put_task_async(task); + if (!rpc_sleep_check_activated(task)) return; - } + WARN_ON_ONCE(task->tk_timeout != 0); + priority -= RPC_PRIORITY_LOW; /* * Protect the queue operations. */ - spin_lock_bh(&q->lock); - __rpc_sleep_on_priority(q, task, action, priority - RPC_PRIORITY_LOW); - spin_unlock_bh(&q->lock); + spin_lock(&q->lock); + __rpc_sleep_on_priority(q, task, priority); + spin_unlock(&q->lock); } EXPORT_SYMBOL_GPL(rpc_sleep_on_priority); /** - * __rpc_do_wake_up_task - wake up a single rpc_task + * __rpc_do_wake_up_task_on_wq - wake up a single rpc_task + * @wq: workqueue on which to run task * @queue: wait queue * @task: task to be woken up * * Caller must hold queue->lock, and have cleared the task queued flag. */ -static void __rpc_do_wake_up_task(struct rpc_wait_queue *queue, struct rpc_task *task) +static void __rpc_do_wake_up_task_on_wq(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, + struct rpc_task *task) { - dprintk("RPC: %5u __rpc_wake_up_task (now %lu)\n", - task->tk_pid, jiffies); - /* Has the task been executed yet? If not, we cannot wake it up! */ if (!RPC_IS_ACTIVATED(task)) { printk(KERN_ERR "RPC: Inactive task (%p) being woken up!\n", task); return; } - trace_rpc_task_wakeup(task->tk_client, task, queue); + trace_rpc_task_wakeup(task, queue); __rpc_remove_wait_queue(queue, task); - rpc_make_runnable(task); - - dprintk("RPC: __rpc_wake_up_task done\n"); + rpc_make_runnable(wq, task); } /* * Wake up a queued task while the queue lock is being held */ -static void rpc_wake_up_task_queue_locked(struct rpc_wait_queue *queue, struct rpc_task *task) +static struct rpc_task * +rpc_wake_up_task_on_wq_queue_action_locked(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, struct rpc_task *task, + bool (*action)(struct rpc_task *, void *), void *data) { if (RPC_IS_QUEUED(task)) { smp_rmb(); - if (task->tk_waitqueue == queue) - __rpc_do_wake_up_task(queue, task); + if (task->tk_waitqueue == queue) { + if (action == NULL || action(task, data)) { + __rpc_do_wake_up_task_on_wq(wq, queue, task); + return task; + } + } } + return NULL; +} + +/* + * Wake up a queued task while the queue lock is being held + */ +static void rpc_wake_up_task_queue_locked(struct rpc_wait_queue *queue, + struct rpc_task *task) +{ + rpc_wake_up_task_on_wq_queue_action_locked(rpciod_workqueue, queue, + task, NULL, NULL); } /* @@ -450,12 +557,48 @@ static void rpc_wake_up_task_queue_locked(struct rpc_wait_queue *queue, struct r */ void rpc_wake_up_queued_task(struct rpc_wait_queue *queue, struct rpc_task *task) { - spin_lock_bh(&queue->lock); + if (!RPC_IS_QUEUED(task)) + return; + spin_lock(&queue->lock); rpc_wake_up_task_queue_locked(queue, task); - spin_unlock_bh(&queue->lock); + spin_unlock(&queue->lock); } EXPORT_SYMBOL_GPL(rpc_wake_up_queued_task); +static bool rpc_task_action_set_status(struct rpc_task *task, void *status) +{ + task->tk_status = *(int *)status; + return true; +} + +static void +rpc_wake_up_task_queue_set_status_locked(struct rpc_wait_queue *queue, + struct rpc_task *task, int status) +{ + rpc_wake_up_task_on_wq_queue_action_locked(rpciod_workqueue, queue, + task, rpc_task_action_set_status, &status); +} + +/** + * rpc_wake_up_queued_task_set_status - wake up a task and set task->tk_status + * @queue: pointer to rpc_wait_queue + * @task: pointer to rpc_task + * @status: integer error value + * + * If @task is queued on @queue, then it is woken up, and @task->tk_status is + * set to the value of @status. + */ +void +rpc_wake_up_queued_task_set_status(struct rpc_wait_queue *queue, + struct rpc_task *task, int status) +{ + if (!RPC_IS_QUEUED(task)) + return; + spin_lock(&queue->lock); + rpc_wake_up_task_queue_set_status_locked(queue, task, status); + spin_unlock(&queue->lock); +} + /* * Wake up the next task on a priority queue. */ @@ -465,20 +608,22 @@ static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *q struct rpc_task *task; /* + * Service the privileged queue. + */ + q = &queue->tasks[RPC_NR_PRIORITY - 1]; + if (queue->maxpriority > RPC_PRIORITY_PRIVILEGED && !list_empty(q)) { + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + goto out; + } + + /* * Service a batch of tasks from a single owner. */ q = &queue->tasks[queue->priority]; - if (!list_empty(q)) { - task = list_entry(q->next, struct rpc_task, u.tk_wait.list); - if (queue->owner == task->tk_owner) { - if (--queue->nr) - goto out; - list_move_tail(&task->u.tk_wait.list, q); - } - /* - * Check if we need to switch queues. - */ - goto new_owner; + if (!list_empty(q) && queue->nr) { + queue->nr--; + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + goto out; } /* @@ -490,7 +635,7 @@ static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *q else q = q - 1; if (!list_empty(q)) { - task = list_entry(q->next, struct rpc_task, u.tk_wait.list); + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); goto new_queue; } } while (q != &queue->tasks[queue->priority]); @@ -500,8 +645,6 @@ static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *q new_queue: rpc_set_waitqueue_priority(queue, (unsigned int)(q - &queue->tasks[0])); -new_owner: - rpc_set_waitqueue_owner(queue, task->tk_owner); out: return task; } @@ -518,25 +661,30 @@ static struct rpc_task *__rpc_find_next_queued(struct rpc_wait_queue *queue) /* * Wake up the first task on the wait queue. */ -struct rpc_task *rpc_wake_up_first(struct rpc_wait_queue *queue, +struct rpc_task *rpc_wake_up_first_on_wq(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, bool (*func)(struct rpc_task *, void *), void *data) { struct rpc_task *task = NULL; - dprintk("RPC: wake_up_first(%p \"%s\")\n", - queue, rpc_qname(queue)); - spin_lock_bh(&queue->lock); + spin_lock(&queue->lock); task = __rpc_find_next_queued(queue); - if (task != NULL) { - if (func(task, data)) - rpc_wake_up_task_queue_locked(queue, task); - else - task = NULL; - } - spin_unlock_bh(&queue->lock); + if (task != NULL) + task = rpc_wake_up_task_on_wq_queue_action_locked(wq, queue, + task, func, data); + spin_unlock(&queue->lock); return task; } + +/* + * Wake up the first task on the wait queue. + */ +struct rpc_task *rpc_wake_up_first(struct rpc_wait_queue *queue, + bool (*func)(struct rpc_task *, void *), void *data) +{ + return rpc_wake_up_first_on_wq(rpciod_workqueue, queue, func, data); +} EXPORT_SYMBOL_GPL(rpc_wake_up_first); static bool rpc_wake_up_next_func(struct rpc_task *task, void *data) @@ -554,6 +702,23 @@ struct rpc_task *rpc_wake_up_next(struct rpc_wait_queue *queue) EXPORT_SYMBOL_GPL(rpc_wake_up_next); /** + * rpc_wake_up_locked - wake up all rpc_tasks + * @queue: rpc_wait_queue on which the tasks are sleeping + * + */ +static void rpc_wake_up_locked(struct rpc_wait_queue *queue) +{ + struct rpc_task *task; + + for (;;) { + task = __rpc_find_next_queued(queue); + if (task == NULL) + break; + rpc_wake_up_task_queue_locked(queue, task); + } +} + +/** * rpc_wake_up - wake up all rpc_tasks * @queue: rpc_wait_queue on which the tasks are sleeping * @@ -561,25 +726,28 @@ EXPORT_SYMBOL_GPL(rpc_wake_up_next); */ void rpc_wake_up(struct rpc_wait_queue *queue) { - struct list_head *head; + spin_lock(&queue->lock); + rpc_wake_up_locked(queue); + spin_unlock(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up); + +/** + * rpc_wake_up_status_locked - wake up all rpc_tasks and set their status value. + * @queue: rpc_wait_queue on which the tasks are sleeping + * @status: status value to set + */ +static void rpc_wake_up_status_locked(struct rpc_wait_queue *queue, int status) +{ + struct rpc_task *task; - spin_lock_bh(&queue->lock); - head = &queue->tasks[queue->maxpriority]; for (;;) { - while (!list_empty(head)) { - struct rpc_task *task; - task = list_first_entry(head, - struct rpc_task, - u.tk_wait.list); - rpc_wake_up_task_queue_locked(queue, task); - } - if (head == &queue->tasks[0]) + task = __rpc_find_next_queued(queue); + if (task == NULL) break; - head--; + rpc_wake_up_task_queue_set_status_locked(queue, task, status); } - spin_unlock_bh(&queue->lock); } -EXPORT_SYMBOL_GPL(rpc_wake_up); /** * rpc_wake_up_status - wake up all rpc_tasks and set their status value. @@ -590,39 +758,26 @@ EXPORT_SYMBOL_GPL(rpc_wake_up); */ void rpc_wake_up_status(struct rpc_wait_queue *queue, int status) { - struct list_head *head; - - spin_lock_bh(&queue->lock); - head = &queue->tasks[queue->maxpriority]; - for (;;) { - while (!list_empty(head)) { - struct rpc_task *task; - task = list_first_entry(head, - struct rpc_task, - u.tk_wait.list); - task->tk_status = status; - rpc_wake_up_task_queue_locked(queue, task); - } - if (head == &queue->tasks[0]) - break; - head--; - } - spin_unlock_bh(&queue->lock); + spin_lock(&queue->lock); + rpc_wake_up_status_locked(queue, status); + spin_unlock(&queue->lock); } EXPORT_SYMBOL_GPL(rpc_wake_up_status); -static void __rpc_queue_timer_fn(unsigned long ptr) +static void __rpc_queue_timer_fn(struct work_struct *work) { - struct rpc_wait_queue *queue = (struct rpc_wait_queue *)ptr; + struct rpc_wait_queue *queue = container_of(work, + struct rpc_wait_queue, + timer_list.dwork.work); struct rpc_task *task, *n; unsigned long expires, now, timeo; spin_lock(&queue->lock); expires = now = jiffies; list_for_each_entry_safe(task, n, &queue->timer_list.list, u.tk_wait.timer_list) { - timeo = task->u.tk_wait.expires; + timeo = task->tk_timeout; if (time_after_eq(now, timeo)) { - dprintk("RPC: %5u timeout\n", task->tk_pid); + trace_rpc_task_timeout(task, task->tk_action); task->tk_status = -ETIMEDOUT; rpc_wake_up_task_queue_locked(queue, task); continue; @@ -637,7 +792,8 @@ static void __rpc_queue_timer_fn(unsigned long ptr) static void __rpc_atrun(struct rpc_task *task) { - task->tk_status = 0; + if (task->tk_status == -ETIMEDOUT) + task->tk_status = 0; } /* @@ -645,8 +801,7 @@ static void __rpc_atrun(struct rpc_task *task) */ void rpc_delay(struct rpc_task *task, unsigned long delay) { - task->tk_timeout = delay; - rpc_sleep_on(&delay_queue, task, __rpc_atrun); + rpc_sleep_on_timeout(&delay_queue, task, __rpc_atrun, jiffies + delay); } EXPORT_SYMBOL_GPL(rpc_delay); @@ -664,7 +819,6 @@ rpc_init_task_statistics(struct rpc_task *task) /* Initialize retry counters */ task->tk_garb_retry = 2; task->tk_cred_retry = 2; - task->tk_rebind_retry = 2; /* starting timestamp */ task->tk_start = ktime_get(); @@ -674,8 +828,7 @@ static void rpc_reset_task_statistics(struct rpc_task *task) { task->tk_timeouts = 0; - task->tk_flags &= ~(RPC_CALL_MAJORSEEN|RPC_TASK_KILLED|RPC_TASK_SENT); - + task->tk_flags &= ~(RPC_CALL_MAJORSEEN|RPC_TASK_SENT); rpc_init_task_statistics(task); } @@ -684,11 +837,16 @@ rpc_reset_task_statistics(struct rpc_task *task) */ void rpc_exit_task(struct rpc_task *task) { + trace_rpc_task_end(task, task->tk_action); task->tk_action = NULL; + if (task->tk_ops->rpc_count_stats) + task->tk_ops->rpc_count_stats(task, task->tk_calldata); + else if (task->tk_client) + rpc_count_iostats(task, task->tk_client->cl_metrics); if (task->tk_ops->rpc_call_done != NULL) { + trace_rpc_task_call_done(task, task->tk_ops->rpc_call_done); task->tk_ops->rpc_call_done(task, task->tk_calldata); if (task->tk_action != NULL) { - WARN_ON(RPC_ASSASSINATED(task)); /* Always release the RPC slot and buffer memory */ xprt_release(task); rpc_reset_task_statistics(task); @@ -696,12 +854,37 @@ void rpc_exit_task(struct rpc_task *task) } } +void rpc_signal_task(struct rpc_task *task) +{ + struct rpc_wait_queue *queue; + + if (!RPC_IS_ACTIVATED(task)) + return; + + if (!rpc_task_set_rpc_status(task, -ERESTARTSYS)) + return; + trace_rpc_task_signalled(task, task->tk_action); + queue = READ_ONCE(task->tk_waitqueue); + if (queue) + rpc_wake_up_queued_task(queue, task); +} + +void rpc_task_try_cancel(struct rpc_task *task, int error) +{ + struct rpc_wait_queue *queue; + + if (!rpc_task_set_rpc_status(task, error)) + return; + queue = READ_ONCE(task->tk_waitqueue); + if (queue) + rpc_wake_up_queued_task(queue, task); +} + void rpc_exit(struct rpc_task *task, int status) { task->tk_status = status; task->tk_action = rpc_exit_task; - if (RPC_IS_QUEUED(task)) - rpc_wake_up_queued_task(task->tk_waitqueue, task); + rpc_wake_up_queued_task(task->tk_waitqueue, task); } EXPORT_SYMBOL_GPL(rpc_exit); @@ -711,6 +894,15 @@ void rpc_release_calldata(const struct rpc_call_ops *ops, void *calldata) ops->rpc_release(calldata); } +static bool xprt_needs_memalloc(struct rpc_xprt *xprt, struct rpc_task *tk) +{ + if (!xprt) + return false; + if (!atomic_read(&xprt->swapper)) + return false; + return test_bit(XPRT_LOCKED, &xprt->state) && xprt->snd_task == tk; +} + /* * This is the RPC `scheduler' (or rather, the finite state machine). */ @@ -719,9 +911,7 @@ static void __rpc_execute(struct rpc_task *task) struct rpc_wait_queue *queue; int task_is_async = RPC_IS_ASYNC(task); int status = 0; - - dprintk("RPC: %5u __rpc_execute flags=0x%x\n", - task->tk_pid, task->tk_flags); + unsigned long pflags = current->flags; WARN_ON_ONCE(RPC_IS_QUEUED(task)); if (RPC_IS_QUEUED(task)) @@ -731,29 +921,39 @@ static void __rpc_execute(struct rpc_task *task) void (*do_action)(struct rpc_task *); /* - * Execute any pending callback first. + * Perform the next FSM step or a pending callback. + * + * tk_action may be NULL if the task has been killed. */ - do_action = task->tk_callback; - task->tk_callback = NULL; - if (do_action == NULL) { - /* - * Perform the next FSM step. - * tk_action may be NULL if the task has been killed. - * In particular, note that rpc_killall_tasks may - * do this at any time, so beware when dereferencing. - */ - do_action = task->tk_action; - if (do_action == NULL) - break; + do_action = task->tk_action; + /* Tasks with an RPC error status should exit */ + if (do_action && do_action != rpc_exit_task && + (status = READ_ONCE(task->tk_rpc_status)) != 0) { + task->tk_status = status; + do_action = rpc_exit_task; } - trace_rpc_task_run_action(task->tk_client, task, task->tk_action); + /* Callbacks override all actions */ + if (task->tk_callback) { + do_action = task->tk_callback; + task->tk_callback = NULL; + } + if (!do_action) + break; + if (RPC_IS_SWAPPER(task) || + xprt_needs_memalloc(task->tk_xprt, task)) + current->flags |= PF_MEMALLOC; + + trace_rpc_task_run_action(task, do_action); do_action(task); /* * Lockless check for whether task is sleeping or not. */ - if (!RPC_IS_QUEUED(task)) + if (!RPC_IS_QUEUED(task)) { + cond_resched(); continue; + } + /* * The queue->lock protects against races with * rpc_make_runnable(). @@ -764,39 +964,43 @@ static void __rpc_execute(struct rpc_task *task) * rpc_task pointer may still be dereferenced. */ queue = task->tk_waitqueue; - spin_lock_bh(&queue->lock); + spin_lock(&queue->lock); if (!RPC_IS_QUEUED(task)) { - spin_unlock_bh(&queue->lock); + spin_unlock(&queue->lock); + continue; + } + /* Wake up any task that has an exit status */ + if (READ_ONCE(task->tk_rpc_status) != 0) { + rpc_wake_up_task_queue_locked(queue, task); + spin_unlock(&queue->lock); continue; } rpc_clear_running(task); - spin_unlock_bh(&queue->lock); + spin_unlock(&queue->lock); if (task_is_async) - return; + goto out; /* sync task: sleep here */ - dprintk("RPC: %5u sync task going to sleep\n", task->tk_pid); + trace_rpc_task_sync_sleep(task, task->tk_action); status = out_of_line_wait_on_bit(&task->tk_runstate, RPC_TASK_QUEUED, rpc_wait_bit_killable, - TASK_KILLABLE); - if (status == -ERESTARTSYS) { + TASK_KILLABLE|TASK_FREEZABLE); + if (status < 0) { /* * When a sync task receives a signal, it exits with * -ERESTARTSYS. In order to catch any callbacks that * clean up after sleeping on some queue, we don't * break the loop here, but go around once more. */ - dprintk("RPC: %5u got signal\n", task->tk_pid); - task->tk_flags |= RPC_TASK_KILLED; - rpc_exit(task, -ERESTARTSYS); + rpc_signal_task(task); } - dprintk("RPC: %5u sync task resuming\n", task->tk_pid); + trace_rpc_task_sync_wake(task, task->tk_action); } - dprintk("RPC: %5u return %d, status %d\n", task->tk_pid, status, - task->tk_status); /* Release all resources associated with the task */ rpc_release_task(task); +out: + current_restore_flags(pflags, PF_MEMALLOC); } /* @@ -813,83 +1017,83 @@ void rpc_execute(struct rpc_task *task) bool is_async = RPC_IS_ASYNC(task); rpc_set_active(task); - rpc_make_runnable(task); - if (!is_async) + rpc_make_runnable(rpciod_workqueue, task); + if (!is_async) { + unsigned int pflags = memalloc_nofs_save(); __rpc_execute(task); + memalloc_nofs_restore(pflags); + } } static void rpc_async_schedule(struct work_struct *work) { - current->flags |= PF_FSTRANS; + unsigned int pflags = memalloc_nofs_save(); + __rpc_execute(container_of(work, struct rpc_task, u.tk_work)); - current->flags &= ~PF_FSTRANS; + memalloc_nofs_restore(pflags); } /** - * rpc_malloc - allocate an RPC buffer - * @task: RPC task that will use this buffer - * @size: requested byte size + * rpc_malloc - allocate RPC buffer resources + * @task: RPC task + * + * A single memory region is allocated, which is split between the + * RPC call and RPC reply that this task is being used for. When + * this RPC is retired, the memory is released by calling rpc_free. * * To prevent rpciod from hanging, this allocator never sleeps, - * returning NULL if the request cannot be serviced immediately. - * The caller can arrange to sleep in a way that is safe for rpciod. + * returning -ENOMEM and suppressing warning if the request cannot + * be serviced immediately. The caller can arrange to sleep in a + * way that is safe for rpciod. * * Most requests are 'small' (under 2KiB) and can be serviced from a * mempool, ensuring that NFS reads and writes can always proceed, * and that there is good locality of reference for these buffers. - * - * In order to avoid memory starvation triggering more writebacks of - * NFS requests, we avoid using GFP_KERNEL. */ -void *rpc_malloc(struct rpc_task *task, size_t size) +int rpc_malloc(struct rpc_task *task) { + struct rpc_rqst *rqst = task->tk_rqstp; + size_t size = rqst->rq_callsize + rqst->rq_rcvsize; struct rpc_buffer *buf; - gfp_t gfp = GFP_NOWAIT; - - if (RPC_IS_SWAPPER(task)) - gfp |= __GFP_MEMALLOC; + gfp_t gfp = rpc_task_gfp_mask(); size += sizeof(struct rpc_buffer); - if (size <= RPC_BUFFER_MAXSIZE) - buf = mempool_alloc(rpc_buffer_mempool, gfp); - else + if (size <= RPC_BUFFER_MAXSIZE) { + buf = kmem_cache_alloc(rpc_buffer_slabp, gfp); + /* Reach for the mempool if dynamic allocation fails */ + if (!buf && RPC_IS_ASYNC(task)) + buf = mempool_alloc(rpc_buffer_mempool, GFP_NOWAIT); + } else buf = kmalloc(size, gfp); if (!buf) - return NULL; + return -ENOMEM; buf->len = size; - dprintk("RPC: %5u allocated buffer of size %zu at %p\n", - task->tk_pid, size, buf); - return &buf->data; + rqst->rq_buffer = buf->data; + rqst->rq_rbuffer = (char *)rqst->rq_buffer + rqst->rq_callsize; + return 0; } -EXPORT_SYMBOL_GPL(rpc_malloc); /** - * rpc_free - free buffer allocated via rpc_malloc - * @buffer: buffer to free + * rpc_free - free RPC buffer resources allocated via rpc_malloc + * @task: RPC task * */ -void rpc_free(void *buffer) +void rpc_free(struct rpc_task *task) { + void *buffer = task->tk_rqstp->rq_buffer; size_t size; struct rpc_buffer *buf; - if (!buffer) - return; - buf = container_of(buffer, struct rpc_buffer, data); size = buf->len; - dprintk("RPC: freeing buffer of size %zu at %p\n", - size, buf); - if (size <= RPC_BUFFER_MAXSIZE) mempool_free(buf, rpc_buffer_mempool); else kfree(buf); } -EXPORT_SYMBOL_GPL(rpc_free); /* * Creation and deletion of RPC task structures @@ -909,19 +1113,25 @@ static void rpc_init_task(struct rpc_task *task, const struct rpc_task_setup *ta /* Initialize workqueue for async tasks */ task->tk_workqueue = task_setup_data->workqueue; + task->tk_xprt = rpc_task_get_xprt(task_setup_data->rpc_client, + xprt_get(task_setup_data->rpc_xprt)); + + task->tk_op_cred = get_rpccred(task_setup_data->rpc_op_cred); + if (task->tk_ops->rpc_call_prepare != NULL) task->tk_action = rpc_prepare_task; rpc_init_task_statistics(task); - - dprintk("RPC: new task initialized, procpid %u\n", - task_pid_nr(current)); } -static struct rpc_task * -rpc_alloc_task(void) +static struct rpc_task *rpc_alloc_task(void) { - return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOIO); + struct rpc_task *task; + + task = kmem_cache_alloc(rpc_task_slabp, rpc_task_gfp_mask()); + if (task) + return task; + return mempool_alloc(rpc_task_mempool, GFP_NOWAIT); } /* @@ -936,7 +1146,7 @@ struct rpc_task *rpc_new_task(const struct rpc_task_setup *setup_data) task = rpc_alloc_task(); if (task == NULL) { rpc_release_calldata(setup_data->callback_ops, - setup_data->callback_data); + setup_data->callback_data); return ERR_PTR(-ENOMEM); } flags = RPC_TASK_DYNAMIC; @@ -944,7 +1154,6 @@ struct rpc_task *rpc_new_task(const struct rpc_task_setup *setup_data) rpc_init_task(task, setup_data); task->tk_flags |= flags; - dprintk("RPC: allocated task %p\n", task); return task; } @@ -971,24 +1180,27 @@ static void rpc_free_task(struct rpc_task *task) { unsigned short tk_flags = task->tk_flags; + put_rpccred(task->tk_op_cred); rpc_release_calldata(task->tk_ops, task->tk_calldata); - if (tk_flags & RPC_TASK_DYNAMIC) { - dprintk("RPC: %5u freeing task\n", task->tk_pid); + if (tk_flags & RPC_TASK_DYNAMIC) mempool_free(task, rpc_task_mempool); - } } static void rpc_async_release(struct work_struct *work) { + unsigned int pflags = memalloc_nofs_save(); + rpc_free_task(container_of(work, struct rpc_task, u.tk_work)); + memalloc_nofs_restore(pflags); } static void rpc_release_resources_task(struct rpc_task *task) { xprt_release(task); if (task->tk_msg.rpc_cred) { - put_rpccred(task->tk_msg.rpc_cred); + if (!(task->tk_flags & RPC_TASK_CRED_NOREF)) + put_cred(task->tk_msg.rpc_cred); task->tk_msg.rpc_cred = NULL; } rpc_task_release_client(task); @@ -1026,8 +1238,6 @@ EXPORT_SYMBOL_GPL(rpc_put_task_async); static void rpc_release_task(struct rpc_task *task) { - dprintk("RPC: %5u release task\n", task->tk_pid); - WARN_ON_ONCE(RPC_IS_QUEUED(task)); rpc_release_resources_task(task); @@ -1068,10 +1278,21 @@ static int rpciod_start(void) /* * Create the rpciod thread and wait for it to start. */ - dprintk("RPC: creating workqueue rpciod\n"); - wq = alloc_workqueue("rpciod", WQ_MEM_RECLAIM, 1); + wq = alloc_workqueue("rpciod", WQ_MEM_RECLAIM | WQ_UNBOUND, 0); + if (!wq) + goto out_failed; rpciod_workqueue = wq; - return rpciod_workqueue != NULL; + wq = alloc_workqueue("xprtiod", WQ_UNBOUND | WQ_MEM_RECLAIM, 0); + if (!wq) + goto free_rpciod; + xprtiod_workqueue = wq; + return 1; +free_rpciod: + wq = rpciod_workqueue; + rpciod_workqueue = NULL; + destroy_workqueue(wq); +out_failed: + return 0; } static void rpciod_stop(void) @@ -1080,25 +1301,23 @@ static void rpciod_stop(void) if (rpciod_workqueue == NULL) return; - dprintk("RPC: destroying workqueue rpciod\n"); wq = rpciod_workqueue; rpciod_workqueue = NULL; destroy_workqueue(wq); + wq = xprtiod_workqueue; + xprtiod_workqueue = NULL; + destroy_workqueue(wq); } void rpc_destroy_mempool(void) { rpciod_stop(); - if (rpc_buffer_mempool) - mempool_destroy(rpc_buffer_mempool); - if (rpc_task_mempool) - mempool_destroy(rpc_task_mempool); - if (rpc_task_slabp) - kmem_cache_destroy(rpc_task_slabp); - if (rpc_buffer_slabp) - kmem_cache_destroy(rpc_buffer_slabp); + mempool_destroy(rpc_buffer_mempool); + mempool_destroy(rpc_task_mempool); + kmem_cache_destroy(rpc_task_slabp); + kmem_cache_destroy(rpc_buffer_slabp); rpc_destroy_wait_queue(&delay_queue); } diff --git a/net/sunrpc/socklib.c b/net/sunrpc/socklib.c index 0a648c502fc3..d8d8842c7de5 100644 --- a/net/sunrpc/socklib.c +++ b/net/sunrpc/socklib.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/socklib.c * @@ -13,134 +14,106 @@ #include <linux/types.h> #include <linux/pagemap.h> #include <linux/udp.h> +#include <linux/sunrpc/msg_prot.h> +#include <linux/sunrpc/sched.h> #include <linux/sunrpc/xdr.h> #include <linux/export.h> +#include "socklib.h" -/** - * xdr_skb_read_bits - copy some data bits from skb to internal buffer - * @desc: sk_buff copy helper - * @to: copy destination - * @len: number of bytes to copy - * - * Possibly called several times to iterate over an sk_buff and copy - * data out of it. +/* + * Helper structure for copying from an sk_buff. */ -size_t xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len) -{ - if (len > desc->count) - len = desc->count; - if (unlikely(skb_copy_bits(desc->skb, desc->offset, to, len))) - return 0; - desc->count -= len; - desc->offset += len; - return len; -} -EXPORT_SYMBOL_GPL(xdr_skb_read_bits); +struct xdr_skb_reader { + struct sk_buff *skb; + unsigned int offset; + bool need_checksum; + size_t count; + __wsum csum; +}; /** - * xdr_skb_read_and_csum_bits - copy and checksum from skb to buffer + * xdr_skb_read_bits - copy some data bits from skb to internal buffer * @desc: sk_buff copy helper * @to: copy destination * @len: number of bytes to copy * - * Same as skb_read_bits, but calculate a checksum at the same time. + * Possibly called several times to iterate over an sk_buff and copy data out of + * it. */ -static size_t xdr_skb_read_and_csum_bits(struct xdr_skb_reader *desc, void *to, size_t len) +static size_t +xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len) { - unsigned int pos; - __wsum csum2; - - if (len > desc->count) - len = desc->count; - pos = desc->offset; - csum2 = skb_copy_and_csum_bits(desc->skb, pos, to, len, 0); - desc->csum = csum_block_add(desc->csum, csum2, pos); + len = min(len, desc->count); + + if (desc->need_checksum) { + __wsum csum; + + csum = skb_copy_and_csum_bits(desc->skb, desc->offset, to, len); + desc->csum = csum_block_add(desc->csum, csum, desc->offset); + } else { + if (unlikely(skb_copy_bits(desc->skb, desc->offset, to, len))) + return 0; + } + desc->count -= len; desc->offset += len; return len; } -/** - * xdr_partial_copy_from_skb - copy data out of an skb - * @xdr: target XDR buffer - * @base: starting offset - * @desc: sk_buff copy helper - * @copy_actor: virtual method for copying data - * - */ -ssize_t xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct xdr_skb_reader *desc, xdr_skb_read_actor copy_actor) +static ssize_t +xdr_partial_copy_from_skb(struct xdr_buf *xdr, struct xdr_skb_reader *desc) { - struct page **ppage = xdr->pages; - unsigned int len, pglen = xdr->page_len; - ssize_t copied = 0; - size_t ret; - - len = xdr->head[0].iov_len; - if (base < len) { - len -= base; - ret = copy_actor(desc, (char *)xdr->head[0].iov_base + base, len); - copied += ret; - if (ret != len || !desc->count) - goto out; - base = 0; - } else - base -= len; - - if (unlikely(pglen == 0)) - goto copy_tail; - if (unlikely(base >= pglen)) { - base -= pglen; - goto copy_tail; - } - if (base || xdr->page_base) { - pglen -= base; - base += xdr->page_base; - ppage += base >> PAGE_CACHE_SHIFT; - base &= ~PAGE_CACHE_MASK; - } - do { + struct page **ppage = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + unsigned int poff = xdr->page_base & ~PAGE_MASK; + unsigned int pglen = xdr->page_len; + ssize_t copied = 0; + size_t ret; + + if (xdr->head[0].iov_len == 0) + return 0; + + ret = xdr_skb_read_bits(desc, xdr->head[0].iov_base, + xdr->head[0].iov_len); + if (ret != xdr->head[0].iov_len || !desc->count) + return ret; + copied += ret; + + while (pglen) { + unsigned int len = min(PAGE_SIZE - poff, pglen); char *kaddr; /* ACL likes to be lazy in allocating pages - ACLs * are small by default but can get huge. */ - if (unlikely(*ppage == NULL)) { - *ppage = alloc_page(GFP_ATOMIC); + if ((xdr->flags & XDRBUF_SPARSE_PAGES) && *ppage == NULL) { + *ppage = alloc_page(GFP_NOWAIT); if (unlikely(*ppage == NULL)) { if (copied == 0) - copied = -ENOMEM; - goto out; + return -ENOMEM; + return copied; } } - len = PAGE_CACHE_SIZE; kaddr = kmap_atomic(*ppage); - if (base) { - len -= base; - if (pglen < len) - len = pglen; - ret = copy_actor(desc, kaddr + base, len); - base = 0; - } else { - if (pglen < len) - len = pglen; - ret = copy_actor(desc, kaddr, len); - } + ret = xdr_skb_read_bits(desc, kaddr + poff, len); flush_dcache_page(*ppage); kunmap_atomic(kaddr); + copied += ret; if (ret != len || !desc->count) - goto out; + return copied; ppage++; - } while ((pglen -= len) != 0); -copy_tail: - len = xdr->tail[0].iov_len; - if (base < len) - copied += copy_actor(desc, (char *)xdr->tail[0].iov_base + base, len - base); -out: + pglen -= len; + poff = 0; + } + + if (xdr->tail[0].iov_len) { + copied += xdr_skb_read_bits(desc, xdr->tail[0].iov_base, + xdr->tail[0].iov_len); + } + return copied; } -EXPORT_SYMBOL_GPL(xdr_partial_copy_from_skb); /** * csum_partial_copy_to_xdr - checksum and copy data @@ -152,17 +125,22 @@ EXPORT_SYMBOL_GPL(xdr_partial_copy_from_skb); */ int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) { - struct xdr_skb_reader desc; - - desc.skb = skb; - desc.offset = sizeof(struct udphdr); - desc.count = skb->len - desc.offset; + struct xdr_skb_reader desc = { + .skb = skb, + .count = skb->len - desc.offset, + }; - if (skb_csum_unnecessary(skb)) - goto no_checksum; + if (skb_csum_unnecessary(skb)) { + if (xdr_partial_copy_from_skb(xdr, &desc) < 0) + return -1; + if (desc.count) + return -1; + return 0; + } + desc.need_checksum = true; desc.csum = csum_partial(skb->data, desc.offset, skb->csum); - if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_and_csum_bits) < 0) + if (xdr_partial_copy_from_skb(xdr, &desc) < 0) return -1; if (desc.offset != skb->len) { __wsum csum2; @@ -173,14 +151,128 @@ int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) return -1; if (csum_fold(desc.csum)) return -1; - if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE)) - netdev_rx_csum_fault(skb->dev); - return 0; -no_checksum: - if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_bits) < 0) - return -1; - if (desc.count) - return -1; + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) && + !skb->csum_complete_sw) + netdev_rx_csum_fault(skb->dev, skb); return 0; } -EXPORT_SYMBOL_GPL(csum_partial_copy_to_xdr); + +static inline int xprt_sendmsg(struct socket *sock, struct msghdr *msg, + size_t seek) +{ + if (seek) + iov_iter_advance(&msg->msg_iter, seek); + return sock_sendmsg(sock, msg); +} + +static int xprt_send_kvec(struct socket *sock, struct msghdr *msg, + struct kvec *vec, size_t seek) +{ + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, 1, vec->iov_len); + return xprt_sendmsg(sock, msg, seek); +} + +static int xprt_send_pagedata(struct socket *sock, struct msghdr *msg, + struct xdr_buf *xdr, size_t base) +{ + iov_iter_bvec(&msg->msg_iter, ITER_SOURCE, xdr->bvec, xdr_buf_pagecount(xdr), + xdr->page_len + xdr->page_base); + return xprt_sendmsg(sock, msg, base + xdr->page_base); +} + +/* Common case: + * - stream transport + * - sending from byte 0 of the message + * - the message is wholly contained in @xdr's head iovec + */ +static int xprt_send_rm_and_kvec(struct socket *sock, struct msghdr *msg, + rpc_fraghdr marker, struct kvec *vec, + size_t base) +{ + struct kvec iov[2] = { + [0] = { + .iov_base = &marker, + .iov_len = sizeof(marker) + }, + [1] = *vec, + }; + size_t len = iov[0].iov_len + iov[1].iov_len; + + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, iov, 2, len); + return xprt_sendmsg(sock, msg, base); +} + +/** + * xprt_sock_sendmsg - write an xdr_buf directly to a socket + * @sock: open socket to send on + * @msg: socket message metadata + * @xdr: xdr_buf containing this request + * @base: starting position in the buffer + * @marker: stream record marker field + * @sent_p: return the total number of bytes successfully queued for sending + * + * Return values: + * On success, returns zero and fills in @sent_p. + * %-ENOTSOCK if @sock is not a struct socket. + */ +int xprt_sock_sendmsg(struct socket *sock, struct msghdr *msg, + struct xdr_buf *xdr, unsigned int base, + rpc_fraghdr marker, unsigned int *sent_p) +{ + unsigned int rmsize = marker ? sizeof(marker) : 0; + unsigned int remainder = rmsize + xdr->len - base; + unsigned int want; + int err = 0; + + *sent_p = 0; + + if (unlikely(!sock)) + return -ENOTSOCK; + + msg->msg_flags |= MSG_MORE; + want = xdr->head[0].iov_len + rmsize; + if (base < want) { + unsigned int len = want - base; + + remainder -= len; + if (remainder == 0) + msg->msg_flags &= ~MSG_MORE; + if (rmsize) + err = xprt_send_rm_and_kvec(sock, msg, marker, + &xdr->head[0], base); + else + err = xprt_send_kvec(sock, msg, &xdr->head[0], base); + if (remainder == 0 || err != len) + goto out; + *sent_p += err; + base = 0; + } else { + base -= want; + } + + if (base < xdr->page_len) { + unsigned int len = xdr->page_len - base; + + remainder -= len; + if (remainder == 0) + msg->msg_flags &= ~MSG_MORE; + err = xprt_send_pagedata(sock, msg, xdr, base); + if (remainder == 0 || err != len) + goto out; + *sent_p += err; + base = 0; + } else { + base -= xdr->page_len; + } + + if (base >= xdr->tail[0].iov_len) + return 0; + msg->msg_flags &= ~MSG_MORE; + err = xprt_send_kvec(sock, msg, &xdr->tail[0], base); +out: + if (err > 0) { + *sent_p += err; + err = 0; + } + return err; +} diff --git a/net/sunrpc/socklib.h b/net/sunrpc/socklib.h new file mode 100644 index 000000000000..c48114ad6f00 --- /dev/null +++ b/net/sunrpc/socklib.h @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Copyright (C) 1995-1997 Olaf Kirch <okir@monad.swb.de> + * Copyright (C) 2020, Oracle. + */ + +#ifndef _NET_SUNRPC_SOCKLIB_H_ +#define _NET_SUNRPC_SOCKLIB_H_ + +int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb); +int xprt_sock_sendmsg(struct socket *sock, struct msghdr *msg, + struct xdr_buf *xdr, unsigned int base, + rpc_fraghdr marker, unsigned int *sent_p); + +#endif /* _NET_SUNRPC_SOCKLIB_H_ */ diff --git a/net/sunrpc/stats.c b/net/sunrpc/stats.c index 21b75cb08c03..383860cb1d5b 100644 --- a/net/sunrpc/stats.c +++ b/net/sunrpc/stats.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/stats.c * @@ -24,6 +25,8 @@ #include <linux/sunrpc/metrics.h> #include <linux/rcupdate.h> +#include <trace/events/sunrpc.h> + #include "netns.h" #define RPCDBG_FACILITY RPCDBG_MISC @@ -55,8 +58,7 @@ static int rpc_proc_show(struct seq_file *seq, void *v) { seq_printf(seq, "proc%u %u", vers->number, vers->nrprocs); for (j = 0; j < vers->nrprocs; j++) - seq_printf(seq, " %u", - vers->procs[j].p_count); + seq_printf(seq, " %u", vers->counts[j]); seq_putc(seq, '\n'); } return 0; @@ -64,25 +66,25 @@ static int rpc_proc_show(struct seq_file *seq, void *v) { static int rpc_proc_open(struct inode *inode, struct file *file) { - return single_open(file, rpc_proc_show, PDE_DATA(inode)); + return single_open(file, rpc_proc_show, pde_data(inode)); } -static const struct file_operations rpc_proc_fops = { - .owner = THIS_MODULE, - .open = rpc_proc_open, - .read = seq_read, - .llseek = seq_lseek, - .release = single_release, +static const struct proc_ops rpc_proc_ops = { + .proc_open = rpc_proc_open, + .proc_read = seq_read, + .proc_lseek = seq_lseek, + .proc_release = single_release, }; /* * Get RPC server stats */ -void svc_seq_show(struct seq_file *seq, const struct svc_stat *statp) { +void svc_seq_show(struct seq_file *seq, const struct svc_stat *statp) +{ const struct svc_program *prog = statp->program; - const struct svc_procedure *proc; const struct svc_version *vers; - unsigned int i, j; + unsigned int i, j, k; + unsigned long count; seq_printf(seq, "net %u %u %u %u\n", @@ -99,11 +101,16 @@ void svc_seq_show(struct seq_file *seq, const struct svc_stat *statp) { statp->rpcbadclnt); for (i = 0; i < prog->pg_nvers; i++) { - if (!(vers = prog->pg_vers[i]) || !(proc = vers->vs_proc)) + vers = prog->pg_vers[i]; + if (!vers) continue; seq_printf(seq, "proc%d %u", i, vers->vs_nproc); - for (j = 0; j < vers->vs_nproc; j++, proc++) - seq_printf(seq, " %u", proc->pc_count); + for (j = 0; j < vers->vs_nproc; j++) { + count = 0; + for_each_possible_cpu(k) + count += per_cpu(vers->vs_count[j], k); + seq_printf(seq, " %lu", count); + } seq_putc(seq, '\n'); } } @@ -116,7 +123,15 @@ EXPORT_SYMBOL_GPL(svc_seq_show); */ struct rpc_iostats *rpc_alloc_iostats(struct rpc_clnt *clnt) { - return kcalloc(clnt->cl_maxproc, sizeof(struct rpc_iostats), GFP_KERNEL); + struct rpc_iostats *stats; + int i; + + stats = kcalloc(clnt->cl_maxproc, sizeof(*stats), GFP_KERNEL); + if (stats) { + for (i = 0; i < clnt->cl_maxproc; i++) + spin_lock_init(&stats[i].om_lock); + } + return stats; } EXPORT_SYMBOL_GPL(rpc_alloc_iostats); @@ -132,42 +147,65 @@ void rpc_free_iostats(struct rpc_iostats *stats) EXPORT_SYMBOL_GPL(rpc_free_iostats); /** - * rpc_count_iostats - tally up per-task stats + * rpc_count_iostats_metrics - tally up per-task stats * @task: completed rpc_task - * @stats: array of stat structures - * - * Relies on the caller for serialization. + * @op_metrics: stat structure for OP that will accumulate stats from @task */ -void rpc_count_iostats(const struct rpc_task *task, struct rpc_iostats *stats) +void rpc_count_iostats_metrics(const struct rpc_task *task, + struct rpc_iostats *op_metrics) { struct rpc_rqst *req = task->tk_rqstp; - struct rpc_iostats *op_metrics; - ktime_t delta; + ktime_t backlog, execute, now; - if (!stats || !req) + if (!op_metrics || !req) return; - op_metrics = &stats[task->tk_msg.rpc_proc->p_statidx]; + now = ktime_get(); + spin_lock(&op_metrics->om_lock); op_metrics->om_ops++; - op_metrics->om_ntrans += req->rq_ntrans; + /* kernel API: om_ops must never become larger than om_ntrans */ + op_metrics->om_ntrans += max(req->rq_ntrans, 1); op_metrics->om_timeouts += task->tk_timeouts; op_metrics->om_bytes_sent += req->rq_xmit_bytes_sent; op_metrics->om_bytes_recv += req->rq_reply_bytes_recvd; - delta = ktime_sub(req->rq_xtime, task->tk_start); - op_metrics->om_queue = ktime_add(op_metrics->om_queue, delta); + backlog = 0; + if (ktime_to_ns(req->rq_xtime)) { + backlog = ktime_sub(req->rq_xtime, task->tk_start); + op_metrics->om_queue = ktime_add(op_metrics->om_queue, backlog); + } op_metrics->om_rtt = ktime_add(op_metrics->om_rtt, req->rq_rtt); - delta = ktime_sub(ktime_get(), task->tk_start); - op_metrics->om_execute = ktime_add(op_metrics->om_execute, delta); + execute = ktime_sub(now, task->tk_start); + op_metrics->om_execute = ktime_add(op_metrics->om_execute, execute); + if (task->tk_status < 0) + op_metrics->om_error_status++; + + spin_unlock(&op_metrics->om_lock); + + trace_rpc_stats_latency(req->rq_task, backlog, req->rq_rtt, execute); +} +EXPORT_SYMBOL_GPL(rpc_count_iostats_metrics); + +/** + * rpc_count_iostats - tally up per-task stats + * @task: completed rpc_task + * @stats: array of stat structures + * + * Uses the statidx from @task + */ +void rpc_count_iostats(const struct rpc_task *task, struct rpc_iostats *stats) +{ + rpc_count_iostats_metrics(task, + &stats[task->tk_msg.rpc_proc->p_statidx]); } EXPORT_SYMBOL_GPL(rpc_count_iostats); static void _print_name(struct seq_file *seq, unsigned int op, - struct rpc_procinfo *procs) + const struct rpc_procinfo *procs) { if (procs[op].p_name) seq_printf(seq, "\t%12s: ", procs[op].p_name); @@ -177,60 +215,89 @@ static void _print_name(struct seq_file *seq, unsigned int op, seq_printf(seq, "\t%12u: ", op); } -void rpc_print_iostats(struct seq_file *seq, struct rpc_clnt *clnt) +static void _add_rpc_iostats(struct rpc_iostats *a, struct rpc_iostats *b) +{ + a->om_ops += b->om_ops; + a->om_ntrans += b->om_ntrans; + a->om_timeouts += b->om_timeouts; + a->om_bytes_sent += b->om_bytes_sent; + a->om_bytes_recv += b->om_bytes_recv; + a->om_queue = ktime_add(a->om_queue, b->om_queue); + a->om_rtt = ktime_add(a->om_rtt, b->om_rtt); + a->om_execute = ktime_add(a->om_execute, b->om_execute); + a->om_error_status += b->om_error_status; +} + +static void _print_rpc_iostats(struct seq_file *seq, struct rpc_iostats *stats, + int op, const struct rpc_procinfo *procs) +{ + _print_name(seq, op, procs); + seq_printf(seq, "%lu %lu %lu %llu %llu %llu %llu %llu %lu\n", + stats->om_ops, + stats->om_ntrans, + stats->om_timeouts, + stats->om_bytes_sent, + stats->om_bytes_recv, + ktime_to_ms(stats->om_queue), + ktime_to_ms(stats->om_rtt), + ktime_to_ms(stats->om_execute), + stats->om_error_status); +} + +static int do_print_stats(struct rpc_clnt *clnt, struct rpc_xprt *xprt, void *seqv) +{ + struct seq_file *seq = seqv; + + xprt->ops->print_stats(xprt, seq); + return 0; +} + +void rpc_clnt_show_stats(struct seq_file *seq, struct rpc_clnt *clnt) { - struct rpc_iostats *stats = clnt->cl_metrics; - struct rpc_xprt *xprt; unsigned int op, maxproc = clnt->cl_maxproc; - if (!stats) + if (!clnt->cl_metrics) return; seq_printf(seq, "\tRPC iostats version: %s ", RPC_IOSTATS_VERS); seq_printf(seq, "p/v: %u/%u (%s)\n", - clnt->cl_prog, clnt->cl_vers, clnt->cl_protname); + clnt->cl_prog, clnt->cl_vers, clnt->cl_program->name); - rcu_read_lock(); - xprt = rcu_dereference(clnt->cl_xprt); - if (xprt) - xprt->ops->print_stats(xprt, seq); - rcu_read_unlock(); + rpc_clnt_iterate_for_each_xprt(clnt, do_print_stats, seq); seq_printf(seq, "\tper-op statistics\n"); for (op = 0; op < maxproc; op++) { - struct rpc_iostats *metrics = &stats[op]; - _print_name(seq, op, clnt->cl_procinfo); - seq_printf(seq, "%lu %lu %lu %Lu %Lu %Lu %Lu %Lu\n", - metrics->om_ops, - metrics->om_ntrans, - metrics->om_timeouts, - metrics->om_bytes_sent, - metrics->om_bytes_recv, - ktime_to_ms(metrics->om_queue), - ktime_to_ms(metrics->om_rtt), - ktime_to_ms(metrics->om_execute)); + struct rpc_iostats stats = {}; + struct rpc_clnt *next = clnt; + do { + _add_rpc_iostats(&stats, &next->cl_metrics[op]); + if (next == next->cl_parent) + break; + next = next->cl_parent; + } while (next); + _print_rpc_iostats(seq, &stats, op, clnt->cl_procinfo); } } -EXPORT_SYMBOL_GPL(rpc_print_iostats); +EXPORT_SYMBOL_GPL(rpc_clnt_show_stats); /* * Register/unregister RPC proc files */ static inline struct proc_dir_entry * do_register(struct net *net, const char *name, void *data, - const struct file_operations *fops) + const struct proc_ops *proc_ops) { struct sunrpc_net *sn; dprintk("RPC: registering /proc/net/rpc/%s\n", name); sn = net_generic(net, sunrpc_net_id); - return proc_create_data(name, 0, sn->proc_net_rpc, fops, data); + return proc_create_data(name, 0, sn->proc_net_rpc, proc_ops, data); } struct proc_dir_entry * rpc_proc_register(struct net *net, struct rpc_stat *statp) { - return do_register(net, statp->program->name, statp, &rpc_proc_fops); + return do_register(net, statp->program->name, statp, &rpc_proc_ops); } EXPORT_SYMBOL_GPL(rpc_proc_register); @@ -245,9 +312,9 @@ rpc_proc_unregister(struct net *net, const char *name) EXPORT_SYMBOL_GPL(rpc_proc_unregister); struct proc_dir_entry * -svc_proc_register(struct net *net, struct svc_stat *statp, const struct file_operations *fops) +svc_proc_register(struct net *net, struct svc_stat *statp, const struct proc_ops *proc_ops) { - return do_register(net, statp->program->pg_name, statp, fops); + return do_register(net, statp->program->pg_name, net, proc_ops); } EXPORT_SYMBOL_GPL(svc_proc_register); @@ -279,4 +346,3 @@ void rpc_proc_exit(struct net *net) dprintk("RPC: unregistering /proc/net/rpc\n"); remove_proc_entry("rpc", net->proc_net); } - diff --git a/net/sunrpc/sunrpc.h b/net/sunrpc/sunrpc.h index 14c9f6d1c5ff..e3c6e3b63f0b 100644 --- a/net/sunrpc/sunrpc.h +++ b/net/sunrpc/sunrpc.h @@ -1,22 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ /****************************************************************************** (c) 2008 NetApp. All Rights Reserved. -NetApp provides this source code under the GPL v2 License. -The GPL v2 license is available at -http://opensource.org/licenses/gpl-license.php. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ******************************************************************************/ @@ -37,17 +23,24 @@ struct rpc_buffer { char data[]; }; -static inline int rpc_reply_expected(struct rpc_task *task) +static inline int sock_is_loopback(struct sock *sk) { - return (task->tk_msg.rpc_proc != NULL) && - (task->tk_msg.rpc_proc->p_decode != NULL); + struct dst_entry *dst; + int loopback = 0; + rcu_read_lock(); + dst = rcu_dereference(sk->sk_dst_cache); + if (dst && dst->dev && + (dst->dev->features & NETIF_F_LOOPBACK)) + loopback = 1; + rcu_read_unlock(); + return loopback; } -int svc_send_common(struct socket *sock, struct xdr_buf *xdr, - struct page *headpage, unsigned long headoffset, - struct page *tailpage, unsigned long tailoffset); - +struct svc_serv; +struct svc_rqst; int rpc_clients_notifier_register(void); void rpc_clients_notifier_unregister(void); +void auth_domain_cleanup(void); +void svc_sock_update_bufs(struct svc_serv *serv); +enum svc_auth_status svc_authenticate(struct svc_rqst *rqstp); #endif /* _NET_SUNRPC_SUNRPC_H */ - diff --git a/net/sunrpc/sunrpc_syms.c b/net/sunrpc/sunrpc_syms.c index 3d6498af9adc..bab6cab29405 100644 --- a/net/sunrpc/sunrpc_syms.c +++ b/net/sunrpc/sunrpc_syms.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/sunrpc_syms.c * @@ -22,9 +23,11 @@ #include <linux/sunrpc/rpc_pipe_fs.h> #include <linux/sunrpc/xprtsock.h> +#include "sunrpc.h" +#include "sysfs.h" #include "netns.h" -int sunrpc_net_id; +unsigned int sunrpc_net_id; EXPORT_SYMBOL_GPL(sunrpc_net_id); static __net_init int sunrpc_init_net(struct net *net) @@ -44,12 +47,17 @@ static __net_init int sunrpc_init_net(struct net *net) if (err) goto err_unixgid; - rpc_pipefs_init_net(net); + err = rpc_pipefs_init_net(net); + if (err) + goto err_pipefs; + INIT_LIST_HEAD(&sn->all_clients); spin_lock_init(&sn->rpc_client_lock); spin_lock_init(&sn->rpcb_clnt_lock); return 0; +err_pipefs: + unix_gid_cache_destroy(net); err_unixgid: ip_map_cache_destroy(net); err_ipmap: @@ -60,9 +68,13 @@ err_proc: static __net_exit void sunrpc_exit_net(struct net *net) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + rpc_pipefs_exit_net(net); unix_gid_cache_destroy(net); ip_map_cache_destroy(net); rpc_proc_exit(net); + WARN_ON_ONCE(!list_empty(&sn->all_clients)); } static struct pernet_operations sunrpc_net_ops = { @@ -91,13 +103,21 @@ init_sunrpc(void) err = register_rpc_pipefs(); if (err) goto out4; -#ifdef RPC_DEBUG + + err = rpc_sysfs_init(); + if (err) + goto out5; + + sunrpc_debugfs_init(); +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) rpc_register_sysctl(); #endif svc_init_xprt_sock(); /* svc sock transport */ init_socket_xprt(); /* clnt sock transport */ return 0; +out5: + unregister_rpc_pipefs(); out4: unregister_pernet_subsys(&sunrpc_net_ops); out3: @@ -111,17 +131,24 @@ out: static void __exit cleanup_sunrpc(void) { + rpc_sysfs_exit(); + rpc_cleanup_clids(); + xprt_cleanup_ids(); + xprt_multipath_cleanup_ids(); rpcauth_remove_module(); cleanup_socket_xprt(); svc_cleanup_xprt_sock(); + sunrpc_debugfs_exit(); unregister_rpc_pipefs(); rpc_destroy_mempool(); unregister_pernet_subsys(&sunrpc_net_ops); -#ifdef RPC_DEBUG + auth_domain_cleanup(); +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) rpc_unregister_sysctl(); #endif rcu_barrier(); /* Wait for completion of call_rcu()'s */ } +MODULE_DESCRIPTION("Sun RPC core"); MODULE_LICENSE("GPL"); fs_initcall(init_sunrpc); /* Ensure we're initialised before nfs */ module_exit(cleanup_sunrpc); diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c index b974571126fe..4704dce7284e 100644 --- a/net/sunrpc/svc.c +++ b/net/sunrpc/svc.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/svc.c * @@ -11,7 +12,7 @@ */ #include <linux/linkage.h> -#include <linux/sched.h> +#include <linux/sched/signal.h> #include <linux/errno.h> #include <linux/net.h> #include <linux/in.h> @@ -28,11 +29,16 @@ #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/bc_xprt.h> +#include <trace/events/sunrpc.h> + +#include "fail.h" +#include "sunrpc.h" + #define RPCDBG_FACILITY RPCDBG_SVCDSP static void svc_unregister(const struct svc_serv *serv, struct net *net); -#define svc_serv_is_pooled(serv) ((serv)->sv_function) +#define SVC_POOL_DEFAULT SVC_POOL_GLOBAL /* * Mode for mapping cpus to pools. @@ -44,13 +50,13 @@ enum { SVC_POOL_PERCPU, /* one pool per cpu */ SVC_POOL_PERNODE /* one pool per numa node */ }; -#define SVC_POOL_DEFAULT SVC_POOL_GLOBAL /* * Structure for mapping cpus to pools and vice versa. * Setup once during sunrpc initialisation. */ -static struct svc_pool_map { + +struct svc_pool_map { int count; /* How many svc_servs use us */ int mode; /* Note: int not enum to avoid * warnings about "enumeration value @@ -58,64 +64,109 @@ static struct svc_pool_map { unsigned int npools; unsigned int *pool_to; /* maps pool id to cpu or node */ unsigned int *to_pool; /* maps cpu or node to pool id */ -} svc_pool_map = { - .count = 0, +}; + +static struct svc_pool_map svc_pool_map = { .mode = SVC_POOL_DEFAULT }; + static DEFINE_MUTEX(svc_pool_map_mutex);/* protects svc_pool_map.count only */ static int -param_set_pool_mode(const char *val, struct kernel_param *kp) +__param_set_pool_mode(const char *val, struct svc_pool_map *m) { - int *ip = (int *)kp->arg; - struct svc_pool_map *m = &svc_pool_map; - int err; + int err, mode; mutex_lock(&svc_pool_map_mutex); - err = -EBUSY; - if (m->count) - goto out; - err = 0; if (!strncmp(val, "auto", 4)) - *ip = SVC_POOL_AUTO; + mode = SVC_POOL_AUTO; else if (!strncmp(val, "global", 6)) - *ip = SVC_POOL_GLOBAL; + mode = SVC_POOL_GLOBAL; else if (!strncmp(val, "percpu", 6)) - *ip = SVC_POOL_PERCPU; + mode = SVC_POOL_PERCPU; else if (!strncmp(val, "pernode", 7)) - *ip = SVC_POOL_PERNODE; + mode = SVC_POOL_PERNODE; else err = -EINVAL; + if (err) + goto out; + + if (m->count == 0) + m->mode = mode; + else if (mode != m->mode) + err = -EBUSY; out: mutex_unlock(&svc_pool_map_mutex); return err; } static int -param_get_pool_mode(char *buf, struct kernel_param *kp) +param_set_pool_mode(const char *val, const struct kernel_param *kp) +{ + struct svc_pool_map *m = kp->arg; + + return __param_set_pool_mode(val, m); +} + +int sunrpc_set_pool_mode(const char *val) +{ + return __param_set_pool_mode(val, &svc_pool_map); +} +EXPORT_SYMBOL(sunrpc_set_pool_mode); + +/** + * sunrpc_get_pool_mode - get the current pool_mode for the host + * @buf: where to write the current pool_mode + * @size: size of @buf + * + * Grab the current pool_mode from the svc_pool_map and write + * the resulting string to @buf. Returns the number of characters + * written to @buf (a'la snprintf()). + */ +int +sunrpc_get_pool_mode(char *buf, size_t size) { - int *ip = (int *)kp->arg; + struct svc_pool_map *m = &svc_pool_map; - switch (*ip) + switch (m->mode) { case SVC_POOL_AUTO: - return strlcpy(buf, "auto", 20); + return snprintf(buf, size, "auto"); case SVC_POOL_GLOBAL: - return strlcpy(buf, "global", 20); + return snprintf(buf, size, "global"); case SVC_POOL_PERCPU: - return strlcpy(buf, "percpu", 20); + return snprintf(buf, size, "percpu"); case SVC_POOL_PERNODE: - return strlcpy(buf, "pernode", 20); + return snprintf(buf, size, "pernode"); default: - return sprintf(buf, "%d", *ip); + return snprintf(buf, size, "%d", m->mode); } } +EXPORT_SYMBOL(sunrpc_get_pool_mode); + +static int +param_get_pool_mode(char *buf, const struct kernel_param *kp) +{ + char str[16]; + int len; + + len = sunrpc_get_pool_mode(str, ARRAY_SIZE(str)); + + /* Ensure we have room for newline and NUL */ + len = min_t(int, len, ARRAY_SIZE(str) - 2); + + /* tack on the newline */ + str[len] = '\n'; + str[len + 1] = '\0'; + + return sysfs_emit(buf, "%s", str); +} module_param_call(pool_mode, param_set_pool_mode, param_get_pool_mode, - &svc_pool_map.mode, 0644); + &svc_pool_map, 0644); /* * Detect best pool mapping mode heuristically, @@ -189,7 +240,7 @@ svc_pool_map_init_percpu(struct svc_pool_map *m) return err; for_each_online_cpu(cpu) { - BUG_ON(pidx > maxpools); + BUG_ON(pidx >= maxpools); m->to_pool[cpu] = pidx; m->pool_to[pidx] = cpu; pidx++; @@ -231,8 +282,10 @@ svc_pool_map_init_pernode(struct svc_pool_map *m) /* * Add a reference to the global map of cpus to pools (and - * vice versa). Initialise the map if we're the first user. - * Returns the number of pools. + * vice versa) if pools are in use. + * Initialise the map if we're the first user. + * Returns the number of pools. If this is '1', no reference + * was taken. */ static unsigned int svc_pool_map_get(void) @@ -241,7 +294,6 @@ svc_pool_map_get(void) int npools = -1; mutex_lock(&svc_pool_map_mutex); - if (m->count++) { mutex_unlock(&svc_pool_map_mutex); return m->npools; @@ -259,24 +311,20 @@ svc_pool_map_get(void) break; } - if (npools < 0) { + if (npools <= 0) { /* default, or memory allocation failure */ npools = 1; m->mode = SVC_POOL_GLOBAL; } m->npools = npools; - mutex_unlock(&svc_pool_map_mutex); - return m->npools; + return npools; } - /* * Drop a reference to the global map of cpus to pools. * When the last reference is dropped, the map data is - * freed; this allows the sysadmin to change the pool - * mode using the pool_mode module option without - * rebooting or re-loading sunrpc.ko. + * freed; this allows the sysadmin to change the pool. */ static void svc_pool_map_put(void) @@ -284,7 +332,6 @@ svc_pool_map_put(void) struct svc_pool_map *m = &svc_pool_map; mutex_lock(&svc_pool_map_mutex); - if (!--m->count) { kfree(m->to_pool); m->to_pool = NULL; @@ -292,11 +339,9 @@ svc_pool_map_put(void) m->pool_to = NULL; m->npools = 0; } - mutex_unlock(&svc_pool_map_mutex); } - static int svc_pool_map_get_node(unsigned int pidx) { const struct svc_pool_map *m = &svc_pool_map; @@ -307,7 +352,7 @@ static int svc_pool_map_get_node(unsigned int pidx) if (m->mode == SVC_POOL_PERNODE) return m->pool_to[pidx]; } - return NUMA_NO_NODE; + return numa_mem_id(); } /* * Set the given thread's cpus_allowed mask so that it @@ -341,36 +386,39 @@ svc_pool_map_set_cpumask(struct task_struct *task, unsigned int pidx) } } -/* - * Use the mapping mode to choose a pool for a given CPU. - * Used when enqueueing an incoming RPC. Always returns - * a non-NULL pool pointer. +/** + * svc_pool_for_cpu - Select pool to run a thread on this cpu + * @serv: An RPC service + * + * Use the active CPU and the svc_pool_map's mode setting to + * select the svc thread pool to use. Once initialized, the + * svc_pool_map does not change. + * + * Return value: + * A pointer to an svc_pool */ -struct svc_pool * -svc_pool_for_cpu(struct svc_serv *serv, int cpu) +struct svc_pool *svc_pool_for_cpu(struct svc_serv *serv) { struct svc_pool_map *m = &svc_pool_map; + int cpu = raw_smp_processor_id(); unsigned int pidx = 0; - /* - * An uninitialised map happens in a pure client when - * lockd is brought up, so silently treat it the - * same as SVC_POOL_GLOBAL. - */ - if (svc_serv_is_pooled(serv)) { - switch (m->mode) { - case SVC_POOL_PERCPU: - pidx = m->to_pool[cpu]; - break; - case SVC_POOL_PERNODE: - pidx = m->to_pool[cpu_to_node(cpu)]; - break; - } + if (serv->sv_nrpools <= 1) + return serv->sv_pools; + + switch (m->mode) { + case SVC_POOL_PERCPU: + pidx = m->to_pool[cpu]; + break; + case SVC_POOL_PERNODE: + pidx = m->to_pool[cpu_to_node(cpu)]; + break; } + return &serv->sv_pools[pidx % serv->sv_nrpools]; } -int svc_rpcb_setup(struct svc_serv *serv, struct net *net) +static int svc_rpcb_setup(struct svc_serv *serv, struct net *net) { int err; @@ -382,25 +430,24 @@ int svc_rpcb_setup(struct svc_serv *serv, struct net *net) svc_unregister(serv, net); return 0; } -EXPORT_SYMBOL_GPL(svc_rpcb_setup); void svc_rpcb_cleanup(struct svc_serv *serv, struct net *net) { svc_unregister(serv, net); rpcb_put_local(net); } -EXPORT_SYMBOL_GPL(svc_rpcb_cleanup); static int svc_uses_rpcbind(struct svc_serv *serv) { - struct svc_program *progp; - unsigned int i; + unsigned int p, i; + + for (p = 0; p < serv->sv_nprogs; p++) { + struct svc_program *progp = &serv->sv_programs[p]; - for (progp = serv->sv_program; progp; progp = progp->pg_next) { for (i = 0; i < progp->pg_nvers; i++) { if (progp->pg_vers[i] == NULL) continue; - if (progp->pg_vers[i]->vs_hidden == 0) + if (!progp->pg_vers[i]->vs_hidden) return 1; } } @@ -416,12 +463,25 @@ int svc_bind(struct svc_serv *serv, struct net *net) } EXPORT_SYMBOL_GPL(svc_bind); +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void +__svc_init_bc(struct svc_serv *serv) +{ + lwq_init(&serv->sv_cb_list); +} +#else +static void +__svc_init_bc(struct svc_serv *serv) +{ +} +#endif + /* * Create an RPC service */ static struct svc_serv * -__svc_create(struct svc_program *prog, unsigned int bufsize, int npools, - void (*shutdown)(struct svc_serv *serv, struct net *net)) +__svc_create(struct svc_program *prog, int nprogs, struct svc_stat *stats, + unsigned int bufsize, int npools, int (*threadfn)(void *data)) { struct svc_serv *serv; unsigned int vers; @@ -431,33 +491,36 @@ __svc_create(struct svc_program *prog, unsigned int bufsize, int npools, if (!(serv = kzalloc(sizeof(*serv), GFP_KERNEL))) return NULL; serv->sv_name = prog->pg_name; - serv->sv_program = prog; - serv->sv_nrthreads = 1; - serv->sv_stats = prog->pg_stats; + serv->sv_programs = prog; + serv->sv_nprogs = nprogs; + serv->sv_stats = stats; if (bufsize > RPCSVC_MAXPAYLOAD) bufsize = RPCSVC_MAXPAYLOAD; serv->sv_max_payload = bufsize? bufsize : 4096; serv->sv_max_mesg = roundup(serv->sv_max_payload + PAGE_SIZE, PAGE_SIZE); - serv->sv_shutdown = shutdown; + serv->sv_threadfn = threadfn; xdrsize = 0; - while (prog) { - prog->pg_lovers = prog->pg_nvers-1; - for (vers=0; vers<prog->pg_nvers ; vers++) - if (prog->pg_vers[vers]) { - prog->pg_hivers = vers; - if (prog->pg_lovers > vers) - prog->pg_lovers = vers; - if (prog->pg_vers[vers]->vs_xdrsize > xdrsize) - xdrsize = prog->pg_vers[vers]->vs_xdrsize; + for (i = 0; i < nprogs; i++) { + struct svc_program *progp = &prog[i]; + + progp->pg_lovers = progp->pg_nvers-1; + for (vers = 0; vers < progp->pg_nvers ; vers++) + if (progp->pg_vers[vers]) { + progp->pg_hivers = vers; + if (progp->pg_lovers > vers) + progp->pg_lovers = vers; + if (progp->pg_vers[vers]->vs_xdrsize > xdrsize) + xdrsize = progp->pg_vers[vers]->vs_xdrsize; } - prog = prog->pg_next; } serv->sv_xdrsize = xdrsize; INIT_LIST_HEAD(&serv->sv_tempsocks); INIT_LIST_HEAD(&serv->sv_permsocks); - init_timer(&serv->sv_temptimer); + timer_setup(&serv->sv_temptimer, NULL, 0); spin_lock_init(&serv->sv_lock); + __svc_init_bc(serv); + serv->sv_nrpools = npools; serv->sv_pools = kcalloc(serv->sv_nrpools, sizeof(struct svc_pool), @@ -474,120 +537,116 @@ __svc_create(struct svc_program *prog, unsigned int bufsize, int npools, i, serv->sv_name); pool->sp_id = i; - INIT_LIST_HEAD(&pool->sp_threads); - INIT_LIST_HEAD(&pool->sp_sockets); + lwq_init(&pool->sp_xprts); INIT_LIST_HEAD(&pool->sp_all_threads); - spin_lock_init(&pool->sp_lock); - } + init_llist_head(&pool->sp_idle_threads); - if (svc_uses_rpcbind(serv) && (!serv->sv_shutdown)) - serv->sv_shutdown = svc_rpcb_cleanup; + percpu_counter_init(&pool->sp_messages_arrived, 0, GFP_KERNEL); + percpu_counter_init(&pool->sp_sockets_queued, 0, GFP_KERNEL); + percpu_counter_init(&pool->sp_threads_woken, 0, GFP_KERNEL); + } return serv; } -struct svc_serv * -svc_create(struct svc_program *prog, unsigned int bufsize, - void (*shutdown)(struct svc_serv *serv, struct net *net)) +/** + * svc_create - Create an RPC service + * @prog: the RPC program the new service will handle + * @bufsize: maximum message size for @prog + * @threadfn: a function to service RPC requests for @prog + * + * Returns an instantiated struct svc_serv object or NULL. + */ +struct svc_serv *svc_create(struct svc_program *prog, unsigned int bufsize, + int (*threadfn)(void *data)) { - return __svc_create(prog, bufsize, /*npools*/1, shutdown); + return __svc_create(prog, 1, NULL, bufsize, 1, threadfn); } EXPORT_SYMBOL_GPL(svc_create); -struct svc_serv * -svc_create_pooled(struct svc_program *prog, unsigned int bufsize, - void (*shutdown)(struct svc_serv *serv, struct net *net), - svc_thread_fn func, struct module *mod) +/** + * svc_create_pooled - Create an RPC service with pooled threads + * @prog: Array of RPC programs the new service will handle + * @nprogs: Number of programs in the array + * @stats: the stats struct if desired + * @bufsize: maximum message size for @prog + * @threadfn: a function to service RPC requests for @prog + * + * Returns an instantiated struct svc_serv object or NULL. + */ +struct svc_serv *svc_create_pooled(struct svc_program *prog, + unsigned int nprogs, + struct svc_stat *stats, + unsigned int bufsize, + int (*threadfn)(void *data)) { struct svc_serv *serv; unsigned int npools = svc_pool_map_get(); - serv = __svc_create(prog, bufsize, npools, shutdown); - - if (serv != NULL) { - serv->sv_function = func; - serv->sv_module = mod; - } - + serv = __svc_create(prog, nprogs, stats, bufsize, npools, threadfn); + if (!serv) + goto out_err; + serv->sv_is_pooled = true; return serv; +out_err: + svc_pool_map_put(); + return NULL; } EXPORT_SYMBOL_GPL(svc_create_pooled); -void svc_shutdown_net(struct svc_serv *serv, struct net *net) -{ - svc_close_net(serv, net); - - if (serv->sv_shutdown) - serv->sv_shutdown(serv, net); -} -EXPORT_SYMBOL_GPL(svc_shutdown_net); - /* * Destroy an RPC service. Should be called with appropriate locking to - * protect the sv_nrthreads, sv_permsocks and sv_tempsocks. + * protect sv_permsocks and sv_tempsocks. */ void -svc_destroy(struct svc_serv *serv) +svc_destroy(struct svc_serv **servp) { - dprintk("svc: svc_destroy(%s, %d)\n", - serv->sv_program->pg_name, - serv->sv_nrthreads); - - if (serv->sv_nrthreads) { - if (--(serv->sv_nrthreads) != 0) { - svc_sock_update_bufs(serv); - return; - } - } else - printk("svc_destroy: no threads for serv=%p!\n", serv); + struct svc_serv *serv = *servp; + unsigned int i; + + *servp = NULL; - del_timer_sync(&serv->sv_temptimer); + dprintk("svc: svc_destroy(%s)\n", serv->sv_programs->pg_name); + timer_shutdown_sync(&serv->sv_temptimer); /* - * The last user is gone and thus all sockets have to be destroyed to - * the point. Check this. + * Remaining transports at this point are not expected. */ - BUG_ON(!list_empty(&serv->sv_permsocks)); - BUG_ON(!list_empty(&serv->sv_tempsocks)); + WARN_ONCE(!list_empty(&serv->sv_permsocks), + "SVC: permsocks remain for %s\n", serv->sv_programs->pg_name); + WARN_ONCE(!list_empty(&serv->sv_tempsocks), + "SVC: tempsocks remain for %s\n", serv->sv_programs->pg_name); cache_clean_deferred(serv); - if (svc_serv_is_pooled(serv)) + if (serv->sv_is_pooled) svc_pool_map_put(); + for (i = 0; i < serv->sv_nrpools; i++) { + struct svc_pool *pool = &serv->sv_pools[i]; + + percpu_counter_destroy(&pool->sp_messages_arrived); + percpu_counter_destroy(&pool->sp_sockets_queued); + percpu_counter_destroy(&pool->sp_threads_woken); + } kfree(serv->sv_pools); kfree(serv); } EXPORT_SYMBOL_GPL(svc_destroy); -/* - * Allocate an RPC server's buffer space. - * We allocate pages and place them in rq_argpages. - */ -static int -svc_init_buffer(struct svc_rqst *rqstp, unsigned int size, int node) +static bool +svc_init_buffer(struct svc_rqst *rqstp, const struct svc_serv *serv, int node) { - unsigned int pages, arghi; - - /* bc_xprt uses fore channel allocated buffers */ - if (svc_is_backchannel(rqstp)) - return 1; - - pages = size / PAGE_SIZE + 1; /* extra page as we hold both request and reply. - * We assume one is at most one page - */ - arghi = 0; - WARN_ON_ONCE(pages > RPCSVC_MAXPAGES); - if (pages > RPCSVC_MAXPAGES) - pages = RPCSVC_MAXPAGES; - while (pages) { - struct page *p = alloc_pages_node(node, GFP_KERNEL, 0); - if (!p) - break; - rqstp->rq_pages[arghi++] = p; - pages--; - } - return pages == 0; + rqstp->rq_maxpages = svc_serv_maxpages(serv); + + /* rq_pages' last entry is NULL for historical reasons. */ + rqstp->rq_pages = kcalloc_node(rqstp->rq_maxpages + 1, + sizeof(struct page *), + GFP_KERNEL, node); + if (!rqstp->rq_pages) + return false; + + return true; } /* @@ -596,156 +655,168 @@ svc_init_buffer(struct svc_rqst *rqstp, unsigned int size, int node) static void svc_release_buffer(struct svc_rqst *rqstp) { - unsigned int i; + unsigned long i; - for (i = 0; i < ARRAY_SIZE(rqstp->rq_pages); i++) + for (i = 0; i < rqstp->rq_maxpages; i++) if (rqstp->rq_pages[i]) put_page(rqstp->rq_pages[i]); + kfree(rqstp->rq_pages); +} + +static void +svc_rqst_free(struct svc_rqst *rqstp) +{ + folio_batch_release(&rqstp->rq_fbatch); + kfree(rqstp->rq_bvec); + svc_release_buffer(rqstp); + if (rqstp->rq_scratch_folio) + folio_put(rqstp->rq_scratch_folio); + kfree(rqstp->rq_resp); + kfree(rqstp->rq_argp); + kfree(rqstp->rq_auth_data); + kfree_rcu(rqstp, rq_rcu_head); } -struct svc_rqst * +static struct svc_rqst * svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) { struct svc_rqst *rqstp; rqstp = kzalloc_node(sizeof(*rqstp), GFP_KERNEL, node); if (!rqstp) - goto out_enomem; + return rqstp; - init_waitqueue_head(&rqstp->rq_wait); + folio_batch_init(&rqstp->rq_fbatch); - serv->sv_nrthreads++; - spin_lock_bh(&pool->sp_lock); - pool->sp_nrthreads++; - list_add(&rqstp->rq_all, &pool->sp_all_threads); - spin_unlock_bh(&pool->sp_lock); rqstp->rq_server = serv; rqstp->rq_pool = pool; + rqstp->rq_scratch_folio = __folio_alloc_node(GFP_KERNEL, 0, node); + if (!rqstp->rq_scratch_folio) + goto out_enomem; + rqstp->rq_argp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node); if (!rqstp->rq_argp) - goto out_thread; + goto out_enomem; rqstp->rq_resp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node); if (!rqstp->rq_resp) - goto out_thread; + goto out_enomem; + + if (!svc_init_buffer(rqstp, serv, node)) + goto out_enomem; + + rqstp->rq_bvec = kcalloc_node(rqstp->rq_maxpages, + sizeof(struct bio_vec), + GFP_KERNEL, node); + if (!rqstp->rq_bvec) + goto out_enomem; + + rqstp->rq_err = -EAGAIN; /* No error yet */ + + serv->sv_nrthreads += 1; + pool->sp_nrthreads += 1; - if (!svc_init_buffer(rqstp, serv->sv_max_mesg, node)) - goto out_thread; + /* Protected by whatever lock the service uses when calling + * svc_set_num_threads() + */ + list_add_rcu(&rqstp->rq_all, &pool->sp_all_threads); return rqstp; -out_thread: - svc_exit_thread(rqstp); + out_enomem: - return ERR_PTR(-ENOMEM); + svc_rqst_free(rqstp); + return NULL; } -EXPORT_SYMBOL_GPL(svc_prepare_thread); -/* - * Choose a pool in which to create a new thread, for svc_set_num_threads +/** + * svc_pool_wake_idle_thread - Awaken an idle thread in @pool + * @pool: service thread pool + * + * Can be called from soft IRQ or process context. Finding an idle + * service thread and marking it BUSY is atomic with respect to + * other calls to svc_pool_wake_idle_thread(). + * */ -static inline struct svc_pool * -choose_pool(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state) +void svc_pool_wake_idle_thread(struct svc_pool *pool) { - if (pool != NULL) - return pool; + struct svc_rqst *rqstp; + struct llist_node *ln; + + rcu_read_lock(); + ln = READ_ONCE(pool->sp_idle_threads.first); + if (ln) { + rqstp = llist_entry(ln, struct svc_rqst, rq_idle); + WRITE_ONCE(rqstp->rq_qtime, ktime_get()); + if (!task_is_running(rqstp->rq_task)) { + wake_up_process(rqstp->rq_task); + trace_svc_pool_thread_wake(pool, rqstp->rq_task->pid); + percpu_counter_inc(&pool->sp_threads_woken); + } else { + trace_svc_pool_thread_running(pool, rqstp->rq_task->pid); + } + rcu_read_unlock(); + return; + } + rcu_read_unlock(); + trace_svc_pool_thread_noidle(pool, 0); +} +EXPORT_SYMBOL_GPL(svc_pool_wake_idle_thread); - return &serv->sv_pools[(*state)++ % serv->sv_nrpools]; +static struct svc_pool * +svc_pool_next(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state) +{ + return pool ? pool : &serv->sv_pools[(*state)++ % serv->sv_nrpools]; } -/* - * Choose a thread to kill, for svc_set_num_threads - */ -static inline struct task_struct * -choose_victim(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state) +static struct svc_pool * +svc_pool_victim(struct svc_serv *serv, struct svc_pool *target_pool, + unsigned int *state) { + struct svc_pool *pool; unsigned int i; - struct task_struct *task = NULL; - if (pool != NULL) { - spin_lock_bh(&pool->sp_lock); - } else { - /* choose a pool in round-robin fashion */ + pool = target_pool; + + if (!pool) { for (i = 0; i < serv->sv_nrpools; i++) { pool = &serv->sv_pools[--(*state) % serv->sv_nrpools]; - spin_lock_bh(&pool->sp_lock); - if (!list_empty(&pool->sp_all_threads)) - goto found_pool; - spin_unlock_bh(&pool->sp_lock); + if (pool->sp_nrthreads) + break; } - return NULL; } -found_pool: - if (!list_empty(&pool->sp_all_threads)) { - struct svc_rqst *rqstp; - - /* - * Remove from the pool->sp_all_threads list - * so we don't try to kill it again. - */ - rqstp = list_entry(pool->sp_all_threads.next, struct svc_rqst, rq_all); - list_del_init(&rqstp->rq_all); - task = rqstp->rq_task; + if (pool && pool->sp_nrthreads) { + set_bit(SP_VICTIM_REMAINS, &pool->sp_flags); + set_bit(SP_NEED_VICTIM, &pool->sp_flags); + return pool; } - spin_unlock_bh(&pool->sp_lock); - - return task; + return NULL; } -/* - * Create or destroy enough new threads to make the number - * of threads the given number. If `pool' is non-NULL, applies - * only to threads in that pool, otherwise round-robins between - * all pools. Caller must ensure that mutual exclusion between this and - * server startup or shutdown. - * - * Destroying threads relies on the service threads filling in - * rqstp->rq_task, which only the nfs ones do. Assumes the serv - * has been created using svc_create_pooled(). - * - * Based on code that used to be in nfsd_svc() but tweaked - * to be pool-aware. - */ -int -svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +static int +svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) { struct svc_rqst *rqstp; struct task_struct *task; struct svc_pool *chosen_pool; - int error = 0; unsigned int state = serv->sv_nrthreads-1; int node; + int err; - if (pool == NULL) { - /* The -1 assumes caller has done a svc_get() */ - nrservs -= (serv->sv_nrthreads-1); - } else { - spin_lock_bh(&pool->sp_lock); - nrservs -= pool->sp_nrthreads; - spin_unlock_bh(&pool->sp_lock); - } - - /* create new threads */ - while (nrservs > 0) { + do { nrservs--; - chosen_pool = choose_pool(serv, pool, &state); - + chosen_pool = svc_pool_next(serv, pool, &state); node = svc_pool_map_get_node(chosen_pool->sp_id); - rqstp = svc_prepare_thread(serv, chosen_pool, node); - if (IS_ERR(rqstp)) { - error = PTR_ERR(rqstp); - break; - } - __module_get(serv->sv_module); - task = kthread_create_on_node(serv->sv_function, rqstp, + rqstp = svc_prepare_thread(serv, chosen_pool, node); + if (!rqstp) + return -ENOMEM; + task = kthread_create_on_node(serv->sv_threadfn, rqstp, node, "%s", serv->sv_name); if (IS_ERR(task)) { - error = PTR_ERR(task); - module_put(serv->sv_module); svc_exit_thread(rqstp); - break; + return PTR_ERR(task); } rqstp->rq_task = task; @@ -754,21 +825,134 @@ svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) svc_sock_update_bufs(serv); wake_up_process(task); - } - /* destroy old threads */ - while (nrservs < 0 && - (task = choose_victim(serv, pool, &state)) != NULL) { - send_sig(SIGINT, task, 1); + + wait_var_event(&rqstp->rq_err, rqstp->rq_err != -EAGAIN); + err = rqstp->rq_err; + if (err) { + svc_exit_thread(rqstp); + return err; + } + } while (nrservs > 0); + + return 0; +} + +static int +svc_stop_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + unsigned int state = serv->sv_nrthreads-1; + struct svc_pool *victim; + + do { + victim = svc_pool_victim(serv, pool, &state); + if (!victim) + break; + svc_pool_wake_idle_thread(victim); + wait_on_bit(&victim->sp_flags, SP_VICTIM_REMAINS, + TASK_IDLE); nrservs++; - } + } while (nrservs < 0); + return 0; +} - return error; +/** + * svc_set_num_threads - adjust number of threads per RPC service + * @serv: RPC service to adjust + * @pool: Specific pool from which to choose threads, or NULL + * @nrservs: New number of threads for @serv (0 or less means kill all threads) + * + * Create or destroy threads to make the number of threads for @serv the + * given number. If @pool is non-NULL, change only threads in that pool; + * otherwise, round-robin between all pools for @serv. @serv's + * sv_nrthreads is adjusted for each thread created or destroyed. + * + * Caller must ensure mutual exclusion between this and server startup or + * shutdown. + * + * Returns zero on success or a negative errno if an error occurred while + * starting a thread. + */ +int +svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + if (!pool) + nrservs -= serv->sv_nrthreads; + else + nrservs -= pool->sp_nrthreads; + + if (nrservs > 0) + return svc_start_kthreads(serv, pool, nrservs); + if (nrservs < 0) + return svc_stop_kthreads(serv, pool, nrservs); + return 0; } EXPORT_SYMBOL_GPL(svc_set_num_threads); -/* - * Called from a server thread as it's exiting. Caller must hold the BKL or - * the "service mutex", whichever is appropriate for the service. +/** + * svc_rqst_replace_page - Replace one page in rq_pages[] + * @rqstp: svc_rqst with pages to replace + * @page: replacement page + * + * When replacing a page in rq_pages, batch the release of the + * replaced pages to avoid hammering the page allocator. + * + * Return values: + * %true: page replaced + * %false: array bounds checking failed + */ +bool svc_rqst_replace_page(struct svc_rqst *rqstp, struct page *page) +{ + struct page **begin = rqstp->rq_pages; + struct page **end = &rqstp->rq_pages[rqstp->rq_maxpages]; + + if (unlikely(rqstp->rq_next_page < begin || rqstp->rq_next_page > end)) { + trace_svc_replace_page_err(rqstp); + return false; + } + + if (*rqstp->rq_next_page) { + if (!folio_batch_add(&rqstp->rq_fbatch, + page_folio(*rqstp->rq_next_page))) + __folio_batch_release(&rqstp->rq_fbatch); + } + + get_page(page); + *(rqstp->rq_next_page++) = page; + return true; +} +EXPORT_SYMBOL_GPL(svc_rqst_replace_page); + +/** + * svc_rqst_release_pages - Release Reply buffer pages + * @rqstp: RPC transaction context + * + * Release response pages that might still be in flight after + * svc_send, and any spliced filesystem-owned pages. + */ +void svc_rqst_release_pages(struct svc_rqst *rqstp) +{ + int i, count = rqstp->rq_next_page - rqstp->rq_respages; + + if (count) { + release_pages(rqstp->rq_respages, count); + for (i = 0; i < count; i++) + rqstp->rq_respages[i] = NULL; + } +} + +/** + * svc_exit_thread - finalise the termination of a sunrpc server thread + * @rqstp: the svc_rqst which represents the thread. + * + * When a thread started with svc_new_thread() exits it must call + * svc_exit_thread() as its last act. This must be done with the + * service mutex held. Normally this is held by a DIFFERENT thread, the + * one that is calling svc_set_num_threads() and which will wait for + * SP_VICTIM_REMAINS to be cleared before dropping the mutex. If the + * thread exits for any reason other than svc_thread_should_stop() + * returning %true (which indicated that svc_set_num_threads() is + * waiting for it to exit), then it must take the service mutex itself, + * which can only safely be done using mutex_try_lock(). */ void svc_exit_thread(struct svc_rqst *rqstp) @@ -776,21 +960,15 @@ svc_exit_thread(struct svc_rqst *rqstp) struct svc_serv *serv = rqstp->rq_server; struct svc_pool *pool = rqstp->rq_pool; - svc_release_buffer(rqstp); - kfree(rqstp->rq_resp); - kfree(rqstp->rq_argp); - kfree(rqstp->rq_auth_data); + list_del_rcu(&rqstp->rq_all); - spin_lock_bh(&pool->sp_lock); - pool->sp_nrthreads--; - list_del(&rqstp->rq_all); - spin_unlock_bh(&pool->sp_lock); + pool->sp_nrthreads -= 1; + serv->sv_nrthreads -= 1; + svc_sock_update_bufs(serv); - kfree(rqstp); + svc_rqst_free(rqstp); - /* Release the server */ - if (serv) - svc_destroy(serv); + clear_and_wake_up_bit(SP_VICTIM_REMAINS, &pool->sp_flags); } EXPORT_SYMBOL_GPL(svc_exit_thread); @@ -916,12 +1094,54 @@ static int __svc_register(struct net *net, const char *progname, #endif } - if (error < 0) - printk(KERN_WARNING "svc: failed to register %sv%u RPC " - "service (errno %d).\n", progname, version, -error); + trace_svc_register(progname, version, family, protocol, port, error); return error; } +static +int svc_rpcbind_set_version(struct net *net, + const struct svc_program *progp, + u32 version, int family, + unsigned short proto, + unsigned short port) +{ + return __svc_register(net, progp->pg_name, progp->pg_prog, + version, family, proto, port); + +} + +int svc_generic_rpcbind_set(struct net *net, + const struct svc_program *progp, + u32 version, int family, + unsigned short proto, + unsigned short port) +{ + const struct svc_version *vers = progp->pg_vers[version]; + int error; + + if (vers == NULL) + return 0; + + if (vers->vs_hidden) { + trace_svc_noregister(progp->pg_name, version, proto, + port, family, 0); + return 0; + } + + /* + * Don't register a UDP port if we need congestion + * control. + */ + if (vers->vs_need_cong_ctrl && proto == IPPROTO_UDP) + return 0; + + error = svc_rpcbind_set_version(net, progp, version, + family, proto, port); + + return (vers->vs_rpcb_optnl) ? 0 : error; +} +EXPORT_SYMBOL_GPL(svc_generic_rpcbind_set); + /** * svc_register - register an RPC service with the local portmapper * @serv: svc_serv struct for the service to register @@ -936,35 +1156,26 @@ int svc_register(const struct svc_serv *serv, struct net *net, const int family, const unsigned short proto, const unsigned short port) { - struct svc_program *progp; - unsigned int i; + unsigned int p, i; int error = 0; WARN_ON_ONCE(proto == 0 && port == 0); if (proto == 0 && port == 0) return -EINVAL; - for (progp = serv->sv_program; progp; progp = progp->pg_next) { - for (i = 0; i < progp->pg_nvers; i++) { - if (progp->pg_vers[i] == NULL) - continue; - - dprintk("svc: svc_register(%sv%d, %s, %u, %u)%s\n", - progp->pg_name, - i, - proto == IPPROTO_UDP? "udp" : "tcp", - port, - family, - progp->pg_vers[i]->vs_hidden? - " (but not telling portmap)" : ""); + for (p = 0; p < serv->sv_nprogs; p++) { + struct svc_program *progp = &serv->sv_programs[p]; - if (progp->pg_vers[i]->vs_hidden) - continue; + for (i = 0; i < progp->pg_nvers; i++) { - error = __svc_register(net, progp->pg_name, progp->pg_prog, - i, family, proto, port); - if (error < 0) + error = progp->pg_rpcbind_set(net, progp, i, + family, proto, port); + if (error < 0) { + printk(KERN_WARNING "svc: failed to register " + "%sv%u RPC service (errno %d).\n", + progp->pg_name, i, -error); break; + } } } @@ -992,8 +1203,7 @@ static void __svc_unregister(struct net *net, const u32 program, const u32 versi if (error == -EPROTONOSUPPORT) error = rpcb_register(net, program, version, 0, 0); - dprintk("svc: %s(%sv%u), error %d\n", - __func__, progname, version, error); + trace_svc_unregister(progname, version, error); } /* @@ -1006,34 +1216,36 @@ static void __svc_unregister(struct net *net, const u32 program, const u32 versi */ static void svc_unregister(const struct svc_serv *serv, struct net *net) { - struct svc_program *progp; + struct sighand_struct *sighand; unsigned long flags; - unsigned int i; + unsigned int p, i; clear_thread_flag(TIF_SIGPENDING); - for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (p = 0; p < serv->sv_nprogs; p++) { + struct svc_program *progp = &serv->sv_programs[p]; + for (i = 0; i < progp->pg_nvers; i++) { if (progp->pg_vers[i] == NULL) continue; if (progp->pg_vers[i]->vs_hidden) continue; - - dprintk("svc: attempting to unregister %sv%u\n", - progp->pg_name, i); __svc_unregister(net, progp->pg_prog, i, progp->pg_name); } } - spin_lock_irqsave(¤t->sighand->siglock, flags); + rcu_read_lock(); + sighand = rcu_dereference(current->sighand); + spin_lock_irqsave(&sighand->siglock, flags); recalc_sigpending(); - spin_unlock_irqrestore(¤t->sighand->siglock, flags); + spin_unlock_irqrestore(&sighand->siglock, flags); + rcu_read_unlock(); } /* * dprintk the given error with the address of the client that caused it. */ -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) static __printf(2, 3) void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) { @@ -1054,117 +1266,156 @@ void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) static __printf(2,3) void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) {} #endif +__be32 +svc_generic_init_request(struct svc_rqst *rqstp, + const struct svc_program *progp, + struct svc_process_info *ret) +{ + const struct svc_version *versp = NULL; /* compiler food */ + const struct svc_procedure *procp = NULL; + + if (rqstp->rq_vers >= progp->pg_nvers ) + goto err_bad_vers; + versp = progp->pg_vers[rqstp->rq_vers]; + if (!versp) + goto err_bad_vers; + + /* + * Some protocol versions (namely NFSv4) require some form of + * congestion control. (See RFC 7530 section 3.1 paragraph 2) + * In other words, UDP is not allowed. We mark those when setting + * up the svc_xprt, and verify that here. + * + * The spec is not very clear about what error should be returned + * when someone tries to access a server that is listening on UDP + * for lower versions. RPC_PROG_MISMATCH seems to be the closest + * fit. + */ + if (versp->vs_need_cong_ctrl && rqstp->rq_xprt && + !test_bit(XPT_CONG_CTRL, &rqstp->rq_xprt->xpt_flags)) + goto err_bad_vers; + + if (rqstp->rq_proc >= versp->vs_nproc) + goto err_bad_proc; + rqstp->rq_procinfo = procp = &versp->vs_proc[rqstp->rq_proc]; + + /* Initialize storage for argp and resp */ + memset(rqstp->rq_argp, 0, procp->pc_argzero); + memset(rqstp->rq_resp, 0, procp->pc_ressize); + + /* Bump per-procedure stats counter */ + this_cpu_inc(versp->vs_count[rqstp->rq_proc]); + + ret->dispatch = versp->vs_dispatch; + return rpc_success; +err_bad_vers: + ret->mismatch.lovers = progp->pg_lovers; + ret->mismatch.hivers = progp->pg_hivers; + return rpc_prog_mismatch; +err_bad_proc: + return rpc_proc_unavail; +} +EXPORT_SYMBOL_GPL(svc_generic_init_request); + /* * Common routine for processing the RPC request. */ static int -svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) +svc_process_common(struct svc_rqst *rqstp) { - struct svc_program *progp; - struct svc_version *versp = NULL; /* compiler food */ - struct svc_procedure *procp = NULL; + struct xdr_stream *xdr = &rqstp->rq_res_stream; + struct svc_program *progp = NULL; + const struct svc_procedure *procp = NULL; struct svc_serv *serv = rqstp->rq_server; - kxdrproc_t xdr; - __be32 *statp; - u32 prog, vers, proc; - __be32 auth_stat, rpc_stat; - int auth_res; - __be32 *reply_statp; - - rpc_stat = rpc_success; + struct svc_process_info process; + enum svc_auth_status auth_res; + unsigned int aoffset; + int pr, rc; + __be32 *p; - if (argv->iov_len < 6*4) - goto err_short_len; + /* Reset the accept_stat for the RPC */ + rqstp->rq_accept_statp = NULL; - /* Will be turned off only in gss privacy case: */ - rqstp->rq_splice_ok = 1; /* Will be turned off only when NFSv4 Sessions are used */ - rqstp->rq_usedeferral = 1; - rqstp->rq_dropme = false; - - /* Setup reply header */ - rqstp->rq_xprt->xpt_ops->xpo_prep_reply_hdr(rqstp); - - svc_putu32(resv, rqstp->rq_xid); + set_bit(RQ_USEDEFERRAL, &rqstp->rq_flags); + clear_bit(RQ_DROPME, &rqstp->rq_flags); - vers = svc_getnl(argv); + /* Construct the first words of the reply: */ + svcxdr_init_encode(rqstp); + xdr_stream_encode_be32(xdr, rqstp->rq_xid); + xdr_stream_encode_be32(xdr, rpc_reply); - /* First words of reply: */ - svc_putnl(resv, 1); /* REPLY */ - - if (vers != 2) /* RPC version number */ + p = xdr_inline_decode(&rqstp->rq_arg_stream, XDR_UNIT * 4); + if (unlikely(!p)) + goto err_short_len; + if (*p++ != cpu_to_be32(RPC_VERSION)) goto err_bad_rpc; - /* Save position in case we later decide to reject: */ - reply_statp = resv->iov_base + resv->iov_len; - - svc_putnl(resv, 0); /* ACCEPT */ + xdr_stream_encode_be32(xdr, rpc_msg_accepted); - rqstp->rq_prog = prog = svc_getnl(argv); /* program number */ - rqstp->rq_vers = vers = svc_getnl(argv); /* version number */ - rqstp->rq_proc = proc = svc_getnl(argv); /* procedure number */ + rqstp->rq_prog = be32_to_cpup(p++); + rqstp->rq_vers = be32_to_cpup(p++); + rqstp->rq_proc = be32_to_cpup(p); - progp = serv->sv_program; - - for (progp = serv->sv_program; progp; progp = progp->pg_next) - if (prog == progp->pg_prog) - break; + for (pr = 0; pr < serv->sv_nprogs; pr++) + if (rqstp->rq_prog == serv->sv_programs[pr].pg_prog) + progp = &serv->sv_programs[pr]; /* * Decode auth data, and add verifier to reply buffer. * We do this before anything else in order to get a decent * auth verifier. */ - auth_res = svc_authenticate(rqstp, &auth_stat); + auth_res = svc_authenticate(rqstp); /* Also give the program a chance to reject this call: */ - if (auth_res == SVC_OK && progp) { - auth_stat = rpc_autherr_badcred; + if (auth_res == SVC_OK && progp) auth_res = progp->pg_authenticate(rqstp); - } + trace_svc_authenticate(rqstp, auth_res); switch (auth_res) { case SVC_OK: break; case SVC_GARBAGE: - goto err_garbage; - case SVC_SYSERR: - rpc_stat = rpc_system_err; - goto err_bad; + rqstp->rq_auth_stat = rpc_autherr_badcred; + goto err_bad_auth; case SVC_DENIED: goto err_bad_auth; case SVC_CLOSE: - if (test_bit(XPT_TEMP, &rqstp->rq_xprt->xpt_flags)) - svc_close_xprt(rqstp->rq_xprt); + goto close; case SVC_DROP: goto dropit; case SVC_COMPLETE: goto sendit; + default: + pr_warn_once("Unexpected svc_auth_status (%d)\n", auth_res); + rqstp->rq_auth_stat = rpc_autherr_failed; + goto err_bad_auth; } if (progp == NULL) goto err_bad_prog; - if (vers >= progp->pg_nvers || - !(versp = progp->pg_vers[vers])) + switch (progp->pg_init_request(rqstp, progp, &process)) { + case rpc_success: + break; + case rpc_prog_unavail: + goto err_bad_prog; + case rpc_prog_mismatch: goto err_bad_vers; + case rpc_proc_unavail: + goto err_bad_proc; + } - procp = versp->vs_proc + proc; - if (proc >= versp->vs_nproc || !procp->pc_func) + procp = rqstp->rq_procinfo; + /* Should this check go into the dispatcher? */ + if (!procp || !procp->pc_func) goto err_bad_proc; - rqstp->rq_procinfo = procp; /* Syntactic check complete */ - serv->sv_stats->rpccnt++; - - /* Build the reply header. */ - statp = resv->iov_base +resv->iov_len; - svc_putnl(resv, RPC_SUCCESS); + if (serv->sv_stats) + serv->sv_stats->rpccnt++; + trace_svc_process(rqstp, progp->pg_name); - /* Bump per-procedure stats counter */ - procp->pc_count++; - - /* Initialize storage for argp and resp */ - memset(rqstp->rq_argp, 0, procp->pc_argsize); - memset(rqstp->rq_resp, 0, procp->pc_ressize); + aoffset = xdr_stream_pos(xdr); /* un-reserve some of the out-queue now that we have a * better idea of reply size @@ -1173,51 +1424,23 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) svc_reserve_auth(rqstp, procp->pc_xdrressize<<2); /* Call the function that processes the request. */ - if (!versp->vs_dispatch) { - /* Decode arguments */ - xdr = procp->pc_decode; - if (xdr && !xdr(rqstp, argv->iov_base, rqstp->rq_argp)) - goto err_garbage; - - *statp = procp->pc_func(rqstp, rqstp->rq_argp, rqstp->rq_resp); - - /* Encode reply */ - if (rqstp->rq_dropme) { - if (procp->pc_release) - procp->pc_release(rqstp, NULL, rqstp->rq_resp); - goto dropit; - } - if (*statp == rpc_success && - (xdr = procp->pc_encode) && - !xdr(rqstp, resv->iov_base+resv->iov_len, rqstp->rq_resp)) { - dprintk("svc: failed to encode reply\n"); - /* serv->sv_stats->rpcsystemerr++; */ - *statp = rpc_system_err; - } - } else { - dprintk("svc: calling dispatcher\n"); - if (!versp->vs_dispatch(rqstp, statp)) { - /* Release reply info */ - if (procp->pc_release) - procp->pc_release(rqstp, NULL, rqstp->rq_resp); - goto dropit; - } - } + rc = process.dispatch(rqstp); + xdr_finish_decode(xdr); - /* Check RPC status result */ - if (*statp != rpc_success) - resv->iov_len = ((void*)statp) - resv->iov_base + 4; + if (!rc) + goto dropit; + if (rqstp->rq_auth_stat != rpc_auth_ok) + goto err_bad_auth; - /* Release reply info */ - if (procp->pc_release) - procp->pc_release(rqstp, NULL, rqstp->rq_resp); + if (*rqstp->rq_accept_statp != rpc_success) + xdr_truncate_encode(xdr, aoffset); if (procp->pc_encode == NULL) goto dropit; sendit: if (svc_authorise(rqstp)) - goto dropit; + goto close_xprt; return 1; /* Caller can now send it */ dropit: @@ -1225,74 +1448,104 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) dprintk("svc: svc_process dropit\n"); return 0; -err_short_len: - svc_printk(rqstp, "short len %Zd, dropping request\n", - argv->iov_len); + close: + svc_authorise(rqstp); +close_xprt: + if (rqstp->rq_xprt && test_bit(XPT_TEMP, &rqstp->rq_xprt->xpt_flags)) + svc_xprt_close(rqstp->rq_xprt); + dprintk("svc: svc_process close\n"); + return 0; - goto dropit; /* drop request */ +err_short_len: + svc_printk(rqstp, "short len %u, dropping request\n", + rqstp->rq_arg.len); + goto close_xprt; err_bad_rpc: - serv->sv_stats->rpcbadfmt++; - svc_putnl(resv, 1); /* REJECT */ - svc_putnl(resv, 0); /* RPC_MISMATCH */ - svc_putnl(resv, 2); /* Only RPCv2 supported */ - svc_putnl(resv, 2); - goto sendit; + if (serv->sv_stats) + serv->sv_stats->rpcbadfmt++; + xdr_stream_encode_u32(xdr, RPC_MSG_DENIED); + xdr_stream_encode_u32(xdr, RPC_MISMATCH); + /* Only RPCv2 supported */ + xdr_stream_encode_u32(xdr, RPC_VERSION); + xdr_stream_encode_u32(xdr, RPC_VERSION); + return 1; /* don't wrap */ err_bad_auth: - dprintk("svc: authentication failed (%d)\n", ntohl(auth_stat)); - serv->sv_stats->rpcbadauth++; - /* Restore write pointer to location of accept status: */ - xdr_ressize_check(rqstp, reply_statp); - svc_putnl(resv, 1); /* REJECT */ - svc_putnl(resv, 1); /* AUTH_ERROR */ - svc_putnl(resv, ntohl(auth_stat)); /* status */ + dprintk("svc: authentication failed (%d)\n", + be32_to_cpu(rqstp->rq_auth_stat)); + if (serv->sv_stats) + serv->sv_stats->rpcbadauth++; + /* Restore write pointer to location of reply status: */ + xdr_truncate_encode(xdr, XDR_UNIT * 2); + xdr_stream_encode_u32(xdr, RPC_MSG_DENIED); + xdr_stream_encode_u32(xdr, RPC_AUTH_ERROR); + xdr_stream_encode_be32(xdr, rqstp->rq_auth_stat); goto sendit; err_bad_prog: - dprintk("svc: unknown program %d\n", prog); - serv->sv_stats->rpcbadfmt++; - svc_putnl(resv, RPC_PROG_UNAVAIL); + dprintk("svc: unknown program %d\n", rqstp->rq_prog); + if (serv->sv_stats) + serv->sv_stats->rpcbadfmt++; + *rqstp->rq_accept_statp = rpc_prog_unavail; goto sendit; err_bad_vers: svc_printk(rqstp, "unknown version (%d for prog %d, %s)\n", - vers, prog, progp->pg_name); + rqstp->rq_vers, rqstp->rq_prog, progp->pg_name); + + if (serv->sv_stats) + serv->sv_stats->rpcbadfmt++; + *rqstp->rq_accept_statp = rpc_prog_mismatch; - serv->sv_stats->rpcbadfmt++; - svc_putnl(resv, RPC_PROG_MISMATCH); - svc_putnl(resv, progp->pg_lovers); - svc_putnl(resv, progp->pg_hivers); + /* + * svc_authenticate() has already added the verifier and + * advanced the stream just past rq_accept_statp. + */ + xdr_stream_encode_u32(xdr, process.mismatch.lovers); + xdr_stream_encode_u32(xdr, process.mismatch.hivers); goto sendit; err_bad_proc: - svc_printk(rqstp, "unknown procedure (%d)\n", proc); + svc_printk(rqstp, "unknown procedure (%d)\n", rqstp->rq_proc); - serv->sv_stats->rpcbadfmt++; - svc_putnl(resv, RPC_PROC_UNAVAIL); + if (serv->sv_stats) + serv->sv_stats->rpcbadfmt++; + *rqstp->rq_accept_statp = rpc_proc_unavail; goto sendit; +} + +/* + * Drop request + */ +static void svc_drop(struct svc_rqst *rqstp) +{ + trace_svc_drop(rqstp); +} -err_garbage: - svc_printk(rqstp, "failed to decode args\n"); +static void svc_release_rqst(struct svc_rqst *rqstp) +{ + const struct svc_procedure *procp = rqstp->rq_procinfo; - rpc_stat = rpc_garbage_args; -err_bad: - serv->sv_stats->rpcbadfmt++; - svc_putnl(resv, ntohl(rpc_stat)); - goto sendit; + if (procp && procp->pc_release) + procp->pc_release(rqstp); } -EXPORT_SYMBOL_GPL(svc_process); -/* - * Process the RPC request. +/** + * svc_process - Execute one RPC transaction + * @rqstp: RPC transaction context + * */ -int -svc_process(struct svc_rqst *rqstp) +void svc_process(struct svc_rqst *rqstp) { - struct kvec *argv = &rqstp->rq_arg.head[0]; struct kvec *resv = &rqstp->rq_res.head[0]; - struct svc_serv *serv = rqstp->rq_server; - u32 dir; + __be32 *p; + +#if IS_ENABLED(CONFIG_FAIL_SUNRPC) + if (!fail_sunrpc.ignore_server_disconnect && + should_fail(&fail_sunrpc.attr, 1)) + svc_xprt_deferred_close(rqstp->rq_xprt); +#endif /* * Setup response xdr_buf. @@ -1301,7 +1554,7 @@ svc_process(struct svc_rqst *rqstp) rqstp->rq_next_page = &rqstp->rq_respages[1]; resv->iov_base = page_address(rqstp->rq_respages[0]); resv->iov_len = 0; - rqstp->rq_res.pages = rqstp->rq_respages + 1; + rqstp->rq_res.pages = rqstp->rq_next_page; rqstp->rq_res.len = 0; rqstp->rq_res.page_base = 0; rqstp->rq_res.page_len = 0; @@ -1309,80 +1562,117 @@ svc_process(struct svc_rqst *rqstp) rqstp->rq_res.tail[0].iov_base = NULL; rqstp->rq_res.tail[0].iov_len = 0; - rqstp->rq_xid = svc_getu32(argv); - - dir = svc_getnl(argv); - if (dir != 0) { - /* direction != CALL */ - svc_printk(rqstp, "bad direction %d, dropping request\n", dir); - serv->sv_stats->rpcbadfmt++; - svc_drop(rqstp); - return 0; - } - - /* Returns 1 for send, 0 for drop */ - if (svc_process_common(rqstp, argv, resv)) - return svc_send(rqstp); - else { - svc_drop(rqstp); - return 0; + svcxdr_init_decode(rqstp); + p = xdr_inline_decode(&rqstp->rq_arg_stream, XDR_UNIT * 2); + if (unlikely(!p)) + goto out_drop; + rqstp->rq_xid = *p++; + if (unlikely(*p != rpc_call)) + goto out_baddir; + + if (!svc_process_common(rqstp)) { + svc_release_rqst(rqstp); + goto out_drop; } + svc_send(rqstp); + svc_release_rqst(rqstp); + return; + +out_baddir: + svc_printk(rqstp, "bad direction 0x%08x, dropping request\n", + be32_to_cpu(*p)); + if (rqstp->rq_server->sv_stats) + rqstp->rq_server->sv_stats->rpcbadfmt++; +out_drop: + svc_drop(rqstp); } #if defined(CONFIG_SUNRPC_BACKCHANNEL) -/* - * Process a backchannel RPC request that arrived over an existing - * outbound connection +/** + * svc_process_bc - process a reverse-direction RPC request + * @req: RPC request to be used for client-side processing + * @rqstp: server-side execution context + * */ -int -bc_svc_process(struct svc_serv *serv, struct rpc_rqst *req, - struct svc_rqst *rqstp) +void svc_process_bc(struct rpc_rqst *req, struct svc_rqst *rqstp) { - struct kvec *argv = &rqstp->rq_arg.head[0]; - struct kvec *resv = &rqstp->rq_res.head[0]; + struct rpc_timeout timeout = { + .to_increment = 0, + }; + struct rpc_task *task; + int proc_error; /* Build the svc_rqst used by the common processing routine */ - rqstp->rq_xprt = serv->sv_bc_xprt; rqstp->rq_xid = req->rq_xid; rqstp->rq_prot = req->rq_xprt->prot; - rqstp->rq_server = serv; + rqstp->rq_bc_net = req->rq_xprt->xprt_net; rqstp->rq_addrlen = sizeof(req->rq_xprt->addr); memcpy(&rqstp->rq_addr, &req->rq_xprt->addr, rqstp->rq_addrlen); memcpy(&rqstp->rq_arg, &req->rq_rcv_buf, sizeof(rqstp->rq_arg)); memcpy(&rqstp->rq_res, &req->rq_snd_buf, sizeof(rqstp->rq_res)); - /* reset result send buffer "put" position */ - resv->iov_len = 0; + /* Adjust the argument buffer length */ + rqstp->rq_arg.len = req->rq_private_buf.len; + if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = rqstp->rq_arg.len; + rqstp->rq_arg.page_len = 0; + } else if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len + + rqstp->rq_arg.page_len) + rqstp->rq_arg.page_len = rqstp->rq_arg.len - + rqstp->rq_arg.head[0].iov_len; + else + rqstp->rq_arg.len = rqstp->rq_arg.head[0].iov_len + + rqstp->rq_arg.page_len; - if (rqstp->rq_prot != IPPROTO_TCP) { - printk(KERN_ERR "No support for Non-TCP transports!\n"); - BUG(); - } + /* Reset the response buffer */ + rqstp->rq_res.head[0].iov_len = 0; /* - * Skip the next two words because they've already been - * processed in the trasport + * Skip the XID and calldir fields because they've already + * been processed by the caller. */ - svc_getu32(argv); /* XID */ - svc_getnl(argv); /* CALLDIR */ - - /* Returns 1 for send, 0 for drop */ - if (svc_process_common(rqstp, argv, resv)) { - memcpy(&req->rq_snd_buf, &rqstp->rq_res, - sizeof(req->rq_snd_buf)); - return bc_send(req); - } else { - /* drop request */ + svcxdr_init_decode(rqstp); + if (!xdr_inline_decode(&rqstp->rq_arg_stream, XDR_UNIT * 2)) + return; + + /* Parse and execute the bc call */ + proc_error = svc_process_common(rqstp); + + atomic_dec(&req->rq_xprt->bc_slot_count); + if (!proc_error) { + /* Processing error: drop the request */ xprt_free_bc_request(req); - return 0; + svc_release_rqst(rqstp); + return; + } + /* Finally, send the reply synchronously */ + if (rqstp->bc_to_initval > 0) { + timeout.to_initval = rqstp->bc_to_initval; + timeout.to_retries = rqstp->bc_to_retries; + } else { + timeout.to_initval = req->rq_xprt->timeout->to_initval; + timeout.to_retries = req->rq_xprt->timeout->to_retries; } + timeout.to_maxval = timeout.to_initval; + memcpy(&req->rq_snd_buf, &rqstp->rq_res, sizeof(req->rq_snd_buf)); + task = rpc_run_bc_task(req, &timeout); + svc_release_rqst(rqstp); + + if (IS_ERR(task)) + return; + + WARN_ON_ONCE(atomic_read(&task->tk_count) != 1); + rpc_put_task(task); } -EXPORT_SYMBOL_GPL(bc_svc_process); #endif /* CONFIG_SUNRPC_BACKCHANNEL */ -/* - * Return (transport-specific) limit on the rpc payload. +/** + * svc_max_payload - Return transport-specific limit on the RPC payload + * @rqstp: RPC transaction context + * + * Returns the maximum number of payload bytes the current transport + * allows. */ u32 svc_max_payload(const struct svc_rqst *rqstp) { @@ -1393,3 +1683,85 @@ u32 svc_max_payload(const struct svc_rqst *rqstp) return max; } EXPORT_SYMBOL_GPL(svc_max_payload); + +/** + * svc_proc_name - Return RPC procedure name in string form + * @rqstp: svc_rqst to operate on + * + * Return value: + * Pointer to a NUL-terminated string + */ +const char *svc_proc_name(const struct svc_rqst *rqstp) +{ + if (rqstp && rqstp->rq_procinfo) + return rqstp->rq_procinfo->pc_name; + return "unknown"; +} + + +/** + * svc_encode_result_payload - mark a range of bytes as a result payload + * @rqstp: svc_rqst to operate on + * @offset: payload's byte offset in rqstp->rq_res + * @length: size of payload, in bytes + * + * Returns zero on success, or a negative errno if a permanent + * error occurred. + */ +int svc_encode_result_payload(struct svc_rqst *rqstp, unsigned int offset, + unsigned int length) +{ + return rqstp->rq_xprt->xpt_ops->xpo_result_payload(rqstp, offset, + length); +} +EXPORT_SYMBOL_GPL(svc_encode_result_payload); + +/** + * svc_fill_symlink_pathname - Construct pathname argument for VFS symlink call + * @rqstp: svc_rqst to operate on + * @first: buffer containing first section of pathname + * @p: buffer containing remaining section of pathname + * @total: total length of the pathname argument + * + * The VFS symlink API demands a NUL-terminated pathname in mapped memory. + * Returns pointer to a NUL-terminated string, or an ERR_PTR. Caller must free + * the returned string. + */ +char *svc_fill_symlink_pathname(struct svc_rqst *rqstp, struct kvec *first, + void *p, size_t total) +{ + size_t len, remaining; + char *result, *dst; + + result = kmalloc(total + 1, GFP_KERNEL); + if (!result) + return ERR_PTR(-ESERVERFAULT); + + dst = result; + remaining = total; + + len = min_t(size_t, total, first->iov_len); + if (len) { + memcpy(dst, first->iov_base, len); + dst += len; + remaining -= len; + } + + if (remaining) { + len = min_t(size_t, remaining, PAGE_SIZE); + memcpy(dst, p, len); + dst += len; + } + + *dst = '\0'; + + /* Sanity check: Linux doesn't allow the pathname argument to + * contain a NUL byte. + */ + if (strlen(result) != total) { + kfree(result); + return ERR_PTR(-EINVAL); + } + return result; +} +EXPORT_SYMBOL_GPL(svc_fill_symlink_pathname); diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c index 80a6640f329b..6973184ff667 100644 --- a/net/sunrpc/svc_xprt.c +++ b/net/sunrpc/svc_xprt.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/svc_xprt.c * @@ -5,29 +6,37 @@ */ #include <linux/sched.h> +#include <linux/sched/mm.h> #include <linux/errno.h> #include <linux/freezer.h> -#include <linux/kthread.h> #include <linux/slab.h> #include <net/sock.h> +#include <linux/sunrpc/addr.h> #include <linux/sunrpc/stats.h> #include <linux/sunrpc/svc_xprt.h> #include <linux/sunrpc/svcsock.h> #include <linux/sunrpc/xprt.h> +#include <linux/sunrpc/bc_xprt.h> #include <linux/module.h> +#include <linux/netdevice.h> +#include <trace/events/sunrpc.h> #define RPCDBG_FACILITY RPCDBG_SVCXPRT +static unsigned int svc_rpc_per_connection_limit __read_mostly; +module_param(svc_rpc_per_connection_limit, uint, 0644); + + static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt); static int svc_deferred_recv(struct svc_rqst *rqstp); static struct cache_deferred_req *svc_defer(struct cache_req *req); -static void svc_age_temp_xprts(unsigned long closure); +static void svc_age_temp_xprts(struct timer_list *t); static void svc_delete_xprt(struct svc_xprt *xprt); /* apparently the "standard" is that clients close * idle connections after 5 minutes, servers after * 6 minutes - * http://www.connectathon.org/talks96/nfstcp.pdf + * http://nfsv4bat.org/Documents/ConnectAThon/1996/nfstcp.pdf */ static int svc_conn_age_period = 6*60; @@ -37,10 +46,9 @@ static LIST_HEAD(svc_xprt_class_list); /* SMP locking strategy: * - * svc_pool->sp_lock protects most of the fields of that pool. * svc_serv->sv_lock protects sv_tempsocks, sv_permsocks, sv_tmpcnt. * when both need to be taken (rare), svc_serv->sv_lock is first. - * BKL protects svc_serv->sv_nrthread. + * The "service mutex" protects svc_serv->sv_nrthread. * svc_sock->sk_lock protects the svc_sock->sk_deferred list * and the ->sk_info_authunix cache. * @@ -66,13 +74,17 @@ static LIST_HEAD(svc_xprt_class_list); * try to set XPT_DEAD. */ +/** + * svc_reg_xprt_class - Register a server-side RPC transport class + * @xcl: New transport class to be registered + * + * Returns zero on success; otherwise a negative errno is returned. + */ int svc_reg_xprt_class(struct svc_xprt_class *xcl) { struct svc_xprt_class *cl; int res = -EEXIST; - dprintk("svc: Adding svc transport class '%s'\n", xcl->xcl_name); - INIT_LIST_HEAD(&xcl->xcl_list); spin_lock(&svc_xprt_class_lock); /* Make sure there isn't already a class with the same name */ @@ -88,17 +100,30 @@ out: } EXPORT_SYMBOL_GPL(svc_reg_xprt_class); +/** + * svc_unreg_xprt_class - Unregister a server-side RPC transport class + * @xcl: Transport class to be unregistered + * + */ void svc_unreg_xprt_class(struct svc_xprt_class *xcl) { - dprintk("svc: Removing svc transport class '%s'\n", xcl->xcl_name); spin_lock(&svc_xprt_class_lock); list_del_init(&xcl->xcl_list); spin_unlock(&svc_xprt_class_lock); } EXPORT_SYMBOL_GPL(svc_unreg_xprt_class); -/* - * Format the transport list for printing +/** + * svc_print_xprts - Format the transport list for printing + * @buf: target buffer for formatted address + * @maxlen: length of target buffer + * + * Fills in @buf with a string containing a list of transport names, each name + * terminated with '\n'. If the buffer is too small, some entries may be + * missing, but it is guaranteed that all lines in the output buffer are + * complete. + * + * Returns positive length of the filled-in string. */ int svc_print_xprts(char *buf, int maxlen) { @@ -111,9 +136,9 @@ int svc_print_xprts(char *buf, int maxlen) list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) { int slen; - sprintf(tmpstr, "%s %d\n", xcl->xcl_name, xcl->xcl_max_payload); - slen = strlen(tmpstr); - if (len + slen > maxlen) + slen = snprintf(tmpstr, sizeof(tmpstr), "%s %d\n", + xcl->xcl_name, xcl->xcl_max_payload); + if (slen >= sizeof(tmpstr) || len + slen >= maxlen) break; len += slen; strcat(buf, tmpstr); @@ -123,6 +148,21 @@ int svc_print_xprts(char *buf, int maxlen) return len; } +/** + * svc_xprt_deferred_close - Close a transport + * @xprt: transport instance + * + * Used in contexts that need to defer the work of shutting down + * the transport to an nfsd thread. + */ +void svc_xprt_deferred_close(struct svc_xprt *xprt) +{ + trace_svc_xprt_close(xprt); + if (!test_and_set_bit(XPT_CLOSE, &xprt->xpt_flags)) + svc_xprt_enqueue(xprt); +} +EXPORT_SYMBOL_GPL(svc_xprt_deferred_close); + static void svc_xprt_free(struct kref *kref) { struct svc_xprt *xprt = @@ -130,10 +170,14 @@ static void svc_xprt_free(struct kref *kref) struct module *owner = xprt->xpt_class->xcl_owner; if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) svcauth_unix_info_release(xprt); - put_net(xprt->xpt_net); + put_cred(xprt->xpt_cred); + put_net_track(xprt->xpt_net, &xprt->ns_tracker); /* See comment on corresponding get in xs_setup_bc_tcp(): */ if (xprt->xpt_bc_xprt) xprt_put(xprt->xpt_bc_xprt); + if (xprt->xpt_bc_xps) + xprt_switch_put(xprt->xpt_bc_xps); + trace_svc_xprt_free(xprt); xprt->xpt_ops->xpo_free(xprt); module_put(owner); } @@ -157,78 +201,43 @@ void svc_xprt_init(struct net *net, struct svc_xprt_class *xcl, kref_init(&xprt->xpt_ref); xprt->xpt_server = serv; INIT_LIST_HEAD(&xprt->xpt_list); - INIT_LIST_HEAD(&xprt->xpt_ready); INIT_LIST_HEAD(&xprt->xpt_deferred); INIT_LIST_HEAD(&xprt->xpt_users); mutex_init(&xprt->xpt_mutex); spin_lock_init(&xprt->xpt_lock); set_bit(XPT_BUSY, &xprt->xpt_flags); - rpc_init_wait_queue(&xprt->xpt_bc_pending, "xpt_bc_pending"); - xprt->xpt_net = get_net(net); + xprt->xpt_net = get_net_track(net, &xprt->ns_tracker, GFP_ATOMIC); + strcpy(xprt->xpt_remotebuf, "uninitialized"); } EXPORT_SYMBOL_GPL(svc_xprt_init); -static struct svc_xprt *__svc_xpo_create(struct svc_xprt_class *xcl, - struct svc_serv *serv, - struct net *net, - const int family, - const unsigned short port, - int flags) -{ - struct sockaddr_in sin = { - .sin_family = AF_INET, - .sin_addr.s_addr = htonl(INADDR_ANY), - .sin_port = htons(port), - }; -#if IS_ENABLED(CONFIG_IPV6) - struct sockaddr_in6 sin6 = { - .sin6_family = AF_INET6, - .sin6_addr = IN6ADDR_ANY_INIT, - .sin6_port = htons(port), - }; -#endif - struct sockaddr *sap; - size_t len; - - switch (family) { - case PF_INET: - sap = (struct sockaddr *)&sin; - len = sizeof(sin); - break; -#if IS_ENABLED(CONFIG_IPV6) - case PF_INET6: - sap = (struct sockaddr *)&sin6; - len = sizeof(sin6); - break; -#endif - default: - return ERR_PTR(-EAFNOSUPPORT); - } - - return xcl->xcl_ops->xpo_create(serv, net, sap, len, flags); -} - -/* - * svc_xprt_received conditionally queues the transport for processing - * by another thread. The caller must hold the XPT_BUSY bit and must +/** + * svc_xprt_received - start next receiver thread + * @xprt: controlling transport + * + * The caller must hold the XPT_BUSY bit and must * not thereafter touch transport data. * * Note: XPT_DATA only gets cleared when a read-attempt finds no (or * insufficient) data. */ -static void svc_xprt_received(struct svc_xprt *xprt) +void svc_xprt_received(struct svc_xprt *xprt) { - WARN_ON_ONCE(!test_bit(XPT_BUSY, &xprt->xpt_flags)); - if (!test_bit(XPT_BUSY, &xprt->xpt_flags)) + if (!test_bit(XPT_BUSY, &xprt->xpt_flags)) { + WARN_ONCE(1, "xprt=0x%p already busy!", xprt); return; + } + /* As soon as we clear busy, the xprt could be closed and * 'put', so we need a reference to call svc_xprt_enqueue with: */ svc_xprt_get(xprt); + smp_mb__before_atomic(); clear_bit(XPT_BUSY, &xprt->xpt_flags); svc_xprt_enqueue(xprt); svc_xprt_put(xprt); } +EXPORT_SYMBOL_GPL(svc_xprt_received); void svc_add_new_perm_xprt(struct svc_serv *serv, struct svc_xprt *new) { @@ -239,13 +248,12 @@ void svc_add_new_perm_xprt(struct svc_serv *serv, struct svc_xprt *new) svc_xprt_received(new); } -int svc_create_xprt(struct svc_serv *serv, const char *xprt_name, - struct net *net, const int family, - const unsigned short port, int flags) +static int _svc_xprt_create(struct svc_serv *serv, const char *xprt_name, + struct net *net, struct sockaddr *sap, + size_t len, int flags, const struct cred *cred) { struct svc_xprt_class *xcl; - dprintk("svc: creating transport %s[%d]\n", xprt_name, port); spin_lock(&svc_xprt_class_lock); list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) { struct svc_xprt *newxprt; @@ -258,24 +266,115 @@ int svc_create_xprt(struct svc_serv *serv, const char *xprt_name, goto err; spin_unlock(&svc_xprt_class_lock); - newxprt = __svc_xpo_create(xcl, serv, net, family, port, flags); + newxprt = xcl->xcl_ops->xpo_create(serv, net, sap, len, flags); if (IS_ERR(newxprt)) { + trace_svc_xprt_create_err(serv->sv_programs->pg_name, + xcl->xcl_name, sap, len, + newxprt); module_put(xcl->xcl_owner); return PTR_ERR(newxprt); } + newxprt->xpt_cred = get_cred(cred); svc_add_new_perm_xprt(serv, newxprt); newport = svc_xprt_local_port(newxprt); return newport; } err: spin_unlock(&svc_xprt_class_lock); - dprintk("svc: transport %s not found\n", xprt_name); - /* This errno is exposed to user space. Provide a reasonable * perror msg for a bad transport. */ return -EPROTONOSUPPORT; } -EXPORT_SYMBOL_GPL(svc_create_xprt); + +/** + * svc_xprt_create_from_sa - Add a new listener to @serv from socket address + * @serv: target RPC service + * @xprt_name: transport class name + * @net: network namespace + * @sap: socket address pointer + * @flags: SVC_SOCK flags + * @cred: credential to bind to this transport + * + * Return local xprt port on success or %-EPROTONOSUPPORT on failure + */ +int svc_xprt_create_from_sa(struct svc_serv *serv, const char *xprt_name, + struct net *net, struct sockaddr *sap, + int flags, const struct cred *cred) +{ + size_t len; + int err; + + switch (sap->sa_family) { + case AF_INET: + len = sizeof(struct sockaddr_in); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + len = sizeof(struct sockaddr_in6); + break; +#endif + default: + return -EAFNOSUPPORT; + } + + err = _svc_xprt_create(serv, xprt_name, net, sap, len, flags, cred); + if (err == -EPROTONOSUPPORT) { + request_module("svc%s", xprt_name); + err = _svc_xprt_create(serv, xprt_name, net, sap, len, flags, + cred); + } + + return err; +} +EXPORT_SYMBOL_GPL(svc_xprt_create_from_sa); + +/** + * svc_xprt_create - Add a new listener to @serv + * @serv: target RPC service + * @xprt_name: transport class name + * @net: network namespace + * @family: network address family + * @port: listener port + * @flags: SVC_SOCK flags + * @cred: credential to bind to this transport + * + * Return local xprt port on success or %-EPROTONOSUPPORT on failure + */ +int svc_xprt_create(struct svc_serv *serv, const char *xprt_name, + struct net *net, const int family, + const unsigned short port, int flags, + const struct cred *cred) +{ + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + .sin_port = htons(port), + }; +#if IS_ENABLED(CONFIG_IPV6) + struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + .sin6_port = htons(port), + }; +#endif + struct sockaddr *sap; + + switch (family) { + case PF_INET: + sap = (struct sockaddr *)&sin; + break; +#if IS_ENABLED(CONFIG_IPV6) + case PF_INET6: + sap = (struct sockaddr *)&sin6; + break; +#endif + default: + return -EAFNOSUPPORT; + } + + return svc_xprt_create_from_sa(serv, xprt_name, net, sap, flags, cred); +} +EXPORT_SYMBOL_GPL(svc_xprt_create); /* * Copy the local and remote xprt addresses to the rqstp structure @@ -307,116 +406,105 @@ char *svc_print_addr(struct svc_rqst *rqstp, char *buf, size_t len) } EXPORT_SYMBOL_GPL(svc_print_addr); -/* - * Queue up an idle server thread. Must have pool->sp_lock held. - * Note: this is really a stack rather than a queue, so that we only - * use as many different threads as we need, and the rest don't pollute - * the cache. - */ -static void svc_thread_enqueue(struct svc_pool *pool, struct svc_rqst *rqstp) +static bool svc_xprt_slots_in_range(struct svc_xprt *xprt) { - list_add(&rqstp->rq_list, &pool->sp_threads); + unsigned int limit = svc_rpc_per_connection_limit; + int nrqsts = atomic_read(&xprt->xpt_nr_rqsts); + + return limit == 0 || (nrqsts >= 0 && nrqsts < limit); } -/* - * Dequeue an nfsd thread. Must have pool->sp_lock held. - */ -static void svc_thread_dequeue(struct svc_pool *pool, struct svc_rqst *rqstp) +static bool svc_xprt_reserve_slot(struct svc_rqst *rqstp, struct svc_xprt *xprt) { - list_del(&rqstp->rq_list); + if (!test_bit(RQ_DATA, &rqstp->rq_flags)) { + if (!svc_xprt_slots_in_range(xprt)) + return false; + atomic_inc(&xprt->xpt_nr_rqsts); + set_bit(RQ_DATA, &rqstp->rq_flags); + } + return true; } -static bool svc_xprt_has_something_to_do(struct svc_xprt *xprt) +static void svc_xprt_release_slot(struct svc_rqst *rqstp) { - if (xprt->xpt_flags & ((1<<XPT_CONN)|(1<<XPT_CLOSE))) + struct svc_xprt *xprt = rqstp->rq_xprt; + if (test_and_clear_bit(RQ_DATA, &rqstp->rq_flags)) { + atomic_dec(&xprt->xpt_nr_rqsts); + smp_wmb(); /* See smp_rmb() in svc_xprt_ready() */ + svc_xprt_enqueue(xprt); + } +} + +static bool svc_xprt_ready(struct svc_xprt *xprt) +{ + unsigned long xpt_flags; + + /* + * If another cpu has recently updated xpt_flags, + * sk_sock->flags, xpt_reserved, or xpt_nr_rqsts, we need to + * know about it; otherwise it's possible that both that cpu and + * this one could call svc_xprt_enqueue() without either + * svc_xprt_enqueue() recognizing that the conditions below + * are satisfied, and we could stall indefinitely: + */ + smp_rmb(); + xpt_flags = READ_ONCE(xprt->xpt_flags); + + trace_svc_xprt_enqueue(xprt, xpt_flags); + if (xpt_flags & BIT(XPT_BUSY)) + return false; + if (xpt_flags & (BIT(XPT_CONN) | BIT(XPT_CLOSE) | BIT(XPT_HANDSHAKE))) return true; - if (xprt->xpt_flags & ((1<<XPT_DATA)|(1<<XPT_DEFERRED))) - return xprt->xpt_ops->xpo_has_wspace(xprt); + if (xpt_flags & (BIT(XPT_DATA) | BIT(XPT_DEFERRED))) { + if (xprt->xpt_ops->xpo_has_wspace(xprt) && + svc_xprt_slots_in_range(xprt)) + return true; + trace_svc_xprt_no_write_space(xprt); + return false; + } return false; } -/* - * Queue up a transport with data pending. If there are idle nfsd - * processes, wake 'em up. +/** + * svc_xprt_enqueue - Queue a transport on an idle nfsd thread + * @xprt: transport with data pending * */ void svc_xprt_enqueue(struct svc_xprt *xprt) { struct svc_pool *pool; - struct svc_rqst *rqstp; - int cpu; - if (!svc_xprt_has_something_to_do(xprt)) + if (!svc_xprt_ready(xprt)) return; - cpu = get_cpu(); - pool = svc_pool_for_cpu(xprt->xpt_server, cpu); - put_cpu(); - - spin_lock_bh(&pool->sp_lock); - - if (!list_empty(&pool->sp_threads) && - !list_empty(&pool->sp_sockets)) - printk(KERN_ERR - "svc_xprt_enqueue: " - "threads and transports both waiting??\n"); - - pool->sp_stats.packets++; - /* Mark transport as busy. It will remain in this state until * the provider calls svc_xprt_received. We update XPT_BUSY * atomically because it also guards against trying to enqueue * the transport twice. */ - if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) { - /* Don't enqueue transport while already enqueued */ - dprintk("svc: transport %p busy, not enqueued\n", xprt); - goto out_unlock; - } + if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) + return; - if (!list_empty(&pool->sp_threads)) { - rqstp = list_entry(pool->sp_threads.next, - struct svc_rqst, - rq_list); - dprintk("svc: transport %p served by daemon %p\n", - xprt, rqstp); - svc_thread_dequeue(pool, rqstp); - if (rqstp->rq_xprt) - printk(KERN_ERR - "svc_xprt_enqueue: server %p, rq_xprt=%p!\n", - rqstp, rqstp->rq_xprt); - rqstp->rq_xprt = xprt; - svc_xprt_get(xprt); - pool->sp_stats.threads_woken++; - wake_up(&rqstp->rq_wait); - } else { - dprintk("svc: transport %p put into queue\n", xprt); - list_add_tail(&xprt->xpt_ready, &pool->sp_sockets); - pool->sp_stats.sockets_queued++; - } + pool = svc_pool_for_cpu(xprt->xpt_server); + + percpu_counter_inc(&pool->sp_sockets_queued); + xprt->xpt_qtime = ktime_get(); + lwq_enqueue(&xprt->xpt_ready, &pool->sp_xprts); -out_unlock: - spin_unlock_bh(&pool->sp_lock); + svc_pool_wake_idle_thread(pool); } EXPORT_SYMBOL_GPL(svc_xprt_enqueue); /* - * Dequeue the first transport. Must be called with the pool->sp_lock held. + * Dequeue the first transport, if there is one. */ static struct svc_xprt *svc_xprt_dequeue(struct svc_pool *pool) { - struct svc_xprt *xprt; - - if (list_empty(&pool->sp_sockets)) - return NULL; - - xprt = list_entry(pool->sp_sockets.next, - struct svc_xprt, xpt_ready); - list_del_init(&xprt->xpt_ready); - - dprintk("svc: transport %p dequeued, inuse=%d\n", - xprt, atomic_read(&xprt->xpt_ref.refcount)); + struct svc_xprt *xprt = NULL; + xprt = lwq_dequeue(&pool->sp_xprts, struct svc_xprt, xpt_ready); + if (xprt) + svc_xprt_get(xprt); return xprt; } @@ -432,28 +520,39 @@ static struct svc_xprt *svc_xprt_dequeue(struct svc_pool *pool) */ void svc_reserve(struct svc_rqst *rqstp, int space) { + struct svc_xprt *xprt = rqstp->rq_xprt; + space += rqstp->rq_res.head[0].iov_len; - if (space < rqstp->rq_reserved) { - struct svc_xprt *xprt = rqstp->rq_xprt; + if (xprt && space < rqstp->rq_reserved) { atomic_sub((rqstp->rq_reserved - space), &xprt->xpt_reserved); rqstp->rq_reserved = space; - + smp_wmb(); /* See smp_rmb() in svc_xprt_ready() */ svc_xprt_enqueue(xprt); } } EXPORT_SYMBOL_GPL(svc_reserve); +static void free_deferred(struct svc_xprt *xprt, struct svc_deferred_req *dr) +{ + if (!dr) + return; + + xprt->xpt_ops->xpo_release_ctxt(xprt, dr->xprt_ctxt); + kfree(dr); +} + static void svc_xprt_release(struct svc_rqst *rqstp) { struct svc_xprt *xprt = rqstp->rq_xprt; - rqstp->rq_xprt->xpt_ops->xpo_release_rqst(rqstp); + xprt->xpt_ops->xpo_release_ctxt(xprt, rqstp->rq_xprt_ctxt); + rqstp->rq_xprt_ctxt = NULL; - kfree(rqstp->rq_deferred); + free_deferred(xprt, rqstp->rq_deferred); rqstp->rq_deferred = NULL; - svc_free_res_pages(rqstp); + svc_rqst_release_pages(rqstp); rqstp->rq_res.page_len = 0; rqstp->rq_res.page_base = 0; @@ -469,40 +568,27 @@ static void svc_xprt_release(struct svc_rqst *rqstp) rqstp->rq_res.head[0].iov_len = 0; svc_reserve(rqstp, 0); + svc_xprt_release_slot(rqstp); rqstp->rq_xprt = NULL; - svc_xprt_put(xprt); } -/* - * External function to wake up a server waiting for data - * This really only makes sense for services like lockd - * which have exactly one thread anyway. +/** + * svc_wake_up - Wake up a service thread for non-transport work + * @serv: RPC service + * + * Some svc_serv's will have occasional work to do, even when a xprt is not + * waiting to be serviced. This function is there to "kick" a task in one of + * those services so that it can wake up and do that work. Note that we only + * bother with pool 0 as we don't need to wake up more than one thread for + * this purpose. */ void svc_wake_up(struct svc_serv *serv) { - struct svc_rqst *rqstp; - unsigned int i; - struct svc_pool *pool; + struct svc_pool *pool = &serv->sv_pools[0]; - for (i = 0; i < serv->sv_nrpools; i++) { - pool = &serv->sv_pools[i]; - - spin_lock_bh(&pool->sp_lock); - if (!list_empty(&pool->sp_threads)) { - rqstp = list_entry(pool->sp_threads.next, - struct svc_rqst, - rq_list); - dprintk("svc: daemon %p woken up.\n", rqstp); - /* - svc_thread_dequeue(pool, rqstp); - rqstp->rq_xprt = NULL; - */ - wake_up(&rqstp->rq_wait); - } else - pool->sp_task_pending = 1; - spin_unlock_bh(&pool->sp_lock); - } + set_bit(SP_TASK_PENDING, &pool->sp_flags); + svc_pool_wake_idle_thread(pool); } EXPORT_SYMBOL_GPL(svc_wake_up); @@ -521,7 +607,8 @@ int svc_port_is_privileged(struct sockaddr *sin) } /* - * Make sure that we don't have too many active connections. If we have, + * Make sure that we don't have too many connections that have not yet + * demonstrated that they have access to the NFS server. If we have, * something must be dropped. It's not clear what will happen if we allow * "too many" connections, but when dealing with network-facing software, * we have to code defensively. Here we do that by imposing hard limits. @@ -533,34 +620,26 @@ int svc_port_is_privileged(struct sockaddr *sin) * The only somewhat efficient mechanism would be if drop old * connections from the same IP first. But right now we don't even * record the client IP in svc_sock. - * - * single-threaded services that expect a lot of clients will probably - * need to set sv_maxconn to override the default value which is based - * on the number of threads */ static void svc_check_conn_limits(struct svc_serv *serv) { - unsigned int limit = serv->sv_maxconn ? serv->sv_maxconn : - (serv->sv_nrthreads+3) * 20; - - if (serv->sv_tmpcnt > limit) { - struct svc_xprt *xprt = NULL; + if (serv->sv_tmpcnt > XPT_MAX_TMP_CONN) { + struct svc_xprt *xprt = NULL, *xprti; spin_lock_bh(&serv->sv_lock); if (!list_empty(&serv->sv_tempsocks)) { - /* Try to help the admin */ - net_notice_ratelimited("%s: too many open connections, consider increasing the %s\n", - serv->sv_name, serv->sv_maxconn ? - "max number of connections" : - "number of threads"); /* * Always select the oldest connection. It's not fair, - * but so is life + * but nor is life. */ - xprt = list_entry(serv->sv_tempsocks.prev, - struct svc_xprt, - xpt_list); - set_bit(XPT_CLOSE, &xprt->xpt_flags); - svc_xprt_get(xprt); + list_for_each_entry_reverse(xprti, &serv->sv_tempsocks, + xpt_list) { + if (!test_bit(XPT_PEER_VALID, &xprti->xpt_flags)) { + xprt = xprti; + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_get(xprt); + break; + } + } } spin_unlock_bh(&serv->sv_lock); @@ -571,36 +650,30 @@ static void svc_check_conn_limits(struct svc_serv *serv) } } -int svc_alloc_arg(struct svc_rqst *rqstp) +static bool svc_alloc_arg(struct svc_rqst *rqstp) { - struct svc_serv *serv = rqstp->rq_server; - struct xdr_buf *arg; - int pages; - int i; + struct xdr_buf *arg = &rqstp->rq_arg; + unsigned long pages, filled, ret; + + pages = rqstp->rq_maxpages; + for (filled = 0; filled < pages; filled = ret) { + ret = alloc_pages_bulk(GFP_KERNEL, pages, rqstp->rq_pages); + if (ret > filled) + /* Made progress, don't sleep yet */ + continue; - /* now allocate needed pages. If we get a failure, sleep briefly */ - pages = (serv->sv_max_mesg + PAGE_SIZE) / PAGE_SIZE; - WARN_ON_ONCE(pages >= RPCSVC_MAXPAGES); - if (pages >= RPCSVC_MAXPAGES) - /* use as many pages as possible */ - pages = RPCSVC_MAXPAGES - 1; - for (i = 0; i < pages ; i++) - while (rqstp->rq_pages[i] == NULL) { - struct page *p = alloc_page(GFP_KERNEL); - if (!p) { - set_current_state(TASK_INTERRUPTIBLE); - if (signalled() || kthread_should_stop()) { - set_current_state(TASK_RUNNING); - return -EINTR; - } - schedule_timeout(msecs_to_jiffies(500)); - } - rqstp->rq_pages[i] = p; + set_current_state(TASK_IDLE); + if (svc_thread_should_stop(rqstp)) { + set_current_state(TASK_RUNNING); + return false; } - rqstp->rq_pages[i++] = NULL; /* this might be seen in nfs_read_actor */ + trace_svc_alloc_arg_err(pages, ret); + memalloc_retry_wait(GFP_KERNEL); + } + rqstp->rq_page_end = &rqstp->rq_pages[pages]; + rqstp->rq_pages[pages] = NULL; /* this might be seen in nfsd_splice_actor() */ /* Make arg->head point to first page and arg->pages point to rest */ - arg = &rqstp->rq_arg; arg->head[0].iov_base = page_address(rqstp->rq_pages[0]); arg->head[0].iov_len = PAGE_SIZE; arg->pages = rqstp->rq_pages + 1; @@ -609,89 +682,69 @@ int svc_alloc_arg(struct svc_rqst *rqstp) arg->page_len = (pages-2)*PAGE_SIZE; arg->len = (pages-1)*PAGE_SIZE; arg->tail[0].iov_len = 0; - return 0; + + rqstp->rq_xid = xdr_zero; + return true; } -struct svc_xprt *svc_get_next_xprt(struct svc_rqst *rqstp, long timeout) +static bool +svc_thread_should_sleep(struct svc_rqst *rqstp) { - struct svc_xprt *xprt; struct svc_pool *pool = rqstp->rq_pool; - DECLARE_WAITQUEUE(wait, current); - long time_left; - - /* Normally we will wait up to 5 seconds for any required - * cache information to be provided. - */ - rqstp->rq_chandle.thread_wait = 5*HZ; - - spin_lock_bh(&pool->sp_lock); - xprt = svc_xprt_dequeue(pool); - if (xprt) { - rqstp->rq_xprt = xprt; - svc_xprt_get(xprt); - /* As there is a shortage of threads and this request - * had to be queued, don't allow the thread to wait so - * long for cache updates. - */ - rqstp->rq_chandle.thread_wait = 1*HZ; - pool->sp_task_pending = 0; - } else { - if (pool->sp_task_pending) { - pool->sp_task_pending = 0; - spin_unlock_bh(&pool->sp_lock); - return ERR_PTR(-EAGAIN); - } - /* No data pending. Go to sleep */ - svc_thread_enqueue(pool, rqstp); + /* did someone call svc_wake_up? */ + if (test_bit(SP_TASK_PENDING, &pool->sp_flags)) + return false; - /* - * We have to be able to interrupt this wait - * to bring down the daemons ... - */ - set_current_state(TASK_INTERRUPTIBLE); - - /* - * checking kthread_should_stop() here allows us to avoid - * locking and signalling when stopping kthreads that call - * svc_recv. If the thread has already been woken up, then - * we can exit here without sleeping. If not, then it - * it'll be woken up quickly during the schedule_timeout - */ - if (kthread_should_stop()) { - set_current_state(TASK_RUNNING); - spin_unlock_bh(&pool->sp_lock); - return ERR_PTR(-EINTR); - } + /* was a socket queued? */ + if (!lwq_empty(&pool->sp_xprts)) + return false; - add_wait_queue(&rqstp->rq_wait, &wait); - spin_unlock_bh(&pool->sp_lock); + /* are we shutting down? */ + if (svc_thread_should_stop(rqstp)) + return false; - time_left = schedule_timeout(timeout); - - try_to_freeze(); +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + if (svc_is_backchannel(rqstp)) { + if (!lwq_empty(&rqstp->rq_server->sv_cb_list)) + return false; + } +#endif - spin_lock_bh(&pool->sp_lock); - remove_wait_queue(&rqstp->rq_wait, &wait); - if (!time_left) - pool->sp_stats.threads_timedout++; + return true; +} - xprt = rqstp->rq_xprt; - if (!xprt) { - svc_thread_dequeue(pool, rqstp); - spin_unlock_bh(&pool->sp_lock); - dprintk("svc: server %p, no data yet\n", rqstp); - if (signalled() || kthread_should_stop()) - return ERR_PTR(-EINTR); - else - return ERR_PTR(-EAGAIN); +static void svc_thread_wait_for_work(struct svc_rqst *rqstp) +{ + struct svc_pool *pool = rqstp->rq_pool; + + if (svc_thread_should_sleep(rqstp)) { + set_current_state(TASK_IDLE | TASK_FREEZABLE); + llist_add(&rqstp->rq_idle, &pool->sp_idle_threads); + if (likely(svc_thread_should_sleep(rqstp))) + schedule(); + + while (!llist_del_first_this(&pool->sp_idle_threads, + &rqstp->rq_idle)) { + /* Work just became available. This thread can only + * handle it after removing rqstp from the idle + * list. If that attempt failed, some other thread + * must have queued itself after finding no + * work to do, so that thread has taken responsibly + * for this new work. This thread can safely sleep + * until woken again. + */ + schedule(); + set_current_state(TASK_IDLE | TASK_FREEZABLE); } + __set_current_state(TASK_RUNNING); + } else { + cond_resched(); } - spin_unlock_bh(&pool->sp_lock); - return xprt; + try_to_freeze(); } -void svc_add_new_temp_xprt(struct svc_serv *serv, struct svc_xprt *newxpt) +static void svc_add_new_temp_xprt(struct svc_serv *serv, struct svc_xprt *newxpt) { spin_lock_bh(&serv->sv_lock); set_bit(XPT_TEMP, &newxpt->xpt_flags); @@ -699,8 +752,7 @@ void svc_add_new_temp_xprt(struct svc_serv *serv, struct svc_xprt *newxpt) serv->sv_tmpcnt++; if (serv->sv_temptimer.function == NULL) { /* setup timer to age temp transports */ - setup_timer(&serv->sv_temptimer, svc_age_temp_xprts, - (unsigned long)serv); + serv->sv_temptimer.function = svc_age_temp_xprts; mod_timer(&serv->sv_temptimer, jiffies + svc_conn_age_period * HZ); } @@ -708,16 +760,17 @@ void svc_add_new_temp_xprt(struct svc_serv *serv, struct svc_xprt *newxpt) svc_xprt_received(newxpt); } -static int svc_handle_xprt(struct svc_rqst *rqstp, struct svc_xprt *xprt) +static void svc_handle_xprt(struct svc_rqst *rqstp, struct svc_xprt *xprt) { struct svc_serv *serv = rqstp->rq_server; int len = 0; if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) { - dprintk("svc_recv: found XPT_CLOSE\n"); + if (test_and_clear_bit(XPT_KILL_TEMP, &xprt->xpt_flags)) + xprt->xpt_ops->xpo_kill_temp_xprt(xprt); svc_delete_xprt(xprt); /* Leave XPT_BUSY set on the dead xprt: */ - return 0; + goto out; } if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) { struct svc_xprt *newxpt; @@ -728,139 +781,148 @@ static int svc_handle_xprt(struct svc_rqst *rqstp, struct svc_xprt *xprt) __module_get(xprt->xpt_class->xcl_owner); svc_check_conn_limits(xprt->xpt_server); newxpt = xprt->xpt_ops->xpo_accept(xprt); - if (newxpt) + if (newxpt) { + newxpt->xpt_cred = get_cred(xprt->xpt_cred); svc_add_new_temp_xprt(serv, newxpt); - } else if (xprt->xpt_ops->xpo_has_wspace(xprt)) { + trace_svc_xprt_accept(newxpt, serv->sv_name); + } else { + module_put(xprt->xpt_class->xcl_owner); + } + svc_xprt_received(xprt); + } else if (test_bit(XPT_HANDSHAKE, &xprt->xpt_flags)) { + xprt->xpt_ops->xpo_handshake(xprt); + svc_xprt_received(xprt); + } else if (svc_xprt_reserve_slot(rqstp, xprt)) { /* XPT_DATA|XPT_DEFERRED case: */ - dprintk("svc: server %p, pool %u, transport %p, inuse=%d\n", - rqstp, rqstp->rq_pool->sp_id, xprt, - atomic_read(&xprt->xpt_ref.refcount)); rqstp->rq_deferred = svc_deferred_dequeue(xprt); if (rqstp->rq_deferred) len = svc_deferred_recv(rqstp); else len = xprt->xpt_ops->xpo_recvfrom(rqstp); - dprintk("svc: got len=%d\n", len); rqstp->rq_reserved = serv->sv_max_mesg; atomic_add(rqstp->rq_reserved, &xprt->xpt_reserved); - } - /* clear XPT_BUSY: */ - svc_xprt_received(xprt); - return len; -} - -/* - * Receive the next request on any transport. This code is carefully - * organised not to touch any cachelines in the shared svc_serv - * structure, only cachelines in the local svc_pool. - */ -int svc_recv(struct svc_rqst *rqstp, long timeout) -{ - struct svc_xprt *xprt = NULL; - struct svc_serv *serv = rqstp->rq_server; - int len, err; - - dprintk("svc: server %p waiting for data (to = %ld)\n", - rqstp, timeout); - - if (rqstp->rq_xprt) - printk(KERN_ERR - "svc_recv: service %p, transport not NULL!\n", - rqstp); - if (waitqueue_active(&rqstp->rq_wait)) - printk(KERN_ERR - "svc_recv: service %p, wait queue active!\n", - rqstp); - - err = svc_alloc_arg(rqstp); - if (err) - return err; - - try_to_freeze(); - cond_resched(); - if (signalled() || kthread_should_stop()) - return -EINTR; - - xprt = svc_get_next_xprt(rqstp, timeout); - if (IS_ERR(xprt)) - return PTR_ERR(xprt); + if (len <= 0) + goto out; - len = svc_handle_xprt(rqstp, xprt); + trace_svc_xdr_recvfrom(&rqstp->rq_arg); - /* No data, incomplete (TCP) read, or accept() */ - if (len <= 0) - goto out; + clear_bit(XPT_OLD, &xprt->xpt_flags); - clear_bit(XPT_OLD, &xprt->xpt_flags); + rqstp->rq_chandle.defer = svc_defer; - rqstp->rq_secure = svc_port_is_privileged(svc_addr(rqstp)); - rqstp->rq_chandle.defer = svc_defer; + if (serv->sv_stats) + serv->sv_stats->netcnt++; + percpu_counter_inc(&rqstp->rq_pool->sp_messages_arrived); + rqstp->rq_stime = ktime_get(); + svc_process(rqstp); + } else + svc_xprt_received(xprt); - if (serv->sv_stats) - serv->sv_stats->netcnt++; - return len; out: rqstp->rq_res.len = 0; svc_xprt_release(rqstp); - return -EAGAIN; } -EXPORT_SYMBOL_GPL(svc_recv); -/* - * Drop request +static void svc_thread_wake_next(struct svc_rqst *rqstp) +{ + if (!svc_thread_should_sleep(rqstp)) + /* More work pending after I dequeued some, + * wake another worker + */ + svc_pool_wake_idle_thread(rqstp->rq_pool); +} + +/** + * svc_recv - Receive and process the next request on any transport + * @rqstp: an idle RPC service thread + * + * This code is carefully organised not to touch any cachelines in + * the shared svc_serv structure, only cachelines in the local + * svc_pool. */ -void svc_drop(struct svc_rqst *rqstp) +void svc_recv(struct svc_rqst *rqstp) { - dprintk("svc: xprt %p dropped request\n", rqstp->rq_xprt); - svc_xprt_release(rqstp); + struct svc_pool *pool = rqstp->rq_pool; + + if (!svc_alloc_arg(rqstp)) + return; + + svc_thread_wait_for_work(rqstp); + + clear_bit(SP_TASK_PENDING, &pool->sp_flags); + + if (svc_thread_should_stop(rqstp)) { + svc_thread_wake_next(rqstp); + return; + } + + rqstp->rq_xprt = svc_xprt_dequeue(pool); + if (rqstp->rq_xprt) { + struct svc_xprt *xprt = rqstp->rq_xprt; + + svc_thread_wake_next(rqstp); + /* Normally we will wait up to 5 seconds for any required + * cache information to be provided. When there are no + * idle threads, we reduce the wait time. + */ + if (pool->sp_idle_threads.first) + rqstp->rq_chandle.thread_wait = 5 * HZ; + else + rqstp->rq_chandle.thread_wait = 1 * HZ; + + trace_svc_xprt_dequeue(rqstp); + svc_handle_xprt(rqstp, xprt); + } + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + if (svc_is_backchannel(rqstp)) { + struct svc_serv *serv = rqstp->rq_server; + struct rpc_rqst *req; + + req = lwq_dequeue(&serv->sv_cb_list, + struct rpc_rqst, rq_bc_list); + if (req) { + svc_thread_wake_next(rqstp); + svc_process_bc(req, rqstp); + } + } +#endif } -EXPORT_SYMBOL_GPL(svc_drop); +EXPORT_SYMBOL_GPL(svc_recv); -/* - * Return reply to client. +/** + * svc_send - Return reply to client + * @rqstp: RPC transaction context + * */ -int svc_send(struct svc_rqst *rqstp) +void svc_send(struct svc_rqst *rqstp) { struct svc_xprt *xprt; - int len; struct xdr_buf *xb; + int status; xprt = rqstp->rq_xprt; - if (!xprt) - return -EFAULT; - - /* release the receive skb before sending the reply */ - rqstp->rq_xprt->xpt_ops->xpo_release_rqst(rqstp); /* calculate over-all length */ xb = &rqstp->rq_res; xb->len = xb->head[0].iov_len + xb->page_len + xb->tail[0].iov_len; + trace_svc_xdr_sendto(rqstp->rq_xid, xb); + trace_svc_stats_latency(rqstp); - /* Grab mutex to serialize outgoing data. */ - mutex_lock(&xprt->xpt_mutex); - if (test_bit(XPT_DEAD, &xprt->xpt_flags) - || test_bit(XPT_CLOSE, &xprt->xpt_flags)) - len = -ENOTCONN; - else - len = xprt->xpt_ops->xpo_sendto(rqstp); - mutex_unlock(&xprt->xpt_mutex); - rpc_wake_up(&xprt->xpt_bc_pending); - svc_xprt_release(rqstp); + status = xprt->xpt_ops->xpo_sendto(rqstp); - if (len == -ECONNREFUSED || len == -ENOTCONN || len == -EAGAIN) - return 0; - return len; + trace_svc_send(rqstp, status); } /* * Timer function to close old temporary transports, using * a mark-and-sweep algorithm. */ -static void svc_age_temp_xprts(unsigned long closure) +static void svc_age_temp_xprts(struct timer_list *t) { - struct svc_serv *serv = (struct svc_serv *)closure; + struct svc_serv *serv = timer_container_of(serv, t, sv_temptimer); struct svc_xprt *xprt; struct list_head *le, *next; @@ -880,12 +942,11 @@ static void svc_age_temp_xprts(unsigned long closure) * through, close it. */ if (!test_and_set_bit(XPT_OLD, &xprt->xpt_flags)) continue; - if (atomic_read(&xprt->xpt_ref.refcount) > 1 || + if (kref_read(&xprt->xpt_ref) > 1 || test_bit(XPT_BUSY, &xprt->xpt_flags)) continue; list_del_init(le); set_bit(XPT_CLOSE, &xprt->xpt_flags); - set_bit(XPT_DETACHED, &xprt->xpt_flags); dprintk("queuing xprt %p for closing\n", xprt); /* a thread will dequeue and close it soon */ @@ -896,6 +957,42 @@ static void svc_age_temp_xprts(unsigned long closure) mod_timer(&serv->sv_temptimer, jiffies + svc_conn_age_period * HZ); } +/* Close temporary transports whose xpt_local matches server_addr immediately + * instead of waiting for them to be picked up by the timer. + * + * This is meant to be called from a notifier_block that runs when an ip + * address is deleted. + */ +void svc_age_temp_xprts_now(struct svc_serv *serv, struct sockaddr *server_addr) +{ + struct svc_xprt *xprt; + struct list_head *le, *next; + LIST_HEAD(to_be_closed); + + spin_lock_bh(&serv->sv_lock); + list_for_each_safe(le, next, &serv->sv_tempsocks) { + xprt = list_entry(le, struct svc_xprt, xpt_list); + if (rpc_cmp_addr(server_addr, (struct sockaddr *) + &xprt->xpt_local)) { + dprintk("svc_age_temp_xprts_now: found %p\n", xprt); + list_move(le, &to_be_closed); + } + } + spin_unlock_bh(&serv->sv_lock); + + while (!list_empty(&to_be_closed)) { + le = to_be_closed.next; + list_del_init(le); + xprt = list_entry(le, struct svc_xprt, xpt_list); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + set_bit(XPT_KILL_TEMP, &xprt->xpt_flags); + dprintk("svc_age_temp_xprts_now: queuing xprt %p for closing\n", + xprt); + svc_xprt_enqueue(xprt); + } +} +EXPORT_SYMBOL_GPL(svc_age_temp_xprts_now); + static void call_xpt_users(struct svc_xprt *xprt) { struct svc_xpt_user *u; @@ -903,7 +1000,7 @@ static void call_xpt_users(struct svc_xprt *xprt) spin_lock(&xprt->xpt_lock); while (!list_empty(&xprt->xpt_users)) { u = list_first_entry(&xprt->xpt_users, struct svc_xpt_user, list); - list_del(&u->list); + list_del_init(&u->list); u->callback(u); } spin_unlock(&xprt->xpt_lock); @@ -917,30 +1014,49 @@ static void svc_delete_xprt(struct svc_xprt *xprt) struct svc_serv *serv = xprt->xpt_server; struct svc_deferred_req *dr; - /* Only do this once */ + /* unregister with rpcbind for when transport type is TCP or UDP. + */ + if (test_bit(XPT_RPCB_UNREG, &xprt->xpt_flags)) { + struct svc_sock *svsk = container_of(xprt, struct svc_sock, + sk_xprt); + struct socket *sock = svsk->sk_sock; + + if (svc_register(serv, xprt->xpt_net, sock->sk->sk_family, + sock->sk->sk_protocol, 0) < 0) + pr_warn("failed to unregister %s with rpcbind\n", + xprt->xpt_class->xcl_name); + } + if (test_and_set_bit(XPT_DEAD, &xprt->xpt_flags)) - BUG(); + return; - dprintk("svc: svc_delete_xprt(%p)\n", xprt); + trace_svc_xprt_detach(xprt); xprt->xpt_ops->xpo_detach(xprt); + if (xprt->xpt_bc_xprt) + xprt->xpt_bc_xprt->ops->close(xprt->xpt_bc_xprt); spin_lock_bh(&serv->sv_lock); - if (!test_and_set_bit(XPT_DETACHED, &xprt->xpt_flags)) - list_del_init(&xprt->xpt_list); - WARN_ON_ONCE(!list_empty(&xprt->xpt_ready)); - if (test_bit(XPT_TEMP, &xprt->xpt_flags)) + list_del_init(&xprt->xpt_list); + if (test_bit(XPT_TEMP, &xprt->xpt_flags) && + !test_bit(XPT_PEER_VALID, &xprt->xpt_flags)) serv->sv_tmpcnt--; spin_unlock_bh(&serv->sv_lock); while ((dr = svc_deferred_dequeue(xprt)) != NULL) - kfree(dr); + free_deferred(xprt, dr); call_xpt_users(xprt); svc_xprt_put(xprt); } -void svc_close_xprt(struct svc_xprt *xprt) +/** + * svc_xprt_close - Close a client connection + * @xprt: transport to disconnect + * + */ +void svc_xprt_close(struct svc_xprt *xprt) { + trace_svc_xprt_close(xprt); set_bit(XPT_CLOSE, &xprt->xpt_flags); if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) /* someone else will have to effect the close */ @@ -953,14 +1069,14 @@ void svc_close_xprt(struct svc_xprt *xprt) */ svc_delete_xprt(xprt); } -EXPORT_SYMBOL_GPL(svc_close_xprt); +EXPORT_SYMBOL_GPL(svc_xprt_close); static int svc_close_list(struct svc_serv *serv, struct list_head *xprt_list, struct net *net) { struct svc_xprt *xprt; int ret = 0; - spin_lock(&serv->sv_lock); + spin_lock_bh(&serv->sv_lock); list_for_each_entry(xprt, xprt_list, xpt_list) { if (xprt->xpt_net != net) continue; @@ -968,44 +1084,39 @@ static int svc_close_list(struct svc_serv *serv, struct list_head *xprt_list, st set_bit(XPT_CLOSE, &xprt->xpt_flags); svc_xprt_enqueue(xprt); } - spin_unlock(&serv->sv_lock); + spin_unlock_bh(&serv->sv_lock); return ret; } -static struct svc_xprt *svc_dequeue_net(struct svc_serv *serv, struct net *net) +static void svc_clean_up_xprts(struct svc_serv *serv, struct net *net) { - struct svc_pool *pool; struct svc_xprt *xprt; - struct svc_xprt *tmp; int i; for (i = 0; i < serv->sv_nrpools; i++) { - pool = &serv->sv_pools[i]; - - spin_lock_bh(&pool->sp_lock); - list_for_each_entry_safe(xprt, tmp, &pool->sp_sockets, xpt_ready) { - if (xprt->xpt_net != net) - continue; - list_del_init(&xprt->xpt_ready); - spin_unlock_bh(&pool->sp_lock); - return xprt; + struct svc_pool *pool = &serv->sv_pools[i]; + struct llist_node *q, **t1, *t2; + + q = lwq_dequeue_all(&pool->sp_xprts); + lwq_for_each_safe(xprt, t1, t2, &q, xpt_ready) { + if (xprt->xpt_net == net) { + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_delete_xprt(xprt); + xprt = NULL; + } } - spin_unlock_bh(&pool->sp_lock); - } - return NULL; -} - -static void svc_clean_up_xprts(struct svc_serv *serv, struct net *net) -{ - struct svc_xprt *xprt; - while ((xprt = svc_dequeue_net(serv, net))) { - set_bit(XPT_CLOSE, &xprt->xpt_flags); - svc_delete_xprt(xprt); + if (q) + lwq_enqueue_batch(q, &pool->sp_xprts); } } -/* +/** + * svc_xprt_destroy_all - Destroy transports associated with @serv + * @serv: RPC service to be shut down + * @net: target network namespace + * @unregister: true if it is OK to unregister the destroyed xprts + * * Server threads may still be running (especially in the case where the * service is still running in other network namespaces). * @@ -1017,7 +1128,8 @@ static void svc_clean_up_xprts(struct svc_serv *serv, struct net *net) * threads, we may need to wait a little while and then check again to * see if they're done. */ -void svc_close_net(struct svc_serv *serv, struct net *net) +void svc_xprt_destroy_all(struct svc_serv *serv, struct net *net, + bool unregister) { int delay = 0; @@ -1027,7 +1139,11 @@ void svc_close_net(struct svc_serv *serv, struct net *net) svc_clean_up_xprts(serv, net); msleep(delay++); } + + if (unregister) + svc_rpcb_cleanup(serv, net); } +EXPORT_SYMBOL_GPL(svc_xprt_destroy_all); /* * Handle defer and revisit of requests @@ -1043,15 +1159,15 @@ static void svc_revisit(struct cache_deferred_req *dreq, int too_many) set_bit(XPT_DEFERRED, &xprt->xpt_flags); if (too_many || test_bit(XPT_DEAD, &xprt->xpt_flags)) { spin_unlock(&xprt->xpt_lock); - dprintk("revisit canceled\n"); + trace_svc_defer_drop(dr); + free_deferred(xprt, dr); svc_xprt_put(xprt); - kfree(dr); return; } - dprintk("revisit queued\n"); dr->xprt = NULL; list_add(&dr->handle.recent, &xprt->xpt_deferred); spin_unlock(&xprt->xpt_lock); + trace_svc_defer_queue(dr); svc_xprt_enqueue(xprt); svc_xprt_put(xprt); } @@ -1070,7 +1186,7 @@ static struct cache_deferred_req *svc_defer(struct cache_req *req) struct svc_rqst *rqstp = container_of(req, struct svc_rqst, rq_chandle); struct svc_deferred_req *dr; - if (rqstp->rq_arg.page_len || !rqstp->rq_usedeferral) + if (rqstp->rq_arg.page_len || !test_bit(RQ_USEDEFERRAL, &rqstp->rq_flags)) return NULL; /* if more than a page, give up FIXME */ if (rqstp->rq_deferred) { dr = rqstp->rq_deferred; @@ -1090,16 +1206,18 @@ static struct cache_deferred_req *svc_defer(struct cache_req *req) dr->addrlen = rqstp->rq_addrlen; dr->daddr = rqstp->rq_daddr; dr->argslen = rqstp->rq_arg.len >> 2; - dr->xprt_hlen = rqstp->rq_xprt_hlen; /* back up head to the start of the buffer and copy */ skip = rqstp->rq_arg.len - rqstp->rq_arg.head[0].iov_len; memcpy(dr->args, rqstp->rq_arg.head[0].iov_base - skip, dr->argslen << 2); } + dr->xprt_ctxt = rqstp->rq_xprt_ctxt; + rqstp->rq_xprt_ctxt = NULL; + trace_svc_defer(rqstp); svc_xprt_get(rqstp->rq_xprt); dr->xprt = rqstp->rq_xprt; - rqstp->rq_dropme = true; + set_bit(RQ_DROPME, &rqstp->rq_flags); dr->handle.revisit = svc_revisit; return &dr->handle; @@ -1108,25 +1226,30 @@ static struct cache_deferred_req *svc_defer(struct cache_req *req) /* * recv data from a deferred request into an active one */ -static int svc_deferred_recv(struct svc_rqst *rqstp) +static noinline int svc_deferred_recv(struct svc_rqst *rqstp) { struct svc_deferred_req *dr = rqstp->rq_deferred; + trace_svc_defer_recv(dr); + /* setup iov_base past transport header */ - rqstp->rq_arg.head[0].iov_base = dr->args + (dr->xprt_hlen>>2); + rqstp->rq_arg.head[0].iov_base = dr->args; /* The iov_len does not include the transport header bytes */ - rqstp->rq_arg.head[0].iov_len = (dr->argslen<<2) - dr->xprt_hlen; + rqstp->rq_arg.head[0].iov_len = dr->argslen << 2; rqstp->rq_arg.page_len = 0; /* The rq_arg.len includes the transport header bytes */ - rqstp->rq_arg.len = dr->argslen<<2; + rqstp->rq_arg.len = dr->argslen << 2; rqstp->rq_prot = dr->prot; memcpy(&rqstp->rq_addr, &dr->addr, dr->addrlen); rqstp->rq_addrlen = dr->addrlen; /* Save off transport header len in case we get deferred again */ - rqstp->rq_xprt_hlen = dr->xprt_hlen; rqstp->rq_daddr = dr->daddr; rqstp->rq_respages = rqstp->rq_pages; - return (dr->argslen<<2) - dr->xprt_hlen; + rqstp->rq_xprt_ctxt = dr->xprt_ctxt; + + dr->xprt_ctxt = NULL; + svc_xprt_received(rqstp->rq_xprt); + return dr->argslen << 2; } @@ -1149,6 +1272,40 @@ static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt) } /** + * svc_find_listener - find an RPC transport instance + * @serv: pointer to svc_serv to search + * @xcl_name: C string containing transport's class name + * @net: owner net pointer + * @sa: sockaddr containing address + * + * Return the transport instance pointer for the endpoint accepting + * connections/peer traffic from the specified transport class, + * and matching sockaddr. + */ +struct svc_xprt *svc_find_listener(struct svc_serv *serv, const char *xcl_name, + struct net *net, const struct sockaddr *sa) +{ + struct svc_xprt *xprt; + struct svc_xprt *found = NULL; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) { + if (xprt->xpt_net != net) + continue; + if (strcmp(xprt->xpt_class->xcl_name, xcl_name)) + continue; + if (!rpc_cmp_addr_port(sa, (struct sockaddr *)&xprt->xpt_local)) + continue; + found = xprt; + svc_xprt_get(xprt); + break; + } + spin_unlock_bh(&serv->sv_lock); + return found; +} +EXPORT_SYMBOL_GPL(svc_find_listener); + +/** * svc_find_xprt - find an RPC transport instance * @serv: pointer to svc_serv to search * @xcl_name: C string containing transport's class name @@ -1251,29 +1408,36 @@ int svc_xprt_names(struct svc_serv *serv, char *buf, const int buflen) } EXPORT_SYMBOL_GPL(svc_xprt_names); - /*----------------------------------------------------------------------------*/ static void *svc_pool_stats_start(struct seq_file *m, loff_t *pos) { unsigned int pidx = (unsigned int)*pos; - struct svc_serv *serv = m->private; + struct svc_info *si = m->private; dprintk("svc_pool_stats_start, *pidx=%u\n", pidx); + mutex_lock(si->mutex); + if (!pidx) return SEQ_START_TOKEN; - return (pidx > serv->sv_nrpools ? NULL : &serv->sv_pools[pidx-1]); + if (!si->serv) + return NULL; + return pidx > si->serv->sv_nrpools ? NULL + : &si->serv->sv_pools[pidx - 1]; } static void *svc_pool_stats_next(struct seq_file *m, void *p, loff_t *pos) { struct svc_pool *pool = p; - struct svc_serv *serv = m->private; + struct svc_info *si = m->private; + struct svc_serv *serv = si->serv; dprintk("svc_pool_stats_next, *pos=%llu\n", *pos); - if (p == SEQ_START_TOKEN) { + if (!serv) { + pool = NULL; + } else if (p == SEQ_START_TOKEN) { pool = &serv->sv_pools[0]; } else { unsigned int pidx = (pool - &serv->sv_pools[0]); @@ -1288,6 +1452,9 @@ static void *svc_pool_stats_next(struct seq_file *m, void *p, loff_t *pos) static void svc_pool_stats_stop(struct seq_file *m, void *p) { + struct svc_info *si = m->private; + + mutex_unlock(si->mutex); } static int svc_pool_stats_show(struct seq_file *m, void *p) @@ -1299,12 +1466,11 @@ static int svc_pool_stats_show(struct seq_file *m, void *p) return 0; } - seq_printf(m, "%u %lu %lu %lu %lu\n", - pool->sp_id, - pool->sp_stats.packets, - pool->sp_stats.sockets_queued, - pool->sp_stats.threads_woken, - pool->sp_stats.threads_timedout); + seq_printf(m, "%u %llu %llu %llu 0\n", + pool->sp_id, + percpu_counter_sum_positive(&pool->sp_messages_arrived), + percpu_counter_sum_positive(&pool->sp_sockets_queued), + percpu_counter_sum_positive(&pool->sp_threads_woken)); return 0; } @@ -1316,14 +1482,18 @@ static const struct seq_operations svc_pool_stats_seq_ops = { .show = svc_pool_stats_show, }; -int svc_pool_stats_open(struct svc_serv *serv, struct file *file) +int svc_pool_stats_open(struct svc_info *info, struct file *file) { + struct seq_file *seq; int err; err = seq_open(file, &svc_pool_stats_seq_ops); - if (!err) - ((struct seq_file *) file->private_data)->private = serv; - return err; + if (err) + return err; + seq = file->private_data; + seq->private = info; + + return 0; } EXPORT_SYMBOL(svc_pool_stats_open); diff --git a/net/sunrpc/svcauth.c b/net/sunrpc/svcauth.c index 2af7b0cba43a..55b4d2874188 100644 --- a/net/sunrpc/svcauth.c +++ b/net/sunrpc/svcauth.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/svcauth.c * @@ -17,6 +18,11 @@ #include <linux/sunrpc/svcauth.h> #include <linux/err.h> #include <linux/hash.h> +#include <linux/user_namespace.h> + +#include <trace/events/sunrpc.h> + +#include "sunrpc.h" #define RPCDBG_FACILITY RPCDBG_AUTH @@ -26,48 +32,96 @@ */ extern struct auth_ops svcauth_null; extern struct auth_ops svcauth_unix; +extern struct auth_ops svcauth_tls; -static DEFINE_SPINLOCK(authtab_lock); -static struct auth_ops *authtab[RPC_AUTH_MAXFLAVOR] = { - [0] = &svcauth_null, - [1] = &svcauth_unix, +static struct auth_ops __rcu *authtab[RPC_AUTH_MAXFLAVOR] = { + [RPC_AUTH_NULL] = (struct auth_ops __force __rcu *)&svcauth_null, + [RPC_AUTH_UNIX] = (struct auth_ops __force __rcu *)&svcauth_unix, + [RPC_AUTH_TLS] = (struct auth_ops __force __rcu *)&svcauth_tls, }; -int -svc_authenticate(struct svc_rqst *rqstp, __be32 *authp) +static struct auth_ops * +svc_get_auth_ops(rpc_authflavor_t flavor) { - rpc_authflavor_t flavor; struct auth_ops *aops; - *authp = rpc_auth_ok; + if (flavor >= RPC_AUTH_MAXFLAVOR) + return NULL; + rcu_read_lock(); + aops = rcu_dereference(authtab[flavor]); + if (aops != NULL && !try_module_get(aops->owner)) + aops = NULL; + rcu_read_unlock(); + return aops; +} + +static void +svc_put_auth_ops(struct auth_ops *aops) +{ + module_put(aops->owner); +} + +/** + * svc_authenticate - Initialize an outgoing credential + * @rqstp: RPC execution context + * + * Return values: + * %SVC_OK: XDR encoding of the result can begin + * %SVC_DENIED: Credential or verifier is not valid + * %SVC_GARBAGE: Failed to decode credential or verifier + * %SVC_COMPLETE: GSS context lifetime event; no further action + * %SVC_DROP: Drop this request; no further action + * %SVC_CLOSE: Like drop, but also close transport connection + */ +enum svc_auth_status svc_authenticate(struct svc_rqst *rqstp) +{ + struct auth_ops *aops; + u32 flavor; - flavor = svc_getnl(&rqstp->rq_arg.head[0]); + rqstp->rq_auth_stat = rpc_auth_ok; - dprintk("svc: svc_authenticate (%d)\n", flavor); + /* + * Decode the Call credential's flavor field. The credential's + * body field is decoded in the chosen ->accept method below. + */ + if (xdr_stream_decode_u32(&rqstp->rq_arg_stream, &flavor) < 0) + return SVC_GARBAGE; - spin_lock(&authtab_lock); - if (flavor >= RPC_AUTH_MAXFLAVOR || !(aops = authtab[flavor]) || - !try_module_get(aops->owner)) { - spin_unlock(&authtab_lock); - *authp = rpc_autherr_badcred; + aops = svc_get_auth_ops(flavor); + if (aops == NULL) { + rqstp->rq_auth_stat = rpc_autherr_badcred; return SVC_DENIED; } - spin_unlock(&authtab_lock); + + rqstp->rq_auth_slack = 0; + init_svc_cred(&rqstp->rq_cred); rqstp->rq_authop = aops; - return aops->accept(rqstp, authp); + return aops->accept(rqstp); } -EXPORT_SYMBOL_GPL(svc_authenticate); -int svc_set_client(struct svc_rqst *rqstp) +/** + * svc_set_client - Assign an appropriate 'auth_domain' as the client + * @rqstp: RPC execution context + * + * Return values: + * %SVC_OK: Client was found and assigned + * %SVC_DENY: Client was explicitly denied + * %SVC_DROP: Ignore this request + * %SVC_CLOSE: Ignore this request and close the connection + */ +enum svc_auth_status svc_set_client(struct svc_rqst *rqstp) { + rqstp->rq_client = NULL; return rqstp->rq_authop->set_client(rqstp); } EXPORT_SYMBOL_GPL(svc_set_client); -/* A request, which was authenticated, has now executed. - * Time to finalise the credentials and verifier - * and release and resources +/** + * svc_authorise - Finalize credentials/verifier and release resources + * @rqstp: RPC execution context + * + * Returns zero on success, or a negative errno. */ int svc_authorise(struct svc_rqst *rqstp) { @@ -78,7 +132,7 @@ int svc_authorise(struct svc_rqst *rqstp) if (aops) { rv = aops->release(rqstp); - module_put(aops->owner); + svc_put_auth_ops(aops); } return rv; } @@ -86,13 +140,14 @@ int svc_authorise(struct svc_rqst *rqstp) int svc_auth_register(rpc_authflavor_t flavor, struct auth_ops *aops) { + struct auth_ops *old; int rv = -EINVAL; - spin_lock(&authtab_lock); - if (flavor < RPC_AUTH_MAXFLAVOR && authtab[flavor] == NULL) { - authtab[flavor] = aops; - rv = 0; + + if (flavor < RPC_AUTH_MAXFLAVOR) { + old = cmpxchg((struct auth_ops ** __force)&authtab[flavor], NULL, aops); + if (old == NULL || old == aops) + rv = 0; } - spin_unlock(&authtab_lock); return rv; } EXPORT_SYMBOL_GPL(svc_auth_register); @@ -100,13 +155,54 @@ EXPORT_SYMBOL_GPL(svc_auth_register); void svc_auth_unregister(rpc_authflavor_t flavor) { - spin_lock(&authtab_lock); if (flavor < RPC_AUTH_MAXFLAVOR) - authtab[flavor] = NULL; - spin_unlock(&authtab_lock); + rcu_assign_pointer(authtab[flavor], NULL); } EXPORT_SYMBOL_GPL(svc_auth_unregister); +/** + * svc_auth_flavor - return RPC transaction's RPC_AUTH flavor + * @rqstp: RPC transaction context + * + * Returns an RPC flavor or GSS pseudoflavor. + */ +rpc_authflavor_t svc_auth_flavor(struct svc_rqst *rqstp) +{ + struct auth_ops *aops = rqstp->rq_authop; + + if (!aops->pseudoflavor) + return aops->flavour; + return aops->pseudoflavor(rqstp); +} +EXPORT_SYMBOL_GPL(svc_auth_flavor); + +/** + * svcauth_map_clnt_to_svc_cred_local - maps a generic cred + * to a svc_cred suitable for use in nfsd. + * @clnt: rpc_clnt associated with nfs client + * @cred: generic cred associated with nfs client + * @svc: returned svc_cred that is suitable for use in nfsd + */ +void svcauth_map_clnt_to_svc_cred_local(struct rpc_clnt *clnt, + const struct cred *cred, + struct svc_cred *svc) +{ + struct user_namespace *userns = clnt->cl_cred ? + clnt->cl_cred->user_ns : &init_user_ns; + + memset(svc, 0, sizeof(struct svc_cred)); + + svc->cr_uid = KUIDT_INIT(from_kuid_munged(userns, cred->fsuid)); + svc->cr_gid = KGIDT_INIT(from_kgid_munged(userns, cred->fsgid)); + svc->cr_flavor = clnt->cl_auth->au_flavor; + if (cred->group_info) + svc->cr_group_info = get_group_info(cred->group_info); + /* These aren't relevant for local (network is bypassed) */ + svc->cr_principal = NULL; + svc->cr_gss_mech = NULL; +} +EXPORT_SYMBOL_GPL(svcauth_map_clnt_to_svc_cred_local); + /************************************************** * 'auth_domains' are stored in a hash table indexed by name. * When the last reference to an 'auth_domain' is dropped, @@ -120,16 +216,21 @@ EXPORT_SYMBOL_GPL(svc_auth_unregister); #define DN_HASHMAX (1<<DN_HASHBITS) static struct hlist_head auth_domain_table[DN_HASHMAX]; -static spinlock_t auth_domain_lock = - __SPIN_LOCK_UNLOCKED(auth_domain_lock); +static DEFINE_SPINLOCK(auth_domain_lock); + +static void auth_domain_release(struct kref *kref) + __releases(&auth_domain_lock) +{ + struct auth_domain *dom = container_of(kref, struct auth_domain, ref); + + hlist_del_rcu(&dom->hash); + dom->flavour->domain_release(dom); + spin_unlock(&auth_domain_lock); +} void auth_domain_put(struct auth_domain *dom) { - if (atomic_dec_and_lock(&dom->ref.refcount, &auth_domain_lock)) { - hlist_del(&dom->hash); - dom->flavour->domain_release(dom); - spin_unlock(&auth_domain_lock); - } + kref_put_lock(&dom->ref, auth_domain_release, &auth_domain_lock); } EXPORT_SYMBOL_GPL(auth_domain_put); @@ -151,7 +252,7 @@ auth_domain_lookup(char *name, struct auth_domain *new) } } if (new) - hlist_add_head(&new->hash, head); + hlist_add_head_rcu(&new->hash, head); spin_unlock(&auth_domain_lock); return new; } @@ -159,6 +260,44 @@ EXPORT_SYMBOL_GPL(auth_domain_lookup); struct auth_domain *auth_domain_find(char *name) { - return auth_domain_lookup(name, NULL); + struct auth_domain *hp; + struct hlist_head *head; + + head = &auth_domain_table[hash_str(name, DN_HASHBITS)]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(hp, head, hash) { + if (strcmp(hp->name, name)==0) { + if (!kref_get_unless_zero(&hp->ref)) + hp = NULL; + rcu_read_unlock(); + return hp; + } + } + rcu_read_unlock(); + return NULL; } EXPORT_SYMBOL_GPL(auth_domain_find); + +/** + * auth_domain_cleanup - check that the auth_domain table is empty + * + * On module unload the auth_domain_table must be empty. To make it + * easier to catch bugs which don't clean up domains properly, we + * warn if anything remains in the table at cleanup time. + * + * Note that we cannot proactively remove the domains at this stage. + * The ->release() function might be in a module that has already been + * unloaded. + */ + +void auth_domain_cleanup(void) +{ + int h; + struct auth_domain *hp; + + for (h = 0; h < DN_HASHMAX; h++) + hlist_for_each_entry(hp, &auth_domain_table[h], hash) + pr_warn("svc: domain %s still present at module unload.\n", + hp->name); +} diff --git a/net/sunrpc/svcauth_unix.c b/net/sunrpc/svcauth_unix.c index 621ca7b4a155..8ca98b146ec8 100644 --- a/net/sunrpc/svcauth_unix.c +++ b/net/sunrpc/svcauth_unix.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only #include <linux/types.h> #include <linux/sched.h> #include <linux/module.h> @@ -16,8 +17,9 @@ #include <net/ipv6.h> #include <linux/kernel.h> #include <linux/user_namespace.h> -#define RPCDBG_FACILITY RPCDBG_AUTH +#include <trace/events/sunrpc.h> +#define RPCDBG_FACILITY RPCDBG_AUTH #include "netns.h" @@ -36,21 +38,28 @@ struct unix_domain { extern struct auth_ops svcauth_null; extern struct auth_ops svcauth_unix; +extern struct auth_ops svcauth_tls; -static void svcauth_unix_domain_release(struct auth_domain *dom) +static void svcauth_unix_domain_release_rcu(struct rcu_head *head) { + struct auth_domain *dom = container_of(head, struct auth_domain, rcu_head); struct unix_domain *ud = container_of(dom, struct unix_domain, h); kfree(dom->name); kfree(ud); } +static void svcauth_unix_domain_release(struct auth_domain *dom) +{ + call_rcu(&dom->rcu_head, svcauth_unix_domain_release_rcu); +} + struct auth_domain *unix_domain_find(char *name) { struct auth_domain *rv; struct unix_domain *new = NULL; - rv = auth_domain_lookup(name, NULL); + rv = auth_domain_find(name); while(1) { if (rv) { if (new && rv != &new->h) @@ -91,6 +100,7 @@ struct ip_map { char m_class[8]; /* e.g. "nfsd" */ struct in6_addr m_addr; struct unix_domain *m_client; + struct rcu_head m_rcu; }; static void ip_map_put(struct kref *kref) @@ -101,7 +111,7 @@ static void ip_map_put(struct kref *kref) if (test_bit(CACHE_VALID, &item->flags) && !test_bit(CACHE_NEGATIVE, &item->flags)) auth_domain_put(&im->m_client->h); - kfree(im); + kfree_rcu(im, m_rcu); } static inline int hash_ip6(const struct in6_addr *ip) @@ -140,6 +150,11 @@ static struct cache_head *ip_map_alloc(void) return NULL; } +static int ip_map_upcall(struct cache_detail *cd, struct cache_head *h) +{ + return sunrpc_cache_pipe_upcall(cd, h); +} + static void ip_map_request(struct cache_detail *cd, struct cache_head *h, char **bpp, int *blen) @@ -158,7 +173,7 @@ static void ip_map_request(struct cache_detail *cd, } static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, struct in6_addr *addr); -static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, struct unix_domain *udom, time_t expiry); +static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, struct unix_domain *udom, time64_t expiry); static int ip_map_parse(struct cache_detail *cd, char *mesg, int mlen) @@ -179,7 +194,7 @@ static int ip_map_parse(struct cache_detail *cd, struct ip_map *ipmp; struct auth_domain *dom; - time_t expiry; + time64_t expiry; if (mesg[mlen-1] != '\n') return -EINVAL; @@ -211,9 +226,9 @@ static int ip_map_parse(struct cache_detail *cd, return -EINVAL; } - expiry = get_expiry(&mesg); - if (expiry ==0) - return -EINVAL; + err = get_expiry(&mesg, &expiry); + if (err) + return err; /* domainname, or empty for NEGATIVE */ len = qword_get(&mesg, buf, mlen); @@ -280,9 +295,9 @@ static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, strcpy(ip.m_class, class); ip.m_addr = *addr; - ch = sunrpc_cache_lookup(cd, &ip.h, - hash_str(class, IP_HASHBITS) ^ - hash_ip6(addr)); + ch = sunrpc_cache_lookup_rcu(cd, &ip.h, + hash_str(class, IP_HASHBITS) ^ + hash_ip6(addr)); if (ch) return container_of(ch, struct ip_map, h); @@ -290,17 +305,8 @@ static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, return NULL; } -static inline struct ip_map *ip_map_lookup(struct net *net, char *class, - struct in6_addr *addr) -{ - struct sunrpc_net *sn; - - sn = net_generic(net, sunrpc_net_id); - return __ip_map_lookup(sn->ip_map_cache, class, addr); -} - static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, - struct unix_domain *udom, time_t expiry) + struct unix_domain *udom, time64_t expiry) { struct ip_map ip; struct cache_head *ch; @@ -319,15 +325,6 @@ static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, return 0; } -static inline int ip_map_update(struct net *net, struct ip_map *ipm, - struct unix_domain *udom, time_t expiry) -{ - struct sunrpc_net *sn; - - sn = net_generic(net, sunrpc_net_id); - return __ip_map_update(sn->ip_map_cache, ipm, udom, expiry); -} - void svcauth_unix_purge(struct net *net) { struct sunrpc_net *sn; @@ -403,7 +400,7 @@ svcauth_unix_info_release(struct svc_xprt *xpt) /**************************************************************************** * auth.unix.gid cache * simple cache to map a UID to a list of GIDs - * because AUTH_UNIX aka AUTH_SYS has a max of 16 + * because AUTH_UNIX aka AUTH_SYS has a max of UNX_NGROUPS */ #define GID_HASHBITS 8 #define GID_HASHMAX (1<<GID_HASHBITS) @@ -412,6 +409,7 @@ struct unix_gid { struct cache_head h; kuid_t uid; struct group_info *gi; + struct rcu_head rcu; }; static int unix_gid_hash(kuid_t uid) @@ -419,16 +417,25 @@ static int unix_gid_hash(kuid_t uid) return hash_long(from_kuid(&init_user_ns, uid), GID_HASHBITS); } -static void unix_gid_put(struct kref *kref) +static void unix_gid_free(struct rcu_head *rcu) { - struct cache_head *item = container_of(kref, struct cache_head, ref); - struct unix_gid *ug = container_of(item, struct unix_gid, h); + struct unix_gid *ug = container_of(rcu, struct unix_gid, rcu); + struct cache_head *item = &ug->h; + if (test_bit(CACHE_VALID, &item->flags) && !test_bit(CACHE_NEGATIVE, &item->flags)) put_group_info(ug->gi); kfree(ug); } +static void unix_gid_put(struct kref *kref) +{ + struct cache_head *item = container_of(kref, struct cache_head, ref); + struct unix_gid *ug = container_of(item, struct unix_gid, h); + + call_rcu(&ug->rcu, unix_gid_free); +} + static int unix_gid_match(struct cache_head *corig, struct cache_head *cnew) { struct unix_gid *orig = container_of(corig, struct unix_gid, h); @@ -458,6 +465,11 @@ static struct cache_head *unix_gid_alloc(void) return NULL; } +static int unix_gid_upcall(struct cache_detail *cd, struct cache_head *h) +{ + return sunrpc_cache_pipe_upcall_timeout(cd, h); +} + static void unix_gid_request(struct cache_detail *cd, struct cache_head *h, char **bpp, int *blen) @@ -482,7 +494,7 @@ static int unix_gid_parse(struct cache_detail *cd, int rv; int i; int err; - time_t expiry; + time64_t expiry; struct unix_gid ug, *ugp; if (mesg[mlen - 1] != '\n') @@ -492,12 +504,12 @@ static int unix_gid_parse(struct cache_detail *cd, rv = get_int(&mesg, &id); if (rv) return -EINVAL; - uid = make_kuid(&init_user_ns, id); + uid = make_kuid(current_user_ns(), id); ug.uid = uid; - expiry = get_expiry(&mesg); - if (expiry == 0) - return -EINVAL; + err = get_expiry(&mesg, &expiry); + if (err) + return err; rv = get_int(&mesg, &gids); if (rv || gids < 0 || gids > 8192) @@ -514,12 +526,13 @@ static int unix_gid_parse(struct cache_detail *cd, err = -EINVAL; if (rv) goto out; - kgid = make_kgid(&init_user_ns, gid); + kgid = make_kgid(current_user_ns(), gid); if (!gid_valid(kgid)) goto out; - GROUP_AT(ug.gi, i) = kgid; + ug.gi->gid[i] = kgid; } + groups_sort(ug.gi); ugp = unix_gid_lookup(cd, uid); if (ugp) { struct cache_head *ch; @@ -546,7 +559,7 @@ static int unix_gid_show(struct seq_file *m, struct cache_detail *cd, struct cache_head *h) { - struct user_namespace *user_ns = &init_user_ns; + struct user_namespace *user_ns = m->file->f_cred->user_ns; struct unix_gid *ug; int i; int glen; @@ -564,16 +577,17 @@ static int unix_gid_show(struct seq_file *m, seq_printf(m, "%u %d:", from_kuid_munged(user_ns, ug->uid), glen); for (i = 0; i < glen; i++) - seq_printf(m, " %d", from_kgid_munged(user_ns, GROUP_AT(ug->gi, i))); + seq_printf(m, " %d", from_kgid_munged(user_ns, ug->gi->gid[i])); seq_printf(m, "\n"); return 0; } -static struct cache_detail unix_gid_cache_template = { +static const struct cache_detail unix_gid_cache_template = { .owner = THIS_MODULE, .hash_size = GID_HASHMAX, .name = "auth.unix.gid", .cache_put = unix_gid_put, + .cache_upcall = unix_gid_upcall, .cache_request = unix_gid_request, .cache_parse = unix_gid_parse, .cache_show = unix_gid_show, @@ -618,7 +632,7 @@ static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid) struct cache_head *ch; ug.uid = uid; - ch = sunrpc_cache_lookup(cd, &ug.h, unix_gid_hash(uid)); + ch = sunrpc_cache_lookup_rcu(cd, &ug.h, unix_gid_hash(uid)); if (ch) return container_of(ch, struct unix_gid, h); else @@ -651,7 +665,7 @@ static struct group_info *unix_gid_find(kuid_t uid, struct svc_rqst *rqstp) } } -int +enum svc_auth_status svcauth_unix_set_client(struct svc_rqst *rqstp) { struct sockaddr_in *sin; @@ -678,11 +692,13 @@ svcauth_unix_set_client(struct svc_rqst *rqstp) rqstp->rq_client = NULL; if (rqstp->rq_proc == 0) - return SVC_OK; + goto out; + rqstp->rq_auth_stat = rpc_autherr_badcred; ipm = ip_map_cached_get(xprt); if (ipm == NULL) - ipm = __ip_map_lookup(sn->ip_map_cache, rqstp->rq_server->sv_program->pg_class, + ipm = __ip_map_lookup(sn->ip_map_cache, + rqstp->rq_server->sv_programs->pg_class, &sin6->sin6_addr); if (ipm == NULL) @@ -716,33 +732,46 @@ svcauth_unix_set_client(struct svc_rqst *rqstp) put_group_info(cred->cr_group_info); cred->cr_group_info = gi; } + +out: + rqstp->rq_auth_stat = rpc_auth_ok; return SVC_OK; } - EXPORT_SYMBOL_GPL(svcauth_unix_set_client); -static int -svcauth_null_accept(struct svc_rqst *rqstp, __be32 *authp) +/** + * svcauth_null_accept - Decode and validate incoming RPC_AUTH_NULL credential + * @rqstp: RPC transaction + * + * Return values: + * %SVC_OK: Both credential and verifier are valid + * %SVC_DENIED: Credential or verifier is not valid + * %SVC_GARBAGE: Failed to decode credential or verifier + * %SVC_CLOSE: Temporary failure + * + * rqstp->rq_auth_stat is set as mandated by RFC 5531. + */ +static enum svc_auth_status +svcauth_null_accept(struct svc_rqst *rqstp) { - struct kvec *argv = &rqstp->rq_arg.head[0]; - struct kvec *resv = &rqstp->rq_res.head[0]; + struct xdr_stream *xdr = &rqstp->rq_arg_stream; struct svc_cred *cred = &rqstp->rq_cred; + u32 flavor, len; + void *body; - cred->cr_group_info = NULL; - cred->cr_principal = NULL; - rqstp->rq_client = NULL; - - if (argv->iov_len < 3*4) + /* Length of Call's credential body field: */ + if (xdr_stream_decode_u32(xdr, &len) < 0) return SVC_GARBAGE; - - if (svc_getu32(argv) != 0) { - dprintk("svc: bad null cred\n"); - *authp = rpc_autherr_badcred; + if (len != 0) { + rqstp->rq_auth_stat = rpc_autherr_badcred; return SVC_DENIED; } - if (svc_getu32(argv) != htonl(RPC_AUTH_NULL) || svc_getu32(argv) != 0) { - dprintk("svc: bad null verf\n"); - *authp = rpc_autherr_badverf; + + /* Call's verf field: */ + if (xdr_stream_decode_opaque_auth(xdr, &flavor, &body, &len) < 0) + return SVC_GARBAGE; + if (flavor != RPC_AUTH_NULL || len != 0) { + rqstp->rq_auth_stat = rpc_autherr_badverf; return SVC_DENIED; } @@ -753,9 +782,11 @@ svcauth_null_accept(struct svc_rqst *rqstp, __be32 *authp) if (cred->cr_group_info == NULL) return SVC_CLOSE; /* kmalloc failure - client must retry */ - /* Put NULL verifier */ - svc_putnl(resv, RPC_AUTH_NULL); - svc_putnl(resv, 0); + if (xdr_stream_encode_opaque_auth(&rqstp->rq_res_stream, + RPC_AUTH_NULL, NULL, 0) < 0) + return SVC_CLOSE; + if (!svcxdr_set_accept_stat(rqstp)) + return SVC_CLOSE; rqstp->rq_cred.cr_flavor = RPC_AUTH_NULL; return SVC_OK; @@ -779,35 +810,133 @@ struct auth_ops svcauth_null = { .name = "null", .owner = THIS_MODULE, .flavour = RPC_AUTH_NULL, - .accept = svcauth_null_accept, + .accept = svcauth_null_accept, .release = svcauth_null_release, .set_client = svcauth_unix_set_client, }; -static int -svcauth_unix_accept(struct svc_rqst *rqstp, __be32 *authp) +/** + * svcauth_tls_accept - Decode and validate incoming RPC_AUTH_TLS credential + * @rqstp: RPC transaction + * + * Return values: + * %SVC_OK: Both credential and verifier are valid + * %SVC_DENIED: Credential or verifier is not valid + * %SVC_GARBAGE: Failed to decode credential or verifier + * %SVC_CLOSE: Temporary failure + * + * rqstp->rq_auth_stat is set as mandated by RFC 5531. + */ +static enum svc_auth_status +svcauth_tls_accept(struct svc_rqst *rqstp) { - struct kvec *argv = &rqstp->rq_arg.head[0]; - struct kvec *resv = &rqstp->rq_res.head[0]; + struct xdr_stream *xdr = &rqstp->rq_arg_stream; struct svc_cred *cred = &rqstp->rq_cred; - u32 slen, i; - int len = argv->iov_len; + struct svc_xprt *xprt = rqstp->rq_xprt; + u32 flavor, len; + void *body; + __be32 *p; - cred->cr_group_info = NULL; - cred->cr_principal = NULL; - rqstp->rq_client = NULL; + /* Length of Call's credential body field: */ + if (xdr_stream_decode_u32(xdr, &len) < 0) + return SVC_GARBAGE; + if (len != 0) { + rqstp->rq_auth_stat = rpc_autherr_badcred; + return SVC_DENIED; + } - if ((len -= 3*4) < 0) + /* Call's verf field: */ + if (xdr_stream_decode_opaque_auth(xdr, &flavor, &body, &len) < 0) + return SVC_GARBAGE; + if (flavor != RPC_AUTH_NULL || len != 0) { + rqstp->rq_auth_stat = rpc_autherr_badverf; + return SVC_DENIED; + } + + /* AUTH_TLS is not valid on non-NULL procedures */ + if (rqstp->rq_proc != 0) { + rqstp->rq_auth_stat = rpc_autherr_badcred; + return SVC_DENIED; + } + + /* Signal that mapping to nobody uid/gid is required */ + cred->cr_uid = INVALID_UID; + cred->cr_gid = INVALID_GID; + cred->cr_group_info = groups_alloc(0); + if (cred->cr_group_info == NULL) + return SVC_CLOSE; + + if (xprt->xpt_ops->xpo_handshake) { + p = xdr_reserve_space(&rqstp->rq_res_stream, XDR_UNIT * 2 + 8); + if (!p) + return SVC_CLOSE; + trace_svc_tls_start(xprt); + *p++ = rpc_auth_null; + *p++ = cpu_to_be32(8); + memcpy(p, "STARTTLS", 8); + + set_bit(XPT_HANDSHAKE, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + } else { + trace_svc_tls_unavailable(xprt); + if (xdr_stream_encode_opaque_auth(&rqstp->rq_res_stream, + RPC_AUTH_NULL, NULL, 0) < 0) + return SVC_CLOSE; + } + if (!svcxdr_set_accept_stat(rqstp)) + return SVC_CLOSE; + + rqstp->rq_cred.cr_flavor = RPC_AUTH_TLS; + return SVC_OK; +} + +struct auth_ops svcauth_tls = { + .name = "tls", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_TLS, + .accept = svcauth_tls_accept, + .release = svcauth_null_release, + .set_client = svcauth_unix_set_client, +}; + + +/** + * svcauth_unix_accept - Decode and validate incoming RPC_AUTH_SYS credential + * @rqstp: RPC transaction + * + * Return values: + * %SVC_OK: Both credential and verifier are valid + * %SVC_DENIED: Credential or verifier is not valid + * %SVC_GARBAGE: Failed to decode credential or verifier + * %SVC_CLOSE: Temporary failure + * + * rqstp->rq_auth_stat is set as mandated by RFC 5531. + */ +static enum svc_auth_status +svcauth_unix_accept(struct svc_rqst *rqstp) +{ + struct xdr_stream *xdr = &rqstp->rq_arg_stream; + struct svc_cred *cred = &rqstp->rq_cred; + struct user_namespace *userns; + u32 flavor, len, i; + void *body; + __be32 *p; + + /* + * This implementation ignores the length of the Call's + * credential body field and the timestamp and machinename + * fields. + */ + p = xdr_inline_decode(xdr, XDR_UNIT * 3); + if (!p) + return SVC_GARBAGE; + len = be32_to_cpup(p + 2); + if (len > RPC_MAX_MACHINENAME) + return SVC_GARBAGE; + if (!xdr_inline_decode(xdr, len)) return SVC_GARBAGE; - svc_getu32(argv); /* length */ - svc_getu32(argv); /* time stamp */ - slen = XDR_QUADLEN(svc_getnl(argv)); /* machname length */ - if (slen > 64 || (len -= (slen + 3)*4) < 0) - goto badcred; - argv->iov_base = (void*)((__be32*)argv->iov_base + slen); /* skip machname */ - argv->iov_len -= slen*4; /* * Note: we skip uid_valid()/gid_valid() checks here for * backwards compatibility with clients that use -1 id's. @@ -815,32 +944,50 @@ svcauth_unix_accept(struct svc_rqst *rqstp, __be32 *authp) * (export-specific) anonymous id by nfsd_setuser. * Supplementary gid's will be left alone. */ - cred->cr_uid = make_kuid(&init_user_ns, svc_getnl(argv)); /* uid */ - cred->cr_gid = make_kgid(&init_user_ns, svc_getnl(argv)); /* gid */ - slen = svc_getnl(argv); /* gids length */ - if (slen > 16 || (len -= (slen + 2)*4) < 0) + userns = (rqstp->rq_xprt && rqstp->rq_xprt->xpt_cred) ? + rqstp->rq_xprt->xpt_cred->user_ns : &init_user_ns; + if (xdr_stream_decode_u32(xdr, &i) < 0) + return SVC_GARBAGE; + cred->cr_uid = make_kuid(userns, i); + if (xdr_stream_decode_u32(xdr, &i) < 0) + return SVC_GARBAGE; + cred->cr_gid = make_kgid(userns, i); + + if (xdr_stream_decode_u32(xdr, &len) < 0) + return SVC_GARBAGE; + if (len > UNX_NGROUPS) goto badcred; - cred->cr_group_info = groups_alloc(slen); + p = xdr_inline_decode(xdr, XDR_UNIT * len); + if (!p) + return SVC_GARBAGE; + cred->cr_group_info = groups_alloc(len); if (cred->cr_group_info == NULL) return SVC_CLOSE; - for (i = 0; i < slen; i++) { - kgid_t kgid = make_kgid(&init_user_ns, svc_getnl(argv)); - GROUP_AT(cred->cr_group_info, i) = kgid; + for (i = 0; i < len; i++) { + kgid_t kgid = make_kgid(userns, be32_to_cpup(p++)); + cred->cr_group_info->gid[i] = kgid; } - if (svc_getu32(argv) != htonl(RPC_AUTH_NULL) || svc_getu32(argv) != 0) { - *authp = rpc_autherr_badverf; + groups_sort(cred->cr_group_info); + + /* Call's verf field: */ + if (xdr_stream_decode_opaque_auth(xdr, &flavor, &body, &len) < 0) + return SVC_GARBAGE; + if (flavor != RPC_AUTH_NULL || len != 0) { + rqstp->rq_auth_stat = rpc_autherr_badverf; return SVC_DENIED; } - /* Put NULL verifier */ - svc_putnl(resv, RPC_AUTH_NULL); - svc_putnl(resv, 0); + if (xdr_stream_encode_opaque_auth(&rqstp->rq_res_stream, + RPC_AUTH_NULL, NULL, 0) < 0) + return SVC_CLOSE; + if (!svcxdr_set_accept_stat(rqstp)) + return SVC_CLOSE; rqstp->rq_cred.cr_flavor = RPC_AUTH_UNIX; return SVC_OK; badcred: - *authp = rpc_autherr_badcred; + rqstp->rq_auth_stat = rpc_autherr_badcred; return SVC_DENIED; } @@ -864,17 +1011,18 @@ struct auth_ops svcauth_unix = { .name = "unix", .owner = THIS_MODULE, .flavour = RPC_AUTH_UNIX, - .accept = svcauth_unix_accept, + .accept = svcauth_unix_accept, .release = svcauth_unix_release, .domain_release = svcauth_unix_domain_release, .set_client = svcauth_unix_set_client, }; -static struct cache_detail ip_map_cache_template = { +static const struct cache_detail ip_map_cache_template = { .owner = THIS_MODULE, .hash_size = IP_HASHMAX, .name = "auth.unix.ip", .cache_put = ip_map_put, + .cache_upcall = ip_map_upcall, .cache_request = ip_map_request, .cache_parse = ip_map_parse, .cache_show = ip_map_show, diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index 305374d4fb98..d61cd9b40491 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/svcsock.c * @@ -35,15 +36,21 @@ #include <linux/skbuff.h> #include <linux/file.h> #include <linux/freezer.h> +#include <linux/bvec.h> + #include <net/sock.h> #include <net/checksum.h> #include <net/ip.h> #include <net/ipv6.h> +#include <net/udp.h> #include <net/tcp.h> #include <net/tcp_states.h> -#include <asm/uaccess.h> +#include <net/tls_prot.h> +#include <net/handshake.h> +#include <linux/uaccess.h> +#include <linux/highmem.h> #include <asm/ioctls.h> -#include <trace/events/skb.h> +#include <linux/key.h> #include <linux/sunrpc/types.h> #include <linux/sunrpc/clnt.h> @@ -53,14 +60,34 @@ #include <linux/sunrpc/stats.h> #include <linux/sunrpc/xprt.h> +#include <trace/events/sock.h> +#include <trace/events/sunrpc.h> + +#include "socklib.h" #include "sunrpc.h" #define RPCDBG_FACILITY RPCDBG_SVCXPRT +/* + * For UDP: + * 1 for header page + * enough pages for RPCSVC_MAXPAYLOAD_UDP + * 1 in case payload is not aligned + * 1 for tail page + */ +enum { + SUNRPC_MAX_UDP_SENDPAGES = 1 + RPCSVC_MAXPAYLOAD_UDP / PAGE_SIZE + 1 + 1 +}; + +/* To-do: to avoid tying up an nfsd thread while waiting for a + * handshake request, the request could instead be deferred. + */ +enum { + SVC_HANDSHAKE_TO = 5U * HZ +}; static struct svc_sock *svc_setup_socket(struct svc_serv *, struct socket *, int flags); -static void svc_udp_data_ready(struct sock *, int); static int svc_udp_recvfrom(struct svc_rqst *); static int svc_udp_sendto(struct svc_rqst *); static void svc_sock_detach(struct svc_xprt *); @@ -70,13 +97,6 @@ static void svc_sock_free(struct svc_xprt *); static struct svc_xprt *svc_create_socket(struct svc_serv *, int, struct net *, struct sockaddr *, int, int); -#if defined(CONFIG_SUNRPC_BACKCHANNEL) -static struct svc_xprt *svc_bc_create_socket(struct svc_serv *, int, - struct net *, struct sockaddr *, - int, int); -static void svc_bc_sock_free(struct svc_xprt *xprt); -#endif /* CONFIG_SUNRPC_BACKCHANNEL */ - #ifdef CONFIG_DEBUG_LOCK_ALLOC static struct lock_class_key svc_key[2]; static struct lock_class_key svc_slock_key[2]; @@ -85,8 +105,7 @@ static void svc_reclassify_socket(struct socket *sock) { struct sock *sk = sock->sk; - WARN_ON_ONCE(sock_owned_by_user(sk)); - if (sock_owned_by_user(sk)) + if (WARN_ON_ONCE(!sock_allow_reclassification(sk))) return; switch (sk->sk_family) { @@ -114,21 +133,28 @@ static void svc_reclassify_socket(struct socket *sock) } #endif -/* - * Release an skbuff after use +/** + * svc_tcp_release_ctxt - Release transport-related resources + * @xprt: the transport which owned the context + * @ctxt: the context from rqstp->rq_xprt_ctxt or dr->xprt_ctxt + * */ -static void svc_release_skb(struct svc_rqst *rqstp) +static void svc_tcp_release_ctxt(struct svc_xprt *xprt, void *ctxt) { - struct sk_buff *skb = rqstp->rq_xprt_ctxt; +} - if (skb) { - struct svc_sock *svsk = - container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); - rqstp->rq_xprt_ctxt = NULL; +/** + * svc_udp_release_ctxt - Release transport-related resources + * @xprt: the transport which owned the context + * @ctxt: the context from rqstp->rq_xprt_ctxt or dr->xprt_ctxt + * + */ +static void svc_udp_release_ctxt(struct svc_xprt *xprt, void *ctxt) +{ + struct sk_buff *skb = ctxt; - dprintk("svc: service %p, releasing skb %p\n", rqstp, skb); - skb_free_datagram_locked(svsk->sk_sk, skb); - } + if (skb) + consume_skb(skb); } union svc_pktinfo_u { @@ -169,109 +195,10 @@ static void svc_set_cmsg_data(struct svc_rqst *rqstp, struct cmsghdr *cmh) } } -/* - * send routine intended to be shared by the fore- and back-channel - */ -int svc_send_common(struct socket *sock, struct xdr_buf *xdr, - struct page *headpage, unsigned long headoffset, - struct page *tailpage, unsigned long tailoffset) +static int svc_sock_result_payload(struct svc_rqst *rqstp, unsigned int offset, + unsigned int length) { - int result; - int size; - struct page **ppage = xdr->pages; - size_t base = xdr->page_base; - unsigned int pglen = xdr->page_len; - unsigned int flags = MSG_MORE; - int slen; - int len = 0; - - slen = xdr->len; - - /* send head */ - if (slen == xdr->head[0].iov_len) - flags = 0; - len = kernel_sendpage(sock, headpage, headoffset, - xdr->head[0].iov_len, flags); - if (len != xdr->head[0].iov_len) - goto out; - slen -= xdr->head[0].iov_len; - if (slen == 0) - goto out; - - /* send page data */ - size = PAGE_SIZE - base < pglen ? PAGE_SIZE - base : pglen; - while (pglen > 0) { - if (slen == size) - flags = 0; - result = kernel_sendpage(sock, *ppage, base, size, flags); - if (result > 0) - len += result; - if (result != size) - goto out; - slen -= size; - pglen -= size; - size = PAGE_SIZE < pglen ? PAGE_SIZE : pglen; - base = 0; - ppage++; - } - - /* send tail */ - if (xdr->tail[0].iov_len) { - result = kernel_sendpage(sock, tailpage, tailoffset, - xdr->tail[0].iov_len, 0); - if (result > 0) - len += result; - } - -out: - return len; -} - - -/* - * Generic sendto routine - */ -static int svc_sendto(struct svc_rqst *rqstp, struct xdr_buf *xdr) -{ - struct svc_sock *svsk = - container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); - struct socket *sock = svsk->sk_sock; - union { - struct cmsghdr hdr; - long all[SVC_PKTINFO_SPACE / sizeof(long)]; - } buffer; - struct cmsghdr *cmh = &buffer.hdr; - int len = 0; - unsigned long tailoff; - unsigned long headoff; - RPC_IFDEBUG(char buf[RPC_MAX_ADDRBUFLEN]); - - if (rqstp->rq_prot == IPPROTO_UDP) { - struct msghdr msg = { - .msg_name = &rqstp->rq_addr, - .msg_namelen = rqstp->rq_addrlen, - .msg_control = cmh, - .msg_controllen = sizeof(buffer), - .msg_flags = MSG_MORE, - }; - - svc_set_cmsg_data(rqstp, cmh); - - if (sock_sendmsg(sock, &msg, 0) < 0) - goto out; - } - - tailoff = ((unsigned long)xdr->tail[0].iov_base) & (PAGE_SIZE-1); - headoff = 0; - len = svc_send_common(sock, xdr, rqstp->rq_respages[0], headoff, - rqstp->rq_respages[0], tailoff); - -out: - dprintk("svc: socket %p sendto([%p %Zu... ], %d) = %d (addr %s)\n", - svsk, xdr->head[0].iov_base, xdr->head[0].iov_len, - xdr->len, len, svc_print_addr(rqstp, buf, sizeof(buf))); - - return len; + return 0; } /* @@ -291,12 +218,14 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) &inet_sk(sk)->inet_rcv_saddr, inet_sk(sk)->inet_num); break; +#if IS_ENABLED(CONFIG_IPV6) case PF_INET6: len = snprintf(buf, remaining, "ipv6 %s %pI6 %d\n", proto_name, - &inet6_sk(sk)->rcv_saddr, + &sk->sk_v6_rcv_saddr, inet_sk(sk)->inet_num); break; +#endif default: len = snprintf(buf, remaining, "*unknown-%d*\n", sk->sk_family); @@ -309,112 +238,184 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) return len; } -/* - * Check input queue length - */ -static int svc_recv_available(struct svc_sock *svsk) +static int +svc_tcp_sock_process_cmsg(struct socket *sock, struct msghdr *msg, + struct cmsghdr *cmsg, int ret) { - struct socket *sock = svsk->sk_sock; - int avail, err; - - err = kernel_sock_ioctl(sock, TIOCINQ, (unsigned long) &avail); + u8 content_type = tls_get_record_type(sock->sk, cmsg); + u8 level, description; - return (err >= 0)? avail : err; + switch (content_type) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + msg->msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(sock->sk, msg, &level, &description); + ret = (level == TLS_ALERT_LEVEL_FATAL) ? + -ENOTCONN : -EAGAIN; + break; + default: + /* discard this record type */ + ret = -EAGAIN; + } + return ret; } -/* - * Generic recvfrom routine. - */ -static int svc_recvfrom(struct svc_rqst *rqstp, struct kvec *iov, int nr, - int buflen) +static int +svc_tcp_sock_recv_cmsg(struct socket *sock, unsigned int *msg_flags) { - struct svc_sock *svsk = - container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + union { + struct cmsghdr cmsg; + u8 buf[CMSG_SPACE(sizeof(u8))]; + } u; + u8 alert[2]; + struct kvec alert_kvec = { + .iov_base = alert, + .iov_len = sizeof(alert), + }; struct msghdr msg = { - .msg_flags = MSG_DONTWAIT, + .msg_flags = *msg_flags, + .msg_control = &u, + .msg_controllen = sizeof(u), }; - int len; + int ret; - rqstp->rq_xprt_hlen = 0; + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &alert_kvec, 1, + alert_kvec.iov_len); + ret = sock_recvmsg(sock, &msg, MSG_DONTWAIT); + if (ret > 0 && + tls_get_record_type(sock->sk, &u.cmsg) == TLS_RECORD_TYPE_ALERT) { + iov_iter_revert(&msg.msg_iter, ret); + ret = svc_tcp_sock_process_cmsg(sock, &msg, &u.cmsg, -EAGAIN); + } + return ret; +} - len = kernel_recvmsg(svsk->sk_sock, &msg, iov, nr, buflen, - msg.msg_flags); +static int +svc_tcp_sock_recvmsg(struct svc_sock *svsk, struct msghdr *msg) +{ + int ret; + struct socket *sock = svsk->sk_sock; - dprintk("svc: socket %p recvfrom(%p, %Zu) = %d\n", - svsk, iov[0].iov_base, iov[0].iov_len, len); - return len; + ret = sock_recvmsg(sock, msg, MSG_DONTWAIT); + if (msg->msg_flags & MSG_CTRUNC) { + msg->msg_flags &= ~(MSG_CTRUNC | MSG_EOR); + if (ret == 0 || ret == -EIO) + ret = svc_tcp_sock_recv_cmsg(sock, &msg->msg_flags); + } + return ret; } -static int svc_partial_recvfrom(struct svc_rqst *rqstp, - struct kvec *iov, int nr, - int buflen, unsigned int base) +#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE +static void svc_flush_bvec(const struct bio_vec *bvec, size_t size, size_t seek) { - size_t save_iovlen; - void *save_iovbase; + struct bvec_iter bi = { + .bi_size = size + seek, + }; + struct bio_vec bv; + + bvec_iter_advance(bvec, &bi, seek & PAGE_MASK); + for_each_bvec(bv, bvec, bi, bi) + flush_dcache_page(bv.bv_page); +} +#else +static inline void svc_flush_bvec(const struct bio_vec *bvec, size_t size, + size_t seek) +{ +} +#endif + +/* + * Read from @rqstp's transport socket. The incoming message fills whole + * pages in @rqstp's rq_pages array until the last page of the message + * has been received into a partial page. + */ +static ssize_t svc_tcp_read_msg(struct svc_rqst *rqstp, size_t buflen, + size_t seek) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct bio_vec *bvec = rqstp->rq_bvec; + struct msghdr msg = { NULL }; unsigned int i; - int ret; + ssize_t len; + size_t t; - if (base == 0) - return svc_recvfrom(rqstp, iov, nr, buflen); + clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); - for (i = 0; i < nr; i++) { - if (iov[i].iov_len > base) - break; - base -= iov[i].iov_len; + for (i = 0, t = 0; t < buflen; i++, t += PAGE_SIZE) + bvec_set_page(&bvec[i], rqstp->rq_pages[i], PAGE_SIZE, 0); + rqstp->rq_respages = &rqstp->rq_pages[i]; + rqstp->rq_next_page = rqstp->rq_respages + 1; + + iov_iter_bvec(&msg.msg_iter, ITER_DEST, bvec, i, buflen); + if (seek) { + iov_iter_advance(&msg.msg_iter, seek); + buflen -= seek; } - save_iovlen = iov[i].iov_len; - save_iovbase = iov[i].iov_base; - iov[i].iov_len -= base; - iov[i].iov_base += base; - ret = svc_recvfrom(rqstp, &iov[i], nr - i, buflen); - iov[i].iov_len = save_iovlen; - iov[i].iov_base = save_iovbase; - return ret; + len = svc_tcp_sock_recvmsg(svsk, &msg); + if (len > 0) + svc_flush_bvec(bvec, len, seek); + + /* If we read a full record, then assume there may be more + * data to read (stream based sockets only!) + */ + if (len == buflen) + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + + return len; } /* * Set socket snd and rcv buffer lengths */ -static void svc_sock_setbufsize(struct socket *sock, unsigned int snd, - unsigned int rcv) +static void svc_sock_setbufsize(struct svc_sock *svsk, unsigned int nreqs) { -#if 0 - mm_segment_t oldfs; - oldfs = get_fs(); set_fs(KERNEL_DS); - sock_setsockopt(sock, SOL_SOCKET, SO_SNDBUF, - (char*)&snd, sizeof(snd)); - sock_setsockopt(sock, SOL_SOCKET, SO_RCVBUF, - (char*)&rcv, sizeof(rcv)); -#else - /* sock_setsockopt limits use to sysctl_?mem_max, - * which isn't acceptable. Until that is made conditional - * on not having CAP_SYS_RESOURCE or similar, we go direct... - * DaveM said I could! - */ + unsigned int max_mesg = svsk->sk_xprt.xpt_server->sv_max_mesg; + struct socket *sock = svsk->sk_sock; + + nreqs = min(nreqs, INT_MAX / 2 / max_mesg); + lock_sock(sock->sk); - sock->sk->sk_sndbuf = snd * 2; - sock->sk->sk_rcvbuf = rcv * 2; + sock->sk->sk_sndbuf = nreqs * max_mesg * 2; + sock->sk->sk_rcvbuf = nreqs * max_mesg * 2; sock->sk->sk_write_space(sock->sk); release_sock(sock->sk); -#endif } + +static void svc_sock_secure_port(struct svc_rqst *rqstp) +{ + if (svc_port_is_privileged(svc_addr(rqstp))) + set_bit(RQ_SECURE, &rqstp->rq_flags); + else + clear_bit(RQ_SECURE, &rqstp->rq_flags); +} + /* * INET callback when data has been received on the socket. */ -static void svc_udp_data_ready(struct sock *sk, int count) +static void svc_data_ready(struct sock *sk) { struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; - wait_queue_head_t *wq = sk_sleep(sk); + + trace_sk_data_ready(sk); if (svsk) { - dprintk("svc: socket %p(inet %p), count=%d, busy=%d\n", - svsk, sk, count, - test_bit(XPT_BUSY, &svsk->sk_xprt.xpt_flags)); - set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); - svc_xprt_enqueue(&svsk->sk_xprt); + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_odata(sk); + trace_svcsock_data_ready(&svsk->sk_xprt, 0); + if (test_bit(XPT_HANDSHAKE, &svsk->sk_xprt.xpt_flags)) + return; + if (!test_and_set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags)) + svc_xprt_enqueue(&svsk->sk_xprt); } - if (wq && waitqueue_active(wq)) - wake_up_interruptible(wq); } /* @@ -423,28 +424,112 @@ static void svc_udp_data_ready(struct sock *sk, int count) static void svc_write_space(struct sock *sk) { struct svc_sock *svsk = (struct svc_sock *)(sk->sk_user_data); - wait_queue_head_t *wq = sk_sleep(sk); if (svsk) { - dprintk("svc: socket %p(inet %p), write_space busy=%d\n", - svsk, sk, test_bit(XPT_BUSY, &svsk->sk_xprt.xpt_flags)); + /* Refer to svc_setup_socket() for details. */ + rmb(); + trace_svcsock_write_space(&svsk->sk_xprt, 0); + svsk->sk_owspace(sk); svc_xprt_enqueue(&svsk->sk_xprt); } +} + +static int svc_tcp_has_wspace(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); - if (wq && waitqueue_active(wq)) { - dprintk("RPC svc_write_space: someone sleeping on %p\n", - svsk); - wake_up_interruptible(wq); + if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) + return 1; + return !test_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); +} + +static void svc_tcp_kill_temp_xprt(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + + sock_no_linger(svsk->sk_sock->sk); +} + +/** + * svc_tcp_handshake_done - Handshake completion handler + * @data: address of xprt to wake + * @status: status of handshake + * @peerid: serial number of key containing the remote peer's identity + * + * If a security policy is specified as an export option, we don't + * have a specific export here to check. So we set a "TLS session + * is present" flag on the xprt and let an upper layer enforce local + * security policy. + */ +static void svc_tcp_handshake_done(void *data, int status, key_serial_t peerid) +{ + struct svc_xprt *xprt = data; + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + + if (!status) { + if (peerid != TLS_NO_PEERID) + set_bit(XPT_PEER_AUTH, &xprt->xpt_flags); + set_bit(XPT_TLS_SESSION, &xprt->xpt_flags); } + clear_bit(XPT_HANDSHAKE, &xprt->xpt_flags); + complete_all(&svsk->sk_handshake_done); } -static void svc_tcp_write_space(struct sock *sk) +/** + * svc_tcp_handshake - Perform a transport-layer security handshake + * @xprt: connected transport endpoint + * + */ +static void svc_tcp_handshake(struct svc_xprt *xprt) { - struct socket *sock = sk->sk_socket; + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct sock *sk = svsk->sk_sock->sk; + struct tls_handshake_args args = { + .ta_sock = svsk->sk_sock, + .ta_done = svc_tcp_handshake_done, + .ta_data = xprt, + }; + int ret; + + trace_svc_tls_upcall(xprt); + + clear_bit(XPT_TLS_SESSION, &xprt->xpt_flags); + init_completion(&svsk->sk_handshake_done); - if (sk_stream_wspace(sk) >= sk_stream_min_wspace(sk) && sock) - clear_bit(SOCK_NOSPACE, &sock->flags); - svc_write_space(sk); + ret = tls_server_hello_x509(&args, GFP_KERNEL); + if (ret) { + trace_svc_tls_not_started(xprt); + goto out_failed; + } + + ret = wait_for_completion_interruptible_timeout(&svsk->sk_handshake_done, + SVC_HANDSHAKE_TO); + if (ret <= 0) { + if (tls_handshake_cancel(sk)) { + trace_svc_tls_timed_out(xprt); + goto out_close; + } + } + + if (!test_bit(XPT_TLS_SESSION, &xprt->xpt_flags)) { + trace_svc_tls_unavailable(xprt); + goto out_close; + } + + /* Mark the transport ready in case the remote sent RPC + * traffic before the kernel received the handshake + * completion downcall. + */ + set_bit(XPT_DATA, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + return; + +out_close: + set_bit(XPT_CLOSE, &xprt->xpt_flags); +out_failed: + clear_bit(XPT_HANDSHAKE, &xprt->xpt_flags); + set_bit(XPT_DATA, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); } /* @@ -502,8 +587,15 @@ static int svc_udp_get_dest_address(struct svc_rqst *rqstp, return 0; } -/* - * Receive a datagram from a UDP socket. +/** + * svc_udp_recvfrom - Receive a datagram from a UDP socket. + * @rqstp: request structure into which to receive an RPC Call + * + * Called in a loop when XPT_DATA has been set. + * + * Returns: + * On success, the number of bytes in a received RPC Call, or + * %0 if a complete RPC Call message was not ready to return */ static int svc_udp_recvfrom(struct svc_rqst *rqstp) { @@ -534,61 +626,47 @@ static int svc_udp_recvfrom(struct svc_rqst *rqstp) * provides an upper bound on the number of threads * which will access the socket. */ - svc_sock_setbufsize(svsk->sk_sock, - (serv->sv_nrthreads+3) * serv->sv_max_mesg, - (serv->sv_nrthreads+3) * serv->sv_max_mesg); + svc_sock_setbufsize(svsk, serv->sv_nrthreads + 3); clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); - skb = NULL; err = kernel_recvmsg(svsk->sk_sock, &msg, NULL, 0, 0, MSG_PEEK | MSG_DONTWAIT); - if (err >= 0) - skb = skb_recv_datagram(svsk->sk_sk, 0, 1, &err); - - if (skb == NULL) { - if (err != -EAGAIN) { - /* possibly an icmp error */ - dprintk("svc: recvfrom returned error %d\n", -err); - set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); - } - return 0; - } + if (err < 0) + goto out_recv_err; + skb = skb_recv_udp(svsk->sk_sk, MSG_DONTWAIT, &err); + if (!skb) + goto out_recv_err; + len = svc_addr_len(svc_addr(rqstp)); rqstp->rq_addrlen = len; - if (skb->tstamp.tv64 == 0) { + if (skb->tstamp == 0) { skb->tstamp = ktime_get_real(); /* Don't enable netstamp, sunrpc doesn't need that much accuracy */ } - svsk->sk_sk->sk_stamp = skb->tstamp; + sock_write_timestamp(svsk->sk_sk, skb->tstamp); set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); /* there may be more data... */ - len = skb->len - sizeof(struct udphdr); + len = skb->len; rqstp->rq_arg.len = len; + trace_svcsock_udp_recv(&svsk->sk_xprt, len); rqstp->rq_prot = IPPROTO_UDP; - if (!svc_udp_get_dest_address(rqstp, cmh)) { - net_warn_ratelimited("svc: received unknown control message %d/%d; dropping RPC reply datagram\n", - cmh->cmsg_level, cmh->cmsg_type); - goto out_free; - } + if (!svc_udp_get_dest_address(rqstp, cmh)) + goto out_cmsg_err; rqstp->rq_daddrlen = svc_addr_len(svc_daddr(rqstp)); if (skb_is_nonlinear(skb)) { /* we have to copy */ local_bh_disable(); - if (csum_partial_copy_to_xdr(&rqstp->rq_arg, skb)) { - local_bh_enable(); - /* checksum error */ - goto out_free; - } + if (csum_partial_copy_to_xdr(&rqstp->rq_arg, skb)) + goto out_bh_enable; local_bh_enable(); - skb_free_datagram_locked(svsk->sk_sk, skb); + consume_skb(skb); } else { /* we can use it in-place */ - rqstp->rq_arg.head[0].iov_base = skb->data + - sizeof(struct udphdr); + rqstp->rq_arg.head[0].iov_base = skb->data; rqstp->rq_arg.head[0].iov_len = len; if (skb_checksum_complete(skb)) goto out_free; @@ -610,28 +688,89 @@ static int svc_udp_recvfrom(struct svc_rqst *rqstp) if (serv->sv_stats) serv->sv_stats->netudpcnt++; + svc_sock_secure_port(rqstp); + svc_xprt_received(rqstp->rq_xprt); return len; + +out_recv_err: + if (err != -EAGAIN) { + /* possibly an icmp error */ + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + } + trace_svcsock_udp_recv_err(&svsk->sk_xprt, err); + goto out_clear_busy; +out_cmsg_err: + net_warn_ratelimited("svc: received unknown control message %d/%d; dropping RPC reply datagram\n", + cmh->cmsg_level, cmh->cmsg_type); + goto out_free; +out_bh_enable: + local_bh_enable(); out_free: - trace_kfree_skb(skb, svc_udp_recvfrom); - skb_free_datagram_locked(svsk->sk_sk, skb); + kfree_skb(skb); +out_clear_busy: + svc_xprt_received(rqstp->rq_xprt); return 0; } -static int -svc_udp_sendto(struct svc_rqst *rqstp) +/** + * svc_udp_sendto - Send out a reply on a UDP socket + * @rqstp: completed svc_rqst + * + * xpt_mutex ensures @rqstp's whole message is written to the socket + * without interruption. + * + * Returns the number of bytes sent, or a negative errno. + */ +static int svc_udp_sendto(struct svc_rqst *rqstp) { - int error; + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct xdr_buf *xdr = &rqstp->rq_res; + union { + struct cmsghdr hdr; + long all[SVC_PKTINFO_SPACE / sizeof(long)]; + } buffer; + struct cmsghdr *cmh = &buffer.hdr; + struct msghdr msg = { + .msg_name = &rqstp->rq_addr, + .msg_namelen = rqstp->rq_addrlen, + .msg_control = cmh, + .msg_flags = MSG_SPLICE_PAGES, + .msg_controllen = sizeof(buffer), + }; + unsigned int count; + int err; + + svc_udp_release_ctxt(xprt, rqstp->rq_xprt_ctxt); + rqstp->rq_xprt_ctxt = NULL; + + svc_set_cmsg_data(rqstp, cmh); + + mutex_lock(&xprt->xpt_mutex); + + if (svc_xprt_is_dead(xprt)) + goto out_notconn; - error = svc_sendto(rqstp, &rqstp->rq_res); - if (error == -ECONNREFUSED) + count = xdr_buf_to_bvec(svsk->sk_bvec, SUNRPC_MAX_UDP_SENDPAGES, xdr); + + iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, svsk->sk_bvec, + count, rqstp->rq_res.len); + err = sock_sendmsg(svsk->sk_sock, &msg); + if (err == -ECONNREFUSED) { /* ICMP error on earlier request. */ - error = svc_sendto(rqstp, &rqstp->rq_res); + iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, svsk->sk_bvec, + count, rqstp->rq_res.len); + err = sock_sendmsg(svsk->sk_sock, &msg); + } - return error; -} + trace_svcsock_udp_send(xprt, err); -static void svc_udp_prep_reply_hdr(struct svc_rqst *rqstp) -{ + mutex_unlock(&xprt->xpt_mutex); + return err; + +out_notconn: + mutex_unlock(&xprt->xpt_mutex); + return -ENOTCONN; } static int svc_udp_has_wspace(struct svc_xprt *xprt) @@ -658,6 +797,10 @@ static struct svc_xprt *svc_udp_accept(struct svc_xprt *xprt) return NULL; } +static void svc_udp_kill_temp_xprt(struct svc_xprt *xprt) +{ +} + static struct svc_xprt *svc_udp_create(struct svc_serv *serv, struct net *net, struct sockaddr *sa, int salen, @@ -666,16 +809,17 @@ static struct svc_xprt *svc_udp_create(struct svc_serv *serv, return svc_create_socket(serv, IPPROTO_UDP, net, sa, salen, flags); } -static struct svc_xprt_ops svc_udp_ops = { +static const struct svc_xprt_ops svc_udp_ops = { .xpo_create = svc_udp_create, .xpo_recvfrom = svc_udp_recvfrom, .xpo_sendto = svc_udp_sendto, - .xpo_release_rqst = svc_release_skb, + .xpo_result_payload = svc_sock_result_payload, + .xpo_release_ctxt = svc_udp_release_ctxt, .xpo_detach = svc_sock_detach, .xpo_free = svc_sock_free, - .xpo_prep_reply_hdr = svc_udp_prep_reply_hdr, .xpo_has_wspace = svc_udp_has_wspace, .xpo_accept = svc_udp_accept, + .xpo_kill_temp_xprt = svc_udp_kill_temp_xprt, }; static struct svc_xprt_class svc_udp_class = { @@ -683,59 +827,50 @@ static struct svc_xprt_class svc_udp_class = { .xcl_owner = THIS_MODULE, .xcl_ops = &svc_udp_ops, .xcl_max_payload = RPCSVC_MAXPAYLOAD_UDP, + .xcl_ident = XPRT_TRANSPORT_UDP, }; static void svc_udp_init(struct svc_sock *svsk, struct svc_serv *serv) { - int err, level, optname, one = 1; - svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_udp_class, &svsk->sk_xprt, serv); clear_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); - svsk->sk_sk->sk_data_ready = svc_udp_data_ready; + svsk->sk_sk->sk_data_ready = svc_data_ready; svsk->sk_sk->sk_write_space = svc_write_space; /* initialise setting must have enough space to * receive and respond to one request. * svc_udp_recvfrom will re-adjust if necessary */ - svc_sock_setbufsize(svsk->sk_sock, - 3 * svsk->sk_xprt.xpt_server->sv_max_mesg, - 3 * svsk->sk_xprt.xpt_server->sv_max_mesg); + svc_sock_setbufsize(svsk, 3); /* data might have come in before data_ready set up */ set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); + set_bit(XPT_RPCB_UNREG, &svsk->sk_xprt.xpt_flags); /* make sure we get destination address info */ switch (svsk->sk_sk->sk_family) { case AF_INET: - level = SOL_IP; - optname = IP_PKTINFO; + ip_sock_set_pktinfo(svsk->sk_sock->sk); break; case AF_INET6: - level = SOL_IPV6; - optname = IPV6_RECVPKTINFO; + ip6_sock_set_recvpktinfo(svsk->sk_sock->sk); break; default: BUG(); } - err = kernel_setsockopt(svsk->sk_sock, level, optname, - (char *)&one, sizeof(one)); - dprintk("svc: kernel_setsockopt returned %d\n", err); } /* * A data_ready event on a listening socket means there's a connection * pending. Do not use state_change as a substitute for it. */ -static void svc_tcp_listen_data_ready(struct sock *sk, int count_unused) +static void svc_tcp_listen_data_ready(struct sock *sk) { struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; - wait_queue_head_t *wq; - dprintk("svc: socket %p TCP (listen) state change %d\n", - sk, sk->sk_state); + trace_sk_data_ready(sk); /* * This callback may called twice when a new connection @@ -745,19 +880,19 @@ static void svc_tcp_listen_data_ready(struct sock *sk, int count_unused) * when one of child sockets become ESTABLISHED. * 2) data_ready method of the child socket may be called * when it receives data before the socket is accepted. - * In case of 2, we should ignore it silently. + * In case of 2, we should ignore it silently and DO NOT + * dereference svsk. */ - if (sk->sk_state == TCP_LISTEN) { - if (svsk) { - set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); - svc_xprt_enqueue(&svsk->sk_xprt); - } else - printk("svc: socket %p: no user data\n", sk); - } + if (sk->sk_state != TCP_LISTEN) + return; - wq = sk_sleep(sk); - if (wq && waitqueue_active(wq)) - wake_up_interruptible_all(wq); + if (svsk) { + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_odata(sk); + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + svc_xprt_enqueue(&svsk->sk_xprt); + } } /* @@ -766,34 +901,15 @@ static void svc_tcp_listen_data_ready(struct sock *sk, int count_unused) static void svc_tcp_state_change(struct sock *sk) { struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; - wait_queue_head_t *wq = sk_sleep(sk); - - dprintk("svc: socket %p TCP (connected) state change %d (svsk %p)\n", - sk, sk->sk_state, sk->sk_user_data); - - if (!svsk) - printk("svc: socket %p: no user data\n", sk); - else { - set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); - svc_xprt_enqueue(&svsk->sk_xprt); - } - if (wq && waitqueue_active(wq)) - wake_up_interruptible_all(wq); -} -static void svc_tcp_data_ready(struct sock *sk, int count) -{ - struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; - wait_queue_head_t *wq = sk_sleep(sk); - - dprintk("svc: socket %p TCP data ready (svsk %p)\n", - sk, sk->sk_user_data); if (svsk) { - set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); - svc_xprt_enqueue(&svsk->sk_xprt); + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_ostate(sk); + trace_svcsock_tcp_state(&svsk->sk_xprt, svsk->sk_sock); + if (sk->sk_state != TCP_ESTABLISHED) + svc_xprt_deferred_close(&svsk->sk_xprt); } - if (wq && waitqueue_active(wq)) - wake_up_interruptible(wq); } /* @@ -809,44 +925,33 @@ static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt) struct socket *newsock; struct svc_sock *newsvsk; int err, slen; - RPC_IFDEBUG(char buf[RPC_MAX_ADDRBUFLEN]); - dprintk("svc: tcp_accept %p sock %p\n", svsk, sock); if (!sock) return NULL; clear_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); err = kernel_accept(sock, &newsock, O_NONBLOCK); if (err < 0) { - if (err == -ENOMEM) - printk(KERN_WARNING "%s: no more sockets!\n", - serv->sv_name); - else if (err != -EAGAIN) - net_warn_ratelimited("%s: accept failed (err %d)!\n", - serv->sv_name, -err); + if (err != -EAGAIN) + trace_svcsock_accept_err(xprt, serv->sv_name, err); return NULL; } + if (IS_ERR(sock_alloc_file(newsock, O_NONBLOCK, NULL))) + return NULL; + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); - err = kernel_getpeername(newsock, sin, &slen); + err = kernel_getpeername(newsock, sin); if (err < 0) { - net_warn_ratelimited("%s: peername failed (err %d)!\n", - serv->sv_name, -err); + trace_svcsock_getpeername_err(xprt, serv->sv_name, err); goto failed; /* aborted connection or whatever */ } + slen = err; - /* Ideally, we would want to reject connections from unauthorized - * hosts here, but when we get encryption, the IP of the host won't - * tell us anything. For now just warn about unpriv connections. - */ - if (!svc_port_is_privileged(sin)) { - dprintk(KERN_WARNING - "%s: connect from unprivileged port: %s\n", - serv->sv_name, - __svc_print_addr(sin, buf, sizeof(buf))); - } - dprintk("%s: connect from %s\n", serv->sv_name, - __svc_print_addr(sin, buf, sizeof(buf))); + /* Reset the inherited callbacks before calling svc_setup_socket */ + newsock->sk->sk_state_change = svsk->sk_ostate; + newsock->sk->sk_data_ready = svsk->sk_odata; + newsock->sk->sk_write_space = svsk->sk_owspace; /* make sure that a write doesn't block forever when * low on memory @@ -858,30 +963,34 @@ static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt) if (IS_ERR(newsvsk)) goto failed; svc_xprt_set_remote(&newsvsk->sk_xprt, sin, slen); - err = kernel_getsockname(newsock, sin, &slen); - if (unlikely(err < 0)) { - dprintk("svc_tcp_accept: kernel_getsockname error %d\n", -err); + err = kernel_getsockname(newsock, sin); + slen = err; + if (unlikely(err < 0)) slen = offsetof(struct sockaddr, sa_data); - } svc_xprt_set_local(&newsvsk->sk_xprt, sin, slen); + if (sock_is_loopback(newsock->sk)) + set_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags); + else + clear_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags); if (serv->sv_stats) serv->sv_stats->nettcpconn++; return &newsvsk->sk_xprt; failed: - sock_release(newsock); + sockfd_put(newsock); return NULL; } -static unsigned int svc_tcp_restore_pages(struct svc_sock *svsk, struct svc_rqst *rqstp) +static size_t svc_tcp_restore_pages(struct svc_sock *svsk, + struct svc_rqst *rqstp) { - unsigned int i, len, npages; + size_t len = svsk->sk_datalen; + unsigned int i, npages; - if (svsk->sk_datalen == 0) + if (!len) return 0; - len = svsk->sk_datalen; npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; for (i = 0; i < npages; i++) { if (rqstp->rq_pages[i] != NULL) @@ -930,48 +1039,46 @@ out: } /* - * Receive fragment record header. - * If we haven't gotten the record length yet, get the next four bytes. + * Receive fragment record header into sk_marker. */ -static int svc_tcp_recv_record(struct svc_sock *svsk, struct svc_rqst *rqstp) +static ssize_t svc_tcp_read_marker(struct svc_sock *svsk, + struct svc_rqst *rqstp) { - struct svc_serv *serv = svsk->sk_xprt.xpt_server; - unsigned int want; - int len; - - clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + ssize_t want, len; + /* If we haven't gotten the record length yet, + * get the next four bytes. + */ if (svsk->sk_tcplen < sizeof(rpc_fraghdr)) { + struct msghdr msg = { NULL }; struct kvec iov; want = sizeof(rpc_fraghdr) - svsk->sk_tcplen; - iov.iov_base = ((char *) &svsk->sk_reclen) + svsk->sk_tcplen; + iov.iov_base = ((char *)&svsk->sk_marker) + svsk->sk_tcplen; iov.iov_len = want; - if ((len = svc_recvfrom(rqstp, &iov, 1, want)) < 0) - goto error; + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, want); + len = svc_tcp_sock_recvmsg(svsk, &msg); + if (len < 0) + return len; svsk->sk_tcplen += len; - if (len < want) { - dprintk("svc: short recvfrom while reading record " - "length (%d of %d)\n", len, want); - return -EAGAIN; + /* call again to read the remaining bytes */ + goto err_short; } - - dprintk("svc: TCP record, %d bytes\n", svc_sock_reclen(svsk)); + trace_svcsock_marker(&svsk->sk_xprt, svsk->sk_marker); if (svc_sock_reclen(svsk) + svsk->sk_datalen > - serv->sv_max_mesg) { - net_notice_ratelimited("RPC: fragment too large: %d\n", - svc_sock_reclen(svsk)); - goto err_delete; - } + svsk->sk_xprt.xpt_server->sv_max_mesg) + goto err_too_large; } - return svc_sock_reclen(svsk); -error: - dprintk("RPC: TCP recv_record got %d\n", len); - return len; -err_delete: - set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); + +err_too_large: + net_notice_ratelimited("svc: %s oversized RPC fragment (%u octets) from %pISpc\n", + svsk->sk_xprt.xpt_server->sv_name, + svc_sock_reclen(svsk), + (struct sockaddr *)&svsk->sk_xprt.xpt_remote); + svc_xprt_deferred_close(&svsk->sk_xprt); +err_short: return -EAGAIN; } @@ -981,23 +1088,14 @@ static int receive_cb_reply(struct svc_sock *svsk, struct svc_rqst *rqstp) struct rpc_rqst *req = NULL; struct kvec *src, *dst; __be32 *p = (__be32 *)rqstp->rq_arg.head[0].iov_base; - __be32 xid; - __be32 calldir; - - xid = *p++; - calldir = *p; + __be32 xid = *p; - if (bc_xprt) - req = xprt_lookup_rqst(bc_xprt, xid); - - if (!req) { - printk(KERN_NOTICE - "%s: Got unrecognized reply: " - "calldir 0x%x xpt_bc_xprt %p xid %08x\n", - __func__, ntohl(calldir), - bc_xprt, xid); + if (!bc_xprt) return -EAGAIN; - } + spin_lock(&bc_xprt->queue_lock); + req = xprt_lookup_rqst(bc_xprt, xid); + if (!req) + goto unlock_eagain; memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(struct xdr_buf)); /* @@ -1008,97 +1106,69 @@ static int receive_cb_reply(struct svc_sock *svsk, struct svc_rqst *rqstp) dst = &req->rq_private_buf.head[0]; src = &rqstp->rq_arg.head[0]; if (dst->iov_len < src->iov_len) - return -EAGAIN; /* whatever; just giving up. */ + goto unlock_eagain; /* whatever; just giving up. */ memcpy(dst->iov_base, src->iov_base, src->iov_len); xprt_complete_rqst(req->rq_task, rqstp->rq_arg.len); rqstp->rq_arg.len = 0; + spin_unlock(&bc_xprt->queue_lock); return 0; -} - -static int copy_pages_to_kvecs(struct kvec *vec, struct page **pages, int len) -{ - int i = 0; - int t = 0; - - while (t < len) { - vec[i].iov_base = page_address(pages[i]); - vec[i].iov_len = PAGE_SIZE; - i++; - t += PAGE_SIZE; - } - return i; +unlock_eagain: + spin_unlock(&bc_xprt->queue_lock); + return -EAGAIN; } static void svc_tcp_fragment_received(struct svc_sock *svsk) { /* If we have more data, signal svc_xprt_enqueue() to try again */ - if (svc_recv_available(svsk) > sizeof(rpc_fraghdr)) - set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); - dprintk("svc: TCP %s record (%d bytes)\n", - svc_sock_final_rec(svsk) ? "final" : "nonfinal", - svc_sock_reclen(svsk)); svsk->sk_tcplen = 0; - svsk->sk_reclen = 0; + svsk->sk_marker = xdr_zero; } -/* - * Receive data from a TCP socket. +/** + * svc_tcp_recvfrom - Receive data from a TCP socket + * @rqstp: request structure into which to receive an RPC Call + * + * Called in a loop when XPT_DATA has been set. + * + * Read the 4-byte stream record marker, then use the record length + * in that marker to set up exactly the resources needed to receive + * the next RPC message into @rqstp. + * + * Returns: + * On success, the number of bytes in a received RPC Call, or + * %0 if a complete RPC Call message was not ready to return + * + * The zero return case handles partial receives and callback Replies. + * The state of a partial receive is preserved in the svc_sock for + * the next call to svc_tcp_recvfrom. */ static int svc_tcp_recvfrom(struct svc_rqst *rqstp) { struct svc_sock *svsk = container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); struct svc_serv *serv = svsk->sk_xprt.xpt_server; - int len; - struct kvec *vec; - unsigned int want, base; + size_t want, base; + ssize_t len; __be32 *p; __be32 calldir; - int pnum; - - dprintk("svc: tcp_recv %p data %d conn %d close %d\n", - svsk, test_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags), - test_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags), - test_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags)); - len = svc_tcp_recv_record(svsk, rqstp); + clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + len = svc_tcp_read_marker(svsk, rqstp); if (len < 0) goto error; base = svc_tcp_restore_pages(svsk, rqstp); - want = svc_sock_reclen(svsk) - (svsk->sk_tcplen - sizeof(rpc_fraghdr)); - - vec = rqstp->rq_vec; - - pnum = copy_pages_to_kvecs(&vec[0], &rqstp->rq_pages[0], - svsk->sk_datalen + want); - - rqstp->rq_respages = &rqstp->rq_pages[pnum]; - rqstp->rq_next_page = rqstp->rq_respages + 1; - - /* Now receive data */ - len = svc_partial_recvfrom(rqstp, vec, pnum, want, base); + want = len - (svsk->sk_tcplen - sizeof(rpc_fraghdr)); + len = svc_tcp_read_msg(rqstp, base + want, base); if (len >= 0) { + trace_svcsock_tcp_recv(&svsk->sk_xprt, len); svsk->sk_tcplen += len; svsk->sk_datalen += len; } - if (len != want || !svc_sock_final_rec(svsk)) { - svc_tcp_save_pages(svsk, rqstp); - if (len < 0 && len != -EAGAIN) - goto err_delete; - if (len == want) - svc_tcp_fragment_received(svsk); - else - dprintk("svc: incomplete TCP record (%d of %d)\n", - (int)(svsk->sk_tcplen - sizeof(rpc_fraghdr)), - svc_sock_reclen(svsk)); - goto err_noclose; - } - - if (svsk->sk_datalen < 8) { - svsk->sk_datalen = 0; - goto err_delete; /* client is nuts. */ - } + if (len != want || !svc_sock_final_rec(svsk)) + goto err_incomplete; + if (svsk->sk_datalen < 8) + goto err_nuts; rqstp->rq_arg.len = svsk->sk_datalen; rqstp->rq_arg.page_base = 0; @@ -1110,6 +1180,10 @@ static int svc_tcp_recvfrom(struct svc_rqst *rqstp) rqstp->rq_xprt_ctxt = NULL; rqstp->rq_prot = IPPROTO_TCP; + if (test_bit(XPT_LOCAL, &svsk->sk_xprt.xpt_flags)) + set_bit(RQ_LOCAL, &rqstp->rq_flags); + else + clear_bit(RQ_LOCAL, &rqstp->rq_flags); p = (__be32 *)rqstp->rq_arg.head[0].iov_base; calldir = p[1]; @@ -1127,76 +1201,113 @@ static int svc_tcp_recvfrom(struct svc_rqst *rqstp) if (serv->sv_stats) serv->sv_stats->nettcpcnt++; + svc_sock_secure_port(rqstp); + svc_xprt_received(rqstp->rq_xprt); return rqstp->rq_arg.len; +err_incomplete: + svc_tcp_save_pages(svsk, rqstp); + if (len < 0 && len != -EAGAIN) + goto err_delete; + if (len == want) + svc_tcp_fragment_received(svsk); + else + trace_svcsock_tcp_recv_short(&svsk->sk_xprt, + svc_sock_reclen(svsk), + svsk->sk_tcplen - sizeof(rpc_fraghdr)); + goto err_noclose; error: if (len != -EAGAIN) goto err_delete; - dprintk("RPC: TCP recvfrom got EAGAIN\n"); - return 0; + trace_svcsock_tcp_recv_eagain(&svsk->sk_xprt, 0); + goto err_noclose; +err_nuts: + svsk->sk_datalen = 0; err_delete: - printk(KERN_NOTICE "%s: recvfrom returned errno %d\n", - svsk->sk_xprt.xpt_server->sv_name, -len); - set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); + trace_svcsock_tcp_recv_err(&svsk->sk_xprt, len); + svc_xprt_deferred_close(&svsk->sk_xprt); err_noclose: + svc_xprt_received(rqstp->rq_xprt); return 0; /* record not complete */ } /* - * Send out data on TCP socket. + * MSG_SPLICE_PAGES is used exclusively to reduce the number of + * copy operations in this path. Therefore the caller must ensure + * that the pages backing @xdr are unchanging. */ -static int svc_tcp_sendto(struct svc_rqst *rqstp) +static int svc_tcp_sendmsg(struct svc_sock *svsk, struct svc_rqst *rqstp, + rpc_fraghdr marker) { - struct xdr_buf *xbufp = &rqstp->rq_res; - int sent; - __be32 reclen; + struct msghdr msg = { + .msg_flags = MSG_SPLICE_PAGES, + }; + unsigned int count; + void *buf; + int ret; - /* Set up the first element of the reply kvec. - * Any other kvecs that may be in use have been taken - * care of by the server implementation itself. + /* The stream record marker is copied into a temporary page + * fragment buffer so that it can be included in sk_bvec. */ - reclen = htonl(0x80000000|((xbufp->len ) - 4)); - memcpy(xbufp->head[0].iov_base, &reclen, 4); - - sent = svc_sendto(rqstp, &rqstp->rq_res); - if (sent != xbufp->len) { - printk(KERN_NOTICE - "rpc-srv/tcp: %s: %s %d when sending %d bytes " - "- shutting down socket\n", - rqstp->rq_xprt->xpt_server->sv_name, - (sent<0)?"got error":"sent only", - sent, xbufp->len); - set_bit(XPT_CLOSE, &rqstp->rq_xprt->xpt_flags); - svc_xprt_enqueue(rqstp->rq_xprt); - sent = -EAGAIN; - } - return sent; + buf = page_frag_alloc(&svsk->sk_frag_cache, sizeof(marker), + GFP_KERNEL); + if (!buf) + return -ENOMEM; + memcpy(buf, &marker, sizeof(marker)); + bvec_set_virt(svsk->sk_bvec, buf, sizeof(marker)); + + count = xdr_buf_to_bvec(svsk->sk_bvec + 1, rqstp->rq_maxpages, + &rqstp->rq_res); + + iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, svsk->sk_bvec, + 1 + count, sizeof(marker) + rqstp->rq_res.len); + ret = sock_sendmsg(svsk->sk_sock, &msg); + page_frag_free(buf); + return ret; } -/* - * Setup response header. TCP has a 4B record length field. +/** + * svc_tcp_sendto - Send out a reply on a TCP socket + * @rqstp: completed svc_rqst + * + * xpt_mutex ensures @rqstp's whole message is written to the socket + * without interruption. + * + * Returns the number of bytes sent, or a negative errno. */ -static void svc_tcp_prep_reply_hdr(struct svc_rqst *rqstp) +static int svc_tcp_sendto(struct svc_rqst *rqstp) { - struct kvec *resv = &rqstp->rq_res.head[0]; - - /* tcp needs a space for the record length... */ - svc_putnl(resv, 0); -} + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct xdr_buf *xdr = &rqstp->rq_res; + rpc_fraghdr marker = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | + (u32)xdr->len); + int sent; -static int svc_tcp_has_wspace(struct svc_xprt *xprt) -{ - struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); - struct svc_serv *serv = svsk->sk_xprt.xpt_server; - int required; + svc_tcp_release_ctxt(xprt, rqstp->rq_xprt_ctxt); + rqstp->rq_xprt_ctxt = NULL; + + mutex_lock(&xprt->xpt_mutex); + if (svc_xprt_is_dead(xprt)) + goto out_notconn; + sent = svc_tcp_sendmsg(svsk, rqstp, marker); + trace_svcsock_tcp_send(xprt, sent); + if (sent < 0 || sent != (xdr->len + sizeof(marker))) + goto out_close; + mutex_unlock(&xprt->xpt_mutex); + return sent; - if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) - return 1; - required = atomic_read(&xprt->xpt_reserved) + serv->sv_max_mesg; - if (sk_stream_wspace(svsk->sk_sk) >= required) - return 1; - set_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); - return 0; +out_notconn: + mutex_unlock(&xprt->xpt_mutex); + return -ENOTCONN; +out_close: + pr_notice("rpc-srv/tcp: %s: %s %d when sending %zu bytes - shutting down socket\n", + xprt->xpt_server->sv_name, + (sent < 0) ? "got error" : "sent", + sent, xdr->len + sizeof(marker)); + svc_xprt_deferred_close(xprt); + mutex_unlock(&xprt->xpt_mutex); + return -EAGAIN; } static struct svc_xprt *svc_tcp_create(struct svc_serv *serv, @@ -1207,67 +1318,18 @@ static struct svc_xprt *svc_tcp_create(struct svc_serv *serv, return svc_create_socket(serv, IPPROTO_TCP, net, sa, salen, flags); } -#if defined(CONFIG_SUNRPC_BACKCHANNEL) -static struct svc_xprt *svc_bc_create_socket(struct svc_serv *, int, - struct net *, struct sockaddr *, - int, int); -static void svc_bc_sock_free(struct svc_xprt *xprt); - -static struct svc_xprt *svc_bc_tcp_create(struct svc_serv *serv, - struct net *net, - struct sockaddr *sa, int salen, - int flags) -{ - return svc_bc_create_socket(serv, IPPROTO_TCP, net, sa, salen, flags); -} - -static void svc_bc_tcp_sock_detach(struct svc_xprt *xprt) -{ -} - -static struct svc_xprt_ops svc_tcp_bc_ops = { - .xpo_create = svc_bc_tcp_create, - .xpo_detach = svc_bc_tcp_sock_detach, - .xpo_free = svc_bc_sock_free, - .xpo_prep_reply_hdr = svc_tcp_prep_reply_hdr, -}; - -static struct svc_xprt_class svc_tcp_bc_class = { - .xcl_name = "tcp-bc", - .xcl_owner = THIS_MODULE, - .xcl_ops = &svc_tcp_bc_ops, - .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, -}; - -static void svc_init_bc_xprt_sock(void) -{ - svc_reg_xprt_class(&svc_tcp_bc_class); -} - -static void svc_cleanup_bc_xprt_sock(void) -{ - svc_unreg_xprt_class(&svc_tcp_bc_class); -} -#else /* CONFIG_SUNRPC_BACKCHANNEL */ -static void svc_init_bc_xprt_sock(void) -{ -} - -static void svc_cleanup_bc_xprt_sock(void) -{ -} -#endif /* CONFIG_SUNRPC_BACKCHANNEL */ - -static struct svc_xprt_ops svc_tcp_ops = { +static const struct svc_xprt_ops svc_tcp_ops = { .xpo_create = svc_tcp_create, .xpo_recvfrom = svc_tcp_recvfrom, .xpo_sendto = svc_tcp_sendto, - .xpo_release_rqst = svc_release_skb, + .xpo_result_payload = svc_sock_result_payload, + .xpo_release_ctxt = svc_tcp_release_ctxt, .xpo_detach = svc_tcp_sock_detach, .xpo_free = svc_sock_free, - .xpo_prep_reply_hdr = svc_tcp_prep_reply_hdr, .xpo_has_wspace = svc_tcp_has_wspace, .xpo_accept = svc_tcp_accept, + .xpo_kill_temp_xprt = svc_tcp_kill_temp_xprt, + .xpo_handshake = svc_tcp_handshake, }; static struct svc_xprt_class svc_tcp_class = { @@ -1275,20 +1337,19 @@ static struct svc_xprt_class svc_tcp_class = { .xcl_owner = THIS_MODULE, .xcl_ops = &svc_tcp_ops, .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, + .xcl_ident = XPRT_TRANSPORT_TCP, }; void svc_init_xprt_sock(void) { svc_reg_xprt_class(&svc_tcp_class); svc_reg_xprt_class(&svc_udp_class); - svc_init_bc_xprt_sock(); } void svc_cleanup_xprt_sock(void) { svc_unreg_xprt_class(&svc_tcp_class); svc_unreg_xprt_class(&svc_udp_class); - svc_cleanup_bc_xprt_sock(); } static void svc_tcp_init(struct svc_sock *svsk, struct svc_serv *serv) @@ -1298,27 +1359,34 @@ static void svc_tcp_init(struct svc_sock *svsk, struct svc_serv *serv) svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_tcp_class, &svsk->sk_xprt, serv); set_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); + set_bit(XPT_CONG_CTRL, &svsk->sk_xprt.xpt_flags); if (sk->sk_state == TCP_LISTEN) { - dprintk("setting up TCP socket for listening\n"); + strcpy(svsk->sk_xprt.xpt_remotebuf, "listener"); set_bit(XPT_LISTENER, &svsk->sk_xprt.xpt_flags); + set_bit(XPT_RPCB_UNREG, &svsk->sk_xprt.xpt_flags); sk->sk_data_ready = svc_tcp_listen_data_ready; set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); } else { - dprintk("setting up TCP socket for reading\n"); sk->sk_state_change = svc_tcp_state_change; - sk->sk_data_ready = svc_tcp_data_ready; - sk->sk_write_space = svc_tcp_write_space; + sk->sk_data_ready = svc_data_ready; + sk->sk_write_space = svc_write_space; - svsk->sk_reclen = 0; + svsk->sk_marker = xdr_zero; svsk->sk_tcplen = 0; svsk->sk_datalen = 0; - memset(&svsk->sk_pages[0], 0, sizeof(svsk->sk_pages)); + memset(&svsk->sk_pages[0], 0, + svsk->sk_maxpages * sizeof(struct page *)); - tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF; + tcp_sock_set_nodelay(sk); set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); - if (sk->sk_state != TCP_ESTABLISHED) - set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); + switch (sk->sk_state) { + case TCP_SYN_RECV: + case TCP_ESTABLISHED: + break; + default: + svc_xprt_deferred_close(&svsk->sk_xprt); + } } } @@ -1335,11 +1403,23 @@ void svc_sock_update_bufs(struct svc_serv *serv) set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); spin_unlock_bh(&serv->sv_lock); } -EXPORT_SYMBOL_GPL(svc_sock_update_bufs); + +static int svc_sock_sendpages(struct svc_serv *serv, struct socket *sock, int flags) +{ + switch (sock->type) { + case SOCK_STREAM: + /* +1 for TCP record marker */ + if (flags & SVC_SOCK_TEMPORARY) + return svc_serv_maxpages(serv) + 1; + return 0; + case SOCK_DGRAM: + return SUNRPC_MAX_UDP_SENDPAGES; + } + return -EINVAL; +} /* * Initialize socket for RPC use and create svc_sock struct - * XXX: May want to setsockopt SO_SNDBUF and SO_RCVBUF. */ static struct svc_sock *svc_setup_socket(struct svc_serv *serv, struct socket *sock, @@ -1348,64 +1428,81 @@ static struct svc_sock *svc_setup_socket(struct svc_serv *serv, struct svc_sock *svsk; struct sock *inet; int pmap_register = !(flags & SVC_SOCK_ANONYMOUS); - int err = 0; + int sendpages; + unsigned long pages; + + sendpages = svc_sock_sendpages(serv, sock, flags); + if (sendpages < 0) + return ERR_PTR(sendpages); - dprintk("svc: svc_setup_socket %p\n", sock); - svsk = kzalloc(sizeof(*svsk), GFP_KERNEL); + pages = svc_serv_maxpages(serv); + svsk = kzalloc(struct_size(svsk, sk_pages, pages), GFP_KERNEL); if (!svsk) return ERR_PTR(-ENOMEM); + if (sendpages) { + svsk->sk_bvec = kcalloc(sendpages, sizeof(*svsk->sk_bvec), GFP_KERNEL); + if (!svsk->sk_bvec) { + kfree(svsk); + return ERR_PTR(-ENOMEM); + } + } + + svsk->sk_maxpages = pages; + inet = sock->sk; - /* Register socket with portmapper */ - if (pmap_register) + if (pmap_register) { + int err; + err = svc_register(serv, sock_net(sock->sk), inet->sk_family, inet->sk_protocol, ntohs(inet_sk(inet)->inet_sport)); - - if (err < 0) { - kfree(svsk); - return ERR_PTR(err); + if (err < 0) { + kfree(svsk->sk_bvec); + kfree(svsk); + return ERR_PTR(err); + } } - inet->sk_user_data = svsk; svsk->sk_sock = sock; svsk->sk_sk = inet; svsk->sk_ostate = inet->sk_state_change; svsk->sk_odata = inet->sk_data_ready; svsk->sk_owspace = inet->sk_write_space; + /* + * This barrier is necessary in order to prevent race condition + * with svc_data_ready(), svc_tcp_listen_data_ready(), and others + * when calling callbacks above. + */ + wmb(); + inet->sk_user_data = svsk; /* Initialize the socket */ if (sock->type == SOCK_DGRAM) svc_udp_init(svsk, serv); - else { - /* initialise setting must have enough space to - * receive and respond to one request. - */ - svc_sock_setbufsize(svsk->sk_sock, 4 * serv->sv_max_mesg, - 4 * serv->sv_max_mesg); + else svc_tcp_init(svsk, serv); - } - - dprintk("svc: svc_setup_socket created %p (inet %p)\n", - svsk, svsk->sk_sk); + trace_svcsock_new(svsk, sock); return svsk; } /** * svc_addsock - add a listener socket to an RPC service * @serv: pointer to RPC service to which to add a new listener + * @net: caller's network namespace * @fd: file descriptor of the new listener * @name_return: pointer to buffer to fill in with name of listener * @len: size of the buffer + * @cred: credential * * Fills in socket name and returns positive length of name if successful. * Name is terminated with '\n'. On error, returns a negative errno * value. */ -int svc_addsock(struct svc_serv *serv, const int fd, char *name_return, - const size_t len) +int svc_addsock(struct svc_serv *serv, struct net *net, const int fd, + char *name_return, const size_t len, const struct cred *cred) { int err = 0; struct socket *so = sockfd_lookup(fd, &err); @@ -1416,6 +1513,9 @@ int svc_addsock(struct svc_serv *serv, const int fd, char *name_return, if (!so) return err; + err = -EINVAL; + if (sock_net(so->sk) != net) + goto out; err = -EAFNOSUPPORT; if ((so->sk->sk_family != PF_INET) && (so->sk->sk_family != PF_INET6)) goto out; @@ -1435,8 +1535,10 @@ int svc_addsock(struct svc_serv *serv, const int fd, char *name_return, err = PTR_ERR(svsk); goto out; } - if (kernel_getsockname(svsk->sk_sock, sin, &salen) == 0) + salen = kernel_getsockname(svsk->sk_sock, sin); + if (salen >= 0) svc_xprt_set_local(&svsk->sk_xprt, sin, salen); + svsk->sk_xprt.xpt_cred = get_cred(cred); svc_add_new_perm_xprt(serv, &svsk->sk_xprt); return svc_one_sock_name(svsk, name_return, len); out: @@ -1462,12 +1564,6 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv, struct sockaddr *newsin = (struct sockaddr *)&addr; int newlen; int family; - int val; - RPC_IFDEBUG(char buf[RPC_MAX_ADDRBUFLEN]); - - dprintk("svc: svc_create_socket(%s, %d, %s)\n", - serv->sv_program->pg_name, protocol, - __svc_print_addr(sin, buf, sizeof(buf))); if (protocol != IPPROTO_UDP && protocol != IPPROTO_TCP) { printk(KERN_WARNING "svc: only UDP and TCP " @@ -1498,24 +1594,22 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv, * getting requests from IPv4 remotes. Those should * be shunted to a PF_INET listener via rpcbind. */ - val = 1; if (family == PF_INET6) - kernel_setsockopt(sock, SOL_IPV6, IPV6_V6ONLY, - (char *)&val, sizeof(val)); - + ip6_sock_set_v6only(sock->sk); if (type == SOCK_STREAM) sock->sk->sk_reuse = SK_CAN_REUSE; /* allow address reuse */ - error = kernel_bind(sock, sin, len); + error = kernel_bind(sock, (struct sockaddr_unsized *)sin, len); if (error < 0) goto bummer; - newlen = len; - error = kernel_getsockname(sock, newsin, &newlen); + error = kernel_getsockname(sock, newsin); if (error < 0) goto bummer; + newlen = error; if (protocol == IPPROTO_TCP) { - if ((error = kernel_listen(sock, 64)) < 0) + sk_net_refcnt_upgrade(sock->sk); + if ((error = kernel_listen(sock, SOMAXCONN)) < 0) goto bummer; } @@ -1527,7 +1621,6 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv, svc_xprt_set_local(&svsk->sk_xprt, newsin, newlen); return (struct svc_xprt *)svsk; bummer: - dprintk("svc: svc_create_socket error = %d\n", -error); sock_release(sock); return ERR_PTR(error); } @@ -1540,18 +1633,14 @@ static void svc_sock_detach(struct svc_xprt *xprt) { struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); struct sock *sk = svsk->sk_sk; - wait_queue_head_t *wq; - - dprintk("svc: svc_sock_detach(%p)\n", svsk); /* put back the old socket callbacks */ + lock_sock(sk); sk->sk_state_change = svsk->sk_ostate; sk->sk_data_ready = svsk->sk_odata; sk->sk_write_space = svsk->sk_owspace; - - wq = sk_sleep(sk); - if (wq && waitqueue_active(wq)) - wake_up_interruptible(wq); + sk->sk_user_data = NULL; + release_sock(sk); } /* @@ -1561,7 +1650,7 @@ static void svc_tcp_sock_detach(struct svc_xprt *xprt) { struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); - dprintk("svc: svc_tcp_sock_detach(%p)\n", svsk); + tls_handshake_close(svsk->sk_sock); svc_sock_detach(xprt); @@ -1577,52 +1666,17 @@ static void svc_tcp_sock_detach(struct svc_xprt *xprt) static void svc_sock_free(struct svc_xprt *xprt) { struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); - dprintk("svc: svc_sock_free(%p)\n", svsk); - - if (svsk->sk_sock->file) - sockfd_put(svsk->sk_sock); - else - sock_release(svsk->sk_sock); - kfree(svsk); -} - -#if defined(CONFIG_SUNRPC_BACKCHANNEL) -/* - * Create a back channel svc_xprt which shares the fore channel socket. - */ -static struct svc_xprt *svc_bc_create_socket(struct svc_serv *serv, - int protocol, - struct net *net, - struct sockaddr *sin, int len, - int flags) -{ - struct svc_sock *svsk; - struct svc_xprt *xprt; + struct socket *sock = svsk->sk_sock; - if (protocol != IPPROTO_TCP) { - printk(KERN_WARNING "svc: only TCP sockets" - " supported on shared back channel\n"); - return ERR_PTR(-EINVAL); - } + trace_svcsock_free(svsk, sock); - svsk = kzalloc(sizeof(*svsk), GFP_KERNEL); - if (!svsk) - return ERR_PTR(-ENOMEM); - - xprt = &svsk->sk_xprt; - svc_xprt_init(net, &svc_tcp_bc_class, xprt, serv); - - serv->sv_bc_xprt = xprt; - - return xprt; -} + tls_handshake_cancel(sock->sk); + if (sock->file) + sockfd_put(sock); + else + sock_release(sock); -/* - * Free a back channel svc_sock. - */ -static void svc_bc_sock_free(struct svc_xprt *xprt) -{ - if (xprt) - kfree(container_of(xprt, struct svc_sock, sk_xprt)); + page_frag_cache_drain(&svsk->sk_frag_cache); + kfree(svsk->sk_bvec); + kfree(svsk); } -#endif /* CONFIG_SUNRPC_BACKCHANNEL */ diff --git a/net/sunrpc/sysctl.c b/net/sunrpc/sysctl.c index c99c58e2ee66..bdb587a72422 100644 --- a/net/sunrpc/sysctl.c +++ b/net/sunrpc/sysctl.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/sysctl.c * @@ -14,7 +15,7 @@ #include <linux/sysctl.h> #include <linux/module.h> -#include <asm/uaccess.h> +#include <linux/uaccess.h> #include <linux/sunrpc/types.h> #include <linux/sunrpc/sched.h> #include <linux/sunrpc/stats.h> @@ -37,47 +38,35 @@ EXPORT_SYMBOL_GPL(nfsd_debug); unsigned int nlm_debug; EXPORT_SYMBOL_GPL(nlm_debug); -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) -static struct ctl_table_header *sunrpc_table_header; -static struct ctl_table sunrpc_table[]; - -void -rpc_register_sysctl(void) -{ - if (!sunrpc_table_header) - sunrpc_table_header = register_sysctl_table(sunrpc_table); -} - -void -rpc_unregister_sysctl(void) -{ - if (sunrpc_table_header) { - unregister_sysctl_table(sunrpc_table_header); - sunrpc_table_header = NULL; - } -} - -static int proc_do_xprt(struct ctl_table *table, int write, - void __user *buffer, size_t *lenp, loff_t *ppos) +static int proc_do_xprt(const struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) { char tmpbuf[256]; - size_t len; + ssize_t len; - if ((*ppos && !write) || !*lenp) { + if (write || *ppos) { *lenp = 0; return 0; } len = svc_print_xprts(tmpbuf, sizeof(tmpbuf)); - return simple_read_from_buffer(buffer, *lenp, ppos, tmpbuf, len); + len = memory_read_from_buffer(buffer, *lenp, ppos, tmpbuf, len); + + if (len < 0) { + *lenp = 0; + return -EINVAL; + } + *lenp = len; + return 0; } static int -proc_dodebug(struct ctl_table *table, int write, - void __user *buffer, size_t *lenp, loff_t *ppos) +proc_dodebug(const struct ctl_table *table, int write, void *buffer, size_t *lenp, + loff_t *ppos) { - char tmpbuf[20], c, *s; - char __user *p; + char tmpbuf[20], *s = NULL; + char *p; unsigned int value; size_t left, len; @@ -89,41 +78,41 @@ proc_dodebug(struct ctl_table *table, int write, left = *lenp; if (write) { - if (!access_ok(VERIFY_READ, buffer, left)) - return -EFAULT; p = buffer; - while (left && __get_user(c, p) >= 0 && isspace(c)) - left--, p++; + while (left && isspace(*p)) { + left--; + p++; + } if (!left) goto done; if (left > sizeof(tmpbuf) - 1) return -EINVAL; - if (copy_from_user(tmpbuf, p, left)) - return -EFAULT; + memcpy(tmpbuf, p, left); tmpbuf[left] = '\0'; - for (s = tmpbuf, value = 0; '0' <= *s && *s <= '9'; s++, left--) - value = 10 * value + (*s - '0'); - if (*s && !isspace(*s)) - return -EINVAL; - while (left && isspace(*s)) - left--, s++; + value = simple_strtol(tmpbuf, &s, 0); + if (s) { + left -= (s - tmpbuf); + if (left && !isspace(*s)) + return -EINVAL; + while (left && isspace(*s)) { + left--; + s++; + } + } else + left = 0; *(unsigned int *) table->data = value; /* Display the RPC tasks on writing to rpc_debug */ if (strcmp(table->procname, "rpc_debug") == 0) rpc_show_tasks(&init_net); } else { - if (!access_ok(VERIFY_WRITE, buffer, left)) - return -EFAULT; - len = sprintf(tmpbuf, "%d", *(unsigned int *) table->data); + len = sprintf(tmpbuf, "0x%04x", *(unsigned int *) table->data); if (len > left) len = left; - if (__copy_to_user(buffer, tmpbuf, len)) - return -EFAULT; + memcpy(buffer, tmpbuf, len); if ((left -= len) > 0) { - if (put_user('\n', (char __user *)buffer + len)) - return -EFAULT; + *((char *)buffer + len) = '\n'; left--; } } @@ -134,6 +123,7 @@ done: return 0; } +static struct ctl_table_header *sunrpc_table_header; static struct ctl_table debug_table[] = { { @@ -170,16 +160,21 @@ static struct ctl_table debug_table[] = { .mode = 0444, .proc_handler = proc_do_xprt, }, - { } }; -static struct ctl_table sunrpc_table[] = { - { - .procname = "sunrpc", - .mode = 0555, - .child = debug_table - }, - { } -}; +void +rpc_register_sysctl(void) +{ + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl("sunrpc", debug_table); +} +void +rpc_unregister_sysctl(void) +{ + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } +} #endif diff --git a/net/sunrpc/sysfs.c b/net/sunrpc/sysfs.c new file mode 100644 index 000000000000..8b01b7ae2690 --- /dev/null +++ b/net/sunrpc/sysfs.c @@ -0,0 +1,829 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2020 Anna Schumaker <Anna.Schumaker@Netapp.com> + */ +#include <linux/sunrpc/clnt.h> +#include <linux/kobject.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/xprtsock.h> + +#include "sysfs.h" + +struct xprt_addr { + const char *addr; + struct rcu_head rcu; +}; + +static void free_xprt_addr(struct rcu_head *head) +{ + struct xprt_addr *addr = container_of(head, struct xprt_addr, rcu); + + kfree(addr->addr); + kfree(addr); +} + +static struct kset *rpc_sunrpc_kset; +static struct kobject *rpc_sunrpc_client_kobj, *rpc_sunrpc_xprt_switch_kobj; + +static void rpc_sysfs_object_release(struct kobject *kobj) +{ + kfree(kobj); +} + +static const struct kobj_ns_type_operations * +rpc_sysfs_object_child_ns_type(const struct kobject *kobj) +{ + return &net_ns_type_operations; +} + +static const struct kobj_type rpc_sysfs_object_type = { + .release = rpc_sysfs_object_release, + .sysfs_ops = &kobj_sysfs_ops, + .child_ns_type = rpc_sysfs_object_child_ns_type, +}; + +static struct kobject *rpc_sysfs_object_alloc(const char *name, + struct kset *kset, + struct kobject *parent) +{ + struct kobject *kobj; + + kobj = kzalloc(sizeof(*kobj), GFP_KERNEL); + if (kobj) { + kobj->kset = kset; + if (kobject_init_and_add(kobj, &rpc_sysfs_object_type, + parent, "%s", name) == 0) + return kobj; + kobject_put(kobj); + } + return NULL; +} + +static inline struct rpc_clnt * +rpc_sysfs_client_kobj_get_clnt(struct kobject *kobj) +{ + struct rpc_sysfs_client *c = container_of(kobj, + struct rpc_sysfs_client, kobject); + struct rpc_clnt *ret = c->clnt; + + return refcount_inc_not_zero(&ret->cl_count) ? ret : NULL; +} + +static inline struct rpc_xprt * +rpc_sysfs_xprt_kobj_get_xprt(struct kobject *kobj) +{ + struct rpc_sysfs_xprt *x = container_of(kobj, + struct rpc_sysfs_xprt, kobject); + + return xprt_get(x->xprt); +} + +static inline struct rpc_xprt_switch * +rpc_sysfs_xprt_kobj_get_xprt_switch(struct kobject *kobj) +{ + struct rpc_sysfs_xprt *x = container_of(kobj, + struct rpc_sysfs_xprt, kobject); + + return xprt_switch_get(x->xprt_switch); +} + +static inline struct rpc_xprt_switch * +rpc_sysfs_xprt_switch_kobj_get_xprt(struct kobject *kobj) +{ + struct rpc_sysfs_xprt_switch *x = container_of(kobj, + struct rpc_sysfs_xprt_switch, kobject); + + return xprt_switch_get(x->xprt_switch); +} + +static ssize_t rpc_sysfs_clnt_version_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_clnt *clnt = rpc_sysfs_client_kobj_get_clnt(kobj); + ssize_t ret; + + if (!clnt) + return sprintf(buf, "<closed>\n"); + + ret = sprintf(buf, "%u", clnt->cl_vers); + refcount_dec(&clnt->cl_count); + return ret; +} + +static ssize_t rpc_sysfs_clnt_program_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_clnt *clnt = rpc_sysfs_client_kobj_get_clnt(kobj); + ssize_t ret; + + if (!clnt) + return sprintf(buf, "<closed>\n"); + + ret = sprintf(buf, "%s", clnt->cl_program->name); + refcount_dec(&clnt->cl_count); + return ret; +} + +static ssize_t rpc_sysfs_clnt_max_connect_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_clnt *clnt = rpc_sysfs_client_kobj_get_clnt(kobj); + ssize_t ret; + + if (!clnt) + return sprintf(buf, "<closed>\n"); + + ret = sprintf(buf, "%u\n", clnt->cl_max_connect); + refcount_dec(&clnt->cl_count); + return ret; +} + +static ssize_t rpc_sysfs_xprt_dstaddr_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + ssize_t ret; + + if (!xprt) { + ret = sprintf(buf, "<closed>\n"); + goto out; + } + ret = sprintf(buf, "%s\n", xprt->address_strings[RPC_DISPLAY_ADDR]); + xprt_put(xprt); +out: + return ret; +} + +static ssize_t rpc_sysfs_xprt_srcaddr_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + size_t buflen = PAGE_SIZE; + ssize_t ret; + + if (!xprt || !xprt_connected(xprt)) { + ret = sprintf(buf, "<closed>\n"); + } else if (xprt->ops->get_srcaddr) { + ret = xprt->ops->get_srcaddr(xprt, buf, buflen); + if (ret > 0) { + if (ret < buflen - 1) { + buf[ret] = '\n'; + ret++; + buf[ret] = '\0'; + } + } else + ret = sprintf(buf, "<closed>\n"); + } else + ret = sprintf(buf, "<not a socket>\n"); + xprt_put(xprt); + return ret; +} + +static const char *xprtsec_strings[] = { + [RPC_XPRTSEC_NONE] = "none", + [RPC_XPRTSEC_TLS_ANON] = "tls-anon", + [RPC_XPRTSEC_TLS_X509] = "tls-x509", +}; + +static ssize_t rpc_sysfs_xprt_xprtsec_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + ssize_t ret; + + if (!xprt) { + ret = sprintf(buf, "<closed>\n"); + goto out; + } + + ret = sprintf(buf, "%s\n", xprtsec_strings[xprt->xprtsec.policy]); + xprt_put(xprt); +out: + return ret; + +} + +static ssize_t rpc_sysfs_xprt_info_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + unsigned short srcport = 0; + size_t buflen = PAGE_SIZE; + ssize_t ret; + + if (!xprt || !xprt_connected(xprt)) { + ret = sprintf(buf, "<closed>\n"); + goto out; + } + + if (xprt->ops->get_srcport) + srcport = xprt->ops->get_srcport(xprt); + + ret = snprintf(buf, buflen, + "last_used=%lu\ncur_cong=%lu\ncong_win=%lu\n" + "max_num_slots=%u\nmin_num_slots=%u\nnum_reqs=%u\n" + "binding_q_len=%u\nsending_q_len=%u\npending_q_len=%u\n" + "backlog_q_len=%u\nmain_xprt=%d\nsrc_port=%u\n" + "tasks_queuelen=%ld\ndst_port=%s\n", + xprt->last_used, xprt->cong, xprt->cwnd, xprt->max_reqs, + xprt->min_reqs, xprt->num_reqs, xprt->binding.qlen, + xprt->sending.qlen, xprt->pending.qlen, + xprt->backlog.qlen, xprt->main, srcport, + atomic_long_read(&xprt->queuelen), + xprt->address_strings[RPC_DISPLAY_PORT]); +out: + xprt_put(xprt); + return ret; +} + +static ssize_t rpc_sysfs_xprt_state_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + ssize_t ret; + int locked, connected, connecting, close_wait, bound, binding, + closing, congested, cwnd_wait, write_space, offline, remove; + + if (!(xprt && xprt->state)) { + ret = sprintf(buf, "state=CLOSED\n"); + } else { + locked = test_bit(XPRT_LOCKED, &xprt->state); + connected = test_bit(XPRT_CONNECTED, &xprt->state); + connecting = test_bit(XPRT_CONNECTING, &xprt->state); + close_wait = test_bit(XPRT_CLOSE_WAIT, &xprt->state); + bound = test_bit(XPRT_BOUND, &xprt->state); + binding = test_bit(XPRT_BINDING, &xprt->state); + closing = test_bit(XPRT_CLOSING, &xprt->state); + congested = test_bit(XPRT_CONGESTED, &xprt->state); + cwnd_wait = test_bit(XPRT_CWND_WAIT, &xprt->state); + write_space = test_bit(XPRT_WRITE_SPACE, &xprt->state); + offline = test_bit(XPRT_OFFLINE, &xprt->state); + remove = test_bit(XPRT_REMOVE, &xprt->state); + + ret = sprintf(buf, "state=%s %s %s %s %s %s %s %s %s %s %s %s\n", + locked ? "LOCKED" : "", + connected ? "CONNECTED" : "", + connecting ? "CONNECTING" : "", + close_wait ? "CLOSE_WAIT" : "", + bound ? "BOUND" : "", + binding ? "BOUNDING" : "", + closing ? "CLOSING" : "", + congested ? "CONGESTED" : "", + cwnd_wait ? "CWND_WAIT" : "", + write_space ? "WRITE_SPACE" : "", + offline ? "OFFLINE" : "", + remove ? "REMOVE" : ""); + } + + xprt_put(xprt); + return ret; +} + +static ssize_t rpc_sysfs_xprt_del_xprt_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + return sprintf(buf, "# delete this xprt\n"); +} + + +static ssize_t rpc_sysfs_xprt_switch_info_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt_switch *xprt_switch = + rpc_sysfs_xprt_switch_kobj_get_xprt(kobj); + ssize_t ret; + + if (!xprt_switch) + return 0; + ret = sprintf(buf, "num_xprts=%u\nnum_active=%u\n" + "num_unique_destaddr=%u\nqueue_len=%ld\n", + xprt_switch->xps_nxprts, xprt_switch->xps_nactive, + xprt_switch->xps_nunique_destaddr_xprts, + atomic_long_read(&xprt_switch->xps_queuelen)); + xprt_switch_put(xprt_switch); + return ret; +} + +static ssize_t rpc_sysfs_xprt_switch_add_xprt_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + return sprintf(buf, "# add one xprt to this xprt_switch\n"); +} + +static ssize_t rpc_sysfs_xprt_switch_add_xprt_store(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + struct rpc_xprt_switch *xprt_switch = + rpc_sysfs_xprt_switch_kobj_get_xprt(kobj); + struct xprt_create xprt_create_args; + struct rpc_xprt *xprt, *new; + + if (!xprt_switch) + return 0; + + xprt = rpc_xprt_switch_get_main_xprt(xprt_switch); + if (!xprt) + goto out; + + xprt_create_args.ident = xprt->xprt_class->ident; + xprt_create_args.net = xprt->xprt_net; + xprt_create_args.dstaddr = (struct sockaddr *)&xprt->addr; + xprt_create_args.addrlen = xprt->addrlen; + xprt_create_args.servername = xprt->servername; + xprt_create_args.bc_xprt = xprt->bc_xprt; + xprt_create_args.xprtsec = xprt->xprtsec; + xprt_create_args.connect_timeout = xprt->connect_timeout; + xprt_create_args.reconnect_timeout = xprt->max_reconnect_timeout; + + new = xprt_create_transport(&xprt_create_args); + if (IS_ERR_OR_NULL(new)) { + count = PTR_ERR(new); + goto out_put_xprt; + } + + rpc_xprt_switch_add_xprt(xprt_switch, new); + xprt_put(new); + +out_put_xprt: + xprt_put(xprt); +out: + xprt_switch_put(xprt_switch); + return count; +} + +static ssize_t rpc_sysfs_xprt_dstaddr_store(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + struct sockaddr *saddr; + char *dst_addr; + int port; + struct xprt_addr *saved_addr; + size_t buf_len; + + if (!xprt) + return 0; + if (!(xprt->xprt_class->ident == XPRT_TRANSPORT_TCP || + xprt->xprt_class->ident == XPRT_TRANSPORT_TCP_TLS || + xprt->xprt_class->ident == XPRT_TRANSPORT_RDMA)) { + xprt_put(xprt); + return -EOPNOTSUPP; + } + + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) { + count = -EINTR; + goto out_put; + } + saddr = (struct sockaddr *)&xprt->addr; + port = rpc_get_port(saddr); + + /* buf_len is the len until the first occurrence of either + * '\n' or '\0' + */ + buf_len = strcspn(buf, "\n"); + + dst_addr = kstrndup(buf, buf_len, GFP_KERNEL); + if (!dst_addr) + goto out_err; + saved_addr = kzalloc(sizeof(*saved_addr), GFP_KERNEL); + if (!saved_addr) + goto out_err_free; + saved_addr->addr = + rcu_dereference_raw(xprt->address_strings[RPC_DISPLAY_ADDR]); + rcu_assign_pointer(xprt->address_strings[RPC_DISPLAY_ADDR], dst_addr); + call_rcu(&saved_addr->rcu, free_xprt_addr); + xprt->addrlen = rpc_pton(xprt->xprt_net, buf, buf_len, saddr, + sizeof(*saddr)); + rpc_set_port(saddr, port); + + xprt_force_disconnect(xprt); +out: + xprt_release_write(xprt, NULL); +out_put: + xprt_put(xprt); + return count; +out_err_free: + kfree(dst_addr); +out_err: + count = -ENOMEM; + goto out; +} + +static ssize_t rpc_sysfs_xprt_state_change(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + int offline = 0, online = 0, remove = 0; + struct rpc_xprt_switch *xps = rpc_sysfs_xprt_kobj_get_xprt_switch(kobj); + + if (!xprt || !xps) { + count = 0; + goto out_put; + } + + if (!strncmp(buf, "offline", 7)) + offline = 1; + else if (!strncmp(buf, "online", 6)) + online = 1; + else if (!strncmp(buf, "remove", 6)) + remove = 1; + else { + count = -EINVAL; + goto out_put; + } + + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) { + count = -EINTR; + goto out_put; + } + if (xprt->main) { + count = -EINVAL; + goto release_tasks; + } + if (offline) { + xprt_set_offline_locked(xprt, xps); + } else if (online) { + xprt_set_online_locked(xprt, xps); + } else if (remove) { + if (test_bit(XPRT_OFFLINE, &xprt->state)) + xprt_delete_locked(xprt, xps); + else + count = -EINVAL; + } + +release_tasks: + xprt_release_write(xprt, NULL); +out_put: + xprt_put(xprt); + xprt_switch_put(xps); + return count; +} + +static ssize_t rpc_sysfs_xprt_del_xprt(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + struct rpc_xprt_switch *xps = rpc_sysfs_xprt_kobj_get_xprt_switch(kobj); + + if (!xprt || !xps) { + count = 0; + goto out; + } + + if (xprt->main) { + count = -EINVAL; + goto release_tasks; + } + + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) { + count = -EINTR; + goto out_put; + } + + xprt_set_offline_locked(xprt, xps); + xprt_delete_locked(xprt, xps); + +release_tasks: + xprt_release_write(xprt, NULL); +out_put: + xprt_put(xprt); + xprt_switch_put(xps); +out: + return count; +} + +int rpc_sysfs_init(void) +{ + rpc_sunrpc_kset = kset_create_and_add("sunrpc", NULL, kernel_kobj); + if (!rpc_sunrpc_kset) + return -ENOMEM; + rpc_sunrpc_client_kobj = + rpc_sysfs_object_alloc("rpc-clients", rpc_sunrpc_kset, NULL); + if (!rpc_sunrpc_client_kobj) + goto err_client; + rpc_sunrpc_xprt_switch_kobj = + rpc_sysfs_object_alloc("xprt-switches", rpc_sunrpc_kset, NULL); + if (!rpc_sunrpc_xprt_switch_kobj) + goto err_switch; + return 0; +err_switch: + kobject_put(rpc_sunrpc_client_kobj); + rpc_sunrpc_client_kobj = NULL; +err_client: + kset_unregister(rpc_sunrpc_kset); + rpc_sunrpc_kset = NULL; + return -ENOMEM; +} + +static void rpc_sysfs_client_release(struct kobject *kobj) +{ + struct rpc_sysfs_client *c; + + c = container_of(kobj, struct rpc_sysfs_client, kobject); + kfree(c); +} + +static void rpc_sysfs_xprt_switch_release(struct kobject *kobj) +{ + struct rpc_sysfs_xprt_switch *xprt_switch; + + xprt_switch = container_of(kobj, struct rpc_sysfs_xprt_switch, kobject); + kfree(xprt_switch); +} + +static void rpc_sysfs_xprt_release(struct kobject *kobj) +{ + struct rpc_sysfs_xprt *xprt; + + xprt = container_of(kobj, struct rpc_sysfs_xprt, kobject); + kfree(xprt); +} + +static const void *rpc_sysfs_client_namespace(const struct kobject *kobj) +{ + return container_of(kobj, struct rpc_sysfs_client, kobject)->net; +} + +static const void *rpc_sysfs_xprt_switch_namespace(const struct kobject *kobj) +{ + return container_of(kobj, struct rpc_sysfs_xprt_switch, kobject)->net; +} + +static const void *rpc_sysfs_xprt_namespace(const struct kobject *kobj) +{ + return container_of(kobj, struct rpc_sysfs_xprt, + kobject)->xprt->xprt_net; +} + +static struct kobj_attribute rpc_sysfs_clnt_version = __ATTR(rpc_version, + 0444, rpc_sysfs_clnt_version_show, NULL); + +static struct kobj_attribute rpc_sysfs_clnt_program = __ATTR(program, + 0444, rpc_sysfs_clnt_program_show, NULL); + +static struct kobj_attribute rpc_sysfs_clnt_max_connect = __ATTR(max_connect, + 0444, rpc_sysfs_clnt_max_connect_show, NULL); + +static struct attribute *rpc_sysfs_rpc_clnt_attrs[] = { + &rpc_sysfs_clnt_version.attr, + &rpc_sysfs_clnt_program.attr, + &rpc_sysfs_clnt_max_connect.attr, + NULL, +}; +ATTRIBUTE_GROUPS(rpc_sysfs_rpc_clnt); + +static struct kobj_attribute rpc_sysfs_xprt_dstaddr = __ATTR(dstaddr, + 0644, rpc_sysfs_xprt_dstaddr_show, rpc_sysfs_xprt_dstaddr_store); + +static struct kobj_attribute rpc_sysfs_xprt_srcaddr = __ATTR(srcaddr, + 0644, rpc_sysfs_xprt_srcaddr_show, NULL); + +static struct kobj_attribute rpc_sysfs_xprt_xprtsec = __ATTR(xprtsec, + 0644, rpc_sysfs_xprt_xprtsec_show, NULL); + +static struct kobj_attribute rpc_sysfs_xprt_info = __ATTR(xprt_info, + 0444, rpc_sysfs_xprt_info_show, NULL); + +static struct kobj_attribute rpc_sysfs_xprt_change_state = __ATTR(xprt_state, + 0644, rpc_sysfs_xprt_state_show, rpc_sysfs_xprt_state_change); + +static struct kobj_attribute rpc_sysfs_xprt_del = __ATTR(del_xprt, + 0644, rpc_sysfs_xprt_del_xprt_show, rpc_sysfs_xprt_del_xprt); + +static struct attribute *rpc_sysfs_xprt_attrs[] = { + &rpc_sysfs_xprt_dstaddr.attr, + &rpc_sysfs_xprt_srcaddr.attr, + &rpc_sysfs_xprt_xprtsec.attr, + &rpc_sysfs_xprt_info.attr, + &rpc_sysfs_xprt_change_state.attr, + &rpc_sysfs_xprt_del.attr, + NULL, +}; +ATTRIBUTE_GROUPS(rpc_sysfs_xprt); + +static struct kobj_attribute rpc_sysfs_xprt_switch_info = + __ATTR(xprt_switch_info, 0444, rpc_sysfs_xprt_switch_info_show, NULL); + +static struct kobj_attribute rpc_sysfs_xprt_switch_add_xprt = + __ATTR(add_xprt, 0644, rpc_sysfs_xprt_switch_add_xprt_show, + rpc_sysfs_xprt_switch_add_xprt_store); + +static struct attribute *rpc_sysfs_xprt_switch_attrs[] = { + &rpc_sysfs_xprt_switch_info.attr, + &rpc_sysfs_xprt_switch_add_xprt.attr, + NULL, +}; +ATTRIBUTE_GROUPS(rpc_sysfs_xprt_switch); + +static const struct kobj_type rpc_sysfs_client_type = { + .release = rpc_sysfs_client_release, + .default_groups = rpc_sysfs_rpc_clnt_groups, + .sysfs_ops = &kobj_sysfs_ops, + .namespace = rpc_sysfs_client_namespace, +}; + +static const struct kobj_type rpc_sysfs_xprt_switch_type = { + .release = rpc_sysfs_xprt_switch_release, + .default_groups = rpc_sysfs_xprt_switch_groups, + .sysfs_ops = &kobj_sysfs_ops, + .namespace = rpc_sysfs_xprt_switch_namespace, +}; + +static const struct kobj_type rpc_sysfs_xprt_type = { + .release = rpc_sysfs_xprt_release, + .default_groups = rpc_sysfs_xprt_groups, + .sysfs_ops = &kobj_sysfs_ops, + .namespace = rpc_sysfs_xprt_namespace, +}; + +void rpc_sysfs_exit(void) +{ + kobject_put(rpc_sunrpc_client_kobj); + kobject_put(rpc_sunrpc_xprt_switch_kobj); + kset_unregister(rpc_sunrpc_kset); +} + +static struct rpc_sysfs_client *rpc_sysfs_client_alloc(struct kobject *parent, + struct net *net, + int clid) +{ + struct rpc_sysfs_client *p; + + p = kzalloc(sizeof(*p), GFP_KERNEL); + if (p) { + p->net = net; + p->kobject.kset = rpc_sunrpc_kset; + if (kobject_init_and_add(&p->kobject, &rpc_sysfs_client_type, + parent, "clnt-%d", clid) == 0) + return p; + kobject_put(&p->kobject); + } + return NULL; +} + +static struct rpc_sysfs_xprt_switch * +rpc_sysfs_xprt_switch_alloc(struct kobject *parent, + struct rpc_xprt_switch *xprt_switch, + struct net *net, + gfp_t gfp_flags) +{ + struct rpc_sysfs_xprt_switch *p; + + p = kzalloc(sizeof(*p), gfp_flags); + if (p) { + p->net = net; + p->kobject.kset = rpc_sunrpc_kset; + if (kobject_init_and_add(&p->kobject, + &rpc_sysfs_xprt_switch_type, + parent, "switch-%d", + xprt_switch->xps_id) == 0) + return p; + kobject_put(&p->kobject); + } + return NULL; +} + +static struct rpc_sysfs_xprt *rpc_sysfs_xprt_alloc(struct kobject *parent, + struct rpc_xprt *xprt, + gfp_t gfp_flags) +{ + struct rpc_sysfs_xprt *p; + + p = kzalloc(sizeof(*p), gfp_flags); + if (!p) + goto out; + p->kobject.kset = rpc_sunrpc_kset; + if (kobject_init_and_add(&p->kobject, &rpc_sysfs_xprt_type, + parent, "xprt-%d-%s", xprt->id, + xprt->address_strings[RPC_DISPLAY_PROTO]) == 0) + return p; + kobject_put(&p->kobject); +out: + return NULL; +} + +void rpc_sysfs_client_setup(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xprt_switch, + struct net *net) +{ + struct rpc_sysfs_client *rpc_client; + struct rpc_sysfs_xprt_switch *xswitch = + (struct rpc_sysfs_xprt_switch *)xprt_switch->xps_sysfs; + + if (!xswitch) + return; + + rpc_client = rpc_sysfs_client_alloc(rpc_sunrpc_client_kobj, + net, clnt->cl_clid); + if (rpc_client) { + char name[] = "switch"; + int ret; + + clnt->cl_sysfs = rpc_client; + rpc_client->clnt = clnt; + rpc_client->xprt_switch = xprt_switch; + kobject_uevent(&rpc_client->kobject, KOBJ_ADD); + ret = sysfs_create_link_nowarn(&rpc_client->kobject, + &xswitch->kobject, name); + if (ret) + pr_warn("can't create link to %s in sysfs (%d)\n", + name, ret); + } +} + +void rpc_sysfs_xprt_switch_setup(struct rpc_xprt_switch *xprt_switch, + struct rpc_xprt *xprt, + gfp_t gfp_flags) +{ + struct rpc_sysfs_xprt_switch *rpc_xprt_switch; + struct net *net; + + if (xprt_switch->xps_net) + net = xprt_switch->xps_net; + else + net = xprt->xprt_net; + rpc_xprt_switch = + rpc_sysfs_xprt_switch_alloc(rpc_sunrpc_xprt_switch_kobj, + xprt_switch, net, gfp_flags); + if (rpc_xprt_switch) { + xprt_switch->xps_sysfs = rpc_xprt_switch; + rpc_xprt_switch->xprt_switch = xprt_switch; + rpc_xprt_switch->xprt = xprt; + kobject_uevent(&rpc_xprt_switch->kobject, KOBJ_ADD); + } else { + xprt_switch->xps_sysfs = NULL; + } +} + +void rpc_sysfs_xprt_setup(struct rpc_xprt_switch *xprt_switch, + struct rpc_xprt *xprt, + gfp_t gfp_flags) +{ + struct rpc_sysfs_xprt *rpc_xprt; + struct rpc_sysfs_xprt_switch *switch_obj = + (struct rpc_sysfs_xprt_switch *)xprt_switch->xps_sysfs; + + if (!switch_obj) + return; + + rpc_xprt = rpc_sysfs_xprt_alloc(&switch_obj->kobject, xprt, gfp_flags); + if (rpc_xprt) { + xprt->xprt_sysfs = rpc_xprt; + rpc_xprt->xprt = xprt; + rpc_xprt->xprt_switch = xprt_switch; + kobject_uevent(&rpc_xprt->kobject, KOBJ_ADD); + } +} + +void rpc_sysfs_client_destroy(struct rpc_clnt *clnt) +{ + struct rpc_sysfs_client *rpc_client = clnt->cl_sysfs; + + if (rpc_client) { + char name[] = "switch"; + + sysfs_remove_link(&rpc_client->kobject, name); + kobject_uevent(&rpc_client->kobject, KOBJ_REMOVE); + kobject_del(&rpc_client->kobject); + kobject_put(&rpc_client->kobject); + clnt->cl_sysfs = NULL; + } +} + +void rpc_sysfs_xprt_switch_destroy(struct rpc_xprt_switch *xprt_switch) +{ + struct rpc_sysfs_xprt_switch *rpc_xprt_switch = xprt_switch->xps_sysfs; + + if (rpc_xprt_switch) { + kobject_uevent(&rpc_xprt_switch->kobject, KOBJ_REMOVE); + kobject_del(&rpc_xprt_switch->kobject); + kobject_put(&rpc_xprt_switch->kobject); + xprt_switch->xps_sysfs = NULL; + } +} + +void rpc_sysfs_xprt_destroy(struct rpc_xprt *xprt) +{ + struct rpc_sysfs_xprt *rpc_xprt = xprt->xprt_sysfs; + + if (rpc_xprt) { + kobject_uevent(&rpc_xprt->kobject, KOBJ_REMOVE); + kobject_del(&rpc_xprt->kobject); + kobject_put(&rpc_xprt->kobject); + xprt->xprt_sysfs = NULL; + } +} diff --git a/net/sunrpc/sysfs.h b/net/sunrpc/sysfs.h new file mode 100644 index 000000000000..d2dd77a0a0e9 --- /dev/null +++ b/net/sunrpc/sysfs.h @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2020 Anna Schumaker <Anna.Schumaker@Netapp.com> + */ +#ifndef __SUNRPC_SYSFS_H +#define __SUNRPC_SYSFS_H + +struct rpc_sysfs_xprt_switch { + struct kobject kobject; + struct net *net; + struct rpc_xprt_switch *xprt_switch; + struct rpc_xprt *xprt; +}; + +struct rpc_sysfs_xprt { + struct kobject kobject; + struct rpc_xprt *xprt; + struct rpc_xprt_switch *xprt_switch; +}; + +int rpc_sysfs_init(void); +void rpc_sysfs_exit(void); + +void rpc_sysfs_client_setup(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xprt_switch, + struct net *net); +void rpc_sysfs_client_destroy(struct rpc_clnt *clnt); +void rpc_sysfs_xprt_switch_setup(struct rpc_xprt_switch *xprt_switch, + struct rpc_xprt *xprt, gfp_t gfp_flags); +void rpc_sysfs_xprt_switch_destroy(struct rpc_xprt_switch *xprt); +void rpc_sysfs_xprt_setup(struct rpc_xprt_switch *xprt_switch, + struct rpc_xprt *xprt, gfp_t gfp_flags); +void rpc_sysfs_xprt_destroy(struct rpc_xprt *xprt); + +#endif diff --git a/net/sunrpc/timer.c b/net/sunrpc/timer.c index 08881d0c9672..81ae35b3764f 100644 --- a/net/sunrpc/timer.c +++ b/net/sunrpc/timer.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/timer.c * diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c index 75edcfad6e26..70efc727a9cd 100644 --- a/net/sunrpc/xdr.c +++ b/net/sunrpc/xdr.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/xdr.c * @@ -15,6 +16,11 @@ #include <linux/errno.h> #include <linux/sunrpc/xdr.h> #include <linux/sunrpc/msg_prot.h> +#include <linux/bvec.h> +#include <trace/events/sunrpc.h> + +static void _copy_to_pages(struct page **, size_t, const char *, size_t); + /* * XDR functions for basic NFS types @@ -31,19 +37,6 @@ xdr_encode_netobj(__be32 *p, const struct xdr_netobj *obj) } EXPORT_SYMBOL_GPL(xdr_encode_netobj); -__be32 * -xdr_decode_netobj(__be32 *p, struct xdr_netobj *obj) -{ - unsigned int len; - - if ((len = be32_to_cpu(*p++)) > XDR_MAX_NETOBJ) - return NULL; - obj->len = len; - obj->data = (u8 *) p; - return p + XDR_QUADLEN(len); -} -EXPORT_SYMBOL_GPL(xdr_decode_netobj); - /** * xdr_encode_opaque_fixed - Encode fixed length opaque data * @p: pointer to current position in XDR buffer. @@ -96,29 +89,13 @@ xdr_encode_string(__be32 *p, const char *string) } EXPORT_SYMBOL_GPL(xdr_encode_string); -__be32 * -xdr_decode_string_inplace(__be32 *p, char **sp, - unsigned int *lenp, unsigned int maxlen) -{ - u32 len; - - len = be32_to_cpu(*p++); - if (len > maxlen) - return NULL; - *lenp = len; - *sp = (char *) p; - return p + XDR_QUADLEN(len); -} -EXPORT_SYMBOL_GPL(xdr_decode_string_inplace); - /** * xdr_terminate_string - '\0'-terminate a string residing in an xdr_buf * @buf: XDR buffer where string resides * @len: length of string, in bytes * */ -void -xdr_terminate_string(struct xdr_buf *buf, const u32 len) +void xdr_terminate_string(const struct xdr_buf *buf, const u32 len) { char *kaddr; @@ -128,6 +105,97 @@ xdr_terminate_string(struct xdr_buf *buf, const u32 len) } EXPORT_SYMBOL_GPL(xdr_terminate_string); +size_t xdr_buf_pagecount(const struct xdr_buf *buf) +{ + if (!buf->page_len) + return 0; + return (buf->page_base + buf->page_len + PAGE_SIZE - 1) >> PAGE_SHIFT; +} + +int +xdr_alloc_bvec(struct xdr_buf *buf, gfp_t gfp) +{ + size_t i, n = xdr_buf_pagecount(buf); + + if (n != 0 && buf->bvec == NULL) { + buf->bvec = kmalloc_array(n, sizeof(buf->bvec[0]), gfp); + if (!buf->bvec) + return -ENOMEM; + for (i = 0; i < n; i++) { + bvec_set_page(&buf->bvec[i], buf->pages[i], PAGE_SIZE, + 0); + } + } + return 0; +} + +void +xdr_free_bvec(struct xdr_buf *buf) +{ + kfree(buf->bvec); + buf->bvec = NULL; +} + +/** + * xdr_buf_to_bvec - Copy components of an xdr_buf into a bio_vec array + * @bvec: bio_vec array to populate + * @bvec_size: element count of @bio_vec + * @xdr: xdr_buf to be copied + * + * Returns the number of entries consumed in @bvec. + */ +unsigned int xdr_buf_to_bvec(struct bio_vec *bvec, unsigned int bvec_size, + const struct xdr_buf *xdr) +{ + const struct kvec *head = xdr->head; + const struct kvec *tail = xdr->tail; + unsigned int count = 0; + + if (head->iov_len) { + bvec_set_virt(bvec++, head->iov_base, head->iov_len); + ++count; + } + + if (xdr->page_len) { + unsigned int offset, len, remaining; + struct page **pages = xdr->pages; + + offset = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining > 0) { + len = min_t(unsigned int, remaining, + PAGE_SIZE - offset); + bvec_set_page(bvec++, *pages++, len, offset); + remaining -= len; + offset = 0; + if (unlikely(++count > bvec_size)) + goto bvec_overflow; + } + } + + if (tail->iov_len) { + bvec_set_virt(bvec, tail->iov_base, tail->iov_len); + if (unlikely(++count > bvec_size)) + goto bvec_overflow; + } + + return count; + +bvec_overflow: + pr_warn_once("%s: bio_vec array overflow\n", __func__); + return count - 1; +} +EXPORT_SYMBOL_GPL(xdr_buf_to_bvec); + +/** + * xdr_inline_pages - Prepare receive buffer for a large reply + * @xdr: xdr_buf into which reply will be placed + * @offset: expected offset where data payload will start, in bytes + * @pages: vector of struct page pointers + * @base: offset in first page where receive should start, in bytes + * @len: expected size of the upper layer data payload, in bytes + * + */ void xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset, struct page **pages, unsigned int base, unsigned int len) @@ -145,7 +213,6 @@ xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset, tail->iov_base = buf + offset; tail->iov_len = buflen - offset; - xdr->buflen += len; } EXPORT_SYMBOL_GPL(xdr_inline_pages); @@ -155,7 +222,7 @@ EXPORT_SYMBOL_GPL(xdr_inline_pages); */ /** - * _shift_data_right_pages + * _shift_data_left_pages * @pages: vector of pages containing both the source and dest memory area. * @pgto_base: page vector address of destination * @pgfrom_base: page vector address of source @@ -165,6 +232,71 @@ EXPORT_SYMBOL_GPL(xdr_inline_pages); * the same way: * if a memory area starts at byte 'base' in page 'pages[i]', * then its address is given as (i << PAGE_CACHE_SHIFT) + base + * Alse note: pgto_base must be < pgfrom_base, but the memory areas + * they point to may overlap. + */ +static void +_shift_data_left_pages(struct page **pages, size_t pgto_base, + size_t pgfrom_base, size_t len) +{ + struct page **pgfrom, **pgto; + char *vfrom, *vto; + size_t copy; + + BUG_ON(pgfrom_base <= pgto_base); + + if (!len) + return; + + pgto = pages + (pgto_base >> PAGE_SHIFT); + pgfrom = pages + (pgfrom_base >> PAGE_SHIFT); + + pgto_base &= ~PAGE_MASK; + pgfrom_base &= ~PAGE_MASK; + + do { + if (pgto_base >= PAGE_SIZE) { + pgto_base = 0; + pgto++; + } + if (pgfrom_base >= PAGE_SIZE){ + pgfrom_base = 0; + pgfrom++; + } + + copy = len; + if (copy > (PAGE_SIZE - pgto_base)) + copy = PAGE_SIZE - pgto_base; + if (copy > (PAGE_SIZE - pgfrom_base)) + copy = PAGE_SIZE - pgfrom_base; + + vto = kmap_atomic(*pgto); + if (*pgto != *pgfrom) { + vfrom = kmap_atomic(*pgfrom); + memcpy(vto + pgto_base, vfrom + pgfrom_base, copy); + kunmap_atomic(vfrom); + } else + memmove(vto + pgto_base, vto + pgfrom_base, copy); + flush_dcache_page(*pgto); + kunmap_atomic(vto); + + pgto_base += copy; + pgfrom_base += copy; + + } while ((len -= copy) != 0); +} + +/** + * _shift_data_right_pages + * @pages: vector of pages containing both the source and dest memory area. + * @pgto_base: page vector address of destination + * @pgfrom_base: page vector address of source + * @len: number of bytes to copy + * + * Note: the addresses pgto_base and pgfrom_base are both calculated in + * the same way: + * if a memory area starts at byte 'base' in page 'pages[i]', + * then its address is given as (i << PAGE_SHIFT) + base * Also note: pgfrom_base must be < pgto_base, but the memory areas * they point to may overlap. */ @@ -178,23 +310,26 @@ _shift_data_right_pages(struct page **pages, size_t pgto_base, BUG_ON(pgto_base <= pgfrom_base); + if (!len) + return; + pgto_base += len; pgfrom_base += len; - pgto = pages + (pgto_base >> PAGE_CACHE_SHIFT); - pgfrom = pages + (pgfrom_base >> PAGE_CACHE_SHIFT); + pgto = pages + (pgto_base >> PAGE_SHIFT); + pgfrom = pages + (pgfrom_base >> PAGE_SHIFT); - pgto_base &= ~PAGE_CACHE_MASK; - pgfrom_base &= ~PAGE_CACHE_MASK; + pgto_base &= ~PAGE_MASK; + pgfrom_base &= ~PAGE_MASK; do { /* Are any pointers crossing a page boundary? */ if (pgto_base == 0) { - pgto_base = PAGE_CACHE_SIZE; + pgto_base = PAGE_SIZE; pgto--; } if (pgfrom_base == 0) { - pgfrom_base = PAGE_CACHE_SIZE; + pgfrom_base = PAGE_SIZE; pgfrom--; } @@ -207,10 +342,13 @@ _shift_data_right_pages(struct page **pages, size_t pgto_base, pgfrom_base -= copy; vto = kmap_atomic(*pgto); - vfrom = kmap_atomic(*pgfrom); - memmove(vto + pgto_base, vfrom + pgfrom_base, copy); + if (*pgto != *pgfrom) { + vfrom = kmap_atomic(*pgfrom); + memcpy(vto + pgto_base, vfrom + pgfrom_base, copy); + kunmap_atomic(vfrom); + } else + memmove(vto + pgto_base, vto + pgfrom_base, copy); flush_dcache_page(*pgto); - kunmap_atomic(vfrom); kunmap_atomic(vto); } while ((len -= copy) != 0); @@ -233,11 +371,14 @@ _copy_to_pages(struct page **pages, size_t pgbase, const char *p, size_t len) char *vto; size_t copy; - pgto = pages + (pgbase >> PAGE_CACHE_SHIFT); - pgbase &= ~PAGE_CACHE_MASK; + if (!len) + return; + + pgto = pages + (pgbase >> PAGE_SHIFT); + pgbase &= ~PAGE_MASK; for (;;) { - copy = PAGE_CACHE_SIZE - pgbase; + copy = PAGE_SIZE - pgbase; if (copy > len) copy = len; @@ -250,7 +391,7 @@ _copy_to_pages(struct page **pages, size_t pgbase, const char *p, size_t len) break; pgbase += copy; - if (pgbase == PAGE_CACHE_SIZE) { + if (pgbase == PAGE_SIZE) { flush_dcache_page(*pgto); pgbase = 0; pgto++; @@ -277,11 +418,14 @@ _copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len) char *vfrom; size_t copy; - pgfrom = pages + (pgbase >> PAGE_CACHE_SHIFT); - pgbase &= ~PAGE_CACHE_MASK; + if (!len) + return; + + pgfrom = pages + (pgbase >> PAGE_SHIFT); + pgbase &= ~PAGE_MASK; do { - copy = PAGE_CACHE_SIZE - pgbase; + copy = PAGE_SIZE - pgbase; if (copy > len) copy = len; @@ -290,7 +434,7 @@ _copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len) kunmap_atomic(vfrom); pgbase += copy; - if (pgbase == PAGE_CACHE_SIZE) { + if (pgbase == PAGE_SIZE) { pgbase = 0; pgfrom++; } @@ -300,136 +444,446 @@ _copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len) } EXPORT_SYMBOL_GPL(_copy_from_pages); +static void xdr_buf_iov_zero(const struct kvec *iov, unsigned int base, + unsigned int len) +{ + if (base >= iov->iov_len) + return; + if (len > iov->iov_len - base) + len = iov->iov_len - base; + memset(iov->iov_base + base, 0, len); +} + /** - * xdr_shrink_bufhead + * xdr_buf_pages_zero * @buf: xdr_buf - * @len: bytes to remove from buf->head[0] - * - * Shrinks XDR buffer's header kvec buf->head[0] by - * 'len' bytes. The extra data is not lost, but is instead - * moved into the inlined pages and/or the tail. + * @pgbase: beginning offset + * @len: length */ -static void -xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) +static void xdr_buf_pages_zero(const struct xdr_buf *buf, unsigned int pgbase, + unsigned int len) { - struct kvec *head, *tail; - size_t copy, offs; - unsigned int pglen = buf->page_len; + struct page **pages = buf->pages; + struct page **page; + char *vpage; + unsigned int zero; + + if (!len) + return; + if (pgbase >= buf->page_len) { + xdr_buf_iov_zero(buf->tail, pgbase - buf->page_len, len); + return; + } + if (pgbase + len > buf->page_len) { + xdr_buf_iov_zero(buf->tail, 0, pgbase + len - buf->page_len); + len = buf->page_len - pgbase; + } - tail = buf->tail; - head = buf->head; + pgbase += buf->page_base; - WARN_ON_ONCE(len > head->iov_len); - if (len > head->iov_len) - len = head->iov_len; - - /* Shift the tail first */ - if (tail->iov_len != 0) { - if (tail->iov_len > len) { - copy = tail->iov_len - len; - memmove((char *)tail->iov_base + len, - tail->iov_base, copy); + page = pages + (pgbase >> PAGE_SHIFT); + pgbase &= ~PAGE_MASK; + + do { + zero = PAGE_SIZE - pgbase; + if (zero > len) + zero = len; + + vpage = kmap_atomic(*page); + memset(vpage + pgbase, 0, zero); + kunmap_atomic(vpage); + + flush_dcache_page(*page); + pgbase = 0; + page++; + + } while ((len -= zero) != 0); +} + +static unsigned int xdr_buf_pages_fill_sparse(const struct xdr_buf *buf, + unsigned int buflen, gfp_t gfp) +{ + unsigned int i, npages, pagelen; + + if (!(buf->flags & XDRBUF_SPARSE_PAGES)) + return buflen; + if (buflen <= buf->head->iov_len) + return buflen; + pagelen = buflen - buf->head->iov_len; + if (pagelen > buf->page_len) + pagelen = buf->page_len; + npages = (pagelen + buf->page_base + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + if (!buf->pages[i]) + continue; + buf->pages[i] = alloc_page(gfp); + if (likely(buf->pages[i])) + continue; + buflen -= pagelen; + pagelen = i << PAGE_SHIFT; + if (pagelen > buf->page_base) + buflen += pagelen - buf->page_base; + break; + } + return buflen; +} + +static void xdr_buf_try_expand(struct xdr_buf *buf, unsigned int len) +{ + struct kvec *head = buf->head; + struct kvec *tail = buf->tail; + unsigned int sum = head->iov_len + buf->page_len + tail->iov_len; + unsigned int free_space, newlen; + + if (sum > buf->len) { + free_space = min_t(unsigned int, sum - buf->len, len); + newlen = xdr_buf_pages_fill_sparse(buf, buf->len + free_space, + GFP_KERNEL); + free_space = newlen - buf->len; + buf->len = newlen; + len -= free_space; + if (!len) + return; + } + + if (buf->buflen > sum) { + /* Expand the tail buffer */ + free_space = min_t(unsigned int, buf->buflen - sum, len); + tail->iov_len += free_space; + buf->len += free_space; + } +} + +static void xdr_buf_tail_copy_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *tail = buf->tail; + unsigned int to = base + shift; + + if (to >= tail->iov_len) + return; + if (len + to > tail->iov_len) + len = tail->iov_len - to; + memmove(tail->iov_base + to, tail->iov_base + base, len); +} + +static void xdr_buf_pages_copy_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *tail = buf->tail; + unsigned int to = base + shift; + unsigned int pglen = 0; + unsigned int talen = 0, tato = 0; + + if (base >= buf->page_len) + return; + if (len > buf->page_len - base) + len = buf->page_len - base; + if (to >= buf->page_len) { + tato = to - buf->page_len; + if (tail->iov_len >= len + tato) + talen = len; + else if (tail->iov_len > tato) + talen = tail->iov_len - tato; + } else if (len + to >= buf->page_len) { + pglen = buf->page_len - to; + talen = len - pglen; + if (talen > tail->iov_len) + talen = tail->iov_len; + } else + pglen = len; + + _copy_from_pages(tail->iov_base + tato, buf->pages, + buf->page_base + base + pglen, talen); + _shift_data_right_pages(buf->pages, buf->page_base + to, + buf->page_base + base, pglen); +} + +static void xdr_buf_head_copy_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *head = buf->head; + const struct kvec *tail = buf->tail; + unsigned int to = base + shift; + unsigned int pglen = 0, pgto = 0; + unsigned int talen = 0, tato = 0; + + if (base >= head->iov_len) + return; + if (len > head->iov_len - base) + len = head->iov_len - base; + if (to >= buf->page_len + head->iov_len) { + tato = to - buf->page_len - head->iov_len; + talen = len; + } else if (to >= head->iov_len) { + pgto = to - head->iov_len; + pglen = len; + if (pgto + pglen > buf->page_len) { + talen = pgto + pglen - buf->page_len; + pglen -= talen; } - /* Copy from the inlined pages into the tail */ - copy = len; - if (copy > pglen) - copy = pglen; - offs = len - copy; - if (offs >= tail->iov_len) - copy = 0; - else if (copy > tail->iov_len - offs) - copy = tail->iov_len - offs; - if (copy != 0) - _copy_from_pages((char *)tail->iov_base + offs, - buf->pages, - buf->page_base + pglen + offs - len, - copy); - /* Do we also need to copy data from the head into the tail ? */ - if (len > pglen) { - offs = copy = len - pglen; - if (copy > tail->iov_len) - copy = tail->iov_len; - memcpy(tail->iov_base, - (char *)head->iov_base + - head->iov_len - offs, - copy); + } else { + pglen = len - to; + if (pglen > buf->page_len) { + talen = pglen - buf->page_len; + pglen = buf->page_len; } } - /* Now handle pages */ - if (pglen != 0) { - if (pglen > len) - _shift_data_right_pages(buf->pages, - buf->page_base + len, - buf->page_base, - pglen - len); - copy = len; - if (len > pglen) - copy = pglen; - _copy_to_pages(buf->pages, buf->page_base, - (char *)head->iov_base + head->iov_len - len, - copy); + + len -= talen; + base += len; + if (talen + tato > tail->iov_len) + talen = tail->iov_len > tato ? tail->iov_len - tato : 0; + memcpy(tail->iov_base + tato, head->iov_base + base, talen); + + len -= pglen; + base -= pglen; + _copy_to_pages(buf->pages, buf->page_base + pgto, head->iov_base + base, + pglen); + + base -= len; + memmove(head->iov_base + to, head->iov_base + base, len); +} + +static void xdr_buf_tail_shift_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *tail = buf->tail; + + if (base >= tail->iov_len || !shift || !len) + return; + xdr_buf_tail_copy_right(buf, base, len, shift); +} + +static void xdr_buf_pages_shift_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + if (!shift || !len) + return; + if (base >= buf->page_len) { + xdr_buf_tail_shift_right(buf, base - buf->page_len, len, shift); + return; + } + if (base + len > buf->page_len) + xdr_buf_tail_shift_right(buf, 0, base + len - buf->page_len, + shift); + xdr_buf_pages_copy_right(buf, base, len, shift); +} + +static void xdr_buf_head_shift_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *head = buf->head; + + if (!shift) + return; + if (base >= head->iov_len) { + xdr_buf_pages_shift_right(buf, head->iov_len - base, len, + shift); + return; } - head->iov_len -= len; - buf->buflen -= len; - /* Have we truncated the message? */ - if (buf->len > buf->buflen) - buf->len = buf->buflen; + if (base + len > head->iov_len) + xdr_buf_pages_shift_right(buf, 0, base + len - head->iov_len, + shift); + xdr_buf_head_copy_right(buf, base, len, shift); +} + +static void xdr_buf_tail_copy_left(const struct xdr_buf *buf, unsigned int base, + unsigned int len, unsigned int shift) +{ + const struct kvec *tail = buf->tail; + + if (base >= tail->iov_len) + return; + if (len > tail->iov_len - base) + len = tail->iov_len - base; + /* Shift data into head */ + if (shift > buf->page_len + base) { + const struct kvec *head = buf->head; + unsigned int hdto = + head->iov_len + buf->page_len + base - shift; + unsigned int hdlen = len; + + if (WARN_ONCE(shift > head->iov_len + buf->page_len + base, + "SUNRPC: Misaligned data.\n")) + return; + if (hdto + hdlen > head->iov_len) + hdlen = head->iov_len - hdto; + memcpy(head->iov_base + hdto, tail->iov_base + base, hdlen); + base += hdlen; + len -= hdlen; + if (!len) + return; + } + /* Shift data into pages */ + if (shift > base) { + unsigned int pgto = buf->page_len + base - shift; + unsigned int pglen = len; + + if (pgto + pglen > buf->page_len) + pglen = buf->page_len - pgto; + _copy_to_pages(buf->pages, buf->page_base + pgto, + tail->iov_base + base, pglen); + base += pglen; + len -= pglen; + if (!len) + return; + } + memmove(tail->iov_base + base - shift, tail->iov_base + base, len); +} + +static void xdr_buf_pages_copy_left(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + unsigned int pgto; + + if (base >= buf->page_len) + return; + if (len > buf->page_len - base) + len = buf->page_len - base; + /* Shift data into head */ + if (shift > base) { + const struct kvec *head = buf->head; + unsigned int hdto = head->iov_len + base - shift; + unsigned int hdlen = len; + + if (WARN_ONCE(shift > head->iov_len + base, + "SUNRPC: Misaligned data.\n")) + return; + if (hdto + hdlen > head->iov_len) + hdlen = head->iov_len - hdto; + _copy_from_pages(head->iov_base + hdto, buf->pages, + buf->page_base + base, hdlen); + base += hdlen; + len -= hdlen; + if (!len) + return; + } + pgto = base - shift; + _shift_data_left_pages(buf->pages, buf->page_base + pgto, + buf->page_base + base, len); +} + +static void xdr_buf_tail_shift_left(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + if (!shift || !len) + return; + xdr_buf_tail_copy_left(buf, base, len, shift); +} + +static void xdr_buf_pages_shift_left(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + if (!shift || !len) + return; + if (base >= buf->page_len) { + xdr_buf_tail_shift_left(buf, base - buf->page_len, len, shift); + return; + } + xdr_buf_pages_copy_left(buf, base, len, shift); + len += base; + if (len <= buf->page_len) + return; + xdr_buf_tail_copy_left(buf, 0, len - buf->page_len, shift); +} + +static void xdr_buf_head_shift_left(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *head = buf->head; + unsigned int bytes; + + if (!shift || !len) + return; + + if (shift > base) { + bytes = (shift - base); + if (bytes >= len) + return; + base += bytes; + len -= bytes; + } + + if (base < head->iov_len) { + bytes = min_t(unsigned int, len, head->iov_len - base); + memmove(head->iov_base + (base - shift), + head->iov_base + base, bytes); + base += bytes; + len -= bytes; + } + xdr_buf_pages_shift_left(buf, base - head->iov_len, len, shift); } /** - * xdr_shrink_pagelen + * xdr_shrink_bufhead * @buf: xdr_buf - * @len: bytes to remove from buf->pages + * @len: new length of buf->head[0] * - * Shrinks XDR buffer's page array buf->pages by + * Shrinks XDR buffer's header kvec buf->head[0], setting it to * 'len' bytes. The extra data is not lost, but is instead - * moved into the tail. + * moved into the inlined pages and/or the tail. */ -static void -xdr_shrink_pagelen(struct xdr_buf *buf, size_t len) +static unsigned int xdr_shrink_bufhead(struct xdr_buf *buf, unsigned int len) { - struct kvec *tail; - size_t copy; - unsigned int pglen = buf->page_len; - unsigned int tailbuf_len; - - tail = buf->tail; - BUG_ON (len > pglen); - - tailbuf_len = buf->buflen - buf->head->iov_len - buf->page_len; + struct kvec *head = buf->head; + unsigned int shift, buflen = max(buf->len, len); - /* Shift the tail first */ - if (tailbuf_len != 0) { - unsigned int free_space = tailbuf_len - tail->iov_len; - - if (len < free_space) - free_space = len; - tail->iov_len += free_space; - - copy = len; - if (tail->iov_len > len) { - char *p = (char *)tail->iov_base + len; - memmove(p, tail->iov_base, tail->iov_len - len); - } else - copy = tail->iov_len; - /* Copy from the inlined pages into the tail */ - _copy_from_pages((char *)tail->iov_base, - buf->pages, buf->page_base + pglen - len, - copy); + WARN_ON_ONCE(len > head->iov_len); + if (head->iov_len > buflen) { + buf->buflen -= head->iov_len - buflen; + head->iov_len = buflen; } - buf->page_len -= len; - buf->buflen -= len; - /* Have we truncated the message? */ - if (buf->len > buf->buflen) - buf->len = buf->buflen; + if (len >= head->iov_len) + return 0; + shift = head->iov_len - len; + xdr_buf_try_expand(buf, shift); + xdr_buf_head_shift_right(buf, len, buflen - len, shift); + head->iov_len = len; + buf->buflen -= shift; + buf->len -= shift; + return shift; } -void -xdr_shift_buf(struct xdr_buf *buf, size_t len) +/** + * xdr_shrink_pagelen - shrinks buf->pages to @len bytes + * @buf: xdr_buf + * @len: new page buffer length + * + * The extra data is not lost, but is instead moved into buf->tail. + * Returns the actual number of bytes moved. + */ +static unsigned int xdr_shrink_pagelen(struct xdr_buf *buf, unsigned int len) { - xdr_shrink_bufhead(buf, len); + unsigned int shift, buflen = buf->len - buf->head->iov_len; + + WARN_ON_ONCE(len > buf->page_len); + if (buf->head->iov_len >= buf->len || len > buflen) + buflen = len; + if (buf->page_len > buflen) { + buf->buflen -= buf->page_len - buflen; + buf->page_len = buflen; + } + if (len >= buf->page_len) + return 0; + shift = buf->page_len - len; + xdr_buf_try_expand(buf, shift); + xdr_buf_pages_shift_right(buf, len, buflen - len, shift); + buf->page_len = len; + buf->len -= shift; + buf->buflen -= shift; + return shift; } -EXPORT_SYMBOL_GPL(xdr_shift_buf); /** * xdr_stream_pos - Return the current offset from the start of the xdr_stream @@ -441,11 +895,37 @@ unsigned int xdr_stream_pos(const struct xdr_stream *xdr) } EXPORT_SYMBOL_GPL(xdr_stream_pos); +static void xdr_stream_set_pos(struct xdr_stream *xdr, unsigned int pos) +{ + unsigned int blen = xdr->buf->len; + + xdr->nwords = blen > pos ? XDR_QUADLEN(blen) - XDR_QUADLEN(pos) : 0; +} + +static void xdr_stream_page_set_pos(struct xdr_stream *xdr, unsigned int pos) +{ + xdr_stream_set_pos(xdr, pos + xdr->buf->head[0].iov_len); +} + +/** + * xdr_page_pos - Return the current offset from the start of the xdr pages + * @xdr: pointer to struct xdr_stream + */ +unsigned int xdr_page_pos(const struct xdr_stream *xdr) +{ + unsigned int pos = xdr_stream_pos(xdr); + + WARN_ON(pos < xdr->buf->head[0].iov_len); + return pos - xdr->buf->head[0].iov_len; +} +EXPORT_SYMBOL_GPL(xdr_page_pos); + /** * xdr_init_encode - Initialize a struct xdr_stream for sending data. * @xdr: pointer to xdr_stream struct * @buf: pointer to XDR buffer in which to encode data * @p: current pointer inside XDR buffer + * @rqst: pointer to controlling rpc_rqst, for debugging * * Note: at the moment the RPC client only passes the length of our * scratch buffer in the xdr_buf's header kvec. Previously this @@ -454,11 +934,13 @@ EXPORT_SYMBOL_GPL(xdr_stream_pos); * of the buffer length, and takes care of adjusting the kvec * length for us. */ -void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) +void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p, + struct rpc_rqst *rqst) { struct kvec *iov = buf->head; int scratch_len = buf->buflen - buf->page_len - buf->tail[0].iov_len; + xdr_reset_scratch_buffer(xdr); BUG_ON(scratch_len < 0); xdr->buf = buf; xdr->iov = iov; @@ -475,10 +957,109 @@ void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) buf->len += len; iov->iov_len += len; } + xdr->rqst = rqst; } EXPORT_SYMBOL_GPL(xdr_init_encode); /** + * xdr_init_encode_pages - Initialize an xdr_stream for encoding into pages + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer into which to encode data + * + */ +void xdr_init_encode_pages(struct xdr_stream *xdr, struct xdr_buf *buf) +{ + xdr_reset_scratch_buffer(xdr); + + xdr->buf = buf; + xdr->page_ptr = buf->pages; + xdr->iov = NULL; + xdr->p = page_address(*xdr->page_ptr); + xdr->end = (void *)xdr->p + min_t(u32, buf->buflen, PAGE_SIZE); + xdr->rqst = NULL; +} +EXPORT_SYMBOL_GPL(xdr_init_encode_pages); + +/** + * __xdr_commit_encode - Ensure all data is written to buffer + * @xdr: pointer to xdr_stream + * + * We handle encoding across page boundaries by giving the caller a + * temporary location to write to, then later copying the data into + * place; xdr_commit_encode does that copying. + * + * Normally the caller doesn't need to call this directly, as the + * following xdr_reserve_space will do it. But an explicit call may be + * required at the end of encoding, or any other time when the xdr_buf + * data might be read. + */ +void __xdr_commit_encode(struct xdr_stream *xdr) +{ + size_t shift = xdr->scratch.iov_len; + void *page; + + page = page_address(*xdr->page_ptr); + memcpy(xdr->scratch.iov_base, page, shift); + memmove(page, page + shift, (void *)xdr->p - page); + xdr_reset_scratch_buffer(xdr); +} +EXPORT_SYMBOL_GPL(__xdr_commit_encode); + +/* + * The buffer space to be reserved crosses the boundary between + * xdr->buf->head and xdr->buf->pages, or between two pages + * in xdr->buf->pages. + */ +static noinline __be32 *xdr_get_next_encode_buffer(struct xdr_stream *xdr, + size_t nbytes) +{ + int space_left; + int frag1bytes, frag2bytes; + void *p; + + if (nbytes > PAGE_SIZE) + goto out_overflow; /* Bigger buffers require special handling */ + if (xdr->buf->len + nbytes > xdr->buf->buflen) + goto out_overflow; /* Sorry, we're totally out of space */ + frag1bytes = (xdr->end - xdr->p) << 2; + frag2bytes = nbytes - frag1bytes; + if (xdr->iov) + xdr->iov->iov_len += frag1bytes; + else + xdr->buf->page_len += frag1bytes; + xdr->page_ptr++; + xdr->iov = NULL; + + /* + * If the last encode didn't end exactly on a page boundary, the + * next one will straddle boundaries. Encode into the next + * page, then copy it back later in xdr_commit_encode. We use + * the "scratch" iov to track any temporarily unused fragment of + * space at the end of the previous buffer: + */ + xdr_set_scratch_buffer(xdr, xdr->p, frag1bytes); + + /* + * xdr->p is where the next encode will start after + * xdr_commit_encode() has shifted this one back: + */ + p = page_address(*xdr->page_ptr); + xdr->p = p + frag2bytes; + space_left = xdr->buf->buflen - xdr->buf->len; + if (space_left - frag1bytes >= PAGE_SIZE) + xdr->end = p + PAGE_SIZE; + else + xdr->end = p + space_left - frag1bytes; + + xdr->buf->page_len += frag2bytes; + xdr->buf->len += nbytes; + return p; +out_overflow: + trace_rpc_xdr_overflow(xdr, nbytes); + return NULL; +} + +/** * xdr_reserve_space - Reserve buffer space for sending * @xdr: pointer to xdr_stream * @nbytes: number of bytes to reserve @@ -486,52 +1067,222 @@ EXPORT_SYMBOL_GPL(xdr_init_encode); * Checks that we have enough buffer space to encode 'nbytes' more * bytes of data. If so, update the total xdr_buf length, and * adjust the length of the current kvec. + * + * The returned pointer is valid only until the next call to + * xdr_reserve_space() or xdr_commit_encode() on @xdr. The current + * implementation of this API guarantees that space reserved for a + * four-byte data item remains valid until @xdr is destroyed, but + * that might not always be true in the future. */ __be32 * xdr_reserve_space(struct xdr_stream *xdr, size_t nbytes) { __be32 *p = xdr->p; __be32 *q; + xdr_commit_encode(xdr); /* align nbytes on the next 32-bit boundary */ nbytes += 3; nbytes &= ~3; q = p + (nbytes >> 2); if (unlikely(q > xdr->end || q < p)) - return NULL; + return xdr_get_next_encode_buffer(xdr, nbytes); xdr->p = q; - xdr->iov->iov_len += nbytes; + if (xdr->iov) + xdr->iov->iov_len += nbytes; + else + xdr->buf->page_len += nbytes; xdr->buf->len += nbytes; return p; } EXPORT_SYMBOL_GPL(xdr_reserve_space); /** + * xdr_reserve_space_vec - Reserves a large amount of buffer space for sending + * @xdr: pointer to xdr_stream + * @nbytes: number of bytes to reserve + * + * The size argument passed to xdr_reserve_space() is determined based + * on the number of bytes remaining in the current page to avoid + * invalidating iov_base pointers when xdr_commit_encode() is called. + * + * Return values: + * %0: success + * %-EMSGSIZE: not enough space is available in @xdr + */ +int xdr_reserve_space_vec(struct xdr_stream *xdr, size_t nbytes) +{ + size_t thislen; + __be32 *p; + + /* + * svcrdma requires every READ payload to start somewhere + * in xdr->pages. + */ + if (xdr->iov == xdr->buf->head) { + xdr->iov = NULL; + xdr->end = xdr->p; + } + + /* XXX: Let's find a way to make this more efficient */ + while (nbytes) { + thislen = xdr->buf->page_len % PAGE_SIZE; + thislen = min_t(size_t, nbytes, PAGE_SIZE - thislen); + + p = xdr_reserve_space(xdr, thislen); + if (!p) + return -EMSGSIZE; + + nbytes -= thislen; + } + + return 0; +} +EXPORT_SYMBOL_GPL(xdr_reserve_space_vec); + +/** + * xdr_truncate_encode - truncate an encode buffer + * @xdr: pointer to xdr_stream + * @len: new length of buffer + * + * Truncates the xdr stream, so that xdr->buf->len == len, + * and xdr->p points at offset len from the start of the buffer, and + * head, tail, and page lengths are adjusted to correspond. + * + * If this means moving xdr->p to a different buffer, we assume that + * the end pointer should be set to the end of the current page, + * except in the case of the head buffer when we assume the head + * buffer's current length represents the end of the available buffer. + * + * This is *not* safe to use on a buffer that already has inlined page + * cache pages (as in a zero-copy server read reply), except for the + * simple case of truncating from one position in the tail to another. + * + */ +void xdr_truncate_encode(struct xdr_stream *xdr, size_t len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *head = buf->head; + struct kvec *tail = buf->tail; + int fraglen; + int new; + + if (len > buf->len) { + WARN_ON_ONCE(1); + return; + } + xdr_commit_encode(xdr); + + fraglen = min_t(int, buf->len - len, tail->iov_len); + tail->iov_len -= fraglen; + buf->len -= fraglen; + if (tail->iov_len) { + xdr->p = tail->iov_base + tail->iov_len; + WARN_ON_ONCE(!xdr->end); + WARN_ON_ONCE(!xdr->iov); + return; + } + WARN_ON_ONCE(fraglen); + fraglen = min_t(int, buf->len - len, buf->page_len); + buf->page_len -= fraglen; + buf->len -= fraglen; + + new = buf->page_base + buf->page_len; + + xdr->page_ptr = buf->pages + (new >> PAGE_SHIFT); + + if (buf->page_len) { + xdr->p = page_address(*xdr->page_ptr); + xdr->end = (void *)xdr->p + PAGE_SIZE; + xdr->p = (void *)xdr->p + (new % PAGE_SIZE); + WARN_ON_ONCE(xdr->iov); + return; + } + if (fraglen) + xdr->end = head->iov_base + head->iov_len; + /* (otherwise assume xdr->end is already set) */ + xdr->page_ptr--; + head->iov_len = len; + buf->len = len; + xdr->p = head->iov_base + head->iov_len; + xdr->iov = buf->head; +} +EXPORT_SYMBOL(xdr_truncate_encode); + +/** + * xdr_truncate_decode - Truncate a decoding stream + * @xdr: pointer to struct xdr_stream + * @len: Number of bytes to remove + * + */ +void xdr_truncate_decode(struct xdr_stream *xdr, size_t len) +{ + unsigned int nbytes = xdr_align_size(len); + + xdr->buf->len -= nbytes; + xdr->nwords -= XDR_QUADLEN(nbytes); +} +EXPORT_SYMBOL_GPL(xdr_truncate_decode); + +/** + * xdr_restrict_buflen - decrease available buffer space + * @xdr: pointer to xdr_stream + * @newbuflen: new maximum number of bytes available + * + * Adjust our idea of how much space is available in the buffer. + * If we've already used too much space in the buffer, returns -1. + * If the available space is already smaller than newbuflen, returns 0 + * and does nothing. Otherwise, adjusts xdr->buf->buflen to newbuflen + * and ensures xdr->end is set at most offset newbuflen from the start + * of the buffer. + */ +int xdr_restrict_buflen(struct xdr_stream *xdr, int newbuflen) +{ + struct xdr_buf *buf = xdr->buf; + int left_in_this_buf = (void *)xdr->end - (void *)xdr->p; + int end_offset = buf->len + left_in_this_buf; + + if (newbuflen < 0 || newbuflen < buf->len) + return -1; + if (newbuflen > buf->buflen) + return 0; + if (newbuflen < end_offset) + xdr->end = (void *)xdr->end + newbuflen - end_offset; + buf->buflen = newbuflen; + return 0; +} +EXPORT_SYMBOL(xdr_restrict_buflen); + +/** * xdr_write_pages - Insert a list of pages into an XDR buffer for sending * @xdr: pointer to xdr_stream - * @pages: list of pages - * @base: offset of first byte - * @len: length of data in bytes + * @pages: array of pages to insert + * @base: starting offset of first data byte in @pages + * @len: number of data bytes in @pages to insert * + * After the @pages are added, the tail iovec is instantiated pointing to + * end of the head buffer, and the stream is set up to encode subsequent + * items into the tail. */ void xdr_write_pages(struct xdr_stream *xdr, struct page **pages, unsigned int base, unsigned int len) { struct xdr_buf *buf = xdr->buf; - struct kvec *iov = buf->tail; + struct kvec *tail = buf->tail; + buf->pages = pages; buf->page_base = base; buf->page_len = len; - iov->iov_base = (char *)xdr->p; - iov->iov_len = 0; - xdr->iov = iov; + tail->iov_base = xdr->p; + tail->iov_len = 0; + xdr->iov = tail; if (len & 3) { unsigned int pad = 4 - (len & 3); BUG_ON(xdr->p >= xdr->end); - iov->iov_base = (char *)xdr->p + (len & 3); - iov->iov_len += pad; + tail->iov_base = (char *)xdr->p + (len & 3); + tail->iov_len += pad; len += pad; *xdr->p++ = 0; } @@ -540,19 +1291,39 @@ void xdr_write_pages(struct xdr_stream *xdr, struct page **pages, unsigned int b } EXPORT_SYMBOL_GPL(xdr_write_pages); -static void xdr_set_iov(struct xdr_stream *xdr, struct kvec *iov, - unsigned int len) +static unsigned int xdr_set_iov(struct xdr_stream *xdr, struct kvec *iov, + unsigned int base, unsigned int len) { if (len > iov->iov_len) len = iov->iov_len; - xdr->p = (__be32*)iov->iov_base; + if (unlikely(base > len)) + base = len; + xdr->p = (__be32*)(iov->iov_base + base); xdr->end = (__be32*)(iov->iov_base + len); xdr->iov = iov; xdr->page_ptr = NULL; + return len - base; +} + +static unsigned int xdr_set_tail_base(struct xdr_stream *xdr, + unsigned int base, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + + xdr_stream_set_pos(xdr, base + buf->page_len + buf->head->iov_len); + return xdr_set_iov(xdr, buf->tail, base, len); } -static int xdr_set_page_base(struct xdr_stream *xdr, - unsigned int base, unsigned int len) +static void xdr_stream_unmap_current_page(struct xdr_stream *xdr) +{ + if (xdr->page_kaddr) { + kunmap_local(xdr->page_kaddr); + xdr->page_kaddr = NULL; + } +} + +static unsigned int xdr_set_page_base(struct xdr_stream *xdr, + unsigned int base, unsigned int len) { unsigned int pgnr; unsigned int maxlen; @@ -562,16 +1333,24 @@ static int xdr_set_page_base(struct xdr_stream *xdr, maxlen = xdr->buf->page_len; if (base >= maxlen) - return -EINVAL; - maxlen -= base; + return 0; + else + maxlen -= base; if (len > maxlen) len = maxlen; + xdr_stream_unmap_current_page(xdr); + xdr_stream_page_set_pos(xdr, base); base += xdr->buf->page_base; pgnr = base >> PAGE_SHIFT; xdr->page_ptr = &xdr->buf->pages[pgnr]; - kaddr = page_address(*xdr->page_ptr); + + if (PageHighMem(*xdr->page_ptr)) { + xdr->page_kaddr = kmap_local_page(*xdr->page_ptr); + kaddr = xdr->page_kaddr; + } else + kaddr = page_address(*xdr->page_ptr); pgoff = base & ~PAGE_MASK; xdr->p = (__be32*)(kaddr + pgoff); @@ -581,7 +1360,16 @@ static int xdr_set_page_base(struct xdr_stream *xdr, pgend = PAGE_SIZE; xdr->end = (__be32*)(kaddr + pgend); xdr->iov = NULL; - return 0; + return len; +} + +static void xdr_set_page(struct xdr_stream *xdr, unsigned int base, + unsigned int len) +{ + if (xdr_set_page_base(xdr, base, len) == 0) { + base -= xdr->buf->page_len; + xdr_set_tail_base(xdr, base, len); + } } static void xdr_set_next_page(struct xdr_stream *xdr) @@ -590,19 +1378,18 @@ static void xdr_set_next_page(struct xdr_stream *xdr) newbase = (1 + xdr->page_ptr - xdr->buf->pages) << PAGE_SHIFT; newbase -= xdr->buf->page_base; - - if (xdr_set_page_base(xdr, newbase, PAGE_SIZE) < 0) - xdr_set_iov(xdr, xdr->buf->tail, xdr->buf->len); + if (newbase < xdr->buf->page_len) + xdr_set_page_base(xdr, newbase, xdr_stream_remaining(xdr)); + else + xdr_set_tail_base(xdr, 0, xdr_stream_remaining(xdr)); } static bool xdr_set_next_buffer(struct xdr_stream *xdr) { if (xdr->page_ptr != NULL) xdr_set_next_page(xdr); - else if (xdr->iov == xdr->buf->head) { - if (xdr_set_page_base(xdr, 0, PAGE_SIZE) < 0) - xdr_set_iov(xdr, xdr->buf->tail, xdr->buf->len); - } + else if (xdr->iov == xdr->buf->head) + xdr_set_page(xdr, 0, xdr_stream_remaining(xdr)); return xdr->p != xdr->end; } @@ -611,26 +1398,28 @@ static bool xdr_set_next_buffer(struct xdr_stream *xdr) * @xdr: pointer to xdr_stream struct * @buf: pointer to XDR buffer from which to decode data * @p: current pointer inside XDR buffer + * @rqst: pointer to controlling rpc_rqst, for debugging */ -void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) +void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p, + struct rpc_rqst *rqst) { xdr->buf = buf; - xdr->scratch.iov_base = NULL; - xdr->scratch.iov_len = 0; + xdr->page_kaddr = NULL; + xdr_reset_scratch_buffer(xdr); xdr->nwords = XDR_QUADLEN(buf->len); - if (buf->head[0].iov_len != 0) - xdr_set_iov(xdr, buf->head, buf->len); - else if (buf->page_len != 0) - xdr_set_page_base(xdr, 0, buf->len); + if (xdr_set_iov(xdr, buf->head, 0, buf->len) == 0 && + xdr_set_page_base(xdr, 0, buf->len) == 0) + xdr_set_iov(xdr, buf->tail, 0, buf->len); if (p != NULL && p > xdr->p && xdr->end >= p) { xdr->nwords -= p - xdr->p; xdr->p = p; } + xdr->rqst = rqst; } EXPORT_SYMBOL_GPL(xdr_init_decode); /** - * xdr_init_decode - Initialize an xdr_stream for decoding data. + * xdr_init_decode_pages - Initialize an xdr_stream for decoding into pages * @xdr: pointer to xdr_stream struct * @buf: pointer to XDR buffer from which to decode data * @pages: list of pages to decode into @@ -644,10 +1433,20 @@ void xdr_init_decode_pages(struct xdr_stream *xdr, struct xdr_buf *buf, buf->page_len = len; buf->buflen = len; buf->len = len; - xdr_init_decode(xdr, buf, NULL); + xdr_init_decode(xdr, buf, NULL, NULL); } EXPORT_SYMBOL_GPL(xdr_init_decode_pages); +/** + * xdr_finish_decode - Clean up the xdr_stream after decoding data. + * @xdr: pointer to xdr_stream struct + */ +void xdr_finish_decode(struct xdr_stream *xdr) +{ + xdr_stream_unmap_current_page(xdr); +} +EXPORT_SYMBOL(xdr_finish_decode); + static __be32 * __xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) { unsigned int nwords = XDR_QUADLEN(nbytes); @@ -661,42 +1460,30 @@ static __be32 * __xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) return p; } -/** - * xdr_set_scratch_buffer - Attach a scratch buffer for decoding data. - * @xdr: pointer to xdr_stream struct - * @buf: pointer to an empty buffer - * @buflen: size of 'buf' - * - * The scratch buffer is used when decoding from an array of pages. - * If an xdr_inline_decode() call spans across page boundaries, then - * we copy the data into the scratch buffer in order to allow linear - * access. - */ -void xdr_set_scratch_buffer(struct xdr_stream *xdr, void *buf, size_t buflen) -{ - xdr->scratch.iov_base = buf; - xdr->scratch.iov_len = buflen; -} -EXPORT_SYMBOL_GPL(xdr_set_scratch_buffer); - static __be32 *xdr_copy_to_scratch(struct xdr_stream *xdr, size_t nbytes) { __be32 *p; - void *cpdest = xdr->scratch.iov_base; + char *cpdest = xdr->scratch.iov_base; size_t cplen = (char *)xdr->end - (char *)xdr->p; if (nbytes > xdr->scratch.iov_len) + goto out_overflow; + p = __xdr_inline_decode(xdr, cplen); + if (p == NULL) return NULL; - memcpy(cpdest, xdr->p, cplen); + memcpy(cpdest, p, cplen); + if (!xdr_set_next_buffer(xdr)) + goto out_overflow; cpdest += cplen; nbytes -= cplen; - if (!xdr_set_next_buffer(xdr)) - return NULL; p = __xdr_inline_decode(xdr, nbytes); if (p == NULL) return NULL; memcpy(cpdest, p, nbytes); return xdr->scratch.iov_base; +out_overflow: + trace_rpc_xdr_overflow(xdr, nbytes); + return NULL; } /** @@ -713,33 +1500,45 @@ __be32 * xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) { __be32 *p; - if (nbytes == 0) + if (unlikely(nbytes == 0)) return xdr->p; if (xdr->p == xdr->end && !xdr_set_next_buffer(xdr)) - return NULL; + goto out_overflow; p = __xdr_inline_decode(xdr, nbytes); if (p != NULL) return p; return xdr_copy_to_scratch(xdr, nbytes); +out_overflow: + trace_rpc_xdr_overflow(xdr, nbytes); + return NULL; } EXPORT_SYMBOL_GPL(xdr_inline_decode); -static unsigned int xdr_align_pages(struct xdr_stream *xdr, unsigned int len) +static void xdr_realign_pages(struct xdr_stream *xdr) { struct xdr_buf *buf = xdr->buf; - struct kvec *iov; - unsigned int nwords = XDR_QUADLEN(len); + struct kvec *iov = buf->head; unsigned int cur = xdr_stream_pos(xdr); + unsigned int copied; - if (xdr->nwords == 0) - return 0; /* Realign pages to current pointer position */ - iov = buf->head; if (iov->iov_len > cur) { - xdr_shrink_bufhead(buf, iov->iov_len - cur); - xdr->nwords = XDR_QUADLEN(buf->len - cur); + copied = xdr_shrink_bufhead(buf, cur); + trace_rpc_xdr_alignment(xdr, cur, copied); + xdr_set_page(xdr, 0, buf->page_len); } +} + +static unsigned int xdr_align_pages(struct xdr_stream *xdr, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + unsigned int nwords = XDR_QUADLEN(len); + unsigned int copied; + + if (xdr->nwords == 0) + return 0; + xdr_realign_pages(xdr); if (nwords > xdr->nwords) { nwords = xdr->nwords; len = nwords << 2; @@ -748,55 +1547,72 @@ static unsigned int xdr_align_pages(struct xdr_stream *xdr, unsigned int len) len = buf->page_len; else if (nwords < xdr->nwords) { /* Truncate page data and move it into the tail */ - xdr_shrink_pagelen(buf, buf->page_len - len); - xdr->nwords = XDR_QUADLEN(buf->len - cur); + copied = xdr_shrink_pagelen(buf, len); + trace_rpc_xdr_alignment(xdr, len, copied); } return len; } /** - * xdr_read_pages - Ensure page-based XDR data to decode is aligned at current pointer position + * xdr_read_pages - align page-based XDR data to current pointer position * @xdr: pointer to xdr_stream struct * @len: number of bytes of page data * * Moves data beyond the current pointer position from the XDR head[] buffer - * into the page list. Any data that lies beyond current position + "len" - * bytes is moved into the XDR tail[]. + * into the page list. Any data that lies beyond current position + @len + * bytes is moved into the XDR tail[]. The xdr_stream current position is + * then advanced past that data to align to the next XDR object in the tail. * * Returns the number of XDR encoded bytes now contained in the pages */ unsigned int xdr_read_pages(struct xdr_stream *xdr, unsigned int len) { - struct xdr_buf *buf = xdr->buf; - struct kvec *iov; - unsigned int nwords; - unsigned int end; - unsigned int padding; + unsigned int nwords = XDR_QUADLEN(len); + unsigned int base, end, pglen; - len = xdr_align_pages(xdr, len); - if (len == 0) + pglen = xdr_align_pages(xdr, nwords << 2); + if (pglen == 0) return 0; - nwords = XDR_QUADLEN(len); - padding = (nwords << 2) - len; - xdr->iov = iov = buf->tail; - /* Compute remaining message length. */ - end = ((xdr->nwords - nwords) << 2) + padding; - if (end > iov->iov_len) - end = iov->iov_len; - /* - * Position current pointer at beginning of tail, and - * set remaining message length. - */ - xdr->p = (__be32 *)((char *)iov->iov_base + padding); - xdr->end = (__be32 *)((char *)iov->iov_base + end); - xdr->page_ptr = NULL; - xdr->nwords = XDR_QUADLEN(end - padding); - return len; + base = (nwords << 2) - pglen; + end = xdr_stream_remaining(xdr) - pglen; + + xdr_set_tail_base(xdr, base, end); + return len <= pglen ? len : pglen; } EXPORT_SYMBOL_GPL(xdr_read_pages); /** + * xdr_set_pagelen - Sets the length of the XDR pages + * @xdr: pointer to xdr_stream struct + * @len: new length of the XDR page data + * + * Either grows or shrinks the length of the xdr pages by setting pagelen to + * @len bytes. When shrinking, any extra data is moved into buf->tail, whereas + * when growing any data beyond the current pointer is moved into the tail. + * + * Returns True if the operation was successful, and False otherwise. + */ +void xdr_set_pagelen(struct xdr_stream *xdr, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + size_t remaining = xdr_stream_remaining(xdr); + size_t base = 0; + + if (len < buf->page_len) { + base = buf->page_len - len; + xdr_shrink_pagelen(buf, len); + } else { + xdr_buf_head_shift_right(buf, xdr_stream_pos(xdr), + buf->page_len, remaining); + if (len > buf->page_len) + xdr_buf_try_expand(buf, len - buf->page_len); + } + xdr_set_tail_base(xdr, base, remaining); +} +EXPORT_SYMBOL_GPL(xdr_set_pagelen); + +/** * xdr_enter_page - decode data from the XDR page * @xdr: pointer to xdr_stream struct * @len: number of bytes of page data @@ -818,10 +1634,9 @@ void xdr_enter_page(struct xdr_stream *xdr, unsigned int len) } EXPORT_SYMBOL_GPL(xdr_enter_page); -static struct kvec empty_iov = {.iov_base = NULL, .iov_len = 0}; +static const struct kvec empty_iov = {.iov_base = NULL, .iov_len = 0}; -void -xdr_buf_from_iov(struct kvec *iov, struct xdr_buf *buf) +void xdr_buf_from_iov(const struct kvec *iov, struct xdr_buf *buf) { buf->head[0] = *iov; buf->tail[0] = empty_iov; @@ -830,11 +1645,22 @@ xdr_buf_from_iov(struct kvec *iov, struct xdr_buf *buf) } EXPORT_SYMBOL_GPL(xdr_buf_from_iov); -/* Sets subbuf to the portion of buf of length len beginning base bytes - * from the start of buf. Returns -1 if base of length are out of bounds. */ -int -xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, - unsigned int base, unsigned int len) +/** + * xdr_buf_subsegment - set subbuf to a portion of buf + * @buf: an xdr buffer + * @subbuf: the result buffer + * @base: beginning of range in bytes + * @len: length of range in bytes + * + * sets @subbuf to an xdr buffer representing the portion of @buf of + * length @len starting at offset @base. + * + * @buf and @subbuf may be pointers to the same struct xdr_buf. + * + * Returns -1 if base or length are out of bounds. + */ +int xdr_buf_subsegment(const struct xdr_buf *buf, struct xdr_buf *subbuf, + unsigned int base, unsigned int len) { subbuf->buflen = subbuf->len = len; if (base < buf->head[0].iov_len) { @@ -844,20 +1670,22 @@ xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, len -= subbuf->head[0].iov_len; base = 0; } else { - subbuf->head[0].iov_base = NULL; - subbuf->head[0].iov_len = 0; base -= buf->head[0].iov_len; + subbuf->head[0].iov_base = buf->head[0].iov_base; + subbuf->head[0].iov_len = 0; } if (base < buf->page_len) { subbuf->page_len = min(buf->page_len - base, len); base += buf->page_base; - subbuf->page_base = base & ~PAGE_CACHE_MASK; - subbuf->pages = &buf->pages[base >> PAGE_CACHE_SHIFT]; + subbuf->page_base = base & ~PAGE_MASK; + subbuf->pages = &buf->pages[base >> PAGE_SHIFT]; len -= subbuf->page_len; base = 0; } else { base -= buf->page_len; + subbuf->pages = buf->pages; + subbuf->page_base = 0; subbuf->page_len = 0; } @@ -868,9 +1696,9 @@ xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, len -= subbuf->tail[0].iov_len; base = 0; } else { - subbuf->tail[0].iov_base = NULL; - subbuf->tail[0].iov_len = 0; base -= buf->tail[0].iov_len; + subbuf->tail[0].iov_base = buf->tail[0].iov_base; + subbuf->tail[0].iov_len = 0; } if (base || len) @@ -880,6 +1708,107 @@ xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, EXPORT_SYMBOL_GPL(xdr_buf_subsegment); /** + * xdr_stream_subsegment - set @subbuf to a portion of @xdr + * @xdr: an xdr_stream set up for decoding + * @subbuf: the result buffer + * @nbytes: length of @xdr to extract, in bytes + * + * Sets up @subbuf to represent a portion of @xdr. The portion + * starts at the current offset in @xdr, and extends for a length + * of @nbytes. If this is successful, @xdr is advanced to the next + * XDR data item following that portion. + * + * Return values: + * %true: @subbuf has been initialized, and @xdr has been advanced. + * %false: a bounds error has occurred + */ +bool xdr_stream_subsegment(struct xdr_stream *xdr, struct xdr_buf *subbuf, + unsigned int nbytes) +{ + unsigned int start = xdr_stream_pos(xdr); + unsigned int remaining, len; + + /* Extract @subbuf and bounds-check the fn arguments */ + if (xdr_buf_subsegment(xdr->buf, subbuf, start, nbytes)) + return false; + + /* Advance @xdr by @nbytes */ + for (remaining = nbytes; remaining;) { + if (xdr->p == xdr->end && !xdr_set_next_buffer(xdr)) + return false; + + len = (char *)xdr->end - (char *)xdr->p; + if (remaining <= len) { + xdr->p = (__be32 *)((char *)xdr->p + + (remaining + xdr_pad_size(nbytes))); + break; + } + + xdr->p = (__be32 *)((char *)xdr->p + len); + xdr->end = xdr->p; + remaining -= len; + } + + xdr_stream_set_pos(xdr, start + nbytes); + return true; +} +EXPORT_SYMBOL_GPL(xdr_stream_subsegment); + +/** + * xdr_stream_move_subsegment - Move part of a stream to another position + * @xdr: the source xdr_stream + * @offset: the source offset of the segment + * @target: the target offset of the segment + * @length: the number of bytes to move + * + * Moves @length bytes from @offset to @target in the xdr_stream, overwriting + * anything in its space. Returns the number of bytes in the segment. + */ +unsigned int xdr_stream_move_subsegment(struct xdr_stream *xdr, unsigned int offset, + unsigned int target, unsigned int length) +{ + struct xdr_buf buf; + unsigned int shift; + + if (offset < target) { + shift = target - offset; + if (xdr_buf_subsegment(xdr->buf, &buf, offset, shift + length) < 0) + return 0; + xdr_buf_head_shift_right(&buf, 0, length, shift); + } else if (offset > target) { + shift = offset - target; + if (xdr_buf_subsegment(xdr->buf, &buf, target, shift + length) < 0) + return 0; + xdr_buf_head_shift_left(&buf, shift, length, shift); + } + return length; +} +EXPORT_SYMBOL_GPL(xdr_stream_move_subsegment); + +/** + * xdr_stream_zero - zero out a portion of an xdr_stream + * @xdr: an xdr_stream to zero out + * @offset: the starting point in the stream + * @length: the number of bytes to zero + */ +unsigned int xdr_stream_zero(struct xdr_stream *xdr, unsigned int offset, + unsigned int length) +{ + struct xdr_buf buf; + + if (xdr_buf_subsegment(xdr->buf, &buf, offset, length) < 0) + return 0; + if (buf.head[0].iov_len) + xdr_buf_iov_zero(buf.head, 0, buf.head[0].iov_len); + if (buf.page_len > 0) + xdr_buf_pages_zero(&buf, 0, buf.page_len); + if (buf.tail[0].iov_len) + xdr_buf_iov_zero(buf.tail, 0, buf.tail[0].iov_len); + return length; +} +EXPORT_SYMBOL_GPL(xdr_stream_zero); + +/** * xdr_buf_trim - lop at most "len" bytes off the end of "buf" * @buf: buf to be trimmed * @len: number of bytes to reduce "buf" by @@ -920,7 +1849,8 @@ fix_len: } EXPORT_SYMBOL_GPL(xdr_buf_trim); -static void __read_bytes_from_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned int len) +static void __read_bytes_from_xdr_buf(const struct xdr_buf *subbuf, + void *obj, unsigned int len) { unsigned int this_len; @@ -929,8 +1859,7 @@ static void __read_bytes_from_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigne len -= this_len; obj += this_len; this_len = min_t(unsigned int, len, subbuf->page_len); - if (this_len) - _copy_from_pages(obj, subbuf->pages, subbuf->page_base, this_len); + _copy_from_pages(obj, subbuf->pages, subbuf->page_base, this_len); len -= this_len; obj += this_len; this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len); @@ -938,7 +1867,8 @@ static void __read_bytes_from_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigne } /* obj is assumed to point to allocated memory of size at least len: */ -int read_bytes_from_xdr_buf(struct xdr_buf *buf, unsigned int base, void *obj, unsigned int len) +int read_bytes_from_xdr_buf(const struct xdr_buf *buf, unsigned int base, + void *obj, unsigned int len) { struct xdr_buf subbuf; int status; @@ -951,7 +1881,8 @@ int read_bytes_from_xdr_buf(struct xdr_buf *buf, unsigned int base, void *obj, u } EXPORT_SYMBOL_GPL(read_bytes_from_xdr_buf); -static void __write_bytes_to_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned int len) +static void __write_bytes_to_xdr_buf(const struct xdr_buf *subbuf, + void *obj, unsigned int len) { unsigned int this_len; @@ -960,8 +1891,7 @@ static void __write_bytes_to_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned len -= this_len; obj += this_len; this_len = min_t(unsigned int, len, subbuf->page_len); - if (this_len) - _copy_to_pages(subbuf->pages, subbuf->page_base, obj, this_len); + _copy_to_pages(subbuf->pages, subbuf->page_base, obj, this_len); len -= this_len; obj += this_len; this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len); @@ -969,7 +1899,8 @@ static void __write_bytes_to_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned } /* obj is assumed to point to allocated memory of size at least len: */ -int write_bytes_to_xdr_buf(struct xdr_buf *buf, unsigned int base, void *obj, unsigned int len) +int write_bytes_to_xdr_buf(const struct xdr_buf *buf, unsigned int base, + void *obj, unsigned int len) { struct xdr_buf subbuf; int status; @@ -982,8 +1913,7 @@ int write_bytes_to_xdr_buf(struct xdr_buf *buf, unsigned int base, void *obj, un } EXPORT_SYMBOL_GPL(write_bytes_to_xdr_buf); -int -xdr_decode_word(struct xdr_buf *buf, unsigned int base, u32 *obj) +int xdr_decode_word(const struct xdr_buf *buf, unsigned int base, u32 *obj) { __be32 raw; int status; @@ -996,8 +1926,7 @@ xdr_decode_word(struct xdr_buf *buf, unsigned int base, u32 *obj) } EXPORT_SYMBOL_GPL(xdr_decode_word); -int -xdr_encode_word(struct xdr_buf *buf, unsigned int base, u32 obj) +int xdr_encode_word(const struct xdr_buf *buf, unsigned int base, u32 obj) { __be32 raw = cpu_to_be32(obj); @@ -1005,48 +1934,9 @@ xdr_encode_word(struct xdr_buf *buf, unsigned int base, u32 obj) } EXPORT_SYMBOL_GPL(xdr_encode_word); -/* If the netobj starting offset bytes from the start of xdr_buf is contained - * entirely in the head or the tail, set object to point to it; otherwise - * try to find space for it at the end of the tail, copy it there, and - * set obj to point to it. */ -int xdr_buf_read_netobj(struct xdr_buf *buf, struct xdr_netobj *obj, unsigned int offset) -{ - struct xdr_buf subbuf; - - if (xdr_decode_word(buf, offset, &obj->len)) - return -EFAULT; - if (xdr_buf_subsegment(buf, &subbuf, offset + 4, obj->len)) - return -EFAULT; - - /* Is the obj contained entirely in the head? */ - obj->data = subbuf.head[0].iov_base; - if (subbuf.head[0].iov_len == obj->len) - return 0; - /* ..or is the obj contained entirely in the tail? */ - obj->data = subbuf.tail[0].iov_base; - if (subbuf.tail[0].iov_len == obj->len) - return 0; - - /* use end of tail as storage for obj: - * (We don't copy to the beginning because then we'd have - * to worry about doing a potentially overlapping copy. - * This assumes the object is at most half the length of the - * tail.) */ - if (obj->len > buf->buflen - buf->len) - return -ENOMEM; - if (buf->tail[0].iov_len != 0) - obj->data = buf->tail[0].iov_base + buf->tail[0].iov_len; - else - obj->data = buf->head[0].iov_base + buf->head[0].iov_len; - __read_bytes_from_xdr_buf(&subbuf, obj->data, obj->len); - return 0; -} -EXPORT_SYMBOL_GPL(xdr_buf_read_netobj); - /* Returns 0 on success, or else a negative error code. */ -static int -xdr_xcode_array2(struct xdr_buf *buf, unsigned int base, - struct xdr_array2_desc *desc, int encode) +static int xdr_xcode_array2(const struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc, int encode) { char *elem = NULL, *c; unsigned int copied = 0, todo, avail_here; @@ -1112,9 +2002,9 @@ xdr_xcode_array2(struct xdr_buf *buf, unsigned int base, todo -= avail_here; base += buf->page_base; - ppages = buf->pages + (base >> PAGE_CACHE_SHIFT); - base &= ~PAGE_CACHE_MASK; - avail_page = min_t(unsigned int, PAGE_CACHE_SIZE - base, + ppages = buf->pages + (base >> PAGE_SHIFT); + base &= ~PAGE_MASK; + avail_page = min_t(unsigned int, PAGE_SIZE - base, avail_here); c = kmap(*ppages) + base; @@ -1198,7 +2088,7 @@ xdr_xcode_array2(struct xdr_buf *buf, unsigned int base, } avail_page = min(avail_here, - (unsigned int) PAGE_CACHE_SIZE); + (unsigned int) PAGE_SIZE); } base = buf->page_len; /* align to start of tail */ } @@ -1238,9 +2128,8 @@ out: return err; } -int -xdr_decode_array2(struct xdr_buf *buf, unsigned int base, - struct xdr_array2_desc *desc) +int xdr_decode_array2(const struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc) { if (base >= buf->len) return -EINVAL; @@ -1249,9 +2138,8 @@ xdr_decode_array2(struct xdr_buf *buf, unsigned int base, } EXPORT_SYMBOL_GPL(xdr_decode_array2); -int -xdr_encode_array2(struct xdr_buf *buf, unsigned int base, - struct xdr_array2_desc *desc) +int xdr_encode_array2(const struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc) { if ((unsigned long) base + 4 + desc->array_len * desc->elem_size > buf->head->iov_len + buf->page_len + buf->tail->iov_len) @@ -1261,9 +2149,9 @@ xdr_encode_array2(struct xdr_buf *buf, unsigned int base, } EXPORT_SYMBOL_GPL(xdr_encode_array2); -int -xdr_process_buf(struct xdr_buf *buf, unsigned int offset, unsigned int len, - int (*actor)(struct scatterlist *, void *), void *data) +int xdr_process_buf(const struct xdr_buf *buf, unsigned int offset, + unsigned int len, + int (*actor)(struct scatterlist *, void *), void *data) { int i, ret = 0; unsigned int page_len, thislen, page_offset; @@ -1294,9 +2182,9 @@ xdr_process_buf(struct xdr_buf *buf, unsigned int offset, unsigned int len, if (page_len > len) page_len = len; len -= page_len; - page_offset = (offset + buf->page_base) & (PAGE_CACHE_SIZE - 1); - i = (offset + buf->page_base) >> PAGE_CACHE_SHIFT; - thislen = PAGE_CACHE_SIZE - page_offset; + page_offset = (offset + buf->page_base) & (PAGE_SIZE - 1); + i = (offset + buf->page_base) >> PAGE_SHIFT; + thislen = PAGE_SIZE - page_offset; do { if (thislen > page_len) thislen = page_len; @@ -1307,7 +2195,7 @@ xdr_process_buf(struct xdr_buf *buf, unsigned int offset, unsigned int len, page_len -= thislen; i++; page_offset = 0; - thislen = PAGE_CACHE_SIZE; + thislen = PAGE_SIZE; } while (page_len != 0); offset = 0; } @@ -1328,3 +2216,92 @@ out: } EXPORT_SYMBOL_GPL(xdr_process_buf); +/** + * xdr_stream_decode_string_dup - Decode and duplicate variable length string + * @xdr: pointer to xdr_stream + * @str: location to store pointer to string + * @maxlen: maximum acceptable string length + * @gfp_flags: GFP mask to use + * + * Return values: + * On success, returns length of NUL-terminated string stored in *@ptr + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE if the size of the string would exceed @maxlen + * %-ENOMEM on memory allocation failure + */ +ssize_t xdr_stream_decode_string_dup(struct xdr_stream *xdr, char **str, + size_t maxlen, gfp_t gfp_flags) +{ + void *p; + ssize_t ret; + + ret = xdr_stream_decode_opaque_inline(xdr, &p, maxlen); + if (ret > 0) { + char *s = kmemdup_nul(p, ret, gfp_flags); + if (s != NULL) { + *str = s; + return strlen(s); + } + ret = -ENOMEM; + } + *str = NULL; + return ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_string_dup); + +/** + * xdr_stream_decode_opaque_auth - Decode struct opaque_auth (RFC5531 S8.2) + * @xdr: pointer to xdr_stream + * @flavor: location to store decoded flavor + * @body: location to store decode body + * @body_len: location to store length of decoded body + * + * Return values: + * On success, returns the number of buffer bytes consumed + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE if the decoded size of the body field exceeds 400 octets + */ +ssize_t xdr_stream_decode_opaque_auth(struct xdr_stream *xdr, u32 *flavor, + void **body, unsigned int *body_len) +{ + ssize_t ret, len; + + len = xdr_stream_decode_u32(xdr, flavor); + if (unlikely(len < 0)) + return len; + ret = xdr_stream_decode_opaque_inline(xdr, body, RPC_MAX_AUTH_SIZE); + if (unlikely(ret < 0)) + return ret; + *body_len = ret; + return len + ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_opaque_auth); + +/** + * xdr_stream_encode_opaque_auth - Encode struct opaque_auth (RFC5531 S8.2) + * @xdr: pointer to xdr_stream + * @flavor: verifier flavor to encode + * @body: content of body to encode + * @body_len: length of body to encode + * + * Return values: + * On success, returns length in bytes of XDR buffer consumed + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE if the size of @body exceeds 400 octets + */ +ssize_t xdr_stream_encode_opaque_auth(struct xdr_stream *xdr, u32 flavor, + void *body, unsigned int body_len) +{ + ssize_t ret, len; + + if (unlikely(body_len > RPC_MAX_AUTH_SIZE)) + return -EMSGSIZE; + len = xdr_stream_encode_u32(xdr, flavor); + if (unlikely(len < 0)) + return len; + ret = xdr_stream_encode_opaque(xdr, body, body_len); + if (unlikely(ret < 0)) + return ret; + return len + ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_encode_opaque_auth); diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c index 095363eee764..1023361845f9 100644 --- a/net/sunrpc/xprt.c +++ b/net/sunrpc/xprt.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * linux/net/sunrpc/xprt.c * @@ -48,46 +49,43 @@ #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/metrics.h> #include <linux/sunrpc/bc_xprt.h> +#include <linux/rcupdate.h> +#include <linux/sched/mm.h> + +#include <trace/events/sunrpc.h> #include "sunrpc.h" +#include "sysfs.h" +#include "fail.h" /* * Local variables */ -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_XPRT #endif /* * Local functions */ -static void xprt_init(struct rpc_xprt *xprt, struct net *net); -static void xprt_request_init(struct rpc_task *, struct rpc_xprt *); -static void xprt_connect_status(struct rpc_task *task); -static int __xprt_get_cong(struct rpc_xprt *, struct rpc_task *); -static void xprt_destroy(struct rpc_xprt *xprt); +static void xprt_init(struct rpc_xprt *xprt, struct net *net); +static __be32 xprt_alloc_xid(struct rpc_xprt *xprt); +static void xprt_destroy(struct rpc_xprt *xprt); +static void xprt_request_init(struct rpc_task *task); +static int xprt_request_prepare(struct rpc_rqst *req, struct xdr_buf *buf); static DEFINE_SPINLOCK(xprt_list_lock); static LIST_HEAD(xprt_list); -/* - * The transport code maintains an estimate on the maximum number of out- - * standing RPC requests, using a smoothed version of the congestion - * avoidance implemented in 44BSD. This is basically the Van Jacobson - * congestion algorithm: If a retransmit occurs, the congestion window is - * halved; otherwise, it is incremented by 1/cwnd when - * - * - a reply is received and - * - a full number of requests are outstanding and - * - the congestion window hasn't been updated recently. - */ -#define RPC_CWNDSHIFT (8U) -#define RPC_CWNDSCALE (1U << RPC_CWNDSHIFT) -#define RPC_INITCWND RPC_CWNDSCALE -#define RPC_MAXCWND(xprt) ((xprt)->max_reqs << RPC_CWNDSHIFT) +static unsigned long xprt_request_timeout(const struct rpc_rqst *req) +{ + unsigned long timeout = jiffies + req->rq_timeout; -#define RPCXPRT_CONGESTED(xprt) ((xprt)->cong >= (xprt)->cwnd) + if (time_before(timeout, req->rq_majortimeo)) + return timeout; + return req->rq_majortimeo; +} /** * xprt_register_transport - register a transport implementation @@ -157,33 +155,103 @@ out: } EXPORT_SYMBOL_GPL(xprt_unregister_transport); -/** - * xprt_load_transport - load a transport implementation - * @transport_name: transport to load - * - * Returns: - * 0: transport successfully loaded - * -ENOENT: transport module not available - */ -int xprt_load_transport(const char *transport_name) +static void +xprt_class_release(const struct xprt_class *t) { - struct xprt_class *t; - int result; + module_put(t->owner); +} + +static const struct xprt_class * +xprt_class_find_by_ident_locked(int ident) +{ + const struct xprt_class *t; + + list_for_each_entry(t, &xprt_list, list) { + if (t->ident != ident) + continue; + if (!try_module_get(t->owner)) + continue; + return t; + } + return NULL; +} + +static const struct xprt_class * +xprt_class_find_by_ident(int ident) +{ + const struct xprt_class *t; - result = 0; spin_lock(&xprt_list_lock); + t = xprt_class_find_by_ident_locked(ident); + spin_unlock(&xprt_list_lock); + return t; +} + +static const struct xprt_class * +xprt_class_find_by_netid_locked(const char *netid) +{ + const struct xprt_class *t; + unsigned int i; + list_for_each_entry(t, &xprt_list, list) { - if (strcmp(t->name, transport_name) == 0) { - spin_unlock(&xprt_list_lock); - goto out; + for (i = 0; t->netid[i][0] != '\0'; i++) { + if (strcmp(t->netid[i], netid) != 0) + continue; + if (!try_module_get(t->owner)) + continue; + return t; } } + return NULL; +} + +static const struct xprt_class * +xprt_class_find_by_netid(const char *netid) +{ + const struct xprt_class *t; + + spin_lock(&xprt_list_lock); + t = xprt_class_find_by_netid_locked(netid); + if (!t) { + spin_unlock(&xprt_list_lock); + request_module("rpc%s", netid); + spin_lock(&xprt_list_lock); + t = xprt_class_find_by_netid_locked(netid); + } spin_unlock(&xprt_list_lock); - result = request_module("xprt%s", transport_name); -out: - return result; + return t; +} + +/** + * xprt_find_transport_ident - convert a netid into a transport identifier + * @netid: transport to load + * + * Returns: + * > 0: transport identifier + * -ENOENT: transport module not available + */ +int xprt_find_transport_ident(const char *netid) +{ + const struct xprt_class *t; + int ret; + + t = xprt_class_find_by_netid(netid); + if (!t) + return -ENOENT; + ret = t->ident; + xprt_class_release(t); + return ret; +} +EXPORT_SYMBOL_GPL(xprt_find_transport_ident); + +static void xprt_clear_locked(struct rpc_xprt *xprt) +{ + xprt->snd_task = NULL; + if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state)) + clear_bit_unlock(XPRT_LOCKED, &xprt->state); + else + queue_work(xprtiod_workqueue, &xprt->task_cleanup); } -EXPORT_SYMBOL_GPL(xprt_load_transport); /** * xprt_reserve_xprt - serialize write access to transports @@ -197,46 +265,56 @@ EXPORT_SYMBOL_GPL(xprt_load_transport); int xprt_reserve_xprt(struct rpc_xprt *xprt, struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; - int priority; if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { if (task == xprt->snd_task) - return 1; + goto out_locked; goto out_sleep; } + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; xprt->snd_task = task; - if (req != NULL) { - req->rq_bytes_sent = 0; - req->rq_ntrans++; - } +out_locked: + trace_xprt_reserve_xprt(xprt, task); return 1; +out_unlock: + xprt_clear_locked(xprt); out_sleep: - dprintk("RPC: %5u failed to lock transport %p\n", - task->tk_pid, xprt); - task->tk_timeout = 0; task->tk_status = -EAGAIN; - if (req == NULL) - priority = RPC_PRIORITY_LOW; - else if (!req->rq_ntrans) - priority = RPC_PRIORITY_NORMAL; + if (RPC_IS_SOFT(task) || RPC_IS_SOFTCONN(task)) + rpc_sleep_on_timeout(&xprt->sending, task, NULL, + xprt_request_timeout(req)); else - priority = RPC_PRIORITY_HIGH; - rpc_sleep_on_priority(&xprt->sending, task, NULL, priority); + rpc_sleep_on(&xprt->sending, task, NULL); return 0; } EXPORT_SYMBOL_GPL(xprt_reserve_xprt); -static void xprt_clear_locked(struct rpc_xprt *xprt) +static bool +xprt_need_congestion_window_wait(struct rpc_xprt *xprt) { - xprt->snd_task = NULL; - if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state)) { - smp_mb__before_clear_bit(); - clear_bit(XPRT_LOCKED, &xprt->state); - smp_mb__after_clear_bit(); - } else - queue_work(rpciod_workqueue, &xprt->task_cleanup); + return test_bit(XPRT_CWND_WAIT, &xprt->state); +} + +static void +xprt_set_congestion_window_wait(struct rpc_xprt *xprt) +{ + if (!list_empty(&xprt->xmit_queue)) { + /* Peek at head of queue to see if it can make progress */ + if (list_first_entry(&xprt->xmit_queue, struct rpc_rqst, + rq_xmit)->rq_cong) + return; + } + set_bit(XPRT_CWND_WAIT, &xprt->state); +} + +static void +xprt_test_and_clear_congestion_window_wait(struct rpc_xprt *xprt) +{ + if (!RPCXPRT_CONGESTED(xprt)) + clear_bit(XPRT_CWND_WAIT, &xprt->state); } /* @@ -246,40 +324,40 @@ static void xprt_clear_locked(struct rpc_xprt *xprt) * Same as xprt_reserve_xprt, but Van Jacobson congestion control is * integrated into the decision of whether a request is allowed to be * woken up and given access to the transport. + * Note that the lock is only granted if we know there are free slots. */ int xprt_reserve_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; - int priority; if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { if (task == xprt->snd_task) - return 1; + goto out_locked; goto out_sleep; } if (req == NULL) { xprt->snd_task = task; - return 1; + goto out_locked; } - if (__xprt_get_cong(xprt, task)) { + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; + if (!xprt_need_congestion_window_wait(xprt)) { xprt->snd_task = task; - req->rq_bytes_sent = 0; - req->rq_ntrans++; - return 1; + goto out_locked; } +out_unlock: xprt_clear_locked(xprt); out_sleep: - dprintk("RPC: %5u failed to lock transport %p\n", task->tk_pid, xprt); - task->tk_timeout = 0; task->tk_status = -EAGAIN; - if (req == NULL) - priority = RPC_PRIORITY_LOW; - else if (!req->rq_ntrans) - priority = RPC_PRIORITY_NORMAL; + if (RPC_IS_SOFT(task) || RPC_IS_SOFTCONN(task)) + rpc_sleep_on_timeout(&xprt->sending, task, NULL, + xprt_request_timeout(req)); else - priority = RPC_PRIORITY_HIGH; - rpc_sleep_on_priority(&xprt->sending, task, NULL, priority); + rpc_sleep_on(&xprt->sending, task, NULL); return 0; +out_locked: + trace_xprt_reserve_cong(xprt, task); + return 1; } EXPORT_SYMBOL_GPL(xprt_reserve_xprt_cong); @@ -287,23 +365,19 @@ static inline int xprt_lock_write(struct rpc_xprt *xprt, struct rpc_task *task) { int retval; - spin_lock_bh(&xprt->transport_lock); + if (test_bit(XPRT_LOCKED, &xprt->state) && xprt->snd_task == task) + return 1; + spin_lock(&xprt->transport_lock); retval = xprt->ops->reserve_xprt(xprt, task); - spin_unlock_bh(&xprt->transport_lock); + spin_unlock(&xprt->transport_lock); return retval; } static bool __xprt_lock_write_func(struct rpc_task *task, void *data) { struct rpc_xprt *xprt = data; - struct rpc_rqst *req; - req = task->tk_rqstp; xprt->snd_task = task; - if (req) { - req->rq_bytes_sent = 0; - req->rq_ntrans++; - } return true; } @@ -311,38 +385,25 @@ static void __xprt_lock_write_next(struct rpc_xprt *xprt) { if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) return; - - if (rpc_wake_up_first(&xprt->sending, __xprt_lock_write_func, xprt)) + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; + if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending, + __xprt_lock_write_func, xprt)) return; +out_unlock: xprt_clear_locked(xprt); } -static bool __xprt_lock_write_cong_func(struct rpc_task *task, void *data) -{ - struct rpc_xprt *xprt = data; - struct rpc_rqst *req; - - req = task->tk_rqstp; - if (req == NULL) { - xprt->snd_task = task; - return true; - } - if (__xprt_get_cong(xprt, task)) { - xprt->snd_task = task; - req->rq_bytes_sent = 0; - req->rq_ntrans++; - return true; - } - return false; -} - static void __xprt_lock_write_next_cong(struct rpc_xprt *xprt) { if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) return; - if (RPCXPRT_CONGESTED(xprt)) + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; + if (xprt_need_congestion_window_wait(xprt)) goto out_unlock; - if (rpc_wake_up_first(&xprt->sending, __xprt_lock_write_cong_func, xprt)) + if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending, + __xprt_lock_write_func, xprt)) return; out_unlock: xprt_clear_locked(xprt); @@ -361,6 +422,7 @@ void xprt_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) xprt_clear_locked(xprt); __xprt_lock_write_next(xprt); } + trace_xprt_release_xprt(xprt, task); } EXPORT_SYMBOL_GPL(xprt_release_xprt); @@ -378,14 +440,17 @@ void xprt_release_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) xprt_clear_locked(xprt); __xprt_lock_write_next_cong(xprt); } + trace_xprt_release_cong(xprt, task); } EXPORT_SYMBOL_GPL(xprt_release_xprt_cong); -static inline void xprt_release_write(struct rpc_xprt *xprt, struct rpc_task *task) +void xprt_release_write(struct rpc_xprt *xprt, struct rpc_task *task) { - spin_lock_bh(&xprt->transport_lock); + if (xprt->snd_task != task) + return; + spin_lock(&xprt->transport_lock); xprt->ops->release_xprt(xprt, task); - spin_unlock_bh(&xprt->transport_lock); + spin_unlock(&xprt->transport_lock); } /* @@ -393,16 +458,15 @@ static inline void xprt_release_write(struct rpc_xprt *xprt, struct rpc_task *ta * overflowed. Put the task to sleep if this is the case. */ static int -__xprt_get_cong(struct rpc_xprt *xprt, struct rpc_task *task) +__xprt_get_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; - if (req->rq_cong) return 1; - dprintk("RPC: %5u xprt_cwnd_limited cong = %lu cwnd = %lu\n", - task->tk_pid, xprt->cong, xprt->cwnd); - if (RPCXPRT_CONGESTED(xprt)) + trace_xprt_get_cong(xprt, req->rq_task); + if (RPCXPRT_CONGESTED(xprt)) { + xprt_set_congestion_window_wait(xprt); return 0; + } req->rq_cong = 1; xprt->cong += RPC_CWNDSCALE; return 1; @@ -419,10 +483,33 @@ __xprt_put_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) return; req->rq_cong = 0; xprt->cong -= RPC_CWNDSCALE; + xprt_test_and_clear_congestion_window_wait(xprt); + trace_xprt_put_cong(xprt, req->rq_task); __xprt_lock_write_next_cong(xprt); } /** + * xprt_request_get_cong - Request congestion control credits + * @xprt: pointer to transport + * @req: pointer to RPC request + * + * Useful for transports that require congestion control. + */ +bool +xprt_request_get_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + bool ret = false; + + if (req->rq_cong) + return true; + spin_lock(&xprt->transport_lock); + ret = __xprt_get_cong(xprt, req) != 0; + spin_unlock(&xprt->transport_lock); + return ret; +} +EXPORT_SYMBOL_GPL(xprt_request_get_cong); + +/** * xprt_release_rqst_cong - housekeeping when request is complete * @task: RPC request that recently completed * @@ -436,13 +523,41 @@ void xprt_release_rqst_cong(struct rpc_task *task) } EXPORT_SYMBOL_GPL(xprt_release_rqst_cong); +static void xprt_clear_congestion_window_wait_locked(struct rpc_xprt *xprt) +{ + if (test_and_clear_bit(XPRT_CWND_WAIT, &xprt->state)) + __xprt_lock_write_next_cong(xprt); +} + +/* + * Clear the congestion window wait flag and wake up the next + * entry on xprt->sending + */ +static void +xprt_clear_congestion_window_wait(struct rpc_xprt *xprt) +{ + if (test_and_clear_bit(XPRT_CWND_WAIT, &xprt->state)) { + spin_lock(&xprt->transport_lock); + __xprt_lock_write_next_cong(xprt); + spin_unlock(&xprt->transport_lock); + } +} + /** * xprt_adjust_cwnd - adjust transport congestion window * @xprt: pointer to xprt * @task: recently completed RPC request used to adjust window * @result: result code of completed RPC request * - * We use a time-smoothed congestion estimator to avoid heavy oscillation. + * The transport code maintains an estimate on the maximum number of out- + * standing RPC requests, using a smoothed version of the congestion + * avoidance implemented in 44BSD. This is basically the Van Jacobson + * congestion algorithm: If a retransmit occurs, the congestion window is + * halved; otherwise, it is incremented by 1/cwnd when + * + * - a reply is received and + * - a full number of requests are outstanding and + * - the congestion window hasn't been updated recently. */ void xprt_adjust_cwnd(struct rpc_xprt *xprt, struct rpc_task *task, int result) { @@ -485,88 +600,96 @@ EXPORT_SYMBOL_GPL(xprt_wake_pending_tasks); /** * xprt_wait_for_buffer_space - wait for transport output buffer to clear - * @task: task to be put to sleep - * @action: function pointer to be executed after wait + * @xprt: transport * * Note that we only set the timer for the case of RPC_IS_SOFT(), since * we don't in general want to force a socket disconnection due to * an incomplete RPC call transmission. */ -void xprt_wait_for_buffer_space(struct rpc_task *task, rpc_action action) +void xprt_wait_for_buffer_space(struct rpc_xprt *xprt) { - struct rpc_rqst *req = task->tk_rqstp; - struct rpc_xprt *xprt = req->rq_xprt; - - task->tk_timeout = RPC_IS_SOFT(task) ? req->rq_timeout : 0; - rpc_sleep_on(&xprt->pending, task, action); + set_bit(XPRT_WRITE_SPACE, &xprt->state); } EXPORT_SYMBOL_GPL(xprt_wait_for_buffer_space); +static bool +xprt_clear_write_space_locked(struct rpc_xprt *xprt) +{ + if (test_and_clear_bit(XPRT_WRITE_SPACE, &xprt->state)) { + __xprt_lock_write_next(xprt); + dprintk("RPC: write space: waking waiting task on " + "xprt %p\n", xprt); + return true; + } + return false; +} + /** * xprt_write_space - wake the task waiting for transport output buffer space * @xprt: transport with waiting tasks * * Can be called in a soft IRQ context, so xprt_write_space never sleeps. */ -void xprt_write_space(struct rpc_xprt *xprt) +bool xprt_write_space(struct rpc_xprt *xprt) { - spin_lock_bh(&xprt->transport_lock); - if (xprt->snd_task) { - dprintk("RPC: write space: waking waiting task on " - "xprt %p\n", xprt); - rpc_wake_up_queued_task(&xprt->pending, xprt->snd_task); - } - spin_unlock_bh(&xprt->transport_lock); + bool ret; + + if (!test_bit(XPRT_WRITE_SPACE, &xprt->state)) + return false; + spin_lock(&xprt->transport_lock); + ret = xprt_clear_write_space_locked(xprt); + spin_unlock(&xprt->transport_lock); + return ret; } EXPORT_SYMBOL_GPL(xprt_write_space); -/** - * xprt_set_retrans_timeout_def - set a request's retransmit timeout - * @task: task whose timeout is to be set - * - * Set a request's retransmit timeout based on the transport's - * default timeout parameters. Used by transports that don't adjust - * the retransmit timeout based on round-trip time estimation. - */ -void xprt_set_retrans_timeout_def(struct rpc_task *task) +static unsigned long xprt_abs_ktime_to_jiffies(ktime_t abstime) { - task->tk_timeout = task->tk_rqstp->rq_timeout; + s64 delta = ktime_to_ns(ktime_get() - abstime); + return likely(delta >= 0) ? + jiffies - nsecs_to_jiffies(delta) : + jiffies + nsecs_to_jiffies(-delta); } -EXPORT_SYMBOL_GPL(xprt_set_retrans_timeout_def); -/** - * xprt_set_retrans_timeout_rtt - set a request's retransmit timeout - * @task: task whose timeout is to be set - * - * Set a request's retransmit timeout using the RTT estimator. - */ -void xprt_set_retrans_timeout_rtt(struct rpc_task *task) +static unsigned long xprt_calc_majortimeo(struct rpc_rqst *req, + const struct rpc_timeout *to) { - int timer = task->tk_msg.rpc_proc->p_timer; - struct rpc_clnt *clnt = task->tk_client; - struct rpc_rtt *rtt = clnt->cl_rtt; - struct rpc_rqst *req = task->tk_rqstp; - unsigned long max_timeout = clnt->cl_timeout->to_maxval; + unsigned long majortimeo = req->rq_timeout; - task->tk_timeout = rpc_calc_rto(rtt, timer); - task->tk_timeout <<= rpc_ntimeo(rtt, timer) + req->rq_retries; - if (task->tk_timeout > max_timeout || task->tk_timeout == 0) - task->tk_timeout = max_timeout; + if (to->to_exponential) + majortimeo <<= to->to_retries; + else + majortimeo += to->to_increment * to->to_retries; + if (majortimeo > to->to_maxval || majortimeo == 0) + majortimeo = to->to_maxval; + return majortimeo; } -EXPORT_SYMBOL_GPL(xprt_set_retrans_timeout_rtt); -static void xprt_reset_majortimeo(struct rpc_rqst *req) +static void xprt_reset_majortimeo(struct rpc_rqst *req, + const struct rpc_timeout *to) { - const struct rpc_timeout *to = req->rq_task->tk_client->cl_timeout; + req->rq_majortimeo += xprt_calc_majortimeo(req, to); +} - req->rq_majortimeo = req->rq_timeout; - if (to->to_exponential) - req->rq_majortimeo <<= to->to_retries; +static void xprt_reset_minortimeo(struct rpc_rqst *req) +{ + req->rq_minortimeo += req->rq_timeout; +} + +static void xprt_init_majortimeo(struct rpc_task *task, struct rpc_rqst *req, + const struct rpc_timeout *to) +{ + unsigned long time_init; + struct rpc_xprt *xprt = req->rq_xprt; + + if (likely(xprt && xprt_connected(xprt))) + time_init = jiffies; else - req->rq_majortimeo += to->to_increment * to->to_retries; - if (req->rq_majortimeo > to->to_maxval || req->rq_majortimeo == 0) - req->rq_majortimeo = to->to_maxval; - req->rq_majortimeo += jiffies; + time_init = xprt_abs_ktime_to_jiffies(task->tk_start); + + req->rq_timeout = to->to_initval; + req->rq_majortimeo = time_init + xprt_calc_majortimeo(req, to); + req->rq_minortimeo = time_init + req->rq_timeout; } /** @@ -581,6 +704,8 @@ int xprt_adjust_timeout(struct rpc_rqst *req) int status = 0; if (time_before(jiffies, req->rq_majortimeo)) { + if (time_before(jiffies, req->rq_minortimeo)) + return status; if (to->to_exponential) req->rq_timeout <<= 1; else @@ -591,13 +716,14 @@ int xprt_adjust_timeout(struct rpc_rqst *req) } else { req->rq_timeout = to->to_initval; req->rq_retries = 0; - xprt_reset_majortimeo(req); + xprt_reset_majortimeo(req, to); /* Reset the RTT counters == "slow start" */ - spin_lock_bh(&xprt->transport_lock); + spin_lock(&xprt->transport_lock); rpc_init_rtt(req->rq_task->tk_client->cl_rtt, to->to_initval); - spin_unlock_bh(&xprt->transport_lock); + spin_unlock(&xprt->transport_lock); status = -ETIMEDOUT; } + xprt_reset_minortimeo(req); if (req->rq_timeout == 0) { printk(KERN_WARNING "xprt_adjust_timeout: rq_timeout = 0!\n"); @@ -610,10 +736,16 @@ static void xprt_autoclose(struct work_struct *work) { struct rpc_xprt *xprt = container_of(work, struct rpc_xprt, task_cleanup); + unsigned int pflags = memalloc_nofs_save(); - xprt->ops->close(xprt); + trace_xprt_disconnect_auto(xprt); + xprt->connect_cookie++; + smp_mb__before_atomic(); clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + xprt->ops->close(xprt); xprt_release_write(xprt, NULL); + wake_up_bit(&xprt->state, XPRT_LOCKED); + memalloc_nofs_restore(pflags); } /** @@ -623,29 +755,61 @@ static void xprt_autoclose(struct work_struct *work) */ void xprt_disconnect_done(struct rpc_xprt *xprt) { - dprintk("RPC: disconnected transport %p\n", xprt); - spin_lock_bh(&xprt->transport_lock); + trace_xprt_disconnect_done(xprt); + spin_lock(&xprt->transport_lock); xprt_clear_connected(xprt); - xprt_wake_pending_tasks(xprt, -EAGAIN); - spin_unlock_bh(&xprt->transport_lock); + xprt_clear_write_space_locked(xprt); + xprt_clear_congestion_window_wait_locked(xprt); + xprt_wake_pending_tasks(xprt, -ENOTCONN); + spin_unlock(&xprt->transport_lock); } EXPORT_SYMBOL_GPL(xprt_disconnect_done); /** + * xprt_schedule_autoclose_locked - Try to schedule an autoclose RPC call + * @xprt: transport to disconnect + */ +static void xprt_schedule_autoclose_locked(struct rpc_xprt *xprt) +{ + if (test_and_set_bit(XPRT_CLOSE_WAIT, &xprt->state)) + return; + if (test_and_set_bit(XPRT_LOCKED, &xprt->state) == 0) + queue_work(xprtiod_workqueue, &xprt->task_cleanup); + else if (xprt->snd_task && !test_bit(XPRT_SND_IS_COOKIE, &xprt->state)) + rpc_wake_up_queued_task_set_status(&xprt->pending, + xprt->snd_task, -ENOTCONN); +} + +/** * xprt_force_disconnect - force a transport to disconnect * @xprt: transport to disconnect * */ void xprt_force_disconnect(struct rpc_xprt *xprt) { + trace_xprt_disconnect_force(xprt); + /* Don't race with the test_bit() in xprt_clear_locked() */ - spin_lock_bh(&xprt->transport_lock); - set_bit(XPRT_CLOSE_WAIT, &xprt->state); - /* Try to schedule an autoclose RPC call */ - if (test_and_set_bit(XPRT_LOCKED, &xprt->state) == 0) - queue_work(rpciod_workqueue, &xprt->task_cleanup); - xprt_wake_pending_tasks(xprt, -EAGAIN); - spin_unlock_bh(&xprt->transport_lock); + spin_lock(&xprt->transport_lock); + xprt_schedule_autoclose_locked(xprt); + spin_unlock(&xprt->transport_lock); +} +EXPORT_SYMBOL_GPL(xprt_force_disconnect); + +static unsigned int +xprt_connect_cookie(struct rpc_xprt *xprt) +{ + return READ_ONCE(xprt->connect_cookie); +} + +static bool +xprt_request_retransmit_after_disconnect(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + return req->rq_connect_cookie != xprt_connect_cookie(xprt) || + !xprt_connected(xprt); } /** @@ -662,37 +826,94 @@ void xprt_force_disconnect(struct rpc_xprt *xprt) void xprt_conditional_disconnect(struct rpc_xprt *xprt, unsigned int cookie) { /* Don't race with the test_bit() in xprt_clear_locked() */ - spin_lock_bh(&xprt->transport_lock); + spin_lock(&xprt->transport_lock); if (cookie != xprt->connect_cookie) goto out; - if (test_bit(XPRT_CLOSING, &xprt->state) || !xprt_connected(xprt)) + if (test_bit(XPRT_CLOSING, &xprt->state)) goto out; - set_bit(XPRT_CLOSE_WAIT, &xprt->state); - /* Try to schedule an autoclose RPC call */ - if (test_and_set_bit(XPRT_LOCKED, &xprt->state) == 0) - queue_work(rpciod_workqueue, &xprt->task_cleanup); - xprt_wake_pending_tasks(xprt, -EAGAIN); + xprt_schedule_autoclose_locked(xprt); out: - spin_unlock_bh(&xprt->transport_lock); + spin_unlock(&xprt->transport_lock); +} + +static bool +xprt_has_timer(const struct rpc_xprt *xprt) +{ + return xprt->idle_timeout != 0; } static void -xprt_init_autodisconnect(unsigned long data) +xprt_schedule_autodisconnect(struct rpc_xprt *xprt) + __must_hold(&xprt->transport_lock) { - struct rpc_xprt *xprt = (struct rpc_xprt *)data; + xprt->last_used = jiffies; + if (RB_EMPTY_ROOT(&xprt->recv_queue) && xprt_has_timer(xprt)) + mod_timer(&xprt->timer, xprt->last_used + xprt->idle_timeout); +} - spin_lock(&xprt->transport_lock); - if (!list_empty(&xprt->recv)) - goto out_abort; +static void +xprt_init_autodisconnect(struct timer_list *t) +{ + struct rpc_xprt *xprt = timer_container_of(xprt, t, timer); + + if (!RB_EMPTY_ROOT(&xprt->recv_queue)) + return; + /* Reset xprt->last_used to avoid connect/autodisconnect cycling */ + xprt->last_used = jiffies; if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) - goto out_abort; + return; + queue_work(xprtiod_workqueue, &xprt->task_cleanup); +} + +#if IS_ENABLED(CONFIG_FAIL_SUNRPC) +static void xprt_inject_disconnect(struct rpc_xprt *xprt) +{ + if (!fail_sunrpc.ignore_client_disconnect && + should_fail(&fail_sunrpc.attr, 1)) + xprt->ops->inject_disconnect(xprt); +} +#else +static inline void xprt_inject_disconnect(struct rpc_xprt *xprt) +{ +} +#endif + +bool xprt_lock_connect(struct rpc_xprt *xprt, + struct rpc_task *task, + void *cookie) +{ + bool ret = false; + + spin_lock(&xprt->transport_lock); + if (!test_bit(XPRT_LOCKED, &xprt->state)) + goto out; + if (xprt->snd_task != task) + goto out; + set_bit(XPRT_SND_IS_COOKIE, &xprt->state); + xprt->snd_task = cookie; + ret = true; +out: spin_unlock(&xprt->transport_lock); - set_bit(XPRT_CONNECTION_CLOSE, &xprt->state); - queue_work(rpciod_workqueue, &xprt->task_cleanup); - return; -out_abort: + return ret; +} +EXPORT_SYMBOL_GPL(xprt_lock_connect); + +void xprt_unlock_connect(struct rpc_xprt *xprt, void *cookie) +{ + spin_lock(&xprt->transport_lock); + if (xprt->snd_task != cookie) + goto out; + if (!test_bit(XPRT_LOCKED, &xprt->state)) + goto out; + xprt->snd_task =NULL; + clear_bit(XPRT_SND_IS_COOKIE, &xprt->state); + xprt->ops->release_xprt(xprt, NULL); + xprt_schedule_autodisconnect(xprt); +out: spin_unlock(&xprt->transport_lock); + wake_up_bit(&xprt->state, XPRT_LOCKED); } +EXPORT_SYMBOL_GPL(xprt_unlock_connect); /** * xprt_connect - schedule a transport connect operation @@ -703,8 +924,7 @@ void xprt_connect(struct rpc_task *task) { struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; - dprintk("RPC: %5u xprt_connect xprt %p %s connected\n", task->tk_pid, - xprt, (xprt_connected(xprt) ? "is" : "is not")); + trace_xprt_connect(xprt); if (!xprt_bound(xprt)) { task->tk_status = -EAGAIN; @@ -713,52 +933,127 @@ void xprt_connect(struct rpc_task *task) if (!xprt_lock_write(xprt, task)) return; - if (test_and_clear_bit(XPRT_CLOSE_WAIT, &xprt->state)) - xprt->ops->close(xprt); - - if (xprt_connected(xprt)) - xprt_release_write(xprt, task); - else { - task->tk_rqstp->rq_bytes_sent = 0; - task->tk_timeout = task->tk_rqstp->rq_timeout; - rpc_sleep_on(&xprt->pending, task, xprt_connect_status); + if (!xprt_connected(xprt) && !test_bit(XPRT_CLOSE_WAIT, &xprt->state)) { + task->tk_rqstp->rq_connect_cookie = xprt->connect_cookie; + rpc_sleep_on_timeout(&xprt->pending, task, NULL, + xprt_request_timeout(task->tk_rqstp)); if (test_bit(XPRT_CLOSING, &xprt->state)) return; if (xprt_test_and_set_connecting(xprt)) return; - xprt->stat.connect_start = jiffies; - xprt->ops->connect(xprt, task); + /* Race breaker */ + if (!xprt_connected(xprt)) { + xprt->stat.connect_start = jiffies; + xprt->ops->connect(xprt, task); + } else { + xprt_clear_connecting(xprt); + task->tk_status = 0; + rpc_wake_up_queued_task(&xprt->pending, task); + } } + xprt_release_write(xprt, task); } -static void xprt_connect_status(struct rpc_task *task) +/** + * xprt_reconnect_delay - compute the wait before scheduling a connect + * @xprt: transport instance + * + */ +unsigned long xprt_reconnect_delay(const struct rpc_xprt *xprt) { - struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + unsigned long start, now = jiffies; - if (task->tk_status == 0) { - xprt->stat.connect_count++; - xprt->stat.connect_time += (long)jiffies - xprt->stat.connect_start; - dprintk("RPC: %5u xprt_connect_status: connection established\n", - task->tk_pid); - return; + start = xprt->stat.connect_start + xprt->reestablish_timeout; + if (time_after(start, now)) + return start - now; + return 0; +} +EXPORT_SYMBOL_GPL(xprt_reconnect_delay); + +/** + * xprt_reconnect_backoff - compute the new re-establish timeout + * @xprt: transport instance + * @init_to: initial reestablish timeout + * + */ +void xprt_reconnect_backoff(struct rpc_xprt *xprt, unsigned long init_to) +{ + xprt->reestablish_timeout <<= 1; + if (xprt->reestablish_timeout > xprt->max_reconnect_timeout) + xprt->reestablish_timeout = xprt->max_reconnect_timeout; + if (xprt->reestablish_timeout < init_to) + xprt->reestablish_timeout = init_to; +} +EXPORT_SYMBOL_GPL(xprt_reconnect_backoff); + +enum xprt_xid_rb_cmp { + XID_RB_EQUAL, + XID_RB_LEFT, + XID_RB_RIGHT, +}; +static enum xprt_xid_rb_cmp +xprt_xid_cmp(__be32 xid1, __be32 xid2) +{ + if (xid1 == xid2) + return XID_RB_EQUAL; + if ((__force u32)xid1 < (__force u32)xid2) + return XID_RB_LEFT; + return XID_RB_RIGHT; +} + +static struct rpc_rqst * +xprt_request_rb_find(struct rpc_xprt *xprt, __be32 xid) +{ + struct rb_node *n = xprt->recv_queue.rb_node; + struct rpc_rqst *req; + + while (n != NULL) { + req = rb_entry(n, struct rpc_rqst, rq_recv); + switch (xprt_xid_cmp(xid, req->rq_xid)) { + case XID_RB_LEFT: + n = n->rb_left; + break; + case XID_RB_RIGHT: + n = n->rb_right; + break; + case XID_RB_EQUAL: + return req; + } } + return NULL; +} - switch (task->tk_status) { - case -EAGAIN: - dprintk("RPC: %5u xprt_connect_status: retrying\n", task->tk_pid); - break; - case -ETIMEDOUT: - dprintk("RPC: %5u xprt_connect_status: connect attempt timed " - "out\n", task->tk_pid); - break; - default: - dprintk("RPC: %5u xprt_connect_status: error %d connecting to " - "server %s\n", task->tk_pid, -task->tk_status, - xprt->servername); - xprt_release_write(xprt, task); - task->tk_status = -EIO; +static void +xprt_request_rb_insert(struct rpc_xprt *xprt, struct rpc_rqst *new) +{ + struct rb_node **p = &xprt->recv_queue.rb_node; + struct rb_node *n = NULL; + struct rpc_rqst *req; + + while (*p != NULL) { + n = *p; + req = rb_entry(n, struct rpc_rqst, rq_recv); + switch(xprt_xid_cmp(new->rq_xid, req->rq_xid)) { + case XID_RB_LEFT: + p = &n->rb_left; + break; + case XID_RB_RIGHT: + p = &n->rb_right; + break; + case XID_RB_EQUAL: + WARN_ON_ONCE(new != req); + return; + } } + rb_link_node(&new->rq_recv, n, p); + rb_insert_color(&new->rq_recv, &xprt->recv_queue); +} + +static void +xprt_request_rb_remove(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + rb_erase(&req->rq_recv, &xprt->recv_queue); } /** @@ -766,23 +1061,138 @@ static void xprt_connect_status(struct rpc_task *task) * @xprt: transport on which the original request was transmitted * @xid: RPC XID of incoming reply * + * Caller holds xprt->queue_lock. */ struct rpc_rqst *xprt_lookup_rqst(struct rpc_xprt *xprt, __be32 xid) { struct rpc_rqst *entry; - list_for_each_entry(entry, &xprt->recv, rq_list) - if (entry->rq_xid == xid) - return entry; + entry = xprt_request_rb_find(xprt, xid); + if (entry != NULL) { + trace_xprt_lookup_rqst(xprt, xid, 0); + entry->rq_rtt = ktime_sub(ktime_get(), entry->rq_xtime); + return entry; + } dprintk("RPC: xprt_lookup_rqst did not find xid %08x\n", ntohl(xid)); + trace_xprt_lookup_rqst(xprt, xid, -ENOENT); xprt->stat.bad_xids++; return NULL; } EXPORT_SYMBOL_GPL(xprt_lookup_rqst); -static void xprt_update_rtt(struct rpc_task *task) +static bool +xprt_is_pinned_rqst(struct rpc_rqst *req) +{ + return atomic_read(&req->rq_pin) != 0; +} + +/** + * xprt_pin_rqst - Pin a request on the transport receive list + * @req: Request to pin + * + * Caller must ensure this is atomic with the call to xprt_lookup_rqst() + * so should be holding xprt->queue_lock. + */ +void xprt_pin_rqst(struct rpc_rqst *req) +{ + atomic_inc(&req->rq_pin); +} +EXPORT_SYMBOL_GPL(xprt_pin_rqst); + +/** + * xprt_unpin_rqst - Unpin a request on the transport receive list + * @req: Request to pin + * + * Caller should be holding xprt->queue_lock. + */ +void xprt_unpin_rqst(struct rpc_rqst *req) +{ + if (!test_bit(RPC_TASK_MSG_PIN_WAIT, &req->rq_task->tk_runstate)) { + atomic_dec(&req->rq_pin); + return; + } + if (atomic_dec_and_test(&req->rq_pin)) + wake_up_var(&req->rq_pin); +} +EXPORT_SYMBOL_GPL(xprt_unpin_rqst); + +static void xprt_wait_on_pinned_rqst(struct rpc_rqst *req) +{ + wait_var_event(&req->rq_pin, !xprt_is_pinned_rqst(req)); +} + +static bool +xprt_request_data_received(struct rpc_task *task) +{ + return !test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) && + READ_ONCE(task->tk_rqstp->rq_reply_bytes_recvd) != 0; +} + +static bool +xprt_request_need_enqueue_receive(struct rpc_task *task, struct rpc_rqst *req) +{ + return !test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) && + READ_ONCE(task->tk_rqstp->rq_reply_bytes_recvd) == 0; +} + +/** + * xprt_request_enqueue_receive - Add an request to the receive queue + * @task: RPC task + * + */ +int +xprt_request_enqueue_receive(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + int ret; + + if (!xprt_request_need_enqueue_receive(task, req)) + return 0; + + ret = xprt_request_prepare(task->tk_rqstp, &req->rq_rcv_buf); + if (ret) + return ret; + spin_lock(&xprt->queue_lock); + + /* Update the softirq receive buffer */ + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, + sizeof(req->rq_private_buf)); + + /* Add request to the receive list */ + xprt_request_rb_insert(xprt, req); + set_bit(RPC_TASK_NEED_RECV, &task->tk_runstate); + spin_unlock(&xprt->queue_lock); + + /* Turn off autodisconnect */ + timer_delete_sync(&xprt->timer); + return 0; +} + +/** + * xprt_request_dequeue_receive_locked - Remove a request from the receive queue + * @task: RPC task + * + * Caller must hold xprt->queue_lock. + */ +static void +xprt_request_dequeue_receive_locked(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (test_and_clear_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) + xprt_request_rb_remove(req->rq_xprt, req); +} + +/** + * xprt_update_rtt - Update RPC RTT statistics + * @task: RPC request that recently completed + * + * Caller holds xprt->queue_lock. + */ +void xprt_update_rtt(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_rtt *rtt = task->tk_client->cl_rtt; @@ -795,33 +1205,30 @@ static void xprt_update_rtt(struct rpc_task *task) rpc_set_timeo(rtt, timer, req->rq_ntrans - 1); } } +EXPORT_SYMBOL_GPL(xprt_update_rtt); /** * xprt_complete_rqst - called when reply processing is complete * @task: RPC request that recently completed * @copied: actual number of bytes received from the transport * - * Caller holds transport lock. + * Caller holds xprt->queue_lock. */ void xprt_complete_rqst(struct rpc_task *task, int copied) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; - dprintk("RPC: %5u xid %08x complete (%d bytes received)\n", - task->tk_pid, ntohl(req->rq_xid), copied); - xprt->stat.recvs++; - req->rq_rtt = ktime_sub(ktime_get(), req->rq_xtime); - if (xprt->ops->timer != NULL) - xprt_update_rtt(task); - list_del_init(&req->rq_list); + xdr_free_bvec(&req->rq_rcv_buf); + req->rq_private_buf.bvec = NULL; req->rq_private_buf.len = copied; /* Ensure all writes are done before we update */ /* req->rq_reply_bytes_recvd */ smp_wmb(); req->rq_reply_bytes_recvd = copied; + xprt_request_dequeue_receive_locked(task); rpc_wake_up_queued_task(&xprt->pending, task); } EXPORT_SYMBOL_GPL(xprt_complete_rqst); @@ -833,132 +1240,450 @@ static void xprt_timer(struct rpc_task *task) if (task->tk_status != -ETIMEDOUT) return; - dprintk("RPC: %5u xprt_timer\n", task->tk_pid); - spin_lock_bh(&xprt->transport_lock); + trace_xprt_timer(xprt, req->rq_xid, task->tk_status); if (!req->rq_reply_bytes_recvd) { if (xprt->ops->timer) xprt->ops->timer(xprt, task); } else task->tk_status = 0; - spin_unlock_bh(&xprt->transport_lock); } -static inline int xprt_has_timer(struct rpc_xprt *xprt) +/** + * xprt_wait_for_reply_request_def - wait for reply + * @task: pointer to rpc_task + * + * Set a request's retransmit timeout based on the transport's + * default timeout parameters. Used by transports that don't adjust + * the retransmit timeout based on round-trip time estimation, + * and put the task to sleep on the pending queue. + */ +void xprt_wait_for_reply_request_def(struct rpc_task *task) { - return xprt->idle_timeout != 0; + struct rpc_rqst *req = task->tk_rqstp; + + rpc_sleep_on_timeout(&req->rq_xprt->pending, task, xprt_timer, + xprt_request_timeout(req)); } +EXPORT_SYMBOL_GPL(xprt_wait_for_reply_request_def); /** - * xprt_prepare_transmit - reserve the transport before sending a request + * xprt_wait_for_reply_request_rtt - wait for reply using RTT estimator + * @task: pointer to rpc_task + * + * Set a request's retransmit timeout using the RTT estimator, + * and put the task to sleep on the pending queue. + */ +void xprt_wait_for_reply_request_rtt(struct rpc_task *task) +{ + int timer = task->tk_msg.rpc_proc->p_timer; + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rtt *rtt = clnt->cl_rtt; + struct rpc_rqst *req = task->tk_rqstp; + unsigned long max_timeout = clnt->cl_timeout->to_maxval; + unsigned long timeout; + + timeout = rpc_calc_rto(rtt, timer); + timeout <<= rpc_ntimeo(rtt, timer) + req->rq_retries; + if (timeout > max_timeout || timeout == 0) + timeout = max_timeout; + rpc_sleep_on_timeout(&req->rq_xprt->pending, task, xprt_timer, + jiffies + timeout); +} +EXPORT_SYMBOL_GPL(xprt_wait_for_reply_request_rtt); + +/** + * xprt_request_wait_receive - wait for the reply to an RPC request * @task: RPC task about to send a request * */ -int xprt_prepare_transmit(struct rpc_task *task) +void xprt_request_wait_receive(struct rpc_task *task) { - struct rpc_rqst *req = task->tk_rqstp; - struct rpc_xprt *xprt = req->rq_xprt; - int err = 0; + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; - dprintk("RPC: %5u xprt_prepare_transmit\n", task->tk_pid); + if (!test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) + return; + /* + * Sleep on the pending queue if we're expecting a reply. + * The spinlock ensures atomicity between the test of + * req->rq_reply_bytes_recvd, and the call to rpc_sleep_on(). + */ + spin_lock(&xprt->queue_lock); + if (test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) { + xprt->ops->wait_for_reply_request(task); + /* + * Send an extra queue wakeup call if the + * connection was dropped in case the call to + * rpc_sleep_on() raced. + */ + if (xprt_request_retransmit_after_disconnect(task)) + rpc_wake_up_queued_task_set_status(&xprt->pending, + task, -ENOTCONN); + } + spin_unlock(&xprt->queue_lock); +} - spin_lock_bh(&xprt->transport_lock); - if (req->rq_reply_bytes_recvd && !req->rq_bytes_sent) { - err = req->rq_reply_bytes_recvd; - goto out_unlock; +static bool +xprt_request_need_enqueue_transmit(struct rpc_task *task, struct rpc_rqst *req) +{ + return !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); +} + +/** + * xprt_request_enqueue_transmit - queue a task for transmission + * @task: pointer to rpc_task + * + * Add a task to the transmission queue. + */ +void +xprt_request_enqueue_transmit(struct rpc_task *task) +{ + struct rpc_rqst *pos, *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + int ret; + + if (xprt_request_need_enqueue_transmit(task, req)) { + ret = xprt_request_prepare(task->tk_rqstp, &req->rq_snd_buf); + if (ret) { + task->tk_status = ret; + return; + } + req->rq_bytes_sent = 0; + spin_lock(&xprt->queue_lock); + /* + * Requests that carry congestion control credits are added + * to the head of the list to avoid starvation issues. + */ + if (req->rq_cong) { + xprt_clear_congestion_window_wait(xprt); + list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) { + if (pos->rq_cong) + continue; + /* Note: req is added _before_ pos */ + list_add_tail(&req->rq_xmit, &pos->rq_xmit); + INIT_LIST_HEAD(&req->rq_xmit2); + goto out; + } + } else if (req->rq_seqno_count == 0) { + list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) { + if (pos->rq_task->tk_owner != task->tk_owner) + continue; + list_add_tail(&req->rq_xmit2, &pos->rq_xmit2); + INIT_LIST_HEAD(&req->rq_xmit); + goto out; + } + } + list_add_tail(&req->rq_xmit, &xprt->xmit_queue); + INIT_LIST_HEAD(&req->rq_xmit2); +out: + atomic_long_inc(&xprt->xmit_queuelen); + set_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); + spin_unlock(&xprt->queue_lock); } - if (!xprt->ops->reserve_xprt(xprt, task)) - err = -EAGAIN; -out_unlock: - spin_unlock_bh(&xprt->transport_lock); - return err; } -void xprt_end_transmit(struct rpc_task *task) +/** + * xprt_request_dequeue_transmit_locked - remove a task from the transmission queue + * @task: pointer to rpc_task + * + * Remove a task from the transmission queue + * Caller must hold xprt->queue_lock + */ +static void +xprt_request_dequeue_transmit_locked(struct rpc_task *task) { - xprt_release_write(task->tk_rqstp->rq_xprt, task); + struct rpc_rqst *req = task->tk_rqstp; + + if (!test_and_clear_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + return; + if (!list_empty(&req->rq_xmit)) { + struct rpc_xprt *xprt = req->rq_xprt; + + if (list_is_first(&req->rq_xmit, &xprt->xmit_queue) && + xprt->ops->abort_send_request) + xprt->ops->abort_send_request(req); + + list_del(&req->rq_xmit); + if (!list_empty(&req->rq_xmit2)) { + struct rpc_rqst *next = list_first_entry(&req->rq_xmit2, + struct rpc_rqst, rq_xmit2); + list_del(&req->rq_xmit2); + list_add_tail(&next->rq_xmit, &next->rq_xprt->xmit_queue); + } + } else + list_del(&req->rq_xmit2); + atomic_long_dec(&req->rq_xprt->xmit_queuelen); + xdr_free_bvec(&req->rq_snd_buf); } /** - * xprt_transmit - send an RPC request on a transport - * @task: controlling RPC task + * xprt_request_dequeue_transmit - remove a task from the transmission queue + * @task: pointer to rpc_task * - * We have to copy the iovec because sendmsg fiddles with its contents. + * Remove a task from the transmission queue */ -void xprt_transmit(struct rpc_task *task) +static void +xprt_request_dequeue_transmit(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + spin_lock(&xprt->queue_lock); + xprt_request_dequeue_transmit_locked(task); + spin_unlock(&xprt->queue_lock); +} + +/** + * xprt_request_dequeue_xprt - remove a task from the transmit+receive queue + * @task: pointer to rpc_task + * + * Remove a task from the transmit and receive queues, and ensure that + * it is not pinned by the receive work item. + */ +void +xprt_request_dequeue_xprt(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate) || + test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) || + xprt_is_pinned_rqst(req)) { + spin_lock(&xprt->queue_lock); + while (xprt_is_pinned_rqst(req)) { + set_bit(RPC_TASK_MSG_PIN_WAIT, &task->tk_runstate); + spin_unlock(&xprt->queue_lock); + xprt_wait_on_pinned_rqst(req); + spin_lock(&xprt->queue_lock); + clear_bit(RPC_TASK_MSG_PIN_WAIT, &task->tk_runstate); + } + xprt_request_dequeue_transmit_locked(task); + xprt_request_dequeue_receive_locked(task); + spin_unlock(&xprt->queue_lock); + xdr_free_bvec(&req->rq_rcv_buf); + } +} + +/** + * xprt_request_prepare - prepare an encoded request for transport + * @req: pointer to rpc_rqst + * @buf: pointer to send/rcv xdr_buf + * + * Calls into the transport layer to do whatever is needed to prepare + * the request for transmission or receive. + * Returns error, or zero. + */ +static int +xprt_request_prepare(struct rpc_rqst *req, struct xdr_buf *buf) +{ + struct rpc_xprt *xprt = req->rq_xprt; + + if (xprt->ops->prepare_request) + return xprt->ops->prepare_request(req, buf); + return 0; +} + +/** + * xprt_request_need_retransmit - Test if a task needs retransmission + * @task: pointer to rpc_task + * + * Test for whether a connection breakage requires the task to retransmit + */ +bool +xprt_request_need_retransmit(struct rpc_task *task) +{ + return xprt_request_retransmit_after_disconnect(task); +} + +/** + * xprt_prepare_transmit - reserve the transport before sending a request + * @task: RPC task about to send a request + * + */ +bool xprt_prepare_transmit(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; - int status, numreqs; - dprintk("RPC: %5u xprt_transmit(%u)\n", task->tk_pid, req->rq_slen); + if (!xprt_lock_write(xprt, task)) { + /* Race breaker: someone may have transmitted us */ + if (!test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + rpc_wake_up_queued_task_set_status(&xprt->sending, + task, 0); + return false; - if (!req->rq_reply_bytes_recvd) { - if (list_empty(&req->rq_list) && rpc_reply_expected(task)) { - /* - * Add to the list only if we're expecting a reply - */ - spin_lock_bh(&xprt->transport_lock); - /* Update the softirq receive buffer */ - memcpy(&req->rq_private_buf, &req->rq_rcv_buf, - sizeof(req->rq_private_buf)); - /* Add request to the receive list */ - list_add_tail(&req->rq_list, &xprt->recv); - spin_unlock_bh(&xprt->transport_lock); - xprt_reset_majortimeo(req); - /* Turn off autodisconnect */ - del_singleshot_timer_sync(&xprt->timer); + } + if (atomic_read(&xprt->swapper)) + /* This will be clear in __rpc_execute */ + current->flags |= PF_MEMALLOC; + return true; +} + +void xprt_end_transmit(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + + xprt_inject_disconnect(xprt); + xprt_release_write(xprt, task); +} + +/** + * xprt_request_transmit - send an RPC request on a transport + * @req: pointer to request to transmit + * @snd_task: RPC task that owns the transport lock + * + * This performs the transmission of a single request. + * Note that if the request is not the same as snd_task, then it + * does need to be pinned. + * Returns '0' on success. + */ +static int +xprt_request_transmit(struct rpc_rqst *req, struct rpc_task *snd_task) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct rpc_task *task = req->rq_task; + unsigned int connect_cookie; + int is_retrans = RPC_WAS_SENT(task); + int status; + + if (test_bit(XPRT_CLOSE_WAIT, &xprt->state)) + return -ENOTCONN; + + if (!req->rq_bytes_sent) { + if (xprt_request_data_received(task)) { + status = 0; + goto out_dequeue; } - } else if (!req->rq_bytes_sent) - return; + /* Verify that our message lies in the RPCSEC_GSS window */ + if (rpcauth_xmit_need_reencode(task)) { + status = -EBADMSG; + goto out_dequeue; + } + if (RPC_SIGNALLED(task)) { + status = -ERESTARTSYS; + goto out_dequeue; + } + } - req->rq_connect_cookie = xprt->connect_cookie; - req->rq_xtime = ktime_get(); - status = xprt->ops->send_request(task); + /* + * Update req->rq_ntrans before transmitting to avoid races with + * xprt_update_rtt(), which needs to know that it is recording a + * reply to the first transmission. + */ + req->rq_ntrans++; + + trace_rpc_xdr_sendto(task, &req->rq_snd_buf); + connect_cookie = xprt->connect_cookie; + status = xprt->ops->send_request(req); if (status != 0) { - task->tk_status = status; - return; + req->rq_ntrans--; + trace_xprt_transmit(req, status); + return status; } - dprintk("RPC: %5u xmit complete\n", task->tk_pid); - task->tk_flags |= RPC_TASK_SENT; - spin_lock_bh(&xprt->transport_lock); + if (is_retrans) { + task->tk_client->cl_stats->rpcretrans++; + trace_xprt_retransmit(req); + } + + xprt_inject_disconnect(xprt); - xprt->ops->set_retrans_timeout(task); + task->tk_flags |= RPC_TASK_SENT; + spin_lock(&xprt->transport_lock); - numreqs = atomic_read(&xprt->num_reqs); - if (numreqs > xprt->stat.max_slots) - xprt->stat.max_slots = numreqs; xprt->stat.sends++; xprt->stat.req_u += xprt->stat.sends - xprt->stat.recvs; xprt->stat.bklog_u += xprt->backlog.qlen; xprt->stat.sending_u += xprt->sending.qlen; xprt->stat.pending_u += xprt->pending.qlen; + spin_unlock(&xprt->transport_lock); - /* Don't race with disconnect */ - if (!xprt_connected(xprt)) - task->tk_status = -ENOTCONN; - else if (!req->rq_reply_bytes_recvd && rpc_reply_expected(task)) { - /* - * Sleep on the pending queue since - * we're expecting a reply. - */ - rpc_sleep_on(&xprt->pending, task, xprt_timer); + req->rq_connect_cookie = connect_cookie; +out_dequeue: + trace_xprt_transmit(req, status); + xprt_request_dequeue_transmit(task); + rpc_wake_up_queued_task_set_status(&xprt->sending, task, status); + return status; +} + +/** + * xprt_transmit - send an RPC request on a transport + * @task: controlling RPC task + * + * Attempts to drain the transmit queue. On exit, either the transport + * signalled an error that needs to be handled before transmission can + * resume, or @task finished transmitting, and detected that it already + * received a reply. + */ +void +xprt_transmit(struct rpc_task *task) +{ + struct rpc_rqst *next, *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + int status; + + spin_lock(&xprt->queue_lock); + for (;;) { + next = list_first_entry_or_null(&xprt->xmit_queue, + struct rpc_rqst, rq_xmit); + if (!next) + break; + xprt_pin_rqst(next); + spin_unlock(&xprt->queue_lock); + status = xprt_request_transmit(next, task); + if (status == -EBADMSG && next != req) + status = 0; + spin_lock(&xprt->queue_lock); + xprt_unpin_rqst(next); + if (status < 0) { + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + task->tk_status = status; + break; + } + /* Was @task transmitted, and has it received a reply? */ + if (xprt_request_data_received(task) && + !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + break; + cond_resched_lock(&xprt->queue_lock); } - spin_unlock_bh(&xprt->transport_lock); + spin_unlock(&xprt->queue_lock); +} + +static void xprt_complete_request_init(struct rpc_task *task) +{ + if (task->tk_rqstp) + xprt_request_init(task); } -static void xprt_add_backlog(struct rpc_xprt *xprt, struct rpc_task *task) +void xprt_add_backlog(struct rpc_xprt *xprt, struct rpc_task *task) { set_bit(XPRT_CONGESTED, &xprt->state); - rpc_sleep_on(&xprt->backlog, task, NULL); + rpc_sleep_on(&xprt->backlog, task, xprt_complete_request_init); +} +EXPORT_SYMBOL_GPL(xprt_add_backlog); + +static bool __xprt_set_rq(struct rpc_task *task, void *data) +{ + struct rpc_rqst *req = data; + + if (task->tk_rqstp == NULL) { + memset(req, 0, sizeof(*req)); /* mark unused */ + task->tk_rqstp = req; + return true; + } + return false; } -static void xprt_wake_up_backlog(struct rpc_xprt *xprt) +bool xprt_wake_up_backlog(struct rpc_xprt *xprt, struct rpc_rqst *req) { - if (rpc_wake_up_next(&xprt->backlog) == NULL) + if (rpc_wake_up_first(&xprt->backlog, __xprt_set_rq, req) == NULL) { clear_bit(XPRT_CONGESTED, &xprt->state); + return false; + } + return true; } +EXPORT_SYMBOL_GPL(xprt_wake_up_backlog); static bool xprt_throttle_congested(struct rpc_xprt *xprt, struct rpc_task *task) { @@ -968,7 +1693,7 @@ static bool xprt_throttle_congested(struct rpc_xprt *xprt, struct rpc_task *task goto out; spin_lock(&xprt->reserve_lock); if (test_bit(XPRT_CONGESTED, &xprt->state)) { - rpc_sleep_on(&xprt->backlog, task, NULL); + xprt_add_backlog(xprt, task); ret = true; } spin_unlock(&xprt->reserve_lock); @@ -976,16 +1701,19 @@ out: return ret; } -static struct rpc_rqst *xprt_dynamic_alloc_slot(struct rpc_xprt *xprt, gfp_t gfp_flags) +static struct rpc_rqst *xprt_dynamic_alloc_slot(struct rpc_xprt *xprt) { struct rpc_rqst *req = ERR_PTR(-EAGAIN); - if (!atomic_add_unless(&xprt->num_reqs, 1, xprt->max_reqs)) + if (xprt->num_reqs >= xprt->max_reqs) goto out; - req = kzalloc(sizeof(struct rpc_rqst), gfp_flags); + ++xprt->num_reqs; + spin_unlock(&xprt->reserve_lock); + req = kzalloc(sizeof(*req), rpc_task_gfp_mask()); + spin_lock(&xprt->reserve_lock); if (req != NULL) goto out; - atomic_dec(&xprt->num_reqs); + --xprt->num_reqs; req = ERR_PTR(-ENOMEM); out: return req; @@ -993,7 +1721,8 @@ out: static bool xprt_dynamic_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) { - if (atomic_add_unless(&xprt->num_reqs, -1, xprt->min_reqs)) { + if (xprt->num_reqs > xprt->min_reqs) { + --xprt->num_reqs; kfree(req); return true; } @@ -1010,7 +1739,7 @@ void xprt_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) list_del(&req->rq_list); goto out_init_req; } - req = xprt_dynamic_alloc_slot(xprt, GFP_NOWAIT|__GFP_NOWARN); + req = xprt_dynamic_alloc_slot(xprt); if (!IS_ERR(req)) goto out_init_req; switch (PTR_ERR(req)) { @@ -1022,43 +1751,33 @@ void xprt_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) case -EAGAIN: xprt_add_backlog(xprt, task); dprintk("RPC: waiting for request slot\n"); + fallthrough; default: task->tk_status = -EAGAIN; } spin_unlock(&xprt->reserve_lock); return; out_init_req: + xprt->stat.max_slots = max_t(unsigned int, xprt->stat.max_slots, + xprt->num_reqs); + spin_unlock(&xprt->reserve_lock); + task->tk_status = 0; task->tk_rqstp = req; - xprt_request_init(task, xprt); - spin_unlock(&xprt->reserve_lock); } EXPORT_SYMBOL_GPL(xprt_alloc_slot); -void xprt_lock_and_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) -{ - /* Note: grabbing the xprt_lock_write() ensures that we throttle - * new slot allocation if the transport is congested (i.e. when - * reconnecting a stream transport or when out of socket write - * buffer space). - */ - if (xprt_lock_write(xprt, task)) { - xprt_alloc_slot(xprt, task); - xprt_release_write(xprt, task); - } -} -EXPORT_SYMBOL_GPL(xprt_lock_and_alloc_slot); - -static void xprt_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) +void xprt_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) { spin_lock(&xprt->reserve_lock); - if (!xprt_dynamic_free_slot(xprt, req)) { + if (!xprt_wake_up_backlog(xprt, req) && + !xprt_dynamic_free_slot(xprt, req)) { memset(req, 0, sizeof(*req)); /* mark unused */ list_add(&req->rq_list, &xprt->free); } - xprt_wake_up_backlog(xprt); spin_unlock(&xprt->reserve_lock); } +EXPORT_SYMBOL_GPL(xprt_free_slot); static void xprt_free_all_slots(struct rpc_xprt *xprt) { @@ -1070,6 +1789,30 @@ static void xprt_free_all_slots(struct rpc_xprt *xprt) } } +static DEFINE_IDA(rpc_xprt_ids); + +void xprt_cleanup_ids(void) +{ + ida_destroy(&rpc_xprt_ids); +} + +static int xprt_alloc_id(struct rpc_xprt *xprt) +{ + int id; + + id = ida_alloc(&rpc_xprt_ids, GFP_KERNEL); + if (id < 0) + return id; + + xprt->id = id; + return 0; +} + +static void xprt_free_id(struct rpc_xprt *xprt) +{ + ida_free(&rpc_xprt_ids, xprt->id); +} + struct rpc_xprt *xprt_alloc(struct net *net, size_t size, unsigned int num_prealloc, unsigned int max_alloc) @@ -1082,22 +1825,18 @@ struct rpc_xprt *xprt_alloc(struct net *net, size_t size, if (xprt == NULL) goto out; + xprt_alloc_id(xprt); xprt_init(xprt, net); for (i = 0; i < num_prealloc; i++) { req = kzalloc(sizeof(struct rpc_rqst), GFP_KERNEL); if (!req) - break; + goto out_free; list_add(&req->rq_list, &xprt->free); } - if (i < num_prealloc) - goto out_free; - if (max_alloc > num_prealloc) - xprt->max_reqs = max_alloc; - else - xprt->max_reqs = num_prealloc; + xprt->max_reqs = max_t(unsigned int, max_alloc, num_prealloc); xprt->min_reqs = num_prealloc; - atomic_set(&xprt->num_reqs, num_prealloc); + xprt->num_reqs = num_prealloc; return xprt; @@ -1110,12 +1849,69 @@ EXPORT_SYMBOL_GPL(xprt_alloc); void xprt_free(struct rpc_xprt *xprt) { - put_net(xprt->xprt_net); + put_net_track(xprt->xprt_net, &xprt->ns_tracker); xprt_free_all_slots(xprt); - kfree(xprt); + xprt_free_id(xprt); + rpc_sysfs_xprt_destroy(xprt); + kfree_rcu(xprt, rcu); } EXPORT_SYMBOL_GPL(xprt_free); +static void +xprt_init_connect_cookie(struct rpc_rqst *req, struct rpc_xprt *xprt) +{ + req->rq_connect_cookie = xprt_connect_cookie(xprt) - 1; +} + +static __be32 +xprt_alloc_xid(struct rpc_xprt *xprt) +{ + __be32 xid; + + spin_lock(&xprt->reserve_lock); + xid = (__force __be32)xprt->xid++; + spin_unlock(&xprt->reserve_lock); + return xid; +} + +static void +xprt_init_xid(struct rpc_xprt *xprt) +{ + xprt->xid = get_random_u32(); +} + +static void +xprt_request_init(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_rqst *req = task->tk_rqstp; + + req->rq_task = task; + req->rq_xprt = xprt; + req->rq_buffer = NULL; + req->rq_xid = xprt_alloc_xid(xprt); + xprt_init_connect_cookie(req, xprt); + req->rq_snd_buf.len = 0; + req->rq_snd_buf.buflen = 0; + req->rq_rcv_buf.len = 0; + req->rq_rcv_buf.buflen = 0; + req->rq_snd_buf.bvec = NULL; + req->rq_rcv_buf.bvec = NULL; + req->rq_release_snd_buf = NULL; + req->rq_seqno_count = 0; + xprt_init_majortimeo(task, req, task->tk_client->cl_timeout); + + trace_xprt_reserve(req); +} + +static void +xprt_do_reserve(struct rpc_xprt *xprt, struct rpc_task *task) +{ + xprt->ops->alloc_slot(xprt, task); + if (task->tk_rqstp != NULL) + xprt_request_init(task); +} + /** * xprt_reserve - allocate an RPC request slot * @task: RPC task requesting a slot allocation @@ -1126,19 +1922,15 @@ EXPORT_SYMBOL_GPL(xprt_free); */ void xprt_reserve(struct rpc_task *task) { - struct rpc_xprt *xprt; + struct rpc_xprt *xprt = task->tk_xprt; task->tk_status = 0; if (task->tk_rqstp != NULL) return; - task->tk_timeout = 0; task->tk_status = -EAGAIN; - rcu_read_lock(); - xprt = rcu_dereference(task->tk_client->cl_xprt); if (!xprt_throttle_congested(xprt, task)) - xprt->ops->alloc_slot(xprt, task); - rcu_read_unlock(); + xprt_do_reserve(xprt, task); } /** @@ -1152,44 +1944,14 @@ void xprt_reserve(struct rpc_task *task) */ void xprt_retry_reserve(struct rpc_task *task) { - struct rpc_xprt *xprt; + struct rpc_xprt *xprt = task->tk_xprt; task->tk_status = 0; if (task->tk_rqstp != NULL) return; - task->tk_timeout = 0; task->tk_status = -EAGAIN; - rcu_read_lock(); - xprt = rcu_dereference(task->tk_client->cl_xprt); - xprt->ops->alloc_slot(xprt, task); - rcu_read_unlock(); -} - -static inline __be32 xprt_alloc_xid(struct rpc_xprt *xprt) -{ - return (__force __be32)xprt->xid++; -} - -static inline void xprt_init_xid(struct rpc_xprt *xprt) -{ - xprt->xid = net_random(); -} - -static void xprt_request_init(struct rpc_task *task, struct rpc_xprt *xprt) -{ - struct rpc_rqst *req = task->tk_rqstp; - - INIT_LIST_HEAD(&req->rq_list); - req->rq_timeout = task->tk_client->cl_timeout->to_initval; - req->rq_task = task; - req->rq_xprt = xprt; - req->rq_buffer = NULL; - req->rq_xid = xprt_alloc_xid(xprt); - req->rq_release_snd_buf = NULL; - xprt_reset_majortimeo(req); - dprintk("RPC: %5u reserved req %p xid %08x\n", task->tk_pid, - req, ntohl(req->rq_xid)); + xprt_do_reserve(xprt, task); } /** @@ -1204,59 +1966,76 @@ void xprt_release(struct rpc_task *task) if (req == NULL) { if (task->tk_client) { - rcu_read_lock(); - xprt = rcu_dereference(task->tk_client->cl_xprt); - if (xprt->snd_task == task) - xprt_release_write(xprt, task); - rcu_read_unlock(); + xprt = task->tk_xprt; + xprt_release_write(xprt, task); } return; } xprt = req->rq_xprt; - if (task->tk_ops->rpc_count_stats != NULL) - task->tk_ops->rpc_count_stats(task, task->tk_calldata); - else if (task->tk_client) - rpc_count_iostats(task, task->tk_client->cl_metrics); - spin_lock_bh(&xprt->transport_lock); + xprt_request_dequeue_xprt(task); + spin_lock(&xprt->transport_lock); xprt->ops->release_xprt(xprt, task); if (xprt->ops->release_request) xprt->ops->release_request(task); - if (!list_empty(&req->rq_list)) - list_del(&req->rq_list); - xprt->last_used = jiffies; - if (list_empty(&xprt->recv) && xprt_has_timer(xprt)) - mod_timer(&xprt->timer, - xprt->last_used + xprt->idle_timeout); - spin_unlock_bh(&xprt->transport_lock); + xprt_schedule_autodisconnect(xprt); + spin_unlock(&xprt->transport_lock); if (req->rq_buffer) - xprt->ops->buf_free(req->rq_buffer); + xprt->ops->buf_free(task); if (req->rq_cred != NULL) put_rpccred(req->rq_cred); - task->tk_rqstp = NULL; if (req->rq_release_snd_buf) req->rq_release_snd_buf(req); - dprintk("RPC: %5u release request %p\n", task->tk_pid, req); + task->tk_rqstp = NULL; if (likely(!bc_prealloc(req))) - xprt_free_slot(xprt, req); + xprt->ops->free_slot(xprt, req); else xprt_free_bc_request(req); } +#ifdef CONFIG_SUNRPC_BACKCHANNEL +void +xprt_init_bc_request(struct rpc_rqst *req, struct rpc_task *task, + const struct rpc_timeout *to) +{ + struct xdr_buf *xbufp = &req->rq_snd_buf; + + task->tk_rqstp = req; + req->rq_task = task; + xprt_init_connect_cookie(req, req->rq_xprt); + /* + * Set up the xdr_buf length. + * This also indicates that the buffer is XDR encoded already. + */ + xbufp->len = xbufp->head[0].iov_len + xbufp->page_len + + xbufp->tail[0].iov_len; + /* + * Backchannel Replies are sent with !RPC_TASK_SOFT and + * RPC_TASK_NO_RETRANS_TIMEOUT. The major timeout setting + * affects only how long each Reply waits to be sent when + * a transport connection cannot be established. + */ + xprt_init_majortimeo(task, req, to); +} +#endif + static void xprt_init(struct rpc_xprt *xprt, struct net *net) { - atomic_set(&xprt->count, 1); + kref_init(&xprt->kref); spin_lock_init(&xprt->transport_lock); spin_lock_init(&xprt->reserve_lock); + spin_lock_init(&xprt->queue_lock); INIT_LIST_HEAD(&xprt->free); - INIT_LIST_HEAD(&xprt->recv); + xprt->recv_queue = RB_ROOT; + INIT_LIST_HEAD(&xprt->xmit_queue); #if defined(CONFIG_SUNRPC_BACKCHANNEL) spin_lock_init(&xprt->bc_pa_lock); INIT_LIST_HEAD(&xprt->bc_pa_list); #endif /* CONFIG_SUNRPC_BACKCHANNEL */ + INIT_LIST_HEAD(&xprt->xprt_switch); xprt->last_used = jiffies; xprt->cwnd = RPC_INITCWND; @@ -1264,12 +2043,12 @@ static void xprt_init(struct rpc_xprt *xprt, struct net *net) rpc_init_wait_queue(&xprt->binding, "xprt_binding"); rpc_init_wait_queue(&xprt->pending, "xprt_pending"); - rpc_init_priority_wait_queue(&xprt->sending, "xprt_sending"); + rpc_init_wait_queue(&xprt->sending, "xprt_sending"); rpc_init_priority_wait_queue(&xprt->backlog, "xprt_backlog"); xprt_init_xid(xprt); - xprt->xprt_net = get_net(net); + xprt->xprt_net = get_net_track(net, &xprt->ns_tracker, GFP_KERNEL); } /** @@ -1280,34 +2059,26 @@ static void xprt_init(struct rpc_xprt *xprt, struct net *net) struct rpc_xprt *xprt_create_transport(struct xprt_create *args) { struct rpc_xprt *xprt; - struct xprt_class *t; + const struct xprt_class *t; - spin_lock(&xprt_list_lock); - list_for_each_entry(t, &xprt_list, list) { - if (t->ident == args->ident) { - spin_unlock(&xprt_list_lock); - goto found; - } + t = xprt_class_find_by_ident(args->ident); + if (!t) { + dprintk("RPC: transport (%d) not supported\n", args->ident); + return ERR_PTR(-EIO); } - spin_unlock(&xprt_list_lock); - printk(KERN_ERR "RPC: transport (%d) not supported\n", args->ident); - return ERR_PTR(-EIO); -found: xprt = t->setup(args); - if (IS_ERR(xprt)) { - dprintk("RPC: xprt_create_transport: failed, %ld\n", - -PTR_ERR(xprt)); + xprt_class_release(t); + + if (IS_ERR(xprt)) goto out; - } if (args->flags & XPRT_CREATE_NO_IDLE_TIMEOUT) xprt->idle_timeout = 0; INIT_WORK(&xprt->task_cleanup, xprt_autoclose); if (xprt_has_timer(xprt)) - setup_timer(&xprt->timer, xprt_init_autodisconnect, - (unsigned long)xprt); + timer_setup(&xprt->timer, xprt_init_autodisconnect, 0); else - init_timer(&xprt->timer); + timer_setup(&xprt->timer, NULL, 0); if (strlen(args->servername) > RPC_MAXNETNAMELEN) { xprt_destroy(xprt); @@ -1319,43 +2090,69 @@ found: return ERR_PTR(-ENOMEM); } - dprintk("RPC: created transport %p with %u slots\n", xprt, - xprt->max_reqs); + rpc_xprt_debugfs_register(xprt); + + trace_xprt_create(xprt); out: return xprt; } -/** - * xprt_destroy - destroy an RPC transport, killing off all requests. - * @xprt: transport to destroy - * - */ -static void xprt_destroy(struct rpc_xprt *xprt) +static void xprt_destroy_cb(struct work_struct *work) { - dprintk("RPC: destroying transport %p\n", xprt); - del_timer_sync(&xprt->timer); + struct rpc_xprt *xprt = + container_of(work, struct rpc_xprt, task_cleanup); + + trace_xprt_destroy(xprt); + rpc_xprt_debugfs_unregister(xprt); rpc_destroy_wait_queue(&xprt->binding); rpc_destroy_wait_queue(&xprt->pending); rpc_destroy_wait_queue(&xprt->sending); rpc_destroy_wait_queue(&xprt->backlog); - cancel_work_sync(&xprt->task_cleanup); kfree(xprt->servername); /* + * Destroy any existing back channel + */ + xprt_destroy_backchannel(xprt, UINT_MAX); + + /* * Tear down transport state and free the rpc_xprt */ xprt->ops->destroy(xprt); } /** - * xprt_put - release a reference to an RPC transport. - * @xprt: pointer to the transport + * xprt_destroy - destroy an RPC transport, killing off all requests. + * @xprt: transport to destroy * */ -void xprt_put(struct rpc_xprt *xprt) +static void xprt_destroy(struct rpc_xprt *xprt) { - if (atomic_dec_and_test(&xprt->count)) - xprt_destroy(xprt); + /* + * Exclude transport connect/disconnect handlers and autoclose + */ + wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_UNINTERRUPTIBLE); + + /* + * xprt_schedule_autodisconnect() can run after XPRT_LOCKED + * is cleared. We use ->transport_lock to ensure the mod_timer() + * can only run *before* del_time_sync(), never after. + */ + spin_lock(&xprt->transport_lock); + timer_delete_sync(&xprt->timer); + spin_unlock(&xprt->transport_lock); + + /* + * Destroy sockets etc from the system workqueue so they can + * safely flush receive work running on rpciod. + */ + INIT_WORK(&xprt->task_cleanup, xprt_destroy_cb); + schedule_work(&xprt->task_cleanup); +} + +static void xprt_destroy_kref(struct kref *kref) +{ + xprt_destroy(container_of(kref, struct rpc_xprt, kref)); } /** @@ -1365,7 +2162,52 @@ void xprt_put(struct rpc_xprt *xprt) */ struct rpc_xprt *xprt_get(struct rpc_xprt *xprt) { - if (atomic_inc_not_zero(&xprt->count)) + if (xprt != NULL && kref_get_unless_zero(&xprt->kref)) return xprt; return NULL; } +EXPORT_SYMBOL_GPL(xprt_get); + +/** + * xprt_put - release a reference to an RPC transport. + * @xprt: pointer to the transport + * + */ +void xprt_put(struct rpc_xprt *xprt) +{ + if (xprt != NULL) + kref_put(&xprt->kref, xprt_destroy_kref); +} +EXPORT_SYMBOL_GPL(xprt_put); + +void xprt_set_offline_locked(struct rpc_xprt *xprt, struct rpc_xprt_switch *xps) +{ + if (!test_and_set_bit(XPRT_OFFLINE, &xprt->state)) { + spin_lock(&xps->xps_lock); + xps->xps_nactive--; + spin_unlock(&xps->xps_lock); + } +} + +void xprt_set_online_locked(struct rpc_xprt *xprt, struct rpc_xprt_switch *xps) +{ + if (test_and_clear_bit(XPRT_OFFLINE, &xprt->state)) { + spin_lock(&xps->xps_lock); + xps->xps_nactive++; + spin_unlock(&xps->xps_lock); + } +} + +void xprt_delete_locked(struct rpc_xprt *xprt, struct rpc_xprt_switch *xps) +{ + if (test_and_set_bit(XPRT_REMOVE, &xprt->state)) + return; + + xprt_force_disconnect(xprt); + if (!test_bit(XPRT_CONNECTED, &xprt->state)) + return; + + if (!xprt->sending.qlen && !xprt->pending.qlen && + !xprt->backlog.qlen && !atomic_long_read(&xprt->queuelen)) + rpc_xprt_switch_remove_xprt(xps, xprt, true); +} diff --git a/net/sunrpc/xprtmultipath.c b/net/sunrpc/xprtmultipath.c new file mode 100644 index 000000000000..4c5e08b0aa64 --- /dev/null +++ b/net/sunrpc/xprtmultipath.c @@ -0,0 +1,672 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Multipath support for RPC + * + * Copyright (c) 2015, 2016, Primary Data, Inc. All rights reserved. + * + * Trond Myklebust <trond.myklebust@primarydata.com> + * + */ +#include <linux/atomic.h> +#include <linux/types.h> +#include <linux/kref.h> +#include <linux/list.h> +#include <linux/rcupdate.h> +#include <linux/rculist.h> +#include <linux/slab.h> +#include <linux/spinlock.h> +#include <linux/sunrpc/xprt.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/xprtmultipath.h> + +#include "sysfs.h" + +typedef struct rpc_xprt *(*xprt_switch_find_xprt_t)(struct rpc_xprt_switch *xps, + const struct rpc_xprt *cur); + +static const struct rpc_xprt_iter_ops rpc_xprt_iter_singular; +static const struct rpc_xprt_iter_ops rpc_xprt_iter_roundrobin; +static const struct rpc_xprt_iter_ops rpc_xprt_iter_listall; +static const struct rpc_xprt_iter_ops rpc_xprt_iter_listoffline; + +static void xprt_switch_add_xprt_locked(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt) +{ + if (unlikely(xprt_get(xprt) == NULL)) + return; + list_add_tail_rcu(&xprt->xprt_switch, &xps->xps_xprt_list); + smp_wmb(); + if (xps->xps_nxprts == 0) + xps->xps_net = xprt->xprt_net; + xps->xps_nxprts++; + xps->xps_nactive++; +} + +/** + * rpc_xprt_switch_add_xprt - Add a new rpc_xprt to an rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * @xprt: pointer to struct rpc_xprt + * + * Adds xprt to the end of the list of struct rpc_xprt in xps. + */ +void rpc_xprt_switch_add_xprt(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt) +{ + if (xprt == NULL) + return; + spin_lock(&xps->xps_lock); + if (xps->xps_net == xprt->xprt_net || xps->xps_net == NULL) + xprt_switch_add_xprt_locked(xps, xprt); + spin_unlock(&xps->xps_lock); + rpc_sysfs_xprt_setup(xps, xprt, GFP_KERNEL); +} + +static void xprt_switch_remove_xprt_locked(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, bool offline) +{ + if (unlikely(xprt == NULL)) + return; + if (!test_bit(XPRT_OFFLINE, &xprt->state) && offline) + xps->xps_nactive--; + xps->xps_nxprts--; + if (xps->xps_nxprts == 0) + xps->xps_net = NULL; + smp_wmb(); + list_del_rcu(&xprt->xprt_switch); +} + +/** + * rpc_xprt_switch_remove_xprt - Removes an rpc_xprt from a rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * @xprt: pointer to struct rpc_xprt + * @offline: indicates if the xprt that's being removed is in an offline state + * + * Removes xprt from the list of struct rpc_xprt in xps. + */ +void rpc_xprt_switch_remove_xprt(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, bool offline) +{ + spin_lock(&xps->xps_lock); + xprt_switch_remove_xprt_locked(xps, xprt, offline); + spin_unlock(&xps->xps_lock); + xprt_put(xprt); +} + +/** + * rpc_xprt_switch_get_main_xprt - Get the 'main' xprt for an xprt switch. + * @xps: pointer to struct rpc_xprt_switch. + */ +struct rpc_xprt *rpc_xprt_switch_get_main_xprt(struct rpc_xprt_switch *xps) +{ + struct rpc_xprt_iter xpi; + struct rpc_xprt *xprt; + + xprt_iter_init_listall(&xpi, xps); + + xprt = xprt_iter_get_next(&xpi); + while (xprt && !xprt->main) { + xprt_put(xprt); + xprt = xprt_iter_get_next(&xpi); + } + + xprt_iter_destroy(&xpi); + return xprt; +} + +static DEFINE_IDA(rpc_xprtswitch_ids); + +void xprt_multipath_cleanup_ids(void) +{ + ida_destroy(&rpc_xprtswitch_ids); +} + +static int xprt_switch_alloc_id(struct rpc_xprt_switch *xps, gfp_t gfp_flags) +{ + int id; + + id = ida_alloc(&rpc_xprtswitch_ids, gfp_flags); + if (id < 0) + return id; + + xps->xps_id = id; + return 0; +} + +static void xprt_switch_free_id(struct rpc_xprt_switch *xps) +{ + ida_free(&rpc_xprtswitch_ids, xps->xps_id); +} + +/** + * xprt_switch_alloc - Allocate a new struct rpc_xprt_switch + * @xprt: pointer to struct rpc_xprt + * @gfp_flags: allocation flags + * + * On success, returns an initialised struct rpc_xprt_switch, containing + * the entry xprt. Returns NULL on failure. + */ +struct rpc_xprt_switch *xprt_switch_alloc(struct rpc_xprt *xprt, + gfp_t gfp_flags) +{ + struct rpc_xprt_switch *xps; + + xps = kmalloc(sizeof(*xps), gfp_flags); + if (xps != NULL) { + spin_lock_init(&xps->xps_lock); + kref_init(&xps->xps_kref); + xprt_switch_alloc_id(xps, gfp_flags); + xps->xps_nxprts = xps->xps_nactive = 0; + atomic_long_set(&xps->xps_queuelen, 0); + xps->xps_net = NULL; + INIT_LIST_HEAD(&xps->xps_xprt_list); + xps->xps_iter_ops = &rpc_xprt_iter_singular; + rpc_sysfs_xprt_switch_setup(xps, xprt, gfp_flags); + xprt_switch_add_xprt_locked(xps, xprt); + xps->xps_nunique_destaddr_xprts = 1; + rpc_sysfs_xprt_setup(xps, xprt, gfp_flags); + } + + return xps; +} + +static void xprt_switch_free_entries(struct rpc_xprt_switch *xps) +{ + spin_lock(&xps->xps_lock); + while (!list_empty(&xps->xps_xprt_list)) { + struct rpc_xprt *xprt; + + xprt = list_first_entry(&xps->xps_xprt_list, + struct rpc_xprt, xprt_switch); + xprt_switch_remove_xprt_locked(xps, xprt, true); + spin_unlock(&xps->xps_lock); + xprt_put(xprt); + spin_lock(&xps->xps_lock); + } + spin_unlock(&xps->xps_lock); +} + +static void xprt_switch_free(struct kref *kref) +{ + struct rpc_xprt_switch *xps = container_of(kref, + struct rpc_xprt_switch, xps_kref); + + xprt_switch_free_entries(xps); + rpc_sysfs_xprt_switch_destroy(xps); + xprt_switch_free_id(xps); + kfree_rcu(xps, xps_rcu); +} + +/** + * xprt_switch_get - Return a reference to a rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * + * Returns a reference to xps unless the refcount is already zero. + */ +struct rpc_xprt_switch *xprt_switch_get(struct rpc_xprt_switch *xps) +{ + if (xps != NULL && kref_get_unless_zero(&xps->xps_kref)) + return xps; + return NULL; +} + +/** + * xprt_switch_put - Release a reference to a rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * + * Release the reference to xps, and free it once the refcount is zero. + */ +void xprt_switch_put(struct rpc_xprt_switch *xps) +{ + if (xps != NULL) + kref_put(&xps->xps_kref, xprt_switch_free); +} + +/** + * rpc_xprt_switch_set_roundrobin - Set a round-robin policy on rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * + * Sets a round-robin default policy for iterators acting on xps. + */ +void rpc_xprt_switch_set_roundrobin(struct rpc_xprt_switch *xps) +{ + if (READ_ONCE(xps->xps_iter_ops) != &rpc_xprt_iter_roundrobin) + WRITE_ONCE(xps->xps_iter_ops, &rpc_xprt_iter_roundrobin); +} + +static +const struct rpc_xprt_iter_ops *xprt_iter_ops(const struct rpc_xprt_iter *xpi) +{ + if (xpi->xpi_ops != NULL) + return xpi->xpi_ops; + return rcu_dereference(xpi->xpi_xpswitch)->xps_iter_ops; +} + +static +void xprt_iter_no_rewind(struct rpc_xprt_iter *xpi) +{ +} + +static +void xprt_iter_default_rewind(struct rpc_xprt_iter *xpi) +{ + WRITE_ONCE(xpi->xpi_cursor, NULL); +} + +static +bool xprt_is_active(const struct rpc_xprt *xprt) +{ + return (kref_read(&xprt->kref) != 0 && + !test_bit(XPRT_OFFLINE, &xprt->state)); +} + +static +struct rpc_xprt *xprt_switch_find_first_entry(struct list_head *head) +{ + struct rpc_xprt *pos; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (xprt_is_active(pos)) + return pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_switch_find_first_entry_offline(struct list_head *head) +{ + struct rpc_xprt *pos; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (!xprt_is_active(pos)) + return pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_iter_first_entry(struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch); + + if (xps == NULL) + return NULL; + return xprt_switch_find_first_entry(&xps->xps_xprt_list); +} + +static +struct rpc_xprt *_xprt_switch_find_current_entry(struct list_head *head, + const struct rpc_xprt *cur, + bool find_active) +{ + struct rpc_xprt *pos; + bool found = false; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (cur == pos) + found = true; + if (found && ((find_active && xprt_is_active(pos)) || + (!find_active && !xprt_is_active(pos)))) + return pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_switch_find_current_entry(struct list_head *head, + const struct rpc_xprt *cur) +{ + return _xprt_switch_find_current_entry(head, cur, true); +} + +static +struct rpc_xprt * _xprt_iter_current_entry(struct rpc_xprt_iter *xpi, + struct rpc_xprt *first_entry(struct list_head *head), + struct rpc_xprt *current_entry(struct list_head *head, + const struct rpc_xprt *cur)) +{ + struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch); + struct list_head *head; + + if (xps == NULL) + return NULL; + head = &xps->xps_xprt_list; + if (xpi->xpi_cursor == NULL || xps->xps_nxprts < 2) + return first_entry(head); + return current_entry(head, xpi->xpi_cursor); +} + +static +struct rpc_xprt *xprt_iter_current_entry(struct rpc_xprt_iter *xpi) +{ + return _xprt_iter_current_entry(xpi, xprt_switch_find_first_entry, + xprt_switch_find_current_entry); +} + +static +struct rpc_xprt *xprt_switch_find_current_entry_offline(struct list_head *head, + const struct rpc_xprt *cur) +{ + return _xprt_switch_find_current_entry(head, cur, false); +} + +static +struct rpc_xprt *xprt_iter_current_entry_offline(struct rpc_xprt_iter *xpi) +{ + return _xprt_iter_current_entry(xpi, + xprt_switch_find_first_entry_offline, + xprt_switch_find_current_entry_offline); +} + +static +bool __rpc_xprt_switch_has_addr(struct rpc_xprt_switch *xps, + const struct sockaddr *sap) +{ + struct list_head *head; + struct rpc_xprt *pos; + + if (xps == NULL || sap == NULL) + return false; + + head = &xps->xps_xprt_list; + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (rpc_cmp_addr_port(sap, (struct sockaddr *)&pos->addr)) { + pr_info("RPC: addr %s already in xprt switch\n", + pos->address_strings[RPC_DISPLAY_ADDR]); + return true; + } + } + return false; +} + +bool rpc_xprt_switch_has_addr(struct rpc_xprt_switch *xps, + const struct sockaddr *sap) +{ + bool res; + + rcu_read_lock(); + res = __rpc_xprt_switch_has_addr(xps, sap); + rcu_read_unlock(); + + return res; +} + +static +struct rpc_xprt *xprt_switch_find_next_entry(struct list_head *head, + const struct rpc_xprt *cur, bool check_active) +{ + struct rpc_xprt *pos, *prev = NULL; + bool found = false; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (cur == prev) + found = true; + /* for request to return active transports return only + * active, for request to return offline transports + * return only offline + */ + if (found && ((check_active && xprt_is_active(pos)) || + (!check_active && !xprt_is_active(pos)))) + return pos; + prev = pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_switch_set_next_cursor(struct rpc_xprt_switch *xps, + struct rpc_xprt **cursor, + xprt_switch_find_xprt_t find_next) +{ + struct rpc_xprt *pos, *old; + + old = smp_load_acquire(cursor); + pos = find_next(xps, old); + smp_store_release(cursor, pos); + return pos; +} + +static +struct rpc_xprt *xprt_iter_next_entry_multiple(struct rpc_xprt_iter *xpi, + xprt_switch_find_xprt_t find_next) +{ + struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch); + + if (xps == NULL) + return NULL; + return xprt_switch_set_next_cursor(xps, &xpi->xpi_cursor, find_next); +} + +static +struct rpc_xprt *__xprt_switch_find_next_entry_roundrobin(struct list_head *head, + const struct rpc_xprt *cur) +{ + struct rpc_xprt *ret; + + ret = xprt_switch_find_next_entry(head, cur, true); + if (ret != NULL) + return ret; + return xprt_switch_find_first_entry(head); +} + +static +struct rpc_xprt *xprt_switch_find_next_entry_roundrobin(struct rpc_xprt_switch *xps, + const struct rpc_xprt *cur) +{ + struct list_head *head = &xps->xps_xprt_list; + struct rpc_xprt *xprt; + unsigned int nactive; + + for (;;) { + unsigned long xprt_queuelen, xps_queuelen; + + xprt = __xprt_switch_find_next_entry_roundrobin(head, cur); + if (!xprt) + break; + xprt_queuelen = atomic_long_read(&xprt->queuelen); + xps_queuelen = atomic_long_read(&xps->xps_queuelen); + nactive = READ_ONCE(xps->xps_nactive); + /* Exit loop if xprt_queuelen <= average queue length */ + if (xprt_queuelen * nactive <= xps_queuelen) + break; + cur = xprt; + } + return xprt; +} + +static +struct rpc_xprt *xprt_iter_next_entry_roundrobin(struct rpc_xprt_iter *xpi) +{ + return xprt_iter_next_entry_multiple(xpi, + xprt_switch_find_next_entry_roundrobin); +} + +static +struct rpc_xprt *xprt_switch_find_next_entry_all(struct rpc_xprt_switch *xps, + const struct rpc_xprt *cur) +{ + return xprt_switch_find_next_entry(&xps->xps_xprt_list, cur, true); +} + +static +struct rpc_xprt *xprt_switch_find_next_entry_offline(struct rpc_xprt_switch *xps, + const struct rpc_xprt *cur) +{ + return xprt_switch_find_next_entry(&xps->xps_xprt_list, cur, false); +} + +static +struct rpc_xprt *xprt_iter_next_entry_all(struct rpc_xprt_iter *xpi) +{ + return xprt_iter_next_entry_multiple(xpi, + xprt_switch_find_next_entry_all); +} + +static +struct rpc_xprt *xprt_iter_next_entry_offline(struct rpc_xprt_iter *xpi) +{ + return xprt_iter_next_entry_multiple(xpi, + xprt_switch_find_next_entry_offline); +} + +/* + * xprt_iter_rewind - Resets the xprt iterator + * @xpi: pointer to rpc_xprt_iter + * + * Resets xpi to ensure that it points to the first entry in the list + * of transports. + */ +void xprt_iter_rewind(struct rpc_xprt_iter *xpi) +{ + rcu_read_lock(); + xprt_iter_ops(xpi)->xpi_rewind(xpi); + rcu_read_unlock(); +} + +static void __xprt_iter_init(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps, + const struct rpc_xprt_iter_ops *ops) +{ + rcu_assign_pointer(xpi->xpi_xpswitch, xprt_switch_get(xps)); + xpi->xpi_cursor = NULL; + xpi->xpi_ops = ops; +} + +/** + * xprt_iter_init - Initialise an xprt iterator + * @xpi: pointer to rpc_xprt_iter + * @xps: pointer to rpc_xprt_switch + * + * Initialises the iterator to use the default iterator ops + * as set in xps. This function is mainly intended for internal + * use in the rpc_client. + */ +void xprt_iter_init(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps) +{ + __xprt_iter_init(xpi, xps, NULL); +} + +/** + * xprt_iter_init_listall - Initialise an xprt iterator + * @xpi: pointer to rpc_xprt_iter + * @xps: pointer to rpc_xprt_switch + * + * Initialises the iterator to iterate once through the entire list + * of entries in xps. + */ +void xprt_iter_init_listall(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps) +{ + __xprt_iter_init(xpi, xps, &rpc_xprt_iter_listall); +} + +void xprt_iter_init_listoffline(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps) +{ + __xprt_iter_init(xpi, xps, &rpc_xprt_iter_listoffline); +} + +/** + * xprt_iter_xchg_switch - Atomically swap out the rpc_xprt_switch + * @xpi: pointer to rpc_xprt_iter + * @newswitch: pointer to a new rpc_xprt_switch or NULL + * + * Swaps out the existing xpi->xpi_xpswitch with a new value. + */ +struct rpc_xprt_switch *xprt_iter_xchg_switch(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *newswitch) +{ + struct rpc_xprt_switch __rcu *oldswitch; + + /* Atomically swap out the old xpswitch */ + oldswitch = xchg(&xpi->xpi_xpswitch, RCU_INITIALIZER(newswitch)); + if (newswitch != NULL) + xprt_iter_rewind(xpi); + return rcu_dereference_protected(oldswitch, true); +} + +/** + * xprt_iter_destroy - Destroys the xprt iterator + * @xpi: pointer to rpc_xprt_iter + */ +void xprt_iter_destroy(struct rpc_xprt_iter *xpi) +{ + xprt_switch_put(xprt_iter_xchg_switch(xpi, NULL)); +} + +/** + * xprt_iter_xprt - Returns the rpc_xprt pointed to by the cursor + * @xpi: pointer to rpc_xprt_iter + * + * Returns a pointer to the struct rpc_xprt that is currently + * pointed to by the cursor. + * Caller must be holding rcu_read_lock(). + */ +struct rpc_xprt *xprt_iter_xprt(struct rpc_xprt_iter *xpi) +{ + WARN_ON_ONCE(!rcu_read_lock_held()); + return xprt_iter_ops(xpi)->xpi_xprt(xpi); +} + +static +struct rpc_xprt *xprt_iter_get_helper(struct rpc_xprt_iter *xpi, + struct rpc_xprt *(*fn)(struct rpc_xprt_iter *)) +{ + struct rpc_xprt *ret; + + do { + ret = fn(xpi); + if (ret == NULL) + break; + ret = xprt_get(ret); + } while (ret == NULL); + return ret; +} + +/** + * xprt_iter_get_next - Returns the next rpc_xprt following the cursor + * @xpi: pointer to rpc_xprt_iter + * + * Returns a reference to the struct rpc_xprt that immediately follows the + * entry pointed to by the cursor. + */ +struct rpc_xprt *xprt_iter_get_next(struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = xprt_iter_get_helper(xpi, xprt_iter_ops(xpi)->xpi_next); + rcu_read_unlock(); + return xprt; +} + +/* Policy for always returning the first entry in the rpc_xprt_switch */ +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_singular = { + .xpi_rewind = xprt_iter_no_rewind, + .xpi_xprt = xprt_iter_first_entry, + .xpi_next = xprt_iter_first_entry, +}; + +/* Policy for round-robin iteration of entries in the rpc_xprt_switch */ +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_roundrobin = { + .xpi_rewind = xprt_iter_default_rewind, + .xpi_xprt = xprt_iter_current_entry, + .xpi_next = xprt_iter_next_entry_roundrobin, +}; + +/* Policy for once-through iteration of entries in the rpc_xprt_switch */ +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_listall = { + .xpi_rewind = xprt_iter_default_rewind, + .xpi_xprt = xprt_iter_current_entry, + .xpi_next = xprt_iter_next_entry_all, +}; + +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_listoffline = { + .xpi_rewind = xprt_iter_default_rewind, + .xpi_xprt = xprt_iter_current_entry_offline, + .xpi_next = xprt_iter_next_entry_offline, +}; diff --git a/net/sunrpc/xprtrdma/Makefile b/net/sunrpc/xprtrdma/Makefile index 5a8f268bdd30..3232aa23cdb4 100644 --- a/net/sunrpc/xprtrdma/Makefile +++ b/net/sunrpc/xprtrdma/Makefile @@ -1,8 +1,8 @@ -obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma.o +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_SUNRPC_XPRT_RDMA) += rpcrdma.o -xprtrdma-y := transport.o rpc_rdma.o verbs.o - -obj-$(CONFIG_SUNRPC_XPRT_RDMA) += svcrdma.o - -svcrdma-y := svc_rdma.o svc_rdma_transport.o \ - svc_rdma_marshal.o svc_rdma_sendto.o svc_rdma_recvfrom.o +rpcrdma-y := transport.o rpc_rdma.o verbs.o frwr_ops.o ib_client.o \ + svc_rdma.o svc_rdma_backchannel.o svc_rdma_transport.o \ + svc_rdma_sendto.o svc_rdma_recvfrom.o svc_rdma_rw.o \ + svc_rdma_pcl.o module.o +rpcrdma-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel.o diff --git a/net/sunrpc/xprtrdma/backchannel.c b/net/sunrpc/xprtrdma/backchannel.c new file mode 100644 index 000000000000..8c817e755262 --- /dev/null +++ b/net/sunrpc/xprtrdma/backchannel.c @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015-2020, Oracle and/or its affiliates. + * + * Support for reverse-direction RPCs on RPC/RDMA. + */ + +#include <linux/sunrpc/xprt.h> +#include <linux/sunrpc/svc.h> +#include <linux/sunrpc/svc_xprt.h> +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +#undef RPCRDMA_BACKCHANNEL_DEBUG + +/** + * xprt_rdma_bc_setup - Pre-allocate resources for handling backchannel requests + * @xprt: transport associated with these backchannel resources + * @reqs: number of concurrent incoming requests to expect + * + * Returns 0 on success; otherwise a negative errno + */ +int xprt_rdma_bc_setup(struct rpc_xprt *xprt, unsigned int reqs) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + r_xprt->rx_buf.rb_bc_srv_max_requests = RPCRDMA_BACKWARD_WRS >> 1; + trace_xprtrdma_cb_setup(r_xprt, reqs); + return 0; +} + +/** + * xprt_rdma_bc_maxpayload - Return maximum backchannel message size + * @xprt: transport + * + * Returns maximum size, in bytes, of a backchannel message + */ +size_t xprt_rdma_bc_maxpayload(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_ep *ep = r_xprt->rx_ep; + size_t maxmsg; + + maxmsg = min_t(unsigned int, ep->re_inline_send, ep->re_inline_recv); + maxmsg = min_t(unsigned int, maxmsg, PAGE_SIZE); + return maxmsg - RPCRDMA_HDRLEN_MIN; +} + +unsigned int xprt_rdma_bc_max_slots(struct rpc_xprt *xprt) +{ + return RPCRDMA_BACKWARD_WRS >> 1; +} + +static int rpcrdma_bc_marshal_reply(struct rpc_rqst *rqst) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + __be32 *p; + + rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0); + xdr_init_encode(&req->rl_stream, &req->rl_hdrbuf, + rdmab_data(req->rl_rdmabuf), rqst); + + p = xdr_reserve_space(&req->rl_stream, 28); + if (unlikely(!p)) + return -EIO; + *p++ = rqst->rq_xid; + *p++ = rpcrdma_version; + *p++ = cpu_to_be32(r_xprt->rx_buf.rb_bc_srv_max_requests); + *p++ = rdma_msg; + *p++ = xdr_zero; + *p++ = xdr_zero; + *p = xdr_zero; + + if (rpcrdma_prepare_send_sges(r_xprt, req, RPCRDMA_HDRLEN_MIN, + &rqst->rq_snd_buf, rpcrdma_noch_pullup)) + return -EIO; + + trace_xprtrdma_cb_reply(r_xprt, rqst); + return 0; +} + +/** + * xprt_rdma_bc_send_reply - marshal and send a backchannel reply + * @rqst: RPC rqst with a backchannel RPC reply in rq_snd_buf + * + * Caller holds the transport's write lock. + * + * Returns: + * %0 if the RPC message has been sent + * %-ENOTCONN if the caller should reconnect and call again + * %-EIO if a permanent error occurred and the request was not + * sent. Do not try to send this message again. + */ +int xprt_rdma_bc_send_reply(struct rpc_rqst *rqst) +{ + struct rpc_xprt *xprt = rqst->rq_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + int rc; + + if (!xprt_connected(xprt)) + return -ENOTCONN; + + if (!xprt_request_get_cong(xprt, rqst)) + return -EBADSLT; + + rc = rpcrdma_bc_marshal_reply(rqst); + if (rc < 0) + goto failed_marshal; + + if (frwr_send(r_xprt, req)) + goto drop_connection; + return 0; + +failed_marshal: + if (rc != -ENOTCONN) + return rc; +drop_connection: + xprt_rdma_close(xprt); + return -ENOTCONN; +} + +/** + * xprt_rdma_bc_destroy - Release resources for handling backchannel requests + * @xprt: transport associated with these backchannel resources + * @reqs: number of incoming requests to destroy; ignored + */ +void xprt_rdma_bc_destroy(struct rpc_xprt *xprt, unsigned int reqs) +{ + struct rpc_rqst *rqst, *tmp; + + spin_lock(&xprt->bc_pa_lock); + list_for_each_entry_safe(rqst, tmp, &xprt->bc_pa_list, rq_bc_pa_list) { + list_del(&rqst->rq_bc_pa_list); + spin_unlock(&xprt->bc_pa_lock); + + rpcrdma_req_destroy(rpcr_to_rdmar(rqst)); + + spin_lock(&xprt->bc_pa_lock); + } + spin_unlock(&xprt->bc_pa_lock); +} + +/** + * xprt_rdma_bc_free_rqst - Release a backchannel rqst + * @rqst: request to release + */ +void xprt_rdma_bc_free_rqst(struct rpc_rqst *rqst) +{ + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + struct rpcrdma_rep *rep = req->rl_reply; + struct rpc_xprt *xprt = rqst->rq_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + rpcrdma_rep_put(&r_xprt->rx_buf, rep); + req->rl_reply = NULL; + + spin_lock(&xprt->bc_pa_lock); + list_add_tail(&rqst->rq_bc_pa_list, &xprt->bc_pa_list); + spin_unlock(&xprt->bc_pa_lock); + xprt_put(xprt); +} + +static struct rpc_rqst *rpcrdma_bc_rqst_get(struct rpcrdma_xprt *r_xprt) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpcrdma_req *req; + struct rpc_rqst *rqst; + size_t size; + + spin_lock(&xprt->bc_pa_lock); + rqst = list_first_entry_or_null(&xprt->bc_pa_list, struct rpc_rqst, + rq_bc_pa_list); + if (!rqst) + goto create_req; + list_del(&rqst->rq_bc_pa_list); + spin_unlock(&xprt->bc_pa_lock); + return rqst; + +create_req: + spin_unlock(&xprt->bc_pa_lock); + + /* Set a limit to prevent a remote from overrunning our resources. + */ + if (xprt->bc_alloc_count >= RPCRDMA_BACKWARD_WRS) + return NULL; + + size = min_t(size_t, r_xprt->rx_ep->re_inline_recv, PAGE_SIZE); + req = rpcrdma_req_create(r_xprt, size); + if (!req) + return NULL; + if (rpcrdma_req_setup(r_xprt, req)) { + rpcrdma_req_destroy(req); + return NULL; + } + + xprt->bc_alloc_count++; + rqst = &req->rl_slot; + rqst->rq_xprt = xprt; + __set_bit(RPC_BC_PA_IN_USE, &rqst->rq_bc_pa_state); + xdr_buf_init(&rqst->rq_snd_buf, rdmab_data(req->rl_sendbuf), size); + return rqst; +} + +/** + * rpcrdma_bc_receive_call - Handle a reverse-direction Call + * @r_xprt: transport receiving the call + * @rep: receive buffer containing the call + * + * Operational assumptions: + * o Backchannel credits are ignored, just as the NFS server + * forechannel currently does + * o The ULP manages a replay cache (eg, NFSv4.1 sessions). + * No replay detection is done at the transport level + */ +void rpcrdma_bc_receive_call(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_rep *rep) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct svc_serv *bc_serv; + struct rpcrdma_req *req; + struct rpc_rqst *rqst; + struct xdr_buf *buf; + size_t size; + __be32 *p; + + p = xdr_inline_decode(&rep->rr_stream, 0); + size = xdr_stream_remaining(&rep->rr_stream); + +#ifdef RPCRDMA_BACKCHANNEL_DEBUG + pr_info("RPC: %s: callback XID %08x, length=%u\n", + __func__, be32_to_cpup(p), size); + pr_info("RPC: %s: %*ph\n", __func__, size, p); +#endif + + rqst = rpcrdma_bc_rqst_get(r_xprt); + if (!rqst) + goto out_overflow; + + rqst->rq_reply_bytes_recvd = 0; + rqst->rq_xid = *p; + + rqst->rq_private_buf.len = size; + + buf = &rqst->rq_rcv_buf; + memset(buf, 0, sizeof(*buf)); + buf->head[0].iov_base = p; + buf->head[0].iov_len = size; + buf->len = size; + + /* The receive buffer has to be hooked to the rpcrdma_req + * so that it is not released while the req is pointing + * to its buffer, and so that it can be reposted after + * the Upper Layer is done decoding it. + */ + req = rpcr_to_rdmar(rqst); + req->rl_reply = rep; + trace_xprtrdma_cb_call(r_xprt, rqst); + + /* Queue rqst for ULP's callback service */ + bc_serv = xprt->bc_serv; + xprt_get(xprt); + lwq_enqueue(&rqst->rq_bc_list, &bc_serv->sv_cb_list); + + svc_pool_wake_idle_thread(&bc_serv->sv_pools[0]); + + r_xprt->rx_stats.bcall_count++; + return; + +out_overflow: + pr_warn("RPC/RDMA backchannel overflow\n"); + xprt_force_disconnect(xprt); + /* This receive buffer gets reposted automatically + * when the connection is re-established. + */ + return; +} diff --git a/net/sunrpc/xprtrdma/frwr_ops.c b/net/sunrpc/xprtrdma/frwr_ops.c new file mode 100644 index 000000000000..31434aeb8e29 --- /dev/null +++ b/net/sunrpc/xprtrdma/frwr_ops.c @@ -0,0 +1,697 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015, 2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + */ + +/* Lightweight memory registration using Fast Registration Work + * Requests (FRWR). + * + * FRWR features ordered asynchronous registration and invalidation + * of arbitrarily-sized memory regions. This is the fastest and safest + * but most complex memory registration mode. + */ + +/* Normal operation + * + * A Memory Region is prepared for RDMA Read or Write using a FAST_REG + * Work Request (frwr_map). When the RDMA operation is finished, this + * Memory Region is invalidated using a LOCAL_INV Work Request + * (frwr_unmap_async and frwr_unmap_sync). + * + * Typically FAST_REG Work Requests are not signaled, and neither are + * RDMA Send Work Requests (with the exception of signaling occasionally + * to prevent provider work queue overflows). This greatly reduces HCA + * interrupt workload. + */ + +/* Transport recovery + * + * frwr_map and frwr_unmap_* cannot run at the same time the transport + * connect worker is running. The connect worker holds the transport + * send lock, just as ->send_request does. This prevents frwr_map and + * the connect worker from running concurrently. When a connection is + * closed, the Receive completion queue is drained before the allowing + * the connect worker to get control. This prevents frwr_unmap and the + * connect worker from running concurrently. + * + * When the underlying transport disconnects, MRs that are in flight + * are flushed and are likely unusable. Thus all MRs are destroyed. + * New MRs are created on demand. + */ + +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +static void frwr_cid_init(struct rpcrdma_ep *ep, + struct rpcrdma_mr *mr) +{ + struct rpc_rdma_cid *cid = &mr->mr_cid; + + cid->ci_queue_id = ep->re_attr.send_cq->res.id; + cid->ci_completion_id = mr->mr_ibmr->res.id; +} + +static void frwr_mr_unmap(struct rpcrdma_mr *mr) +{ + if (mr->mr_device) { + trace_xprtrdma_mr_unmap(mr); + ib_dma_unmap_sg(mr->mr_device, mr->mr_sg, mr->mr_nents, + mr->mr_dir); + mr->mr_device = NULL; + } +} + +/** + * frwr_mr_release - Destroy one MR + * @mr: MR allocated by frwr_mr_init + * + */ +void frwr_mr_release(struct rpcrdma_mr *mr) +{ + int rc; + + frwr_mr_unmap(mr); + + rc = ib_dereg_mr(mr->mr_ibmr); + if (rc) + trace_xprtrdma_frwr_dereg(mr, rc); + kfree(mr->mr_sg); + kfree(mr); +} + +static void frwr_mr_put(struct rpcrdma_mr *mr) +{ + frwr_mr_unmap(mr); + + /* The MR is returned to the req's MR free list instead + * of to the xprt's MR free list. No spinlock is needed. + */ + rpcrdma_mr_push(mr, &mr->mr_req->rl_free_mrs); +} + +/** + * frwr_reset - Place MRs back on @req's free list + * @req: request to reset + * + * Used after a failed marshal. For FRWR, this means the MRs + * don't have to be fully released and recreated. + * + * NB: This is safe only as long as none of @req's MRs are + * involved with an ongoing asynchronous FAST_REG or LOCAL_INV + * Work Request. + */ +void frwr_reset(struct rpcrdma_req *req) +{ + struct rpcrdma_mr *mr; + + while ((mr = rpcrdma_mr_pop(&req->rl_registered))) + frwr_mr_put(mr); +} + +/** + * frwr_mr_init - Initialize one MR + * @r_xprt: controlling transport instance + * @mr: generic MR to prepare for FRWR + * + * Returns zero if successful. Otherwise a negative errno + * is returned. + */ +int frwr_mr_init(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr *mr) +{ + struct rpcrdma_ep *ep = r_xprt->rx_ep; + unsigned int depth = ep->re_max_fr_depth; + struct scatterlist *sg; + struct ib_mr *frmr; + + sg = kcalloc_node(depth, sizeof(*sg), XPRTRDMA_GFP_FLAGS, + ibdev_to_node(ep->re_id->device)); + if (!sg) + return -ENOMEM; + + frmr = ib_alloc_mr(ep->re_pd, ep->re_mrtype, depth); + if (IS_ERR(frmr)) + goto out_mr_err; + + mr->mr_xprt = r_xprt; + mr->mr_ibmr = frmr; + mr->mr_device = NULL; + INIT_LIST_HEAD(&mr->mr_list); + init_completion(&mr->mr_linv_done); + frwr_cid_init(ep, mr); + + sg_init_table(sg, depth); + mr->mr_sg = sg; + return 0; + +out_mr_err: + kfree(sg); + trace_xprtrdma_frwr_alloc(mr, PTR_ERR(frmr)); + return PTR_ERR(frmr); +} + +/** + * frwr_query_device - Prepare a transport for use with FRWR + * @ep: endpoint to fill in + * @device: RDMA device to query + * + * On success, sets: + * ep->re_attr + * ep->re_max_requests + * ep->re_max_rdma_segs + * ep->re_max_fr_depth + * ep->re_mrtype + * + * Return values: + * On success, returns zero. + * %-EINVAL - the device does not support FRWR memory registration + * %-ENOMEM - the device is not sufficiently capable for NFS/RDMA + */ +int frwr_query_device(struct rpcrdma_ep *ep, const struct ib_device *device) +{ + const struct ib_device_attr *attrs = &device->attrs; + int max_qp_wr, depth, delta; + unsigned int max_sge; + + if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS) || + attrs->max_fast_reg_page_list_len == 0) { + pr_err("rpcrdma: 'frwr' mode is not supported by device %s\n", + device->name); + return -EINVAL; + } + + max_sge = min_t(unsigned int, attrs->max_send_sge, + RPCRDMA_MAX_SEND_SGES); + if (max_sge < RPCRDMA_MIN_SEND_SGES) { + pr_err("rpcrdma: HCA provides only %u send SGEs\n", max_sge); + return -ENOMEM; + } + ep->re_attr.cap.max_send_sge = max_sge; + ep->re_attr.cap.max_recv_sge = 1; + + ep->re_mrtype = IB_MR_TYPE_MEM_REG; + if (attrs->kernel_cap_flags & IBK_SG_GAPS_REG) + ep->re_mrtype = IB_MR_TYPE_SG_GAPS; + + /* Quirk: Some devices advertise a large max_fast_reg_page_list_len + * capability, but perform optimally when the MRs are not larger + * than a page. + */ + if (attrs->max_sge_rd > RPCRDMA_MAX_HDR_SEGS) + ep->re_max_fr_depth = attrs->max_sge_rd; + else + ep->re_max_fr_depth = attrs->max_fast_reg_page_list_len; + if (ep->re_max_fr_depth > RPCRDMA_MAX_DATA_SEGS) + ep->re_max_fr_depth = RPCRDMA_MAX_DATA_SEGS; + + /* Add room for frwr register and invalidate WRs. + * 1. FRWR reg WR for head + * 2. FRWR invalidate WR for head + * 3. N FRWR reg WRs for pagelist + * 4. N FRWR invalidate WRs for pagelist + * 5. FRWR reg WR for tail + * 6. FRWR invalidate WR for tail + * 7. The RDMA_SEND WR + */ + depth = 7; + + /* Calculate N if the device max FRWR depth is smaller than + * RPCRDMA_MAX_DATA_SEGS. + */ + if (ep->re_max_fr_depth < RPCRDMA_MAX_DATA_SEGS) { + delta = RPCRDMA_MAX_DATA_SEGS - ep->re_max_fr_depth; + do { + depth += 2; /* FRWR reg + invalidate */ + delta -= ep->re_max_fr_depth; + } while (delta > 0); + } + + max_qp_wr = attrs->max_qp_wr; + max_qp_wr -= RPCRDMA_BACKWARD_WRS; + max_qp_wr -= 1; + if (max_qp_wr < RPCRDMA_MIN_SLOT_TABLE) + return -ENOMEM; + if (ep->re_max_requests > max_qp_wr) + ep->re_max_requests = max_qp_wr; + ep->re_attr.cap.max_send_wr = ep->re_max_requests * depth; + if (ep->re_attr.cap.max_send_wr > max_qp_wr) { + ep->re_max_requests = max_qp_wr / depth; + if (!ep->re_max_requests) + return -ENOMEM; + ep->re_attr.cap.max_send_wr = ep->re_max_requests * depth; + } + ep->re_attr.cap.max_send_wr += RPCRDMA_BACKWARD_WRS; + ep->re_attr.cap.max_send_wr += 1; /* for ib_drain_sq */ + ep->re_attr.cap.max_recv_wr = ep->re_max_requests; + ep->re_attr.cap.max_recv_wr += RPCRDMA_BACKWARD_WRS; + ep->re_attr.cap.max_recv_wr += RPCRDMA_MAX_RECV_BATCH; + ep->re_attr.cap.max_recv_wr += 1; /* for ib_drain_rq */ + + ep->re_max_rdma_segs = + DIV_ROUND_UP(RPCRDMA_MAX_DATA_SEGS, ep->re_max_fr_depth); + /* Reply chunks require segments for head and tail buffers */ + ep->re_max_rdma_segs += 2; + if (ep->re_max_rdma_segs > RPCRDMA_MAX_HDR_SEGS) + ep->re_max_rdma_segs = RPCRDMA_MAX_HDR_SEGS; + + /* Ensure the underlying device is capable of conveying the + * largest r/wsize NFS will ask for. This guarantees that + * failing over from one RDMA device to another will not + * break NFS I/O. + */ + if ((ep->re_max_rdma_segs * ep->re_max_fr_depth) < RPCRDMA_MAX_SEGS) + return -ENOMEM; + + return 0; +} + +/** + * frwr_map - Register a memory region + * @r_xprt: controlling transport + * @seg: memory region co-ordinates + * @nsegs: number of segments remaining + * @writing: true when RDMA Write will be used + * @xid: XID of RPC using the registered memory + * @mr: MR to fill in + * + * Prepare a REG_MR Work Request to register a memory region + * for remote access via RDMA READ or RDMA WRITE. + * + * Returns the next segment or a negative errno pointer. + * On success, @mr is filled in. + */ +struct rpcrdma_mr_seg *frwr_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_mr_seg *seg, + int nsegs, bool writing, __be32 xid, + struct rpcrdma_mr *mr) +{ + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct ib_reg_wr *reg_wr; + int i, n, dma_nents; + struct ib_mr *ibmr; + u8 key; + + if (nsegs > ep->re_max_fr_depth) + nsegs = ep->re_max_fr_depth; + for (i = 0; i < nsegs;) { + sg_set_page(&mr->mr_sg[i], seg->mr_page, + seg->mr_len, seg->mr_offset); + + ++seg; + ++i; + if (ep->re_mrtype == IB_MR_TYPE_SG_GAPS) + continue; + if ((i < nsegs && seg->mr_offset) || + offset_in_page((seg-1)->mr_offset + (seg-1)->mr_len)) + break; + } + mr->mr_dir = rpcrdma_data_dir(writing); + mr->mr_nents = i; + + dma_nents = ib_dma_map_sg(ep->re_id->device, mr->mr_sg, mr->mr_nents, + mr->mr_dir); + if (!dma_nents) + goto out_dmamap_err; + mr->mr_device = ep->re_id->device; + + ibmr = mr->mr_ibmr; + n = ib_map_mr_sg(ibmr, mr->mr_sg, dma_nents, NULL, PAGE_SIZE); + if (n != dma_nents) + goto out_mapmr_err; + + ibmr->iova &= 0x00000000ffffffff; + ibmr->iova |= ((u64)be32_to_cpu(xid)) << 32; + key = (u8)(ibmr->rkey & 0x000000FF); + ib_update_fast_reg_key(ibmr, ++key); + + reg_wr = &mr->mr_regwr; + reg_wr->mr = ibmr; + reg_wr->key = ibmr->rkey; + reg_wr->access = writing ? + IB_ACCESS_REMOTE_WRITE | IB_ACCESS_LOCAL_WRITE : + IB_ACCESS_REMOTE_READ; + + mr->mr_handle = ibmr->rkey; + mr->mr_length = ibmr->length; + mr->mr_offset = ibmr->iova; + trace_xprtrdma_mr_map(mr); + + return seg; + +out_dmamap_err: + trace_xprtrdma_frwr_sgerr(mr, i); + return ERR_PTR(-EIO); + +out_mapmr_err: + trace_xprtrdma_frwr_maperr(mr, n); + return ERR_PTR(-EIO); +} + +/** + * frwr_wc_fastreg - Invoked by RDMA provider for a flushed FastReg WC + * @cq: completion queue + * @wc: WCE for a completed FastReg WR + * + * Each flushed MR gets destroyed after the QP has drained. + */ +static void frwr_wc_fastreg(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_fastreg(wc, &mr->mr_cid); + + rpcrdma_flush_disconnect(cq->cq_context, wc); +} + +/** + * frwr_send - post Send WRs containing the RPC Call message + * @r_xprt: controlling transport instance + * @req: prepared RPC Call + * + * For FRWR, chain any FastReg WRs to the Send WR. Only a + * single ib_post_send call is needed to register memory + * and then post the Send WR. + * + * Returns the return code from ib_post_send. + * + * Caller must hold the transport send lock to ensure that the + * pointers to the transport's rdma_cm_id and QP are stable. + */ +int frwr_send(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req) +{ + struct ib_send_wr *post_wr, *send_wr = &req->rl_wr; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rpcrdma_mr *mr; + unsigned int num_wrs; + int ret; + + num_wrs = 1; + post_wr = send_wr; + list_for_each_entry(mr, &req->rl_registered, mr_list) { + trace_xprtrdma_mr_fastreg(mr); + + mr->mr_cqe.done = frwr_wc_fastreg; + mr->mr_regwr.wr.next = post_wr; + mr->mr_regwr.wr.wr_cqe = &mr->mr_cqe; + mr->mr_regwr.wr.num_sge = 0; + mr->mr_regwr.wr.opcode = IB_WR_REG_MR; + mr->mr_regwr.wr.send_flags = 0; + post_wr = &mr->mr_regwr.wr; + ++num_wrs; + } + + if ((kref_read(&req->rl_kref) > 1) || num_wrs > ep->re_send_count) { + send_wr->send_flags |= IB_SEND_SIGNALED; + ep->re_send_count = min_t(unsigned int, ep->re_send_batch, + num_wrs - ep->re_send_count); + } else { + send_wr->send_flags &= ~IB_SEND_SIGNALED; + ep->re_send_count -= num_wrs; + } + + trace_xprtrdma_post_send(req); + ret = ib_post_send(ep->re_id->qp, post_wr, NULL); + if (ret) + trace_xprtrdma_post_send_err(r_xprt, req, ret); + return ret; +} + +/** + * frwr_reminv - handle a remotely invalidated mr on the @mrs list + * @rep: Received reply + * @mrs: list of MRs to check + * + */ +void frwr_reminv(struct rpcrdma_rep *rep, struct list_head *mrs) +{ + struct rpcrdma_mr *mr; + + list_for_each_entry(mr, mrs, mr_list) + if (mr->mr_handle == rep->rr_inv_rkey) { + list_del_init(&mr->mr_list); + trace_xprtrdma_mr_reminv(mr); + frwr_mr_put(mr); + break; /* only one invalidated MR per RPC */ + } +} + +static void frwr_mr_done(struct ib_wc *wc, struct rpcrdma_mr *mr) +{ + if (likely(wc->status == IB_WC_SUCCESS)) + frwr_mr_put(mr); +} + +/** + * frwr_wc_localinv - Invoked by RDMA provider for a LOCAL_INV WC + * @cq: completion queue + * @wc: WCE for a completed LocalInv WR + * + */ +static void frwr_wc_localinv(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_li(wc, &mr->mr_cid); + frwr_mr_done(wc, mr); + + rpcrdma_flush_disconnect(cq->cq_context, wc); +} + +/** + * frwr_wc_localinv_wake - Invoked by RDMA provider for a LOCAL_INV WC + * @cq: completion queue + * @wc: WCE for a completed LocalInv WR + * + * Awaken anyone waiting for an MR to finish being fenced. + */ +static void frwr_wc_localinv_wake(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_li_wake(wc, &mr->mr_cid); + frwr_mr_done(wc, mr); + complete(&mr->mr_linv_done); + + rpcrdma_flush_disconnect(cq->cq_context, wc); +} + +/** + * frwr_unmap_sync - invalidate memory regions that were registered for @req + * @r_xprt: controlling transport instance + * @req: rpcrdma_req with a non-empty list of MRs to process + * + * Sleeps until it is safe for the host CPU to access the previously mapped + * memory regions. This guarantees that registered MRs are properly fenced + * from the server before the RPC consumer accesses the data in them. It + * also ensures proper Send flow control: waking the next RPC waits until + * this RPC has relinquished all its Send Queue entries. + */ +void frwr_unmap_sync(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req) +{ + struct ib_send_wr *first, **prev, *last; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + const struct ib_send_wr *bad_wr; + struct rpcrdma_mr *mr; + int rc; + + /* ORDER: Invalidate all of the MRs first + * + * Chain the LOCAL_INV Work Requests and post them with + * a single ib_post_send() call. + */ + prev = &first; + mr = rpcrdma_mr_pop(&req->rl_registered); + do { + trace_xprtrdma_mr_localinv(mr); + r_xprt->rx_stats.local_inv_needed++; + + last = &mr->mr_invwr; + last->next = NULL; + last->wr_cqe = &mr->mr_cqe; + last->sg_list = NULL; + last->num_sge = 0; + last->opcode = IB_WR_LOCAL_INV; + last->send_flags = IB_SEND_SIGNALED; + last->ex.invalidate_rkey = mr->mr_handle; + + last->wr_cqe->done = frwr_wc_localinv; + + *prev = last; + prev = &last->next; + } while ((mr = rpcrdma_mr_pop(&req->rl_registered))); + + mr = container_of(last, struct rpcrdma_mr, mr_invwr); + + /* Strong send queue ordering guarantees that when the + * last WR in the chain completes, all WRs in the chain + * are complete. + */ + last->wr_cqe->done = frwr_wc_localinv_wake; + reinit_completion(&mr->mr_linv_done); + + /* Transport disconnect drains the receive CQ before it + * replaces the QP. The RPC reply handler won't call us + * unless re_id->qp is a valid pointer. + */ + bad_wr = NULL; + rc = ib_post_send(ep->re_id->qp, first, &bad_wr); + + /* The final LOCAL_INV WR in the chain is supposed to + * do the wake. If it was never posted, the wake will + * not happen, so don't wait in that case. + */ + if (bad_wr != first) + wait_for_completion(&mr->mr_linv_done); + if (!rc) + return; + + /* On error, the MRs get destroyed once the QP has drained. */ + trace_xprtrdma_post_linv_err(req, rc); + + /* Force a connection loss to ensure complete recovery. + */ + rpcrdma_force_disconnect(ep); +} + +/** + * frwr_wc_localinv_done - Invoked by RDMA provider for a signaled LOCAL_INV WC + * @cq: completion queue + * @wc: WCE for a completed LocalInv WR + * + */ +static void frwr_wc_localinv_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe); + struct rpcrdma_rep *rep; + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_li_done(wc, &mr->mr_cid); + + /* Ensure that @rep is generated before the MR is released */ + rep = mr->mr_req->rl_reply; + smp_rmb(); + + if (wc->status != IB_WC_SUCCESS) { + if (rep) + rpcrdma_unpin_rqst(rep); + rpcrdma_flush_disconnect(cq->cq_context, wc); + return; + } + frwr_mr_put(mr); + rpcrdma_complete_rqst(rep); +} + +/** + * frwr_unmap_async - invalidate memory regions that were registered for @req + * @r_xprt: controlling transport instance + * @req: rpcrdma_req with a non-empty list of MRs to process + * + * This guarantees that registered MRs are properly fenced from the + * server before the RPC consumer accesses the data in them. It also + * ensures proper Send flow control: waking the next RPC waits until + * this RPC has relinquished all its Send Queue entries. + */ +void frwr_unmap_async(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req) +{ + struct ib_send_wr *first, *last, **prev; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rpcrdma_mr *mr; + int rc; + + /* Chain the LOCAL_INV Work Requests and post them with + * a single ib_post_send() call. + */ + prev = &first; + mr = rpcrdma_mr_pop(&req->rl_registered); + do { + trace_xprtrdma_mr_localinv(mr); + r_xprt->rx_stats.local_inv_needed++; + + last = &mr->mr_invwr; + last->next = NULL; + last->wr_cqe = &mr->mr_cqe; + last->sg_list = NULL; + last->num_sge = 0; + last->opcode = IB_WR_LOCAL_INV; + last->send_flags = IB_SEND_SIGNALED; + last->ex.invalidate_rkey = mr->mr_handle; + + last->wr_cqe->done = frwr_wc_localinv; + + *prev = last; + prev = &last->next; + } while ((mr = rpcrdma_mr_pop(&req->rl_registered))); + + /* Strong send queue ordering guarantees that when the + * last WR in the chain completes, all WRs in the chain + * are complete. The last completion will wake up the + * RPC waiter. + */ + last->wr_cqe->done = frwr_wc_localinv_done; + + /* Transport disconnect drains the receive CQ before it + * replaces the QP. The RPC reply handler won't call us + * unless re_id->qp is a valid pointer. + */ + rc = ib_post_send(ep->re_id->qp, first, NULL); + if (!rc) + return; + + /* On error, the MRs get destroyed once the QP has drained. */ + trace_xprtrdma_post_linv_err(req, rc); + + /* The final LOCAL_INV WR in the chain is supposed to + * do the wake. If it was never posted, the wake does + * not happen. Unpin the rqst in preparation for its + * retransmission. + */ + rpcrdma_unpin_rqst(req->rl_reply); + + /* Force a connection loss to ensure complete recovery. + */ + rpcrdma_force_disconnect(ep); +} + +/** + * frwr_wp_create - Create an MR for padding Write chunks + * @r_xprt: transport resources to use + * + * Return 0 on success, negative errno on failure. + */ +int frwr_wp_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rpcrdma_mr_seg seg; + struct rpcrdma_mr *mr; + + mr = rpcrdma_mr_get(r_xprt); + if (!mr) + return -EAGAIN; + mr->mr_req = NULL; + ep->re_write_pad_mr = mr; + + seg.mr_len = XDR_UNIT; + seg.mr_page = virt_to_page(ep->re_write_pad); + seg.mr_offset = offset_in_page(ep->re_write_pad); + if (IS_ERR(frwr_map(r_xprt, &seg, 1, true, xdr_zero, mr))) + return -EIO; + trace_xprtrdma_mr_fastreg(mr); + + mr->mr_cqe.done = frwr_wc_fastreg; + mr->mr_regwr.wr.next = NULL; + mr->mr_regwr.wr.wr_cqe = &mr->mr_cqe; + mr->mr_regwr.wr.num_sge = 0; + mr->mr_regwr.wr.opcode = IB_WR_REG_MR; + mr->mr_regwr.wr.send_flags = 0; + + return ib_post_send(ep->re_id->qp, &mr->mr_regwr.wr, NULL); +} diff --git a/net/sunrpc/xprtrdma/ib_client.c b/net/sunrpc/xprtrdma/ib_client.c new file mode 100644 index 000000000000..28c68b5f6823 --- /dev/null +++ b/net/sunrpc/xprtrdma/ib_client.c @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2024 Oracle. All rights reserved. + */ + +/* #include <linux/module.h> +#include <linux/slab.h> */ +#include <linux/xarray.h> +#include <linux/types.h> +#include <linux/kref.h> +#include <linux/completion.h> + +#include <linux/sunrpc/svc_rdma.h> +#include <linux/sunrpc/rdma_rn.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +/* Per-ib_device private data for rpcrdma */ +struct rpcrdma_device { + struct kref rd_kref; + unsigned long rd_flags; + struct ib_device *rd_device; + struct xarray rd_xa; + struct completion rd_done; +}; + +#define RPCRDMA_RD_F_REMOVING (0) + +static struct ib_client rpcrdma_ib_client; + +/* + * Listeners have no associated device, so we never register them. + * Note that ib_get_client_data() does not check if @device is + * NULL for us. + */ +static struct rpcrdma_device *rpcrdma_get_client_data(struct ib_device *device) +{ + if (!device) + return NULL; + return ib_get_client_data(device, &rpcrdma_ib_client); +} + +/** + * rpcrdma_rn_register - register to get device removal notifications + * @device: device to monitor + * @rn: notification object that wishes to be notified + * @done: callback to notify caller of device removal + * + * Returns zero on success. The callback in rn_done is guaranteed + * to be invoked when the device is removed, unless this notification + * is unregistered first. + * + * On failure, a negative errno is returned. + */ +int rpcrdma_rn_register(struct ib_device *device, + struct rpcrdma_notification *rn, + void (*done)(struct rpcrdma_notification *rn)) +{ + struct rpcrdma_device *rd = rpcrdma_get_client_data(device); + + if (!rd || test_bit(RPCRDMA_RD_F_REMOVING, &rd->rd_flags)) + return -ENETUNREACH; + + if (xa_alloc(&rd->rd_xa, &rn->rn_index, rn, xa_limit_32b, GFP_KERNEL) < 0) + return -ENOMEM; + kref_get(&rd->rd_kref); + rn->rn_done = done; + trace_rpcrdma_client_register(device, rn); + return 0; +} + +static void rpcrdma_rn_release(struct kref *kref) +{ + struct rpcrdma_device *rd = container_of(kref, struct rpcrdma_device, + rd_kref); + + trace_rpcrdma_client_completion(rd->rd_device); + complete(&rd->rd_done); +} + +/** + * rpcrdma_rn_unregister - stop device removal notifications + * @device: monitored device + * @rn: notification object that no longer wishes to be notified + */ +void rpcrdma_rn_unregister(struct ib_device *device, + struct rpcrdma_notification *rn) +{ + struct rpcrdma_device *rd = rpcrdma_get_client_data(device); + + if (!rd) + return; + + trace_rpcrdma_client_unregister(device, rn); + xa_erase(&rd->rd_xa, rn->rn_index); + kref_put(&rd->rd_kref, rpcrdma_rn_release); +} + +/** + * rpcrdma_add_one - ib_client device insertion callback + * @device: device about to be inserted + * + * Returns zero on success. xprtrdma private data has been allocated + * for this device. On failure, a negative errno is returned. + */ +static int rpcrdma_add_one(struct ib_device *device) +{ + struct rpcrdma_device *rd; + + rd = kzalloc(sizeof(*rd), GFP_KERNEL); + if (!rd) + return -ENOMEM; + + kref_init(&rd->rd_kref); + xa_init_flags(&rd->rd_xa, XA_FLAGS_ALLOC); + rd->rd_device = device; + init_completion(&rd->rd_done); + ib_set_client_data(device, &rpcrdma_ib_client, rd); + + trace_rpcrdma_client_add_one(device); + return 0; +} + +/** + * rpcrdma_remove_one - ib_client device removal callback + * @device: device about to be removed + * @client_data: this module's private per-device data + * + * Upon return, all transports associated with @device have divested + * themselves from IB hardware resources. + */ +static void rpcrdma_remove_one(struct ib_device *device, + void *client_data) +{ + struct rpcrdma_device *rd = client_data; + struct rpcrdma_notification *rn; + unsigned long index; + + trace_rpcrdma_client_remove_one(device); + + set_bit(RPCRDMA_RD_F_REMOVING, &rd->rd_flags); + xa_for_each(&rd->rd_xa, index, rn) + rn->rn_done(rn); + + /* + * Wait only if there are still outstanding notification + * registrants for this device. + */ + if (!refcount_dec_and_test(&rd->rd_kref.refcount)) { + trace_rpcrdma_client_wait_on(device); + wait_for_completion(&rd->rd_done); + } + + trace_rpcrdma_client_remove_one_done(device); + xa_destroy(&rd->rd_xa); + kfree(rd); +} + +static struct ib_client rpcrdma_ib_client = { + .name = "rpcrdma", + .add = rpcrdma_add_one, + .remove = rpcrdma_remove_one, +}; + +/** + * rpcrdma_ib_client_unregister - unregister ib_client for xprtrdma + * + * cel: watch for orphaned rpcrdma_device objects on module unload + */ +void rpcrdma_ib_client_unregister(void) +{ + ib_unregister_client(&rpcrdma_ib_client); +} + +/** + * rpcrdma_ib_client_register - register ib_client for rpcrdma + * + * Returns zero on success, or a negative errno. + */ +int rpcrdma_ib_client_register(void) +{ + return ib_register_client(&rpcrdma_ib_client); +} diff --git a/net/sunrpc/xprtrdma/module.c b/net/sunrpc/xprtrdma/module.c new file mode 100644 index 000000000000..697f571d4c01 --- /dev/null +++ b/net/sunrpc/xprtrdma/module.c @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2015, 2017 Oracle. All rights reserved. + */ + +/* rpcrdma.ko module initialization + */ + +#include <linux/types.h> +#include <linux/compiler.h> +#include <linux/module.h> +#include <linux/init.h> +#include <linux/sunrpc/svc_rdma.h> +#include <linux/sunrpc/rdma_rn.h> + +#include <asm/swab.h> + +#include "xprt_rdma.h" + +#define CREATE_TRACE_POINTS +#include <trace/events/rpcrdma.h> + +MODULE_AUTHOR("Open Grid Computing and Network Appliance, Inc."); +MODULE_DESCRIPTION("RPC/RDMA Transport"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_ALIAS("svcrdma"); +MODULE_ALIAS("xprtrdma"); +MODULE_ALIAS("rpcrdma6"); + +static void __exit rpc_rdma_cleanup(void) +{ + xprt_rdma_cleanup(); + svc_rdma_cleanup(); + rpcrdma_ib_client_unregister(); +} + +static int __init rpc_rdma_init(void) +{ + int rc; + + rc = rpcrdma_ib_client_register(); + if (rc) + goto out_rc; + + rc = svc_rdma_init(); + if (rc) + goto out_ib_client; + + rc = xprt_rdma_init(); + if (rc) + goto out_svc_rdma; + + return 0; + +out_svc_rdma: + svc_rdma_cleanup(); +out_ib_client: + rpcrdma_ib_client_unregister(); +out_rc: + return rc; +} + +module_init(rpc_rdma_init); +module_exit(rpc_rdma_cleanup); diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c index e03725bfe2b8..3aac1456e23e 100644 --- a/net/sunrpc/xprtrdma/rpc_rdma.c +++ b/net/sunrpc/xprtrdma/rpc_rdma.c @@ -1,4 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause /* + * Copyright (c) 2014-2020, Oracle and/or its affiliates. * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -45,111 +47,277 @@ * to the Linux RPC framework lives. */ +#include <linux/highmem.h> + +#include <linux/sunrpc/svc_rdma.h> + #include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> -#include <linux/highmem.h> +/* Returns size of largest RPC-over-RDMA header in a Call message + * + * The largest Call header contains a full-size Read list and a + * minimal Reply chunk. + */ +static unsigned int rpcrdma_max_call_header_size(unsigned int maxsegs) +{ + unsigned int size; -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_TRANS -#endif - -enum rpcrdma_chunktype { - rpcrdma_noch = 0, - rpcrdma_readch, - rpcrdma_areadch, - rpcrdma_writech, - rpcrdma_replych -}; - -#ifdef RPC_DEBUG -static const char transfertypes[][12] = { - "pure inline", /* no chunks */ - " read chunk", /* some argument via rdma read */ - "*read chunk", /* entire request via rdma read */ - "write chunk", /* some result via rdma write */ - "reply chunk" /* entire reply via rdma write */ -}; -#endif + /* Fixed header fields and list discriminators */ + size = RPCRDMA_HDRLEN_MIN; -/* - * Chunk assembly from upper layer xdr_buf. + /* Maximum Read list size */ + size += maxsegs * rpcrdma_readchunk_maxsz * sizeof(__be32); + + /* Minimal Read chunk size */ + size += sizeof(__be32); /* segment count */ + size += rpcrdma_segment_maxsz * sizeof(__be32); + size += sizeof(__be32); /* list discriminator */ + + return size; +} + +/* Returns size of largest RPC-over-RDMA header in a Reply message + * + * There is only one Write list or one Reply chunk per Reply + * message. The larger list is the Write list. + */ +static unsigned int rpcrdma_max_reply_header_size(unsigned int maxsegs) +{ + unsigned int size; + + /* Fixed header fields and list discriminators */ + size = RPCRDMA_HDRLEN_MIN; + + /* Maximum Write list size */ + size += sizeof(__be32); /* segment count */ + size += maxsegs * rpcrdma_segment_maxsz * sizeof(__be32); + size += sizeof(__be32); /* list discriminator */ + + return size; +} + +/** + * rpcrdma_set_max_header_sizes - Initialize inline payload sizes + * @ep: endpoint to initialize + * + * The max_inline fields contain the maximum size of an RPC message + * so the marshaling code doesn't have to repeat this calculation + * for every RPC. + */ +void rpcrdma_set_max_header_sizes(struct rpcrdma_ep *ep) +{ + unsigned int maxsegs = ep->re_max_rdma_segs; + + ep->re_max_inline_send = + ep->re_inline_send - rpcrdma_max_call_header_size(maxsegs); + ep->re_max_inline_recv = + ep->re_inline_recv - rpcrdma_max_reply_header_size(maxsegs); +} + +/* The client can send a request inline as long as the RPCRDMA header + * plus the RPC call fit under the transport's inline limit. If the + * combined call message size exceeds that limit, the client must use + * a Read chunk for this operation. + * + * A Read chunk is also required if sending the RPC call inline would + * exceed this device's max_sge limit. + */ +static bool rpcrdma_args_inline(struct rpcrdma_xprt *r_xprt, + struct rpc_rqst *rqst) +{ + struct xdr_buf *xdr = &rqst->rq_snd_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + unsigned int count, remaining, offset; + + if (xdr->len > ep->re_max_inline_send) + return false; + + if (xdr->page_len) { + remaining = xdr->page_len; + offset = offset_in_page(xdr->page_base); + count = RPCRDMA_MIN_SEND_SGES; + while (remaining) { + remaining -= min_t(unsigned int, + PAGE_SIZE - offset, remaining); + offset = 0; + if (++count > ep->re_attr.cap.max_send_sge) + return false; + } + } + + return true; +} + +/* The client can't know how large the actual reply will be. Thus it + * plans for the largest possible reply for that particular ULP + * operation. If the maximum combined reply message size exceeds that + * limit, the client must provide a write list or a reply chunk for + * this request. + */ +static bool rpcrdma_results_inline(struct rpcrdma_xprt *r_xprt, + struct rpc_rqst *rqst) +{ + return rqst->rq_rcv_buf.buflen <= r_xprt->rx_ep->re_max_inline_recv; +} + +/* The client is required to provide a Reply chunk if the maximum + * size of the non-payload part of the RPC Reply is larger than + * the inline threshold. + */ +static bool +rpcrdma_nonpayload_inline(const struct rpcrdma_xprt *r_xprt, + const struct rpc_rqst *rqst) +{ + const struct xdr_buf *buf = &rqst->rq_rcv_buf; + + return (buf->head[0].iov_len + buf->tail[0].iov_len) < + r_xprt->rx_ep->re_max_inline_recv; +} + +/* ACL likes to be lazy in allocating pages. For TCP, these + * pages can be allocated during receive processing. Not true + * for RDMA, which must always provision receive buffers + * up front. + */ +static noinline int +rpcrdma_alloc_sparse_pages(struct xdr_buf *buf) +{ + struct page **ppages; + int len; + + len = buf->page_len; + ppages = buf->pages + (buf->page_base >> PAGE_SHIFT); + while (len > 0) { + if (!*ppages) + *ppages = alloc_page(GFP_NOWAIT); + if (!*ppages) + return -ENOBUFS; + ppages++; + len -= PAGE_SIZE; + } + + return 0; +} + +/* Convert @vec to a single SGL element. * - * Prepare the passed-in xdr_buf into representation as RPC/RDMA chunk - * elements. Segments are then coalesced when registered, if possible - * within the selected memreg mode. + * Returns pointer to next available SGE, and bumps the total number + * of SGEs consumed. + */ +static struct rpcrdma_mr_seg * +rpcrdma_convert_kvec(struct kvec *vec, struct rpcrdma_mr_seg *seg, + unsigned int *n) +{ + seg->mr_page = virt_to_page(vec->iov_base); + seg->mr_offset = offset_in_page(vec->iov_base); + seg->mr_len = vec->iov_len; + ++seg; + ++(*n); + return seg; +} + +/* Convert @xdrbuf into SGEs no larger than a page each. As they + * are registered, these SGEs are then coalesced into RDMA segments + * when the selected memreg mode supports it. * - * Note, this routine is never called if the connection's memory - * registration strategy is 0 (bounce buffers). + * Returns positive number of SGEs consumed, or a negative errno. */ static int -rpcrdma_convert_iovs(struct xdr_buf *xdrbuf, unsigned int pos, - enum rpcrdma_chunktype type, struct rpcrdma_mr_seg *seg, int nsegs) +rpcrdma_convert_iovs(struct rpcrdma_xprt *r_xprt, struct xdr_buf *xdrbuf, + unsigned int pos, enum rpcrdma_chunktype type, + struct rpcrdma_mr_seg *seg) { - int len, n = 0, p; - int page_base; + unsigned long page_base; + unsigned int len, n; struct page **ppages; - if (pos == 0 && xdrbuf->head[0].iov_len) { - seg[n].mr_page = NULL; - seg[n].mr_offset = xdrbuf->head[0].iov_base; - seg[n].mr_len = xdrbuf->head[0].iov_len; - ++n; - } + n = 0; + if (pos == 0) + seg = rpcrdma_convert_kvec(&xdrbuf->head[0], seg, &n); len = xdrbuf->page_len; ppages = xdrbuf->pages + (xdrbuf->page_base >> PAGE_SHIFT); - page_base = xdrbuf->page_base & ~PAGE_MASK; - p = 0; - while (len && n < nsegs) { - seg[n].mr_page = ppages[p]; - seg[n].mr_offset = (void *)(unsigned long) page_base; - seg[n].mr_len = min_t(u32, PAGE_SIZE - page_base, len); - BUG_ON(seg[n].mr_len > PAGE_SIZE); - len -= seg[n].mr_len; + page_base = offset_in_page(xdrbuf->page_base); + while (len) { + seg->mr_page = *ppages; + seg->mr_offset = page_base; + seg->mr_len = min_t(u32, PAGE_SIZE - page_base, len); + len -= seg->mr_len; + ++ppages; + ++seg; ++n; - ++p; - page_base = 0; /* page offset only applies to first page */ + page_base = 0; } - /* Message overflows the seg array */ - if (len && n == nsegs) - return 0; + if (type == rpcrdma_readch || type == rpcrdma_writech) + goto out; - if (xdrbuf->tail[0].iov_len) { - /* the rpcrdma protocol allows us to omit any trailing - * xdr pad bytes, saving the server an RDMA operation. */ - if (xdrbuf->tail[0].iov_len < 4 && xprt_rdma_pad_optimize) - return n; - if (n == nsegs) - /* Tail remains, but we're out of segments */ - return 0; - seg[n].mr_page = NULL; - seg[n].mr_offset = xdrbuf->tail[0].iov_base; - seg[n].mr_len = xdrbuf->tail[0].iov_len; - ++n; - } + if (xdrbuf->tail[0].iov_len) + rpcrdma_convert_kvec(&xdrbuf->tail[0], seg, &n); +out: + if (unlikely(n > RPCRDMA_MAX_SEGS)) + return -EIO; return n; } -/* - * Create read/write chunk lists, and reply chunks, for RDMA - * - * Assume check against THRESHOLD has been done, and chunks are required. - * Assume only encoding one list entry for read|write chunks. The NFSv3 - * protocol is simple enough to allow this as it only has a single "bulk - * result" in each procedure - complicated NFSv4 COMPOUNDs are not. (The - * RDMA/Sessions NFSv4 proposal addresses this for future v4 revs.) - * - * When used for a single reply chunk (which is a special write - * chunk used for the entire reply, rather than just the data), it - * is used primarily for READDIR and READLINK which would otherwise - * be severely size-limited by a small rdma inline read max. The server - * response will come back as an RDMA Write, followed by a message - * of type RDMA_NOMSG carrying the xid and length. As a result, reply - * chunks do not provide data alignment, however they do not require - * "fixup" (moving the response to the upper layer buffer) either. +static int +encode_rdma_segment(struct xdr_stream *xdr, struct rpcrdma_mr *mr) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4 * sizeof(*p)); + if (unlikely(!p)) + return -EMSGSIZE; + + xdr_encode_rdma_segment(p, mr->mr_handle, mr->mr_length, mr->mr_offset); + return 0; +} + +static int +encode_read_segment(struct xdr_stream *xdr, struct rpcrdma_mr *mr, + u32 position) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 6 * sizeof(*p)); + if (unlikely(!p)) + return -EMSGSIZE; + + *p++ = xdr_one; /* Item present */ + xdr_encode_read_segment(p, position, mr->mr_handle, mr->mr_length, + mr->mr_offset); + return 0; +} + +static struct rpcrdma_mr_seg *rpcrdma_mr_prepare(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct rpcrdma_mr_seg *seg, + int nsegs, bool writing, + struct rpcrdma_mr **mr) +{ + *mr = rpcrdma_mr_pop(&req->rl_free_mrs); + if (!*mr) { + *mr = rpcrdma_mr_get(r_xprt); + if (!*mr) + goto out_getmr_err; + (*mr)->mr_req = req; + } + + rpcrdma_mr_push(*mr, &req->rl_registered); + return frwr_map(r_xprt, seg, nsegs, writing, req->rl_slot.rq_xid, *mr); + +out_getmr_err: + trace_xprtrdma_nomrs_err(r_xprt, req); + xprt_wait_for_buffer_space(&r_xprt->rx_xprt); + rpcrdma_mrs_refresh(r_xprt); + return ERR_PTR(-EAGAIN); +} + +/* Register and XDR encode the Read list. Supports encoding a list of read + * segments that belong to a single read chunk. * * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): * @@ -157,262 +325,587 @@ rpcrdma_convert_iovs(struct xdr_buf *xdrbuf, unsigned int pos, * N elements, position P (same P for all chunks of same arg!): * 1 - PHLOO - 1 - PHLOO - ... - 1 - PHLOO - 0 * + * Returns zero on success, or a negative errno if a failure occurred. + * @xdr is advanced to the next position in the stream. + * + * Only a single @pos value is currently supported. + */ +static int rpcrdma_encode_read_list(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct rpc_rqst *rqst, + enum rpcrdma_chunktype rtype) +{ + struct xdr_stream *xdr = &req->rl_stream; + struct rpcrdma_mr_seg *seg; + struct rpcrdma_mr *mr; + unsigned int pos; + int nsegs; + + if (rtype == rpcrdma_noch_pullup || rtype == rpcrdma_noch_mapped) + goto done; + + pos = rqst->rq_snd_buf.head[0].iov_len; + if (rtype == rpcrdma_areadch) + pos = 0; + seg = req->rl_segments; + nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_snd_buf, pos, + rtype, seg); + if (nsegs < 0) + return nsegs; + + do { + seg = rpcrdma_mr_prepare(r_xprt, req, seg, nsegs, false, &mr); + if (IS_ERR(seg)) + return PTR_ERR(seg); + + if (encode_read_segment(xdr, mr, pos) < 0) + return -EMSGSIZE; + + trace_xprtrdma_chunk_read(rqst->rq_task, pos, mr, nsegs); + r_xprt->rx_stats.read_chunk_count++; + nsegs -= mr->mr_nents; + } while (nsegs); + +done: + if (xdr_stream_encode_item_absent(xdr) < 0) + return -EMSGSIZE; + return 0; +} + +/* Register and XDR encode the Write list. Supports encoding a list + * containing one array of plain segments that belong to a single + * write chunk. + * + * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): + * * Write chunklist (a list of (one) counted array): * N elements: * 1 - N - HLOO - HLOO - ... - HLOO - 0 * + * Returns zero on success, or a negative errno if a failure occurred. + * @xdr is advanced to the next position in the stream. + * + * Only a single Write chunk is currently supported. + */ +static int rpcrdma_encode_write_list(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct rpc_rqst *rqst, + enum rpcrdma_chunktype wtype) +{ + struct xdr_stream *xdr = &req->rl_stream; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rpcrdma_mr_seg *seg; + struct rpcrdma_mr *mr; + int nsegs, nchunks; + __be32 *segcount; + + if (wtype != rpcrdma_writech) + goto done; + + seg = req->rl_segments; + nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_rcv_buf, + rqst->rq_rcv_buf.head[0].iov_len, + wtype, seg); + if (nsegs < 0) + return nsegs; + + if (xdr_stream_encode_item_present(xdr) < 0) + return -EMSGSIZE; + segcount = xdr_reserve_space(xdr, sizeof(*segcount)); + if (unlikely(!segcount)) + return -EMSGSIZE; + /* Actual value encoded below */ + + nchunks = 0; + do { + seg = rpcrdma_mr_prepare(r_xprt, req, seg, nsegs, true, &mr); + if (IS_ERR(seg)) + return PTR_ERR(seg); + + if (encode_rdma_segment(xdr, mr) < 0) + return -EMSGSIZE; + + trace_xprtrdma_chunk_write(rqst->rq_task, mr, nsegs); + r_xprt->rx_stats.write_chunk_count++; + r_xprt->rx_stats.total_rdma_request += mr->mr_length; + nchunks++; + nsegs -= mr->mr_nents; + } while (nsegs); + + if (xdr_pad_size(rqst->rq_rcv_buf.page_len)) { + if (encode_rdma_segment(xdr, ep->re_write_pad_mr) < 0) + return -EMSGSIZE; + + trace_xprtrdma_chunk_wp(rqst->rq_task, ep->re_write_pad_mr, + nsegs); + r_xprt->rx_stats.write_chunk_count++; + r_xprt->rx_stats.total_rdma_request += mr->mr_length; + nchunks++; + nsegs -= mr->mr_nents; + } + + /* Update count of segments in this Write chunk */ + *segcount = cpu_to_be32(nchunks); + +done: + if (xdr_stream_encode_item_absent(xdr) < 0) + return -EMSGSIZE; + return 0; +} + +/* Register and XDR encode the Reply chunk. Supports encoding an array + * of plain segments that belong to a single write (reply) chunk. + * + * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): + * * Reply chunk (a counted array): * N elements: * 1 - N - HLOO - HLOO - ... - HLOO + * + * Returns zero on success, or a negative errno if a failure occurred. + * @xdr is advanced to the next position in the stream. */ - -static unsigned int -rpcrdma_create_chunks(struct rpc_rqst *rqst, struct xdr_buf *target, - struct rpcrdma_msg *headerp, enum rpcrdma_chunktype type) +static int rpcrdma_encode_reply_chunk(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct rpc_rqst *rqst, + enum rpcrdma_chunktype wtype) { - struct rpcrdma_req *req = rpcr_to_rdmar(rqst); - struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); - int nsegs, nchunks = 0; - unsigned int pos; - struct rpcrdma_mr_seg *seg = req->rl_segments; - struct rpcrdma_read_chunk *cur_rchunk = NULL; - struct rpcrdma_write_array *warray = NULL; - struct rpcrdma_write_chunk *cur_wchunk = NULL; - __be32 *iptr = headerp->rm_body.rm_chunks; - - if (type == rpcrdma_readch || type == rpcrdma_areadch) { - /* a read chunk - server will RDMA Read our memory */ - cur_rchunk = (struct rpcrdma_read_chunk *) iptr; - } else { - /* a write or reply chunk - server will RDMA Write our memory */ - *iptr++ = xdr_zero; /* encode a NULL read chunk list */ - if (type == rpcrdma_replych) - *iptr++ = xdr_zero; /* a NULL write chunk list */ - warray = (struct rpcrdma_write_array *) iptr; - cur_wchunk = (struct rpcrdma_write_chunk *) (warray + 1); + struct xdr_stream *xdr = &req->rl_stream; + struct rpcrdma_mr_seg *seg; + struct rpcrdma_mr *mr; + int nsegs, nchunks; + __be32 *segcount; + + if (wtype != rpcrdma_replych) { + if (xdr_stream_encode_item_absent(xdr) < 0) + return -EMSGSIZE; + return 0; } - if (type == rpcrdma_replych || type == rpcrdma_areadch) - pos = 0; - else - pos = target->head[0].iov_len; + seg = req->rl_segments; + nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_rcv_buf, 0, wtype, seg); + if (nsegs < 0) + return nsegs; - nsegs = rpcrdma_convert_iovs(target, pos, type, seg, RPCRDMA_MAX_SEGS); - if (nsegs == 0) - return 0; + if (xdr_stream_encode_item_present(xdr) < 0) + return -EMSGSIZE; + segcount = xdr_reserve_space(xdr, sizeof(*segcount)); + if (unlikely(!segcount)) + return -EMSGSIZE; + /* Actual value encoded below */ + nchunks = 0; do { - /* bind/register the memory, then build chunk from result. */ - int n = rpcrdma_register_external(seg, nsegs, - cur_wchunk != NULL, r_xprt); - if (n <= 0) - goto out; - if (cur_rchunk) { /* read */ - cur_rchunk->rc_discrim = xdr_one; - /* all read chunks have the same "position" */ - cur_rchunk->rc_position = htonl(pos); - cur_rchunk->rc_target.rs_handle = htonl(seg->mr_rkey); - cur_rchunk->rc_target.rs_length = htonl(seg->mr_len); - xdr_encode_hyper( - (__be32 *)&cur_rchunk->rc_target.rs_offset, - seg->mr_base); - dprintk("RPC: %s: read chunk " - "elem %d@0x%llx:0x%x pos %u (%s)\n", __func__, - seg->mr_len, (unsigned long long)seg->mr_base, - seg->mr_rkey, pos, n < nsegs ? "more" : "last"); - cur_rchunk++; - r_xprt->rx_stats.read_chunk_count++; - } else { /* write/reply */ - cur_wchunk->wc_target.rs_handle = htonl(seg->mr_rkey); - cur_wchunk->wc_target.rs_length = htonl(seg->mr_len); - xdr_encode_hyper( - (__be32 *)&cur_wchunk->wc_target.rs_offset, - seg->mr_base); - dprintk("RPC: %s: %s chunk " - "elem %d@0x%llx:0x%x (%s)\n", __func__, - (type == rpcrdma_replych) ? "reply" : "write", - seg->mr_len, (unsigned long long)seg->mr_base, - seg->mr_rkey, n < nsegs ? "more" : "last"); - cur_wchunk++; - if (type == rpcrdma_replych) - r_xprt->rx_stats.reply_chunk_count++; - else - r_xprt->rx_stats.write_chunk_count++; - r_xprt->rx_stats.total_rdma_request += seg->mr_len; - } + seg = rpcrdma_mr_prepare(r_xprt, req, seg, nsegs, true, &mr); + if (IS_ERR(seg)) + return PTR_ERR(seg); + + if (encode_rdma_segment(xdr, mr) < 0) + return -EMSGSIZE; + + trace_xprtrdma_chunk_reply(rqst->rq_task, mr, nsegs); + r_xprt->rx_stats.reply_chunk_count++; + r_xprt->rx_stats.total_rdma_request += mr->mr_length; nchunks++; - seg += n; - nsegs -= n; + nsegs -= mr->mr_nents; } while (nsegs); - /* success. all failures return above */ - req->rl_nchunks = nchunks; + /* Update count of segments in the Reply chunk */ + *segcount = cpu_to_be32(nchunks); - BUG_ON(nchunks == 0); - BUG_ON((r_xprt->rx_ia.ri_memreg_strategy == RPCRDMA_FRMR) - && (nchunks > 3)); + return 0; +} - /* - * finish off header. If write, marshal discrim and nchunks. - */ - if (cur_rchunk) { - iptr = (__be32 *) cur_rchunk; - *iptr++ = xdr_zero; /* finish the read chunk list */ - *iptr++ = xdr_zero; /* encode a NULL write chunk list */ - *iptr++ = xdr_zero; /* encode a NULL reply chunk */ - } else { - warray->wc_discrim = xdr_one; - warray->wc_nchunks = htonl(nchunks); - iptr = (__be32 *) cur_wchunk; - if (type == rpcrdma_writech) { - *iptr++ = xdr_zero; /* finish the write chunk list */ - *iptr++ = xdr_zero; /* encode a NULL reply chunk */ - } - } +static void rpcrdma_sendctx_done(struct kref *kref) +{ + struct rpcrdma_req *req = + container_of(kref, struct rpcrdma_req, rl_kref); + struct rpcrdma_rep *rep = req->rl_reply; - /* - * Return header size. + rpcrdma_complete_rqst(rep); + rep->rr_rxprt->rx_stats.reply_waits_for_send++; +} + +/** + * rpcrdma_sendctx_unmap - DMA-unmap Send buffer + * @sc: sendctx containing SGEs to unmap + * + */ +void rpcrdma_sendctx_unmap(struct rpcrdma_sendctx *sc) +{ + struct rpcrdma_regbuf *rb = sc->sc_req->rl_sendbuf; + struct ib_sge *sge; + + if (!sc->sc_unmap_count) + return; + + /* The first two SGEs contain the transport header and + * the inline buffer. These are always left mapped so + * they can be cheaply re-used. */ - return (unsigned char *)iptr - (unsigned char *)headerp; + for (sge = &sc->sc_sges[2]; sc->sc_unmap_count; + ++sge, --sc->sc_unmap_count) + ib_dma_unmap_page(rdmab_device(rb), sge->addr, sge->length, + DMA_TO_DEVICE); -out: - for (pos = 0; nchunks--;) - pos += rpcrdma_deregister_external( - &req->rl_segments[pos], r_xprt, NULL); - return 0; + kref_put(&sc->sc_req->rl_kref, rpcrdma_sendctx_done); } -/* - * Copy write data inline. - * This function is used for "small" requests. Data which is passed - * to RPC via iovecs (or page list) is copied directly into the - * pre-registered memory buffer for this request. For small amounts - * of data, this is efficient. The cutoff value is tunable. +/* Prepare an SGE for the RPC-over-RDMA transport header. */ -static int -rpcrdma_inline_pullup(struct rpc_rqst *rqst, int pad) +static void rpcrdma_prepare_hdr_sge(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, u32 len) { - int i, npages, curlen; - int copy_len; - unsigned char *srcp, *destp; - struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); - int page_base; - struct page **ppages; + struct rpcrdma_sendctx *sc = req->rl_sendctx; + struct rpcrdma_regbuf *rb = req->rl_rdmabuf; + struct ib_sge *sge = &sc->sc_sges[req->rl_wr.num_sge++]; - destp = rqst->rq_svec[0].iov_base; - curlen = rqst->rq_svec[0].iov_len; - destp += curlen; - /* - * Do optional padding where it makes sense. Alignment of write - * payload can help the server, if our setting is accurate. - */ - pad -= (curlen + 36/*sizeof(struct rpcrdma_msg_padded)*/); - if (pad < 0 || rqst->rq_slen - curlen < RPCRDMA_INLINE_PAD_THRESH) - pad = 0; /* don't pad this request */ + sge->addr = rdmab_addr(rb); + sge->length = len; + sge->lkey = rdmab_lkey(rb); + + ib_dma_sync_single_for_device(rdmab_device(rb), sge->addr, sge->length, + DMA_TO_DEVICE); +} - dprintk("RPC: %s: pad %d destp 0x%p len %d hdrlen %d\n", - __func__, pad, destp, rqst->rq_slen, curlen); +/* The head iovec is straightforward, as it is usually already + * DMA-mapped. Sync the content that has changed. + */ +static bool rpcrdma_prepare_head_iov(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, unsigned int len) +{ + struct rpcrdma_sendctx *sc = req->rl_sendctx; + struct ib_sge *sge = &sc->sc_sges[req->rl_wr.num_sge++]; + struct rpcrdma_regbuf *rb = req->rl_sendbuf; - copy_len = rqst->rq_snd_buf.page_len; + if (!rpcrdma_regbuf_dma_map(r_xprt, rb)) + return false; - if (rqst->rq_snd_buf.tail[0].iov_len) { - curlen = rqst->rq_snd_buf.tail[0].iov_len; - if (destp + copy_len != rqst->rq_snd_buf.tail[0].iov_base) { - memmove(destp + copy_len, - rqst->rq_snd_buf.tail[0].iov_base, curlen); - r_xprt->rx_stats.pullup_copy_count += curlen; - } - dprintk("RPC: %s: tail destp 0x%p len %d\n", - __func__, destp + copy_len, curlen); - rqst->rq_svec[0].iov_len += curlen; + sge->addr = rdmab_addr(rb); + sge->length = len; + sge->lkey = rdmab_lkey(rb); + + ib_dma_sync_single_for_device(rdmab_device(rb), sge->addr, sge->length, + DMA_TO_DEVICE); + return true; +} + +/* If there is a page list present, DMA map and prepare an + * SGE for each page to be sent. + */ +static bool rpcrdma_prepare_pagelist(struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + struct rpcrdma_sendctx *sc = req->rl_sendctx; + struct rpcrdma_regbuf *rb = req->rl_sendbuf; + unsigned int page_base, len, remaining; + struct page **ppages; + struct ib_sge *sge; + + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + page_base = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + sge = &sc->sc_sges[req->rl_wr.num_sge++]; + len = min_t(unsigned int, PAGE_SIZE - page_base, remaining); + sge->addr = ib_dma_map_page(rdmab_device(rb), *ppages, + page_base, len, DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdmab_device(rb), sge->addr)) + goto out_mapping_err; + + sge->length = len; + sge->lkey = rdmab_lkey(rb); + + sc->sc_unmap_count++; + ppages++; + remaining -= len; + page_base = 0; } - r_xprt->rx_stats.pullup_copy_count += copy_len; - - page_base = rqst->rq_snd_buf.page_base; - ppages = rqst->rq_snd_buf.pages + (page_base >> PAGE_SHIFT); - page_base &= ~PAGE_MASK; - npages = PAGE_ALIGN(page_base+copy_len) >> PAGE_SHIFT; - for (i = 0; copy_len && i < npages; i++) { - curlen = PAGE_SIZE - page_base; - if (curlen > copy_len) - curlen = copy_len; - dprintk("RPC: %s: page %d destp 0x%p len %d curlen %d\n", - __func__, i, destp, copy_len, curlen); - srcp = kmap_atomic(ppages[i]); - memcpy(destp, srcp+page_base, curlen); - kunmap_atomic(srcp); - rqst->rq_svec[0].iov_len += curlen; - destp += curlen; - copy_len -= curlen; + + return true; + +out_mapping_err: + trace_xprtrdma_dma_maperr(sge->addr); + return false; +} + +/* The tail iovec may include an XDR pad for the page list, + * as well as additional content, and may not reside in the + * same page as the head iovec. + */ +static bool rpcrdma_prepare_tail_iov(struct rpcrdma_req *req, + struct xdr_buf *xdr, + unsigned int page_base, unsigned int len) +{ + struct rpcrdma_sendctx *sc = req->rl_sendctx; + struct ib_sge *sge = &sc->sc_sges[req->rl_wr.num_sge++]; + struct rpcrdma_regbuf *rb = req->rl_sendbuf; + struct page *page = virt_to_page(xdr->tail[0].iov_base); + + sge->addr = ib_dma_map_page(rdmab_device(rb), page, page_base, len, + DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdmab_device(rb), sge->addr)) + goto out_mapping_err; + + sge->length = len; + sge->lkey = rdmab_lkey(rb); + ++sc->sc_unmap_count; + return true; + +out_mapping_err: + trace_xprtrdma_dma_maperr(sge->addr); + return false; +} + +/* Copy the tail to the end of the head buffer. + */ +static void rpcrdma_pullup_tail_iov(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + unsigned char *dst; + + dst = (unsigned char *)xdr->head[0].iov_base; + dst += xdr->head[0].iov_len + xdr->page_len; + memmove(dst, xdr->tail[0].iov_base, xdr->tail[0].iov_len); + r_xprt->rx_stats.pullup_copy_count += xdr->tail[0].iov_len; +} + +/* Copy pagelist content into the head buffer. + */ +static void rpcrdma_pullup_pagelist(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + unsigned int len, page_base, remaining; + struct page **ppages; + unsigned char *src, *dst; + + dst = (unsigned char *)xdr->head[0].iov_base; + dst += xdr->head[0].iov_len; + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + page_base = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + src = page_address(*ppages); + src += page_base; + len = min_t(unsigned int, PAGE_SIZE - page_base, remaining); + memcpy(dst, src, len); + r_xprt->rx_stats.pullup_copy_count += len; + + ppages++; + dst += len; + remaining -= len; page_base = 0; } - /* header now contains entire send message */ - return pad; } -/* - * Marshal a request: the primary job of this routine is to choose - * the transfer modes. See comments below. +/* Copy the contents of @xdr into @rl_sendbuf and DMA sync it. + * When the head, pagelist, and tail are small, a pull-up copy + * is considerably less costly than DMA mapping the components + * of @xdr. + * + * Assumptions: + * - the caller has already verified that the total length + * of the RPC Call body will fit into @rl_sendbuf. + */ +static bool rpcrdma_prepare_noch_pullup(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + if (unlikely(xdr->tail[0].iov_len)) + rpcrdma_pullup_tail_iov(r_xprt, req, xdr); + + if (unlikely(xdr->page_len)) + rpcrdma_pullup_pagelist(r_xprt, req, xdr); + + /* The whole RPC message resides in the head iovec now */ + return rpcrdma_prepare_head_iov(r_xprt, req, xdr->len); +} + +static bool rpcrdma_prepare_noch_mapped(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + struct kvec *tail = &xdr->tail[0]; + + if (!rpcrdma_prepare_head_iov(r_xprt, req, xdr->head[0].iov_len)) + return false; + if (xdr->page_len) + if (!rpcrdma_prepare_pagelist(req, xdr)) + return false; + if (tail->iov_len) + if (!rpcrdma_prepare_tail_iov(req, xdr, + offset_in_page(tail->iov_base), + tail->iov_len)) + return false; + + if (req->rl_sendctx->sc_unmap_count) + kref_get(&req->rl_kref); + return true; +} + +static bool rpcrdma_prepare_readch(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + if (!rpcrdma_prepare_head_iov(r_xprt, req, xdr->head[0].iov_len)) + return false; + + /* If there is a Read chunk, the page list is being handled + * via explicit RDMA, and thus is skipped here. + */ + + /* Do not include the tail if it is only an XDR pad */ + if (xdr->tail[0].iov_len > 3) { + unsigned int page_base, len; + + /* If the content in the page list is an odd length, + * xdr_write_pages() adds a pad at the beginning of + * the tail iovec. Force the tail's non-pad content to + * land at the next XDR position in the Send message. + */ + page_base = offset_in_page(xdr->tail[0].iov_base); + len = xdr->tail[0].iov_len; + page_base += len & 3; + len -= len & 3; + if (!rpcrdma_prepare_tail_iov(req, xdr, page_base, len)) + return false; + kref_get(&req->rl_kref); + } + + return true; +} + +/** + * rpcrdma_prepare_send_sges - Construct SGEs for a Send WR + * @r_xprt: controlling transport + * @req: context of RPC Call being marshalled + * @hdrlen: size of transport header, in bytes + * @xdr: xdr_buf containing RPC Call + * @rtype: chunk type being encoded * - * Uses multiple RDMA IOVs for a request: - * [0] -- RPC RDMA header, which uses memory from the *start* of the - * preregistered buffer that already holds the RPC data in - * its middle. - * [1] -- the RPC header/data, marshaled by RPC and the NFS protocol. - * [2] -- optional padding. - * [3] -- if padded, header only in [1] and data here. + * Returns 0 on success; otherwise a negative errno is returned. */ +inline int rpcrdma_prepare_send_sges(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, u32 hdrlen, + struct xdr_buf *xdr, + enum rpcrdma_chunktype rtype) +{ + int ret; + + ret = -EAGAIN; + req->rl_sendctx = rpcrdma_sendctx_get_locked(r_xprt); + if (!req->rl_sendctx) + goto out_nosc; + req->rl_sendctx->sc_unmap_count = 0; + req->rl_sendctx->sc_req = req; + kref_init(&req->rl_kref); + req->rl_wr.wr_cqe = &req->rl_sendctx->sc_cqe; + req->rl_wr.sg_list = req->rl_sendctx->sc_sges; + req->rl_wr.num_sge = 0; + req->rl_wr.opcode = IB_WR_SEND; + + rpcrdma_prepare_hdr_sge(r_xprt, req, hdrlen); + + ret = -EIO; + switch (rtype) { + case rpcrdma_noch_pullup: + if (!rpcrdma_prepare_noch_pullup(r_xprt, req, xdr)) + goto out_unmap; + break; + case rpcrdma_noch_mapped: + if (!rpcrdma_prepare_noch_mapped(r_xprt, req, xdr)) + goto out_unmap; + break; + case rpcrdma_readch: + if (!rpcrdma_prepare_readch(r_xprt, req, xdr)) + goto out_unmap; + break; + case rpcrdma_areadch: + break; + default: + goto out_unmap; + } + + return 0; + +out_unmap: + rpcrdma_sendctx_unmap(req->rl_sendctx); +out_nosc: + trace_xprtrdma_prepsend_failed(&req->rl_slot, ret); + return ret; +} +/** + * rpcrdma_marshal_req - Marshal and send one RPC request + * @r_xprt: controlling transport + * @rqst: RPC request to be marshaled + * + * For the RPC in "rqst", this function: + * - Chooses the transfer mode (eg., RDMA_MSG or RDMA_NOMSG) + * - Registers Read, Write, and Reply chunks + * - Constructs the transport header + * - Posts a Send WR to send the transport header and request + * + * Returns: + * %0 if the RPC was sent successfully, + * %-ENOTCONN if the connection was lost, + * %-EAGAIN if the caller should call again with the same arguments, + * %-ENOBUFS if the caller should call again after a delay, + * %-EMSGSIZE if the transport header is too small, + * %-EIO if a permanent problem occurred while marshaling. + */ int -rpcrdma_marshal_req(struct rpc_rqst *rqst) +rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst) { - struct rpc_xprt *xprt = rqst->rq_xprt; - struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); struct rpcrdma_req *req = rpcr_to_rdmar(rqst); - char *base; - size_t hdrlen, rpclen, padlen; + struct xdr_stream *xdr = &req->rl_stream; enum rpcrdma_chunktype rtype, wtype; - struct rpcrdma_msg *headerp; + struct xdr_buf *buf = &rqst->rq_snd_buf; + bool ddp_allowed; + __be32 *p; + int ret; + + if (unlikely(rqst->rq_rcv_buf.flags & XDRBUF_SPARSE_PAGES)) { + ret = rpcrdma_alloc_sparse_pages(&rqst->rq_rcv_buf); + if (ret) + return ret; + } - /* - * rpclen gets amount of data in first buffer, which is the - * pre-registered buffer. + rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0); + xdr_init_encode(xdr, &req->rl_hdrbuf, rdmab_data(req->rl_rdmabuf), + rqst); + + /* Fixed header fields */ + ret = -EMSGSIZE; + p = xdr_reserve_space(xdr, 4 * sizeof(*p)); + if (!p) + goto out_err; + *p++ = rqst->rq_xid; + *p++ = rpcrdma_version; + *p++ = r_xprt->rx_buf.rb_max_requests; + + /* When the ULP employs a GSS flavor that guarantees integrity + * or privacy, direct data placement of individual data items + * is not allowed. */ - base = rqst->rq_svec[0].iov_base; - rpclen = rqst->rq_svec[0].iov_len; - - /* build RDMA header in private area at front */ - headerp = (struct rpcrdma_msg *) req->rl_base; - /* don't htonl XID, it's already done in request */ - headerp->rm_xid = rqst->rq_xid; - headerp->rm_vers = xdr_one; - headerp->rm_credit = htonl(r_xprt->rx_buf.rb_max_requests); - headerp->rm_type = htonl(RDMA_MSG); + ddp_allowed = !test_bit(RPCAUTH_AUTH_DATATOUCH, + &rqst->rq_cred->cr_auth->au_flags); /* * Chunks needed for results? * * o If the expected result is under the inline threshold, all ops - * return as inline (but see later). + * return as inline. + * o Large read ops return data as write chunk(s), header as + * inline. * o Large non-read ops return as a single reply chunk. - * o Large read ops return data as write chunk(s), header as inline. - * - * Note: the NFS code sending down multiple result segments implies - * the op is one of read, readdir[plus], readlink or NFSv4 getacl. */ - - /* - * This code can handle read chunks, write chunks OR reply - * chunks -- only one type. If the request is too big to fit - * inline, then we will choose read chunks. If the request is - * a READ, then use write chunks to separate the file data - * into pages; otherwise use reply chunks. - */ - if (rqst->rq_rcv_buf.buflen <= RPCRDMA_INLINE_READ_THRESHOLD(rqst)) + if (rpcrdma_results_inline(r_xprt, rqst)) wtype = rpcrdma_noch; - else if (rqst->rq_rcv_buf.page_len == 0) - wtype = rpcrdma_replych; - else if (rqst->rq_rcv_buf.flags & XDRBUF_READ) + else if ((ddp_allowed && rqst->rq_rcv_buf.flags & XDRBUF_READ) && + rpcrdma_nonpayload_inline(r_xprt, rqst)) wtype = rpcrdma_writech; else wtype = rpcrdma_replych; @@ -422,466 +915,595 @@ rpcrdma_marshal_req(struct rpc_rqst *rqst) * * o If the total request is under the inline threshold, all ops * are sent as inline. - * o Large non-write ops are sent with the entire message as a - * single read chunk (protocol 0-position special case). * o Large write ops transmit data as read chunk(s), header as * inline. + * o Large non-write ops are sent with the entire message as a + * single read chunk (protocol 0-position special case). * - * Note: the NFS code sending down multiple argument segments - * implies the op is a write. - * TBD check NFSv4 setacl + * This assumes that the upper layer does not present a request + * that both has a data payload, and whose non-data arguments + * by themselves are larger than the inline threshold. */ - if (rqst->rq_snd_buf.len <= RPCRDMA_INLINE_WRITE_THRESHOLD(rqst)) - rtype = rpcrdma_noch; - else if (rqst->rq_snd_buf.page_len == 0) - rtype = rpcrdma_areadch; - else + if (rpcrdma_args_inline(r_xprt, rqst)) { + *p++ = rdma_msg; + rtype = buf->len < rdmab_length(req->rl_sendbuf) ? + rpcrdma_noch_pullup : rpcrdma_noch_mapped; + } else if (ddp_allowed && buf->flags & XDRBUF_WRITE) { + *p++ = rdma_msg; rtype = rpcrdma_readch; - - /* The following simplification is not true forever */ - if (rtype != rpcrdma_noch && wtype == rpcrdma_replych) - wtype = rpcrdma_noch; - BUG_ON(rtype != rpcrdma_noch && wtype != rpcrdma_noch); - - if (r_xprt->rx_ia.ri_memreg_strategy == RPCRDMA_BOUNCEBUFFERS && - (rtype != rpcrdma_noch || wtype != rpcrdma_noch)) { - /* forced to "pure inline"? */ - dprintk("RPC: %s: too much data (%d/%d) for inline\n", - __func__, rqst->rq_rcv_buf.len, rqst->rq_snd_buf.len); - return -1; - } - - hdrlen = 28; /*sizeof *headerp;*/ - padlen = 0; - - /* - * Pull up any extra send data into the preregistered buffer. - * When padding is in use and applies to the transfer, insert - * it and change the message type. - */ - if (rtype == rpcrdma_noch) { - - padlen = rpcrdma_inline_pullup(rqst, - RPCRDMA_INLINE_PAD_VALUE(rqst)); - - if (padlen) { - headerp->rm_type = htonl(RDMA_MSGP); - headerp->rm_body.rm_padded.rm_align = - htonl(RPCRDMA_INLINE_PAD_VALUE(rqst)); - headerp->rm_body.rm_padded.rm_thresh = - htonl(RPCRDMA_INLINE_PAD_THRESH); - headerp->rm_body.rm_padded.rm_pempty[0] = xdr_zero; - headerp->rm_body.rm_padded.rm_pempty[1] = xdr_zero; - headerp->rm_body.rm_padded.rm_pempty[2] = xdr_zero; - hdrlen += 2 * sizeof(u32); /* extra words in padhdr */ - BUG_ON(wtype != rpcrdma_noch); - - } else { - headerp->rm_body.rm_nochunks.rm_empty[0] = xdr_zero; - headerp->rm_body.rm_nochunks.rm_empty[1] = xdr_zero; - headerp->rm_body.rm_nochunks.rm_empty[2] = xdr_zero; - /* new length after pullup */ - rpclen = rqst->rq_svec[0].iov_len; - /* - * Currently we try to not actually use read inline. - * Reply chunks have the desirable property that - * they land, packed, directly in the target buffers - * without headers, so they require no fixup. The - * additional RDMA Write op sends the same amount - * of data, streams on-the-wire and adds no overhead - * on receive. Therefore, we request a reply chunk - * for non-writes wherever feasible and efficient. - */ - if (wtype == rpcrdma_noch && - r_xprt->rx_ia.ri_memreg_strategy > RPCRDMA_REGISTER) - wtype = rpcrdma_replych; - } - } - - /* - * Marshal chunks. This routine will return the header length - * consumed by marshaling. - */ - if (rtype != rpcrdma_noch) { - hdrlen = rpcrdma_create_chunks(rqst, - &rqst->rq_snd_buf, headerp, rtype); - wtype = rtype; /* simplify dprintk */ - - } else if (wtype != rpcrdma_noch) { - hdrlen = rpcrdma_create_chunks(rqst, - &rqst->rq_rcv_buf, headerp, wtype); + } else { + r_xprt->rx_stats.nomsg_call_count++; + *p++ = rdma_nomsg; + rtype = rpcrdma_areadch; } - if (hdrlen == 0) - return -1; - - dprintk("RPC: %s: %s: hdrlen %zd rpclen %zd padlen %zd" - " headerp 0x%p base 0x%p lkey 0x%x\n", - __func__, transfertypes[wtype], hdrlen, rpclen, padlen, - headerp, base, req->rl_iov.lkey); - - /* - * initialize send_iov's - normally only two: rdma chunk header and - * single preregistered RPC header buffer, but if padding is present, - * then use a preregistered (and zeroed) pad buffer between the RPC - * header and any write data. In all non-rdma cases, any following - * data has been copied into the RPC header buffer. + /* This implementation supports the following combinations + * of chunk lists in one RPC-over-RDMA Call message: + * + * - Read list + * - Write list + * - Reply chunk + * - Read list + Reply chunk + * + * It might not yet support the following combinations: + * + * - Read list + Write list + * + * It does not support the following combinations: + * + * - Write list + Reply chunk + * - Read list + Write list + Reply chunk + * + * This implementation supports only a single chunk in each + * Read or Write list. Thus for example the client cannot + * send a Call message with a Position Zero Read chunk and a + * regular Read chunk at the same time. */ - req->rl_send_iov[0].addr = req->rl_iov.addr; - req->rl_send_iov[0].length = hdrlen; - req->rl_send_iov[0].lkey = req->rl_iov.lkey; - - req->rl_send_iov[1].addr = req->rl_iov.addr + (base - req->rl_base); - req->rl_send_iov[1].length = rpclen; - req->rl_send_iov[1].lkey = req->rl_iov.lkey; - - req->rl_niovs = 2; - - if (padlen) { - struct rpcrdma_ep *ep = &r_xprt->rx_ep; + ret = rpcrdma_encode_read_list(r_xprt, req, rqst, rtype); + if (ret) + goto out_err; + ret = rpcrdma_encode_write_list(r_xprt, req, rqst, wtype); + if (ret) + goto out_err; + ret = rpcrdma_encode_reply_chunk(r_xprt, req, rqst, wtype); + if (ret) + goto out_err; + + ret = rpcrdma_prepare_send_sges(r_xprt, req, req->rl_hdrbuf.len, + buf, rtype); + if (ret) + goto out_err; + + trace_xprtrdma_marshal(req, rtype, wtype); + return 0; - req->rl_send_iov[2].addr = ep->rep_pad.addr; - req->rl_send_iov[2].length = padlen; - req->rl_send_iov[2].lkey = ep->rep_pad.lkey; +out_err: + trace_xprtrdma_marshal_failed(rqst, ret); + r_xprt->rx_stats.failed_marshal_count++; + frwr_reset(req); + return ret; +} - req->rl_send_iov[3].addr = req->rl_send_iov[1].addr + rpclen; - req->rl_send_iov[3].length = rqst->rq_slen - rpclen; - req->rl_send_iov[3].lkey = req->rl_iov.lkey; +static void __rpcrdma_update_cwnd_locked(struct rpc_xprt *xprt, + struct rpcrdma_buffer *buf, + u32 grant) +{ + buf->rb_credits = grant; + xprt->cwnd = grant << RPC_CWNDSHIFT; +} - req->rl_niovs = 4; - } +static void rpcrdma_update_cwnd(struct rpcrdma_xprt *r_xprt, u32 grant) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; - return 0; + spin_lock(&xprt->transport_lock); + __rpcrdma_update_cwnd_locked(xprt, &r_xprt->rx_buf, grant); + spin_unlock(&xprt->transport_lock); } -/* - * Chase down a received write or reply chunklist to get length - * RDMA'd by server. See map at rpcrdma_create_chunks()! :-) +/** + * rpcrdma_reset_cwnd - Reset the xprt's congestion window + * @r_xprt: controlling transport instance + * + * Prepare @r_xprt for the next connection by reinitializing + * its credit grant to one (see RFC 8166, Section 3.3.3). */ -static int -rpcrdma_count_chunks(struct rpcrdma_rep *rep, unsigned int max, int wrchunk, __be32 **iptrp) -{ - unsigned int i, total_len; - struct rpcrdma_write_chunk *cur_wchunk; - - i = ntohl(**iptrp); /* get array count */ - if (i > max) - return -1; - cur_wchunk = (struct rpcrdma_write_chunk *) (*iptrp + 1); - total_len = 0; - while (i--) { - struct rpcrdma_segment *seg = &cur_wchunk->wc_target; - ifdebug(FACILITY) { - u64 off; - xdr_decode_hyper((__be32 *)&seg->rs_offset, &off); - dprintk("RPC: %s: chunk %d@0x%llx:0x%x\n", - __func__, - ntohl(seg->rs_length), - (unsigned long long)off, - ntohl(seg->rs_handle)); - } - total_len += ntohl(seg->rs_length); - ++cur_wchunk; - } - /* check and adjust for properly terminated write chunk */ - if (wrchunk) { - __be32 *w = (__be32 *) cur_wchunk; - if (*w++ != xdr_zero) - return -1; - cur_wchunk = (struct rpcrdma_write_chunk *) w; - } - if ((char *) cur_wchunk > rep->rr_base + rep->rr_len) - return -1; +void rpcrdma_reset_cwnd(struct rpcrdma_xprt *r_xprt) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; - *iptrp = (__be32 *) cur_wchunk; - return total_len; + spin_lock(&xprt->transport_lock); + xprt->cong = 0; + __rpcrdma_update_cwnd_locked(xprt, &r_xprt->rx_buf, 1); + spin_unlock(&xprt->transport_lock); } -/* - * Scatter inline received data back into provided iov's. +/** + * rpcrdma_inline_fixup - Scatter inline received data into rqst's iovecs + * @rqst: controlling RPC request + * @srcp: points to RPC message payload in receive buffer + * @copy_len: remaining length of receive buffer content + * @pad: Write chunk pad bytes needed (zero for pure inline) + * + * The upper layer has set the maximum number of bytes it can + * receive in each component of rq_rcv_buf. These values are set in + * the head.iov_len, page_len, tail.iov_len, and buflen fields. + * + * Unlike the TCP equivalent (xdr_partial_copy_from_skb), in + * many cases this function simply updates iov_base pointers in + * rq_rcv_buf to point directly to the received reply data, to + * avoid copying reply data. + * + * Returns the count of bytes which had to be memcopied. */ -static void +static unsigned long rpcrdma_inline_fixup(struct rpc_rqst *rqst, char *srcp, int copy_len, int pad) { - int i, npages, curlen, olen; + unsigned long fixup_copy_count; + int i, npages, curlen; char *destp; struct page **ppages; int page_base; + /* The head iovec is redirected to the RPC reply message + * in the receive buffer, to avoid a memcopy. + */ + rqst->rq_rcv_buf.head[0].iov_base = srcp; + rqst->rq_private_buf.head[0].iov_base = srcp; + + /* The contents of the receive buffer that follow + * head.iov_len bytes are copied into the page list. + */ curlen = rqst->rq_rcv_buf.head[0].iov_len; - if (curlen > copy_len) { /* write chunk header fixup */ + if (curlen > copy_len) curlen = copy_len; - rqst->rq_rcv_buf.head[0].iov_len = curlen; - } - - dprintk("RPC: %s: srcp 0x%p len %d hdrlen %d\n", - __func__, srcp, copy_len, curlen); - - /* Shift pointer for first receive segment only */ - rqst->rq_rcv_buf.head[0].iov_base = srcp; srcp += curlen; copy_len -= curlen; - olen = copy_len; - i = 0; - rpcx_to_rdmax(rqst->rq_xprt)->rx_stats.fixup_copy_count += olen; - page_base = rqst->rq_rcv_buf.page_base; - ppages = rqst->rq_rcv_buf.pages + (page_base >> PAGE_SHIFT); - page_base &= ~PAGE_MASK; - + ppages = rqst->rq_rcv_buf.pages + + (rqst->rq_rcv_buf.page_base >> PAGE_SHIFT); + page_base = offset_in_page(rqst->rq_rcv_buf.page_base); + fixup_copy_count = 0; if (copy_len && rqst->rq_rcv_buf.page_len) { - npages = PAGE_ALIGN(page_base + - rqst->rq_rcv_buf.page_len) >> PAGE_SHIFT; - for (; i < npages; i++) { + int pagelist_len; + + pagelist_len = rqst->rq_rcv_buf.page_len; + if (pagelist_len > copy_len) + pagelist_len = copy_len; + npages = PAGE_ALIGN(page_base + pagelist_len) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { curlen = PAGE_SIZE - page_base; - if (curlen > copy_len) - curlen = copy_len; - dprintk("RPC: %s: page %d" - " srcp 0x%p len %d curlen %d\n", - __func__, i, srcp, copy_len, curlen); + if (curlen > pagelist_len) + curlen = pagelist_len; + destp = kmap_atomic(ppages[i]); memcpy(destp + page_base, srcp, curlen); flush_dcache_page(ppages[i]); kunmap_atomic(destp); srcp += curlen; copy_len -= curlen; - if (copy_len == 0) + fixup_copy_count += curlen; + pagelist_len -= curlen; + if (!pagelist_len) break; page_base = 0; } - rqst->rq_rcv_buf.page_len = olen - copy_len; - } else - rqst->rq_rcv_buf.page_len = 0; - if (copy_len && rqst->rq_rcv_buf.tail[0].iov_len) { - curlen = copy_len; - if (curlen > rqst->rq_rcv_buf.tail[0].iov_len) - curlen = rqst->rq_rcv_buf.tail[0].iov_len; - if (rqst->rq_rcv_buf.tail[0].iov_base != srcp) - memmove(rqst->rq_rcv_buf.tail[0].iov_base, srcp, curlen); - dprintk("RPC: %s: tail srcp 0x%p len %d curlen %d\n", - __func__, srcp, copy_len, curlen); - rqst->rq_rcv_buf.tail[0].iov_len = curlen; - copy_len -= curlen; ++i; - } else - rqst->rq_rcv_buf.tail[0].iov_len = 0; - - if (pad) { - /* implicit padding on terminal chunk */ - unsigned char *p = rqst->rq_rcv_buf.tail[0].iov_base; - while (pad--) - p[rqst->rq_rcv_buf.tail[0].iov_len++] = 0; + /* Implicit padding for the last segment in a Write + * chunk is inserted inline at the front of the tail + * iovec. The upper layer ignores the content of + * the pad. Simply ensure inline content in the tail + * that follows the Write chunk is properly aligned. + */ + if (pad) + srcp -= pad; } - if (copy_len) - dprintk("RPC: %s: %d bytes in" - " %d extra segments (%d lost)\n", - __func__, olen, i, copy_len); + /* The tail iovec is redirected to the remaining data + * in the receive buffer, to avoid a memcopy. + */ + if (copy_len || pad) { + rqst->rq_rcv_buf.tail[0].iov_base = srcp; + rqst->rq_private_buf.tail[0].iov_base = srcp; + } - /* TBD avoid a warning from call_decode() */ - rqst->rq_private_buf = rqst->rq_rcv_buf; + if (fixup_copy_count) + trace_xprtrdma_fixup(rqst, fixup_copy_count); + return fixup_copy_count; } -/* - * This function is called when an async event is posted to - * the connection which changes the connection state. All it - * does at this point is mark the connection up/down, the rpc - * timers do the rest. - */ -void -rpcrdma_conn_func(struct rpcrdma_ep *ep) -{ - struct rpc_xprt *xprt = ep->rep_xprt; - - spin_lock_bh(&xprt->transport_lock); - if (++xprt->connect_cookie == 0) /* maintain a reserved value */ - ++xprt->connect_cookie; - if (ep->rep_connected > 0) { - if (!xprt_test_and_set_connected(xprt)) - xprt_wake_pending_tasks(xprt, 0); - } else { - if (xprt_test_and_clear_connected(xprt)) - xprt_wake_pending_tasks(xprt, -ENOTCONN); +/* By convention, backchannel calls arrive via rdma_msg type + * messages, and never populate the chunk lists. This makes + * the RPC/RDMA header small and fixed in size, so it is + * straightforward to check the RPC header's direction field. + */ +static bool +rpcrdma_is_bcall(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep) +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct xdr_stream *xdr = &rep->rr_stream; + __be32 *p; + + if (rep->rr_proc != rdma_msg) + return false; + + /* Peek at stream contents without advancing. */ + p = xdr_inline_decode(xdr, 0); + + /* Chunk lists */ + if (xdr_item_is_present(p++)) + return false; + if (xdr_item_is_present(p++)) + return false; + if (xdr_item_is_present(p++)) + return false; + + /* RPC header */ + if (*p++ != rep->rr_xid) + return false; + if (*p != cpu_to_be32(RPC_CALL)) + return false; + + /* No bc service. */ + if (xprt->bc_serv == NULL) + return false; + + /* Now that we are sure this is a backchannel call, + * advance to the RPC header. + */ + p = xdr_inline_decode(xdr, 3 * sizeof(*p)); + if (unlikely(!p)) + return true; + + rpcrdma_bc_receive_call(r_xprt, rep); + return true; +} +#else /* CONFIG_SUNRPC_BACKCHANNEL */ +{ + return false; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +static int decode_rdma_segment(struct xdr_stream *xdr, u32 *length) +{ + u32 handle; + u64 offset; + __be32 *p; + + p = xdr_inline_decode(xdr, 4 * sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + xdr_decode_rdma_segment(p, &handle, length, &offset); + trace_xprtrdma_decode_seg(handle, *length, offset); + return 0; +} + +static int decode_write_chunk(struct xdr_stream *xdr, u32 *length) +{ + u32 segcount, seglength; + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + *length = 0; + segcount = be32_to_cpup(p); + while (segcount--) { + if (decode_rdma_segment(xdr, &seglength)) + return -EIO; + *length += seglength; } - spin_unlock_bh(&xprt->transport_lock); + + return 0; } -/* - * This function is called when memory window unbind which we are waiting - * for completes. Just use rr_func (zeroed by upcall) to signal completion. +/* In RPC-over-RDMA Version One replies, a Read list is never + * expected. This decoder is a stub that returns an error if + * a Read list is present. */ -static void -rpcrdma_unbind_func(struct rpcrdma_rep *rep) +static int decode_read_list(struct xdr_stream *xdr) { - wake_up(&rep->rr_unbind); + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + if (unlikely(xdr_item_is_present(p))) + return -EIO; + return 0; } -/* - * Called as a tasklet to do req/reply match and complete a request - * Errors must result in the RPC task either being awakened, or - * allowed to timeout, to discover the errors at that time. +/* Supports only one Write chunk in the Write list */ -void -rpcrdma_reply_handler(struct rpcrdma_rep *rep) +static int decode_write_list(struct xdr_stream *xdr, u32 *length) { - struct rpcrdma_msg *headerp; - struct rpcrdma_req *req; - struct rpc_rqst *rqst; - struct rpc_xprt *xprt = rep->rr_xprt; - struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); - __be32 *iptr; - int i, rdmalen, status; - - /* Check status. If bad, signal disconnect and return rep to pool */ - if (rep->rr_len == ~0U) { - rpcrdma_recv_buffer_put(rep); - if (r_xprt->rx_ep.rep_connected == 1) { - r_xprt->rx_ep.rep_connected = -EIO; - rpcrdma_conn_func(&r_xprt->rx_ep); - } - return; - } - if (rep->rr_len < 28) { - dprintk("RPC: %s: short/invalid reply\n", __func__); - goto repost; - } - headerp = (struct rpcrdma_msg *) rep->rr_base; - if (headerp->rm_vers != xdr_one) { - dprintk("RPC: %s: invalid version %d\n", - __func__, ntohl(headerp->rm_vers)); - goto repost; - } + u32 chunklen; + bool first; + __be32 *p; - /* Get XID and try for a match. */ - spin_lock(&xprt->transport_lock); - rqst = xprt_lookup_rqst(xprt, headerp->rm_xid); - if (rqst == NULL) { - spin_unlock(&xprt->transport_lock); - dprintk("RPC: %s: reply 0x%p failed " - "to match any request xid 0x%08x len %d\n", - __func__, rep, headerp->rm_xid, rep->rr_len); -repost: - r_xprt->rx_stats.bad_reply_count++; - rep->rr_func = rpcrdma_reply_handler; - if (rpcrdma_ep_post_recv(&r_xprt->rx_ia, &r_xprt->rx_ep, rep)) - rpcrdma_recv_buffer_put(rep); + *length = 0; + first = true; + do { + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + if (xdr_item_is_absent(p)) + break; + if (!first) + return -EIO; + + if (decode_write_chunk(xdr, &chunklen)) + return -EIO; + *length += chunklen; + first = false; + } while (true); + return 0; +} - return; - } +static int decode_reply_chunk(struct xdr_stream *xdr, u32 *length) +{ + __be32 *p; - /* get request object */ - req = rpcr_to_rdmar(rqst); - if (req->rl_reply) { - spin_unlock(&xprt->transport_lock); - dprintk("RPC: %s: duplicate reply 0x%p to RPC " - "request 0x%p: xid 0x%08x\n", __func__, rep, req, - headerp->rm_xid); - goto repost; - } + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; - dprintk("RPC: %s: reply 0x%p completes request 0x%p\n" - " RPC request 0x%p xid 0x%08x\n", - __func__, rep, req, rqst, headerp->rm_xid); + *length = 0; + if (xdr_item_is_present(p)) + if (decode_write_chunk(xdr, length)) + return -EIO; + return 0; +} - /* from here on, the reply is no longer an orphan */ - req->rl_reply = rep; +static int +rpcrdma_decode_msg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep, + struct rpc_rqst *rqst) +{ + struct xdr_stream *xdr = &rep->rr_stream; + u32 writelist, replychunk, rpclen; + char *base; - /* check for expected message types */ - /* The order of some of these tests is important. */ - switch (headerp->rm_type) { - case htonl(RDMA_MSG): - /* never expect read chunks */ - /* never expect reply chunks (two ways to check) */ - /* never expect write chunks without having offered RDMA */ - if (headerp->rm_body.rm_chunks[0] != xdr_zero || - (headerp->rm_body.rm_chunks[1] == xdr_zero && - headerp->rm_body.rm_chunks[2] != xdr_zero) || - (headerp->rm_body.rm_chunks[1] != xdr_zero && - req->rl_nchunks == 0)) - goto badheader; - if (headerp->rm_body.rm_chunks[1] != xdr_zero) { - /* count any expected write chunks in read reply */ - /* start at write chunk array count */ - iptr = &headerp->rm_body.rm_chunks[2]; - rdmalen = rpcrdma_count_chunks(rep, - req->rl_nchunks, 1, &iptr); - /* check for validity, and no reply chunk after */ - if (rdmalen < 0 || *iptr++ != xdr_zero) - goto badheader; - rep->rr_len -= - ((unsigned char *)iptr - (unsigned char *)headerp); - status = rep->rr_len + rdmalen; - r_xprt->rx_stats.total_rdma_reply += rdmalen; - /* special case - last chunk may omit padding */ - if (rdmalen &= 3) { - rdmalen = 4 - rdmalen; - status += rdmalen; - } - } else { - /* else ordinary inline */ - rdmalen = 0; - iptr = (__be32 *)((unsigned char *)headerp + 28); - rep->rr_len -= 28; /*sizeof *headerp;*/ - status = rep->rr_len; - } - /* Fix up the rpc results for upper layer */ - rpcrdma_inline_fixup(rqst, (char *)iptr, rep->rr_len, rdmalen); - break; + /* Decode the chunk lists */ + if (decode_read_list(xdr)) + return -EIO; + if (decode_write_list(xdr, &writelist)) + return -EIO; + if (decode_reply_chunk(xdr, &replychunk)) + return -EIO; + + /* RDMA_MSG sanity checks */ + if (unlikely(replychunk)) + return -EIO; + + /* Build the RPC reply's Payload stream in rqst->rq_rcv_buf */ + base = (char *)xdr_inline_decode(xdr, 0); + rpclen = xdr_stream_remaining(xdr); + r_xprt->rx_stats.fixup_copy_count += + rpcrdma_inline_fixup(rqst, base, rpclen, writelist & 3); + + r_xprt->rx_stats.total_rdma_reply += writelist; + return rpclen + xdr_align_size(writelist); +} - case htonl(RDMA_NOMSG): - /* never expect read or write chunks, always reply chunks */ - if (headerp->rm_body.rm_chunks[0] != xdr_zero || - headerp->rm_body.rm_chunks[1] != xdr_zero || - headerp->rm_body.rm_chunks[2] != xdr_one || - req->rl_nchunks == 0) - goto badheader; - iptr = (__be32 *)((unsigned char *)headerp + 28); - rdmalen = rpcrdma_count_chunks(rep, req->rl_nchunks, 0, &iptr); - if (rdmalen < 0) - goto badheader; - r_xprt->rx_stats.total_rdma_reply += rdmalen; - /* Reply chunk buffer already is the reply vector - no fixup. */ - status = rdmalen; - break; +static noinline int +rpcrdma_decode_nomsg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep) +{ + struct xdr_stream *xdr = &rep->rr_stream; + u32 writelist, replychunk; + + /* Decode the chunk lists */ + if (decode_read_list(xdr)) + return -EIO; + if (decode_write_list(xdr, &writelist)) + return -EIO; + if (decode_reply_chunk(xdr, &replychunk)) + return -EIO; + + /* RDMA_NOMSG sanity checks */ + if (unlikely(writelist)) + return -EIO; + if (unlikely(!replychunk)) + return -EIO; + + /* Reply chunk buffer already is the reply vector */ + r_xprt->rx_stats.total_rdma_reply += replychunk; + return replychunk; +} -badheader: - default: - dprintk("%s: invalid rpcrdma reply header (type %d):" - " chunks[012] == %d %d %d" - " expected chunks <= %d\n", - __func__, ntohl(headerp->rm_type), - headerp->rm_body.rm_chunks[0], - headerp->rm_body.rm_chunks[1], - headerp->rm_body.rm_chunks[2], - req->rl_nchunks); - status = -EIO; - r_xprt->rx_stats.bad_reply_count++; +static noinline int +rpcrdma_decode_error(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep, + struct rpc_rqst *rqst) +{ + struct xdr_stream *xdr = &rep->rr_stream; + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + switch (*p) { + case err_vers: + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + break; + trace_xprtrdma_err_vers(rqst, p, p + 1); + break; + case err_chunk: + trace_xprtrdma_err_chunk(rqst); break; + default: + trace_xprtrdma_err_unrecognized(rqst, p); } - /* If using mw bind, start the deregister process now. */ - /* (Note: if mr_free(), cannot perform it here, in tasklet context) */ - if (req->rl_nchunks) switch (r_xprt->rx_ia.ri_memreg_strategy) { - case RPCRDMA_MEMWINDOWS: - for (i = 0; req->rl_nchunks-- > 1;) - i += rpcrdma_deregister_external( - &req->rl_segments[i], r_xprt, NULL); - /* Optionally wait (not here) for unbinds to complete */ - rep->rr_func = rpcrdma_unbind_func; - (void) rpcrdma_deregister_external(&req->rl_segments[i], - r_xprt, rep); + return -EIO; +} + +/** + * rpcrdma_unpin_rqst - Release rqst without completing it + * @rep: RPC/RDMA Receive context + * + * This is done when a connection is lost so that a Reply + * can be dropped and its matching Call can be subsequently + * retransmitted on a new connection. + */ +void rpcrdma_unpin_rqst(struct rpcrdma_rep *rep) +{ + struct rpc_xprt *xprt = &rep->rr_rxprt->rx_xprt; + struct rpc_rqst *rqst = rep->rr_rqst; + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + + req->rl_reply = NULL; + rep->rr_rqst = NULL; + + spin_lock(&xprt->queue_lock); + xprt_unpin_rqst(rqst); + spin_unlock(&xprt->queue_lock); +} + +/** + * rpcrdma_complete_rqst - Pass completed rqst back to RPC + * @rep: RPC/RDMA Receive context + * + * Reconstruct the RPC reply and complete the transaction + * while @rqst is still pinned to ensure the rep, rqst, and + * rq_task pointers remain stable. + */ +void rpcrdma_complete_rqst(struct rpcrdma_rep *rep) +{ + struct rpcrdma_xprt *r_xprt = rep->rr_rxprt; + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpc_rqst *rqst = rep->rr_rqst; + int status; + + switch (rep->rr_proc) { + case rdma_msg: + status = rpcrdma_decode_msg(r_xprt, rep, rqst); break; - case RPCRDMA_MEMWINDOWS_ASYNC: - for (i = 0; req->rl_nchunks--;) - i += rpcrdma_deregister_external(&req->rl_segments[i], - r_xprt, NULL); + case rdma_nomsg: + status = rpcrdma_decode_nomsg(r_xprt, rep); break; - default: + case rdma_error: + status = rpcrdma_decode_error(r_xprt, rep, rqst); break; + default: + status = -EIO; } + if (status < 0) + goto out_badheader; - dprintk("RPC: %s: xprt_complete_rqst(0x%p, 0x%p, %d)\n", - __func__, xprt, rqst, status); +out: + spin_lock(&xprt->queue_lock); xprt_complete_rqst(rqst->rq_task, status); - spin_unlock(&xprt->transport_lock); + xprt_unpin_rqst(rqst); + spin_unlock(&xprt->queue_lock); + return; + +out_badheader: + trace_xprtrdma_reply_hdr_err(rep); + r_xprt->rx_stats.bad_reply_count++; + rqst->rq_task->tk_status = status; + status = 0; + goto out; +} + +static void rpcrdma_reply_done(struct kref *kref) +{ + struct rpcrdma_req *req = + container_of(kref, struct rpcrdma_req, rl_kref); + + rpcrdma_complete_rqst(req->rl_reply); +} + +/** + * rpcrdma_reply_handler - Process received RPC/RDMA messages + * @rep: Incoming rpcrdma_rep object to process + * + * Errors must result in the RPC task either being awakened, or + * allowed to timeout, to discover the errors at that time. + */ +void rpcrdma_reply_handler(struct rpcrdma_rep *rep) +{ + struct rpcrdma_xprt *r_xprt = rep->rr_rxprt; + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_req *req; + struct rpc_rqst *rqst; + u32 credits; + __be32 *p; + + /* Any data means we had a useful conversation, so + * then we don't need to delay the next reconnect. + */ + if (xprt->reestablish_timeout) + xprt->reestablish_timeout = 0; + + /* Fixed transport header fields */ + xdr_init_decode(&rep->rr_stream, &rep->rr_hdrbuf, + rep->rr_hdrbuf.head[0].iov_base, NULL); + p = xdr_inline_decode(&rep->rr_stream, 4 * sizeof(*p)); + if (unlikely(!p)) + goto out_shortreply; + rep->rr_xid = *p++; + rep->rr_vers = *p++; + credits = be32_to_cpu(*p++); + rep->rr_proc = *p++; + + if (rep->rr_vers != rpcrdma_version) + goto out_badversion; + + if (rpcrdma_is_bcall(r_xprt, rep)) + return; + + /* Match incoming rpcrdma_rep to an rpcrdma_req to + * get context for handling any incoming chunks. + */ + spin_lock(&xprt->queue_lock); + rqst = xprt_lookup_rqst(xprt, rep->rr_xid); + if (!rqst) + goto out_norqst; + xprt_pin_rqst(rqst); + spin_unlock(&xprt->queue_lock); + + if (credits == 0) + credits = 1; /* don't deadlock */ + else if (credits > r_xprt->rx_ep->re_max_requests) + credits = r_xprt->rx_ep->re_max_requests; + rpcrdma_post_recvs(r_xprt, credits + (buf->rb_bc_srv_max_requests << 1)); + if (buf->rb_credits != credits) + rpcrdma_update_cwnd(r_xprt, credits); + + req = rpcr_to_rdmar(rqst); + if (unlikely(req->rl_reply)) + rpcrdma_rep_put(buf, req->rl_reply); + req->rl_reply = rep; + rep->rr_rqst = rqst; + + trace_xprtrdma_reply(rqst->rq_task, rep, credits); + + if (rep->rr_wc_flags & IB_WC_WITH_INVALIDATE) + frwr_reminv(rep, &req->rl_registered); + if (!list_empty(&req->rl_registered)) + frwr_unmap_async(r_xprt, req); + /* LocalInv completion will complete the RPC */ + else + kref_put(&req->rl_kref, rpcrdma_reply_done); + return; + +out_badversion: + trace_xprtrdma_reply_vers_err(rep); + goto out; + +out_norqst: + spin_unlock(&xprt->queue_lock); + trace_xprtrdma_reply_rqst_err(rep); + goto out; + +out_shortreply: + trace_xprtrdma_reply_short_err(rep); + +out: + rpcrdma_rep_put(buf, rep); } diff --git a/net/sunrpc/xprtrdma/svc_rdma.c b/net/sunrpc/xprtrdma/svc_rdma.c index c1b6270262c2..415c0310101f 100644 --- a/net/sunrpc/xprtrdma/svc_rdma.c +++ b/net/sunrpc/xprtrdma/svc_rdma.c @@ -1,4 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause /* + * Copyright (c) 2015-2018 Oracle. All rights reserved. * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -38,8 +40,7 @@ * * Author: Tom Tucker <tom@opengridcomputing.com> */ -#include <linux/module.h> -#include <linux/init.h> + #include <linux/slab.h> #include <linux/fs.h> #include <linux/sysctl.h> @@ -47,74 +48,61 @@ #include <linux/sunrpc/clnt.h> #include <linux/sunrpc/sched.h> #include <linux/sunrpc/svc_rdma.h> -#include "xprt_rdma.h" #define RPCDBG_FACILITY RPCDBG_SVCXPRT /* RPC/RDMA parameters */ -unsigned int svcrdma_ord = RPCRDMA_ORD; +unsigned int svcrdma_ord = 16; /* historical default */ static unsigned int min_ord = 1; -static unsigned int max_ord = 4096; +static unsigned int max_ord = 255; unsigned int svcrdma_max_requests = RPCRDMA_MAX_REQUESTS; +unsigned int svcrdma_max_bc_requests = RPCRDMA_MAX_BC_REQUESTS; static unsigned int min_max_requests = 4; static unsigned int max_max_requests = 16384; -unsigned int svcrdma_max_req_size = RPCRDMA_MAX_REQ_SIZE; -static unsigned int min_max_inline = 4096; -static unsigned int max_max_inline = 65536; - -atomic_t rdma_stat_recv; -atomic_t rdma_stat_read; -atomic_t rdma_stat_write; -atomic_t rdma_stat_sq_starve; -atomic_t rdma_stat_rq_starve; -atomic_t rdma_stat_rq_poll; -atomic_t rdma_stat_rq_prod; -atomic_t rdma_stat_sq_poll; -atomic_t rdma_stat_sq_prod; +unsigned int svcrdma_max_req_size = RPCRDMA_DEF_INLINE_THRESH; +static unsigned int min_max_inline = RPCRDMA_DEF_INLINE_THRESH; +static unsigned int max_max_inline = RPCRDMA_MAX_INLINE_THRESH; +static unsigned int svcrdma_stat_unused; +static unsigned int zero; -/* Temporary NFS request map and context caches */ -struct kmem_cache *svc_rdma_map_cachep; -struct kmem_cache *svc_rdma_ctxt_cachep; +struct percpu_counter svcrdma_stat_read; +struct percpu_counter svcrdma_stat_recv; +struct percpu_counter svcrdma_stat_sq_starve; +struct percpu_counter svcrdma_stat_write; -struct workqueue_struct *svc_rdma_wq; +enum { + SVCRDMA_COUNTER_BUFSIZ = sizeof(unsigned long long), +}; -/* - * This function implements reading and resetting an atomic_t stat - * variable through read/write to a proc file. Any write to the file - * resets the associated statistic to zero. Any read returns it's - * current value. - */ -static int read_reset_stat(struct ctl_table *table, int write, - void __user *buffer, size_t *lenp, - loff_t *ppos) +static int svcrdma_counter_handler(const struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) { - atomic_t *stat = (atomic_t *)table->data; + struct percpu_counter *stat = (struct percpu_counter *)table->data; + char tmp[SVCRDMA_COUNTER_BUFSIZ + 1]; + int len; - if (!stat) - return -EINVAL; + if (write) { + percpu_counter_set(stat, 0); + return 0; + } - if (write) - atomic_set(stat, 0); - else { - char str_buf[32]; - char *data; - int len = snprintf(str_buf, 32, "%d\n", atomic_read(stat)); - if (len >= 32) - return -EFAULT; - len = strlen(str_buf); - if (*ppos > len) { - *lenp = 0; - return 0; - } - data = &str_buf[*ppos]; - len -= *ppos; - if (len > *lenp) - len = *lenp; - if (len && copy_to_user(buffer, str_buf, len)) - return -EFAULT; - *lenp = len; - *ppos += len; + len = snprintf(tmp, SVCRDMA_COUNTER_BUFSIZ, "%lld\n", + percpu_counter_sum_positive(stat)); + if (len >= SVCRDMA_COUNTER_BUFSIZ) + return -EFAULT; + len = strlen(tmp); + if (*ppos > len) { + *lenp = 0; + return 0; } + len -= *ppos; + if (len > *lenp) + len = *lenp; + if (len) + memcpy(buffer, tmp, len); + *lenp = len; + *ppos += len; + return 0; } @@ -150,153 +138,170 @@ static struct ctl_table svcrdma_parm_table[] = { { .procname = "rdma_stat_read", - .data = &rdma_stat_read, - .maxlen = sizeof(atomic_t), + .data = &svcrdma_stat_read, + .maxlen = SVCRDMA_COUNTER_BUFSIZ, .mode = 0644, - .proc_handler = read_reset_stat, + .proc_handler = svcrdma_counter_handler, }, { .procname = "rdma_stat_recv", - .data = &rdma_stat_recv, - .maxlen = sizeof(atomic_t), + .data = &svcrdma_stat_recv, + .maxlen = SVCRDMA_COUNTER_BUFSIZ, .mode = 0644, - .proc_handler = read_reset_stat, + .proc_handler = svcrdma_counter_handler, }, { .procname = "rdma_stat_write", - .data = &rdma_stat_write, - .maxlen = sizeof(atomic_t), + .data = &svcrdma_stat_write, + .maxlen = SVCRDMA_COUNTER_BUFSIZ, .mode = 0644, - .proc_handler = read_reset_stat, + .proc_handler = svcrdma_counter_handler, }, { .procname = "rdma_stat_sq_starve", - .data = &rdma_stat_sq_starve, - .maxlen = sizeof(atomic_t), + .data = &svcrdma_stat_sq_starve, + .maxlen = SVCRDMA_COUNTER_BUFSIZ, .mode = 0644, - .proc_handler = read_reset_stat, + .proc_handler = svcrdma_counter_handler, }, { .procname = "rdma_stat_rq_starve", - .data = &rdma_stat_rq_starve, - .maxlen = sizeof(atomic_t), + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), .mode = 0644, - .proc_handler = read_reset_stat, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, }, { .procname = "rdma_stat_rq_poll", - .data = &rdma_stat_rq_poll, - .maxlen = sizeof(atomic_t), + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), .mode = 0644, - .proc_handler = read_reset_stat, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, }, { .procname = "rdma_stat_rq_prod", - .data = &rdma_stat_rq_prod, - .maxlen = sizeof(atomic_t), + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), .mode = 0644, - .proc_handler = read_reset_stat, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, }, { .procname = "rdma_stat_sq_poll", - .data = &rdma_stat_sq_poll, - .maxlen = sizeof(atomic_t), + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), .mode = 0644, - .proc_handler = read_reset_stat, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, }, { .procname = "rdma_stat_sq_prod", - .data = &rdma_stat_sq_prod, - .maxlen = sizeof(atomic_t), + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), .mode = 0644, - .proc_handler = read_reset_stat, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, }, - { }, }; -static struct ctl_table svcrdma_table[] = { - { - .procname = "svc_rdma", - .mode = 0555, - .child = svcrdma_parm_table - }, - { }, -}; +static void svc_rdma_proc_cleanup(void) +{ + if (!svcrdma_table_header) + return; + unregister_sysctl_table(svcrdma_table_header); + svcrdma_table_header = NULL; -static struct ctl_table svcrdma_root_table[] = { - { - .procname = "sunrpc", - .mode = 0555, - .child = svcrdma_table - }, - { }, -}; + percpu_counter_destroy(&svcrdma_stat_write); + percpu_counter_destroy(&svcrdma_stat_sq_starve); + percpu_counter_destroy(&svcrdma_stat_recv); + percpu_counter_destroy(&svcrdma_stat_read); +} + +static int svc_rdma_proc_init(void) +{ + int rc; + + if (svcrdma_table_header) + return 0; + + rc = percpu_counter_init(&svcrdma_stat_read, 0, GFP_KERNEL); + if (rc) + goto err; + rc = percpu_counter_init(&svcrdma_stat_recv, 0, GFP_KERNEL); + if (rc) + goto err_read; + rc = percpu_counter_init(&svcrdma_stat_sq_starve, 0, GFP_KERNEL); + if (rc) + goto err_recv; + rc = percpu_counter_init(&svcrdma_stat_write, 0, GFP_KERNEL); + if (rc) + goto err_sq; + + svcrdma_table_header = register_sysctl("sunrpc/svc_rdma", + svcrdma_parm_table); + if (!svcrdma_table_header) + goto err_write; + + return 0; + +err_write: + rc = -ENOMEM; + percpu_counter_destroy(&svcrdma_stat_write); +err_sq: + percpu_counter_destroy(&svcrdma_stat_sq_starve); +err_recv: + percpu_counter_destroy(&svcrdma_stat_recv); +err_read: + percpu_counter_destroy(&svcrdma_stat_read); +err: + return rc; +} + +struct workqueue_struct *svcrdma_wq; void svc_rdma_cleanup(void) { - dprintk("SVCRDMA Module Removed, deregister RPC RDMA transport\n"); - destroy_workqueue(svc_rdma_wq); - if (svcrdma_table_header) { - unregister_sysctl_table(svcrdma_table_header); - svcrdma_table_header = NULL; - } svc_unreg_xprt_class(&svc_rdma_class); - kmem_cache_destroy(svc_rdma_map_cachep); - kmem_cache_destroy(svc_rdma_ctxt_cachep); + svc_rdma_proc_cleanup(); + if (svcrdma_wq) { + struct workqueue_struct *wq = svcrdma_wq; + + svcrdma_wq = NULL; + destroy_workqueue(wq); + } + + dprintk("SVCRDMA Module Removed, deregister RPC RDMA transport\n"); } int svc_rdma_init(void) { - dprintk("SVCRDMA Module Init, register RPC RDMA transport\n"); - dprintk("\tsvcrdma_ord : %d\n", svcrdma_ord); - dprintk("\tmax_requests : %d\n", svcrdma_max_requests); - dprintk("\tsq_depth : %d\n", - svcrdma_max_requests * RPCRDMA_SQ_DEPTH_MULT); - dprintk("\tmax_inline : %d\n", svcrdma_max_req_size); + struct workqueue_struct *wq; + int rc; - svc_rdma_wq = alloc_workqueue("svc_rdma", 0, 0); - if (!svc_rdma_wq) + wq = alloc_workqueue("svcrdma", WQ_UNBOUND, 0); + if (!wq) return -ENOMEM; - if (!svcrdma_table_header) - svcrdma_table_header = - register_sysctl_table(svcrdma_root_table); - - /* Create the temporary map cache */ - svc_rdma_map_cachep = kmem_cache_create("svc_rdma_map_cache", - sizeof(struct svc_rdma_req_map), - 0, - SLAB_HWCACHE_ALIGN, - NULL); - if (!svc_rdma_map_cachep) { - printk(KERN_INFO "Could not allocate map cache.\n"); - goto err0; + rc = svc_rdma_proc_init(); + if (rc) { + destroy_workqueue(wq); + return rc; } - /* Create the temporary context cache */ - svc_rdma_ctxt_cachep = - kmem_cache_create("svc_rdma_ctxt_cache", - sizeof(struct svc_rdma_op_ctxt), - 0, - SLAB_HWCACHE_ALIGN, - NULL); - if (!svc_rdma_ctxt_cachep) { - printk(KERN_INFO "Could not allocate WR ctxt cache.\n"); - goto err1; - } - - /* Register RDMA with the SVC transport switch */ + svcrdma_wq = wq; svc_reg_xprt_class(&svc_rdma_class); + + dprintk("SVCRDMA Module Init, register RPC RDMA transport\n"); + dprintk("\tsvcrdma_ord : %d\n", svcrdma_ord); + dprintk("\tmax_requests : %u\n", svcrdma_max_requests); + dprintk("\tmax_bc_requests : %u\n", svcrdma_max_bc_requests); + dprintk("\tmax_inline : %d\n", svcrdma_max_req_size); return 0; - err1: - kmem_cache_destroy(svc_rdma_map_cachep); - err0: - unregister_sysctl_table(svcrdma_table_header); - destroy_workqueue(svc_rdma_wq); - return -ENOMEM; } -MODULE_AUTHOR("Tom Tucker <tom@opengridcomputing.com>"); -MODULE_DESCRIPTION("SVC RDMA Transport"); -MODULE_LICENSE("Dual BSD/GPL"); -module_init(svc_rdma_init); -module_exit(svc_rdma_cleanup); diff --git a/net/sunrpc/xprtrdma/svc_rdma_backchannel.c b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c new file mode 100644 index 000000000000..e5a78b761012 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c @@ -0,0 +1,284 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015-2018 Oracle. All rights reserved. + * + * Support for reverse-direction RPCs on RPC/RDMA (server-side). + */ + +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +/** + * svc_rdma_handle_bc_reply - Process incoming backchannel Reply + * @rqstp: resources for handling the Reply + * @rctxt: Received message + * + */ +void svc_rdma_handle_bc_reply(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *rctxt) +{ + struct svc_xprt *sxprt = rqstp->rq_xprt; + struct rpc_xprt *xprt = sxprt->xpt_bc_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct xdr_buf *rcvbuf = &rqstp->rq_arg; + struct kvec *dst, *src = &rcvbuf->head[0]; + __be32 *rdma_resp = rctxt->rc_recv_buf; + struct rpc_rqst *req; + u32 credits; + + spin_lock(&xprt->queue_lock); + req = xprt_lookup_rqst(xprt, *rdma_resp); + if (!req) + goto out_unlock; + + dst = &req->rq_private_buf.head[0]; + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(struct xdr_buf)); + if (dst->iov_len < src->iov_len) + goto out_unlock; + memcpy(dst->iov_base, src->iov_base, src->iov_len); + xprt_pin_rqst(req); + spin_unlock(&xprt->queue_lock); + + credits = be32_to_cpup(rdma_resp + 2); + if (credits == 0) + credits = 1; /* don't deadlock */ + else if (credits > r_xprt->rx_buf.rb_bc_max_requests) + credits = r_xprt->rx_buf.rb_bc_max_requests; + spin_lock(&xprt->transport_lock); + xprt->cwnd = credits << RPC_CWNDSHIFT; + spin_unlock(&xprt->transport_lock); + + spin_lock(&xprt->queue_lock); + xprt_complete_rqst(req->rq_task, rcvbuf->len); + xprt_unpin_rqst(req); + rcvbuf->len = 0; + +out_unlock: + spin_unlock(&xprt->queue_lock); +} + +/* Send a reverse-direction RPC Call. + * + * Caller holds the connection's mutex and has already marshaled + * the RPC/RDMA request. + * + * This is similar to svc_rdma_send_reply_msg, but takes a struct + * rpc_rqst instead, does not support chunks, and avoids blocking + * memory allocation. + * + * XXX: There is still an opportunity to block in svc_rdma_send() + * if there are no SQ entries to post the Send. This may occur if + * the adapter has a small maximum SQ depth. + */ +static int svc_rdma_bc_sendto(struct svcxprt_rdma *rdma, + struct rpc_rqst *rqst, + struct svc_rdma_send_ctxt *sctxt) +{ + struct svc_rdma_pcl empty_pcl; + int ret; + + pcl_init(&empty_pcl); + ret = svc_rdma_map_reply_msg(rdma, sctxt, &empty_pcl, &empty_pcl, + &rqst->rq_snd_buf); + if (ret < 0) + return -EIO; + + /* Bump page refcnt so Send completion doesn't release + * the rq_buffer before all retransmits are complete. + */ + get_page(virt_to_page(rqst->rq_buffer)); + sctxt->sc_send_wr.opcode = IB_WR_SEND; + return svc_rdma_post_send(rdma, sctxt); +} + +/* Server-side transport endpoint wants a whole page for its send + * buffer. The client RPC code constructs the RPC header in this + * buffer before it invokes ->send_request. + */ +static int +xprt_rdma_bc_allocate(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + size_t size = rqst->rq_callsize; + struct page *page; + + if (size > PAGE_SIZE) { + WARN_ONCE(1, "svcrdma: large bc buffer request (size %zu)\n", + size); + return -EINVAL; + } + + page = alloc_page(GFP_NOIO | __GFP_NOWARN); + if (!page) + return -ENOMEM; + rqst->rq_buffer = page_address(page); + + rqst->rq_rbuffer = kmalloc(rqst->rq_rcvsize, GFP_NOIO | __GFP_NOWARN); + if (!rqst->rq_rbuffer) { + put_page(page); + return -ENOMEM; + } + return 0; +} + +static void +xprt_rdma_bc_free(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + + put_page(virt_to_page(rqst->rq_buffer)); + kfree(rqst->rq_rbuffer); +} + +static int +rpcrdma_bc_send_request(struct svcxprt_rdma *rdma, struct rpc_rqst *rqst) +{ + struct rpc_xprt *xprt = rqst->rq_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct svc_rdma_send_ctxt *ctxt; + __be32 *p; + int rc; + + ctxt = svc_rdma_send_ctxt_get(rdma); + if (!ctxt) + goto drop_connection; + + p = xdr_reserve_space(&ctxt->sc_stream, RPCRDMA_HDRLEN_MIN); + if (!p) + goto put_ctxt; + *p++ = rqst->rq_xid; + *p++ = rpcrdma_version; + *p++ = cpu_to_be32(r_xprt->rx_buf.rb_bc_max_requests); + *p++ = rdma_msg; + *p++ = xdr_zero; + *p++ = xdr_zero; + *p = xdr_zero; + + rqst->rq_xtime = ktime_get(); + rc = svc_rdma_bc_sendto(rdma, rqst, ctxt); + if (rc) + goto put_ctxt; + return 0; + +put_ctxt: + svc_rdma_send_ctxt_put(rdma, ctxt); + +drop_connection: + return -ENOTCONN; +} + +/** + * xprt_rdma_bc_send_request - Send a reverse-direction Call + * @rqst: rpc_rqst containing Call message to be sent + * + * Return values: + * %0 if the message was sent successfully + * %ENOTCONN if the message was not sent + */ +static int xprt_rdma_bc_send_request(struct rpc_rqst *rqst) +{ + struct svc_xprt *sxprt = rqst->rq_xprt->bc_xprt; + struct svcxprt_rdma *rdma = + container_of(sxprt, struct svcxprt_rdma, sc_xprt); + int ret; + + if (test_bit(XPT_DEAD, &sxprt->xpt_flags)) + return -ENOTCONN; + + ret = rpcrdma_bc_send_request(rdma, rqst); + if (ret == -ENOTCONN) + svc_xprt_close(sxprt); + return ret; +} + +static void +xprt_rdma_bc_close(struct rpc_xprt *xprt) +{ + xprt_disconnect_done(xprt); + xprt->cwnd = RPC_CWNDSHIFT; +} + +static void +xprt_rdma_bc_put(struct rpc_xprt *xprt) +{ + xprt_rdma_free_addresses(xprt); + xprt_free(xprt); +} + +static const struct rpc_xprt_ops xprt_rdma_bc_procs = { + .reserve_xprt = xprt_reserve_xprt_cong, + .release_xprt = xprt_release_xprt_cong, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .release_request = xprt_release_rqst_cong, + .buf_alloc = xprt_rdma_bc_allocate, + .buf_free = xprt_rdma_bc_free, + .send_request = xprt_rdma_bc_send_request, + .wait_for_reply_request = xprt_wait_for_reply_request_def, + .close = xprt_rdma_bc_close, + .destroy = xprt_rdma_bc_put, + .print_stats = xprt_rdma_print_stats +}; + +static const struct rpc_timeout xprt_rdma_bc_timeout = { + .to_initval = 60 * HZ, + .to_maxval = 60 * HZ, +}; + +/* It shouldn't matter if the number of backchannel session slots + * doesn't match the number of RPC/RDMA credits. That just means + * one or the other will have extra slots that aren't used. + */ +static struct rpc_xprt * +xprt_setup_rdma_bc(struct xprt_create *args) +{ + struct rpc_xprt *xprt; + struct rpcrdma_xprt *new_xprt; + + if (args->addrlen > sizeof(xprt->addr)) + return ERR_PTR(-EBADF); + + xprt = xprt_alloc(args->net, sizeof(*new_xprt), + RPCRDMA_MAX_BC_REQUESTS, + RPCRDMA_MAX_BC_REQUESTS); + if (!xprt) + return ERR_PTR(-ENOMEM); + + xprt->timeout = &xprt_rdma_bc_timeout; + xprt_set_bound(xprt); + xprt_set_connected(xprt); + xprt->bind_timeout = 0; + xprt->reestablish_timeout = 0; + xprt->idle_timeout = 0; + + xprt->prot = XPRT_TRANSPORT_BC_RDMA; + xprt->ops = &xprt_rdma_bc_procs; + + memcpy(&xprt->addr, args->dstaddr, args->addrlen); + xprt->addrlen = args->addrlen; + xprt_rdma_format_addresses(xprt, (struct sockaddr *)&xprt->addr); + xprt->resvport = 0; + + xprt->max_payload = xprt_rdma_max_inline_read; + + new_xprt = rpcx_to_rdmax(xprt); + new_xprt->rx_buf.rb_bc_max_requests = xprt->max_reqs; + + xprt_get(xprt); + args->bc_xprt->xpt_bc_xprt = xprt; + xprt->bc_xprt = args->bc_xprt; + + /* Final put for backchannel xprt is in __svc_rdma_free */ + xprt_get(xprt); + return xprt; +} + +struct xprt_class xprt_rdma_bc = { + .list = LIST_HEAD_INIT(xprt_rdma_bc.list), + .name = "rdma backchannel", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_BC_RDMA, + .setup = xprt_setup_rdma_bc, +}; diff --git a/net/sunrpc/xprtrdma/svc_rdma_marshal.c b/net/sunrpc/xprtrdma/svc_rdma_marshal.c deleted file mode 100644 index 8d2edddf48cf..000000000000 --- a/net/sunrpc/xprtrdma/svc_rdma_marshal.c +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. - * - * This software is available to you under a choice of one of two - * licenses. You may choose to be licensed under the terms of the GNU - * General Public License (GPL) Version 2, available from the file - * COPYING in the main directory of this source tree, or the BSD-type - * license below: - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following - * disclaimer in the documentation and/or other materials provided - * with the distribution. - * - * Neither the name of the Network Appliance, Inc. nor the names of - * its contributors may be used to endorse or promote products - * derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * Author: Tom Tucker <tom@opengridcomputing.com> - */ - -#include <linux/sunrpc/xdr.h> -#include <linux/sunrpc/debug.h> -#include <asm/unaligned.h> -#include <linux/sunrpc/rpc_rdma.h> -#include <linux/sunrpc/svc_rdma.h> - -#define RPCDBG_FACILITY RPCDBG_SVCXPRT - -/* - * Decodes a read chunk list. The expected format is as follows: - * descrim : xdr_one - * position : u32 offset into XDR stream - * handle : u32 RKEY - * . . . - * end-of-list: xdr_zero - */ -static u32 *decode_read_list(u32 *va, u32 *vaend) -{ - struct rpcrdma_read_chunk *ch = (struct rpcrdma_read_chunk *)va; - - while (ch->rc_discrim != xdr_zero) { - if (((unsigned long)ch + sizeof(struct rpcrdma_read_chunk)) > - (unsigned long)vaend) { - dprintk("svcrdma: vaend=%p, ch=%p\n", vaend, ch); - return NULL; - } - ch++; - } - return (u32 *)&ch->rc_position; -} - -/* - * Determine number of chunks and total bytes in chunk list. The chunk - * list has already been verified to fit within the RPCRDMA header. - */ -void svc_rdma_rcl_chunk_counts(struct rpcrdma_read_chunk *ch, - int *ch_count, int *byte_count) -{ - /* compute the number of bytes represented by read chunks */ - *byte_count = 0; - *ch_count = 0; - for (; ch->rc_discrim != 0; ch++) { - *byte_count = *byte_count + ntohl(ch->rc_target.rs_length); - *ch_count = *ch_count + 1; - } -} - -/* - * Decodes a write chunk list. The expected format is as follows: - * descrim : xdr_one - * nchunks : <count> - * handle : u32 RKEY ---+ - * length : u32 <len of segment> | - * offset : remove va + <count> - * . . . | - * ---+ - */ -static u32 *decode_write_list(u32 *va, u32 *vaend) -{ - int nchunks; - - struct rpcrdma_write_array *ary = - (struct rpcrdma_write_array *)va; - - /* Check for not write-array */ - if (ary->wc_discrim == xdr_zero) - return (u32 *)&ary->wc_nchunks; - - if ((unsigned long)ary + sizeof(struct rpcrdma_write_array) > - (unsigned long)vaend) { - dprintk("svcrdma: ary=%p, vaend=%p\n", ary, vaend); - return NULL; - } - nchunks = ntohl(ary->wc_nchunks); - if (((unsigned long)&ary->wc_array[0] + - (sizeof(struct rpcrdma_write_chunk) * nchunks)) > - (unsigned long)vaend) { - dprintk("svcrdma: ary=%p, wc_nchunks=%d, vaend=%p\n", - ary, nchunks, vaend); - return NULL; - } - /* - * rs_length is the 2nd 4B field in wc_target and taking its - * address skips the list terminator - */ - return (u32 *)&ary->wc_array[nchunks].wc_target.rs_length; -} - -static u32 *decode_reply_array(u32 *va, u32 *vaend) -{ - int nchunks; - struct rpcrdma_write_array *ary = - (struct rpcrdma_write_array *)va; - - /* Check for no reply-array */ - if (ary->wc_discrim == xdr_zero) - return (u32 *)&ary->wc_nchunks; - - if ((unsigned long)ary + sizeof(struct rpcrdma_write_array) > - (unsigned long)vaend) { - dprintk("svcrdma: ary=%p, vaend=%p\n", ary, vaend); - return NULL; - } - nchunks = ntohl(ary->wc_nchunks); - if (((unsigned long)&ary->wc_array[0] + - (sizeof(struct rpcrdma_write_chunk) * nchunks)) > - (unsigned long)vaend) { - dprintk("svcrdma: ary=%p, wc_nchunks=%d, vaend=%p\n", - ary, nchunks, vaend); - return NULL; - } - return (u32 *)&ary->wc_array[nchunks]; -} - -int svc_rdma_xdr_decode_req(struct rpcrdma_msg **rdma_req, - struct svc_rqst *rqstp) -{ - struct rpcrdma_msg *rmsgp = NULL; - u32 *va; - u32 *vaend; - u32 hdr_len; - - rmsgp = (struct rpcrdma_msg *)rqstp->rq_arg.head[0].iov_base; - - /* Verify that there's enough bytes for header + something */ - if (rqstp->rq_arg.len <= RPCRDMA_HDRLEN_MIN) { - dprintk("svcrdma: header too short = %d\n", - rqstp->rq_arg.len); - return -EINVAL; - } - - /* Decode the header */ - rmsgp->rm_xid = ntohl(rmsgp->rm_xid); - rmsgp->rm_vers = ntohl(rmsgp->rm_vers); - rmsgp->rm_credit = ntohl(rmsgp->rm_credit); - rmsgp->rm_type = ntohl(rmsgp->rm_type); - - if (rmsgp->rm_vers != RPCRDMA_VERSION) - return -ENOSYS; - - /* Pull in the extra for the padded case and bump our pointer */ - if (rmsgp->rm_type == RDMA_MSGP) { - int hdrlen; - rmsgp->rm_body.rm_padded.rm_align = - ntohl(rmsgp->rm_body.rm_padded.rm_align); - rmsgp->rm_body.rm_padded.rm_thresh = - ntohl(rmsgp->rm_body.rm_padded.rm_thresh); - - va = &rmsgp->rm_body.rm_padded.rm_pempty[4]; - rqstp->rq_arg.head[0].iov_base = va; - hdrlen = (u32)((unsigned long)va - (unsigned long)rmsgp); - rqstp->rq_arg.head[0].iov_len -= hdrlen; - if (hdrlen > rqstp->rq_arg.len) - return -EINVAL; - return hdrlen; - } - - /* The chunk list may contain either a read chunk list or a write - * chunk list and a reply chunk list. - */ - va = &rmsgp->rm_body.rm_chunks[0]; - vaend = (u32 *)((unsigned long)rmsgp + rqstp->rq_arg.len); - va = decode_read_list(va, vaend); - if (!va) - return -EINVAL; - va = decode_write_list(va, vaend); - if (!va) - return -EINVAL; - va = decode_reply_array(va, vaend); - if (!va) - return -EINVAL; - - rqstp->rq_arg.head[0].iov_base = va; - hdr_len = (unsigned long)va - (unsigned long)rmsgp; - rqstp->rq_arg.head[0].iov_len -= hdr_len; - - *rdma_req = rmsgp; - return hdr_len; -} - -int svc_rdma_xdr_decode_deferred_req(struct svc_rqst *rqstp) -{ - struct rpcrdma_msg *rmsgp = NULL; - struct rpcrdma_read_chunk *ch; - struct rpcrdma_write_array *ary; - u32 *va; - u32 hdrlen; - - dprintk("svcrdma: processing deferred RDMA header on rqstp=%p\n", - rqstp); - rmsgp = (struct rpcrdma_msg *)rqstp->rq_arg.head[0].iov_base; - - /* Pull in the extra for the padded case and bump our pointer */ - if (rmsgp->rm_type == RDMA_MSGP) { - va = &rmsgp->rm_body.rm_padded.rm_pempty[4]; - rqstp->rq_arg.head[0].iov_base = va; - hdrlen = (u32)((unsigned long)va - (unsigned long)rmsgp); - rqstp->rq_arg.head[0].iov_len -= hdrlen; - return hdrlen; - } - - /* - * Skip all chunks to find RPC msg. These were previously processed - */ - va = &rmsgp->rm_body.rm_chunks[0]; - - /* Skip read-list */ - for (ch = (struct rpcrdma_read_chunk *)va; - ch->rc_discrim != xdr_zero; ch++); - va = (u32 *)&ch->rc_position; - - /* Skip write-list */ - ary = (struct rpcrdma_write_array *)va; - if (ary->wc_discrim == xdr_zero) - va = (u32 *)&ary->wc_nchunks; - else - /* - * rs_length is the 2nd 4B field in wc_target and taking its - * address skips the list terminator - */ - va = (u32 *)&ary->wc_array[ary->wc_nchunks].wc_target.rs_length; - - /* Skip reply-array */ - ary = (struct rpcrdma_write_array *)va; - if (ary->wc_discrim == xdr_zero) - va = (u32 *)&ary->wc_nchunks; - else - va = (u32 *)&ary->wc_array[ary->wc_nchunks]; - - rqstp->rq_arg.head[0].iov_base = va; - hdrlen = (unsigned long)va - (unsigned long)rmsgp; - rqstp->rq_arg.head[0].iov_len -= hdrlen; - - return hdrlen; -} - -int svc_rdma_xdr_encode_error(struct svcxprt_rdma *xprt, - struct rpcrdma_msg *rmsgp, - enum rpcrdma_errcode err, u32 *va) -{ - u32 *startp = va; - - *va++ = htonl(rmsgp->rm_xid); - *va++ = htonl(rmsgp->rm_vers); - *va++ = htonl(xprt->sc_max_requests); - *va++ = htonl(RDMA_ERROR); - *va++ = htonl(err); - if (err == ERR_VERS) { - *va++ = htonl(RPCRDMA_VERSION); - *va++ = htonl(RPCRDMA_VERSION); - } - - return (int)((unsigned long)va - (unsigned long)startp); -} - -int svc_rdma_xdr_get_reply_hdr_len(struct rpcrdma_msg *rmsgp) -{ - struct rpcrdma_write_array *wr_ary; - - /* There is no read-list in a reply */ - - /* skip write list */ - wr_ary = (struct rpcrdma_write_array *) - &rmsgp->rm_body.rm_chunks[1]; - if (wr_ary->wc_discrim) - wr_ary = (struct rpcrdma_write_array *) - &wr_ary->wc_array[ntohl(wr_ary->wc_nchunks)]. - wc_target.rs_length; - else - wr_ary = (struct rpcrdma_write_array *) - &wr_ary->wc_nchunks; - - /* skip reply array */ - if (wr_ary->wc_discrim) - wr_ary = (struct rpcrdma_write_array *) - &wr_ary->wc_array[ntohl(wr_ary->wc_nchunks)]; - else - wr_ary = (struct rpcrdma_write_array *) - &wr_ary->wc_nchunks; - - return (unsigned long) wr_ary - (unsigned long) rmsgp; -} - -void svc_rdma_xdr_encode_write_list(struct rpcrdma_msg *rmsgp, int chunks) -{ - struct rpcrdma_write_array *ary; - - /* no read-list */ - rmsgp->rm_body.rm_chunks[0] = xdr_zero; - - /* write-array discrim */ - ary = (struct rpcrdma_write_array *) - &rmsgp->rm_body.rm_chunks[1]; - ary->wc_discrim = xdr_one; - ary->wc_nchunks = htonl(chunks); - - /* write-list terminator */ - ary->wc_array[chunks].wc_target.rs_handle = xdr_zero; - - /* reply-array discriminator */ - ary->wc_array[chunks].wc_target.rs_length = xdr_zero; -} - -void svc_rdma_xdr_encode_reply_array(struct rpcrdma_write_array *ary, - int chunks) -{ - ary->wc_discrim = xdr_one; - ary->wc_nchunks = htonl(chunks); -} - -void svc_rdma_xdr_encode_array_chunk(struct rpcrdma_write_array *ary, - int chunk_no, - __be32 rs_handle, - __be64 rs_offset, - u32 write_len) -{ - struct rpcrdma_segment *seg = &ary->wc_array[chunk_no].wc_target; - seg->rs_handle = rs_handle; - seg->rs_offset = rs_offset; - seg->rs_length = htonl(write_len); -} - -void svc_rdma_xdr_encode_reply_header(struct svcxprt_rdma *xprt, - struct rpcrdma_msg *rdma_argp, - struct rpcrdma_msg *rdma_resp, - enum rpcrdma_proc rdma_type) -{ - rdma_resp->rm_xid = htonl(rdma_argp->rm_xid); - rdma_resp->rm_vers = htonl(rdma_argp->rm_vers); - rdma_resp->rm_credit = htonl(xprt->sc_max_requests); - rdma_resp->rm_type = htonl(rdma_type); - - /* Encode <nul> chunks lists */ - rdma_resp->rm_body.rm_chunks[0] = xdr_zero; - rdma_resp->rm_body.rm_chunks[1] = xdr_zero; - rdma_resp->rm_body.rm_chunks[2] = xdr_zero; -} diff --git a/net/sunrpc/xprtrdma/svc_rdma_pcl.c b/net/sunrpc/xprtrdma/svc_rdma_pcl.c new file mode 100644 index 000000000000..b63cfeaa2923 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_pcl.c @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2020 Oracle. All rights reserved. + */ + +#include <linux/sunrpc/svc_rdma.h> +#include <linux/sunrpc/rpc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +/** + * pcl_free - Release all memory associated with a parsed chunk list + * @pcl: parsed chunk list + * + */ +void pcl_free(struct svc_rdma_pcl *pcl) +{ + while (!list_empty(&pcl->cl_chunks)) { + struct svc_rdma_chunk *chunk; + + chunk = pcl_first_chunk(pcl); + list_del(&chunk->ch_list); + kfree(chunk); + } +} + +static struct svc_rdma_chunk *pcl_alloc_chunk(u32 segcount, u32 position) +{ + struct svc_rdma_chunk *chunk; + + chunk = kmalloc(struct_size(chunk, ch_segments, segcount), GFP_KERNEL); + if (!chunk) + return NULL; + + chunk->ch_position = position; + chunk->ch_length = 0; + chunk->ch_payload_length = 0; + chunk->ch_segcount = 0; + return chunk; +} + +static struct svc_rdma_chunk * +pcl_lookup_position(struct svc_rdma_pcl *pcl, u32 position) +{ + struct svc_rdma_chunk *pos; + + pcl_for_each_chunk(pos, pcl) { + if (pos->ch_position == position) + return pos; + } + return NULL; +} + +static void pcl_insert_position(struct svc_rdma_pcl *pcl, + struct svc_rdma_chunk *chunk) +{ + struct svc_rdma_chunk *pos; + + pcl_for_each_chunk(pos, pcl) { + if (pos->ch_position > chunk->ch_position) + break; + } + __list_add(&chunk->ch_list, pos->ch_list.prev, &pos->ch_list); + pcl->cl_count++; +} + +static void pcl_set_read_segment(const struct svc_rdma_recv_ctxt *rctxt, + struct svc_rdma_chunk *chunk, + u32 handle, u32 length, u64 offset) +{ + struct svc_rdma_segment *segment; + + segment = &chunk->ch_segments[chunk->ch_segcount]; + segment->rs_handle = handle; + segment->rs_length = length; + segment->rs_offset = offset; + + trace_svcrdma_decode_rseg(&rctxt->rc_cid, chunk, segment); + + chunk->ch_length += length; + chunk->ch_segcount++; +} + +/** + * pcl_alloc_call - Construct a parsed chunk list for the Call body + * @rctxt: Ingress receive context + * @p: Start of an un-decoded Read list + * + * Assumptions: + * - The incoming Read list has already been sanity checked. + * - cl_count is already set to the number of segments in + * the un-decoded list. + * - The list might not be in order by position. + * + * Return values: + * %true: Parsed chunk list was successfully constructed, and + * cl_count is updated to be the number of chunks (ie. + * unique positions) in the Read list. + * %false: Memory allocation failed. + */ +bool pcl_alloc_call(struct svc_rdma_recv_ctxt *rctxt, __be32 *p) +{ + struct svc_rdma_pcl *pcl = &rctxt->rc_call_pcl; + unsigned int i, segcount = pcl->cl_count; + + pcl->cl_count = 0; + for (i = 0; i < segcount; i++) { + struct svc_rdma_chunk *chunk; + u32 position, handle, length; + u64 offset; + + p++; /* skip the list discriminator */ + p = xdr_decode_read_segment(p, &position, &handle, + &length, &offset); + if (position != 0) + continue; + + if (pcl_is_empty(pcl)) { + chunk = pcl_alloc_chunk(segcount, position); + if (!chunk) + return false; + pcl_insert_position(pcl, chunk); + } else { + chunk = list_first_entry(&pcl->cl_chunks, + struct svc_rdma_chunk, + ch_list); + } + + pcl_set_read_segment(rctxt, chunk, handle, length, offset); + } + + return true; +} + +/** + * pcl_alloc_read - Construct a parsed chunk list for normal Read chunks + * @rctxt: Ingress receive context + * @p: Start of an un-decoded Read list + * + * Assumptions: + * - The incoming Read list has already been sanity checked. + * - cl_count is already set to the number of segments in + * the un-decoded list. + * - The list might not be in order by position. + * + * Return values: + * %true: Parsed chunk list was successfully constructed, and + * cl_count is updated to be the number of chunks (ie. + * unique position values) in the Read list. + * %false: Memory allocation failed. + * + * TODO: + * - Check for chunk range overlaps + */ +bool pcl_alloc_read(struct svc_rdma_recv_ctxt *rctxt, __be32 *p) +{ + struct svc_rdma_pcl *pcl = &rctxt->rc_read_pcl; + unsigned int i, segcount = pcl->cl_count; + + pcl->cl_count = 0; + for (i = 0; i < segcount; i++) { + struct svc_rdma_chunk *chunk; + u32 position, handle, length; + u64 offset; + + p++; /* skip the list discriminator */ + p = xdr_decode_read_segment(p, &position, &handle, + &length, &offset); + if (position == 0) + continue; + + chunk = pcl_lookup_position(pcl, position); + if (!chunk) { + chunk = pcl_alloc_chunk(segcount, position); + if (!chunk) + return false; + pcl_insert_position(pcl, chunk); + } + + pcl_set_read_segment(rctxt, chunk, handle, length, offset); + } + + return true; +} + +/** + * pcl_alloc_write - Construct a parsed chunk list from a Write list + * @rctxt: Ingress receive context + * @pcl: Parsed chunk list to populate + * @p: Start of an un-decoded Write list + * + * Assumptions: + * - The incoming Write list has already been sanity checked, and + * - cl_count is set to the number of chunks in the un-decoded list. + * + * Return values: + * %true: Parsed chunk list was successfully constructed. + * %false: Memory allocation failed. + */ +bool pcl_alloc_write(struct svc_rdma_recv_ctxt *rctxt, + struct svc_rdma_pcl *pcl, __be32 *p) +{ + struct svc_rdma_segment *segment; + struct svc_rdma_chunk *chunk; + unsigned int i, j; + u32 segcount; + + for (i = 0; i < pcl->cl_count; i++) { + p++; /* skip the list discriminator */ + segcount = be32_to_cpup(p++); + + chunk = pcl_alloc_chunk(segcount, 0); + if (!chunk) + return false; + list_add_tail(&chunk->ch_list, &pcl->cl_chunks); + + for (j = 0; j < segcount; j++) { + segment = &chunk->ch_segments[j]; + p = xdr_decode_rdma_segment(p, &segment->rs_handle, + &segment->rs_length, + &segment->rs_offset); + trace_svcrdma_decode_wseg(&rctxt->rc_cid, chunk, j); + + chunk->ch_length += segment->rs_length; + chunk->ch_segcount++; + } + } + return true; +} + +static int pcl_process_region(const struct xdr_buf *xdr, + unsigned int offset, unsigned int length, + int (*actor)(const struct xdr_buf *, void *), + void *data) +{ + struct xdr_buf subbuf; + + if (!length) + return 0; + if (xdr_buf_subsegment(xdr, &subbuf, offset, length)) + return -EMSGSIZE; + return actor(&subbuf, data); +} + +/** + * pcl_process_nonpayloads - Process non-payload regions inside @xdr + * @pcl: Chunk list to process + * @xdr: xdr_buf to process + * @actor: Function to invoke on each non-payload region + * @data: Arguments for @actor + * + * This mechanism must ignore not only result payloads that were already + * sent via RDMA Write, but also XDR padding for those payloads that + * the upper layer has added. + * + * Assumptions: + * The xdr->len and ch_position fields are aligned to 4-byte multiples. + * + * Returns: + * On success, zero, + * %-EMSGSIZE on XDR buffer overflow, or + * The return value of @actor + */ +int pcl_process_nonpayloads(const struct svc_rdma_pcl *pcl, + const struct xdr_buf *xdr, + int (*actor)(const struct xdr_buf *, void *), + void *data) +{ + struct svc_rdma_chunk *chunk, *next; + unsigned int start; + int ret; + + chunk = pcl_first_chunk(pcl); + + /* No result payloads were generated */ + if (!chunk || !chunk->ch_payload_length) + return actor(xdr, data); + + /* Process the region before the first result payload */ + ret = pcl_process_region(xdr, 0, chunk->ch_position, actor, data); + if (ret < 0) + return ret; + + /* Process the regions between each middle result payload */ + while ((next = pcl_next_chunk(pcl, chunk))) { + if (!next->ch_payload_length) + break; + + start = pcl_chunk_end_offset(chunk); + ret = pcl_process_region(xdr, start, next->ch_position - start, + actor, data); + if (ret < 0) + return ret; + + chunk = next; + } + + /* Process the region after the last result payload */ + start = pcl_chunk_end_offset(chunk); + ret = pcl_process_region(xdr, start, xdr->len - start, actor, data); + if (ret < 0) + return ret; + + return 0; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c index 0ce75524ed21..e7e4a39ca6c6 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c +++ b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c @@ -1,4 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause /* + * Copyright (c) 2016-2018 Oracle. All rights reserved. + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -39,649 +42,979 @@ * Author: Tom Tucker <tom@opengridcomputing.com> */ -#include <linux/sunrpc/debug.h> -#include <linux/sunrpc/rpc_rdma.h> +/* Operation + * + * The main entry point is svc_rdma_recvfrom. This is called from + * svc_recv when the transport indicates there is incoming data to + * be read. "Data Ready" is signaled when an RDMA Receive completes, + * or when a set of RDMA Reads complete. + * + * An svc_rqst is passed in. This structure contains an array of + * free pages (rq_pages) that will contain the incoming RPC message. + * + * Short messages are moved directly into svc_rqst::rq_arg, and + * the RPC Call is ready to be processed by the Upper Layer. + * svc_rdma_recvfrom returns the length of the RPC Call message, + * completing the reception of the RPC Call. + * + * However, when an incoming message has Read chunks, + * svc_rdma_recvfrom must post RDMA Reads to pull the RPC Call's + * data payload from the client. svc_rdma_recvfrom sets up the + * RDMA Reads using pages in svc_rqst::rq_pages, which are + * transferred to an svc_rdma_recv_ctxt for the duration of the + * I/O. svc_rdma_recvfrom then returns zero, since the RPC message + * is still not yet ready. + * + * When the Read chunk payloads have become available on the + * server, "Data Ready" is raised again, and svc_recv calls + * svc_rdma_recvfrom again. This second call may use a different + * svc_rqst than the first one, thus any information that needs + * to be preserved across these two calls is kept in an + * svc_rdma_recv_ctxt. + * + * The second call to svc_rdma_recvfrom performs final assembly + * of the RPC Call message, using the RDMA Read sink pages kept in + * the svc_rdma_recv_ctxt. The xdr_buf is copied from the + * svc_rdma_recv_ctxt to the second svc_rqst. The second call returns + * the length of the completed RPC Call message. + * + * Page Management + * + * Pages under I/O must be transferred from the first svc_rqst to an + * svc_rdma_recv_ctxt before the first svc_rdma_recvfrom call returns. + * + * The first svc_rqst supplies pages for RDMA Reads. These are moved + * from rqstp::rq_pages into ctxt::pages. The consumed elements of + * the rq_pages array are set to NULL and refilled with the first + * svc_rdma_recvfrom call returns. + * + * During the second svc_rdma_recvfrom call, RDMA Read sink pages + * are transferred from the svc_rdma_recv_ctxt to the second svc_rqst. + */ + +#include <linux/slab.h> #include <linux/spinlock.h> -#include <asm/unaligned.h> +#include <linux/unaligned.h> #include <rdma/ib_verbs.h> #include <rdma/rdma_cm.h> + +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/rpc_rdma.h> #include <linux/sunrpc/svc_rdma.h> -#define RPCDBG_FACILITY RPCDBG_SVCXPRT +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> -/* - * Replace the pages in the rq_argpages array with the pages from the SGE in - * the RDMA_RECV completion. The SGL should contain full pages up until the - * last one. +static void svc_rdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc); + +static inline struct svc_rdma_recv_ctxt * +svc_rdma_next_recv_ctxt(struct list_head *list) +{ + return list_first_entry_or_null(list, struct svc_rdma_recv_ctxt, + rc_list); +} + +static struct svc_rdma_recv_ctxt * +svc_rdma_recv_ctxt_alloc(struct svcxprt_rdma *rdma) +{ + int node = ibdev_to_node(rdma->sc_cm_id->device); + struct svc_rdma_recv_ctxt *ctxt; + unsigned long pages; + dma_addr_t addr; + void *buffer; + + pages = svc_serv_maxpages(rdma->sc_xprt.xpt_server); + ctxt = kzalloc_node(struct_size(ctxt, rc_pages, pages), + GFP_KERNEL, node); + if (!ctxt) + goto fail0; + ctxt->rc_maxpages = pages; + buffer = kmalloc_node(rdma->sc_max_req_size, GFP_KERNEL, node); + if (!buffer) + goto fail1; + addr = ib_dma_map_single(rdma->sc_pd->device, buffer, + rdma->sc_max_req_size, DMA_FROM_DEVICE); + if (ib_dma_mapping_error(rdma->sc_pd->device, addr)) + goto fail2; + + svc_rdma_recv_cid_init(rdma, &ctxt->rc_cid); + pcl_init(&ctxt->rc_call_pcl); + pcl_init(&ctxt->rc_read_pcl); + pcl_init(&ctxt->rc_write_pcl); + pcl_init(&ctxt->rc_reply_pcl); + + ctxt->rc_recv_wr.next = NULL; + ctxt->rc_recv_wr.wr_cqe = &ctxt->rc_cqe; + ctxt->rc_recv_wr.sg_list = &ctxt->rc_recv_sge; + ctxt->rc_recv_wr.num_sge = 1; + ctxt->rc_cqe.done = svc_rdma_wc_receive; + ctxt->rc_recv_sge.addr = addr; + ctxt->rc_recv_sge.length = rdma->sc_max_req_size; + ctxt->rc_recv_sge.lkey = rdma->sc_pd->local_dma_lkey; + ctxt->rc_recv_buf = buffer; + svc_rdma_cc_init(rdma, &ctxt->rc_cc); + return ctxt; + +fail2: + kfree(buffer); +fail1: + kfree(ctxt); +fail0: + return NULL; +} + +static void svc_rdma_recv_ctxt_destroy(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *ctxt) +{ + ib_dma_unmap_single(rdma->sc_pd->device, ctxt->rc_recv_sge.addr, + ctxt->rc_recv_sge.length, DMA_FROM_DEVICE); + kfree(ctxt->rc_recv_buf); + kfree(ctxt); +} + +/** + * svc_rdma_recv_ctxts_destroy - Release all recv_ctxt's for an xprt + * @rdma: svcxprt_rdma being torn down + * */ -static void rdma_build_arg_xdr(struct svc_rqst *rqstp, - struct svc_rdma_op_ctxt *ctxt, - u32 byte_count) +void svc_rdma_recv_ctxts_destroy(struct svcxprt_rdma *rdma) { - struct page *page; - u32 bc; - int sge_no; - - /* Swap the page in the SGE with the page in argpages */ - page = ctxt->pages[0]; - put_page(rqstp->rq_pages[0]); - rqstp->rq_pages[0] = page; - - /* Set up the XDR head */ - rqstp->rq_arg.head[0].iov_base = page_address(page); - rqstp->rq_arg.head[0].iov_len = min(byte_count, ctxt->sge[0].length); - rqstp->rq_arg.len = byte_count; - rqstp->rq_arg.buflen = byte_count; - - /* Compute bytes past head in the SGL */ - bc = byte_count - rqstp->rq_arg.head[0].iov_len; - - /* If data remains, store it in the pagelist */ - rqstp->rq_arg.page_len = bc; - rqstp->rq_arg.page_base = 0; - rqstp->rq_arg.pages = &rqstp->rq_pages[1]; - sge_no = 1; - while (bc && sge_no < ctxt->count) { - page = ctxt->pages[sge_no]; - put_page(rqstp->rq_pages[sge_no]); - rqstp->rq_pages[sge_no] = page; - bc -= min(bc, ctxt->sge[sge_no].length); - rqstp->rq_arg.buflen += ctxt->sge[sge_no].length; - sge_no++; + struct svc_rdma_recv_ctxt *ctxt; + struct llist_node *node; + + while ((node = llist_del_first(&rdma->sc_recv_ctxts))) { + ctxt = llist_entry(node, struct svc_rdma_recv_ctxt, rc_node); + svc_rdma_recv_ctxt_destroy(rdma, ctxt); } - rqstp->rq_respages = &rqstp->rq_pages[sge_no]; +} + +/** + * svc_rdma_recv_ctxt_get - Allocate a recv_ctxt + * @rdma: controlling svcxprt_rdma + * + * Returns a recv_ctxt or (rarely) NULL if none are available. + */ +struct svc_rdma_recv_ctxt *svc_rdma_recv_ctxt_get(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + struct llist_node *node; + + node = llist_del_first(&rdma->sc_recv_ctxts); + if (!node) + return NULL; + + ctxt = llist_entry(node, struct svc_rdma_recv_ctxt, rc_node); + ctxt->rc_page_count = 0; + return ctxt; +} + +/** + * svc_rdma_recv_ctxt_put - Return recv_ctxt to free list + * @rdma: controlling svcxprt_rdma + * @ctxt: object to return to the free list + * + */ +void svc_rdma_recv_ctxt_put(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *ctxt) +{ + svc_rdma_cc_release(rdma, &ctxt->rc_cc, DMA_FROM_DEVICE); - /* We should never run out of SGE because the limit is defined to - * support the max allowed RPC data length + /* @rc_page_count is normally zero here, but error flows + * can leave pages in @rc_pages. */ - BUG_ON(bc && (sge_no == ctxt->count)); - BUG_ON((rqstp->rq_arg.head[0].iov_len + rqstp->rq_arg.page_len) - != byte_count); - BUG_ON(rqstp->rq_arg.len != byte_count); - - /* If not all pages were used from the SGL, free the remaining ones */ - bc = sge_no; - while (sge_no < ctxt->count) { - page = ctxt->pages[sge_no++]; - put_page(page); + release_pages(ctxt->rc_pages, ctxt->rc_page_count); + + pcl_free(&ctxt->rc_call_pcl); + pcl_free(&ctxt->rc_read_pcl); + pcl_free(&ctxt->rc_write_pcl); + pcl_free(&ctxt->rc_reply_pcl); + + llist_add(&ctxt->rc_node, &rdma->sc_recv_ctxts); +} + +/** + * svc_rdma_release_ctxt - Release transport-specific per-rqst resources + * @xprt: the transport which owned the context + * @vctxt: the context from rqstp->rq_xprt_ctxt or dr->xprt_ctxt + * + * Ensure that the recv_ctxt is released whether or not a Reply + * was sent. For example, the client could close the connection, + * or svc_process could drop an RPC, before the Reply is sent. + */ +void svc_rdma_release_ctxt(struct svc_xprt *xprt, void *vctxt) +{ + struct svc_rdma_recv_ctxt *ctxt = vctxt; + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + if (ctxt) + svc_rdma_recv_ctxt_put(rdma, ctxt); +} + +static bool svc_rdma_refresh_recvs(struct svcxprt_rdma *rdma, + unsigned int wanted) +{ + const struct ib_recv_wr *bad_wr = NULL; + struct svc_rdma_recv_ctxt *ctxt; + struct ib_recv_wr *recv_chain; + int ret; + + if (test_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags)) + return false; + + recv_chain = NULL; + while (wanted--) { + ctxt = svc_rdma_recv_ctxt_get(rdma); + if (!ctxt) + break; + + trace_svcrdma_post_recv(&ctxt->rc_cid); + ctxt->rc_recv_wr.next = recv_chain; + recv_chain = &ctxt->rc_recv_wr; + rdma->sc_pending_recvs++; } - ctxt->count = bc; + if (!recv_chain) + return true; - /* Set up tail */ - rqstp->rq_arg.tail[0].iov_base = NULL; - rqstp->rq_arg.tail[0].iov_len = 0; + ret = ib_post_recv(rdma->sc_qp, recv_chain, &bad_wr); + if (ret) + goto err_free; + return true; + +err_free: + trace_svcrdma_rq_post_err(rdma, ret); + while (bad_wr) { + ctxt = container_of(bad_wr, struct svc_rdma_recv_ctxt, + rc_recv_wr); + bad_wr = bad_wr->next; + svc_rdma_recv_ctxt_put(rdma, ctxt); + } + /* Since we're destroying the xprt, no need to reset + * sc_pending_recvs. */ + return false; } -/* Encode a read-chunk-list as an array of IB SGE +/** + * svc_rdma_post_recvs - Post initial set of Recv WRs + * @rdma: fresh svcxprt_rdma * - * Assumptions: - * - chunk[0]->position points to pages[0] at an offset of 0 - * - pages[] is not physically or virtually contiguous and consists of - * PAGE_SIZE elements. + * Return values: + * %true: Receive Queue initialization successful + * %false: memory allocation or DMA error + */ +bool svc_rdma_post_recvs(struct svcxprt_rdma *rdma) +{ + unsigned int total; + + /* For each credit, allocate enough recv_ctxts for one + * posted Receive and one RPC in process. + */ + total = (rdma->sc_max_requests * 2) + rdma->sc_recv_batch; + while (total--) { + struct svc_rdma_recv_ctxt *ctxt; + + ctxt = svc_rdma_recv_ctxt_alloc(rdma); + if (!ctxt) + return false; + llist_add(&ctxt->rc_node, &rdma->sc_recv_ctxts); + } + + return svc_rdma_refresh_recvs(rdma, rdma->sc_max_requests); +} + +/** + * svc_rdma_wc_receive - Invoked by RDMA provider for each polled Receive WC + * @cq: Completion Queue context + * @wc: Work Completion object * - * Output: - * - sge array pointing into pages[] array. - * - chunk_sge array specifying sge index and count for each - * chunk in the read list + */ +static void svc_rdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc) +{ + struct svcxprt_rdma *rdma = cq->cq_context; + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_recv_ctxt *ctxt; + + rdma->sc_pending_recvs--; + + /* WARNING: Only wc->wr_cqe and wc->status are reliable */ + ctxt = container_of(cqe, struct svc_rdma_recv_ctxt, rc_cqe); + + if (wc->status != IB_WC_SUCCESS) + goto flushed; + trace_svcrdma_wc_recv(wc, &ctxt->rc_cid); + + /* If receive posting fails, the connection is about to be + * lost anyway. The server will not be able to send a reply + * for this RPC, and the client will retransmit this RPC + * anyway when it reconnects. + * + * Therefore we drop the Receive, even if status was SUCCESS + * to reduce the likelihood of replayed requests once the + * client reconnects. + */ + if (rdma->sc_pending_recvs < rdma->sc_max_requests) + if (!svc_rdma_refresh_recvs(rdma, rdma->sc_recv_batch)) + goto dropped; + + /* All wc fields are now known to be valid */ + ctxt->rc_byte_len = wc->byte_len; + + spin_lock(&rdma->sc_rq_dto_lock); + list_add_tail(&ctxt->rc_list, &rdma->sc_rq_dto_q); + /* Note the unlock pairs with the smp_rmb in svc_xprt_ready: */ + set_bit(XPT_DATA, &rdma->sc_xprt.xpt_flags); + spin_unlock(&rdma->sc_rq_dto_lock); + if (!test_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags)) + svc_xprt_enqueue(&rdma->sc_xprt); + return; + +flushed: + if (wc->status == IB_WC_WR_FLUSH_ERR) + trace_svcrdma_wc_recv_flush(wc, &ctxt->rc_cid); + else + trace_svcrdma_wc_recv_err(wc, &ctxt->rc_cid); +dropped: + svc_rdma_recv_ctxt_put(rdma, ctxt); + svc_xprt_deferred_close(&rdma->sc_xprt); +} + +/** + * svc_rdma_flush_recv_queues - Drain pending Receive work + * @rdma: svcxprt_rdma being shut down * */ -static int map_read_chunks(struct svcxprt_rdma *xprt, - struct svc_rqst *rqstp, - struct svc_rdma_op_ctxt *head, - struct rpcrdma_msg *rmsgp, - struct svc_rdma_req_map *rpl_map, - struct svc_rdma_req_map *chl_map, - int ch_count, - int byte_count) +void svc_rdma_flush_recv_queues(struct svcxprt_rdma *rdma) { - int sge_no; - int sge_bytes; - int page_off; - int page_no; - int ch_bytes; - int ch_no; - struct rpcrdma_read_chunk *ch; - - sge_no = 0; - page_no = 0; - page_off = 0; - ch = (struct rpcrdma_read_chunk *)&rmsgp->rm_body.rm_chunks[0]; - ch_no = 0; - ch_bytes = ntohl(ch->rc_target.rs_length); - head->arg.head[0] = rqstp->rq_arg.head[0]; - head->arg.tail[0] = rqstp->rq_arg.tail[0]; - head->arg.pages = &head->pages[head->count]; - head->hdr_count = head->count; /* save count of hdr pages */ - head->arg.page_base = 0; - head->arg.page_len = ch_bytes; - head->arg.len = rqstp->rq_arg.len + ch_bytes; - head->arg.buflen = rqstp->rq_arg.buflen + ch_bytes; - head->count++; - chl_map->ch[0].start = 0; - while (byte_count) { - rpl_map->sge[sge_no].iov_base = - page_address(rqstp->rq_arg.pages[page_no]) + page_off; - sge_bytes = min_t(int, PAGE_SIZE-page_off, ch_bytes); - rpl_map->sge[sge_no].iov_len = sge_bytes; - /* - * Don't bump head->count here because the same page - * may be used by multiple SGE. - */ - head->arg.pages[page_no] = rqstp->rq_arg.pages[page_no]; - rqstp->rq_respages = &rqstp->rq_arg.pages[page_no+1]; - - byte_count -= sge_bytes; - ch_bytes -= sge_bytes; - sge_no++; - /* - * If all bytes for this chunk have been mapped to an - * SGE, move to the next SGE - */ - if (ch_bytes == 0) { - chl_map->ch[ch_no].count = - sge_no - chl_map->ch[ch_no].start; - ch_no++; - ch++; - chl_map->ch[ch_no].start = sge_no; - ch_bytes = ntohl(ch->rc_target.rs_length); - /* If bytes remaining account for next chunk */ - if (byte_count) { - head->arg.page_len += ch_bytes; - head->arg.len += ch_bytes; - head->arg.buflen += ch_bytes; - } - } - /* - * If this SGE consumed all of the page, move to the - * next page - */ - if ((sge_bytes + page_off) == PAGE_SIZE) { - page_no++; - page_off = 0; - /* - * If there are still bytes left to map, bump - * the page count - */ - if (byte_count) - head->count++; - } else - page_off += sge_bytes; + struct svc_rdma_recv_ctxt *ctxt; + + while ((ctxt = svc_rdma_next_recv_ctxt(&rdma->sc_read_complete_q))) { + list_del(&ctxt->rc_list); + svc_rdma_recv_ctxt_put(rdma, ctxt); + } + while ((ctxt = svc_rdma_next_recv_ctxt(&rdma->sc_rq_dto_q))) { + list_del(&ctxt->rc_list); + svc_rdma_recv_ctxt_put(rdma, ctxt); } - BUG_ON(byte_count != 0); - return sge_no; } -/* Map a read-chunk-list to an XDR and fast register the page-list. +static void svc_rdma_build_arg_xdr(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *ctxt) +{ + struct xdr_buf *arg = &rqstp->rq_arg; + + arg->head[0].iov_base = ctxt->rc_recv_buf; + arg->head[0].iov_len = ctxt->rc_byte_len; + arg->tail[0].iov_base = NULL; + arg->tail[0].iov_len = 0; + arg->page_len = 0; + arg->page_base = 0; + arg->buflen = ctxt->rc_byte_len; + arg->len = ctxt->rc_byte_len; +} + +/** + * xdr_count_read_segments - Count number of Read segments in Read list + * @rctxt: Ingress receive context + * @p: Start of an un-decoded Read list * - * Assumptions: - * - chunk[0] position points to pages[0] at an offset of 0 - * - pages[] will be made physically contiguous by creating a one-off memory - * region using the fastreg verb. - * - byte_count is # of bytes in read-chunk-list - * - ch_count is # of chunks in read-chunk-list - * - * Output: - * - sge array pointing into pages[] array. - * - chunk_sge array specifying sge index and count for each - * chunk in the read list + * Before allocating anything, ensure the ingress Read list is safe + * to use. + * + * The segment count is limited to how many segments can fit in the + * transport header without overflowing the buffer. That's about 40 + * Read segments for a 1KB inline threshold. + * + * Return values: + * %true: Read list is valid. @rctxt's xdr_stream is updated to point + * to the first byte past the Read list. rc_read_pcl and + * rc_call_pcl cl_count fields are set to the number of + * Read segments in the list. + * %false: Read list is corrupt. @rctxt's xdr_stream is left in an + * unknown state. */ -static int fast_reg_read_chunks(struct svcxprt_rdma *xprt, - struct svc_rqst *rqstp, - struct svc_rdma_op_ctxt *head, - struct rpcrdma_msg *rmsgp, - struct svc_rdma_req_map *rpl_map, - struct svc_rdma_req_map *chl_map, - int ch_count, - int byte_count) +static bool xdr_count_read_segments(struct svc_rdma_recv_ctxt *rctxt, __be32 *p) { - int page_no; - int ch_no; - u32 offset; - struct rpcrdma_read_chunk *ch; - struct svc_rdma_fastreg_mr *frmr; - int ret = 0; - - frmr = svc_rdma_get_frmr(xprt); - if (IS_ERR(frmr)) - return -ENOMEM; - - head->frmr = frmr; - head->arg.head[0] = rqstp->rq_arg.head[0]; - head->arg.tail[0] = rqstp->rq_arg.tail[0]; - head->arg.pages = &head->pages[head->count]; - head->hdr_count = head->count; /* save count of hdr pages */ - head->arg.page_base = 0; - head->arg.page_len = byte_count; - head->arg.len = rqstp->rq_arg.len + byte_count; - head->arg.buflen = rqstp->rq_arg.buflen + byte_count; - - /* Fast register the page list */ - frmr->kva = page_address(rqstp->rq_arg.pages[0]); - frmr->direction = DMA_FROM_DEVICE; - frmr->access_flags = (IB_ACCESS_LOCAL_WRITE|IB_ACCESS_REMOTE_WRITE); - frmr->map_len = byte_count; - frmr->page_list_len = PAGE_ALIGN(byte_count) >> PAGE_SHIFT; - for (page_no = 0; page_no < frmr->page_list_len; page_no++) { - frmr->page_list->page_list[page_no] = - ib_dma_map_page(xprt->sc_cm_id->device, - rqstp->rq_arg.pages[page_no], 0, - PAGE_SIZE, DMA_FROM_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - frmr->page_list->page_list[page_no])) - goto fatal_err; - atomic_inc(&xprt->sc_dma_used); - head->arg.pages[page_no] = rqstp->rq_arg.pages[page_no]; - } - head->count += page_no; - - /* rq_respages points one past arg pages */ - rqstp->rq_respages = &rqstp->rq_arg.pages[page_no]; - - /* Create the reply and chunk maps */ - offset = 0; - ch = (struct rpcrdma_read_chunk *)&rmsgp->rm_body.rm_chunks[0]; - for (ch_no = 0; ch_no < ch_count; ch_no++) { - int len = ntohl(ch->rc_target.rs_length); - rpl_map->sge[ch_no].iov_base = frmr->kva + offset; - rpl_map->sge[ch_no].iov_len = len; - chl_map->ch[ch_no].count = 1; - chl_map->ch[ch_no].start = ch_no; - offset += len; - ch++; + rctxt->rc_call_pcl.cl_count = 0; + rctxt->rc_read_pcl.cl_count = 0; + while (xdr_item_is_present(p)) { + u32 position, handle, length; + u64 offset; + + p = xdr_inline_decode(&rctxt->rc_stream, + rpcrdma_readseg_maxsz * sizeof(*p)); + if (!p) + return false; + + xdr_decode_read_segment(p, &position, &handle, + &length, &offset); + if (position) { + if (position & 3) + return false; + ++rctxt->rc_read_pcl.cl_count; + } else { + ++rctxt->rc_call_pcl.cl_count; + } + + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; } + return true; +} - ret = svc_rdma_fastreg(xprt, frmr); - if (ret) - goto fatal_err; +/* Sanity check the Read list. + * + * Sanity checks: + * - Read list does not overflow Receive buffer. + * - Chunk size limited by largest NFS data payload. + * + * Return values: + * %true: Read list is valid. @rctxt's xdr_stream is updated + * to point to the first byte past the Read list. + * %false: Read list is corrupt. @rctxt's xdr_stream is left + * in an unknown state. + */ +static bool xdr_check_read_list(struct svc_rdma_recv_ctxt *rctxt) +{ + __be32 *p; + + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; + if (!xdr_count_read_segments(rctxt, p)) + return false; + if (!pcl_alloc_call(rctxt, p)) + return false; + return pcl_alloc_read(rctxt, p); +} - return ch_no; +static bool xdr_check_write_chunk(struct svc_rdma_recv_ctxt *rctxt) +{ + u32 segcount; + __be32 *p; + + if (xdr_stream_decode_u32(&rctxt->rc_stream, &segcount)) + return false; - fatal_err: - printk("svcrdma: error fast registering xdr for xprt %p", xprt); - svc_rdma_put_frmr(xprt, frmr); - return -EIO; + /* Before trusting the segcount value enough to use it in + * a computation, perform a simple range check. This is an + * arbitrary but sensible limit (ie, not architectural). + */ + if (unlikely(segcount > rctxt->rc_maxpages)) + return false; + + p = xdr_inline_decode(&rctxt->rc_stream, + segcount * rpcrdma_segment_maxsz * sizeof(*p)); + return p != NULL; } -static int rdma_set_ctxt_sge(struct svcxprt_rdma *xprt, - struct svc_rdma_op_ctxt *ctxt, - struct svc_rdma_fastreg_mr *frmr, - struct kvec *vec, - u64 *sgl_offset, - int count) +/** + * xdr_count_write_chunks - Count number of Write chunks in Write list + * @rctxt: Received header and decoding state + * @p: start of an un-decoded Write list + * + * Before allocating anything, ensure the ingress Write list is + * safe to use. + * + * Return values: + * %true: Write list is valid. @rctxt's xdr_stream is updated + * to point to the first byte past the Write list, and + * the number of Write chunks is in rc_write_pcl.cl_count. + * %false: Write list is corrupt. @rctxt's xdr_stream is left + * in an indeterminate state. + */ +static bool xdr_count_write_chunks(struct svc_rdma_recv_ctxt *rctxt, __be32 *p) { - int i; - unsigned long off; - - ctxt->count = count; - ctxt->direction = DMA_FROM_DEVICE; - for (i = 0; i < count; i++) { - ctxt->sge[i].length = 0; /* in case map fails */ - if (!frmr) { - BUG_ON(!virt_to_page(vec[i].iov_base)); - off = (unsigned long)vec[i].iov_base & ~PAGE_MASK; - ctxt->sge[i].addr = - ib_dma_map_page(xprt->sc_cm_id->device, - virt_to_page(vec[i].iov_base), - off, - vec[i].iov_len, - DMA_FROM_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - ctxt->sge[i].addr)) - return -EINVAL; - ctxt->sge[i].lkey = xprt->sc_dma_lkey; - atomic_inc(&xprt->sc_dma_used); - } else { - ctxt->sge[i].addr = (unsigned long)vec[i].iov_base; - ctxt->sge[i].lkey = frmr->mr->lkey; - } - ctxt->sge[i].length = vec[i].iov_len; - *sgl_offset = *sgl_offset + vec[i].iov_len; + rctxt->rc_write_pcl.cl_count = 0; + while (xdr_item_is_present(p)) { + if (!xdr_check_write_chunk(rctxt)) + return false; + ++rctxt->rc_write_pcl.cl_count; + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; } - return 0; + return true; } -static int rdma_read_max_sge(struct svcxprt_rdma *xprt, int sge_count) +/* Sanity check the Write list. + * + * Implementation limits: + * - This implementation currently supports only one Write chunk. + * + * Sanity checks: + * - Write list does not overflow Receive buffer. + * - Chunk size limited by largest NFS data payload. + * + * Return values: + * %true: Write list is valid. @rctxt's xdr_stream is updated + * to point to the first byte past the Write list. + * %false: Write list is corrupt. @rctxt's xdr_stream is left + * in an unknown state. + */ +static bool xdr_check_write_list(struct svc_rdma_recv_ctxt *rctxt) { - if ((rdma_node_get_transport(xprt->sc_cm_id->device->node_type) == - RDMA_TRANSPORT_IWARP) && - sge_count > 1) - return 1; - else - return min_t(int, sge_count, xprt->sc_max_sge); + __be32 *p; + + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; + if (!xdr_count_write_chunks(rctxt, p)) + return false; + if (!pcl_alloc_write(rctxt, &rctxt->rc_write_pcl, p)) + return false; + + rctxt->rc_cur_result_payload = pcl_first_chunk(&rctxt->rc_write_pcl); + return true; } -/* - * Use RDMA_READ to read data from the advertised client buffer into the - * XDR stream starting at rq_arg.head[0].iov_base. - * Each chunk in the array - * contains the following fields: - * discrim - '1', This isn't used for data placement - * position - The xdr stream offset (the same for every chunk) - * handle - RMR for client memory region - * length - data transfer length - * offset - 64 bit tagged offset in remote memory region - * - * On our side, we need to read into a pagelist. The first page immediately - * follows the RPC header. - * - * This function returns: - * 0 - No error and no read-list found. - * - * 1 - Successful read-list processing. The data is not yet in - * the pagelist and therefore the RPC request must be deferred. The - * I/O completion will enqueue the transport again and - * svc_rdma_recvfrom will complete the request. - * - * <0 - Error processing/posting read-list. - * - * NOTE: The ctxt must not be touched after the last WR has been posted - * because the I/O completion processing may occur on another - * processor and free / modify the context. Ne touche pas! +/* Sanity check the Reply chunk. + * + * Sanity checks: + * - Reply chunk does not overflow Receive buffer. + * - Chunk size limited by largest NFS data payload. + * + * Return values: + * %true: Reply chunk is valid. @rctxt's xdr_stream is updated + * to point to the first byte past the Reply chunk. + * %false: Reply chunk is corrupt. @rctxt's xdr_stream is left + * in an unknown state. */ -static int rdma_read_xdr(struct svcxprt_rdma *xprt, - struct rpcrdma_msg *rmsgp, - struct svc_rqst *rqstp, - struct svc_rdma_op_ctxt *hdr_ctxt) +static bool xdr_check_reply_chunk(struct svc_rdma_recv_ctxt *rctxt) { - struct ib_send_wr read_wr; - struct ib_send_wr inv_wr; - int err = 0; - int ch_no; - int ch_count; - int byte_count; - int sge_count; - u64 sgl_offset; - struct rpcrdma_read_chunk *ch; - struct svc_rdma_op_ctxt *ctxt = NULL; - struct svc_rdma_req_map *rpl_map; - struct svc_rdma_req_map *chl_map; - - /* If no read list is present, return 0 */ - ch = svc_rdma_get_read_chunk(rmsgp); - if (!ch) - return 0; + __be32 *p; - svc_rdma_rcl_chunk_counts(ch, &ch_count, &byte_count); - if (ch_count > RPCSVC_MAXPAGES) - return -EINVAL; + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; - /* Allocate temporary reply and chunk maps */ - rpl_map = svc_rdma_get_req_map(); - chl_map = svc_rdma_get_req_map(); + if (!xdr_item_is_present(p)) + return true; + if (!xdr_check_write_chunk(rctxt)) + return false; - if (!xprt->sc_frmr_pg_list_len) - sge_count = map_read_chunks(xprt, rqstp, hdr_ctxt, rmsgp, - rpl_map, chl_map, ch_count, - byte_count); - else - sge_count = fast_reg_read_chunks(xprt, rqstp, hdr_ctxt, rmsgp, - rpl_map, chl_map, ch_count, - byte_count); - if (sge_count < 0) { - err = -EIO; - goto out; - } + rctxt->rc_reply_pcl.cl_count = 1; + return pcl_alloc_write(rctxt, &rctxt->rc_reply_pcl, p); +} - sgl_offset = 0; - ch_no = 0; - - for (ch = (struct rpcrdma_read_chunk *)&rmsgp->rm_body.rm_chunks[0]; - ch->rc_discrim != 0; ch++, ch_no++) { - u64 rs_offset; -next_sge: - ctxt = svc_rdma_get_context(xprt); - ctxt->direction = DMA_FROM_DEVICE; - ctxt->frmr = hdr_ctxt->frmr; - ctxt->read_hdr = NULL; - clear_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); - clear_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); - - /* Prepare READ WR */ - memset(&read_wr, 0, sizeof read_wr); - read_wr.wr_id = (unsigned long)ctxt; - read_wr.opcode = IB_WR_RDMA_READ; - ctxt->wr_op = read_wr.opcode; - read_wr.send_flags = IB_SEND_SIGNALED; - read_wr.wr.rdma.rkey = ntohl(ch->rc_target.rs_handle); - xdr_decode_hyper((__be32 *)&ch->rc_target.rs_offset, - &rs_offset); - read_wr.wr.rdma.remote_addr = rs_offset + sgl_offset; - read_wr.sg_list = ctxt->sge; - read_wr.num_sge = - rdma_read_max_sge(xprt, chl_map->ch[ch_no].count); - err = rdma_set_ctxt_sge(xprt, ctxt, hdr_ctxt->frmr, - &rpl_map->sge[chl_map->ch[ch_no].start], - &sgl_offset, - read_wr.num_sge); - if (err) { - svc_rdma_unmap_dma(ctxt); - svc_rdma_put_context(ctxt, 0); - goto out; +/* RPC-over-RDMA Version One private extension: Remote Invalidation. + * Responder's choice: requester signals it can handle Send With + * Invalidate, and responder chooses one R_key to invalidate. + * + * If there is exactly one distinct R_key in the received transport + * header, set rc_inv_rkey to that R_key. Otherwise, set it to zero. + */ +static void svc_rdma_get_inv_rkey(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *ctxt) +{ + struct svc_rdma_segment *segment; + struct svc_rdma_chunk *chunk; + u32 inv_rkey; + + ctxt->rc_inv_rkey = 0; + + if (!rdma->sc_snd_w_inv) + return; + + inv_rkey = 0; + pcl_for_each_chunk(chunk, &ctxt->rc_call_pcl) { + pcl_for_each_segment(segment, chunk) { + if (inv_rkey == 0) + inv_rkey = segment->rs_handle; + else if (inv_rkey != segment->rs_handle) + return; } - if (((ch+1)->rc_discrim == 0) && - (read_wr.num_sge == chl_map->ch[ch_no].count)) { - /* - * Mark the last RDMA_READ with a bit to - * indicate all RPC data has been fetched from - * the client and the RPC needs to be enqueued. - */ - set_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); - if (hdr_ctxt->frmr) { - set_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); - /* - * Invalidate the local MR used to map the data - * sink. - */ - if (xprt->sc_dev_caps & - SVCRDMA_DEVCAP_READ_W_INV) { - read_wr.opcode = - IB_WR_RDMA_READ_WITH_INV; - ctxt->wr_op = read_wr.opcode; - read_wr.ex.invalidate_rkey = - ctxt->frmr->mr->lkey; - } else { - /* Prepare INVALIDATE WR */ - memset(&inv_wr, 0, sizeof inv_wr); - inv_wr.opcode = IB_WR_LOCAL_INV; - inv_wr.send_flags = IB_SEND_SIGNALED; - inv_wr.ex.invalidate_rkey = - hdr_ctxt->frmr->mr->lkey; - read_wr.next = &inv_wr; - } - } - ctxt->read_hdr = hdr_ctxt; + } + pcl_for_each_chunk(chunk, &ctxt->rc_read_pcl) { + pcl_for_each_segment(segment, chunk) { + if (inv_rkey == 0) + inv_rkey = segment->rs_handle; + else if (inv_rkey != segment->rs_handle) + return; } - /* Post the read */ - err = svc_rdma_send(xprt, &read_wr); - if (err) { - printk(KERN_ERR "svcrdma: Error %d posting RDMA_READ\n", - err); - set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); - svc_rdma_unmap_dma(ctxt); - svc_rdma_put_context(ctxt, 0); - goto out; + } + pcl_for_each_chunk(chunk, &ctxt->rc_write_pcl) { + pcl_for_each_segment(segment, chunk) { + if (inv_rkey == 0) + inv_rkey = segment->rs_handle; + else if (inv_rkey != segment->rs_handle) + return; } - atomic_inc(&rdma_stat_read); - - if (read_wr.num_sge < chl_map->ch[ch_no].count) { - chl_map->ch[ch_no].count -= read_wr.num_sge; - chl_map->ch[ch_no].start += read_wr.num_sge; - goto next_sge; + } + pcl_for_each_chunk(chunk, &ctxt->rc_reply_pcl) { + pcl_for_each_segment(segment, chunk) { + if (inv_rkey == 0) + inv_rkey = segment->rs_handle; + else if (inv_rkey != segment->rs_handle) + return; } - sgl_offset = 0; - err = 1; } + ctxt->rc_inv_rkey = inv_rkey; +} - out: - svc_rdma_put_req_map(rpl_map); - svc_rdma_put_req_map(chl_map); +/** + * svc_rdma_xdr_decode_req - Decode the transport header + * @rq_arg: xdr_buf containing ingress RPC/RDMA message + * @rctxt: state of decoding + * + * On entry, xdr->head[0].iov_base points to first byte of the + * RPC-over-RDMA transport header. + * + * On successful exit, head[0] points to first byte past the + * RPC-over-RDMA header. For RDMA_MSG, this is the RPC message. + * + * The length of the RPC-over-RDMA header is returned. + * + * Assumptions: + * - The transport header is entirely contained in the head iovec. + */ +static int svc_rdma_xdr_decode_req(struct xdr_buf *rq_arg, + struct svc_rdma_recv_ctxt *rctxt) +{ + __be32 *p, *rdma_argp; + unsigned int hdr_len; + + rdma_argp = rq_arg->head[0].iov_base; + xdr_init_decode(&rctxt->rc_stream, rq_arg, rdma_argp, NULL); + + p = xdr_inline_decode(&rctxt->rc_stream, + rpcrdma_fixed_maxsz * sizeof(*p)); + if (unlikely(!p)) + goto out_short; + p++; + if (*p != rpcrdma_version) + goto out_version; + p += 2; + rctxt->rc_msgtype = *p; + switch (rctxt->rc_msgtype) { + case rdma_msg: + break; + case rdma_nomsg: + break; + case rdma_done: + goto out_drop; + case rdma_error: + goto out_drop; + default: + goto out_proc; + } + + if (!xdr_check_read_list(rctxt)) + goto out_inval; + if (!xdr_check_write_list(rctxt)) + goto out_inval; + if (!xdr_check_reply_chunk(rctxt)) + goto out_inval; + + rq_arg->head[0].iov_base = rctxt->rc_stream.p; + hdr_len = xdr_stream_pos(&rctxt->rc_stream); + rq_arg->head[0].iov_len -= hdr_len; + rq_arg->len -= hdr_len; + trace_svcrdma_decode_rqst(rctxt, rdma_argp, hdr_len); + return hdr_len; + +out_short: + trace_svcrdma_decode_short_err(rctxt, rq_arg->len); + return -EINVAL; + +out_version: + trace_svcrdma_decode_badvers_err(rctxt, rdma_argp); + return -EPROTONOSUPPORT; + +out_drop: + trace_svcrdma_decode_drop_err(rctxt, rdma_argp); + return 0; + +out_proc: + trace_svcrdma_decode_badproc_err(rctxt, rdma_argp); + return -EINVAL; + +out_inval: + trace_svcrdma_decode_parse_err(rctxt, rdma_argp); + return -EINVAL; +} + +static void svc_rdma_send_error(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *rctxt, + int status) +{ + struct svc_rdma_send_ctxt *sctxt; + + sctxt = svc_rdma_send_ctxt_get(rdma); + if (!sctxt) + return; + svc_rdma_send_error_msg(rdma, sctxt, rctxt, status); +} + +/* By convention, backchannel calls arrive via rdma_msg type + * messages, and never populate the chunk lists. This makes + * the RPC/RDMA header small and fixed in size, so it is + * straightforward to check the RPC header's direction field. + */ +static bool svc_rdma_is_reverse_direction_reply(struct svc_xprt *xprt, + struct svc_rdma_recv_ctxt *rctxt) +{ + __be32 *p = rctxt->rc_recv_buf; + + if (!xprt->xpt_bc_xprt) + return false; + + if (rctxt->rc_msgtype != rdma_msg) + return false; + + if (!pcl_is_empty(&rctxt->rc_call_pcl)) + return false; + if (!pcl_is_empty(&rctxt->rc_read_pcl)) + return false; + if (!pcl_is_empty(&rctxt->rc_write_pcl)) + return false; + if (!pcl_is_empty(&rctxt->rc_reply_pcl)) + return false; - /* Detach arg pages. svc_recv will replenish them */ - for (ch_no = 0; &rqstp->rq_pages[ch_no] < rqstp->rq_respages; ch_no++) - rqstp->rq_pages[ch_no] = NULL; + /* RPC call direction */ + if (*(p + 8) == cpu_to_be32(RPC_CALL)) + return false; - /* - * Detach res pages. If svc_release sees any it will attempt to - * put them. + return true; +} + +/* Finish constructing the RPC Call message in rqstp::rq_arg. + * + * The incoming RPC/RDMA message is an RDMA_MSG type message + * with a single Read chunk (only the upper layer data payload + * was conveyed via RDMA Read). + */ +static void svc_rdma_read_complete_one(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *ctxt) +{ + struct svc_rdma_chunk *chunk = pcl_first_chunk(&ctxt->rc_read_pcl); + struct xdr_buf *buf = &rqstp->rq_arg; + unsigned int length; + + /* Split the Receive buffer between the head and tail + * buffers at Read chunk's position. XDR roundup of the + * chunk is not included in either the pagelist or in + * the tail. + */ + buf->tail[0].iov_base = buf->head[0].iov_base + chunk->ch_position; + buf->tail[0].iov_len = buf->head[0].iov_len - chunk->ch_position; + buf->head[0].iov_len = chunk->ch_position; + + /* Read chunk may need XDR roundup (see RFC 8166, s. 3.4.5.2). + * + * If the client already rounded up the chunk length, the + * length does not change. Otherwise, the length of the page + * list is increased to include XDR round-up. + * + * Currently these chunks always start at page offset 0, + * thus the rounded-up length never crosses a page boundary. */ - while (rqstp->rq_next_page != rqstp->rq_respages) - *(--rqstp->rq_next_page) = NULL; + buf->pages = &rqstp->rq_pages[0]; + length = xdr_align_size(chunk->ch_length); + buf->page_len = length; + buf->len += length; + buf->buflen += length; +} - return err; +/* Finish constructing the RPC Call message in rqstp::rq_arg. + * + * The incoming RPC/RDMA message is an RDMA_MSG type message + * with payload in multiple Read chunks and no PZRC. + */ +static void svc_rdma_read_complete_multiple(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *ctxt) +{ + struct xdr_buf *buf = &rqstp->rq_arg; + + buf->len += ctxt->rc_readbytes; + buf->buflen += ctxt->rc_readbytes; + + buf->head[0].iov_base = page_address(rqstp->rq_pages[0]); + buf->head[0].iov_len = min_t(size_t, PAGE_SIZE, ctxt->rc_readbytes); + buf->pages = &rqstp->rq_pages[1]; + buf->page_len = ctxt->rc_readbytes - buf->head[0].iov_len; } -static int rdma_read_complete(struct svc_rqst *rqstp, - struct svc_rdma_op_ctxt *head) +/* Finish constructing the RPC Call message in rqstp::rq_arg. + * + * The incoming RPC/RDMA message is an RDMA_NOMSG type message + * (the RPC message body was conveyed via RDMA Read). + */ +static void svc_rdma_read_complete_pzrc(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *ctxt) { - int page_no; - int ret; + struct xdr_buf *buf = &rqstp->rq_arg; - BUG_ON(!head); + buf->len += ctxt->rc_readbytes; + buf->buflen += ctxt->rc_readbytes; - /* Copy RPC pages */ - for (page_no = 0; page_no < head->count; page_no++) { - put_page(rqstp->rq_pages[page_no]); - rqstp->rq_pages[page_no] = head->pages[page_no]; - } - /* Point rq_arg.pages past header */ - rqstp->rq_arg.pages = &rqstp->rq_pages[head->hdr_count]; - rqstp->rq_arg.page_len = head->arg.page_len; - rqstp->rq_arg.page_base = head->arg.page_base; + buf->head[0].iov_base = page_address(rqstp->rq_pages[0]); + buf->head[0].iov_len = min_t(size_t, PAGE_SIZE, ctxt->rc_readbytes); + buf->pages = &rqstp->rq_pages[1]; + buf->page_len = ctxt->rc_readbytes - buf->head[0].iov_len; +} - /* rq_respages starts after the last arg page */ - rqstp->rq_respages = &rqstp->rq_arg.pages[page_no]; - rqstp->rq_next_page = &rqstp->rq_arg.pages[page_no]; +static noinline void svc_rdma_read_complete(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *ctxt) +{ + unsigned int i; - /* Rebuild rq_arg head and tail. */ - rqstp->rq_arg.head[0] = head->arg.head[0]; - rqstp->rq_arg.tail[0] = head->arg.tail[0]; - rqstp->rq_arg.len = head->arg.len; - rqstp->rq_arg.buflen = head->arg.buflen; + /* Transfer the Read chunk pages into @rqstp.rq_pages, replacing + * the rq_pages that were already allocated for this rqstp. + */ + release_pages(rqstp->rq_respages, ctxt->rc_page_count); + for (i = 0; i < ctxt->rc_page_count; i++) + rqstp->rq_pages[i] = ctxt->rc_pages[i]; - /* Free the context */ - svc_rdma_put_context(head, 0); + /* Update @rqstp's result send buffer to start after the + * last page in the RDMA Read payload. + */ + rqstp->rq_respages = &rqstp->rq_pages[ctxt->rc_page_count]; + rqstp->rq_next_page = rqstp->rq_respages + 1; - /* XXX: What should this be? */ - rqstp->rq_prot = IPPROTO_MAX; - svc_xprt_copy_addrs(rqstp, rqstp->rq_xprt); + /* Prevent svc_rdma_recv_ctxt_put() from releasing the + * pages in ctxt::rc_pages a second time. + */ + ctxt->rc_page_count = 0; - ret = rqstp->rq_arg.head[0].iov_len - + rqstp->rq_arg.page_len - + rqstp->rq_arg.tail[0].iov_len; - dprintk("svcrdma: deferred read ret=%d, rq_arg.len =%d, " - "rq_arg.head[0].iov_base=%p, rq_arg.head[0].iov_len = %zd\n", - ret, rqstp->rq_arg.len, rqstp->rq_arg.head[0].iov_base, - rqstp->rq_arg.head[0].iov_len); + /* Finish constructing the RPC Call message. The exact + * procedure for that depends on what kind of RPC/RDMA + * chunks were provided by the client. + */ + rqstp->rq_arg = ctxt->rc_saved_arg; + if (pcl_is_empty(&ctxt->rc_call_pcl)) { + if (ctxt->rc_read_pcl.cl_count == 1) + svc_rdma_read_complete_one(rqstp, ctxt); + else + svc_rdma_read_complete_multiple(rqstp, ctxt); + } else { + svc_rdma_read_complete_pzrc(rqstp, ctxt); + } - return ret; + trace_svcrdma_read_finished(&ctxt->rc_cid); } -/* - * Set up the rqstp thread context to point to the RQ buffer. If - * necessary, pull additional data from the client with an RDMA_READ - * request. +/** + * svc_rdma_recvfrom - Receive an RPC call + * @rqstp: request structure into which to receive an RPC Call + * + * Returns: + * The positive number of bytes in the RPC Call message, + * %0 if there were no Calls ready to return, + * %-EINVAL if the Read chunk data is too large, + * %-ENOMEM if rdma_rw context pool was exhausted, + * %-ENOTCONN if posting failed (connection is lost), + * %-EIO if rdma_rw initialization failed (DMA mapping, etc). + * + * Called in a loop when XPT_DATA is set. XPT_DATA is cleared only + * when there are no remaining ctxt's to process. + * + * The next ctxt is removed from the "receive" lists. + * + * - If the ctxt completes a Receive, then construct the Call + * message from the contents of the Receive buffer. + * + * - If there are no Read chunks in this message, then finish + * assembling the Call message and return the number of bytes + * in the message. + * + * - If there are Read chunks in this message, post Read WRs to + * pull that payload. When the Read WRs complete, build the + * full message and return the number of bytes in it. */ int svc_rdma_recvfrom(struct svc_rqst *rqstp) { struct svc_xprt *xprt = rqstp->rq_xprt; struct svcxprt_rdma *rdma_xprt = container_of(xprt, struct svcxprt_rdma, sc_xprt); - struct svc_rdma_op_ctxt *ctxt = NULL; - struct rpcrdma_msg *rmsgp; - int ret = 0; - int len; - - dprintk("svcrdma: rqstp=%p\n", rqstp); - - spin_lock_bh(&rdma_xprt->sc_rq_dto_lock); - if (!list_empty(&rdma_xprt->sc_read_complete_q)) { - ctxt = list_entry(rdma_xprt->sc_read_complete_q.next, - struct svc_rdma_op_ctxt, - dto_q); - list_del_init(&ctxt->dto_q); - } + struct svc_rdma_recv_ctxt *ctxt; + int ret; + + /* Prevent svc_xprt_release() from releasing pages in rq_pages + * when returning 0 or an error. + */ + rqstp->rq_respages = rqstp->rq_pages; + rqstp->rq_next_page = rqstp->rq_respages; + + rqstp->rq_xprt_ctxt = NULL; + + spin_lock(&rdma_xprt->sc_rq_dto_lock); + ctxt = svc_rdma_next_recv_ctxt(&rdma_xprt->sc_read_complete_q); if (ctxt) { - spin_unlock_bh(&rdma_xprt->sc_rq_dto_lock); - return rdma_read_complete(rqstp, ctxt); + list_del(&ctxt->rc_list); + spin_unlock(&rdma_xprt->sc_rq_dto_lock); + svc_xprt_received(xprt); + svc_rdma_read_complete(rqstp, ctxt); + goto complete; } - - if (!list_empty(&rdma_xprt->sc_rq_dto_q)) { - ctxt = list_entry(rdma_xprt->sc_rq_dto_q.next, - struct svc_rdma_op_ctxt, - dto_q); - list_del_init(&ctxt->dto_q); - } else { - atomic_inc(&rdma_stat_rq_starve); + ctxt = svc_rdma_next_recv_ctxt(&rdma_xprt->sc_rq_dto_q); + if (ctxt) + list_del(&ctxt->rc_list); + else + /* No new incoming requests, terminate the loop */ clear_bit(XPT_DATA, &xprt->xpt_flags); - ctxt = NULL; - } - spin_unlock_bh(&rdma_xprt->sc_rq_dto_lock); - if (!ctxt) { - /* This is the EAGAIN path. The svc_recv routine will - * return -EAGAIN, the nfsd thread will go to call into - * svc_recv again and we shouldn't be on the active - * transport list - */ - if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) - goto close_out; - - BUG_ON(ret); - goto out; - } - dprintk("svcrdma: processing ctxt=%p on xprt=%p, rqstp=%p, status=%d\n", - ctxt, rdma_xprt, rqstp, ctxt->wc_status); - BUG_ON(ctxt->wc_status != IB_WC_SUCCESS); - atomic_inc(&rdma_stat_recv); - - /* Build up the XDR from the receive buffers. */ - rdma_build_arg_xdr(rqstp, ctxt, ctxt->byte_len); - - /* Decode the RDMA header. */ - len = svc_rdma_xdr_decode_req(&rmsgp, rqstp); - rqstp->rq_xprt_hlen = len; - - /* If the request is invalid, reply with an error */ - if (len < 0) { - if (len == -ENOSYS) - svc_rdma_send_error(rdma_xprt, rmsgp, ERR_VERS); - goto close_out; - } + spin_unlock(&rdma_xprt->sc_rq_dto_lock); - /* Read read-list data. */ - ret = rdma_read_xdr(rdma_xprt, rmsgp, rqstp, ctxt); - if (ret > 0) { - /* read-list posted, defer until data received from client. */ - goto defer; - } - if (ret < 0) { - /* Post of read-list failed, free context. */ - svc_rdma_put_context(ctxt, 1); + /* Unblock the transport for the next receive */ + svc_xprt_received(xprt); + if (!ctxt) return 0; - } - ret = rqstp->rq_arg.head[0].iov_len - + rqstp->rq_arg.page_len - + rqstp->rq_arg.tail[0].iov_len; - svc_rdma_put_context(ctxt, 0); - out: - dprintk("svcrdma: ret = %d, rq_arg.len =%d, " - "rq_arg.head[0].iov_base=%p, rq_arg.head[0].iov_len = %zd\n", - ret, rqstp->rq_arg.len, - rqstp->rq_arg.head[0].iov_base, - rqstp->rq_arg.head[0].iov_len); + percpu_counter_inc(&svcrdma_stat_recv); + ib_dma_sync_single_for_cpu(rdma_xprt->sc_pd->device, + ctxt->rc_recv_sge.addr, ctxt->rc_byte_len, + DMA_FROM_DEVICE); + svc_rdma_build_arg_xdr(rqstp, ctxt); + + ret = svc_rdma_xdr_decode_req(&rqstp->rq_arg, ctxt); + if (ret < 0) + goto out_err; + if (ret == 0) + goto out_drop; + + if (svc_rdma_is_reverse_direction_reply(xprt, ctxt)) + goto out_backchannel; + + svc_rdma_get_inv_rkey(rdma_xprt, ctxt); + + if (!pcl_is_empty(&ctxt->rc_read_pcl) || + !pcl_is_empty(&ctxt->rc_call_pcl)) + goto out_readlist; + +complete: + rqstp->rq_xprt_ctxt = ctxt; rqstp->rq_prot = IPPROTO_MAX; svc_xprt_copy_addrs(rqstp, xprt); - return ret; + set_bit(RQ_SECURE, &rqstp->rq_flags); + return rqstp->rq_arg.len; - close_out: - if (ctxt) - svc_rdma_put_context(ctxt, 1); - dprintk("svcrdma: transport %p is closing\n", xprt); - /* - * Set the close bit and enqueue it. svc_recv will see the - * close bit and call svc_xprt_delete +out_err: + svc_rdma_send_error(rdma_xprt, ctxt, ret); + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); + return 0; + +out_readlist: + /* This @rqstp is about to be recycled. Save the work + * already done constructing the Call message in rq_arg + * so it can be restored when the RDMA Reads have + * completed. */ - set_bit(XPT_CLOSE, &xprt->xpt_flags); -defer: + ctxt->rc_saved_arg = rqstp->rq_arg; + + ret = svc_rdma_process_read_list(rdma_xprt, rqstp, ctxt); + if (ret < 0) { + if (ret == -EINVAL) + svc_rdma_send_error(rdma_xprt, ctxt, ret); + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); + svc_xprt_deferred_close(xprt); + return ret; + } + return 0; + +out_backchannel: + svc_rdma_handle_bc_reply(rqstp, ctxt); +out_drop: + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); return 0; } diff --git a/net/sunrpc/xprtrdma/svc_rdma_rw.c b/net/sunrpc/xprtrdma/svc_rdma_rw.c new file mode 100644 index 000000000000..661b3fe2779f --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_rw.c @@ -0,0 +1,1142 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2016-2018 Oracle. All rights reserved. + * + * Use the core R/W API to move RPC-over-RDMA Read and Write chunks. + */ + +#include <rdma/rw.h> + +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/sunrpc/svc_rdma.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc); +static void svc_rdma_wc_read_done(struct ib_cq *cq, struct ib_wc *wc); + +/* Each R/W context contains state for one chain of RDMA Read or + * Write Work Requests. + * + * Each WR chain handles a single contiguous server-side buffer, + * because scatterlist entries after the first have to start on + * page alignment. xdr_buf iovecs cannot guarantee alignment. + * + * Each WR chain handles only one R_key. Each RPC-over-RDMA segment + * from a client may contain a unique R_key, so each WR chain moves + * up to one segment at a time. + * + * The scatterlist makes this data structure over 4KB in size. To + * make it less likely to fail, and to handle the allocation for + * smaller I/O requests without disabling bottom-halves, these + * contexts are created on demand, but cached and reused until the + * controlling svcxprt_rdma is destroyed. + */ +struct svc_rdma_rw_ctxt { + struct llist_node rw_node; + struct list_head rw_list; + struct rdma_rw_ctx rw_ctx; + unsigned int rw_nents; + unsigned int rw_first_sgl_nents; + struct sg_table rw_sg_table; + struct scatterlist rw_first_sgl[]; +}; + +static inline struct svc_rdma_rw_ctxt * +svc_rdma_next_ctxt(struct list_head *list) +{ + return list_first_entry_or_null(list, struct svc_rdma_rw_ctxt, + rw_list); +} + +static struct svc_rdma_rw_ctxt * +svc_rdma_get_rw_ctxt(struct svcxprt_rdma *rdma, unsigned int sges) +{ + struct ib_device *dev = rdma->sc_cm_id->device; + unsigned int first_sgl_nents = dev->attrs.max_send_sge; + struct svc_rdma_rw_ctxt *ctxt; + struct llist_node *node; + + spin_lock(&rdma->sc_rw_ctxt_lock); + node = llist_del_first(&rdma->sc_rw_ctxts); + spin_unlock(&rdma->sc_rw_ctxt_lock); + if (node) { + ctxt = llist_entry(node, struct svc_rdma_rw_ctxt, rw_node); + } else { + ctxt = kmalloc_node(struct_size(ctxt, rw_first_sgl, first_sgl_nents), + GFP_KERNEL, ibdev_to_node(dev)); + if (!ctxt) + goto out_noctx; + + INIT_LIST_HEAD(&ctxt->rw_list); + ctxt->rw_first_sgl_nents = first_sgl_nents; + } + + ctxt->rw_sg_table.sgl = ctxt->rw_first_sgl; + if (sg_alloc_table_chained(&ctxt->rw_sg_table, sges, + ctxt->rw_sg_table.sgl, + first_sgl_nents)) + goto out_free; + return ctxt; + +out_free: + kfree(ctxt); +out_noctx: + trace_svcrdma_rwctx_empty(rdma, sges); + return NULL; +} + +static void __svc_rdma_put_rw_ctxt(struct svc_rdma_rw_ctxt *ctxt, + struct llist_head *list) +{ + sg_free_table_chained(&ctxt->rw_sg_table, ctxt->rw_first_sgl_nents); + llist_add(&ctxt->rw_node, list); +} + +static void svc_rdma_put_rw_ctxt(struct svcxprt_rdma *rdma, + struct svc_rdma_rw_ctxt *ctxt) +{ + __svc_rdma_put_rw_ctxt(ctxt, &rdma->sc_rw_ctxts); +} + +/** + * svc_rdma_destroy_rw_ctxts - Free accumulated R/W contexts + * @rdma: transport about to be destroyed + * + */ +void svc_rdma_destroy_rw_ctxts(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_rw_ctxt *ctxt; + struct llist_node *node; + + while ((node = llist_del_first(&rdma->sc_rw_ctxts)) != NULL) { + ctxt = llist_entry(node, struct svc_rdma_rw_ctxt, rw_node); + kfree(ctxt); + } +} + +/** + * svc_rdma_rw_ctx_init - Prepare a R/W context for I/O + * @rdma: controlling transport instance + * @ctxt: R/W context to prepare + * @offset: RDMA offset + * @handle: RDMA tag/handle + * @direction: I/O direction + * + * Returns on success, the number of WQEs that will be needed + * on the workqueue, or a negative errno. + */ +static int svc_rdma_rw_ctx_init(struct svcxprt_rdma *rdma, + struct svc_rdma_rw_ctxt *ctxt, + u64 offset, u32 handle, + enum dma_data_direction direction) +{ + int ret; + + ret = rdma_rw_ctx_init(&ctxt->rw_ctx, rdma->sc_qp, rdma->sc_port_num, + ctxt->rw_sg_table.sgl, ctxt->rw_nents, + 0, offset, handle, direction); + if (unlikely(ret < 0)) { + trace_svcrdma_dma_map_rw_err(rdma, offset, handle, + ctxt->rw_nents, ret); + svc_rdma_put_rw_ctxt(rdma, ctxt); + } + return ret; +} + +/** + * svc_rdma_cc_init - Initialize an svc_rdma_chunk_ctxt + * @rdma: controlling transport instance + * @cc: svc_rdma_chunk_ctxt to be initialized + */ +void svc_rdma_cc_init(struct svcxprt_rdma *rdma, + struct svc_rdma_chunk_ctxt *cc) +{ + struct rpc_rdma_cid *cid = &cc->cc_cid; + + if (unlikely(!cid->ci_completion_id)) + svc_rdma_send_cid_init(rdma, cid); + + INIT_LIST_HEAD(&cc->cc_rwctxts); + cc->cc_sqecount = 0; +} + +/** + * svc_rdma_cc_release - Release resources held by a svc_rdma_chunk_ctxt + * @rdma: controlling transport instance + * @cc: svc_rdma_chunk_ctxt to be released + * @dir: DMA direction + */ +void svc_rdma_cc_release(struct svcxprt_rdma *rdma, + struct svc_rdma_chunk_ctxt *cc, + enum dma_data_direction dir) +{ + struct llist_node *first, *last; + struct svc_rdma_rw_ctxt *ctxt; + LLIST_HEAD(free); + + trace_svcrdma_cc_release(&cc->cc_cid, cc->cc_sqecount); + + first = last = NULL; + while ((ctxt = svc_rdma_next_ctxt(&cc->cc_rwctxts)) != NULL) { + list_del(&ctxt->rw_list); + + rdma_rw_ctx_destroy(&ctxt->rw_ctx, rdma->sc_qp, + rdma->sc_port_num, ctxt->rw_sg_table.sgl, + ctxt->rw_nents, dir); + __svc_rdma_put_rw_ctxt(ctxt, &free); + + ctxt->rw_node.next = first; + first = &ctxt->rw_node; + if (!last) + last = first; + } + if (first) + llist_add_batch(first, last, &rdma->sc_rw_ctxts); +} + +static struct svc_rdma_write_info * +svc_rdma_write_info_alloc(struct svcxprt_rdma *rdma, + const struct svc_rdma_chunk *chunk) +{ + struct svc_rdma_write_info *info; + + info = kzalloc_node(sizeof(*info), GFP_KERNEL, + ibdev_to_node(rdma->sc_cm_id->device)); + if (!info) + return info; + + info->wi_rdma = rdma; + info->wi_chunk = chunk; + svc_rdma_cc_init(rdma, &info->wi_cc); + info->wi_cc.cc_cqe.done = svc_rdma_write_done; + return info; +} + +static void svc_rdma_write_info_free_async(struct work_struct *work) +{ + struct svc_rdma_write_info *info; + + info = container_of(work, struct svc_rdma_write_info, wi_work); + svc_rdma_cc_release(info->wi_rdma, &info->wi_cc, DMA_TO_DEVICE); + kfree(info); +} + +static void svc_rdma_write_info_free(struct svc_rdma_write_info *info) +{ + INIT_WORK(&info->wi_work, svc_rdma_write_info_free_async); + queue_work(svcrdma_wq, &info->wi_work); +} + +/** + * svc_rdma_reply_chunk_release - Release Reply chunk I/O resources + * @rdma: controlling transport + * @ctxt: Send context that is being released + */ +void svc_rdma_reply_chunk_release(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt) +{ + struct svc_rdma_chunk_ctxt *cc = &ctxt->sc_reply_info.wi_cc; + + if (!cc->cc_sqecount) + return; + svc_rdma_cc_release(rdma, cc, DMA_TO_DEVICE); +} + +/** + * svc_rdma_reply_done - Reply chunk Write completion handler + * @cq: controlling Completion Queue + * @wc: Work Completion report + * + * Pages under I/O are released by a subsequent Send completion. + */ +static void svc_rdma_reply_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_chunk_ctxt *cc = + container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe); + struct svcxprt_rdma *rdma = cq->cq_context; + + switch (wc->status) { + case IB_WC_SUCCESS: + trace_svcrdma_wc_reply(&cc->cc_cid); + return; + case IB_WC_WR_FLUSH_ERR: + trace_svcrdma_wc_reply_flush(wc, &cc->cc_cid); + break; + default: + trace_svcrdma_wc_reply_err(wc, &cc->cc_cid); + } + + svc_xprt_deferred_close(&rdma->sc_xprt); +} + +/** + * svc_rdma_write_done - Write chunk completion + * @cq: controlling Completion Queue + * @wc: Work Completion + * + * Pages under I/O are freed by a subsequent Send completion. + */ +static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct svcxprt_rdma *rdma = cq->cq_context; + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_chunk_ctxt *cc = + container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe); + struct svc_rdma_write_info *info = + container_of(cc, struct svc_rdma_write_info, wi_cc); + + switch (wc->status) { + case IB_WC_SUCCESS: + trace_svcrdma_wc_write(&cc->cc_cid); + break; + case IB_WC_WR_FLUSH_ERR: + trace_svcrdma_wc_write_flush(wc, &cc->cc_cid); + break; + default: + trace_svcrdma_wc_write_err(wc, &cc->cc_cid); + } + + svc_rdma_wake_send_waiters(rdma, cc->cc_sqecount); + + if (unlikely(wc->status != IB_WC_SUCCESS)) + svc_xprt_deferred_close(&rdma->sc_xprt); + + svc_rdma_write_info_free(info); +} + +/** + * svc_rdma_wc_read_done - Handle completion of an RDMA Read ctx + * @cq: controlling Completion Queue + * @wc: Work Completion + * + */ +static void svc_rdma_wc_read_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct svcxprt_rdma *rdma = cq->cq_context; + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_chunk_ctxt *cc = + container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe); + struct svc_rdma_recv_ctxt *ctxt; + + svc_rdma_wake_send_waiters(rdma, cc->cc_sqecount); + + ctxt = container_of(cc, struct svc_rdma_recv_ctxt, rc_cc); + switch (wc->status) { + case IB_WC_SUCCESS: + trace_svcrdma_wc_read(wc, &cc->cc_cid, ctxt->rc_readbytes, + cc->cc_posttime); + + spin_lock(&rdma->sc_rq_dto_lock); + list_add_tail(&ctxt->rc_list, &rdma->sc_read_complete_q); + /* the unlock pairs with the smp_rmb in svc_xprt_ready */ + set_bit(XPT_DATA, &rdma->sc_xprt.xpt_flags); + spin_unlock(&rdma->sc_rq_dto_lock); + svc_xprt_enqueue(&rdma->sc_xprt); + return; + case IB_WC_WR_FLUSH_ERR: + trace_svcrdma_wc_read_flush(wc, &cc->cc_cid); + break; + default: + trace_svcrdma_wc_read_err(wc, &cc->cc_cid); + } + + /* The RDMA Read has flushed, so the incoming RPC message + * cannot be constructed and must be dropped. Signal the + * loss to the client by closing the connection. + */ + svc_rdma_cc_release(rdma, cc, DMA_FROM_DEVICE); + svc_rdma_recv_ctxt_put(rdma, ctxt); + svc_xprt_deferred_close(&rdma->sc_xprt); +} + +/* + * Assumptions: + * - If ib_post_send() succeeds, only one completion is expected, + * even if one or more WRs are flushed. This is true when posting + * an rdma_rw_ctx or when posting a single signaled WR. + */ +static int svc_rdma_post_chunk_ctxt(struct svcxprt_rdma *rdma, + struct svc_rdma_chunk_ctxt *cc) +{ + struct ib_send_wr *first_wr; + const struct ib_send_wr *bad_wr; + struct list_head *tmp; + struct ib_cqe *cqe; + int ret; + + might_sleep(); + + if (cc->cc_sqecount > rdma->sc_sq_depth) + return -EINVAL; + + first_wr = NULL; + cqe = &cc->cc_cqe; + list_for_each(tmp, &cc->cc_rwctxts) { + struct svc_rdma_rw_ctxt *ctxt; + + ctxt = list_entry(tmp, struct svc_rdma_rw_ctxt, rw_list); + first_wr = rdma_rw_ctx_wrs(&ctxt->rw_ctx, rdma->sc_qp, + rdma->sc_port_num, cqe, first_wr); + cqe = NULL; + } + + do { + if (atomic_sub_return(cc->cc_sqecount, + &rdma->sc_sq_avail) > 0) { + cc->cc_posttime = ktime_get(); + ret = ib_post_send(rdma->sc_qp, first_wr, &bad_wr); + if (ret) + break; + return 0; + } + + percpu_counter_inc(&svcrdma_stat_sq_starve); + trace_svcrdma_sq_full(rdma, &cc->cc_cid); + atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail); + wait_event(rdma->sc_send_wait, + atomic_read(&rdma->sc_sq_avail) > cc->cc_sqecount); + trace_svcrdma_sq_retry(rdma, &cc->cc_cid); + } while (1); + + trace_svcrdma_sq_post_err(rdma, &cc->cc_cid, ret); + svc_xprt_deferred_close(&rdma->sc_xprt); + + /* If even one was posted, there will be a completion. */ + if (bad_wr != first_wr) + return 0; + + atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail); + wake_up(&rdma->sc_send_wait); + return -ENOTCONN; +} + +/* Build and DMA-map an SGL that covers one kvec in an xdr_buf + */ +static void svc_rdma_vec_to_sg(struct svc_rdma_write_info *info, + unsigned int len, + struct svc_rdma_rw_ctxt *ctxt) +{ + struct scatterlist *sg = ctxt->rw_sg_table.sgl; + + sg_set_buf(&sg[0], info->wi_base, len); + info->wi_base += len; + + ctxt->rw_nents = 1; +} + +/* Build and DMA-map an SGL that covers part of an xdr_buf's pagelist. + */ +static void svc_rdma_pagelist_to_sg(struct svc_rdma_write_info *info, + unsigned int remaining, + struct svc_rdma_rw_ctxt *ctxt) +{ + unsigned int sge_no, sge_bytes, page_off, page_no; + const struct xdr_buf *xdr = info->wi_xdr; + struct scatterlist *sg; + struct page **page; + + page_off = info->wi_next_off + xdr->page_base; + page_no = page_off >> PAGE_SHIFT; + page_off = offset_in_page(page_off); + page = xdr->pages + page_no; + info->wi_next_off += remaining; + sg = ctxt->rw_sg_table.sgl; + sge_no = 0; + do { + sge_bytes = min_t(unsigned int, remaining, + PAGE_SIZE - page_off); + sg_set_page(sg, *page, sge_bytes, page_off); + + remaining -= sge_bytes; + sg = sg_next(sg); + page_off = 0; + sge_no++; + page++; + } while (remaining); + + ctxt->rw_nents = sge_no; +} + +/* Construct RDMA Write WRs to send a portion of an xdr_buf containing + * an RPC Reply. + */ +static int +svc_rdma_build_writes(struct svc_rdma_write_info *info, + void (*constructor)(struct svc_rdma_write_info *info, + unsigned int len, + struct svc_rdma_rw_ctxt *ctxt), + unsigned int remaining) +{ + struct svc_rdma_chunk_ctxt *cc = &info->wi_cc; + struct svcxprt_rdma *rdma = info->wi_rdma; + const struct svc_rdma_segment *seg; + struct svc_rdma_rw_ctxt *ctxt; + int ret; + + do { + unsigned int write_len; + u64 offset; + + if (info->wi_seg_no >= info->wi_chunk->ch_segcount) + goto out_overflow; + + seg = &info->wi_chunk->ch_segments[info->wi_seg_no]; + write_len = min(remaining, seg->rs_length - info->wi_seg_off); + if (!write_len) + goto out_overflow; + ctxt = svc_rdma_get_rw_ctxt(rdma, + (write_len >> PAGE_SHIFT) + 2); + if (!ctxt) + return -ENOMEM; + + constructor(info, write_len, ctxt); + offset = seg->rs_offset + info->wi_seg_off; + ret = svc_rdma_rw_ctx_init(rdma, ctxt, offset, seg->rs_handle, + DMA_TO_DEVICE); + if (ret < 0) + return -EIO; + percpu_counter_inc(&svcrdma_stat_write); + + list_add(&ctxt->rw_list, &cc->cc_rwctxts); + cc->cc_sqecount += ret; + if (write_len == seg->rs_length - info->wi_seg_off) { + info->wi_seg_no++; + info->wi_seg_off = 0; + } else { + info->wi_seg_off += write_len; + } + remaining -= write_len; + } while (remaining); + + return 0; + +out_overflow: + trace_svcrdma_small_wrch_err(&cc->cc_cid, remaining, info->wi_seg_no, + info->wi_chunk->ch_segcount); + return -E2BIG; +} + +/** + * svc_rdma_iov_write - Construct RDMA Writes from an iov + * @info: pointer to write arguments + * @iov: kvec to write + * + * Returns: + * On success, returns zero + * %-E2BIG if the client-provided Write chunk is too small + * %-ENOMEM if a resource has been exhausted + * %-EIO if an rdma-rw error occurred + */ +static int svc_rdma_iov_write(struct svc_rdma_write_info *info, + const struct kvec *iov) +{ + info->wi_base = iov->iov_base; + return svc_rdma_build_writes(info, svc_rdma_vec_to_sg, + iov->iov_len); +} + +/** + * svc_rdma_pages_write - Construct RDMA Writes from pages + * @info: pointer to write arguments + * @xdr: xdr_buf with pages to write + * @offset: offset into the content of @xdr + * @length: number of bytes to write + * + * Returns: + * On success, returns zero + * %-E2BIG if the client-provided Write chunk is too small + * %-ENOMEM if a resource has been exhausted + * %-EIO if an rdma-rw error occurred + */ +static int svc_rdma_pages_write(struct svc_rdma_write_info *info, + const struct xdr_buf *xdr, + unsigned int offset, + unsigned long length) +{ + info->wi_xdr = xdr; + info->wi_next_off = offset - xdr->head[0].iov_len; + return svc_rdma_build_writes(info, svc_rdma_pagelist_to_sg, + length); +} + +/** + * svc_rdma_xb_write - Construct RDMA Writes to write an xdr_buf + * @xdr: xdr_buf to write + * @data: pointer to write arguments + * + * Returns: + * On success, returns zero + * %-E2BIG if the client-provided Write chunk is too small + * %-ENOMEM if a resource has been exhausted + * %-EIO if an rdma-rw error occurred + */ +static int svc_rdma_xb_write(const struct xdr_buf *xdr, void *data) +{ + struct svc_rdma_write_info *info = data; + int ret; + + if (xdr->head[0].iov_len) { + ret = svc_rdma_iov_write(info, &xdr->head[0]); + if (ret < 0) + return ret; + } + + if (xdr->page_len) { + ret = svc_rdma_pages_write(info, xdr, xdr->head[0].iov_len, + xdr->page_len); + if (ret < 0) + return ret; + } + + if (xdr->tail[0].iov_len) { + ret = svc_rdma_iov_write(info, &xdr->tail[0]); + if (ret < 0) + return ret; + } + + return xdr->len; +} + +static int svc_rdma_send_write_chunk(struct svcxprt_rdma *rdma, + const struct svc_rdma_chunk *chunk, + const struct xdr_buf *xdr) +{ + struct svc_rdma_write_info *info; + struct svc_rdma_chunk_ctxt *cc; + struct xdr_buf payload; + int ret; + + if (xdr_buf_subsegment(xdr, &payload, chunk->ch_position, + chunk->ch_payload_length)) + return -EMSGSIZE; + + info = svc_rdma_write_info_alloc(rdma, chunk); + if (!info) + return -ENOMEM; + cc = &info->wi_cc; + + ret = svc_rdma_xb_write(&payload, info); + if (ret != payload.len) + goto out_err; + + trace_svcrdma_post_write_chunk(&cc->cc_cid, cc->cc_sqecount); + ret = svc_rdma_post_chunk_ctxt(rdma, cc); + if (ret < 0) + goto out_err; + return 0; + +out_err: + svc_rdma_write_info_free(info); + return ret; +} + +/** + * svc_rdma_send_write_list - Send all chunks on the Write list + * @rdma: controlling RDMA transport + * @rctxt: Write list provisioned by the client + * @xdr: xdr_buf containing an RPC Reply message + * + * Returns zero on success, or a negative errno if one or more + * Write chunks could not be sent. + */ +int svc_rdma_send_write_list(struct svcxprt_rdma *rdma, + const struct svc_rdma_recv_ctxt *rctxt, + const struct xdr_buf *xdr) +{ + struct svc_rdma_chunk *chunk; + int ret; + + pcl_for_each_chunk(chunk, &rctxt->rc_write_pcl) { + if (!chunk->ch_payload_length) + break; + ret = svc_rdma_send_write_chunk(rdma, chunk, xdr); + if (ret < 0) + return ret; + } + return 0; +} + +/** + * svc_rdma_prepare_reply_chunk - Construct WR chain for writing the Reply chunk + * @rdma: controlling RDMA transport + * @write_pcl: Write chunk list provided by client + * @reply_pcl: Reply chunk provided by client + * @sctxt: Send WR resources + * @xdr: xdr_buf containing an RPC Reply + * + * Returns a non-negative number of bytes the chunk consumed, or + * %-E2BIG if the payload was larger than the Reply chunk, + * %-EINVAL if client provided too many segments, + * %-ENOMEM if rdma_rw context pool was exhausted, + * %-ENOTCONN if posting failed (connection is lost), + * %-EIO if rdma_rw initialization failed (DMA mapping, etc). + */ +int svc_rdma_prepare_reply_chunk(struct svcxprt_rdma *rdma, + const struct svc_rdma_pcl *write_pcl, + const struct svc_rdma_pcl *reply_pcl, + struct svc_rdma_send_ctxt *sctxt, + const struct xdr_buf *xdr) +{ + struct svc_rdma_write_info *info = &sctxt->sc_reply_info; + struct svc_rdma_chunk_ctxt *cc = &info->wi_cc; + struct ib_send_wr *first_wr; + struct list_head *pos; + struct ib_cqe *cqe; + int ret; + + info->wi_rdma = rdma; + info->wi_chunk = pcl_first_chunk(reply_pcl); + info->wi_seg_off = 0; + info->wi_seg_no = 0; + info->wi_cc.cc_cqe.done = svc_rdma_reply_done; + + ret = pcl_process_nonpayloads(write_pcl, xdr, + svc_rdma_xb_write, info); + if (ret < 0) + return ret; + + first_wr = sctxt->sc_wr_chain; + cqe = &cc->cc_cqe; + list_for_each(pos, &cc->cc_rwctxts) { + struct svc_rdma_rw_ctxt *rwc; + + rwc = list_entry(pos, struct svc_rdma_rw_ctxt, rw_list); + first_wr = rdma_rw_ctx_wrs(&rwc->rw_ctx, rdma->sc_qp, + rdma->sc_port_num, cqe, first_wr); + cqe = NULL; + } + sctxt->sc_wr_chain = first_wr; + sctxt->sc_sqecount += cc->cc_sqecount; + + trace_svcrdma_post_reply_chunk(&cc->cc_cid, cc->cc_sqecount); + return xdr->len; +} + +/** + * svc_rdma_build_read_segment - Build RDMA Read WQEs to pull one RDMA segment + * @rqstp: RPC transaction context + * @head: context for ongoing I/O + * @segment: co-ordinates of remote memory to be read + * + * Returns: + * %0: the Read WR chain was constructed successfully + * %-EINVAL: there were not enough rq_pages to finish + * %-ENOMEM: allocating a local resources failed + * %-EIO: a DMA mapping error occurred + */ +static int svc_rdma_build_read_segment(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head, + const struct svc_rdma_segment *segment) +{ + struct svcxprt_rdma *rdma = svc_rdma_rqst_rdma(rqstp); + struct svc_rdma_chunk_ctxt *cc = &head->rc_cc; + unsigned int sge_no, seg_len, len; + struct svc_rdma_rw_ctxt *ctxt; + struct scatterlist *sg; + int ret; + + len = segment->rs_length; + sge_no = PAGE_ALIGN(head->rc_pageoff + len) >> PAGE_SHIFT; + ctxt = svc_rdma_get_rw_ctxt(rdma, sge_no); + if (!ctxt) + return -ENOMEM; + ctxt->rw_nents = sge_no; + + sg = ctxt->rw_sg_table.sgl; + for (sge_no = 0; sge_no < ctxt->rw_nents; sge_no++) { + seg_len = min_t(unsigned int, len, + PAGE_SIZE - head->rc_pageoff); + + if (!head->rc_pageoff) + head->rc_page_count++; + + sg_set_page(sg, rqstp->rq_pages[head->rc_curpage], + seg_len, head->rc_pageoff); + sg = sg_next(sg); + + head->rc_pageoff += seg_len; + if (head->rc_pageoff == PAGE_SIZE) { + head->rc_curpage++; + head->rc_pageoff = 0; + } + len -= seg_len; + + if (len && ((head->rc_curpage + 1) > rqstp->rq_maxpages)) + goto out_overrun; + } + + ret = svc_rdma_rw_ctx_init(rdma, ctxt, segment->rs_offset, + segment->rs_handle, DMA_FROM_DEVICE); + if (ret < 0) + return -EIO; + percpu_counter_inc(&svcrdma_stat_read); + + list_add(&ctxt->rw_list, &cc->cc_rwctxts); + cc->cc_sqecount += ret; + return 0; + +out_overrun: + trace_svcrdma_page_overrun_err(&cc->cc_cid, head->rc_curpage); + return -EINVAL; +} + +/** + * svc_rdma_build_read_chunk - Build RDMA Read WQEs to pull one RDMA chunk + * @rqstp: RPC transaction context + * @head: context for ongoing I/O + * @chunk: Read chunk to pull + * + * Return values: + * %0: the Read WR chain was constructed successfully + * %-EINVAL: there were not enough resources to finish + * %-ENOMEM: allocating a local resources failed + * %-EIO: a DMA mapping error occurred + */ +static int svc_rdma_build_read_chunk(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head, + const struct svc_rdma_chunk *chunk) +{ + const struct svc_rdma_segment *segment; + int ret; + + ret = -EINVAL; + pcl_for_each_segment(segment, chunk) { + ret = svc_rdma_build_read_segment(rqstp, head, segment); + if (ret < 0) + break; + head->rc_readbytes += segment->rs_length; + } + return ret; +} + +/** + * svc_rdma_copy_inline_range - Copy part of the inline content into pages + * @rqstp: RPC transaction context + * @head: context for ongoing I/O + * @offset: offset into the Receive buffer of region to copy + * @remaining: length of region to copy + * + * Take a page at a time from rqstp->rq_pages and copy the inline + * content from the Receive buffer into that page. Update + * head->rc_curpage and head->rc_pageoff so that the next RDMA Read + * result will land contiguously with the copied content. + * + * Return values: + * %0: Inline content was successfully copied + * %-EINVAL: offset or length was incorrect + */ +static int svc_rdma_copy_inline_range(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head, + unsigned int offset, + unsigned int remaining) +{ + unsigned char *dst, *src = head->rc_recv_buf; + unsigned int page_no, numpages; + + numpages = PAGE_ALIGN(head->rc_pageoff + remaining) >> PAGE_SHIFT; + for (page_no = 0; page_no < numpages; page_no++) { + unsigned int page_len; + + page_len = min_t(unsigned int, remaining, + PAGE_SIZE - head->rc_pageoff); + + if (!head->rc_pageoff) + head->rc_page_count++; + + dst = page_address(rqstp->rq_pages[head->rc_curpage]); + memcpy(dst + head->rc_curpage, src + offset, page_len); + + head->rc_readbytes += page_len; + head->rc_pageoff += page_len; + if (head->rc_pageoff == PAGE_SIZE) { + head->rc_curpage++; + head->rc_pageoff = 0; + } + remaining -= page_len; + offset += page_len; + } + + return -EINVAL; +} + +/** + * svc_rdma_read_multiple_chunks - Construct RDMA Reads to pull data item Read chunks + * @rqstp: RPC transaction context + * @head: context for ongoing I/O + * + * The chunk data lands in rqstp->rq_arg as a series of contiguous pages, + * like an incoming TCP call. + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: client provided too many chunks or segments, + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static noinline int +svc_rdma_read_multiple_chunks(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head) +{ + const struct svc_rdma_pcl *pcl = &head->rc_read_pcl; + struct svc_rdma_chunk *chunk, *next; + unsigned int start, length; + int ret; + + start = 0; + chunk = pcl_first_chunk(pcl); + length = chunk->ch_position; + ret = svc_rdma_copy_inline_range(rqstp, head, start, length); + if (ret < 0) + return ret; + + pcl_for_each_chunk(chunk, pcl) { + ret = svc_rdma_build_read_chunk(rqstp, head, chunk); + if (ret < 0) + return ret; + + next = pcl_next_chunk(pcl, chunk); + if (!next) + break; + + start += length; + length = next->ch_position - head->rc_readbytes; + ret = svc_rdma_copy_inline_range(rqstp, head, start, length); + if (ret < 0) + return ret; + } + + start += length; + length = head->rc_byte_len - start; + return svc_rdma_copy_inline_range(rqstp, head, start, length); +} + +/** + * svc_rdma_read_data_item - Construct RDMA Reads to pull data item Read chunks + * @rqstp: RPC transaction context + * @head: context for ongoing I/O + * + * The chunk data lands in the page list of rqstp->rq_arg.pages. + * + * Currently NFSD does not look at the rqstp->rq_arg.tail[0] kvec. + * Therefore, XDR round-up of the Read chunk and trailing + * inline content must both be added at the end of the pagelist. + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: client provided too many chunks or segments, + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static int svc_rdma_read_data_item(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head) +{ + return svc_rdma_build_read_chunk(rqstp, head, + pcl_first_chunk(&head->rc_read_pcl)); +} + +/** + * svc_rdma_read_chunk_range - Build RDMA Read WRs for portion of a chunk + * @rqstp: RPC transaction context + * @head: context for ongoing I/O + * @chunk: parsed Call chunk to pull + * @offset: offset of region to pull + * @length: length of region to pull + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: there were not enough resources to finish + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static int svc_rdma_read_chunk_range(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head, + const struct svc_rdma_chunk *chunk, + unsigned int offset, unsigned int length) +{ + const struct svc_rdma_segment *segment; + int ret; + + ret = -EINVAL; + pcl_for_each_segment(segment, chunk) { + struct svc_rdma_segment dummy; + + if (offset > segment->rs_length) { + offset -= segment->rs_length; + continue; + } + + dummy.rs_handle = segment->rs_handle; + dummy.rs_length = min_t(u32, length, segment->rs_length) - offset; + dummy.rs_offset = segment->rs_offset + offset; + + ret = svc_rdma_build_read_segment(rqstp, head, &dummy); + if (ret < 0) + break; + + head->rc_readbytes += dummy.rs_length; + length -= dummy.rs_length; + offset = 0; + } + return ret; +} + +/** + * svc_rdma_read_call_chunk - Build RDMA Read WQEs to pull a Long Message + * @rqstp: RPC transaction context + * @head: context for ongoing I/O + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: there were not enough resources to finish + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static int svc_rdma_read_call_chunk(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head) +{ + const struct svc_rdma_chunk *call_chunk = + pcl_first_chunk(&head->rc_call_pcl); + const struct svc_rdma_pcl *pcl = &head->rc_read_pcl; + struct svc_rdma_chunk *chunk, *next; + unsigned int start, length; + int ret; + + if (pcl_is_empty(pcl)) + return svc_rdma_build_read_chunk(rqstp, head, call_chunk); + + start = 0; + chunk = pcl_first_chunk(pcl); + length = chunk->ch_position; + ret = svc_rdma_read_chunk_range(rqstp, head, call_chunk, + start, length); + if (ret < 0) + return ret; + + pcl_for_each_chunk(chunk, pcl) { + ret = svc_rdma_build_read_chunk(rqstp, head, chunk); + if (ret < 0) + return ret; + + next = pcl_next_chunk(pcl, chunk); + if (!next) + break; + + start += length; + length = next->ch_position - head->rc_readbytes; + ret = svc_rdma_read_chunk_range(rqstp, head, call_chunk, + start, length); + if (ret < 0) + return ret; + } + + start += length; + length = call_chunk->ch_length - start; + return svc_rdma_read_chunk_range(rqstp, head, call_chunk, + start, length); +} + +/** + * svc_rdma_read_special - Build RDMA Read WQEs to pull a Long Message + * @rqstp: RPC transaction context + * @head: context for ongoing I/O + * + * The start of the data lands in the first page just after the + * Transport header, and the rest lands in rqstp->rq_arg.pages. + * + * Assumptions: + * - A PZRC is never sent in an RDMA_MSG message, though it's + * allowed by spec. + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: client provided too many chunks or segments, + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static noinline int svc_rdma_read_special(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head) +{ + return svc_rdma_read_call_chunk(rqstp, head); +} + +/* Pages under I/O have been copied to head->rc_pages. Ensure that + * svc_xprt_release() does not put them when svc_rdma_recvfrom() + * returns. This has to be done after all Read WRs are constructed + * to properly handle a page that happens to be part of I/O on behalf + * of two different RDMA segments. + * + * Note: if the subsequent post_send fails, these pages have already + * been moved to head->rc_pages and thus will be cleaned up by + * svc_rdma_recv_ctxt_put(). + */ +static void svc_rdma_clear_rqst_pages(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head) +{ + unsigned int i; + + for (i = 0; i < head->rc_page_count; i++) { + head->rc_pages[i] = rqstp->rq_pages[i]; + rqstp->rq_pages[i] = NULL; + } +} + +/** + * svc_rdma_process_read_list - Pull list of Read chunks from the client + * @rdma: controlling RDMA transport + * @rqstp: set of pages to use as Read sink buffers + * @head: pages under I/O collect here + * + * The RPC/RDMA protocol assumes that the upper layer's XDR decoders + * pull each Read chunk as they decode an incoming RPC message. + * + * On Linux, however, the server needs to have a fully-constructed RPC + * message in rqstp->rq_arg when there is a positive return code from + * ->xpo_recvfrom. So the Read list is safety-checked immediately when + * it is received, then here the whole Read list is pulled all at once. + * The ingress RPC message is fully reconstructed once all associated + * RDMA Reads have completed. + * + * Return values: + * %1: all needed RDMA Reads were posted successfully, + * %-EINVAL: client provided too many chunks or segments, + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +int svc_rdma_process_read_list(struct svcxprt_rdma *rdma, + struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head) +{ + struct svc_rdma_chunk_ctxt *cc = &head->rc_cc; + int ret; + + cc->cc_cqe.done = svc_rdma_wc_read_done; + cc->cc_sqecount = 0; + head->rc_pageoff = 0; + head->rc_curpage = 0; + head->rc_readbytes = 0; + + if (pcl_is_empty(&head->rc_call_pcl)) { + if (head->rc_read_pcl.cl_count == 1) + ret = svc_rdma_read_data_item(rqstp, head); + else + ret = svc_rdma_read_multiple_chunks(rqstp, head); + } else + ret = svc_rdma_read_special(rqstp, head); + svc_rdma_clear_rqst_pages(rqstp, head); + if (ret < 0) + return ret; + + trace_svcrdma_post_read_chunk(&cc->cc_cid, cc->cc_sqecount); + ret = svc_rdma_post_chunk_ctxt(rdma, cc); + return ret < 0 ? ret : 1; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_sendto.c b/net/sunrpc/xprtrdma/svc_rdma_sendto.c index c1d124dc772b..914cd263c2f1 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_sendto.c +++ b/net/sunrpc/xprtrdma/svc_rdma_sendto.c @@ -1,4 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause /* + * Copyright (c) 2016-2018 Oracle. All rights reserved. + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -39,709 +42,1069 @@ * Author: Tom Tucker <tom@opengridcomputing.com> */ -#include <linux/sunrpc/debug.h> -#include <linux/sunrpc/rpc_rdma.h> +/* Operation + * + * The main entry point is svc_rdma_sendto. This is called by the + * RPC server when an RPC Reply is ready to be transmitted to a client. + * + * The passed-in svc_rqst contains a struct xdr_buf which holds an + * XDR-encoded RPC Reply message. sendto must construct the RPC-over-RDMA + * transport header, post all Write WRs needed for this Reply, then post + * a Send WR conveying the transport header and the RPC message itself to + * the client. + * + * svc_rdma_sendto must fully transmit the Reply before returning, as + * the svc_rqst will be recycled as soon as sendto returns. Remaining + * resources referred to by the svc_rqst are also recycled at that time. + * Therefore any resources that must remain longer must be detached + * from the svc_rqst and released later. + * + * Page Management + * + * The I/O that performs Reply transmission is asynchronous, and may + * complete well after sendto returns. Thus pages under I/O must be + * removed from the svc_rqst before sendto returns. + * + * The logic here depends on Send Queue and completion ordering. Since + * the Send WR is always posted last, it will always complete last. Thus + * when it completes, it is guaranteed that all previous Write WRs have + * also completed. + * + * Write WRs are constructed and posted. Each Write segment gets its own + * svc_rdma_rw_ctxt, allowing the Write completion handler to find and + * DMA-unmap the pages under I/O for that Write segment. The Write + * completion handler does not release any pages. + * + * When the Send WR is constructed, it also gets its own svc_rdma_send_ctxt. + * The ownership of all of the Reply's pages are transferred into that + * ctxt, the Send WR is posted, and sendto returns. + * + * The svc_rdma_send_ctxt is presented when the Send WR completes. The + * Send completion handler finally releases the Reply's pages. + * + * This mechanism also assumes that completions on the transport's Send + * Completion Queue do not run in parallel. Otherwise a Write completion + * and Send completion running at the same time could release pages that + * are still DMA-mapped. + * + * Error Handling + * + * - If the Send WR is posted successfully, it will either complete + * successfully, or get flushed. Either way, the Send completion + * handler releases the Reply's pages. + * - If the Send WR cannot be not posted, the forward path releases + * the Reply's pages. + * + * This handles the case, without the use of page reference counting, + * where two different Write segments send portions of the same page. + */ + #include <linux/spinlock.h> -#include <asm/unaligned.h> +#include <linux/unaligned.h> + #include <rdma/ib_verbs.h> #include <rdma/rdma_cm.h> + +#include <linux/sunrpc/debug.h> #include <linux/sunrpc/svc_rdma.h> -#define RPCDBG_FACILITY RPCDBG_SVCXPRT +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +static void svc_rdma_wc_send(struct ib_cq *cq, struct ib_wc *wc); + +static struct svc_rdma_send_ctxt * +svc_rdma_send_ctxt_alloc(struct svcxprt_rdma *rdma) +{ + int node = ibdev_to_node(rdma->sc_cm_id->device); + struct svc_rdma_send_ctxt *ctxt; + unsigned long pages; + dma_addr_t addr; + void *buffer; + int i; + + ctxt = kzalloc_node(struct_size(ctxt, sc_sges, rdma->sc_max_send_sges), + GFP_KERNEL, node); + if (!ctxt) + goto fail0; + pages = svc_serv_maxpages(rdma->sc_xprt.xpt_server); + ctxt->sc_pages = kcalloc_node(pages, sizeof(struct page *), + GFP_KERNEL, node); + if (!ctxt->sc_pages) + goto fail1; + ctxt->sc_maxpages = pages; + buffer = kmalloc_node(rdma->sc_max_req_size, GFP_KERNEL, node); + if (!buffer) + goto fail2; + addr = ib_dma_map_single(rdma->sc_pd->device, buffer, + rdma->sc_max_req_size, DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdma->sc_pd->device, addr)) + goto fail3; + + svc_rdma_send_cid_init(rdma, &ctxt->sc_cid); + + ctxt->sc_rdma = rdma; + ctxt->sc_send_wr.next = NULL; + ctxt->sc_send_wr.wr_cqe = &ctxt->sc_cqe; + ctxt->sc_send_wr.sg_list = ctxt->sc_sges; + ctxt->sc_send_wr.send_flags = IB_SEND_SIGNALED; + ctxt->sc_cqe.done = svc_rdma_wc_send; + ctxt->sc_xprt_buf = buffer; + xdr_buf_init(&ctxt->sc_hdrbuf, ctxt->sc_xprt_buf, + rdma->sc_max_req_size); + ctxt->sc_sges[0].addr = addr; + + for (i = 0; i < rdma->sc_max_send_sges; i++) + ctxt->sc_sges[i].lkey = rdma->sc_pd->local_dma_lkey; + return ctxt; + +fail3: + kfree(buffer); +fail2: + kfree(ctxt->sc_pages); +fail1: + kfree(ctxt); +fail0: + return NULL; +} -/* Encode an XDR as an array of IB SGE +/** + * svc_rdma_send_ctxts_destroy - Release all send_ctxt's for an xprt + * @rdma: svcxprt_rdma being torn down * - * Assumptions: - * - head[0] is physically contiguous. - * - tail[0] is physically contiguous. - * - pages[] is not physically or virtually contiguous and consists of - * PAGE_SIZE elements. - * - * Output: - * SGE[0] reserved for RCPRDMA header - * SGE[1] data from xdr->head[] - * SGE[2..sge_count-2] data from xdr->pages[] - * SGE[sge_count-1] data from xdr->tail. - * - * The max SGE we need is the length of the XDR / pagesize + one for - * head + one for tail + one for RPCRDMA header. Since RPCSVC_MAXPAGES - * reserves a page for both the request and the reply header, and this - * array is only concerned with the reply we are assured that we have - * on extra page for the RPCRMDA header. */ -static int fast_reg_xdr(struct svcxprt_rdma *xprt, - struct xdr_buf *xdr, - struct svc_rdma_req_map *vec) +void svc_rdma_send_ctxts_destroy(struct svcxprt_rdma *rdma) { - int sge_no; - u32 sge_bytes; - u32 page_bytes; - u32 page_off; - int page_no = 0; - u8 *frva; - struct svc_rdma_fastreg_mr *frmr; - - frmr = svc_rdma_get_frmr(xprt); - if (IS_ERR(frmr)) - return -ENOMEM; - vec->frmr = frmr; - - /* Skip the RPCRDMA header */ - sge_no = 1; - - /* Map the head. */ - frva = (void *)((unsigned long)(xdr->head[0].iov_base) & PAGE_MASK); - vec->sge[sge_no].iov_base = xdr->head[0].iov_base; - vec->sge[sge_no].iov_len = xdr->head[0].iov_len; - vec->count = 2; - sge_no++; - - /* Map the XDR head */ - frmr->kva = frva; - frmr->direction = DMA_TO_DEVICE; - frmr->access_flags = 0; - frmr->map_len = PAGE_SIZE; - frmr->page_list_len = 1; - page_off = (unsigned long)xdr->head[0].iov_base & ~PAGE_MASK; - frmr->page_list->page_list[page_no] = - ib_dma_map_page(xprt->sc_cm_id->device, - virt_to_page(xdr->head[0].iov_base), - page_off, - PAGE_SIZE - page_off, - DMA_TO_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - frmr->page_list->page_list[page_no])) - goto fatal_err; - atomic_inc(&xprt->sc_dma_used); - - /* Map the XDR page list */ - page_off = xdr->page_base; - page_bytes = xdr->page_len + page_off; - if (!page_bytes) - goto encode_tail; - - /* Map the pages */ - vec->sge[sge_no].iov_base = frva + frmr->map_len + page_off; - vec->sge[sge_no].iov_len = page_bytes; - sge_no++; - while (page_bytes) { - struct page *page; - - page = xdr->pages[page_no++]; - sge_bytes = min_t(u32, page_bytes, (PAGE_SIZE - page_off)); - page_bytes -= sge_bytes; - - frmr->page_list->page_list[page_no] = - ib_dma_map_page(xprt->sc_cm_id->device, - page, page_off, - sge_bytes, DMA_TO_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - frmr->page_list->page_list[page_no])) - goto fatal_err; - - atomic_inc(&xprt->sc_dma_used); - page_off = 0; /* reset for next time through loop */ - frmr->map_len += PAGE_SIZE; - frmr->page_list_len++; + struct svc_rdma_send_ctxt *ctxt; + struct llist_node *node; + + while ((node = llist_del_first(&rdma->sc_send_ctxts)) != NULL) { + ctxt = llist_entry(node, struct svc_rdma_send_ctxt, sc_node); + ib_dma_unmap_single(rdma->sc_pd->device, + ctxt->sc_sges[0].addr, + rdma->sc_max_req_size, + DMA_TO_DEVICE); + kfree(ctxt->sc_xprt_buf); + kfree(ctxt->sc_pages); + kfree(ctxt); } - vec->count++; - - encode_tail: - /* Map tail */ - if (0 == xdr->tail[0].iov_len) - goto done; - - vec->count++; - vec->sge[sge_no].iov_len = xdr->tail[0].iov_len; - - if (((unsigned long)xdr->tail[0].iov_base & PAGE_MASK) == - ((unsigned long)xdr->head[0].iov_base & PAGE_MASK)) { - /* - * If head and tail use the same page, we don't need - * to map it again. - */ - vec->sge[sge_no].iov_base = xdr->tail[0].iov_base; - } else { - void *va; +} - /* Map another page for the tail */ - page_off = (unsigned long)xdr->tail[0].iov_base & ~PAGE_MASK; - va = (void *)((unsigned long)xdr->tail[0].iov_base & PAGE_MASK); - vec->sge[sge_no].iov_base = frva + frmr->map_len + page_off; +/** + * svc_rdma_send_ctxt_get - Get a free send_ctxt + * @rdma: controlling svcxprt_rdma + * + * Returns a ready-to-use send_ctxt, or NULL if none are + * available and a fresh one cannot be allocated. + */ +struct svc_rdma_send_ctxt *svc_rdma_send_ctxt_get(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_send_ctxt *ctxt; + struct llist_node *node; + + spin_lock(&rdma->sc_send_lock); + node = llist_del_first(&rdma->sc_send_ctxts); + spin_unlock(&rdma->sc_send_lock); + if (!node) + goto out_empty; + + ctxt = llist_entry(node, struct svc_rdma_send_ctxt, sc_node); + +out: + rpcrdma_set_xdrlen(&ctxt->sc_hdrbuf, 0); + xdr_init_encode(&ctxt->sc_stream, &ctxt->sc_hdrbuf, + ctxt->sc_xprt_buf, NULL); + + svc_rdma_cc_init(rdma, &ctxt->sc_reply_info.wi_cc); + ctxt->sc_send_wr.num_sge = 0; + ctxt->sc_cur_sge_no = 0; + ctxt->sc_page_count = 0; + ctxt->sc_wr_chain = &ctxt->sc_send_wr; + ctxt->sc_sqecount = 1; + + return ctxt; + +out_empty: + ctxt = svc_rdma_send_ctxt_alloc(rdma); + if (!ctxt) + return NULL; + goto out; +} - frmr->page_list->page_list[page_no] = - ib_dma_map_page(xprt->sc_cm_id->device, virt_to_page(va), - page_off, - PAGE_SIZE, - DMA_TO_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - frmr->page_list->page_list[page_no])) - goto fatal_err; - atomic_inc(&xprt->sc_dma_used); - frmr->map_len += PAGE_SIZE; - frmr->page_list_len++; +static void svc_rdma_send_ctxt_release(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt) +{ + struct ib_device *device = rdma->sc_cm_id->device; + unsigned int i; + + svc_rdma_reply_chunk_release(rdma, ctxt); + + if (ctxt->sc_page_count) + release_pages(ctxt->sc_pages, ctxt->sc_page_count); + + /* The first SGE contains the transport header, which + * remains mapped until @ctxt is destroyed. + */ + for (i = 1; i < ctxt->sc_send_wr.num_sge; i++) { + trace_svcrdma_dma_unmap_page(&ctxt->sc_cid, + ctxt->sc_sges[i].addr, + ctxt->sc_sges[i].length); + ib_dma_unmap_page(device, + ctxt->sc_sges[i].addr, + ctxt->sc_sges[i].length, + DMA_TO_DEVICE); } - done: - if (svc_rdma_fastreg(xprt, frmr)) - goto fatal_err; + llist_add(&ctxt->sc_node, &rdma->sc_send_ctxts); +} - return 0; +static void svc_rdma_send_ctxt_put_async(struct work_struct *work) +{ + struct svc_rdma_send_ctxt *ctxt; - fatal_err: - printk("svcrdma: Error fast registering memory for xprt %p\n", xprt); - vec->frmr = NULL; - svc_rdma_put_frmr(xprt, frmr); - return -EIO; + ctxt = container_of(work, struct svc_rdma_send_ctxt, sc_work); + svc_rdma_send_ctxt_release(ctxt->sc_rdma, ctxt); } -static int map_xdr(struct svcxprt_rdma *xprt, - struct xdr_buf *xdr, - struct svc_rdma_req_map *vec) +/** + * svc_rdma_send_ctxt_put - Return send_ctxt to free list + * @rdma: controlling svcxprt_rdma + * @ctxt: object to return to the free list + * + * Pages left in sc_pages are DMA unmapped and released. + */ +void svc_rdma_send_ctxt_put(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt) { - int sge_no; - u32 sge_bytes; - u32 page_bytes; - u32 page_off; - int page_no; - - BUG_ON(xdr->len != - (xdr->head[0].iov_len + xdr->page_len + xdr->tail[0].iov_len)); - - if (xprt->sc_frmr_pg_list_len) - return fast_reg_xdr(xprt, xdr, vec); - - /* Skip the first sge, this is for the RPCRDMA header */ - sge_no = 1; - - /* Head SGE */ - vec->sge[sge_no].iov_base = xdr->head[0].iov_base; - vec->sge[sge_no].iov_len = xdr->head[0].iov_len; - sge_no++; - - /* pages SGE */ - page_no = 0; - page_bytes = xdr->page_len; - page_off = xdr->page_base; - while (page_bytes) { - vec->sge[sge_no].iov_base = - page_address(xdr->pages[page_no]) + page_off; - sge_bytes = min_t(u32, page_bytes, (PAGE_SIZE - page_off)); - page_bytes -= sge_bytes; - vec->sge[sge_no].iov_len = sge_bytes; - - sge_no++; - page_no++; - page_off = 0; /* reset for next time through loop */ - } + INIT_WORK(&ctxt->sc_work, svc_rdma_send_ctxt_put_async); + queue_work(svcrdma_wq, &ctxt->sc_work); +} - /* Tail SGE */ - if (xdr->tail[0].iov_len) { - vec->sge[sge_no].iov_base = xdr->tail[0].iov_base; - vec->sge[sge_no].iov_len = xdr->tail[0].iov_len; - sge_no++; - } +/** + * svc_rdma_wake_send_waiters - manage Send Queue accounting + * @rdma: controlling transport + * @avail: Number of additional SQEs that are now available + * + */ +void svc_rdma_wake_send_waiters(struct svcxprt_rdma *rdma, int avail) +{ + atomic_add(avail, &rdma->sc_sq_avail); + smp_mb__after_atomic(); + if (unlikely(waitqueue_active(&rdma->sc_send_wait))) + wake_up(&rdma->sc_send_wait); +} - dprintk("svcrdma: map_xdr: sge_no %d page_no %d " - "page_base %u page_len %u head_len %zu tail_len %zu\n", - sge_no, page_no, xdr->page_base, xdr->page_len, - xdr->head[0].iov_len, xdr->tail[0].iov_len); +/** + * svc_rdma_wc_send - Invoked by RDMA provider for each polled Send WC + * @cq: Completion Queue context + * @wc: Work Completion object + * + * NB: The svc_xprt/svcxprt_rdma is pinned whenever it's possible that + * the Send completion handler could be running. + */ +static void svc_rdma_wc_send(struct ib_cq *cq, struct ib_wc *wc) +{ + struct svcxprt_rdma *rdma = cq->cq_context; + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_send_ctxt *ctxt = + container_of(cqe, struct svc_rdma_send_ctxt, sc_cqe); - vec->count = sge_no; - return 0; + svc_rdma_wake_send_waiters(rdma, ctxt->sc_sqecount); + + if (unlikely(wc->status != IB_WC_SUCCESS)) + goto flushed; + + trace_svcrdma_wc_send(&ctxt->sc_cid); + svc_rdma_send_ctxt_put(rdma, ctxt); + return; + +flushed: + if (wc->status != IB_WC_WR_FLUSH_ERR) + trace_svcrdma_wc_send_err(wc, &ctxt->sc_cid); + else + trace_svcrdma_wc_send_flush(wc, &ctxt->sc_cid); + svc_rdma_send_ctxt_put(rdma, ctxt); + svc_xprt_deferred_close(&rdma->sc_xprt); } -static dma_addr_t dma_map_xdr(struct svcxprt_rdma *xprt, - struct xdr_buf *xdr, - u32 xdr_off, size_t len, int dir) +/** + * svc_rdma_post_send - Post a WR chain to the Send Queue + * @rdma: transport context + * @ctxt: WR chain to post + * + * Copy fields in @ctxt to stack variables in order to guarantee + * that these values remain available after the ib_post_send() call. + * In some error flow cases, svc_rdma_wc_send() releases @ctxt. + * + * Note there is potential for starvation when the Send Queue is + * full because there is no order to when waiting threads are + * awoken. The transport is typically provisioned with a deep + * enough Send Queue that SQ exhaustion should be a rare event. + * + * Return values: + * %0: @ctxt's WR chain was posted successfully + * %-ENOTCONN: The connection was lost + */ +int svc_rdma_post_send(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt) { - struct page *page; - dma_addr_t dma_addr; - if (xdr_off < xdr->head[0].iov_len) { - /* This offset is in the head */ - xdr_off += (unsigned long)xdr->head[0].iov_base & ~PAGE_MASK; - page = virt_to_page(xdr->head[0].iov_base); - } else { - xdr_off -= xdr->head[0].iov_len; - if (xdr_off < xdr->page_len) { - /* This offset is in the page list */ - page = xdr->pages[xdr_off >> PAGE_SHIFT]; - xdr_off &= ~PAGE_MASK; - } else { - /* This offset is in the tail */ - xdr_off -= xdr->page_len; - xdr_off += (unsigned long) - xdr->tail[0].iov_base & ~PAGE_MASK; - page = virt_to_page(xdr->tail[0].iov_base); + struct ib_send_wr *first_wr = ctxt->sc_wr_chain; + struct ib_send_wr *send_wr = &ctxt->sc_send_wr; + const struct ib_send_wr *bad_wr = first_wr; + struct rpc_rdma_cid cid = ctxt->sc_cid; + int ret, sqecount = ctxt->sc_sqecount; + + might_sleep(); + + /* Sync the transport header buffer */ + ib_dma_sync_single_for_device(rdma->sc_pd->device, + send_wr->sg_list[0].addr, + send_wr->sg_list[0].length, + DMA_TO_DEVICE); + + /* If the SQ is full, wait until an SQ entry is available */ + while (!test_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags)) { + if (atomic_sub_return(sqecount, &rdma->sc_sq_avail) < 0) { + svc_rdma_wake_send_waiters(rdma, sqecount); + + /* When the transport is torn down, assume + * ib_drain_sq() will trigger enough Send + * completions to wake us. The XPT_CLOSE test + * above should then cause the while loop to + * exit. + */ + percpu_counter_inc(&svcrdma_stat_sq_starve); + trace_svcrdma_sq_full(rdma, &cid); + wait_event(rdma->sc_send_wait, + atomic_read(&rdma->sc_sq_avail) > 0); + trace_svcrdma_sq_retry(rdma, &cid); + continue; } + + trace_svcrdma_post_send(ctxt); + ret = ib_post_send(rdma->sc_qp, first_wr, &bad_wr); + if (ret) { + trace_svcrdma_sq_post_err(rdma, &cid, ret); + svc_xprt_deferred_close(&rdma->sc_xprt); + + /* If even one WR was posted, there will be a + * Send completion that bumps sc_sq_avail. + */ + if (bad_wr == first_wr) { + svc_rdma_wake_send_waiters(rdma, sqecount); + break; + } + } + return 0; } - dma_addr = ib_dma_map_page(xprt->sc_cm_id->device, page, xdr_off, - min_t(size_t, PAGE_SIZE, len), dir); - return dma_addr; + return -ENOTCONN; } -/* Assumptions: - * - We are using FRMR - * - or - - * - The specified write_len can be represented in sc_max_sge * PAGE_SIZE +/** + * svc_rdma_encode_read_list - Encode RPC Reply's Read chunk list + * @sctxt: Send context for the RPC Reply + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Reply Read list + * %-EMSGSIZE on XDR buffer overflow + */ +static ssize_t svc_rdma_encode_read_list(struct svc_rdma_send_ctxt *sctxt) +{ + /* RPC-over-RDMA version 1 replies never have a Read list. */ + return xdr_stream_encode_item_absent(&sctxt->sc_stream); +} + +/** + * svc_rdma_encode_write_segment - Encode one Write segment + * @sctxt: Send context for the RPC Reply + * @chunk: Write chunk to push + * @remaining: remaining bytes of the payload left in the Write chunk + * @segno: which segment in the chunk + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Write segment, and updates @remaining + * %-EMSGSIZE on XDR buffer overflow + */ +static ssize_t svc_rdma_encode_write_segment(struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_chunk *chunk, + u32 *remaining, unsigned int segno) +{ + const struct svc_rdma_segment *segment = &chunk->ch_segments[segno]; + const size_t len = rpcrdma_segment_maxsz * sizeof(__be32); + u32 length; + __be32 *p; + + p = xdr_reserve_space(&sctxt->sc_stream, len); + if (!p) + return -EMSGSIZE; + + length = min_t(u32, *remaining, segment->rs_length); + *remaining -= length; + xdr_encode_rdma_segment(p, segment->rs_handle, length, + segment->rs_offset); + trace_svcrdma_encode_wseg(sctxt, segno, segment->rs_handle, length, + segment->rs_offset); + return len; +} + +/** + * svc_rdma_encode_write_chunk - Encode one Write chunk + * @sctxt: Send context for the RPC Reply + * @chunk: Write chunk to push + * + * Copy a Write chunk from the Call transport header to the + * Reply transport header. Update each segment's length field + * to reflect the number of bytes written in that segment. + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Write chunk + * %-EMSGSIZE on XDR buffer overflow */ -static int send_write(struct svcxprt_rdma *xprt, struct svc_rqst *rqstp, - u32 rmr, u64 to, - u32 xdr_off, int write_len, - struct svc_rdma_req_map *vec) +static ssize_t svc_rdma_encode_write_chunk(struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_chunk *chunk) { - struct ib_send_wr write_wr; - struct ib_sge *sge; - int xdr_sge_no; - int sge_no; - int sge_bytes; - int sge_off; - int bc; - struct svc_rdma_op_ctxt *ctxt; - - BUG_ON(vec->count > RPCSVC_MAXPAGES); - dprintk("svcrdma: RDMA_WRITE rmr=%x, to=%llx, xdr_off=%d, " - "write_len=%d, vec->sge=%p, vec->count=%lu\n", - rmr, (unsigned long long)to, xdr_off, - write_len, vec->sge, vec->count); - - ctxt = svc_rdma_get_context(xprt); - ctxt->direction = DMA_TO_DEVICE; - sge = ctxt->sge; - - /* Find the SGE associated with xdr_off */ - for (bc = xdr_off, xdr_sge_no = 1; bc && xdr_sge_no < vec->count; - xdr_sge_no++) { - if (vec->sge[xdr_sge_no].iov_len > bc) - break; - bc -= vec->sge[xdr_sge_no].iov_len; + u32 remaining = chunk->ch_payload_length; + unsigned int segno; + ssize_t len, ret; + + len = 0; + ret = xdr_stream_encode_item_present(&sctxt->sc_stream); + if (ret < 0) + return ret; + len += ret; + + ret = xdr_stream_encode_u32(&sctxt->sc_stream, chunk->ch_segcount); + if (ret < 0) + return ret; + len += ret; + + for (segno = 0; segno < chunk->ch_segcount; segno++) { + ret = svc_rdma_encode_write_segment(sctxt, chunk, &remaining, segno); + if (ret < 0) + return ret; + len += ret; } - sge_off = bc; - bc = write_len; - sge_no = 0; - - /* Copy the remaining SGE */ - while (bc != 0) { - sge_bytes = min_t(size_t, - bc, vec->sge[xdr_sge_no].iov_len-sge_off); - sge[sge_no].length = sge_bytes; - if (!vec->frmr) { - sge[sge_no].addr = - dma_map_xdr(xprt, &rqstp->rq_res, xdr_off, - sge_bytes, DMA_TO_DEVICE); - xdr_off += sge_bytes; - if (ib_dma_mapping_error(xprt->sc_cm_id->device, - sge[sge_no].addr)) - goto err; - atomic_inc(&xprt->sc_dma_used); - sge[sge_no].lkey = xprt->sc_dma_lkey; - } else { - sge[sge_no].addr = (unsigned long) - vec->sge[xdr_sge_no].iov_base + sge_off; - sge[sge_no].lkey = vec->frmr->mr->lkey; - } - ctxt->count++; - ctxt->frmr = vec->frmr; - sge_off = 0; - sge_no++; - xdr_sge_no++; - BUG_ON(xdr_sge_no > vec->count); - bc -= sge_bytes; + return len; +} + +/** + * svc_rdma_encode_write_list - Encode RPC Reply's Write chunk list + * @rctxt: Reply context with information about the RPC Call + * @sctxt: Send context for the RPC Reply + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Reply's Write list + * %-EMSGSIZE on XDR buffer overflow + */ +static ssize_t svc_rdma_encode_write_list(struct svc_rdma_recv_ctxt *rctxt, + struct svc_rdma_send_ctxt *sctxt) +{ + struct svc_rdma_chunk *chunk; + ssize_t len, ret; + + len = 0; + pcl_for_each_chunk(chunk, &rctxt->rc_write_pcl) { + ret = svc_rdma_encode_write_chunk(sctxt, chunk); + if (ret < 0) + return ret; + len += ret; } - /* Prepare WRITE WR */ - memset(&write_wr, 0, sizeof write_wr); - ctxt->wr_op = IB_WR_RDMA_WRITE; - write_wr.wr_id = (unsigned long)ctxt; - write_wr.sg_list = &sge[0]; - write_wr.num_sge = sge_no; - write_wr.opcode = IB_WR_RDMA_WRITE; - write_wr.send_flags = IB_SEND_SIGNALED; - write_wr.wr.rdma.rkey = rmr; - write_wr.wr.rdma.remote_addr = to; - - /* Post It */ - atomic_inc(&rdma_stat_write); - if (svc_rdma_send(xprt, &write_wr)) - goto err; + /* Terminate the Write list */ + ret = xdr_stream_encode_item_absent(&sctxt->sc_stream); + if (ret < 0) + return ret; + + return len + ret; +} + +/** + * svc_rdma_encode_reply_chunk - Encode RPC Reply's Reply chunk + * @rctxt: Reply context with information about the RPC Call + * @sctxt: Send context for the RPC Reply + * @length: size in bytes of the payload in the Reply chunk + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Reply's Reply chunk + * %-EMSGSIZE on XDR buffer overflow + * %-E2BIG if the RPC message is larger than the Reply chunk + */ +static ssize_t +svc_rdma_encode_reply_chunk(struct svc_rdma_recv_ctxt *rctxt, + struct svc_rdma_send_ctxt *sctxt, + unsigned int length) +{ + struct svc_rdma_chunk *chunk; + + if (pcl_is_empty(&rctxt->rc_reply_pcl)) + return xdr_stream_encode_item_absent(&sctxt->sc_stream); + + chunk = pcl_first_chunk(&rctxt->rc_reply_pcl); + if (length > chunk->ch_length) + return -E2BIG; + + chunk->ch_payload_length = length; + return svc_rdma_encode_write_chunk(sctxt, chunk); +} + +struct svc_rdma_map_data { + struct svcxprt_rdma *md_rdma; + struct svc_rdma_send_ctxt *md_ctxt; +}; + +/** + * svc_rdma_page_dma_map - DMA map one page + * @data: pointer to arguments + * @page: struct page to DMA map + * @offset: offset into the page + * @len: number of bytes to map + * + * Returns: + * %0 if DMA mapping was successful + * %-EIO if the page cannot be DMA mapped + */ +static int svc_rdma_page_dma_map(void *data, struct page *page, + unsigned long offset, unsigned int len) +{ + struct svc_rdma_map_data *args = data; + struct svcxprt_rdma *rdma = args->md_rdma; + struct svc_rdma_send_ctxt *ctxt = args->md_ctxt; + struct ib_device *dev = rdma->sc_cm_id->device; + dma_addr_t dma_addr; + + ++ctxt->sc_cur_sge_no; + + dma_addr = ib_dma_map_page(dev, page, offset, len, DMA_TO_DEVICE); + if (ib_dma_mapping_error(dev, dma_addr)) + goto out_maperr; + + trace_svcrdma_dma_map_page(&ctxt->sc_cid, dma_addr, len); + ctxt->sc_sges[ctxt->sc_cur_sge_no].addr = dma_addr; + ctxt->sc_sges[ctxt->sc_cur_sge_no].length = len; + ctxt->sc_send_wr.num_sge++; return 0; - err: - svc_rdma_unmap_dma(ctxt); - svc_rdma_put_frmr(xprt, vec->frmr); - svc_rdma_put_context(ctxt, 0); - /* Fatal error, close transport */ + +out_maperr: + trace_svcrdma_dma_map_err(&ctxt->sc_cid, dma_addr, len); return -EIO; } -static int send_write_chunks(struct svcxprt_rdma *xprt, - struct rpcrdma_msg *rdma_argp, - struct rpcrdma_msg *rdma_resp, - struct svc_rqst *rqstp, - struct svc_rdma_req_map *vec) +/** + * svc_rdma_iov_dma_map - DMA map an iovec + * @data: pointer to arguments + * @iov: kvec to DMA map + * + * ib_dma_map_page() is used here because svc_rdma_dma_unmap() + * handles DMA-unmap and it uses ib_dma_unmap_page() exclusively. + * + * Returns: + * %0 if DMA mapping was successful + * %-EIO if the iovec cannot be DMA mapped + */ +static int svc_rdma_iov_dma_map(void *data, const struct kvec *iov) { - u32 xfer_len = rqstp->rq_res.page_len + rqstp->rq_res.tail[0].iov_len; - int write_len; - int max_write; - u32 xdr_off; - int chunk_off; - int chunk_no; - struct rpcrdma_write_array *arg_ary; - struct rpcrdma_write_array *res_ary; + if (!iov->iov_len) + return 0; + return svc_rdma_page_dma_map(data, virt_to_page(iov->iov_base), + offset_in_page(iov->iov_base), + iov->iov_len); +} + +/** + * svc_rdma_xb_dma_map - DMA map all segments of an xdr_buf + * @xdr: xdr_buf containing portion of an RPC message to transmit + * @data: pointer to arguments + * + * Returns: + * %0 if DMA mapping was successful + * %-EIO if DMA mapping failed + * + * On failure, any DMA mappings that have been already done must be + * unmapped by the caller. + */ +static int svc_rdma_xb_dma_map(const struct xdr_buf *xdr, void *data) +{ + unsigned int len, remaining; + unsigned long pageoff; + struct page **ppages; int ret; - arg_ary = svc_rdma_get_write_array(rdma_argp); - if (!arg_ary) - return 0; - res_ary = (struct rpcrdma_write_array *) - &rdma_resp->rm_body.rm_chunks[1]; + ret = svc_rdma_iov_dma_map(data, &xdr->head[0]); + if (ret < 0) + return ret; - if (vec->frmr) - max_write = vec->frmr->map_len; - else - max_write = xprt->sc_max_sge * PAGE_SIZE; - - /* Write chunks start at the pagelist */ - for (xdr_off = rqstp->rq_res.head[0].iov_len, chunk_no = 0; - xfer_len && chunk_no < arg_ary->wc_nchunks; - chunk_no++) { - struct rpcrdma_segment *arg_ch; - u64 rs_offset; - - arg_ch = &arg_ary->wc_array[chunk_no].wc_target; - write_len = min(xfer_len, ntohl(arg_ch->rs_length)); - - /* Prepare the response chunk given the length actually - * written */ - xdr_decode_hyper((__be32 *)&arg_ch->rs_offset, &rs_offset); - svc_rdma_xdr_encode_array_chunk(res_ary, chunk_no, - arg_ch->rs_handle, - arg_ch->rs_offset, - write_len); - chunk_off = 0; - while (write_len) { - int this_write; - this_write = min(write_len, max_write); - ret = send_write(xprt, rqstp, - ntohl(arg_ch->rs_handle), - rs_offset + chunk_off, - xdr_off, - this_write, - vec); - if (ret) { - dprintk("svcrdma: RDMA_WRITE failed, ret=%d\n", - ret); - return -EIO; - } - chunk_off += this_write; - xdr_off += this_write; - xfer_len -= this_write; - write_len -= this_write; - } + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + pageoff = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + len = min_t(u32, PAGE_SIZE - pageoff, remaining); + + ret = svc_rdma_page_dma_map(data, *ppages++, pageoff, len); + if (ret < 0) + return ret; + + remaining -= len; + pageoff = 0; } - /* Update the req with the number of chunks actually used */ - svc_rdma_xdr_encode_write_list(rdma_resp, chunk_no); - return rqstp->rq_res.page_len + rqstp->rq_res.tail[0].iov_len; -} + ret = svc_rdma_iov_dma_map(data, &xdr->tail[0]); + if (ret < 0) + return ret; -static int send_reply_chunks(struct svcxprt_rdma *xprt, - struct rpcrdma_msg *rdma_argp, - struct rpcrdma_msg *rdma_resp, - struct svc_rqst *rqstp, - struct svc_rdma_req_map *vec) -{ - u32 xfer_len = rqstp->rq_res.len; - int write_len; - int max_write; - u32 xdr_off; - int chunk_no; - int chunk_off; - int nchunks; - struct rpcrdma_segment *ch; - struct rpcrdma_write_array *arg_ary; - struct rpcrdma_write_array *res_ary; - int ret; + return xdr->len; +} - arg_ary = svc_rdma_get_reply_array(rdma_argp); - if (!arg_ary) - return 0; - /* XXX: need to fix when reply lists occur with read-list and or - * write-list */ - res_ary = (struct rpcrdma_write_array *) - &rdma_resp->rm_body.rm_chunks[2]; +struct svc_rdma_pullup_data { + u8 *pd_dest; + unsigned int pd_length; + unsigned int pd_num_sges; +}; - if (vec->frmr) - max_write = vec->frmr->map_len; - else - max_write = xprt->sc_max_sge * PAGE_SIZE; - - /* xdr offset starts at RPC message */ - nchunks = ntohl(arg_ary->wc_nchunks); - for (xdr_off = 0, chunk_no = 0; - xfer_len && chunk_no < nchunks; - chunk_no++) { - u64 rs_offset; - ch = &arg_ary->wc_array[chunk_no].wc_target; - write_len = min(xfer_len, htonl(ch->rs_length)); - - /* Prepare the reply chunk given the length actually - * written */ - xdr_decode_hyper((__be32 *)&ch->rs_offset, &rs_offset); - svc_rdma_xdr_encode_array_chunk(res_ary, chunk_no, - ch->rs_handle, ch->rs_offset, - write_len); - chunk_off = 0; - while (write_len) { - int this_write; - - this_write = min(write_len, max_write); - ret = send_write(xprt, rqstp, - ntohl(ch->rs_handle), - rs_offset + chunk_off, - xdr_off, - this_write, - vec); - if (ret) { - dprintk("svcrdma: RDMA_WRITE failed, ret=%d\n", - ret); - return -EIO; - } - chunk_off += this_write; - xdr_off += this_write; - xfer_len -= this_write; - write_len -= this_write; - } +/** + * svc_rdma_xb_count_sges - Count how many SGEs will be needed + * @xdr: xdr_buf containing portion of an RPC message to transmit + * @data: pointer to arguments + * + * Returns: + * Number of SGEs needed to Send the contents of @xdr inline + */ +static int svc_rdma_xb_count_sges(const struct xdr_buf *xdr, + void *data) +{ + struct svc_rdma_pullup_data *args = data; + unsigned int remaining; + unsigned long offset; + + if (xdr->head[0].iov_len) + ++args->pd_num_sges; + + offset = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + ++args->pd_num_sges; + remaining -= min_t(u32, PAGE_SIZE - offset, remaining); + offset = 0; } - /* Update the req with the number of chunks actually used */ - svc_rdma_xdr_encode_reply_array(res_ary, chunk_no); - return rqstp->rq_res.len; + if (xdr->tail[0].iov_len) + ++args->pd_num_sges; + + args->pd_length += xdr->len; + return 0; } -/* This function prepares the portion of the RPCRDMA message to be - * sent in the RDMA_SEND. This function is called after data sent via - * RDMA has already been transmitted. There are three cases: - * - The RPCRDMA header, RPC header, and payload are all sent in a - * single RDMA_SEND. This is the "inline" case. - * - The RPCRDMA header and some portion of the RPC header and data - * are sent via this RDMA_SEND and another portion of the data is - * sent via RDMA. - * - The RPCRDMA header [NOMSG] is sent in this RDMA_SEND and the RPC - * header and data are all transmitted via RDMA. - * In all three cases, this function prepares the RPCRDMA header in - * sge[0], the 'type' parameter indicates the type to place in the - * RPCRDMA header, and the 'byte_count' field indicates how much of - * the XDR to include in this RDMA_SEND. NB: The offset of the payload - * to send is zero in the XDR. +/** + * svc_rdma_pull_up_needed - Determine whether to use pull-up + * @rdma: controlling transport + * @sctxt: send_ctxt for the Send WR + * @write_pcl: Write chunk list provided by client + * @xdr: xdr_buf containing RPC message to transmit + * + * Returns: + * %true if pull-up must be used + * %false otherwise */ -static int send_reply(struct svcxprt_rdma *rdma, - struct svc_rqst *rqstp, - struct page *page, - struct rpcrdma_msg *rdma_resp, - struct svc_rdma_op_ctxt *ctxt, - struct svc_rdma_req_map *vec, - int byte_count) +static bool svc_rdma_pull_up_needed(const struct svcxprt_rdma *rdma, + const struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_pcl *write_pcl, + const struct xdr_buf *xdr) { - struct ib_send_wr send_wr; - struct ib_send_wr inv_wr; - int sge_no; - int sge_bytes; - int page_no; - int pages; + /* Resources needed for the transport header */ + struct svc_rdma_pullup_data args = { + .pd_length = sctxt->sc_hdrbuf.len, + .pd_num_sges = 1, + }; int ret; - /* Post a recv buffer to handle another request. */ - ret = svc_rdma_post_recv(rdma); - if (ret) { - printk(KERN_INFO - "svcrdma: could not post a receive buffer, err=%d." - "Closing transport %p.\n", ret, rdma); - set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); - svc_rdma_put_frmr(rdma, vec->frmr); - svc_rdma_put_context(ctxt, 0); - return -ENOTCONN; - } + ret = pcl_process_nonpayloads(write_pcl, xdr, + svc_rdma_xb_count_sges, &args); + if (ret < 0) + return false; - /* Prepare the context */ - ctxt->pages[0] = page; - ctxt->count = 1; - ctxt->frmr = vec->frmr; - if (vec->frmr) - set_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); - else - clear_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); - - /* Prepare the SGE for the RPCRDMA Header */ - ctxt->sge[0].lkey = rdma->sc_dma_lkey; - ctxt->sge[0].length = svc_rdma_xdr_get_reply_hdr_len(rdma_resp); - ctxt->sge[0].addr = - ib_dma_map_page(rdma->sc_cm_id->device, page, 0, - ctxt->sge[0].length, DMA_TO_DEVICE); - if (ib_dma_mapping_error(rdma->sc_cm_id->device, ctxt->sge[0].addr)) - goto err; - atomic_inc(&rdma->sc_dma_used); - - ctxt->direction = DMA_TO_DEVICE; - - /* Map the payload indicated by 'byte_count' */ - for (sge_no = 1; byte_count && sge_no < vec->count; sge_no++) { - int xdr_off = 0; - sge_bytes = min_t(size_t, vec->sge[sge_no].iov_len, byte_count); - byte_count -= sge_bytes; - if (!vec->frmr) { - ctxt->sge[sge_no].addr = - dma_map_xdr(rdma, &rqstp->rq_res, xdr_off, - sge_bytes, DMA_TO_DEVICE); - xdr_off += sge_bytes; - if (ib_dma_mapping_error(rdma->sc_cm_id->device, - ctxt->sge[sge_no].addr)) - goto err; - atomic_inc(&rdma->sc_dma_used); - ctxt->sge[sge_no].lkey = rdma->sc_dma_lkey; - } else { - ctxt->sge[sge_no].addr = (unsigned long) - vec->sge[sge_no].iov_base; - ctxt->sge[sge_no].lkey = vec->frmr->mr->lkey; - } - ctxt->sge[sge_no].length = sge_bytes; + if (args.pd_length < RPCRDMA_PULLUP_THRESH) + return true; + return args.pd_num_sges >= rdma->sc_max_send_sges; +} + +/** + * svc_rdma_xb_linearize - Copy region of xdr_buf to flat buffer + * @xdr: xdr_buf containing portion of an RPC message to copy + * @data: pointer to arguments + * + * Returns: + * Always zero. + */ +static int svc_rdma_xb_linearize(const struct xdr_buf *xdr, + void *data) +{ + struct svc_rdma_pullup_data *args = data; + unsigned int len, remaining; + unsigned long pageoff; + struct page **ppages; + + if (xdr->head[0].iov_len) { + memcpy(args->pd_dest, xdr->head[0].iov_base, xdr->head[0].iov_len); + args->pd_dest += xdr->head[0].iov_len; } - BUG_ON(byte_count != 0); - /* Save all respages in the ctxt and remove them from the - * respages array. They are our pages until the I/O - * completes. - */ - pages = rqstp->rq_next_page - rqstp->rq_respages; - for (page_no = 0; page_no < pages; page_no++) { - ctxt->pages[page_no+1] = rqstp->rq_respages[page_no]; - ctxt->count++; - rqstp->rq_respages[page_no] = NULL; - /* - * If there are more pages than SGE, terminate SGE - * list so that svc_rdma_unmap_dma doesn't attempt to - * unmap garbage. - */ - if (page_no+1 >= sge_no) - ctxt->sge[page_no+1].length = 0; + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + pageoff = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + len = min_t(u32, PAGE_SIZE - pageoff, remaining); + memcpy(args->pd_dest, page_address(*ppages) + pageoff, len); + remaining -= len; + args->pd_dest += len; + pageoff = 0; + ppages++; } - BUG_ON(sge_no > rdma->sc_max_sge); - memset(&send_wr, 0, sizeof send_wr); - ctxt->wr_op = IB_WR_SEND; - send_wr.wr_id = (unsigned long)ctxt; - send_wr.sg_list = ctxt->sge; - send_wr.num_sge = sge_no; - send_wr.opcode = IB_WR_SEND; - send_wr.send_flags = IB_SEND_SIGNALED; - if (vec->frmr) { - /* Prepare INVALIDATE WR */ - memset(&inv_wr, 0, sizeof inv_wr); - inv_wr.opcode = IB_WR_LOCAL_INV; - inv_wr.send_flags = IB_SEND_SIGNALED; - inv_wr.ex.invalidate_rkey = - vec->frmr->mr->lkey; - send_wr.next = &inv_wr; + + if (xdr->tail[0].iov_len) { + memcpy(args->pd_dest, xdr->tail[0].iov_base, xdr->tail[0].iov_len); + args->pd_dest += xdr->tail[0].iov_len; } - ret = svc_rdma_send(rdma, &send_wr); - if (ret) - goto err; + args->pd_length += xdr->len; + return 0; +} +/** + * svc_rdma_pull_up_reply_msg - Copy Reply into a single buffer + * @rdma: controlling transport + * @sctxt: send_ctxt for the Send WR; xprt hdr is already prepared + * @write_pcl: Write chunk list provided by client + * @xdr: prepared xdr_buf containing RPC message + * + * The device is not capable of sending the reply directly. + * Assemble the elements of @xdr into the transport header buffer. + * + * Assumptions: + * pull_up_needed has determined that @xdr will fit in the buffer. + * + * Returns: + * %0 if pull-up was successful + * %-EMSGSIZE if a buffer manipulation problem occurred + */ +static int svc_rdma_pull_up_reply_msg(const struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_pcl *write_pcl, + const struct xdr_buf *xdr) +{ + struct svc_rdma_pullup_data args = { + .pd_dest = sctxt->sc_xprt_buf + sctxt->sc_hdrbuf.len, + }; + int ret; + + ret = pcl_process_nonpayloads(write_pcl, xdr, + svc_rdma_xb_linearize, &args); + if (ret < 0) + return ret; + + sctxt->sc_sges[0].length = sctxt->sc_hdrbuf.len + args.pd_length; + trace_svcrdma_send_pullup(sctxt, args.pd_length); return 0; +} - err: - svc_rdma_unmap_dma(ctxt); - svc_rdma_put_frmr(rdma, vec->frmr); - svc_rdma_put_context(ctxt, 1); - return -EIO; +/* svc_rdma_map_reply_msg - DMA map the buffer holding RPC message + * @rdma: controlling transport + * @sctxt: send_ctxt for the Send WR + * @write_pcl: Write chunk list provided by client + * @reply_pcl: Reply chunk provided by client + * @xdr: prepared xdr_buf containing RPC message + * + * Returns: + * %0 if DMA mapping was successful. + * %-EMSGSIZE if a buffer manipulation problem occurred + * %-EIO if DMA mapping failed + * + * The Send WR's num_sge field is set in all cases. + */ +int svc_rdma_map_reply_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_pcl *write_pcl, + const struct svc_rdma_pcl *reply_pcl, + const struct xdr_buf *xdr) +{ + struct svc_rdma_map_data args = { + .md_rdma = rdma, + .md_ctxt = sctxt, + }; + + /* Set up the (persistently-mapped) transport header SGE. */ + sctxt->sc_send_wr.num_sge = 1; + sctxt->sc_sges[0].length = sctxt->sc_hdrbuf.len; + + /* If there is a Reply chunk, nothing follows the transport + * header, so there is nothing to map. + */ + if (!pcl_is_empty(reply_pcl)) + return 0; + + /* For pull-up, svc_rdma_send() will sync the transport header. + * No additional DMA mapping is necessary. + */ + if (svc_rdma_pull_up_needed(rdma, sctxt, write_pcl, xdr)) + return svc_rdma_pull_up_reply_msg(rdma, sctxt, write_pcl, xdr); + + return pcl_process_nonpayloads(write_pcl, xdr, + svc_rdma_xb_dma_map, &args); } -void svc_rdma_prep_reply_hdr(struct svc_rqst *rqstp) +/* The svc_rqst and all resources it owns are released as soon as + * svc_rdma_sendto returns. Transfer pages under I/O to the ctxt + * so they are released by the Send completion handler. + */ +static void svc_rdma_save_io_pages(struct svc_rqst *rqstp, + struct svc_rdma_send_ctxt *ctxt) { + int i, pages = rqstp->rq_next_page - rqstp->rq_respages; + + ctxt->sc_page_count += pages; + for (i = 0; i < pages; i++) { + ctxt->sc_pages[i] = rqstp->rq_respages[i]; + rqstp->rq_respages[i] = NULL; + } + + /* Prevent svc_xprt_release from releasing pages in rq_pages */ + rqstp->rq_next_page = rqstp->rq_respages; } -/* - * Return the start of an xdr buffer. +/* Prepare the portion of the RPC Reply that will be transmitted + * via RDMA Send. The RPC-over-RDMA transport header is prepared + * in sc_sges[0], and the RPC xdr_buf is prepared in following sges. + * + * Depending on whether a Write list or Reply chunk is present, + * the server may Send all, a portion of, or none of the xdr_buf. + * In the latter case, only the transport header (sc_sges[0]) is + * transmitted. + * + * Assumptions: + * - The Reply's transport header will never be larger than a page. + */ +static int svc_rdma_send_reply_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_recv_ctxt *rctxt, + struct svc_rqst *rqstp) +{ + struct ib_send_wr *send_wr = &sctxt->sc_send_wr; + int ret; + + ret = svc_rdma_map_reply_msg(rdma, sctxt, &rctxt->rc_write_pcl, + &rctxt->rc_reply_pcl, &rqstp->rq_res); + if (ret < 0) + return ret; + + /* Transfer pages involved in RDMA Writes to the sctxt's + * page array. Completion handling releases these pages. + */ + svc_rdma_save_io_pages(rqstp, sctxt); + + if (rctxt->rc_inv_rkey) { + send_wr->opcode = IB_WR_SEND_WITH_INV; + send_wr->ex.invalidate_rkey = rctxt->rc_inv_rkey; + } else { + send_wr->opcode = IB_WR_SEND; + } + + return svc_rdma_post_send(rdma, sctxt); +} + +/** + * svc_rdma_send_error_msg - Send an RPC/RDMA v1 error response + * @rdma: controlling transport context + * @sctxt: Send context for the response + * @rctxt: Receive context for incoming bad message + * @status: negative errno indicating error that occurred + * + * Given the client-provided Read, Write, and Reply chunks, the + * server was not able to parse the Call or form a complete Reply. + * Return an RDMA_ERROR message so the client can retire the RPC + * transaction. + * + * The caller does not have to release @sctxt. It is released by + * Send completion, or by this function on error. */ -static void *xdr_start(struct xdr_buf *xdr) +void svc_rdma_send_error_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *sctxt, + struct svc_rdma_recv_ctxt *rctxt, + int status) { - return xdr->head[0].iov_base - - (xdr->len - - xdr->page_len - - xdr->tail[0].iov_len - - xdr->head[0].iov_len); + __be32 *rdma_argp = rctxt->rc_recv_buf; + __be32 *p; + + rpcrdma_set_xdrlen(&sctxt->sc_hdrbuf, 0); + xdr_init_encode(&sctxt->sc_stream, &sctxt->sc_hdrbuf, + sctxt->sc_xprt_buf, NULL); + + p = xdr_reserve_space(&sctxt->sc_stream, + rpcrdma_fixed_maxsz * sizeof(*p)); + if (!p) + goto put_ctxt; + + *p++ = *rdma_argp; + *p++ = *(rdma_argp + 1); + *p++ = rdma->sc_fc_credits; + *p = rdma_error; + + switch (status) { + case -EPROTONOSUPPORT: + p = xdr_reserve_space(&sctxt->sc_stream, 3 * sizeof(*p)); + if (!p) + goto put_ctxt; + + *p++ = err_vers; + *p++ = rpcrdma_version; + *p = rpcrdma_version; + trace_svcrdma_err_vers(*rdma_argp); + break; + default: + p = xdr_reserve_space(&sctxt->sc_stream, sizeof(*p)); + if (!p) + goto put_ctxt; + + *p = err_chunk; + trace_svcrdma_err_chunk(*rdma_argp); + } + + /* Remote Invalidation is skipped for simplicity. */ + sctxt->sc_send_wr.num_sge = 1; + sctxt->sc_send_wr.opcode = IB_WR_SEND; + sctxt->sc_sges[0].length = sctxt->sc_hdrbuf.len; + if (svc_rdma_post_send(rdma, sctxt)) + goto put_ctxt; + return; + +put_ctxt: + svc_rdma_send_ctxt_put(rdma, sctxt); } +/** + * svc_rdma_sendto - Transmit an RPC reply + * @rqstp: processed RPC request, reply XDR already in ::rq_res + * + * Any resources still associated with @rqstp are released upon return. + * If no reply message was possible, the connection is closed. + * + * Returns: + * %0 if an RPC reply has been successfully posted, + * %-ENOMEM if a resource shortage occurred (connection is lost), + * %-ENOTCONN if posting failed (connection is lost). + */ int svc_rdma_sendto(struct svc_rqst *rqstp) { struct svc_xprt *xprt = rqstp->rq_xprt; struct svcxprt_rdma *rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt); - struct rpcrdma_msg *rdma_argp; - struct rpcrdma_msg *rdma_resp; - struct rpcrdma_write_array *reply_ary; - enum rpcrdma_proc reply_type; + struct svc_rdma_recv_ctxt *rctxt = rqstp->rq_xprt_ctxt; + __be32 *rdma_argp = rctxt->rc_recv_buf; + struct svc_rdma_send_ctxt *sctxt; + unsigned int rc_size; + __be32 *p; int ret; - int inline_bytes; - struct page *res_page; - struct svc_rdma_op_ctxt *ctxt; - struct svc_rdma_req_map *vec; - - dprintk("svcrdma: sending response for rqstp=%p\n", rqstp); - - /* Get the RDMA request header. */ - rdma_argp = xdr_start(&rqstp->rq_arg); - - /* Build an req vec for the XDR */ - ctxt = svc_rdma_get_context(rdma); - ctxt->direction = DMA_TO_DEVICE; - vec = svc_rdma_get_req_map(); - ret = map_xdr(rdma, &rqstp->rq_res, vec); - if (ret) - goto err0; - inline_bytes = rqstp->rq_res.len; - - /* Create the RDMA response header */ - res_page = svc_rdma_get_page(); - rdma_resp = page_address(res_page); - reply_ary = svc_rdma_get_reply_array(rdma_argp); - if (reply_ary) - reply_type = RDMA_NOMSG; - else - reply_type = RDMA_MSG; - svc_rdma_xdr_encode_reply_header(rdma, rdma_argp, - rdma_resp, reply_type); - - /* Send any write-chunk data and build resp write-list */ - ret = send_write_chunks(rdma, rdma_argp, rdma_resp, - rqstp, vec); - if (ret < 0) { - printk(KERN_ERR "svcrdma: failed to send write chunks, rc=%d\n", - ret); - goto err1; - } - inline_bytes -= ret; - - /* Send any reply-list data and update resp reply-list */ - ret = send_reply_chunks(rdma, rdma_argp, rdma_resp, - rqstp, vec); - if (ret < 0) { - printk(KERN_ERR "svcrdma: failed to send reply chunks, rc=%d\n", - ret); - goto err1; + + ret = -ENOTCONN; + if (svc_xprt_is_dead(xprt)) + goto drop_connection; + + ret = -ENOMEM; + sctxt = svc_rdma_send_ctxt_get(rdma); + if (!sctxt) + goto drop_connection; + + ret = -EMSGSIZE; + p = xdr_reserve_space(&sctxt->sc_stream, + rpcrdma_fixed_maxsz * sizeof(*p)); + if (!p) + goto put_ctxt; + + ret = svc_rdma_send_write_list(rdma, rctxt, &rqstp->rq_res); + if (ret < 0) + goto put_ctxt; + + rc_size = 0; + if (!pcl_is_empty(&rctxt->rc_reply_pcl)) { + ret = svc_rdma_prepare_reply_chunk(rdma, &rctxt->rc_write_pcl, + &rctxt->rc_reply_pcl, sctxt, + &rqstp->rq_res); + if (ret < 0) + goto reply_chunk; + rc_size = ret; } - inline_bytes -= ret; - - ret = send_reply(rdma, rqstp, res_page, rdma_resp, ctxt, vec, - inline_bytes); - svc_rdma_put_req_map(vec); - dprintk("svcrdma: send_reply returns %d\n", ret); - return ret; - - err1: - put_page(res_page); - err0: - svc_rdma_put_req_map(vec); - svc_rdma_put_context(ctxt, 0); - return ret; + + *p++ = *rdma_argp; + *p++ = *(rdma_argp + 1); + *p++ = rdma->sc_fc_credits; + *p = pcl_is_empty(&rctxt->rc_reply_pcl) ? rdma_msg : rdma_nomsg; + + ret = svc_rdma_encode_read_list(sctxt); + if (ret < 0) + goto put_ctxt; + ret = svc_rdma_encode_write_list(rctxt, sctxt); + if (ret < 0) + goto put_ctxt; + ret = svc_rdma_encode_reply_chunk(rctxt, sctxt, rc_size); + if (ret < 0) + goto put_ctxt; + + ret = svc_rdma_send_reply_msg(rdma, sctxt, rctxt, rqstp); + if (ret < 0) + goto put_ctxt; + return 0; + +reply_chunk: + if (ret != -E2BIG && ret != -EINVAL) + goto put_ctxt; + + /* Send completion releases payload pages that were part + * of previously posted RDMA Writes. + */ + svc_rdma_save_io_pages(rqstp, sctxt); + svc_rdma_send_error_msg(rdma, sctxt, rctxt, ret); + return 0; + +put_ctxt: + svc_rdma_send_ctxt_put(rdma, sctxt); +drop_connection: + trace_svcrdma_send_err(rqstp, ret); + svc_xprt_deferred_close(&rdma->sc_xprt); + return -ENOTCONN; +} + +/** + * svc_rdma_result_payload - special processing for a result payload + * @rqstp: RPC transaction context + * @offset: payload's byte offset in @rqstp->rq_res + * @length: size of payload, in bytes + * + * Assign the passed-in result payload to the current Write chunk, + * and advance to cur_result_payload to the next Write chunk, if + * there is one. + * + * Return values: + * %0 if successful or nothing needed to be done + * %-E2BIG if the payload was larger than the Write chunk + */ +int svc_rdma_result_payload(struct svc_rqst *rqstp, unsigned int offset, + unsigned int length) +{ + struct svc_rdma_recv_ctxt *rctxt = rqstp->rq_xprt_ctxt; + struct svc_rdma_chunk *chunk; + + chunk = rctxt->rc_cur_result_payload; + if (!length || !chunk) + return 0; + rctxt->rc_cur_result_payload = + pcl_next_chunk(&rctxt->rc_write_pcl, chunk); + + if (length > chunk->ch_length) + return -E2BIG; + chunk->ch_position = offset; + chunk->ch_payload_length = length; + return 0; } diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c index 62e4f9bcc387..b7b318ad25c4 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_transport.c +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -1,4 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause /* + * Copyright (c) 2015-2018 Oracle. All rights reserved. + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. * Copyright (c) 2005-2007 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -39,159 +42,76 @@ * Author: Tom Tucker <tom@opengridcomputing.com> */ -#include <linux/sunrpc/svc_xprt.h> -#include <linux/sunrpc/debug.h> -#include <linux/sunrpc/rpc_rdma.h> #include <linux/interrupt.h> #include <linux/sched.h> #include <linux/slab.h> #include <linux/spinlock.h> #include <linux/workqueue.h> +#include <linux/export.h> + #include <rdma/ib_verbs.h> #include <rdma/rdma_cm.h> +#include <rdma/rw.h> + +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/svc_xprt.h> #include <linux/sunrpc/svc_rdma.h> -#include <linux/export.h> + #include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> #define RPCDBG_FACILITY RPCDBG_SVCXPRT +static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv, + struct net *net, int node); +static int svc_rdma_listen_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event); static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, struct net *net, struct sockaddr *sa, int salen, int flags); static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt); -static void svc_rdma_release_rqst(struct svc_rqst *); -static void dto_tasklet_func(unsigned long data); static void svc_rdma_detach(struct svc_xprt *xprt); static void svc_rdma_free(struct svc_xprt *xprt); static int svc_rdma_has_wspace(struct svc_xprt *xprt); -static void rq_cq_reap(struct svcxprt_rdma *xprt); -static void sq_cq_reap(struct svcxprt_rdma *xprt); +static void svc_rdma_kill_temp_xprt(struct svc_xprt *); -static DECLARE_TASKLET(dto_tasklet, dto_tasklet_func, 0UL); -static DEFINE_SPINLOCK(dto_lock); -static LIST_HEAD(dto_xprt_q); - -static struct svc_xprt_ops svc_rdma_ops = { +static const struct svc_xprt_ops svc_rdma_ops = { .xpo_create = svc_rdma_create, .xpo_recvfrom = svc_rdma_recvfrom, .xpo_sendto = svc_rdma_sendto, - .xpo_release_rqst = svc_rdma_release_rqst, + .xpo_result_payload = svc_rdma_result_payload, + .xpo_release_ctxt = svc_rdma_release_ctxt, .xpo_detach = svc_rdma_detach, .xpo_free = svc_rdma_free, - .xpo_prep_reply_hdr = svc_rdma_prep_reply_hdr, .xpo_has_wspace = svc_rdma_has_wspace, .xpo_accept = svc_rdma_accept, + .xpo_kill_temp_xprt = svc_rdma_kill_temp_xprt, }; struct svc_xprt_class svc_rdma_class = { .xcl_name = "rdma", .xcl_owner = THIS_MODULE, .xcl_ops = &svc_rdma_ops, - .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_RDMA, + .xcl_ident = XPRT_TRANSPORT_RDMA, }; -struct svc_rdma_op_ctxt *svc_rdma_get_context(struct svcxprt_rdma *xprt) -{ - struct svc_rdma_op_ctxt *ctxt; - - while (1) { - ctxt = kmem_cache_alloc(svc_rdma_ctxt_cachep, GFP_KERNEL); - if (ctxt) - break; - schedule_timeout_uninterruptible(msecs_to_jiffies(500)); - } - ctxt->xprt = xprt; - INIT_LIST_HEAD(&ctxt->dto_q); - ctxt->count = 0; - ctxt->frmr = NULL; - atomic_inc(&xprt->sc_ctxt_used); - return ctxt; -} - -void svc_rdma_unmap_dma(struct svc_rdma_op_ctxt *ctxt) -{ - struct svcxprt_rdma *xprt = ctxt->xprt; - int i; - for (i = 0; i < ctxt->count && ctxt->sge[i].length; i++) { - /* - * Unmap the DMA addr in the SGE if the lkey matches - * the sc_dma_lkey, otherwise, ignore it since it is - * an FRMR lkey and will be unmapped later when the - * last WR that uses it completes. - */ - if (ctxt->sge[i].lkey == xprt->sc_dma_lkey) { - atomic_dec(&xprt->sc_dma_used); - ib_dma_unmap_page(xprt->sc_cm_id->device, - ctxt->sge[i].addr, - ctxt->sge[i].length, - ctxt->direction); - } - } -} - -void svc_rdma_put_context(struct svc_rdma_op_ctxt *ctxt, int free_pages) -{ - struct svcxprt_rdma *xprt; - int i; - - BUG_ON(!ctxt); - xprt = ctxt->xprt; - if (free_pages) - for (i = 0; i < ctxt->count; i++) - put_page(ctxt->pages[i]); - - kmem_cache_free(svc_rdma_ctxt_cachep, ctxt); - atomic_dec(&xprt->sc_ctxt_used); -} - -/* - * Temporary NFS req mappings are shared across all transport - * instances. These are short lived and should be bounded by the number - * of concurrent server threads * depth of the SQ. - */ -struct svc_rdma_req_map *svc_rdma_get_req_map(void) -{ - struct svc_rdma_req_map *map; - while (1) { - map = kmem_cache_alloc(svc_rdma_map_cachep, GFP_KERNEL); - if (map) - break; - schedule_timeout_uninterruptible(msecs_to_jiffies(500)); - } - map->count = 0; - map->frmr = NULL; - return map; -} - -void svc_rdma_put_req_map(struct svc_rdma_req_map *map) -{ - kmem_cache_free(svc_rdma_map_cachep, map); -} - -/* ib_cq event handler */ -static void cq_event_handler(struct ib_event *event, void *context) -{ - struct svc_xprt *xprt = context; - dprintk("svcrdma: received CQ event id=%d, context=%p\n", - event->event, context); - set_bit(XPT_CLOSE, &xprt->xpt_flags); -} - /* QP event handler */ static void qp_event_handler(struct ib_event *event, void *context) { struct svc_xprt *xprt = context; + trace_svcrdma_qp_error(event, (struct sockaddr *)&xprt->xpt_remote); switch (event->event) { /* These are considered benign events */ case IB_EVENT_PATH_MIG: case IB_EVENT_COMM_EST: case IB_EVENT_SQ_DRAINED: case IB_EVENT_QP_LAST_WQE_REACHED: - dprintk("svcrdma: QP event %d received for QP=%p\n", - event->event, event->element.qp); break; + /* These are considered fatal events */ case IB_EVENT_PATH_MIG_ERR: case IB_EVENT_QP_FATAL: @@ -199,338 +119,104 @@ static void qp_event_handler(struct ib_event *event, void *context) case IB_EVENT_QP_ACCESS_ERR: case IB_EVENT_DEVICE_FATAL: default: - dprintk("svcrdma: QP ERROR event %d received for QP=%p, " - "closing transport\n", - event->event, event->element.qp); - set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_deferred_close(xprt); break; } } -/* - * Data Transfer Operation Tasklet - * - * Walks a list of transports with I/O pending, removing entries as - * they are added to the server's I/O pending list. Two bits indicate - * if SQ, RQ, or both have I/O pending. The dto_lock is an irqsave - * spinlock that serializes access to the transport list with the RQ - * and SQ interrupt handlers. - */ -static void dto_tasklet_func(unsigned long data) -{ - struct svcxprt_rdma *xprt; - unsigned long flags; - - spin_lock_irqsave(&dto_lock, flags); - while (!list_empty(&dto_xprt_q)) { - xprt = list_entry(dto_xprt_q.next, - struct svcxprt_rdma, sc_dto_q); - list_del_init(&xprt->sc_dto_q); - spin_unlock_irqrestore(&dto_lock, flags); - - rq_cq_reap(xprt); - sq_cq_reap(xprt); - - svc_xprt_put(&xprt->sc_xprt); - spin_lock_irqsave(&dto_lock, flags); - } - spin_unlock_irqrestore(&dto_lock, flags); -} - -/* - * Receive Queue Completion Handler - * - * Since an RQ completion handler is called on interrupt context, we - * need to defer the handling of the I/O to a tasklet - */ -static void rq_comp_handler(struct ib_cq *cq, void *cq_context) -{ - struct svcxprt_rdma *xprt = cq_context; - unsigned long flags; - - /* Guard against unconditional flush call for destroyed QP */ - if (atomic_read(&xprt->sc_xprt.xpt_ref.refcount)==0) - return; - - /* - * Set the bit regardless of whether or not it's on the list - * because it may be on the list already due to an SQ - * completion. - */ - set_bit(RDMAXPRT_RQ_PENDING, &xprt->sc_flags); - - /* - * If this transport is not already on the DTO transport queue, - * add it - */ - spin_lock_irqsave(&dto_lock, flags); - if (list_empty(&xprt->sc_dto_q)) { - svc_xprt_get(&xprt->sc_xprt); - list_add_tail(&xprt->sc_dto_q, &dto_xprt_q); - } - spin_unlock_irqrestore(&dto_lock, flags); - - /* Tasklet does all the work to avoid irqsave locks. */ - tasklet_schedule(&dto_tasklet); -} - -/* - * rq_cq_reap - Process the RQ CQ. - * - * Take all completing WC off the CQE and enqueue the associated DTO - * context on the dto_q for the transport. - * - * Note that caller must hold a transport reference. - */ -static void rq_cq_reap(struct svcxprt_rdma *xprt) +static struct rdma_cm_id * +svc_rdma_create_listen_id(struct net *net, struct sockaddr *sap, + void *context) { + struct rdma_cm_id *listen_id; int ret; - struct ib_wc wc; - struct svc_rdma_op_ctxt *ctxt = NULL; - if (!test_and_clear_bit(RDMAXPRT_RQ_PENDING, &xprt->sc_flags)) - return; + listen_id = rdma_create_id(net, svc_rdma_listen_handler, context, + RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(listen_id)) + return listen_id; - ib_req_notify_cq(xprt->sc_rq_cq, IB_CQ_NEXT_COMP); - atomic_inc(&rdma_stat_rq_poll); - - while ((ret = ib_poll_cq(xprt->sc_rq_cq, 1, &wc)) > 0) { - ctxt = (struct svc_rdma_op_ctxt *)(unsigned long)wc.wr_id; - ctxt->wc_status = wc.status; - ctxt->byte_len = wc.byte_len; - svc_rdma_unmap_dma(ctxt); - if (wc.status != IB_WC_SUCCESS) { - /* Close the transport */ - dprintk("svcrdma: transport closing putting ctxt %p\n", ctxt); - set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); - svc_rdma_put_context(ctxt, 1); - svc_xprt_put(&xprt->sc_xprt); - continue; - } - spin_lock_bh(&xprt->sc_rq_dto_lock); - list_add_tail(&ctxt->dto_q, &xprt->sc_rq_dto_q); - spin_unlock_bh(&xprt->sc_rq_dto_lock); - svc_xprt_put(&xprt->sc_xprt); - } - - if (ctxt) - atomic_inc(&rdma_stat_rq_prod); - - set_bit(XPT_DATA, &xprt->sc_xprt.xpt_flags); - /* - * If data arrived before established event, - * don't enqueue. This defers RPC I/O until the - * RDMA connection is complete. + /* Allow both IPv4 and IPv6 sockets to bind a single port + * at the same time. */ - if (!test_bit(RDMAXPRT_CONN_PENDING, &xprt->sc_flags)) - svc_xprt_enqueue(&xprt->sc_xprt); -} - -/* - * Process a completion context - */ -static void process_context(struct svcxprt_rdma *xprt, - struct svc_rdma_op_ctxt *ctxt) -{ - svc_rdma_unmap_dma(ctxt); - - switch (ctxt->wr_op) { - case IB_WR_SEND: - if (test_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags)) - svc_rdma_put_frmr(xprt, ctxt->frmr); - svc_rdma_put_context(ctxt, 1); - break; - - case IB_WR_RDMA_WRITE: - svc_rdma_put_context(ctxt, 0); - break; - - case IB_WR_RDMA_READ: - case IB_WR_RDMA_READ_WITH_INV: - if (test_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags)) { - struct svc_rdma_op_ctxt *read_hdr = ctxt->read_hdr; - BUG_ON(!read_hdr); - if (test_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags)) - svc_rdma_put_frmr(xprt, ctxt->frmr); - spin_lock_bh(&xprt->sc_rq_dto_lock); - set_bit(XPT_DATA, &xprt->sc_xprt.xpt_flags); - list_add_tail(&read_hdr->dto_q, - &xprt->sc_read_complete_q); - spin_unlock_bh(&xprt->sc_rq_dto_lock); - svc_xprt_enqueue(&xprt->sc_xprt); - } - svc_rdma_put_context(ctxt, 0); - break; - - default: - printk(KERN_ERR "svcrdma: unexpected completion type, " - "opcode=%d\n", - ctxt->wr_op); - break; - } -} - -/* - * Send Queue Completion Handler - potentially called on interrupt context. - * - * Note that caller must hold a transport reference. - */ -static void sq_cq_reap(struct svcxprt_rdma *xprt) -{ - struct svc_rdma_op_ctxt *ctxt = NULL; - struct ib_wc wc; - struct ib_cq *cq = xprt->sc_sq_cq; - int ret; - - if (!test_and_clear_bit(RDMAXPRT_SQ_PENDING, &xprt->sc_flags)) - return; - - ib_req_notify_cq(xprt->sc_sq_cq, IB_CQ_NEXT_COMP); - atomic_inc(&rdma_stat_sq_poll); - while ((ret = ib_poll_cq(cq, 1, &wc)) > 0) { - if (wc.status != IB_WC_SUCCESS) - /* Close the transport */ - set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); +#if IS_ENABLED(CONFIG_IPV6) + ret = rdma_set_afonly(listen_id, 1); + if (ret) + goto out_destroy; +#endif + ret = rdma_bind_addr(listen_id, sap); + if (ret) + goto out_destroy; - /* Decrement used SQ WR count */ - atomic_dec(&xprt->sc_sq_count); - wake_up(&xprt->sc_send_wait); - - ctxt = (struct svc_rdma_op_ctxt *)(unsigned long)wc.wr_id; - if (ctxt) - process_context(xprt, ctxt); - - svc_xprt_put(&xprt->sc_xprt); - } - - if (ctxt) - atomic_inc(&rdma_stat_sq_prod); -} - -static void sq_comp_handler(struct ib_cq *cq, void *cq_context) -{ - struct svcxprt_rdma *xprt = cq_context; - unsigned long flags; - - /* Guard against unconditional flush call for destroyed QP */ - if (atomic_read(&xprt->sc_xprt.xpt_ref.refcount)==0) - return; - - /* - * Set the bit regardless of whether or not it's on the list - * because it may be on the list already due to an RQ - * completion. - */ - set_bit(RDMAXPRT_SQ_PENDING, &xprt->sc_flags); + ret = rdma_listen(listen_id, RPCRDMA_LISTEN_BACKLOG); + if (ret) + goto out_destroy; - /* - * If this transport is not already on the DTO transport queue, - * add it - */ - spin_lock_irqsave(&dto_lock, flags); - if (list_empty(&xprt->sc_dto_q)) { - svc_xprt_get(&xprt->sc_xprt); - list_add_tail(&xprt->sc_dto_q, &dto_xprt_q); - } - spin_unlock_irqrestore(&dto_lock, flags); + return listen_id; - /* Tasklet does all the work to avoid irqsave locks. */ - tasklet_schedule(&dto_tasklet); +out_destroy: + rdma_destroy_id(listen_id); + return ERR_PTR(ret); } -static struct svcxprt_rdma *rdma_create_xprt(struct svc_serv *serv, - int listener) +static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv, + struct net *net, int node) { - struct svcxprt_rdma *cma_xprt = kzalloc(sizeof *cma_xprt, GFP_KERNEL); + static struct lock_class_key svcrdma_rwctx_lock; + static struct lock_class_key svcrdma_sctx_lock; + static struct lock_class_key svcrdma_dto_lock; + struct svcxprt_rdma *cma_xprt; + cma_xprt = kzalloc_node(sizeof(*cma_xprt), GFP_KERNEL, node); if (!cma_xprt) return NULL; - svc_xprt_init(&init_net, &svc_rdma_class, &cma_xprt->sc_xprt, serv); + + svc_xprt_init(net, &svc_rdma_class, &cma_xprt->sc_xprt, serv); INIT_LIST_HEAD(&cma_xprt->sc_accept_q); - INIT_LIST_HEAD(&cma_xprt->sc_dto_q); INIT_LIST_HEAD(&cma_xprt->sc_rq_dto_q); INIT_LIST_HEAD(&cma_xprt->sc_read_complete_q); - INIT_LIST_HEAD(&cma_xprt->sc_frmr_q); + init_llist_head(&cma_xprt->sc_send_ctxts); + init_llist_head(&cma_xprt->sc_recv_ctxts); + init_llist_head(&cma_xprt->sc_rw_ctxts); init_waitqueue_head(&cma_xprt->sc_send_wait); spin_lock_init(&cma_xprt->sc_lock); spin_lock_init(&cma_xprt->sc_rq_dto_lock); - spin_lock_init(&cma_xprt->sc_frmr_q_lock); - - cma_xprt->sc_ord = svcrdma_ord; + lockdep_set_class(&cma_xprt->sc_rq_dto_lock, &svcrdma_dto_lock); + spin_lock_init(&cma_xprt->sc_send_lock); + lockdep_set_class(&cma_xprt->sc_send_lock, &svcrdma_sctx_lock); + spin_lock_init(&cma_xprt->sc_rw_ctxt_lock); + lockdep_set_class(&cma_xprt->sc_rw_ctxt_lock, &svcrdma_rwctx_lock); - cma_xprt->sc_max_req_size = svcrdma_max_req_size; - cma_xprt->sc_max_requests = svcrdma_max_requests; - cma_xprt->sc_sq_depth = svcrdma_max_requests * RPCRDMA_SQ_DEPTH_MULT; - atomic_set(&cma_xprt->sc_sq_count, 0); - atomic_set(&cma_xprt->sc_ctxt_used, 0); - - if (listener) - set_bit(XPT_LISTENER, &cma_xprt->sc_xprt.xpt_flags); + /* + * Note that this implies that the underlying transport support + * has some form of congestion control (see RFC 7530 section 3.1 + * paragraph 2). For now, we assume that all supported RDMA + * transports are suitable here. + */ + set_bit(XPT_CONG_CTRL, &cma_xprt->sc_xprt.xpt_flags); return cma_xprt; } -struct page *svc_rdma_get_page(void) -{ - struct page *page; - - while ((page = alloc_page(GFP_KERNEL)) == NULL) { - /* If we can't get memory, wait a bit and try again */ - printk(KERN_INFO "svcrdma: out of memory...retrying in 1000 " - "jiffies.\n"); - schedule_timeout_uninterruptible(msecs_to_jiffies(1000)); - } - return page; -} - -int svc_rdma_post_recv(struct svcxprt_rdma *xprt) +static void +svc_rdma_parse_connect_private(struct svcxprt_rdma *newxprt, + struct rdma_conn_param *param) { - struct ib_recv_wr recv_wr, *bad_recv_wr; - struct svc_rdma_op_ctxt *ctxt; - struct page *page; - dma_addr_t pa; - int sge_no; - int buflen; - int ret; + const struct rpcrdma_connect_private *pmsg = param->private_data; - ctxt = svc_rdma_get_context(xprt); - buflen = 0; - ctxt->direction = DMA_FROM_DEVICE; - for (sge_no = 0; buflen < xprt->sc_max_req_size; sge_no++) { - BUG_ON(sge_no >= xprt->sc_max_sge); - page = svc_rdma_get_page(); - ctxt->pages[sge_no] = page; - pa = ib_dma_map_page(xprt->sc_cm_id->device, - page, 0, PAGE_SIZE, - DMA_FROM_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, pa)) - goto err_put_ctxt; - atomic_inc(&xprt->sc_dma_used); - ctxt->sge[sge_no].addr = pa; - ctxt->sge[sge_no].length = PAGE_SIZE; - ctxt->sge[sge_no].lkey = xprt->sc_dma_lkey; - ctxt->count = sge_no + 1; - buflen += PAGE_SIZE; - } - recv_wr.next = NULL; - recv_wr.sg_list = &ctxt->sge[0]; - recv_wr.num_sge = ctxt->count; - recv_wr.wr_id = (u64)(unsigned long)ctxt; + if (pmsg && + pmsg->cp_magic == rpcrdma_cmp_magic && + pmsg->cp_version == RPCRDMA_CMP_VERSION) { + newxprt->sc_snd_w_inv = pmsg->cp_flags & + RPCRDMA_CMP_F_SND_W_INV_OK; - svc_xprt_get(&xprt->sc_xprt); - ret = ib_post_recv(xprt->sc_qp, &recv_wr, &bad_recv_wr); - if (ret) { - svc_rdma_unmap_dma(ctxt); - svc_rdma_put_context(ctxt, 1); - svc_xprt_put(&xprt->sc_xprt); + dprintk("svcrdma: client send_size %u, recv_size %u " + "remote inv %ssupported\n", + rpcrdma_decode_buffer_size(pmsg->cp_send_size), + rpcrdma_decode_buffer_size(pmsg->cp_recv_size), + newxprt->sc_snd_w_inv ? "" : "un"); } - return ret; - - err_put_ctxt: - svc_rdma_unmap_dma(ctxt); - svc_rdma_put_context(ctxt, 1); - return -ENOMEM; } /* @@ -544,29 +230,38 @@ int svc_rdma_post_recv(struct svcxprt_rdma *xprt) * will call the recvfrom method on the listen xprt which will accept the new * connection. */ -static void handle_connect_req(struct rdma_cm_id *new_cma_id, size_t client_ird) +static void handle_connect_req(struct rdma_cm_id *new_cma_id, + struct rdma_conn_param *param) { struct svcxprt_rdma *listen_xprt = new_cma_id->context; struct svcxprt_rdma *newxprt; struct sockaddr *sa; - /* Create a new transport */ - newxprt = rdma_create_xprt(listen_xprt->sc_xprt.xpt_server, 0); - if (!newxprt) { - dprintk("svcrdma: failed to create new transport\n"); + newxprt = svc_rdma_create_xprt(listen_xprt->sc_xprt.xpt_server, + listen_xprt->sc_xprt.xpt_net, + ibdev_to_node(new_cma_id->device)); + if (!newxprt) return; - } newxprt->sc_cm_id = new_cma_id; new_cma_id->context = newxprt; - dprintk("svcrdma: Creating newxprt=%p, cm_id=%p, listenxprt=%p\n", - newxprt, newxprt->sc_cm_id, listen_xprt); + svc_rdma_parse_connect_private(newxprt, param); /* Save client advertised inbound read limit for use later in accept. */ - newxprt->sc_ord = client_ird; + newxprt->sc_ord = param->initiator_depth; - /* Set the local and remote addresses in the transport */ sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.dst_addr; - svc_xprt_set_remote(&newxprt->sc_xprt, sa, svc_addr_len(sa)); + newxprt->sc_xprt.xpt_remotelen = svc_addr_len(sa); + memcpy(&newxprt->sc_xprt.xpt_remote, sa, + newxprt->sc_xprt.xpt_remotelen); + snprintf(newxprt->sc_xprt.xpt_remotebuf, + sizeof(newxprt->sc_xprt.xpt_remotebuf) - 1, "%pISc", sa); + + /* The remote port is arbitrary and not under the control of the + * client ULP. Set it to a fixed value so that the DRC continues + * to be effective after a reconnect. + */ + rpc_set_port((struct sockaddr *)&newxprt->sc_xprt.xpt_remote, 0); + sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr; svc_xprt_set_local(&newxprt->sc_xprt, sa, svc_addr_len(sa)); @@ -574,89 +269,79 @@ static void handle_connect_req(struct rdma_cm_id *new_cma_id, size_t client_ird) * Enqueue the new transport on the accept queue of the listening * transport */ - spin_lock_bh(&listen_xprt->sc_lock); + spin_lock(&listen_xprt->sc_lock); list_add_tail(&newxprt->sc_accept_q, &listen_xprt->sc_accept_q); - spin_unlock_bh(&listen_xprt->sc_lock); + spin_unlock(&listen_xprt->sc_lock); set_bit(XPT_CONN, &listen_xprt->sc_xprt.xpt_flags); svc_xprt_enqueue(&listen_xprt->sc_xprt); } -/* - * Handles events generated on the listening endpoint. These events will be - * either be incoming connect requests or adapter removal events. +/** + * svc_rdma_listen_handler - Handle CM events generated on a listening endpoint + * @cma_id: the server's listener rdma_cm_id + * @event: details of the event + * + * Return values: + * %0: Do not destroy @cma_id + * %1: Destroy @cma_id + * + * NB: There is never a DEVICE_REMOVAL event for INADDR_ANY listeners. */ -static int rdma_listen_handler(struct rdma_cm_id *cma_id, - struct rdma_cm_event *event) +static int svc_rdma_listen_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event) { - struct svcxprt_rdma *xprt = cma_id->context; - int ret = 0; + struct sockaddr *sap = (struct sockaddr *)&cma_id->route.addr.src_addr; + struct svcxprt_rdma *cma_xprt = cma_id->context; + struct svc_xprt *cma_rdma = &cma_xprt->sc_xprt; + struct rdma_cm_id *listen_id; switch (event->event) { case RDMA_CM_EVENT_CONNECT_REQUEST: - dprintk("svcrdma: Connect request on cma_id=%p, xprt = %p, " - "event=%d\n", cma_id, cma_id->context, event->event); - handle_connect_req(cma_id, - event->param.conn.initiator_depth); + handle_connect_req(cma_id, &event->param.conn); break; - - case RDMA_CM_EVENT_ESTABLISHED: - /* Accept complete */ - dprintk("svcrdma: Connection completed on LISTEN xprt=%p, " - "cm_id=%p\n", xprt, cma_id); - break; - - case RDMA_CM_EVENT_DEVICE_REMOVAL: - dprintk("svcrdma: Device removal xprt=%p, cm_id=%p\n", - xprt, cma_id); - if (xprt) - set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); - break; - + case RDMA_CM_EVENT_ADDR_CHANGE: + listen_id = svc_rdma_create_listen_id(cma_rdma->xpt_net, + sap, cma_xprt); + if (IS_ERR(listen_id)) { + pr_err("Listener dead, address change failed for device %s\n", + cma_id->device->name); + } else + cma_xprt->sc_cm_id = listen_id; + return 1; default: - dprintk("svcrdma: Unexpected event on listening endpoint %p, " - "event=%d\n", cma_id, event->event); break; } - - return ret; + return 0; } -static int rdma_cma_handler(struct rdma_cm_id *cma_id, - struct rdma_cm_event *event) +/** + * svc_rdma_cma_handler - Handle CM events on client connections + * @cma_id: the server's listener rdma_cm_id + * @event: details of the event + * + * Return values: + * %0: Do not destroy @cma_id + * %1: Destroy @cma_id (never returned here) + */ +static int svc_rdma_cma_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event) { - struct svc_xprt *xprt = cma_id->context; - struct svcxprt_rdma *rdma = - container_of(xprt, struct svcxprt_rdma, sc_xprt); + struct svcxprt_rdma *rdma = cma_id->context; + struct svc_xprt *xprt = &rdma->sc_xprt; + switch (event->event) { case RDMA_CM_EVENT_ESTABLISHED: - /* Accept complete */ - svc_xprt_get(xprt); - dprintk("svcrdma: Connection completed on DTO xprt=%p, " - "cm_id=%p\n", xprt, cma_id); clear_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags); + + /* Handle any requests that were received while + * CONN_PENDING was set. */ svc_xprt_enqueue(xprt); break; case RDMA_CM_EVENT_DISCONNECTED: - dprintk("svcrdma: Disconnect on DTO xprt=%p, cm_id=%p\n", - xprt, cma_id); - if (xprt) { - set_bit(XPT_CLOSE, &xprt->xpt_flags); - svc_xprt_enqueue(xprt); - svc_xprt_put(xprt); - } - break; - case RDMA_CM_EVENT_DEVICE_REMOVAL: - dprintk("svcrdma: Device removal cma_id=%p, xprt = %p, " - "event=%d\n", cma_id, xprt, event->event); - if (xprt) { - set_bit(XPT_CLOSE, &xprt->xpt_flags); - svc_xprt_enqueue(xprt); - } + svc_xprt_deferred_close(xprt); break; default: - dprintk("svcrdma: Unexpected event on DTO endpoint %p, " - "event=%d\n", cma_id, event->event); break; } return 0; @@ -672,40 +357,22 @@ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, { struct rdma_cm_id *listen_id; struct svcxprt_rdma *cma_xprt; - struct svc_xprt *xprt; - int ret; - dprintk("svcrdma: Creating RDMA socket\n"); - if (sa->sa_family != AF_INET) { - dprintk("svcrdma: Address family %d is not supported.\n", sa->sa_family); + if (sa->sa_family != AF_INET && sa->sa_family != AF_INET6) return ERR_PTR(-EAFNOSUPPORT); - } - cma_xprt = rdma_create_xprt(serv, 1); + cma_xprt = svc_rdma_create_xprt(serv, net, NUMA_NO_NODE); if (!cma_xprt) return ERR_PTR(-ENOMEM); - xprt = &cma_xprt->sc_xprt; + set_bit(XPT_LISTENER, &cma_xprt->sc_xprt.xpt_flags); + strcpy(cma_xprt->sc_xprt.xpt_remotebuf, "listener"); - listen_id = rdma_create_id(rdma_listen_handler, cma_xprt, RDMA_PS_TCP, - IB_QPT_RC); + listen_id = svc_rdma_create_listen_id(net, sa, cma_xprt); if (IS_ERR(listen_id)) { - ret = PTR_ERR(listen_id); - dprintk("svcrdma: rdma_create_id failed = %d\n", ret); - goto err0; - } - - ret = rdma_bind_addr(listen_id, sa); - if (ret) { - dprintk("svcrdma: rdma_bind_addr failed = %d\n", ret); - goto err1; + kfree(cma_xprt); + return ERR_CAST(listen_id); } cma_xprt->sc_cm_id = listen_id; - ret = rdma_listen(listen_id, RPCRDMA_LISTEN_BACKLOG); - if (ret) { - dprintk("svcrdma: rdma_listen failed = %d\n", ret); - goto err1; - } - /* * We need to use the address from the cm_id in case the * caller specified 0 for the port number. @@ -714,103 +381,16 @@ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, svc_xprt_set_local(&cma_xprt->sc_xprt, sa, salen); return &cma_xprt->sc_xprt; - - err1: - rdma_destroy_id(listen_id); - err0: - kfree(cma_xprt); - return ERR_PTR(ret); } -static struct svc_rdma_fastreg_mr *rdma_alloc_frmr(struct svcxprt_rdma *xprt) +static void svc_rdma_xprt_done(struct rpcrdma_notification *rn) { - struct ib_mr *mr; - struct ib_fast_reg_page_list *pl; - struct svc_rdma_fastreg_mr *frmr; - - frmr = kmalloc(sizeof(*frmr), GFP_KERNEL); - if (!frmr) - goto err; - - mr = ib_alloc_fast_reg_mr(xprt->sc_pd, RPCSVC_MAXPAGES); - if (IS_ERR(mr)) - goto err_free_frmr; - - pl = ib_alloc_fast_reg_page_list(xprt->sc_cm_id->device, - RPCSVC_MAXPAGES); - if (IS_ERR(pl)) - goto err_free_mr; - - frmr->mr = mr; - frmr->page_list = pl; - INIT_LIST_HEAD(&frmr->frmr_list); - return frmr; - - err_free_mr: - ib_dereg_mr(mr); - err_free_frmr: - kfree(frmr); - err: - return ERR_PTR(-ENOMEM); -} + struct svcxprt_rdma *rdma = container_of(rn, struct svcxprt_rdma, + sc_rn); + struct rdma_cm_id *id = rdma->sc_cm_id; -static void rdma_dealloc_frmr_q(struct svcxprt_rdma *xprt) -{ - struct svc_rdma_fastreg_mr *frmr; - - while (!list_empty(&xprt->sc_frmr_q)) { - frmr = list_entry(xprt->sc_frmr_q.next, - struct svc_rdma_fastreg_mr, frmr_list); - list_del_init(&frmr->frmr_list); - ib_dereg_mr(frmr->mr); - ib_free_fast_reg_page_list(frmr->page_list); - kfree(frmr); - } -} - -struct svc_rdma_fastreg_mr *svc_rdma_get_frmr(struct svcxprt_rdma *rdma) -{ - struct svc_rdma_fastreg_mr *frmr = NULL; - - spin_lock_bh(&rdma->sc_frmr_q_lock); - if (!list_empty(&rdma->sc_frmr_q)) { - frmr = list_entry(rdma->sc_frmr_q.next, - struct svc_rdma_fastreg_mr, frmr_list); - list_del_init(&frmr->frmr_list); - frmr->map_len = 0; - frmr->page_list_len = 0; - } - spin_unlock_bh(&rdma->sc_frmr_q_lock); - if (frmr) - return frmr; - - return rdma_alloc_frmr(rdma); -} - -static void frmr_unmap_dma(struct svcxprt_rdma *xprt, - struct svc_rdma_fastreg_mr *frmr) -{ - int page_no; - for (page_no = 0; page_no < frmr->page_list_len; page_no++) { - dma_addr_t addr = frmr->page_list->page_list[page_no]; - if (ib_dma_mapping_error(frmr->mr->device, addr)) - continue; - atomic_dec(&xprt->sc_dma_used); - ib_dma_unmap_page(frmr->mr->device, addr, PAGE_SIZE, - frmr->direction); - } -} - -void svc_rdma_put_frmr(struct svcxprt_rdma *rdma, - struct svc_rdma_fastreg_mr *frmr) -{ - if (frmr) { - frmr_unmap_dma(rdma, frmr); - spin_lock_bh(&rdma->sc_frmr_q_lock); - BUG_ON(!list_empty(&frmr->frmr_list)); - list_add(&frmr->frmr_list, &rdma->sc_frmr_q); - spin_unlock_bh(&rdma->sc_frmr_q_lock); - } + trace_svcrdma_device_removal(id); + svc_xprt_close(&rdma->sc_xprt); } /* @@ -826,20 +406,20 @@ void svc_rdma_put_frmr(struct svcxprt_rdma *rdma, */ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) { + unsigned int ctxts, rq_depth, maxpayload; struct svcxprt_rdma *listen_rdma; struct svcxprt_rdma *newxprt = NULL; struct rdma_conn_param conn_param; + struct rpcrdma_connect_private pmsg; struct ib_qp_init_attr qp_attr; - struct ib_device_attr devattr; - int uninitialized_var(dma_mr_acc); - int need_dma_mr; - int ret; - int i; + struct ib_device *dev; + int ret = 0; + RPC_IFDEBUG(struct sockaddr *sap); listen_rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt); clear_bit(XPT_CONN, &xprt->xpt_flags); /* Get the next entry off the accept list */ - spin_lock_bh(&listen_rdma->sc_lock); + spin_lock(&listen_rdma->sc_lock); if (!list_empty(&listen_rdma->sc_accept_q)) { newxprt = list_entry(listen_rdma->sc_accept_q.next, struct svcxprt_rdma, sc_accept_q); @@ -847,326 +427,202 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) } if (!list_empty(&listen_rdma->sc_accept_q)) set_bit(XPT_CONN, &listen_rdma->sc_xprt.xpt_flags); - spin_unlock_bh(&listen_rdma->sc_lock); + spin_unlock(&listen_rdma->sc_lock); if (!newxprt) return NULL; - dprintk("svcrdma: newxprt from accept queue = %p, cm_id=%p\n", - newxprt, newxprt->sc_cm_id); + dev = newxprt->sc_cm_id->device; + newxprt->sc_port_num = newxprt->sc_cm_id->port_num; - ret = ib_query_device(newxprt->sc_cm_id->device, &devattr); - if (ret) { - dprintk("svcrdma: could not query device attributes on " - "device %p, rc=%d\n", newxprt->sc_cm_id->device, ret); + if (rpcrdma_rn_register(dev, &newxprt->sc_rn, svc_rdma_xprt_done)) goto errout; - } - /* Qualify the transport resource defaults with the - * capabilities of this particular device */ - newxprt->sc_max_sge = min((size_t)devattr.max_sge, - (size_t)RPCSVC_MAXPAGES); - newxprt->sc_max_requests = min((size_t)devattr.max_qp_wr, - (size_t)svcrdma_max_requests); - newxprt->sc_sq_depth = RPCRDMA_SQ_DEPTH_MULT * newxprt->sc_max_requests; + newxprt->sc_max_req_size = svcrdma_max_req_size; + newxprt->sc_max_requests = svcrdma_max_requests; + newxprt->sc_max_bc_requests = svcrdma_max_bc_requests; + newxprt->sc_recv_batch = RPCRDMA_MAX_RECV_BATCH; + newxprt->sc_fc_credits = cpu_to_be32(newxprt->sc_max_requests); - /* - * Limit ORD based on client limit, local device limit, and - * configured svcrdma limit. + /* Qualify the transport's resource defaults with the + * capabilities of this particular device. */ - newxprt->sc_ord = min_t(size_t, devattr.max_qp_rd_atom, newxprt->sc_ord); - newxprt->sc_ord = min_t(size_t, svcrdma_ord, newxprt->sc_ord); - newxprt->sc_pd = ib_alloc_pd(newxprt->sc_cm_id->device); + /* Transport header, head iovec, tail iovec */ + newxprt->sc_max_send_sges = 3; + /* Add one SGE per page list entry */ + newxprt->sc_max_send_sges += (svcrdma_max_req_size / PAGE_SIZE) + 1; + if (newxprt->sc_max_send_sges > dev->attrs.max_send_sge) + newxprt->sc_max_send_sges = dev->attrs.max_send_sge; + rq_depth = newxprt->sc_max_requests + newxprt->sc_max_bc_requests + + newxprt->sc_recv_batch + 1 /* drain */; + if (rq_depth > dev->attrs.max_qp_wr) { + rq_depth = dev->attrs.max_qp_wr; + newxprt->sc_recv_batch = 1; + newxprt->sc_max_requests = rq_depth - 2; + newxprt->sc_max_bc_requests = 2; + } + + /* Arbitrary estimate of the needed number of rdma_rw contexts. + */ + maxpayload = min(xprt->xpt_server->sv_max_payload, + RPCSVC_MAXPAYLOAD_RDMA); + ctxts = newxprt->sc_max_requests * 3 * + rdma_rw_mr_factor(dev, newxprt->sc_port_num, + maxpayload >> PAGE_SHIFT); + + newxprt->sc_sq_depth = rq_depth + ctxts; + if (newxprt->sc_sq_depth > dev->attrs.max_qp_wr) + newxprt->sc_sq_depth = dev->attrs.max_qp_wr; + atomic_set(&newxprt->sc_sq_avail, newxprt->sc_sq_depth); + + newxprt->sc_pd = ib_alloc_pd(dev, 0); if (IS_ERR(newxprt->sc_pd)) { - dprintk("svcrdma: error creating PD for connect request\n"); + trace_svcrdma_pd_err(newxprt, PTR_ERR(newxprt->sc_pd)); goto errout; } - newxprt->sc_sq_cq = ib_create_cq(newxprt->sc_cm_id->device, - sq_comp_handler, - cq_event_handler, - newxprt, - newxprt->sc_sq_depth, - 0); - if (IS_ERR(newxprt->sc_sq_cq)) { - dprintk("svcrdma: error creating SQ CQ for connect request\n"); + newxprt->sc_sq_cq = ib_alloc_cq_any(dev, newxprt, newxprt->sc_sq_depth, + IB_POLL_WORKQUEUE); + if (IS_ERR(newxprt->sc_sq_cq)) goto errout; - } - newxprt->sc_rq_cq = ib_create_cq(newxprt->sc_cm_id->device, - rq_comp_handler, - cq_event_handler, - newxprt, - newxprt->sc_max_requests, - 0); - if (IS_ERR(newxprt->sc_rq_cq)) { - dprintk("svcrdma: error creating RQ CQ for connect request\n"); + newxprt->sc_rq_cq = + ib_alloc_cq_any(dev, newxprt, rq_depth, IB_POLL_WORKQUEUE); + if (IS_ERR(newxprt->sc_rq_cq)) goto errout; - } memset(&qp_attr, 0, sizeof qp_attr); qp_attr.event_handler = qp_event_handler; qp_attr.qp_context = &newxprt->sc_xprt; - qp_attr.cap.max_send_wr = newxprt->sc_sq_depth; - qp_attr.cap.max_recv_wr = newxprt->sc_max_requests; - qp_attr.cap.max_send_sge = newxprt->sc_max_sge; - qp_attr.cap.max_recv_sge = newxprt->sc_max_sge; + qp_attr.port_num = newxprt->sc_port_num; + qp_attr.cap.max_rdma_ctxs = ctxts; + qp_attr.cap.max_send_wr = newxprt->sc_sq_depth - ctxts; + qp_attr.cap.max_recv_wr = rq_depth; + qp_attr.cap.max_send_sge = newxprt->sc_max_send_sges; + qp_attr.cap.max_recv_sge = 1; qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR; qp_attr.qp_type = IB_QPT_RC; qp_attr.send_cq = newxprt->sc_sq_cq; qp_attr.recv_cq = newxprt->sc_rq_cq; - dprintk("svcrdma: newxprt->sc_cm_id=%p, newxprt->sc_pd=%p\n" - " cm_id->device=%p, sc_pd->device=%p\n" - " cap.max_send_wr = %d\n" - " cap.max_recv_wr = %d\n" - " cap.max_send_sge = %d\n" - " cap.max_recv_sge = %d\n", - newxprt->sc_cm_id, newxprt->sc_pd, - newxprt->sc_cm_id->device, newxprt->sc_pd->device, - qp_attr.cap.max_send_wr, - qp_attr.cap.max_recv_wr, - qp_attr.cap.max_send_sge, - qp_attr.cap.max_recv_sge); - + dprintk(" cap.max_send_wr = %d, cap.max_recv_wr = %d\n", + qp_attr.cap.max_send_wr, qp_attr.cap.max_recv_wr); + dprintk(" cap.max_send_sge = %d, cap.max_recv_sge = %d\n", + qp_attr.cap.max_send_sge, qp_attr.cap.max_recv_sge); + dprintk(" send CQ depth = %u, recv CQ depth = %u\n", + newxprt->sc_sq_depth, rq_depth); ret = rdma_create_qp(newxprt->sc_cm_id, newxprt->sc_pd, &qp_attr); if (ret) { - /* - * XXX: This is a hack. We need a xx_request_qp interface - * that will adjust the qp_attr's with a best-effort - * number - */ - qp_attr.cap.max_send_sge -= 2; - qp_attr.cap.max_recv_sge -= 2; - ret = rdma_create_qp(newxprt->sc_cm_id, newxprt->sc_pd, - &qp_attr); - if (ret) { - dprintk("svcrdma: failed to create QP, ret=%d\n", ret); - goto errout; - } - newxprt->sc_max_sge = qp_attr.cap.max_send_sge; - newxprt->sc_max_sge = qp_attr.cap.max_recv_sge; - newxprt->sc_sq_depth = qp_attr.cap.max_send_wr; - newxprt->sc_max_requests = qp_attr.cap.max_recv_wr; + trace_svcrdma_qp_err(newxprt, ret); + goto errout; } + newxprt->sc_max_send_sges = qp_attr.cap.max_send_sge; newxprt->sc_qp = newxprt->sc_cm_id->qp; - /* - * Use the most secure set of MR resources based on the - * transport type and available memory management features in - * the device. Here's the table implemented below: - * - * Fast Global DMA Remote WR - * Reg LKEY MR Access - * Sup'd Sup'd Needed Needed - * - * IWARP N N Y Y - * N Y Y Y - * Y N Y N - * Y Y N - - * - * IB N N Y N - * N Y N - - * Y N Y N - * Y Y N - - * - * NB: iWARP requires remote write access for the data sink - * of an RDMA_READ. IB does not. - */ - if (devattr.device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS) { - newxprt->sc_frmr_pg_list_len = - devattr.max_fast_reg_page_list_len; - newxprt->sc_dev_caps |= SVCRDMA_DEVCAP_FAST_REG; - } - - /* - * Determine if a DMA MR is required and if so, what privs are required - */ - switch (rdma_node_get_transport(newxprt->sc_cm_id->device->node_type)) { - case RDMA_TRANSPORT_IWARP: - newxprt->sc_dev_caps |= SVCRDMA_DEVCAP_READ_W_INV; - if (!(newxprt->sc_dev_caps & SVCRDMA_DEVCAP_FAST_REG)) { - need_dma_mr = 1; - dma_mr_acc = - (IB_ACCESS_LOCAL_WRITE | - IB_ACCESS_REMOTE_WRITE); - } else if (!(devattr.device_cap_flags & IB_DEVICE_LOCAL_DMA_LKEY)) { - need_dma_mr = 1; - dma_mr_acc = IB_ACCESS_LOCAL_WRITE; - } else - need_dma_mr = 0; - break; - case RDMA_TRANSPORT_IB: - if (!(devattr.device_cap_flags & IB_DEVICE_LOCAL_DMA_LKEY)) { - need_dma_mr = 1; - dma_mr_acc = IB_ACCESS_LOCAL_WRITE; - } else - need_dma_mr = 0; - break; - default: + if (!(dev->attrs.device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS)) + newxprt->sc_snd_w_inv = false; + if (!rdma_protocol_iwarp(dev, newxprt->sc_port_num) && + !rdma_ib_or_roce(dev, newxprt->sc_port_num)) { + trace_svcrdma_fabric_err(newxprt, -EINVAL); goto errout; } - /* Create the DMA MR if needed, otherwise, use the DMA LKEY */ - if (need_dma_mr) { - /* Register all of physical memory */ - newxprt->sc_phys_mr = - ib_get_dma_mr(newxprt->sc_pd, dma_mr_acc); - if (IS_ERR(newxprt->sc_phys_mr)) { - dprintk("svcrdma: Failed to create DMA MR ret=%d\n", - ret); - goto errout; - } - newxprt->sc_dma_lkey = newxprt->sc_phys_mr->lkey; - } else - newxprt->sc_dma_lkey = - newxprt->sc_cm_id->device->local_dma_lkey; - - /* Post receive buffers */ - for (i = 0; i < newxprt->sc_max_requests; i++) { - ret = svc_rdma_post_recv(newxprt); - if (ret) { - dprintk("svcrdma: failure posting receive buffers\n"); - goto errout; - } - } - - /* Swap out the handler */ - newxprt->sc_cm_id->event_handler = rdma_cma_handler; + if (!svc_rdma_post_recvs(newxprt)) + goto errout; - /* - * Arm the CQs for the SQ and RQ before accepting so we can't - * miss the first message - */ - ib_req_notify_cq(newxprt->sc_sq_cq, IB_CQ_NEXT_COMP); - ib_req_notify_cq(newxprt->sc_rq_cq, IB_CQ_NEXT_COMP); + /* Construct RDMA-CM private message */ + pmsg.cp_magic = rpcrdma_cmp_magic; + pmsg.cp_version = RPCRDMA_CMP_VERSION; + pmsg.cp_flags = 0; + pmsg.cp_send_size = pmsg.cp_recv_size = + rpcrdma_encode_buffer_size(newxprt->sc_max_req_size); /* Accept Connection */ set_bit(RDMAXPRT_CONN_PENDING, &newxprt->sc_flags); memset(&conn_param, 0, sizeof conn_param); conn_param.responder_resources = 0; - conn_param.initiator_depth = newxprt->sc_ord; + conn_param.initiator_depth = min_t(int, newxprt->sc_ord, + dev->attrs.max_qp_init_rd_atom); + if (!conn_param.initiator_depth) { + ret = -EINVAL; + trace_svcrdma_initdepth_err(newxprt, ret); + goto errout; + } + conn_param.private_data = &pmsg; + conn_param.private_data_len = sizeof(pmsg); + rdma_lock_handler(newxprt->sc_cm_id); + newxprt->sc_cm_id->event_handler = svc_rdma_cma_handler; ret = rdma_accept(newxprt->sc_cm_id, &conn_param); + rdma_unlock_handler(newxprt->sc_cm_id); if (ret) { - dprintk("svcrdma: failed to accept new connection, ret=%d\n", - ret); + trace_svcrdma_accept_err(newxprt, ret); goto errout; } - dprintk("svcrdma: new connection %p accepted with the following " - "attributes:\n" - " local_ip : %pI4\n" - " local_port : %d\n" - " remote_ip : %pI4\n" - " remote_port : %d\n" - " max_sge : %d\n" - " sq_depth : %d\n" - " max_requests : %d\n" - " ord : %d\n", - newxprt, - &((struct sockaddr_in *)&newxprt->sc_cm_id-> - route.addr.src_addr)->sin_addr.s_addr, - ntohs(((struct sockaddr_in *)&newxprt->sc_cm_id-> - route.addr.src_addr)->sin_port), - &((struct sockaddr_in *)&newxprt->sc_cm_id-> - route.addr.dst_addr)->sin_addr.s_addr, - ntohs(((struct sockaddr_in *)&newxprt->sc_cm_id-> - route.addr.dst_addr)->sin_port), - newxprt->sc_max_sge, - newxprt->sc_sq_depth, - newxprt->sc_max_requests, - newxprt->sc_ord); +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + dprintk("svcrdma: new connection accepted on device %s:\n", dev->name); + sap = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr; + dprintk(" local address : %pIS:%u\n", sap, rpc_get_port(sap)); + sap = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.dst_addr; + dprintk(" remote address : %pIS:%u\n", sap, rpc_get_port(sap)); + dprintk(" max_sge : %d\n", newxprt->sc_max_send_sges); + dprintk(" sq_depth : %d\n", newxprt->sc_sq_depth); + dprintk(" rdma_rw_ctxs : %d\n", ctxts); + dprintk(" max_requests : %d\n", newxprt->sc_max_requests); + dprintk(" ord : %d\n", conn_param.initiator_depth); +#endif return &newxprt->sc_xprt; errout: - dprintk("svcrdma: failure accepting new connection rc=%d.\n", ret); /* Take a reference in case the DTO handler runs */ svc_xprt_get(&newxprt->sc_xprt); if (newxprt->sc_qp && !IS_ERR(newxprt->sc_qp)) ib_destroy_qp(newxprt->sc_qp); rdma_destroy_id(newxprt->sc_cm_id); + rpcrdma_rn_unregister(dev, &newxprt->sc_rn); /* This call to put will destroy the transport */ svc_xprt_put(&newxprt->sc_xprt); return NULL; } -static void svc_rdma_release_rqst(struct svc_rqst *rqstp) -{ -} - -/* - * When connected, an svc_xprt has at least two references: - * - * - A reference held by the cm_id between the ESTABLISHED and - * DISCONNECTED events. If the remote peer disconnected first, this - * reference could be gone. - * - * - A reference held by the svc_recv code that called this function - * as part of close processing. - * - * At a minimum one references should still be held. - */ static void svc_rdma_detach(struct svc_xprt *xprt) { struct svcxprt_rdma *rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt); - dprintk("svc: svc_rdma_detach(%p)\n", xprt); - /* Disconnect and flush posted WQE */ rdma_disconnect(rdma->sc_cm_id); } -static void __svc_rdma_free(struct work_struct *work) +/** + * svc_rdma_free - Release class-specific transport resources + * @xprt: Generic svc transport object + */ +static void svc_rdma_free(struct svc_xprt *xprt) { struct svcxprt_rdma *rdma = - container_of(work, struct svcxprt_rdma, sc_work); - dprintk("svcrdma: svc_rdma_free(%p)\n", rdma); - - /* We should only be called from kref_put */ - BUG_ON(atomic_read(&rdma->sc_xprt.xpt_ref.refcount) != 0); + container_of(xprt, struct svcxprt_rdma, sc_xprt); + struct ib_device *device = rdma->sc_cm_id->device; - /* - * Destroy queued, but not processed read completions. Note - * that this cleanup has to be done before destroying the - * cm_id because the device ptr is needed to unmap the dma in - * svc_rdma_put_context. - */ - while (!list_empty(&rdma->sc_read_complete_q)) { - struct svc_rdma_op_ctxt *ctxt; - ctxt = list_entry(rdma->sc_read_complete_q.next, - struct svc_rdma_op_ctxt, - dto_q); - list_del_init(&ctxt->dto_q); - svc_rdma_put_context(ctxt, 1); - } + might_sleep(); - /* Destroy queued, but not processed recv completions */ - while (!list_empty(&rdma->sc_rq_dto_q)) { - struct svc_rdma_op_ctxt *ctxt; - ctxt = list_entry(rdma->sc_rq_dto_q.next, - struct svc_rdma_op_ctxt, - dto_q); - list_del_init(&ctxt->dto_q); - svc_rdma_put_context(ctxt, 1); - } + /* This blocks until the Completion Queues are empty */ + if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) + ib_drain_qp(rdma->sc_qp); + flush_workqueue(svcrdma_wq); - /* Warn if we leaked a resource or under-referenced */ - WARN_ON(atomic_read(&rdma->sc_ctxt_used) != 0); - WARN_ON(atomic_read(&rdma->sc_dma_used) != 0); + svc_rdma_flush_recv_queues(rdma); - /* De-allocate fastreg mr */ - rdma_dealloc_frmr_q(rdma); + svc_rdma_destroy_rw_ctxts(rdma); + svc_rdma_send_ctxts_destroy(rdma); + svc_rdma_recv_ctxts_destroy(rdma); /* Destroy the QP if present (not a listener) */ if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) ib_destroy_qp(rdma->sc_qp); if (rdma->sc_sq_cq && !IS_ERR(rdma->sc_sq_cq)) - ib_destroy_cq(rdma->sc_sq_cq); + ib_free_cq(rdma->sc_sq_cq); if (rdma->sc_rq_cq && !IS_ERR(rdma->sc_rq_cq)) - ib_destroy_cq(rdma->sc_rq_cq); - - if (rdma->sc_phys_mr && !IS_ERR(rdma->sc_phys_mr)) - ib_dereg_mr(rdma->sc_phys_mr); + ib_free_cq(rdma->sc_rq_cq); if (rdma->sc_pd && !IS_ERR(rdma->sc_pd)) ib_dealloc_pd(rdma->sc_pd); @@ -1174,31 +630,18 @@ static void __svc_rdma_free(struct work_struct *work) /* Destroy the CM ID */ rdma_destroy_id(rdma->sc_cm_id); + if (!test_bit(XPT_LISTENER, &rdma->sc_xprt.xpt_flags)) + rpcrdma_rn_unregister(device, &rdma->sc_rn); kfree(rdma); } -static void svc_rdma_free(struct svc_xprt *xprt) -{ - struct svcxprt_rdma *rdma = - container_of(xprt, struct svcxprt_rdma, sc_xprt); - INIT_WORK(&rdma->sc_work, __svc_rdma_free); - queue_work(svc_rdma_wq, &rdma->sc_work); -} - static int svc_rdma_has_wspace(struct svc_xprt *xprt) { struct svcxprt_rdma *rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt); /* - * If there are fewer SQ WR available than required to send a - * simple response, return false. - */ - if ((rdma->sc_sq_depth - atomic_read(&rdma->sc_sq_count) < 3)) - return 0; - - /* - * ...or there are already waiters on the SQ, + * If there are already waiters on the SQ, * return false. */ if (waitqueue_active(&rdma->sc_send_wait)) @@ -1208,146 +651,6 @@ static int svc_rdma_has_wspace(struct svc_xprt *xprt) return 1; } -/* - * Attempt to register the kvec representing the RPC memory with the - * device. - * - * Returns: - * NULL : The device does not support fastreg or there were no more - * fastreg mr. - * frmr : The kvec register request was successfully posted. - * <0 : An error was encountered attempting to register the kvec. - */ -int svc_rdma_fastreg(struct svcxprt_rdma *xprt, - struct svc_rdma_fastreg_mr *frmr) -{ - struct ib_send_wr fastreg_wr; - u8 key; - - /* Bump the key */ - key = (u8)(frmr->mr->lkey & 0x000000FF); - ib_update_fast_reg_key(frmr->mr, ++key); - - /* Prepare FASTREG WR */ - memset(&fastreg_wr, 0, sizeof fastreg_wr); - fastreg_wr.opcode = IB_WR_FAST_REG_MR; - fastreg_wr.send_flags = IB_SEND_SIGNALED; - fastreg_wr.wr.fast_reg.iova_start = (unsigned long)frmr->kva; - fastreg_wr.wr.fast_reg.page_list = frmr->page_list; - fastreg_wr.wr.fast_reg.page_list_len = frmr->page_list_len; - fastreg_wr.wr.fast_reg.page_shift = PAGE_SHIFT; - fastreg_wr.wr.fast_reg.length = frmr->map_len; - fastreg_wr.wr.fast_reg.access_flags = frmr->access_flags; - fastreg_wr.wr.fast_reg.rkey = frmr->mr->lkey; - return svc_rdma_send(xprt, &fastreg_wr); -} - -int svc_rdma_send(struct svcxprt_rdma *xprt, struct ib_send_wr *wr) +static void svc_rdma_kill_temp_xprt(struct svc_xprt *xprt) { - struct ib_send_wr *bad_wr, *n_wr; - int wr_count; - int i; - int ret; - - if (test_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags)) - return -ENOTCONN; - - BUG_ON(wr->send_flags != IB_SEND_SIGNALED); - wr_count = 1; - for (n_wr = wr->next; n_wr; n_wr = n_wr->next) - wr_count++; - - /* If the SQ is full, wait until an SQ entry is available */ - while (1) { - spin_lock_bh(&xprt->sc_lock); - if (xprt->sc_sq_depth < atomic_read(&xprt->sc_sq_count) + wr_count) { - spin_unlock_bh(&xprt->sc_lock); - atomic_inc(&rdma_stat_sq_starve); - - /* See if we can opportunistically reap SQ WR to make room */ - sq_cq_reap(xprt); - - /* Wait until SQ WR available if SQ still full */ - wait_event(xprt->sc_send_wait, - atomic_read(&xprt->sc_sq_count) < - xprt->sc_sq_depth); - if (test_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags)) - return -ENOTCONN; - continue; - } - /* Take a transport ref for each WR posted */ - for (i = 0; i < wr_count; i++) - svc_xprt_get(&xprt->sc_xprt); - - /* Bump used SQ WR count and post */ - atomic_add(wr_count, &xprt->sc_sq_count); - ret = ib_post_send(xprt->sc_qp, wr, &bad_wr); - if (ret) { - set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); - atomic_sub(wr_count, &xprt->sc_sq_count); - for (i = 0; i < wr_count; i ++) - svc_xprt_put(&xprt->sc_xprt); - dprintk("svcrdma: failed to post SQ WR rc=%d, " - "sc_sq_count=%d, sc_sq_depth=%d\n", - ret, atomic_read(&xprt->sc_sq_count), - xprt->sc_sq_depth); - } - spin_unlock_bh(&xprt->sc_lock); - if (ret) - wake_up(&xprt->sc_send_wait); - break; - } - return ret; -} - -void svc_rdma_send_error(struct svcxprt_rdma *xprt, struct rpcrdma_msg *rmsgp, - enum rpcrdma_errcode err) -{ - struct ib_send_wr err_wr; - struct page *p; - struct svc_rdma_op_ctxt *ctxt; - u32 *va; - int length; - int ret; - - p = svc_rdma_get_page(); - va = page_address(p); - - /* XDR encode error */ - length = svc_rdma_xdr_encode_error(xprt, rmsgp, err, va); - - ctxt = svc_rdma_get_context(xprt); - ctxt->direction = DMA_FROM_DEVICE; - ctxt->count = 1; - ctxt->pages[0] = p; - - /* Prepare SGE for local address */ - ctxt->sge[0].addr = ib_dma_map_page(xprt->sc_cm_id->device, - p, 0, length, DMA_FROM_DEVICE); - if (ib_dma_mapping_error(xprt->sc_cm_id->device, ctxt->sge[0].addr)) { - put_page(p); - svc_rdma_put_context(ctxt, 1); - return; - } - atomic_inc(&xprt->sc_dma_used); - ctxt->sge[0].lkey = xprt->sc_dma_lkey; - ctxt->sge[0].length = length; - - /* Prepare SEND WR */ - memset(&err_wr, 0, sizeof err_wr); - ctxt->wr_op = IB_WR_SEND; - err_wr.wr_id = (unsigned long)ctxt; - err_wr.sg_list = ctxt->sge; - err_wr.num_sge = 1; - err_wr.opcode = IB_WR_SEND; - err_wr.send_flags = IB_SEND_SIGNALED; - - /* Post It */ - ret = svc_rdma_send(xprt, &err_wr); - if (ret) { - dprintk("svcrdma: Error %d posting send for protocol error\n", - ret); - svc_rdma_unmap_dma(ctxt); - svc_rdma_put_context(ctxt, 1); - } } diff --git a/net/sunrpc/xprtrdma/transport.c b/net/sunrpc/xprtrdma/transport.c index 285dc0884115..9a8ce5df83ca 100644 --- a/net/sunrpc/xprtrdma/transport.c +++ b/net/sunrpc/xprtrdma/transport.c @@ -1,4 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause /* + * Copyright (c) 2014-2017 Oracle. All rights reserved. * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -48,41 +50,37 @@ */ #include <linux/module.h> -#include <linux/init.h> #include <linux/slab.h> #include <linux/seq_file.h> +#include <linux/smp.h> + #include <linux/sunrpc/addr.h> +#include <linux/sunrpc/svc_rdma.h> #include "xprt_rdma.h" - -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_TRANS -#endif - -MODULE_LICENSE("Dual BSD/GPL"); - -MODULE_DESCRIPTION("RPC/RDMA Transport for Linux kernel NFS"); -MODULE_AUTHOR("Network Appliance, Inc."); +#include <trace/events/rpcrdma.h> /* * tunables */ static unsigned int xprt_rdma_slot_table_entries = RPCRDMA_DEF_SLOT_TABLE; -static unsigned int xprt_rdma_max_inline_read = RPCRDMA_DEF_INLINE; -static unsigned int xprt_rdma_max_inline_write = RPCRDMA_DEF_INLINE; -static unsigned int xprt_rdma_inline_write_padding; -static unsigned int xprt_rdma_memreg_strategy = RPCRDMA_FRMR; - int xprt_rdma_pad_optimize = 0; +unsigned int xprt_rdma_max_inline_read = RPCRDMA_DEF_INLINE; +unsigned int xprt_rdma_max_inline_write = RPCRDMA_DEF_INLINE; +unsigned int xprt_rdma_memreg_strategy = RPCRDMA_FRWR; +int xprt_rdma_pad_optimize; +static struct xprt_class xprt_rdma; -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) static unsigned int min_slot_table_size = RPCRDMA_MIN_SLOT_TABLE; static unsigned int max_slot_table_size = RPCRDMA_MAX_SLOT_TABLE; -static unsigned int zero; +static unsigned int min_inline_size = RPCRDMA_MIN_INLINE; +static unsigned int max_inline_size = RPCRDMA_MAX_INLINE; static unsigned int max_padding = PAGE_SIZE; static unsigned int min_memreg = RPCRDMA_BOUNCEBUFFERS; static unsigned int max_memreg = RPCRDMA_LAST - 1; +static unsigned int dummy; static struct ctl_table_header *sunrpc_table_header; @@ -101,22 +99,26 @@ static struct ctl_table xr_tunables_table[] = { .data = &xprt_rdma_max_inline_read, .maxlen = sizeof(unsigned int), .mode = 0644, - .proc_handler = proc_dointvec, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_inline_size, + .extra2 = &max_inline_size, }, { .procname = "rdma_max_inline_write", .data = &xprt_rdma_max_inline_write, .maxlen = sizeof(unsigned int), .mode = 0644, - .proc_handler = proc_dointvec, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_inline_size, + .extra2 = &max_inline_size, }, { .procname = "rdma_inline_write_padding", - .data = &xprt_rdma_inline_write_padding, + .data = &dummy, .maxlen = sizeof(unsigned int), .mode = 0644, .proc_handler = proc_dointvec_minmax, - .extra1 = &zero, + .extra1 = SYSCTL_ZERO, .extra2 = &max_padding, }, { @@ -135,29 +137,52 @@ static struct ctl_table xr_tunables_table[] = { .mode = 0644, .proc_handler = proc_dointvec, }, - { }, -}; - -static struct ctl_table sunrpc_table[] = { - { - .procname = "sunrpc", - .mode = 0555, - .child = xr_tunables_table - }, - { }, }; #endif -static struct rpc_xprt_ops xprt_rdma_procs; /* forward reference */ +static const struct rpc_xprt_ops xprt_rdma_procs; static void -xprt_rdma_format_addresses(struct rpc_xprt *xprt) +xprt_rdma_format_addresses4(struct rpc_xprt *xprt, struct sockaddr *sap) { - struct sockaddr *sap = (struct sockaddr *) - &rpcx_to_rdmad(xprt).addr; struct sockaddr_in *sin = (struct sockaddr_in *)sap; - char buf[64]; + char buf[20]; + + snprintf(buf, sizeof(buf), "%08x", ntohl(sin->sin_addr.s_addr)); + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); + + xprt->address_strings[RPC_DISPLAY_NETID] = RPCBIND_NETID_RDMA; +} + +static void +xprt_rdma_format_addresses6(struct rpc_xprt *xprt, struct sockaddr *sap) +{ + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; + char buf[40]; + + snprintf(buf, sizeof(buf), "%pi6", &sin6->sin6_addr); + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); + + xprt->address_strings[RPC_DISPLAY_NETID] = RPCBIND_NETID_RDMA6; +} + +void +xprt_rdma_format_addresses(struct rpc_xprt *xprt, struct sockaddr *sap) +{ + char buf[128]; + + switch (sap->sa_family) { + case AF_INET: + xprt_rdma_format_addresses4(xprt, sap); + break; + case AF_INET6: + xprt_rdma_format_addresses6(xprt, sap); + break; + default: + pr_err("rpcrdma: Unrecognized address family\n"); + return; + } (void)rpc_ntop(sap, buf, sizeof(buf)); xprt->address_strings[RPC_DISPLAY_ADDR] = kstrdup(buf, GFP_KERNEL); @@ -165,19 +190,13 @@ xprt_rdma_format_addresses(struct rpc_xprt *xprt) snprintf(buf, sizeof(buf), "%u", rpc_get_port(sap)); xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL); - xprt->address_strings[RPC_DISPLAY_PROTO] = "rdma"; - - snprintf(buf, sizeof(buf), "%08x", ntohl(sin->sin_addr.s_addr)); - xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); - snprintf(buf, sizeof(buf), "%4hx", rpc_get_port(sap)); xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL); - /* netid */ - xprt->address_strings[RPC_DISPLAY_NETID] = "rdma"; + xprt->address_strings[RPC_DISPLAY_PROTO] = "rdma"; } -static void +void xprt_rdma_free_addresses(struct rpc_xprt *xprt) { unsigned int i; @@ -192,67 +211,83 @@ xprt_rdma_free_addresses(struct rpc_xprt *xprt) } } +/** + * xprt_rdma_connect_worker - establish connection in the background + * @work: worker thread context + * + * Requester holds the xprt's send lock to prevent activity on this + * transport while a fresh connection is being established. RPC tasks + * sleep on the xprt's pending queue waiting for connect to complete. + */ static void xprt_rdma_connect_worker(struct work_struct *work) { - struct rpcrdma_xprt *r_xprt = - container_of(work, struct rpcrdma_xprt, rdma_connect.work); - struct rpc_xprt *xprt = &r_xprt->xprt; - int rc = 0; - - current->flags |= PF_FSTRANS; - xprt_clear_connected(xprt); - - dprintk("RPC: %s: %sconnect\n", __func__, - r_xprt->rx_ep.rep_connected != 0 ? "re" : ""); - rc = rpcrdma_ep_connect(&r_xprt->rx_ep, &r_xprt->rx_ia); - if (rc) - xprt_wake_pending_tasks(xprt, rc); + struct rpcrdma_xprt *r_xprt = container_of(work, struct rpcrdma_xprt, + rx_connect_worker.work); + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + unsigned int pflags = current->flags; + int rc; - dprintk("RPC: %s: exit\n", __func__); + if (atomic_read(&xprt->swapper)) + current->flags |= PF_MEMALLOC; + rc = rpcrdma_xprt_connect(r_xprt); xprt_clear_connecting(xprt); - current->flags &= ~PF_FSTRANS; + if (!rc) { + xprt->connect_cookie++; + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; + xprt_set_connected(xprt); + rc = -EAGAIN; + } else + rpcrdma_xprt_disconnect(r_xprt); + xprt_unlock_connect(xprt, r_xprt); + xprt_wake_pending_tasks(xprt, rc); + current_restore_flags(pflags, PF_MEMALLOC); } -/* - * xprt_rdma_destroy +/** + * xprt_rdma_inject_disconnect - inject a connection fault + * @xprt: transport context * - * Destroy the xprt. - * Free all memory associated with the object, including its own. - * NOTE: none of the *destroy methods free memory for their top-level - * objects, even though they may have allocated it (they do free - * private memory). It's up to the caller to handle it. In this - * case (RDMA transport), all structure memory is inlined with the - * struct rpcrdma_xprt. + * If @xprt is connected, disconnect it to simulate spurious + * connection loss. Caller must hold @xprt's send lock to + * ensure that data structures and hardware resources are + * stable during the rdma_disconnect() call. */ static void -xprt_rdma_destroy(struct rpc_xprt *xprt) +xprt_rdma_inject_disconnect(struct rpc_xprt *xprt) { struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); - int rc; - dprintk("RPC: %s: called\n", __func__); + trace_xprtrdma_op_inject_dsc(r_xprt); + rdma_disconnect(r_xprt->rx_ep->re_id); +} - cancel_delayed_work_sync(&r_xprt->rdma_connect); +/** + * xprt_rdma_destroy - Full tear down of transport + * @xprt: doomed transport context + * + * Caller guarantees there will be no more calls to us with + * this @xprt. + */ +static void +xprt_rdma_destroy(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); - xprt_clear_connected(xprt); + cancel_delayed_work_sync(&r_xprt->rx_connect_worker); + rpcrdma_xprt_disconnect(r_xprt); rpcrdma_buffer_destroy(&r_xprt->rx_buf); - rc = rpcrdma_ep_destroy(&r_xprt->rx_ep, &r_xprt->rx_ia); - if (rc) - dprintk("RPC: %s: rpcrdma_ep_destroy returned %i\n", - __func__, rc); - rpcrdma_ia_close(&r_xprt->rx_ia); xprt_rdma_free_addresses(xprt); - xprt_free(xprt); - dprintk("RPC: %s: returning\n", __func__); - module_put(THIS_MODULE); } +/* 60 second timeout, no retries */ static const struct rpc_timeout xprt_rdma_default_timeout = { .to_initval = 60 * HZ, .to_maxval = 60 * HZ, @@ -266,406 +301,371 @@ static const struct rpc_timeout xprt_rdma_default_timeout = { static struct rpc_xprt * xprt_setup_rdma(struct xprt_create *args) { - struct rpcrdma_create_data_internal cdata; struct rpc_xprt *xprt; struct rpcrdma_xprt *new_xprt; - struct rpcrdma_ep *new_ep; - struct sockaddr_in *sin; + struct sockaddr *sap; int rc; - if (args->addrlen > sizeof(xprt->addr)) { - dprintk("RPC: %s: address too large\n", __func__); + if (args->addrlen > sizeof(xprt->addr)) return ERR_PTR(-EBADF); - } - xprt = xprt_alloc(args->net, sizeof(struct rpcrdma_xprt), - xprt_rdma_slot_table_entries, - xprt_rdma_slot_table_entries); - if (xprt == NULL) { - dprintk("RPC: %s: couldn't allocate rpcrdma_xprt\n", - __func__); + if (!try_module_get(THIS_MODULE)) + return ERR_PTR(-EIO); + + xprt = xprt_alloc(args->net, sizeof(struct rpcrdma_xprt), 0, + xprt_rdma_slot_table_entries); + if (!xprt) { + module_put(THIS_MODULE); return ERR_PTR(-ENOMEM); } - /* 60 second timeout, no retries */ xprt->timeout = &xprt_rdma_default_timeout; - xprt->bind_timeout = (60U * HZ); - xprt->reestablish_timeout = (5U * HZ); - xprt->idle_timeout = (5U * 60 * HZ); + xprt->connect_timeout = xprt->timeout->to_initval; + xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + xprt->bind_timeout = RPCRDMA_BIND_TO; + xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO; + xprt->idle_timeout = RPCRDMA_IDLE_DISC_TO; xprt->resvport = 0; /* privileged port not needed */ - xprt->tsh_size = 0; /* RPC-RDMA handles framing */ - xprt->max_payload = RPCRDMA_MAX_DATA_SEGS * PAGE_SIZE; xprt->ops = &xprt_rdma_procs; /* * Set up RDMA-specific connect data. */ - - /* Put server RDMA address in local cdata */ - memcpy(&cdata.addr, args->dstaddr, args->addrlen); + sap = args->dstaddr; /* Ensure xprt->addr holds valid server TCP (not RDMA) * address, for any side protocols which peek at it */ xprt->prot = IPPROTO_TCP; + xprt->xprt_class = &xprt_rdma; xprt->addrlen = args->addrlen; - memcpy(&xprt->addr, &cdata.addr, xprt->addrlen); + memcpy(&xprt->addr, sap, xprt->addrlen); - sin = (struct sockaddr_in *)&cdata.addr; - if (ntohs(sin->sin_port) != 0) + if (rpc_get_port(sap)) xprt_set_bound(xprt); + xprt_rdma_format_addresses(xprt, sap); - dprintk("RPC: %s: %pI4:%u\n", - __func__, &sin->sin_addr.s_addr, ntohs(sin->sin_port)); + new_xprt = rpcx_to_rdmax(xprt); + rc = rpcrdma_buffer_create(new_xprt); + if (rc) { + xprt_rdma_free_addresses(xprt); + xprt_free(xprt); + module_put(THIS_MODULE); + return ERR_PTR(rc); + } - /* Set max requests */ - cdata.max_requests = xprt->max_reqs; + INIT_DELAYED_WORK(&new_xprt->rx_connect_worker, + xprt_rdma_connect_worker); - /* Set some length limits */ - cdata.rsize = RPCRDMA_MAX_SEGS * PAGE_SIZE; /* RDMA write max */ - cdata.wsize = RPCRDMA_MAX_SEGS * PAGE_SIZE; /* RDMA read max */ + xprt->max_payload = RPCRDMA_MAX_DATA_SEGS << PAGE_SHIFT; - cdata.inline_wsize = xprt_rdma_max_inline_write; - if (cdata.inline_wsize > cdata.wsize) - cdata.inline_wsize = cdata.wsize; + return xprt; +} - cdata.inline_rsize = xprt_rdma_max_inline_read; - if (cdata.inline_rsize > cdata.rsize) - cdata.inline_rsize = cdata.rsize; +/** + * xprt_rdma_close - close a transport connection + * @xprt: transport context + * + * Called during autoclose or device removal. + * + * Caller holds @xprt's send lock to prevent activity on this + * transport while the connection is torn down. + */ +void xprt_rdma_close(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); - cdata.padding = xprt_rdma_inline_write_padding; + rpcrdma_xprt_disconnect(r_xprt); - /* - * Create new transport instance, which includes initialized - * o ia - * o endpoint - * o buffers - */ + xprt->reestablish_timeout = 0; + ++xprt->connect_cookie; + xprt_disconnect_done(xprt); +} - new_xprt = rpcx_to_rdmax(xprt); +/** + * xprt_rdma_set_port - update server port with rpcbind result + * @xprt: controlling RPC transport + * @port: new port value + * + * Transport connect status is unchanged. + */ +static void +xprt_rdma_set_port(struct rpc_xprt *xprt, u16 port) +{ + struct sockaddr *sap = (struct sockaddr *)&xprt->addr; + char buf[8]; - rc = rpcrdma_ia_open(new_xprt, (struct sockaddr *) &cdata.addr, - xprt_rdma_memreg_strategy); - if (rc) - goto out1; + rpc_set_port(sap, port); - /* - * initialize and create ep - */ - new_xprt->rx_data = cdata; - new_ep = &new_xprt->rx_ep; - new_ep->rep_remote_addr = cdata.addr; + kfree(xprt->address_strings[RPC_DISPLAY_PORT]); + snprintf(buf, sizeof(buf), "%u", port); + xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL); - rc = rpcrdma_ep_create(&new_xprt->rx_ep, - &new_xprt->rx_ia, &new_xprt->rx_data); - if (rc) - goto out2; + kfree(xprt->address_strings[RPC_DISPLAY_HEX_PORT]); + snprintf(buf, sizeof(buf), "%4hx", port); + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL); +} - /* - * Allocate pre-registered send and receive buffers for headers and - * any inline data. Also specify any padding which will be provided - * from a preregistered zero buffer. - */ - rc = rpcrdma_buffer_create(&new_xprt->rx_buf, new_ep, &new_xprt->rx_ia, - &new_xprt->rx_data); - if (rc) - goto out3; +/** + * xprt_rdma_timer - invoked when an RPC times out + * @xprt: controlling RPC transport + * @task: RPC task that timed out + * + * Invoked when the transport is still connected, but an RPC + * retransmit timeout occurs. + * + * Since RDMA connections don't have a keep-alive, forcibly + * disconnect and retry to connect. This drives full + * detection of the network path, and retransmissions of + * all pending RPCs. + */ +static void +xprt_rdma_timer(struct rpc_xprt *xprt, struct rpc_task *task) +{ + xprt_force_disconnect(xprt); +} - /* - * Register a callback for connection events. This is necessary because - * connection loss notification is async. We also catch connection loss - * when reaping receives. - */ - INIT_DELAYED_WORK(&new_xprt->rdma_connect, xprt_rdma_connect_worker); - new_ep->rep_func = rpcrdma_conn_func; - new_ep->rep_xprt = xprt; +/** + * xprt_rdma_set_connect_timeout - set timeouts for establishing a connection + * @xprt: controlling transport instance + * @connect_timeout: reconnect timeout after client disconnects + * @reconnect_timeout: reconnect timeout after server disconnects + * + */ +static void xprt_rdma_set_connect_timeout(struct rpc_xprt *xprt, + unsigned long connect_timeout, + unsigned long reconnect_timeout) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); - xprt_rdma_format_addresses(xprt); + trace_xprtrdma_op_set_cto(r_xprt, connect_timeout, reconnect_timeout); - if (!try_module_get(THIS_MODULE)) - goto out4; + spin_lock(&xprt->transport_lock); - return xprt; + if (connect_timeout < xprt->connect_timeout) { + struct rpc_timeout to; + unsigned long initval; -out4: - xprt_rdma_free_addresses(xprt); - rc = -EINVAL; -out3: - (void) rpcrdma_ep_destroy(new_ep, &new_xprt->rx_ia); -out2: - rpcrdma_ia_close(&new_xprt->rx_ia); -out1: - xprt_free(xprt); - return ERR_PTR(rc); + to = *xprt->timeout; + initval = connect_timeout; + if (initval < RPCRDMA_INIT_REEST_TO << 1) + initval = RPCRDMA_INIT_REEST_TO << 1; + to.to_initval = initval; + to.to_maxval = initval; + r_xprt->rx_timeout = to; + xprt->timeout = &r_xprt->rx_timeout; + xprt->connect_timeout = connect_timeout; + } + + if (reconnect_timeout < xprt->max_reconnect_timeout) + xprt->max_reconnect_timeout = reconnect_timeout; + + spin_unlock(&xprt->transport_lock); } -/* - * Close a connection, during shutdown or timeout/reconnect +/** + * xprt_rdma_connect - schedule an attempt to reconnect + * @xprt: transport state + * @task: RPC scheduler context (unused) + * */ static void -xprt_rdma_close(struct rpc_xprt *xprt) +xprt_rdma_connect(struct rpc_xprt *xprt, struct rpc_task *task) { struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_ep *ep = r_xprt->rx_ep; + unsigned long delay; - dprintk("RPC: %s: closing\n", __func__); - if (r_xprt->rx_ep.rep_connected > 0) - xprt->reestablish_timeout = 0; - xprt_disconnect_done(xprt); - (void) rpcrdma_ep_disconnect(&r_xprt->rx_ep, &r_xprt->rx_ia); + WARN_ON_ONCE(!xprt_lock_connect(xprt, task, r_xprt)); + + delay = 0; + if (ep && ep->re_connect_status != 0) { + delay = xprt_reconnect_delay(xprt); + xprt_reconnect_backoff(xprt, RPCRDMA_INIT_REEST_TO); + } + trace_xprtrdma_op_connect(r_xprt, delay); + queue_delayed_work(system_long_wq, &r_xprt->rx_connect_worker, delay); } +/** + * xprt_rdma_alloc_slot - allocate an rpc_rqst + * @xprt: controlling RPC transport + * @task: RPC task requesting a fresh rpc_rqst + * + * tk_status values: + * %0 if task->tk_rqstp points to a fresh rpc_rqst + * %-EAGAIN if no rpc_rqst is available; queued on backlog + */ static void -xprt_rdma_set_port(struct rpc_xprt *xprt, u16 port) +xprt_rdma_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) { - struct sockaddr_in *sap; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_req *req; - sap = (struct sockaddr_in *)&xprt->addr; - sap->sin_port = htons(port); - sap = (struct sockaddr_in *)&rpcx_to_rdmad(xprt).addr; - sap->sin_port = htons(port); - dprintk("RPC: %s: %u\n", __func__, port); + req = rpcrdma_buffer_get(&r_xprt->rx_buf); + if (!req) + goto out_sleep; + task->tk_rqstp = &req->rl_slot; + task->tk_status = 0; + return; + +out_sleep: + task->tk_status = -ENOMEM; + xprt_add_backlog(xprt, task); } +/** + * xprt_rdma_free_slot - release an rpc_rqst + * @xprt: controlling RPC transport + * @rqst: rpc_rqst to release + * + */ static void -xprt_rdma_connect(struct rpc_xprt *xprt, struct rpc_task *task) +xprt_rdma_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *rqst) { - struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_xprt *r_xprt = + container_of(xprt, struct rpcrdma_xprt, rx_xprt); - if (r_xprt->rx_ep.rep_connected != 0) { - /* Reconnect */ - schedule_delayed_work(&r_xprt->rdma_connect, - xprt->reestablish_timeout); - xprt->reestablish_timeout <<= 1; - if (xprt->reestablish_timeout > (30 * HZ)) - xprt->reestablish_timeout = (30 * HZ); - else if (xprt->reestablish_timeout < (5 * HZ)) - xprt->reestablish_timeout = (5 * HZ); - } else { - schedule_delayed_work(&r_xprt->rdma_connect, 0); - if (!RPC_IS_ASYNC(task)) - flush_delayed_work(&r_xprt->rdma_connect); + rpcrdma_reply_put(&r_xprt->rx_buf, rpcr_to_rdmar(rqst)); + if (!xprt_wake_up_backlog(xprt, rqst)) { + memset(rqst, 0, sizeof(*rqst)); + rpcrdma_buffer_put(&r_xprt->rx_buf, rpcr_to_rdmar(rqst)); } } -static int -xprt_rdma_reserve_xprt(struct rpc_xprt *xprt, struct rpc_task *task) +static bool rpcrdma_check_regbuf(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_regbuf *rb, size_t size, + gfp_t flags) { - struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); - int credits = atomic_read(&r_xprt->rx_buf.rb_credits); - - /* == RPC_CWNDSCALE @ init, but *after* setup */ - if (r_xprt->rx_buf.rb_cwndscale == 0UL) { - r_xprt->rx_buf.rb_cwndscale = xprt->cwnd; - dprintk("RPC: %s: cwndscale %lu\n", __func__, - r_xprt->rx_buf.rb_cwndscale); - BUG_ON(r_xprt->rx_buf.rb_cwndscale <= 0); + if (unlikely(rdmab_length(rb) < size)) { + if (!rpcrdma_regbuf_realloc(rb, size, flags)) + return false; + r_xprt->rx_stats.hardway_register_count += size; } - xprt->cwnd = credits * r_xprt->rx_buf.rb_cwndscale; - return xprt_reserve_xprt_cong(xprt, task); + return true; } -/* - * The RDMA allocate/free functions need the task structure as a place - * to hide the struct rpcrdma_req, which is necessary for the actual send/recv - * sequence. For this reason, the recv buffers are attached to send - * buffers for portions of the RPC. Note that the RPC layer allocates - * both send and receive buffers in the same call. We may register - * the receive buffer portion when using reply chunks. +/** + * xprt_rdma_allocate - allocate transport resources for an RPC + * @task: RPC task + * + * Return values: + * 0: Success; rq_buffer points to RPC buffer to use + * ENOMEM: Out of memory, call again later + * EIO: A permanent error occurred, do not retry */ -static void * -xprt_rdma_allocate(struct rpc_task *task, size_t size) +static int +xprt_rdma_allocate(struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; - struct rpcrdma_req *req, *nreq; - - req = rpcrdma_buffer_get(&rpcx_to_rdmax(xprt)->rx_buf); - BUG_ON(NULL == req); - - if (size > req->rl_size) { - dprintk("RPC: %s: size %zd too large for buffer[%zd]: " - "prog %d vers %d proc %d\n", - __func__, size, req->rl_size, - task->tk_client->cl_prog, task->tk_client->cl_vers, - task->tk_msg.rpc_proc->p_proc); - /* - * Outgoing length shortage. Our inline write max must have - * been configured to perform direct i/o. - * - * This is therefore a large metadata operation, and the - * allocate call was made on the maximum possible message, - * e.g. containing long filename(s) or symlink data. In - * fact, while these metadata operations *might* carry - * large outgoing payloads, they rarely *do*. However, we - * have to commit to the request here, so reallocate and - * register it now. The data path will never require this - * reallocation. - * - * If the allocation or registration fails, the RPC framework - * will (doggedly) retry. - */ - if (rpcx_to_rdmax(xprt)->rx_ia.ri_memreg_strategy == - RPCRDMA_BOUNCEBUFFERS) { - /* forced to "pure inline" */ - dprintk("RPC: %s: too much data (%zd) for inline " - "(r/w max %d/%d)\n", __func__, size, - rpcx_to_rdmad(xprt).inline_rsize, - rpcx_to_rdmad(xprt).inline_wsize); - size = req->rl_size; - rpc_exit(task, -EIO); /* fail the operation */ - rpcx_to_rdmax(xprt)->rx_stats.failed_marshal_count++; - goto out; - } - if (task->tk_flags & RPC_TASK_SWAPPER) - nreq = kmalloc(sizeof *req + size, GFP_ATOMIC); - else - nreq = kmalloc(sizeof *req + size, GFP_NOFS); - if (nreq == NULL) - goto outfail; - - if (rpcrdma_register_internal(&rpcx_to_rdmax(xprt)->rx_ia, - nreq->rl_base, size + sizeof(struct rpcrdma_req) - - offsetof(struct rpcrdma_req, rl_base), - &nreq->rl_handle, &nreq->rl_iov)) { - kfree(nreq); - goto outfail; - } - rpcx_to_rdmax(xprt)->rx_stats.hardway_register_count += size; - nreq->rl_size = size; - nreq->rl_niovs = 0; - nreq->rl_nchunks = 0; - nreq->rl_buffer = (struct rpcrdma_buffer *)req; - nreq->rl_reply = req->rl_reply; - memcpy(nreq->rl_segments, - req->rl_segments, sizeof nreq->rl_segments); - /* flag the swap with an unused field */ - nreq->rl_iov.length = 0; - req->rl_reply = NULL; - req = nreq; - } - dprintk("RPC: %s: size %zd, request 0x%p\n", __func__, size, req); -out: - req->rl_connect_cookie = 0; /* our reserved value */ - return req->rl_xdr_buf; - -outfail: - rpcrdma_buffer_put(req); - rpcx_to_rdmax(xprt)->rx_stats.failed_marshal_count++; - return NULL; + struct rpc_rqst *rqst = task->tk_rqstp; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + gfp_t flags = rpc_task_gfp_mask(); + + if (!rpcrdma_check_regbuf(r_xprt, req->rl_sendbuf, rqst->rq_callsize, + flags)) + goto out_fail; + if (!rpcrdma_check_regbuf(r_xprt, req->rl_recvbuf, rqst->rq_rcvsize, + flags)) + goto out_fail; + + rqst->rq_buffer = rdmab_data(req->rl_sendbuf); + rqst->rq_rbuffer = rdmab_data(req->rl_recvbuf); + return 0; + +out_fail: + return -ENOMEM; } -/* - * This function returns all RDMA resources to the pool. +/** + * xprt_rdma_free - release resources allocated by xprt_rdma_allocate + * @task: RPC task + * + * Caller guarantees rqst->rq_buffer is non-NULL. */ static void -xprt_rdma_free(void *buffer) +xprt_rdma_free(struct rpc_task *task) { - struct rpcrdma_req *req; - struct rpcrdma_xprt *r_xprt; - struct rpcrdma_rep *rep; - int i; - - if (buffer == NULL) - return; - - req = container_of(buffer, struct rpcrdma_req, rl_xdr_buf[0]); - if (req->rl_iov.length == 0) { /* see allocate above */ - r_xprt = container_of(((struct rpcrdma_req *) req->rl_buffer)->rl_buffer, - struct rpcrdma_xprt, rx_buf); - } else - r_xprt = container_of(req->rl_buffer, struct rpcrdma_xprt, rx_buf); - rep = req->rl_reply; - - dprintk("RPC: %s: called on 0x%p%s\n", - __func__, rep, (rep && rep->rr_func) ? " (with waiter)" : ""); - - /* - * Finish the deregistration. When using mw bind, this was - * begun in rpcrdma_reply_handler(). In all other modes, we - * do it here, in thread context. The process is considered - * complete when the rr_func vector becomes NULL - this - * was put in place during rpcrdma_reply_handler() - the wait - * call below will not block if the dereg is "done". If - * interrupted, our framework will clean up. - */ - for (i = 0; req->rl_nchunks;) { - --req->rl_nchunks; - i += rpcrdma_deregister_external( - &req->rl_segments[i], r_xprt, NULL); - } - - if (rep && wait_event_interruptible(rep->rr_unbind, !rep->rr_func)) { - rep->rr_func = NULL; /* abandon the callback */ - req->rl_reply = NULL; - } + struct rpc_rqst *rqst = task->tk_rqstp; + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); - if (req->rl_iov.length == 0) { /* see allocate above */ - struct rpcrdma_req *oreq = (struct rpcrdma_req *)req->rl_buffer; - oreq->rl_reply = req->rl_reply; - (void) rpcrdma_deregister_internal(&r_xprt->rx_ia, - req->rl_handle, - &req->rl_iov); - kfree(req); - req = oreq; + if (unlikely(!list_empty(&req->rl_registered))) { + trace_xprtrdma_mrs_zap(task); + frwr_unmap_sync(rpcx_to_rdmax(rqst->rq_xprt), req); } - /* Put back request+reply buffers */ - rpcrdma_buffer_put(req); + /* XXX: If the RPC is completing because of a signal and + * not because a reply was received, we ought to ensure + * that the Send completion has fired, so that memory + * involved with the Send is not still visible to the NIC. + */ } -/* - * send_request invokes the meat of RPC RDMA. It must do the following: - * 1. Marshal the RPC request into an RPC RDMA request, which means - * putting a header in front of data, and creating IOVs for RDMA - * from those in the request. - * 2. In marshaling, detect opportunities for RDMA, and use them. - * 3. Post a recv message to set up asynch completion, then send - * the request (rpcrdma_ep_post). - * 4. No partial sends are possible in the RPC-RDMA protocol (as in UDP). +/** + * xprt_rdma_send_request - marshal and send an RPC request + * @rqst: RPC message in rq_snd_buf + * + * Caller holds the transport's write lock. + * + * Returns: + * %0 if the RPC message has been sent + * %-ENOTCONN if the caller should reconnect and call again + * %-EAGAIN if the caller should call again + * %-ENOBUFS if the caller should call again after a delay + * %-EMSGSIZE if encoding ran out of buffer space. The request + * was not sent. Do not try to send this message again. + * %-EIO if an I/O error occurred. The request was not sent. + * Do not try to send this message again. */ - static int -xprt_rdma_send_request(struct rpc_task *task) +xprt_rdma_send_request(struct rpc_rqst *rqst) { - struct rpc_rqst *rqst = task->tk_rqstp; struct rpc_xprt *xprt = rqst->rq_xprt; struct rpcrdma_req *req = rpcr_to_rdmar(rqst); struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + int rc = 0; - /* marshal the send itself */ - if (req->rl_niovs == 0 && rpcrdma_marshal_req(rqst) != 0) { - r_xprt->rx_stats.failed_marshal_count++; - dprintk("RPC: %s: rpcrdma_marshal_req failed\n", - __func__); - return -EIO; - } +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + if (unlikely(!rqst->rq_buffer)) + return xprt_rdma_bc_send_reply(rqst); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ - if (req->rl_reply == NULL) /* e.g. reconnection */ - rpcrdma_recv_buffer_get(req); + if (!xprt_connected(xprt)) + return -ENOTCONN; - if (req->rl_reply) { - req->rl_reply->rr_func = rpcrdma_reply_handler; - /* this need only be done once, but... */ - req->rl_reply->rr_xprt = xprt; - } + if (!xprt_request_get_cong(xprt, rqst)) + return -EBADSLT; + + rc = rpcrdma_marshal_req(r_xprt, rqst); + if (rc < 0) + goto failed_marshal; /* Must suppress retransmit to maintain credits */ - if (req->rl_connect_cookie == xprt->connect_cookie) + if (rqst->rq_connect_cookie == xprt->connect_cookie) goto drop_connection; - req->rl_connect_cookie = xprt->connect_cookie; + rqst->rq_xtime = ktime_get(); - if (rpcrdma_ep_post(&r_xprt->rx_ia, &r_xprt->rx_ep, req)) + if (frwr_send(r_xprt, req)) goto drop_connection; rqst->rq_xmit_bytes_sent += rqst->rq_snd_buf.len; - rqst->rq_bytes_sent = 0; + + /* An RPC with no reply will throw off credit accounting, + * so drop the connection to reset the credit grant. + */ + if (!rpc_reply_expected(rqst->rq_task)) + goto drop_connection; return 0; +failed_marshal: + if (rc != -ENOTCONN) + return rc; drop_connection: - xprt_disconnect_done(xprt); - return -ENOTCONN; /* implies disconnect */ + xprt_rdma_close(xprt); + return -ENOTCONN; } -static void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) { struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); long idle_time = 0; @@ -673,43 +673,62 @@ static void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) if (xprt_connected(xprt)) idle_time = (long)(jiffies - xprt->last_used) / HZ; - seq_printf(seq, - "\txprt:\trdma %u %lu %lu %lu %ld %lu %lu %lu %Lu %Lu " - "%lu %lu %lu %Lu %Lu %Lu %Lu %lu %lu %lu\n", - - 0, /* need a local port? */ - xprt->stat.bind_count, - xprt->stat.connect_count, - xprt->stat.connect_time, - idle_time, - xprt->stat.sends, - xprt->stat.recvs, - xprt->stat.bad_xids, - xprt->stat.req_u, - xprt->stat.bklog_u, - - r_xprt->rx_stats.read_chunk_count, - r_xprt->rx_stats.write_chunk_count, - r_xprt->rx_stats.reply_chunk_count, - r_xprt->rx_stats.total_rdma_request, - r_xprt->rx_stats.total_rdma_reply, - r_xprt->rx_stats.pullup_copy_count, - r_xprt->rx_stats.fixup_copy_count, - r_xprt->rx_stats.hardway_register_count, - r_xprt->rx_stats.failed_marshal_count, - r_xprt->rx_stats.bad_reply_count); + seq_puts(seq, "\txprt:\trdma "); + seq_printf(seq, "%u %lu %lu %lu %ld %lu %lu %lu %llu %llu ", + 0, /* need a local port? */ + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time / HZ, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u); + seq_printf(seq, "%lu %lu %lu %llu %llu %llu %llu %lu %lu %lu %lu ", + r_xprt->rx_stats.read_chunk_count, + r_xprt->rx_stats.write_chunk_count, + r_xprt->rx_stats.reply_chunk_count, + r_xprt->rx_stats.total_rdma_request, + r_xprt->rx_stats.total_rdma_reply, + r_xprt->rx_stats.pullup_copy_count, + r_xprt->rx_stats.fixup_copy_count, + r_xprt->rx_stats.hardway_register_count, + r_xprt->rx_stats.failed_marshal_count, + r_xprt->rx_stats.bad_reply_count, + r_xprt->rx_stats.nomsg_call_count); + seq_printf(seq, "%lu %lu %lu %lu %lu %lu\n", + r_xprt->rx_stats.mrs_recycled, + r_xprt->rx_stats.mrs_orphaned, + r_xprt->rx_stats.mrs_allocated, + r_xprt->rx_stats.local_inv_needed, + r_xprt->rx_stats.empty_sendctx_q, + r_xprt->rx_stats.reply_waits_for_send); +} + +static int +xprt_rdma_enable_swap(struct rpc_xprt *xprt) +{ + return 0; +} + +static void +xprt_rdma_disable_swap(struct rpc_xprt *xprt) +{ } /* * Plumbing for rpc transport switch and kernel module */ -static struct rpc_xprt_ops xprt_rdma_procs = { - .reserve_xprt = xprt_rdma_reserve_xprt, +static const struct rpc_xprt_ops xprt_rdma_procs = { + .reserve_xprt = xprt_reserve_xprt_cong, .release_xprt = xprt_release_xprt_cong, /* sunrpc/xprt.c */ - .alloc_slot = xprt_alloc_slot, + .alloc_slot = xprt_rdma_alloc_slot, + .free_slot = xprt_rdma_free_slot, .release_request = xprt_release_rqst_cong, /* ditto */ - .set_retrans_timeout = xprt_set_retrans_timeout_def, /* ditto */ + .wait_for_reply_request = xprt_wait_for_reply_request_def, /* ditto */ + .timer = xprt_rdma_timer, .rpcbind = rpcb_getport_async, /* sunrpc/rpcb_clnt.c */ .set_port = xprt_rdma_set_port, .connect = xprt_rdma_connect, @@ -718,7 +737,18 @@ static struct rpc_xprt_ops xprt_rdma_procs = { .send_request = xprt_rdma_send_request, .close = xprt_rdma_close, .destroy = xprt_rdma_destroy, - .print_stats = xprt_rdma_print_stats + .set_connect_timeout = xprt_rdma_set_connect_timeout, + .print_stats = xprt_rdma_print_stats, + .enable_swap = xprt_rdma_enable_swap, + .disable_swap = xprt_rdma_disable_swap, + .inject_disconnect = xprt_rdma_inject_disconnect, +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + .bc_setup = xprt_rdma_bc_setup, + .bc_maxpayload = xprt_rdma_bc_maxpayload, + .bc_num_slots = xprt_rdma_bc_max_slots, + .bc_free_rqst = xprt_rdma_bc_free_rqst, + .bc_destroy = xprt_rdma_bc_destroy, +#endif }; static struct xprt_class xprt_rdma = { @@ -727,50 +757,39 @@ static struct xprt_class xprt_rdma = { .owner = THIS_MODULE, .ident = XPRT_TRANSPORT_RDMA, .setup = xprt_setup_rdma, + .netid = { "rdma", "rdma6", "" }, }; -static void __exit xprt_rdma_cleanup(void) +void xprt_rdma_cleanup(void) { - int rc; - - dprintk(KERN_INFO "RPCRDMA Module Removed, deregister RPC RDMA transport\n"); -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) if (sunrpc_table_header) { unregister_sysctl_table(sunrpc_table_header); sunrpc_table_header = NULL; } #endif - rc = xprt_unregister_transport(&xprt_rdma); - if (rc) - dprintk("RPC: %s: xprt_unregister returned %i\n", - __func__, rc); + + xprt_unregister_transport(&xprt_rdma); + xprt_unregister_transport(&xprt_rdma_bc); } -static int __init xprt_rdma_init(void) +int xprt_rdma_init(void) { int rc; rc = xprt_register_transport(&xprt_rdma); - if (rc) return rc; - dprintk(KERN_INFO "RPCRDMA Module Init, register RPC RDMA transport\n"); - - dprintk(KERN_INFO "Defaults:\n"); - dprintk(KERN_INFO "\tSlots %d\n" - "\tMaxInlineRead %d\n\tMaxInlineWrite %d\n", - xprt_rdma_slot_table_entries, - xprt_rdma_max_inline_read, xprt_rdma_max_inline_write); - dprintk(KERN_INFO "\tPadding %d\n\tMemreg %d\n", - xprt_rdma_inline_write_padding, xprt_rdma_memreg_strategy); + rc = xprt_register_transport(&xprt_rdma_bc); + if (rc) { + xprt_unregister_transport(&xprt_rdma); + return rc; + } -#ifdef RPC_DEBUG +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) if (!sunrpc_table_header) - sunrpc_table_header = register_sysctl_table(sunrpc_table); + sunrpc_table_header = register_sysctl("sunrpc", xr_tunables_table); #endif return 0; } - -module_init(xprt_rdma_init); -module_exit(xprt_rdma_cleanup); diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c index 93726560eaa8..63262ef0c2e3 100644 --- a/net/sunrpc/xprtrdma/verbs.c +++ b/net/sunrpc/xprtrdma/verbs.c @@ -1,4 +1,6 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause /* + * Copyright (c) 2014-2017 Oracle. All rights reserved. * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -47,378 +49,282 @@ * o buffer memory */ +#include <linux/bitops.h> #include <linux/interrupt.h> -#include <linux/pci.h> /* for Tavor hack below */ #include <linux/slab.h> +#include <linux/sunrpc/addr.h> +#include <linux/sunrpc/svc_rdma.h> +#include <linux/log2.h> -#include "xprt_rdma.h" +#include <asm/barrier.h> -/* - * Globals/Macros - */ +#include <rdma/ib_cm.h> -#ifdef RPC_DEBUG -# define RPCDBG_FACILITY RPCDBG_TRANS -#endif - -/* - * internal functions - */ - -/* - * handle replies in tasklet context, using a single, global list - * rdma tasklet function -- just turn around and call the func - * for all replies on the list +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +static int rpcrdma_sendctxs_create(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_sendctxs_destroy(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_sendctx_put_locked(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_sendctx *sc); +static int rpcrdma_reqs_setup(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_reqs_reset(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_reps_unmap(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_mrs_destroy(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_ep_get(struct rpcrdma_ep *ep); +static int rpcrdma_ep_put(struct rpcrdma_ep *ep); +static struct rpcrdma_regbuf * +rpcrdma_regbuf_alloc_node(size_t size, enum dma_data_direction direction, + int node); +static struct rpcrdma_regbuf * +rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction); +static void rpcrdma_regbuf_dma_unmap(struct rpcrdma_regbuf *rb); +static void rpcrdma_regbuf_free(struct rpcrdma_regbuf *rb); + +/* Wait for outstanding transport work to finish. ib_drain_qp + * handles the drains in the wrong order for us, so open code + * them here. */ - -static DEFINE_SPINLOCK(rpcrdma_tk_lock_g); -static LIST_HEAD(rpcrdma_tasklets_g); - -static void -rpcrdma_run_tasklet(unsigned long data) +static void rpcrdma_xprt_drain(struct rpcrdma_xprt *r_xprt) { - struct rpcrdma_rep *rep; - void (*func)(struct rpcrdma_rep *); - unsigned long flags; - - data = data; - spin_lock_irqsave(&rpcrdma_tk_lock_g, flags); - while (!list_empty(&rpcrdma_tasklets_g)) { - rep = list_entry(rpcrdma_tasklets_g.next, - struct rpcrdma_rep, rr_list); - list_del(&rep->rr_list); - func = rep->rr_func; - rep->rr_func = NULL; - spin_unlock_irqrestore(&rpcrdma_tk_lock_g, flags); - - if (func) - func(rep); - else - rpcrdma_recv_buffer_put(rep); - - spin_lock_irqsave(&rpcrdma_tk_lock_g, flags); - } - spin_unlock_irqrestore(&rpcrdma_tk_lock_g, flags); -} + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rdma_cm_id *id = ep->re_id; -static DECLARE_TASKLET(rpcrdma_tasklet_g, rpcrdma_run_tasklet, 0UL); + /* Wait for rpcrdma_post_recvs() to leave its critical + * section. + */ + if (atomic_inc_return(&ep->re_receiving) > 1) + wait_for_completion(&ep->re_done); -static inline void -rpcrdma_schedule_tasklet(struct rpcrdma_rep *rep) -{ - unsigned long flags; + /* Flush Receives, then wait for deferred Reply work + * to complete. + */ + ib_drain_rq(id->qp); - spin_lock_irqsave(&rpcrdma_tk_lock_g, flags); - list_add_tail(&rep->rr_list, &rpcrdma_tasklets_g); - spin_unlock_irqrestore(&rpcrdma_tk_lock_g, flags); - tasklet_schedule(&rpcrdma_tasklet_g); -} + /* Deferred Reply processing might have scheduled + * local invalidations. + */ + ib_drain_sq(id->qp); -static void -rpcrdma_qp_async_error_upcall(struct ib_event *event, void *context) -{ - struct rpcrdma_ep *ep = context; - - dprintk("RPC: %s: QP error %X on device %s ep %p\n", - __func__, event->event, event->device->name, context); - if (ep->rep_connected == 1) { - ep->rep_connected = -EIO; - ep->rep_func(ep); - wake_up_all(&ep->rep_connect_wait); - } + rpcrdma_ep_put(ep); } -static void -rpcrdma_cq_async_error_upcall(struct ib_event *event, void *context) +/* Ensure xprt_force_disconnect() is invoked exactly once when a + * connection is closed or lost. (The important thing is it needs + * to be invoked "at least" once). + */ +void rpcrdma_force_disconnect(struct rpcrdma_ep *ep) { - struct rpcrdma_ep *ep = context; - - dprintk("RPC: %s: CQ error %X on device %s ep %p\n", - __func__, event->event, event->device->name, context); - if (ep->rep_connected == 1) { - ep->rep_connected = -EIO; - ep->rep_func(ep); - wake_up_all(&ep->rep_connect_wait); - } + if (atomic_add_unless(&ep->re_force_disconnect, 1, 1)) + xprt_force_disconnect(ep->re_xprt); } -static inline -void rpcrdma_event_process(struct ib_wc *wc) +/** + * rpcrdma_flush_disconnect - Disconnect on flushed completion + * @r_xprt: transport to disconnect + * @wc: work completion entry + * + * Must be called in process context. + */ +void rpcrdma_flush_disconnect(struct rpcrdma_xprt *r_xprt, struct ib_wc *wc) { - struct rpcrdma_mw *frmr; - struct rpcrdma_rep *rep = - (struct rpcrdma_rep *)(unsigned long) wc->wr_id; - - dprintk("RPC: %s: event rep %p status %X opcode %X length %u\n", - __func__, rep, wc->status, wc->opcode, wc->byte_len); - - if (!rep) /* send or bind completion that we don't care about */ - return; - - if (IB_WC_SUCCESS != wc->status) { - dprintk("RPC: %s: WC opcode %d status %X, connection lost\n", - __func__, wc->opcode, wc->status); - rep->rr_len = ~0U; - if (wc->opcode != IB_WC_FAST_REG_MR && wc->opcode != IB_WC_LOCAL_INV) - rpcrdma_schedule_tasklet(rep); - return; - } - - switch (wc->opcode) { - case IB_WC_FAST_REG_MR: - frmr = (struct rpcrdma_mw *)(unsigned long)wc->wr_id; - frmr->r.frmr.state = FRMR_IS_VALID; - break; - case IB_WC_LOCAL_INV: - frmr = (struct rpcrdma_mw *)(unsigned long)wc->wr_id; - frmr->r.frmr.state = FRMR_IS_INVALID; - break; - case IB_WC_RECV: - rep->rr_len = wc->byte_len; - ib_dma_sync_single_for_cpu( - rdmab_to_ia(rep->rr_buffer)->ri_id->device, - rep->rr_iov.addr, rep->rr_len, DMA_FROM_DEVICE); - /* Keep (only) the most recent credits, after check validity */ - if (rep->rr_len >= 16) { - struct rpcrdma_msg *p = - (struct rpcrdma_msg *) rep->rr_base; - unsigned int credits = ntohl(p->rm_credit); - if (credits == 0) { - dprintk("RPC: %s: server" - " dropped credits to 0!\n", __func__); - /* don't deadlock */ - credits = 1; - } else if (credits > rep->rr_buffer->rb_max_requests) { - dprintk("RPC: %s: server" - " over-crediting: %d (%d)\n", - __func__, credits, - rep->rr_buffer->rb_max_requests); - credits = rep->rr_buffer->rb_max_requests; - } - atomic_set(&rep->rr_buffer->rb_credits, credits); - } - /* fall through */ - case IB_WC_BIND_MW: - rpcrdma_schedule_tasklet(rep); - break; - default: - dprintk("RPC: %s: unexpected WC event %X\n", - __func__, wc->opcode); - break; - } + if (wc->status != IB_WC_SUCCESS) + rpcrdma_force_disconnect(r_xprt->rx_ep); } -static inline int -rpcrdma_cq_poll(struct ib_cq *cq) +/** + * rpcrdma_wc_send - Invoked by RDMA provider for each polled Send WC + * @cq: completion queue + * @wc: WCE for a completed Send WR + * + */ +static void rpcrdma_wc_send(struct ib_cq *cq, struct ib_wc *wc) { - struct ib_wc wc; - int rc; - - for (;;) { - rc = ib_poll_cq(cq, 1, &wc); - if (rc < 0) { - dprintk("RPC: %s: ib_poll_cq failed %i\n", - __func__, rc); - return rc; - } - if (rc == 0) - break; - - rpcrdma_event_process(&wc); - } - - return 0; + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_sendctx *sc = + container_of(cqe, struct rpcrdma_sendctx, sc_cqe); + struct rpcrdma_xprt *r_xprt = cq->cq_context; + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_send(wc, &sc->sc_cid); + rpcrdma_sendctx_put_locked(r_xprt, sc); + rpcrdma_flush_disconnect(r_xprt, wc); } -/* - * rpcrdma_cq_event_upcall - * - * This upcall handles recv, send, bind and unbind events. - * It is reentrant but processes single events in order to maintain - * ordering of receives to keep server credits. - * - * It is the responsibility of the scheduled tasklet to return - * recv buffers to the pool. NOTE: this affects synchronization of - * connection shutdown. That is, the structures required for - * the completion of the reply handler must remain intact until - * all memory has been reclaimed. +/** + * rpcrdma_wc_receive - Invoked by RDMA provider for each polled Receive WC + * @cq: completion queue + * @wc: WCE for a completed Receive WR * - * Note that send events are suppressed and do not result in an upcall. */ -static void -rpcrdma_cq_event_upcall(struct ib_cq *cq, void *context) +static void rpcrdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc) { - int rc; - - rc = rpcrdma_cq_poll(cq); - if (rc) - return; + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_rep *rep = container_of(cqe, struct rpcrdma_rep, + rr_cqe); + struct rpcrdma_xprt *r_xprt = cq->cq_context; + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_receive(wc, &rep->rr_cid); + --r_xprt->rx_ep->re_receive_count; + if (wc->status != IB_WC_SUCCESS) + goto out_flushed; + + /* status == SUCCESS means all fields in wc are trustworthy */ + rpcrdma_set_xdrlen(&rep->rr_hdrbuf, wc->byte_len); + rep->rr_wc_flags = wc->wc_flags; + rep->rr_inv_rkey = wc->ex.invalidate_rkey; + + ib_dma_sync_single_for_cpu(rdmab_device(rep->rr_rdmabuf), + rdmab_addr(rep->rr_rdmabuf), + wc->byte_len, DMA_FROM_DEVICE); + + rpcrdma_reply_handler(rep); + return; + +out_flushed: + rpcrdma_flush_disconnect(r_xprt, wc); + rpcrdma_rep_put(&r_xprt->rx_buf, rep); +} - rc = ib_req_notify_cq(cq, IB_CQ_NEXT_COMP); - if (rc) { - dprintk("RPC: %s: ib_req_notify_cq failed %i\n", - __func__, rc); - return; +static void rpcrdma_update_cm_private(struct rpcrdma_ep *ep, + struct rdma_conn_param *param) +{ + const struct rpcrdma_connect_private *pmsg = param->private_data; + unsigned int rsize, wsize; + + /* Default settings for RPC-over-RDMA Version One */ + rsize = RPCRDMA_V1_DEF_INLINE_SIZE; + wsize = RPCRDMA_V1_DEF_INLINE_SIZE; + + if (pmsg && + pmsg->cp_magic == rpcrdma_cmp_magic && + pmsg->cp_version == RPCRDMA_CMP_VERSION) { + rsize = rpcrdma_decode_buffer_size(pmsg->cp_send_size); + wsize = rpcrdma_decode_buffer_size(pmsg->cp_recv_size); } - rpcrdma_cq_poll(cq); -} + if (rsize < ep->re_inline_recv) + ep->re_inline_recv = rsize; + if (wsize < ep->re_inline_send) + ep->re_inline_send = wsize; -#ifdef RPC_DEBUG -static const char * const conn[] = { - "address resolved", - "address error", - "route resolved", - "route error", - "connect request", - "connect response", - "connect error", - "unreachable", - "rejected", - "established", - "disconnected", - "device removal" -}; -#endif + rpcrdma_set_max_header_sizes(ep); +} +/** + * rpcrdma_cm_event_handler - Handle RDMA CM events + * @id: rdma_cm_id on which an event has occurred + * @event: details of the event + * + * Called with @id's mutex held. Returns 1 if caller should + * destroy @id, otherwise 0. + */ static int -rpcrdma_conn_upcall(struct rdma_cm_id *id, struct rdma_cm_event *event) +rpcrdma_cm_event_handler(struct rdma_cm_id *id, struct rdma_cm_event *event) { - struct rpcrdma_xprt *xprt = id->context; - struct rpcrdma_ia *ia = &xprt->rx_ia; - struct rpcrdma_ep *ep = &xprt->rx_ep; -#ifdef RPC_DEBUG - struct sockaddr_in *addr = (struct sockaddr_in *) &ep->rep_remote_addr; -#endif - struct ib_qp_attr attr; - struct ib_qp_init_attr iattr; - int connstate = 0; + struct rpcrdma_ep *ep = id->context; + + might_sleep(); switch (event->event) { case RDMA_CM_EVENT_ADDR_RESOLVED: case RDMA_CM_EVENT_ROUTE_RESOLVED: - ia->ri_async_rc = 0; - complete(&ia->ri_done); - break; + ep->re_async_rc = 0; + complete(&ep->re_done); + return 0; case RDMA_CM_EVENT_ADDR_ERROR: - ia->ri_async_rc = -EHOSTUNREACH; - dprintk("RPC: %s: CM address resolution error, ep 0x%p\n", - __func__, ep); - complete(&ia->ri_done); - break; + ep->re_async_rc = -EPROTO; + complete(&ep->re_done); + return 0; case RDMA_CM_EVENT_ROUTE_ERROR: - ia->ri_async_rc = -ENETUNREACH; - dprintk("RPC: %s: CM route resolution error, ep 0x%p\n", - __func__, ep); - complete(&ia->ri_done); - break; + ep->re_async_rc = -ENETUNREACH; + complete(&ep->re_done); + return 0; + case RDMA_CM_EVENT_ADDR_CHANGE: + ep->re_connect_status = -ENODEV; + goto disconnected; case RDMA_CM_EVENT_ESTABLISHED: - connstate = 1; - ib_query_qp(ia->ri_id->qp, &attr, - IB_QP_MAX_QP_RD_ATOMIC | IB_QP_MAX_DEST_RD_ATOMIC, - &iattr); - dprintk("RPC: %s: %d responder resources" - " (%d initiator)\n", - __func__, attr.max_dest_rd_atomic, attr.max_rd_atomic); - goto connected; + rpcrdma_ep_get(ep); + ep->re_connect_status = 1; + rpcrdma_update_cm_private(ep, &event->param.conn); + trace_xprtrdma_inline_thresh(ep); + wake_up_all(&ep->re_connect_wait); + break; case RDMA_CM_EVENT_CONNECT_ERROR: - connstate = -ENOTCONN; - goto connected; + ep->re_connect_status = -ENOTCONN; + goto wake_connect_worker; case RDMA_CM_EVENT_UNREACHABLE: - connstate = -ENETDOWN; - goto connected; + ep->re_connect_status = -ENETUNREACH; + goto wake_connect_worker; case RDMA_CM_EVENT_REJECTED: - connstate = -ECONNREFUSED; - goto connected; + ep->re_connect_status = -ECONNREFUSED; + if (event->status == IB_CM_REJ_STALE_CONN) + ep->re_connect_status = -ENOTCONN; +wake_connect_worker: + wake_up_all(&ep->re_connect_wait); + return 0; case RDMA_CM_EVENT_DISCONNECTED: - connstate = -ECONNABORTED; - goto connected; - case RDMA_CM_EVENT_DEVICE_REMOVAL: - connstate = -ENODEV; -connected: - dprintk("RPC: %s: %s: %pI4:%u (ep 0x%p event 0x%x)\n", - __func__, - (event->event <= 11) ? conn[event->event] : - "unknown connection error", - &addr->sin_addr.s_addr, - ntohs(addr->sin_port), - ep, event->event); - atomic_set(&rpcx_to_rdmax(ep->rep_xprt)->rx_buf.rb_credits, 1); - dprintk("RPC: %s: %sconnected\n", - __func__, connstate > 0 ? "" : "dis"); - ep->rep_connected = connstate; - ep->rep_func(ep); - wake_up_all(&ep->rep_connect_wait); - break; + ep->re_connect_status = -ECONNABORTED; +disconnected: + rpcrdma_force_disconnect(ep); + return rpcrdma_ep_put(ep); default: - dprintk("RPC: %s: unexpected CM event %d\n", - __func__, event->event); break; } -#ifdef RPC_DEBUG - if (connstate == 1) { - int ird = attr.max_dest_rd_atomic; - int tird = ep->rep_remote_cma.responder_resources; - printk(KERN_INFO "rpcrdma: connection to %pI4:%u " - "on %s, memreg %d slots %d ird %d%s\n", - &addr->sin_addr.s_addr, - ntohs(addr->sin_port), - ia->ri_id->device->name, - ia->ri_memreg_strategy, - xprt->rx_buf.rb_max_requests, - ird, ird < 4 && ird < tird / 2 ? " (low!)" : ""); - } else if (connstate < 0) { - printk(KERN_INFO "rpcrdma: connection to %pI4:%u closed (%d)\n", - &addr->sin_addr.s_addr, - ntohs(addr->sin_port), - connstate); - } -#endif - return 0; } -static struct rdma_cm_id * -rpcrdma_create_id(struct rpcrdma_xprt *xprt, - struct rpcrdma_ia *ia, struct sockaddr *addr) +static void rpcrdma_ep_removal_done(struct rpcrdma_notification *rn) +{ + struct rpcrdma_ep *ep = container_of(rn, struct rpcrdma_ep, re_rn); + + trace_xprtrdma_device_removal(ep->re_id); + xprt_force_disconnect(ep->re_xprt); +} + +static struct rdma_cm_id *rpcrdma_create_id(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_ep *ep) { + unsigned long wtimeout = msecs_to_jiffies(RDMA_RESOLVE_TIMEOUT) + 1; + struct rpc_xprt *xprt = &r_xprt->rx_xprt; struct rdma_cm_id *id; int rc; - init_completion(&ia->ri_done); + init_completion(&ep->re_done); - id = rdma_create_id(rpcrdma_conn_upcall, xprt, RDMA_PS_TCP, IB_QPT_RC); - if (IS_ERR(id)) { - rc = PTR_ERR(id); - dprintk("RPC: %s: rdma_create_id() failed %i\n", - __func__, rc); + id = rdma_create_id(xprt->xprt_net, rpcrdma_cm_event_handler, ep, + RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(id)) return id; - } - ia->ri_async_rc = -ETIMEDOUT; - rc = rdma_resolve_addr(id, NULL, addr, RDMA_RESOLVE_TIMEOUT); - if (rc) { - dprintk("RPC: %s: rdma_resolve_addr() failed %i\n", - __func__, rc); + ep->re_async_rc = -ETIMEDOUT; + rc = rdma_resolve_addr(id, NULL, (struct sockaddr *)&xprt->addr, + RDMA_RESOLVE_TIMEOUT); + if (rc) goto out; - } - wait_for_completion_interruptible_timeout(&ia->ri_done, - msecs_to_jiffies(RDMA_RESOLVE_TIMEOUT) + 1); - rc = ia->ri_async_rc; + rc = wait_for_completion_interruptible_timeout(&ep->re_done, wtimeout); + if (rc < 0) + goto out; + + rc = ep->re_async_rc; if (rc) goto out; - ia->ri_async_rc = -ETIMEDOUT; + ep->re_async_rc = -ETIMEDOUT; rc = rdma_resolve_route(id, RDMA_RESOLVE_TIMEOUT); - if (rc) { - dprintk("RPC: %s: rdma_resolve_route() failed %i\n", - __func__, rc); + if (rc) goto out; - } - wait_for_completion_interruptible_timeout(&ia->ri_done, - msecs_to_jiffies(RDMA_RESOLVE_TIMEOUT) + 1); - rc = ia->ri_async_rc; + rc = wait_for_completion_interruptible_timeout(&ep->re_done, wtimeout); + if (rc < 0) + goto out; + rc = ep->re_async_rc; + if (rc) + goto out; + + rc = rpcrdma_rn_register(id->device, &ep->re_rn, rpcrdma_ep_removal_done); if (rc) goto out; @@ -429,1532 +335,1076 @@ out: return ERR_PTR(rc); } -/* - * Drain any cq, prior to teardown. - */ -static void -rpcrdma_clean_cq(struct ib_cq *cq) +static void rpcrdma_ep_destroy(struct kref *kref) { - struct ib_wc wc; - int count = 0; + struct rpcrdma_ep *ep = container_of(kref, struct rpcrdma_ep, re_kref); - while (1 == ib_poll_cq(cq, 1, &wc)) - ++count; + if (ep->re_id->qp) { + rdma_destroy_qp(ep->re_id); + ep->re_id->qp = NULL; + } + + if (ep->re_attr.recv_cq) + ib_free_cq(ep->re_attr.recv_cq); + ep->re_attr.recv_cq = NULL; + if (ep->re_attr.send_cq) + ib_free_cq(ep->re_attr.send_cq); + ep->re_attr.send_cq = NULL; + + if (ep->re_pd) + ib_dealloc_pd(ep->re_pd); + ep->re_pd = NULL; + + rpcrdma_rn_unregister(ep->re_id->device, &ep->re_rn); - if (count) - dprintk("RPC: %s: flushed %d events (last 0x%x)\n", - __func__, count, wc.opcode); + kfree(ep); + module_put(THIS_MODULE); } -/* - * Exported functions. - */ +static noinline void rpcrdma_ep_get(struct rpcrdma_ep *ep) +{ + kref_get(&ep->re_kref); +} -/* - * Open and initialize an Interface Adapter. - * o initializes fields of struct rpcrdma_ia, including - * interface and provider attributes and protection zone. +/* Returns: + * %0 if @ep still has a positive kref count, or + * %1 if @ep was destroyed successfully. */ -int -rpcrdma_ia_open(struct rpcrdma_xprt *xprt, struct sockaddr *addr, int memreg) +static noinline int rpcrdma_ep_put(struct rpcrdma_ep *ep) { - int rc, mem_priv; - struct ib_device_attr devattr; - struct rpcrdma_ia *ia = &xprt->rx_ia; + return kref_put(&ep->re_kref, rpcrdma_ep_destroy); +} - ia->ri_id = rpcrdma_create_id(xprt, ia, addr); - if (IS_ERR(ia->ri_id)) { - rc = PTR_ERR(ia->ri_id); - goto out1; - } +static int rpcrdma_ep_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_connect_private *pmsg; + struct ib_device *device; + struct rdma_cm_id *id; + struct rpcrdma_ep *ep; + int rc; - ia->ri_pd = ib_alloc_pd(ia->ri_id->device); - if (IS_ERR(ia->ri_pd)) { - rc = PTR_ERR(ia->ri_pd); - dprintk("RPC: %s: ib_alloc_pd() failed %i\n", - __func__, rc); - goto out2; - } + ep = kzalloc(sizeof(*ep), XPRTRDMA_GFP_FLAGS); + if (!ep) + return -ENOTCONN; + ep->re_xprt = &r_xprt->rx_xprt; + kref_init(&ep->re_kref); - /* - * Query the device to determine if the requested memory - * registration strategy is supported. If it isn't, set the - * strategy to a globally supported model. - */ - rc = ib_query_device(ia->ri_id->device, &devattr); - if (rc) { - dprintk("RPC: %s: ib_query_device failed %d\n", - __func__, rc); - goto out2; + id = rpcrdma_create_id(r_xprt, ep); + if (IS_ERR(id)) { + kfree(ep); + return PTR_ERR(id); } - - if (devattr.device_cap_flags & IB_DEVICE_LOCAL_DMA_LKEY) { - ia->ri_have_dma_lkey = 1; - ia->ri_dma_lkey = ia->ri_id->device->local_dma_lkey; + __module_get(THIS_MODULE); + device = id->device; + ep->re_id = id; + reinit_completion(&ep->re_done); + + ep->re_max_requests = r_xprt->rx_xprt.max_reqs; + ep->re_inline_send = xprt_rdma_max_inline_write; + ep->re_inline_recv = xprt_rdma_max_inline_read; + rc = frwr_query_device(ep, device); + if (rc) + goto out_destroy; + + r_xprt->rx_buf.rb_max_requests = cpu_to_be32(ep->re_max_requests); + + ep->re_attr.srq = NULL; + ep->re_attr.cap.max_inline_data = 0; + ep->re_attr.sq_sig_type = IB_SIGNAL_REQ_WR; + ep->re_attr.qp_type = IB_QPT_RC; + ep->re_attr.port_num = ~0; + + ep->re_send_batch = ep->re_max_requests >> 3; + ep->re_send_count = ep->re_send_batch; + init_waitqueue_head(&ep->re_connect_wait); + + ep->re_attr.send_cq = ib_alloc_cq_any(device, r_xprt, + ep->re_attr.cap.max_send_wr, + IB_POLL_WORKQUEUE); + if (IS_ERR(ep->re_attr.send_cq)) { + rc = PTR_ERR(ep->re_attr.send_cq); + ep->re_attr.send_cq = NULL; + goto out_destroy; } - switch (memreg) { - case RPCRDMA_MEMWINDOWS: - case RPCRDMA_MEMWINDOWS_ASYNC: - if (!(devattr.device_cap_flags & IB_DEVICE_MEM_WINDOW)) { - dprintk("RPC: %s: MEMWINDOWS registration " - "specified but not supported by adapter, " - "using slower RPCRDMA_REGISTER\n", - __func__); - memreg = RPCRDMA_REGISTER; - } - break; - case RPCRDMA_MTHCAFMR: - if (!ia->ri_id->device->alloc_fmr) { -#if RPCRDMA_PERSISTENT_REGISTRATION - dprintk("RPC: %s: MTHCAFMR registration " - "specified but not supported by adapter, " - "using riskier RPCRDMA_ALLPHYSICAL\n", - __func__); - memreg = RPCRDMA_ALLPHYSICAL; -#else - dprintk("RPC: %s: MTHCAFMR registration " - "specified but not supported by adapter, " - "using slower RPCRDMA_REGISTER\n", - __func__); - memreg = RPCRDMA_REGISTER; -#endif - } - break; - case RPCRDMA_FRMR: - /* Requires both frmr reg and local dma lkey */ - if ((devattr.device_cap_flags & - (IB_DEVICE_MEM_MGT_EXTENSIONS|IB_DEVICE_LOCAL_DMA_LKEY)) != - (IB_DEVICE_MEM_MGT_EXTENSIONS|IB_DEVICE_LOCAL_DMA_LKEY)) { -#if RPCRDMA_PERSISTENT_REGISTRATION - dprintk("RPC: %s: FRMR registration " - "specified but not supported by adapter, " - "using riskier RPCRDMA_ALLPHYSICAL\n", - __func__); - memreg = RPCRDMA_ALLPHYSICAL; -#else - dprintk("RPC: %s: FRMR registration " - "specified but not supported by adapter, " - "using slower RPCRDMA_REGISTER\n", - __func__); - memreg = RPCRDMA_REGISTER; -#endif - } - break; + ep->re_attr.recv_cq = ib_alloc_cq_any(device, r_xprt, + ep->re_attr.cap.max_recv_wr, + IB_POLL_WORKQUEUE); + if (IS_ERR(ep->re_attr.recv_cq)) { + rc = PTR_ERR(ep->re_attr.recv_cq); + ep->re_attr.recv_cq = NULL; + goto out_destroy; } + ep->re_receive_count = 0; + + /* Initialize cma parameters */ + memset(&ep->re_remote_cma, 0, sizeof(ep->re_remote_cma)); + + /* Prepare RDMA-CM private message */ + pmsg = &ep->re_cm_private; + pmsg->cp_magic = rpcrdma_cmp_magic; + pmsg->cp_version = RPCRDMA_CMP_VERSION; + pmsg->cp_flags |= RPCRDMA_CMP_F_SND_W_INV_OK; + pmsg->cp_send_size = rpcrdma_encode_buffer_size(ep->re_inline_send); + pmsg->cp_recv_size = rpcrdma_encode_buffer_size(ep->re_inline_recv); + ep->re_remote_cma.private_data = pmsg; + ep->re_remote_cma.private_data_len = sizeof(*pmsg); + + /* Client offers RDMA Read but does not initiate */ + ep->re_remote_cma.initiator_depth = 0; + ep->re_remote_cma.responder_resources = + min_t(int, U8_MAX, device->attrs.max_qp_rd_atom); - /* - * Optionally obtain an underlying physical identity mapping in - * order to do a memory window-based bind. This base registration - * is protected from remote access - that is enabled only by binding - * for the specific bytes targeted during each RPC operation, and - * revoked after the corresponding completion similar to a storage - * adapter. + /* Limit transport retries so client can detect server + * GID changes quickly. RPC layer handles re-establishing + * transport connection and retransmission. */ - switch (memreg) { - case RPCRDMA_BOUNCEBUFFERS: - case RPCRDMA_REGISTER: - case RPCRDMA_FRMR: - break; -#if RPCRDMA_PERSISTENT_REGISTRATION - case RPCRDMA_ALLPHYSICAL: - mem_priv = IB_ACCESS_LOCAL_WRITE | - IB_ACCESS_REMOTE_WRITE | - IB_ACCESS_REMOTE_READ; - goto register_setup; -#endif - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - mem_priv = IB_ACCESS_LOCAL_WRITE | - IB_ACCESS_MW_BIND; - goto register_setup; - case RPCRDMA_MTHCAFMR: - if (ia->ri_have_dma_lkey) - break; - mem_priv = IB_ACCESS_LOCAL_WRITE; - register_setup: - ia->ri_bind_mem = ib_get_dma_mr(ia->ri_pd, mem_priv); - if (IS_ERR(ia->ri_bind_mem)) { - printk(KERN_ALERT "%s: ib_get_dma_mr for " - "phys register failed with %lX\n\t" - "Will continue with degraded performance\n", - __func__, PTR_ERR(ia->ri_bind_mem)); - memreg = RPCRDMA_REGISTER; - ia->ri_bind_mem = NULL; - } - break; - default: - printk(KERN_ERR "%s: invalid memory registration mode %d\n", - __func__, memreg); - rc = -EINVAL; - goto out2; + ep->re_remote_cma.retry_count = 6; + + /* RPC-over-RDMA handles its own flow control. In addition, + * make all RNR NAKs visible so we know that RPC-over-RDMA + * flow control is working correctly (no NAKs should be seen). + */ + ep->re_remote_cma.flow_control = 0; + ep->re_remote_cma.rnr_retry_count = 0; + + ep->re_pd = ib_alloc_pd(device, 0); + if (IS_ERR(ep->re_pd)) { + rc = PTR_ERR(ep->re_pd); + ep->re_pd = NULL; + goto out_destroy; } - dprintk("RPC: %s: memory registration strategy is %d\n", - __func__, memreg); - /* Else will do memory reg/dereg for each chunk */ - ia->ri_memreg_strategy = memreg; + rc = rdma_create_qp(id, ep->re_pd, &ep->re_attr); + if (rc) + goto out_destroy; + r_xprt->rx_ep = ep; return 0; -out2: - rdma_destroy_id(ia->ri_id); - ia->ri_id = NULL; -out1: + +out_destroy: + rpcrdma_ep_put(ep); + rdma_destroy_id(id); return rc; } -/* - * Clean up/close an IA. - * o if event handles and PD have been initialized, free them. - * o close the IA +/** + * rpcrdma_xprt_connect - Connect an unconnected transport + * @r_xprt: controlling transport instance + * + * Returns 0 on success or a negative errno. */ -void -rpcrdma_ia_close(struct rpcrdma_ia *ia) +int rpcrdma_xprt_connect(struct rpcrdma_xprt *r_xprt) { + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpcrdma_ep *ep; int rc; - dprintk("RPC: %s: entering\n", __func__); - if (ia->ri_bind_mem != NULL) { - rc = ib_dereg_mr(ia->ri_bind_mem); - dprintk("RPC: %s: ib_dereg_mr returned %i\n", - __func__, rc); - } - if (ia->ri_id != NULL && !IS_ERR(ia->ri_id)) { - if (ia->ri_id->qp) - rdma_destroy_qp(ia->ri_id); - rdma_destroy_id(ia->ri_id); - ia->ri_id = NULL; - } - if (ia->ri_pd != NULL && !IS_ERR(ia->ri_pd)) { - rc = ib_dealloc_pd(ia->ri_pd); - dprintk("RPC: %s: ib_dealloc_pd returned %i\n", - __func__, rc); - } -} - -/* - * Create unconnected endpoint. - */ -int -rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, - struct rpcrdma_create_data_internal *cdata) -{ - struct ib_device_attr devattr; - int rc, err; - - rc = ib_query_device(ia->ri_id->device, &devattr); - if (rc) { - dprintk("RPC: %s: ib_query_device failed %d\n", - __func__, rc); + rc = rpcrdma_ep_create(r_xprt); + if (rc) return rc; - } + ep = r_xprt->rx_ep; - /* check provider's send/recv wr limits */ - if (cdata->max_requests > devattr.max_qp_wr) - cdata->max_requests = devattr.max_qp_wr; - - ep->rep_attr.event_handler = rpcrdma_qp_async_error_upcall; - ep->rep_attr.qp_context = ep; - /* send_cq and recv_cq initialized below */ - ep->rep_attr.srq = NULL; - ep->rep_attr.cap.max_send_wr = cdata->max_requests; - switch (ia->ri_memreg_strategy) { - case RPCRDMA_FRMR: - /* Add room for frmr register and invalidate WRs. - * 1. FRMR reg WR for head - * 2. FRMR invalidate WR for head - * 3. FRMR reg WR for pagelist - * 4. FRMR invalidate WR for pagelist - * 5. FRMR reg WR for tail - * 6. FRMR invalidate WR for tail - * 7. The RDMA_SEND WR - */ - ep->rep_attr.cap.max_send_wr *= 7; - if (ep->rep_attr.cap.max_send_wr > devattr.max_qp_wr) { - cdata->max_requests = devattr.max_qp_wr / 7; - if (!cdata->max_requests) - return -EINVAL; - ep->rep_attr.cap.max_send_wr = cdata->max_requests * 7; - } - break; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - /* Add room for mw_binds+unbinds - overkill! */ - ep->rep_attr.cap.max_send_wr++; - ep->rep_attr.cap.max_send_wr *= (2 * RPCRDMA_MAX_SEGS); - if (ep->rep_attr.cap.max_send_wr > devattr.max_qp_wr) - return -EINVAL; - break; - default: - break; - } - ep->rep_attr.cap.max_recv_wr = cdata->max_requests; - ep->rep_attr.cap.max_send_sge = (cdata->padding ? 4 : 2); - ep->rep_attr.cap.max_recv_sge = 1; - ep->rep_attr.cap.max_inline_data = 0; - ep->rep_attr.sq_sig_type = IB_SIGNAL_REQ_WR; - ep->rep_attr.qp_type = IB_QPT_RC; - ep->rep_attr.port_num = ~0; - - dprintk("RPC: %s: requested max: dtos: send %d recv %d; " - "iovs: send %d recv %d\n", - __func__, - ep->rep_attr.cap.max_send_wr, - ep->rep_attr.cap.max_recv_wr, - ep->rep_attr.cap.max_send_sge, - ep->rep_attr.cap.max_recv_sge); - - /* set trigger for requesting send completion */ - ep->rep_cqinit = ep->rep_attr.cap.max_send_wr/2 /* - 1*/; - switch (ia->ri_memreg_strategy) { - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - ep->rep_cqinit -= RPCRDMA_MAX_SEGS; - break; - default: - break; - } - if (ep->rep_cqinit <= 2) - ep->rep_cqinit = 0; - INIT_CQCOUNT(ep); - ep->rep_ia = ia; - init_waitqueue_head(&ep->rep_connect_wait); - - /* - * Create a single cq for receive dto and mw_bind (only ever - * care about unbind, really). Send completions are suppressed. - * Use single threaded tasklet upcalls to maintain ordering. + xprt_clear_connected(xprt); + rpcrdma_reset_cwnd(r_xprt); + + /* Bump the ep's reference count while there are + * outstanding Receives. */ - ep->rep_cq = ib_create_cq(ia->ri_id->device, rpcrdma_cq_event_upcall, - rpcrdma_cq_async_error_upcall, NULL, - ep->rep_attr.cap.max_recv_wr + - ep->rep_attr.cap.max_send_wr + 1, 0); - if (IS_ERR(ep->rep_cq)) { - rc = PTR_ERR(ep->rep_cq); - dprintk("RPC: %s: ib_create_cq failed: %i\n", - __func__, rc); - goto out1; + rpcrdma_ep_get(ep); + rpcrdma_post_recvs(r_xprt, 1); + + rc = rdma_connect(ep->re_id, &ep->re_remote_cma); + if (rc) + goto out; + + if (xprt->reestablish_timeout < RPCRDMA_INIT_REEST_TO) + xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO; + wait_event_interruptible(ep->re_connect_wait, + ep->re_connect_status != 0); + if (ep->re_connect_status <= 0) { + rc = ep->re_connect_status; + goto out; } - rc = ib_req_notify_cq(ep->rep_cq, IB_CQ_NEXT_COMP); + rc = rpcrdma_sendctxs_create(r_xprt); if (rc) { - dprintk("RPC: %s: ib_req_notify_cq failed: %i\n", - __func__, rc); - goto out2; + rc = -ENOTCONN; + goto out; } - ep->rep_attr.send_cq = ep->rep_cq; - ep->rep_attr.recv_cq = ep->rep_cq; - - /* Initialize cma parameters */ - - /* RPC/RDMA does not use private data */ - ep->rep_remote_cma.private_data = NULL; - ep->rep_remote_cma.private_data_len = 0; - - /* Client offers RDMA Read but does not initiate */ - ep->rep_remote_cma.initiator_depth = 0; - if (ia->ri_memreg_strategy == RPCRDMA_BOUNCEBUFFERS) - ep->rep_remote_cma.responder_resources = 0; - else if (devattr.max_qp_rd_atom > 32) /* arbitrary but <= 255 */ - ep->rep_remote_cma.responder_resources = 32; - else - ep->rep_remote_cma.responder_resources = devattr.max_qp_rd_atom; - - ep->rep_remote_cma.retry_count = 7; - ep->rep_remote_cma.flow_control = 0; - ep->rep_remote_cma.rnr_retry_count = 0; - - return 0; + rc = rpcrdma_reqs_setup(r_xprt); + if (rc) { + rc = -ENOTCONN; + goto out; + } + rpcrdma_mrs_create(r_xprt); + frwr_wp_create(r_xprt); -out2: - err = ib_destroy_cq(ep->rep_cq); - if (err) - dprintk("RPC: %s: ib_destroy_cq returned %i\n", - __func__, err); -out1: +out: + trace_xprtrdma_connect(r_xprt, rc); return rc; } -/* - * rpcrdma_ep_destroy +/** + * rpcrdma_xprt_disconnect - Disconnect underlying transport + * @r_xprt: controlling transport instance * - * Disconnect and destroy endpoint. After this, the only - * valid operations on the ep are to free it (if dynamically - * allocated) or re-create it. + * Caller serializes. Either the transport send lock is held, + * or we're being called to destroy the transport. * - * The caller's error handling must be sure to not leak the endpoint - * if this function fails. + * On return, @r_xprt is completely divested of all hardware + * resources and prepared for the next ->connect operation. */ -int -rpcrdma_ep_destroy(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) +void rpcrdma_xprt_disconnect(struct rpcrdma_xprt *r_xprt) { + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rdma_cm_id *id; int rc; - dprintk("RPC: %s: entering, connected is %d\n", - __func__, ep->rep_connected); + if (!ep) + return; - if (ia->ri_id->qp) { - rc = rpcrdma_ep_disconnect(ep, ia); - if (rc) - dprintk("RPC: %s: rpcrdma_ep_disconnect" - " returned %i\n", __func__, rc); - rdma_destroy_qp(ia->ri_id); - ia->ri_id->qp = NULL; - } + id = ep->re_id; + rc = rdma_disconnect(id); + trace_xprtrdma_disconnect(r_xprt, rc); - /* padding - could be done in rpcrdma_buffer_destroy... */ - if (ep->rep_pad_mr) { - rpcrdma_deregister_internal(ia, ep->rep_pad_mr, &ep->rep_pad); - ep->rep_pad_mr = NULL; - } + rpcrdma_xprt_drain(r_xprt); + rpcrdma_reps_unmap(r_xprt); + rpcrdma_reqs_reset(r_xprt); + rpcrdma_mrs_destroy(r_xprt); + rpcrdma_sendctxs_destroy(r_xprt); - rpcrdma_clean_cq(ep->rep_cq); - rc = ib_destroy_cq(ep->rep_cq); - if (rc) - dprintk("RPC: %s: ib_destroy_cq returned %i\n", - __func__, rc); + if (rpcrdma_ep_put(ep)) + rdma_destroy_id(id); - return rc; + r_xprt->rx_ep = NULL; } -/* - * Connect unconnected endpoint. +/* Fixed-size circular FIFO queue. This implementation is wait-free and + * lock-free. + * + * Consumer is the code path that posts Sends. This path dequeues a + * sendctx for use by a Send operation. Multiple consumer threads + * are serialized by the RPC transport lock, which allows only one + * ->send_request call at a time. + * + * Producer is the code path that handles Send completions. This path + * enqueues a sendctx that has been completed. Multiple producer + * threads are serialized by the ib_poll_cq() function. */ -int -rpcrdma_ep_connect(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) -{ - struct rdma_cm_id *id; - int rc = 0; - int retry_count = 0; - - if (ep->rep_connected != 0) { - struct rpcrdma_xprt *xprt; -retry: - rc = rpcrdma_ep_disconnect(ep, ia); - if (rc && rc != -ENOTCONN) - dprintk("RPC: %s: rpcrdma_ep_disconnect" - " status %i\n", __func__, rc); - rpcrdma_clean_cq(ep->rep_cq); - - xprt = container_of(ia, struct rpcrdma_xprt, rx_ia); - id = rpcrdma_create_id(xprt, ia, - (struct sockaddr *)&xprt->rx_data.addr); - if (IS_ERR(id)) { - rc = PTR_ERR(id); - goto out; - } - /* TEMP TEMP TEMP - fail if new device: - * Deregister/remarshal *all* requests! - * Close and recreate adapter, pd, etc! - * Re-determine all attributes still sane! - * More stuff I haven't thought of! - * Rrrgh! - */ - if (ia->ri_id->device != id->device) { - printk("RPC: %s: can't reconnect on " - "different device!\n", __func__); - rdma_destroy_id(id); - rc = -ENETDOWN; - goto out; - } - /* END TEMP */ - rdma_destroy_qp(ia->ri_id); - rdma_destroy_id(ia->ri_id); - ia->ri_id = id; - } - rc = rdma_create_qp(ia->ri_id, ia->ri_pd, &ep->rep_attr); - if (rc) { - dprintk("RPC: %s: rdma_create_qp failed %i\n", - __func__, rc); - goto out; - } +/* rpcrdma_sendctxs_destroy() assumes caller has already quiesced + * queue activity, and rpcrdma_xprt_drain has flushed all remaining + * Send requests. + */ +static void rpcrdma_sendctxs_destroy(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + unsigned long i; -/* XXX Tavor device performs badly with 2K MTU! */ -if (strnicmp(ia->ri_id->device->dma_device->bus->name, "pci", 3) == 0) { - struct pci_dev *pcid = to_pci_dev(ia->ri_id->device->dma_device); - if (pcid->device == PCI_DEVICE_ID_MELLANOX_TAVOR && - (pcid->vendor == PCI_VENDOR_ID_MELLANOX || - pcid->vendor == PCI_VENDOR_ID_TOPSPIN)) { - struct ib_qp_attr attr = { - .path_mtu = IB_MTU_1024 - }; - rc = ib_modify_qp(ia->ri_id->qp, &attr, IB_QP_PATH_MTU); - } + if (!buf->rb_sc_ctxs) + return; + for (i = 0; i <= buf->rb_sc_last; i++) + kfree(buf->rb_sc_ctxs[i]); + kfree(buf->rb_sc_ctxs); + buf->rb_sc_ctxs = NULL; } - ep->rep_connected = 0; - - rc = rdma_connect(ia->ri_id, &ep->rep_remote_cma); - if (rc) { - dprintk("RPC: %s: rdma_connect() failed with %i\n", - __func__, rc); - goto out; - } - - wait_event_interruptible(ep->rep_connect_wait, ep->rep_connected != 0); +static struct rpcrdma_sendctx *rpcrdma_sendctx_create(struct rpcrdma_ep *ep) +{ + struct rpcrdma_sendctx *sc; + + sc = kzalloc(struct_size(sc, sc_sges, ep->re_attr.cap.max_send_sge), + XPRTRDMA_GFP_FLAGS); + if (!sc) + return NULL; + + sc->sc_cqe.done = rpcrdma_wc_send; + sc->sc_cid.ci_queue_id = ep->re_attr.send_cq->res.id; + sc->sc_cid.ci_completion_id = + atomic_inc_return(&ep->re_completion_ids); + return sc; +} - /* - * Check state. A non-peer reject indicates no listener - * (ECONNREFUSED), which may be a transient state. All - * others indicate a transport condition which has already - * undergone a best-effort. +static int rpcrdma_sendctxs_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_sendctx *sc; + unsigned long i; + + /* Maximum number of concurrent outstanding Send WRs. Capping + * the circular queue size stops Send Queue overflow by causing + * the ->send_request call to fail temporarily before too many + * Sends are posted. */ - if (ep->rep_connected == -ECONNREFUSED && - ++retry_count <= RDMA_CONNECT_RETRY_MAX) { - dprintk("RPC: %s: non-peer_reject, retry\n", __func__); - goto retry; - } - if (ep->rep_connected <= 0) { - /* Sometimes, the only way to reliably connect to remote - * CMs is to use same nonzero values for ORD and IRD. */ - if (retry_count++ <= RDMA_CONNECT_RETRY_MAX + 1 && - (ep->rep_remote_cma.responder_resources == 0 || - ep->rep_remote_cma.initiator_depth != - ep->rep_remote_cma.responder_resources)) { - if (ep->rep_remote_cma.responder_resources == 0) - ep->rep_remote_cma.responder_resources = 1; - ep->rep_remote_cma.initiator_depth = - ep->rep_remote_cma.responder_resources; - goto retry; - } - rc = ep->rep_connected; - } else { - dprintk("RPC: %s: connected\n", __func__); + i = r_xprt->rx_ep->re_max_requests + RPCRDMA_MAX_BC_REQUESTS; + buf->rb_sc_ctxs = kcalloc(i, sizeof(sc), XPRTRDMA_GFP_FLAGS); + if (!buf->rb_sc_ctxs) + return -ENOMEM; + + buf->rb_sc_last = i - 1; + for (i = 0; i <= buf->rb_sc_last; i++) { + sc = rpcrdma_sendctx_create(r_xprt->rx_ep); + if (!sc) + return -ENOMEM; + + buf->rb_sc_ctxs[i] = sc; } -out: - if (rc) - ep->rep_connected = rc; - return rc; + buf->rb_sc_head = 0; + buf->rb_sc_tail = 0; + return 0; } -/* - * rpcrdma_ep_disconnect +/* The sendctx queue is not guaranteed to have a size that is a + * power of two, thus the helpers in circ_buf.h cannot be used. + * The other option is to use modulus (%), which can be expensive. + */ +static unsigned long rpcrdma_sendctx_next(struct rpcrdma_buffer *buf, + unsigned long item) +{ + return likely(item < buf->rb_sc_last) ? item + 1 : 0; +} + +/** + * rpcrdma_sendctx_get_locked - Acquire a send context + * @r_xprt: controlling transport instance + * + * Returns pointer to a free send completion context; or NULL if + * the queue is empty. * - * This is separate from destroy to facilitate the ability - * to reconnect without recreating the endpoint. + * Usage: Called to acquire an SGE array before preparing a Send WR. * - * This call is not reentrant, and must not be made in parallel - * on the same endpoint. + * The caller serializes calls to this function (per transport), and + * provides an effective memory barrier that flushes the new value + * of rb_sc_head. */ -int -rpcrdma_ep_disconnect(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) +struct rpcrdma_sendctx *rpcrdma_sendctx_get_locked(struct rpcrdma_xprt *r_xprt) { - int rc; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_sendctx *sc; + unsigned long next_head; - rpcrdma_clean_cq(ep->rep_cq); - rc = rdma_disconnect(ia->ri_id); - if (!rc) { - /* returns without wait if not connected */ - wait_event_interruptible(ep->rep_connect_wait, - ep->rep_connected != 1); - dprintk("RPC: %s: after wait, %sconnected\n", __func__, - (ep->rep_connected == 1) ? "still " : "dis"); - } else { - dprintk("RPC: %s: rdma_disconnect %i\n", __func__, rc); - ep->rep_connected = rc; - } - return rc; + next_head = rpcrdma_sendctx_next(buf, buf->rb_sc_head); + + if (next_head == READ_ONCE(buf->rb_sc_tail)) + goto out_emptyq; + + /* ORDER: item must be accessed _before_ head is updated */ + sc = buf->rb_sc_ctxs[next_head]; + + /* Releasing the lock in the caller acts as a memory + * barrier that flushes rb_sc_head. + */ + buf->rb_sc_head = next_head; + + return sc; + +out_emptyq: + /* The queue is "empty" if there have not been enough Send + * completions recently. This is a sign the Send Queue is + * backing up. Cause the caller to pause and try again. + */ + xprt_wait_for_buffer_space(&r_xprt->rx_xprt); + r_xprt->rx_stats.empty_sendctx_q++; + return NULL; } -/* - * Initialize buffer memory +/** + * rpcrdma_sendctx_put_locked - Release a send context + * @r_xprt: controlling transport instance + * @sc: send context to release + * + * Usage: Called from Send completion to return a sendctxt + * to the queue. + * + * The caller serializes calls to this function (per transport). */ -int -rpcrdma_buffer_create(struct rpcrdma_buffer *buf, struct rpcrdma_ep *ep, - struct rpcrdma_ia *ia, struct rpcrdma_create_data_internal *cdata) +static void rpcrdma_sendctx_put_locked(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_sendctx *sc) { - char *p; - size_t len; - int i, rc; - struct rpcrdma_mw *r; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + unsigned long next_tail; - buf->rb_max_requests = cdata->max_requests; - spin_lock_init(&buf->rb_lock); - atomic_set(&buf->rb_credits, 1); - - /* Need to allocate: - * 1. arrays for send and recv pointers - * 2. arrays of struct rpcrdma_req to fill in pointers - * 3. array of struct rpcrdma_rep for replies - * 4. padding, if any - * 5. mw's, fmr's or frmr's, if any - * Send/recv buffers in req/rep need to be registered + /* Unmap SGEs of previously completed but unsignaled + * Sends by walking up the queue until @sc is found. */ + next_tail = buf->rb_sc_tail; + do { + next_tail = rpcrdma_sendctx_next(buf, next_tail); - len = buf->rb_max_requests * - (sizeof(struct rpcrdma_req *) + sizeof(struct rpcrdma_rep *)); - len += cdata->padding; - switch (ia->ri_memreg_strategy) { - case RPCRDMA_FRMR: - len += buf->rb_max_requests * RPCRDMA_MAX_SEGS * - sizeof(struct rpcrdma_mw); - break; - case RPCRDMA_MTHCAFMR: - /* TBD we are perhaps overallocating here */ - len += (buf->rb_max_requests + 1) * RPCRDMA_MAX_SEGS * - sizeof(struct rpcrdma_mw); - break; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - len += (buf->rb_max_requests + 1) * RPCRDMA_MAX_SEGS * - sizeof(struct rpcrdma_mw); - break; - default: - break; - } + /* ORDER: item must be accessed _before_ tail is updated */ + rpcrdma_sendctx_unmap(buf->rb_sc_ctxs[next_tail]); - /* allocate 1, 4 and 5 in one shot */ - p = kzalloc(len, GFP_KERNEL); - if (p == NULL) { - dprintk("RPC: %s: req_t/rep_t/pad kzalloc(%zd) failed\n", - __func__, len); - rc = -ENOMEM; - goto out; - } - buf->rb_pool = p; /* for freeing it later */ + } while (buf->rb_sc_ctxs[next_tail] != sc); - buf->rb_send_bufs = (struct rpcrdma_req **) p; - p = (char *) &buf->rb_send_bufs[buf->rb_max_requests]; - buf->rb_recv_bufs = (struct rpcrdma_rep **) p; - p = (char *) &buf->rb_recv_bufs[buf->rb_max_requests]; + /* Paired with READ_ONCE */ + smp_store_release(&buf->rb_sc_tail, next_tail); - /* - * Register the zeroed pad buffer, if any. - */ - if (cdata->padding) { - rc = rpcrdma_register_internal(ia, p, cdata->padding, - &ep->rep_pad_mr, &ep->rep_pad); - if (rc) - goto out; - } - p += cdata->padding; + xprt_write_space(&r_xprt->rx_xprt); +} - /* - * Allocate the fmr's, or mw's for mw_bind chunk registration. - * We "cycle" the mw's in order to minimize rkey reuse, - * and also reduce unbind-to-bind collision. - */ - INIT_LIST_HEAD(&buf->rb_mws); - r = (struct rpcrdma_mw *)p; - switch (ia->ri_memreg_strategy) { - case RPCRDMA_FRMR: - for (i = buf->rb_max_requests * RPCRDMA_MAX_SEGS; i; i--) { - r->r.frmr.fr_mr = ib_alloc_fast_reg_mr(ia->ri_pd, - RPCRDMA_MAX_SEGS); - if (IS_ERR(r->r.frmr.fr_mr)) { - rc = PTR_ERR(r->r.frmr.fr_mr); - dprintk("RPC: %s: ib_alloc_fast_reg_mr" - " failed %i\n", __func__, rc); - goto out; - } - r->r.frmr.fr_pgl = - ib_alloc_fast_reg_page_list(ia->ri_id->device, - RPCRDMA_MAX_SEGS); - if (IS_ERR(r->r.frmr.fr_pgl)) { - rc = PTR_ERR(r->r.frmr.fr_pgl); - dprintk("RPC: %s: " - "ib_alloc_fast_reg_page_list " - "failed %i\n", __func__, rc); - goto out; - } - list_add(&r->mw_list, &buf->rb_mws); - ++r; - } - break; - case RPCRDMA_MTHCAFMR: - /* TBD we are perhaps overallocating here */ - for (i = (buf->rb_max_requests+1) * RPCRDMA_MAX_SEGS; i; i--) { - static struct ib_fmr_attr fa = - { RPCRDMA_MAX_DATA_SEGS, 1, PAGE_SHIFT }; - r->r.fmr = ib_alloc_fmr(ia->ri_pd, - IB_ACCESS_REMOTE_WRITE | IB_ACCESS_REMOTE_READ, - &fa); - if (IS_ERR(r->r.fmr)) { - rc = PTR_ERR(r->r.fmr); - dprintk("RPC: %s: ib_alloc_fmr" - " failed %i\n", __func__, rc); - goto out; - } - list_add(&r->mw_list, &buf->rb_mws); - ++r; - } - break; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - /* Allocate one extra request's worth, for full cycling */ - for (i = (buf->rb_max_requests+1) * RPCRDMA_MAX_SEGS; i; i--) { - r->r.mw = ib_alloc_mw(ia->ri_pd, IB_MW_TYPE_1); - if (IS_ERR(r->r.mw)) { - rc = PTR_ERR(r->r.mw); - dprintk("RPC: %s: ib_alloc_mw" - " failed %i\n", __func__, rc); - goto out; - } - list_add(&r->mw_list, &buf->rb_mws); - ++r; - } - break; - default: - break; - } +static void +rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct ib_device *device = ep->re_id->device; + unsigned int count; + + /* Try to allocate enough to perform one full-sized I/O */ + for (count = 0; count < ep->re_max_rdma_segs; count++) { + struct rpcrdma_mr *mr; + int rc; + + mr = kzalloc_node(sizeof(*mr), XPRTRDMA_GFP_FLAGS, + ibdev_to_node(device)); + if (!mr) + break; - /* - * Allocate/init the request/reply buffers. Doing this - * using kmalloc for now -- one for each buf. - */ - for (i = 0; i < buf->rb_max_requests; i++) { - struct rpcrdma_req *req; - struct rpcrdma_rep *rep; - - len = cdata->inline_wsize + sizeof(struct rpcrdma_req); - /* RPC layer requests *double* size + 1K RPC_SLACK_SPACE! */ - /* Typical ~2400b, so rounding up saves work later */ - if (len < 4096) - len = 4096; - req = kmalloc(len, GFP_KERNEL); - if (req == NULL) { - dprintk("RPC: %s: request buffer %d alloc" - " failed\n", __func__, i); - rc = -ENOMEM; - goto out; + rc = frwr_mr_init(r_xprt, mr); + if (rc) { + kfree(mr); + break; } - memset(req, 0, sizeof(struct rpcrdma_req)); - buf->rb_send_bufs[i] = req; - buf->rb_send_bufs[i]->rl_buffer = buf; - - rc = rpcrdma_register_internal(ia, req->rl_base, - len - offsetof(struct rpcrdma_req, rl_base), - &buf->rb_send_bufs[i]->rl_handle, - &buf->rb_send_bufs[i]->rl_iov); - if (rc) - goto out; - buf->rb_send_bufs[i]->rl_size = len-sizeof(struct rpcrdma_req); + spin_lock(&buf->rb_lock); + rpcrdma_mr_push(mr, &buf->rb_mrs); + list_add(&mr->mr_all, &buf->rb_all_mrs); + spin_unlock(&buf->rb_lock); + } - len = cdata->inline_rsize + sizeof(struct rpcrdma_rep); - rep = kmalloc(len, GFP_KERNEL); - if (rep == NULL) { - dprintk("RPC: %s: reply buffer %d alloc failed\n", - __func__, i); - rc = -ENOMEM; - goto out; - } - memset(rep, 0, sizeof(struct rpcrdma_rep)); - buf->rb_recv_bufs[i] = rep; - buf->rb_recv_bufs[i]->rr_buffer = buf; - init_waitqueue_head(&rep->rr_unbind); - - rc = rpcrdma_register_internal(ia, rep->rr_base, - len - offsetof(struct rpcrdma_rep, rr_base), - &buf->rb_recv_bufs[i]->rr_handle, - &buf->rb_recv_bufs[i]->rr_iov); - if (rc) - goto out; + r_xprt->rx_stats.mrs_allocated += count; + trace_xprtrdma_createmrs(r_xprt, count); +} - } - dprintk("RPC: %s: max_requests %d\n", - __func__, buf->rb_max_requests); - /* done */ - return 0; -out: - rpcrdma_buffer_destroy(buf); - return rc; +static void +rpcrdma_mr_refresh_worker(struct work_struct *work) +{ + struct rpcrdma_buffer *buf = container_of(work, struct rpcrdma_buffer, + rb_refresh_worker); + struct rpcrdma_xprt *r_xprt = container_of(buf, struct rpcrdma_xprt, + rx_buf); + + rpcrdma_mrs_create(r_xprt); + xprt_write_space(&r_xprt->rx_xprt); } -/* - * Unregister and destroy buffer memory. Need to deal with - * partial initialization, so it's callable from failed create. - * Must be called before destroying endpoint, as registrations - * reference it. +/** + * rpcrdma_mrs_refresh - Wake the MR refresh worker + * @r_xprt: controlling transport instance + * */ -void -rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf) +void rpcrdma_mrs_refresh(struct rpcrdma_xprt *r_xprt) { - int rc, i; - struct rpcrdma_ia *ia = rdmab_to_ia(buf); - struct rpcrdma_mw *r; - - /* clean up in reverse order from create - * 1. recv mr memory (mr free, then kfree) - * 1a. bind mw memory - * 2. send mr memory (mr free, then kfree) - * 3. padding (if any) [moved to rpcrdma_ep_destroy] - * 4. arrays - */ - dprintk("RPC: %s: entering\n", __func__); - - for (i = 0; i < buf->rb_max_requests; i++) { - if (buf->rb_recv_bufs && buf->rb_recv_bufs[i]) { - rpcrdma_deregister_internal(ia, - buf->rb_recv_bufs[i]->rr_handle, - &buf->rb_recv_bufs[i]->rr_iov); - kfree(buf->rb_recv_bufs[i]); - } - if (buf->rb_send_bufs && buf->rb_send_bufs[i]) { - while (!list_empty(&buf->rb_mws)) { - r = list_entry(buf->rb_mws.next, - struct rpcrdma_mw, mw_list); - list_del(&r->mw_list); - switch (ia->ri_memreg_strategy) { - case RPCRDMA_FRMR: - rc = ib_dereg_mr(r->r.frmr.fr_mr); - if (rc) - dprintk("RPC: %s:" - " ib_dereg_mr" - " failed %i\n", - __func__, rc); - ib_free_fast_reg_page_list(r->r.frmr.fr_pgl); - break; - case RPCRDMA_MTHCAFMR: - rc = ib_dealloc_fmr(r->r.fmr); - if (rc) - dprintk("RPC: %s:" - " ib_dealloc_fmr" - " failed %i\n", - __func__, rc); - break; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - rc = ib_dealloc_mw(r->r.mw); - if (rc) - dprintk("RPC: %s:" - " ib_dealloc_mw" - " failed %i\n", - __func__, rc); - break; - default: - break; - } - } - rpcrdma_deregister_internal(ia, - buf->rb_send_bufs[i]->rl_handle, - &buf->rb_send_bufs[i]->rl_iov); - kfree(buf->rb_send_bufs[i]); - } - } + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; - kfree(buf->rb_pool); + /* If there is no underlying connection, it's no use + * to wake the refresh worker. + */ + if (ep->re_connect_status != 1) + return; + queue_work(system_highpri_wq, &buf->rb_refresh_worker); } -/* - * Get a set of request/reply buffers. +/** + * rpcrdma_req_create - Allocate an rpcrdma_req object + * @r_xprt: controlling r_xprt + * @size: initial size, in bytes, of send and receive buffers * - * Reply buffer (if needed) is attached to send buffer upon return. - * Rule: - * rb_send_index and rb_recv_index MUST always be pointing to the - * *next* available buffer (non-NULL). They are incremented after - * removing buffers, and decremented *before* returning them. + * Returns an allocated and fully initialized rpcrdma_req or NULL. */ -struct rpcrdma_req * -rpcrdma_buffer_get(struct rpcrdma_buffer *buffers) +struct rpcrdma_req *rpcrdma_req_create(struct rpcrdma_xprt *r_xprt, + size_t size) { + struct rpcrdma_buffer *buffer = &r_xprt->rx_buf; struct rpcrdma_req *req; - unsigned long flags; - int i; - struct rpcrdma_mw *r; - - spin_lock_irqsave(&buffers->rb_lock, flags); - if (buffers->rb_send_index == buffers->rb_max_requests) { - spin_unlock_irqrestore(&buffers->rb_lock, flags); - dprintk("RPC: %s: out of request buffers\n", __func__); - return ((struct rpcrdma_req *)NULL); - } - req = buffers->rb_send_bufs[buffers->rb_send_index]; - if (buffers->rb_send_index < buffers->rb_recv_index) { - dprintk("RPC: %s: %d extra receives outstanding (ok)\n", - __func__, - buffers->rb_recv_index - buffers->rb_send_index); - req->rl_reply = NULL; - } else { - req->rl_reply = buffers->rb_recv_bufs[buffers->rb_recv_index]; - buffers->rb_recv_bufs[buffers->rb_recv_index++] = NULL; - } - buffers->rb_send_bufs[buffers->rb_send_index++] = NULL; - if (!list_empty(&buffers->rb_mws)) { - i = RPCRDMA_MAX_SEGS - 1; - do { - r = list_entry(buffers->rb_mws.next, - struct rpcrdma_mw, mw_list); - list_del(&r->mw_list); - req->rl_segments[i].mr_chunk.rl_mw = r; - } while (--i >= 0); - } - spin_unlock_irqrestore(&buffers->rb_lock, flags); + req = kzalloc(sizeof(*req), XPRTRDMA_GFP_FLAGS); + if (req == NULL) + goto out1; + + req->rl_sendbuf = rpcrdma_regbuf_alloc(size, DMA_TO_DEVICE); + if (!req->rl_sendbuf) + goto out2; + + req->rl_recvbuf = rpcrdma_regbuf_alloc(size, DMA_NONE); + if (!req->rl_recvbuf) + goto out3; + + INIT_LIST_HEAD(&req->rl_free_mrs); + INIT_LIST_HEAD(&req->rl_registered); + spin_lock(&buffer->rb_lock); + list_add(&req->rl_all, &buffer->rb_allreqs); + spin_unlock(&buffer->rb_lock); return req; + +out3: + rpcrdma_regbuf_free(req->rl_sendbuf); +out2: + kfree(req); +out1: + return NULL; } -/* - * Put request/reply buffers back into pool. - * Pre-decrement counter/array index. +/** + * rpcrdma_req_setup - Per-connection instance setup of an rpcrdma_req object + * @r_xprt: controlling transport instance + * @req: rpcrdma_req object to set up + * + * Returns zero on success, and a negative errno on failure. */ -void -rpcrdma_buffer_put(struct rpcrdma_req *req) +int rpcrdma_req_setup(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req) { - struct rpcrdma_buffer *buffers = req->rl_buffer; - struct rpcrdma_ia *ia = rdmab_to_ia(buffers); - int i; - unsigned long flags; - - BUG_ON(req->rl_nchunks != 0); - spin_lock_irqsave(&buffers->rb_lock, flags); - buffers->rb_send_bufs[--buffers->rb_send_index] = req; - req->rl_niovs = 0; - if (req->rl_reply) { - buffers->rb_recv_bufs[--buffers->rb_recv_index] = req->rl_reply; - init_waitqueue_head(&req->rl_reply->rr_unbind); - req->rl_reply->rr_func = NULL; - req->rl_reply = NULL; - } - switch (ia->ri_memreg_strategy) { - case RPCRDMA_FRMR: - case RPCRDMA_MTHCAFMR: - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - /* - * Cycle mw's back in reverse order, and "spin" them. - * This delays and scrambles reuse as much as possible. - */ - i = 1; - do { - struct rpcrdma_mw **mw; - mw = &req->rl_segments[i].mr_chunk.rl_mw; - list_add_tail(&(*mw)->mw_list, &buffers->rb_mws); - *mw = NULL; - } while (++i < RPCRDMA_MAX_SEGS); - list_add_tail(&req->rl_segments[0].mr_chunk.rl_mw->mw_list, - &buffers->rb_mws); - req->rl_segments[0].mr_chunk.rl_mw = NULL; - break; - default: - break; - } - spin_unlock_irqrestore(&buffers->rb_lock, flags); + struct rpcrdma_regbuf *rb; + size_t maxhdrsize; + + /* Compute maximum header buffer size in bytes */ + maxhdrsize = rpcrdma_fixed_maxsz + 3 + + r_xprt->rx_ep->re_max_rdma_segs * rpcrdma_readchunk_maxsz; + maxhdrsize *= sizeof(__be32); + rb = rpcrdma_regbuf_alloc(__roundup_pow_of_two(maxhdrsize), + DMA_TO_DEVICE); + if (!rb) + goto out; + + if (!__rpcrdma_regbuf_dma_map(r_xprt, rb)) + goto out_free; + + req->rl_rdmabuf = rb; + xdr_buf_init(&req->rl_hdrbuf, rdmab_data(rb), rdmab_length(rb)); + return 0; + +out_free: + rpcrdma_regbuf_free(rb); +out: + return -ENOMEM; } -/* - * Recover reply buffers from pool. - * This happens when recovering from error conditions. - * Post-increment counter/array index. +/* ASSUMPTION: the rb_allreqs list is stable for the duration, + * and thus can be walked without holding rb_lock. Eg. the + * caller is holding the transport send lock to exclude + * device removal or disconnection. */ -void -rpcrdma_recv_buffer_get(struct rpcrdma_req *req) +static int rpcrdma_reqs_setup(struct rpcrdma_xprt *r_xprt) { - struct rpcrdma_buffer *buffers = req->rl_buffer; - unsigned long flags; - - if (req->rl_iov.length == 0) /* special case xprt_rdma_allocate() */ - buffers = ((struct rpcrdma_req *) buffers)->rl_buffer; - spin_lock_irqsave(&buffers->rb_lock, flags); - if (buffers->rb_recv_index < buffers->rb_max_requests) { - req->rl_reply = buffers->rb_recv_bufs[buffers->rb_recv_index]; - buffers->rb_recv_bufs[buffers->rb_recv_index++] = NULL; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_req *req; + int rc; + + list_for_each_entry(req, &buf->rb_allreqs, rl_all) { + rc = rpcrdma_req_setup(r_xprt, req); + if (rc) + return rc; } - spin_unlock_irqrestore(&buffers->rb_lock, flags); + return 0; } -/* - * Put reply buffers back into pool when not attached to - * request. This happens in error conditions, and when - * aborting unbinds. Pre-decrement counter/array index. - */ -void -rpcrdma_recv_buffer_put(struct rpcrdma_rep *rep) +static void rpcrdma_req_reset(struct rpcrdma_req *req) { - struct rpcrdma_buffer *buffers = rep->rr_buffer; - unsigned long flags; + struct rpcrdma_mr *mr; - rep->rr_func = NULL; - spin_lock_irqsave(&buffers->rb_lock, flags); - buffers->rb_recv_bufs[--buffers->rb_recv_index] = rep; - spin_unlock_irqrestore(&buffers->rb_lock, flags); -} + /* Credits are valid for only one connection */ + req->rl_slot.rq_cong = 0; -/* - * Wrappers for internal-use kmalloc memory registration, used by buffer code. - */ + rpcrdma_regbuf_free(req->rl_rdmabuf); + req->rl_rdmabuf = NULL; -int -rpcrdma_register_internal(struct rpcrdma_ia *ia, void *va, int len, - struct ib_mr **mrp, struct ib_sge *iov) -{ - struct ib_phys_buf ipb; - struct ib_mr *mr; - int rc; + rpcrdma_regbuf_dma_unmap(req->rl_sendbuf); + rpcrdma_regbuf_dma_unmap(req->rl_recvbuf); - /* - * All memory passed here was kmalloc'ed, therefore phys-contiguous. + /* The verbs consumer can't know the state of an MR on the + * req->rl_registered list unless a successful completion + * has occurred, so they cannot be re-used. */ - iov->addr = ib_dma_map_single(ia->ri_id->device, - va, len, DMA_BIDIRECTIONAL); - iov->length = len; + while ((mr = rpcrdma_mr_pop(&req->rl_registered))) { + struct rpcrdma_buffer *buf = &mr->mr_xprt->rx_buf; - if (ia->ri_have_dma_lkey) { - *mrp = NULL; - iov->lkey = ia->ri_dma_lkey; - return 0; - } else if (ia->ri_bind_mem != NULL) { - *mrp = NULL; - iov->lkey = ia->ri_bind_mem->lkey; - return 0; - } + spin_lock(&buf->rb_lock); + list_del(&mr->mr_all); + spin_unlock(&buf->rb_lock); - ipb.addr = iov->addr; - ipb.size = iov->length; - mr = ib_reg_phys_mr(ia->ri_pd, &ipb, 1, - IB_ACCESS_LOCAL_WRITE, &iov->addr); - - dprintk("RPC: %s: phys convert: 0x%llx " - "registered 0x%llx length %d\n", - __func__, (unsigned long long)ipb.addr, - (unsigned long long)iov->addr, len); - - if (IS_ERR(mr)) { - *mrp = NULL; - rc = PTR_ERR(mr); - dprintk("RPC: %s: failed with %i\n", __func__, rc); - } else { - *mrp = mr; - iov->lkey = mr->lkey; - rc = 0; + frwr_mr_release(mr); } +} - return rc; +/* ASSUMPTION: the rb_allreqs list is stable for the duration, + * and thus can be walked without holding rb_lock. Eg. the + * caller is holding the transport send lock to exclude + * device removal or disconnection. + */ +static void rpcrdma_reqs_reset(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_req *req; + + list_for_each_entry(req, &buf->rb_allreqs, rl_all) + rpcrdma_req_reset(req); } -int -rpcrdma_deregister_internal(struct rpcrdma_ia *ia, - struct ib_mr *mr, struct ib_sge *iov) +static noinline +struct rpcrdma_rep *rpcrdma_rep_create(struct rpcrdma_xprt *r_xprt) { - int rc; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct ib_device *device = ep->re_id->device; + struct rpcrdma_rep *rep; - ib_dma_unmap_single(ia->ri_id->device, - iov->addr, iov->length, DMA_BIDIRECTIONAL); + rep = kzalloc(sizeof(*rep), XPRTRDMA_GFP_FLAGS); + if (rep == NULL) + goto out; - if (NULL == mr) - return 0; + rep->rr_rdmabuf = rpcrdma_regbuf_alloc_node(ep->re_inline_recv, + DMA_FROM_DEVICE, + ibdev_to_node(device)); + if (!rep->rr_rdmabuf) + goto out_free; + + rep->rr_cid.ci_completion_id = + atomic_inc_return(&r_xprt->rx_ep->re_completion_ids); + + xdr_buf_init(&rep->rr_hdrbuf, rdmab_data(rep->rr_rdmabuf), + rdmab_length(rep->rr_rdmabuf)); + rep->rr_cqe.done = rpcrdma_wc_receive; + rep->rr_rxprt = r_xprt; + rep->rr_recv_wr.next = NULL; + rep->rr_recv_wr.wr_cqe = &rep->rr_cqe; + rep->rr_recv_wr.sg_list = &rep->rr_rdmabuf->rg_iov; + rep->rr_recv_wr.num_sge = 1; + + spin_lock(&buf->rb_lock); + list_add(&rep->rr_all, &buf->rb_all_reps); + spin_unlock(&buf->rb_lock); + return rep; + +out_free: + kfree(rep); +out: + return NULL; +} - rc = ib_dereg_mr(mr); - if (rc) - dprintk("RPC: %s: ib_dereg_mr failed %i\n", __func__, rc); - return rc; +static void rpcrdma_rep_free(struct rpcrdma_rep *rep) +{ + rpcrdma_regbuf_free(rep->rr_rdmabuf); + kfree(rep); } -/* - * Wrappers for chunk registration, shared by read/write chunk code. - */ +static struct rpcrdma_rep *rpcrdma_rep_get_locked(struct rpcrdma_buffer *buf) +{ + struct llist_node *node; -static void -rpcrdma_map_one(struct rpcrdma_ia *ia, struct rpcrdma_mr_seg *seg, int writing) + /* Calls to llist_del_first are required to be serialized */ + node = llist_del_first(&buf->rb_free_reps); + if (!node) + return NULL; + return llist_entry(node, struct rpcrdma_rep, rr_node); +} + +/** + * rpcrdma_rep_put - Release rpcrdma_rep back to free list + * @buf: buffer pool + * @rep: rep to release + * + */ +void rpcrdma_rep_put(struct rpcrdma_buffer *buf, struct rpcrdma_rep *rep) { - seg->mr_dir = writing ? DMA_FROM_DEVICE : DMA_TO_DEVICE; - seg->mr_dmalen = seg->mr_len; - if (seg->mr_page) - seg->mr_dma = ib_dma_map_page(ia->ri_id->device, - seg->mr_page, offset_in_page(seg->mr_offset), - seg->mr_dmalen, seg->mr_dir); - else - seg->mr_dma = ib_dma_map_single(ia->ri_id->device, - seg->mr_offset, - seg->mr_dmalen, seg->mr_dir); - if (ib_dma_mapping_error(ia->ri_id->device, seg->mr_dma)) { - dprintk("RPC: %s: mr_dma %llx mr_offset %p mr_dma_len %zu\n", - __func__, - (unsigned long long)seg->mr_dma, - seg->mr_offset, seg->mr_dmalen); - } + llist_add(&rep->rr_node, &buf->rb_free_reps); } -static void -rpcrdma_unmap_one(struct rpcrdma_ia *ia, struct rpcrdma_mr_seg *seg) +/* Caller must ensure the QP is quiescent (RQ is drained) before + * invoking this function, to guarantee rb_all_reps is not + * changing. + */ +static void rpcrdma_reps_unmap(struct rpcrdma_xprt *r_xprt) { - if (seg->mr_page) - ib_dma_unmap_page(ia->ri_id->device, - seg->mr_dma, seg->mr_dmalen, seg->mr_dir); - else - ib_dma_unmap_single(ia->ri_id->device, - seg->mr_dma, seg->mr_dmalen, seg->mr_dir); + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_rep *rep; + + list_for_each_entry(rep, &buf->rb_all_reps, rr_all) + rpcrdma_regbuf_dma_unmap(rep->rr_rdmabuf); } -static int -rpcrdma_register_frmr_external(struct rpcrdma_mr_seg *seg, - int *nsegs, int writing, struct rpcrdma_ia *ia, - struct rpcrdma_xprt *r_xprt) +static void rpcrdma_reps_destroy(struct rpcrdma_buffer *buf) { - struct rpcrdma_mr_seg *seg1 = seg; - struct ib_send_wr invalidate_wr, frmr_wr, *bad_wr, *post_wr; + struct rpcrdma_rep *rep; - u8 key; - int len, pageoff; - int i, rc; - int seg_len; - u64 pa; - int page_no; - - pageoff = offset_in_page(seg1->mr_offset); - seg1->mr_offset -= pageoff; /* start of page */ - seg1->mr_len += pageoff; - len = -pageoff; - if (*nsegs > RPCRDMA_MAX_DATA_SEGS) - *nsegs = RPCRDMA_MAX_DATA_SEGS; - for (page_no = i = 0; i < *nsegs;) { - rpcrdma_map_one(ia, seg, writing); - pa = seg->mr_dma; - for (seg_len = seg->mr_len; seg_len > 0; seg_len -= PAGE_SIZE) { - seg1->mr_chunk.rl_mw->r.frmr.fr_pgl-> - page_list[page_no++] = pa; - pa += PAGE_SIZE; - } - len += seg->mr_len; - ++seg; - ++i; - /* Check for holes */ - if ((i < *nsegs && offset_in_page(seg->mr_offset)) || - offset_in_page((seg-1)->mr_offset + (seg-1)->mr_len)) - break; - } - dprintk("RPC: %s: Using frmr %p to map %d segments\n", - __func__, seg1->mr_chunk.rl_mw, i); - - if (unlikely(seg1->mr_chunk.rl_mw->r.frmr.state == FRMR_IS_VALID)) { - dprintk("RPC: %s: frmr %x left valid, posting invalidate.\n", - __func__, - seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey); - /* Invalidate before using. */ - memset(&invalidate_wr, 0, sizeof invalidate_wr); - invalidate_wr.wr_id = (unsigned long)(void *)seg1->mr_chunk.rl_mw; - invalidate_wr.next = &frmr_wr; - invalidate_wr.opcode = IB_WR_LOCAL_INV; - invalidate_wr.send_flags = IB_SEND_SIGNALED; - invalidate_wr.ex.invalidate_rkey = - seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; - DECR_CQCOUNT(&r_xprt->rx_ep); - post_wr = &invalidate_wr; - } else - post_wr = &frmr_wr; - - /* Bump the key */ - key = (u8)(seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey & 0x000000FF); - ib_update_fast_reg_key(seg1->mr_chunk.rl_mw->r.frmr.fr_mr, ++key); - - /* Prepare FRMR WR */ - memset(&frmr_wr, 0, sizeof frmr_wr); - frmr_wr.wr_id = (unsigned long)(void *)seg1->mr_chunk.rl_mw; - frmr_wr.opcode = IB_WR_FAST_REG_MR; - frmr_wr.send_flags = IB_SEND_SIGNALED; - frmr_wr.wr.fast_reg.iova_start = seg1->mr_dma; - frmr_wr.wr.fast_reg.page_list = seg1->mr_chunk.rl_mw->r.frmr.fr_pgl; - frmr_wr.wr.fast_reg.page_list_len = page_no; - frmr_wr.wr.fast_reg.page_shift = PAGE_SHIFT; - frmr_wr.wr.fast_reg.length = page_no << PAGE_SHIFT; - BUG_ON(frmr_wr.wr.fast_reg.length < len); - frmr_wr.wr.fast_reg.access_flags = (writing ? - IB_ACCESS_REMOTE_WRITE | IB_ACCESS_LOCAL_WRITE : - IB_ACCESS_REMOTE_READ); - frmr_wr.wr.fast_reg.rkey = seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; - DECR_CQCOUNT(&r_xprt->rx_ep); - - rc = ib_post_send(ia->ri_id->qp, post_wr, &bad_wr); + spin_lock(&buf->rb_lock); + while ((rep = list_first_entry_or_null(&buf->rb_all_reps, + struct rpcrdma_rep, + rr_all)) != NULL) { + list_del(&rep->rr_all); + spin_unlock(&buf->rb_lock); - if (rc) { - dprintk("RPC: %s: failed ib_post_send for register," - " status %i\n", __func__, rc); - while (i--) - rpcrdma_unmap_one(ia, --seg); - } else { - seg1->mr_rkey = seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; - seg1->mr_base = seg1->mr_dma + pageoff; - seg1->mr_nsegs = i; - seg1->mr_len = len; + rpcrdma_rep_free(rep); + + spin_lock(&buf->rb_lock); } - *nsegs = i; - return rc; + spin_unlock(&buf->rb_lock); } -static int -rpcrdma_deregister_frmr_external(struct rpcrdma_mr_seg *seg, - struct rpcrdma_ia *ia, struct rpcrdma_xprt *r_xprt) +/** + * rpcrdma_buffer_create - Create initial set of req/rep objects + * @r_xprt: transport instance to (re)initialize + * + * Returns zero on success, otherwise a negative errno. + */ +int rpcrdma_buffer_create(struct rpcrdma_xprt *r_xprt) { - struct rpcrdma_mr_seg *seg1 = seg; - struct ib_send_wr invalidate_wr, *bad_wr; - int rc; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + int i, rc; - while (seg1->mr_nsegs--) - rpcrdma_unmap_one(ia, seg++); + buf->rb_bc_srv_max_requests = 0; + spin_lock_init(&buf->rb_lock); + INIT_LIST_HEAD(&buf->rb_mrs); + INIT_LIST_HEAD(&buf->rb_all_mrs); + INIT_WORK(&buf->rb_refresh_worker, rpcrdma_mr_refresh_worker); - memset(&invalidate_wr, 0, sizeof invalidate_wr); - invalidate_wr.wr_id = (unsigned long)(void *)seg1->mr_chunk.rl_mw; - invalidate_wr.opcode = IB_WR_LOCAL_INV; - invalidate_wr.send_flags = IB_SEND_SIGNALED; - invalidate_wr.ex.invalidate_rkey = seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; - DECR_CQCOUNT(&r_xprt->rx_ep); + INIT_LIST_HEAD(&buf->rb_send_bufs); + INIT_LIST_HEAD(&buf->rb_allreqs); + INIT_LIST_HEAD(&buf->rb_all_reps); - rc = ib_post_send(ia->ri_id->qp, &invalidate_wr, &bad_wr); - if (rc) - dprintk("RPC: %s: failed ib_post_send for invalidate," - " status %i\n", __func__, rc); + rc = -ENOMEM; + for (i = 0; i < r_xprt->rx_xprt.max_reqs; i++) { + struct rpcrdma_req *req; + + req = rpcrdma_req_create(r_xprt, + RPCRDMA_V1_DEF_INLINE_SIZE * 2); + if (!req) + goto out; + list_add(&req->rl_list, &buf->rb_send_bufs); + } + + init_llist_head(&buf->rb_free_reps); + + return 0; +out: + rpcrdma_buffer_destroy(buf); return rc; } -static int -rpcrdma_register_fmr_external(struct rpcrdma_mr_seg *seg, - int *nsegs, int writing, struct rpcrdma_ia *ia) +/** + * rpcrdma_req_destroy - Destroy an rpcrdma_req object + * @req: unused object to be destroyed + * + * Relies on caller holding the transport send lock to protect + * removing req->rl_all from buf->rb_all_reqs safely. + */ +void rpcrdma_req_destroy(struct rpcrdma_req *req) { - struct rpcrdma_mr_seg *seg1 = seg; - u64 physaddrs[RPCRDMA_MAX_DATA_SEGS]; - int len, pageoff, i, rc; - - pageoff = offset_in_page(seg1->mr_offset); - seg1->mr_offset -= pageoff; /* start of page */ - seg1->mr_len += pageoff; - len = -pageoff; - if (*nsegs > RPCRDMA_MAX_DATA_SEGS) - *nsegs = RPCRDMA_MAX_DATA_SEGS; - for (i = 0; i < *nsegs;) { - rpcrdma_map_one(ia, seg, writing); - physaddrs[i] = seg->mr_dma; - len += seg->mr_len; - ++seg; - ++i; - /* Check for holes */ - if ((i < *nsegs && offset_in_page(seg->mr_offset)) || - offset_in_page((seg-1)->mr_offset + (seg-1)->mr_len)) - break; - } - rc = ib_map_phys_fmr(seg1->mr_chunk.rl_mw->r.fmr, - physaddrs, i, seg1->mr_dma); - if (rc) { - dprintk("RPC: %s: failed ib_map_phys_fmr " - "%u@0x%llx+%i (%d)... status %i\n", __func__, - len, (unsigned long long)seg1->mr_dma, - pageoff, i, rc); - while (i--) - rpcrdma_unmap_one(ia, --seg); - } else { - seg1->mr_rkey = seg1->mr_chunk.rl_mw->r.fmr->rkey; - seg1->mr_base = seg1->mr_dma + pageoff; - seg1->mr_nsegs = i; - seg1->mr_len = len; + struct rpcrdma_mr *mr; + + list_del(&req->rl_all); + + while ((mr = rpcrdma_mr_pop(&req->rl_free_mrs))) { + struct rpcrdma_buffer *buf = &mr->mr_xprt->rx_buf; + + spin_lock(&buf->rb_lock); + list_del(&mr->mr_all); + spin_unlock(&buf->rb_lock); + + frwr_mr_release(mr); } - *nsegs = i; - return rc; + + rpcrdma_regbuf_free(req->rl_recvbuf); + rpcrdma_regbuf_free(req->rl_sendbuf); + rpcrdma_regbuf_free(req->rl_rdmabuf); + kfree(req); } -static int -rpcrdma_deregister_fmr_external(struct rpcrdma_mr_seg *seg, - struct rpcrdma_ia *ia) +/** + * rpcrdma_mrs_destroy - Release all of a transport's MRs + * @r_xprt: controlling transport instance + * + * Relies on caller holding the transport send lock to protect + * removing mr->mr_list from req->rl_free_mrs safely. + */ +static void rpcrdma_mrs_destroy(struct rpcrdma_xprt *r_xprt) { - struct rpcrdma_mr_seg *seg1 = seg; - LIST_HEAD(l); - int rc; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_mr *mr; - list_add(&seg1->mr_chunk.rl_mw->r.fmr->list, &l); - rc = ib_unmap_fmr(&l); - while (seg1->mr_nsegs--) - rpcrdma_unmap_one(ia, seg++); - if (rc) - dprintk("RPC: %s: failed ib_unmap_fmr," - " status %i\n", __func__, rc); - return rc; + cancel_work_sync(&buf->rb_refresh_worker); + + spin_lock(&buf->rb_lock); + while ((mr = list_first_entry_or_null(&buf->rb_all_mrs, + struct rpcrdma_mr, + mr_all)) != NULL) { + list_del(&mr->mr_list); + list_del(&mr->mr_all); + spin_unlock(&buf->rb_lock); + + frwr_mr_release(mr); + + spin_lock(&buf->rb_lock); + } + spin_unlock(&buf->rb_lock); } -static int -rpcrdma_register_memwin_external(struct rpcrdma_mr_seg *seg, - int *nsegs, int writing, struct rpcrdma_ia *ia, - struct rpcrdma_xprt *r_xprt) +/** + * rpcrdma_buffer_destroy - Release all hw resources + * @buf: root control block for resources + * + * ORDERING: relies on a prior rpcrdma_xprt_drain : + * - No more Send or Receive completions can occur + * - All MRs, reps, and reqs are returned to their free lists + */ +void +rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf) { - int mem_priv = (writing ? IB_ACCESS_REMOTE_WRITE : - IB_ACCESS_REMOTE_READ); - struct ib_mw_bind param; - int rc; + rpcrdma_reps_destroy(buf); - *nsegs = 1; - rpcrdma_map_one(ia, seg, writing); - param.bind_info.mr = ia->ri_bind_mem; - param.wr_id = 0ULL; /* no send cookie */ - param.bind_info.addr = seg->mr_dma; - param.bind_info.length = seg->mr_len; - param.send_flags = 0; - param.bind_info.mw_access_flags = mem_priv; - - DECR_CQCOUNT(&r_xprt->rx_ep); - rc = ib_bind_mw(ia->ri_id->qp, seg->mr_chunk.rl_mw->r.mw, ¶m); - if (rc) { - dprintk("RPC: %s: failed ib_bind_mw " - "%u@0x%llx status %i\n", - __func__, seg->mr_len, - (unsigned long long)seg->mr_dma, rc); - rpcrdma_unmap_one(ia, seg); - } else { - seg->mr_rkey = seg->mr_chunk.rl_mw->r.mw->rkey; - seg->mr_base = param.bind_info.addr; - seg->mr_nsegs = 1; + while (!list_empty(&buf->rb_send_bufs)) { + struct rpcrdma_req *req; + + req = list_first_entry(&buf->rb_send_bufs, + struct rpcrdma_req, rl_list); + list_del(&req->rl_list); + rpcrdma_req_destroy(req); } - return rc; } -static int -rpcrdma_deregister_memwin_external(struct rpcrdma_mr_seg *seg, - struct rpcrdma_ia *ia, - struct rpcrdma_xprt *r_xprt, void **r) +/** + * rpcrdma_mr_get - Allocate an rpcrdma_mr object + * @r_xprt: controlling transport + * + * Returns an initialized rpcrdma_mr or NULL if no free + * rpcrdma_mr objects are available. + */ +struct rpcrdma_mr * +rpcrdma_mr_get(struct rpcrdma_xprt *r_xprt) { - struct ib_mw_bind param; - LIST_HEAD(l); - int rc; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_mr *mr; - BUG_ON(seg->mr_nsegs != 1); - param.bind_info.mr = ia->ri_bind_mem; - param.bind_info.addr = 0ULL; /* unbind */ - param.bind_info.length = 0; - param.bind_info.mw_access_flags = 0; - if (*r) { - param.wr_id = (u64) (unsigned long) *r; - param.send_flags = IB_SEND_SIGNALED; - INIT_CQCOUNT(&r_xprt->rx_ep); - } else { - param.wr_id = 0ULL; - param.send_flags = 0; - DECR_CQCOUNT(&r_xprt->rx_ep); - } - rc = ib_bind_mw(ia->ri_id->qp, seg->mr_chunk.rl_mw->r.mw, ¶m); - rpcrdma_unmap_one(ia, seg); - if (rc) - dprintk("RPC: %s: failed ib_(un)bind_mw," - " status %i\n", __func__, rc); - else - *r = NULL; /* will upcall on completion */ - return rc; + spin_lock(&buf->rb_lock); + mr = rpcrdma_mr_pop(&buf->rb_mrs); + spin_unlock(&buf->rb_lock); + return mr; } -static int -rpcrdma_register_default_external(struct rpcrdma_mr_seg *seg, - int *nsegs, int writing, struct rpcrdma_ia *ia) +/** + * rpcrdma_reply_put - Put reply buffers back into pool + * @buffers: buffer pool + * @req: object to return + * + */ +void rpcrdma_reply_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req) { - int mem_priv = (writing ? IB_ACCESS_REMOTE_WRITE : - IB_ACCESS_REMOTE_READ); - struct rpcrdma_mr_seg *seg1 = seg; - struct ib_phys_buf ipb[RPCRDMA_MAX_DATA_SEGS]; - int len, i, rc = 0; - - if (*nsegs > RPCRDMA_MAX_DATA_SEGS) - *nsegs = RPCRDMA_MAX_DATA_SEGS; - for (len = 0, i = 0; i < *nsegs;) { - rpcrdma_map_one(ia, seg, writing); - ipb[i].addr = seg->mr_dma; - ipb[i].size = seg->mr_len; - len += seg->mr_len; - ++seg; - ++i; - /* Check for holes */ - if ((i < *nsegs && offset_in_page(seg->mr_offset)) || - offset_in_page((seg-1)->mr_offset+(seg-1)->mr_len)) - break; - } - seg1->mr_base = seg1->mr_dma; - seg1->mr_chunk.rl_mr = ib_reg_phys_mr(ia->ri_pd, - ipb, i, mem_priv, &seg1->mr_base); - if (IS_ERR(seg1->mr_chunk.rl_mr)) { - rc = PTR_ERR(seg1->mr_chunk.rl_mr); - dprintk("RPC: %s: failed ib_reg_phys_mr " - "%u@0x%llx (%d)... status %i\n", - __func__, len, - (unsigned long long)seg1->mr_dma, i, rc); - while (i--) - rpcrdma_unmap_one(ia, --seg); - } else { - seg1->mr_rkey = seg1->mr_chunk.rl_mr->rkey; - seg1->mr_nsegs = i; - seg1->mr_len = len; + if (req->rl_reply) { + rpcrdma_rep_put(buffers, req->rl_reply); + req->rl_reply = NULL; } - *nsegs = i; - return rc; } -static int -rpcrdma_deregister_default_external(struct rpcrdma_mr_seg *seg, - struct rpcrdma_ia *ia) +/** + * rpcrdma_buffer_get - Get a request buffer + * @buffers: Buffer pool from which to obtain a buffer + * + * Returns a fresh rpcrdma_req, or NULL if none are available. + */ +struct rpcrdma_req * +rpcrdma_buffer_get(struct rpcrdma_buffer *buffers) { - struct rpcrdma_mr_seg *seg1 = seg; - int rc; + struct rpcrdma_req *req; - rc = ib_dereg_mr(seg1->mr_chunk.rl_mr); - seg1->mr_chunk.rl_mr = NULL; - while (seg1->mr_nsegs--) - rpcrdma_unmap_one(ia, seg++); - if (rc) - dprintk("RPC: %s: failed ib_dereg_mr," - " status %i\n", __func__, rc); - return rc; + spin_lock(&buffers->rb_lock); + req = list_first_entry_or_null(&buffers->rb_send_bufs, + struct rpcrdma_req, rl_list); + if (req) + list_del_init(&req->rl_list); + spin_unlock(&buffers->rb_lock); + return req; } -int -rpcrdma_register_external(struct rpcrdma_mr_seg *seg, - int nsegs, int writing, struct rpcrdma_xprt *r_xprt) +/** + * rpcrdma_buffer_put - Put request/reply buffers back into pool + * @buffers: buffer pool + * @req: object to return + * + */ +void rpcrdma_buffer_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req) { - struct rpcrdma_ia *ia = &r_xprt->rx_ia; - int rc = 0; - - switch (ia->ri_memreg_strategy) { - -#if RPCRDMA_PERSISTENT_REGISTRATION - case RPCRDMA_ALLPHYSICAL: - rpcrdma_map_one(ia, seg, writing); - seg->mr_rkey = ia->ri_bind_mem->rkey; - seg->mr_base = seg->mr_dma; - seg->mr_nsegs = 1; - nsegs = 1; - break; -#endif - - /* Registration using frmr registration */ - case RPCRDMA_FRMR: - rc = rpcrdma_register_frmr_external(seg, &nsegs, writing, ia, r_xprt); - break; - - /* Registration using fmr memory registration */ - case RPCRDMA_MTHCAFMR: - rc = rpcrdma_register_fmr_external(seg, &nsegs, writing, ia); - break; + rpcrdma_reply_put(buffers, req); - /* Registration using memory windows */ - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - rc = rpcrdma_register_memwin_external(seg, &nsegs, writing, ia, r_xprt); - break; + spin_lock(&buffers->rb_lock); + list_add(&req->rl_list, &buffers->rb_send_bufs); + spin_unlock(&buffers->rb_lock); +} - /* Default registration each time */ - default: - rc = rpcrdma_register_default_external(seg, &nsegs, writing, ia); - break; +/* Returns a pointer to a rpcrdma_regbuf object, or NULL. + * + * xprtrdma uses a regbuf for posting an outgoing RDMA SEND, or for + * receiving the payload of RDMA RECV operations. During Long Calls + * or Replies they may be registered externally via frwr_map. + */ +static struct rpcrdma_regbuf * +rpcrdma_regbuf_alloc_node(size_t size, enum dma_data_direction direction, + int node) +{ + struct rpcrdma_regbuf *rb; + + rb = kmalloc_node(sizeof(*rb), XPRTRDMA_GFP_FLAGS, node); + if (!rb) + return NULL; + rb->rg_data = kmalloc_node(size, XPRTRDMA_GFP_FLAGS, node); + if (!rb->rg_data) { + kfree(rb); + return NULL; } - if (rc) - return -1; - return nsegs; + rb->rg_device = NULL; + rb->rg_direction = direction; + rb->rg_iov.length = size; + return rb; } -int -rpcrdma_deregister_external(struct rpcrdma_mr_seg *seg, - struct rpcrdma_xprt *r_xprt, void *r) +static struct rpcrdma_regbuf * +rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction) { - struct rpcrdma_ia *ia = &r_xprt->rx_ia; - int nsegs = seg->mr_nsegs, rc; - - switch (ia->ri_memreg_strategy) { - -#if RPCRDMA_PERSISTENT_REGISTRATION - case RPCRDMA_ALLPHYSICAL: - BUG_ON(nsegs != 1); - rpcrdma_unmap_one(ia, seg); - rc = 0; - break; -#endif + return rpcrdma_regbuf_alloc_node(size, direction, NUMA_NO_NODE); +} - case RPCRDMA_FRMR: - rc = rpcrdma_deregister_frmr_external(seg, ia, r_xprt); - break; +/** + * rpcrdma_regbuf_realloc - re-allocate a SEND/RECV buffer + * @rb: regbuf to reallocate + * @size: size of buffer to be allocated, in bytes + * @flags: GFP flags + * + * Returns true if reallocation was successful. If false is + * returned, @rb is left untouched. + */ +bool rpcrdma_regbuf_realloc(struct rpcrdma_regbuf *rb, size_t size, gfp_t flags) +{ + void *buf; - case RPCRDMA_MTHCAFMR: - rc = rpcrdma_deregister_fmr_external(seg, ia); - break; + buf = kmalloc(size, flags); + if (!buf) + return false; - case RPCRDMA_MEMWINDOWS_ASYNC: - case RPCRDMA_MEMWINDOWS: - rc = rpcrdma_deregister_memwin_external(seg, ia, r_xprt, &r); - break; + rpcrdma_regbuf_dma_unmap(rb); + kfree(rb->rg_data); - default: - rc = rpcrdma_deregister_default_external(seg, ia); - break; - } - if (r) { - struct rpcrdma_rep *rep = r; - void (*func)(struct rpcrdma_rep *) = rep->rr_func; - rep->rr_func = NULL; - func(rep); /* dereg done, callback now */ - } - return nsegs; + rb->rg_data = buf; + rb->rg_iov.length = size; + return true; } -/* - * Prepost any receive buffer, then post send. +/** + * __rpcrdma_regbuf_dma_map - DMA-map a regbuf + * @r_xprt: controlling transport instance + * @rb: regbuf to be mapped * - * Receive buffer is donated to hardware, reclaimed upon recv completion. + * Returns true if the buffer is now DMA mapped to @r_xprt's device */ -int -rpcrdma_ep_post(struct rpcrdma_ia *ia, - struct rpcrdma_ep *ep, - struct rpcrdma_req *req) +bool __rpcrdma_regbuf_dma_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_regbuf *rb) { - struct ib_send_wr send_wr, *send_wr_fail; - struct rpcrdma_rep *rep = req->rl_reply; - int rc; + struct ib_device *device = r_xprt->rx_ep->re_id->device; - if (rep) { - rc = rpcrdma_ep_post_recv(ia, ep, rep); - if (rc) - goto out; - req->rl_reply = NULL; - } + if (rb->rg_direction == DMA_NONE) + return false; - send_wr.next = NULL; - send_wr.wr_id = 0ULL; /* no send cookie */ - send_wr.sg_list = req->rl_send_iov; - send_wr.num_sge = req->rl_niovs; - send_wr.opcode = IB_WR_SEND; - if (send_wr.num_sge == 4) /* no need to sync any pad (constant) */ - ib_dma_sync_single_for_device(ia->ri_id->device, - req->rl_send_iov[3].addr, req->rl_send_iov[3].length, - DMA_TO_DEVICE); - ib_dma_sync_single_for_device(ia->ri_id->device, - req->rl_send_iov[1].addr, req->rl_send_iov[1].length, - DMA_TO_DEVICE); - ib_dma_sync_single_for_device(ia->ri_id->device, - req->rl_send_iov[0].addr, req->rl_send_iov[0].length, - DMA_TO_DEVICE); - - if (DECR_CQCOUNT(ep) > 0) - send_wr.send_flags = 0; - else { /* Provider must take a send completion every now and then */ - INIT_CQCOUNT(ep); - send_wr.send_flags = IB_SEND_SIGNALED; + rb->rg_iov.addr = ib_dma_map_single(device, rdmab_data(rb), + rdmab_length(rb), rb->rg_direction); + if (ib_dma_mapping_error(device, rdmab_addr(rb))) { + trace_xprtrdma_dma_maperr(rdmab_addr(rb)); + return false; } - rc = ib_post_send(ia->ri_id->qp, &send_wr, &send_wr_fail); - if (rc) - dprintk("RPC: %s: ib_post_send returned %i\n", __func__, - rc); -out: - return rc; + rb->rg_device = device; + rb->rg_iov.lkey = r_xprt->rx_ep->re_pd->local_dma_lkey; + return true; } -/* - * (Re)post a receive buffer. +static void rpcrdma_regbuf_dma_unmap(struct rpcrdma_regbuf *rb) +{ + if (!rb) + return; + + if (!rpcrdma_regbuf_is_mapped(rb)) + return; + + ib_dma_unmap_single(rb->rg_device, rdmab_addr(rb), rdmab_length(rb), + rb->rg_direction); + rb->rg_device = NULL; +} + +static void rpcrdma_regbuf_free(struct rpcrdma_regbuf *rb) +{ + rpcrdma_regbuf_dma_unmap(rb); + if (rb) + kfree(rb->rg_data); + kfree(rb); +} + +/** + * rpcrdma_post_recvs - Refill the Receive Queue + * @r_xprt: controlling transport instance + * @needed: current credit grant + * */ -int -rpcrdma_ep_post_recv(struct rpcrdma_ia *ia, - struct rpcrdma_ep *ep, - struct rpcrdma_rep *rep) +void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed) { - struct ib_recv_wr recv_wr, *recv_wr_fail; - int rc; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct ib_recv_wr *wr, *bad_wr; + struct rpcrdma_rep *rep; + int count, rc; - recv_wr.next = NULL; - recv_wr.wr_id = (u64) (unsigned long) rep; - recv_wr.sg_list = &rep->rr_iov; - recv_wr.num_sge = 1; + rc = 0; + count = 0; - ib_dma_sync_single_for_cpu(ia->ri_id->device, - rep->rr_iov.addr, rep->rr_iov.length, DMA_BIDIRECTIONAL); + if (likely(ep->re_receive_count > needed)) + goto out; + needed -= ep->re_receive_count; + needed += RPCRDMA_MAX_RECV_BATCH; - DECR_CQCOUNT(ep); - rc = ib_post_recv(ia->ri_id->qp, &recv_wr, &recv_wr_fail); + if (atomic_inc_return(&ep->re_receiving) > 1) + goto out; - if (rc) - dprintk("RPC: %s: ib_post_recv returned %i\n", __func__, - rc); - return rc; + /* fast path: all needed reps can be found on the free list */ + wr = NULL; + while (needed) { + rep = rpcrdma_rep_get_locked(buf); + if (!rep) + rep = rpcrdma_rep_create(r_xprt); + if (!rep) + break; + if (!rpcrdma_regbuf_dma_map(r_xprt, rep->rr_rdmabuf)) { + rpcrdma_rep_put(buf, rep); + break; + } + + rep->rr_cid.ci_queue_id = ep->re_attr.recv_cq->res.id; + trace_xprtrdma_post_recv(&rep->rr_cid); + rep->rr_recv_wr.next = wr; + wr = &rep->rr_recv_wr; + --needed; + ++count; + } + if (!wr) + goto out; + + rc = ib_post_recv(ep->re_id->qp, wr, + (const struct ib_recv_wr **)&bad_wr); + if (rc) { + trace_xprtrdma_post_recvs_err(r_xprt, rc); + for (wr = bad_wr; wr;) { + struct rpcrdma_rep *rep; + + rep = container_of(wr, struct rpcrdma_rep, rr_recv_wr); + wr = wr->next; + rpcrdma_rep_put(buf, rep); + --count; + } + } + if (atomic_dec_return(&ep->re_receiving) > 0) + complete(&ep->re_done); + +out: + trace_xprtrdma_post_recvs(r_xprt, count); + ep->re_receive_count += count; + return; } diff --git a/net/sunrpc/xprtrdma/xprt_rdma.h b/net/sunrpc/xprtrdma/xprt_rdma.h index cc1445dc1d1a..8147d2b41494 100644 --- a/net/sunrpc/xprtrdma/xprt_rdma.h +++ b/net/sunrpc/xprtrdma/xprt_rdma.h @@ -1,4 +1,6 @@ +/* SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause */ /* + * Copyright (c) 2014-2017 Oracle. All rights reserved. * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. * * This software is available to you under a choice of one of two @@ -42,92 +44,221 @@ #include <linux/wait.h> /* wait_queue_head_t, etc */ #include <linux/spinlock.h> /* spinlock_t, etc */ -#include <linux/atomic.h> /* atomic_t, etc */ +#include <linux/atomic.h> /* atomic_t, etc */ +#include <linux/kref.h> /* struct kref */ +#include <linux/workqueue.h> /* struct work_struct */ +#include <linux/llist.h> #include <rdma/rdma_cm.h> /* RDMA connection api */ #include <rdma/ib_verbs.h> /* RDMA verbs api */ #include <linux/sunrpc/clnt.h> /* rpc_xprt */ +#include <linux/sunrpc/rpc_rdma_cid.h> /* completion IDs */ #include <linux/sunrpc/rpc_rdma.h> /* RPC/RDMA protocol */ #include <linux/sunrpc/xprtrdma.h> /* xprt parameters */ +#include <linux/sunrpc/rdma_rn.h> /* removal notifications */ #define RDMA_RESOLVE_TIMEOUT (5000) /* 5 seconds */ #define RDMA_CONNECT_RETRY_MAX (2) /* retries if no listener backlog */ +#define RPCRDMA_BIND_TO (60U * HZ) +#define RPCRDMA_INIT_REEST_TO (5U * HZ) +#define RPCRDMA_MAX_REEST_TO (30U * HZ) +#define RPCRDMA_IDLE_DISC_TO (5U * 60 * HZ) + /* - * Interface Adapter -- one per transport instance + * RDMA Endpoint -- connection endpoint details */ -struct rpcrdma_ia { - struct rdma_cm_id *ri_id; - struct ib_pd *ri_pd; - struct ib_mr *ri_bind_mem; - u32 ri_dma_lkey; - int ri_have_dma_lkey; - struct completion ri_done; - int ri_async_rc; - enum rpcrdma_memreg ri_memreg_strategy; +struct rpcrdma_mr; +struct rpcrdma_ep { + struct kref re_kref; + struct rdma_cm_id *re_id; + struct ib_pd *re_pd; + unsigned int re_max_rdma_segs; + unsigned int re_max_fr_depth; + struct rpcrdma_mr *re_write_pad_mr; + enum ib_mr_type re_mrtype; + struct completion re_done; + unsigned int re_send_count; + unsigned int re_send_batch; + unsigned int re_max_inline_send; + unsigned int re_max_inline_recv; + int re_async_rc; + int re_connect_status; + atomic_t re_receiving; + atomic_t re_force_disconnect; + struct ib_qp_init_attr re_attr; + wait_queue_head_t re_connect_wait; + struct rpc_xprt *re_xprt; + struct rpcrdma_connect_private + re_cm_private; + struct rdma_conn_param re_remote_cma; + struct rpcrdma_notification re_rn; + int re_receive_count; + unsigned int re_max_requests; /* depends on device */ + unsigned int re_inline_send; /* negotiated */ + unsigned int re_inline_recv; /* negotiated */ + + atomic_t re_completion_ids; + + char re_write_pad[XDR_UNIT]; }; -/* - * RDMA Endpoint -- one per transport instance +/* Pre-allocate extra Work Requests for handling reverse-direction + * Receives and Sends. This is a fixed value because the Work Queues + * are allocated when the forward channel is set up, long before the + * backchannel is provisioned. This value is two times + * NFS4_DEF_CB_SLOT_TABLE_SIZE. */ +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +#define RPCRDMA_BACKWARD_WRS (32) +#else +#define RPCRDMA_BACKWARD_WRS (0) +#endif -struct rpcrdma_ep { - atomic_t rep_cqcount; - int rep_cqinit; - int rep_connected; - struct rpcrdma_ia *rep_ia; - struct ib_cq *rep_cq; - struct ib_qp_init_attr rep_attr; - wait_queue_head_t rep_connect_wait; - struct ib_sge rep_pad; /* holds zeroed pad */ - struct ib_mr *rep_pad_mr; /* holds zeroed pad */ - void (*rep_func)(struct rpcrdma_ep *); - struct rpc_xprt *rep_xprt; /* for rep_func */ - struct rdma_conn_param rep_remote_cma; - struct sockaddr_storage rep_remote_addr; +/* Registered buffer -- registered kmalloc'd memory for RDMA SEND/RECV + */ + +struct rpcrdma_regbuf { + struct ib_sge rg_iov; + struct ib_device *rg_device; + enum dma_data_direction rg_direction; + void *rg_data; }; -#define INIT_CQCOUNT(ep) atomic_set(&(ep)->rep_cqcount, (ep)->rep_cqinit) -#define DECR_CQCOUNT(ep) atomic_sub_return(1, &(ep)->rep_cqcount) +static inline u64 rdmab_addr(struct rpcrdma_regbuf *rb) +{ + return rb->rg_iov.addr; +} + +static inline u32 rdmab_length(struct rpcrdma_regbuf *rb) +{ + return rb->rg_iov.length; +} + +static inline u32 rdmab_lkey(struct rpcrdma_regbuf *rb) +{ + return rb->rg_iov.lkey; +} + +static inline struct ib_device *rdmab_device(struct rpcrdma_regbuf *rb) +{ + return rb->rg_device; +} + +static inline void *rdmab_data(const struct rpcrdma_regbuf *rb) +{ + return rb->rg_data; +} + +/* Do not use emergency memory reserves, and fail quickly if memory + * cannot be allocated easily. These flags may be used wherever there + * is robust logic to handle a failure to allocate. + */ +#define XPRTRDMA_GFP_FLAGS (__GFP_NOMEMALLOC | __GFP_NORETRY | __GFP_NOWARN) + +/* To ensure a transport can always make forward progress, + * the number of RDMA segments allowed in header chunk lists + * is capped at 16. This prevents less-capable devices from + * overrunning the Send buffer while building chunk lists. + * + * Elements of the Read list take up more room than the + * Write list or Reply chunk. 16 read segments means the + * chunk lists cannot consume more than + * + * ((16 + 2) * read segment size) + 1 XDR words, + * + * or about 400 bytes. The fixed part of the header is + * another 24 bytes. Thus when the inline threshold is + * 1024 bytes, at least 600 bytes are available for RPC + * message bodies. + */ +enum { + RPCRDMA_MAX_HDR_SEGS = 16, +}; /* - * struct rpcrdma_rep -- this structure encapsulates state required to recv - * and complete a reply, asychronously. It needs several pieces of - * state: - * o recv buffer (posted to provider) - * o ib_sge (also donated to provider) - * o status of reply (length, success or not) - * o bookkeeping state to get run by tasklet (list, etc) + * struct rpcrdma_rep -- this structure encapsulates state required + * to receive and complete an RPC Reply, asychronously. It needs + * several pieces of state: * - * These are allocated during initialization, per-transport instance; - * however, the tasklet execution list itself is global, as it should - * always be pretty short. + * o receive buffer and ib_sge (donated to provider) + * o status of receive (success or not, length, inv rkey) + * o bookkeeping state to get run by reply handler (XDR stream) * - * N of these are associated with a transport instance, and stored in - * struct rpcrdma_buffer. N is the max number of outstanding requests. + * These structures are allocated during transport initialization. + * N of these are associated with a transport instance, managed by + * struct rpcrdma_buffer. N is the max number of outstanding RPCs. */ -/* temporary static scatter/gather max */ -#define RPCRDMA_MAX_DATA_SEGS (64) /* max scatter/gather */ -#define RPCRDMA_MAX_SEGS (RPCRDMA_MAX_DATA_SEGS + 2) /* head+tail = 2 */ -#define MAX_RPCRDMAHDR (\ - /* max supported RPC/RDMA header */ \ - sizeof(struct rpcrdma_msg) + (2 * sizeof(u32)) + \ - (sizeof(struct rpcrdma_read_chunk) * RPCRDMA_MAX_SEGS) + sizeof(u32)) +struct rpcrdma_rep { + struct ib_cqe rr_cqe; + struct rpc_rdma_cid rr_cid; + + __be32 rr_xid; + __be32 rr_vers; + __be32 rr_proc; + int rr_wc_flags; + u32 rr_inv_rkey; + struct rpcrdma_regbuf *rr_rdmabuf; + struct rpcrdma_xprt *rr_rxprt; + struct rpc_rqst *rr_rqst; + struct xdr_buf rr_hdrbuf; + struct xdr_stream rr_stream; + struct llist_node rr_node; + struct ib_recv_wr rr_recv_wr; + struct list_head rr_all; +}; + +/* To reduce the rate at which a transport invokes ib_post_recv + * (and thus the hardware doorbell rate), xprtrdma posts Receive + * WRs in batches. + * + * Setting this to zero disables Receive post batching. + */ +enum { + RPCRDMA_MAX_RECV_BATCH = 7, +}; -struct rpcrdma_buffer; +/* struct rpcrdma_sendctx - DMA mapped SGEs to unmap after Send completes + */ +struct rpcrdma_req; +struct rpcrdma_sendctx { + struct ib_cqe sc_cqe; + struct rpc_rdma_cid sc_cid; + struct rpcrdma_req *sc_req; + unsigned int sc_unmap_count; + struct ib_sge sc_sges[]; +}; -struct rpcrdma_rep { - unsigned int rr_len; /* actual received reply length */ - struct rpcrdma_buffer *rr_buffer; /* home base for this structure */ - struct rpc_xprt *rr_xprt; /* needed for request/reply matching */ - void (*rr_func)(struct rpcrdma_rep *);/* called by tasklet in softint */ - struct list_head rr_list; /* tasklet list */ - wait_queue_head_t rr_unbind; /* optional unbind wait */ - struct ib_sge rr_iov; /* for posting */ - struct ib_mr *rr_handle; /* handle for mem in rr_iov */ - char rr_base[MAX_RPCRDMAHDR]; /* minimal inline receive buffer */ +/* + * struct rpcrdma_mr - external memory region metadata + * + * An external memory region is any buffer or page that is registered + * on the fly (ie, not pre-registered). + */ +struct rpcrdma_req; +struct rpcrdma_mr { + struct list_head mr_list; + struct rpcrdma_req *mr_req; + + struct ib_mr *mr_ibmr; + struct ib_device *mr_device; + struct scatterlist *mr_sg; + int mr_nents; + enum dma_data_direction mr_dir; + struct ib_cqe mr_cqe; + struct completion mr_linv_done; + union { + struct ib_reg_wr mr_regwr; + struct ib_send_wr mr_invwr; + }; + struct rpcrdma_xprt *mr_xprt; + u32 mr_handle; + u32 mr_length; + u64 mr_offset; + struct list_head mr_all; + struct rpc_rdma_cid mr_cid; }; /* @@ -146,57 +277,84 @@ struct rpcrdma_rep { * of iovs for send operations. The reason is that the iovs passed to * ib_post_{send,recv} must not be modified until the work request * completes. - * - * NOTES: - * o RPCRDMA_MAX_SEGS is the max number of addressible chunk elements we - * marshal. The number needed varies depending on the iov lists that - * are passed to us, the memory registration mode we are in, and if - * physical addressing is used, the layout. */ -struct rpcrdma_mr_seg { /* chunk descriptors */ - union { /* chunk memory handles */ - struct ib_mr *rl_mr; /* if registered directly */ - struct rpcrdma_mw { /* if registered from region */ - union { - struct ib_mw *mw; - struct ib_fmr *fmr; - struct { - struct ib_fast_reg_page_list *fr_pgl; - struct ib_mr *fr_mr; - enum { FRMR_IS_INVALID, FRMR_IS_VALID } state; - } frmr; - } r; - struct list_head mw_list; - } *rl_mw; - } mr_chunk; - u64 mr_base; /* registration result */ - u32 mr_rkey; /* registration result */ - u32 mr_len; /* length of chunk or segment */ - int mr_nsegs; /* number of segments in chunk or 0 */ - enum dma_data_direction mr_dir; /* segment mapping direction */ - dma_addr_t mr_dma; /* segment mapping address */ - size_t mr_dmalen; /* segment mapping length */ - struct page *mr_page; /* owning page, if any */ - char *mr_offset; /* kva if no page, else offset */ +/* Maximum number of page-sized "segments" per chunk list to be + * registered or invalidated. Must handle a Reply chunk: + */ +enum { + RPCRDMA_MAX_IOV_SEGS = 3, + RPCRDMA_MAX_DATA_SEGS = ((1 * 1024 * 1024) / PAGE_SIZE) + 1, + RPCRDMA_MAX_SEGS = RPCRDMA_MAX_DATA_SEGS + + RPCRDMA_MAX_IOV_SEGS, +}; + +/* Arguments for DMA mapping and registration */ +struct rpcrdma_mr_seg { + u32 mr_len; /* length of segment */ + struct page *mr_page; /* underlying struct page */ + u64 mr_offset; /* IN: page offset, OUT: iova */ +}; + +/* The Send SGE array is provisioned to send a maximum size + * inline request: + * - RPC-over-RDMA header + * - xdr_buf head iovec + * - RPCRDMA_MAX_INLINE bytes, in pages + * - xdr_buf tail iovec + * + * The actual number of array elements consumed by each RPC + * depends on the device's max_sge limit. + */ +enum { + RPCRDMA_MIN_SEND_SGES = 3, + RPCRDMA_MAX_PAGE_SGES = RPCRDMA_MAX_INLINE >> PAGE_SHIFT, + RPCRDMA_MAX_SEND_SGES = 1 + 1 + RPCRDMA_MAX_PAGE_SGES + 1, }; +struct rpcrdma_buffer; struct rpcrdma_req { - size_t rl_size; /* actual length of buffer */ - unsigned int rl_niovs; /* 0, 2 or 4 */ - unsigned int rl_nchunks; /* non-zero if chunks */ - unsigned int rl_connect_cookie; /* retry detection */ - struct rpcrdma_buffer *rl_buffer; /* home base for this structure */ - struct rpcrdma_rep *rl_reply;/* holder for reply buffer */ - struct rpcrdma_mr_seg rl_segments[RPCRDMA_MAX_SEGS];/* chunk segments */ - struct ib_sge rl_send_iov[4]; /* for active requests */ - struct ib_sge rl_iov; /* for posting */ - struct ib_mr *rl_handle; /* handle for mem in rl_iov */ - char rl_base[MAX_RPCRDMAHDR]; /* start of actual buffer */ - __u32 rl_xdr_buf[0]; /* start of returned rpc rq_buffer */ + struct list_head rl_list; + struct rpc_rqst rl_slot; + struct rpcrdma_rep *rl_reply; + struct xdr_stream rl_stream; + struct xdr_buf rl_hdrbuf; + struct ib_send_wr rl_wr; + struct rpcrdma_sendctx *rl_sendctx; + struct rpcrdma_regbuf *rl_rdmabuf; /* xprt header */ + struct rpcrdma_regbuf *rl_sendbuf; /* rq_snd_buf */ + struct rpcrdma_regbuf *rl_recvbuf; /* rq_rcv_buf */ + + struct list_head rl_all; + struct kref rl_kref; + + struct list_head rl_free_mrs; + struct list_head rl_registered; + struct rpcrdma_mr_seg rl_segments[RPCRDMA_MAX_SEGS]; }; -#define rpcr_to_rdmar(r) \ - container_of((r)->rq_buffer, struct rpcrdma_req, rl_xdr_buf[0]) + +static inline struct rpcrdma_req * +rpcr_to_rdmar(const struct rpc_rqst *rqst) +{ + return container_of(rqst, struct rpcrdma_req, rl_slot); +} + +static inline void +rpcrdma_mr_push(struct rpcrdma_mr *mr, struct list_head *list) +{ + list_add(&mr->mr_list, list); +} + +static inline struct rpcrdma_mr * +rpcrdma_mr_pop(struct list_head *list) +{ + struct rpcrdma_mr *mr; + + mr = list_first_entry_or_null(list, struct rpcrdma_mr, mr_list); + if (mr) + list_del_init(&mr->mr_list); + return mr; +} /* * struct rpcrdma_buffer -- holds list/queue of pre-registered memory for @@ -205,60 +363,57 @@ struct rpcrdma_req { * One of these is associated with a transport instance */ struct rpcrdma_buffer { - spinlock_t rb_lock; /* protects indexes */ - atomic_t rb_credits; /* most recent server credits */ - unsigned long rb_cwndscale; /* cached framework rpc_cwndscale */ - int rb_max_requests;/* client max requests */ - struct list_head rb_mws; /* optional memory windows/fmrs/frmrs */ - int rb_send_index; - struct rpcrdma_req **rb_send_bufs; - int rb_recv_index; - struct rpcrdma_rep **rb_recv_bufs; - char *rb_pool; -}; -#define rdmab_to_ia(b) (&container_of((b), struct rpcrdma_xprt, rx_buf)->rx_ia) + spinlock_t rb_lock; + struct list_head rb_send_bufs; + struct list_head rb_mrs; -/* - * Internal structure for transport instance creation. This - * exists primarily for modularity. - * - * This data should be set with mount options - */ -struct rpcrdma_create_data_internal { - struct sockaddr_storage addr; /* RDMA server address */ - unsigned int max_requests; /* max requests (slots) in flight */ - unsigned int rsize; /* mount rsize - max read hdr+data */ - unsigned int wsize; /* mount wsize - max write hdr+data */ - unsigned int inline_rsize; /* max non-rdma read data payload */ - unsigned int inline_wsize; /* max non-rdma write data payload */ - unsigned int padding; /* non-rdma write header padding */ -}; + unsigned long rb_sc_head; + unsigned long rb_sc_tail; + unsigned long rb_sc_last; + struct rpcrdma_sendctx **rb_sc_ctxs; -#define RPCRDMA_INLINE_READ_THRESHOLD(rq) \ - (rpcx_to_rdmad(rq->rq_xprt).inline_rsize) + struct list_head rb_allreqs; + struct list_head rb_all_mrs; + struct list_head rb_all_reps; -#define RPCRDMA_INLINE_WRITE_THRESHOLD(rq)\ - (rpcx_to_rdmad(rq->rq_xprt).inline_wsize) + struct llist_head rb_free_reps; -#define RPCRDMA_INLINE_PAD_VALUE(rq)\ - rpcx_to_rdmad(rq->rq_xprt).padding + __be32 rb_max_requests; + u32 rb_credits; /* most recent credit grant */ + + u32 rb_bc_srv_max_requests; + u32 rb_bc_max_requests; + + struct work_struct rb_refresh_worker; +}; /* * Statistics for RPCRDMA */ struct rpcrdma_stats { + /* accessed when sending a call */ unsigned long read_chunk_count; unsigned long write_chunk_count; unsigned long reply_chunk_count; - unsigned long long total_rdma_request; - unsigned long long total_rdma_reply; + /* rarely accessed error counters */ unsigned long long pullup_copy_count; - unsigned long long fixup_copy_count; unsigned long hardway_register_count; unsigned long failed_marshal_count; unsigned long bad_reply_count; + unsigned long mrs_recycled; + unsigned long mrs_orphaned; + unsigned long mrs_allocated; + unsigned long empty_sendctx_q; + + /* accessed when receiving a reply */ + unsigned long long total_rdma_reply; + unsigned long long fixup_copy_count; + unsigned long reply_waits_for_send; + unsigned long local_inv_needed; + unsigned long nomsg_call_count; + unsigned long bcall_count; }; /* @@ -272,82 +427,179 @@ struct rpcrdma_stats { * during unmount. */ struct rpcrdma_xprt { - struct rpc_xprt xprt; - struct rpcrdma_ia rx_ia; - struct rpcrdma_ep rx_ep; + struct rpc_xprt rx_xprt; + struct rpcrdma_ep *rx_ep; struct rpcrdma_buffer rx_buf; - struct rpcrdma_create_data_internal rx_data; - struct delayed_work rdma_connect; + struct delayed_work rx_connect_worker; + struct rpc_timeout rx_timeout; struct rpcrdma_stats rx_stats; }; -#define rpcx_to_rdmax(x) container_of(x, struct rpcrdma_xprt, xprt) -#define rpcx_to_rdmad(x) (rpcx_to_rdmax(x)->rx_data) +#define rpcx_to_rdmax(x) container_of(x, struct rpcrdma_xprt, rx_xprt) + +static inline const char * +rpcrdma_addrstr(const struct rpcrdma_xprt *r_xprt) +{ + return r_xprt->rx_xprt.address_strings[RPC_DISPLAY_ADDR]; +} + +static inline const char * +rpcrdma_portstr(const struct rpcrdma_xprt *r_xprt) +{ + return r_xprt->rx_xprt.address_strings[RPC_DISPLAY_PORT]; +} /* Setting this to 0 ensures interoperability with early servers. * Setting this to 1 enhances certain unaligned read/write performance. * Default is 0, see sysctl entry and rpc_rdma.c rpcrdma_convert_iovs() */ extern int xprt_rdma_pad_optimize; -/* - * Interface Adapter calls - xprtrdma/verbs.c +/* This setting controls the hunt for a supported memory + * registration strategy. */ -int rpcrdma_ia_open(struct rpcrdma_xprt *, struct sockaddr *, int); -void rpcrdma_ia_close(struct rpcrdma_ia *); +extern unsigned int xprt_rdma_memreg_strategy; /* * Endpoint calls - xprtrdma/verbs.c */ -int rpcrdma_ep_create(struct rpcrdma_ep *, struct rpcrdma_ia *, - struct rpcrdma_create_data_internal *); -int rpcrdma_ep_destroy(struct rpcrdma_ep *, struct rpcrdma_ia *); -int rpcrdma_ep_connect(struct rpcrdma_ep *, struct rpcrdma_ia *); -int rpcrdma_ep_disconnect(struct rpcrdma_ep *, struct rpcrdma_ia *); +void rpcrdma_force_disconnect(struct rpcrdma_ep *ep); +void rpcrdma_flush_disconnect(struct rpcrdma_xprt *r_xprt, struct ib_wc *wc); +int rpcrdma_xprt_connect(struct rpcrdma_xprt *r_xprt); +void rpcrdma_xprt_disconnect(struct rpcrdma_xprt *r_xprt); -int rpcrdma_ep_post(struct rpcrdma_ia *, struct rpcrdma_ep *, - struct rpcrdma_req *); -int rpcrdma_ep_post_recv(struct rpcrdma_ia *, struct rpcrdma_ep *, - struct rpcrdma_rep *); +void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed); /* * Buffer calls - xprtrdma/verbs.c */ -int rpcrdma_buffer_create(struct rpcrdma_buffer *, struct rpcrdma_ep *, - struct rpcrdma_ia *, - struct rpcrdma_create_data_internal *); +struct rpcrdma_req *rpcrdma_req_create(struct rpcrdma_xprt *r_xprt, + size_t size); +int rpcrdma_req_setup(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req); +void rpcrdma_req_destroy(struct rpcrdma_req *req); +int rpcrdma_buffer_create(struct rpcrdma_xprt *); void rpcrdma_buffer_destroy(struct rpcrdma_buffer *); +struct rpcrdma_sendctx *rpcrdma_sendctx_get_locked(struct rpcrdma_xprt *r_xprt); + +struct rpcrdma_mr *rpcrdma_mr_get(struct rpcrdma_xprt *r_xprt); +void rpcrdma_mrs_refresh(struct rpcrdma_xprt *r_xprt); struct rpcrdma_req *rpcrdma_buffer_get(struct rpcrdma_buffer *); -void rpcrdma_buffer_put(struct rpcrdma_req *); -void rpcrdma_recv_buffer_get(struct rpcrdma_req *); -void rpcrdma_recv_buffer_put(struct rpcrdma_rep *); +void rpcrdma_buffer_put(struct rpcrdma_buffer *buffers, + struct rpcrdma_req *req); +void rpcrdma_rep_put(struct rpcrdma_buffer *buf, struct rpcrdma_rep *rep); +void rpcrdma_reply_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req); + +bool rpcrdma_regbuf_realloc(struct rpcrdma_regbuf *rb, size_t size, + gfp_t flags); +bool __rpcrdma_regbuf_dma_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_regbuf *rb); + +/** + * rpcrdma_regbuf_is_mapped - check if buffer is DMA mapped + * + * Returns true if the buffer is now mapped to rb->rg_device. + */ +static inline bool rpcrdma_regbuf_is_mapped(struct rpcrdma_regbuf *rb) +{ + return rb->rg_device != NULL; +} + +/** + * rpcrdma_regbuf_dma_map - DMA-map a regbuf + * @r_xprt: controlling transport instance + * @rb: regbuf to be mapped + * + * Returns true if the buffer is currently DMA mapped. + */ +static inline bool rpcrdma_regbuf_dma_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_regbuf *rb) +{ + if (likely(rpcrdma_regbuf_is_mapped(rb))) + return true; + return __rpcrdma_regbuf_dma_map(r_xprt, rb); +} -int rpcrdma_register_internal(struct rpcrdma_ia *, void *, int, - struct ib_mr **, struct ib_sge *); -int rpcrdma_deregister_internal(struct rpcrdma_ia *, - struct ib_mr *, struct ib_sge *); +/* + * Wrappers for chunk registration, shared by read/write chunk code. + */ -int rpcrdma_register_external(struct rpcrdma_mr_seg *, - int, int, struct rpcrdma_xprt *); -int rpcrdma_deregister_external(struct rpcrdma_mr_seg *, - struct rpcrdma_xprt *, void *); +static inline enum dma_data_direction +rpcrdma_data_dir(bool writing) +{ + return writing ? DMA_FROM_DEVICE : DMA_TO_DEVICE; +} -/* - * RPC/RDMA connection management calls - xprtrdma/rpc_rdma.c +/* Memory registration calls xprtrdma/frwr_ops.c */ -void rpcrdma_conn_func(struct rpcrdma_ep *); -void rpcrdma_reply_handler(struct rpcrdma_rep *); +void frwr_reset(struct rpcrdma_req *req); +int frwr_query_device(struct rpcrdma_ep *ep, const struct ib_device *device); +int frwr_mr_init(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr *mr); +void frwr_mr_release(struct rpcrdma_mr *mr); +struct rpcrdma_mr_seg *frwr_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_mr_seg *seg, + int nsegs, bool writing, __be32 xid, + struct rpcrdma_mr *mr); +int frwr_send(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req); +void frwr_reminv(struct rpcrdma_rep *rep, struct list_head *mrs); +void frwr_unmap_sync(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req); +void frwr_unmap_async(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req); +int frwr_wp_create(struct rpcrdma_xprt *r_xprt); /* * RPC/RDMA protocol calls - xprtrdma/rpc_rdma.c */ -int rpcrdma_marshal_req(struct rpc_rqst *); - -/* Temporary NFS request map cache. Created in svc_rdma.c */ -extern struct kmem_cache *svc_rdma_map_cachep; -/* WR context cache. Created in svc_rdma.c */ -extern struct kmem_cache *svc_rdma_ctxt_cachep; -/* Workqueue created in svc_rdma.c */ -extern struct workqueue_struct *svc_rdma_wq; + +enum rpcrdma_chunktype { + rpcrdma_noch = 0, + rpcrdma_noch_pullup, + rpcrdma_noch_mapped, + rpcrdma_readch, + rpcrdma_areadch, + rpcrdma_writech, + rpcrdma_replych +}; + +int rpcrdma_prepare_send_sges(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, u32 hdrlen, + struct xdr_buf *xdr, + enum rpcrdma_chunktype rtype); +void rpcrdma_sendctx_unmap(struct rpcrdma_sendctx *sc); +int rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst); +void rpcrdma_set_max_header_sizes(struct rpcrdma_ep *ep); +void rpcrdma_reset_cwnd(struct rpcrdma_xprt *r_xprt); +void rpcrdma_complete_rqst(struct rpcrdma_rep *rep); +void rpcrdma_unpin_rqst(struct rpcrdma_rep *rep); +void rpcrdma_reply_handler(struct rpcrdma_rep *rep); + +static inline void rpcrdma_set_xdrlen(struct xdr_buf *xdr, size_t len) +{ + xdr->head[0].iov_len = len; + xdr->len = len; +} + +/* RPC/RDMA module init - xprtrdma/transport.c + */ +extern unsigned int xprt_rdma_max_inline_read; +extern unsigned int xprt_rdma_max_inline_write; +void xprt_rdma_format_addresses(struct rpc_xprt *xprt, struct sockaddr *sap); +void xprt_rdma_free_addresses(struct rpc_xprt *xprt); +void xprt_rdma_close(struct rpc_xprt *xprt); +void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq); +int xprt_rdma_init(void); +void xprt_rdma_cleanup(void); + +/* Backchannel calls - xprtrdma/backchannel.c + */ +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +int xprt_rdma_bc_setup(struct rpc_xprt *, unsigned int); +size_t xprt_rdma_bc_maxpayload(struct rpc_xprt *); +unsigned int xprt_rdma_bc_max_slots(struct rpc_xprt *); +void rpcrdma_bc_receive_call(struct rpcrdma_xprt *, struct rpcrdma_rep *); +int xprt_rdma_bc_send_reply(struct rpc_rqst *rqst); +void xprt_rdma_bc_free_rqst(struct rpc_rqst *); +void xprt_rdma_bc_destroy(struct rpc_xprt *, unsigned int); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +extern struct xprt_class xprt_rdma_bc; #endif /* _LINUX_SUNRPC_XPRT_RDMA_H */ diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index ddf0602603bd..2e1fe6013361 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0 /* * linux/net/sunrpc/xprtsock.c * @@ -46,10 +47,25 @@ #include <net/checksum.h> #include <net/udp.h> #include <net/tcp.h> +#include <net/tls_prot.h> +#include <net/handshake.h> +#include <linux/bvec.h> +#include <linux/highmem.h> +#include <linux/uio.h> +#include <linux/sched/mm.h> + +#include <trace/events/sock.h> +#include <trace/events/sunrpc.h> + +#include "socklib.h" #include "sunrpc.h" static void xs_close(struct rpc_xprt *xprt); +static void xs_reset_srcport(struct sock_xprt *transport); +static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock); +static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, + struct socket *sock); /* * xprtsock tunables @@ -66,15 +82,13 @@ static unsigned int xs_tcp_fin_timeout __read_mostly = XS_TCP_LINGER_TO; /* * We can register our own files under /proc/sys/sunrpc by - * calling register_sysctl_table() again. The files in that + * calling register_sysctl() again. The files in that * directory become the union of all files registered there. * * We simply need to make sure that we don't collide with * someone else's file names! */ -#ifdef RPC_DEBUG - static unsigned int min_slot_table_size = RPC_MIN_SLOT_TABLE; static unsigned int max_slot_table_size = RPC_MAX_SLOT_TABLE; static unsigned int max_tcp_slot_table_limit = RPC_MAX_SLOT_TABLE_LIMIT; @@ -83,6 +97,12 @@ static unsigned int xprt_max_resvport_limit = RPC_MAX_RESVPORT; static struct ctl_table_header *sunrpc_table_header; +static struct xprt_class xs_local_transport; +static struct xprt_class xs_udp_transport; +static struct xprt_class xs_tcp_transport; +static struct xprt_class xs_tcp_tls_transport; +static struct xprt_class xs_bc_tcp_transport; + /* * FIXME: changing the UDP slot table size should also resize the UDP * socket buffers for existing UDP transports @@ -140,20 +160,8 @@ static struct ctl_table xs_tunables_table[] = { .mode = 0644, .proc_handler = proc_dointvec_jiffies, }, - { }, }; -static struct ctl_table sunrpc_table[] = { - { - .procname = "sunrpc", - .mode = 0555, - .child = xs_tunables_table - }, - { }, -}; - -#endif - /* * Wait duration for a reply from the RPC portmapper. */ @@ -175,7 +183,6 @@ static struct ctl_table sunrpc_table[] = { * increase over time if the server is down or not responding. */ #define XS_TCP_INIT_REEST_TO (3U * HZ) -#define XS_TCP_MAX_REEST_TO (5U * 60 * HZ) /* * TCP idle timeout; client drops the transport socket if it is idle @@ -184,7 +191,12 @@ static struct ctl_table sunrpc_table[] = { */ #define XS_IDLE_DISC_TO (5U * 60 * HZ) -#ifdef RPC_DEBUG +/* + * TLS handshake timeout. + */ +#define XS_TLS_HANDSHAKE_TO (10U * HZ) + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # undef RPC_DEBUG_DATA # define RPCDBG_FACILITY RPCDBG_TRANS #endif @@ -214,63 +226,10 @@ static inline void xs_pktdump(char *msg, u32 *packet, unsigned int count) } #endif -struct sock_xprt { - struct rpc_xprt xprt; - - /* - * Network layer - */ - struct socket * sock; - struct sock * inet; - - /* - * State of TCP reply receive - */ - __be32 tcp_fraghdr, - tcp_xid, - tcp_calldir; - - u32 tcp_offset, - tcp_reclen; - - unsigned long tcp_copied, - tcp_flags; - - /* - * Connection of transports - */ - struct delayed_work connect_worker; - struct sockaddr_storage srcaddr; - unsigned short srcport; - - /* - * UDP socket buffer size parameters - */ - size_t rcvsize, - sndsize; - - /* - * Saved socket callback addresses - */ - void (*old_data_ready)(struct sock *, int); - void (*old_state_change)(struct sock *); - void (*old_write_space)(struct sock *); -}; - -/* - * TCP receive state flags - */ -#define TCP_RCV_LAST_FRAG (1UL << 0) -#define TCP_RCV_COPY_FRAGHDR (1UL << 1) -#define TCP_RCV_COPY_XID (1UL << 2) -#define TCP_RCV_COPY_DATA (1UL << 3) -#define TCP_RCV_READ_CALLDIR (1UL << 4) -#define TCP_RCV_COPY_CALLDIR (1UL << 5) - -/* - * TCP RPC flags - */ -#define TCP_RPC_REPLY (1UL << 6) +static inline struct rpc_xprt *xprt_from_sock(struct sock *sk) +{ + return (struct rpc_xprt *) sk->sk_user_data; +} static inline struct sockaddr *xs_addr(struct rpc_xprt *xprt) { @@ -303,7 +262,12 @@ static void xs_format_common_peer_addresses(struct rpc_xprt *xprt) switch (sap->sa_family) { case AF_LOCAL: sun = xs_addr_un(xprt); - strlcpy(buf, sun->sun_path, sizeof(buf)); + if (sun->sun_path[0]) { + strscpy(buf, sun->sun_path, sizeof(buf)); + } else { + buf[0] = '@'; + strscpy(buf+1, sun->sun_path+1, sizeof(buf)-1); + } xprt->address_strings[RPC_DISPLAY_ADDR] = kstrdup(buf, GFP_KERNEL); break; @@ -372,218 +336,670 @@ static void xs_free_peer_addresses(struct rpc_xprt *xprt) } } -#define XS_SENDMSG_FLAGS (MSG_DONTWAIT | MSG_NOSIGNAL) +static size_t +xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp) +{ + size_t i,n; -static int xs_send_kvec(struct socket *sock, struct sockaddr *addr, int addrlen, struct kvec *vec, unsigned int base, int more) + if (!want || !(buf->flags & XDRBUF_SPARSE_PAGES)) + return want; + n = (buf->page_base + want + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < n; i++) { + if (buf->pages[i]) + continue; + buf->bvec[i].bv_page = buf->pages[i] = alloc_page(gfp); + if (!buf->pages[i]) { + i *= PAGE_SIZE; + return i > buf->page_base ? i - buf->page_base : 0; + } + } + return want; +} + +static int +xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, + unsigned int *msg_flags, struct cmsghdr *cmsg, int ret) { - struct msghdr msg = { - .msg_name = addr, - .msg_namelen = addrlen, - .msg_flags = XS_SENDMSG_FLAGS | (more ? MSG_MORE : 0), + u8 content_type = tls_get_record_type(sock->sk, cmsg); + u8 level, description; + + switch (content_type) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + *msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(sock->sk, msg, &level, &description); + ret = (level == TLS_ALERT_LEVEL_FATAL) ? + -EACCES : -EAGAIN; + break; + default: + /* discard this record type */ + ret = -EAGAIN; + } + return ret; +} + +static int +xs_sock_recv_cmsg(struct socket *sock, unsigned int *msg_flags, int flags) +{ + union { + struct cmsghdr cmsg; + u8 buf[CMSG_SPACE(sizeof(u8))]; + } u; + u8 alert[2]; + struct kvec alert_kvec = { + .iov_base = alert, + .iov_len = sizeof(alert), }; - struct kvec iov = { - .iov_base = vec->iov_base + base, - .iov_len = vec->iov_len - base, + struct msghdr msg = { + .msg_flags = *msg_flags, + .msg_control = &u, + .msg_controllen = sizeof(u), }; + int ret; - if (iov.iov_len != 0) - return kernel_sendmsg(sock, &msg, &iov, 1, iov.iov_len); - return kernel_sendmsg(sock, &msg, NULL, 0, 0); + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &alert_kvec, 1, + alert_kvec.iov_len); + ret = sock_recvmsg(sock, &msg, flags); + if (ret > 0) { + if (tls_get_record_type(sock->sk, &u.cmsg) == TLS_RECORD_TYPE_ALERT) + iov_iter_revert(&msg.msg_iter, ret); + ret = xs_sock_process_cmsg(sock, &msg, msg_flags, &u.cmsg, + -EAGAIN); + } + return ret; } -static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned int base, int more) +static ssize_t +xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek) { - struct page **ppage; - unsigned int remainder; - int err, sent = 0; + ssize_t ret; + if (seek != 0) + iov_iter_advance(&msg->msg_iter, seek); + ret = sock_recvmsg(sock, msg, flags); + /* Handle TLS inband control message lazily */ + if (msg->msg_flags & MSG_CTRUNC) { + msg->msg_flags &= ~(MSG_CTRUNC | MSG_EOR); + if (ret == 0 || ret == -EIO) + ret = xs_sock_recv_cmsg(sock, &msg->msg_flags, flags); + } + return ret > 0 ? ret + seek : ret; +} - remainder = xdr->page_len - base; - base += xdr->page_base; - ppage = xdr->pages + (base >> PAGE_SHIFT); - base &= ~PAGE_MASK; - for(;;) { - unsigned int len = min_t(unsigned int, PAGE_SIZE - base, remainder); - int flags = XS_SENDMSG_FLAGS; +static ssize_t +xs_read_kvec(struct socket *sock, struct msghdr *msg, int flags, + struct kvec *kvec, size_t count, size_t seek) +{ + iov_iter_kvec(&msg->msg_iter, ITER_DEST, kvec, 1, count); + return xs_sock_recvmsg(sock, msg, flags, seek); +} - remainder -= len; - if (remainder != 0 || more) - flags |= MSG_MORE; - err = sock->ops->sendpage(sock, *ppage, base, len, flags); - if (remainder == 0 || err != len) - break; - sent += err; - ppage++; - base = 0; - } - if (sent == 0) - return err; - if (err > 0) - sent += err; - return sent; +static ssize_t +xs_read_bvec(struct socket *sock, struct msghdr *msg, int flags, + struct bio_vec *bvec, unsigned long nr, size_t count, + size_t seek) +{ + iov_iter_bvec(&msg->msg_iter, ITER_DEST, bvec, nr, count); + return xs_sock_recvmsg(sock, msg, flags, seek); } -/** - * xs_sendpages - write pages directly to a socket - * @sock: socket to send on - * @addr: UDP only -- address of destination - * @addrlen: UDP only -- length of destination address - * @xdr: buffer containing this request - * @base: starting position in the buffer - * - */ -static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base) +static ssize_t +xs_read_discard(struct socket *sock, struct msghdr *msg, int flags, + size_t count) +{ + iov_iter_discard(&msg->msg_iter, ITER_DEST, count); + return xs_sock_recvmsg(sock, msg, flags, 0); +} + +#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE +static void +xs_flush_bvec(const struct bio_vec *bvec, size_t count, size_t seek) +{ + struct bvec_iter bi = { + .bi_size = count, + }; + struct bio_vec bv; + + bvec_iter_advance(bvec, &bi, seek & PAGE_MASK); + for_each_bvec(bv, bvec, bi, bi) + flush_dcache_page(bv.bv_page); +} +#else +static inline void +xs_flush_bvec(const struct bio_vec *bvec, size_t count, size_t seek) { - unsigned int remainder = xdr->len - base; - int err, sent = 0; +} +#endif - if (unlikely(!sock)) - return -ENOTSOCK; +static ssize_t +xs_read_xdr_buf(struct socket *sock, struct msghdr *msg, int flags, + struct xdr_buf *buf, size_t count, size_t seek, size_t *read) +{ + size_t want, seek_init = seek, offset = 0; + ssize_t ret; - clear_bit(SOCK_ASYNC_NOSPACE, &sock->flags); - if (base != 0) { - addr = NULL; - addrlen = 0; + want = min_t(size_t, count, buf->head[0].iov_len); + if (seek < want) { + ret = xs_read_kvec(sock, msg, flags, &buf->head[0], want, seek); + if (ret <= 0) + goto sock_err; + offset += ret; + if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + goto out; + if (ret != want) + goto out; + seek = 0; + } else { + seek -= want; + offset += want; } - if (base < xdr->head[0].iov_len || addr != NULL) { - unsigned int len = xdr->head[0].iov_len - base; - remainder -= len; - err = xs_send_kvec(sock, addr, addrlen, &xdr->head[0], base, remainder != 0); - if (remainder == 0 || err != len) + want = xs_alloc_sparse_pages( + buf, min_t(size_t, count - offset, buf->page_len), + GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN); + if (seek < want) { + ret = xs_read_bvec(sock, msg, flags, buf->bvec, + xdr_buf_pagecount(buf), + want + buf->page_base, + seek + buf->page_base); + if (ret <= 0) + goto sock_err; + xs_flush_bvec(buf->bvec, ret, seek + buf->page_base); + ret -= buf->page_base; + offset += ret; + if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC)) goto out; - sent += err; - base = 0; - } else - base -= xdr->head[0].iov_len; + if (ret != want) + goto out; + seek = 0; + } else { + seek -= want; + offset += want; + } - if (base < xdr->page_len) { - unsigned int len = xdr->page_len - base; - remainder -= len; - err = xs_send_pagedata(sock, xdr, base, remainder != 0); - if (remainder == 0 || err != len) + want = min_t(size_t, count - offset, buf->tail[0].iov_len); + if (seek < want) { + ret = xs_read_kvec(sock, msg, flags, &buf->tail[0], want, seek); + if (ret <= 0) + goto sock_err; + offset += ret; + if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC)) goto out; - sent += err; - base = 0; - } else - base -= xdr->page_len; + if (ret != want) + goto out; + } else if (offset < seek_init) + offset = seek_init; + ret = -EMSGSIZE; +out: + *read = offset - seek_init; + return ret; +sock_err: + offset += seek; + goto out; +} + +static void +xs_read_header(struct sock_xprt *transport, struct xdr_buf *buf) +{ + if (!transport->recv.copied) { + if (buf->head[0].iov_len >= transport->recv.offset) + memcpy(buf->head[0].iov_base, + &transport->recv.xid, + transport->recv.offset); + transport->recv.copied = transport->recv.offset; + } +} + +static bool +xs_read_stream_request_done(struct sock_xprt *transport) +{ + return transport->recv.fraghdr & cpu_to_be32(RPC_LAST_STREAM_FRAGMENT); +} + +static void +xs_read_stream_check_eor(struct sock_xprt *transport, + struct msghdr *msg) +{ + if (xs_read_stream_request_done(transport)) + msg->msg_flags |= MSG_EOR; +} + +static ssize_t +xs_read_stream_request(struct sock_xprt *transport, struct msghdr *msg, + int flags, struct rpc_rqst *req) +{ + struct xdr_buf *buf = &req->rq_private_buf; + size_t want, read; + ssize_t ret; + + xs_read_header(transport, buf); + + want = transport->recv.len - transport->recv.offset; + if (want != 0) { + ret = xs_read_xdr_buf(transport->sock, msg, flags, buf, + transport->recv.copied + want, + transport->recv.copied, + &read); + transport->recv.offset += read; + transport->recv.copied += read; + } + + if (transport->recv.offset == transport->recv.len) + xs_read_stream_check_eor(transport, msg); + + if (want == 0) + return 0; + + switch (ret) { + default: + break; + case -EFAULT: + case -EMSGSIZE: + msg->msg_flags |= MSG_TRUNC; + return read; + case 0: + return -ESHUTDOWN; + } + return ret < 0 ? ret : read; +} + +static size_t +xs_read_stream_headersize(bool isfrag) +{ + if (isfrag) + return sizeof(__be32); + return 3 * sizeof(__be32); +} + +static ssize_t +xs_read_stream_header(struct sock_xprt *transport, struct msghdr *msg, + int flags, size_t want, size_t seek) +{ + struct kvec kvec = { + .iov_base = &transport->recv.fraghdr, + .iov_len = want, + }; + return xs_read_kvec(transport->sock, msg, flags, &kvec, want, seek); +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static ssize_t +xs_read_stream_call(struct sock_xprt *transport, struct msghdr *msg, int flags) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct rpc_rqst *req; + ssize_t ret; + + /* Is this transport associated with the backchannel? */ + if (!xprt->bc_serv) + return -ESHUTDOWN; + + /* Look up and lock the request corresponding to the given XID */ + req = xprt_lookup_bc_request(xprt, transport->recv.xid); + if (!req) { + printk(KERN_WARNING "Callback slot table overflowed\n"); + return -ESHUTDOWN; + } + if (transport->recv.copied && !req->rq_private_buf.len) + return -ESHUTDOWN; + + ret = xs_read_stream_request(transport, msg, flags, req); + if (msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + xprt_complete_bc_request(req, transport->recv.copied); + else + req->rq_private_buf.len = transport->recv.copied; - if (base >= xdr->tail[0].iov_len) - return sent; - err = xs_send_kvec(sock, NULL, 0, &xdr->tail[0], base, 0); + return ret; +} +#else /* CONFIG_SUNRPC_BACKCHANNEL */ +static ssize_t +xs_read_stream_call(struct sock_xprt *transport, struct msghdr *msg, int flags) +{ + return -ESHUTDOWN; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +static ssize_t +xs_read_stream_reply(struct sock_xprt *transport, struct msghdr *msg, int flags) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct rpc_rqst *req; + ssize_t ret = 0; + + /* Look up and lock the request corresponding to the given XID */ + spin_lock(&xprt->queue_lock); + req = xprt_lookup_rqst(xprt, transport->recv.xid); + if (!req || (transport->recv.copied && !req->rq_private_buf.len)) { + msg->msg_flags |= MSG_TRUNC; + goto out; + } + xprt_pin_rqst(req); + spin_unlock(&xprt->queue_lock); + + ret = xs_read_stream_request(transport, msg, flags, req); + + spin_lock(&xprt->queue_lock); + if (msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + xprt_complete_rqst(req->rq_task, transport->recv.copied); + else + req->rq_private_buf.len = transport->recv.copied; + xprt_unpin_rqst(req); out: - if (sent == 0) - return err; - if (err > 0) - sent += err; - return sent; + spin_unlock(&xprt->queue_lock); + return ret; } -static void xs_nospace_callback(struct rpc_task *task) +static ssize_t +xs_read_stream(struct sock_xprt *transport, int flags) { - struct sock_xprt *transport = container_of(task->tk_rqstp->rq_xprt, struct sock_xprt, xprt); + struct msghdr msg = { 0 }; + size_t want, read = 0; + ssize_t ret = 0; + + if (transport->recv.len == 0) { + want = xs_read_stream_headersize(transport->recv.copied != 0); + ret = xs_read_stream_header(transport, &msg, flags, want, + transport->recv.offset); + if (ret <= 0) + goto out_err; + transport->recv.offset = ret; + if (transport->recv.offset != want) + return transport->recv.offset; + transport->recv.len = be32_to_cpu(transport->recv.fraghdr) & + RPC_FRAGMENT_SIZE_MASK; + transport->recv.offset -= sizeof(transport->recv.fraghdr); + read = ret; + } - transport->inet->sk_write_pending--; - clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); + switch (be32_to_cpu(transport->recv.calldir)) { + default: + msg.msg_flags |= MSG_TRUNC; + break; + case RPC_CALL: + ret = xs_read_stream_call(transport, &msg, flags); + break; + case RPC_REPLY: + ret = xs_read_stream_reply(transport, &msg, flags); + } + if (msg.msg_flags & MSG_TRUNC) { + transport->recv.calldir = cpu_to_be32(-1); + transport->recv.copied = -1; + } + if (ret < 0) + goto out_err; + read += ret; + if (transport->recv.offset < transport->recv.len) { + if (!(msg.msg_flags & MSG_TRUNC)) + return read; + msg.msg_flags = 0; + ret = xs_read_discard(transport->sock, &msg, flags, + transport->recv.len - transport->recv.offset); + if (ret <= 0) + goto out_err; + transport->recv.offset += ret; + read += ret; + if (transport->recv.offset != transport->recv.len) + return read; + } + if (xs_read_stream_request_done(transport)) { + trace_xs_stream_read_request(transport); + transport->recv.copied = 0; + } + transport->recv.offset = 0; + transport->recv.len = 0; + return read; +out_err: + return ret != 0 ? ret : -ESHUTDOWN; +} + +static __poll_t xs_poll_socket(struct sock_xprt *transport) +{ + return transport->sock->ops->poll(transport->file, transport->sock, + NULL); +} + +static bool xs_poll_socket_readable(struct sock_xprt *transport) +{ + __poll_t events = xs_poll_socket(transport); + + return (events & (EPOLLIN | EPOLLRDNORM)) && !(events & EPOLLRDHUP); +} + +static void xs_poll_check_readable(struct sock_xprt *transport) +{ + + clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state)) + return; + if (!xs_poll_socket_readable(transport)) + return; + if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + queue_work(xprtiod_workqueue, &transport->recv_worker); } +static void xs_stream_data_receive(struct sock_xprt *transport) +{ + size_t read = 0; + ssize_t ret = 0; + + mutex_lock(&transport->recv_mutex); + if (transport->sock == NULL) + goto out; + for (;;) { + ret = xs_read_stream(transport, MSG_DONTWAIT); + if (ret < 0) + break; + read += ret; + cond_resched(); + } + if (ret == -ESHUTDOWN) + kernel_sock_shutdown(transport->sock, SHUT_RDWR); + else if (ret == -EACCES) + xprt_wake_pending_tasks(&transport->xprt, -EACCES); + else + xs_poll_check_readable(transport); +out: + mutex_unlock(&transport->recv_mutex); + trace_xs_stream_read_data(&transport->xprt, ret, read); +} + +static void xs_stream_data_receive_workfn(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, recv_worker); + unsigned int pflags = memalloc_nofs_save(); + + xs_stream_data_receive(transport); + memalloc_nofs_restore(pflags); +} + +static void +xs_stream_reset_connect(struct sock_xprt *transport) +{ + transport->recv.offset = 0; + transport->recv.len = 0; + transport->recv.copied = 0; + transport->xmit.offset = 0; +} + +static void +xs_stream_start_connect(struct sock_xprt *transport) +{ + transport->xprt.stat.connect_count++; + transport->xprt.stat.connect_start = jiffies; +} + +#define XS_SENDMSG_FLAGS (MSG_DONTWAIT | MSG_NOSIGNAL) + /** - * xs_nospace - place task on wait queue if transmit was incomplete - * @task: task to put to sleep + * xs_nospace - handle transmit was incomplete + * @req: pointer to RPC request + * @transport: pointer to struct sock_xprt * */ -static int xs_nospace(struct rpc_task *task) +static int xs_nospace(struct rpc_rqst *req, struct sock_xprt *transport) { - struct rpc_rqst *req = task->tk_rqstp; - struct rpc_xprt *xprt = req->rq_xprt; - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct rpc_xprt *xprt = &transport->xprt; + struct sock *sk = transport->inet; int ret = -EAGAIN; - dprintk("RPC: %5u xmit incomplete (%u left of %u)\n", - task->tk_pid, req->rq_slen - req->rq_bytes_sent, - req->rq_slen); + trace_rpc_socket_nospace(req, transport); /* Protect against races with write_space */ - spin_lock_bh(&xprt->transport_lock); + spin_lock(&xprt->transport_lock); /* Don't race with disconnect */ if (xprt_connected(xprt)) { - if (test_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags)) { - /* - * Notify TCP that we're limited by the application - * window size - */ - set_bit(SOCK_NOSPACE, &transport->sock->flags); - transport->inet->sk_write_pending++; - /* ...and wait for more buffer space */ - xprt_wait_for_buffer_space(task, xs_nospace_callback); - } - } else { - clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); + /* wait for more buffer space */ + set_bit(XPRT_SOCK_NOSPACE, &transport->sock_state); + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + sk->sk_write_pending++; + xprt_wait_for_buffer_space(xprt); + } else ret = -ENOTCONN; - } - spin_unlock_bh(&xprt->transport_lock); + spin_unlock(&xprt->transport_lock); + return ret; +} + +static int xs_sock_nospace(struct rpc_rqst *req) +{ + struct sock_xprt *transport = + container_of(req->rq_xprt, struct sock_xprt, xprt); + struct sock *sk = transport->inet; + int ret = -EAGAIN; + + lock_sock(sk); + if (!sock_writeable(sk)) + ret = xs_nospace(req, transport); + release_sock(sk); + return ret; +} + +static int xs_stream_nospace(struct rpc_rqst *req, bool vm_wait) +{ + struct sock_xprt *transport = + container_of(req->rq_xprt, struct sock_xprt, xprt); + struct sock *sk = transport->inet; + int ret = -EAGAIN; + + if (vm_wait) + return -ENOBUFS; + lock_sock(sk); + if (!sk_stream_memory_free(sk)) + ret = xs_nospace(req, transport); + release_sock(sk); return ret; } +static int xs_stream_prepare_request(struct rpc_rqst *req, struct xdr_buf *buf) +{ + return xdr_alloc_bvec(buf, rpc_task_gfp_mask()); +} + +static void xs_stream_abort_send_request(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + + if (transport->xmit.offset != 0 && + !test_bit(XPRT_CLOSE_WAIT, &xprt->state)) + xprt_force_disconnect(xprt); +} + +/* + * Determine if the previous message in the stream was aborted before it + * could complete transmission. + */ +static bool +xs_send_request_was_aborted(struct sock_xprt *transport, struct rpc_rqst *req) +{ + return transport->xmit.offset != 0 && req->rq_bytes_sent == 0; +} + /* - * Construct a stream transport record marker in @buf. + * Return the stream record marker field for a record of length < 2^31-1 */ -static inline void xs_encode_stream_record_marker(struct xdr_buf *buf) +static rpc_fraghdr +xs_stream_record_marker(struct xdr_buf *xdr) { - u32 reclen = buf->len - sizeof(rpc_fraghdr); - rpc_fraghdr *base = buf->head[0].iov_base; - *base = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | reclen); + if (!xdr->len) + return 0; + return cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | (u32)xdr->len); } /** * xs_local_send_request - write an RPC request to an AF_LOCAL socket - * @task: RPC task that manages the state of an RPC request + * @req: pointer to RPC request * * Return values: * 0: The request has been sent * EAGAIN: The socket was blocked, please call again later to * complete the request * ENOTCONN: Caller needs to invoke connect logic then call again - * other: Some other error occured, the request was not sent + * other: Some other error occurred, the request was not sent */ -static int xs_local_send_request(struct rpc_task *task) +static int xs_local_send_request(struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct xdr_buf *xdr = &req->rq_snd_buf; + rpc_fraghdr rm = xs_stream_record_marker(xdr); + unsigned int msglen = rm ? req->rq_slen + sizeof(rm) : req->rq_slen; + struct msghdr msg = { + .msg_flags = XS_SENDMSG_FLAGS, + }; + bool vm_wait; + unsigned int sent; int status; - xs_encode_stream_record_marker(&req->rq_snd_buf); + /* Close the stream if the previous transmission was incomplete */ + if (xs_send_request_was_aborted(transport, req)) { + xprt_force_disconnect(xprt); + return -ENOTCONN; + } xs_pktdump("packet data:", req->rq_svec->iov_base, req->rq_svec->iov_len); - status = xs_sendpages(transport->sock, NULL, 0, - xdr, req->rq_bytes_sent); + vm_wait = sk_stream_is_writeable(transport->inet) ? true : false; + + req->rq_xtime = ktime_get(); + status = xprt_sock_sendmsg(transport->sock, &msg, xdr, + transport->xmit.offset, rm, &sent); dprintk("RPC: %s(%u) = %d\n", - __func__, xdr->len - req->rq_bytes_sent, status); - if (likely(status >= 0)) { - req->rq_bytes_sent += status; - req->rq_xmit_bytes_sent += status; - if (likely(req->rq_bytes_sent >= req->rq_slen)) { - req->rq_bytes_sent = 0; + __func__, xdr->len - transport->xmit.offset, status); + + if (likely(sent > 0) || status == 0) { + transport->xmit.offset += sent; + req->rq_bytes_sent = transport->xmit.offset; + if (likely(req->rq_bytes_sent >= msglen)) { + req->rq_xmit_bytes_sent += transport->xmit.offset; + transport->xmit.offset = 0; return 0; } status = -EAGAIN; + vm_wait = false; } switch (status) { case -EAGAIN: - status = xs_nospace(task); + status = xs_stream_nospace(req, vm_wait); break; default: dprintk("RPC: sendmsg returned unrecognized error %d\n", -status); + fallthrough; case -EPIPE: - xs_close(xprt); + xprt_force_disconnect(xprt); status = -ENOTCONN; } @@ -592,7 +1008,7 @@ static int xs_local_send_request(struct rpc_task *task) /** * xs_udp_send_request - write an RPC request to a UDP socket - * @task: address of RPC task that manages the state of an RPC request + * @req: pointer to RPC request * * Return values: * 0: The request has been sent @@ -601,12 +1017,17 @@ static int xs_local_send_request(struct rpc_task *task) * ENOTCONN: Caller needs to invoke connect logic then call again * other: Some other error occurred, the request was not sent */ -static int xs_udp_send_request(struct rpc_task *task) +static int xs_udp_send_request(struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct xdr_buf *xdr = &req->rq_snd_buf; + struct msghdr msg = { + .msg_name = xs_addr(xprt), + .msg_namelen = xprt->addrlen, + .msg_flags = XS_SENDMSG_FLAGS, + }; + unsigned int sent; int status; xs_pktdump("packet data:", @@ -615,63 +1036,62 @@ static int xs_udp_send_request(struct rpc_task *task) if (!xprt_bound(xprt)) return -ENOTCONN; - status = xs_sendpages(transport->sock, - xs_addr(xprt), - xprt->addrlen, xdr, - req->rq_bytes_sent); + + if (!xprt_request_get_cong(xprt, req)) + return -EBADSLT; + + status = xdr_alloc_bvec(xdr, rpc_task_gfp_mask()); + if (status < 0) + return status; + req->rq_xtime = ktime_get(); + status = xprt_sock_sendmsg(transport->sock, &msg, xdr, 0, 0, &sent); dprintk("RPC: xs_udp_send_request(%u) = %d\n", - xdr->len - req->rq_bytes_sent, status); + xdr->len, status); - if (status >= 0) { - req->rq_xmit_bytes_sent += status; - if (status >= req->rq_slen) + /* firewall is blocking us, don't return -EAGAIN or we end up looping */ + if (status == -EPERM) + goto process_status; + + if (status == -EAGAIN && sock_writeable(transport->inet)) + status = -ENOBUFS; + + if (sent > 0 || status == 0) { + req->rq_xmit_bytes_sent += sent; + if (sent >= req->rq_slen) return 0; /* Still some bytes left; set up for a retry later. */ status = -EAGAIN; } +process_status: switch (status) { case -ENOTSOCK: status = -ENOTCONN; /* Should we call xs_close() here? */ break; case -EAGAIN: - status = xs_nospace(task); + status = xs_sock_nospace(req); break; - default: - dprintk("RPC: sendmsg returned unrecognized error %d\n", - -status); case -ENETUNREACH: + case -ENOBUFS: case -EPIPE: case -ECONNREFUSED: + case -EPERM: /* When the server has died, an ICMP port unreachable message * prompts ECONNREFUSED. */ - clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); } return status; } /** - * xs_tcp_shutdown - gracefully shut down a TCP socket - * @xprt: transport - * - * Initiates a graceful shutdown of the TCP socket by calling the - * equivalent of shutdown(SHUT_WR); - */ -static void xs_tcp_shutdown(struct rpc_xprt *xprt) -{ - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - struct socket *sock = transport->sock; - - if (sock != NULL) - kernel_sock_shutdown(sock, SHUT_WR); -} - -/** * xs_tcp_send_request - write an RPC request to a TCP socket - * @task: address of RPC task that manages the state of an RPC request + * @req: pointer to RPC request * * Return values: * 0: The request has been sent @@ -683,47 +1103,71 @@ static void xs_tcp_shutdown(struct rpc_xprt *xprt) * XXX: In the case of soft timeouts, should we eventually give up * if sendmsg is not able to make progress? */ -static int xs_tcp_send_request(struct rpc_task *task) +static int xs_tcp_send_request(struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct xdr_buf *xdr = &req->rq_snd_buf; + rpc_fraghdr rm = xs_stream_record_marker(xdr); + unsigned int msglen = rm ? req->rq_slen + sizeof(rm) : req->rq_slen; + struct msghdr msg = { + .msg_flags = XS_SENDMSG_FLAGS, + }; + bool vm_wait; + unsigned int sent; int status; - xs_encode_stream_record_marker(&req->rq_snd_buf); + /* Close the stream if the previous transmission was incomplete */ + if (xs_send_request_was_aborted(transport, req)) { + if (transport->sock != NULL) + kernel_sock_shutdown(transport->sock, SHUT_RDWR); + return -ENOTCONN; + } + if (!transport->inet) + return -ENOTCONN; xs_pktdump("packet data:", req->rq_svec->iov_base, req->rq_svec->iov_len); + if (test_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state)) + xs_tcp_set_socket_timeouts(xprt, transport->sock); + + xs_set_srcport(transport, transport->sock); + /* Continue transmitting the packet/record. We must be careful * to cope with writespace callbacks arriving _after_ we have * called sendmsg(). */ - while (1) { - status = xs_sendpages(transport->sock, - NULL, 0, xdr, req->rq_bytes_sent); + req->rq_xtime = ktime_get(); + tcp_sock_set_cork(transport->inet, true); - dprintk("RPC: xs_tcp_send_request(%u) = %d\n", - xdr->len - req->rq_bytes_sent, status); + vm_wait = sk_stream_is_writeable(transport->inet) ? true : false; - if (unlikely(status < 0)) - break; + do { + status = xprt_sock_sendmsg(transport->sock, &msg, xdr, + transport->xmit.offset, rm, &sent); + + dprintk("RPC: xs_tcp_send_request(%u) = %d\n", + xdr->len - transport->xmit.offset, status); /* If we've sent the entire packet, immediately * reset the count of bytes sent. */ - req->rq_bytes_sent += status; - req->rq_xmit_bytes_sent += status; - if (likely(req->rq_bytes_sent >= req->rq_slen)) { - req->rq_bytes_sent = 0; + transport->xmit.offset += sent; + req->rq_bytes_sent = transport->xmit.offset; + if (likely(req->rq_bytes_sent >= msglen)) { + req->rq_xmit_bytes_sent += transport->xmit.offset; + transport->xmit.offset = 0; + if (atomic_long_read(&xprt->xmit_queuelen) == 1) + tcp_sock_set_cork(transport->inet, false); return 0; } - if (status != 0) - continue; - status = -EAGAIN; - break; - } + WARN_ON_ONCE(sent == 0 && status == 0); + + if (sent > 0) + vm_wait = false; + + } while (status == 0); switch (status) { case -ENOTSOCK: @@ -731,56 +1175,29 @@ static int xs_tcp_send_request(struct rpc_task *task) /* Should we call xs_close() here? */ break; case -EAGAIN: - status = xs_nospace(task); + status = xs_stream_nospace(req, vm_wait); break; - default: - dprintk("RPC: sendmsg returned unrecognized error %d\n", - -status); case -ECONNRESET: - xs_tcp_shutdown(xprt); case -ECONNREFUSED: case -ENOTCONN: + case -EADDRINUSE: + case -ENOBUFS: case -EPIPE: - clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); } return status; } -/** - * xs_tcp_release_xprt - clean up after a tcp transmission - * @xprt: transport - * @task: rpc task - * - * This cleans up if an error causes us to abort the transmission of a request. - * In this case, the socket may need to be reset in order to avoid confusing - * the server. - */ -static void xs_tcp_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) -{ - struct rpc_rqst *req; - - if (task != xprt->snd_task) - return; - if (task == NULL) - goto out_release; - req = task->tk_rqstp; - if (req == NULL) - goto out_release; - if (req->rq_bytes_sent == 0) - goto out_release; - if (req->rq_bytes_sent == req->rq_snd_buf.len) - goto out_release; - set_bit(XPRT_CLOSE_WAIT, &xprt->state); -out_release: - xprt_release_xprt(xprt, task); -} - static void xs_save_old_callbacks(struct sock_xprt *transport, struct sock *sk) { transport->old_data_ready = sk->sk_data_ready; transport->old_state_change = sk->sk_state_change; transport->old_write_space = sk->sk_write_space; + transport->old_error_report = sk->sk_error_report; } static void xs_restore_old_callbacks(struct sock_xprt *transport, struct sock *sk) @@ -788,30 +1205,114 @@ static void xs_restore_old_callbacks(struct sock_xprt *transport, struct sock *s sk->sk_data_ready = transport->old_data_ready; sk->sk_state_change = transport->old_state_change; sk->sk_write_space = transport->old_write_space; + sk->sk_error_report = transport->old_error_report; +} + +static void xs_sock_reset_state_flags(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + transport->xprt_err = 0; + clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state); + clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state); + clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state); + clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state); + clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); +} + +static void xs_run_error_worker(struct sock_xprt *transport, unsigned int nr) +{ + set_bit(nr, &transport->sock_state); + queue_work(xprtiod_workqueue, &transport->error_worker); +} + +static void xs_sock_reset_connection_flags(struct rpc_xprt *xprt) +{ + xprt->connect_cookie++; + smp_mb__before_atomic(); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + clear_bit(XPRT_CLOSING, &xprt->state); + xs_sock_reset_state_flags(xprt); + smp_mb__after_atomic(); +} + +/** + * xs_error_report - callback to handle TCP socket state errors + * @sk: socket + * + * Note: we don't call sock_error() since there may be a rpc_task + * using the socket, and so we don't want to clear sk->sk_err. + */ +static void xs_error_report(struct sock *sk) +{ + struct sock_xprt *transport; + struct rpc_xprt *xprt; + + if (!(xprt = xprt_from_sock(sk))) + return; + + transport = container_of(xprt, struct sock_xprt, xprt); + transport->xprt_err = -sk->sk_err; + if (transport->xprt_err == 0) + return; + dprintk("RPC: xs_error_report client %p, error=%d...\n", + xprt, -transport->xprt_err); + trace_rpc_socket_error(xprt, sk->sk_socket, transport->xprt_err); + + /* barrier ensures xprt_err is set before XPRT_SOCK_WAKE_ERROR */ + smp_mb__before_atomic(); + xs_run_error_worker(transport, XPRT_SOCK_WAKE_ERROR); } static void xs_reset_transport(struct sock_xprt *transport) { struct socket *sock = transport->sock; struct sock *sk = transport->inet; + struct rpc_xprt *xprt = &transport->xprt; + struct file *filp = transport->file; if (sk == NULL) return; + /* + * Make sure we're calling this in a context from which it is safe + * to call __fput_sync(). In practice that means rpciod and the + * system workqueue. + */ + if (!(current->flags & PF_WQ_WORKER)) { + WARN_ON_ONCE(1); + set_bit(XPRT_CLOSE_WAIT, &xprt->state); + return; + } - transport->srcport = 0; + if (atomic_read(&transport->xprt.swapper)) + sk_clear_memalloc(sk); + + tls_handshake_cancel(sk); - write_lock_bh(&sk->sk_callback_lock); + kernel_sock_shutdown(sock, SHUT_RDWR); + + mutex_lock(&transport->recv_mutex); + lock_sock(sk); transport->inet = NULL; transport->sock = NULL; + transport->file = NULL; sk->sk_user_data = NULL; + sk->sk_sndtimeo = 0; xs_restore_old_callbacks(transport, sk); - write_unlock_bh(&sk->sk_callback_lock); + xprt_clear_connected(xprt); + xs_sock_reset_connection_flags(xprt); + /* Reset stream record info */ + xs_stream_reset_connect(transport); + release_sock(sk); + mutex_unlock(&transport->recv_mutex); - sk->sk_no_check = 0; + trace_rpc_socket_close(xprt, sock); + __fput_sync(filp); - sock_release(sock); + xprt_disconnect_done(xprt); } /** @@ -830,31 +1331,23 @@ static void xs_close(struct rpc_xprt *xprt) dprintk("RPC: xs_close xprt %p\n", xprt); + if (transport->sock) + tls_handshake_close(transport->sock); xs_reset_transport(transport); xprt->reestablish_timeout = 0; - - smp_mb__before_clear_bit(); - clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); - clear_bit(XPRT_CLOSE_WAIT, &xprt->state); - clear_bit(XPRT_CLOSING, &xprt->state); - smp_mb__after_clear_bit(); - xprt_disconnect_done(xprt); } -static void xs_tcp_close(struct rpc_xprt *xprt) +static void xs_inject_disconnect(struct rpc_xprt *xprt) { - if (test_and_clear_bit(XPRT_CONNECTION_CLOSE, &xprt->state)) - xs_close(xprt); - else - xs_tcp_shutdown(xprt); + dprintk("RPC: injecting transport disconnect on xprt=%p\n", + xprt); + xprt_disconnect_done(xprt); } -static void xs_local_destroy(struct rpc_xprt *xprt) +static void xs_xprt_free(struct rpc_xprt *xprt) { - xs_close(xprt); xs_free_peer_addresses(xprt); xprt_free(xprt); - module_put(THIS_MODULE); } /** @@ -864,140 +1357,54 @@ static void xs_local_destroy(struct rpc_xprt *xprt) */ static void xs_destroy(struct rpc_xprt *xprt) { - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - + struct sock_xprt *transport = container_of(xprt, + struct sock_xprt, xprt); dprintk("RPC: xs_destroy xprt %p\n", xprt); cancel_delayed_work_sync(&transport->connect_worker); - - xs_local_destroy(xprt); -} - -static inline struct rpc_xprt *xprt_from_sock(struct sock *sk) -{ - return (struct rpc_xprt *) sk->sk_user_data; -} - -static int xs_local_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) -{ - struct xdr_skb_reader desc = { - .skb = skb, - .offset = sizeof(rpc_fraghdr), - .count = skb->len - sizeof(rpc_fraghdr), - }; - - if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_bits) < 0) - return -1; - if (desc.count) - return -1; - return 0; -} - -/** - * xs_local_data_ready - "data ready" callback for AF_LOCAL sockets - * @sk: socket with data to read - * @len: how much data to read - * - * Currently this assumes we can read the whole reply in a single gulp. - */ -static void xs_local_data_ready(struct sock *sk, int len) -{ - struct rpc_task *task; - struct rpc_xprt *xprt; - struct rpc_rqst *rovr; - struct sk_buff *skb; - int err, repsize, copied; - u32 _xid; - __be32 *xp; - - read_lock_bh(&sk->sk_callback_lock); - dprintk("RPC: %s...\n", __func__); - xprt = xprt_from_sock(sk); - if (xprt == NULL) - goto out; - - skb = skb_recv_datagram(sk, 0, 1, &err); - if (skb == NULL) - goto out; - - repsize = skb->len - sizeof(rpc_fraghdr); - if (repsize < 4) { - dprintk("RPC: impossible RPC reply size %d\n", repsize); - goto dropit; - } - - /* Copy the XID from the skb... */ - xp = skb_header_pointer(skb, sizeof(rpc_fraghdr), sizeof(_xid), &_xid); - if (xp == NULL) - goto dropit; - - /* Look up and lock the request corresponding to the given XID */ - spin_lock(&xprt->transport_lock); - rovr = xprt_lookup_rqst(xprt, *xp); - if (!rovr) - goto out_unlock; - task = rovr->rq_task; - - copied = rovr->rq_private_buf.buflen; - if (copied > repsize) - copied = repsize; - - if (xs_local_copy_to_xdr(&rovr->rq_private_buf, skb)) { - dprintk("RPC: sk_buff copy failed\n"); - goto out_unlock; - } - - xprt_complete_rqst(task, copied); - - out_unlock: - spin_unlock(&xprt->transport_lock); - dropit: - skb_free_datagram(sk, skb); - out: - read_unlock_bh(&sk->sk_callback_lock); + xs_close(xprt); + cancel_work_sync(&transport->recv_worker); + cancel_work_sync(&transport->error_worker); + xs_xprt_free(xprt); + module_put(THIS_MODULE); } /** - * xs_udp_data_ready - "data ready" callback for UDP sockets - * @sk: socket with data to read - * @len: how much data to read + * xs_udp_data_read_skb - receive callback for UDP sockets + * @xprt: transport + * @sk: socket + * @skb: skbuff * */ -static void xs_udp_data_ready(struct sock *sk, int len) +static void xs_udp_data_read_skb(struct rpc_xprt *xprt, + struct sock *sk, + struct sk_buff *skb) { struct rpc_task *task; - struct rpc_xprt *xprt; struct rpc_rqst *rovr; - struct sk_buff *skb; - int err, repsize, copied; + int repsize, copied; u32 _xid; __be32 *xp; - read_lock_bh(&sk->sk_callback_lock); - dprintk("RPC: xs_udp_data_ready...\n"); - if (!(xprt = xprt_from_sock(sk))) - goto out; - - if ((skb = skb_recv_datagram(sk, 0, 1, &err)) == NULL) - goto out; - - repsize = skb->len - sizeof(struct udphdr); + repsize = skb->len; if (repsize < 4) { dprintk("RPC: impossible RPC reply size %d!\n", repsize); - goto dropit; + return; } /* Copy the XID from the skb... */ - xp = skb_header_pointer(skb, sizeof(struct udphdr), - sizeof(_xid), &_xid); + xp = skb_header_pointer(skb, 0, sizeof(_xid), &_xid); if (xp == NULL) - goto dropit; + return; /* Look up and lock the request corresponding to the given XID */ - spin_lock(&xprt->transport_lock); + spin_lock(&xprt->queue_lock); rovr = xprt_lookup_rqst(xprt, *xp); if (!rovr) goto out_unlock; + xprt_pin_rqst(rovr); + xprt_update_rtt(rovr->rq_task); + spin_unlock(&xprt->queue_lock); task = rovr->rq_task; if ((copied = rovr->rq_private_buf.buflen) > repsize) @@ -1005,472 +1412,124 @@ static void xs_udp_data_ready(struct sock *sk, int len) /* Suck it into the iovec, verify checksum if not done by hw. */ if (csum_partial_copy_to_xdr(&rovr->rq_private_buf, skb)) { - UDPX_INC_STATS_BH(sk, UDP_MIB_INERRORS); - goto out_unlock; + spin_lock(&xprt->queue_lock); + __UDPX_INC_STATS(sk, UDP_MIB_INERRORS); + goto out_unpin; } - UDPX_INC_STATS_BH(sk, UDP_MIB_INDATAGRAMS); + spin_lock(&xprt->transport_lock); xprt_adjust_cwnd(xprt, task, copied); + spin_unlock(&xprt->transport_lock); + spin_lock(&xprt->queue_lock); xprt_complete_rqst(task, copied); - + __UDPX_INC_STATS(sk, UDP_MIB_INDATAGRAMS); +out_unpin: + xprt_unpin_rqst(rovr); out_unlock: - spin_unlock(&xprt->transport_lock); - dropit: - skb_free_datagram(sk, skb); - out: - read_unlock_bh(&sk->sk_callback_lock); -} - -/* - * Helper function to force a TCP close if the server is sending - * junk and/or it has put us in CLOSE_WAIT - */ -static void xs_tcp_force_close(struct rpc_xprt *xprt) -{ - set_bit(XPRT_CONNECTION_CLOSE, &xprt->state); - xprt_force_disconnect(xprt); -} - -static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_reader *desc) -{ - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - size_t len, used; - char *p; - - p = ((char *) &transport->tcp_fraghdr) + transport->tcp_offset; - len = sizeof(transport->tcp_fraghdr) - transport->tcp_offset; - used = xdr_skb_read_bits(desc, p, len); - transport->tcp_offset += used; - if (used != len) - return; - - transport->tcp_reclen = ntohl(transport->tcp_fraghdr); - if (transport->tcp_reclen & RPC_LAST_STREAM_FRAGMENT) - transport->tcp_flags |= TCP_RCV_LAST_FRAG; - else - transport->tcp_flags &= ~TCP_RCV_LAST_FRAG; - transport->tcp_reclen &= RPC_FRAGMENT_SIZE_MASK; - - transport->tcp_flags &= ~TCP_RCV_COPY_FRAGHDR; - transport->tcp_offset = 0; - - /* Sanity check of the record length */ - if (unlikely(transport->tcp_reclen < 8)) { - dprintk("RPC: invalid TCP record fragment length\n"); - xs_tcp_force_close(xprt); - return; - } - dprintk("RPC: reading TCP record fragment of length %d\n", - transport->tcp_reclen); + spin_unlock(&xprt->queue_lock); } -static void xs_tcp_check_fraghdr(struct sock_xprt *transport) +static void xs_udp_data_receive(struct sock_xprt *transport) { - if (transport->tcp_offset == transport->tcp_reclen) { - transport->tcp_flags |= TCP_RCV_COPY_FRAGHDR; - transport->tcp_offset = 0; - if (transport->tcp_flags & TCP_RCV_LAST_FRAG) { - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - transport->tcp_flags |= TCP_RCV_COPY_XID; - transport->tcp_copied = 0; - } - } -} - -static inline void xs_tcp_read_xid(struct sock_xprt *transport, struct xdr_skb_reader *desc) -{ - size_t len, used; - char *p; - - len = sizeof(transport->tcp_xid) - transport->tcp_offset; - dprintk("RPC: reading XID (%Zu bytes)\n", len); - p = ((char *) &transport->tcp_xid) + transport->tcp_offset; - used = xdr_skb_read_bits(desc, p, len); - transport->tcp_offset += used; - if (used != len) - return; - transport->tcp_flags &= ~TCP_RCV_COPY_XID; - transport->tcp_flags |= TCP_RCV_READ_CALLDIR; - transport->tcp_copied = 4; - dprintk("RPC: reading %s XID %08x\n", - (transport->tcp_flags & TCP_RPC_REPLY) ? "reply for" - : "request with", - ntohl(transport->tcp_xid)); - xs_tcp_check_fraghdr(transport); -} - -static inline void xs_tcp_read_calldir(struct sock_xprt *transport, - struct xdr_skb_reader *desc) -{ - size_t len, used; - u32 offset; - char *p; + struct sk_buff *skb; + struct sock *sk; + int err; - /* - * We want transport->tcp_offset to be 8 at the end of this routine - * (4 bytes for the xid and 4 bytes for the call/reply flag). - * When this function is called for the first time, - * transport->tcp_offset is 4 (after having already read the xid). - */ - offset = transport->tcp_offset - sizeof(transport->tcp_xid); - len = sizeof(transport->tcp_calldir) - offset; - dprintk("RPC: reading CALL/REPLY flag (%Zu bytes)\n", len); - p = ((char *) &transport->tcp_calldir) + offset; - used = xdr_skb_read_bits(desc, p, len); - transport->tcp_offset += used; - if (used != len) - return; - transport->tcp_flags &= ~TCP_RCV_READ_CALLDIR; - /* - * We don't yet have the XDR buffer, so we will write the calldir - * out after we get the buffer from the 'struct rpc_rqst' - */ - switch (ntohl(transport->tcp_calldir)) { - case RPC_REPLY: - transport->tcp_flags |= TCP_RCV_COPY_CALLDIR; - transport->tcp_flags |= TCP_RCV_COPY_DATA; - transport->tcp_flags |= TCP_RPC_REPLY; - break; - case RPC_CALL: - transport->tcp_flags |= TCP_RCV_COPY_CALLDIR; - transport->tcp_flags |= TCP_RCV_COPY_DATA; - transport->tcp_flags &= ~TCP_RPC_REPLY; - break; - default: - dprintk("RPC: invalid request message type\n"); - xs_tcp_force_close(&transport->xprt); + mutex_lock(&transport->recv_mutex); + sk = transport->inet; + if (sk == NULL) + goto out; + for (;;) { + skb = skb_recv_udp(sk, MSG_DONTWAIT, &err); + if (skb == NULL) + break; + xs_udp_data_read_skb(&transport->xprt, sk, skb); + consume_skb(skb); + cond_resched(); } - xs_tcp_check_fraghdr(transport); + xs_poll_check_readable(transport); +out: + mutex_unlock(&transport->recv_mutex); } -static inline void xs_tcp_read_common(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc, - struct rpc_rqst *req) +static void xs_udp_data_receive_workfn(struct work_struct *work) { struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - struct xdr_buf *rcvbuf; - size_t len; - ssize_t r; + container_of(work, struct sock_xprt, recv_worker); + unsigned int pflags = memalloc_nofs_save(); - rcvbuf = &req->rq_private_buf; - - if (transport->tcp_flags & TCP_RCV_COPY_CALLDIR) { - /* - * Save the RPC direction in the XDR buffer - */ - memcpy(rcvbuf->head[0].iov_base + transport->tcp_copied, - &transport->tcp_calldir, - sizeof(transport->tcp_calldir)); - transport->tcp_copied += sizeof(transport->tcp_calldir); - transport->tcp_flags &= ~TCP_RCV_COPY_CALLDIR; - } - - len = desc->count; - if (len > transport->tcp_reclen - transport->tcp_offset) { - struct xdr_skb_reader my_desc; - - len = transport->tcp_reclen - transport->tcp_offset; - memcpy(&my_desc, desc, sizeof(my_desc)); - my_desc.count = len; - r = xdr_partial_copy_from_skb(rcvbuf, transport->tcp_copied, - &my_desc, xdr_skb_read_bits); - desc->count -= r; - desc->offset += r; - } else - r = xdr_partial_copy_from_skb(rcvbuf, transport->tcp_copied, - desc, xdr_skb_read_bits); - - if (r > 0) { - transport->tcp_copied += r; - transport->tcp_offset += r; - } - if (r != len) { - /* Error when copying to the receive buffer, - * usually because we weren't able to allocate - * additional buffer pages. All we can do now - * is turn off TCP_RCV_COPY_DATA, so the request - * will not receive any additional updates, - * and time out. - * Any remaining data from this record will - * be discarded. - */ - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - dprintk("RPC: XID %08x truncated request\n", - ntohl(transport->tcp_xid)); - dprintk("RPC: xprt = %p, tcp_copied = %lu, " - "tcp_offset = %u, tcp_reclen = %u\n", - xprt, transport->tcp_copied, - transport->tcp_offset, transport->tcp_reclen); - return; - } - - dprintk("RPC: XID %08x read %Zd bytes\n", - ntohl(transport->tcp_xid), r); - dprintk("RPC: xprt = %p, tcp_copied = %lu, tcp_offset = %u, " - "tcp_reclen = %u\n", xprt, transport->tcp_copied, - transport->tcp_offset, transport->tcp_reclen); - - if (transport->tcp_copied == req->rq_private_buf.buflen) - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - else if (transport->tcp_offset == transport->tcp_reclen) { - if (transport->tcp_flags & TCP_RCV_LAST_FRAG) - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - } + xs_udp_data_receive(transport); + memalloc_nofs_restore(pflags); } -/* - * Finds the request corresponding to the RPC xid and invokes the common - * tcp read code to read the data. +/** + * xs_data_ready - "data ready" callback for sockets + * @sk: socket with data to read + * */ -static inline int xs_tcp_read_reply(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) +static void xs_data_ready(struct sock *sk) { - struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - struct rpc_rqst *req; - - dprintk("RPC: read reply XID %08x\n", ntohl(transport->tcp_xid)); - - /* Find and lock the request corresponding to this xid */ - spin_lock(&xprt->transport_lock); - req = xprt_lookup_rqst(xprt, transport->tcp_xid); - if (!req) { - dprintk("RPC: XID %08x request not found!\n", - ntohl(transport->tcp_xid)); - spin_unlock(&xprt->transport_lock); - return -1; - } - - xs_tcp_read_common(xprt, desc, req); + struct rpc_xprt *xprt; - if (!(transport->tcp_flags & TCP_RCV_COPY_DATA)) - xprt_complete_rqst(req->rq_task, transport->tcp_copied); + trace_sk_data_ready(sk); - spin_unlock(&xprt->transport_lock); - return 0; -} - -#if defined(CONFIG_SUNRPC_BACKCHANNEL) -/* - * Obtains an rpc_rqst previously allocated and invokes the common - * tcp read code to read the data. The result is placed in the callback - * queue. - * If we're unable to obtain the rpc_rqst we schedule the closing of the - * connection and return -1. - */ -static inline int xs_tcp_read_callback(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) -{ - struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - struct rpc_rqst *req; + xprt = xprt_from_sock(sk); + if (xprt != NULL) { + struct sock_xprt *transport = container_of(xprt, + struct sock_xprt, xprt); - req = xprt_alloc_bc_request(xprt); - if (req == NULL) { - printk(KERN_WARNING "Callback slot table overflowed\n"); - xprt_force_disconnect(xprt); - return -1; - } + trace_xs_data_ready(xprt); - req->rq_xid = transport->tcp_xid; - dprintk("RPC: read callback XID %08x\n", ntohl(req->rq_xid)); - xs_tcp_read_common(xprt, desc, req); + transport->old_data_ready(sk); - if (!(transport->tcp_flags & TCP_RCV_COPY_DATA)) { - struct svc_serv *bc_serv = xprt->bc_serv; + if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state)) + return; - /* - * Add callback request to callback list. The callback - * service sleeps on the sv_cb_waitq waiting for new - * requests. Wake it up after adding enqueing the - * request. + /* Any data means we had a useful conversation, so + * then we don't need to delay the next reconnect */ - dprintk("RPC: add callback request to list\n"); - spin_lock(&bc_serv->sv_cb_lock); - list_add(&req->rq_bc_list, &bc_serv->sv_cb_list); - spin_unlock(&bc_serv->sv_cb_lock); - wake_up(&bc_serv->sv_cb_waitq); + if (xprt->reestablish_timeout) + xprt->reestablish_timeout = 0; + if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + queue_work(xprtiod_workqueue, &transport->recv_worker); } - - req->rq_private_buf.len = transport->tcp_copied; - - return 0; -} - -static inline int _xs_tcp_read_data(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) -{ - struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - - return (transport->tcp_flags & TCP_RPC_REPLY) ? - xs_tcp_read_reply(xprt, desc) : - xs_tcp_read_callback(xprt, desc); -} -#else -static inline int _xs_tcp_read_data(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) -{ - return xs_tcp_read_reply(xprt, desc); } -#endif /* CONFIG_SUNRPC_BACKCHANNEL */ /* - * Read data off the transport. This can be either an RPC_CALL or an - * RPC_REPLY. Relay the processing to helper functions. + * Helper function to force a TCP close if the server is sending + * junk and/or it has put us in CLOSE_WAIT */ -static void xs_tcp_read_data(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) -{ - struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - - if (_xs_tcp_read_data(xprt, desc) == 0) - xs_tcp_check_fraghdr(transport); - else { - /* - * The transport_lock protects the request handling. - * There's no need to hold it to update the tcp_flags. - */ - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - } -} - -static inline void xs_tcp_read_discard(struct sock_xprt *transport, struct xdr_skb_reader *desc) +static void xs_tcp_force_close(struct rpc_xprt *xprt) { - size_t len; - - len = transport->tcp_reclen - transport->tcp_offset; - if (len > desc->count) - len = desc->count; - desc->count -= len; - desc->offset += len; - transport->tcp_offset += len; - dprintk("RPC: discarded %Zu bytes\n", len); - xs_tcp_check_fraghdr(transport); + xprt_force_disconnect(xprt); } -static int xs_tcp_data_recv(read_descriptor_t *rd_desc, struct sk_buff *skb, unsigned int offset, size_t len) +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static size_t xs_tcp_bc_maxpayload(struct rpc_xprt *xprt) { - struct rpc_xprt *xprt = rd_desc->arg.data; - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - struct xdr_skb_reader desc = { - .skb = skb, - .offset = offset, - .count = len, - }; - - dprintk("RPC: xs_tcp_data_recv started\n"); - do { - /* Read in a new fragment marker if necessary */ - /* Can we ever really expect to get completely empty fragments? */ - if (transport->tcp_flags & TCP_RCV_COPY_FRAGHDR) { - xs_tcp_read_fraghdr(xprt, &desc); - continue; - } - /* Read in the xid if necessary */ - if (transport->tcp_flags & TCP_RCV_COPY_XID) { - xs_tcp_read_xid(transport, &desc); - continue; - } - /* Read in the call/reply flag */ - if (transport->tcp_flags & TCP_RCV_READ_CALLDIR) { - xs_tcp_read_calldir(transport, &desc); - continue; - } - /* Read in the request data */ - if (transport->tcp_flags & TCP_RCV_COPY_DATA) { - xs_tcp_read_data(xprt, &desc); - continue; - } - /* Skip over any trailing bytes on short reads */ - xs_tcp_read_discard(transport, &desc); - } while (desc.count); - dprintk("RPC: xs_tcp_data_recv done\n"); - return len - desc.count; + return PAGE_SIZE; } +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ /** - * xs_tcp_data_ready - "data ready" callback for TCP sockets - * @sk: socket with data to read - * @bytes: how much data to read + * xs_local_state_change - callback to handle AF_LOCAL socket state changes + * @sk: socket whose state has changed * */ -static void xs_tcp_data_ready(struct sock *sk, int bytes) +static void xs_local_state_change(struct sock *sk) { struct rpc_xprt *xprt; - read_descriptor_t rd_desc; - int read; - - dprintk("RPC: xs_tcp_data_ready...\n"); - - read_lock_bh(&sk->sk_callback_lock); - if (!(xprt = xprt_from_sock(sk))) - goto out; - /* Any data means we had a useful conversation, so - * the we don't need to delay the next reconnect - */ - if (xprt->reestablish_timeout) - xprt->reestablish_timeout = 0; - - /* We use rd_desc to pass struct xprt to xs_tcp_data_recv */ - rd_desc.arg.data = xprt; - do { - rd_desc.count = 65536; - read = tcp_read_sock(sk, &rd_desc, xs_tcp_data_recv); - } while (read > 0); -out: - read_unlock_bh(&sk->sk_callback_lock); -} - -/* - * Do the equivalent of linger/linger2 handling for dealing with - * broken servers that don't close the socket in a timely - * fashion - */ -static void xs_tcp_schedule_linger_timeout(struct rpc_xprt *xprt, - unsigned long timeout) -{ struct sock_xprt *transport; - if (xprt_test_and_set_connecting(xprt)) + if (!(xprt = xprt_from_sock(sk))) return; - set_bit(XPRT_CONNECTION_ABORT, &xprt->state); - transport = container_of(xprt, struct sock_xprt, xprt); - queue_delayed_work(rpciod_workqueue, &transport->connect_worker, - timeout); -} - -static void xs_tcp_cancel_linger_timeout(struct rpc_xprt *xprt) -{ - struct sock_xprt *transport; - transport = container_of(xprt, struct sock_xprt, xprt); - - if (!test_bit(XPRT_CONNECTION_ABORT, &xprt->state) || - !cancel_delayed_work(&transport->connect_worker)) - return; - clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); - xprt_clear_connecting(xprt); -} - -static void xs_sock_reset_connection_flags(struct rpc_xprt *xprt) -{ - smp_mb__before_clear_bit(); - clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); - clear_bit(XPRT_CONNECTION_CLOSE, &xprt->state); - clear_bit(XPRT_CLOSE_WAIT, &xprt->state); - clear_bit(XPRT_CLOSING, &xprt->state); - smp_mb__after_clear_bit(); -} - -static void xs_sock_mark_closed(struct rpc_xprt *xprt) -{ - xs_sock_reset_connection_flags(xprt); - /* Mark transport as closed and wake up all pending tasks */ - xprt_disconnect_done(xprt); + if (sk->sk_shutdown & SHUTDOWN_MASK) { + clear_bit(XPRT_CONNECTED, &xprt->state); + /* Trigger the socket release */ + xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT); + } } /** @@ -1481,10 +1540,10 @@ static void xs_sock_mark_closed(struct rpc_xprt *xprt) static void xs_tcp_state_change(struct sock *sk) { struct rpc_xprt *xprt; + struct sock_xprt *transport; - read_lock_bh(&sk->sk_callback_lock); if (!(xprt = xprt_from_sock(sk))) - goto out; + return; dprintk("RPC: xs_tcp_state_change client %p...\n", xprt); dprintk("RPC: state %x conn %d dead %d zapped %d sk_shutdown %d\n", sk->sk_state, xprt_connected(xprt), @@ -1492,40 +1551,37 @@ static void xs_tcp_state_change(struct sock *sk) sock_flag(sk, SOCK_ZAPPED), sk->sk_shutdown); + transport = container_of(xprt, struct sock_xprt, xprt); + trace_rpc_socket_state_change(xprt, sk->sk_socket); switch (sk->sk_state) { case TCP_ESTABLISHED: - spin_lock(&xprt->transport_lock); if (!xprt_test_and_set_connected(xprt)) { - struct sock_xprt *transport = container_of(xprt, - struct sock_xprt, xprt); - - /* Reset TCP record info */ - transport->tcp_offset = 0; - transport->tcp_reclen = 0; - transport->tcp_copied = 0; - transport->tcp_flags = - TCP_RCV_COPY_FRAGHDR | TCP_RCV_COPY_XID; - - xprt_wake_pending_tasks(xprt, -EAGAIN); + xprt->connect_cookie++; + clear_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); + xprt_clear_connecting(xprt); + + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; + xs_run_error_worker(transport, XPRT_SOCK_WAKE_PENDING); } - spin_unlock(&xprt->transport_lock); break; case TCP_FIN_WAIT1: /* The client initiated a shutdown of the socket */ xprt->connect_cookie++; xprt->reestablish_timeout = 0; set_bit(XPRT_CLOSING, &xprt->state); - smp_mb__before_clear_bit(); + smp_mb__before_atomic(); clear_bit(XPRT_CONNECTED, &xprt->state); clear_bit(XPRT_CLOSE_WAIT, &xprt->state); - smp_mb__after_clear_bit(); - xs_tcp_schedule_linger_timeout(xprt, xs_tcp_fin_timeout); + smp_mb__after_atomic(); break; case TCP_CLOSE_WAIT: /* The server initiated a shutdown of the socket */ xprt->connect_cookie++; clear_bit(XPRT_CONNECTED, &xprt->state); - xs_tcp_force_close(xprt); + xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT); + fallthrough; case TCP_CLOSING: /* * If the server closed down the connection, make sure that @@ -1536,34 +1592,38 @@ static void xs_tcp_state_change(struct sock *sk) break; case TCP_LAST_ACK: set_bit(XPRT_CLOSING, &xprt->state); - xs_tcp_schedule_linger_timeout(xprt, xs_tcp_fin_timeout); - smp_mb__before_clear_bit(); + smp_mb__before_atomic(); clear_bit(XPRT_CONNECTED, &xprt->state); - smp_mb__after_clear_bit(); + smp_mb__after_atomic(); break; case TCP_CLOSE: - xs_tcp_cancel_linger_timeout(xprt); - xs_sock_mark_closed(xprt); + if (test_and_clear_bit(XPRT_SOCK_CONNECTING, + &transport->sock_state)) { + xs_reset_srcport(transport); + xprt_clear_connecting(xprt); + } + clear_bit(XPRT_CLOSING, &xprt->state); + /* Trigger the socket release */ + xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT); } - out: - read_unlock_bh(&sk->sk_callback_lock); } static void xs_write_space(struct sock *sk) { - struct socket *sock; + struct sock_xprt *transport; struct rpc_xprt *xprt; - if (unlikely(!(sock = sk->sk_socket))) + if (!sk->sk_socket) return; - clear_bit(SOCK_NOSPACE, &sock->flags); + clear_bit(SOCK_NOSPACE, &sk->sk_socket->flags); if (unlikely(!(xprt = xprt_from_sock(sk)))) return; - if (test_and_clear_bit(SOCK_ASYNC_NOSPACE, &sock->flags) == 0) + transport = container_of(xprt, struct sock_xprt, xprt); + if (!test_and_clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state)) return; - - xprt_write_space(xprt); + xs_run_error_worker(transport, XPRT_SOCK_WAKE_WRITE); + sk->sk_write_pending--; } /** @@ -1578,13 +1638,9 @@ static void xs_write_space(struct sock *sk) */ static void xs_udp_write_space(struct sock *sk) { - read_lock_bh(&sk->sk_callback_lock); - /* from net/core/sock.c:sock_def_write_space */ if (sock_writeable(sk)) xs_write_space(sk); - - read_unlock_bh(&sk->sk_callback_lock); } /** @@ -1599,13 +1655,9 @@ static void xs_udp_write_space(struct sock *sk) */ static void xs_tcp_write_space(struct sock *sk) { - read_lock_bh(&sk->sk_callback_lock); - /* from net/core/stream.c:sk_stream_write_space */ - if (sk_stream_wspace(sk) >= sk_stream_min_wspace(sk)) + if (sk_stream_is_writeable(sk)) xs_write_space(sk); - - read_unlock_bh(&sk->sk_callback_lock); } static void xs_udp_do_set_buffer_size(struct rpc_xprt *xprt) @@ -1648,20 +1700,47 @@ static void xs_udp_set_buffer_size(struct rpc_xprt *xprt, size_t sndsize, size_t /** * xs_udp_timer - called when a retransmit timeout occurs on a UDP transport + * @xprt: controlling transport * @task: task that timed out * * Adjust the congestion window after a retransmit timeout has occurred. */ static void xs_udp_timer(struct rpc_xprt *xprt, struct rpc_task *task) { + spin_lock(&xprt->transport_lock); xprt_adjust_cwnd(xprt, task, -ETIMEDOUT); + spin_unlock(&xprt->transport_lock); } -static unsigned short xs_get_random_port(void) +static int xs_get_random_port(void) { - unsigned short range = xprt_max_resvport - xprt_min_resvport; - unsigned short rand = (unsigned short) net_random() % range; - return rand + xprt_min_resvport; + unsigned short min = xprt_min_resvport, max = xprt_max_resvport; + unsigned short range; + unsigned short rand; + + if (max < min) + return -EADDRINUSE; + range = max - min + 1; + rand = get_random_u32_below(range); + return rand + min; +} + +static unsigned short xs_sock_getport(struct socket *sock) +{ + struct sockaddr_storage buf; + unsigned short port = 0; + + if (kernel_getsockname(sock, (struct sockaddr *)&buf) < 0) + goto out; + switch (buf.ss_family) { + case AF_INET6: + port = ntohs(((struct sockaddr_in6 *)&buf)->sin6_port); + break; + case AF_INET: + port = ntohs(((struct sockaddr_in *)&buf)->sin_port); + } +out: + return port; } /** @@ -1678,15 +1757,56 @@ static void xs_set_port(struct rpc_xprt *xprt, unsigned short port) xs_update_peer_port(xprt); } -static unsigned short xs_get_srcport(struct sock_xprt *transport) +static void xs_reset_srcport(struct sock_xprt *transport) +{ + transport->srcport = 0; +} + +static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock) { - unsigned short port = transport->srcport; + if (transport->srcport == 0 && transport->xprt.reuseport) + transport->srcport = xs_sock_getport(sock); +} + +static int xs_get_srcport(struct sock_xprt *transport) +{ + int port = transport->srcport; if (port == 0 && transport->xprt.resvport) port = xs_get_random_port(); return port; } +static unsigned short xs_sock_srcport(struct rpc_xprt *xprt) +{ + struct sock_xprt *sock = container_of(xprt, struct sock_xprt, xprt); + unsigned short ret = 0; + mutex_lock(&sock->recv_mutex); + if (sock->sock) + ret = xs_sock_getport(sock->sock); + mutex_unlock(&sock->recv_mutex); + return ret; +} + +static int xs_sock_srcaddr(struct rpc_xprt *xprt, char *buf, size_t buflen) +{ + struct sock_xprt *sock = container_of(xprt, struct sock_xprt, xprt); + union { + struct sockaddr sa; + struct sockaddr_storage st; + } saddr; + int ret = -ENOTCONN; + + mutex_lock(&sock->recv_mutex); + if (sock->sock) { + ret = kernel_getsockname(sock->sock, &saddr.sa); + if (ret >= 0) + ret = snprintf(buf, buflen, "%pISc", &saddr.sa); + } + mutex_unlock(&sock->recv_mutex); + return ret; +} + static unsigned short xs_next_srcport(struct sock_xprt *transport, unsigned short port) { if (transport->srcport != 0) @@ -1701,18 +1821,35 @@ static int xs_bind(struct sock_xprt *transport, struct socket *sock) { struct sockaddr_storage myaddr; int err, nloop = 0; - unsigned short port = xs_get_srcport(transport); + int port = xs_get_srcport(transport); unsigned short last; + /* + * If we are asking for any ephemeral port (i.e. port == 0 && + * transport->xprt.resvport == 0), don't bind. Let the local + * port selection happen implicitly when the socket is used + * (for example at connect time). + * + * This ensures that we can continue to establish TCP + * connections even when all local ephemeral ports are already + * a part of some TCP connection. This makes no difference + * for UDP sockets, but also doesn't harm them. + * + * If we're asking for any reserved port (i.e. port == 0 && + * transport->xprt.resvport == 1) xs_get_srcport above will + * ensure that port is non-zero and we will bind as needed. + */ + if (port <= 0) + return port; + memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen); do { rpc_set_port((struct sockaddr *)&myaddr, port); - err = kernel_bind(sock, (struct sockaddr *)&myaddr, - transport->xprt.addrlen); - if (port == 0) - break; + err = kernel_bind(sock, (struct sockaddr_unsized *)&myaddr, + transport->xprt.addrlen); if (err == 0) { - transport->srcport = port; + if (transport->xprt.reuseport) + transport->srcport = port; break; } last = port; @@ -1737,9 +1874,7 @@ static int xs_bind(struct sock_xprt *transport, struct socket *sock) */ static void xs_local_rpcbind(struct rpc_task *task) { - rcu_read_lock(); - xprt_set_bound(rcu_dereference(task->tk_client->cl_xprt)); - rcu_read_unlock(); + xprt_set_bound(task->tk_xprt); } static void xs_local_set_port(struct rpc_xprt *xprt, unsigned short port) @@ -1747,15 +1882,15 @@ static void xs_local_set_port(struct rpc_xprt *xprt, unsigned short port) } #ifdef CONFIG_DEBUG_LOCK_ALLOC -static struct lock_class_key xs_key[2]; -static struct lock_class_key xs_slock_key[2]; +static struct lock_class_key xs_key[3]; +static struct lock_class_key xs_slock_key[3]; static inline void xs_reclassify_socketu(struct socket *sock) { struct sock *sk = sock->sk; sock_lock_init_class_and_name(sk, "slock-AF_LOCAL-RPC", - &xs_slock_key[1], "sk_lock-AF_LOCAL-RPC", &xs_key[1]); + &xs_slock_key[0], "sk_lock-AF_LOCAL-RPC", &xs_key[0]); } static inline void xs_reclassify_socket4(struct socket *sock) @@ -1763,7 +1898,7 @@ static inline void xs_reclassify_socket4(struct socket *sock) struct sock *sk = sock->sk; sock_lock_init_class_and_name(sk, "slock-AF_INET-RPC", - &xs_slock_key[0], "sk_lock-AF_INET-RPC", &xs_key[0]); + &xs_slock_key[1], "sk_lock-AF_INET-RPC", &xs_key[1]); } static inline void xs_reclassify_socket6(struct socket *sock) @@ -1771,13 +1906,12 @@ static inline void xs_reclassify_socket6(struct socket *sock) struct sock *sk = sock->sk; sock_lock_init_class_and_name(sk, "slock-AF_INET6-RPC", - &xs_slock_key[1], "sk_lock-AF_INET6-RPC", &xs_key[1]); + &xs_slock_key[2], "sk_lock-AF_INET6-RPC", &xs_key[2]); } static inline void xs_reclassify_socket(int family, struct socket *sock) { - WARN_ON_ONCE(sock_owned_by_user(sock->sk)); - if (sock_owned_by_user(sock->sk)) + if (WARN_ON_ONCE(!sock_allow_reclassification(sock->sk))) return; switch (family) { @@ -1793,26 +1927,20 @@ static inline void xs_reclassify_socket(int family, struct socket *sock) } } #else -static inline void xs_reclassify_socketu(struct socket *sock) -{ -} - -static inline void xs_reclassify_socket4(struct socket *sock) -{ -} - -static inline void xs_reclassify_socket6(struct socket *sock) +static inline void xs_reclassify_socket(int family, struct socket *sock) { } +#endif -static inline void xs_reclassify_socket(int family, struct socket *sock) +static void xs_dummy_setup_socket(struct work_struct *work) { } -#endif static struct socket *xs_create_sock(struct rpc_xprt *xprt, - struct sock_xprt *transport, int family, int type, int protocol) + struct sock_xprt *transport, int family, int type, + int protocol, bool reuseport) { + struct file *filp; struct socket *sock; int err; @@ -1824,12 +1952,23 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt, } xs_reclassify_socket(family, sock); + if (reuseport) + sock_set_reuseport(sock->sk); + err = xs_bind(transport, sock); if (err) { sock_release(sock); goto out; } + if (protocol == IPPROTO_TCP) + sk_net_refcnt_upgrade(sock->sk); + + filp = sock_alloc_file(sock, O_NONBLOCK, NULL); + if (IS_ERR(filp)) + return ERR_CAST(filp); + transport->file = filp; + return sock; out: return ERR_PTR(err); @@ -1844,14 +1983,16 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt, if (!transport->inet) { struct sock *sk = sock->sk; - write_lock_bh(&sk->sk_callback_lock); + lock_sock(sk); xs_save_old_callbacks(transport, sk); sk->sk_user_data = xprt; - sk->sk_data_ready = xs_local_data_ready; + sk->sk_data_ready = xs_data_ready; sk->sk_write_space = xs_udp_write_space; - sk->sk_allocation = GFP_ATOMIC; + sk->sk_state_change = xs_local_state_change; + sk->sk_error_report = xs_error_report; + sk->sk_use_task_frag = false; xprt_clear_connected(xprt); @@ -1859,30 +2000,25 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt, transport->sock = sock; transport->inet = sk; - write_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); } - /* Tell the socket layer to start connecting... */ - xprt->stat.connect_count++; - xprt->stat.connect_start = jiffies; - return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0); + xs_stream_start_connect(transport); + + return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt), xprt->addrlen, 0); } /** * xs_local_setup_socket - create AF_LOCAL socket, connect to a local endpoint - * @xprt: RPC transport to connect * @transport: socket transport to connect - * @create_sock: function to create a socket of the correct type */ static int xs_local_setup_socket(struct sock_xprt *transport) { struct rpc_xprt *xprt = &transport->xprt; + struct file *filp; struct socket *sock; - int status = -EIO; - - current->flags |= PF_FSTRANS; + int status; - clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); status = __sock_create(xprt->xprt_net, AF_LOCAL, SOCK_STREAM, 0, &sock, 1); if (status < 0) { @@ -1890,18 +2026,31 @@ static int xs_local_setup_socket(struct sock_xprt *transport) "transport socket (%d).\n", -status); goto out; } - xs_reclassify_socketu(sock); + xs_reclassify_socket(AF_LOCAL, sock); + + filp = sock_alloc_file(sock, O_NONBLOCK, NULL); + if (IS_ERR(filp)) { + status = PTR_ERR(filp); + goto out; + } + transport->file = filp; dprintk("RPC: worker connecting xprt %p via AF_LOCAL to %s\n", xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); status = xs_local_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, status); switch (status) { case 0: dprintk("RPC: xprt %p connected to %s\n", xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; xprt_set_connected(xprt); break; + case -ENOBUFS: + break; case -ENOENT: dprintk("RPC: xprt %p: socket %s does not exist\n", xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); @@ -1919,7 +2068,6 @@ static int xs_local_setup_socket(struct sock_xprt *transport) out: xprt_clear_connecting(xprt); xprt_wake_pending_tasks(xprt, status); - current->flags &= ~PF_FSTRANS; return status; } @@ -1928,7 +2076,10 @@ static void xs_local_connect(struct rpc_xprt *xprt, struct rpc_task *task) struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); int ret; - if (RPC_IS_ASYNC(task)) { + if (transport->file) + goto force_disconnect; + + if (RPC_IS_ASYNC(task)) { /* * We want the AF_LOCAL connect to be resolved in the * filesystem namespace of the process making the rpc @@ -1938,51 +2089,94 @@ static void xs_local_connect(struct rpc_xprt *xprt, struct rpc_task *task) * we'll need to figure out how to pass a namespace to * connect. */ - rpc_exit(task, -ENOTCONN); - return; + rpc_task_set_rpc_status(task, -ENOTCONN); + goto out_wake; } ret = xs_local_setup_socket(transport); if (ret && !RPC_IS_SOFTCONN(task)) msleep_interruptible(15000); + return; +force_disconnect: + xprt_force_disconnect(xprt); +out_wake: + xprt_clear_connecting(xprt); + xprt_wake_pending_tasks(xprt, -ENOTCONN); } -#ifdef CONFIG_SUNRPC_SWAP +#if IS_ENABLED(CONFIG_SUNRPC_SWAP) +/* + * Note that this should be called with XPRT_LOCKED held, or recv_mutex + * held, or when we otherwise know that we have exclusive access to the + * socket, to guard against races with xs_reset_transport. + */ static void xs_set_memalloc(struct rpc_xprt *xprt) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - if (xprt->swapper) + /* + * If there's no sock, then we have nothing to set. The + * reconnecting process will get it for us. + */ + if (!transport->inet) + return; + if (atomic_read(&xprt->swapper)) sk_set_memalloc(transport->inet); } /** - * xs_swapper - Tag this transport as being used for swap. + * xs_enable_swap - Tag this transport as being used for swap. * @xprt: transport to tag - * @enable: enable/disable * + * Take a reference to this transport on behalf of the rpc_clnt, and + * optionally mark it for swapping if it wasn't already. */ -int xs_swapper(struct rpc_xprt *xprt, int enable) +static int +xs_enable_swap(struct rpc_xprt *xprt) { - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, - xprt); - int err = 0; + struct sock_xprt *xs = container_of(xprt, struct sock_xprt, xprt); - if (enable) { - xprt->swapper++; - xs_set_memalloc(xprt); - } else if (xprt->swapper) { - xprt->swapper--; - sk_clear_memalloc(transport->inet); - } + mutex_lock(&xs->recv_mutex); + if (atomic_inc_return(&xprt->swapper) == 1 && + xs->inet) + sk_set_memalloc(xs->inet); + mutex_unlock(&xs->recv_mutex); + return 0; +} - return err; +/** + * xs_disable_swap - Untag this transport as being used for swap. + * @xprt: transport to tag + * + * Drop a "swapper" reference to this xprt on behalf of the rpc_clnt. If the + * swapper refcount goes to 0, untag the socket as a memalloc socket. + */ +static void +xs_disable_swap(struct rpc_xprt *xprt) +{ + struct sock_xprt *xs = container_of(xprt, struct sock_xprt, xprt); + + mutex_lock(&xs->recv_mutex); + if (atomic_dec_and_test(&xprt->swapper) && + xs->inet) + sk_clear_memalloc(xs->inet); + mutex_unlock(&xs->recv_mutex); } -EXPORT_SYMBOL_GPL(xs_swapper); #else static void xs_set_memalloc(struct rpc_xprt *xprt) { } + +static int +xs_enable_swap(struct rpc_xprt *xprt) +{ + return -EINVAL; +} + +static void +xs_disable_swap(struct rpc_xprt *xprt) +{ +} #endif static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) @@ -1992,15 +2186,14 @@ static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) if (!transport->inet) { struct sock *sk = sock->sk; - write_lock_bh(&sk->sk_callback_lock); + lock_sock(sk); xs_save_old_callbacks(transport, sk); sk->sk_user_data = xprt; - sk->sk_data_ready = xs_udp_data_ready; + sk->sk_data_ready = xs_data_ready; sk->sk_write_space = xs_udp_write_space; - sk->sk_no_check = UDP_CSUM_NORCV; - sk->sk_allocation = GFP_ATOMIC; + sk->sk_use_task_frag = false; xprt_set_connected(xprt); @@ -2010,9 +2203,11 @@ static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) xs_set_memalloc(xprt); - write_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); } xs_udp_do_set_buffer_size(xprt); + + xprt->stat.connect_start = jiffies; } static void xs_udp_setup_socket(struct work_struct *work) @@ -2020,15 +2215,15 @@ static void xs_udp_setup_socket(struct work_struct *work) struct sock_xprt *transport = container_of(work, struct sock_xprt, connect_worker.work); struct rpc_xprt *xprt = &transport->xprt; - struct socket *sock = transport->sock; + struct socket *sock; int status = -EIO; + unsigned int pflags = current->flags; - current->flags |= PF_FSTRANS; - - /* Start by resetting any existing state */ - xs_reset_transport(transport); + if (atomic_read(&xprt->swapper)) + current->flags |= PF_MEMALLOC; sock = xs_create_sock(xprt, transport, - xs_addr(xprt)->sa_family, SOCK_DGRAM, IPPROTO_UDP); + xs_addr(xprt)->sa_family, SOCK_DGRAM, + IPPROTO_UDP, false); if (IS_ERR(sock)) goto out; @@ -2039,85 +2234,158 @@ static void xs_udp_setup_socket(struct work_struct *work) xprt->address_strings[RPC_DISPLAY_PORT]); xs_udp_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, 0); status = 0; out: xprt_clear_connecting(xprt); + xprt_unlock_connect(xprt, transport); xprt_wake_pending_tasks(xprt, status); - current->flags &= ~PF_FSTRANS; + current_restore_flags(pflags, PF_MEMALLOC); } -/* - * We need to preserve the port number so the reply cache on the server can - * find our cached RPC replies when we get around to reconnecting. +/** + * xs_tcp_shutdown - gracefully shut down a TCP socket + * @xprt: transport + * + * Initiates a graceful shutdown of the TCP socket by calling the + * equivalent of shutdown(SHUT_RDWR); */ -static void xs_abort_connection(struct sock_xprt *transport) +static void xs_tcp_shutdown(struct rpc_xprt *xprt) { - int result; - struct sockaddr any; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct socket *sock = transport->sock; + int skst = transport->inet ? transport->inet->sk_state : TCP_CLOSE; - dprintk("RPC: disconnecting xprt %p to reuse port\n", transport); + if (sock == NULL) + return; + if (!xprt->reuseport) { + xs_close(xprt); + return; + } + switch (skst) { + case TCP_FIN_WAIT1: + case TCP_FIN_WAIT2: + case TCP_LAST_ACK: + break; + case TCP_ESTABLISHED: + case TCP_CLOSE_WAIT: + kernel_sock_shutdown(sock, SHUT_RDWR); + trace_rpc_socket_shutdown(xprt, sock); + break; + default: + xs_reset_transport(transport); + } +} - /* - * Disconnect the transport socket by doing a connect operation - * with AF_UNSPEC. This should return immediately... - */ - memset(&any, 0, sizeof(any)); - any.sa_family = AF_UNSPEC; - result = kernel_connect(transport->sock, &any, sizeof(any), 0); - if (!result) - xs_sock_reset_connection_flags(&transport->xprt); - dprintk("RPC: AF_UNSPEC connect return code %d\n", result); +static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, + struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct net *net = sock_net(sock->sk); + unsigned long connect_timeout; + unsigned long syn_retries; + unsigned int keepidle; + unsigned int keepcnt; + unsigned int timeo; + unsigned long t; + + spin_lock(&xprt->transport_lock); + keepidle = DIV_ROUND_UP(xprt->timeout->to_initval, HZ); + keepcnt = xprt->timeout->to_retries + 1; + timeo = jiffies_to_msecs(xprt->timeout->to_initval) * + (xprt->timeout->to_retries + 1); + clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); + spin_unlock(&xprt->transport_lock); + + /* TCP Keepalive options */ + sock_set_keepalive(sock->sk); + tcp_sock_set_keepidle(sock->sk, keepidle); + tcp_sock_set_keepintvl(sock->sk, keepidle); + tcp_sock_set_keepcnt(sock->sk, keepcnt); + + /* TCP user timeout (see RFC5482) */ + tcp_sock_set_user_timeout(sock->sk, timeo); + + /* Connect timeout */ + connect_timeout = max_t(unsigned long, + DIV_ROUND_UP(xprt->connect_timeout, HZ), 1); + syn_retries = max_t(unsigned long, + READ_ONCE(net->ipv4.sysctl_tcp_syn_retries), 1); + for (t = 0; t <= syn_retries && (1UL << t) < connect_timeout; t++) + ; + if (t <= syn_retries) + tcp_sock_set_syncnt(sock->sk, t - 1); } -static void xs_tcp_reuse_connection(struct sock_xprt *transport) +static void xs_tcp_do_set_connect_timeout(struct rpc_xprt *xprt, + unsigned long connect_timeout) { - unsigned int state = transport->inet->sk_state; + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct rpc_timeout to; + unsigned long initval; - if (state == TCP_CLOSE && transport->sock->state == SS_UNCONNECTED) { - /* we don't need to abort the connection if the socket - * hasn't undergone a shutdown - */ - if (transport->inet->sk_shutdown == 0) - return; - dprintk("RPC: %s: TCP_CLOSEd and sk_shutdown set to %d\n", - __func__, transport->inet->sk_shutdown); - } - if ((1 << state) & (TCPF_ESTABLISHED|TCPF_SYN_SENT)) { - /* we don't need to abort the connection if the socket - * hasn't undergone a shutdown - */ - if (transport->inet->sk_shutdown == 0) - return; - dprintk("RPC: %s: ESTABLISHED/SYN_SENT " - "sk_shutdown set to %d\n", - __func__, transport->inet->sk_shutdown); - } - xs_abort_connection(transport); + memcpy(&to, xprt->timeout, sizeof(to)); + /* Arbitrary lower limit */ + initval = max_t(unsigned long, connect_timeout, XS_TCP_INIT_REEST_TO); + to.to_initval = initval; + to.to_maxval = initval; + to.to_retries = 0; + memcpy(&transport->tcp_timeout, &to, sizeof(transport->tcp_timeout)); + xprt->timeout = &transport->tcp_timeout; + xprt->connect_timeout = connect_timeout; +} + +static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt, + unsigned long connect_timeout, + unsigned long reconnect_timeout) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + spin_lock(&xprt->transport_lock); + if (reconnect_timeout < xprt->max_reconnect_timeout) + xprt->max_reconnect_timeout = reconnect_timeout; + if (connect_timeout < xprt->connect_timeout) + xs_tcp_do_set_connect_timeout(xprt, connect_timeout); + set_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); + spin_unlock(&xprt->transport_lock); } static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - int ret = -ENOTCONN; if (!transport->inet) { struct sock *sk = sock->sk; - write_lock_bh(&sk->sk_callback_lock); + /* Avoid temporary address, they are bad for long-lived + * connections such as NFS mounts. + * RFC4941, section 3.6 suggests that: + * Individual applications, which have specific + * knowledge about the normal duration of connections, + * MAY override this as appropriate. + */ + if (xs_addr(xprt)->sa_family == PF_INET6) { + ip6_sock_set_addr_preferences(sk, + IPV6_PREFER_SRC_PUBLIC); + } + + xs_tcp_set_socket_timeouts(xprt, sock); + tcp_sock_set_nodelay(sk); + + lock_sock(sk); xs_save_old_callbacks(transport, sk); sk->sk_user_data = xprt; - sk->sk_data_ready = xs_tcp_data_ready; + sk->sk_data_ready = xs_data_ready; sk->sk_state_change = xs_tcp_state_change; sk->sk_write_space = xs_tcp_write_space; - sk->sk_allocation = GFP_ATOMIC; + sk->sk_error_report = xs_error_report; + sk->sk_use_task_frag = false; /* socket options */ - sk->sk_userlocks |= SOCK_BINDPORT_LOCK; sock_reset_flag(sk, SOCK_LINGER); - tcp_sk(sk)->linger2 = 0; - tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF; xprt_clear_connected(xprt); @@ -2125,35 +2393,25 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) transport->sock = sock; transport->inet = sk; - write_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); } if (!xprt_bound(xprt)) - goto out; + return -ENOTCONN; xs_set_memalloc(xprt); + xs_stream_start_connect(transport); + /* Tell the socket layer to start connecting... */ - xprt->stat.connect_count++; - xprt->stat.connect_start = jiffies; - ret = kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK); - switch (ret) { - case 0: - case -EINPROGRESS: - /* SYN_SENT! */ - xprt->connect_cookie++; - if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) - xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; - } -out: - return ret; + set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); + return kernel_connect(sock, (struct sockaddr_unsized *)xs_addr(xprt), + xprt->addrlen, O_NONBLOCK); } /** * xs_tcp_setup_socket - create a TCP socket and connect to a remote endpoint - * @xprt: RPC transport to connect - * @transport: socket transport to connect - * @create_sock: function to create a socket of the correct type + * @work: queued work item * * Invoked by a work queue tasklet. */ @@ -2163,28 +2421,24 @@ static void xs_tcp_setup_socket(struct work_struct *work) container_of(work, struct sock_xprt, connect_worker.work); struct socket *sock = transport->sock; struct rpc_xprt *xprt = &transport->xprt; - int status = -EIO; + int status; + unsigned int pflags = current->flags; - current->flags |= PF_FSTRANS; + if (atomic_read(&xprt->swapper)) + current->flags |= PF_MEMALLOC; - if (!sock) { - clear_bit(XPRT_CONNECTION_ABORT, &xprt->state); - sock = xs_create_sock(xprt, transport, - xs_addr(xprt)->sa_family, SOCK_STREAM, IPPROTO_TCP); + if (xprt_connected(xprt)) + goto out; + if (test_and_clear_bit(XPRT_SOCK_CONNECT_SENT, + &transport->sock_state) || + !sock) { + xs_reset_transport(transport); + sock = xs_create_sock(xprt, transport, xs_addr(xprt)->sa_family, + SOCK_STREAM, IPPROTO_TCP, true); if (IS_ERR(sock)) { - status = PTR_ERR(sock); + xprt_wake_pending_tasks(xprt, PTR_ERR(sock)); goto out; } - } else { - int abort_and_exit; - - abort_and_exit = test_and_clear_bit(XPRT_CONNECTION_ABORT, - &xprt->state); - /* "close" the socket, preserving the local port */ - xs_tcp_reuse_connection(transport); - - if (abort_and_exit) - goto out_eagain; } dprintk("RPC: worker connecting xprt %p via %s to " @@ -2194,41 +2448,330 @@ static void xs_tcp_setup_socket(struct work_struct *work) xprt->address_strings[RPC_DISPLAY_PORT]); status = xs_tcp_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, status); dprintk("RPC: %p connect status %d connected %d sock state %d\n", xprt, -status, xprt_connected(xprt), sock->sk->sk_state); switch (status) { - default: - printk("%s: connect returned unhandled error %d\n", - __func__, status); - case -EADDRNOTAVAIL: - /* We're probably in TIME_WAIT. Get rid of existing socket, - * and retry - */ - xs_tcp_force_close(xprt); - break; case 0: case -EINPROGRESS: + /* SYN_SENT! */ + set_bit(XPRT_SOCK_CONNECT_SENT, &transport->sock_state); + if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + fallthrough; case -EALREADY: - xprt_clear_connecting(xprt); - current->flags &= ~PF_FSTRANS; - return; + goto out_unlock; + case -EADDRNOTAVAIL: + /* Source port number is unavailable. Try a new one! */ + transport->srcport = 0; + status = -EAGAIN; + break; + case -EPERM: + /* Happens, for instance, if a BPF program is preventing + * the connect. Remap the error so upper layers can better + * deal with it. + */ + status = -ECONNREFUSED; + fallthrough; case -EINVAL: /* Happens, for instance, if the user specified a link * local IPv6 address without a scope-id. */ case -ECONNREFUSED: case -ECONNRESET: + case -ENETDOWN: case -ENETUNREACH: - /* retry with existing socket, after a delay */ - goto out; + case -EHOSTUNREACH: + case -EADDRINUSE: + case -ENOBUFS: + case -ENOTCONN: + break; + default: + printk("%s: connect returned unhandled error %d\n", + __func__, status); + status = -EAGAIN; } -out_eagain: - status = -EAGAIN; + + /* xs_tcp_force_close() wakes tasks with a fixed error code. + * We need to wake them first to ensure the correct error code. + */ + xprt_wake_pending_tasks(xprt, status); + xs_tcp_force_close(xprt); out: xprt_clear_connecting(xprt); - xprt_wake_pending_tasks(xprt, status); - current->flags &= ~PF_FSTRANS; +out_unlock: + xprt_unlock_connect(xprt, transport); + current_restore_flags(pflags, PF_MEMALLOC); +} + +/* + * Transfer the connected socket to @upper_transport, then mark that + * xprt CONNECTED. + */ +static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt, + struct sock_xprt *upper_transport) +{ + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + struct rpc_xprt *upper_xprt = &upper_transport->xprt; + + if (!upper_transport->inet) { + struct socket *sock = lower_transport->sock; + struct sock *sk = sock->sk; + + /* Avoid temporary address, they are bad for long-lived + * connections such as NFS mounts. + * RFC4941, section 3.6 suggests that: + * Individual applications, which have specific + * knowledge about the normal duration of connections, + * MAY override this as appropriate. + */ + if (xs_addr(upper_xprt)->sa_family == PF_INET6) + ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC); + + xs_tcp_set_socket_timeouts(upper_xprt, sock); + tcp_sock_set_nodelay(sk); + + lock_sock(sk); + + /* @sk is already connected, so it now has the RPC callbacks. + * Reach into @lower_transport to save the original ones. + */ + upper_transport->old_data_ready = lower_transport->old_data_ready; + upper_transport->old_state_change = lower_transport->old_state_change; + upper_transport->old_write_space = lower_transport->old_write_space; + upper_transport->old_error_report = lower_transport->old_error_report; + sk->sk_user_data = upper_xprt; + + /* socket options */ + sock_reset_flag(sk, SOCK_LINGER); + + xprt_clear_connected(upper_xprt); + + upper_transport->sock = sock; + upper_transport->inet = sk; + upper_transport->file = lower_transport->file; + + release_sock(sk); + + /* Reset lower_transport before shutting down its clnt */ + mutex_lock(&lower_transport->recv_mutex); + lower_transport->inet = NULL; + lower_transport->sock = NULL; + lower_transport->file = NULL; + + xprt_clear_connected(lower_xprt); + xs_sock_reset_connection_flags(lower_xprt); + xs_stream_reset_connect(lower_transport); + mutex_unlock(&lower_transport->recv_mutex); + } + + if (!xprt_bound(upper_xprt)) + return -ENOTCONN; + + xs_set_memalloc(upper_xprt); + + if (!xprt_test_and_set_connected(upper_xprt)) { + upper_xprt->connect_cookie++; + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + + upper_xprt->stat.connect_count++; + upper_xprt->stat.connect_time += (long)jiffies - + upper_xprt->stat.connect_start; + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + } + return 0; +} + +/** + * xs_tls_handshake_done - TLS handshake completion handler + * @data: address of xprt to wake + * @status: status of handshake + * @peerid: serial number of key containing the remote's identity + * + */ +static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid) +{ + struct rpc_xprt *lower_xprt = data; + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + + switch (status) { + case 0: + case -EACCES: + case -ETIMEDOUT: + lower_transport->xprt_err = status; + break; + default: + lower_transport->xprt_err = -EACCES; + } + complete(&lower_transport->handshake_done); + xprt_put(lower_xprt); +} + +static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec) +{ + struct sock_xprt *lower_transport = + container_of(lower_xprt, struct sock_xprt, xprt); + struct tls_handshake_args args = { + .ta_sock = lower_transport->sock, + .ta_done = xs_tls_handshake_done, + .ta_data = xprt_get(lower_xprt), + .ta_peername = lower_xprt->servername, + }; + struct sock *sk = lower_transport->inet; + int rc; + + init_completion(&lower_transport->handshake_done); + set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state); + lower_transport->xprt_err = -ETIMEDOUT; + switch (xprtsec->policy) { + case RPC_XPRTSEC_TLS_ANON: + rc = tls_client_hello_anon(&args, GFP_KERNEL); + if (rc) + goto out_put_xprt; + break; + case RPC_XPRTSEC_TLS_X509: + args.ta_my_cert = xprtsec->cert_serial; + args.ta_my_privkey = xprtsec->privkey_serial; + rc = tls_client_hello_x509(&args, GFP_KERNEL); + if (rc) + goto out_put_xprt; + break; + default: + rc = -EACCES; + goto out_put_xprt; + } + + rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done, + XS_TLS_HANDSHAKE_TO); + if (rc <= 0) { + tls_handshake_cancel(sk); + if (rc == 0) + rc = -ETIMEDOUT; + goto out_put_xprt; + } + + rc = lower_transport->xprt_err; + +out: + xs_stream_reset_connect(lower_transport); + clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state); + return rc; + +out_put_xprt: + xprt_put(lower_xprt); + goto out; +} + +/** + * xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket + * @work: queued work item + * + * Invoked by a work queue tasklet. + * + * For RPC-with-TLS, there is a two-stage connection process. + * + * The "upper-layer xprt" is visible to the RPC consumer. Once it has + * been marked connected, the consumer knows that a TCP connection and + * a TLS session have been established. + * + * A "lower-layer xprt", created in this function, handles the mechanics + * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and + * then driving the TLS handshake. Once all that is complete, the upper + * layer xprt is marked connected. + */ +static void xs_tcp_tls_setup_socket(struct work_struct *work) +{ + struct sock_xprt *upper_transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_clnt *upper_clnt = upper_transport->clnt; + struct rpc_xprt *upper_xprt = &upper_transport->xprt; + struct rpc_create_args args = { + .net = upper_xprt->xprt_net, + .protocol = upper_xprt->prot, + .address = (struct sockaddr *)&upper_xprt->addr, + .addrsize = upper_xprt->addrlen, + .timeout = upper_clnt->cl_timeout, + .servername = upper_xprt->servername, + .program = upper_clnt->cl_program, + .prognumber = upper_clnt->cl_prog, + .version = upper_clnt->cl_vers, + .authflavor = RPC_AUTH_TLS, + .cred = upper_clnt->cl_cred, + .xprtsec = { + .policy = RPC_XPRTSEC_NONE, + }, + .stats = upper_clnt->cl_stats, + }; + unsigned int pflags = current->flags; + struct rpc_clnt *lower_clnt; + struct rpc_xprt *lower_xprt; + int status; + + if (atomic_read(&upper_xprt->swapper)) + current->flags |= PF_MEMALLOC; + + xs_stream_start_connect(upper_transport); + + /* This implicitly sends an RPC_AUTH_TLS probe */ + lower_clnt = rpc_create(&args); + if (IS_ERR(lower_clnt)) { + trace_rpc_tls_unavailable(upper_clnt, upper_xprt); + clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); + xprt_clear_connecting(upper_xprt); + xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt)); + xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); + goto out_unlock; + } + + /* RPC_AUTH_TLS probe was successful. Try a TLS handshake on + * the lower xprt. + */ + rcu_read_lock(); + lower_xprt = rcu_dereference(lower_clnt->cl_xprt); + rcu_read_unlock(); + + if (wait_on_bit_lock(&lower_xprt->state, XPRT_LOCKED, TASK_KILLABLE)) + goto out_unlock; + + status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec); + if (status) { + trace_rpc_tls_not_started(upper_clnt, upper_xprt); + goto out_close; + } + + status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport); + if (status) + goto out_close; + xprt_release_write(lower_xprt, NULL); + trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0); + rpc_shutdown_client(lower_clnt); + + /* Check for ingress data that arrived before the socket's + * ->data_ready callback was set up. + */ + xs_poll_check_readable(upper_transport); + +out_unlock: + current_restore_flags(pflags, PF_MEMALLOC); + upper_transport->clnt = NULL; + xprt_unlock_connect(upper_xprt, upper_transport); + return; + +out_close: + xprt_release_write(lower_xprt, NULL); + rpc_shutdown_client(lower_clnt); + + /* xprt_force_disconnect() wakes tasks with a fixed tk_status code. + * Wake them first here to ensure they get our tk_status code. + */ + xprt_wake_pending_tasks(upper_xprt, status); + xs_tcp_force_close(upper_xprt); + xprt_clear_connecting(upper_xprt); + goto out_unlock; } /** @@ -2248,28 +2791,70 @@ out: static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task) { struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + unsigned long delay = 0; + + WARN_ON_ONCE(!xprt_lock_connect(xprt, task, transport)); - if (transport->sock != NULL && !RPC_IS_SOFTCONN(task)) { + if (transport->sock != NULL) { dprintk("RPC: xs_connect delayed xprt %p for %lu " - "seconds\n", - xprt, xprt->reestablish_timeout / HZ); - queue_delayed_work(rpciod_workqueue, - &transport->connect_worker, - xprt->reestablish_timeout); - xprt->reestablish_timeout <<= 1; - if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) - xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; - if (xprt->reestablish_timeout > XS_TCP_MAX_REEST_TO) - xprt->reestablish_timeout = XS_TCP_MAX_REEST_TO; - } else { + "seconds\n", xprt, xprt->reestablish_timeout / HZ); + + delay = xprt_reconnect_delay(xprt); + xprt_reconnect_backoff(xprt, XS_TCP_INIT_REEST_TO); + + } else dprintk("RPC: xs_connect scheduled xprt %p\n", xprt); - queue_delayed_work(rpciod_workqueue, - &transport->connect_worker, 0); + + transport->clnt = task->tk_client; + queue_delayed_work(xprtiod_workqueue, + &transport->connect_worker, + delay); +} + +static void xs_wake_disconnect(struct sock_xprt *transport) +{ + if (test_and_clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state)) + xs_tcp_force_close(&transport->xprt); +} + +static void xs_wake_write(struct sock_xprt *transport) +{ + if (test_and_clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state)) + xprt_write_space(&transport->xprt); +} + +static void xs_wake_error(struct sock_xprt *transport) +{ + int sockerr; + + if (!test_and_clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state)) + return; + sockerr = xchg(&transport->xprt_err, 0); + if (sockerr < 0) { + xprt_wake_pending_tasks(&transport->xprt, sockerr); + xs_tcp_force_close(&transport->xprt); } } +static void xs_wake_pending(struct sock_xprt *transport) +{ + if (test_and_clear_bit(XPRT_SOCK_WAKE_PENDING, &transport->sock_state)) + xprt_wake_pending_tasks(&transport->xprt, -EAGAIN); +} + +static void xs_error_handle(struct work_struct *work) +{ + struct sock_xprt *transport = container_of(work, + struct sock_xprt, error_worker); + + xs_wake_disconnect(transport); + xs_wake_write(transport); + xs_wake_error(transport); + xs_wake_pending(transport); +} + /** - * xs_local_print_stats - display AF_LOCAL socket-specifc stats + * xs_local_print_stats - display AF_LOCAL socket-specific stats * @xprt: rpc_xprt struct containing statistics * @seq: output file * @@ -2285,7 +2870,7 @@ static void xs_local_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) "%llu %llu %lu %llu %llu\n", xprt->stat.bind_count, xprt->stat.connect_count, - xprt->stat.connect_time, + xprt->stat.connect_time / HZ, idle_time, xprt->stat.sends, xprt->stat.recvs, @@ -2298,7 +2883,7 @@ static void xs_local_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) } /** - * xs_udp_print_stats - display UDP socket-specifc stats + * xs_udp_print_stats - display UDP socket-specific stats * @xprt: rpc_xprt struct containing statistics * @seq: output file * @@ -2322,7 +2907,7 @@ static void xs_udp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) } /** - * xs_tcp_print_stats - display TCP socket-specifc stats + * xs_tcp_print_stats - display TCP socket-specific stats * @xprt: rpc_xprt struct containing statistics * @seq: output file * @@ -2340,7 +2925,7 @@ static void xs_tcp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) transport->srcport, xprt->stat.bind_count, xprt->stat.connect_count, - xprt->stat.connect_time, + xprt->stat.connect_time / HZ, idle_time, xprt->stat.sends, xprt->stat.recvs, @@ -2357,80 +2942,83 @@ static void xs_tcp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) * we allocate pages instead doing a kmalloc like rpc_malloc is because we want * to use the server side send routines. */ -static void *bc_malloc(struct rpc_task *task, size_t size) +static int bc_malloc(struct rpc_task *task) { + struct rpc_rqst *rqst = task->tk_rqstp; + size_t size = rqst->rq_callsize; struct page *page; struct rpc_buffer *buf; - WARN_ON_ONCE(size > PAGE_SIZE - sizeof(struct rpc_buffer)); - if (size > PAGE_SIZE - sizeof(struct rpc_buffer)) - return NULL; + if (size > PAGE_SIZE - sizeof(struct rpc_buffer)) { + WARN_ONCE(1, "xprtsock: large bc buffer request (size %zu)\n", + size); + return -EINVAL; + } - page = alloc_page(GFP_KERNEL); + page = alloc_page(GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN); if (!page) - return NULL; + return -ENOMEM; buf = page_address(page); buf->len = PAGE_SIZE; - return buf->data; + rqst->rq_buffer = buf->data; + rqst->rq_rbuffer = (char *)rqst->rq_buffer + rqst->rq_callsize; + return 0; } /* * Free the space allocated in the bc_alloc routine */ -static void bc_free(void *buffer) +static void bc_free(struct rpc_task *task) { + void *buffer = task->tk_rqstp->rq_buffer; struct rpc_buffer *buf; - if (!buffer) - return; - buf = container_of(buffer, struct rpc_buffer, data); free_page((unsigned long)buf); } -/* - * Use the svc_sock to send the callback. Must be called with svsk->sk_mutex - * held. Borrows heavily from svc_tcp_sendto and xs_tcp_send_request. - */ static int bc_sendto(struct rpc_rqst *req) { - int len; - struct xdr_buf *xbufp = &req->rq_snd_buf; - struct rpc_xprt *xprt = req->rq_xprt; + struct xdr_buf *xdr = &req->rq_snd_buf; struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - struct socket *sock = transport->sock; - unsigned long headoff; - unsigned long tailoff; - - xs_encode_stream_record_marker(xbufp); - - tailoff = (unsigned long)xbufp->tail[0].iov_base & ~PAGE_MASK; - headoff = (unsigned long)xbufp->head[0].iov_base & ~PAGE_MASK; - len = svc_send_common(sock, xbufp, - virt_to_page(xbufp->head[0].iov_base), headoff, - xbufp->tail[0].iov_base, tailoff); - - if (len != xbufp->len) { - printk(KERN_NOTICE "Error sending entire callback!\n"); - len = -EAGAIN; - } + container_of(req->rq_xprt, struct sock_xprt, xprt); + struct msghdr msg = { + .msg_flags = 0, + }; + rpc_fraghdr marker = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | + (u32)xdr->len); + unsigned int sent = 0; + int err; - return len; + req->rq_xtime = ktime_get(); + err = xdr_alloc_bvec(xdr, rpc_task_gfp_mask()); + if (err < 0) + return err; + err = xprt_sock_sendmsg(transport->sock, &msg, xdr, 0, marker, &sent); + xdr_free_bvec(xdr); + if (err < 0 || sent != (xdr->len + sizeof(marker))) + return -EAGAIN; + return sent; } -/* - * The send routine. Borrows from svc_send +/** + * bc_send_request - Send a backchannel Call on a TCP socket + * @req: rpc_rqst containing Call message to be sent + * + * xpt_mutex ensures @rqstp's whole message is written to the socket + * without interruption. + * + * Return values: + * %0 if the message was sent successfully + * %ENOTCONN if the message was not sent */ -static int bc_send_request(struct rpc_task *task) +static int bc_send_request(struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; struct svc_xprt *xprt; - u32 len; + int len; - dprintk("sending request with xid: %08x\n", ntohl(req->rq_xid)); /* * Get the server socket associated with this callback xprt */ @@ -2440,12 +3028,7 @@ static int bc_send_request(struct rpc_task *task) * Grab the mutex to serialize data as the connection is shared * with the fore channel */ - if (!mutex_trylock(&xprt->xpt_mutex)) { - rpc_sleep_on(&xprt->xpt_bc_pending, task, NULL); - if (!mutex_trylock(&xprt->xpt_mutex)) - return -EAGAIN; - rpc_wake_up_queued_task(&xprt->xpt_bc_pending, task); - } + mutex_lock(&xprt->xpt_mutex); if (test_bit(XPT_DEAD, &xprt->xpt_flags)) len = -ENOTCONN; else @@ -2458,89 +3041,116 @@ static int bc_send_request(struct rpc_task *task) return len; } -/* - * The close routine. Since this is client initiated, we do nothing - */ - static void bc_close(struct rpc_xprt *xprt) { + xprt_disconnect_done(xprt); } -/* - * The xprt destroy routine. Again, because this connection is client - * initiated, we do nothing - */ - static void bc_destroy(struct rpc_xprt *xprt) { + dprintk("RPC: bc_destroy xprt %p\n", xprt); + + xs_xprt_free(xprt); + module_put(THIS_MODULE); } -static struct rpc_xprt_ops xs_local_ops = { +static const struct rpc_xprt_ops xs_local_ops = { .reserve_xprt = xprt_reserve_xprt, - .release_xprt = xs_tcp_release_xprt, + .release_xprt = xprt_release_xprt, .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, .rpcbind = xs_local_rpcbind, .set_port = xs_local_set_port, .connect = xs_local_connect, .buf_alloc = rpc_malloc, .buf_free = rpc_free, + .prepare_request = xs_stream_prepare_request, .send_request = xs_local_send_request, - .set_retrans_timeout = xprt_set_retrans_timeout_def, + .abort_send_request = xs_stream_abort_send_request, + .wait_for_reply_request = xprt_wait_for_reply_request_def, .close = xs_close, - .destroy = xs_local_destroy, + .destroy = xs_destroy, .print_stats = xs_local_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, }; -static struct rpc_xprt_ops xs_udp_ops = { +static const struct rpc_xprt_ops xs_udp_ops = { .set_buffer_size = xs_udp_set_buffer_size, .reserve_xprt = xprt_reserve_xprt_cong, .release_xprt = xprt_release_xprt_cong, .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, .rpcbind = rpcb_getport_async, .set_port = xs_set_port, .connect = xs_connect, + .get_srcaddr = xs_sock_srcaddr, + .get_srcport = xs_sock_srcport, .buf_alloc = rpc_malloc, .buf_free = rpc_free, .send_request = xs_udp_send_request, - .set_retrans_timeout = xprt_set_retrans_timeout_rtt, + .wait_for_reply_request = xprt_wait_for_reply_request_rtt, .timer = xs_udp_timer, .release_request = xprt_release_rqst_cong, .close = xs_close, .destroy = xs_destroy, .print_stats = xs_udp_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, + .inject_disconnect = xs_inject_disconnect, }; -static struct rpc_xprt_ops xs_tcp_ops = { +static const struct rpc_xprt_ops xs_tcp_ops = { .reserve_xprt = xprt_reserve_xprt, - .release_xprt = xs_tcp_release_xprt, - .alloc_slot = xprt_lock_and_alloc_slot, + .release_xprt = xprt_release_xprt, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, .rpcbind = rpcb_getport_async, .set_port = xs_set_port, .connect = xs_connect, + .get_srcaddr = xs_sock_srcaddr, + .get_srcport = xs_sock_srcport, .buf_alloc = rpc_malloc, .buf_free = rpc_free, + .prepare_request = xs_stream_prepare_request, .send_request = xs_tcp_send_request, - .set_retrans_timeout = xprt_set_retrans_timeout_def, - .close = xs_tcp_close, + .abort_send_request = xs_stream_abort_send_request, + .wait_for_reply_request = xprt_wait_for_reply_request_def, + .close = xs_tcp_shutdown, .destroy = xs_destroy, + .set_connect_timeout = xs_tcp_set_connect_timeout, .print_stats = xs_tcp_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, + .inject_disconnect = xs_inject_disconnect, +#ifdef CONFIG_SUNRPC_BACKCHANNEL + .bc_setup = xprt_setup_bc, + .bc_maxpayload = xs_tcp_bc_maxpayload, + .bc_num_slots = xprt_bc_max_slots, + .bc_free_rqst = xprt_free_bc_rqst, + .bc_destroy = xprt_destroy_bc, +#endif }; /* * The rpc_xprt_ops for the server backchannel */ -static struct rpc_xprt_ops bc_tcp_ops = { +static const struct rpc_xprt_ops bc_tcp_ops = { .reserve_xprt = xprt_reserve_xprt, .release_xprt = xprt_release_xprt, .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, .buf_alloc = bc_malloc, .buf_free = bc_free, .send_request = bc_send_request, - .set_retrans_timeout = xprt_set_retrans_timeout_def, + .wait_for_reply_request = xprt_wait_for_reply_request_def, .close = bc_close, .destroy = bc_destroy, .print_stats = xs_tcp_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, + .inject_disconnect = xs_inject_disconnect, }; static int xs_init_anyaddr(const int family, struct sockaddr *sap) @@ -2591,6 +3201,7 @@ static struct rpc_xprt *xs_setup_xprt(struct xprt_create *args, } new = container_of(xprt, struct sock_xprt, xprt); + mutex_init(&new->recv_mutex); memcpy(&xprt->addr, args->dstaddr, args->addrlen); xprt->addrlen = args->addrlen; if (args->srcaddr) @@ -2634,7 +3245,7 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args) transport = container_of(xprt, struct sock_xprt, xprt); xprt->prot = 0; - xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); + xprt->xprt_class = &xs_local_transport; xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; xprt->bind_timeout = XS_BIND_TO; @@ -2644,9 +3255,13 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args) xprt->ops = &xs_local_ops; xprt->timeout = &xs_local_default_timeout; + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + INIT_DELAYED_WORK(&transport->connect_worker, xs_dummy_setup_socket); + switch (sun->sun_family) { case AF_LOCAL: - if (sun->sun_path[0] != '/') { + if (sun->sun_path[0] != '/' && sun->sun_path[0] != '\0') { dprintk("RPC: bad AF_LOCAL address: %s\n", sun->sun_path); ret = ERR_PTR(-EINVAL); @@ -2654,9 +3269,6 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args) } xprt_set_bound(xprt); xs_format_peer_addresses(xprt, "local", RPCBIND_NETID_LOCAL); - ret = ERR_PTR(xs_local_setup_socket(transport)); - if (ret) - goto out_err; break; default: ret = ERR_PTR(-EAFNOSUPPORT); @@ -2670,7 +3282,7 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args) return xprt; ret = ERR_PTR(-EINVAL); out_err: - xprt_free(xprt); + xs_xprt_free(xprt); return ret; } @@ -2700,7 +3312,7 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) transport = container_of(xprt, struct sock_xprt, xprt); xprt->prot = IPPROTO_UDP; - xprt->tsh_size = 0; + xprt->xprt_class = &xs_udp_transport; /* XXX: header size can vary due to auth type, IPv6, etc. */ xprt->max_payload = (1U << 16) - (MAX_HEADER << 3); @@ -2712,21 +3324,21 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) xprt->timeout = &xs_udp_default_timeout; + INIT_WORK(&transport->recv_worker, xs_udp_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + INIT_DELAYED_WORK(&transport->connect_worker, xs_udp_setup_socket); + switch (addr->sa_family) { case AF_INET: if (((struct sockaddr_in *)addr)->sin_port != htons(0)) xprt_set_bound(xprt); - INIT_DELAYED_WORK(&transport->connect_worker, - xs_udp_setup_socket); xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP); break; case AF_INET6: if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) xprt_set_bound(xprt); - INIT_DELAYED_WORK(&transport->connect_worker, - xs_udp_setup_socket); xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP6); break; default: @@ -2748,7 +3360,7 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) return xprt; ret = ERR_PTR(-EINVAL); out_err: - xprt_free(xprt); + xs_xprt_free(xprt); return ret; } @@ -2781,7 +3393,7 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) transport = container_of(xprt, struct sock_xprt, xprt); xprt->prot = IPPROTO_TCP; - xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); + xprt->xprt_class = &xs_tcp_transport; xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; xprt->bind_timeout = XS_BIND_TO; @@ -2791,21 +3403,30 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) xprt->ops = &xs_tcp_ops; xprt->timeout = &xs_tcp_default_timeout; + xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + if (args->reconnect_timeout) + xprt->max_reconnect_timeout = args->reconnect_timeout; + + xprt->connect_timeout = xprt->timeout->to_initval * + (xprt->timeout->to_retries + 1); + if (args->connect_timeout) + xs_tcp_do_set_connect_timeout(xprt, args->connect_timeout); + + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + INIT_DELAYED_WORK(&transport->connect_worker, xs_tcp_setup_socket); + switch (addr->sa_family) { case AF_INET: if (((struct sockaddr_in *)addr)->sin_port != htons(0)) xprt_set_bound(xprt); - INIT_DELAYED_WORK(&transport->connect_worker, - xs_tcp_setup_socket); xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP); break; case AF_INET6: if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) xprt_set_bound(xprt); - INIT_DELAYED_WORK(&transport->connect_worker, - xs_tcp_setup_socket); xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6); break; default: @@ -2823,12 +3444,99 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) xprt->address_strings[RPC_DISPLAY_ADDR], xprt->address_strings[RPC_DISPLAY_PROTO]); + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +/** + * xs_setup_tcp_tls - Set up transport to use a TCP with TLS + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct rpc_xprt *ret; + unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries; + + if (args->flags & XPRT_CREATE_INFINITE_SLOTS) + max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + max_slot_table_size); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_TCP; + xprt->xprt_class = &xs_tcp_transport; + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_tcp_ops; + xprt->timeout = &xs_tcp_default_timeout; + + xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + xprt->connect_timeout = xprt->timeout->to_initval * + (xprt->timeout->to_retries + 1); + + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + + switch (args->xprtsec.policy) { + case RPC_XPRTSEC_TLS_ANON: + case RPC_XPRTSEC_TLS_X509: + xprt->xprtsec = args->xprtsec; + INIT_DELAYED_WORK(&transport->connect_worker, + xs_tcp_tls_setup_socket); + break; + default: + ret = ERR_PTR(-EACCES); + goto out_err; + } + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + if (xprt_bound(xprt)) + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + else + dprintk("RPC: set up xprt to %s (autobind) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PROTO]); if (try_module_get(THIS_MODULE)) return xprt; ret = ERR_PTR(-EINVAL); out_err: - xprt_free(xprt); + xs_xprt_free(xprt); return ret; } @@ -2845,15 +3553,6 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) struct svc_sock *bc_sock; struct rpc_xprt *ret; - if (args->bc_xprt->xpt_bc_xprt) { - /* - * This server connection already has a backchannel - * export; we can't create a new one, as we wouldn't be - * able to match replies based on xid any more. So, - * reuse the already-existing one: - */ - return args->bc_xprt->xpt_bc_xprt; - } xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, xprt_tcp_slot_table_entries); if (IS_ERR(xprt)) @@ -2861,7 +3560,7 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) transport = container_of(xprt, struct sock_xprt, xprt); xprt->prot = IPPROTO_TCP; - xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); + xprt->xprt_class = &xs_bc_tcp_transport; xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; xprt->timeout = &xs_tcp_default_timeout; @@ -2894,10 +3593,9 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) /* * Once we've associated a backchannel xprt with a connection, - * we want to keep it around as long as long as the connection - * lasts, in case we need to start using it for a backchannel - * again; this reference won't be dropped until bc_xprt is - * destroyed. + * we want to keep it around as long as the connection lasts, + * in case we need to start using it for a backchannel again; + * this reference won't be dropped until bc_xprt is destroyed. */ xprt_get(xprt); args->bc_xprt->xpt_bc_xprt = xprt; @@ -2912,13 +3610,15 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) */ xprt_set_connected(xprt); - if (try_module_get(THIS_MODULE)) return xprt; + + args->bc_xprt->xpt_bc_xprt = NULL; + args->bc_xprt->xpt_bc_xps = NULL; xprt_put(xprt); ret = ERR_PTR(-EINVAL); out_err: - xprt_free(xprt); + xs_xprt_free(xprt); return ret; } @@ -2928,6 +3628,7 @@ static struct xprt_class xs_local_transport = { .owner = THIS_MODULE, .ident = XPRT_TRANSPORT_LOCAL, .setup = xs_setup_local, + .netid = { "" }, }; static struct xprt_class xs_udp_transport = { @@ -2936,6 +3637,7 @@ static struct xprt_class xs_udp_transport = { .owner = THIS_MODULE, .ident = XPRT_TRANSPORT_UDP, .setup = xs_setup_udp, + .netid = { "udp", "udp6", "" }, }; static struct xprt_class xs_tcp_transport = { @@ -2944,6 +3646,16 @@ static struct xprt_class xs_tcp_transport = { .owner = THIS_MODULE, .ident = XPRT_TRANSPORT_TCP, .setup = xs_setup_tcp, + .netid = { "tcp", "tcp6", "" }, +}; + +static struct xprt_class xs_tcp_tls_transport = { + .list = LIST_HEAD_INIT(xs_tcp_tls_transport.list), + .name = "tcp-with-tls", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_TCP_TLS, + .setup = xs_setup_tcp_tls, + .netid = { "tcp", "tcp6", "" }, }; static struct xprt_class xs_bc_tcp_transport = { @@ -2952,6 +3664,7 @@ static struct xprt_class xs_bc_tcp_transport = { .owner = THIS_MODULE, .ident = XPRT_TRANSPORT_BC_TCP, .setup = xs_setup_bc_tcp, + .netid = { "" }, }; /** @@ -2960,14 +3673,13 @@ static struct xprt_class xs_bc_tcp_transport = { */ int init_socket_xprt(void) { -#ifdef RPC_DEBUG if (!sunrpc_table_header) - sunrpc_table_header = register_sysctl_table(sunrpc_table); -#endif + sunrpc_table_header = register_sysctl("sunrpc", xs_tunables_table); xprt_register_transport(&xs_local_transport); xprt_register_transport(&xs_udp_transport); xprt_register_transport(&xs_tcp_transport); + xprt_register_transport(&xs_tcp_tls_transport); xprt_register_transport(&xs_bc_tcp_transport); return 0; @@ -2979,35 +3691,18 @@ int init_socket_xprt(void) */ void cleanup_socket_xprt(void) { -#ifdef RPC_DEBUG if (sunrpc_table_header) { unregister_sysctl_table(sunrpc_table_header); sunrpc_table_header = NULL; } -#endif xprt_unregister_transport(&xs_local_transport); xprt_unregister_transport(&xs_udp_transport); xprt_unregister_transport(&xs_tcp_transport); + xprt_unregister_transport(&xs_tcp_tls_transport); xprt_unregister_transport(&xs_bc_tcp_transport); } -static int param_set_uint_minmax(const char *val, - const struct kernel_param *kp, - unsigned int min, unsigned int max) -{ - unsigned long num; - int ret; - - if (!val) - return -EINVAL; - ret = strict_strtoul(val, 0, &num); - if (ret == -EINVAL || num < min || num > max) - return -EINVAL; - *((unsigned int *)kp->arg) = num; - return 0; -} - static int param_set_portnr(const char *val, const struct kernel_param *kp) { return param_set_uint_minmax(val, kp, @@ -3015,7 +3710,7 @@ static int param_set_portnr(const char *val, const struct kernel_param *kp) RPC_MAX_RESVPORT); } -static struct kernel_param_ops param_ops_portnr = { +static const struct kernel_param_ops param_ops_portnr = { .set = param_set_portnr, .get = param_get_uint, }; @@ -3034,7 +3729,7 @@ static int param_set_slot_table_size(const char *val, RPC_MAX_SLOT_TABLE); } -static struct kernel_param_ops param_ops_slot_table_size = { +static const struct kernel_param_ops param_ops_slot_table_size = { .set = param_set_slot_table_size, .get = param_get_uint, }; @@ -3050,7 +3745,7 @@ static int param_set_max_slot_table_size(const char *val, RPC_MAX_SLOT_TABLE_LIMIT); } -static struct kernel_param_ops param_ops_max_slot_table_size = { +static const struct kernel_param_ops param_ops_max_slot_table_size = { .set = param_set_max_slot_table_size, .get = param_get_uint, }; @@ -3064,4 +3759,3 @@ module_param_named(tcp_max_slot_table_entries, xprt_max_tcp_slot_table_entries, max_slot_table_size, 0644); module_param_named(udp_slot_table_entries, xprt_udp_slot_table_entries, slot_table_size, 0644); - |
