Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/PrimeNetwork.sol
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ contract PrimeNetwork is AccessControlEnumerable {
domainRegistry.updateValidationLogic(domainId, validationLogic);
}

function updateDomainParameters(uint256 domainId, string calldata domainParametersURI)
external
onlyRole(FEDERATOR_ROLE)
{
domainRegistry.updateParameters(domainId, domainParametersURI);
}

function registerProvider(uint256 stake) external {
uint256 stakeMinimum = stakeManager.getStakeMinimum();
require(stake >= stakeMinimum, "Stake amount is below minimum");
Expand All @@ -114,9 +121,20 @@ contract PrimeNetwork is AccessControlEnumerable {
function increaseStake(uint256 amount) external {
address provider = msg.sender;
require(computeRegistry.checkProviderExists(provider), "Provider not registered");
AIToken.transferFrom(msg.sender, address(this), amount);
AIToken.approve(address(stakeManager), amount);
stakeManager.stake(provider, amount);
// check if provider has pending unbonding, if so rebond those
uint256 pendingUnbond = stakeManager.getPendingUnbondTotal(provider);
if (pendingUnbond > 0) {
if (pendingUnbond >= amount) {
pendingUnbond = amount; // rebond only the amount we can
}
amount -= pendingUnbond; // reduce the amount to stake by the rebonded amount
stakeManager.rebond(provider, pendingUnbond);
}
if (amount != 0) {
AIToken.transferFrom(msg.sender, address(this), amount);
AIToken.approve(address(stakeManager), amount);
stakeManager.stake(provider, amount);
}
}

function reclaimStake(uint256 amount) external {
Expand Down
63 changes: 58 additions & 5 deletions src/StakeManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ contract StakeManager is IStakeManager, AccessControlEnumerable {
}

bytes32 public constant PRIME_ROLE = keccak256("PRIME_ROLE");
uint256 constant MAX_PENDING_UNBONDS = 10;

IERC20 public AIToken;

mapping(address => uint256) private _stakes;
mapping(address => UnbondTracker) private _unbonds;
mapping(address => uint256) private _unbonding;
uint256 private _totalStaked;
uint256 private _totalUnbonding;
uint256 private _unbondingPeriod;
Expand All @@ -38,21 +41,58 @@ contract StakeManager is IStakeManager, AccessControlEnumerable {
function unstake(address staker, uint256 amount) external onlyRole(PRIME_ROLE) {
require(_stakes[staker] >= amount, "StakeManager: insufficient balance");
_stakes[staker] -= amount;
_unbonding[staker] += amount;
_totalStaked -= amount;
_totalUnbonding += amount;
// add unbonding
_unbonds[staker].unbonds.push(Unbond(amount, block.timestamp + _unbondingPeriod));
if (_unbonds[staker].unbonds.length - _unbonds[staker].offset >= MAX_PENDING_UNBONDS) {
// add to the newest unbond and reset the time
UnbondTracker storage pending = _unbonds[staker];
pending.unbonds[pending.unbonds.length - 1].amount += amount;
pending.unbonds[pending.unbonds.length - 1].timestamp = block.timestamp + _unbondingPeriod;
} else {
// add a new unbond
_unbonds[staker].unbonds.push(Unbond(amount, block.timestamp + _unbondingPeriod));
}
emit Unstake(staker, amount);
}

function rebond(address staker, uint256 amount) external onlyRole(PRIME_ROLE) {
require(_unbonding[staker] >= amount, "StakeManager: insufficient unbonding balance");
_unbonding[staker] -= amount;
_totalUnbonding -= amount;
_stakes[staker] += amount;
_totalStaked += amount;
// remove unbond
uint256 rebonded = 0;
UnbondTracker storage pending = _unbonds[staker];
for (uint256 i = pending.offset; i < pending.unbonds.length; i++) {
uint256 unbond_amount = pending.unbonds[i].amount;
if (amount >= unbond_amount + rebonded) {
rebonded += unbond_amount;
delete pending.unbonds[i];
pending.offset = i + 1;
} else {
// only part of the unbond is rebonded
uint256 leftover = unbond_amount + rebonded - amount;
pending.unbonds[i].amount = leftover;
rebonded = amount;
pending.offset = i;
break;
}
}
emit Rebond(staker, amount);
}

function withdraw() external {
// calculate amount that can be withdrawn
uint256 amount = 0;
Unbond[] storage pending = _unbonds[msg.sender].unbonds;
for (uint256 i = 0; i < pending.length; i++) {
UnbondTracker storage staker = _unbonds[msg.sender];
Unbond[] storage pending = staker.unbonds;
for (uint256 i = staker.offset; i < pending.length; i++) {
if (pending[i].timestamp <= block.timestamp) {
amount += pending[i].amount;
_totalUnbonding -= pending[i].amount;
_unbonding[msg.sender] -= pending[i].amount;
delete pending[i];
} else {
_unbonds[msg.sender].offset = i;
Expand Down Expand Up @@ -83,19 +123,22 @@ contract StakeManager is IStakeManager, AccessControlEnumerable {
// slash the difference
uint256 diff = unbonding_amount - amount;
_totalUnbonding -= (pending.unbonds[i].amount - diff);
_unbonding[staker] -= (pending.unbonds[i].amount - diff);
pending.unbonds[i].amount = diff;
unbonding_amount = amount;
pending.offset = i;
break;
} else if (unbonding_amount == amount) {
// slash the whole unbond
_totalUnbonding -= pending.unbonds[i].amount;
_unbonding[staker] -= pending.unbonds[i].amount;
delete pending.unbonds[i];
pending.offset = i + 1;
break;
} else {
// slash the whole unbond and continue
_totalUnbonding -= pending.unbonds[i].amount;
_unbonding[staker] -= pending.unbonds[i].amount;
delete pending.unbonds[i];
pending.offset = i + 1;
}
Expand Down Expand Up @@ -146,7 +189,17 @@ contract StakeManager is IStakeManager, AccessControlEnumerable {
}

function getPendingUnbonds(address staker) external view returns (Unbond[] memory) {
return _unbonds[staker].unbonds;
Unbond[] memory unbonds = _unbonds[staker].unbonds;
uint256 length = unbonds.length - _unbonds[staker].offset;
Unbond[] memory pendingUnbonds = new Unbond[](length);
for (uint256 i = 0; i < length; i++) {
pendingUnbonds[i] = unbonds[_unbonds[staker].offset + i];
}
return pendingUnbonds;
}

function getPendingUnbondTotal(address staker) external view returns (uint256) {
return _unbonding[staker];
}

function getUnbondingPeriod() external view returns (uint256) {
Expand Down
4 changes: 4 additions & 0 deletions src/interfaces/IStakeManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ event Withdraw(address staker, uint256 amount);

event Slashed(address staker, uint256 amount, bytes reason);

event Rebond(address staker, uint256 amount);

event UpdateUnbondingPeriod(uint256 period);

event StakeMinimumUpdate(uint256 minimum);
Expand All @@ -23,6 +25,7 @@ interface IStakeManager is IAccessControlEnumerable {

function stake(address staker, uint256 amount) external;
function unstake(address staker, uint256 amount) external;
function rebond(address staker, uint256 amount) external;
function withdraw() external;
function slash(address staker, uint256 amount, bytes calldata reason) external returns (uint256 slashed);

Expand All @@ -32,6 +35,7 @@ interface IStakeManager is IAccessControlEnumerable {
function getStake(address staker) external view returns (uint256);
function getTotalStaked() external view returns (uint256);
function getPendingUnbonds(address staker) external view returns (Unbond[] memory);
function getPendingUnbondTotal(address staker) external view returns (uint256);
function getUnbondingPeriod() external view returns (uint256);
function getTotalUnbonding() external view returns (uint256);
function getStakeMinimum() external view returns (uint256);
Expand Down
211 changes: 211 additions & 0 deletions test/StakeManager.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// SPDX-License-Identifier: Apache-2.0
pragma solidity ^0.8.24;

import "forge-std/Test.sol";
import "../src/StakeManager.sol";
import "../src/interfaces/IStakeManager.sol";
import "@openzeppelin/contracts/token/ERC20/ERC20.sol";

/*──────────────────────── Mocks ────────────────────────*/
contract MockERC20 is ERC20 {
constructor() ERC20("AI Token", "AI") {}

function mint(address to, uint256 amt) external {
_mint(to, amt);
}
}

/*─────────────────────── Test Suite ─────────────────────*/
contract StakeManagerTest is Test {
/* roles & addresses */
address internal prime = address(this); // gets PRIME + DEFAULT_ADMIN in ctor
address internal staker = address(0xBEEF);
bytes32 constant PRIME_ROLE = keccak256("PRIME_ROLE");

/* system under test */
MockERC20 internal token;
StakeManager internal mgr;

/* constants */
uint256 constant PERIOD = 7 days;
uint256 constant ONE = 1 ether;
uint256 constant HUND = 100 * ONE;

/*────────── Setup ─────────*/
function setUp() public {
token = new MockERC20();
mgr = new StakeManager(prime, PERIOD, IERC20(token));

token.mint(prime, 1_000_000 * ONE);
token.mint(staker, 1_000_000 * ONE);

token.approve(address(mgr), type(uint256).max);
vm.prank(staker);
token.approve(address(mgr), type(uint256).max); // not strictly needed but convenient
}

/*────────── helpers ─────────*/
function _stake(address who, uint256 amt) internal {
vm.prank(prime);
mgr.stake(who, amt);
}

function _unstake(address who, uint256 amt) internal {
vm.prank(prime);
mgr.unstake(who, amt);
}

/*────────── tests ─────────*/

function testStakeIncreasesBalances() public {
vm.expectEmit(true, false, false, true);
emit Stake(staker, HUND);

_stake(staker, HUND);

assertEq(mgr.getStake(staker), HUND);
assertEq(mgr.getTotalStaked(), HUND);
assertEq(token.balanceOf(address(mgr)), HUND);
assertEq(token.balanceOf(prime), 1_000_000 * ONE - HUND);
}

function testStakeRequiresPrime() public {
vm.prank(staker);
vm.expectRevert(); // AccessControl
mgr.stake(staker, ONE);
}

function testUnstakeMovesToPending() public {
_stake(staker, HUND);

vm.expectEmit(true, false, false, true);
emit Unstake(staker, 40 * ONE);

_unstake(staker, 40 * ONE);

assertEq(mgr.getStake(staker), 60 * ONE);
assertEq(mgr.getPendingUnbondTotal(staker), 40 * ONE);

StakeManager.Unbond[] memory arr = mgr.getPendingUnbonds(staker);
assertEq(arr.length, 1);
assertEq(arr[0].amount, 40 * ONE);
assertEq(arr[0].timestamp, block.timestamp + PERIOD);
}

function testWithdrawAfterPeriod() public {
_stake(staker, HUND);
_unstake(staker, HUND);

vm.warp(block.timestamp + PERIOD + 1);

uint256 balBefore = token.balanceOf(staker);

vm.prank(staker);
vm.expectEmit(true, false, false, true);
emit Withdraw(staker, HUND);
mgr.withdraw();

assertEq(token.balanceOf(staker), balBefore + HUND);
assertEq(mgr.getPendingUnbondTotal(staker), 0);
assertEq(mgr.getTotalUnbonding(), 0);
}

function testWithdrawRevertsIfNothing() public {
vm.prank(staker);
vm.expectRevert(bytes("StakeManager: no funds to withdraw"));
mgr.withdraw();
}

function testRebondPartial() public {
_stake(staker, HUND);
_unstake(staker, HUND);

vm.expectEmit(true, false, false, true);
emit Rebond(staker, 40 * ONE);

vm.prank(prime);
mgr.rebond(staker, 40 * ONE);

assertEq(mgr.getStake(staker), 40 * ONE + 0); // 40 rebonded on top of 0 residual
assertEq(mgr.getPendingUnbondTotal(staker), 60 * ONE);
StakeManager.Unbond[] memory arr = mgr.getPendingUnbonds(staker);
// first slot reduced by 40
assertEq(arr[0].amount, 60 * ONE);
}

function testRebondFull() public {
_stake(staker, HUND);
_unstake(staker, HUND);

vm.expectEmit(true, false, false, true);
emit Rebond(staker, 100 * ONE);

vm.prank(prime);
mgr.rebond(staker, 100 * ONE);

assertEq(mgr.getStake(staker), 100 * ONE);
assertEq(mgr.getPendingUnbondTotal(staker), 0);
StakeManager.Unbond[] memory arr = mgr.getPendingUnbonds(staker);
assertEq(arr.length, 0); // all unbonds rebonded
}

function testMaxPendingAggregates() public {
_stake(staker, 1000 * ONE);

// create 10 unbonds
for (uint256 i; i < 9; i++) {
_unstake(staker, ONE);
} // first 9 individual
_unstake(staker, ONE); // 10th — separate slot
// next call should aggregate into the 10th
_unstake(staker, ONE);

StakeManager.Unbond[] memory arr = mgr.getPendingUnbonds(staker);
assertEq(arr.length, 10);
assertEq(arr[9].amount, 2 * ONE); // aggregated
}

function testSlashFromStake() public {
_stake(staker, HUND);

vm.expectEmit(true, false, false, true);
emit Slashed(staker, 30 * ONE, "");

vm.prank(prime);
uint256 slashed = mgr.slash(staker, 30 * ONE, "");

assertEq(slashed, 30 * ONE);
assertEq(mgr.getStake(staker), 70 * ONE);
assertEq(token.balanceOf(prime), 1_000_000 * ONE - HUND + 30 * ONE);
}

function testSlashPrefersUnbonding() public {
_stake(staker, HUND);
_unstake(staker, HUND); // entire stake now pending

vm.prank(prime);
uint256 slashed = mgr.slash(staker, 60 * ONE, "");

assertEq(slashed, 60 * ONE);
assertEq(mgr.getPendingUnbondTotal(staker), 40 * ONE);
assertEq(mgr.getTotalUnbonding(), 40 * ONE);
}

function testSettersRequirePrime() public {
vm.prank(staker);
vm.expectRevert(); // AccessControl
mgr.setUnbondingPeriod(1 days);

vm.prank(prime);
mgr.setUnbondingPeriod(1 days);
assertEq(mgr.getUnbondingPeriod(), 1 days);

vm.prank(staker);
vm.expectRevert();
mgr.setStakeMinimum(10 * ONE);

vm.prank(prime);
mgr.setStakeMinimum(10 * ONE);
assertEq(mgr.getStakeMinimum(), 10 * ONE);
}
}