M yoyo/backends/base.py +6 -9
@@ 12,14 12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections.abc import Mapping
+from collections import abc
from datetime import datetime
from datetime import timezone
from contextlib import contextmanager
from importlib import import_module
from itertools import count
from logging import getLogger
-from typing import Dict
from importlib_metadata import entry_points
import getpass
@@ 27,6 26,7 @@ import os
import pickle
import socket
import time
+import typing as t
import uuid
from yoyo import exceptions
@@ 150,7 150,7 @@ class DatabaseBackend:
_is_locked = False
_in_transaction = False
_internal_schema_updated = False
- _transactional_ddl_cache: Dict[bytes, bool] = {}
+ _transactional_ddl_cache: dict[bytes, bool] = {}
def __init__(self, dburi, migration_table):
self.uri = dburi
@@ 373,7 373,7 @@ class DatabaseBackend:
with self.transaction():
self.execute("DELETE FROM {}".format(self.lock_table_quoted))
- def execute(self, sql, params=None):
+ def execute(self, sql, params: t.Union[abc.Mapping[str, t.Any], None] = None):
"""
Create a new cursor, execute a single statement and return the cursor
object.
@@ 382,12 382,9 @@ class DatabaseBackend:
(eg 'SELECT * FROM foo WHERE :bar IS NULL')
:param params: A dictionary of parameters
"""
- if params and not isinstance(params, Mapping):
- raise TypeError("Expected dict or other mapping object")
-
cursor = self.cursor()
- sql, params = utils.change_param_style(self.driver.paramstyle, sql, params)
- cursor.execute(sql, params)
+ sql, queryparams = utils.change_param_style(self.driver.paramstyle, sql, params)
+ cursor.execute(sql, queryparams)
return cursor
def create_lock_table(self):
M yoyo/utils.py +9 -3
@@ 13,6 13,7 @@
# limitations under the License.
from itertools import count
+from collections import abc
import configparser
import os
import random
@@ 20,6 21,7 @@ import re
import string
import sys
import unicodedata
+import typing as t
from yoyo.config import CONFIG_EDITOR_KEY
@@ 125,18 127,22 @@ def get_random_string(length, chars=(str
return "".join(rng.choice(chars) for i in range(length))
-def change_param_style(target_style, sql, bind_parameters):
+def change_param_style(
+ target_style: str,
+ sql: str,
+ bind_parameters: t.Optional[abc.Mapping[str, t.Any]]
+) -> tuple[str, t.Union[abc.Mapping[str, t.Any], abc.Sequence[str]]]:
"""
:param target_style: A DBAPI paramstyle value (eg 'qmark', 'format', etc)
:param sql: An SQL str
- :bind_parameters: A dict of bind parameters for the query
+ :param bind_parameters: A dict of bind parameters for the query
:return: tuple of `(sql, bind_parameters)`. ``sql`` will be rewritten with
the target paramstyle; ``bind_parameters`` will be a tuple or
dict as required.
"""
if target_style == "named":
- return sql, bind_parameters
+ return sql, bind_parameters or {}
positional = target_style in {"qmark", "numeric", "format"}
if not bind_parameters:
return (sql, (tuple() if positional else {}))