myl7/fss 1.1.0
Function secret sharing (FSS) primitives including distributed point/comparison function (DPF/DCF)
Loading...
Searching...
No Matches
uint.cuh
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
8#pragma once
9#include <fss/group.cuh>
10#include <cuda_runtime.h>
11#include <type_traits>
12#include <cassert>
13
14namespace fss::group {
15
27template <typename T, T mod = 0>
28 requires((std::is_unsigned_v<T> || std::is_same_v<T, __uint128_t>) && sizeof(T) <= 16 &&
29 (sizeof(T) < 16 || (mod > 0 && mod <= static_cast<T>(1) << 127)))
30struct Uint {
31 T val;
32
33 __host__ __device__ Uint operator+(Uint rhs) const {
34 if constexpr (mod == 0) return {static_cast<T>(val + rhs.val)};
35
36 if (val >= mod - rhs.val) return {static_cast<T>(val + rhs.val - mod)};
37 else return {static_cast<T>(val + rhs.val)};
38 }
39
40 __host__ __device__ Uint operator-() const {
41 if constexpr (mod == 0) return {static_cast<T>(-val)};
42
43 if (val == 0) return {0};
44 else return {static_cast<T>(mod - val)};
45 }
46
47 __host__ __device__ Uint() : val(0) {}
48
49 __host__ __device__ static Uint From(int4 buf) {
50 assert((buf.w & 1) == 0);
51
52 T val = 0;
53 if constexpr (sizeof(T) < 4) val = buf.x & ((1 << 8 * sizeof(T)) - 1);
54 // Cast to unsigned int first to prevent sign extension when promoting to larger types
55 else if constexpr (sizeof(T) == 4) val = static_cast<unsigned int>(buf.x);
56 else if constexpr (sizeof(T) == 8)
57 val = static_cast<T>(static_cast<unsigned int>(buf.x)) |
58 static_cast<T>(static_cast<unsigned int>(buf.y)) << 32;
59 else if constexpr (sizeof(T) == 16)
60 val = static_cast<T>(static_cast<unsigned int>(buf.x)) |
61 static_cast<T>(static_cast<unsigned int>(buf.y)) << 32 |
62 static_cast<T>(static_cast<unsigned int>(buf.z)) << 64 |
63 // For uint128, LSB of buf.w is the clamped bit
64 static_cast<T>(static_cast<unsigned int>(buf.w) >> 1) << 96;
65 else __builtin_unreachable();
66
67 if constexpr (mod > 0) val %= mod;
68
69 return {val};
70 }
71
72 __host__ __device__ int4 Into() const {
73 int4 buf = {0, 0, 0, 0};
74 if constexpr (sizeof(T) <= 4) buf.x = static_cast<int>(val);
75 else if constexpr (sizeof(T) == 8) {
76 buf.x = static_cast<int>(val & 0xffffffff);
77 buf.y = static_cast<int>(val >> 32);
78 } else if constexpr (sizeof(T) == 16) {
79 buf.x = static_cast<int>(val & 0xffffffff);
80 buf.y = static_cast<int>((val >> 32) & 0xffffffff);
81 buf.z = static_cast<int>((val >> 64) & 0xffffffff);
82 // For uint128, its LSB is the clamped bit
83 buf.w = static_cast<int>((val >> 96) << 1);
84 } else __builtin_unreachable();
85 return buf;
86 }
87
88private:
89 __host__ __device__ Uint(T v) : val(v) {}
90};
91static_assert(Groupable<Uint<uint8_t>>);
92static_assert(Groupable<Uint<uint16_t>>);
93static_assert(Groupable<Uint<uint32_t>>);
94static_assert(Groupable<Uint<uint64_t>>);
96
97} // namespace fss::group
Group interface.
Definition group.cuh:40
Unsigned integers with arithmetic addition and optional modulo as a group.
Definition uint.cuh:30