diff --git a/.env.example b/.env.example index 2e7958a..6d06ec2 100644 --- a/.env.example +++ b/.env.example @@ -4,11 +4,10 @@ PORT=8080 ADMIN_API_KEY=change-me-to-a-long-random-string FRONTEND_URL=https://azzamo.net/nip05 -# Optional extra browser origins (comma-separated). Merged with FRONTEND_URL for CORS. -# CORS_ORIGINS= - -# Allow http(s)://localhost:* and 127.0.0.1 for local UI dev hitting this API directly (not via Vite proxy). -CORS_ALLOW_LOCALHOST=true +# --- CORS --- +# Comma-separated list of allowed origins, or "*" to allow all. +# Examples: "*" | "https://azzamo.net" | "https://azzamo.net,https://other.example" +CORS_HEADER=* # --- Database --- DATABASE_PATH=.data/nip05.db diff --git a/internal/config/config.go b/internal/config/config.go index 898bbce..e7a1d6d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -63,10 +63,10 @@ type Config struct { RateLimitPerMin int ReservedUsernames []string - // CORS: exact origin list = FRONTEND_URL ∪ CORS_ORIGINS; loopback hosts if CORS_ALLOW_LOCALHOST. - CORSExtraOrigins []string - CORSAllowLocalhost bool - CORSAllowCredentials bool + // CORSOrigins is parsed from the CORS_HEADER env var (comma-separated). + // Use "*" to allow all origins, or list specific origins like + // "https://example.com,https://other.example". + CORSOrigins []string } func Load() (*Config, error) { @@ -113,9 +113,7 @@ func Load() (*Config, error) { LogLevel: env("LOG_LEVEL", "info"), RateLimitPerMin: envInt("RATE_LIMIT_PER_MIN", 30), ReservedUsernames: csv(env("RESERVED_USERNAMES", "")), - CORSExtraOrigins: csv(env("CORS_ORIGINS", "")), - CORSAllowLocalhost: envBool("CORS_ALLOW_LOCALHOST", true), - CORSAllowCredentials: envBool("CORS_ALLOW_CREDENTIALS", false), + CORSOrigins: csv(env("CORS_HEADER", "*")), } if err := Validate(c); err != nil { @@ -181,22 +179,3 @@ func csvInt(v string) []int { } func (c *Config) Addr() string { return fmt.Sprintf(":%d", c.Port) } - -// CORSExactOrigins lists allowed browser Origins for exact match (before loopback wildcard). -func (c *Config) CORSExactOrigins() []string { - seen := make(map[string]bool) - out := make([]string, 0, 4+len(c.CORSExtraOrigins)) - add := func(s string) { - s = strings.TrimSpace(s) - if s == "" || seen[s] { - return - } - seen[s] = true - out = append(out, s) - } - add(c.FrontendURL) - for _, o := range c.CORSExtraOrigins { - add(o) - } - return out -} diff --git a/internal/http/middleware/cors.go b/internal/http/middleware/cors.go index daafc81..f60e095 100644 --- a/internal/http/middleware/cors.go +++ b/internal/http/middleware/cors.go @@ -2,27 +2,44 @@ package middleware import ( "net/http" - "net/url" - "strings" "github.com/noderunners/nip05api/internal/config" ) -// CORS sends at most one Access-Control-Allow-Origin value (echo of request Origin). -// Configure FRONTEND_URL, optional CORS_ORIGINS, and CORS_ALLOW_LOCALHOST / CORS_ALLOW_CREDENTIALS. +// CORS sets Access-Control-Allow-Origin based on the CORS_HEADER env var. +// +// Supports "*" (allow all), a single origin, or a comma-separated list. +// When multiple origins are configured the middleware reflects the request +// Origin back if it matches one of the allowed values (the HTTP spec forbids +// sending more than one origin in the header). func CORS(cfg *config.Config) func(http.Handler) http.Handler { + allowAll := len(cfg.CORSOrigins) == 1 && cfg.CORSOrigins[0] == "*" + + allowed := make(map[string]bool, len(cfg.CORSOrigins)) + for _, o := range cfg.CORSOrigins { + allowed[o] = true + } + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") + var origin string + if allowAll { + origin = "*" + } else { + reqOrigin := r.Header.Get("Origin") + if allowed[reqOrigin] { + origin = reqOrigin + } + } - if origin != "" && originAllowed(origin, cfg) { + if origin != "" { h := w.Header() h.Set("Access-Control-Allow-Origin", origin) h.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") h.Set("Access-Control-Allow-Headers", "Content-Type, X-API-Key, Authorization") h.Set("Access-Control-Max-Age", "86400") - if cfg.CORSAllowCredentials { - h.Set("Access-Control-Allow-Credentials", "true") + if !allowAll { + h.Set("Vary", "Origin") } } @@ -35,36 +52,3 @@ func CORS(cfg *config.Config) func(http.Handler) http.Handler { }) } } - -func originAllowed(origin string, cfg *config.Config) bool { - if origin == "" { - return false - } - - u, err := url.Parse(origin) - if err != nil || u.Scheme == "" || u.Host == "" { - return false - } - - for _, allowed := range cfg.CORSExactOrigins() { - if origin == allowed { - return true - } - } - - if cfg.CORSAllowLocalhost && isLoopbackOrigin(u) { - return true - } - - return false -} - -func isLoopbackOrigin(u *url.URL) bool { - host := strings.TrimSuffix(strings.ToLower(u.Hostname()), ".") - switch host { - case "localhost", "127.0.0.1", "::1": - return true - default: - return false - } -} diff --git a/internal/http/server_test.go b/internal/http/server_test.go index 920c4f3..bdbfbe9 100644 --- a/internal/http/server_test.go +++ b/internal/http/server_test.go @@ -48,6 +48,7 @@ func newFixture(t *testing.T) *fixture { Expiry: config.ExpiryConfig{GraceDays: 30}, ReservedUsernames: []string{"admin", "root"}, RateLimitPerMin: 0, // disabled in tests + CORSOrigins: []string{"*"}, } tmpls, _ := messages.Load("/nonexistent.yaml") users := user.NewService(user.NewRepo(d), cfg.ReservedUsernames) @@ -553,6 +554,29 @@ func TestDocsPage(t *testing.T) { } } +func TestCORSHeader(t *testing.T) { + f := newFixture(t) + + resp, _ := f.get(t, "/healthz") + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "*" { + t.Errorf("expected Access-Control-Allow-Origin=*, got %q", got) + } + + req, _ := http.NewRequest("OPTIONS", f.srv.URL+"/v1/pricing", nil) + req.Header.Set("Origin", "https://random-frontend.example") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + t.Errorf("expected 204 on OPTIONS preflight, got %d", resp.StatusCode) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "*" { + t.Errorf("expected Access-Control-Allow-Origin=*, got %q", got) + } +} + func TestBodyLimit(t *testing.T) { f := newFixture(t) huge := bytes.Repeat([]byte("a"), 2<<20) // 2 MiB