tests: ensure backends are always accessed as a contextmanager

This ensures that backends are cleaned up and the connection closed at the end
of tests, helping improve isolation between tests
M yoyo/tests/__init__.py +0 -5
@@ 20,7 20,6 @@ import contextlib
 import os.path
 
 from yoyo.config import get_configparser
-from yoyo.connections import get_backend
 
 dburi_sqlite3 = "sqlite:///:memory:"
 

          
@@ 39,10 38,6 @@ def get_test_dburis(only=frozenset(), ex
     ]
 
 
-def get_test_backends(only=frozenset(), exclude=frozenset()):
-    return [get_backend(dburi) for dburi in get_test_dburis(only, exclude)]
-
-
 def clear_database(backend):
     with backend.transaction():
         for table in backend.list_tables():

          
M yoyo/tests/conftest.py +16 -16
@@ 3,7 3,6 @@ import pytest
 import yoyo.backends.core
 from yoyo.connections import get_backend
 from yoyo.tests import dburi_sqlite3
-from yoyo.tests import get_test_backends
 from yoyo.tests import get_test_dburis
 
 

          
@@ 11,20 10,20 @@ def _backend(dburi):
     """
     Return a backend configured in ``test_databases.ini``
     """
-    backend = get_backend(dburi)
-    with backend.transaction():
-        if backend.__class__ is yoyo.backends.core.MySQLBackend:
-            backend.execute(
-                "CREATE TABLE yoyo_t (id CHAR(1) primary key) ENGINE=InnoDB"
-            )
-        else:
-            backend.execute("CREATE TABLE yoyo_t " "(id CHAR(1) primary key)")
-    try:
-        yield backend
-    finally:
-        backend.rollback()
+    with get_backend(dburi) as backend:
         with backend.transaction():
-            drop_all_tables(backend)
+            if backend.__class__ is yoyo.backends.core.MySQLBackend:
+                backend.execute(
+                    "CREATE TABLE yoyo_t (id CHAR(1) primary key) ENGINE=InnoDB"
+                )
+            else:
+                backend.execute("CREATE TABLE yoyo_t " "(id CHAR(1) primary key)")
+        try:
+            yield backend
+        finally:
+            backend.rollback()
+            with backend.transaction():
+                drop_all_tables(backend)
 
 
 @pytest.fixture(params=get_test_dburis())

          
@@ 63,5 62,6 @@ def drop_yoyo_tables(backend):
 
 
 def pytest_configure(config):
-    for backend in get_test_backends():
-        drop_yoyo_tables(backend)
+    for dburi in get_test_dburis():
+        with get_backend(dburi) as backend:
+            drop_yoyo_tables(backend)

          
M yoyo/tests/test_backends.py +63 -61
@@ 12,7 12,6 @@ from yoyo import read_migrations
 from yoyo import exceptions
 from yoyo.backends.contrib.redshift import RedshiftBackend
 from yoyo.connections import get_backend
-from yoyo.tests import get_test_backends
 from yoyo.tests import get_test_dburis
 from yoyo.tests import migrations_dir
 

          
@@ 122,10 121,11 @@ class TestTransactionHandling(object):
                 )
         """
         ) as tmpdir:
-            for backend in get_test_backends(exclude={"sqlite", "oracle"}):
-                migrations = read_migrations(tmpdir)
-                backend.apply_migrations(migrations)
-                backend.rollback_migrations(migrations)
+            for dburi in get_test_dburis(exclude={"sqlite", "oracle"}):
+                with get_backend(dburi) as backend:
+                    migrations = read_migrations(tmpdir)
+                    backend.apply_migrations(migrations)
+                    backend.rollback_migrations(migrations)
 
     def test_disabling_transactions_in_sqlite(self):
         """

          
@@ 191,36 191,36 @@ class TestConcurrency(object):
         Test that :meth:`~yoyo.backends.DatabaseBackend.lock`
         acquires an exclusive lock
         """
-        backend = get_backend(dburi)
-        self.skip_if_not_concurrency_safe(backend)
-        thread = Thread(target=self.get_lock_sleeper(dburi))
-        t = time.time()
-        thread.start()
+        with get_backend(dburi) as backend:
+            self.skip_if_not_concurrency_safe(backend)
+            thread = Thread(target=self.get_lock_sleeper(dburi))
+            t = time.time()
+            thread.start()
 
-        # Give the thread time to acquire the lock, but not enough
-        # to complete
-        time.sleep(self.lock_duration * 0.6)
+            # Give the thread time to acquire the lock, but not enough
+            # to complete
+            time.sleep(self.lock_duration * 0.6)
 
-        with backend.lock():
-            delta = time.time() - t
-            assert delta >= self.lock_duration
+            with backend.lock():
+                delta = time.time() - t
+                assert delta >= self.lock_duration
 
-        thread.join()
+            thread.join()
 
     def test_lock_times_out(self, dburi):
-        backend = get_backend(dburi)
-        self.skip_if_not_concurrency_safe(backend)
+        with get_backend(dburi) as backend:
+            self.skip_if_not_concurrency_safe(backend)
 
-        thread = Thread(target=self.get_lock_sleeper(dburi))
-        thread.start()
-        # Give the thread time to acquire the lock, but not enough
-        # to complete
-        time.sleep(self.lock_duration * 0.6)
-        with pytest.raises(exceptions.LockTimeout):
-            with backend.lock(timeout=0.001):
-                assert False, "Execution should never reach this point"
+            thread = Thread(target=self.get_lock_sleeper(dburi))
+            thread.start()
+            # Give the thread time to acquire the lock, but not enough
+            # to complete
+            time.sleep(self.lock_duration * 0.6)
+            with pytest.raises(exceptions.LockTimeout):
+                with backend.lock(timeout=0.001):
+                    assert False, "Execution should never reach this point"
 
-        thread.join()
+            thread.join()
 
 
 class TestInitConnection(object):

          
@@ 265,39 265,39 @@ class TestInitConnection(object):
         if dburi is None:
             pytest.skip("PostgreSQL backend not available")
             return
-        backend = get_backend(dburi)
-        with backend.transaction():
-            backend.execute("CREATE SCHEMA foo")
-        try:
-            assert get_backend(dburi + "?schema=foo").execute(
-                "SHOW search_path"
-            ).fetchone() == ("foo",)
-        finally:
+        with get_backend(dburi) as backend:
             with backend.transaction():
-                backend.execute("DROP SCHEMA foo CASCADE")
+                backend.execute("CREATE SCHEMA foo")
+
+                try:
+                    with get_backend(f"{dburi}?schema=foo") as b2:
+                        assert b2.execute("SHOW search_path").fetchone() == ("foo",)
+                finally:
+                    with backend.transaction():
+                        backend.execute("DROP SCHEMA foo CASCADE")
 
     def test_postgresql_list_table_uses_current_schema(self):
         dburi = next(iter(get_test_dburis(only={"postgresql"})), None)
         if dburi is None:
             pytest.skip("PostgreSQL backend not available")
-        backend = get_backend(dburi)
-        dbname = backend.uri.database
-        with backend.transaction():
-            backend.execute(
-                "ALTER DATABASE {} SET SEARCH_PATH = custom_schema,public".format(
-                    dbname
+        with get_backend(dburi) as backend:
+            dbname = backend.uri.database
+            with backend.transaction():
+                backend.execute(
+                    "ALTER DATABASE {} SET SEARCH_PATH = custom_schema,public".format(
+                        dbname
+                    )
                 )
-            )
-        try:
-            with backend.transaction():
-                backend.execute("CREATE SCHEMA custom_schema")
-                backend.execute("CREATE TABLE custom_schema.foo (x int)")
-            assert "foo" in get_backend(dburi).list_tables()
+            try:
+                with backend.transaction():
+                    backend.execute("CREATE SCHEMA custom_schema")
+                    backend.execute("CREATE TABLE custom_schema.foo (x int)")
+                assert "foo" in get_backend(dburi).list_tables()
 
-        finally:
-            with backend.transaction():
-                backend.execute("ALTER DATABASE {} RESET SEARCH_PATH".format(dbname))
-                backend.execute("DROP SCHEMA custom_schema CASCADE")
+            finally:
+                with backend.transaction():
+                    backend.execute(f"ALTER DATABASE {dbname} RESET SEARCH_PATH")
+                    backend.execute("DROP SCHEMA custom_schema CASCADE")
 
     def test_postgresql_migrations_can_change_schema_search_path(self):
         """

          
@@ 306,11 306,13 @@ class TestInitConnection(object):
         dburi = next(iter(get_test_dburis(only={"postgresql"})), None)
         if dburi is None:
             pytest.skip("PostgreSQL backend not available")
-        backend = get_backend(dburi)
-        with migrations_dir(
-            **{"1.sql": "SELECT pg_catalog.set_config('search_path', '', false)"}
-        ) as tmpdir:
-            migrations = read_migrations(tmpdir)
-            backend.apply_migrations(migrations)
-            applied = backend.execute("SELECT migration_id FROM _yoyo_log").fetchall()
-            assert applied == [("1",)]
+        with get_backend(dburi) as backend:
+            with migrations_dir(
+                **{"1.sql": "SELECT pg_catalog.set_config('search_path', '', false)"}
+            ) as tmpdir:
+                migrations = read_migrations(tmpdir)
+                backend.apply_migrations(migrations)
+                applied = backend.execute(
+                    "SELECT migration_id FROM _yoyo_log"
+                ).fetchall()
+                assert applied == [("1",)]

          
M yoyo/tests/test_cli_script.py +42 -41
@@ 33,9 33,9 @@ import tms
 
 from yoyo import read_migrations
 from yoyo.config import get_configparser
+from yoyo.connections import get_backend
 from yoyo.tests import dburi_sqlite3
 from yoyo.tests import migrations_dir
-from yoyo.tests import get_backend
 from yoyo.scripts.main import main, parse_args, LEGACY_CONFIG_FILENAME
 from yoyo.scripts import newmigration
 

          
@@ 79,11 79,10 @@ class TestInteractiveScript(object):
                 cp.write(f)
 
     def get_migration_log(self):
-        return (
-            get_backend(self.dburi)
-            .execute("SELECT migration_id, operation FROM _yoyo_log")
-            .fetchall()
-        )
+        with get_backend(self.dburi) as backend:
+            return backend.execute(
+                "SELECT migration_id, operation FROM _yoyo_log"
+            ).fetchall()
 
 
 class TestYoyoScript(TestInteractiveScript):

          
@@ 257,15 256,15 @@ class TestYoyoScript(TestInteractiveScri
     def test_it_breaks_lock(self, dburi):
         if dburi.startswith("sqlite"):
             pytest.skip("Test not supported for sqlite databases")
-        backend = get_backend(dburi)
-        backend.execute(
-            "INSERT INTO yoyo_lock (locked, ctime, pid) " "VALUES (1, :now, 1)",
-            {"now": datetime.now(timezone.utc).replace(tzinfo=None)},
-        )
-        backend.commit()
-        main(["break-lock", "--database", dburi])
-        lock_count = backend.execute("SELECT COUNT(1) FROM yoyo_lock").fetchone()[0]
-        assert lock_count == 0
+        with get_backend(dburi) as backend:
+            backend.execute(
+                "INSERT INTO yoyo_lock (locked, ctime, pid) " "VALUES (1, :now, 1)",
+                {"now": datetime.now(timezone.utc).replace(tzinfo=None)},
+            )
+            backend.commit()
+            main(["break-lock", "--database", dburi])
+            lock_count = backend.execute("SELECT COUNT(1) FROM yoyo_lock").fetchone()[0]
+            assert lock_count == 0
 
     def test_it_prompts_password_on_break_lock(self):
         dburi = "sqlite://user@/:memory"

          
@@ 324,14 323,16 @@ class TestMarkCommand(TestInteractiveScr
             from yoyo.connections import get_backend
 
             migrations = read_migrations(tmpdir)
-            backend = get_backend(self.dburi)
-            backend.apply_migrations(migrations[:1])
+            with get_backend(self.dburi) as backend:
+                backend.apply_migrations(migrations[:1])
 
-            with patch("yoyo.scripts.migrate.prompt_migrations") as prompt_migrations:
-                main(["mark", tmpdir, "--database", self.dburi])
-                _, prompted, _ = prompt_migrations.call_args[0]
-                prompted = [m.id for m in prompted]
-                assert prompted == ["m2", "m3"]
+                with patch(
+                    "yoyo.scripts.migrate.prompt_migrations"
+                ) as prompt_migrations:
+                    main(["mark", tmpdir, "--database", self.dburi])
+                    _, prompted, _ = prompt_migrations.call_args[0]
+                    prompted = [m.id for m in prompted]
+                    assert prompted == ["m2", "m3"]
 
     def test_it_marks_at_selected_version(self):
         with migrations_dir(

          
@@ 343,18 344,18 @@ class TestMarkCommand(TestInteractiveScr
 
             self.confirm.return_value = True
             migrations = read_migrations(tmpdir)
-            backend = get_backend(self.dburi)
-            with backend.transaction():
-                backend.execute("CREATE TABLE t (id INT)")
+            with get_backend(self.dburi) as backend:
+                with backend.transaction():
+                    backend.execute("CREATE TABLE t (id INT)")
 
-            main(["mark", "-r", "m2", tmpdir, "--database", self.dburi])
-            assert backend.is_applied(migrations[0])
-            assert backend.is_applied(migrations[1])
-            assert not backend.is_applied(migrations[2])
+                main(["mark", "-r", "m2", tmpdir, "--database", self.dburi])
+                assert backend.is_applied(migrations[0])
+                assert backend.is_applied(migrations[1])
+                assert not backend.is_applied(migrations[2])
 
-            # Check that migration steps have not been applied
-            c = backend.execute("SELECT * FROM t")
-            assert len(c.fetchall()) == 0
+                # Check that migration steps have not been applied
+                c = backend.execute("SELECT * FROM t")
+                assert len(c.fetchall()) == 0
 
 
 class TestUnmarkCommand(TestInteractiveScript):

          
@@ 363,9 364,9 @@ class TestUnmarkCommand(TestInteractiveS
             from yoyo.connections import get_backend
 
             migrations = read_migrations(tmpdir)
-            backend = get_backend(self.dburi)
-            backend.apply_migrations(migrations[:2])
-            assert len(backend.get_applied_migration_hashes()) == 2
+            with get_backend(self.dburi) as backend:
+                backend.apply_migrations(migrations[:2])
+                assert len(backend.get_applied_migration_hashes()) == 2
 
             with patch("yoyo.scripts.migrate.prompt_migrations") as prompt_migrations:
                 main(["unmark", tmpdir, "--database", self.dburi])

          
@@ 381,13 382,13 @@ class TestUnmarkCommand(TestInteractiveS
 
             self.confirm.return_value = True
             migrations = read_migrations(tmpdir)
-            backend = get_backend(self.dburi)
-            backend.apply_migrations(migrations)
+            with get_backend(self.dburi) as backend:
+                backend.apply_migrations(migrations)
 
-            main(["unmark", "-r", "m2", tmpdir, "--database", self.dburi])
-            assert backend.is_applied(migrations[0])
-            assert not backend.is_applied(migrations[1])
-            assert not backend.is_applied(migrations[2])
+                main(["unmark", "-r", "m2", tmpdir, "--database", self.dburi])
+                assert backend.is_applied(migrations[0])
+                assert not backend.is_applied(migrations[1])
+                assert not backend.is_applied(migrations[2])
 
 
 class TestNewMigration(TestInteractiveScript):

          
M yoyo/tests/test_migrations.py +6 -6
@@ 176,12 176,12 @@ def test_specify_migration_table(tmpdir,
         step("DROP TABLE yoyo_test")
         """
     ) as tmpdir:
-        backend = get_backend(dburi, migration_table="another_migration_table")
-        migrations = read_migrations(tmpdir)
-        backend.apply_migrations(migrations)
-        cursor = backend.cursor()
-        cursor.execute("SELECT migration_id FROM another_migration_table")
-        assert list(cursor.fetchall()) == [("0",)]
+        with get_backend(dburi, migration_table="another_migration_table") as backend:
+            migrations = read_migrations(tmpdir)
+            backend.apply_migrations(migrations)
+            cursor = backend.cursor()
+            cursor.execute("SELECT migration_id FROM another_migration_table")
+            assert list(cursor.fetchall()) == [("0",)]
 
 
 def test_migration_functions_have_namespace_access(backend_sqlite3):