feat: add simpCasesOnCtor

This commit is contained in:
Leonardo de Moura 2022-09-01 16:52:23 -07:00
parent 255d34d2ac
commit 30c75b4b88
3 changed files with 39 additions and 7 deletions

View file

@ -66,7 +66,6 @@ inductive Code where
| unreach (type : Expr)
deriving Inhabited
abbrev Alt := AltCore Code
abbrev FunDecl := FunDeclCore Code
abbrev Cases := CasesCore Code
@ -230,6 +229,15 @@ to be updated.
-/
@[implementedBy updateFunDeclCoreImp] opaque FunDeclCore.updateCore (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
def CasesCore.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
let found (i : Nat) := (cases.alts[i]!, { cases with alts := cases.alts.eraseIdx i })
if let some i := cases.alts.findIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
found i
else if let some i := cases.alts.findIdx? fun | .default _ => true | _ => false then
found i
else
unreachable!
def Code.isDecl : Code → Bool
| .let .. | .fun .. | .jp .. => true
| _ => false

View file

@ -51,6 +51,9 @@ def eraseFVar (fvarId : FVarId) (recursive := true) : CompilerM Unit := do
def eraseFVarsAt (code : Code) : CompilerM Unit := do
modifyLCtx fun lctx => lctx.eraseFVarsAt code
def eraseParams (params : Array Param) : CompilerM Unit :=
params.forM (eraseFVar ·.fvarId)
/--
A free variable substitution.
We use these substitutions when inlining definitions and "internalizing" LCNF code into `CompilerM`.

View file

@ -357,6 +357,24 @@ partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
let value ← simp decl.value
decl.update type params value
/-- Try to simplify `cases` of `constructor` -/
partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
let discr ← normFVar cases.discr
let discrExpr ← findExpr (.fvar discr)
let some (ctorVal, ctorArgs) := discrExpr.constructorApp? (← getEnv) | return none
let (alt, cases) := cases.extractAlt! ctorVal.name
eraseFVarsAt (.cases cases)
markSimplified
match alt with
| .default k => simp k
| .alt _ params k =>
let fields := ctorArgs[ctorVal.numParams:]
for param in params, field in fields do
addSubst param.fvarId field
let k ← simp k
eraseParams params
return k
partial def simp (code : Code) : SimpM Code := do
incVisited
match code with
@ -413,12 +431,15 @@ partial def simp (code : Code) : SimpM Code := do
args.forM markUsedExpr
return code.updateJmp! fvarId args
| .cases c =>
-- TODO: cases simplifications
let resultType ← normExpr c.resultType
let discr ← normFVar c.discr
markUsedFVar discr
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
return code.updateCases! resultType discr alts
if let some k ← simpCasesOnCtor? c then
return k
else
-- TODO: other cases simplifications
let discr ← normFVar c.discr
let resultType ← normExpr c.resultType
markUsedFVar discr
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
return code.updateCases! resultType discr alts
end