# HG changeset patch # User Olly Cope # Date 1714724484 0 # Fri May 03 08:21:24 2024 +0000 # Node ID 914dff3af4bd79e95f6bde5c7463e4873f5127e7 # Parent a38828e1ac8545a6869cbaeeea64943e958a13cf quote all identifiers used in internal sql queries This should fix #102 as well as future-proof against similar issues arising for other backends diff --git a/yoyo/backends/base.py b/yoyo/backends/base.py --- a/yoyo/backends/base.py +++ b/yoyo/backends/base.py @@ -115,35 +115,40 @@ migration_table = "_yoyo_migrations" is_applied_sql = """ SELECT COUNT(1) FROM {0.migration_table_quoted} - WHERE id=:id""" + WHERE {quoted.id}=:id}""" mark_migration_sql = ( "INSERT INTO {0.migration_table_quoted} " - "(migration_hash, migration_id, applied_at_utc) " + "({quoted.migration_hash}, {quoted.migration_id}, {quoted.applied_at_utc}) " "VALUES (:migration_hash, :migration_id, :when)" ) + unmark_migration_sql = ( "DELETE FROM {0.migration_table_quoted} WHERE " - "migration_hash = :migration_hash" + "{quoted.migration_hash} = :migration_hash" ) + applied_migrations_sql = ( - "SELECT migration_hash FROM " + "SELECT {quoted.migration_hash} FROM " "{0.migration_table_quoted} " - "ORDER by applied_at_utc" + "ORDER by {quoted.applied_at_utc}" ) - create_test_table_sql = "CREATE TABLE {table_name_quoted} " "(id INT PRIMARY KEY)" + create_test_table_sql = ( + "CREATE TABLE {table_name_quoted} ({quoted.id} INT PRIMARY KEY)" + ) log_migration_sql = ( "INSERT INTO {0.log_table_quoted} " - "(id, migration_hash, migration_id, operation, " - "username, hostname, created_at_utc) " + "({quoted.id}, {quoted.migration_hash}, {quoted.migration_id}, " + "{quoted.operation}, {quoted.username}, {quoted.hostname}, " + "{quoted.created_at_utc}) " "VALUES (:id, :migration_hash, :migration_id, " ":operation, :username, :hostname, :created_at_utc)" ) create_lock_table_sql = ( "CREATE TABLE {0.lock_table_quoted} (" - "locked INT DEFAULT 1, " - "ctime TIMESTAMP," - "pid INT NOT NULL," - "PRIMARY KEY (locked))" + "{quoted.locked} INT DEFAULT 1, " + "{quoted.ctime} TIMESTAMP," + "{quoted.pid} INT NOT NULL," + "PRIMARY KEY ({quoted.locked}))" ) _driver = None @@ -177,6 +182,19 @@ exceptions.register(driver.DatabaseError) return driver + def format_sql(self, s, **kwargs): + """ + Take a string in the format used by the various ``..._sql`` class + variables and format it. + """ + quote_identifier = self.quote_identifier + + class Quoter: + def __getattr__(self, s): + return quote_identifier(s) + + return s.format(self, quoted=Quoter(), **kwargs) + @property def driver(self): if self._driver: @@ -228,7 +246,9 @@ """ table_name = "yoyo_tmp_{}".format(utils.get_random_string(10)) table_name_quoted = self.quote_identifier(table_name) - sql = self.create_test_table_sql.format(table_name_quoted=table_name_quoted) + sql = self.format_sql( + self.create_test_table_sql, table_name_quoted=table_name_quoted + ) try: with self.transaction(rollback_on_exit=True): self.execute(sql) @@ -237,7 +257,7 @@ try: with self.transaction(): - self.execute("DROP TABLE {}".format(table_name_quoted)) + self.execute(f"DROP TABLE {table_name_quoted}") except self.DatabaseError: return True return False @@ -285,19 +305,19 @@ """ Create a new savepoint with the given id """ - self.execute("SAVEPOINT {}".format(id)) + self.execute(f"SAVEPOINT {self.quote_identifier(id)}") def savepoint_release(self, id): """ Release (commit) the savepoint with the given id """ - self.execute("RELEASE SAVEPOINT {}".format(id)) + self.execute(f"RELEASE SAVEPOINT {self.quote_identifier(id)}") def savepoint_rollback(self, id): """ Rollback the savepoint with the given id """ - self.execute("ROLLBACK TO SAVEPOINT {}".format(id)) + self.execute(f"ROLLBACK TO SAVEPOINT {self.quote_identifier(id)}") @contextmanager def disable_transactions(self): @@ -331,12 +351,16 @@ def _insert_lock_row(self, pid, timeout, poll_interval=0.5): poll_interval = min(poll_interval, timeout) started = time.time() + qi = self.quote_identifier while True: try: with self.transaction(): self.execute( - "INSERT INTO {} (locked, ctime, pid) " - "VALUES (1, :when, :pid)".format(self.lock_table_quoted), + f""" + INSERT INTO {self.lock_table_quoted} + ({qi('locked')}, {qi('ctime')}, {qi('pid')}) + VALUES (1, :when, :pid) + """, { "when": datetime.now(timezone.utc).replace(tzinfo=None), "pid": pid, @@ -345,18 +369,17 @@ except self.DatabaseError: if timeout and time.time() > started + timeout: cursor = self.execute( - "SELECT pid FROM {}".format(self.lock_table_quoted) + f"SELECT {qi('pid')} FROM {self.lock_table_quoted}" ) row = cursor.fetchone() if row: raise exceptions.LockTimeout( - "Process {} has locked this database " - "(run yoyo break-lock to remove this lock)".format(row[0]) + f"Process {row[0]} has locked this database " + "(run yoyo break-lock to remove this lock)" ) else: raise exceptions.LockTimeout( - "Database locked " - "(run yoyo break-lock to remove this lock)" + "Database locked (run yoyo break-lock to remove this lock)" ) time.sleep(poll_interval) else: @@ -364,14 +387,15 @@ def _delete_lock_row(self, pid): with self.transaction(): + qi = self.quote_identifier self.execute( - "DELETE FROM {} WHERE pid=:pid".format(self.lock_table_quoted), + f"DELETE FROM {self.lock_table_quoted} WHERE {qi('pid')}=:pid", {"pid": pid}, ) def break_lock(self): with self.transaction(): - self.execute("DELETE FROM {}".format(self.lock_table_quoted)) + self.execute(f"DELETE FROM {self.lock_table_quoted}") def execute(self, sql, params: t.Union[abc.Mapping[str, t.Any], None] = None): """ @@ -393,7 +417,7 @@ """ try: with self.transaction(): - self.execute(self.create_lock_table_sql.format(self)) + self.execute(self.format_sql(self.create_lock_table_sql)) except self.DatabaseError: pass @@ -419,7 +443,7 @@ were applied """ self.ensure_internal_schema_updated() - sql = self.applied_migrations_sql.format(self) + sql = self.format_sql(self.applied_migrations_sql) return [row[0] for row in self.execute(sql).fetchall()] def to_apply(self, migrations): @@ -523,7 +547,7 @@ def unmark_one(self, migration, log=True): self.ensure_internal_schema_updated() - sql = self.unmark_migration_sql.format(self) + sql = self.format_sql(self.unmark_migration_sql) self.execute(sql, {"migration_hash": migration.hash}) if log: self.log_migration(migration, "unmark") @@ -531,7 +555,7 @@ def mark_one(self, migration, log=True): self.ensure_internal_schema_updated() logger.info("Marking %s applied", migration.id) - sql = self.mark_migration_sql.format(self) + sql = self.format_sql(self.mark_migration_sql) self.execute( sql, { @@ -544,7 +568,7 @@ self.log_migration(migration, "mark") def log_migration(self, migration, operation, comment=None): - sql = self.log_migration_sql.format(self) + sql = self.format_sql(self.log_migration_sql) self.execute(sql, self.get_log_data(migration, operation, comment)) def get_log_data(self, migration=None, operation="apply", comment=None): diff --git a/yoyo/internalmigrations/__init__.py b/yoyo/internalmigrations/__init__.py --- a/yoyo/internalmigrations/__init__.py +++ b/yoyo/internalmigrations/__init__.py @@ -49,8 +49,9 @@ return 0 if version_table not in tables: return 1 + qi = backend.quote_identifier cursor = backend.execute( - f"SELECT max(version) FROM {backend.quote_identifier(version_table)}" + f"SELECT max({qi('version')}) FROM {qi(version_table)}" ) version = cursor.fetchone()[0] assert version in schema_versions @@ -65,6 +66,6 @@ if version < USE_VERSION_TABLE_FROM: return backend.execute( - "INSERT INTO {0.version_table_quoted} VALUES (:version, :when)".format(backend), + f"INSERT INTO {backend.version_table_quoted} VALUES (:version, :when)", {"version": version, "when": datetime.now(timezone.utc).replace(tzinfo=None)}, ) diff --git a/yoyo/internalmigrations/v2.py b/yoyo/internalmigrations/v2.py --- a/yoyo/internalmigrations/v2.py +++ b/yoyo/internalmigrations/v2.py @@ -8,10 +8,11 @@ def upgrade(backend): + qi = backend.quote_identifier create_log_table(backend) create_version_table(backend) cursor = backend.execute( - "SELECT id, ctime FROM {}".format(backend.migration_table_quoted) + f"SELECT {qi('id')}, {qi('ctime')} FROM {backend.migration_table_quoted}" ) migration_id = "" created_at = datetime(1970, 1, 1) @@ -27,56 +28,81 @@ migration_hash=migration_hash, migration_id=migration_id, ) + qi = backend.quote_identifier backend.execute( - "INSERT INTO {0.log_table_quoted} " - "(id, migration_hash, migration_id, operation, created_at_utc, " - "username, hostname, comment) " - "VALUES " - "(:id, :migration_hash, :migration_id, 'apply', :created_at_utc, " - ":username, :hostname, :comment)".format(backend), + f""" + INSERT INTO {backend.log_table_quoted} ( + {qi('id')}, + {qi('migration_hash')}, + {qi('migration_id')}, + {qi('operation')}, + {qi('created_at_utc')}, + {qi('username')}, + {qi('hostname')}, + {qi('comment')} + ) VALUES ( + :id, :migration_hash, :migration_id, 'apply', :created_at_utc, + :username, :hostname, :comment + ) + """, log_data, ) backend.execute("DROP TABLE {0.migration_table_quoted}".format(backend)) create_migration_table(backend) backend.execute( - "INSERT INTO {0.migration_table_quoted} " - "SELECT migration_hash, migration_id, created_at_utc " - "FROM {0.log_table_quoted}".format(backend) + f""" + INSERT INTO {backend.migration_table_quoted} + SELECT {qi('migration_hash')}, {qi('migration_id')}, {qi('created_at_utc')} + FROM {backend.log_table_quoted} + """ ) def create_migration_table(backend): + qi = backend.quote_identifier backend.execute( - "CREATE TABLE {0.migration_table_quoted} ( " - # sha256 hash of the migration id - "migration_hash VARCHAR(64), " - # The migration id (ie path basename without extension) - "migration_id VARCHAR(255), " - # When this id was applied - "applied_at_utc TIMESTAMP, " - "PRIMARY KEY (migration_hash))".format(backend) + # migration_hash: sha256 hash of the migration id + # migration_id: identifier of the migration file + # (path basename without extension) + # applied_at_utc: time in UTC of when the id was applied + f""" + CREATE TABLE {backend.migration_table_quoted} ( + {qi('migration_hash')} VARCHAR(64), + {qi('migration_id')} VARCHAR(255), + {qi('applied_at_utc')} TIMESTAMP, + PRIMARY KEY ({qi('migration_hash')}) + ) + """ ) def create_log_table(backend): + qi = backend.quote_identifier backend.execute( - "CREATE TABLE {0.log_table_quoted} ( " - "id VARCHAR(36), " - "migration_hash VARCHAR(64), " - "migration_id VARCHAR(255), " - "operation VARCHAR(10), " - "username VARCHAR(255), " - "hostname VARCHAR(255), " - "comment VARCHAR(255), " - "created_at_utc TIMESTAMP, " - "PRIMARY KEY (id))".format(backend) + f""" + CREATE TABLE {backend.log_table_quoted} ( + {qi('id')} VARCHAR(36), + {qi('migration_hash')} VARCHAR(64), + {qi('migration_id')} VARCHAR(255), + {qi('operation')} VARCHAR(10), + {qi('username')} VARCHAR(255), + {qi('hostname')} VARCHAR(255), + {qi('comment')} VARCHAR(255), + {qi('created_at_utc')} TIMESTAMP, + PRIMARY KEY ({qi('id')}) + ) + """ ) def create_version_table(backend): + qi = backend.quote_identifier backend.execute( - "CREATE TABLE {0.version_table_quoted} (" - "version INT NOT NULL PRIMARY KEY, " - "installed_at_utc TIMESTAMP)".format(backend) + f""" + CREATE TABLE {backend.version_table_quoted} ( + {qi('version')} INT NOT NULL PRIMARY KEY, + {qi('installed_at_utc')} TIMESTAMP + ) + """ )