myl7/fss 1.1.0
Function secret sharing (FSS) primitives including distributed point/comparison function (DPF/DCF)
Loading...
Searching...
No Matches
vdpf.cuh
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
35#pragma once
36#include <cuda_runtime.h>
37#include <cuda/std/array>
38#include <cuda/std/span>
39#include <cuda/std/tuple>
40#include <type_traits>
41#include <cstddef>
42#include <cassert>
43#include <omp.h>
44#include <fss/group.cuh>
45#include <fss/prg.cuh>
46#include <fss/hash.cuh>
47#include <fss/util.cuh>
48
49namespace fss {
50
63template <int in_bits, typename Group, typename Prg, typename XorHash, typename Hash, typename In = uint,
64 int par_depth = -1>
65 requires((std::is_unsigned_v<In> || std::is_same_v<In, __uint128_t>) && in_bits <= sizeof(In) * 8 &&
67class Vdpf {
68public:
69 Prg prg;
70 XorHash xor_hash;
71 Hash hash;
72
81 struct __align__(32) Cw {
82 int4 s;
83 bool tr;
84 };
85 // For only 1 and aligned memory access on GPU
86 static_assert(sizeof(Cw) == 32);
87
101 __host__ __device__ int Gen(
102 Cw cws[], cuda::std::array<int4, 4> &cs, int4 &ocw, cuda::std::span<const int4, 2> s0s, In a, int4 b_buf) {
103 int4 s0 = s0s[0];
104 s0 = util::SetLsb(s0, false);
105 int4 s1 = s0s[1];
106 s1 = util::SetLsb(s1, false);
107 bool t0 = false;
108 bool t1 = true;
109 b_buf = util::SetLsb(b_buf, false);
110
111 for (int i = 0; i < in_bits; ++i) {
112 auto [s0l, s0r] = prg.Gen(s0);
113 auto [s1l, s1r] = prg.Gen(s1);
114
115 bool t0l = util::GetLsb(s0l);
116 s0l = util::SetLsb(s0l, false);
117 bool t0r = util::GetLsb(s0r);
118 s0r = util::SetLsb(s0r, false);
119 bool t1l = util::GetLsb(s1l);
120 s1l = util::SetLsb(s1l, false);
121 bool t1r = util::GetLsb(s1r);
122 s1r = util::SetLsb(s1r, false);
123
124 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
125
126 int4 s_cw;
127 if (!a_bit) s_cw = util::Xor(s0r, s1r);
128 else s_cw = util::Xor(s0l, s1l);
129
130 bool tl_cw = t0l ^ t1l ^ a_bit ^ 1;
131 bool tr_cw = t0r ^ t1r ^ a_bit;
132
133 if (!a_bit) {
134 s0 = s0l;
135 if (t0) s0 = util::Xor(s0, s_cw);
136 s1 = s1l;
137 if (t1) s1 = util::Xor(s1, s_cw);
138
139 if (t0) t0 = t0l ^ tl_cw;
140 else t0 = t0l;
141 if (t1) t1 = t1l ^ tl_cw;
142 else t1 = t1l;
143 } else {
144 s0 = s0r;
145 if (t0) s0 = util::Xor(s0, s_cw);
146 s1 = s1r;
147 if (t1) s1 = util::Xor(s1, s_cw);
148
149 if (t0) t0 = t0r ^ tr_cw;
150 else t0 = t0r;
151 if (t1) t1 = t1r ^ tr_cw;
152 else t1 = t1r;
153 }
154
155 s_cw = util::SetLsb(s_cw, tl_cw);
156 cws[i] = {s_cw, tr_cw};
157 }
158
159 // Verification hash
160 int4 a_buf = util::Pack(a);
161
162 auto pi_tilde_0 = xor_hash.Hash(cuda::std::tuple<int4, const int4>{a_buf, s0});
163 auto pi_tilde_1 = xor_hash.Hash(cuda::std::tuple<int4, const int4>{a_buf, s1});
164 cs = util::Xor(cuda::std::span<const int4, 4>(pi_tilde_0), cuda::std::span<const int4, 4>(pi_tilde_1));
165
166 // Check retry condition
167 if (t0 == t1) return 1;
168
169 // Output correction word
170 auto v_cw = Group::From(b_buf) + (-Group::From(s0)) + Group::From(s1);
171 if (t1) v_cw = -v_cw;
172 ocw = v_cw.Into();
173
174 return 0;
175 }
176
189 __host__ __device__ cuda::std::array<int4, 4> Eval(
190 bool b, int4 s0, cuda::std::span<const Cw> cws, cuda::std::span<const int4, 4> cs, int4 ocw, In x, int4 &y) {
191 int4 s = s0;
192 s = util::SetLsb(s, false);
193 bool t = b;
194
195 for (int i = 0; i < in_bits; ++i) {
196 Cw cw = cws[i];
197 int4 s_cw = cw.s;
198 bool tl_cw = util::GetLsb(s_cw);
199 s_cw = util::SetLsb(s_cw, false);
200 bool tr_cw = cw.tr;
201
202 auto [sl, sr] = prg.Gen(s);
203
204 bool tl = util::GetLsb(sl);
205 sl = util::SetLsb(sl, false);
206 bool tr = util::GetLsb(sr);
207 sr = util::SetLsb(sr, false);
208
209 if (t) {
210 sl = util::Xor(sl, s_cw);
211 sr = util::Xor(sr, s_cw);
212 tl = tl ^ tl_cw;
213 tr = tr ^ tr_cw;
214 }
215
216 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
217
218 if (!x_bit) {
219 s = sl;
220 t = tl;
221 } else {
222 s = sr;
223 t = tr;
224 }
225 }
226
227 // Output share
228 auto g = Group::From(s);
229 assert((ocw.w & 1) == 0);
230 if (t) g = g + Group::From(ocw);
231 if (b) g = -g;
232 y = g.Into();
233
234 // Corrected verification hash
235 int4 x_buf = util::Pack(x);
236
237 auto pi_tilde = xor_hash.Hash(cuda::std::tuple<int4, const int4>{x_buf, s});
238 if (t) {
239 return util::Xor(cuda::std::span<const int4, 4>(pi_tilde), cuda::std::span<const int4, 4>(cs));
240 }
241 return pi_tilde;
242 }
243
253 void Prove(cuda::std::span<const cuda::std::array<int4, 4>> pi_tildes, cuda::std::span<const int4, 4> cs,
254 cuda::std::array<int4, 4> &pi) {
255 pi = {cs[0], cs[1], cs[2], cs[3]};
256 for (size_t i = 0; i < pi_tildes.size(); ++i) {
257 cuda::std::array<int4, 4> h_input =
258 util::Xor(cuda::std::span<const int4, 4>(pi), cuda::std::span<const int4, 4>(pi_tildes[i]));
259 auto h_out = hash.Hash(cuda::std::span<const int4, 4>(h_input));
260 pi[0] = util::Xor(pi[0], h_out[0]);
261 pi[1] = util::Xor(pi[1], h_out[1]);
262 }
263 }
264
270 __host__ __device__ static bool Verify(cuda::std::span<const int4, 4> pi0, cuda::std::span<const int4, 4> pi1) {
271 for (int i = 0; i < 4; ++i) {
272 if (pi0[i].x != pi1[i].x || pi0[i].y != pi1[i].y || pi0[i].z != pi1[i].z || pi0[i].w != pi1[i].w) return false;
273 }
274 return true;
275 }
276
293 void EvalAll(bool b, int4 s0, cuda::std::span<const Cw> cws, cuda::std::span<const int4, 4> cs, int4 ocw,
294 cuda::std::span<int4> ys, cuda::std::array<int4, 4> &pi) {
295 int4 st = s0;
296 bool t = b;
297 st = util::SetLsb(st, t);
298
299 assert(in_bits < sizeof(size_t) * 8);
300 size_t l = 0;
301 size_t r = 1ULL << in_bits;
302 int i = 0;
303
304 int par_depth_ = util::ResolveParDepth(par_depth);
305
306 // Phase 1: tree traversal, store (s, t) packed into ys temporarily
307#pragma omp parallel
308#pragma omp single
309 EvalTree(st, cws, ys, l, r, i, par_depth_);
310
311 // Phase 2: sequential output computation and proof accumulation
312 pi = {cs[0], cs[1], cs[2], cs[3]};
313 size_t n = 1ULL << in_bits;
314 assert((ocw.w & 1) == 0);
315 auto ocw_group = Group::From(ocw);
316 for (size_t j = 0; j < n; ++j) {
317 int4 sj = ys[j];
318 bool tj = util::GetLsb(sj);
319 sj = util::SetLsb(sj, false);
320
321 // Output share
322 auto g = Group::From(sj);
323 if (tj) g = g + ocw_group;
324 if (b) g = -g;
325 ys[j] = g.Into();
326
327 // Proof accumulation
328 int4 x_buf = util::Pack(static_cast<In>(j));
329
330 auto pi_tilde = xor_hash.Hash(cuda::std::tuple<int4, const int4>{x_buf, sj});
331 if (tj) {
332 pi_tilde = util::Xor(cuda::std::span<const int4, 4>(pi_tilde), cuda::std::span<const int4, 4>(cs));
333 }
334
335 cuda::std::array<int4, 4> h_input =
336 util::Xor(cuda::std::span<const int4, 4>(pi), cuda::std::span<const int4, 4>(pi_tilde));
337 auto h_out = hash.Hash(cuda::std::span<const int4, 4>(h_input));
338 pi[0] = util::Xor(pi[0], h_out[0]);
339 pi[1] = util::Xor(pi[1], h_out[1]);
340 }
341 }
342
343private:
344 void EvalTree(
345 int4 st, cuda::std::span<const Cw> cws, cuda::std::span<int4> ys, size_t l, size_t r, int i, int par_depth_) {
346 if (i == in_bits) {
347 assert(l + 1 == r);
348 ys[l] = st;
349 return;
350 }
351
352 bool t = util::GetLsb(st);
353 int4 s = st;
354 s = util::SetLsb(s, false);
355
356 Cw cw = cws[i];
357 int4 s_cw = cw.s;
358 bool tl_cw = util::GetLsb(s_cw);
359 s_cw = util::SetLsb(s_cw, false);
360 bool tr_cw = cw.tr;
361
362 auto [sl, sr] = prg.Gen(s);
363
364 bool tl = util::GetLsb(sl);
365 sl = util::SetLsb(sl, false);
366 bool tr = util::GetLsb(sr);
367 sr = util::SetLsb(sr, false);
368
369 if (t) {
370 sl = util::Xor(sl, s_cw);
371 sr = util::Xor(sr, s_cw);
372 tl = tl ^ tl_cw;
373 tr = tr ^ tr_cw;
374 }
375
376 int4 stl = sl;
377 stl = util::SetLsb(stl, tl);
378 int4 str = sr;
379 str = util::SetLsb(str, tr);
380
381 size_t mid = (l + r) / 2;
382
383 if (i < par_depth_) {
384#pragma omp task
385 EvalTree(stl, cws, ys, l, mid, i + 1, par_depth_);
386#pragma omp task
387 EvalTree(str, cws, ys, mid, r, i + 1, par_depth_);
388#pragma omp taskwait
389 } else {
390 EvalTree(stl, cws, ys, l, mid, i + 1, par_depth_);
391 EvalTree(str, cws, ys, mid, r, i + 1, par_depth_);
392 }
393 }
394};
395
396} // namespace fss
2-party VDPF scheme.
Definition vdpf.cuh:67
int Gen(Cw cws[], cuda::std::array< int4, 4 > &cs, int4 &ocw, cuda::std::span< const int4, 2 > s0s, In a, int4 b_buf)
Key generation method.
Definition vdpf.cuh:101
void Prove(cuda::std::span< const cuda::std::array< int4, 4 > > pi_tildes, cuda::std::span< const int4, 4 > cs, cuda::std::array< int4, 4 > &pi)
Proof accumulation method.
Definition vdpf.cuh:253
cuda::std::array< int4, 4 > Eval(bool b, int4 s0, cuda::std::span< const Cw > cws, cuda::std::span< const int4, 4 > cs, int4 ocw, In x, int4 &y)
Evaluation method.
Definition vdpf.cuh:189
static bool Verify(cuda::std::span< const int4, 4 > pi0, cuda::std::span< const int4, 4 > pi1)
Verification method.
Definition vdpf.cuh:270
void EvalAll(bool b, int4 s0, cuda::std::span< const Cw > cws, cuda::std::span< const int4, 4 > cs, int4 ocw, cuda::std::span< int4 > ys, cuda::std::array< int4, 4 > &pi)
Full domain evaluation method.
Definition vdpf.cuh:293
Group interface.
Definition group.cuh:40
Collision-resistant hash interface.
Definition hash.cuh:19
Pseudorandom generator (PRG) interface.
Definition prg.cuh:21
Collision-resistant and XOR-collision-resistant hash interface.
Definition hash.cuh:27
Correction word.
Definition vdpf.cuh:81