myl7/fss 1.1.0
Function secret sharing (FSS) primitives including distributed point/comparison function (DPF/DCF)
Loading...
Searching...
No Matches
dcf.cuh
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
41#pragma once
42#include <cuda_runtime.h>
43#include <type_traits>
44#include <cstddef>
45#include <cassert>
46#include <omp.h>
47#include <fss/group.cuh>
48#include <fss/prg.cuh>
49#include <fss/util.cuh>
50
51namespace fss {
52
58enum class DcfPred {
59 kLt,
60 kGt,
61};
62
74template <int in_bits, typename Group, typename Prg, typename In = uint,
75 DcfPred pred = DcfPred::kLt, int par_depth = -1>
76 requires((std::is_unsigned_v<In> || std::is_same_v<In, __uint128_t>) &&
77 in_bits <= sizeof(In) * 8 && Groupable<Group> && Prgable<Prg, 4>)
78class Dcf {
79public:
80 Prg prg;
81
91 struct __align__(32) Cw {
92 int4 s;
93 int4 v;
94 };
95 // For only 1 and aligned memory access on GPU
96 static_assert(sizeof(Cw) == 32);
97
108 __host__ __device__ void Gen(Cw cws[], const int4 s0s[2], In a, int4 b_buf) {
109 int4 s0 = s0s[0];
110 s0 = util::SetLsb(s0, false);
111 int4 s1 = s0s[1];
112 s1 = util::SetLsb(s1, false);
113 bool t0 = false;
114 bool t1 = true;
115 Group v;
116 b_buf = util::SetLsb(b_buf, false);
117
118 for (int i = 0; i < in_bits; ++i) {
119 auto [s0l, v0l_buf, s0r, v0r_buf] = prg.Gen(s0);
120 auto [s1l, v1l_buf, s1r, v1r_buf] = prg.Gen(s1);
121
122 bool t0l = util::GetLsb(s0l);
123 s0l = util::SetLsb(s0l, false);
124 v0l_buf = util::SetLsb(v0l_buf, false);
125 auto v0l = Group::From(v0l_buf);
126 bool t0r = util::GetLsb(s0r);
127 s0r = util::SetLsb(s0r, false);
128 v0r_buf = util::SetLsb(v0r_buf, false);
129 auto v0r = Group::From(v0r_buf);
130 bool t1l = util::GetLsb(s1l);
131 s1l = util::SetLsb(s1l, false);
132 v1l_buf = util::SetLsb(v1l_buf, false);
133 auto v1l = Group::From(v1l_buf);
134 bool t1r = util::GetLsb(s1r);
135 s1r = util::SetLsb(s1r, false);
136 v1r_buf = util::SetLsb(v1r_buf, false);
137 auto v1r = Group::From(v1r_buf);
138
139 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
140
141 int4 s_cw;
142 if (!a_bit) s_cw = util::Xor(s0r, s1r);
143 else s_cw = util::Xor(s0l, s1l);
144
145 Group v_cw = (-v);
146 if (!a_bit) {
147 v_cw = v_cw + v1r + (-v0r);
148 if constexpr (pred == DcfPred::kGt) v_cw = v_cw + Group::From(b_buf);
149 } else {
150 v_cw = v_cw + v1l + (-v0l);
151 if constexpr (pred == DcfPred::kLt) v_cw = v_cw + Group::From(b_buf);
152 }
153 if (t1) v_cw = -v_cw;
154
155 if (!a_bit) v = v + (-v1l) + v0l;
156 else v = v + (-v1r) + v0r;
157 if (t1) v = v + (-v_cw);
158 else v = v + v_cw;
159
160 bool tl_cw = t0l ^ t1l ^ a_bit ^ 1;
161 bool tr_cw = t0r ^ t1r ^ a_bit;
162
163 if (!a_bit) {
164 s0 = s0l;
165 if (t0) s0 = util::Xor(s0, s_cw);
166 s1 = s1l;
167 if (t1) s1 = util::Xor(s1, s_cw);
168
169 if (t0) t0 = t0l ^ tl_cw;
170 else t0 = t0l;
171 if (t1) t1 = t1l ^ tl_cw;
172 else t1 = t1l;
173 } else {
174 s0 = s0r;
175 if (t0) s0 = util::Xor(s0, s_cw);
176 s1 = s1r;
177 if (t1) s1 = util::Xor(s1, s_cw);
178
179 if (t0) t0 = t0r ^ tr_cw;
180 else t0 = t0r;
181 if (t1) t1 = t1r ^ tr_cw;
182 else t1 = t1r;
183 }
184
185 s_cw = util::SetLsb(s_cw, tl_cw);
186 int4 v_buf = v_cw.Into();
187 v_buf = util::SetLsb(v_buf, tr_cw);
188 cws[i] = {s_cw, v_buf};
189 }
190
191 auto v_cw_np1 = Group::From(s1) + (-Group::From(s0)) + (-v);
192 if (t1) v_cw_np1 = -v_cw_np1;
193 cws[in_bits] = {{0, 0, 0, 0}, v_cw_np1.Into()};
194 }
195
205 __host__ __device__ int4 Eval(bool b, int4 s0, const Cw cws[], In x) {
206 int4 s = s0;
207 s = util::SetLsb(s, false);
208 Group v;
209 bool t = b;
210
211 for (int i = 0; i < in_bits; ++i) {
212 auto cw = cws[i];
213
214 int4 s_cw = cw.s;
215 bool tl_cw = util::GetLsb(s_cw);
216 s_cw = util::SetLsb(s_cw, false);
217
218 int4 v_cw_buf = cw.v;
219 bool tr_cw = util::GetLsb(v_cw_buf);
220 v_cw_buf = util::SetLsb(v_cw_buf, false);
221 auto v_cw = Group::From(v_cw_buf);
222
223 auto [sl, vl_buf, sr, vr_buf] = prg.Gen(s);
224
225 bool tl = util::GetLsb(sl);
226 sl = util::SetLsb(sl, false);
227 vl_buf = util::SetLsb(vl_buf, false);
228 auto vl = Group::From(vl_buf);
229
230 bool tr = util::GetLsb(sr);
231 sr = util::SetLsb(sr, false);
232 vr_buf = util::SetLsb(vr_buf, false);
233 auto vr = Group::From(vr_buf);
234
235 if (t) {
236 sl = util::Xor(sl, s_cw);
237 sr = util::Xor(sr, s_cw);
238 tl = tl ^ tl_cw;
239 tr = tr ^ tr_cw;
240 }
241
242 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
243
244 if (b) {
245 if (!x_bit) v = v + (-vl);
246 else v = v + (-vr);
247 if (t) v = v + (-v_cw);
248 } else {
249 if (!x_bit) v = v + vl;
250 else v = v + vr;
251 if (t) v = v + v_cw;
252 }
253
254 if (!x_bit) {
255 s = sl;
256 t = tl;
257 } else {
258 s = sr;
259 t = tr;
260 }
261 }
262
263 int4 v_cw_np1_buf = cws[in_bits].v;
264 assert((v_cw_np1_buf.w & 1) == 0);
265 auto v_cw_np1 = Group::From(v_cw_np1_buf);
266
267 if (b) {
268 v = v + (-Group::From(s));
269 if (t) v = v + (-v_cw_np1);
270 } else {
271 v = v + Group::From(s);
272 if (t) v = v + v_cw_np1;
273 }
274
275 return v.Into();
276 }
277
293 void EvalAll(bool b, int4 s0, const Cw cws[], int4 ys[]) {
294 int4 st = s0;
295 bool t = b;
296 st = util::SetLsb(st, t);
297
298 assert(in_bits < sizeof(size_t) * 8);
299 size_t l = 0;
300 size_t r = 1ULL << in_bits;
301 int i = 0;
302
303 int par_depth_ = util::ResolveParDepth(par_depth);
304
305 Group v;
306
307#pragma omp parallel
308#pragma omp single
309 EvalTree(b, st, cws, ys, l, r, i, par_depth_, v);
310 }
311
312private:
313 void EvalTree(bool b, int4 st, const Cw cws[], int4 ys[], size_t l, size_t r, int i,
314 int par_depth_, Group v) {
315 bool t = util::GetLsb(st);
316 int4 s = st;
317 s = util::SetLsb(s, false);
318
319 if (i == in_bits) {
320 int4 v_cw_np1_buf = cws[in_bits].v;
321 assert((v_cw_np1_buf.w & 1) == 0);
322 auto term = Group::From(s);
323 if (t) term = term + Group::From(v_cw_np1_buf);
324 if (b) term = -term;
325 v = v + term;
326 assert(l + 1 == r);
327 ys[l] = v.Into();
328 return;
329 }
330
331 Cw cw = cws[i];
332 int4 s_cw = cw.s;
333 bool tl_cw = util::GetLsb(s_cw);
334 s_cw = util::SetLsb(s_cw, false);
335 int4 v_cw_buf = cw.v;
336 bool tr_cw = util::GetLsb(v_cw_buf);
337 v_cw_buf = util::SetLsb(v_cw_buf, false);
338 auto v_cw = Group::From(v_cw_buf);
339
340 auto [sl, vl_buf, sr, vr_buf] = prg.Gen(s);
341
342 bool tl = util::GetLsb(sl);
343 sl = util::SetLsb(sl, false);
344 vl_buf = util::SetLsb(vl_buf, false);
345 auto vl = Group::From(vl_buf);
346
347 bool tr = util::GetLsb(sr);
348 sr = util::SetLsb(sr, false);
349 vr_buf = util::SetLsb(vr_buf, false);
350 auto vr = Group::From(vr_buf);
351
352 if (t) {
353 sl = util::Xor(sl, s_cw);
354 sr = util::Xor(sr, s_cw);
355 tl = tl ^ tl_cw;
356 tr = tr ^ tr_cw;
357 vl = vl + v_cw;
358 vr = vr + v_cw;
359 }
360 if (b) {
361 vl = -vl;
362 vr = -vr;
363 }
364
365 vl = vl + v;
366 vr = vr + v;
367
368 int4 stl = sl;
369 stl = util::SetLsb(stl, tl);
370 int4 str = sr;
371 str = util::SetLsb(str, tr);
372
373 size_t mid = (l + r) / 2;
374
375 if (i < par_depth_) {
376#pragma omp task
377 EvalTree(b, stl, cws, ys, l, mid, i + 1, par_depth_, vl);
378#pragma omp task
379 EvalTree(b, str, cws, ys, mid, r, i + 1, par_depth_, vr);
380#pragma omp taskwait
381 } else {
382 EvalTree(b, stl, cws, ys, l, mid, i + 1, par_depth_, vl);
383 EvalTree(b, str, cws, ys, mid, r, i + 1, par_depth_, vr);
384 }
385 }
386};
387
388} // namespace fss
2-party DCF scheme.
Definition dcf.cuh:78
int4 Eval(bool b, int4 s0, const Cw cws[], In x)
Evaluation method.
Definition dcf.cuh:205
void Gen(Cw cws[], const int4 s0s[2], In a, int4 b_buf)
Key generation method.
Definition dcf.cuh:108
void EvalAll(bool b, int4 s0, const Cw cws[], int4 ys[])
Full domain evaluation method.
Definition dcf.cuh:293
Group interface.
Definition group.cuh:40
Pseudorandom generator (PRG) interface.
Definition prg.cuh:21
DcfPred
Comparison predicate.
Definition dcf.cuh:58
Correction word.
Definition dcf.cuh:91