41 static constexpr int kRoundKeys = 11;
42 static constexpr int kRoundKeySize = kRoundKeys * 16;
43 alignas(16) uint8_t round_keys_[mul][kRoundKeySize];
46 static __m128i KeyExpStep(__m128i key, __m128i kg) {
47 kg = _mm_shuffle_epi32(kg, _MM_SHUFFLE(3, 3, 3, 3));
48 key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
49 key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
50 key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
51 return _mm_xor_si128(key, kg);
54 static void ExpandKey(
const uint8_t *key, uint8_t *rk_bytes) {
55 __m128i rk[kRoundKeys];
56 rk[0] = _mm_loadu_si128(
reinterpret_cast<const __m128i *
>(key));
57 rk[1] = KeyExpStep(rk[0], _mm_aeskeygenassist_si128(rk[0], 0x01));
58 rk[2] = KeyExpStep(rk[1], _mm_aeskeygenassist_si128(rk[1], 0x02));
59 rk[3] = KeyExpStep(rk[2], _mm_aeskeygenassist_si128(rk[2], 0x04));
60 rk[4] = KeyExpStep(rk[3], _mm_aeskeygenassist_si128(rk[3], 0x08));
61 rk[5] = KeyExpStep(rk[4], _mm_aeskeygenassist_si128(rk[4], 0x10));
62 rk[6] = KeyExpStep(rk[5], _mm_aeskeygenassist_si128(rk[5], 0x20));
63 rk[7] = KeyExpStep(rk[6], _mm_aeskeygenassist_si128(rk[6], 0x40));
64 rk[8] = KeyExpStep(rk[7], _mm_aeskeygenassist_si128(rk[7], 0x80));
65 rk[9] = KeyExpStep(rk[8], _mm_aeskeygenassist_si128(rk[8], 0x1b));
66 rk[10] = KeyExpStep(rk[9], _mm_aeskeygenassist_si128(rk[9], 0x36));
67 for (
int i = 0; i < kRoundKeys; ++i)
68 _mm_store_si128(
reinterpret_cast<__m128i *
>(rk_bytes + i * 16), rk[i]);
80 for (
int i = 0; i < mul; ++i) ExpandKey(keys[i], round_keys_[i]);
84 __host__ __device__ cuda::std::array<int4, mul> Gen(int4 seed) {
85 cuda::std::array<int4, mul> out{};
88 assert(
false &&
"Aes128MmoRaw is not supported on device side");
91 __m128i s = _mm_loadu_si128(
reinterpret_cast<const __m128i *
>(&seed));
92 for (
int i = 0; i < mul; ++i) {
93 const auto *rk =
reinterpret_cast<const __m128i *
>(round_keys_[i]);
94 __m128i b = _mm_xor_si128(s, _mm_load_si128(rk));
95 b = _mm_aesenc_si128(b, _mm_load_si128(rk + 1));
96 b = _mm_aesenc_si128(b, _mm_load_si128(rk + 2));
97 b = _mm_aesenc_si128(b, _mm_load_si128(rk + 3));
98 b = _mm_aesenc_si128(b, _mm_load_si128(rk + 4));
99 b = _mm_aesenc_si128(b, _mm_load_si128(rk + 5));
100 b = _mm_aesenc_si128(b, _mm_load_si128(rk + 6));
101 b = _mm_aesenc_si128(b, _mm_load_si128(rk + 7));
102 b = _mm_aesenc_si128(b, _mm_load_si128(rk + 8));
103 b = _mm_aesenc_si128(b, _mm_load_si128(rk + 9));
104 b = _mm_aesenclast_si128(b, _mm_load_si128(rk + 10));
105 b = _mm_xor_si128(b, s);
106 _mm_storeu_si128(
reinterpret_cast<__m128i *
>(&out[i]), b);