7205074644b3 — Leonard Ritter 22 days ago
* glm: matrix constructor now returns constants when all arguments are constant
* `extractvalue`, `insertvalue`, `extractelement` and `insertelement` return constants when all arguments are constant
4 files changed, 74 insertions(+), 36 deletions(-)

M lib/scopes/glm.sc
M src/value.cpp
M src/value.hpp
M testing/test_glm.sc
M lib/scopes/glm.sc +10 -20
@@ 208,15 208,13 @@ typedef+ vec-type
         let ET argsz =
             'element@ self 0; 'argcount ...
         # count sum of elements
-        let flatargsz const? =
-            fold (total const? = 0 false) for arg in ('args ...)
+        let flatargsz =
+            fold (total = 0) for arg in ('args ...)
                 let argT = ('typeof arg)
-                _
-                    + total
-                        if (argT < vec-type)
-                            'element-count argT
-                        else 1
-                    const? & ('constant? arg)
+                + total
+                    if (argT < vec-type)
+                        'element-count argT
+                    else 1
         let vecsz = ('element-count self)
         let initval = (sc_const_null_new self)
         if (flatargsz == 0)

          
@@ 227,10 225,7 @@ typedef+ vec-type
             let argT = ('typeof arg)
             let arg =
                 if (argT < vec-type)
-                    if ('constant? arg)
-                        sc_const_extract_at arg 0
-                    else
-                        `(extractelement arg 0)
+                    `(extractelement arg 0)
                 else `(arg as ET)
             let smear = vector.smear
             `(bitcast (smear arg vecsz) self)

          
@@ 241,14 236,9 @@ typedef+ vec-type
                 if (argT < vec-type)
                     let argET argvecsz =
                         'element@ argT 0; 'element-count argT
-                    if ('constant? arg)
-                        fold (total = total) for k in (range argvecsz)
-                            values @ total = (sc_const_extract_at arg k)
-                            total + 1
-                    else
-                        fold (total = total) for k in (range argvecsz)
-                            values @ total = `(extractelement arg k)
-                            total + 1
+                    fold (total = total) for k in (range argvecsz)
+                        values @ total = `(extractelement arg k)
+                        total + 1
                 else
                     values @ total = arg
                     total + 1

          
M src/value.cpp +48 -11
@@ 1395,8 1395,13 @@ ExtractValue::ExtractValue(const TypedVa
         value_type_at_index(_value->get_type(), _index)),
     value(_value), index(_index) {}
 
-ExtractValueRef ExtractValue::from(const TypedValueRef &value, uint32_t index) {
-    return ref(unknown_anchor(), new ExtractValue(value, index));
+TypedValueRef ExtractValue::from(const TypedValueRef &value, uint32_t index) {
+    if (value.isa<ConstAggregate>()
+        && is_plain(value->get_type())) {
+        return get_field(value.cast<ConstAggregate>(), index);
+    } else {
+        return ref(unknown_anchor(), new ExtractValue(value, index));
+    }
 }
 
 //------------------------------------------------------------------------------

          
@@ 1405,22 1410,43 @@ InsertValue::InsertValue(const TypedValu
     : Instruction(VK_InsertValue, _value->get_type()),
         value(_value), element(_element), index(_index) {}
 
-InsertValueRef InsertValue::from(const TypedValueRef &value, const TypedValueRef &element, uint32_t index) {
-    return ref(unknown_anchor(), new InsertValue(value, element, index));
+TypedValueRef InsertValue::from(const TypedValueRef &value, const TypedValueRef &element, uint32_t index) {
+    if (value.isa<ConstAggregate>()
+        && element.isa<Const>()
+        && is_plain(value->get_type())
+        && is_plain(element->get_type())) {
+        return set_field(value.cast<ConstAggregate>(), element.cast<Const>(), index);
+    } else {
+        return ref(unknown_anchor(), new InsertValue(value, element, index));
+    }
 }
 
 //------------------------------------------------------------------------------
 
 ExtractElement::ExtractElement(const TypedValueRef &_value, const TypedValueRef &_index)
     : Instruction(VK_ExtractElement, value_type_at_index(_value->get_type(), 0)), value(_value), index(_index) {}
-ExtractElementRef ExtractElement::from(const TypedValueRef &value, const TypedValueRef &index) {
-    return ref(unknown_anchor(), new ExtractElement(value, index));
+TypedValueRef ExtractElement::from(const TypedValueRef &value, const TypedValueRef &index) {
+    if (value.isa<ConstAggregate>()
+        && index.isa<ConstInt>()
+        && is_plain(value->get_type())) {
+        return get_field(value.cast<ConstAggregate>(), index.cast<ConstInt>()->msw());
+    } else {
+        return ref(unknown_anchor(), new ExtractElement(value, index));
+    }
 }
 
 InsertElement::InsertElement(const TypedValueRef &_value, const TypedValueRef &_element, const TypedValueRef &_index)
     : Instruction(VK_InsertElement, _value->get_type()), value(_value), element(_element), index(_index) {}
-InsertElementRef InsertElement::from(const TypedValueRef &value, const TypedValueRef &element, const TypedValueRef &index) {
-    return ref(unknown_anchor(), new InsertElement(value, element, index));
+TypedValueRef InsertElement::from(const TypedValueRef &value, const TypedValueRef &element, const TypedValueRef &index) {
+    if (value.isa<ConstAggregate>()
+        && element.isa<Const>()
+        && index.isa<ConstInt>()
+        && is_plain(value->get_type())
+        && is_plain(element->get_type())) {
+        return set_field(value.cast<ConstAggregate>(), element.cast<Const>(), index.cast<ConstInt>()->msw());
+    } else {
+        return ref(unknown_anchor(), new InsertElement(value, element, index));
+    }
 }
 
 ShuffleVector::ShuffleVector(const TypedValueRef &_v1, const TypedValueRef &_v2, const std::vector<uint32_t> &_mask)

          
@@ 1797,8 1823,16 @@ ConstAggregateRef ConstAggregate::ast_fr
     return from(TYPE_ValueRef, { ptr, ConstPointer::anchor_from(node.anchor()).unref() });
 }
 
-ConstRef get_field(const ConstAggregateRef &value, int i) {
-    return ref(value.anchor(), value->values[i]);
+ConstRef get_field(const ConstAggregateRef &value, uint32_t i) {
+    auto VT = value_type_at_index(value->get_type(), i);
+    return PureCast::from(VT, ref(value.anchor(), value->values[i])).cast<Const>();
+}
+
+ConstRef set_field(const ConstAggregateRef &value, const ConstRef &element, uint32_t i) {
+    ConstantPtrs values = value->values;
+    assert (i < values.size());
+    values[i] = element.unref();
+    return ConstAggregate::from(value->get_type(), values);
 }
 
 //------------------------------------------------------------------------------

          
@@ 2049,7 2083,10 @@ TypedValue::TypedValue(ValueKind _kind, 
 
 void TypedValue::hack_change_value(const Type *T) {
     assert(T);
-    _type = T;
+    if (_type != T) {
+        assert(!Const::classof(this)); // must not retype constants
+        _type = T;
+    }
 }
 
 const Type *TypedValue::get_type() const {

          
M src/value.hpp +6 -5
@@ 904,7 904,7 @@ struct ExtractValue : Instruction {
     static bool classof(const Value *T);
 
     ExtractValue(const TypedValueRef &value, uint32_t index);
-    static ExtractValueRef from(const TypedValueRef &value, uint32_t index);
+    static TypedValueRef from(const TypedValueRef &value, uint32_t index);
     TypedValueRef value;
     uint32_t index;
 };

          
@@ 913,7 913,7 @@ struct InsertValue : Instruction {
     static bool classof(const Value *T);
 
     InsertValue(const TypedValueRef &value, const TypedValueRef &element, uint32_t index);
-    static InsertValueRef from(const TypedValueRef &value, const TypedValueRef &element, uint32_t index);
+    static TypedValueRef from(const TypedValueRef &value, const TypedValueRef &element, uint32_t index);
     TypedValueRef value;
     TypedValueRef element;
     uint32_t index;

          
@@ 925,7 925,7 @@ struct ExtractElement : Instruction {
     static bool classof(const Value *T);
 
     ExtractElement(const TypedValueRef &value, const TypedValueRef &index);
-    static ExtractElementRef from(const TypedValueRef &value, const TypedValueRef &index);
+    static TypedValueRef from(const TypedValueRef &value, const TypedValueRef &index);
     TypedValueRef value;
     TypedValueRef index;
 };

          
@@ 934,7 934,7 @@ struct InsertElement : Instruction {
     static bool classof(const Value *T);
 
     InsertElement(const TypedValueRef &value, const TypedValueRef &element, const TypedValueRef &index);
-    static InsertElementRef from(const TypedValueRef &value, const TypedValueRef &element, const TypedValueRef &index);
+    static TypedValueRef from(const TypedValueRef &value, const TypedValueRef &element, const TypedValueRef &index);
     TypedValueRef value;
     TypedValueRef element;
     TypedValueRef index;

          
@@ 1255,7 1255,8 @@ struct ConstAggregate : Const {
     std::size_t _hash;
 };
 
-ConstRef get_field(const ConstAggregateRef &value, int i);
+ConstRef get_field(const ConstAggregateRef &value, uint32_t i);
+ConstRef set_field(const ConstAggregateRef &value, const ConstRef &element, uint32_t i);
 
 //------------------------------------------------------------------------------
 

          
M testing/test_glm.sc +10 -0
@@ 75,6 75,16 @@ test ((mat4 1) != (mat4 0))
 
 test ((mat4 1) == (mat4))
 
+test (constant? (mat4))
+test (constant? (mat4 1))
+test
+    constant?
+        mat4
+            \ 10 0 0 0
+            \ 0 10 0 0
+            \ 0 0 1 0
+            vec2 0; vec2 1
+
 test
     ==
         mat4