From c16bec6e30e556df44ac07ff9710cd43de7aae70 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 7 Aug 2022 17:17:20 -0700 Subject: [PATCH] refactor: move auxiliary let declaration support to `CompilerM.lean` --- src/Lean/Compiler/CompilerM.lean | 18 +++++++++++++-- src/Lean/Compiler/LCNF.lean | 39 ++++++++++++-------------------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/Lean/Compiler/CompilerM.lean b/src/Lean/Compiler/CompilerM.lean index 708fb8bd50..d9057d08c0 100644 --- a/src/Lean/Compiler/CompilerM.lean +++ b/src/Lean/Compiler/CompilerM.lean @@ -10,6 +10,8 @@ namespace Lean.Compiler structure CompilerM.State where lctx : LocalContext letFVars : Array Expr := #[] + /-- Next auxiliary variable suffix -/ + nextIdx : Nat := 1 abbrev CompilerM := StateRefT CompilerM.State CoreM @@ -42,6 +44,15 @@ def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (nonDep : Bool) : modify fun s => { s with lctx := s.lctx.mkLetDecl fvarId binderName type value nonDep, letFVars := s.letFVars.push x } return x +def mkAuxLetDecl (e : Expr) (prefixName := `_x): CompilerM Expr := do + if e.isFVar then + return e + else + try + mkLetDecl (prefixName.appendIndexAfter (← get).nextIdx) (← inferType e) e (nonDep := false) + finally + modify fun s => { s with nextIdx := s.nextIdx + 1 } + def visitLambda (e : Expr) : CompilerM (Array Expr × Expr) := go e #[] where @@ -67,9 +78,12 @@ where return (fvars, e.instantiateRev fvars) def withNewScopeImp (x : CompilerM α) : CompilerM α := do - let s ← get + let saved ← get modify fun s => { s with letFVars := #[] } - try x finally set s + try x + finally + let saved := { saved with nextIdx := (← get).nextIdx } + set saved def withNewScope [MonadFunctorT CompilerM m] (x : m α) : m α := monadMap (m := CompilerM) withNewScopeImp x diff --git a/src/Lean/Compiler/LCNF.lean b/src/Lean/Compiler/LCNF.lean index 045802e8ef..db5d9d7d70 100644 --- a/src/Lean/Compiler/LCNF.lean +++ b/src/Lean/Compiler/LCNF.lean @@ -28,29 +28,18 @@ structure Context where structure State where cache : PersistentExprMap Expr := {} - /-- Next auxiliary variable suffix -/ - nextIdx : Nat := 1 abbrev M := ReaderT Context $ StateRefT State CompilerM -def mkFreshLetDecl (e : Expr) : M Expr := do - if (← read).root then - return e - else - try - mkLetDecl ((`_x).appendIndexAfter (← get).nextIdx) (← inferType e) e (nonDep := false) - finally - modify fun s => { s with nextIdx := s.nextIdx + 1 } - @[inline] def withRoot (flag : Bool) (x : M α) : M α := withReader (fun _ => { root := flag }) x def withNewRootScope (x : M α) : M α := do - let cacheSaved := (← get).cache + let saved ← get try withRoot true <| Compiler.withNewScope x finally - modify fun s => { s with cache := cacheSaved } + set saved /-- Eta-expand with `n` lambdas. @@ -79,7 +68,7 @@ where /-- Visit args, and return `f args` -/ visitAppDefault (f : Expr) (args : Array Expr) : M Expr := do let args ← args.mapM visitChild - mkFreshLetDecl <| mkAppN f args + mkAuxLetDecl <| mkAppN f args /-- Eta expand if under applied, otherwise apply k -/ etaIfUnderApplied (e : Expr) (arity : Nat) (k : M Expr) : M Expr := do @@ -113,13 +102,13 @@ where -/ mkOverApplication (app : Expr) (args : Array Expr) (arity : Nat) : M Expr := do if args.size == arity then - mkFreshLetDecl app + mkAuxLetDecl app else - let k ← withRoot false <| mkFreshLetDecl app + let k ← withRoot false <| mkAuxLetDecl app let mut args := args for i in [arity : args.size] do args ← args.modifyM i visitChild - mkFreshLetDecl (mkAppN k args[arity:]) + mkAuxLetDecl (mkAppN k args[arity:]) /-- Create an application `f args` that is expected to have arity `arity`. @@ -132,7 +121,7 @@ where -/ mkAppWithArity (f : Expr) (args : Array Expr) (arity : Nat) : M Expr := do if args.size == arity then - mkFreshLetDecl (mkAppN f args) + mkAuxLetDecl (mkAppN f args) else mkOverApplication (mkAppN f args[:arity]) args arity @@ -169,7 +158,7 @@ where let f ← visitChild args[3]! let q ← visitChild args[5]! let .const _ [u, _] := e.getAppFn | unreachable! - let invq ← mkFreshLetDecl (mkApp3 (.const ``Quot.lcInv [u]) α r q) + let invq ← mkAuxLetDecl (mkApp3 (.const ``Quot.lcInv [u]) α r q) let r := mkApp f invq mkOverApplication r args arity @@ -186,7 +175,7 @@ where visitFalseRec (e : Expr) : M Expr := let arity := 2 etaIfUnderApplied e arity do - mkFreshLetDecl (← mkLcUnreachable (← inferType e)) + mkAuxLetDecl (← mkLcUnreachable (← inferType e)) visitAndRec (e : Expr) : M Expr := let arity := 5 @@ -218,7 +207,7 @@ where let major := mkAppN major args[arity+1:] visit major else - mkFreshLetDecl (← mkLcUnreachable (← inferType e)) + mkAuxLetDecl (← mkLcUnreachable (← inferType e)) | _, _ => throwError "code generator failed, unsupported occurrence of `{declName}`" @@ -280,16 +269,16 @@ where let (as, e) ← Compiler.visitLambda e let e ← mkLetUsingScope (← visit e) mkLambda as e - mkFreshLetDecl r + mkAuxLetDecl r visitMData (mdata : MData) (e : Expr) : M Expr := do if isCompilerRelevantMData mdata then - mkFreshLetDecl <| .mdata mdata (← visitChild e) + mkAuxLetDecl <| .mdata mdata (← visitChild e) else visit e visitProj (s : Name) (i : Nat) (e : Expr) : M Expr := do - mkFreshLetDecl <| .proj s i (← visitChild e) + mkAuxLetDecl <| .proj s i (← visitChild e) visit (e : Expr) : M Expr := withIncRecDepth do match e with @@ -315,7 +304,7 @@ where | .mdata d e => visitMData d e | .lam .. => visitLambda e | .letE .. => visit (← visitLet e visitChild) - | .lit .. => mkFreshLetDecl e + | .lit .. => mkAuxLetDecl e | _ => pure e modify fun s => { s with cache := s.cache.insert e r } return r