myl7/fss 1.1.0
Function secret sharing (FSS) primitives including distributed point/comparison function (DPF/DCF)
Loading...
Searching...
No Matches
uint.cuh
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
8#pragma once
9#include <fss/group.cuh>
10#include <cuda_runtime.h>
11#include <type_traits>
12#include <cassert>
13
14namespace fss::group {
15
27template <typename T, T mod = 0>
28 requires((std::is_unsigned_v<T> || std::is_same_v<T, __uint128_t>) && sizeof(T) <= 16 &&
29 (sizeof(T) < 16 || (mod > 0 && mod <= static_cast<T>(1) << 127)))
30struct Uint {
31 T val;
32
33 __host__ __device__ Uint operator+(Uint rhs) const {
34 if constexpr (mod == 0) return {static_cast<T>(val + rhs.val)};
35
36 if (val >= mod - rhs.val) return {static_cast<T>(val + rhs.val - mod)};
37 else return {static_cast<T>(val + rhs.val)};
38 }
39
40 __host__ __device__ Uint operator-() const {
41 if constexpr (mod == 0) return {static_cast<T>(-val)};
42
43 if (val == 0) return {0};
44 else return {static_cast<T>(mod - val)};
45 }
46
47 __host__ __device__ Uint() : val(0) {}
48
49 __host__ __device__ static Uint From(int4 buf) {
50 assert((buf.w & 1) == 0);
51
52 T val = 0;
53 if constexpr (sizeof(T) < 4) val = buf.x & ((1 << 8 * sizeof(T)) - 1);
54 // Cast to unsigned int first to prevent sign extension when promoting to larger types
55 else if constexpr (sizeof(T) == 4) val = static_cast<unsigned int>(buf.x);
56 else if constexpr (sizeof(T) == 8)
57 val = static_cast<T>(static_cast<unsigned int>(buf.x)) | static_cast<T>(static_cast<unsigned int>(buf.y)) << 32;
58 else if constexpr (sizeof(T) == 16)
59 val = static_cast<T>(static_cast<unsigned int>(buf.x)) | static_cast<T>(static_cast<unsigned int>(buf.y)) << 32 |
60 static_cast<T>(static_cast<unsigned int>(buf.z)) << 64 |
61 // For uint128, LSB of buf.w is the clamped bit
62 static_cast<T>(static_cast<unsigned int>(buf.w) >> 1) << 96;
63 else __builtin_unreachable();
64
65 if constexpr (mod > 0) val %= mod;
66
67 return {val};
68 }
69
70 __host__ __device__ int4 Into() const {
71 int4 buf = {0, 0, 0, 0};
72 if constexpr (sizeof(T) <= 4) buf.x = static_cast<int>(val);
73 else if constexpr (sizeof(T) == 8) {
74 buf.x = static_cast<int>(val & 0xffffffff);
75 buf.y = static_cast<int>(val >> 32);
76 } else if constexpr (sizeof(T) == 16) {
77 buf.x = static_cast<int>(val & 0xffffffff);
78 buf.y = static_cast<int>((val >> 32) & 0xffffffff);
79 buf.z = static_cast<int>((val >> 64) & 0xffffffff);
80 // For uint128, its LSB is the clamped bit
81 buf.w = static_cast<int>((val >> 96) << 1);
82 } else __builtin_unreachable();
83 return buf;
84 }
85
86private:
87 __host__ __device__ Uint(T v) : val(v) {}
88};
89static_assert(Groupable<Uint<uint8_t>>);
90static_assert(Groupable<Uint<uint16_t>>);
91static_assert(Groupable<Uint<uint32_t>>);
92static_assert(Groupable<Uint<uint64_t>>);
94
95} // namespace fss::group
Group interface.
Definition group.cuh:40
Unsigned integers with arithmetic addition and optional modulo as a group.
Definition uint.cuh:30