about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--go.mod1
-rw-r--r--go.sum2
-rw-r--r--main.go45
3 files changed, 34 insertions, 14 deletions
diff --git a/go.mod b/go.mod
index ef4566c..ffbc52b 100644
--- a/go.mod
+++ b/go.mod
@@ -8,4 +8,5 @@ go 1.16
 require (
 	github.com/coreos/go-systemd/v22 v22.3.2
 	github.com/gorilla/handlers v1.5.1
+	golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
 )
diff --git a/go.sum b/go.sum
index 5b74262..00aa90b 100644
--- a/go.sum
+++ b/go.sum
@@ -5,3 +5,5 @@ github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSw
 github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
 github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH4=
 github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
diff --git a/main.go b/main.go
index 2019888..3da87b6 100644
--- a/main.go
+++ b/main.go
@@ -8,6 +8,7 @@ import (
 	"flag"
 	"fmt"
 	"log"
+	"net"
 	"net/http"
 	"net/http/cgi"
 	"os"
@@ -16,6 +17,7 @@ import (
 
 	"github.com/coreos/go-systemd/v22/activation"
 	"github.com/gorilla/handlers"
+	"golang.org/x/sync/errgroup"
 )
 
 func main() {
@@ -62,31 +64,46 @@ func main() {
 	// Additionally, we want to log requests.
 	handler = handlers.CombinedLoggingHandler(os.Stdout, handler)
 
-	// Catch SIGTERM so we can shutdown gracefully.
-	sig := make(chan os.Signal, 1)
-	signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
-
 	// Create a Server so we can call Shutdown() on it later.
 	srv := &http.Server{Handler: handler}
 
-	doom := make(chan error)
+	// TODO(V): comment
+	ctx, cancel := context.WithCancel(context.Background())
+	group, ctx := errgroup.WithContext(ctx)
+	srv.BaseContext = func(net.Listener) context.Context { return ctx }
+
+	group.Go(func() error {
+		// Catch SIGTERM so we can shutdown gracefully.
+		sig := make(chan os.Signal, 1)
+		signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
+
+		select {
+		case <-sig:
+			log.Print("Caught signal, shutting down")
+			cancel()
+		case <-ctx.Done():
+			// Nothing to do here
+		}
+
+		return nil
+	})
+
 	for _, ln := range lns {
 		// Loop variables are unsafe to close over with goroutines,
 		// so this just shadows it with a fresh binding.
 		ln := ln
 
-		go func() {
-			// If any one of these return, it's game over.
-			doom <- srv.Serve(ln)
-		}()
+		group.Go(func() error {
+			err := srv.Serve(ln)
+			if err != http.ErrServerClosed {
+				return err
+			}
+			return nil
+		})
 	}
 
-	select {
-	case err = <-doom:
-		// Only the first error matters, the rest will be ErrServerClosed.
+	if err := group.Wait(); err != nil {
 		log.Printf("Fatal server error: %v", err)
-	case <-sig:
-		log.Print("Caught signal, shutting down")
 	}
 
 	// This will block until all existing connections are handled.