68 __host__ __device__
void Gen(
Cw cws[], int4 &ocw,
const int4 s0s[2], In a, int4 b_buf) {
69 b_buf = util::SetLsb(b_buf,
false);
72 int4 node0 = util::SetLsb(s0s[0],
false);
73 int4 node1 = util::SetLsb(s0s[1],
true);
74 int4 delta = util::Xor(node0, node1);
77 for (
int i = 0; i < in_bits - 1; ++i) {
78 int4 h0 = prg.Gen(util::Xor(hash_key, node0))[0];
79 int4 h1 = prg.Gen(util::Xor(hash_key, node1))[0];
81 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
86 int4 cw = util::Xor(h0, h1);
87 if (!a_bit) cw = util::Xor(cw, delta);
91 bool t0 = util::GetLsb(node0);
92 bool t1 = util::GetLsb(node1);
95 int4 zero4 = {0, 0, 0, 0};
96 int4 ab_mask0 = a_bit ? node0 : zero4;
97 int4 ab_mask1 = a_bit ? node1 : zero4;
98 int4 t0_mask = t0 ? cw : zero4;
99 int4 t1_mask = t1 ? cw : zero4;
101 node0 = util::Xor(util::Xor(h0, ab_mask0), t0_mask);
102 node1 = util::Xor(util::Xor(h1, ab_mask1), t1_mask);
105 delta = util::Xor(node0, node1);
110 bool a_n = (a >> 0) & 1;
111 bool t0 = util::GetLsb(node0);
112 bool t1 = util::GetLsb(node1);
115 int4 h0_0 = prg.Gen(util::Xor(hash_key, util::SetLsb(node0,
false)))[0];
116 int4 h0_1 = prg.Gen(util::Xor(hash_key, util::SetLsb(node0,
true)))[0];
117 int4 h1_0 = prg.Gen(util::Xor(hash_key, util::SetLsb(node1,
false)))[0];
118 int4 h1_1 = prg.Gen(util::Xor(hash_key, util::SetLsb(node1,
true)))[0];
121 int4 high0_0 = util::SetLsb(h0_0,
false);
122 bool low0_0 = util::GetLsb(h0_0);
123 int4 high0_1 = util::SetLsb(h0_1,
false);
124 bool low0_1 = util::GetLsb(h0_1);
125 int4 high1_0 = util::SetLsb(h1_0,
false);
126 bool low1_0 = util::GetLsb(h1_0);
127 int4 high1_1 = util::SetLsb(h1_1,
false);
128 bool low1_1 = util::GetLsb(h1_1);
133 if (a_n) HCW = util::Xor(high0_0, high1_0);
134 else HCW = util::Xor(high0_1, high1_1);
141 bool LCW_0 = low0_0 ^ low1_0 ^ !a_n;
142 bool LCW_1 = low0_1 ^ low1_1 ^ a_n;
145 cws[in_bits - 1] = {util::SetLsb(HCW, LCW_0), LCW_1};
151 leaf0 = util::SetLsb(high0_1, low0_1);
152 leaf1 = util::SetLsb(high1_1, low1_1);
154 leaf0 = util::SetLsb(high0_0, low0_0);
155 leaf1 = util::SetLsb(high1_0, low1_0);
159 bool lcw_an = a_n ? LCW_1 : LCW_0;
160 int4 leaf_cw = util::SetLsb(HCW, lcw_an);
161 if (t0) leaf0 = util::Xor(leaf0, leaf_cw);
162 if (t1) leaf1 = util::Xor(leaf1, leaf_cw);
165 auto v_cw = Group::From(b_buf) + (-Group::From(util::SetLsb(leaf0,
false))) +
166 Group::From(util::SetLsb(leaf1,
false));
167 if (util::GetLsb(leaf1)) v_cw = -v_cw;
182 __host__ __device__ int4
Eval(
bool b, int4 s0,
const Cw cws[], int4 ocw, In x) {
183 int4 node = util::SetLsb(s0, b);
186 for (
int i = 0; i < in_bits - 1; ++i) {
187 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
188 bool t = util::GetLsb(node);
190 int4 h = prg.Gen(util::Xor(hash_key, node))[0];
192 int4 zero4 = {0, 0, 0, 0};
193 int4 xb_mask = x_bit ? node : zero4;
194 int4 t_mask = t ? cws[i].s : zero4;
196 node = util::Xor(util::Xor(h, xb_mask), t_mask);
201 bool x_n = (x >> 0) & 1;
202 bool t = util::GetLsb(node);
204 int4 h = prg.Gen(util::Xor(hash_key, util::SetLsb(node, x_n)))[0];
207 int4 hcw = util::SetLsb(cws[in_bits - 1].s,
false);
209 if (x_n) lcw_xn = cws[in_bits - 1].extra;
210 else lcw_xn = util::GetLsb(cws[in_bits - 1].s);
212 int4 high = util::SetLsb(h,
false);
213 bool low = util::GetLsb(h);
216 high = util::Xor(high, hcw);
220 auto y = Group::From(high);
221 if (low) y = y + Group::From(ocw);
237 void EvalAll(
bool b, int4 s0,
const Cw cws[], int4 ocw, int4 ys[]) {
238 int4 node = util::SetLsb(s0, b);
240 assert(in_bits <
sizeof(
size_t) * 8);
242 int par_depth_ = util::ResolveParDepth(par_depth);
244 if constexpr (in_bits == 1) {
248 EvalLastLevel(b, node, cws, ocw, ys);
255 size_t num_leaves = 1ULL << (in_bits - 1);
260 EvalTree(node, cws, ys, 0, num_leaves, 0, par_depth_);
263 int4 hcw = util::SetLsb(cws[in_bits - 1].s,
false);
264 bool lcw_0 = util::GetLsb(cws[in_bits - 1].s);
265 bool lcw_1 = cws[in_bits - 1].extra;
266 auto ocw_group = Group::From(ocw);
269 for (
size_t j = num_leaves; j-- > 0;) {
270 ConvertLastLevel(b, ys[j], hcw, lcw_0, lcw_1, ocw_group, ys[2 * j], ys[2 * j + 1]);