Files
Nip-05-api/internal/invoice/repo.go
2026-04-29 02:35:00 +00:00

164 lines
5.2 KiB
Go

package invoice
import (
"context"
"database/sql"
"errors"
"time"
"github.com/noderunners/nip05api/internal/db"
"github.com/noderunners/nip05api/internal/user"
)
type Repo struct{ db *db.DB }
func NewRepo(d *db.DB) *Repo { return &Repo{db: d} }
const invCols = `payment_hash, payment_request, username, pubkey, subscription_type,
years, amount_sats, expires_at, paid, is_renewal, created_at, target_expires_at`
func scanInvoice(row interface{ Scan(...any) error }) (*PendingInvoice, error) {
var p PendingInvoice
var sub, expires, created string
var paid, renewal int
var target sql.NullString
if err := row.Scan(&p.PaymentHash, &p.PaymentRequest, &p.Username, &p.Pubkey,
&sub, &p.Years, &p.AmountSats, &expires, &paid, &renewal, &created, &target); err != nil {
return nil, err
}
p.SubscriptionType = user.SubscriptionType(sub)
if t, err := time.Parse(time.RFC3339, expires); err == nil {
p.ExpiresAt = t
}
if t, err := time.Parse(time.RFC3339, created); err == nil {
p.CreatedAt = t
} else if t, err := time.Parse("2006-01-02 15:04:05", created); err == nil {
p.CreatedAt = t
}
p.Paid = paid == 1
p.IsRenewal = renewal == 1
if target.Valid {
p.TargetSet = true
if target.String != "" {
if t, err := time.Parse(time.RFC3339, target.String); err == nil {
p.TargetExpiresAt = &t
}
}
}
return &p, nil
}
func (r *Repo) Insert(ctx context.Context, p *PendingInvoice) error {
var target any
if p.TargetSet {
if p.TargetExpiresAt != nil {
target = p.TargetExpiresAt.UTC().Format(time.RFC3339)
} else {
target = ""
}
}
_, err := r.db.ExecContext(ctx, `INSERT INTO pending_invoices
(payment_hash, payment_request, username, pubkey, subscription_type,
years, amount_sats, expires_at, paid, is_renewal, target_expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
p.PaymentHash, p.PaymentRequest, p.Username, p.Pubkey,
string(p.SubscriptionType), p.Years, p.AmountSats,
p.ExpiresAt.UTC().Format(time.RFC3339),
boolToInt(p.Paid), boolToInt(p.IsRenewal), target)
return err
}
func (r *Repo) Get(ctx context.Context, hash string) (*PendingInvoice, error) {
row := r.db.QueryRowContext(ctx, `SELECT `+invCols+` FROM pending_invoices WHERE payment_hash = ?`, hash)
p, err := scanInvoice(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrInvoiceNotFound
}
return p, err
}
func (r *Repo) MarkPaid(ctx context.Context, hash string) error {
_, err := r.db.ExecContext(ctx, `UPDATE pending_invoices SET paid = 1 WHERE payment_hash = ?`, hash)
return err
}
// SetTargetIfUnset writes target_expires_at only when currently NULL.
// Returns true if this call won the race. Lifetime is encoded as empty string,
// allowing the caller to distinguish "not yet set" (NULL) from "set to nil".
func (r *Repo) SetTargetIfUnset(ctx context.Context, hash string, target *time.Time) (bool, error) {
stored := ""
if target != nil {
stored = target.UTC().Format(time.RFC3339)
}
res, err := r.db.ExecContext(ctx,
`UPDATE pending_invoices SET target_expires_at = ? WHERE payment_hash = ? AND target_expires_at IS NULL`,
stored, hash)
if err != nil {
return false, err
}
n, _ := res.RowsAffected()
return n == 1, nil
}
// ClaimPaid atomically transitions paid 0 → 1. Returns true if the caller
// performed the transition (i.e. it was unpaid before this call).
func (r *Repo) ClaimPaid(ctx context.Context, hash string) (bool, error) {
res, err := r.db.ExecContext(ctx,
`UPDATE pending_invoices SET paid = 1 WHERE payment_hash = ? AND paid = 0`, hash)
if err != nil {
return false, err
}
n, _ := res.RowsAffected()
return n == 1, nil
}
func (r *Repo) ListUnpaid(ctx context.Context) ([]*PendingInvoice, error) {
rows, err := r.db.QueryContext(ctx, `SELECT `+invCols+` FROM pending_invoices
WHERE paid = 0 AND expires_at > ?`,
time.Now().UTC().Format(time.RFC3339))
if err != nil {
return nil, err
}
defer rows.Close()
out := []*PendingInvoice{}
for rows.Next() {
p, err := scanInvoice(rows)
if err != nil {
return nil, err
}
out = append(out, p)
}
return out, rows.Err()
}
// HasUnpaidForUsername returns true if there is an active unpaid invoice for the username.
func (r *Repo) HasUnpaidForUsername(ctx context.Context, username string) (bool, error) {
var count int
err := r.db.QueryRowContext(ctx, `SELECT COUNT(1) FROM pending_invoices
WHERE username = ? COLLATE NOCASE AND paid = 0 AND expires_at > ?`,
username, time.Now().UTC().Format(time.RFC3339)).Scan(&count)
return count > 0, err
}
// HasUnpaidForPubkey returns true if there is an active unpaid invoice for the pubkey.
func (r *Repo) HasUnpaidForPubkey(ctx context.Context, pubkey string) (bool, error) {
var count int
err := r.db.QueryRowContext(ctx, `SELECT COUNT(1) FROM pending_invoices
WHERE pubkey = ? AND paid = 0 AND expires_at > ?`,
pubkey, time.Now().UTC().Format(time.RFC3339)).Scan(&count)
return count > 0, err
}
func (r *Repo) PurgeOldUnpaid(ctx context.Context) error {
cutoff := time.Now().UTC().Add(-1 * time.Hour).Format(time.RFC3339)
_, err := r.db.ExecContext(ctx, `DELETE FROM pending_invoices WHERE paid = 0 AND expires_at < ?`, cutoff)
return err
}
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}