From 98407798aff8787c456d9a24364984767b483830 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 25 Jan 2022 18:41:07 -0800 Subject: [PATCH] fix: `mkEquationsFor` at `Match/MatchEqs.lean` closes #974 --- src/Lean/Meta/Match/MatchEqs.lean | 174 +++++++++++++++++++-------- tests/lean/974.lean | 15 +++ tests/lean/974.lean.expected.out | 7 ++ tests/lean/eqValue.lean | 11 ++ tests/lean/eqValue.lean.expected.out | 4 + 5 files changed, 164 insertions(+), 47 deletions(-) create mode 100644 tests/lean/974.lean create mode 100644 tests/lean/974.lean.expected.out create mode 100644 tests/lean/eqValue.lean create mode 100644 tests/lean/eqValue.lean.expected.out diff --git a/src/Lean/Meta/Match/MatchEqs.lean b/src/Lean/Meta/Match/MatchEqs.lean index 44e1a6ccb1..b1f7262b49 100644 --- a/src/Lean/Meta/Match/MatchEqs.lean +++ b/src/Lean/Meta/Match/MatchEqs.lean @@ -7,6 +7,8 @@ import Lean.Meta.Match.Match import Lean.Meta.Tactic.Apply import Lean.Meta.Tactic.Delta import Lean.Meta.Tactic.SplitIf +import Lean.Meta.Tactic.Injection +import Lean.Meta.Tactic.Contradiction namespace Lean.Meta @@ -122,51 +124,125 @@ where Option.isSome <| type.find? fun e => e.isAppOfArity ``namedPattern 4 && e.appArg! == h -/-- - Simplify/filter hypotheses that ensure that a match alternative does not match the previous ones. - Remark: if there is no overlaping between the alternatives, the empty array is returned. -/ -private partial def simpHs (hs : Array Expr) (numPatterns : Nat) : MetaM (Array Expr) := do - hs.filterMapM fun h => forallTelescope h fun ys _ => do - let xs := ys[:ys.size - numPatterns].toArray - let eqs ← ys[ys.size - numPatterns : ys.size].toArray.mapM inferType - if let some eqsNew ← simpEqs eqs *> get |>.run |>.run' #[] then - let newH ← eqsNew.foldrM (init := mkConst ``False) mkArrow - let xs ← xs.filterM fun x => dependsOn newH x.fvarId! - return some (← mkForallFVars xs newH) - else - none -where - simpEq (lhs : Expr) (rhs : Expr) : OptionT (StateRefT (Array Expr) MetaM) Unit := do - if isMatchValue lhs && isMatchValue rhs then - unless (← isDefEq lhs rhs) do - failure - else if rhs.isFVar then - -- Ignore case since it matches anything - pure () - else match lhs.arrayLit?, rhs.arrayLit? with - | some (_, lhsArgs), some (_, rhsArgs) => - if lhsArgs.length != rhsArgs.length then - failure - else - for lhsArg in lhsArgs, rhsArg in rhsArgs do - simpEq lhsArg rhsArg - | _, _ => - match toCtorIfLit lhs |>.constructorApp? (← getEnv), toCtorIfLit rhs |>.constructorApp? (← getEnv) with - | some (lhsCtor, lhsArgs), some (rhsCtor, rhsArgs) => - if lhsCtor.name == rhsCtor.name then - for lhsArg in lhsArgs[lhsCtor.numParams:], rhsArg in rhsArgs[lhsCtor.numParams:] do - simpEq lhsArg rhsArg - else - failure - | _, _ => - let newEq ← mkEq lhs rhs - modify fun eqs => eqs.push newEq +namespace SimpH - simpEqs (eqs : Array Expr) : OptionT (StateRefT (Array Expr) MetaM) Unit := do - eqs.forM fun eq => - match eq.eq? with - | some (_, lhs, rhs) => simpEq lhs rhs - | _ => throwError "failed to generate equality theorems for 'match', equality expected{indentExpr eq}" +/-- + State for the equational theorem hypothesis simplifier. + + Recall that each equation contains additional hypotheses to ensure the associated case does not taken by previous cases. + We have one hypothesis for each previous case. + + Each hypothesis is of the form `forall xs, eqs → False` + + We use tactics to minimize code duplication. +-/ +structure State where + mvarId : MVarId -- Goal representing the hypothesis + xs : List FVarId -- Pattern variables for a previous case + eqs : List FVarId -- Equations to be processed + eqsNew : List FVarId := [] -- Simplied (already processed) equations + +abbrev M := StateRefT State MetaM + +/-- + Apply the given substitution to `fvarIds`. + This is an auxiliary method for `substRHS`. +-/ +private def applySubst (s : FVarSubst) (fvarIds : List FVarId) : List FVarId := + fvarIds.filterMap fun fvarId => match s.apply (mkFVar fvarId) with + | Expr.fvar fvarId .. => some fvarId + | _ => none + +/-- + Given an equation of the form `lhs = rhs` where `rhs` is variable in `xs`, + the replace it everywhere with `lhs`. +-/ +private def substRHS (eq : FVarId) (rhs : FVarId) : M Unit := do + assert! (← get).xs.contains rhs + let (subst, mvarId) ← substCore (← get).mvarId eq (symm := true) + modify fun s => { s with + mvarId, + xs := applySubst subst (s.xs.erase rhs) + eqs := applySubst subst s.eqs + eqsNew := applySubst subst s.eqsNew + } + +private def isDone : M Bool := + return (← get).eqs.isEmpty + +/-- + Auxiliary tactic that tries to replace as many variables as possible and then apply `contradiction`. + We use it to discard redundant hypotheses. +-/ +private def trySubstVarsAndContradiction (mvarId : MVarId) : MetaM Bool := + commitWhen do + let mvarId ← substVars mvarId + contradictionCore mvarId {} + +private def processNextEq : M Bool := do + let s ← get + withMVarContext s.mvarId do + -- If the goal is contradictory, the hypothesis is redundant. + if (← contradictionCore s.mvarId {}) then + return false + if let eq :: eqs := s.eqs then + modify fun s => { s with eqs } + let eqType ← inferType (mkFVar eq) + -- See `substRHS`. Recall that if `rhs` is a variable then if must be in `s.xs` + if let some (_, lhs, rhs) ← matchEq? eqType then + if rhs.isFVar then + substRHS eq rhs.fvarId! + return true + if let some (α, lhs, β, rhs) ← matchHEq? eqType then + -- Try to convert `HEq` into `Eq` + if (← isDefEq α β) then + let (eqNew, mvarId) ← heqToEq s.mvarId eq (tryToClear := true) + modify fun s => { s with mvarId, eqs := eqNew :: s.eqs } + return true + -- If it is not possible, we try to show the hypothesis is redundant by substituting even variables that are not at `s.xs`, and then use contradiction. + else if (← trySubstVarsAndContradiction s.mvarId) then + return false + try + -- Try to simplify equation using `injection` tactic. + match (← injection s.mvarId eq) with + | InjectionResult.solved => return false + | InjectionResult.subgoal mvarId eqNews .. => + modify fun s => { s with mvarId, eqs := eqNews.toList ++ s.eqs } + catch _ => + modify fun s => { s with eqsNew := eq :: s.eqsNew } + return true + +partial def go : M Bool := do + if (← isDone) then + return true + else if (← processNextEq) then + go + else + return false + +end SimpH + +/-- + Auxiliary method for simplifying equational theorem hypotheses. + + Recall that each equation contains additional hypotheses to ensure the associated case does not taken by previous cases. + We have one hypothesis for each previous case. +-/ +private partial def simpH? (h : Expr) (numEqs : Nat) : MetaM (Option Expr) := withDefault do + let numVars ← forallTelescope h fun ys _ => pure (ys.size - numEqs) + let mvarId := (← mkFreshExprSyntheticOpaqueMVar h).mvarId! + let (xs, mvarId) ← introN mvarId numVars + let (eqs, mvarId) ← introN mvarId numEqs + let (r, s) ← SimpH.go |>.run { mvarId, xs := xs.toList, eqs := eqs.toList } + if r then + withMVarContext s.mvarId do + let vars := (s.xs ++ s.eqsNew.reverse).toArray.map mkFVar + let r ← mkForallFVars vars (mkConst ``False) + trace[Meta.Match.matchEqs] "simplified hypothesis{indentExpr r}" + check r + return some r + else + return none private def substSomeVar (mvarId : MVarId) : MetaM (Array MVarId) := withMVarContext mvarId do for localDecl in (← getLCtx) do @@ -330,15 +406,19 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := let patterns := altResultType.getAppArgs let mut hs := #[] for notAlt in notAlts do - hs := hs.push (← instantiateForall notAlt patterns) - hs ← simpHs hs patterns.size + let h ← instantiateForall notAlt patterns + if let some h ← simpH? h patterns.size then + hs := hs.push h trace[Meta.Match.matchEqs] "hs: {hs}" let splitterAltType ← mkForallFVars ys (← hs.foldrM (init := altResultType) mkArrow) let splitterAltNumParam := hs.size + ys.size -- Create a proposition for representing terms that do not match `patterns` let mut notAlt := mkConst ``False for discr in discrs.toArray.reverse, pattern in patterns.reverse do - notAlt ← mkArrow (← mkEq discr pattern) notAlt + if (← isDefEq (← inferType discr) (← inferType pattern)) then + notAlt ← mkArrow (← mkEq discr pattern) notAlt + else + notAlt ← mkArrow (← mkHEq discr pattern) notAlt notAlt ← mkForallFVars (discrs ++ ys) notAlt let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts) let rhs := mkAppN alt rhsArgs diff --git a/tests/lean/974.lean b/tests/lean/974.lean new file mode 100644 index 0000000000..45a00cd0b0 --- /dev/null +++ b/tests/lean/974.lean @@ -0,0 +1,15 @@ +inductive Formula : Nat → Type u +| bot : Formula n +| imp (f₁ f₂ : Formula n ) : Formula n +| all (f : Formula (n+1)) : Formula n + +def Formula.count_quantifiers : {n:Nat} → Formula n → Nat +| _, imp f₁ f₂ => f₁.count_quantifiers + f₂.count_quantifiers +| _, all f => f.count_quantifiers + 1 +| _, _ => 0 + +attribute [simp] Formula.count_quantifiers + +#check @Formula.count_quantifiers._eq_1 +#check @Formula.count_quantifiers._eq_2 +#check @Formula.count_quantifiers._eq_3 diff --git a/tests/lean/974.lean.expected.out b/tests/lean/974.lean.expected.out new file mode 100644 index 0000000000..97674a3ada --- /dev/null +++ b/tests/lean/974.lean.expected.out @@ -0,0 +1,7 @@ +Formula.count_quantifiers._eq_1 : ∀ (x : Nat) (f₁ f₂ : Formula x), + Formula.count_quantifiers (Formula.imp f₁ f₂) = Formula.count_quantifiers f₁ + Formula.count_quantifiers f₂ +Formula.count_quantifiers._eq_2 : ∀ (x : Nat) (f : Formula (x + 1)), + Formula.count_quantifiers (Formula.all f) = Formula.count_quantifiers f + 1 +Formula.count_quantifiers._eq_3 : ∀ (x : Nat) (x_1 : Formula x), + (∀ (f₁ f₂ : Formula x), x_1 = Formula.imp f₁ f₂ → False) → + (∀ (f : Formula (x + 1)), x_1 = Formula.all f → False) → Formula.count_quantifiers x_1 = 0 diff --git a/tests/lean/eqValue.lean b/tests/lean/eqValue.lean new file mode 100644 index 0000000000..2ccbcc4bfc --- /dev/null +++ b/tests/lean/eqValue.lean @@ -0,0 +1,11 @@ +@[simp] def f (x : Nat) : Nat := + match x with + | 0 => 1 + | 100 => 2 + | 1000 => 3 + | x+1 => f x + +#check f._eq_1 +#check f._eq_2 +#check f._eq_3 +#check f._eq_4 diff --git a/tests/lean/eqValue.lean.expected.out b/tests/lean/eqValue.lean.expected.out new file mode 100644 index 0000000000..d920bc916e --- /dev/null +++ b/tests/lean/eqValue.lean.expected.out @@ -0,0 +1,4 @@ +f._eq_1 : f 0 = 1 +f._eq_2 : f 100 = 2 +f._eq_3 : f 1000 = 3 +f._eq_4 : ∀ (x_1 : Nat), (x_1 = 99 → False) → (x_1 = 999 → False) → f (Nat.succ x_1) = f x_1