myl7/fss 1.1.0
Function secret sharing (FSS) primitives including distributed point/comparison function (DPF/DCF)
Loading...
Searching...
No Matches
half_tree_dpf.cuh
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
17#pragma once
18#include <cuda_runtime.h>
19#include <type_traits>
20#include <cstddef>
21#include <cassert>
22#include <omp.h>
23#include <fss/group.cuh>
24#include <fss/prg.cuh>
25#include <fss/util.cuh>
26
27namespace fss {
28
39template <int in_bits, typename Group, typename Prg, typename In = uint, int par_depth = -1>
40 requires((std::is_unsigned_v<In> || std::is_same_v<In, __uint128_t>) &&
41 in_bits <= sizeof(In) * 8 && Groupable<Group> && Prgable<Prg, 1>)
43public:
44 Prg prg;
45 int4 hash_key;
46
53 struct __align__(32) Cw {
54 int4 s;
55 bool extra;
56 };
57 static_assert(sizeof(Cw) == 32);
58
68 __host__ __device__ void Gen(Cw cws[], int4 &ocw, const int4 s0s[2], In a, int4 b_buf) {
69 b_buf = util::SetLsb(b_buf, false);
70
71 // Initialize: node0 has t=0, node1 has t=1
72 int4 node0 = util::SetLsb(s0s[0], false);
73 int4 node1 = util::SetLsb(s0s[1], true);
74 int4 delta = util::Xor(node0, node1); // LSB = 0^1 = 1
75
76 // Levels 1 to n-1 (index i = 0 to in_bits-2)
77 for (int i = 0; i < in_bits - 1; ++i) {
78 int4 h0 = prg.Gen(util::Xor(hash_key, node0))[0];
79 int4 h1 = prg.Gen(util::Xor(hash_key, node1))[0];
80
81 bool a_bit = (a >> (in_bits - 1 - i)) & 1;
82
83 // CW = h0 ^ h1 ^ (!a_bit ? delta : 0)
84 // When a_bit=0 (go left): non-alpha is right, CW = h0^h1^delta makes right0=right1
85 // When a_bit=1 (go right): non-alpha is left, CW = h0^h1 makes left0=left1
86 int4 cw = util::Xor(h0, h1);
87 if (!a_bit) cw = util::Xor(cw, delta);
88
89 cws[i] = {cw, false};
90
91 bool t0 = util::GetLsb(node0);
92 bool t1 = util::GetLsb(node1);
93
94 // node_b = h_b ^ (a_bit ? node_b : 0) ^ (t_b ? cw : 0)
95 int4 zero4 = {0, 0, 0, 0};
96 int4 ab_mask0 = a_bit ? node0 : zero4;
97 int4 ab_mask1 = a_bit ? node1 : zero4;
98 int4 t0_mask = t0 ? cw : zero4;
99 int4 t1_mask = t1 ? cw : zero4;
100
101 node0 = util::Xor(util::Xor(h0, ab_mask0), t0_mask);
102 node1 = util::Xor(util::Xor(h1, ab_mask1), t1_mask);
103
104 // delta = node0 ^ node1 for next level
105 delta = util::Xor(node0, node1);
106 }
107
108 // Level n (last level, index i = in_bits-1)
109 {
110 bool a_n = (a >> 0) & 1; // last bit of alpha
111 bool t0 = util::GetLsb(node0);
112 bool t1 = util::GetLsb(node1);
113
114 // Hash with sigma in {0, 1}
115 int4 h0_0 = prg.Gen(util::Xor(hash_key, util::SetLsb(node0, false)))[0];
116 int4 h0_1 = prg.Gen(util::Xor(hash_key, util::SetLsb(node0, true)))[0];
117 int4 h1_0 = prg.Gen(util::Xor(hash_key, util::SetLsb(node1, false)))[0];
118 int4 h1_1 = prg.Gen(util::Xor(hash_key, util::SetLsb(node1, true)))[0];
119
120 // Extract high (s) and low (t) parts
121 int4 high0_0 = util::SetLsb(h0_0, false);
122 bool low0_0 = util::GetLsb(h0_0);
123 int4 high0_1 = util::SetLsb(h0_1, false);
124 bool low0_1 = util::GetLsb(h0_1);
125 int4 high1_0 = util::SetLsb(h1_0, false);
126 bool low1_0 = util::GetLsb(h1_0);
127 int4 high1_1 = util::SetLsb(h1_1, false);
128 bool low1_1 = util::GetLsb(h1_1);
129
130 // HCW corrects the non-alpha direction so both parties converge.
131 // HCW = high{!a_n}_0 ^ high{!a_n}_1
132 int4 HCW;
133 if (a_n) HCW = util::Xor(high0_0, high1_0);
134 else HCW = util::Xor(high0_1, high1_1);
135
136 // LCW ensures:
137 // Alpha direction (sigma=a_n): low0 ^ low1 = 1 (exactly one adds ocw)
138 // Non-alpha direction (sigma=!a_n): low0 ^ low1 = 0 (cancels)
139 // LCW_0 = low{0}_0 ^ low{0}_1 ^ !a_n
140 // LCW_1 = low{1}_0 ^ low{1}_1 ^ a_n
141 bool LCW_0 = low0_0 ^ low1_0 ^ !a_n;
142 bool LCW_1 = low0_1 ^ low1_1 ^ a_n;
143
144 // Store CW_n
145 cws[in_bits - 1] = {util::SetLsb(HCW, LCW_0), LCW_1};
146
147 // Compute leaf for each party
148 // leaf_b = (a_n ? high{1}_b||low{1}_b : high{0}_b||low{0}_b)
149 int4 leaf0, leaf1;
150 if (a_n) {
151 leaf0 = util::SetLsb(high0_1, low0_1);
152 leaf1 = util::SetLsb(high1_1, low1_1);
153 } else {
154 leaf0 = util::SetLsb(high0_0, low0_0);
155 leaf1 = util::SetLsb(high1_0, low1_0);
156 }
157
158 // Apply CW correction: if t_b: leaf_b ^= SetLsb(HCW, lcw_an)
159 bool lcw_an = a_n ? LCW_1 : LCW_0;
160 int4 leaf_cw = util::SetLsb(HCW, lcw_an);
161 if (t0) leaf0 = util::Xor(leaf0, leaf_cw);
162 if (t1) leaf1 = util::Xor(leaf1, leaf_cw);
163
164 // Output CW: v_cw = Group::From(b_buf) + (-Group::From(SetLsb(leaf0,false))) + Group::From(SetLsb(leaf1,false))
165 auto v_cw = Group::From(b_buf) + (-Group::From(util::SetLsb(leaf0, false))) +
166 Group::From(util::SetLsb(leaf1, false));
167 if (util::GetLsb(leaf1)) v_cw = -v_cw;
168 ocw = v_cw.Into();
169 }
170 }
171
182 __host__ __device__ int4 Eval(bool b, int4 s0, const Cw cws[], int4 ocw, In x) {
183 int4 node = util::SetLsb(s0, b);
184
185 // Levels 1 to n-1 (index i = 0 to in_bits-2)
186 for (int i = 0; i < in_bits - 1; ++i) {
187 bool x_bit = (x >> (in_bits - 1 - i)) & 1;
188 bool t = util::GetLsb(node);
189
190 int4 h = prg.Gen(util::Xor(hash_key, node))[0];
191
192 int4 zero4 = {0, 0, 0, 0};
193 int4 xb_mask = x_bit ? node : zero4;
194 int4 t_mask = t ? cws[i].s : zero4;
195
196 node = util::Xor(util::Xor(h, xb_mask), t_mask);
197 }
198
199 // Level n (last level)
200 {
201 bool x_n = (x >> 0) & 1;
202 bool t = util::GetLsb(node);
203
204 int4 h = prg.Gen(util::Xor(hash_key, util::SetLsb(node, x_n)))[0];
205
206 // Unpack CW_n
207 int4 hcw = util::SetLsb(cws[in_bits - 1].s, false);
208 bool lcw_xn;
209 if (x_n) lcw_xn = cws[in_bits - 1].extra;
210 else lcw_xn = util::GetLsb(cws[in_bits - 1].s);
211
212 int4 high = util::SetLsb(h, false);
213 bool low = util::GetLsb(h);
214
215 if (t) {
216 high = util::Xor(high, hcw);
217 low = low ^ lcw_xn;
218 }
219
220 auto y = Group::From(high);
221 if (low) y = y + Group::From(ocw);
222 if (b) y = -y;
223
224 return y.Into();
225 }
226 }
227
237 void EvalAll(bool b, int4 s0, const Cw cws[], int4 ocw, int4 ys[]) {
238 int4 node = util::SetLsb(s0, b);
239
240 assert(in_bits < sizeof(size_t) * 8);
241
242 int par_depth_ = util::ResolveParDepth(par_depth);
243
244 if constexpr (in_bits == 1) {
245 // Only level n (last level), no tree traversal
246#pragma omp parallel
247#pragma omp single
248 EvalLastLevel(b, node, cws, ocw, ys);
249 return;
250 }
251
252 // Phase 1: tree traversal for levels 1..n-1, stores nodes at level n-1
253 // We use ys[] as scratch space for intermediate nodes.
254 // After phase 1, ys[0..2^(in_bits-1)-1] hold the level n-1 nodes (packed s||t).
255 size_t num_leaves = 1ULL << (in_bits - 1);
256
257 // Recursive tree traversal
258#pragma omp parallel
259#pragma omp single
260 EvalTree(node, cws, ys, 0, num_leaves, 0, par_depth_);
261
262 // Phase 2: level n + output conversion
263 int4 hcw = util::SetLsb(cws[in_bits - 1].s, false);
264 bool lcw_0 = util::GetLsb(cws[in_bits - 1].s);
265 bool lcw_1 = cws[in_bits - 1].extra;
266 auto ocw_group = Group::From(ocw);
267
268 // Iterate backward to avoid overwriting unprocessed parent nodes.
269 for (size_t j = num_leaves; j-- > 0;) {
270 ConvertLastLevel(b, ys[j], hcw, lcw_0, lcw_1, ocw_group, ys[2 * j], ys[2 * j + 1]);
271 }
272 }
273
274private:
275 void EvalTree(int4 node, const Cw cws[], int4 ys[], size_t l, size_t r, int i, int par_depth_) {
276 // i is the level index (0-based), we traverse levels 0..in_bits-2
277 // At level in_bits-1, we store the node
278 if (i == in_bits - 1) {
279 assert(l + 1 == r);
280 ys[l] = node;
281 return;
282 }
283
284 bool t = util::GetLsb(node);
285 int4 h = prg.Gen(util::Xor(hash_key, node))[0];
286
287 int4 zero4 = {0, 0, 0, 0};
288 int4 t_mask = t ? cws[i].s : zero4;
289
290 // Left child: left = H_S(parent) ^ (t ? cw : 0)
291 int4 left = util::Xor(h, t_mask);
292 // Right child: right = left ^ parent
293 int4 right = util::Xor(left, node);
294
295 size_t mid = (l + r) / 2;
296
297 if (i < par_depth_) {
298#pragma omp task
299 EvalTree(left, cws, ys, l, mid, i + 1, par_depth_);
300#pragma omp task
301 EvalTree(right, cws, ys, mid, r, i + 1, par_depth_);
302#pragma omp taskwait
303 } else {
304 EvalTree(left, cws, ys, l, mid, i + 1, par_depth_);
305 EvalTree(right, cws, ys, mid, r, i + 1, par_depth_);
306 }
307 }
308
309 void EvalLastLevel(bool b, int4 node, const Cw cws[], int4 ocw, int4 ys[]) {
310 int4 hcw = util::SetLsb(cws[0].s, false);
311 bool lcw_0 = util::GetLsb(cws[0].s);
312 bool lcw_1 = cws[0].extra;
313 ConvertLastLevel(b, node, hcw, lcw_0, lcw_1, Group::From(ocw), ys[0], ys[1]);
314 }
315
316 void ConvertLastLevel(bool b, int4 parent, int4 hcw, bool lcw_0, bool lcw_1, Group ocw_group,
317 int4 &y0_out, int4 &y1_out) {
318 bool t_parent = util::GetLsb(parent);
319
320 int4 h0 = prg.Gen(util::Xor(hash_key, util::SetLsb(parent, false)))[0];
321 int4 h1 = prg.Gen(util::Xor(hash_key, util::SetLsb(parent, true)))[0];
322
323 int4 high0 = util::SetLsb(h0, false);
324 bool low0 = util::GetLsb(h0);
325 int4 high1 = util::SetLsb(h1, false);
326 bool low1 = util::GetLsb(h1);
327
328 if (t_parent) {
329 high0 = util::Xor(high0, hcw);
330 low0 = low0 ^ lcw_0;
331 high1 = util::Xor(high1, hcw);
332 low1 = low1 ^ lcw_1;
333 }
334
335 auto y0 = Group::From(high0);
336 if (low0) y0 = y0 + ocw_group;
337 if (b) y0 = -y0;
338
339 auto y1 = Group::From(high1);
340 if (low1) y1 = y1 + ocw_group;
341 if (b) y1 = -y1;
342
343 y0_out = y0.Into();
344 y1_out = y1.Into();
345 }
346};
347
348} // namespace fss
2-party DPF scheme using the Half-Tree construction.
Definition half_tree_dpf.cuh:42
void EvalAll(bool b, int4 s0, const Cw cws[], int4 ocw, int4 ys[])
Full domain evaluation method.
Definition half_tree_dpf.cuh:237
int4 Eval(bool b, int4 s0, const Cw cws[], int4 ocw, In x)
Evaluation method.
Definition half_tree_dpf.cuh:182
void Gen(Cw cws[], int4 &ocw, const int4 s0s[2], In a, int4 b_buf)
Key generation method.
Definition half_tree_dpf.cuh:68
Group interface.
Definition group.cuh:40
Pseudorandom generator (PRG) interface.
Definition prg.cuh:21
Correction word.
Definition half_tree_dpf.cuh:53