Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/ComputePool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,21 @@ contract ComputePool is IComputePool, AccessControlEnumerable {
return (provider, node);
}

function softInvalidateWork(uint256 poolId, bytes calldata data)
external
onlyExistingPool(poolId)
onlyRole(PRIME_ROLE)
returns (address, address, uint256)
{
IDomainRegistry.Domain memory domainInfo = domainRegistry.get(pools[poolId].domainId);
IWorkValidation workValidation = IWorkValidation(domainInfo.validationLogic);
(address provider, address node, uint256 workUnits) = workValidation.softInvalidateWork(poolId, data);
IRewardsDistributor rewardsDistributor = poolStates[poolId].rewardsDistributor;
rewardsDistributor.removeWork(node, workUnits);
// Note: we don't eject the node with soft invalidation
return (provider, node, workUnits);
}

//
// Management functions
//
Expand Down
10 changes: 10 additions & 0 deletions src/PrimeNetwork.sol
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ contract PrimeNetwork is AccessControlEnumerable {
} catch {}
}

function softInvalidateWork(uint256 poolId, bytes calldata data)
external
onlyRole(VALIDATOR_ROLE)
returns (address, address, uint256)
{
(address provider, address node, uint256 workUnits) = computePool.softInvalidateWork(poolId, data);
// Note: No slashing, no node invalidation, no blacklisting - just remove the work
return (provider, node, workUnits);
}

function _verifyNodekeySignature(address provider, address nodekey, bytes memory signature)
internal
view
Expand Down
7 changes: 7 additions & 0 deletions src/RewardsDistributor.sol
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,11 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable {
workUnits == workUnits;
return;
}

function removeWork(address node, uint256 workUnits) external pure {
// suppress warnings - not implemented in this distributor type
node;
workUnits;
return;
}
}
8 changes: 8 additions & 0 deletions src/RewardsDistributorFixed.sol
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,12 @@ contract RewardsDistributorFixed is IRewardsDistributor, AccessControlEnumerable
workUnits == workUnits;
return;
}

// Soft slash function for work removal (not implemented in Fixed distributor)
function removeWork(address node, uint256 workUnits) external pure {
// suppress warnings - not implemented in fixed distributor
node;
workUnits;
return;
}
}
65 changes: 56 additions & 9 deletions src/RewardsDistributorWorkSubmission.sol
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,20 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE
// --------------------------------------------------------------------------------------------

/// @notice Called by the pool to record that `node` performed `workUnits`.
/// This increments the nodes current bucket, ensuring O(1) ring buffer updates.
/// This increments the node's current bucket, ensuring O(1) ring buffer updates.
function submitWork(address node, uint256 workUnits) external onlyRole(COMPUTE_POOL_ROLE) {
require(endTime == 0, "Rewards have ended");
require(computePool.isNodeInPool(poolId, node), "Node not in pool");

NodeBuckets storage nb = nodeBuckets[node];
// Roll forward first to ensure were in the correct active bucket
// Roll forward first to ensure we're in the correct active bucket
_rollBuckets(node);

// Increment the current bucket
nb.buckets[nb.currentBucket] += workUnits;
nb.totalLast24H += workUnits;

// Track an all-time total if you want to do locked/unlocked logic
// Track an all-time total if you want to do "locked/unlocked" logic
nb.totalAllSubmissions += workUnits;

// Optionally, ensure lastBucketTimestamp is set if first time
Expand All @@ -120,9 +120,9 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE
/**
* @notice Bucket approach:
* - totalAllSubmissions: total submissions ever done by this node.
* - totalLast24H: the sum of the ring buffers most recent 24h.
* We treat that as locked.
* - The difference (totalAllSubmissions - totalLast24H) is unlocked.
* - totalLast24H: the sum of the ring buffer's most recent 24h.
* We treat that as "locked."
* - The difference (totalAllSubmissions - totalLast24H) is "unlocked."
* - We track lastClaimed to ensure we only pay incremental amounts.
*/
function claimRewards(address node) external {
Expand Down Expand Up @@ -196,7 +196,7 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE

NodeBuckets memory nb = nodeBuckets[node];

// Simulate the ring buffer if updated now
// Simulate the ring buffer if updated "now"
uint256 elapsed = (block.timestamp - nb.lastBucketTimestamp) / BUCKET_DURATION;
uint256 simulatedTotalLast24H = nb.totalLast24H;
if (elapsed >= NUM_BUCKETS) {
Expand All @@ -211,7 +211,7 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE
simulatedTotalLast24H -= nb.buckets[idx];
}
}
// Unlocked so far if we hypothetically updated now
// "Unlocked so far" if we hypothetically updated now
uint256 unlockedNow = nb.totalAllSubmissions - simulatedTotalLast24H;
uint256 claimable = unlockedNow - nb.lastClaimed;
uint256 claimableTokens = claimable * rewardRatePerUnit;
Expand Down Expand Up @@ -252,7 +252,7 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE
}

function leavePool(address node) external onlyRole(COMPUTE_POOL_ROLE) {
// Optionally roll + finalize the nodes data. Zero out buckets, etc.
// Optionally roll + finalize the node's data. Zero out buckets, etc.
_rollBuckets(node);
}

Expand All @@ -261,4 +261,51 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE
require(endTime == 0, "Already ended");
endTime = block.timestamp;
}

// --------------------------------------------------------------------------------------------
// Soft slash functions for batch work scenarios
// --------------------------------------------------------------------------------------------

/**
* @notice Removes specific work units from a node's submissions (soft slash).
* This is useful when work is incomplete and you want to remove
* work submissions without harsh penalties.
* @param node The address of the node whose work should be removed.
* @param workUnits The number of work units to remove.
* @dev This function can only be called by the COMPUTE_POOL_ROLE.
* It removes work from the current bucket and adjusts totals accordingly.
*/
function removeWork(address node, uint256 workUnits) external onlyRole(COMPUTE_POOL_ROLE) {
if (workUnits == 0) return; // No-op if no work to remove

NodeBuckets storage nb = nodeBuckets[node];
_rollBuckets(node);

// Ensure we don't remove more than what exists
uint256 toRemove = workUnits > nb.totalLast24H ? nb.totalLast24H : workUnits;
if (toRemove == 0) return; // Nothing to remove

// Remove from current bucket first, then work backwards if needed
uint256 remaining = toRemove;
uint256 bucketIdx = nb.currentBucket;

// Remove from buckets, starting with the current one
for (uint256 i = 0; i < NUM_BUCKETS && remaining > 0; i++) {
uint256 bucketAmount = nb.buckets[bucketIdx];
if (bucketAmount > 0) {
uint256 removeFromBucket = remaining > bucketAmount ? bucketAmount : remaining;
nb.buckets[bucketIdx] -= removeFromBucket;
remaining -= removeFromBucket;
}
// Move to previous bucket (circular)
bucketIdx = bucketIdx == 0 ? NUM_BUCKETS - 1 : bucketIdx - 1;
}

// Update totals
uint256 actualRemoved = toRemove - remaining;
nb.totalLast24H -= actualRemoved;

// Also reduce total submissions to maintain consistency
nb.totalAllSubmissions = nb.totalAllSubmissions > actualRemoved ? nb.totalAllSubmissions - actualRemoved : 0;
}
}
33 changes: 33 additions & 0 deletions src/SyntheticDataWorkValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ event WorkSubmitted(uint256 poolId, address provider, address nodeId, bytes32 wo

event WorkInvalidated(uint256 poolId, address provider, address nodeId, bytes32 workKey, uint256 workUnits);

event WorkRemoved(uint256 poolId, address provider, address nodeId, bytes32 workKey, uint256 workUnits);

contract SyntheticDataWorkValidator is IWorkValidation {
using EnumerableSet for EnumerableSet.Bytes32Set;

Expand Down Expand Up @@ -84,6 +86,37 @@ contract SyntheticDataWorkValidator is IWorkValidation {
return (info.provider, info.nodeId);
}

function softInvalidateWork(uint256 poolId, bytes calldata data) external returns (address, address, uint256) {
require(msg.sender == computePool, "Unauthorized");
require(data.length >= 64, "Data too short for soft invalidation");

// Decode both workKey and workUnits for soft invalidation
bytes32 workKey;
uint256 workUnits;
assembly {
workKey := calldataload(data.offset)
workUnits := calldataload(add(data.offset, 32))
}

require(poolWork[poolId].workKeys.contains(workKey), "Work not found");
require(!poolWork[poolId].invalidWorkKeys.contains(workKey), "Work already invalidated");
require(
block.timestamp - poolWork[poolId].work[workKey].timestamp < workValidityPeriod,
"Work invalidation window has lapsed"
);

WorkInfo memory info = poolWork[poolId].work[workKey];

// Move work from valid to invalid (same as hard invalidation)
poolWork[poolId].invalidWorkKeys.add(workKey);
poolWork[poolId].workKeys.remove(workKey);

emit WorkRemoved(poolId, info.provider, info.nodeId, workKey, workUnits);

// Return the specified work units to be soft-removed (key difference)
return (info.provider, info.nodeId, workUnits);
}

function getWorkInfo(uint256 poolId, bytes32 workKey) external view returns (WorkInfo memory) {
return poolWork[poolId].work[workKey];
}
Expand Down
1 change: 1 addition & 0 deletions src/interfaces/IComputePool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ interface IComputePool is IAccessControlEnumerable {
function ejectNode(uint256 poolId, address nodekey) external;
function submitWork(uint256 poolId, address nodekey, bytes calldata data) external;
function invalidateWork(uint256 poolId, bytes calldata data) external returns (address, address);
function softInvalidateWork(uint256 poolId, bytes calldata data) external returns (address, address, uint256);
function blacklistProvider(uint256 poolId, address provider) external;
function blacklistProviderList(uint256 poolId, address[] memory providers) external;
function blacklistAndPurgeProvider(uint256 poolId, address provider) external;
Expand Down
3 changes: 3 additions & 0 deletions src/interfaces/IRewardsDistributor.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ interface IRewardsDistributor {
function joinPool(address node) external;
function leavePool(address node) external;
function submitWork(address node, uint256 workUnits) external;

// Soft slash function for work removal
function removeWork(address node, uint256 workUnits) external;
}
2 changes: 2 additions & 0 deletions src/interfaces/IWorkValidation.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ interface IWorkValidation {
returns (bool, uint256);

function invalidateWork(uint256 poolId, bytes calldata data) external returns (address, address);

function softInvalidateWork(uint256 poolId, bytes calldata data) external returns (address, address, uint256);
}
116 changes: 99 additions & 17 deletions test/RewardsDistributorWorkSubmission.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ contract RewardsDistributorWorkSubmissionRingBufferTest is Test {
uint256 unlockedNow = fetchRewards(node, false);
assertEq(unlockedNow, 100, "First 100 is unlocked, second 200 is locked.");

// 3) Slash the nodes pending 24h => manager only
// 3) Slash the node's pending 24h => manager only
// That should remove the locked 200 from totalAll, zero the ring buffer, etc.
vm.prank(manager);
distributor.slashPendingRewards(node);
Expand All @@ -489,36 +489,118 @@ contract RewardsDistributorWorkSubmissionRingBufferTest is Test {
(uint256 last24HAfter, uint256 totalAllAfter,,) = distributor.nodeInfo(node);
assertEq(last24HAfter, 0, "Should have cleared the ring buffer");
assertEq(totalAllAfter, 100, "Should have subtracted the slashed 200 from totalAll");
// lastClaimed should remain the same, because we didnt claim.
// lastClaimed should remain the same, because we didn't claim.

// 5) Confirm that now, if we skip 25 hours more, there is no "locked" portion to unlock
skip(25 hours);
uint256 unlockedAfterSlash = fetchRewards(node, false);
// The 100 is still unlocked, but we never claimed it, so it remains unclaimed.
// Because slash only subtracted from totalAll the locked portion, the older 100 is unaffected.
// Because slash only subtracted from totalAll the "locked" portion, the older 100 is unaffected.
// So unlockedAfterSlash == 100 - lastClaimedAfter. But we haven't claimed at all, so lastClaimedAfter=0.
assertEq(unlockedAfterSlash, 100, "The older 100 remains claimable.");
}

// 6) Claim the 100
vm.prank(nodeProvider);
distributor.claimRewards(node);
// -----------------------------------------------------------------------
// Test: Soft slash functionality - removeWork
// -----------------------------------------------------------------------
function testRemoveWork() public {
vm.prank(address(mockComputePool));
mockComputePool.joinComputePool(node, 10);

// Confirm node got 100
uint256 nodeBalance = mockRewardToken.balanceOf(nodeProvider);
assertEq(nodeBalance, 100, "Node should receive the older (unlocked) 100 tokens");
// Submit some work
vm.prank(address(mockComputePool));
distributor.submitWork(node, 500);

(uint256 last24HBefore, uint256 totalAllBefore,,) = distributor.nodeInfo(node);
assertEq(last24HBefore, 500);
assertEq(totalAllBefore, 500);

// Remove 200 work units
vm.prank(address(mockComputePool));
distributor.removeWork(node, 200);

(uint256 last24HAfter, uint256 totalAllAfter,,) = distributor.nodeInfo(node);
assertEq(last24HAfter, 300, "Should have removed 200 work units");
assertEq(totalAllAfter, 300, "Total submissions should also be reduced");
}

function testRemoveWorkExceedsAvailable() public {
vm.prank(address(mockComputePool));
mockComputePool.joinComputePool(node, 10);

// Submit 100 work units
vm.prank(address(mockComputePool));
distributor.submitWork(node, 100);

// Try to remove 200 (more than available)
vm.prank(address(mockComputePool));
distributor.removeWork(node, 200);

// Should only remove what's available (100)
(uint256 last24H, uint256 totalAll,,) = distributor.nodeInfo(node);
assertEq(last24H, 0, "Should have removed all available work");
assertEq(totalAll, 0, "Total should also be zero");
}

function testRemoveWorkAcrossMultipleBuckets() public {
vm.prank(address(mockComputePool));
mockComputePool.joinComputePool(node, 10);

// Submit work in different time buckets
vm.prank(address(mockComputePool));
distributor.submitWork(node, 100);

skip(2 hours); // Move to next bucket
vm.prank(address(mockComputePool));
distributor.submitWork(node, 200);

skip(2 hours); // Move to next bucket
vm.prank(address(mockComputePool));
distributor.submitWork(node, 300);

// Total should be 600
(uint256 last24HBefore, uint256 totalAllBefore,,) = distributor.nodeInfo(node);
assertEq(last24HBefore, 600);
assertEq(totalAllBefore, 600);

// Remove 450 work units (should span multiple buckets)
vm.prank(address(mockComputePool));
distributor.removeWork(node, 450);

(uint256 last24HAfter, uint256 totalAllAfter,,) = distributor.nodeInfo(node);
assertEq(last24HAfter, 150, "Should have removed 450 work units");
assertEq(totalAllAfter, 150, "Total should also be reduced");
}

// -----------------------------------------------------------------------
// Test: setRewardRate and endRewards in ring-buffer version
// Test: Your use case - multiple nodes (call multiple times)
// -----------------------------------------------------------------------
function testSetRewardRate() public {
vm.prank(manager);
vm.expectRevert();
distributor.setRewardRate(12345);
}
function testMultipleNodeRemoval() public {
// Setup multiple nodes
vm.prank(address(mockComputePool));
mockComputePool.joinComputePool(node1, 10);
vm.prank(address(mockComputePool));
mockComputePool.joinComputePool(node2, 10);

function testEndRewards() public {
// Submit work for both nodes
vm.prank(address(mockComputePool));
distributor.endRewards();
distributor.submitWork(node1, 100);
vm.prank(address(mockComputePool));
distributor.submitWork(node2, 150);

// Remove work from both nodes (call multiple times)
vm.prank(address(mockComputePool));
distributor.removeWork(node1, 80);
vm.prank(address(mockComputePool));
distributor.removeWork(node2, 80);

// Check results - both should have 80 removed
(uint256 last24H1, uint256 totalAll1,,) = distributor.nodeInfo(node1);
(uint256 last24H2, uint256 totalAll2,,) = distributor.nodeInfo(node2);

assertEq(last24H1, 20, "Node1 should have 20 remaining (100-80)");
assertEq(totalAll1, 20, "Node1 total should be 20");
assertEq(last24H2, 70, "Node2 should have 70 remaining (150-80)");
assertEq(totalAll2, 70, "Node2 total should be 70");
}
}