From b24fbf44f3aaa112f5d799ef2a341772d1eb222d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 5 Mar 2024 10:35:24 -0800 Subject: [PATCH] feat: `dsimproc` command Simplification procedures that produce definitionally equal results. WIP --- src/Init/Simproc.lean | 98 +++++++++++++++++++++----- src/Lean/Elab/Tactic/Simproc.lean | 12 ++-- src/Lean/Meta/Tactic/Simp/Simproc.lean | 58 +++++++++++---- src/Lean/Meta/Tactic/Simp/Types.lean | 14 +++- 4 files changed, 142 insertions(+), 40 deletions(-) diff --git a/src/Init/Simproc.lean b/src/Init/Simproc.lean index 199cc6fdee..debf5d10af 100644 --- a/src/Init/Simproc.lean +++ b/src/Init/Simproc.lean @@ -31,22 +31,43 @@ Simplification procedures can be also scoped or local. -/ syntax (docComment)? attrKind "simproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command +/-- +Similar to `simproc`, but resulting expression must be definitionally equal to the input one. +-/ +syntax (docComment)? attrKind "dsimproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command + /-- A user-defined simplification procedure declaration. To activate this procedure in `simp` tactic, we must provide it as an argument, or use the command `attribute` to set its `[simproc]` attribute. -/ syntax (docComment)? "simproc_decl " ident " (" term ")" " := " term : command +/-- +A user-defined defeq simplification procedure declaration. To activate this procedure in `simp` tactic, +we must provide it as an argument, or use the command `attribute` to set its `[simproc]` attribute. +-/ +syntax (docComment)? "dsimproc_decl " ident " (" term ")" " := " term : command + /-- A builtin simplification procedure. -/ syntax (docComment)? attrKind "builtin_simproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command +/-- +A builtin defeq simplification procedure. +-/ +syntax (docComment)? attrKind "builtin_dsimproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command + /-- A builtin simplification procedure declaration. -/ syntax (docComment)? "builtin_simproc_decl " ident " (" term ")" " := " term : command +/-- +A builtin defeq simplification procedure declaration. +-/ +syntax (docComment)? "builtin_dsimproc_decl " ident " (" term ")" " := " term : command + /-- Auxiliary command for associating a pattern with a simplification procedure. -/ @@ -86,33 +107,60 @@ macro_rules `($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body simproc_pattern% $pattern => $n) +macro_rules + | `($[$doc?:docComment]? dsimproc_decl $n:ident ($pattern:term) := $body) => do + let simprocType := `Lean.Meta.Simp.DSimproc + `($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body + simproc_pattern% $pattern => $n) + macro_rules | `($[$doc?:docComment]? builtin_simproc_decl $n:ident ($pattern:term) := $body) => do let simprocType := `Lean.Meta.Simp.Simproc `($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body builtin_simproc_pattern% $pattern => $n) +macro_rules + | `($[$doc?:docComment]? builtin_dsimproc_decl $n:ident ($pattern:term) := $body) => do + let simprocType := `Lean.Meta.Simp.DSimproc + `($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body + builtin_simproc_pattern% $pattern => $n) + +private def mkAttributeCmds + (kind : TSyntax `Lean.Parser.Term.attrKind) + (pre? : Option (TSyntax [`Lean.Parser.Tactic.simpPre, `Lean.Parser.Tactic.simpPost])) + (ids? : Option (Syntax.TSepArray `ident ",")) + (n : Ident) : MacroM (Array Syntax) := do + let mut cmds := #[] + let pushDefault (cmds : Array (TSyntax `command)) : MacroM (Array (TSyntax `command)) := do + return cmds.push (← `(attribute [$kind simproc $[$pre?]?] $n)) + if let some ids := ids? then + for id in ids.getElems do + let idName := id.getId + let (attrName, attrKey) := + if idName == `simp then + (`simprocAttr, "simproc") + else if idName == `seval then + (`sevalprocAttr, "sevalproc") + else + let idName := idName.appendAfter "_proc" + (`Parser.Attr ++ idName, idName.toString) + let attrStx : TSyntax `attr := ⟨mkNode attrName #[mkAtom attrKey, mkOptionalNode pre?]⟩ + cmds := cmds.push (← `(attribute [$kind $attrStx] $n)) + else + cmds ← pushDefault cmds + return cmds + macro_rules | `($[$doc?:docComment]? $kind:attrKind simproc $[$pre?]? $[ [ $ids?:ident,* ] ]? $n:ident ($pattern:term) := $body) => do - let mut cmds := #[(← `($[$doc?:docComment]? simproc_decl $n ($pattern) := $body))] - let pushDefault (cmds : Array (TSyntax `command)) : MacroM (Array (TSyntax `command)) := do - return cmds.push (← `(attribute [$kind simproc $[$pre?]?] $n)) - if let some ids := ids? then - for id in ids.getElems do - let idName := id.getId - let (attrName, attrKey) := - if idName == `simp then - (`simprocAttr, "simproc") - else if idName == `seval then - (`sevalprocAttr, "sevalproc") - else - let idName := idName.appendAfter "_proc" - (`Parser.Attr ++ idName, idName.toString) - let attrStx : TSyntax `attr := ⟨mkNode attrName #[mkAtom attrKey, mkOptionalNode pre?]⟩ - cmds := cmds.push (← `(attribute [$kind $attrStx] $n)) - else - cmds ← pushDefault cmds - return mkNullNode cmds + return mkNullNode <| + #[(← `($[$doc?:docComment]? simproc_decl $n ($pattern) := $body))] + ++ (← mkAttributeCmds kind pre? ids? n) + +macro_rules + | `($[$doc?:docComment]? $kind:attrKind dsimproc $[$pre?]? $[ [ $ids?:ident,* ] ]? $n:ident ($pattern:term) := $body) => do + return mkNullNode <| + #[(← `($[$doc?:docComment]? dsimproc_decl $n ($pattern) := $body))] + ++ (← mkAttributeCmds kind pre? ids? n) macro_rules | `($[$doc?:docComment]? $kind:attrKind builtin_simproc $[$pre?]? $n:ident ($pattern:term) := $body) => do @@ -126,4 +174,16 @@ macro_rules attribute [$kind builtin_simproc $[$pre?]?] $n attribute [$kind builtin_sevalproc $[$pre?]?] $n) +macro_rules + | `($[$doc?:docComment]? $kind:attrKind builtin_dsimproc $[$pre?]? $n:ident ($pattern:term) := $body) => do + `($[$doc?:docComment]? builtin_dsimproc_decl $n ($pattern) := $body + attribute [$kind builtin_simproc $[$pre?]?] $n) + | `($[$doc?:docComment]? $kind:attrKind builtin_dsimproc $[$pre?]? [seval] $n:ident ($pattern:term) := $body) => do + `($[$doc?:docComment]? builtin_dsimproc_decl $n ($pattern) := $body + attribute [$kind builtin_sevalproc $[$pre?]?] $n) + | `($[$doc?:docComment]? $kind:attrKind builtin_dsimproc $[$pre?]? [simp, seval] $n:ident ($pattern:term) := $body) => do + `($[$doc?:docComment]? builtin_dsimproc_decl $n ($pattern) := $body + attribute [$kind builtin_simproc $[$pre?]?] $n + attribute [$kind builtin_sevalproc $[$pre?]?] $n) + end Lean.Parser diff --git a/src/Lean/Elab/Tactic/Simproc.lean b/src/Lean/Elab/Tactic/Simproc.lean index 791403e8ab..1a13305f25 100644 --- a/src/Lean/Elab/Tactic/Simproc.lean +++ b/src/Lean/Elab/Tactic/Simproc.lean @@ -26,10 +26,11 @@ def elabSimprocKeys (stx : Syntax) : MetaM (Array Meta.SimpTheoremKey) := do let pattern ← elabSimprocPattern stx DiscrTree.mkPath pattern simpDtConfig -def checkSimprocType (declName : Name) : CoreM Unit := do +def checkSimprocType (declName : Name) : CoreM Bool := do let decl ← getConstInfo declName match decl.type with - | .const ``Simproc _ => pure () + | .const ``Simproc _ => pure false + | .const ``DSimproc _ => pure true | _ => throwError "unexpected type at '{declName}', 'Simproc' expected" namespace Command @@ -38,7 +39,7 @@ namespace Command let `(simproc_pattern% $pattern => $declName) := stx | throwUnsupportedSyntax let declName ← resolveGlobalConstNoOverload declName liftTermElabM do - checkSimprocType declName + discard <| checkSimprocType declName let keys ← elabSimprocKeys pattern registerSimproc declName keys @@ -46,9 +47,10 @@ namespace Command let `(builtin_simproc_pattern% $pattern => $declName) := stx | throwUnsupportedSyntax let declName ← resolveGlobalConstNoOverload declName liftTermElabM do - checkSimprocType declName + let dsimp ← checkSimprocType declName let keys ← elabSimprocKeys pattern - let val := mkAppN (mkConst ``registerBuiltinSimproc) #[toExpr declName, toExpr keys, mkConst declName] + let registerProcName := if dsimp then ``registerBuiltinDSimproc else ``registerBuiltinSimproc + let val := mkAppN (mkConst registerProcName) #[toExpr declName, toExpr keys, mkConst declName] let initDeclName ← mkFreshUserName (declName ++ `declare) declareBuiltin initDeclName val diff --git a/src/Lean/Meta/Tactic/Simp/Simproc.lean b/src/Lean/Meta/Tactic/Simp/Simproc.lean index d5b62491f0..2e0afd7f53 100644 --- a/src/Lean/Meta/Tactic/Simp/Simproc.lean +++ b/src/Lean/Meta/Tactic/Simp/Simproc.lean @@ -20,7 +20,7 @@ It contains: -/ structure BuiltinSimprocs where keys : HashMap Name (Array SimpTheoremKey) := {} - procs : HashMap Name Simproc := {} + procs : HashMap Name (Sum Simproc DSimproc) := {} deriving Inhabited /-- @@ -79,7 +79,7 @@ Given a declaration name `declName`, store the discrimination tree keys and the This method is invoked by the command `builtin_simproc_pattern%` elaborator. -/ -def registerBuiltinSimproc (declName : Name) (key : Array SimpTheoremKey) (proc : Simproc) : IO Unit := do +def registerBuiltinSimprocCore (declName : Name) (key : Array SimpTheoremKey) (proc : Sum Simproc DSimproc) : IO Unit := do unless (← initializing) do throw (IO.userError s!"invalid builtin simproc declaration, it can only be registered during initialization") if (← builtinSimprocDeclsRef.get).keys.contains declName then @@ -87,6 +87,12 @@ def registerBuiltinSimproc (declName : Name) (key : Array SimpTheoremKey) (proc builtinSimprocDeclsRef.modify fun { keys, procs } => { keys := keys.insert declName key, procs := procs.insert declName proc } +def registerBuiltinSimproc (declName : Name) (key : Array SimpTheoremKey) (proc : Simproc) : IO Unit := do + registerBuiltinSimprocCore declName key (.inl proc) + +def registerBuiltinDSimproc (declName : Name) (key : Array SimpTheoremKey) (proc : DSimproc) : IO Unit := do + registerBuiltinSimprocCore declName key (.inr proc) + def registerSimproc (declName : Name) (keys : Array SimpTheoremKey) : CoreM Unit := do let env ← getEnv unless (env.getModuleIdxFor? declName).isNone do @@ -112,14 +118,21 @@ builtin_initialize builtinSEvalprocsRef : IO.Ref Simprocs ← IO.mkRef {} abbrev SimprocExtension := ScopedEnvExtension SimprocOLeanEntry SimprocEntry Simprocs -unsafe def getSimprocFromDeclImpl (declName : Name) : ImportM Simproc := do +unsafe def getSimprocFromDeclImpl (declName : Name) : ImportM (Sum Simproc DSimproc) := do let ctx ← read - match ctx.env.evalConstCheck Simproc ctx.opts ``Lean.Meta.Simp.Simproc declName with - | .ok proc => return proc - | .error ex => throw (IO.userError ex) + match ctx.env.find? declName with + | none => throw <| IO.userError ("unknown constant '" ++ toString declName ++ "'") + | some info => + match info.type with + | .const ``Simproc _ => + return .inl (← IO.ofExcept <| ctx.env.evalConst Simproc ctx.opts declName) + | .const ``DSimproc _ => + return .inr (← IO.ofExcept <| ctx.env.evalConst DSimproc ctx.opts declName) + | _ => throw <| IO.userError "unexpected type at simproc" + @[implemented_by getSimprocFromDeclImpl] -opaque getSimprocFromDecl (declName: Name) : ImportM Simproc +opaque getSimprocFromDecl (declName: Name) : ImportM (Sum Simproc DSimproc) def toSimprocEntry (e : SimprocOLeanEntry) : ImportM SimprocEntry := do return { toSimprocOLeanEntry := e, proc := (← getSimprocFromDecl e.declName) } @@ -136,7 +149,7 @@ def addSimprocAttrCore (ext : SimprocExtension) (declName : Name) (kind : Attrib throwError "invalid [simproc] attribute, '{declName}' is not a simproc" ext.add { declName, post, keys, proc } kind -def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Name) (post : Bool) (proc : Simproc) : Simprocs := +def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : Simprocs := let s := { s with simprocNames := s.simprocNames.insert declName, erased := s.erased.erase declName } if post then { s with post := s.post.insertCore keys { declName, keys, post, proc } } @@ -146,15 +159,21 @@ def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Na /-- Implements attributes `builtin_simproc` and `builtin_sevalproc`. -/ -def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Simproc) : IO Unit := do +def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := do let some keys := (← builtinSimprocDeclsRef.get).keys.find? declName | throw (IO.userError "invalid [builtin_simproc] attribute, '{declName}' is not a builtin simproc") ref.modify fun s => s.addCore keys declName post proc def addSimprocBuiltinAttr (declName : Name) (post : Bool) (proc : Simproc) : IO Unit := - addSimprocBuiltinAttrCore builtinSimprocsRef declName post proc + addSimprocBuiltinAttrCore builtinSimprocsRef declName post (.inl proc) def addSEvalprocBuiltinAttr (declName : Name) (post : Bool) (proc : Simproc) : IO Unit := + addSimprocBuiltinAttrCore builtinSEvalprocsRef declName post (.inl proc) + +def addSimprocBuiltinAttrNew (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := + addSimprocBuiltinAttrCore builtinSimprocsRef declName post proc + +def addSEvalprocBuiltinAttrNew (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := addSimprocBuiltinAttrCore builtinSEvalprocsRef declName post proc def Simprocs.add (s : Simprocs) (declName : Name) (post : Bool) : CoreM Simprocs := do @@ -179,8 +198,13 @@ def SimprocEntry.try (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM extraArgs := extraArgs.push e.appArg! e := e.appFn! extraArgs := extraArgs.reverse - let s ← s.proc e - s.addExtraArgs extraArgs + match s.proc with + | .inl proc => + let s ← proc e + s.addExtraArgs extraArgs + | .inr proc => + let s ← proc e + s.toStep.addExtraArgs extraArgs def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM Step := do let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig)) @@ -315,7 +339,11 @@ builtin_initialize simprocSEvalExtension : SimprocExtension ← registerSimprocA private def addBuiltin (declName : Name) (stx : Syntax) (addDeclName : Name) : AttrM Unit := do let go : MetaM Unit := do let post := if stx[1].isNone then true else stx[1][0].getKind == ``Lean.Parser.Tactic.simpPost - let val := mkAppN (mkConst addDeclName) #[toExpr declName, toExpr post, mkConst declName] + let procExpr ← match (← getConstInfo declName).type with + | .const ``Simproc _ => pure <| mkApp3 (mkConst ``Sum.inl [0, 0]) (mkConst ``Simproc) (mkConst ``DSimproc) (mkConst declName) + | .const ``DSimproc _ => pure <| mkApp3 (mkConst ``Sum.inr [0, 0]) (mkConst ``Simproc) (mkConst ``DSimproc) (mkConst declName) + | _ => throwError "unexpected type at simproc" + let val := mkAppN (mkConst addDeclName) #[toExpr declName, toExpr post, procExpr] let initDeclName ← mkFreshUserName (declName ++ `declare) declareBuiltin initDeclName val go.run' {} @@ -327,7 +355,7 @@ builtin_initialize descr := "Builtin simplification procedure" applicationTime := AttributeApplicationTime.afterCompilation erase := fun _ => throwError "Not implemented yet, [-builtin_simproc]" - add := fun declName stx _ => addBuiltin declName stx ``addSimprocBuiltinAttr + add := fun declName stx _ => addBuiltin declName stx ``addSimprocBuiltinAttrNew } builtin_initialize @@ -337,7 +365,7 @@ builtin_initialize descr := "Builtin symbolic evaluation procedure" applicationTime := AttributeApplicationTime.afterCompilation erase := fun _ => throwError "Not implemented yet, [-builtin_sevalproc]" - add := fun declName stx _ => addBuiltin declName stx ``addSEvalprocBuiltinAttr + add := fun declName stx _ => addBuiltin declName stx ``addSEvalprocBuiltinAttrNew } def getSimprocs : CoreM Simprocs := diff --git a/src/Lean/Meta/Tactic/Simp/Types.lean b/src/Lean/Meta/Tactic/Simp/Types.lean index f9af107abd..94e4fb775c 100644 --- a/src/Lean/Meta/Tactic/Simp/Types.lean +++ b/src/Lean/Meta/Tactic/Simp/Types.lean @@ -146,6 +146,18 @@ See `Step`. -/ abbrev Simproc := Expr → SimpM Step +/-- +Similar to `Simproc`, but resulting expression should be definitionally equal to the input one. +-/ +abbrev DSimproc := Expr → SimpM TransformStep + +def _root_.Lean.TransformStep.toStep (s : TransformStep) : Step := + match s with + | .done e => .done { expr := e } + | .visit e => .visit { expr := e } + | .continue (some e) => .continue (some { expr := e }) + | .continue none => .continue none + def mkEqTransResultStep (r : Result) (s : Step) : MetaM Step := match s with | .done r' => return .done (← mkEqTransOptProofResult r.proof? r.cache r') @@ -189,7 +201,7 @@ structure SimprocEntry extends SimprocOLeanEntry where Recall that we cannot store `Simproc` into .olean files because it is a closure. Given `SimprocOLeanEntry.declName`, we convert it into a `Simproc` by using the unsafe function `evalConstCheck`. -/ - proc : Simproc + proc : Sum Simproc DSimproc abbrev SimprocTree := DiscrTree SimprocEntry