95e622fb42ab — aurelien@trantor.local 13 years ago
trim down a bit, add an internal signature/method cache, add a test
2 files changed, 53 insertions(+), 27 deletions(-)

M dispatch.py
M test/test_gf.py
M dispatch.py +19 -27
@@ 12,40 12,32 @@ def lexicographic_mro(signature, matches
 
 #-- GF class
 class GF(object):
-    def __init__(self, table=(), order=lexicographic_mro):
-        self.table = []
-        for rule in table:
-            self.add_rule(*rule)
-        self.order = order
+    def __init__(self):
+        self._table = []
+        self._cache = {}
 
     def __call__(self, *args):
-        signature = [o.__class__ for o in args]
-        linearization = self.linearize_table(signature)
-        return linearization[0][0](*args)
+        signature = tuple(o.__class__ for o in args)
+        if signature not in self._cache:
+            func = self.linearize_table(signature)[0][0]
+            self._cache[signature] = func
+        else:
+            func = self._cache[signature]
+        return func(*args)
 
     def add_rule(self, signature, func):
-        self.table.append((signature, func, 0))
+        self._table.append((signature, func, 0))
 
-    def next_method(self):
-        func, _ = self.linearization[self.pos]
-        self.pos +=1
-        result = func(*self.args)
-        self.pos -= 1
-        return result
-
-    def linearize_table(self, sig):
+    def linearize_table(self, signature):
         from operator import mul
-        len_match = lambda (s,f,nm): len(s) == len(sig)
-        typ_match = lambda (s,f,nm): reduce(mul, map(issubclass, sig, s))
-        all_match = lambda x: len_match(x) and typ_match(x)
-        table = filter(all_match, self.table)
+        table = [(s,f,nm) for s,f,nm in self._table
+                 if len(s) == len(signature)
+                 and reduce(mul, map(issubclass, signature, s))]
         if not table:
             def nomatch(*a):
-                raise TypeError, \
-                  "%s instance: no defined call signature <%s> for args (%s)" % \
-                  (self.__class__.__name__,
-                   ",".join([str(o) for o in sig]),
-                   a)
+                raise TypeError("%s: no defined call signature <%s> for args (%s)" %
+                                (self.__class__.__name__,
+                                 ",".join([str(o) for o in signature]), a))
             return [(nomatch,0)]
-        return map(lambda l:l[1:], self.order(sig, table))
+        return map(lambda l:l[1:], lexicographic_mro(signature, table))
 

          
M test/test_gf.py +34 -0
@@ 16,6 16,40 @@ class TestGF(unittest.TestCase):
         self.assertEquals(foo('a', 'b'), 'ab')
         self.assertRaises(TypeError, foo, 1, 'b')
 
+    def test_scissors(self):
+        class Thing(object): pass
+        class Scissors(Thing): pass
+        class Paper(Thing): pass
+        class Rock(Thing): pass
+        @method(Scissors, Paper)
+        def beats(x, y):
+            return True
+        @method(Paper, Scissors)
+        def beats(x, y):
+            return not beats(y, x)
+        @method(Scissors, Rock)
+        def beats(x, y):
+            return False
+        @method(Rock, Scissors)
+        def beats(x, y):
+            return not beats(y, x)
+        @method(Rock, Paper)
+        def beats(x, y):
+            return False
+        @method(Paper, Rock)
+        def beats(x, y):
+            return not beats(y, x)
+
+        s, p, r = Scissors(), Paper(), Rock()
+        print unittest.__file__
+        self.assertTrue(beats(s, p))
+        self.assertFalse(beats(p, s))
+        self.assertTrue(beats(r, s))
+        self.assertFalse(beats(s, r))
+        self.assertTrue(beats(p, r))
+        self.assertFalse(beats(r, p))
+
+
 if __name__ == '__main__':
     unittest.main()