diff --git a/tests/integration/x/vm/test_gas.go b/tests/integration/x/vm/test_gas.go new file mode 100644 index 00000000..bce4c7d2 --- /dev/null +++ b/tests/integration/x/vm/test_gas.go @@ -0,0 +1,206 @@ +package vm + +import ( + "math/big" + + sdkmath "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" + "github.com/cosmos/evm/testutil/integration/evm/factory" + "github.com/cosmos/evm/testutil/integration/evm/grpc" + testkeyring "github.com/cosmos/evm/testutil/keyring" + erc20mocks "github.com/cosmos/evm/x/erc20/types/mocks" + "github.com/cosmos/evm/x/vm/keeper" + "github.com/cosmos/evm/x/vm/types" + "go.uber.org/mock/gomock" +) + +const ( + DefaultCoreMsgGasUsage = 21000 + DefaultGasPrice = 120000 +) + +// TestGasRefundGas tests the refund gas exclusively without going though the state transition +// The gas part on the name refers to the file name to not generate a duplicated test name +func (suite *KeeperTestSuite) TestGasRefundGas() { + // Create a txFactory + grpcHandler := grpc.NewIntegrationHandler(suite.Network) + txFactory := factory.New(suite.Network, grpcHandler) + + // Create a core message to use for the test + keyring := testkeyring.New(2) + sender := keyring.GetKey(0) + recipient := keyring.GetAddr(1) + coreMsg, err := txFactory.GenerateGethCoreMsg( + sender.Priv, + types.EvmTxArgs{ + To: &recipient, + Amount: big.NewInt(100), + GasPrice: big.NewInt(120000), + }, + ) + suite.Require().NoError(err) + + // Produce all the test cases + testCases := []struct { + name string + leftoverGas uint64 // The coreMsg always uses 21000 gas limit + malleate func(sdk.Context) sdk.Context + expectedRefund sdk.Coins + errContains string + }{ + { + name: "Refund the full value as no gas was used", + leftoverGas: DefaultCoreMsgGasUsage, + expectedRefund: sdk.NewCoins( + sdk.NewCoin(suite.Network.GetBaseDenom(), sdkmath.NewInt(DefaultCoreMsgGasUsage*DefaultGasPrice)), + ), + }, + { + name: "Refund half the value as half gas was used", + leftoverGas: DefaultCoreMsgGasUsage / 2, + expectedRefund: sdk.NewCoins( + sdk.NewCoin(suite.Network.GetBaseDenom(), sdkmath.NewInt((DefaultCoreMsgGasUsage*DefaultGasPrice)/2)), + ), + }, + { + name: "No refund as no gas was left over used", + leftoverGas: 0, + expectedRefund: sdk.NewCoins( + sdk.NewCoin(suite.Network.GetBaseDenom(), sdkmath.NewInt(0)), + ), + }, + { + name: "Refund with context fees, refunding the full value", + leftoverGas: DefaultCoreMsgGasUsage, + malleate: func(ctx sdk.Context) sdk.Context { + // Set the fee abstraction paid fee key with a single coin + return ctx.WithValue( + keeper.ContextPaidFeesKey{}, + sdk.NewCoins( + sdk.NewCoin("acoin", sdkmath.NewInt(750_000_000)), + ), + ) + }, + expectedRefund: sdk.NewCoins( + sdk.NewCoin("acoin", sdkmath.NewInt(750_000_000)), + ), + }, + { + name: "Refund with context fees, refunding the half the value", + leftoverGas: DefaultCoreMsgGasUsage / 2, + malleate: func(ctx sdk.Context) sdk.Context { + // Set the fee abstraction paid fee key with a single coin + return ctx.WithValue( + keeper.ContextPaidFeesKey{}, + sdk.NewCoins( + sdk.NewCoin("acoin", sdkmath.NewInt(750_000_000)), + ), + ) + }, + expectedRefund: sdk.NewCoins( + sdk.NewCoin("acoin", sdkmath.NewInt(750_000_000/2)), + ), + }, + { + name: "Refund with context fees, no refund", + leftoverGas: 0, + malleate: func(ctx sdk.Context) sdk.Context { + // Set the fee abstraction paid fee key with a single coin + return ctx.WithValue( + keeper.ContextPaidFeesKey{}, + sdk.NewCoins( + sdk.NewCoin("acoin", sdkmath.NewInt(750_000_000)), + ), + ) + }, + expectedRefund: sdk.NewCoins( + sdk.NewCoin("acoin", sdkmath.NewInt(0)), + ), + }, + { + name: "Error - More than one coin being passed", + leftoverGas: DefaultCoreMsgGasUsage, + malleate: func(ctx sdk.Context) sdk.Context { + // Set the fee abstraction paid fee key with a single coin + return ctx.WithValue( + keeper.ContextPaidFeesKey{}, + sdk.NewCoins( + sdk.NewCoin("acoin", sdkmath.NewInt(750_000_000)), + sdk.NewCoin("atwo", sdkmath.NewInt(750_000_000)), + ), + ) + }, + expectedRefund: sdk.NewCoins( + sdk.NewCoin("acoin", sdkmath.NewInt(0)), // We say as zero to skip the mock bank check + ), + errContains: "expected a single coin for EVM refunds, got 2", + }, + } + + // Iterate though the test cases + for _, tc := range testCases { + suite.Run(tc.name, func() { + // Generate a cached context to not leak data between tests + ctx, _ := suite.Network.GetContext().CacheContext() + + // Create a new controller for the mock + ctrl := gomock.NewController(suite.T()) + defer ctrl.Finish() + + // Apply the malleate function to the context + if tc.malleate != nil { + ctx = tc.malleate(ctx) + } + + // Create a new mock bank keeper + mockBankKeeper := erc20mocks.NewMockBankKeeper(ctrl) + + // Apply the expect, but only if expected refund is not zero + if !tc.expectedRefund.IsZero() { + mockBankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx sdk.Context, senderModule string, recipient sdk.AccAddress, coins sdk.Coins) error { + if !coins.Equal(tc.expectedRefund) { + suite.T().Errorf("expected %s, got %s", tc.expectedRefund, coins) + } + + return nil + }) + } + + // Initialize a new EVM keeper with the mock bank keeper + // We need to redo this every time, since we will apply the mocked bank keeper at this step + evmKeeper := keeper.NewKeeper( + suite.Network.App.AppCodec(), + suite.Network.App.GetKey(types.StoreKey), + suite.Network.App.GetKey(types.StoreKey), + suite.Network.App.GetEVMKeeper().KVStoreKeys(), + authtypes.NewModuleAddress(govtypes.ModuleName), + suite.Network.App.GetAccountKeeper(), + mockBankKeeper, + suite.Network.App.GetStakingKeeper(), + suite.Network.App.GetFeeMarketKeeper(), + suite.Network.App.GetConsensusParamsKeeper(), + suite.Network.App.GetErc20Keeper(), + "", + ) + + // Call the msg, not further checks are needed, all balance checks are done in the mock + err := evmKeeper.RefundGas( + ctx, + *coreMsg, + tc.leftoverGas, + suite.Network.GetBaseDenom(), + ) + + // Check the error + if tc.errContains != "" { + suite.Require().ErrorContains(err, tc.errContains, "RefundGas should return an error") + } else { + suite.Require().NoError(err, "RefundGas should not return an error") + } + }) + } + +} diff --git a/x/vm/keeper/gas.go b/x/vm/keeper/gas.go index ab5dcf08..be7c565f 100644 --- a/x/vm/keeper/gas.go +++ b/x/vm/keeper/gas.go @@ -16,6 +16,9 @@ import ( authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" ) +// ContextPaidFeesKey is a key used to store the paid fee in the context +type ContextPaidFeesKey struct{} + // GetEthIntrinsicGas returns the intrinsic gas cost for the transaction func (k *Keeper) GetEthIntrinsicGas(ctx sdk.Context, msg core.Message, cfg *params.ChainConfig, isContractCreation bool, @@ -36,12 +39,45 @@ func (k *Keeper) RefundGas(ctx sdk.Context, msg core.Message, leftoverGas uint64 // Return EVM tokens for remaining gas, exchanged at the original rate. remaining := new(big.Int).Mul(new(big.Int).SetUint64(leftoverGas), msg.GasPrice) + // Check if gas is zero + if msg.GasLimit == 0 { + // If gas is zero, we cannot refund anything, so we return early + return nil + } + switch remaining.Sign() { case -1: // negative refund errors return errorsmod.Wrapf(types.ErrInvalidRefund, "refunded amount value cannot be negative %d", remaining.Int64()) case 1: - // positive amount refund + // Attempt to extract the paid coin from the context + // This is used when fee abstraction is applied into the fee payment + // If no value is found under the context, the original denom is used + if val := ctx.Value(ContextPaidFeesKey{}); val != nil { + // We check if a coin exists under the value and if it's not empty + if paidCoins, ok := val.(sdk.Coins); ok && !paidCoins.IsZero() { + // We know that only a single coin is used for EVM payments + if len(paidCoins) != 1 { + // This should never happen, but if it does, we return an error + return errorsmod.Wrapf(types.ErrInvalidRefund, "expected a single coin for EVM refunds, got %d", len(paidCoins)) + } + paidCoin := paidCoins[0] + + // Extract the coin information + denom = paidCoin.Denom + amount := paidCoin.Amount.BigInt() + + // Calculate the amount to refund + // This is calculated as: + // remaining = amount * leftoverGas / gasUsed + remaining = new(big.Int).Div( + new(big.Int).Mul(amount, new(big.Int).SetUint64(leftoverGas)), + new(big.Int).SetUint64(msg.GasLimit), + ) + } + } + + // Positive amount refund refundedCoins := sdk.Coins{sdk.NewCoin(denom, sdkmath.NewIntFromBigInt(remaining))} // refund to sender from the fee collector module account, which is the escrow account in charge of collecting tx fees @@ -51,7 +87,7 @@ func (k *Keeper) RefundGas(ctx sdk.Context, msg core.Message, leftoverGas uint64 return errorsmod.Wrapf(err, "failed to refund %d leftover gas (%s)", leftoverGas, refundedCoins.String()) } default: - // no refund, consume gas and update the tx gas meter + // No refund, consume gas and update the tx gas meter } return nil