refactor: use MetaM and exceptions
This commit is contained in:
parent
25bcc95b13
commit
41e6447837
1 changed files with 66 additions and 61 deletions
|
|
@ -11,6 +11,14 @@ namespace Lean
|
|||
namespace Elab
|
||||
open Meta
|
||||
|
||||
def registerStructualId : IO InternalExceptionId :=
|
||||
registerInternalExceptionId `structuralRec
|
||||
@[init registerStructualId]
|
||||
constant structuralExceptionId : InternalExceptionId := arbitrary _
|
||||
|
||||
def throwStructuralFailed {α m} [MonadExceptOf Exception m] : m α :=
|
||||
throw $ Exception.internal structuralExceptionId
|
||||
|
||||
private def getFixedPrefix (declName : Name) (xs : Array Expr) (value : Expr) : Nat :=
|
||||
let visitor {ω} : StateRefT Nat (ST ω) Unit :=
|
||||
value.forEach' fun e =>
|
||||
|
|
@ -43,7 +51,7 @@ indices.foldl
|
|||
xs.size
|
||||
|
||||
-- Indices can only depend on other indices
|
||||
private def hasBadIndexDep? (ys : Array Expr) (indices : Array Expr) : TermElabM (Option (Expr × Expr)) :=
|
||||
private def hasBadIndexDep? (ys : Array Expr) (indices : Array Expr) : MetaM (Option (Expr × Expr)) :=
|
||||
indices.findSomeM? fun index => do
|
||||
indexType ← inferType index;
|
||||
ys.findSomeM? fun y =>
|
||||
|
|
@ -53,7 +61,7 @@ indices.findSomeM? fun index => do
|
|||
(pure none)
|
||||
|
||||
-- Inductive datatype parameters cannot depend on ys
|
||||
private def hasBadParamDep? (ys : Array Expr) (indParams : Array Expr) : TermElabM (Option (Expr × Expr)) :=
|
||||
private def hasBadParamDep? (ys : Array Expr) (indParams : Array Expr) : MetaM (Option (Expr × Expr)) :=
|
||||
indParams.findSomeM? fun p => do
|
||||
pType ← inferType p;
|
||||
ys.findSomeM? fun y =>
|
||||
|
|
@ -61,28 +69,29 @@ indParams.findSomeM? fun p => do
|
|||
(pure (some (p, y)))
|
||||
(pure none)
|
||||
|
||||
private partial def findRecArgAux? {α} (numFixed : Nat) (xs : Array Expr) (k? : RecArgInfo → TermElabM (Option α)) : Nat → TermElabM (Option α)
|
||||
private partial def findRecArgAux {α} (numFixed : Nat) (xs : Array Expr) (k : RecArgInfo → MetaM α) : Nat → MetaM α
|
||||
| i =>
|
||||
if h : i < xs.size then do
|
||||
let x := xs.get ⟨i, h⟩;
|
||||
localDecl ← getFVarLocalDecl x;
|
||||
if localDecl.isLet then pure none
|
||||
if localDecl.isLet then
|
||||
throwStructuralFailed
|
||||
else do
|
||||
xType ← whnfD localDecl.type;
|
||||
matchConstInduct xType.getAppFn (fun _ => findRecArgAux? (i+1)) fun indInfo us => do
|
||||
condM (not <$> hasConst (mkBRecOnFor indInfo.name)) (findRecArgAux? (i+1)) do
|
||||
condM (do hasBInductionOn ← hasConst (mkBInductionOnFor indInfo.name); pure $ indInfo.isReflexive && !hasBInductionOn) (findRecArgAux? (i+1)) do
|
||||
matchConstInduct xType.getAppFn (fun _ => findRecArgAux (i+1)) fun indInfo us => do
|
||||
condM (not <$> hasConst (mkBRecOnFor indInfo.name)) (findRecArgAux (i+1)) do
|
||||
condM (do hasBInductionOn ← hasConst (mkBInductionOnFor indInfo.name); pure $ indInfo.isReflexive && !hasBInductionOn) (findRecArgAux (i+1)) do
|
||||
let indArgs := xType.getAppArgs;
|
||||
let indParams := indArgs.extract 0 indInfo.nparams;
|
||||
let indIndices := indArgs.extract indInfo.nparams indArgs.size;
|
||||
if !indIndices.all Expr.isFVar then do
|
||||
trace `Elab.definition.structural fun _ =>
|
||||
"argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not variables" ++ indentExpr xType;
|
||||
findRecArgAux? (i+1)
|
||||
trace! `Elab.definition.structural
|
||||
("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not variables" ++ indentExpr xType);
|
||||
findRecArgAux (i+1)
|
||||
else if !indIndices.allDiff then do
|
||||
trace `Elab.definition.structural fun _ =>
|
||||
"argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not pairwise distinct" ++ indentExpr xType;
|
||||
findRecArgAux? (i+1)
|
||||
trace! `Elab.definition.structural
|
||||
("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not pairwise distinct" ++ indentExpr xType);
|
||||
findRecArgAux (i+1)
|
||||
else do
|
||||
let indexMinPos := getIndexMinPos xs indIndices;
|
||||
let numFixed := if indexMinPos < numFixed then indexMinPos else numFixed;
|
||||
|
|
@ -91,43 +100,42 @@ private partial def findRecArgAux? {α} (numFixed : Nat) (xs : Array Expr) (k? :
|
|||
badDep? ← hasBadIndexDep? ys indIndices;
|
||||
match badDep? with
|
||||
| some (index, y) => do
|
||||
trace `Elab.definition.structural fun _ =>
|
||||
"argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family" ++ indentExpr xType ++
|
||||
Format.line ++ "and index" ++ indentExpr index ++
|
||||
Format.line ++ "depends on the non index" ++ indentExpr y;
|
||||
findRecArgAux? (i+1)
|
||||
trace! `Elab.definition.structural
|
||||
("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family" ++ indentExpr xType ++
|
||||
Format.line ++ "and index" ++ indentExpr index ++
|
||||
Format.line ++ "depends on the non index" ++ indentExpr y);
|
||||
findRecArgAux (i+1)
|
||||
| none => do
|
||||
badDep? ← hasBadParamDep? ys indParams;
|
||||
match badDep? with
|
||||
| some (indParam, y) => do
|
||||
trace `Elab.definition.structural fun _ =>
|
||||
"argument #" ++ toString (i+1) ++ " was not used because its type is an inductive datatype" ++ indentExpr xType ++
|
||||
Format.line ++ "and parameter" ++ indentExpr indParam ++
|
||||
Format.line ++ "depends on" ++ indentExpr y;
|
||||
findRecArgAux? (i+1)
|
||||
trace! `Elab.definition.structural
|
||||
("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive datatype" ++ indentExpr xType ++
|
||||
Format.line ++ "and parameter" ++ indentExpr indParam ++
|
||||
Format.line ++ "depends on" ++ indentExpr y);
|
||||
findRecArgAux (i+1)
|
||||
| none => do
|
||||
let indicesPos := indIndices.map fun index => match ys.indexOf index with | some i => i.val | none => unreachable!;
|
||||
a? ← k? { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size,
|
||||
indicesPos := indicesPos,
|
||||
indName := indInfo.name,
|
||||
indLevels := us,
|
||||
indParams := indParams,
|
||||
indIndices := indIndices,
|
||||
reflexive := indInfo.isReflexive };
|
||||
match a? with
|
||||
| some a => pure a
|
||||
| none => findRecArgAux? (i+1)
|
||||
catchInternalId structuralExceptionId
|
||||
(k { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size,
|
||||
indicesPos := indicesPos,
|
||||
indName := indInfo.name,
|
||||
indLevels := us,
|
||||
indParams := indParams,
|
||||
indIndices := indIndices,
|
||||
reflexive := indInfo.isReflexive })
|
||||
(fun _ => findRecArgAux (i+1))
|
||||
else
|
||||
pure none
|
||||
throwStructuralFailed
|
||||
|
||||
@[inline] private def findRecArg? {α} (numFixed : Nat) (xs : Array Expr) (k? : RecArgInfo → TermElabM (Option α)) : TermElabM (Option α) :=
|
||||
findRecArgAux? numFixed xs k? numFixed
|
||||
@[inline] private def findRecArg {α} (numFixed : Nat) (xs : Array Expr) (k : RecArgInfo → MetaM α) : MetaM α :=
|
||||
findRecArgAux numFixed xs k numFixed
|
||||
|
||||
private def replaceRecApps? (argInfo : RecArgInfo) (below : Expr) (value : Expr) : TermElabM (Option Expr) :=
|
||||
private def replaceRecApps (argInfo : RecArgInfo) (below : Expr) (value : Expr) : MetaM Expr :=
|
||||
-- TODO
|
||||
pure value
|
||||
|
||||
private def mkBRecOn? (argInfo : RecArgInfo) (value : Expr) : TermElabM (Option Expr) := do
|
||||
private def mkBRecOn (argInfo : RecArgInfo) (value : Expr) : MetaM Expr := do
|
||||
type ← inferType value;
|
||||
let type := type.headBeta;
|
||||
let major := argInfo.ys.get! argInfo.pos;
|
||||
|
|
@ -135,15 +143,15 @@ let otherArgs := argInfo.ys.filter fun y => y != major && !argInfo.indIndices.co
|
|||
motive ← mkForallFVars otherArgs type;
|
||||
brecOnUniv ← getDecLevel motive;
|
||||
motive ← mkLambdaFVars (argInfo.indIndices.push major) motive;
|
||||
trace `Elab.definition.structural fun _ => "brecOn motive: " ++ motive;
|
||||
trace! `Elab.definition.structural ("brecOn motive: " ++ motive);
|
||||
let brecOn := Lean.mkConst (mkBRecOnFor argInfo.indName) (brecOnUniv :: argInfo.indLevels);
|
||||
let brecOn := mkAppN brecOn argInfo.indParams;
|
||||
let brecOn := mkApp brecOn motive;
|
||||
let brecOn := mkAppN brecOn argInfo.indIndices;
|
||||
let brecOn := mkApp brecOn major;
|
||||
brecOnType ← inferType brecOn;
|
||||
trace `Elab.definition.structural fun _ => "brecOn " ++ brecOn;
|
||||
trace `Elab.definition.structural fun _ => "brecOnType " ++ brecOnType;
|
||||
trace! `Elab.definition.structural ("brecOn " ++ brecOn);
|
||||
trace! `Elab.definition.structural ("brecOnType " ++ brecOnType);
|
||||
forallBoundedTelescope brecOnType (some 1) fun F _ => do
|
||||
let F := F.get! 0;
|
||||
FType ← inferType F;
|
||||
|
|
@ -154,36 +162,33 @@ forallBoundedTelescope brecOnType (some 1) fun F _ => do
|
|||
let below := Fargs.get! (numIndices+1);
|
||||
let valueNew := value.replaceFVars argInfo.indIndices indicesNew;
|
||||
let valueNew := valueNew.replaceFVar major majorNew;
|
||||
valueNew? ← replaceRecApps? argInfo below valueNew;
|
||||
match valueNew? with
|
||||
| none => pure none
|
||||
| some valueNew => do
|
||||
Farg ← mkLambdaFVars Fargs valueNew;
|
||||
let brecOn := mkApp brecOn Farg;
|
||||
pure $ mkAppN brecOn otherArgs
|
||||
valueNew ← replaceRecApps argInfo below valueNew;
|
||||
Farg ← mkLambdaFVars Fargs valueNew;
|
||||
let brecOn := mkApp brecOn Farg;
|
||||
pure $ mkAppN brecOn otherArgs
|
||||
|
||||
private def elimRecursion? (preDef : PreDefinition) : TermElabM (Option PreDefinition) :=
|
||||
private def elimRecursion (preDef : PreDefinition) : MetaM PreDefinition :=
|
||||
lambdaTelescope preDef.value fun xs value => do
|
||||
trace `Elab.definition.structural fun _ => preDef.declName ++ " " ++ xs ++ " :=\n" ++ value;
|
||||
trace! `Elab.definition.structural (preDef.declName ++ " " ++ xs ++ " :=\n" ++ value);
|
||||
let numFixed := getFixedPrefix preDef.declName xs value;
|
||||
findRecArg? numFixed xs fun argInfo => do
|
||||
some valueNew ← mkBRecOn? argInfo value | pure none;
|
||||
findRecArg numFixed xs fun argInfo => do
|
||||
-- when (argInfo.indName == `Nat) throwStructuralFailed; -- HACK to skip Nat argument
|
||||
valueNew ← mkBRecOn argInfo value;
|
||||
valueNew ← mkLambdaFVars xs valueNew;
|
||||
trace `Elab.definition.structural fun _ => "result: " ++ valueNew;
|
||||
trace! `Elab.definition.structural ("result: " ++ valueNew);
|
||||
-- pure $ some { preDef with value := valueNew }
|
||||
throwError "WIP"
|
||||
|
||||
def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Bool :=
|
||||
if preDefs.size != 1 then
|
||||
pure false
|
||||
else do
|
||||
preDefNonRec? ← elimRecursion? (preDefs.get! 0);
|
||||
match preDefNonRec? with
|
||||
| none => pure false
|
||||
| some preDefNonRec => do
|
||||
addNonRec preDefNonRec;
|
||||
addAndCompileUnsafeRec preDefs;
|
||||
pure true
|
||||
else
|
||||
catchInternalId structuralExceptionId
|
||||
(do preDefNonRec ← liftMetaM $ elimRecursion (preDefs.get! 0);
|
||||
addNonRec preDefNonRec;
|
||||
addAndCompileUnsafeRec preDefs;
|
||||
pure true)
|
||||
(fun _ => pure false)
|
||||
|
||||
@[init] private def regTraceClasses : IO Unit := do
|
||||
registerTraceClass `Elab.definition.structural;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue