feat: inferType for new LCNF

This commit is contained in:
Leonardo de Moura 2022-08-22 19:51:17 -07:00
parent bd1186536f
commit a2fabc6d49
2 changed files with 177 additions and 0 deletions

View file

@ -29,6 +29,10 @@ instance : AddMessageContext CompilerM where
let opts ← getOptions
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do
let some decl := (← get).lctx.find? fvarId | throwError "unknown free variable {fvarId.name}"
return decl
namespace Internalize
structure State where

View file

@ -0,0 +1,173 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.Types
namespace Lean.Compiler.LCNF
namespace InferType
/--
We use a regular local context to store temporary local declarations
created during type inference.
-/
abbrev InferTypeM := ReaderT LocalContext CompilerM
def getLocalDecl (fvarId : FVarId) : InferTypeM LocalDecl := do
match (← read).find? fvarId with
| some localDecl => return localDecl
| none => LCNF.getLocalDecl fvarId
def mkForallFVars (xs : Array Expr) (b : Expr) : InferTypeM Expr :=
let b := b.abstract xs
xs.size.foldRevM (init := b) fun i b => do
let x := xs[i]!
let .cdecl _ _ n ty _ ← getLocalDecl x.fvarId! | unreachable!
let ty := ty.abstractRange i xs;
return .forallE n ty b .default
@[inline] def withLocalDecl (binderName : Name) (type : Expr) (binderInfo : BinderInfo) (k : Expr → InferTypeM α) : InferTypeM α := do
let fvarId ← mkFreshFVarId
withReader (fun lctx => lctx.mkLocalDecl fvarId binderName type binderInfo) do
k (.fvar fvarId)
def inferFVarType (fvarId : FVarId) : InferTypeM Expr :=
return (← getLocalDecl fvarId).type
def inferConstType (declName : Name) (us : List Level) : CoreM Expr :=
if declName == ``lcAny || declName == ``lcErased then
return anyTypeExpr
else
instantiateLCNFTypeLevelParams declName us
mutual
partial def inferType (e : Expr) : InferTypeM Expr :=
match e with
| .const c us => inferConstType c us
| .proj n i s => inferProjType n i s
| .app .. => inferAppType e
| .mvar .. => throwError "unexpected metavariable {e}"
| .fvar fvarId => inferFVarType fvarId
| .bvar .. => throwError "unexpected bound variable {e}"
| .mdata _ e => inferType e
| .lit v => return v.type
| .sort lvl => return .sort (mkLevelSucc lvl)
| .forallE .. => inferForallType e
| .lam .. => inferLambdaType e
| .letE .. => inferLambdaType e
partial def inferAppTypeCore (f : Expr) (args : Array Expr) : InferTypeM Expr := do
let mut j := 0
let mut fType ← inferType f
for i in [:args.size] do
fType := fType.headBeta
match fType with
| .forallE _ _ b _ => fType := b
| _ =>
fType := fType.instantiateRevRange j i args |>.headBeta
match fType with
| .forallE _ _ b _ => j := i; fType := b
| _ =>
if fType.isAnyType then return anyTypeExpr
throwError "function expected{indentExpr (mkAppN f args[:i])} : {fType}\nfunction type{indentExpr (← inferType f)}"
return fType.instantiateRevRange j args.size args |>.headBeta
partial def inferAppType (e : Expr) : InferTypeM Expr := do
inferAppTypeCore e.getAppFn e.getAppArgs
partial def inferProjType (structName : Name) (idx : Nat) (s : Expr) : InferTypeM Expr := do
let failed {α} : Unit → InferTypeM α := fun _ =>
throwError "invalid projection{indentExpr (mkProj structName idx s)}"
let structType ← inferType s
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
let n := structVal.numParams
let structParams := structType.getAppArgs
if n != structParams.size then
failed ()
else do
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams)
for _ in [:idx] do
match ctorType with
| .forallE _ _ body _ =>
assert! !body.hasLooseBVars
ctorType := body
| _ =>
if ctorType.isAnyType then return anyTypeExpr
failed ()
match ctorType with
| .forallE _ d _ _ => return d
| _ =>
if ctorType.isAnyType then return anyTypeExpr
failed ()
partial def getLevel? (type : Expr) : InferTypeM (Option Level) := do
match (← inferType type) with
| .sort u => return some u
| e =>
if e.isAnyType then
return none
else
throwError "type expected{indentExpr type}"
partial def inferForallType (e : Expr) : InferTypeM Expr :=
go e #[]
where
go (e : Expr) (fvars : Array Expr) : InferTypeM Expr := do
match e with
| .forallE n d b bi =>
withLocalDecl n (d.instantiateRev fvars) bi fun fvar =>
go b (fvars.push fvar)
| _ =>
let e := e.instantiateRev fvars
let some u ← getLevel? e | return anyTypeExpr
let mut u := u
for x in fvars do
let xType ← inferType x
let some v ← getLevel? xType | return anyTypeExpr
u := .imax v u
return .sort u.normalize
partial def inferLambdaType (e : Expr) : InferTypeM Expr :=
go e #[] #[]
where
go (e : Expr) (fvars : Array Expr) (all : Array Expr) : InferTypeM Expr := do
match e with
| .lam n d b bi =>
withLocalDecl n (d.instantiateRev all) bi fun fvar => go b (fvars.push fvar) (all.push fvar)
| .letE n t _ b _ =>
withLocalDecl n (t.instantiateRev all) .default fun fvar => go b fvars (all.push fvar)
| e =>
let type ← inferType (e.instantiateRev all)
mkForallFVars fvars type
end
end InferType
def inferType (e : Expr) : CompilerM Expr :=
InferType.inferType e |>.run {}
def getLevel (type : Expr) : CompilerM Level := do
match (← inferType type) with
| .sort u => return u
| e => if e.isAnyType then return levelOne else throwError "type expected{indentExpr type}"
/-- Create `lcCast expectedType e : expectedType` -/
def mkLcCast (e : Expr) (expectedType : Expr) : CompilerM Expr := do
let type ← inferType e
let u ← getLevel type
let v ← getLevel expectedType
return mkApp3 (.const ``lcCast [u, v]) type expectedType e
def Code.inferType (code : Code) : CompilerM Expr := do
match code with
| .let _ k | .fun _ k | .jp _ k => k.inferType
| .return fvarId => return (← getLocalDecl fvarId).type
| .jmp fvarId args => InferType.inferAppTypeCore (.fvar fvarId) args |>.run {}
| .unreach type => return type
| .cases c => return c.resultType
end Lean.Compiler.LCNF