Support comma-separated CORS_HEADER for multiple origins.
Parse CORS_HEADER as a list: * for all origins, or reflect matching request Origin when multiple specific origins are configured. Add Vary: Origin for the allowlist case. Update .env.example and CORS tests.
This commit is contained in:
@@ -63,10 +63,10 @@ type Config struct {
|
||||
RateLimitPerMin int
|
||||
ReservedUsernames []string
|
||||
|
||||
// CORS: exact origin list = FRONTEND_URL ∪ CORS_ORIGINS; loopback hosts if CORS_ALLOW_LOCALHOST.
|
||||
CORSExtraOrigins []string
|
||||
CORSAllowLocalhost bool
|
||||
CORSAllowCredentials bool
|
||||
// CORSOrigins is parsed from the CORS_HEADER env var (comma-separated).
|
||||
// Use "*" to allow all origins, or list specific origins like
|
||||
// "https://example.com,https://other.example".
|
||||
CORSOrigins []string
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
@@ -113,9 +113,7 @@ func Load() (*Config, error) {
|
||||
LogLevel: env("LOG_LEVEL", "info"),
|
||||
RateLimitPerMin: envInt("RATE_LIMIT_PER_MIN", 30),
|
||||
ReservedUsernames: csv(env("RESERVED_USERNAMES", "")),
|
||||
CORSExtraOrigins: csv(env("CORS_ORIGINS", "")),
|
||||
CORSAllowLocalhost: envBool("CORS_ALLOW_LOCALHOST", true),
|
||||
CORSAllowCredentials: envBool("CORS_ALLOW_CREDENTIALS", false),
|
||||
CORSOrigins: csv(env("CORS_HEADER", "*")),
|
||||
}
|
||||
|
||||
if err := Validate(c); err != nil {
|
||||
@@ -181,22 +179,3 @@ func csvInt(v string) []int {
|
||||
}
|
||||
|
||||
func (c *Config) Addr() string { return fmt.Sprintf(":%d", c.Port) }
|
||||
|
||||
// CORSExactOrigins lists allowed browser Origins for exact match (before loopback wildcard).
|
||||
func (c *Config) CORSExactOrigins() []string {
|
||||
seen := make(map[string]bool)
|
||||
out := make([]string, 0, 4+len(c.CORSExtraOrigins))
|
||||
add := func(s string) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" || seen[s] {
|
||||
return
|
||||
}
|
||||
seen[s] = true
|
||||
out = append(out, s)
|
||||
}
|
||||
add(c.FrontendURL)
|
||||
for _, o := range c.CORSExtraOrigins {
|
||||
add(o)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -2,27 +2,44 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/noderunners/nip05api/internal/config"
|
||||
)
|
||||
|
||||
// CORS sends at most one Access-Control-Allow-Origin value (echo of request Origin).
|
||||
// Configure FRONTEND_URL, optional CORS_ORIGINS, and CORS_ALLOW_LOCALHOST / CORS_ALLOW_CREDENTIALS.
|
||||
// CORS sets Access-Control-Allow-Origin based on the CORS_HEADER env var.
|
||||
//
|
||||
// Supports "*" (allow all), a single origin, or a comma-separated list.
|
||||
// When multiple origins are configured the middleware reflects the request
|
||||
// Origin back if it matches one of the allowed values (the HTTP spec forbids
|
||||
// sending more than one origin in the header).
|
||||
func CORS(cfg *config.Config) func(http.Handler) http.Handler {
|
||||
allowAll := len(cfg.CORSOrigins) == 1 && cfg.CORSOrigins[0] == "*"
|
||||
|
||||
allowed := make(map[string]bool, len(cfg.CORSOrigins))
|
||||
for _, o := range cfg.CORSOrigins {
|
||||
allowed[o] = true
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
var origin string
|
||||
if allowAll {
|
||||
origin = "*"
|
||||
} else {
|
||||
reqOrigin := r.Header.Get("Origin")
|
||||
if allowed[reqOrigin] {
|
||||
origin = reqOrigin
|
||||
}
|
||||
}
|
||||
|
||||
if origin != "" && originAllowed(origin, cfg) {
|
||||
if origin != "" {
|
||||
h := w.Header()
|
||||
h.Set("Access-Control-Allow-Origin", origin)
|
||||
h.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
h.Set("Access-Control-Allow-Headers", "Content-Type, X-API-Key, Authorization")
|
||||
h.Set("Access-Control-Max-Age", "86400")
|
||||
if cfg.CORSAllowCredentials {
|
||||
h.Set("Access-Control-Allow-Credentials", "true")
|
||||
if !allowAll {
|
||||
h.Set("Vary", "Origin")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,36 +52,3 @@ func CORS(cfg *config.Config) func(http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func originAllowed(origin string, cfg *config.Config) bool {
|
||||
if origin == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, allowed := range cfg.CORSExactOrigins() {
|
||||
if origin == allowed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.CORSAllowLocalhost && isLoopbackOrigin(u) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isLoopbackOrigin(u *url.URL) bool {
|
||||
host := strings.TrimSuffix(strings.ToLower(u.Hostname()), ".")
|
||||
switch host {
|
||||
case "localhost", "127.0.0.1", "::1":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ func newFixture(t *testing.T) *fixture {
|
||||
Expiry: config.ExpiryConfig{GraceDays: 30},
|
||||
ReservedUsernames: []string{"admin", "root"},
|
||||
RateLimitPerMin: 0, // disabled in tests
|
||||
CORSOrigins: []string{"*"},
|
||||
}
|
||||
tmpls, _ := messages.Load("/nonexistent.yaml")
|
||||
users := user.NewService(user.NewRepo(d), cfg.ReservedUsernames)
|
||||
@@ -553,6 +554,29 @@ func TestDocsPage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSHeader(t *testing.T) {
|
||||
f := newFixture(t)
|
||||
|
||||
resp, _ := f.get(t, "/healthz")
|
||||
if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("expected Access-Control-Allow-Origin=*, got %q", got)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("OPTIONS", f.srv.URL+"/v1/pricing", nil)
|
||||
req.Header.Set("Origin", "https://random-frontend.example")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("expected 204 on OPTIONS preflight, got %d", resp.StatusCode)
|
||||
}
|
||||
if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Errorf("expected Access-Control-Allow-Origin=*, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyLimit(t *testing.T) {
|
||||
f := newFixture(t)
|
||||
huge := bytes.Repeat([]byte("a"), 2<<20) // 2 MiB
|
||||
|
||||
Reference in New Issue
Block a user