mirror of
https://github.com/lone-cloud/prism
synced 2026-06-03 08:43:10 -07:00
adding basic webpush request validation
This commit is contained in:
parent
ba032f72cc
commit
6f84d107cc
3 changed files with 130 additions and 8 deletions
|
|
@ -6,7 +6,6 @@ import (
|
|||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"prism/service/notification"
|
||||
"prism/service/util"
|
||||
|
|
@ -54,8 +53,23 @@ func (h *Handlers) HandleRegister(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if _, err := url.Parse(req.PushEndpoint); err != nil {
|
||||
util.JSONError(w, "Invalid pushEndpoint URL", http.StatusBadRequest)
|
||||
encryptedFieldCount := 0
|
||||
if req.P256dh != nil {
|
||||
encryptedFieldCount++
|
||||
}
|
||||
if req.Auth != nil {
|
||||
encryptedFieldCount++
|
||||
}
|
||||
if req.VapidPrivateKey != nil {
|
||||
encryptedFieldCount++
|
||||
}
|
||||
if encryptedFieldCount > 0 && encryptedFieldCount < 3 {
|
||||
util.JSONError(w, "p256dh, auth, and vapidPrivateKey must all be provided together", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePushEndpoint(req.PushEndpoint, encryptedFieldCount == 3); err != nil {
|
||||
util.JSONError(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -67,11 +81,29 @@ func (h *Handlers) HandleRegister(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var webPush *notification.WebPushSubscription
|
||||
if req.P256dh != nil && req.Auth != nil && req.VapidPrivateKey != nil {
|
||||
normalizedP256dh, err := normalizeP256DH(*req.P256dh)
|
||||
if err != nil {
|
||||
util.JSONError(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
normalizedAuth, err := normalizeAuthSecret(*req.Auth)
|
||||
if err != nil {
|
||||
util.JSONError(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
normalizedKey, err := normalizeVAPIDPrivateKey(*req.VapidPrivateKey)
|
||||
if err != nil {
|
||||
util.JSONError(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
webPush = ¬ification.WebPushSubscription{
|
||||
Endpoint: req.PushEndpoint,
|
||||
P256dh: *req.P256dh,
|
||||
Auth: *req.Auth,
|
||||
VapidPrivateKey: *req.VapidPrivateKey,
|
||||
P256dh: normalizedP256dh,
|
||||
Auth: normalizedAuth,
|
||||
VapidPrivateKey: normalizedKey,
|
||||
}
|
||||
} else {
|
||||
webPush = ¬ification.WebPushSubscription{
|
||||
|
|
@ -87,6 +119,7 @@ func (h *Handlers) HandleRegister(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if err := h.store.AddSubscription(sub); err != nil {
|
||||
h.logger.Warn("Failed to add webpush subscription", "app", req.AppName, "error", err)
|
||||
util.LogAndError(w, h.logger, "Failed to add subscription", http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
89
service/integration/webpush/validation.go
Normal file
89
service/integration/webpush/validation.go
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
package webpush
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/elliptic"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func validatePushEndpoint(raw string, requireHTTPS bool) error {
|
||||
u, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil || u == nil || u.Scheme == "" || u.Host == "" {
|
||||
return fmt.Errorf("invalid pushEndpoint URL")
|
||||
}
|
||||
|
||||
if u.Scheme != "https" && u.Scheme != "http" {
|
||||
return fmt.Errorf("pushEndpoint must use http or https")
|
||||
}
|
||||
|
||||
if requireHTTPS && u.Scheme != "https" {
|
||||
return fmt.Errorf("encrypted webpush endpoint must use https")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeVAPIDPrivateKey(raw string) (string, error) {
|
||||
decoded, err := decodeBase64URL(raw)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid VAPID private key encoding")
|
||||
}
|
||||
|
||||
if len(decoded) != 32 {
|
||||
return "", fmt.Errorf("invalid VAPID private key length: expected 32 bytes, got %d", len(decoded))
|
||||
}
|
||||
|
||||
n := elliptic.P256().Params().N
|
||||
d := new(big.Int).SetBytes(decoded)
|
||||
if d.Sign() <= 0 || d.Cmp(n) >= 0 {
|
||||
return "", fmt.Errorf("invalid VAPID private key scalar")
|
||||
}
|
||||
|
||||
return base64.RawURLEncoding.EncodeToString(decoded), nil
|
||||
}
|
||||
|
||||
func normalizeP256DH(raw string) (string, error) {
|
||||
decoded, err := decodeBase64URL(raw)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid p256dh encoding")
|
||||
}
|
||||
|
||||
if len(decoded) != 65 || decoded[0] != 0x04 {
|
||||
return "", fmt.Errorf("invalid p256dh key format")
|
||||
}
|
||||
|
||||
if _, err := ecdh.P256().NewPublicKey(decoded); err != nil {
|
||||
return "", fmt.Errorf("invalid p256dh point")
|
||||
}
|
||||
|
||||
return base64.RawURLEncoding.EncodeToString(decoded), nil
|
||||
}
|
||||
|
||||
func normalizeAuthSecret(raw string) (string, error) {
|
||||
decoded, err := decodeBase64URL(raw)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid auth encoding")
|
||||
}
|
||||
|
||||
if len(decoded) != 16 {
|
||||
return "", fmt.Errorf("invalid auth length: expected 16 bytes, got %d", len(decoded))
|
||||
}
|
||||
|
||||
return base64.RawURLEncoding.EncodeToString(decoded), nil
|
||||
}
|
||||
|
||||
func decodeBase64URL(raw string) ([]byte, error) {
|
||||
key := strings.TrimSpace(raw)
|
||||
if decoded, err := base64.RawURLEncoding.DecodeString(key); err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
decoded, err := base64.URLEncoding.DecodeString(key)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -159,9 +159,9 @@ func (s *Server) handleDeleteSubscription(w http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
message := fmt.Sprintf("%s disabled", sub.Channel.Label())
|
||||
message := fmt.Sprintf("%s channel disabled", sub.Channel.Label())
|
||||
if sub.Channel == notification.ChannelWebPush {
|
||||
message = fmt.Sprintf("%s deleted", sub.Channel.Label())
|
||||
message = fmt.Sprintf("%s channel deleted", sub.Channel.Label())
|
||||
}
|
||||
util.SetToast(w, message, "success")
|
||||
s.handleFragmentApps(w, r)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue