303b62af36f4 — Oben Sonne 9 years ago
provide a factory function for database objects
3 files changed, 90 insertions(+), 12 deletions(-)

M src/deeno/db.py
M src/tests/__init__.py
M src/tests/test_db.py
M src/deeno/db.py +31 -6
@@ 1,3 1,4 @@ 
+import os
 from functools import partial, wraps
 from contextlib import contextmanager
 import logging

          
@@ 5,6 6,7 @@ import re
 import sqlite3
 import threading
 from collections import namedtuple
+from urllib.parse import urlparse, unquote
 
 from deeno import sql, DeenoException
 

          
@@ 379,17 381,19 @@ class PostgresDatabase(Database):
     placeholder = '%s'
     schema_separator_char = '.'
 
-    def __init__(self, host, user, dbname):
+    def __init__(self, host, port=None, user=None, password=None,
+            database=None):
         if isinstance(psycopg2, ImportError):
             raise psycopg2
         super(PostgresDatabase, self).__init__()
-        self._host = host
-        self._user = user
-        self._dbname = dbname
+        self._config = {
+            'host': host, 'port': port,
+            'user': user, 'password': password,
+            'database': database
+        }
 
     def connect(self):
-        self._tl.connection = psycopg2.connect(host=self._host,
-            user=self._user, database=self._dbname)
+        self._tl.connection = psycopg2.connect(**self._config)
         self._tl.connection.cursor_factory = NamedTupleCursor
         self._tl.connection.autocommit = True
 

          
@@ 468,3 472,24 @@ class PostgresDatabase(Database):
     def insert_and_return_row(self, relation, row):
         rows = relation.insert(row=row, returning='*')
         return next(rows)
+
+
+def get_db(uri):
+    """Database factory method based on a connection URI."""
+
+    u = urlparse(uri)
+
+    if u.scheme == 'sqlite':
+        path = u.netloc if u.netloc == ':memory:' else u.path
+        return SQLiteDatabase(path)
+
+    if u.scheme == 'postgres':
+        if not u.hostname:
+            host = os.path.dirname(u.path)
+        else:
+            host = unquote(u.hostname)
+        database = os.path.basename(u.path)
+        return PostgresDatabase(host, port=u.port, user=u.username,
+            password=u.password, database=database)
+
+    raise ValueError('unsupported db: %s' % u.scheme)

          
M src/tests/__init__.py +4 -0
@@ 1,7 1,11 @@ 
+import os
 import mock
 
 from deeno.db import Database
 
+HERE = os.path.dirname(__file__)
+TESTS_WD = os.path.realpath(os.path.join(HERE, '..', '..', 'var', 'tests'))
+
 
 def MockedDatabase():
 

          
M src/tests/test_db.py +55 -6
@@ 12,9 12,9 @@ from nose.tools import eq_, ok_, raises
 import mock
 
 from deeno.db import (Record, Relation, Relatiomat, NoMatchingRecord,
-    SQLiteDatabase, PostgresDatabase, RelationNotFound)
+    SQLiteDatabase, PostgresDatabase, RelationNotFound, get_db)
 
-from tests import MockedDatabase
+from tests import MockedDatabase, TESTS_WD
 
 
 def make_row(**kwargs):

          
@@ 620,9 620,8 @@ class PostgresDatabaseTest(AbstractDatab
     def setUpClass(cls):
 
         here = os.path.dirname(__file__)
-        pgconf = os.path.join(os.path.dirname(__file__), 'ressources',
-            'postgresql.conf')
-        pgtc = os.path.realpath(os.path.join(here, '..', '..', 'var', 'pgtc'))
+        pgconf = os.path.join(here, 'ressources', 'postgresql.conf')
+        pgtc = os.path.join(TESTS_WD, 'pgtc')
 
         for pgbin in cls.PG_BINS:
             if os.path.exists(os.path.join(pgbin, 'postgres')):

          
@@ 665,7 664,8 @@ class PostgresDatabaseTest(AbstractDatab
 
     def setUp(self):
 
-        self.db = PostgresDatabase(self.pgtc, os.environ['USER'], 'deeno')
+        self.db = PostgresDatabase(self.pgtc, user=os.environ['USER'],
+            database='deeno')
 
     def tearDown(self):
 

          
@@ 721,3 721,52 @@ class PostgresDatabaseTest(AbstractDatab
         eq_(r, {'c1': 2, 'c2': 'y'})
         r = self.db.r['s1.t1'].get(c1=2)
         eq_(r, {'c1': 2, 'c2': 'y'})
+
+
+class GetDbTest(unittest.TestCase):
+
+    def setUp(self):
+        self.db = SQLiteDatabase(':memory:')
+        self.sqlite_fname = os.path.join(TESTS_WD, 'sqlite.db')
+        self.cleanup()
+
+    def cleanup(self):
+        if os.path.exists(self.sqlite_fname):
+            os.remove(self.sqlite_fname)
+
+    def test(self):
+
+        db = get_db('sqlite://:memory:')
+        eq_(type(db), SQLiteDatabase)
+        eq_(db._path, ':memory:')
+
+        db = get_db('sqlite://%s' % self.sqlite_fname)
+        eq_(type(db), SQLiteDatabase)
+        eq_(db._path, self.sqlite_fname)
+
+        db = get_db('postgres://joe:secret@thehost:44/thedb')
+        eq_(type(db), PostgresDatabase)
+        eq_(db._config, {
+            'host': 'thehost', 'port': 44,
+            'user': 'joe', 'password': 'secret',
+            'database': 'thedb',
+        })
+
+        db = get_db('postgres:///path/to/local/cluster/thedb')
+        eq_(type(db), PostgresDatabase)
+        eq_(db._config, {
+            'host': '/path/to/local/cluster', 'port': None,
+            'user': None, 'password': None,
+            'database': 'thedb',
+        })
+
+        db = get_db('postgres://%2Fpath%2Fto%2Flocal%2Fcluster/thedb')
+        eq_(type(db), PostgresDatabase)
+        eq_(db._config, {
+            'host': '/path/to/local/cluster', 'port': None,
+            'user': None, 'password': None,
+            'database': 'thedb',
+        })
+
+        with self.assertRaises(ValueError):
+            get_db('unknown://host/db')