feat: add simpCasesOnCtor
This commit is contained in:
parent
255d34d2ac
commit
30c75b4b88
3 changed files with 39 additions and 7 deletions
|
|
@ -66,7 +66,6 @@ inductive Code where
|
|||
| unreach (type : Expr)
|
||||
deriving Inhabited
|
||||
|
||||
|
||||
abbrev Alt := AltCore Code
|
||||
abbrev FunDecl := FunDeclCore Code
|
||||
abbrev Cases := CasesCore Code
|
||||
|
|
@ -230,6 +229,15 @@ to be updated.
|
|||
-/
|
||||
@[implementedBy updateFunDeclCoreImp] opaque FunDeclCore.updateCore (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
|
||||
|
||||
def CasesCore.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
|
||||
let found (i : Nat) := (cases.alts[i]!, { cases with alts := cases.alts.eraseIdx i })
|
||||
if let some i := cases.alts.findIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
|
||||
found i
|
||||
else if let some i := cases.alts.findIdx? fun | .default _ => true | _ => false then
|
||||
found i
|
||||
else
|
||||
unreachable!
|
||||
|
||||
def Code.isDecl : Code → Bool
|
||||
| .let .. | .fun .. | .jp .. => true
|
||||
| _ => false
|
||||
|
|
|
|||
|
|
@ -51,6 +51,9 @@ def eraseFVar (fvarId : FVarId) (recursive := true) : CompilerM Unit := do
|
|||
def eraseFVarsAt (code : Code) : CompilerM Unit := do
|
||||
modifyLCtx fun lctx => lctx.eraseFVarsAt code
|
||||
|
||||
def eraseParams (params : Array Param) : CompilerM Unit :=
|
||||
params.forM (eraseFVar ·.fvarId)
|
||||
|
||||
/--
|
||||
A free variable substitution.
|
||||
We use these substitutions when inlining definitions and "internalizing" LCNF code into `CompilerM`.
|
||||
|
|
|
|||
|
|
@ -357,6 +357,24 @@ partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
|
|||
let value ← simp decl.value
|
||||
decl.update type params value
|
||||
|
||||
/-- Try to simplify `cases` of `constructor` -/
|
||||
partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
|
||||
let discr ← normFVar cases.discr
|
||||
let discrExpr ← findExpr (.fvar discr)
|
||||
let some (ctorVal, ctorArgs) := discrExpr.constructorApp? (← getEnv) | return none
|
||||
let (alt, cases) := cases.extractAlt! ctorVal.name
|
||||
eraseFVarsAt (.cases cases)
|
||||
markSimplified
|
||||
match alt with
|
||||
| .default k => simp k
|
||||
| .alt _ params k =>
|
||||
let fields := ctorArgs[ctorVal.numParams:]
|
||||
for param in params, field in fields do
|
||||
addSubst param.fvarId field
|
||||
let k ← simp k
|
||||
eraseParams params
|
||||
return k
|
||||
|
||||
partial def simp (code : Code) : SimpM Code := do
|
||||
incVisited
|
||||
match code with
|
||||
|
|
@ -413,12 +431,15 @@ partial def simp (code : Code) : SimpM Code := do
|
|||
args.forM markUsedExpr
|
||||
return code.updateJmp! fvarId args
|
||||
| .cases c =>
|
||||
-- TODO: cases simplifications
|
||||
let resultType ← normExpr c.resultType
|
||||
let discr ← normFVar c.discr
|
||||
markUsedFVar discr
|
||||
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
|
||||
return code.updateCases! resultType discr alts
|
||||
if let some k ← simpCasesOnCtor? c then
|
||||
return k
|
||||
else
|
||||
-- TODO: other cases simplifications
|
||||
let discr ← normFVar c.discr
|
||||
let resultType ← normExpr c.resultType
|
||||
markUsedFVar discr
|
||||
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
|
||||
return code.updateCases! resultType discr alts
|
||||
|
||||
end
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue