feat: TerminalCases for new LCNF
This commit is contained in:
parent
073e72181d
commit
623e0e9af9
6 changed files with 35 additions and 44 deletions
|
|
@ -6,14 +6,15 @@ Authors: Leonardo de Moura
|
|||
import Lean.Compiler.InferType
|
||||
import Lean.Compiler.Util
|
||||
|
||||
namespace Lean.Compiler.InferType
|
||||
namespace Lean.Compiler
|
||||
open InferType
|
||||
|
||||
/-!
|
||||
Type checker for LCNF expressions
|
||||
-/
|
||||
|
||||
structure Config where
|
||||
terminalCasesOnly : Bool := false
|
||||
structure Check.Config where
|
||||
terminalCasesOnly : Bool := true
|
||||
|
||||
def lambdaBoundedTelescope (e : Expr) (n : Nat) (k : Array Expr → Expr → InferTypeM α) : InferTypeM α :=
|
||||
go e n #[]
|
||||
|
|
@ -26,7 +27,7 @@ where
|
|||
withLocalDecl n (d.instantiateRev xs) bi fun x => go b i (xs.push x)
|
||||
| _ => throwError "lambda expected"
|
||||
|
||||
partial def check (e : Expr) (cfg : Config := {}) : InferTypeM Unit := do
|
||||
partial def check (e : Expr) (cfg : Check.Config := {}) : InferTypeM Unit := do
|
||||
checkBlock e #[]
|
||||
where
|
||||
checkBlock (e : Expr) (xs : Array Expr) : InferTypeM Unit := do
|
||||
|
|
@ -95,4 +96,4 @@ where
|
|||
unless compatibleTypes expectedType altBodyType do
|
||||
throwError "type mismatch at LCNF `cases` alternative{indentExpr altBody}\nhas type{indentExpr altBodyType}\nbut is expected to have type{indentExpr expectedType}"
|
||||
|
||||
end Lean.Compiler.InferType
|
||||
end Lean.Compiler
|
||||
|
|
@ -61,8 +61,8 @@ def toDecl (declName : Name) : CoreM Decl := do
|
|||
let value ← toLCNF value -- TODO: uncomment
|
||||
return { name := declName, type, value }
|
||||
|
||||
def Decl.check (decl : Decl) : CoreM Unit := do
|
||||
InferType.check decl.value {} { lctx := {} }
|
||||
def Decl.check (decl : Decl) (cfg : Check.Config := {}): CoreM Unit := do
|
||||
Compiler.check decl.value cfg { lctx := {} }
|
||||
let valueType ← InferType.inferType decl.value { lctx := {} }
|
||||
unless compatibleTypes decl.type valueType do
|
||||
throwError "declaration type mismatch at `{decl.name}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr decl.type}"
|
||||
|
|
|
|||
|
|
@ -155,6 +155,9 @@ instance : MonadInferType InferType.InferTypeM where
|
|||
|
||||
export MonadInferType (inferType)
|
||||
|
||||
instance [MonadLift m n] [MonadInferType m] : MonadInferType n where
|
||||
inferType e := liftM (inferType e : m _)
|
||||
|
||||
def getLevel [Monad m] [MonadInferType m] [MonadError m] (type : Expr) : m Level := do
|
||||
match (← inferType type) with
|
||||
| .sort u => return u
|
||||
|
|
|
|||
|
|
@ -28,19 +28,18 @@ where
|
|||
let info ← getConstInfo declName
|
||||
Meta.isProp info.type <||> Meta.isTypeFormerType info.type
|
||||
|
||||
def checkpoint (step : Name) (decls : Array Decl) : CoreM Unit := do
|
||||
def checkpoint (step : Name) (decls : Array Decl) (cfg : Check.Config := {}): CoreM Unit := do
|
||||
trace[Meta.debug] "After {step}"
|
||||
for decl in decls do
|
||||
withOptions (fun opts => opts.setBool `pp.motives.pi false) do
|
||||
trace[Meta.debug] "{decl.name} := {decl.value}"
|
||||
decl.check
|
||||
decl.check cfg
|
||||
|
||||
def compile (declNames : Array Name) : CoreM Unit := do
|
||||
let declNames ← declNames.filterM shouldGenerateCode
|
||||
let decls ← declNames.mapM toDecl
|
||||
checkpoint `init decls
|
||||
-- TODO: uncomment
|
||||
-- let decls ← decls.mapM (·.terminalCases)
|
||||
-- checkpoint `terminalCases decls
|
||||
checkpoint `init decls { terminalCasesOnly := false }
|
||||
let decls ← decls.mapM (·.terminalCases)
|
||||
checkpoint `terminalCases decls
|
||||
|
||||
end Lean.Compiler
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ Authors: Leonardo de Moura
|
|||
import Lean.Meta.Check
|
||||
import Lean.Compiler.Util
|
||||
import Lean.Compiler.Decl
|
||||
|
||||
#exit -- TODO: port file to new LCNF format
|
||||
import Lean.Compiler.CompilerM
|
||||
|
||||
namespace Lean.Compiler
|
||||
|
||||
|
|
@ -29,8 +28,8 @@ partial def visitAlt (e : Expr) (numParams : Nat) : M Expr := do
|
|||
partial def visitCases (casesInfo : CasesInfo) (cases : Expr) : M Expr := do
|
||||
let mut args := cases.getAppArgs
|
||||
if let some jp := (← read).jp? then
|
||||
let .forallE _ _ b _ ← inferType' jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow, see `visitLet`
|
||||
args ← liftMetaM <| updateMotive casesInfo args b
|
||||
let .forallE _ _ b _ ← inferType jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow, see `visitLet`
|
||||
args := casesInfo.updateResultingType args b
|
||||
for i in casesInfo.altsRange, numParams in casesInfo.altNumParams do
|
||||
args ← args.modifyM i (visitAlt · numParams)
|
||||
return mkAppN cases.getAppFn args
|
||||
|
|
@ -47,34 +46,15 @@ partial def visitLet (e : Expr) (fvars : Array Expr) : M Expr := do
|
|||
let type := type.instantiateRev fvars
|
||||
let mut value := value.instantiateRev fvars
|
||||
if let some casesInfo ← isCasesApp? value then
|
||||
let (bodyAbst, safeJp) ← withNewScope do
|
||||
let bodyAbst ← withNewScope do
|
||||
let x ← mkLocalDecl binderName type
|
||||
let body ← visitLet body (fvars.push x)
|
||||
let body ← mkLetUsingScope body
|
||||
let bodyType ← inferType body
|
||||
let bodyAbst := body.abstract #[x]
|
||||
if (bodyType.abstract #[x]).hasLooseBVars then
|
||||
/-
|
||||
We cannot eliminate this nonterminal `cases` because the resulting type of the joinpoint
|
||||
depends on `x`. We have to wait until we perform erasure to do it.
|
||||
-/
|
||||
return (bodyAbst, false)
|
||||
else if !(← liftMetaM <| Meta.isTypeCorrect body) then
|
||||
/-
|
||||
We cannot eliminate this nonterminal `cases` because the joinpoint is not type correct.
|
||||
This can happen because we abstracted `x`.
|
||||
We have to wait until we perform erasure to do it.
|
||||
Remark: we can skip this test if we set `nonDep` properly.
|
||||
-/
|
||||
return (bodyAbst, false)
|
||||
else
|
||||
return (bodyAbst, true)
|
||||
if !safeJp then
|
||||
return .letE binderName type value bodyAbst nonDep
|
||||
else
|
||||
let jp ← mkJpDecl (.lam binderName type bodyAbst .default)
|
||||
withReader (fun _ => { jp? := some jp }) do
|
||||
visitCases casesInfo value
|
||||
return bodyAbst
|
||||
let jp ← mkJpDecl (.lam binderName type bodyAbst .default)
|
||||
withReader (fun _ => { jp? := some jp }) do
|
||||
visitCases casesInfo value
|
||||
else
|
||||
if value.isLambda then
|
||||
value ← visitLambda value
|
||||
|
|
@ -91,10 +71,10 @@ partial def visitLet (e : Expr) (fvars : Array Expr) : M Expr := do
|
|||
else
|
||||
return e
|
||||
| some jp =>
|
||||
let .forallE _ d _ _ ← inferType' jp | unreachable!
|
||||
let .forallE _ d _ _ ← inferType jp | unreachable!
|
||||
if isLcUnreachable e then
|
||||
mkLcUnreachable d
|
||||
else if (← isDefEq (← inferType e) d) then
|
||||
else if compatibleTypes (← inferType e) d then
|
||||
let x ← mkAuxLetDecl e
|
||||
return mkApp jp x
|
||||
else if let some x := isLcCast? e then
|
||||
|
|
@ -110,7 +90,7 @@ end
|
|||
end TerminalCases
|
||||
|
||||
/--
|
||||
(Try to) ensure all `casesOn` and `matcher` applications are terminal.
|
||||
Ensure all `casesOn` and `matcher` applications are terminal.
|
||||
-/
|
||||
def Decl.terminalCases (decl : Decl) : CoreM Decl := do
|
||||
return { decl with value := (← TerminalCases.visitLambda decl.value |>.run {} |>.run' { nextIdx := (← getMaxLetVarIdx decl.value) + 1 }) }
|
||||
|
|
|
|||
|
|
@ -89,6 +89,14 @@ def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do
|
|||
def CasesInfo.geNumDiscrs (casesInfo : CasesInfo) : Nat :=
|
||||
casesInfo.discrsRange.stop - casesInfo.discrsRange.start
|
||||
|
||||
def CasesInfo.updateResultingType (casesInfo : CasesInfo) (casesArgs : Array Expr) (typeNew : Expr) : Array Expr :=
|
||||
casesArgs.modify casesInfo.motivePos fun motive => go motive
|
||||
where
|
||||
go (e : Expr) : Expr :=
|
||||
match e with
|
||||
| .lam n b d bi => .lam n b (go d) bi
|
||||
| _ => typeNew
|
||||
|
||||
def isCasesApp? (e : Expr) : CoreM (Option CasesInfo) := do
|
||||
let .const declName _ := e.getAppFn | return none
|
||||
if let some info ← getCasesInfo? declName then
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue