feat: allow rw to unfold nonrecursive definitions too

This commit is contained in:
Leonardo de Moura 2022-03-12 15:44:52 -08:00
parent c0a72172f1
commit f0c31d7e28
4 changed files with 61 additions and 4 deletions

View file

@ -46,7 +46,7 @@ def withRWRulesSeq (token : Syntax) (rwRulesSeqStx : Syntax) (x : (symm : Bool)
let processId (id : Syntax) : TacticM Unit := do
-- Try to get equation theorems for `id` first
let declName ← try resolveGlobalConstNoOverload id catch _ => return (← x symm term)
let some eqThms ← getEqnsFor? declName | x symm term
let some eqThms ← getEqnsFor? declName (nonRec := true) | x symm term
let rec go : List Name → TacticM Unit
| [] => throwError "failed to rewrite using equation theorems for '{declName}'"
| eqThm::eqThms => (x symm (mkIdentFrom id eqThm)) <|> go eqThms

View file

@ -58,9 +58,29 @@ builtin_initialize eqnsExt : EnvExtension EqnsExtState ←
registerEnvExtension (pure {})
/--
Return equation theorems for the given declaration.
Simple equation theorem for nonrecursive definitions.
-/
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
private def mkSimpleEqThm (declName : Name) : MetaM (Option Name) := do
if let some (.defnInfo info) := (← getEnv).find? declName then
lambdaTelescope info.value fun xs body => do
let lhs := mkAppN (mkConst info.name <| info.levelParams.map mkLevelParam) xs
let type ← mkForallFVars xs (← mkEq lhs body)
let value ← mkLambdaFVars xs (← mkEqRefl lhs)
let name := mkPrivateName (← getEnv) declName ++ `_eq_1
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return some name
else
return none
/--
Return equation theorems for the given declaration.
By default, we not create equation theorems for nonrecursive definitions.
You can use `nonRec := true` to override this behavior, a dummy `rfl` proof is created on the fly.
-/
def getEqnsFor? (declName : Name) (nonRec := false) : MetaM (Option (Array Name)) := do
if let some eqs := eqnsExt.getState (← getEnv) |>.map.find? declName then
return some eqs
else if (← shouldGenerateEqnThms declName) then
@ -68,6 +88,11 @@ def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if let some r ← f declName then
modifyEnv fun env => eqnsExt.modifyState env fun s => { s with map := s.map.insert declName r }
return some r
if nonRec then
let some eqThm ← mkSimpleEqThm declName | return none
let r := #[eqThm]
modifyEnv fun env => eqnsExt.modifyState env fun s => { s with map := s.map.insert declName r }
return some r
return none
def GetUnfoldEqnFn := Name → MetaM (Option Name)
@ -104,11 +129,19 @@ def registerGetUnfoldEqnFn (f : GetUnfoldEqnFn) : IO Unit := do
throw (IO.userError "failed to register equation getter, this kind of extension can only be registered during initialization")
getUnfoldEqnFnsRef.modify (f :: ·)
def getUnfoldEqnFor? (declName : Name) : MetaM (Option Name) := do
/--
Return a "unfold" theorem for the given declaration.
By default, we not create unfold theorems for nonrecursive definitions.
You can use `nonRec := true` to override this behavior.
-/
def getUnfoldEqnFor? (declName : Name) (nonRec := false) : MetaM (Option Name) := do
if (← shouldGenerateEqnThms declName) then
for f in (← getUnfoldEqnFnsRef.get) do
if let some r ← f declName then
return some r
if nonRec then
let some #[eqThm] ← getEqnsFor? declName (nonRec := true) | return none
return some eqThm
return none
end Lean.Meta

View file

@ -14,3 +14,12 @@ example {a : α} {as bs : List α} (h : as = bs) : (a::b::as).length = (b::bs).l
conv => rhs; rw [List.length]
trace_state -- rhs was unfolded
rw [h]
example {a : α} {as bs : List α} (h : as = bs) : id (id ((a::b::as).length)) = (b::bs).length + 1 := by
rw [id]
trace_state
rw [id]
trace_state
rw [List.length, List.length, List.length]
trace_state
rw [h]

View file

@ -25,3 +25,18 @@ b a : α
as bs : List α
h : as = bs
⊢ List.length as + 1 + 1 = List.length bs + 1 + 1
α : Type ?u
b a : α
as bs : List α
h : as = bs
⊢ id (List.length (a :: b :: as)) = List.length (b :: bs) + 1
α : Type ?u
b a : α
as bs : List α
h : as = bs
⊢ List.length (a :: b :: as) = List.length (b :: bs) + 1
α : Type ?u
b a : α
as bs : List α
h : as = bs
⊢ List.length as + 1 + 1 = List.length bs + 1 + 1