refactor: move auxiliary let declaration support to CompilerM.lean

This commit is contained in:
Leonardo de Moura 2022-08-07 17:17:20 -07:00
parent 578adcd7f0
commit c16bec6e30
2 changed files with 30 additions and 27 deletions

View file

@ -10,6 +10,8 @@ namespace Lean.Compiler
structure CompilerM.State where
lctx : LocalContext
letFVars : Array Expr := #[]
/-- Next auxiliary variable suffix -/
nextIdx : Nat := 1
abbrev CompilerM := StateRefT CompilerM.State CoreM
@ -42,6 +44,15 @@ def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (nonDep : Bool) :
modify fun s => { s with lctx := s.lctx.mkLetDecl fvarId binderName type value nonDep, letFVars := s.letFVars.push x }
return x
def mkAuxLetDecl (e : Expr) (prefixName := `_x): CompilerM Expr := do
if e.isFVar then
return e
else
try
mkLetDecl (prefixName.appendIndexAfter (← get).nextIdx) (← inferType e) e (nonDep := false)
finally
modify fun s => { s with nextIdx := s.nextIdx + 1 }
def visitLambda (e : Expr) : CompilerM (Array Expr × Expr) :=
go e #[]
where
@ -67,9 +78,12 @@ where
return (fvars, e.instantiateRev fvars)
def withNewScopeImp (x : CompilerM α) : CompilerM α := do
let s ← get
let saved ← get
modify fun s => { s with letFVars := #[] }
try x finally set s
try x
finally
let saved := { saved with nextIdx := (← get).nextIdx }
set saved
def withNewScope [MonadFunctorT CompilerM m] (x : m α) : m α :=
monadMap (m := CompilerM) withNewScopeImp x

View file

@ -28,29 +28,18 @@ structure Context where
structure State where
cache : PersistentExprMap Expr := {}
/-- Next auxiliary variable suffix -/
nextIdx : Nat := 1
abbrev M := ReaderT Context $ StateRefT State CompilerM
def mkFreshLetDecl (e : Expr) : M Expr := do
if (← read).root then
return e
else
try
mkLetDecl ((`_x).appendIndexAfter (← get).nextIdx) (← inferType e) e (nonDep := false)
finally
modify fun s => { s with nextIdx := s.nextIdx + 1 }
@[inline] def withRoot (flag : Bool) (x : M α) : M α :=
withReader (fun _ => { root := flag }) x
def withNewRootScope (x : M α) : M α := do
let cacheSaved := (← get).cache
let saved ← get
try
withRoot true <| Compiler.withNewScope x
finally
modify fun s => { s with cache := cacheSaved }
set saved
/--
Eta-expand with `n` lambdas.
@ -79,7 +68,7 @@ where
/-- Visit args, and return `f args` -/
visitAppDefault (f : Expr) (args : Array Expr) : M Expr := do
let args ← args.mapM visitChild
mkFreshLetDecl <| mkAppN f args
mkAuxLetDecl <| mkAppN f args
/-- Eta expand if under applied, otherwise apply k -/
etaIfUnderApplied (e : Expr) (arity : Nat) (k : M Expr) : M Expr := do
@ -113,13 +102,13 @@ where
-/
mkOverApplication (app : Expr) (args : Array Expr) (arity : Nat) : M Expr := do
if args.size == arity then
mkFreshLetDecl app
mkAuxLetDecl app
else
let k ← withRoot false <| mkFreshLetDecl app
let k ← withRoot false <| mkAuxLetDecl app
let mut args := args
for i in [arity : args.size] do
args ← args.modifyM i visitChild
mkFreshLetDecl (mkAppN k args[arity:])
mkAuxLetDecl (mkAppN k args[arity:])
/--
Create an application `f args` that is expected to have arity `arity`.
@ -132,7 +121,7 @@ where
-/
mkAppWithArity (f : Expr) (args : Array Expr) (arity : Nat) : M Expr := do
if args.size == arity then
mkFreshLetDecl (mkAppN f args)
mkAuxLetDecl (mkAppN f args)
else
mkOverApplication (mkAppN f args[:arity]) args arity
@ -169,7 +158,7 @@ where
let f ← visitChild args[3]!
let q ← visitChild args[5]!
let .const _ [u, _] := e.getAppFn | unreachable!
let invq ← mkFreshLetDecl (mkApp3 (.const ``Quot.lcInv [u]) α r q)
let invq ← mkAuxLetDecl (mkApp3 (.const ``Quot.lcInv [u]) α r q)
let r := mkApp f invq
mkOverApplication r args arity
@ -186,7 +175,7 @@ where
visitFalseRec (e : Expr) : M Expr :=
let arity := 2
etaIfUnderApplied e arity do
mkFreshLetDecl (← mkLcUnreachable (← inferType e))
mkAuxLetDecl (← mkLcUnreachable (← inferType e))
visitAndRec (e : Expr) : M Expr :=
let arity := 5
@ -218,7 +207,7 @@ where
let major := mkAppN major args[arity+1:]
visit major
else
mkFreshLetDecl (← mkLcUnreachable (← inferType e))
mkAuxLetDecl (← mkLcUnreachable (← inferType e))
| _, _ =>
throwError "code generator failed, unsupported occurrence of `{declName}`"
@ -280,16 +269,16 @@ where
let (as, e) ← Compiler.visitLambda e
let e ← mkLetUsingScope (← visit e)
mkLambda as e
mkFreshLetDecl r
mkAuxLetDecl r
visitMData (mdata : MData) (e : Expr) : M Expr := do
if isCompilerRelevantMData mdata then
mkFreshLetDecl <| .mdata mdata (← visitChild e)
mkAuxLetDecl <| .mdata mdata (← visitChild e)
else
visit e
visitProj (s : Name) (i : Nat) (e : Expr) : M Expr := do
mkFreshLetDecl <| .proj s i (← visitChild e)
mkAuxLetDecl <| .proj s i (← visitChild e)
visit (e : Expr) : M Expr := withIncRecDepth do
match e with
@ -315,7 +304,7 @@ where
| .mdata d e => visitMData d e
| .lam .. => visitLambda e
| .letE .. => visit (← visitLet e visitChild)
| .lit .. => mkFreshLetDecl e
| .lit .. => mkAuxLetDecl e
| _ => pure e
modify fun s => { s with cache := s.cache.insert e r }
return r