feat: synthPending implementation
This commit is contained in:
parent
174771cdb3
commit
abf27d843e
6 changed files with 72 additions and 38 deletions
|
|
@ -168,16 +168,16 @@ IO.mkRef $ fun _ _ => throw $ Exception.other "isDefEq implementation was not se
|
|||
|
||||
@[init mkIsExprDefEqAuxRef] def isExprDefEqAuxRef : IO.Ref (Expr → Expr → MetaM Bool) := arbitrary _
|
||||
|
||||
def mkSynthPendingRef : IO (IO.Ref (Expr → MetaM Bool)) :=
|
||||
def mkSynthPendingRef : IO (IO.Ref (MVarId → MetaM Bool)) :=
|
||||
IO.mkRef $ fun _ => pure false
|
||||
|
||||
@[init mkSynthPendingRef] def synthPendingRef : IO.Ref (Expr → MetaM Bool) := arbitrary _
|
||||
@[init mkSynthPendingRef] def synthPendingRef : IO.Ref (MVarId → MetaM Bool) := arbitrary _
|
||||
|
||||
structure MetaExtState :=
|
||||
(whnf : Expr → MetaM Expr)
|
||||
(inferType : Expr → MetaM Expr)
|
||||
(isDefEqAux : Expr → Expr → MetaM Bool)
|
||||
(synthPending : Expr → MetaM Bool)
|
||||
(synthPending : MVarId → MetaM Bool)
|
||||
|
||||
instance MetaExtState.inhabited : Inhabited MetaExtState :=
|
||||
⟨{ whnf := arbitrary _, inferType := arbitrary _, isDefEqAux := arbitrary _, synthPending := arbitrary _ }⟩
|
||||
|
|
@ -212,10 +212,10 @@ withIncRecDepth $ do
|
|||
env ← getEnv;
|
||||
(metaExt.getState env).isDefEqAux t s
|
||||
|
||||
def synthPending (e : Expr) : MetaM Bool :=
|
||||
def synthPending (mvarId : MVarId) : MetaM Bool :=
|
||||
withIncRecDepth $ do
|
||||
env ← getEnv;
|
||||
(metaExt.getState env).synthPending e
|
||||
(metaExt.getState env).synthPending mvarId
|
||||
|
||||
def mkFreshId : MetaM Name := do
|
||||
s ← get;
|
||||
|
|
|
|||
|
|
@ -110,6 +110,12 @@ private partial def isDefEqArgsAux (args₁ args₂ : Array Expr) (h : args₁.s
|
|||
else
|
||||
pure true
|
||||
|
||||
@[specialize] private def trySynthPending (e : Expr) : MetaM Bool := do
|
||||
mvarId? ← getStuckMVar? e;
|
||||
match mvarId? with
|
||||
| some mvarId => synthPending mvarId
|
||||
| none => pure false
|
||||
|
||||
private def isDefEqArgs (f : Expr) (args₁ args₂ : Array Expr) : MetaM Bool :=
|
||||
if h : args₁.size = args₂.size then do
|
||||
finfo ← getFunInfoNArgs f args₁.size;
|
||||
|
|
@ -124,8 +130,8 @@ if h : args₁.size = args₂.size then do
|
|||
let a₂ := args₂.get! i;
|
||||
let info := finfo.paramInfo.get! i;
|
||||
when info.instImplicit $ do {
|
||||
synthPending a₁;
|
||||
synthPending a₂;
|
||||
trySynthPending a₁;
|
||||
trySynthPending a₂;
|
||||
pure ()
|
||||
};
|
||||
withAtLeastTransparency TransparencyMode.default $ isExprDefEqAux a₁ a₂)
|
||||
|
|
@ -892,8 +898,8 @@ private partial def isDefEqQuick : Expr → Expr → MetaM LBool
|
|||
condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $
|
||||
condM (isDelayedAssignedHead tFn t) (do t ← instantiateMVars t; isDefEqQuick t s) $
|
||||
condM (isDelayedAssignedHead sFn s) (do s ← instantiateMVars s; isDefEqQuick t s) $
|
||||
condM (isSynthetic tFn <&&> synthPending tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $
|
||||
condM (isSynthetic sFn <&&> synthPending sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ do
|
||||
condM (isSynthetic tFn <&&> trySynthPending tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $
|
||||
condM (isSynthetic sFn <&&> trySynthPending sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ do
|
||||
tAssign? ← isAssignable tFn;
|
||||
sAssign? ← isAssignable sFn;
|
||||
trace! `Meta.isDefEq
|
||||
|
|
@ -957,10 +963,10 @@ else
|
|||
@[specialize] private def unstuckMVar
|
||||
(e : Expr)
|
||||
(successK : Expr → MetaM Bool) (failK : MetaM Bool): MetaM Bool := do
|
||||
s? ← WHNF.getStuckMVar getConst whnf e;
|
||||
match s? with
|
||||
| some s =>
|
||||
condM (synthPending s)
|
||||
mvarId? ← getStuckMVar? e;
|
||||
match mvarId? with
|
||||
| some mvarId =>
|
||||
condM (synthPending mvarId)
|
||||
(do e ← instantiateMVars e; successK e)
|
||||
failK
|
||||
| none => failK
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import Init.Lean.Meta.Basic
|
|||
import Init.Lean.Meta.Instances
|
||||
import Init.Lean.Meta.LevelDefEq
|
||||
import Init.Lean.Meta.AbstractMVars
|
||||
import Init.Lean.Meta.WHNF
|
||||
|
||||
namespace Lean
|
||||
namespace Meta
|
||||
|
|
@ -594,6 +595,26 @@ match result? with
|
|||
| some result => pure result
|
||||
| none => throwEx $ Exception.synthInstance type
|
||||
|
||||
def synthPendingImp (mvarId : MVarId) : MetaM Bool := do
|
||||
mvarDecl ← getMVarDecl mvarId;
|
||||
match mvarDecl.kind with
|
||||
| MetavarKind.synthetic => do
|
||||
c? ← isClass mvarDecl.type;
|
||||
match c? with
|
||||
| none => pure false
|
||||
| some _ => do
|
||||
val? ← synthInstance? mvarDecl.type;
|
||||
match val? with
|
||||
| none => pure false
|
||||
| some val =>
|
||||
condM (isExprMVarAssigned mvarId) (pure false) $ do
|
||||
assignExprMVar mvarId val;
|
||||
pure true
|
||||
| _ => pure false
|
||||
|
||||
@[init] def setSynthPendingRef : IO Unit :=
|
||||
synthPendingRef.set synthPendingImp
|
||||
|
||||
@[init] private def regTraceClasses : IO Unit := do
|
||||
registerTraceClass `Meta.synthInstance;
|
||||
registerTraceClass `Meta.synthInstance.globalInstances;
|
||||
|
|
|
|||
|
|
@ -66,5 +66,8 @@ whnfHeadPredAux pred e
|
|||
def whnfUntil (e : Expr) (declName : Name) : MetaM Expr :=
|
||||
whnfHeadPredAux (fun e => pure $ e.isAppOf declName) e
|
||||
|
||||
def getStuckMVar? (e : Expr) : MetaM (Option MVarId) :=
|
||||
WHNF.getStuckMVar? getConst whnf e
|
||||
|
||||
end Meta
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -121,10 +121,10 @@ if h : majorIdx < recArgs.size then do
|
|||
else
|
||||
failK ()
|
||||
|
||||
@[specialize] def isRecStuck {m : Type → Type} [Monad m]
|
||||
(whnf : Expr → m Expr)
|
||||
(isStuck : Expr → m (Option Expr))
|
||||
(rec : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option Expr) :=
|
||||
@[specialize] def isRecStuck? {m : Type → Type} [Monad m]
|
||||
(whnf : Expr → m Expr)
|
||||
(isStuck? : Expr → m (Option MVarId))
|
||||
(rec : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option MVarId) :=
|
||||
if rec.k then
|
||||
-- TODO: improve this case
|
||||
pure none
|
||||
|
|
@ -133,7 +133,7 @@ else do
|
|||
if h : majorIdx < recArgs.size then do
|
||||
let major := recArgs.get ⟨majorIdx, h⟩;
|
||||
major ← whnf major;
|
||||
isStuck major
|
||||
isStuck? major
|
||||
else
|
||||
pure none
|
||||
|
||||
|
|
@ -166,20 +166,20 @@ match rec.kind with
|
|||
| QuotKind.ind => process 4 3
|
||||
| _ => failK ()
|
||||
|
||||
@[specialize] def isQuotRecStuck {m : Type → Type} [Monad m]
|
||||
@[specialize] def isQuotRecStuck? {m : Type → Type} [Monad m]
|
||||
(whnf : Expr → m Expr)
|
||||
(isStuck : Expr → m (Option Expr))
|
||||
(rec : QuotVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option Expr) :=
|
||||
let process (majorPos : Nat) : m (Option Expr) :=
|
||||
(isStuck? : Expr → m (Option MVarId))
|
||||
(rec : QuotVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option MVarId) :=
|
||||
let process? (majorPos : Nat) : m (Option MVarId) :=
|
||||
if h : majorPos < recArgs.size then do
|
||||
let major := recArgs.get ⟨majorPos, h⟩;
|
||||
major ← whnf major;
|
||||
isStuck major
|
||||
isStuck? major
|
||||
else
|
||||
pure none;
|
||||
match rec.kind with
|
||||
| QuotKind.lift => process 5
|
||||
| QuotKind.ind => process 4
|
||||
| QuotKind.lift => process? 5
|
||||
| QuotKind.ind => process? 4
|
||||
| _ => pure none
|
||||
|
||||
/- ===========================
|
||||
|
|
@ -187,22 +187,22 @@ match rec.kind with
|
|||
=========================== -/
|
||||
|
||||
/-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/
|
||||
@[specialize] partial def getStuckMVar {m : Type → Type} [Monad m]
|
||||
@[specialize] partial def getStuckMVar? {m : Type → Type} [Monad m]
|
||||
(getConst : Name → m (Option ConstantInfo))
|
||||
(whnf : Expr → m Expr)
|
||||
: Expr → m (Option Expr)
|
||||
| Expr.mdata _ e _ => getStuckMVar e
|
||||
| Expr.proj _ _ e _ => do e ← whnf e; getStuckMVar e
|
||||
| e@(Expr.mvar _ _) => pure (some e)
|
||||
: Expr → m (Option MVarId)
|
||||
| Expr.mdata _ e _ => getStuckMVar? e
|
||||
| Expr.proj _ _ e _ => do e ← whnf e; getStuckMVar? e
|
||||
| e@(Expr.mvar mvarId _) => pure (some mvarId)
|
||||
| e@(Expr.app f _ _) =>
|
||||
let f := f.getAppFn;
|
||||
match f with
|
||||
| Expr.mvar _ _ => pure (some f)
|
||||
| Expr.mvar mvarId _ => pure (some mvarId)
|
||||
| Expr.const fName fLvls _ => do
|
||||
cinfo? ← getConst fName;
|
||||
match cinfo? with
|
||||
| some $ ConstantInfo.recInfo rec => isRecStuck whnf getStuckMVar rec fLvls e.getAppArgs
|
||||
| some $ ConstantInfo.quotInfo rec => isQuotRecStuck whnf getStuckMVar rec fLvls e.getAppArgs
|
||||
| some $ ConstantInfo.recInfo rec => isRecStuck? whnf getStuckMVar? rec fLvls e.getAppArgs
|
||||
| some $ ConstantInfo.quotInfo rec => isQuotRecStuck? whnf getStuckMVar? rec fLvls e.getAppArgs
|
||||
| _ => pure none
|
||||
| _ => pure none
|
||||
| _ => pure none
|
||||
|
|
@ -320,14 +320,14 @@ else
|
|||
(whnf : Expr → m Expr)
|
||||
(inferType : Expr → m Expr)
|
||||
(isDefEq : Expr → Expr → m Bool)
|
||||
(synthesizePending : Expr → m Bool)
|
||||
(synthesizePending : MVarId → m Bool)
|
||||
(getLocalDecl : FVarId → m LocalDecl)
|
||||
(getMVarAssignment : MVarId → m (Option Expr))
|
||||
: Expr → m Expr
|
||||
| e => do
|
||||
e ← whnfCore getConst isAuxDef? whnf inferType isDefEq getLocalDecl getMVarAssignment e;
|
||||
(some mvar) ← getStuckMVar getConst whnf e | pure e;
|
||||
succeeded ← synthesizePending mvar;
|
||||
(some mvarId) ← getStuckMVar? getConst whnf e | pure e;
|
||||
succeeded ← synthesizePending mvarId;
|
||||
if succeeded then whnfCoreUnstuck e else pure e
|
||||
|
||||
/-- Unfold definition using "smart unfolding" if possible. -/
|
||||
|
|
@ -337,7 +337,7 @@ else
|
|||
(whnf : Expr → m Expr)
|
||||
(inferType : Expr → m Expr)
|
||||
(isDefEq : Expr → Expr → m Bool)
|
||||
(synthesizePending : Expr → m Bool)
|
||||
(synthesizePending : MVarId → m Bool)
|
||||
(getLocalDecl : FVarId → m LocalDecl)
|
||||
(getMVarAssignment : MVarId → m (Option Expr))
|
||||
(e : Expr) : m (Option Expr) :=
|
||||
|
|
@ -379,7 +379,7 @@ match e with
|
|||
(isAuxDef? : Name → m Bool)
|
||||
(inferType : Expr → m Expr)
|
||||
(isDefEq : Expr → Expr → m Bool)
|
||||
(synthesizePending : Expr → m Bool)
|
||||
(synthesizePending : MVarId → m Bool)
|
||||
(getLocalDecl : FVarId → m LocalDecl)
|
||||
(getMVarAssignment : MVarId → m (Option Expr))
|
||||
: Expr → m Expr
|
||||
|
|
|
|||
4
tests/lean/run/synthPending1.lean
Normal file
4
tests/lean/run/synthPending1.lean
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
new_frontend
|
||||
|
||||
theorem ex : Not (1 = 2) :=
|
||||
ofDecideEqFalse rfl
|
||||
Loading…
Add table
Reference in a new issue