package invoice import ( "context" "database/sql" "errors" "time" "github.com/noderunners/nip05api/internal/db" "github.com/noderunners/nip05api/internal/user" ) type Repo struct{ db *db.DB } func NewRepo(d *db.DB) *Repo { return &Repo{db: d} } const invCols = `payment_hash, payment_request, username, pubkey, subscription_type, years, amount_sats, expires_at, paid, is_renewal, created_at, target_expires_at` func scanInvoice(row interface{ Scan(...any) error }) (*PendingInvoice, error) { var p PendingInvoice var sub, expires, created string var paid, renewal int var target sql.NullString if err := row.Scan(&p.PaymentHash, &p.PaymentRequest, &p.Username, &p.Pubkey, &sub, &p.Years, &p.AmountSats, &expires, &paid, &renewal, &created, &target); err != nil { return nil, err } p.SubscriptionType = user.SubscriptionType(sub) if t, err := time.Parse(time.RFC3339, expires); err == nil { p.ExpiresAt = t } if t, err := time.Parse(time.RFC3339, created); err == nil { p.CreatedAt = t } else if t, err := time.Parse("2006-01-02 15:04:05", created); err == nil { p.CreatedAt = t } p.Paid = paid == 1 p.IsRenewal = renewal == 1 if target.Valid { p.TargetSet = true if target.String != "" { if t, err := time.Parse(time.RFC3339, target.String); err == nil { p.TargetExpiresAt = &t } } } return &p, nil } func (r *Repo) Insert(ctx context.Context, p *PendingInvoice) error { var target any if p.TargetSet { if p.TargetExpiresAt != nil { target = p.TargetExpiresAt.UTC().Format(time.RFC3339) } else { target = "" } } _, err := r.db.ExecContext(ctx, `INSERT INTO pending_invoices (payment_hash, payment_request, username, pubkey, subscription_type, years, amount_sats, expires_at, paid, is_renewal, target_expires_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, p.PaymentHash, p.PaymentRequest, p.Username, p.Pubkey, string(p.SubscriptionType), p.Years, p.AmountSats, p.ExpiresAt.UTC().Format(time.RFC3339), boolToInt(p.Paid), boolToInt(p.IsRenewal), target) return err } func (r *Repo) Get(ctx context.Context, hash string) (*PendingInvoice, error) { row := r.db.QueryRowContext(ctx, `SELECT `+invCols+` FROM pending_invoices WHERE payment_hash = ?`, hash) p, err := scanInvoice(row) if errors.Is(err, sql.ErrNoRows) { return nil, ErrInvoiceNotFound } return p, err } func (r *Repo) MarkPaid(ctx context.Context, hash string) error { _, err := r.db.ExecContext(ctx, `UPDATE pending_invoices SET paid = 1 WHERE payment_hash = ?`, hash) return err } // SetTargetIfUnset writes target_expires_at only when currently NULL. // Returns true if this call won the race. Lifetime is encoded as empty string, // allowing the caller to distinguish "not yet set" (NULL) from "set to nil". func (r *Repo) SetTargetIfUnset(ctx context.Context, hash string, target *time.Time) (bool, error) { stored := "" if target != nil { stored = target.UTC().Format(time.RFC3339) } res, err := r.db.ExecContext(ctx, `UPDATE pending_invoices SET target_expires_at = ? WHERE payment_hash = ? AND target_expires_at IS NULL`, stored, hash) if err != nil { return false, err } n, _ := res.RowsAffected() return n == 1, nil } // ClaimPaid atomically transitions paid 0 → 1. Returns true if the caller // performed the transition (i.e. it was unpaid before this call). func (r *Repo) ClaimPaid(ctx context.Context, hash string) (bool, error) { res, err := r.db.ExecContext(ctx, `UPDATE pending_invoices SET paid = 1 WHERE payment_hash = ? AND paid = 0`, hash) if err != nil { return false, err } n, _ := res.RowsAffected() return n == 1, nil } func (r *Repo) ListUnpaid(ctx context.Context) ([]*PendingInvoice, error) { rows, err := r.db.QueryContext(ctx, `SELECT `+invCols+` FROM pending_invoices WHERE paid = 0 AND expires_at > ?`, time.Now().UTC().Format(time.RFC3339)) if err != nil { return nil, err } defer rows.Close() out := []*PendingInvoice{} for rows.Next() { p, err := scanInvoice(rows) if err != nil { return nil, err } out = append(out, p) } return out, rows.Err() } // HasUnpaidForUsername returns true if there is an active unpaid invoice for the username. func (r *Repo) HasUnpaidForUsername(ctx context.Context, username string) (bool, error) { var count int err := r.db.QueryRowContext(ctx, `SELECT COUNT(1) FROM pending_invoices WHERE username = ? COLLATE NOCASE AND paid = 0 AND expires_at > ?`, username, time.Now().UTC().Format(time.RFC3339)).Scan(&count) return count > 0, err } // HasUnpaidForPubkey returns true if there is an active unpaid invoice for the pubkey. func (r *Repo) HasUnpaidForPubkey(ctx context.Context, pubkey string) (bool, error) { var count int err := r.db.QueryRowContext(ctx, `SELECT COUNT(1) FROM pending_invoices WHERE pubkey = ? AND paid = 0 AND expires_at > ?`, pubkey, time.Now().UTC().Format(time.RFC3339)).Scan(&count) return count > 0, err } // GetActiveUnpaidByPubkey returns the most recent unpaid, unexpired invoice for the pubkey, or nil if none. func (r *Repo) GetActiveUnpaidByPubkey(ctx context.Context, pubkey string) (*PendingInvoice, error) { row := r.db.QueryRowContext(ctx, `SELECT `+invCols+` FROM pending_invoices WHERE pubkey = ? AND paid = 0 AND expires_at > ? ORDER BY created_at DESC LIMIT 1`, pubkey, time.Now().UTC().Format(time.RFC3339)) p, err := scanInvoice(row) if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { return nil, err } return p, nil } // DeleteActiveUnpaidForPubkey removes all unpaid, unexpired invoices for the pubkey so a new // invoice can be issued when the user switches plan (replacing the previous Bolt11). func (r *Repo) DeleteActiveUnpaidForPubkey(ctx context.Context, pubkey string) error { _, err := r.db.ExecContext(ctx, `DELETE FROM pending_invoices WHERE pubkey = ? AND paid = 0 AND expires_at > ?`, pubkey, time.Now().UTC().Format(time.RFC3339)) return err } func (r *Repo) PurgeOldUnpaid(ctx context.Context) error { cutoff := time.Now().UTC().Add(-1 * time.Hour).Format(time.RFC3339) _, err := r.db.ExecContext(ctx, `DELETE FROM pending_invoices WHERE paid = 0 AND expires_at < ?`, cutoff) return err } func boolToInt(b bool) int { if b { return 1 } return 0 }