feat: add isDefEqAux

`isDefEq` specialization is currently disabled because implementation
is producing 150k lines of code.
It seems the CPS trick I am using is producing a code explosion.
This commit is contained in:
Leonardo de Moura 2019-11-20 13:20:09 -08:00
parent 235ef740e4
commit da8f9806a8
5 changed files with 142 additions and 68 deletions

View file

@ -38,8 +38,9 @@ private partial def auxFixpoint : MetaOp → Expr → Expr → MetaM Expr
match op with
| whnfOp => whnfAux inferType isDefEq synthPending e₁
| inferTypeOp => inferTypeAux whnf e₁
-- | isDefEqOp => boolToExpr <$> isExprDefEqAux whnf synthPending e₁ e₂
| isDefEqOp => boolToExpr <$> pure false
| synthPendingOp => boolToExpr <$> pure false -- TODO
| isDefEqOp => boolToExpr <$> pure false -- TODO
def whnf (e : Expr) : MetaM Expr :=
auxFixpoint whnfOp e e
@ -48,7 +49,7 @@ def inferType (e : Expr) : MetaM Expr :=
auxFixpoint inferTypeOp e e
def isDefEq (e₁ e₂ : Expr) : MetaM Bool :=
exprToBool <$> auxFixpoint isDefEqOp e₁ e₂
try $ exprToBool <$> auxFixpoint isDefEqOp e₁ e₂
/- =========================================== -/
/- END OF BIG HACK -/
/- =========================================== -/

View file

@ -377,10 +377,10 @@ structure State :=
abbrev CheckAssignmentM := ReaderT Context (EStateM Exception State)
def findCached (e : Expr) : CheckAssignmentM (Option Expr) :=
private def findCached (e : Expr) : CheckAssignmentM (Option Expr) :=
do s ← get; pure $ s.cache.find e
def cache (e r : Expr) : CheckAssignmentM Unit :=
private def cache (e r : Expr) : CheckAssignmentM Unit :=
modify $ fun s => { cache := s.cache.insert e r, .. s }
instance : MonadCache Expr Expr CheckAssignmentM :=
@ -748,22 +748,22 @@ match t.getAppFn, s.getAppFn with
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(t s : Expr) : MetaM LBool :=
do let isDefEqL (t s : Expr) : MetaM LBool := toLBoolM $ isDefEq t s;
let isDefEqLeft (fn : Name) (t s : Expr) : MetaM LBool := do {
(t s : Expr)
(k : MetaM Bool) -- continuation when `isDefEqDelta` could not decide
: MetaM Bool :=
do let isDefEqLeft (fn : Name) (t s : Expr) : MetaM Bool := do {
trace! `Meta.isDefEq.delta.unfoldLeft fn;
toLBoolM $ isDefEq t s
isDefEq t s
};
let isDefEqRight (fn : Name) (t s : Expr) : MetaM LBool := do {
let isDefEqRight (fn : Name) (t s : Expr) : MetaM Bool := do {
trace! `Meta.isDefEq.delta.unfoldRight fn;
toLBoolM $ isDefEq t s
isDefEq t s
};
let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM LBool := do {
let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM Bool := do {
trace! `Meta.isDefEq.delta.unfoldLeftRight fn;
toLBoolM $ isDefEq t s
isDefEq t s
};
let isListLevelDefEqL (us vs : List Level) : MetaM LBool := toLBoolM $ isListLevelDefEqAux us vs;
let unfold (e failK successK) : MetaM LBool := unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK;
let unfold (e failK successK) : MetaM Bool := unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK;
let tryHeuristic : MetaM Bool :=
/- Try to solve `f a₁ ... aₙ =?= f b₁ ... bₙ` by solving `a₁ =?= b₁, ..., aₙ =?= bₙ` -/
let tFn := t.getAppFn;
@ -776,32 +776,32 @@ do let isDefEqL (t s : Expr) : MetaM LBool := toLBoolM $ isDefEq t s;
tInfo? ← isDeltaCandidate t.getAppFn;
sInfo? ← isDeltaCandidate s.getAppFn;
match tInfo?, sInfo? with
| none, none => pure LBool.undef
| some tInfo, none => unfold t (pure LBool.undef) $ fun t => isDefEqLeft tInfo.name t s
| none, some sInfo => unfold s (pure LBool.undef) $ fun s => isDefEqRight sInfo.name t s
| none, none => k
| some tInfo, none => unfold t k $ fun t => isDefEqLeft tInfo.name t s
| none, some sInfo => unfold s k $ fun s => isDefEqRight sInfo.name t s
| some tInfo, some sInfo =>
let isDefEqLeft (t s) := isDefEqLeft tInfo.name t s;
let isDefEqRight (t s) := isDefEqRight sInfo.name t s;
let isDefEqLeftRight (t s) := isDefEqLeftRight tInfo.name t s;
if tInfo.name == sInfo.name then
match t, s with
| Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqL ls₁ ls₂
| Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqAux ls₁ ls₂
| Expr.app _ _ _, Expr.app _ _ _ =>
condM tryHeuristic
(pure LBool.true)
(pure true)
(unfold t
(unfold s (pure LBool.undef) (fun s => isDefEqRight t s))
(unfold s (pure false) (fun s => isDefEqRight t s))
(fun t => unfold s (isDefEqLeft t s) (fun s => isDefEqLeftRight t s)))
| _, _ => pure LBool.false
| _, _ => pure false
else
let unfoldComparingHeads : Unit → MetaM LBool := fun _ =>
let unfoldComparingHeads : Unit → MetaM Bool := fun _ =>
/-
- If headSymbol (unfold t) == headSymbol s, then unfold t
- If headSymbol (unfold s) == headSymbol t, then unfold s
- Otherwise unfold t and s if possible. -/
unfold t
(unfold s
(pure LBool.undef) -- `t` and `s` failed to be unfolded
k -- `t` and `s` failed to be unfolded
(fun s => isDefEqRight t s))
(fun tNew =>
if sameHeadSymbol tNew s then
@ -810,7 +810,7 @@ do let isDefEqL (t s : Expr) : MetaM LBool := toLBoolM $ isDefEq t s;
unfold s
(isDefEqLeft tNew s)
(fun sNew => if sameHeadSymbol t sNew then isDefEqRight t sNew else isDefEqLeftRight tNew sNew));
let kernelLikeUnfolding : Unit → MetaM LBool := fun _ =>
let kernelLikeUnfolding : Unit → MetaM Bool := fun _ =>
if !t.hasExprMVar && !s.hasExprMVar then
/- If `t` and `s` do not contain metavariables,
we simulate strategy used in the kernel. -/
@ -862,38 +862,38 @@ do decl ← getLocalDecl fvarId;
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
: Expr → Expr → MetaM LBool
| Expr.lit l₁ _, Expr.lit l₂ _ => pure (l₁ == l₂).toLBool
| Expr.sort u _, Expr.sort v _ => toLBoolM $ isLevelDefEqAux u v
| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _) => toLBoolM $ isDefEqBinding whnf isDefEq t s
| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _) => toLBoolM $ isDefEqBinding whnf isDefEq t s
| Expr.mdata _ t _, s => isDefEqQuick t s
| t, Expr.mdata _ s _ => isDefEqQuick t s
| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _ =>
: Expr → Expr → MetaM Bool → MetaM Bool
| Expr.lit l₁ _, Expr.lit l₂ _, _ => pure (l₁ == l₂)
| Expr.sort u _, Expr.sort v _, _ => isLevelDefEqAux u v
| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _), _ => isDefEqBinding whnf isDefEq t s
| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _), _ => isDefEqBinding whnf isDefEq t s
| Expr.mdata _ t _, s, k => isDefEqQuick t s k
| t, Expr.mdata _ s _, k => isDefEqQuick t s k
| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _, k =>
condM (isLetFVar fvarId₁ <||> isLetFVar fvarId₂)
(pure LBool.undef)
(pure (fvarId₁ == fvarId₂).toLBool)
| t, s =>
cond (t == s) (pure LBool.true) $
cond (etaEq t s || etaEq s t) (pure LBool.true) $ -- t =?= (fun xs => t xs)
k
(pure (fvarId₁ == fvarId₂))
| t, s, k =>
cond (t == s) (pure true) $
cond (etaEq t s || etaEq s t) (pure true) $ -- t =?= (fun xs => t xs)
let tFn := t.getAppFn;
let sFn := s.getAppFn;
cond (!tFn.isMVar && !sFn.isMVar) (pure LBool.undef) $
condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $
condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $
condM (isSynthetic tFn <&&> synthesizePending tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $
condM (isSynthetic sFn <&&> synthesizePending sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ do
cond (!tFn.isMVar && !sFn.isMVar) k $
condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $
condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $
condM (isSynthetic tFn <&&> synthesizePending tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $
condM (isSynthetic sFn <&&> synthesizePending sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $ do
tAssign? ← isAssignable tFn;
sAssign? ← isAssignable sFn;
let assign (t s : Expr) : MetaM LBool := toLBoolM $ processAssignment whnf isDefEq synthesizePending t s;
let assign (t s : Expr) : MetaM Bool := processAssignment whnf isDefEq synthesizePending t s;
cond (tAssign? && !sAssign?) (assign t s) $
cond (!tAssign? && sAssign?) (assign s t) $
cond (!tAssign? && !sAssign?)
(if tFn.isMVar || sFn.isMVar then do
ctx ← read;
if ctx.config.isDefEqStuckEx then throwEx $ Exception.isDefEqStuck t s
else pure LBool.false
else pure LBool.undef) $ do
else pure false
else k) $ do
-- Both `t` and `s` are terms of the form `?m ...`
tMVarDecl ← getMVarDecl tFn.mvarId!;
sMVarDecl ← getMVarDecl sFn.mvarId!;
@ -912,5 +912,80 @@ do decl ← getLocalDecl fvarId;
cond (!s.isApp && t.isApp && tMVarDecl.lctx.isSubPrefixOf sMVarDecl.lctx) (assign s t) $
assign t s
@[specialize] private def isDefEqProofIrrel
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(t s : Expr)
(k : MetaM Bool) -- continuation when `isDefEqQuick` could not decide
: MetaM Bool :=
do tType ← inferTypeAux whnf t;
condM (isPropAux whnf tType)
(do sType ← inferTypeAux whnf s; isDefEq tType sType)
k
@[specialize] private partial def whnfCoreAux
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(e : Expr) : MetaM Expr :=
Lean.whnfCore getConstNoEx isAuxDef? whnf (inferTypeAux whnf) isDefEq getLocalDecl getExprMVarAssignment e
@[specialize] private partial def isDefEqWHNF
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(t s : Expr)
(k : Expr → Expr → MetaM Bool) : MetaM Bool :=
do t' ← whnfCoreAux whnf isDefEq synthesizePending t;
s' ← whnfCoreAux whnf isDefEq synthesizePending s;
if t == t' && s == s' then
k t s
else
isDefEqQuick whnf isDefEq synthesizePending t' s' $ k t' s'
@[specialize] private def unstuckMVar
(whnf : Expr → MetaM Expr)
(synthesizePending : Expr → MetaM Bool)
(e : Expr)
(successK : Expr → MetaM Bool) (failK : MetaM Bool): MetaM Bool :=
do s? ← getStuckMVar getConst whnf e;
match s? with
| some s =>
condM (synthesizePending s)
(do e ← instantiateMVars e; successK e)
failK
| none => failK
@[specialize] private def isDefEqOnFailure
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(t s : Expr) : MetaM Bool :=
unstuckMVar whnf synthesizePending t (fun t => isDefEq t s) $
unstuckMVar whnf synthesizePending s (fun s => isDefEq t s) $
pure false
@[specialize] partial def isExprDefEqAux
(whnf : Expr → MetaM Expr)
(synthesizePending : Expr → MetaM Bool)
: Expr → Expr → MetaM Bool
| t, s => do
trace! `Meta.isDefEq.step (t ++ " =?= " ++ s);
isDefEqQuick whnf isExprDefEqAux synthesizePending t s $
isDefEqProofIrrel whnf isExprDefEqAux t s $
isDefEqWHNF whnf isExprDefEqAux synthesizePending t s $ fun t s =>
isDefEqOffset whnf isExprDefEqAux t s $
isDefEqDelta whnf isExprDefEqAux synthesizePending t s $
condM (isDefEqEta whnf isExprDefEqAux t s <||> isDefEqEta whnf isExprDefEqAux s t) (pure true) $
match t, s with
| Expr.const _ us _, Expr.const _ vs _ => isListLevelDefEqAux us vs
| Expr.app _ _ _, Expr.app _ _ _ =>
let tFn := t.getAppFn;
condM (try (isExprDefEqAux tFn s.getAppFn <&&>
isDefEqArgs whnf isExprDefEqAux synthesizePending tFn t.getAppArgs s.getAppArgs))
(pure true)
(isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s)
| _, _ => isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s
end Meta
end Lean

View file

@ -115,8 +115,8 @@ traceCtx `type_context.level_is_def_eq.postponed_step $ do
pure false)
true
private partial def processPostponedAux : Bool → MetaM Bool
| mayPostpone => do
private partial def processPostponedAux : Unit → MetaM Bool
| _ => do
numPostponed ← getNumPostponed;
if numPostponed == 0 then
pure true
@ -130,15 +130,15 @@ private partial def processPostponedAux : Bool → MetaM Bool
if numPostponed' == 0 then
pure true
else if numPostponed' < numPostponed then
processPostponedAux mayPostpone
processPostponedAux ()
else do
trace! `type_context.level_is_def_eq ("no progress solving pending is-def-eq level constraints");
pure mayPostpone
pure false
private def processPostponed (mayPostpone : Bool) : MetaM Bool :=
private def processPostponed : MetaM Bool :=
do numPostponed ← getNumPostponed;
if numPostponed == 0 then pure true
else traceCtx `type_context.level_is_def_eq.postponed $ processPostponedAux mayPostpone
else traceCtx `type_context.level_is_def_eq.postponed $ processPostponedAux ()
private def restore (env : Environment) (mctx : MetavarContext) (postponed : PersistentArray PostponedEntry) : MetaM Unit :=
@ -158,7 +158,7 @@ do s ← get;
modify $ fun s => { postponed := {}, .. s };
catch
(condM x
(condM (processPostponed false)
(condM processPostponed
(pure true)
(do restore env mctx postponed; pure false))
(do restore env mctx postponed; pure false))
@ -167,10 +167,7 @@ do s ← get;
/- Public interface -/
def isLevelDefEq (u v : Level) : MetaM Bool :=
try $ do
r ← isLevelDefEqAux u v;
if !r then pure false
else processPostponed false
try $ isLevelDefEqAux u v
end Meta
end Lean

View file

@ -92,26 +92,27 @@ private partial def isOffset : Expr → Option (Expr × Nat)
@[specialize] def isDefEqOffset
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(s t : Expr) : MetaM LBool :=
(s t : Expr)
(k : MetaM Bool) -- continuation when `isDefEqOffset` could not decide
: MetaM Bool :=
match isOffset s with
| some (s, k₁) => match isOffset t with
| some (t, k₂) => -- s+k₁ =?= t+k₂
toLBoolM $
if k₁ == k₂ then isDefEq s t
else if k₁ < k₂ then isDefEq s (mkCAppB `Nat.add t (mkNatLit $ k₂ - k₁))
else isDefEq (mkCAppB `Nat.add s (mkNatLit $ k₁ - k₂)) t
if k₁ == k₂ then isDefEq s t
else if k₁ < k₂ then isDefEq s (mkCAppB `Nat.add t (mkNatLit $ k₂ - k₁))
else isDefEq (mkCAppB `Nat.add s (mkNatLit $ k₁ - k₂)) t
| none => match evalNat t with
| some v₂ => -- s+k₁ =?= v₂
if v₂ ≥ k₁ then toLBoolM $ isDefEq s (mkNatLit $ v₂ - k₁) else pure LBool.false
| none => pure LBool.undef
if v₂ ≥ k₁ then isDefEq s (mkNatLit $ v₂ - k₁) else pure false
| none => k
| none => match evalNat s with
| some v₁ => match isOffset t with
| some (t, k₂) => -- v₁ =?= t+k₂
if v₁ ≥ k₂ then toLBoolM $ isDefEq s (mkNatLit $ v₁ - k₂) else pure LBool.false
if v₁ ≥ k₂ then isDefEq s (mkNatLit $ v₁ - k₂) else pure false
| none => match evalNat t with
| some v₂ => pure (v₁ == v₂).toLBool -- v₁ =?= v₂
| none => pure LBool.undef
| none => pure LBool.undef
| some v₂ => pure (v₁ == v₂) -- v₁ =?= v₂
| none => k
| none => k
end Meta
end Lean

View file

@ -11,7 +11,7 @@ import Init.Lean.Meta.Basic
namespace Lean
namespace Meta
private def isAuxDef? (constName : Name) : MetaM Bool :=
def isAuxDef? (constName : Name) : MetaM Bool :=
do env ← getEnv; pure (isAuxRecursor env constName || isNoConfusion env constName)
@[specialize] def unfoldDefinitionAux {α}