feat: save mutual block information for definitions/theorems/opaques
This commit is contained in:
parent
ea3e27bbc4
commit
98c775da34
6 changed files with 154 additions and 16 deletions
|
|
@ -409,6 +409,16 @@ def isInductive : ConstantInfo → Bool
|
|||
| inductInfo _ => true
|
||||
| _ => false
|
||||
|
||||
/--
|
||||
List of all (including this one) declarations in the same mutual block.
|
||||
-/
|
||||
def all : ConstantInfo → List Name
|
||||
| inductInfo val => val.all
|
||||
| defnInfo val => val.all
|
||||
| thmInfo val => val.all
|
||||
| opaqueInfo val => val.all
|
||||
| _ => []
|
||||
|
||||
@[extern "lean_instantiate_type_lparams"]
|
||||
opaque instantiateTypeLevelParams (c : @& ConstantInfo) (ls : @& List Level) : Expr
|
||||
|
||||
|
|
|
|||
|
|
@ -92,30 +92,32 @@ private def compileDecl (decl : Declaration) : TermElabM Bool := do
|
|||
throw ex
|
||||
return true
|
||||
|
||||
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (applyAttrAfterCompilation := true) : TermElabM Unit :=
|
||||
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List Name) (applyAttrAfterCompilation := true) : TermElabM Unit :=
|
||||
withRef preDef.ref do
|
||||
let preDef ← abstractNestedProofs preDef
|
||||
let decl ←
|
||||
match preDef.kind with
|
||||
| DefKind.«theorem» =>
|
||||
pure <| Declaration.thmDecl {
|
||||
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
|
||||
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value, all
|
||||
}
|
||||
| DefKind.«opaque» =>
|
||||
pure <| Declaration.opaqueDecl {
|
||||
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
|
||||
isUnsafe := preDef.modifiers.isUnsafe
|
||||
isUnsafe := preDef.modifiers.isUnsafe, all
|
||||
}
|
||||
| DefKind.«abbrev» =>
|
||||
pure <| Declaration.defnDecl {
|
||||
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
|
||||
hints := ReducibilityHints.«abbrev»
|
||||
safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe }
|
||||
safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe,
|
||||
all }
|
||||
| _ => -- definitions and examples
|
||||
pure <| Declaration.defnDecl {
|
||||
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value,
|
||||
hints := ReducibilityHints.regular (getMaxHeight (← getEnv) preDef.value + 1),
|
||||
safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe }
|
||||
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
|
||||
hints := ReducibilityHints.regular (getMaxHeight (← getEnv) preDef.value + 1)
|
||||
safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe,
|
||||
all }
|
||||
addDecl decl
|
||||
withSaveInfoContext do -- save new env
|
||||
addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true)
|
||||
|
|
@ -128,11 +130,11 @@ private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (applyAttrAft
|
|||
if applyAttrAfterCompilation then
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
|
||||
|
||||
def addAndCompileNonRec (preDef : PreDefinition) : TermElabM Unit := do
|
||||
addNonRecAux preDef true
|
||||
def addAndCompileNonRec (preDef : PreDefinition) (all : List Name := [preDef.declName]) : TermElabM Unit := do
|
||||
addNonRecAux preDef (compile := true) (all := all)
|
||||
|
||||
def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) : TermElabM Unit := do
|
||||
addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation)
|
||||
def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all : List Name := [preDef.declName]) : TermElabM Unit := do
|
||||
addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation) (all := all)
|
||||
|
||||
/--
|
||||
Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module
|
||||
|
|
@ -146,13 +148,14 @@ def eraseRecAppSyntax (preDef : PreDefinition) : CoreM PreDefinition :=
|
|||
def addAndCompileUnsafe (preDefs : Array PreDefinition) (safety := DefinitionSafety.unsafe) : TermElabM Unit := do
|
||||
let preDefs ← preDefs.mapM fun d => eraseRecAppSyntax d
|
||||
withRef preDefs[0].ref do
|
||||
let all := preDefs.toList.map (·.declName)
|
||||
let decl := Declaration.mutualDefnDecl <| ← preDefs.toList.mapM fun preDef => return {
|
||||
name := preDef.declName
|
||||
levelParams := preDef.levelParams
|
||||
type := preDef.type
|
||||
value := preDef.value
|
||||
safety := safety
|
||||
hints := ReducibilityHints.opaque
|
||||
safety, all
|
||||
}
|
||||
addDecl decl
|
||||
withSaveInfoContext do -- save new env
|
||||
|
|
|
|||
|
|
@ -19,15 +19,16 @@ structure TerminationHints where
|
|||
private def addAndCompilePartial (preDefs : Array PreDefinition) (useSorry := false) : TermElabM Unit := do
|
||||
for preDef in preDefs do
|
||||
trace[Elab.definition] "processing {preDef.declName}"
|
||||
let all := preDefs.toList.map (·.declName)
|
||||
forallTelescope preDef.type fun xs type => do
|
||||
let val ← if useSorry then
|
||||
let value ← if useSorry then
|
||||
mkLambdaFVars xs (← mkSorry type (synthetic := true))
|
||||
else
|
||||
liftM <| mkInhabitantFor preDef.declName xs type
|
||||
addNonRec { preDef with
|
||||
kind := DefKind.«opaque»
|
||||
value := val
|
||||
}
|
||||
value
|
||||
} (all := all)
|
||||
addAndCompilePartialRec preDefs
|
||||
|
||||
private def isNonRecursive (preDef : PreDefinition) : Bool :=
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR
|
|||
if (← isOnlyOneUnaryDef preDefs fixedPrefixSize) then
|
||||
return ()
|
||||
let us := preDefNonRec.levelParams.map mkLevelParam
|
||||
let all := preDefs.toList.map (·.declName)
|
||||
for fidx in [:preDefs.size] do
|
||||
let preDef := preDefs[fidx]
|
||||
let value ← lambdaTelescope preDef.value fun xs _ => do
|
||||
|
|
@ -47,7 +48,7 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR
|
|||
let arg ← mkSum 0 domain
|
||||
mkLambdaFVars xs (mkApp (mkAppN (mkConst preDefNonRec.declName us) xs[:fixedPrefixSize]) arg)
|
||||
trace[Elab.definition.wf] "{preDef.declName} := {value}"
|
||||
addNonRec { preDef with value } (applyAttrAfterCompilation := false)
|
||||
addNonRec { preDef with value } (applyAttrAfterCompilation := false) (all := all)
|
||||
|
||||
partial def withCommonTelescope (preDefs : Array PreDefinition) (k : Array Expr → Array Expr → TermElabM α) : TermElabM α :=
|
||||
go #[] (preDefs.map (·.value))
|
||||
|
|
|
|||
114
tests/lean/allFieldForConstants.lean
Normal file
114
tests/lean/allFieldForConstants.lean
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
import Lean
|
||||
|
||||
open Lean Meta in
|
||||
def printMutualBlock (declName : Name) : MetaM Unit := do
|
||||
IO.println (← getConstInfo declName).all
|
||||
|
||||
mutual
|
||||
def even : Nat → Bool
|
||||
| 0 => true
|
||||
| x+1 => !odd x
|
||||
def odd : Nat → Bool
|
||||
| 0 => false
|
||||
| x+1 => !even x
|
||||
end
|
||||
|
||||
#eval printMutualBlock ``even
|
||||
#eval printMutualBlock ``odd
|
||||
|
||||
namespace Ex1
|
||||
mutual
|
||||
partial def f (x : Nat) : Nat :=
|
||||
if x < 10 then g x + 1 else 0
|
||||
partial def g (x : Nat) : Nat :=
|
||||
f (x * 3 / 2)
|
||||
partial def h (x : Nat) : Nat :=
|
||||
f x
|
||||
end
|
||||
|
||||
#eval printMutualBlock ``f
|
||||
#eval printMutualBlock ``g
|
||||
-- Recall that Lean breaks a mutual block into strongly connected components
|
||||
#eval printMutualBlock ``h
|
||||
|
||||
end Ex1
|
||||
|
||||
namespace Ex2
|
||||
mutual
|
||||
unsafe def f (x : Nat) : Nat :=
|
||||
if x < 10 then g x + 1 else 0
|
||||
unsafe def g (x : Nat) : Nat :=
|
||||
f (x * 3 / 2)
|
||||
unsafe def h (x : Nat) : Nat :=
|
||||
f x
|
||||
end
|
||||
|
||||
#eval printMutualBlock ``f
|
||||
#eval printMutualBlock ``g
|
||||
-- Recall that Lean breaks a mutual block into strongly connected components
|
||||
#eval printMutualBlock ``h
|
||||
end Ex2
|
||||
|
||||
inductive Foo where
|
||||
| text : String → Foo
|
||||
| element : List Foo → Foo
|
||||
|
||||
namespace Foo
|
||||
|
||||
mutual
|
||||
@[simp] def textLengthList : List Foo → Nat
|
||||
| [] => 0
|
||||
| f::fs => textLength f + textLengthList fs
|
||||
|
||||
@[simp] def textLength : Foo → Nat
|
||||
| text s => s.length
|
||||
| element children => textLengthList children
|
||||
end
|
||||
|
||||
def concat (f₁ f₂ : Foo) : Foo :=
|
||||
Foo.element [f₁, f₂]
|
||||
|
||||
theorem textLength_concat (f₁ f₂ : Foo) : textLength (concat f₁ f₂) = textLength f₁ + textLength f₂ := by
|
||||
simp [concat]
|
||||
|
||||
mutual
|
||||
@[simp] def flatList : List Foo → List String
|
||||
| [] => []
|
||||
| f :: fs => flat f ++ flatList fs
|
||||
|
||||
@[simp] def flat : Foo → List String
|
||||
| text s => [s]
|
||||
| element children => flatList children
|
||||
end
|
||||
|
||||
def listStringLen (xs : List String) : Nat :=
|
||||
xs.foldl (init := 0) fun sum s => sum + s.length
|
||||
|
||||
attribute [simp] List.foldl
|
||||
|
||||
theorem foldl_init (s : Nat) (xs : List String) : (xs.foldl (init := s) fun sum s => sum + s.length) = s + (xs.foldl (init := 0) fun sum s => sum + s.length) := by
|
||||
induction xs generalizing s with
|
||||
| nil => simp
|
||||
| cons x xs ih => simp_arith [ih x.length, ih (s + x.length)]
|
||||
|
||||
theorem listStringLen_append (xs ys : List String) : listStringLen (xs ++ ys) = listStringLen xs + listStringLen ys := by
|
||||
simp [listStringLen]
|
||||
induction xs with
|
||||
| nil => simp
|
||||
| cons x xs ih => simp_arith [foldl_init x.length, ih]
|
||||
|
||||
mutual
|
||||
theorem listStringLen_flat (f : Foo) : listStringLen (flat f) = textLength f := by
|
||||
match f with
|
||||
| text s => simp [listStringLen]
|
||||
| element cs => simp [listStringLen_flatList cs]
|
||||
|
||||
theorem listStringLen_flatList (cs : List Foo) : listStringLen (flatList cs) = textLengthList cs := by
|
||||
match cs with
|
||||
| [] => simp
|
||||
| f :: fs => simp [listStringLen_append, listStringLen_flatList fs, listStringLen_flat f]
|
||||
end
|
||||
|
||||
#eval printMutualBlock ``listStringLen_flat
|
||||
|
||||
end Foo
|
||||
9
tests/lean/allFieldForConstants.lean.expected.out
Normal file
9
tests/lean/allFieldForConstants.lean.expected.out
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
[even, odd]
|
||||
[even, odd]
|
||||
[Ex1.f, Ex1.g]
|
||||
[Ex1.f, Ex1.g]
|
||||
[Ex1.h]
|
||||
[Ex2.f, Ex2.g]
|
||||
[Ex2.f, Ex2.g]
|
||||
[Ex2.h]
|
||||
[Foo.listStringLen_flat, Foo.listStringLen_flatList]
|
||||
Loading…
Add table
Reference in a new issue