108 __host__ __device__
void Gen(
Cw cws[],
const int4 s0s[2], In a, int4 b_buf) {
110 s0 = util::SetLsb(s0,
false);
112 s1 = util::SetLsb(s1,
false);
116 b_buf = util::SetLsb(b_buf,
false);
118 for (
int i = 0; i < in_bits; ++i) {
119 auto [s0l, v0l_buf, s0r, v0r_buf] = prg.Gen(s0);
120 auto [s1l, v1l_buf, s1r, v1r_buf] = prg.Gen(s1);
122 bool t0l = util::GetLsb(s0l);
123 s0l = util::SetLsb(s0l,
false);
124 v0l_buf = util::SetLsb(v0l_buf,
false);
125 auto v0l = Group::From(v0l_buf);
126 bool t0r = util::GetLsb(s0r);
127 s0r = util::SetLsb(s0r,
false);
128 v0r_buf = util::SetLsb(v0r_buf,
false);
129 auto v0r = Group::From(v0r_buf);
130 bool t1l = util::GetLsb(s1l);
131 s1l = util::SetLsb(s1l,
false);
132 v1l_buf = util::SetLsb(v1l_buf,
false);
133 auto v1l = Group::From(v1l_buf);
134 bool t1r = util::GetLsb(s1r);
135 s1r = util::SetLsb(s1r,
false);
136 v1r_buf = util::SetLsb(v1r_buf,
false);
137 auto v1r = Group::From(v1r_buf);
139 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
142 if (!a_bit) s_cw = util::Xor(s0r, s1r);
143 else s_cw = util::Xor(s0l, s1l);
147 v_cw = v_cw + v1r + (-v0r);
148 if constexpr (pred == DcfPred::kGt) v_cw = v_cw + Group::From(b_buf);
150 v_cw = v_cw + v1l + (-v0l);
151 if constexpr (pred == DcfPred::kLt) v_cw = v_cw + Group::From(b_buf);
153 if (t1) v_cw = -v_cw;
155 if (!a_bit) v = v + (-v1l) + v0l;
156 else v = v + (-v1r) + v0r;
157 if (t1) v = v + (-v_cw);
160 bool tl_cw = t0l ^ t1l ^ a_bit ^ 1;
161 bool tr_cw = t0r ^ t1r ^ a_bit;
165 if (t0) s0 = util::Xor(s0, s_cw);
167 if (t1) s1 = util::Xor(s1, s_cw);
169 if (t0) t0 = t0l ^ tl_cw;
171 if (t1) t1 = t1l ^ tl_cw;
175 if (t0) s0 = util::Xor(s0, s_cw);
177 if (t1) s1 = util::Xor(s1, s_cw);
179 if (t0) t0 = t0r ^ tr_cw;
181 if (t1) t1 = t1r ^ tr_cw;
185 s_cw = util::SetLsb(s_cw, tl_cw);
186 int4 v_buf = v_cw.Into();
187 v_buf = util::SetLsb(v_buf, tr_cw);
188 cws[i] = {s_cw, v_buf};
191 auto v_cw_np1 = Group::From(s1) + (-Group::From(s0)) + (-v);
192 if (t1) v_cw_np1 = -v_cw_np1;
193 cws[in_bits] = {{0, 0, 0, 0}, v_cw_np1.Into()};
205 __host__ __device__ int4
Eval(
bool b, int4 s0,
const Cw cws[], In x) {
207 s = util::SetLsb(s,
false);
211 for (
int i = 0; i < in_bits; ++i) {
215 bool tl_cw = util::GetLsb(s_cw);
216 s_cw = util::SetLsb(s_cw,
false);
218 int4 v_cw_buf = cw.v;
219 bool tr_cw = util::GetLsb(v_cw_buf);
220 v_cw_buf = util::SetLsb(v_cw_buf,
false);
221 auto v_cw = Group::From(v_cw_buf);
223 auto [sl, vl_buf, sr, vr_buf] = prg.Gen(s);
225 bool tl = util::GetLsb(sl);
226 sl = util::SetLsb(sl,
false);
227 vl_buf = util::SetLsb(vl_buf,
false);
228 auto vl = Group::From(vl_buf);
230 bool tr = util::GetLsb(sr);
231 sr = util::SetLsb(sr,
false);
232 vr_buf = util::SetLsb(vr_buf,
false);
233 auto vr = Group::From(vr_buf);
236 sl = util::Xor(sl, s_cw);
237 sr = util::Xor(sr, s_cw);
242 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
245 if (!x_bit) v = v + (-vl);
247 if (t) v = v + (-v_cw);
249 if (!x_bit) v = v + vl;
263 int4 v_cw_np1_buf = cws[in_bits].v;
264 assert((v_cw_np1_buf.w & 1) == 0);
265 auto v_cw_np1 = Group::From(v_cw_np1_buf);
268 v = v + (-Group::From(s));
269 if (t) v = v + (-v_cw_np1);
271 v = v + Group::From(s);
272 if (t) v = v + v_cw_np1;
293 void EvalAll(
bool b, int4 s0,
const Cw cws[], int4 ys[]) {
296 st = util::SetLsb(st, t);
298 assert(in_bits <
sizeof(
size_t) * 8);
300 size_t r = 1ULL << in_bits;
303 int par_depth_ = util::ResolveParDepth(par_depth);
309 EvalTree(b, st, cws, ys, l, r, i, par_depth_, v);
377 EvalTree(b, stl, cws, ys, l, mid, i + 1, par_depth_, vl);
379 EvalTree(b, str, cws, ys, mid, r, i + 1, par_depth_, vr);
382 EvalTree(b, stl, cws, ys, l, mid, i + 1, par_depth_, vl);
383 EvalTree(b, str, cws, ys, mid, r, i + 1, par_depth_, vr);