800281fa04c0 — Leonard Ritter a month ago
* typer can deduce stages
3 files changed, 410 insertions(+), 131 deletions(-)

M lib/tukan/CADAG/init.sc
M lib/tukan/FIR.sc
M testing/tukdag.sc
M lib/tukan/CADAG/init.sc +12 -0
@@ 25,6 25,18 @@ type Id
             inline (a b)
                 (storagecast a) == (storagecast b)
 
+    @@ memo
+    inline __< (cls T)
+        static-if (T < this-type)
+            inline (a b)
+                (storagecast a) < (storagecast b)
+
+    @@ memo
+    inline __> (cls T)
+        static-if (T < this-type)
+            inline (a b)
+                (storagecast a) > (storagecast b)
+
 'define-symbol Id '__copy integer.__copy
 'define-symbol Id '__hash integer.__hash
 

          
M lib/tukan/FIR.sc +392 -128
@@ 6,6 6,7 @@ using import Map
 using import Array
 using import Option
 using import glm
+using import Rc
 
 using import .CADAG
 using import .gl

          
@@ 300,6 301,8 @@ define-type "uvec"      (RIFF "UVEC") (t
     typecolor...
 define-type "mrv"       (RIFF "VART") (tuple (types = (array AnyId)))
     typecolor...
+define-type "args"      (RIFF "ARGS") (tuple (args = (array AnyId)))
+    typecolor...
 define-type "fvec2"     (RIFF "FVC2") (tuple (x = AnyId) (y = AnyId))
     instrcolor...
 define-type "fvec3"     (RIFF "FVC3") (tuple (x = AnyId) (y = AnyId) (z = AnyId))

          
@@ 400,10 403,134 @@ define-op1 "cos"    "FCOS"
 
 ################################################################################
 
+#
+    expressions can be of these contexts:
+
+    * constant
+        runs anywhere
+
+    * control stage, and sampling order
+        these run singular on the CPU
+        a control can only source control expressions
+        a control can be of type uvec, fvec, imagestorage, bufferstorage
+        imagestorage and bufferstorage have a sampling order
+
+    * shader stage:
+        * compute range, and sampling order aka buffer index
+            each expression is executed as wide as the range is.
+            two expressions can share the same range, but one samples the other.
+            a range can source range (of same size and sampling order) and
+            control through uniform variables.
+            a range can be of type uvec, fvec, image, sampler
+
+        * vertex primitive, and sampling order
+            each expression is executed as wide as primitive count times instance
+            count, or less. this is an intermediate stage for fragment shaders.
+            a primitive can source primitive (of same vertex primitive) and control
+            through uniform variables.
+            a primitive can be of type uvec, fvec, image, sampler
+
+        * fragment
+            each expression is executed for as many fragments can be generated for
+            a primitive. a fragment can source fragments (of same vertex primitive)
+            and control through uniform variables.
+            a fragment can be of type uvec, fvec, image, sampler
+
+            fragments from different primitives should be combinable
+            could sampling a fragment use the range of the uv coordinate?
+
+    for shader stages, we need to aggregate bindings, and perhaps search / replace
+    later instead of doing eager replacements; i would suggest a linked list of
+    external value assignments here. we can then make proper bindings for them
+    and run a replace operation afterwards.
+
+let AnyIdArray = (Array AnyId)
+let RcAnyIdArray = (Rc AnyIdArray)
+
+fn... merge-arrays (a : RcAnyIdArray, b : RcAnyIdArray)
+    if (empty? b) (copy a)
+    elseif (empty? a) (copy b)
+    else
+        local imports : AnyIdArray
+        let acount aptr = ((countof a) as u32) (& (a @ 0))
+        let bcount bptr = ((countof b) as u32) (& (b @ 0))
+        loop (a b = 0:u32 0:u32)
+            if (a < acount)
+                if (b < bcount)
+                    let A B = (copy (aptr @ a)) (copy (bptr @ b))
+                    if (A == B)
+                        'append imports A
+                        repeat (a + 1) (b + 1)
+                    elseif (A < B)
+                        'append imports A
+                        repeat (a + 1) b
+                    else # (A > B)
+                        'append imports B
+                        repeat a (b + 1)
+                else
+                    'append imports (copy (aptr @ a))
+                    repeat (a + 1) b
+            elseif (b < bcount)
+                'append imports (copy (bptr @ b))
+                repeat a (b + 1)
+            else
+                break;
+        Rc.wrap (deref imports)
+
+struct StageInfo
+    # how many sample indirections have been performed?
+    sample-index : u32
+    # NoId, range, primitive or fragment
+    generator : AnyId
+    # values that need to be imported in order to evaluate
+        this expression in a shader stage
+    stage-import : RcAnyIdArray
+
+    fn __copy (self)
+        this-type
+            sample-index = self.sample-index
+            generator = self.generator
+            stage-import = (copy self.stage-import)
+
+    fn __repr (self)
+        ..
+            \ "(sample-index=" (repr self.sample-index)
+            \ " generator=" (repr self.generator)
+            \ " stage-import=" (repr self.stage-import)
+            \ ")"
+
+    fn combine (self other module)
+        let sample-index = (max self.sample-index other.sample-index)
+        let generator =
+            if (self.generator == NoId)
+                copy other.generator
+            elseif (other.generator == NoId)
+                copy self.generator
+            elseif (self.generator == other.generator)
+                copy self.generator
+            else
+                error "generators of arguments do not match"
+        let stage-import =
+            if (empty? self.stage-import)
+                copy other.stage-import
+            elseif (empty? other.stage-import)
+                copy self.stage-import
+            elseif (self.stage-import == other.stage-import)
+                copy self.stage-import
+            else
+                # merge two ordered arrays
+                merge-arrays (view self.stage-import) (view other.stage-import)
+        this-type
+            sample-index = sample-index
+            generator = generator
+            stage-import = stage-import
+
 struct FIRTyper
     types : (Map AnyId AnyId)
+    stages : (Map AnyId StageInfo)
+    rootstage : StageInfo
 
-    fn setup (self module)
+    fn setup (module)
         # insert common types early, so we don't interfere with late insertions
         from (methodsof module.builder) let uvec fvec
         uvec 1

          
@@ 416,6 543,118 @@ struct FIRTyper
         fvec 4
         ;
 
+    fn stage-value (ctx module id)
+        from (methodsof module.builder) let uvec fvec
+        report "staging" ('repr module id)
+        inline get (id)
+            try ('get ctx.stages id)
+            else
+                error
+                    .. "stage missing for: " ('repr module id)
+
+        let handle = ('handleof module id)
+        let vacount = ('vacount handle)
+        vvv bind stage
+        dispatch handle
+        case range (self)
+            let dims = self.dims
+            let sample-index =
+                fold (si = 0:u32) for i in (range vacount)
+                    max si ((get (dims @ i)) . sample-index)
+            StageInfo
+                sample-index = sample-index
+                generator = id
+                stage-import = (copy ctx.rootstage.stage-import)
+        case sample (self)
+            let uv = (get self.uv)
+            let source = (get self.source)
+            StageInfo
+                sample-index = (max uv.sample-index source.sample-index)
+                generator = uv.generator
+                stage-import = (copy uv.stage-import)
+        default
+            switch ('typeidof module id)
+            pass TypeId.typeid_uconst
+            pass TypeId.typeid_fconst
+            pass TypeId.typeid_input
+            pass TypeId.typeid_output
+            pass TypeId.typeid_outputs
+            do
+                copy ctx.rootstage
+            pass TypeId.typeid_comp
+            pass TypeId.typeid_add
+            pass TypeId.typeid_sub
+            pass TypeId.typeid_mul
+            pass TypeId.typeid_udiv
+            pass TypeId.typeid_sdiv
+            pass TypeId.typeid_and
+            pass TypeId.typeid_or
+            pass TypeId.typeid_xor
+            pass TypeId.typeid_utof
+            pass TypeId.typeid_fadd
+            pass TypeId.typeid_fsub
+            pass TypeId.typeid_fmul
+            pass TypeId.typeid_fdiv
+            pass TypeId.typeid_frem
+            pass TypeId.typeid_sin
+            pass TypeId.typeid_cos
+            pass TypeId.typeid_fvec2
+            pass TypeId.typeid_fvec3
+            pass TypeId.typeid_fvec4
+            pass TypeId.typeid_uvec2
+            pass TypeId.typeid_uvec3
+            pass TypeId.typeid_uvec4
+            do
+                # aggregate all arguments
+                local si =
+                    fold (si = (copy ctx.rootstage)) for srcid in ('sources handle)
+                        'combine si (get srcid) module
+                let new-imports? =
+                    if (si.generator == NoId) false
+                    else
+                        for srcid in ('sources handle)
+                            let s = (get srcid)
+                            if (s.generator == NoId)
+                                if (not ('constant? module srcid))
+                                    break true
+                        else false
+                if new-imports?
+                    local newarr : (Array AnyId)
+                    for srcid in ('sources handle)
+                        let s = (get srcid)
+                        if (s.generator == NoId)
+                            'append newarr srcid
+                    'sort newarr
+                    si.stage-import =
+                        merge-arrays si.stage-import (Rc.wrap (deref newarr))
+                deref si
+            default
+                error "failed to deduce stage"
+        report "staged to" stage
+        'set ctx.stages id stage
+        ;
+
+    fn... stageof (ctx, module : FIR, id : AnyId)
+        try
+            return ('get ctx.stages id)
+        else;
+
+        'descend module id
+            on-enter =
+                capture (module id) {&ctx}
+                    not ('in? ctx.stages id)
+            on-leave =
+                capture on-leave (module id) {&ctx}
+                    try
+                        stage-value ctx module id
+                    except (err)
+                        error@+ err unknown-anchor
+                            .. "while deducing stage of " ('repr module id)
+        try
+            return ('get ctx.stages id)
+        else
+            trap;
+
     fn imagetypeof (self module id)
         let tid = ('typeof self module id)
         dispatch ('handleof module tid)

          
@@ 432,124 671,124 @@ struct FIRTyper
                     " has type "
                     'repr module tid
 
+    fn type-value (ctx module id)
+        from (methodsof module.builder) let uvec fvec
+        report "typing" ('repr module id)
+        inline get (id)
+            try (copy ('get ctx.types id))
+            else
+                error
+                    .. "type missing for: " ('repr module id)
+
+        let handle = ('handleof module id)
+        vvv bind type
+        dispatch handle
+        case input (self)
+            switch self.source
+            case SystemKey.ScreenSize (uvec 2)
+            case SystemKey.Iteration (uvec 1)
+            default
+                error
+                    .. "don't know how to type source: " (repr self.source)
+        case load (self)
+            get self.pointer
+        case sample (self)
+            get self.source
+        case clearimage (self)
+            get self.target
+        case dispatch (self)
+            let sinkhandle = ('handleof module self.sinks)
+            let vacount = ('vacount sinkhandle)
+            dispatch sinkhandle
+            case bindings (sinks)
+                if (vacount == 1)
+                    let entries = sinks.entries
+                    get (entries @ 0 @ 0)
+                else
+                    let mrv =
+                        'alloc module TypeId.typeid_mrv vacount
+                    let entries = sinks.entries
+                    let args = mrv.types
+                    for i in (range vacount)
+                        args @ i = (get (entries @ i @ 0))
+                    'commit module
+            default
+                trap;
+        default
+            switch ('typeidof module id)
+            #case TypeId.typeid_range (uvec 3)
+            case TypeId.typeid_globalid (uvec 3)
+
+            pass TypeId.typeid_fconst
+            pass TypeId.typeid_utof
+            pass TypeId.typeid_fadd
+            pass TypeId.typeid_fmul
+            pass TypeId.typeid_fdiv
+            pass TypeId.typeid_sin
+            do (fvec 1)
+
+            pass TypeId.typeid_uconst
+            pass TypeId.typeid_add
+            pass TypeId.typeid_sub
+            pass TypeId.typeid_mul
+            pass TypeId.typeid_udiv
+            pass TypeId.typeid_sdiv
+            pass TypeId.typeid_and
+            pass TypeId.typeid_or
+            pass TypeId.typeid_xor
+            do (uvec 1)
+
+            pass TypeId.typeid_outputs
+            pass TypeId.typeid_output
+            pass TypeId.typeid_imagewrite
+            pass TypeId.typeid_computefn
+            pass TypeId.typeid_bindings
+            do NoId
+
+            case TypeId.typeid_uvec2 (uvec 2)
+            case TypeId.typeid_uvec3 (uvec 3)
+            case TypeId.typeid_uvec4 (uvec 4)
+            case TypeId.typeid_fvec2 (fvec 2)
+            case TypeId.typeid_fvec3 (fvec 3)
+            case TypeId.typeid_fvec4 (fvec 4)
+
+            # types have no type
+            pass TypeId.typeid_image
+            pass TypeId.typeid_sampler
+            pass TypeId.typeid_mrv
+            pass TypeId.typeid_fvec
+            pass TypeId.typeid_uvec
+            pass TypeId.typeid_imagestorage
+            do NoId
+
+            # first value is type
+            pass TypeId.typeid_undef
+            pass TypeId.typeid_wimage
+            pass TypeId.typeid_uniform
+            do
+                for srcid in ('sources handle)
+                    if true
+                        break (copy srcid)
+                else NoId
+
+            # type of first value
+            pass TypeId.typeid_comp
+            do
+                for srcid in ('sources handle)
+                    if true
+                        break (get srcid)
+                else NoId
+            default
+                error "failed to deduce type"
+        report "typed to" type
+        'set ctx.types id type
+        ;
+
     fn... typeof (ctx, module : FIR, id : AnyId)
         try
             return (copy ('get ctx.types id))
         else;
 
-        fn type-value (ctx module id)
-            from (methodsof module.builder) let uvec fvec
-            report "typing" ('repr module id)
-            inline get (id)
-                try (copy ('get ctx.types id))
-                else
-                    error
-                        .. "type missing for: " ('repr module id)
-
-            let handle = ('handleof module id)
-            vvv bind type
-            dispatch handle
-            case input (self)
-                switch self.source
-                case SystemKey.ScreenSize (uvec 2)
-                case SystemKey.Iteration (uvec 1)
-                default
-                    error
-                        .. "don't know how to type source: " (repr self.source)
-            case load (self)
-                get self.pointer
-            case sample (self)
-                get self.source
-            case clearimage (self)
-                get self.target
-            case dispatch (self)
-                let sinkhandle = ('handleof module self.sinks)
-                let vacount = ('vacount sinkhandle)
-                dispatch sinkhandle
-                case bindings (sinks)
-                    if (vacount == 1)
-                        let entries = sinks.entries
-                        get (entries @ 0 @ 0)
-                    else
-                        let mrv =
-                            'alloc module TypeId.typeid_mrv vacount
-                        let entries = sinks.entries
-                        let args = mrv.types
-                        for i in (range vacount)
-                            args @ i = (get (entries @ i @ 0))
-                        'commit module
-                default
-                    trap;
-            default
-                switch ('typeidof module id)
-                #case TypeId.typeid_range (uvec 3)
-                case TypeId.typeid_globalid (uvec 3)
-
-                pass TypeId.typeid_fconst
-                pass TypeId.typeid_utof
-                pass TypeId.typeid_fadd
-                pass TypeId.typeid_fmul
-                pass TypeId.typeid_fdiv
-                pass TypeId.typeid_sin
-                do (fvec 1)
-
-                pass TypeId.typeid_uconst
-                pass TypeId.typeid_add
-                pass TypeId.typeid_sub
-                pass TypeId.typeid_mul
-                pass TypeId.typeid_udiv
-                pass TypeId.typeid_sdiv
-                pass TypeId.typeid_and
-                pass TypeId.typeid_or
-                pass TypeId.typeid_xor
-                do (uvec 1)
-
-                pass TypeId.typeid_outputs
-                pass TypeId.typeid_output
-                pass TypeId.typeid_imagewrite
-                pass TypeId.typeid_computefn
-                pass TypeId.typeid_bindings
-                do NoId
-
-                case TypeId.typeid_uvec2 (uvec 2)
-                case TypeId.typeid_uvec3 (uvec 3)
-                case TypeId.typeid_uvec4 (uvec 4)
-                case TypeId.typeid_fvec2 (fvec 2)
-                case TypeId.typeid_fvec3 (fvec 3)
-                case TypeId.typeid_fvec4 (fvec 4)
-
-                # types have no type
-                pass TypeId.typeid_image
-                pass TypeId.typeid_sampler
-                pass TypeId.typeid_mrv
-                pass TypeId.typeid_fvec
-                pass TypeId.typeid_uvec
-                pass TypeId.typeid_imagestorage
-                do NoId
-
-                # first value is type
-                pass TypeId.typeid_undef
-                pass TypeId.typeid_wimage
-                pass TypeId.typeid_uniform
-                do
-                    for srcid in ('sources handle)
-                        if true
-                            break (copy srcid)
-                    else NoId
-
-                # type of first value
-                pass TypeId.typeid_comp
-                do
-                    for srcid in ('sources handle)
-                        if true
-                            break (get srcid)
-                    else NoId
-                default
-                    error "failed to deduce type"
-            report "typed to" type
-            'set ctx.types id type
-            ;
-
         'descend module id
             on-enter =
                 capture (module id) {&ctx}

          
@@ 1539,7 1778,7 @@ fn generate-IL (module)
             drop-body = drop-body
 
     let rootid = ('rootid module)
-    'setup ctx.typer module
+    FIRTyper.setup module
 
     'descend module rootid
         on-leave =

          
@@ 1598,7 1837,7 @@ fn tryuconsts (lhs rhs)
     _ (| (? a? 1 0) (? b? 2 0)) a b
 
 # returns NoId if expression can't be folded, otherwise id of new expression
-fn fold-constant-expression (self handle)
+fn fold-constant-expression (typer self handle)
     viewing self
     raising Error
     from (methodsof self.builder) let uconst fconst

          
@@ 1615,14 1854,18 @@ fn fold-constant-expression (self handle
     dispatch handle
     case comp (self)
         let typeid sz ptr = (unpack ('handleof module self.value))
+        if (self.index == 0)
+            let val = (copy self.value)
+            let tid = ('typeof typer module val)
+            dispatch ('handleof module tid)
+            case fvec (self)
+                if (self.count == 1)
+                    return val
+            case uvec (self)
+                if (self.count == 1)
+                    return val
+            default;
         switch typeid
-        pass TypeId.typeid_fconst
-        pass TypeId.typeid_uconst
-        do
-            if (self.index == 0)
-                return (copy self.value)
-            else
-                report "index out of vector size"
         pass TypeId.typeid_fvec2
         pass TypeId.typeid_fvec3
         pass TypeId.typeid_fvec4

          
@@ 1713,11 1956,12 @@ fn fold-constant-expression (self handle
 
 fn fold-constant-expressions (self)
     let cls = (typeof self)
+    local typer : FIRTyper
     'translate self self ('rootid self)
         on-leave =
-            capture (module handle oldmodule id) {}
+            capture (module handle oldmodule id) {&typer}
                 try
-                    let id = (fold-constant-expression module handle)
+                    let id = (fold-constant-expression typer module handle)
                     if (id != cls.NoId) id
                     else
                         'commit module handle

          
@@ 1728,7 1972,7 @@ fn fold-constant-expressions (self)
 ################################################################################
 
 # lower range based expressions to compute functions and dispatches
-fn lower-FIR (module)
+fn lower-FIR (module rootid)
     viewing module
 
     fn get-capacity (module id)

          
@@ 2043,8 2287,7 @@ fn lower-FIR (module)
             merge-gpujobs;
         newid
 
-    let rootid = ('rootid module)
-    'setup ctx.typer module
+    FIRTyper.setup module
     'translate module module rootid
         on-leave =
             capture (module handle oldmodule id) {&ctx}

          
@@ 2061,6 2304,27 @@ type+ FIR
     let lower = lower-FIR
     let fold-constant-expressions
 
+    fn constant? (self id)
+        returning bool
+        let handle = ('handleof self id)
+        switch handle.typeid
+        pass TypeId.typeid_fconst
+        pass TypeId.typeid_uconst
+        do true
+        pass TypeId.typeid_fvec2
+        pass TypeId.typeid_fvec3
+        pass TypeId.typeid_fvec4
+        pass TypeId.typeid_uvec2
+        pass TypeId.typeid_uvec3
+        pass TypeId.typeid_uvec4
+        pass TypeId.typeid_comp
+        do
+            for id in ('sources handle)
+                if (not (this-function self id))
+                    return false
+            else true
+        default false
+
 type+ FIR.BuilderType
     inline unpack-comp (self value n)
         va-map

          
M testing/tukdag.sc +6 -3
@@ 154,7 154,7 @@ inline gen-level2-test ()
     let w h =
         unpack-comp (input SystemKey.ScreenSize) 2
     let screenrange = (range w h)
-    outputs
+    #outputs
         output SystemKey.Screen
             clear screenrange
                 fvec3 (fconst 0) (fconst 0) (fconst 1)

          
@@ 181,7 181,7 @@ inline gen-level2-test ()
                     clear screenrange
                         fvec3 (uconst 0) (uconst 0) (uconst 1)
 
-    #outputs
+    outputs
         output SystemKey.Screen
             do
                 # frame time

          
@@ 223,7 223,10 @@ do
     gen-level2-test;
     cleanup;
     'dump module
-    'lower module
+    let rootid = ('rootid module)
+    local typer : FIRTyper
+    'stageof typer module rootid
+    'lower module rootid
 print;
 'fold-constant-expressions module
 #cleanup;