code cleanups, switch to chi rate limiting middleware

This commit is contained in:
Egor 2026-02-24 01:24:22 -08:00
parent e91b0fd74f
commit b27a943bd3
14 changed files with 40 additions and 163 deletions

View file

@ -1 +1 @@
0.4.0
0.4.1

3
go.mod
View file

@ -6,10 +6,10 @@ require (
github.com/SherClockHolmes/webpush-go v1.4.0
github.com/emersion/hydroxide v0.2.31
github.com/go-chi/chi/v5 v5.2.5
github.com/go-chi/httprate v0.15.0
github.com/joho/godotenv v1.5.1
github.com/mymmrac/telego v1.6.0
golang.org/x/crypto v0.48.0
golang.org/x/time v0.14.0
modernc.org/sqlite v1.46.1
)
@ -35,6 +35,7 @@ require (
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
golang.org/x/arch v0.24.0 // indirect
golang.org/x/exp v0.0.0-20260209203927-2842357ff358 // indirect
golang.org/x/sys v0.41.0 // indirect

8
go.sum
View file

@ -25,6 +25,8 @@ github.com/emersion/hydroxide v0.2.31 h1:ofPKtEpD+AtE2oKJhcQKhUjubQS8+AHIjCuO3ae
github.com/emersion/hydroxide v0.2.31/go.mod h1:jhoMVyP0z2GACrmFkL0ppcWKt2LbC02auADSSsB4vH4=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/go-chi/httprate v0.15.0 h1:j54xcWV9KGmPf/X4H32/aTH+wBlrvxL7P+SdnRqxh5g=
github.com/go-chi/httprate v0.15.0/go.mod h1:rzGHhVrsBn3IMLYDOZQsSU4fJNWcjui4fWKJcCId1R4=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
@ -74,6 +76,10 @@ github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLr
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
@ -144,8 +150,6 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=

View file

@ -29,7 +29,7 @@ func init() {
func main() {
if len(os.Args) > 1 && os.Args[1] == "version" {
fmt.Printf("Prism %s\n", version)
fmt.Println("Prism v", version)
return
}
@ -49,7 +49,7 @@ func main() {
}
func runServer(cfg *config.Config, logger *slog.Logger) error {
srv, err := server.New(cfg, publicAssets, version)
srv, err := server.New(cfg, publicAssets, version, logger)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}

View file

@ -61,7 +61,9 @@ func getEnvInt(key string, defaultValue int) int {
func getEnvBool(key string, defaultValue bool) bool {
if value := os.Getenv(key); value != "" {
return value == "true" || value == "1"
if b, err := strconv.ParseBool(value); err == nil {
return b
}
}
return defaultValue
}

View file

@ -97,7 +97,7 @@ func Initialize(cfg *config.Config, store *notification.Store, logger *slog.Logg
}, fragmentTmpl, nil
}
func (i *Integrations) Start(ctx context.Context, cfg *config.Config, logger *slog.Logger) {
func (i *Integrations) Start(ctx context.Context, logger *slog.Logger) {
for _, integration := range i.integrations {
if integration.IsEnabled() {
integration.Start(ctx, logger)

View file

@ -17,9 +17,7 @@ import (
type Integration struct {
cfg *config.Config
dispatcher *notification.Dispatcher
logger *slog.Logger
handlers *Handlers
tmpl *util.TemplateRenderer
monitor *Monitor
db *sql.DB
apiKey string
@ -31,9 +29,7 @@ func NewIntegration(cfg *config.Config, dispatcher *notification.Dispatcher, log
return &Integration{
cfg: cfg,
dispatcher: dispatcher,
logger: logger,
handlers: handlers,
tmpl: tmpl,
monitor: monitor,
db: db,
apiKey: apiKey,

View file

@ -19,7 +19,6 @@ type Integration struct {
handlers *Handlers
sender *Sender
tmpl *util.TemplateRenderer
logger *slog.Logger
}
func NewIntegration(cfg *config.Config, store *notification.Store, logger *slog.Logger, tmpl *util.TemplateRenderer) *Integration {
@ -33,7 +32,6 @@ func NewIntegration(cfg *config.Config, store *notification.Store, logger *slog.
client: client,
sender: sender,
tmpl: tmpl,
logger: logger,
}
}
@ -42,7 +40,7 @@ func (s *Integration) GetSender() *Sender {
}
func (s *Integration) RegisterRoutes(router *chi.Mux, auth func(http.Handler) http.Handler, db *sql.DB, apiKey string, logger *slog.Logger) {
s.handlers = RegisterRoutes(router, s.cfg, auth, s.tmpl, s.logger, s.client)
s.handlers = RegisterRoutes(router, s.cfg, auth, s.tmpl, logger, s.client)
}
func (s *Integration) Start(ctx context.Context, logger *slog.Logger) {

View file

@ -16,10 +16,8 @@ import (
)
type Integration struct {
cfg *config.Config
handlers *Handlers
sender *Sender
logger *slog.Logger
}
func NewIntegration(cfg *config.Config, store *notification.Store, logger *slog.Logger, tmpl *util.TemplateRenderer) *Integration {
@ -52,10 +50,8 @@ func NewIntegration(cfg *config.Config, store *notification.Store, logger *slog.
handlers := NewHandlers(client, chatID, tmpl, logger)
return &Integration{
cfg: cfg,
handlers: handlers,
sender: sender,
logger: logger,
}
}

View file

@ -5,8 +5,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
_ "modernc.org/sqlite"
)
@ -21,11 +19,7 @@ func NewStore(dbPath string) (*Store, error) {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
// foreign_keys(1): enable FK constraints
// busy_timeout(5000): wait up to 5s for locks instead of failing immediately
// journal_mode(WAL): improves concurrent read/write behavior
// synchronous(FULL): durability-first mode; safest on sudden power loss
db, err := sql.Open("sqlite", dbPath+"?_pragma=foreign_keys(1)&_pragma=busy_timeout(5000)&_pragma=journal_mode(WAL)&_pragma=synchronous(FULL)")
db, err := sql.Open("sqlite", dbPath+"?_pragma=foreign_keys(1)&_pragma=busy_timeout(5000)&_pragma=journal_mode(WAL)")
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
@ -33,10 +27,6 @@ func NewStore(dbPath string) (*Store, error) {
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
store := &Store{db: db}
if err := store.createTables(); err != nil {
return nil, err
@ -90,8 +80,7 @@ func (s *Store) GetDB() *sql.DB {
}
func (s *Store) RegisterApp(appName string) error {
query := `INSERT INTO apps (appName) VALUES (?) ON CONFLICT(appName) DO NOTHING`
_, err := s.execWrite(query, appName)
_, err := s.db.Exec(`INSERT INTO apps (appName) VALUES (?) ON CONFLICT(appName) DO NOTHING`, appName)
return err
}
@ -121,7 +110,7 @@ func (s *Store) AddSubscription(sub Subscription) error {
vapidPrivateKey = &sub.WebPush.VapidPrivateKey
}
_, err := s.execWrite(query, sub.ID, sub.AppName, sub.Channel, signalGroupID, signalAccount, telegramChatID, pushEndpoint, p256dh, auth, vapidPrivateKey)
_, err := s.db.Exec(query, sub.ID, sub.AppName, sub.Channel, signalGroupID, signalAccount, telegramChatID, pushEndpoint, p256dh, auth, vapidPrivateKey)
return err
}
@ -261,15 +250,13 @@ func (s *Store) SaveSignalGroup(appName string, sub *SignalSubscription) error {
return err
}
query := `INSERT INTO signal_groups (appName, groupId, account) VALUES (?, ?, ?)
ON CONFLICT(appName) DO UPDATE SET groupId=excluded.groupId, account=excluded.account`
_, err := s.execWrite(query, appName, sub.GroupID, sub.Account)
_, err := s.db.Exec(`INSERT INTO signal_groups (appName, groupId, account) VALUES (?, ?, ?)
ON CONFLICT(appName) DO UPDATE SET groupId=excluded.groupId, account=excluded.account`, appName, sub.GroupID, sub.Account)
return err
}
func (s *Store) DeleteSubscription(subscriptionID string) error {
query := `DELETE FROM subscriptions WHERE id = ?`
_, err := s.execWrite(query, subscriptionID)
_, err := s.db.Exec(`DELETE FROM subscriptions WHERE id = ?`, subscriptionID)
return err
}
@ -322,35 +309,6 @@ func (s *Store) GetSubscription(subscriptionID string) (*Subscription, error) {
}
func (s *Store) RemoveApp(appName string) error {
_, err := s.execWrite(`DELETE FROM apps WHERE appName = ?`, appName)
_, err := s.db.Exec(`DELETE FROM apps WHERE appName = ?`, appName)
return err
}
func (s *Store) execWrite(query string, args ...any) (sql.Result, error) {
const maxAttempts = 4
delay := 50 * time.Millisecond
for attempt := 0; attempt < maxAttempts; attempt++ {
result, err := s.db.Exec(query, args...)
if err == nil {
return result, nil
}
if !isSQLiteBusyError(err) || attempt == maxAttempts-1 {
return nil, err
}
time.Sleep(delay)
delay *= 2
}
return nil, fmt.Errorf("write failed")
}
func isSQLiteBusyError(err error) bool {
if err == nil {
return false
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "sqlite_busy") || strings.Contains(message, "database is locked")
}

View file

@ -3,13 +3,12 @@ package server
import (
"log/slog"
"net/http"
"sync"
"time"
"prism/service/util"
"github.com/go-chi/chi/v5/middleware"
"golang.org/x/time/rate"
"github.com/go-chi/httprate"
)
var noisyPaths = map[string]bool{
@ -63,63 +62,20 @@ func securityHeadersMiddleware() func(http.Handler) http.Handler {
}
}
type visitor struct {
limiter *rate.Limiter
lastSeen time.Time
}
var (
visitors = make(map[string]*visitor)
visitorsMu sync.RWMutex
)
func rateLimitMiddleware(rps int) func(http.Handler) http.Handler {
go cleanupVisitors()
rl := httprate.LimitByIP(rps, time.Second)
return func(next http.Handler) http.Handler {
limited := rl(next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := util.GetClientIP(r)
if util.IsLocalhost(ip) {
if util.IsLocalhost(util.GetClientIP(r)) {
next.ServeHTTP(w, r)
return
}
visitorsMu.Lock()
v, exists := visitors[ip]
if !exists {
limiter := rate.NewLimiter(rate.Limit(rps), rps*2)
visitors[ip] = &visitor{limiter: limiter, lastSeen: time.Now()}
v = visitors[ip]
}
v.lastSeen = time.Now()
visitorsMu.Unlock()
if !v.limiter.Allow() {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
limited.ServeHTTP(w, r)
})
}
}
func cleanupVisitors() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
visitorsMu.Lock()
for ip, v := range visitors {
if time.Since(v.lastSeen) > 10*time.Minute {
delete(visitors, ip)
}
}
visitorsMu.Unlock()
}
}
func maxBodySizeMiddleware(maxBytes int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View file

@ -37,9 +37,7 @@ type Server struct {
version string
}
func New(cfg *config.Config, publicAssets embed.FS, version string) (*Server, error) {
logger := util.NewLogger(cfg.VerboseLogging)
func New(cfg *config.Config, publicAssets embed.FS, version string, logger *slog.Logger) (*Server, error) {
store, err := notification.NewStore(cfg.StoragePath)
if err != nil {
return nil, fmt.Errorf("failed to create store: %w", err)
@ -128,7 +126,7 @@ func (s *Server) setupRoutes() {
}
func (s *Server) Start(ctx context.Context) error {
s.integrations.Start(ctx, s.cfg, s.logger)
s.integrations.Start(ctx, s.logger)
addr := fmt.Sprintf(":%d", s.cfg.Port)
s.httpServer = &http.Server{

View file

@ -34,48 +34,9 @@ func VerifyAPIKey(r *http.Request, apiKey string) bool {
return false
}
if len(password) != len(apiKey) {
return false
}
proto := r.Header.Get("X-Forwarded-Proto")
if proto == "" {
if r.TLS != nil {
proto = "https"
} else {
proto = "http"
}
}
clientIP := GetClientIP(r)
if proto != "https" && !isLocalIP(clientIP) {
return false
}
return subtle.ConstantTimeCompare([]byte(password), []byte(apiKey)) == 1
}
func isLocalIP(addr string) bool {
if addr == "" || addr == "::1" || addr == "localhost" {
return true
}
ip := net.ParseIP(addr)
if ip == nil {
return false
}
if ip.IsLoopback() {
return true
}
if ip.To4() != nil {
return ip.IsPrivate()
}
return false
}
func GetClientIP(r *http.Request) string {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil || host == "" {

View file

@ -18,8 +18,9 @@ const (
)
type ColorHandler struct {
w io.Writer
level slog.Level
w io.Writer
level slog.Level
preAttrs []slog.Attr
}
func NewColorHandler(w io.Writer, opts *slog.HandlerOptions) *ColorHandler {
@ -38,7 +39,9 @@ func (h *ColorHandler) Enabled(ctx context.Context, level slog.Level) bool {
}
func (h *ColorHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return h
newH := *h
newH.preAttrs = append(append([]slog.Attr{}, h.preAttrs...), attrs...)
return &newH
}
func (h *ColorHandler) WithGroup(name string) slog.Handler {
@ -68,6 +71,10 @@ func (h *ColorHandler) Handle(ctx context.Context, r slog.Record) error {
color, level, colorReset,
r.Message)
for _, a := range h.preAttrs {
_, _ = fmt.Fprintf(h.w, " %s=%v", a.Key, a.Value) //nolint:errcheck
}
r.Attrs(func(a slog.Attr) bool {
_, _ = fmt.Fprintf(h.w, " %s=%v", a.Key, a.Value) //nolint:errcheck
return true