diff --git a/src/PrimeNetwork.sol b/src/PrimeNetwork.sol index 228a8e4..b73c8e2 100644 --- a/src/PrimeNetwork.sol +++ b/src/PrimeNetwork.sol @@ -99,6 +99,13 @@ contract PrimeNetwork is AccessControlEnumerable { domainRegistry.updateValidationLogic(domainId, validationLogic); } + function updateDomainParameters(uint256 domainId, string calldata domainParametersURI) + external + onlyRole(FEDERATOR_ROLE) + { + domainRegistry.updateParameters(domainId, domainParametersURI); + } + function registerProvider(uint256 stake) external { uint256 stakeMinimum = stakeManager.getStakeMinimum(); require(stake >= stakeMinimum, "Stake amount is below minimum"); @@ -114,9 +121,20 @@ contract PrimeNetwork is AccessControlEnumerable { function increaseStake(uint256 amount) external { address provider = msg.sender; require(computeRegistry.checkProviderExists(provider), "Provider not registered"); - AIToken.transferFrom(msg.sender, address(this), amount); - AIToken.approve(address(stakeManager), amount); - stakeManager.stake(provider, amount); + // check if provider has pending unbonding, if so rebond those + uint256 pendingUnbond = stakeManager.getPendingUnbondTotal(provider); + if (pendingUnbond > 0) { + if (pendingUnbond >= amount) { + pendingUnbond = amount; // rebond only the amount we can + } + amount -= pendingUnbond; // reduce the amount to stake by the rebonded amount + stakeManager.rebond(provider, pendingUnbond); + } + if (amount != 0) { + AIToken.transferFrom(msg.sender, address(this), amount); + AIToken.approve(address(stakeManager), amount); + stakeManager.stake(provider, amount); + } } function reclaimStake(uint256 amount) external { diff --git a/src/StakeManager.sol b/src/StakeManager.sol index d05e6ff..61b27a4 100644 --- a/src/StakeManager.sol +++ b/src/StakeManager.sol @@ -12,10 +12,13 @@ contract StakeManager is IStakeManager, AccessControlEnumerable { } bytes32 public constant PRIME_ROLE = keccak256("PRIME_ROLE"); + uint256 constant MAX_PENDING_UNBONDS = 10; + IERC20 public AIToken; mapping(address => uint256) private _stakes; mapping(address => UnbondTracker) private _unbonds; + mapping(address => uint256) private _unbonding; uint256 private _totalStaked; uint256 private _totalUnbonding; uint256 private _unbondingPeriod; @@ -38,21 +41,58 @@ contract StakeManager is IStakeManager, AccessControlEnumerable { function unstake(address staker, uint256 amount) external onlyRole(PRIME_ROLE) { require(_stakes[staker] >= amount, "StakeManager: insufficient balance"); _stakes[staker] -= amount; + _unbonding[staker] += amount; _totalStaked -= amount; _totalUnbonding += amount; - // add unbonding - _unbonds[staker].unbonds.push(Unbond(amount, block.timestamp + _unbondingPeriod)); + if (_unbonds[staker].unbonds.length - _unbonds[staker].offset >= MAX_PENDING_UNBONDS) { + // add to the newest unbond and reset the time + UnbondTracker storage pending = _unbonds[staker]; + pending.unbonds[pending.unbonds.length - 1].amount += amount; + pending.unbonds[pending.unbonds.length - 1].timestamp = block.timestamp + _unbondingPeriod; + } else { + // add a new unbond + _unbonds[staker].unbonds.push(Unbond(amount, block.timestamp + _unbondingPeriod)); + } emit Unstake(staker, amount); } + function rebond(address staker, uint256 amount) external onlyRole(PRIME_ROLE) { + require(_unbonding[staker] >= amount, "StakeManager: insufficient unbonding balance"); + _unbonding[staker] -= amount; + _totalUnbonding -= amount; + _stakes[staker] += amount; + _totalStaked += amount; + // remove unbond + uint256 rebonded = 0; + UnbondTracker storage pending = _unbonds[staker]; + for (uint256 i = pending.offset; i < pending.unbonds.length; i++) { + uint256 unbond_amount = pending.unbonds[i].amount; + if (amount >= unbond_amount + rebonded) { + rebonded += unbond_amount; + delete pending.unbonds[i]; + pending.offset = i + 1; + } else { + // only part of the unbond is rebonded + uint256 leftover = unbond_amount + rebonded - amount; + pending.unbonds[i].amount = leftover; + rebonded = amount; + pending.offset = i; + break; + } + } + emit Rebond(staker, amount); + } + function withdraw() external { // calculate amount that can be withdrawn uint256 amount = 0; - Unbond[] storage pending = _unbonds[msg.sender].unbonds; - for (uint256 i = 0; i < pending.length; i++) { + UnbondTracker storage staker = _unbonds[msg.sender]; + Unbond[] storage pending = staker.unbonds; + for (uint256 i = staker.offset; i < pending.length; i++) { if (pending[i].timestamp <= block.timestamp) { amount += pending[i].amount; _totalUnbonding -= pending[i].amount; + _unbonding[msg.sender] -= pending[i].amount; delete pending[i]; } else { _unbonds[msg.sender].offset = i; @@ -83,6 +123,7 @@ contract StakeManager is IStakeManager, AccessControlEnumerable { // slash the difference uint256 diff = unbonding_amount - amount; _totalUnbonding -= (pending.unbonds[i].amount - diff); + _unbonding[staker] -= (pending.unbonds[i].amount - diff); pending.unbonds[i].amount = diff; unbonding_amount = amount; pending.offset = i; @@ -90,12 +131,14 @@ contract StakeManager is IStakeManager, AccessControlEnumerable { } else if (unbonding_amount == amount) { // slash the whole unbond _totalUnbonding -= pending.unbonds[i].amount; + _unbonding[staker] -= pending.unbonds[i].amount; delete pending.unbonds[i]; pending.offset = i + 1; break; } else { // slash the whole unbond and continue _totalUnbonding -= pending.unbonds[i].amount; + _unbonding[staker] -= pending.unbonds[i].amount; delete pending.unbonds[i]; pending.offset = i + 1; } @@ -146,7 +189,17 @@ contract StakeManager is IStakeManager, AccessControlEnumerable { } function getPendingUnbonds(address staker) external view returns (Unbond[] memory) { - return _unbonds[staker].unbonds; + Unbond[] memory unbonds = _unbonds[staker].unbonds; + uint256 length = unbonds.length - _unbonds[staker].offset; + Unbond[] memory pendingUnbonds = new Unbond[](length); + for (uint256 i = 0; i < length; i++) { + pendingUnbonds[i] = unbonds[_unbonds[staker].offset + i]; + } + return pendingUnbonds; + } + + function getPendingUnbondTotal(address staker) external view returns (uint256) { + return _unbonding[staker]; } function getUnbondingPeriod() external view returns (uint256) { diff --git a/src/interfaces/IStakeManager.sol b/src/interfaces/IStakeManager.sol index 35b53ef..269c030 100644 --- a/src/interfaces/IStakeManager.sol +++ b/src/interfaces/IStakeManager.sol @@ -11,6 +11,8 @@ event Withdraw(address staker, uint256 amount); event Slashed(address staker, uint256 amount, bytes reason); +event Rebond(address staker, uint256 amount); + event UpdateUnbondingPeriod(uint256 period); event StakeMinimumUpdate(uint256 minimum); @@ -23,6 +25,7 @@ interface IStakeManager is IAccessControlEnumerable { function stake(address staker, uint256 amount) external; function unstake(address staker, uint256 amount) external; + function rebond(address staker, uint256 amount) external; function withdraw() external; function slash(address staker, uint256 amount, bytes calldata reason) external returns (uint256 slashed); @@ -32,6 +35,7 @@ interface IStakeManager is IAccessControlEnumerable { function getStake(address staker) external view returns (uint256); function getTotalStaked() external view returns (uint256); function getPendingUnbonds(address staker) external view returns (Unbond[] memory); + function getPendingUnbondTotal(address staker) external view returns (uint256); function getUnbondingPeriod() external view returns (uint256); function getTotalUnbonding() external view returns (uint256); function getStakeMinimum() external view returns (uint256); diff --git a/test/StakeManager.t.sol b/test/StakeManager.t.sol new file mode 100644 index 0000000..fe2fe14 --- /dev/null +++ b/test/StakeManager.t.sol @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.24; + +import "forge-std/Test.sol"; +import "../src/StakeManager.sol"; +import "../src/interfaces/IStakeManager.sol"; +import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +/*──────────────────────── Mocks ────────────────────────*/ +contract MockERC20 is ERC20 { + constructor() ERC20("AI Token", "AI") {} + + function mint(address to, uint256 amt) external { + _mint(to, amt); + } +} + +/*─────────────────────── Test Suite ─────────────────────*/ +contract StakeManagerTest is Test { + /* roles & addresses */ + address internal prime = address(this); // gets PRIME + DEFAULT_ADMIN in ctor + address internal staker = address(0xBEEF); + bytes32 constant PRIME_ROLE = keccak256("PRIME_ROLE"); + + /* system under test */ + MockERC20 internal token; + StakeManager internal mgr; + + /* constants */ + uint256 constant PERIOD = 7 days; + uint256 constant ONE = 1 ether; + uint256 constant HUND = 100 * ONE; + + /*────────── Setup ─────────*/ + function setUp() public { + token = new MockERC20(); + mgr = new StakeManager(prime, PERIOD, IERC20(token)); + + token.mint(prime, 1_000_000 * ONE); + token.mint(staker, 1_000_000 * ONE); + + token.approve(address(mgr), type(uint256).max); + vm.prank(staker); + token.approve(address(mgr), type(uint256).max); // not strictly needed but convenient + } + + /*────────── helpers ─────────*/ + function _stake(address who, uint256 amt) internal { + vm.prank(prime); + mgr.stake(who, amt); + } + + function _unstake(address who, uint256 amt) internal { + vm.prank(prime); + mgr.unstake(who, amt); + } + + /*────────── tests ─────────*/ + + function testStakeIncreasesBalances() public { + vm.expectEmit(true, false, false, true); + emit Stake(staker, HUND); + + _stake(staker, HUND); + + assertEq(mgr.getStake(staker), HUND); + assertEq(mgr.getTotalStaked(), HUND); + assertEq(token.balanceOf(address(mgr)), HUND); + assertEq(token.balanceOf(prime), 1_000_000 * ONE - HUND); + } + + function testStakeRequiresPrime() public { + vm.prank(staker); + vm.expectRevert(); // AccessControl + mgr.stake(staker, ONE); + } + + function testUnstakeMovesToPending() public { + _stake(staker, HUND); + + vm.expectEmit(true, false, false, true); + emit Unstake(staker, 40 * ONE); + + _unstake(staker, 40 * ONE); + + assertEq(mgr.getStake(staker), 60 * ONE); + assertEq(mgr.getPendingUnbondTotal(staker), 40 * ONE); + + StakeManager.Unbond[] memory arr = mgr.getPendingUnbonds(staker); + assertEq(arr.length, 1); + assertEq(arr[0].amount, 40 * ONE); + assertEq(arr[0].timestamp, block.timestamp + PERIOD); + } + + function testWithdrawAfterPeriod() public { + _stake(staker, HUND); + _unstake(staker, HUND); + + vm.warp(block.timestamp + PERIOD + 1); + + uint256 balBefore = token.balanceOf(staker); + + vm.prank(staker); + vm.expectEmit(true, false, false, true); + emit Withdraw(staker, HUND); + mgr.withdraw(); + + assertEq(token.balanceOf(staker), balBefore + HUND); + assertEq(mgr.getPendingUnbondTotal(staker), 0); + assertEq(mgr.getTotalUnbonding(), 0); + } + + function testWithdrawRevertsIfNothing() public { + vm.prank(staker); + vm.expectRevert(bytes("StakeManager: no funds to withdraw")); + mgr.withdraw(); + } + + function testRebondPartial() public { + _stake(staker, HUND); + _unstake(staker, HUND); + + vm.expectEmit(true, false, false, true); + emit Rebond(staker, 40 * ONE); + + vm.prank(prime); + mgr.rebond(staker, 40 * ONE); + + assertEq(mgr.getStake(staker), 40 * ONE + 0); // 40 rebonded on top of 0 residual + assertEq(mgr.getPendingUnbondTotal(staker), 60 * ONE); + StakeManager.Unbond[] memory arr = mgr.getPendingUnbonds(staker); + // first slot reduced by 40 + assertEq(arr[0].amount, 60 * ONE); + } + + function testRebondFull() public { + _stake(staker, HUND); + _unstake(staker, HUND); + + vm.expectEmit(true, false, false, true); + emit Rebond(staker, 100 * ONE); + + vm.prank(prime); + mgr.rebond(staker, 100 * ONE); + + assertEq(mgr.getStake(staker), 100 * ONE); + assertEq(mgr.getPendingUnbondTotal(staker), 0); + StakeManager.Unbond[] memory arr = mgr.getPendingUnbonds(staker); + assertEq(arr.length, 0); // all unbonds rebonded + } + + function testMaxPendingAggregates() public { + _stake(staker, 1000 * ONE); + + // create 10 unbonds + for (uint256 i; i < 9; i++) { + _unstake(staker, ONE); + } // first 9 individual + _unstake(staker, ONE); // 10th — separate slot + // next call should aggregate into the 10th + _unstake(staker, ONE); + + StakeManager.Unbond[] memory arr = mgr.getPendingUnbonds(staker); + assertEq(arr.length, 10); + assertEq(arr[9].amount, 2 * ONE); // aggregated + } + + function testSlashFromStake() public { + _stake(staker, HUND); + + vm.expectEmit(true, false, false, true); + emit Slashed(staker, 30 * ONE, ""); + + vm.prank(prime); + uint256 slashed = mgr.slash(staker, 30 * ONE, ""); + + assertEq(slashed, 30 * ONE); + assertEq(mgr.getStake(staker), 70 * ONE); + assertEq(token.balanceOf(prime), 1_000_000 * ONE - HUND + 30 * ONE); + } + + function testSlashPrefersUnbonding() public { + _stake(staker, HUND); + _unstake(staker, HUND); // entire stake now pending + + vm.prank(prime); + uint256 slashed = mgr.slash(staker, 60 * ONE, ""); + + assertEq(slashed, 60 * ONE); + assertEq(mgr.getPendingUnbondTotal(staker), 40 * ONE); + assertEq(mgr.getTotalUnbonding(), 40 * ONE); + } + + function testSettersRequirePrime() public { + vm.prank(staker); + vm.expectRevert(); // AccessControl + mgr.setUnbondingPeriod(1 days); + + vm.prank(prime); + mgr.setUnbondingPeriod(1 days); + assertEq(mgr.getUnbondingPeriod(), 1 days); + + vm.prank(staker); + vm.expectRevert(); + mgr.setStakeMinimum(10 * ONE); + + vm.prank(prime); + mgr.setStakeMinimum(10 * ONE); + assertEq(mgr.getStakeMinimum(), 10 * ONE); + } +}