From 107eb9b10ba3180df05c313104f0c6d29efcec6b Mon Sep 17 00:00:00 2001 From: enyinnaya1234 Date: Thu, 8 May 2025 18:04:11 +0100 Subject: [PATCH 1/3] test: validate aggregator logic --- .tool-versions | 2 + src/aggregator.cairo | 2 +- tests/test_contract.cairo | 139 +++++++++++++++++++++++++++++++++----- 3 files changed, 125 insertions(+), 18 deletions(-) create mode 100644 .tool-versions diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 0000000..9f71edc --- /dev/null +++ b/.tool-versions @@ -0,0 +1,2 @@ +scarb 2.11.4 +starknet-foundry 0.40.0 diff --git a/src/aggregator.cairo b/src/aggregator.cairo index 3e4cbe8..12d8229 100644 --- a/src/aggregator.cairo +++ b/src/aggregator.cairo @@ -50,7 +50,7 @@ mod Agggregator { let killswitch: IKillSwitchDispatcher = IKillSwitchDispatcher { contract_address: self.killswitch.read(), }; - assert(killswitch.get_status(), 'not active'); + assert(killswitch.get_status(), 'should be active'); ICounterDispatcher { contract_address: self.counter.read() }.increase_count(amount) } diff --git a/tests/test_contract.cairo b/tests/test_contract.cairo index 9d6708a..20d224e 100644 --- a/tests/test_contract.cairo +++ b/tests/test_contract.cairo @@ -1,42 +1,62 @@ -use starknet::ContractAddress; - use snforge_std::{declare, ContractClassTrait, DeclareResultTrait}; use cohort_4::counter::{ICounterDispatcher, ICounterDispatcherTrait, ICounterSafeDispatcher, ICounterSafeDispatcherTrait}; +use cohort_4::aggregator::{IAggregatorDispatcher, IAggregatorDispatcherTrait}; +use cohort_4::killswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; + -fn deploy_contract() -> ContractAddress { - let countract_name: ByteArray = "Counter"; - let contract = declare(countract_name).unwrap().contract_class(); - let (contract_address, _) = contract.deploy(@ArrayTrait::new()).unwrap(); - contract_address +fn deploy_contract() -> (ICounterDispatcher, IKillSwitchDispatcher, IAggregatorDispatcher) { + //counter deployment + let counter_countract_name: ByteArray = "Counter"; + let contract = declare(counter_countract_name).unwrap().contract_class(); + let (counter_contract_address, _) = contract.deploy(@ArrayTrait::new()).unwrap(); + let counter_dispatcher = ICounterDispatcher{ + contract_address: counter_contract_address + }; + + //killswitch deployment + let killswitch_contract_name: ByteArray = "KillSwitch"; + let killswitch_contract = declare(killswitch_contract_name).unwrap().contract_class(); + let (killswitch_contract_address, _) = killswitch_contract.deploy(@ArrayTrait::new()).unwrap(); + let killswitch_dispatcher = IKillSwitchDispatcher{ + contract_address: killswitch_contract_address + }; + + //aggregator deployment + let aggregator = declare("Agggregator").unwrap().contract_class(); + let (aggregator_contract_address, _) = aggregator.deploy(@array![counter_contract_address.into(), killswitch_contract_address.into()]).unwrap(); + + let aggregator_dispatcher = IAggregatorDispatcher{ + contract_address: aggregator_contract_address + }; + + (counter_dispatcher, killswitch_dispatcher, aggregator_dispatcher) } #[test] fn test_increase_count() { - let contract_address = deploy_contract(); + let (counter_dispatcher, _, _) = deploy_contract(); - let dispatcher = ICounterDispatcher { contract_address }; + - let balance_before = dispatcher.get_count(); + let balance_before = counter_dispatcher.get_count(); assert(balance_before == 0, 'Invalid balance'); - dispatcher.increase_count(42); + counter_dispatcher.increase_count(42); - let balance_after = dispatcher.get_count(); + let balance_after = counter_dispatcher.get_count(); assert(balance_after == 42, 'Invalid balance'); } #[test] #[feature("safe_dispatcher")] fn test_cannot_increase_balance_with_zero_value() { - let contract_address = deploy_contract(); - - let dispatcher = ICounterDispatcher { contract_address }; + let (counter_dispatcher, _, _) = deploy_contract(); - let balance_before = dispatcher.get_count(); + let balance_before = counter_dispatcher.get_count(); assert(balance_before == 0 , 'Invalid balance'); - let safe_dispatcher = ICounterSafeDispatcher { contract_address }; + let safe_dispatcher = ICounterSafeDispatcher { contract_address: counter_dispatcher.contract_address }; match safe_dispatcher.increase_count(0) { Result::Ok(_) => core::panic_with_felt252('Should have panicked'), @@ -47,3 +67,88 @@ fn test_cannot_increase_balance_with_zero_value() { } + +#[test] +fn test_increase_count_aggregator() { + let (_, _, aggregator_dispatcher)= deploy_contract(); + + let balance_before = aggregator_dispatcher.get_count(); + assert(balance_before == 0, 'Invalid balance'); + + aggregator_dispatcher.increase_count(42); + + let balance_after = aggregator_dispatcher.get_count(); + assert(balance_after == 42, 'Invalid balance'); +} + +#[test] +fn test_increase_counter_count_aggregator() { + let (counter_dispatcher, killswitch_dispatcher, aggregator_dispatcher) = deploy_contract(); + + let count_1 = counter_dispatcher.get_count(); + assert(count_1 == 0, 'invalid count 1'); + + let status_before = killswitch_dispatcher.get_status(); + assert(!status_before, 'incorrect status'); + + aggregator_dispatcher.activate_switch(); + let status_after = killswitch_dispatcher.get_status(); + assert(status_after, 'failed to activate'); + + aggregator_dispatcher.increase_counter_count(42); + + let count_2 = counter_dispatcher.get_count(); + assert(count_2 == 42, 'invalid count 2'); +} + +#[test] +fn test_decrease_count_by_one_aggregator() { + let (_, _, aggregator_dispatcher) = deploy_contract(); + + let count_after = aggregator_dispatcher.get_count(); + assert(count_after == 0, 'invalid count'); + + aggregator_dispatcher.increase_count(20); + aggregator_dispatcher.decrease_count_by_one(); + + let count_after = aggregator_dispatcher.get_count(); + assert(count_after == 19, 'incorrect count'); + +} + +#[test] +fn test_increase_activate_switch() { + let (_, killswitch_dispatcher, aggregator_dispatcher) = deploy_contract(); + + let status = killswitch_dispatcher.get_status(); + assert(!status, 'failed'); + + aggregator_dispatcher.activate_switch(); + + let status_after = killswitch_dispatcher.get_status(); + assert(status_after, 'invalid status'); + +} + +#[test] +#[should_panic(expect: 'Amount cannot be 0')] +fn test_increase_count_by_zero(){ + let (_, _, aggregator_dispatcher) = deploy_contract(); + + let count_after = aggregator_dispatcher.get_count(); + assert(count_after == 0, 'incorrect count'); + + aggregator_dispatcher.increase_count(0); +} + +#[test] +#[should_panic(expect: 'should be active')] +fn test_increase_counter_count_error() { + let (_, killswitch_dispatcher, aggregator_dispatcher) = deploy_contract(); + + let status_before = killswitch_dispatcher.get_status(); + assert(status_before, 'invalid status'); + + aggregator_dispatcher.increase_counter_count(42); + +} \ No newline at end of file From f8963aeac2a3b6138af62f639728b2af85103250 Mon Sep 17 00:00:00 2001 From: enyinnaya1234 Date: Thu, 8 May 2025 18:07:23 +0100 Subject: [PATCH 2/3] chore: run scarb fmt --- src/IHello.cairo | 2 +- src/INumber.cairo | 2 +- src/counter.cairo | 2 +- src/hello.cairo | 6 ++-- src/lib.cairo | 9 +++--- tests/test_contract.cairo | 61 +++++++++++++++++++-------------------- 6 files changed, 39 insertions(+), 43 deletions(-) diff --git a/src/IHello.cairo b/src/IHello.cairo index 0d1c511..df6f372 100644 --- a/src/IHello.cairo +++ b/src/IHello.cairo @@ -8,4 +8,4 @@ pub trait IHelloStarknet { fn get_balance(self: @TContractState) -> felt252; fn add_and_subtract(ref self: TContractState, amount: felt252); -} \ No newline at end of file +} diff --git a/src/INumber.cairo b/src/INumber.cairo index eb4014b..5a4f700 100644 --- a/src/INumber.cairo +++ b/src/INumber.cairo @@ -5,4 +5,4 @@ pub trait INumber { fn set_number(ref self: TContractState, amount: u8); /// Returns the current number fn get_number(self: @TContractState) -> u8; -} \ No newline at end of file +} diff --git a/src/counter.cairo b/src/counter.cairo index 190c073..ced698b 100644 --- a/src/counter.cairo +++ b/src/counter.cairo @@ -37,7 +37,7 @@ mod Counter { } fn get_count(self: @ContractState) -> u32 { - self.count.read() + self.count.read() } } } diff --git a/src/hello.cairo b/src/hello.cairo index fd7d3e2..13c95fd 100644 --- a/src/hello.cairo +++ b/src/hello.cairo @@ -39,8 +39,8 @@ mod HelloStarknet { } fn add_and_subtract(ref self: ContractState, amount: felt252) { - self._add(amount); - self._subtract(amount); + self._add(amount); + self._subtract(amount); } } @@ -52,4 +52,4 @@ mod HelloStarknet { number } } -} \ No newline at end of file +} diff --git a/src/lib.cairo b/src/lib.cairo index 8b3993c..7df1810 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -1,22 +1,21 @@ -pub mod hello; pub mod IHello; pub mod INumber; +pub mod aggregator; pub mod counter; +pub mod hello; pub mod killswitch; -pub mod aggregator; - fn main() { // Function calls (Uncomment to execute them) // say_name("Sylvia Nnoruka!"); // intro_to_felt(); - + let num_1 = 5; let num_2 = 10; let sum = sum_num(num_1, num_2); println!("The sum of {} and {} is = {}", num_1, num_2, sum); - + // check_u16(6553); // Uncomment if needed is_greater_than_50(3); } diff --git a/tests/test_contract.cairo b/tests/test_contract.cairo index 20d224e..71b3178 100644 --- a/tests/test_contract.cairo +++ b/tests/test_contract.cairo @@ -1,8 +1,10 @@ -use snforge_std::{declare, ContractClassTrait, DeclareResultTrait}; - -use cohort_4::counter::{ICounterDispatcher, ICounterDispatcherTrait, ICounterSafeDispatcher, ICounterSafeDispatcherTrait}; use cohort_4::aggregator::{IAggregatorDispatcher, IAggregatorDispatcherTrait}; +use cohort_4::counter::{ + ICounterDispatcher, ICounterDispatcherTrait, ICounterSafeDispatcher, + ICounterSafeDispatcherTrait, +}; use cohort_4::killswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; +use snforge_std::{ContractClassTrait, DeclareResultTrait, declare}; fn deploy_contract() -> (ICounterDispatcher, IKillSwitchDispatcher, IAggregatorDispatcher) { @@ -10,24 +12,24 @@ fn deploy_contract() -> (ICounterDispatcher, IKillSwitchDispatcher, IAggregatorD let counter_countract_name: ByteArray = "Counter"; let contract = declare(counter_countract_name).unwrap().contract_class(); let (counter_contract_address, _) = contract.deploy(@ArrayTrait::new()).unwrap(); - let counter_dispatcher = ICounterDispatcher{ - contract_address: counter_contract_address - }; - - //killswitch deployment + let counter_dispatcher = ICounterDispatcher { contract_address: counter_contract_address }; + + //killswitch deployment let killswitch_contract_name: ByteArray = "KillSwitch"; let killswitch_contract = declare(killswitch_contract_name).unwrap().contract_class(); let (killswitch_contract_address, _) = killswitch_contract.deploy(@ArrayTrait::new()).unwrap(); - let killswitch_dispatcher = IKillSwitchDispatcher{ - contract_address: killswitch_contract_address + let killswitch_dispatcher = IKillSwitchDispatcher { + contract_address: killswitch_contract_address, }; //aggregator deployment let aggregator = declare("Agggregator").unwrap().contract_class(); - let (aggregator_contract_address, _) = aggregator.deploy(@array![counter_contract_address.into(), killswitch_contract_address.into()]).unwrap(); + let (aggregator_contract_address, _) = aggregator + .deploy(@array![counter_contract_address.into(), killswitch_contract_address.into()]) + .unwrap(); - let aggregator_dispatcher = IAggregatorDispatcher{ - contract_address: aggregator_contract_address + let aggregator_dispatcher = IAggregatorDispatcher { + contract_address: aggregator_contract_address, }; (counter_dispatcher, killswitch_dispatcher, aggregator_dispatcher) @@ -35,9 +37,7 @@ fn deploy_contract() -> (ICounterDispatcher, IKillSwitchDispatcher, IAggregatorD #[test] fn test_increase_count() { - let (counter_dispatcher, _, _) = deploy_contract(); - - + let (counter_dispatcher, _, _) = deploy_contract(); let balance_before = counter_dispatcher.get_count(); assert(balance_before == 0, 'Invalid balance'); @@ -51,26 +51,26 @@ fn test_increase_count() { #[test] #[feature("safe_dispatcher")] fn test_cannot_increase_balance_with_zero_value() { - let (counter_dispatcher, _, _) = deploy_contract(); + let (counter_dispatcher, _, _) = deploy_contract(); let balance_before = counter_dispatcher.get_count(); - assert(balance_before == 0 , 'Invalid balance'); + assert(balance_before == 0, 'Invalid balance'); - let safe_dispatcher = ICounterSafeDispatcher { contract_address: counter_dispatcher.contract_address }; + let safe_dispatcher = ICounterSafeDispatcher { + contract_address: counter_dispatcher.contract_address, + }; match safe_dispatcher.increase_count(0) { Result::Ok(_) => core::panic_with_felt252('Should have panicked'), Result::Err(panic_data) => { assert(*panic_data.at(0) == 'Amount cannot be 0', *panic_data.at(0)); - } + }, }; - - } #[test] fn test_increase_count_aggregator() { - let (_, _, aggregator_dispatcher)= deploy_contract(); + let (_, _, aggregator_dispatcher) = deploy_contract(); let balance_before = aggregator_dispatcher.get_count(); assert(balance_before == 0, 'Invalid balance'); @@ -83,7 +83,7 @@ fn test_increase_count_aggregator() { #[test] fn test_increase_counter_count_aggregator() { - let (counter_dispatcher, killswitch_dispatcher, aggregator_dispatcher) = deploy_contract(); + let (counter_dispatcher, killswitch_dispatcher, aggregator_dispatcher) = deploy_contract(); let count_1 = counter_dispatcher.get_count(); assert(count_1 == 0, 'invalid count 1'); @@ -103,7 +103,7 @@ fn test_increase_counter_count_aggregator() { #[test] fn test_decrease_count_by_one_aggregator() { - let (_, _, aggregator_dispatcher) = deploy_contract(); + let (_, _, aggregator_dispatcher) = deploy_contract(); let count_after = aggregator_dispatcher.get_count(); assert(count_after == 0, 'invalid count'); @@ -113,7 +113,6 @@ fn test_decrease_count_by_one_aggregator() { let count_after = aggregator_dispatcher.get_count(); assert(count_after == 19, 'incorrect count'); - } #[test] @@ -124,16 +123,15 @@ fn test_increase_activate_switch() { assert(!status, 'failed'); aggregator_dispatcher.activate_switch(); - + let status_after = killswitch_dispatcher.get_status(); assert(status_after, 'invalid status'); - } #[test] #[should_panic(expect: 'Amount cannot be 0')] -fn test_increase_count_by_zero(){ - let (_, _, aggregator_dispatcher) = deploy_contract(); +fn test_increase_count_by_zero() { + let (_, _, aggregator_dispatcher) = deploy_contract(); let count_after = aggregator_dispatcher.get_count(); assert(count_after == 0, 'incorrect count'); @@ -150,5 +148,4 @@ fn test_increase_counter_count_error() { assert(status_before, 'invalid status'); aggregator_dispatcher.increase_counter_count(42); - -} \ No newline at end of file +} From 72c332179b92e41c14695a328dfe3f8dd81d20ca Mon Sep 17 00:00:00 2001 From: olowo Date: Fri, 16 May 2025 13:38:36 +0100 Subject: [PATCH 3/3] refac: moduralize the code --- src/aggregator.cairo | 57 ++++++----- src/counter.cairo | 44 +++++---- src/interfaces/Iaggregator.cairo | 15 +++ src/interfaces/Icounter.cairo | 11 +++ src/interfaces/Ikillswitch.cairo | 10 ++ src/interfaces/Iowner.cairo | 7 ++ src/killswitch.cairo | 14 +-- src/lib.cairo | 9 ++ src/ownable.cairo | 63 ++++++++++++ tests/test_contract.cairo | 158 +++++++++++++++++++++++-------- tests/test_counter.cairo | 69 ++++++++++++++ tests/test_killswitch.cairo | 31 ++++++ 12 files changed, 400 insertions(+), 88 deletions(-) create mode 100644 src/interfaces/Iaggregator.cairo create mode 100644 src/interfaces/Icounter.cairo create mode 100644 src/interfaces/Ikillswitch.cairo create mode 100644 src/interfaces/Iowner.cairo create mode 100644 src/ownable.cairo create mode 100644 tests/test_counter.cairo create mode 100644 tests/test_killswitch.cairo diff --git a/src/aggregator.cairo b/src/aggregator.cairo index 12d8229..4d9a851 100644 --- a/src/aggregator.cairo +++ b/src/aggregator.cairo @@ -1,26 +1,11 @@ -#[starknet::interface] -pub trait IAggregator { - /// Increase contract count. - fn increase_count(ref self: TContractState, amount: u32); - /// Increase contract count. - /// - fn increase_counter_count(ref self: TContractState, amount: u32); - - /// Retrieve contract count. - fn decrease_count_by_one(ref self: TContractState); - /// Retrieve contract count. - fn get_count(self: @TContractState) -> u32; - - fn activate_switch(ref self: TContractState); -} - -/// Simple contract for managing count. #[starknet::contract] mod Agggregator { - use cohort_4::counter::{ICounterDispatcher, ICounterDispatcherTrait}; - use cohort_4::killswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; - use starknet::ContractAddress; + use cohort_4::interfaces::Iaggregator::IAggregator; + use cohort_4::interfaces::Icounter::{ICounterDispatcher, ICounterDispatcherTrait}; + use cohort_4::interfaces::Ikillswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; + use cohort_4::interfaces::Iowner::{IOwnerDispatcher, IOwnerDispatcherTrait}; use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + use starknet::{ContractAddress, get_caller_address}; #[storage] @@ -28,18 +13,30 @@ mod Agggregator { count: u32, counter: ContractAddress, killswitch: ContractAddress, + ownable: ContractAddress, } + #[constructor] - fn constructor(ref self: ContractState, counter: ContractAddress, killswitch: ContractAddress) { + fn constructor( + ref self: ContractState, + counter: ContractAddress, + killswitch: ContractAddress, + ownable: ContractAddress, + ) { + assert(counter != killswitch && counter != ownable, 'use counter address'); + assert(killswitch != counter && killswitch != ownable, 'use killswitch address'); + assert(ownable != killswitch && ownable != counter, 'use counter address'); self.counter.write(counter); self.killswitch.write(killswitch); + self.ownable.write(ownable); } #[abi(embed_v0)] - impl AggregatorImpl of super::IAggregator { + impl AggregatorImpl of IAggregator { fn increase_count(ref self: ContractState, amount: u32) { + self.check_for_owner(); assert(amount > 0, 'Amount cannot be 0'); let counter = ICounterDispatcher { contract_address: self.counter.read() }; let counter_count = counter.get_count(); @@ -47,6 +44,7 @@ mod Agggregator { } fn increase_counter_count(ref self: ContractState, amount: u32) { + self.check_for_owner(); let killswitch: IKillSwitchDispatcher = IKillSwitchDispatcher { contract_address: self.killswitch.read(), }; @@ -55,12 +53,15 @@ mod Agggregator { } fn decrease_count_by_one(ref self: ContractState) { + self.check_for_owner(); let current_count = self.get_count(); assert!(current_count != 0, "Amount cannot be 0"); self.count.write(current_count - 1); } fn activate_switch(ref self: ContractState) { + self.check_for_owner(); + let killswitch: IKillSwitchDispatcher = IKillSwitchDispatcher { contract_address: self.killswitch.read(), }; @@ -74,4 +75,16 @@ mod Agggregator { self.count.read() } } + + + #[generate_trait] + impl privateAggregator of privateAggregatorTrait { + fn check_for_owner(self: @ContractState) { + let caller = get_caller_address(); + + let ownable = IOwnerDispatcher { contract_address: self.ownable.read() }; + let owner = ownable.get_owner(); + assert(caller == owner, 'you are not the owner'); + } + } } diff --git a/src/counter.cairo b/src/counter.cairo index ced698b..5f6a87f 100644 --- a/src/counter.cairo +++ b/src/counter.cairo @@ -1,20 +1,7 @@ -/// Interface representing `Counter`. -/// This interface allows modification and retrieval of the contract count. -#[starknet::interface] -pub trait ICounter { - /// Increase contract count. - fn increase_count(ref self: TContractState, amount: u32); - - /// Decrease contract count by one - fn decrease_count_by_one(ref self: TContractState); - - /// Retrieve contract count. - fn get_count(self: @TContractState) -> u32; -} - /// Simple contract for managing count. #[starknet::contract] -mod Counter { +pub mod Counter { + use cohort_4::interfaces::Icounter::ICounter; use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; #[storage] @@ -22,18 +9,41 @@ mod Counter { count: u32, } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + IncreaseCount: IncreaseCount, + DecreaseCount: DecreaseCount, + } + + #[derive(Drop, starknet::Event)] + pub struct IncreaseCount { + pub count: u32, + } + + #[derive(Drop, starknet::Event)] + pub struct DecreaseCount { + pub count: u32, + } + #[abi(embed_v0)] - impl CounterImpl of super::ICounter { + impl CounterImpl of ICounter { fn increase_count(ref self: ContractState, amount: u32) { assert(amount > 0, 'Amount cannot be 0'); let counter_count = self.get_count(); + + let new_count = counter_count + amount; self.count.write(counter_count + amount); + self.emit(Event::IncreaseCount(IncreaseCount { count: new_count })); } fn decrease_count_by_one(ref self: ContractState) { let current_count = self.get_count(); assert(current_count > 0, 'Amount cannot be 0'); - self.count.write(current_count - 1); + let new_count = current_count - 1; + self.count.write(new_count); + self.emit(Event::DecreaseCount(DecreaseCount { count: new_count })); } fn get_count(self: @ContractState) -> u32 { diff --git a/src/interfaces/Iaggregator.cairo b/src/interfaces/Iaggregator.cairo new file mode 100644 index 0000000..58ff900 --- /dev/null +++ b/src/interfaces/Iaggregator.cairo @@ -0,0 +1,15 @@ +#[starknet::interface] +pub trait IAggregator { + /// Increase contract count. + fn increase_count(ref self: TContractState, amount: u32); + /// Increase contract count. + /// + fn increase_counter_count(ref self: TContractState, amount: u32); + + /// Retrieve contract count. + fn decrease_count_by_one(ref self: TContractState); + /// Retrieve contract count. + fn get_count(self: @TContractState) -> u32; + + fn activate_switch(ref self: TContractState); +} diff --git a/src/interfaces/Icounter.cairo b/src/interfaces/Icounter.cairo new file mode 100644 index 0000000..1ffa408 --- /dev/null +++ b/src/interfaces/Icounter.cairo @@ -0,0 +1,11 @@ +#[starknet::interface] +pub trait ICounter { + /// Increase contract count. + fn increase_count(ref self: TContractState, amount: u32); + + /// Decrease contract count by one + fn decrease_count_by_one(ref self: TContractState); + + /// Retrieve contract count. + fn get_count(self: @TContractState) -> u32; +} diff --git a/src/interfaces/Ikillswitch.cairo b/src/interfaces/Ikillswitch.cairo new file mode 100644 index 0000000..f76c958 --- /dev/null +++ b/src/interfaces/Ikillswitch.cairo @@ -0,0 +1,10 @@ +/// Interface representing `HelloContract`. +/// This interface allows modification and retrieval of the contract count. +#[starknet::interface] +pub trait IKillSwitch { + /// Increase contract count. + fn switch(ref self: TContractState); + + /// Retrieve contract count. + fn get_status(self: @TContractState) -> bool; +} diff --git a/src/interfaces/Iowner.cairo b/src/interfaces/Iowner.cairo new file mode 100644 index 0000000..27127b9 --- /dev/null +++ b/src/interfaces/Iowner.cairo @@ -0,0 +1,7 @@ +#[starknet::interface] +pub trait IOwner { + fn get_owner(self: @TContractState) -> starknet::ContractAddress; + + + fn transfer_owner(ref self: TContractState, new_owner: starknet::ContractAddress); +} diff --git a/src/killswitch.cairo b/src/killswitch.cairo index 6437088..2b8ba78 100644 --- a/src/killswitch.cairo +++ b/src/killswitch.cairo @@ -1,17 +1,7 @@ -/// Interface representing `HelloContract`. -/// This interface allows modification and retrieval of the contract count. -#[starknet::interface] -pub trait IKillSwitch { - /// Increase contract count. - fn switch(ref self: TContractState); - - /// Retrieve contract count. - fn get_status(self: @TContractState) -> bool; -} - /// Simple contract for managing count. #[starknet::contract] mod KillSwitch { + use cohort_4::interfaces::Ikillswitch::IKillSwitch; use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; #[storage] @@ -21,7 +11,7 @@ mod KillSwitch { #[abi(embed_v0)] - impl KillSwitchImpl of super::IKillSwitch { + impl KillSwitchImpl of IKillSwitch { fn switch(ref self: ContractState) { // assert(amount != 0, 'Amount cannot be 0'); self.status.write(!self.status.read()); diff --git a/src/lib.cairo b/src/lib.cairo index 7df1810..20fa20d 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -1,9 +1,18 @@ pub mod IHello; pub mod INumber; + pub mod aggregator; pub mod counter; pub mod hello; pub mod killswitch; +pub mod ownable; + +pub mod interfaces { + pub mod Iaggregator; + pub mod Icounter; + pub mod Ikillswitch; + pub mod Iowner; +} fn main() { diff --git a/src/ownable.cairo b/src/ownable.cairo new file mode 100644 index 0000000..7d74033 --- /dev/null +++ b/src/ownable.cairo @@ -0,0 +1,63 @@ +#[starknet::contract] +pub mod Ownable { + use starknet::event::EventEmitter; + use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + use starknet::{ContractAddress, contract_address_const, get_caller_address}; + use crate::interfaces::Iowner::IOwner; + + + #[storage] + struct Storage { + pub owner: ContractAddress, + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + OwnershipTransferred: OwnershipTransferred, + } + + #[derive(Drop, starknet::Event)] + struct OwnershipTransferred { + previous_owner: ContractAddress, + new_owner: ContractAddress, + } + + #[constructor] + fn constructor(ref self: ContractState, initial_owner: ContractAddress) { + let address_zero = contract_address_const::<0>(); + assert(initial_owner != address_zero, 'Owner cannot be zero'); + self.owner.write(initial_owner); + } + + + #[abi(embed_v0)] + impl Owner of IOwner { + fn get_owner(self: @ContractState) -> ContractAddress { + self.owner.read() + } + + + fn transfer_owner(ref self: ContractState, new_owner: ContractAddress) { + self.check_for_owner(); + let address_zero = contract_address_const::<0>(); + assert(new_owner != address_zero, 'new owner cannot be zero'); + + let previous_owner = self.owner.read(); + self.owner.write(new_owner) + } + } + + + //private function + + #[generate_trait] + impl Internalownable of InternalownableTrait { + fn check_for_owner(self: @ContractState) { + let caller = get_caller_address(); + + let owner = self.owner.read(); + assert(caller == owner, 'you are not the owner'); + } + } +} diff --git a/tests/test_contract.cairo b/tests/test_contract.cairo index 71b3178..d703609 100644 --- a/tests/test_contract.cairo +++ b/tests/test_contract.cairo @@ -1,13 +1,27 @@ -use cohort_4::aggregator::{IAggregatorDispatcher, IAggregatorDispatcherTrait}; -use cohort_4::counter::{ +use cohort_4::interfaces::Iaggregator::{IAggregatorDispatcher, IAggregatorDispatcherTrait}; +use cohort_4::interfaces::Icounter::{ ICounterDispatcher, ICounterDispatcherTrait, ICounterSafeDispatcher, ICounterSafeDispatcherTrait, }; -use cohort_4::killswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; -use snforge_std::{ContractClassTrait, DeclareResultTrait, declare}; +use cohort_4::interfaces::Ikillswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; +use cohort_4::interfaces::Iowner::{IOwnerDispatcher, IOwnerDispatcherTrait}; +use snforge_std::{ + ContractClassTrait, DeclareResultTrait, declare, start_cheat_caller_address, + stop_cheat_caller_address, +}; +use starknet::{ContractAddress, contract_address_const}; + +fn deploy_contract() -> ( + ICounterDispatcher, + IKillSwitchDispatcher, + IOwnerDispatcher, + IAggregatorDispatcher, + ContractAddress, +) { + // owner address + let owner_address: ContractAddress = contract_address_const::<'1'>(); -fn deploy_contract() -> (ICounterDispatcher, IKillSwitchDispatcher, IAggregatorDispatcher) { //counter deployment let counter_countract_name: ByteArray = "Counter"; let contract = declare(counter_countract_name).unwrap().contract_class(); @@ -22,22 +36,43 @@ fn deploy_contract() -> (ICounterDispatcher, IKillSwitchDispatcher, IAggregatorD contract_address: killswitch_contract_address, }; + //ownable deployment + let ownable_contract_name: ByteArray = "Ownable"; + let ownable_contract = declare(ownable_contract_name).unwrap().contract_class(); + // Deploy with owner_address as the initial owner + let (ownable_contract_address, _) = ownable_contract + .deploy(@array![owner_address.into()]) + .unwrap(); + let ownable_dispatcher = IOwnerDispatcher { contract_address: ownable_contract_address }; + //aggregator deployment let aggregator = declare("Agggregator").unwrap().contract_class(); let (aggregator_contract_address, _) = aggregator - .deploy(@array![counter_contract_address.into(), killswitch_contract_address.into()]) + .deploy( + @array![ + counter_contract_address.into(), + killswitch_contract_address.into(), + ownable_contract_address.into(), + ], + ) .unwrap(); let aggregator_dispatcher = IAggregatorDispatcher { contract_address: aggregator_contract_address, }; - (counter_dispatcher, killswitch_dispatcher, aggregator_dispatcher) + ( + counter_dispatcher, + killswitch_dispatcher, + ownable_dispatcher, + aggregator_dispatcher, + owner_address, + ) } #[test] fn test_increase_count() { - let (counter_dispatcher, _, _) = deploy_contract(); + let (counter_dispatcher, _, _, _, _) = deploy_contract(); let balance_before = counter_dispatcher.get_count(); assert(balance_before == 0, 'Invalid balance'); @@ -48,10 +83,44 @@ fn test_increase_count() { assert(balance_after == 42, 'Invalid balance'); } +#[test] +fn test_increase_count_aggregator_as_owner() { + let (_, _, _, aggregator_dispatcher, owner_address) = deploy_contract(); + + start_cheat_caller_address(aggregator_dispatcher.contract_address, owner_address); + + let balance_before = aggregator_dispatcher.get_count(); + assert(balance_before == 0, 'Invalid balance'); + + aggregator_dispatcher.increase_count(42); + + let balance_after = aggregator_dispatcher.get_count(); + assert(balance_after == 42, 'Invalid balance'); + + stop_cheat_caller_address(aggregator_dispatcher.contract_address); +} + +#[test] +#[should_panic(expected: ('you are not the owner',))] +fn test_increase_count_aggregator_as_non_owner() { + let (_, _, _, aggregator_dispatcher, _) = deploy_contract(); + + let non_owner_address: ContractAddress = contract_address_const::<'2'>(); + + start_cheat_caller_address(aggregator_dispatcher.contract_address, non_owner_address); + + aggregator_dispatcher.increase_count(42); + + let balance_after = aggregator_dispatcher.get_count(); + assert(balance_after == 0, 'you are not the owner'); + + stop_cheat_caller_address(aggregator_dispatcher.contract_address); +} + #[test] #[feature("safe_dispatcher")] fn test_cannot_increase_balance_with_zero_value() { - let (counter_dispatcher, _, _) = deploy_contract(); + let (counter_dispatcher, _, _, _, _) = deploy_contract(); let balance_before = counter_dispatcher.get_count(); assert(balance_before == 0, 'Invalid balance'); @@ -68,45 +137,31 @@ fn test_cannot_increase_balance_with_zero_value() { }; } -#[test] -fn test_increase_count_aggregator() { - let (_, _, aggregator_dispatcher) = deploy_contract(); - - let balance_before = aggregator_dispatcher.get_count(); - assert(balance_before == 0, 'Invalid balance'); - - aggregator_dispatcher.increase_count(42); - - let balance_after = aggregator_dispatcher.get_count(); - assert(balance_after == 42, 'Invalid balance'); -} #[test] -fn test_increase_counter_count_aggregator() { - let (counter_dispatcher, killswitch_dispatcher, aggregator_dispatcher) = deploy_contract(); +fn test_increase_counter_count() { + let (counter_dispatcher, _, _, aggregator_dispatcher, owner_address) = deploy_contract(); - let count_1 = counter_dispatcher.get_count(); - assert(count_1 == 0, 'invalid count 1'); + start_cheat_caller_address(aggregator_dispatcher.contract_address, owner_address); - let status_before = killswitch_dispatcher.get_status(); - assert(!status_before, 'incorrect status'); + let initial_count = counter_dispatcher.get_count(); + assert(initial_count == 0, 'invalid initial count'); aggregator_dispatcher.activate_switch(); - let status_after = killswitch_dispatcher.get_status(); - assert(status_after, 'failed to activate'); aggregator_dispatcher.increase_counter_count(42); - let count_2 = counter_dispatcher.get_count(); - assert(count_2 == 42, 'invalid count 2'); + let final_count = counter_dispatcher.get_count(); + assert(final_count == 42, 'counter not increased correctly'); } #[test] fn test_decrease_count_by_one_aggregator() { - let (_, _, aggregator_dispatcher) = deploy_contract(); + let (counter_dispatcher, _, _, aggregator_dispatcher, owner_address) = deploy_contract(); - let count_after = aggregator_dispatcher.get_count(); - assert(count_after == 0, 'invalid count'); + start_cheat_caller_address(aggregator_dispatcher.contract_address, owner_address); + let count_before = aggregator_dispatcher.get_count(); + assert(count_before == 0, 'invalid count'); aggregator_dispatcher.increase_count(20); aggregator_dispatcher.decrease_count_by_one(); @@ -117,13 +172,38 @@ fn test_decrease_count_by_one_aggregator() { #[test] fn test_increase_activate_switch() { - let (_, killswitch_dispatcher, aggregator_dispatcher) = deploy_contract(); + let (_, killswitch_dispatcher, _, aggregator_dispatcher, owner_address) = deploy_contract(); + + let non_owner_address: ContractAddress = contract_address_const::<'2'>(); + start_cheat_caller_address(aggregator_dispatcher.contract_address, owner_address); let status = killswitch_dispatcher.get_status(); assert(!status, 'failed'); aggregator_dispatcher.activate_switch(); + stop_cheat_caller_address(aggregator_dispatcher.contract_address); + + let status_after = killswitch_dispatcher.get_status(); + assert(status_after, 'invalid status'); +} + + +#[test] +#[should_panic(expect: 'you are not the owner')] +fn test_increase_activate_switch_non_owner() { + let (_, killswitch_dispatcher, _, aggregator_dispatcher, _) = deploy_contract(); + + let non_owner_address: ContractAddress = contract_address_const::<'2'>(); + + start_cheat_caller_address(aggregator_dispatcher.contract_address, non_owner_address); + let status = killswitch_dispatcher.get_status(); + assert(!status, 'failed'); + + aggregator_dispatcher.activate_switch(); + + stop_cheat_caller_address(aggregator_dispatcher.contract_address); + let status_after = killswitch_dispatcher.get_status(); assert(status_after, 'invalid status'); } @@ -131,7 +211,7 @@ fn test_increase_activate_switch() { #[test] #[should_panic(expect: 'Amount cannot be 0')] fn test_increase_count_by_zero() { - let (_, _, aggregator_dispatcher) = deploy_contract(); + let (_, killswitch_dispatcher, _, aggregator_dispatcher, _) = deploy_contract(); let count_after = aggregator_dispatcher.get_count(); assert(count_after == 0, 'incorrect count'); @@ -142,7 +222,11 @@ fn test_increase_count_by_zero() { #[test] #[should_panic(expect: 'should be active')] fn test_increase_counter_count_error() { - let (_, killswitch_dispatcher, aggregator_dispatcher) = deploy_contract(); + let (_, killswitch_dispatcher, _, aggregator_dispatcher, _) = deploy_contract(); + + let non_owner_address: ContractAddress = contract_address_const::<'2'>(); + + start_cheat_caller_address(aggregator_dispatcher.contract_address, non_owner_address); let status_before = killswitch_dispatcher.get_status(); assert(status_before, 'invalid status'); diff --git a/tests/test_counter.cairo b/tests/test_counter.cairo new file mode 100644 index 0000000..cf3006f --- /dev/null +++ b/tests/test_counter.cairo @@ -0,0 +1,69 @@ +use cohort_4::interfaces::Icounter::{ + ICounterDispatcher, ICounterDispatcherTrait, ICounterSafeDispatcher, + ICounterSafeDispatcherTrait, +}; +use snforge_std::{ContractClassTrait, DeclareResultTrait, declare}; +use starknet::ContractAddress; + + +fn deploy() -> (ICounterDispatcher, ICounterSafeDispatcher, ContractAddress) { + let counter_countract_name: ByteArray = "Counter"; + let contract = declare(counter_countract_name).unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@ArrayTrait::new()).unwrap(); + let counter = ICounterDispatcher { contract_address }; + let safe_counter = ICounterSafeDispatcher { contract_address }; + return (counter, safe_counter, contract_address); +} + + +#[test] +fn test_initialized_count() { + let (counter, _, _) = deploy(); + let count_1 = counter.get_count(); + assert(count_1 == 0, 'invalid count value'); +} + +#[test] +fn test_increase_count() { + let (counter, _, _) = deploy(); + let count_1 = counter.get_count(); + assert(count_1 == 0, 'invalid count value'); + + let amount = 30; + counter.increase_count(amount); + let count_2 = counter.get_count(); + assert(count_2 == amount, 'Invalid amount'); + assert(count_2 != 10, 'invalid 10') +} + +#[test] +#[feature("safe_dispatcher")] +fn test_increase_count_by_zero_amount() { + let (counter, safe_counter, _) = deploy(); + + let count_1 = counter.get_count(); + assert(count_1 == 0, 'invalid count value'); + + match safe_counter.increase_count(0) { + Result::Ok(_) => panic!("amount passed not zero"), + Result::Err(panic_data) => { + assert(*panic_data.at(0) == 'Amount cannot be 0', *panic_data.at(0)) + }, + } +} + + +#[test] +fn test_for_decrease_count_by_one() { + let (counter, _, _) = deploy(); + + let count_1 = counter.get_count(); + assert(count_1 == 0, 'Invalid count value not zero'); + + counter.increase_count(2); + + counter.decrease_count_by_one(); + let count_2 = counter.get_count(); + assert(count_1 == count_2 - 1, 'Invalid count value'); +} + diff --git a/tests/test_killswitch.cairo b/tests/test_killswitch.cairo new file mode 100644 index 0000000..02eeea5 --- /dev/null +++ b/tests/test_killswitch.cairo @@ -0,0 +1,31 @@ +use cohort_4::interfaces::Ikillswitch::{IKillSwitchDispatcher, IKillSwitchDispatcherTrait}; +use snforge_std::{ContractClassTrait, DeclareResultTrait, declare}; +use starknet::ContractAddress; + +fn deploy() -> (IKillSwitchDispatcher, ContractAddress) { + let contract = declare("KillSwitch").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@array![]).unwrap(); + let killswitch = IKillSwitchDispatcher { contract_address }; + + (killswitch, contract_address) +} + +#[test] +fn test_get_status() { + let (killswitch, _) = deploy(); + + let current_status = killswitch.get_status(); + + assert(!current_status, 'status is not valid'); +} + +#[test] +fn test_switch() { + let (killswitch, _) = deploy(); + + let prev_status = killswitch.get_status(); + killswitch.switch(); + + assert(!prev_status, 'Status did not change'); +} +