chore: cleanup

This commit is contained in:
Leonardo de Moura 2020-10-24 06:26:32 -07:00
parent 0ab38742db
commit 525fb7ca91
5 changed files with 448 additions and 447 deletions

View file

@ -19,134 +19,134 @@ A (potentially recursive) definition.
The elaborator converts it into Kernel definitions using many different strategies.
-/
structure PreDefinition :=
(kind : DefKind)
(lparams : List Name)
(modifiers : Modifiers)
(declName : Name)
(type : Expr)
(value : Expr)
(kind : DefKind)
(lparams : List Name)
(modifiers : Modifiers)
(declName : Name)
(type : Expr)
(value : Expr)
instance : Inhabited PreDefinition :=
⟨⟨DefKind.«def», [], {}, arbitrary _, arbitrary _, arbitrary _⟩⟩
⟨⟨DefKind.«def», [], {}, arbitrary _, arbitrary _, arbitrary _⟩⟩
def instantiateMVarsAtPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) :=
preDefs.mapM fun preDef => do
pure { preDef with type := (← instantiateMVars preDef.type), value := (← instantiateMVars preDef.value) }
preDefs.mapM fun preDef => do
pure { preDef with type := (← instantiateMVars preDef.type), value := (← instantiateMVars preDef.value) }
private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do
let nextIdx ← get;
let (e, nextIdx) ← levelMVarToParam e nextIdx;
set nextIdx;
pure e
let nextIdx ← get
let (e, nextIdx) ← levelMVarToParam e nextIdx;
set nextIdx;
pure e
private def levelMVarToParamPreDeclsAux (preDefs : Array PreDefinition) : StateRefT Nat TermElabM (Array PreDefinition) :=
preDefs.mapM fun preDef => do
pure { preDef with type := (← levelMVarToParamExpr preDef.type), value := (← levelMVarToParamExpr preDef.value) }
preDefs.mapM fun preDef => do
pure { preDef with type := (← levelMVarToParamExpr preDef.type), value := (← levelMVarToParamExpr preDef.value) }
def levelMVarToParamPreDecls (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) :=
(levelMVarToParamPreDeclsAux preDefs).run' 1
(levelMVarToParamPreDeclsAux preDefs).run' 1
private def getLevelParamsPreDecls (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) := do
let s : CollectLevelParams.State := {}
for preDef in preDefs do
s := collectLevelParams s preDef.type
s := collectLevelParams s preDef.value
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
| Except.error msg => throwError msg
| Except.ok levelParams => pure levelParams
let s : CollectLevelParams.State := {}
for preDef in preDefs do
s := collectLevelParams s preDef.type
s := collectLevelParams s preDef.value
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
| Except.error msg => throwError msg
| Except.ok levelParams => pure levelParams
private def shareCommon (preDefs : Array PreDefinition) : Array PreDefinition :=
let result : Std.ShareCommonM (Array PreDefinition) :=
preDefs.mapM fun preDef => do
pure { preDef with type := (← Std.withShareCommon preDef.type), value := (← Std.withShareCommon preDef.value) }
result.run
let result : Std.ShareCommonM (Array PreDefinition) :=
preDefs.mapM fun preDef => do
pure { preDef with type := (← Std.withShareCommon preDef.type), value := (← Std.withShareCommon preDef.value) }
result.run
def fixLevelParams (preDefs : Array PreDefinition) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDefinition) := do
let preDefs := shareCommon preDefs
let lparams ← getLevelParamsPreDecls preDefs scopeLevelNames allUserLevelNames
let us := lparams.map mkLevelParam
let fixExpr (e : Expr) : Expr :=
e.replace fun c => match c with
| Expr.const declName _ _ => if preDefs.any fun preDef => preDef.declName == declName then some $ Lean.mkConst declName us else none
| _ => none
pure $ preDefs.map fun preDef =>
{ preDef with
type := fixExpr preDef.type,
value := fixExpr preDef.value,
lparams := lparams }
let preDefs := shareCommon preDefs
let lparams ← getLevelParamsPreDecls preDefs scopeLevelNames allUserLevelNames
let us := lparams.map mkLevelParam
let fixExpr (e : Expr) : Expr :=
e.replace fun c => match c with
| Expr.const declName _ _ => if preDefs.any fun preDef => preDef.declName == declName then some $ Lean.mkConst declName us else none
| _ => none
pure $ preDefs.map fun preDef =>
{ preDef with
type := fixExpr preDef.type,
value := fixExpr preDef.value,
lparams := lparams }
def applyAttributesOf (preDefs : Array PreDefinition) (applicationTime : AttributeApplicationTime) : TermElabM Unit := do
for preDef in preDefs do
applyAttributesAt preDef.declName preDef.modifiers.attrs applicationTime
for preDef in preDefs do
applyAttributesAt preDef.declName preDef.modifiers.attrs applicationTime
def abstractNestedProofs (preDef : PreDefinition) : MetaM PreDefinition :=
if preDef.kind.isTheorem || preDef.kind.isExample then
pure preDef
else do
let value ← Meta.abstractNestedProofs preDef.declName preDef.value
pure { preDef with value := value }
if preDef.kind.isTheorem || preDef.kind.isExample then
pure preDef
else do
let value ← Meta.abstractNestedProofs preDef.declName preDef.value
pure { preDef with value := value }
/- Auxiliary method for (temporarily) adding pre definition as an axiom -/
def addAsAxiom (preDef : PreDefinition) : MetaM Unit := do
addDecl $ Declaration.axiomDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, isUnsafe := preDef.modifiers.isUnsafe }
addDecl $ Declaration.axiomDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, isUnsafe := preDef.modifiers.isUnsafe }
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) : TermElabM Unit := do
let preDef ← abstractNestedProofs preDef
let env ← getEnv
let decl :=
match preDef.kind with
| DefKind.«example» => unreachable!
| DefKind.«theorem» =>
Declaration.thmDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value }
| DefKind.«opaque» =>
Declaration.opaqueDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
let preDef ← abstractNestedProofs preDef
let env ← getEnv
let decl :=
match preDef.kind with
| DefKind.«example» => unreachable!
| DefKind.«theorem» =>
Declaration.thmDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value }
| DefKind.«opaque» =>
Declaration.opaqueDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
isUnsafe := preDef.modifiers.isUnsafe }
| DefKind.«abbrev» =>
Declaration.defnDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
hints := ReducibilityHints.«abbrev», isUnsafe := preDef.modifiers.isUnsafe }
| DefKind.«def» =>
Declaration.defnDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
hints := ReducibilityHints.regular (getMaxHeight env preDef.value + 1),
isUnsafe := preDef.modifiers.isUnsafe }
| DefKind.«abbrev» =>
Declaration.defnDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
hints := ReducibilityHints.«abbrev», isUnsafe := preDef.modifiers.isUnsafe }
| DefKind.«def» =>
Declaration.defnDecl { name := preDef.declName, lparams := preDef.lparams, type := preDef.type, value := preDef.value,
hints := ReducibilityHints.regular (getMaxHeight env preDef.value + 1),
isUnsafe := preDef.modifiers.isUnsafe }
addDecl decl
applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking
when (compile && !preDef.kind.isTheorem) $
compileDecl decl
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
pure ()
addDecl decl
applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking
if compile && !preDef.kind.isTheorem then
compileDecl decl
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
pure ()
def addAndCompileNonRec (preDef : PreDefinition) : TermElabM Unit := do
addNonRecAux preDef true
addNonRecAux preDef true
def addNonRec (preDef : PreDefinition) : TermElabM Unit := do
addNonRecAux preDef false
addNonRecAux preDef false
def addAndCompileUnsafe (preDefs : Array PreDefinition) : TermElabM Unit := do
let decl := Declaration.mutualDefnDecl $ preDefs.toList.map fun preDef => {
name := preDef.declName,
lparams := preDef.lparams,
type := preDef.type,
value := preDef.value,
isUnsafe := true,
hints := ReducibilityHints.opaque
}
addDecl decl
applyAttributesOf preDefs AttributeApplicationTime.afterTypeChecking
compileDecl decl
applyAttributesOf preDefs AttributeApplicationTime.afterCompilation
pure ()
let decl := Declaration.mutualDefnDecl $ preDefs.toList.map fun preDef => {
name := preDef.declName,
lparams := preDef.lparams,
type := preDef.type,
value := preDef.value,
isUnsafe := true,
hints := ReducibilityHints.opaque
}
addDecl decl
applyAttributesOf preDefs AttributeApplicationTime.afterTypeChecking
compileDecl decl
applyAttributesOf preDefs AttributeApplicationTime.afterCompilation
pure ()
def addAndCompileUnsafeRec (preDefs : Array PreDefinition) : TermElabM Unit := do
addAndCompileUnsafe $ preDefs.map fun preDef =>
{ preDef with
declName := Compiler.mkUnsafeRecName preDef.declName,
value := preDef.value.replace fun e => match e with
| Expr.const declName us _ =>
if preDefs.any fun preDef => preDef.declName == declName then
some $ mkConst (Compiler.mkUnsafeRecName declName) us
else
none
| _ => none,
modifiers := {} }
addAndCompileUnsafe $ preDefs.map fun preDef =>
{ preDef with
declName := Compiler.mkUnsafeRecName preDef.declName,
value := preDef.value.replace fun e => match e with
| Expr.const declName us _ =>
if preDefs.any fun preDef => preDef.declName == declName then
some $ mkConst (Compiler.mkUnsafeRecName declName) us
else
none
| _ => none,
modifiers := {} }
end Lean.Elab

View file

@ -12,69 +12,70 @@ open Meta
open Term
private def addAndCompilePartial (preDefs : Array PreDefinition) : TermElabM Unit := do
for preDef in preDefs do
trace[Elab.definition]! "processing {preDef.declName}"
forallTelescope preDef.type fun xs type => do
let inh ← liftM $ mkInhabitantFor preDef.declName xs type
trace[Elab.definition]! "inhabitant for {preDef.declName}"
addNonRec { preDef with
kind := DefKind.«opaque»,
value := inh }
addAndCompileUnsafeRec preDefs
for preDef in preDefs do
trace[Elab.definition]! "processing {preDef.declName}"
forallTelescope preDef.type fun xs type => do
let inh ← liftM $ mkInhabitantFor preDef.declName xs type
trace[Elab.definition]! "inhabitant for {preDef.declName}"
addNonRec { preDef with
kind := DefKind.«opaque»,
value := inh
}
addAndCompileUnsafeRec preDefs
private def isNonRecursive (preDef : PreDefinition) : Bool :=
Option.isNone $ preDef.value.find? fun
| Expr.const declName _ _ => preDef.declName == declName
| _ => false
Option.isNone $ preDef.value.find? fun
| Expr.const declName _ _ => preDef.declName == declName
| _ => false
private def partitionPreDefs (preDefs : Array PreDefinition) : Array (Array PreDefinition) :=
let getPreDef := fun declName => (preDefs.find? fun preDef => preDef.declName == declName).get!
let vertices := preDefs.toList.map (·.declName)
let successorsOf := fun declName => (getPreDef declName).value.foldConsts [] fun declName successors =>
if preDefs.any fun preDef => preDef.declName == declName then
declName :: successors
else
successors
let sccs := SCC.scc vertices successorsOf
sccs.toArray.map fun scc => scc.toArray.map getPreDef
let getPreDef := fun declName => (preDefs.find? fun preDef => preDef.declName == declName).get!
let vertices := preDefs.toList.map (·.declName)
let successorsOf := fun declName => (getPreDef declName).value.foldConsts [] fun declName successors =>
if preDefs.any fun preDef => preDef.declName == declName then
declName :: successors
else
successors
let sccs := SCC.scc vertices successorsOf
sccs.toArray.map fun scc => scc.toArray.map getPreDef
private def collectMVarsAtPreDef (preDef : PreDefinition) : StateRefT CollectMVars.State MetaM Unit := do
collectMVars preDef.value
collectMVars preDef.type
collectMVars preDef.value
collectMVars preDef.type
private def getMVarsAtPreDef (preDef : PreDefinition) : MetaM (Array MVarId) := do
let (_, s) ← (collectMVarsAtPreDef preDef).run {}
pure s.result
let (_, s) ← (collectMVarsAtPreDef preDef).run {}
pure s.result
private def ensureNoUnassignedMVarsAtPreDef (preDef : PreDefinition) : TermElabM Unit := do
let pendingMVarIds ← liftMetaM $ getMVarsAtPreDef preDef
if ← logUnassignedUsingErrorInfos pendingMVarIds then
throwAbort
let pendingMVarIds ← liftMetaM $ getMVarsAtPreDef preDef
if ← logUnassignedUsingErrorInfos pendingMVarIds then
throwAbort
def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := do
for preDef in preDefs do
trace[Elab.definition.body]! "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
for preDef in preDefs do
ensureNoUnassignedMVarsAtPreDef preDef
for preDefs in partitionPreDefs preDefs do
trace[Elab.definition.scc]! "{preDefs.map (·.declName)}"
if preDefs.size == 1 && isNonRecursive preDefs[0] then
let preDef := preDefs[0]
if preDef.modifiers.isNoncomputable then
addNonRec preDef
for preDef in preDefs do
trace[Elab.definition.body]! "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
for preDef in preDefs do
ensureNoUnassignedMVarsAtPreDef preDef
for preDefs in partitionPreDefs preDefs do
trace[Elab.definition.scc]! "{preDefs.map (·.declName)}"
if preDefs.size == 1 && isNonRecursive preDefs[0] then
let preDef := preDefs[0]
if preDef.modifiers.isNoncomputable then
addNonRec preDef
else
addAndCompileNonRec preDef
else if preDefs.any (·.modifiers.isUnsafe) then
addAndCompileUnsafe preDefs
else if preDefs.any (·.modifiers.isPartial) then
addAndCompilePartial preDefs
else
addAndCompileNonRec preDef
else if preDefs.any (·.modifiers.isUnsafe) then
addAndCompileUnsafe preDefs
else if preDefs.any (·.modifiers.isPartial) then
addAndCompilePartial preDefs
else
mapError
(orelseMergeErrors
(structuralRecursion preDefs)
(WFRecursion preDefs))
(fun msg =>
let preDefMsgs := preDefs.toList.map (MessageData.ofExpr $ mkConst ·.declName)
msg!"fail to show termination for{indentD (MessageData.joinSep preDefMsgs Format.line)}\nwith errors\n{msg}")
mapError
(orelseMergeErrors
(structuralRecursion preDefs)
(WFRecursion preDefs))
(fun msg =>
let preDefMsgs := preDefs.toList.map (MessageData.ofExpr $ mkConst ·.declName)
msg!"fail to show termination for{indentD (MessageData.joinSep preDefMsgs Format.line)}\nwith errors\n{msg}")
end Lean.Elab

View file

@ -9,36 +9,36 @@ namespace Lean.Elab
open Meta
private def mkInhabitant? (type : Expr) : MetaM (Option Expr) := do
try
pure $ some (← mkAppM `arbitrary #[type])
catch _ =>
pure none
try
pure $ some (← mkAppM `arbitrary #[type])
catch _ =>
pure none
private def findAssumption? (xs : Array Expr) (type : Expr) : MetaM (Option Expr) := do
xs.findM? fun x => do isDefEq (← inferType x) type
xs.findM? fun x => do isDefEq (← inferType x) type
private def mkFnInhabitant? (xs : Array Expr) (type : Expr) : MetaM (Option Expr) :=
let rec loop
| 0, type => mkInhabitant? type
| i+1, type => do
let x := xs[i]
let type ← mkForallFVars #[x] type;
match ← mkInhabitant? type with
| none => loop i type
| some val => pure $ some (← mkLambdaFVars xs[0:i] val)
loop xs.size type
let rec loop
| 0, type => mkInhabitant? type
| i+1, type => do
let x := xs[i]
let type ← mkForallFVars #[x] type;
match ← mkInhabitant? type with
| none => loop i type
| some val => pure $ some (← mkLambdaFVars xs[0:i] val)
loop xs.size type
/- TODO: add a global IO.Ref to let users customize/extend this procedure -/
def mkInhabitantFor (declName : Name) (xs : Array Expr) (type : Expr) : MetaM Expr := do
match ← mkInhabitant? type with
| some val => mkLambdaFVars xs val
| none =>
match ← findAssumption? xs type with
| some x => mkLambdaFVars xs x
| none =>
match ← mkFnInhabitant? xs type with
| some val => pure val
| none => throwError! "failed to compile partial definition '{declName}', failed to show that type is inhabited"
match ← mkInhabitant? type with
| some val => mkLambdaFVars xs val
| none =>
match ← findAssumption? xs type with
| some x => mkLambdaFVars xs x
| none =>
match ← mkFnInhabitant? xs type with
| some val => pure val
| none => throwError! "failed to compile partial definition '{declName}', failed to show that type is inhabited"
end Lean.Elab

View file

@ -13,173 +13,173 @@ namespace Lean.Elab
open Meta
private def getFixedPrefix (declName : Name) (xs : Array Expr) (value : Expr) : Nat :=
let visitor {ω} : StateRefT Nat (ST ω) Unit :=
value.forEach' fun e =>
if e.isAppOf declName then do
let args := e.getAppArgs
modify fun numFixed => if args.size < numFixed then args.size else numFixed
-- we continue searching if the e's arguments are not a prefix of `xs`
pure !args.isPrefixOf xs
else
pure true
runST fun _ => do let (_, numFixed) ← visitor.run xs.size; pure numFixed
let visitor {ω} : StateRefT Nat (ST ω) Unit :=
value.forEach' fun e =>
if e.isAppOf declName then do
let args := e.getAppArgs
modify fun numFixed => if args.size < numFixed then args.size else numFixed
-- we continue searching if the e's arguments are not a prefix of `xs`
pure !args.isPrefixOf xs
else
pure true
runST fun _ => do let (_, numFixed) ← visitor.run xs.size; pure numFixed
structure RecArgInfo :=
/- `fixedParams ++ ys` are the arguments of the function we are trying to justify termination using structural recursion. -/
(fixedParams : Array Expr)
(ys : Array Expr) -- recursion arguments
(pos : Nat) -- position in `ys` of the argument we are recursing on
(indicesPos : Array Nat) -- position in `ys` of the inductive datatype indices we are recursing on
(indName : Name) -- inductive datatype name of the argument we are recursing on
(indLevels : List Level) -- inductice datatype universe levels of the argument we are recursing on
(indParams : Array Expr) -- inductive datatype parameters of the argument we are recursing on
(indIndices : Array Expr) -- inductive datatype indices of the argument we are recursing on, it is equal to `indicesPos.map fun i => ys.get! i`
(reflexive : Bool) -- true if we are recursing over a reflexive inductive datatype
/- `fixedParams ++ ys` are the arguments of the function we are trying to justify termination using structural recursion. -/
(fixedParams : Array Expr)
(ys : Array Expr) -- recursion arguments
(pos : Nat) -- position in `ys` of the argument we are recursing on
(indicesPos : Array Nat) -- position in `ys` of the inductive datatype indices we are recursing on
(indName : Name) -- inductive datatype name of the argument we are recursing on
(indLevels : List Level) -- inductice datatype universe levels of the argument we are recursing on
(indParams : Array Expr) -- inductive datatype parameters of the argument we are recursing on
(indIndices : Array Expr) -- inductive datatype indices of the argument we are recursing on, it is equal to `indicesPos.map fun i => ys.get! i`
(reflexive : Bool) -- true if we are recursing over a reflexive inductive datatype
private def getIndexMinPos (xs : Array Expr) (indices : Array Expr) : Nat := do
let minPos := xs.size
for index in indices do
match xs.indexOf? index with
| some pos => if pos.val < minPos then minPos := pos.val
| _ => pure ()
return minPos
let minPos := xs.size
for index in indices do
match xs.indexOf? index with
| some pos => if pos.val < minPos then minPos := pos.val
| _ => pure ()
return minPos
-- Indices can only depend on other indices
private def hasBadIndexDep? (ys : Array Expr) (indices : Array Expr) : MetaM (Option (Expr × Expr)) := do
for index in indices do
let indexType ← inferType index
for y in ys do
if !indices.contains y && (← dependsOn indexType y.fvarId!) then
return some (index, y)
return none
for index in indices do
let indexType ← inferType index
for y in ys do
if !indices.contains y && (← dependsOn indexType y.fvarId!) then
return some (index, y)
return none
-- Inductive datatype parameters cannot depend on ys
private def hasBadParamDep? (ys : Array Expr) (indParams : Array Expr) : MetaM (Option (Expr × Expr)) := do
for p in indParams do
let pType ← inferType p
for y in ys do
if ← dependsOn pType y.fvarId! then
return some (p, y)
return none
for p in indParams do
let pType ← inferType p
for y in ys do
if ← dependsOn pType y.fvarId! then
return some (p, y)
return none
private def throwStructuralFailed {α} : MetaM α :=
throwError "structural recursion cannot be used"
throwError "structural recursion cannot be used"
private partial def findRecArg {α} (numFixed : Nat) (xs : Array Expr) (k : RecArgInfo → MetaM α) : MetaM α :=
let rec loop (i : Nat) : MetaM α := do
if h : i < xs.size then
let x := xs.get ⟨i, h⟩
let localDecl ← getFVarLocalDecl x
if localDecl.isLet then
throwStructuralFailed
else
let xType ← whnfD localDecl.type
matchConstInduct xType.getAppFn (fun _ => loop (i+1)) fun indInfo us => do
if !(← hasConst (mkBRecOnFor indInfo.name)) then
loop (i+1)
else if indInfo.isReflexive && !(← hasConst (mkBInductionOnFor indInfo.name)) then
loop (i+1)
let rec loop (i : Nat) : MetaM α := do
if h : i < xs.size then
let x := xs.get ⟨i, h⟩
let localDecl ← getFVarLocalDecl x
if localDecl.isLet then
throwStructuralFailed
else
let indArgs := xType.getAppArgs
let indParams := indArgs.extract 0 indInfo.nparams
let indIndices := indArgs.extract indInfo.nparams indArgs.size
if !indIndices.all Expr.isFVar then
orelseMergeErrors
(throwError! "argument #{i+1} was not used because its type is an inductive family and indices are not variables{indentExpr xType}")
(loop (i+1))
else if !indIndices.allDiff then
orelseMergeErrors
(throwError! "argument #{i+1} was not used because its type is an inductive family and indices are not pairwise distinct{indentExpr xType}")
(loop (i+1))
let xType ← whnfD localDecl.type
matchConstInduct xType.getAppFn (fun _ => loop (i+1)) fun indInfo us => do
if !(← hasConst (mkBRecOnFor indInfo.name)) then
loop (i+1)
else if indInfo.isReflexive && !(← hasConst (mkBInductionOnFor indInfo.name)) then
loop (i+1)
else
let indexMinPos := getIndexMinPos xs indIndices
let numFixed := if indexMinPos < numFixed then indexMinPos else numFixed
let fixedParams := xs.extract 0 numFixed
let ys := xs.extract numFixed xs.size
match ← hasBadIndexDep? ys indIndices with
| some (index, y) =>
let indArgs := xType.getAppArgs
let indParams := indArgs.extract 0 indInfo.nparams
let indIndices := indArgs.extract indInfo.nparams indArgs.size
if !indIndices.all Expr.isFVar then
orelseMergeErrors
(throwError! "argument #{i+1} was not used because its type is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}")
(throwError! "argument #{i+1} was not used because its type is an inductive family and indices are not variables{indentExpr xType}")
(loop (i+1))
| none =>
match ← hasBadParamDep? ys indParams with
| some (indParam, y) =>
else if !indIndices.allDiff then
orelseMergeErrors
(throwError! "argument #{i+1} was not used because its type is an inductive family and indices are not pairwise distinct{indentExpr xType}")
(loop (i+1))
else
let indexMinPos := getIndexMinPos xs indIndices
let numFixed := if indexMinPos < numFixed then indexMinPos else numFixed
let fixedParams := xs.extract 0 numFixed
let ys := xs.extract numFixed xs.size
match ← hasBadIndexDep? ys indIndices with
| some (index, y) =>
orelseMergeErrors
(throwError! "argument #{i+1} was not used because its type is an inductive datatype{indentExpr xType}\nand parameter{indentExpr indParam}\ndepends on{indentExpr y}")
(throwError! "argument #{i+1} was not used because its type is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}")
(loop (i+1))
| none =>
let indicesPos := indIndices.map fun index => match ys.indexOf? index with | some i => i.val | none => unreachable!
orelseMergeErrors
(k { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size,
indicesPos := indicesPos,
indName := indInfo.name,
indLevels := us,
indParams := indParams,
indIndices := indIndices,
reflexive := indInfo.isReflexive })
(loop (i+1))
else
throwStructuralFailed
loop numFixed
match ← hasBadParamDep? ys indParams with
| some (indParam, y) =>
orelseMergeErrors
(throwError! "argument #{i+1} was not used because its type is an inductive datatype{indentExpr xType}\nand parameter{indentExpr indParam}\ndepends on{indentExpr y}")
(loop (i+1))
| none =>
let indicesPos := indIndices.map fun index => match ys.indexOf? index with | some i => i.val | none => unreachable!
orelseMergeErrors
(k { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size,
indicesPos := indicesPos,
indName := indInfo.name,
indLevels := us,
indParams := indParams,
indIndices := indIndices,
reflexive := indInfo.isReflexive })
(loop (i+1))
else
throwStructuralFailed
loop numFixed
private def containsRecFn (recFnName : Name) (e : Expr) : Bool :=
(e.find? fun e => e.isConstOf recFnName).isSome
(e.find? fun e => e.isConstOf recFnName).isSome
private def ensureNoRecFn (recFnName : Name) (e : Expr) : MetaM Expr := do
if containsRecFn recFnName e then
Meta.forEachExpr e fun e => do
if e.isAppOf recFnName then
throwError! "unexpected occurrence of recursive application{indentExpr e}"
pure e
else
pure e
if containsRecFn recFnName e then
Meta.forEachExpr e fun e => do
if e.isAppOf recFnName then
throwError! "unexpected occurrence of recursive application{indentExpr e}"
pure e
else
pure e
private def throwToBelowFailed {α} : MetaM α :=
throwError "toBelow failed"
throwError "toBelow failed"
/- See toBelow -/
private partial def toBelowAux (C : Expr) : Expr → Expr → Expr → MetaM Expr
| belowDict, arg, F => do
belowDict ← whnf belowDict
trace[Elab.definition.structural]! "belowDict: {belowDict}, arg: {arg}"
match belowDict with
| Expr.app (Expr.app (Expr.const `PProd _ _) d1 _) d2 _ =>
(do toBelowAux C d1 arg (← mkAppM `PProd.fst #[F]))
<|>
(do toBelowAux C d2 arg (← mkAppM `PProd.snd #[F]))
| Expr.app (Expr.app (Expr.const `And _ _) d1 _) d2 _ =>
(do toBelowAux C d1 arg (← mkAppM `And.left #[F]))
<|>
(do toBelowAux C d2 arg (← mkAppM `And.right #[F]))
| _ => forallTelescopeReducing belowDict fun xs belowDict => do
let argArgs := arg.getAppArgs
unless argArgs.size >= xs.size do throwToBelowFailed
let n := argArgs.size
let argTailArgs := argArgs.extract (n - xs.size) n
let belowDict := belowDict.replaceFVars xs argTailArgs
| belowDict, arg, F => do
belowDict ← whnf belowDict
trace[Elab.definition.structural]! "belowDict: {belowDict}, arg: {arg}"
match belowDict with
| Expr.app belowDictFun belowDictArg _ =>
unless belowDictFun.getAppFn == C do throwToBelowFailed
unless ← isDefEq belowDictArg arg do throwToBelowFailed
pure (mkAppN F argTailArgs)
| _ => throwToBelowFailed
| Expr.app (Expr.app (Expr.const `PProd _ _) d1 _) d2 _ =>
(do toBelowAux C d1 arg (← mkAppM `PProd.fst #[F]))
<|>
(do toBelowAux C d2 arg (← mkAppM `PProd.snd #[F]))
| Expr.app (Expr.app (Expr.const `And _ _) d1 _) d2 _ =>
(do toBelowAux C d1 arg (← mkAppM `And.left #[F]))
<|>
(do toBelowAux C d2 arg (← mkAppM `And.right #[F]))
| _ => forallTelescopeReducing belowDict fun xs belowDict => do
let argArgs := arg.getAppArgs
unless argArgs.size >= xs.size do throwToBelowFailed
let n := argArgs.size
let argTailArgs := argArgs.extract (n - xs.size) n
let belowDict := belowDict.replaceFVars xs argTailArgs
match belowDict with
| Expr.app belowDictFun belowDictArg _ =>
unless belowDictFun.getAppFn == C do throwToBelowFailed
unless ← isDefEq belowDictArg arg do throwToBelowFailed
pure (mkAppN F argTailArgs)
| _ => throwToBelowFailed
/- See toBelow -/
private def withBelowDict {α} (below : Expr) (numIndParams : Nat) (k : Expr → Expr → MetaM α) : MetaM α := do
let belowType ← inferType below
trace[Elab.definition.structural]! "belowType: {belowType}"
belowType.withApp fun f args => do
let motivePos := numIndParams + 1
unless motivePos < args.size do throwError! "unexpected 'below' type{indentExpr belowType}"
let pre := mkAppN f (args.extract 0 numIndParams)
let preType ← inferType pre
forallBoundedTelescope preType (some 1) fun x _ => do
let motiveType ← inferType x[0]
let C ← mkFreshUserName `C
withLocalDeclD C motiveType fun C =>
let belowDict := mkApp pre C
let belowDict := mkAppN belowDict (args.extract (numIndParams + 1) args.size)
k C belowDict
let belowType ← inferType below
trace[Elab.definition.structural]! "belowType: {belowType}"
belowType.withApp fun f args => do
let motivePos := numIndParams + 1
unless motivePos < args.size do throwError! "unexpected 'below' type{indentExpr belowType}"
let pre := mkAppN f (args.extract 0 numIndParams)
let preType ← inferType pre
forallBoundedTelescope preType (some 1) fun x _ => do
let motiveType ← inferType x[0]
let C ← mkFreshUserName `C
withLocalDeclD C motiveType fun C =>
let belowDict := mkApp pre C
let belowDict := mkAppN belowDict (args.extract (numIndParams + 1) args.size)
k C belowDict
/-
`below` is a free variable with type of the form `I.below indParams motive indices major`,
@ -202,160 +202,160 @@ belowType.withApp fun f args => do
The dictionary is built using the `PProd` (`And` for inductive predicates).
We keep searching it until we find `C recArg`, where `C` is the auxiliary fresh variable created at `withBelowDict`. -/
private partial def toBelow (below : Expr) (numIndParams : Nat) (recArg : Expr) : MetaM Expr := do
withBelowDict below numIndParams fun C belowDict =>
toBelowAux C belowDict recArg below
withBelowDict below numIndParams fun C belowDict =>
toBelowAux C belowDict recArg below
private partial def replaceRecApps (recFnName : Name) (recArgInfo : RecArgInfo) (below : Expr) (e : Expr) : MetaM Expr :=
let rec loop : Expr → Expr → MetaM Expr
| below, e@(Expr.lam n d b c) => do
withLocalDecl n c.binderInfo (← loop below d) fun x => do
mkLambdaFVars #[x] (← loop below (b.instantiate1 x))
| below, e@(Expr.forallE n d b c) => do
withLocalDecl n c.binderInfo (← loop below d) fun x => do
mkForallFVars #[x] (← loop below (b.instantiate1 x))
| below, Expr.letE n type val body _ => do
withLetDecl n (← loop below type) (← loop below val) fun x => do
mkLetFVars #[x] (← loop below (body.instantiate1 x))
| below, Expr.mdata d e _ => do pure $ mkMData d (← loop below e)
| below, Expr.proj n i e _ => do pure $ mkProj n i (← loop below e)
| below, e@(Expr.app _ _ _) => do
let processApp (e : Expr) : MetaM Expr :=
e.withApp fun f args => do
if f.isConstOf recFnName then
let numFixed := recArgInfo.fixedParams.size
let recArgPos := recArgInfo.fixedParams.size + recArgInfo.pos
if recArgPos >= args.size then
throwError! "insufficient number of parameters at recursive application {indentExpr e}"
let recArg := args[recArgPos]
let f ← try toBelow below recArgInfo.indParams.size recArg catch _ => throwError! "failed to eliminate recursive application{indentExpr e}"
-- Recall that the fixed parameters are not in the scope of the `brecOn`. So, we skip them.
let argsNonFixed := args.extract numFixed args.size
-- The function `f` does not explicitly take `recArg` and its indices as arguments. So, we skip them too.
let fArgs := #[]
for i in [:argsNonFixed.size] do
if recArgInfo.pos != i && !recArgInfo.indicesPos.contains i then
let arg := argsNonFixed[i]
let arg ← replaceRecApps recFnName recArgInfo below arg
fArgs := fArgs.push arg
pure $ mkAppN f fArgs
let rec loop : Expr → Expr → MetaM Expr
| below, e@(Expr.lam n d b c) => do
withLocalDecl n c.binderInfo (← loop below d) fun x => do
mkLambdaFVars #[x] (← loop below (b.instantiate1 x))
| below, e@(Expr.forallE n d b c) => do
withLocalDecl n c.binderInfo (← loop below d) fun x => do
mkForallFVars #[x] (← loop below (b.instantiate1 x))
| below, Expr.letE n type val body _ => do
withLetDecl n (← loop below type) (← loop below val) fun x => do
mkLetFVars #[x] (← loop below (body.instantiate1 x))
| below, Expr.mdata d e _ => do pure $ mkMData d (← loop below e)
| below, Expr.proj n i e _ => do pure $ mkProj n i (← loop below e)
| below, e@(Expr.app _ _ _) => do
let processApp (e : Expr) : MetaM Expr :=
e.withApp fun f args => do
if f.isConstOf recFnName then
let numFixed := recArgInfo.fixedParams.size
let recArgPos := recArgInfo.fixedParams.size + recArgInfo.pos
if recArgPos >= args.size then
throwError! "insufficient number of parameters at recursive application {indentExpr e}"
let recArg := args[recArgPos]
let f ← try toBelow below recArgInfo.indParams.size recArg catch _ => throwError! "failed to eliminate recursive application{indentExpr e}"
-- Recall that the fixed parameters are not in the scope of the `brecOn`. So, we skip them.
let argsNonFixed := args.extract numFixed args.size
-- The function `f` does not explicitly take `recArg` and its indices as arguments. So, we skip them too.
let fArgs := #[]
for i in [:argsNonFixed.size] do
if recArgInfo.pos != i && !recArgInfo.indicesPos.contains i then
let arg := argsNonFixed[i]
let arg ← replaceRecApps recFnName recArgInfo below arg
fArgs := fArgs.push arg
pure $ mkAppN f fArgs
else
pure $ mkAppN (← loop below f) (← args.mapM (loop below))
let matcherApp? ← matchMatcherApp? e
match matcherApp? with
| some matcherApp =>
if !containsRecFn recFnName e then
processApp e
else
pure $ mkAppN (← loop below f) (← args.mapM (loop below))
let matcherApp? ← matchMatcherApp? e
match matcherApp? with
| some matcherApp =>
if !containsRecFn recFnName e then
processApp e
else
/- If we first try to process the `match` as a regular application. If it fails, then we try to `push` the below over the dependent `match`.
This is useful for examples such as:
```
def f (xs : List Nat) : Nat :=
match xs with
| [] => 0
| y::ys =>
match ys with
| [] => 1
| zs => f ys + 1
```
We are matching on `ys`, but still using `ys` in the second alternative.
If we push the `below` argument over the dependent match it will be able to eliminate recursive call using `zs`.
This trick is not sufficient for the slightly more complicated example:
```
def g (xs : List Nat) : Nat :=
match xs with
| [] => 0
| y::ys =>
match ys with
| [] => 1
| _::_::zs => g zs + 1
| _ => g ys + 2
```
To make it work, users would have to write the last alternative as
```
| zs => g zs + 2
```
/- If we first try to process the `match` as a regular application. If it fails, then we try to `push` the below over the dependent `match`.
This is useful for examples such as:
```
def f (xs : List Nat) : Nat :=
match xs with
| [] => 0
| y::ys =>
match ys with
| [] => 1
| zs => f ys + 1
```
We are matching on `ys`, but still using `ys` in the second alternative.
If we push the `below` argument over the dependent match it will be able to eliminate recursive call using `zs`.
This trick is not sufficient for the slightly more complicated example:
```
def g (xs : List Nat) : Nat :=
match xs with
| [] => 0
| y::ys =>
match ys with
| [] => 1
| _::_::zs => g zs + 1
| _ => g ys + 2
```
To make it work, users would have to write the last alternative as
```
| zs => g zs + 2
```
If this is too annoying in practice, we may replace `ys` with the matching term.
This may generate weird error messages, when it doesn't work.
-/
processApp e
<|>
do let matcherApp ← mapError (matcherApp.addArg below) (fun msg => "failed to add `below` argument to 'matcher' application" ++ indentD msg)
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
lambdaTelescope alt fun xs altBody => do
trace[Elab.definition.structural]! "altNumParams: {numParams}, xs: {xs}"
unless xs.size >= numParams do
throwError! "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
let belowForAlt := xs[numParams - 1]
mkLambdaFVars xs (← loop belowForAlt altBody)
pure { matcherApp with alts := altsNew }.toExpr
| none => processApp e
| _, e => ensureNoRecFn recFnName e
loop below e
If this is too annoying in practice, we may replace `ys` with the matching term.
This may generate weird error messages, when it doesn't work.
-/
processApp e
<|>
do let matcherApp ← mapError (matcherApp.addArg below) (fun msg => "failed to add `below` argument to 'matcher' application" ++ indentD msg)
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
lambdaTelescope alt fun xs altBody => do
trace[Elab.definition.structural]! "altNumParams: {numParams}, xs: {xs}"
unless xs.size >= numParams do
throwError! "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
let belowForAlt := xs[numParams - 1]
mkLambdaFVars xs (← loop belowForAlt altBody)
pure { matcherApp with alts := altsNew }.toExpr
| none => processApp e
| _, e => ensureNoRecFn recFnName e
loop below e
private def mkBRecOn (recFnName : Name) (recArgInfo : RecArgInfo) (value : Expr) : MetaM Expr := do
let type := (← inferType value).headBeta
let major := recArgInfo.ys[recArgInfo.pos]
let otherArgs := recArgInfo.ys.filter fun y => y != major && !recArgInfo.indIndices.contains y
let motive ← mkForallFVars otherArgs type
let brecOnUniv ← getLevel motive
trace[Elab.definition.structural]! "brecOn univ: {brecOnUniv}"
let useBInductionOn := recArgInfo.reflexive && brecOnUniv == levelZero
if recArgInfo.reflexive && brecOnUniv != levelZero then
brecOnUniv ← decLevel brecOnUniv
let motive ← mkLambdaFVars (recArgInfo.indIndices.push major) motive
trace[Elab.definition.structural]! "brecOn motive: {motive}"
let brecOn :=
if useBInductionOn then
Lean.mkConst (mkBInductionOnFor recArgInfo.indName) recArgInfo.indLevels
else
Lean.mkConst (mkBRecOnFor recArgInfo.indName) (brecOnUniv :: recArgInfo.indLevels)
let brecOn := mkAppN brecOn recArgInfo.indParams
let brecOn := mkApp brecOn motive
let brecOn := mkAppN brecOn recArgInfo.indIndices
let brecOn := mkApp brecOn major
check brecOn
let brecOnType ← inferType brecOn
trace[Elab.definition.structural]! "brecOn {brecOn}"
trace[Elab.definition.structural]! "brecOnType {brecOnType}"
forallBoundedTelescope brecOnType (some 1) fun F _ => do
let F := F[0]
let FType ← inferType F
let numIndices := recArgInfo.indIndices.size
forallBoundedTelescope FType (some $ numIndices + 1 /- major -/ + 1 /- below -/ + otherArgs.size) fun Fargs _ => do
let indicesNew := Fargs.extract 0 numIndices
let majorNew := Fargs[numIndices]
let below := Fargs[numIndices+1]
let otherArgsNew := Fargs.extract (numIndices+2) Fargs.size
let valueNew := value.replaceFVars recArgInfo.indIndices indicesNew
let valueNew := valueNew.replaceFVar major majorNew
let valueNew := valueNew.replaceFVars otherArgs otherArgsNew
let valueNew ← replaceRecApps recFnName recArgInfo below valueNew
let Farg ← mkLambdaFVars Fargs valueNew
let brecOn := mkApp brecOn Farg
pure $ mkAppN brecOn otherArgs
let type := (← inferType value).headBeta
let major := recArgInfo.ys[recArgInfo.pos]
let otherArgs := recArgInfo.ys.filter fun y => y != major && !recArgInfo.indIndices.contains y
let motive ← mkForallFVars otherArgs type
let brecOnUniv ← getLevel motive
trace[Elab.definition.structural]! "brecOn univ: {brecOnUniv}"
let useBInductionOn := recArgInfo.reflexive && brecOnUniv == levelZero
if recArgInfo.reflexive && brecOnUniv != levelZero then
brecOnUniv ← decLevel brecOnUniv
let motive ← mkLambdaFVars (recArgInfo.indIndices.push major) motive
trace[Elab.definition.structural]! "brecOn motive: {motive}"
let brecOn :=
if useBInductionOn then
Lean.mkConst (mkBInductionOnFor recArgInfo.indName) recArgInfo.indLevels
else
Lean.mkConst (mkBRecOnFor recArgInfo.indName) (brecOnUniv :: recArgInfo.indLevels)
let brecOn := mkAppN brecOn recArgInfo.indParams
let brecOn := mkApp brecOn motive
let brecOn := mkAppN brecOn recArgInfo.indIndices
let brecOn := mkApp brecOn major
check brecOn
let brecOnType ← inferType brecOn
trace[Elab.definition.structural]! "brecOn {brecOn}"
trace[Elab.definition.structural]! "brecOnType {brecOnType}"
forallBoundedTelescope brecOnType (some 1) fun F _ => do
let F := F[0]
let FType ← inferType F
let numIndices := recArgInfo.indIndices.size
forallBoundedTelescope FType (some $ numIndices + 1 /- major -/ + 1 /- below -/ + otherArgs.size) fun Fargs _ => do
let indicesNew := Fargs.extract 0 numIndices
let majorNew := Fargs[numIndices]
let below := Fargs[numIndices+1]
let otherArgsNew := Fargs.extract (numIndices+2) Fargs.size
let valueNew := value.replaceFVars recArgInfo.indIndices indicesNew
let valueNew := valueNew.replaceFVar major majorNew
let valueNew := valueNew.replaceFVars otherArgs otherArgsNew
let valueNew ← replaceRecApps recFnName recArgInfo below valueNew
let Farg ← mkLambdaFVars Fargs valueNew
let brecOn := mkApp brecOn Farg
pure $ mkAppN brecOn otherArgs
private def elimRecursion (preDef : PreDefinition) : MetaM PreDefinition :=
withoutModifyingEnv do lambdaTelescope preDef.value fun xs value => do
addAsAxiom preDef
trace[Elab.definition.structural]! "{preDef.declName} {xs} :=\n{value}"
let numFixed := getFixedPrefix preDef.declName xs value
findRecArg numFixed xs fun recArgInfo => do
-- when (recArgInfo.indName == `Nat) throwStructuralFailed -- HACK to skip Nat argument
let valueNew ← mkBRecOn preDef.declName recArgInfo value
let valueNew ← mkLambdaFVars xs valueNew
trace[Elab.definition.structural]! "result: {valueNew}"
-- Recursive applications may still occur in expressions that were not visited by replaceRecApps (e.g., in types)
let valueNew ← ensureNoRecFn preDef.declName valueNew
pure { preDef with value := valueNew }
withoutModifyingEnv do lambdaTelescope preDef.value fun xs value => do
addAsAxiom preDef
trace[Elab.definition.structural]! "{preDef.declName} {xs} :=\n{value}"
let numFixed := getFixedPrefix preDef.declName xs value
findRecArg numFixed xs fun recArgInfo => do
-- when (recArgInfo.indName == `Nat) throwStructuralFailed -- HACK to skip Nat argument
let valueNew ← mkBRecOn preDef.declName recArgInfo value
let valueNew ← mkLambdaFVars xs valueNew
trace[Elab.definition.structural]! "result: {valueNew}"
-- Recursive applications may still occur in expressions that were not visited by replaceRecApps (e.g., in types)
let valueNew ← ensureNoRecFn preDef.declName valueNew
pure { preDef with value := valueNew }
def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Unit :=
if preDefs.size != 1 then
throwError "structural recursion does not handle mutually recursive functions"
else do
let preDefNonRec ← elimRecursion preDefs[0]
addNonRec preDefNonRec
addAndCompileUnsafeRec preDefs
if preDefs.size != 1 then
throwError "structural recursion does not handle mutually recursive functions"
else do
let preDefNonRec ← elimRecursion preDefs[0]
addNonRec preDefNonRec
addAndCompileUnsafeRec preDefs
builtin_initialize
registerTraceClass `Elab.definition.structural

View file

@ -10,7 +10,7 @@ namespace Elab
open Meta
def WFRecursion (preDefs : Array PreDefinition) : TermElabM Unit :=
throwError "well founded recursion has not been implemented yet"
throwError "well founded recursion has not been implemented yet"
end Elab
end Lean