575d6a8dd84f — Leonard Ritter 5 years ago
* memoize all constants
3 files changed, 58 insertions(+), 16 deletions(-)

M src/globals.cpp
M src/value.cpp
M src/value.hpp
M src/globals.cpp +4 -1
@@ 476,6 476,7 @@ namespace scopes {
 
 struct MemoKeyEqual {
     bool operator()( Value *lhs, Value *rhs ) const {
+        if (lhs == rhs) return true;
         if (lhs->kind() != rhs->kind())
             return false;
         if (isa<ArgumentList>(lhs)) {

          
@@ 486,6 487,7 @@ struct MemoKeyEqual {
             for (int i = 0; i < a->values.size(); ++i) {
                 auto u = a->values[i];
                 auto v = b->values[i];
+                if (u == v) continue;
                 if (u->kind() != v->kind())
                     return false;
                 if (u.isa<Pure>()) {

          
@@ 503,6 505,7 @@ struct MemoKeyEqual {
             for (int i = 0; i < a->values.size(); ++i) {
                 auto u = a->values[i];
                 auto v = b->values[i];
+                if (u == v) continue;
                 if (u->kind() != v->kind())
                     return false;
                 if (u.isa<Pure>()) {

          
@@ 517,7 520,7 @@ struct MemoKeyEqual {
         } else if (isa<Pure>(lhs)) {
             return cast<Pure>(lhs)->key_equal(cast<Pure>(rhs));
         } else {
-            return lhs == rhs;
+            return false;
         }
     }
 };

          
M src/value.cpp +53 -15
@@ 1052,6 1052,39 @@ Const::Const(ValueKind _kind, const Type
 
 //------------------------------------------------------------------------------
 
+template<typename T>
+struct ConstSet {
+    struct Hash {
+        std::size_t operator()(const T *k) const {
+            return k->hash();
+        }
+    };
+
+    struct Equal {
+        std::size_t operator()(const T *self, const T *other) const {
+            return self->key_equal(other);
+        }
+    };
+
+    std::unordered_set<T *, Hash, Equal> map;
+
+    template<typename ... Args>
+    TValueRef<T> from(Args ... args) {
+        T key(args ...);
+        auto it = map.find(&key);
+        if (it != map.end()) {
+            return ref(unknown_anchor(), *it);
+        }
+        auto val = new T(args ...);
+        map.insert(val);
+        return ref(unknown_anchor(), val);
+    }
+};
+
+//------------------------------------------------------------------------------
+
+static ConstSet<ConstInt> constints;
+
 ConstInt::ConstInt(const Type *type, uint64_t _value)
     : Const(VK_ConstInt, type), value(_value) {
 }

          
@@ 1062,11 1095,11 @@ bool ConstInt::key_equal(const ConstInt 
 }
 
 std::size_t ConstInt::hash() const {
-    return std::hash<uint64_t>{}(value);
+    return hash2(std::hash<const Type *>{}(get_type()), std::hash<uint64_t>{}(value));
 }
 
 ConstIntRef ConstInt::from(const Type *type, uint64_t value) {
-    return ref(unknown_anchor(), new ConstInt(type, value));
+    return constints.from(type, value);
 }
 
 ConstIntRef ConstInt::symbol_from(Symbol value) {

          
@@ 1079,6 1112,8 @@ ConstIntRef ConstInt::builtin_from(Built
 
 //------------------------------------------------------------------------------
 
+static ConstSet<ConstReal> constreals;
+
 ConstReal::ConstReal(const Type *type, double _value)
     : Const(VK_ConstReal, type), value(_value) {}
 

          
@@ 1088,17 1123,24 @@ bool ConstReal::key_equal(const ConstRea
 }
 
 std::size_t ConstReal::hash() const {
-    return std::hash<double>{}(value);
+    return hash2(std::hash<const Type *>{}(get_type()), std::hash<double>{}(value));
 }
 
 ConstRealRef ConstReal::from(const Type *type, double value) {
-    return ref(unknown_anchor(), new ConstReal(type, value));
+    return constreals.from(type, value);
 }
 
 //------------------------------------------------------------------------------
 
+static ConstSet<ConstAggregate> constaggs;
+
 ConstAggregate::ConstAggregate(const Type *type, const ConstantPtrs &_fields)
     : Const(VK_ConstAggregate, type), values(_fields) {
+    uint64_t h = std::hash<const Type *>{}(get_type());
+    for (int i = 0; i < values.size(); ++i) {
+        h = hash2(h, values[i]->hash());
+    }
+    _hash = h;
 }
 
 bool ConstAggregate::key_equal(const ConstAggregate *other) const {

          
@@ 1107,22 1149,18 @@ bool ConstAggregate::key_equal(const Con
     for (int i = 0; i < values.size(); ++i) {
         auto a = values[i];
         auto b = other->values[i];
-        if (!a->key_equal(b))
+        if (a != b)
             return false;
     }
     return true;
 }
 
 std::size_t ConstAggregate::hash() const {
-    uint64_t h = std::hash<const Type *>{}(get_type());
-    for (int i = 0; i < values.size(); ++i) {
-        h = hash2(h, values[i]->hash());
-    }
-    return h;
+    return _hash;
 }
 
 ConstAggregateRef ConstAggregate::from(const Type *type, const ConstantPtrs &fields) {
-    return ref(unknown_anchor(), new ConstAggregate(type, fields));
+    return constaggs.from(type, fields);
 }
 
 ConstAggregateRef ConstAggregate::none_from() {

          
@@ 1140,23 1178,23 @@ ConstRef get_field(const ConstAggregateR
 
 //------------------------------------------------------------------------------
 
+static ConstSet<ConstPointer> constptrs;
+
 ConstPointer::ConstPointer(const Type *type, const void *_pointer)
     : Const(VK_ConstPointer, type), value(_pointer) {}
 
 bool ConstPointer::key_equal(const ConstPointer *other) const {
     if (get_type() != other->get_type())
         return false;
-    if (get_type() == TYPE_List)
-        return sc_list_compare((const List *)value, (const List *)other->value);
     return value == other->value;
 }
 
 std::size_t ConstPointer::hash() const {
-    return std::hash<const void *>{}(value);
+    return hash2(std::hash<const Type *>{}(get_type()), std::hash<const void *>{}(value));
 }
 
 ConstPointerRef ConstPointer::from(const Type *type, const void *pointer) {
-    return ref(unknown_anchor(), new ConstPointer(type, pointer));
+    return constptrs.from(type, pointer);
 }
 
 ConstPointerRef ConstPointer::type_from(const Type *type) {

          
M src/value.hpp +1 -0
@@ 711,6 711,7 @@ struct ConstAggregate : Const {
     static ConstAggregateRef ast_from(const ValueRef &node);
 
     ConstantPtrs values;
+    std::size_t _hash;
 };
 
 ConstRef get_field(const ConstAggregateRef &value, int i);