mirror of
https://github.com/lone-cloud/prism
synced 2026-06-03 08:43:10 -07:00
125 lines
3.1 KiB
Go
125 lines
3.1 KiB
Go
package server
|
|
|
|
import (
|
|
"log/slog"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"prism/internal/util"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
func authMiddleware(apiKey string, allowInsecureHTTP bool) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !util.VerifyAPIKey(r, apiKey, allowInsecureHTTP) {
|
|
w.Header().Set("WWW-Authenticate", `Basic realm="Prism Admin - Username: any, Password: API_KEY"`)
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func loggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
|
|
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
|
next.ServeHTTP(wrapped, r)
|
|
|
|
logger.Debug("HTTP request",
|
|
"method", r.Method,
|
|
"path", r.URL.Path,
|
|
"status", wrapped.statusCode,
|
|
"duration", time.Since(start),
|
|
"ip", util.GetClientIP(r),
|
|
)
|
|
})
|
|
}
|
|
}
|
|
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(code int) {
|
|
rw.statusCode = code
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func securityHeadersMiddleware(allowInsecureHTTP bool) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; form-action 'self'; frame-ancestors 'none'; object-src 'none'")
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
type visitor struct {
|
|
limiter *rate.Limiter
|
|
lastSeen time.Time
|
|
}
|
|
|
|
var (
|
|
visitors = make(map[string]*visitor)
|
|
visitorsMu sync.RWMutex
|
|
)
|
|
|
|
func rateLimitMiddleware(rps int, allowInsecureHTTP bool) func(http.Handler) http.Handler {
|
|
go cleanupVisitors()
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := util.GetClientIP(r)
|
|
|
|
if allowInsecureHTTP || util.IsLocalhost(ip) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
visitorsMu.Lock()
|
|
v, exists := visitors[ip]
|
|
if !exists {
|
|
limiter := rate.NewLimiter(rate.Limit(rps), rps*2)
|
|
visitors[ip] = &visitor{limiter: limiter, lastSeen: time.Now()}
|
|
v = visitors[ip]
|
|
}
|
|
v.lastSeen = time.Now()
|
|
visitorsMu.Unlock()
|
|
|
|
if !v.limiter.Allow() {
|
|
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func cleanupVisitors() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
visitorsMu.Lock()
|
|
for ip, v := range visitors {
|
|
if time.Since(v.lastSeen) > 10*time.Minute {
|
|
delete(visitors, ip)
|
|
}
|
|
}
|
|
visitorsMu.Unlock()
|
|
}
|
|
}
|