feat: TerminalCases for new LCNF

This commit is contained in:
Leonardo de Moura 2022-08-11 18:29:26 -07:00
parent 073e72181d
commit 623e0e9af9
6 changed files with 35 additions and 44 deletions

View file

@ -6,14 +6,15 @@ Authors: Leonardo de Moura
import Lean.Compiler.InferType
import Lean.Compiler.Util
namespace Lean.Compiler.InferType
namespace Lean.Compiler
open InferType
/-!
Type checker for LCNF expressions
-/
structure Config where
terminalCasesOnly : Bool := false
structure Check.Config where
terminalCasesOnly : Bool := true
def lambdaBoundedTelescope (e : Expr) (n : Nat) (k : Array Expr → Expr → InferTypeM α) : InferTypeM α :=
go e n #[]
@ -26,7 +27,7 @@ where
withLocalDecl n (d.instantiateRev xs) bi fun x => go b i (xs.push x)
| _ => throwError "lambda expected"
partial def check (e : Expr) (cfg : Config := {}) : InferTypeM Unit := do
partial def check (e : Expr) (cfg : Check.Config := {}) : InferTypeM Unit := do
checkBlock e #[]
where
checkBlock (e : Expr) (xs : Array Expr) : InferTypeM Unit := do
@ -95,4 +96,4 @@ where
unless compatibleTypes expectedType altBodyType do
throwError "type mismatch at LCNF `cases` alternative{indentExpr altBody}\nhas type{indentExpr altBodyType}\nbut is expected to have type{indentExpr expectedType}"
end Lean.Compiler.InferType
end Lean.Compiler

View file

@ -61,8 +61,8 @@ def toDecl (declName : Name) : CoreM Decl := do
let value ← toLCNF value -- TODO: uncomment
return { name := declName, type, value }
def Decl.check (decl : Decl) : CoreM Unit := do
InferType.check decl.value {} { lctx := {} }
def Decl.check (decl : Decl) (cfg : Check.Config := {}): CoreM Unit := do
Compiler.check decl.value cfg { lctx := {} }
let valueType ← InferType.inferType decl.value { lctx := {} }
unless compatibleTypes decl.type valueType do
throwError "declaration type mismatch at `{decl.name}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr decl.type}"

View file

@ -155,6 +155,9 @@ instance : MonadInferType InferType.InferTypeM where
export MonadInferType (inferType)
instance [MonadLift m n] [MonadInferType m] : MonadInferType n where
inferType e := liftM (inferType e : m _)
def getLevel [Monad m] [MonadInferType m] [MonadError m] (type : Expr) : m Level := do
match (← inferType type) with
| .sort u => return u

View file

@ -28,19 +28,18 @@ where
let info ← getConstInfo declName
Meta.isProp info.type <||> Meta.isTypeFormerType info.type
def checkpoint (step : Name) (decls : Array Decl) : CoreM Unit := do
def checkpoint (step : Name) (decls : Array Decl) (cfg : Check.Config := {}): CoreM Unit := do
trace[Meta.debug] "After {step}"
for decl in decls do
withOptions (fun opts => opts.setBool `pp.motives.pi false) do
trace[Meta.debug] "{decl.name} := {decl.value}"
decl.check
decl.check cfg
def compile (declNames : Array Name) : CoreM Unit := do
let declNames ← declNames.filterM shouldGenerateCode
let decls ← declNames.mapM toDecl
checkpoint `init decls
-- TODO: uncomment
-- let decls ← decls.mapM (·.terminalCases)
-- checkpoint `terminalCases decls
checkpoint `init decls { terminalCasesOnly := false }
let decls ← decls.mapM (·.terminalCases)
checkpoint `terminalCases decls
end Lean.Compiler

View file

@ -6,8 +6,7 @@ Authors: Leonardo de Moura
import Lean.Meta.Check
import Lean.Compiler.Util
import Lean.Compiler.Decl
#exit -- TODO: port file to new LCNF format
import Lean.Compiler.CompilerM
namespace Lean.Compiler
@ -29,8 +28,8 @@ partial def visitAlt (e : Expr) (numParams : Nat) : M Expr := do
partial def visitCases (casesInfo : CasesInfo) (cases : Expr) : M Expr := do
let mut args := cases.getAppArgs
if let some jp := (← read).jp? then
let .forallE _ _ b _ ← inferType' jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow, see `visitLet`
args ← liftMetaM <| updateMotive casesInfo args b
let .forallE _ _ b _ ← inferType jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow, see `visitLet`
args := casesInfo.updateResultingType args b
for i in casesInfo.altsRange, numParams in casesInfo.altNumParams do
args ← args.modifyM i (visitAlt · numParams)
return mkAppN cases.getAppFn args
@ -47,34 +46,15 @@ partial def visitLet (e : Expr) (fvars : Array Expr) : M Expr := do
let type := type.instantiateRev fvars
let mut value := value.instantiateRev fvars
if let some casesInfo ← isCasesApp? value then
let (bodyAbst, safeJp) ← withNewScope do
let bodyAbst ← withNewScope do
let x ← mkLocalDecl binderName type
let body ← visitLet body (fvars.push x)
let body ← mkLetUsingScope body
let bodyType ← inferType body
let bodyAbst := body.abstract #[x]
if (bodyType.abstract #[x]).hasLooseBVars then
/-
We cannot eliminate this nonterminal `cases` because the resulting type of the joinpoint
depends on `x`. We have to wait until we perform erasure to do it.
-/
return (bodyAbst, false)
else if !(← liftMetaM <| Meta.isTypeCorrect body) then
/-
We cannot eliminate this nonterminal `cases` because the joinpoint is not type correct.
This can happen because we abstracted `x`.
We have to wait until we perform erasure to do it.
Remark: we can skip this test if we set `nonDep` properly.
-/
return (bodyAbst, false)
else
return (bodyAbst, true)
if !safeJp then
return .letE binderName type value bodyAbst nonDep
else
let jp ← mkJpDecl (.lam binderName type bodyAbst .default)
withReader (fun _ => { jp? := some jp }) do
visitCases casesInfo value
return bodyAbst
let jp ← mkJpDecl (.lam binderName type bodyAbst .default)
withReader (fun _ => { jp? := some jp }) do
visitCases casesInfo value
else
if value.isLambda then
value ← visitLambda value
@ -91,10 +71,10 @@ partial def visitLet (e : Expr) (fvars : Array Expr) : M Expr := do
else
return e
| some jp =>
let .forallE _ d _ _ ← inferType' jp | unreachable!
let .forallE _ d _ _ ← inferType jp | unreachable!
if isLcUnreachable e then
mkLcUnreachable d
else if (← isDefEq (← inferType e) d) then
else if compatibleTypes (← inferType e) d then
let x ← mkAuxLetDecl e
return mkApp jp x
else if let some x := isLcCast? e then
@ -110,7 +90,7 @@ end
end TerminalCases
/--
(Try to) ensure all `casesOn` and `matcher` applications are terminal.
Ensure all `casesOn` and `matcher` applications are terminal.
-/
def Decl.terminalCases (decl : Decl) : CoreM Decl := do
return { decl with value := (← TerminalCases.visitLambda decl.value |>.run {} |>.run' { nextIdx := (← getMaxLetVarIdx decl.value) + 1 }) }

View file

@ -89,6 +89,14 @@ def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do
def CasesInfo.geNumDiscrs (casesInfo : CasesInfo) : Nat :=
casesInfo.discrsRange.stop - casesInfo.discrsRange.start
def CasesInfo.updateResultingType (casesInfo : CasesInfo) (casesArgs : Array Expr) (typeNew : Expr) : Array Expr :=
casesArgs.modify casesInfo.motivePos fun motive => go motive
where
go (e : Expr) : Expr :=
match e with
| .lam n b d bi => .lam n b (go d) bi
| _ => typeNew
def isCasesApp? (e : Expr) : CoreM (Option CasesInfo) := do
let .const declName _ := e.getAppFn | return none
if let some info ← getCasesInfo? declName then