diff --git a/gateway/go.mod b/gateway/go.mod index 62b1122..d8d08df 100644 --- a/gateway/go.mod +++ b/gateway/go.mod @@ -11,6 +11,7 @@ require ( ) require ( + github.com/ProjectZKM/Ziren/crates/go-runtime/zkvm_runtime v0.0.0-20251001021608-1fe7b43fc4d6 // indirect github.com/bytedance/sonic v1.14.0 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect diff --git a/gateway/main.go b/gateway/main.go index a2d68be..4a5a610 100644 --- a/gateway/main.go +++ b/gateway/main.go @@ -20,6 +20,8 @@ import ( "strings" "sync" "time" + "os/signal" + "syscall" "github.com/ethereum/go-ethereum/crypto" "github.com/gin-contrib/cors" @@ -157,6 +159,7 @@ func main() { // deadline; the middleware implementation always uses the earliest // deadline when nested timeouts are present to avoid surprising behavior. r.Use(RequestTimeoutMiddleware(getRequestTimeout())) + r.Use(TrackInFlightRequests()) // Health check with shorter timeout (2s) r.GET("/healthz", RequestTimeoutMiddleware(getHealthCheckTimeout()), handleHealth) @@ -187,8 +190,49 @@ func main() { port = "3000" } - log.Printf("Go Gateway running on port %s", port) - r.Run(":" + port) +addr := ":" + port + + srv := &http.Server{ + Addr: addr, + Handler: r, + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 120 * time.Second, + } + + + go func() { + log.Printf("[INFO] Gateway listening on %s", addr) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("[FATAL] listen error: %v", err) + } + }() + + // ---- Graceful shutdown ---- + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + <-quit + log.Println("[INFO] Shutdown signal received, draining connections...") + + active := GetActiveRequestCount() + if active > 0 { + log.Printf("[INFO] Waiting for %d in-flight request(s)...", active) + WaitForInFlightRequests() + log.Println("[INFO] All in-flight requests completed") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + log.Printf("[ERROR] Server forced to shutdown: %v", err) + } else { + log.Println("[OK] Server shutdown completed") + } + + } // handleSummarize handles POST /api/ai/summarize requests. It validates diff --git a/gateway/request_tracker.go b/gateway/request_tracker.go new file mode 100644 index 0000000..ad2c723 --- /dev/null +++ b/gateway/request_tracker.go @@ -0,0 +1,38 @@ +package main + +import ( + "sync" + "sync/atomic" + + "github.com/gin-gonic/gin" +) + +var ( + activeRequestsWG sync.WaitGroup + activeRequestCnt int64 +) + +// TrackInFlightRequests tracks active HTTP requests. +func TrackInFlightRequests() gin.HandlerFunc { + return func(c *gin.Context) { + activeRequestsWG.Add(1) + atomic.AddInt64(&activeRequestCnt, 1) + + defer func() { + atomic.AddInt64(&activeRequestCnt, -1) + activeRequestsWG.Done() + }() + + c.Next() + } +} + +// WaitForInFlightRequests blocks until all active requests finish. +func WaitForInFlightRequests() { + activeRequestsWG.Wait() +} + +// GetActiveRequestCount returns the current number of active requests. +func GetActiveRequestCount() int64 { + return atomic.LoadInt64(&activeRequestCnt) +} diff --git a/gateway/shutdown_test.go b/gateway/shutdown_test.go new file mode 100644 index 0000000..a121dde --- /dev/null +++ b/gateway/shutdown_test.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" +) +func TestGracefulShutdown_WaitsForInFlightRequests(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(TrackInFlightRequests()) + + // Simulate slow handler + r.GET("/slow", func(c *gin.Context) { + time.Sleep(200 * time.Millisecond) + c.Status(http.StatusOK) + }) + + srv := &http.Server{ + Handler: r, + } + + // Start test server + ln := httptest.NewUnstartedServer(r) + ln.Config = srv + ln.Start() + defer ln.Close() + + // Make request in background + done := make(chan struct{}) + go func() { + resp, err := http.Get(ln.URL + "/slow") + if err != nil { + t.Errorf("request failed: %v", err) + return + } + resp.Body.Close() + close(done) + }() + + // Give request time to start + time.Sleep(50 * time.Millisecond) + + // Shutdown server + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + start := time.Now() + if err := srv.Shutdown(ctx); err != nil { + t.Fatalf("shutdown failed: %v", err) + } + + WaitForInFlightRequests() + elapsed := time.Since(start) + + <-done + + // Assert shutdown waited for request + if elapsed < 200*time.Millisecond { + t.Fatalf("shutdown did not wait for in-flight request") + } +}