refactor: extract methods from Lean.Meta.transform

Lean.Meta.transform is created with a series of recursive
visit functions. However these visit functions are useful
on their own outside of transform for traversing expressions.

This commit moves the visit functions outside the main function.
This commit is contained in:
E.W.Ayers 2022-06-03 19:29:57 -04:00 committed by Leonardo de Moura
parent ece1c1085c
commit 8311c88fd0

View file

@ -1,7 +1,7 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
Authors: Leonardo de Moura, E.W.Ayers
-/
import Lean.Meta.Basic
@ -11,6 +11,13 @@ inductive TransformStep where
| done (e : Expr)
| visit (e : Expr)
/-- Given `e = fn a₁ ... aₙ`, runs `f` on `fn` and each of the arguments `aᵢ` and
makes a new function application with the results. -/
def Expr.traverseApp {M} [Monad M]
(f : Expr → M Expr) (e : Expr) : M Expr :=
e.withApp fun fn args => (pure mkAppN) <*> (f fn) <*> (args.mapM f)
namespace Core
/--
@ -58,10 +65,59 @@ end Core
namespace Meta
variable {M} [Monad M] [MonadLiftT MetaM M] [MonadControlT MetaM M] [MonadOptions M]
def usedLetOnly : M Bool := getBoolOption `visit.usedLetOnly false
/-- Given an expression `fun (x₁ : α₁) ... (xₙ : αₙ) => b`, will run
`f` on each of the variable types `αᵢ` and `b` with the correct MetaM context,
replacing each expression with the output of `f` and creating a new lambda.
(that is, correctly instantiating bound variables and repackaging them after) -/
def traverseLambda
(f : Expr → M Expr) (e : Expr) : M Expr := visit #[] e
where visit (fvars : Array Expr) : Expr → M Expr
| (Expr.lam n d b c) => do withLocalDecl n c.binderInfo (← f (d.instantiateRev fvars)) fun x => visit (fvars.push x) b
| e => do mkLambdaFVars (usedLetOnly := ← usedLetOnly) fvars (← f (e.instantiateRev fvars))
/-- Given an expression ` (x₁ : α₁) → ... → (xₙ : αₙ) → b`, will run
`f` on each of the variable types `αᵢ` and `b` with the correct MetaM context,
replacing the expression with the output of `f` and creating a new forall expression.
(that is, correctly instantiating bound variables and repackaging them after) -/
def traverseForall
(f : Expr → M Expr) (e : Expr) : M Expr := visit #[] e
where visit fvars : Expr → M Expr
| (Expr.forallE n d b c) => do withLocalDecl n c.binderInfo (← f (d.instantiateRev fvars)) fun x => visit (fvars.push x) b
| e => do mkForallFVars (usedLetOnly := ←usedLetOnly) fvars (← f (e.instantiateRev fvars))
/-- Similar to traverseLambda and traverseForall but with let binders. -/
def traverseLet
(f : Expr → M Expr) (e : Expr) : M Expr := visit #[] e
where visit fvars
| Expr.letE n t v b _ => do
withLetDecl n (← f (t.instantiateRev fvars)) (← f (v.instantiateRev fvars)) fun x =>
visit (fvars.push x) b
| e => do mkLetFVars (usedLetOnly := ←usedLetOnly) fvars (← f (e.instantiateRev fvars))
/-- Maps `f` on each child of the given expression.
Applications, foralls, lambdas and let binders are bundled (as they are bundled in `Expr.traverseApp`, `traverseForall`, ...).
So `traverseChildren f e` where ``e = `(fn a₁ ... aₙ)`` will return
``(← f `(fn)) (← f `(a₁)) ... (← f `(aₙ))`` rather than ``(← f `(fn a₁ ... aₙ₋₁)) (← f `(aₙ))``
-/
def traverseChildren (f : Expr → M Expr) (e: Expr) : M Expr := do
match e with
| Expr.forallE .. => traverseForall f e
| Expr.lam .. => traverseLambda f e
| Expr.letE .. => traverseLet f e
| Expr.app .. => Expr.traverseApp f e
| Expr.mdata _ b _ => return e.updateMData! (← f b)
| Expr.proj _ _ b _ => return e.updateProj! (← f b)
| _ => return e
/--
Similar to `Core.transform`, but terms provided to `pre` and `post` do not contain loose bound variables.
So, it is safe to use any `MetaM` method at `pre` and `post`. -/
partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m] [MonadTrace m] [MonadRef m] [MonadOptions m] [AddMessageContext m]
partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m] [MonadTrace m] [MonadRef m] [MonadOptions m] [MonadWithOptions m] [AddMessageContext m]
(input : Expr)
(pre : Expr → m TransformStep := fun e => return TransformStep.visit e)
(post : Expr → m TransformStep := fun e => return TransformStep.done e)
@ -71,42 +127,15 @@ partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m]
let _ : MonadLiftT (ST IO.RealWorld) m := { monadLift := fun x => liftM (m := MetaM) (liftM (m := ST IO.RealWorld) x) }
let rec visit (e : Expr) : MonadCacheT ExprStructEq Expr m Expr :=
checkCache { val := e : ExprStructEq } fun _ => Meta.withIncRecDepth do
let rec visitPost (e : Expr) : MonadCacheT ExprStructEq Expr m Expr := do
match (← post e) with
| TransformStep.done e => pure e
| TransformStep.visit e => visit e
let rec visitLambda (fvars : Array Expr) (e : Expr) : MonadCacheT ExprStructEq Expr m Expr := do
match e with
| Expr.lam n d b c =>
withLocalDecl n c.binderInfo (← visit (d.instantiateRev fvars)) fun x =>
visitLambda (fvars.push x) b
| e => visitPost (← mkLambdaFVars (usedLetOnly := usedLetOnly) fvars (← visit (e.instantiateRev fvars)))
let rec visitForall (fvars : Array Expr) (e : Expr) : MonadCacheT ExprStructEq Expr m Expr := do
match e with
| Expr.forallE n d b c =>
withLocalDecl n c.binderInfo (← visit (d.instantiateRev fvars)) fun x =>
visitForall (fvars.push x) b
| e => visitPost (← mkForallFVars (usedLetOnly := usedLetOnly) fvars (← visit (e.instantiateRev fvars)))
let rec visitLet (fvars : Array Expr) (e : Expr) : MonadCacheT ExprStructEq Expr m Expr := do
match e with
| Expr.letE n t v b _ =>
withLetDecl n (← visit (t.instantiateRev fvars)) (← visit (v.instantiateRev fvars)) fun x =>
visitLet (fvars.push x) b
| e => visitPost (← mkLetFVars (usedLetOnly := usedLetOnly) fvars (← visit (e.instantiateRev fvars)))
let visitApp (e : Expr) : MonadCacheT ExprStructEq Expr m Expr :=
e.withApp fun f args => do
visitPost (mkAppN (← visit f) (← args.mapM visit))
match (← pre e) with
| TransformStep.done e => pure e
| TransformStep.visit e => match e with
| Expr.forallE .. => visitForall #[] e
| Expr.lam .. => visitLambda #[] e
| Expr.letE .. => visitLet #[] e
| Expr.app .. => visitApp e
| Expr.mdata _ b _ => visitPost (e.updateMData! (← visit b))
| Expr.proj _ _ b _ => visitPost (e.updateProj! (← visit b))
| _ => visitPost e
visit input |>.run
| TransformStep.visit e =>
let e ← traverseChildren visit e
match (← post e) with
| TransformStep.done e => pure e
| TransformStep.visit e => visit e
withOptions (fun o => o.setBool `visit.usedLetOnly usedLetOnly)
(visit input |>.run)
def zetaReduce (e : Expr) : MetaM Expr := do
let pre (e : Expr) : MetaM TransformStep := do