about summary refs log tree commit diff
path: root/proxy.go
diff options
context:
space:
mode:
Diffstat (limited to 'proxy.go')
-rw-r--r--proxy.go145
1 files changed, 145 insertions, 0 deletions
diff --git a/proxy.go b/proxy.go
new file mode 100644
index 0000000..dfee27e
--- /dev/null
+++ b/proxy.go
@@ -0,0 +1,145 @@
+// SPDX-FileCopyrightText: V <v@anomalous.eu>
+// SPDX-FileCopyrightText: edef <edef@edef.eu>
+// SPDX-License-Identifier: OSL-3.0
+
+package main
+
+import (
+	"bufio"
+	"bytes"
+	"crypto/tls"
+	"io"
+	"log"
+	"net"
+	"net/http"
+	"sync"
+)
+
+type sidedConn struct {
+	net.Conn
+	Side
+}
+
+type Proxy struct {
+	store     *Store
+	tlsConfig *tls.Config
+
+	exiting chan struct{}
+	wg      sync.WaitGroup
+}
+
+func NewProxy(store *Store, tlsConfig *tls.Config) *Proxy {
+	return &Proxy{
+		store:     store,
+		tlsConfig: tlsConfig,
+
+		exiting: make(chan struct{}),
+	}
+}
+
+func (p *Proxy) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+	if req.Method != http.MethodConnect {
+		log.Printf("%s - invalid request %s %s", req.RemoteAddr, req.Method, req.RequestURI)
+		http.Error(resp, "405 I'm a proxy", http.StatusMethodNotAllowed)
+		return
+	}
+
+	host := req.URL.Hostname()
+	log.Printf("%s - new connection to %s", req.RemoteAddr, host)
+
+	server, err := tls.DialWithDialer(
+		&net.Dialer{Cancel: req.Context().Done()},
+		"tcp", net.JoinHostPort(host, "6697"),
+		p.tlsConfig,
+	)
+	if err != nil {
+		log.Printf("%s - failed connection %v", req.RemoteAddr, err)
+		resp.WriteHeader(http.StatusBadGateway)
+		return
+	}
+
+	session := OpenSession(p.store, host)
+	resp.WriteHeader(http.StatusOK)
+
+	// http.Server's Shutdown "does not attempt to close nor wait for hijacked connections",
+	// so we have to bump the waitgroup prior to calling Hijack()
+	p.wg.Add(1)
+	defer p.wg.Done()
+
+	// XXX: bufio.ReadWriter might still contain data
+	// I think it's impossible for err to be non-nil
+	client, _, _ := resp.(http.Hijacker).Hijack()
+	p.proxy(session, sidedConn{client, SideClient}, sidedConn{server, SideServer})
+}
+
+func (p *Proxy) proxy(session *Session, a, b sidedConn) {
+	ch := make(chan func())
+
+	pipe := func(r, w sidedConn) {
+		scanner := bufio.NewScanner(r)
+		scanner.Split(scanIRCLines)
+		for scanner.Scan() {
+			session.Write(r.Side, string(dropLineEnding(scanner.Bytes())))
+			_, err := w.Write(scanner.Bytes())
+			if err != nil {
+				ch <- func() { session.Close(w.Side, err.Error()) }
+				return
+			}
+		}
+
+		err := scanner.Err()
+		if err == nil {
+			err = io.EOF
+		}
+
+		ch <- func() { session.Close(r.Side, err.Error()) }
+	}
+
+	go pipe(a, b)
+	go pipe(b, a)
+
+	done := func() { session.Close(SideProxy, "shutting down") }
+
+	select {
+	case <-p.exiting:
+		a.Close()
+		b.Close()
+		<-ch
+		<-ch
+	case done = <-ch:
+		a.Close()
+		b.Close()
+		<-ch
+	}
+
+	done()
+}
+
+func (p *Proxy) Shutdown() {
+	close(p.exiting)
+	p.wg.Wait()
+	p.store.Close()
+}
+
+func scanIRCLines(data []byte, atEOF bool) (advance int, token []byte, err error) {
+	if atEOF && len(data) == 0 {
+		return 0, nil, nil
+	}
+	if i := bytes.IndexByte(data, '\n'); i >= 0 {
+		// we have a full newline-terminated line
+		return i + 1, data[:i+1], nil
+	}
+	if atEOF {
+		return 0, nil, io.ErrUnexpectedEOF
+	}
+	return 0, nil, nil // request more data
+}
+
+// on a buffer known to end in \n, drop \n or \r\n
+func dropLineEnding(data []byte) []byte {
+	n := len(data)
+	if n > 1 && data[n-1] == '\r' {
+		return data[:n-2]
+	}
+	return data[:n-1]
+}