101 __host__ __device__
int Gen(
102 Cw cws[], cuda::std::array<int4, 4> &cs, int4 &ocw, cuda::std::span<const int4, 2> s0s, In a, int4 b_buf) {
104 s0 = util::SetLsb(s0,
false);
106 s1 = util::SetLsb(s1,
false);
109 b_buf = util::SetLsb(b_buf,
false);
111 for (
int i = 0; i < in_bits; ++i) {
112 auto [s0l, s0r] = prg.Gen(s0);
113 auto [s1l, s1r] = prg.Gen(s1);
115 bool t0l = util::GetLsb(s0l);
116 s0l = util::SetLsb(s0l,
false);
117 bool t0r = util::GetLsb(s0r);
118 s0r = util::SetLsb(s0r,
false);
119 bool t1l = util::GetLsb(s1l);
120 s1l = util::SetLsb(s1l,
false);
121 bool t1r = util::GetLsb(s1r);
122 s1r = util::SetLsb(s1r,
false);
124 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
127 if (!a_bit) s_cw = util::Xor(s0r, s1r);
128 else s_cw = util::Xor(s0l, s1l);
130 bool tl_cw = t0l ^ t1l ^ a_bit ^ 1;
131 bool tr_cw = t0r ^ t1r ^ a_bit;
135 if (t0) s0 = util::Xor(s0, s_cw);
137 if (t1) s1 = util::Xor(s1, s_cw);
139 if (t0) t0 = t0l ^ tl_cw;
141 if (t1) t1 = t1l ^ tl_cw;
145 if (t0) s0 = util::Xor(s0, s_cw);
147 if (t1) s1 = util::Xor(s1, s_cw);
149 if (t0) t0 = t0r ^ tr_cw;
151 if (t1) t1 = t1r ^ tr_cw;
155 s_cw = util::SetLsb(s_cw, tl_cw);
156 cws[i] = {s_cw, tr_cw};
160 int4 a_buf = util::Pack(a);
162 auto pi_tilde_0 = xor_hash.Hash(cuda::std::tuple<int4, const int4>{a_buf, s0});
163 auto pi_tilde_1 = xor_hash.Hash(cuda::std::tuple<int4, const int4>{a_buf, s1});
164 cs = util::Xor(cuda::std::span<const int4, 4>(pi_tilde_0), cuda::std::span<const int4, 4>(pi_tilde_1));
167 if (t0 == t1)
return 1;
170 auto v_cw = Group::From(b_buf) + (-Group::From(s0)) + Group::From(s1);
171 if (t1) v_cw = -v_cw;
189 __host__ __device__ cuda::std::array<int4, 4>
Eval(
190 bool b, int4 s0, cuda::std::span<const Cw> cws, cuda::std::span<const int4, 4> cs, int4 ocw, In x, int4 &y) {
192 s = util::SetLsb(s,
false);
195 for (
int i = 0; i < in_bits; ++i) {
198 bool tl_cw = util::GetLsb(s_cw);
199 s_cw = util::SetLsb(s_cw,
false);
202 auto [sl, sr] = prg.Gen(s);
204 bool tl = util::GetLsb(sl);
205 sl = util::SetLsb(sl,
false);
206 bool tr = util::GetLsb(sr);
207 sr = util::SetLsb(sr,
false);
210 sl = util::Xor(sl, s_cw);
211 sr = util::Xor(sr, s_cw);
216 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
228 auto g = Group::From(s);
229 assert((ocw.w & 1) == 0);
230 if (t) g = g + Group::From(ocw);
235 int4 x_buf = util::Pack(x);
237 auto pi_tilde = xor_hash.Hash(cuda::std::tuple<int4, const int4>{x_buf, s});
239 return util::Xor(cuda::std::span<const int4, 4>(pi_tilde), cuda::std::span<const int4, 4>(cs));
253 void Prove(cuda::std::span<
const cuda::std::array<int4, 4>> pi_tildes, cuda::std::span<const int4, 4> cs,
254 cuda::std::array<int4, 4> &pi) {
255 pi = {cs[0], cs[1], cs[2], cs[3]};
256 for (
size_t i = 0; i < pi_tildes.size(); ++i) {
257 cuda::std::array<int4, 4> h_input =
258 util::Xor(cuda::std::span<const int4, 4>(pi), cuda::std::span<const int4, 4>(pi_tildes[i]));
259 auto h_out = hash.Hash(cuda::std::span<const int4, 4>(h_input));
260 pi[0] = util::Xor(pi[0], h_out[0]);
261 pi[1] = util::Xor(pi[1], h_out[1]);
293 void EvalAll(
bool b, int4 s0, cuda::std::span<const Cw> cws, cuda::std::span<const int4, 4> cs, int4 ocw,
294 cuda::std::span<int4> ys, cuda::std::array<int4, 4> &pi) {
297 st = util::SetLsb(st, t);
299 assert(in_bits <
sizeof(
size_t) * 8);
301 size_t r = 1ULL << in_bits;
304 int par_depth_ = util::ResolveParDepth(par_depth);
309 EvalTree(st, cws, ys, l, r, i, par_depth_);
312 pi = {cs[0], cs[1], cs[2], cs[3]};
313 size_t n = 1ULL << in_bits;
314 assert((ocw.w & 1) == 0);
315 auto ocw_group = Group::From(ocw);
316 for (
size_t j = 0; j < n; ++j) {
318 bool tj = util::GetLsb(sj);
319 sj = util::SetLsb(sj,
false);
322 auto g = Group::From(sj);
323 if (tj) g = g + ocw_group;
328 int4 x_buf = util::Pack(
static_cast<In
>(j));
330 auto pi_tilde = xor_hash.Hash(cuda::std::tuple<int4, const int4>{x_buf, sj});
332 pi_tilde = util::Xor(cuda::std::span<const int4, 4>(pi_tilde), cuda::std::span<const int4, 4>(cs));
335 cuda::std::array<int4, 4> h_input =
336 util::Xor(cuda::std::span<const int4, 4>(pi), cuda::std::span<const int4, 4>(pi_tilde));
337 auto h_out = hash.Hash(cuda::std::span<const int4, 4>(h_input));
338 pi[0] = util::Xor(pi[0], h_out[0]);
339 pi[1] = util::Xor(pi[1], h_out[1]);
345 int4 st, cuda::std::span<const Cw> cws, cuda::std::span<int4> ys,
size_t l,
size_t r,
int i,
int par_depth_) {
385 EvalTree(stl, cws, ys, l, mid, i + 1, par_depth_);
387 EvalTree(str, cws, ys, mid, r, i + 1, par_depth_);
390 EvalTree(stl, cws, ys, l, mid, i + 1, par_depth_);
391 EvalTree(str, cws, ys, mid, r, i + 1, par_depth_);