@@ 15,23 15,26 @@
from collections import Counter
from collections import OrderedDict
from collections import abc
+from contextlib import ExitStack
from copy import copy
from glob import glob
+from importlib import resources
from itertools import chain
from itertools import count
from itertools import zip_longest
from logging import getLogger
+import atexit
import typing as t
import hashlib
import importlib.util
import os
+import pathlib
import re
import sys
import inspect
import types
import textwrap
-from importlib.resources import files
import sqlparse
from yoyo import exceptions
@@ 43,15 46,16 @@ default_migration_table = "_yoyo_migrati
hash_function = hashlib.sha256
-def _is_migration_file(path):
+def _is_migration_file(path: pathlib.Path):
"""
Return True if the given path matches a migration file pattern
"""
from yoyo.scripts import newmigration
- _, extension = os.path.splitext(path)
- return extension in {".py", ".sql"} and not path.startswith(
- newmigration.tempfile_prefix
+ return (
+ path.is_file()
+ and path.suffix in {".py", ".sql"}
+ and not path.name.startswith(newmigration.tempfile_prefix)
)
@@ 446,30 450,31 @@ class StepGroup(MigrationStep):
def _expand_sources(sources) -> t.Iterable[t.Tuple[str, t.List[str]]]:
package_match = re.compile(r"^package:([^\s\/:]+):(.*)$").match
+
+ filecontext = ExitStack()
+ atexit.register(filecontext.close)
+
for source in sources:
mo = package_match(source)
if mo:
package_name = mo.group(1)
resource_dir = mo.group(2)
try:
- pkg_files = files(package_name).joinpath(resource_dir)
+ pkg_files = resources.files(package_name).joinpath(resource_dir)
if pkg_files.is_dir():
- paths = [
- str(file)
- for file in sorted(pkg_files.iterdir())
- if _is_migration_file(file.name)
- ]
+ all_files = (
+ filecontext.enter_context(resources.as_file(traversable))
+ for traversable in pkg_files.iterdir()
+ if traversable.is_file()
+ )
+ paths = [str(f) for f in sorted(all_files) if _is_migration_file(f)]
yield (source, paths)
except FileNotFoundError:
continue
else:
- for directory in glob(source):
- paths = [
- os.path.join(directory, path)
- for path in os.listdir(directory)
- if _is_migration_file(path)
- ]
- yield (directory, sorted(paths))
+ for directory in map(pathlib.Path, glob(source)):
+ paths = [str(f) for f in directory.iterdir() if _is_migration_file(f)]
+ yield (str(directory), sorted(paths))
def read_migrations(*sources):