myl7/fss 1.1.0
Function secret sharing (FSS) primitives including distributed point/comparison function (DPF/DCF)
Loading...
Searching...
No Matches
dpf.cuh
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
39#pragma once
40#include <cuda_runtime.h>
41#include <type_traits>
42#include <cstddef>
43#include <cassert>
44#include <omp.h>
45#include <fss/group.cuh>
46#include <fss/prg.cuh>
47#include <fss/util.cuh>
48
49namespace fss {
50
61template <int in_bits, typename Group, typename Prg, typename In = uint, int par_depth = -1>
62 requires((std::is_unsigned_v<In> || std::is_same_v<In, __uint128_t>) &&
63 in_bits <= sizeof(In) * 8 && Groupable<Group> && Prgable<Prg, 2>)
64class Dpf {
65public:
66 Prg prg;
67
76 struct __align__(32) Cw {
77 int4 s;
78 bool tr;
79 };
80 // For only 1 and aligned memory access on GPU
81 static_assert(sizeof(Cw) == 32);
82
93 __host__ __device__ void Gen(Cw cws[], const int4 s0s[2], In a, int4 b_buf) {
94 int4 s0 = s0s[0];
95 s0 = util::SetLsb(s0, false);
96 int4 s1 = s0s[1];
97 s1 = util::SetLsb(s1, false);
98 bool t0 = false;
99 bool t1 = true;
100 b_buf = util::SetLsb(b_buf, false);
101
102 for (int i = 0; i < in_bits; ++i) {
103 auto [s0l, s0r] = prg.Gen(s0);
104 auto [s1l, s1r] = prg.Gen(s1);
105
106 bool t0l = util::GetLsb(s0l);
107 s0l = util::SetLsb(s0l, false);
108 bool t0r = util::GetLsb(s0r);
109 s0r = util::SetLsb(s0r, false);
110 bool t1l = util::GetLsb(s1l);
111 s1l = util::SetLsb(s1l, false);
112 bool t1r = util::GetLsb(s1r);
113 s1r = util::SetLsb(s1r, false);
114
115 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
116
117 int4 s_cw;
118 if (!a_bit) s_cw = util::Xor(s0r, s1r);
119 else s_cw = util::Xor(s0l, s1l);
120
121 bool tl_cw = t0l ^ t1l ^ a_bit ^ 1;
122 bool tr_cw = t0r ^ t1r ^ a_bit;
123
124 if (!a_bit) {
125 s0 = s0l;
126 if (t0) s0 = util::Xor(s0, s_cw);
127 s1 = s1l;
128 if (t1) s1 = util::Xor(s1, s_cw);
129
130 if (t0) t0 = t0l ^ tl_cw;
131 else t0 = t0l;
132 if (t1) t1 = t1l ^ tl_cw;
133 else t1 = t1l;
134 } else {
135 s0 = s0r;
136 if (t0) s0 = util::Xor(s0, s_cw);
137 s1 = s1r;
138 if (t1) s1 = util::Xor(s1, s_cw);
139
140 if (t0) t0 = t0r ^ tr_cw;
141 else t0 = t0r;
142 if (t1) t1 = t1r ^ tr_cw;
143 else t1 = t1r;
144 }
145
146 s_cw = util::SetLsb(s_cw, tl_cw);
147 cws[i] = {s_cw, tr_cw};
148 }
149
150 auto v_cw_np1 = Group::From(b_buf) + (-Group::From(s0)) + Group::From(s1);
151 if (t1) v_cw_np1 = -v_cw_np1;
152 cws[in_bits] = {v_cw_np1.Into(), false};
153 }
154
164 __host__ __device__ int4 Eval(bool b, int4 s0, const Cw cws[], In x) {
165 int4 s = s0;
166 s = util::SetLsb(s, false);
167 bool t = b;
168
169 for (int i = 0; i < in_bits; ++i) {
170 Cw cw = cws[i];
171 int4 s_cw = cw.s;
172 bool tl_cw = util::GetLsb(s_cw);
173 s_cw = util::SetLsb(s_cw, false);
174 bool tr_cw = cw.tr;
175
176 auto [sl, sr] = prg.Gen(s);
177
178 bool tl = util::GetLsb(sl);
179 sl = util::SetLsb(sl, false);
180 bool tr = util::GetLsb(sr);
181 sr = util::SetLsb(sr, false);
182
183 if (t) {
184 sl = util::Xor(sl, s_cw);
185 sr = util::Xor(sr, s_cw);
186 tl = tl ^ tl_cw;
187 tr = tr ^ tr_cw;
188 }
189
190 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
191
192 if (!x_bit) {
193 s = sl;
194 t = tl;
195 } else {
196 s = sr;
197 t = tr;
198 }
199 }
200
201 auto y = Group::From(s);
202 int4 v_cw_np1 = cws[in_bits].s;
203 assert((v_cw_np1.w & 1) == 0);
204 if (t) y = y + Group::From(v_cw_np1);
205 if (b) y = -y;
206
207 return y.Into();
208 }
209
225 void EvalAll(bool b, int4 s0, const Cw cws[], int4 ys[]) {
226 int4 st = s0;
227 bool t = b;
228 st = util::SetLsb(st, t);
229
230 assert(in_bits < sizeof(size_t) * 8);
231 size_t l = 0;
232 size_t r = 1ULL << in_bits;
233 int i = 0;
234
235 int par_depth_ = util::ResolveParDepth(par_depth);
236
237#pragma omp parallel
238#pragma omp single
239 EvalTree(b, st, cws, ys, l, r, i, par_depth_);
240 }
241
242private:
243 void EvalTree(
244 bool b, int4 st, const Cw cws[], int4 ys[], size_t l, size_t r, int i, int par_depth_) {
245 bool t = util::GetLsb(st);
246 int4 s = st;
247 s = util::SetLsb(s, false);
248
249 if (i == in_bits) {
250 auto y = Group::From(s);
251 int4 v_cw_np1 = cws[in_bits].s;
252 assert((v_cw_np1.w & 1) == 0);
253 if (t) y = y + Group::From(v_cw_np1);
254 if (b) y = -y;
255 assert(l + 1 == r);
256 ys[l] = y.Into();
257 return;
258 }
259
260 Cw cw = cws[i];
261 int4 s_cw = cw.s;
262 bool tl_cw = util::GetLsb(s_cw);
263 s_cw = util::SetLsb(s_cw, false);
264 bool tr_cw = cw.tr;
265
266 auto [sl, sr] = prg.Gen(s);
267
268 bool tl = util::GetLsb(sl);
269 sl = util::SetLsb(sl, false);
270 bool tr = util::GetLsb(sr);
271 sr = util::SetLsb(sr, false);
272
273 if (t) {
274 sl = util::Xor(sl, s_cw);
275 sr = util::Xor(sr, s_cw);
276 tl = tl ^ tl_cw;
277 tr = tr ^ tr_cw;
278 }
279
280 int4 stl = sl;
281 stl = util::SetLsb(stl, tl);
282 int4 str = sr;
283 str = util::SetLsb(str, tr);
284
285 size_t mid = (l + r) / 2;
286
287 if (i < par_depth_) {
288#pragma omp task
289 EvalTree(b, stl, cws, ys, l, mid, i + 1, par_depth_);
290#pragma omp task
291 EvalTree(b, str, cws, ys, mid, r, i + 1, par_depth_);
292#pragma omp taskwait
293 } else {
294 EvalTree(b, stl, cws, ys, l, mid, i + 1, par_depth_);
295 EvalTree(b, str, cws, ys, mid, r, i + 1, par_depth_);
296 }
297 }
298};
299
300} // namespace fss
2-party DPF scheme.
Definition dpf.cuh:64
int4 Eval(bool b, int4 s0, const Cw cws[], In x)
Evaluation method.
Definition dpf.cuh:164
void EvalAll(bool b, int4 s0, const Cw cws[], int4 ys[])
Full domain evaluation method.
Definition dpf.cuh:225
void Gen(Cw cws[], const int4 s0s[2], In a, int4 b_buf)
Key generation method.
Definition dpf.cuh:93
Group interface.
Definition group.cuh:40
Pseudorandom generator (PRG) interface.
Definition prg.cuh:21
Correction word.
Definition dpf.cuh:76