2f03333cb137 — Leonard Ritter 27 days ago
* UVM: added select optimizations
2 files changed, 243 insertions(+), 54 deletions(-)

M testing/BDD.sc
M testing/test_node2.sc
M testing/BDD.sc +20 -5
@@ 282,12 282,27 @@ for x3 x2 x1 in (dim 2 2 2)
     test ((T a) == b)
     print x1 x2 x3 a b
 
-b1 := (& (& (T 'a) (T 'b)) (T 'c))
-b2 := (& (& (& (T 'a) (T 'b)) (T 'c)) (T 'd))
+do
+    b1 := (& (& (& (T 'a) (T 'b)) (T 'c)) (T 'd))
+    b2 := (& (& (& (T 'a) (T 'b)) (T 'c)) (~ (T 'd)))
+
+    print
+        b1 | b2
+    print
+        b1 & (~ b2)
+    print
+        b2 & (~ b1)
 
-print
-    b1 & (~ b2)
-print b2
+do
+    b1 := (& (& (& (T 'a) (T 'b)) (T 'c)) (T 'd))
+    b2 := (& (& (T 'a) (T 'b)) (T 'c))
+    m :=
+        b1 | b2
+    print "m" m
+    print "b1"
+        T (copy b2) 0 (copy b1)
+    print "b2"
+        b2 | (~ b1)
 
 #
     if a

          
M testing/test_node2.sc +223 -49
@@ 374,7 374,7 @@ do
         Input               OpPure      "I"         -           -               -
         Output              OpPure      "O"         -           -               -
         State               OpPure      "@@"        -           -               -
-        Load                OpControl   ""          evalload    -               -
+        Defined             OpPure      "@"         -           -               -
         Then                OpControl   "@@"        evalthen    -               -
         Merge               OpControl   "@@"        evalmerge   partmerge       commutative
         Select              OpControl   "@@@"       -           -               -

          
@@ 438,15 438,10 @@ struct Node
             hash h operand
 
     fn __copy (self)
-        local node =
-            this-type
-                self.condition
-                self.opCode
-                self.typeId
-        let ops = node.operands
-        for operand in self.operands
-            'append ops operand
-        node
+        super-type.__typecall this-type
+            self.opCode
+            self.typeId
+            copy self.operands
 
     fn getstring (self idx)
         local text : String

          
@@ 503,6 498,8 @@ struct Module
         Map RcNode Id
             fn (rcnode)
                 hash (rcnode as Node)
+    # cached select simplifications
+    select_cache : (Map (tuple Id Id Id) Id)
     rnodes_valid = false
 
     inline __typecall (cls)

          
@@ 615,9 612,23 @@ struct Module
         'nodeId self
             Node Op.Parameter typeId params index
 
-    fn... input (self, name : Input)
+    fn... rawinput (self, typeId : Id, name : Input)
+        'nodeId self
+            Node Op.Input typeId name
+
+    fn... defined? (self, source : Id)
         'nodeId self
-            Node Op.Input ('boolType self) name
+            Node Op.Defined ('boolType self) source
+
+    fn... input
+    case (self, typeId : Id, name : Input)
+        let inp = (rawinput self typeId name)
+        'then self
+            'defined? self inp
+            inp
+    case (self, name : Input)
+        let inp = (rawinput self NoId name)
+        'defined? self inp
 
     fn... output (self, typeId : Id, name : Output)
         'nodeId self

          
@@ 651,13 662,16 @@ struct Module
         fn... (self, a : Id, b : Id)
             'nodeId self
                 Node
-                    'condAnd self ('getCond self a) ('getCond self b)
                     \ op (getType self a) a b
 
     inline binary_bool_op (op)
         fn... (self, a : Id, b : Id)
-            'nodeId self
-                Node op ('boolType self) a b
+            let ca a = ('unthen self a)
+            let cb b = ('unthen self b)
+            'then self
+                'and self ca cb            
+                'nodeId self
+                    Node op ('boolType self) a b
     let
         sin = (unary_op Op.Sin)
         cos = (unary_op Op.Cos)

          
@@ 743,8 757,33 @@ struct Module
                     visit-id = visit
                 \ self node remap
         self.rnodes_valid = false
+        'clear self.select_cache
 
-    fn... reachable-indegrees (self)
+    fn... topolist (self, root : Id)
+        let nodes = self.nodes
+        let count = ((countof nodes) as u32)
+        local visited : (Set Id)
+
+        # find roots
+        local queue : (Array Id)
+        'insert visited root
+        'append queue root
+        # tag all reachable nodes
+        for nodeid in queue
+            node := nodes @ nodeid
+            fn visit (self id queue visited)
+                if (id != NoType)
+                    if (not ('in? visited id))
+                        'append queue id
+                    'insert visited id
+            visit self (copy node.typeId) queue visited
+            call
+                genvisitor
+                    visit-id = visit
+                \ self node queue visited
+        queue
+
+    fn... reachable-indegrees (self, root : Id = NoId)
         let nodes = self.nodes
         let count = ((countof nodes) as u32)
         local visited : (Array i32)

          
@@ 752,10 791,14 @@ struct Module
 
         # find roots
         local queue : (Array Id)
-        for id node in (enumerate nodes Id)
-            if (node.opCode == Op.Bind)
-                visited @ id = 1
-                'append queue id
+        if (root != NoId)
+            visited @ root = 1
+            'append queue root
+        else
+            for id node in (enumerate nodes Id)
+                if (node.opCode == Op.Bind)
+                    visited @ id = 1
+                    'append queue id
         # tag all reachable nodes
         for nodeid in queue
             node := nodes @ nodeid

          
@@ 968,45 1011,155 @@ struct Module
         va-switch-case OpDefs Op 'constant? ('getOp self id)
             inline (cls) false
 
-    fn iscondition? (self condid id)
-        """"id is implicitly true if we are already in a conditional branch depending on it
-        loop (condid = condid)
-            if (condid == NoId)
-                break false
-            if (id == condid)
-                break true
-            'getCond self condid
+    fn unthen (self id)
+        if (('getOp self id) == Op.Then)
+            _ ('getArg self id 0) ('getArg self id 1)
+        else
+            _ ('constBool self true) id
 
-    fn unthen (self id)
-        loop (id = id)
-            if (('getOp self id) == Op.Then)
-                repeat ('getArg self id 0)
-            else
-                break id
-
-    fn... getbool (self, id : Id, cond = NoId)
+    fn... getbool (self, id : Id)
         """"returns bool as signed integer; zero = undefined
         if ('isbool? self id)
             if ('isconstant? self id)
                 ? (('getArg self id 0) != 0) 1 -1
-            elseif ('iscondition? self cond id) 1
             else 0
         else 0
 
+    fn... subst (self, id : Id, substmap : (Map Id Id))
+        """"recursive substitution
+        returning Id
+        # see if there is a direct replacement
+        for id in ('reverse (topolist self id))
+            local found = false
+            # see if any change is necessary
+            node := self.nodes @ id
+            fn visit-argument (self id substmap found)            
+                if (id != NoId)
+                    if ('in? substmap id)
+                        found = true
+                        raise;
+                ;
+            let visitctx... = substmap found
+            try
+                visit-argument self node.typeId visitctx...
+                call
+                    genvisitor
+                        visit-id = visit-argument
+                    \ self node visitctx...
+            else;
+            if (not found)
+                continue;
+            # make the changes
+            local node = (copy ((self.nodes @ id) as Node))
+            fn visit-argument (self id substmap found)            
+                if (id != NoId)
+                    try
+                        id = ('get substmap id)
+                    else;
+                ;
+            let visitctx... = substmap found
+            visit-argument self node.typeId visitctx...
+            call
+                genvisitor
+                    visit-id = visit-argument
+                \ self node visitctx...
+            let newid =
+                vvv copy
+                'nodeId self (deref node)            
+            assert (id != newid)
+            # update select statements that have become constant
+            let newid =
+                if (('getOp self newid) == Op.Select)
+                    let cond = ('getArg self newid 0)
+                    let b = ('getbool self cond)
+                    if (b != 0)
+                        if (b > 0)
+                            'getArg self newid 2
+                        else
+                            'getArg self newid 1
+                    elseif
+                        and
+                            ('getbool self ('getArg self newid 1)) == -1
+                            ('getbool self ('getArg self newid 2)) == 1
+                        cond
+                    else newid
+                else newid
+            'set substmap id newid
+        try (copy ('get substmap id))
+        else (copy id)
+
 'define-symbols Module
     select =
         fn... "select" (self, cond : Id, fvalue : Id, tvalue : Id)
-            'nodeId self
-                Node Op.Select ('getType self tvalue) cond fvalue tvalue
+            returning Id
+            key := (tupleof cond fvalue tvalue)
+            try
+                return
+                    copy
+                        'get self.select_cache key
+            else;
+            let result =
+                loop (cond fvalue tvalue = cond fvalue tvalue)
+                    let c cond = ('unthen self cond)
+                    let cf fvalue = ('unthen self fvalue)
+                    let ct tvalue = ('unthen self tvalue)
+                    if (('getOp self cond) == Op.Select)
+                        # (select (select a b c) d e) -> (select a (select b d e) (select c d e))
+                        let u v w =
+                            'getArg self cond 0
+                            'getArg self cond 1
+                            'getArg self cond 2
+                        repeat
+                            'then self
+                                'and self c ('and self cf ct)
+                                u
+                            this-function self v fvalue tvalue
+                            this-function self w fvalue tvalue
+                    else 
+                        let b = ('getbool self cond)
+                        let stmt =
+                            if (b > 0) (copy tvalue)
+                            elseif (b < 0) (copy fvalue)
+                            else
+                                let constfalse = ('constBool self false)
+                                let consttrue = ('constBool self true)
+                                let fvalue =
+                                    do
+                                        local substmap : (Map Id Id)
+                                        'set substmap cond constfalse
+                                        'subst self fvalue substmap
+                                let tvalue = 
+                                    do
+                                        local substmap : (Map Id Id)
+                                        'set substmap cond consttrue
+                                        'subst self tvalue substmap 
+                                if (fvalue == tvalue)
+                                    copy tvalue
+                                elseif ((fvalue == constfalse) & (tvalue == consttrue))
+                                    copy cond
+                                else
+                                    vvv copy
+                                    'nodeId self
+                                        Node Op.Select ('getType self tvalue) cond fvalue tvalue
+                        break
+                            'then self
+                                'and self c ('and self cf ct)
+                                stmt
+            'set self.select_cache key result
+            result
 
     merge =
         fn... "merge" (self, value1 : Id, value2 : Id)
-            'nodeId self
-                Node Op.Merge ('getType self value1) value1 value2
-    load =
-        fn... "load" (self, typeId : Id, source : Id)
-            'nodeId self
-                Node Op.Load typeId
+            let c1 value1 = ('unthen self value1)
+            let c2 value2 = ('unthen self value2)
+            let lcond1 = ('and self c1 ('not self c2))
+            let lcond2 = ('and self ('not self c1) c2)
+            'then self
+                'or self c1 c2
+                if (('getbool self lcond1) >= 0)
+                    'select self lcond1 value2 value1
+                else
+                    'select self lcond2 value1 value2
 
     bind =
         fn... "bind" (self, target : Id, source : Id)

          
@@ 1015,8 1168,24 @@ struct Module
 
     then =
         fn... "then" (self, cond : Id, value : Id)
+            returning Id
+            let c1 cond = ('unthen self cond)
+            let c2 value = ('unthen self value)
+            let cond = ('and self ('and self c1 cond) c2)
+            let b = ('getbool self cond)
+            if (b > 0) (copy value)
+            elseif (b < 0)
+                vvv copy
+                'undef self ('getType self value)
+            else
+                vvv copy
+                'nodeId self
+                    Node Op.Then ('getType self value) cond value
+
+    undef =
+        fn... "undef" (self, typeId : Id)
             'nodeId self
-                Node Op.Then ('getType self value) cond value
+                Node Op.Undefined typeId
 
     else =
         fn... "else" (self, condition : Id, value : Id)

          
@@ 1026,10 1195,15 @@ struct Module
 
     or =
         fn... "or" (self, value1 : Id, value2 : Id)
+            if (value1 == value2)
+                return value1
             'select self value1 value2 value1
 
     and =
         fn... "and" (self, value1 : Id, value2 : Id)
+            returning Id
+            if (value1 == value2)
+                return value2
             'select self value1 value1 value2
 
     not =

          
@@ 1359,7 1533,7 @@ static-if main-module?
             \ getType vectorType intToReal state constFloat fadd fmul fdiv sin cos
             \ constComposite compositeInsert compositeConstruct constBool constString
             \ parameter parameters function tupleType else bind then merge
-            \ equal input load output constInt
+            \ equal input load output constInt and or
 
         let string =
             stringType;

          
@@ 1368,7 1542,7 @@ static-if main-module?
         let inttype =
             integerType 32 true
         let readline =
-            load string (input Input.Readline)
+            input string Input.Readline
         let setup =
             input Input.Setup
         let stdout =

          
@@ 1385,7 1559,7 @@ static-if main-module?
             then exit? (constInt inttype 0)
 
         bind prompt
-            then (merge setup exit?) (constString "> ")
+            then (or setup exit?) (constString "> ")
 
         bind stdout
             merge