@@ 1,3 1,4 @@
+import os
from functools import partial, wraps
from contextlib import contextmanager
import logging
@@ 5,6 6,7 @@ import re
import sqlite3
import threading
from collections import namedtuple
+from urllib.parse import urlparse, unquote
from deeno import sql, DeenoException
@@ 379,17 381,19 @@ class PostgresDatabase(Database):
placeholder = '%s'
schema_separator_char = '.'
- def __init__(self, host, user, dbname):
+ def __init__(self, host, port=None, user=None, password=None,
+ database=None):
if isinstance(psycopg2, ImportError):
raise psycopg2
super(PostgresDatabase, self).__init__()
- self._host = host
- self._user = user
- self._dbname = dbname
+ self._config = {
+ 'host': host, 'port': port,
+ 'user': user, 'password': password,
+ 'database': database
+ }
def connect(self):
- self._tl.connection = psycopg2.connect(host=self._host,
- user=self._user, database=self._dbname)
+ self._tl.connection = psycopg2.connect(**self._config)
self._tl.connection.cursor_factory = NamedTupleCursor
self._tl.connection.autocommit = True
@@ 468,3 472,24 @@ class PostgresDatabase(Database):
def insert_and_return_row(self, relation, row):
rows = relation.insert(row=row, returning='*')
return next(rows)
+
+
+def get_db(uri):
+ """Database factory method based on a connection URI."""
+
+ u = urlparse(uri)
+
+ if u.scheme == 'sqlite':
+ path = u.netloc if u.netloc == ':memory:' else u.path
+ return SQLiteDatabase(path)
+
+ if u.scheme == 'postgres':
+ if not u.hostname:
+ host = os.path.dirname(u.path)
+ else:
+ host = unquote(u.hostname)
+ database = os.path.basename(u.path)
+ return PostgresDatabase(host, port=u.port, user=u.username,
+ password=u.password, database=database)
+
+ raise ValueError('unsupported db: %s' % u.scheme)
@@ 1,7 1,11 @@
+import os
import mock
from deeno.db import Database
+HERE = os.path.dirname(__file__)
+TESTS_WD = os.path.realpath(os.path.join(HERE, '..', '..', 'var', 'tests'))
+
def MockedDatabase():
@@ 12,9 12,9 @@ from nose.tools import eq_, ok_, raises
import mock
from deeno.db import (Record, Relation, Relatiomat, NoMatchingRecord,
- SQLiteDatabase, PostgresDatabase, RelationNotFound)
+ SQLiteDatabase, PostgresDatabase, RelationNotFound, get_db)
-from tests import MockedDatabase
+from tests import MockedDatabase, TESTS_WD
def make_row(**kwargs):
@@ 620,9 620,8 @@ class PostgresDatabaseTest(AbstractDatab
def setUpClass(cls):
here = os.path.dirname(__file__)
- pgconf = os.path.join(os.path.dirname(__file__), 'ressources',
- 'postgresql.conf')
- pgtc = os.path.realpath(os.path.join(here, '..', '..', 'var', 'pgtc'))
+ pgconf = os.path.join(here, 'ressources', 'postgresql.conf')
+ pgtc = os.path.join(TESTS_WD, 'pgtc')
for pgbin in cls.PG_BINS:
if os.path.exists(os.path.join(pgbin, 'postgres')):
@@ 665,7 664,8 @@ class PostgresDatabaseTest(AbstractDatab
def setUp(self):
- self.db = PostgresDatabase(self.pgtc, os.environ['USER'], 'deeno')
+ self.db = PostgresDatabase(self.pgtc, user=os.environ['USER'],
+ database='deeno')
def tearDown(self):
@@ 721,3 721,52 @@ class PostgresDatabaseTest(AbstractDatab
eq_(r, {'c1': 2, 'c2': 'y'})
r = self.db.r['s1.t1'].get(c1=2)
eq_(r, {'c1': 2, 'c2': 'y'})
+
+
+class GetDbTest(unittest.TestCase):
+
+ def setUp(self):
+ self.db = SQLiteDatabase(':memory:')
+ self.sqlite_fname = os.path.join(TESTS_WD, 'sqlite.db')
+ self.cleanup()
+
+ def cleanup(self):
+ if os.path.exists(self.sqlite_fname):
+ os.remove(self.sqlite_fname)
+
+ def test(self):
+
+ db = get_db('sqlite://:memory:')
+ eq_(type(db), SQLiteDatabase)
+ eq_(db._path, ':memory:')
+
+ db = get_db('sqlite://%s' % self.sqlite_fname)
+ eq_(type(db), SQLiteDatabase)
+ eq_(db._path, self.sqlite_fname)
+
+ db = get_db('postgres://joe:secret@thehost:44/thedb')
+ eq_(type(db), PostgresDatabase)
+ eq_(db._config, {
+ 'host': 'thehost', 'port': 44,
+ 'user': 'joe', 'password': 'secret',
+ 'database': 'thedb',
+ })
+
+ db = get_db('postgres:///path/to/local/cluster/thedb')
+ eq_(type(db), PostgresDatabase)
+ eq_(db._config, {
+ 'host': '/path/to/local/cluster', 'port': None,
+ 'user': None, 'password': None,
+ 'database': 'thedb',
+ })
+
+ db = get_db('postgres://%2Fpath%2Fto%2Flocal%2Fcluster/thedb')
+ eq_(type(db), PostgresDatabase)
+ eq_(db._config, {
+ 'host': '/path/to/local/cluster', 'port': None,
+ 'user': None, 'password': None,
+ 'database': 'thedb',
+ })
+
+ with self.assertRaises(ValueError):
+ get_db('unknown://host/db')