327 lines
7.7 KiB
Go
327 lines
7.7 KiB
Go
// A simple database migrator for PostgreSQL.
|
|
|
|
package gomigrate
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"errors"
|
|
"io/ioutil"
|
|
"log"
|
|
"path/filepath"
|
|
"sort"
|
|
)
|
|
|
|
type migrationType string
|
|
|
|
const (
|
|
migrationTableName = "gomigrate"
|
|
upMigration = migrationType("up")
|
|
downMigration = migrationType("down")
|
|
)
|
|
|
|
var (
|
|
InvalidMigrationFile = errors.New("Invalid migration file")
|
|
InvalidMigrationPair = errors.New("Invalid pair of migration files")
|
|
InvalidMigrationsPath = errors.New("Invalid migrations path")
|
|
InvalidMigrationType = errors.New("Invalid migration type")
|
|
NoActiveMigrations = errors.New("No active migrations to rollback")
|
|
)
|
|
|
|
type Migrator struct {
|
|
DB *sql.DB
|
|
MigrationsPath string
|
|
dbAdapter Migratable
|
|
migrations map[uint64]*Migration
|
|
}
|
|
|
|
// Returns true if the migration table already exists.
|
|
func (m *Migrator) MigrationTableExists() (bool, error) {
|
|
row := m.DB.QueryRow(m.dbAdapter.SelectMigrationTableSql(), migrationTableName)
|
|
var tableName string
|
|
err := row.Scan(&tableName)
|
|
if err == sql.ErrNoRows {
|
|
log.Print("Migrations table not found")
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
log.Printf("Error checking for migration table: %v", err)
|
|
return false, err
|
|
}
|
|
log.Print("Migrations table found")
|
|
return true, nil
|
|
}
|
|
|
|
// Creates the migrations table if it doesn't exist.
|
|
func (m *Migrator) CreateMigrationsTable() error {
|
|
_, err := m.DB.Query(m.dbAdapter.CreateMigrationTableSql())
|
|
if err != nil {
|
|
log.Fatalf("Error creating migrations table: %v", err)
|
|
}
|
|
|
|
log.Printf("Created migrations table: %s", migrationTableName)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Returns a new migrator.
|
|
func NewMigrator(db *sql.DB, adapter Migratable, migrationsPath string) (*Migrator, error) {
|
|
// Normalize the migrations path.
|
|
path := []byte(migrationsPath)
|
|
pathLength := len(path)
|
|
if path[pathLength-1] != '/' {
|
|
path = append(path, '/')
|
|
}
|
|
|
|
log.Printf("Migrations path: %s", path)
|
|
|
|
migrator := Migrator{
|
|
db,
|
|
string(path),
|
|
adapter,
|
|
make(map[uint64]*Migration),
|
|
}
|
|
|
|
// Create the migrations table if it doesn't exist.
|
|
tableExists, err := migrator.MigrationTableExists()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !tableExists {
|
|
if err := migrator.CreateMigrationsTable(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Get all metadata from the database.
|
|
if err := migrator.fetchMigrations(); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := migrator.getMigrationStatuses(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &migrator, nil
|
|
}
|
|
|
|
// Populates a migrator with a sorted list of migrations from the file system.
|
|
func (m *Migrator) fetchMigrations() error {
|
|
pathGlob := append([]byte(m.MigrationsPath), []byte("*")...)
|
|
|
|
matches, err := filepath.Glob(string(pathGlob))
|
|
if err != nil {
|
|
log.Fatalf("Error while globbing migrations: %v", err)
|
|
}
|
|
|
|
for _, match := range matches {
|
|
num, migrationType, name, err := parseMigrationPath(match)
|
|
if err != nil {
|
|
log.Printf("Invalid migration file found: %s", match)
|
|
continue
|
|
}
|
|
|
|
log.Printf("Migration file found: %s", match)
|
|
|
|
migration, ok := m.migrations[num]
|
|
if !ok {
|
|
migration = &Migration{Id: num, Name: name, Status: Inactive}
|
|
m.migrations[num] = migration
|
|
}
|
|
if migrationType == upMigration {
|
|
migration.UpPath = match
|
|
} else {
|
|
migration.DownPath = match
|
|
}
|
|
}
|
|
|
|
// Validate each migration.
|
|
for _, migration := range m.migrations {
|
|
if !migration.valid() {
|
|
path := migration.UpPath
|
|
if path == "" {
|
|
path = migration.DownPath
|
|
}
|
|
log.Printf("Invalid migration pair for path: %s", path)
|
|
return InvalidMigrationPair
|
|
}
|
|
}
|
|
|
|
log.Printf("Migrations file pairs found: %v", len(m.migrations))
|
|
|
|
return nil
|
|
}
|
|
|
|
// Queries the migration table to determine the status of each
|
|
// migration.
|
|
func (m *Migrator) getMigrationStatuses() error {
|
|
for _, migration := range m.migrations {
|
|
row := m.DB.QueryRow(m.dbAdapter.GetMigrationSql(), migration.Id)
|
|
var mid uint64
|
|
err := row.Scan(&mid)
|
|
if err == sql.ErrNoRows {
|
|
continue
|
|
}
|
|
if err != nil {
|
|
log.Printf(
|
|
"Error getting migration status for %s: %v",
|
|
migration.Name,
|
|
err,
|
|
)
|
|
return err
|
|
}
|
|
migration.Status = Active
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Returns a sorted list of migration ids for a given status. -1 returns
|
|
// all migrations.
|
|
func (m *Migrator) Migrations(status int) []*Migration {
|
|
// Sort all migration ids.
|
|
ids := make([]uint64, 0)
|
|
for id, _ := range m.migrations {
|
|
ids = append(ids, id)
|
|
}
|
|
sort.Sort(uint64slice(ids))
|
|
|
|
// Find ids for the given status.
|
|
migrations := make([]*Migration, 0)
|
|
for _, id := range ids {
|
|
migration := m.migrations[id]
|
|
if status == -1 || migration.Status == status {
|
|
migrations = append(migrations, migration)
|
|
}
|
|
}
|
|
return migrations
|
|
}
|
|
|
|
// Applies a single migration.
|
|
func (m *Migrator) ApplyMigration(migration *Migration, mType migrationType) error {
|
|
var path string
|
|
if mType == upMigration {
|
|
path = migration.UpPath
|
|
} else if mType == downMigration {
|
|
path = migration.DownPath
|
|
} else {
|
|
return InvalidMigrationType
|
|
}
|
|
|
|
log.Printf("Applying migration: %s", path)
|
|
|
|
sql, err := ioutil.ReadFile(path)
|
|
if err != nil {
|
|
log.Printf("Error reading migration: %s", path)
|
|
return err
|
|
}
|
|
transaction, err := m.DB.Begin()
|
|
if err != nil {
|
|
log.Printf("Error opening transaction: %v", err)
|
|
return err
|
|
}
|
|
|
|
for n, subMigration := range splitMigrationString(string(sql)) {
|
|
if allWhitespace.Match([]byte(subMigration)) {
|
|
continue
|
|
}
|
|
|
|
log.Printf("Applying submigration: %v", n+1)
|
|
|
|
for _, line := range bytes.Split([]byte(subMigration), []byte("\n")) {
|
|
log.Printf("MIGRATION: %s", line)
|
|
}
|
|
|
|
// Perform the migration.
|
|
result, err := transaction.Exec(string(subMigration))
|
|
if err != nil {
|
|
log.Printf("Error executing migration: %v", err)
|
|
if rollbackErr := transaction.Rollback(); rollbackErr != nil {
|
|
log.Printf("Error rolling back transaction: %v", rollbackErr)
|
|
return rollbackErr
|
|
}
|
|
return err
|
|
}
|
|
if rowsAffected, err := result.RowsAffected(); err != nil {
|
|
log.Printf("Error getting rows affected: %v", err)
|
|
if rollbackErr := transaction.Rollback(); rollbackErr != nil {
|
|
log.Printf("Error rolling back transaction: %v", rollbackErr)
|
|
return rollbackErr
|
|
}
|
|
return err
|
|
} else {
|
|
log.Printf("Rows affected: %v", rowsAffected)
|
|
}
|
|
}
|
|
|
|
// Log the event.
|
|
if mType == upMigration {
|
|
_, err = transaction.Exec(
|
|
m.dbAdapter.MigrationLogInsertSql(),
|
|
migration.Id,
|
|
)
|
|
} else {
|
|
_, err = transaction.Exec(
|
|
m.dbAdapter.MigrationLogDeleteSql(),
|
|
migration.Id,
|
|
)
|
|
}
|
|
if err != nil {
|
|
log.Printf("Error logging migration: %v", err)
|
|
if rollbackErr := transaction.Rollback(); rollbackErr != nil {
|
|
log.Printf("Error rolling back transaction: %v", rollbackErr)
|
|
return rollbackErr
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Commit and update the struct status.
|
|
if err := transaction.Commit(); err != nil {
|
|
log.Printf("Error commiting transaction: %v", err)
|
|
return err
|
|
}
|
|
if mType == upMigration {
|
|
migration.Status = Active
|
|
} else {
|
|
migration.Status = Inactive
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Applies all inactive migrations.
|
|
func (m *Migrator) Migrate() error {
|
|
for _, migration := range m.Migrations(Inactive) {
|
|
if err := m.ApplyMigration(migration, upMigration); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Rolls back the last migration.
|
|
func (m *Migrator) Rollback() error {
|
|
return m.RollbackN(1)
|
|
}
|
|
|
|
// Rolls back N migrations.
|
|
func (m *Migrator) RollbackN(n int) error {
|
|
migrations := m.Migrations(Active)
|
|
if len(migrations) == 0 {
|
|
return NoActiveMigrations
|
|
}
|
|
|
|
last_migration := len(migrations) - 1 - n
|
|
|
|
for i := len(migrations) - 1; i != last_migration; i-- {
|
|
if err := m.ApplyMigration(migrations[i], downMigration); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Rolls back all migrations.
|
|
func (m *Migrator) RollbackAll() error {
|
|
migrations := m.Migrations(Active)
|
|
return m.RollbackN(len(migrations))
|
|
}
|