diff --git a/internal/pkg/logger/file.go b/internal/pkg/logger/file.go index 06f3bb7..6f0a16e 100644 --- a/internal/pkg/logger/file.go +++ b/internal/pkg/logger/file.go @@ -93,6 +93,11 @@ func (l *FileTransactionLogger) ReadEvents() (<-chan Event, <-chan error) { return outEvent, outErr } +func (l *FileTransactionLogger) Close() error { + close(l.events) + return l.file.Close() +} + type Event struct { Sequence uint64 Kind EventKind diff --git a/internal/pkg/logger/file_test.go b/internal/pkg/logger/file_test.go index 42a3e20..799cbec 100644 --- a/internal/pkg/logger/file_test.go +++ b/internal/pkg/logger/file_test.go @@ -6,7 +6,10 @@ import ( "strings" "sync" "testing" + "testing/synctest" "time" + + "github.com/stretchr/testify/assert" ) // mockReadWriteCloser is a mock implementation of an io.ReadWriterCloser for testing @@ -98,35 +101,41 @@ func TestFileTransactionLogger_Run(t *testing.T) { // TestFileTransactionLogger_WritePut tests writing PUT events func TestFileTransactionLogger_WritePut(t *testing.T) { - mock := newMockReadWriteCloser("") - logger := NewFileTransactionLogger(mock) - logger.Run() + synctest.Test(t, func(t *testing.T) { + mock := newMockReadWriteCloser("") + logger := NewFileTransactionLogger(mock) + logger.Run() - // Give goroutine time to start - time.Sleep(10 * time.Millisecond) + // Give goroutine time to start + time.Sleep(10 * time.Millisecond) - logger.WritePut("key1", "value1") - logger.WritePut("key2", "value2") + logger.WritePut("key1", "value1") + logger.WritePut("key2", "value2") - // Give time for writes to complete - time.Sleep(50 * time.Millisecond) + // Give time for writes to complete + time.Sleep(50 * time.Millisecond) - output := mock.String() - expectedLines := []string{ - "1\t2\tkey1\tvalue1", - "2\t2\tkey2\tvalue2", - } + output := mock.String() + expectedLines := []string{ + "1\t2\tkey1\tvalue1", + "2\t2\tkey2\tvalue2", + } - for _, expected := range expectedLines { - if !strings.Contains(output, expected) { - t.Errorf("Expected output to contain %q, got: %s", expected, output) + for _, expected := range expectedLines { + if !strings.Contains(output, expected) { + t.Errorf("Expected output to contain %q, got: %s", expected, output) + } } - } - last := logger.lastSequence.Load() - if last != 2 { - t.Errorf("Expected lastSequence to be 2, got %d", last) - } + last := logger.lastSequence.Load() + if last != 2 { + t.Errorf("Expected lastSequence to be 2, got %d", last) + } + + err := logger.Close() + assert.Nil(t, err) + synctest.Wait() + }) } // TestFileTransactionLogger_WriteDelete tests writing DELETE events diff --git a/internal/pkg/logger/postgres.go b/internal/pkg/logger/postgres.go index 1ed12ff..6ca45dd 100644 --- a/internal/pkg/logger/postgres.go +++ b/internal/pkg/logger/postgres.go @@ -164,3 +164,8 @@ func (p *PostgresTransactionLogger) createTable() error { return nil } + +func (p *PostgresTransactionLogger) Close() error { + close(p.events) + return nil +} diff --git a/internal/pkg/logger/transaction.go b/internal/pkg/logger/transaction.go index 7080a25..2203b15 100644 --- a/internal/pkg/logger/transaction.go +++ b/internal/pkg/logger/transaction.go @@ -7,4 +7,5 @@ type TransactionLog interface { Run() ReadEvents() (<-chan Event, <-chan error) + Close() error }