feat: store FunDecls at LCNF local context

This commit is contained in:
Leonardo de Moura 2022-08-22 21:46:37 -07:00
parent 82acc2b39c
commit 766afdd0bc
2 changed files with 45 additions and 14 deletions

View file

@ -30,7 +30,7 @@ instance : AddMessageContext CompilerM where
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do
let some decl := (← get).lctx.find? fvarId | throwError "unknown free variable {fvarId.name}"
let some decl := (← get).lctx.localDecls.find? fvarId | throwError "unknown free variable {fvarId.name}"
return decl
namespace Internalize
@ -66,15 +66,18 @@ where
@[inline] private def translate (e : Expr) : M Expr :=
return translateCore (← get) e
private def declareFVarId (fvarId : FVarId) (binderName : Name) (type : Expr) : M FVarId := do
let fvarId' ← mkFreshFVarId
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
modify fun s => { s with lctx := f s.lctx }
private def mkNewFVarId (fvarId : FVarId) : M FVarId := do
let fvarId' ← Lean.mkFreshFVarId
modify fun s => { s with fvarIdMap := s.fvarIdMap.insert fvarId fvarId' }
modifyThe CompilerM.State fun s => { s with lctx := s.lctx.insert fvarId (.cdecl 0 fvarId' binderName type .default) }
return fvarId'
private def declareParam (p : Param) : M Param := do
private def addParam (p : Param) : M Param := do
let type ← translate p.type
let fvarId ← declareFVarId p.fvarId p.binderName type
let fvarId ← mkNewFVarId p.fvarId
modifyLCtx fun lctx => lctx.addLocalDecl fvarId p.binderName type
return { p with fvarId, type }
end Internalize
@ -88,17 +91,20 @@ partial def internalize (code : Code) : CompilerM Code :=
where
goFunDecl (decl : FunDecl) : M FunDecl := do
let type ← translate decl.type
let params ← decl.params.mapM declareParam
let params ← decl.params.mapM addParam
let value ← go decl.value
let fvarId ← declareFVarId decl.fvarId decl.binderName type
return { decl with fvarId, params, type, value }
let fvarId ← mkNewFVarId decl.fvarId
let decl := { decl with fvarId, params, type, value }
modifyLCtx fun lctx => lctx.addFunDecl decl
return decl
go (code : Code) : M Code := do
match code with
| .let decl k =>
let type ← translate decl.type
let value ← translate decl.value
let fvarId ← declareFVarId decl.fvarId decl.binderName type
let fvarId ← mkNewFVarId decl.fvarId
modifyLCtx fun lctx => lctx.addLetDecl fvarId decl.binderName type value
let k ← go k
return .let { decl with fvarId, type, value } k
| .fun decl k =>
@ -111,7 +117,7 @@ where
| .cases c =>
let discr ← translateFVarId c.discr
let alts ← c.alts.mapM fun
| .alt ctorName params k => return .alt ctorName (← params.mapM declareParam) (← go k)
| .alt ctorName params k => return .alt ctorName (← params.mapM addParam) (← go k)
| .default k => return .default (← go k)
return .cases { c with discr, alts }

View file

@ -4,18 +4,43 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.LocalContext
import Lean.Compiler.LCNF.Basic
namespace Lean.Compiler.LCNF
/--
LCNF local context.
-/
abbrev LCtx := Std.HashMap FVarId LocalDecl
structure LCtx where
localDecls : Std.HashMap FVarId LocalDecl := {}
funDecls : Std.HashMap FVarId FunDecl := {}
fvarIds : Array FVarId := #[]
deriving Inhabited
def LCtx.addLocalDecl (lctx : LCtx) (fvarId : FVarId) (binderName : Name) (type : Expr) : LCtx :=
{ lctx with
localDecls := lctx.localDecls.insert fvarId (.cdecl 0 fvarId binderName type .default)
fvarIds := lctx.fvarIds.push fvarId }
def LCtx.addLetDecl (lctx : LCtx) (fvarId : FVarId) (binderName : Name) (type : Expr) (value : Expr) : LCtx :=
{ lctx with
localDecls := lctx.localDecls.insert fvarId (.ldecl 0 fvarId binderName type value true)
fvarIds := lctx.fvarIds.push fvarId }
def LCtx.addFunDecl (lctx : LCtx) (funDecl : FunDecl) : LCtx :=
{ lctx with
localDecls := lctx.localDecls.insert funDecl.fvarId (.cdecl 0 funDecl.fvarId funDecl.binderName funDecl.type .default)
funDecls := lctx.funDecls.insert funDecl.fvarId funDecl
fvarIds := lctx.fvarIds.push funDecl.fvarId }
/--
Convert a LCNF local context into a regular Lean local context.
-/
def LCtx.toLocalContext (lctx : LCtx) : LocalContext :=
lctx.fold (init := {}) fun lctx _ localDecl => lctx.addDecl localDecl
def LCtx.toLocalContext (lctx : LCtx) : LocalContext := Id.run do
let mut result := {}
for fvarId in lctx.fvarIds do
let localDecl := lctx.localDecls.find? fvarId |>.get!
result := result.addDecl localDecl
return result
end Lean.Compiler.LCNF