From 6f362f33c626a02398b5faa7b270371fba906d7f Mon Sep 17 00:00:00 2001 From: finiteprods Date: Fri, 13 Nov 2020 10:18:17 +0100 Subject: [PATCH] mini histogram eg --- configs/docker-dev.toml | 6 ++--- rust/Cargo.lock | 7 ++++++ rust/examples/test-drive/main.rs | 14 ++++++++++- rust/examples/test-drive/participant.rs | 1 + rust/xaynet-core/Cargo.toml | 1 + rust/xaynet-core/src/mask/masking.rs | 32 +++++++++++++++++++++---- 6 files changed, 53 insertions(+), 8 deletions(-) diff --git a/configs/docker-dev.toml b/configs/docker-dev.toml index 5a6296da8..c6bceb049 100644 --- a/configs/docker-dev.toml +++ b/configs/docker-dev.toml @@ -14,8 +14,8 @@ min_sum_time = 5 min_update_time = 10 max_sum_time = 3600 max_update_time = 3600 -sum = 0.01 -update = 0.1 +sum = 0.5 +update = 0.9 [mask] group_type = "Prime" @@ -31,7 +31,7 @@ url = "http://influxdb:8086" db = "metrics" [redis] -url = "redis://redis" +url = "redis://127.00.0.1/" [s3] access_key = "minio" diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 9e2b40fd3..1cbf31edd 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -945,6 +945,12 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "644f9158b2f133fd50f5fb3242878846d9eb792e445c893805ff0e3824006e35" +[[package]] +name = "hist" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca1d61bc7e668d28d2a9759992dc83e73461f6efc5da114d5ab8502780cebd6c" + [[package]] name = "hmac" version = "0.8.1" @@ -3532,6 +3538,7 @@ dependencies = [ "anyhow", "bitflags", "derive_more", + "hist", "num", "paste", "rand 0.7.3", diff --git a/rust/examples/test-drive/main.rs b/rust/examples/test-drive/main.rs index ed11dbbb3..371c8f75a 100644 --- a/rust/examples/test-drive/main.rs +++ b/rust/examples/test-drive/main.rs @@ -29,7 +29,7 @@ async fn main() -> Result<(), ClientError> { // dummy local model for clients let len = opt.len as usize; - let model = Arc::new(Model::from_primitives(vec![0; len].into_iter()).unwrap()); + // let model = Arc::new(Model::from_primitives(vec![0; len].into_iter()).unwrap()); // optional certificates for TLS server authentication let certificates = opt @@ -39,6 +39,17 @@ async fn main() -> Result<(), ClientError> { .transpose()?; for id in 0..opt.nb_client { + let mut v = Vec::with_capacity(len); + for i in 0..len as u32 { + if i == id % (len as u32) { + v.push(1) + } else { + v.push(0) + } + } + println!("private model {:?}", v); + let it = v.into_iter(); + let model = Arc::new(Model::from_primitives(it).unwrap()); spawn_participant( id as u32, &opt.url, @@ -53,6 +64,7 @@ async fn main() -> Result<(), ClientError> { } fn generate_agent_config() -> PetSettings { + // TODO check this config is ok let mask_config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, diff --git a/rust/examples/test-drive/participant.rs b/rust/examples/test-drive/participant.rs index ec29244be..99e8b331b 100644 --- a/rust/examples/test-drive/participant.rs +++ b/rust/examples/test-drive/participant.rs @@ -64,6 +64,7 @@ impl Agent { impl Participant { pub fn new(settings: PetSettings, xaynet_client: Client, model: Arc) -> (Self, Agent) { + info!("private model {:?}", model); let (tx, rx) = mpsc::channel::(10); let notifier = Notifier(tx); let agent = Agent::new(settings, xaynet_client.clone(), LocalModel(model), notifier); diff --git a/rust/xaynet-core/Cargo.toml b/rust/xaynet-core/Cargo.toml index e60eaca4e..6ad1f33d0 100644 --- a/rust/xaynet-core/Cargo.toml +++ b/rust/xaynet-core/Cargo.toml @@ -27,6 +27,7 @@ derive_more = { version = "0.99.10", default-features = false, features = [ "index_mut", "into", ] } +hist = "0.1.0" num = { version = "0.3.0", features = ["serde"] } paste = "1.0.1" rand = "0.7.3" diff --git a/rust/xaynet-core/src/mask/masking.rs b/rust/xaynet-core/src/mask/masking.rs index 2ca5cba5a..07384a598 100644 --- a/rust/xaynet-core/src/mask/masking.rs +++ b/rust/xaynet-core/src/mask/masking.rs @@ -25,6 +25,8 @@ use crate::{ }, }; +extern crate hist; + #[derive(Debug, Error, Eq, PartialEq)] /// Errors related to the unmasking of models. pub enum UnmaskingError { @@ -198,13 +200,13 @@ impl Aggregation { let order_1 = config_1.order(); let n = (masked_1 + &order_1 - mask_1) % &order_1; let ratio = Ratio::::from(n.to_bigint().unwrap()); - let scalar_sum = ratio / &exp_shift_1 - &scaled_add_shift_1; + let _scalar_sum = ratio / &exp_shift_1 - &scaled_add_shift_1; // unmask global model let scaled_add_shift_n = config_n.add_shift() * BigInt::from(self.nb_models); let exp_shift_n = config_n.exp_shift(); let order_n = config_n.order(); - masked_n + let unmasked: Model = masked_n .into_iter() .zip(mask_n) .map(|(masked, mask)| { @@ -222,10 +224,32 @@ impl Aggregation { let ratio = Ratio::::from(n.to_bigint().unwrap()); let unmasked = ratio / &exp_shift_n - &scaled_add_shift_n; + // TEMP suppress // scaling correction - unmasked / &scalar_sum + //unmasked / &scalar_sum + + //let unmasked_int = unmasked.to_integer(); + //println!("rounded unmasked int {}", unmasked_int.to_str_radix(10)); + unmasked + }) + .collect(); + + let approxs: Vec = unmasked + .iter() + .map(|x| { + let y = x.to_integer(); + let digits = y.to_str_radix(10); + i32::from_str_radix(&digits, 10).unwrap() }) - .collect() + .collect(); + + println!("histogram for {:?}", approxs); + + let h = hist::Hist::new(12, 3, &vec![0, 1, 2, 3], &approxs); + h.display(); + + //println!("unmasked model {:#?}", unmasked); + unmasked } /// Validates if aggregation of the aggregated mask object with the given `object` may be safely