164 lines
5.2 KiB
Go
164 lines
5.2 KiB
Go
package invoice
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/noderunners/nip05api/internal/db"
|
|
"github.com/noderunners/nip05api/internal/user"
|
|
)
|
|
|
|
type Repo struct{ db *db.DB }
|
|
|
|
func NewRepo(d *db.DB) *Repo { return &Repo{db: d} }
|
|
|
|
const invCols = `payment_hash, payment_request, username, pubkey, subscription_type,
|
|
years, amount_sats, expires_at, paid, is_renewal, created_at, target_expires_at`
|
|
|
|
func scanInvoice(row interface{ Scan(...any) error }) (*PendingInvoice, error) {
|
|
var p PendingInvoice
|
|
var sub, expires, created string
|
|
var paid, renewal int
|
|
var target sql.NullString
|
|
if err := row.Scan(&p.PaymentHash, &p.PaymentRequest, &p.Username, &p.Pubkey,
|
|
&sub, &p.Years, &p.AmountSats, &expires, &paid, &renewal, &created, &target); err != nil {
|
|
return nil, err
|
|
}
|
|
p.SubscriptionType = user.SubscriptionType(sub)
|
|
if t, err := time.Parse(time.RFC3339, expires); err == nil {
|
|
p.ExpiresAt = t
|
|
}
|
|
if t, err := time.Parse(time.RFC3339, created); err == nil {
|
|
p.CreatedAt = t
|
|
} else if t, err := time.Parse("2006-01-02 15:04:05", created); err == nil {
|
|
p.CreatedAt = t
|
|
}
|
|
p.Paid = paid == 1
|
|
p.IsRenewal = renewal == 1
|
|
if target.Valid {
|
|
p.TargetSet = true
|
|
if target.String != "" {
|
|
if t, err := time.Parse(time.RFC3339, target.String); err == nil {
|
|
p.TargetExpiresAt = &t
|
|
}
|
|
}
|
|
}
|
|
return &p, nil
|
|
}
|
|
|
|
func (r *Repo) Insert(ctx context.Context, p *PendingInvoice) error {
|
|
var target any
|
|
if p.TargetSet {
|
|
if p.TargetExpiresAt != nil {
|
|
target = p.TargetExpiresAt.UTC().Format(time.RFC3339)
|
|
} else {
|
|
target = ""
|
|
}
|
|
}
|
|
_, err := r.db.ExecContext(ctx, `INSERT INTO pending_invoices
|
|
(payment_hash, payment_request, username, pubkey, subscription_type,
|
|
years, amount_sats, expires_at, paid, is_renewal, target_expires_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
p.PaymentHash, p.PaymentRequest, p.Username, p.Pubkey,
|
|
string(p.SubscriptionType), p.Years, p.AmountSats,
|
|
p.ExpiresAt.UTC().Format(time.RFC3339),
|
|
boolToInt(p.Paid), boolToInt(p.IsRenewal), target)
|
|
return err
|
|
}
|
|
|
|
func (r *Repo) Get(ctx context.Context, hash string) (*PendingInvoice, error) {
|
|
row := r.db.QueryRowContext(ctx, `SELECT `+invCols+` FROM pending_invoices WHERE payment_hash = ?`, hash)
|
|
p, err := scanInvoice(row)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrInvoiceNotFound
|
|
}
|
|
return p, err
|
|
}
|
|
|
|
func (r *Repo) MarkPaid(ctx context.Context, hash string) error {
|
|
_, err := r.db.ExecContext(ctx, `UPDATE pending_invoices SET paid = 1 WHERE payment_hash = ?`, hash)
|
|
return err
|
|
}
|
|
|
|
// SetTargetIfUnset writes target_expires_at only when currently NULL.
|
|
// Returns true if this call won the race. Lifetime is encoded as empty string,
|
|
// allowing the caller to distinguish "not yet set" (NULL) from "set to nil".
|
|
func (r *Repo) SetTargetIfUnset(ctx context.Context, hash string, target *time.Time) (bool, error) {
|
|
stored := ""
|
|
if target != nil {
|
|
stored = target.UTC().Format(time.RFC3339)
|
|
}
|
|
res, err := r.db.ExecContext(ctx,
|
|
`UPDATE pending_invoices SET target_expires_at = ? WHERE payment_hash = ? AND target_expires_at IS NULL`,
|
|
stored, hash)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
return n == 1, nil
|
|
}
|
|
|
|
// ClaimPaid atomically transitions paid 0 → 1. Returns true if the caller
|
|
// performed the transition (i.e. it was unpaid before this call).
|
|
func (r *Repo) ClaimPaid(ctx context.Context, hash string) (bool, error) {
|
|
res, err := r.db.ExecContext(ctx,
|
|
`UPDATE pending_invoices SET paid = 1 WHERE payment_hash = ? AND paid = 0`, hash)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
return n == 1, nil
|
|
}
|
|
|
|
func (r *Repo) ListUnpaid(ctx context.Context) ([]*PendingInvoice, error) {
|
|
rows, err := r.db.QueryContext(ctx, `SELECT `+invCols+` FROM pending_invoices
|
|
WHERE paid = 0 AND expires_at > ?`,
|
|
time.Now().UTC().Format(time.RFC3339))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
out := []*PendingInvoice{}
|
|
for rows.Next() {
|
|
p, err := scanInvoice(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, p)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// HasUnpaidForUsername returns true if there is an active unpaid invoice for the username.
|
|
func (r *Repo) HasUnpaidForUsername(ctx context.Context, username string) (bool, error) {
|
|
var count int
|
|
err := r.db.QueryRowContext(ctx, `SELECT COUNT(1) FROM pending_invoices
|
|
WHERE username = ? COLLATE NOCASE AND paid = 0 AND expires_at > ?`,
|
|
username, time.Now().UTC().Format(time.RFC3339)).Scan(&count)
|
|
return count > 0, err
|
|
}
|
|
|
|
// HasUnpaidForPubkey returns true if there is an active unpaid invoice for the pubkey.
|
|
func (r *Repo) HasUnpaidForPubkey(ctx context.Context, pubkey string) (bool, error) {
|
|
var count int
|
|
err := r.db.QueryRowContext(ctx, `SELECT COUNT(1) FROM pending_invoices
|
|
WHERE pubkey = ? AND paid = 0 AND expires_at > ?`,
|
|
pubkey, time.Now().UTC().Format(time.RFC3339)).Scan(&count)
|
|
return count > 0, err
|
|
}
|
|
|
|
func (r *Repo) PurgeOldUnpaid(ctx context.Context) error {
|
|
cutoff := time.Now().UTC().Add(-1 * time.Hour).Format(time.RFC3339)
|
|
_, err := r.db.ExecContext(ctx, `DELETE FROM pending_invoices WHERE paid = 0 AND expires_at < ?`, cutoff)
|
|
return err
|
|
}
|
|
|
|
func boolToInt(b bool) int {
|
|
if b {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|