diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 48a14c8f73..679d224ce5 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -487,6 +487,14 @@ where else congrDefault e + simpMatch? (e : Expr) : M (Option Result) := do + let .const declName _ := e.getAppFn + | return none + if let some info ← getMatcherInfo? declName then + simpMatchDiscrs? simp dsimp info e + else + return none + simpApp (e : Expr) : M Result := do let e' ← reduceStep e if e' != e then @@ -494,6 +502,8 @@ where else if isOfNatNatLit e' then -- Recall that we expand "orphan" kernel nat literals `n` into `ofNat n` return { expr := e' } + else if let some r ← simpMatch? e' then + simpLoop r else congr e' diff --git a/src/Lean/Meta/Tactic/Simp/Types.lean b/src/Lean/Meta/Tactic/Simp/Types.lean index 849b7dc268..7d26120e1d 100644 --- a/src/Lean/Meta/Tactic/Simp/Types.lean +++ b/src/Lean/Meta/Tactic/Simp/Types.lean @@ -190,6 +190,46 @@ The resulting proof is built using `congr` and `congrFun` theorems. i := i + 1 return r +/-- +Given a match-application `e` with `MatcherInfo` `info`, return `some result` +if at least of one of the discriminants has been simplified. +-/ +@[specialize] def simpMatchDiscrs? + [Monad m] [MonadLiftT MetaM m] [MonadLiftT IO m] [MonadRef m] [MonadOptions m] [MonadTrace m] [AddMessageContext m] + (simp : Expr → m Result) + (dsimp : Expr → m Expr) + (info : MatcherInfo) (e : Expr) : m (Option Result) := do + let numArgs := e.getAppNumArgs + if numArgs < info.arity then + return none + let prefixSize := info.numParams + 1 /- motive -/ + let n := numArgs - prefixSize + let f := e.extractNumArgs n + let infos := (← getFunInfoNArgs f n).paramInfo + let args := e.getAppArgsN n + let mut r : Result := { expr := f } + let mut modified := false + for i in [0 : info.numDiscrs] do + let arg := args[i]! + if i < infos.size && !infos[i]!.hasFwdDeps then + let argNew ← simp arg + if argNew.expr != arg then modified := true + r ← mkCongr r argNew + else if (← whnfD (← inferType r.expr)).isArrow then + let argNew ← simp arg + if argNew.expr != arg then modified := true + r ← mkCongr r argNew + else + let argNew ← dsimp arg + if argNew != arg then modified := true + r ← mkCongrFun r argNew + unless modified do + return none + for i in [info.numDiscrs : args.size] do + let arg := args[i]! + r ← mkCongrFun r arg + return some r + /-- Helper class for generalizing `mkCongrSimp?` -/ diff --git a/tests/lean/run/simpMatchDiscrIssue.lean b/tests/lean/run/simpMatchDiscrIssue.lean new file mode 100644 index 0000000000..b7cf30ef5e --- /dev/null +++ b/tests/lean/run/simpMatchDiscrIssue.lean @@ -0,0 +1,24 @@ +/-! +Test support for `match`-applications in the simplifier. +The discriminants should be simplified before trying to apply congruence. +In the following example, the term `g (a + )` takes an +exponential amount of time to be simplified the discriminants are not simplified, +and the `match`-application reduced before visiting the alternatives. +-/ + +def myid (x : Nat) := 0 + x + +@[simp] theorem myid_eq : myid x = x := by simp [myid] + +def f (x : Nat) (y z : Nat) : Nat := + match myid x with + | 0 => y + | _+1 => z + +def g (x : Nat) : Nat := + match x with + | 0 => 1 + | a+1 => f x (g a + 1) (g a) + +theorem ex (h : a = 1) : g (a+64) = a := by + simp [g, f, h] diff --git a/tests/lean/simp_trace.lean.expected.out b/tests/lean/simp_trace.lean.expected.out index aa380ca3af..2d4398c90b 100644 --- a/tests/lean/simp_trace.lean.expected.out +++ b/tests/lean/simp_trace.lean.expected.out @@ -34,7 +34,7 @@ Try this: simp (config := { unfoldPartialApp := true }) only [f1, modify, modify | (a, s) => (fun s => set (g s)) a s [Meta.Tactic.simp.rewrite] unfold getThe, getThe Nat s ==> MonadStateOf.get s [Meta.Tactic.simp.rewrite] unfold StateT.get, StateT.get s ==> pure (s, s) -[Meta.Tactic.simp.rewrite] unfold StateT.set, StateT.set (g a) s ==> pure (PUnit.unit, g a) +[Meta.Tactic.simp.rewrite] unfold StateT.set, StateT.set (g s) s ==> pure (PUnit.unit, g s) [Meta.Tactic.simp.rewrite] @eq_self:1000, (fun s => (PUnit.unit, g s)) = fun s => (PUnit.unit, g s) ==> True Try this: simp only [bla, h] [Meta.Tactic.simp.rewrite] unfold bla, bla x ==> match h x with