From d64f17fde4d0ca12d98740a4dae6bd2bcf3023e1 Mon Sep 17 00:00:00 2001 From: dinahmaccodes Date: Fri, 9 May 2025 07:48:07 +0100 Subject: [PATCH] feat: improve aggregator --- src/aggregator.cairo | 45 ++++++++++++++++++++++++++++++++++++++- tests/test_contract.cairo | 17 ++++++++++++++- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/src/aggregator.cairo b/src/aggregator.cairo index 12d8229..2d4d589 100644 --- a/src/aggregator.cairo +++ b/src/aggregator.cairo @@ -20,6 +20,7 @@ mod Agggregator { use cohort_4::counter::{ICounterDispatcher, ICounterDispatcherTrait}; use cohort_4::killswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; use starknet::ContractAddress; + use core::num::traits::Zero; use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; @@ -30,10 +31,48 @@ mod Agggregator { killswitch: ContractAddress, } + + #[event] + #[derive(Drop, starknet::Event)] + pub enum Event { + CountIncreased: CountIncreased, + CounterCountIncreased: CounterCountIncreased, + CountDecreasedByOne: CountDecreasedByOne, + SwitchStatus: SwitchStatus, + } + + #[derive(Drop, Serde, starknet::Event)] + pub struct CountIncreased { + amount: u32, + } + + #[derive(Drop, Serde, starknet::Event)] + pub struct CounterCountIncreased { + amount: u32, + } + + #[derive(Drop, Serde, starknet::Event)] + pub struct CountDecreasedByOne { + pub previous_count: u32, + } + + #[derive(Drop, Serde, starknet::Event)] + pub struct SwitchStatus { + pub get_status: bool, + } + #[constructor] fn constructor(ref self: ContractState, counter: ContractAddress, killswitch: ContractAddress) { self.counter.write(counter); self.killswitch.write(killswitch); + //Assert address is not 0 + assert(!self.counter.read().is_zero(), 'Counter address cannot be 0'); + assert(!self.killswitch.read().is_zero(), 'KillSwitch address cannot be 0'); + //Assert count is not kill_switch + assert( + self.counter.read() != self.killswitch.read(), + 'Odd! Counter is KillSwitch' + ); } @@ -44,6 +83,7 @@ mod Agggregator { let counter = ICounterDispatcher { contract_address: self.counter.read() }; let counter_count = counter.get_count(); self.count.write(counter_count + amount); + self.emit(CountIncreased { amount }); } fn increase_counter_count(ref self: ContractState, amount: u32) { @@ -51,13 +91,15 @@ mod Agggregator { contract_address: self.killswitch.read(), }; assert(killswitch.get_status(), 'should be active'); - ICounterDispatcher { contract_address: self.counter.read() }.increase_count(amount) + ICounterDispatcher { contract_address: self.counter.read() }.increase_count(amount); + self.emit(CounterCountIncreased { amount }); } fn decrease_count_by_one(ref self: ContractState) { let current_count = self.get_count(); assert!(current_count != 0, "Amount cannot be 0"); self.count.write(current_count - 1); + self.emit(CountDecreasedByOne { previous_count: current_count }); } fn activate_switch(ref self: ContractState) { @@ -68,6 +110,7 @@ mod Agggregator { if !killswitch.get_status() { killswitch.switch() } + self.emit(SwitchStatus { get_status: true }); } fn get_count(self: @ContractState) -> u32 { diff --git a/tests/test_contract.cairo b/tests/test_contract.cairo index 71b3178..e9ffb8a 100644 --- a/tests/test_contract.cairo +++ b/tests/test_contract.cairo @@ -115,6 +115,7 @@ fn test_decrease_count_by_one_aggregator() { assert(count_after == 19, 'incorrect count'); } + #[test] fn test_increase_activate_switch() { let (_, killswitch_dispatcher, aggregator_dispatcher) = deploy_contract(); @@ -130,7 +131,7 @@ fn test_increase_activate_switch() { #[test] #[should_panic(expect: 'Amount cannot be 0')] -fn test_increase_count_by_zero() { +fn test_increase_count_by_zero_aggregator() { let (_, _, aggregator_dispatcher) = deploy_contract(); let count_after = aggregator_dispatcher.get_count(); @@ -149,3 +150,17 @@ fn test_increase_counter_count_error() { aggregator_dispatcher.increase_counter_count(42); } + +#[test] +#[should_panic(expect: "Amount cannot be 0")] +fn test_decrease_zero_count_by_one_aggregator() { + let (_, _, aggregator_dispatcher) = deploy_contract(); + + let count_after = aggregator_dispatcher.get_count(); + assert(count_after == 0, 'invalid count'); + + aggregator_dispatcher.decrease_count_by_one(); + + let count_after = aggregator_dispatcher.get_count(); + assert(count_after == 19, 'incorrect count'); +}