about summary refs log tree commit diff
path: root/store.go
blob: 5f228aed1bf6a3a3615556bc54c5e6ef85e25f11 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// SPDX-FileCopyrightText: V <v@anomalous.eu>
// SPDX-FileCopyrightText: edef <edef@edef.eu>
// SPDX-License-Identifier: OSL-3.0

package main

import (
	"database/sql"

	_ "github.com/mattn/go-sqlite3"
)

const schema = `
	CREATE TABLE IF NOT EXISTS sessions (
		id             INTEGER PRIMARY KEY,
		opened_at      INTEGER NOT NULL,
		host           TEXT    NOT NULL,
		closed_at      INTEGER,
		closed_by      INTEGER,
		closed_because TEXT
	);

	CREATE TABLE IF NOT EXISTS messages (
		session INTEGER NOT NULL REFERENCES sessions(id),
		time    INTEGER NOT NULL,
		side    INTEGER NOT NULL,
		data    TEXT    NOT NULL
	);
`

type SessionID int64
type Timestamp int64
type Side byte

const (
	SideProxy Side = iota
	SideClient
	SideServer
)

type Message struct {
	when Timestamp
	side Side
	data string
}

type Store struct {
	open, batch, close *sql.Stmt

	q    chan func(*sql.DB)
	done chan struct{}
}

func OpenStore(path string) *Store {
	db, err := sql.Open("sqlite3", path+"?_foreign_keys=yes")
	check(err)

	must(db.Exec(schema))
	must(db.Exec(`UPDATE sessions SET closed_at = ? WHERE closed_at IS NULL`, now()))

	prepare := func(query string) *sql.Stmt {
		stmt, err := db.Prepare(query)
		check(err)
		return stmt
	}

	s := &Store{
		open:  prepare(`INSERT INTO sessions(opened_at, host) VALUES(?, ?)`),
		batch: prepare(`INSERT INTO messages(session, time, side, data) VALUES(?, ?, ?, ?)`),
		close: prepare(`UPDATE sessions SET closed_at = ?, closed_by = ?, closed_because = ? WHERE id = ?`),

		q:    make(chan func(*sql.DB)),
		done: make(chan struct{}),
	}

	go func() {
		for op := range s.q {
			op(db)
		}
		check(db.Close())
		close(s.done)
	}()

	return s
}

func (s *Store) WriteOpen(when Timestamp, host string) SessionID {
	ch := make(chan int64, 1)
	s.q <- func(*sql.DB) {
		id, err := must(s.open.Exec(when, host)).LastInsertId()
		check(err)
		ch <- id
	}
	return SessionID(<-ch)
}

func (s *Store) WriteBatch(id SessionID, batch []Message) {
	ch := make(chan struct{})
	s.q <- func(db *sql.DB) {
		tx, err := db.Begin()
		check(err)

		stmt := tx.Stmt(s.batch)
		for _, msg := range batch {
			must(stmt.Exec(id, msg.when, msg.side, msg.data))
		}

		check(tx.Commit())

		close(ch)
	}
	<-ch
}

func (s *Store) WriteClose(id SessionID, when Timestamp, by Side, reason string) {
	s.q <- func(*sql.DB) {
		must(s.close.Exec(when, by, reason, id))
	}
}

func (s *Store) Close() {
	close(s.q)
	<-s.done
}

func check(err error) {
	if err != nil {
		panic(err)
	}
}

func must(res sql.Result, err error) sql.Result {
	check(err)
	return res
}