diff --git a/src/Lean/Elab/PreDefinition/Structural.lean b/src/Lean/Elab/PreDefinition/Structural.lean index ca61d3037e..3071cafd7a 100644 --- a/src/Lean/Elab/PreDefinition/Structural.lean +++ b/src/Lean/Elab/PreDefinition/Structural.lean @@ -11,6 +11,14 @@ namespace Lean namespace Elab open Meta +def registerStructualId : IO InternalExceptionId := +registerInternalExceptionId `structuralRec +@[init registerStructualId] +constant structuralExceptionId : InternalExceptionId := arbitrary _ + +def throwStructuralFailed {α m} [MonadExceptOf Exception m] : m α := +throw $ Exception.internal structuralExceptionId + private def getFixedPrefix (declName : Name) (xs : Array Expr) (value : Expr) : Nat := let visitor {ω} : StateRefT Nat (ST ω) Unit := value.forEach' fun e => @@ -43,7 +51,7 @@ indices.foldl xs.size -- Indices can only depend on other indices -private def hasBadIndexDep? (ys : Array Expr) (indices : Array Expr) : TermElabM (Option (Expr × Expr)) := +private def hasBadIndexDep? (ys : Array Expr) (indices : Array Expr) : MetaM (Option (Expr × Expr)) := indices.findSomeM? fun index => do indexType ← inferType index; ys.findSomeM? fun y => @@ -53,7 +61,7 @@ indices.findSomeM? fun index => do (pure none) -- Inductive datatype parameters cannot depend on ys -private def hasBadParamDep? (ys : Array Expr) (indParams : Array Expr) : TermElabM (Option (Expr × Expr)) := +private def hasBadParamDep? (ys : Array Expr) (indParams : Array Expr) : MetaM (Option (Expr × Expr)) := indParams.findSomeM? fun p => do pType ← inferType p; ys.findSomeM? fun y => @@ -61,28 +69,29 @@ indParams.findSomeM? fun p => do (pure (some (p, y))) (pure none) -private partial def findRecArgAux? {α} (numFixed : Nat) (xs : Array Expr) (k? : RecArgInfo → TermElabM (Option α)) : Nat → TermElabM (Option α) +private partial def findRecArgAux {α} (numFixed : Nat) (xs : Array Expr) (k : RecArgInfo → MetaM α) : Nat → MetaM α | i => if h : i < xs.size then do let x := xs.get ⟨i, h⟩; localDecl ← getFVarLocalDecl x; - if localDecl.isLet then pure none + if localDecl.isLet then + throwStructuralFailed else do xType ← whnfD localDecl.type; - matchConstInduct xType.getAppFn (fun _ => findRecArgAux? (i+1)) fun indInfo us => do - condM (not <$> hasConst (mkBRecOnFor indInfo.name)) (findRecArgAux? (i+1)) do - condM (do hasBInductionOn ← hasConst (mkBInductionOnFor indInfo.name); pure $ indInfo.isReflexive && !hasBInductionOn) (findRecArgAux? (i+1)) do + matchConstInduct xType.getAppFn (fun _ => findRecArgAux (i+1)) fun indInfo us => do + condM (not <$> hasConst (mkBRecOnFor indInfo.name)) (findRecArgAux (i+1)) do + condM (do hasBInductionOn ← hasConst (mkBInductionOnFor indInfo.name); pure $ indInfo.isReflexive && !hasBInductionOn) (findRecArgAux (i+1)) do let indArgs := xType.getAppArgs; let indParams := indArgs.extract 0 indInfo.nparams; let indIndices := indArgs.extract indInfo.nparams indArgs.size; if !indIndices.all Expr.isFVar then do - trace `Elab.definition.structural fun _ => - "argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not variables" ++ indentExpr xType; - findRecArgAux? (i+1) + trace! `Elab.definition.structural + ("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not variables" ++ indentExpr xType); + findRecArgAux (i+1) else if !indIndices.allDiff then do - trace `Elab.definition.structural fun _ => - "argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not pairwise distinct" ++ indentExpr xType; - findRecArgAux? (i+1) + trace! `Elab.definition.structural + ("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family and indices are not pairwise distinct" ++ indentExpr xType); + findRecArgAux (i+1) else do let indexMinPos := getIndexMinPos xs indIndices; let numFixed := if indexMinPos < numFixed then indexMinPos else numFixed; @@ -91,43 +100,42 @@ private partial def findRecArgAux? {α} (numFixed : Nat) (xs : Array Expr) (k? : badDep? ← hasBadIndexDep? ys indIndices; match badDep? with | some (index, y) => do - trace `Elab.definition.structural fun _ => - "argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family" ++ indentExpr xType ++ - Format.line ++ "and index" ++ indentExpr index ++ - Format.line ++ "depends on the non index" ++ indentExpr y; - findRecArgAux? (i+1) + trace! `Elab.definition.structural + ("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive family" ++ indentExpr xType ++ + Format.line ++ "and index" ++ indentExpr index ++ + Format.line ++ "depends on the non index" ++ indentExpr y); + findRecArgAux (i+1) | none => do badDep? ← hasBadParamDep? ys indParams; match badDep? with | some (indParam, y) => do - trace `Elab.definition.structural fun _ => - "argument #" ++ toString (i+1) ++ " was not used because its type is an inductive datatype" ++ indentExpr xType ++ - Format.line ++ "and parameter" ++ indentExpr indParam ++ - Format.line ++ "depends on" ++ indentExpr y; - findRecArgAux? (i+1) + trace! `Elab.definition.structural + ("argument #" ++ toString (i+1) ++ " was not used because its type is an inductive datatype" ++ indentExpr xType ++ + Format.line ++ "and parameter" ++ indentExpr indParam ++ + Format.line ++ "depends on" ++ indentExpr y); + findRecArgAux (i+1) | none => do let indicesPos := indIndices.map fun index => match ys.indexOf index with | some i => i.val | none => unreachable!; - a? ← k? { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size, - indicesPos := indicesPos, - indName := indInfo.name, - indLevels := us, - indParams := indParams, - indIndices := indIndices, - reflexive := indInfo.isReflexive }; - match a? with - | some a => pure a - | none => findRecArgAux? (i+1) + catchInternalId structuralExceptionId + (k { fixedParams := fixedParams, ys := ys, pos := i - fixedParams.size, + indicesPos := indicesPos, + indName := indInfo.name, + indLevels := us, + indParams := indParams, + indIndices := indIndices, + reflexive := indInfo.isReflexive }) + (fun _ => findRecArgAux (i+1)) else - pure none + throwStructuralFailed -@[inline] private def findRecArg? {α} (numFixed : Nat) (xs : Array Expr) (k? : RecArgInfo → TermElabM (Option α)) : TermElabM (Option α) := -findRecArgAux? numFixed xs k? numFixed +@[inline] private def findRecArg {α} (numFixed : Nat) (xs : Array Expr) (k : RecArgInfo → MetaM α) : MetaM α := +findRecArgAux numFixed xs k numFixed -private def replaceRecApps? (argInfo : RecArgInfo) (below : Expr) (value : Expr) : TermElabM (Option Expr) := +private def replaceRecApps (argInfo : RecArgInfo) (below : Expr) (value : Expr) : MetaM Expr := -- TODO pure value -private def mkBRecOn? (argInfo : RecArgInfo) (value : Expr) : TermElabM (Option Expr) := do +private def mkBRecOn (argInfo : RecArgInfo) (value : Expr) : MetaM Expr := do type ← inferType value; let type := type.headBeta; let major := argInfo.ys.get! argInfo.pos; @@ -135,15 +143,15 @@ let otherArgs := argInfo.ys.filter fun y => y != major && !argInfo.indIndices.co motive ← mkForallFVars otherArgs type; brecOnUniv ← getDecLevel motive; motive ← mkLambdaFVars (argInfo.indIndices.push major) motive; -trace `Elab.definition.structural fun _ => "brecOn motive: " ++ motive; +trace! `Elab.definition.structural ("brecOn motive: " ++ motive); let brecOn := Lean.mkConst (mkBRecOnFor argInfo.indName) (brecOnUniv :: argInfo.indLevels); let brecOn := mkAppN brecOn argInfo.indParams; let brecOn := mkApp brecOn motive; let brecOn := mkAppN brecOn argInfo.indIndices; let brecOn := mkApp brecOn major; brecOnType ← inferType brecOn; -trace `Elab.definition.structural fun _ => "brecOn " ++ brecOn; -trace `Elab.definition.structural fun _ => "brecOnType " ++ brecOnType; +trace! `Elab.definition.structural ("brecOn " ++ brecOn); +trace! `Elab.definition.structural ("brecOnType " ++ brecOnType); forallBoundedTelescope brecOnType (some 1) fun F _ => do let F := F.get! 0; FType ← inferType F; @@ -154,36 +162,33 @@ forallBoundedTelescope brecOnType (some 1) fun F _ => do let below := Fargs.get! (numIndices+1); let valueNew := value.replaceFVars argInfo.indIndices indicesNew; let valueNew := valueNew.replaceFVar major majorNew; - valueNew? ← replaceRecApps? argInfo below valueNew; - match valueNew? with - | none => pure none - | some valueNew => do - Farg ← mkLambdaFVars Fargs valueNew; - let brecOn := mkApp brecOn Farg; - pure $ mkAppN brecOn otherArgs + valueNew ← replaceRecApps argInfo below valueNew; + Farg ← mkLambdaFVars Fargs valueNew; + let brecOn := mkApp brecOn Farg; + pure $ mkAppN brecOn otherArgs -private def elimRecursion? (preDef : PreDefinition) : TermElabM (Option PreDefinition) := +private def elimRecursion (preDef : PreDefinition) : MetaM PreDefinition := lambdaTelescope preDef.value fun xs value => do - trace `Elab.definition.structural fun _ => preDef.declName ++ " " ++ xs ++ " :=\n" ++ value; + trace! `Elab.definition.structural (preDef.declName ++ " " ++ xs ++ " :=\n" ++ value); let numFixed := getFixedPrefix preDef.declName xs value; - findRecArg? numFixed xs fun argInfo => do - some valueNew ← mkBRecOn? argInfo value | pure none; + findRecArg numFixed xs fun argInfo => do + -- when (argInfo.indName == `Nat) throwStructuralFailed; -- HACK to skip Nat argument + valueNew ← mkBRecOn argInfo value; valueNew ← mkLambdaFVars xs valueNew; - trace `Elab.definition.structural fun _ => "result: " ++ valueNew; + trace! `Elab.definition.structural ("result: " ++ valueNew); -- pure $ some { preDef with value := valueNew } throwError "WIP" def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Bool := if preDefs.size != 1 then pure false -else do - preDefNonRec? ← elimRecursion? (preDefs.get! 0); - match preDefNonRec? with - | none => pure false - | some preDefNonRec => do - addNonRec preDefNonRec; - addAndCompileUnsafeRec preDefs; - pure true +else + catchInternalId structuralExceptionId + (do preDefNonRec ← liftMetaM $ elimRecursion (preDefs.get! 0); + addNonRec preDefNonRec; + addAndCompileUnsafeRec preDefs; + pure true) + (fun _ => pure false) @[init] private def regTraceClasses : IO Unit := do registerTraceClass `Elab.definition.structural;