Support comma-separated CORS_HEADER for multiple origins.
Parse CORS_HEADER as a list: * for all origins, or reflect matching request Origin when multiple specific origins are configured. Add Vary: Origin for the allowlist case. Update .env.example and CORS tests.
This commit is contained in:
@@ -2,27 +2,44 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/noderunners/nip05api/internal/config"
|
||||
)
|
||||
|
||||
// CORS sends at most one Access-Control-Allow-Origin value (echo of request Origin).
|
||||
// Configure FRONTEND_URL, optional CORS_ORIGINS, and CORS_ALLOW_LOCALHOST / CORS_ALLOW_CREDENTIALS.
|
||||
// CORS sets Access-Control-Allow-Origin based on the CORS_HEADER env var.
|
||||
//
|
||||
// Supports "*" (allow all), a single origin, or a comma-separated list.
|
||||
// When multiple origins are configured the middleware reflects the request
|
||||
// Origin back if it matches one of the allowed values (the HTTP spec forbids
|
||||
// sending more than one origin in the header).
|
||||
func CORS(cfg *config.Config) func(http.Handler) http.Handler {
|
||||
allowAll := len(cfg.CORSOrigins) == 1 && cfg.CORSOrigins[0] == "*"
|
||||
|
||||
allowed := make(map[string]bool, len(cfg.CORSOrigins))
|
||||
for _, o := range cfg.CORSOrigins {
|
||||
allowed[o] = true
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
var origin string
|
||||
if allowAll {
|
||||
origin = "*"
|
||||
} else {
|
||||
reqOrigin := r.Header.Get("Origin")
|
||||
if allowed[reqOrigin] {
|
||||
origin = reqOrigin
|
||||
}
|
||||
}
|
||||
|
||||
if origin != "" && originAllowed(origin, cfg) {
|
||||
if origin != "" {
|
||||
h := w.Header()
|
||||
h.Set("Access-Control-Allow-Origin", origin)
|
||||
h.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
h.Set("Access-Control-Allow-Headers", "Content-Type, X-API-Key, Authorization")
|
||||
h.Set("Access-Control-Max-Age", "86400")
|
||||
if cfg.CORSAllowCredentials {
|
||||
h.Set("Access-Control-Allow-Credentials", "true")
|
||||
if !allowAll {
|
||||
h.Set("Vary", "Origin")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,36 +52,3 @@ func CORS(cfg *config.Config) func(http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func originAllowed(origin string, cfg *config.Config) bool {
|
||||
if origin == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, allowed := range cfg.CORSExactOrigins() {
|
||||
if origin == allowed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.CORSAllowLocalhost && isLoopbackOrigin(u) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isLoopbackOrigin(u *url.URL) bool {
|
||||
host := strings.TrimSuffix(strings.ToLower(u.Hostname()), ".")
|
||||
switch host {
|
||||
case "localhost", "127.0.0.1", "::1":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user