diff --git a/README.md b/README.md index 78e0fc0..2a62d9f 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,47 @@ Manners ensures that all requests are served by incrementing a WaitGroup when a If your request handler spawns Goroutines that are not guaranteed to finish with the request, you can ensure they are also completed with the `StartRoutine` and `FinishRoutine` functions on the server. +### HTTP, HTTPS and FCGI + +Manners supports three protocols: HTTP, HTTPS and FCGI. HTTP is illustrated above. +For HTTPS, Manners can likewise act as a drop-in replacement for the standard library's +[http.ListenAndServeTLS](http://golang.org/pkg/net/http/#ListenAndServeTLS) function: + +```go +func main() { + handler := MyHTTPHandler() + certFile := MyCertificate() + keyFile := MyKeyFile() + manners.ListenAndServeTLS(":https", certFile, keyFile, handler) +} +``` + +In Manners, FCGI only operates via local a Unix socket connected to a co-hosted proxy, such as Apache or Nginx. + +```go +func main() { + handler := MyHTTPHandler() + manners.ListenAndServe("/var/run/goserver.sock", handler) +} +``` + +To use FCGI, the port string must specify the Unix socket and start with a slash or dot, as in the example above. In this case, Manners will use [fcgi.Serve](http://golang.org/pkg/net/http/fcgi/#Serve). + +In each of the protocols, Manners drains down the connections cleanly when `manners.Close()` is called. + +### Handling signals + +It's good to close down the server cleanly when OS signals are received. This is easy: just add + +```go +manners.CloseOnInterrupt() +``` +before the `ListenAndServe` call. This kicks off a separate goroutine to wait for an OS signal, upon which it simply calls `manners.Close()` for you. Optionally, you can pass in a list of the particular signals you care about and you can find out which signal was received, if any, afterwards. + +### Known Issues + +Manners does not correctly shut down long-lived keepalive connections when issued a shutdown command. Clients on an idle keepalive connection may see a connection reset error rather than a close. See https://github.com/braintree/manners/issues/13 for details. + ### Compatability Manners 0.3.0 and above uses standard library functionality introduced in Go 1.3. diff --git a/helpers_test.go b/helpers_test.go index bde3703..d176abe 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -9,21 +9,21 @@ import ( "testing" ) -func newServer() *GracefulServer { - return NewWithServer(new(http.Server)) -} - // a simple step-controllable http client type client struct { tls bool addr net.Addr connected chan error sendrequest chan bool - idle chan error - idlerelease chan bool + response chan *rawResponse closed chan bool } +type rawResponse struct { + body []string + err error +} + func (c *client) Run() { go func() { var err error @@ -39,19 +39,21 @@ func (c *client) Run() { for <-c.sendrequest { _, err = conn.Write([]byte("GET / HTTP/1.1\nHost: localhost:8000\n\n")) if err != nil { - c.idle <- err + c.response <- &rawResponse{err: err} } // Read response; no content scanner := bufio.NewScanner(conn) + var lines []string for scanner.Scan() { // our null handler doesn't send a body, so we know the request is // done when we reach the blank line after the headers - if scanner.Text() == "" { + line := scanner.Text() + if line == "" { break } + lines = append(lines, line) } - c.idle <- scanner.Err() - <-c.idlerelease + c.response <- &rawResponse{lines, scanner.Err()} } conn.Close() ioutil.ReadAll(conn) @@ -65,8 +67,7 @@ func newClient(addr net.Addr, tls bool) *client { tls: tls, connected: make(chan error), sendrequest: make(chan bool), - idle: make(chan error), - idlerelease: make(chan bool), + response: make(chan *rawResponse), closed: make(chan bool), } } @@ -81,7 +82,7 @@ func startGenericServer(t *testing.T, server *GracefulServer, statechanged chan // Wrap the ConnState handler with something that will notify // the statechanged channel when a state change happens server.ConnState = func(conn net.Conn, newState http.ConnState) { - statechanged <- newState + statechanged <- conn.LocalAddr().(*gracefulAddr).gconn.lastHTTPState } } diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..ccb1c9f --- /dev/null +++ b/listener.go @@ -0,0 +1,190 @@ +package manners + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "os" + "sync" + "time" +) + +// NewListener wraps an existing listener for use with +// GracefulServer. +// +// Note that you generally don't need to use this directly as +// GracefulServer will automatically wrap any non-graceful listeners +// supplied to it. +func NewListener(l net.Listener) *GracefulListener { + return &GracefulListener{ + listener: l, + mutex: &sync.RWMutex{}, + open: true, + } +} + +// A gracefulCon wraps a normal net.Conn and tracks the last known http state. +type gracefulConn struct { + net.Conn + lastHTTPState http.ConnState + // protected tells whether the connection is going to defer server shutdown + // until the current HTTP request is completed. + protected bool +} + +type gracefulAddr struct { + net.Addr + gconn *gracefulConn +} + +func (g *gracefulConn) LocalAddr() net.Addr { + return &gracefulAddr{g.Conn.LocalAddr(), g} +} + +// retrieveGracefulConn retrieves a concrete gracefulConn instance from an +// interface value that can either refer to it directly or refer to a tls.Conn +// instance wrapping around a gracefulConn one. +func retrieveGracefulConn(conn net.Conn) *gracefulConn { + return conn.LocalAddr().(*gracefulAddr).gconn +} + +// A GracefulListener differs from a standard net.Listener in one way: if +// Accept() is called after it is gracefully closed, it returns a +// listenerAlreadyClosed error. The GracefulServer will ignore this error. +type GracefulListener struct { + listener net.Listener + open bool + mutex *sync.RWMutex +} + +func (l *GracefulListener) isClosed() bool { + l.mutex.RLock() + defer l.mutex.RUnlock() + return !l.open +} + +func (l *GracefulListener) Addr() net.Addr { + return l.listener.Addr() +} + +// Accept implements the Accept method in the Listener interface. +func (l *GracefulListener) Accept() (net.Conn, error) { + conn, err := l.listener.Accept() + if err != nil { + if l.isClosed() { + err = listenerAlreadyClosed{err} + } + return nil, err + } + + // don't wrap connection if it's tls so we won't break + // http server internal logic that relies on the type + if _, ok := conn.(*tls.Conn); ok { + return conn, nil + } + return &gracefulConn{Conn: conn}, nil +} + +// Close tells the wrapped listener to stop listening. It is idempotent. +func (l *GracefulListener) Close() error { + l.mutex.Lock() + defer l.mutex.Unlock() + if !l.open { + return nil + } + l.open = false + return l.listener.Close() +} + +func (l *GracefulListener) GetFile() (*os.File, error) { + return getListenerFile(l.listener) +} + +func (l *GracefulListener) Clone() (net.Listener, error) { + l.mutex.Lock() + defer l.mutex.Unlock() + + if !l.open { + return nil, fmt.Errorf("listener is already closed") + } + + file, err := l.GetFile() + if err != nil { + return nil, err + } + defer file.Close() + + fl, err := net.FileListener(file) + if nil != err { + return nil, err + } + return fl, nil +} + +// A listener implements a network listener (net.Listener) for TLS connections. +// direct lift from crypto/tls.go +type TLSListener struct { + net.Listener + config *tls.Config +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection c is a *tls.Conn. +func (l *TLSListener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + if err != nil { + return + } + c = tls.Server(&gracefulConn{Conn: c}, l.config) + return +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must have +// at least one certificate. +func NewTLSListener(inner net.Listener, config *tls.Config) net.Listener { + l := new(TLSListener) + l.Listener = inner + l.config = config + return l +} + +type listenerAlreadyClosed struct { + error +} + +// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +// +// direct lift from net/http/server.go +type TCPKeepAliveListener struct { + *net.TCPListener +} + +func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +func getListenerFile(listener net.Listener) (*os.File, error) { + switch t := listener.(type) { + case *net.TCPListener: + return t.File() + case *net.UnixListener: + return t.File() + case TCPKeepAliveListener: + return t.TCPListener.File() + case *TLSListener: + return getListenerFile(t.Listener) + } + return nil, fmt.Errorf("Unsupported listener: %T", listener) +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..7de7041 --- /dev/null +++ b/logger.go @@ -0,0 +1,17 @@ +package manners + +import ( + "io/ioutil" + "log" +) + +var logger = log.New(ioutil.Discard, "", 0) + +// SetLogger changes the logger used for the startup and shutdown messages +// generated by Manners. By default, no log messages are emitted. +// To make Manners logging behave the same as per the standard +// log package, i.e. to stderr, use +// `SetLogger(log.New(os.Stderr, "", log.LstdFlags))` +func SetLogger(l *log.Logger) { + logger = l +} diff --git a/server.go b/server.go index 869fe05..c32022a 100644 --- a/server.go +++ b/server.go @@ -27,14 +27,7 @@ or for a customized server: The server will shut down cleanly when the Close() method is called: - go func() { - sigchan := make(chan os.Signal, 1) - signal.Notify(sigchan, os.Interrupt, os.Kill) - <-sigchan - log.Info("Shutting down...") - manners.Close() - }() - + manners.CloseOnInterrupt() http.Handle("/hello", myHandler) log.Fatal(manners.ListenAndServe(":8080", nil)) */ @@ -42,12 +35,28 @@ package manners import ( "crypto/tls" + "fmt" "net" "net/http" + "net/http/fcgi" + "os" + "os/signal" + "strings" "sync" "sync/atomic" + "syscall" ) +// StateHandler can be called by the server if the state of the connection changes. +// Notice that it passed previous state and the new state as parameters. +type StateHandler func(net.Conn, http.ConnState, http.ConnState) + +type Options struct { + Server *http.Server + StateHandler StateHandler + Listener net.Listener +} + // A GracefulServer maintains a WaitGroup that counts how many in-flight // requests the server is handling. When it receives a shutdown signal, // it stops accepting new requests but does not actually shut down until @@ -56,48 +65,121 @@ import ( // GracefulServer embeds the underlying net/http.Server making its non-override // methods and properties avaiable. // -// It must be initialized by calling NewWithServer. +// It must be initialized by calling NewServer or NewWithServer type GracefulServer struct { *http.Server - shutdown chan bool - wg waitGroup - - lcsmu sync.RWMutex - lastConnState map[net.Conn]http.ConnState + shutdown chan bool + shutdownFinished chan bool + wg waitGroup + listener *GracefulListener + stateHandler StateHandler up chan net.Listener // Only used by test code. + + signal os.Signal +} + +// NewServer creates a new GracefulServer. +func NewServer() *GracefulServer { + return NewWithServer(new(http.Server)) } // NewWithServer wraps an existing http.Server object and returns a // GracefulServer that supports all of the original Server operations. func NewWithServer(s *http.Server) *GracefulServer { return &GracefulServer{ - Server: s, - shutdown: make(chan bool), - wg: new(sync.WaitGroup), - lastConnState: make(map[net.Conn]http.ConnState), + Server: s, + shutdown: make(chan bool), + shutdownFinished: make(chan bool, 1), + wg: new(sync.WaitGroup), + } +} + +func NewWithOptions(o Options) *GracefulServer { + // Set up listener + var listener *GracefulListener + if o.Listener != nil { + g, ok := o.Listener.(*GracefulListener) + if !ok { + listener = NewListener(o.Listener) + } else { + listener = g + } + } + + return &GracefulServer{ + listener: listener, + Server: o.Server, + stateHandler: o.StateHandler, + shutdown: make(chan bool), + shutdownFinished: make(chan bool, 1), + wg: new(sync.WaitGroup), } } // Close stops the server from accepting new requets and begins shutting down. // It returns true if it's the first time Close is called. func (s *GracefulServer) Close() bool { + logger.Printf("Shutting down server on %s\n", s.Server.Addr) return <-s.shutdown } -// ListenAndServe provides a graceful equivalent of net/http.Serve.ListenAndServe. -func (s *GracefulServer) ListenAndServe() error { - addr := s.Addr - if addr == "" { - addr = ":http" +// BlockingClose is similar to Close, except that it blocks until the last +// connection has been closed. +func (s *GracefulServer) BlockingClose() bool { + logger.Printf("Shutting down server on %s (blocking)\n", s.Server.Addr) + result := s.Close() + <-s.shutdownFinished + return result +} + +func isUnixNetwork(addr string) bool { + return strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, ".") +} + +func listenToUnix(bind string) (listener net.Listener, err error) { + _, err = os.Stat(bind) + if err == nil { + // socket exists and is "already in use"; + // presume this is from earlier run and therefore delete it + err = os.Remove(bind) + if err != nil { + return + } + } else if !os.IsNotExist(err) { + return } - listener, err := net.Listen("tcp", addr) - if err != nil { - return err + listener, err = net.Listen("unix", bind) + return +} + +func listen(bind string) (listener net.Listener, err error) { + if isUnixNetwork(bind) { + logger.Printf("Listening on unix socket %s\n", bind) + return listenToUnix(bind) + } else if strings.Contains(bind, ":") { + logger.Printf("Listening on tcp socket %s\n", bind) + return net.Listen("tcp", bind) + } else { + return nil, fmt.Errorf("error while parsing bind arg %v", bind) } +} - return s.Serve(listener) +// ListenAndServe provides a graceful equivalent of net/http.Serve.ListenAndServe. +func (s *GracefulServer) ListenAndServe() error { + if s.listener == nil { + addr := s.Addr + if addr == "" { + addr = ":http" + } + oldListener, err := listen(addr) + if err != nil { + return err + } + s.listener = NewListener(oldListener) + } + return s.Serve(s.listener) } // ListenAndServeTLS provides a graceful equivalent of net/http.Serve.ListenAndServeTLS. @@ -122,74 +204,120 @@ func (s *GracefulServer) ListenAndServeTLS(certFile, keyFile string) error { return err } - ln, err := net.Listen("tcp", addr) + return s.ListenAndServeTLSWithConfig(config) +} + +// ListenAndServeTLSWithConfig provides a graceful equivalent of net/http.Serve.ListenAndServeTLS +// using a bespoke TLS config. +func (s *GracefulServer) ListenAndServeTLSWithConfig(config *tls.Config) error { + addr := s.Addr + if addr == "" { + addr = ":https" + } + + if s.listener == nil { + logger.Printf("Listening on tcp socket %s\n", addr) + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + tlsListener := NewTLSListener(TCPKeepAliveListener{ln.(*net.TCPListener)}, config) + s.listener = NewListener(tlsListener) + } + return s.Serve(s.listener) +} + +func (gs *GracefulServer) GetFile() (*os.File, error) { + return gs.listener.GetFile() +} + +func (gs *GracefulServer) HijackListener(s *http.Server, config *tls.Config) (*GracefulServer, error) { + listener, err := gs.listener.Clone() if err != nil { - return err + return nil, err + } + + if config != nil { + listener = NewTLSListener(TCPKeepAliveListener{listener.(*net.TCPListener)}, config) } - return s.Serve(tls.NewListener(ln, config)) + other := NewWithServer(s) + other.listener = NewListener(listener) + return other, nil } // Serve provides a graceful equivalent net/http.Server.Serve. +// +// If listener is not an instance of *GracefulListener it will be wrapped +// to become one. func (s *GracefulServer) Serve(listener net.Listener) error { - var closing int32 + // Accept a net.Listener to preserve the interface compatibility with the + // standard http.Server. If it is not a GracefulListener then wrap it into + // one. + gracefulListener, ok := listener.(*GracefulListener) + if !ok { + gracefulListener = NewListener(listener) + listener = gracefulListener + } + s.listener = gracefulListener + + // Wrap the server HTTP handler into graceful one, that will close kept + // alive connections if a new request is received after shutdown. + gracefulHandler := newGracefulHandler(s.Server.Handler) + s.Server.Handler = gracefulHandler + // Start a goroutine that waits for a shutdown signal and will stop the + // listener when it receives the signal. That in turn will result in + // unblocking of the http.Serve call. go func() { s.shutdown <- true close(s.shutdown) - atomic.StoreInt32(&closing, 1) + gracefulHandler.Close() s.Server.SetKeepAlivesEnabled(false) - listener.Close() + gracefulListener.Close() }() originalConnState := s.Server.ConnState - // s.ConnState is invoked by the net/http.Server every time a connectiion + // s.ConnState is invoked by the net/http.Server every time a connection // changes state. It keeps track of each connection's state over time, // enabling manners to handle persisted connections correctly. s.ConnState = func(conn net.Conn, newState http.ConnState) { - s.lcsmu.RLock() - lastConnState := s.lastConnState[conn] - s.lcsmu.RUnlock() + gracefulConn := retrieveGracefulConn(conn) + oldState := gracefulConn.lastHTTPState + gracefulConn.lastHTTPState = newState switch newState { - // New connection -> StateNew case http.StateNew: + // New connection -> StateNew + gracefulConn.protected = true s.StartRoutine() - // (StateNew, StateIdle) -> StateActive case http.StateActive: - // The connection transitioned from idle back to active - if lastConnState == http.StateIdle { - s.StartRoutine() + // (StateNew, StateIdle) -> StateActive + if gracefulHandler.IsClosed() { + gracefulConn.Close() + break } - // StateActive -> StateIdle - // Immediately close newly idle connections; if not they may make - // one more request before SetKeepAliveEnabled(false) takes effect. - case http.StateIdle: - if atomic.LoadInt32(&closing) == 1 { - conn.Close() + if !gracefulConn.protected { + gracefulConn.protected = true + s.StartRoutine() } - s.FinishRoutine() - // (StateNew, StateActive, StateIdle) -> (StateClosed, StateHiJacked) - // If the connection was idle we do not need to decrement the counter. - case http.StateClosed, http.StateHijacked: - if lastConnState != http.StateIdle { + default: + // (StateNew, StateActive) -> (StateIdle, StateClosed, StateHiJacked) + if gracefulConn.protected { s.FinishRoutine() + gracefulConn.protected = false } - } - s.lcsmu.Lock() - if newState == http.StateClosed || newState == http.StateHijacked { - delete(s.lastConnState, conn) - } else { - s.lastConnState[conn] = newState + if s.stateHandler != nil { + s.stateHandler(conn, oldState, newState) } - s.lcsmu.Unlock() if originalConnState != nil { originalConnState(conn, newState) @@ -202,15 +330,22 @@ func (s *GracefulServer) Serve(listener net.Listener) error { s.up <- listener } - err := s.Server.Serve(listener) + var err error + if isUnixNetwork(s.Server.Addr) { + os.Chmod(s.Server.Addr, os.ModePerm) + err = fcgi.Serve(listener, s.Server.Handler) + } else { + err = s.Server.Serve(listener) + } - // This block is reached when the server has received a shut down command - // or a real error happened. - if err == nil || atomic.LoadInt32(&closing) == 1 { - s.wg.Wait() - return nil + // An error returned on shutdown is not worth reporting. + if _, ok = err.(listenerAlreadyClosed); ok { + err = nil } + // Wait for pending requests to complete regardless the Serve result. + s.wg.Wait() + s.shutdownFinished <- true return err } @@ -226,3 +361,68 @@ func (s *GracefulServer) StartRoutine() { func (s *GracefulServer) FinishRoutine() { s.wg.Done() } + +// CloseOnInterrupt creates a go-routine that will call the Close() function when certain OS +// signals are received. If no signals are specified, +// the following are used: SIGINT, SIGTERM, SIGKILL, SIGQUIT, SIGHUP, SIGUSR1. +// This function must be called before ListenAndServe, ListenAndServeTLS, or Serve. +func (s *GracefulServer) CloseOnInterrupt(signals ...os.Signal) *GracefulServer { + if s == nil { + panic("Program error: the server must exist before this method is called.") + } + go func(rx *GracefulServer) { + sigchan := make(chan os.Signal, 1) + if len(signals) > 0 { + signal.Notify(sigchan, signals...) + } else { + signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL, + syscall.SIGQUIT, syscall.SIGHUP, syscall.SIGUSR1) + } + rx.signal = <-sigchan + rx.Close() + }(s) + return s +} + +// SignalReceived gets the signal that caused the server to close, if any. If Close() was called +// some other way, this method will return nil. +// +// Note that, by convention, SIGUSR1 is often used to cause a server to close all its current +// connections cleanly, close its log files, and then restart. This facilitates log rotation. +// If you need this behaviour, you will need to provide a loop around both the CloseOnInterrupt and +// ListenAndServe calls. +func (s *GracefulServer) SignalReceived() os.Signal { + return s.signal +} + +// gracefulHandler is used by GracefulServer to prevent calling ServeHTTP on +// to be closed kept-alive connections during the server shutdown. +type gracefulHandler struct { + closed int32 // accessed atomically. + wrapped http.Handler +} + +func newGracefulHandler(wrapped http.Handler) *gracefulHandler { + return &gracefulHandler{ + wrapped: wrapped, + } +} + +func (gh *gracefulHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if atomic.LoadInt32(&gh.closed) == 0 { + gh.wrapped.ServeHTTP(w, r) + return + } + r.Body.Close() + // Server is shutting down at this moment, and the connection that this + // handler is being called on is about to be closed. So we do not need to + // actually execute the handler logic. +} + +func (gh *gracefulHandler) Close() { + atomic.StoreInt32(&gh.closed, 1) +} + +func (gh *gracefulHandler) IsClosed() bool { + return atomic.LoadInt32(&gh.closed) == 1 +} diff --git a/server_test.go b/server_test.go index 2f54eaf..4d96760 100644 --- a/server_test.go +++ b/server_test.go @@ -1,17 +1,37 @@ package manners import ( - helpers "github.com/braintree/manners/test_helpers" + helpers "github.com/rickb777/manners/test_helpers" "net" "net/http" + "os" "testing" "time" ) +type httpInterface interface { + ListenAndServe() error + ListenAndServeTLS(certFile, keyFile string) error + Serve(listener net.Listener) error +} + +// Test that the method signatures of the methods we override from net/http/Server match those of the original. +func TestInterface(t *testing.T) { + var original, ours interface{} + original = &http.Server{} + ours = &GracefulServer{} + if _, ok := original.(httpInterface); !ok { + t.Errorf("httpInterface definition does not match the canonical server!") + } + if _, ok := ours.(httpInterface); !ok { + t.Errorf("GracefulServer does not implement httpInterface") + } +} + // Tests that the server allows in-flight requests to complete // before shutting down. func TestGracefulness(t *testing.T) { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() server.wg = wg statechanged := make(chan http.ConnState) @@ -24,10 +44,9 @@ func TestGracefulness(t *testing.T) { if err := <-client.connected; err != nil { t.Fatal("Client failed to connect to server", err) } - // avoid a race between the client connection and the server accept - if state := <-statechanged; state != http.StateNew { - t.Fatal("Unexpected state", state) - } + // Even though the client is connected, the server ConnState handler may + // not know about that yet. So wait until it is called. + waitForState(t, statechanged, http.StateNew, "Request not received") server.Close() @@ -48,7 +67,7 @@ func TestGracefulness(t *testing.T) { // Tests that the server begins to shut down when told to and does not accept // new requests once shutdown has begun func TestShutdown(t *testing.T) { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() server.wg = wg statechanged := make(chan http.ConnState) @@ -61,10 +80,9 @@ func TestShutdown(t *testing.T) { if err := <-client1.connected; err != nil { t.Fatal("Client failed to connect to server", err) } - // avoid a race between the client connection and the server accept - if state := <-statechanged; state != http.StateNew { - t.Fatal("Unexpected state", state) - } + // Even though the client is connected, the server ConnState handler may + // not know about that yet. So wait until it is called. + waitForState(t, statechanged, http.StateNew, "Request not received") // start the shutdown; once it hits waitgroup.Wait() // the listener should of been closed, though client1 is still connected @@ -94,36 +112,32 @@ func TestShutdown(t *testing.T) { <-exitchan } -// Test that a connection is closed upon reaching an idle state if and only if the server -// is shutting down. -func TestCloseOnIdle(t *testing.T) { - server := newServer() - wg := helpers.NewWaitGroup() - server.wg = wg - fl := helpers.NewListener() - runner := func() error { - return server.Serve(fl) - } - - startGenericServer(t, server, nil, runner) +// If a request is sent to a closed server via a kept alive connection then +// the server closes the connection upon receiving the request. +func TestRequestAfterClose(t *testing.T) { + // Given + server := NewServer() + srvStateChangedCh := make(chan http.ConnState, 100) + listener, srvClosedCh := startServer(t, server, srvStateChangedCh) - // Change to idle state while server is not closing; Close should not be called - conn := &helpers.Conn{} - server.ConnState(conn, http.StateIdle) - if conn.CloseCalled { - t.Error("Close was called unexpected") - } + client := newClient(listener.Addr(), false) + client.Run() + <-client.connected + client.sendrequest <- true + <-client.response server.Close() + if err := <-srvClosedCh; err != nil { + t.Error("Unexpected error during shutdown", err) + } - // wait until the server calls Close() on the listener - // by that point the atomic closing variable will have been updated, avoiding a race. - <-fl.CloseCalled + // When + client.sendrequest <- true + rr := <-client.response - conn = &helpers.Conn{} - server.ConnState(conn, http.StateIdle) - if !conn.CloseCalled { - t.Error("Close was not called") + // Then + if rr.body != nil || rr.err != nil { + t.Errorf("Request should be rejected, body=%v, err=%v", rr.body, rr.err) } } @@ -143,7 +157,7 @@ func waitForState(t *testing.T, waiter chan http.ConnState, state http.ConnState // Test that a request moving from active->idle->active using an actual // network connection still results in a corect shutdown func TestStateTransitionActiveIdleActive(t *testing.T) { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() statechanged := make(chan http.ConnState) server.wg = wg @@ -160,8 +174,7 @@ func TestStateTransitionActiveIdleActive(t *testing.T) { for i := 0; i < 2; i++ { client.sendrequest <- true waitForState(t, statechanged, http.StateActive, "Client failed to reach active state") - <-client.idle - client.idlerelease <- true + <-client.response waitForState(t, statechanged, http.StateIdle, "Client failed to reach idle state") } @@ -196,7 +209,7 @@ func TestStateTransitionActiveIdleClosed(t *testing.T) { } for _, withTLS := range []bool{false, true} { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() statechanged := make(chan http.ConnState) server.wg = wg @@ -217,12 +230,11 @@ func TestStateTransitionActiveIdleClosed(t *testing.T) { client.sendrequest <- true waitForState(t, statechanged, http.StateActive, "Client failed to reach active state") - err := <-client.idle - if err != nil { - t.Fatalf("tls=%t unexpected error from client %s", withTLS, err) + rr := <-client.response + if rr.err != nil { + t.Fatalf("tls=%t unexpected error from client %s", withTLS, rr.err) } - client.idlerelease <- true waitForState(t, statechanged, http.StateIdle, "Client failed to reach idle state") // client is now in an idle state @@ -241,3 +253,171 @@ func TestStateTransitionActiveIdleClosed(t *testing.T) { } } } + +// Test that supplying a non GracefulListener to Serve works +// correctly (ie. that the listener is wrapped to become graceful) +func TestWrapConnectionTcp(t *testing.T) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Failed to create listener", err) + } + + s := NewServer() + s.up = make(chan net.Listener) + + var called bool + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + s.Close() // clean shutdown as soon as handler exits + }) + s.Handler = handler + + serverr := make(chan error) + + go func() { + serverr <- s.Serve(l) + }() + + gl := <-s.up + if _, ok := gl.(*GracefulListener); !ok { + t.Fatal("connection was not wrapped into a GracefulListener") + } + + addr := l.Addr() + if _, err := http.Get("http://" + addr.String()); err != nil { + t.Fatal("Get failed", err) + } + + if err := <-serverr; err != nil { + t.Fatal("Error from Serve()", err) + } + + if !called { + t.Error("Handler was not called") + } +} + +func TestWrapConnectionUnix(t *testing.T) { + l, err := listenToUnix("/var/tmp/servertest") + if err != nil { + t.Fatal("Failed to create listener", err) + } + defer os.Remove("/var/tmp/servertest") + + _, err = os.Stat("/var/tmp/servertest") + if err != nil { + t.Fatal("Failed to create listener", err) + } + + s := NewServer() + s.up = make(chan net.Listener) + + var called bool + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + s.Close() // clean shutdown as soon as handler exits + }) + s.Handler = handler + + serverr := make(chan error) + + go func() { + serverr <- s.Serve(l) + }() + + gl := <-s.up + if _, ok := gl.(*GracefulListener); !ok { + t.Fatal("connection was not wrapped into a GracefulListener") + } + + //addr := l.Addr() + //if _, err := http.Get("http://" + addr.String()); err != nil { + // t.Fatal("Get failed", err) + //} + // + //if err := <-serverr; err != nil { + // t.Fatal("Error from Serve()", err) + //} + // + //if !called { + // t.Error("Handler was not called") + //} +} + +// Hijack listener +func TestHijackListener(t *testing.T) { + server := NewServer() + wg := helpers.NewWaitGroup() + server.wg = wg + listener, exitchan := startServer(t, server, nil) + + client := newClient(listener.Addr(), false) + client.Run() + + // wait for client to connect, but don't let it send the request yet + if err := <-client.connected; err != nil { + t.Fatal("Client failed to connect to server", err) + } + + // Make sure server1 got the request and added it to the waiting group + <-wg.CountChanged + + wg2 := helpers.NewWaitGroup() + server2, err := server.HijackListener(new(http.Server), nil) + server2.wg = wg2 + if err != nil { + t.Fatal("Failed to hijack listener", err) + } + + listener2, exitchan2 := startServer(t, server2, nil) + + // Close the first server + server.Close() + + // First server waits for the first request to finish + waiting := <-wg.WaitCalled + if waiting < 1 { + t.Errorf("Expected the waitgroup to equal 1 at shutdown; actually %d", waiting) + } + + // allow the client to finish sending the request and make sure the server exits after + // (client will be in connected but idle state at that point) + client.sendrequest <- true + close(client.sendrequest) + if err := <-exitchan; err != nil { + t.Error("Unexpected error during shutdown", err) + } + + client2 := newClient(listener2.Addr(), false) + client2.Run() + + // wait for client to connect, but don't let it send the request yet + select { + case err := <-client2.connected: + if err != nil { + t.Fatal("Client failed to connect to server", err) + } + case <-time.After(time.Second): + t.Fatal("Timeout connecting to the server", err) + } + + // Close the second server + server2.Close() + + waiting = <-wg2.WaitCalled + if waiting < 1 { + t.Errorf("Expected the waitgroup to equal 1 at shutdown; actually %d", waiting) + } + + // allow the client to finish sending the request and make sure the server exits after + // (client will be in connected but idle state at that point) + client2.sendrequest <- true + // Make sure that request resulted in success + if rr := <-client2.response; rr.err != nil { + t.Errorf("Client failed to write the request, error: %s", err) + } + close(client2.sendrequest) + if err := <-exitchan2; err != nil { + t.Error("Unexpected error during shutdown", err) + } +} diff --git a/static.go b/static.go index 2a74b09..726c941 100644 --- a/static.go +++ b/static.go @@ -3,33 +3,78 @@ package manners import ( "net" "net/http" + "os" ) -var defaultServer *GracefulServer +var ( + defaultServer *GracefulServer + defaultSignals []os.Signal + hasSignals = false +) + +func preventReEntrance() { + if defaultServer != nil { + panic("Program error: the default server must be closed before re-use.") + } +} // ListenAndServe provides a graceful version of the function provided by the // net/http package. Call Close() to stop the server. func ListenAndServe(addr string, handler http.Handler) error { + preventReEntrance() defaultServer = NewWithServer(&http.Server{Addr: addr, Handler: handler}) + if hasSignals { + defaultServer.CloseOnInterrupt(defaultSignals...) + } return defaultServer.ListenAndServe() } // ListenAndServeTLS provides a graceful version of the function provided by the // net/http package. Call Close() to stop the server. func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + preventReEntrance() defaultServer = NewWithServer(&http.Server{Addr: addr, Handler: handler}) + if hasSignals { + defaultServer.CloseOnInterrupt(defaultSignals...) + } return defaultServer.ListenAndServeTLS(certFile, keyFile) } // Serve provides a graceful version of the function provided by the net/http // package. Call Close() to stop the server. func Serve(l net.Listener, handler http.Handler) error { + preventReEntrance() defaultServer = NewWithServer(&http.Server{Handler: handler}) + if hasSignals { + defaultServer.CloseOnInterrupt(defaultSignals...) + } return defaultServer.Serve(l) } // Shuts down the default server used by ListenAndServe, ListenAndServeTLS and // Serve. It returns true if it's the first time Close is called. func Close() bool { - return defaultServer.Close() + outcome := defaultServer.Close() + defaultServer = nil + return outcome +} + +// CloseOnInterrupt creates a go-routine that will call the Close() function when certain OS +// signals are received. If no signals are specified, +// the following are used: SIGINT, SIGTERM, SIGKILL, SIGQUIT, SIGHUP, SIGUSR1. +// This function must be called before ListenAndServe, ListenAndServeTLS, or Serve. +func CloseOnInterrupt(signals ...os.Signal) { + defaultSignals = signals + hasSignals = true +} + +// After a signal has cause the server to close, this method allows you to determine which +// signal had been received. If Close was called some other way, this method will return nil. +// +// Note that, by convention, SIGUSR1 is often used to cause a server to close all its current +// connections cleanly, close its log files, and then restart. This facilitates log rotation. +// If you need this behaviour, you will need to provide a loop around the CloseOnInterrupt and +// ListenAndServe calls. +func SignalReceived() os.Signal { + return defaultServer.SignalReceived() } diff --git a/test_helpers/conn.go b/test_helpers/conn.go index 8c610f5..d7a298b 100644 --- a/test_helpers/conn.go +++ b/test_helpers/conn.go @@ -4,10 +4,13 @@ import "net" type Conn struct { net.Conn - CloseCalled bool + localAddr net.Addr +} + +func (f *Conn) LocalAddr() net.Addr { + return &net.IPAddr{} } func (c *Conn) Close() error { - c.CloseCalled = true return nil } diff --git a/test_helpers/listener.go b/test_helpers/listener.go index a74ac11..e3af35a 100644 --- a/test_helpers/listener.go +++ b/test_helpers/listener.go @@ -1,8 +1,8 @@ package test_helpers import ( - "net" - "errors" + "errors" + "net" ) type Listener struct { @@ -11,10 +11,10 @@ type Listener struct { } func NewListener() *Listener { - return &Listener{ - make(chan bool, 1), - make(chan bool, 1), - } + return &Listener{ + make(chan bool, 1), + make(chan bool, 1), + } } func (l *Listener) Addr() net.Addr { diff --git a/test_helpers/wait_group.go b/test_helpers/wait_group.go index 1df590d..192a121 100644 --- a/test_helpers/wait_group.go +++ b/test_helpers/wait_group.go @@ -4,25 +4,29 @@ import "sync" type WaitGroup struct { sync.Mutex - Count int - WaitCalled chan int + Count int + WaitCalled chan int + CountChanged chan int } func NewWaitGroup() *WaitGroup { return &WaitGroup{ - WaitCalled: make(chan int, 1), + WaitCalled: make(chan int, 1), + CountChanged: make(chan int, 1024), } } func (wg *WaitGroup) Add(delta int) { wg.Lock() wg.Count++ + wg.CountChanged <- wg.Count wg.Unlock() } func (wg *WaitGroup) Done() { wg.Lock() wg.Count-- + wg.CountChanged <- wg.Count wg.Unlock() } diff --git a/transition_test.go b/transition_test.go index 34fe5c6..b0bb4a3 100644 --- a/transition_test.go +++ b/transition_test.go @@ -1,7 +1,7 @@ package manners import ( - helpers "github.com/braintree/manners/test_helpers" + helpers "github.com/rickb777/manners/test_helpers" "net/http" "strings" "testing" @@ -31,12 +31,12 @@ type transitionTest struct { } func testStateTransition(t *testing.T, test transitionTest) { - server := newServer() + server := NewServer() wg := helpers.NewWaitGroup() server.wg = wg startServer(t, server, nil) - conn := &helpers.Conn{} + conn := &gracefulConn{Conn: &helpers.Conn{}} for _, newState := range test.states { server.ConnState(conn, newState) }