feat: generate auxiliary declaration for "smart unfolding"
This commit is contained in:
parent
0ca7dabb2a
commit
461c0786fd
2 changed files with 70 additions and 0 deletions
|
|
@ -349,6 +349,67 @@ private def elimRecursion (preDef : PreDefinition) : MetaM PreDefinition :=
|
|||
let valueNew ← ensureNoRecFn preDef.declName valueNew
|
||||
pure { preDef with value := valueNew }
|
||||
|
||||
/-
|
||||
Return true if `e` contains a matcher with nested recursive applications of `recFnName`.
|
||||
This is auxiliary function used by the smartUnfolding procedure to decide where to insert
|
||||
`idRhs` auxiliary applications.
|
||||
|
||||
TODO: refine this test. It is just an approximation right now.
|
||||
The perfect test should reflect the behavior of replaceRecApps. -/
|
||||
private def containsMatcherWithRecApp (recFnName : Name) (e : Expr) : MetaM Bool := do
|
||||
let env ← getEnv
|
||||
let m? := e.find? fun e =>
|
||||
match e.getAppFn with
|
||||
| Expr.const constName .. =>
|
||||
match Match.Extension.getMatcherInfo? env constName with
|
||||
| some info => containsRecFn recFnName e
|
||||
| none => false
|
||||
| _ => false
|
||||
pure m?.isSome
|
||||
|
||||
partial def addSmartUnfoldingDef (preDef : PreDefinition) : TermElabM Unit := do
|
||||
if (← isProp preDef.type) then
|
||||
return ()
|
||||
else
|
||||
let recFnName := preDef.declName
|
||||
let rec visit (e : Expr) : MetaM Expr := do
|
||||
match e with
|
||||
| Expr.lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs (← visit b)
|
||||
| Expr.forallE .. => forallTelescope e fun xs b => do mkForallFVars xs (← visit b)
|
||||
| Expr.letE n type val body _ =>
|
||||
withLetDecl n type (← visit val) fun x => do
|
||||
mkLetFVars #[x] (← visit (body.instantiate1 x))
|
||||
| Expr.mdata d b _ => return mkMData d (← visit b)
|
||||
| Expr.proj n i s _ => return mkProj n i (← visit s)
|
||||
| Expr.app .. =>
|
||||
let processApp (e : Expr) : MetaM Expr :=
|
||||
e.withApp fun f args => do
|
||||
return mkAppN (← visit f) (← args.mapM visit)
|
||||
let matcherApp? ← matchMatcherApp? e
|
||||
match matcherApp? with
|
||||
| some matcherApp =>
|
||||
let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) =>
|
||||
lambdaTelescope alt fun xs altBody => do
|
||||
unless xs.size >= numParams do
|
||||
throwError! "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
|
||||
if (← containsMatcherWithRecApp recFnName altBody) then
|
||||
-- continue
|
||||
mkLambdaFVars xs (← visit altBody)
|
||||
else
|
||||
-- add idRhs marker
|
||||
let altBody ← mkLambdaFVars xs[numParams:xs.size] altBody
|
||||
let altBody ← mkIdRhs altBody
|
||||
mkLambdaFVars xs[0:numParams] altBody
|
||||
pure { matcherApp with alts := altsNew }.toExpr
|
||||
| none => processApp e
|
||||
| _ => pure e
|
||||
trace[Meta.debug]! "preDef {preDef.value}"
|
||||
addNonRec { preDef with
|
||||
declName := mkSmartUnfoldingNameFor preDef.declName,
|
||||
value := (← visit preDef.value),
|
||||
modifiers := {}
|
||||
}
|
||||
|
||||
def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Unit :=
|
||||
if preDefs.size != 1 then
|
||||
throwError "structural recursion does not handle mutually recursive functions"
|
||||
|
|
@ -356,6 +417,7 @@ def structuralRecursion (preDefs : Array PreDefinition) : TermElabM Unit :=
|
|||
let preDefNonRec ← elimRecursion preDefs[0]
|
||||
mapError (addNonRec preDefNonRec) (fun msg => m!"structural recursion failed, produced type incorrect term{indentD msg}")
|
||||
addAndCompileUnsafeRec preDefs
|
||||
addSmartUnfoldingDef preDefs[0]
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.definition.structural
|
||||
|
|
|
|||
|
|
@ -20,6 +20,14 @@ private def mkIdImp (e : Expr) : MetaM Expr := do
|
|||
def mkId (e : Expr) : m Expr :=
|
||||
liftMetaM $ mkIdImp e
|
||||
|
||||
def mkIdRhsImp (e : Expr) : MetaM Expr := do
|
||||
let type ← inferType e
|
||||
let u ← getLevel type
|
||||
pure $ mkApp2 (mkConst `idRhs [u]) type e
|
||||
/-- Return `idRhs e` -/
|
||||
def mkIdRhs (e : Expr) : m Expr :=
|
||||
liftMetaM $ mkIdRhsImp e
|
||||
|
||||
private def mkExpectedTypeHintImp (e : Expr) (expectedType : Expr) : MetaM Expr := do
|
||||
let u ← getLevel expectedType
|
||||
pure $ mkApp2 (mkConst `id [u]) expectedType e
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue