diff options
Diffstat (limited to 'crypto/zstd.c')
-rw-r--r-- | crypto/zstd.c | 356 |
1 files changed, 236 insertions, 120 deletions
diff --git a/crypto/zstd.c b/crypto/zstd.c index 7570e11b4ee6..c2a19cb0879d 100644 --- a/crypto/zstd.c +++ b/crypto/zstd.c @@ -12,188 +12,304 @@ #include <linux/net.h> #include <linux/vmalloc.h> #include <linux/zstd.h> -#include <crypto/internal/scompress.h> +#include <crypto/internal/acompress.h> +#include <crypto/scatterwalk.h> -#define ZSTD_DEF_LEVEL 3 +#define ZSTD_DEF_LEVEL 3 +#define ZSTD_MAX_WINDOWLOG 18 +#define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG) struct zstd_ctx { zstd_cctx *cctx; zstd_dctx *dctx; - void *cwksp; - void *dwksp; + size_t wksp_size; + zstd_parameters params; + u8 wksp[] __aligned(8); }; -static zstd_parameters zstd_params(void) -{ - return zstd_get_params(ZSTD_DEF_LEVEL, 0); -} +static DEFINE_MUTEX(zstd_stream_lock); -static int zstd_comp_init(struct zstd_ctx *ctx) +static void *zstd_alloc_stream(void) { - int ret = 0; - const zstd_parameters params = zstd_params(); - const size_t wksp_size = zstd_cctx_workspace_bound(¶ms.cParams); + zstd_parameters params; + struct zstd_ctx *ctx; + size_t wksp_size; - ctx->cwksp = vzalloc(wksp_size); - if (!ctx->cwksp) { - ret = -ENOMEM; - goto out; - } + params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE); - ctx->cctx = zstd_init_cctx(ctx->cwksp, wksp_size); - if (!ctx->cctx) { - ret = -EINVAL; - goto out_free; - } -out: - return ret; -out_free: - vfree(ctx->cwksp); - goto out; + wksp_size = max_t(size_t, + zstd_cstream_workspace_bound(¶ms.cParams), + zstd_dstream_workspace_bound(ZSTD_MAX_SIZE)); + if (!wksp_size) + return ERR_PTR(-EINVAL); + + ctx = kvmalloc(sizeof(*ctx) + wksp_size, GFP_KERNEL); + if (!ctx) + return ERR_PTR(-ENOMEM); + + ctx->params = params; + ctx->wksp_size = wksp_size; + + return ctx; } -static int zstd_decomp_init(struct zstd_ctx *ctx) +static void zstd_free_stream(void *ctx) +{ + kvfree(ctx); +} + +static struct crypto_acomp_streams zstd_streams = { + .alloc_ctx = zstd_alloc_stream, + .free_ctx = zstd_free_stream, +}; + +static int zstd_init(struct crypto_acomp *acomp_tfm) { int ret = 0; - const size_t wksp_size = zstd_dctx_workspace_bound(); - ctx->dwksp = vzalloc(wksp_size); - if (!ctx->dwksp) { - ret = -ENOMEM; - goto out; - } + mutex_lock(&zstd_stream_lock); + ret = crypto_acomp_alloc_streams(&zstd_streams); + mutex_unlock(&zstd_stream_lock); - ctx->dctx = zstd_init_dctx(ctx->dwksp, wksp_size); - if (!ctx->dctx) { - ret = -EINVAL; - goto out_free; - } -out: return ret; -out_free: - vfree(ctx->dwksp); - goto out; } -static void zstd_comp_exit(struct zstd_ctx *ctx) +static void zstd_exit(struct crypto_acomp *acomp_tfm) { - vfree(ctx->cwksp); - ctx->cwksp = NULL; - ctx->cctx = NULL; + crypto_acomp_free_streams(&zstd_streams); } -static void zstd_decomp_exit(struct zstd_ctx *ctx) +static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx, + const void *src, void *dst, unsigned int *dlen) { - vfree(ctx->dwksp); - ctx->dwksp = NULL; - ctx->dctx = NULL; -} + unsigned int out_len; -static int __zstd_init(void *ctx) -{ - int ret; + ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size); + if (!ctx->cctx) + return -EINVAL; - ret = zstd_comp_init(ctx); - if (ret) - return ret; - ret = zstd_decomp_init(ctx); - if (ret) - zstd_comp_exit(ctx); - return ret; + out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen, + &ctx->params); + if (zstd_is_error(out_len)) + return -EINVAL; + + *dlen = out_len; + + return 0; } -static void *zstd_alloc_ctx(void) +static int zstd_compress(struct acomp_req *req) { - int ret; + struct crypto_acomp_stream *s; + unsigned int pos, scur, dcur; + unsigned int total_out = 0; + bool data_available = true; + zstd_out_buffer outbuf; + struct acomp_walk walk; + zstd_in_buffer inbuf; struct zstd_ctx *ctx; + size_t pending_bytes; + size_t num_bytes; + int ret; - ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); - if (!ctx) - return ERR_PTR(-ENOMEM); + s = crypto_acomp_lock_stream_bh(&zstd_streams); + ctx = s->ctx; - ret = __zstd_init(ctx); - if (ret) { - kfree(ctx); - return ERR_PTR(ret); + ret = acomp_walk_virt(&walk, req, true); + if (ret) + goto out; + + ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size); + if (!ctx->cctx) { + ret = -EINVAL; + goto out; } - return ctx; -} + do { + dcur = acomp_walk_next_dst(&walk); + if (!dcur) { + ret = -ENOSPC; + goto out; + } -static void __zstd_exit(void *ctx) -{ - zstd_comp_exit(ctx); - zstd_decomp_exit(ctx); -} + outbuf.pos = 0; + outbuf.dst = (u8 *)walk.dst.virt.addr; + outbuf.size = dcur; -static void zstd_free_ctx(void *ctx) -{ - __zstd_exit(ctx); - kfree_sensitive(ctx); -} + do { + scur = acomp_walk_next_src(&walk); + if (dcur == req->dlen && scur == req->slen) { + ret = zstd_compress_one(req, ctx, walk.src.virt.addr, + walk.dst.virt.addr, &total_out); + acomp_walk_done_src(&walk, scur); + acomp_walk_done_dst(&walk, dcur); + goto out; + } -static int __zstd_compress(const u8 *src, unsigned int slen, - u8 *dst, unsigned int *dlen, void *ctx) -{ - size_t out_len; - struct zstd_ctx *zctx = ctx; - const zstd_parameters params = zstd_params(); + if (scur) { + inbuf.pos = 0; + inbuf.src = walk.src.virt.addr; + inbuf.size = scur; + } else { + data_available = false; + break; + } - out_len = zstd_compress_cctx(zctx->cctx, dst, *dlen, src, slen, ¶ms); - if (zstd_is_error(out_len)) - return -EINVAL; - *dlen = out_len; - return 0; -} + num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf); + if (ZSTD_isError(num_bytes)) { + ret = -EIO; + goto out; + } -static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src, - unsigned int slen, u8 *dst, unsigned int *dlen, - void *ctx) -{ - return __zstd_compress(src, slen, dst, dlen, ctx); + pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf); + if (ZSTD_isError(pending_bytes)) { + ret = -EIO; + goto out; + } + acomp_walk_done_src(&walk, inbuf.pos); + } while (dcur != outbuf.pos); + + total_out += outbuf.pos; + acomp_walk_done_dst(&walk, dcur); + } while (data_available); + + pos = outbuf.pos; + num_bytes = zstd_end_stream(ctx->cctx, &outbuf); + if (ZSTD_isError(num_bytes)) + ret = -EIO; + else + total_out += (outbuf.pos - pos); + +out: + if (ret) + req->dlen = 0; + else + req->dlen = total_out; + + crypto_acomp_unlock_stream_bh(s); + + return ret; } -static int __zstd_decompress(const u8 *src, unsigned int slen, - u8 *dst, unsigned int *dlen, void *ctx) +static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx, + const void *src, void *dst, unsigned int *dlen) { size_t out_len; - struct zstd_ctx *zctx = ctx; - out_len = zstd_decompress_dctx(zctx->dctx, dst, *dlen, src, slen); + ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size); + if (!ctx->dctx) + return -EINVAL; + + out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen); if (zstd_is_error(out_len)) return -EINVAL; + *dlen = out_len; + return 0; } -static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src, - unsigned int slen, u8 *dst, unsigned int *dlen, - void *ctx) +static int zstd_decompress(struct acomp_req *req) { - return __zstd_decompress(src, slen, dst, dlen, ctx); -} + struct crypto_acomp_stream *s; + unsigned int total_out = 0; + unsigned int scur, dcur; + zstd_out_buffer outbuf; + struct acomp_walk walk; + zstd_in_buffer inbuf; + struct zstd_ctx *ctx; + size_t pending_bytes; + int ret; -static struct scomp_alg scomp = { - .alloc_ctx = zstd_alloc_ctx, - .free_ctx = zstd_free_ctx, - .compress = zstd_scompress, - .decompress = zstd_sdecompress, - .base = { - .cra_name = "zstd", - .cra_driver_name = "zstd-scomp", - .cra_module = THIS_MODULE, + s = crypto_acomp_lock_stream_bh(&zstd_streams); + ctx = s->ctx; + + ret = acomp_walk_virt(&walk, req, true); + if (ret) + goto out; + + ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size); + if (!ctx->dctx) { + ret = -EINVAL; + goto out; } + + do { + scur = acomp_walk_next_src(&walk); + if (scur) { + inbuf.pos = 0; + inbuf.size = scur; + inbuf.src = walk.src.virt.addr; + } else { + break; + } + + do { + dcur = acomp_walk_next_dst(&walk); + if (dcur == req->dlen && scur == req->slen) { + ret = zstd_decompress_one(req, ctx, walk.src.virt.addr, + walk.dst.virt.addr, &total_out); + acomp_walk_done_dst(&walk, dcur); + acomp_walk_done_src(&walk, scur); + goto out; + } + + if (!dcur) { + ret = -ENOSPC; + goto out; + } + + outbuf.pos = 0; + outbuf.dst = (u8 *)walk.dst.virt.addr; + outbuf.size = dcur; + + pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf); + if (ZSTD_isError(pending_bytes)) { + ret = -EIO; + goto out; + } + + total_out += outbuf.pos; + + acomp_walk_done_dst(&walk, outbuf.pos); + } while (inbuf.pos != scur); + + acomp_walk_done_src(&walk, scur); + } while (ret == 0); + +out: + if (ret) + req->dlen = 0; + else + req->dlen = total_out; + + crypto_acomp_unlock_stream_bh(s); + + return ret; +} + +static struct acomp_alg zstd_acomp = { + .base = { + .cra_name = "zstd", + .cra_driver_name = "zstd-generic", + .cra_flags = CRYPTO_ALG_REQ_VIRT, + .cra_module = THIS_MODULE, + }, + .init = zstd_init, + .exit = zstd_exit, + .compress = zstd_compress, + .decompress = zstd_decompress, }; static int __init zstd_mod_init(void) { - return crypto_register_scomp(&scomp); + return crypto_register_acomp(&zstd_acomp); } static void __exit zstd_mod_fini(void) { - crypto_unregister_scomp(&scomp); + crypto_unregister_acomp(&zstd_acomp); } module_init(zstd_mod_init); |