# HG changeset patch # User aurelien@trantor.local # Date 1258918645 -3600 # Sun Nov 22 20:37:25 2009 +0100 # Node ID 95e622fb42ab6e6dd9e288219f519a1267ae52d6 # Parent 1d5f549da77a3866d48736a0d135471bfd0cef44 trim down a bit, add an internal signature/method cache, add a test diff --git a/dispatch.py b/dispatch.py --- a/dispatch.py +++ b/dispatch.py @@ -12,40 +12,32 @@ #-- GF class class GF(object): - def __init__(self, table=(), order=lexicographic_mro): - self.table = [] - for rule in table: - self.add_rule(*rule) - self.order = order + def __init__(self): + self._table = [] + self._cache = {} def __call__(self, *args): - signature = [o.__class__ for o in args] - linearization = self.linearize_table(signature) - return linearization[0][0](*args) + signature = tuple(o.__class__ for o in args) + if signature not in self._cache: + func = self.linearize_table(signature)[0][0] + self._cache[signature] = func + else: + func = self._cache[signature] + return func(*args) def add_rule(self, signature, func): - self.table.append((signature, func, 0)) + self._table.append((signature, func, 0)) - def next_method(self): - func, _ = self.linearization[self.pos] - self.pos +=1 - result = func(*self.args) - self.pos -= 1 - return result - - def linearize_table(self, sig): + def linearize_table(self, signature): from operator import mul - len_match = lambda (s,f,nm): len(s) == len(sig) - typ_match = lambda (s,f,nm): reduce(mul, map(issubclass, sig, s)) - all_match = lambda x: len_match(x) and typ_match(x) - table = filter(all_match, self.table) + table = [(s,f,nm) for s,f,nm in self._table + if len(s) == len(signature) + and reduce(mul, map(issubclass, signature, s))] if not table: def nomatch(*a): - raise TypeError, \ - "%s instance: no defined call signature <%s> for args (%s)" % \ - (self.__class__.__name__, - ",".join([str(o) for o in sig]), - a) + raise TypeError("%s: no defined call signature <%s> for args (%s)" % + (self.__class__.__name__, + ",".join([str(o) for o in signature]), a)) return [(nomatch,0)] - return map(lambda l:l[1:], self.order(sig, table)) + return map(lambda l:l[1:], lexicographic_mro(signature, table)) diff --git a/test/test_gf.py b/test/test_gf.py --- a/test/test_gf.py +++ b/test/test_gf.py @@ -16,6 +16,40 @@ self.assertEquals(foo('a', 'b'), 'ab') self.assertRaises(TypeError, foo, 1, 'b') + def test_scissors(self): + class Thing(object): pass + class Scissors(Thing): pass + class Paper(Thing): pass + class Rock(Thing): pass + @method(Scissors, Paper) + def beats(x, y): + return True + @method(Paper, Scissors) + def beats(x, y): + return not beats(y, x) + @method(Scissors, Rock) + def beats(x, y): + return False + @method(Rock, Scissors) + def beats(x, y): + return not beats(y, x) + @method(Rock, Paper) + def beats(x, y): + return False + @method(Paper, Rock) + def beats(x, y): + return not beats(y, x) + + s, p, r = Scissors(), Paper(), Rock() + print unittest.__file__ + self.assertTrue(beats(s, p)) + self.assertFalse(beats(p, s)) + self.assertTrue(beats(r, s)) + self.assertFalse(beats(s, r)) + self.assertTrue(beats(p, r)) + self.assertFalse(beats(r, p)) + + if __name__ == '__main__': unittest.main()