-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathgc.rs
297 lines (260 loc) · 9.62 KB
/
gc.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
#![allow(non_snake_case)]
use algebra::{BitIterator, FixedPointParameters, Fp64Parameters, FpParameters, PrimeField};
pub use fancy_garbling;
use fancy_garbling::{
circuit::CircuitBuilder, error::CircuitBuilderError, util, BinaryBundle, BinaryGadgets,
BundleGadgets, Fancy,
};
#[inline(always)]
fn mux_single_bit<F: Fancy>(
f: &mut F,
b: &F::Item,
x: &F::Item,
y: &F::Item,
) -> Result<F::Item, F::Error> {
let y_plus_x = f.add(x, y)?;
let res = f.mul(b, &y_plus_x)?;
f.add(&x, &res)
}
/// If `b = 0` returns `x` else `y`.
///
/// `b` must be mod 2 but `x` and `y` can be have any modulus.
fn mux<F: Fancy>(
f: &mut F,
b: &F::Item,
x: &BinaryBundle<F::Item>,
y: &BinaryBundle<F::Item>,
) -> Result<Vec<F::Item>, F::Error> {
x.wires()
.iter()
.zip(y.wires())
.map(|(x, y)| mux_single_bit(f, b, x, y))
.collect()
}
#[inline]
fn mod_p_helper<F: Fancy>(
b: &mut F,
neg_p: &BinaryBundle<F::Item>,
bits: &BinaryBundle<F::Item>,
) -> Result<BinaryBundle<F::Item>, F::Error> {
let (result, borrow) = b.bin_addition(&bits, &neg_p)?;
// If p underflowed, then we want the result, otherwise we're fine with the
// original.
mux(b, &borrow, &bits, &result).map(BinaryBundle::new)
}
/// Binary adder. Returns the result and the carry.
fn adder_const<F: Fancy>(
f: &mut F,
x: &F::Item,
y: &F::Item,
b: bool,
carry_in: Option<&F::Item>,
) -> Result<(F::Item, Option<F::Item>), F::Error> {
if let Some(c) = carry_in {
let z1 = f.xor(x, y)?;
let z2 = f.xor(&z1, c)?;
let z3 = f.xor(x, c)?;
let z4 = f.and(&z1, &z3)?;
let carry = f.xor(&z4, x)?;
Ok((z2, Some(carry)))
} else {
let z = f.xor(x, y)?;
let carry = if !b { None } else { Some(f.and(x, y)?) };
Ok((z, carry))
}
}
fn neg_p_over_2_helper<F: Fancy>(
f: &mut F,
neg_p_over_2: u128,
neg_p_over_2_bits: &BinaryBundle<F::Item>,
bits: &BinaryBundle<F::Item>,
) -> Result<BinaryBundle<F::Item>, F::Error> {
let xwires = bits.wires();
let ywires = neg_p_over_2_bits.wires();
let mut neg_p_over_2 = BitIterator::new([neg_p_over_2 as u64]).collect::<Vec<_>>();
neg_p_over_2.reverse();
let mut neg_p_over_2 = neg_p_over_2.into_iter();
let mut seen_one = neg_p_over_2.next().unwrap();
let (mut z, mut c) = adder_const(f, &xwires[0], &ywires[0], seen_one, None)?;
let mut bs = vec![z];
for ((x, y), b) in xwires[1..(xwires.len() - 1)]
.iter()
.zip(&ywires[1..])
.zip(neg_p_over_2)
{
seen_one |= b;
let res = adder_const(f, x, y, seen_one, c.as_ref())?;
z = res.0;
c = res.1;
bs.push(z);
}
z = f.add_many(&[
xwires.last().unwrap().clone(),
ywires.last().unwrap().clone(),
c.unwrap(),
])?;
bs.push(z);
Ok(BinaryBundle::new(bs))
}
/// Compute the number of bits needed to represent `p`, plus one.
#[inline]
pub fn num_bits(p: u128) -> usize {
(p.next_power_of_two() * 2).trailing_zeros() as usize
}
/// Compute the `ReLU` of `n` over the field `P::Field`.
pub fn relu<P: FixedPointParameters>(
b: &mut CircuitBuilder,
n: usize,
) -> Result<(), CircuitBuilderError>
where
<P::Field as PrimeField>::Params: Fp64Parameters,
P::Field: PrimeField<BigInt = <<P::Field as PrimeField>::Params as FpParameters>::BigInt>,
{
let p = u128::from(<<P::Field as PrimeField>::Params>::MODULUS.0);
let exponent_size = P::EXPONENT_CAPACITY as usize;
let p_over_2 = p / 2;
// Convert to two's complement
let neg_p_over_2 = !p_over_2 + 1;
// Convert to two's complement. Equivalent to `let neg_p = -(p as i128) as u128;
let neg_p = !p + 1;
let q = 2;
let num_bits = num_bits(p);
let moduli = vec![q; num_bits];
// Construct constant for addition with neg p
let neg_p = b.bin_constant_bundle(neg_p, num_bits)?;
let neg_p_over_2_bits = b
.constant_bundle(&util::u128_to_bits(neg_p_over_2, num_bits), &moduli)?
.into();
let zero = b.constant(0, 2)?;
let one = b.constant(1, 2)?;
for _ in 0..n {
let s1 = BinaryBundle::new(b.evaluator_inputs(&moduli));
let s2 = BinaryBundle::new(b.garbler_inputs(&moduli));
let s2_next = BinaryBundle::new(b.garbler_inputs(&moduli));
// Add secret shares as integers
let res = b.bin_addition_no_carry(&s1, &s2)?;
// Take the result mod p;
let layer_input = mod_p_helper(b, &neg_p, &res).unwrap();
// Compare with p/2
// Since we take > p/2 as negative, if the number is less than p/2, it is
// positive.
let res = neg_p_over_2_helper(b, neg_p_over_2, &neg_p_over_2_bits, &layer_input)?;
// Take the sign bit
let zs_is_positive = res.wires().last().unwrap();
// Compute the relu
let mut relu_res = Vec::with_capacity(num_bits);
let relu_6_size = exponent_size + 3;
// We choose 5 arbitrarily here; the idea is that we won't see values of
// greater than 2^8.
// We then drop the larger bits
for wire in layer_input.wires().iter().take(relu_6_size + 5) {
relu_res.push(b.and(&zs_is_positive, wire)?);
}
let is_seven = b.and_many(&relu_res[(exponent_size + 1)..relu_6_size])?;
let some_higher_bit_is_set = b.or_many(&relu_res[relu_6_size..])?;
let should_be_six = b.or(&some_higher_bit_is_set, &is_seven)?;
for wire in &mut relu_res[relu_6_size..] {
*wire = zero;
}
let lsb = &mut relu_res[exponent_size];
*lsb = mux_single_bit(b, &should_be_six, lsb, &zero)?;
let middle_bit = &mut relu_res[exponent_size + 1];
*middle_bit = mux_single_bit(b, &should_be_six, middle_bit, &one)?;
let msb = &mut relu_res[exponent_size + 2];
*msb = mux_single_bit(b, &should_be_six, msb, &one)?;
for wire in &mut relu_res[..exponent_size] {
*wire = mux_single_bit(b, &should_be_six, wire, &zero)?;
}
relu_res.extend(std::iter::repeat(zero).take(num_bits - relu_6_size - 5));
let relu_res = BinaryBundle::new(relu_res);
let res = b.bin_addition_no_carry(&relu_res, &s2_next)?;
let next_share = mod_p_helper(b, &neg_p, &res)?;
b.output_bundle(&next_share)?;
}
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
use crate::Share;
use algebra::{fields::near_mersenne_64::F, *};
use fancy_garbling::circuit::CircuitBuilder;
use rand::{thread_rng, Rng};
struct TenBitExpParams {}
impl FixedPointParameters for TenBitExpParams {
type Field = F;
const MANTISSA_CAPACITY: u8 = 3;
const EXPONENT_CAPACITY: u8 = 10;
}
type TenBitExpFP = FixedPoint<TenBitExpParams>;
fn generate_random_number<R: Rng>(rng: &mut R) -> (f64, TenBitExpFP) {
let is_neg: bool = rng.gen();
let mul = if is_neg { -10.0 } else { 10.0 };
let float: f64 = rng.gen();
let f = TenBitExpFP::truncate_float(float * mul);
let n = TenBitExpFP::from(f);
(f, n)
}
/// Compute the product of some u16s as a u128.
#[inline]
pub(crate) fn product(xs: &[u16]) -> u128 {
xs.iter().fold(1, |acc, &x| acc * x as u128)
}
#[test]
pub(crate) fn test_relu() {
// TODO: There is currently an off-by-one in this test that causes it
// to fail occasionally
let mut rng = thread_rng();
let n = 42;
let q = 2;
let p = <F as PrimeField>::Params::MODULUS.0 as u128;
let Q = product(&vec![q; n]);
println!("n={} q={} Q={}", n, q, Q);
let mut b = CircuitBuilder::new();
relu::<TenBitExpParams>(&mut b, 1).unwrap();
let mut c = b.finish();
let _ = c.print_info();
let zero = TenBitExpFP::zero();
let six = TenBitExpFP::from(6.0);
for i in 0..10000 {
let (_, n1) = generate_random_number(&mut rng);
let (s1, s2) = n1.share(&mut rng);
let res_should_be_fp = if n1 <= zero {
zero
} else if n1 > six {
six
} else {
n1
};
let res_should_be = res_should_be_fp.inner.into_repr().0 as u128;
let z1 = F::uniform(&mut rng).into_repr().0 as u128;
let res_should_be = (res_should_be + z1) % p;
let s1 = s1.inner.inner.into_repr().0 as u128;
let mut garbler_inputs = util::u128_to_bits(s1, n);
garbler_inputs.extend_from_slice(&util::u128_to_bits(z1, n));
let s2 = s2.inner.inner.into_repr().0 as u128;
let evaluator_inputs = util::u128_to_bits(s2, n);
let (en, ev) = fancy_garbling::garble(&mut c).unwrap();
let xs = en.encode_garbler_inputs(&garbler_inputs);
let ys = en.encode_evaluator_inputs(&evaluator_inputs);
let garbled_eval_results = ev.eval(&mut c, &xs, &ys).unwrap();
let evaluated_results = c.eval_plain(&garbler_inputs, &evaluator_inputs).unwrap();
assert!(
util::u128_from_bits(&evaluated_results).abs_diff(res_should_be) <= 1,
"Iteration {}, Pre-ReLU value is {}, value should be {}, {:?}",
i,
n1,
res_should_be_fp,
res_should_be_fp
);
assert!(
util::u128_from_bits(&garbled_eval_results).abs_diff(res_should_be) <= 1,
"Iteration {}, Pre-ReLU value is {}, value should be {}, {:?}",
i,
n1,
res_should_be_fp,
res_should_be_fp
);
}
}
}