myl7/fss 1.1.0
Function secret sharing (FSS) primitives including distributed point/comparison function (DPF/DCF)
Loading...
Searching...
No Matches
grotto_dcf.cuh
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
23#pragma once
24#include <cuda_runtime.h>
25#include <type_traits>
26#include <cstddef>
27#include <cassert>
28#include <omp.h>
29#include <fss/dpf.cuh>
30#include <fss/group/bytes.cuh>
31#include <fss/prg.cuh>
32#include <fss/util.cuh>
33
34namespace fss {
35
45template <int in_bits, typename Prg, typename In = uint, int par_depth = -1>
46 requires((std::is_unsigned_v<In> || std::is_same_v<In, __uint128_t>) &&
47 in_bits <= sizeof(In) * 8 && Prgable<Prg, 2>)
48class GrottoDcf {
50
51public:
52 using Cw = typename DpfType::Cw;
53 Prg prg;
54
64 __host__ __device__ void Gen(Cw cws[], const int4 s0s[2], In a) {
65 DpfType dpf{prg};
66 int4 beta = {0, 0, 0, 0};
67 dpf.Gen(cws, s0s, a, beta);
68 }
69
79 struct ParityTree {
80 bool *p;
81 bool b;
82 };
83
95 void Preprocess(ParityTree &pt, int4 s0, const Cw cws[]) {
96 constexpr size_t N = 1ULL << in_bits;
97
98 // Phase 1: expand tree, write leaf control bits to pt.p[N-1 .. 2N-2]
99 ExpandTree(pt.b, s0, cws, pt.p + (N - 1));
100
101 // Phase 2a: build parity segment tree bottom-up
102 for (size_t j = N - 2; j < N - 1; --j) {
103 pt.p[j] = pt.p[2 * j + 1] ^ pt.p[2 * j + 2];
104 }
105 }
106
117 __host__ __device__ static bool Eval(const ParityTree &pt, In x) {
118 constexpr size_t N = 1ULL << in_bits;
119 In e = static_cast<In>(x) + 1;
120
121 // e == 0 means x + 1 overflowed, i.e., e = N (entire domain)
122 if (e == 0 || e == N) return pt.p[0];
123
124 bool pi = false;
125 size_t cur = 0;
126 for (int i = 0; i < in_bits; ++i) {
127 bool e_bit = (e >> (in_bits - 1 - i)) & 1;
128 if (e_bit) {
129 pi ^= pt.p[2 * cur + 1];
130 cur = 2 * cur + 2;
131 } else {
132 cur = 2 * cur + 1;
133 }
134 }
135 return pi;
136 }
137
152 void EvalAll(bool b, int4 s0, const Cw cws[], bool ys[]) {
153 constexpr size_t N = 1ULL << in_bits;
154
155 // Phase 1: expand tree to get leaf control bits into ys[]
156 ExpandTree(b, s0, cws, ys);
157
158 // Phase 2b: prefix-sum scan (running XOR)
159 // ys[x] currently holds leaf x's control bit.
160 // Transform to: ys[x] = XOR of control bits [0..x] = share of 1[alpha <= x].
161 for (size_t x = 1; x < N; ++x) {
162 ys[x] = ys[x] ^ ys[x - 1];
163 }
164 }
165
166private:
175 void ExpandTree(bool b, int4 s0, const Cw cws[], bool t[]) {
176 int4 st = s0;
177 st = util::SetLsb(st, b);
178
179 assert(in_bits < sizeof(size_t) * 8);
180 size_t l = 0;
181 size_t r = 1ULL << in_bits;
182 int i = 0;
183
184 int par_depth_ = util::ResolveParDepth(par_depth);
185
186#pragma omp parallel
187#pragma omp single
188 ExpandTreeRec(st, cws, t, l, r, i, par_depth_);
189 }
190
191 void ExpandTreeRec(
192 int4 st, const Cw cws[], bool t[], size_t l, size_t r, int i, int par_depth_) {
193 bool tc = util::GetLsb(st);
194 int4 s = st;
195 s = util::SetLsb(s, false);
196
197 if (i == in_bits) {
198 assert(l + 1 == r);
199 t[l] = tc;
200 return;
201 }
202
203 Cw cw = cws[i];
204 int4 s_cw = cw.s;
205 bool tl_cw = util::GetLsb(s_cw);
206 s_cw = util::SetLsb(s_cw, false);
207 bool tr_cw = cw.tr;
208
209 auto [sl, sr] = prg.Gen(s);
210
211 bool tl = util::GetLsb(sl);
212 sl = util::SetLsb(sl, false);
213 bool tr = util::GetLsb(sr);
214 sr = util::SetLsb(sr, false);
215
216 if (tc) {
217 sl = util::Xor(sl, s_cw);
218 sr = util::Xor(sr, s_cw);
219 tl = tl ^ tl_cw;
220 tr = tr ^ tr_cw;
221 }
222
223 int4 stl = sl;
224 stl = util::SetLsb(stl, tl);
225 int4 str = sr;
226 str = util::SetLsb(str, tr);
227
228 size_t mid = (l + r) / 2;
229
230 if (i < par_depth_) {
231#pragma omp task
232 ExpandTreeRec(stl, cws, t, l, mid, i + 1, par_depth_);
233#pragma omp task
234 ExpandTreeRec(str, cws, t, mid, r, i + 1, par_depth_);
235#pragma omp taskwait
236 } else {
237 ExpandTreeRec(stl, cws, t, l, mid, i + 1, par_depth_);
238 ExpandTreeRec(str, cws, t, mid, r, i + 1, par_depth_);
239 }
240 }
241};
242
243} // namespace fss
2-party DPF scheme.
Definition dpf.cuh:64
void Gen(Cw cws[], const int4 s0s[2], In a, int4 b_buf)
Key generation method.
Definition dpf.cuh:93
2-party DCF scheme over F2 from standard DPF (Grotto construction).
Definition grotto_dcf.cuh:48
void Gen(Cw cws[], const int4 s0s[2], In a)
Key generation method.
Definition grotto_dcf.cuh:64
static bool Eval(const ParityTree &pt, In x)
Prefix-parity query on the parity segment tree.
Definition grotto_dcf.cuh:117
void Preprocess(ParityTree &pt, int4 s0, const Cw cws[])
Preprocess: expand DPF tree and build parity segment tree.
Definition grotto_dcf.cuh:95
void EvalAll(bool b, int4 s0, const Cw cws[], bool ys[])
Full domain evaluation.
Definition grotto_dcf.cuh:152
Pseudorandom generator (PRG) interface.
Definition prg.cuh:21
2-party distributed point function (DPF).
Correction word.
Definition dpf.cuh:76
Parity segment tree over leaf control bits.
Definition grotto_dcf.cuh:79