M __init__.py +2 -1
@@ 3,7 3,8 @@
from gf import dispatch as mm
def method(*types):
- "a method decorator on top of David Mertz's GF package"
+ """ decorator to globally register a method in a gf
+ named like the method itself"""
def _wrap(meth):
name = meth.__name__
gf = _METHS.get(name)
M dispatch.py +31 -17
@@ 3,18 3,29 @@ from operator import mul
def proximity(klass, mro):
return mro.index(klass)
-def lexicographic_mro(signature, matches):
- "dispatch ranking similar to CLOS"
- mros = tuple(klass.__mro__ for klass in signature)
- return sorted((map(proximity, sig, mros), func)
- for sig, func in matches)
+def check_ambiguity(candidates, mros):
+ def dominates(dom, sub,
+ orders=tuple(dict((t, i) for i, t in enumerate(mro))
+ for mro in mros)):
+ if dom is sub:
+ return False
+ return all(order[d] <= order[s]
+ for d, s, order in zip(dom, sub, orders))
+ return [cand for cand, _func in candidates
+ if not any(dominates(dom, cand)
+ for dom, _xxx in candidates)]
def clos_mro(types_func_map, callsig):
siglen = len(callsig)
- table = ((sig, func) for sig, func in types_func_map.iteritems()
+ table = [(sig, func) for sig, func in types_func_map.iteritems()
if len(sig) == siglen
- and reduce(mul, map(issubclass, callsig, sig)))
- return (x[1] for x in lexicographic_mro(callsig, table))
+ and reduce(mul, map(issubclass, callsig, sig))]
+ mros = tuple(klass.__mro__ for klass in callsig)
+ if len(check_ambiguity(table, mros)) > 1:
+ raise TypeError("ambigous call; types=%r; candidates=%r" %
+ (types_func_map, table))
+ return (x[1] for x in sorted((map(proximity, sig, mros), func)
+ for sig, func in table))
class GF(object):
def __init__(self, mro=clos_mro):
@@ 23,22 34,25 @@ class GF(object):
self._cache = {}
def register(self, sig, func):
- self._reg[tuple(sig)] = func
+ self._reg[sig] = func
self._cache = {}
def __call__(self, *args, **kwargs):
+ self._pos = 0
sig = tuple(x.__class__ for x in args)
func = self._cache.get(sig)
if func is None:
- self._funcs = self.mro(self._reg, sig)
+ self._funcs = list(self.mro(self._reg, sig))
try:
- self._cache[sig] = func = self._funcs.next()
- except StopIteration:
- raise TypeError('no defined call signature args (%s)' % args)
- return func(*args)
+ self._cache[sig] = func = self._funcs[self._pos]
+ except IndexError:
+ raise TypeError('no defined call signature for args (%s)'
+ % ','.join(str(arg) for arg in args))
+ return func(*args, **kwargs)
- def next_method(self, *args):
- func = self._funcs.next()
- return func(*args)
+ def next_method(self, *args, **kwargs):
+ self._pos += 1
+ func = self._funcs[self._pos]
+ return func(*args, **kwargs)
M test/test_gf.py +5 -34
@@ 11,10 11,10 @@ class TestGF(unittest.TestCase):
@method(str, str)
def foo(a, b):
return a+b
-
self.assertEquals(foo(1, 2), 3)
self.assertEquals(foo('a', 'b'), 'ab')
self.assertRaises(TypeError, foo, 1, 'b')
+ del foo
def test_diamond(self):
class Top(object): pass
@@ 42,38 42,9 @@ class TestGF(unittest.TestCase):
out = []
bip(b, out)
self.assertEquals(out, ['top', 'right', 'left', 'bottom'])
-
- def test_scissors(self):
- class Thing(object): pass
- class Scissors(Thing): pass
- class Paper(Thing): pass
- class Rock(Thing): pass
- @method(Scissors, Paper)
- def beats(x, y):
- return True
- @method(Paper, Scissors)
- def beats(x, y):
- return not beats(y, x)
- @method(Scissors, Rock)
- def beats(x, y):
- return False
- @method(Rock, Scissors)
- def beats(x, y):
- return not beats(y, x)
- @method(Rock, Paper)
- def beats(x, y):
- return False
- @method(Paper, Rock)
- def beats(x, y):
- return not beats(y, x)
-
- s, p, r = Scissors(), Paper(), Rock()
- self.assertTrue(beats(s, p))
- self.assertFalse(beats(p, s))
- self.assertTrue(beats(r, s))
- self.assertFalse(beats(s, r))
- self.assertTrue(beats(p, r))
- self.assertFalse(beats(r, p))
+ out = []
+ bip(b, out)
+ self.assertEquals(out, ['top', 'right', 'left', 'bottom'])
def test_deep_diamond(self):
class A(object): pass
@@ 108,7 79,7 @@ class TestGF(unittest.TestCase):
def foo(b,x):
return b,x
b, y = B(), Y()
- foo(b, y)
+ self.assertRaises(TypeError, foo, b, y)
if __name__ == '__main__':
unittest.main()