diff --git a/src/Lean/Declaration.lean b/src/Lean/Declaration.lean index f93a11f1c3..30336c054f 100644 --- a/src/Lean/Declaration.lean +++ b/src/Lean/Declaration.lean @@ -409,6 +409,16 @@ def isInductive : ConstantInfo → Bool | inductInfo _ => true | _ => false +/-- + List of all (including this one) declarations in the same mutual block. +-/ +def all : ConstantInfo → List Name + | inductInfo val => val.all + | defnInfo val => val.all + | thmInfo val => val.all + | opaqueInfo val => val.all + | _ => [] + @[extern "lean_instantiate_type_lparams"] opaque instantiateTypeLevelParams (c : @& ConstantInfo) (ls : @& List Level) : Expr diff --git a/src/Lean/Elab/PreDefinition/Basic.lean b/src/Lean/Elab/PreDefinition/Basic.lean index aa9c37a5ec..e1d490f5ec 100644 --- a/src/Lean/Elab/PreDefinition/Basic.lean +++ b/src/Lean/Elab/PreDefinition/Basic.lean @@ -92,30 +92,32 @@ private def compileDecl (decl : Declaration) : TermElabM Bool := do throw ex return true -private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (applyAttrAfterCompilation := true) : TermElabM Unit := +private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List Name) (applyAttrAfterCompilation := true) : TermElabM Unit := withRef preDef.ref do let preDef ← abstractNestedProofs preDef let decl ← match preDef.kind with | DefKind.«theorem» => pure <| Declaration.thmDecl { - name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value + name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value, all } | DefKind.«opaque» => pure <| Declaration.opaqueDecl { name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value - isUnsafe := preDef.modifiers.isUnsafe + isUnsafe := preDef.modifiers.isUnsafe, all } | DefKind.«abbrev» => pure <| Declaration.defnDecl { name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value hints := ReducibilityHints.«abbrev» - safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe } + safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe, + all } | _ => -- definitions and examples pure <| Declaration.defnDecl { - name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value, - hints := ReducibilityHints.regular (getMaxHeight (← getEnv) preDef.value + 1), - safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe } + name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) preDef.value + 1) + safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe, + all } addDecl decl withSaveInfoContext do -- save new env addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true) @@ -128,11 +130,11 @@ private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (applyAttrAft if applyAttrAfterCompilation then applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation -def addAndCompileNonRec (preDef : PreDefinition) : TermElabM Unit := do - addNonRecAux preDef true +def addAndCompileNonRec (preDef : PreDefinition) (all : List Name := [preDef.declName]) : TermElabM Unit := do + addNonRecAux preDef (compile := true) (all := all) -def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) : TermElabM Unit := do - addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation) +def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all : List Name := [preDef.declName]) : TermElabM Unit := do + addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation) (all := all) /-- Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module @@ -146,13 +148,14 @@ def eraseRecAppSyntax (preDef : PreDefinition) : CoreM PreDefinition := def addAndCompileUnsafe (preDefs : Array PreDefinition) (safety := DefinitionSafety.unsafe) : TermElabM Unit := do let preDefs ← preDefs.mapM fun d => eraseRecAppSyntax d withRef preDefs[0].ref do + let all := preDefs.toList.map (·.declName) let decl := Declaration.mutualDefnDecl <| ← preDefs.toList.mapM fun preDef => return { name := preDef.declName levelParams := preDef.levelParams type := preDef.type value := preDef.value - safety := safety hints := ReducibilityHints.opaque + safety, all } addDecl decl withSaveInfoContext do -- save new env diff --git a/src/Lean/Elab/PreDefinition/Main.lean b/src/Lean/Elab/PreDefinition/Main.lean index 53dff57c58..88c7293412 100644 --- a/src/Lean/Elab/PreDefinition/Main.lean +++ b/src/Lean/Elab/PreDefinition/Main.lean @@ -19,15 +19,16 @@ structure TerminationHints where private def addAndCompilePartial (preDefs : Array PreDefinition) (useSorry := false) : TermElabM Unit := do for preDef in preDefs do trace[Elab.definition] "processing {preDef.declName}" + let all := preDefs.toList.map (·.declName) forallTelescope preDef.type fun xs type => do - let val ← if useSorry then + let value ← if useSorry then mkLambdaFVars xs (← mkSorry type (synthetic := true)) else liftM <| mkInhabitantFor preDef.declName xs type addNonRec { preDef with kind := DefKind.«opaque» - value := val - } + value + } (all := all) addAndCompilePartialRec preDefs private def isNonRecursive (preDef : PreDefinition) : Bool := diff --git a/src/Lean/Elab/PreDefinition/WF/Main.lean b/src/Lean/Elab/PreDefinition/WF/Main.lean index 8e5e41883b..0db395002f 100644 --- a/src/Lean/Elab/PreDefinition/WF/Main.lean +++ b/src/Lean/Elab/PreDefinition/WF/Main.lean @@ -26,6 +26,7 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR if (← isOnlyOneUnaryDef preDefs fixedPrefixSize) then return () let us := preDefNonRec.levelParams.map mkLevelParam + let all := preDefs.toList.map (·.declName) for fidx in [:preDefs.size] do let preDef := preDefs[fidx] let value ← lambdaTelescope preDef.value fun xs _ => do @@ -47,7 +48,7 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR let arg ← mkSum 0 domain mkLambdaFVars xs (mkApp (mkAppN (mkConst preDefNonRec.declName us) xs[:fixedPrefixSize]) arg) trace[Elab.definition.wf] "{preDef.declName} := {value}" - addNonRec { preDef with value } (applyAttrAfterCompilation := false) + addNonRec { preDef with value } (applyAttrAfterCompilation := false) (all := all) partial def withCommonTelescope (preDefs : Array PreDefinition) (k : Array Expr → Array Expr → TermElabM α) : TermElabM α := go #[] (preDefs.map (·.value)) diff --git a/tests/lean/allFieldForConstants.lean b/tests/lean/allFieldForConstants.lean new file mode 100644 index 0000000000..b89e333f41 --- /dev/null +++ b/tests/lean/allFieldForConstants.lean @@ -0,0 +1,114 @@ +import Lean + +open Lean Meta in +def printMutualBlock (declName : Name) : MetaM Unit := do + IO.println (← getConstInfo declName).all + +mutual + def even : Nat → Bool + | 0 => true + | x+1 => !odd x + def odd : Nat → Bool + | 0 => false + | x+1 => !even x +end + +#eval printMutualBlock ``even +#eval printMutualBlock ``odd + +namespace Ex1 +mutual + partial def f (x : Nat) : Nat := + if x < 10 then g x + 1 else 0 + partial def g (x : Nat) : Nat := + f (x * 3 / 2) + partial def h (x : Nat) : Nat := + f x +end + +#eval printMutualBlock ``f +#eval printMutualBlock ``g +-- Recall that Lean breaks a mutual block into strongly connected components +#eval printMutualBlock ``h + +end Ex1 + +namespace Ex2 +mutual + unsafe def f (x : Nat) : Nat := + if x < 10 then g x + 1 else 0 + unsafe def g (x : Nat) : Nat := + f (x * 3 / 2) + unsafe def h (x : Nat) : Nat := + f x +end + +#eval printMutualBlock ``f +#eval printMutualBlock ``g +-- Recall that Lean breaks a mutual block into strongly connected components +#eval printMutualBlock ``h +end Ex2 + +inductive Foo where + | text : String → Foo + | element : List Foo → Foo + +namespace Foo + +mutual + @[simp] def textLengthList : List Foo → Nat + | [] => 0 + | f::fs => textLength f + textLengthList fs + + @[simp] def textLength : Foo → Nat + | text s => s.length + | element children => textLengthList children +end + +def concat (f₁ f₂ : Foo) : Foo := + Foo.element [f₁, f₂] + +theorem textLength_concat (f₁ f₂ : Foo) : textLength (concat f₁ f₂) = textLength f₁ + textLength f₂ := by + simp [concat] + +mutual + @[simp] def flatList : List Foo → List String + | [] => [] + | f :: fs => flat f ++ flatList fs + + @[simp] def flat : Foo → List String + | text s => [s] + | element children => flatList children +end + +def listStringLen (xs : List String) : Nat := + xs.foldl (init := 0) fun sum s => sum + s.length + +attribute [simp] List.foldl + +theorem foldl_init (s : Nat) (xs : List String) : (xs.foldl (init := s) fun sum s => sum + s.length) = s + (xs.foldl (init := 0) fun sum s => sum + s.length) := by + induction xs generalizing s with + | nil => simp + | cons x xs ih => simp_arith [ih x.length, ih (s + x.length)] + +theorem listStringLen_append (xs ys : List String) : listStringLen (xs ++ ys) = listStringLen xs + listStringLen ys := by + simp [listStringLen] + induction xs with + | nil => simp + | cons x xs ih => simp_arith [foldl_init x.length, ih] + +mutual + theorem listStringLen_flat (f : Foo) : listStringLen (flat f) = textLength f := by + match f with + | text s => simp [listStringLen] + | element cs => simp [listStringLen_flatList cs] + + theorem listStringLen_flatList (cs : List Foo) : listStringLen (flatList cs) = textLengthList cs := by + match cs with + | [] => simp + | f :: fs => simp [listStringLen_append, listStringLen_flatList fs, listStringLen_flat f] +end + +#eval printMutualBlock ``listStringLen_flat + +end Foo diff --git a/tests/lean/allFieldForConstants.lean.expected.out b/tests/lean/allFieldForConstants.lean.expected.out new file mode 100644 index 0000000000..9131814468 --- /dev/null +++ b/tests/lean/allFieldForConstants.lean.expected.out @@ -0,0 +1,9 @@ +[even, odd] +[even, odd] +[Ex1.f, Ex1.g] +[Ex1.f, Ex1.g] +[Ex1.h] +[Ex2.f, Ex2.g] +[Ex2.f, Ex2.g] +[Ex2.h] +[Foo.listStringLen_flat, Foo.listStringLen_flatList]