summaryrefslogtreecommitdiff
path: root/crypto/zstd.c
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/zstd.c')
-rw-r--r--crypto/zstd.c315
1 files changed, 315 insertions, 0 deletions
diff --git a/crypto/zstd.c b/crypto/zstd.c
new file mode 100644
index 000000000000..cbbd0413751a
--- /dev/null
+++ b/crypto/zstd.c
@@ -0,0 +1,315 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Cryptographic API.
+ *
+ * Copyright (c) 2017-present, Facebook, Inc.
+ */
+#include <linux/crypto.h>
+#include <linux/init.h>
+#include <linux/interrupt.h>
+#include <linux/mm.h>
+#include <linux/module.h>
+#include <linux/net.h>
+#include <linux/overflow.h>
+#include <linux/vmalloc.h>
+#include <linux/zstd.h>
+#include <crypto/internal/acompress.h>
+#include <crypto/scatterwalk.h>
+
+
+#define ZSTD_DEF_LEVEL 3
+#define ZSTD_MAX_WINDOWLOG 18
+#define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG)
+
+struct zstd_ctx {
+ zstd_cctx *cctx;
+ zstd_dctx *dctx;
+ size_t wksp_size;
+ zstd_parameters params;
+ u8 wksp[] __aligned(8) __counted_by(wksp_size);
+};
+
+static DEFINE_MUTEX(zstd_stream_lock);
+
+static void *zstd_alloc_stream(void)
+{
+ zstd_parameters params;
+ struct zstd_ctx *ctx;
+ size_t wksp_size;
+
+ params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE);
+
+ wksp_size = max(zstd_cstream_workspace_bound(&params.cParams),
+ zstd_dstream_workspace_bound(ZSTD_MAX_SIZE));
+ if (!wksp_size)
+ return ERR_PTR(-EINVAL);
+
+ ctx = kvmalloc(struct_size(ctx, wksp, wksp_size), GFP_KERNEL);
+ if (!ctx)
+ return ERR_PTR(-ENOMEM);
+
+ ctx->params = params;
+ ctx->wksp_size = wksp_size;
+
+ return ctx;
+}
+
+static void zstd_free_stream(void *ctx)
+{
+ kvfree(ctx);
+}
+
+static struct crypto_acomp_streams zstd_streams = {
+ .alloc_ctx = zstd_alloc_stream,
+ .free_ctx = zstd_free_stream,
+};
+
+static int zstd_init(struct crypto_acomp *acomp_tfm)
+{
+ int ret = 0;
+
+ mutex_lock(&zstd_stream_lock);
+ ret = crypto_acomp_alloc_streams(&zstd_streams);
+ mutex_unlock(&zstd_stream_lock);
+
+ return ret;
+}
+
+static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx,
+ const void *src, void *dst, unsigned int *dlen)
+{
+ size_t out_len;
+
+ ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size);
+ if (!ctx->cctx)
+ return -EINVAL;
+
+ out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen,
+ &ctx->params);
+ if (zstd_is_error(out_len))
+ return -EINVAL;
+
+ *dlen = out_len;
+
+ return 0;
+}
+
+static int zstd_compress(struct acomp_req *req)
+{
+ struct crypto_acomp_stream *s;
+ unsigned int pos, scur, dcur;
+ unsigned int total_out = 0;
+ bool data_available = true;
+ zstd_out_buffer outbuf;
+ struct acomp_walk walk;
+ zstd_in_buffer inbuf;
+ struct zstd_ctx *ctx;
+ size_t pending_bytes;
+ size_t num_bytes;
+ int ret;
+
+ s = crypto_acomp_lock_stream_bh(&zstd_streams);
+ ctx = s->ctx;
+
+ ret = acomp_walk_virt(&walk, req, true);
+ if (ret)
+ goto out;
+
+ ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size);
+ if (!ctx->cctx) {
+ ret = -EINVAL;
+ goto out;
+ }
+
+ do {
+ dcur = acomp_walk_next_dst(&walk);
+ if (!dcur) {
+ ret = -ENOSPC;
+ goto out;
+ }
+
+ outbuf.pos = 0;
+ outbuf.dst = (u8 *)walk.dst.virt.addr;
+ outbuf.size = dcur;
+
+ do {
+ scur = acomp_walk_next_src(&walk);
+ if (dcur == req->dlen && scur == req->slen) {
+ ret = zstd_compress_one(req, ctx, walk.src.virt.addr,
+ walk.dst.virt.addr, &total_out);
+ acomp_walk_done_src(&walk, scur);
+ acomp_walk_done_dst(&walk, dcur);
+ goto out;
+ }
+
+ if (scur) {
+ inbuf.pos = 0;
+ inbuf.src = walk.src.virt.addr;
+ inbuf.size = scur;
+ } else {
+ data_available = false;
+ break;
+ }
+
+ num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf);
+ if (ZSTD_isError(num_bytes)) {
+ ret = -EIO;
+ goto out;
+ }
+
+ pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf);
+ if (ZSTD_isError(pending_bytes)) {
+ ret = -EIO;
+ goto out;
+ }
+ acomp_walk_done_src(&walk, inbuf.pos);
+ } while (dcur != outbuf.pos);
+
+ total_out += outbuf.pos;
+ acomp_walk_done_dst(&walk, dcur);
+ } while (data_available);
+
+ pos = outbuf.pos;
+ num_bytes = zstd_end_stream(ctx->cctx, &outbuf);
+ if (ZSTD_isError(num_bytes))
+ ret = -EIO;
+ else
+ total_out += (outbuf.pos - pos);
+
+out:
+ if (ret)
+ req->dlen = 0;
+ else
+ req->dlen = total_out;
+
+ crypto_acomp_unlock_stream_bh(s);
+
+ return ret;
+}
+
+static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx,
+ const void *src, void *dst, unsigned int *dlen)
+{
+ size_t out_len;
+
+ ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size);
+ if (!ctx->dctx)
+ return -EINVAL;
+
+ out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen);
+ if (zstd_is_error(out_len))
+ return -EINVAL;
+
+ *dlen = out_len;
+
+ return 0;
+}
+
+static int zstd_decompress(struct acomp_req *req)
+{
+ struct crypto_acomp_stream *s;
+ unsigned int total_out = 0;
+ unsigned int scur, dcur;
+ zstd_out_buffer outbuf;
+ struct acomp_walk walk;
+ zstd_in_buffer inbuf;
+ struct zstd_ctx *ctx;
+ size_t pending_bytes;
+ int ret;
+
+ s = crypto_acomp_lock_stream_bh(&zstd_streams);
+ ctx = s->ctx;
+
+ ret = acomp_walk_virt(&walk, req, true);
+ if (ret)
+ goto out;
+
+ ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size);
+ if (!ctx->dctx) {
+ ret = -EINVAL;
+ goto out;
+ }
+
+ do {
+ scur = acomp_walk_next_src(&walk);
+ if (scur) {
+ inbuf.pos = 0;
+ inbuf.size = scur;
+ inbuf.src = walk.src.virt.addr;
+ } else {
+ break;
+ }
+
+ do {
+ dcur = acomp_walk_next_dst(&walk);
+ if (dcur == req->dlen && scur == req->slen) {
+ ret = zstd_decompress_one(req, ctx, walk.src.virt.addr,
+ walk.dst.virt.addr, &total_out);
+ acomp_walk_done_dst(&walk, dcur);
+ acomp_walk_done_src(&walk, scur);
+ goto out;
+ }
+
+ if (!dcur) {
+ ret = -ENOSPC;
+ goto out;
+ }
+
+ outbuf.pos = 0;
+ outbuf.dst = (u8 *)walk.dst.virt.addr;
+ outbuf.size = dcur;
+
+ pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf);
+ if (ZSTD_isError(pending_bytes)) {
+ ret = -EIO;
+ goto out;
+ }
+
+ total_out += outbuf.pos;
+
+ acomp_walk_done_dst(&walk, outbuf.pos);
+ } while (inbuf.pos != scur);
+
+ acomp_walk_done_src(&walk, scur);
+ } while (ret == 0);
+
+out:
+ if (ret)
+ req->dlen = 0;
+ else
+ req->dlen = total_out;
+
+ crypto_acomp_unlock_stream_bh(s);
+
+ return ret;
+}
+
+static struct acomp_alg zstd_acomp = {
+ .base = {
+ .cra_name = "zstd",
+ .cra_driver_name = "zstd-generic",
+ .cra_flags = CRYPTO_ALG_REQ_VIRT,
+ .cra_module = THIS_MODULE,
+ },
+ .init = zstd_init,
+ .compress = zstd_compress,
+ .decompress = zstd_decompress,
+};
+
+static int __init zstd_mod_init(void)
+{
+ return crypto_register_acomp(&zstd_acomp);
+}
+
+static void __exit zstd_mod_fini(void)
+{
+ crypto_unregister_acomp(&zstd_acomp);
+ crypto_acomp_free_streams(&zstd_streams);
+}
+
+module_init(zstd_mod_init);
+module_exit(zstd_mod_fini);
+
+MODULE_LICENSE("GPL");
+MODULE_DESCRIPTION("Zstd Compression Algorithm");
+MODULE_ALIAS_CRYPTO("zstd");