From e15e6bfaeed31186510d415e9b67ed8ade064d86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Wed, 5 Oct 2022 13:36:11 +0200 Subject: [PATCH] chore: address PR comments --- src/Lean/Compiler/LCNF/FVarUtil.lean | 46 ++++++++ src/Lean/Compiler/LCNF/JoinPoints.lean | 150 ++++++------------------- src/Lean/Compiler/LCNF/ScopeM.lean | 53 +++++++++ 3 files changed, 136 insertions(+), 113 deletions(-) create mode 100644 src/Lean/Compiler/LCNF/FVarUtil.lean create mode 100644 src/Lean/Compiler/LCNF/ScopeM.lean diff --git a/src/Lean/Compiler/LCNF/FVarUtil.lean b/src/Lean/Compiler/LCNF/FVarUtil.lean new file mode 100644 index 0000000000..28e8115b7e --- /dev/null +++ b/src/Lean/Compiler/LCNF/FVarUtil.lean @@ -0,0 +1,46 @@ +/- +Copyright (c) 2022 Henrik Böving. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Henrik Böving +-/ +import Lean.Expr + +namespace Lean.Compiler.LCNF + +partial def mapFVarM [Monad m] (f : FVarId → m FVarId) (e : Expr) : m Expr := do + match e with + | .proj typ idx struct => return .proj typ idx (← mapFVarM f struct) + | .app fn arg => return .app (← mapFVarM f fn) (← mapFVarM f arg) + | .fvar fvarId => return .fvar (← f fvarId) + | .lam arg ty body bi => + return .lam arg (← mapFVarM f ty) (← mapFVarM f body) bi + | .forallE arg ty body bi => + return .forallE arg (←mapFVarM f ty) (← mapFVarM f body) bi + | .letE var ty value body nonDep => + return .letE var (← mapFVarM f ty) (← mapFVarM f value) (← mapFVarM f body) nonDep + | .bvar .. | .sort .. => return e + | .mdata .. | .const .. | .lit .. => return e + | .mvar .. => unreachable! + +partial def forFVarM [Monad m] (f : FVarId → m Unit) (e : Expr) : m Unit := do + match e with + | .proj _ _ struct => forFVarM f struct + | .app fn arg => + forFVarM f fn + forFVarM f arg + | .fvar fvarId => f fvarId + | .lam _ ty body .. => + forFVarM f ty + forFVarM f body + | .forallE _ ty body .. => + forFVarM f ty + forFVarM f body + | .letE _ ty value body .. => + forFVarM f ty + forFVarM f value + forFVarM f body + | .bvar .. | .sort .. => return + | .mdata .. | .const .. | .lit .. => return + | .mvar .. => unreachable! + +end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/JoinPoints.lean b/src/Lean/Compiler/LCNF/JoinPoints.lean index fed4d8db7d..7d723600db 100644 --- a/src/Lean/Compiler/LCNF/JoinPoints.lean +++ b/src/Lean/Compiler/LCNF/JoinPoints.lean @@ -6,89 +6,11 @@ Authors: Henrik Böving import Lean.Compiler.LCNF.CompilerM import Lean.Compiler.LCNF.PassManager import Lean.Compiler.LCNF.PullFunDecls +import Lean.Compiler.LCNF.FVarUtil +import Lean.Compiler.LCNF.ScopeM namespace Lean.Compiler.LCNF --- TODO: These can be used in a much more general context -partial def mapFVarM [Monad m] (f : FVarId → m FVarId) (e : Expr) : m Expr := do - match e with - | .proj typ idx struct => return .proj typ idx (←mapFVarM f struct) - | .app fn arg => return .app (←mapFVarM f fn) (←mapFVarM f arg) - | .fvar fvarId => return .fvar (←f fvarId) - | .lam arg ty body bi => - return .lam arg (←mapFVarM f ty) (←mapFVarM f body) bi - | .forallE arg ty body bi => - return .forallE arg (←mapFVarM f ty) (←mapFVarM f body) bi - | .letE var ty value body nonDep => - return .letE var (←mapFVarM f ty) (←mapFVarM f value) (←mapFVarM f body) nonDep - | .bvar .. | .sort .. => return e - | .mdata .. | .const .. | .lit .. => return e - | .mvar .. => unreachable! - -partial def forFVarM [Monad m] (f : FVarId → m Unit) (e : Expr) : m Unit := do - match e with - | .proj _ _ struct => forFVarM f struct - | .app fn arg => - forFVarM f fn - forFVarM f arg - | .fvar fvarId => f fvarId - | .lam _ ty body .. => - forFVarM f ty - forFVarM f body - | .forallE _ ty body .. => - forFVarM f ty - forFVarM f body - | .letE _ ty value body .. => - forFVarM f ty - forFVarM f value - forFVarM f body - | .bvar .. | .sort .. => return - | .mdata .. | .const .. | .lit .. => return - | .mvar .. => unreachable! - -/-- -A general abstraction for the idea of a scope in the compiler. --/ -abbrev ScopeM := StateRefT FVarIdSet CompilerM - -namespace ScopeM - -def getScope : ScopeM FVarIdSet := get -def setScope (newScope : FVarIdSet) : ScopeM Unit := set newScope -def clearScope : ScopeM Unit := setScope {} - -/-- -Execute `x` but recover the previous scope after doing so. --/ -def withBackTrackingScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do - let scope ← getScope - try x finally setScope scope - -/-- -Clear the current scope for the monadic action `x`, afterwards continuing -with the old one. --/ -def withNewScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do - withBackTrackingScope do - clearScope - x - -/-- -Check whether `fvarId` is in the current scope, that is, was declared within -the current `fun` declaration that is being processed. --/ -def isInScope (fvarId : FVarId) : ScopeM Bool := do - let scope ← getScope - return scope.contains fvarId - -/-- -Add a new `FVarId` to the current scope. --/ -def addToScope (fvarId : FVarId) : ScopeM Unit := - modify fun scope => scope.insert fvarId - -end ScopeM - namespace JoinPointFinder open ScopeM @@ -217,7 +139,7 @@ where eraseCandidate fvarId -- Out of scope join point candidate handling else if let some upperCandidate ← read then - if !(←isInScope fvarId) then + if !(← isInScope fvarId) then addDependency fvarId upperCandidate else eraseCandidate fvarId @@ -246,7 +168,7 @@ Replace all join point candidate `fun` declarations with `jp` ones and all calls to them with `jmp`s. -/ partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do - let mapper := fun acc cname _ => do return acc.insert cname (←mkFreshJpName) + let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName) let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := .empty) mapper let newValue ← go decl.value |>.run replaceCtx return { decl with value := newValue } @@ -264,23 +186,23 @@ where return code else return code - | _, _, _ => return Code.updateLet! code decl (←go k) + | _, _, _ => return Code.updateLet! code decl (← go k) | .fun decl k => if let some replacement := (← read).find? decl.fvarId then let newDecl := { decl with binderName := replacement, - value := (←go decl.value) + value := (← go decl.value) } modifyLCtx fun lctx => lctx.addFunDecl newDecl - return .jp newDecl (←go k) + return .jp newDecl (← go k) else - let newDecl ← decl.updateValue (←go decl.value) - return Code.updateFun! code newDecl (←go k) + let newDecl ← decl.updateValue (← go decl.value) + return Code.updateFun! code newDecl (← go k) | .jp decl k => - let newDecl ← decl.updateValue (←go decl.value) - return Code.updateFun! code newDecl (←go k) + let newDecl ← decl.updateValue (← go decl.value) + return Code.updateFun! code newDecl (← go k) | .cases cs => - return Code.updateCases! code cs.resultType cs.discr (←cs.alts.mapM (·.mapCodeM go)) + return Code.updateCases! code cs.resultType cs.discr (← cs.alts.mapM (·.mapCodeM go)) | .jmp .. | .return .. | .unreach .. => return code @@ -330,9 +252,9 @@ Replace a free variable if necessary, that is: otherwise just return `fvar`. -/ def replaceFVar (fvar : FVarId) : ExtendM FVarId := do - if (←read).candidates.contains fvar then - if let some currentJp := (←read).currentJp? then - if let some replacement := (←get).fvarMap.find! currentJp |>.find? fvar then + if (← read).candidates.contains fvar then + if let some currentJp := (← read).currentJp? then + if let some replacement := (← get).fvarMap.find! currentJp |>.find? fvar then return replacement.fvarId return fvar @@ -342,7 +264,7 @@ if we are currently within a join point. Then execute `x`. -/ def withNewCandidate (fvar : FVarId) (x : ExtendM α) : ExtendM α := do addToScope fvar - if (←read).currentJp?.isSome then + if (← read).currentJp?.isSome then withReader (fun ctx => { ctx with candidates := ctx.candidates.insert fvar }) do x else @@ -352,9 +274,11 @@ def withNewCandidate (fvar : FVarId) (x : ExtendM α) : ExtendM α := do Same as `withNewCandidate` but with multiple `FVarId`s. -/ def withNewCandidates (fvars : Array FVarId) (x : ExtendM α) : ExtendM α := do - if (←read).currentJp?.isSome then - let candidates := (←read).candidates - let folder := (fun acc val => do addToScope val; return acc.insert val) + if (← read).currentJp?.isSome then + let candidates := (← read).candidates + let folder (acc : FVarIdSet) (val : FVarId) := do + addToScope val + return acc.insert val let newCandidates ← fvars.foldlM (init := candidates) folder withReader (fun ctx => { ctx with candidates := newCandidates }) do x @@ -380,10 +304,10 @@ This is necessary if: cannot lift a join point outside of a local function declaration. -/ def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do - if let some currentJp := (←read).currentJp? then - let mut translator := (←get).fvarMap.find! currentJp - let candidates := (←read).candidates - if !(←isInScope fvar) && !translator.contains fvar && candidates.contains fvar then + if let some currentJp := (← read).currentJp? then + let mut translator := (← get).fvarMap.find! currentJp + let candidates := (← read).candidates + if !(← isInScope fvar) && !translator.contains fvar && candidates.contains fvar then let typ ← getType fvar let newParam ← mkAuxParam typ translator := translator.insert fvar newParam @@ -404,8 +328,8 @@ as well because we need to drag these variables through at the call sites of `j.2` in `j.1`. -/ def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do - if (←read).currentJp?.isSome then - let additionalArgs := (←get).fvarMap.find! jp |>.toArray + if (← read).currentJp?.isSome then + let additionalArgs := (← get).fvarMap.find! jp |>.toArray for (fvar, _) in additionalArgs do extendByIfNecessary fvar @@ -469,23 +393,23 @@ where go (code : Code) : ExtendM Code := do match code with | .let decl k => - let decl ← decl.updateValue (←goExpr decl.value) + let decl ← decl.updateValue (← goExpr decl.value) withNewCandidate decl.fvarId do - return Code.updateLet! code decl (←go k) + return Code.updateLet! code decl (← go k) | .jp decl k => let decl ← withNewJpScope decl do let value ← go decl.value - let additionalParams := (←get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd + let additionalParams := (← get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default) decl.update newType (additionalParams ++ decl.params) value mergeJpContextIfNecessary decl.fvarId withNewCandidate decl.fvarId do - return Code.updateFun! code decl (←go k) + return Code.updateFun! code decl (← go k) | .fun decl k => let decl ← withNewFunScope decl do - decl.updateValue (←go decl.value) + decl.updateValue (← go decl.value) withNewCandidate decl.fvarId do - return Code.updateFun! code decl (←go k) + return Code.updateFun! code decl (← go k) | .cases cs => extendByIfNecessary cs.discr let discr ← replaceFVar cs.discr @@ -496,9 +420,9 @@ where return Code.updateCases! code cs.resultType discr alts | .jmp fn args => let mut newArgs ← args.mapM goExpr - let additionalArgs := (←get).fvarMap.find! fn |>.toArray |>.map Prod.fst - if let some currentJp := (←read).currentJp? then - let translator := (←get).fvarMap.find! currentJp + let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst + if let some currentJp := (← read).currentJp? then + let translator := (← get).fvarMap.find! currentJp let f := fun arg => if let some translated := translator.find? arg then .fvar translated.fvarId @@ -510,7 +434,7 @@ where return Code.updateJmp! code fn newArgs | .return var => extendByIfNecessary var - return Code.updateReturn! code (←replaceFVar var) + return Code.updateReturn! code (← replaceFVar var) | .unreach .. => return code end JoinPointContextExtender diff --git a/src/Lean/Compiler/LCNF/ScopeM.lean b/src/Lean/Compiler/LCNF/ScopeM.lean new file mode 100644 index 0000000000..893af43acd --- /dev/null +++ b/src/Lean/Compiler/LCNF/ScopeM.lean @@ -0,0 +1,53 @@ +/- +Copyright (c) 2022 Henrik Böving. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Henrik Böving +-/ +import Lean.Compiler.LCNF.CompilerM + +namespace Lean.Compiler.LCNF + +/-- +A general abstraction for the idea of a scope in the compiler. +-/ +abbrev ScopeM := StateRefT FVarIdSet CompilerM + +namespace ScopeM + +def getScope : ScopeM FVarIdSet := get +def setScope (newScope : FVarIdSet) : ScopeM Unit := set newScope +def clearScope : ScopeM Unit := setScope {} + +/-- +Execute `x` but recover the previous scope after doing so. +-/ +def withBackTrackingScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do + let scope ← getScope + try x finally setScope scope + +/-- +Clear the current scope for the monadic action `x`, afterwards continuing +with the old one. +-/ +def withNewScope [MonadLiftT ScopeM m] [Monad m] [MonadFinally m] (x : m α) : m α := do + withBackTrackingScope do + clearScope + x + +/-- +Check whether `fvarId` is in the current scope, that is, was declared within +the current `fun` declaration that is being processed. +-/ +def isInScope (fvarId : FVarId) : ScopeM Bool := do + let scope ← getScope + return scope.contains fvarId + +/-- +Add a new `FVarId` to the current scope. +-/ +def addToScope (fvarId : FVarId) : ScopeM Unit := + modify fun scope => scope.insert fvarId + +end ScopeM + +end Lean.Compiler.LCNF