Make pass pre- and post-traversal work properly.

With unit tests!  Huzzah!
1 files changed, 128 insertions(+), 48 deletions(-)

M src/passes.rs
M src/passes.rs +128 -48
@@ 65,33 65,11 @@ pub fn run_typechecked_passes(ir: Ir, tc
     res
 }
 
-fn exprs_map(
-    exprs: Vec<ExprNode>,
-    f_pre: &mut dyn FnMut(ExprNode) -> ExprNode,
-    f_post: &mut dyn FnMut(ExprNode) -> ExprNode,
-) -> Vec<ExprNode> {
-    exprs
-        .into_iter()
-        .map(|e| expr_map(e, f_pre, f_post))
-        .collect()
-}
-
 /// Handy do-nothing combinator
 fn id<T>(thing: T) -> T {
     thing
 }
 
-fn exprs_map_pre(exprs: Vec<ExprNode>, f: &mut dyn FnMut(ExprNode) -> ExprNode) -> Vec<ExprNode> {
-    let f_post = &mut |e| e;
-    exprs_map(exprs, f, f_post)
-    //exprs.into_iter().map(|e| expr_map(e, f, f_post)).collect()
-}
-
-fn exprs_map_post(exprs: Vec<ExprNode>, f: &mut dyn FnMut(ExprNode) -> ExprNode) -> Vec<ExprNode> {
-    exprs_map(exprs, &mut id, f)
-    //exprs.into_iter().map(|e| expr_map(e, f, f_post)).collect()
-}
-
 /// This is the core recusion scheme that takes a function and applies
 /// it to an expression, and all its subexpressions.
 /// The function is `&mut FnMut` so if it needs to have access to more

          
@@ 132,12 110,12 @@ fn expr_map(
         E::Break => e,
         E::EnumCtor { .. } => e,
         E::TupleCtor { body } => E::TupleCtor {
-            body: exprs_map_pre(body, f_pre),
+            body: exprs_map(body, f_pre, f_post),
         },
         E::StructCtor { body } => {
             let new_body = body
                 .into_iter()
-                .map(|(sym, vl)| (sym, expr_map_pre(vl, f_pre)))
+                .map(|(sym, vl)| (sym, expr_map(vl, f_pre, f_post)))
                 .collect();
             E::StructCtor { body: new_body }
         }

          
@@ 148,7 126,7 @@ fn expr_map(
         } => E::TypeCtor {
             name,
             type_params,
-            body: expr_map_pre(body, f_pre),
+            body: expr_map(body, f_pre, f_post),
         },
         E::SumCtor {
             name,

          
@@ 157,41 135,41 @@ fn expr_map(
         } => E::SumCtor {
             name,
             variant,
-            body: expr_map_pre(body, f_pre),
+            body: expr_map(body, f_pre, f_post),
         },
         E::ArrayCtor { body } => E::ArrayCtor {
-            body: exprs_map_pre(body, f_pre),
+            body: exprs_map(body, f_pre, f_post),
         },
         E::TypeUnwrap { expr } => E::TypeUnwrap {
-            expr: expr_map_pre(expr, f_pre),
+            expr: expr_map(expr, f_pre, f_post),
         },
         E::TupleRef { expr, elt } => E::TupleRef {
-            expr: expr_map_pre(expr, f_pre),
+            expr: expr_map(expr, f_pre, f_post),
             elt,
         },
         E::StructRef { expr, elt } => E::StructRef {
-            expr: expr_map_pre(expr, f_pre),
+            expr: expr_map(expr, f_pre, f_post),
             elt,
         },
         E::ArrayRef { expr, idx } => E::ArrayRef {
-            expr: expr_map_pre(expr, f_pre),
+            expr: expr_map(expr, f_pre, f_post),
             idx,
         },
         E::Assign { lhs, rhs } => E::Assign {
-            lhs: expr_map_pre(lhs, f_pre), // TODO: Think real hard about lvalues
-            rhs: expr_map_pre(rhs, f_pre),
+            lhs: expr_map(lhs, f_pre, f_post), // TODO: Think real hard about lvalues
+            rhs: expr_map(rhs, f_pre, f_post),
         },
         E::BinOp { op, lhs, rhs } => E::BinOp {
             op,
-            lhs: expr_map_pre(lhs, f_pre),
-            rhs: expr_map_pre(rhs, f_pre),
+            lhs: expr_map(lhs, f_pre, f_post),
+            rhs: expr_map(rhs, f_pre, f_post),
         },
         E::UniOp { op, rhs } => E::UniOp {
             op,
-            rhs: expr_map_pre(rhs, f_pre),
+            rhs: expr_map(rhs, f_pre, f_post),
         },
         E::Block { body } => E::Block {
-            body: exprs_map_pre(body, f_pre),
+            body: exprs_map(body, f_pre, f_post),
         },
         E::Let {
             varname,

          
@@ 201,33 179,33 @@ fn expr_map(
         } => E::Let {
             varname,
             typename,
-            init: expr_map_pre(init, f_pre),
+            init: expr_map(init, f_pre, f_post),
             mutable,
         },
         E::If { cases } => {
             let new_cases = cases
                 .into_iter()
                 .map(|(test, case)| {
-                    let new_test = expr_map_pre(test, f_pre);
-                    let new_cases = exprs_map_pre(case, f_pre);
+                    let new_test = expr_map(test, f_pre, f_post);
+                    let new_cases = exprs_map(case, f_pre, f_post);
                     (new_test, new_cases)
                 })
                 .collect();
             E::If { cases: new_cases }
         }
         E::Loop { body } => E::Loop {
-            body: exprs_map_pre(body, f_pre),
+            body: exprs_map(body, f_pre, f_post),
         },
         E::Return { retval } => E::Return {
-            retval: expr_map_pre(retval, f_pre),
+            retval: expr_map(retval, f_pre, f_post),
         },
         E::Funcall {
             func,
             params,
             type_params,
         } => {
-            let new_func = expr_map_pre(func, f_pre);
-            let new_params = exprs_map_pre(params, f_pre);
+            let new_func = expr_map(func, f_pre, f_post);
+            let new_params = exprs_map(params, f_pre, f_post);
             E::Funcall {
                 func: new_func,
                 params: new_params,

          
@@ 236,10 214,10 @@ fn expr_map(
         }
         E::Lambda { signature, body } => E::Lambda {
             signature,
-            body: exprs_map_pre(body, f_pre),
+            body: exprs_map(body, f_pre, f_post),
         },
         E::Typecast { e, to } => E::Typecast {
-            e: expr_map_pre(e, f_pre),
+            e: expr_map(e, f_pre, f_post),
             to,
         },
     };

          
@@ 247,14 225,34 @@ fn expr_map(
     f_post(post_thing)
 }
 
-fn expr_map_pre(expr: ExprNode, f: &mut dyn FnMut(ExprNode) -> ExprNode) -> ExprNode {
+pub fn expr_map_pre(expr: ExprNode, f: &mut dyn FnMut(ExprNode) -> ExprNode) -> ExprNode {
     expr_map(expr, f, &mut id)
 }
 
-fn expr_map_post(expr: ExprNode, f: &mut dyn FnMut(ExprNode) -> ExprNode) -> ExprNode {
+pub fn expr_map_post(expr: ExprNode, f: &mut dyn FnMut(ExprNode) -> ExprNode) -> ExprNode {
     expr_map(expr, &mut id, f)
 }
 
+/// Map functions over a list of exprs.
+fn exprs_map(
+    exprs: Vec<ExprNode>,
+    f_pre: &mut dyn FnMut(ExprNode) -> ExprNode,
+    f_post: &mut dyn FnMut(ExprNode) -> ExprNode,
+) -> Vec<ExprNode> {
+    exprs
+        .into_iter()
+        .map(|e| expr_map(e, f_pre, f_post))
+        .collect()
+}
+
+fn exprs_map_pre(exprs: Vec<ExprNode>, f: &mut dyn FnMut(ExprNode) -> ExprNode) -> Vec<ExprNode> {
+    exprs_map(exprs, f, &mut id)
+}
+
+fn exprs_map_post(exprs: Vec<ExprNode>, f: &mut dyn FnMut(ExprNode) -> ExprNode) -> Vec<ExprNode> {
+    exprs_map(exprs, &mut id, f)
+}
+
 fn decl_map_pre(
     decl: D,
     fe: &mut dyn FnMut(ExprNode) -> ExprNode,

          
@@ 471,4 469,86 @@ mod tests {
         let out2 = expr_map_pre(inp2, &mut swap_binop_args);
         assert_eq!(out2, desired2);
     }
+
+    /// Test whether our pre-traversal expr map works properly.
+    #[test]
+    fn test_expr_pretraverse() {
+        let inp = ExprNode::new(E::Block {
+            body: vec![ExprNode::new(E::Var {
+                name: Sym::new("foo"),
+            })],
+        });
+        // Make a transformer that renames vars, and check in the Block
+        // body whether the inner var has been transformed.
+        let f = &mut |e: ExprNode| {
+            let helper = &mut |e| match e {
+                E::Var { .. } => E::Var {
+                    name: Sym::new("bar!"),
+                },
+                E::Block { body } => {
+                    // Has the name within the body been transformed yet?
+                    {
+                        let inner = &*body[0].e;
+                        assert_eq!(
+                            inner,
+                            &E::Var {
+                                name: Sym::new("foo")
+                            }
+                        );
+                    }
+                    E::Block { body }
+                }
+                _other => _other,
+            };
+            e.map(helper)
+        };
+        let outp = expr_map_pre(inp, f);
+        // Make sure that the actual transformation has done what we expected.
+        let expected = ExprNode::new(E::Block {
+            body: vec![ExprNode::new(E::Var {
+                name: Sym::new("bar!"),
+            })],
+        });
+        assert_eq!(outp, expected);
+    }
+
+    /// Similar to above with the case in helper() reversed
+    #[test]
+    fn test_expr_posttraverse() {
+        let inp = ExprNode::new(E::Block {
+            body: vec![ExprNode::new(E::Var {
+                name: Sym::new("foo"),
+            })],
+        });
+        let f = &mut |e: ExprNode| {
+            let helper = &mut |e| match e {
+                E::Var { .. } => E::Var {
+                    name: Sym::new("bar!"),
+                },
+                E::Block { body } => {
+                    // Has the name within the body been transformed yet?
+                    {
+                        let inner = &*body[0].e;
+                        assert_eq!(
+                            inner,
+                            &E::Var {
+                                name: Sym::new("bar!")
+                            }
+                        );
+                    }
+                    E::Block { body }
+                }
+                _other => _other,
+            };
+            e.map(helper)
+        };
+        let outp = expr_map_post(inp, f);
+        // Make sure that the actual transformation has done what we expected.
+        let expected = ExprNode::new(E::Block {
+            body: vec![ExprNode::new(E::Var {
+                name: Sym::new("bar!"),
+            })],
+        });
+        assert_eq!(outp, expected);
+    }
 }