feat: save mutual block information for definitions/theorems/opaques

This commit is contained in:
Leonardo de Moura 2022-06-23 16:39:51 -07:00
parent ea3e27bbc4
commit 98c775da34
6 changed files with 154 additions and 16 deletions

View file

@ -409,6 +409,16 @@ def isInductive : ConstantInfo → Bool
| inductInfo _ => true
| _ => false
/--
List of all (including this one) declarations in the same mutual block.
-/
def all : ConstantInfo → List Name
| inductInfo val => val.all
| defnInfo val => val.all
| thmInfo val => val.all
| opaqueInfo val => val.all
| _ => []
@[extern "lean_instantiate_type_lparams"]
opaque instantiateTypeLevelParams (c : @& ConstantInfo) (ls : @& List Level) : Expr

View file

@ -92,30 +92,32 @@ private def compileDecl (decl : Declaration) : TermElabM Bool := do
throw ex
return true
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (applyAttrAfterCompilation := true) : TermElabM Unit :=
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List Name) (applyAttrAfterCompilation := true) : TermElabM Unit :=
withRef preDef.ref do
let preDef ← abstractNestedProofs preDef
let decl ←
match preDef.kind with
| DefKind.«theorem» =>
pure <| Declaration.thmDecl {
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value, all
}
| DefKind.«opaque» =>
pure <| Declaration.opaqueDecl {
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
isUnsafe := preDef.modifiers.isUnsafe
isUnsafe := preDef.modifiers.isUnsafe, all
}
| DefKind.«abbrev» =>
pure <| Declaration.defnDecl {
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
hints := ReducibilityHints.«abbrev»
safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe }
safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe,
all }
| _ => -- definitions and examples
pure <| Declaration.defnDecl {
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value,
hints := ReducibilityHints.regular (getMaxHeight (← getEnv) preDef.value + 1),
safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe }
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
hints := ReducibilityHints.regular (getMaxHeight (← getEnv) preDef.value + 1)
safety := if preDef.modifiers.isUnsafe then DefinitionSafety.unsafe else DefinitionSafety.safe,
all }
addDecl decl
withSaveInfoContext do -- save new env
addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true)
@ -128,11 +130,11 @@ private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (applyAttrAft
if applyAttrAfterCompilation then
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
def addAndCompileNonRec (preDef : PreDefinition) : TermElabM Unit := do
addNonRecAux preDef true
def addAndCompileNonRec (preDef : PreDefinition) (all : List Name := [preDef.declName]) : TermElabM Unit := do
addNonRecAux preDef (compile := true) (all := all)
def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) : TermElabM Unit := do
addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation)
def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all : List Name := [preDef.declName]) : TermElabM Unit := do
addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation) (all := all)
/--
Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module
@ -146,13 +148,14 @@ def eraseRecAppSyntax (preDef : PreDefinition) : CoreM PreDefinition :=
def addAndCompileUnsafe (preDefs : Array PreDefinition) (safety := DefinitionSafety.unsafe) : TermElabM Unit := do
let preDefs ← preDefs.mapM fun d => eraseRecAppSyntax d
withRef preDefs[0].ref do
let all := preDefs.toList.map (·.declName)
let decl := Declaration.mutualDefnDecl <| ← preDefs.toList.mapM fun preDef => return {
name := preDef.declName
levelParams := preDef.levelParams
type := preDef.type
value := preDef.value
safety := safety
hints := ReducibilityHints.opaque
safety, all
}
addDecl decl
withSaveInfoContext do -- save new env

View file

@ -19,15 +19,16 @@ structure TerminationHints where
private def addAndCompilePartial (preDefs : Array PreDefinition) (useSorry := false) : TermElabM Unit := do
for preDef in preDefs do
trace[Elab.definition] "processing {preDef.declName}"
let all := preDefs.toList.map (·.declName)
forallTelescope preDef.type fun xs type => do
let val ← if useSorry then
let value ← if useSorry then
mkLambdaFVars xs (← mkSorry type (synthetic := true))
else
liftM <| mkInhabitantFor preDef.declName xs type
addNonRec { preDef with
kind := DefKind.«opaque»
value := val
}
value
} (all := all)
addAndCompilePartialRec preDefs
private def isNonRecursive (preDef : PreDefinition) : Bool :=

View file

@ -26,6 +26,7 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR
if (← isOnlyOneUnaryDef preDefs fixedPrefixSize) then
return ()
let us := preDefNonRec.levelParams.map mkLevelParam
let all := preDefs.toList.map (·.declName)
for fidx in [:preDefs.size] do
let preDef := preDefs[fidx]
let value ← lambdaTelescope preDef.value fun xs _ => do
@ -47,7 +48,7 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR
let arg ← mkSum 0 domain
mkLambdaFVars xs (mkApp (mkAppN (mkConst preDefNonRec.declName us) xs[:fixedPrefixSize]) arg)
trace[Elab.definition.wf] "{preDef.declName} := {value}"
addNonRec { preDef with value } (applyAttrAfterCompilation := false)
addNonRec { preDef with value } (applyAttrAfterCompilation := false) (all := all)
partial def withCommonTelescope (preDefs : Array PreDefinition) (k : Array Expr → Array Expr → TermElabM α) : TermElabM α :=
go #[] (preDefs.map (·.value))

View file

@ -0,0 +1,114 @@
import Lean
open Lean Meta in
def printMutualBlock (declName : Name) : MetaM Unit := do
IO.println (← getConstInfo declName).all
mutual
def even : Nat → Bool
| 0 => true
| x+1 => !odd x
def odd : Nat → Bool
| 0 => false
| x+1 => !even x
end
#eval printMutualBlock ``even
#eval printMutualBlock ``odd
namespace Ex1
mutual
partial def f (x : Nat) : Nat :=
if x < 10 then g x + 1 else 0
partial def g (x : Nat) : Nat :=
f (x * 3 / 2)
partial def h (x : Nat) : Nat :=
f x
end
#eval printMutualBlock ``f
#eval printMutualBlock ``g
-- Recall that Lean breaks a mutual block into strongly connected components
#eval printMutualBlock ``h
end Ex1
namespace Ex2
mutual
unsafe def f (x : Nat) : Nat :=
if x < 10 then g x + 1 else 0
unsafe def g (x : Nat) : Nat :=
f (x * 3 / 2)
unsafe def h (x : Nat) : Nat :=
f x
end
#eval printMutualBlock ``f
#eval printMutualBlock ``g
-- Recall that Lean breaks a mutual block into strongly connected components
#eval printMutualBlock ``h
end Ex2
inductive Foo where
| text : String → Foo
| element : List Foo → Foo
namespace Foo
mutual
@[simp] def textLengthList : List Foo → Nat
| [] => 0
| f::fs => textLength f + textLengthList fs
@[simp] def textLength : Foo → Nat
| text s => s.length
| element children => textLengthList children
end
def concat (f₁ f₂ : Foo) : Foo :=
Foo.element [f₁, f₂]
theorem textLength_concat (f₁ f₂ : Foo) : textLength (concat f₁ f₂) = textLength f₁ + textLength f₂ := by
simp [concat]
mutual
@[simp] def flatList : List Foo → List String
| [] => []
| f :: fs => flat f ++ flatList fs
@[simp] def flat : Foo → List String
| text s => [s]
| element children => flatList children
end
def listStringLen (xs : List String) : Nat :=
xs.foldl (init := 0) fun sum s => sum + s.length
attribute [simp] List.foldl
theorem foldl_init (s : Nat) (xs : List String) : (xs.foldl (init := s) fun sum s => sum + s.length) = s + (xs.foldl (init := 0) fun sum s => sum + s.length) := by
induction xs generalizing s with
| nil => simp
| cons x xs ih => simp_arith [ih x.length, ih (s + x.length)]
theorem listStringLen_append (xs ys : List String) : listStringLen (xs ++ ys) = listStringLen xs + listStringLen ys := by
simp [listStringLen]
induction xs with
| nil => simp
| cons x xs ih => simp_arith [foldl_init x.length, ih]
mutual
theorem listStringLen_flat (f : Foo) : listStringLen (flat f) = textLength f := by
match f with
| text s => simp [listStringLen]
| element cs => simp [listStringLen_flatList cs]
theorem listStringLen_flatList (cs : List Foo) : listStringLen (flatList cs) = textLengthList cs := by
match cs with
| [] => simp
| f :: fs => simp [listStringLen_append, listStringLen_flatList fs, listStringLen_flat f]
end
#eval printMutualBlock ``listStringLen_flat
end Foo

View file

@ -0,0 +1,9 @@
[even, odd]
[even, odd]
[Ex1.f, Ex1.g]
[Ex1.f, Ex1.g]
[Ex1.h]
[Ex2.f, Ex2.g]
[Ex2.f, Ex2.g]
[Ex2.h]
[Foo.listStringLen_flat, Foo.listStringLen_flatList]