prism/service/server/middleware.go

130 lines
3.2 KiB
Go

package server
import (
"log/slog"
"net/http"
"sync"
"time"
"prism/service/util"
"github.com/go-chi/chi/v5/middleware"
"golang.org/x/time/rate"
)
var noisyPaths = map[string]bool{
"/.well-known/appspecific/com.chrome.devtools.json": true,
"/health": true,
}
func authMiddleware(apiKey string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !util.VerifyAPIKey(r, apiKey) {
w.Header().Set("WWW-Authenticate", `Basic realm="Prism Admin - Username: any, Password: API_KEY"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
}
func loggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
next.ServeHTTP(ww, r)
if !noisyPaths[r.URL.Path] {
logger.Debug("HTTP request",
"method", r.Method,
"path", r.URL.Path,
"status", ww.Status(),
"duration", time.Since(start),
"ip", util.GetClientIP(r),
)
}
})
}
}
func securityHeadersMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https://api.qrserver.com; form-action 'self'; frame-ancestors 'none'; object-src 'none'")
next.ServeHTTP(w, r)
})
}
}
type visitor struct {
limiter *rate.Limiter
lastSeen time.Time
}
var (
visitors = make(map[string]*visitor)
visitorsMu sync.RWMutex
)
func rateLimitMiddleware(rps int) func(http.Handler) http.Handler {
go cleanupVisitors()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := util.GetClientIP(r)
if util.IsLocalhost(ip) {
next.ServeHTTP(w, r)
return
}
visitorsMu.Lock()
v, exists := visitors[ip]
if !exists {
limiter := rate.NewLimiter(rate.Limit(rps), rps*2)
visitors[ip] = &visitor{limiter: limiter, lastSeen: time.Now()}
v = visitors[ip]
}
v.lastSeen = time.Now()
visitorsMu.Unlock()
if !v.limiter.Allow() {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
func cleanupVisitors() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
visitorsMu.Lock()
for ip, v := range visitors {
if time.Since(v.lastSeen) > 10*time.Minute {
delete(visitors, ip)
}
}
visitorsMu.Unlock()
}
}
func maxBodySizeMiddleware(maxBytes int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
next.ServeHTTP(w, r)
})
}
}