1d6117b3648a — Oben Sonne 9 years ago
pass uri parsing do db binding lib
2 files changed, 57 insertions(+), 67 deletions(-)

M src/deeno/db.py
M src/tests/test_db.py
M src/deeno/db.py +18 -23
@@ 1,4 1,3 @@ 
-import os
 from functools import partial, wraps
 from contextlib import contextmanager
 import logging

          
@@ 6,7 5,7 @@ import re
 import sqlite3
 import threading
 from collections import namedtuple
-from urllib.parse import urlparse, unquote
+from urllib.parse import urlparse
 
 from deeno import sql, DeenoException
 

          
@@ 381,16 380,21 @@ class PostgresDatabase(Database):
     placeholder = '%s'
     schema_separator_char = '.'
 
-    def __init__(self, host, port=None, user=None, password=None,
-            database=None):
+    def __init__(self, dsn=None, host=None, port=None, user=None,
+            password=None, database=None):
         if isinstance(psycopg2, ImportError):
             raise psycopg2
         super(PostgresDatabase, self).__init__()
-        self._config = {
-            'host': host, 'port': port,
-            'user': user, 'password': password,
-            'database': database
-        }
+        if dsn:
+            self._config = {
+                'dsn': dsn
+            }
+        else:
+            self._config = {
+                'host': host, 'port': port,
+                'user': user, 'password': password,
+                'database': database
+            }
 
     def connect(self):
         self._tl.connection = psycopg2.connect(**self._config)

          
@@ 477,19 481,10 @@ class PostgresDatabase(Database):
 def get_db(uri):
     """Database factory method based on a connection URI."""
 
-    u = urlparse(uri)
-
-    if u.scheme == 'sqlite':
-        path = u.netloc if u.netloc == ':memory:' else u.path
-        return SQLiteDatabase(path)
+    if uri == ':memory:' or uri.startswith('file:'):
+        return SQLiteDatabase(urlparse(uri).path)
 
-    if u.scheme == 'postgres':
-        if not u.hostname:
-            host = os.path.dirname(u.path)
-        else:
-            host = unquote(u.hostname)
-        database = os.path.basename(u.path)
-        return PostgresDatabase(host, port=u.port, user=u.username,
-            password=u.password, database=database)
+    if uri.startswith('postgres:') or uri.startswith('postgresql:'):
+        return PostgresDatabase(dsn=uri)
 
-    raise ValueError('unsupported db: %s' % u.scheme)
+    raise ValueError('unsupported uri: %s' % uri)

          
M src/tests/test_db.py +39 -44
@@ 4,6 4,7 @@ from collections import namedtuple
 import subprocess
 import shutil
 import time
+from urllib.parse import quote
 
 import psycopg2
 

          
@@ 593,15 594,41 @@ class AbstractDatabaseTest(object):
         r = self.db.r.select.get(integer=1)
         eq_(r, {'text': 'x', 'integer': 1})
 
+    def test_connect_with_uri(self):
+        raise NotImplementedError
+
 
 class SqliteDatabaseTest(AbstractDatabaseTest, unittest.TestCase):
 
     ROWCOUNT_FOR_SELECT = False
     SERIAL_KEY_TYPE = 'INTEGER PRIMARY KEY ASC'
+    TEST_DB_FNAME = os.path.join(TESTS_WD, 'sqlite.db')
 
     def setUp(self):
         self.db = SQLiteDatabase(':memory:')
 
+    def tearDown(self):
+        if os.path.exists(self.TEST_DB_FNAME):
+            os.remove(self.TEST_DB_FNAME)
+
+    def test_connect_with_uri(self):
+
+        db = get_db(':memory:')
+        eq_(type(db), SQLiteDatabase)
+        db.fetchall('SELECT 1')
+
+        db = get_db('file://%s' % self.TEST_DB_FNAME)
+        eq_(type(db), SQLiteDatabase)
+        db.fetchall('SELECT 1')
+
+        db = get_db('file:%s' % self.TEST_DB_FNAME)
+        eq_(type(db), SQLiteDatabase)
+        db.fetchall('SELECT 1')
+
+        db = get_db('file:%s' % os.path.relpath(self.TEST_DB_FNAME))
+        eq_(type(db), SQLiteDatabase)
+        db.fetchall('SELECT 1')
+
 
 class PostgresDatabaseTest(AbstractDatabaseTest, unittest.TestCase):
 

          
@@ 664,7 691,7 @@ class PostgresDatabaseTest(AbstractDatab
 
     def setUp(self):
 
-        self.db = PostgresDatabase(self.pgtc, user=os.environ['USER'],
+        self.db = PostgresDatabase(host=self.pgtc, user=os.environ['USER'],
             database='deeno')
 
     def tearDown(self):

          
@@ 722,51 749,19 @@ class PostgresDatabaseTest(AbstractDatab
         r = self.db.r['s1.t1'].get(c1=2)
         eq_(r, {'c1': 2, 'c2': 'y'})
 
-
-class GetDbTest(unittest.TestCase):
-
-    def setUp(self):
-        self.db = SQLiteDatabase(':memory:')
-        self.sqlite_fname = os.path.join(TESTS_WD, 'sqlite.db')
-        self.cleanup()
-
-    def cleanup(self):
-        if os.path.exists(self.sqlite_fname):
-            os.remove(self.sqlite_fname)
+    def test_connect_with_uri(self):
 
-    def test(self):
-
-        db = get_db('sqlite://:memory:')
-        eq_(type(db), SQLiteDatabase)
-        eq_(db._path, ':memory:')
+        host = quote(self.pgtc, safe='')
+        uri = 'postgres://%s@%s/deeno' % (os.environ['USER'], host)
+        db = PostgresDatabase(uri)
+        db.fetchall('SELECT 1')
 
-        db = get_db('sqlite://%s' % self.sqlite_fname)
-        eq_(type(db), SQLiteDatabase)
-        eq_(db._path, self.sqlite_fname)
-
-        db = get_db('postgres://joe:secret@thehost:44/thedb')
+        uri = 'postgres://joe:secret@thehost:44/thedb'
+        db = get_db(uri)
         eq_(type(db), PostgresDatabase)
-        eq_(db._config, {
-            'host': 'thehost', 'port': 44,
-            'user': 'joe', 'password': 'secret',
-            'database': 'thedb',
-        })
+        eq_(db._config, {'dsn': uri})
 
-        db = get_db('postgres:///path/to/local/cluster/thedb')
+        uri = 'postgresql://joe:secret@thehost:44/thedb'
+        db = get_db(uri)
         eq_(type(db), PostgresDatabase)
-        eq_(db._config, {
-            'host': '/path/to/local/cluster', 'port': None,
-            'user': None, 'password': None,
-            'database': 'thedb',
-        })
-
-        db = get_db('postgres://%2Fpath%2Fto%2Flocal%2Fcluster/thedb')
-        eq_(type(db), PostgresDatabase)
-        eq_(db._config, {
-            'host': '/path/to/local/cluster', 'port': None,
-            'user': None, 'password': None,
-            'database': 'thedb',
-        })
-
-        with self.assertRaises(ValueError):
-            get_db('unknown://host/db')
+        eq_(db._config, {'dsn': uri})