feat: add isDefEqAux
`isDefEq` specialization is currently disabled because implementation is producing 150k lines of code. It seems the CPS trick I am using is producing a code explosion.
This commit is contained in:
parent
235ef740e4
commit
da8f9806a8
5 changed files with 142 additions and 68 deletions
|
|
@ -38,8 +38,9 @@ private partial def auxFixpoint : MetaOp → Expr → Expr → MetaM Expr
|
|||
match op with
|
||||
| whnfOp => whnfAux inferType isDefEq synthPending e₁
|
||||
| inferTypeOp => inferTypeAux whnf e₁
|
||||
-- | isDefEqOp => boolToExpr <$> isExprDefEqAux whnf synthPending e₁ e₂
|
||||
| isDefEqOp => boolToExpr <$> pure false
|
||||
| synthPendingOp => boolToExpr <$> pure false -- TODO
|
||||
| isDefEqOp => boolToExpr <$> pure false -- TODO
|
||||
|
||||
def whnf (e : Expr) : MetaM Expr :=
|
||||
auxFixpoint whnfOp e e
|
||||
|
|
@ -48,7 +49,7 @@ def inferType (e : Expr) : MetaM Expr :=
|
|||
auxFixpoint inferTypeOp e e
|
||||
|
||||
def isDefEq (e₁ e₂ : Expr) : MetaM Bool :=
|
||||
exprToBool <$> auxFixpoint isDefEqOp e₁ e₂
|
||||
try $ exprToBool <$> auxFixpoint isDefEqOp e₁ e₂
|
||||
/- =========================================== -/
|
||||
/- END OF BIG HACK -/
|
||||
/- =========================================== -/
|
||||
|
|
|
|||
|
|
@ -377,10 +377,10 @@ structure State :=
|
|||
|
||||
abbrev CheckAssignmentM := ReaderT Context (EStateM Exception State)
|
||||
|
||||
def findCached (e : Expr) : CheckAssignmentM (Option Expr) :=
|
||||
private def findCached (e : Expr) : CheckAssignmentM (Option Expr) :=
|
||||
do s ← get; pure $ s.cache.find e
|
||||
|
||||
def cache (e r : Expr) : CheckAssignmentM Unit :=
|
||||
private def cache (e r : Expr) : CheckAssignmentM Unit :=
|
||||
modify $ fun s => { cache := s.cache.insert e r, .. s }
|
||||
|
||||
instance : MonadCache Expr Expr CheckAssignmentM :=
|
||||
|
|
@ -748,22 +748,22 @@ match t.getAppFn, s.getAppFn with
|
|||
(whnf : Expr → MetaM Expr)
|
||||
(isDefEq : Expr → Expr → MetaM Bool)
|
||||
(synthesizePending : Expr → MetaM Bool)
|
||||
(t s : Expr) : MetaM LBool :=
|
||||
do let isDefEqL (t s : Expr) : MetaM LBool := toLBoolM $ isDefEq t s;
|
||||
let isDefEqLeft (fn : Name) (t s : Expr) : MetaM LBool := do {
|
||||
(t s : Expr)
|
||||
(k : MetaM Bool) -- continuation when `isDefEqDelta` could not decide
|
||||
: MetaM Bool :=
|
||||
do let isDefEqLeft (fn : Name) (t s : Expr) : MetaM Bool := do {
|
||||
trace! `Meta.isDefEq.delta.unfoldLeft fn;
|
||||
toLBoolM $ isDefEq t s
|
||||
isDefEq t s
|
||||
};
|
||||
let isDefEqRight (fn : Name) (t s : Expr) : MetaM LBool := do {
|
||||
let isDefEqRight (fn : Name) (t s : Expr) : MetaM Bool := do {
|
||||
trace! `Meta.isDefEq.delta.unfoldRight fn;
|
||||
toLBoolM $ isDefEq t s
|
||||
isDefEq t s
|
||||
};
|
||||
let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM LBool := do {
|
||||
let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM Bool := do {
|
||||
trace! `Meta.isDefEq.delta.unfoldLeftRight fn;
|
||||
toLBoolM $ isDefEq t s
|
||||
isDefEq t s
|
||||
};
|
||||
let isListLevelDefEqL (us vs : List Level) : MetaM LBool := toLBoolM $ isListLevelDefEqAux us vs;
|
||||
let unfold (e failK successK) : MetaM LBool := unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK;
|
||||
let unfold (e failK successK) : MetaM Bool := unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK;
|
||||
let tryHeuristic : MetaM Bool :=
|
||||
/- Try to solve `f a₁ ... aₙ =?= f b₁ ... bₙ` by solving `a₁ =?= b₁, ..., aₙ =?= bₙ` -/
|
||||
let tFn := t.getAppFn;
|
||||
|
|
@ -776,32 +776,32 @@ do let isDefEqL (t s : Expr) : MetaM LBool := toLBoolM $ isDefEq t s;
|
|||
tInfo? ← isDeltaCandidate t.getAppFn;
|
||||
sInfo? ← isDeltaCandidate s.getAppFn;
|
||||
match tInfo?, sInfo? with
|
||||
| none, none => pure LBool.undef
|
||||
| some tInfo, none => unfold t (pure LBool.undef) $ fun t => isDefEqLeft tInfo.name t s
|
||||
| none, some sInfo => unfold s (pure LBool.undef) $ fun s => isDefEqRight sInfo.name t s
|
||||
| none, none => k
|
||||
| some tInfo, none => unfold t k $ fun t => isDefEqLeft tInfo.name t s
|
||||
| none, some sInfo => unfold s k $ fun s => isDefEqRight sInfo.name t s
|
||||
| some tInfo, some sInfo =>
|
||||
let isDefEqLeft (t s) := isDefEqLeft tInfo.name t s;
|
||||
let isDefEqRight (t s) := isDefEqRight sInfo.name t s;
|
||||
let isDefEqLeftRight (t s) := isDefEqLeftRight tInfo.name t s;
|
||||
if tInfo.name == sInfo.name then
|
||||
match t, s with
|
||||
| Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqL ls₁ ls₂
|
||||
| Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqAux ls₁ ls₂
|
||||
| Expr.app _ _ _, Expr.app _ _ _ =>
|
||||
condM tryHeuristic
|
||||
(pure LBool.true)
|
||||
(pure true)
|
||||
(unfold t
|
||||
(unfold s (pure LBool.undef) (fun s => isDefEqRight t s))
|
||||
(unfold s (pure false) (fun s => isDefEqRight t s))
|
||||
(fun t => unfold s (isDefEqLeft t s) (fun s => isDefEqLeftRight t s)))
|
||||
| _, _ => pure LBool.false
|
||||
| _, _ => pure false
|
||||
else
|
||||
let unfoldComparingHeads : Unit → MetaM LBool := fun _ =>
|
||||
let unfoldComparingHeads : Unit → MetaM Bool := fun _ =>
|
||||
/-
|
||||
- If headSymbol (unfold t) == headSymbol s, then unfold t
|
||||
- If headSymbol (unfold s) == headSymbol t, then unfold s
|
||||
- Otherwise unfold t and s if possible. -/
|
||||
unfold t
|
||||
(unfold s
|
||||
(pure LBool.undef) -- `t` and `s` failed to be unfolded
|
||||
k -- `t` and `s` failed to be unfolded
|
||||
(fun s => isDefEqRight t s))
|
||||
(fun tNew =>
|
||||
if sameHeadSymbol tNew s then
|
||||
|
|
@ -810,7 +810,7 @@ do let isDefEqL (t s : Expr) : MetaM LBool := toLBoolM $ isDefEq t s;
|
|||
unfold s
|
||||
(isDefEqLeft tNew s)
|
||||
(fun sNew => if sameHeadSymbol t sNew then isDefEqRight t sNew else isDefEqLeftRight tNew sNew));
|
||||
let kernelLikeUnfolding : Unit → MetaM LBool := fun _ =>
|
||||
let kernelLikeUnfolding : Unit → MetaM Bool := fun _ =>
|
||||
if !t.hasExprMVar && !s.hasExprMVar then
|
||||
/- If `t` and `s` do not contain metavariables,
|
||||
we simulate strategy used in the kernel. -/
|
||||
|
|
@ -862,38 +862,38 @@ do decl ← getLocalDecl fvarId;
|
|||
(whnf : Expr → MetaM Expr)
|
||||
(isDefEq : Expr → Expr → MetaM Bool)
|
||||
(synthesizePending : Expr → MetaM Bool)
|
||||
: Expr → Expr → MetaM LBool
|
||||
| Expr.lit l₁ _, Expr.lit l₂ _ => pure (l₁ == l₂).toLBool
|
||||
| Expr.sort u _, Expr.sort v _ => toLBoolM $ isLevelDefEqAux u v
|
||||
| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _) => toLBoolM $ isDefEqBinding whnf isDefEq t s
|
||||
| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _) => toLBoolM $ isDefEqBinding whnf isDefEq t s
|
||||
| Expr.mdata _ t _, s => isDefEqQuick t s
|
||||
| t, Expr.mdata _ s _ => isDefEqQuick t s
|
||||
| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _ =>
|
||||
: Expr → Expr → MetaM Bool → MetaM Bool
|
||||
| Expr.lit l₁ _, Expr.lit l₂ _, _ => pure (l₁ == l₂)
|
||||
| Expr.sort u _, Expr.sort v _, _ => isLevelDefEqAux u v
|
||||
| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _), _ => isDefEqBinding whnf isDefEq t s
|
||||
| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _), _ => isDefEqBinding whnf isDefEq t s
|
||||
| Expr.mdata _ t _, s, k => isDefEqQuick t s k
|
||||
| t, Expr.mdata _ s _, k => isDefEqQuick t s k
|
||||
| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _, k =>
|
||||
condM (isLetFVar fvarId₁ <||> isLetFVar fvarId₂)
|
||||
(pure LBool.undef)
|
||||
(pure (fvarId₁ == fvarId₂).toLBool)
|
||||
| t, s =>
|
||||
cond (t == s) (pure LBool.true) $
|
||||
cond (etaEq t s || etaEq s t) (pure LBool.true) $ -- t =?= (fun xs => t xs)
|
||||
k
|
||||
(pure (fvarId₁ == fvarId₂))
|
||||
| t, s, k =>
|
||||
cond (t == s) (pure true) $
|
||||
cond (etaEq t s || etaEq s t) (pure true) $ -- t =?= (fun xs => t xs)
|
||||
let tFn := t.getAppFn;
|
||||
let sFn := s.getAppFn;
|
||||
cond (!tFn.isMVar && !sFn.isMVar) (pure LBool.undef) $
|
||||
condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $
|
||||
condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $
|
||||
condM (isSynthetic tFn <&&> synthesizePending tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $
|
||||
condM (isSynthetic sFn <&&> synthesizePending sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ do
|
||||
cond (!tFn.isMVar && !sFn.isMVar) k $
|
||||
condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $
|
||||
condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $
|
||||
condM (isSynthetic tFn <&&> synthesizePending tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $
|
||||
condM (isSynthetic sFn <&&> synthesizePending sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $ do
|
||||
tAssign? ← isAssignable tFn;
|
||||
sAssign? ← isAssignable sFn;
|
||||
let assign (t s : Expr) : MetaM LBool := toLBoolM $ processAssignment whnf isDefEq synthesizePending t s;
|
||||
let assign (t s : Expr) : MetaM Bool := processAssignment whnf isDefEq synthesizePending t s;
|
||||
cond (tAssign? && !sAssign?) (assign t s) $
|
||||
cond (!tAssign? && sAssign?) (assign s t) $
|
||||
cond (!tAssign? && !sAssign?)
|
||||
(if tFn.isMVar || sFn.isMVar then do
|
||||
ctx ← read;
|
||||
if ctx.config.isDefEqStuckEx then throwEx $ Exception.isDefEqStuck t s
|
||||
else pure LBool.false
|
||||
else pure LBool.undef) $ do
|
||||
else pure false
|
||||
else k) $ do
|
||||
-- Both `t` and `s` are terms of the form `?m ...`
|
||||
tMVarDecl ← getMVarDecl tFn.mvarId!;
|
||||
sMVarDecl ← getMVarDecl sFn.mvarId!;
|
||||
|
|
@ -912,5 +912,80 @@ do decl ← getLocalDecl fvarId;
|
|||
cond (!s.isApp && t.isApp && tMVarDecl.lctx.isSubPrefixOf sMVarDecl.lctx) (assign s t) $
|
||||
assign t s
|
||||
|
||||
@[specialize] private def isDefEqProofIrrel
|
||||
(whnf : Expr → MetaM Expr)
|
||||
(isDefEq : Expr → Expr → MetaM Bool)
|
||||
(t s : Expr)
|
||||
(k : MetaM Bool) -- continuation when `isDefEqQuick` could not decide
|
||||
: MetaM Bool :=
|
||||
do tType ← inferTypeAux whnf t;
|
||||
condM (isPropAux whnf tType)
|
||||
(do sType ← inferTypeAux whnf s; isDefEq tType sType)
|
||||
k
|
||||
|
||||
@[specialize] private partial def whnfCoreAux
|
||||
(whnf : Expr → MetaM Expr)
|
||||
(isDefEq : Expr → Expr → MetaM Bool)
|
||||
(synthesizePending : Expr → MetaM Bool)
|
||||
(e : Expr) : MetaM Expr :=
|
||||
Lean.whnfCore getConstNoEx isAuxDef? whnf (inferTypeAux whnf) isDefEq getLocalDecl getExprMVarAssignment e
|
||||
|
||||
@[specialize] private partial def isDefEqWHNF
|
||||
(whnf : Expr → MetaM Expr)
|
||||
(isDefEq : Expr → Expr → MetaM Bool)
|
||||
(synthesizePending : Expr → MetaM Bool)
|
||||
(t s : Expr)
|
||||
(k : Expr → Expr → MetaM Bool) : MetaM Bool :=
|
||||
do t' ← whnfCoreAux whnf isDefEq synthesizePending t;
|
||||
s' ← whnfCoreAux whnf isDefEq synthesizePending s;
|
||||
if t == t' && s == s' then
|
||||
k t s
|
||||
else
|
||||
isDefEqQuick whnf isDefEq synthesizePending t' s' $ k t' s'
|
||||
|
||||
@[specialize] private def unstuckMVar
|
||||
(whnf : Expr → MetaM Expr)
|
||||
(synthesizePending : Expr → MetaM Bool)
|
||||
(e : Expr)
|
||||
(successK : Expr → MetaM Bool) (failK : MetaM Bool): MetaM Bool :=
|
||||
do s? ← getStuckMVar getConst whnf e;
|
||||
match s? with
|
||||
| some s =>
|
||||
condM (synthesizePending s)
|
||||
(do e ← instantiateMVars e; successK e)
|
||||
failK
|
||||
| none => failK
|
||||
|
||||
@[specialize] private def isDefEqOnFailure
|
||||
(whnf : Expr → MetaM Expr)
|
||||
(isDefEq : Expr → Expr → MetaM Bool)
|
||||
(synthesizePending : Expr → MetaM Bool)
|
||||
(t s : Expr) : MetaM Bool :=
|
||||
unstuckMVar whnf synthesizePending t (fun t => isDefEq t s) $
|
||||
unstuckMVar whnf synthesizePending s (fun s => isDefEq t s) $
|
||||
pure false
|
||||
|
||||
@[specialize] partial def isExprDefEqAux
|
||||
(whnf : Expr → MetaM Expr)
|
||||
(synthesizePending : Expr → MetaM Bool)
|
||||
: Expr → Expr → MetaM Bool
|
||||
| t, s => do
|
||||
trace! `Meta.isDefEq.step (t ++ " =?= " ++ s);
|
||||
isDefEqQuick whnf isExprDefEqAux synthesizePending t s $
|
||||
isDefEqProofIrrel whnf isExprDefEqAux t s $
|
||||
isDefEqWHNF whnf isExprDefEqAux synthesizePending t s $ fun t s =>
|
||||
isDefEqOffset whnf isExprDefEqAux t s $
|
||||
isDefEqDelta whnf isExprDefEqAux synthesizePending t s $
|
||||
condM (isDefEqEta whnf isExprDefEqAux t s <||> isDefEqEta whnf isExprDefEqAux s t) (pure true) $
|
||||
match t, s with
|
||||
| Expr.const _ us _, Expr.const _ vs _ => isListLevelDefEqAux us vs
|
||||
| Expr.app _ _ _, Expr.app _ _ _ =>
|
||||
let tFn := t.getAppFn;
|
||||
condM (try (isExprDefEqAux tFn s.getAppFn <&&>
|
||||
isDefEqArgs whnf isExprDefEqAux synthesizePending tFn t.getAppArgs s.getAppArgs))
|
||||
(pure true)
|
||||
(isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s)
|
||||
| _, _ => isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s
|
||||
|
||||
end Meta
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -115,8 +115,8 @@ traceCtx `type_context.level_is_def_eq.postponed_step $ do
|
|||
pure false)
|
||||
true
|
||||
|
||||
private partial def processPostponedAux : Bool → MetaM Bool
|
||||
| mayPostpone => do
|
||||
private partial def processPostponedAux : Unit → MetaM Bool
|
||||
| _ => do
|
||||
numPostponed ← getNumPostponed;
|
||||
if numPostponed == 0 then
|
||||
pure true
|
||||
|
|
@ -130,15 +130,15 @@ private partial def processPostponedAux : Bool → MetaM Bool
|
|||
if numPostponed' == 0 then
|
||||
pure true
|
||||
else if numPostponed' < numPostponed then
|
||||
processPostponedAux mayPostpone
|
||||
processPostponedAux ()
|
||||
else do
|
||||
trace! `type_context.level_is_def_eq ("no progress solving pending is-def-eq level constraints");
|
||||
pure mayPostpone
|
||||
pure false
|
||||
|
||||
private def processPostponed (mayPostpone : Bool) : MetaM Bool :=
|
||||
private def processPostponed : MetaM Bool :=
|
||||
do numPostponed ← getNumPostponed;
|
||||
if numPostponed == 0 then pure true
|
||||
else traceCtx `type_context.level_is_def_eq.postponed $ processPostponedAux mayPostpone
|
||||
else traceCtx `type_context.level_is_def_eq.postponed $ processPostponedAux ()
|
||||
|
||||
|
||||
private def restore (env : Environment) (mctx : MetavarContext) (postponed : PersistentArray PostponedEntry) : MetaM Unit :=
|
||||
|
|
@ -158,7 +158,7 @@ do s ← get;
|
|||
modify $ fun s => { postponed := {}, .. s };
|
||||
catch
|
||||
(condM x
|
||||
(condM (processPostponed false)
|
||||
(condM processPostponed
|
||||
(pure true)
|
||||
(do restore env mctx postponed; pure false))
|
||||
(do restore env mctx postponed; pure false))
|
||||
|
|
@ -167,10 +167,7 @@ do s ← get;
|
|||
/- Public interface -/
|
||||
|
||||
def isLevelDefEq (u v : Level) : MetaM Bool :=
|
||||
try $ do
|
||||
r ← isLevelDefEqAux u v;
|
||||
if !r then pure false
|
||||
else processPostponed false
|
||||
try $ isLevelDefEqAux u v
|
||||
|
||||
end Meta
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -92,26 +92,27 @@ private partial def isOffset : Expr → Option (Expr × Nat)
|
|||
@[specialize] def isDefEqOffset
|
||||
(whnf : Expr → MetaM Expr)
|
||||
(isDefEq : Expr → Expr → MetaM Bool)
|
||||
(s t : Expr) : MetaM LBool :=
|
||||
(s t : Expr)
|
||||
(k : MetaM Bool) -- continuation when `isDefEqOffset` could not decide
|
||||
: MetaM Bool :=
|
||||
match isOffset s with
|
||||
| some (s, k₁) => match isOffset t with
|
||||
| some (t, k₂) => -- s+k₁ =?= t+k₂
|
||||
toLBoolM $
|
||||
if k₁ == k₂ then isDefEq s t
|
||||
else if k₁ < k₂ then isDefEq s (mkCAppB `Nat.add t (mkNatLit $ k₂ - k₁))
|
||||
else isDefEq (mkCAppB `Nat.add s (mkNatLit $ k₁ - k₂)) t
|
||||
if k₁ == k₂ then isDefEq s t
|
||||
else if k₁ < k₂ then isDefEq s (mkCAppB `Nat.add t (mkNatLit $ k₂ - k₁))
|
||||
else isDefEq (mkCAppB `Nat.add s (mkNatLit $ k₁ - k₂)) t
|
||||
| none => match evalNat t with
|
||||
| some v₂ => -- s+k₁ =?= v₂
|
||||
if v₂ ≥ k₁ then toLBoolM $ isDefEq s (mkNatLit $ v₂ - k₁) else pure LBool.false
|
||||
| none => pure LBool.undef
|
||||
if v₂ ≥ k₁ then isDefEq s (mkNatLit $ v₂ - k₁) else pure false
|
||||
| none => k
|
||||
| none => match evalNat s with
|
||||
| some v₁ => match isOffset t with
|
||||
| some (t, k₂) => -- v₁ =?= t+k₂
|
||||
if v₁ ≥ k₂ then toLBoolM $ isDefEq s (mkNatLit $ v₁ - k₂) else pure LBool.false
|
||||
if v₁ ≥ k₂ then isDefEq s (mkNatLit $ v₁ - k₂) else pure false
|
||||
| none => match evalNat t with
|
||||
| some v₂ => pure (v₁ == v₂).toLBool -- v₁ =?= v₂
|
||||
| none => pure LBool.undef
|
||||
| none => pure LBool.undef
|
||||
| some v₂ => pure (v₁ == v₂) -- v₁ =?= v₂
|
||||
| none => k
|
||||
| none => k
|
||||
|
||||
end Meta
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import Init.Lean.Meta.Basic
|
|||
namespace Lean
|
||||
namespace Meta
|
||||
|
||||
private def isAuxDef? (constName : Name) : MetaM Bool :=
|
||||
def isAuxDef? (constName : Name) : MetaM Bool :=
|
||||
do env ← getEnv; pure (isAuxRecursor env constName || isNoConfusion env constName)
|
||||
|
||||
@[specialize] def unfoldDefinitionAux {α}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue