package server import ( "log/slog" "net/http" "sync" "time" "prism/service/util" "github.com/go-chi/chi/v5/middleware" "golang.org/x/time/rate" ) var noisyPaths = map[string]bool{ "/.well-known/appspecific/com.chrome.devtools.json": true, "/health": true, } func authMiddleware(apiKey string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !util.VerifyAPIKey(r, apiKey) { w.Header().Set("WWW-Authenticate", `Basic realm="Prism Admin - Username: any, Password: API_KEY"`) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) } } func loggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) next.ServeHTTP(ww, r) if !noisyPaths[r.URL.Path] { logger.Debug("HTTP request", "method", r.Method, "path", r.URL.Path, "status", ww.Status(), "duration", time.Since(start), "ip", util.GetClientIP(r), ) } }) } } func securityHeadersMiddleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https://api.qrserver.com; form-action 'self'; frame-ancestors 'none'; object-src 'none'") next.ServeHTTP(w, r) }) } } type visitor struct { limiter *rate.Limiter lastSeen time.Time } var ( visitors = make(map[string]*visitor) visitorsMu sync.RWMutex ) func rateLimitMiddleware(rps int) func(http.Handler) http.Handler { go cleanupVisitors() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := util.GetClientIP(r) if util.IsLocalhost(ip) { next.ServeHTTP(w, r) return } visitorsMu.Lock() v, exists := visitors[ip] if !exists { limiter := rate.NewLimiter(rate.Limit(rps), rps*2) visitors[ip] = &visitor{limiter: limiter, lastSeen: time.Now()} v = visitors[ip] } v.lastSeen = time.Now() visitorsMu.Unlock() if !v.limiter.Allow() { http.Error(w, "Too Many Requests", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } } func cleanupVisitors() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { visitorsMu.Lock() for ip, v := range visitors { if time.Since(v.lastSeen) > 10*time.Minute { delete(visitors, ip) } } visitorsMu.Unlock() } } func maxBodySizeMiddleware(maxBytes int64) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxBytes) next.ServeHTTP(w, r) }) } }