Adding unit tests
4 files changed, 260 insertions(+), 0 deletions(-)

M go.mod
A => go.sum
A => sql_test.go
A => testdata/widgets.sql
M go.mod +2 -0
@@ 1,3 1,5 @@ 
 module petersanchez.com/migrate
 
 go 1.16
+
+require github.com/mattn/go-sqlite3 v1.14.11 // indirect

          
A => go.sum +2 -0
@@ 0,0 1,2 @@ 
+github.com/mattn/go-sqlite3 v1.14.11 h1:gt+cp9c0XGqe9S/wAHTL3n/7MqY+siPWgWJgqdsFrzQ=
+github.com/mattn/go-sqlite3 v1.14.11/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=

          
A => sql_test.go +250 -0
@@ 0,0 1,250 @@ 
+package migrate_test
+
+import (
+	"context"
+	"database/sql"
+	"fmt"
+	"testing"
+
+	_ "github.com/mattn/go-sqlite3"
+	"petersanchez.com/migrate"
+)
+
+func sqliteInMem(t *testing.T) *sql.DB {
+	db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()))
+	if err != nil {
+		t.Fatalf("Open() err = %v; want nil", err)
+	}
+	t.Cleanup(func() {
+		err = db.Close()
+		if err != nil {
+			t.Errorf("Close() err = %v; want nil", err)
+		}
+	})
+	return db
+}
+
+func TestMigrate(t *testing.T) {
+	t.Run("simple", func(t *testing.T) {
+		db := sqliteInMem(t)
+		migrations := []migrate.Migration{
+			migrate.QueryMigration("001_create_courses", createCoursesSQL, "", 0),
+		}
+		engine := migrate.NewEngine(db, migrations, migrate.QUESTION, false)
+		engine.Printf = func(format string, args ...interface{}) (int, error) {
+			t.Logf(format, args...)
+			return 0, nil
+		}
+		ctx := context.Background()
+		err := engine.Migrate(ctx, "", false)
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		_, err = db.Exec("INSERT INTO courses (name) VALUES (?) ", "cor_test")
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+	})
+
+	t.Run("existing migrations", func(t *testing.T) {
+		db := sqliteInMem(t)
+		migrations := []migrate.Migration{
+			migrate.QueryMigration("001_create_courses", createCoursesSQL, "", 0),
+		}
+		engine := migrate.NewEngine(db, migrations, migrate.QUESTION, false)
+		engine.Printf = func(format string, args ...interface{}) (int, error) {
+			t.Logf(format, args...)
+			return 0, nil
+		}
+		ctx := context.Background()
+		err := engine.Migrate(ctx, "", false)
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		_, err = db.Exec("INSERT INTO courses (name) VALUES (?) ", "cor_test")
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+
+		// the real test
+		migrations = []migrate.Migration{
+			migrate.QueryMigration("001_create_courses", createCoursesSQL, "", 0),
+			migrate.QueryMigration("002_create_users", createUsersSQL, "", 0),
+		}
+		engine = migrate.NewEngine(db, migrations, migrate.QUESTION, false)
+		engine.Printf = func(format string, args ...interface{}) (int, error) {
+			t.Logf(format, args...)
+			return 0, nil
+		}
+		err = engine.Migrate(ctx, "", false)
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		_, err = db.Exec("INSERT INTO users (email) VALUES (?) ", "abc@test.com")
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+	})
+
+	t.Run("file", func(t *testing.T) {
+		db := sqliteInMem(t)
+		migrations := []migrate.Migration{
+			migrate.FileMigration("001_create_widgets", "testdata/widgets.sql", "", 0),
+		}
+		engine := migrate.NewEngine(db, migrations, migrate.QUESTION, false)
+		engine.Printf = func(format string, args ...interface{}) (int, error) {
+			t.Logf(format, args...)
+			return 0, nil
+		}
+		ctx := context.Background()
+		err := engine.Migrate(ctx, "", false)
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		_, err = db.Exec("INSERT INTO widgets (color, price) VALUES (?, ?)", "red", 1200)
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+	})
+
+	t.Run("rollback", func(t *testing.T) {
+		db := sqliteInMem(t)
+		migrations := []migrate.Migration{
+			migrate.QueryMigration("001_create_courses", createCoursesSQL, dropCouresesSQL, 0),
+		}
+		engine := migrate.NewEngine(db, migrations, migrate.QUESTION, false)
+		engine.Printf = func(format string, args ...interface{}) (int, error) {
+			t.Logf(format, args...)
+			return 0, nil
+		}
+		ctx := context.Background()
+		err := engine.Migrate(ctx, "", false)
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		_, err = db.Exec("INSERT INTO courses (name) VALUES (?) ", "cor_test")
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+		err = engine.Rollback(ctx, "")
+		if err != nil {
+			t.Fatalf("Rollback() err = %v; want nil", err)
+		}
+		var count int
+		err = db.QueryRow("SELECT COUNT(id) FROM courses;").Scan(&count)
+		if err == nil {
+			// Want an error here
+			t.Fatalf("db.QueryRow() err = nil; want table missing error")
+		}
+		// Don't want to test inner workings of lib, so let's just migrate again and verify we have a table now
+		err = engine.Migrate(ctx, "", false)
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		_, err = db.Exec("INSERT INTO courses (name) VALUES (?) ", "cor_test")
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+		err = db.QueryRow("SELECT COUNT(*) FROM courses;").Scan(&count)
+		if err != nil {
+			// Want an error here
+			t.Fatalf("db.QueryRow() err = %v; want nil", err)
+		}
+		if count != 1 {
+			t.Fatalf("count = %d; want %d", count, 1)
+		}
+	})
+
+	t.Run("specific migration", func(t *testing.T) {
+		db := sqliteInMem(t)
+		migrations := []migrate.Migration{
+			migrate.QueryMigration("001_create_courses", createCoursesSQL, "", 0),
+			migrate.QueryMigration("002_create_users", createUsersSQL, dropUsersSQL, 0),
+		}
+		engine := migrate.NewEngine(db, migrations, migrate.QUESTION, false)
+		engine.Printf = func(format string, args ...interface{}) (int, error) {
+			t.Logf(format, args...)
+			return 0, nil
+		}
+		ctx := context.Background()
+		err := engine.Migrate(ctx, "001_create_courses", false)
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		_, err = db.Exec("INSERT INTO users (email) VALUES (?) ", "abc@test.com")
+		if err == nil {
+			t.Fatalf("db.Exec() err = nil; want error (table doesn't exist)")
+		}
+		err = engine.Migrate(ctx, "", false)
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		_, err = db.Exec("INSERT INTO users (email) VALUES (?) ", "abc@test.com")
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+
+		// Rollback one
+		err = engine.Rollback(ctx, "002_create_users")
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		_, err = db.Exec("INSERT INTO users (email) VALUES (?) ", "abc@test.com")
+		if err == nil {
+			t.Fatalf("db.Exec() err = nil; want error (table doesn't exist)")
+		}
+		_, err = db.Exec("INSERT INTO courses (name) VALUES (?) ", "cor_test")
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+	})
+
+	t.Run("fake migration", func(t *testing.T) {
+		db := sqliteInMem(t)
+		migrations := []migrate.Migration{
+			migrate.QueryMigration("001_create_courses", createCoursesSQL, "", 0),
+			migrate.QueryMigration("002_create_users", createUsersSQL, dropUsersSQL, 0),
+		}
+		engine := migrate.NewEngine(db, migrations, migrate.QUESTION, false)
+		engine.Printf = func(format string, args ...interface{}) (int, error) {
+			t.Logf(format, args...)
+			return 0, nil
+		}
+		ctx := context.Background()
+		err := engine.Migrate(ctx, "", false)
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+
+		// Delete migration history
+		_, err = db.Exec("DELETE FROM migrations WHERE id=?", "002_create_users")
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+		err = engine.Migrate(ctx, "", true) // Fake migrations
+		if err != nil {
+			t.Fatalf("Migrate() err = %v; want nil", err)
+		}
+		var mid string
+		err = db.QueryRow("SELECT id FROM migrations WHERE id=?", "002_create_users").Scan(&mid)
+		if err != nil {
+			t.Fatalf("db.Exec() err = %v; want nil", err)
+		}
+	})
+}
+
+var (
+	createCoursesSQL = `
+CREATE TABLE courses (
+  id serial PRIMARY KEY,
+  name text
+);`
+	dropCouresesSQL = `DROP TABLE courses;`
+
+	createUsersSQL = `
+CREATE TABLE users (
+  id serial PRIMARY KEY,
+  email text UNIQUE NOT NULL
+);`
+	dropUsersSQL = `DROP TABLE users;`
+)

          
A => testdata/widgets.sql +6 -0
@@ 0,0 1,6 @@ 
+CREATE TABLE widgets (
+  id serial PRIMARY KEY,
+  color text NOT NULL,
+  price integer
+);
+