chore: cleanup
This commit is contained in:
parent
0ab38742db
commit
525fb7ca91
5 changed files with 448 additions and 447 deletions
|
|
@ -19,134 +19,134 @@ A (potentially recursive) definition.
|
|||
The elaborator converts it into Kernel definitions using many different strategies.
|
||||
-/
|
||||
structure PreDefinition :=
|
||||
(kind : DefKind)
|
||||
(lparams : List Name)
|
||||
(modifiers : Modifiers)
|
||||
(declName : Name)
|
||||
(type : Expr)
|
||||
(value : Expr)
|
||||
(kind : DefKind)
|
||||
(lparams : List Name)
|
||||
(modifiers : Modifiers)
|
||||
(declName : Name)
|
||||
(type : Expr)
|
||||
(value : Expr)
|
||||
|
||||
instance : Inhabited PreDefinition :=
|
||||
⟨⟨DefKind.«def», [], {}, arbitrary _, arbitrary _, arbitrary _⟩⟩
|
||||
⟨⟨DefKind.«def», [], {}, arbitrary _, arbitrary _, arbitrary _⟩⟩
|
||||
|
||||
def instantiateMVarsAtPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDef => do
|
||||
pure { preDef with type := (← instantiateMVars preDef.type), value := (← instantiateMVars preDef.value) }
|
||||
preDefs.mapM fun preDef => do
|
||||
pure { preDef with type := (← instantiateMVars preDef.type), value := (← instantiateMVars preDef.value) }
|
||||
|
||||
private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do
|
||||
let nextIdx ← get;
|
||||
let (e, nextIdx) ← levelMVarToParam e nextIdx;
|
||||
set nextIdx;
|
||||
pure e
|
||||
let nextIdx ← get
|
||||
let (e, nextIdx) ← levelMVarToParam e nextIdx;
|
||||
set nextIdx;
|
||||
pure e
|
||||
|
||||
private def levelMVarToParamPreDeclsAux (preDefs : Array PreDefinition) : StateRefT Nat TermElabM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDef => do
|
||||
pure { preDef with type := (← levelMVarToParamExpr preDef.type), value := (← levelMVarToParamExpr preDef.value) }
|
||||
preDefs.mapM fun preDef => do
|
||||
pure { preDef with type := (← levelMVarToParamExpr preDef.type), value := (← levelMVarToParamExpr preDef.value) }
|
||||
|
||||
def levelMVarToParamPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) :=
|
||||
(levelMVarToParamPreDeclsAux preDefs).run' 1
|
||||
(levelMVarToParamPreDeclsAux preDefs).run' 1
|
||||
|
||||
private def getLevelParamsPreDecls (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) := do
|
||||
let s : CollectLevelParams.State := {}
|
||||
for preDef in preDefs do
|
||||
s := collectLevelParams s preDef.type
|
||||
s := collectLevelParams s preDef.value
|
||||
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
|
||||
| Except.error msg => throwError msg
|
||||
| Except.ok levelParams => pure levelParams
|
||||
let s : CollectLevelParams.State := {}
|
||||
for preDef in preDefs do
|
||||
s := collectLevelParams s preDef.type
|
||||
s := collectLevelParams s preDef.value
|
||||
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
|
||||
| Except.error msg => throwError msg
|
||||
| Except.ok levelParams => pure levelParams
|
||||
|
||||
private def shareCommon (preDefs : Array PreDefinition) : Array PreDefinition :=
|
||||
let result : Std.ShareCommonM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDef => do
|
||||
pure { preDef with type := (← Std.withShareCommon preDef.type), value := (← Std.withShareCommon preDef.value) }
|
||||
result.run
|
||||
let result : Std.ShareCommonM (Array PreDefinition) :=
|
||||
preDefs.mapM fun preDef => do
|
||||
pure { preDef with type := (← Std.withShareCommon preDef.type), value := (← Std.withShareCommon preDef.value) }
|
||||
result.run
|
||||
|
||||
def fixLevelParams (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDefinition) := do
|
||||
let preDefs := shareCommon preDefs
|
||||
let lparams ← getLevelParamsPreDecls preDefs scopeLevelNames allUserLevelNames
|
||||
let us := lparams.map mkLevelParam
|
||||
let fixExpr (e : Expr) : Expr :=
|
||||
e.replace fun c => match c with
|
||||
| Expr.const declName _ _ => if preDefs.any fun preDef => preDef.declName == declName then some $ Lean.mkConst declName us else none
|
||||
| _ => none
|
||||
pure $ preDefs.map fun preDef =>
|
||||
{ preDef with
|
||||
type := fixExpr preDef.type,
|
||||
value := fixExpr preDef.value,
|
||||
lparams := lparams }
|
||||
let preDefs := shareCommon preDefs
|
||||
let lparams ← getLevelParamsPreDecls preDefs scopeLevelNames allUserLevelNames
|
||||
let us := lparams.map mkLevelParam
|
||||
let fixExpr (e : Expr) : Expr :=
|
||||
e.replace fun c => match c with
|
||||
| Expr.const declName _ _ => if preDefs.any fun preDef => preDef.declName == declName then some $ Lean.mkConst declName us else none
|
||||
| _ => none
|
||||
pure $ preDefs.map fun preDef =>
|
||||
{ preDef with
|
||||
type := fixExpr preDef.type,
|
||||
value := fixExpr preDef.value,
|
||||
lparams := lparams }
|
||||
|
||||
def applyAttributesOf (preDefs : Array PreDefinition) (applicationTime : AttributeApplicationTime) : TermElabM Unit := do
|
||||
for preDef in preDefs do
|
||||
applyAttributesAt preDef.declName preDef.modifiers.attrs applicationTime
|
||||
for preDef in preDefs do
|
||||
applyAttributesAt preDef.declName preDef.modifiers.attrs applicationTime
|
||||
|
||||
def abstractNestedProofs (preDef : PreDefinition) : MetaM PreDefinition :=
|
||||
if preDef.kind.isTheorem || preDef.kind.isExample then
|
||||
pure preDef
|
||||
else do
|
||||
let value ← Meta.abstractNestedProofs preDef.declName preDef.value
|
||||
pure { preDef with value := value }
|
||||
if preDef.kind.isTheorem || preDef.kind.isExample then
|
||||
pure preDef
|
||||
else do
|
||||
let value ← Meta.abstractNestedProofs preDef.declName preDef.value
|
||||
pure { preDef with value := value }
|
||||
|
||||
/- Auxiliary method for (temporarily) adding pre definition as an axiom -/
|
||||
def addAsAxiom (preDef : PreDefinition) : MetaM Unit := do
|
||||
addDecl $ Declaration.axiomDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, isUnsafe := preDef.modifiers.isUnsafe }
|
||||
addDecl $ Declaration.axiomDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, isUnsafe := preDef.modifiers.isUnsafe }
|
||||
|
||||
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) : TermElabM Unit := do
|
||||
let preDef ← abstractNestedProofs preDef
|
||||
let env ← getEnv
|
||||
let decl :=
|
||||
match preDef.kind with
|
||||
| DefKind.«example» => unreachable!
|
||||
| DefKind.«theorem» =>
|
||||
Declaration.thmDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value }
|
||||
| DefKind.«opaque» =>
|
||||
Declaration.opaqueDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
|
||||
let preDef ← abstractNestedProofs preDef
|
||||
let env ← getEnv
|
||||
let decl :=
|
||||
match preDef.kind with
|
||||
| DefKind.«example» => unreachable!
|
||||
| DefKind.«theorem» =>
|
||||
Declaration.thmDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value }
|
||||
| DefKind.«opaque» =>
|
||||
Declaration.opaqueDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
|
||||
isUnsafe := preDef.modifiers.isUnsafe }
|
||||
| DefKind.«abbrev» =>
|
||||
Declaration.defnDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
|
||||
hints := ReducibilityHints.«abbrev», isUnsafe := preDef.modifiers.isUnsafe }
|
||||
| DefKind.«def» =>
|
||||
Declaration.defnDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
|
||||
hints := ReducibilityHints.regular (getMaxHeight env preDef.value + 1),
|
||||
isUnsafe := preDef.modifiers.isUnsafe }
|
||||
| DefKind.«abbrev» =>
|
||||
Declaration.defnDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
|
||||
hints := ReducibilityHints.«abbrev», isUnsafe := preDef.modifiers.isUnsafe }
|
||||
| DefKind.«def» =>
|
||||
Declaration.defnDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
|
||||
hints := ReducibilityHints.regular (getMaxHeight env preDef.value + 1),
|
||||
isUnsafe := preDef.modifiers.isUnsafe }
|
||||
addDecl decl
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking
|
||||
when (compile && !preDef.kind.isTheorem) $
|
||||
compileDecl decl
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
|
||||
pure ()
|
||||
addDecl decl
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking
|
||||
if compile && !preDef.kind.isTheorem then
|
||||
compileDecl decl
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
|
||||
pure ()
|
||||
|
||||
def addAndCompileNonRec (preDef : PreDefinition) : TermElabM Unit := do
|
||||
addNonRecAux preDef true
|
||||
addNonRecAux preDef true
|
||||
|
||||
def addNonRec (preDef : PreDefinition) : TermElabM Unit := do
|
||||
addNonRecAux preDef false
|
||||
addNonRecAux preDef false
|
||||
|
||||
def addAndCompileUnsafe (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
let decl := Declaration.mutualDefnDecl $ preDefs.toList.map fun preDef => {
|
||||
name := preDef.declName,
|
||||
lparams := preDef.lparams,
|
||||
type := preDef.type,
|
||||
value := preDef.value,
|
||||
isUnsafe := true,
|
||||
hints := ReducibilityHints.opaque
|
||||
}
|
||||
addDecl decl
|
||||
applyAttributesOf preDefs AttributeApplicationTime.afterTypeChecking
|
||||
compileDecl decl
|
||||
applyAttributesOf preDefs AttributeApplicationTime.afterCompilation
|
||||
pure ()
|
||||
let decl := Declaration.mutualDefnDecl $ preDefs.toList.map fun preDef => {
|
||||
name := preDef.declName,
|
||||
lparams := preDef.lparams,
|
||||
type := preDef.type,
|
||||
value := preDef.value,
|
||||
isUnsafe := true,
|
||||
hints := ReducibilityHints.opaque
|
||||
}
|
||||
addDecl decl
|
||||
applyAttributesOf preDefs AttributeApplicationTime.afterTypeChecking
|
||||
compileDecl decl
|
||||
applyAttributesOf preDefs AttributeApplicationTime.afterCompilation
|
||||
pure ()
|
||||
|
||||
def addAndCompileUnsafeRec (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
addAndCompileUnsafe $ preDefs.map fun preDef =>
|
||||
{ preDef with
|
||||
declName := Compiler.mkUnsafeRecName preDef.declName,
|
||||
value := preDef.value.replace fun e => match e with
|
||||
| Expr.const declName us _ =>
|
||||
if preDefs.any fun preDef => preDef.declName == declName then
|
||||
some $ mkConst (Compiler.mkUnsafeRecName declName) us
|
||||
else
|
||||
none
|
||||
| _ => none,
|
||||
modifiers := {} }
|
||||
addAndCompileUnsafe $ preDefs.map fun preDef =>
|
||||
{ preDef with
|
||||
declName := Compiler.mkUnsafeRecName preDef.declName,
|
||||
value := preDef.value.replace fun e => match e with
|
||||
| Expr.const declName us _ =>
|
||||
if preDefs.any fun preDef => preDef.declName == declName then
|
||||
some $ mkConst (Compiler.mkUnsafeRecName declName) us
|
||||
else
|
||||
none
|
||||
| _ => none,
|
||||
modifiers := {} }
|
||||
|
||||
end Lean.Elab
|
||||
|
|
|
|||
|
|
@ -12,69 +12,70 @@ open Meta
|
|||
open Term
|
||||
|
||||
private def addAndCompilePartial (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
for preDef in preDefs do
|
||||
trace[Elab.definition]! "processing {preDef.declName}"
|
||||
forallTelescope preDef.type fun xs type => do
|
||||
let inh ← liftM $ mkInhabitantFor preDef.declName xs type
|
||||
trace[Elab.definition]! "inhabitant for {preDef.declName}"
|
||||
addNonRec { preDef with
|
||||
kind := DefKind.«opaque»,
|
||||
value := inh }
|
||||
addAndCompileUnsafeRec preDefs
|
||||
for preDef in preDefs do
|
||||
trace[Elab.definition]! "processing {preDef.declName}"
|
||||
forallTelescope preDef.type fun xs type => do
|
||||
let inh ← liftM $ mkInhabitantFor preDef.declName xs type
|
||||
trace[Elab.definition]! "inhabitant for {preDef.declName}"
|
||||
addNonRec { preDef with
|
||||
kind := DefKind.«opaque»,
|
||||
value := inh
|
||||
}
|
||||
addAndCompileUnsafeRec preDefs
|
||||
|
||||
private def isNonRecursive (preDef : PreDefinition) : Bool :=
|
||||
Option.isNone $ preDef.value.find? fun
|
||||
| Expr.const declName _ _ => preDef.declName == declName
|
||||
| _ => false
|
||||
Option.isNone $ preDef.value.find? fun
|
||||
| Expr.const declName _ _ => preDef.declName == declName
|
||||
| _ => false
|
||||
|
||||
private def partitionPreDefs (preDefs : Array PreDefinition) : Array (Array PreDefinition) :=
|
||||
let getPreDef := fun declName => (preDefs.find? fun preDef => preDef.declName == declName).get!
|
||||
let vertices := preDefs.toList.map (·.declName)
|
||||
let successorsOf := fun declName => (getPreDef declName).value.foldConsts [] fun declName successors =>
|
||||
if preDefs.any fun preDef => preDef.declName == declName then
|
||||
declName :: successors
|
||||
else
|
||||
successors
|
||||
let sccs := SCC.scc vertices successorsOf
|
||||
sccs.toArray.map fun scc => scc.toArray.map getPreDef
|
||||
let getPreDef := fun declName => (preDefs.find? fun preDef => preDef.declName == declName).get!
|
||||
let vertices := preDefs.toList.map (·.declName)
|
||||
let successorsOf := fun declName => (getPreDef declName).value.foldConsts [] fun declName successors =>
|
||||
if preDefs.any fun preDef => preDef.declName == declName then
|
||||
declName :: successors
|
||||
else
|
||||
successors
|
||||
let sccs := SCC.scc vertices successorsOf
|
||||
sccs.toArray.map fun scc => scc.toArray.map getPreDef
|
||||
|
||||
private def collectMVarsAtPreDef (preDef : PreDefinition) : StateRefT CollectMVars.State MetaM Unit := do
|
||||
collectMVars preDef.value
|
||||
collectMVars preDef.type
|
||||
collectMVars preDef.value
|
||||
collectMVars preDef.type
|
||||
|
||||
private def getMVarsAtPreDef (preDef : PreDefinition) : MetaM (Array MVarId) := do
|
||||
let (_, s) ← (collectMVarsAtPreDef preDef).run {}
|
||||
pure s.result
|
||||
let (_, s) ← (collectMVarsAtPreDef preDef).run {}
|
||||
pure s.result
|
||||
|
||||
private def ensureNoUnassignedMVarsAtPreDef (preDef : PreDefinition) : TermElabM Unit := do
|
||||
let pendingMVarIds ← liftMetaM $ getMVarsAtPreDef preDef
|
||||
if ← logUnassignedUsingErrorInfos pendingMVarIds then
|
||||
throwAbort
|
||||
let pendingMVarIds ← liftMetaM $ getMVarsAtPreDef preDef
|
||||
if ← logUnassignedUsingErrorInfos pendingMVarIds then
|
||||
throwAbort
|
||||
|
||||
def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
for preDef in preDefs do
|
||||
trace[Elab.definition.body]! "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
|
||||
for preDef in preDefs do
|
||||
ensureNoUnassignedMVarsAtPreDef preDef
|
||||
for preDefs in partitionPreDefs preDefs do
|
||||
trace[Elab.definition.scc]! "{preDefs.map (·.declName)}"
|
||||
if preDefs.size == 1 && isNonRecursive preDefs[0] then
|
||||
let preDef := preDefs[0]
|
||||
if preDef.modifiers.isNoncomputable then
|
||||
addNonRec preDef
|
||||
for preDef in preDefs do
|
||||
trace[Elab.definition.body]! "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
|
||||
for preDef in preDefs do
|
||||
ensureNoUnassignedMVarsAtPreDef preDef
|
||||
for preDefs in partitionPreDefs preDefs do
|
||||
trace[Elab.definition.scc]! "{preDefs.map (·.declName)}"
|
||||
if preDefs.size == 1 && isNonRecursive preDefs[0] then
|
||||
let preDef := preDefs[0]
|
||||
if preDef.modifiers.isNoncomputable then
|
||||
addNonRec preDef
|
||||
else
|
||||
addAndCompileNonRec preDef
|
||||
else if preDefs.any (·.modifiers.isUnsafe) then
|
||||
addAndCompileUnsafe preDefs
|
||||
else if preDefs.any (·.modifiers.isPartial) then
|
||||
addAndCompilePartial preDefs
|
||||
else
|
||||
addAndCompileNonRec preDef
|
||||
else if preDefs.any (·.modifiers.isUnsafe) then
|
||||
addAndCompileUnsafe preDefs
|
||||
else if preDefs.any (·.modifiers.isPartial) then
|
||||
addAndCompilePartial preDefs
|
||||
else
|
||||
mapError
|
||||
(orelseMergeErrors
|
||||
(structuralRecursion preDefs)
|
||||
(WFRecursion preDefs))
|
||||
(fun msg =>
|
||||
let preDefMsgs := preDefs.toList.map (MessageData.ofExpr $ mkConst ·.declName)
|
||||
msg!"fail to show termination for{indentD (MessageData.joinSep preDefMsgs Format.line)}\nwith errors\n{msg}")
|
||||
mapError
|
||||
(orelseMergeErrors
|
||||
(structuralRecursion preDefs)
|
||||
(WFRecursion preDefs))
|
||||
(fun msg =>
|
||||
let preDefMsgs := preDefs.toList.map (MessageData.ofExpr $ mkConst ·.declName)
|
||||
msg!"fail to show termination for{indentD (MessageData.joinSep preDefMsgs Format.line)}\nwith errors\n{msg}")
|
||||
|
||||
end Lean.Elab
|
||||
|
|
|
|||
|
|
@ -9,36 +9,36 @@ namespace Lean.Elab
|
|||
open Meta
|
||||
|
||||
private def mkInhabitant? (type : Expr) : MetaM (Option Expr) := do
|
||||
try
|
||||
pure $ some (← mkAppM `arbitrary #[type])
|
||||
catch _ =>
|
||||
pure none
|
||||
try
|
||||
pure $ some (← mkAppM `arbitrary #[type])
|
||||
catch _ =>
|
||||
pure none
|
||||
|
||||
private def findAssumption? (xs : Array Expr) (type : Expr) : MetaM (Option Expr) := do
|
||||
xs.findM? fun x => do isDefEq (← inferType x) type
|
||||
xs.findM? fun x => do isDefEq (← inferType x) type
|
||||
|
||||
private def mkFnInhabitant? (xs : Array Expr) (type : Expr) : MetaM (Option Expr) :=
|
||||
let rec loop
|
||||
| 0, type => mkInhabitant? type
|
||||
| i+1, type => do
|
||||
let x := xs[i]
|
||||
let type ← mkForallFVars #[x] type;
|
||||
match ← mkInhabitant? type with
|
||||
| none => loop i type
|
||||
| some val => pure $ some (← mkLambdaFVars xs[0:i] val)
|
||||
loop xs.size type
|
||||
let rec loop
|
||||
| 0, type => mkInhabitant? type
|
||||
| i+1, type => do
|
||||
let x := xs[i]
|
||||
let type ← mkForallFVars #[x] type;
|
||||
match ← mkInhabitant? type with
|
||||
| none => loop i type
|
||||
| some val => pure $ some (← mkLambdaFVars xs[0:i] val)
|
||||
loop xs.size type
|
||||
|
||||
/- TODO: add a global IO.Ref to let users customize/extend this procedure -/
|
||||
|
||||
def mkInhabitantFor (declName : Name) (xs : Array Expr) (type : Expr) : MetaM Expr := do
|
||||
match ← mkInhabitant? type with
|
||||
| some val => mkLambdaFVars xs val
|
||||
| none =>
|
||||
match ← findAssumption? xs type with
|
||||
| some x => mkLambdaFVars xs x
|
||||
| none =>
|
||||
match ← mkFnInhabitant? xs type with
|
||||
| some val => pure val
|
||||
| none => throwError! "failed to compile partial definition '{declName}', failed to show that type is inhabited"
|
||||
match ← mkInhabitant? type with
|
||||
| some val => mkLambdaFVars xs val
|
||||
| none =>
|
||||
match ← findAssumption? xs type with
|
||||
| some x => mkLambdaFVars xs x
|
||||
| none =>
|
||||
match ← mkFnInhabitant? xs type with
|
||||
| some val => pure val
|
||||
| none => throwError! "failed to compile partial definition '{declName}', failed to show that type is inhabited"
|
||||
|
||||
end Lean.Elab
|
||||
|
|
|
|||
|
|
@ -13,173 +13,173 @@ namespace Lean.Elab
|
|||
open Meta
|
||||
|
||||
private def getFixedPrefix (declName : Name) (xs : Array Expr) (value : Expr) : Nat :=
|
||||
let visitor {ω} : StateRefT Nat (ST ω) Unit :=
|
||||
value.forEach' fun e =>
|
||||
if e.isAppOf declName then do
|
||||
let args := e.getAppArgs
|
||||
modify fun numFixed => if args.size < numFixed then args.size else numFixed
|
||||
-- we continue searching if the e's arguments are not a prefix of `xs`
|
||||
pure !args.isPrefixOf xs
|
||||
else
|
||||
pure true
|
||||
runST fun _ => do let (_, numFixed) ← visitor.run xs.size; pure numFixed
|
||||
let visitor {ω} : StateRefT Nat (ST ω) Unit :=
|
||||
value.forEach' fun e =>
|
||||
if e.isAppOf declName then do
|
||||
let args := e.getAppArgs
|
||||
modify fun numFixed => if args.size < numFixed then args.size else numFixed
|
||||
-- we continue searching if the e's arguments are not a prefix of `xs`
|
||||
pure !args.isPrefixOf xs
|
||||
else
|
||||
pure true
|
||||
runST fun _ => do let (_, numFixed) ← visitor.run xs.size; pure numFixed
|
||||
|
||||
structure RecArgInfo :=
|
||||
/- `fixedParams ++ ys` are the arguments of the function we are trying to justify termination using structural recursion. -/
|
||||
(fixedParams : Array Expr)
|
||||
(ys : Array Expr) -- recursion arguments
|
||||
(pos : Nat) -- position in `ys` of the argument we are recursing on
|
||||
(indicesPos : Array Nat) -- position in `ys` of the inductive datatype indices we are recursing on
|
||||
(indName : Name) -- inductive datatype name of the argument we are recursing on
|
||||
(indLevels : List Level) -- inductice datatype universe levels of the argument we are recursing on
|
||||
(indParams : Array Expr) -- inductive datatype parameters of the argument we are recursing on
|
||||
(indIndices : Array Expr) -- inductive datatype indices of the argument we are recursing on, it is equal to `indicesPos.map fun i => ys.get! i`
|
||||
(reflexive : Bool) -- true if we are recursing over a reflexive inductive datatype
|
||||
/- `fixedParams ++ ys` are the arguments of the function we are trying to justify termination using structural recursion. -/
|
||||
(fixedParams : Array Expr)
|
||||
(ys : Array Expr) -- recursion arguments
|
||||
(pos : Nat) -- position in `ys` of the argument we are recursing on
|
||||
(indicesPos : Array Nat) -- position in `ys` of the inductive datatype indices we are recursing on
|
||||
(indName : Name) -- inductive datatype name of the argument we are recursing on
|
||||
(indLevels : List Level) -- inductice datatype universe levels of the argument we are recursing on
|
||||
(indParams : Array Expr) -- inductive datatype parameters of the argument we are recursing on
|
||||
(indIndices : Array Expr) -- inductive datatype indices of the argument we are recursing on, it is equal to `indicesPos.map fun i => ys.get! i`
|
||||
(reflexive : Bool) -- true if we are recursing over a reflexive inductive datatype
|
||||
|
||||
private def getIndexMinPos (xs : Array Expr) (indices : Array Expr) : Nat := do
|
||||
let minPos := xs.size
|
||||
for index in indices do
|
||||
match xs.indexOf? index with
|
||||
| some pos => if pos.val < minPos then minPos := pos.val
|
||||
| _ => pure ()
|
||||
return minPos
|
||||
let minPos := xs.size
|
||||
for index in indices do
|
||||
match xs.indexOf? index with
|
||||
| some pos => if pos.val < minPos then minPos := pos.val
|
||||
| _ => pure ()
|
||||
return minPos
|
||||
|
||||
-- Indices can only depend on other indices
|
||||
private def hasBadIndexDep? (ys : Array Expr) (indices : Array Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
for index in indices do
|
||||
let indexType ← inferType index
|
||||
for y in ys do
|
||||
if !indices.contains y && (← dependsOn indexType y.fvarId!) then
|
||||
return some (index, y)
|
||||
return none
|
||||
for index in indices do
|
||||
let indexType ← inferType index
|
||||
for y in ys do
|
||||
if !indices.contains y && (← dependsOn indexType y.fvarId!) then
|
||||
return some (index, y)
|
||||
return none
|
||||
|
||||
-- Inductive datatype parameters cannot depend on ys
|
||||
private def hasBadParamDep? (ys : Array Expr) (indParams : Array Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
for p in indParams do
|
||||
let pType ← inferType p
|
||||
for y in ys do
|
||||
if ← dependsOn pType y.fvarId! then
|
||||
return some (p, y)
|
||||
return none
|
||||
for p in indParams do
|
||||
let pType ← inferType p
|
||||
for y in ys do
|
||||
if ← dependsOn pType y.fvarId! then
|
||||
return some (p, y)
|
||||
return none
|
||||
|
||||
private def throwStructuralFailed {α} : MetaM α :=
|
||||
throwError "structural recursion cannot be used"
|
||||
throwError "structural recursion cannot be used"
|
||||
|
||||
private partial def findRecArg {α} (numFixed : Nat) (xs : Array Expr) (k : RecArgInfo → MetaM α) : MetaM α :=
|
||||
let rec loop (i : Nat) : MetaM α := do
|
||||
if h : i < xs.size then
|
||||
let x := xs.get ⟨i, h⟩
|
||||
let localDecl ← getFVarLocalDecl x
|
||||
if localDecl.isLet then
|
||||
throwStructuralFailed
|
||||
else
|
||||
let xType ← whnfD localDecl.type
|
||||
matchConstInduct xType.getAppFn (fun _ => loop (i+1)) fun indInfo us => do
|
||||
if !(← hasConst (mkBRecOnFor indInfo.name)) then
|
||||
loop (i+1)
|
||||
else if indInfo.isReflexive && !(← hasConst (mkBInductionOnFor indInfo.name)) then
|
||||
loop (i+1)
|
||||
let rec loop (i : Nat) : MetaM α := do
|
||||
if h : i < xs.size then
|
||||
let x := xs.get ⟨i, h⟩
|
||||
let localDecl ← getFVarLocalDecl x
|
||||
if localDecl.isLet then
|
||||
throwStructuralFailed
|
||||
else
|
||||
let indArgs := xType.getAppArgs
|
||||
let indParams := indArgs.extract 0 indInfo.nparams
|
||||
let indIndices := indArgs.extract indInfo.nparams indArgs.size
|
||||
if !indIndices.all Expr.isFVar then
|
||||
orelseMergeErrors
|
||||
(throwError! "argument #{i+1} was not used because its type is an inductive family and indices are not variables{indentExpr xType}")
|
||||
(loop (i+1))
|
||||
else if !indIndices.allDiff then
|
||||
orelseMergeErrors
|
||||
(throwError! "argument #{i+1} was not used because its type is an inductive family and indices are not pairwise distinct{indentExpr xType}")
|
||||
(loop (i+1))
|
||||
let xType ← whnfD localDecl.type
|
||||
matchConstInduct xType.getAppFn (fun _ => loop (i+1)) fun indInfo us => do
|
||||
if !(← hasConst (mkBRecOnFor indInfo.name)) then
|
||||
loop (i+1)
|
||||
else if indInfo.isReflexive && !(← hasConst (mkBInductionOnFor indInfo.name)) then
|
||||
loop (i+1)
|
||||
else
|
||||
let indexMinPos := getIndexMinPos xs indIndices
|
||||
let numFixed := if indexMinPos < numFixed then indexMinPos else numFixed
|
||||
let fixedParams := xs.extract 0 numFixed
|
||||
let ys := xs.extract numFixed xs.size
|
||||
match ← hasBadIndexDep? ys indIndices with
|
||||
| some (index, y) =>
|
||||
let indArgs := xType.getAppArgs
|
||||
let indParams := indArgs.extract 0 indInfo.nparams
|
||||
let indIndices := indArgs.extract indInfo.nparams indArgs.size
|
||||
if !indIndices.all Expr.isFVar then
|
||||
orelseMergeErrors
|
||||
(throwError! "argument #{i+1} was not used because its type is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}")
|
||||
(throwError! "argument #{i+1} was not used because its type is an inductive family and indices are not variables{indentExpr xType}")
|
||||
(loop (i+1))
|
||||
| none =>
|
||||
match ← hasBadParamDep? ys indParams with
|
||||
| some (indParam, y) =>
|
||||
else if !indIndices.allDiff then
|
||||
orelseMergeErrors
|
||||
(throwError! "argument #{i+1} was not used because its type is an inductive family and indices are not pairwise distinct{indentExpr xType}")
|
||||
(loop (i+1))
|
||||
else
|
||||
let indexMinPos := getIndexMinPos xs indIndices
|
||||
let numFixed := if indexMinPos < numFixed then indexMinPos else numFixed
|
||||
let fixedParams := xs.extract 0 numFixed
|
||||
let ys := xs.extract numFixed xs.size
|
||||
match ← hasBadIndexDep? ys indIndices with
|
||||
| some (index, y) =>
|
||||
orelseMergeErrors
|
||||
(throwError! "argument #{i+1} was not used because its type is an inductive datatype{indentExpr xType}\nand parameter{indentExpr indParam}\ndepends on{indentExpr y}")
|
||||
(throwError! "argument #{i+1} was not used because its type is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}")
|
||||
(loop (i+1))
|
||||
| none =>
|
||||
let indicesPos := indIndices.map fun index => match ys.indexOf? index with | some i => i.val | none => unreachable!
|
||||
orelseMergeErrors
|
||||
(k { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size,
|
||||
indicesPos := indicesPos,
|
||||
indName := indInfo.name,
|
||||
indLevels := us,
|
||||
indParams := indParams,
|
||||
indIndices := indIndices,
|
||||
reflexive := indInfo.isReflexive })
|
||||
(loop (i+1))
|
||||
else
|
||||
throwStructuralFailed
|
||||
loop numFixed
|
||||
match ← hasBadParamDep? ys indParams with
|
||||
| some (indParam, y) =>
|
||||
orelseMergeErrors
|
||||
(throwError! "argument #{i+1} was not used because its type is an inductive datatype{indentExpr xType}\nand parameter{indentExpr indParam}\ndepends on{indentExpr y}")
|
||||
(loop (i+1))
|
||||
| none =>
|
||||
let indicesPos := indIndices.map fun index => match ys.indexOf? index with | some i => i.val | none => unreachable!
|
||||
orelseMergeErrors
|
||||
(k { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size,
|
||||
indicesPos := indicesPos,
|
||||
indName := indInfo.name,
|
||||
indLevels := us,
|
||||
indParams := indParams,
|
||||
indIndices := indIndices,
|
||||
reflexive := indInfo.isReflexive })
|
||||
(loop (i+1))
|
||||
else
|
||||
throwStructuralFailed
|
||||
loop numFixed
|
||||
|
||||
private def containsRecFn (recFnName : Name) (e : Expr) : Bool :=
|
||||
(e.find? fun e => e.isConstOf recFnName).isSome
|
||||
(e.find? fun e => e.isConstOf recFnName).isSome
|
||||
|
||||
private def ensureNoRecFn (recFnName : Name) (e : Expr) : MetaM Expr := do
|
||||
if containsRecFn recFnName e then
|
||||
Meta.forEachExpr e fun e => do
|
||||
if e.isAppOf recFnName then
|
||||
throwError! "unexpected occurrence of recursive application{indentExpr e}"
|
||||
pure e
|
||||
else
|
||||
pure e
|
||||
if containsRecFn recFnName e then
|
||||
Meta.forEachExpr e fun e => do
|
||||
if e.isAppOf recFnName then
|
||||
throwError! "unexpected occurrence of recursive application{indentExpr e}"
|
||||
pure e
|
||||
else
|
||||
pure e
|
||||
|
||||
private def throwToBelowFailed {α} : MetaM α :=
|
||||
throwError "toBelow failed"
|
||||
throwError "toBelow failed"
|
||||
|
||||
/- See toBelow -/
|
||||
private partial def toBelowAux (C : Expr) : Expr → Expr → Expr → MetaM Expr
|
||||
| belowDict, arg, F => do
|
||||
belowDict ← whnf belowDict
|
||||
trace[Elab.definition.structural]! "belowDict: {belowDict}, arg: {arg}"
|
||||
match belowDict with
|
||||
| Expr.app (Expr.app (Expr.const `PProd _ _) d1 _) d2 _ =>
|
||||
(do toBelowAux C d1 arg (← mkAppM `PProd.fst #[F]))
|
||||
<|>
|
||||
(do toBelowAux C d2 arg (← mkAppM `PProd.snd #[F]))
|
||||
| Expr.app (Expr.app (Expr.const `And _ _) d1 _) d2 _ =>
|
||||
(do toBelowAux C d1 arg (← mkAppM `And.left #[F]))
|
||||
<|>
|
||||
(do toBelowAux C d2 arg (← mkAppM `And.right #[F]))
|
||||
| _ => forallTelescopeReducing belowDict fun xs belowDict => do
|
||||
let argArgs := arg.getAppArgs
|
||||
unless argArgs.size >= xs.size do throwToBelowFailed
|
||||
let n := argArgs.size
|
||||
let argTailArgs := argArgs.extract (n - xs.size) n
|
||||
let belowDict := belowDict.replaceFVars xs argTailArgs
|
||||
| belowDict, arg, F => do
|
||||
belowDict ← whnf belowDict
|
||||
trace[Elab.definition.structural]! "belowDict: {belowDict}, arg: {arg}"
|
||||
match belowDict with
|
||||
| Expr.app belowDictFun belowDictArg _ =>
|
||||
unless belowDictFun.getAppFn == C do throwToBelowFailed
|
||||
unless ← isDefEq belowDictArg arg do throwToBelowFailed
|
||||
pure (mkAppN F argTailArgs)
|
||||
| _ => throwToBelowFailed
|
||||
| Expr.app (Expr.app (Expr.const `PProd _ _) d1 _) d2 _ =>
|
||||
(do toBelowAux C d1 arg (← mkAppM `PProd.fst #[F]))
|
||||
<|>
|
||||
(do toBelowAux C d2 arg (← mkAppM `PProd.snd #[F]))
|
||||
| Expr.app (Expr.app (Expr.const `And _ _) d1 _) d2 _ =>
|
||||
(do toBelowAux C d1 arg (← mkAppM `And.left #[F]))
|
||||
<|>
|
||||
(do toBelowAux C d2 arg (← mkAppM `And.right #[F]))
|
||||
| _ => forallTelescopeReducing belowDict fun xs belowDict => do
|
||||
let argArgs := arg.getAppArgs
|
||||
unless argArgs.size >= xs.size do throwToBelowFailed
|
||||
let n := argArgs.size
|
||||
let argTailArgs := argArgs.extract (n - xs.size) n
|
||||
let belowDict := belowDict.replaceFVars xs argTailArgs
|
||||
match belowDict with
|
||||
| Expr.app belowDictFun belowDictArg _ =>
|
||||
unless belowDictFun.getAppFn == C do throwToBelowFailed
|
||||
unless ← isDefEq belowDictArg arg do throwToBelowFailed
|
||||
pure (mkAppN F argTailArgs)
|
||||
| _ => throwToBelowFailed
|
||||
|
||||
/- See toBelow -/
|
||||
private def withBelowDict {α} (below : Expr) (numIndParams : Nat) (k : Expr → Expr → MetaM α) : MetaM α := do
|
||||
let belowType ← inferType below
|
||||
trace[Elab.definition.structural]! "belowType: {belowType}"
|
||||
belowType.withApp fun f args => do
|
||||
let motivePos := numIndParams + 1
|
||||
unless motivePos < args.size do throwError! "unexpected 'below' type{indentExpr belowType}"
|
||||
let pre := mkAppN f (args.extract 0 numIndParams)
|
||||
let preType ← inferType pre
|
||||
forallBoundedTelescope preType (some 1) fun x _ => do
|
||||
let motiveType ← inferType x[0]
|
||||
let C ← mkFreshUserName `C
|
||||
withLocalDeclD C motiveType fun C =>
|
||||
let belowDict := mkApp pre C
|
||||
let belowDict := mkAppN belowDict (args.extract (numIndParams + 1) args.size)
|
||||
k C belowDict
|
||||
let belowType ← inferType below
|
||||
trace[Elab.definition.structural]! "belowType: {belowType}"
|
||||
belowType.withApp fun f args => do
|
||||
let motivePos := numIndParams + 1
|
||||
unless motivePos < args.size do throwError! "unexpected 'below' type{indentExpr belowType}"
|
||||
let pre := mkAppN f (args.extract 0 numIndParams)
|
||||
let preType ← inferType pre
|
||||
forallBoundedTelescope preType (some 1) fun x _ => do
|
||||
let motiveType ← inferType x[0]
|
||||
let C ← mkFreshUserName `C
|
||||
withLocalDeclD C motiveType fun C =>
|
||||
let belowDict := mkApp pre C
|
||||
let belowDict := mkAppN belowDict (args.extract (numIndParams + 1) args.size)
|
||||
k C belowDict
|
||||
|
||||
/-
|
||||
`below` is a free variable with type of the form `I.below indParams motive indices major`,
|
||||
|
|
@ -202,160 +202,160 @@ belowType.withApp fun f args => do
|
|||
The dictionary is built using the `PProd` (`And` for inductive predicates).
|
||||
We keep searching it until we find `C recArg`, where `C` is the auxiliary fresh variable created at `withBelowDict`. -/
|
||||
private partial def toBelow (below : Expr) (numIndParams : Nat) (recArg : Expr) : MetaM Expr := do
|
||||
withBelowDict below numIndParams fun C belowDict =>
|
||||
toBelowAux C belowDict recArg below
|
||||
withBelowDict below numIndParams fun C belowDict =>
|
||||
toBelowAux C belowDict recArg below
|
||||
|
||||
private partial def replaceRecApps (recFnName : Name) (recArgInfo : RecArgInfo) (below : Expr) (e : Expr) : MetaM Expr :=
|
||||
let rec loop : Expr → Expr → MetaM Expr
|
||||
| below, e@(Expr.lam n d b c) => do
|
||||
withLocalDecl n c.binderInfo (← loop below d) fun x => do
|
||||
mkLambdaFVars #[x] (← loop below (b.instantiate1 x))
|
||||
| below, e@(Expr.forallE n d b c) => do
|
||||
withLocalDecl n c.binderInfo (← loop below d) fun x => do
|
||||
mkForallFVars #[x] (← loop below (b.instantiate1 x))
|
||||
| below, Expr.letE n type val body _ => do
|
||||
withLetDecl n (← loop below type) (← loop below val) fun x => do
|
||||
mkLetFVars #[x] (← loop below (body.instantiate1 x))
|
||||
| below, Expr.mdata d e _ => do pure $ mkMData d (← loop below e)
|
||||
| below, Expr.proj n i e _ => do pure $ mkProj n i (← loop below e)
|
||||
| below, e@(Expr.app _ _ _) => do
|
||||
let processApp (e : Expr) : MetaM Expr :=
|
||||
e.withApp fun f args => do
|
||||
if f.isConstOf recFnName then
|
||||
let numFixed := recArgInfo.fixedParams.size
|
||||
let recArgPos := recArgInfo.fixedParams.size + recArgInfo.pos
|
||||
if recArgPos >= args.size then
|
||||
throwError! "insufficient number of parameters at recursive application {indentExpr e}"
|
||||
let recArg := args[recArgPos]
|
||||
let f ← try toBelow below recArgInfo.indParams.size recArg catch _ => throwError! "failed to eliminate recursive application{indentExpr e}"
|
||||
-- Recall that the fixed parameters are not in the scope of the `brecOn`. So, we skip them.
|
||||
let argsNonFixed := args.extract numFixed args.size
|
||||
-- The function `f` does not explicitly take `recArg` and its indices as arguments. So, we skip them too.
|
||||
let fArgs := #[]
|
||||
for i in [:argsNonFixed.size] do
|
||||
if recArgInfo.pos != i && !recArgInfo.indicesPos.contains i then
|
||||
let arg := argsNonFixed[i]
|
||||
let arg ← replaceRecApps recFnName recArgInfo below arg
|
||||
fArgs := fArgs.push arg
|
||||
pure $ mkAppN f fArgs
|
||||
let rec loop : Expr → Expr → MetaM Expr
|
||||
| below, e@(Expr.lam n d b c) => do
|
||||
withLocalDecl n c.binderInfo (← loop below d) fun x => do
|
||||
mkLambdaFVars #[x] (← loop below (b.instantiate1 x))
|
||||
| below, e@(Expr.forallE n d b c) => do
|
||||
withLocalDecl n c.binderInfo (← loop below d) fun x => do
|
||||
mkForallFVars #[x] (← loop below (b.instantiate1 x))
|
||||
| below, Expr.letE n type val body _ => do
|
||||
withLetDecl n (← loop below type) (← loop below val) fun x => do
|
||||
mkLetFVars #[x] (← loop below (body.instantiate1 x))
|
||||
| below, Expr.mdata d e _ => do pure $ mkMData d (← loop below e)
|
||||
| below, Expr.proj n i e _ => do pure $ mkProj n i (← loop below e)
|
||||
| below, e@(Expr.app _ _ _) => do
|
||||
let processApp (e : Expr) : MetaM Expr :=
|
||||
e.withApp fun f args => do
|
||||
if f.isConstOf recFnName then
|
||||
let numFixed := recArgInfo.fixedParams.size
|
||||
let recArgPos := recArgInfo.fixedParams.size + recArgInfo.pos
|
||||
if recArgPos >= args.size then
|
||||
throwError! "insufficient number of parameters at recursive application {indentExpr e}"
|
||||
let recArg := args[recArgPos]
|
||||
let f ← try toBelow below recArgInfo.indParams.size recArg catch _ => throwError! "failed to eliminate recursive application{indentExpr e}"
|
||||
-- Recall that the fixed parameters are not in the scope of the `brecOn`. So, we skip them.
|
||||
let argsNonFixed := args.extract numFixed args.size
|
||||
-- The function `f` does not explicitly take `recArg` and its indices as arguments. So, we skip them too.
|
||||
let fArgs := #[]
|
||||
for i in [:argsNonFixed.size] do
|
||||
if recArgInfo.pos != i && !recArgInfo.indicesPos.contains i then
|
||||
let arg := argsNonFixed[i]
|
||||
let arg ← replaceRecApps recFnName recArgInfo below arg
|
||||
fArgs := fArgs.push arg
|
||||
pure $ mkAppN f fArgs
|
||||
else
|
||||
pure $ mkAppN (← loop below f) (← args.mapM (loop below))
|
||||
let matcherApp? ← matchMatcherApp? e
|
||||
match matcherApp? with
|
||||
| some matcherApp =>
|
||||
if !containsRecFn recFnName e then
|
||||
processApp e
|
||||
else
|
||||
pure $ mkAppN (← loop below f) (← args.mapM (loop below))
|
||||
let matcherApp? ← matchMatcherApp? e
|
||||
match matcherApp? with
|
||||
| some matcherApp =>
|
||||
if !containsRecFn recFnName e then
|
||||
processApp e
|
||||
else
|
||||
/- If we first try to process the `match` as a regular application. If it fails, then we try to `push` the below over the dependent `match`.
|
||||
This is useful for examples such as:
|
||||
```
|
||||
def f (xs : List Nat) : Nat :=
|
||||
match xs with
|
||||
| [] => 0
|
||||
| y::ys =>
|
||||
match ys with
|
||||
| [] => 1
|
||||
| zs => f ys + 1
|
||||
```
|
||||
We are matching on `ys`, but still using `ys` in the second alternative.
|
||||
If we push the `below` argument over the dependent match it will be able to eliminate recursive call using `zs`.
|
||||
This trick is not sufficient for the slightly more complicated example:
|
||||
```
|
||||
def g (xs : List Nat) : Nat :=
|
||||
match xs with
|
||||
| [] => 0
|
||||
| y::ys =>
|
||||
match ys with
|
||||
| [] => 1
|
||||
| _::_::zs => g zs + 1
|
||||
| _ => g ys + 2
|
||||
```
|
||||
To make it work, users would have to write the last alternative as
|
||||
```
|
||||
| zs => g zs + 2
|
||||
```
|
||||
/- If we first try to process the `match` as a regular application. If it fails, then we try to `push` the below over the dependent `match`.
|
||||
This is useful for examples such as:
|
||||
```
|
||||
def f (xs : List Nat) : Nat :=
|
||||
match xs with
|
||||
| [] => 0
|
||||
| y::ys =>
|
||||
match ys with
|
||||
| [] => 1
|
||||
| zs => f ys + 1
|
||||
```
|
||||
We are matching on `ys`, but still using `ys` in the second alternative.
|
||||
If we push the `below` argument over the dependent match it will be able to eliminate recursive call using `zs`.
|
||||
This trick is not sufficient for the slightly more complicated example:
|
||||
```
|
||||
def g (xs : List Nat) : Nat :=
|
||||
match xs with
|
||||
| [] => 0
|
||||
| y::ys =>
|
||||
match ys with
|
||||
| [] => 1
|
||||
| _::_::zs => g zs + 1
|
||||
| _ => g ys + 2
|
||||
```
|
||||
To make it work, users would have to write the last alternative as
|
||||
```
|
||||
| zs => g zs + 2
|
||||
```
|
||||
|
||||
If this is too annoying in practice, we may replace `ys` with the matching term.
|
||||
This may generate weird error messages, when it doesn't work.
|
||||
-/
|
||||
processApp e
|
||||
<|>
|
||||
do let matcherApp ← mapError (matcherApp.addArg below) (fun msg => "failed to add `below` argument to 'matcher' application" ++ indentD msg)
|
||||
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
|
||||
lambdaTelescope alt fun xs altBody => do
|
||||
trace[Elab.definition.structural]! "altNumParams: {numParams}, xs: {xs}"
|
||||
unless xs.size >= numParams do
|
||||
throwError! "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
|
||||
let belowForAlt := xs[numParams - 1]
|
||||
mkLambdaFVars xs (← loop belowForAlt altBody)
|
||||
pure { matcherApp with alts := altsNew }.toExpr
|
||||
| none => processApp e
|
||||
| _, e => ensureNoRecFn recFnName e
|
||||
loop below e
|
||||
If this is too annoying in practice, we may replace `ys` with the matching term.
|
||||
This may generate weird error messages, when it doesn't work.
|
||||
-/
|
||||
processApp e
|
||||
<|>
|
||||
do let matcherApp ← mapError (matcherApp.addArg below) (fun msg => "failed to add `below` argument to 'matcher' application" ++ indentD msg)
|
||||
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
|
||||
lambdaTelescope alt fun xs altBody => do
|
||||
trace[Elab.definition.structural]! "altNumParams: {numParams}, xs: {xs}"
|
||||
unless xs.size >= numParams do
|
||||
throwError! "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
|
||||
let belowForAlt := xs[numParams - 1]
|
||||
mkLambdaFVars xs (← loop belowForAlt altBody)
|
||||
pure { matcherApp with alts := altsNew }.toExpr
|
||||
| none => processApp e
|
||||
| _, e => ensureNoRecFn recFnName e
|
||||
loop below e
|
||||
|
||||
private def mkBRecOn (recFnName : Name) (recArgInfo : RecArgInfo) (value : Expr) : MetaM Expr := do
|
||||
let type := (← inferType value).headBeta
|
||||
let major := recArgInfo.ys[recArgInfo.pos]
|
||||
let otherArgs := recArgInfo.ys.filter fun y => y != major && !recArgInfo.indIndices.contains y
|
||||
let motive ← mkForallFVars otherArgs type
|
||||
let brecOnUniv ← getLevel motive
|
||||
trace[Elab.definition.structural]! "brecOn univ: {brecOnUniv}"
|
||||
let useBInductionOn := recArgInfo.reflexive && brecOnUniv == levelZero
|
||||
if recArgInfo.reflexive && brecOnUniv != levelZero then
|
||||
brecOnUniv ← decLevel brecOnUniv
|
||||
let motive ← mkLambdaFVars (recArgInfo.indIndices.push major) motive
|
||||
trace[Elab.definition.structural]! "brecOn motive: {motive}"
|
||||
let brecOn :=
|
||||
if useBInductionOn then
|
||||
Lean.mkConst (mkBInductionOnFor recArgInfo.indName) recArgInfo.indLevels
|
||||
else
|
||||
Lean.mkConst (mkBRecOnFor recArgInfo.indName) (brecOnUniv :: recArgInfo.indLevels)
|
||||
let brecOn := mkAppN brecOn recArgInfo.indParams
|
||||
let brecOn := mkApp brecOn motive
|
||||
let brecOn := mkAppN brecOn recArgInfo.indIndices
|
||||
let brecOn := mkApp brecOn major
|
||||
check brecOn
|
||||
let brecOnType ← inferType brecOn
|
||||
trace[Elab.definition.structural]! "brecOn {brecOn}"
|
||||
trace[Elab.definition.structural]! "brecOnType {brecOnType}"
|
||||
forallBoundedTelescope brecOnType (some 1) fun F _ => do
|
||||
let F := F[0]
|
||||
let FType ← inferType F
|
||||
let numIndices := recArgInfo.indIndices.size
|
||||
forallBoundedTelescope FType (some $ numIndices + 1 /- major -/ + 1 /- below -/ + otherArgs.size) fun Fargs _ => do
|
||||
let indicesNew := Fargs.extract 0 numIndices
|
||||
let majorNew := Fargs[numIndices]
|
||||
let below := Fargs[numIndices+1]
|
||||
let otherArgsNew := Fargs.extract (numIndices+2) Fargs.size
|
||||
let valueNew := value.replaceFVars recArgInfo.indIndices indicesNew
|
||||
let valueNew := valueNew.replaceFVar major majorNew
|
||||
let valueNew := valueNew.replaceFVars otherArgs otherArgsNew
|
||||
let valueNew ← replaceRecApps recFnName recArgInfo below valueNew
|
||||
let Farg ← mkLambdaFVars Fargs valueNew
|
||||
let brecOn := mkApp brecOn Farg
|
||||
pure $ mkAppN brecOn otherArgs
|
||||
let type := (← inferType value).headBeta
|
||||
let major := recArgInfo.ys[recArgInfo.pos]
|
||||
let otherArgs := recArgInfo.ys.filter fun y => y != major && !recArgInfo.indIndices.contains y
|
||||
let motive ← mkForallFVars otherArgs type
|
||||
let brecOnUniv ← getLevel motive
|
||||
trace[Elab.definition.structural]! "brecOn univ: {brecOnUniv}"
|
||||
let useBInductionOn := recArgInfo.reflexive && brecOnUniv == levelZero
|
||||
if recArgInfo.reflexive && brecOnUniv != levelZero then
|
||||
brecOnUniv ← decLevel brecOnUniv
|
||||
let motive ← mkLambdaFVars (recArgInfo.indIndices.push major) motive
|
||||
trace[Elab.definition.structural]! "brecOn motive: {motive}"
|
||||
let brecOn :=
|
||||
if useBInductionOn then
|
||||
Lean.mkConst (mkBInductionOnFor recArgInfo.indName) recArgInfo.indLevels
|
||||
else
|
||||
Lean.mkConst (mkBRecOnFor recArgInfo.indName) (brecOnUniv :: recArgInfo.indLevels)
|
||||
let brecOn := mkAppN brecOn recArgInfo.indParams
|
||||
let brecOn := mkApp brecOn motive
|
||||
let brecOn := mkAppN brecOn recArgInfo.indIndices
|
||||
let brecOn := mkApp brecOn major
|
||||
check brecOn
|
||||
let brecOnType ← inferType brecOn
|
||||
trace[Elab.definition.structural]! "brecOn {brecOn}"
|
||||
trace[Elab.definition.structural]! "brecOnType {brecOnType}"
|
||||
forallBoundedTelescope brecOnType (some 1) fun F _ => do
|
||||
let F := F[0]
|
||||
let FType ← inferType F
|
||||
let numIndices := recArgInfo.indIndices.size
|
||||
forallBoundedTelescope FType (some $ numIndices + 1 /- major -/ + 1 /- below -/ + otherArgs.size) fun Fargs _ => do
|
||||
let indicesNew := Fargs.extract 0 numIndices
|
||||
let majorNew := Fargs[numIndices]
|
||||
let below := Fargs[numIndices+1]
|
||||
let otherArgsNew := Fargs.extract (numIndices+2) Fargs.size
|
||||
let valueNew := value.replaceFVars recArgInfo.indIndices indicesNew
|
||||
let valueNew := valueNew.replaceFVar major majorNew
|
||||
let valueNew := valueNew.replaceFVars otherArgs otherArgsNew
|
||||
let valueNew ← replaceRecApps recFnName recArgInfo below valueNew
|
||||
let Farg ← mkLambdaFVars Fargs valueNew
|
||||
let brecOn := mkApp brecOn Farg
|
||||
pure $ mkAppN brecOn otherArgs
|
||||
|
||||
private def elimRecursion (preDef : PreDefinition) : MetaM PreDefinition :=
|
||||
withoutModifyingEnv do lambdaTelescope preDef.value fun xs value => do
|
||||
addAsAxiom preDef
|
||||
trace[Elab.definition.structural]! "{preDef.declName} {xs} :=\n{value}"
|
||||
let numFixed := getFixedPrefix preDef.declName xs value
|
||||
findRecArg numFixed xs fun recArgInfo => do
|
||||
-- when (recArgInfo.indName == `Nat) throwStructuralFailed -- HACK to skip Nat argument
|
||||
let valueNew ← mkBRecOn preDef.declName recArgInfo value
|
||||
let valueNew ← mkLambdaFVars xs valueNew
|
||||
trace[Elab.definition.structural]! "result: {valueNew}"
|
||||
-- Recursive applications may still occur in expressions that were not visited by replaceRecApps (e.g., in types)
|
||||
let valueNew ← ensureNoRecFn preDef.declName valueNew
|
||||
pure { preDef with value := valueNew }
|
||||
withoutModifyingEnv do lambdaTelescope preDef.value fun xs value => do
|
||||
addAsAxiom preDef
|
||||
trace[Elab.definition.structural]! "{preDef.declName} {xs} :=\n{value}"
|
||||
let numFixed := getFixedPrefix preDef.declName xs value
|
||||
findRecArg numFixed xs fun recArgInfo => do
|
||||
-- when (recArgInfo.indName == `Nat) throwStructuralFailed -- HACK to skip Nat argument
|
||||
let valueNew ← mkBRecOn preDef.declName recArgInfo value
|
||||
let valueNew ← mkLambdaFVars xs valueNew
|
||||
trace[Elab.definition.structural]! "result: {valueNew}"
|
||||
-- Recursive applications may still occur in expressions that were not visited by replaceRecApps (e.g., in types)
|
||||
let valueNew ← ensureNoRecFn preDef.declName valueNew
|
||||
pure { preDef with value := valueNew }
|
||||
|
||||
def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Unit :=
|
||||
if preDefs.size != 1 then
|
||||
throwError "structural recursion does not handle mutually recursive functions"
|
||||
else do
|
||||
let preDefNonRec ← elimRecursion preDefs[0]
|
||||
addNonRec preDefNonRec
|
||||
addAndCompileUnsafeRec preDefs
|
||||
if preDefs.size != 1 then
|
||||
throwError "structural recursion does not handle mutually recursive functions"
|
||||
else do
|
||||
let preDefNonRec ← elimRecursion preDefs[0]
|
||||
addNonRec preDefNonRec
|
||||
addAndCompileUnsafeRec preDefs
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.definition.structural
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ namespace Elab
|
|||
open Meta
|
||||
|
||||
def WFRecursion (preDefs : Array PreDefinition) : TermElabM Unit :=
|
||||
throwError "well founded recursion has not been implemented yet"
|
||||
throwError "well founded recursion has not been implemented yet"
|
||||
|
||||
end Elab
|
||||
end Lean
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue