diff --git a/cmd/serve.go b/cmd/serve.go index a32514b07..e8c674f31 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -94,6 +94,14 @@ func (c *serveCmd) Command() *cobra.Command { FlagDefault: 15, Required: true, }, + { + Name: "min-distribution-account-balance", + Usage: "Minimum XLM balance required for the distribution account in stroops (1 XLM = 10,000,000 stroops). Server will fail to start if balance is below this threshold. Set to 0 to only check account existence.", + OptType: types.Int, + ConfigKey: &cfg.MinDistributionAccountBalance, + FlagDefault: 100_000_000, // 10 XLM in stroops + Required: false, + }, } // Distribution Account Signature Client options diff --git a/internal/integrationtests/infrastructure/setup.go b/internal/integrationtests/infrastructure/setup.go index 12dec9d2c..d4d73e17e 100644 --- a/internal/integrationtests/infrastructure/setup.go +++ b/internal/integrationtests/infrastructure/setup.go @@ -1023,7 +1023,10 @@ func createRPCService(containers *SharedContainers, ctx context.Context) (servic } // Start tracking RPC health - go rpcService.TrackRPCServiceHealth(ctx, nil) + go func() { + //nolint:errcheck // Error is expected on context cancellation during shutdown + rpcService.TrackRPCServiceHealth(ctx, nil) + }() return rpcService, nil } diff --git a/internal/serve/serve.go b/internal/serve/serve.go index ea5f49654..9ae9b801d 100644 --- a/internal/serve/serve.go +++ b/internal/serve/serve.go @@ -2,6 +2,7 @@ package serve import ( "context" + "errors" "fmt" "net/http" "time" @@ -60,6 +61,9 @@ type Configs struct { // RPC RPCURL string + // Distribution Account Validation + MinDistributionAccountBalance int64 // Minimum balance in stroops. 0 to only check existence. + // GraphQL GraphQLComplexityLimit int @@ -136,7 +140,15 @@ func initHandlerDeps(ctx context.Context, cfg Configs) (handlerDeps, error) { if err != nil { return handlerDeps{}, fmt.Errorf("instantiating rpc service: %w", err) } - go rpcService.TrackRPCServiceHealth(ctx, nil) + + // Validate distribution account exists and has sufficient balance + distributionAccountPublicKey, err := cfg.DistributionAccountSignatureClient.GetAccountPublicKey(ctx) + if err != nil { + return handlerDeps{}, fmt.Errorf("getting distribution account public key: %w", err) + } + if err := validateDistributionAccount(rpcService, distributionAccountPublicKey, cfg.MinDistributionAccountBalance); err != nil { + return handlerDeps{}, fmt.Errorf("distribution account validation failed: %w", err) + } channelAccountStore := store.NewChannelAccountModel(dbConnectionPool) @@ -180,7 +192,13 @@ func initHandlerDeps(ctx context.Context, cfg Configs) (handlerDeps, error) { if err != nil { return handlerDeps{}, fmt.Errorf("instantiating channel account service: %w", err) } - go ensureChannelAccounts(ctx, channelAccountService, int64(cfg.NumberOfChannelAccounts)) + + // Ensure channel accounts exist synchronously - fail startup if validation fails + log.Ctx(ctx).Info("Ensuring the number of channel accounts...") + if err := channelAccountService.EnsureChannelAccounts(ctx, int64(cfg.NumberOfChannelAccounts)); err != nil { + return handlerDeps{}, fmt.Errorf("ensuring channel accounts: %w", err) + } + log.Ctx(ctx).Infof("✅ Ensured that %d channel accounts exist", cfg.NumberOfChannelAccounts) return handlerDeps{ Models: models, @@ -197,16 +215,6 @@ func initHandlerDeps(ctx context.Context, cfg Configs) (handlerDeps, error) { }, nil } -func ensureChannelAccounts(ctx context.Context, channelAccountService services.ChannelAccountService, numberOfChannelAccounts int64) { - log.Ctx(ctx).Info("Ensuring the number of channel accounts in the database...") - err := channelAccountService.EnsureChannelAccounts(ctx, numberOfChannelAccounts) - if err != nil { - log.Ctx(ctx).Errorf("error ensuring the number of channel accounts: %s", err.Error()) - return - } - log.Ctx(ctx).Infof("Ensured that exactly %d channel accounts exist in the database", numberOfChannelAccounts) -} - func handler(deps handlerDeps) http.Handler { mux := supporthttp.NewAPIMux(log.DefaultLogger) mux.NotFound(httperror.ErrorHandler{Error: httperror.NotFound}.ServeHTTP) @@ -270,6 +278,35 @@ func handler(deps handlerDeps) http.Handler { return mux } +// validateDistributionAccount checks that the distribution account exists on the network +// and has sufficient balance for operations. +func validateDistributionAccount(rpcService services.RPCService, distributionAccountPublicKey string, minBalance int64) error { + accountInfo, err := rpcService.GetAccountInfo(distributionAccountPublicKey) + if err != nil { + if errors.Is(err, services.ErrAccountNotFound) { + return fmt.Errorf("distribution account %s does not exist on the network", distributionAccountPublicKey) + } + return fmt.Errorf("validating distribution account: %w", err) + } + + if minBalance > 0 && accountInfo.Balance < minBalance { + return fmt.Errorf( + "distribution account %s has insufficient balance: %d stroops (minimum: %d stroops / %.2f XLM)", + distributionAccountPublicKey, + accountInfo.Balance, + minBalance, + float64(minBalance)/10_000_000, + ) + } + + log.Infof("✅ Distribution account %s validated: balance %d stroops (%.2f XLM)", + distributionAccountPublicKey, + accountInfo.Balance, + float64(accountInfo.Balance)/10_000_000, + ) + return nil +} + func addComplexityCalculation(config *generated.Config) { /* Complexity Calculation diff --git a/internal/serve/serve_test.go b/internal/serve/serve_test.go new file mode 100644 index 000000000..0281d3f57 --- /dev/null +++ b/internal/serve/serve_test.go @@ -0,0 +1,90 @@ +// Tests for serve package initialization and validation functions. +package serve + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stellar/wallet-backend/internal/services" +) + +func TestValidateDistributionAccount(t *testing.T) { + testAccountAddress := "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5" + + t.Run("successful_with_sufficient_balance", func(t *testing.T) { + mockRPCService := services.NewRPCServiceMock(t) + mockRPCService.On("GetAccountInfo", testAccountAddress). + Return(services.AccountInfo{ + Balance: 100_000_000, // 10 XLM + SeqNum: 12345, + }, nil).Once() + + err := validateDistributionAccount(mockRPCService, testAccountAddress, 100_000_000) + require.NoError(t, err) + }) + + t.Run("successful_with_balance_above_threshold", func(t *testing.T) { + mockRPCService := services.NewRPCServiceMock(t) + mockRPCService.On("GetAccountInfo", testAccountAddress). + Return(services.AccountInfo{ + Balance: 500_000_000, // 50 XLM + SeqNum: 12345, + }, nil).Once() + + err := validateDistributionAccount(mockRPCService, testAccountAddress, 100_000_000) // 10 XLM threshold + require.NoError(t, err) + }) + + t.Run("successful_with_zero_threshold_existence_only", func(t *testing.T) { + mockRPCService := services.NewRPCServiceMock(t) + mockRPCService.On("GetAccountInfo", testAccountAddress). + Return(services.AccountInfo{ + Balance: 1_000_000, // 0.1 XLM (below typical threshold but should pass) + SeqNum: 12345, + }, nil).Once() + + err := validateDistributionAccount(mockRPCService, testAccountAddress, 0) // 0 means only check existence + require.NoError(t, err) + }) + + t.Run("account_not_found", func(t *testing.T) { + mockRPCService := services.NewRPCServiceMock(t) + mockRPCService.On("GetAccountInfo", testAccountAddress). + Return(services.AccountInfo{}, services.ErrAccountNotFound).Once() + + err := validateDistributionAccount(mockRPCService, testAccountAddress, 100_000_000) + require.Error(t, err) + assert.Contains(t, err.Error(), "does not exist on the network") + assert.Contains(t, err.Error(), testAccountAddress) + }) + + t.Run("insufficient_balance", func(t *testing.T) { + mockRPCService := services.NewRPCServiceMock(t) + mockRPCService.On("GetAccountInfo", testAccountAddress). + Return(services.AccountInfo{ + Balance: 50_000_000, // 5 XLM + SeqNum: 12345, + }, nil).Once() + + err := validateDistributionAccount(mockRPCService, testAccountAddress, 100_000_000) // 10 XLM threshold + require.Error(t, err) + assert.Contains(t, err.Error(), "insufficient balance") + assert.Contains(t, err.Error(), testAccountAddress) + assert.Contains(t, err.Error(), "50000000 stroops") + assert.Contains(t, err.Error(), "100000000 stroops") + }) + + t.Run("rpc_error", func(t *testing.T) { + mockRPCService := services.NewRPCServiceMock(t) + mockRPCService.On("GetAccountInfo", testAccountAddress). + Return(services.AccountInfo{}, errors.New("connection failed")).Once() + + err := validateDistributionAccount(mockRPCService, testAccountAddress, 100_000_000) + require.Error(t, err) + assert.Contains(t, err.Error(), "validating distribution account") + assert.Contains(t, err.Error(), "connection failed") + }) +} diff --git a/internal/services/channel_account_service.go b/internal/services/channel_account_service.go index 62bb351e9..4d57478de 100644 --- a/internal/services/channel_account_service.go +++ b/internal/services/channel_account_service.go @@ -140,6 +140,24 @@ func (s *channelAccountService) createChannelAccounts(ctx context.Context, amoun } log.Ctx(ctx).Infof("🎉 Successfully created %d channel account(s) on chain", amount) + // Validate that all accounts exist on-chain before inserting into database + log.Ctx(ctx).Infof("⏳ Validating %d channel account(s) exist on-chain...", amount) + var missingAccounts []string + for _, chAcc := range channelAccountsToInsert { + _, err = s.RPCService.GetAccountInfo(chAcc.PublicKey) + if err != nil { + if errors.Is(err, ErrAccountNotFound) { + missingAccounts = append(missingAccounts, chAcc.PublicKey) + continue + } + return fmt.Errorf("validating channel account %s on-chain: %w", chAcc.PublicKey, err) + } + } + if len(missingAccounts) > 0 { + return fmt.Errorf("channel account validation failed: %d of %d accounts missing from network after creation: %v", len(missingAccounts), amount, missingAccounts) + } + log.Ctx(ctx).Infof("✅ Validated %d channel account(s) exist on-chain", amount) + if err = s.ChannelAccountStore.BatchInsert(ctx, s.DB, channelAccountsToInsert); err != nil { return fmt.Errorf("inserting channel account(s): %w", err) } @@ -236,70 +254,91 @@ func (s *channelAccountService) submitChannelAccountsTxOnChain( ops []txnbuild.Operation, chAccSigner ChannelAccSigner, ) error { + // Wait for RPC service to become healthy by polling GetHealth directly. + // This lets the API server startup so that users can start interacting with the API + // which does not depend on RPC, instead of waiting till it becomes healthy. log.Ctx(ctx).Infof("⏳ Waiting for RPC service to become healthy") - rpcHeartbeatChannel := s.RPCService.GetHeartbeatChannel() - select { - case <-ctx.Done(): - return fmt.Errorf("context cancelled while waiting for rpc service to become healthy: %w", ctx.Err()) - - // The channel account creation goroutine will wait in the background for the rpc service to become healthy on startup. - // This lets the API server startup so that users can start interacting with the API which does not depend on RPC, instead of waiting till it becomes healthy. - case <-rpcHeartbeatChannel: - log.Ctx(ctx).Infof("👍 RPC service is healthy") - accountSeq, err := s.RPCService.GetAccountLedgerSequence(distributionAccountPublicKey) - if err != nil { - return fmt.Errorf("getting ledger sequence for distribution account public key=%s: %w", distributionAccountPublicKey, err) - } - tx, err := txnbuild.NewTransaction( - txnbuild.TransactionParams{ - SourceAccount: &txnbuild.SimpleAccount{ - AccountID: distributionAccountPublicKey, - Sequence: accountSeq, - }, - IncrementSequenceNum: true, - Operations: ops, - BaseFee: s.BaseFee, - Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(300)}, - }, - ) - if err != nil { - return fmt.Errorf("building transaction: %w", err) - } + healthCheckCtx, cancel := context.WithTimeout(ctx, rpcHealthCheckTimeout) + defer cancel() - // Sign the transaction for the distribution account - tx, err = s.DistributionAccountSignatureClient.SignStellarTransaction(ctx, tx, distributionAccountPublicKey) - if err != nil { - return fmt.Errorf("signing transaction for distribution account: %w", err) - } - // Sign the transaction for the channel accounts - tx, err = chAccSigner(ctx, tx) - if err != nil { - return fmt.Errorf("signing transaction with channel account(s) keypairs: %w", err) - } + ticker := time.NewTicker(sleepDelayForChannelAccountCreation) + defer ticker.Stop() - txHash, err := tx.HashHex(s.DistributionAccountSignatureClient.NetworkPassphrase()) - if err != nil { - return fmt.Errorf("getting transaction hash: %w", err) - } - txXDR, err := tx.Base64() - if err != nil { - return fmt.Errorf("getting transaction envelope: %w", err) + // Try immediately first + _, err := s.RPCService.GetHealth() + if err != nil { + log.Ctx(ctx).Debugf("Initial RPC health check failed: %v, will retry...", err) + for { + select { + case <-healthCheckCtx.Done(): + return fmt.Errorf("timeout waiting for RPC service to become healthy: %w", healthCheckCtx.Err()) + case <-ticker.C: + _, err = s.RPCService.GetHealth() + if err == nil { + break + } + log.Ctx(ctx).Debugf("RPC health check failed: %v, will retry...", err) + continue + } + break } + } - log.Ctx(ctx).Infof("🚧 Submitting channel account transaction to RPC service") - err = s.submitTransaction(ctx, txHash, txXDR) - if err != nil { - return fmt.Errorf("submitting channel account transaction to RPC service: %w", err) - } - log.Ctx(ctx).Infof("🚧 Successfully submitted channel account transaction to RPC service, waiting for confirmation...") - err = s.waitForTransactionConfirmation(ctx, txHash) - if err != nil { - return fmt.Errorf("getting transaction status: %w", err) - } + log.Ctx(ctx).Infof("👍 RPC service is healthy") + accountSeq, err := s.RPCService.GetAccountLedgerSequence(distributionAccountPublicKey) + if err != nil { + return fmt.Errorf("getting ledger sequence for distribution account public key=%s: %w", distributionAccountPublicKey, err) + } - return nil + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionAccountPublicKey, + Sequence: accountSeq, + }, + IncrementSequenceNum: true, + Operations: ops, + BaseFee: s.BaseFee, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(300)}, + }, + ) + if err != nil { + return fmt.Errorf("building transaction: %w", err) + } + + // Sign the transaction for the distribution account + tx, err = s.DistributionAccountSignatureClient.SignStellarTransaction(ctx, tx, distributionAccountPublicKey) + if err != nil { + return fmt.Errorf("signing transaction for distribution account: %w", err) + } + // Sign the transaction for the channel accounts + tx, err = chAccSigner(ctx, tx) + if err != nil { + return fmt.Errorf("signing transaction with channel account(s) keypairs: %w", err) } + + txHash, err := tx.HashHex(s.DistributionAccountSignatureClient.NetworkPassphrase()) + if err != nil { + return fmt.Errorf("getting transaction hash: %w", err) + } + txXDR, err := tx.Base64() + if err != nil { + return fmt.Errorf("getting transaction envelope: %w", err) + } + + log.Ctx(ctx).Infof("🚧 Submitting channel account transaction to RPC service") + err = s.submitTransaction(ctx, txHash, txXDR) + if err != nil { + return fmt.Errorf("submitting channel account transaction to RPC service: %w", err) + } + log.Ctx(ctx).Infof("🚧 Successfully submitted channel account transaction to RPC service, waiting for confirmation...") + err = s.waitForTransactionConfirmation(ctx, txHash) + if err != nil { + return fmt.Errorf("getting transaction status: %w", err) + } + + return nil } func (s *channelAccountService) submitTransaction(_ context.Context, hash string, signedTxXDR string) error { @@ -399,14 +438,12 @@ func (o *ChannelAccountServiceOptions) Validate() error { return nil } -func NewChannelAccountService(ctx context.Context, opts ChannelAccountServiceOptions) (*channelAccountService, error) { +func NewChannelAccountService(_ context.Context, opts ChannelAccountServiceOptions) (*channelAccountService, error) { err := opts.Validate() if err != nil { return nil, fmt.Errorf("validating channel account service options: %w", err) } - go opts.RPCService.TrackRPCServiceHealth(ctx, nil) - return &channelAccountService{ DB: opts.DB, RPCService: opts.RPCService, diff --git a/internal/services/channel_account_service_test.go b/internal/services/channel_account_service_test.go index 3acd07be6..728f09cff 100644 --- a/internal/services/channel_account_service_test.go +++ b/internal/services/channel_account_service_test.go @@ -2,8 +2,8 @@ package services import ( "context" + "fmt" "testing" - "time" "github.com/stellar/go/keypair" "github.com/stellar/go/network" @@ -107,11 +107,9 @@ func Test_ChannelAccountService_EnsureChannelAccounts(t *testing.T) { Return(network.TestNetworkPassphrase). Once() - heartbeatChan := make(chan entities.RPCGetHealthResult, 1) - heartbeatChan <- entities.RPCGetHealthResult{Status: "healthy"} mockRPCService. - On("GetHeartbeatChannel"). - Return(heartbeatChan). + On("GetHealth"). + Return(entities.RPCGetHealthResult{Status: "healthy"}, nil). Once(). On("GetAccountLedgerSequence", distributionAccount.Address()). Return(int64(123), nil). @@ -121,7 +119,11 @@ func Test_ChannelAccountService_EnsureChannelAccounts(t *testing.T) { Once(). On("GetTransaction", mock.AnythingOfType("string")). Return(entities.RPCGetTransactionResult{Status: entities.SuccessStatus}, nil). - Once() + Once(). + // Mock GetAccountInfo for on-chain validation after creation + On("GetAccountInfo", mock.AnythingOfType("string")). + Return(AccountInfo{Balance: 0, SeqNum: 0}, nil). + Times(3) // 3 channel accounts being created channelAccountStore. On("BatchInsert", mock.Anything, dbConnectionPool, mock.AnythingOfType("[]*store.ChannelAccount")). @@ -195,10 +197,8 @@ func Test_ChannelAccountService_EnsureChannelAccounts(t *testing.T) { Return(&signedTx, nil). Once() - heartbeatChan := make(chan entities.RPCGetHealthResult, 1) - heartbeatChan <- entities.RPCGetHealthResult{Status: "healthy"} mockRPCService. - On("GetHeartbeatChannel").Return(heartbeatChan).Once(). + On("GetHealth").Return(entities.RPCGetHealthResult{Status: "healthy"}, nil).Once(). On("GetAccountLedgerSequence", distributionAccount.Address()).Return(int64(123), nil).Once(). On("GetAccountLedgerSequence", chAcc1.PublicKey).Return(int64(123), nil).Once(). On("GetAccountLedgerSequence", chAcc2.PublicKey).Return(int64(123), nil).Once(). @@ -256,11 +256,9 @@ func Test_ChannelAccountService_EnsureChannelAccounts(t *testing.T) { Return(network.TestNetworkPassphrase). Once() - heartbeatChan := make(chan entities.RPCGetHealthResult, 1) - heartbeatChan <- entities.RPCGetHealthResult{Status: "healthy"} mockRPCService. - On("GetHeartbeatChannel"). - Return(heartbeatChan). + On("GetHealth"). + Return(entities.RPCGetHealthResult{Status: "healthy"}, nil). Once(). On("GetAccountLedgerSequence", distributionAccount.Address()). Return(int64(123), nil). @@ -312,11 +310,9 @@ func Test_ChannelAccountService_EnsureChannelAccounts(t *testing.T) { Return(network.TestNetworkPassphrase). Once() - heartbeatChan := make(chan entities.RPCGetHealthResult, 1) - heartbeatChan <- entities.RPCGetHealthResult{Status: "healthy"} mockRPCService. - On("GetHeartbeatChannel"). - Return(heartbeatChan). + On("GetHealth"). + Return(entities.RPCGetHealthResult{Status: "healthy"}, nil). Once(). On("GetAccountLedgerSequence", distributionAccount.Address()). Return(int64(123), nil). @@ -349,10 +345,12 @@ func Test_ChannelAccountService_EnsureChannelAccounts(t *testing.T) { Return(distributionAccount.Address(), nil). Once() - heartbeatChan := make(chan entities.RPCGetHealthResult, 1) - mockRPCService.On("GetHeartbeatChannel").Return(heartbeatChan) + mockRPCService. + On("GetHealth"). + Return(entities.RPCGetHealthResult{}, fmt.Errorf("RPC unavailable")). + Maybe() }, - expectedError: "context cancelled while waiting for rpc service to become healthy", + expectedError: "timeout waiting for RPC service to become healthy", }, } @@ -361,7 +359,6 @@ func Test_ChannelAccountService_EnsureChannelAccounts(t *testing.T) { // Create fresh mocks for each test ctx := tc.getCtx() mockRPCService := NewRPCServiceMock(t) - mockRPCService.On("TrackRPCServiceHealth", ctx, mock.Anything).Return() distAccSigClient := signing.NewSignatureClientMock(t) chAccSigClient := signing.NewSignatureClientMock(t) channelAccountStore := store.NewChannelAccountStoreMock(t) @@ -380,7 +377,6 @@ func Test_ChannelAccountService_EnsureChannelAccounts(t *testing.T) { PrivateKeyEncrypter: &signingutils.DefaultPrivateKeyEncrypter{}, EncryptionPassphrase: "my-encryption-passphrase", }) - time.Sleep(50 * time.Millisecond) // waiting for the goroutine to call `TrackRPCServiceHealth` require.NoError(t, err) // Execute test @@ -407,7 +403,6 @@ func TestSubmitTransaction(t *testing.T) { ctx := context.Background() mockRPCService := RPCServiceMock{} - mockRPCService.On("TrackRPCServiceHealth", ctx, mock.Anything).Return() defer mockRPCService.AssertExpectations(t) signatureClient := signing.SignatureClientMock{} channelAccountStore := store.ChannelAccountStoreMock{} @@ -422,9 +417,7 @@ func TestSubmitTransaction(t *testing.T) { PrivateKeyEncrypter: &privateKeyEncrypter, EncryptionPassphrase: passphrase, }) - time.Sleep(100 * time.Millisecond) // waiting for the goroutine to call `TrackRPCServiceHealth` require.NoError(t, err) - time.Sleep(100 * time.Millisecond) // waiting for the goroutine to call `TrackRPCServiceHealth` hash := "test_hash" signedTxXDR := "test_xdr" @@ -465,7 +458,6 @@ func TestWaitForTransactionConfirmation(t *testing.T) { ctx := context.Background() mockRPCService := RPCServiceMock{} defer mockRPCService.AssertExpectations(t) - mockRPCService.On("TrackRPCServiceHealth", ctx, mock.Anything).Return() signatureClient := signing.SignatureClientMock{} channelAccountStore := store.ChannelAccountStoreMock{} privateKeyEncrypter := signingutils.DefaultPrivateKeyEncrypter{} @@ -480,7 +472,6 @@ func TestWaitForTransactionConfirmation(t *testing.T) { EncryptionPassphrase: passphrase, }) require.NoError(t, err) - time.Sleep(100 * time.Millisecond) // waiting for the goroutine to call `TrackRPCServiceHealth` hash := "test_hash" diff --git a/internal/services/ingest.go b/internal/services/ingest.go index b0c39f057..0cc09c135 100644 --- a/internal/services/ingest.go +++ b/internal/services/ingest.go @@ -140,7 +140,11 @@ func NewIngestService( func (m *ingestService) DeprecatedRun(ctx context.Context, startLedger uint32, endLedger uint32) error { manualTriggerChannel := make(chan any, 1) - go m.rpcService.TrackRPCServiceHealth(ctx, manualTriggerChannel) + go func() { + if err := m.rpcService.TrackRPCServiceHealth(ctx, manualTriggerChannel); err != nil { + log.Ctx(ctx).Warnf("RPC health tracking stopped: %v", err) + } + }() ingestHeartbeatChannel := make(chan any, 1) rpcHeartbeatChannel := m.rpcService.GetHeartbeatChannel() go trackIngestServiceHealth(ctx, ingestHeartbeatChannel, m.appTracker) @@ -269,7 +273,11 @@ func (m *ingestService) Run(ctx context.Context, startLedger uint32, endLedger u // Prepare the health check: manualTriggerChan := make(chan any, 1) - go m.rpcService.TrackRPCServiceHealth(ctx, manualTriggerChan) + go func() { + if err := m.rpcService.TrackRPCServiceHealth(ctx, manualTriggerChan); err != nil { + log.Ctx(ctx).Warnf("RPC health tracking stopped: %v", err) + } + }() ingestHeartbeatChannel := make(chan any, 1) rpcHeartbeatChannel := m.rpcService.GetHeartbeatChannel() go trackIngestServiceHealth(ctx, ingestHeartbeatChannel, m.appTracker) @@ -277,12 +285,16 @@ func (m *ingestService) Run(ctx context.Context, startLedger uint32, endLedger u log.Ctx(ctx).Info("Starting ingestion loop") for { + var ok bool select { case sig := <-signalChan: return fmt.Errorf("ingestor stopped due to signal %q", sig) case <-ctx.Done(): return fmt.Errorf("ingestor stopped due to context cancellation: %w", ctx.Err()) - case rpcHealth = <-rpcHeartbeatChannel: + case rpcHealth, ok = <-rpcHeartbeatChannel: + if !ok { + return fmt.Errorf("RPC heartbeat channel closed unexpectedly") + } ingestHeartbeatChannel <- true // ⬅️ indicate that it's still running // this will fallthrough to execute the code below ⬇️ } diff --git a/internal/services/ingest_test.go b/internal/services/ingest_test.go index 0bb2ff4bf..cd8caf7b1 100644 --- a/internal/services/ingest_test.go +++ b/internal/services/ingest_test.go @@ -235,7 +235,7 @@ func TestIngest_LatestSyncedLedgerBehindRPC(t *testing.T) { mockAppTracker := apptracker.MockAppTracker{} mockRPCService := RPCServiceMock{} mockRPCService. - On("TrackRPCServiceHealth", ctx, mock.Anything).Once(). + On("TrackRPCServiceHealth", ctx, mock.Anything).Return(nil).Once(). On("NetworkPassphrase").Return(network.TestNetworkPassphrase) mockChAccStore := &store.ChannelAccountStoreMock{} mockContractStore := &cache.MockTokenContractStore{} @@ -328,7 +328,7 @@ func TestIngest_LatestSyncedLedgerAheadOfRPC(t *testing.T) { mockAppTracker := apptracker.MockAppTracker{} mockRPCService := RPCServiceMock{} mockRPCService. - On("TrackRPCServiceHealth", ctx, mock.Anything).Once(). + On("TrackRPCServiceHealth", ctx, mock.Anything).Return(nil).Once(). On("NetworkPassphrase").Return(network.TestNetworkPassphrase) mockChAccStore := &store.ChannelAccountStoreMock{} mockChAccStore.On("UnassignTxAndUnlockChannelAccounts", mock.Anything, mock.Anything, testInnerTxHash).Return(int64(1), nil).Twice() diff --git a/internal/services/mocks.go b/internal/services/mocks.go index 611c84c8f..bc6331cce 100644 --- a/internal/services/mocks.go +++ b/internal/services/mocks.go @@ -14,8 +14,9 @@ type RPCServiceMock struct { var _ RPCService = (*RPCServiceMock)(nil) -func (r *RPCServiceMock) TrackRPCServiceHealth(ctx context.Context, triggerHeartbeat chan any) { - r.Called(ctx, triggerHeartbeat) +func (r *RPCServiceMock) TrackRPCServiceHealth(ctx context.Context, triggerHeartbeat <-chan any) error { + args := r.Called(ctx, triggerHeartbeat) + return args.Error(0) } func (r *RPCServiceMock) GetHeartbeatChannel() chan entities.RPCGetHealthResult { @@ -58,6 +59,11 @@ func (r *RPCServiceMock) GetAccountLedgerSequence(address string) (int64, error) return args.Get(0).(int64), args.Error(1) } +func (r *RPCServiceMock) GetAccountInfo(address string) (AccountInfo, error) { + args := r.Called(address) + return args.Get(0).(AccountInfo), args.Error(1) +} + func (r *RPCServiceMock) SimulateTransaction(transactionXDR string, resourceConfig entities.RPCResourceConfig) (entities.RPCSimulateTransactionResult, error) { args := r.Called(transactionXDR, resourceConfig) return args.Get(0).(entities.RPCSimulateTransactionResult), args.Error(1) diff --git a/internal/services/rpc_service.go b/internal/services/rpc_service.go index ff9417f18..0d44e5cec 100644 --- a/internal/services/rpc_service.go +++ b/internal/services/rpc_service.go @@ -8,9 +8,6 @@ import ( "fmt" "io" "net/http" - "os" - "os/signal" - "syscall" "time" "github.com/stellar/go/support/log" @@ -28,6 +25,12 @@ const ( getHealthMethodName = "getHealth" ) +// AccountInfo contains the balance and sequence number for a Stellar account. +type AccountInfo struct { + Balance int64 // Balance in stroops + SeqNum int64 +} + type RPCService interface { GetTransaction(transactionHash string) (entities.RPCGetTransactionResult, error) GetTransactions(startLedger int64, startCursor string, limit int) (entities.RPCGetTransactionsResult, error) @@ -36,13 +39,16 @@ type RPCService interface { GetLedgers(startLedger uint32, limit uint32) (GetLedgersResponse, error) GetLedgerEntries(keys []string) (entities.RPCGetLedgerEntriesResult, error) GetAccountLedgerSequence(address string) (int64, error) + GetAccountInfo(address string) (AccountInfo, error) GetHeartbeatChannel() chan entities.RPCGetHealthResult // TrackRPCServiceHealth continuously monitors the health of the RPC service and updates metrics. // It runs health checks at regular intervals and can be triggered on-demand via immediateHealthCheckTrigger. // // The immediateHealthCheckTrigger channel allows external components to request an immediate health check, // which is particularly useful when the ingestor needs to catch up with the RPC service. - TrackRPCServiceHealth(ctx context.Context, immediateHealthCheckTrigger chan any) + // + // Returns an error if the context is cancelled. The caller is responsible for handling shutdown signals. + TrackRPCServiceHealth(ctx context.Context, immediateHealthCheckTrigger <-chan any) error SimulateTransaction(transactionXDR string, resourceConfig entities.RPCResourceConfig) (entities.RPCSimulateTransactionResult, error) NetworkPassphrase() string } @@ -310,6 +316,42 @@ func (r *rpcService) GetAccountLedgerSequence(address string) (int64, error) { return int64(accountEntry.SeqNum), nil } +func (r *rpcService) GetAccountInfo(address string) (AccountInfo, error) { + startTime := time.Now() + r.metricsService.IncRPCMethodCalls("GetAccountInfo") + defer func() { + duration := time.Since(startTime).Seconds() + r.metricsService.ObserveRPCMethodDuration("GetAccountInfo", duration) + }() + + keyXdr, err := utils.GetAccountLedgerKey(address) + if err != nil { + r.metricsService.IncRPCMethodErrors("GetAccountInfo", "validation_error") + return AccountInfo{}, fmt.Errorf("getting ledger key for account public key: %w", err) + } + result, err := r.GetLedgerEntries([]string{keyXdr}) + if err != nil { + r.metricsService.IncRPCMethodErrors("GetAccountInfo", "rpc_error") + return AccountInfo{}, fmt.Errorf("getting ledger entry for account public key: %w", err) + } + if len(result.Entries) == 0 { + r.metricsService.IncRPCMethodErrors("GetAccountInfo", "not_found_error") + return AccountInfo{}, fmt.Errorf("%w: entry not found for account public key", ErrAccountNotFound) + } + + var ledgerEntryData xdr.LedgerEntryData + err = xdr.SafeUnmarshalBase64(result.Entries[0].DataXDR, &ledgerEntryData) + if err != nil { + r.metricsService.IncRPCMethodErrors("GetAccountInfo", "xdr_decode_error") + return AccountInfo{}, fmt.Errorf("decoding account entry for account public key: %w", err) + } + accountEntry := ledgerEntryData.MustAccount() + return AccountInfo{ + Balance: int64(accountEntry.Balance), + SeqNum: int64(accountEntry.SeqNum), + }, nil +} + func (r *rpcService) NetworkPassphrase() string { return r.networkPassphrase } @@ -333,14 +375,17 @@ func (r *rpcService) HealthCheckTickInterval() time.Duration { // // The immediateHealthCheckTrigger channel allows external components to request an immediate health check, // which is particularly useful when the ingestor needs to catch up with the RPC service. -func (r *rpcService) TrackRPCServiceHealth(ctx context.Context, immediateHealthCheckTrigger chan any) { - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT) +// +// Returns an error if the context is cancelled. The caller is responsible for handling shutdown signals. +func (r *rpcService) TrackRPCServiceHealth(ctx context.Context, immediateHealthCheckTrigger <-chan any) error { + // Handle nil channel by creating a never-firing channel + if immediateHealthCheckTrigger == nil { + immediateHealthCheckTrigger = make(chan any) + } healthCheckTicker := time.NewTicker(r.HealthCheckTickInterval()) unhealthyWarningTicker := time.NewTicker(r.HealthCheckWarningInterval()) defer func() { - signal.Stop(signalChan) healthCheckTicker.Stop() unhealthyWarningTicker.Stop() close(r.heartbeatChannel) @@ -368,15 +413,14 @@ func (r *rpcService) TrackRPCServiceHealth(ctx context.Context, immediateHealthC r.metricsService.SetRPCLatestLedger(int64(health.LatestLedger)) } + // Perform immediate health check at startup to avoid 5-second delay + performHealthCheck() + for { select { case <-ctx.Done(): log.Ctx(ctx).Infof("RPC health tracking stopped due to context cancellation: %v", ctx.Err()) - return - - case sig := <-signalChan: - log.Ctx(ctx).Warnf("RPC health tracking stopped due to signal %s", sig) - return + return fmt.Errorf("context cancelled: %w", ctx.Err()) case <-unhealthyWarningTicker.C: log.Ctx(ctx).Warnf("RPC service unhealthy for over %s", r.HealthCheckWarningInterval()) diff --git a/internal/services/rpc_service_test.go b/internal/services/rpc_service_test.go index 01f079fcc..245ace157 100644 --- a/internal/services/rpc_service_test.go +++ b/internal/services/rpc_service_test.go @@ -926,7 +926,8 @@ func TestTrackRPCServiceHealth_HealthyService(t *testing.T) { mockHTTPClient.On("Post", rpcURL, "application/json", mock.Anything).Return(mockResponse, nil).Run(func(args mock.Arguments) { cancel() }) - rpcService.TrackRPCServiceHealth(ctx, nil) + err = rpcService.TrackRPCServiceHealth(ctx, nil) + require.Error(t, err) // Get result from heartbeat channel select { @@ -998,7 +999,8 @@ func TestTrackRPCServiceHealth_UnhealthyService(t *testing.T) { Return(mockResponse, nil) // The ctx will timeout after {contextTimeout}, which is enough for the warning to trigger - rpcService.TrackRPCServiceHealth(ctx, nil) + err = rpcService.TrackRPCServiceHealth(ctx, nil) + require.Error(t, err) entries := getLogs() testSucceeded := false @@ -1014,8 +1016,9 @@ func TestTrackRPCServiceHealth_UnhealthyService(t *testing.T) { } func TestTrackRPCService_ContextCancelled(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + // Create and immediately cancel context to test cancellation handling + ctx, cancel := context.WithCancel(context.Background()) + cancel() dbt := dbtest.Open(t) defer dbt.Close() @@ -1029,14 +1032,26 @@ func TestTrackRPCService_ContextCancelled(t *testing.T) { rpcService, err := NewRPCService(rpcURL, network.TestNetworkPassphrase, mockHTTPClient, mockMetricsService) require.NoError(t, err) - rpcService.TrackRPCServiceHealth(ctx, nil) + // Mock metrics for the initial health check that happens before context check + mockMetricsService.On("IncRPCMethodCalls", "GetHealth").Maybe() + mockMetricsService.On("ObserveRPCMethodDuration", "GetHealth", mock.AnythingOfType("float64")).Maybe() + mockMetricsService.On("IncRPCRequests", "getHealth").Maybe() + mockMetricsService.On("IncRPCEndpointFailure", "getHealth").Maybe() + mockMetricsService.On("IncRPCMethodErrors", "GetHealth", "rpc_error").Maybe() + mockMetricsService.On("ObserveRPCRequestDuration", "getHealth", mock.AnythingOfType("float64")).Maybe() + mockMetricsService.On("SetRPCServiceHealth", false).Maybe() + + // Mock HTTP client to return error (simulating cancelled context) + mockHTTPClient.On("Post", rpcURL, "application/json", mock.Anything). + Return(&http.Response{}, context.Canceled).Maybe() + + err = rpcService.TrackRPCServiceHealth(ctx, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "context") // Verify channel is closed after context cancellation - time.Sleep(100 * time.Millisecond) _, ok := <-rpcService.GetHeartbeatChannel() assert.False(t, ok, "channel should be closed") - - mockHTTPClient.AssertNotCalled(t, "Post") } func TestTrackRPCService_DeadlockPrevention(t *testing.T) { @@ -1086,7 +1101,10 @@ func TestTrackRPCService_DeadlockPrevention(t *testing.T) { defer cancel() manualTriggerChan := make(chan any, 1) - go rpcService.TrackRPCServiceHealth(ctx, manualTriggerChan) + go func() { + //nolint:errcheck // Error is expected on context cancellation + rpcService.TrackRPCServiceHealth(ctx, manualTriggerChan) + }() time.Sleep(20 * time.Millisecond) manualTriggerChan <- nil @@ -1097,3 +1115,134 @@ func TestTrackRPCService_DeadlockPrevention(t *testing.T) { t.Log("🎉 Deadlock prevented!") } } + +func TestGetAccountInfo(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + mockMetricsService := metrics.NewMockMetricsService() + mockHTTPClient := utils.MockHTTPClient{} + rpcURL := "http://api.vibrantapp.com/soroban/rpc" + rpcService, err := NewRPCService(rpcURL, network.TestNetworkPassphrase, &mockHTTPClient, mockMetricsService) + require.NoError(t, err) + + testAccountAddress := "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5" + + t.Run("successful", func(t *testing.T) { + mockMetricsService.On("IncRPCMethodCalls", "GetAccountInfo").Once() + mockMetricsService.On("ObserveRPCMethodDuration", "GetAccountInfo", mock.AnythingOfType("float64")).Once() + mockMetricsService.On("IncRPCMethodCalls", "GetLedgerEntries").Once() + mockMetricsService.On("ObserveRPCMethodDuration", "GetLedgerEntries", mock.AnythingOfType("float64")).Once() + mockMetricsService.On("IncRPCRequests", "getLedgerEntries").Once() + mockMetricsService.On("IncRPCEndpointSuccess", "getLedgerEntries").Once() + mockMetricsService.On("ObserveRPCRequestDuration", "getLedgerEntries", mock.AnythingOfType("float64")).Once() + defer mockMetricsService.AssertExpectations(t) + + // Create account entry XDR with balance=10000000000 (1000 XLM) and seqNum=12345 + accountEntry := xdr.LedgerEntryData{ + Type: xdr.LedgerEntryTypeAccount, + Account: &xdr.AccountEntry{ + AccountId: xdr.MustAddress(testAccountAddress), + Balance: 10000000000, // 1000 XLM + SeqNum: 12345, + }, + } + accountEntryXDR, err := xdr.MarshalBase64(accountEntry) + require.NoError(t, err) + + httpResponse := http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "entries": [{ + "xdr": "%s" + }] + } + }`, accountEntryXDR))), + } + + mockHTTPClient. + On("Post", rpcURL, "application/json", mock.Anything). + Return(&httpResponse, nil). + Once() + + result, err := rpcService.GetAccountInfo(testAccountAddress) + require.NoError(t, err) + assert.Equal(t, int64(10000000000), result.Balance) + assert.Equal(t, int64(12345), result.SeqNum) + }) + + t.Run("account_not_found", func(t *testing.T) { + mockMetricsService.On("IncRPCMethodCalls", "GetAccountInfo").Once() + mockMetricsService.On("ObserveRPCMethodDuration", "GetAccountInfo", mock.AnythingOfType("float64")).Once() + mockMetricsService.On("IncRPCMethodCalls", "GetLedgerEntries").Once() + mockMetricsService.On("ObserveRPCMethodDuration", "GetLedgerEntries", mock.AnythingOfType("float64")).Once() + mockMetricsService.On("IncRPCRequests", "getLedgerEntries").Once() + mockMetricsService.On("IncRPCEndpointSuccess", "getLedgerEntries").Once() + mockMetricsService.On("ObserveRPCRequestDuration", "getLedgerEntries", mock.AnythingOfType("float64")).Once() + mockMetricsService.On("IncRPCMethodErrors", "GetAccountInfo", "not_found_error").Once() + defer mockMetricsService.AssertExpectations(t) + + httpResponse := http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "entries": [] + } + }`)), + } + + mockHTTPClient. + On("Post", rpcURL, "application/json", mock.Anything). + Return(&httpResponse, nil). + Once() + + result, err := rpcService.GetAccountInfo(testAccountAddress) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrAccountNotFound)) + assert.Contains(t, err.Error(), "entry not found for account public key") + assert.Equal(t, AccountInfo{}, result) + }) + + t.Run("invalid_address", func(t *testing.T) { + mockMetricsService.On("IncRPCMethodCalls", "GetAccountInfo").Once() + mockMetricsService.On("ObserveRPCMethodDuration", "GetAccountInfo", mock.AnythingOfType("float64")).Once() + mockMetricsService.On("IncRPCMethodErrors", "GetAccountInfo", "validation_error").Once() + defer mockMetricsService.AssertExpectations(t) + + result, err := rpcService.GetAccountInfo("invalid-address") + require.Error(t, err) + assert.Contains(t, err.Error(), "getting ledger key for account public key") + assert.Equal(t, AccountInfo{}, result) + }) + + t.Run("rpc_request_fails", func(t *testing.T) { + mockMetricsService.On("IncRPCMethodCalls", "GetAccountInfo").Once() + mockMetricsService.On("ObserveRPCMethodDuration", "GetAccountInfo", mock.AnythingOfType("float64")).Once() + mockMetricsService.On("IncRPCMethodCalls", "GetLedgerEntries").Once() + mockMetricsService.On("ObserveRPCMethodDuration", "GetLedgerEntries", mock.AnythingOfType("float64")).Once() + mockMetricsService.On("IncRPCRequests", "getLedgerEntries").Once() + mockMetricsService.On("IncRPCEndpointFailure", "getLedgerEntries").Once() + mockMetricsService.On("ObserveRPCRequestDuration", "getLedgerEntries", mock.AnythingOfType("float64")).Once() + mockMetricsService.On("IncRPCMethodErrors", "GetLedgerEntries", "rpc_error").Once() + mockMetricsService.On("IncRPCMethodErrors", "GetAccountInfo", "rpc_error").Once() + defer mockMetricsService.AssertExpectations(t) + + mockHTTPClient. + On("Post", rpcURL, "application/json", mock.Anything). + Return(&http.Response{}, errors.New("connection failed")). + Once() + + result, err := rpcService.GetAccountInfo(testAccountAddress) + require.Error(t, err) + assert.Contains(t, err.Error(), "getting ledger entry for account public key") + assert.Equal(t, AccountInfo{}, result) + }) +}