refactor: use MetaM and exceptions

This commit is contained in:
Leonardo de Moura 2020-09-22 16:05:28 -07:00
parent 25bcc95b13
commit 41e6447837

View file

@ -11,6 +11,14 @@ namespace Lean
namespace Elab
open Meta
def registerStructualId : IO InternalExceptionId :=
registerInternalExceptionId `structuralRec
@[init registerStructualId]
constant structuralExceptionId : InternalExceptionId := arbitrary _
def throwStructuralFailed {α m} [MonadExceptOf Exception m] : m α :=
throw $ Exception.internal structuralExceptionId
private def getFixedPrefix (declName : Name) (xs : Array Expr) (value : Expr) : Nat :=
let visitor {ω} : StateRefT Nat (ST ω) Unit :=
value.forEach' fun e =>
@ -43,7 +51,7 @@ indices.foldl
xs.size
-- Indices can only depend on other indices
private def hasBadIndexDep? (ys : Array Expr) (indices : Array Expr) : TermElabM (Option (Expr × Expr)) :=
private def hasBadIndexDep? (ys : Array Expr) (indices : Array Expr) : MetaM (Option (Expr × Expr)) :=
indices.findSomeM? fun index => do
indexType ← inferType index;
ys.findSomeM? fun y =>
@ -53,7 +61,7 @@ indices.findSomeM? fun index => do
(pure none)
-- Inductive datatype parameters cannot depend on ys
private def hasBadParamDep? (ys : Array Expr) (indParams : Array Expr) : TermElabM (Option (Expr × Expr)) :=
private def hasBadParamDep? (ys : Array Expr) (indParams : Array Expr) : MetaM (Option (Expr × Expr)) :=
indParams.findSomeM? fun p => do
pType ← inferType p;
ys.findSomeM? fun y =>
@ -61,28 +69,29 @@ indParams.findSomeM? fun p => do
(pure (some (p, y)))
(pure none)
private partial def findRecArgAux? {α} (numFixed : Nat) (xs : Array Expr) (k? : RecArgInfo → TermElabM (Option α)) : Nat → TermElabM (Option α)
private partial def findRecArgAux {α} (numFixed : Nat) (xs : Array Expr) (k : RecArgInfo → MetaM α) : Nat → MetaM α
| i =>
if h : i < xs.size then do
let x := xs.get ⟨i, h⟩;
localDecl ← getFVarLocalDecl x;
if localDecl.isLet then pure none
if localDecl.isLet then
throwStructuralFailed
else do
xType ← whnfD localDecl.type;
matchConstInduct xType.getAppFn (fun _ => findRecArgAux? (i+1)) fun indInfo us => do
condM (not <$> hasConst (mkBRecOnFor indInfo.name)) (findRecArgAux? (i+1)) do
condM (do hasBInductionOn ← hasConst (mkBInductionOnFor indInfo.name); pure $ indInfo.isReflexive && !hasBInductionOn) (findRecArgAux? (i+1)) do
matchConstInduct xType.getAppFn (fun _ => findRecArgAux (i+1)) fun indInfo us => do
condM (not <$> hasConst (mkBRecOnFor indInfo.name)) (findRecArgAux (i+1)) do
condM (do hasBInductionOn ← hasConst (mkBInductionOnFor indInfo.name); pure $ indInfo.isReflexive && !hasBInductionOn) (findRecArgAux (i+1)) do
let indArgs := xType.getAppArgs;
let indParams := indArgs.extract 0 indInfo.nparams;
let indIndices := indArgs.extract indInfo.nparams indArgs.size;
if !indIndices.all Expr.isFVar then do
trace `Elab.definition.structural fun _ =>
"argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not variables" ++ indentExpr xType;
findRecArgAux? (i+1)
trace! `Elab.definition.structural
("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not variables" ++ indentExpr xType);
findRecArgAux (i+1)
else if !indIndices.allDiff then do
trace `Elab.definition.structural fun _ =>
"argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not pairwise distinct" ++ indentExpr xType;
findRecArgAux? (i+1)
trace! `Elab.definition.structural
("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not pairwise distinct" ++ indentExpr xType);
findRecArgAux (i+1)
else do
let indexMinPos := getIndexMinPos xs indIndices;
let numFixed := if indexMinPos < numFixed then indexMinPos else numFixed;
@ -91,43 +100,42 @@ private partial def findRecArgAux? {α} (numFixed : Nat) (xs : Array Expr) (k? :
badDep? ← hasBadIndexDep? ys indIndices;
match badDep? with
| some (index, y) => do
trace `Elab.definition.structural fun _ =>
"argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family" ++ indentExpr xType ++
Format.line ++ "and index" ++ indentExpr index ++
Format.line ++ "depends on the non index" ++ indentExpr y;
findRecArgAux? (i+1)
trace! `Elab.definition.structural
("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family" ++ indentExpr xType ++
Format.line ++ "and index" ++ indentExpr index ++
Format.line ++ "depends on the non index" ++ indentExpr y);
findRecArgAux (i+1)
| none => do
badDep? ← hasBadParamDep? ys indParams;
match badDep? with
| some (indParam, y) => do
trace `Elab.definition.structural fun _ =>
"argument #" ++ toString (i+1) ++ " was not used because its type is an inductive datatype" ++ indentExpr xType ++
Format.line ++ "and parameter" ++ indentExpr indParam ++
Format.line ++ "depends on" ++ indentExpr y;
findRecArgAux? (i+1)
trace! `Elab.definition.structural
("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive datatype" ++ indentExpr xType ++
Format.line ++ "and parameter" ++ indentExpr indParam ++
Format.line ++ "depends on" ++ indentExpr y);
findRecArgAux (i+1)
| none => do
let indicesPos := indIndices.map fun index => match ys.indexOf index with | some i => i.val | none => unreachable!;
a? ← k? { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size,
indicesPos := indicesPos,
indName := indInfo.name,
indLevels := us,
indParams := indParams,
indIndices := indIndices,
reflexive := indInfo.isReflexive };
match a? with
| some a => pure a
| none => findRecArgAux? (i+1)
catchInternalId structuralExceptionId
(k { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size,
indicesPos := indicesPos,
indName := indInfo.name,
indLevels := us,
indParams := indParams,
indIndices := indIndices,
reflexive := indInfo.isReflexive })
(fun _ => findRecArgAux (i+1))
else
pure none
throwStructuralFailed
@[inline] private def findRecArg? {α} (numFixed : Nat) (xs : Array Expr) (k? : RecArgInfo → TermElabM (Option α)) : TermElabM (Option α) :=
findRecArgAux? numFixed xs k? numFixed
@[inline] private def findRecArg {α} (numFixed : Nat) (xs : Array Expr) (k : RecArgInfo → MetaM α) : MetaM α :=
findRecArgAux numFixed xs k numFixed
private def replaceRecApps? (argInfo : RecArgInfo) (below : Expr) (value : Expr) : TermElabM (Option Expr) :=
private def replaceRecApps (argInfo : RecArgInfo) (below : Expr) (value : Expr) : MetaM Expr :=
-- TODO
pure value
private def mkBRecOn? (argInfo : RecArgInfo) (value : Expr) : TermElabM (Option Expr) := do
private def mkBRecOn (argInfo : RecArgInfo) (value : Expr) : MetaM Expr := do
type ← inferType value;
let type := type.headBeta;
let major := argInfo.ys.get! argInfo.pos;
@ -135,15 +143,15 @@ let otherArgs := argInfo.ys.filter fun y => y != major && !argInfo.indIndices.co
motive ← mkForallFVars otherArgs type;
brecOnUniv ← getDecLevel motive;
motive ← mkLambdaFVars (argInfo.indIndices.push major) motive;
trace `Elab.definition.structural fun _ => "brecOn motive: " ++ motive;
trace! `Elab.definition.structural ("brecOn motive: " ++ motive);
let brecOn := Lean.mkConst (mkBRecOnFor argInfo.indName) (brecOnUniv :: argInfo.indLevels);
let brecOn := mkAppN brecOn argInfo.indParams;
let brecOn := mkApp brecOn motive;
let brecOn := mkAppN brecOn argInfo.indIndices;
let brecOn := mkApp brecOn major;
brecOnType ← inferType brecOn;
trace `Elab.definition.structural fun _ => "brecOn " ++ brecOn;
trace `Elab.definition.structural fun _ => "brecOnType " ++ brecOnType;
trace! `Elab.definition.structural ("brecOn " ++ brecOn);
trace! `Elab.definition.structural ("brecOnType " ++ brecOnType);
forallBoundedTelescope brecOnType (some 1) fun F _ => do
let F := F.get! 0;
FType ← inferType F;
@ -154,36 +162,33 @@ forallBoundedTelescope brecOnType (some 1) fun F _ => do
let below := Fargs.get! (numIndices+1);
let valueNew := value.replaceFVars argInfo.indIndices indicesNew;
let valueNew := valueNew.replaceFVar major majorNew;
valueNew? ← replaceRecApps? argInfo below valueNew;
match valueNew? with
| none => pure none
| some valueNew => do
Farg ← mkLambdaFVars Fargs valueNew;
let brecOn := mkApp brecOn Farg;
pure $ mkAppN brecOn otherArgs
valueNew ← replaceRecApps argInfo below valueNew;
Farg ← mkLambdaFVars Fargs valueNew;
let brecOn := mkApp brecOn Farg;
pure $ mkAppN brecOn otherArgs
private def elimRecursion? (preDef : PreDefinition) : TermElabM (Option PreDefinition) :=
private def elimRecursion (preDef : PreDefinition) : MetaM PreDefinition :=
lambdaTelescope preDef.value fun xs value => do
trace `Elab.definition.structural fun _ => preDef.declName ++ " " ++ xs ++ " :=\n" ++ value;
trace! `Elab.definition.structural (preDef.declName ++ " " ++ xs ++ " :=\n" ++ value);
let numFixed := getFixedPrefix preDef.declName xs value;
findRecArg? numFixed xs fun argInfo => do
some valueNew ← mkBRecOn? argInfo value | pure none;
findRecArg numFixed xs fun argInfo => do
-- when (argInfo.indName == `Nat) throwStructuralFailed; -- HACK to skip Nat argument
valueNew ← mkBRecOn argInfo value;
valueNew ← mkLambdaFVars xs valueNew;
trace `Elab.definition.structural fun _ => "result: " ++ valueNew;
trace! `Elab.definition.structural ("result: " ++ valueNew);
-- pure $ some { preDef with value := valueNew }
throwError "WIP"
def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Bool :=
if preDefs.size != 1 then
pure false
else do
preDefNonRec? ← elimRecursion? (preDefs.get! 0);
match preDefNonRec? with
| none => pure false
| some preDefNonRec => do
addNonRec preDefNonRec;
addAndCompileUnsafeRec preDefs;
pure true
else
catchInternalId structuralExceptionId
(do preDefNonRec ← liftMetaM $ elimRecursion (preDefs.get! 0);
addNonRec preDefNonRec;
addAndCompileUnsafeRec preDefs;
pure true)
(fun _ => pure false)
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `Elab.definition.structural;