diff --git a/src/ComputePool.sol b/src/ComputePool.sol index 7976464..95d2f68 100644 --- a/src/ComputePool.sol +++ b/src/ComputePool.sol @@ -369,6 +369,21 @@ contract ComputePool is IComputePool, AccessControlEnumerable { return (provider, node); } + function softInvalidateWork(uint256 poolId, bytes calldata data) + external + onlyExistingPool(poolId) + onlyRole(PRIME_ROLE) + returns (address, address, uint256) + { + IDomainRegistry.Domain memory domainInfo = domainRegistry.get(pools[poolId].domainId); + IWorkValidation workValidation = IWorkValidation(domainInfo.validationLogic); + (address provider, address node, uint256 workUnits) = workValidation.softInvalidateWork(poolId, data); + IRewardsDistributor rewardsDistributor = poolStates[poolId].rewardsDistributor; + rewardsDistributor.removeWork(node, workUnits); + // Note: we don't eject the node with soft invalidation + return (provider, node, workUnits); + } + // // Management functions // diff --git a/src/PrimeNetwork.sol b/src/PrimeNetwork.sol index 40680a4..228a8e4 100644 --- a/src/PrimeNetwork.sol +++ b/src/PrimeNetwork.sol @@ -214,6 +214,16 @@ contract PrimeNetwork is AccessControlEnumerable { } catch {} } + function softInvalidateWork(uint256 poolId, bytes calldata data) + external + onlyRole(VALIDATOR_ROLE) + returns (address, address, uint256) + { + (address provider, address node, uint256 workUnits) = computePool.softInvalidateWork(poolId, data); + // Note: No slashing, no node invalidation, no blacklisting - just remove the work + return (provider, node, workUnits); + } + function _verifyNodekeySignature(address provider, address nodekey, bytes memory signature) internal view diff --git a/src/RewardsDistributor.sol b/src/RewardsDistributor.sol index 66a0945..c2499aa 100644 --- a/src/RewardsDistributor.sol +++ b/src/RewardsDistributor.sol @@ -203,4 +203,11 @@ contract RewardsDistributor is IRewardsDistributor, AccessControlEnumerable { workUnits == workUnits; return; } + + function removeWork(address node, uint256 workUnits) external pure { + // suppress warnings - not implemented in this distributor type + node; + workUnits; + return; + } } diff --git a/src/RewardsDistributorFixed.sol b/src/RewardsDistributorFixed.sol index 2d9de75..69053fe 100644 --- a/src/RewardsDistributorFixed.sol +++ b/src/RewardsDistributorFixed.sol @@ -190,4 +190,12 @@ contract RewardsDistributorFixed is IRewardsDistributor, AccessControlEnumerable workUnits == workUnits; return; } + + // Soft slash function for work removal (not implemented in Fixed distributor) + function removeWork(address node, uint256 workUnits) external pure { + // suppress warnings - not implemented in fixed distributor + node; + workUnits; + return; + } } diff --git a/src/RewardsDistributorWorkSubmission.sol b/src/RewardsDistributorWorkSubmission.sol index 8ee5f2c..5e9e698 100644 --- a/src/RewardsDistributorWorkSubmission.sol +++ b/src/RewardsDistributorWorkSubmission.sol @@ -91,20 +91,20 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE // -------------------------------------------------------------------------------------------- /// @notice Called by the pool to record that `node` performed `workUnits`. - /// This increments the node’s current bucket, ensuring O(1) ring buffer updates. + /// This increments the node's current bucket, ensuring O(1) ring buffer updates. function submitWork(address node, uint256 workUnits) external onlyRole(COMPUTE_POOL_ROLE) { require(endTime == 0, "Rewards have ended"); require(computePool.isNodeInPool(poolId, node), "Node not in pool"); NodeBuckets storage nb = nodeBuckets[node]; - // Roll forward first to ensure we’re in the correct active bucket + // Roll forward first to ensure we're in the correct active bucket _rollBuckets(node); // Increment the current bucket nb.buckets[nb.currentBucket] += workUnits; nb.totalLast24H += workUnits; - // Track an all-time total if you want to do “locked/unlocked” logic + // Track an all-time total if you want to do "locked/unlocked" logic nb.totalAllSubmissions += workUnits; // Optionally, ensure lastBucketTimestamp is set if first time @@ -120,9 +120,9 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE /** * @notice Bucket approach: * - totalAllSubmissions: total submissions ever done by this node. - * - totalLast24H: the sum of the ring buffer’s most recent 24h. - * We treat that as “locked.” - * - The difference (totalAllSubmissions - totalLast24H) is “unlocked.” + * - totalLast24H: the sum of the ring buffer's most recent 24h. + * We treat that as "locked." + * - The difference (totalAllSubmissions - totalLast24H) is "unlocked." * - We track lastClaimed to ensure we only pay incremental amounts. */ function claimRewards(address node) external { @@ -196,7 +196,7 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE NodeBuckets memory nb = nodeBuckets[node]; - // Simulate the ring buffer if updated “now” + // Simulate the ring buffer if updated "now" uint256 elapsed = (block.timestamp - nb.lastBucketTimestamp) / BUCKET_DURATION; uint256 simulatedTotalLast24H = nb.totalLast24H; if (elapsed >= NUM_BUCKETS) { @@ -211,7 +211,7 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE simulatedTotalLast24H -= nb.buckets[idx]; } } - // “Unlocked so far” if we hypothetically updated now + // "Unlocked so far" if we hypothetically updated now uint256 unlockedNow = nb.totalAllSubmissions - simulatedTotalLast24H; uint256 claimable = unlockedNow - nb.lastClaimed; uint256 claimableTokens = claimable * rewardRatePerUnit; @@ -252,7 +252,7 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE } function leavePool(address node) external onlyRole(COMPUTE_POOL_ROLE) { - // Optionally roll + finalize the node’s data. Zero out buckets, etc. + // Optionally roll + finalize the node's data. Zero out buckets, etc. _rollBuckets(node); } @@ -261,4 +261,51 @@ contract RewardsDistributorWorkSubmission is IRewardsDistributor, AccessControlE require(endTime == 0, "Already ended"); endTime = block.timestamp; } + + // -------------------------------------------------------------------------------------------- + // Soft slash functions for batch work scenarios + // -------------------------------------------------------------------------------------------- + + /** + * @notice Removes specific work units from a node's submissions (soft slash). + * This is useful when work is incomplete and you want to remove + * work submissions without harsh penalties. + * @param node The address of the node whose work should be removed. + * @param workUnits The number of work units to remove. + * @dev This function can only be called by the COMPUTE_POOL_ROLE. + * It removes work from the current bucket and adjusts totals accordingly. + */ + function removeWork(address node, uint256 workUnits) external onlyRole(COMPUTE_POOL_ROLE) { + if (workUnits == 0) return; // No-op if no work to remove + + NodeBuckets storage nb = nodeBuckets[node]; + _rollBuckets(node); + + // Ensure we don't remove more than what exists + uint256 toRemove = workUnits > nb.totalLast24H ? nb.totalLast24H : workUnits; + if (toRemove == 0) return; // Nothing to remove + + // Remove from current bucket first, then work backwards if needed + uint256 remaining = toRemove; + uint256 bucketIdx = nb.currentBucket; + + // Remove from buckets, starting with the current one + for (uint256 i = 0; i < NUM_BUCKETS && remaining > 0; i++) { + uint256 bucketAmount = nb.buckets[bucketIdx]; + if (bucketAmount > 0) { + uint256 removeFromBucket = remaining > bucketAmount ? bucketAmount : remaining; + nb.buckets[bucketIdx] -= removeFromBucket; + remaining -= removeFromBucket; + } + // Move to previous bucket (circular) + bucketIdx = bucketIdx == 0 ? NUM_BUCKETS - 1 : bucketIdx - 1; + } + + // Update totals + uint256 actualRemoved = toRemove - remaining; + nb.totalLast24H -= actualRemoved; + + // Also reduce total submissions to maintain consistency + nb.totalAllSubmissions = nb.totalAllSubmissions > actualRemoved ? nb.totalAllSubmissions - actualRemoved : 0; + } } diff --git a/src/SyntheticDataWorkValidator.sol b/src/SyntheticDataWorkValidator.sol index 00dbdd0..fe0e9fe 100644 --- a/src/SyntheticDataWorkValidator.sol +++ b/src/SyntheticDataWorkValidator.sol @@ -8,6 +8,8 @@ event WorkSubmitted(uint256 poolId, address provider, address nodeId, bytes32 wo event WorkInvalidated(uint256 poolId, address provider, address nodeId, bytes32 workKey, uint256 workUnits); +event WorkRemoved(uint256 poolId, address provider, address nodeId, bytes32 workKey, uint256 workUnits); + contract SyntheticDataWorkValidator is IWorkValidation { using EnumerableSet for EnumerableSet.Bytes32Set; @@ -84,6 +86,37 @@ contract SyntheticDataWorkValidator is IWorkValidation { return (info.provider, info.nodeId); } + function softInvalidateWork(uint256 poolId, bytes calldata data) external returns (address, address, uint256) { + require(msg.sender == computePool, "Unauthorized"); + require(data.length >= 64, "Data too short for soft invalidation"); + + // Decode both workKey and workUnits for soft invalidation + bytes32 workKey; + uint256 workUnits; + assembly { + workKey := calldataload(data.offset) + workUnits := calldataload(add(data.offset, 32)) + } + + require(poolWork[poolId].workKeys.contains(workKey), "Work not found"); + require(!poolWork[poolId].invalidWorkKeys.contains(workKey), "Work already invalidated"); + require( + block.timestamp - poolWork[poolId].work[workKey].timestamp < workValidityPeriod, + "Work invalidation window has lapsed" + ); + + WorkInfo memory info = poolWork[poolId].work[workKey]; + + // Move work from valid to invalid (same as hard invalidation) + poolWork[poolId].invalidWorkKeys.add(workKey); + poolWork[poolId].workKeys.remove(workKey); + + emit WorkRemoved(poolId, info.provider, info.nodeId, workKey, workUnits); + + // Return the specified work units to be soft-removed (key difference) + return (info.provider, info.nodeId, workUnits); + } + function getWorkInfo(uint256 poolId, bytes32 workKey) external view returns (WorkInfo memory) { return poolWork[poolId].work[workKey]; } diff --git a/src/interfaces/IComputePool.sol b/src/interfaces/IComputePool.sol index f556b0f..2867ead 100644 --- a/src/interfaces/IComputePool.sol +++ b/src/interfaces/IComputePool.sol @@ -92,6 +92,7 @@ interface IComputePool is IAccessControlEnumerable { function ejectNode(uint256 poolId, address nodekey) external; function submitWork(uint256 poolId, address nodekey, bytes calldata data) external; function invalidateWork(uint256 poolId, bytes calldata data) external returns (address, address); + function softInvalidateWork(uint256 poolId, bytes calldata data) external returns (address, address, uint256); function blacklistProvider(uint256 poolId, address provider) external; function blacklistProviderList(uint256 poolId, address[] memory providers) external; function blacklistAndPurgeProvider(uint256 poolId, address provider) external; diff --git a/src/interfaces/IRewardsDistributor.sol b/src/interfaces/IRewardsDistributor.sol index 070740a..592e4c9 100644 --- a/src/interfaces/IRewardsDistributor.sol +++ b/src/interfaces/IRewardsDistributor.sol @@ -19,4 +19,7 @@ interface IRewardsDistributor { function joinPool(address node) external; function leavePool(address node) external; function submitWork(address node, uint256 workUnits) external; + + // Soft slash function for work removal + function removeWork(address node, uint256 workUnits) external; } diff --git a/src/interfaces/IWorkValidation.sol b/src/interfaces/IWorkValidation.sol index 0990412..f6e863f 100644 --- a/src/interfaces/IWorkValidation.sol +++ b/src/interfaces/IWorkValidation.sol @@ -7,4 +7,6 @@ interface IWorkValidation { returns (bool, uint256); function invalidateWork(uint256 poolId, bytes calldata data) external returns (address, address); + + function softInvalidateWork(uint256 poolId, bytes calldata data) external returns (address, address, uint256); } diff --git a/test/RewardsDistributorWorkSubmission.t.sol b/test/RewardsDistributorWorkSubmission.t.sol index bc75581..dba79b9 100644 --- a/test/RewardsDistributorWorkSubmission.t.sol +++ b/test/RewardsDistributorWorkSubmission.t.sol @@ -480,7 +480,7 @@ contract RewardsDistributorWorkSubmissionRingBufferTest is Test { uint256 unlockedNow = fetchRewards(node, false); assertEq(unlockedNow, 100, "First 100 is unlocked, second 200 is locked."); - // 3) Slash the node’s pending 24h => manager only + // 3) Slash the node's pending 24h => manager only // That should remove the locked 200 from totalAll, zero the ring buffer, etc. vm.prank(manager); distributor.slashPendingRewards(node); @@ -489,36 +489,118 @@ contract RewardsDistributorWorkSubmissionRingBufferTest is Test { (uint256 last24HAfter, uint256 totalAllAfter,,) = distributor.nodeInfo(node); assertEq(last24HAfter, 0, "Should have cleared the ring buffer"); assertEq(totalAllAfter, 100, "Should have subtracted the slashed 200 from totalAll"); - // lastClaimed should remain the same, because we didn’t claim. + // lastClaimed should remain the same, because we didn't claim. // 5) Confirm that now, if we skip 25 hours more, there is no "locked" portion to unlock skip(25 hours); uint256 unlockedAfterSlash = fetchRewards(node, false); // The 100 is still unlocked, but we never claimed it, so it remains unclaimed. - // Because slash only subtracted from totalAll the “locked” portion, the older 100 is unaffected. + // Because slash only subtracted from totalAll the "locked" portion, the older 100 is unaffected. // So unlockedAfterSlash == 100 - lastClaimedAfter. But we haven't claimed at all, so lastClaimedAfter=0. assertEq(unlockedAfterSlash, 100, "The older 100 remains claimable."); + } - // 6) Claim the 100 - vm.prank(nodeProvider); - distributor.claimRewards(node); + // ----------------------------------------------------------------------- + // Test: Soft slash functionality - removeWork + // ----------------------------------------------------------------------- + function testRemoveWork() public { + vm.prank(address(mockComputePool)); + mockComputePool.joinComputePool(node, 10); - // Confirm node got 100 - uint256 nodeBalance = mockRewardToken.balanceOf(nodeProvider); - assertEq(nodeBalance, 100, "Node should receive the older (unlocked) 100 tokens"); + // Submit some work + vm.prank(address(mockComputePool)); + distributor.submitWork(node, 500); + + (uint256 last24HBefore, uint256 totalAllBefore,,) = distributor.nodeInfo(node); + assertEq(last24HBefore, 500); + assertEq(totalAllBefore, 500); + + // Remove 200 work units + vm.prank(address(mockComputePool)); + distributor.removeWork(node, 200); + + (uint256 last24HAfter, uint256 totalAllAfter,,) = distributor.nodeInfo(node); + assertEq(last24HAfter, 300, "Should have removed 200 work units"); + assertEq(totalAllAfter, 300, "Total submissions should also be reduced"); + } + + function testRemoveWorkExceedsAvailable() public { + vm.prank(address(mockComputePool)); + mockComputePool.joinComputePool(node, 10); + + // Submit 100 work units + vm.prank(address(mockComputePool)); + distributor.submitWork(node, 100); + + // Try to remove 200 (more than available) + vm.prank(address(mockComputePool)); + distributor.removeWork(node, 200); + + // Should only remove what's available (100) + (uint256 last24H, uint256 totalAll,,) = distributor.nodeInfo(node); + assertEq(last24H, 0, "Should have removed all available work"); + assertEq(totalAll, 0, "Total should also be zero"); + } + + function testRemoveWorkAcrossMultipleBuckets() public { + vm.prank(address(mockComputePool)); + mockComputePool.joinComputePool(node, 10); + + // Submit work in different time buckets + vm.prank(address(mockComputePool)); + distributor.submitWork(node, 100); + + skip(2 hours); // Move to next bucket + vm.prank(address(mockComputePool)); + distributor.submitWork(node, 200); + + skip(2 hours); // Move to next bucket + vm.prank(address(mockComputePool)); + distributor.submitWork(node, 300); + + // Total should be 600 + (uint256 last24HBefore, uint256 totalAllBefore,,) = distributor.nodeInfo(node); + assertEq(last24HBefore, 600); + assertEq(totalAllBefore, 600); + + // Remove 450 work units (should span multiple buckets) + vm.prank(address(mockComputePool)); + distributor.removeWork(node, 450); + + (uint256 last24HAfter, uint256 totalAllAfter,,) = distributor.nodeInfo(node); + assertEq(last24HAfter, 150, "Should have removed 450 work units"); + assertEq(totalAllAfter, 150, "Total should also be reduced"); } // ----------------------------------------------------------------------- - // Test: setRewardRate and endRewards in ring-buffer version + // Test: Your use case - multiple nodes (call multiple times) // ----------------------------------------------------------------------- - function testSetRewardRate() public { - vm.prank(manager); - vm.expectRevert(); - distributor.setRewardRate(12345); - } + function testMultipleNodeRemoval() public { + // Setup multiple nodes + vm.prank(address(mockComputePool)); + mockComputePool.joinComputePool(node1, 10); + vm.prank(address(mockComputePool)); + mockComputePool.joinComputePool(node2, 10); - function testEndRewards() public { + // Submit work for both nodes vm.prank(address(mockComputePool)); - distributor.endRewards(); + distributor.submitWork(node1, 100); + vm.prank(address(mockComputePool)); + distributor.submitWork(node2, 150); + + // Remove work from both nodes (call multiple times) + vm.prank(address(mockComputePool)); + distributor.removeWork(node1, 80); + vm.prank(address(mockComputePool)); + distributor.removeWork(node2, 80); + + // Check results - both should have 80 removed + (uint256 last24H1, uint256 totalAll1,,) = distributor.nodeInfo(node1); + (uint256 last24H2, uint256 totalAll2,,) = distributor.nodeInfo(node2); + + assertEq(last24H1, 20, "Node1 should have 20 remaining (100-80)"); + assertEq(totalAll1, 20, "Node1 total should be 20"); + assertEq(last24H2, 70, "Node2 should have 70 remaining (150-80)"); + assertEq(totalAll2, 70, "Node2 total should be 70"); } }