feat: allow rw to unfold nonrecursive definitions too
This commit is contained in:
parent
c0a72172f1
commit
f0c31d7e28
4 changed files with 61 additions and 4 deletions
|
|
@ -46,7 +46,7 @@ def withRWRulesSeq (token : Syntax) (rwRulesSeqStx : Syntax) (x : (symm : Bool)
|
|||
let processId (id : Syntax) : TacticM Unit := do
|
||||
-- Try to get equation theorems for `id` first
|
||||
let declName ← try resolveGlobalConstNoOverload id catch _ => return (← x symm term)
|
||||
let some eqThms ← getEqnsFor? declName | x symm term
|
||||
let some eqThms ← getEqnsFor? declName (nonRec := true) | x symm term
|
||||
let rec go : List Name → TacticM Unit
|
||||
| [] => throwError "failed to rewrite using equation theorems for '{declName}'"
|
||||
| eqThm::eqThms => (x symm (mkIdentFrom id eqThm)) <|> go eqThms
|
||||
|
|
|
|||
|
|
@ -58,9 +58,29 @@ builtin_initialize eqnsExt : EnvExtension EqnsExtState ←
|
|||
registerEnvExtension (pure {})
|
||||
|
||||
/--
|
||||
Return equation theorems for the given declaration.
|
||||
Simple equation theorem for nonrecursive definitions.
|
||||
-/
|
||||
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
|
||||
private def mkSimpleEqThm (declName : Name) : MetaM (Option Name) := do
|
||||
if let some (.defnInfo info) := (← getEnv).find? declName then
|
||||
lambdaTelescope info.value fun xs body => do
|
||||
let lhs := mkAppN (mkConst info.name <| info.levelParams.map mkLevelParam) xs
|
||||
let type ← mkForallFVars xs (← mkEq lhs body)
|
||||
let value ← mkLambdaFVars xs (← mkEqRefl lhs)
|
||||
let name := mkPrivateName (← getEnv) declName ++ `_eq_1
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name, type, value
|
||||
levelParams := info.levelParams
|
||||
}
|
||||
return some name
|
||||
else
|
||||
return none
|
||||
|
||||
/--
|
||||
Return equation theorems for the given declaration.
|
||||
By default, we not create equation theorems for nonrecursive definitions.
|
||||
You can use `nonRec := true` to override this behavior, a dummy `rfl` proof is created on the fly.
|
||||
-/
|
||||
def getEqnsFor? (declName : Name) (nonRec := false) : MetaM (Option (Array Name)) := do
|
||||
if let some eqs := eqnsExt.getState (← getEnv) |>.map.find? declName then
|
||||
return some eqs
|
||||
else if (← shouldGenerateEqnThms declName) then
|
||||
|
|
@ -68,6 +88,11 @@ def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
|
|||
if let some r ← f declName then
|
||||
modifyEnv fun env => eqnsExt.modifyState env fun s => { s with map := s.map.insert declName r }
|
||||
return some r
|
||||
if nonRec then
|
||||
let some eqThm ← mkSimpleEqThm declName | return none
|
||||
let r := #[eqThm]
|
||||
modifyEnv fun env => eqnsExt.modifyState env fun s => { s with map := s.map.insert declName r }
|
||||
return some r
|
||||
return none
|
||||
|
||||
def GetUnfoldEqnFn := Name → MetaM (Option Name)
|
||||
|
|
@ -104,11 +129,19 @@ def registerGetUnfoldEqnFn (f : GetUnfoldEqnFn) : IO Unit := do
|
|||
throw (IO.userError "failed to register equation getter, this kind of extension can only be registered during initialization")
|
||||
getUnfoldEqnFnsRef.modify (f :: ·)
|
||||
|
||||
def getUnfoldEqnFor? (declName : Name) : MetaM (Option Name) := do
|
||||
/--
|
||||
Return a "unfold" theorem for the given declaration.
|
||||
By default, we not create unfold theorems for nonrecursive definitions.
|
||||
You can use `nonRec := true` to override this behavior.
|
||||
-/
|
||||
def getUnfoldEqnFor? (declName : Name) (nonRec := false) : MetaM (Option Name) := do
|
||||
if (← shouldGenerateEqnThms declName) then
|
||||
for f in (← getUnfoldEqnFnsRef.get) do
|
||||
if let some r ← f declName then
|
||||
return some r
|
||||
if nonRec then
|
||||
let some #[eqThm] ← getEqnsFor? declName (nonRec := true) | return none
|
||||
return some eqThm
|
||||
return none
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
|
|
@ -14,3 +14,12 @@ example {a : α} {as bs : List α} (h : as = bs) : (a::b::as).length = (b::bs).l
|
|||
conv => rhs; rw [List.length]
|
||||
trace_state -- rhs was unfolded
|
||||
rw [h]
|
||||
|
||||
example {a : α} {as bs : List α} (h : as = bs) : id (id ((a::b::as).length)) = (b::bs).length + 1 := by
|
||||
rw [id]
|
||||
trace_state
|
||||
rw [id]
|
||||
trace_state
|
||||
rw [List.length, List.length, List.length]
|
||||
trace_state
|
||||
rw [h]
|
||||
|
|
|
|||
|
|
@ -25,3 +25,18 @@ b a : α
|
|||
as bs : List α
|
||||
h : as = bs
|
||||
⊢ List.length as + 1 + 1 = List.length bs + 1 + 1
|
||||
α : Type ?u
|
||||
b a : α
|
||||
as bs : List α
|
||||
h : as = bs
|
||||
⊢ id (List.length (a :: b :: as)) = List.length (b :: bs) + 1
|
||||
α : Type ?u
|
||||
b a : α
|
||||
as bs : List α
|
||||
h : as = bs
|
||||
⊢ List.length (a :: b :: as) = List.length (b :: bs) + 1
|
||||
α : Type ?u
|
||||
b a : α
|
||||
as bs : List α
|
||||
h : as = bs
|
||||
⊢ List.length as + 1 + 1 = List.length bs + 1 + 1
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue