diff --git a/src/Init/Lean/Meta/Basic.lean b/src/Init/Lean/Meta/Basic.lean index 4b5664efcd..5363761b4c 100644 --- a/src/Init/Lean/Meta/Basic.lean +++ b/src/Init/Lean/Meta/Basic.lean @@ -168,16 +168,16 @@ IO.mkRef $ fun _ _ => throw $ Exception.other "isDefEq implementation was not se @[init mkIsExprDefEqAuxRef] def isExprDefEqAuxRef : IO.Ref (Expr → Expr → MetaM Bool) := arbitrary _ -def mkSynthPendingRef : IO (IO.Ref (Expr → MetaM Bool)) := +def mkSynthPendingRef : IO (IO.Ref (MVarId → MetaM Bool)) := IO.mkRef $ fun _ => pure false -@[init mkSynthPendingRef] def synthPendingRef : IO.Ref (Expr → MetaM Bool) := arbitrary _ +@[init mkSynthPendingRef] def synthPendingRef : IO.Ref (MVarId → MetaM Bool) := arbitrary _ structure MetaExtState := (whnf : Expr → MetaM Expr) (inferType : Expr → MetaM Expr) (isDefEqAux : Expr → Expr → MetaM Bool) -(synthPending : Expr → MetaM Bool) +(synthPending : MVarId → MetaM Bool) instance MetaExtState.inhabited : Inhabited MetaExtState := ⟨{ whnf := arbitrary _, inferType := arbitrary _, isDefEqAux := arbitrary _, synthPending := arbitrary _ }⟩ @@ -212,10 +212,10 @@ withIncRecDepth $ do env ← getEnv; (metaExt.getState env).isDefEqAux t s -def synthPending (e : Expr) : MetaM Bool := +def synthPending (mvarId : MVarId) : MetaM Bool := withIncRecDepth $ do env ← getEnv; - (metaExt.getState env).synthPending e + (metaExt.getState env).synthPending mvarId def mkFreshId : MetaM Name := do s ← get; diff --git a/src/Init/Lean/Meta/ExprDefEq.lean b/src/Init/Lean/Meta/ExprDefEq.lean index e860144e5a..e447f51ddb 100644 --- a/src/Init/Lean/Meta/ExprDefEq.lean +++ b/src/Init/Lean/Meta/ExprDefEq.lean @@ -110,6 +110,12 @@ private partial def isDefEqArgsAux (args₁ args₂ : Array Expr) (h : args₁.s else pure true +@[specialize] private def trySynthPending (e : Expr) : MetaM Bool := do +mvarId? ← getStuckMVar? e; +match mvarId? with +| some mvarId => synthPending mvarId +| none => pure false + private def isDefEqArgs (f : Expr) (args₁ args₂ : Array Expr) : MetaM Bool := if h : args₁.size = args₂.size then do finfo ← getFunInfoNArgs f args₁.size; @@ -124,8 +130,8 @@ if h : args₁.size = args₂.size then do let a₂ := args₂.get! i; let info := finfo.paramInfo.get! i; when info.instImplicit $ do { - synthPending a₁; - synthPending a₂; + trySynthPending a₁; + trySynthPending a₂; pure () }; withAtLeastTransparency TransparencyMode.default $ isExprDefEqAux a₁ a₂) @@ -892,8 +898,8 @@ private partial def isDefEqQuick : Expr → Expr → MetaM LBool condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ condM (isDelayedAssignedHead tFn t) (do t ← instantiateMVars t; isDefEqQuick t s) $ condM (isDelayedAssignedHead sFn s) (do s ← instantiateMVars s; isDefEqQuick t s) $ - condM (isSynthetic tFn <&&> synthPending tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $ - condM (isSynthetic sFn <&&> synthPending sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ do + condM (isSynthetic tFn <&&> trySynthPending tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $ + condM (isSynthetic sFn <&&> trySynthPending sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ do tAssign? ← isAssignable tFn; sAssign? ← isAssignable sFn; trace! `Meta.isDefEq @@ -957,10 +963,10 @@ else @[specialize] private def unstuckMVar (e : Expr) (successK : Expr → MetaM Bool) (failK : MetaM Bool): MetaM Bool := do -s? ← WHNF.getStuckMVar getConst whnf e; -match s? with -| some s => - condM (synthPending s) +mvarId? ← getStuckMVar? e; +match mvarId? with +| some mvarId => + condM (synthPending mvarId) (do e ← instantiateMVars e; successK e) failK | none => failK diff --git a/src/Init/Lean/Meta/SynthInstance.lean b/src/Init/Lean/Meta/SynthInstance.lean index 8b61f7e32d..dd2cda5b5f 100644 --- a/src/Init/Lean/Meta/SynthInstance.lean +++ b/src/Init/Lean/Meta/SynthInstance.lean @@ -9,6 +9,7 @@ import Init.Lean.Meta.Basic import Init.Lean.Meta.Instances import Init.Lean.Meta.LevelDefEq import Init.Lean.Meta.AbstractMVars +import Init.Lean.Meta.WHNF namespace Lean namespace Meta @@ -594,6 +595,26 @@ match result? with | some result => pure result | none => throwEx $ Exception.synthInstance type +def synthPendingImp (mvarId : MVarId) : MetaM Bool := do +mvarDecl ← getMVarDecl mvarId; +match mvarDecl.kind with +| MetavarKind.synthetic => do + c? ← isClass mvarDecl.type; + match c? with + | none => pure false + | some _ => do + val? ← synthInstance? mvarDecl.type; + match val? with + | none => pure false + | some val => + condM (isExprMVarAssigned mvarId) (pure false) $ do + assignExprMVar mvarId val; + pure true +| _ => pure false + +@[init] def setSynthPendingRef : IO Unit := +synthPendingRef.set synthPendingImp + @[init] private def regTraceClasses : IO Unit := do registerTraceClass `Meta.synthInstance; registerTraceClass `Meta.synthInstance.globalInstances; diff --git a/src/Init/Lean/Meta/WHNF.lean b/src/Init/Lean/Meta/WHNF.lean index 84ce45baec..61eff6c4f4 100644 --- a/src/Init/Lean/Meta/WHNF.lean +++ b/src/Init/Lean/Meta/WHNF.lean @@ -66,5 +66,8 @@ whnfHeadPredAux pred e def whnfUntil (e : Expr) (declName : Name) : MetaM Expr := whnfHeadPredAux (fun e => pure $ e.isAppOf declName) e +def getStuckMVar? (e : Expr) : MetaM (Option MVarId) := +WHNF.getStuckMVar? getConst whnf e + end Meta end Lean diff --git a/src/Init/Lean/Util/WHNF.lean b/src/Init/Lean/Util/WHNF.lean index 9a3ad4c88c..1f2b7c000f 100644 --- a/src/Init/Lean/Util/WHNF.lean +++ b/src/Init/Lean/Util/WHNF.lean @@ -121,10 +121,10 @@ if h : majorIdx < recArgs.size then do else failK () -@[specialize] def isRecStuck {m : Type → Type} [Monad m] - (whnf : Expr → m Expr) - (isStuck : Expr → m (Option Expr)) - (rec : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option Expr) := +@[specialize] def isRecStuck? {m : Type → Type} [Monad m] + (whnf : Expr → m Expr) + (isStuck? : Expr → m (Option MVarId)) + (rec : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option MVarId) := if rec.k then -- TODO: improve this case pure none @@ -133,7 +133,7 @@ else do if h : majorIdx < recArgs.size then do let major := recArgs.get ⟨majorIdx, h⟩; major ← whnf major; - isStuck major + isStuck? major else pure none @@ -166,20 +166,20 @@ match rec.kind with | QuotKind.ind => process 4 3 | _ => failK () -@[specialize] def isQuotRecStuck {m : Type → Type} [Monad m] +@[specialize] def isQuotRecStuck? {m : Type → Type} [Monad m] (whnf : Expr → m Expr) - (isStuck : Expr → m (Option Expr)) - (rec : QuotVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option Expr) := -let process (majorPos : Nat) : m (Option Expr) := + (isStuck? : Expr → m (Option MVarId)) + (rec : QuotVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option MVarId) := +let process? (majorPos : Nat) : m (Option MVarId) := if h : majorPos < recArgs.size then do let major := recArgs.get ⟨majorPos, h⟩; major ← whnf major; - isStuck major + isStuck? major else pure none; match rec.kind with -| QuotKind.lift => process 5 -| QuotKind.ind => process 4 +| QuotKind.lift => process? 5 +| QuotKind.ind => process? 4 | _ => pure none /- =========================== @@ -187,22 +187,22 @@ match rec.kind with =========================== -/ /-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/ -@[specialize] partial def getStuckMVar {m : Type → Type} [Monad m] +@[specialize] partial def getStuckMVar? {m : Type → Type} [Monad m] (getConst : Name → m (Option ConstantInfo)) (whnf : Expr → m Expr) - : Expr → m (Option Expr) -| Expr.mdata _ e _ => getStuckMVar e -| Expr.proj _ _ e _ => do e ← whnf e; getStuckMVar e -| e@(Expr.mvar _ _) => pure (some e) + : Expr → m (Option MVarId) +| Expr.mdata _ e _ => getStuckMVar? e +| Expr.proj _ _ e _ => do e ← whnf e; getStuckMVar? e +| e@(Expr.mvar mvarId _) => pure (some mvarId) | e@(Expr.app f _ _) => let f := f.getAppFn; match f with - | Expr.mvar _ _ => pure (some f) + | Expr.mvar mvarId _ => pure (some mvarId) | Expr.const fName fLvls _ => do cinfo? ← getConst fName; match cinfo? with - | some $ ConstantInfo.recInfo rec => isRecStuck whnf getStuckMVar rec fLvls e.getAppArgs - | some $ ConstantInfo.quotInfo rec => isQuotRecStuck whnf getStuckMVar rec fLvls e.getAppArgs + | some $ ConstantInfo.recInfo rec => isRecStuck? whnf getStuckMVar? rec fLvls e.getAppArgs + | some $ ConstantInfo.quotInfo rec => isQuotRecStuck? whnf getStuckMVar? rec fLvls e.getAppArgs | _ => pure none | _ => pure none | _ => pure none @@ -320,14 +320,14 @@ else (whnf : Expr → m Expr) (inferType : Expr → m Expr) (isDefEq : Expr → Expr → m Bool) - (synthesizePending : Expr → m Bool) + (synthesizePending : MVarId → m Bool) (getLocalDecl : FVarId → m LocalDecl) (getMVarAssignment : MVarId → m (Option Expr)) : Expr → m Expr | e => do e ← whnfCore getConst isAuxDef? whnf inferType isDefEq getLocalDecl getMVarAssignment e; - (some mvar) ← getStuckMVar getConst whnf e | pure e; - succeeded ← synthesizePending mvar; + (some mvarId) ← getStuckMVar? getConst whnf e | pure e; + succeeded ← synthesizePending mvarId; if succeeded then whnfCoreUnstuck e else pure e /-- Unfold definition using "smart unfolding" if possible. -/ @@ -337,7 +337,7 @@ else (whnf : Expr → m Expr) (inferType : Expr → m Expr) (isDefEq : Expr → Expr → m Bool) - (synthesizePending : Expr → m Bool) + (synthesizePending : MVarId → m Bool) (getLocalDecl : FVarId → m LocalDecl) (getMVarAssignment : MVarId → m (Option Expr)) (e : Expr) : m (Option Expr) := @@ -379,7 +379,7 @@ match e with (isAuxDef? : Name → m Bool) (inferType : Expr → m Expr) (isDefEq : Expr → Expr → m Bool) - (synthesizePending : Expr → m Bool) + (synthesizePending : MVarId → m Bool) (getLocalDecl : FVarId → m LocalDecl) (getMVarAssignment : MVarId → m (Option Expr)) : Expr → m Expr diff --git a/tests/lean/run/synthPending1.lean b/tests/lean/run/synthPending1.lean new file mode 100644 index 0000000000..d0acc0580c --- /dev/null +++ b/tests/lean/run/synthPending1.lean @@ -0,0 +1,4 @@ +new_frontend + +theorem ex : Not (1 = 2) := +ofDecideEqFalse rfl