From 30c75b4b88eb7a8a3ec9cf014b0a5f56ea241aa1 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 1 Sep 2022 16:52:23 -0700 Subject: [PATCH] feat: add `simpCasesOnCtor` --- src/Lean/Compiler/LCNF/Basic.lean | 10 +++++++- src/Lean/Compiler/LCNF/CompilerM.lean | 3 +++ src/Lean/Compiler/LCNF/Simp.lean | 33 ++++++++++++++++++++++----- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index c33dcd8a4b..3dc06e8cd4 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -66,7 +66,6 @@ inductive Code where | unreach (type : Expr) deriving Inhabited - abbrev Alt := AltCore Code abbrev FunDecl := FunDeclCore Code abbrev Cases := CasesCore Code @@ -230,6 +229,15 @@ to be updated. -/ @[implementedBy updateFunDeclCoreImp] opaque FunDeclCore.updateCore (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl +def CasesCore.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases := + let found (i : Nat) := (cases.alts[i]!, { cases with alts := cases.alts.eraseIdx i }) + if let some i := cases.alts.findIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then + found i + else if let some i := cases.alts.findIdx? fun | .default _ => true | _ => false then + found i + else + unreachable! + def Code.isDecl : Code → Bool | .let .. | .fun .. | .jp .. => true | _ => false diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index 739b662b20..df9558f403 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -51,6 +51,9 @@ def eraseFVar (fvarId : FVarId) (recursive := true) : CompilerM Unit := do def eraseFVarsAt (code : Code) : CompilerM Unit := do modifyLCtx fun lctx => lctx.eraseFVarsAt code +def eraseParams (params : Array Param) : CompilerM Unit := + params.forM (eraseFVar ·.fvarId) + /-- A free variable substitution. We use these substitutions when inlining definitions and "internalizing" LCNF code into `CompilerM`. diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index 2aa26c61fd..e2ac342e71 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -357,6 +357,24 @@ partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do let value ← simp decl.value decl.update type params value +/-- Try to simplify `cases` of `constructor` -/ +partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do + let discr ← normFVar cases.discr + let discrExpr ← findExpr (.fvar discr) + let some (ctorVal, ctorArgs) := discrExpr.constructorApp? (← getEnv) | return none + let (alt, cases) := cases.extractAlt! ctorVal.name + eraseFVarsAt (.cases cases) + markSimplified + match alt with + | .default k => simp k + | .alt _ params k => + let fields := ctorArgs[ctorVal.numParams:] + for param in params, field in fields do + addSubst param.fvarId field + let k ← simp k + eraseParams params + return k + partial def simp (code : Code) : SimpM Code := do incVisited match code with @@ -413,12 +431,15 @@ partial def simp (code : Code) : SimpM Code := do args.forM markUsedExpr return code.updateJmp! fvarId args | .cases c => - -- TODO: cases simplifications - let resultType ← normExpr c.resultType - let discr ← normFVar c.discr - markUsedFVar discr - let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode) - return code.updateCases! resultType discr alts + if let some k ← simpCasesOnCtor? c then + return k + else + -- TODO: other cases simplifications + let discr ← normFVar c.discr + let resultType ← normExpr c.resultType + markUsedFVar discr + let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode) + return code.updateCases! resultType discr alts end