# HG changeset patch # User Oben Sonne # Date 1414261632 -7200 # Sat Oct 25 20:27:12 2014 +0200 # Node ID 1d6117b3648a7a9d576560c87742d970d4979e9d # Parent 71eadafb3343cc915b9a91495465ff86c76c5f63 pass uri parsing do db binding lib diff --git a/src/deeno/db.py b/src/deeno/db.py --- a/src/deeno/db.py +++ b/src/deeno/db.py @@ -1,4 +1,3 @@ -import os from functools import partial, wraps from contextlib import contextmanager import logging @@ -6,7 +5,7 @@ import sqlite3 import threading from collections import namedtuple -from urllib.parse import urlparse, unquote +from urllib.parse import urlparse from deeno import sql, DeenoException @@ -381,16 +380,21 @@ placeholder = '%s' schema_separator_char = '.' - def __init__(self, host, port=None, user=None, password=None, - database=None): + def __init__(self, dsn=None, host=None, port=None, user=None, + password=None, database=None): if isinstance(psycopg2, ImportError): raise psycopg2 super(PostgresDatabase, self).__init__() - self._config = { - 'host': host, 'port': port, - 'user': user, 'password': password, - 'database': database - } + if dsn: + self._config = { + 'dsn': dsn + } + else: + self._config = { + 'host': host, 'port': port, + 'user': user, 'password': password, + 'database': database + } def connect(self): self._tl.connection = psycopg2.connect(**self._config) @@ -477,19 +481,10 @@ def get_db(uri): """Database factory method based on a connection URI.""" - u = urlparse(uri) - - if u.scheme == 'sqlite': - path = u.netloc if u.netloc == ':memory:' else u.path - return SQLiteDatabase(path) + if uri == ':memory:' or uri.startswith('file:'): + return SQLiteDatabase(urlparse(uri).path) - if u.scheme == 'postgres': - if not u.hostname: - host = os.path.dirname(u.path) - else: - host = unquote(u.hostname) - database = os.path.basename(u.path) - return PostgresDatabase(host, port=u.port, user=u.username, - password=u.password, database=database) + if uri.startswith('postgres:') or uri.startswith('postgresql:'): + return PostgresDatabase(dsn=uri) - raise ValueError('unsupported db: %s' % u.scheme) + raise ValueError('unsupported uri: %s' % uri) diff --git a/src/tests/test_db.py b/src/tests/test_db.py --- a/src/tests/test_db.py +++ b/src/tests/test_db.py @@ -4,6 +4,7 @@ import subprocess import shutil import time +from urllib.parse import quote import psycopg2 @@ -593,15 +594,41 @@ r = self.db.r.select.get(integer=1) eq_(r, {'text': 'x', 'integer': 1}) + def test_connect_with_uri(self): + raise NotImplementedError + class SqliteDatabaseTest(AbstractDatabaseTest, unittest.TestCase): ROWCOUNT_FOR_SELECT = False SERIAL_KEY_TYPE = 'INTEGER PRIMARY KEY ASC' + TEST_DB_FNAME = os.path.join(TESTS_WD, 'sqlite.db') def setUp(self): self.db = SQLiteDatabase(':memory:') + def tearDown(self): + if os.path.exists(self.TEST_DB_FNAME): + os.remove(self.TEST_DB_FNAME) + + def test_connect_with_uri(self): + + db = get_db(':memory:') + eq_(type(db), SQLiteDatabase) + db.fetchall('SELECT 1') + + db = get_db('file://%s' % self.TEST_DB_FNAME) + eq_(type(db), SQLiteDatabase) + db.fetchall('SELECT 1') + + db = get_db('file:%s' % self.TEST_DB_FNAME) + eq_(type(db), SQLiteDatabase) + db.fetchall('SELECT 1') + + db = get_db('file:%s' % os.path.relpath(self.TEST_DB_FNAME)) + eq_(type(db), SQLiteDatabase) + db.fetchall('SELECT 1') + class PostgresDatabaseTest(AbstractDatabaseTest, unittest.TestCase): @@ -664,7 +691,7 @@ def setUp(self): - self.db = PostgresDatabase(self.pgtc, user=os.environ['USER'], + self.db = PostgresDatabase(host=self.pgtc, user=os.environ['USER'], database='deeno') def tearDown(self): @@ -722,51 +749,19 @@ r = self.db.r['s1.t1'].get(c1=2) eq_(r, {'c1': 2, 'c2': 'y'}) - -class GetDbTest(unittest.TestCase): - - def setUp(self): - self.db = SQLiteDatabase(':memory:') - self.sqlite_fname = os.path.join(TESTS_WD, 'sqlite.db') - self.cleanup() - - def cleanup(self): - if os.path.exists(self.sqlite_fname): - os.remove(self.sqlite_fname) + def test_connect_with_uri(self): - def test(self): - - db = get_db('sqlite://:memory:') - eq_(type(db), SQLiteDatabase) - eq_(db._path, ':memory:') + host = quote(self.pgtc, safe='') + uri = 'postgres://%s@%s/deeno' % (os.environ['USER'], host) + db = PostgresDatabase(uri) + db.fetchall('SELECT 1') - db = get_db('sqlite://%s' % self.sqlite_fname) - eq_(type(db), SQLiteDatabase) - eq_(db._path, self.sqlite_fname) - - db = get_db('postgres://joe:secret@thehost:44/thedb') + uri = 'postgres://joe:secret@thehost:44/thedb' + db = get_db(uri) eq_(type(db), PostgresDatabase) - eq_(db._config, { - 'host': 'thehost', 'port': 44, - 'user': 'joe', 'password': 'secret', - 'database': 'thedb', - }) + eq_(db._config, {'dsn': uri}) - db = get_db('postgres:///path/to/local/cluster/thedb') + uri = 'postgresql://joe:secret@thehost:44/thedb' + db = get_db(uri) eq_(type(db), PostgresDatabase) - eq_(db._config, { - 'host': '/path/to/local/cluster', 'port': None, - 'user': None, 'password': None, - 'database': 'thedb', - }) - - db = get_db('postgres://%2Fpath%2Fto%2Flocal%2Fcluster/thedb') - eq_(type(db), PostgresDatabase) - eq_(db._config, { - 'host': '/path/to/local/cluster', 'port': None, - 'user': None, 'password': None, - 'database': 'thedb', - }) - - with self.assertRaises(ValueError): - get_db('unknown://host/db') + eq_(db._config, {'dsn': uri})