refactor: add Lean.Meta directory for MetaM monad

This commit is contained in:
Leonardo de Moura 2019-11-10 08:04:00 -08:00
parent bff5e4ed37
commit 41fccd976b
4 changed files with 234 additions and 197 deletions

View file

@ -9,9 +9,7 @@ import Init.Lean.NameGenerator
import Init.Lean.Environment
import Init.Lean.LOption
import Init.Lean.Trace
import Init.Lean.AuxRecursor
import Init.Lean.Class
import Init.Lean.WHNF
import Init.Lean.ReducibilityAttrs
/-
@ -113,21 +111,21 @@ abbrev MetaM := ReaderT Context (EStateM Exception State)
instance MetaM.inhabited {α} : Inhabited (MetaM α) :=
⟨fun c s => EStateM.Result.error (arbitrary _) s⟩
@[inline] private def getLCtx : MetaM LocalContext :=
@[inline] def getLCtx : MetaM LocalContext :=
do ctx ← read; pure ctx.lctx
@[inline] private def getMCtx : MetaM MetavarContext :=
@[inline] def getMCtx : MetaM MetavarContext :=
do s ← get; pure s.mctx
@[inline] private def getEnv : MetaM Environment :=
@[inline] def getEnv : MetaM Environment :=
do s ← get; pure s.env
@[inline] private def throwEx {α} (f : ExceptionContext → Exception) : MetaM α :=
@[inline] def throwEx {α} (f : ExceptionContext → Exception) : MetaM α :=
do ctx ← read;
s ← get;
throw (f {env := s.env, mctx := s.mctx, lctx := ctx.lctx })
@[inline] private def throwBug {α} (b : Bug) : MetaM α :=
@[inline] def throwBug {α} (b : Bug) : MetaM α :=
throwEx $ Exception.bug b
/-- Execute `x` only in debugging mode. -/
@ -135,7 +133,7 @@ throwEx $ Exception.bug b
do ctx ← read;
when ctx.config.debug (do x; pure ())
private def mkFreshId : MetaM Name :=
def mkFreshId : MetaM Name :=
do s ← get;
let id := s.ngen.curr;
modify $ fun s => { ngen := s.ngen.next, .. s };
@ -147,7 +145,7 @@ do ctx ← read; pure $ ctx.config.transparency == TransparencyMode.all
@[inline] private def reduceReducibleOnly? : MetaM Bool :=
do ctx ← read; pure $ ctx.config.transparency == TransparencyMode.reducible
@[inline] private def getTransparency : MetaM TransparencyMode :=
@[inline] def getTransparency : MetaM TransparencyMode :=
do ctx ← read; pure $ ctx.config.transparency
@[inline] private def getOptions : MetaM Options :=
@ -164,7 +162,7 @@ adaptReader
(fun (ctx : Context) => { config := { transparency := TransparencyMode.reducible, .. ctx.config }, .. ctx })
x
private def isReadOnlyOrSyntheticMVar (mvarId : Name) : MetaM Bool :=
def isReadOnlyOrSyntheticMVar (mvarId : Name) : MetaM Bool :=
do mctx ← getMCtx;
match mctx.findDecl mvarId with
| some d => pure $ d.synthetic || d.depth != mctx.depth
@ -174,10 +172,10 @@ do mctx ← getMCtx;
do mctx ← getMCtx;
pure $ mctx.isExprAssigned mvarId
@[inline] private def getMVarAssignment (mvarId : Name) : MetaM (Option Expr) :=
@[inline] def getExprMVarAssignment (mvarId : Name) : MetaM (Option Expr) :=
do mctx ← getMCtx; pure (mctx.getExprAssignment mvarId)
private def assignExprMVar (mvarId : Name) (val : Expr) : MetaM Unit :=
def assignExprMVar (mvarId : Name) (val : Expr) : MetaM Unit :=
do whenDebugging $ whenM (isExprAssigned mvarId) $ throwBug $ Bug.overwritingExprMVar mvarId;
modify $ fun s => { mctx := s.mctx.assignExpr mvarId val, .. s }
@ -189,7 +187,7 @@ instance tracer : SimpleMonadTracerAdapter MetaM :=
getTraceState := getTraceState,
modifyTraceState := fun f => modify $ fun s => { traceState := f s.traceState, .. s } }
private def getConst (constName : Name) : MetaM (Option ConstantInfo) :=
def getConst (constName : Name) : MetaM (Option ConstantInfo) :=
do env ← getEnv;
match env.find constName with
| some (info@(ConstantInfo.thmInfo _)) =>
@ -201,24 +199,12 @@ do env ← getEnv;
| none =>
throwEx $ Exception.unknownConst constName
private def isAuxDef? (constName : Name) : MetaM Bool :=
do env ← getEnv; pure (isAuxRecursor env constName || isNoConfusion env constName)
private def getLocalDecl (fvarId : Name) : MetaM LocalDecl :=
def getLocalDecl (fvarId : Name) : MetaM LocalDecl :=
do lctx ← getLCtx;
match lctx.find fvarId with
| some d => pure d
| none => throwEx $ Exception.unknownFVar fvarId
@[inline] private def getCachedWHNF (e : Expr) : MetaM (Option Expr) :=
do t ← getTransparency;
s ← get;
pure $ s.cache.whnf.find (t, e)
@[inline] private def cacheWHNF (e : Expr) (r : Expr) : MetaM Unit :=
do t ← getTransparency;
modify $ fun s => { cache := { whnf := s.cache.whnf.insert (t, e) r, .. s.cache }, .. s }
def instantiateMVars (e : Expr) : MetaM Expr :=
if e.hasMVar then
modifyGet $ fun s =>
@ -227,21 +213,6 @@ if e.hasMVar then
else
pure e
@[specialize] private partial def whnfAux
(inferType : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
: Expr → MetaM Expr
| e => whnfEasyCases getLocalDecl getMVarAssignment e $ fun e => do
cached? ← getCachedWHNF e;
match cached? with
| some r => pure r
| none => do
e ← whnfCore getConst isAuxDef? whnfAux inferType isDefEq getLocalDecl getMVarAssignment e;
r ← unfoldDefinition getConst isAuxDef? whnfAux inferType isDefEq synthesizePending getLocalDecl getMVarAssignment e (fun _ => pure e) whnfAux;
cacheWHNF e r;
pure r
@[inline] private def liftMkBindingM {α} (x : MetavarContext.MkBindingM α) : MetaM α :=
fun ctx s =>
match x ctx.lctx { mctx := s.mctx, ngen := s.ngen } with
@ -263,7 +234,7 @@ def mkLambda (xs : Array Expr) (e : Expr) : MetaM Expr :=
liftMkBindingM $ MetavarContext.mkLambda xs e
/-- Save cache, execute `x`, restore cache -/
@[inline] private def savingCache {α} (x : MetaM α) : MetaM α :=
@[inline] def savingCache {α} (x : MetaM α) : MetaM α :=
do s ← get;
let savedCache := s.cache;
finally x (modify $ fun s => { cache := savedCache, .. s })
@ -290,7 +261,7 @@ private partial def isClassQuick : Expr → MetaM (LOption Name)
| Expr.mdata _ e => isClassQuick e
| Expr.const n _ => isClassQuickConst n
| Expr.mvar mvarId => do
val? ← getMVarAssignment mvarId;
val? ← getExprMVarAssignment mvarId;
match val? with
| some val => isClassQuick val
| none => pure LOption.none
@ -463,166 +434,13 @@ do c? ← isClassQuick type;
k fvars e
/-- Similar to `forallTelescope` but for lambda and let expressions. -/
@[specialize] private def lambdaTelescope {α}
@[specialize] def lambdaTelescope {α}
(whnf : Expr → MetaM Expr)
(e : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α :=
savingCache $ do
lctx ← getLCtx;
lambdaTelescopeAux whnf k lctx #[] 0 e
@[specialize] private def getForallResultType
(whnf : Expr → MetaM Expr)
(fType : Expr) (args : Array Expr) : MetaM Expr :=
do (j, fType) ← args.size.foldM
(fun i (acc : Nat × Expr) =>
let (j, type) := acc;
match type with
| Expr.forallE _ _ _ b => pure (j, b)
| _ => do
type ← whnf $ type.instantiateRevRange j i args;
match type with
| Expr.forallE _ _ _ b => pure (i, b)
| _ => throwEx $ Exception.functionExpected fType args)
(0, fType);
pure $ fType.instantiateRevRange j args.size args
@[specialize] private def inferAppType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(f : Expr) (args : Array Expr) : MetaM Expr :=
do fType ← inferType f;
getForallResultType whnf fType args
private def inferConstType (c : Name) (lvls : List Level) : MetaM Expr :=
do env ← getEnv;
match env.find c with
| some cinfo =>
if cinfo.lparams.length == lvls.length then
throwEx $ Exception.incorrectNumOfLevels c lvls
else
pure $ cinfo.instantiateTypeLevelParams lvls
| none =>
throwEx $ Exception.unknownConst c
@[specialize] private def inferProjType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(structName : Name) (idx : Nat) (e : Expr) : MetaM Expr :=
do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProjection structName idx e;
structType ← inferType e;
structType ← whnf structType;
env ← getEnv;
matchConst env structType.getAppFn failed $ fun structInfo structLvls => do
match structInfo with
| ConstantInfo.inductInfo { nparams := n, ctors := [ctor], .. } =>
let structParams := structType.getAppArgs;
if n != structParams.size then failed ()
else match env.find ctor with
| none => failed ()
| some (ctorInfo) => do
let ctorType := ctorInfo.instantiateTypeLevelParams structLvls;
ctorType ← getForallResultType whnf ctorType structParams;
ctorType ← idx.foldM
(fun i ctorType => do
ctorType ← whnf ctorType;
match ctorType with
| Expr.forallE _ _ _ body =>
if body.hasLooseBVars then
pure $ body.instantiate1 $ Expr.proj structName i e
else
pure body
| _ => failed ())
ctorType;
ctorType ← whnf ctorType;
match ctorType with
| Expr.forallE _ _ d _ => pure d
| _ => failed ()
| _ => failed ()
@[specialize] private def getLevel
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(type : Expr) : MetaM Level :=
do typeType ← inferType type;
typeType ← whnf type;
match typeType with
| Expr.sort lvl => pure lvl
| Expr.mvar mvarId =>
condM (isReadOnlyOrSyntheticMVar mvarId)
(throwEx $ Exception.typeExpected type)
(do levelMVarId ← mkFreshId;
let lvl := Level.mvar levelMVarId;
assignExprMVar mvarId (Expr.sort lvl);
pure lvl)
| _ => throwEx $ Exception.typeExpected type
@[specialize] private def inferForallType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(e : Expr) : MetaM Expr :=
forallTelescope whnf e $ fun xs e => do
type ← inferType e;
lvl ← getLevel whnf inferType type;
lvl ← xs.foldrM
(fun x lvl => do
xType ← inferType x;
xTypeLvl ← getLevel whnf inferType xType;
pure $ Level.imax xTypeLvl lvl)
lvl;
pure $ Expr.sort lvl
/- Infer type of lambda and let expressions -/
@[specialize] private def inferLambdaType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(e : Expr) : MetaM Expr :=
lambdaTelescope whnf e $ fun xs e => do
type ← inferType e;
mkForall xs type
@[inline] private def withLocalDecl {α} (name : Name) (bi : BinderInfo) (type : Expr) (x : Expr → MetaM α) : MetaM α :=
savingCache $ do
fvarId ← mkFreshId;
adaptReader (fun (ctx : Context) => { lctx := ctx.lctx.mkLocalDecl fvarId name type bi, .. ctx }) $
x (Expr.fvar fvarId)
private def inferMVarType (mvarId : Name) : MetaM Expr :=
do mctx ← getMCtx;
match mctx.findDecl mvarId with
| some d => pure d.type
| none => throwEx $ Exception.unknownMVar mvarId
private def inferFVarType (fvarId : Name) : MetaM Expr :=
do lctx ← getLCtx;
match lctx.find fvarId with
| some d => pure d.type
| none => throwEx $ Exception.unknownFVar fvarId
@[inline] private def checkInferTypeCache (e : Expr) (inferType : MetaM Expr) : MetaM Expr :=
do s ← get;
match s.cache.inferType.find e with
| some type => pure type
| none => do
type ← inferType;
modify $ fun s => { cache := { inferType := s.cache.inferType.insert e type, .. s.cache }, .. s };
pure type
@[specialize] private partial def inferTypeAux
(whnf : Expr → MetaM Expr)
: Expr → MetaM Expr
| Expr.const c lvls => inferConstType c lvls
| e@(Expr.proj n i s) => checkInferTypeCache e (inferProjType whnf inferTypeAux n i s)
| e@(Expr.app f _) => checkInferTypeCache e (inferAppType whnf inferTypeAux f e.getAppArgs)
| Expr.mvar mvarId => inferMVarType mvarId
| Expr.fvar fvarId => inferFVarType fvarId
| Expr.bvar _ => unreachable!
| Expr.mdata _ e => inferTypeAux e
| Expr.lit v => pure v.type
| Expr.sort lvl => pure $ Expr.sort (Level.succ lvl)
| e@(Expr.forallE _ _ _ _) => checkInferTypeCache e (inferForallType whnf inferTypeAux e)
| e@(Expr.lam _ _ _ _) => checkInferTypeCache e (inferLambdaType whnf inferTypeAux e)
| e@(Expr.letE _ _ _ _) => checkInferTypeCache e (inferLambdaType whnf inferTypeAux e)
#exit
@[inline] private def liftStateMCtx {α} (x : StateM σ α) : TypeUtilM σ ϕ α :=

View file

@ -0,0 +1,9 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Lean.Meta.Basic
import Init.Lean.Meta.WHNF
import Init.Lean.Meta.InferType

View file

@ -0,0 +1,167 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Lean.Meta.Basic
namespace Lean
namespace Meta
@[specialize] private def getForallResultType
(whnf : Expr → MetaM Expr)
(fType : Expr) (args : Array Expr) : MetaM Expr :=
do (j, fType) ← args.size.foldM
(fun i (acc : Nat × Expr) =>
let (j, type) := acc;
match type with
| Expr.forallE _ _ _ b => pure (j, b)
| _ => do
type ← whnf $ type.instantiateRevRange j i args;
match type with
| Expr.forallE _ _ _ b => pure (i, b)
| _ => throwEx $ Exception.functionExpected fType args)
(0, fType);
pure $ fType.instantiateRevRange j args.size args
@[specialize] private def inferAppType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(f : Expr) (args : Array Expr) : MetaM Expr :=
do fType ← inferType f;
getForallResultType whnf fType args
private def inferConstType (c : Name) (lvls : List Level) : MetaM Expr :=
do env ← getEnv;
match env.find c with
| some cinfo =>
if cinfo.lparams.length == lvls.length then
throwEx $ Exception.incorrectNumOfLevels c lvls
else
pure $ cinfo.instantiateTypeLevelParams lvls
| none =>
throwEx $ Exception.unknownConst c
@[specialize] private def inferProjType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(structName : Name) (idx : Nat) (e : Expr) : MetaM Expr :=
do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProjection structName idx e;
structType ← inferType e;
structType ← whnf structType;
env ← getEnv;
matchConst env structType.getAppFn failed $ fun structInfo structLvls => do
match structInfo with
| ConstantInfo.inductInfo { nparams := n, ctors := [ctor], .. } =>
let structParams := structType.getAppArgs;
if n != structParams.size then failed ()
else match env.find ctor with
| none => failed ()
| some (ctorInfo) => do
let ctorType := ctorInfo.instantiateTypeLevelParams structLvls;
ctorType ← getForallResultType whnf ctorType structParams;
ctorType ← idx.foldM
(fun i ctorType => do
ctorType ← whnf ctorType;
match ctorType with
| Expr.forallE _ _ _ body =>
if body.hasLooseBVars then
pure $ body.instantiate1 $ Expr.proj structName i e
else
pure body
| _ => failed ())
ctorType;
ctorType ← whnf ctorType;
match ctorType with
| Expr.forallE _ _ d _ => pure d
| _ => failed ()
| _ => failed ()
@[specialize] private def getLevel
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(type : Expr) : MetaM Level :=
do typeType ← inferType type;
typeType ← whnf type;
match typeType with
| Expr.sort lvl => pure lvl
| Expr.mvar mvarId =>
condM (isReadOnlyOrSyntheticMVar mvarId)
(throwEx $ Exception.typeExpected type)
(do levelMVarId ← mkFreshId;
let lvl := Level.mvar levelMVarId;
assignExprMVar mvarId (Expr.sort lvl);
pure lvl)
| _ => throwEx $ Exception.typeExpected type
@[specialize] private def inferForallType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(e : Expr) : MetaM Expr :=
forallTelescope whnf e $ fun xs e => do
type ← inferType e;
lvl ← getLevel whnf inferType type;
lvl ← xs.foldrM
(fun x lvl => do
xType ← inferType x;
xTypeLvl ← getLevel whnf inferType xType;
pure $ Level.imax xTypeLvl lvl)
lvl;
pure $ Expr.sort lvl
/- Infer type of lambda and let expressions -/
@[specialize] private def inferLambdaType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(e : Expr) : MetaM Expr :=
lambdaTelescope whnf e $ fun xs e => do
type ← inferType e;
mkForall xs type
@[inline] private def withLocalDecl {α} (name : Name) (bi : BinderInfo) (type : Expr) (x : Expr → MetaM α) : MetaM α :=
savingCache $ do
fvarId ← mkFreshId;
adaptReader (fun (ctx : Context) => { lctx := ctx.lctx.mkLocalDecl fvarId name type bi, .. ctx }) $
x (Expr.fvar fvarId)
private def inferMVarType (mvarId : Name) : MetaM Expr :=
do mctx ← getMCtx;
match mctx.findDecl mvarId with
| some d => pure d.type
| none => throwEx $ Exception.unknownMVar mvarId
private def inferFVarType (fvarId : Name) : MetaM Expr :=
do lctx ← getLCtx;
match lctx.find fvarId with
| some d => pure d.type
| none => throwEx $ Exception.unknownFVar fvarId
@[inline] private def checkInferTypeCache (e : Expr) (inferType : MetaM Expr) : MetaM Expr :=
do s ← get;
match s.cache.inferType.find e with
| some type => pure type
| none => do
type ← inferType;
modify $ fun s => { cache := { inferType := s.cache.inferType.insert e type, .. s.cache }, .. s };
pure type
@[specialize] private partial def inferTypeAux
(whnf : Expr → MetaM Expr)
: Expr → MetaM Expr
| Expr.const c lvls => inferConstType c lvls
| e@(Expr.proj n i s) => checkInferTypeCache e (inferProjType whnf inferTypeAux n i s)
| e@(Expr.app f _) => checkInferTypeCache e (inferAppType whnf inferTypeAux f e.getAppArgs)
| Expr.mvar mvarId => inferMVarType mvarId
| Expr.fvar fvarId => inferFVarType fvarId
| Expr.bvar _ => unreachable!
| Expr.mdata _ e => inferTypeAux e
| Expr.lit v => pure v.type
| Expr.sort lvl => pure $ Expr.sort (Level.succ lvl)
| e@(Expr.forallE _ _ _ _) => checkInferTypeCache e (inferForallType whnf inferTypeAux e)
| e@(Expr.lam _ _ _ _) => checkInferTypeCache e (inferLambdaType whnf inferTypeAux e)
| e@(Expr.letE _ _ _ _) => checkInferTypeCache e (inferLambdaType whnf inferTypeAux e)
end Meta
end Lean

View file

@ -0,0 +1,43 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Lean.AuxRecursor
import Init.Lean.WHNF
import Init.Lean.Meta.Basic
namespace Lean
namespace Meta
private def isAuxDef? (constName : Name) : MetaM Bool :=
do env ← getEnv; pure (isAuxRecursor env constName || isNoConfusion env constName)
@[inline] private def getCachedWHNF (e : Expr) : MetaM (Option Expr) :=
do t ← getTransparency;
s ← get;
pure $ s.cache.whnf.find (t, e)
@[inline] private def cacheWHNF (e : Expr) (r : Expr) : MetaM Unit :=
do t ← getTransparency;
modify $ fun s => { cache := { whnf := s.cache.whnf.insert (t, e) r, .. s.cache }, .. s }
@[specialize] private partial def whnfAux
(inferType : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
: Expr → MetaM Expr
| e => whnfEasyCases getLocalDecl getExprMVarAssignment e $ fun e => do
cached? ← getCachedWHNF e;
match cached? with
| some r => pure r
| none => do
e ← whnfCore getConst isAuxDef? whnfAux inferType isDefEq getLocalDecl getExprMVarAssignment e;
r ← unfoldDefinition getConst isAuxDef? whnfAux inferType isDefEq synthesizePending getLocalDecl
getExprMVarAssignment e (fun _ => pure e) whnfAux;
cacheWHNF e r;
pure r
end Meta
end Lean