From b8ead62fc469b324d2917547cc0a6e59d4587209 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Fri, 6 Feb 2026 23:39:19 +0100 Subject: [PATCH] Add more tests for Byzantine roubust aggregator --- discojs/package.json | 3 +- discojs/src/aggregator/byzantine.spec.ts | 208 +++++++++++++++++++++++ package-lock.json | 93 ++++++++++ package.json | 1 + 4 files changed, 304 insertions(+), 1 deletion(-) diff --git a/discojs/package.json b/discojs/package.json index 4e3b63a5f..6723eea10 100644 --- a/discojs/package.json +++ b/discojs/package.json @@ -34,6 +34,7 @@ "@tensorflow/tfjs-node": "4", "@types/simple-peer": "9", "nodemon": "3", - "ts-node": "10" + "ts-node": "10", + "fast-check": "^3" } } diff --git a/discojs/src/aggregator/byzantine.spec.ts b/discojs/src/aggregator/byzantine.spec.ts index d300fbb4c..77de1b367 100644 --- a/discojs/src/aggregator/byzantine.spec.ts +++ b/discojs/src/aggregator/byzantine.spec.ts @@ -1,3 +1,4 @@ +import fc from "fast-check"; import { Set } from "immutable"; import { describe, expect, it } from "vitest"; @@ -100,4 +101,211 @@ describe("ByzantineRobustAggregator", () => { const arr2 = await WSIntoArrays(out2); expect(arr2[0][0]).to.equal(20); }); + + it("applies momentum before aggregation", async () => { + const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1e6, 1, 0.5); + const [a, b] = ["a", "b"]; + agg.setNodes(Set.of(a, b)); + + const p1 = agg.getPromiseForAggregation(); + agg.add(a, WeightsContainer.of([0]), 0); + agg.add(b, WeightsContainer.of([10]), 0); + await p1; + + const p2 = agg.getPromiseForAggregation(); + agg.add(a, WeightsContainer.of([0]), 1); + agg.add(b, WeightsContainer.of([20]), 1); + const out = await p2; + + const arr = await WSIntoArrays(out); + expect(arr[0][0]).to.be.closeTo(7.5, 1e-6); // mean of (0, 15) + }); + + it("beta = 1 freezes aggregation after first round", async () => { + const agg = new ByzantineRobustAggregator(0, 1, 'absolute', 1e6, 1, 1); + const id = "c1"; + agg.setNodes(Set.of(id)); + + const p1 = agg.getPromiseForAggregation(); + agg.add(id, WeightsContainer.of([5]), 0); + await p1; + + const p2 = agg.getPromiseForAggregation(); + agg.add(id, WeightsContainer.of([100]), 1); + const out = await p2; + + const arr = await WSIntoArrays(out); + expect(arr[0][0]).to.equal(5); + }); + + it("remains robust with 30% Byzantine clients", async () => { + const honest = Array(7).fill(1); + const byzantine = Array(3).fill(100); + + const agg = new ByzantineRobustAggregator(0, 10, 'absolute', 1.0, 5, 0); + const ids = [...honest, ...byzantine].map((_, i) => `c${i}`); + agg.setNodes(Set(ids)); + + const p = agg.getPromiseForAggregation(); + honest.forEach((v, i) => agg.add(`c${i}`, WeightsContainer.of([v]), 0)); + byzantine.forEach((v, i) => agg.add(`c${i + honest.length}`, WeightsContainer.of([v]), 0)); + + const out = await p; + const arr = await WSIntoArrays(out); + + const honestMean = honest.reduce((a, b) => a + b, 0) / honest.length; + const rawMean = [...honest, ...byzantine].reduce((a, b) => a + b, 0) / (honest.length + byzantine.length); + + expect(Math.abs(arr[0][0] - honestMean)).to.be.lessThan(Math.abs(rawMean - honestMean)); + }); + + it("stays close to the honest signal under constant input", async () => { + const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 3, 0.9); + const ids = ["a", "b", "c"]; + agg.setNodes(Set(ids)); + + for (let r = 0; r < 10; r++) { + const p = agg.getPromiseForAggregation(); + ids.forEach(id => agg.add(id, WeightsContainer.of([1]), r)); + const out = await p; + const arr = await WSIntoArrays(out); + + expect(Math.abs(arr[0][0] - 1)).to.be.lessThan(0.3); + } + }); + + it("bounds the marginal influence of a single Byzantine client", async () => { + const clipRadius = 1.0; + + await fc.assert( + fc.asyncProperty( + fc.array(fc.double({ min: -1, max: 1 }), { minLength: 3, maxLength: 10 }), + async (honest) => { + const n = honest.length + 1; + + // aggregation without Byzantine + const aggClean = new ByzantineRobustAggregator(0, honest.length, "absolute", clipRadius, 1, 0); + const honestIds = honest.map((_, i) => `h${i}`); + aggClean.setNodes(Set(honestIds)); + + const pClean = aggClean.getPromiseForAggregation(); + honest.forEach((v, i) => aggClean.add(`h${i}`, WeightsContainer.of([v]), 0)); + const cleanOut = await pClean; + const clean = (await cleanOut.weights[0].data())[0]; + + // aggregation with Byzantine + const aggByz = new ByzantineRobustAggregator(0, n, "absolute", clipRadius, 1, 0); + const ids = honestIds.concat("byz"); + aggByz.setNodes(Set(ids)); + + const pByz = aggByz.getPromiseForAggregation(); + honest.forEach((v, i) => aggByz.add(`h${i}`, WeightsContainer.of([v]), 0)); + aggByz.add("byz", WeightsContainer.of([1e9]), 0); + const byzOut = await pByz; + const byz = (await byzOut.weights[0].data())[0]; + + const deviation = Math.abs(byz - clean); + const maxAllowed = 2 * clipRadius / n; // realistic tolerance for extreme inputs + expect(deviation).toBeLessThanOrEqual(maxAllowed); + } + ), + { numRuns: 200 } + ); + }); + + it("is invariant to client ordering", async () => { + const values = [0, 1, 100]; + const ids1 = ["a", "b", "c"]; + const ids2 = ["c", "a", "b"]; + + const run = async (ids: string[]) => { + const agg = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 3, 0); + agg.setNodes(Set(ids)); + const p = agg.getPromiseForAggregation(); + ids.forEach((id, i) => + agg.add(id, WeightsContainer.of([values[i]]), 0) + ); + return (await (await p).weights[0].data())[0]; + }; + + const out1 = await run(ids1); + const out2 = await run(ids2); + + expect(out1).to.be.closeTo(out2, 1e-6); + }); + + it("is idempotent when all inputs are identical and within clipping radius", async () => { + const agg = new ByzantineRobustAggregator(0, 5, "absolute", 10.0, 5, 0); + const ids = ["a", "b", "c", "d", "e"]; + agg.setNodes(Set(ids)); + + const p = agg.getPromiseForAggregation(); + ids.forEach(id => agg.add(id, WeightsContainer.of([3.14]), 0)); + const out = await p; + + const v = (await out.weights[0].data())[0]; + expect(v).to.be.closeTo(3.14, 1e-6); + }); + + it("limits bias under symmetric Byzantine attacks", async () => { + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); + agg.setNodes(Set(["h1", "h2", "b1", "b2"])); + + const p = agg.getPromiseForAggregation(); + agg.add("h1", WeightsContainer.of([1]), 0); + agg.add("h2", WeightsContainer.of([1]), 0); + agg.add("b1", WeightsContainer.of([100]), 0); + agg.add("b2", WeightsContainer.of([-100]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + expect(Math.abs(v - 1)).to.be.lessThan(0.3); + }); + + it("output lies within the range of clipped inputs", async () => { + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); + agg.setNodes(Set(["a", "b", "c", "d"])); + + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([0]), 0); + agg.add("b", WeightsContainer.of([0.5]), 0); + agg.add("c", WeightsContainer.of([1]), 0); + agg.add("d", WeightsContainer.of([100]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + expect(v).to.be.greaterThanOrEqual(0); + expect(v).to.be.lessThanOrEqual(1); + }); + + it("single client cannot dominate aggregation", async () => { + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); + agg.setNodes(Set(["h1", "h2", "h3", "b"])); + + const p = agg.getPromiseForAggregation(); + agg.add("h1", WeightsContainer.of([0]), 0); + agg.add("h2", WeightsContainer.of([0]), 0); + agg.add("h3", WeightsContainer.of([0]), 0); + agg.add("b", WeightsContainer.of([1e9]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + expect(Math.abs(v)).to.be.lessThan(0.5); + }); + + it("reset state when starting fresh aggregator", async () => { + const run = async () => { + const agg = new ByzantineRobustAggregator(0, 2, "absolute", 1.0, 3, 0.9); + agg.setNodes(Set(["a", "b"])); + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + return (await (await p).weights[0].data())[0]; + }; + + expect(await run()).to.be.closeTo(await run(), 1e-6); + }); }); diff --git a/package-lock.json b/package-lock.json index 258547a31..45d1af2e7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -29,6 +29,7 @@ "eslint": "9", "eslint-plugin-cypress": "5", "eslint-plugin-vue": "10", + "fast-check": "^4.5.3", "typescript": "5", "typescript-eslint": "8" } @@ -64,6 +65,7 @@ "devDependencies": { "@tensorflow/tfjs-node": "4", "@types/simple-peer": "9", + "fast-check": "^3", "nodemon": "3", "ts-node": "10" } @@ -99,6 +101,29 @@ "nodemon": "3" } }, + "discojs/node_modules/fast-check": { + "version": "3.23.2", + "resolved": "https://registry.npmjs.org/fast-check/-/fast-check-3.23.2.tgz", + "integrity": "sha512-h5+1OzzfCC3Ef7VbtKdcv7zsstUQwUDlYpUTvjeUsJAssPgLn7QzbboPtL5ro04Mq0rPOsMzl7q5hIbRs2wD1A==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT", + "dependencies": { + "pure-rand": "^6.1.0" + }, + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/@acemir/cssom": { "version": "0.9.28", "resolved": "https://registry.npmjs.org/@acemir/cssom/-/cssom-0.9.28.tgz", @@ -7913,6 +7938,46 @@ ], "license": "MIT" }, + "node_modules/fast-check": { + "version": "4.5.3", + "resolved": "https://registry.npmjs.org/fast-check/-/fast-check-4.5.3.tgz", + "integrity": "sha512-IE9csY7lnhxBnA8g/WI5eg/hygA6MGWJMSNfFRrBlXUciADEhS1EDB0SIsMSvzubzIlOBbVITSsypCsW717poA==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT", + "dependencies": { + "pure-rand": "^7.0.0" + }, + "engines": { + "node": ">=12.17.0" + } + }, + "node_modules/fast-check/node_modules/pure-rand": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/pure-rand/-/pure-rand-7.0.1.tgz", + "integrity": "sha512-oTUZM/NAZS8p7ANR3SHh30kXB+zK2r2BPcEn/awJIbOvq82WoMN4p62AWWp3Hhw50G0xMsw1mhIBLqHw64EcNQ==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT" + }, "node_modules/fast-deep-equal": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", @@ -9642,6 +9707,7 @@ "os": [ "android" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9662,6 +9728,7 @@ "os": [ "darwin" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9682,6 +9749,7 @@ "os": [ "darwin" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9702,6 +9770,7 @@ "os": [ "freebsd" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9722,6 +9791,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9742,6 +9812,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9762,6 +9833,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9782,6 +9854,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9802,6 +9875,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9822,6 +9896,7 @@ "os": [ "win32" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9842,6 +9917,7 @@ "os": [ "win32" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -11500,6 +11576,23 @@ "node": ">=6" } }, + "node_modules/pure-rand": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/pure-rand/-/pure-rand-6.1.0.tgz", + "integrity": "sha512-bVWawvoZoBYpp6yIoQtQXHZjmz35RSVHnUOTefl8Vcjr8snTPY1wnpSPMWekcFwbxI6gtmT7rSYPFvz71ldiOA==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT" + }, "node_modules/qs": { "version": "6.14.1", "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz", diff --git a/package.json b/package.json index 703afa3c5..84a7f697a 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "eslint": "9", "eslint-plugin-cypress": "5", "eslint-plugin-vue": "10", + "fast-check": "^4.5.3", "typescript": "5", "typescript-eslint": "8" }