diff --git a/cmd/api/app.go b/cmd/api/app.go index 1d93bd78..f2e1183f 100644 --- a/cmd/api/app.go +++ b/cmd/api/app.go @@ -70,6 +70,11 @@ func BuildApplication(cfg appconf.Config, gtfsCfg gtfs.Config) (*app.Application var directionCalculator *gtfs.AdvancedDirectionCalculator if gtfsManager != nil { directionCalculator = gtfs.NewAdvancedDirectionCalculator(gtfsManager.GtfsDB.Queries) + + err = gtfs.InitializeGlobalCache(context.Background(), gtfsManager.GtfsDB.Queries, directionCalculator) + if err != nil { + return nil, fmt.Errorf("failed to initialize global cache: %w", err) + } } // Select clock implementation based on environment diff --git a/gtfsdb/db.go b/gtfsdb/db.go index 0cb8a80c..696dc9f8 100644 --- a/gtfsdb/db.go +++ b/gtfsdb/db.go @@ -117,6 +117,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.getAllShapesStmt, err = db.PrepareContext(ctx, getAllShapes); err != nil { return nil, fmt.Errorf("error preparing query GetAllShapes: %w", err) } + if q.getAllStopIDsStmt, err = db.PrepareContext(ctx, getAllStopIDs); err != nil { + return nil, fmt.Errorf("error preparing query GetAllStopIDs: %w", err) + } if q.getAllTripsForRouteStmt, err = db.PrepareContext(ctx, getAllTripsForRoute); err != nil { return nil, fmt.Errorf("error preparing query GetAllTripsForRoute: %w", err) } @@ -460,6 +463,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getAllShapesStmt: %w", cerr) } } + if q.getAllStopIDsStmt != nil { + if cerr := q.getAllStopIDsStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getAllStopIDsStmt: %w", cerr) + } + } if q.getAllTripsForRouteStmt != nil { if cerr := q.getAllTripsForRouteStmt.Close(); cerr != nil { err = fmt.Errorf("error closing getAllTripsForRouteStmt: %w", cerr) @@ -835,6 +843,7 @@ type Queries struct { getAgencyStmt *sql.Stmt getAgencyForStopStmt *sql.Stmt getAllShapesStmt *sql.Stmt + getAllStopIDsStmt *sql.Stmt getAllTripsForRouteStmt *sql.Stmt getArrivalsAndDeparturesForStopStmt *sql.Stmt getBlockDetailsStmt *sql.Stmt @@ -933,6 +942,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { getAgencyStmt: q.getAgencyStmt, getAgencyForStopStmt: q.getAgencyForStopStmt, getAllShapesStmt: q.getAllShapesStmt, + getAllStopIDsStmt: q.getAllStopIDsStmt, getAllTripsForRouteStmt: q.getAllTripsForRouteStmt, getArrivalsAndDeparturesForStopStmt: q.getArrivalsAndDeparturesForStopStmt, getBlockDetailsStmt: q.getBlockDetailsStmt, diff --git a/gtfsdb/query.sql b/gtfsdb/query.sql index 8b846634..1797ee35 100644 --- a/gtfsdb/query.sql +++ b/gtfsdb/query.sql @@ -194,6 +194,12 @@ ORDER BY LIMIT 1; +-- name: GetAllStopIDs :many +SELECT + id +FROM + stops; + -- name: GetStopIDsForAgency :many SELECT DISTINCT s.id diff --git a/gtfsdb/query.sql.go b/gtfsdb/query.sql.go index 12903d19..ceaf0251 100644 --- a/gtfsdb/query.sql.go +++ b/gtfsdb/query.sql.go @@ -1136,6 +1136,36 @@ func (q *Queries) GetAllShapes(ctx context.Context) ([]Shape, error) { return items, nil } +const getAllStopIDs = `-- name: GetAllStopIDs :many +SELECT + id +FROM + stops +` + +func (q *Queries) GetAllStopIDs(ctx context.Context) ([]string, error) { + rows, err := q.query(ctx, q.getAllStopIDsStmt, getAllStopIDs) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + items = append(items, id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getAllTripsForRoute = `-- name: GetAllTripsForRoute :many SELECT DISTINCT id, route_id, service_id, trip_headsign, trip_short_name, direction_id, block_id, shape_id, wheelchair_accessible, bikes_allowed FROM trips t diff --git a/internal/gtfs/advanced_direction_calculator_test.go b/internal/gtfs/advanced_direction_calculator_test.go index 6c2ef17e..253e60ff 100644 --- a/internal/gtfs/advanced_direction_calculator_test.go +++ b/internal/gtfs/advanced_direction_calculator_test.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "math" + "os" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -11,6 +13,44 @@ import ( "maglev.onebusaway.org/internal/models" ) +// This uses a Singleton pattern to load the DB and Warm the Cache exactly ONCE +// for this test file. This prevents re-loading the ZIP file 15+ times. + +var ( + sharedManager *Manager + sharedCalc *AdvancedDirectionCalculator + setupOnce sync.Once +) + +// Helper function to get the shared instances. +func getSharedTestComponents(t *testing.T) (*Manager, *AdvancedDirectionCalculator) { + setupOnce.Do(func() { + // Initialize the DB (In-Memory) + gtfsConfig := Config{ + GtfsURL: models.GetFixturePath(t, "raba.zip"), + GTFSDataPath: ":memory:", + } + + var err error + sharedManager, err = InitGTFSManager(gtfsConfig) + if err != nil { + panic("Failed to init shared GTFS manager: " + err.Error()) + } + + // Create the Calculator + sharedCalc = NewAdvancedDirectionCalculator(sharedManager.GtfsDB.Queries) + + // Warm the Global Cache (The heavy operation) + // We do this only once per test suite execution. + err = InitializeGlobalCache(context.Background(), sharedManager.GtfsDB.Queries, sharedCalc) + if err != nil { + panic("Failed to warm global cache: " + err.Error()) + } + }) + + return sharedManager, sharedCalc +} + func TestTranslateGtfsDirection(t *testing.T) { calc := &AdvancedDirectionCalculator{} @@ -41,7 +81,6 @@ func TestTranslateGtfsDirection(t *testing.T) { {"225 degrees", "225", "SW"}, {"270 degrees", "270", "W"}, {"315 degrees", "315", "NW"}, - // Invalid {"invalid text", "invalid", ""}, {"empty string", "", ""}, @@ -139,7 +178,6 @@ func TestStatisticalFunctions(t *testing.T) { m := mean(values) v := variance(values, m) assert.InDelta(t, 2.5, v, 0.001) // Sample variance of 1,2,3,4,5 is 2.5 - assert.Equal(t, 0.0, variance([]float64{5}, 5.0)) }) @@ -159,74 +197,41 @@ func TestStatisticalFunctions(t *testing.T) { func TestVarianceThreshold(t *testing.T) { calc := NewAdvancedDirectionCalculator(nil) - // Test default threshold assert.Equal(t, defaultVarianceThreshold, calc.varianceThreshold) - // Test setting custom threshold calc.SetVarianceThreshold(1.0) assert.Equal(t, 1.0, calc.varianceThreshold) } func TestCalculateStopDirection_WithShapeData(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() + // Optimization: Reuse shared DB and Cache + _, calc := getSharedTestComponents(t) - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) - - // Test with a real stop from RABA data direction := calc.CalculateStopDirection(context.Background(), "7000", sql.NullString{Valid: false}) - // Should return a valid direction or empty string assert.True(t, direction == "" || len(direction) <= 2) } func TestComputeFromShapes_NoShapeData(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + // Optimization: Reuse shared DB and Cache + _, calc := getSharedTestComponents(t) - // Test with a non-existent stop direction := calc.computeFromShapes(context.Background(), "nonexistent") assert.Equal(t, "", direction) } func TestComputeFromShapes_SingleOrientation(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + // Optimization: Reuse shared DB and Cache + _, calc := getSharedTestComponents(t) - // Test with actual stop data - single orientation path will be taken if only one trip direction := calc.computeFromShapes(context.Background(), "7000") - // Direction should be valid or empty assert.True(t, direction == "" || len(direction) <= 2) } func TestComputeFromShapes_VarianceThreshold(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - + // Note: We reuse the Shared Manager (DB) but create a NEW Calculator. + // This is because we modify the variance threshold and don't want to break other tests. + manager, _ := getSharedTestComponents(t) calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) // Set a very low variance threshold to trigger variance check @@ -239,15 +244,7 @@ func TestComputeFromShapes_VarianceThreshold(t *testing.T) { } func TestCalculateOrientationAtStop_WithDistanceTraveled(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + manager, calc := getSharedTestComponents(t) // Get a shape ID from the database shapes, err := manager.GtfsDB.Queries.GetShapePointsWithDistance(context.Background(), "19_0_1") @@ -264,15 +261,7 @@ func TestCalculateOrientationAtStop_WithDistanceTraveled(t *testing.T) { } func TestCalculateOrientationAtStop_GeographicMatching(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + manager, calc := getSharedTestComponents(t) // Get a shape ID from the database shapes, err := manager.GtfsDB.Queries.GetShapePointsWithDistance(context.Background(), "19_0_1") @@ -291,15 +280,7 @@ func TestCalculateOrientationAtStop_GeographicMatching(t *testing.T) { } func TestCalculateOrientationAtStop_NoShapePoints(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + _, calc := getSharedTestComponents(t) // Test with non-existent shape - should return error or 0 orientation orientation, err := calc.calculateOrientationAtStop(context.Background(), "nonexistent", 0, 0, 0) @@ -308,22 +289,13 @@ func TestCalculateOrientationAtStop_NoShapePoints(t *testing.T) { } func TestCalculateOrientationAtStop_EdgeCases(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + manager, calc := getSharedTestComponents(t) // Test with shape that has points at the boundaries shapes, err := manager.GtfsDB.Queries.GetShapePointsWithDistance(context.Background(), "19_0_1") if err != nil || len(shapes) < 2 { t.Skip("No shape data available for testing") } - // Test at the very beginning of the shape if len(shapes) > 0 && shapes[0].ShapeDistTraveled.Valid { orientation, err := calc.calculateOrientationAtStop(context.Background(), "19_0_1", shapes[0].ShapeDistTraveled.Float64, 0, 0) @@ -431,18 +403,7 @@ func TestSetContextCache_PanicAfterInit(t *testing.T) { } func TestCalculateStopDirection_VariadicSignature(t *testing.T) { - - // Setup in-memory DB so the calculator has a valid query interface - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - // Create the calculator using the VALID queries object - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + _, calc := getSharedTestComponents(t) // Case 1: Caller provides the optimized direction (should be used instantly) // We pass "North", expect "N" @@ -457,15 +418,8 @@ func TestCalculateStopDirection_VariadicSignature(t *testing.T) { } func TestSetContextCache_ConcurrentAccess(t *testing.T) { - // Setup - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - + manager, _ := getSharedTestComponents(t) + // We use shared DB, but MUST use a fresh Calculator to test the race condition specifically on that instance. calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) // Create dummy cache @@ -506,22 +460,10 @@ func TestSetContextCache_ConcurrentAccess(t *testing.T) { // TestBulkQuery_GetStopsWithShapeContextByIDs verifies the bulk optimization func TestBulkQuery_GetStopsWithShapeContextByIDs(t *testing.T) { - // Setup - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - if err != nil { - t.Fatalf("Failed to init manager: %v", err) - } - defer manager.Shutdown() - + manager, _ := getSharedTestComponents(t) ctx := context.Background() - // DYNAMICALLY fetch valid Stop IDs rows, err := manager.GtfsDB.DB.QueryContext(ctx, "SELECT id FROM stops LIMIT 5") - if err != nil { t.Fatalf("Failed to query stops: %v", err) } @@ -564,22 +506,12 @@ func TestBulkQuery_GetStopsWithShapeContextByIDs(t *testing.T) { // TestBulkQuery_GetShapePointsByIDs verifies fetching shape points in bulk. func TestBulkQuery_GetShapePointsByIDs(t *testing.T) { - // Setup - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - if err != nil { - t.Fatalf("Failed to init manager: %v", err) - } - defer manager.Shutdown() - + manager, _ := getSharedTestComponents(t) ctx := context.Background() // DYNAMICALLY fetch a real Shape ID from the DB var shapeID string - err = manager.GtfsDB.DB.QueryRowContext(ctx, "SELECT shape_id FROM shapes LIMIT 1").Scan(&shapeID) + err := manager.GtfsDB.DB.QueryRowContext(ctx, "SELECT shape_id FROM shapes LIMIT 1").Scan(&shapeID) // Stop immediately on error if err != nil { @@ -607,3 +539,17 @@ func TestBulkQuery_GetShapePointsByIDs(t *testing.T) { } assert.True(t, isSorted, "Shape points should be returned in sequence order") } + +func TestMain(m *testing.M) { + // Run all tests + code := m.Run() + + // Global Teardown + // If sharedManager was initialized during tests, shut it down now. + if sharedManager != nil { + sharedManager.Shutdown() + } + + // Exit with the test result code + os.Exit(code) +} diff --git a/internal/gtfs/global_cache.go b/internal/gtfs/global_cache.go new file mode 100644 index 00000000..53a44d87 --- /dev/null +++ b/internal/gtfs/global_cache.go @@ -0,0 +1,71 @@ +package gtfs + +import ( + "context" + "fmt" + "log/slog" + + "maglev.onebusaway.org/gtfsdb" +) + +func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *AdvancedDirectionCalculator) error { + slog.Info("starting global cache warmup...") + + allStopIDs, err := queries.GetAllStopIDs(ctx) + if err != nil { + return fmt.Errorf("failed to fetch all stop IDs: %w", err) + } + + // Fetch Context (Stop -> Shape mappings) + contextRows, err := queries.GetStopsWithShapeContextByIDs(ctx, allStopIDs) + if err != nil { + return fmt.Errorf("failed to fetch stop context rows: %w", err) + } + + contextCache := make(map[string][]gtfsdb.GetStopsWithShapeContextRow) + shapeIDMap := make(map[string]bool) + var uniqueShapeIDs []string + + for _, row := range contextRows { + calcRow := gtfsdb.GetStopsWithShapeContextRow{ + ID: row.StopID, + ShapeID: row.ShapeID, + Lat: row.Lat, + Lon: row.Lon, + ShapeDistTraveled: row.ShapeDistTraveled, + } + contextCache[row.StopID] = append(contextCache[row.StopID], calcRow) + + if row.ShapeID.Valid && row.ShapeID.String != "" && !shapeIDMap[row.ShapeID.String] { + shapeIDMap[row.ShapeID.String] = true + uniqueShapeIDs = append(uniqueShapeIDs, row.ShapeID.String) + } + } + + shapeCache := make(map[string][]gtfsdb.GetShapePointsWithDistanceRow) + + if len(uniqueShapeIDs) > 0 { + shapePoints, err := queries.GetShapePointsByIDs(ctx, uniqueShapeIDs) + if err != nil { + return fmt.Errorf("failed to fetch shape points for global cache: %w", err) + } + + for _, p := range shapePoints { + shapeCache[p.ShapeID] = append(shapeCache[p.ShapeID], gtfsdb.GetShapePointsWithDistanceRow{ + Lat: p.Lat, + Lon: p.Lon, + ShapeDistTraveled: p.ShapeDistTraveled, + ShapePtSequence: p.ShapePtSequence, + }) + } + } + + adc.SetShapeCache(shapeCache) + adc.SetContextCache(contextCache) + + slog.Info("global cache warmup complete", + slog.Int("stops_cached", len(contextCache)), + slog.Int("shapes_cached", len(shapeCache))) + + return nil +} diff --git a/internal/gtfs/gtfs_manager_test.go b/internal/gtfs/gtfs_manager_test.go index b51f5ab3..814e183e 100644 --- a/internal/gtfs/gtfs_manager_test.go +++ b/internal/gtfs/gtfs_manager_test.go @@ -15,81 +15,43 @@ import ( ) func TestManager_GetAgencies(t *testing.T) { - testCases := []struct { - name string - dataPath string - }{ - { - name: "FromLocalFile", - dataPath: models.GetFixturePath(t, "raba.zip"), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: tc.dataPath, - Env: appconf.Test, - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - agencies := manager.GetAgencies() - assert.Equal(t, 1, len(agencies)) - - agency := agencies[0] - assert.Equal(t, "25", agency.Id) - assert.Equal(t, "Redding Area Bus Authority", agency.Name) - assert.Equal(t, "http://www.rabaride.com/", agency.Url) - assert.Equal(t, "America/Los_Angeles", agency.Timezone) - assert.Equal(t, "en", agency.Language) - assert.Equal(t, "530-241-2877", agency.Phone) - assert.Equal(t, "", agency.FareUrl) - assert.Equal(t, "", agency.Email) - }) - } + // Use shared component to avoid reloading DB + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) + + agencies := manager.GetAgencies() + assert.Equal(t, 1, len(agencies)) + + agency := agencies[0] + assert.Equal(t, "25", agency.Id) + assert.Equal(t, "Redding Area Bus Authority", agency.Name) + assert.Equal(t, "http://www.rabaride.com/", agency.Url) + assert.Equal(t, "America/Los_Angeles", agency.Timezone) + assert.Equal(t, "en", agency.Language) + assert.Equal(t, "530-241-2877", agency.Phone) + assert.Equal(t, "", agency.FareUrl) + assert.Equal(t, "", agency.Email) } func TestManager_RoutesForAgencyID(t *testing.T) { - testCases := []struct { - name string - dataPath string - }{ - { - name: "FromLocalFile", - dataPath: models.GetFixturePath(t, "raba.zip"), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: tc.dataPath, - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - manager.RLock() - routes := manager.RoutesForAgencyID("25") - manager.RUnlock() - assert.Equal(t, 13, len(routes)) - - route := routes[0] - assert.Equal(t, "1", route.ShortName) - assert.Equal(t, "25", route.Agency.Id) - }) - } + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) + + manager.RLock() + routes := manager.RoutesForAgencyID("25") + manager.RUnlock() + assert.Equal(t, 13, len(routes)) + + route := routes[0] + assert.Equal(t, "1", route.ShortName) + assert.Equal(t, "25", route.Agency.Id) } func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { testCases := []struct { name string - dataPath string lat float64 lon float64 radius float64 @@ -97,7 +59,6 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { }{ { name: "FindStopsWithinRadius", - dataPath: models.GetFixturePath(t, "raba.zip"), lat: 40.589123, // Near Redding, CA lon: -122.390830, radius: 2000, // 2km radius @@ -105,7 +66,6 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { }, { name: "FindStopsWithinRadius", - dataPath: models.GetFixturePath(t, "raba.zip"), lat: 47.589123, // West Seattle lon: -122.390830, radius: 2000, // 2km radius @@ -115,14 +75,8 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: tc.dataPath, - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) // Get stops using the manager method stops := manager.GetStopsForLocation(context.Background(), tc.lat, tc.lon, tc.radius, 0, 0, "", 100, false, nil, time.Time{}) @@ -142,14 +96,8 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { } func TestManager_GetTrips(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) trips := manager.GetTrips() assert.NotEmpty(t, trips) @@ -157,14 +105,7 @@ func TestManager_GetTrips(t *testing.T) { } func TestManager_FindAgency(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() + manager, _ := getSharedTestComponents(t) agency := manager.FindAgency("25") assert.NotNil(t, agency) @@ -256,14 +197,7 @@ func TestManager_GetTripUpdateByID(t *testing.T) { } func TestManager_IsServiceActiveOnDate(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() + manager, _ := getSharedTestComponents(t) // Get a trip to find a valid service ID trips := manager.GetTrips() @@ -326,11 +260,13 @@ func TestManager_IsServiceActiveOnDate(t *testing.T) { } func TestManager_GetVehicleForTrip(t *testing.T) { + gtfsConfig := Config{ GtfsURL: models.GetFixturePath(t, "raba.zip"), GTFSDataPath: ":memory:", Env: appconf.Test, } + //We use isolated GTFSManager here instead of shared test components because we want to control the real-time vehicles for this test. manager, err := InitGTFSManager(gtfsConfig) assert.Nil(t, err) defer manager.Shutdown() @@ -352,6 +288,10 @@ func TestManager_GetVehicleForTrip(t *testing.T) { assert.NotNil(t, vehicle) assert.Equal(t, "vehicle1", vehicle.ID.ID) } + + // Test Not Found + nilVehicle := manager.GetVehicleForTrip("nonexistent") + assert.Nil(t, nilVehicle) } func TestBuildLookupMaps(t *testing.T) { diff --git a/internal/restapi/arrival_and_departure_for_stop_handler.go b/internal/restapi/arrival_and_departure_for_stop_handler.go index 50084931..54202d4c 100644 --- a/internal/restapi/arrival_and_departure_for_stop_handler.go +++ b/internal/restapi/arrival_and_departure_for_stop_handler.go @@ -8,7 +8,6 @@ import ( "github.com/OneBusAway/go-gtfs" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -417,7 +416,6 @@ func (api *RestAPI) arrivalAndDepartureForStopHandler(w http.ResponseWriter, r * stopIDSet[closestStopID] = true } } - calc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) for stopID := range stopIDSet { stopData, err := api.GtfsManager.GtfsDB.Queries.GetStop(ctx, stopID) if err != nil { @@ -448,7 +446,7 @@ func (api *RestAPI) arrivalAndDepartureForStopHandler(w http.ResponseWriter, r * Lat: stopData.Lat, Lon: stopData.Lon, Code: stopData.Code.String, - Direction: calc.CalculateStopDirection(r.Context(), stopData.ID, stopData.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(r.Context(), stopData.ID, stopData.Direction), LocationType: int(stopData.LocationType.Int64), WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stopData.WheelchairBoarding)), RouteIDs: combinedRouteIDs, diff --git a/internal/restapi/arrivals_and_departure_for_stop.go b/internal/restapi/arrivals_and_departure_for_stop.go index 308d0b02..f71638fc 100644 --- a/internal/restapi/arrivals_and_departure_for_stop.go +++ b/internal/restapi/arrivals_and_departure_for_stop.go @@ -8,7 +8,6 @@ import ( "time" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -443,8 +442,6 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r references.Trips = append(references.Trips, tripRef) } - calc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) - for stopID := range stopIDSet { if ctx.Err() != nil { return @@ -489,7 +486,7 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r Lat: stopData.Lat, Lon: stopData.Lon, Code: stopData.Code.String, - Direction: calc.CalculateStopDirection(ctx, stopData.ID, stopData.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stopData.ID, stopData.Direction), LocationType: int(stopData.LocationType.Int64), WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stopData.WheelchairBoarding)), RouteIDs: combinedRouteIDs, diff --git a/internal/restapi/block_handler.go b/internal/restapi/block_handler.go index 39d59892..84765ac3 100644 --- a/internal/restapi/block_handler.go +++ b/internal/restapi/block_handler.go @@ -7,7 +7,6 @@ import ( "sort" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -57,8 +56,7 @@ func (api *RestAPI) blockHandler(w http.ResponseWriter, r *http.Request) { Data: blockData, } - calc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) - references, err := api.getReferences(ctx, agencyID, calc, block) + references, err := api.getReferences(ctx, agencyID, block) if err != nil { api.serverErrorResponse(w, r, err) return @@ -167,7 +165,7 @@ func transformBlockToEntry(block []gtfsdb.GetBlockDetailsRow, blockID, agencyID } // IMPORTANT: Caller must hold manager.RLock() before calling this method. -func (api *RestAPI) getReferences(ctx context.Context, agencyID string, calc *GTFS.AdvancedDirectionCalculator, block []gtfsdb.GetBlockDetailsRow) (models.ReferencesModel, error) { +func (api *RestAPI) getReferences(ctx context.Context, agencyID string, block []gtfsdb.GetBlockDetailsRow) (models.ReferencesModel, error) { routeIDs := make(map[string]struct{}) stopIDs := make(map[string]struct{}) tripIDs := make(map[string]struct{}) @@ -227,7 +225,7 @@ func (api *RestAPI) getReferences(ctx context.Context, agencyID string, calc *GT Code: stop.Code.String, Lat: stop.Lat, Lon: stop.Lon, - Direction: calc.CalculateStopDirection(ctx, stop.ID, stop.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stop.ID, stop.Direction), }) } diff --git a/internal/restapi/http_test.go b/internal/restapi/http_test.go index e59577d3..30eee4fb 100644 --- a/internal/restapi/http_test.go +++ b/internal/restapi/http_test.go @@ -3,6 +3,7 @@ package restapi import ( "bytes" "compress/gzip" + "context" "encoding/json" "io" "log/slog" @@ -26,9 +27,10 @@ import ( // Shared test database setup var ( - testGtfsManager *gtfs.Manager - testDbSetupOnce sync.Once - testDbPath = filepath.Join("../../testdata", "raba-test.db") + testGtfsManager *gtfs.Manager + testDirectionCalculator *gtfs.AdvancedDirectionCalculator + testDbSetupOnce sync.Once + testDbPath = filepath.Join("../../testdata", "raba-test.db") ) // TestMain handles setup and cleanup for all tests in this package @@ -59,6 +61,13 @@ func createTestApiWithClock(t testing.TB, c clock.Clock) *RestAPI { if err != nil { t.Fatalf("Failed to initialize shared test GTFS manager: %v", err) } + + // Create the DirectionCalculator using the shared manager's queries + testDirectionCalculator = gtfs.NewAdvancedDirectionCalculator(testGtfsManager.GtfsDB.Queries) + + // Warm up the cache with test data + err = gtfs.InitializeGlobalCache(context.Background(), testGtfsManager.GtfsDB.Queries, testDirectionCalculator) + require.NoError(t, err, "Failed to initialize global cache for tests") }) gtfsConfig := gtfs.Config{ @@ -73,9 +82,10 @@ func createTestApiWithClock(t testing.TB, c clock.Clock) *RestAPI { RateLimit: 5, // Low rate limit for testing ExemptApiKeys: []string{"org.onebusaway.iphone"}, }, - GtfsConfig: gtfsConfig, - GtfsManager: testGtfsManager, - Clock: c, + GtfsConfig: gtfsConfig, + GtfsManager: testGtfsManager, + DirectionCalculator: testDirectionCalculator, + Clock: c, } api := NewRestAPI(application) diff --git a/internal/restapi/schedule_for_route_handler.go b/internal/restapi/schedule_for_route_handler.go index e167bb5e..f49da05b 100644 --- a/internal/restapi/schedule_for_route_handler.go +++ b/internal/restapi/schedule_for_route_handler.go @@ -5,7 +5,6 @@ import ( "time" "maglev.onebusaway.org/gtfsdb" - "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -183,22 +182,25 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque references := models.NewEmptyReferences() agency, err := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, agencyID) - if err == nil { - agencyModel := models.NewAgencyReference( - agency.ID, - agency.Name, - agency.Url, - agency.Timezone, - agency.Lang.String, - agency.Phone.String, - agency.Email.String, - agency.FareUrl.String, - "", - false, - ) - references.Agencies = append(references.Agencies, agencyModel) + if err != nil { + api.serverErrorResponse(w, r, err) + return } + agencyModel := models.NewAgencyReference( + agency.ID, + agency.Name, + agency.Url, + agency.Timezone, + agency.Lang.String, + agency.Phone.String, + agency.Email.String, + agency.FareUrl.String, + "", + false, + ) + references.Agencies = append(references.Agencies, agencyModel) + for _, r := range routeRefs { references.Routes = append(references.Routes, r) } @@ -209,37 +211,39 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque } if len(tripIDs) > 0 { tripRows, err := api.GtfsManager.GtfsDB.Queries.GetTripsByIDs(ctx, tripIDs) - if err == nil { - for _, t := range tripRows { - combinedTripID := utils.FormCombinedID(agencyID, t.ID) - tripRef := models.NewTripReference( - combinedTripID, - t.RouteID, - t.ServiceID, - t.TripHeadsign.String, - t.TripShortName.String, - t.DirectionID.Int64, - utils.FormCombinedID(agencyID, t.BlockID.String), - utils.FormCombinedID(agencyID, t.ShapeID.String), - ) - references.Trips = append(references.Trips, tripRef) - } + if err != nil { + api.serverErrorResponse(w, r, err) + return } - } - // Create a local calculator to ensure thread safety - calc := gtfs.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) + for _, t := range tripRows { + combinedTripID := utils.FormCombinedID(agencyID, t.ID) + tripRef := models.NewTripReference( + combinedTripID, + t.RouteID, + t.ServiceID, + t.TripHeadsign.String, + t.TripShortName.String, + t.DirectionID.Int64, + utils.FormCombinedID(agencyID, t.BlockID.String), + utils.FormCombinedID(agencyID, t.ShapeID.String), + ) + references.Trips = append(references.Trips, tripRef) + } + } uniqueStopIDs := make([]string, 0, len(globalStopIDSet)) for sid := range globalStopIDSet { uniqueStopIDs = append(uniqueStopIDs, sid) } if len(uniqueStopIDs) > 0 { - // Pass the local calculator - modelStops, _, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, uniqueStopIDs, calc) - if err == nil { - references.Stops = append(references.Stops, modelStops...) + + modelStops, _, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, uniqueStopIDs) + if err != nil { + api.serverErrorResponse(w, r, err) + return } + references.Stops = append(references.Stops, modelStops...) } for _, sref := range stopTimesRefs { diff --git a/internal/restapi/stops_for_route_handler.go b/internal/restapi/stops_for_route_handler.go index e1b26785..9f15214c 100644 --- a/internal/restapi/stops_for_route_handler.go +++ b/internal/restapi/stops_for_route_handler.go @@ -10,7 +10,6 @@ import ( "github.com/twpayne/go-polyline" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -92,67 +91,7 @@ func (api *RestAPI) stopsForRouteHandler(w http.ResponseWriter, r *http.Request) return } - // This prevents nil pointer panics and ensures thread-safety. - adc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) - - // Get Stop IDs for the route to drive the bulk-loading caches - stopIDs, err := api.GtfsManager.GtfsDB.Queries.GetStopIDsForRoute(ctx, routeID) - if err == nil && len(stopIDs) > 0 { - - contextRows, err := api.GtfsManager.GtfsDB.Queries.GetStopsWithShapeContextByIDs(ctx, stopIDs) - if err != nil { - // Log error when bulk context load fails - slog.Warn("bulk context cache load failed, falling back to per-stop queries", - slog.String("routeID", routeID), - slog.String("error", err.Error())) - } else { - contextCache := make(map[string][]gtfsdb.GetStopsWithShapeContextRow) - shapeIDMap := make(map[string]bool) - var uniqueShapeIDs []string - - for _, row := range contextRows { - calcRow := gtfsdb.GetStopsWithShapeContextRow{ - ID: row.StopID, - ShapeID: row.ShapeID, - Lat: row.Lat, - Lon: row.Lon, - ShapeDistTraveled: row.ShapeDistTraveled, - } - contextCache[row.StopID] = append(contextCache[row.StopID], calcRow) - - if row.ShapeID.Valid && row.ShapeID.String != "" && !shapeIDMap[row.ShapeID.String] { - shapeIDMap[row.ShapeID.String] = true - uniqueShapeIDs = append(uniqueShapeIDs, row.ShapeID.String) - } - } - - // Fetch Shape Points in bulk to populate the local cache - if len(uniqueShapeIDs) > 0 { - shapePoints, err := api.GtfsManager.GtfsDB.Queries.GetShapePointsByIDs(ctx, uniqueShapeIDs) - if err != nil { - // Log error when bulk shape load fails - slog.Warn("bulk shape cache load failed, falling back to per-stop queries", - slog.String("routeID", routeID), - slog.String("error", err.Error())) - } else { - shapeCache := make(map[string][]gtfsdb.GetShapePointsWithDistanceRow) - for _, p := range shapePoints { - shapeCache[p.ShapeID] = append(shapeCache[p.ShapeID], gtfsdb.GetShapePointsWithDistanceRow{ - Lat: p.Lat, - Lon: p.Lon, - ShapeDistTraveled: p.ShapeDistTraveled, - }) - } - - // Inject caches into the LOCAL instance. - adc.SetShapeCache(shapeCache) - adc.SetContextCache(contextCache) - } - } - } - } - - result, stopsList, err := api.processRouteStops(ctx, agencyID, routeID, serviceIDs, params.IncludePolylines, adc) + result, stopsList, err := api.processRouteStops(ctx, agencyID, routeID, serviceIDs, params.IncludePolylines) if err != nil { api.serverErrorResponse(w, r, err) return @@ -161,7 +100,7 @@ func (api *RestAPI) stopsForRouteHandler(w http.ResponseWriter, r *http.Request) api.buildAndSendResponse(w, r, ctx, result, stopsList, currentAgency) } -func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, routeID string, serviceIDs []string, includePolylines bool, adc *GTFS.AdvancedDirectionCalculator) (models.RouteEntry, []models.Stop, error) { +func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, routeID string, serviceIDs []string, includePolylines bool) (models.RouteEntry, []models.Stop, error) { allStops := make(map[string]bool) allPolylines := make([]models.Polyline, 0, 100) var stopGroupings []models.StopGrouping @@ -193,7 +132,7 @@ func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, rout } allStopsIds := formatStopIDs(agencyID, allStops) - stopsList, err := buildStopsList(ctx, api, adc, agencyID, allStops) + stopsList, err := buildStopsList(ctx, api, agencyID, allStops) if err != nil { return models.RouteEntry{}, nil, err } @@ -208,7 +147,7 @@ func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, rout return result, stopsList, nil } -func buildStopsList(ctx context.Context, api *RestAPI, calc *GTFS.AdvancedDirectionCalculator, agencyID string, allStops map[string]bool) ([]models.Stop, error) { +func buildStopsList(ctx context.Context, api *RestAPI, agencyID string, allStops map[string]bool) ([]models.Stop, error) { stopIDs := make([]string, 0, len(allStops)) for stopID := range allStops { @@ -243,7 +182,7 @@ func buildStopsList(ctx context.Context, api *RestAPI, calc *GTFS.AdvancedDirect return nil, ctx.Err() } - direction := calc.CalculateStopDirection(ctx, stop.ID, stop.Direction) + direction := api.DirectionCalculator.CalculateStopDirection(ctx, stop.ID, stop.Direction) routeIdsString := append([]string(nil), routesMap[stop.ID]...) diff --git a/internal/restapi/trip_details_handler.go b/internal/restapi/trip_details_handler.go index e88e49b2..1bcd8e53 100644 --- a/internal/restapi/trip_details_handler.go +++ b/internal/restapi/trip_details_handler.go @@ -7,7 +7,6 @@ import ( "time" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -194,8 +193,6 @@ func (api *RestAPI) tripDetailsHandler(w http.ResponseWriter, r *http.Request) { references.Trips = referencedTripsIface } - calc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) - agencyModel := models.NewAgencyReference( agency.ID, agency.Name, @@ -211,7 +208,7 @@ func (api *RestAPI) tripDetailsHandler(w http.ResponseWriter, r *http.Request) { references.Agencies = append(references.Agencies, agencyModel) if params.IncludeSchedule && schedule != nil { - stops, err := api.buildStopReferences(ctx, calc, agencyID, schedule.StopTimes) + stops, err := api.buildStopReferences(ctx, agencyID, schedule.StopTimes) if err != nil { api.serverErrorResponse(w, r, err) return @@ -289,7 +286,7 @@ func (api *RestAPI) buildReferencedTrips(ctx context.Context, agencyID string, t } // IMPORTANT: Caller must hold manager.RLock() before calling this method. -func (api *RestAPI) buildStopReferences(ctx context.Context, calc *GTFS.AdvancedDirectionCalculator, agencyID string, stopTimes []models.StopTime) ([]models.Stop, error) { +func (api *RestAPI) buildStopReferences(ctx context.Context, agencyID string, stopTimes []models.StopTime) ([]models.Stop, error) { stopIDSet := make(map[string]bool) originalStopIDs := make([]string, 0, len(stopTimes)) @@ -379,7 +376,7 @@ func (api *RestAPI) buildStopReferences(ctx context.Context, calc *GTFS.Advanced Lat: stop.Lat, Lon: stop.Lon, Code: stop.Code.String, - Direction: calc.CalculateStopDirection(ctx, stop.ID, stop.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stop.ID, stop.Direction), LocationType: int(stop.LocationType.Int64), WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stop.WheelchairBoarding)), RouteIDs: combinedRouteIDs, diff --git a/internal/restapi/trip_for_vehicle_handler.go b/internal/restapi/trip_for_vehicle_handler.go index 82934a02..2acb7f43 100644 --- a/internal/restapi/trip_for_vehicle_handler.go +++ b/internal/restapi/trip_for_vehicle_handler.go @@ -9,7 +9,6 @@ import ( "time" "maglev.onebusaway.org/gtfsdb" - "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -222,7 +221,6 @@ func (api *RestAPI) tripForVehicleHandler(w http.ResponseWriter, r *http.Request ) stopIDs := []string{} - calc := gtfs.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) if status != nil { if status.ClosestStop != "" { @@ -242,7 +240,7 @@ func (api *RestAPI) tripForVehicleHandler(w http.ResponseWriter, r *http.Request stopIDs = append(stopIDs, nextStopID) } } - stops, uniqueRouteMap, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, stopIDs, calc) + stops, uniqueRouteMap, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, stopIDs) if err != nil { api.serverErrorResponse(w, r, err) return @@ -287,7 +285,7 @@ func (api *RestAPI) tripForVehicleHandler(w http.ResponseWriter, r *http.Request // BuildStopReferencesAndRouteIDsForStops builds stop references and collects unique routes for the given stop IDs. // IMPORTANT: Caller must hold manager.RLock() before calling this method. -func BuildStopReferencesAndRouteIDsForStops(api *RestAPI, ctx context.Context, agencyID string, stopIDs []string, calc *gtfs.AdvancedDirectionCalculator) ([]models.Stop, map[string]gtfsdb.GetRoutesForStopsRow, error) { +func BuildStopReferencesAndRouteIDsForStops(api *RestAPI, ctx context.Context, agencyID string, stopIDs []string) ([]models.Stop, map[string]gtfsdb.GetRoutesForStopsRow, error) { if len(stopIDs) == 0 { return []models.Stop{}, map[string]gtfsdb.GetRoutesForStopsRow{}, nil } @@ -351,7 +349,7 @@ func BuildStopReferencesAndRouteIDsForStops(api *RestAPI, ctx context.Context, a Lat: stop.Lat, Lon: stop.Lon, Code: stop.Code.String, - Direction: calc.CalculateStopDirection(ctx, stop.ID, stop.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stop.ID, stop.Direction), LocationType: int(stop.LocationType.Int64), WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stop.WheelchairBoarding)), RouteIDs: combinedRouteIDs, diff --git a/internal/restapi/vehicles_for_agency_handler_test.go b/internal/restapi/vehicles_for_agency_handler_test.go index 477bcfec..7ac3d8bf 100644 --- a/internal/restapi/vehicles_for_agency_handler_test.go +++ b/internal/restapi/vehicles_for_agency_handler_test.go @@ -1,6 +1,7 @@ package restapi import ( + "context" "net/http" "net/http/httptest" "os" @@ -340,15 +341,20 @@ func createTestApiWithRealTimeData(t *testing.T) (*RestAPI, func()) { gtfsManager, err := gtfs.InitGTFSManager(gtfsConfig) require.NoError(t, err) + dirCalc := gtfs.NewAdvancedDirectionCalculator(gtfsManager.GtfsDB.Queries) + err = gtfs.InitializeGlobalCache(context.Background(), gtfsManager.GtfsDB.Queries, dirCalc) + require.NoError(t, err) + application := &app.Application{ Config: appconf.Config{ Env: appconf.EnvFlagToEnvironment("test"), ApiKeys: []string{"TEST"}, RateLimit: 100, // Higher rate limit for this test }, - GtfsConfig: gtfsConfig, - GtfsManager: gtfsManager, - Clock: clock.RealClock{}, + GtfsConfig: gtfsConfig, + GtfsManager: gtfsManager, + DirectionCalculator: dirCalc, + Clock: clock.RealClock{}, } api := NewRestAPI(application)