diff --git a/src/Init/Tactics.lean b/src/Init/Tactics.lean index 0e566b2359..05a3843fde 100644 --- a/src/Init/Tactics.lean +++ b/src/Init/Tactics.lean @@ -2262,6 +2262,18 @@ such as replacing `if c then _ else _` with `if h : c then _ else _` or `xs.map` -/ syntax (name := wf_preprocess) "wf_preprocess" (Tactic.simpPre <|> Tactic.simpPost)? patternIgnore("← " <|> "<- ")? (ppSpace prio)? : attr +/-- +Theorems tagged with the `method_specs_simp` attribute are used by `@[method_specs]` to further +rewrite the theorem statement. This is primarily used to rewrite type class methods further to +the desired user-visible form, e.g. from `Append.append` to `HAppend.hAppend`, which has the familiar +notation associated. + +The `method_specs` theorems are created on demand (using the realizable constant feature). Thus, +this simp set should behave the same in all modules. Do not add theorems to it except in the module +defining the thing you are rewriting. +-/ +syntax (name := method_specs_simp) "method_specs_simp" (Tactic.simpPre <|> Tactic.simpPost)? patternIgnore("← " <|> "<- ")? (ppSpace prio)? : attr + /-- The possible `norm_cast` kinds: `elim`, `move`, or `squash`. -/ syntax normCastLabel := &"elim" <|> &"move" <|> &"squash" diff --git a/src/Lean/Meta.lean b/src/Lean/Meta.lean index bd9f7bf4eb..063b930ea8 100644 --- a/src/Lean/Meta.lean +++ b/src/Lean/Meta.lean @@ -57,5 +57,4 @@ public import Lean.Meta.Diagnostics public import Lean.Meta.BinderNameHint public import Lean.Meta.TryThis public import Lean.Meta.Hint - -public section +public import Lean.Meta.MethodSpecs diff --git a/src/Lean/Meta/MethodSpecs.lean b/src/Lean/Meta/MethodSpecs.lean new file mode 100644 index 0000000000..dd822215c0 --- /dev/null +++ b/src/Lean/Meta/MethodSpecs.lean @@ -0,0 +1,211 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joachim Breitner +-/ + +module +prelude +public import Init.System.IO +public import Lean.Attributes +public import Lean.Meta.Tactic.Simp.SimpTheorems +import Lean.Meta.Basic +import Lean.Structure +import Lean.Meta.CtorRecognizer +import Lean.Meta.InferType +import Lean.Meta.AppBuilder +import Lean.ReservedNameAction +import Lean.Meta.Tactic.Simp.SimpTheorems +import Lean.Meta.Tactic.Simp.Types +import Lean.Meta.Tactic.Simp.Main + +namespace Lean + +open Meta + +structure MethodSpecTheorem where + /-- Name of the implementation function -/ + name : Name + levelParams : List Name + /-- `opImpl = Cls.op instClsT` -/ + type : Expr + +structure MethodSpecsInfo where + clsName : Name + /-- Array mapping field names to implementation functions. -/ + fieldImpls : Array (Name × Name) + /-- rewrite rules to apply -/ + thms : Array MethodSpecTheorem + +/-- +This function checks the `instName` for eligibility and collects the information to rewrite. +It is run twice: when setting the `@[specs]` attribute as a preflight check, and when actually realizing +the constants. +-/ +def getMethodSpecsInfo (instName : Name) : MetaM MethodSpecsInfo := do + let instInfo ← getConstInfoDefn instName + let some clsName ← isClass? instInfo.type + | throwError "expected `{.ofConstName instName}` to be a type class instance, but its \ + type{inlineExpr instInfo.type}does not look like a class." + let instArity ← forallTelescopeReducing instInfo.type fun xs _ => pure xs.size + let some structInfo := getStructureInfo? (← getEnv) clsName + | throwError "`{.ofConstName clsName}` is not a structure" + + lambdaTelescope instInfo.value fun xs body => do + let inst := mkAppN (mkConst instInfo.name (instInfo.levelParams.map mkLevelParam)) xs + let clsApp ← instantiateForall instInfo.type xs + unless xs.size == instArity && (← isConstructorApp body) do + throwError "the definition of `{.ofConstName instName}` does not have the expected shape" + + let mut fieldImpls := #[] + let mut thms := #[] + + for field in structInfo.fieldNames, arg in body.getAppArgsN structInfo.fieldNames.size do + if (← isProof arg) then continue + let arg := arg.eta + let f := arg.getAppFn + let ys := arg.getAppArgs + unless f.isConst do + throwError "field `{field}` of the instance is not an application of a constant" + unless f.constLevels! == instInfo.levelParams.map mkLevelParam do + throwError "function `{f}` is called with universe parameters\n {f.constLevels!}\nwhich differs from \ + the instances' universe parameters\n {instInfo.levelParams.map mkLevelParam}" + unless xs == ys do + throwError "function `{f}` does not take its arguments in the same order as the instance" + -- Construct the replacement theorems + let some fieldInfo := getFieldInfo? (← getEnv) clsName field + | throwError "internal error: could not find field {field} in structure {clsName}" + let lhs := arg + let projFn := mkConst fieldInfo.projFn clsApp.getAppFn.constLevels! + let rhs := mkAppN projFn (clsApp.getAppArgs ++ #[inst]) + let eq ← mkEq lhs rhs + let thm ← mkForallFVars xs eq + unless (← isDefEq lhs rhs) do + throwError "internal error: equation `{eq}` does not hold definitionally" + fieldImpls := fieldImpls.push (field, f.constName!) + thms := thms.push { name := f.constName!, levelParams := instInfo.levelParams, type := thm } + trace[Meta.MethodSpecs] "MethodSpecs for {instName}:\n{fieldImpls}\nthms: {thms.map (·.type)}" + + return {clsName, fieldImpls, thms} + +public structure MethodSpecsAttrData where + clsName : Name +deriving Inhabited + +def getParam (instName : Name) (_stx : Syntax) : AttrM MethodSpecsAttrData := do + -- Preflight check + let specsInfo ← (getMethodSpecsInfo instName).run' + return { + clsName := specsInfo.clsName + } + +/-- +Generate method specification theorems for the methods of the given type class instance. + +This expects all (non-proof) methods of the instance to be defined via separate helper functions, +which must take the same arguments as the instance itself, in the same order. + +If it is applied to an instance +``` +instance instClsT : Cls T where op := opImpl +``` +it produces a theorem `instClsT.op_spec` based on `opImpl.eq_def`, but phrased in terms of the +overloaded `Cls.op` operation, and similarly `instClsT.op_spec_` based on the equational theorems +`opImpl.eq_`. +-/ +@[builtin_doc] +builtin_initialize methodSpecsAttr : ParametricAttribute MethodSpecsAttrData ← + registerParametricAttribute { + name := `method_specs + descr := "generate method specification theorems" + getParam + } + +builtin_initialize methodSpecsSimpExtension : SimpExtension ← + registerSimpAttr `method_specs_simp + "simp lemma used to post-process the theorem created by `@[method_specs]`." + +def rewriteThm (ctx : Simp.Context) (simprocs : Simprocs) + (eqThmName destThmName : Name) : MetaM Unit := do + let thmInfo ← getConstVal eqThmName + let (result, _) ← simp thmInfo.type ctx (simprocs := #[simprocs]) + trace[Meta.MethodSpecs] "type for {destThmName}:{indentExpr result.expr}" + let value ← result.mkEqMPR <| mkConst eqThmName (thmInfo.levelParams.map mkLevelParam) + addDecl <| Declaration.thmDecl { + name := destThmName + levelParams := thmInfo.levelParams + type := result.expr + value := value + } + +def genSpecs (instName : Name) : MetaM Unit := do + let methodSpecsInfo ← getMethodSpecsInfo instName + let key := instName.str s!"{methodSpecsInfo.fieldImpls[0]!.1}_spec" + realizeConst instName key doRealize +where + doRealize := do + let methodSpecsInfo ← getMethodSpecsInfo instName + let mut s ← methodSpecsSimpExtension.getTheorems + for thm in methodSpecsInfo.thms do + trace[Meta.MethodSpecs] "adding simp theorem for {thm.name} : {thm.type}" + s := s.addSimpTheorem <| ← mkDSimpTheorem (.other thm.name) thm.levelParams.toArray thm.type + let ctx ← Simp.mkContext + (simpTheorems := #[s]) + (congrTheorems := (← getSimpCongrTheorems)) + (config := { } ) -- Simp.neutralConfig with dsimp := true, letToHave := false }) + let simprocs ← Simp.getSimprocs + + for (fieldName, implName) in methodSpecsInfo.fieldImpls do + let some unfoldThm ← getUnfoldEqnFor? implName (nonRec := true) + | throwError "failed to generate unfolding theorem for {.ofConstName implName}" + rewriteThm ctx simprocs unfoldThm (instName.str s!"{fieldName}_spec") + + if let some eqnThms ← getEqnsFor? implName then + for eqnThm in eqnThms, i in [:eqnThms.size] do + rewriteThm ctx simprocs eqnThm (instName.str s!"{fieldName}_spec_{i+1}") + +def startsWithFollowedByNumber (s p : String) : Bool := + s.startsWith p && (s.drop p.length).isNat + +def isSpecThmNameFor (env : Environment) (name : Name) : Option Name := do + let .str p n := name | none + let attrData ← methodSpecsAttr.getParam? env p + for fieldName in getStructureFields env attrData.clsName do + if n == s!"{fieldName}_spec" || startsWithFollowedByNumber n s!"{fieldName}_spec_" then + return p + none + +public partial def getMethodSpecTheorem (instName : Name) (op : String) : MetaM (Option Name) := do + let env ← getEnv + let some _ := methodSpecsAttr.getParam? env instName | return none + realizeGlobalConstNoOverloadCore (instName.str s!"{op}_spec") + +public partial def getMethodSpecTheorems (instName : Name) (op : String) : MetaM (Option (Array Name)) := do + let some _ := methodSpecsAttr.getParam? (← getEnv) instName | return none + -- Realize spec theorems + let _ ← realizeGlobalConstNoOverloadCore (instName.str s!"{op}_spec") + -- Now collect the generated ones + let mut i := 0 + let mut thms := #[] + let env ← getEnv + while true do + let thmName := instName.str s!"{op}_spec_{i+1}" + if env.containsOnBranch thmName then + thms := thms.push thmName + i := i + 1 + else + break + return some thms + +builtin_initialize + registerReservedNamePredicate fun env name => isSpecThmNameFor env name |>.isSome + + registerReservedNameAction fun name => do + if let some instName := isSpecThmNameFor (← getEnv) name then + (genSpecs instName).run' + return true + return false + + +builtin_initialize + Lean.registerTraceClass `Meta.MethodSpecs diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index 7c8ca354b0..45774cccda 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -122,6 +122,7 @@ private def useImplicitDefEqProof (thm : SimpTheorem) : SimpM Bool := do private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (thm : SimpTheorem) (numExtraArgs : Nat) : SimpM (Option Result) := do recordTriedSimpTheorem thm.origin let rec go (e : Expr) : SimpM (Option Result) := do + trace[Debug.Meta.Tactic.simp] "trying {← ppSimpTheorem thm} to rewrite{indentExpr e}" if (← withSimpMetaConfig <| isDefEq lhs e) then unless (← synthesizeArgs thm.origin bis xs) do return none @@ -142,6 +143,9 @@ private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInf we seldom have assigned metavariables in goals. -/ if (← instantiateMVars e) == rhs then + trace[Debug.Meta.Tactic.simp] "Not applying {← ppSimpTheorem thm} with type\ + {indentExpr type}\nto{indentExpr e}\nas the result is structurally equal \ + to the original expression" return none if thm.perm then /- diff --git a/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean b/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean index 68a0377117..bc187c56e9 100644 --- a/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean +++ b/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean @@ -364,17 +364,19 @@ private def checkTypeIsProp (type : Expr) : MetaM Unit := unless (← isProp type) do throwError "Invalid simp theorem: Expected a proposition, but found{indentExpr type}" -private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array Name) (proof : Expr) (post : Bool) (prio : Nat) (noIndexAtArgs : Bool) : MetaM SimpTheorem := do - assert! origin != .fvar ⟨.anonymous⟩ - let type ← instantiateMVars (← inferType e) +private def mkSimpTheoremKeys (type : Expr) (noIndexAtArgs : Bool) : MetaM (Array SimpTheoremKey × Bool) := do withNewMCtxDepth do let (_, _, type) ← forallMetaTelescopeReducing type let type ← whnfR type - let (keys, perm) ← - match type.eq? with - | some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs noIndexAtArgs, ← isPerm lhs rhs) - | none => throwError "Unexpected kind of simp theorem{indentExpr type}" - return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) } + match type.eq? with + | some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs noIndexAtArgs, ← isPerm lhs rhs) + | none => throwError "Unexpected kind of simp theorem{indentExpr type}" + +private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array Name) (proof : Expr) (post : Bool) (prio : Nat) (noIndexAtArgs : Bool) : MetaM SimpTheorem := do + assert! origin != .fvar ⟨.anonymous⟩ + let type ← instantiateMVars (← inferType e) + let (keys, perm) ← mkSimpTheoremKeys type noIndexAtArgs + return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) } /-- Creates a `SimpTheorem` from a global theorem. @@ -427,6 +429,19 @@ def mkSimpTheoremFromExpr (id : Origin) (levelParams : Array Name) (proof : Expr (← preprocessProof proof inv).mapM fun val => mkSimpTheoremCore id val levelParams val post prio (noIndexAtArgs := true) +/-- Creates a `SimpTheorem` from a definitional equality. -/ +def mkDSimpTheorem (id : Origin) (levelParams : Array Name) (type : Expr) + (post := true) (prio : Nat := eval_prio default) (config : ConfigWithKey := simpGlobalConfig) : + MetaM SimpTheorem := do + withConfigWithKey config do + let (keys, perm) ← mkSimpTheoremKeys type (noIndexAtArgs := true) + let proof ← forallTelescopeReducing type fun xs r => do + let some (_, lhs, _rhs) := r.eq? + | throwError "Unexpected kind of dsimp theorem{indentExpr type}" + -- We need to wrap the proof in a type hint, else the type is lost + mkExpectedTypeHint (← mkLambdaFVars xs (← mkEqRefl lhs)) type + return { origin := id, keys, perm, post, levelParams, proof, priority := prio, rfl := true } + /-- A simp theorem or information about a declaration to unfold by simp. This is stored in the oleans to implement the `simp` attribute and user-defined simp sets. diff --git a/tests/lean/run/methodSpecs.lean b/tests/lean/run/methodSpecs.lean new file mode 100644 index 0000000000..b3251eeb44 --- /dev/null +++ b/tests/lean/run/methodSpecs.lean @@ -0,0 +1,144 @@ +inductive L α where + | nil : L α + | cons : α → L α → L α + +def L.beqImpl [BEq α] : L α → L α → Bool + | nil, nil => true + | cons x xs, cons y ys => x == y && L.beqImpl xs ys + | _, _ => false + +@[method_specs] instance [BEq α] : BEq (L α) := ⟨L.beqImpl⟩ + +/-- +info: theorem instBEqL.beq_spec.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (x x_1 : L α), + (x == x_1) = + match x, x_1 with + | L.nil, L.nil => true + | L.cons x xs, L.cons y ys => x == y && xs == ys + | x, x_2 => false +-/ +#guard_msgs(pass trace, all) in +#print sig instBEqL.beq_spec + +/-- +info: theorem instBEqL.beq_spec_1.{u_1} : ∀ {α : Type u_1} [inst : BEq α], (L.nil == L.nil) = true +-/ +#guard_msgs(pass trace, all) in +#print sig instBEqL.beq_spec_1 + +/-- +info: theorem instBEqL.beq_spec_2.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (x_2 : α) (xs : L α) (y : α) (ys : L α), + (L.cons x_2 xs == L.cons y ys) = (x_2 == y && xs == ys) +-/ +#guard_msgs(pass trace, all) in +#print sig instBEqL.beq_spec_2 + +/-- +info: theorem instBEqL.beq_spec_3.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (x x_1 : L α), + (x = L.nil → x_1 = L.nil → False) → + (∀ (x_2 : α) (xs : L α) (y : α) (ys : L α), x = L.cons x_2 xs → x_1 = L.cons y ys → False) → (x == x_1) = false +-/ +#guard_msgs(pass trace, all) in +#print sig instBEqL.beq_spec_3 + +/-- error: Unknown constant `instBEqL.beq_spec_4` -/ +#guard_msgs(pass trace, all) in +#print sig instBEqL.beq_spec_4 + +-- Other names are not reserved + +/-- error: Unknown constant `instBEqL.eq_spec` -/ +#guard_msgs in #print sig instBEqL.eq_spec + +/-- error: Unknown constant `instBEqL.beq_spec_` -/ +#guard_msgs in #print sig instBEqL.beq_spec_ + +-- Test rewriting all the way to HAppend + +@[method_specs_simp] theorem Append.append_eq_happend : + @Append.append α inst = @HAppend.hAppend α α α (@instHAppendOfAppend α inst) := rfl + +def L.append {α : Type u} : L α → L α → L α + | nil, ys => ys + | cons x xs, ys => cons x (L.append xs ys) + +@[method_specs] instance (α : Type u) : Append (L α) where + append := L.append + +/-- +info: theorem instAppendL.append_spec_2.{u} : ∀ {α : Type u} (x : L α) (x_2 : α) (xs : L α), + L.cons x_2 xs ++ x = L.cons x_2 (xs ++ x) +-/ +#guard_msgs in #print sig instAppendL.append_spec_2 + +-- Test that rewriting works with non-rfl theorem too + +class Cls α where op : α → α +class HCls α where hOp : α → α +instance instHClsOfCls [Cls α] : HCls α where hOp := Cls.op +-- NB: Not a rfl theorem +@[method_specs_simp] theorem Cls.op_eq_hOp : @Cls.op α inst = @HCls.hOp α (@instHClsOfCls α inst) := (rfl) + +def L.op {α : Type u} : L α → L α + | nil => nil + | cons x xs => cons x (L.op xs) +@[method_specs] instance : Cls (L α) where op := L.op + +/-- +info: theorem instClsL.op_spec_2.{u} : ∀ {α : Type u} (x_1 : α) (xs : L α), + HCls.hOp (L.cons x_1 xs) = L.cons x_1 (HCls.hOp xs) +-/ +#guard_msgs in +#print sig instClsL.op_spec_2 + +/-! +Now some error conditions +-/ + +/-- error: `Foo` is not a definition -/ +#guard_msgs in @[method_specs] inductive Foo + +/-- +error: expected `foo` to be a type class instance, but its type `Nat` does not look like a class. +-/ +#guard_msgs in @[method_specs] def foo := 1 + +structure S where field : Nat +/-- +error: expected `aS` to be a type class instance, but its type `S` does not look like a class. +-/ +#guard_msgs in @[method_specs] def aS : S := ⟨1⟩ + +@[class] inductive indClass where | mk +/-- error: `indClass` is not a structure -/ +#guard_msgs in @[method_specs] def instIndClass : indClass := .mk + +-- This used to fail until we eta-reduced the field values +@[method_specs] instance anotherInstBEqL [BEq α] : BEq (L α) := ⟨fun x y => L.beqImpl x y⟩ + +def L.badBeqImpl {α : Type u} : L α → L α → Bool + | nil, nil => true + | cons _ xs, cons _ ys => L.badBeqImpl xs ys + | _, _ => false + +/-- error: function `@L.badBeqImpl` does not take its arguments in the same order as the instance -/ +#guard_msgs in +@[method_specs] instance badInstBEqL [BEq α] : BEq (L α) := ⟨L.badBeqImpl⟩ + +-- L.append has a more general type (in terms of universes) +-- than the instance below. +-- This should be caught and warned about. + +def L.badAppend : L α → L α → L α + | nil, ys => ys + | cons x xs, ys => cons x (L.badAppend xs ys) + +/-- +error: function `@L.badAppend` is called with universe parameters + [u+1] +which differs from the instances' universe parameters + [u] +-/ +#guard_msgs in +@[method_specs] instance badInstAppendL (α : Type u) : Append (L α) where + append := L.badAppend