ensure files accessed through importlib.resources are cleaned up at exit
1 files changed, 23 insertions(+), 18 deletions(-)

M yoyo/migrations.py
M yoyo/migrations.py +23 -18
@@ 15,23 15,26 @@ 
 from collections import Counter
 from collections import OrderedDict
 from collections import abc
+from contextlib import ExitStack
 from copy import copy
 from glob import glob
+from importlib import resources
 from itertools import chain
 from itertools import count
 from itertools import zip_longest
 from logging import getLogger
+import atexit
 import typing as t
 import hashlib
 import importlib.util
 import os
+import pathlib
 import re
 import sys
 import inspect
 import types
 import textwrap
 
-from importlib.resources import files
 import sqlparse
 
 from yoyo import exceptions

          
@@ 43,15 46,16 @@ default_migration_table = "_yoyo_migrati
 hash_function = hashlib.sha256
 
 
-def _is_migration_file(path):
+def _is_migration_file(path: pathlib.Path):
     """
     Return True if the given path matches a migration file pattern
     """
     from yoyo.scripts import newmigration
 
-    _, extension = os.path.splitext(path)
-    return extension in {".py", ".sql"} and not path.startswith(
-        newmigration.tempfile_prefix
+    return (
+        path.is_file()
+        and path.suffix in {".py", ".sql"}
+        and not path.name.startswith(newmigration.tempfile_prefix)
     )
 
 

          
@@ 446,30 450,31 @@ class StepGroup(MigrationStep):
 
 def _expand_sources(sources) -> t.Iterable[t.Tuple[str, t.List[str]]]:
     package_match = re.compile(r"^package:([^\s\/:]+):(.*)$").match
+
+    filecontext = ExitStack()
+    atexit.register(filecontext.close)
+
     for source in sources:
         mo = package_match(source)
         if mo:
             package_name = mo.group(1)
             resource_dir = mo.group(2)
             try:
-                pkg_files = files(package_name).joinpath(resource_dir)
+                pkg_files = resources.files(package_name).joinpath(resource_dir)
                 if pkg_files.is_dir():
-                    paths = [
-                        str(file)
-                        for file in sorted(pkg_files.iterdir())
-                        if _is_migration_file(file.name)
-                    ]
+                    all_files = (
+                        filecontext.enter_context(resources.as_file(traversable))
+                        for traversable in pkg_files.iterdir()
+                        if traversable.is_file()
+                    )
+                    paths = [str(f) for f in sorted(all_files) if _is_migration_file(f)]
                     yield (source, paths)
             except FileNotFoundError:
                 continue
         else:
-            for directory in glob(source):
-                paths = [
-                    os.path.join(directory, path)
-                    for path in os.listdir(directory)
-                    if _is_migration_file(path)
-                ]
-                yield (directory, sorted(paths))
+            for directory in map(pathlib.Path, glob(source)):
+                paths = [str(f) for f in directory.iterdir() if _is_migration_file(f)]
+                yield (str(directory), sorted(paths))
 
 
 def read_migrations(*sources):