feat: generate auxiliary declaration for "smart unfolding"

This commit is contained in:
Leonardo de Moura 2020-11-15 16:31:40 -08:00
parent 0ca7dabb2a
commit 461c0786fd
2 changed files with 70 additions and 0 deletions

View file

@ -349,6 +349,67 @@ private def elimRecursion (preDef : PreDefinition) : MetaM PreDefinition :=
let valueNew ← ensureNoRecFn preDef.declName valueNew
pure { preDef with value := valueNew }
/-
Return true if `e` contains a matcher with nested recursive applications of `recFnName`.
This is auxiliary function used by the smartUnfolding procedure to decide where to insert
`idRhs` auxiliary applications.
TODO: refine this test. It is just an approximation right now.
The perfect test should reflect the behavior of replaceRecApps. -/
private def containsMatcherWithRecApp (recFnName : Name) (e : Expr) : MetaM Bool := do
let env ← getEnv
let m? := e.find? fun e =>
match e.getAppFn with
| Expr.const constName .. =>
match Match.Extension.getMatcherInfo? env constName with
| some info => containsRecFn recFnName e
| none => false
| _ => false
pure m?.isSome
partial def addSmartUnfoldingDef (preDef : PreDefinition) : TermElabM Unit := do
if (← isProp preDef.type) then
return ()
else
let recFnName := preDef.declName
let rec visit (e : Expr) : MetaM Expr := do
match e with
| Expr.lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs (← visit b)
| Expr.forallE .. => forallTelescope e fun xs b => do mkForallFVars xs (← visit b)
| Expr.letE n type val body _ =>
withLetDecl n type (← visit val) fun x => do
mkLetFVars #[x] (← visit (body.instantiate1 x))
| Expr.mdata d b _ => return mkMData d (← visit b)
| Expr.proj n i s _ => return mkProj n i (← visit s)
| Expr.app .. =>
let processApp (e : Expr) : MetaM Expr :=
e.withApp fun f args => do
return mkAppN (← visit f) (← args.mapM visit)
let matcherApp? ← matchMatcherApp? e
match matcherApp? with
| some matcherApp =>
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
lambdaTelescope alt fun xs altBody => do
unless xs.size >= numParams do
throwError! "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
if (← containsMatcherWithRecApp recFnName altBody) then
-- continue
mkLambdaFVars xs (← visit altBody)
else
-- add idRhs marker
let altBody ← mkLambdaFVars xs[numParams:xs.size] altBody
let altBody ← mkIdRhs altBody
mkLambdaFVars xs[0:numParams] altBody
pure { matcherApp with alts := altsNew }.toExpr
| none => processApp e
| _ => pure e
trace[Meta.debug]! "preDef {preDef.value}"
addNonRec { preDef with
declName := mkSmartUnfoldingNameFor preDef.declName,
value := (← visit preDef.value),
modifiers := {}
}
def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Unit :=
if preDefs.size != 1 then
throwError "structural recursion does not handle mutually recursive functions"
@ -356,6 +417,7 @@ def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Unit :=
let preDefNonRec ← elimRecursion preDefs[0]
mapError (addNonRec preDefNonRec) (fun msg => m!"structural recursion failed, produced type incorrect term{indentD msg}")
addAndCompileUnsafeRec preDefs
addSmartUnfoldingDef preDefs[0]
builtin_initialize
registerTraceClass `Elab.definition.structural

View file

@ -20,6 +20,14 @@ private def mkIdImp (e : Expr) : MetaM Expr := do
def mkId (e : Expr) : m Expr :=
liftMetaM $ mkIdImp e
def mkIdRhsImp (e : Expr) : MetaM Expr := do
let type ← inferType e
let u ← getLevel type
pure $ mkApp2 (mkConst `idRhs [u]) type e
/-- Return `idRhs e` -/
def mkIdRhs (e : Expr) : m Expr :=
liftMetaM $ mkIdRhsImp e
private def mkExpectedTypeHintImp (e : Expr) (expectedType : Expr) : MetaM Expr := do
let u ← getLevel expectedType
pure $ mkApp2 (mkConst `id [u]) expectedType e