@@ 12,40 12,32 @@ def lexicographic_mro(signature, matches
#-- GF class
class GF(object):
- def __init__(self, table=(), order=lexicographic_mro):
- self.table = []
- for rule in table:
- self.add_rule(*rule)
- self.order = order
+ def __init__(self):
+ self._table = []
+ self._cache = {}
def __call__(self, *args):
- signature = [o.__class__ for o in args]
- linearization = self.linearize_table(signature)
- return linearization[0][0](*args)
+ signature = tuple(o.__class__ for o in args)
+ if signature not in self._cache:
+ func = self.linearize_table(signature)[0][0]
+ self._cache[signature] = func
+ else:
+ func = self._cache[signature]
+ return func(*args)
def add_rule(self, signature, func):
- self.table.append((signature, func, 0))
+ self._table.append((signature, func, 0))
- def next_method(self):
- func, _ = self.linearization[self.pos]
- self.pos +=1
- result = func(*self.args)
- self.pos -= 1
- return result
-
- def linearize_table(self, sig):
+ def linearize_table(self, signature):
from operator import mul
- len_match = lambda (s,f,nm): len(s) == len(sig)
- typ_match = lambda (s,f,nm): reduce(mul, map(issubclass, sig, s))
- all_match = lambda x: len_match(x) and typ_match(x)
- table = filter(all_match, self.table)
+ table = [(s,f,nm) for s,f,nm in self._table
+ if len(s) == len(signature)
+ and reduce(mul, map(issubclass, signature, s))]
if not table:
def nomatch(*a):
- raise TypeError, \
- "%s instance: no defined call signature <%s> for args (%s)" % \
- (self.__class__.__name__,
- ",".join([str(o) for o in sig]),
- a)
+ raise TypeError("%s: no defined call signature <%s> for args (%s)" %
+ (self.__class__.__name__,
+ ",".join([str(o) for o in signature]), a))
return [(nomatch,0)]
- return map(lambda l:l[1:], self.order(sig, table))
+ return map(lambda l:l[1:], lexicographic_mro(signature, table))
@@ 16,6 16,40 @@ class TestGF(unittest.TestCase):
self.assertEquals(foo('a', 'b'), 'ab')
self.assertRaises(TypeError, foo, 1, 'b')
+ def test_scissors(self):
+ class Thing(object): pass
+ class Scissors(Thing): pass
+ class Paper(Thing): pass
+ class Rock(Thing): pass
+ @method(Scissors, Paper)
+ def beats(x, y):
+ return True
+ @method(Paper, Scissors)
+ def beats(x, y):
+ return not beats(y, x)
+ @method(Scissors, Rock)
+ def beats(x, y):
+ return False
+ @method(Rock, Scissors)
+ def beats(x, y):
+ return not beats(y, x)
+ @method(Rock, Paper)
+ def beats(x, y):
+ return False
+ @method(Paper, Rock)
+ def beats(x, y):
+ return not beats(y, x)
+
+ s, p, r = Scissors(), Paper(), Rock()
+ print unittest.__file__
+ self.assertTrue(beats(s, p))
+ self.assertFalse(beats(p, s))
+ self.assertTrue(beats(r, s))
+ self.assertFalse(beats(s, r))
+ self.assertTrue(beats(p, r))
+ self.assertFalse(beats(r, p))
+
+
if __name__ == '__main__':
unittest.main()