diff options
Diffstat (limited to 'net/sunrpc')
42 files changed, 1258 insertions, 1150 deletions
diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c index 04534ea537c8..5a827afd8e3b 100644 --- a/net/sunrpc/auth.c +++ b/net/sunrpc/auth.c @@ -489,7 +489,7 @@ static unsigned long rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc) { - return number_cred_unused * sysctl_vfs_cache_pressure / 100; + return number_cred_unused; } static void diff --git a/net/sunrpc/auth_gss/Makefile b/net/sunrpc/auth_gss/Makefile index ad1736d93b76..452f67deebc6 100644 --- a/net/sunrpc/auth_gss/Makefile +++ b/net/sunrpc/auth_gss/Makefile @@ -5,7 +5,7 @@ obj-$(CONFIG_SUNRPC_GSS) += auth_rpcgss.o -auth_rpcgss-y := auth_gss.o gss_generic_token.o \ +auth_rpcgss-y := auth_gss.o \ gss_mech_switch.o svcauth_gss.o \ gss_rpc_upcall.o gss_rpc_xdr.o trace.o diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c index c7af0220f82f..7b943fbafcc3 100644 --- a/net/sunrpc/auth_gss/auth_gss.c +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -1545,6 +1545,7 @@ static int gss_marshal(struct rpc_task *task, struct xdr_stream *xdr) struct kvec iov; struct xdr_buf verf_buf; int status; + u32 seqno; /* Credential */ @@ -1556,15 +1557,16 @@ static int gss_marshal(struct rpc_task *task, struct xdr_stream *xdr) cred_len = p++; spin_lock(&ctx->gc_seq_lock); - req->rq_seqno = (ctx->gc_seq < MAXSEQ) ? ctx->gc_seq++ : MAXSEQ; + seqno = (ctx->gc_seq < MAXSEQ) ? ctx->gc_seq++ : MAXSEQ; + xprt_rqst_add_seqno(req, seqno); spin_unlock(&ctx->gc_seq_lock); - if (req->rq_seqno == MAXSEQ) + if (*req->rq_seqnos == MAXSEQ) goto expired; trace_rpcgss_seqno(task); *p++ = cpu_to_be32(RPC_GSS_VERSION); *p++ = cpu_to_be32(ctx->gc_proc); - *p++ = cpu_to_be32(req->rq_seqno); + *p++ = cpu_to_be32(*req->rq_seqnos); *p++ = cpu_to_be32(gss_cred->gc_service); p = xdr_encode_netobj(p, &ctx->gc_wire_ctx); *cred_len = cpu_to_be32((p - (cred_len + 1)) << 2); @@ -1678,17 +1680,31 @@ gss_refresh_null(struct rpc_task *task) return 0; } +static u32 +gss_validate_seqno_mic(struct gss_cl_ctx *ctx, u32 seqno, __be32 *seq, __be32 *p, u32 len) +{ + struct kvec iov; + struct xdr_buf verf_buf; + struct xdr_netobj mic; + + *seq = cpu_to_be32(seqno); + iov.iov_base = seq; + iov.iov_len = 4; + xdr_buf_from_iov(&iov, &verf_buf); + mic.data = (u8 *)p; + mic.len = len; + return gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic); +} + static int gss_validate(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_cred *cred = task->tk_rqstp->rq_cred; struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); __be32 *p, *seq = NULL; - struct kvec iov; - struct xdr_buf verf_buf; - struct xdr_netobj mic; u32 len, maj_stat; int status; + int i = 1; /* don't recheck the first item */ p = xdr_inline_decode(xdr, 2 * sizeof(*p)); if (!p) @@ -1705,13 +1721,10 @@ gss_validate(struct rpc_task *task, struct xdr_stream *xdr) seq = kmalloc(4, GFP_KERNEL); if (!seq) goto validate_failed; - *seq = cpu_to_be32(task->tk_rqstp->rq_seqno); - iov.iov_base = seq; - iov.iov_len = 4; - xdr_buf_from_iov(&iov, &verf_buf); - mic.data = (u8 *)p; - mic.len = len; - maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic); + maj_stat = gss_validate_seqno_mic(ctx, task->tk_rqstp->rq_seqnos[0], seq, p, len); + /* RFC 2203 5.3.3.1 - compute the checksum of each sequence number in the cache */ + while (unlikely(maj_stat == GSS_S_BAD_SIG && i < task->tk_rqstp->rq_seqno_count)) + maj_stat = gss_validate_seqno_mic(ctx, task->tk_rqstp->rq_seqnos[i++], seq, p, len); if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); if (maj_stat) @@ -1750,7 +1763,7 @@ gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, if (!p) goto wrap_failed; integ_len = p++; - *p = cpu_to_be32(rqstp->rq_seqno); + *p = cpu_to_be32(*rqstp->rq_seqnos); if (rpcauth_wrap_req_encode(task, xdr)) goto wrap_failed; @@ -1847,7 +1860,7 @@ gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, if (!p) goto wrap_failed; opaque_len = p++; - *p = cpu_to_be32(rqstp->rq_seqno); + *p = cpu_to_be32(*rqstp->rq_seqnos); if (rpcauth_wrap_req_encode(task, xdr)) goto wrap_failed; @@ -1875,8 +1888,10 @@ gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages); /* slack space should prevent this ever happening: */ - if (unlikely(snd_buf->len > snd_buf->buflen)) + if (unlikely(snd_buf->len > snd_buf->buflen)) { + status = -EIO; goto wrap_failed; + } /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was * done anyway, so it's safe to put the request on the wire: */ if (maj_stat == GSS_S_CONTEXT_EXPIRED) @@ -1999,7 +2014,7 @@ gss_unwrap_resp_integ(struct rpc_task *task, struct rpc_cred *cred, offset = rcv_buf->len - xdr_stream_remaining(xdr); if (xdr_stream_decode_u32(xdr, &seqno)) goto unwrap_failed; - if (seqno != rqstp->rq_seqno) + if (seqno != *rqstp->rq_seqnos) goto bad_seqno; if (xdr_buf_subsegment(rcv_buf, &gss_data, offset, len)) goto unwrap_failed; @@ -2043,7 +2058,7 @@ unwrap_failed: trace_rpcgss_unwrap_failed(task); goto out; bad_seqno: - trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, seqno); + trace_rpcgss_bad_seqno(task, *rqstp->rq_seqnos, seqno); goto out; bad_mic: trace_rpcgss_verify_mic(task, maj_stat); @@ -2075,7 +2090,7 @@ gss_unwrap_resp_priv(struct rpc_task *task, struct rpc_cred *cred, if (maj_stat != GSS_S_COMPLETE) goto bad_unwrap; /* gss_unwrap decrypted the sequence number */ - if (be32_to_cpup(p++) != rqstp->rq_seqno) + if (be32_to_cpup(p++) != *rqstp->rq_seqnos) goto bad_seqno; /* gss_unwrap redacts the opaque blob from the head iovec. @@ -2091,7 +2106,7 @@ unwrap_failed: trace_rpcgss_unwrap_failed(task); return -EIO; bad_seqno: - trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, be32_to_cpup(--p)); + trace_rpcgss_bad_seqno(task, *rqstp->rq_seqnos, be32_to_cpup(--p)); return -EIO; bad_unwrap: trace_rpcgss_unwrap(task, maj_stat); @@ -2116,14 +2131,14 @@ gss_xmit_need_reencode(struct rpc_task *task) if (!ctx) goto out; - if (gss_seq_is_newer(req->rq_seqno, READ_ONCE(ctx->gc_seq))) + if (gss_seq_is_newer(*req->rq_seqnos, READ_ONCE(ctx->gc_seq))) goto out_ctx; seq_xmit = READ_ONCE(ctx->gc_seq_xmit); - while (gss_seq_is_newer(req->rq_seqno, seq_xmit)) { + while (gss_seq_is_newer(*req->rq_seqnos, seq_xmit)) { u32 tmp = seq_xmit; - seq_xmit = cmpxchg(&ctx->gc_seq_xmit, tmp, req->rq_seqno); + seq_xmit = cmpxchg(&ctx->gc_seq_xmit, tmp, *req->rq_seqnos); if (seq_xmit == tmp) { ret = false; goto out_ctx; @@ -2132,7 +2147,7 @@ gss_xmit_need_reencode(struct rpc_task *task) win = ctx->gc_win; if (win > 0) - ret = !gss_seq_is_newer(req->rq_seqno, seq_xmit - win); + ret = !gss_seq_is_newer(*req->rq_seqnos, seq_xmit - win); out_ctx: gss_put_ctx(ctx); diff --git a/net/sunrpc/auth_gss/auth_gss_internal.h b/net/sunrpc/auth_gss/auth_gss_internal.h index c53b329092d4..4ebc1b7043d9 100644 --- a/net/sunrpc/auth_gss/auth_gss_internal.h +++ b/net/sunrpc/auth_gss/auth_gss_internal.h @@ -23,7 +23,7 @@ simple_get_bytes(const void *p, const void *end, void *res, size_t len) } static inline const void * -simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest) +simple_get_netobj_noprof(const void *p, const void *end, struct xdr_netobj *dest) { const void *q; unsigned int len; @@ -35,7 +35,7 @@ simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest) if (unlikely(q > end || q < p)) return ERR_PTR(-EFAULT); if (len) { - dest->data = kmemdup(p, len, GFP_KERNEL); + dest->data = kmemdup_noprof(p, len, GFP_KERNEL); if (unlikely(dest->data == NULL)) return ERR_PTR(-ENOMEM); } else @@ -43,3 +43,5 @@ simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest) dest->len = len; return q; } + +#define simple_get_netobj(...) alloc_hooks(simple_get_netobj_noprof(__VA_ARGS__)) diff --git a/net/sunrpc/auth_gss/gss_generic_token.c b/net/sunrpc/auth_gss/gss_generic_token.c deleted file mode 100644 index 4a4082bb22ad..000000000000 --- a/net/sunrpc/auth_gss/gss_generic_token.c +++ /dev/null @@ -1,231 +0,0 @@ -/* - * linux/net/sunrpc/gss_generic_token.c - * - * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/generic/util_token.c - * - * Copyright (c) 2000 The Regents of the University of Michigan. - * All rights reserved. - * - * Andy Adamson <andros@umich.edu> - */ - -/* - * Copyright 1993 by OpenVision Technologies, Inc. - * - * Permission to use, copy, modify, distribute, and sell this software - * and its documentation for any purpose is hereby granted without fee, - * provided that the above copyright notice appears in all copies and - * that both that copyright notice and this permission notice appear in - * supporting documentation, and that the name of OpenVision not be used - * in advertising or publicity pertaining to distribution of the software - * without specific, written prior permission. OpenVision makes no - * representations about the suitability of this software for any - * purpose. It is provided "as is" without express or implied warranty. - * - * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, - * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO - * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR - * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF - * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR - * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR - * PERFORMANCE OF THIS SOFTWARE. - */ - -#include <linux/types.h> -#include <linux/module.h> -#include <linux/string.h> -#include <linux/sunrpc/sched.h> -#include <linux/sunrpc/gss_asn1.h> - - -#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif - - -/* TWRITE_STR from gssapiP_generic.h */ -#define TWRITE_STR(ptr, str, len) \ - memcpy((ptr), (char *) (str), (len)); \ - (ptr) += (len); - -/* XXXX this code currently makes the assumption that a mech oid will - never be longer than 127 bytes. This assumption is not inherent in - the interfaces, so the code can be fixed if the OSI namespace - balloons unexpectedly. */ - -/* Each token looks like this: - -0x60 tag for APPLICATION 0, SEQUENCE - (constructed, definite-length) - <length> possible multiple bytes, need to parse/generate - 0x06 tag for OBJECT IDENTIFIER - <moid_length> compile-time constant string (assume 1 byte) - <moid_bytes> compile-time constant string - <inner_bytes> the ANY containing the application token - bytes 0,1 are the token type - bytes 2,n are the token data - -For the purposes of this abstraction, the token "header" consists of -the sequence tag and length octets, the mech OID DER encoding, and the -first two inner bytes, which indicate the token type. The token -"body" consists of everything else. - -*/ - -static int -der_length_size( int length) -{ - if (length < (1<<7)) - return 1; - else if (length < (1<<8)) - return 2; -#if (SIZEOF_INT == 2) - else - return 3; -#else - else if (length < (1<<16)) - return 3; - else if (length < (1<<24)) - return 4; - else - return 5; -#endif -} - -static void -der_write_length(unsigned char **buf, int length) -{ - if (length < (1<<7)) { - *(*buf)++ = (unsigned char) length; - } else { - *(*buf)++ = (unsigned char) (der_length_size(length)+127); -#if (SIZEOF_INT > 2) - if (length >= (1<<24)) - *(*buf)++ = (unsigned char) (length>>24); - if (length >= (1<<16)) - *(*buf)++ = (unsigned char) ((length>>16)&0xff); -#endif - if (length >= (1<<8)) - *(*buf)++ = (unsigned char) ((length>>8)&0xff); - *(*buf)++ = (unsigned char) (length&0xff); - } -} - -/* returns decoded length, or < 0 on failure. Advances buf and - decrements bufsize */ - -static int -der_read_length(unsigned char **buf, int *bufsize) -{ - unsigned char sf; - int ret; - - if (*bufsize < 1) - return -1; - sf = *(*buf)++; - (*bufsize)--; - if (sf & 0x80) { - if ((sf &= 0x7f) > ((*bufsize)-1)) - return -1; - if (sf > SIZEOF_INT) - return -1; - ret = 0; - for (; sf; sf--) { - ret = (ret<<8) + (*(*buf)++); - (*bufsize)--; - } - } else { - ret = sf; - } - - return ret; -} - -/* returns the length of a token, given the mech oid and the body size */ - -int -g_token_size(struct xdr_netobj *mech, unsigned int body_size) -{ - /* set body_size to sequence contents size */ - body_size += 2 + (int) mech->len; /* NEED overflow check */ - return 1 + der_length_size(body_size) + body_size; -} - -EXPORT_SYMBOL_GPL(g_token_size); - -/* fills in a buffer with the token header. The buffer is assumed to - be the right size. buf is advanced past the token header */ - -void -g_make_token_header(struct xdr_netobj *mech, int body_size, unsigned char **buf) -{ - *(*buf)++ = 0x60; - der_write_length(buf, 2 + mech->len + body_size); - *(*buf)++ = 0x06; - *(*buf)++ = (unsigned char) mech->len; - TWRITE_STR(*buf, mech->data, ((int) mech->len)); -} - -EXPORT_SYMBOL_GPL(g_make_token_header); - -/* - * Given a buffer containing a token, reads and verifies the token, - * leaving buf advanced past the token header, and setting body_size - * to the number of remaining bytes. Returns 0 on success, - * G_BAD_TOK_HEADER for a variety of errors, and G_WRONG_MECH if the - * mechanism in the token does not match the mech argument. buf and - * *body_size are left unmodified on error. - */ -u32 -g_verify_token_header(struct xdr_netobj *mech, int *body_size, - unsigned char **buf_in, int toksize) -{ - unsigned char *buf = *buf_in; - int seqsize; - struct xdr_netobj toid; - int ret = 0; - - if ((toksize-=1) < 0) - return G_BAD_TOK_HEADER; - if (*buf++ != 0x60) - return G_BAD_TOK_HEADER; - - if ((seqsize = der_read_length(&buf, &toksize)) < 0) - return G_BAD_TOK_HEADER; - - if (seqsize != toksize) - return G_BAD_TOK_HEADER; - - if ((toksize-=1) < 0) - return G_BAD_TOK_HEADER; - if (*buf++ != 0x06) - return G_BAD_TOK_HEADER; - - if ((toksize-=1) < 0) - return G_BAD_TOK_HEADER; - toid.len = *buf++; - - if ((toksize-=toid.len) < 0) - return G_BAD_TOK_HEADER; - toid.data = buf; - buf+=toid.len; - - if (! g_OID_equal(&toid, mech)) - ret = G_WRONG_MECH; - - /* G_WRONG_MECH is not returned immediately because it's more important - to return G_BAD_TOK_HEADER if the token header is in fact bad */ - - if ((toksize-=2) < 0) - return G_BAD_TOK_HEADER; - - if (ret) - return ret; - - *buf_in = buf; - *body_size = toksize; - - return ret; -} - -EXPORT_SYMBOL_GPL(g_verify_token_header); diff --git a/net/sunrpc/auth_gss/gss_krb5_crypto.c b/net/sunrpc/auth_gss/gss_krb5_crypto.c index b2c1b683a88e..8f2d65c1e831 100644 --- a/net/sunrpc/auth_gss/gss_krb5_crypto.c +++ b/net/sunrpc/auth_gss/gss_krb5_crypto.c @@ -138,60 +138,6 @@ out: return ret; } -/** - * krb5_decrypt - simple decryption of an RPCSEC GSS payload - * @tfm: initialized cipher transform - * @iv: pointer to an IV - * @in: ciphertext to decrypt - * @out: OUT: plaintext - * @length: length of input and output buffers, in bytes - * - * @iv may be NULL to force the use of an all-zero IV. - * The buffer containing the IV must be as large as the - * cipher's ivsize. - * - * Return values: - * %0: @in successfully decrypted into @out - * negative errno: @in not decrypted - */ -u32 -krb5_decrypt( - struct crypto_sync_skcipher *tfm, - void * iv, - void * in, - void * out, - int length) -{ - u32 ret = -EINVAL; - struct scatterlist sg[1]; - u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; - SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); - - if (length % crypto_sync_skcipher_blocksize(tfm) != 0) - goto out; - - if (crypto_sync_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { - dprintk("RPC: gss_k5decrypt: tfm iv size too large %d\n", - crypto_sync_skcipher_ivsize(tfm)); - goto out; - } - if (iv) - memcpy(local_iv, iv, crypto_sync_skcipher_ivsize(tfm)); - - memcpy(out, in, length); - sg_init_one(sg, out, length); - - skcipher_request_set_sync_tfm(req, tfm); - skcipher_request_set_callback(req, 0, NULL, NULL); - skcipher_request_set_crypt(req, sg, sg, length, local_iv); - - ret = crypto_skcipher_decrypt(req); - skcipher_request_zero(req); -out: - dprintk("RPC: gss_k5decrypt returns %d\n",ret); - return ret; -} - static int checksummer(struct scatterlist *sg, void *data) { @@ -202,96 +148,6 @@ checksummer(struct scatterlist *sg, void *data) return crypto_ahash_update(req); } -/* - * checksum the plaintext data and hdrlen bytes of the token header - * The checksum is performed over the first 8 bytes of the - * gss token header and then over the data body - */ -u32 -make_checksum(struct krb5_ctx *kctx, char *header, int hdrlen, - struct xdr_buf *body, int body_offset, u8 *cksumkey, - unsigned int usage, struct xdr_netobj *cksumout) -{ - struct crypto_ahash *tfm; - struct ahash_request *req; - struct scatterlist sg[1]; - int err = -1; - u8 *checksumdata; - unsigned int checksumlen; - - if (cksumout->len < kctx->gk5e->cksumlength) { - dprintk("%s: checksum buffer length, %u, too small for %s\n", - __func__, cksumout->len, kctx->gk5e->name); - return GSS_S_FAILURE; - } - - checksumdata = kmalloc(GSS_KRB5_MAX_CKSUM_LEN, GFP_KERNEL); - if (checksumdata == NULL) - return GSS_S_FAILURE; - - tfm = crypto_alloc_ahash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); - if (IS_ERR(tfm)) - goto out_free_cksum; - - req = ahash_request_alloc(tfm, GFP_KERNEL); - if (!req) - goto out_free_ahash; - - ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); - - checksumlen = crypto_ahash_digestsize(tfm); - - if (cksumkey != NULL) { - err = crypto_ahash_setkey(tfm, cksumkey, - kctx->gk5e->keylength); - if (err) - goto out; - } - - err = crypto_ahash_init(req); - if (err) - goto out; - sg_init_one(sg, header, hdrlen); - ahash_request_set_crypt(req, sg, NULL, hdrlen); - err = crypto_ahash_update(req); - if (err) - goto out; - err = xdr_process_buf(body, body_offset, body->len - body_offset, - checksummer, req); - if (err) - goto out; - ahash_request_set_crypt(req, NULL, checksumdata, 0); - err = crypto_ahash_final(req); - if (err) - goto out; - - switch (kctx->gk5e->ctype) { - case CKSUMTYPE_RSA_MD5: - err = krb5_encrypt(kctx->seq, NULL, checksumdata, - checksumdata, checksumlen); - if (err) - goto out; - memcpy(cksumout->data, - checksumdata + checksumlen - kctx->gk5e->cksumlength, - kctx->gk5e->cksumlength); - break; - case CKSUMTYPE_HMAC_SHA1_DES3: - memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); - break; - default: - BUG(); - break; - } - cksumout->len = kctx->gk5e->cksumlength; -out: - ahash_request_free(req); -out_free_ahash: - crypto_free_ahash(tfm); -out_free_cksum: - kfree(checksumdata); - return err ? GSS_S_FAILURE : 0; -} - /** * gss_krb5_checksum - Compute the MAC for a GSS Wrap or MIC token * @tfm: an initialized hash transform @@ -442,35 +298,6 @@ encryptor(struct scatterlist *sg, void *data) return 0; } -int -gss_encrypt_xdr_buf(struct crypto_sync_skcipher *tfm, struct xdr_buf *buf, - int offset, struct page **pages) -{ - int ret; - struct encryptor_desc desc; - SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); - - BUG_ON((buf->len - offset) % crypto_sync_skcipher_blocksize(tfm) != 0); - - skcipher_request_set_sync_tfm(req, tfm); - skcipher_request_set_callback(req, 0, NULL, NULL); - - memset(desc.iv, 0, sizeof(desc.iv)); - desc.req = req; - desc.pos = offset; - desc.outbuf = buf; - desc.pages = pages; - desc.fragno = 0; - desc.fraglen = 0; - - sg_init_table(desc.infrags, 4); - sg_init_table(desc.outfrags, 4); - - ret = xdr_process_buf(buf, offset, buf->len - offset, encryptor, &desc); - skcipher_request_zero(req); - return ret; -} - struct decryptor_desc { u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; struct skcipher_request *req; @@ -525,32 +352,6 @@ decryptor(struct scatterlist *sg, void *data) return 0; } -int -gss_decrypt_xdr_buf(struct crypto_sync_skcipher *tfm, struct xdr_buf *buf, - int offset) -{ - int ret; - struct decryptor_desc desc; - SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); - - /* XXXJBF: */ - BUG_ON((buf->len - offset) % crypto_sync_skcipher_blocksize(tfm) != 0); - - skcipher_request_set_sync_tfm(req, tfm); - skcipher_request_set_callback(req, 0, NULL, NULL); - - memset(desc.iv, 0, sizeof(desc.iv)); - desc.req = req; - desc.fragno = 0; - desc.fraglen = 0; - - sg_init_table(desc.frags, 4); - - ret = xdr_process_buf(buf, offset, buf->len - offset, decryptor, &desc); - skcipher_request_zero(req); - return ret; -} - /* * This function makes the assumption that it was ultimately called * from gss_wrap(). @@ -921,8 +722,6 @@ out_err: * Caller provides the truncation length of the output token (h) in * cksumout.len. * - * Note that for RPCSEC, the "initial cipher state" is always all zeroes. - * * Return values: * %GSS_S_COMPLETE: Digest computed, @cksumout filled in * %GSS_S_FAILURE: Call failed @@ -933,19 +732,22 @@ u32 krb5_etm_checksum(struct crypto_sync_skcipher *cipher, int body_offset, struct xdr_netobj *cksumout) { unsigned int ivsize = crypto_sync_skcipher_ivsize(cipher); - static const u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; struct ahash_request *req; struct scatterlist sg[1]; + u8 *iv, *checksumdata; int err = -ENOMEM; - u8 *checksumdata; checksumdata = kmalloc(crypto_ahash_digestsize(tfm), GFP_KERNEL); if (!checksumdata) return GSS_S_FAILURE; + /* For RPCSEC, the "initial cipher state" is always all zeroes. */ + iv = kzalloc(ivsize, GFP_KERNEL); + if (!iv) + goto out_free_mem; req = ahash_request_alloc(tfm, GFP_KERNEL); if (!req) - goto out_free_cksumdata; + goto out_free_mem; ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); err = crypto_ahash_init(req); if (err) @@ -969,7 +771,8 @@ u32 krb5_etm_checksum(struct crypto_sync_skcipher *cipher, out_free_ahash: ahash_request_free(req); -out_free_cksumdata: +out_free_mem: + kfree(iv); kfree_sensitive(checksumdata); return err ? GSS_S_FAILURE : GSS_S_COMPLETE; } diff --git a/net/sunrpc/auth_gss/gss_krb5_internal.h b/net/sunrpc/auth_gss/gss_krb5_internal.h index 3afd4065bf3d..8769e9e705bf 100644 --- a/net/sunrpc/auth_gss/gss_krb5_internal.h +++ b/net/sunrpc/auth_gss/gss_krb5_internal.h @@ -155,10 +155,6 @@ static inline int krb5_derive_key(struct krb5_ctx *kctx, void krb5_make_confounder(u8 *p, int conflen); -u32 make_checksum(struct krb5_ctx *kctx, char *header, int hdrlen, - struct xdr_buf *body, int body_offset, u8 *cksumkey, - unsigned int usage, struct xdr_netobj *cksumout); - u32 gss_krb5_checksum(struct crypto_ahash *tfm, char *header, int hdrlen, const struct xdr_buf *body, int body_offset, struct xdr_netobj *cksumout); @@ -166,19 +162,9 @@ u32 gss_krb5_checksum(struct crypto_ahash *tfm, char *header, int hdrlen, u32 krb5_encrypt(struct crypto_sync_skcipher *key, void *iv, void *in, void *out, int length); -u32 krb5_decrypt(struct crypto_sync_skcipher *key, void *iv, void *in, - void *out, int length); - int xdr_extend_head(struct xdr_buf *buf, unsigned int base, unsigned int shiftlen); -int gss_encrypt_xdr_buf(struct crypto_sync_skcipher *tfm, - struct xdr_buf *outbuf, int offset, - struct page **pages); - -int gss_decrypt_xdr_buf(struct crypto_sync_skcipher *tfm, - struct xdr_buf *inbuf, int offset); - u32 gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_buf *buf, struct page **pages); diff --git a/net/sunrpc/auth_gss/gss_krb5_keys.c b/net/sunrpc/auth_gss/gss_krb5_keys.c index 06d8ee0db000..4eb19c3a54c7 100644 --- a/net/sunrpc/auth_gss/gss_krb5_keys.c +++ b/net/sunrpc/auth_gss/gss_krb5_keys.c @@ -168,7 +168,7 @@ static int krb5_DK(const struct gss_krb5_enctype *gk5e, goto err_return; blocksize = crypto_sync_skcipher_blocksize(cipher); if (crypto_sync_skcipher_setkey(cipher, inkey->data, inkey->len)) - goto err_return; + goto err_free_cipher; ret = -ENOMEM; inblockdata = kmalloc(blocksize, gfp_mask); diff --git a/net/sunrpc/auth_gss/gss_krb5_test.c b/net/sunrpc/auth_gss/gss_krb5_test.c index 85625e3f3814..a5bff02cd7ba 100644 --- a/net/sunrpc/auth_gss/gss_krb5_test.c +++ b/net/sunrpc/auth_gss/gss_krb5_test.c @@ -17,7 +17,7 @@ #include "gss_krb5_internal.h" -MODULE_IMPORT_NS(EXPORTED_FOR_KUNIT_TESTING); +MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING"); struct gss_krb5_test_param { const char *desc; diff --git a/net/sunrpc/auth_gss/gss_mech_switch.c b/net/sunrpc/auth_gss/gss_mech_switch.c index fae632da1058..c84d0cf61980 100644 --- a/net/sunrpc/auth_gss/gss_mech_switch.c +++ b/net/sunrpc/auth_gss/gss_mech_switch.c @@ -13,7 +13,6 @@ #include <linux/module.h> #include <linux/oid_registry.h> #include <linux/sunrpc/msg_prot.h> -#include <linux/sunrpc/gss_asn1.h> #include <linux/sunrpc/auth_gss.h> #include <linux/sunrpc/svcauth_gss.h> #include <linux/sunrpc/gss_err.h> diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c index 24de94184700..73a90ad873fb 100644 --- a/net/sunrpc/auth_gss/svcauth_gss.c +++ b/net/sunrpc/auth_gss/svcauth_gss.c @@ -1033,17 +1033,11 @@ null_verifier: static void gss_free_in_token_pages(struct gssp_in_token *in_token) { - u32 inlen; int i; i = 0; - inlen = in_token->page_len; - while (inlen) { - if (in_token->pages[i]) - put_page(in_token->pages[i]); - inlen -= inlen > PAGE_SIZE ? PAGE_SIZE : inlen; - } - + while (in_token->pages[i]) + put_page(in_token->pages[i++]); kfree(in_token->pages); in_token->pages = NULL; } @@ -1075,7 +1069,7 @@ static int gss_read_proxy_verf(struct svc_rqst *rqstp, goto out_denied_free; pages = DIV_ROUND_UP(inlen, PAGE_SIZE); - in_token->pages = kcalloc(pages, sizeof(struct page *), GFP_KERNEL); + in_token->pages = kcalloc(pages + 1, sizeof(struct page *), GFP_KERNEL); if (!in_token->pages) goto out_denied_free; in_token->page_base = 0; diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c index 95ff74706104..131090f31e6a 100644 --- a/net/sunrpc/cache.c +++ b/net/sunrpc/cache.c @@ -135,6 +135,8 @@ static struct cache_head *sunrpc_cache_add_entry(struct cache_detail *detail, hlist_add_head_rcu(&new->cache_list, head); detail->entries++; + if (detail->nextcheck > new->expiry_time) + detail->nextcheck = new->expiry_time + 1; cache_get(new); spin_unlock(&detail->hash_lock); @@ -281,21 +283,7 @@ static int try_to_negate_entry(struct cache_detail *detail, struct cache_head *h return rv; } -/* - * This is the generic cache management routine for all - * the authentication caches. - * It checks the currency of a cache item and will (later) - * initiate an upcall to fill it if needed. - * - * - * Returns 0 if the cache_head can be used, or cache_puts it and returns - * -EAGAIN if upcall is pending and request has been queued - * -ETIMEDOUT if upcall failed or request could not be queue or - * upcall completed but item is still invalid (implying that - * the cache item has been replaced with a newer one). - * -ENOENT if cache entry was negative - */ -int cache_check(struct cache_detail *detail, +int cache_check_rcu(struct cache_detail *detail, struct cache_head *h, struct cache_req *rqstp) { int rv; @@ -336,6 +324,31 @@ int cache_check(struct cache_detail *detail, rv = -ETIMEDOUT; } } + + return rv; +} +EXPORT_SYMBOL_GPL(cache_check_rcu); + +/* + * This is the generic cache management routine for all + * the authentication caches. + * It checks the currency of a cache item and will (later) + * initiate an upcall to fill it if needed. + * + * + * Returns 0 if the cache_head can be used, or cache_puts it and returns + * -EAGAIN if upcall is pending and request has been queued + * -ETIMEDOUT if upcall failed or request could not be queue or + * upcall completed but item is still invalid (implying that + * the cache item has been replaced with a newer one). + * -ENOENT if cache entry was negative + */ +int cache_check(struct cache_detail *detail, + struct cache_head *h, struct cache_req *rqstp) +{ + int rv; + + rv = cache_check_rcu(detail, h, rqstp); if (rv) cache_put(h, detail); return rv; @@ -451,24 +464,21 @@ static int cache_clean(void) } } + spin_lock(¤t_detail->hash_lock); + /* find a non-empty bucket in the table */ - while (current_detail && - current_index < current_detail->hash_size && + while (current_index < current_detail->hash_size && hlist_empty(¤t_detail->hash_table[current_index])) current_index++; /* find a cleanable entry in the bucket and clean it, or set to next bucket */ - - if (current_detail && current_index < current_detail->hash_size) { + if (current_index < current_detail->hash_size) { struct cache_head *ch = NULL; struct cache_detail *d; struct hlist_head *head; struct hlist_node *tmp; - spin_lock(¤t_detail->hash_lock); - /* Ok, now to clean this strand */ - head = ¤t_detail->hash_table[current_index]; hlist_for_each_entry_safe(ch, tmp, head, cache_list) { if (current_detail->nextcheck > ch->expiry_time) @@ -489,8 +499,10 @@ static int cache_clean(void) spin_unlock(&cache_list_lock); if (ch) sunrpc_end_cache_remove_entry(ch, d); - } else + } else { + spin_unlock(¤t_detail->hash_lock); spin_unlock(&cache_list_lock); + } return rv; } @@ -731,11 +743,10 @@ static bool cache_defer_req(struct cache_req *req, struct cache_head *item) static void cache_revisit_request(struct cache_head *item) { struct cache_deferred_req *dreq; - struct list_head pending; struct hlist_node *tmp; int hash = DFR_HASH(item); + LIST_HEAD(pending); - INIT_LIST_HEAD(&pending); spin_lock(&cache_defer_lock); hlist_for_each_entry_safe(dreq, tmp, &cache_defer_hash[hash], hash) @@ -756,10 +767,8 @@ static void cache_revisit_request(struct cache_head *item) void cache_clean_deferred(void *owner) { struct cache_deferred_req *dreq, *tmp; - struct list_head pending; + LIST_HEAD(pending); - - INIT_LIST_HEAD(&pending); spin_lock(&cache_defer_lock); list_for_each_entry_safe(dreq, tmp, &cache_defer_list, recent) { @@ -1085,9 +1094,8 @@ static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch) { struct cache_queue *cq, *tmp; struct cache_request *cr; - struct list_head dequeued; + LIST_HEAD(dequeued); - INIT_LIST_HEAD(&dequeued); spin_lock(&queue_lock); list_for_each_entry_safe(cq, tmp, &detail->queue, list) if (!cq->reader) { @@ -1431,15 +1439,11 @@ static int c_show(struct seq_file *m, void *p) seq_printf(m, "# expiry=%lld refcnt=%d flags=%lx\n", convert_to_wallclock(cp->expiry_time), kref_read(&cp->ref), cp->flags); - cache_get(cp); - if (cache_check(cd, cp, NULL)) - /* cache_check does a cache_put on failure */ + + if (cache_check_rcu(cd, cp, NULL)) + seq_puts(m, "# "); + else if (cache_is_expired(cd, cp)) seq_puts(m, "# "); - else { - if (cache_is_expired(cd, cp)) - seq_puts(m, "# "); - cache_put(cp, cd); - } return cd->cache_show(m, cd, cp); } @@ -1596,7 +1600,6 @@ static int cache_release_procfs(struct inode *inode, struct file *filp) } static const struct proc_ops cache_channel_proc_ops = { - .proc_lseek = no_llseek, .proc_read = cache_read_procfs, .proc_write = cache_write_procfs, .proc_poll = cache_poll_procfs, @@ -1662,7 +1665,6 @@ static const struct proc_ops cache_flush_proc_ops = { .proc_read = read_flush_procfs, .proc_write = write_flush_procfs, .proc_release = release_flush_procfs, - .proc_lseek = no_llseek, }; static void remove_cache_proc_entries(struct cache_detail *cd) @@ -1673,12 +1675,14 @@ static void remove_cache_proc_entries(struct cache_detail *cd) } } -#ifdef CONFIG_PROC_FS static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) { struct proc_dir_entry *p; struct sunrpc_net *sn; + if (!IS_ENABLED(CONFIG_PROC_FS)) + return 0; + sn = net_generic(net, sunrpc_net_id); cd->procfs = proc_mkdir(cd->name, sn->proc_net_rpc); if (cd->procfs == NULL) @@ -1706,12 +1710,6 @@ out_nomem: remove_cache_proc_entries(cd); return -ENOMEM; } -#else /* CONFIG_PROC_FS */ -static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) -{ - return 0; -} -#endif void __init cache_initialize(void) { @@ -1815,7 +1813,6 @@ static int cache_release_pipefs(struct inode *inode, struct file *filp) const struct file_operations cache_file_operations_pipefs = { .owner = THIS_MODULE, - .llseek = no_llseek, .read = cache_read_pipefs, .write = cache_write_pipefs, .poll = cache_poll_pipefs, @@ -1881,7 +1878,6 @@ const struct file_operations cache_flush_operations_pipefs = { .read = read_flush_pipefs, .write = write_flush_pipefs, .release = release_flush_pipefs, - .llseek = no_llseek, }; int sunrpc_cache_register_pipefs(struct dentry *parent, diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c index 28f3749f6dc6..21426c3049d3 100644 --- a/net/sunrpc/clnt.c +++ b/net/sunrpc/clnt.c @@ -48,13 +48,8 @@ # define RPCDBG_FACILITY RPCDBG_CALL #endif -/* - * All RPC clients are linked into this list - */ - static DECLARE_WAIT_QUEUE_HEAD(destroy_wait); - static void call_start(struct rpc_task *task); static void call_reserve(struct rpc_task *task); static void call_reserveresult(struct rpc_task *task); @@ -275,9 +270,6 @@ static struct rpc_xprt *rpc_clnt_set_transport(struct rpc_clnt *clnt, old = rcu_dereference_protected(clnt->cl_xprt, lockdep_is_held(&clnt->cl_lock)); - if (!xprt_bound(xprt)) - clnt->cl_autobind = 1; - clnt->cl_timeout = timeout; rcu_assign_pointer(clnt->cl_xprt, xprt); spin_unlock(&clnt->cl_lock); @@ -517,6 +509,8 @@ static struct rpc_clnt *rpc_create_xprt(struct rpc_create_args *args, clnt->cl_discrtry = 1; if (!(args->flags & RPC_CLNT_CREATE_QUIET)) clnt->cl_chatty = 1; + if (args->flags & RPC_CLNT_CREATE_NETUNREACH_FATAL) + clnt->cl_netunreach_fatal = 1; return clnt; } @@ -546,7 +540,7 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args) .connect_timeout = args->connect_timeout, .reconnect_timeout = args->reconnect_timeout, }; - char servername[48]; + char servername[RPC_MAXNETNAMELEN]; struct rpc_clnt *clnt; int i; @@ -667,6 +661,7 @@ static struct rpc_clnt *__rpc_clone_client(struct rpc_create_args *args, new->cl_noretranstimeo = clnt->cl_noretranstimeo; new->cl_discrtry = clnt->cl_discrtry; new->cl_chatty = clnt->cl_chatty; + new->cl_netunreach_fatal = clnt->cl_netunreach_fatal; new->cl_principal = clnt->cl_principal; new->cl_max_connect = clnt->cl_max_connect; return new; @@ -963,12 +958,17 @@ void rpc_shutdown_client(struct rpc_clnt *clnt) trace_rpc_clnt_shutdown(clnt); + clnt->cl_shutdown = 1; while (!list_empty(&clnt->cl_tasks)) { rpc_killall_tasks(clnt); wait_event_timeout(destroy_wait, list_empty(&clnt->cl_tasks), 1*HZ); } + /* wait for tasks still in workqueue or waitqueue */ + wait_event_timeout(destroy_wait, + atomic_read(&clnt->cl_task_count) == 0, 1 * HZ); + rpc_release_client(clnt); } EXPORT_SYMBOL_GPL(rpc_shutdown_client); @@ -1071,6 +1071,7 @@ struct rpc_clnt *rpc_bind_new_program(struct rpc_clnt *old, .authflavor = old->cl_auth->au_flavor, .cred = old->cl_cred, .stats = old->cl_stats, + .timeout = old->cl_timeout, }; struct rpc_clnt *clnt; int err; @@ -1143,6 +1144,7 @@ void rpc_task_release_client(struct rpc_task *task) list_del(&task->tk_task); spin_unlock(&clnt->cl_lock); task->tk_client = NULL; + atomic_dec(&clnt->cl_task_count); rpc_release_client(clnt); } @@ -1193,10 +1195,9 @@ void rpc_task_set_client(struct rpc_task *task, struct rpc_clnt *clnt) task->tk_flags |= RPC_TASK_TIMEOUT; if (clnt->cl_noretranstimeo) task->tk_flags |= RPC_TASK_NO_RETRANS_TIMEOUT; - /* Add to the client's list of all tasks */ - spin_lock(&clnt->cl_lock); - list_add_tail(&task->tk_task, &clnt->cl_tasks); - spin_unlock(&clnt->cl_lock); + if (clnt->cl_netunreach_fatal) + task->tk_flags |= RPC_TASK_NETUNREACH_FATAL; + atomic_inc(&clnt->cl_task_count); } static void @@ -1791,9 +1792,14 @@ call_reserveresult(struct rpc_task *task) if (status >= 0) { if (task->tk_rqstp) { task->tk_action = call_refresh; + + /* Add to the client's list of all tasks */ + spin_lock(&task->tk_client->cl_lock); + if (list_empty(&task->tk_task)) + list_add_tail(&task->tk_task, &task->tk_client->cl_tasks); + spin_unlock(&task->tk_client->cl_lock); return; } - rpc_call_rpcerror(task, -EIO); return; } @@ -1858,13 +1864,13 @@ call_refreshresult(struct rpc_task *task) fallthrough; case -EAGAIN: status = -EACCES; - fallthrough; - case -EKEYEXPIRED: if (!task->tk_cred_retry) break; task->tk_cred_retry--; trace_rpc_retry_refresh_status(task); return; + case -EKEYEXPIRED: + break; case -ENOMEM: rpc_delay(task, HZ >> 4); return; @@ -1892,12 +1898,6 @@ call_allocate(struct rpc_task *task) if (req->rq_buffer) return; - if (proc->p_proc != 0) { - BUG_ON(proc->p_arglen == 0); - if (proc->p_decode != NULL) - BUG_ON(proc->p_replen == 0); - } - /* * Calculate the size (in quads) of the RPC call * and reply headers, and convert both values @@ -2104,14 +2104,17 @@ call_bind_status(struct rpc_task *task) case -EPROTONOSUPPORT: trace_rpcb_bind_version_err(task); goto retry_timeout; + case -ENETDOWN: + case -ENETUNREACH: + if (task->tk_flags & RPC_TASK_NETUNREACH_FATAL) + break; + fallthrough; case -ECONNREFUSED: /* connection problems */ case -ECONNRESET: case -ECONNABORTED: case -ENOTCONN: case -EHOSTDOWN: - case -ENETDOWN: case -EHOSTUNREACH: - case -ENETUNREACH: case -EPIPE: trace_rpcb_unreachable_err(task); if (!RPC_IS_SOFTCONN(task)) { @@ -2193,19 +2196,22 @@ call_connect_status(struct rpc_task *task) task->tk_status = 0; switch (status) { + case -ENETDOWN: + case -ENETUNREACH: + if (task->tk_flags & RPC_TASK_NETUNREACH_FATAL) + break; + fallthrough; case -ECONNREFUSED: case -ECONNRESET: /* A positive refusal suggests a rebind is needed. */ - if (RPC_IS_SOFTCONN(task)) - break; if (clnt->cl_autobind) { rpc_force_rebind(clnt); + if (RPC_IS_SOFTCONN(task)) + break; goto out_retry; } fallthrough; case -ECONNABORTED: - case -ENETDOWN: - case -ENETUNREACH: case -EHOSTUNREACH: case -EPIPE: case -EPROTO: @@ -2325,12 +2331,13 @@ call_transmit_status(struct rpc_task *task) task->tk_action = call_transmit; task->tk_status = 0; break; - case -ECONNREFUSED: case -EHOSTDOWN: case -ENETDOWN: case -EHOSTUNREACH: case -ENETUNREACH: case -EPERM: + break; + case -ECONNREFUSED: if (RPC_IS_SOFTCONN(task)) { if (!task->tk_msg.rpc_proc->p_proc) trace_xprt_ping(task->tk_xprt, @@ -2456,10 +2463,13 @@ call_status(struct rpc_task *task) trace_rpc_call_status(task); task->tk_status = 0; switch(status) { - case -EHOSTDOWN: case -ENETDOWN: - case -EHOSTUNREACH: case -ENETUNREACH: + if (task->tk_flags & RPC_TASK_NETUNREACH_FATAL) + goto out_exit; + fallthrough; + case -EHOSTDOWN: + case -EHOSTUNREACH: case -EPERM: if (RPC_IS_SOFTCONN(task)) goto out_exit; @@ -2698,8 +2708,19 @@ rpc_decode_header(struct rpc_task *task, struct xdr_stream *xdr) goto out_msg_denied; error = rpcauth_checkverf(task, xdr); - if (error) + if (error) { + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + if (!test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) { + rpcauth_invalcred(task); + if (!task->tk_cred_retry) + goto out_err; + task->tk_cred_retry--; + trace_rpc__stale_creds(task); + return -EKEYREJECTED; + } goto out_verifier; + } p = xdr_inline_decode(xdr, sizeof(*p)); if (!p) @@ -2750,8 +2771,13 @@ out_verifier: case -EPROTONOSUPPORT: goto out_err; case -EACCES: - /* Re-encode with a fresh cred */ - fallthrough; + /* possible RPCSEC_GSS out-of-sequence event (RFC2203), + * reset recv state and keep waiting, don't retransmit + */ + task->tk_rqstp->rq_reply_bytes_recvd = 0; + task->tk_status = xprt_request_enqueue_receive(task); + task->tk_action = call_transmit_status; + return -EBADMSG; default: goto out_garbage; } @@ -3317,8 +3343,11 @@ bool rpc_clnt_xprt_switch_has_addr(struct rpc_clnt *clnt, EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_has_addr); #if IS_ENABLED(CONFIG_SUNRPC_DEBUG) -static void rpc_show_header(void) +static void rpc_show_header(struct rpc_clnt *clnt) { + printk(KERN_INFO "clnt[%pISpc] RPC tasks[%d]\n", + (struct sockaddr *)&clnt->cl_xprt->addr, + atomic_read(&clnt->cl_task_count)); printk(KERN_INFO "-pid- flgs status -client- --rqstp- " "-timeout ---ops--\n"); } @@ -3350,7 +3379,7 @@ void rpc_show_tasks(struct net *net) spin_lock(&clnt->cl_lock); list_for_each_entry(task, &clnt->cl_tasks, tk_task) { if (!header) { - rpc_show_header(); + rpc_show_header(clnt); header++; } rpc_show_task(clnt, task); diff --git a/net/sunrpc/debugfs.c b/net/sunrpc/debugfs.c index a176d5a0b0ee..32417db340de 100644 --- a/net/sunrpc/debugfs.c +++ b/net/sunrpc/debugfs.c @@ -74,6 +74,9 @@ tasks_stop(struct seq_file *f, void *v) { struct rpc_clnt *clnt = f->private; spin_unlock(&clnt->cl_lock); + seq_printf(f, "clnt[%pISpc] RPC tasks[%d]\n", + (struct sockaddr *)&clnt->cl_xprt->addr, + atomic_read(&clnt->cl_task_count)); } static const struct seq_operations tasks_seq_operations = { @@ -179,6 +182,18 @@ xprt_info_show(struct seq_file *f, void *v) seq_printf(f, "addr: %s\n", xprt->address_strings[RPC_DISPLAY_ADDR]); seq_printf(f, "port: %s\n", xprt->address_strings[RPC_DISPLAY_PORT]); seq_printf(f, "state: 0x%lx\n", xprt->state); + seq_printf(f, "netns: %u\n", xprt->xprt_net->ns.inum); + + if (xprt->ops->get_srcaddr) { + int ret, buflen; + char buf[INET6_ADDRSTRLEN]; + + buflen = ARRAY_SIZE(buf); + ret = xprt->ops->get_srcaddr(xprt, buf, buflen); + if (ret < 0) + ret = sprintf(buf, "<closed>"); + seq_printf(f, "saddr: %.*s\n", ret, buf); + } return 0; } diff --git a/net/sunrpc/rpc_pipe.c b/net/sunrpc/rpc_pipe.c index 910a5d850d04..98f78cd55905 100644 --- a/net/sunrpc/rpc_pipe.c +++ b/net/sunrpc/rpc_pipe.c @@ -385,7 +385,6 @@ rpc_pipe_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) static const struct file_operations rpc_pipe_fops = { .owner = THIS_MODULE, - .llseek = no_llseek, .read = rpc_pipe_read, .write = rpc_pipe_write, .poll = rpc_pipe_poll, @@ -631,8 +630,8 @@ static int __rpc_rmpipe(struct inode *dir, struct dentry *dentry) static struct dentry *__rpc_lookup_create_exclusive(struct dentry *parent, const char *name) { - struct qstr q = QSTR_INIT(name, strlen(name)); - struct dentry *dentry = d_hash_and_lookup(parent, &q); + struct qstr q = QSTR(name); + struct dentry *dentry = try_lookup_noperm(&q, parent); if (!dentry) { dentry = d_alloc(parent, &q); if (!dentry) @@ -659,7 +658,7 @@ static void __rpc_depopulate(struct dentry *parent, for (i = start; i < eof; i++) { name.name = files[i].name; name.len = strlen(files[i].name); - dentry = d_hash_and_lookup(parent, &name); + dentry = try_lookup_noperm(&name, parent); if (dentry == NULL) continue; @@ -1191,8 +1190,7 @@ static const struct rpc_filelist files[] = { struct dentry *rpc_d_lookup_sb(const struct super_block *sb, const unsigned char *dir_name) { - struct qstr dir = QSTR_INIT(dir_name, strlen(dir_name)); - return d_hash_and_lookup(sb->s_root, &dir); + return try_lookup_noperm(&QSTR(dir_name), sb->s_root); } EXPORT_SYMBOL_GPL(rpc_d_lookup_sb); @@ -1301,11 +1299,9 @@ rpc_gssd_dummy_populate(struct dentry *root, struct rpc_pipe *pipe_data) struct dentry *gssd_dentry; struct dentry *clnt_dentry = NULL; struct dentry *pipe_dentry = NULL; - struct qstr q = QSTR_INIT(files[RPCAUTH_gssd].name, - strlen(files[RPCAUTH_gssd].name)); /* We should never get this far if "gssd" doesn't exist */ - gssd_dentry = d_hash_and_lookup(root, &q); + gssd_dentry = try_lookup_noperm(&QSTR(files[RPCAUTH_gssd].name), root); if (!gssd_dentry) return ERR_PTR(-ENOENT); @@ -1315,9 +1311,8 @@ rpc_gssd_dummy_populate(struct dentry *root, struct rpc_pipe *pipe_data) goto out; } - q.name = gssd_dummy_clnt_dir[0].name; - q.len = strlen(gssd_dummy_clnt_dir[0].name); - clnt_dentry = d_hash_and_lookup(gssd_dentry, &q); + clnt_dentry = try_lookup_noperm(&QSTR(gssd_dummy_clnt_dir[0].name), + gssd_dentry); if (!clnt_dentry) { __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1); pipe_dentry = ERR_PTR(-ENOENT); diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c index 102c3818bc54..53bcca365fb1 100644 --- a/net/sunrpc/rpcb_clnt.c +++ b/net/sunrpc/rpcb_clnt.c @@ -820,9 +820,10 @@ static void rpcb_getport_done(struct rpc_task *child, void *data) } trace_rpcb_setport(child, map->r_status, map->r_port); - xprt->ops->set_port(xprt, map->r_port); - if (map->r_port) + if (map->r_port) { + xprt->ops->set_port(xprt, map->r_port); xprt_set_bound(xprt); + } } /* diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c index 6debf4fd42d4..73bc39281ef5 100644 --- a/net/sunrpc/sched.c +++ b/net/sunrpc/sched.c @@ -276,6 +276,8 @@ EXPORT_SYMBOL_GPL(rpc_destroy_wait_queue); static int rpc_wait_bit_killable(struct wait_bit_key *key, int mode) { + if (unlikely(current->flags & PF_EXITING)) + return -EINTR; schedule(); if (signal_pending_state(mode, current)) return -ERESTARTSYS; @@ -369,8 +371,10 @@ static void rpc_make_runnable(struct workqueue_struct *wq, if (RPC_IS_ASYNC(task)) { INIT_WORK(&task->u.tk_work, rpc_async_schedule); queue_work(wq, &task->u.tk_work); - } else + } else { + smp_mb__after_atomic(); wake_up_bit(&task->tk_runstate, RPC_TASK_QUEUED); + } } /* @@ -862,8 +866,6 @@ void rpc_signal_task(struct rpc_task *task) if (!rpc_task_set_rpc_status(task, -ERESTARTSYS)) return; trace_rpc_task_signalled(task, task->tk_action); - set_bit(RPC_TASK_SIGNALLED, &task->tk_runstate); - smp_mb__after_atomic(); queue = READ_ONCE(task->tk_waitqueue); if (queue) rpc_wake_up_queued_task(queue, task); diff --git a/net/sunrpc/sunrpc.h b/net/sunrpc/sunrpc.h index d4a362c9e4b3..e3c6e3b63f0b 100644 --- a/net/sunrpc/sunrpc.h +++ b/net/sunrpc/sunrpc.h @@ -36,7 +36,11 @@ static inline int sock_is_loopback(struct sock *sk) return loopback; } +struct svc_serv; +struct svc_rqst; int rpc_clients_notifier_register(void); void rpc_clients_notifier_unregister(void); void auth_domain_cleanup(void); +void svc_sock_update_bufs(struct svc_serv *serv); +enum svc_auth_status svc_authenticate(struct svc_rqst *rqstp); #endif /* _NET_SUNRPC_SUNRPC_H */ diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c index b33e429336fb..9c93b854e809 100644 --- a/net/sunrpc/svc.c +++ b/net/sunrpc/svc.c @@ -32,6 +32,7 @@ #include <trace/events/sunrpc.h> #include "fail.h" +#include "sunrpc.h" #define RPCDBG_FACILITY RPCDBG_SVCDSP @@ -72,57 +73,100 @@ static struct svc_pool_map svc_pool_map = { static DEFINE_MUTEX(svc_pool_map_mutex);/* protects svc_pool_map.count only */ static int -param_set_pool_mode(const char *val, const struct kernel_param *kp) +__param_set_pool_mode(const char *val, struct svc_pool_map *m) { - int *ip = (int *)kp->arg; - struct svc_pool_map *m = &svc_pool_map; - int err; + int err, mode; mutex_lock(&svc_pool_map_mutex); - err = -EBUSY; - if (m->count) - goto out; - err = 0; if (!strncmp(val, "auto", 4)) - *ip = SVC_POOL_AUTO; + mode = SVC_POOL_AUTO; else if (!strncmp(val, "global", 6)) - *ip = SVC_POOL_GLOBAL; + mode = SVC_POOL_GLOBAL; else if (!strncmp(val, "percpu", 6)) - *ip = SVC_POOL_PERCPU; + mode = SVC_POOL_PERCPU; else if (!strncmp(val, "pernode", 7)) - *ip = SVC_POOL_PERNODE; + mode = SVC_POOL_PERNODE; else err = -EINVAL; + if (err) + goto out; + + if (m->count == 0) + m->mode = mode; + else if (mode != m->mode) + err = -EBUSY; out: mutex_unlock(&svc_pool_map_mutex); return err; } static int -param_get_pool_mode(char *buf, const struct kernel_param *kp) +param_set_pool_mode(const char *val, const struct kernel_param *kp) +{ + struct svc_pool_map *m = kp->arg; + + return __param_set_pool_mode(val, m); +} + +int sunrpc_set_pool_mode(const char *val) +{ + return __param_set_pool_mode(val, &svc_pool_map); +} +EXPORT_SYMBOL(sunrpc_set_pool_mode); + +/** + * sunrpc_get_pool_mode - get the current pool_mode for the host + * @buf: where to write the current pool_mode + * @size: size of @buf + * + * Grab the current pool_mode from the svc_pool_map and write + * the resulting string to @buf. Returns the number of characters + * written to @buf (a'la snprintf()). + */ +int +sunrpc_get_pool_mode(char *buf, size_t size) { - int *ip = (int *)kp->arg; + struct svc_pool_map *m = &svc_pool_map; - switch (*ip) + switch (m->mode) { case SVC_POOL_AUTO: - return sysfs_emit(buf, "auto\n"); + return snprintf(buf, size, "auto"); case SVC_POOL_GLOBAL: - return sysfs_emit(buf, "global\n"); + return snprintf(buf, size, "global"); case SVC_POOL_PERCPU: - return sysfs_emit(buf, "percpu\n"); + return snprintf(buf, size, "percpu"); case SVC_POOL_PERNODE: - return sysfs_emit(buf, "pernode\n"); + return snprintf(buf, size, "pernode"); default: - return sysfs_emit(buf, "%d\n", *ip); + return snprintf(buf, size, "%d", m->mode); } } +EXPORT_SYMBOL(sunrpc_get_pool_mode); + +static int +param_get_pool_mode(char *buf, const struct kernel_param *kp) +{ + char str[16]; + int len; + + len = sunrpc_get_pool_mode(str, ARRAY_SIZE(str)); + + /* Ensure we have room for newline and NUL */ + len = min_t(int, len, ARRAY_SIZE(str) - 2); + + /* tack on the newline */ + str[len] = '\n'; + str[len + 1] = '\0'; + + return sysfs_emit(buf, "%s", str); +} module_param_call(pool_mode, param_set_pool_mode, param_get_pool_mode, - &svc_pool_map.mode, 0644); + &svc_pool_map, 0644); /* * Detect best pool mapping mode heuristically, @@ -250,10 +294,8 @@ svc_pool_map_get(void) int npools = -1; mutex_lock(&svc_pool_map_mutex); - if (m->count++) { mutex_unlock(&svc_pool_map_mutex); - WARN_ON_ONCE(m->npools <= 1); return m->npools; } @@ -275,32 +317,21 @@ svc_pool_map_get(void) m->mode = SVC_POOL_GLOBAL; } m->npools = npools; - - if (npools == 1) - /* service is unpooled, so doesn't hold a reference */ - m->count--; - mutex_unlock(&svc_pool_map_mutex); return npools; } /* - * Drop a reference to the global map of cpus to pools, if - * pools were in use, i.e. if npools > 1. + * Drop a reference to the global map of cpus to pools. * When the last reference is dropped, the map data is - * freed; this allows the sysadmin to change the pool - * mode using the pool_mode module option without - * rebooting or re-loading sunrpc.ko. + * freed; this allows the sysadmin to change the pool. */ static void -svc_pool_map_put(int npools) +svc_pool_map_put(void) { struct svc_pool_map *m = &svc_pool_map; - if (npools <= 1) - return; mutex_lock(&svc_pool_map_mutex); - if (!--m->count) { kfree(m->to_pool); m->to_pool = NULL; @@ -308,7 +339,6 @@ svc_pool_map_put(int npools) m->pool_to = NULL; m->npools = 0; } - mutex_unlock(&svc_pool_map_mutex); } @@ -388,7 +418,7 @@ struct svc_pool *svc_pool_for_cpu(struct svc_serv *serv) return &serv->sv_pools[pidx % serv->sv_nrpools]; } -int svc_rpcb_setup(struct svc_serv *serv, struct net *net) +static int svc_rpcb_setup(struct svc_serv *serv, struct net *net) { int err; @@ -400,7 +430,6 @@ int svc_rpcb_setup(struct svc_serv *serv, struct net *net) svc_unregister(serv, net); return 0; } -EXPORT_SYMBOL_GPL(svc_rpcb_setup); void svc_rpcb_cleanup(struct svc_serv *serv, struct net *net) { @@ -411,10 +440,11 @@ EXPORT_SYMBOL_GPL(svc_rpcb_cleanup); static int svc_uses_rpcbind(struct svc_serv *serv) { - struct svc_program *progp; - unsigned int i; + unsigned int p, i; + + for (p = 0; p < serv->sv_nprogs; p++) { + struct svc_program *progp = &serv->sv_programs[p]; - for (progp = serv->sv_program; progp; progp = progp->pg_next) { for (i = 0; i < progp->pg_nvers; i++) { if (progp->pg_vers[i] == NULL) continue; @@ -451,7 +481,7 @@ __svc_init_bc(struct svc_serv *serv) * Create an RPC service */ static struct svc_serv * -__svc_create(struct svc_program *prog, struct svc_stat *stats, +__svc_create(struct svc_program *prog, int nprogs, struct svc_stat *stats, unsigned int bufsize, int npools, int (*threadfn)(void *data)) { struct svc_serv *serv; @@ -462,7 +492,8 @@ __svc_create(struct svc_program *prog, struct svc_stat *stats, if (!(serv = kzalloc(sizeof(*serv), GFP_KERNEL))) return NULL; serv->sv_name = prog->pg_name; - serv->sv_program = prog; + serv->sv_programs = prog; + serv->sv_nprogs = nprogs; serv->sv_stats = stats; if (bufsize > RPCSVC_MAXPAYLOAD) bufsize = RPCSVC_MAXPAYLOAD; @@ -470,17 +501,18 @@ __svc_create(struct svc_program *prog, struct svc_stat *stats, serv->sv_max_mesg = roundup(serv->sv_max_payload + PAGE_SIZE, PAGE_SIZE); serv->sv_threadfn = threadfn; xdrsize = 0; - while (prog) { - prog->pg_lovers = prog->pg_nvers-1; - for (vers=0; vers<prog->pg_nvers ; vers++) - if (prog->pg_vers[vers]) { - prog->pg_hivers = vers; - if (prog->pg_lovers > vers) - prog->pg_lovers = vers; - if (prog->pg_vers[vers]->vs_xdrsize > xdrsize) - xdrsize = prog->pg_vers[vers]->vs_xdrsize; + for (i = 0; i < nprogs; i++) { + struct svc_program *progp = &prog[i]; + + progp->pg_lovers = progp->pg_nvers-1; + for (vers = 0; vers < progp->pg_nvers ; vers++) + if (progp->pg_vers[vers]) { + progp->pg_hivers = vers; + if (progp->pg_lovers > vers) + progp->pg_lovers = vers; + if (progp->pg_vers[vers]->vs_xdrsize > xdrsize) + xdrsize = progp->pg_vers[vers]->vs_xdrsize; } - prog = prog->pg_next; } serv->sv_xdrsize = xdrsize; INIT_LIST_HEAD(&serv->sv_tempsocks); @@ -529,13 +561,14 @@ __svc_create(struct svc_program *prog, struct svc_stat *stats, struct svc_serv *svc_create(struct svc_program *prog, unsigned int bufsize, int (*threadfn)(void *data)) { - return __svc_create(prog, NULL, bufsize, 1, threadfn); + return __svc_create(prog, 1, NULL, bufsize, 1, threadfn); } EXPORT_SYMBOL_GPL(svc_create); /** * svc_create_pooled - Create an RPC service with pooled threads - * @prog: the RPC program the new service will handle + * @prog: Array of RPC programs the new service will handle + * @nprogs: Number of programs in the array * @stats: the stats struct if desired * @bufsize: maximum message size for @prog * @threadfn: a function to service RPC requests for @prog @@ -543,6 +576,7 @@ EXPORT_SYMBOL_GPL(svc_create); * Returns an instantiated struct svc_serv object or NULL. */ struct svc_serv *svc_create_pooled(struct svc_program *prog, + unsigned int nprogs, struct svc_stat *stats, unsigned int bufsize, int (*threadfn)(void *data)) @@ -550,12 +584,13 @@ struct svc_serv *svc_create_pooled(struct svc_program *prog, struct svc_serv *serv; unsigned int npools = svc_pool_map_get(); - serv = __svc_create(prog, stats, bufsize, npools, threadfn); + serv = __svc_create(prog, nprogs, stats, bufsize, npools, threadfn); if (!serv) goto out_err; + serv->sv_is_pooled = true; return serv; out_err: - svc_pool_map_put(npools); + svc_pool_map_put(); return NULL; } EXPORT_SYMBOL_GPL(svc_create_pooled); @@ -572,20 +607,21 @@ svc_destroy(struct svc_serv **servp) *servp = NULL; - dprintk("svc: svc_destroy(%s)\n", serv->sv_program->pg_name); + dprintk("svc: svc_destroy(%s)\n", serv->sv_programs->pg_name); timer_shutdown_sync(&serv->sv_temptimer); /* * Remaining transports at this point are not expected. */ WARN_ONCE(!list_empty(&serv->sv_permsocks), - "SVC: permsocks remain for %s\n", serv->sv_program->pg_name); + "SVC: permsocks remain for %s\n", serv->sv_programs->pg_name); WARN_ONCE(!list_empty(&serv->sv_tempsocks), - "SVC: tempsocks remain for %s\n", serv->sv_program->pg_name); + "SVC: tempsocks remain for %s\n", serv->sv_programs->pg_name); cache_clean_deferred(serv); - svc_pool_map_put(serv->sv_nrpools); + if (serv->sv_is_pooled) + svc_pool_map_put(); for (i = 0; i < serv->sv_nrpools; i++) { struct svc_pool *pool = &serv->sv_pools[i]; @@ -600,24 +636,18 @@ svc_destroy(struct svc_serv **servp) EXPORT_SYMBOL_GPL(svc_destroy); static bool -svc_init_buffer(struct svc_rqst *rqstp, unsigned int size, int node) +svc_init_buffer(struct svc_rqst *rqstp, const struct svc_serv *serv, int node) { - unsigned long pages, ret; - - /* bc_xprt uses fore channel allocated buffers */ - if (svc_is_backchannel(rqstp)) - return true; - - pages = size / PAGE_SIZE + 1; /* extra page as we hold both request and reply. - * We assume one is at most one page - */ - WARN_ON_ONCE(pages > RPCSVC_MAXPAGES); - if (pages > RPCSVC_MAXPAGES) - pages = RPCSVC_MAXPAGES; - - ret = alloc_pages_bulk_array_node(GFP_KERNEL, node, pages, - rqstp->rq_pages); - return ret == pages; + rqstp->rq_maxpages = svc_serv_maxpages(serv); + + /* rq_pages' last entry is NULL for historical reasons. */ + rqstp->rq_pages = kcalloc_node(rqstp->rq_maxpages + 1, + sizeof(struct page *), + GFP_KERNEL, node); + if (!rqstp->rq_pages) + return false; + + return true; } /* @@ -626,15 +656,30 @@ svc_init_buffer(struct svc_rqst *rqstp, unsigned int size, int node) static void svc_release_buffer(struct svc_rqst *rqstp) { - unsigned int i; + unsigned long i; - for (i = 0; i < ARRAY_SIZE(rqstp->rq_pages); i++) + for (i = 0; i < rqstp->rq_maxpages; i++) if (rqstp->rq_pages[i]) put_page(rqstp->rq_pages[i]); + kfree(rqstp->rq_pages); +} + +static void +svc_rqst_free(struct svc_rqst *rqstp) +{ + folio_batch_release(&rqstp->rq_fbatch); + kfree(rqstp->rq_bvec); + svc_release_buffer(rqstp); + if (rqstp->rq_scratch_page) + put_page(rqstp->rq_scratch_page); + kfree(rqstp->rq_resp); + kfree(rqstp->rq_argp); + kfree(rqstp->rq_auth_data); + kfree_rcu(rqstp, rq_rcu_head); } -struct svc_rqst * -svc_rqst_alloc(struct svc_serv *serv, struct svc_pool *pool, int node) +static struct svc_rqst * +svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) { struct svc_rqst *rqstp; @@ -659,30 +704,19 @@ svc_rqst_alloc(struct svc_serv *serv, struct svc_pool *pool, int node) if (!rqstp->rq_resp) goto out_enomem; - if (!svc_init_buffer(rqstp, serv->sv_max_mesg, node)) + if (!svc_init_buffer(rqstp, serv, node)) goto out_enomem; - return rqstp; -out_enomem: - svc_rqst_free(rqstp); - return NULL; -} -EXPORT_SYMBOL_GPL(svc_rqst_alloc); - -static struct svc_rqst * -svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) -{ - struct svc_rqst *rqstp; + rqstp->rq_bvec = kcalloc_node(rqstp->rq_maxpages, + sizeof(struct bio_vec), + GFP_KERNEL, node); + if (!rqstp->rq_bvec) + goto out_enomem; - rqstp = svc_rqst_alloc(serv, pool, node); - if (!rqstp) - return ERR_PTR(-ENOMEM); + rqstp->rq_err = -EAGAIN; /* No error yet */ - spin_lock_bh(&serv->sv_lock); serv->sv_nrthreads += 1; - spin_unlock_bh(&serv->sv_lock); - - atomic_inc(&pool->sp_nrthreads); + pool->sp_nrthreads += 1; /* Protected by whatever lock the service uses when calling * svc_set_num_threads() @@ -690,6 +724,10 @@ svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) list_add_rcu(&rqstp->rq_all, &pool->sp_all_threads); return rqstp; + +out_enomem: + svc_rqst_free(rqstp); + return NULL; } /** @@ -737,31 +775,22 @@ svc_pool_victim(struct svc_serv *serv, struct svc_pool *target_pool, struct svc_pool *pool; unsigned int i; -retry: pool = target_pool; - if (pool != NULL) { - if (atomic_inc_not_zero(&pool->sp_nrthreads)) - goto found_pool; - return NULL; - } else { + if (!pool) { for (i = 0; i < serv->sv_nrpools; i++) { pool = &serv->sv_pools[--(*state) % serv->sv_nrpools]; - if (atomic_inc_not_zero(&pool->sp_nrthreads)) - goto found_pool; + if (pool->sp_nrthreads) + break; } - return NULL; } -found_pool: - set_bit(SP_VICTIM_REMAINS, &pool->sp_flags); - set_bit(SP_NEED_VICTIM, &pool->sp_flags); - if (!atomic_dec_and_test(&pool->sp_nrthreads)) + if (pool && pool->sp_nrthreads) { + set_bit(SP_VICTIM_REMAINS, &pool->sp_flags); + set_bit(SP_NEED_VICTIM, &pool->sp_flags); return pool; - /* Nothing left in this pool any more */ - clear_bit(SP_NEED_VICTIM, &pool->sp_flags); - clear_bit(SP_VICTIM_REMAINS, &pool->sp_flags); - goto retry; + } + return NULL; } static int @@ -772,6 +801,7 @@ svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) struct svc_pool *chosen_pool; unsigned int state = serv->sv_nrthreads-1; int node; + int err; do { nrservs--; @@ -779,8 +809,8 @@ svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) node = svc_pool_map_get_node(chosen_pool->sp_id); rqstp = svc_prepare_thread(serv, chosen_pool, node); - if (IS_ERR(rqstp)) - return PTR_ERR(rqstp); + if (!rqstp) + return -ENOMEM; task = kthread_create_on_node(serv->sv_threadfn, rqstp, node, "%s", serv->sv_name); if (IS_ERR(task)) { @@ -794,6 +824,13 @@ svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) svc_sock_update_bufs(serv); wake_up_process(task); + + wait_var_event(&rqstp->rq_err, rqstp->rq_err != -EAGAIN); + err = rqstp->rq_err; + if (err) { + svc_exit_thread(rqstp); + return err; + } } while (nrservs > 0); return 0; @@ -840,7 +877,7 @@ svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) if (!pool) nrservs -= serv->sv_nrthreads; else - nrservs -= atomic_read(&pool->sp_nrthreads); + nrservs -= pool->sp_nrthreads; if (nrservs > 0) return svc_start_kthreads(serv, pool, nrservs); @@ -865,7 +902,7 @@ EXPORT_SYMBOL_GPL(svc_set_num_threads); bool svc_rqst_replace_page(struct svc_rqst *rqstp, struct page *page) { struct page **begin = rqstp->rq_pages; - struct page **end = &rqstp->rq_pages[RPCSVC_MAXPAGES]; + struct page **end = &rqstp->rq_pages[rqstp->rq_maxpages]; if (unlikely(rqstp->rq_next_page < begin || rqstp->rq_next_page > end)) { trace_svc_replace_page_err(rqstp); @@ -902,25 +939,21 @@ void svc_rqst_release_pages(struct svc_rqst *rqstp) } } -/* - * Called from a server thread as it's exiting. Caller must hold the "service - * mutex" for the service. +/** + * svc_exit_thread - finalise the termination of a sunrpc server thread + * @rqstp: the svc_rqst which represents the thread. + * + * When a thread started with svc_new_thread() exits it must call + * svc_exit_thread() as its last act. This must be done with the + * service mutex held. Normally this is held by a DIFFERENT thread, the + * one that is calling svc_set_num_threads() and which will wait for + * SP_VICTIM_REMAINS to be cleared before dropping the mutex. If the + * thread exits for any reason other than svc_thread_should_stop() + * returning %true (which indicated that svc_set_num_threads() is + * waiting for it to exit), then it must take the service mutex itself, + * which can only safely be done using mutex_try_lock(). */ void -svc_rqst_free(struct svc_rqst *rqstp) -{ - folio_batch_release(&rqstp->rq_fbatch); - svc_release_buffer(rqstp); - if (rqstp->rq_scratch_page) - put_page(rqstp->rq_scratch_page); - kfree(rqstp->rq_resp); - kfree(rqstp->rq_argp); - kfree(rqstp->rq_auth_data); - kfree_rcu(rqstp, rq_rcu_head); -} -EXPORT_SYMBOL_GPL(svc_rqst_free); - -void svc_exit_thread(struct svc_rqst *rqstp) { struct svc_serv *serv = rqstp->rq_server; @@ -928,11 +961,8 @@ svc_exit_thread(struct svc_rqst *rqstp) list_del_rcu(&rqstp->rq_all); - atomic_dec(&pool->sp_nrthreads); - - spin_lock_bh(&serv->sv_lock); + pool->sp_nrthreads -= 1; serv->sv_nrthreads -= 1; - spin_unlock_bh(&serv->sv_lock); svc_sock_update_bufs(serv); svc_rqst_free(rqstp); @@ -1067,6 +1097,7 @@ static int __svc_register(struct net *net, const char *progname, return error; } +static int svc_rpcbind_set_version(struct net *net, const struct svc_program *progp, u32 version, int family, @@ -1077,7 +1108,6 @@ int svc_rpcbind_set_version(struct net *net, version, family, proto, port); } -EXPORT_SYMBOL_GPL(svc_rpcbind_set_version); int svc_generic_rpcbind_set(struct net *net, const struct svc_program *progp, @@ -1125,15 +1155,16 @@ int svc_register(const struct svc_serv *serv, struct net *net, const int family, const unsigned short proto, const unsigned short port) { - struct svc_program *progp; - unsigned int i; + unsigned int p, i; int error = 0; WARN_ON_ONCE(proto == 0 && port == 0); if (proto == 0 && port == 0) return -EINVAL; - for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (p = 0; p < serv->sv_nprogs; p++) { + struct svc_program *progp = &serv->sv_programs[p]; + for (i = 0; i < progp->pg_nvers; i++) { error = progp->pg_rpcbind_set(net, progp, i, @@ -1185,13 +1216,14 @@ static void __svc_unregister(struct net *net, const u32 program, const u32 versi static void svc_unregister(const struct svc_serv *serv, struct net *net) { struct sighand_struct *sighand; - struct svc_program *progp; unsigned long flags; - unsigned int i; + unsigned int p, i; clear_thread_flag(TIF_SIGPENDING); - for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (p = 0; p < serv->sv_nprogs; p++) { + struct svc_program *progp = &serv->sv_programs[p]; + for (i = 0; i < progp->pg_nvers; i++) { if (progp->pg_vers[i] == NULL) continue; @@ -1265,8 +1297,6 @@ svc_generic_init_request(struct svc_rqst *rqstp, if (rqstp->rq_proc >= versp->vs_nproc) goto err_bad_proc; rqstp->rq_procinfo = procp = &versp->vs_proc[rqstp->rq_proc]; - if (!procp) - goto err_bad_proc; /* Initialize storage for argp and resp */ memset(rqstp->rq_argp, 0, procp->pc_argzero); @@ -1293,13 +1323,13 @@ static int svc_process_common(struct svc_rqst *rqstp) { struct xdr_stream *xdr = &rqstp->rq_res_stream; - struct svc_program *progp; + struct svc_program *progp = NULL; const struct svc_procedure *procp = NULL; struct svc_serv *serv = rqstp->rq_server; struct svc_process_info process; enum svc_auth_status auth_res; unsigned int aoffset; - int rc; + int pr, rc; __be32 *p; /* Will be turned off only when NFSv4 Sessions are used */ @@ -1323,9 +1353,9 @@ svc_process_common(struct svc_rqst *rqstp) rqstp->rq_vers = be32_to_cpup(p++); rqstp->rq_proc = be32_to_cpup(p); - for (progp = serv->sv_program; progp; progp = progp->pg_next) - if (rqstp->rq_prog == progp->pg_prog) - break; + for (pr = 0; pr < serv->sv_nprogs; pr++) + if (rqstp->rq_prog == serv->sv_programs[pr].pg_prog) + progp = &serv->sv_programs[pr]; /* * Decode auth data, and add verifier to reply buffer. @@ -1341,7 +1371,8 @@ svc_process_common(struct svc_rqst *rqstp) case SVC_OK: break; case SVC_GARBAGE: - goto err_garbage_args; + rqstp->rq_auth_stat = rpc_autherr_badcred; + goto err_bad_auth; case SVC_SYSERR: goto err_system_err; case SVC_DENIED: @@ -1482,14 +1513,6 @@ err_bad_proc: *rqstp->rq_accept_statp = rpc_proc_unavail; goto sendit; -err_garbage_args: - svc_printk(rqstp, "failed to decode RPC header\n"); - - if (serv->sv_stats) - serv->sv_stats->rpcbadfmt++; - *rqstp->rq_accept_statp = rpc_garbage_args; - goto sendit; - err_system_err: if (serv->sv_stats) serv->sv_stats->rpcbadfmt++; @@ -1497,6 +1520,14 @@ err_system_err: goto sendit; } +/* + * Drop request + */ +static void svc_drop(struct svc_rqst *rqstp) +{ + trace_svc_drop(rqstp); +} + /** * svc_process - Execute one RPC transaction * @rqstp: RPC transaction context @@ -1559,9 +1590,11 @@ out_drop: */ void svc_process_bc(struct rpc_rqst *req, struct svc_rqst *rqstp) { + struct rpc_timeout timeout = { + .to_increment = 0, + }; struct rpc_task *task; int proc_error; - struct rpc_timeout timeout; /* Build the svc_rqst used by the common processing routine */ rqstp->rq_xid = req->rq_xid; @@ -1614,6 +1647,7 @@ void svc_process_bc(struct rpc_rqst *req, struct svc_rqst *rqstp) timeout.to_initval = req->rq_xprt->timeout->to_initval; timeout.to_retries = req->rq_xprt->timeout->to_retries; } + timeout.to_maxval = timeout.to_initval; memcpy(&req->rq_snd_buf, &rqstp->rq_res, sizeof(req->rq_snd_buf)); task = rpc_run_bc_task(req, &timeout); @@ -1675,46 +1709,6 @@ int svc_encode_result_payload(struct svc_rqst *rqstp, unsigned int offset, EXPORT_SYMBOL_GPL(svc_encode_result_payload); /** - * svc_fill_write_vector - Construct data argument for VFS write call - * @rqstp: svc_rqst to operate on - * @payload: xdr_buf containing only the write data payload - * - * Fills in rqstp::rq_vec, and returns the number of elements. - */ -unsigned int svc_fill_write_vector(struct svc_rqst *rqstp, - struct xdr_buf *payload) -{ - struct page **pages = payload->pages; - struct kvec *first = payload->head; - struct kvec *vec = rqstp->rq_vec; - size_t total = payload->len; - unsigned int i; - - /* Some types of transport can present the write payload - * entirely in rq_arg.pages. In this case, @first is empty. - */ - i = 0; - if (first->iov_len) { - vec[i].iov_base = first->iov_base; - vec[i].iov_len = min_t(size_t, total, first->iov_len); - total -= vec[i].iov_len; - ++i; - } - - while (total) { - vec[i].iov_base = page_address(*pages); - vec[i].iov_len = min_t(size_t, total, PAGE_SIZE); - total -= vec[i].iov_len; - ++i; - ++pages; - } - - WARN_ON_ONCE(i > ARRAY_SIZE(rqstp->rq_vec)); - return i; -} -EXPORT_SYMBOL_GPL(svc_fill_write_vector); - -/** * svc_fill_symlink_pathname - Construct pathname argument for VFS symlink call * @rqstp: svc_rqst to operate on * @first: buffer containing first section of pathname diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c index b4a85a227bd7..8b1837228799 100644 --- a/net/sunrpc/svc_xprt.c +++ b/net/sunrpc/svc_xprt.c @@ -46,7 +46,6 @@ static LIST_HEAD(svc_xprt_class_list); /* SMP locking strategy: * - * svc_pool->sp_lock protects most of the fields of that pool. * svc_serv->sv_lock protects sv_tempsocks, sv_permsocks, sv_tmpcnt. * when both need to be taken (rare), svc_serv->sv_lock is first. * The "service mutex" protects svc_serv->sv_nrthread. @@ -158,6 +157,7 @@ int svc_print_xprts(char *buf, int maxlen) */ void svc_xprt_deferred_close(struct svc_xprt *xprt) { + trace_svc_xprt_close(xprt); if (!test_and_set_bit(XPT_CLOSE, &xprt->xpt_flags)) svc_xprt_enqueue(xprt); } @@ -211,51 +211,6 @@ void svc_xprt_init(struct net *net, struct svc_xprt_class *xcl, } EXPORT_SYMBOL_GPL(svc_xprt_init); -static struct svc_xprt *__svc_xpo_create(struct svc_xprt_class *xcl, - struct svc_serv *serv, - struct net *net, - const int family, - const unsigned short port, - int flags) -{ - struct sockaddr_in sin = { - .sin_family = AF_INET, - .sin_addr.s_addr = htonl(INADDR_ANY), - .sin_port = htons(port), - }; -#if IS_ENABLED(CONFIG_IPV6) - struct sockaddr_in6 sin6 = { - .sin6_family = AF_INET6, - .sin6_addr = IN6ADDR_ANY_INIT, - .sin6_port = htons(port), - }; -#endif - struct svc_xprt *xprt; - struct sockaddr *sap; - size_t len; - - switch (family) { - case PF_INET: - sap = (struct sockaddr *)&sin; - len = sizeof(sin); - break; -#if IS_ENABLED(CONFIG_IPV6) - case PF_INET6: - sap = (struct sockaddr *)&sin6; - len = sizeof(sin6); - break; -#endif - default: - return ERR_PTR(-EAFNOSUPPORT); - } - - xprt = xcl->xcl_ops->xpo_create(serv, net, sap, len, flags); - if (IS_ERR(xprt)) - trace_svc_xprt_create_err(serv->sv_program->pg_name, - xcl->xcl_name, sap, len, xprt); - return xprt; -} - /** * svc_xprt_received - start next receiver thread * @xprt: controlling transport @@ -294,9 +249,8 @@ void svc_add_new_perm_xprt(struct svc_serv *serv, struct svc_xprt *new) } static int _svc_xprt_create(struct svc_serv *serv, const char *xprt_name, - struct net *net, const int family, - const unsigned short port, int flags, - const struct cred *cred) + struct net *net, struct sockaddr *sap, + size_t len, int flags, const struct cred *cred) { struct svc_xprt_class *xcl; @@ -312,8 +266,11 @@ static int _svc_xprt_create(struct svc_serv *serv, const char *xprt_name, goto err; spin_unlock(&svc_xprt_class_lock); - newxprt = __svc_xpo_create(xcl, serv, net, family, port, flags); + newxprt = xcl->xcl_ops->xpo_create(serv, net, sap, len, flags); if (IS_ERR(newxprt)) { + trace_svc_xprt_create_err(serv->sv_programs->pg_name, + xcl->xcl_name, sap, len, + newxprt); module_put(xcl->xcl_owner); return PTR_ERR(newxprt); } @@ -330,6 +287,48 @@ static int _svc_xprt_create(struct svc_serv *serv, const char *xprt_name, } /** + * svc_xprt_create_from_sa - Add a new listener to @serv from socket address + * @serv: target RPC service + * @xprt_name: transport class name + * @net: network namespace + * @sap: socket address pointer + * @flags: SVC_SOCK flags + * @cred: credential to bind to this transport + * + * Return local xprt port on success or %-EPROTONOSUPPORT on failure + */ +int svc_xprt_create_from_sa(struct svc_serv *serv, const char *xprt_name, + struct net *net, struct sockaddr *sap, + int flags, const struct cred *cred) +{ + size_t len; + int err; + + switch (sap->sa_family) { + case AF_INET: + len = sizeof(struct sockaddr_in); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + len = sizeof(struct sockaddr_in6); + break; +#endif + default: + return -EAFNOSUPPORT; + } + + err = _svc_xprt_create(serv, xprt_name, net, sap, len, flags, cred); + if (err == -EPROTONOSUPPORT) { + request_module("svc%s", xprt_name); + err = _svc_xprt_create(serv, xprt_name, net, sap, len, flags, + cred); + } + + return err; +} +EXPORT_SYMBOL_GPL(svc_xprt_create_from_sa); + +/** * svc_xprt_create - Add a new listener to @serv * @serv: target RPC service * @xprt_name: transport class name @@ -339,23 +338,41 @@ static int _svc_xprt_create(struct svc_serv *serv, const char *xprt_name, * @flags: SVC_SOCK flags * @cred: credential to bind to this transport * - * Return values: - * %0: New listener added successfully - * %-EPROTONOSUPPORT: Requested transport type not supported + * Return local xprt port on success or %-EPROTONOSUPPORT on failure */ int svc_xprt_create(struct svc_serv *serv, const char *xprt_name, struct net *net, const int family, const unsigned short port, int flags, const struct cred *cred) { - int err; + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + .sin_port = htons(port), + }; +#if IS_ENABLED(CONFIG_IPV6) + struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + .sin6_port = htons(port), + }; +#endif + struct sockaddr *sap; - err = _svc_xprt_create(serv, xprt_name, net, family, port, flags, cred); - if (err == -EPROTONOSUPPORT) { - request_module("svc%s", xprt_name); - err = _svc_xprt_create(serv, xprt_name, net, family, port, flags, cred); + switch (family) { + case PF_INET: + sap = (struct sockaddr *)&sin; + break; +#if IS_ENABLED(CONFIG_IPV6) + case PF_INET6: + sap = (struct sockaddr *)&sin6; + break; +#endif + default: + return -EAFNOSUPPORT; } - return err; + + return svc_xprt_create_from_sa(serv, xprt_name, net, sap, flags, cred); } EXPORT_SYMBOL_GPL(svc_xprt_create); @@ -471,6 +488,7 @@ void svc_xprt_enqueue(struct svc_xprt *xprt) pool = svc_pool_for_cpu(xprt->xpt_server); percpu_counter_inc(&pool->sp_sockets_queued); + xprt->xpt_qtime = ktime_get(); lwq_enqueue(&xprt->xpt_ready, &pool->sp_xprts); svc_pool_wake_idle_thread(pool); @@ -589,7 +607,8 @@ int svc_port_is_privileged(struct sockaddr *sin) } /* - * Make sure that we don't have too many active connections. If we have, + * Make sure that we don't have too many connections that have not yet + * demonstrated that they have access to the NFS server. If we have, * something must be dropped. It's not clear what will happen if we allow * "too many" connections, but when dealing with network-facing software, * we have to code defensively. Here we do that by imposing hard limits. @@ -601,34 +620,26 @@ int svc_port_is_privileged(struct sockaddr *sin) * The only somewhat efficient mechanism would be if drop old * connections from the same IP first. But right now we don't even * record the client IP in svc_sock. - * - * single-threaded services that expect a lot of clients will probably - * need to set sv_maxconn to override the default value which is based - * on the number of threads */ static void svc_check_conn_limits(struct svc_serv *serv) { - unsigned int limit = serv->sv_maxconn ? serv->sv_maxconn : - (serv->sv_nrthreads+3) * 20; - - if (serv->sv_tmpcnt > limit) { - struct svc_xprt *xprt = NULL; + if (serv->sv_tmpcnt > XPT_MAX_TMP_CONN) { + struct svc_xprt *xprt = NULL, *xprti; spin_lock_bh(&serv->sv_lock); if (!list_empty(&serv->sv_tempsocks)) { - /* Try to help the admin */ - net_notice_ratelimited("%s: too many open connections, consider increasing the %s\n", - serv->sv_name, serv->sv_maxconn ? - "max number of connections" : - "number of threads"); /* * Always select the oldest connection. It's not fair, - * but so is life + * but nor is life. */ - xprt = list_entry(serv->sv_tempsocks.prev, - struct svc_xprt, - xpt_list); - set_bit(XPT_CLOSE, &xprt->xpt_flags); - svc_xprt_get(xprt); + list_for_each_entry_reverse(xprti, &serv->sv_tempsocks, + xpt_list) { + if (!test_bit(XPT_PEER_VALID, &xprti->xpt_flags)) { + xprt = xprti; + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_get(xprt); + break; + } + } } spin_unlock_bh(&serv->sv_lock); @@ -641,21 +652,12 @@ static void svc_check_conn_limits(struct svc_serv *serv) static bool svc_alloc_arg(struct svc_rqst *rqstp) { - struct svc_serv *serv = rqstp->rq_server; struct xdr_buf *arg = &rqstp->rq_arg; unsigned long pages, filled, ret; - pages = (serv->sv_max_mesg + 2 * PAGE_SIZE) >> PAGE_SHIFT; - if (pages > RPCSVC_MAXPAGES) { - pr_warn_once("svc: warning: pages=%lu > RPCSVC_MAXPAGES=%lu\n", - pages, RPCSVC_MAXPAGES); - /* use as many pages as possible */ - pages = RPCSVC_MAXPAGES; - } - + pages = rqstp->rq_maxpages; for (filled = 0; filled < pages; filled = ret) { - ret = alloc_pages_bulk_array(GFP_KERNEL, pages, - rqstp->rq_pages); + ret = alloc_pages_bulk(GFP_KERNEL, pages, rqstp->rq_pages); if (ret > filled) /* Made progress, don't sleep yet */ continue; @@ -888,15 +890,6 @@ void svc_recv(struct svc_rqst *rqstp) } EXPORT_SYMBOL_GPL(svc_recv); -/* - * Drop request - */ -void svc_drop(struct svc_rqst *rqstp) -{ - trace_svc_drop(rqstp); -} -EXPORT_SYMBOL_GPL(svc_drop); - /** * svc_send - Return reply to client * @rqstp: RPC transaction context @@ -929,7 +922,7 @@ void svc_send(struct svc_rqst *rqstp) */ static void svc_age_temp_xprts(struct timer_list *t) { - struct svc_serv *serv = from_timer(serv, t, sv_temptimer); + struct svc_serv *serv = timer_container_of(serv, t, sv_temptimer); struct svc_xprt *xprt; struct list_head *le, *next; @@ -1031,7 +1024,8 @@ static void svc_delete_xprt(struct svc_xprt *xprt) spin_lock_bh(&serv->sv_lock); list_del_init(&xprt->xpt_list); - if (test_bit(XPT_TEMP, &xprt->xpt_flags)) + if (test_bit(XPT_TEMP, &xprt->xpt_flags) && + !test_bit(XPT_PEER_VALID, &xprt->xpt_flags)) serv->sv_tmpcnt--; spin_unlock_bh(&serv->sv_lock); @@ -1260,6 +1254,40 @@ static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt) } /** + * svc_find_listener - find an RPC transport instance + * @serv: pointer to svc_serv to search + * @xcl_name: C string containing transport's class name + * @net: owner net pointer + * @sa: sockaddr containing address + * + * Return the transport instance pointer for the endpoint accepting + * connections/peer traffic from the specified transport class, + * and matching sockaddr. + */ +struct svc_xprt *svc_find_listener(struct svc_serv *serv, const char *xcl_name, + struct net *net, const struct sockaddr *sa) +{ + struct svc_xprt *xprt; + struct svc_xprt *found = NULL; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) { + if (xprt->xpt_net != net) + continue; + if (strcmp(xprt->xpt_class->xcl_name, xcl_name)) + continue; + if (!rpc_cmp_addr_port(sa, (struct sockaddr *)&xprt->xpt_local)) + continue; + found = xprt; + svc_xprt_get(xprt); + break; + } + spin_unlock_bh(&serv->sv_lock); + return found; +} +EXPORT_SYMBOL_GPL(svc_find_listener); + +/** * svc_find_xprt - find an RPC transport instance * @serv: pointer to svc_serv to search * @xcl_name: C string containing transport's class name diff --git a/net/sunrpc/svcauth.c b/net/sunrpc/svcauth.c index 1619211f0960..55b4d2874188 100644 --- a/net/sunrpc/svcauth.c +++ b/net/sunrpc/svcauth.c @@ -18,6 +18,7 @@ #include <linux/sunrpc/svcauth.h> #include <linux/err.h> #include <linux/hash.h> +#include <linux/user_namespace.h> #include <trace/events/sunrpc.h> @@ -98,7 +99,6 @@ enum svc_auth_status svc_authenticate(struct svc_rqst *rqstp) rqstp->rq_authop = aops; return aops->accept(rqstp); } -EXPORT_SYMBOL_GPL(svc_authenticate); /** * svc_set_client - Assign an appropriate 'auth_domain' as the client @@ -176,6 +176,33 @@ rpc_authflavor_t svc_auth_flavor(struct svc_rqst *rqstp) } EXPORT_SYMBOL_GPL(svc_auth_flavor); +/** + * svcauth_map_clnt_to_svc_cred_local - maps a generic cred + * to a svc_cred suitable for use in nfsd. + * @clnt: rpc_clnt associated with nfs client + * @cred: generic cred associated with nfs client + * @svc: returned svc_cred that is suitable for use in nfsd + */ +void svcauth_map_clnt_to_svc_cred_local(struct rpc_clnt *clnt, + const struct cred *cred, + struct svc_cred *svc) +{ + struct user_namespace *userns = clnt->cl_cred ? + clnt->cl_cred->user_ns : &init_user_ns; + + memset(svc, 0, sizeof(struct svc_cred)); + + svc->cr_uid = KUIDT_INIT(from_kuid_munged(userns, cred->fsuid)); + svc->cr_gid = KGIDT_INIT(from_kgid_munged(userns, cred->fsgid)); + svc->cr_flavor = clnt->cl_auth->au_flavor; + if (cred->group_info) + svc->cr_group_info = get_group_info(cred->group_info); + /* These aren't relevant for local (network is bypassed) */ + svc->cr_principal = NULL; + svc->cr_gss_mech = NULL; +} +EXPORT_SYMBOL_GPL(svcauth_map_clnt_to_svc_cred_local); + /************************************************** * 'auth_domains' are stored in a hash table indexed by name. * When the last reference to an 'auth_domain' is dropped, diff --git a/net/sunrpc/svcauth_unix.c b/net/sunrpc/svcauth_unix.c index 04b45588ae6f..8ca98b146ec8 100644 --- a/net/sunrpc/svcauth_unix.c +++ b/net/sunrpc/svcauth_unix.c @@ -697,7 +697,8 @@ svcauth_unix_set_client(struct svc_rqst *rqstp) rqstp->rq_auth_stat = rpc_autherr_badcred; ipm = ip_map_cached_get(xprt); if (ipm == NULL) - ipm = __ip_map_lookup(sn->ip_map_cache, rqstp->rq_server->sv_program->pg_class, + ipm = __ip_map_lookup(sn->ip_map_cache, + rqstp->rq_server->sv_programs->pg_class, &sin6->sin6_addr); if (ipm == NULL) diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index 545017a3daa4..e1c85123b445 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -713,8 +713,7 @@ static int svc_udp_sendto(struct svc_rqst *rqstp) if (svc_xprt_is_dead(xprt)) goto out_notconn; - count = xdr_buf_to_bvec(rqstp->rq_bvec, - ARRAY_SIZE(rqstp->rq_bvec), xdr); + count = xdr_buf_to_bvec(rqstp->rq_bvec, rqstp->rq_maxpages, xdr); iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, rqstp->rq_bvec, count, rqstp->rq_res.len); @@ -1083,9 +1082,6 @@ static void svc_tcp_fragment_received(struct svc_sock *svsk) /* If we have more data, signal svc_xprt_enqueue() to try again */ svsk->sk_tcplen = 0; svsk->sk_marker = xdr_zero; - - smp_wmb(); - tcp_set_rcvlowat(svsk->sk_sk, 1); } /** @@ -1175,17 +1171,10 @@ err_incomplete: goto err_delete; if (len == want) svc_tcp_fragment_received(svsk); - else { - /* Avoid more ->sk_data_ready() calls until the rest - * of the message has arrived. This reduces service - * thread wake-ups on large incoming messages. */ - tcp_set_rcvlowat(svsk->sk_sk, - svc_sock_reclen(svsk) - svsk->sk_tcplen); - + else trace_svcsock_tcp_recv_short(&svsk->sk_xprt, svc_sock_reclen(svsk), svsk->sk_tcplen - sizeof(rpc_fraghdr)); - } goto err_noclose; error: if (len != -EAGAIN) @@ -1206,15 +1195,6 @@ err_noclose: * MSG_SPLICE_PAGES is used exclusively to reduce the number of * copy operations in this path. Therefore the caller must ensure * that the pages backing @xdr are unchanging. - * - * Note that the send is non-blocking. The caller has incremented - * the reference count on each page backing the RPC message, and - * the network layer will "put" these pages when transmission is - * complete. - * - * This is safe for our RPC services because the memory backing - * the head and tail components is never kmalloc'd. These always - * come from pages in the svc_rqst::rq_pages array. */ static int svc_tcp_sendmsg(struct svc_sock *svsk, struct svc_rqst *rqstp, rpc_fraghdr marker, unsigned int *sentp) @@ -1238,12 +1218,13 @@ static int svc_tcp_sendmsg(struct svc_sock *svsk, struct svc_rqst *rqstp, memcpy(buf, &marker, sizeof(marker)); bvec_set_virt(rqstp->rq_bvec, buf, sizeof(marker)); - count = xdr_buf_to_bvec(rqstp->rq_bvec + 1, - ARRAY_SIZE(rqstp->rq_bvec) - 1, &rqstp->rq_res); + count = xdr_buf_to_bvec(rqstp->rq_bvec + 1, rqstp->rq_maxpages, + &rqstp->rq_res); iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, rqstp->rq_bvec, 1 + count, sizeof(marker) + rqstp->rq_res.len); ret = sock_sendmsg(svsk->sk_sock, &msg); + page_frag_free(buf); if (ret < 0) return ret; *sentp += ret; @@ -1358,7 +1339,8 @@ static void svc_tcp_init(struct svc_sock *svsk, struct svc_serv *serv) svsk->sk_marker = xdr_zero; svsk->sk_tcplen = 0; svsk->sk_datalen = 0; - memset(&svsk->sk_pages[0], 0, sizeof(svsk->sk_pages)); + memset(&svsk->sk_pages[0], 0, + svsk->sk_maxpages * sizeof(struct page *)); tcp_sock_set_nodelay(sk); @@ -1386,7 +1368,6 @@ void svc_sock_update_bufs(struct svc_serv *serv) set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); spin_unlock_bh(&serv->sv_lock); } -EXPORT_SYMBOL_GPL(svc_sock_update_bufs); /* * Initialize socket for RPC use and create svc_sock struct @@ -1398,10 +1379,13 @@ static struct svc_sock *svc_setup_socket(struct svc_serv *serv, struct svc_sock *svsk; struct sock *inet; int pmap_register = !(flags & SVC_SOCK_ANONYMOUS); + unsigned long pages; - svsk = kzalloc(sizeof(*svsk), GFP_KERNEL); + pages = svc_serv_maxpages(serv); + svsk = kzalloc(struct_size(svsk, sk_pages, pages), GFP_KERNEL); if (!svsk) return ERR_PTR(-ENOMEM); + svsk->sk_maxpages = pages; inet = sock->sk; @@ -1560,7 +1544,8 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv, newlen = error; if (protocol == IPPROTO_TCP) { - if ((error = kernel_listen(sock, 64)) < 0) + sk_net_refcnt_upgrade(sock->sk); + if ((error = kernel_listen(sock, SOMAXCONN)) < 0) goto bummer; } @@ -1617,7 +1602,6 @@ static void svc_tcp_sock_detach(struct svc_xprt *xprt) static void svc_sock_free(struct svc_xprt *xprt) { struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); - struct page_frag_cache *pfc = &svsk->sk_frag_cache; struct socket *sock = svsk->sk_sock; trace_svcsock_free(svsk, sock); @@ -1627,8 +1611,7 @@ static void svc_sock_free(struct svc_xprt *xprt) sockfd_put(sock); else sock_release(sock); - if (pfc->va) - __page_frag_cache_drain(virt_to_head_page(pfc->va), - pfc->pagecnt_bias); + + page_frag_cache_drain(&svsk->sk_frag_cache); kfree(svsk); } diff --git a/net/sunrpc/sysctl.c b/net/sunrpc/sysctl.c index 93941ab12549..bdb587a72422 100644 --- a/net/sunrpc/sysctl.c +++ b/net/sunrpc/sysctl.c @@ -40,7 +40,7 @@ EXPORT_SYMBOL_GPL(nlm_debug); #if IS_ENABLED(CONFIG_SUNRPC_DEBUG) -static int proc_do_xprt(struct ctl_table *table, int write, +static int proc_do_xprt(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { char tmpbuf[256]; @@ -62,7 +62,7 @@ static int proc_do_xprt(struct ctl_table *table, int write, } static int -proc_dodebug(struct ctl_table *table, int write, void *buffer, size_t *lenp, +proc_dodebug(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { char tmpbuf[20], *s = NULL; @@ -160,7 +160,6 @@ static struct ctl_table debug_table[] = { .mode = 0444, .proc_handler = proc_do_xprt, }, - { } }; void diff --git a/net/sunrpc/sysfs.c b/net/sunrpc/sysfs.c index 5c8ecdaaa985..09434e1143c5 100644 --- a/net/sunrpc/sysfs.c +++ b/net/sunrpc/sysfs.c @@ -59,6 +59,16 @@ static struct kobject *rpc_sysfs_object_alloc(const char *name, return NULL; } +static inline struct rpc_clnt * +rpc_sysfs_client_kobj_get_clnt(struct kobject *kobj) +{ + struct rpc_sysfs_client *c = container_of(kobj, + struct rpc_sysfs_client, kobject); + struct rpc_clnt *ret = c->clnt; + + return refcount_inc_not_zero(&ret->cl_count) ? ret : NULL; +} + static inline struct rpc_xprt * rpc_sysfs_xprt_kobj_get_xprt(struct kobject *kobj) { @@ -86,6 +96,51 @@ rpc_sysfs_xprt_switch_kobj_get_xprt(struct kobject *kobj) return xprt_switch_get(x->xprt_switch); } +static ssize_t rpc_sysfs_clnt_version_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_clnt *clnt = rpc_sysfs_client_kobj_get_clnt(kobj); + ssize_t ret; + + if (!clnt) + return sprintf(buf, "<closed>\n"); + + ret = sprintf(buf, "%u", clnt->cl_vers); + refcount_dec(&clnt->cl_count); + return ret; +} + +static ssize_t rpc_sysfs_clnt_program_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_clnt *clnt = rpc_sysfs_client_kobj_get_clnt(kobj); + ssize_t ret; + + if (!clnt) + return sprintf(buf, "<closed>\n"); + + ret = sprintf(buf, "%s", clnt->cl_program->name); + refcount_dec(&clnt->cl_count); + return ret; +} + +static ssize_t rpc_sysfs_clnt_max_connect_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_clnt *clnt = rpc_sysfs_client_kobj_get_clnt(kobj); + ssize_t ret; + + if (!clnt) + return sprintf(buf, "<closed>\n"); + + ret = sprintf(buf, "%u\n", clnt->cl_max_connect); + refcount_dec(&clnt->cl_count); + return ret; +} + static ssize_t rpc_sysfs_xprt_dstaddr_show(struct kobject *kobj, struct kobj_attribute *attr, char *buf) @@ -129,6 +184,31 @@ static ssize_t rpc_sysfs_xprt_srcaddr_show(struct kobject *kobj, return ret; } +static const char *xprtsec_strings[] = { + [RPC_XPRTSEC_NONE] = "none", + [RPC_XPRTSEC_TLS_ANON] = "tls-anon", + [RPC_XPRTSEC_TLS_X509] = "tls-x509", +}; + +static ssize_t rpc_sysfs_xprt_xprtsec_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + ssize_t ret; + + if (!xprt) { + ret = sprintf(buf, "<closed>\n"); + goto out; + } + + ret = sprintf(buf, "%s\n", xprtsec_strings[xprt->xprtsec.policy]); + xprt_put(xprt); +out: + return ret; + +} + static ssize_t rpc_sysfs_xprt_info_show(struct kobject *kobj, struct kobj_attribute *attr, char *buf) { @@ -206,6 +286,14 @@ static ssize_t rpc_sysfs_xprt_state_show(struct kobject *kobj, return ret; } +static ssize_t rpc_sysfs_xprt_del_xprt_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + return sprintf(buf, "# delete this xprt\n"); +} + + static ssize_t rpc_sysfs_xprt_switch_info_show(struct kobject *kobj, struct kobj_attribute *attr, char *buf) @@ -225,6 +313,55 @@ static ssize_t rpc_sysfs_xprt_switch_info_show(struct kobject *kobj, return ret; } +static ssize_t rpc_sysfs_xprt_switch_add_xprt_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + return sprintf(buf, "# add one xprt to this xprt_switch\n"); +} + +static ssize_t rpc_sysfs_xprt_switch_add_xprt_store(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + struct rpc_xprt_switch *xprt_switch = + rpc_sysfs_xprt_switch_kobj_get_xprt(kobj); + struct xprt_create xprt_create_args; + struct rpc_xprt *xprt, *new; + + if (!xprt_switch) + return 0; + + xprt = rpc_xprt_switch_get_main_xprt(xprt_switch); + if (!xprt) + goto out; + + xprt_create_args.ident = xprt->xprt_class->ident; + xprt_create_args.net = xprt->xprt_net; + xprt_create_args.dstaddr = (struct sockaddr *)&xprt->addr; + xprt_create_args.addrlen = xprt->addrlen; + xprt_create_args.servername = xprt->servername; + xprt_create_args.bc_xprt = xprt->bc_xprt; + xprt_create_args.xprtsec = xprt->xprtsec; + xprt_create_args.connect_timeout = xprt->connect_timeout; + xprt_create_args.reconnect_timeout = xprt->max_reconnect_timeout; + + new = xprt_create_transport(&xprt_create_args); + if (IS_ERR_OR_NULL(new)) { + count = PTR_ERR(new); + goto out_put_xprt; + } + + rpc_xprt_switch_add_xprt(xprt_switch, new); + xprt_put(new); + +out_put_xprt: + xprt_put(xprt); +out: + xprt_switch_put(xprt_switch); + return count; +} + static ssize_t rpc_sysfs_xprt_dstaddr_store(struct kobject *kobj, struct kobj_attribute *attr, const char *buf, size_t count) @@ -335,6 +472,40 @@ out_put: return count; } +static ssize_t rpc_sysfs_xprt_del_xprt(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + struct rpc_xprt_switch *xps = rpc_sysfs_xprt_kobj_get_xprt_switch(kobj); + + if (!xprt || !xps) { + count = 0; + goto out; + } + + if (xprt->main) { + count = -EINVAL; + goto release_tasks; + } + + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) { + count = -EINTR; + goto out_put; + } + + xprt_set_offline_locked(xprt, xps); + xprt_delete_locked(xprt, xps); + +release_tasks: + xprt_release_write(xprt, NULL); +out_put: + xprt_put(xprt); + xprt_switch_put(xps); +out: + return count; +} + int rpc_sysfs_init(void) { rpc_sunrpc_kset = kset_create_and_add("sunrpc", NULL, kernel_kobj); @@ -398,23 +569,48 @@ static const void *rpc_sysfs_xprt_namespace(const struct kobject *kobj) kobject)->xprt->xprt_net; } +static struct kobj_attribute rpc_sysfs_clnt_version = __ATTR(rpc_version, + 0444, rpc_sysfs_clnt_version_show, NULL); + +static struct kobj_attribute rpc_sysfs_clnt_program = __ATTR(program, + 0444, rpc_sysfs_clnt_program_show, NULL); + +static struct kobj_attribute rpc_sysfs_clnt_max_connect = __ATTR(max_connect, + 0444, rpc_sysfs_clnt_max_connect_show, NULL); + +static struct attribute *rpc_sysfs_rpc_clnt_attrs[] = { + &rpc_sysfs_clnt_version.attr, + &rpc_sysfs_clnt_program.attr, + &rpc_sysfs_clnt_max_connect.attr, + NULL, +}; +ATTRIBUTE_GROUPS(rpc_sysfs_rpc_clnt); + static struct kobj_attribute rpc_sysfs_xprt_dstaddr = __ATTR(dstaddr, 0644, rpc_sysfs_xprt_dstaddr_show, rpc_sysfs_xprt_dstaddr_store); static struct kobj_attribute rpc_sysfs_xprt_srcaddr = __ATTR(srcaddr, 0644, rpc_sysfs_xprt_srcaddr_show, NULL); +static struct kobj_attribute rpc_sysfs_xprt_xprtsec = __ATTR(xprtsec, + 0644, rpc_sysfs_xprt_xprtsec_show, NULL); + static struct kobj_attribute rpc_sysfs_xprt_info = __ATTR(xprt_info, 0444, rpc_sysfs_xprt_info_show, NULL); static struct kobj_attribute rpc_sysfs_xprt_change_state = __ATTR(xprt_state, 0644, rpc_sysfs_xprt_state_show, rpc_sysfs_xprt_state_change); +static struct kobj_attribute rpc_sysfs_xprt_del = __ATTR(del_xprt, + 0644, rpc_sysfs_xprt_del_xprt_show, rpc_sysfs_xprt_del_xprt); + static struct attribute *rpc_sysfs_xprt_attrs[] = { &rpc_sysfs_xprt_dstaddr.attr, &rpc_sysfs_xprt_srcaddr.attr, + &rpc_sysfs_xprt_xprtsec.attr, &rpc_sysfs_xprt_info.attr, &rpc_sysfs_xprt_change_state.attr, + &rpc_sysfs_xprt_del.attr, NULL, }; ATTRIBUTE_GROUPS(rpc_sysfs_xprt); @@ -422,14 +618,20 @@ ATTRIBUTE_GROUPS(rpc_sysfs_xprt); static struct kobj_attribute rpc_sysfs_xprt_switch_info = __ATTR(xprt_switch_info, 0444, rpc_sysfs_xprt_switch_info_show, NULL); +static struct kobj_attribute rpc_sysfs_xprt_switch_add_xprt = + __ATTR(add_xprt, 0644, rpc_sysfs_xprt_switch_add_xprt_show, + rpc_sysfs_xprt_switch_add_xprt_store); + static struct attribute *rpc_sysfs_xprt_switch_attrs[] = { &rpc_sysfs_xprt_switch_info.attr, + &rpc_sysfs_xprt_switch_add_xprt.attr, NULL, }; ATTRIBUTE_GROUPS(rpc_sysfs_xprt_switch); static const struct kobj_type rpc_sysfs_client_type = { .release = rpc_sysfs_client_release, + .default_groups = rpc_sysfs_rpc_clnt_groups, .sysfs_ops = &kobj_sysfs_ops, .namespace = rpc_sysfs_client_namespace, }; diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c index 62e07c330a66..2ea00e354ba6 100644 --- a/net/sunrpc/xdr.c +++ b/net/sunrpc/xdr.c @@ -213,6 +213,7 @@ bvec_overflow: pr_warn_once("%s: bio_vec array overflow\n", __func__); return count - 1; } +EXPORT_SYMBOL_GPL(xdr_buf_to_bvec); /** * xdr_inline_pages - Prepare receive buffer for a large reply @@ -1097,6 +1098,12 @@ out_overflow: * Checks that we have enough buffer space to encode 'nbytes' more * bytes of data. If so, update the total xdr_buf length, and * adjust the length of the current kvec. + * + * The returned pointer is valid only until the next call to + * xdr_reserve_space() or xdr_commit_encode() on @xdr. The current + * implementation of this API guarantees that space reserved for a + * four-byte data item remains valid until @xdr is destroyed, but + * that might not always be true in the future. */ __be32 * xdr_reserve_space(struct xdr_stream *xdr, size_t nbytes) { diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c index 09f245cda526..1023361845f9 100644 --- a/net/sunrpc/xprt.c +++ b/net/sunrpc/xprt.c @@ -854,7 +854,7 @@ xprt_schedule_autodisconnect(struct rpc_xprt *xprt) static void xprt_init_autodisconnect(struct timer_list *t) { - struct rpc_xprt *xprt = from_timer(xprt, t, timer); + struct rpc_xprt *xprt = timer_container_of(xprt, t, timer); if (!RB_EMPTY_ROOT(&xprt->recv_queue)) return; @@ -1167,7 +1167,7 @@ xprt_request_enqueue_receive(struct rpc_task *task) spin_unlock(&xprt->queue_lock); /* Turn off autodisconnect */ - del_timer_sync(&xprt->timer); + timer_delete_sync(&xprt->timer); return 0; } @@ -1365,7 +1365,7 @@ xprt_request_enqueue_transmit(struct rpc_task *task) INIT_LIST_HEAD(&req->rq_xmit2); goto out; } - } else if (!req->rq_seqno) { + } else if (req->rq_seqno_count == 0) { list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) { if (pos->rq_task->tk_owner != task->tk_owner) continue; @@ -1898,6 +1898,7 @@ xprt_request_init(struct rpc_task *task) req->rq_snd_buf.bvec = NULL; req->rq_rcv_buf.bvec = NULL; req->rq_release_snd_buf = NULL; + req->rq_seqno_count = 0; xprt_init_majortimeo(task, req, task->tk_client->cl_timeout); trace_xprt_reserve(req); @@ -2138,7 +2139,7 @@ static void xprt_destroy(struct rpc_xprt *xprt) * can only run *before* del_time_sync(), never after. */ spin_lock(&xprt->transport_lock); - del_timer_sync(&xprt->timer); + timer_delete_sync(&xprt->timer); spin_unlock(&xprt->transport_lock); /* diff --git a/net/sunrpc/xprtmultipath.c b/net/sunrpc/xprtmultipath.c index 720d3ba742ec..4c5e08b0aa64 100644 --- a/net/sunrpc/xprtmultipath.c +++ b/net/sunrpc/xprtmultipath.c @@ -92,6 +92,27 @@ void rpc_xprt_switch_remove_xprt(struct rpc_xprt_switch *xps, xprt_put(xprt); } +/** + * rpc_xprt_switch_get_main_xprt - Get the 'main' xprt for an xprt switch. + * @xps: pointer to struct rpc_xprt_switch. + */ +struct rpc_xprt *rpc_xprt_switch_get_main_xprt(struct rpc_xprt_switch *xps) +{ + struct rpc_xprt_iter xpi; + struct rpc_xprt *xprt; + + xprt_iter_init_listall(&xpi, xps); + + xprt = xprt_iter_get_next(&xpi); + while (xprt && !xprt->main) { + xprt_put(xprt); + xprt = xprt_iter_get_next(&xpi); + } + + xprt_iter_destroy(&xpi); + return xprt; +} + static DEFINE_IDA(rpc_xprtswitch_ids); void xprt_multipath_cleanup_ids(void) @@ -603,23 +624,6 @@ struct rpc_xprt *xprt_iter_get_helper(struct rpc_xprt_iter *xpi, } /** - * xprt_iter_get_xprt - Returns the rpc_xprt pointed to by the cursor - * @xpi: pointer to rpc_xprt_iter - * - * Returns a reference to the struct rpc_xprt that is currently - * pointed to by the cursor. - */ -struct rpc_xprt *xprt_iter_get_xprt(struct rpc_xprt_iter *xpi) -{ - struct rpc_xprt *xprt; - - rcu_read_lock(); - xprt = xprt_iter_get_helper(xpi, xprt_iter_ops(xpi)->xpi_xprt); - rcu_read_unlock(); - return xprt; -} - -/** * xprt_iter_get_next - Returns the next rpc_xprt following the cursor * @xpi: pointer to rpc_xprt_iter * diff --git a/net/sunrpc/xprtrdma/Makefile b/net/sunrpc/xprtrdma/Makefile index 55b21bae866d..3232aa23cdb4 100644 --- a/net/sunrpc/xprtrdma/Makefile +++ b/net/sunrpc/xprtrdma/Makefile @@ -1,7 +1,7 @@ # SPDX-License-Identifier: GPL-2.0 obj-$(CONFIG_SUNRPC_XPRT_RDMA) += rpcrdma.o -rpcrdma-y := transport.o rpc_rdma.o verbs.o frwr_ops.o \ +rpcrdma-y := transport.o rpc_rdma.o verbs.o frwr_ops.o ib_client.o \ svc_rdma.o svc_rdma_backchannel.o svc_rdma_transport.o \ svc_rdma_sendto.o svc_rdma_recvfrom.o svc_rdma_rw.o \ svc_rdma_pcl.o module.o diff --git a/net/sunrpc/xprtrdma/frwr_ops.c b/net/sunrpc/xprtrdma/frwr_ops.c index ffbf99894970..31434aeb8e29 100644 --- a/net/sunrpc/xprtrdma/frwr_ops.c +++ b/net/sunrpc/xprtrdma/frwr_ops.c @@ -54,7 +54,7 @@ static void frwr_cid_init(struct rpcrdma_ep *ep, cid->ci_completion_id = mr->mr_ibmr->res.id; } -static void frwr_mr_unmap(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr *mr) +static void frwr_mr_unmap(struct rpcrdma_mr *mr) { if (mr->mr_device) { trace_xprtrdma_mr_unmap(mr); @@ -73,7 +73,7 @@ void frwr_mr_release(struct rpcrdma_mr *mr) { int rc; - frwr_mr_unmap(mr->mr_xprt, mr); + frwr_mr_unmap(mr); rc = ib_dereg_mr(mr->mr_ibmr); if (rc) @@ -84,7 +84,7 @@ void frwr_mr_release(struct rpcrdma_mr *mr) static void frwr_mr_put(struct rpcrdma_mr *mr) { - frwr_mr_unmap(mr->mr_xprt, mr); + frwr_mr_unmap(mr); /* The MR is returned to the req's MR free list instead * of to the xprt's MR free list. No spinlock is needed. @@ -92,7 +92,8 @@ static void frwr_mr_put(struct rpcrdma_mr *mr) rpcrdma_mr_push(mr, &mr->mr_req->rl_free_mrs); } -/* frwr_reset - Place MRs back on the free list +/** + * frwr_reset - Place MRs back on @req's free list * @req: request to reset * * Used after a failed marshal. For FRWR, this means the MRs diff --git a/net/sunrpc/xprtrdma/ib_client.c b/net/sunrpc/xprtrdma/ib_client.c new file mode 100644 index 000000000000..28c68b5f6823 --- /dev/null +++ b/net/sunrpc/xprtrdma/ib_client.c @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2024 Oracle. All rights reserved. + */ + +/* #include <linux/module.h> +#include <linux/slab.h> */ +#include <linux/xarray.h> +#include <linux/types.h> +#include <linux/kref.h> +#include <linux/completion.h> + +#include <linux/sunrpc/svc_rdma.h> +#include <linux/sunrpc/rdma_rn.h> + +#include "xprt_rdma.h" +#include <trace/events/rpcrdma.h> + +/* Per-ib_device private data for rpcrdma */ +struct rpcrdma_device { + struct kref rd_kref; + unsigned long rd_flags; + struct ib_device *rd_device; + struct xarray rd_xa; + struct completion rd_done; +}; + +#define RPCRDMA_RD_F_REMOVING (0) + +static struct ib_client rpcrdma_ib_client; + +/* + * Listeners have no associated device, so we never register them. + * Note that ib_get_client_data() does not check if @device is + * NULL for us. + */ +static struct rpcrdma_device *rpcrdma_get_client_data(struct ib_device *device) +{ + if (!device) + return NULL; + return ib_get_client_data(device, &rpcrdma_ib_client); +} + +/** + * rpcrdma_rn_register - register to get device removal notifications + * @device: device to monitor + * @rn: notification object that wishes to be notified + * @done: callback to notify caller of device removal + * + * Returns zero on success. The callback in rn_done is guaranteed + * to be invoked when the device is removed, unless this notification + * is unregistered first. + * + * On failure, a negative errno is returned. + */ +int rpcrdma_rn_register(struct ib_device *device, + struct rpcrdma_notification *rn, + void (*done)(struct rpcrdma_notification *rn)) +{ + struct rpcrdma_device *rd = rpcrdma_get_client_data(device); + + if (!rd || test_bit(RPCRDMA_RD_F_REMOVING, &rd->rd_flags)) + return -ENETUNREACH; + + if (xa_alloc(&rd->rd_xa, &rn->rn_index, rn, xa_limit_32b, GFP_KERNEL) < 0) + return -ENOMEM; + kref_get(&rd->rd_kref); + rn->rn_done = done; + trace_rpcrdma_client_register(device, rn); + return 0; +} + +static void rpcrdma_rn_release(struct kref *kref) +{ + struct rpcrdma_device *rd = container_of(kref, struct rpcrdma_device, + rd_kref); + + trace_rpcrdma_client_completion(rd->rd_device); + complete(&rd->rd_done); +} + +/** + * rpcrdma_rn_unregister - stop device removal notifications + * @device: monitored device + * @rn: notification object that no longer wishes to be notified + */ +void rpcrdma_rn_unregister(struct ib_device *device, + struct rpcrdma_notification *rn) +{ + struct rpcrdma_device *rd = rpcrdma_get_client_data(device); + + if (!rd) + return; + + trace_rpcrdma_client_unregister(device, rn); + xa_erase(&rd->rd_xa, rn->rn_index); + kref_put(&rd->rd_kref, rpcrdma_rn_release); +} + +/** + * rpcrdma_add_one - ib_client device insertion callback + * @device: device about to be inserted + * + * Returns zero on success. xprtrdma private data has been allocated + * for this device. On failure, a negative errno is returned. + */ +static int rpcrdma_add_one(struct ib_device *device) +{ + struct rpcrdma_device *rd; + + rd = kzalloc(sizeof(*rd), GFP_KERNEL); + if (!rd) + return -ENOMEM; + + kref_init(&rd->rd_kref); + xa_init_flags(&rd->rd_xa, XA_FLAGS_ALLOC); + rd->rd_device = device; + init_completion(&rd->rd_done); + ib_set_client_data(device, &rpcrdma_ib_client, rd); + + trace_rpcrdma_client_add_one(device); + return 0; +} + +/** + * rpcrdma_remove_one - ib_client device removal callback + * @device: device about to be removed + * @client_data: this module's private per-device data + * + * Upon return, all transports associated with @device have divested + * themselves from IB hardware resources. + */ +static void rpcrdma_remove_one(struct ib_device *device, + void *client_data) +{ + struct rpcrdma_device *rd = client_data; + struct rpcrdma_notification *rn; + unsigned long index; + + trace_rpcrdma_client_remove_one(device); + + set_bit(RPCRDMA_RD_F_REMOVING, &rd->rd_flags); + xa_for_each(&rd->rd_xa, index, rn) + rn->rn_done(rn); + + /* + * Wait only if there are still outstanding notification + * registrants for this device. + */ + if (!refcount_dec_and_test(&rd->rd_kref.refcount)) { + trace_rpcrdma_client_wait_on(device); + wait_for_completion(&rd->rd_done); + } + + trace_rpcrdma_client_remove_one_done(device); + xa_destroy(&rd->rd_xa); + kfree(rd); +} + +static struct ib_client rpcrdma_ib_client = { + .name = "rpcrdma", + .add = rpcrdma_add_one, + .remove = rpcrdma_remove_one, +}; + +/** + * rpcrdma_ib_client_unregister - unregister ib_client for xprtrdma + * + * cel: watch for orphaned rpcrdma_device objects on module unload + */ +void rpcrdma_ib_client_unregister(void) +{ + ib_unregister_client(&rpcrdma_ib_client); +} + +/** + * rpcrdma_ib_client_register - register ib_client for rpcrdma + * + * Returns zero on success, or a negative errno. + */ +int rpcrdma_ib_client_register(void) +{ + return ib_register_client(&rpcrdma_ib_client); +} diff --git a/net/sunrpc/xprtrdma/module.c b/net/sunrpc/xprtrdma/module.c index 45c5b41ac8dc..697f571d4c01 100644 --- a/net/sunrpc/xprtrdma/module.c +++ b/net/sunrpc/xprtrdma/module.c @@ -11,6 +11,7 @@ #include <linux/module.h> #include <linux/init.h> #include <linux/sunrpc/svc_rdma.h> +#include <linux/sunrpc/rdma_rn.h> #include <asm/swab.h> @@ -30,21 +31,32 @@ static void __exit rpc_rdma_cleanup(void) { xprt_rdma_cleanup(); svc_rdma_cleanup(); + rpcrdma_ib_client_unregister(); } static int __init rpc_rdma_init(void) { int rc; + rc = rpcrdma_ib_client_register(); + if (rc) + goto out_rc; + rc = svc_rdma_init(); if (rc) - goto out; + goto out_ib_client; rc = xprt_rdma_init(); if (rc) - svc_rdma_cleanup(); + goto out_svc_rdma; -out: + return 0; + +out_svc_rdma: + svc_rdma_cleanup(); +out_ib_client: + rpcrdma_ib_client_unregister(); +out_rc: return rc; } diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c index 190a4de239c8..1478c41c7e9d 100644 --- a/net/sunrpc/xprtrdma/rpc_rdma.c +++ b/net/sunrpc/xprtrdma/rpc_rdma.c @@ -1471,8 +1471,7 @@ void rpcrdma_reply_handler(struct rpcrdma_rep *rep) credits = 1; /* don't deadlock */ else if (credits > r_xprt->rx_ep->re_max_requests) credits = r_xprt->rx_ep->re_max_requests; - rpcrdma_post_recvs(r_xprt, credits + (buf->rb_bc_srv_max_requests << 1), - false); + rpcrdma_post_recvs(r_xprt, credits + (buf->rb_bc_srv_max_requests << 1)); if (buf->rb_credits != credits) rpcrdma_update_cwnd(r_xprt, credits); diff --git a/net/sunrpc/xprtrdma/svc_rdma.c b/net/sunrpc/xprtrdma/svc_rdma.c index f86970733eb0..415c0310101f 100644 --- a/net/sunrpc/xprtrdma/svc_rdma.c +++ b/net/sunrpc/xprtrdma/svc_rdma.c @@ -74,7 +74,7 @@ enum { SVCRDMA_COUNTER_BUFSIZ = sizeof(unsigned long long), }; -static int svcrdma_counter_handler(struct ctl_table *table, int write, +static int svcrdma_counter_handler(const struct ctl_table *table, int write, void *buffer, size_t *lenp, loff_t *ppos) { struct percpu_counter *stat = (struct percpu_counter *)table->data; @@ -209,7 +209,6 @@ static struct ctl_table svcrdma_parm_table[] = { .extra1 = &zero, .extra2 = &zero, }, - { }, }; static void svc_rdma_proc_cleanup(void) @@ -234,25 +233,34 @@ static int svc_rdma_proc_init(void) rc = percpu_counter_init(&svcrdma_stat_read, 0, GFP_KERNEL); if (rc) - goto out_err; + goto err; rc = percpu_counter_init(&svcrdma_stat_recv, 0, GFP_KERNEL); if (rc) - goto out_err; + goto err_read; rc = percpu_counter_init(&svcrdma_stat_sq_starve, 0, GFP_KERNEL); if (rc) - goto out_err; + goto err_recv; rc = percpu_counter_init(&svcrdma_stat_write, 0, GFP_KERNEL); if (rc) - goto out_err; + goto err_sq; svcrdma_table_header = register_sysctl("sunrpc/svc_rdma", svcrdma_parm_table); + if (!svcrdma_table_header) + goto err_write; + return 0; -out_err: +err_write: + rc = -ENOMEM; + percpu_counter_destroy(&svcrdma_stat_write); +err_sq: percpu_counter_destroy(&svcrdma_stat_sq_starve); +err_recv: percpu_counter_destroy(&svcrdma_stat_recv); +err_read: percpu_counter_destroy(&svcrdma_stat_read); +err: return rc; } diff --git a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c index d72953f29258..e7e4a39ca6c6 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c +++ b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c @@ -94,7 +94,7 @@ #include <linux/slab.h> #include <linux/spinlock.h> -#include <asm/unaligned.h> +#include <linux/unaligned.h> #include <rdma/ib_verbs.h> #include <rdma/rdma_cm.h> @@ -120,12 +120,16 @@ svc_rdma_recv_ctxt_alloc(struct svcxprt_rdma *rdma) { int node = ibdev_to_node(rdma->sc_cm_id->device); struct svc_rdma_recv_ctxt *ctxt; + unsigned long pages; dma_addr_t addr; void *buffer; - ctxt = kzalloc_node(sizeof(*ctxt), GFP_KERNEL, node); + pages = svc_serv_maxpages(rdma->sc_xprt.xpt_server); + ctxt = kzalloc_node(struct_size(ctxt, rc_pages, pages), + GFP_KERNEL, node); if (!ctxt) goto fail0; + ctxt->rc_maxpages = pages; buffer = kmalloc_node(rdma->sc_max_req_size, GFP_KERNEL, node); if (!buffer) goto fail1; @@ -493,7 +497,13 @@ static bool xdr_check_write_chunk(struct svc_rdma_recv_ctxt *rctxt) if (xdr_stream_decode_u32(&rctxt->rc_stream, &segcount)) return false; - /* A bogus segcount causes this buffer overflow check to fail. */ + /* Before trusting the segcount value enough to use it in + * a computation, perform a simple range check. This is an + * arbitrary but sensible limit (ie, not architectural). + */ + if (unlikely(segcount > rctxt->rc_maxpages)) + return false; + p = xdr_inline_decode(&rctxt->rc_stream, segcount * rpcrdma_segment_maxsz * sizeof(*p)); return p != NULL; diff --git a/net/sunrpc/xprtrdma/svc_rdma_rw.c b/net/sunrpc/xprtrdma/svc_rdma_rw.c index f2a100c4c81f..661b3fe2779f 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_rw.c +++ b/net/sunrpc/xprtrdma/svc_rdma_rw.c @@ -231,28 +231,6 @@ static void svc_rdma_write_info_free(struct svc_rdma_write_info *info) } /** - * svc_rdma_write_chunk_release - Release Write chunk I/O resources - * @rdma: controlling transport - * @ctxt: Send context that is being released - */ -void svc_rdma_write_chunk_release(struct svcxprt_rdma *rdma, - struct svc_rdma_send_ctxt *ctxt) -{ - struct svc_rdma_write_info *info; - struct svc_rdma_chunk_ctxt *cc; - - while (!list_empty(&ctxt->sc_write_info_list)) { - info = list_first_entry(&ctxt->sc_write_info_list, - struct svc_rdma_write_info, wi_list); - list_del(&info->wi_list); - - cc = &info->wi_cc; - svc_rdma_wake_send_waiters(rdma, cc->cc_sqecount); - svc_rdma_write_info_free(info); - } -} - -/** * svc_rdma_reply_chunk_release - Release Reply chunk I/O resources * @rdma: controlling transport * @ctxt: Send context that is being released @@ -308,11 +286,13 @@ static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc) struct ib_cqe *cqe = wc->wr_cqe; struct svc_rdma_chunk_ctxt *cc = container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe); + struct svc_rdma_write_info *info = + container_of(cc, struct svc_rdma_write_info, wi_cc); switch (wc->status) { case IB_WC_SUCCESS: trace_svcrdma_wc_write(&cc->cc_cid); - return; + break; case IB_WC_WR_FLUSH_ERR: trace_svcrdma_wc_write_flush(wc, &cc->cc_cid); break; @@ -320,11 +300,12 @@ static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc) trace_svcrdma_wc_write_err(wc, &cc->cc_cid); } - /* The RDMA Write has flushed, so the client won't get - * some of the outgoing RPC message. Signal the loss - * to the client by closing the connection. - */ - svc_xprt_deferred_close(&rdma->sc_xprt); + svc_rdma_wake_send_waiters(rdma, cc->cc_sqecount); + + if (unlikely(wc->status != IB_WC_SUCCESS)) + svc_xprt_deferred_close(&rdma->sc_xprt); + + svc_rdma_write_info_free(info); } /** @@ -620,19 +601,13 @@ static int svc_rdma_xb_write(const struct xdr_buf *xdr, void *data) return xdr->len; } -/* Link Write WRs for @chunk onto @sctxt's WR chain. - */ -static int svc_rdma_prepare_write_chunk(struct svcxprt_rdma *rdma, - struct svc_rdma_send_ctxt *sctxt, - const struct svc_rdma_chunk *chunk, - const struct xdr_buf *xdr) +static int svc_rdma_send_write_chunk(struct svcxprt_rdma *rdma, + const struct svc_rdma_chunk *chunk, + const struct xdr_buf *xdr) { struct svc_rdma_write_info *info; struct svc_rdma_chunk_ctxt *cc; - struct ib_send_wr *first_wr; struct xdr_buf payload; - struct list_head *pos; - struct ib_cqe *cqe; int ret; if (xdr_buf_subsegment(xdr, &payload, chunk->ch_position, @@ -648,25 +623,10 @@ static int svc_rdma_prepare_write_chunk(struct svcxprt_rdma *rdma, if (ret != payload.len) goto out_err; - ret = -EINVAL; - if (unlikely(cc->cc_sqecount > rdma->sc_sq_depth)) - goto out_err; - - first_wr = sctxt->sc_wr_chain; - cqe = &cc->cc_cqe; - list_for_each(pos, &cc->cc_rwctxts) { - struct svc_rdma_rw_ctxt *rwc; - - rwc = list_entry(pos, struct svc_rdma_rw_ctxt, rw_list); - first_wr = rdma_rw_ctx_wrs(&rwc->rw_ctx, rdma->sc_qp, - rdma->sc_port_num, cqe, first_wr); - cqe = NULL; - } - sctxt->sc_wr_chain = first_wr; - sctxt->sc_sqecount += cc->cc_sqecount; - list_add(&info->wi_list, &sctxt->sc_write_info_list); - trace_svcrdma_post_write_chunk(&cc->cc_cid, cc->cc_sqecount); + ret = svc_rdma_post_chunk_ctxt(rdma, cc); + if (ret < 0) + goto out_err; return 0; out_err: @@ -675,27 +635,25 @@ out_err: } /** - * svc_rdma_prepare_write_list - Construct WR chain for sending Write list + * svc_rdma_send_write_list - Send all chunks on the Write list * @rdma: controlling RDMA transport - * @write_pcl: Write list provisioned by the client - * @sctxt: Send WR resources + * @rctxt: Write list provisioned by the client * @xdr: xdr_buf containing an RPC Reply message * * Returns zero on success, or a negative errno if one or more * Write chunks could not be sent. */ -int svc_rdma_prepare_write_list(struct svcxprt_rdma *rdma, - const struct svc_rdma_pcl *write_pcl, - struct svc_rdma_send_ctxt *sctxt, - const struct xdr_buf *xdr) +int svc_rdma_send_write_list(struct svcxprt_rdma *rdma, + const struct svc_rdma_recv_ctxt *rctxt, + const struct xdr_buf *xdr) { struct svc_rdma_chunk *chunk; int ret; - pcl_for_each_chunk(chunk, write_pcl) { + pcl_for_each_chunk(chunk, &rctxt->rc_write_pcl) { if (!chunk->ch_payload_length) break; - ret = svc_rdma_prepare_write_chunk(rdma, sctxt, chunk, xdr); + ret = svc_rdma_send_write_chunk(rdma, chunk, xdr); if (ret < 0) return ret; } @@ -807,7 +765,7 @@ static int svc_rdma_build_read_segment(struct svc_rqst *rqstp, } len -= seg_len; - if (len && ((head->rc_curpage + 1) > ARRAY_SIZE(rqstp->rq_pages))) + if (len && ((head->rc_curpage + 1) > rqstp->rq_maxpages)) goto out_overrun; } diff --git a/net/sunrpc/xprtrdma/svc_rdma_sendto.c b/net/sunrpc/xprtrdma/svc_rdma_sendto.c index dfca39abd16c..914cd263c2f1 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_sendto.c +++ b/net/sunrpc/xprtrdma/svc_rdma_sendto.c @@ -100,7 +100,7 @@ */ #include <linux/spinlock.h> -#include <asm/unaligned.h> +#include <linux/unaligned.h> #include <rdma/ib_verbs.h> #include <rdma/rdma_cm.h> @@ -118,6 +118,7 @@ svc_rdma_send_ctxt_alloc(struct svcxprt_rdma *rdma) { int node = ibdev_to_node(rdma->sc_cm_id->device); struct svc_rdma_send_ctxt *ctxt; + unsigned long pages; dma_addr_t addr; void *buffer; int i; @@ -126,13 +127,19 @@ svc_rdma_send_ctxt_alloc(struct svcxprt_rdma *rdma) GFP_KERNEL, node); if (!ctxt) goto fail0; + pages = svc_serv_maxpages(rdma->sc_xprt.xpt_server); + ctxt->sc_pages = kcalloc_node(pages, sizeof(struct page *), + GFP_KERNEL, node); + if (!ctxt->sc_pages) + goto fail1; + ctxt->sc_maxpages = pages; buffer = kmalloc_node(rdma->sc_max_req_size, GFP_KERNEL, node); if (!buffer) - goto fail1; + goto fail2; addr = ib_dma_map_single(rdma->sc_pd->device, buffer, rdma->sc_max_req_size, DMA_TO_DEVICE); if (ib_dma_mapping_error(rdma->sc_pd->device, addr)) - goto fail2; + goto fail3; svc_rdma_send_cid_init(rdma, &ctxt->sc_cid); @@ -142,7 +149,6 @@ svc_rdma_send_ctxt_alloc(struct svcxprt_rdma *rdma) ctxt->sc_send_wr.sg_list = ctxt->sc_sges; ctxt->sc_send_wr.send_flags = IB_SEND_SIGNALED; ctxt->sc_cqe.done = svc_rdma_wc_send; - INIT_LIST_HEAD(&ctxt->sc_write_info_list); ctxt->sc_xprt_buf = buffer; xdr_buf_init(&ctxt->sc_hdrbuf, ctxt->sc_xprt_buf, rdma->sc_max_req_size); @@ -152,8 +158,10 @@ svc_rdma_send_ctxt_alloc(struct svcxprt_rdma *rdma) ctxt->sc_sges[i].lkey = rdma->sc_pd->local_dma_lkey; return ctxt; -fail2: +fail3: kfree(buffer); +fail2: + kfree(ctxt->sc_pages); fail1: kfree(ctxt); fail0: @@ -177,6 +185,7 @@ void svc_rdma_send_ctxts_destroy(struct svcxprt_rdma *rdma) rdma->sc_max_req_size, DMA_TO_DEVICE); kfree(ctxt->sc_xprt_buf); + kfree(ctxt->sc_pages); kfree(ctxt); } } @@ -228,7 +237,6 @@ static void svc_rdma_send_ctxt_release(struct svcxprt_rdma *rdma, struct ib_device *device = rdma->sc_cm_id->device; unsigned int i; - svc_rdma_write_chunk_release(rdma, ctxt); svc_rdma_reply_chunk_release(rdma, ctxt); if (ctxt->sc_page_count) @@ -1015,8 +1023,7 @@ int svc_rdma_sendto(struct svc_rqst *rqstp) if (!p) goto put_ctxt; - ret = svc_rdma_prepare_write_list(rdma, &rctxt->rc_write_pcl, sctxt, - &rqstp->rq_res); + ret = svc_rdma_send_write_list(rdma, rctxt, &rqstp->rq_res); if (ret < 0) goto put_ctxt; diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c index 2b1c16b9547d..3d7f1413df02 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_transport.c +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -65,6 +65,8 @@ static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv, struct net *net, int node); +static int svc_rdma_listen_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event); static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, struct net *net, struct sockaddr *sa, int salen, @@ -122,6 +124,41 @@ static void qp_event_handler(struct ib_event *event, void *context) } } +static struct rdma_cm_id * +svc_rdma_create_listen_id(struct net *net, struct sockaddr *sap, + void *context) +{ + struct rdma_cm_id *listen_id; + int ret; + + listen_id = rdma_create_id(net, svc_rdma_listen_handler, context, + RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(listen_id)) + return listen_id; + + /* Allow both IPv4 and IPv6 sockets to bind a single port + * at the same time. + */ +#if IS_ENABLED(CONFIG_IPV6) + ret = rdma_set_afonly(listen_id, 1); + if (ret) + goto out_destroy; +#endif + ret = rdma_bind_addr(listen_id, sap); + if (ret) + goto out_destroy; + + ret = rdma_listen(listen_id, RPCRDMA_LISTEN_BACKLOG); + if (ret) + goto out_destroy; + + return listen_id; + +out_destroy: + rdma_destroy_id(listen_id); + return ERR_PTR(ret); +} + static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv, struct net *net, int node) { @@ -247,17 +284,31 @@ static void handle_connect_req(struct rdma_cm_id *new_cma_id, * * Return values: * %0: Do not destroy @cma_id - * %1: Destroy @cma_id (never returned here) + * %1: Destroy @cma_id * * NB: There is never a DEVICE_REMOVAL event for INADDR_ANY listeners. */ static int svc_rdma_listen_handler(struct rdma_cm_id *cma_id, struct rdma_cm_event *event) { + struct sockaddr *sap = (struct sockaddr *)&cma_id->route.addr.src_addr; + struct svcxprt_rdma *cma_xprt = cma_id->context; + struct svc_xprt *cma_rdma = &cma_xprt->sc_xprt; + struct rdma_cm_id *listen_id; + switch (event->event) { case RDMA_CM_EVENT_CONNECT_REQUEST: handle_connect_req(cma_id, &event->param.conn); break; + case RDMA_CM_EVENT_ADDR_CHANGE: + listen_id = svc_rdma_create_listen_id(cma_rdma->xpt_net, + sap, cma_xprt); + if (IS_ERR(listen_id)) { + pr_err("Listener dead, address change failed for device %s\n", + cma_id->device->name); + } else + cma_xprt->sc_cm_id = listen_id; + return 1; default: break; } @@ -288,7 +339,6 @@ static int svc_rdma_cma_handler(struct rdma_cm_id *cma_id, svc_xprt_enqueue(xprt); break; case RDMA_CM_EVENT_DISCONNECTED: - case RDMA_CM_EVENT_DEVICE_REMOVAL: svc_xprt_deferred_close(xprt); break; default: @@ -307,7 +357,6 @@ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, { struct rdma_cm_id *listen_id; struct svcxprt_rdma *cma_xprt; - int ret; if (sa->sa_family != AF_INET && sa->sa_family != AF_INET6) return ERR_PTR(-EAFNOSUPPORT); @@ -317,30 +366,13 @@ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, set_bit(XPT_LISTENER, &cma_xprt->sc_xprt.xpt_flags); strcpy(cma_xprt->sc_xprt.xpt_remotebuf, "listener"); - listen_id = rdma_create_id(net, svc_rdma_listen_handler, cma_xprt, - RDMA_PS_TCP, IB_QPT_RC); + listen_id = svc_rdma_create_listen_id(net, sa, cma_xprt); if (IS_ERR(listen_id)) { - ret = PTR_ERR(listen_id); - goto err0; + kfree(cma_xprt); + return ERR_CAST(listen_id); } - - /* Allow both IPv4 and IPv6 sockets to bind a single port - * at the same time. - */ -#if IS_ENABLED(CONFIG_IPV6) - ret = rdma_set_afonly(listen_id, 1); - if (ret) - goto err1; -#endif - ret = rdma_bind_addr(listen_id, sa); - if (ret) - goto err1; cma_xprt->sc_cm_id = listen_id; - ret = rdma_listen(listen_id, RPCRDMA_LISTEN_BACKLOG); - if (ret) - goto err1; - /* * We need to use the address from the cm_id in case the * caller specified 0 for the port number. @@ -349,12 +381,16 @@ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, svc_xprt_set_local(&cma_xprt->sc_xprt, sa, salen); return &cma_xprt->sc_xprt; +} - err1: - rdma_destroy_id(listen_id); - err0: - kfree(cma_xprt); - return ERR_PTR(ret); +static void svc_rdma_xprt_done(struct rpcrdma_notification *rn) +{ + struct svcxprt_rdma *rdma = container_of(rn, struct svcxprt_rdma, + sc_rn); + struct rdma_cm_id *id = rdma->sc_cm_id; + + trace_svcrdma_device_removal(id); + svc_xprt_close(&rdma->sc_xprt); } /* @@ -370,12 +406,12 @@ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, */ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) { + unsigned int ctxts, rq_depth, maxpayload; struct svcxprt_rdma *listen_rdma; struct svcxprt_rdma *newxprt = NULL; struct rdma_conn_param conn_param; struct rpcrdma_connect_private pmsg; struct ib_qp_init_attr qp_attr; - unsigned int ctxts, rq_depth; struct ib_device *dev; int ret = 0; RPC_IFDEBUG(struct sockaddr *sap); @@ -398,6 +434,9 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) dev = newxprt->sc_cm_id->device; newxprt->sc_port_num = newxprt->sc_cm_id->port_num; + if (rpcrdma_rn_register(dev, &newxprt->sc_rn, svc_rdma_xprt_done)) + goto errout; + newxprt->sc_max_req_size = svcrdma_max_req_size; newxprt->sc_max_requests = svcrdma_max_requests; newxprt->sc_max_bc_requests = svcrdma_max_bc_requests; @@ -423,12 +462,14 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) newxprt->sc_max_bc_requests = 2; } - /* Arbitrarily estimate the number of rw_ctxs needed for - * this transport. This is enough rw_ctxs to make forward - * progress even if the client is using one rkey per page - * in each Read chunk. + /* Arbitrary estimate of the needed number of rdma_rw contexts. */ - ctxts = 3 * RPCSVC_MAXPAGES; + maxpayload = min(xprt->xpt_server->sv_max_payload, + RPCSVC_MAXPAYLOAD_RDMA); + ctxts = newxprt->sc_max_requests * 3 * + rdma_rw_mr_factor(dev, newxprt->sc_port_num, + maxpayload >> PAGE_SHIFT); + newxprt->sc_sq_depth = rq_depth + ctxts; if (newxprt->sc_sq_depth > dev->attrs.max_qp_wr) newxprt->sc_sq_depth = dev->attrs.max_qp_wr; @@ -536,6 +577,7 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) if (newxprt->sc_qp && !IS_ERR(newxprt->sc_qp)) ib_destroy_qp(newxprt->sc_qp); rdma_destroy_id(newxprt->sc_cm_id); + rpcrdma_rn_unregister(dev, &newxprt->sc_rn); /* This call to put will destroy the transport */ svc_xprt_put(&newxprt->sc_xprt); return NULL; @@ -553,6 +595,7 @@ static void __svc_rdma_free(struct work_struct *work) { struct svcxprt_rdma *rdma = container_of(work, struct svcxprt_rdma, sc_work); + struct ib_device *device = rdma->sc_cm_id->device; /* This blocks until the Completion Queues are empty */ if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) @@ -581,6 +624,8 @@ static void __svc_rdma_free(struct work_struct *work) /* Destroy the CM ID */ rdma_destroy_id(rdma->sc_cm_id); + if (!test_bit(XPT_LISTENER, &rdma->sc_xprt.xpt_flags)) + rpcrdma_rn_unregister(device, &rdma->sc_rn); kfree(rdma); } diff --git a/net/sunrpc/xprtrdma/transport.c b/net/sunrpc/xprtrdma/transport.c index 29b0562d62e7..9a8ce5df83ca 100644 --- a/net/sunrpc/xprtrdma/transport.c +++ b/net/sunrpc/xprtrdma/transport.c @@ -137,7 +137,6 @@ static struct ctl_table xr_tunables_table[] = { .mode = 0644, .proc_handler = proc_dointvec, }, - { }, }; #endif diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c index 4f8d7efa469f..63262ef0c2e3 100644 --- a/net/sunrpc/xprtrdma/verbs.c +++ b/net/sunrpc/xprtrdma/verbs.c @@ -49,14 +49,14 @@ * o buffer memory */ +#include <linux/bitops.h> #include <linux/interrupt.h> #include <linux/slab.h> #include <linux/sunrpc/addr.h> #include <linux/sunrpc/svc_rdma.h> #include <linux/log2.h> -#include <asm-generic/barrier.h> -#include <asm/bitops.h> +#include <asm/barrier.h> #include <rdma/ib_cm.h> @@ -69,13 +69,15 @@ static void rpcrdma_sendctx_put_locked(struct rpcrdma_xprt *r_xprt, struct rpcrdma_sendctx *sc); static int rpcrdma_reqs_setup(struct rpcrdma_xprt *r_xprt); static void rpcrdma_reqs_reset(struct rpcrdma_xprt *r_xprt); -static void rpcrdma_rep_destroy(struct rpcrdma_rep *rep); static void rpcrdma_reps_unmap(struct rpcrdma_xprt *r_xprt); static void rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt); static void rpcrdma_mrs_destroy(struct rpcrdma_xprt *r_xprt); static void rpcrdma_ep_get(struct rpcrdma_ep *ep); static int rpcrdma_ep_put(struct rpcrdma_ep *ep); static struct rpcrdma_regbuf * +rpcrdma_regbuf_alloc_node(size_t size, enum dma_data_direction direction, + int node); +static struct rpcrdma_regbuf * rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction); static void rpcrdma_regbuf_dma_unmap(struct rpcrdma_regbuf *rb); static void rpcrdma_regbuf_free(struct rpcrdma_regbuf *rb); @@ -222,7 +224,6 @@ static void rpcrdma_update_cm_private(struct rpcrdma_ep *ep, static int rpcrdma_cm_event_handler(struct rdma_cm_id *id, struct rdma_cm_event *event) { - struct sockaddr *sap = (struct sockaddr *)&id->route.addr.dst_addr; struct rpcrdma_ep *ep = id->context; might_sleep(); @@ -241,10 +242,6 @@ rpcrdma_cm_event_handler(struct rdma_cm_id *id, struct rdma_cm_event *event) ep->re_async_rc = -ENETUNREACH; complete(&ep->re_done); return 0; - case RDMA_CM_EVENT_DEVICE_REMOVAL: - pr_info("rpcrdma: removing device %s for %pISpc\n", - ep->re_id->device->name, sap); - fallthrough; case RDMA_CM_EVENT_ADDR_CHANGE: ep->re_connect_status = -ENODEV; goto disconnected; @@ -280,6 +277,14 @@ disconnected: return 0; } +static void rpcrdma_ep_removal_done(struct rpcrdma_notification *rn) +{ + struct rpcrdma_ep *ep = container_of(rn, struct rpcrdma_ep, re_rn); + + trace_xprtrdma_device_removal(ep->re_id); + xprt_force_disconnect(ep->re_xprt); +} + static struct rdma_cm_id *rpcrdma_create_id(struct rpcrdma_xprt *r_xprt, struct rpcrdma_ep *ep) { @@ -319,6 +324,10 @@ static struct rdma_cm_id *rpcrdma_create_id(struct rpcrdma_xprt *r_xprt, if (rc) goto out; + rc = rpcrdma_rn_register(id->device, &ep->re_rn, rpcrdma_ep_removal_done); + if (rc) + goto out; + return id; out: @@ -346,6 +355,8 @@ static void rpcrdma_ep_destroy(struct kref *kref) ib_dealloc_pd(ep->re_pd); ep->re_pd = NULL; + rpcrdma_rn_unregister(ep->re_id->device, &ep->re_rn); + kfree(ep); module_put(THIS_MODULE); } @@ -501,7 +512,7 @@ int rpcrdma_xprt_connect(struct rpcrdma_xprt *r_xprt) * outstanding Receives. */ rpcrdma_ep_get(ep); - rpcrdma_post_recvs(r_xprt, 1, true); + rpcrdma_post_recvs(r_xprt, 1); rc = rdma_connect(ep->re_id, &ep->re_remote_cma); if (rc) @@ -893,6 +904,8 @@ static int rpcrdma_reqs_setup(struct rpcrdma_xprt *r_xprt) static void rpcrdma_req_reset(struct rpcrdma_req *req) { + struct rpcrdma_mr *mr; + /* Credits are valid for only one connection */ req->rl_slot.rq_cong = 0; @@ -902,7 +915,19 @@ static void rpcrdma_req_reset(struct rpcrdma_req *req) rpcrdma_regbuf_dma_unmap(req->rl_sendbuf); rpcrdma_regbuf_dma_unmap(req->rl_recvbuf); - frwr_reset(req); + /* The verbs consumer can't know the state of an MR on the + * req->rl_registered list unless a successful completion + * has occurred, so they cannot be re-used. + */ + while ((mr = rpcrdma_mr_pop(&req->rl_registered))) { + struct rpcrdma_buffer *buf = &mr->mr_xprt->rx_buf; + + spin_lock(&buf->rb_lock); + list_del(&mr->mr_all); + spin_unlock(&buf->rb_lock); + + frwr_mr_release(mr); + } } /* ASSUMPTION: the rb_allreqs list is stable for the duration, @@ -920,18 +945,20 @@ static void rpcrdma_reqs_reset(struct rpcrdma_xprt *r_xprt) } static noinline -struct rpcrdma_rep *rpcrdma_rep_create(struct rpcrdma_xprt *r_xprt, - bool temp) +struct rpcrdma_rep *rpcrdma_rep_create(struct rpcrdma_xprt *r_xprt) { struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct ib_device *device = ep->re_id->device; struct rpcrdma_rep *rep; rep = kzalloc(sizeof(*rep), XPRTRDMA_GFP_FLAGS); if (rep == NULL) goto out; - rep->rr_rdmabuf = rpcrdma_regbuf_alloc(r_xprt->rx_ep->re_inline_recv, - DMA_FROM_DEVICE); + rep->rr_rdmabuf = rpcrdma_regbuf_alloc_node(ep->re_inline_recv, + DMA_FROM_DEVICE, + ibdev_to_node(device)); if (!rep->rr_rdmabuf) goto out_free; @@ -946,7 +973,6 @@ struct rpcrdma_rep *rpcrdma_rep_create(struct rpcrdma_xprt *r_xprt, rep->rr_recv_wr.wr_cqe = &rep->rr_cqe; rep->rr_recv_wr.sg_list = &rep->rr_rdmabuf->rg_iov; rep->rr_recv_wr.num_sge = 1; - rep->rr_temp = temp; spin_lock(&buf->rb_lock); list_add(&rep->rr_all, &buf->rb_all_reps); @@ -965,17 +991,6 @@ static void rpcrdma_rep_free(struct rpcrdma_rep *rep) kfree(rep); } -static void rpcrdma_rep_destroy(struct rpcrdma_rep *rep) -{ - struct rpcrdma_buffer *buf = &rep->rr_rxprt->rx_buf; - - spin_lock(&buf->rb_lock); - list_del(&rep->rr_all); - spin_unlock(&buf->rb_lock); - - rpcrdma_rep_free(rep); -} - static struct rpcrdma_rep *rpcrdma_rep_get_locked(struct rpcrdma_buffer *buf) { struct llist_node *node; @@ -1007,10 +1022,8 @@ static void rpcrdma_reps_unmap(struct rpcrdma_xprt *r_xprt) struct rpcrdma_buffer *buf = &r_xprt->rx_buf; struct rpcrdma_rep *rep; - list_for_each_entry(rep, &buf->rb_all_reps, rr_all) { + list_for_each_entry(rep, &buf->rb_all_reps, rr_all) rpcrdma_regbuf_dma_unmap(rep->rr_rdmabuf); - rep->rr_temp = true; /* Mark this rep for destruction */ - } } static void rpcrdma_reps_destroy(struct rpcrdma_buffer *buf) @@ -1227,14 +1240,15 @@ void rpcrdma_buffer_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req) * or Replies they may be registered externally via frwr_map. */ static struct rpcrdma_regbuf * -rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction) +rpcrdma_regbuf_alloc_node(size_t size, enum dma_data_direction direction, + int node) { struct rpcrdma_regbuf *rb; - rb = kmalloc(sizeof(*rb), XPRTRDMA_GFP_FLAGS); + rb = kmalloc_node(sizeof(*rb), XPRTRDMA_GFP_FLAGS, node); if (!rb) return NULL; - rb->rg_data = kmalloc(size, XPRTRDMA_GFP_FLAGS); + rb->rg_data = kmalloc_node(size, XPRTRDMA_GFP_FLAGS, node); if (!rb->rg_data) { kfree(rb); return NULL; @@ -1246,6 +1260,12 @@ rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction) return rb; } +static struct rpcrdma_regbuf * +rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction) +{ + return rpcrdma_regbuf_alloc_node(size, direction, NUMA_NO_NODE); +} + /** * rpcrdma_regbuf_realloc - re-allocate a SEND/RECV buffer * @rb: regbuf to reallocate @@ -1323,10 +1343,9 @@ static void rpcrdma_regbuf_free(struct rpcrdma_regbuf *rb) * rpcrdma_post_recvs - Refill the Receive Queue * @r_xprt: controlling transport instance * @needed: current credit grant - * @temp: mark Receive buffers to be deleted after one use * */ -void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed, bool temp) +void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed) { struct rpcrdma_buffer *buf = &r_xprt->rx_buf; struct rpcrdma_ep *ep = r_xprt->rx_ep; @@ -1340,8 +1359,7 @@ void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed, bool temp) if (likely(ep->re_receive_count > needed)) goto out; needed -= ep->re_receive_count; - if (!temp) - needed += RPCRDMA_MAX_RECV_BATCH; + needed += RPCRDMA_MAX_RECV_BATCH; if (atomic_inc_return(&ep->re_receiving) > 1) goto out; @@ -1350,12 +1368,8 @@ void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed, bool temp) wr = NULL; while (needed) { rep = rpcrdma_rep_get_locked(buf); - if (rep && rep->rr_temp) { - rpcrdma_rep_destroy(rep); - continue; - } if (!rep) - rep = rpcrdma_rep_create(r_xprt, temp); + rep = rpcrdma_rep_create(r_xprt); if (!rep) break; if (!rpcrdma_regbuf_dma_map(r_xprt, rep->rr_rdmabuf)) { diff --git a/net/sunrpc/xprtrdma/xprt_rdma.h b/net/sunrpc/xprtrdma/xprt_rdma.h index da409450dfc0..8147d2b41494 100644 --- a/net/sunrpc/xprtrdma/xprt_rdma.h +++ b/net/sunrpc/xprtrdma/xprt_rdma.h @@ -56,6 +56,7 @@ #include <linux/sunrpc/rpc_rdma_cid.h> /* completion IDs */ #include <linux/sunrpc/rpc_rdma.h> /* RPC/RDMA protocol */ #include <linux/sunrpc/xprtrdma.h> /* xprt parameters */ +#include <linux/sunrpc/rdma_rn.h> /* removal notifications */ #define RDMA_RESOLVE_TIMEOUT (5000) /* 5 seconds */ #define RDMA_CONNECT_RETRY_MAX (2) /* retries if no listener backlog */ @@ -92,6 +93,7 @@ struct rpcrdma_ep { struct rpcrdma_connect_private re_cm_private; struct rdma_conn_param re_remote_cma; + struct rpcrdma_notification re_rn; int re_receive_count; unsigned int re_max_requests; /* depends on device */ unsigned int re_inline_send; /* negotiated */ @@ -198,7 +200,6 @@ struct rpcrdma_rep { __be32 rr_proc; int rr_wc_flags; u32 rr_inv_rkey; - bool rr_temp; struct rpcrdma_regbuf *rr_rdmabuf; struct rpcrdma_xprt *rr_rxprt; struct rpc_rqst *rr_rqst; @@ -466,7 +467,7 @@ void rpcrdma_flush_disconnect(struct rpcrdma_xprt *r_xprt, struct ib_wc *wc); int rpcrdma_xprt_connect(struct rpcrdma_xprt *r_xprt); void rpcrdma_xprt_disconnect(struct rpcrdma_xprt *r_xprt); -void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed, bool temp); +void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed); /* * Buffer calls - xprtrdma/verbs.c diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index bb9b747d58a1..04ff66758fc3 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -160,7 +160,6 @@ static struct ctl_table xs_tunables_table[] = { .mode = 0644, .proc_handler = proc_dointvec_jiffies, }, - { }, }; /* @@ -1199,6 +1198,7 @@ static void xs_sock_reset_state_flags(struct rpc_xprt *xprt) clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state); clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state); clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state); + clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); } static void xs_run_error_worker(struct sock_xprt *transport, unsigned int nr) @@ -1279,6 +1279,7 @@ static void xs_reset_transport(struct sock_xprt *transport) transport->file = NULL; sk->sk_user_data = NULL; + sk->sk_sndtimeo = 0; xs_restore_old_callbacks(transport, sk); xprt_clear_connected(xprt); @@ -1940,6 +1941,9 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt, goto out; } + if (protocol == IPPROTO_TCP) + sk_net_refcnt_upgrade(sock->sk); + filp = sock_alloc_file(sock, O_NONBLOCK, NULL); if (IS_ERR(filp)) return ERR_CAST(filp); @@ -2442,6 +2446,13 @@ static void xs_tcp_setup_socket(struct work_struct *work) transport->srcport = 0; status = -EAGAIN; break; + case -EPERM: + /* Happens, for instance, if a BPF program is preventing + * the connect. Remap the error so upper layers can better + * deal with it. + */ + status = -ECONNREFUSED; + fallthrough; case -EINVAL: /* Happens, for instance, if the user specified a link * local IPv6 address without a scope-id. @@ -2453,6 +2464,7 @@ static void xs_tcp_setup_socket(struct work_struct *work) case -EHOSTUNREACH: case -EADDRINUSE: case -ENOBUFS: + case -ENOTCONN: break; default: printk("%s: connect returned unhandled error %d\n", @@ -2565,7 +2577,15 @@ static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid) struct sock_xprt *lower_transport = container_of(lower_xprt, struct sock_xprt, xprt); - lower_transport->xprt_err = status ? -EACCES : 0; + switch (status) { + case 0: + case -EACCES: + case -ETIMEDOUT: + lower_transport->xprt_err = status; + break; + default: + lower_transport->xprt_err = -EACCES; + } complete(&lower_transport->handshake_done); xprt_put(lower_xprt); } @@ -2607,11 +2627,10 @@ static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_par rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done, XS_TLS_HANDSHAKE_TO); if (rc <= 0) { - if (!tls_handshake_cancel(sk)) { - if (rc == 0) - rc = -ETIMEDOUT; - goto out_put_xprt; - } + tls_handshake_cancel(sk); + if (rc == 0) + rc = -ETIMEDOUT; + goto out_put_xprt; } rc = lower_transport->xprt_err; @@ -2664,6 +2683,7 @@ static void xs_tcp_tls_setup_socket(struct work_struct *work) .xprtsec = { .policy = RPC_XPRTSEC_NONE, }, + .stats = upper_clnt->cl_stats, }; unsigned int pflags = current->flags; struct rpc_clnt *lower_clnt; @@ -2706,20 +2726,14 @@ static void xs_tcp_tls_setup_socket(struct work_struct *work) if (status) goto out_close; xprt_release_write(lower_xprt, NULL); - trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0); - if (!xprt_test_and_set_connected(upper_xprt)) { - upper_xprt->connect_cookie++; - clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state); - xprt_clear_connecting(upper_xprt); - - upper_xprt->stat.connect_count++; - upper_xprt->stat.connect_time += (long)jiffies - - upper_xprt->stat.connect_start; - xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING); - } rpc_shutdown_client(lower_clnt); + /* Check for ingress data that arrived before the socket's + * ->data_ready callback was set up. + */ + xs_poll_check_readable(upper_transport); + out_unlock: current_restore_flags(pflags, PF_MEMALLOC); upper_transport->clnt = NULL; |