From a01797e9b2287add2c9c5c4b67a398dcf4b62028 Mon Sep 17 00:00:00 2001 From: Michilis Date: Wed, 29 Apr 2026 05:44:59 +0000 Subject: [PATCH] Improve CORS origin handling; extend invoice repo/service and payments dispatch; rate limit and nginx config updates Made-with: Love --- .env.example | 6 ++ deploy/nginx.conf | 3 +- internal/config/config.go | 27 +++++++++ internal/http/docs/openapi.yaml | 7 ++- internal/http/handlers/invoices.go | 2 - internal/http/middleware/cors.go | 80 ++++++++++++++++++++++----- internal/http/middleware/ratelimit.go | 6 ++ internal/http/server.go | 2 +- internal/invoice/repo.go | 25 +++++++++ internal/invoice/service.go | 24 +++++++- internal/invoice/service_test.go | 66 +++++++++++++++++----- internal/payments/dispatch.go | 11 ++++ 12 files changed, 224 insertions(+), 35 deletions(-) diff --git a/.env.example b/.env.example index 217d85f..d676179 100644 --- a/.env.example +++ b/.env.example @@ -4,6 +4,12 @@ PORT=8080 ADMIN_API_KEY=change-me-to-a-long-random-string FRONTEND_URL=https://azzamo.net/nip05 +# Optional extra browser origins (comma-separated). Merged with FRONTEND_URL for CORS. +# CORS_ORIGINS= + +# Allow http(s)://localhost:* and 127.0.0.1 for local UI dev hitting this API directly (not via Vite proxy). +CORS_ALLOW_LOCALHOST=true + # --- Database --- DATABASE_PATH=.data/nip05.db diff --git a/deploy/nginx.conf b/deploy/nginx.conf index 0df489d..d705d0f 100644 --- a/deploy/nginx.conf +++ b/deploy/nginx.conf @@ -45,7 +45,8 @@ server { proxy_http_version 1.1; proxy_set_header Connection ""; - add_header Access-Control-Allow-Origin "*" always; + # Do not set CORS headers here — nip05api sends a single reflected Origin (see FRONTEND_URL / CORS_* env). + # Duplicate ACAO headers break browsers ("multiple values"). add_header Cache-Control "public, max-age=60" always; } diff --git a/internal/config/config.go b/internal/config/config.go index 91fca7a..d5c48ce 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -60,6 +60,11 @@ type Config struct { LogLevel string RateLimitPerMin int ReservedUsernames []string + + // CORS: exact origin list = FRONTEND_URL ∪ CORS_ORIGINS; loopback hosts if CORS_ALLOW_LOCALHOST. + CORSExtraOrigins []string + CORSAllowLocalhost bool + CORSAllowCredentials bool } func Load() (*Config, error) { @@ -104,6 +109,9 @@ func Load() (*Config, error) { LogLevel: env("LOG_LEVEL", "info"), RateLimitPerMin: envInt("RATE_LIMIT_PER_MIN", 30), ReservedUsernames: csv(env("RESERVED_USERNAMES", "")), + CORSExtraOrigins: csv(env("CORS_ORIGINS", "")), + CORSAllowLocalhost: envBool("CORS_ALLOW_LOCALHOST", true), + CORSAllowCredentials: envBool("CORS_ALLOW_CREDENTIALS", false), } if err := Validate(c); err != nil { @@ -169,3 +177,22 @@ func csvInt(v string) []int { } func (c *Config) Addr() string { return fmt.Sprintf(":%d", c.Port) } + +// CORSExactOrigins lists allowed browser Origins for exact match (before loopback wildcard). +func (c *Config) CORSExactOrigins() []string { + seen := make(map[string]bool) + out := make([]string, 0, 4+len(c.CORSExtraOrigins)) + add := func(s string) { + s = strings.TrimSpace(s) + if s == "" || seen[s] { + return + } + seen[s] = true + out = append(out, s) + } + add(c.FrontendURL) + for _, o := range c.CORSExtraOrigins { + add(o) + } + return out +} diff --git a/internal/http/docs/openapi.yaml b/internal/http/docs/openapi.yaml index d26c9a3..857519c 100644 --- a/internal/http/docs/openapi.yaml +++ b/internal/http/docs/openapi.yaml @@ -112,6 +112,11 @@ paths: post: tags: [User] summary: Create payment invoice + description: | + Creates a new Lightning invoice. If this pubkey already has an unpaid, unexpired invoice + for the same subscription_type (and same years when yearly), that invoice is returned + (idempotent resume). If an unpaid invoice exists for a different plan, it is discarded and + a new invoice is created for the requested plan. requestBody: required: true content: @@ -140,7 +145,7 @@ paths: schema: { $ref: '#/components/schemas/Invoice' } '400': { description: Validation error, content: { application/json: { schema: { $ref: '#/components/schemas/Error' } } } } '403': { description: Forbidden — user already has lifetime access, content: { application/json: { schema: { $ref: '#/components/schemas/Error' } } } } - '409': { description: Conflict — username unavailable or pending invoice already exists, content: { application/json: { schema: { $ref: '#/components/schemas/Error' } } } } + '409': { description: Conflict — username unavailable, content: { application/json: { schema: { $ref: '#/components/schemas/Error' } } } } '503': { description: Lightning unavailable, content: { application/json: { schema: { $ref: '#/components/schemas/Error' } } } } /v1/invoices/{payment_hash}: get: diff --git a/internal/http/handlers/invoices.go b/internal/http/handlers/invoices.go index 34732cd..114e46c 100644 --- a/internal/http/handlers/invoices.go +++ b/internal/http/handlers/invoices.go @@ -63,8 +63,6 @@ func (h *Invoices) Create(w http.ResponseWriter, r *http.Request) { switch { case errors.Is(err, invoice.ErrLifetimeAccess): WriteError(w, http.StatusForbidden, "User already has lifetime access", "") - case errors.Is(err, invoice.ErrPendingInvoiceExists): - WriteError(w, http.StatusConflict, "Conflict", err.Error()) case errors.Is(err, invoice.ErrUsernameTaken), errors.Is(err, user.ErrUsernameTaken): WriteError(w, http.StatusConflict, "Conflict", "username unavailable") diff --git a/internal/http/middleware/cors.go b/internal/http/middleware/cors.go index df16d94..daafc81 100644 --- a/internal/http/middleware/cors.go +++ b/internal/http/middleware/cors.go @@ -1,18 +1,70 @@ package middleware -import "net/http" +import ( + "net/http" + "net/url" + "strings" -func CORS(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h := w.Header() - h.Set("Access-Control-Allow-Origin", "*") - h.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - h.Set("Access-Control-Allow-Headers", "Content-Type, X-API-Key, Authorization") - h.Set("Access-Control-Max-Age", "86400") - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return - } - next.ServeHTTP(w, r) - }) + "github.com/noderunners/nip05api/internal/config" +) + +// CORS sends at most one Access-Control-Allow-Origin value (echo of request Origin). +// Configure FRONTEND_URL, optional CORS_ORIGINS, and CORS_ALLOW_LOCALHOST / CORS_ALLOW_CREDENTIALS. +func CORS(cfg *config.Config) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + if origin != "" && originAllowed(origin, cfg) { + h := w.Header() + h.Set("Access-Control-Allow-Origin", origin) + h.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + h.Set("Access-Control-Allow-Headers", "Content-Type, X-API-Key, Authorization") + h.Set("Access-Control-Max-Age", "86400") + if cfg.CORSAllowCredentials { + h.Set("Access-Control-Allow-Credentials", "true") + } + } + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +func originAllowed(origin string, cfg *config.Config) bool { + if origin == "" { + return false + } + + u, err := url.Parse(origin) + if err != nil || u.Scheme == "" || u.Host == "" { + return false + } + + for _, allowed := range cfg.CORSExactOrigins() { + if origin == allowed { + return true + } + } + + if cfg.CORSAllowLocalhost && isLoopbackOrigin(u) { + return true + } + + return false +} + +func isLoopbackOrigin(u *url.URL) bool { + host := strings.TrimSuffix(strings.ToLower(u.Hostname()), ".") + switch host { + case "localhost", "127.0.0.1", "::1": + return true + default: + return false + } } diff --git a/internal/http/middleware/ratelimit.go b/internal/http/middleware/ratelimit.go index 8211b24..e57d419 100644 --- a/internal/http/middleware/ratelimit.go +++ b/internal/http/middleware/ratelimit.go @@ -10,6 +10,8 @@ import ( // RateLimit returns a middleware that limits requests per minute by IP. // Admin routes are skipped. +// GET /v1/invoices/{hash} is skipped: the SPA polls invoice status ~30/min while +// the default global limit is 30/min, which starves pricing and user lookups on the same IP. func RateLimit(perMin int) func(http.Handler) http.Handler { if perMin <= 0 { return func(next http.Handler) http.Handler { return next } @@ -21,6 +23,10 @@ func RateLimit(perMin int) func(http.Handler) http.Handler { next.ServeHTTP(w, r) return } + if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/v1/invoices/") { + next.ServeHTTP(w, r) + return + } limiter(next).ServeHTTP(w, r) }) } diff --git a/internal/http/server.go b/internal/http/server.go index 15243a4..f6c81b0 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -34,7 +34,7 @@ func NewServer(d Deps) *http.Server { r.Use(middleware.Recoverer) r.Use(middleware.RealIP) r.Use(middleware.Logging) - r.Use(middleware.CORS) + r.Use(middleware.CORS(d.Cfg)) r.Use(middleware.BodyLimit(1 << 20)) // 1 MiB max request body r.Use(middleware.RateLimit(d.Cfg.RateLimitPerMin)) diff --git a/internal/invoice/repo.go b/internal/invoice/repo.go index bd1f6cb..ed65b42 100644 --- a/internal/invoice/repo.go +++ b/internal/invoice/repo.go @@ -149,6 +149,31 @@ func (r *Repo) HasUnpaidForPubkey(ctx context.Context, pubkey string) (bool, err return count > 0, err } +// GetActiveUnpaidByPubkey returns the most recent unpaid, unexpired invoice for the pubkey, or nil if none. +func (r *Repo) GetActiveUnpaidByPubkey(ctx context.Context, pubkey string) (*PendingInvoice, error) { + row := r.db.QueryRowContext(ctx, `SELECT `+invCols+` FROM pending_invoices + WHERE pubkey = ? AND paid = 0 AND expires_at > ? + ORDER BY created_at DESC LIMIT 1`, + pubkey, time.Now().UTC().Format(time.RFC3339)) + p, err := scanInvoice(row) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, err + } + return p, nil +} + +// DeleteActiveUnpaidForPubkey removes all unpaid, unexpired invoices for the pubkey so a new +// invoice can be issued when the user switches plan (replacing the previous Bolt11). +func (r *Repo) DeleteActiveUnpaidForPubkey(ctx context.Context, pubkey string) error { + _, err := r.db.ExecContext(ctx, + `DELETE FROM pending_invoices WHERE pubkey = ? AND paid = 0 AND expires_at > ?`, + pubkey, time.Now().UTC().Format(time.RFC3339)) + return err +} + func (r *Repo) PurgeOldUnpaid(ctx context.Context) error { cutoff := time.Now().UTC().Add(-1 * time.Hour).Format(time.RFC3339) _, err := r.db.ExecContext(ctx, `DELETE FROM pending_invoices WHERE paid = 0 AND expires_at < ?`, cutoff) diff --git a/internal/invoice/service.go b/internal/invoice/service.go index e868ebd..66fdfed 100644 --- a/internal/invoice/service.go +++ b/internal/invoice/service.go @@ -43,9 +43,20 @@ var ( ErrUsernameTaken = errors.New("username taken") ErrInvalidYears = errors.New("invalid years") ErrLifetimeAccess = errors.New("user already has lifetime access") + // Deprecated: Create no longer returns this; an existing unpaid invoice is returned instead. ErrPendingInvoiceExists = errors.New("pending unpaid invoice already exists") ) +func pendingMatchesRequest(p *PendingInvoice, req CreateRequest) bool { + if p.SubscriptionType != req.SubscriptionType { + return false + } + if req.SubscriptionType == user.SubYearly { + return p.Years == req.Years + } + return true +} + // Create computes amount, calls LNbits, persists pending invoice. Detects renewal. func (s *Service) Create(ctx context.Context, req CreateRequest) (*PendingInvoice, error) { if !req.SubscriptionType.Valid() { @@ -62,12 +73,19 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (*PendingInvoic req.Years = 0 } - hasPendingPubkey, err := s.repo.HasUnpaidForPubkey(ctx, req.Pubkey) + pendingExisting, err := s.repo.GetActiveUnpaidByPubkey(ctx, req.Pubkey) if err != nil { return nil, err } - if hasPendingPubkey { - return nil, ErrPendingInvoiceExists + if pendingExisting != nil { + if pendingMatchesRequest(pendingExisting, req) { + // Idempotent resume: same Bolt11 until paid or LN invoice expiry. + return pendingExisting, nil + } + // Replace pending Bolt11 with one for the newly requested plan. + if err := s.repo.DeleteActiveUnpaidForPubkey(ctx, req.Pubkey); err != nil { + return nil, err + } } username := user.NormalizeUsername(req.Username) diff --git a/internal/invoice/service_test.go b/internal/invoice/service_test.go index 02b3d33..1609944 100644 --- a/internal/invoice/service_test.go +++ b/internal/invoice/service_test.go @@ -127,26 +127,66 @@ func TestCreate_BlocksActiveLifetimeUser(t *testing.T) { } } -func TestCreate_RejectsDuplicatePending(t *testing.T) { +func TestCreate_ResumesDuplicatePending(t *testing.T) { svc, _, _ := newServiceFixture(t) - if _, err := svc.Create(context.Background(), invoice.CreateRequest{ - Pubkey: testHex, - Username: "alice", - SubscriptionType: user.SubYearly, - Years: 1, - }); err != nil { - t.Fatalf("first create: %v", err) - } - - _, err := svc.Create(context.Background(), invoice.CreateRequest{ + first, err := svc.Create(context.Background(), invoice.CreateRequest{ Pubkey: testHex, Username: "alice", SubscriptionType: user.SubYearly, Years: 1, }) - if !errors.Is(err, invoice.ErrPendingInvoiceExists) { - t.Fatalf("expected ErrPendingInvoiceExists, got %v", err) + if err != nil { + t.Fatalf("first create: %v", err) + } + + second, err := svc.Create(context.Background(), invoice.CreateRequest{ + Pubkey: testHex, + Username: "alice", + SubscriptionType: user.SubYearly, + Years: 1, + }) + if err != nil { + t.Fatalf("second create should resume pending: %v", err) + } + if second.PaymentHash != first.PaymentHash || second.PaymentRequest != first.PaymentRequest { + t.Fatalf("expected same pending invoice on resume, got hash=%q vs %q", second.PaymentHash, first.PaymentHash) + } +} + +func TestCreate_SupersedesDifferentPendingPlan(t *testing.T) { + svc, _, _ := newServiceFixture(t) + + first, err := svc.Create(context.Background(), invoice.CreateRequest{ + Pubkey: testHex, + Username: "alice", + SubscriptionType: user.SubLifetime, + }) + if err != nil { + t.Fatalf("lifetime pending: %v", err) + } + + second, err := svc.Create(context.Background(), invoice.CreateRequest{ + Pubkey: testHex, + Username: "alice", + SubscriptionType: user.SubYearly, + Years: 1, + }) + if err != nil { + t.Fatalf("yearly after lifetime pending: %v", err) + } + stored, err := svc.Repo().Get(context.Background(), second.PaymentHash) + if err != nil { + t.Fatal(err) + } + if stored.SubscriptionType != user.SubYearly || stored.Years != 1 || stored.AmountSats != 1000 { + t.Fatalf("DB row should reflect yearly plan after supersede: %+v", stored) + } + if second.SubscriptionType != user.SubYearly || second.Years != 1 || second.AmountSats != 1000 { + t.Fatalf("unexpected yearly invoice: %+v", second) + } + if first.SubscriptionType == second.SubscriptionType || first.AmountSats == second.AmountSats { + t.Fatalf("expected plan switch lifetime -> yearly") } } diff --git a/internal/payments/dispatch.go b/internal/payments/dispatch.go index fd983ae..1a623cb 100644 --- a/internal/payments/dispatch.go +++ b/internal/payments/dispatch.go @@ -2,6 +2,7 @@ package payments import ( "context" + "fmt" "log/slog" "strconv" "time" @@ -19,14 +20,24 @@ func (w *Worker) dispatchEvents(ctx context.Context, u *user.User, p *invoice.Pe if err := w.dms.Send(ctx, ev, u.Pubkey, vars); err != nil { slog.Error("dm enqueue", "err", err) } + confirmed := time.Now().UTC() data := map[string]any{ "pubkey": u.Pubkey, "npub": nostr.HexToNpub(u.Pubkey), "username": u.Username, "subscription_type": string(u.SubscriptionType), + "is_lifetime": u.IsLifetime(), + "years": p.Years, "amount_sats": p.AmountSats, "payment_hash": p.PaymentHash, "is_renewal": p.IsRenewal, + "confirmed_at": confirmed.Format(time.RFC3339), + "confirmed_at_unix": confirmed.Unix(), + } + if u.IsLifetime() { + data["duration"] = "lifetime" + } else { + data["duration"] = fmt.Sprintf("%dy", p.Years) } if u.ExpiresAt != nil { data["expires_at"] = u.ExpiresAt.UTC().Format(time.RFC3339)