From e2c5ae093782c5bd5dd0eb18f6bc83f057425138 Mon Sep 17 00:00:00 2001 From: Vedanth Date: Sun, 1 Feb 2026 23:06:14 +0530 Subject: [PATCH 1/7] refactor:initialize caches at startup and use shared directionCalculator --- cmd/api/app.go | 5 + gtfsdb/db.go | 30 ++-- gtfsdb/models.go | 9 + gtfsdb/query.sql | 6 + gtfsdb/query.sql.go | 162 +++++++++++------- internal/gtfs/global_cache.go | 76 ++++++++ .../arrival_and_departure_for_stop_handler.go | 4 +- .../arrivals_and_departure_for_stop.go | 5 +- internal/restapi/block_handler.go | 8 +- internal/restapi/http_test.go | 14 +- .../restapi/schedule_for_route_handler.go | 8 +- internal/restapi/stops_for_route_handler.go | 68 +------- internal/restapi/trip_details_handler.go | 9 +- internal/restapi/trip_for_vehicle_handler.go | 8 +- 14 files changed, 240 insertions(+), 172 deletions(-) create mode 100644 internal/gtfs/global_cache.go diff --git a/cmd/api/app.go b/cmd/api/app.go index 1cda2cdc..b2e76ea8 100644 --- a/cmd/api/app.go +++ b/cmd/api/app.go @@ -51,6 +51,11 @@ func BuildApplication(cfg appconf.Config, gtfsCfg gtfs.Config) (*app.Application directionCalculator = gtfs.NewAdvancedDirectionCalculator(gtfsManager.GtfsDB.Queries) } + err = gtfs.InitializeGlobalCache(context.Background(), gtfsManager.GtfsDB.Queries, directionCalculator) + if err != nil { + return nil, fmt.Errorf("failed to initialize global cache: %w", err) + } + // Select clock implementation based on environment appClock := createClock(cfg.Env) diff --git a/gtfsdb/db.go b/gtfsdb/db.go index 9bb74084..7081aec1 100644 --- a/gtfsdb/db.go +++ b/gtfsdb/db.go @@ -105,6 +105,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.getAllShapesStmt, err = db.PrepareContext(ctx, getAllShapes); err != nil { return nil, fmt.Errorf("error preparing query GetAllShapes: %w", err) } + if q.getAllStopIDsStmt, err = db.PrepareContext(ctx, getAllStopIDs); err != nil { + return nil, fmt.Errorf("error preparing query GetAllStopIDs: %w", err) + } if q.getAllTripsForRouteStmt, err = db.PrepareContext(ctx, getAllTripsForRoute); err != nil { return nil, fmt.Errorf("error preparing query GetAllTripsForRoute: %w", err) } @@ -264,15 +267,15 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.listRoutesStmt, err = db.PrepareContext(ctx, listRoutes); err != nil { return nil, fmt.Errorf("error preparing query ListRoutes: %w", err) } - if q.searchRoutesByFullTextStmt, err = db.PrepareContext(ctx, searchRoutesByFullText); err != nil { - return nil, fmt.Errorf("error preparing query SearchRoutesByFullText: %w", err) - } if q.listStopsStmt, err = db.PrepareContext(ctx, listStops); err != nil { return nil, fmt.Errorf("error preparing query ListStops: %w", err) } if q.listTripsStmt, err = db.PrepareContext(ctx, listTrips); err != nil { return nil, fmt.Errorf("error preparing query ListTrips: %w", err) } + if q.searchRoutesByFullTextStmt, err = db.PrepareContext(ctx, searchRoutesByFullText); err != nil { + return nil, fmt.Errorf("error preparing query SearchRoutesByFullText: %w", err) + } if q.searchStopsByNameStmt, err = db.PrepareContext(ctx, searchStopsByName); err != nil { return nil, fmt.Errorf("error preparing query SearchStopsByName: %w", err) } @@ -422,6 +425,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getAllShapesStmt: %w", cerr) } } + if q.getAllStopIDsStmt != nil { + if cerr := q.getAllStopIDsStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getAllStopIDsStmt: %w", cerr) + } + } if q.getAllTripsForRouteStmt != nil { if cerr := q.getAllTripsForRouteStmt.Close(); cerr != nil { err = fmt.Errorf("error closing getAllTripsForRouteStmt: %w", cerr) @@ -687,11 +695,6 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing listRoutesStmt: %w", cerr) } } - if q.searchRoutesByFullTextStmt != nil { - if cerr := q.searchRoutesByFullTextStmt.Close(); cerr != nil { - err = fmt.Errorf("error closing searchRoutesByFullTextStmt: %w", cerr) - } - } if q.listStopsStmt != nil { if cerr := q.listStopsStmt.Close(); cerr != nil { err = fmt.Errorf("error closing listStopsStmt: %w", cerr) @@ -702,6 +705,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing listTripsStmt: %w", cerr) } } + if q.searchRoutesByFullTextStmt != nil { + if cerr := q.searchRoutesByFullTextStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing searchRoutesByFullTextStmt: %w", cerr) + } + } if q.searchStopsByNameStmt != nil { if cerr := q.searchStopsByNameStmt.Close(); cerr != nil { err = fmt.Errorf("error closing searchStopsByNameStmt: %w", cerr) @@ -783,6 +791,7 @@ type Queries struct { getAgencyStmt *sql.Stmt getAgencyForStopStmt *sql.Stmt getAllShapesStmt *sql.Stmt + getAllStopIDsStmt *sql.Stmt getAllTripsForRouteStmt *sql.Stmt getArrivalsAndDeparturesForStopStmt *sql.Stmt getBlockDetailsStmt *sql.Stmt @@ -836,9 +845,9 @@ type Queries struct { getTripsInBlockStmt *sql.Stmt listAgenciesStmt *sql.Stmt listRoutesStmt *sql.Stmt - searchRoutesByFullTextStmt *sql.Stmt listStopsStmt *sql.Stmt listTripsStmt *sql.Stmt + searchRoutesByFullTextStmt *sql.Stmt searchStopsByNameStmt *sql.Stmt updateStopDirectionStmt *sql.Stmt upsertImportMetadataStmt *sql.Stmt @@ -875,6 +884,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { getAgencyStmt: q.getAgencyStmt, getAgencyForStopStmt: q.getAgencyForStopStmt, getAllShapesStmt: q.getAllShapesStmt, + getAllStopIDsStmt: q.getAllStopIDsStmt, getAllTripsForRouteStmt: q.getAllTripsForRouteStmt, getArrivalsAndDeparturesForStopStmt: q.getArrivalsAndDeparturesForStopStmt, getBlockDetailsStmt: q.getBlockDetailsStmt, @@ -928,9 +938,9 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { getTripsInBlockStmt: q.getTripsInBlockStmt, listAgenciesStmt: q.listAgenciesStmt, listRoutesStmt: q.listRoutesStmt, - searchRoutesByFullTextStmt: q.searchRoutesByFullTextStmt, listStopsStmt: q.listStopsStmt, listTripsStmt: q.listTripsStmt, + searchRoutesByFullTextStmt: q.searchRoutesByFullTextStmt, searchStopsByNameStmt: q.searchStopsByNameStmt, updateStopDirectionStmt: q.updateStopDirectionStmt, upsertImportMetadataStmt: q.upsertImportMetadataStmt, diff --git a/gtfsdb/models.go b/gtfsdb/models.go index 5ea8bb10..1a43b0c4 100644 --- a/gtfsdb/models.go +++ b/gtfsdb/models.go @@ -76,6 +76,15 @@ type Route struct { ContinuousDropOff sql.NullInt64 } +type RoutesFt struct { + RoutesFts string + ID string + AgencyID string + ShortName string + LongName string + Desc string +} + type Shape struct { ID int64 ShapeID string diff --git a/gtfsdb/query.sql b/gtfsdb/query.sql index c3f8c717..23bdea02 100644 --- a/gtfsdb/query.sql +++ b/gtfsdb/query.sql @@ -218,6 +218,12 @@ ORDER BY LIMIT 1; +-- name: GetAllStopIDs :many +SELECT DISTINCT + id +FROM + stops; + -- name: GetStopIDsForAgency :many SELECT DISTINCT s.id diff --git a/gtfsdb/query.sql.go b/gtfsdb/query.sql.go index 0ae828ca..f78a6ec9 100644 --- a/gtfsdb/query.sql.go +++ b/gtfsdb/query.sql.go @@ -978,6 +978,36 @@ func (q *Queries) GetAllShapes(ctx context.Context) ([]Shape, error) { return items, nil } +const getAllStopIDs = `-- name: GetAllStopIDs :many +SELECT DISTINCT + id +FROM + stops +` + +func (q *Queries) GetAllStopIDs(ctx context.Context) ([]string, error) { + rows, err := q.query(ctx, q.getAllStopIDsStmt, getAllStopIDs) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + items = append(items, id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getAllTripsForRoute = `-- name: GetAllTripsForRoute :many SELECT DISTINCT id, route_id, service_id, trip_headsign, trip_short_name, direction_id, block_id, shape_id, wheelchair_accessible, bikes_allowed FROM trips t @@ -3628,72 +3658,6 @@ func (q *Queries) ListRoutes(ctx context.Context) ([]Route, error) { return items, nil } -const searchRoutesByFullText = `-- name: SearchRoutesByFullText :many -SELECT - r.id, - r.agency_id, - r.short_name, - r.long_name, - r."desc", - r.type, - r.url, - r.color, - r.text_color, - r.continuous_pickup, - r.continuous_drop_off -FROM - routes_fts - JOIN routes r ON r.rowid = routes_fts.rowid -WHERE - routes_fts MATCH ? -ORDER BY - bm25(routes_fts), - r.agency_id, - r.id -LIMIT - ? -` - -type SearchRoutesByFullTextParams struct { - Query string - Limit int64 -} - -func (q *Queries) SearchRoutesByFullText(ctx context.Context, arg SearchRoutesByFullTextParams) ([]Route, error) { - rows, err := q.query(ctx, q.searchRoutesByFullTextStmt, searchRoutesByFullText, arg.Query, arg.Limit) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Route - for rows.Next() { - var i Route - if err := rows.Scan( - &i.ID, - &i.AgencyID, - &i.ShortName, - &i.LongName, - &i.Desc, - &i.Type, - &i.Url, - &i.Color, - &i.TextColor, - &i.ContinuousPickup, - &i.ContinuousDropOff, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const listStops = `-- name: ListStops :many SELECT id, code, name, "desc", lat, lon, zone_id, url, location_type, timezone, wheelchair_boarding, platform_code, direction, parent_station @@ -3782,6 +3746,72 @@ func (q *Queries) ListTrips(ctx context.Context) ([]Trip, error) { return items, nil } +const searchRoutesByFullText = `-- name: SearchRoutesByFullText :many +SELECT + r.id, + r.agency_id, + r.short_name, + r.long_name, + r."desc", + r.type, + r.url, + r.color, + r.text_color, + r.continuous_pickup, + r.continuous_drop_off +FROM + routes_fts + JOIN routes r ON r.rowid = routes_fts.rowid +WHERE + routes_fts MATCH ?1 +ORDER BY + rank, + r.agency_id, + r.id +LIMIT + ?2 +` + +type SearchRoutesByFullTextParams struct { + Query string + Limit int64 +} + +func (q *Queries) SearchRoutesByFullText(ctx context.Context, arg SearchRoutesByFullTextParams) ([]Route, error) { + rows, err := q.query(ctx, q.searchRoutesByFullTextStmt, searchRoutesByFullText, arg.Query, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Route + for rows.Next() { + var i Route + if err := rows.Scan( + &i.ID, + &i.AgencyID, + &i.ShortName, + &i.LongName, + &i.Desc, + &i.Type, + &i.Url, + &i.Color, + &i.TextColor, + &i.ContinuousPickup, + &i.ContinuousDropOff, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const searchStopsByName = `-- name: SearchStopsByName :many SELECT s.id, diff --git a/internal/gtfs/global_cache.go b/internal/gtfs/global_cache.go new file mode 100644 index 00000000..0d0a8092 --- /dev/null +++ b/internal/gtfs/global_cache.go @@ -0,0 +1,76 @@ +package gtfs + +import ( + "context" + "log/slog" + + "maglev.onebusaway.org/gtfsdb" +) + +func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *AdvancedDirectionCalculator) error { + slog.Info("starting global cache warmup...") + + // Fetch ALL Stop IDs + allStopIDs, err := queries.GetAllStopIDs(ctx) + if err != nil { + return err + } + + // Fetch Context (Stop -> Shape mappings) + contextRows, err := queries.GetStopsWithShapeContextByIDs(ctx, allStopIDs) + if err != nil { + return err + } + + contextCache := make(map[string][]gtfsdb.GetStopsWithShapeContextRow) + shapeIDMap := make(map[string]bool) + var uniqueShapeIDs []string + + for _, row := range contextRows { + // Map the DB row to the Cache row struct + calcRow := gtfsdb.GetStopsWithShapeContextRow{ + ID: row.StopID, + ShapeID: row.ShapeID, + Lat: row.Lat, + Lon: row.Lon, + ShapeDistTraveled: row.ShapeDistTraveled, + } + contextCache[row.StopID] = append(contextCache[row.StopID], calcRow) + + // Collect unique valid Shape IDs + if row.ShapeID.Valid && row.ShapeID.String != "" && !shapeIDMap[row.ShapeID.String] { + shapeIDMap[row.ShapeID.String] = true + uniqueShapeIDs = append(uniqueShapeIDs, row.ShapeID.String) + } + } + + // Fetch Shape Points (Geometry) + shapeCache := make(map[string][]gtfsdb.GetShapePointsWithDistanceRow) + + if len(uniqueShapeIDs) > 0 { + shapePoints, err := queries.GetShapePointsByIDs(ctx, uniqueShapeIDs) + if err != nil { + // Fail fast if we can't load shapes (or just log error if you want to be resilient) + slog.Warn("Failed to fetch shape points for global cache", "error", err) + return err + } + + for _, p := range shapePoints { + shapeCache[p.ShapeID] = append(shapeCache[p.ShapeID], gtfsdb.GetShapePointsWithDistanceRow{ + Lat: p.Lat, + Lon: p.Lon, + ShapeDistTraveled: p.ShapeDistTraveled, + }) + } + } + + // Set Cache + adc.SetShapeCache(shapeCache) + adc.SetContextCache(contextCache) + + slog.Info("global cache warmup complete", + slog.Int("stops_cached", len(contextCache)), + slog.Int("shapes_cached", len(shapeCache))) + + return nil +} diff --git a/internal/restapi/arrival_and_departure_for_stop_handler.go b/internal/restapi/arrival_and_departure_for_stop_handler.go index 262c2b20..ed89588f 100644 --- a/internal/restapi/arrival_and_departure_for_stop_handler.go +++ b/internal/restapi/arrival_and_departure_for_stop_handler.go @@ -8,7 +8,6 @@ import ( "github.com/OneBusAway/go-gtfs" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -389,7 +388,6 @@ func (api *RestAPI) arrivalAndDepartureForStopHandler(w http.ResponseWriter, r * stopIDSet[closestStopID] = true } } - calc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) for stopID := range stopIDSet { stopData, err := api.GtfsManager.GtfsDB.Queries.GetStop(ctx, stopID) if err != nil { @@ -420,7 +418,7 @@ func (api *RestAPI) arrivalAndDepartureForStopHandler(w http.ResponseWriter, r * Lat: stopData.Lat, Lon: stopData.Lon, Code: stopData.Code.String, - Direction: calc.CalculateStopDirection(r.Context(), stopData.ID, stopData.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(r.Context(), stopData.ID, stopData.Direction), LocationType: int(stopData.LocationType.Int64), WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stopData.WheelchairBoarding)), RouteIDs: combinedRouteIDs, diff --git a/internal/restapi/arrivals_and_departure_for_stop.go b/internal/restapi/arrivals_and_departure_for_stop.go index 4f05a3cf..44dd7474 100644 --- a/internal/restapi/arrivals_and_departure_for_stop.go +++ b/internal/restapi/arrivals_and_departure_for_stop.go @@ -7,7 +7,6 @@ import ( "time" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -326,8 +325,6 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r references.Trips = append(references.Trips, tripRef) } - calc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) - for stopID := range stopIDSet { stopData, err := api.GtfsManager.GtfsDB.Queries.GetStop(ctx, stopID) if err != nil { @@ -360,7 +357,7 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r Lat: stopData.Lat, Lon: stopData.Lon, Code: stopData.Code.String, - Direction: calc.CalculateStopDirection(ctx, stopID), + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stopID), LocationType: int(stopData.LocationType.Int64), WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stopData.WheelchairBoarding)), RouteIDs: combinedRouteIDs, diff --git a/internal/restapi/block_handler.go b/internal/restapi/block_handler.go index 13589e77..bfc566c7 100644 --- a/internal/restapi/block_handler.go +++ b/internal/restapi/block_handler.go @@ -7,7 +7,6 @@ import ( "sort" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -56,8 +55,7 @@ func (api *RestAPI) blockHandler(w http.ResponseWriter, r *http.Request) { Data: blockData, } - calc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) - references, err := api.getReferences(ctx, agencyID, calc, block) + references, err := api.getReferences(ctx, agencyID, block) if err != nil { api.serverErrorResponse(w, r, err) return @@ -166,7 +164,7 @@ func transformBlockToEntry(block []gtfsdb.GetBlockDetailsRow, blockID, agencyID } // IMPORTANT: Caller must hold manager.RLock() before calling this method. -func (api *RestAPI) getReferences(ctx context.Context, agencyID string, calc *GTFS.AdvancedDirectionCalculator, block []gtfsdb.GetBlockDetailsRow) (models.ReferencesModel, error) { +func (api *RestAPI) getReferences(ctx context.Context, agencyID string, block []gtfsdb.GetBlockDetailsRow) (models.ReferencesModel, error) { routeIDs := make(map[string]struct{}) stopIDs := make(map[string]struct{}) tripIDs := make(map[string]struct{}) @@ -222,7 +220,7 @@ func (api *RestAPI) getReferences(ctx context.Context, agencyID string, calc *GT Code: stop.Code.String, Lat: stop.Lat, Lon: stop.Lon, - Direction: calc.CalculateStopDirection(ctx, stop.ID, stop.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stop.ID, stop.Direction), }) } diff --git a/internal/restapi/http_test.go b/internal/restapi/http_test.go index 2128ff2d..18410596 100644 --- a/internal/restapi/http_test.go +++ b/internal/restapi/http_test.go @@ -3,6 +3,7 @@ package restapi import ( "bytes" "compress/gzip" + "context" "encoding/json" "io" "log/slog" @@ -66,15 +67,22 @@ func createTestApiWithClock(t testing.TB, c clock.Clock) *RestAPI { GTFSDataPath: testDbPath, } + // Create the DirectionCalculator using the shared manager's queries + directionCalculator := gtfs.NewAdvancedDirectionCalculator(testGtfsManager.GtfsDB.Queries) + + // Warm up the cache with test data + _ = gtfs.InitializeGlobalCache(context.Background(), testGtfsManager.GtfsDB.Queries, directionCalculator) + application := &app.Application{ Config: appconf.Config{ Env: appconf.EnvFlagToEnvironment("test"), ApiKeys: []string{"TEST", "test", "test-rate-limit", "test-headers", "test-refill", "test-error-format", "org.onebusaway.iphone"}, RateLimit: 5, // Low rate limit for testing }, - GtfsConfig: gtfsConfig, - GtfsManager: testGtfsManager, - Clock: c, + GtfsConfig: gtfsConfig, + GtfsManager: testGtfsManager, + DirectionCalculator: directionCalculator, + Clock: c, } api := NewRestAPI(application) diff --git a/internal/restapi/schedule_for_route_handler.go b/internal/restapi/schedule_for_route_handler.go index e6e0fc37..75386d70 100644 --- a/internal/restapi/schedule_for_route_handler.go +++ b/internal/restapi/schedule_for_route_handler.go @@ -5,7 +5,6 @@ import ( "time" "maglev.onebusaway.org/gtfsdb" - "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -236,16 +235,13 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque } } - // Create a local calculator to ensure thread safety - calc := gtfs.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) - uniqueStopIDs := make([]string, 0, len(globalStopIDSet)) for sid := range globalStopIDSet { uniqueStopIDs = append(uniqueStopIDs, sid) } if len(uniqueStopIDs) > 0 { - // Pass the local calculator - modelStops, _, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, uniqueStopIDs, calc) + + modelStops, _, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, uniqueStopIDs) if err == nil { references.Stops = append(references.Stops, modelStops...) } diff --git a/internal/restapi/stops_for_route_handler.go b/internal/restapi/stops_for_route_handler.go index 5581a1db..50ec424a 100644 --- a/internal/restapi/stops_for_route_handler.go +++ b/internal/restapi/stops_for_route_handler.go @@ -97,67 +97,7 @@ func (api *RestAPI) stopsForRouteHandler(w http.ResponseWriter, r *http.Request) return } - // This prevents nil pointer panics and ensures thread-safety. - adc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) - - // Get Stop IDs for the route to drive the bulk-loading caches - stopIDs, err := api.GtfsManager.GtfsDB.Queries.GetStopIDsForRoute(ctx, routeID) - if err == nil && len(stopIDs) > 0 { - - contextRows, err := api.GtfsManager.GtfsDB.Queries.GetStopsWithShapeContextByIDs(ctx, stopIDs) - if err != nil { - // Log error when bulk context load fails - slog.Warn("bulk context cache load failed, falling back to per-stop queries", - slog.String("routeID", routeID), - slog.String("error", err.Error())) - } else { - contextCache := make(map[string][]gtfsdb.GetStopsWithShapeContextRow) - shapeIDMap := make(map[string]bool) - var uniqueShapeIDs []string - - for _, row := range contextRows { - calcRow := gtfsdb.GetStopsWithShapeContextRow{ - ID: row.StopID, - ShapeID: row.ShapeID, - Lat: row.Lat, - Lon: row.Lon, - ShapeDistTraveled: row.ShapeDistTraveled, - } - contextCache[row.StopID] = append(contextCache[row.StopID], calcRow) - - if row.ShapeID.Valid && row.ShapeID.String != "" && !shapeIDMap[row.ShapeID.String] { - shapeIDMap[row.ShapeID.String] = true - uniqueShapeIDs = append(uniqueShapeIDs, row.ShapeID.String) - } - } - - // Fetch Shape Points in bulk to populate the local cache - if len(uniqueShapeIDs) > 0 { - shapePoints, err := api.GtfsManager.GtfsDB.Queries.GetShapePointsByIDs(ctx, uniqueShapeIDs) - if err != nil { - // Log error when bulk shape load fails - slog.Warn("bulk shape cache load failed, falling back to per-stop queries", - slog.String("routeID", routeID), - slog.String("error", err.Error())) - } else { - shapeCache := make(map[string][]gtfsdb.GetShapePointsWithDistanceRow) - for _, p := range shapePoints { - shapeCache[p.ShapeID] = append(shapeCache[p.ShapeID], gtfsdb.GetShapePointsWithDistanceRow{ - Lat: p.Lat, - Lon: p.Lon, - ShapeDistTraveled: p.ShapeDistTraveled, - }) - } - - // Inject caches into the LOCAL instance. - adc.SetShapeCache(shapeCache) - adc.SetContextCache(contextCache) - } - } - } - } - - result, stopsList, err := api.processRouteStops(ctx, agencyID, routeID, serviceIDs, params.IncludePolylines, adc) + result, stopsList, err := api.processRouteStops(ctx, agencyID, routeID, serviceIDs, params.IncludePolylines, api.DirectionCalculator) if err != nil { api.serverErrorResponse(w, r, err) return @@ -197,7 +137,7 @@ func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, rout } allStopsIds := formatStopIDs(agencyID, allStops) - stopsList, err := buildStopsList(ctx, api, adc, agencyID, allStops) + stopsList, err := buildStopsList(ctx, api, agencyID, allStops) if err != nil { return models.RouteEntry{}, nil, err } @@ -212,7 +152,7 @@ func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, rout return result, stopsList, nil } -func buildStopsList(ctx context.Context, api *RestAPI, calc *GTFS.AdvancedDirectionCalculator, agencyID string, allStops map[string]bool) ([]models.Stop, error) { +func buildStopsList(ctx context.Context, api *RestAPI, agencyID string, allStops map[string]bool) ([]models.Stop, error) { stopIDs := make([]string, 0, len(allStops)) for stopID := range allStops { @@ -244,7 +184,7 @@ func buildStopsList(ctx context.Context, api *RestAPI, calc *GTFS.AdvancedDirect for _, stop := range stops { - direction := calc.CalculateStopDirection(ctx, stop.ID, stop.Direction) + direction := api.DirectionCalculator.CalculateStopDirection(ctx, stop.ID, stop.Direction) routeIdsString := append([]string(nil), routesMap[stop.ID]...) diff --git a/internal/restapi/trip_details_handler.go b/internal/restapi/trip_details_handler.go index 6fa0afb4..e841a660 100644 --- a/internal/restapi/trip_details_handler.go +++ b/internal/restapi/trip_details_handler.go @@ -7,7 +7,6 @@ import ( "time" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -170,8 +169,6 @@ func (api *RestAPI) tripDetailsHandler(w http.ResponseWriter, r *http.Request) { references.Trips = referencedTripsIface } - calc := GTFS.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) - agencyModel := models.NewAgencyReference( agency.ID, agency.Name, @@ -187,7 +184,7 @@ func (api *RestAPI) tripDetailsHandler(w http.ResponseWriter, r *http.Request) { references.Agencies = append(references.Agencies, agencyModel) if params.IncludeSchedule && schedule != nil { - stops, err := api.buildStopReferences(ctx, calc, agencyID, schedule.StopTimes) + stops, err := api.buildStopReferences(ctx, agencyID, schedule.StopTimes) if err != nil { api.serverErrorResponse(w, r, err) return @@ -261,7 +258,7 @@ func (api *RestAPI) buildReferencedTrips(ctx context.Context, agencyID string, t } // IMPORTANT: Caller must hold manager.RLock() before calling this method. -func (api *RestAPI) buildStopReferences(ctx context.Context, calc *GTFS.AdvancedDirectionCalculator, agencyID string, stopTimes []models.StopTime) ([]models.Stop, error) { +func (api *RestAPI) buildStopReferences(ctx context.Context, agencyID string, stopTimes []models.StopTime) ([]models.Stop, error) { stopIDSet := make(map[string]bool) originalStopIDs := make([]string, 0, len(stopTimes)) @@ -343,7 +340,7 @@ func (api *RestAPI) buildStopReferences(ctx context.Context, calc *GTFS.Advanced Lat: stop.Lat, Lon: stop.Lon, Code: stop.Code.String, - Direction: calc.CalculateStopDirection(ctx, stop.ID, stop.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stop.ID, stop.Direction), LocationType: int(stop.LocationType.Int64), WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stop.WheelchairBoarding)), RouteIDs: combinedRouteIDs, diff --git a/internal/restapi/trip_for_vehicle_handler.go b/internal/restapi/trip_for_vehicle_handler.go index 2c4a2270..0e3c8ef5 100644 --- a/internal/restapi/trip_for_vehicle_handler.go +++ b/internal/restapi/trip_for_vehicle_handler.go @@ -9,7 +9,6 @@ import ( "time" "maglev.onebusaway.org/gtfsdb" - "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -191,7 +190,6 @@ func (api *RestAPI) tripForVehicleHandler(w http.ResponseWriter, r *http.Request ) stopIDs := []string{} - calc := gtfs.NewAdvancedDirectionCalculator(api.GtfsManager.GtfsDB.Queries) if status != nil { if status.ClosestStop != "" { @@ -211,7 +209,7 @@ func (api *RestAPI) tripForVehicleHandler(w http.ResponseWriter, r *http.Request stopIDs = append(stopIDs, nextStopID) } } - stops, uniqueRouteMap, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, stopIDs, calc) + stops, uniqueRouteMap, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, stopIDs) if err != nil { api.serverErrorResponse(w, r, err) return @@ -257,7 +255,7 @@ func (api *RestAPI) tripForVehicleHandler(w http.ResponseWriter, r *http.Request // BuildStopReferencesAndRouteIDsForStops builds stop references and collects unique routes for the given stop IDs. // IMPORTANT: Caller must hold manager.RLock() before calling this method. -func BuildStopReferencesAndRouteIDsForStops(api *RestAPI, ctx context.Context, agencyID string, stopIDs []string, calc *gtfs.AdvancedDirectionCalculator) ([]models.Stop, map[string]gtfsdb.GetRoutesForStopsRow, error) { +func BuildStopReferencesAndRouteIDsForStops(api *RestAPI, ctx context.Context, agencyID string, stopIDs []string) ([]models.Stop, map[string]gtfsdb.GetRoutesForStopsRow, error) { if len(stopIDs) == 0 { return []models.Stop{}, map[string]gtfsdb.GetRoutesForStopsRow{}, nil } @@ -321,7 +319,7 @@ func BuildStopReferencesAndRouteIDsForStops(api *RestAPI, ctx context.Context, a Lat: stop.Lat, Lon: stop.Lon, Code: stop.Code.String, - Direction: calc.CalculateStopDirection(ctx, stop.ID, stop.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stop.ID, stop.Direction), LocationType: int(stop.LocationType.Int64), WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stop.WheelchairBoarding)), RouteIDs: combinedRouteIDs, From 11b1abc3d661ff7a8b47c0a7d829aa3f2fcf0bb3 Mon Sep 17 00:00:00 2001 From: Vedanth Date: Sun, 8 Feb 2026 19:49:19 +0530 Subject: [PATCH 2/7] refactor: optimize test setup by reusing shared DB and cache --- .../advanced_direction_calculator_test.go | 185 ++++++------------ internal/gtfs/gtfs_manager_test.go | 46 ++--- 2 files changed, 70 insertions(+), 161 deletions(-) diff --git a/internal/gtfs/advanced_direction_calculator_test.go b/internal/gtfs/advanced_direction_calculator_test.go index 6c2ef17e..408bd8d6 100644 --- a/internal/gtfs/advanced_direction_calculator_test.go +++ b/internal/gtfs/advanced_direction_calculator_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "math" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -11,6 +12,44 @@ import ( "maglev.onebusaway.org/internal/models" ) +// This uses a Singleton pattern to load the DB and Warm the Cache exactly ONCE +// for this test file. This prevents re-loading the ZIP file 15+ times. + +var ( + sharedManager *Manager + sharedCalc *AdvancedDirectionCalculator + setupOnce sync.Once +) + +// Helper function to get the shared instances. +func getSharedTestComponents(t *testing.T) (*Manager, *AdvancedDirectionCalculator) { + setupOnce.Do(func() { + // Initialize the DB (In-Memory) + gtfsConfig := Config{ + GtfsURL: models.GetFixturePath(t, "raba.zip"), + GTFSDataPath: ":memory:", + } + + var err error + sharedManager, err = InitGTFSManager(gtfsConfig) + if err != nil { + panic("Failed to init shared GTFS manager: " + err.Error()) + } + + // Create the Calculator + sharedCalc = NewAdvancedDirectionCalculator(sharedManager.GtfsDB.Queries) + + // Warm the Global Cache (The heavy operation) + // We do this only once per test suite execution. + err = InitializeGlobalCache(context.Background(), sharedManager.GtfsDB.Queries, sharedCalc) + if err != nil { + panic("Failed to warm global cache: " + err.Error()) + } + }) + + return sharedManager, sharedCalc +} + func TestTranslateGtfsDirection(t *testing.T) { calc := &AdvancedDirectionCalculator{} @@ -41,7 +80,6 @@ func TestTranslateGtfsDirection(t *testing.T) { {"225 degrees", "225", "SW"}, {"270 degrees", "270", "W"}, {"315 degrees", "315", "NW"}, - // Invalid {"invalid text", "invalid", ""}, {"empty string", "", ""}, @@ -139,7 +177,6 @@ func TestStatisticalFunctions(t *testing.T) { m := mean(values) v := variance(values, m) assert.InDelta(t, 2.5, v, 0.001) // Sample variance of 1,2,3,4,5 is 2.5 - assert.Equal(t, 0.0, variance([]float64{5}, 5.0)) }) @@ -159,74 +196,41 @@ func TestStatisticalFunctions(t *testing.T) { func TestVarianceThreshold(t *testing.T) { calc := NewAdvancedDirectionCalculator(nil) - // Test default threshold assert.Equal(t, defaultVarianceThreshold, calc.varianceThreshold) - // Test setting custom threshold calc.SetVarianceThreshold(1.0) assert.Equal(t, 1.0, calc.varianceThreshold) } func TestCalculateStopDirection_WithShapeData(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + // Optimization: Reuse shared DB and Cache + _, calc := getSharedTestComponents(t) - // Test with a real stop from RABA data direction := calc.CalculateStopDirection(context.Background(), "7000", sql.NullString{Valid: false}) - // Should return a valid direction or empty string assert.True(t, direction == "" || len(direction) <= 2) } func TestComputeFromShapes_NoShapeData(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + // Optimization: Reuse shared DB and Cache + _, calc := getSharedTestComponents(t) - // Test with a non-existent stop direction := calc.computeFromShapes(context.Background(), "nonexistent") assert.Equal(t, "", direction) } func TestComputeFromShapes_SingleOrientation(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + // Optimization: Reuse shared DB and Cache + _, calc := getSharedTestComponents(t) - // Test with actual stop data - single orientation path will be taken if only one trip direction := calc.computeFromShapes(context.Background(), "7000") - // Direction should be valid or empty assert.True(t, direction == "" || len(direction) <= 2) } func TestComputeFromShapes_VarianceThreshold(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - + // Note: We reuse the Shared Manager (DB) but create a NEW Calculator. + // This is because we modify the variance threshold and don't want to break other tests. + manager, _ := getSharedTestComponents(t) calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) // Set a very low variance threshold to trigger variance check @@ -239,15 +243,7 @@ func TestComputeFromShapes_VarianceThreshold(t *testing.T) { } func TestCalculateOrientationAtStop_WithDistanceTraveled(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + manager, calc := getSharedTestComponents(t) // Get a shape ID from the database shapes, err := manager.GtfsDB.Queries.GetShapePointsWithDistance(context.Background(), "19_0_1") @@ -264,15 +260,7 @@ func TestCalculateOrientationAtStop_WithDistanceTraveled(t *testing.T) { } func TestCalculateOrientationAtStop_GeographicMatching(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + manager, calc := getSharedTestComponents(t) // Get a shape ID from the database shapes, err := manager.GtfsDB.Queries.GetShapePointsWithDistance(context.Background(), "19_0_1") @@ -291,15 +279,7 @@ func TestCalculateOrientationAtStop_GeographicMatching(t *testing.T) { } func TestCalculateOrientationAtStop_NoShapePoints(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + _, calc := getSharedTestComponents(t) // Test with non-existent shape - should return error or 0 orientation orientation, err := calc.calculateOrientationAtStop(context.Background(), "nonexistent", 0, 0, 0) @@ -308,22 +288,13 @@ func TestCalculateOrientationAtStop_NoShapePoints(t *testing.T) { } func TestCalculateOrientationAtStop_EdgeCases(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + manager, calc := getSharedTestComponents(t) // Test with shape that has points at the boundaries shapes, err := manager.GtfsDB.Queries.GetShapePointsWithDistance(context.Background(), "19_0_1") if err != nil || len(shapes) < 2 { t.Skip("No shape data available for testing") } - // Test at the very beginning of the shape if len(shapes) > 0 && shapes[0].ShapeDistTraveled.Valid { orientation, err := calc.calculateOrientationAtStop(context.Background(), "19_0_1", shapes[0].ShapeDistTraveled.Float64, 0, 0) @@ -431,18 +402,7 @@ func TestSetContextCache_PanicAfterInit(t *testing.T) { } func TestCalculateStopDirection_VariadicSignature(t *testing.T) { - - // Setup in-memory DB so the calculator has a valid query interface - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - - // Create the calculator using the VALID queries object - calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) + _, calc := getSharedTestComponents(t) // Case 1: Caller provides the optimized direction (should be used instantly) // We pass "North", expect "N" @@ -457,15 +417,8 @@ func TestCalculateStopDirection_VariadicSignature(t *testing.T) { } func TestSetContextCache_ConcurrentAccess(t *testing.T) { - // Setup - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) - defer manager.Shutdown() - + manager, _ := getSharedTestComponents(t) + // We use shared DB, but MUST use a fresh Calculator to test the race condition specifically on that instance. calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries) // Create dummy cache @@ -506,22 +459,10 @@ func TestSetContextCache_ConcurrentAccess(t *testing.T) { // TestBulkQuery_GetStopsWithShapeContextByIDs verifies the bulk optimization func TestBulkQuery_GetStopsWithShapeContextByIDs(t *testing.T) { - // Setup - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - if err != nil { - t.Fatalf("Failed to init manager: %v", err) - } - defer manager.Shutdown() - + manager, _ := getSharedTestComponents(t) ctx := context.Background() - // DYNAMICALLY fetch valid Stop IDs rows, err := manager.GtfsDB.DB.QueryContext(ctx, "SELECT id FROM stops LIMIT 5") - if err != nil { t.Fatalf("Failed to query stops: %v", err) } @@ -564,22 +505,12 @@ func TestBulkQuery_GetStopsWithShapeContextByIDs(t *testing.T) { // TestBulkQuery_GetShapePointsByIDs verifies fetching shape points in bulk. func TestBulkQuery_GetShapePointsByIDs(t *testing.T) { - // Setup - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - if err != nil { - t.Fatalf("Failed to init manager: %v", err) - } - defer manager.Shutdown() - + manager, _ := getSharedTestComponents(t) ctx := context.Background() // DYNAMICALLY fetch a real Shape ID from the DB var shapeID string - err = manager.GtfsDB.DB.QueryRowContext(ctx, "SELECT shape_id FROM shapes LIMIT 1").Scan(&shapeID) + err := manager.GtfsDB.DB.QueryRowContext(ctx, "SELECT shape_id FROM shapes LIMIT 1").Scan(&shapeID) // Stop immediately on error if err != nil { diff --git a/internal/gtfs/gtfs_manager_test.go b/internal/gtfs/gtfs_manager_test.go index 97e7ef11..97cf5550 100644 --- a/internal/gtfs/gtfs_manager_test.go +++ b/internal/gtfs/gtfs_manager_test.go @@ -24,13 +24,9 @@ func TestManager_GetAgencies(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: tc.dataPath, - Env: appconf.Test, - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) + + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) agencies := manager.GetAgencies() assert.Equal(t, 1, len(agencies)) @@ -61,12 +57,8 @@ func TestManager_RoutesForAgencyID(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: tc.dataPath, - GTFSDataPath: ":memory:", - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) routes := manager.RoutesForAgencyID("25") assert.Equal(t, 13, len(routes)) @@ -107,13 +99,8 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: tc.dataPath, - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) // Get stops using the manager method stops := manager.GetStopsForLocation(context.Background(), tc.lat, tc.lon, tc.radius, 0, 0, "", 100, false, nil, time.Time{}) @@ -144,13 +131,8 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { } func TestManager_GetTrips(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) trips := manager.GetTrips() assert.NotEmpty(t, trips) @@ -327,13 +309,9 @@ func TestManager_IsServiceActiveOnDate(t *testing.T) { } func TestManager_GetVehicleForTrip(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) + + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) // Set up real-time vehicle with a trip trip := >fs.Trip{ From 56b7889205e5a717e5e9e4af5dcee4849575c767 Mon Sep 17 00:00:00 2001 From: Vedanth Date: Thu, 19 Feb 2026 23:18:57 +0530 Subject: [PATCH 3/7] resolve minor conflicts --- internal/restapi/arrivals_and_departure_for_stop.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/restapi/arrivals_and_departure_for_stop.go b/internal/restapi/arrivals_and_departure_for_stop.go index d4d53414..67477151 100644 --- a/internal/restapi/arrivals_and_departure_for_stop.go +++ b/internal/restapi/arrivals_and_departure_for_stop.go @@ -471,7 +471,7 @@ func (api *RestAPI) arrivalsAndDeparturesForStopHandler(w http.ResponseWriter, r Lat: stopData.Lat, Lon: stopData.Lon, Code: stopData.Code.String, - Direction: calc.CalculateStopDirection(ctx, stopData.ID, stopData.Direction), + Direction: api.DirectionCalculator.CalculateStopDirection(ctx, stopData.ID, stopData.Direction), LocationType: int(stopData.LocationType.Int64), WheelchairBoarding: utils.MapWheelchairBoarding(utils.NullWheelchairBoardingOrUnknown(stopData.WheelchairBoarding)), RouteIDs: combinedRouteIDs, From 5a77bae62c8f462ae3c467e4fc44135851abbb50 Mon Sep 17 00:00:00 2001 From: Vedanth Date: Fri, 20 Feb 2026 00:46:28 +0530 Subject: [PATCH 4/7] Refactor global cache initialization and error handling; streamline SQL queries and test setup --- cmd/api/app.go | 8 ++-- gtfsdb/query.sql | 2 +- gtfsdb/query.sql.go | 2 +- .../advanced_direction_calculator_test.go | 15 ++++++ internal/gtfs/global_cache.go | 10 ++-- internal/gtfs/gtfs_manager_test.go | 48 ++++++++----------- internal/restapi/http_test.go | 22 +++++---- .../restapi/schedule_for_route_handler.go | 6 ++- internal/restapi/stops_for_route_handler.go | 5 +- 9 files changed, 64 insertions(+), 54 deletions(-) diff --git a/cmd/api/app.go b/cmd/api/app.go index 3a65c9ab..5965637f 100644 --- a/cmd/api/app.go +++ b/cmd/api/app.go @@ -51,11 +51,11 @@ func BuildApplication(cfg appconf.Config, gtfsCfg gtfs.Config) (*app.Application var directionCalculator *gtfs.AdvancedDirectionCalculator if gtfsManager != nil { directionCalculator = gtfs.NewAdvancedDirectionCalculator(gtfsManager.GtfsDB.Queries) - } - err = gtfs.InitializeGlobalCache(context.Background(), gtfsManager.GtfsDB.Queries, directionCalculator) - if err != nil { - return nil, fmt.Errorf("failed to initialize global cache: %w", err) + err = gtfs.InitializeGlobalCache(context.Background(), gtfsManager.GtfsDB.Queries, directionCalculator) + if err != nil { + return nil, fmt.Errorf("failed to initialize global cache: %w", err) + } } // Select clock implementation based on environment diff --git a/gtfsdb/query.sql b/gtfsdb/query.sql index 3193fbe2..528de8db 100644 --- a/gtfsdb/query.sql +++ b/gtfsdb/query.sql @@ -195,7 +195,7 @@ LIMIT 1; -- name: GetAllStopIDs :many -SELECT DISTINCT +SELECT id FROM stops; diff --git a/gtfsdb/query.sql.go b/gtfsdb/query.sql.go index 661fe109..5296c7cf 100644 --- a/gtfsdb/query.sql.go +++ b/gtfsdb/query.sql.go @@ -1072,7 +1072,7 @@ func (q *Queries) GetAllShapes(ctx context.Context) ([]Shape, error) { } const getAllStopIDs = `-- name: GetAllStopIDs :many -SELECT DISTINCT +SELECT id FROM stops diff --git a/internal/gtfs/advanced_direction_calculator_test.go b/internal/gtfs/advanced_direction_calculator_test.go index 408bd8d6..253e60ff 100644 --- a/internal/gtfs/advanced_direction_calculator_test.go +++ b/internal/gtfs/advanced_direction_calculator_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "math" + "os" "sync" "testing" @@ -538,3 +539,17 @@ func TestBulkQuery_GetShapePointsByIDs(t *testing.T) { } assert.True(t, isSorted, "Shape points should be returned in sequence order") } + +func TestMain(m *testing.M) { + // Run all tests + code := m.Run() + + // Global Teardown + // If sharedManager was initialized during tests, shut it down now. + if sharedManager != nil { + sharedManager.Shutdown() + } + + // Exit with the test result code + os.Exit(code) +} diff --git a/internal/gtfs/global_cache.go b/internal/gtfs/global_cache.go index 0d0a8092..963e221d 100644 --- a/internal/gtfs/global_cache.go +++ b/internal/gtfs/global_cache.go @@ -2,6 +2,7 @@ package gtfs import ( "context" + "fmt" "log/slog" "maglev.onebusaway.org/gtfsdb" @@ -13,13 +14,13 @@ func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *Ad // Fetch ALL Stop IDs allStopIDs, err := queries.GetAllStopIDs(ctx) if err != nil { - return err + return fmt.Errorf("failed to fetch all stop IDs: %w", err) } // Fetch Context (Stop -> Shape mappings) contextRows, err := queries.GetStopsWithShapeContextByIDs(ctx, allStopIDs) if err != nil { - return err + return fmt.Errorf("failed to fetch stop context rows: %w", err) } contextCache := make(map[string][]gtfsdb.GetStopsWithShapeContextRow) @@ -50,9 +51,7 @@ func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *Ad if len(uniqueShapeIDs) > 0 { shapePoints, err := queries.GetShapePointsByIDs(ctx, uniqueShapeIDs) if err != nil { - // Fail fast if we can't load shapes (or just log error if you want to be resilient) - slog.Warn("Failed to fetch shape points for global cache", "error", err) - return err + return fmt.Errorf("failed to fetch shape points for global cache: %w", err) } for _, p := range shapePoints { @@ -60,6 +59,7 @@ func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *Ad Lat: p.Lat, Lon: p.Lon, ShapeDistTraveled: p.ShapeDistTraveled, + ShapePtSequence: int64(p.ShapePtSequence), // Added ShapePtSequence }) } } diff --git a/internal/gtfs/gtfs_manager_test.go b/internal/gtfs/gtfs_manager_test.go index 117795ba..9b6e39b9 100644 --- a/internal/gtfs/gtfs_manager_test.go +++ b/internal/gtfs/gtfs_manager_test.go @@ -13,18 +13,16 @@ import ( func TestManager_GetAgencies(t *testing.T) { testCases := []struct { - name string - dataPath string + name string }{ { - name: "FromLocalFile", - dataPath: models.GetFixturePath(t, "raba.zip"), + name: "FromLocalFile", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - + // Use shared component to avoid reloading DB manager, _ := getSharedTestComponents(t) assert.NotNil(t, manager) @@ -46,12 +44,10 @@ func TestManager_GetAgencies(t *testing.T) { func TestManager_RoutesForAgencyID(t *testing.T) { testCases := []struct { - name string - dataPath string + name string }{ { - name: "FromLocalFile", - dataPath: models.GetFixturePath(t, "raba.zip"), + name: "FromLocalFile", }, } @@ -73,7 +69,6 @@ func TestManager_RoutesForAgencyID(t *testing.T) { func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { testCases := []struct { name string - dataPath string lat float64 lon float64 radius float64 @@ -81,7 +76,6 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { }{ { name: "FindStopsWithinRadius", - dataPath: models.GetFixturePath(t, "raba.zip"), lat: 40.589123, // Near Redding, CA lon: -122.390830, radius: 2000, // 2km radius @@ -89,7 +83,6 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { }, { name: "FindStopsWithinRadius", - dataPath: models.GetFixturePath(t, "raba.zip"), lat: 47.589123, // West Seattle lon: -122.390830, radius: 2000, // 2km radius @@ -98,6 +91,7 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { } for _, tc := range testCases { + tc := tc // Capture loop variable t.Run(tc.name, func(t *testing.T) { manager, _ := getSharedTestComponents(t) assert.NotNil(t, manager) @@ -140,13 +134,7 @@ func TestManager_GetTrips(t *testing.T) { } func TestManager_FindAgency(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) + manager, _ := getSharedTestComponents(t) agency := manager.FindAgency("25") assert.NotNil(t, agency) @@ -238,13 +226,7 @@ func TestManager_GetTripUpdateByID(t *testing.T) { } func TestManager_IsServiceActiveOnDate(t *testing.T) { - gtfsConfig := Config{ - GtfsURL: models.GetFixturePath(t, "raba.zip"), - GTFSDataPath: ":memory:", - Env: appconf.Test, - } - manager, err := InitGTFSManager(gtfsConfig) - assert.Nil(t, err) + manager, _ := getSharedTestComponents(t) // Get a trip to find a valid service ID trips := manager.GetTrips() @@ -295,6 +277,7 @@ func TestManager_IsServiceActiveOnDate(t *testing.T) { } for _, tc := range testCases { + tc := tc t.Run(tc.name, func(t *testing.T) { // Verify the date is the expected weekday assert.Equal(t, tc.weekday, tc.date.Weekday().String()) @@ -310,8 +293,13 @@ func TestManager_IsServiceActiveOnDate(t *testing.T) { func TestManager_GetVehicleForTrip(t *testing.T) { - manager, _ := getSharedTestComponents(t) - assert.NotNil(t, manager) + gtfsConfig := Config{ + GtfsURL: models.GetFixturePath(t, "raba.zip"), + GTFSDataPath: ":memory:", + Env: appconf.Test, + } + manager, err := InitGTFSManager(gtfsConfig) + assert.Nil(t, err) // Set up real-time vehicle with a trip trip := >fs.Trip{ @@ -330,6 +318,10 @@ func TestManager_GetVehicleForTrip(t *testing.T) { assert.NotNil(t, vehicle) assert.Equal(t, "vehicle1", vehicle.ID.ID) } + + // Test Not Found + nilVehicle := manager.GetVehicleForTrip("nonexistent") + assert.Nil(t, nilVehicle) } func TestBuildLookupMaps(t *testing.T) { diff --git a/internal/restapi/http_test.go b/internal/restapi/http_test.go index 2748a5f0..30eee4fb 100644 --- a/internal/restapi/http_test.go +++ b/internal/restapi/http_test.go @@ -27,9 +27,10 @@ import ( // Shared test database setup var ( - testGtfsManager *gtfs.Manager - testDbSetupOnce sync.Once - testDbPath = filepath.Join("../../testdata", "raba-test.db") + testGtfsManager *gtfs.Manager + testDirectionCalculator *gtfs.AdvancedDirectionCalculator + testDbSetupOnce sync.Once + testDbPath = filepath.Join("../../testdata", "raba-test.db") ) // TestMain handles setup and cleanup for all tests in this package @@ -60,6 +61,13 @@ func createTestApiWithClock(t testing.TB, c clock.Clock) *RestAPI { if err != nil { t.Fatalf("Failed to initialize shared test GTFS manager: %v", err) } + + // Create the DirectionCalculator using the shared manager's queries + testDirectionCalculator = gtfs.NewAdvancedDirectionCalculator(testGtfsManager.GtfsDB.Queries) + + // Warm up the cache with test data + err = gtfs.InitializeGlobalCache(context.Background(), testGtfsManager.GtfsDB.Queries, testDirectionCalculator) + require.NoError(t, err, "Failed to initialize global cache for tests") }) gtfsConfig := gtfs.Config{ @@ -67,12 +75,6 @@ func createTestApiWithClock(t testing.TB, c clock.Clock) *RestAPI { GTFSDataPath: testDbPath, } - // Create the DirectionCalculator using the shared manager's queries - directionCalculator := gtfs.NewAdvancedDirectionCalculator(testGtfsManager.GtfsDB.Queries) - - // Warm up the cache with test data - _ = gtfs.InitializeGlobalCache(context.Background(), testGtfsManager.GtfsDB.Queries, directionCalculator) - application := &app.Application{ Config: appconf.Config{ Env: appconf.EnvFlagToEnvironment("test"), @@ -82,7 +84,7 @@ func createTestApiWithClock(t testing.TB, c clock.Clock) *RestAPI { }, GtfsConfig: gtfsConfig, GtfsManager: testGtfsManager, - DirectionCalculator: directionCalculator, + DirectionCalculator: testDirectionCalculator, Clock: c, } diff --git a/internal/restapi/schedule_for_route_handler.go b/internal/restapi/schedule_for_route_handler.go index 75386d70..19aecbbb 100644 --- a/internal/restapi/schedule_for_route_handler.go +++ b/internal/restapi/schedule_for_route_handler.go @@ -242,9 +242,11 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque if len(uniqueStopIDs) > 0 { modelStops, _, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, uniqueStopIDs) - if err == nil { - references.Stops = append(references.Stops, modelStops...) + if err != nil { + api.serverErrorResponse(w, r, err) + return } + references.Stops = append(references.Stops, modelStops...) } for _, sref := range stopTimesRefs { diff --git a/internal/restapi/stops_for_route_handler.go b/internal/restapi/stops_for_route_handler.go index 8d466dbf..5f998738 100644 --- a/internal/restapi/stops_for_route_handler.go +++ b/internal/restapi/stops_for_route_handler.go @@ -10,7 +10,6 @@ import ( "github.com/twpayne/go-polyline" "maglev.onebusaway.org/gtfsdb" - GTFS "maglev.onebusaway.org/internal/gtfs" "maglev.onebusaway.org/internal/models" "maglev.onebusaway.org/internal/utils" ) @@ -106,7 +105,7 @@ func (api *RestAPI) stopsForRouteHandler(w http.ResponseWriter, r *http.Request) return } - result, stopsList, err := api.processRouteStops(ctx, agencyID, routeID, serviceIDs, params.IncludePolylines, api.DirectionCalculator) + result, stopsList, err := api.processRouteStops(ctx, agencyID, routeID, serviceIDs, params.IncludePolylines) if err != nil { api.serverErrorResponse(w, r, err) return @@ -114,7 +113,7 @@ func (api *RestAPI) stopsForRouteHandler(w http.ResponseWriter, r *http.Request) api.buildAndSendResponse(w, r, ctx, result, stopsList, currentAgency) } -func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, routeID string, serviceIDs []string, includePolylines bool, adc *GTFS.AdvancedDirectionCalculator) (models.RouteEntry, []models.Stop, error) { +func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, routeID string, serviceIDs []string, includePolylines bool) (models.RouteEntry, []models.Stop, error) { allStops := make(map[string]bool) allPolylines := make([]models.Polyline, 0, 100) var stopGroupings []models.StopGrouping From 48bc39a2e31efde13d69a438b21574f924bbadb9 Mon Sep 17 00:00:00 2001 From: Vedanth Date: Mon, 23 Feb 2026 18:02:10 +0530 Subject: [PATCH 5/7] Refactor: streamline global cache initialization and improve test structure --- internal/gtfs/global_cache.go | 9 +-- internal/gtfs/gtfs_manager_test.go | 73 +++++++------------ internal/restapi/stops_for_route_handler.go | 2 +- .../vehicles_for_agency_handler_test.go | 6 ++ 4 files changed, 34 insertions(+), 56 deletions(-) diff --git a/internal/gtfs/global_cache.go b/internal/gtfs/global_cache.go index 963e221d..485324f7 100644 --- a/internal/gtfs/global_cache.go +++ b/internal/gtfs/global_cache.go @@ -11,7 +11,6 @@ import ( func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *AdvancedDirectionCalculator) error { slog.Info("starting global cache warmup...") - // Fetch ALL Stop IDs allStopIDs, err := queries.GetAllStopIDs(ctx) if err != nil { return fmt.Errorf("failed to fetch all stop IDs: %w", err) @@ -22,13 +21,12 @@ func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *Ad if err != nil { return fmt.Errorf("failed to fetch stop context rows: %w", err) } - + contextCache := make(map[string][]gtfsdb.GetStopsWithShapeContextRow) shapeIDMap := make(map[string]bool) var uniqueShapeIDs []string for _, row := range contextRows { - // Map the DB row to the Cache row struct calcRow := gtfsdb.GetStopsWithShapeContextRow{ ID: row.StopID, ShapeID: row.ShapeID, @@ -38,14 +36,12 @@ func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *Ad } contextCache[row.StopID] = append(contextCache[row.StopID], calcRow) - // Collect unique valid Shape IDs if row.ShapeID.Valid && row.ShapeID.String != "" && !shapeIDMap[row.ShapeID.String] { shapeIDMap[row.ShapeID.String] = true uniqueShapeIDs = append(uniqueShapeIDs, row.ShapeID.String) } } - // Fetch Shape Points (Geometry) shapeCache := make(map[string][]gtfsdb.GetShapePointsWithDistanceRow) if len(uniqueShapeIDs) > 0 { @@ -59,12 +55,11 @@ func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *Ad Lat: p.Lat, Lon: p.Lon, ShapeDistTraveled: p.ShapeDistTraveled, - ShapePtSequence: int64(p.ShapePtSequence), // Added ShapePtSequence + ShapePtSequence: p.ShapePtSequence, }) } } - // Set Cache adc.SetShapeCache(shapeCache) adc.SetContextCache(contextCache) diff --git a/internal/gtfs/gtfs_manager_test.go b/internal/gtfs/gtfs_manager_test.go index ce446331..814e183e 100644 --- a/internal/gtfs/gtfs_manager_test.go +++ b/internal/gtfs/gtfs_manager_test.go @@ -15,60 +15,38 @@ import ( ) func TestManager_GetAgencies(t *testing.T) { - testCases := []struct { - name string - }{ - { - name: "FromLocalFile", - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Use shared component to avoid reloading DB - manager, _ := getSharedTestComponents(t) - assert.NotNil(t, manager) + // Use shared component to avoid reloading DB + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) - agencies := manager.GetAgencies() - assert.Equal(t, 1, len(agencies)) - - agency := agencies[0] - assert.Equal(t, "25", agency.Id) - assert.Equal(t, "Redding Area Bus Authority", agency.Name) - assert.Equal(t, "http://www.rabaride.com/", agency.Url) - assert.Equal(t, "America/Los_Angeles", agency.Timezone) - assert.Equal(t, "en", agency.Language) - assert.Equal(t, "530-241-2877", agency.Phone) - assert.Equal(t, "", agency.FareUrl) - assert.Equal(t, "", agency.Email) - }) - } + agencies := manager.GetAgencies() + assert.Equal(t, 1, len(agencies)) + + agency := agencies[0] + assert.Equal(t, "25", agency.Id) + assert.Equal(t, "Redding Area Bus Authority", agency.Name) + assert.Equal(t, "http://www.rabaride.com/", agency.Url) + assert.Equal(t, "America/Los_Angeles", agency.Timezone) + assert.Equal(t, "en", agency.Language) + assert.Equal(t, "530-241-2877", agency.Phone) + assert.Equal(t, "", agency.FareUrl) + assert.Equal(t, "", agency.Email) } func TestManager_RoutesForAgencyID(t *testing.T) { - testCases := []struct { - name string - }{ - { - name: "FromLocalFile", - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - manager, _ := getSharedTestComponents(t) - assert.NotNil(t, manager) + manager, _ := getSharedTestComponents(t) + assert.NotNil(t, manager) - manager.RLock() - routes := manager.RoutesForAgencyID("25") - manager.RUnlock() - assert.Equal(t, 13, len(routes)) + manager.RLock() + routes := manager.RoutesForAgencyID("25") + manager.RUnlock() + assert.Equal(t, 13, len(routes)) - route := routes[0] - assert.Equal(t, "1", route.ShortName) - assert.Equal(t, "25", route.Agency.Id) - }) - } + route := routes[0] + assert.Equal(t, "1", route.ShortName) + assert.Equal(t, "25", route.Agency.Id) } func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { @@ -96,7 +74,6 @@ func TestManager_GetStopsForLocation_UsesSpatialIndex(t *testing.T) { } for _, tc := range testCases { - tc := tc // Capture loop variable t.Run(tc.name, func(t *testing.T) { manager, _ := getSharedTestComponents(t) assert.NotNil(t, manager) @@ -271,7 +248,6 @@ func TestManager_IsServiceActiveOnDate(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { assert.Equal(t, tc.weekday, tc.date.Weekday().String()) @@ -290,6 +266,7 @@ func TestManager_GetVehicleForTrip(t *testing.T) { GTFSDataPath: ":memory:", Env: appconf.Test, } + //We use isolated GTFSManager here instead of shared test components because we want to control the real-time vehicles for this test. manager, err := InitGTFSManager(gtfsConfig) assert.Nil(t, err) defer manager.Shutdown() diff --git a/internal/restapi/stops_for_route_handler.go b/internal/restapi/stops_for_route_handler.go index df431f69..9f15214c 100644 --- a/internal/restapi/stops_for_route_handler.go +++ b/internal/restapi/stops_for_route_handler.go @@ -100,7 +100,7 @@ func (api *RestAPI) stopsForRouteHandler(w http.ResponseWriter, r *http.Request) api.buildAndSendResponse(w, r, ctx, result, stopsList, currentAgency) } -func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, routeID string, serviceIDs []string, includePolylines bool, adc *GTFS.AdvancedDirectionCalculator) (models.RouteEntry, []models.Stop, error) { +func (api *RestAPI) processRouteStops(ctx context.Context, agencyID string, routeID string, serviceIDs []string, includePolylines bool) (models.RouteEntry, []models.Stop, error) { allStops := make(map[string]bool) allPolylines := make([]models.Polyline, 0, 100) var stopGroupings []models.StopGrouping diff --git a/internal/restapi/vehicles_for_agency_handler_test.go b/internal/restapi/vehicles_for_agency_handler_test.go index 477bcfec..5ac09fb3 100644 --- a/internal/restapi/vehicles_for_agency_handler_test.go +++ b/internal/restapi/vehicles_for_agency_handler_test.go @@ -1,6 +1,7 @@ package restapi import ( + "context" "net/http" "net/http/httptest" "os" @@ -340,6 +341,10 @@ func createTestApiWithRealTimeData(t *testing.T) (*RestAPI, func()) { gtfsManager, err := gtfs.InitGTFSManager(gtfsConfig) require.NoError(t, err) + dirCalc := gtfs.NewAdvancedDirectionCalculator(gtfsManager.GtfsDB.Queries) + err = gtfs.InitializeGlobalCache(context.Background(), gtfsManager.GtfsDB.Queries, dirCalc) + require.NoError(t, err) + application := &app.Application{ Config: appconf.Config{ Env: appconf.EnvFlagToEnvironment("test"), @@ -348,6 +353,7 @@ func createTestApiWithRealTimeData(t *testing.T) (*RestAPI, func()) { }, GtfsConfig: gtfsConfig, GtfsManager: gtfsManager, + DirectionCalculator: dirCalc, Clock: clock.RealClock{}, } From d7343c26561beefb0a280e5387dbffc9fef8a7ae Mon Sep 17 00:00:00 2001 From: Vedanth Date: Mon, 23 Feb 2026 18:02:41 +0530 Subject: [PATCH 6/7] Refactor: clean up whitespace and improve code formatting in global cache and vehicles handler tests --- internal/gtfs/global_cache.go | 2 +- internal/restapi/vehicles_for_agency_handler_test.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/gtfs/global_cache.go b/internal/gtfs/global_cache.go index 485324f7..53a44d87 100644 --- a/internal/gtfs/global_cache.go +++ b/internal/gtfs/global_cache.go @@ -21,7 +21,7 @@ func InitializeGlobalCache(ctx context.Context, queries *gtfsdb.Queries, adc *Ad if err != nil { return fmt.Errorf("failed to fetch stop context rows: %w", err) } - + contextCache := make(map[string][]gtfsdb.GetStopsWithShapeContextRow) shapeIDMap := make(map[string]bool) var uniqueShapeIDs []string diff --git a/internal/restapi/vehicles_for_agency_handler_test.go b/internal/restapi/vehicles_for_agency_handler_test.go index 5ac09fb3..7ac3d8bf 100644 --- a/internal/restapi/vehicles_for_agency_handler_test.go +++ b/internal/restapi/vehicles_for_agency_handler_test.go @@ -351,10 +351,10 @@ func createTestApiWithRealTimeData(t *testing.T) (*RestAPI, func()) { ApiKeys: []string{"TEST"}, RateLimit: 100, // Higher rate limit for this test }, - GtfsConfig: gtfsConfig, - GtfsManager: gtfsManager, + GtfsConfig: gtfsConfig, + GtfsManager: gtfsManager, DirectionCalculator: dirCalc, - Clock: clock.RealClock{}, + Clock: clock.RealClock{}, } api := NewRestAPI(application) From 990f11f29774e111624be6964c09f9ad6884b069 Mon Sep 17 00:00:00 2001 From: Vedanth Date: Mon, 23 Feb 2026 18:12:27 +0530 Subject: [PATCH 7/7] Refactor: improve error handling in scheduleForRouteHandler for agency and trip retrieval --- .../restapi/schedule_for_route_handler.go | 64 ++++++++++--------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/internal/restapi/schedule_for_route_handler.go b/internal/restapi/schedule_for_route_handler.go index e5a0bdde..f49da05b 100644 --- a/internal/restapi/schedule_for_route_handler.go +++ b/internal/restapi/schedule_for_route_handler.go @@ -182,22 +182,25 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque references := models.NewEmptyReferences() agency, err := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, agencyID) - if err == nil { - agencyModel := models.NewAgencyReference( - agency.ID, - agency.Name, - agency.Url, - agency.Timezone, - agency.Lang.String, - agency.Phone.String, - agency.Email.String, - agency.FareUrl.String, - "", - false, - ) - references.Agencies = append(references.Agencies, agencyModel) + if err != nil { + api.serverErrorResponse(w, r, err) + return } + agencyModel := models.NewAgencyReference( + agency.ID, + agency.Name, + agency.Url, + agency.Timezone, + agency.Lang.String, + agency.Phone.String, + agency.Email.String, + agency.FareUrl.String, + "", + false, + ) + references.Agencies = append(references.Agencies, agencyModel) + for _, r := range routeRefs { references.Routes = append(references.Routes, r) } @@ -208,21 +211,24 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque } if len(tripIDs) > 0 { tripRows, err := api.GtfsManager.GtfsDB.Queries.GetTripsByIDs(ctx, tripIDs) - if err == nil { - for _, t := range tripRows { - combinedTripID := utils.FormCombinedID(agencyID, t.ID) - tripRef := models.NewTripReference( - combinedTripID, - t.RouteID, - t.ServiceID, - t.TripHeadsign.String, - t.TripShortName.String, - t.DirectionID.Int64, - utils.FormCombinedID(agencyID, t.BlockID.String), - utils.FormCombinedID(agencyID, t.ShapeID.String), - ) - references.Trips = append(references.Trips, tripRef) - } + if err != nil { + api.serverErrorResponse(w, r, err) + return + } + + for _, t := range tripRows { + combinedTripID := utils.FormCombinedID(agencyID, t.ID) + tripRef := models.NewTripReference( + combinedTripID, + t.RouteID, + t.ServiceID, + t.TripHeadsign.String, + t.TripShortName.String, + t.DirectionID.Int64, + utils.FormCombinedID(agencyID, t.BlockID.String), + utils.FormCombinedID(agencyID, t.ShapeID.String), + ) + references.Trips = append(references.Trips, tripRef) } }