refactor: move auxiliary let declaration support to CompilerM.lean
This commit is contained in:
parent
578adcd7f0
commit
c16bec6e30
2 changed files with 30 additions and 27 deletions
|
|
@ -10,6 +10,8 @@ namespace Lean.Compiler
|
|||
structure CompilerM.State where
|
||||
lctx : LocalContext
|
||||
letFVars : Array Expr := #[]
|
||||
/-- Next auxiliary variable suffix -/
|
||||
nextIdx : Nat := 1
|
||||
|
||||
abbrev CompilerM := StateRefT CompilerM.State CoreM
|
||||
|
||||
|
|
@ -42,6 +44,15 @@ def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (nonDep : Bool) :
|
|||
modify fun s => { s with lctx := s.lctx.mkLetDecl fvarId binderName type value nonDep, letFVars := s.letFVars.push x }
|
||||
return x
|
||||
|
||||
def mkAuxLetDecl (e : Expr) (prefixName := `_x): CompilerM Expr := do
|
||||
if e.isFVar then
|
||||
return e
|
||||
else
|
||||
try
|
||||
mkLetDecl (prefixName.appendIndexAfter (← get).nextIdx) (← inferType e) e (nonDep := false)
|
||||
finally
|
||||
modify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||||
|
||||
def visitLambda (e : Expr) : CompilerM (Array Expr × Expr) :=
|
||||
go e #[]
|
||||
where
|
||||
|
|
@ -67,9 +78,12 @@ where
|
|||
return (fvars, e.instantiateRev fvars)
|
||||
|
||||
def withNewScopeImp (x : CompilerM α) : CompilerM α := do
|
||||
let s ← get
|
||||
let saved ← get
|
||||
modify fun s => { s with letFVars := #[] }
|
||||
try x finally set s
|
||||
try x
|
||||
finally
|
||||
let saved := { saved with nextIdx := (← get).nextIdx }
|
||||
set saved
|
||||
|
||||
def withNewScope [MonadFunctorT CompilerM m] (x : m α) : m α :=
|
||||
monadMap (m := CompilerM) withNewScopeImp x
|
||||
|
|
|
|||
|
|
@ -28,29 +28,18 @@ structure Context where
|
|||
|
||||
structure State where
|
||||
cache : PersistentExprMap Expr := {}
|
||||
/-- Next auxiliary variable suffix -/
|
||||
nextIdx : Nat := 1
|
||||
|
||||
abbrev M := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
def mkFreshLetDecl (e : Expr) : M Expr := do
|
||||
if (← read).root then
|
||||
return e
|
||||
else
|
||||
try
|
||||
mkLetDecl ((`_x).appendIndexAfter (← get).nextIdx) (← inferType e) e (nonDep := false)
|
||||
finally
|
||||
modify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||||
|
||||
@[inline] def withRoot (flag : Bool) (x : M α) : M α :=
|
||||
withReader (fun _ => { root := flag }) x
|
||||
|
||||
def withNewRootScope (x : M α) : M α := do
|
||||
let cacheSaved := (← get).cache
|
||||
let saved ← get
|
||||
try
|
||||
withRoot true <| Compiler.withNewScope x
|
||||
finally
|
||||
modify fun s => { s with cache := cacheSaved }
|
||||
set saved
|
||||
|
||||
/--
|
||||
Eta-expand with `n` lambdas.
|
||||
|
|
@ -79,7 +68,7 @@ where
|
|||
/-- Visit args, and return `f args` -/
|
||||
visitAppDefault (f : Expr) (args : Array Expr) : M Expr := do
|
||||
let args ← args.mapM visitChild
|
||||
mkFreshLetDecl <| mkAppN f args
|
||||
mkAuxLetDecl <| mkAppN f args
|
||||
|
||||
/-- Eta expand if under applied, otherwise apply k -/
|
||||
etaIfUnderApplied (e : Expr) (arity : Nat) (k : M Expr) : M Expr := do
|
||||
|
|
@ -113,13 +102,13 @@ where
|
|||
-/
|
||||
mkOverApplication (app : Expr) (args : Array Expr) (arity : Nat) : M Expr := do
|
||||
if args.size == arity then
|
||||
mkFreshLetDecl app
|
||||
mkAuxLetDecl app
|
||||
else
|
||||
let k ← withRoot false <| mkFreshLetDecl app
|
||||
let k ← withRoot false <| mkAuxLetDecl app
|
||||
let mut args := args
|
||||
for i in [arity : args.size] do
|
||||
args ← args.modifyM i visitChild
|
||||
mkFreshLetDecl (mkAppN k args[arity:])
|
||||
mkAuxLetDecl (mkAppN k args[arity:])
|
||||
|
||||
/--
|
||||
Create an application `f args` that is expected to have arity `arity`.
|
||||
|
|
@ -132,7 +121,7 @@ where
|
|||
-/
|
||||
mkAppWithArity (f : Expr) (args : Array Expr) (arity : Nat) : M Expr := do
|
||||
if args.size == arity then
|
||||
mkFreshLetDecl (mkAppN f args)
|
||||
mkAuxLetDecl (mkAppN f args)
|
||||
else
|
||||
mkOverApplication (mkAppN f args[:arity]) args arity
|
||||
|
||||
|
|
@ -169,7 +158,7 @@ where
|
|||
let f ← visitChild args[3]!
|
||||
let q ← visitChild args[5]!
|
||||
let .const _ [u, _] := e.getAppFn | unreachable!
|
||||
let invq ← mkFreshLetDecl (mkApp3 (.const ``Quot.lcInv [u]) α r q)
|
||||
let invq ← mkAuxLetDecl (mkApp3 (.const ``Quot.lcInv [u]) α r q)
|
||||
let r := mkApp f invq
|
||||
mkOverApplication r args arity
|
||||
|
||||
|
|
@ -186,7 +175,7 @@ where
|
|||
visitFalseRec (e : Expr) : M Expr :=
|
||||
let arity := 2
|
||||
etaIfUnderApplied e arity do
|
||||
mkFreshLetDecl (← mkLcUnreachable (← inferType e))
|
||||
mkAuxLetDecl (← mkLcUnreachable (← inferType e))
|
||||
|
||||
visitAndRec (e : Expr) : M Expr :=
|
||||
let arity := 5
|
||||
|
|
@ -218,7 +207,7 @@ where
|
|||
let major := mkAppN major args[arity+1:]
|
||||
visit major
|
||||
else
|
||||
mkFreshLetDecl (← mkLcUnreachable (← inferType e))
|
||||
mkAuxLetDecl (← mkLcUnreachable (← inferType e))
|
||||
| _, _ =>
|
||||
throwError "code generator failed, unsupported occurrence of `{declName}`"
|
||||
|
||||
|
|
@ -280,16 +269,16 @@ where
|
|||
let (as, e) ← Compiler.visitLambda e
|
||||
let e ← mkLetUsingScope (← visit e)
|
||||
mkLambda as e
|
||||
mkFreshLetDecl r
|
||||
mkAuxLetDecl r
|
||||
|
||||
visitMData (mdata : MData) (e : Expr) : M Expr := do
|
||||
if isCompilerRelevantMData mdata then
|
||||
mkFreshLetDecl <| .mdata mdata (← visitChild e)
|
||||
mkAuxLetDecl <| .mdata mdata (← visitChild e)
|
||||
else
|
||||
visit e
|
||||
|
||||
visitProj (s : Name) (i : Nat) (e : Expr) : M Expr := do
|
||||
mkFreshLetDecl <| .proj s i (← visitChild e)
|
||||
mkAuxLetDecl <| .proj s i (← visitChild e)
|
||||
|
||||
visit (e : Expr) : M Expr := withIncRecDepth do
|
||||
match e with
|
||||
|
|
@ -315,7 +304,7 @@ where
|
|||
| .mdata d e => visitMData d e
|
||||
| .lam .. => visitLambda e
|
||||
| .letE .. => visit (← visitLet e visitChild)
|
||||
| .lit .. => mkFreshLetDecl e
|
||||
| .lit .. => mkAuxLetDecl e
|
||||
| _ => pure e
|
||||
modify fun s => { s with cache := s.cache.insert e r }
|
||||
return r
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue