76 struct __align__(32)
Cw {
81 static_assert(
sizeof(
Cw) == 32);
93 __host__ __device__
void Gen(
Cw cws[],
const int4 s0s[2], In a, int4 b_buf) {
95 s0 = util::SetLsb(s0,
false);
97 s1 = util::SetLsb(s1,
false);
100 b_buf = util::SetLsb(b_buf,
false);
102 for (
int i = 0; i < in_bits; ++i) {
103 auto [s0l, s0r] = prg.Gen(s0);
104 auto [s1l, s1r] = prg.Gen(s1);
106 bool t0l = util::GetLsb(s0l);
107 s0l = util::SetLsb(s0l,
false);
108 bool t0r = util::GetLsb(s0r);
109 s0r = util::SetLsb(s0r,
false);
110 bool t1l = util::GetLsb(s1l);
111 s1l = util::SetLsb(s1l,
false);
112 bool t1r = util::GetLsb(s1r);
113 s1r = util::SetLsb(s1r,
false);
115 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
118 if (!a_bit) s_cw = util::Xor(s0r, s1r);
119 else s_cw = util::Xor(s0l, s1l);
121 bool tl_cw = t0l ^ t1l ^ a_bit ^ 1;
122 bool tr_cw = t0r ^ t1r ^ a_bit;
126 if (t0) s0 = util::Xor(s0, s_cw);
128 if (t1) s1 = util::Xor(s1, s_cw);
130 if (t0) t0 = t0l ^ tl_cw;
132 if (t1) t1 = t1l ^ tl_cw;
136 if (t0) s0 = util::Xor(s0, s_cw);
138 if (t1) s1 = util::Xor(s1, s_cw);
140 if (t0) t0 = t0r ^ tr_cw;
142 if (t1) t1 = t1r ^ tr_cw;
146 s_cw = util::SetLsb(s_cw, tl_cw);
147 cws[i] = {s_cw, tr_cw};
150 auto v_cw_np1 = Group::From(b_buf) + (-Group::From(s0)) + Group::From(s1);
151 if (t1) v_cw_np1 = -v_cw_np1;
152 cws[in_bits] = {v_cw_np1.Into(),
false};
164 __host__ __device__ int4
Eval(
bool b, int4 s0,
const Cw cws[], In x) {
166 s = util::SetLsb(s,
false);
169 for (
int i = 0; i < in_bits; ++i) {
172 bool tl_cw = util::GetLsb(s_cw);
173 s_cw = util::SetLsb(s_cw,
false);
176 auto [sl, sr] = prg.Gen(s);
178 bool tl = util::GetLsb(sl);
179 sl = util::SetLsb(sl,
false);
180 bool tr = util::GetLsb(sr);
181 sr = util::SetLsb(sr,
false);
184 sl = util::Xor(sl, s_cw);
185 sr = util::Xor(sr, s_cw);
190 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
201 auto y = Group::From(s);
202 int4 v_cw_np1 = cws[in_bits].s;
203 assert((v_cw_np1.w & 1) == 0);
204 if (t) y = y + Group::From(v_cw_np1);
225 void EvalAll(
bool b, int4 s0,
const Cw cws[], int4 ys[]) {
228 st = util::SetLsb(st, t);
230 assert(in_bits <
sizeof(
size_t) * 8);
232 size_t r = 1ULL << in_bits;
235 int par_depth_ = util::ResolveParDepth(par_depth);
239 EvalTree(b, st, cws, ys, l, r, i, par_depth_);
244 bool b, int4 st,
const Cw cws[], int4 ys[],
size_t l,
size_t r,
int i,
int par_depth_) {
245 bool t = util::GetLsb(st);
247 s = util::SetLsb(s,
false);
250 auto y = Group::From(s);
251 int4 v_cw_np1 = cws[in_bits].s;
252 assert((v_cw_np1.w & 1) == 0);
253 if (t) y = y + Group::From(v_cw_np1);
262 bool tl_cw = util::GetLsb(s_cw);
263 s_cw = util::SetLsb(s_cw,
false);
266 auto [sl, sr] = prg.Gen(s);
268 bool tl = util::GetLsb(sl);
269 sl = util::SetLsb(sl,
false);
270 bool tr = util::GetLsb(sr);
271 sr = util::SetLsb(sr,
false);
274 sl = util::Xor(sl, s_cw);
275 sr = util::Xor(sr, s_cw);
281 stl = util::SetLsb(stl, tl);
283 str = util::SetLsb(str, tr);
285 size_t mid = (l + r) / 2;
287 if (i < par_depth_) {
289 EvalTree(b, stl, cws, ys, l, mid, i + 1, par_depth_);
291 EvalTree(b, str, cws, ys, mid, r, i + 1, par_depth_);
294 EvalTree(b, stl, cws, ys, l, mid, i + 1, par_depth_);
295 EvalTree(b, str, cws, ys, mid, r, i + 1, par_depth_);
void EvalAll(bool b, int4 s0, const Cw cws[], int4 ys[])
Full domain evaluation method.
Definition dpf.cuh:225