feat: helper functions for debugging, handling metavars, creating telescopes, extract universe level from types, checking whether type is a class, and declaring locals

This commit is contained in:
Leonardo de Moura 2019-11-09 11:36:05 -08:00
parent d10e08236f
commit d54880b6d1
2 changed files with 216 additions and 6 deletions

View file

@ -0,0 +1,38 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Data.ToString
universes u
namespace Lean
inductive LOption (α : Type u)
| none {} : LOption
| some : α → LOption
| undef {} : LOption
namespace LOption
variables {α : Type u}
instance : Inhabited (LOption α) := ⟨none⟩
instance [HasToString α] : HasToString (LOption α) :=
⟨fun o => match o with | none => "none" | undef => "undef" | (some a) => "(some " ++ toString a ++ ")"⟩
def beq [HasBeq α] : LOption α → LOption α → Bool
| none, none => true
| undef, undef => true
| some a, some b => a == b
| _, _ => false
instance [HasBeq α] : HasBeq (LOption α) := ⟨beq⟩
end LOption
end Lean
def Option.toLOption {α : Type u} : Option α → Lean.LOption α
| none => Lean.LOption.none
| some a => Lean.LOption.some a

View file

@ -7,8 +7,10 @@ prelude
import Init.Control.Reader
import Init.Lean.NameGenerator
import Init.Lean.Environment
import Init.Lean.LOption
import Init.Lean.Trace
import Init.Lean.AuxRecursor
import Init.Lean.Class
import Init.Lean.WHNF
import Init.Lean.ReducibilityAttrs
@ -59,6 +61,7 @@ structure Config :=
(foApprox : Bool := false)
(ctxApprox : Bool := false)
(quasiPatternApprox : Bool := false)
(debug : Bool := false)
(transparency : TransparencyMode := TransparencyMode.default)
structure Cache :=
@ -68,15 +71,22 @@ structure Cache :=
structure ExceptionContext :=
(env : Environment) (mctx : MetavarContext) (lctx : LocalContext)
inductive Bug
| overwritingExprMVar (mvarId : Name)
inductive Exception
| unknownConst (constName : Name) (ctx : ExceptionContext)
| unknownFVar (fvarId : Name) (ctx : ExceptionContext)
| unknownMVar (mvarId : Name) (ctx : ExceptionContext)
| functionExpected (fType : Expr) (args : Array Expr) (ctx : ExceptionContext)
| typeExpected (type : Expr) (ctx : ExceptionContext)
| incorrectNumOfLevels (constName : Name) (constLvls : List Level) (ctx : ExceptionContext)
| invalidProjection (structName : Name) (idx : Nat) (s : Expr) (ctx : ExceptionContext)
| bug (b : Bug) (ctx : ExceptionContext)
| other (msg : String)
instance Exception.inhabited : Inhabited Exception := ⟨Exception.other ""⟩
structure Context :=
(config : Config := {})
(lctx : LocalContext := {})
@ -94,10 +104,13 @@ structure State :=
(cache : Cache := {})
(ngen : NameGenerator := {})
(traceState : TraceState := {})
(postponed : Array PostponedEntry := #[])
(postponed : PersistentArray PostponedEntry := {})
abbrev MetaM := ReaderT Context (EStateM Exception State)
instance MetaM.inhabited {α} : Inhabited (MetaM α) :=
⟨fun c s => EStateM.Result.error (arbitrary _) s⟩
@[inline] private def getLCtx : MetaM LocalContext :=
do ctx ← read; pure ctx.lctx
@ -112,6 +125,20 @@ do ctx ← read;
s ← get;
throw (f {env := s.env, mctx := s.mctx, lctx := ctx.lctx })
@[inline] private def throwBug {α} (b : Bug) : MetaM α :=
throwEx $ Exception.bug b
/-- Execute `x` only in debugging mode. -/
@[inline] private def whenDebugging {α} (x : MetaM α) : MetaM Unit :=
do ctx ← read;
when ctx.config.debug (do x; pure ())
private def mkFreshId : MetaM Name :=
do s ← get;
let id := s.ngen.curr;
modify $ fun s => { ngen := s.ngen.next, .. s };
pure id
@[inline] private def reduceAll? : MetaM Bool :=
do ctx ← read; pure $ ctx.config.transparency == TransparencyMode.all
@ -124,8 +151,27 @@ do ctx ← read; pure $ ctx.config.transparency
@[inline] private def getOptions : MetaM Options :=
do ctx ← read; pure ctx.config.opts
@[inline] def isReducible (n : Name) : MetaM Bool :=
do env ← getEnv; pure $ isReducible env n
-- Remark: wanted to use `private`, but in C++ parser, `private` declarations do not shadow outer public ones.
-- TODO: fix this bug
@[inline] def isReducible (constName : Name) : MetaM Bool :=
do env ← getEnv; pure $ isReducible env constName
private def isReadOnlyOrSyntheticMVar (mvarId : Name) : MetaM Bool :=
do mctx ← getMCtx;
match mctx.findDecl mvarId with
| some d => pure $ d.synthetic || d.depth != mctx.depth
| _ => throwEx $ Exception.unknownMVar mvarId
@[inline] private def isExprAssigned (mvarId : Name) : MetaM Bool :=
do mctx ← getMCtx;
pure $ mctx.isExprAssigned mvarId
@[inline] private def getMVarAssignment (mvarId : Name) : MetaM (Option Expr) :=
do mctx ← getMCtx; pure (mctx.getExprAssignment mvarId)
private def assignExprMVar (mvarId : Name) (val : Expr) : MetaM Unit :=
do whenDebugging $ whenM (isExprAssigned mvarId) $ throwBug $ Bug.overwritingExprMVar mvarId;
modify $ fun s => { mctx := s.mctx.assignExpr mvarId val, .. s }
@[inline] private def getTraceState : MetaM TraceState :=
do s ← get; pure s.traceState
@ -156,9 +202,6 @@ do lctx ← getLCtx;
| some d => pure d
| none => throwEx $ Exception.unknownFVar fvarId
@[inline] private def getMVarAssignment (mvarId : Name) : MetaM (Option Expr) :=
do mctx ← getMCtx; pure (mctx.getExprAssignment mvarId)
@[inline] private def getCachedWHNF (e : Expr) : MetaM (Option Expr) :=
do t ← getTransparency;
s ← get;
@ -168,6 +211,14 @@ do t ← getTransparency;
do t ← getTransparency;
modify $ fun s => { cache := { whnf := s.cache.whnf.insert (t, e) r, .. s.cache }, .. s }
def instantiateMVars (e : Expr) : MetaM Expr :=
if e.hasMVar then
modifyGet $ fun s =>
let (e, mctx) := s.mctx.instantiateMVars e;
(e, { mctx := mctx, .. s })
else
pure e
@[specialize] private partial def whnfAux
(inferType : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
@ -183,6 +234,104 @@ do t ← getTransparency;
cacheWHNF e r;
pure r
/-- Save cache, execute `x`, restore cache -/
@[inline] private def withCacheScope {α} (x : MetaM α) : MetaM α :=
do s ← get;
let savedCache := s.cache;
finally x (modify $ fun s => { cache := savedCache, .. s })
/--
`forallTelescopeAux whnf k lctx fvars j type`
Remarks:
- `lctx` is the `MetaM` local context exteded with the declaration for `fvars`.
- `type` is the type we are computing the telescope for. It contains only
dangling bound variables in the range `[j, fvars.size)`
- when `type` is not a `forallE` nor it can't be reduced to one, we
excute the continuation `k`. -/
@[specialize] private partial def forallTelescopeAux {α}
(whnf : Expr → MetaM Expr)
(k : Array Expr → Expr → MetaM α)
: LocalContext → Array Expr → Nat → Expr → MetaM α
| lctx, fvars, j, Expr.forallE n bi d b => do
let d := d.instantiateRevRange j fvars.size fvars;
fvarId ← mkFreshId;
let lctx := lctx.mkLocalDecl fvarId n d bi;
let fvar := Expr.fvar fvarId;
forallTelescopeAux lctx (fvars.push fvar) j b
| lctx, fvars, j, type =>
let type := type.instantiateRevRange j fvars.size fvars;
adaptReader (fun (ctx : Context) => { lctx := lctx, .. ctx }) $ do
newType ← whnf type;
if newType.isForall then
forallTelescopeAux lctx fvars fvars.size type
else
k fvars type
/-- Given `type` of the form `forall xs, A`, execute `k xs A`.
This combinator will declare local declarations, create free variables for them,
execute `k` with updated local context, and make sure the cache is restored after executing `k`. -/
@[specialize] def forallTelescope {α}
(whnf : Expr → MetaM Expr)
(type : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α :=
do newType ← whnf type;
if newType.isForall then
withCacheScope $ do
lctx ← getLCtx;
forallTelescopeAux whnf k lctx #[] 0 newType
else do
k #[] type
def isClassQuickConst (constName : Name) : MetaM (LOption Name) :=
do env ← getEnv;
if isClass env constName then
pure (LOption.some constName)
else do
cinfo? ← getConst constName;
match cinfo? with
| some _ => pure LOption.undef
| none => pure LOption.none
private partial def isClassQuick : Expr → MetaM (LOption Name)
| Expr.bvar _ => pure LOption.none
| Expr.lit _ => pure LOption.none
| Expr.fvar _ => pure LOption.none
| Expr.sort _ => pure LOption.none
| Expr.lam _ _ _ _ => pure LOption.none
| Expr.letE _ _ _ _ => pure LOption.undef
| Expr.proj _ _ _ => pure LOption.undef
| Expr.forallE _ _ _ b => isClassQuick b
| Expr.mdata _ e => isClassQuick e
| Expr.const n _ => isClassQuickConst n
| Expr.mvar mvarId => do
val? ← getMVarAssignment mvarId;
match val? with
| some val => isClassQuick val
| none => pure LOption.none
| Expr.app f _ =>
match f.getAppFn with
| Expr.const n _ => isClassQuickConst n
| Expr.lam _ _ _ _ => pure LOption.undef
| _ => pure LOption.none
@[specialize] private partial def isClassExpensive
(whnf : Expr → MetaM Expr)
(type : Expr) : MetaM (Option Name) :=
do forallTelescope whnf type $ fun xs type => do
match type.getAppFn with
| Expr.const c _ => do
env ← getEnv;
pure $ if isClass env c then some c else none
| _ => pure none
@[specialize] def isClass
(whnf : Expr → MetaM Expr)
(type : Expr) : MetaM (Option Name) :=
do c? ← isClassQuick type;
match c? with
| LOption.none => pure none
| LOption.some c => pure (some c)
| LOption.undef => isClassExpensive whnf type
@[specialize] private def getForallResultType
(whnf : Expr → MetaM Expr)
(fType : Expr) (args : Array Expr) : MetaM Expr :=
@ -252,6 +401,29 @@ do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProje
| _ => failed ()
| _ => failed ()
@[specialize] private def getLevel
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(type : Expr) : MetaM Level :=
do typeType ← inferType type;
typeType ← whnf type;
match typeType with
| Expr.sort lvl => pure lvl
| Expr.mvar mvarId =>
condM (isReadOnlyOrSyntheticMVar mvarId)
(throwEx $ Exception.typeExpected type)
(do levelMVarId ← mkFreshId;
let lvl := Level.mvar levelMVarId;
assignExprMVar mvarId (Expr.sort lvl);
pure lvl)
| _ => throwEx $ Exception.typeExpected type
@[inline] private def withLocalDecl {α} (name : Name) (bi : BinderInfo) (type : Expr) (x : Expr → MetaM α) : MetaM α :=
withCacheScope $ do
fvarId ← mkFreshId;
adaptReader (fun (ctx : Context) => { lctx := ctx.lctx.mkLocalDecl fvarId name type bi, .. ctx }) $
x (Expr.fvar fvarId)
private def inferMVarType (mvarId : Name) : MetaM Expr :=
do mctx ← getMCtx;
match mctx.findDecl mvarId with