quote all identifiers used in internal sql queries

This should fix #102 as well as future-proof against similar issues arising for
other backends
3 files changed, 115 insertions(+), 64 deletions(-)

M yoyo/backends/base.py
M yoyo/internalmigrations/__init__.py
M yoyo/internalmigrations/v2.py
M yoyo/backends/base.py +55 -31
@@ 115,35 115,40 @@ class DatabaseBackend:
     migration_table = "_yoyo_migrations"
     is_applied_sql = """
         SELECT COUNT(1) FROM {0.migration_table_quoted}
-        WHERE id=:id"""
+        WHERE {quoted.id}=:id}"""
     mark_migration_sql = (
         "INSERT INTO {0.migration_table_quoted} "
-        "(migration_hash, migration_id, applied_at_utc) "
+        "({quoted.migration_hash}, {quoted.migration_id}, {quoted.applied_at_utc}) "
         "VALUES (:migration_hash, :migration_id, :when)"
     )
+
     unmark_migration_sql = (
         "DELETE FROM {0.migration_table_quoted} WHERE "
-        "migration_hash = :migration_hash"
+        "{quoted.migration_hash} = :migration_hash"
     )
+
     applied_migrations_sql = (
-        "SELECT migration_hash FROM "
+        "SELECT {quoted.migration_hash} FROM "
         "{0.migration_table_quoted} "
-        "ORDER by applied_at_utc"
+        "ORDER by {quoted.applied_at_utc}"
     )
-    create_test_table_sql = "CREATE TABLE {table_name_quoted} " "(id INT PRIMARY KEY)"
+    create_test_table_sql = (
+        "CREATE TABLE {table_name_quoted} ({quoted.id} INT PRIMARY KEY)"
+    )
     log_migration_sql = (
         "INSERT INTO {0.log_table_quoted} "
-        "(id, migration_hash, migration_id, operation, "
-        "username, hostname, created_at_utc) "
+        "({quoted.id}, {quoted.migration_hash}, {quoted.migration_id}, "
+        "{quoted.operation}, {quoted.username}, {quoted.hostname}, "
+        "{quoted.created_at_utc}) "
         "VALUES (:id, :migration_hash, :migration_id, "
         ":operation, :username, :hostname, :created_at_utc)"
     )
     create_lock_table_sql = (
         "CREATE TABLE {0.lock_table_quoted} ("
-        "locked INT DEFAULT 1, "
-        "ctime TIMESTAMP,"
-        "pid INT NOT NULL,"
-        "PRIMARY KEY (locked))"
+        "{quoted.locked} INT DEFAULT 1, "
+        "{quoted.ctime} TIMESTAMP,"
+        "{quoted.pid} INT NOT NULL,"
+        "PRIMARY KEY ({quoted.locked}))"
     )
 
     _driver = None

          
@@ 177,6 182,19 @@ class DatabaseBackend:
         exceptions.register(driver.DatabaseError)
         return driver
 
+    def format_sql(self, s, **kwargs):
+        """
+        Take a string in the format used by the various ``..._sql`` class
+        variables and format it.
+        """
+        quote_identifier = self.quote_identifier
+
+        class Quoter:
+            def __getattr__(self, s):
+                return quote_identifier(s)
+
+        return s.format(self, quoted=Quoter(), **kwargs)
+
     @property
     def driver(self):
         if self._driver:

          
@@ 228,7 246,9 @@ class DatabaseBackend:
         """
         table_name = "yoyo_tmp_{}".format(utils.get_random_string(10))
         table_name_quoted = self.quote_identifier(table_name)
-        sql = self.create_test_table_sql.format(table_name_quoted=table_name_quoted)
+        sql = self.format_sql(
+            self.create_test_table_sql, table_name_quoted=table_name_quoted
+        )
         try:
             with self.transaction(rollback_on_exit=True):
                 self.execute(sql)

          
@@ 237,7 257,7 @@ class DatabaseBackend:
 
         try:
             with self.transaction():
-                self.execute("DROP TABLE {}".format(table_name_quoted))
+                self.execute(f"DROP TABLE {table_name_quoted}")
         except self.DatabaseError:
             return True
         return False

          
@@ 285,19 305,19 @@ class DatabaseBackend:
         """
         Create a new savepoint with the given id
         """
-        self.execute("SAVEPOINT {}".format(id))
+        self.execute(f"SAVEPOINT {self.quote_identifier(id)}")
 
     def savepoint_release(self, id):
         """
         Release (commit) the savepoint with the given id
         """
-        self.execute("RELEASE SAVEPOINT {}".format(id))
+        self.execute(f"RELEASE SAVEPOINT {self.quote_identifier(id)}")
 
     def savepoint_rollback(self, id):
         """
         Rollback the savepoint with the given id
         """
-        self.execute("ROLLBACK TO SAVEPOINT {}".format(id))
+        self.execute(f"ROLLBACK TO SAVEPOINT {self.quote_identifier(id)}")
 
     @contextmanager
     def disable_transactions(self):

          
@@ 331,12 351,16 @@ class DatabaseBackend:
     def _insert_lock_row(self, pid, timeout, poll_interval=0.5):
         poll_interval = min(poll_interval, timeout)
         started = time.time()
+        qi = self.quote_identifier
         while True:
             try:
                 with self.transaction():
                     self.execute(
-                        "INSERT INTO {} (locked, ctime, pid) "
-                        "VALUES (1, :when, :pid)".format(self.lock_table_quoted),
+                        f"""
+                        INSERT INTO {self.lock_table_quoted}
+                        ({qi('locked')}, {qi('ctime')}, {qi('pid')})
+                        VALUES (1, :when, :pid)
+                        """,
                         {
                             "when": datetime.now(timezone.utc).replace(tzinfo=None),
                             "pid": pid,

          
@@ 345,18 369,17 @@ class DatabaseBackend:
             except self.DatabaseError:
                 if timeout and time.time() > started + timeout:
                     cursor = self.execute(
-                        "SELECT pid FROM {}".format(self.lock_table_quoted)
+                        f"SELECT {qi('pid')} FROM {self.lock_table_quoted}"
                     )
                     row = cursor.fetchone()
                     if row:
                         raise exceptions.LockTimeout(
-                            "Process {} has locked this database "
-                            "(run yoyo break-lock to remove this lock)".format(row[0])
+                            f"Process {row[0]} has locked this database "
+                            "(run yoyo break-lock to remove this lock)"
                         )
                     else:
                         raise exceptions.LockTimeout(
-                            "Database locked "
-                            "(run yoyo break-lock to remove this lock)"
+                            "Database locked (run yoyo break-lock to remove this lock)"
                         )
                 time.sleep(poll_interval)
             else:

          
@@ 364,14 387,15 @@ class DatabaseBackend:
 
     def _delete_lock_row(self, pid):
         with self.transaction():
+            qi = self.quote_identifier
             self.execute(
-                "DELETE FROM {} WHERE pid=:pid".format(self.lock_table_quoted),
+                f"DELETE FROM {self.lock_table_quoted} WHERE {qi('pid')}=:pid",
                 {"pid": pid},
             )
 
     def break_lock(self):
         with self.transaction():
-            self.execute("DELETE FROM {}".format(self.lock_table_quoted))
+            self.execute(f"DELETE FROM {self.lock_table_quoted}")
 
     def execute(self, sql, params: t.Union[abc.Mapping[str, t.Any], None] = None):
         """

          
@@ 393,7 417,7 @@ class DatabaseBackend:
         """
         try:
             with self.transaction():
-                self.execute(self.create_lock_table_sql.format(self))
+                self.execute(self.format_sql(self.create_lock_table_sql))
         except self.DatabaseError:
             pass
 

          
@@ 419,7 443,7 @@ class DatabaseBackend:
         were applied
         """
         self.ensure_internal_schema_updated()
-        sql = self.applied_migrations_sql.format(self)
+        sql = self.format_sql(self.applied_migrations_sql)
         return [row[0] for row in self.execute(sql).fetchall()]
 
     def to_apply(self, migrations):

          
@@ 523,7 547,7 @@ class DatabaseBackend:
 
     def unmark_one(self, migration, log=True):
         self.ensure_internal_schema_updated()
-        sql = self.unmark_migration_sql.format(self)
+        sql = self.format_sql(self.unmark_migration_sql)
         self.execute(sql, {"migration_hash": migration.hash})
         if log:
             self.log_migration(migration, "unmark")

          
@@ 531,7 555,7 @@ class DatabaseBackend:
     def mark_one(self, migration, log=True):
         self.ensure_internal_schema_updated()
         logger.info("Marking %s applied", migration.id)
-        sql = self.mark_migration_sql.format(self)
+        sql = self.format_sql(self.mark_migration_sql)
         self.execute(
             sql,
             {

          
@@ 544,7 568,7 @@ class DatabaseBackend:
             self.log_migration(migration, "mark")
 
     def log_migration(self, migration, operation, comment=None):
-        sql = self.log_migration_sql.format(self)
+        sql = self.format_sql(self.log_migration_sql)
         self.execute(sql, self.get_log_data(migration, operation, comment))
 
     def get_log_data(self, migration=None, operation="apply", comment=None):

          
M yoyo/internalmigrations/__init__.py +3 -2
@@ 49,8 49,9 @@ def get_current_version(backend):
         return 0
     if version_table not in tables:
         return 1
+    qi = backend.quote_identifier
     cursor = backend.execute(
-        f"SELECT max(version) FROM {backend.quote_identifier(version_table)}"
+        f"SELECT max({qi('version')}) FROM {qi(version_table)}"
     )
     version = cursor.fetchone()[0]
     assert version in schema_versions

          
@@ 65,6 66,6 @@ def mark_schema_version(backend, version
     if version < USE_VERSION_TABLE_FROM:
         return
     backend.execute(
-        "INSERT INTO {0.version_table_quoted} VALUES (:version, :when)".format(backend),
+        f"INSERT INTO {backend.version_table_quoted} VALUES (:version, :when)",
         {"version": version, "when": datetime.now(timezone.utc).replace(tzinfo=None)},
     )

          
M yoyo/internalmigrations/v2.py +57 -31
@@ 8,10 8,11 @@ from yoyo.migrations import get_migratio
 
 
 def upgrade(backend):
+    qi = backend.quote_identifier
     create_log_table(backend)
     create_version_table(backend)
     cursor = backend.execute(
-        "SELECT id, ctime FROM {}".format(backend.migration_table_quoted)
+        f"SELECT {qi('id')}, {qi('ctime')} FROM {backend.migration_table_quoted}"
     )
     migration_id = ""
     created_at = datetime(1970, 1, 1)

          
@@ 27,56 28,81 @@ def upgrade(backend):
             migration_hash=migration_hash,
             migration_id=migration_id,
         )
+        qi = backend.quote_identifier
         backend.execute(
-            "INSERT INTO {0.log_table_quoted} "
-            "(id, migration_hash, migration_id, operation, created_at_utc, "
-            "username, hostname, comment) "
-            "VALUES "
-            "(:id, :migration_hash, :migration_id, 'apply', :created_at_utc, "
-            ":username, :hostname, :comment)".format(backend),
+            f"""
+            INSERT INTO {backend.log_table_quoted} (
+                {qi('id')},
+                {qi('migration_hash')},
+                {qi('migration_id')},
+                {qi('operation')},
+                {qi('created_at_utc')},
+                {qi('username')},
+                {qi('hostname')},
+                {qi('comment')}
+            ) VALUES (
+                :id, :migration_hash, :migration_id, 'apply', :created_at_utc,
+                :username, :hostname, :comment
+            )
+            """,
             log_data,
         )
 
     backend.execute("DROP TABLE {0.migration_table_quoted}".format(backend))
     create_migration_table(backend)
     backend.execute(
-        "INSERT INTO {0.migration_table_quoted} "
-        "SELECT migration_hash, migration_id, created_at_utc "
-        "FROM {0.log_table_quoted}".format(backend)
+        f"""
+        INSERT INTO {backend.migration_table_quoted}
+        SELECT {qi('migration_hash')}, {qi('migration_id')}, {qi('created_at_utc')}
+        FROM {backend.log_table_quoted}
+        """
     )
 
 
 def create_migration_table(backend):
+    qi = backend.quote_identifier
     backend.execute(
-        "CREATE TABLE {0.migration_table_quoted} ( "
-        # sha256 hash of the migration id
-        "migration_hash VARCHAR(64), "
-        # The migration id (ie path basename without extension)
-        "migration_id VARCHAR(255), "
-        # When this id was applied
-        "applied_at_utc TIMESTAMP, "
-        "PRIMARY KEY (migration_hash))".format(backend)
+        # migration_hash: sha256 hash of the migration id
+        # migration_id: identifier of the migration file
+        #               (path basename without extension)
+        # applied_at_utc: time in UTC of when the id was applied
+        f"""
+        CREATE TABLE {backend.migration_table_quoted} (
+            {qi('migration_hash')} VARCHAR(64),
+            {qi('migration_id')} VARCHAR(255),
+            {qi('applied_at_utc')} TIMESTAMP,
+            PRIMARY KEY ({qi('migration_hash')})
+        )
+        """
     )
 
 
 def create_log_table(backend):
+    qi = backend.quote_identifier
     backend.execute(
-        "CREATE TABLE {0.log_table_quoted} ( "
-        "id VARCHAR(36), "
-        "migration_hash VARCHAR(64), "
-        "migration_id VARCHAR(255), "
-        "operation VARCHAR(10), "
-        "username VARCHAR(255), "
-        "hostname VARCHAR(255), "
-        "comment VARCHAR(255), "
-        "created_at_utc TIMESTAMP, "
-        "PRIMARY KEY (id))".format(backend)
+        f"""
+        CREATE TABLE {backend.log_table_quoted} (
+            {qi('id')} VARCHAR(36),
+            {qi('migration_hash')} VARCHAR(64),
+            {qi('migration_id')} VARCHAR(255),
+            {qi('operation')} VARCHAR(10),
+            {qi('username')} VARCHAR(255),
+            {qi('hostname')} VARCHAR(255),
+            {qi('comment')} VARCHAR(255),
+            {qi('created_at_utc')} TIMESTAMP,
+            PRIMARY KEY ({qi('id')})
+        )
+        """
     )
 
 
 def create_version_table(backend):
+    qi = backend.quote_identifier
     backend.execute(
-        "CREATE TABLE {0.version_table_quoted} ("
-        "version INT NOT NULL PRIMARY KEY, "
-        "installed_at_utc TIMESTAMP)".format(backend)
+        f"""
+        CREATE TABLE {backend.version_table_quoted} (
+            {qi('version')} INT NOT NULL PRIMARY KEY,
+            {qi('installed_at_utc')} TIMESTAMP
+        )
+        """
     )