From 8bc721227fc9024d80643e1c29661f6251a819e9 Mon Sep 17 00:00:00 2001 From: Awointa Date: Fri, 16 May 2025 12:07:57 +0100 Subject: [PATCH 1/2] tassk 4 --- src/IHello.cairo | 2 +- src/INumber.cairo | 2 +- src/IOwnable.cairo | 8 +++ src/aggregator.cairo | 118 ++++++++++++++++++++++++++++-- src/counter.cairo | 24 ++++++- src/hello.cairo | 6 +- src/killswitch.cairo | 16 ++++- src/lib.cairo | 11 +-- src/ownable.cairo | 47 ++++++++++++ tests/test_aggregator.cairo | 140 ++++++++++++++++++++++++++++++++++++ tests/test_contract.cairo | 15 ++-- tests/test_counter.cairo | 106 +++++++++++++++++++++++++++ tests/test_killswitch.cairo | 46 ++++++++++++ tests/test_ownable.cairo | 70 ++++++++++++++++++ 14 files changed, 585 insertions(+), 26 deletions(-) create mode 100644 src/IOwnable.cairo create mode 100644 src/ownable.cairo create mode 100644 tests/test_aggregator.cairo create mode 100644 tests/test_counter.cairo create mode 100644 tests/test_killswitch.cairo create mode 100644 tests/test_ownable.cairo diff --git a/src/IHello.cairo b/src/IHello.cairo index 0d1c511..df6f372 100644 --- a/src/IHello.cairo +++ b/src/IHello.cairo @@ -8,4 +8,4 @@ pub trait IHelloStarknet { fn get_balance(self: @TContractState) -> felt252; fn add_and_subtract(ref self: TContractState, amount: felt252); -} \ No newline at end of file +} diff --git a/src/INumber.cairo b/src/INumber.cairo index eb4014b..5a4f700 100644 --- a/src/INumber.cairo +++ b/src/INumber.cairo @@ -5,4 +5,4 @@ pub trait INumber { fn set_number(ref self: TContractState, amount: u8); /// Returns the current number fn get_number(self: @TContractState) -> u8; -} \ No newline at end of file +} diff --git a/src/IOwnable.cairo b/src/IOwnable.cairo new file mode 100644 index 0000000..adf8754 --- /dev/null +++ b/src/IOwnable.cairo @@ -0,0 +1,8 @@ +use starknet::ContractAddress; + +#[starknet::interface] +pub trait IOwnable { + fn set_owner(ref self: TContractState, new_owner: ContractAddress); + + fn get_owner(self: @TContractState) -> ContractAddress; +} diff --git a/src/aggregator.cairo b/src/aggregator.cairo index 3e4cbe8..fa0f622 100644 --- a/src/aggregator.cairo +++ b/src/aggregator.cairo @@ -16,11 +16,12 @@ pub trait IAggregator { /// Simple contract for managing count. #[starknet::contract] -mod Agggregator { +pub mod Agggregator { + use cohort_4::IOwnable::{IOwnableDispatcher, IOwnableDispatcherTrait}; use cohort_4::counter::{ICounterDispatcher, ICounterDispatcherTrait}; use cohort_4::killswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; - use starknet::ContractAddress; use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + use starknet::{ContractAddress, get_caller_address}; #[storage] @@ -28,39 +29,117 @@ mod Agggregator { count: u32, counter: ContractAddress, killswitch: ContractAddress, + ownable_address: ContractAddress, } + #[event] + #[derive(Drop, starknet::Event)] + pub enum Event { + increased_aggregator_count: increased_aggregator_count, + decreased_aggregator_count: decreased_aggregator_count, + } + + + #[derive(Drop, starknet::Event)] + pub struct increased_aggregator_count { + pub amount: u32, + pub caller_address: ContractAddress, + } + + + #[derive(Drop, starknet::Event)] + pub struct decreased_aggregator_count { + pub amount: u32, + pub caller_address: ContractAddress, + } + + #[constructor] - fn constructor(ref self: ContractState, counter: ContractAddress, killswitch: ContractAddress) { + fn constructor( + ref self: ContractState, + counter: ContractAddress, + killswitch: ContractAddress, + ownable_address: ContractAddress, + ) { + self.__validator(counter, killswitch, ownable_address); + self.counter.write(counter); self.killswitch.write(killswitch); + self.ownable_address.write(ownable_address); } #[abi(embed_v0)] impl AggregatorImpl of super::IAggregator { fn increase_count(ref self: ContractState, amount: u32) { + + let caller = get_caller_address(); + + // check if the caller has access + self.__has_access(); + assert(amount > 0, 'Amount cannot be 0'); let counter = ICounterDispatcher { contract_address: self.counter.read() }; + let counter_count = counter.get_count(); self.count.write(counter_count + amount); + + self + .emit( + Event::increased_aggregator_count( + increased_aggregator_count { + amount: amount, + caller_address: caller, + }, + ), + ); } fn increase_counter_count(ref self: ContractState, amount: u32) { let killswitch: IKillSwitchDispatcher = IKillSwitchDispatcher { contract_address: self.killswitch.read(), }; + + self.__has_access(); + + let caller = get_caller_address(); + assert(killswitch.get_status(), 'not active'); - ICounterDispatcher { contract_address: self.counter.read() }.increase_count(amount) + ICounterDispatcher { contract_address: self.counter.read() }.increase_count(amount); + + self + .emit( + Event::increased_aggregator_count( + increased_aggregator_count { + amount: amount, + caller_address: caller + }, + ), + ); } fn decrease_count_by_one(ref self: ContractState) { + self.__has_access(); + + let caller = get_caller_address(); + let current_count = self.get_count(); assert!(current_count != 0, "Amount cannot be 0"); self.count.write(current_count - 1); + + self + .emit( + Event::decreased_aggregator_count( + decreased_aggregator_count { + amount: 1, caller_address: caller, + }, + ), + ); } fn activate_switch(ref self: ContractState) { + self.__has_access(); + let killswitch: IKillSwitchDispatcher = IKillSwitchDispatcher { contract_address: self.killswitch.read(), }; @@ -74,4 +153,35 @@ mod Agggregator { self.count.read() } } + + #[generate_trait] + impl internalImpl of InternalTrait { + fn __has_access(self: @ContractState) { + let caller = get_caller_address(); + + let owner = IOwnableDispatcher { contract_address: self.ownable_address.read() } + .get_owner(); + + assert(caller == owner, 'Only owner can increase count'); + } + + fn __validator( + self: @ContractState, + counter: ContractAddress, + killswitch: ContractAddress, + ownable_address: ContractAddress, + ) { + // Check if addresses are not reused + assert(counter != killswitch && counter != ownable_address, 'counter address reused'); + + assert( + killswitch != ownable_address && killswitch != counter, 'killswitch address reused', + ); + + assert( + ownable_address != killswitch && ownable_address != counter, + 'ownable address reused', + ) + } + } } diff --git a/src/counter.cairo b/src/counter.cairo index 190c073..5c13e59 100644 --- a/src/counter.cairo +++ b/src/counter.cairo @@ -14,30 +14,50 @@ pub trait ICounter { /// Simple contract for managing count. #[starknet::contract] -mod Counter { +pub mod Counter { use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + #[storage] struct Storage { count: u32, } + #[event] + #[derive(Drop, starknet::Event)] + pub enum Event { + CounterIncreased: CounterWasIncreased, + CounterDecreased: CounterWasDecreased, + } + + #[derive(Drop, starknet::Event)] + pub struct CounterWasIncreased { + pub amount: u32, + } + + #[derive(Drop, starknet::Event)] + pub struct CounterWasDecreased { + pub amount: u32, + } + #[abi(embed_v0)] impl CounterImpl of super::ICounter { fn increase_count(ref self: ContractState, amount: u32) { assert(amount > 0, 'Amount cannot be 0'); let counter_count = self.get_count(); self.count.write(counter_count + amount); + self.emit(Event::CounterIncreased(CounterWasIncreased { amount: amount })); } fn decrease_count_by_one(ref self: ContractState) { let current_count = self.get_count(); assert(current_count > 0, 'Amount cannot be 0'); self.count.write(current_count - 1); + self.emit(Event::CounterDecreased(CounterWasDecreased { amount: 1 })); } fn get_count(self: @ContractState) -> u32 { - self.count.read() + self.count.read() } } } diff --git a/src/hello.cairo b/src/hello.cairo index fd7d3e2..13c95fd 100644 --- a/src/hello.cairo +++ b/src/hello.cairo @@ -39,8 +39,8 @@ mod HelloStarknet { } fn add_and_subtract(ref self: ContractState, amount: felt252) { - self._add(amount); - self._subtract(amount); + self._add(amount); + self._subtract(amount); } } @@ -52,4 +52,4 @@ mod HelloStarknet { number } } -} \ No newline at end of file +} diff --git a/src/killswitch.cairo b/src/killswitch.cairo index 6437088..08fef3e 100644 --- a/src/killswitch.cairo +++ b/src/killswitch.cairo @@ -11,7 +11,7 @@ pub trait IKillSwitch { /// Simple contract for managing count. #[starknet::contract] -mod KillSwitch { +pub mod KillSwitch { use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; #[storage] @@ -19,12 +19,24 @@ mod KillSwitch { status: bool, } + #[event] + #[derive(Drop, starknet::Event)] + pub enum Event{ + EventSwitched: EventSwitched + } + + #[derive(Drop, starknet::Event)] + pub struct EventSwitched{ + pub status: bool + } #[abi(embed_v0)] impl KillSwitchImpl of super::IKillSwitch { fn switch(ref self: ContractState) { - // assert(amount != 0, 'Amount cannot be 0'); self.status.write(!self.status.read()); + self.emit(Event::EventSwitched(EventSwitched{ + status: self.status.read() + })) } diff --git a/src/lib.cairo b/src/lib.cairo index 8b3993c..62819d0 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -1,22 +1,23 @@ -pub mod hello; pub mod IHello; pub mod INumber; +pub mod IOwnable; +pub mod aggregator; pub mod counter; +pub mod hello; pub mod killswitch; -pub mod aggregator; - +pub mod ownable; fn main() { // Function calls (Uncomment to execute them) // say_name("Sylvia Nnoruka!"); // intro_to_felt(); - + let num_1 = 5; let num_2 = 10; let sum = sum_num(num_1, num_2); println!("The sum of {} and {} is = {}", num_1, num_2, sum); - + // check_u16(6553); // Uncomment if needed is_greater_than_50(3); } diff --git a/src/ownable.cairo b/src/ownable.cairo new file mode 100644 index 0000000..38d295b --- /dev/null +++ b/src/ownable.cairo @@ -0,0 +1,47 @@ +use starknet::{ContractAddress, get_caller_address}; +use crate::IOwnable::IOwnable; + +#[starknet::contract] +mod Ownable { + use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + use super::*; + + + #[storage] + struct Storage { + owner: ContractAddress, + } + + #[constructor] + fn constructor(ref self: ContractState, owner: ContractAddress) { + self.__zero_checker(owner); + self.owner.write(owner); + } + + #[abi(embed_v0)] + impl OwnableImpl of IOwnable { + fn set_owner(ref self: ContractState, new_owner: ContractAddress) { + let caller = get_caller_address(); + assert(self.owner.read() == caller, 'Only owner can set new owner'); + + self.__zero_checker(new_owner); + + assert( + self.owner.read() != new_owner, 'New owner not valid address', + ); + self.owner.write(new_owner); + } + + fn get_owner(self: @ContractState) -> ContractAddress { + self.owner.read() + } + } + + #[generate_trait] + impl internalimpl of internalInterface { + fn __zero_checker(ref self: ContractState, some_address: ContractAddress) { + let zero_address = 'zero_address'.try_into().unwrap(); + assert(some_address != zero_address, 'Address cannot be zero'); + } + } +} diff --git a/tests/test_aggregator.cairo b/tests/test_aggregator.cairo new file mode 100644 index 0000000..5f6a9b3 --- /dev/null +++ b/tests/test_aggregator.cairo @@ -0,0 +1,140 @@ +use cohort_4::aggregator::{IAggregatorDispatcher, IAggregatorDispatcherTrait, IAggregatorSafeDispatcher, Agggregator}; +use cohort_4::IOwnable::{IOwnableDispatcher, IOwnableDispatcherTrait, IOwnableSafeDispatcher}; +use cohort_4::killswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; +use cohort_4::counter::{ICounterDispatcher,ICounterDispatcherTrait}; +use starknet::ContractAddress; +use snforge_std::{declare, ContractClassTrait, DeclareResultTrait, start_cheat_caller_address, spy_events, EventSpyAssertionsTrait}; + +use cohort_4::aggregator::Agggregator::{increased_aggregator_count, Event}; + + + +fn deploy_ownable_contract(initial_address: ContractAddress) -> (IOwnableDispatcher, ContractAddress, ) { + let contract = declare("Ownable").unwrap().contract_class(); + + // serialize call_data + let mut calldata: Array = array![]; + initial_address.serialize(ref calldata); + + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + let ownable = IOwnableDispatcher{contract_address}; + let safe_owner = IOwnableSafeDispatcher{contract_address}; + + return (ownable, contract_address); +} + +fn deploy_counter_contract() -> (ICounterDispatcher, ContractAddress) { + let countract_name: ByteArray = "Counter"; + let contract = declare(countract_name).unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@ArrayTrait::new()).unwrap(); + let counter = ICounterDispatcher{contract_address}; + return (counter, contract_address); +} + +fn deploy_killswitch_contract()-> (IKillSwitchDispatcher, ContractAddress){ + let contract = declare("KillSwitch").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@array![]).unwrap(); + let killswitch = IKillSwitchDispatcher{contract_address}; + + return (killswitch, contract_address); +} + +// deploy aggregator +fn deploy_aggregator() -> (IAggregatorDispatcher, IAggregatorSafeDispatcher, ContractAddress, ContractAddress) { + + let (ownable, ownable_address) = deploy_ownable_contract(OWNER()); + let owner_address = ownable.get_owner(); // Get the actual owner address + + let (_, counter_address) = deploy_counter_contract(); + let (_, killswitch_address) = deploy_killswitch_contract(); + + let contract = declare("Agggregator").unwrap().contract_class(); + + // serialize call data + let mut calldata: Array = array![]; + counter_address.serialize(ref calldata); + killswitch_address.serialize(ref calldata); + ownable_address.serialize(ref calldata); // Use the actual ownable contract address + + let(contract_address, _) = contract.deploy(@calldata).unwrap(); + let aggregator = IAggregatorDispatcher{contract_address}; + let safe_aggregator = IAggregatorSafeDispatcher{contract_address}; + + return (aggregator, safe_aggregator, contract_address, owner_address); +} + +fn OWNER() -> ContractAddress{ + 'OWNER'.try_into().unwrap() +} + +// fn ADDRESS_2() -> ContractAddress { +// let owner_address: ContractAddress = 'OWNER_1'.try_into().unwrap(); +// owner_address +// } + +// fn ADDRESS_3() -> ContractAddress { +// let owner_address: ContractAddress = 'OWNER_2'.try_into().unwrap(); +// owner_address +// } + +#[test] +fn test_get_count(){ + let (aggregator, _, _, _) = deploy_aggregator(); + + let aggregator_current_count = aggregator.get_count(); + assert(aggregator_current_count == 0, 'Aggregator count should be zero'); +} + +#[test] +fn test_aggregator_increase_count() { + // Get all values from the single deployment + let (aggregator, _, aggregator_address, owner_address) = deploy_aggregator(); + + let aggregator_count_1 = aggregator.get_count(); + assert(aggregator_count_1 == 0, 'Aggregator count should be zero'); + + // Use the owner address to make the call + start_cheat_caller_address(aggregator_address, owner_address); + + aggregator.increase_count(20); + + let aggregator_count_2 = aggregator.get_count(); + assert(aggregator_count_2 == 20, 'Aggregator count invalid'); +} + +// #[test] +// #[feature("safe_dispatcher")] +// fn test_aggregator_increase_count_by_zero(){ +// let (aggregator, safe_aggregator, aggregator_address, owner_address) = deploy_aggregator(); + +// let aggregator_count_1 = aggregator.get_count(); +// assert(aggregator_count_1== 0, 'Aggregator count should be zero'); + +// // impersonate owner +// start_cheat_caller_address(aggregator_address, owner_address); + +// match safe_aggregator.increase_count(0){ +// Result::Ok(_) => panic!("increasing by zero should panic"), +// Result::Err(panic_data) => assert(*panic_data.at(0) == 'Amount cannot be 0', *panic_data.at(0)) +// } +// } + +// #[test] +// fn test_aggregator_increase_count_event(){ + +// let (aggregator, _, aggregator_address, owner_address) = deploy_aggregator(); +// let mut spy = spy_events(); + +// let aggregator_count_1 = aggregator.get_count(); +// assert(aggregator_count_1 == 0, 'Aggregator count should be zero'); + +// // Use the owner address to make the call +// start_cheat_caller_address(aggregator_address, owner_address); + +// aggregator.increase_count(20); + + +// spy.assert_emitted(@array![(owner_address, +// Event::increased_aggregator_count(increased_aggregator_count{ amount: 20, caller_address: owner_address}), +// ),]) +// } \ No newline at end of file diff --git a/tests/test_contract.cairo b/tests/test_contract.cairo index 9d6708a..d83d72d 100644 --- a/tests/test_contract.cairo +++ b/tests/test_contract.cairo @@ -1,9 +1,10 @@ +use cohort_4::counter::{ + ICounterDispatcher, ICounterDispatcherTrait, ICounterSafeDispatcher, + ICounterSafeDispatcherTrait, +}; +use snforge_std::{ContractClassTrait, DeclareResultTrait, declare}; use starknet::ContractAddress; -use snforge_std::{declare, ContractClassTrait, DeclareResultTrait}; - -use cohort_4::counter::{ICounterDispatcher, ICounterDispatcherTrait, ICounterSafeDispatcher, ICounterSafeDispatcherTrait}; - fn deploy_contract() -> ContractAddress { let countract_name: ByteArray = "Counter"; let contract = declare(countract_name).unwrap().contract_class(); @@ -34,7 +35,7 @@ fn test_cannot_increase_balance_with_zero_value() { let dispatcher = ICounterDispatcher { contract_address }; let balance_before = dispatcher.get_count(); - assert(balance_before == 0 , 'Invalid balance'); + assert(balance_before == 0, 'Invalid balance'); let safe_dispatcher = ICounterSafeDispatcher { contract_address }; @@ -42,8 +43,6 @@ fn test_cannot_increase_balance_with_zero_value() { Result::Ok(_) => core::panic_with_felt252('Should have panicked'), Result::Err(panic_data) => { assert(*panic_data.at(0) == 'Amount cannot be 0', *panic_data.at(0)); - } + }, }; - - } diff --git a/tests/test_counter.cairo b/tests/test_counter.cairo new file mode 100644 index 0000000..5547e7d --- /dev/null +++ b/tests/test_counter.cairo @@ -0,0 +1,106 @@ +use cohort_4::counter::{ + ICounterDispatcher, ICounterDispatcherTrait, ICounterSafeDispatcher, ICounterSafeDispatcherTrait, Counter +}; +use snforge_std::{ContractClassTrait, DeclareResultTrait, declare, start_cheat_caller_address, EventSpyAssertionsTrait, spy_events}; +use starknet::ContractAddress; +use cohort_4::counter::Counter::{CounterWasIncreased, CounterWasDecreased}; + +fn deploy() -> (ICounterDispatcher, ICounterSafeDispatcher, ContractAddress){ + let counter_countract_name: ByteArray = "Counter"; + let contract = declare(counter_countract_name).unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@ArrayTrait::new()).unwrap(); + let counter = ICounterDispatcher { contract_address}; + let safe_counter = ICounterSafeDispatcher{contract_address}; + return (counter, safe_counter, contract_address); +} + + + +#[test] +fn test_initialized_count(){ + let (counter, _, _) = deploy(); + let count_1 = counter.get_count(); + assert(count_1 == 0, 'invalid count value'); +} + +#[test] +fn test_increase_count(){ + let (counter, _, _) = deploy(); + let count_1 = counter.get_count(); + assert(count_1 == 0, 'invalid count value'); + + let amount = 30; + counter.increase_count(amount); + let count_2 = counter.get_count(); + assert(count_2 == amount, 'Invalid amount'); + assert(count_2 != 10, 'invalid 10') +} + +#[test] +#[feature("safe_dispatcher")] +fn test_increase_count_by_zero_amount(){ + let (counter,safe_counter, _) = deploy(); + + let count_1 = counter.get_count(); + assert(count_1 == 0, 'invalid count value'); + + + match safe_counter.increase_count(0){ + Result::Ok(_) => panic!("amount passed not zero"), + Result::Err(panic_data) => { + assert(*panic_data.at(0) == 'Amount cannot be 0', *panic_data.at(0)) + } + } +} + +#[test] +fn test_for_increase_emit_event(){ + let (counter, _, contract_address) = deploy(); + + let mut spy = spy_events(); + + counter.increase_count(2); + + spy.assert_emitted( + @array![ + ( contract_address, + Counter::Event::CounterIncreased( + CounterWasIncreased { + amount: 2 + } + ), + ), + ], + ); +} + +#[test] +fn test_for_decrease_count_by_one(){ + let (counter, _, _) = deploy(); + + + let count_1 = counter.get_count(); + assert(count_1 == 0, 'Invalid count value not zero'); + + counter.increase_count(2); + + counter.decrease_count_by_one(); + let count_2 = counter.get_count(); + assert(count_1 == count_2-1, 'Invalid count value'); +} + +#[test] +fn test_emit_decrease_count_event(){ + let (counter,_, contract_address) = deploy(); + + let mut spy = spy_events(); + counter.increase_count(2); + + + counter.decrease_count_by_one(); + + spy.assert_emitted(@array![(contract_address, Counter::Event::CounterDecreased(CounterWasDecreased{ + amount: 1 + }))]) + +} \ No newline at end of file diff --git a/tests/test_killswitch.cairo b/tests/test_killswitch.cairo new file mode 100644 index 0000000..934bd8b --- /dev/null +++ b/tests/test_killswitch.cairo @@ -0,0 +1,46 @@ +use cohort_4::killswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait, KillSwitch}; +use cohort_4::killswitch::KillSwitch::{EventSwitched}; +use snforge_std::{ContractClassTrait, DeclareResultTrait, declare, spy_events, EventSpyAssertionsTrait}; +use starknet::ContractAddress; + +fn deploy()-> (IKillSwitchDispatcher, ContractAddress){ + let contract = declare("KillSwitch").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@array![]).unwrap(); + let killswitch = IKillSwitchDispatcher{contract_address}; + + (killswitch, contract_address) +} + +#[test] +fn test_get_status(){ + let (killswitch, _)= deploy(); + + let current_status = killswitch.get_status(); + + assert(!current_status, 'status is not valid'); +} + +#[test] +fn test_switch(){ + let (killswitch, _) = deploy(); + + let prev_status = killswitch.get_status(); + killswitch.switch(); + + assert(!prev_status, 'Status did not change'); +} + +#[test] +fn test_switch_event(){ + let (killswitch, contract_address) = deploy(); + + + let mut spy = spy_events(); + + killswitch.switch(); + + + spy.assert_emitted(@array![(contract_address, + KillSwitch::Event::EventSwitched( EventSwitched{ status: true}), + ),]) +} \ No newline at end of file diff --git a/tests/test_ownable.cairo b/tests/test_ownable.cairo new file mode 100644 index 0000000..0aa8fe6 --- /dev/null +++ b/tests/test_ownable.cairo @@ -0,0 +1,70 @@ +use cohort_4::IOwnable::IOwnableSafeDispatcherTrait; +use cohort_4::IOwnable::{IOwnableDispatcher, IOwnableDispatcherTrait, IOwnableSafeDispatcher}; +use snforge_std::{ + ContractClassTrait, DeclareResultTrait, declare, start_cheat_caller_address, + stop_cheat_caller_address, +}; +use starknet::ContractAddress; +// create deploy utils function +fn deploy(initial_address: ContractAddress) -> (IOwnableDispatcher, IOwnableSafeDispatcher, ContractAddress, ContractAddress) { + // let owner_address: ContractAddress = 'ownable'.try_into().unwrap(); + + let contract = declare("Ownable").unwrap().contract_class(); + + // serialize call_data + let mut calldata: Array = array![]; + initial_address.serialize(ref calldata); + + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + let owner = IOwnableDispatcher { contract_address}; + let safe_owner = IOwnableSafeDispatcher{contract_address}; + + return (owner, safe_owner, initial_address, contract_address); +} + +fn OWNER() -> ContractAddress{ + 'OWNER'.try_into().unwrap() +} + +fn OWNER_1() -> ContractAddress { + let owner_address: ContractAddress = 'OWNER_1'.try_into().unwrap(); + owner_address +} + +fn OWNER_2() -> ContractAddress { + let owner_address: ContractAddress = 'OWNER_2'.try_into().unwrap(); + owner_address +} + +#[test] +fn test_ownable_get_owner() { + let (owner, _, owner_address, _) = deploy(OWNER()); + + let current_owner = owner.get_owner(); + assert(owner_address == current_owner, 'got wrong owner address'); +} + +#[test] +fn test_ownable_set_owner(){ + let (owner, _,owner_address,ownable_address) = deploy(OWNER()); + + start_cheat_caller_address(ownable_address, owner_address); + + owner.set_owner(OWNER_2()); + + let updated_owner = owner.get_owner(); + + assert(OWNER_2() == updated_owner, 'Did not update owner'); + +} + +#[test] +#[feature("safe_dispatcher")] +fn test_set_owner_to_previous_owner(){ + let (_,safe_owner, _, _) = deploy(OWNER()); + + match safe_owner.set_owner(OWNER()){ + Result::Ok(_) => panic!("setting same owner should panic"), + Result::Err(panic_data) => assert(*panic_data.at(0) == 'Only owner can set new owner', *panic_data.at(0)) + } +} From cc4d189c5e2129cc0145b114de6a2bede8e0759d Mon Sep 17 00:00:00 2001 From: Awointa Date: Fri, 16 May 2025 13:24:31 +0100 Subject: [PATCH 2/2] aggregator test --- tests/test_aggregator.cairo | 304 ++++++++++++++++++++++++------------ 1 file changed, 202 insertions(+), 102 deletions(-) diff --git a/tests/test_aggregator.cairo b/tests/test_aggregator.cairo index 5f6a9b3..8c5a602 100644 --- a/tests/test_aggregator.cairo +++ b/tests/test_aggregator.cairo @@ -1,140 +1,240 @@ -use cohort_4::aggregator::{IAggregatorDispatcher, IAggregatorDispatcherTrait, IAggregatorSafeDispatcher, Agggregator}; -use cohort_4::IOwnable::{IOwnableDispatcher, IOwnableDispatcherTrait, IOwnableSafeDispatcher}; +use cohort_4::aggregator::{IAggregatorDispatcher, IAggregatorDispatcherTrait}; +use cohort_4::counter::{ + ICounterDispatcher, ICounterDispatcherTrait, ICounterSafeDispatcher, + ICounterSafeDispatcherTrait, +}; use cohort_4::killswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; -use cohort_4::counter::{ICounterDispatcher,ICounterDispatcherTrait}; +use cohort_4::IOwnable::{IOwnableDispatcher, IOwnableDispatcherTrait}; +use snforge_std::{ContractClassTrait, DeclareResultTrait, declare, start_cheat_caller_address, stop_cheat_caller_address}; use starknet::ContractAddress; -use snforge_std::{declare, ContractClassTrait, DeclareResultTrait, start_cheat_caller_address, spy_events, EventSpyAssertionsTrait}; - -use cohort_4::aggregator::Agggregator::{increased_aggregator_count, Event}; +use starknet::contract_address_const; +fn deploy_contract() -> (ICounterDispatcher, IKillSwitchDispatcher, IOwnableDispatcher, IAggregatorDispatcher , ContractAddress) { + + // owner address + let owner_address: ContractAddress = contract_address_const::<'1'>(); + + + //counter deployment + let counter_countract_name: ByteArray = "Counter"; + let contract = declare(counter_countract_name).unwrap().contract_class(); + let (counter_contract_address, _) = contract.deploy(@ArrayTrait::new()).unwrap(); + let counter_dispatcher = ICounterDispatcher { contract_address: counter_contract_address }; + + //killswitch deployment + let killswitch_contract_name: ByteArray = "KillSwitch"; + let killswitch_contract = declare(killswitch_contract_name).unwrap().contract_class(); + let (killswitch_contract_address, _) = killswitch_contract.deploy(@ArrayTrait::new()).unwrap(); + let killswitch_dispatcher = IKillSwitchDispatcher { + contract_address: killswitch_contract_address, + }; + + //ownable deployment + let ownable_contract_name: ByteArray = "Ownable"; + let ownable_contract = declare(ownable_contract_name).unwrap().contract_class(); + // Deploy with owner_address as the initial owner + let (ownable_contract_address, _) = ownable_contract + .deploy(@array![owner_address.into()]) + .unwrap(); + let ownable_dispatcher = IOwnableDispatcher { + contract_address: ownable_contract_address, + }; + + //aggregator deployment + let aggregator = declare("Agggregator").unwrap().contract_class(); + let (aggregator_contract_address, _) = aggregator + .deploy(@array![counter_contract_address.into(), killswitch_contract_address.into(), ownable_contract_address.into()]) + .unwrap(); + + let aggregator_dispatcher = IAggregatorDispatcher { + contract_address: aggregator_contract_address, + }; + + (counter_dispatcher, killswitch_dispatcher, ownable_dispatcher, aggregator_dispatcher, owner_address) +} -fn deploy_ownable_contract(initial_address: ContractAddress) -> (IOwnableDispatcher, ContractAddress, ) { - let contract = declare("Ownable").unwrap().contract_class(); +#[test] +fn test_increase_count() { + let (counter_dispatcher, _, _, _, _) = deploy_contract(); - // serialize call_data - let mut calldata: Array = array![]; - initial_address.serialize(ref calldata); + let balance_before = counter_dispatcher.get_count(); + assert(balance_before == 0, 'Invalid balance'); - let (contract_address, _) = contract.deploy(@calldata).unwrap(); - let ownable = IOwnableDispatcher{contract_address}; - let safe_owner = IOwnableSafeDispatcher{contract_address}; + counter_dispatcher.increase_count(42); - return (ownable, contract_address); + let balance_after = counter_dispatcher.get_count(); + assert(balance_after == 42, 'Invalid balance'); } -fn deploy_counter_contract() -> (ICounterDispatcher, ContractAddress) { - let countract_name: ByteArray = "Counter"; - let contract = declare(countract_name).unwrap().contract_class(); - let (contract_address, _) = contract.deploy(@ArrayTrait::new()).unwrap(); - let counter = ICounterDispatcher{contract_address}; - return (counter, contract_address); -} +#[test] +fn test_increase_count_aggregator_as_owner() { + let (_, _, _, aggregator_dispatcher, owner_address) = deploy_contract(); -fn deploy_killswitch_contract()-> (IKillSwitchDispatcher, ContractAddress){ - let contract = declare("KillSwitch").unwrap().contract_class(); - let (contract_address, _) = contract.deploy(@array![]).unwrap(); - let killswitch = IKillSwitchDispatcher{contract_address}; + + + + start_cheat_caller_address(aggregator_dispatcher.contract_address, owner_address); + + + - return (killswitch, contract_address); + let balance_before = aggregator_dispatcher.get_count(); + assert(balance_before == 0, 'Invalid balance'); + + + aggregator_dispatcher.increase_count(42); + + let balance_after = aggregator_dispatcher.get_count(); + assert(balance_after == 42, 'Invalid balance'); + + + stop_cheat_caller_address(aggregator_dispatcher.contract_address,); } -// deploy aggregator -fn deploy_aggregator() -> (IAggregatorDispatcher, IAggregatorSafeDispatcher, ContractAddress, ContractAddress) { +#[test] +#[should_panic(expected: 'Only owner can increase count')] +fn test_increase_count_aggregator_as_non_owner() { + let (_, _, _, aggregator_dispatcher, _) = deploy_contract(); + + + let non_owner_address: ContractAddress = contract_address_const::<'2'>(); + - let (ownable, ownable_address) = deploy_ownable_contract(OWNER()); - let owner_address = ownable.get_owner(); // Get the actual owner address + start_cheat_caller_address(aggregator_dispatcher.contract_address, non_owner_address); + + + aggregator_dispatcher.increase_count(42); + + + let balance_after = aggregator_dispatcher.get_count(); + assert(balance_after == 0, 'you are not the owner'); - let (_, counter_address) = deploy_counter_contract(); - let (_, killswitch_address) = deploy_killswitch_contract(); + stop_cheat_caller_address(aggregator_dispatcher.contract_address,); +} + +#[test] +#[feature("safe_dispatcher")] +fn test_cannot_increase_balance_with_zero_value() { + let (counter_dispatcher, _, _, _, _) = deploy_contract(); + + let balance_before = counter_dispatcher.get_count(); + assert(balance_before == 0, 'Invalid balance'); + + let safe_dispatcher = ICounterSafeDispatcher { + contract_address: counter_dispatcher.contract_address, + }; + + match safe_dispatcher.increase_count(0) { + Result::Ok(_) => core::panic_with_felt252('Should have panicked'), + Result::Err(panic_data) => { + assert(*panic_data.at(0) == 'Amount cannot be 0', *panic_data.at(0)); + }, + }; +} + + +#[test] +fn test_increase_counter_count() { + let (counter_dispatcher, _, _, aggregator_dispatcher, owner_address) = deploy_contract(); - let contract = declare("Agggregator").unwrap().contract_class(); + + start_cheat_caller_address(aggregator_dispatcher.contract_address, owner_address); + - // serialize call data - let mut calldata: Array = array![]; - counter_address.serialize(ref calldata); - killswitch_address.serialize(ref calldata); - ownable_address.serialize(ref calldata); // Use the actual ownable contract address + let initial_count = counter_dispatcher.get_count(); + assert(initial_count == 0, 'invalid initial count'); - let(contract_address, _) = contract.deploy(@calldata).unwrap(); - let aggregator = IAggregatorDispatcher{contract_address}; - let safe_aggregator = IAggregatorSafeDispatcher{contract_address}; + + aggregator_dispatcher.activate_switch(); + + aggregator_dispatcher.increase_counter_count(42); - return (aggregator, safe_aggregator, contract_address, owner_address); + let final_count = counter_dispatcher.get_count(); + assert(final_count == 42, 'counter not increased correctly'); } -fn OWNER() -> ContractAddress{ - 'OWNER'.try_into().unwrap() +#[test] +fn test_decrease_count_by_one_aggregator() { + let (counter_dispatcher, _, _, aggregator_dispatcher, owner_address) = deploy_contract(); + + + start_cheat_caller_address(aggregator_dispatcher.contract_address, owner_address); + let count_before = aggregator_dispatcher.get_count(); + assert(count_before == 0, 'invalid count'); + + + + aggregator_dispatcher.increase_count(20); + aggregator_dispatcher.decrease_count_by_one(); + + let count_after = aggregator_dispatcher.get_count(); + assert(count_after == 19, 'incorrect count'); } -// fn ADDRESS_2() -> ContractAddress { -// let owner_address: ContractAddress = 'OWNER_1'.try_into().unwrap(); -// owner_address -// } +#[test] +fn test_increase_activate_switch() { + let (_, killswitch_dispatcher, _, aggregator_dispatcher, owner_address) = deploy_contract(); + + let non_owner_address: ContractAddress = contract_address_const::<'2'>(); + + start_cheat_caller_address(aggregator_dispatcher.contract_address, owner_address); + let status = killswitch_dispatcher.get_status(); + assert(!status, 'failed'); + + aggregator_dispatcher.activate_switch(); + + stop_cheat_caller_address(aggregator_dispatcher.contract_address,); + + let status_after = killswitch_dispatcher.get_status(); + assert(status_after, 'invalid status'); +} -// fn ADDRESS_3() -> ContractAddress { -// let owner_address: ContractAddress = 'OWNER_2'.try_into().unwrap(); -// owner_address -// } #[test] -fn test_get_count(){ - let (aggregator, _, _, _) = deploy_aggregator(); +#[should_panic(expect: 'you are not the owner' )] +fn test_increase_activate_switch_non_owner() { + + let (_, killswitch_dispatcher, _, aggregator_dispatcher, _) = deploy_contract(); + + let non_owner_address: ContractAddress = contract_address_const::<'2'>(); + + start_cheat_caller_address(aggregator_dispatcher.contract_address, non_owner_address); + let status = killswitch_dispatcher.get_status(); + assert(!status, 'failed'); + + aggregator_dispatcher.activate_switch(); + + stop_cheat_caller_address(aggregator_dispatcher.contract_address,); + + let status_after = killswitch_dispatcher.get_status(); + assert(status_after, 'invalid status'); - let aggregator_current_count = aggregator.get_count(); - assert(aggregator_current_count == 0, 'Aggregator count should be zero'); } #[test] -fn test_aggregator_increase_count() { - // Get all values from the single deployment - let (aggregator, _, aggregator_address, owner_address) = deploy_aggregator(); +#[should_panic(expect: 'Amount cannot be 0')] +fn test_increase_count_by_zero() { + let (_, killswitch_dispatcher, _, aggregator_dispatcher, _) = deploy_contract(); - let aggregator_count_1 = aggregator.get_count(); - assert(aggregator_count_1 == 0, 'Aggregator count should be zero'); - - // Use the owner address to make the call - start_cheat_caller_address(aggregator_address, owner_address); - - aggregator.increase_count(20); - - let aggregator_count_2 = aggregator.get_count(); - assert(aggregator_count_2 == 20, 'Aggregator count invalid'); + + let count_after = aggregator_dispatcher.get_count(); + assert(count_after == 0, 'incorrect count'); + + aggregator_dispatcher.increase_count(0); } -// #[test] -// #[feature("safe_dispatcher")] -// fn test_aggregator_increase_count_by_zero(){ -// let (aggregator, safe_aggregator, aggregator_address, owner_address) = deploy_aggregator(); +#[test] +#[should_panic(expect: 'should be active')] +fn test_increase_counter_count_error() { + let (_, killswitch_dispatcher, _, aggregator_dispatcher, _) = deploy_contract(); -// let aggregator_count_1 = aggregator.get_count(); -// assert(aggregator_count_1== 0, 'Aggregator count should be zero'); + let non_owner_address: ContractAddress = contract_address_const::<'2'>(); -// // impersonate owner -// start_cheat_caller_address(aggregator_address, owner_address); - -// match safe_aggregator.increase_count(0){ -// Result::Ok(_) => panic!("increasing by zero should panic"), -// Result::Err(panic_data) => assert(*panic_data.at(0) == 'Amount cannot be 0', *panic_data.at(0)) -// } -// } - -// #[test] -// fn test_aggregator_increase_count_event(){ - -// let (aggregator, _, aggregator_address, owner_address) = deploy_aggregator(); -// let mut spy = spy_events(); + start_cheat_caller_address(aggregator_dispatcher.contract_address, non_owner_address); -// let aggregator_count_1 = aggregator.get_count(); -// assert(aggregator_count_1 == 0, 'Aggregator count should be zero'); - -// // Use the owner address to make the call -// start_cheat_caller_address(aggregator_address, owner_address); - -// aggregator.increase_count(20); + let status_before = killswitch_dispatcher.get_status(); + assert(status_before, 'invalid status'); - -// spy.assert_emitted(@array![(owner_address, -// Event::increased_aggregator_count(increased_aggregator_count{ amount: 20, caller_address: owner_address}), -// ),]) -// } \ No newline at end of file + aggregator_dispatcher.increase_counter_count(42); +} \ No newline at end of file