From 11c8253f6ca737eebe1d73c76e7fb4b66dfbe6f0 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 27 Aug 2022 08:48:52 -0700 Subject: [PATCH] feat: more update functions for LCNF --- src/Lean/Compiler/LCNF/Basic.lean | 20 +++++++++++++++++ src/Lean/Compiler/LCNF/Bind.lean | 2 +- src/Lean/Compiler/LCNF/CSE.lean | 2 +- src/Lean/Compiler/LCNF/CompilerM.lean | 28 +++++++++++++++++++----- src/Lean/Compiler/LCNF/PullLetDecls.lean | 6 +++-- 5 files changed, 48 insertions(+), 10 deletions(-) diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index fbc9e770ec..57610adc46 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -73,6 +73,26 @@ def AltCore.getCode : Alt → Code | .default k => k | .alt _ _ k => k +private unsafe def updateAltCodeImp (alt : Alt) (c : Code) : Alt := + match alt with + | .default k => if ptrEq k c then alt else AltCore.default c + | .alt ctorName ps k => if ptrEq k c then alt else AltCore.alt ctorName ps c + +@[implementedBy updateAltCodeImp] opaque AltCore.updateCode (alt : Alt) (c : Code) : Alt + +@[inline] private unsafe def updateContImp (c : Code) (k' : Code) : Code := + match c with + | .let decl k => if ptrEq k k' then c else .let decl k' + | .fun decl k => if ptrEq k k' then c else .fun decl k' + | .jp decl k => if ptrEq k k' then c else .jp decl k' + | _ => unreachable! + +@[implementedBy updateContImp] opaque Code.updateCont! (c : Code) (k' : Code) : Code + +def Code.isDecl : Code → Bool + | .let .. | .fun .. | .jp .. => true + | _ => false + partial def Code.size (c : Code) : Nat := go c 0 where diff --git a/src/Lean/Compiler/LCNF/Bind.lean b/src/Lean/Compiler/LCNF/Bind.lean index e26f8da378..a392fa480d 100644 --- a/src/Lean/Compiler/LCNF/Bind.lean +++ b/src/Lean/Compiler/LCNF/Bind.lean @@ -26,7 +26,7 @@ where | .jp decl k => let value ← go decl.value let type ← value.inferParamType decl.params - let decl := { decl with value, type } + let decl ← decl.update' type value withReader (fun s => s.insert decl.fvarId) do return .jp decl (← go k) | .cases c => diff --git a/src/Lean/Compiler/LCNF/CSE.lean b/src/Lean/Compiler/LCNF/CSE.lean index f67ad99be7..613893cad7 100644 --- a/src/Lean/Compiler/LCNF/CSE.lean +++ b/src/Lean/Compiler/LCNF/CSE.lean @@ -45,7 +45,7 @@ where let type := (← getSubst).applyToExpr decl.type let params ← (← getSubst).applyToParams decl.params let value ← withNewScope do go decl.value - return { decl with type, params, value } + decl.update type params value go (code : Code) : M Code := do match code with diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index 8256f26317..824b459f3a 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -96,8 +96,8 @@ open Internalize in /-- Refresh free variables ids in `code`, and store their declarations in the local context. -/ -partial def Code.internalize (code : Code) : CompilerM Code := - go code |>.run' {} +partial def Code.internalize (code : Code) (s : FVarSubst := {}) : CompilerM Code := + go code |>.run' s where goFunDecl (decl : FunDecl) : M FunDecl := do let type ← translateExpr decl.type @@ -159,7 +159,7 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := modifyLCtx fun lctx => lctx.addLocalDecl p.fvarId p.binderName p.type return p -@[implementedBy updateParamImp] opaque updateParam (p : Param) (type : Expr) : CompilerM Param +@[implementedBy updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl := do if ptrEq type decl.type && ptrEq value decl.value then @@ -169,16 +169,32 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr modifyLCtx fun lctx => lctx.addLetDecl decl.fvarId decl.binderName decl.type decl.value return decl -@[implementedBy updateLetDeclImp] opaque updateLetDecl (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl +@[implementedBy updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl + +private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do + if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then + return decl + else + let decl := { decl with type, params, value } + modifyLCtx fun lctx => lctx.addFunDecl decl + return decl + +@[implementedBy updateFunDeclImp] opaque FunDeclCore.update (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl + +abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl := + decl.update type decl.params value + +abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl := + decl.update decl.type decl.params value def FVarSubst.applyToParam (s : FVarSubst) (p : Param) : CompilerM Param := - updateParam p (s.applyToExpr p.type) + p.update (s.applyToExpr p.type) def FVarSubst.applyToParams (s : FVarSubst) (ps : Array Param) : CompilerM (Array Param) := ps.mapM s.applyToParam def FVarSubst.applyToLetDecl (s : FVarSubst) (decl : LetDecl) : CompilerM LetDecl := - updateLetDecl decl (s.applyToExpr decl.type) (s.applyToExpr decl.value) + decl.update (s.applyToExpr decl.type) (s.applyToExpr decl.value) def mkFreshBinderName (binderName := `_x): CompilerM Name := do let declName := .num binderName (← get).nextIdx diff --git a/src/Lean/Compiler/LCNF/PullLetDecls.lean b/src/Lean/Compiler/LCNF/PullLetDecls.lean index 88c07db0b0..df0da2aaee 100644 --- a/src/Lean/Compiler/LCNF/PullLetDecls.lean +++ b/src/Lean/Compiler/LCNF/PullLetDecls.lean @@ -78,11 +78,13 @@ mutual | .fun decl k => withCheckpoint do let value ← withParams decl.params <| pullDecls decl.value - withFVar decl.fvarId do return .fun { decl with value } (← pullDecls k) + let decl ← decl.updateValue value + withFVar decl.fvarId do return .fun decl (← pullDecls k) | .jp decl k => withCheckpoint do let value ← withParams decl.params <| pullDecls decl.value - withFVar decl.fvarId do return .jp { decl with value } (← pullDecls k) + let decl ← decl.updateValue value + withFVar decl.fvarId do return .jp decl (← pullDecls k) | _ => return code end