feat: dsimproc command
Simplification procedures that produce definitionally equal results. WIP
This commit is contained in:
parent
f986f69a32
commit
b24fbf44f3
4 changed files with 142 additions and 40 deletions
|
|
@ -31,22 +31,43 @@ Simplification procedures can be also scoped or local.
|
|||
-/
|
||||
syntax (docComment)? attrKind "simproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command
|
||||
|
||||
/--
|
||||
Similar to `simproc`, but resulting expression must be definitionally equal to the input one.
|
||||
-/
|
||||
syntax (docComment)? attrKind "dsimproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command
|
||||
|
||||
/--
|
||||
A user-defined simplification procedure declaration. To activate this procedure in `simp` tactic,
|
||||
we must provide it as an argument, or use the command `attribute` to set its `[simproc]` attribute.
|
||||
-/
|
||||
syntax (docComment)? "simproc_decl " ident " (" term ")" " := " term : command
|
||||
|
||||
/--
|
||||
A user-defined defeq simplification procedure declaration. To activate this procedure in `simp` tactic,
|
||||
we must provide it as an argument, or use the command `attribute` to set its `[simproc]` attribute.
|
||||
-/
|
||||
syntax (docComment)? "dsimproc_decl " ident " (" term ")" " := " term : command
|
||||
|
||||
/--
|
||||
A builtin simplification procedure.
|
||||
-/
|
||||
syntax (docComment)? attrKind "builtin_simproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command
|
||||
|
||||
/--
|
||||
A builtin defeq simplification procedure.
|
||||
-/
|
||||
syntax (docComment)? attrKind "builtin_dsimproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command
|
||||
|
||||
/--
|
||||
A builtin simplification procedure declaration.
|
||||
-/
|
||||
syntax (docComment)? "builtin_simproc_decl " ident " (" term ")" " := " term : command
|
||||
|
||||
/--
|
||||
A builtin defeq simplification procedure declaration.
|
||||
-/
|
||||
syntax (docComment)? "builtin_dsimproc_decl " ident " (" term ")" " := " term : command
|
||||
|
||||
/--
|
||||
Auxiliary command for associating a pattern with a simplification procedure.
|
||||
-/
|
||||
|
|
@ -86,33 +107,60 @@ macro_rules
|
|||
`($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body
|
||||
simproc_pattern% $pattern => $n)
|
||||
|
||||
macro_rules
|
||||
| `($[$doc?:docComment]? dsimproc_decl $n:ident ($pattern:term) := $body) => do
|
||||
let simprocType := `Lean.Meta.Simp.DSimproc
|
||||
`($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body
|
||||
simproc_pattern% $pattern => $n)
|
||||
|
||||
macro_rules
|
||||
| `($[$doc?:docComment]? builtin_simproc_decl $n:ident ($pattern:term) := $body) => do
|
||||
let simprocType := `Lean.Meta.Simp.Simproc
|
||||
`($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body
|
||||
builtin_simproc_pattern% $pattern => $n)
|
||||
|
||||
macro_rules
|
||||
| `($[$doc?:docComment]? builtin_dsimproc_decl $n:ident ($pattern:term) := $body) => do
|
||||
let simprocType := `Lean.Meta.Simp.DSimproc
|
||||
`($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body
|
||||
builtin_simproc_pattern% $pattern => $n)
|
||||
|
||||
private def mkAttributeCmds
|
||||
(kind : TSyntax `Lean.Parser.Term.attrKind)
|
||||
(pre? : Option (TSyntax [`Lean.Parser.Tactic.simpPre, `Lean.Parser.Tactic.simpPost]))
|
||||
(ids? : Option (Syntax.TSepArray `ident ","))
|
||||
(n : Ident) : MacroM (Array Syntax) := do
|
||||
let mut cmds := #[]
|
||||
let pushDefault (cmds : Array (TSyntax `command)) : MacroM (Array (TSyntax `command)) := do
|
||||
return cmds.push (← `(attribute [$kind simproc $[$pre?]?] $n))
|
||||
if let some ids := ids? then
|
||||
for id in ids.getElems do
|
||||
let idName := id.getId
|
||||
let (attrName, attrKey) :=
|
||||
if idName == `simp then
|
||||
(`simprocAttr, "simproc")
|
||||
else if idName == `seval then
|
||||
(`sevalprocAttr, "sevalproc")
|
||||
else
|
||||
let idName := idName.appendAfter "_proc"
|
||||
(`Parser.Attr ++ idName, idName.toString)
|
||||
let attrStx : TSyntax `attr := ⟨mkNode attrName #[mkAtom attrKey, mkOptionalNode pre?]⟩
|
||||
cmds := cmds.push (← `(attribute [$kind $attrStx] $n))
|
||||
else
|
||||
cmds ← pushDefault cmds
|
||||
return cmds
|
||||
|
||||
macro_rules
|
||||
| `($[$doc?:docComment]? $kind:attrKind simproc $[$pre?]? $[ [ $ids?:ident,* ] ]? $n:ident ($pattern:term) := $body) => do
|
||||
let mut cmds := #[(← `($[$doc?:docComment]? simproc_decl $n ($pattern) := $body))]
|
||||
let pushDefault (cmds : Array (TSyntax `command)) : MacroM (Array (TSyntax `command)) := do
|
||||
return cmds.push (← `(attribute [$kind simproc $[$pre?]?] $n))
|
||||
if let some ids := ids? then
|
||||
for id in ids.getElems do
|
||||
let idName := id.getId
|
||||
let (attrName, attrKey) :=
|
||||
if idName == `simp then
|
||||
(`simprocAttr, "simproc")
|
||||
else if idName == `seval then
|
||||
(`sevalprocAttr, "sevalproc")
|
||||
else
|
||||
let idName := idName.appendAfter "_proc"
|
||||
(`Parser.Attr ++ idName, idName.toString)
|
||||
let attrStx : TSyntax `attr := ⟨mkNode attrName #[mkAtom attrKey, mkOptionalNode pre?]⟩
|
||||
cmds := cmds.push (← `(attribute [$kind $attrStx] $n))
|
||||
else
|
||||
cmds ← pushDefault cmds
|
||||
return mkNullNode cmds
|
||||
return mkNullNode <|
|
||||
#[(← `($[$doc?:docComment]? simproc_decl $n ($pattern) := $body))]
|
||||
++ (← mkAttributeCmds kind pre? ids? n)
|
||||
|
||||
macro_rules
|
||||
| `($[$doc?:docComment]? $kind:attrKind dsimproc $[$pre?]? $[ [ $ids?:ident,* ] ]? $n:ident ($pattern:term) := $body) => do
|
||||
return mkNullNode <|
|
||||
#[(← `($[$doc?:docComment]? dsimproc_decl $n ($pattern) := $body))]
|
||||
++ (← mkAttributeCmds kind pre? ids? n)
|
||||
|
||||
macro_rules
|
||||
| `($[$doc?:docComment]? $kind:attrKind builtin_simproc $[$pre?]? $n:ident ($pattern:term) := $body) => do
|
||||
|
|
@ -126,4 +174,16 @@ macro_rules
|
|||
attribute [$kind builtin_simproc $[$pre?]?] $n
|
||||
attribute [$kind builtin_sevalproc $[$pre?]?] $n)
|
||||
|
||||
macro_rules
|
||||
| `($[$doc?:docComment]? $kind:attrKind builtin_dsimproc $[$pre?]? $n:ident ($pattern:term) := $body) => do
|
||||
`($[$doc?:docComment]? builtin_dsimproc_decl $n ($pattern) := $body
|
||||
attribute [$kind builtin_simproc $[$pre?]?] $n)
|
||||
| `($[$doc?:docComment]? $kind:attrKind builtin_dsimproc $[$pre?]? [seval] $n:ident ($pattern:term) := $body) => do
|
||||
`($[$doc?:docComment]? builtin_dsimproc_decl $n ($pattern) := $body
|
||||
attribute [$kind builtin_sevalproc $[$pre?]?] $n)
|
||||
| `($[$doc?:docComment]? $kind:attrKind builtin_dsimproc $[$pre?]? [simp, seval] $n:ident ($pattern:term) := $body) => do
|
||||
`($[$doc?:docComment]? builtin_dsimproc_decl $n ($pattern) := $body
|
||||
attribute [$kind builtin_simproc $[$pre?]?] $n
|
||||
attribute [$kind builtin_sevalproc $[$pre?]?] $n)
|
||||
|
||||
end Lean.Parser
|
||||
|
|
|
|||
|
|
@ -26,10 +26,11 @@ def elabSimprocKeys (stx : Syntax) : MetaM (Array Meta.SimpTheoremKey) := do
|
|||
let pattern ← elabSimprocPattern stx
|
||||
DiscrTree.mkPath pattern simpDtConfig
|
||||
|
||||
def checkSimprocType (declName : Name) : CoreM Unit := do
|
||||
def checkSimprocType (declName : Name) : CoreM Bool := do
|
||||
let decl ← getConstInfo declName
|
||||
match decl.type with
|
||||
| .const ``Simproc _ => pure ()
|
||||
| .const ``Simproc _ => pure false
|
||||
| .const ``DSimproc _ => pure true
|
||||
| _ => throwError "unexpected type at '{declName}', 'Simproc' expected"
|
||||
|
||||
namespace Command
|
||||
|
|
@ -38,7 +39,7 @@ namespace Command
|
|||
let `(simproc_pattern% $pattern => $declName) := stx | throwUnsupportedSyntax
|
||||
let declName ← resolveGlobalConstNoOverload declName
|
||||
liftTermElabM do
|
||||
checkSimprocType declName
|
||||
discard <| checkSimprocType declName
|
||||
let keys ← elabSimprocKeys pattern
|
||||
registerSimproc declName keys
|
||||
|
||||
|
|
@ -46,9 +47,10 @@ namespace Command
|
|||
let `(builtin_simproc_pattern% $pattern => $declName) := stx | throwUnsupportedSyntax
|
||||
let declName ← resolveGlobalConstNoOverload declName
|
||||
liftTermElabM do
|
||||
checkSimprocType declName
|
||||
let dsimp ← checkSimprocType declName
|
||||
let keys ← elabSimprocKeys pattern
|
||||
let val := mkAppN (mkConst ``registerBuiltinSimproc) #[toExpr declName, toExpr keys, mkConst declName]
|
||||
let registerProcName := if dsimp then ``registerBuiltinDSimproc else ``registerBuiltinSimproc
|
||||
let val := mkAppN (mkConst registerProcName) #[toExpr declName, toExpr keys, mkConst declName]
|
||||
let initDeclName ← mkFreshUserName (declName ++ `declare)
|
||||
declareBuiltin initDeclName val
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ It contains:
|
|||
-/
|
||||
structure BuiltinSimprocs where
|
||||
keys : HashMap Name (Array SimpTheoremKey) := {}
|
||||
procs : HashMap Name Simproc := {}
|
||||
procs : HashMap Name (Sum Simproc DSimproc) := {}
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
|
|
@ -79,7 +79,7 @@ Given a declaration name `declName`, store the discrimination tree keys and the
|
|||
|
||||
This method is invoked by the command `builtin_simproc_pattern%` elaborator.
|
||||
-/
|
||||
def registerBuiltinSimproc (declName : Name) (key : Array SimpTheoremKey) (proc : Simproc) : IO Unit := do
|
||||
def registerBuiltinSimprocCore (declName : Name) (key : Array SimpTheoremKey) (proc : Sum Simproc DSimproc) : IO Unit := do
|
||||
unless (← initializing) do
|
||||
throw (IO.userError s!"invalid builtin simproc declaration, it can only be registered during initialization")
|
||||
if (← builtinSimprocDeclsRef.get).keys.contains declName then
|
||||
|
|
@ -87,6 +87,12 @@ def registerBuiltinSimproc (declName : Name) (key : Array SimpTheoremKey) (proc
|
|||
builtinSimprocDeclsRef.modify fun { keys, procs } =>
|
||||
{ keys := keys.insert declName key, procs := procs.insert declName proc }
|
||||
|
||||
def registerBuiltinSimproc (declName : Name) (key : Array SimpTheoremKey) (proc : Simproc) : IO Unit := do
|
||||
registerBuiltinSimprocCore declName key (.inl proc)
|
||||
|
||||
def registerBuiltinDSimproc (declName : Name) (key : Array SimpTheoremKey) (proc : DSimproc) : IO Unit := do
|
||||
registerBuiltinSimprocCore declName key (.inr proc)
|
||||
|
||||
def registerSimproc (declName : Name) (keys : Array SimpTheoremKey) : CoreM Unit := do
|
||||
let env ← getEnv
|
||||
unless (env.getModuleIdxFor? declName).isNone do
|
||||
|
|
@ -112,14 +118,21 @@ builtin_initialize builtinSEvalprocsRef : IO.Ref Simprocs ← IO.mkRef {}
|
|||
|
||||
abbrev SimprocExtension := ScopedEnvExtension SimprocOLeanEntry SimprocEntry Simprocs
|
||||
|
||||
unsafe def getSimprocFromDeclImpl (declName : Name) : ImportM Simproc := do
|
||||
unsafe def getSimprocFromDeclImpl (declName : Name) : ImportM (Sum Simproc DSimproc) := do
|
||||
let ctx ← read
|
||||
match ctx.env.evalConstCheck Simproc ctx.opts ``Lean.Meta.Simp.Simproc declName with
|
||||
| .ok proc => return proc
|
||||
| .error ex => throw (IO.userError ex)
|
||||
match ctx.env.find? declName with
|
||||
| none => throw <| IO.userError ("unknown constant '" ++ toString declName ++ "'")
|
||||
| some info =>
|
||||
match info.type with
|
||||
| .const ``Simproc _ =>
|
||||
return .inl (← IO.ofExcept <| ctx.env.evalConst Simproc ctx.opts declName)
|
||||
| .const ``DSimproc _ =>
|
||||
return .inr (← IO.ofExcept <| ctx.env.evalConst DSimproc ctx.opts declName)
|
||||
| _ => throw <| IO.userError "unexpected type at simproc"
|
||||
|
||||
|
||||
@[implemented_by getSimprocFromDeclImpl]
|
||||
opaque getSimprocFromDecl (declName: Name) : ImportM Simproc
|
||||
opaque getSimprocFromDecl (declName: Name) : ImportM (Sum Simproc DSimproc)
|
||||
|
||||
def toSimprocEntry (e : SimprocOLeanEntry) : ImportM SimprocEntry := do
|
||||
return { toSimprocOLeanEntry := e, proc := (← getSimprocFromDecl e.declName) }
|
||||
|
|
@ -136,7 +149,7 @@ def addSimprocAttrCore (ext : SimprocExtension) (declName : Name) (kind : Attrib
|
|||
throwError "invalid [simproc] attribute, '{declName}' is not a simproc"
|
||||
ext.add { declName, post, keys, proc } kind
|
||||
|
||||
def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Name) (post : Bool) (proc : Simproc) : Simprocs :=
|
||||
def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : Simprocs :=
|
||||
let s := { s with simprocNames := s.simprocNames.insert declName, erased := s.erased.erase declName }
|
||||
if post then
|
||||
{ s with post := s.post.insertCore keys { declName, keys, post, proc } }
|
||||
|
|
@ -146,15 +159,21 @@ def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Na
|
|||
/--
|
||||
Implements attributes `builtin_simproc` and `builtin_sevalproc`.
|
||||
-/
|
||||
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Simproc) : IO Unit := do
|
||||
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := do
|
||||
let some keys := (← builtinSimprocDeclsRef.get).keys.find? declName |
|
||||
throw (IO.userError "invalid [builtin_simproc] attribute, '{declName}' is not a builtin simproc")
|
||||
ref.modify fun s => s.addCore keys declName post proc
|
||||
|
||||
def addSimprocBuiltinAttr (declName : Name) (post : Bool) (proc : Simproc) : IO Unit :=
|
||||
addSimprocBuiltinAttrCore builtinSimprocsRef declName post proc
|
||||
addSimprocBuiltinAttrCore builtinSimprocsRef declName post (.inl proc)
|
||||
|
||||
def addSEvalprocBuiltinAttr (declName : Name) (post : Bool) (proc : Simproc) : IO Unit :=
|
||||
addSimprocBuiltinAttrCore builtinSEvalprocsRef declName post (.inl proc)
|
||||
|
||||
def addSimprocBuiltinAttrNew (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit :=
|
||||
addSimprocBuiltinAttrCore builtinSimprocsRef declName post proc
|
||||
|
||||
def addSEvalprocBuiltinAttrNew (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit :=
|
||||
addSimprocBuiltinAttrCore builtinSEvalprocsRef declName post proc
|
||||
|
||||
def Simprocs.add (s : Simprocs) (declName : Name) (post : Bool) : CoreM Simprocs := do
|
||||
|
|
@ -179,8 +198,13 @@ def SimprocEntry.try (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM
|
|||
extraArgs := extraArgs.push e.appArg!
|
||||
e := e.appFn!
|
||||
extraArgs := extraArgs.reverse
|
||||
let s ← s.proc e
|
||||
s.addExtraArgs extraArgs
|
||||
match s.proc with
|
||||
| .inl proc =>
|
||||
let s ← proc e
|
||||
s.addExtraArgs extraArgs
|
||||
| .inr proc =>
|
||||
let s ← proc e
|
||||
s.toStep.addExtraArgs extraArgs
|
||||
|
||||
def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM Step := do
|
||||
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
|
||||
|
|
@ -315,7 +339,11 @@ builtin_initialize simprocSEvalExtension : SimprocExtension ← registerSimprocA
|
|||
private def addBuiltin (declName : Name) (stx : Syntax) (addDeclName : Name) : AttrM Unit := do
|
||||
let go : MetaM Unit := do
|
||||
let post := if stx[1].isNone then true else stx[1][0].getKind == ``Lean.Parser.Tactic.simpPost
|
||||
let val := mkAppN (mkConst addDeclName) #[toExpr declName, toExpr post, mkConst declName]
|
||||
let procExpr ← match (← getConstInfo declName).type with
|
||||
| .const ``Simproc _ => pure <| mkApp3 (mkConst ``Sum.inl [0, 0]) (mkConst ``Simproc) (mkConst ``DSimproc) (mkConst declName)
|
||||
| .const ``DSimproc _ => pure <| mkApp3 (mkConst ``Sum.inr [0, 0]) (mkConst ``Simproc) (mkConst ``DSimproc) (mkConst declName)
|
||||
| _ => throwError "unexpected type at simproc"
|
||||
let val := mkAppN (mkConst addDeclName) #[toExpr declName, toExpr post, procExpr]
|
||||
let initDeclName ← mkFreshUserName (declName ++ `declare)
|
||||
declareBuiltin initDeclName val
|
||||
go.run' {}
|
||||
|
|
@ -327,7 +355,7 @@ builtin_initialize
|
|||
descr := "Builtin simplification procedure"
|
||||
applicationTime := AttributeApplicationTime.afterCompilation
|
||||
erase := fun _ => throwError "Not implemented yet, [-builtin_simproc]"
|
||||
add := fun declName stx _ => addBuiltin declName stx ``addSimprocBuiltinAttr
|
||||
add := fun declName stx _ => addBuiltin declName stx ``addSimprocBuiltinAttrNew
|
||||
}
|
||||
|
||||
builtin_initialize
|
||||
|
|
@ -337,7 +365,7 @@ builtin_initialize
|
|||
descr := "Builtin symbolic evaluation procedure"
|
||||
applicationTime := AttributeApplicationTime.afterCompilation
|
||||
erase := fun _ => throwError "Not implemented yet, [-builtin_sevalproc]"
|
||||
add := fun declName stx _ => addBuiltin declName stx ``addSEvalprocBuiltinAttr
|
||||
add := fun declName stx _ => addBuiltin declName stx ``addSEvalprocBuiltinAttrNew
|
||||
}
|
||||
|
||||
def getSimprocs : CoreM Simprocs :=
|
||||
|
|
|
|||
|
|
@ -146,6 +146,18 @@ See `Step`.
|
|||
-/
|
||||
abbrev Simproc := Expr → SimpM Step
|
||||
|
||||
/--
|
||||
Similar to `Simproc`, but resulting expression should be definitionally equal to the input one.
|
||||
-/
|
||||
abbrev DSimproc := Expr → SimpM TransformStep
|
||||
|
||||
def _root_.Lean.TransformStep.toStep (s : TransformStep) : Step :=
|
||||
match s with
|
||||
| .done e => .done { expr := e }
|
||||
| .visit e => .visit { expr := e }
|
||||
| .continue (some e) => .continue (some { expr := e })
|
||||
| .continue none => .continue none
|
||||
|
||||
def mkEqTransResultStep (r : Result) (s : Step) : MetaM Step :=
|
||||
match s with
|
||||
| .done r' => return .done (← mkEqTransOptProofResult r.proof? r.cache r')
|
||||
|
|
@ -189,7 +201,7 @@ structure SimprocEntry extends SimprocOLeanEntry where
|
|||
Recall that we cannot store `Simproc` into .olean files because it is a closure.
|
||||
Given `SimprocOLeanEntry.declName`, we convert it into a `Simproc` by using the unsafe function `evalConstCheck`.
|
||||
-/
|
||||
proc : Simproc
|
||||
proc : Sum Simproc DSimproc
|
||||
|
||||
abbrev SimprocTree := DiscrTree SimprocEntry
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue