7f6d34babb73 — aurelien@trantor.local 13 years ago
use ambiguity check from gvr's overloading.py, remove useless test
3 files changed, 38 insertions(+), 52 deletions(-)

M __init__.py
M dispatch.py
M test/test_gf.py
M __init__.py +2 -1
@@ 3,7 3,8 @@ 
 from gf import dispatch as mm
 
 def method(*types):
-    "a method decorator on top of David Mertz's GF package"
+    """ decorator to globally register a method in a gf
+    named like the method itself"""
     def _wrap(meth):
         name = meth.__name__
         gf = _METHS.get(name)

          
M dispatch.py +31 -17
@@ 3,18 3,29 @@ from operator import mul
 def proximity(klass, mro):
     return mro.index(klass)
 
-def lexicographic_mro(signature, matches):
-    "dispatch ranking similar to CLOS"
-    mros = tuple(klass.__mro__ for klass in signature)
-    return sorted((map(proximity, sig, mros), func)
-                  for sig, func in matches)
+def check_ambiguity(candidates, mros):
+    def dominates(dom, sub,
+                  orders=tuple(dict((t, i) for i, t in enumerate(mro))
+                               for mro in mros)):
+        if dom is sub:
+            return False
+        return all(order[d] <= order[s]
+                   for d, s, order in zip(dom, sub, orders))
+    return [cand for cand, _func in candidates
+            if not any(dominates(dom, cand)
+                       for dom, _xxx in candidates)]
 
 def clos_mro(types_func_map, callsig):
     siglen = len(callsig)
-    table = ((sig, func) for sig, func in types_func_map.iteritems()
+    table = [(sig, func) for sig, func in types_func_map.iteritems()
              if len(sig) == siglen
-             and reduce(mul, map(issubclass, callsig, sig)))
-    return (x[1] for x in lexicographic_mro(callsig, table))
+             and reduce(mul, map(issubclass, callsig, sig))]
+    mros = tuple(klass.__mro__ for klass in callsig)
+    if len(check_ambiguity(table, mros)) > 1:
+        raise TypeError("ambigous call; types=%r; candidates=%r" %
+                        (types_func_map, table))
+    return (x[1] for x in sorted((map(proximity, sig, mros), func)
+                  for sig, func in table))
 
 class GF(object):
     def __init__(self, mro=clos_mro):

          
@@ 23,22 34,25 @@ class GF(object):
         self._cache = {}
 
     def register(self, sig, func):
-        self._reg[tuple(sig)] = func
+        self._reg[sig] = func
         self._cache = {}
 
     def __call__(self, *args, **kwargs):
+        self._pos = 0
         sig = tuple(x.__class__ for x in args)
         func = self._cache.get(sig)
         if func is None:
-            self._funcs = self.mro(self._reg, sig)
+            self._funcs = list(self.mro(self._reg, sig))
             try:
-                self._cache[sig] = func = self._funcs.next()
-            except StopIteration:
-                raise TypeError('no defined call signature args (%s)' % args)
-        return func(*args)
+                self._cache[sig] = func = self._funcs[self._pos]
+            except IndexError:
+                raise TypeError('no defined call signature for args (%s)'
+                                % ','.join(str(arg) for arg in args))
+        return func(*args, **kwargs)
 
-    def next_method(self, *args):
-        func = self._funcs.next()
-        return func(*args)
+    def next_method(self, *args, **kwargs):
+        self._pos += 1
+        func = self._funcs[self._pos]
+        return func(*args, **kwargs)
 
 

          
M test/test_gf.py +5 -34
@@ 11,10 11,10 @@ class TestGF(unittest.TestCase):
         @method(str, str)
         def foo(a, b):
             return a+b
-
         self.assertEquals(foo(1, 2), 3)
         self.assertEquals(foo('a', 'b'), 'ab')
         self.assertRaises(TypeError, foo, 1, 'b')
+        del foo
 
     def test_diamond(self):
         class Top(object): pass

          
@@ 42,38 42,9 @@ class TestGF(unittest.TestCase):
         out = []
         bip(b, out)
         self.assertEquals(out, ['top', 'right', 'left', 'bottom'])
-
-    def test_scissors(self):
-        class Thing(object): pass
-        class Scissors(Thing): pass
-        class Paper(Thing): pass
-        class Rock(Thing): pass
-        @method(Scissors, Paper)
-        def beats(x, y):
-            return True
-        @method(Paper, Scissors)
-        def beats(x, y):
-            return not beats(y, x)
-        @method(Scissors, Rock)
-        def beats(x, y):
-            return False
-        @method(Rock, Scissors)
-        def beats(x, y):
-            return not beats(y, x)
-        @method(Rock, Paper)
-        def beats(x, y):
-            return False
-        @method(Paper, Rock)
-        def beats(x, y):
-            return not beats(y, x)
-
-        s, p, r = Scissors(), Paper(), Rock()
-        self.assertTrue(beats(s, p))
-        self.assertFalse(beats(p, s))
-        self.assertTrue(beats(r, s))
-        self.assertFalse(beats(s, r))
-        self.assertTrue(beats(p, r))
-        self.assertFalse(beats(r, p))
+        out = []
+        bip(b, out)
+        self.assertEquals(out, ['top', 'right', 'left', 'bottom'])
 
     def test_deep_diamond(self):
         class A(object): pass

          
@@ 108,7 79,7 @@ class TestGF(unittest.TestCase):
         def foo(b,x):
             return b,x
         b, y = B(), Y()
-        foo(b, y)
+        self.assertRaises(TypeError, foo, b, y)
 
 if __name__ == '__main__':
     unittest.main()