myl7/fss 1.1.0
Function secret sharing (FSS) primitives including distributed point/comparison function (DPF/DCF)
Loading...
Searching...
No Matches
vdpf.cuh
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
35#pragma once
36#include <cuda_runtime.h>
37#include <cuda/std/array>
38#include <cuda/std/span>
39#include <cuda/std/tuple>
40#include <type_traits>
41#include <cstddef>
42#include <cassert>
43#include <omp.h>
44#include <fss/group.cuh>
45#include <fss/prg.cuh>
46#include <fss/hash.cuh>
47#include <fss/util.cuh>
48
49namespace fss {
50
63template <int in_bits, typename Group, typename Prg, typename XorHash, typename Hash,
64 typename In = uint, int par_depth = -1>
65 requires((std::is_unsigned_v<In> || std::is_same_v<In, __uint128_t>) &&
66 in_bits <= sizeof(In) * 8 && Groupable<Group> && Prgable<Prg, 2> && XorHashable<XorHash> &&
68class Vdpf {
69public:
70 Prg prg;
71 XorHash xor_hash;
72 Hash hash;
73
82 struct __align__(32) Cw {
83 int4 s;
84 bool tr;
85 };
86 // For only 1 and aligned memory access on GPU
87 static_assert(sizeof(Cw) == 32);
88
102 __host__ __device__ int Gen(Cw cws[], cuda::std::array<int4, 4> &cs, int4 &ocw,
103 cuda::std::span<const int4, 2> s0s, In a, int4 b_buf) {
104 int4 s0 = s0s[0];
105 s0 = util::SetLsb(s0, false);
106 int4 s1 = s0s[1];
107 s1 = util::SetLsb(s1, false);
108 bool t0 = false;
109 bool t1 = true;
110 b_buf = util::SetLsb(b_buf, false);
111
112 for (int i = 0; i < in_bits; ++i) {
113 auto [s0l, s0r] = prg.Gen(s0);
114 auto [s1l, s1r] = prg.Gen(s1);
115
116 bool t0l = util::GetLsb(s0l);
117 s0l = util::SetLsb(s0l, false);
118 bool t0r = util::GetLsb(s0r);
119 s0r = util::SetLsb(s0r, false);
120 bool t1l = util::GetLsb(s1l);
121 s1l = util::SetLsb(s1l, false);
122 bool t1r = util::GetLsb(s1r);
123 s1r = util::SetLsb(s1r, false);
124
125 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
126
127 int4 s_cw;
128 if (!a_bit) s_cw = util::Xor(s0r, s1r);
129 else s_cw = util::Xor(s0l, s1l);
130
131 bool tl_cw = t0l ^ t1l ^ a_bit ^ 1;
132 bool tr_cw = t0r ^ t1r ^ a_bit;
133
134 if (!a_bit) {
135 s0 = s0l;
136 if (t0) s0 = util::Xor(s0, s_cw);
137 s1 = s1l;
138 if (t1) s1 = util::Xor(s1, s_cw);
139
140 if (t0) t0 = t0l ^ tl_cw;
141 else t0 = t0l;
142 if (t1) t1 = t1l ^ tl_cw;
143 else t1 = t1l;
144 } else {
145 s0 = s0r;
146 if (t0) s0 = util::Xor(s0, s_cw);
147 s1 = s1r;
148 if (t1) s1 = util::Xor(s1, s_cw);
149
150 if (t0) t0 = t0r ^ tr_cw;
151 else t0 = t0r;
152 if (t1) t1 = t1r ^ tr_cw;
153 else t1 = t1r;
154 }
155
156 s_cw = util::SetLsb(s_cw, tl_cw);
157 cws[i] = {s_cw, tr_cw};
158 }
159
160 // Verification hash
161 int4 a_buf = util::Pack(a);
162
163 auto pi_tilde_0 = xor_hash.Hash(cuda::std::tuple<int4, const int4>{a_buf, s0});
164 auto pi_tilde_1 = xor_hash.Hash(cuda::std::tuple<int4, const int4>{a_buf, s1});
165 cs = util::Xor(
166 cuda::std::span<const int4, 4>(pi_tilde_0), cuda::std::span<const int4, 4>(pi_tilde_1));
167
168 // Check retry condition
169 if (t0 == t1) return 1;
170
171 // Output correction word
172 auto v_cw = Group::From(b_buf) + (-Group::From(s0)) + Group::From(s1);
173 if (t1) v_cw = -v_cw;
174 ocw = v_cw.Into();
175
176 return 0;
177 }
178
191 __host__ __device__ cuda::std::array<int4, 4> Eval(bool b, int4 s0,
192 cuda::std::span<const Cw> cws, cuda::std::span<const int4, 4> cs, int4 ocw, In x, int4 &y) {
193 int4 s = s0;
194 s = util::SetLsb(s, false);
195 bool t = b;
196
197 for (int i = 0; i < in_bits; ++i) {
198 Cw cw = cws[i];
199 int4 s_cw = cw.s;
200 bool tl_cw = util::GetLsb(s_cw);
201 s_cw = util::SetLsb(s_cw, false);
202 bool tr_cw = cw.tr;
203
204 auto [sl, sr] = prg.Gen(s);
205
206 bool tl = util::GetLsb(sl);
207 sl = util::SetLsb(sl, false);
208 bool tr = util::GetLsb(sr);
209 sr = util::SetLsb(sr, false);
210
211 if (t) {
212 sl = util::Xor(sl, s_cw);
213 sr = util::Xor(sr, s_cw);
214 tl = tl ^ tl_cw;
215 tr = tr ^ tr_cw;
216 }
217
218 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
219
220 if (!x_bit) {
221 s = sl;
222 t = tl;
223 } else {
224 s = sr;
225 t = tr;
226 }
227 }
228
229 // Output share
230 auto g = Group::From(s);
231 assert((ocw.w & 1) == 0);
232 if (t) g = g + Group::From(ocw);
233 if (b) g = -g;
234 y = g.Into();
235
236 // Corrected verification hash
237 int4 x_buf = util::Pack(x);
238
239 auto pi_tilde = xor_hash.Hash(cuda::std::tuple<int4, const int4>{x_buf, s});
240 if (t) {
241 return util::Xor(
242 cuda::std::span<const int4, 4>(pi_tilde), cuda::std::span<const int4, 4>(cs));
243 }
244 return pi_tilde;
245 }
246
256 void Prove(cuda::std::span<const cuda::std::array<int4, 4>> pi_tildes,
257 cuda::std::span<const int4, 4> cs, cuda::std::array<int4, 4> &pi) {
258 pi = {cs[0], cs[1], cs[2], cs[3]};
259 for (size_t i = 0; i < pi_tildes.size(); ++i) {
260 cuda::std::array<int4, 4> h_input = util::Xor(
261 cuda::std::span<const int4, 4>(pi), cuda::std::span<const int4, 4>(pi_tildes[i]));
262 auto h_out = hash.Hash(cuda::std::span<const int4, 4>(h_input));
263 pi[0] = util::Xor(pi[0], h_out[0]);
264 pi[1] = util::Xor(pi[1], h_out[1]);
265 }
266 }
267
273 __host__ __device__ static bool Verify(
274 cuda::std::span<const int4, 4> pi0, cuda::std::span<const int4, 4> pi1) {
275 for (int i = 0; i < 4; ++i) {
276 if (pi0[i].x != pi1[i].x || pi0[i].y != pi1[i].y || pi0[i].z != pi1[i].z ||
277 pi0[i].w != pi1[i].w)
278 return false;
279 }
280 return true;
281 }
282
299 void EvalAll(bool b, int4 s0, cuda::std::span<const Cw> cws, cuda::std::span<const int4, 4> cs,
300 int4 ocw, cuda::std::span<int4> ys, cuda::std::array<int4, 4> &pi) {
301 int4 st = s0;
302 bool t = b;
303 st = util::SetLsb(st, t);
304
305 assert(in_bits < sizeof(size_t) * 8);
306 size_t l = 0;
307 size_t r = 1ULL << in_bits;
308 int i = 0;
309
310 int par_depth_ = util::ResolveParDepth(par_depth);
311
312 // Phase 1: tree traversal, store (s, t) packed into ys temporarily
313#pragma omp parallel
314#pragma omp single
315 EvalTree(st, cws, ys, l, r, i, par_depth_);
316
317 // Phase 2: sequential output computation and proof accumulation
318 pi = {cs[0], cs[1], cs[2], cs[3]};
319 size_t n = 1ULL << in_bits;
320 assert((ocw.w & 1) == 0);
321 auto ocw_group = Group::From(ocw);
322 for (size_t j = 0; j < n; ++j) {
323 int4 sj = ys[j];
324 bool tj = util::GetLsb(sj);
325 sj = util::SetLsb(sj, false);
326
327 // Output share
328 auto g = Group::From(sj);
329 if (tj) g = g + ocw_group;
330 if (b) g = -g;
331 ys[j] = g.Into();
332
333 // Proof accumulation
334 int4 x_buf = util::Pack(static_cast<In>(j));
335
336 auto pi_tilde = xor_hash.Hash(cuda::std::tuple<int4, const int4>{x_buf, sj});
337 if (tj) {
338 pi_tilde = util::Xor(
339 cuda::std::span<const int4, 4>(pi_tilde), cuda::std::span<const int4, 4>(cs));
340 }
341
342 cuda::std::array<int4, 4> h_input = util::Xor(
343 cuda::std::span<const int4, 4>(pi), cuda::std::span<const int4, 4>(pi_tilde));
344 auto h_out = hash.Hash(cuda::std::span<const int4, 4>(h_input));
345 pi[0] = util::Xor(pi[0], h_out[0]);
346 pi[1] = util::Xor(pi[1], h_out[1]);
347 }
348 }
349
350private:
351 void EvalTree(int4 st, cuda::std::span<const Cw> cws, cuda::std::span<int4> ys, size_t l,
352 size_t r, int i, int par_depth_) {
353 if (i == in_bits) {
354 assert(l + 1 == r);
355 ys[l] = st;
356 return;
357 }
358
359 bool t = util::GetLsb(st);
360 int4 s = st;
361 s = util::SetLsb(s, false);
362
363 Cw cw = cws[i];
364 int4 s_cw = cw.s;
365 bool tl_cw = util::GetLsb(s_cw);
366 s_cw = util::SetLsb(s_cw, false);
367 bool tr_cw = cw.tr;
368
369 auto [sl, sr] = prg.Gen(s);
370
371 bool tl = util::GetLsb(sl);
372 sl = util::SetLsb(sl, false);
373 bool tr = util::GetLsb(sr);
374 sr = util::SetLsb(sr, false);
375
376 if (t) {
377 sl = util::Xor(sl, s_cw);
378 sr = util::Xor(sr, s_cw);
379 tl = tl ^ tl_cw;
380 tr = tr ^ tr_cw;
381 }
382
383 int4 stl = sl;
384 stl = util::SetLsb(stl, tl);
385 int4 str = sr;
386 str = util::SetLsb(str, tr);
387
388 size_t mid = (l + r) / 2;
389
390 if (i < par_depth_) {
391#pragma omp task
392 EvalTree(stl, cws, ys, l, mid, i + 1, par_depth_);
393#pragma omp task
394 EvalTree(str, cws, ys, mid, r, i + 1, par_depth_);
395#pragma omp taskwait
396 } else {
397 EvalTree(stl, cws, ys, l, mid, i + 1, par_depth_);
398 EvalTree(str, cws, ys, mid, r, i + 1, par_depth_);
399 }
400 }
401};
402
403} // namespace fss
2-party VDPF scheme.
Definition vdpf.cuh:68
int Gen(Cw cws[], cuda::std::array< int4, 4 > &cs, int4 &ocw, cuda::std::span< const int4, 2 > s0s, In a, int4 b_buf)
Key generation method.
Definition vdpf.cuh:102
void Prove(cuda::std::span< const cuda::std::array< int4, 4 > > pi_tildes, cuda::std::span< const int4, 4 > cs, cuda::std::array< int4, 4 > &pi)
Proof accumulation method.
Definition vdpf.cuh:256
cuda::std::array< int4, 4 > Eval(bool b, int4 s0, cuda::std::span< const Cw > cws, cuda::std::span< const int4, 4 > cs, int4 ocw, In x, int4 &y)
Evaluation method.
Definition vdpf.cuh:191
static bool Verify(cuda::std::span< const int4, 4 > pi0, cuda::std::span< const int4, 4 > pi1)
Verification method.
Definition vdpf.cuh:273
void EvalAll(bool b, int4 s0, cuda::std::span< const Cw > cws, cuda::std::span< const int4, 4 > cs, int4 ocw, cuda::std::span< int4 > ys, cuda::std::array< int4, 4 > &pi)
Full domain evaluation method.
Definition vdpf.cuh:299
Group interface.
Definition group.cuh:40
Collision-resistant hash interface.
Definition hash.cuh:19
Pseudorandom generator (PRG) interface.
Definition prg.cuh:21
Collision-resistant and XOR-collision-resistant hash interface.
Definition hash.cuh:27
Correction word.
Definition vdpf.cuh:82