@@ 1,4 1,3 @@
-import os
from functools import partial, wraps
from contextlib import contextmanager
import logging
@@ 6,7 5,7 @@ import re
import sqlite3
import threading
from collections import namedtuple
-from urllib.parse import urlparse, unquote
+from urllib.parse import urlparse
from deeno import sql, DeenoException
@@ 381,16 380,21 @@ class PostgresDatabase(Database):
placeholder = '%s'
schema_separator_char = '.'
- def __init__(self, host, port=None, user=None, password=None,
- database=None):
+ def __init__(self, dsn=None, host=None, port=None, user=None,
+ password=None, database=None):
if isinstance(psycopg2, ImportError):
raise psycopg2
super(PostgresDatabase, self).__init__()
- self._config = {
- 'host': host, 'port': port,
- 'user': user, 'password': password,
- 'database': database
- }
+ if dsn:
+ self._config = {
+ 'dsn': dsn
+ }
+ else:
+ self._config = {
+ 'host': host, 'port': port,
+ 'user': user, 'password': password,
+ 'database': database
+ }
def connect(self):
self._tl.connection = psycopg2.connect(**self._config)
@@ 477,19 481,10 @@ class PostgresDatabase(Database):
def get_db(uri):
"""Database factory method based on a connection URI."""
- u = urlparse(uri)
-
- if u.scheme == 'sqlite':
- path = u.netloc if u.netloc == ':memory:' else u.path
- return SQLiteDatabase(path)
+ if uri == ':memory:' or uri.startswith('file:'):
+ return SQLiteDatabase(urlparse(uri).path)
- if u.scheme == 'postgres':
- if not u.hostname:
- host = os.path.dirname(u.path)
- else:
- host = unquote(u.hostname)
- database = os.path.basename(u.path)
- return PostgresDatabase(host, port=u.port, user=u.username,
- password=u.password, database=database)
+ if uri.startswith('postgres:') or uri.startswith('postgresql:'):
+ return PostgresDatabase(dsn=uri)
- raise ValueError('unsupported db: %s' % u.scheme)
+ raise ValueError('unsupported uri: %s' % uri)
@@ 4,6 4,7 @@ from collections import namedtuple
import subprocess
import shutil
import time
+from urllib.parse import quote
import psycopg2
@@ 593,15 594,41 @@ class AbstractDatabaseTest(object):
r = self.db.r.select.get(integer=1)
eq_(r, {'text': 'x', 'integer': 1})
+ def test_connect_with_uri(self):
+ raise NotImplementedError
+
class SqliteDatabaseTest(AbstractDatabaseTest, unittest.TestCase):
ROWCOUNT_FOR_SELECT = False
SERIAL_KEY_TYPE = 'INTEGER PRIMARY KEY ASC'
+ TEST_DB_FNAME = os.path.join(TESTS_WD, 'sqlite.db')
def setUp(self):
self.db = SQLiteDatabase(':memory:')
+ def tearDown(self):
+ if os.path.exists(self.TEST_DB_FNAME):
+ os.remove(self.TEST_DB_FNAME)
+
+ def test_connect_with_uri(self):
+
+ db = get_db(':memory:')
+ eq_(type(db), SQLiteDatabase)
+ db.fetchall('SELECT 1')
+
+ db = get_db('file://%s' % self.TEST_DB_FNAME)
+ eq_(type(db), SQLiteDatabase)
+ db.fetchall('SELECT 1')
+
+ db = get_db('file:%s' % self.TEST_DB_FNAME)
+ eq_(type(db), SQLiteDatabase)
+ db.fetchall('SELECT 1')
+
+ db = get_db('file:%s' % os.path.relpath(self.TEST_DB_FNAME))
+ eq_(type(db), SQLiteDatabase)
+ db.fetchall('SELECT 1')
+
class PostgresDatabaseTest(AbstractDatabaseTest, unittest.TestCase):
@@ 664,7 691,7 @@ class PostgresDatabaseTest(AbstractDatab
def setUp(self):
- self.db = PostgresDatabase(self.pgtc, user=os.environ['USER'],
+ self.db = PostgresDatabase(host=self.pgtc, user=os.environ['USER'],
database='deeno')
def tearDown(self):
@@ 722,51 749,19 @@ class PostgresDatabaseTest(AbstractDatab
r = self.db.r['s1.t1'].get(c1=2)
eq_(r, {'c1': 2, 'c2': 'y'})
-
-class GetDbTest(unittest.TestCase):
-
- def setUp(self):
- self.db = SQLiteDatabase(':memory:')
- self.sqlite_fname = os.path.join(TESTS_WD, 'sqlite.db')
- self.cleanup()
-
- def cleanup(self):
- if os.path.exists(self.sqlite_fname):
- os.remove(self.sqlite_fname)
+ def test_connect_with_uri(self):
- def test(self):
-
- db = get_db('sqlite://:memory:')
- eq_(type(db), SQLiteDatabase)
- eq_(db._path, ':memory:')
+ host = quote(self.pgtc, safe='')
+ uri = 'postgres://%s@%s/deeno' % (os.environ['USER'], host)
+ db = PostgresDatabase(uri)
+ db.fetchall('SELECT 1')
- db = get_db('sqlite://%s' % self.sqlite_fname)
- eq_(type(db), SQLiteDatabase)
- eq_(db._path, self.sqlite_fname)
-
- db = get_db('postgres://joe:secret@thehost:44/thedb')
+ uri = 'postgres://joe:secret@thehost:44/thedb'
+ db = get_db(uri)
eq_(type(db), PostgresDatabase)
- eq_(db._config, {
- 'host': 'thehost', 'port': 44,
- 'user': 'joe', 'password': 'secret',
- 'database': 'thedb',
- })
+ eq_(db._config, {'dsn': uri})
- db = get_db('postgres:///path/to/local/cluster/thedb')
+ uri = 'postgresql://joe:secret@thehost:44/thedb'
+ db = get_db(uri)
eq_(type(db), PostgresDatabase)
- eq_(db._config, {
- 'host': '/path/to/local/cluster', 'port': None,
- 'user': None, 'password': None,
- 'database': 'thedb',
- })
-
- db = get_db('postgres://%2Fpath%2Fto%2Flocal%2Fcluster/thedb')
- eq_(type(db), PostgresDatabase)
- eq_(db._config, {
- 'host': '/path/to/local/cluster', 'port': None,
- 'user': None, 'password': None,
- 'database': 'thedb',
- })
-
- with self.assertRaises(ValueError):
- get_db('unknown://host/db')
+ eq_(db._config, {'dsn': uri})