102 __host__ __device__
int Gen(
Cw cws[], cuda::std::array<int4, 4> &cs, int4 &ocw,
103 cuda::std::span<const int4, 2> s0s, In a, int4 b_buf) {
105 s0 = util::SetLsb(s0,
false);
107 s1 = util::SetLsb(s1,
false);
110 b_buf = util::SetLsb(b_buf,
false);
112 for (
int i = 0; i < in_bits; ++i) {
113 auto [s0l, s0r] = prg.Gen(s0);
114 auto [s1l, s1r] = prg.Gen(s1);
116 bool t0l = util::GetLsb(s0l);
117 s0l = util::SetLsb(s0l,
false);
118 bool t0r = util::GetLsb(s0r);
119 s0r = util::SetLsb(s0r,
false);
120 bool t1l = util::GetLsb(s1l);
121 s1l = util::SetLsb(s1l,
false);
122 bool t1r = util::GetLsb(s1r);
123 s1r = util::SetLsb(s1r,
false);
125 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
128 if (!a_bit) s_cw = util::Xor(s0r, s1r);
129 else s_cw = util::Xor(s0l, s1l);
131 bool tl_cw = t0l ^ t1l ^ a_bit ^ 1;
132 bool tr_cw = t0r ^ t1r ^ a_bit;
136 if (t0) s0 = util::Xor(s0, s_cw);
138 if (t1) s1 = util::Xor(s1, s_cw);
140 if (t0) t0 = t0l ^ tl_cw;
142 if (t1) t1 = t1l ^ tl_cw;
146 if (t0) s0 = util::Xor(s0, s_cw);
148 if (t1) s1 = util::Xor(s1, s_cw);
150 if (t0) t0 = t0r ^ tr_cw;
152 if (t1) t1 = t1r ^ tr_cw;
156 s_cw = util::SetLsb(s_cw, tl_cw);
157 cws[i] = {s_cw, tr_cw};
161 int4 a_buf = util::Pack(a);
163 auto pi_tilde_0 = xor_hash.Hash(cuda::std::tuple<int4, const int4>{a_buf, s0});
164 auto pi_tilde_1 = xor_hash.Hash(cuda::std::tuple<int4, const int4>{a_buf, s1});
166 cuda::std::span<const int4, 4>(pi_tilde_0), cuda::std::span<const int4, 4>(pi_tilde_1));
169 if (t0 == t1)
return 1;
172 auto v_cw = Group::From(b_buf) + (-Group::From(s0)) + Group::From(s1);
173 if (t1) v_cw = -v_cw;
191 __host__ __device__ cuda::std::array<int4, 4>
Eval(
bool b, int4 s0,
192 cuda::std::span<const Cw> cws, cuda::std::span<const int4, 4> cs, int4 ocw, In x, int4 &y) {
194 s = util::SetLsb(s,
false);
197 for (
int i = 0; i < in_bits; ++i) {
200 bool tl_cw = util::GetLsb(s_cw);
201 s_cw = util::SetLsb(s_cw,
false);
204 auto [sl, sr] = prg.Gen(s);
206 bool tl = util::GetLsb(sl);
207 sl = util::SetLsb(sl,
false);
208 bool tr = util::GetLsb(sr);
209 sr = util::SetLsb(sr,
false);
212 sl = util::Xor(sl, s_cw);
213 sr = util::Xor(sr, s_cw);
218 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
230 auto g = Group::From(s);
231 assert((ocw.w & 1) == 0);
232 if (t) g = g + Group::From(ocw);
237 int4 x_buf = util::Pack(x);
239 auto pi_tilde = xor_hash.Hash(cuda::std::tuple<int4, const int4>{x_buf, s});
242 cuda::std::span<const int4, 4>(pi_tilde), cuda::std::span<const int4, 4>(cs));
256 void Prove(cuda::std::span<
const cuda::std::array<int4, 4>> pi_tildes,
257 cuda::std::span<const int4, 4> cs, cuda::std::array<int4, 4> &pi) {
258 pi = {cs[0], cs[1], cs[2], cs[3]};
259 for (
size_t i = 0; i < pi_tildes.size(); ++i) {
260 cuda::std::array<int4, 4> h_input = util::Xor(
261 cuda::std::span<const int4, 4>(pi), cuda::std::span<const int4, 4>(pi_tildes[i]));
262 auto h_out = hash.Hash(cuda::std::span<const int4, 4>(h_input));
263 pi[0] = util::Xor(pi[0], h_out[0]);
264 pi[1] = util::Xor(pi[1], h_out[1]);
299 void EvalAll(
bool b, int4 s0, cuda::std::span<const Cw> cws, cuda::std::span<const int4, 4> cs,
300 int4 ocw, cuda::std::span<int4> ys, cuda::std::array<int4, 4> &pi) {
303 st = util::SetLsb(st, t);
305 assert(in_bits <
sizeof(
size_t) * 8);
307 size_t r = 1ULL << in_bits;
310 int par_depth_ = util::ResolveParDepth(par_depth);
315 EvalTree(st, cws, ys, l, r, i, par_depth_);
318 pi = {cs[0], cs[1], cs[2], cs[3]};
319 size_t n = 1ULL << in_bits;
320 assert((ocw.w & 1) == 0);
321 auto ocw_group = Group::From(ocw);
322 for (
size_t j = 0; j < n; ++j) {
324 bool tj = util::GetLsb(sj);
325 sj = util::SetLsb(sj,
false);
328 auto g = Group::From(sj);
329 if (tj) g = g + ocw_group;
334 int4 x_buf = util::Pack(
static_cast<In
>(j));
336 auto pi_tilde = xor_hash.Hash(cuda::std::tuple<int4, const int4>{x_buf, sj});
338 pi_tilde = util::Xor(
339 cuda::std::span<const int4, 4>(pi_tilde), cuda::std::span<const int4, 4>(cs));
342 cuda::std::array<int4, 4> h_input = util::Xor(
343 cuda::std::span<const int4, 4>(pi), cuda::std::span<const int4, 4>(pi_tilde));
344 auto h_out = hash.Hash(cuda::std::span<const int4, 4>(h_input));
345 pi[0] = util::Xor(pi[0], h_out[0]);
346 pi[1] = util::Xor(pi[1], h_out[1]);
351 void EvalTree(int4 st, cuda::std::span<const Cw> cws, cuda::std::span<int4> ys,
size_t l,
392 EvalTree(stl, cws, ys, l, mid, i + 1, par_depth_);
394 EvalTree(str, cws, ys, mid, r, i + 1, par_depth_);
397 EvalTree(stl, cws, ys, l, mid, i + 1, par_depth_);
398 EvalTree(str, cws, ys, mid, r, i + 1, par_depth_);