From da8f9806a8d5d9cc7fa2aff363639caed194d312 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 20 Nov 2019 13:20:09 -0800 Subject: [PATCH] feat: add `isDefEqAux` `isDefEq` specialization is currently disabled because implementation is producing 150k lines of code. It seems the CPS trick I am using is producing a code explosion. --- library/Init/Lean/Meta.lean | 5 +- library/Init/Lean/Meta/ExprDefEq.lean | 161 ++++++++++++++++++------- library/Init/Lean/Meta/LevelDefEq.lean | 19 ++- library/Init/Lean/Meta/Offset.lean | 23 ++-- library/Init/Lean/Meta/WHNF.lean | 2 +- 5 files changed, 142 insertions(+), 68 deletions(-) diff --git a/library/Init/Lean/Meta.lean b/library/Init/Lean/Meta.lean index 7e841ed269..9de4136181 100644 --- a/library/Init/Lean/Meta.lean +++ b/library/Init/Lean/Meta.lean @@ -38,8 +38,9 @@ private partial def auxFixpoint : MetaOp → Expr → Expr → MetaM Expr match op with | whnfOp => whnfAux inferType isDefEq synthPending e₁ | inferTypeOp => inferTypeAux whnf e₁ + -- | isDefEqOp => boolToExpr <$> isExprDefEqAux whnf synthPending e₁ e₂ + | isDefEqOp => boolToExpr <$> pure false | synthPendingOp => boolToExpr <$> pure false -- TODO - | isDefEqOp => boolToExpr <$> pure false -- TODO def whnf (e : Expr) : MetaM Expr := auxFixpoint whnfOp e e @@ -48,7 +49,7 @@ def inferType (e : Expr) : MetaM Expr := auxFixpoint inferTypeOp e e def isDefEq (e₁ e₂ : Expr) : MetaM Bool := -exprToBool <$> auxFixpoint isDefEqOp e₁ e₂ +try $ exprToBool <$> auxFixpoint isDefEqOp e₁ e₂ /- =========================================== -/ /- END OF BIG HACK -/ /- =========================================== -/ diff --git a/library/Init/Lean/Meta/ExprDefEq.lean b/library/Init/Lean/Meta/ExprDefEq.lean index 125d0c24a4..5e32febeee 100644 --- a/library/Init/Lean/Meta/ExprDefEq.lean +++ b/library/Init/Lean/Meta/ExprDefEq.lean @@ -377,10 +377,10 @@ structure State := abbrev CheckAssignmentM := ReaderT Context (EStateM Exception State) -def findCached (e : Expr) : CheckAssignmentM (Option Expr) := +private def findCached (e : Expr) : CheckAssignmentM (Option Expr) := do s ← get; pure $ s.cache.find e -def cache (e r : Expr) : CheckAssignmentM Unit := +private def cache (e r : Expr) : CheckAssignmentM Unit := modify $ fun s => { cache := s.cache.insert e r, .. s } instance : MonadCache Expr Expr CheckAssignmentM := @@ -748,22 +748,22 @@ match t.getAppFn, s.getAppFn with (whnf : Expr → MetaM Expr) (isDefEq : Expr → Expr → MetaM Bool) (synthesizePending : Expr → MetaM Bool) - (t s : Expr) : MetaM LBool := -do let isDefEqL (t s : Expr) : MetaM LBool := toLBoolM $ isDefEq t s; - let isDefEqLeft (fn : Name) (t s : Expr) : MetaM LBool := do { + (t s : Expr) + (k : MetaM Bool) -- continuation when `isDefEqDelta` could not decide + : MetaM Bool := +do let isDefEqLeft (fn : Name) (t s : Expr) : MetaM Bool := do { trace! `Meta.isDefEq.delta.unfoldLeft fn; - toLBoolM $ isDefEq t s + isDefEq t s }; - let isDefEqRight (fn : Name) (t s : Expr) : MetaM LBool := do { + let isDefEqRight (fn : Name) (t s : Expr) : MetaM Bool := do { trace! `Meta.isDefEq.delta.unfoldRight fn; - toLBoolM $ isDefEq t s + isDefEq t s }; - let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM LBool := do { + let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM Bool := do { trace! `Meta.isDefEq.delta.unfoldLeftRight fn; - toLBoolM $ isDefEq t s + isDefEq t s }; - let isListLevelDefEqL (us vs : List Level) : MetaM LBool := toLBoolM $ isListLevelDefEqAux us vs; - let unfold (e failK successK) : MetaM LBool := unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK; + let unfold (e failK successK) : MetaM Bool := unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK; let tryHeuristic : MetaM Bool := /- Try to solve `f a₁ ... aₙ =?= f b₁ ... bₙ` by solving `a₁ =?= b₁, ..., aₙ =?= bₙ` -/ let tFn := t.getAppFn; @@ -776,32 +776,32 @@ do let isDefEqL (t s : Expr) : MetaM LBool := toLBoolM $ isDefEq t s; tInfo? ← isDeltaCandidate t.getAppFn; sInfo? ← isDeltaCandidate s.getAppFn; match tInfo?, sInfo? with - | none, none => pure LBool.undef - | some tInfo, none => unfold t (pure LBool.undef) $ fun t => isDefEqLeft tInfo.name t s - | none, some sInfo => unfold s (pure LBool.undef) $ fun s => isDefEqRight sInfo.name t s + | none, none => k + | some tInfo, none => unfold t k $ fun t => isDefEqLeft tInfo.name t s + | none, some sInfo => unfold s k $ fun s => isDefEqRight sInfo.name t s | some tInfo, some sInfo => let isDefEqLeft (t s) := isDefEqLeft tInfo.name t s; let isDefEqRight (t s) := isDefEqRight sInfo.name t s; let isDefEqLeftRight (t s) := isDefEqLeftRight tInfo.name t s; if tInfo.name == sInfo.name then match t, s with - | Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqL ls₁ ls₂ + | Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqAux ls₁ ls₂ | Expr.app _ _ _, Expr.app _ _ _ => condM tryHeuristic - (pure LBool.true) + (pure true) (unfold t - (unfold s (pure LBool.undef) (fun s => isDefEqRight t s)) + (unfold s (pure false) (fun s => isDefEqRight t s)) (fun t => unfold s (isDefEqLeft t s) (fun s => isDefEqLeftRight t s))) - | _, _ => pure LBool.false + | _, _ => pure false else - let unfoldComparingHeads : Unit → MetaM LBool := fun _ => + let unfoldComparingHeads : Unit → MetaM Bool := fun _ => /- - If headSymbol (unfold t) == headSymbol s, then unfold t - If headSymbol (unfold s) == headSymbol t, then unfold s - Otherwise unfold t and s if possible. -/ unfold t (unfold s - (pure LBool.undef) -- `t` and `s` failed to be unfolded + k -- `t` and `s` failed to be unfolded (fun s => isDefEqRight t s)) (fun tNew => if sameHeadSymbol tNew s then @@ -810,7 +810,7 @@ do let isDefEqL (t s : Expr) : MetaM LBool := toLBoolM $ isDefEq t s; unfold s (isDefEqLeft tNew s) (fun sNew => if sameHeadSymbol t sNew then isDefEqRight t sNew else isDefEqLeftRight tNew sNew)); - let kernelLikeUnfolding : Unit → MetaM LBool := fun _ => + let kernelLikeUnfolding : Unit → MetaM Bool := fun _ => if !t.hasExprMVar && !s.hasExprMVar then /- If `t` and `s` do not contain metavariables, we simulate strategy used in the kernel. -/ @@ -862,38 +862,38 @@ do decl ← getLocalDecl fvarId; (whnf : Expr → MetaM Expr) (isDefEq : Expr → Expr → MetaM Bool) (synthesizePending : Expr → MetaM Bool) - : Expr → Expr → MetaM LBool -| Expr.lit l₁ _, Expr.lit l₂ _ => pure (l₁ == l₂).toLBool -| Expr.sort u _, Expr.sort v _ => toLBoolM $ isLevelDefEqAux u v -| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _) => toLBoolM $ isDefEqBinding whnf isDefEq t s -| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _) => toLBoolM $ isDefEqBinding whnf isDefEq t s -| Expr.mdata _ t _, s => isDefEqQuick t s -| t, Expr.mdata _ s _ => isDefEqQuick t s -| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _ => + : Expr → Expr → MetaM Bool → MetaM Bool +| Expr.lit l₁ _, Expr.lit l₂ _, _ => pure (l₁ == l₂) +| Expr.sort u _, Expr.sort v _, _ => isLevelDefEqAux u v +| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _), _ => isDefEqBinding whnf isDefEq t s +| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _), _ => isDefEqBinding whnf isDefEq t s +| Expr.mdata _ t _, s, k => isDefEqQuick t s k +| t, Expr.mdata _ s _, k => isDefEqQuick t s k +| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _, k => condM (isLetFVar fvarId₁ <||> isLetFVar fvarId₂) - (pure LBool.undef) - (pure (fvarId₁ == fvarId₂).toLBool) -| t, s => - cond (t == s) (pure LBool.true) $ - cond (etaEq t s || etaEq s t) (pure LBool.true) $ -- t =?= (fun xs => t xs) + k + (pure (fvarId₁ == fvarId₂)) +| t, s, k => + cond (t == s) (pure true) $ + cond (etaEq t s || etaEq s t) (pure true) $ -- t =?= (fun xs => t xs) let tFn := t.getAppFn; let sFn := s.getAppFn; - cond (!tFn.isMVar && !sFn.isMVar) (pure LBool.undef) $ - condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $ - condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ - condM (isSynthetic tFn <&&> synthesizePending tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $ - condM (isSynthetic sFn <&&> synthesizePending sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ do + cond (!tFn.isMVar && !sFn.isMVar) k $ + condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $ + condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $ + condM (isSynthetic tFn <&&> synthesizePending tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $ + condM (isSynthetic sFn <&&> synthesizePending sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $ do tAssign? ← isAssignable tFn; sAssign? ← isAssignable sFn; - let assign (t s : Expr) : MetaM LBool := toLBoolM $ processAssignment whnf isDefEq synthesizePending t s; + let assign (t s : Expr) : MetaM Bool := processAssignment whnf isDefEq synthesizePending t s; cond (tAssign? && !sAssign?) (assign t s) $ cond (!tAssign? && sAssign?) (assign s t) $ cond (!tAssign? && !sAssign?) (if tFn.isMVar || sFn.isMVar then do ctx ← read; if ctx.config.isDefEqStuckEx then throwEx $ Exception.isDefEqStuck t s - else pure LBool.false - else pure LBool.undef) $ do + else pure false + else k) $ do -- Both `t` and `s` are terms of the form `?m ...` tMVarDecl ← getMVarDecl tFn.mvarId!; sMVarDecl ← getMVarDecl sFn.mvarId!; @@ -912,5 +912,80 @@ do decl ← getLocalDecl fvarId; cond (!s.isApp && t.isApp && tMVarDecl.lctx.isSubPrefixOf sMVarDecl.lctx) (assign s t) $ assign t s +@[specialize] private def isDefEqProofIrrel + (whnf : Expr → MetaM Expr) + (isDefEq : Expr → Expr → MetaM Bool) + (t s : Expr) + (k : MetaM Bool) -- continuation when `isDefEqQuick` could not decide + : MetaM Bool := +do tType ← inferTypeAux whnf t; + condM (isPropAux whnf tType) + (do sType ← inferTypeAux whnf s; isDefEq tType sType) + k + +@[specialize] private partial def whnfCoreAux + (whnf : Expr → MetaM Expr) + (isDefEq : Expr → Expr → MetaM Bool) + (synthesizePending : Expr → MetaM Bool) + (e : Expr) : MetaM Expr := +Lean.whnfCore getConstNoEx isAuxDef? whnf (inferTypeAux whnf) isDefEq getLocalDecl getExprMVarAssignment e + +@[specialize] private partial def isDefEqWHNF + (whnf : Expr → MetaM Expr) + (isDefEq : Expr → Expr → MetaM Bool) + (synthesizePending : Expr → MetaM Bool) + (t s : Expr) + (k : Expr → Expr → MetaM Bool) : MetaM Bool := +do t' ← whnfCoreAux whnf isDefEq synthesizePending t; + s' ← whnfCoreAux whnf isDefEq synthesizePending s; + if t == t' && s == s' then + k t s + else + isDefEqQuick whnf isDefEq synthesizePending t' s' $ k t' s' + +@[specialize] private def unstuckMVar + (whnf : Expr → MetaM Expr) + (synthesizePending : Expr → MetaM Bool) + (e : Expr) + (successK : Expr → MetaM Bool) (failK : MetaM Bool): MetaM Bool := +do s? ← getStuckMVar getConst whnf e; + match s? with + | some s => + condM (synthesizePending s) + (do e ← instantiateMVars e; successK e) + failK + | none => failK + +@[specialize] private def isDefEqOnFailure + (whnf : Expr → MetaM Expr) + (isDefEq : Expr → Expr → MetaM Bool) + (synthesizePending : Expr → MetaM Bool) + (t s : Expr) : MetaM Bool := +unstuckMVar whnf synthesizePending t (fun t => isDefEq t s) $ +unstuckMVar whnf synthesizePending s (fun s => isDefEq t s) $ +pure false + +@[specialize] partial def isExprDefEqAux + (whnf : Expr → MetaM Expr) + (synthesizePending : Expr → MetaM Bool) + : Expr → Expr → MetaM Bool +| t, s => do + trace! `Meta.isDefEq.step (t ++ " =?= " ++ s); + isDefEqQuick whnf isExprDefEqAux synthesizePending t s $ + isDefEqProofIrrel whnf isExprDefEqAux t s $ + isDefEqWHNF whnf isExprDefEqAux synthesizePending t s $ fun t s => + isDefEqOffset whnf isExprDefEqAux t s $ + isDefEqDelta whnf isExprDefEqAux synthesizePending t s $ + condM (isDefEqEta whnf isExprDefEqAux t s <||> isDefEqEta whnf isExprDefEqAux s t) (pure true) $ + match t, s with + | Expr.const _ us _, Expr.const _ vs _ => isListLevelDefEqAux us vs + | Expr.app _ _ _, Expr.app _ _ _ => + let tFn := t.getAppFn; + condM (try (isExprDefEqAux tFn s.getAppFn <&&> + isDefEqArgs whnf isExprDefEqAux synthesizePending tFn t.getAppArgs s.getAppArgs)) + (pure true) + (isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s) + | _, _ => isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s + end Meta end Lean diff --git a/library/Init/Lean/Meta/LevelDefEq.lean b/library/Init/Lean/Meta/LevelDefEq.lean index 3566445cc4..5e55e5112c 100644 --- a/library/Init/Lean/Meta/LevelDefEq.lean +++ b/library/Init/Lean/Meta/LevelDefEq.lean @@ -115,8 +115,8 @@ traceCtx `type_context.level_is_def_eq.postponed_step $ do pure false) true -private partial def processPostponedAux : Bool → MetaM Bool -| mayPostpone => do +private partial def processPostponedAux : Unit → MetaM Bool +| _ => do numPostponed ← getNumPostponed; if numPostponed == 0 then pure true @@ -130,15 +130,15 @@ private partial def processPostponedAux : Bool → MetaM Bool if numPostponed' == 0 then pure true else if numPostponed' < numPostponed then - processPostponedAux mayPostpone + processPostponedAux () else do trace! `type_context.level_is_def_eq ("no progress solving pending is-def-eq level constraints"); - pure mayPostpone + pure false -private def processPostponed (mayPostpone : Bool) : MetaM Bool := +private def processPostponed : MetaM Bool := do numPostponed ← getNumPostponed; if numPostponed == 0 then pure true - else traceCtx `type_context.level_is_def_eq.postponed $ processPostponedAux mayPostpone + else traceCtx `type_context.level_is_def_eq.postponed $ processPostponedAux () private def restore (env : Environment) (mctx : MetavarContext) (postponed : PersistentArray PostponedEntry) : MetaM Unit := @@ -158,7 +158,7 @@ do s ← get; modify $ fun s => { postponed := {}, .. s }; catch (condM x - (condM (processPostponed false) + (condM processPostponed (pure true) (do restore env mctx postponed; pure false)) (do restore env mctx postponed; pure false)) @@ -167,10 +167,7 @@ do s ← get; /- Public interface -/ def isLevelDefEq (u v : Level) : MetaM Bool := -try $ do - r ← isLevelDefEqAux u v; - if !r then pure false - else processPostponed false +try $ isLevelDefEqAux u v end Meta end Lean diff --git a/library/Init/Lean/Meta/Offset.lean b/library/Init/Lean/Meta/Offset.lean index 062cecf16e..668a034d2b 100644 --- a/library/Init/Lean/Meta/Offset.lean +++ b/library/Init/Lean/Meta/Offset.lean @@ -92,26 +92,27 @@ private partial def isOffset : Expr → Option (Expr × Nat) @[specialize] def isDefEqOffset (whnf : Expr → MetaM Expr) (isDefEq : Expr → Expr → MetaM Bool) - (s t : Expr) : MetaM LBool := + (s t : Expr) + (k : MetaM Bool) -- continuation when `isDefEqOffset` could not decide + : MetaM Bool := match isOffset s with | some (s, k₁) => match isOffset t with | some (t, k₂) => -- s+k₁ =?= t+k₂ - toLBoolM $ - if k₁ == k₂ then isDefEq s t - else if k₁ < k₂ then isDefEq s (mkCAppB `Nat.add t (mkNatLit $ k₂ - k₁)) - else isDefEq (mkCAppB `Nat.add s (mkNatLit $ k₁ - k₂)) t + if k₁ == k₂ then isDefEq s t + else if k₁ < k₂ then isDefEq s (mkCAppB `Nat.add t (mkNatLit $ k₂ - k₁)) + else isDefEq (mkCAppB `Nat.add s (mkNatLit $ k₁ - k₂)) t | none => match evalNat t with | some v₂ => -- s+k₁ =?= v₂ - if v₂ ≥ k₁ then toLBoolM $ isDefEq s (mkNatLit $ v₂ - k₁) else pure LBool.false - | none => pure LBool.undef + if v₂ ≥ k₁ then isDefEq s (mkNatLit $ v₂ - k₁) else pure false + | none => k | none => match evalNat s with | some v₁ => match isOffset t with | some (t, k₂) => -- v₁ =?= t+k₂ - if v₁ ≥ k₂ then toLBoolM $ isDefEq s (mkNatLit $ v₁ - k₂) else pure LBool.false + if v₁ ≥ k₂ then isDefEq s (mkNatLit $ v₁ - k₂) else pure false | none => match evalNat t with - | some v₂ => pure (v₁ == v₂).toLBool -- v₁ =?= v₂ - | none => pure LBool.undef - | none => pure LBool.undef + | some v₂ => pure (v₁ == v₂) -- v₁ =?= v₂ + | none => k + | none => k end Meta end Lean diff --git a/library/Init/Lean/Meta/WHNF.lean b/library/Init/Lean/Meta/WHNF.lean index 35ebcf1d83..8ffe89fd9b 100644 --- a/library/Init/Lean/Meta/WHNF.lean +++ b/library/Init/Lean/Meta/WHNF.lean @@ -11,7 +11,7 @@ import Init.Lean.Meta.Basic namespace Lean namespace Meta -private def isAuxDef? (constName : Name) : MetaM Bool := +def isAuxDef? (constName : Name) : MetaM Bool := do env ← getEnv; pure (isAuxRecursor env constName || isNoConfusion env constName) @[specialize] def unfoldDefinitionAux {α}