package db import ( "context" "embed" "fmt" "io/fs" "log/slog" "sort" "strconv" "strings" "time" ) //go:embed migrations/*.sql var migrationsFS embed.FS type migration struct { version int name string sql string } func (d *DB) Migrate(ctx context.Context) error { if _, err := d.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at TEXT NOT NULL )`); err != nil { return fmt.Errorf("ensure schema_migrations: %w", err) } migs, err := loadMigrations() if err != nil { return err } applied, err := d.appliedVersions(ctx) if err != nil { return err } for _, m := range migs { if applied[m.version] { continue } slog.Info("applying migration", "version", m.version, "name", m.name) if err := d.applyOne(ctx, m); err != nil { return fmt.Errorf("apply %d %s: %w", m.version, m.name, err) } } return nil } func (d *DB) applyOne(ctx context.Context, m migration) error { tx, err := d.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() if _, err := tx.ExecContext(ctx, m.sql); err != nil { return err } if _, err := tx.ExecContext(ctx, `INSERT INTO schema_migrations(version, applied_at) VALUES(?, ?)`, m.version, time.Now().UTC().Format(time.RFC3339)); err != nil { return err } return tx.Commit() } func (d *DB) appliedVersions(ctx context.Context) (map[int]bool, error) { rows, err := d.QueryContext(ctx, `SELECT version FROM schema_migrations`) if err != nil { return nil, err } defer rows.Close() out := map[int]bool{} for rows.Next() { var v int if err := rows.Scan(&v); err != nil { return nil, err } out[v] = true } return out, rows.Err() } func loadMigrations() ([]migration, error) { entries, err := fs.ReadDir(migrationsFS, "migrations") if err != nil { return nil, fmt.Errorf("read migrations: %w", err) } out := make([]migration, 0, len(entries)) for _, e := range entries { if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") { continue } parts := strings.SplitN(e.Name(), "_", 2) if len(parts) != 2 { continue } v, err := strconv.Atoi(parts[0]) if err != nil { continue } b, err := migrationsFS.ReadFile("migrations/" + e.Name()) if err != nil { return nil, err } out = append(out, migration{version: v, name: e.Name(), sql: string(b)}) } sort.Slice(out, func(i, j int) bool { return out[i].version < out[j].version }) return out, nil }