Skip to content
This repository was archived by the owner on Aug 30, 2022. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/docker-dev.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ min_sum_time = 5
min_update_time = 10
max_sum_time = 3600
max_update_time = 3600
sum = 0.01
update = 0.1
sum = 0.5
update = 0.9

[mask]
group_type = "Prime"
Expand All @@ -31,7 +31,7 @@ url = "http://influxdb:8086"
db = "metrics"

[redis]
url = "redis://redis"
url = "redis://127.00.0.1/"

[s3]
access_key = "minio"
Expand Down
7 changes: 7 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 13 additions & 1 deletion rust/examples/test-drive/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async fn main() -> Result<(), ClientError> {

// dummy local model for clients
let len = opt.len as usize;
let model = Arc::new(Model::from_primitives(vec![0; len].into_iter()).unwrap());
// let model = Arc::new(Model::from_primitives(vec![0; len].into_iter()).unwrap());

// optional certificates for TLS server authentication
let certificates = opt
Expand All @@ -39,6 +39,17 @@ async fn main() -> Result<(), ClientError> {
.transpose()?;

for id in 0..opt.nb_client {
let mut v = Vec::with_capacity(len);
for i in 0..len as u32 {
if i == id % (len as u32) {
v.push(1)
} else {
v.push(0)
}
}
println!("private model {:?}", v);
let it = v.into_iter();
let model = Arc::new(Model::from_primitives(it).unwrap());
spawn_participant(
id as u32,
&opt.url,
Expand All @@ -53,6 +64,7 @@ async fn main() -> Result<(), ClientError> {
}

fn generate_agent_config() -> PetSettings {
// TODO check this config is ok
let mask_config = MaskConfig {
group_type: GroupType::Prime,
data_type: DataType::F32,
Expand Down
1 change: 1 addition & 0 deletions rust/examples/test-drive/participant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ impl Agent {

impl Participant {
pub fn new(settings: PetSettings, xaynet_client: Client, model: Arc<Model>) -> (Self, Agent) {
info!("private model {:?}", model);
let (tx, rx) = mpsc::channel::<Event>(10);
let notifier = Notifier(tx);
let agent = Agent::new(settings, xaynet_client.clone(), LocalModel(model), notifier);
Expand Down
1 change: 1 addition & 0 deletions rust/xaynet-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ derive_more = { version = "0.99.10", default-features = false, features = [
"index_mut",
"into",
] }
hist = "0.1.0"
num = { version = "0.3.0", features = ["serde"] }
paste = "1.0.1"
rand = "0.7.3"
Expand Down
32 changes: 28 additions & 4 deletions rust/xaynet-core/src/mask/masking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use crate::{
},
};

extern crate hist;

#[derive(Debug, Error, Eq, PartialEq)]
/// Errors related to the unmasking of models.
pub enum UnmaskingError {
Expand Down Expand Up @@ -198,13 +200,13 @@ impl Aggregation {
let order_1 = config_1.order();
let n = (masked_1 + &order_1 - mask_1) % &order_1;
let ratio = Ratio::<BigInt>::from(n.to_bigint().unwrap());
let scalar_sum = ratio / &exp_shift_1 - &scaled_add_shift_1;
let _scalar_sum = ratio / &exp_shift_1 - &scaled_add_shift_1;

// unmask global model
let scaled_add_shift_n = config_n.add_shift() * BigInt::from(self.nb_models);
let exp_shift_n = config_n.exp_shift();
let order_n = config_n.order();
masked_n
let unmasked: Model = masked_n
.into_iter()
.zip(mask_n)
.map(|(masked, mask)| {
Expand All @@ -222,10 +224,32 @@ impl Aggregation {
let ratio = Ratio::<BigInt>::from(n.to_bigint().unwrap());
let unmasked = ratio / &exp_shift_n - &scaled_add_shift_n;

// TEMP suppress
// scaling correction
unmasked / &scalar_sum
//unmasked / &scalar_sum

//let unmasked_int = unmasked.to_integer();
//println!("rounded unmasked int {}", unmasked_int.to_str_radix(10));
unmasked
})
.collect();

let approxs: Vec<i32> = unmasked
.iter()
.map(|x| {
let y = x.to_integer();
let digits = y.to_str_radix(10);
i32::from_str_radix(&digits, 10).unwrap()
})
.collect()
.collect();

println!("histogram for {:?}", approxs);

let h = hist::Hist::new(12, 3, &vec![0, 1, 2, 3], &approxs);
h.display();

//println!("unmasked model {:#?}", unmasked);
unmasked
}

/// Validates if aggregation of the aggregated mask object with the given `object` may be safely
Expand Down