BACKPORT: crypto: arm64/aes-ce-cipher - move assembler code to .S file
authorArd Biesheuvel <ard.biesheuvel@linaro.org>
Tue, 21 Nov 2017 13:40:17 +0000 (13:40 +0000)
committerBruno Martins <bgcngm@gmail.com>
Sun, 22 Oct 2023 14:12:33 +0000 (15:12 +0100)
commit 019cd46984d04703a39924178f503a98436ac0d7 upstream.

Most crypto drivers involving kernel mode NEON take care to put the code
that actually touches the NEON register file in a separate compilation
unit, to prevent the compiler from reordering code that preserves or
restores the NEON context with code that may corrupt it. This is
necessary because we currently have no way to express the restrictions
imposed upon use of the NEON in kernel mode in a way that the compiler
understands.

However, in the case of aes-ce-cipher, it did not seem unreasonable to
deviate from this rule, given how it does not seem possible for the
compiler to reorder cross object function calls with asm blocks whose
in- and output constraints reflect that it reads from and writes to
memory.

Now that LTO is being proposed for the arm64 kernel, it is time to
revisit this. The link time optimization may replace the function
calls to kernel_neon_begin() and kernel_neon_end() with instantiations
of the IR that make up its implementation, allowing further reordering
with the asm block.

So let's clean this up, and move the asm() blocks into a separate .S
file.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Reviewed-By: Nick Desaulniers <ndesaulniers@google.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
Cc: Matthias Kaehlcke <mka@google.com>
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
[ bgcngm: Backported to 4.9 ]
Signed-off-by: Bruno Martins <bgcngm@gmail.com>
Change-Id: I7b271d83a3d5baf81aa5fca69cd4f2817945a477

arch/arm64/crypto/Makefile
arch/arm64/crypto/aes-ce-cipher-core.c [deleted file]
arch/arm64/crypto/aes-ce-cipher-glue.c [deleted file]
arch/arm64/crypto/aes-ce-core.S [new file with mode: 0644]
arch/arm64/crypto/aes-ce-glue.c [new file with mode: 0644]

index 550e02a5aa32941571ddc5a4c558eef16cae32c8..55e19e3dcb1586f1235ca8f7e6cb27455079c4e1 100644 (file)
@@ -18,8 +18,7 @@ obj-$(CONFIG_CRYPTO_GHASH_ARM64_CE) += ghash-ce.o
 ghash-ce-y := ghash-ce-glue.o ghash-ce-core.o
 
 obj-$(CONFIG_CRYPTO_AES_ARM64_CE) += aes-ce-cipher.o
-aes-ce-cipher-y := aes-ce-cipher-glue.o aes-ce-cipher-core.o
-CFLAGS_aes-ce-cipher-core.o += -march=armv8-a+crypto -Wa,-march=armv8-a+crypto $(DISABLE_LTO)
+aes-ce-cipher-y := aes-ce-core.o aes-ce-glue.o
 
 obj-$(CONFIG_CRYPTO_AES_ARM64_CE_CCM) += aes-ce-ccm.o
 aes-ce-ccm-y := aes-ce-ccm-glue.o aes-ce-ccm-core.o
diff --git a/arch/arm64/crypto/aes-ce-cipher-core.c b/arch/arm64/crypto/aes-ce-cipher-core.c
deleted file mode 100644 (file)
index 948a244..0000000
+++ /dev/null
@@ -1,216 +0,0 @@
-/*
- * aes-ce-cipher-core.c - core AES cipher using ARMv8 Crypto Extensions
- *
- * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
- *
- * This program is free software; you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 2 as
- * published by the Free Software Foundation.
- */
-
-#include <asm/neon.h>
-#include <asm/unaligned.h>
-#include <crypto/aes.h>
-#include <linux/crypto.h>
-
-#include "aes-ce-setkey.h"
-
-struct aes_block {
-       u8 b[AES_BLOCK_SIZE];
-};
-
-static int num_rounds(struct crypto_aes_ctx *ctx)
-{
-       /*
-        * # of rounds specified by AES:
-        * 128 bit key          10 rounds
-        * 192 bit key          12 rounds
-        * 256 bit key          14 rounds
-        * => n byte key        => 6 + (n/4) rounds
-        */
-       return 6 + ctx->key_length / 4;
-}
-
-void aes_cipher_encrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
-{
-       struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
-       struct aes_block *out = (struct aes_block *)dst;
-       struct aes_block const *in = (struct aes_block *)src;
-       void *dummy0;
-       int dummy1;
-
-       kernel_neon_begin_partial(4);
-
-       __asm__("       ld1     {v0.16b}, %[in]                 ;"
-               "       ld1     {v1.4s}, [%[key]], #16          ;"
-               "       cmp     %w[rounds], #10                 ;"
-               "       bmi     0f                              ;"
-               "       bne     3f                              ;"
-               "       mov     v3.16b, v1.16b                  ;"
-               "       b       2f                              ;"
-               "0:     mov     v2.16b, v1.16b                  ;"
-               "       ld1     {v3.4s}, [%[key]], #16          ;"
-               "1:     aese    v0.16b, v2.16b                  ;"
-               "       aesmc   v0.16b, v0.16b                  ;"
-               "2:     ld1     {v1.4s}, [%[key]], #16          ;"
-               "       aese    v0.16b, v3.16b                  ;"
-               "       aesmc   v0.16b, v0.16b                  ;"
-               "3:     ld1     {v2.4s}, [%[key]], #16          ;"
-               "       subs    %w[rounds], %w[rounds], #3      ;"
-               "       aese    v0.16b, v1.16b                  ;"
-               "       aesmc   v0.16b, v0.16b                  ;"
-               "       ld1     {v3.4s}, [%[key]], #16          ;"
-               "       bpl     1b                              ;"
-               "       aese    v0.16b, v2.16b                  ;"
-               "       eor     v0.16b, v0.16b, v3.16b          ;"
-               "       st1     {v0.16b}, %[out]                ;"
-
-       :       [out]           "=Q"(*out),
-               [key]           "=r"(dummy0),
-               [rounds]        "=r"(dummy1)
-       :       [in]            "Q"(*in),
-                               "1"(ctx->key_enc),
-                               "2"(num_rounds(ctx) - 2)
-       :       "cc");
-
-       kernel_neon_end();
-}
-
-void aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
-{
-       struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
-       struct aes_block *out = (struct aes_block *)dst;
-       struct aes_block const *in = (struct aes_block *)src;
-       void *dummy0;
-       int dummy1;
-
-       kernel_neon_begin_partial(4);
-
-       __asm__("       ld1     {v0.16b}, %[in]                 ;"
-               "       ld1     {v1.4s}, [%[key]], #16          ;"
-               "       cmp     %w[rounds], #10                 ;"
-               "       bmi     0f                              ;"
-               "       bne     3f                              ;"
-               "       mov     v3.16b, v1.16b                  ;"
-               "       b       2f                              ;"
-               "0:     mov     v2.16b, v1.16b                  ;"
-               "       ld1     {v3.4s}, [%[key]], #16          ;"
-               "1:     aesd    v0.16b, v2.16b                  ;"
-               "       aesimc  v0.16b, v0.16b                  ;"
-               "2:     ld1     {v1.4s}, [%[key]], #16          ;"
-               "       aesd    v0.16b, v3.16b                  ;"
-               "       aesimc  v0.16b, v0.16b                  ;"
-               "3:     ld1     {v2.4s}, [%[key]], #16          ;"
-               "       subs    %w[rounds], %w[rounds], #3      ;"
-               "       aesd    v0.16b, v1.16b                  ;"
-               "       aesimc  v0.16b, v0.16b                  ;"
-               "       ld1     {v3.4s}, [%[key]], #16          ;"
-               "       bpl     1b                              ;"
-               "       aesd    v0.16b, v2.16b                  ;"
-               "       eor     v0.16b, v0.16b, v3.16b          ;"
-               "       st1     {v0.16b}, %[out]                ;"
-
-       :       [out]           "=Q"(*out),
-               [key]           "=r"(dummy0),
-               [rounds]        "=r"(dummy1)
-       :       [in]            "Q"(*in),
-                               "1"(ctx->key_dec),
-                               "2"(num_rounds(ctx) - 2)
-       :       "cc");
-
-       kernel_neon_end();
-}
-
-/*
- * aes_sub() - use the aese instruction to perform the AES sbox substitution
- *             on each byte in 'input'
- */
-static u32 aes_sub(u32 input)
-{
-       u32 ret;
-
-       __asm__("dup    v1.4s, %w[in]           ;"
-               "movi   v0.16b, #0              ;"
-               "aese   v0.16b, v1.16b          ;"
-               "umov   %w[out], v0.4s[0]       ;"
-
-       :       [out]   "=r"(ret)
-       :       [in]    "r"(input)
-       :               "v0","v1");
-
-       return ret;
-}
-
-int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
-                    unsigned int key_len)
-{
-       /*
-        * The AES key schedule round constants
-        */
-       static u8 const rcon[] = {
-               0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
-       };
-
-       u32 kwords = key_len / sizeof(u32);
-       struct aes_block *key_enc, *key_dec;
-       int i, j;
-
-       if (key_len != AES_KEYSIZE_128 &&
-           key_len != AES_KEYSIZE_192 &&
-           key_len != AES_KEYSIZE_256)
-               return -EINVAL;
-
-       ctx->key_length = key_len;
-       for (i = 0; i < kwords; i++)
-               ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));
-
-       kernel_neon_begin_partial(2);
-       for (i = 0; i < sizeof(rcon); i++) {
-               u32 *rki = ctx->key_enc + (i * kwords);
-               u32 *rko = rki + kwords;
-
-               rko[0] = ror32(aes_sub(rki[kwords - 1]), 8) ^ rcon[i] ^ rki[0];
-               rko[1] = rko[0] ^ rki[1];
-               rko[2] = rko[1] ^ rki[2];
-               rko[3] = rko[2] ^ rki[3];
-
-               if (key_len == AES_KEYSIZE_192) {
-                       if (i >= 7)
-                               break;
-                       rko[4] = rko[3] ^ rki[4];
-                       rko[5] = rko[4] ^ rki[5];
-               } else if (key_len == AES_KEYSIZE_256) {
-                       if (i >= 6)
-                               break;
-                       rko[4] = aes_sub(rko[3]) ^ rki[4];
-                       rko[5] = rko[4] ^ rki[5];
-                       rko[6] = rko[5] ^ rki[6];
-                       rko[7] = rko[6] ^ rki[7];
-               }
-       }
-
-       /*
-        * Generate the decryption keys for the Equivalent Inverse Cipher.
-        * This involves reversing the order of the round keys, and applying
-        * the Inverse Mix Columns transformation on all but the first and
-        * the last one.
-        */
-       key_enc = (struct aes_block *)ctx->key_enc;
-       key_dec = (struct aes_block *)ctx->key_dec;
-       j = num_rounds(ctx);
-
-       key_dec[0] = key_enc[j];
-       for (i = 1, j--; j > 0; i++, j--)
-               __asm__("ld1    {v0.4s}, %[in]          ;"
-                       "aesimc v1.16b, v0.16b          ;"
-                       "st1    {v1.4s}, %[out] ;"
-
-               :       [out]   "=Q"(key_dec[i])
-               :       [in]    "Q"(key_enc[j])
-               :               "v0","v1");
-       key_dec[i] = key_enc[0];
-
-       kernel_neon_end();
-       return 0;
-}
-EXPORT_SYMBOL(ce_aes_expandkey);
diff --git a/arch/arm64/crypto/aes-ce-cipher-glue.c b/arch/arm64/crypto/aes-ce-cipher-glue.c
deleted file mode 100644 (file)
index 442949e..0000000
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * aes-ce-cipher.c - core AES cipher using ARMv8 Crypto Extensions
- *
- * Copyright (C) 2013 - 2014 Linaro Ltd <ard.biesheuvel@linaro.org>
- *
- * This program is free software; you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 2 as
- * published by the Free Software Foundation.
- */
-
-#include <crypto/aes.h>
-#include <linux/cpufeature.h>
-#include <linux/crypto.h>
-#include <linux/module.h>
-
-#include "aes-ce-setkey.h"
-
-MODULE_DESCRIPTION("Synchronous AES cipher using ARMv8 Crypto Extensions");
-MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
-MODULE_LICENSE("GPL v2");
-
-extern void aes_cipher_encrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[]);
-extern void aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[]);
-
-#ifdef CONFIG_CFI_CLANG
-static inline void __cfi_aes_cipher_encrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
-{
-       aes_cipher_encrypt(tfm, dst, src);
-}
-
-static inline void __cfi_aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
-{
-       aes_cipher_decrypt(tfm, dst, src);
-}
-
-#define aes_cipher_encrypt __cfi_aes_cipher_encrypt
-#define aes_cipher_decrypt __cfi_aes_cipher_decrypt
-#endif
-
-int ce_aes_setkey(struct crypto_tfm *tfm, const u8 *in_key,
-                 unsigned int key_len)
-{
-       struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
-       int ret;
-
-       ret = ce_aes_expandkey(ctx, in_key, key_len);
-       if (!ret)
-               return 0;
-
-       tfm->crt_flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
-       return -EINVAL;
-}
-EXPORT_SYMBOL(ce_aes_setkey);
-
-static struct crypto_alg aes_alg = {
-       .cra_name               = "aes",
-       .cra_driver_name        = "aes-ce",
-       .cra_priority           = 250,
-       .cra_flags              = CRYPTO_ALG_TYPE_CIPHER,
-       .cra_blocksize          = AES_BLOCK_SIZE,
-       .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
-       .cra_module             = THIS_MODULE,
-       .cra_cipher = {
-               .cia_min_keysize        = AES_MIN_KEY_SIZE,
-               .cia_max_keysize        = AES_MAX_KEY_SIZE,
-               .cia_setkey             = ce_aes_setkey,
-               .cia_encrypt            = aes_cipher_encrypt,
-               .cia_decrypt            = aes_cipher_decrypt
-       }
-};
-
-static int __init aes_mod_init(void)
-{
-       return crypto_register_alg(&aes_alg);
-}
-
-static void __exit aes_mod_exit(void)
-{
-       crypto_unregister_alg(&aes_alg);
-}
-
-module_cpu_feature_match(AES, aes_mod_init);
-module_exit(aes_mod_exit);
diff --git a/arch/arm64/crypto/aes-ce-core.S b/arch/arm64/crypto/aes-ce-core.S
new file mode 100644 (file)
index 0000000..8efdfda
--- /dev/null
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2 as
+ * published by the Free Software Foundation.
+ */
+
+#include <linux/linkage.h>
+#include <asm/assembler.h>
+
+       .arch           armv8-a+crypto
+
+ENTRY(__aes_ce_encrypt)
+       sub             w3, w3, #2
+       ld1             {v0.16b}, [x2]
+       ld1             {v1.4s}, [x0], #16
+       cmp             w3, #10
+       bmi             0f
+       bne             3f
+       mov             v3.16b, v1.16b
+       b               2f
+0:     mov             v2.16b, v1.16b
+       ld1             {v3.4s}, [x0], #16
+1:     aese            v0.16b, v2.16b
+       aesmc           v0.16b, v0.16b
+2:     ld1             {v1.4s}, [x0], #16
+       aese            v0.16b, v3.16b
+       aesmc           v0.16b, v0.16b
+3:     ld1             {v2.4s}, [x0], #16
+       subs            w3, w3, #3
+       aese            v0.16b, v1.16b
+       aesmc           v0.16b, v0.16b
+       ld1             {v3.4s}, [x0], #16
+       bpl             1b
+       aese            v0.16b, v2.16b
+       eor             v0.16b, v0.16b, v3.16b
+       st1             {v0.16b}, [x1]
+       ret
+ENDPROC(__aes_ce_encrypt)
+
+ENTRY(__aes_ce_decrypt)
+       sub             w3, w3, #2
+       ld1             {v0.16b}, [x2]
+       ld1             {v1.4s}, [x0], #16
+       cmp             w3, #10
+       bmi             0f
+       bne             3f
+       mov             v3.16b, v1.16b
+       b               2f
+0:     mov             v2.16b, v1.16b
+       ld1             {v3.4s}, [x0], #16
+1:     aesd            v0.16b, v2.16b
+       aesimc          v0.16b, v0.16b
+2:     ld1             {v1.4s}, [x0], #16
+       aesd            v0.16b, v3.16b
+       aesimc          v0.16b, v0.16b
+3:     ld1             {v2.4s}, [x0], #16
+       subs            w3, w3, #3
+       aesd            v0.16b, v1.16b
+       aesimc          v0.16b, v0.16b
+       ld1             {v3.4s}, [x0], #16
+       bpl             1b
+       aesd            v0.16b, v2.16b
+       eor             v0.16b, v0.16b, v3.16b
+       st1             {v0.16b}, [x1]
+       ret
+ENDPROC(__aes_ce_decrypt)
+
+/*
+ * __aes_ce_sub() - use the aese instruction to perform the AES sbox
+ *                  substitution on each byte in 'input'
+ */
+ENTRY(__aes_ce_sub)
+       dup             v1.4s, w0
+       movi            v0.16b, #0
+       aese            v0.16b, v1.16b
+       umov            w0, v0.s[0]
+       ret
+ENDPROC(__aes_ce_sub)
+
+ENTRY(__aes_ce_invert)
+       ld1             {v0.4s}, [x1]
+       aesimc          v1.16b, v0.16b
+       st1             {v1.4s}, [x0]
+       ret
+ENDPROC(__aes_ce_invert)
diff --git a/arch/arm64/crypto/aes-ce-glue.c b/arch/arm64/crypto/aes-ce-glue.c
new file mode 100644 (file)
index 0000000..01837b6
--- /dev/null
@@ -0,0 +1,194 @@
+/*
+ * aes-ce-cipher.c - core AES cipher using ARMv8 Crypto Extensions
+ *
+ * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2 as
+ * published by the Free Software Foundation.
+ */
+
+#include <asm/neon.h>
+#include <asm/unaligned.h>
+#include <crypto/aes.h>
+#include <linux/cpufeature.h>
+#include <linux/crypto.h>
+#include <linux/module.h>
+
+#include "aes-ce-setkey.h"
+
+MODULE_DESCRIPTION("Synchronous AES cipher using ARMv8 Crypto Extensions");
+MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
+MODULE_LICENSE("GPL v2");
+
+struct aes_block {
+       u8 b[AES_BLOCK_SIZE];
+};
+
+asmlinkage void __aes_ce_encrypt(u32 *rk, u8 *out, const u8 *in, int rounds);
+asmlinkage void __aes_ce_decrypt(u32 *rk, u8 *out, const u8 *in, int rounds);
+
+asmlinkage u32 __aes_ce_sub(u32 l);
+asmlinkage void __aes_ce_invert(struct aes_block *out,
+                               const struct aes_block *in);
+
+static int num_rounds(struct crypto_aes_ctx *ctx)
+{
+       /*
+        * # of rounds specified by AES:
+        * 128 bit key          10 rounds
+        * 192 bit key          12 rounds
+        * 256 bit key          14 rounds
+        * => n byte key        => 6 + (n/4) rounds
+        */
+       return 6 + ctx->key_length / 4;
+}
+
+extern void aes_cipher_encrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[]);
+extern void aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[]);
+
+#ifdef CONFIG_CFI_CLANG
+static inline void __cfi_aes_cipher_encrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
+{
+       aes_cipher_encrypt(tfm, dst, src);
+}
+
+static inline void __cfi_aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
+{
+       aes_cipher_decrypt(tfm, dst, src);
+}
+
+#define aes_cipher_encrypt __cfi_aes_cipher_encrypt
+#define aes_cipher_decrypt __cfi_aes_cipher_decrypt
+#endif
+
+void aes_cipher_encrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
+{
+       struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
+
+       kernel_neon_begin_partial(4);
+       __aes_ce_encrypt(ctx->key_enc, dst, src, num_rounds(ctx));
+       kernel_neon_end();
+}
+
+void aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
+{
+       struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
+
+       kernel_neon_begin_partial(4);
+       __aes_ce_decrypt(ctx->key_dec, dst, src, num_rounds(ctx));
+       kernel_neon_end();
+}
+
+int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
+                    unsigned int key_len)
+{
+       /*
+        * The AES key schedule round constants
+        */
+       static u8 const rcon[] = {
+               0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
+       };
+
+       u32 kwords = key_len / sizeof(u32);
+       struct aes_block *key_enc, *key_dec;
+       int i, j;
+
+       if (key_len != AES_KEYSIZE_128 &&
+           key_len != AES_KEYSIZE_192 &&
+           key_len != AES_KEYSIZE_256)
+               return -EINVAL;
+
+       ctx->key_length = key_len;
+       for (i = 0; i < kwords; i++)
+               ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));
+
+       kernel_neon_begin_partial(2);
+       for (i = 0; i < sizeof(rcon); i++) {
+               u32 *rki = ctx->key_enc + (i * kwords);
+               u32 *rko = rki + kwords;
+
+               rko[0] = ror32(__aes_ce_sub(rki[kwords - 1]), 8) ^ rcon[i] ^ rki[0];
+               rko[1] = rko[0] ^ rki[1];
+               rko[2] = rko[1] ^ rki[2];
+               rko[3] = rko[2] ^ rki[3];
+
+               if (key_len == AES_KEYSIZE_192) {
+                       if (i >= 7)
+                               break;
+                       rko[4] = rko[3] ^ rki[4];
+                       rko[5] = rko[4] ^ rki[5];
+               } else if (key_len == AES_KEYSIZE_256) {
+                       if (i >= 6)
+                               break;
+                       rko[4] = __aes_ce_sub(rko[3]) ^ rki[4];
+                       rko[5] = rko[4] ^ rki[5];
+                       rko[6] = rko[5] ^ rki[6];
+                       rko[7] = rko[6] ^ rki[7];
+               }
+       }
+
+       /*
+        * Generate the decryption keys for the Equivalent Inverse Cipher.
+        * This involves reversing the order of the round keys, and applying
+        * the Inverse Mix Columns transformation on all but the first and
+        * the last one.
+        */
+       key_enc = (struct aes_block *)ctx->key_enc;
+       key_dec = (struct aes_block *)ctx->key_dec;
+       j = num_rounds(ctx);
+
+       key_dec[0] = key_enc[j];
+       for (i = 1, j--; j > 0; i++, j--)
+               __aes_ce_invert(key_dec + i, key_enc + j);
+       key_dec[i] = key_enc[0];
+
+       kernel_neon_end();
+       return 0;
+}
+EXPORT_SYMBOL(ce_aes_expandkey);
+
+int ce_aes_setkey(struct crypto_tfm *tfm, const u8 *in_key,
+                 unsigned int key_len)
+{
+       struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
+       int ret;
+
+       ret = ce_aes_expandkey(ctx, in_key, key_len);
+       if (!ret)
+               return 0;
+
+       tfm->crt_flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
+       return -EINVAL;
+}
+EXPORT_SYMBOL(ce_aes_setkey);
+
+static struct crypto_alg aes_alg = {
+       .cra_name               = "aes",
+       .cra_driver_name        = "aes-ce",
+       .cra_priority           = 250,
+       .cra_flags              = CRYPTO_ALG_TYPE_CIPHER,
+       .cra_blocksize          = AES_BLOCK_SIZE,
+       .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
+       .cra_module             = THIS_MODULE,
+       .cra_cipher = {
+               .cia_min_keysize        = AES_MIN_KEY_SIZE,
+               .cia_max_keysize        = AES_MAX_KEY_SIZE,
+               .cia_setkey             = ce_aes_setkey,
+               .cia_encrypt            = aes_cipher_encrypt,
+               .cia_decrypt            = aes_cipher_decrypt
+       }
+};
+
+static int __init aes_mod_init(void)
+{
+       return crypto_register_alg(&aes_alg);
+}
+
+static void __exit aes_mod_exit(void)
+{
+       crypto_unregister_alg(&aes_alg);
+}
+
+module_cpu_feature_match(AES, aes_mod_init);
+module_exit(aes_mod_exit);