From f0c31d7e28a342c4ed555404aeb66801b7009ea2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 12 Mar 2022 15:44:52 -0800 Subject: [PATCH] feat: allow `rw` to unfold nonrecursive definitions too --- src/Lean/Elab/Tactic/Rewrite.lean | 2 +- src/Lean/Meta/Eqns.lean | 39 ++++++++++++++++++++++++--- tests/lean/rwEqThms.lean | 9 +++++++ tests/lean/rwEqThms.lean.expected.out | 15 +++++++++++ 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/src/Lean/Elab/Tactic/Rewrite.lean b/src/Lean/Elab/Tactic/Rewrite.lean index eb484bd244..c6dad5da9d 100644 --- a/src/Lean/Elab/Tactic/Rewrite.lean +++ b/src/Lean/Elab/Tactic/Rewrite.lean @@ -46,7 +46,7 @@ def withRWRulesSeq (token : Syntax) (rwRulesSeqStx : Syntax) (x : (symm : Bool) let processId (id : Syntax) : TacticM Unit := do -- Try to get equation theorems for `id` first let declName ← try resolveGlobalConstNoOverload id catch _ => return (← x symm term) - let some eqThms ← getEqnsFor? declName | x symm term + let some eqThms ← getEqnsFor? declName (nonRec := true) | x symm term let rec go : List Name → TacticM Unit | [] => throwError "failed to rewrite using equation theorems for '{declName}'" | eqThm::eqThms => (x symm (mkIdentFrom id eqThm)) <|> go eqThms diff --git a/src/Lean/Meta/Eqns.lean b/src/Lean/Meta/Eqns.lean index 1e4b52b8d2..ecd593baa3 100644 --- a/src/Lean/Meta/Eqns.lean +++ b/src/Lean/Meta/Eqns.lean @@ -58,9 +58,29 @@ builtin_initialize eqnsExt : EnvExtension EqnsExtState ← registerEnvExtension (pure {}) /-- - Return equation theorems for the given declaration. + Simple equation theorem for nonrecursive definitions. -/ -def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do +private def mkSimpleEqThm (declName : Name) : MetaM (Option Name) := do + if let some (.defnInfo info) := (← getEnv).find? declName then + lambdaTelescope info.value fun xs body => do + let lhs := mkAppN (mkConst info.name <| info.levelParams.map mkLevelParam) xs + let type ← mkForallFVars xs (← mkEq lhs body) + let value ← mkLambdaFVars xs (← mkEqRefl lhs) + let name := mkPrivateName (← getEnv) declName ++ `_eq_1 + addDecl <| Declaration.thmDecl { + name, type, value + levelParams := info.levelParams + } + return some name + else + return none + +/-- + Return equation theorems for the given declaration. + By default, we not create equation theorems for nonrecursive definitions. + You can use `nonRec := true` to override this behavior, a dummy `rfl` proof is created on the fly. +-/ +def getEqnsFor? (declName : Name) (nonRec := false) : MetaM (Option (Array Name)) := do if let some eqs := eqnsExt.getState (← getEnv) |>.map.find? declName then return some eqs else if (← shouldGenerateEqnThms declName) then @@ -68,6 +88,11 @@ def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do if let some r ← f declName then modifyEnv fun env => eqnsExt.modifyState env fun s => { s with map := s.map.insert declName r } return some r + if nonRec then + let some eqThm ← mkSimpleEqThm declName | return none + let r := #[eqThm] + modifyEnv fun env => eqnsExt.modifyState env fun s => { s with map := s.map.insert declName r } + return some r return none def GetUnfoldEqnFn := Name → MetaM (Option Name) @@ -104,11 +129,19 @@ def registerGetUnfoldEqnFn (f : GetUnfoldEqnFn) : IO Unit := do throw (IO.userError "failed to register equation getter, this kind of extension can only be registered during initialization") getUnfoldEqnFnsRef.modify (f :: ·) -def getUnfoldEqnFor? (declName : Name) : MetaM (Option Name) := do +/-- + Return a "unfold" theorem for the given declaration. + By default, we not create unfold theorems for nonrecursive definitions. + You can use `nonRec := true` to override this behavior. +-/ +def getUnfoldEqnFor? (declName : Name) (nonRec := false) : MetaM (Option Name) := do if (← shouldGenerateEqnThms declName) then for f in (← getUnfoldEqnFnsRef.get) do if let some r ← f declName then return some r + if nonRec then + let some #[eqThm] ← getEqnsFor? declName (nonRec := true) | return none + return some eqThm return none end Lean.Meta diff --git a/tests/lean/rwEqThms.lean b/tests/lean/rwEqThms.lean index 218082ae0f..d320382b20 100644 --- a/tests/lean/rwEqThms.lean +++ b/tests/lean/rwEqThms.lean @@ -14,3 +14,12 @@ example {a : α} {as bs : List α} (h : as = bs) : (a::b::as).length = (b::bs).l conv => rhs; rw [List.length] trace_state -- rhs was unfolded rw [h] + +example {a : α} {as bs : List α} (h : as = bs) : id (id ((a::b::as).length)) = (b::bs).length + 1 := by + rw [id] + trace_state + rw [id] + trace_state + rw [List.length, List.length, List.length] + trace_state + rw [h] diff --git a/tests/lean/rwEqThms.lean.expected.out b/tests/lean/rwEqThms.lean.expected.out index 9511f6b048..b16aade4ee 100644 --- a/tests/lean/rwEqThms.lean.expected.out +++ b/tests/lean/rwEqThms.lean.expected.out @@ -25,3 +25,18 @@ b a : α as bs : List α h : as = bs ⊢ List.length as + 1 + 1 = List.length bs + 1 + 1 +α : Type ?u +b a : α +as bs : List α +h : as = bs +⊢ id (List.length (a :: b :: as)) = List.length (b :: bs) + 1 +α : Type ?u +b a : α +as bs : List α +h : as = bs +⊢ List.length (a :: b :: as) = List.length (b :: bs) + 1 +α : Type ?u +b a : α +as bs : List α +h : as = bs +⊢ List.length as + 1 + 1 = List.length bs + 1 + 1