From 461c0786fda1949e11c99027dbafc8f4d8972efe Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 15 Nov 2020 16:31:40 -0800 Subject: [PATCH] feat: generate auxiliary declaration for "smart unfolding" --- src/Lean/Elab/PreDefinition/Structural.lean | 62 +++++++++++++++++++++ src/Lean/Meta/AppBuilder.lean | 8 +++ 2 files changed, 70 insertions(+) diff --git a/src/Lean/Elab/PreDefinition/Structural.lean b/src/Lean/Elab/PreDefinition/Structural.lean index 5fb23debb1..3cbc5a9ed9 100644 --- a/src/Lean/Elab/PreDefinition/Structural.lean +++ b/src/Lean/Elab/PreDefinition/Structural.lean @@ -349,6 +349,67 @@ private def elimRecursion (preDef : PreDefinition) : MetaM PreDefinition := let valueNew ← ensureNoRecFn preDef.declName valueNew pure { preDef with value := valueNew } +/- + Return true if `e` contains a matcher with nested recursive applications of `recFnName`. + This is auxiliary function used by the smartUnfolding procedure to decide where to insert + `idRhs` auxiliary applications. + + TODO: refine this test. It is just an approximation right now. + The perfect test should reflect the behavior of replaceRecApps. -/ +private def containsMatcherWithRecApp (recFnName : Name) (e : Expr) : MetaM Bool := do + let env ← getEnv + let m? := e.find? fun e => + match e.getAppFn with + | Expr.const constName .. => + match Match.Extension.getMatcherInfo? env constName with + | some info => containsRecFn recFnName e + | none => false + | _ => false + pure m?.isSome + +partial def addSmartUnfoldingDef (preDef : PreDefinition) : TermElabM Unit := do + if (← isProp preDef.type) then + return () + else + let recFnName := preDef.declName + let rec visit (e : Expr) : MetaM Expr := do + match e with + | Expr.lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs (← visit b) + | Expr.forallE .. => forallTelescope e fun xs b => do mkForallFVars xs (← visit b) + | Expr.letE n type val body _ => + withLetDecl n type (← visit val) fun x => do + mkLetFVars #[x] (← visit (body.instantiate1 x)) + | Expr.mdata d b _ => return mkMData d (← visit b) + | Expr.proj n i s _ => return mkProj n i (← visit s) + | Expr.app .. => + let processApp (e : Expr) : MetaM Expr := + e.withApp fun f args => do + return mkAppN (← visit f) (← args.mapM visit) + let matcherApp? ← matchMatcherApp? e + match matcherApp? with + | some matcherApp => + let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) => + lambdaTelescope alt fun xs altBody => do + unless xs.size >= numParams do + throwError! "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}" + if (← containsMatcherWithRecApp recFnName altBody) then + -- continue + mkLambdaFVars xs (← visit altBody) + else + -- add idRhs marker + let altBody ← mkLambdaFVars xs[numParams:xs.size] altBody + let altBody ← mkIdRhs altBody + mkLambdaFVars xs[0:numParams] altBody + pure { matcherApp with alts := altsNew }.toExpr + | none => processApp e + | _ => pure e + trace[Meta.debug]! "preDef {preDef.value}" + addNonRec { preDef with + declName := mkSmartUnfoldingNameFor preDef.declName, + value := (← visit preDef.value), + modifiers := {} + } + def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Unit := if preDefs.size != 1 then throwError "structural recursion does not handle mutually recursive functions" @@ -356,6 +417,7 @@ def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Unit := let preDefNonRec ← elimRecursion preDefs[0] mapError (addNonRec preDefNonRec) (fun msg => m!"structural recursion failed, produced type incorrect term{indentD msg}") addAndCompileUnsafeRec preDefs + addSmartUnfoldingDef preDefs[0] builtin_initialize registerTraceClass `Elab.definition.structural diff --git a/src/Lean/Meta/AppBuilder.lean b/src/Lean/Meta/AppBuilder.lean index cd766b2e5f..85d23c70a1 100644 --- a/src/Lean/Meta/AppBuilder.lean +++ b/src/Lean/Meta/AppBuilder.lean @@ -20,6 +20,14 @@ private def mkIdImp (e : Expr) : MetaM Expr := do def mkId (e : Expr) : m Expr := liftMetaM $ mkIdImp e +def mkIdRhsImp (e : Expr) : MetaM Expr := do + let type ← inferType e + let u ← getLevel type + pure $ mkApp2 (mkConst `idRhs [u]) type e +/-- Return `idRhs e` -/ +def mkIdRhs (e : Expr) : m Expr := + liftMetaM $ mkIdRhsImp e + private def mkExpectedTypeHintImp (e : Expr) (expectedType : Expr) : MetaM Expr := do let u ← getLevel expectedType pure $ mkApp2 (mkConst `id [u]) expectedType e