45d72af2b420 — aurelien@trantor.local 13 years ago
re-add next-method support
2 files changed, 51 insertions(+), 19 deletions(-)

M dispatch.py
M test/test_gf.py
M dispatch.py +24 -18
@@ 4,41 4,47 @@ def proximity(klass, mro):
     return mro.index(klass)
 
 def lexicographic_mro(signature, matches):
-    "Use dispatch ranking similar to CLOS"
-    # Schwartzian transform to weight match sigs, left-to-right"
+    "dispatch ranking similar to CLOS"
     mros = [klass.mro() for klass in signature]
     for (sig,func),i in zip(matches,xrange(1000)):
         matches[i] = (map(proximity, sig, mros), matches[i])
-    matches.sort()
+    matches.sort(reverse=True)
     return map(lambda t:t[1], matches)
 
-#-- GF class
+def signature(*args):
+    return tuple(o.__class__ for o in args)
+
 class GF(object):
     def __init__(self):
         self._table = []
+
+    def add_rule(self, sig, func):
+        self._table.append((sig, func))
         self._cache = {}
 
     def __call__(self, *args):
-        signature = tuple(o.__class__ for o in args)
-        if signature not in self._cache:
-            func = self.linearize_table(signature)[0][0]
-            self._cache[signature] = func
+        sig = signature(*args)
+        self._pos = 0
+        self._funcs = self.linearized_table(sig)
+        if sig not in self._cache:
+            func = self._funcs.pop()[0]
+            self._cache[sig] = func
         else:
-            func = self._cache[signature]
+            func = self._cache[sig]
         return func(*args)
 
-    def add_rule(self, signature, func):
-        self._table.append((signature, func))
+    def next_method(self, *args):
+        func = self._funcs.pop()[0]
+        return func(*args)
 
-    def linearize_table(self, signature):
+    def linearized_table(self, sig):
         table = [(s,f) for s,f in self._table
-                 if len(s) == len(signature)
-                 and reduce(mul, map(issubclass, signature, s))]
+                 if len(s) == len(sig)
+                 and reduce(mul, map(issubclass, sig, s))]
         if not table:
             def nomatch(*a):
-                raise TypeError("%s: no defined call signature <%s> for args (%s)" %
-                                (self.__class__.__name__,
-                                 ",".join([str(o) for o in signature]), a))
+                raise TypeError('no defined call signature <%s> for args (%s)' %
+                                (','.join([str(o) for o in sig]), a))
             return [(nomatch,)]
-        return map(lambda l:l[1:], lexicographic_mro(signature, table))
+        return map(lambda l:l[1:], lexicographic_mro(sig, table))
 

          
M test/test_gf.py +27 -1
@@ 16,6 16,33 @@ class TestGF(unittest.TestCase):
         self.assertEquals(foo('a', 'b'), 'ab')
         self.assertRaises(TypeError, foo, 1, 'b')
 
+    def test_diamond(self):
+        class Top(object): pass
+        class Left(Top): pass
+        class Right(Top): pass
+        class Bottom(Left, Right): pass
+
+        @method(Top, list)
+        def bip(x, l):
+            l.append('top')
+        @method(Left, list)
+        def bip(x, l):
+            bip.next_method(x, l)
+            l.append('left')
+        @method(Right, list)
+        def bip(x, l):
+            bip.next_method(x, l)
+            l.append('right')
+        @method(Bottom, list)
+        def bip(x, l):
+            bip.next_method(x, l)
+            l.append('bottom')
+
+        b = Bottom()
+        out = []
+        bip(b, out)
+        self.assertEquals(out, ['top', 'right', 'left', 'bottom'])
+
     def test_scissors(self):
         class Thing(object): pass
         class Scissors(Thing): pass

          
@@ 41,7 68,6 @@ class TestGF(unittest.TestCase):
             return not beats(y, x)
 
         s, p, r = Scissors(), Paper(), Rock()
-        print unittest.__file__
         self.assertTrue(beats(s, p))
         self.assertFalse(beats(p, s))
         self.assertTrue(beats(r, s))