diff --git a/src/Lean/Compiler/IR/ToIR.lean b/src/Lean/Compiler/IR/ToIR.lean index a3ec28a2ce..15dce36813 100644 --- a/src/Lean/Compiler/IR/ToIR.lean +++ b/src/Lean/Compiler/IR/ToIR.lean @@ -40,14 +40,14 @@ structure BuilderState where For this reason we carry around these kinds of bindings in this substitution and apply it whenever we access an fvar in the conversion. -/ - subst : LCNF.FVarSubst := {} + subst : LCNF.FVarSubst .pure := {} abbrev M := StateRefT BuilderState CoreM -instance : LCNF.MonadFVarSubst M false where +instance : LCNF.MonadFVarSubst M .pure false where getSubst := return (← get).subst -instance : LCNF.MonadFVarSubstState M where +instance : LCNF.MonadFVarSubstState M .pure where modifySubst f := modify fun s => { s with subst := f s.subst } def M.run (x : M α) : CoreM α := do @@ -102,7 +102,7 @@ def lowerLitValue (v : LCNF.LitValue) : LitVal × IRType := | .uint64 v => ⟨.num (UInt64.toNat v), .uint64⟩ | .usize v => ⟨.num (UInt64.toNat v), .usize⟩ -def lowerArg (a : LCNF.Arg) : M Arg := do +def lowerArg (a : LCNF.Arg .pure) : M Arg := do match a with | .fvar fvarId => getFVarValue fvarId | .erased | .type .. => return .erased @@ -121,15 +121,15 @@ def lowerProj (base : VarId) (ctorInfo : CtorInfo) (field : CtorFieldInfo) | .erased => ⟨.erased, .erased⟩ | .void => ⟨.erased, .void⟩ -def lowerParam (p : LCNF.Param) : M Param := do +def lowerParam (p : LCNF.Param .pure) : M Param := do let x ← bindVar p.fvarId let ty ← toIRType p.type if ty.isVoid || ty.isErased then - Compiler.LCNF.addSubst p.fvarId .erased + Compiler.LCNF.addSubst p.fvarId (.erased : LCNF.Arg .pure) return { x, borrow := p.borrow, ty } mutual -partial def lowerCode (c : LCNF.Code) : M FnBody := do +partial def lowerCode (c : LCNF.Code .pure) : M FnBody := do match c with | .let decl k => lowerLet decl k | .jp decl k => @@ -149,7 +149,7 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do for idx in 0...ps.size do let p := ps[idx]! if idx == info.fieldIdx then - LCNF.addSubst p.fvarId (.fvar cases.discr) + LCNF.addSubst p.fvarId (.fvar cases.discr : LCNF.Arg .pure) else bindErased p.fvarId lowerCode k @@ -165,7 +165,7 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do | .unreach .. => return .unreachable | .fun .. => panic! "all local functions should be λ-lifted" -partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do +partial def lowerLet (decl : LCNF.LetDecl .pure) (k : LCNF.Code .pure) : M FnBody := do let value ← LCNF.normLetValue decl.value match value with | .lit litValue => @@ -175,7 +175,7 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do | .proj typeName i fvarId => if let some info ← hasTrivialStructure? typeName then if info.fieldIdx == i then - LCNF.addSubst decl.fvarId (.fvar fvarId) + LCNF.addSubst decl.fvarId (.fvar fvarId : LCNF.Arg .pure) else bindErased decl.fvarId lowerCode k @@ -302,11 +302,11 @@ where else mkOverApplication name numParams args -partial def lowerAlt (discr : VarId) (a : LCNF.Alt) : M Alt := do +partial def lowerAlt (discr : VarId) (a : LCNF.Alt .pure) : M Alt := do match a with | .alt ctorName params code => let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName - let lowerParams (params : Array LCNF.Param) (fields : Array CtorFieldInfo) : M FnBody := do + let lowerParams (params : Array (LCNF.Param .pure)) (fields : Array CtorFieldInfo) : M FnBody := do let rec loop (i : Nat) : M FnBody := do match params[i]?, fields[i]? with | some param, some field => @@ -340,7 +340,7 @@ where resultTypeForArity (type : Lean.Expr) (arity : Nat) : Lean.Expr := | .const ``lcErased _ => mkConst ``lcErased | _ => panic! "invalid arity" -def lowerDecl (d : LCNF.Decl) : M (Option Decl) := do +def lowerDecl (d : LCNF.Decl .pure) : M (Option Decl) := do let params ← d.params.mapM lowerParam let mut resultType ← lowerResultType d.type d.params.size let taggedReturn := taggedReturnAttr.hasTag (← getEnv) d.name @@ -366,7 +366,7 @@ def lowerDecl (d : LCNF.Decl) : M (Option Decl) := do end ToIR -def toIR (decls: Array LCNF.Decl) : CoreM (Array Decl) := do +def toIR (decls: Array (LCNF.Decl .pure)) : CoreM (Array Decl) := do let mut irDecls := #[] for decl in decls do if let some irDecl ← ToIR.lowerDecl decl |>.run then diff --git a/src/Lean/Compiler/LCNF/AlphaEqv.lean b/src/Lean/Compiler/LCNF/AlphaEqv.lean index 052382516c..ca7f153408 100644 --- a/src/Lean/Compiler/LCNF/AlphaEqv.lean +++ b/src/Lean/Compiler/LCNF/AlphaEqv.lean @@ -40,14 +40,14 @@ def eqvTypes (es₁ es₂ : Array Expr) : EqvM Bool := do else return false -def eqvArg (a₁ a₂ : Arg) : EqvM Bool := do +def eqvArg (a₁ a₂ : Arg pu) : EqvM Bool := do match a₁, a₂ with - | .type e₁, .type e₂ => eqvType e₁ e₂ + | .type e₁ _, .type e₂ _ => eqvType e₁ e₂ | .fvar x₁, .fvar x₂ => eqvFVar x₁ x₂ | .erased, .erased => return true | _, _ => return false -def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do +def eqvArgs (as₁ as₂ : Array (Arg pu)) : EqvM Bool := do if as₁.size = as₂.size then for a₁ in as₁, a₂ in as₂ do unless (← eqvArg a₁ a₂) do @@ -56,19 +56,19 @@ def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do else return false -def eqvLetValue (e₁ e₂ : LetValue) : EqvM Bool := do +def eqvLetValue (e₁ e₂ : LetValue pu) : EqvM Bool := do match e₁, e₂ with | .lit v₁, .lit v₂ => return v₁ == v₂ | .erased, .erased => return true - | .proj s₁ i₁ x₁, .proj s₂ i₂ x₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂ - | .const n₁ us₁ as₁, .const n₂ us₂ as₂ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂ + | .proj s₁ i₁ x₁ _, .proj s₂ i₂ x₂ _ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂ + | .const n₁ us₁ as₁ _, .const n₂ us₂ as₂ _ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂ | .fvar f₁ as₁, .fvar f₂ as₂ => eqvFVar f₁ f₂ <&&> eqvArgs as₁ as₂ | _, _ => return false @[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α := withReader (·.insert fvarId₂ fvarId₁) x -@[inline] def withParams (params₁ params₂ : Array Param) (x : EqvM Bool) : EqvM Bool := do +@[inline] def withParams (params₁ params₂ : Array (Param pu)) (x : EqvM Bool) : EqvM Bool := do if h : params₂.size = params₁.size then let rec @[specialize] go (i : Nat) : EqvM Bool := do if h : i < params₁.size then @@ -85,7 +85,7 @@ def eqvLetValue (e₁ e₂ : LetValue) : EqvM Bool := do else return false -def sortAlts (alts : Array Alt) : Array Alt := +def sortAlts (alts : Array (Alt pu)) : Array (Alt pu) := alts.qsort fun | .alt .., .default .. => true | .alt ctorName₁ .., .alt ctorName₂ .. => Name.lt ctorName₁ ctorName₂ @@ -93,13 +93,13 @@ def sortAlts (alts : Array Alt) : Array Alt := mutual -partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do +partial def eqvAlts (alts₁ alts₂ : Array (Alt pu)) : EqvM Bool := do if alts₁.size = alts₂.size then let alts₁ := sortAlts alts₁ let alts₂ := sortAlts alts₂ for alt₁ in alts₁, alt₂ in alts₂ do match alt₁, alt₂ with - | .alt ctorName₁ ps₁ k₁, .alt ctorName₂ ps₂ k₂ => + | .alt ctorName₁ ps₁ k₁ _, .alt ctorName₂ ps₂ k₂ _ => unless ctorName₁ == ctorName₂ do return false unless (← withParams ps₁ ps₂ (eqv k₁ k₂)) do return false | .default k₁, .default k₂ => unless (← eqv k₁ k₂) do return false @@ -108,13 +108,13 @@ partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do else return false -partial def eqv (code₁ code₂ : Code) : EqvM Bool := do +partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do match code₁, code₂ with | .let decl₁ k₁, .let decl₂ k₂ => eqvType decl₁.type decl₂.type <&&> eqvLetValue decl₁.value decl₂.value <&&> withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂) - | .fun decl₁ k₁, .fun decl₂ k₂ + | .fun decl₁ k₁ _, .fun decl₂ k₂ _ | .jp decl₁ k₁, .jp decl₂ k₂ => eqvType decl₁.type decl₂.type <&&> withParams decl₁.params decl₂.params (eqv decl₁.value decl₂.value) <&&> @@ -135,7 +135,7 @@ end AlphaEqv /-- Return `true` if `c₁` and `c₂` are alpha equivalent. -/ -def Code.alphaEqv (c₁ c₂ : Code) : Bool := +def Code.alphaEqv (c₁ c₂ : Code pu) : Bool := AlphaEqv.eqv c₁ c₂ |>.run {} end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/AuxDeclCache.lean b/src/Lean/Compiler/LCNF/AuxDeclCache.lean index 406d57d4dc..d2d21c70ea 100644 --- a/src/Lean/Compiler/LCNF/AuxDeclCache.lean +++ b/src/Lean/Compiler/LCNF/AuxDeclCache.lean @@ -13,15 +13,21 @@ public section namespace Lean.Compiler.LCNF -builtin_initialize auxDeclCacheExt : CacheExtension Decl Name ← CacheExtension.register +structure AuxDeclCacheKey where + pu : Purity + decl : Decl pu + deriving BEq, Hashable + +builtin_initialize auxDeclCacheExt : CacheExtension AuxDeclCacheKey Name ← CacheExtension.register inductive CacheAuxDeclResult where | new | alreadyCached (declName : Name) -def cacheAuxDecl (decl : Decl) : CompilerM CacheAuxDeclResult := do +def cacheAuxDecl (decl : Decl pu) : CompilerM CacheAuxDeclResult := do let key := { decl with name := .anonymous } let key ← normalizeFVarIds key + let key := ⟨pu, key⟩ match (← auxDeclCacheExt.find? key) with | some declName => return .alreadyCached declName diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index b8230041c3..1db18432a1 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -24,14 +24,50 @@ and the approach described in the paper -/ -structure Param where +/-- +This type is used to index the fundamental LCNF IR data structures. Depending on its value different +constructors are available for the different semantic phases of LCNF. + +Notably in order to save memory we never index the IR types over `Purity`. Instead the type is +parametrized by the phase and the individual constructors might carry a proof (that will be erased) +that they are only allowed in a certain phase. +-/ +inductive Purity where + /-- + The code we are acting on is still pure, things like reordering up to value dependencies are + acceptable. + -/ + | pure + /-- + The code we are acting on is to be considered generally impure, doing reorderings is potentially + no longer legal. + -/ + | impure + deriving Inhabited, DecidableEq, Hashable + +instance : ToString Purity where + toString + | .pure => "pure" + | .impure => "impure" + +@[inline] +def Purity.withAssertPurity [Inhabited α] (is : Purity) (should : Purity) + (k : (is = should) → α) : α := + if h : is = should then + k h + else + panic! s!"Purity should be {should} but is {is}, this is a bug" + +scoped macro "purity_tac" : tactic => `(tactic| first | with_reducible rfl | assumption) + +structure Param (pu : Purity) where fvarId : FVarId binderName : Name type : Expr borrow : Bool deriving Inhabited, BEq -def Param.toExpr (p : Param) : Expr := +def Param.toExpr (p : Param pu) : Expr := .fvar p.fvarId inductive LitValue where @@ -55,111 +91,111 @@ def LitValue.toExpr : LitValue → Expr | .uint64 v => .app (.const ``UInt64.ofNat []) (.lit (.natVal (UInt64.toNat v))) | .usize v => .app (.const ``USize.ofNat []) (.lit (.natVal (UInt64.toNat v))) -inductive Arg where +inductive Arg (pu : Purity) where | erased | fvar (fvarId : FVarId) - | type (expr : Expr) + | type (expr : Expr) (h : pu = .pure := by purity_tac) deriving Inhabited, BEq, Hashable -def Param.toArg (p : Param) : Arg := +def Param.toArg (p : Param pu) : Arg pu := .fvar p.fvarId -def Arg.toExpr (arg : Arg) : Expr := +def Arg.toExpr (arg : Arg pu) : Expr := match arg with | .erased => erasedExpr | .fvar fvarId => .fvar fvarId - | .type e => e + | .type e _ => e -private unsafe def Arg.updateTypeImp (arg : Arg) (type' : Expr) : Arg := +private unsafe def Arg.updateTypeImp (arg : Arg pu) (type' : Expr) : Arg pu := match arg with - | .type ty => if ptrEq ty type' then arg else .type type' + | .type ty _ => if ptrEq ty type' then arg else .type type' | _ => unreachable! -@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg) (type : Expr) : Arg +@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg pu) (type : Expr) : Arg pu -private unsafe def Arg.updateFVarImp (arg : Arg) (fvarId' : FVarId) : Arg := +private unsafe def Arg.updateFVarImp (arg : Arg pu) (fvarId' : FVarId) : Arg pu := match arg with | .fvar fvarId => if fvarId' == fvarId then arg else .fvar fvarId' | _ => unreachable! -@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg) (fvarId' : FVarId) : Arg +@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg pu) (fvarId' : FVarId) : Arg pu -inductive LetValue where +inductive LetValue (pu : Purity) where | lit (value : LitValue) | erased - | proj (typeName : Name) (idx : Nat) (struct : FVarId) - | const (declName : Name) (us : List Level) (args : Array Arg) - | fvar (fvarId : FVarId) (args : Array Arg) + | proj (typeName : Name) (idx : Nat) (struct : FVarId) (h : pu = .pure := by purity_tac) + | const (declName : Name) (us : List Level) (args : Array (Arg pu)) (h : pu = .pure := by purity_tac) + | fvar (fvarId : FVarId) (args : Array (Arg pu)) deriving Inhabited, BEq, Hashable -def Arg.toLetValue (arg : Arg) : LetValue := +def Arg.toLetValue (arg : Arg pu) : LetValue pu := match arg with | .fvar fvarId => .fvar fvarId #[] | .erased | .type .. => .erased -private unsafe def LetValue.updateProjImp (e : LetValue) (fvarId' : FVarId) : LetValue := +private unsafe def LetValue.updateProjImp (e : LetValue pu) (fvarId' : FVarId) : LetValue pu := match e with - | .proj s i fvarId => if fvarId == fvarId' then e else .proj s i fvarId' + | .proj s i fvarId _ => if fvarId == fvarId' then e else .proj s i fvarId' | _ => unreachable! -@[implemented_by LetValue.updateProjImp] opaque LetValue.updateProj! (e : LetValue) (fvarId' : FVarId) : LetValue +@[implemented_by LetValue.updateProjImp] opaque LetValue.updateProj! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu -private unsafe def LetValue.updateConstImp (e : LetValue) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetValue := +private unsafe def LetValue.updateConstImp (e : LetValue pu) (declName' : Name) (us' : List Level) (args' : Array (Arg pu)) : LetValue pu := match e with - | .const declName us args => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args' + | .const declName us args _ => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args' | _ => unreachable! -@[implemented_by LetValue.updateConstImp] opaque LetValue.updateConst! (e : LetValue) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetValue +@[implemented_by LetValue.updateConstImp] opaque LetValue.updateConst! (e : LetValue pu) (declName' : Name) (us' : List Level) (args' : Array (Arg pu)) : LetValue pu -private unsafe def LetValue.updateFVarImp (e : LetValue) (fvarId' : FVarId) (args' : Array Arg) : LetValue := +private unsafe def LetValue.updateFVarImp (e : LetValue pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : LetValue pu := match e with | .fvar fvarId args => if fvarId == fvarId' && ptrEq args args' then e else .fvar fvarId' args' | _ => unreachable! -@[implemented_by LetValue.updateFVarImp] opaque LetValue.updateFVar! (e : LetValue) (fvarId' : FVarId) (args' : Array Arg) : LetValue +@[implemented_by LetValue.updateFVarImp] opaque LetValue.updateFVar! (e : LetValue pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : LetValue pu -private unsafe def LetValue.updateArgsImp (e : LetValue) (args' : Array Arg) : LetValue := +private unsafe def LetValue.updateArgsImp (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu := match e with - | .const declName us args => if ptrEq args args' then e else .const declName us args' + | .const declName us args h => if ptrEq args args' then e else .const declName us args' | .fvar fvarId args => if ptrEq args args' then e else .fvar fvarId args' | _ => unreachable! -@[implemented_by LetValue.updateArgsImp] opaque LetValue.updateArgs! (e : LetValue) (args' : Array Arg) : LetValue +@[implemented_by LetValue.updateArgsImp] opaque LetValue.updateArgs! (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu -def LetValue.toExpr (e : LetValue) : Expr := +def LetValue.toExpr (e : LetValue pu) : Expr := match e with | .lit v => v.toExpr | .erased => erasedExpr - | .proj n i s => .proj n i (.fvar s) - | .const n us as => mkAppN (.const n us) (as.map Arg.toExpr) + | .proj n i s _ => .proj n i (.fvar s) + | .const n us as _ => mkAppN (.const n us) (as.map Arg.toExpr) | .fvar fvarId as => mkAppN (.fvar fvarId) (as.map Arg.toExpr) -structure LetDecl where +structure LetDecl (pu : Purity) where fvarId : FVarId binderName : Name type : Expr - value : LetValue + value : LetValue pu deriving Inhabited, BEq mutual -inductive Alt where - | alt (ctorName : Name) (params : Array Param) (code : Code) - | default (code : Code) +inductive Alt (pu : Purity) where + | alt (ctorName : Name) (params : Array (Param pu)) (code : Code pu) (h : pu = .pure := by purity_tac) + | default (code : Code pu) -inductive FunDecl where - | mk (fvarId : FVarId) (binderName : Name) (params : Array Param) (type : Expr) (value : Code) +inductive FunDecl (pu : Purity) where + | mk (fvarId : FVarId) (binderName : Name) (params : Array (Param pu)) (type : Expr) (value : Code pu) -inductive Cases where - | mk (typeName : Name) (resultType : Expr) (discr : FVarId) (alts : Array Alt) +inductive Cases (pu : Purity) where + | mk (typeName : Name) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu)) deriving Inhabited -inductive Code where - | let (decl : LetDecl) (k : Code) - | fun (decl : FunDecl) (k : Code) - | jp (decl : FunDecl) (k : Code) - | jmp (fvarId : FVarId) (args : Array Arg) - | cases (cases : Cases) +inductive Code (pu : Purity) where + | let (decl : LetDecl pu) (k : Code pu) + | fun (decl : FunDecl pu) (k : Code pu) (h : pu = .pure := by purity_tac) + | jp (decl : FunDecl pu) (k : Code pu) + | jmp (fvarId : FVarId) (args : Array (Arg pu)) + | cases (cases : Cases pu) | return (fvarId : FVarId) | unreach (type : Expr) deriving Inhabited @@ -167,99 +203,99 @@ inductive Code where end @[inline] -def FunDecl.fvarId : FunDecl → FVarId +def FunDecl.fvarId : FunDecl pu → FVarId | .mk (fvarId := fvarId) .. => fvarId @[inline] -def FunDecl.binderName : FunDecl → Name +def FunDecl.binderName : FunDecl pu → Name | .mk (binderName := binderName) .. => binderName @[inline] -def FunDecl.params : FunDecl → Array Param +def FunDecl.params : FunDecl pu → Array (Param pu) | .mk (params := params) .. => params @[inline] -def FunDecl.type : FunDecl → Expr +def FunDecl.type : FunDecl pu → Expr | .mk (type := type) .. => type @[inline] -def FunDecl.value : FunDecl → Code +def FunDecl.value : FunDecl pu → Code pu | .mk (value := value) .. => value @[inline] -def FunDecl.updateBinderName : FunDecl → Name → FunDecl +def FunDecl.updateBinderName : FunDecl pu → Name → FunDecl pu | .mk fvarId _ params type value, new => .mk fvarId new params type value @[inline] -def FunDecl.toParam (decl : FunDecl) (borrow : Bool) : Param := +def FunDecl.toParam (decl : FunDecl pu) (borrow : Bool) : Param pu := match decl with | .mk fvarId binderName _ type .. => ⟨fvarId, binderName, type, borrow⟩ @[inline] -def Cases.typeName : Cases → Name +def Cases.typeName : Cases pu → Name | .mk (typeName := typeName) .. => typeName @[inline] -def Cases.resultType : Cases → Expr +def Cases.resultType : Cases pu → Expr | .mk (resultType := resultType) .. => resultType @[inline] -def Cases.discr : Cases → FVarId +def Cases.discr : Cases pu → FVarId | .mk (discr := discr) .. => discr @[inline] -def Cases.alts : Cases → Array Alt +def Cases.alts : Cases pu → Array (Alt pu) | .mk (alts := alts) .. => alts @[inline] -def Cases.updateAlts : Cases → Array Alt → Cases +def Cases.updateAlts : Cases pu → Array (Alt pu) → Cases pu | .mk typeName resultType discr _, new => .mk typeName resultType discr new deriving instance Inhabited for Alt deriving instance Inhabited for FunDecl -def FunDecl.getArity (decl : FunDecl) : Nat := +def FunDecl.getArity (decl : FunDecl pu) : Nat := decl.params.size /-- Return the constructor names that have an explicit (non-default) alternative. -/ -def Cases.getCtorNames (c : Cases) : NameSet := +def Cases.getCtorNames (c : Cases pu) : NameSet := c.alts.foldl (init := {}) fun ctorNames alt => match alt with | .default _ => ctorNames | .alt ctorName .. => ctorNames.insert ctorName -inductive CodeDecl where - | let (decl : LetDecl) - | fun (decl : FunDecl) - | jp (decl : FunDecl) +inductive CodeDecl (pu : Purity) where + | let (decl : LetDecl pu) + | fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac) + | jp (decl : FunDecl pu) deriving Inhabited -def CodeDecl.fvarId : CodeDecl → FVarId - | .let decl | .fun decl | .jp decl => decl.fvarId +def CodeDecl.fvarId : CodeDecl pu → FVarId + | .let decl | .fun decl _ | .jp decl => decl.fvarId -def attachCodeDecls (decls : Array CodeDecl) (code : Code) : Code := +def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu := go decls.size code where - go (i : Nat) (code : Code) : Code := + go (i : Nat) (code : Code pu) : Code pu := if i > 0 then match decls[i-1]! with | .let decl => go (i-1) (.let decl code) - | .fun decl => go (i-1) (.fun decl code) + | .fun decl _ => go (i-1) (.fun decl code) | .jp decl => go (i-1) (.jp decl code) else code mutual - private unsafe def eqImp (c₁ c₂ : Code) : Bool := + private unsafe def eqImp (c₁ c₂ : Code pu) : Bool := if ptrEq c₁ c₂ then true else match c₁, c₂ with | .let d₁ k₁, .let d₂ k₂ => d₁ == d₂ && eqImp k₁ k₂ - | .fun d₁ k₁, .fun d₂ k₂ + | .fun d₁ k₁ _, .fun d₂ k₂ _ | .jp d₁ k₁, .jp d₂ k₂ => eqFunDecl d₁ d₂ && eqImp k₁ k₂ | .cases c₁, .cases c₂ => eqCases c₁ c₂ | .jmp j₁ as₁, .jmp j₂ as₂ => j₁ == j₂ && as₁ == as₂ @@ -267,7 +303,7 @@ mutual | .unreach t₁, .unreach t₂ => t₁ == t₂ | _, _ => false - private unsafe def eqFunDecl (d₁ d₂ : FunDecl) : Bool := + private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool := if ptrEq d₁ d₂ then true else @@ -275,62 +311,62 @@ mutual d₁.params == d₂.params && d₁.type == d₂.type && eqImp d₁.value d₂.value - private unsafe def eqCases (c₁ c₂ : Cases) : Bool := + private unsafe def eqCases (c₁ c₂ : Cases pu) : Bool := c₁.resultType == c₂.resultType && c₁.discr == c₂.discr && c₁.typeName == c₂.typeName && c₁.alts.isEqv c₂.alts eqAlt - private unsafe def eqAlt (a₁ a₂ : Alt) : Bool := + private unsafe def eqAlt (a₁ a₂ : Alt pu) : Bool := match a₁, a₂ with | .default k₁, .default k₂ => eqImp k₁ k₂ - | .alt c₁ ps₁ k₁, .alt c₂ ps₂ k₂ => c₁ == c₂ && ps₁ == ps₂ && eqImp k₁ k₂ + | .alt c₁ ps₁ k₁ _, .alt c₂ ps₂ k₂ _ => c₁ == c₂ && ps₁ == ps₂ && eqImp k₁ k₂ | _, _ => false end -@[implemented_by eqImp] protected opaque Code.beq : Code → Code → Bool +@[implemented_by eqImp] protected opaque Code.beq : Code pu → Code pu → Bool -instance : BEq Code where +instance : BEq (Code pu) where beq := Code.beq -@[implemented_by eqFunDecl] protected opaque FunDecl.beq : FunDecl → FunDecl → Bool +@[implemented_by eqFunDecl] protected opaque FunDecl.beq : FunDecl pu → FunDecl pu → Bool -instance : BEq FunDecl where +instance : BEq (FunDecl pu) where beq := FunDecl.beq -def Alt.getCode : Alt → Code +def Alt.getCode : Alt pu → Code pu | .default k => k - | .alt _ _ k => k + | .alt _ _ k _ => k -def Alt.getParams : Alt → Array Param +def Alt.getParams : Alt pu → Array (Param pu) | .default _ => #[] - | .alt _ ps _ => ps + | .alt _ ps _ _ => ps -def Alt.forCodeM [Monad m] (alt : Alt) (f : Code → m Unit) : m Unit := do +def Alt.forCodeM [Monad m] (alt : Alt pu) (f : Code pu → m Unit) : m Unit := do match alt with | .default k => f k - | .alt _ _ k => f k + | .alt _ _ k _ => f k -private unsafe def updateAltCodeImp (alt : Alt) (k' : Code) : Alt := +private unsafe def updateAltCodeImp (alt : Alt pu) (k' : Code pu) : Alt pu := match alt with | .default k => if ptrEq k k' then alt else .default k' - | .alt ctorName ps k => if ptrEq k k' then alt else .alt ctorName ps k' + | .alt ctorName ps k _ => if ptrEq k k' then alt else .alt ctorName ps k' -@[implemented_by updateAltCodeImp] opaque Alt.updateCode (alt : Alt) (c : Code) : Alt +@[implemented_by updateAltCodeImp] opaque Alt.updateCode (alt : Alt pu) (c : Code pu) : Alt pu -private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Alt := +private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Code pu) : Alt pu := match alt with - | .alt ctorName ps k => if ptrEq k k' && ptrEq ps ps' then alt else .alt ctorName ps' k' + | .alt ctorName ps k _ => if ptrEq k k' && ptrEq ps ps' then alt else .alt ctorName ps' k' | _ => unreachable! -@[implemented_by updateAltImp] opaque Alt.updateAlt! (alt : Alt) (ps' : Array Param) (k' : Code) : Alt +@[implemented_by updateAltImp] opaque Alt.updateAlt! (alt : Alt pu) (ps' : Array (Param pu)) (k' : Code pu) : Alt pu -@[inline] private unsafe def updateAltsImp (c : Code) (alts : Array Alt) : Code := +@[inline] private unsafe def updateAltsImp (c : Code pu) (alts : Array (Alt pu)) : Code pu := match c with | .cases cs => if ptrEq cs.alts alts then c else .cases <| cs.updateAlts alts | _ => unreachable! -@[implemented_by updateAltsImp] opaque Code.updateAlts! (c : Code) (alts : Array Alt) : Code +@[implemented_by updateAltsImp] opaque Code.updateAlts! (c : Code pu) (alts : Array (Alt pu)) : Code pu -@[inline] private unsafe def updateCasesImp (c : Code) (resultType : Expr) (discr : FVarId) (alts : Array Alt) : Code := +@[inline] private unsafe def updateCasesImp (c : Code pu) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu)) : Code pu := match c with | .cases cs => if ptrEq cs.alts alts && ptrEq cs.resultType resultType && cs.discr == discr then @@ -339,54 +375,54 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al .cases <| ⟨cs.typeName, resultType, discr, alts⟩ | _ => unreachable! -@[implemented_by updateCasesImp] opaque Code.updateCases! (c : Code) (resultType : Expr) (discr : FVarId) (alts : Array Alt) : Code +@[implemented_by updateCasesImp] opaque Code.updateCases! (c : Code pu) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu)) : Code pu -@[inline] private unsafe def updateLetImp (c : Code) (decl' : LetDecl) (k' : Code) : Code := +@[inline] private unsafe def updateLetImp (c : Code pu) (decl' : LetDecl pu) (k' : Code pu) : Code pu := match c with | .let decl k => if ptrEq k k' && ptrEq decl decl' then c else .let decl' k' | _ => unreachable! -@[implemented_by updateLetImp] opaque Code.updateLet! (c : Code) (decl' : LetDecl) (k' : Code) : Code +@[implemented_by updateLetImp] opaque Code.updateLet! (c : Code pu) (decl' : LetDecl pu) (k' : Code pu) : Code pu -@[inline] private unsafe def updateContImp (c : Code) (k' : Code) : Code := +@[inline] private unsafe def updateContImp (c : Code pu) (k' : Code pu) : Code pu := match c with | .let decl k => if ptrEq k k' then c else .let decl k' - | .fun decl k => if ptrEq k k' then c else .fun decl k' + | .fun decl k _ => if ptrEq k k' then c else .fun decl k' | .jp decl k => if ptrEq k k' then c else .jp decl k' | _ => unreachable! -@[implemented_by updateContImp] opaque Code.updateCont! (c : Code) (k' : Code) : Code +@[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu -@[inline] private unsafe def updateFunImp (c : Code) (decl' : FunDecl) (k' : Code) : Code := +@[inline] private unsafe def updateFunImp (c : Code pu) (decl' : FunDecl pu) (k' : Code pu) : Code pu := match c with - | .fun decl k => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k' + | .fun decl k _ => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k' | .jp decl k => if ptrEq k k' && ptrEq decl decl' then c else .jp decl' k' | _ => unreachable! -@[implemented_by updateFunImp] opaque Code.updateFun! (c : Code) (decl' : FunDecl) (k' : Code) : Code +@[implemented_by updateFunImp] opaque Code.updateFun! (c : Code pu) (decl' : FunDecl pu) (k' : Code pu) : Code pu -@[inline] private unsafe def updateReturnImp (c : Code) (fvarId' : FVarId) : Code := +@[inline] private unsafe def updateReturnImp (c : Code pu) (fvarId' : FVarId) : Code pu := match c with | .return fvarId => if fvarId == fvarId' then c else .return fvarId' | _ => unreachable! -@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code +@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code pu) (fvarId' : FVarId) : Code pu -@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code := +@[inline] private unsafe def updateJmpImp (c : Code pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : Code pu := match c with | .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args' | _ => unreachable! -@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code +@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : Code pu -@[inline] private unsafe def updateUnreachImp (c : Code) (type' : Expr) : Code := +@[inline] private unsafe def updateUnreachImp (c : Code pu) (type' : Expr) : Code pu := match c with | .unreach type => if ptrEq type type' then c else .unreach type' | _ => unreachable! -@[implemented_by updateUnreachImp] opaque Code.updateUnreach! (c : Code) (type' : Expr) : Code +@[implemented_by updateUnreachImp] opaque Code.updateUnreach! (c : Code pu) (type' : Expr) : Code pu -private unsafe def updateParamCoreImp (p : Param) (type : Expr) : Param := +private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu := if ptrEq type p.type then p else @@ -397,9 +433,9 @@ Low-level update `Param` function. It does not update the local context. Consider using `Param.update : Param → Expr → CompilerM Param` if you want the local context to be updated. -/ -@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param) (type : Expr) : Param +@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param pu) (type : Expr) : Param pu -private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : LetValue) : LetDecl := +private unsafe def updateLetDeclCoreImp (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : LetDecl pu := if ptrEq type decl.type && ptrEq value decl.value then decl else @@ -410,9 +446,9 @@ Low-level update `LetDecl` function. It does not update the local context. Consider using `LetDecl.update : LetDecl → Expr → Expr → CompilerM LetDecl` if you want the local context to be updated. -/ -@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : LetValue) : LetDecl +@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : LetDecl pu -private unsafe def updateFunDeclCoreImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl := +private unsafe def updateFunDeclCoreImp (decl: FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : FunDecl pu := if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then decl else @@ -423,9 +459,9 @@ Low-level update `FunDecl` function. It does not update the local context. Consider using `FunDecl.update : LetDecl → Expr → Array Param → Code → CompilerM FunDecl` if you want the local context to be updated. -/ -@[implemented_by updateFunDeclCoreImp] opaque FunDecl.updateCore (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl +@[implemented_by updateFunDeclCoreImp] opaque FunDecl.updateCore (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : FunDecl pu -def Cases.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases := +def Cases.extractAlt! (cases : Cases pu) (ctorName : Name) : Alt pu × Cases pu := let found i := (cases.alts[i]!, cases.updateAlts (cases.alts.eraseIdx! i)) if let some i := cases.alts.findFinIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then found i @@ -434,34 +470,34 @@ def Cases.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases := else unreachable! -def Alt.mapCodeM [Monad m] (alt : Alt) (f : Code → m Code) : m Alt := do +def Alt.mapCodeM [Monad m] (alt : Alt pu) (f : Code pu → m (Code pu)) : m (Alt pu) := do return alt.updateCode (← f alt.getCode) -def Code.isDecl : Code → Bool +def Code.isDecl : Code pu → Bool | .let .. | .fun .. | .jp .. => true | _ => false -def Code.isFun : Code → Bool +def Code.isFun : Code pu → Bool | .fun .. => true | _ => false -def Code.isReturnOf : Code → FVarId → Bool +def Code.isReturnOf : Code pu → FVarId → Bool | .return fvarId, fvarId' => fvarId == fvarId' | _, _ => false -partial def Code.size (c : Code) : Nat := +partial def Code.size (c : Code pu) : Nat := go c 0 where - go (c : Code) (n : Nat) : Nat := + go (c : Code pu) (n : Nat) : Nat := match c with | .let _ k => go k (n+1) - | .jp decl k | .fun decl k => go k <| go decl.value n + | .jp decl k | .fun decl k _ => go k <| go decl.value n | .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1) | .jmp .. => n+1 | .return .. | unreach .. => n -- `return` & `unreach` have weight zero /-- Return true iff `c.size ≤ n` -/ -partial def Code.sizeLe (c : Code) (n : Nat) : Bool := +partial def Code.sizeLe (c : Code pu) (n : Nat) : Bool := match go c |>.run 0 with | .ok .. => true | .error .. => false @@ -470,26 +506,26 @@ where modify (·+1) unless (← get) <= n do throw () - go (c : Code) : EStateM Unit Nat Unit := do + go (c : Code pu) : EStateM Unit Nat Unit := do match c with | .let _ k => inc; go k - | .jp decl k | .fun decl k => inc; go decl.value; go k + | .jp decl k | .fun decl k _ => inc; go decl.value; go k | .cases c => inc; c.alts.forM fun alt => go alt.getCode | .jmp .. => inc | .return .. | unreach .. => return () -partial def Code.forM [Monad m] (c : Code) (f : Code → m Unit) : m Unit := +partial def Code.forM [Monad m] (c : Code pu) (f : Code pu → m Unit) : m Unit := go c where - go (c : Code) : m Unit := do + go (c : Code pu) : m Unit := do f c match c with | .let _ k => go k - | .fun decl k | .jp decl k => go decl.value; go k + | .fun decl k _ | .jp decl k => go decl.value; go k | .cases c => c.alts.forM fun alt => go alt.getCode | .unreach .. | .return .. | .jmp .. => return () -partial def Code.instantiateValueLevelParams (code : Code) (levelParams : List Name) (us : List Level) : Code := +partial def Code.instantiateValueLevelParams (code : Code .pure) (levelParams : List Name) (us : List Level) : Code .pure := instCode code where instLevel (u : Level) := @@ -498,67 +534,67 @@ where instExpr (e : Expr) := e.instantiateLevelParamsNoCache levelParams us - instParams (ps : Array Param) := + instParams (ps : Array (Param .pure)) := ps.mapMono fun p => p.updateCore (instExpr p.type) - instAlt (alt : Alt) := + instAlt (alt : Alt .pure) := match alt with | .default k => alt.updateCode (instCode k) - | .alt _ ps k => alt.updateAlt! (instParams ps) (instCode k) + | .alt _ ps k _ => alt.updateAlt! (instParams ps) (instCode k) - instArg (arg : Arg) : Arg := + instArg (arg : Arg .pure) : Arg .pure := match arg with - | .type e => arg.updateType! (instExpr e) + | .type e _ => arg.updateType! (instExpr e) | .fvar .. | .erased => arg - instLetValue (e : LetValue) : LetValue := + instLetValue (e : LetValue .pure) : LetValue .pure := match e with - | .const declName vs args => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg) + | .const declName vs args _ => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg) | .fvar fvarId args => e.updateFVar! fvarId (args.mapMono instArg) | .proj .. | .lit .. | .erased => e - instLetDecl (decl : LetDecl) := + instLetDecl (decl : LetDecl .pure) := decl.updateCore (instExpr decl.type) (instLetValue decl.value) - instFunDecl (decl : FunDecl) := + instFunDecl (decl : FunDecl .pure) := decl.updateCore (instExpr decl.type) (instParams decl.params) (instCode decl.value) - instCode (code : Code) := + instCode (code : Code .pure) := match code with | .let decl k => code.updateLet! (instLetDecl decl) (instCode k) - | .jp decl k | .fun decl k => code.updateFun! (instFunDecl decl) (instCode k) + | .jp decl k | .fun decl k _ => code.updateFun! (instFunDecl decl) (instCode k) | .cases c => code.updateCases! (instExpr c.resultType) c.discr (c.alts.mapMono instAlt) | .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instArg) | .return .. => code | .unreach type => code.updateUnreach! (instExpr type) -inductive DeclValue where - | code (code : Code) +inductive DeclValue (pu : Purity) where + | code (code : Code pu) | extern (externAttrData : ExternAttrData) deriving Inhabited, BEq -partial def DeclValue.size : DeclValue → Nat +partial def DeclValue.size : DeclValue pu → Nat | .code c => c.size | .extern .. => 0 -def DeclValue.mapCode (f : Code → Code) : DeclValue → DeclValue := +def DeclValue.mapCode (f : Code pu → Code pu) : DeclValue pu → DeclValue pu := fun | .code c => .code (f c) | .extern e => .extern e -def DeclValue.mapCodeM [Monad m] (f : Code → m Code) : DeclValue → m DeclValue := +def DeclValue.mapCodeM [Monad m] (f : Code pu → m (Code pu)) : DeclValue pu → m (DeclValue pu) := fun v => do match v with | .code c => return .code (← f c) | .extern .. => return v -def DeclValue.forCodeM [Monad m] (f : Code → m Unit) : DeclValue → m Unit := +def DeclValue.forCodeM [Monad m] (f : Code pu → m Unit) : DeclValue pu → m Unit := fun v => do match v with | .code c => f c | .extern .. => return () -def DeclValue.isCodeAndM [Monad m] (v : DeclValue) (f : Code → m Bool) : m Bool := +def DeclValue.isCodeAndM [Monad m] (v : DeclValue pu) (f : Code pu → m Bool) : m Bool := match v with | .code c => f c | .extern .. => pure false @@ -566,7 +602,7 @@ def DeclValue.isCodeAndM [Monad m] (v : DeclValue) (f : Code → m Bool) : m Boo /-- Declaration being processed by the Lean to Lean compiler passes. -/ -structure Decl where +structure Decl (pu : Purity) where /-- The name of the declaration from the `Environment` it came from -/ @@ -584,12 +620,12 @@ structure Decl where /-- Parameters. -/ - params : Array Param + params : Array (Param pu) /-- The body of the declaration, usually changes as it progresses through compiler passes. -/ - value : DeclValue + value : DeclValue pu /-- We set this flag to true during LCNF conversion. When we receive a block of functions to be compiled, we set this flag to `true` @@ -631,31 +667,37 @@ structure Decl where inlineAttr? : Option InlineAttributeKind deriving Inhabited, BEq -def Decl.size (decl : Decl) : Nat := +def Decl.size (decl : Decl pu) : Nat := decl.value.size -def Decl.getArity (decl : Decl) : Nat := +def Decl.getArity (decl : Decl pu) : Nat := decl.params.size -def Decl.inlineAttr (decl : Decl) : Bool := +def Decl.inlineAttr (decl : Decl pu) : Bool := decl.inlineAttr? matches some .inline -def Decl.noinlineAttr (decl : Decl) : Bool := +def Decl.noinlineAttr (decl : Decl pu) : Bool := decl.inlineAttr? matches some .noinline -def Decl.inlineIfReduceAttr (decl : Decl) : Bool := +def Decl.inlineIfReduceAttr (decl : Decl pu) : Bool := decl.inlineAttr? matches some .inlineIfReduce -def Decl.alwaysInlineAttr (decl : Decl) : Bool := +def Decl.alwaysInlineAttr (decl : Decl pu) : Bool := decl.inlineAttr? matches some .alwaysInline /-- Return `true` if the given declaration has been annotated with `[inline]`, `[inline_if_reduce]`, `[macro_inline]`, or `[always_inline]` -/ -def Decl.inlineable (decl : Decl) : Bool := +def Decl.inlineable (decl : Decl pu) : Bool := match decl.inlineAttr? with | some .noinline => false | some _ => true | none => false +def Decl.castPurity! (decl : Decl pu1) (pu2 : Purity) : Decl pu2 := + if h : pu1 = pu2 then + h ▸ decl + else + panic! s!"Purity {pu1} does not match {pu2}, this is a bug" + /-- Return `some i` if `decl` is of the form ``` @@ -669,21 +711,21 @@ That is, `f` is a sequence of declarations followed by a `cases` on the paramete We use this function to decide whether we should inline a declaration tagged with `[inline_if_reduce]` or not. -/ -def Decl.isCasesOnParam? (decl : Decl) : Option Nat := +def Decl.isCasesOnParam? (decl : Decl pu) : Option Nat := match decl.value with | .code c => go c | .extern .. => none where - go (code : Code) : Option Nat := + go {pu : Purity} (code : Code pu) : Option Nat := match code with - | .let _ k | .jp _ k | .fun _ k => go k + | .let _ k | .jp _ k | .fun _ k _ => go k | .cases c => decl.params.findIdx? fun param => param.fvarId == c.discr | _ => none -def Decl.instantiateTypeLevelParams (decl : Decl) (us : List Level) : Expr := +def Decl.instantiateTypeLevelParams (decl : Decl pu) (us : List Level) : Expr := decl.type.instantiateLevelParamsNoCache decl.levelParams us -def Decl.instantiateParamsLevelParams (decl : Decl) (us : List Level) : Array Param := +def Decl.instantiateParamsLevelParams (decl : Decl pu) (us : List Level) : Array (Param pu) := decl.params.mapMono fun param => param.updateCore (param.type.instantiateLevelParamsNoCache decl.levelParams us) /-- @@ -700,7 +742,7 @@ def hasLocalInst (type : Expr) : CoreM Bool := do /-- Return `true` if `decl` is supposed to be inlined/specialized. -/ -def Decl.isTemplateLike (decl : Decl) : CoreM Bool := do +def Decl.isTemplateLike (decl : Decl pu) : CoreM Bool := do let env ← getEnv if ← hasLocalInst decl.type then return true -- `decl` applications will be specialized @@ -721,40 +763,40 @@ private partial def collectType (e : Expr) : FVarIdHashSet → FVarIdHashSet := | .proj .. | .letE .. => unreachable! | _ => id -private def collectArg (arg : Arg) (s : FVarIdHashSet) : FVarIdHashSet := +private def collectArg (arg : Arg pu) (s : FVarIdHashSet) : FVarIdHashSet := match arg with | .erased => s | .fvar fvarId => s.insert fvarId - | .type e => collectType e s + | .type e _ => collectType e s -private def collectArgs (args : Array Arg) (s : FVarIdHashSet) : FVarIdHashSet := +private def collectArgs (args : Array (Arg pu)) (s : FVarIdHashSet) : FVarIdHashSet := args.foldl (init := s) fun s arg => collectArg arg s -private def collectLetValue (e : LetValue) (s : FVarIdHashSet) : FVarIdHashSet := +private def collectLetValue (e : LetValue pu) (s : FVarIdHashSet) : FVarIdHashSet := match e with | .fvar fvarId args => collectArgs args <| s.insert fvarId - | .const _ _ args => collectArgs args s - | .proj _ _ fvarId => s.insert fvarId + | .const _ _ args _ => collectArgs args s + | .proj _ _ fvarId _ => s.insert fvarId | .lit .. | .erased => s -private partial def collectParams (ps : Array Param) (s : FVarIdHashSet) : FVarIdHashSet := +private partial def collectParams (ps : Array (Param pu)) (s : FVarIdHashSet) : FVarIdHashSet := ps.foldl (init := s) fun s p => collectType p.type s mutual -partial def FunDecl.collectUsed (decl : FunDecl) (s : FVarIdHashSet := {}) : FVarIdHashSet := +partial def FunDecl.collectUsed (decl : FunDecl pu) (s : FVarIdHashSet := {}) : FVarIdHashSet := decl.value.collectUsed <| collectParams decl.params <| collectType decl.type s -partial def Code.collectUsed (code : Code) (s : FVarIdHashSet := {}) : FVarIdHashSet := +partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarIdHashSet := match code with | .let decl k => k.collectUsed <| collectLetValue decl.value <| collectType decl.type s - | .jp decl k | .fun decl k => k.collectUsed <| decl.collectUsed s + | .jp decl k | .fun decl k _ => k.collectUsed <| decl.collectUsed s | .cases c => let s := s.insert c.discr let s := collectType c.resultType s c.alts.foldl (init := s) fun s alt => match alt with | .default k => k.collectUsed s - | .alt _ ps k => k.collectUsed <| collectParams ps s + | .alt _ ps k _ => k.collectUsed <| collectParams ps s | .return fvarId => s.insert fvarId | .unreach type => collectType type s | .jmp fvarId args => collectArgs args <| s.insert fvarId @@ -771,7 +813,7 @@ This is an overapproximation, and relies on the fact that our frontend computes strongly connected components. See comment at `recursive` field. -/ -partial def markRecDecls (decls : Array Decl) : Array Decl := +partial def markRecDecls (decls : Array (Decl pu)) : Array (Decl pu) := let (_, isRec) := go |>.run {} decls.map fun decl => if isRec.contains decl.name then @@ -779,13 +821,13 @@ partial def markRecDecls (decls : Array Decl) : Array Decl := else decl where - visit (code : Code) : StateM NameSet Unit := do + visit {pu : Purity} (code : Code pu) : StateM NameSet Unit := do match code with - | .jp decl k | .fun decl k => visit decl.value; visit k + | .jp decl k | .fun decl k _ => visit decl.value; visit k | .cases c => c.alts.forM fun alt => visit alt.getCode | .unreach .. | .jmp .. | .return .. => return () | .let decl k => - if let .const declName _ _ := decl.value then + if let .const declName _ _ _ := decl.value then if decls.any (·.name == declName) then modify fun s => s.insert declName visit k @@ -793,13 +835,13 @@ where go : StateM NameSet Unit := decls.forM (·.value.forCodeM visit) -def instantiateRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array Arg) : Expr := +def instantiateRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array (Arg pu)) : Expr := if !e.hasLooseBVars then e else e.instantiateRange beginIdx endIdx (args.map (·.toExpr)) -def instantiateRevRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array Arg) : Expr := +def instantiateRevRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array (Arg pu)) : Expr := if !e.hasLooseBVars then e else diff --git a/src/Lean/Compiler/LCNF/Bind.lean b/src/Lean/Compiler/LCNF/Bind.lean index 4685f60c66..8b69f8aa08 100644 --- a/src/Lean/Compiler/LCNF/Bind.lean +++ b/src/Lean/Compiler/LCNF/Bind.lean @@ -14,7 +14,7 @@ namespace Lean.Compiler.LCNF /-- Helper class for lifting `CompilerM.codeBind` -/ class MonadCodeBind (m : Type → Type) where - codeBind : (c : Code) → (f : FVarId → m Code) → m Code + codeBind : {pu : Purity} → (c : Code pu) → (f : FVarId → m (Code pu)) → m (Code pu) /-- Return code that is equivalent to `c >>= f`. That is, executes `c`, and then `f x`, where @@ -25,16 +25,17 @@ an invalid block would be generated. It would be invalid because `f` would not be applied to `jp_i`. Note that, we could have decided to create a copy of `jp_i` where we apply `f` to it, by we decided to not do it to avoid code duplication. -/ -abbrev Code.bind [MonadCodeBind m] (c : Code) (f : FVarId → m Code) : m Code := +abbrev Code.bind [MonadCodeBind m] (c : Code pu) (f : FVarId → m (Code pu)) : m (Code pu) := MonadCodeBind.codeBind c f -partial def CompilerM.codeBind (c : Code) (f : FVarId → CompilerM Code) : CompilerM Code := do +partial def CompilerM.codeBind (c : Code pu) (f : FVarId → CompilerM (Code pu)) : + CompilerM (Code pu) := do go c |>.run {} where - go (c : Code) : ReaderT FVarIdSet CompilerM Code := do + go (c : Code pu) : ReaderT FVarIdSet CompilerM (Code pu) := do match c with | .let decl k => return .let decl (← go k) - | .fun decl k => return .fun decl (← go k) + | .fun decl k _ => return .fun decl (← go k) | .jp decl k => let value ← go decl.value let type ← value.inferParamType decl.params @@ -43,7 +44,7 @@ where return .jp decl (← go k) | .cases c => let alts ← c.alts.mapM fun - | .alt ctorName params k => return .alt ctorName params (← go k) + | .alt ctorName params k _ => return .alt ctorName params (← go k) | .default k => return .default (← go k) if alts.isEmpty then throwError "`Code.bind` failed, empty `cases` found" @@ -60,7 +61,7 @@ where This code is not very efficient, we could ask caller to provide the type of `c >>= f`, but this is more convenient, and this case is seldom reached. -/ - let auxParam ← mkAuxParam type + let auxParam ← mkAuxParam (pu := pu) type let k ← f auxParam.fvarId let typeNew ← k.inferType eraseCode k @@ -81,10 +82,10 @@ Create new parameters for the given arrow type. Example: if `type` is `Nat → Bool → Int`, the result is an array containing two new parameters with types `Nat` and `Bool`. -/ -partial def mkNewParams (type : Expr) : CompilerM (Array Param) := +partial def mkNewParams (type : Expr) : CompilerM (Array (Param pu)) := go type #[] #[] where - go (type : Expr) (xs : Array Expr) (ps : Array Param) : CompilerM (Array Param) := do + go (type : Expr) (xs : Array Expr) (ps : Array (Param pu)) : CompilerM (Array (Param pu)) := do match type with | .forallE _ d b _ => let d := d.instantiateRev xs @@ -98,15 +99,16 @@ where else return ps -def isEtaExpandCandidateCore (type : Expr) (params : Array Param) : Bool := +def isEtaExpandCandidateCore (type : Expr) (params : Array (Param .pure)) : Bool := let typeArity := getArrowArity type let valueArity := params.size typeArity > valueArity -abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl) : Bool := +abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl .pure) : Bool := isEtaExpandCandidateCore decl.type decl.params -def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : CompilerM (Array Param × Code) := do +def etaExpandCore (type : Expr) (params : Array (Param .pure)) (value : Code .pure) : + CompilerM (Array (Param .pure) × Code .pure) := do let valueType ← instantiateForall type (params.map (mkFVar ·.fvarId)) let psNew ← mkNewParams valueType let params := params ++ psNew @@ -116,17 +118,17 @@ def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : Compiler return .let auxDecl (.return auxDecl.fvarId) return (params, value) -def etaExpandCore? (type : Expr) (params : Array Param) (value : Code) : CompilerM (Option (Array Param × Code)) := do +def etaExpandCore? (type : Expr) (params : Array (Param .pure)) (value : Code .pure) : CompilerM (Option (Array (Param .pure) × Code .pure)) := do if isEtaExpandCandidateCore type params then etaExpandCore type params value else return none -def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do +def FunDecl.etaExpand (decl : FunDecl .pure) : CompilerM (FunDecl .pure) := do let some (params, value) ← etaExpandCore? decl.type decl.params decl.value | return decl decl.update decl.type params value -def Decl.etaExpand (decl : Decl) : CompilerM Decl := do +def Decl.etaExpand (decl : Decl .pure) : CompilerM (Decl .pure) := do match decl.value with | .code code => let some (params, newCode) ← etaExpandCore? decl.type decl.params code | return decl diff --git a/src/Lean/Compiler/LCNF/CSE.lean b/src/Lean/Compiler/LCNF/CSE.lean index 69f3b07302..c925c714df 100644 --- a/src/Lean/Compiler/LCNF/CSE.lean +++ b/src/Lean/Compiler/LCNF/CSE.lean @@ -20,17 +20,17 @@ namespace CSE structure State where map : PHashMap Expr FVarId := {} - subst : FVarSubst := {} + subst : FVarSubst .pure := {} abbrev M := StateRefT State CompilerM -instance : MonadFVarSubst M false where +instance : MonadFVarSubst M .pure false where getSubst := return (← get).subst -instance : MonadFVarSubstState M where +instance : MonadFVarSubstState M .pure where modifySubst f := modify fun s => { s with subst := f s.subst } -@[inline] def getSubst : M FVarSubst := +@[inline] def getSubst : M (FVarSubst .pure) := return (← get).subst @[inline] def addEntry (value : Expr) (fvarId : FVarId) : M Unit := @@ -40,31 +40,32 @@ instance : MonadFVarSubstState M where let map := (← get).map try x finally modify fun s => { s with map } -def replaceLet (decl : LetDecl) (fvarId : FVarId) : M Unit := do +def replaceLet (decl : LetDecl .pure) (fvarId : FVarId) : M Unit := do eraseLetDecl decl addFVarSubst decl.fvarId fvarId -def replaceFun (decl : FunDecl) (fvarId : FVarId) : M Unit := do +def replaceFun (decl : FunDecl .pure) (fvarId : FVarId) : M Unit := do eraseFunDecl decl addFVarSubst decl.fvarId fvarId -def hasNeverExtract (v : LetValue) : CompilerM Bool := +def hasNeverExtract (v : LetValue .pure) : CompilerM Bool := match v with | .const declName .. => return hasNeverExtractAttribute (← getEnv) declName | .lit _ | .erased | .proj .. | .fvar .. => return false -partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code) : CompilerM Code := +partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code .pure) : + CompilerM (Code .pure) := go code |>.run' {} where - goFunDecl (decl : FunDecl) : M FunDecl := do + goFunDecl (decl : FunDecl .pure) : M (FunDecl .pure) := do let type ← normExpr decl.type let params ← normParams decl.params let value ← withNewScope do go decl.value decl.update type params value - go (code : Code) : M Code := do + go (code : Code .pure) : M (Code .pure) := do match code with | .let decl k => let decl ← normLetDecl decl @@ -118,12 +119,13 @@ end CSE /-- Common sub-expression elimination -/ -def Decl.cse (shouldElimFunDecls : Bool) (decl : Decl) : CompilerM Decl := do +def Decl.cse (shouldElimFunDecls : Bool) (decl : Decl .pure) : CompilerM (Decl .pure) := do let value ← decl.value.mapCodeM (·.cse shouldElimFunDecls) return { decl with value } def cse (phase : Phase := .base) (shouldElimFunDecls := false) (occurrence := 0) : Pass := - .mkPerDeclaration `cse (Decl.cse shouldElimFunDecls) phase occurrence + phase.withPurityCheck .pure fun h => + .mkPerDeclaration `cse phase (h ▸ Decl.cse shouldElimFunDecls) occurrence builtin_initialize registerTraceClass `Compiler.cse (inherited := true) diff --git a/src/Lean/Compiler/LCNF/Check.lean b/src/Lean/Compiler/LCNF/Check.lean index cc17d990a5..0352411c27 100644 --- a/src/Lean/Compiler/LCNF/Check.lean +++ b/src/Lean/Compiler/LCNF/Check.lean @@ -79,7 +79,8 @@ the subtype relation in sanity checks and add the necessary casts. -/ namespace Check -open InferType +namespace Pure +open InferType InferType.Pure /- Type and structural properties checker for LCNF expressions. @@ -110,7 +111,7 @@ def isCtorParam (f : Expr) (i : Nat) : CoreM Bool := do let .ctorInfo info ← getConstInfo declName | return false return i < info.numParams -def checkAppArgs (f : Expr) (args : Array Arg) : CheckM Unit := do +def checkAppArgs (f : Expr) (args : Array (Arg .pure)) : CheckM Unit := do let mut fType ← inferType f let mut j := 0 for h : i in *...args.size do @@ -129,11 +130,11 @@ def checkAppArgs (f : Expr) (args : Array Arg) : CheckM Unit := do let expectedType := instantiateRevRangeArgs d j i args if (← checkTypes) then let argType ← arg.inferType - unless (← InferType.compatibleTypes argType expectedType) do + unless (← compatibleTypes argType expectedType) do throwError "type mismatch at LCNF application{indentExpr (mkAppN f (args.map Arg.toExpr))}\nargument {arg.toExpr} has type{indentExpr argType}\nbut is expected to have type{indentExpr expectedType}" fType := b -def checkLetValue (e : LetValue) : CheckM Unit := do +def checkLetValue (e : LetValue .pure) : CheckM Unit := do match e with | .lit .. | .erased => pure () | .const declName us args => checkAppArgs (mkConst declName us) args @@ -154,18 +155,18 @@ def checkJpInScope (jp : FVarId) : CheckM Unit := do -/ throwError "invalid jump to out of scope join point `{mkFVar jp}`" -def checkParam (param : Param) : CheckM Unit := do +def checkParam (param : Param .pure) : CheckM Unit := do unless param == (← getParam param.fvarId) do throwError "LCNF parameter mismatch at `{param.binderName}`, does not value in local context" -def checkParams (params : Array Param) : CheckM Unit := +def checkParams (params : Array (Param .pure)) : CheckM Unit := params.forM checkParam -def checkLetDecl (letDecl : LetDecl) : CheckM Unit := do +def checkLetDecl (letDecl : LetDecl .pure) : CheckM Unit := do checkLetValue letDecl.value if (← checkTypes) then let valueType ← letDecl.value.inferType - unless (← InferType.compatibleTypes letDecl.type valueType) do + unless (← compatibleTypes letDecl.type valueType) do throwError "type mismatch at `{letDecl.binderName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr letDecl.type}" unless letDecl == (← getLetDecl letDecl.fvarId) do throwError "LCNF let declaration mismatch at `{letDecl.binderName}`, does not match value in local context" @@ -183,7 +184,7 @@ def addFVarId (fvarId : FVarId) : CheckM Unit := do addFVarId fvarId withReader (fun ctx => { ctx with jps := ctx.jps.insert fvarId }) x -@[inline] def withParams (params : Array Param) (x : CheckM α) : CheckM α := do +@[inline] def withParams (params : Array (Param .pure)) (x : CheckM α) : CheckM α := do params.forM (addFVarId ·.fvarId) withReader (fun ctx => { ctx with vars := params.foldl (init := ctx.vars) fun vars p => vars.insert p.fvarId }) x @@ -192,18 +193,18 @@ mutual set_option linter.all false -partial def checkFunDeclCore (declName : Name) (params : Array Param) (type : Expr) (value : Code) : CheckM Unit := do +partial def checkFunDeclCore (declName : Name) (params : Array (Param .pure)) (type : Expr) (value : Code .pure) : CheckM Unit := do checkParams params withParams params do discard <| check value if (← checkTypes) then let valueType ← mkForallParams params (← value.inferType) - unless (← InferType.compatibleTypes type valueType) do + unless (← compatibleTypes type valueType) do throwError "type mismatch at `{.ofConstName declName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr type}" -partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do +partial def checkFunDecl (funDecl : FunDecl .pure) : CheckM Unit := do checkFunDeclCore funDecl.binderName funDecl.params funDecl.type funDecl.value - let decl ← getFunDecl funDecl.fvarId + let decl ← getFunDecl (pu := .pure) funDecl.fvarId unless decl.binderName == funDecl.binderName do throwError "LCNF local function declaration mismatch at `{funDecl.binderName}`, binder name in local context `{decl.binderName}`" unless decl.type == funDecl.type do @@ -211,7 +212,7 @@ partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do unless (← getFunDecl funDecl.fvarId) == funDecl do throwError "LCNF local function declaration mismatch at `{funDecl.binderName}`, declaration in local context does match" -partial def checkCases (c : Cases) : CheckM Unit := do +partial def checkCases (c : Cases .pure) : CheckM Unit := do let mut ctorNames : NameSet := {} let mut hasDefault := false checkFVar c.discr @@ -230,7 +231,7 @@ partial def checkCases (c : Cases) : CheckM Unit := do throwError "invalid LCNF `cases`, `{ctorName}` has # {val.numFields} fields, but alternative has # {params.size} alternatives" withParams params do check k -partial def check (code : Code) : CheckM Unit := do +partial def check (code : Code .pure) : CheckM Unit := do match code with | .let decl k => checkLetDecl decl; withFVarId decl.fvarId do check k | .fun decl k => @@ -241,7 +242,7 @@ partial def check (code : Code) : CheckM Unit := do | .cases c => checkCases c | .jmp fvarId args => checkJpInScope fvarId - let decl ← getFunDecl fvarId + let decl ← getFunDecl (pu := .pure) fvarId unless decl.getArity == args.size do throwError "invalid LCNF `goto`, join point {decl.binderName} has #{decl.getArity} parameters, but #{args.size} were provided" checkAppArgs (.fvar fvarId) args @@ -253,9 +254,12 @@ end def run (x : CheckM α) : CompilerM α := x |>.run {} |>.run' {} |>.run {} +end Pure end Check -def Decl.check (decl : Decl) : CompilerM Unit := do - Check.run do decl.value.forCodeM (Check.checkFunDeclCore decl.name decl.params decl.type) +def Decl.check (decl : Decl pu) : CompilerM Unit := do + match pu with + | .pure => Check.Pure.run do decl.value.forCodeM (Check.Pure.checkFunDeclCore decl.name decl.params decl.type) + | .impure => panic! "Check for impure unimplemented" -- TODO end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/Closure.lean b/src/Lean/Compiler/LCNF/Closure.lean index e13b88d2af..f336355436 100644 --- a/src/Lean/Compiler/LCNF/Closure.lean +++ b/src/Lean/Compiler/LCNF/Closure.lean @@ -45,7 +45,7 @@ structure State where /-- Free variables that must become new parameters of the code being specialized. -/ - params : Array Param := #[] + params : Array (Param .pure) := #[] /-- Let-declarations and local function declarations that are going to be "copied" to the code being processed. For example, when this module is used in the code specializer, the let-declarations @@ -56,7 +56,7 @@ structure State where All customers of this module try to avoid work duplication. If a let-declaration is a ground value, it most likely will be computed during compilation time, and work duplication is not an issue. -/ - decls : Array CodeDecl := #[] + decls : Array (CodeDecl .pure) := #[] /-- Monad for implementing the dependency collector. @@ -75,16 +75,16 @@ mutual Collect dependencies in parameters. We need this because parameters may contain other type parameters. -/ - partial def collectParams (params : Array Param) : ClosureM Unit := + partial def collectParams (params : Array (Param .pure)) : ClosureM Unit := params.forM (collectType ·.type) - partial def collectArg (arg : Arg) : ClosureM Unit := + partial def collectArg (arg : Arg .pure) : ClosureM Unit := match arg with | .erased => return () | .type e => collectType e | .fvar fvarId => collectFVar fvarId - partial def collectLetValue (e : LetValue) : ClosureM Unit := do + partial def collectLetValue (e : LetValue .pure) : ClosureM Unit := do match e with | .erased | .lit .. => return () | .proj _ _ fvarId => collectFVar fvarId @@ -95,7 +95,7 @@ mutual Collect dependencies in the given code. We need this function to be able to collect dependencies in a local function declaration. -/ - partial def collectCode (c : Code) : ClosureM Unit := do + partial def collectCode (c : Code .pure) : ClosureM Unit := do match c with | .let decl k => collectType decl.type @@ -114,7 +114,7 @@ mutual | .return fvarId => collectFVar fvarId /-- Collect dependencies of a local function declaration. -/ - partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do + partial def collectFunDecl (decl : FunDecl .pure) : ClosureM Unit := do collectType decl.type collectParams decl.params collectCode decl.value @@ -155,7 +155,8 @@ mutual end -def run (x : ClosureM α) (inScope : FVarId → Bool) (abstract : FVarId → Bool := fun _ => true) : CompilerM (α × Array Param × Array CodeDecl) := do +def run (x : ClosureM α) (inScope : FVarId → Bool) (abstract : FVarId → Bool := fun _ => true) : + CompilerM (α × Array (Param .pure) × Array (CodeDecl .pure)) := do let (a, s) ← x { inScope, abstract } |>.run {} -- If we've abstracted an fvar into a param, exclude its definition. Note that this still allows -- for other decls the removed decl depends upon to be included, but they will be removed later diff --git a/src/Lean/Compiler/LCNF/CompatibleTypes.lean b/src/Lean/Compiler/LCNF/CompatibleTypes.lean index 542e7c69f2..0ca5a6fdaf 100644 --- a/src/Lean/Compiler/LCNF/CompatibleTypes.lean +++ b/src/Lean/Compiler/LCNF/CompatibleTypes.lean @@ -72,10 +72,13 @@ partial def compatibleTypesQuick (a b : Expr) : Bool := | .const n us, .const m vs => n == m && List.isEqv us vs Level.isEquiv | _, _ => false +namespace InferType +namespace Pure + /-- Complete check for `compatibleTypes`. It eta-expands type formers. See comment at `compatibleTypes`. -/ -partial def InferType.compatibleTypesFull (a b : Expr) : InferTypeM Bool := do +partial def compatibleTypesFull (a b : Expr) : InferTypeM Bool := do if a.isErased || b.isErased then return true else @@ -141,10 +144,13 @@ This is a simplification. We used to use `isErasedCompatible`, but this only add For item 2, we would have to modify the `toLCNFType` function and make sure a type former is erased if the expected type is not always a type former (see `S.mk` type and example in the note above). -/ -def InferType.compatibleTypes (a b : Expr) : InferTypeM Bool := do +def compatibleTypes (a b : Expr) : InferTypeM Bool := do if compatibleTypesQuick a b then return true else compatibleTypesFull a b +end Pure +end InferType + end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index 096a20f31b..0b84c570d7 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -21,7 +21,12 @@ inductive Phase where | base /-- In this phase polymorphism has been eliminated. -/ | mono - deriving Inhabited, BEq + | impure + deriving Inhabited, DecidableEq + +@[expose, reducible] def Phase.toPurity : Phase → Purity + | .base | .mono => .pure + | .impure => .impure /-- The state managed by the `CompilerM` `Monad`. @@ -52,48 +57,53 @@ instance : Monad CompilerM := let i := inferInstanceAs (Monad CompilerM); { pure def getPhase : CompilerM Phase := return (← read).phase +def getPurity : CompilerM Purity := + return (← getPhase).toPurity + def inBasePhase : CompilerM Bool := return (← getPhase) matches .base instance : AddMessageContext CompilerM where addMessageContext msgData := do let env ← getEnv - let lctx := (← get).lctx.toLocalContext + let lctx := (← get).lctx.toLocalContext (← getPurity) let opts ← getOptions return MessageData.withContext { env, lctx, opts, mctx := {} } msgData def getType (fvarId : FVarId) : CompilerM Expr := do let lctx := (← get).lctx - if let some decl := lctx.letDecls[fvarId]? then + let pu ← getPurity + if let some decl := (lctx.letDecls pu)[fvarId]? then return decl.type - else if let some decl := lctx.params[fvarId]? then + else if let some decl := (lctx.params pu)[fvarId]? then return decl.type - else if let some decl := lctx.funDecls[fvarId]? then + else if let some decl := (lctx.funDecls pu)[fvarId]? then return decl.type else throwError "unknown free variable {fvarId.name}" def getBinderName (fvarId : FVarId) : CompilerM Name := do let lctx := (← get).lctx - if let some decl := lctx.letDecls[fvarId]? then + let pu ← getPurity + if let some decl := (lctx.letDecls pu)[fvarId]? then return decl.binderName - else if let some decl := lctx.params[fvarId]? then + else if let some decl := (lctx.params pu)[fvarId]? then return decl.binderName - else if let some decl := lctx.funDecls[fvarId]? then + else if let some decl := (lctx.funDecls pu)[fvarId]? then return decl.binderName else throwError "unknown free variable {fvarId.name}" -def findParam? (fvarId : FVarId) : CompilerM (Option Param) := - return (← get).lctx.params[fvarId]? +def findParam? (fvarId : FVarId) : CompilerM (Option (Param pu)) := do + return ((← get).lctx.params pu)[fvarId]? -def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) := - return (← get).lctx.letDecls[fvarId]? +def findLetDecl? (fvarId : FVarId) : CompilerM (Option (LetDecl pu)) := do + return ((← get).lctx.letDecls pu)[fvarId]? -def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) := - return (← get).lctx.funDecls[fvarId]? +def findFunDecl? (fvarId : FVarId) : CompilerM (Option (FunDecl pu)) := do + return ((← get).lctx.funDecls pu)[fvarId]? -def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do +def findLetValue? (fvarId : FVarId) : CompilerM (Option (LetValue pu)) := do let some { value, .. } ← findLetDecl? fvarId | return none return some value @@ -101,56 +111,56 @@ def isConstructorApp (fvarId : FVarId) : CompilerM Bool := do let some (.const declName _ _) ← findLetValue? fvarId | return false return (← getEnv).find? declName matches some (.ctorInfo ..) -def Arg.isConstructorApp (arg : Arg) : CompilerM Bool := do +def Arg.isConstructorApp (arg : Arg pu) : CompilerM Bool := do let .fvar fvarId := arg | return false LCNF.isConstructorApp fvarId -def getParam (fvarId : FVarId) : CompilerM Param := do +def getParam (fvarId : FVarId) : CompilerM (Param pu) := do let some param ← findParam? fvarId | throwError "unknown parameter {fvarId.name}" return param -def getLetDecl (fvarId : FVarId) : CompilerM LetDecl := do +def getLetDecl (fvarId : FVarId) : CompilerM (LetDecl pu) := do let some decl ← findLetDecl? fvarId | throwError "unknown let-declaration {fvarId.name}" return decl -def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do +def getFunDecl (fvarId : FVarId) : CompilerM (FunDecl pu) := do let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}" return decl @[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do modify fun s => { s with lctx := f s.lctx } -def eraseLetDecl (decl : LetDecl) : CompilerM Unit := do +def eraseLetDecl (decl : LetDecl pu) : CompilerM Unit := do modifyLCtx fun lctx => lctx.eraseLetDecl decl -def eraseFunDecl (decl : FunDecl) (recursive := true) : CompilerM Unit := do +def eraseFunDecl (decl : FunDecl pu) (recursive := true) : CompilerM Unit := do modifyLCtx fun lctx => lctx.eraseFunDecl decl recursive -def eraseCode (code : Code) : CompilerM Unit := do +def eraseCode (code : Code pu) : CompilerM Unit := do modifyLCtx fun lctx => lctx.eraseCode code -def eraseParam (param : Param) : CompilerM Unit := +def eraseParam (param : Param pu) : CompilerM Unit := modifyLCtx fun lctx => lctx.eraseParam param -def eraseParams (params : Array Param) : CompilerM Unit := +def eraseParams (params : Array (Param pu)) : CompilerM Unit := modifyLCtx fun lctx => lctx.eraseParams params -def eraseCodeDecl (decl : CodeDecl) : CompilerM Unit := do +def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do match decl with | .let decl => eraseLetDecl decl - | .jp decl | .fun decl => eraseFunDecl decl + | .jp decl | .fun decl _ => eraseFunDecl decl /-- Erase all free variables occurring in `decls` from the local context. -/ -def eraseCodeDecls (decls : Array CodeDecl) : CompilerM Unit := do +def eraseCodeDecls (decls : Array (CodeDecl pu)) : CompilerM Unit := do decls.forM fun decl => eraseCodeDecl decl -def eraseDecl (decl : Decl) : CompilerM Unit := do +def eraseDecl (decl : Decl pu) : CompilerM Unit := do eraseParams decl.params decl.value.forCodeM eraseCode -abbrev Decl.erase (decl : Decl) : CompilerM Unit := +abbrev Decl.erase (decl : Decl pu) : CompilerM Unit := eraseDecl decl /-- @@ -166,7 +176,7 @@ it is a free variable, a type (or type former), or `lcErased`. `Check.lean` contains a substitution validator. -/ -abbrev FVarSubst := Std.HashMap FVarId Arg +abbrev FVarSubst (pu : Purity) := Std.HashMap FVarId (Arg pu) /-- Replace the free variables in `e` using the given substitution. @@ -179,7 +189,7 @@ If `translator = false`, we assume the substitution contains free variable repla and given entries such as `x₁ ↦ x₂`, `x₂ ↦ x₃`, ..., `xₙ₋₁ ↦ xₙ`, and the expression `f x₁ x₂`, we want the resulting expression to be `f xₙ xₙ`. We use this setting, for example, in the simplifier. -/ -private partial def normExprImp (s : FVarSubst) (e : Expr) (translator : Bool) : Expr := +private partial def normExprImp (s : FVarSubst pu) (e : Expr) (translator : Bool) : Expr := go e where goApp (e : Expr) : Expr := @@ -192,7 +202,7 @@ where match e with | .fvar fvarId => match s[fvarId]? with | some (.fvar fvarId') => if translator then .fvar fvarId' else go (.fvar fvarId') - | some (.type e) => if translator then e else go e + | some (.type e _) => if translator then e else go e | some .erased => erasedExpr | none => e | .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e @@ -225,7 +235,7 @@ This function panics if the substitution is mapping `fvarId` to an expression th That is, it is not a type (or type former), nor `lcErased`. Recall that a valid `FVarSubst` contains only expressions that are free variables, `lcErased`, or type formers. -/ -partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult := +partial def normFVarImp (s : FVarSubst pu) (fvarId : FVarId) (translator : Bool) : NormFVarResult := match s[fvarId]? with | some (.fvar fvarId') => if translator then @@ -234,7 +244,7 @@ partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : normFVarImp s fvarId' translator -- Types and type formers are only preserved as hints and -- are erased in computationally relevant contexts. - | some .erased | some (.type _) => .erased + | some .erased | some (.type _ _) => .erased | none => .fvar fvarId /-- @@ -242,18 +252,18 @@ Replace the free variables in `arg` using the given substitution. See `normExprImp` -/ -private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) : Arg := +private partial def normArgImp (s : FVarSubst pu) (arg : Arg pu) (translator : Bool) : Arg pu := match arg with | .erased => arg | .fvar fvarId => match s[fvarId]? with | some (arg'@(.fvar _)) => if translator then arg' else normArgImp s arg' translator - | some (arg'@.erased) | some (arg'@(.type _)) => arg' + | some (arg'@.erased) | some (arg'@(.type _ _)) => arg' | none => arg - | .type e => arg.updateType! (normExprImp s e translator) + | .type e _ => arg.updateType! (normExprImp s e translator) -private def normArgsImp (s : FVarSubst) (args : Array Arg) (translator : Bool) : Array Arg := +private def normArgsImp (s : FVarSubst pu) (args : Array (Arg pu)) (translator : Bool) : Array (Arg pu) := args.mapMono (normArgImp s · translator) /-- @@ -261,13 +271,13 @@ Replace the free variables in `e` using the given substitution. See `normExprImp` -/ -private partial def normLetValueImp (s : FVarSubst) (e : LetValue) (translator : Bool) : LetValue := +private partial def normLetValueImp (s : FVarSubst pu) (e : LetValue pu) (translator : Bool) : LetValue pu := match e with | .erased | .lit .. => e - | .proj _ _ fvarId => match normFVarImp s fvarId translator with + | .proj _ _ fvarId _ => match normFVarImp s fvarId translator with | .fvar fvarId' => e.updateProj! fvarId' | .erased => .erased - | .const _ _ args => e.updateArgs! (normArgsImp s args translator) + | .const _ _ args _ => e.updateArgs! (normArgsImp s args translator) | .fvar fvarId args => match normFVarImp s fvarId translator with | .fvar fvarId' => e.updateFVar! fvarId' (normArgsImp s args translator) | .erased => .erased @@ -275,20 +285,20 @@ private partial def normLetValueImp (s : FVarSubst) (e : LetValue) (translator : /-- Interface for monads that have a free substitutions. -/ -class MonadFVarSubst (m : Type → Type) (translator : outParam Bool) where - getSubst : m FVarSubst +class MonadFVarSubst (m : Type → Type) (pu : outParam Purity) (translator : outParam Bool) where + getSubst : m (FVarSubst pu) export MonadFVarSubst (getSubst) -instance (m n) [MonadLift m n] [MonadFVarSubst m t] : MonadFVarSubst n t where +instance (m n) [MonadLift m n] [MonadFVarSubst m pu t] : MonadFVarSubst n pu t where getSubst := liftM (getSubst : m _) -class MonadFVarSubstState (m : Type → Type) where - modifySubst : (FVarSubst → FVarSubst) → m Unit +class MonadFVarSubstState (m : Type → Type) (pu : outParam Purity) where + modifySubst : (FVarSubst pu → FVarSubst pu) → m Unit export MonadFVarSubstState (modifySubst) -instance (m n) [MonadLift m n] [MonadFVarSubstState m] : MonadFVarSubstState n where +instance (m n) [MonadLift m n] [MonadFVarSubstState m pu] : MonadFVarSubstState n pu where modifySubst f := liftM (modifySubst f : m _) /-- @@ -296,35 +306,35 @@ Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF `Arg`. See `Check.lean` for the free variable substitution checker. -/ -@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (arg : Arg) : m Unit := +@[inline] def addSubst [MonadFVarSubstState m pu] (fvarId : FVarId) (arg : Arg pu) : m Unit := modifySubst fun s => s.insert fvarId arg /-- Add the entry `fvarId ↦ fvarId'` to the free variable substitution. -/ -@[inline] def addFVarSubst [MonadFVarSubstState m] (fvarId : FVarId) (fvarId' : FVarId) : m Unit := +@[inline] def addFVarSubst [MonadFVarSubstState m ph] (fvarId : FVarId) (fvarId' : FVarId) : m Unit := modifySubst fun s => s.insert fvarId (.fvar fvarId') -@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m t] [Monad m] (fvarId : FVarId) : m NormFVarResult := +@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m pu t] [Monad m] (fvarId : FVarId) : m NormFVarResult := return normFVarImp (← getSubst) fvarId t -@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m t] [Monad m] (e : Expr) : m Expr := +@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m pu t] [Monad m] (e : Expr) : m Expr := return normExprImp (← getSubst) e t -@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m t] [Monad m] (arg : Arg) : m Arg := +@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m pu t] [Monad m] (arg : Arg pu) : m (Arg pu) := return normArgImp (← getSubst) arg t -@[inline, inherit_doc normLetValueImp] def normLetValue [MonadFVarSubst m t] [Monad m] (e : LetValue) : m LetValue := +@[inline, inherit_doc normLetValueImp] def normLetValue [MonadFVarSubst m pu t] [Monad m] (e : LetValue pu) : m (LetValue pu) := return normLetValueImp (← getSubst) e t @[inherit_doc normExprImp, inline] -def normExprCore (s : FVarSubst) (e : Expr) (translator : Bool) : Expr := +def normExprCore (s : FVarSubst pu) (e : Expr) (translator : Bool) : Expr := normExprImp s e translator /-- Normalize the given arguments using the current substitution. -/ -def normArgs [MonadFVarSubst m t] [Monad m] (args : Array Arg) : m (Array Arg) := +def normArgs [MonadFVarSubst m pu t] [Monad m] (args : Array (Arg pu)) : m (Array (Arg pu)) := return normArgsImp (← getSubst) args t def mkFreshBinderName (binderName := `_x): CompilerM Name := do @@ -342,35 +352,35 @@ def ensureNotAnonymous (binderName : Name) (baseName : Name) : CompilerM Name := Helper functions for creating LCNF local declarations. -/ -def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM Param := do +def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM (Param pu) := do let fvarId ← mkFreshFVarId let binderName ← ensureNotAnonymous binderName `_y let param := { fvarId, binderName, type, borrow } modifyLCtx fun lctx => lctx.addParam param return param -def mkLetDecl (binderName : Name) (type : Expr) (value : LetValue) : CompilerM LetDecl := do +def mkLetDecl (binderName : Name) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu) := do let fvarId ← mkFreshFVarId let binderName ← ensureNotAnonymous binderName `_x let decl := { fvarId, binderName, type, value } modifyLCtx fun lctx => lctx.addLetDecl decl return decl -def mkFunDecl (binderName : Name) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do +def mkFunDecl (binderName : Name) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu) := do let fvarId ← mkFreshFVarId let binderName ← ensureNotAnonymous binderName `_f let funDecl := ⟨fvarId, binderName, params, type, value⟩ modifyLCtx fun lctx => lctx.addFunDecl funDecl return funDecl -def mkLetDeclErased : CompilerM LetDecl := do +def mkLetDeclErased : CompilerM (LetDecl pu) := do mkLetDecl (← mkFreshBinderName `_x) erasedExpr .erased -def mkReturnErased : CompilerM Code := do +def mkReturnErased : CompilerM (Code pu) := do let auxDecl ← mkLetDeclErased return .let auxDecl (.return auxDecl.fvarId) -private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := do +private unsafe def updateParamImp (p : Param pu) (type : Expr) : CompilerM (Param pu) := do if ptrEq type p.type then return p else @@ -378,9 +388,9 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := modifyLCtx fun lctx => lctx.addParam p return p -@[implemented_by updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param +@[implemented_by updateParamImp] opaque Param.update (p : Param pu) (type : Expr) : CompilerM (Param pu) -private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl := do +private unsafe def updateLetDeclImp (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu) := do if ptrEq type decl.type && ptrEq value decl.value then return decl else @@ -388,12 +398,12 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetV modifyLCtx fun lctx => lctx.addLetDecl decl return decl -@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl +@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu) -def LetDecl.updateValue (decl : LetDecl) (value : LetValue) : CompilerM LetDecl := +def LetDecl.updateValue (decl : LetDecl pu) (value : LetValue pu) : CompilerM (LetDecl pu) := decl.update decl.type value -private unsafe def updateFunDeclImp (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do +private unsafe def updateFunDeclImp (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu) := do if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then return decl else @@ -401,48 +411,48 @@ private unsafe def updateFunDeclImp (decl : FunDecl) (type : Expr) (params : Arr modifyLCtx fun lctx => lctx.addFunDecl decl return decl -@[implemented_by updateFunDeclImp] opaque FunDecl.update (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl +@[implemented_by updateFunDeclImp] opaque FunDecl.update (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu) -abbrev FunDecl.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl := +abbrev FunDecl.update' (decl : FunDecl pu) (type : Expr) (value : Code pu) : CompilerM (FunDecl pu) := decl.update type decl.params value -abbrev FunDecl.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl := +abbrev FunDecl.updateValue (decl : FunDecl pu) (value : Code pu) : CompilerM (FunDecl pu) := decl.update decl.type decl.params value -@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (p : Param) : m Param := do +@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (p : Param pu) : m (Param pu) := do p.update (← normExpr p.type) -def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (ps : Array Param) : m (Array Param) := +def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (ps : Array (Param pu)) : m (Array (Param pu)) := ps.mapMonoM normParam -def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : LetDecl) : m LetDecl := do +def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : LetDecl pu) : m (LetDecl pu) := do decl.update (← normExpr decl.type) (← normLetValue decl.value) -abbrev NormalizerM (_translator : Bool) := ReaderT FVarSubst CompilerM +abbrev NormalizerM (pu : Purity) (_translator : Bool) := ReaderT (FVarSubst pu) CompilerM -instance : MonadFVarSubst (NormalizerM t) t where +instance : MonadFVarSubst (NormalizerM pu t) pu t where getSubst := read /-- If `result` is `.fvar fvarId`, then return `x fvarId`. Otherwise, it is `.erased`, and method returns `let _x.i := .erased; return _x.i`. -/ -@[inline] def withNormFVarResult [MonadLiftT CompilerM m] [Monad m] (result : NormFVarResult) (x : FVarId → m Code) : m Code := do +@[inline] def withNormFVarResult [MonadLiftT CompilerM m] [Monad m] (result : NormFVarResult) (x : FVarId → m (Code pu)) : m (Code pu) := do match result with | .fvar fvarId => x fvarId | .erased => mkReturnErased mutual - partial def normFunDeclImp (decl : FunDecl) : NormalizerM t FunDecl := do + partial def normFunDeclImp (decl : FunDecl pu) : NormalizerM pu t (FunDecl pu) := do let type ← normExpr decl.type let params ← normParams decl.params let value ← normCodeImp decl.value decl.update type params value - partial def normCodeImp (code : Code) : NormalizerM t Code := do + partial def normCodeImp (code : Code pu) : NormalizerM pu t (Code pu) := do match code with | .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k) - | .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k) + | .fun decl k _ | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k) | .return fvarId => withNormFVarResult (← normFVar fvarId) fun fvarId => return code.updateReturn! fvarId | .jmp fvarId args => withNormFVarResult (← normFVar fvarId) fun fvarId => return code.updateJmp! fvarId (← normArgs args) | .unreach type => return code.updateUnreach! (← normExpr type) @@ -451,28 +461,28 @@ mutual withNormFVarResult (← normFVar c.discr) fun discr => do let alts ← c.alts.mapMonoM fun alt => match alt with - | .alt _ params k => return alt.updateAlt! (← normParams params) (← normCodeImp k) + | .alt _ params k _ => return alt.updateAlt! (← normParams params) (← normCodeImp k) | .default k => return alt.updateCode (← normCodeImp k) return code.updateCases! resultType discr alts end -@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : FunDecl) : m FunDecl := do +@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do normFunDeclImp (t := t) decl (← getSubst) /-- Similar to `internalize`, but does not refresh `FVarId`s. -/ -@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (code : Code) : m Code := do +@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (code : Code pu) : m (Code pu) := do normCodeImp (t := t) code (← getSubst) -def replaceExprFVars (e : Expr) (s : FVarSubst) (translator : Bool) : CompilerM Expr := - (normExpr e : NormalizerM translator Expr).run s +def replaceExprFVars (e : Expr) (s : FVarSubst pu) (translator : Bool) : CompilerM Expr := + (normExpr e : NormalizerM pu translator Expr).run s -def replaceFVars (code : Code) (s : FVarSubst) (translator : Bool) : CompilerM Code := - (normCode code : NormalizerM translator Code).run s +def replaceFVars (code : Code pu) (s : FVarSubst pu) (translator : Bool) : CompilerM (Code pu) := + (normCode code : NormalizerM pu translator (Code pu)).run s def mkFreshJpName : CompilerM Name := do mkFreshBinderName `_jp -def mkAuxParam (type : Expr) (borrow := false) : CompilerM Param := do +def mkAuxParam (type : Expr) (borrow := false) : CompilerM (Param pu) := do mkParam (← mkFreshBinderName `_y) type borrow def getConfig : CompilerM ConfigOptions := diff --git a/src/Lean/Compiler/LCNF/DeclHash.lean b/src/Lean/Compiler/LCNF/DeclHash.lean index 1eee7d51c0..97a5c87abd 100644 --- a/src/Lean/Compiler/LCNF/DeclHash.lean +++ b/src/Lean/Compiler/LCNF/DeclHash.lean @@ -12,25 +12,25 @@ public section namespace Lean.Compiler.LCNF -instance : Hashable Param where +instance : Hashable (Param pu) where hash p := mixHash (hash p.fvarId) (hash p.type) -def hashParams (ps : Array Param) : UInt64 := +def hashParams (ps : Array (Param pu)) : UInt64 := hash ps mutual -partial def hashAlt (alt : Alt) : UInt64 := +partial def hashAlt (alt : Alt pu) : UInt64 := match alt with - | .alt ctorName ps k => mixHash (mixHash (hash ctorName) (hash ps)) (hashCode k) + | .alt ctorName ps k _ => mixHash (mixHash (hash ctorName) (hash ps)) (hashCode k) | .default k => hashCode k -partial def hashAlts (alts : Array Alt) : UInt64 := +partial def hashAlts (alts : Array (Alt pu)) : UInt64 := alts.foldl (fun r a => mixHash r (hashAlt a)) 7 -partial def hashCode (code : Code) : UInt64 := +partial def hashCode (code : Code pu) : UInt64 := match code with | .let decl k => mixHash (mixHash (hash decl.fvarId) (hash decl.type)) (mixHash (hash decl.value) (hashCode k)) - | .fun decl k | .jp decl k => + | .fun decl k _ | .jp decl k => mixHash (mixHash (mixHash (hash decl.fvarId) (hash decl.type)) (mixHash (hashCode decl.value) (hashCode k))) (hash decl.params) | .return fvarId => hash fvarId | .unreach type => hash type @@ -39,7 +39,7 @@ partial def hashCode (code : Code) : UInt64 := end -instance : Hashable Code where +instance : Hashable (Code pu) where hash c := hashCode c deriving instance Hashable for DeclValue diff --git a/src/Lean/Compiler/LCNF/DependsOn.lean b/src/Lean/Compiler/LCNF/DependsOn.lean index 6a0ca4ca82..d618f03102 100644 --- a/src/Lean/Compiler/LCNF/DependsOn.lean +++ b/src/Lean/Compiler/LCNF/DependsOn.lean @@ -21,46 +21,46 @@ private def typeDepOn (e : Expr) : M Bool := do let s ← read return e.hasAnyFVar fun fvarId => s.contains fvarId -private def argDepOn (a : Arg) : M Bool := do +private def argDepOn (a : Arg pu) : M Bool := do match a with | .erased => return false | .fvar fvarId => fvarDepOn fvarId - | .type e => typeDepOn e + | .type e _ => typeDepOn e -private def letValueDepOn (e : LetValue) : M Bool := +private def letValueDepOn (e : LetValue pu) : M Bool := match e with | .erased | .lit .. => return false - | .proj _ _ fvarId => fvarDepOn fvarId + | .proj _ _ fvarId _ => fvarDepOn fvarId | .fvar fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn - | .const _ _ args => args.anyM argDepOn + | .const _ _ args _ => args.anyM argDepOn -private def LetDecl.depOn (decl : LetDecl) : M Bool := +private def LetDecl.depOn (decl : LetDecl pu) : M Bool := typeDepOn decl.type <||> letValueDepOn decl.value -private partial def depOn (c : Code) : M Bool := +private partial def depOn (c : Code pu) : M Bool := match c with | .let decl k => decl.depOn <||> depOn k - | .jp decl k | .fun decl k => typeDepOn decl.type <||> depOn decl.value <||> depOn k + | .jp decl k | .fun decl k _ => typeDepOn decl.type <||> depOn decl.value <||> depOn k | .cases c => typeDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode | .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn | .return fvarId => fvarDepOn fvarId | .unreach _ => return false -@[inline] def LetDecl.dependsOn (decl : LetDecl) (s : FVarIdSet) : Bool := +@[inline] def LetDecl.dependsOn (decl : LetDecl pu) (s : FVarIdSet) : Bool := decl.depOn s -@[inline] def FunDecl.dependsOn (decl : FunDecl) (s : FVarIdSet) : Bool := +@[inline] def FunDecl.dependsOn (decl : FunDecl pu) (s : FVarIdSet) : Bool := typeDepOn decl.type s || depOn decl.value s -def CodeDecl.dependsOn (decl : CodeDecl) (s : FVarIdSet) : Bool := +def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool := match decl with | .let decl => decl.dependsOn s - | .jp decl | .fun decl => decl.dependsOn s + | .jp decl | .fun decl _ => decl.dependsOn s /-- Return `true` is `c` depends on a free variable in `s`. -/ -def Code.dependsOn (c : Code) (s : FVarIdSet) : Bool := +def Code.dependsOn (c : Code pu) (s : FVarIdSet) : Bool := depOn c s end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/ElimDead.lean b/src/Lean/Compiler/LCNF/ElimDead.lean index 95533e3314..393293de4c 100644 --- a/src/Lean/Compiler/LCNF/ElimDead.lean +++ b/src/Lean/Compiler/LCNF/ElimDead.lean @@ -19,16 +19,16 @@ Collect set of (let) free variables in a LCNF value. This code exploits the LCNF property that local declarations do not occur in types. -/ -def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg) : UsedLocalDecls := +def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg .pure) : UsedLocalDecls := match arg with | .fvar fvarId => s.insert fvarId -- Locally declared variables do not occur in types. | .type _ | .erased => s -def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array Arg) : UsedLocalDecls := +def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array (Arg .pure)) : UsedLocalDecls := args.foldl (init := s) collectLocalDeclsArg -def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue) : UsedLocalDecls := +def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue .pure) : UsedLocalDecls := match e with | .erased | .lit .. => s | .proj _ _ fvarId => s.insert fvarId @@ -39,21 +39,22 @@ namespace ElimDead abbrev M := StateRefT UsedLocalDecls CompilerM -private abbrev collectArgM (arg : Arg) : M Unit := +private abbrev collectArgM (arg : Arg .pure) : M Unit := modify (collectLocalDeclsArg · arg) -private abbrev collectLetValueM (e : LetValue) : M Unit := +private abbrev collectLetValueM (e : LetValue .pure) : M Unit := modify (collectLocalDeclsLetValue · e) private abbrev collectFVarM (fvarId : FVarId) : M Unit := modify (·.insert fvarId) mutual -partial def visitFunDecl (funDecl : FunDecl) : M FunDecl := do + +partial def visitFunDecl (funDecl : FunDecl .pure) : M (FunDecl .pure) := do let value ← elimDead funDecl.value funDecl.updateValue value -partial def elimDead (code : Code) : M Code := do +partial def elimDead (code : Code .pure) : M (Code .pure) := do match code with | .let decl k => let k ← elimDead k @@ -84,10 +85,11 @@ end end ElimDead -def Code.elimDead (code : Code) : CompilerM Code := +-- TODO: Generalize this to arbitrary phases, keep in mind that in impure elim dead is not as easy though +def Code.elimDead (code : Code .pure) : CompilerM (Code .pure) := ElimDead.elimDead code |>.run' {} -def Decl.elimDead (decl : Decl) : CompilerM Decl := do +def Decl.elimDead (decl : Decl .pure) : CompilerM (Decl .pure) := do return { decl with value := (← decl.value.mapCodeM Code.elimDead) } end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/ElimDeadBranches.lean b/src/Lean/Compiler/LCNF/ElimDeadBranches.lean index 9648fb5b2b..89e720f04e 100644 --- a/src/Lean/Compiler/LCNF/ElimDeadBranches.lean +++ b/src/Lean/Compiler/LCNF/ElimDeadBranches.lean @@ -239,14 +239,14 @@ Attempt to turn a `Value` that is representing a literal into a set of auxiliary declarations + the final `FVarId` of the declaration that contains the actual literal. If it is not a literal return none. -/ -partial def getLiteral (v : Value) : CompilerM (Option ((Array CodeDecl) × FVarId)) := do +partial def getLiteral (v : Value) : CompilerM (Option ((Array (CodeDecl .pure)) × FVarId)) := do if isLiteral v then let literal ← go v return some literal else return none where - go : Value → CompilerM ((Array CodeDecl) × FVarId) + go : Value → CompilerM ((Array (CodeDecl .pure)) × FVarId) | .ctor ``Nat.zero #[] .. => do let decl ← mkAuxLetDecl <| .lit <| .nat <| 0 return (#[.let decl], decl.fvarId) @@ -260,7 +260,7 @@ where let flatten acc := fun (decls, var) => (acc.fst ++ decls, acc.snd.push <| .fvar var) let (decls, args) := fields.foldl (init := (#[], Array.replicate ctorInfo.numParams .erased)) flatten - let letVal : LetValue := .const ctorName [] args + let letVal : LetValue .pure := .const ctorName [] args let letDecl ← mkAuxLetDecl letVal return (decls.push <| .let letDecl, letDecl.fvarId) | _ => unreachable! @@ -328,7 +328,7 @@ structure InterpContext where a single declaration or a mutual block of declarations where their analysis might influence each other as we approach the fixpoint. -/ - decls : Array Decl + decls : Array (Decl .pure) /-- The index of the function we are currently operating on in `decls.` -/ @@ -386,7 +386,7 @@ def findVarValue (var : FVarId) : InterpM Value := do /-- Find the value of `arg` using the logic of `findVarValue`. -/ -def findArgValue (arg : Arg) : InterpM Value := do +def findArgValue (arg : Arg .pure) : InterpM Value := do match arg with | .fvar fvarId => findVarValue fvarId | _ => return .top @@ -421,7 +421,8 @@ Furthermore if we see that `params.size != args.size` we know that this is a partial application and set the values of the remaining parameters to `top` since it is impossible to track what will happen with them from here on. -/ -def updateFunDeclParamsAssignment (params : Array Param) (args : Array Arg) : InterpM Bool := do +def updateFunDeclParamsAssignment (params : Array (Param .pure)) (args : Array (Arg .pure)) : + InterpM Bool := do let mut ret := false let env ← getEnv for param in params, arg in args do @@ -443,7 +444,7 @@ def updateFunDeclParamsAssignment (params : Array Param) (args : Array Arg) : In updateVarAssignment param.fvarId .top return ret -def updateFunDeclParamsTop (params : Array Param) : InterpM Bool := do +def updateFunDeclParamsTop (params : Array (Param .pure)) : InterpM Bool := do let mut ret := false for param in params do let paramVal ← findVarValue param.fvarId @@ -453,7 +454,7 @@ def updateFunDeclParamsTop (params : Array Param) : InterpM Bool := do ret := true return ret -private partial def resetNestedFunDeclParams : Code → InterpM Unit +private partial def resetNestedFunDeclParams : Code .pure → InterpM Unit | .let _ k => resetNestedFunDeclParams k | .jp decl k | .fun decl k => do decl.params.forM (resetVarAssignment ·.fvarId) @@ -467,7 +468,7 @@ private partial def resetNestedFunDeclParams : Code → InterpM Unit /-- The actual abstract interpreter on a block of `Code`. -/ -partial def interpCode : Code → InterpM Unit +partial def interpCode : Code .pure → InterpM Unit | .let decl k => do let val ← interpLetValue decl.value updateVarAssignment decl.fvarId val @@ -503,7 +504,7 @@ where /-- The abstract interpreter on a `LetValue`. -/ - interpLetValue (letVal : LetValue) : InterpM Value := do + interpLetValue (letVal : LetValue .pure) : InterpM Value := do match letVal with | .lit val => return .ofLCNFLit val | .proj _ idx struct => @@ -513,7 +514,7 @@ where let env ← getEnv args.forM handleFunArg match (← getDecl? declName) with - | some decl => + | some ⟨_, decl⟩ => if decl.getArity == args.size then match getFunctionSummary? env declName with | some v => return v @@ -538,7 +539,7 @@ where return .top | .erased => return .top - handleFunArg (arg : Arg) : InterpM Unit := do + handleFunArg (arg : Arg .pure) : InterpM Unit := do if let .fvar fvarId := arg then handleFunVar fvarId @@ -557,7 +558,7 @@ where resetNestedFunDeclParams funDecl.value interpCode funDecl.value - interpFunCall (funDecl : FunDecl) (args : Array Arg) : InterpM Unit := do + interpFunCall (funDecl : FunDecl .pure) (args : Array (Arg .pure)) : InterpM Unit := do let updated ← updateFunDeclParamsAssignment funDecl.params args if updated then /- We must reset the value of nested function declaration @@ -608,11 +609,11 @@ Use the information produced by the abstract interpreter to: - Eliminate branches that we know cannot be hit - Eliminate values that we know have to be constants. -/ -partial def elimDead (assignment : Assignment) (decl : Decl) : CompilerM Decl := do +partial def elimDead (assignment : Assignment) (decl : Decl .pure) : CompilerM (Decl .pure) := do trace[Compiler.elimDeadBranches] s!"Eliminating {decl.name} with {repr (← assignment.toArray |>.mapM (fun (name, val) => do return (toString (← getBinderName name), val)))}" return { decl with value := (← decl.value.mapCodeM go) } where - go (code : Code) : CompilerM Code := do + go (code : Code .pure) : CompilerM (Code .pure) := do match code with | .let decl k => return code.updateLet! decl (← go k) @@ -624,16 +625,14 @@ where match alt with | .alt ctor args body => if discrVal.containsCtor ctor then - let filter param := do + let constantInfos ← args.filterMapM fun param => do if let some val := assignment[param.fvarId]? then if let some literal ← val.getLiteral then return some (param, literal) return none - let constantInfos ← args.filterMapM filter if constantInfos.size != 0 then - let folder := fun (body, subst) (param, decls, var) => do + let (body, subst) ← constantInfos.foldlM (init := (← go body, {})) fun (body, subst) (param, decls, var) => do return (attachCodeDecls decls body, subst.insert param.fvarId (.fvar var)) - let (body, subst) ← constantInfos.foldlM (init := (← go body, {})) folder let body ← replaceFVars body subst false return alt.updateCode body else @@ -649,7 +648,7 @@ where end UnreachableBranches open UnreachableBranches in -def Decl.elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do +def Decl.elimDeadBranches (decls : Array (Decl .pure)) : CompilerM (Array (Decl .pure)) := do /- We sort declarations by size here to ensure that when we restart in inferStep it will mostly be small declarations that get re-analyzed. diff --git a/src/Lean/Compiler/LCNF/ExtractClosed.lean b/src/Lean/Compiler/LCNF/ExtractClosed.lean index e15d64fa92..3ccd4a26ae 100644 --- a/src/Lean/Compiler/LCNF/ExtractClosed.lean +++ b/src/Lean/Compiler/LCNF/ExtractClosed.lean @@ -16,11 +16,11 @@ public section namespace Lean.Compiler.LCNF namespace ExtractClosed -abbrev ExtractM := StateRefT (Array CodeDecl) CompilerM +abbrev ExtractM := StateRefT (Array (CodeDecl .pure)) CompilerM mutual -partial def extractLetValue (v : LetValue) : ExtractM Unit := do +partial def extractLetValue (v : LetValue .pure) : ExtractM Unit := do match v with | .const _ _ args => args.forM extractArg | .fvar fnVar args => @@ -29,7 +29,7 @@ partial def extractLetValue (v : LetValue) : ExtractM Unit := do | .proj _ _ baseVar => extractFVar baseVar | .lit _ | .erased => return () -partial def extractArg (arg : Arg) : ExtractM Unit := do +partial def extractArg (arg : Arg .pure) : ExtractM Unit := do match arg with | .fvar fvarId => extractFVar fvarId | .type _ | .erased => return () @@ -41,17 +41,17 @@ partial def extractFVar (fvarId : FVarId) : ExtractM Unit := do end -def isIrrelevantArg (arg : Arg) : Bool := +def isIrrelevantArg (arg : Arg .pure) : Bool := match arg with | .erased | .type _ => true | .fvar _ => false structure Context where baseName : Name - sccDecls : Array Decl + sccDecls : Array (Decl .pure) structure State where - decls : Array Decl := {} + decls : Array (Decl .pure) := {} /-- Cache for `shouldExtractFVar` in order to avoid superlinear behavior. -/ @@ -61,7 +61,7 @@ abbrev M := ReaderT Context $ StateRefT State CompilerM mutual -partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do +partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue .pure) : M Bool := do match v with | .lit (.str _) => return true | .lit (.nat v) => @@ -90,7 +90,7 @@ partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do | .fvar fnVar args => return (← shouldExtractFVar fnVar) && (← args.allM shouldExtractArg) | .proj _ _ baseVar => shouldExtractFVar baseVar -partial def shouldExtractArg (arg : Arg) : M Bool := do +partial def shouldExtractArg (arg : Arg .pure) : M Bool := do match arg with | .fvar fvarId => shouldExtractFVar fvarId | .type _ | .erased => return true @@ -113,7 +113,7 @@ end mutual -partial def visitCode (code : Code) : M Code := do +partial def visitCode (code : Code .pure) : M (Code .pure) := do match code with | .let decl k => if (← shouldExtractLetValue true decl.value) then @@ -151,13 +151,14 @@ partial def visitCode (code : Code) : M Code := do end -def visitDecl (decl : Decl) : M Decl := do +def visitDecl (decl : Decl .pure) : M (Decl .pure) := do let value ← decl.value.mapCodeM visitCode return { decl with value } end ExtractClosed -partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do +partial def Decl.extractClosed (decl : Decl .pure) (sccDecls : Array (Decl .pure)) : + CompilerM (Array (Decl .pure)) := do let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {} return s.decls.push decl diff --git a/src/Lean/Compiler/LCNF/FVarUtil.lean b/src/Lean/Compiler/LCNF/FVarUtil.lean index e0cae2d4d8..edd1440e09 100644 --- a/src/Lean/Compiler/LCNF/FVarUtil.lean +++ b/src/Lean/Compiler/LCNF/FVarUtil.lean @@ -48,67 +48,67 @@ instance : TraverseFVar Expr where mapFVarM := Expr.mapFVarM forFVarM := Expr.forFVarM -def Arg.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (arg : Arg) : m Arg := do +def Arg.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (arg : Arg pu) : m (Arg pu) := do match arg with | .erased => return .erased - | .type e => return arg.updateType! (← TraverseFVar.mapFVarM f e) + | .type e _ => return arg.updateType! (← TraverseFVar.mapFVarM f e) | .fvar fvarId => return arg.updateFVar! (← f fvarId) -def Arg.forFVarM [Monad m] (f : FVarId → m Unit) (arg : Arg) : m Unit := do +def Arg.forFVarM [Monad m] (f : FVarId → m Unit) (arg : Arg pu) : m Unit := do match arg with | .erased => return () - | .type e => TraverseFVar.forFVarM f e + | .type e _ => TraverseFVar.forFVarM f e | .fvar fvarId => f fvarId -instance : TraverseFVar Arg where +instance : TraverseFVar (Arg pu) where mapFVarM := Arg.mapFVarM forFVarM := Arg.forFVarM -def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (e : LetValue) : m LetValue := do +def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (e : LetValue pu) : m (LetValue pu) := do match e with | .lit .. | .erased => return e - | .proj _ _ fvarId => return e.updateProj! (← f fvarId) - | .const _ _ args => return e.updateArgs! (← args.mapM (TraverseFVar.mapFVarM f)) + | .proj _ _ fvarId _ => return e.updateProj! (← f fvarId) + | .const _ _ args _ => return e.updateArgs! (← args.mapM (TraverseFVar.mapFVarM f)) | .fvar fvarId args => return e.updateFVar! (← f fvarId) (← args.mapM (TraverseFVar.mapFVarM f)) -def LetValue.forFVarM [Monad m] (f : FVarId → m Unit) (e : LetValue) : m Unit := do +def LetValue.forFVarM [Monad m] (f : FVarId → m Unit) (e : LetValue pu) : m Unit := do match e with | .lit .. | .erased => return () - | .proj _ _ fvarId => f fvarId - | .const _ _ args => args.forM (TraverseFVar.forFVarM f) + | .proj _ _ fvarId _ => f fvarId + | .const _ _ args _ => args.forM (TraverseFVar.forFVarM f) | .fvar fvarId args => f fvarId; args.forM (TraverseFVar.forFVarM f) -instance : TraverseFVar LetValue where +instance : TraverseFVar (LetValue pu) where mapFVarM := LetValue.mapFVarM forFVarM := LetValue.forFVarM -partial def LetDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : LetDecl) : m LetDecl := do +partial def LetDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : LetDecl pu) : m (LetDecl pu) := do decl.update (← Expr.mapFVarM f decl.type) (← LetValue.mapFVarM f decl.value) -partial def LetDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : LetDecl) : m Unit := do +partial def LetDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : LetDecl pu) : m Unit := do Expr.forFVarM f decl.type LetValue.forFVarM f decl.value -instance : TraverseFVar LetDecl where +instance : TraverseFVar (LetDecl pu) where mapFVarM := LetDecl.mapFVarM forFVarM := LetDecl.forFVarM -partial def Param.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (param : Param) : m Param := do +partial def Param.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (param : Param pu) : m (Param pu) := do param.update (← Expr.mapFVarM f param.type) -partial def Param.forFVarM [Monad m] (f : FVarId → m Unit) (param : Param) : m Unit := do +partial def Param.forFVarM [Monad m] (f : FVarId → m Unit) (param : Param pu) : m Unit := do Expr.forFVarM f param.type -instance : TraverseFVar Param where +instance : TraverseFVar (Param pu) where mapFVarM := Param.mapFVarM forFVarM := Param.forFVarM -partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (c : Code) : m Code := do +partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (c : Code pu) : m (Code pu) := do match c with | .let decl k => let decl ← LetDecl.mapFVarM f decl return Code.updateLet! c decl (← mapFVarM f k) - | .fun decl k => + | .fun decl k _ => let params ← decl.params.mapM (Param.mapFVarM f) let decl ← decl.update (← Expr.mapFVarM f decl.type) params (← mapFVarM f decl.value) return Code.updateFun! c decl (← mapFVarM f k) @@ -125,12 +125,12 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F | .unreach typ => return Code.updateUnreach! c (← Expr.mapFVarM f typ) -partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code) : m Unit := do +partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Unit := do match c with | .let decl k => LetDecl.forFVarM f decl forFVarM f k - | .fun decl k => + | .fun decl k _ => decl.params.forM (Param.forFVarM f) Expr.forFVarM f decl.type forFVarM f decl.value @@ -151,45 +151,45 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code) : m Unit | .unreach typ => Expr.forFVarM f typ -instance : TraverseFVar Code where +instance : TraverseFVar (Code pu) where mapFVarM := Code.mapFVarM forFVarM := Code.forFVarM -def FunDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : FunDecl) : m FunDecl := do +def FunDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : FunDecl pu) : m (FunDecl pu) := do let params ← decl.params.mapM (Param.mapFVarM f) decl.update (← Expr.mapFVarM f decl.type) params (← Code.mapFVarM f decl.value) -def FunDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : FunDecl) : m Unit := do +def FunDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : FunDecl pu) : m Unit := do decl.params.forM (Param.forFVarM f) Expr.forFVarM f decl.type Code.forFVarM f decl.value -instance : TraverseFVar FunDecl where +instance : TraverseFVar (FunDecl pu) where mapFVarM := FunDecl.mapFVarM forFVarM := FunDecl.forFVarM -instance : TraverseFVar CodeDecl where +instance : TraverseFVar (CodeDecl pu) where mapFVarM f decl := do match decl with - | .fun decl => return .fun (← mapFVarM f decl) + | .fun decl _ => return .fun (← mapFVarM f decl) | .jp decl => return .jp (← mapFVarM f decl) | .let decl => return .let (← mapFVarM f decl) forFVarM f decl := match decl with - | .fun decl => forFVarM f decl + | .fun decl _ => forFVarM f decl | .jp decl => forFVarM f decl | .let decl => forFVarM f decl -instance : TraverseFVar Alt where +instance : TraverseFVar (Alt pu) where mapFVarM f alt := do match alt with - | .alt ctor params c => + | .alt ctor params c _ => let params ← params.mapM (Param.mapFVarM f) return .alt ctor params (← Code.mapFVarM f c) | .default c => return .default (← Code.mapFVarM f c) forFVarM f alt := do match alt with - | .alt _ params c => + | .alt _ params c _ => params.forM (Param.forFVarM f) Code.forFVarM f c | .default c => Code.forFVarM f c diff --git a/src/Lean/Compiler/LCNF/FixedParams.lean b/src/Lean/Compiler/LCNF/FixedParams.lean index 2ddb1c96fa..e00d47dafa 100644 --- a/src/Lean/Compiler/LCNF/FixedParams.lean +++ b/src/Lean/Compiler/LCNF/FixedParams.lean @@ -46,12 +46,12 @@ inductive AbsValue where structure Context where /-- Declaration in the same mutual block. -/ - decls : Array Decl + decls : Array (Decl .pure) /-- Function being analyzed. We check every recursive call to this function. Remark: `main` is in `decls`. -/ - main : Decl + main : Decl .pure /-- The assignment maps free variable ids in the current code being analyzed to abstract values. We only track the abstract value assigned to parameters. @@ -84,17 +84,17 @@ def evalFVar (fvarId : FVarId) : FixParamM AbsValue := do let some val := (← read).assignment.get? fvarId | return .top return val -def evalArg (arg : Arg) : FixParamM AbsValue := do +def evalArg (arg : Arg .pure) : FixParamM AbsValue := do match arg with | .erased => return .erased - | .type (.fvar fvarId) => evalFVar fvarId - | .type _ => return .top + | .type (.fvar fvarId) _ => evalFVar fvarId + | .type _ _ => return .top | .fvar fvarId => evalFVar fvarId def inMutualBlock (declName : Name) : FixParamM Bool := return (← read).decls.any (·.name == declName) -def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do +def mkAssignment (decl : Decl .pure) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do let mut assignment := {} for param in decl.params, value in values do assignment := assignment.insert param.fvarId value @@ -102,12 +102,12 @@ def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue := mutual -partial def evalLetValue (e : LetValue) : FixParamM Unit := do +partial def evalLetValue (e : LetValue .pure) : FixParamM Unit := do match e with - | .const declName _ args => evalApp declName args + | .const declName _ args _ => evalApp declName args | _ => return () -partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do +partial def isEquivalentFunDecl? (decl : FunDecl .pure) : FixParamM (Option Nat) := do let .let { fvarId, value := (.fvar funFvarId args), .. } k := decl.value | return none if args.size != decl.params.size then return none let .return retFVarId := k | return none @@ -120,10 +120,10 @@ partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do if arg != .fvar param.fvarId && arg != .erased then return none return some funIdx -partial def evalCode (code : Code) : FixParamM Unit := do +partial def evalCode (code : Code .pure) : FixParamM Unit := do match code with | .let decl k => evalLetValue decl.value; evalCode k - | .fun decl k => + | .fun decl k _ => if let some paramIdx ← isEquivalentFunDecl? decl then withReader (fun ctx => { ctx with assignment := ctx.assignment.insert decl.fvarId (.val paramIdx) }) @@ -135,7 +135,7 @@ partial def evalCode (code : Code) : FixParamM Unit := do | .cases c => c.alts.forM fun alt => evalCode alt.getCode | .unreach .. | .jmp .. | .return .. => return () -partial def evalApp (declName : Name) (args : Array Arg) : FixParamM Unit := do +partial def evalApp (declName : Name) (args : Array (Arg .pure)) : FixParamM Unit := do let main := (← read).main if declName == main.name then -- Recursive call to the function being analyzed @@ -180,6 +180,9 @@ def mkInitialValues (numParams : Nat) : Array AbsValue := Id.run do end FixedParams open FixedParams +-- TODO: consider making it phase polymorphic, this requires detecting in place mutations of +-- variables etc in addition to just graph theory + /-- Given the (potentially mutually) recursive declarations `decls`, return a map from declaration name `decl.name` to a bit-mask `m` where `m[i]` is true @@ -188,7 +191,7 @@ applications. The function assumes that if a function `f` was declared in a mutual block, then `decls` contains all (computationally relevant) functions in the mutual block. -/ -def mkFixedParamsMap (decls : Array Decl) : NameMap (Array Bool) := Id.run do +def mkFixedParamsMap (decls : Array (Decl .pure)) : NameMap (Array Bool) := Id.run do let mut result := {} for decl in decls do let values := mkInitialValues decl.params.size diff --git a/src/Lean/Compiler/LCNF/FloatLetIn.lean b/src/Lean/Compiler/LCNF/FloatLetIn.lean index 765e64cbef..d1f323219e 100644 --- a/src/Lean/Compiler/LCNF/FloatLetIn.lean +++ b/src/Lean/Compiler/LCNF/FloatLetIn.lean @@ -38,7 +38,7 @@ inductive Decision where | unknown deriving Hashable, BEq, Inhabited, Repr -def Decision.ofAlt : Alt → Decision +def Decision.ofAlt : Alt .pure → Decision | .alt name _ _ => .arm name | .default _ => .default @@ -50,7 +50,7 @@ structure BaseFloatContext where All the declarations that were collected in the current LCNF basic block up to the current statement (in reverse order for efficiency). -/ - decls : List CodeDecl := [] + decls : List (CodeDecl .pure) := [] /-- The state for `FloatM` @@ -67,7 +67,7 @@ structure FloatState where - Which declarations do we move into a certain arm - Which declarations do we move into the default arm -/ - newArms : Std.HashMap Decision (List CodeDecl) + newArms : Std.HashMap Decision (List (CodeDecl .pure)) /-- Use to collect relevant declarations for the floating mechanism. @@ -82,7 +82,7 @@ abbrev FloatM := StateRefT FloatState BaseFloatM /-- Add `decl` to the list of declarations and run `x` with that updated context. -/ -def withNewCandidate (decl : CodeDecl) (x : BaseFloatM α) : BaseFloatM α := +def withNewCandidate (decl : CodeDecl .pure) (x : BaseFloatM α) : BaseFloatM α := withReader (fun r => { r with decls := decl :: r.decls }) do x @@ -98,7 +98,7 @@ Whether to ignore `decl` for the floating mechanism. We want to do this if: - `decl`' is storing a typeclass instance - `decl` is a projection from a variable that is storing a typeclass instance -/ -def ignore? (decl : LetDecl) : BaseFloatM Bool := do +def ignore? (decl : LetDecl .pure) : BaseFloatM Bool := do if (← isArrowClass? decl.type).isSome then return true else if let .proj _ _ fvarId := decl.value then @@ -117,7 +117,7 @@ up to this point, with respect to `cs`. The initial decisions are: - `arm` or `default` if we see the declaration only being used in exactly one cases arm - `unknown` otherwise -/ -def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := do +def initialDecisions (cs : Cases .pure) : BaseFloatM (Std.HashMap FVarId Decision) := do let mut map := Std.HashMap.emptyWithCapacity (← read).decls.length let owned : Std.HashSet FVarId := ∅ (map, _) ← (← read).decls.foldlM (init := (map, owned)) fun (acc, owned) val => do @@ -135,12 +135,12 @@ def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := (_, map) ← goCases cs |>.run map return map where - visitDecl (env : Environment) (value : CodeDecl) : StateM (Std.HashSet FVarId) Bool := do + visitDecl (env : Environment) (value : CodeDecl .pure) : StateM (Std.HashSet FVarId) Bool := do match value with | .let decl => visitLetValue env decl.value | _ => return false -- will need to investigate whether that can be a problem - visitLetValue (env : Environment) (value : LetValue) : StateM (Std.HashSet FVarId) Bool := do + visitLetValue (env : Environment) (value : LetValue .pure) : StateM (Std.HashSet FVarId) Bool := do match value with | .proj _ _ x => visitArg (.fvar x) true | .const nm _ args => @@ -158,7 +158,7 @@ where (← visitArg (.fvar x) false) | .erased | .lit _ => return false - visitArg (var : Arg) (borrowed : Bool) : StateM (Std.HashSet FVarId) Bool := do + visitArg (var : Arg .pure) (borrowed : Bool) : StateM (Std.HashSet FVarId) Bool := do let .fvar v := var | return false let res := (← get).contains v unless borrowed do @@ -173,16 +173,16 @@ where modify fun s => s.insert var .dont -- otherwise we already have the proper decision - goAlt (alt : Alt) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit := + goAlt (alt : Alt .pure) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit := forFVarM (goFVar (.ofAlt alt)) alt - goCases (cs : Cases) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit := + goCases (cs : Cases .pure) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit := cs.alts.forM goAlt /-- Compute the initial new arms. This will just set up a map from all arms of `cs` to empty `Array`s, plus one additional entry for `dont`. -/ -def initialNewArms (cs : Cases) : Std.HashMap Decision (List CodeDecl) := Id.run do +def initialNewArms (cs : Cases .pure) : Std.HashMap Decision (List (CodeDecl .pure)) := Id.run do let mut map := Std.HashMap.emptyWithCapacity (cs.alts.size + 1) map := map.insert .dont [] cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) [] @@ -203,7 +203,7 @@ cases z with Here `x` and `y` are originally marked as getting floated into `n` and `m` respectively but since `z` can't be moved we don't want that to move `x` and `y`. -/ -def dontFloat (decl : CodeDecl) : FloatM Unit := do +def dontFloat (decl : CodeDecl .pure) : FloatM Unit := do forFVarM goFVar decl modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms[Decision.dont]!) } where @@ -257,7 +257,7 @@ Will: ``` If we are at `y` `x` is still marked to be moved but we don't want that. -/ -def float (decl : CodeDecl) : FloatM Unit := do +def float (decl : CodeDecl .pure) : FloatM Unit := do let arm := (← get).decision[decl.fvarId]! forFVarM (goFVar · arm) decl modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms[arm]!) } @@ -273,7 +273,7 @@ where Iterate through `decl`, pushing local declarations that are only used in one control flow arm into said arm in order to avoid useless computations. -/ -partial def floatLetIn (decl : Decl) : CompilerM Decl := do +partial def floatLetIn (decl : Decl .pure) : CompilerM (Decl .pure) := do let newValue ← decl.value.mapCodeM go |>.run {} return { decl with value := newValue } where @@ -296,7 +296,7 @@ where else float decl - go (code : Code) : BaseFloatM Code := do + go (code : Code .pure) : BaseFloatM (Code .pure) := do match code with | .let decl k => withNewCandidate (.let decl) do @@ -334,11 +334,12 @@ where end FloatLetIn -def Decl.floatLetIn (decl : Decl) : CompilerM Decl := do +def Decl.floatLetIn (decl : Decl .pure) : CompilerM (Decl .pure) := do FloatLetIn.floatLetIn decl def floatLetIn (phase := Phase.base) (occurrence := 0) : Pass := - .mkPerDeclaration `floatLetIn Decl.floatLetIn phase occurrence + phase.withPurityCheck .pure fun h => + .mkPerDeclaration `floatLetIn phase (h ▸ Decl.floatLetIn) occurrence builtin_initialize registerTraceClass `Compiler.floatLetIn (inherited := true) diff --git a/src/Lean/Compiler/LCNF/InferType.lean b/src/Lean/Compiler/LCNF/InferType.lean index 9779eb649e..e7968ff36f 100644 --- a/src/Lean/Compiler/LCNF/InferType.lean +++ b/src/Lean/Compiler/LCNF/InferType.lean @@ -14,6 +14,10 @@ public section namespace Lean.Compiler.LCNF /-! # Type inference for LCNF -/ +namespace InferType + +namespace Pure + /- Note about **erasure confusion**. @@ -53,10 +57,9 @@ but the expected type is `S Nat Type (fun x => Nat)`. `fun x => Nat` is not eras here because it is a type former. -/ -namespace InferType /- -Type inference algorithm for LCNF. Invoked by the LCNF type checker +Type inference algorithm for pure LCNF. Invoked by the LCNF type checker to check correctness of LCNF IR. -/ @@ -80,12 +83,12 @@ def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr := let b := type.abstract xs xs.size.foldRevM (init := b) fun i _ b => do let x := xs[i] - let n ← InferType.getBinderName x.fvarId! - let ty ← InferType.getType x.fvarId! + let n ← getBinderName x.fvarId! + let ty ← getType x.fvarId! let ty := ty.abstractRange i xs; return .forallE n ty b .default -def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr := +def mkForallParams (params : Array (Param .pure)) (type : Expr) : InferTypeM Expr := let xs := params.map fun p => .fvar p.fvarId mkForallFVars xs type |>.run {} @@ -97,7 +100,7 @@ def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr := def inferConstType (declName : Name) (us : List Level) : CompilerM Expr := do if declName == ``lcErased then return erasedExpr - else if let some decl ← getDecl? declName then + else if let some ⟨_, decl⟩ ← getDecl? declName then return decl.instantiateTypeLevelParams us else /- Declaration does not have code associated with it: constructor, inductive type, foreign function -/ @@ -114,7 +117,7 @@ def inferLitValueType (value : LitValue) : Expr := | .usize .. => mkConst ``USize mutual - partial def inferArgType (arg : Arg) : InferTypeM Expr := + partial def inferArgType (arg : Arg .pure) : InferTypeM Expr := match arg with | .erased => return erasedExpr | .type e => inferType e @@ -124,13 +127,13 @@ mutual match e with | .const c us => inferConstType c us | .app .. => inferAppType e - | .fvar fvarId => InferType.getType fvarId + | .fvar fvarId => getType fvarId | .sort lvl => return .sort (mkLevelSucc lvl) | .forallE .. => inferForallType e | .lam .. => inferLambdaType e | .letE .. | .mvar .. | .mdata .. | .lit .. | .bvar .. | .proj .. => unreachable! - partial def inferLetValueType (e : LetValue) : InferTypeM Expr := do + partial def inferLetValueType (e : LetValue .pure) : InferTypeM Expr := do match e with | .erased => return erasedExpr | .lit v => return inferLitValueType v @@ -138,7 +141,7 @@ mutual | .const declName us args => inferAppTypeCore (← inferConstType declName us) args | .fvar fvarId args => inferAppTypeCore (← getType fvarId) args - partial def inferAppTypeCore (fType : Expr) (args : Array Arg) : InferTypeM Expr := do + partial def inferAppTypeCore (fType : Expr) (args : Array (Arg .pure)) : InferTypeM Expr := do let mut j := 0 let mut fType := fType for i in *...args.size do @@ -237,60 +240,79 @@ mutual mkForallFVars fvars type end +end Pure + +namespace Impure +end Impure + end InferType +-- TODO def inferType (e : Expr) : CompilerM Expr := - InferType.inferType e |>.run {} + InferType.Pure.inferType e |>.run {} -def inferAppType (fnType : Expr) (args : Array Arg) : CompilerM Expr := - InferType.inferAppTypeCore fnType args |>.run {} +def inferAppType (fnType : Expr) (args : Array (Arg pu)) : CompilerM Expr := + match pu with + | .pure => InferType.Pure.inferAppTypeCore fnType args |>.run {} + | .impure => panic! "Infer type for impure unimplemented" -- TODO -def getLevel (type : Expr) : CompilerM Level := do - match (← inferType type) with - | .sort u => return u - | e => if e.isErased then return levelOne else throwError "type expected{indentExpr type}" +def Arg.inferType (arg : Arg pu) : CompilerM Expr := + match pu with + | .pure => InferType.Pure.inferArgType arg |>.run {} + | .impure => panic! "Infer type for impure unimplemented" -- TODO -def Arg.inferType (arg : Arg) : CompilerM Expr := - InferType.inferArgType arg |>.run {} +def LetValue.inferType (e : LetValue pu) : CompilerM Expr := + match pu with + | .pure => InferType.Pure.inferLetValueType e |>.run {} + | .impure => panic! "Infer type for impure unimplemented" -- TODO -def LetValue.inferType (e : LetValue) : CompilerM Expr := - InferType.inferLetValueType e |>.run {} +def Code.inferType (code : Code pu) : CompilerM Expr := do + match pu with + | .pure => + match code with + | .let _ k | .fun _ k _ | .jp _ k => k.inferType + | .return fvarId => getType fvarId + | .jmp fvarId args => InferType.Pure.inferAppTypeCore (← getType fvarId) args |>.run {} + | .unreach type => return type + | .cases c => return c.resultType + | .impure => panic! "Infer type for impure unimplemented" -- TODO -def Code.inferType (code : Code) : CompilerM Expr := do - match code with - | .let _ k | .fun _ k | .jp _ k => k.inferType - | .return fvarId => getType fvarId - | .jmp fvarId args => InferType.inferAppTypeCore (← getType fvarId) args |>.run {} - | .unreach type => return type - | .cases c => return c.resultType - -def Code.inferParamType (params : Array Param) (code : Code) : CompilerM Expr := do +def Code.inferParamType (params : Array (Param pu)) (code : Code pu) : CompilerM Expr := do let type ← code.inferType let xs := params.map fun p => .fvar p.fvarId - InferType.mkForallFVars xs type |>.run {} + InferType.Pure.mkForallFVars xs type |>.run {} -def Alt.inferType (alt : Alt) : CompilerM Expr := +def Alt.inferType (alt : Alt pu) : CompilerM Expr := alt.getCode.inferType -def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : CompilerM LetDecl := do +def mkAuxLetDecl (e : LetValue pu) (prefixName := `_x) : CompilerM (LetDecl pu) := do mkLetDecl (← mkFreshBinderName prefixName) (← e.inferType) e -def mkForallParams (params : Array Param) (type : Expr) : CompilerM Expr := - InferType.mkForallParams params type |>.run {} +def mkForallParams (params : Array (Param pu)) (type : Expr) : CompilerM Expr := + match pu with + | .pure => InferType.Pure.mkForallParams params type |>.run {} + | .impure => panic! "Infer type for impure unimplemented" -- TODO -def mkAuxFunDecl (params : Array Param) (code : Code) (prefixName := `_f) : CompilerM FunDecl := do +private def mkAuxFunDeclAux (params : Array (Param pu)) (code : Code pu) (prefixName : Name) : + CompilerM (FunDecl pu) := do let type ← mkForallParams params (← code.inferType) let binderName ← mkFreshBinderName prefixName mkFunDecl binderName type params code -def mkAuxJpDecl (params : Array Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do - mkAuxFunDecl params code prefixName +def mkAuxFunDecl (params : Array (Param .pure)) (code : Code .pure) (prefixName := `_f) : + CompilerM (FunDecl .pure) := do + mkAuxFunDeclAux params code prefixName -def mkAuxJpDecl' (param : Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do +def mkAuxJpDecl (params : Array (Param pu)) (code : Code pu) (prefixName := `_jp) : + CompilerM (FunDecl pu) := do + mkAuxFunDeclAux params code prefixName + +def mkAuxJpDecl' (param : Param pu) (code : Code pu) (prefixName := `_jp) : + CompilerM (FunDecl pu) := do let params := #[param] - mkAuxFunDecl params code prefixName + mkAuxFunDeclAux params code prefixName -def mkCasesResultType (alts : Array Alt) : CompilerM Expr := do +def mkCasesResultType (alts : Array (Alt pu)) : CompilerM Expr := do if alts.isEmpty then throwError "`Code.bind` failed, empty `cases` found" let mut resultType ← alts[0]!.inferType diff --git a/src/Lean/Compiler/LCNF/Internalize.lean b/src/Lean/Compiler/LCNF/Internalize.lean index 7c1a26cf3d..8033353621 100644 --- a/src/Lean/Compiler/LCNF/Internalize.lean +++ b/src/Lean/Compiler/LCNF/Internalize.lean @@ -22,44 +22,45 @@ private def refreshBinderName (binderName : Name) : CompilerM Name := do namespace Internalize -abbrev InternalizeM := StateRefT FVarSubst CompilerM +abbrev InternalizeM (pu : Purity) := StateRefT (FVarSubst pu) CompilerM /-- The `InternalizeM` monad is a translator. It "translates" the free variables in the input expressions and `Code`, into new fresh free variables in the local context. -/ -instance : MonadFVarSubst InternalizeM true where +instance : MonadFVarSubst (InternalizeM pu) pu true where getSubst := get -instance : MonadFVarSubstState InternalizeM where +instance : MonadFVarSubstState (InternalizeM pu) pu where modifySubst := modify -private def mkNewFVarId (fvarId : FVarId) : InternalizeM FVarId := do +private def mkNewFVarId (fvarId : FVarId) : InternalizeM pu FVarId := do let fvarId' ← Lean.mkFreshFVarId addFVarSubst fvarId fvarId' return fvarId' -private partial def internalizeExpr (e : Expr) : InternalizeM Expr := +private partial def internalizeExpr (e : Expr) : InternalizeM pu Expr := go e where - goApp (e : Expr) : InternalizeM Expr := do + goApp (e : Expr) : InternalizeM pu Expr := do match e with | .app f a => return e.updateApp! (← goApp f) (← go a) | _ => go e - go (e : Expr) : InternalizeM Expr := do + go (e : Expr) : InternalizeM pu Expr := do if e.hasFVar then match e with - | .fvar fvarId => match (← get)[fvarId]? with + | .fvar fvarId => + match (← get)[fvarId]? with | some (.fvar fvarId') => -- In LCNF, types can't depend on let-bound fvars. - if (← findParam? fvarId').isSome then + if (← findParam? (pu := pu) fvarId').isSome then return .fvar fvarId' else return anyExpr | some .erased => return erasedExpr - | some (.type e) | none => return e + | some (.type e _) | none => return e | .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => return e | .app f a => return e.updateApp! (← goApp f) (← go a) |>.headBeta | .mdata _ b => return e.updateMData! (← go b) @@ -70,7 +71,7 @@ where else return e -def internalizeParam (p : Param) : InternalizeM Param := do +def internalizeParam (p : Param pu) : InternalizeM pu (Param pu) := do let binderName ← refreshBinderName p.binderName let type ← internalizeExpr p.type let fvarId ← mkNewFVarId p.fvarId @@ -78,31 +79,31 @@ def internalizeParam (p : Param) : InternalizeM Param := do modifyLCtx fun lctx => lctx.addParam p return p -def internalizeArg (arg : Arg) : InternalizeM Arg := do +def internalizeArg (arg : Arg pu) : InternalizeM pu (Arg pu) := do match arg with | .fvar fvarId => match (← get)[fvarId]? with | some arg'@(.fvar _) => return arg' - | some arg'@.erased | some arg'@(.type _) => return arg' + | some arg'@.erased | some arg'@(.type _ _) => return arg' | none => return arg - | .type e => return arg.updateType! (← internalizeExpr e) + | .type e _ => return arg.updateType! (← internalizeExpr e) | .erased => return arg -def internalizeArgs (args : Array Arg) : InternalizeM (Array Arg) := +def internalizeArgs (args : Array (Arg pu)) : InternalizeM pu (Array (Arg pu)) := args.mapM internalizeArg -private partial def internalizeLetValue (e : LetValue) : InternalizeM LetValue := do +private partial def internalizeLetValue (e : LetValue pu) : InternalizeM pu (LetValue pu) := do match e with | .erased | .lit .. => return e - | .proj _ _ fvarId => match (← normFVar fvarId) with + | .proj _ _ fvarId _ => match (← normFVar fvarId) with | .fvar fvarId' => return e.updateProj! fvarId' | .erased => return .erased - | .const _ _ args => return e.updateArgs! (← internalizeArgs args) + | .const _ _ args _ => return e.updateArgs! (← internalizeArgs args) | .fvar fvarId args => match (← normFVar fvarId) with | .fvar fvarId' => return e.updateFVar! fvarId' (← internalizeArgs args) | .erased => return .erased -def internalizeLetDecl (decl : LetDecl) : InternalizeM LetDecl := do +def internalizeLetDecl (decl : LetDecl pu) : InternalizeM pu (LetDecl pu) := do let binderName ← refreshBinderName decl.binderName let type ← internalizeExpr decl.type let value ← internalizeLetValue decl.value @@ -113,7 +114,7 @@ def internalizeLetDecl (decl : LetDecl) : InternalizeM LetDecl := do mutual -partial def internalizeFunDecl (decl : FunDecl) : InternalizeM FunDecl := do +partial def internalizeFunDecl (decl : FunDecl pu) : InternalizeM pu (FunDecl pu) := do let type ← internalizeExpr decl.type let binderName ← refreshBinderName decl.binderName let params ← decl.params.mapM internalizeParam @@ -123,10 +124,10 @@ partial def internalizeFunDecl (decl : FunDecl) : InternalizeM FunDecl := do modifyLCtx fun lctx => lctx.addFunDecl decl return decl -partial def internalizeCode (code : Code) : InternalizeM Code := do +partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do match code with | .let decl k => return .let (← internalizeLetDecl decl) (← internalizeCode k) - | .fun decl k => return .fun (← internalizeFunDecl decl) (← internalizeCode k) + | .fun decl k _ => return .fun (← internalizeFunDecl decl) (← internalizeCode k) | .jp decl k => return .jp (← internalizeFunDecl decl) (← internalizeCode k) | .return fvarId => withNormFVarResult (← normFVar fvarId) fun fvarId => return .return fvarId | .jmp fvarId args => withNormFVarResult (← normFVar fvarId) fun fvarId => return .jmp fvarId (← internalizeArgs args) @@ -134,19 +135,19 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do | .cases c => withNormFVarResult (← normFVar c.discr) fun discr => do let resultType ← internalizeExpr c.resultType - let internalizeAltCode (k : Code) : InternalizeM Code := + let internalizeAltCode (k : Code pu) : InternalizeM pu (Code pu) := internalizeCode k let alts ← c.alts.mapM fun - | .alt ctorName params k => return .alt ctorName (← params.mapM internalizeParam) (← internalizeAltCode k) + | .alt ctorName params k _ => return .alt ctorName (← params.mapM internalizeParam) (← internalizeAltCode k) | .default k => return .default (← internalizeAltCode k) return .cases ⟨c.typeName, resultType, discr, alts⟩ end -partial def internalizeCodeDecl (decl : CodeDecl) : InternalizeM CodeDecl := do +partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl pu) := do match decl with | .let decl => return .let (← internalizeLetDecl decl) - | .fun decl => return .fun (← internalizeFunDecl decl) + | .fun decl _ => return .fun (← internalizeFunDecl decl) | .jp decl => return .jp (← internalizeFunDecl decl) end Internalize @@ -154,14 +155,14 @@ end Internalize /-- Refresh free variables ids in `code`, and store their declarations in the local context. -/ -partial def Code.internalize (code : Code) (s : FVarSubst := {}) : CompilerM Code := +partial def Code.internalize (code : Code pu) (s : FVarSubst pu := {}) : CompilerM (Code pu) := Internalize.internalizeCode code |>.run' s open Internalize in -def Decl.internalize (decl : Decl) (s : FVarSubst := {}): CompilerM Decl := +def Decl.internalize (decl : Decl pu) (s : FVarSubst pu := {}): CompilerM (Decl pu) := go decl |>.run' s where - go (decl : Decl) : InternalizeM Decl := do + go (decl : Decl pu) : InternalizeM pu (Decl pu) := do let type ← internalizeExpr decl.type let params ← decl.params.mapM internalizeParam let value ← decl.value.mapCodeM internalizeCode @@ -170,13 +171,13 @@ where /-- Create a fresh local context and internalize the given decls. -/ -def cleanup (decl : Array Decl) : CompilerM (Array Decl) := do +def cleanup (decl : Array (Decl pu)) : CompilerM (Array (Decl pu)) := do modify fun _ => {} decl.mapM fun decl => do modify fun s => { s with nextIdx := 1 } decl.internalize -def normalizeFVarIds (decl : Decl) : CoreM Decl := do +def normalizeFVarIds (decl : Decl pu) : CoreM (Decl pu) := do let ngenSaved ← getNGen setNGen {} try diff --git a/src/Lean/Compiler/LCNF/JoinPoints.lean b/src/Lean/Compiler/LCNF/JoinPoints.lean index da8a79c611..3b0ad560a9 100644 --- a/src/Lean/Compiler/LCNF/JoinPoints.lean +++ b/src/Lean/Compiler/LCNF/JoinPoints.lean @@ -92,13 +92,13 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do /-- Remove all join point candidates contained in `a`. -/ -private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do +private partial def removeCandidatesInArg (a : Arg .pure) : FindM Unit := do forFVarM eraseCandidate a /-- Remove all join point candidates contained in `a`. -/ -private partial def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do +private partial def removeCandidatesInLetValue (e : LetValue .pure) : FindM Unit := do forFVarM eraseCandidate e /-- @@ -117,7 +117,7 @@ private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do { targetInfo with associated := targetInfo.associated.insert src } @[inline] -private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α := +private def withFnBody (decl : FunDecl .pure) (x : FindM α) : FindM α := withReader (fun ctx => { ctx with definitionDepth := ctx.definitionDepth + 1, @@ -125,7 +125,7 @@ private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α := x @[inline] -private def withFnDefined (decl : FunDecl) (x : FindM α) : FindM α := +private def withFnDefined (decl : FunDecl .pure) (x : FindM α) : FindM α := withReader (fun ctx => { ctx with scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do @@ -163,11 +163,11 @@ def test (b : Bool) (x y : Nat) : Nat := this. This is because otherwise the calls to `myjp` in `f` and `g` would produce out of scope join point jumps. -/ -partial def find (decl : Decl) : CompilerM FindState := do +partial def find (decl : Decl .pure) : CompilerM FindState := do let (_, candidates) ← decl.value.forCodeM go |>.run {} |>.run {} return candidates where - go : Code → FindM Unit + go : Code .pure → FindM Unit | .let decl k => do match k, decl.value with | .return valId, .fvar fvarId args => @@ -207,13 +207,13 @@ where Replace all join point candidate `fun` declarations with `jp` ones and all calls to them with `jmp`s. -/ -partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do +partial def replace (decl : Decl .pure) (state : FindState) : CompilerM (Decl .pure) := do let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName) let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := ∅) mapper let newValue ← decl.value.mapCodeM go |>.run replaceCtx return { decl with value := newValue } where - go (code : Code) : ReplaceM Code := do + go (code : Code .pure) : ReplaceM (Code .pure) := do match code with | .let decl k => match k, decl.value with @@ -274,7 +274,7 @@ structure ExtendState where to `Param`s. The free variables in this map are the once that the context of said join point will be extended by passing in the respective parameter. -/ - fvarMap : Std.HashMap FVarId (Std.HashMap FVarId Param) := {} + fvarMap : Std.HashMap FVarId (Std.HashMap FVarId (Param .pure)) := {} /-- The monad for the `extendJoinPointContext` pass. @@ -388,7 +388,7 @@ the join point. This is so in the case of nested join points that refer to parameters of the current one we extend the context of the nested join points by said parameters. -/ -def withNewJpScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do +def withNewJpScope (decl : FunDecl .pure) (x : ExtendM α): ExtendM α := do withReader (fun ctx => { ctx with currentJp? := some decl.fvarId }) do modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId {} } withNewScope do @@ -401,7 +401,7 @@ It will back up the current scope (since we are doing a case split and want to continue with other arms afterwards) and add all of the parameters of the match arm to the list of candidates. -/ -def withNewAltScope (alt : Alt) (x : ExtendM α) : ExtendM α := do +def withNewAltScope (alt : Alt .pure) (x : ExtendM α) : ExtendM α := do withBackTrackingScope do withNewCandidates (alt.getParams.map (·.fvarId)) do x @@ -418,7 +418,7 @@ All of this is done to eliminate dependencies of join points onto their position within the code so we can pull them out as far as possible, hopefully enabling new inlining possibilities in the next simplifier run. -/ -partial def extend (decl : Decl) : CompilerM Decl := do +partial def extend (decl : Decl .pure) : CompilerM (Decl .pure) := do let newValue ← decl.value.mapCodeM go |>.run {} |>.run' {} |>.run' {} let decl := { decl with value := newValue } decl.pullFunDecls @@ -426,7 +426,7 @@ where goFVar (fvar : FVarId) : ExtendM FVarId := do extendByIfNecessary fvar replaceFVar fvar - go (code : Code) : ExtendM Code := do + go (code : Code .pure) : ExtendM (Code .pure) := do match code with | .let decl k => let decl ← decl.updateValue (← mapFVarM goFVar decl.value) @@ -491,7 +491,7 @@ structure AnalysisState where A map, that for each join point id contains a map from all (so far) duplicated argument ids to the respective duplicate value -/ - jpJmpArgs : FVarIdMap FVarSubst := {} + jpJmpArgs : FVarIdMap (FVarSubst .pure) := {} abbrev ReduceAnalysisM := ReaderT AnalysisCtx StateRefT AnalysisState ScopeM abbrev ReduceActionM := ReaderT AnalysisState CompilerM @@ -539,17 +539,17 @@ After we have performed all of these optimizations we can take away the (remaining) common arguments and end up with nicely floated and optimized code that has as little arguments as possible in the join points. -/ -partial def reduce (decl : Decl) : CompilerM Decl := do +partial def reduce (decl : Decl .pure) : CompilerM (Decl .pure) := do let (_, analysis) ← decl.value.forCodeM goAnalyze |>.run {} |>.run {} |>.run' {} let newValue ← decl.value.mapCodeM goReduce |>.run analysis return { decl with value := newValue } where - goAnalyzeFunDecl (fn : FunDecl) : ReduceAnalysisM Unit := do + goAnalyzeFunDecl (fn : FunDecl .pure) : ReduceAnalysisM Unit := do withNewScope do fn.params.forM (addToScope ·.fvarId) goAnalyze fn.value - goAnalyze (code : Code) : ReduceAnalysisM Unit := do + goAnalyze (code : Code .pure) : ReduceAnalysisM Unit := do match code with | .let decl k => addToScope decl.fvarId @@ -571,7 +571,7 @@ where goAnalyze alt.getCode cs.alts.forM visitor | .jmp fn args => - let decl ← getFunDecl fn + let decl ← getFunDecl (pu := .pure) fn if let some knownArgs := (← get).jpJmpArgs.get? fn then let mut newArgs := knownArgs for (param, arg) in decl.params.zip args do @@ -589,7 +589,7 @@ where modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn interestingArgs } | .return .. | .unreach .. => return () - goReduce (code : Code) : ReduceActionM Code := do + goReduce (code : Code .pure) : ReduceActionM (Code .pure) := do match code with | .jp decl k => if let some reducibleArgs := (← read).jpJmpArgs.get? decl.fvarId then @@ -613,7 +613,7 @@ where return Code.updateFun! code decl (← goReduce k) | .jmp fn args => let reducibleArgs := (← read).jpJmpArgs.get! fn - let decl ← getFunDecl fn + let decl ← getFunDecl (pu := .pure) fn let newParams := decl.params.zip args |>.filter (!reducibleArgs.contains ·.fst.fvarId) |>.map Prod.snd @@ -630,7 +630,7 @@ where end JoinPointCommonArgs -def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do +def Decl.findJoinPoints? (decl : Decl .pure) : CompilerM (Option (Decl .pure)) := do let findResult ← JoinPointFinder.find decl trace[Compiler.findJoinPoints] "Found {findResult.candidates.size} jp candidates for {decl.name}" if findResult.candidates.isEmpty then @@ -642,29 +642,32 @@ def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do Find all `fun` declarations in `decl` that qualify as join points then replace their definitions and call sites with `jp`/`jmp`. -/ -def Decl.findJoinPoints (decl : Decl) : CompilerM Decl := do +def Decl.findJoinPoints (decl : Decl .pure) : CompilerM (Decl .pure) := do return (← Decl.findJoinPoints? decl).getD decl def findJoinPoints (occurrence : Nat := 0) : Pass := - .mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base (occurrence := occurrence) + .mkPerDeclaration `findJoinPoints .base Decl.findJoinPoints (occurrence := occurrence) builtin_initialize registerTraceClass `Compiler.findJoinPoints (inherited := true) -def Decl.extendJoinPointContext (decl : Decl) : CompilerM Decl := do +def Decl.extendJoinPointContext (decl : Decl .pure) : CompilerM (Decl .pure) := do JoinPointContextExtender.extend decl +-- TODO: It might make sense to extend this to impure one day def extendJoinPointContext (occurrence : Nat := 0) (phase := Phase.mono) (_h : phase ≠ .base := by simp): Pass := - .mkPerDeclaration `extendJoinPointContext Decl.extendJoinPointContext phase (occurrence := occurrence) + phase.withPurityCheck .pure fun h => + .mkPerDeclaration `extendJoinPointContext phase (h ▸ Decl.extendJoinPointContext) (occurrence := occurrence) builtin_initialize registerTraceClass `Compiler.extendJoinPointContext (inherited := true) -def Decl.commonJoinPointArgs (decl : Decl) : CompilerM Decl := do +def Decl.commonJoinPointArgs (decl : Decl .pure) : CompilerM (Decl .pure) := do JoinPointCommonArgs.reduce decl +-- TODO: It might make sense to extend this to impure one day def commonJoinPointArgs : Pass := - .mkPerDeclaration `commonJoinPointArgs Decl.commonJoinPointArgs .mono + .mkPerDeclaration `commonJoinPointArgs .mono Decl.commonJoinPointArgs builtin_initialize registerTraceClass `Compiler.commonJoinPointArgs (inherited := true) diff --git a/src/Lean/Compiler/LCNF/LCtx.lean b/src/Lean/Compiler/LCNF/LCtx.lean index 202da6f775..f71e32da1e 100644 --- a/src/Lean/Compiler/LCNF/LCtx.lean +++ b/src/Lean/Compiler/LCNF/LCtx.lean @@ -16,61 +16,97 @@ namespace Lean.Compiler.LCNF LCNF local context. -/ structure LCtx where - params : Std.HashMap FVarId Param := {} - letDecls : Std.HashMap FVarId LetDecl := {} - funDecls : Std.HashMap FVarId FunDecl := {} + paramsPure : Std.HashMap FVarId (Param .pure) := {} + paramsImpure : Std.HashMap FVarId (Param .impure) := {} + letDeclsPure : Std.HashMap FVarId (LetDecl .pure) := {} + letDeclsImpure : Std.HashMap FVarId (LetDecl .impure) := {} + funDeclsPure : Std.HashMap FVarId (FunDecl .pure) := {} + funDeclsImpure : Std.HashMap FVarId (FunDecl .impure) := {} deriving Inhabited -def LCtx.addParam (lctx : LCtx) (param : Param) : LCtx := - { lctx with params := lctx.params.insert param.fvarId param } +def LCtx.addParam (lctx : LCtx) (param : Param pu) : LCtx := + match pu with + | .pure => { lctx with paramsPure := lctx.paramsPure.insert param.fvarId param } + | .impure => { lctx with paramsImpure := lctx.paramsImpure.insert param.fvarId param } -def LCtx.addLetDecl (lctx : LCtx) (letDecl : LetDecl) : LCtx := - { lctx with letDecls := lctx.letDecls.insert letDecl.fvarId letDecl } +def LCtx.addLetDecl (lctx : LCtx) (letDecl : LetDecl pu) : LCtx := + match pu with + | .pure => { lctx with letDeclsPure := lctx.letDeclsPure.insert letDecl.fvarId letDecl } + | .impure => { lctx with letDeclsImpure := lctx.letDeclsImpure.insert letDecl.fvarId letDecl } -def LCtx.addFunDecl (lctx : LCtx) (funDecl : FunDecl) : LCtx := - { lctx with funDecls := lctx.funDecls.insert funDecl.fvarId funDecl } +def LCtx.addFunDecl (lctx : LCtx) (funDecl : FunDecl pu) : LCtx := + match pu with + | .pure => { lctx with funDeclsPure := lctx.funDeclsPure.insert funDecl.fvarId funDecl } + | .impure => { lctx with funDeclsImpure := lctx.funDeclsImpure.insert funDecl.fvarId funDecl } -def LCtx.eraseParam (lctx : LCtx) (param : Param) : LCtx := - { lctx with params := lctx.params.erase param.fvarId } +def LCtx.eraseParam (lctx : LCtx) (param : Param pu) : LCtx := + match pu with + | .pure => { lctx with paramsPure := lctx.paramsPure.erase param.fvarId } + | .impure => { lctx with paramsImpure := lctx.paramsImpure.erase param.fvarId } -def LCtx.eraseParams (lctx : LCtx) (ps : Array Param) : LCtx := - { lctx with params := ps.foldl (init := lctx.params) fun params p => params.erase p.fvarId } +def LCtx.eraseParams (lctx : LCtx) (ps : Array (Param pu)) : LCtx := + match pu with + | .pure => { lctx with paramsPure := ps.foldl (init := lctx.paramsPure) fun params p => params.erase p.fvarId } + | .impure => { lctx with paramsImpure := ps.foldl (init := lctx.paramsImpure) fun params p => params.erase p.fvarId } -def LCtx.eraseLetDecl (lctx : LCtx) (decl : LetDecl) : LCtx := - { lctx with letDecls := lctx.letDecls.erase decl.fvarId } +def LCtx.eraseLetDecl (lctx : LCtx) (decl : LetDecl pu) : LCtx := + match pu with + | .pure => { lctx with letDeclsPure := lctx.letDeclsPure.erase decl.fvarId } + | .impure => { lctx with letDeclsImpure := lctx.letDeclsImpure.erase decl.fvarId } mutual - partial def LCtx.eraseFunDecl (lctx : LCtx) (decl : FunDecl) (recursive := true) : LCtx := - let lctx := { lctx with funDecls := lctx.funDecls.erase decl.fvarId } + partial def LCtx.eraseFunDecl (lctx : LCtx) (decl : FunDecl pu) (recursive := true) : LCtx := + let lctx := + match pu with + | .pure => { lctx with funDeclsPure := lctx.funDeclsPure.erase decl.fvarId } + | .impure => { lctx with funDeclsImpure := lctx.funDeclsImpure.erase decl.fvarId } if recursive then eraseCode decl.value <| eraseParams lctx decl.params else lctx - partial def LCtx.eraseAlts (alts : Array Alt) (lctx : LCtx) : LCtx := + partial def LCtx.eraseAlts (alts : Array (Alt pu)) (lctx : LCtx) : LCtx := alts.foldl (init := lctx) fun lctx alt => match alt with | .default k => eraseCode k lctx - | .alt _ ps k => eraseCode k <| eraseParams lctx ps + | .alt _ ps k _ => eraseCode k <| eraseParams lctx ps - partial def LCtx.eraseCode (code : Code) (lctx : LCtx) : LCtx := + partial def LCtx.eraseCode (code : Code pu) (lctx : LCtx) : LCtx := match code with | .let decl k => eraseCode k <| lctx.eraseLetDecl decl - | .jp decl k | .fun decl k => eraseCode k <| eraseFunDecl lctx decl + | .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl | .cases c => eraseAlts c.alts lctx | _ => lctx end +@[inline] +def LCtx.params (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (Param pu) := + match pu with + | .pure => lctx.paramsPure + | .impure => lctx.paramsImpure + +@[inline] +def LCtx.letDecls (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (LetDecl pu) := + match pu with + | .pure => lctx.letDeclsPure + | .impure => lctx.letDeclsImpure + +@[inline] +def LCtx.funDecls (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (FunDecl pu) := + match pu with + | .pure => lctx.funDeclsPure + | .impure => lctx.funDeclsImpure + /-- Convert a LCNF local context into a regular Lean local context. -/ -def LCtx.toLocalContext (lctx : LCtx) : LocalContext := Id.run do +def LCtx.toLocalContext (lctx : LCtx) (pu : Purity) : LocalContext := Id.run do let mut result := {} - for (_, param) in lctx.params.toArray do + for (_, param) in lctx.params pu do result := result.addDecl (.cdecl 0 param.fvarId param.binderName param.type .default .default) - for (_, decl) in lctx.letDecls.toArray do + for (_, decl) in lctx.letDecls pu do result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value.toExpr true .default) - for (_, decl) in lctx.funDecls.toArray do + for (_, decl) in lctx.funDecls pu do result := result.addDecl (.cdecl 0 decl.fvarId decl.binderName decl.type .default .default) return result diff --git a/src/Lean/Compiler/LCNF/LambdaLifting.lean b/src/Lean/Compiler/LCNF/LambdaLifting.lean index 06f607fbef..d930ecb500 100644 --- a/src/Lean/Compiler/LCNF/LambdaLifting.lean +++ b/src/Lean/Compiler/LCNF/LambdaLifting.lean @@ -29,7 +29,7 @@ structure Context where Declaration where lambda lifting is being applied. We use it to provide the "base name" for auxiliary declarations and the flag `safe`. -/ - mainDecl : Decl + mainDecl : Decl .pure /-- If true, the lambda-lifted functions inherit the inline attribute from `mainDecl`. We use this feature to implement `@[inline] instance ...` and `@[always_inline] instance ...` @@ -51,7 +51,7 @@ structure State where /-- New auxiliary declarations -/ - decls : Array Decl := #[] + decls : Array (Decl .pure) := #[] /-- Next index for generating auxiliary declaration name. -/ @@ -64,13 +64,13 @@ abbrev LiftM := ReaderT Context (StateRefT State (ScopeT CompilerM)) Return `true` if the given declaration takes a local instance as a parameter. We lambda lift this kind of local function declaration before specialization. -/ -def hasInstParam (decl : FunDecl) : CompilerM Bool := +def hasInstParam (decl : FunDecl .pure) : CompilerM Bool := decl.params.anyM fun param => return (← isArrowClass? param.type).isSome /-- Return `true` if the given declaration should be lambda lifted. -/ -def shouldLift (decl : FunDecl) : LiftM Bool := do +def shouldLift (decl : FunDecl .pure) : LiftM Bool := do let minSize := (← read).minSize if decl.value.size < minSize then return false @@ -85,7 +85,7 @@ partial def mkAuxDeclName : LiftM Name := do if (← getDecl? nameNew).isNone then return nameNew mkAuxDeclName -def replaceFunDecl (decl : FunDecl) (value : LetValue) : LiftM LetDecl := do +def replaceFunDecl (decl : FunDecl .pure) (value : LetValue .pure) : LiftM (LetDecl .pure) := do /- We reuse `decl`s `fvarId` to avoid substitution -/ let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value } modifyLCtx fun lctx => lctx.addLetDecl declNew @@ -97,7 +97,7 @@ open Internalize in Create a new auxiliary declaration. The array `closure` contains all free variables occurring in `decl`. -/ -def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do +def mkAuxDecl (closure : Array (Param .pure)) (decl : FunDecl .pure) : LiftM (LetDecl .pure) := do let nameNew ← mkAuxDeclName let inlineAttr? ← if (← read).inheritInlineAttrs then pure (← read).mainDecl.inlineAttr? else pure none let auxDecl ← go nameNew (← read).mainDecl.safe inlineAttr? |>.run' {} @@ -113,16 +113,16 @@ def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do let value := .const auxDeclName us (closure.map (.fvar ·.fvarId)) replaceFunDecl decl value where - go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM Decl := do + go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM .pure (Decl .pure):= do let params := (← closure.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam) let code ← internalizeCode decl.value let type ← code.inferType let type ← mkForallParams params type let value := .code code - let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl } + let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl .pure } return decl.setLevelParams -def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do +def etaContractibleDecl? (decl : FunDecl .pure) : LiftM (Option (LetDecl .pure)) := do if !(← read).allowEtaContraction then return none let .let { fvarId := letVar, value := .const declName us args, .. } (.return retVar) := decl.value | return none @@ -137,11 +137,11 @@ def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do replaceFunDecl decl value mutual - partial def visitFunDecl (funDecl : FunDecl) : LiftM FunDecl := do + partial def visitFunDecl (funDecl : FunDecl .pure) : LiftM (FunDecl .pure) := do let value ← withParams funDecl.params <| visitCode funDecl.value funDecl.update' funDecl.type value - partial def visitCode (code : Code) : LiftM Code := do + partial def visitCode (code : Code .pure) : LiftM (Code .pure) := do match code with | .let decl k => let k ← withFVar decl.fvarId <| visitCode k @@ -174,14 +174,14 @@ mutual | .unreach .. | .jmp .. | .return .. => return code end -def main (decl : Decl) : LiftM Decl := do +def main (decl : Decl .pure) : LiftM (Decl .pure) := do let value ← withParams decl.params <| decl.value.mapCodeM visitCode return { decl with value } end LambdaLifting -partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (allowEtaContraction : Bool) - (suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array Decl) := do +partial def Decl.lambdaLifting (decl : Decl .pure) (liftInstParamOnly : Bool) (allowEtaContraction : Bool) + (suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array (Decl .pure)) := do let ctx := { mainDecl := decl, liftInstParamOnly, diff --git a/src/Lean/Compiler/LCNF/Level.lean b/src/Lean/Compiler/LCNF/Level.lean index 1d72f98b93..0b596de831 100644 --- a/src/Lean/Compiler/LCNF/Level.lean +++ b/src/Lean/Compiler/LCNF/Level.lean @@ -105,45 +105,45 @@ open Lean.CollectLevelParams abbrev visitType (type : Expr) : Visitor := visitExpr type -def visitArg (arg : Arg) : Visitor := +def visitArg (arg : Arg .pure) : Visitor := match arg with | .erased | .fvar .. => id - | .type e => visitType e + | .type e _ => visitType e -def visitArgs (args : Array Arg) : Visitor := +def visitArgs (args : Array (Arg .pure)) : Visitor := fun s => args.foldl (init := s) fun s arg => visitArg arg s -def visitLetValue (e : LetValue) : Visitor := +def visitLetValue (e : LetValue .pure) : Visitor := match e with | .erased | .lit .. | .proj .. => id - | .const _ us args => visitLevels us ∘ visitArgs args + | .const _ us args _ => visitLevels us ∘ visitArgs args | .fvar _ args => visitArgs args -def visitParam (p : Param) : Visitor := +def visitParam (p : Param .pure) : Visitor := visitType p.type -def visitParams (ps : Array Param) : Visitor := +def visitParams (ps : Array (Param .pure)) : Visitor := fun s => ps.foldl (init := s) fun s p => visitParam p s mutual - partial def visitAlt (alt : Alt) : Visitor := + partial def visitAlt (alt : Alt .pure) : Visitor := match alt with | .default k => visitCode k - | .alt _ ps k => visitCode k ∘ visitParams ps + | .alt _ ps k _ => visitCode k ∘ visitParams ps - partial def visitAlts (alts : Array Alt) : Visitor := + partial def visitAlts (alts : Array (Alt .pure)) : Visitor := fun s => alts.foldl (init := s) fun s alt => visitAlt alt s - partial def visitCode : Code → Visitor + partial def visitCode : Code .pure → Visitor | .let decl k => visitCode k ∘ visitLetValue decl.value ∘ visitType decl.type - | .fun decl k | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type + | .fun decl k _ | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type | .cases c => visitAlts c.alts ∘ visitType c.resultType | .unreach type => visitType type | .return _ => id | .jmp _ args => visitArgs args end -def visitDeclValue : DeclValue → Visitor +def visitDeclValue : DeclValue .pure → Visitor | .code c => visitCode c | .extern .. => id @@ -156,7 +156,7 @@ open CollectLevelParams Collect universe level parameters collecting in the type, parameters, and value, and then set `decl.levelParams` with the resulting value. -/ -def Decl.setLevelParams (decl : Decl) : Decl := +def Decl.setLevelParams (decl : Decl .pure) : Decl .pure := let levelParams := (visitDeclValue decl.value ∘ visitParams decl.params ∘ visitType decl.type) {} |>.params.toList { decl with levelParams } diff --git a/src/Lean/Compiler/LCNF/Main.lean b/src/Lean/Compiler/LCNF/Main.lean index f4f47a9b33..a7111462e4 100644 --- a/src/Lean/Compiler/LCNF/Main.lean +++ b/src/Lean/Compiler/LCNF/Main.lean @@ -14,6 +14,7 @@ import Lean.Meta.Match.MatcherInfo import Lean.Compiler.LCNF.SplitSCC public import Lean.Compiler.IR.Basic public import Lean.Compiler.LCNF.CompilerM + public section namespace Lean.Compiler.LCNF /-- @@ -50,7 +51,7 @@ A checkpoint in code generation to print all declarations in between compiler passes in order to ease debugging. The trace can be viewed with `set_option trace.Compiler.step true`. -/ -def checkpoint (stepName : Name) (decls : Array Decl) (shouldCheck : Bool) : CompilerM Unit := do +def checkpoint (stepName : Name) (decls : Array (Decl pu)) (shouldCheck : Bool) : CompilerM Unit := do for decl in decls do trace[Compiler.stat] "{decl.name} : {decl.size}" withOptions (fun opts => opts.set `pp.motives.pi false) do @@ -101,12 +102,12 @@ def run (declNames : Array Name) : CompilerM (Array (Array IR.Decl)) := withAtLe let decls := markRecDecls decls let manager ← getPassManager let isCheckEnabled := compiler.check.get (← getOptions) - let decls ← runPassManagerPart "compilation (LCNF base)" manager.basePasses decls isCheckEnabled - let decls ← runPassManagerPart "compilation (LCNF mono)" manager.monoPasses decls isCheckEnabled + let decls ← runPassManagerPart .pure .pure "compilation (LCNF base)" manager.basePasses decls isCheckEnabled + let decls ← runPassManagerPart .pure .pure "compilation (LCNF mono)" manager.monoPasses decls isCheckEnabled let sccs ← withTraceNode `Compiler.splitSCC (fun _ => return m!"Splitting up SCC") do splitScc decls sccs.mapM fun decls => do - let decls ← runPassManagerPart "compilation (LCNF mono)" manager.monoPassesNoLambda decls isCheckEnabled + let decls ← runPassManagerPart .pure .pure "compilation (LCNF mono)" manager.monoPassesNoLambda decls isCheckEnabled if (← Lean.isTracingEnabledFor `Compiler.result) then for decl in decls do let decl ← normalizeFVarIds decl @@ -115,14 +116,19 @@ def run (declNames : Array Name) : CompilerM (Array (Array IR.Decl)) := withAtLe let irDecls ← IR.toIR decls IR.compile irDecls where - runPassManagerPart (profilerName : String) (passes : Array Pass) (decls : Array Decl) - (isCheckEnabled : Bool) : CompilerM (Array Decl) := do + runPassManagerPart (inPhase outPhase : Purity) (profilerName : String) + (passes : Array Pass) (decls : Array (Decl inPhase)) (isCheckEnabled : Bool) : + CompilerM (Array (Decl outPhase)) := do profileitM Exception profilerName (← getOptions) do - let mut decls := decls + let mut state : (pu : Purity) × Array (Decl pu) := ⟨inPhase, decls⟩ for pass in passes do - decls ← withTraceNode `Compiler (fun _ => return m!"compiler phase: {pass.phase}, pass: {pass.name}") do - withPhase pass.phase <| pass.run decls - withPhase pass.phaseOut <| checkpoint pass.name decls (isCheckEnabled || pass.shouldAlwaysRunCheck) + state ← withTraceNode `Compiler (fun _ => return m!"compiler phase: {pass.phase}, pass: {pass.name}") do + let decls ← withPhase pass.phase do + state.fst.withAssertPurity pass.phase.toPurity fun h => do + pass.run (h ▸ state.snd) + pure ⟨_, decls⟩ + withPhase pass.phaseOut <| checkpoint pass.name state.snd (isCheckEnabled || pass.shouldAlwaysRunCheck) + let decls := state.fst.withAssertPurity outPhase fun h => h ▸ state.snd return decls end PassManager diff --git a/src/Lean/Compiler/LCNF/MonadScope.lean b/src/Lean/Compiler/LCNF/MonadScope.lean index b996c6856c..52d26a8500 100644 --- a/src/Lean/Compiler/LCNF/MonadScope.lean +++ b/src/Lean/Compiler/LCNF/MonadScope.lean @@ -33,7 +33,7 @@ instance (m n) [MonadLift m n] [MonadFunctor m n] [MonadScope m] : MonadScope n def inScope [MonadScope m] [Monad m] (fvarId : FVarId) : m Bool := return (← getScope).contains fvarId -@[inline] def withParams [MonadScope m] [Monad m] (ps : Array Param) (x : m α) : m α := +@[inline] def withParams [MonadScope m] [Monad m] (ps : Array (Param pu)) (x : m α) : m α := withScope (fun s => ps.foldl (init := s) fun s p => s.insert p.fvarId) x @[inline] def withFVar [MonadScope m] [Monad m] (fvarId : FVarId) (x : m α) : m α := diff --git a/src/Lean/Compiler/LCNF/MonoTypes.lean b/src/Lean/Compiler/LCNF/MonoTypes.lean index 60f8373d63..4c65f85b5a 100644 --- a/src/Lean/Compiler/LCNF/MonoTypes.lean +++ b/src/Lean/Compiler/LCNF/MonoTypes.lean @@ -99,4 +99,7 @@ def getOtherDeclMonoType (declName : Name) : CoreM Expr := do monoTypeExt.insert declName type return type +def getOtherDeclImpureType (_declName : Name) : CoreM Expr := do + panic! "Other decl impure type unimplemented" -- TODO + end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/OtherDecl.lean b/src/Lean/Compiler/LCNF/OtherDecl.lean index a917e2268c..ae659fe074 100644 --- a/src/Lean/Compiler/LCNF/OtherDecl.lean +++ b/src/Lean/Compiler/LCNF/OtherDecl.lean @@ -19,5 +19,6 @@ def getOtherDeclType (declName : Name) (us : List Level := []) : CompilerM Expr match (← getPhase) with | .base => getOtherDeclBaseType declName us | .mono => getOtherDeclMonoType declName + | .impure => getOtherDeclImpureType declName end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/PassManager.lean b/src/Lean/Compiler/LCNF/PassManager.lean index 1e1d025f52..514a5b3d5b 100644 --- a/src/Lean/Compiler/LCNF/PassManager.lean +++ b/src/Lean/Compiler/LCNF/PassManager.lean @@ -15,6 +15,20 @@ namespace Lean.Compiler.LCNF @[expose] def Phase.toNat : Phase → Nat | .base => 0 | .mono => 1 + | .impure => 2 + +instance : ToString Phase where + toString + | .base => "base" + | .mono => "mono" + | .impure => "impure" + +def Phase.withPurityCheck [Inhabited α] (pp : Phase) (ip : Purity) + (x : pp.toPurity = ip → α) : α := + if h : pp.toPurity = ip then + x h + else + panic! s!"Compiler error: {pp} is not equivalent to IR phase {ip}, this is a bug" instance : LT Phase where lt l r := l.toNat < r.toNat @@ -60,7 +74,7 @@ structure Pass where /-- The actual pass function, operating on the `Decl`s. -/ - run : Array Decl → CompilerM (Array Decl) + run : Array (Decl phase.toPurity) → CompilerM (Array (Decl phase.toPurity)) instance : Inhabited Pass where default := { phase := .base, name := default, run := fun decls => return decls } @@ -90,14 +104,10 @@ structure PassManager where monoPassesNoLambda : Array Pass deriving Inhabited -instance : ToString Phase where - toString - | .base => "base" - | .mono => "mono" - namespace Pass -def mkPerDeclaration (name : Name) (run : Decl → CompilerM Decl) (phase : Phase) (occurrence : Nat := 0) : Pass where +def mkPerDeclaration (name : Name) (phase : Phase) + (run : Decl phase.toPurity → CompilerM (Decl phase.toPurity)) (occurrence : Nat := 0) : Pass where occurrence := occurrence phase := phase name := name @@ -190,6 +200,7 @@ def run (manager : PassManager) (installer : PassInstaller) : CoreM PassManager return { manager with basePasses := (← installer.install manager.basePasses) } | .mono => return { manager with monoPasses := (← installer.install manager.monoPasses) } + | .impure => panic! "Pass manager support for impure unimplemented" -- TODO private unsafe def getPassInstallerUnsafe (declName : Name) : CoreM PassInstaller := do ofExcept <| (← getEnv).evalConstCheck PassInstaller (← getOptions) ``PassInstaller declName diff --git a/src/Lean/Compiler/LCNF/PhaseExt.lean b/src/Lean/Compiler/LCNF/PhaseExt.lean index 34a84cabb6..738155a99f 100644 --- a/src/Lean/Compiler/LCNF/PhaseExt.lean +++ b/src/Lean/Compiler/LCNF/PhaseExt.lean @@ -45,10 +45,15 @@ private builtin_initialize baseTransparentDeclsExt : EnvExtension (List Name × Set of public declarations whose mono bodies should be exported to other modules -/ private builtin_initialize monoTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt +/-- +Set of public declarations whose impure bodies should be exported to other modules +-/ +private builtin_initialize impureTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt private def getTransparencyExt : Phase → EnvExtension (List Name × NameSet) | .base => baseTransparentDeclsExt | .mono => monoTransparentDeclsExt + | .impure => impureTransparentDeclsExt def isDeclPublic (env : Environment) (declName : Name) : Bool := Id.run do if !env.header.isModule then @@ -81,26 +86,28 @@ def setDeclTransparent (env : Environment) (phase : Phase) (declName : Name) : E getTransparencyExt phase |>.modifyState env fun s => (declName :: s.1, s.2.insert declName) -abbrev DeclExtState := PHashMap Name Decl +abbrev DeclExtState (pu : Purity) := PHashMap Name (Decl pu) -private abbrev declLt (a b : Decl) := +private abbrev declLt (a b : Decl pu) := Name.quickLt a.name b.name -private def sortedDecls (s : DeclExtState) : Array Decl := +private def sortedDecls (s : DeclExtState pu) : Array (Decl pu) := let decls := s.foldl (init := #[]) fun ps _ v => ps.push v decls.qsort declLt -private abbrev findAtSorted? (decls : Array Decl) (declName : Name) : Option Decl := - let tmpDecl : Decl := default +private abbrev findAtSorted? (decls : Array (Decl pu)) (declName : Name) : Option (Decl pu) := + let tmpDecl : Decl pu := default let tmpDecl := { tmpDecl with name := declName } decls.binSearch tmpDecl declLt -@[expose] def DeclExt := PersistentEnvExtension Decl Decl DeclExtState +@[expose] def DeclExt (pu : Purity) := + PersistentEnvExtension (Decl pu) (Decl pu) (DeclExtState pu) -instance : Inhabited DeclExt := - inferInstanceAs (Inhabited (PersistentEnvExtension Decl Decl DeclExtState)) +instance : Inhabited (DeclExt pu) := + inferInstanceAs (Inhabited (PersistentEnvExtension (Decl pu) (Decl pu) (DeclExtState pu))) -def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) : IO DeclExt := +def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) : + IO (DeclExt phase.toPurity) := registerPersistentEnvExtension { name, mkInitial := pure {}, @@ -128,74 +135,77 @@ def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) : IO DeclExt otherState.insert k v } -builtin_initialize baseExt : DeclExt ← mkDeclExt .base -builtin_initialize monoExt : DeclExt ← mkDeclExt .mono +builtin_initialize baseExt : DeclExt .pure ← mkDeclExt .base +builtin_initialize monoExt : DeclExt .pure ← mkDeclExt .mono +builtin_initialize impureExt : DeclExt .impure ← mkDeclExt .impure -def getDeclCore? (env : Environment) (ext : DeclExt) (declName : Name) : Option Decl := +def getDeclCore? (env : Environment) (ext : DeclExt pu) (declName : Name) : Option (Decl pu) := match env.getModuleIdxFor? declName with | some modIdx => findAtSorted? (ext.getModuleEntries env modIdx) declName | none => ext.getState env |>.find? declName -def getBaseDecl? (declName : Name) : CoreM (Option Decl) := do +def getBaseDecl? (declName : Name) : CoreM (Option (Decl .pure)) := do return getDeclCore? (← getEnv) baseExt declName -def getMonoDecl? (declName : Name) : CoreM (Option Decl) := do +def getMonoDecl? (declName : Name) : CoreM (Option (Decl .pure)) := do return getDeclCore? (← getEnv) monoExt declName -def saveBaseDeclCore (env : Environment) (decl : Decl) : Environment := +def getImpureDecl? (declName : Name) : CoreM (Option (Decl .impure)) := do + return getDeclCore? (← getEnv) impureExt declName + +def saveBaseDeclCore (env : Environment) (decl : Decl .pure) : Environment := baseExt.addEntry env decl -def saveMonoDeclCore (env : Environment) (decl : Decl) : Environment := +def saveMonoDeclCore (env : Environment) (decl : Decl .pure) : Environment := monoExt.addEntry env decl -def Decl.saveBase (decl : Decl) : CoreM Unit := +def saveImpureDeclCore (env : Environment) (decl : Decl .impure) : Environment := + impureExt.addEntry env decl + +def Decl.saveBase (decl : Decl .pure) : CoreM Unit := modifyEnv (saveBaseDeclCore · decl) -def Decl.saveMono (decl : Decl) : CoreM Unit := +def Decl.saveMono (decl : Decl .pure) : CoreM Unit := modifyEnv (saveMonoDeclCore · decl) -def Decl.save (decl : Decl) : CompilerM Unit := do - match (← getPhase) with - | .base => decl.saveBase - | .mono => decl.saveMono +def Decl.saveImpure (decl : Decl .impure) : CoreM Unit := + modifyEnv (saveImpureDeclCore · decl) -def getDeclAt? (declName : Name) (phase : Phase) : CoreM (Option Decl) := +def Decl.save (decl : Decl pu) : CompilerM Unit := do + match (← getPhase) with + | .base => Phase.withPurityCheck .base pu fun h => + (h.symm ▸ decl).saveBase + | .mono => Phase.withPurityCheck .mono pu fun h => + (h.symm ▸ decl).saveMono + | .impure => Phase.withPurityCheck .impure pu fun h => + (h.symm ▸ decl).saveImpure + +def getDeclAt? (declName : Name) (phase : Phase) : CoreM (Option (Decl phase.toPurity)) := match phase with | .base => getBaseDecl? declName | .mono => getMonoDecl? declName + | .impure => getImpureDecl? declName -def getDecl? (declName : Name) : CompilerM (Option Decl) := do - getDeclAt? declName (← getPhase) +@[inline] +def getDecl? (declName : Name) : CompilerM (Option ((pu : Purity) × Decl pu)) := do + let some decl ← getDeclAt? declName (← getPhase) | return none + return some ⟨_, decl⟩ -def getLocalDeclAt? (declName : Name) (phase : Phase) : CompilerM (Option Decl) := do +def getLocalDeclAt? (declName : Name) (phase : Phase) : CompilerM (Option (Decl phase.toPurity)) := do match phase with | .base => return baseExt.getState (← getEnv) |>.find? declName | .mono => return monoExt.getState (← getEnv) |>.find? declName + | .impure => return impureExt.getState (← getEnv) |>.find? declName -def getLocalDecl? (declName : Name) : CompilerM (Option Decl) := do - getLocalDeclAt? declName (← getPhase) +@[inline] +def getLocalDecl? (declName : Name) : CompilerM (Option ((pu : Purity) × Decl pu)) := do + let some decl ← getLocalDeclAt? declName (← getPhase) | return none + return some ⟨_, decl⟩ -def getExt (phase : Phase) : DeclExt := +def getExt (phase : Phase) : DeclExt phase.toPurity := match phase with | .base => baseExt | .mono => monoExt - -def forEachDecl (f : Decl → CoreM Unit) (phase := Phase.base) : CoreM Unit := do - let ext := getExt phase - let env ← getEnv - for modIdx in *...env.allImportedModuleNames.size do - for decl in ext.getModuleEntries env modIdx do - f decl - ext.getState env |>.forM fun _ decl => f decl - -def forEachModuleDecl (moduleName : Name) (f : Decl → CoreM Unit) (phase := Phase.base) : CoreM Unit := do - let ext := getExt phase - let env ← getEnv - let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found" - for decl in ext.getModuleEntries env modIdx do - f decl - -def forEachMainModuleDecl (f : Decl → CoreM Unit) (phase := Phase.base) : CoreM Unit := do - (getExt phase).getState (← getEnv) |>.forM fun _ decl => f decl + | .impure => impureExt end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/PrettyPrinter.lean b/src/Lean/Compiler/LCNF/PrettyPrinter.lean index 6a956be83e..2b967911c8 100644 --- a/src/Lean/Compiler/LCNF/PrettyPrinter.lean +++ b/src/Lean/Compiler/LCNF/PrettyPrinter.lean @@ -43,11 +43,11 @@ def ppFVar (fvarId : FVarId) : M Format := def ppExpr (e : Expr) : M Format := do Meta.ppExpr e |>.run' { lctx := (← read) } -def ppArg (e : Arg) : M Format := do +def ppArg (e : Arg pu) : M Format := do match e with | .erased => return "◾" | .fvar fvarId => ppFVar fvarId - | .type e => + | .type e _ => if pp.explicit.get (← getOptions) then if e.isConst || e.isProp || e.isType0 || e.isFVar then ppExpr e @@ -56,7 +56,7 @@ def ppArg (e : Arg) : M Format := do else return "_" -def ppArgs (args : Array Arg) : M Format := do +def ppArgs (args : Array (Arg pu)) : M Format := do prefixJoin " " args ppArg def ppLitValue (lit : LitValue) : M Format := do @@ -64,49 +64,49 @@ def ppLitValue (lit : LitValue) : M Format := do | .nat v | .uint8 v | .uint16 v | .uint32 v | .uint64 v | .usize v => return format v | .str v => return format (repr v) -def ppLetValue (e : LetValue) : M Format := do +def ppLetValue (e : LetValue pu) : M Format := do match e with | .erased => return "◾" | .lit v => ppLitValue v - | .proj _ i fvarId => return f!"{← ppFVar fvarId} # {i}" + | .proj _ i fvarId _ => return f!"{← ppFVar fvarId} # {i}" | .fvar fvarId args => return f!"{← ppFVar fvarId}{← ppArgs args}" - | .const declName us args => return f!"{← ppExpr (.const declName us)}{← ppArgs args}" + | .const declName us args _ => return f!"{← ppExpr (.const declName us)}{← ppArgs args}" -def ppParam (param : Param) : M Format := do +def ppParam (param : Param pu) : M Format := do let borrow := if param.borrow then "@&" else "" if pp.funBinderTypes.get (← getOptions) then return Format.paren f!"{param.binderName} : {borrow}{← ppExpr param.type}" else return format s!"{borrow}{param.binderName}" -def ppParams (params : Array Param) : M Format := do +def ppParams (params : Array (Param pu)) : M Format := do prefixJoin " " params ppParam -def ppLetDecl (letDecl : LetDecl) : M Format := do +def ppLetDecl (letDecl : LetDecl pu) : M Format := do if pp.letVarTypes.get (← getOptions) then return f!"let {letDecl.binderName} : {← ppExpr letDecl.type} := {← ppLetValue letDecl.value}" else return f!"let {letDecl.binderName} := {← ppLetValue letDecl.value}" -def getFunType (ps : Array Param) (type : Expr) : CoreM Expr := +def getFunType (ps : Array (Param pu)) (type : Expr) : CoreM Expr := if type.isErased then pure type else instantiateForall type (ps.map (mkFVar ·.fvarId)) mutual - partial def ppFunDecl (funDecl : FunDecl) : M Format := do + partial def ppFunDecl (funDecl : FunDecl pu) : M Format := do return f!"{funDecl.binderName}{← ppParams funDecl.params} : {← ppExpr (← getFunType funDecl.params funDecl.type)} :={indentD (← ppCode funDecl.value)}" - partial def ppAlt (alt : Alt) : M Format := do + partial def ppAlt (alt : Alt pu) : M Format := do match alt with | .default k => return f!"| _ =>{indentD (← ppCode k)}" - | .alt ctorName params k => return f!"| {ctorName}{← ppParams params} =>{indentD (← ppCode k)}" + | .alt ctorName params k _ => return f!"| {ctorName}{← ppParams params} =>{indentD (← ppCode k)}" - partial def ppCode (c : Code) : M Format := do + partial def ppCode (c : Code pu) : M Format := do match c with | .let decl k => return (← ppLetDecl decl) ++ ";" ++ .line ++ (← ppCode k) - | .fun decl k => return f!"fun " ++ (← ppFunDecl decl) ++ ";" ++ .line ++ (← ppCode k) + | .fun decl k _ => return f!"fun " ++ (← ppFunDecl decl) ++ ";" ++ .line ++ (← ppCode k) | .jp decl k => return f!"jp " ++ (← ppFunDecl decl) ++ ";" ++ .line ++ (← ppCode k) | .cases c => return f!"cases {← ppFVar c.discr} : {← ppExpr c.resultType}{← prefixJoin .line c.alts ppAlt}" | .return fvarId => return f!"return {← ppFVar fvarId}" @@ -117,7 +117,7 @@ mutual else return "⊥" - partial def ppDeclValue (b : DeclValue) : M Format := do + partial def ppDeclValue (b : DeclValue pu) : M Format := do match b with | .code c => ppCode c | .extern .. => return "extern" @@ -125,21 +125,21 @@ end def run (x : M α) : CompilerM α := withOptions (pp.sanitizeNames.set · false) do - x |>.run (← get).lctx.toLocalContext + x |>.run ((← get).lctx.toLocalContext (← getPurity)) end PP -def ppCode (code : Code) : CompilerM Format := +def ppCode (code : Code pu) : CompilerM Format := PP.run <| PP.ppCode code -def ppLetValue (e : LetValue) : CompilerM Format := +def ppLetValue (e : LetValue pu) : CompilerM Format := PP.run <| PP.ppLetValue e -def ppDecl (decl : Decl) : CompilerM Format := +def ppDecl (decl : Decl pu) : CompilerM Format := PP.run do return f!"def {decl.name}{← PP.ppParams decl.params} : {← PP.ppExpr (← PP.getFunType decl.params decl.type)} :={indentD (← PP.ppDeclValue decl.value)}" -def ppFunDecl (decl : FunDecl) : CompilerM Format := +def ppFunDecl (decl : FunDecl pu) : CompilerM Format := PP.run do return f!"fun {← PP.ppFunDecl decl}" @@ -159,7 +159,7 @@ Similar to `ppDecl`, but in `CoreM`, and it does not assume `decl` has already been internalized. This function is used for debugging purposes. -/ -def ppDecl' (decl : Decl) : CoreM Format := do +def ppDecl' (decl : Decl pu) : CoreM Format := do runCompilerWithoutModifyingState do ppDecl (← decl.internalize) @@ -167,7 +167,7 @@ def ppDecl' (decl : Decl) : CoreM Format := do Similar to `ppCode`, but in `CoreM`, and it does not assume `code` has already been internalized. -/ -def ppCode' (code : Code) : CoreM Format := do +def ppCode' (code : Code pu) : CoreM Format := do runCompilerWithoutModifyingState do ppCode (← code.internalize) diff --git a/src/Lean/Compiler/LCNF/Probing.lean b/src/Lean/Compiler/LCNF/Probing.lean index 05a4727f10..9adff41f9f 100644 --- a/src/Lean/Compiler/LCNF/Probing.lean +++ b/src/Lean/Compiler/LCNF/Probing.lean @@ -26,7 +26,7 @@ def filter (f : α → CompilerM Bool) : Probe α α := fun data => data.filterM def sorted [Inhabited α] [LT α] [DecidableLT α] : Probe α α := fun data => return data.qsort (· < ·) @[inline] -def sortedBySize : Probe Decl (Nat × Decl) := fun decls => +def sortedBySize (pu : Purity) : Probe (Decl pu) (Nat × Decl pu) := fun decls => let decls := decls.map fun decl => (decl.size, decl) return decls.qsort fun (sz₁, decl₁) (sz₂, decl₂) => if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂ @@ -44,116 +44,118 @@ def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := f def countUniqueSorted [ToString α] [BEq α] [Hashable α] [Inhabited α] : Probe α (α × Nat) := countUnique >=> fun data => return data.qsort (fun l r => l.snd < r.snd) -partial def getLetValues : Probe Decl LetValue := fun decls => do +partial def getLetValues (pu : Purity) : Probe (Decl pu) (LetValue pu) := fun decls => do let (_, res) ← start decls |>.run #[] return res where - go (c : Code) : StateRefT (Array LetValue) CompilerM Unit := do + go (c : Code pu) : StateRefT (Array (LetValue pu)) CompilerM Unit := do match c with - | .let (decl : LetDecl) (k : Code) => + | .let decl k => modify fun s => s.push decl.value go k - | .fun decl k | .jp decl k => + | .fun decl k _ | .jp decl k => go decl.value go k | .cases cs => cs.alts.forM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return () - start (decls : Array Decl) : StateRefT (Array LetValue) CompilerM Unit := + start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit := decls.forM (·.value.forCodeM go) -partial def getJps : Probe Decl FunDecl := fun decls => do +partial def getJps (pu : Purity) : Probe (Decl pu) (FunDecl pu) := fun decls => do let (_, res) ← start decls |>.run #[] return res where - go (code : Code) : StateRefT (Array FunDecl) CompilerM Unit := do + go (code : Code pu) : StateRefT (Array (FunDecl pu)) CompilerM Unit := do match code with | .let _ k => go k - | .fun decl k => go decl.value; go k + | .fun decl k _ => go decl.value; go k | .jp decl k => modify (·.push decl); go decl.value; go k | .cases cs => cs.alts.forM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return () - start (decls : Array Decl) : StateRefT (Array FunDecl) CompilerM Unit := + start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit := decls.forM (·.value.forCodeM go) -partial def filterByLet (f : LetDecl → CompilerM Bool) : Probe Decl Decl := +partial def filterByLet (pu : Purity) (f : LetDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) where - go : Code → CompilerM Bool + go : Code pu → CompilerM Bool | .let decl k => do if (← f decl) then return true else go k - | .fun decl k | .jp decl k => go decl.value <||> go k + | .fun decl k _ | .jp decl k => go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false -partial def filterByFun (f : FunDecl → CompilerM Bool) : Probe Decl Decl := +partial def filterByFun (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) where - go : Code → CompilerM Bool + go : Code pu → CompilerM Bool | .let _ k | .jp _ k => go k - | .fun decl k => do if (← f decl) then return true else go decl.value <||> go k + | .fun decl k _ => do if (← f decl) then return true else go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false -partial def filterByJp (f : FunDecl → CompilerM Bool) : Probe Decl Decl := +partial def filterByJp (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) where - go : Code → CompilerM Bool + go : Code pu → CompilerM Bool | .let _ k => go k - | .fun decl k => go decl.value <||> go k + | .fun decl k _ => go decl.value <||> go k | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false -partial def filterByFunDecl (f : FunDecl → CompilerM Bool) : Probe Decl Decl := +partial def filterByFunDecl (pu : Purity) (f : FunDecl pu → CompilerM Bool) : + Probe (Decl pu) (Decl pu):= filter (·.value.isCodeAndM go) where - go : Code → CompilerM Bool + go : Code pu → CompilerM Bool | .let _ k => go k - | .fun decl k | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k + | .fun decl k _ | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false -partial def filterByCases (f : Cases → CompilerM Bool) : Probe Decl Decl := +partial def filterByCases (pu : Purity) (f : Cases pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) where - go : Code → CompilerM Bool + go : Code pu → CompilerM Bool | .let _ k => go k - | .fun decl k | .jp decl k => go decl.value <||> go k + | .fun decl k _ | .jp decl k => go decl.value <||> go k | .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false -partial def filterByJmp (f : FVarId → Array Arg → CompilerM Bool) : Probe Decl Decl := +partial def filterByJmp (pu : Purity) (f : FVarId → Array (Arg pu) → CompilerM Bool) : + Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) where - go : Code → CompilerM Bool + go : Code pu → CompilerM Bool | .let _ k => go k - | .fun decl k | .jp decl k => go decl.value <||> go k + | .fun decl k _ | .jp decl k => go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp fn var => f fn var | .return .. | .unreach .. => return false -partial def filterByReturn (f : FVarId → CompilerM Bool) : Probe Decl Decl := +partial def filterByReturn (pu : Purity) (f : FVarId → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) where - go : Code → CompilerM Bool + go : Code pu → CompilerM Bool | .let _ k => go k - | .fun decl k | .jp decl k => go decl.value <||> go k + | .fun decl k _ | .jp decl k => go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .unreach .. => return false | .return var => f var -partial def filterByUnreach (f : Expr → CompilerM Bool) : Probe Decl Decl := +partial def filterByUnreach (pu : Purity) (f : Expr → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) where - go : Code → CompilerM Bool + go : Code pu → CompilerM Bool | .let _ k => go k - | .fun decl k | .jp decl k => go decl.value <||> go k + | .fun decl k _ | .jp decl k => go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. => return false | .unreach typ => f typ @[inline] -def declNames : Probe Decl Name := +def declNames (pu : Purity) : Probe (Decl pu) Name := Probe.map (fun decl => return decl.name) @[inline] @@ -172,7 +174,8 @@ def tail (n : Nat) : Probe α α := fun data => return data[(data.size - n)...*] @[inline] def head (n : Nat) : Probe α α := fun data => return data[*...n] -def runOnDeclsNamed (declNames : Array Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do +def runOnDeclsNamed (declNames : Array Name) (phase : Phase := Phase.base) + (probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do let ext := getExt phase let env ← getEnv let decls ← declNames.mapM fun name => do @@ -180,14 +183,15 @@ def runOnDeclsNamed (declNames : Array Name) (probe : Probe Decl β) (phase : Ph return decl probe decls |>.run (phase := phase) -def runOnModule (moduleName : Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do +def runOnModule (moduleName : Name) (phase : Phase := Phase.base) + (probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do let ext := getExt phase let env ← getEnv let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found" let decls := ext.getModuleEntries env modIdx probe decls |>.run (phase := phase) -def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (Array β) := do +def runGlobally (phase : Phase := Phase.base) (probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do let ext := getExt phase let env ← getEnv let mut decls := #[] @@ -195,7 +199,7 @@ def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (A decls := decls.append <| ext.getModuleEntries env modIdx probe decls |>.run (phase := phase) -def toPass [ToString β] (probe : Probe Decl β) (phase : Phase) : Pass where +def toPass [ToString β] (phase : Phase) (probe : Probe (Decl phase.toPurity) β) : Pass where phase := phase name := `probe run := fun decls => do diff --git a/src/Lean/Compiler/LCNF/PullFunDecls.lean b/src/Lean/Compiler/LCNF/PullFunDecls.lean index 25961718ea..9bdb2d8a27 100644 --- a/src/Lean/Compiler/LCNF/PullFunDecls.lean +++ b/src/Lean/Compiler/LCNF/PullFunDecls.lean @@ -19,7 +19,7 @@ Local function declaration and join point being pulled. -/ structure ToPull where isFun : Bool - decl : FunDecl + decl : FunDecl .pure used : FVarIdHashSet deriving Inhabited @@ -50,7 +50,8 @@ where else go as (a :: keep) dep -partial def findFVarDepsFixpoint (todo : List ToPull) (acc : Array ToPull := #[]) : PullM (Array ToPull) := do +partial def findFVarDepsFixpoint (todo : List ToPull) (acc : Array ToPull := #[]) : + PullM (Array ToPull) := do match todo with | [] => return acc | p :: ps => @@ -65,7 +66,7 @@ partial def findFVarDeps (fvarId : FVarId) : PullM (Array ToPull) := do Similar to `findFVarDeps`. Extract from the state any local function declarations that depends on the given parameters. -/ -def findParamsDeps (params : Array Param) : PullM (Array ToPull) := do +def findParamsDeps (params : Array (Param pu)) : PullM (Array ToPull) := do let mut acc := #[] for param in params do acc := acc ++ (← findFVarDeps param.fvarId) @@ -74,7 +75,7 @@ def findParamsDeps (params : Array Param) : PullM (Array ToPull) := do /-- Construct the code `fun p.decl k` or `jp p.decl k`. -/ -def ToPull.attach (p : ToPull) (k : Code) : Code := +def ToPull.attach (p : ToPull) (k : Code .pure) : Code .pure := if p.isFun then .fun p.decl k else @@ -83,19 +84,19 @@ def ToPull.attach (p : ToPull) (k : Code) : Code := /-- Attach the given array of local function declarations and join points to `k`. -/ -partial def attach (ps : Array ToPull) (k : Code) : Code := Id.run do +partial def attach (ps : Array ToPull) (k : Code .pure) : Code .pure := Id.run do let visited := ps.map fun _ => false let (_, (k, _)) := go |>.run (k, visited) return k where - go : StateM (Code × Array Bool) Unit := do + go : StateM (Code .pure × Array Bool) Unit := do for i in *...ps.size do visit i - visited (i : Nat) : StateM (Code × Array Bool) Bool := + visited (i : Nat) : StateM (Code .pure × Array Bool) Bool := return (← get).2[i]! - visit (i : Nat) : StateM (Code × Array Bool) Unit := do + visit (i : Nat) : StateM (Code .pure × Array Bool) Unit := do unless (← visited i) do modify fun (k, visited) => (k, visited.set! i true) let pi := ps[i]! @@ -110,7 +111,7 @@ where Extract from the state any local function declarations that depends on the given free variable, **and** attach to code `k`. -/ -partial def attachFVarDeps (fvarId : FVarId) (k : Code) : PullM Code := do +partial def attachFVarDeps (fvarId : FVarId) (k : Code .pure) : PullM (Code .pure) := do let ps ← findFVarDeps fvarId return attach ps k @@ -118,11 +119,11 @@ partial def attachFVarDeps (fvarId : FVarId) (k : Code) : PullM Code := do Similar to `attachFVarDeps`. Extract from the state any local function declarations that depends on the given parameters, **and** attach to code `k`. -/ -def attachParamsDeps (params : Array Param) (k : Code) : PullM Code := do +def attachParamsDeps (params : Array (Param .pure)) (k : Code .pure) : PullM (Code .pure) := do let ps ← findParamsDeps params return attach ps k -def attachJps (k : Code) : PullM Code := do +def attachJps (k : Code .pure) : PullM (Code .pure) := do let jps := (← get).filter fun info => !info.isFun modify fun s => s.filter fun info => info.isFun let jps ← findFVarDepsFixpoint jps @@ -132,7 +133,7 @@ mutual /-- Add local function declaration (or join point if `isFun = false`) to the state. -/ -partial def addToPull (isFun : Bool) (decl : FunDecl) : PullM Unit := do +partial def addToPull (isFun : Bool) (decl : FunDecl .pure) : PullM Unit := do let saved ← get modify fun _ => [] let mut value ← pull decl.value @@ -147,19 +148,19 @@ partial def addToPull (isFun : Bool) (decl : FunDecl) : PullM Unit := do Pull local function declarations and join points in `code`. The state contains the declarations being pulled. -/ -partial def pull (code : Code) : PullM Code := do +partial def pull (code : Code .pure) : PullM (Code .pure) := do match code with | .let decl k => let k ← pull k let k ← attachFVarDeps decl.fvarId k return code.updateLet! decl k - | .fun decl k => addToPull true decl; pull k + | .fun decl k _ => addToPull true decl; pull k | .jp decl k => addToPull false decl; pull k | .cases c => let alts ← c.alts.mapMonoM fun alt => do match alt with | .default k => return alt.updateCode (← pull k) - | .alt _ ps k => + | .alt _ ps k _ => let k ← pull k let k ← attachParamsDeps ps k return alt.updateCode k @@ -174,13 +175,13 @@ open PullFunDecls /-- Pull local function declarations and join points in the given declaration. -/ -def Decl.pullFunDecls (decl : Decl) : CompilerM Decl := do +def Decl.pullFunDecls (decl : Decl .pure) : CompilerM (Decl .pure) := do let (value, ps) ← decl.value.mapCodeM pull |>.run [] let value := value.mapCode (attach ps.toArray) return { decl with value } def pullFunDecls : Pass := - .mkPerDeclaration `pullFunDecls Decl.pullFunDecls .base + .mkPerDeclaration `pullFunDecls .base Decl.pullFunDecls builtin_initialize registerTraceClass `Compiler.pullFunDecls (inherited := true) diff --git a/src/Lean/Compiler/LCNF/PullLetDecls.lean b/src/Lean/Compiler/LCNF/PullLetDecls.lean index 34d3e41701..4c9f3aedb7 100644 --- a/src/Lean/Compiler/LCNF/PullLetDecls.lean +++ b/src/Lean/Compiler/LCNF/PullLetDecls.lean @@ -15,28 +15,28 @@ namespace Lean.Compiler.LCNF namespace PullLetDecls structure Context where - isCandidateFn : LetDecl → FVarIdSet → CompilerM Bool + isCandidateFn : LetDecl .pure → FVarIdSet → CompilerM Bool included : FVarIdSet := {} structure State where - toPull : Array LetDecl := #[] + toPull : Array (LetDecl .pure) := #[] abbrev PullM := ReaderT Context $ StateRefT State CompilerM @[inline] def withFVar (fvarId : FVarId) (x : PullM α) : PullM α := withReader (fun ctx => { ctx with included := ctx.included.insert fvarId }) x -@[inline] def withParams (ps : Array Param) (x : PullM α) : PullM α := +@[inline] def withParams (ps : Array (Param .pure)) (x : PullM α) : PullM α := withReader (fun ctx => { ctx with included := ps.foldl (init := ctx.included) fun s p => s.insert p.fvarId }) x @[inline] def withNewScope (x : PullM α) : PullM α := withReader (fun ctx => { ctx with included := {} }) x -partial def withCheckpoint (x : PullM Code) : PullM Code := do +partial def withCheckpoint (x : PullM (Code .pure)) : PullM (Code .pure) := do let toPullSizeSaved := (← get).toPull.size let c ← withNewScope x let toPull := (← get).toPull - let rec go (i : Nat) (included : FVarIdSet) : StateM (Array LetDecl) Code := do + let rec go (i : Nat) (included : FVarIdSet) : StateM (Array (LetDecl .pure)) (Code .pure) := do if h : i < toPull.size then let letDecl := toPull[i] if letDecl.dependsOn included then @@ -51,11 +51,11 @@ partial def withCheckpoint (x : PullM Code) : PullM Code := do modify fun s => { s with toPull := s.toPull.shrink toPullSizeSaved ++ keep } return c -def attachToPull (c : Code) : PullM Code := do +def attachToPull (c : Code .pure) : PullM (Code .pure) := do let toPull := (← get).toPull return toPull.foldr (init := c) fun decl c => .let decl c -def shouldPull (decl : LetDecl) : PullM Bool := do +def shouldPull (decl : LetDecl .pure) : PullM Bool := do unless decl.dependsOn (← read).included do if (← (← read).isCandidateFn decl (← read).included) then modify fun s => { s with toPull := s.toPull.push decl } @@ -63,12 +63,12 @@ def shouldPull (decl : LetDecl) : PullM Bool := do return false mutual - partial def pullAlt (alt : Alt) : PullM Alt := + partial def pullAlt (alt : (Alt .pure)) : PullM (Alt .pure) := match alt with | .default k => return alt.updateCode (← withNewScope <| pullDecls k) | .alt _ params k => return alt.updateCode (← withNewScope <| withParams params <| pullDecls k) - partial def pullDecls (code : Code) : PullM Code := do + partial def pullDecls (code : Code .pure) : PullM (Code .pure) := do match code with | .cases c => -- At the present time, we can't correctly enforce the dependencies required for lifting @@ -93,21 +93,21 @@ mutual end -def PullM.run (x : PullM α) (isCandidateFn : LetDecl → FVarIdSet → CompilerM Bool) : CompilerM α := +def PullM.run (x : PullM α) (isCandidateFn : LetDecl .pure → FVarIdSet → CompilerM Bool) : CompilerM α := x { isCandidateFn } |>.run' {} end PullLetDecls open PullLetDecls -def Decl.pullLetDecls (decl : Decl) (isCandidateFn : LetDecl → FVarIdSet → CompilerM Bool) : CompilerM Decl := do +def Decl.pullLetDecls (decl : Decl .pure) (isCandidateFn : LetDecl .pure → FVarIdSet → CompilerM Bool) : CompilerM (Decl .pure) := do PullM.run (isCandidateFn := isCandidateFn) do withParams decl.params do let value ← decl.value.mapCodeM pullDecls let value ← value.mapCodeM attachToPull return { decl with value } -def Decl.pullInstances (decl : Decl) : CompilerM Decl := +def Decl.pullInstances (decl : Decl .pure) : CompilerM (Decl .pure) := decl.pullLetDecls fun letDecl candidates => do -- TODO: Correctly represent these dependencies so this check isn't required. if let .const _ _ args := letDecl.value then @@ -122,7 +122,7 @@ def Decl.pullInstances (decl : Decl) : CompilerM Decl := return false def pullInstances : Pass := - .mkPerDeclaration `pullInstances Decl.pullInstances .base + .mkPerDeclaration `pullInstances .base Decl.pullInstances builtin_initialize registerTraceClass `Compiler.pullInstances (inherited := true) diff --git a/src/Lean/Compiler/LCNF/ReduceArity.lean b/src/Lean/Compiler/LCNF/ReduceArity.lean index f5c4bc03fa..1958690a26 100644 --- a/src/Lean/Compiler/LCNF/ReduceArity.lean +++ b/src/Lean/Compiler/LCNF/ReduceArity.lean @@ -52,7 +52,7 @@ We assume this limitation is irrelevant in practice. namespace FindUsed structure Context where - decl : Decl + decl : Decl .pure params : FVarIdSet structure State where @@ -64,12 +64,12 @@ def visitFVar (fvarId : FVarId) : FindUsedM Unit := do if (← read).params.contains fvarId then modify fun s => { s with used := s.used.insert fvarId } -def visitArg (arg : Arg) : FindUsedM Unit := do +def visitArg (arg : Arg .pure) : FindUsedM Unit := do match arg with | .erased | .type .. => return () | .fvar fvarId => visitFVar fvarId -def visitLetValue (e : LetValue) : FindUsedM Unit := do +def visitLetValue (e : LetValue .pure) : FindUsedM Unit := do match e with | .erased | .lit .. => return () | .proj _ _ fvarId => visitFVar fvarId @@ -93,7 +93,7 @@ def visitLetValue (e : LetValue) : FindUsedM Unit := do else args.forM visitArg -partial def visit (code : Code) : FindUsedM Unit := do +partial def visit (code : Code .pure) : FindUsedM Unit := do match code with | .let decl k => visitLetValue decl.value @@ -107,7 +107,7 @@ partial def visit (code : Code) : FindUsedM Unit := do | .return fvarId => visitFVar fvarId | .unreach _ => return () -def collectUsedParams (decl : Decl) : CompilerM FVarIdHashSet := do +def collectUsedParams (decl : Decl .pure) : CompilerM FVarIdHashSet := do let params := decl.params.foldl (init := {}) fun s p => s.insert p.fvarId let (_, { used, .. }) ← decl.value.forCodeM visit |>.run { decl, params } |>.run {} return used @@ -123,7 +123,7 @@ structure Context where abbrev ReduceM := ReaderT Context CompilerM -partial def reduce (code : Code) : ReduceM Code := do +partial def reduce (code : Code .pure) : ReduceM (Code .pure) := do match code with | .let decl k => let .const declName _ args := decl.value | do return code.updateLet! decl (← reduce k) @@ -148,7 +148,7 @@ end ReduceArity open FindUsed ReduceArity Internalize -def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do +def Decl.reduceArity (decl : Decl .pure) : CompilerM (Array (Decl .pure)) := do match decl.value with | .code code => let used ← collectUsedParams decl @@ -160,7 +160,7 @@ def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do trace[Compiler.reduceArity] "{decl.name}, used params: {used.toList.map mkFVar}" let mask := decl.params.map fun param => used.contains param.fvarId let auxName := decl.name ++ `_redArg - let mkAuxDecl : CompilerM Decl := do + let mkAuxDecl : CompilerM (Decl .pure) := do let params := decl.params.filter fun param => used.contains param.fvarId let value ← decl.value.mapCodeM reduce |>.run { declName := decl.name, auxDeclName := auxName, paramMask := mask } let type ← code.inferType @@ -168,7 +168,7 @@ def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do let auxDecl := { decl with name := auxName, levelParams := [], type, params, value } auxDecl.saveMono return auxDecl - let updateDecl : InternalizeM Decl := do + let updateDecl : InternalizeM .pure (Decl .pure) := do let params ← decl.params.mapM internalizeParam let mut args := #[] for used in mask, param in params do diff --git a/src/Lean/Compiler/LCNF/ReduceJpArity.lean b/src/Lean/Compiler/LCNF/ReduceJpArity.lean index 71dc511ba9..cb7040bf18 100644 --- a/src/Lean/Compiler/LCNF/ReduceJpArity.lean +++ b/src/Lean/Compiler/LCNF/ReduceJpArity.lean @@ -18,7 +18,7 @@ namespace ReduceJpArity abbrev ReduceM := ReaderT (FVarIdMap (Array Bool)) CompilerM -partial def reduce (code : Code) : ReduceM Code := do +partial def reduce (code : Code .pure) : ReduceM (Code .pure) := do match code with | .let decl k => return code.updateLet! decl (← reduce k) | .fun decl k => @@ -69,12 +69,14 @@ open ReduceJpArity /-- Try to reduce arity of join points -/ -def Decl.reduceJpArity (decl : Decl) : CompilerM Decl := do +def Decl.reduceJpArity (decl : Decl .pure) : CompilerM (Decl .pure) := do let value ← decl.value.mapCodeM reduce |>.run {} return { decl with value } +-- TODO: This can be made Purity generic def reduceJpArity (phase := Phase.base) : Pass := - .mkPerDeclaration `reduceJpArity Decl.reduceJpArity phase + phase.withPurityCheck .pure fun h => + .mkPerDeclaration `reduceJpArity phase (h ▸ Decl.reduceJpArity) builtin_initialize registerTraceClass `Compiler.reduceJpArity (inherited := true) diff --git a/src/Lean/Compiler/LCNF/Renaming.lean b/src/Lean/Compiler/LCNF/Renaming.lean index 83a9b8fbeb..ed79204e34 100644 --- a/src/Lean/Compiler/LCNF/Renaming.lean +++ b/src/Lean/Compiler/LCNF/Renaming.lean @@ -16,7 +16,7 @@ A mapping from free variable id to binder name. -/ abbrev Renaming := FVarIdMap Name -def Param.applyRenaming (param : Param) (r : Renaming) : CompilerM Param := do +def Param.applyRenaming (param : Param pu) (r : Renaming) : CompilerM (Param pu) := do if let some binderName := r.get? param.fvarId then let param := { param with binderName } modifyLCtx fun lctx => lctx.addParam param @@ -24,7 +24,7 @@ def Param.applyRenaming (param : Param) (r : Renaming) : CompilerM Param := do else return param -def LetDecl.applyRenaming (decl : LetDecl) (r : Renaming) : CompilerM LetDecl := do +def LetDecl.applyRenaming (decl : LetDecl pu) (r : Renaming) : CompilerM (LetDecl pu) := do if let some binderName := r.get? decl.fvarId then let decl := { decl with binderName } modifyLCtx fun lctx => lctx.addLetDecl decl @@ -33,7 +33,7 @@ def LetDecl.applyRenaming (decl : LetDecl) (r : Renaming) : CompilerM LetDecl := return decl mutual -partial def FunDecl.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM FunDecl := do +partial def FunDecl.applyRenaming (decl : (FunDecl pu)) (r : Renaming) : CompilerM (FunDecl pu) := do if let some binderName := r.get? decl.fvarId then let decl := decl.updateBinderName binderName modifyLCtx fun lctx => lctx.addFunDecl decl @@ -41,20 +41,20 @@ partial def FunDecl.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM Fu else decl.updateValue (← decl.value.applyRenaming r) -partial def Code.applyRenaming (code : Code) (r : Renaming) : CompilerM Code := do +partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code pu) := do match code with | .let decl k => return code.updateLet! (← decl.applyRenaming r) (← k.applyRenaming r) - | .fun decl k | .jp decl k => return code.updateFun! (← decl.applyRenaming r) (← k.applyRenaming r) + | .fun decl k _ | .jp decl k => return code.updateFun! (← decl.applyRenaming r) (← k.applyRenaming r) | .cases c => let alts ← c.alts.mapMonoM fun alt => match alt with | .default k => return alt.updateCode (← k.applyRenaming r) - | .alt _ ps k => return alt.updateAlt! (← ps.mapMonoM (·.applyRenaming r)) (← k.applyRenaming r) + | .alt _ ps k _ => return alt.updateAlt! (← ps.mapMonoM (·.applyRenaming r)) (← k.applyRenaming r) return code.updateAlts! alts | .jmp .. | .unreach .. | .return .. => return code end -def Decl.applyRenaming (decl : Decl) (r : Renaming) : CompilerM Decl := do +def Decl.applyRenaming (decl : Decl pu) (r : Renaming) : CompilerM (Decl pu) := do if r.isEmpty then return decl else diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index f682238849..6097989580 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -24,7 +24,7 @@ public section namespace Lean.Compiler.LCNF open Simp -def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do +def Decl.simp? (decl : Decl .pure) : SimpM (Option (Decl .pure)) := do let .code code := decl.value | return none updateFunDeclInfo code traceM `Compiler.simp.inline.info do return m!"{decl.name}:{Format.nest 2 (← (← get).funDeclInfoMap.format)}" @@ -42,7 +42,7 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do else return none -partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do +partial def Decl.simp (decl : Decl .pure) (config : Config) : CompilerM (Decl .pure) := do let mut config := config if (← isTemplateLike decl) then /- @@ -54,7 +54,7 @@ partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do config := { config with etaPoly := false, inlinePartial := false } go decl config where - go (decl : Decl) (config : Config) : CompilerM Decl := do + go (decl : Decl .pure) (config : Config) : CompilerM (Decl .pure) := do if let some decl ← decl.simp? |>.run { config, declName := decl.name } |>.run' {} |>.run {} then -- TODO: bound number of steps? go decl config @@ -62,7 +62,8 @@ where return decl def simp (config : Config := {}) (occurrence : Nat := 0) (phase := Phase.base) : Pass := - .mkPerDeclaration `simp (Decl.simp · config) phase (occurrence := occurrence) + phase.withPurityCheck .pure fun h => + .mkPerDeclaration `simp phase (h ▸ (Decl.simp · config)) (occurrence := occurrence) builtin_initialize registerTraceClass `Compiler.simp (inherited := true) diff --git a/src/Lean/Compiler/LCNF/Simp/Basic.lean b/src/Lean/Compiler/LCNF/Simp/Basic.lean index f1d0ffc3f7..721d078926 100644 --- a/src/Lean/Compiler/LCNF/Simp/Basic.lean +++ b/src/Lean/Compiler/LCNF/Simp/Basic.lean @@ -22,10 +22,10 @@ let _x.2 := _f.1 ``` `findFunDecl? _x.2` returns `none`, but `findFunDecl'? _x.2` returns the declaration for `_f.1`. -/ -partial def findFunDecl'? (fvarId : FVarId) : CompilerM (Option FunDecl) := do - if let some decl ← findFunDecl? fvarId then +partial def findFunDecl'? (fvarId : FVarId) : CompilerM (Option (FunDecl pu)) := do + if let some decl ← findFunDecl? (pu := pu) fvarId then return decl - else if let some (.fvar fvarId' #[]) ← findLetValue? fvarId then + else if let some (.fvar fvarId' #[]) ← findLetValue? (pu := pu) fvarId then findFunDecl'? fvarId' else return none diff --git a/src/Lean/Compiler/LCNF/Simp/ConstantFold.lean b/src/Lean/Compiler/LCNF/Simp/ConstantFold.lean index 575dee4aa3..dcee2451ba 100644 --- a/src/Lean/Compiler/LCNF/Simp/ConstantFold.lean +++ b/src/Lean/Compiler/LCNF/Simp/ConstantFold.lean @@ -18,14 +18,14 @@ namespace ConstantFold A constant folding monad, the additional state stores auxiliary declarations required to build the new constant. -/ -abbrev FolderM := StateRefT (Array CodeDecl) CompilerM +abbrev FolderM := StateRefT (Array (CodeDecl .pure)) CompilerM /-- A constant folder for a specific function, takes all the arguments of a certain function and produces a new `Expr` + auxiliary declarations in the `FolderM` monad on success. If the folding fails it returns `none`. -/ -abbrev Folder := Array Arg → FolderM (Option LetValue) +abbrev Folder := Array (Arg .pure) → FolderM (Option (LetValue .pure)) /-- A typeclass for detecting and producing literals of arbitrary types @@ -43,7 +43,7 @@ class Literal (α : Type) where final `Expr` putting them all together into a literal of type `α`, where again the idea of what a literal is depends on `α`. -/ - mkLit : α → FolderM LetValue + mkLit : α → FolderM (LetValue .pure) export Literal (getLit mkLit) @@ -51,7 +51,7 @@ export Literal (getLit mkLit) A wrapper around `LCNF.mkAuxLetDecl` that will automatically store the `LetDecl` in the state of `FolderM`. -/ -def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : FolderM FVarId := do +def mkAuxLetDecl (e : LetValue .pure) (prefixName := `_x) : FolderM FVarId := do let decl ← LCNF.mkAuxLetDecl e prefixName modify fun s => s.push <| .let decl return decl.fvarId @@ -66,10 +66,10 @@ def mkAuxLit [Literal α] (x : α) (prefixName := `_x) : FolderM FVarId := do mkAuxLetDecl lit prefixName partial def getNatLit (fvarId : FVarId) : CompilerM (Option Nat) := do - let some (.lit (.nat n)) ← findLetValue? fvarId | return none + let some (.lit (.nat n)) ← findLetValue? (pu := .pure) fvarId | return none return n -def mkNatLit (n : Nat) : FolderM LetValue := +def mkNatLit (n : Nat) : FolderM (LetValue .pure) := return .lit (.nat n) instance : Literal Nat where @@ -77,10 +77,10 @@ instance : Literal Nat where mkLit := mkNatLit def getStringLit (fvarId : FVarId) : CompilerM (Option String) := do - let some (.lit (.str s)) ← findLetValue? fvarId | return none + let some (.lit (.str s)) ← findLetValue? (pu := .pure) fvarId | return none return s -def mkStringLit (n : String) : FolderM LetValue := +def mkStringLit (n : String) : FolderM (LetValue .pure) := return .lit (.str n) instance : Literal String where @@ -91,7 +91,7 @@ def getBoolLit (fvarId : FVarId) : CompilerM (Option Bool) := do let some (.const ctor [] #[]) ← findLetValue? fvarId | return none return ctor == ``Bool.true -def mkBoolLit (b : Bool) : FolderM LetValue := +def mkBoolLit (b : Bool) : FolderM (LetValue .pure) := let ctor := if b then ``Bool.true else ``Bool.false return .const ctor [] #[] @@ -115,7 +115,7 @@ instance : Literal Char := mkNatWrapperInstance Char.ofNat ``Char.ofNat Char.toN def mkUIntInstance (matchLit : LitValue → Option α) (litValueCtor : α → LitValue) : Literal α where getLit fvarId := do - let some (.lit litVal) ← findLetValue? fvarId | return none + let some (.lit litVal) ← findLetValue? (pu := .pure) fvarId | return none return matchLit litVal mkLit x := return .lit <| litValueCtor x @@ -162,7 +162,7 @@ let _x.26 := @Array.push _ _x.24 z _x.26 ``` -/ -def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM LetValue := do +def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM (LetValue .pure) := do let sizeLit ← mkAuxLit elements.size let mut literal ← mkAuxLetDecl <| .const ``Array.mkEmpty [typLevel] #[.type typ, .fvar sizeLit] for element in elements do @@ -335,7 +335,7 @@ def Folder.mulShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : α → α) -- TODO: add option for controlling the limit def natPowThreshold := 256 -def foldNatPow (args : Array Arg) : FolderM (Option LetValue) := do +def foldNatPow (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none let some value₁ ← getNatLit fvarId₁ | return none let some value₂ ← getNatLit fvarId₂ | return none @@ -347,14 +347,14 @@ def foldNatPow (args : Array Arg) : FolderM (Option LetValue) := do /-- Folder for ofNat operations on fixed-sized integer types. -/ -def Folder.ofNat (f : Nat → LitValue) (args : Array Arg) : FolderM (Option LetValue) := do +def Folder.ofNat (f : Nat → LitValue) (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do let #[.fvar fvarId] := args | return none let some value ← getNatLit fvarId | return none return some (.lit (f value)) -def Folder.toNat (args : Array Arg) : FolderM (Option LetValue) := do +def Folder.toNat (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do let #[.fvar fvarId] := args | return none - let some (.lit lit) ← findLetValue? fvarId | return none + let some (.lit lit) ← findLetValue? (pu := .pure) fvarId | return none match lit with | .uint8 v | .uint16 v | .uint32 v | .uint64 v | .usize v => return some (.lit (.nat v.toNat)) | .nat _ | .str _ => return none @@ -436,7 +436,7 @@ def stringFolders : List (Name × Folder) := [ /-- Apply all known folders to `decl`. -/ -def applyFolders (decl : LetDecl) (folders : SMap Name Folder) : CompilerM (Option (Array CodeDecl)) := do +def applyFolders (decl : LetDecl .pure) (folders : SMap Name Folder) : CompilerM (Option (Array (CodeDecl .pure))) := do match decl.value with | .const name _ args => if let some folder := folders.find? name then @@ -495,7 +495,7 @@ def getFolders : CoreM (SMap Name Folder) := /-- Apply a list of default folders to `decl` -/ -def foldConstants (decl : LetDecl) : CompilerM (Option (Array CodeDecl)) := do +def foldConstants (decl : LetDecl .pure) : CompilerM (Option (Array (CodeDecl .pure))) := do applyFolders decl (← getFolders) end ConstantFold diff --git a/src/Lean/Compiler/LCNF/Simp/DefaultAlt.lean b/src/Lean/Compiler/LCNF/Simp/DefaultAlt.lean index 63774e540c..9b609321e7 100644 --- a/src/Lean/Compiler/LCNF/Simp/DefaultAlt.lean +++ b/src/Lean/Compiler/LCNF/Simp/DefaultAlt.lean @@ -19,7 +19,7 @@ and the number of occurrences. We use this function to decide whether to create a `.default` case or not. -/ -private def getMaxOccs (alts : Array Alt) : Alt × Nat := Id.run do +private def getMaxOccs (alts : Array (Alt .pure)) : Alt .pure × Nat := Id.run do let mut maxAlt := alts[0]! let mut max := getNumOccsOf alts 0 for h : i in 1...alts.size do @@ -35,7 +35,7 @@ where Note that the number of occurrences can be greater than 1 only when the alternative does not depend on field parameters -/ - getNumOccsOf (alts : Array Alt) (i : Nat) : Nat := Id.run do + getNumOccsOf (alts : Array (Alt .pure)) (i : Nat) : Nat := Id.run do let code := alts[i]!.getCode let mut n := 1 for h : j in (i+1)...alts.size do @@ -47,7 +47,7 @@ where Add a default case to the given `cases` alternatives if there are alternatives with equivalent (aka alpha equivalent) right hand sides. -/ -def addDefaultAlt (alts : Array Alt) : SimpM (Array Alt) := do +def addDefaultAlt (alts : Array (Alt .pure)) : SimpM (Array (Alt .pure)) := do if alts.size <= 1 || alts.any (· matches .default ..) then return alts else diff --git a/src/Lean/Compiler/LCNF/Simp/DiscrM.lean b/src/Lean/Compiler/LCNF/Simp/DiscrM.lean index 33b24cbfe9..8cba16ab6e 100644 --- a/src/Lean/Compiler/LCNF/Simp/DiscrM.lean +++ b/src/Lean/Compiler/LCNF/Simp/DiscrM.lean @@ -15,7 +15,7 @@ namespace Lean.Compiler.LCNF namespace Simp inductive CtorInfo where - | ctor (val : ConstructorVal) (args : Array Arg) + | ctor (val : ConstructorVal) (args : Array (Arg .pure)) | /-- Natural numbers are morally constructor applications -/ natVal (n : Nat) @@ -70,7 +70,7 @@ def findCtorName? (fvarId : FVarId) : DiscrM (Option Name) := do /-- If `type` is an application of the inductive type `ind`, return its universe levels and parameters. -/ -def getIndInfo? (type : Expr) (ind : Name) : CoreM (Option (List Level × Array Arg)) := do +def getIndInfo? (type : Expr) (ind : Name) : CoreM (Option (List Level × Array (Arg .pure))) := do let type := type.headBeta let .const declName us := type.getAppFn | return none unless declName == ind do return none @@ -85,7 +85,8 @@ def getIndInfo? (type : Expr) (ind : Name) : CoreM (Option (List Level × Array Execute `x` with the information that `discr = ctorName ctorFields`. We use this information to simplify nested cases on the same discriminant. -/ -@[inline] def withDiscrCtorImp (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : DiscrM α) : DiscrM α := do +@[inline] def withDiscrCtorImp (discr : FVarId) (ctorName : Name) + (ctorFields : Array (Param .pure)) (x : DiscrM α) : DiscrM α := do let ctx ← updateCtx withReader (fun _ => ctx) x where @@ -103,7 +104,9 @@ where let ctorInfo := .ctor ctorVal (.replicate ctorVal.numParams Arg.erased ++ fieldArgs) return { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctorInfo } -@[inline, inherit_doc withDiscrCtorImp] def withDiscrCtor [MonadFunctorT DiscrM m] (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) : m α → m α := +@[inline, inherit_doc withDiscrCtorImp] +def withDiscrCtor [MonadFunctorT DiscrM m] (discr : FVarId) (ctorName : Name) + (ctorFields : Array (Param .pure)) : m α → m α := monadMap (m := DiscrM) <| withDiscrCtorImp discr ctorName ctorFields def simpCtorDiscrCore? (e : Expr) : DiscrM (Option FVarId) := do diff --git a/src/Lean/Compiler/LCNF/Simp/FunDeclInfo.lean b/src/Lean/Compiler/LCNF/Simp/FunDeclInfo.lean index 5939a5a8ad..53d7b7274a 100644 --- a/src/Lean/Compiler/LCNF/Simp/FunDeclInfo.lean +++ b/src/Lean/Compiler/LCNF/Simp/FunDeclInfo.lean @@ -94,27 +94,27 @@ If `mustInline := true`, then all local function declarations occurring in `code` are tagged as `.mustInline`. Recall that we use `.mustInline` for local function declarations occurring in type class instances. -/ -partial def FunDeclInfoMap.update (s : FunDeclInfoMap) (code : Code) (mustInline := false) : CompilerM FunDeclInfoMap := do +partial def FunDeclInfoMap.update (s : FunDeclInfoMap) (code : Code .pure) (mustInline := false) : CompilerM FunDeclInfoMap := do let (_, s) ← go code |>.run s return s where - addArgOcc (arg : Arg) : StateRefT FunDeclInfoMap CompilerM Unit := do + addArgOcc (arg : Arg .pure) : StateRefT FunDeclInfoMap CompilerM Unit := do match arg with | .fvar fvarId => - let some funDecl ← findFunDecl'? fvarId | return () + let some funDecl ← findFunDecl'? (pu := .pure) fvarId | return () modify fun s => s.addHo funDecl.fvarId | .erased .. | .type .. => return () - addLetValueOccs (e : LetValue) : StateRefT FunDeclInfoMap CompilerM Unit := do + addLetValueOccs (e : LetValue .pure) : StateRefT FunDeclInfoMap CompilerM Unit := do match e with | .erased | .lit .. | .proj .. => return () | .const _ _ args => args.forM addArgOcc | .fvar fvarId args => - let some funDecl ← findFunDecl'? fvarId | return () + let some funDecl ← findFunDecl'? (pu := .pure) fvarId | return () modify fun s => s.add funDecl.fvarId args.forM addArgOcc - go (code : Code) : StateRefT FunDeclInfoMap CompilerM Unit := do + go (code : Code .pure) : StateRefT FunDeclInfoMap CompilerM Unit := do match code with | .let decl k => addLetValueOccs decl.value @@ -126,7 +126,7 @@ where | .jp decl k => go decl.value; go k | .cases c => c.alts.forM fun alt => go alt.getCode | .jmp fvarId args => - let funDecl ← getFunDecl fvarId + let funDecl ← getFunDecl (pu := .pure) fvarId modify fun s => s.add funDecl.fvarId args.forM addArgOcc | .return .. | .unreach .. => return () diff --git a/src/Lean/Compiler/LCNF/Simp/InlineCandidate.lean b/src/Lean/Compiler/LCNF/Simp/InlineCandidate.lean index df746b42a9..d84ce49f4c 100644 --- a/src/Lean/Compiler/LCNF/Simp/InlineCandidate.lean +++ b/src/Lean/Compiler/LCNF/Simp/InlineCandidate.lean @@ -19,11 +19,11 @@ It contains information for inlining local and global functions. -/ structure InlineCandidateInfo where isLocal : Bool - params : Array Param + params : Array (Param .pure) /-- Value (lambda expression) of the function to be inlined. -/ - value : Code + value : Code .pure fType : Expr - args : Array Arg + args : Array (Arg .pure) /-- `ifReduce = true` if the declaration being inlined was tagged with `inline_if_reduce`. -/ ifReduce : Bool /-- `recursive = true` if the declaration being inline is in a mutually recursive block. -/ @@ -36,7 +36,7 @@ def InlineCandidateInfo.arity : InlineCandidateInfo → Nat /-- Return `some info` if `e` should be inlined. -/ -def inlineCandidate? (e : LetValue) : SimpM (Option InlineCandidateInfo) := do +def inlineCandidate? (e : LetValue .pure) : SimpM (Option InlineCandidateInfo) := do let mut e := e let mut mustInline := false if let .const ``inline _ #[_, .fvar argFVarId] := e then @@ -46,7 +46,7 @@ def inlineCandidate? (e : LetValue) : SimpM (Option InlineCandidateInfo) := do if let .const declName us args := e then unless (← read).config.inlineDefs do return none - let some decl ← getDecl? declName | return none + let some ⟨.pure, decl⟩ ← getDecl? declName | return none let .code code := decl.value | return none let shouldInline : SimpM Bool := do if !decl.inlineIfReduceAttr && decl.recursive then return false diff --git a/src/Lean/Compiler/LCNF/Simp/InlineProj.lean b/src/Lean/Compiler/LCNF/Simp/InlineProj.lean index 94cdfde4a6..14369e50f1 100644 --- a/src/Lean/Compiler/LCNF/Simp/InlineProj.lean +++ b/src/Lean/Compiler/LCNF/Simp/InlineProj.lean @@ -39,7 +39,7 @@ and the free variable containing the result (`FVarId`). The resulting `FVarId` o subset of `Array CodeDecl`. However, this method does try to filter the relevant ones. We rely on the `used` var set available in `SimpM` to filter them. See `attachCodeDecls`. -/ -partial def inlineProjInst? (e : LetValue) : SimpM (Option (Array CodeDecl × FVarId)) := do +partial def inlineProjInst? (e : LetValue .pure) : SimpM (Option (Array (CodeDecl .pure) × FVarId)) := do let .proj _ i s := e | return none let sType ← getType s unless (← isClass? sType).isSome do return none @@ -52,7 +52,7 @@ partial def inlineProjInst? (e : LetValue) : SimpM (Option (Array CodeDecl × FV eraseCodeDecls decls return none where - visit (fvarId : FVarId) (projs : List Nat) : OptionT (StateRefT (Array CodeDecl) SimpM) FVarId := do + visit (fvarId : FVarId) (projs : List Nat) : OptionT (StateRefT (Array (CodeDecl .pure)) SimpM) FVarId := do let some letDecl ← findLetDecl? fvarId | failure match letDecl.value with | .proj _ i s => visit s (i :: projs) @@ -72,7 +72,7 @@ where else visit fvarId projs else - let some decl ← getDecl? declName | failure + let some ⟨.pure, decl⟩ ← getDecl? declName | failure match decl.value with | .code code => guard (!decl.recursive && decl.getArity == args.size) @@ -82,7 +82,7 @@ where visitCode code projs | .extern .. => failure - visitCode (code : Code) (projs : List Nat) : OptionT (StateRefT (Array CodeDecl) SimpM) FVarId := do + visitCode (code : Code .pure) (projs : List Nat) : OptionT (StateRefT (Array (CodeDecl .pure)) SimpM) FVarId := do match code with | .let decl k => modify (·.push (.let decl)); visitCode k projs | .fun decl k => modify (·.push (.fun decl)); visitCode k projs diff --git a/src/Lean/Compiler/LCNF/Simp/JpCases.lean b/src/Lean/Compiler/LCNF/Simp/JpCases.lean index 6be93d9fb3..4c312c7911 100644 --- a/src/Lean/Compiler/LCNF/Simp/JpCases.lean +++ b/src/Lean/Compiler/LCNF/Simp/JpCases.lean @@ -26,12 +26,12 @@ f y := ``` `idx` is the index of the parameter used in the `cases` statement. -/ -def isJpCases? (decl : FunDecl) : CompilerM (Option Nat) := do +def isJpCases? (decl : FunDecl .pure) : CompilerM (Option Nat) := do if decl.params.size == 0 then return none else let small := (← getConfig).smallThreshold - let rec go (code : Code) (prefixSize : Nat) : Option Nat := + let rec go (code : Code .pure) (prefixSize : Nat) : Option Nat := if prefixSize > small then none else match code with | .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/ @@ -64,11 +64,11 @@ in code that satisfies `isJpCases`, and `ctorNames` is a set of constructor name there is a jump `.jmp jpFVarId #[..., x, ...]` in `code` and `x` is a constructor application. `paramIdx` is the index of the parameter -/ -partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfoMap := do +partial def collectJpCasesInfo (code : Code .pure) : CompilerM JpCasesInfoMap := do let (_, s) ← go code |>.run {} |>.run {} return s where - go (code : Code) : StateRefT JpCasesInfoMap DiscrM Unit := do + go (code : Code .pure) : StateRefT JpCasesInfoMap DiscrM Unit := do match code with | .let _ k => go k | .fun decl k => go decl.value; go k @@ -90,17 +90,17 @@ where /-- Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases?`. -/ -private def extractJpCases (code : Code) : Array CodeDecl × Cases := +private def extractJpCases (code : Code .pure) : Array (CodeDecl .pure) × Cases .pure := go code #[] where - go (code : Code) (decls : Array CodeDecl) := + go (code : Code .pure) (decls : Array (CodeDecl .pure)) := match code with | .let decl k => go k <| decls.push (.let decl) | .cases c => (decls, c) | _ => unreachable! -- `code` is not the body of a join point that satisfies `isJpCases` structure JpCasesAlt where - decl : FunDecl + decl : FunDecl .pure default : Bool dependsOnDiscr : Bool @@ -116,10 +116,12 @@ Construct an auxiliary join point for a particular alternative in a join-point t - `k` is the body of the alternative. - `default` is true if it is a default alternative. -/ -private def mkJpAlt (decls : Array CodeDecl) (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do +private def mkJpAlt (decls : Array (CodeDecl .pure)) (params : Array (Param .pure)) + (targetParamIdx : Nat) (fields : Array (Param .pure)) (k : Code .pure) (default : Bool) : + CompilerM JpCasesAlt := do go |>.run' {} where - go : InternalizeM JpCasesAlt := do + go : InternalizeM .pure JpCasesAlt := do let mut paramsNew := #[] let singleton : FVarIdSet := ({} : FVarIdSet).insert params[targetParamIdx]!.fvarId let dependsOnDiscr := k.dependsOn singleton || decls.any (·.dependsOn singleton) @@ -137,7 +139,8 @@ where return { decl := (← mkAuxJpDecl paramsNew value), default, dependsOnDiscr } /-- Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`. -/ -private def mkJmpNewArgs (args : Array Arg) (targetParamIdx : Nat) (fields : Array Arg) (dependsOnTarget : Bool) : Array Arg := +private def mkJmpNewArgs (args : Array (Arg .pure)) (targetParamIdx : Nat) + (fields : Array (Arg .pure)) (dependsOnTarget : Bool) : Array (Arg .pure) := if dependsOnTarget then args[*...=targetParamIdx] ++ fields ++ args[targetParamIdx<...*] else @@ -147,7 +150,8 @@ private def mkJmpNewArgs (args : Array Arg) (targetParamIdx : Nat) (fields : Arr Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`. This function is used to create jumps from the join point satisfying `isJpCases?` to the new auxiliary join points created using `mkJpAlt`. -/ -private def mkJmpArgsAtJp (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (dependsOnTarget : Bool) : Array Arg := Id.run do +private def mkJmpArgsAtJp (params : Array (Param .pure)) (targetParamIdx : Nat) + (fields : Array (Param .pure)) (dependsOnTarget : Bool) : Array (Arg .pure) := Id.run do mkJmpNewArgs (params.map (Arg.fvar ·.fvarId)) targetParamIdx (fields.map (Arg.fvar ·.fvarId)) dependsOnTarget /-- @@ -194,7 +198,7 @@ cases x.4 Note that if all jumps to the join point are with constructors, then the join point is eliminated as dead code. -/ -partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do +partial def simpJpCases? (code : Code .pure) : CompilerM (Option (Code .pure)) := do let map ← collectJpCasesInfo code unless map.isCandidate do return none traceM `Compiler.simp.jpCases do @@ -204,7 +208,7 @@ partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do return msg visit code map |>.run' {} |>.run {} where - visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) Code := do + visit (code : Code .pure) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Code .pure) := do match code with | .let decl k => return code.updateLet! decl (← visit k) @@ -232,7 +236,8 @@ where let some code ← visitJmp? fvarId args | return code return code - visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do + visitJp? (decl : FunDecl .pure) (k : Code .pure) : + ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option (Code .pure)) := do let some info := (← read).get? decl.fvarId | return none if info.ctorNames.isEmpty then return none -- This join point satisfies `isJpCases?` and there are jumps with constructors in `info` to it. @@ -273,7 +278,8 @@ where let code := .jp decl (← visit k) return LCNF.attachCodeDecls jpAltDecls code - visitJmp? (fvarId : FVarId) (args : Array Arg) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do + visitJmp? (fvarId : FVarId) (args : Array (Arg .pure)) : + ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option (Code .pure)) := do let some ctorJpAltMap := (← get).get? fvarId | return none let some info := (← read).get? fvarId | return none let .fvar argFVarId := args[info.paramIdx]! | return none diff --git a/src/Lean/Compiler/LCNF/Simp/Main.lean b/src/Lean/Compiler/LCNF/Simp/Main.lean index 846f45decf..abd12f151d 100644 --- a/src/Lean/Compiler/LCNF/Simp/Main.lean +++ b/src/Lean/Compiler/LCNF/Simp/Main.lean @@ -25,10 +25,10 @@ such as: a `cases` with many but only one alternative is not reachable. It is only used to avoid the creation of auxiliary join points, and does not need to be precise. -/ -private partial def oneExitPointQuick (c : Code) : Bool := +private partial def oneExitPointQuick (c : Code .pure) : Bool := go c where - go (c : Code) : Bool := + go (c : Code .pure) : Bool := match c with | .let _ k | .fun _ k => go k -- Approximation, the cases may have many unreachable alternatives, and only reachable. @@ -41,7 +41,7 @@ where Create a new local function declaration when `info.args.size < info.params.size`. We use this function to inline/specialize a partial application of a local function. -/ -def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do +def specializePartialApp (info : InlineCandidateInfo) : SimpM (FunDecl .pure) := do let mut subst := {} for param in info.params, arg in info.args do subst := subst.insert param.fvarId arg @@ -58,7 +58,7 @@ def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do /-- Try to inline a join point. -/ -partial def inlineJp? (fvarId : FVarId) (args : Array Arg) : SimpM (Option Code) := do +partial def inlineJp? (fvarId : FVarId) (args : Array (Arg .pure)) : SimpM (Option (Code .pure)) := do /- Remark: we don't need to use `findFunDecl'?` here. -/ let some decl ← findFunDecl? fvarId | return none unless (← shouldInlineLocal decl) do return none @@ -71,13 +71,13 @@ partial applications of functions that take local instances as arguments. This kind of function is inlined or specialized, and we create new simplification opportunities by eta-expanding them. -/ -def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do +def etaPolyApp? (letDecl : LetDecl .pure) : OptionT SimpM (FunDecl .pure) := do guard <| (← read).config.etaPoly let .const declName us args := letDecl.value | failure let some info := (← getEnv).find? declName | failure guard <| (← hasLocalInst info.type) guard <| !(← Meta.isInstance declName) - let some decl ← getDecl? declName | failure + let some ⟨.pure, decl⟩ ← getDecl? declName | failure guard <| decl.getArity > args.size let params ← mkNewParams letDecl.type let auxDecl ← mkAuxLetDecl (.const declName us (args ++ params.map (.fvar ·.fvarId))) @@ -89,14 +89,14 @@ def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do /-- Similar to `Code.isReturnOf`, but taking the current substitution into account. -/ -def isReturnOf (c : Code) (fvarId : FVarId) : SimpM Bool := do +def isReturnOf (c : Code .pure) (fvarId : FVarId) : SimpM Bool := do match c with | .return fvarId' => match (← normFVar fvarId') with | .fvar fvarId'' => return fvarId'' == fvarId | .erased => return false | _ => return false -def elimVar? (value : LetValue) : SimpM (Option FVarId) := do +def elimVar? (value : LetValue .pure) : SimpM (Option FVarId) := do let .fvar fvarId #[] := value | return none return fvarId @@ -117,7 +117,7 @@ of exit points by simplified the inlined code, and then connecting the result to continuation `k`. However, this optimization is only possible if we simplify the inlined code **before** we attach it to the continuation. -/ -partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do +partial def inlineApp? (letDecl : LetDecl .pure) (k : Code .pure) : SimpM (Option (Code .pure)) := do let some info ← inlineCandidate? letDecl.value | return none let numArgs := info.args.size withInlining letDecl.value info.recursive do @@ -135,7 +135,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d simp code else let code ← simp code - let simpK (result : FVarId) : SimpM Code := do + let simpK (result : FVarId) : SimpM (Code .pure) := do /- `result` contains the result of the inlined code -/ if numArgs > info.arity then let decl ← mkAuxLetDecl (.fvar result info.args[info.arity...*]) @@ -151,7 +151,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d markUsedFVar fvarId' simpK fvarId' else - let expectedType ← inferAppType info.fType info.args[*...info.arity] + let expectedType ← inferAppType info.fType (info.args[*...info.arity]).toArray if expectedType.headBeta.isForall then /- If `code` returns a function, we create an auxiliary local function declaration (and eta-expand it) @@ -171,7 +171,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d /-- Simplify the given local function declaration. -/ -partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do +partial def simpFunDecl (decl : FunDecl .pure) : SimpM (FunDecl .pure) := do let type ← normExpr decl.type let params ← normParams decl.params let value ← simp decl.value @@ -180,9 +180,11 @@ partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do /-- Try to simplify `cases` of `constructor` -/ -partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do +partial def simpCasesOnCtor? (cases : Cases .pure) : SimpM (Option (Code .pure)) := do match (← normFVar cases.discr) with - | .erased => mkReturnErased + | .erased => + let ret ← mkReturnErased + return some ret | .fvar discr => let some ctorInfo ← findCtor? discr | return none let (alt, cases) := cases.extractAlt! ctorInfo.getName @@ -210,7 +212,7 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do /-- Simplify `code` -/ -partial def simp (code : Code) : SimpM Code := withIncRecDepth do +partial def simp (code : Code .pure) : SimpM (Code .pure) := withIncRecDepth do incVisited match code with | .let decl k => @@ -225,7 +227,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do -- and `FVarId` rather than `Arg`, and the substitution will end up -- creating a new erased let decl in that case. if decl.type.isErased && decl.value != .erased then - addSubst decl.fvarId .erased + addSubst decl.fvarId (.erased : Arg .pure) eraseLetDecl decl simp k else if let some decls ← ConstantFold.foldConstants decl then diff --git a/src/Lean/Compiler/LCNF/Simp/SimpM.lean b/src/Lean/Compiler/LCNF/Simp/SimpM.lean index 156c241139..1c15fa72cb 100644 --- a/src/Lean/Compiler/LCNF/Simp/SimpM.lean +++ b/src/Lean/Compiler/LCNF/Simp/SimpM.lean @@ -41,7 +41,7 @@ structure State where /-- Free variable substitution. We use it to implement inlining and removing redundant variables `let _x.i := _x.j` -/ - subst : FVarSubst := {} + subst : FVarSubst .pure := {} /-- Track used local declarations to be able to eliminate dead variables. -/ @@ -80,10 +80,10 @@ abbrev SimpM := ReaderT Context $ StateRefT State DiscrM @[always_inline] instance : Monad SimpM := let i := inferInstanceAs (Monad SimpM); { pure := i.pure, bind := i.bind } -instance : MonadFVarSubst SimpM false where +instance : MonadFVarSubst SimpM .pure false where getSubst := return (← get).subst -instance : MonadFVarSubstState SimpM where +instance : MonadFVarSubstState SimpM .pure where modifySubst f := modify fun s => { s with subst := f s.subst } /-- Set the `simplified` flag to `true`. -/ @@ -115,7 +115,7 @@ def addFunHoOcc (fvarId : FVarId) : SimpM Unit := modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addHo fvarId } @[inherit_doc FunDeclInfoMap.update] -partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit := do +partial def updateFunDeclInfo (code : Code .pure) (mustInline := false) : SimpM Unit := do let map ← modifyGet fun s => (s.funDeclInfoMap, { s with funDeclInfoMap := {} }) let map ← map.update code mustInline modify fun s => { s with funDeclInfoMap := map } @@ -124,7 +124,7 @@ partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit : Execute `x` with an updated `inlineStack`. If `value` is of the form `const ...`, add `const` to the stack. Otherwise, do not change the `inlineStack`. -/ -@[inline] def withInlining (value : LetValue) (recursive : Bool) (x : SimpM α) : SimpM α := do +@[inline] def withInlining (value : LetValue .pure) (recursive : Bool) (x : SimpM α) : SimpM α := do if let .const declName _ _ := value then let numOccs ← check declName withReader (fun ctx => { ctx with inlineStack := declName :: ctx.inlineStack, inlineStackOccs := ctx.inlineStackOccs.insert declName numOccs }) x @@ -135,7 +135,7 @@ where trace[Compiler.simp.inline] "{.ofConstName declName}" let numOccs := (← read).inlineStackOccs.find? declName |>.getD 0 let numOccs := numOccs + 1 - let inlineIfReduce ← if let some decl ← getDecl? declName then pure decl.inlineIfReduceAttr else pure false + let inlineIfReduce ← if let some ⟨_, decl⟩ ← getDecl? declName then pure decl.inlineIfReduceAttr else pure false if recursive && inlineIfReduce && numOccs > (← getConfig).maxRecInlineIfReduce then throwError "function `{.ofConstName declName}` has been recursively inlined more than #{(← getConfig).maxRecInlineIfReduce}, consider removing the attribute `[inline_if_reduce]` from this declaration or increasing the limit using `set_option compiler.maxRecInlineIfReduce `" return numOccs @@ -193,13 +193,13 @@ def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do /-- Return `true` if the given code is considered "small". -/ -def isSmall (code : Code) : SimpM Bool := +def isSmall (code : Code .pure) : SimpM Bool := return code.sizeLe (← getConfig).smallThreshold /-- Return `true` if the given local function declaration should be inlined. -/ -def shouldInlineLocal (decl : FunDecl) : SimpM Bool := do +def shouldInlineLocal (decl : FunDecl .pure) : SimpM Bool := do if (← isOnceOrMustInline decl.fvarId) then return true else @@ -210,7 +210,8 @@ LCNF "Beta-reduce". The equivalent of `(fun params => code) args`. If `mustInline` is true, the local function declarations in the resulting code are marked as `.mustInline`. See comment at `updateFunDeclInfo`. -/ -def betaReduce (params : Array Param) (code : Code) (args : Array Arg) (mustInline := false) : SimpM Code := do +def betaReduce (params : Array (Param .pure)) (code : Code .pure) (args : Array (Arg .pure)) + (mustInline := false) : SimpM (Code .pure) := do let mut subst := {} for param in params, arg in args do subst := subst.insert param.fvarId arg @@ -222,7 +223,7 @@ def betaReduce (params : Array Param) (code : Code) (args : Array Arg) (mustInli Erase the given let-declaration from the local context, and set the `simplified` flag to true. -/ -def eraseLetDecl (decl : LetDecl) : SimpM Unit := do +def eraseLetDecl (decl : LetDecl .pure) : SimpM Unit := do LCNF.eraseLetDecl decl markSimplified @@ -230,7 +231,7 @@ def eraseLetDecl (decl : LetDecl) : SimpM Unit := do Erase the given local function declaration from the local context, and set the `simplified` flag to true. -/ -def eraseFunDecl (decl : FunDecl) : SimpM Unit := do +def eraseFunDecl (decl : FunDecl .pure) : SimpM Unit := do LCNF.eraseFunDecl decl markSimplified diff --git a/src/Lean/Compiler/LCNF/Simp/SimpValue.lean b/src/Lean/Compiler/LCNF/Simp/SimpValue.lean index 3801f62330..0b63db8414 100644 --- a/src/Lean/Compiler/LCNF/Simp/SimpValue.lean +++ b/src/Lean/Compiler/LCNF/Simp/SimpValue.lean @@ -16,7 +16,7 @@ namespace Simp /-- Try to simplify projections `.proj _ i s` where `s` is constructor. -/ -def simpProj? (e : LetValue) : OptionT SimpM LetValue := do +def simpProj? (e : LetValue .pure) : OptionT SimpM (LetValue .pure) := do let .proj _ i s := e | failure let some ctorInfo ← findCtor? s | failure match ctorInfo with @@ -31,7 +31,7 @@ g b ``` is simplified to `f a b`. -/ -def simpAppApp? (e : LetValue) : OptionT SimpM LetValue := do +def simpAppApp? (e : LetValue .pure) : OptionT SimpM (LetValue .pure) := do let .fvar g args := e | failure let some decl ← findLetDecl? g | failure match decl.value with @@ -46,19 +46,19 @@ def simpAppApp? (e : LetValue) : OptionT SimpM LetValue := do | .erased => return .erased | .proj .. | .lit .. => failure -def simpCtorDiscr? (e : LetValue) : OptionT SimpM LetValue := do +def simpCtorDiscr? (e : LetValue .pure) : OptionT SimpM (LetValue .pure) := do let .const declName _ _ := e | failure let some (.ctorInfo _) := (← getEnv).find? declName | failure let some fvarId ← simpCtorDiscrCore? e.toExpr | failure return .fvar fvarId #[] -def applyImplementedBy? (e : LetValue) : OptionT SimpM LetValue := do +def applyImplementedBy? (e : LetValue .pure) : OptionT SimpM (LetValue .pure) := do guard <| (← read).config.implementedBy let .const declName us args := e | failure let some declNameNew := getImplementedBy? (← getEnv) declName | failure return .const declNameNew us args /-- Try to apply simple simplifications. -/ -def simpValue? (e : LetValue) : SimpM (Option LetValue) := +def simpValue? (e : LetValue .pure) : SimpM (Option (LetValue .pure)) := -- TODO: more simplifications simpProj? e <|> simpAppApp? e <|> simpCtorDiscr? e <|> applyImplementedBy? e diff --git a/src/Lean/Compiler/LCNF/Simp/Used.lean b/src/Lean/Compiler/LCNF/Simp/Used.lean index 30dfcf7c87..1718b0038c 100644 --- a/src/Lean/Compiler/LCNF/Simp/Used.lean +++ b/src/Lean/Compiler/LCNF/Simp/Used.lean @@ -23,7 +23,7 @@ def markUsedFVar (fvarId : FVarId) : SimpM Unit := /-- Mark all free variables occurring in `arg` as used. -/ -def markUsedArg (arg : Arg) : SimpM Unit := +def markUsedArg (arg : Arg .pure) : SimpM Unit := match arg with | .fvar fvarId => markUsedFVar fvarId -- Locally declared variables do not occur in types. @@ -32,7 +32,7 @@ def markUsedArg (arg : Arg) : SimpM Unit := /-- Mark all free variables occurring in `e` as used. -/ -def markUsedLetValue (e : LetValue) : SimpM Unit := do +def markUsedLetValue (e : LetValue .pure) : SimpM Unit := do match e with | .lit .. | .erased => return () | .proj _ _ fvarId => markUsedFVar fvarId @@ -43,14 +43,14 @@ def markUsedLetValue (e : LetValue) : SimpM Unit := do Mark all free variables occurring on the right-hand side of the given let declaration as used. This is information is used to eliminate dead local declarations. -/ -def markUsedLetDecl (letDecl : LetDecl) : SimpM Unit := +def markUsedLetDecl (letDecl : LetDecl .pure) : SimpM Unit := markUsedLetValue letDecl.value mutual /-- Mark all free variables occurring in `code` as used. -/ -partial def markUsedCode (code : Code) : SimpM Unit := do +partial def markUsedCode (code : Code .pure) : SimpM Unit := do match code with | .let decl k => markUsedLetDecl decl; markUsedCode k | .jp decl k | .fun decl k => markUsedFunDecl decl; markUsedCode k @@ -62,7 +62,7 @@ partial def markUsedCode (code : Code) : SimpM Unit := do /-- Mark all free variables occurring in `funDecl` as used. -/ -partial def markUsedFunDecl (funDecl : FunDecl) : SimpM Unit := +partial def markUsedFunDecl (funDecl : FunDecl .pure) : SimpM Unit := markUsedCode funDecl.value end @@ -81,10 +81,10 @@ let _x.2 := true ``` -/ -def attachCodeDecls (decls : Array CodeDecl) (code : Code) : SimpM Code := do +def attachCodeDecls (decls : Array (CodeDecl .pure)) (code : Code .pure) : SimpM (Code .pure) := do go decls.size code where - go (i : Nat) (code : Code) : SimpM Code := do + go (i : Nat) (code : Code .pure) : SimpM (Code .pure) := do if i > 0 then let decl := decls[i-1]! if (← isUsed decl.fvarId) then diff --git a/src/Lean/Compiler/LCNF/SpecInfo.lean b/src/Lean/Compiler/LCNF/SpecInfo.lean index aa1a6c2e5e..5936d07e3c 100644 --- a/src/Lean/Compiler/LCNF/SpecInfo.lean +++ b/src/Lean/Compiler/LCNF/SpecInfo.lean @@ -157,7 +157,7 @@ and `k` is tagged as `.user`, `.fixedHO`, or `.fixedInst`. See comment at `.fixedNeutral`. -/ -private def hasFwdDeps (decl : Decl) (paramsInfo : Array SpecParamInfo) (j : Nat) : Bool := Id.run do +private def hasFwdDeps (decl : Decl .pure) (paramsInfo : Array SpecParamInfo) (j : Nat) : Bool := Id.run do let param := decl.params[j]! for h : k in (j+1)...decl.params.size do if paramsInfo[k]!.causesSpecialization then @@ -175,7 +175,7 @@ computationally relevant declarations. Furthermore this function takes: - `alreadySpecialized` which is a mask that says whether a decl is already a specialized declaration itself. -/ -def computeSpecEntries (decls : Array Decl) (autoSpecialize : Name → Option (Array Nat) → Bool) +def computeSpecEntries (decls : Array (Decl .pure)) (autoSpecialize : Name → Option (Array Nat) → Bool) (alreadySpecialized : Array Bool) : CompilerM (Array SpecEntry) := do let mut declsInfo := #[] for decl in decls do @@ -245,7 +245,7 @@ def computeSpecEntries (decls : Array Decl) (autoSpecialize : Name → Option (A Compute and save specialization information for `decls`. Assumes that `decls` is an SCC of user defined declarations. -/ -def saveSpecEntries (decls : Array Decl) : CompilerM Unit := do +def saveSpecEntries (decls : Array (Decl .pure)) : CompilerM Unit := do let entries ← computeSpecEntries decls (fun _ specArgs? => specArgs? == some #[]) diff --git a/src/Lean/Compiler/LCNF/Specialize.lean b/src/Lean/Compiler/LCNF/Specialize.lean index 89f6ecbe3c..9c4cce17ab 100644 --- a/src/Lean/Compiler/LCNF/Specialize.lean +++ b/src/Lean/Compiler/LCNF/Specialize.lean @@ -66,11 +66,11 @@ structure State where /-- The set of `Decl` that we are done processing. -/ - processedDecls : Array Decl := #[] + processedDecls : Array (Decl .pure) := #[] /-- The set of `Decl` that we will attempt recursive specialization on in the next iteration. -/ - workingDecls : Array Decl := #[] + workingDecls : Array (Decl .pure) := #[] /-- Specialization information about specialized declarations generated in this SCC so far. -/ @@ -101,7 +101,7 @@ def isGround [TraverseFVar α] (e : α) : SpecializeM Bool := do let s := (← read).ground return allFVar (s.contains ·) e -@[inline] def withLetDecl (decl : LetDecl) (x : SpecializeM α) : SpecializeM α := do +@[inline] def withLetDecl (decl : LetDecl .pure) (x : SpecializeM α) : SpecializeM α := do let grd ← isGround decl.value <||> (pure (← isArrowClass? decl.type).isSome) let isUnderApplied ← match decl.value with @@ -109,10 +109,10 @@ def isGround [TraverseFVar α] (e : α) : SpecializeM Bool := do match ← getDecl? fnName with -- This ascription to `Bool` is required to avoid this being inferred as `Prop`, -- even with a type specified on the `let` binding. - | some { params, .. } => pure ((args.size < params.size) : Bool) + | some ⟨_, { params, .. }⟩ => pure ((args.size < params.size) : Bool) | none => pure false | .fvar fnFVarId args => - match ← findFunDecl? fnFVarId with + match ← findFunDecl? (pu := .pure) fnFVarId with -- This ascription to `Bool` is required to avoid this being inferred as `Prop`, -- even with a type specified on the `let` binding. | some (.mk (params := params) ..) => pure ((args.size < params.size) : Bool) @@ -125,7 +125,7 @@ def isGround [TraverseFVar α] (e : α) : SpecializeM Bool := do ground := if grd then ctx.ground.insert fvarId else ctx.ground } -@[inline] def withFunDecl (decl : FunDecl) (x : SpecializeM α) : SpecializeM α := do +@[inline] def withFunDecl (decl : FunDecl .pure) (x : SpecializeM α) : SpecializeM α := do let ctx ← read let grd := allFVar (x := decl.value) fun fvarId => !(ctx.scope.contains fvarId) || ctx.ground.contains fvarId @@ -193,12 +193,13 @@ That is, `mask` contains only the arguments that are contributing to the code sp We use this information to compute a "key" to uniquely identify the code specialization, and creating the specialized code. -/ -def collect (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM (Array (Option Arg) × Array Param × Array CodeDecl) := do +def collect (paramsInfo : Array SpecParamInfo) (args : Array (Arg .pure)) : + SpecializeM (Array (Option (Arg .pure)) × Array (Param .pure) × Array (CodeDecl .pure)) := do let ctx ← read let lctx := (← getThe CompilerM.State).lctx let abstract (fvarId : FVarId) : Bool := -- We convert let-declarations that are not ground into parameters - !lctx.funDecls.contains fvarId && + !(lctx.funDecls .pure).contains fvarId && !ctx.underApplied.contains fvarId && !ctx.ground.contains fvarId Closure.run (inScope := ctx.scope.contains) (abstract := abstract) do @@ -217,7 +218,7 @@ end Collector /-- Return `true` if it is worth using arguments `args` for specialization given the parameter specialization information. -/ -def shouldSpecialize (specEntry : SpecEntry) (args : Array Arg) : SpecializeM Bool := do +def shouldSpecialize (specEntry : SpecEntry) (args : Array (Arg .pure)) : SpecializeM Bool := do let hoCheck := if specEntry.alreadySpecialized then fun arg => do @@ -248,7 +249,7 @@ def shouldSpecialize (specEntry : SpecEntry) (args : Array Arg) : SpecializeM Bo -/ match arg with | .erased | .type .. => return false - | .fvar fvar => return (← findParam? fvar).isNone + | .fvar fvar => return (← findParam? (pu := .pure) fvar).isNone else fun _ => pure true for paramInfo in specEntry.paramsInfo, arg in args do @@ -264,7 +265,7 @@ def shouldSpecialize (specEntry : SpecEntry) (args : Array Arg) : SpecializeM Bo Convert the given declarations into `Expr`, and "zeta-reduce" them into body. This function is used to compute the key that uniquely identifies an code specialization. -/ -def expandCodeDecls (decls : Array CodeDecl) (body : LetValue) : CompilerM Expr := do +def expandCodeDecls (decls : Array (CodeDecl .pure)) (body : LetValue .pure) : CompilerM Expr := do let xs := decls.map (mkFVar ·.fvarId) let values := decls.map fun | .let decl => decl.value.toExpr @@ -285,7 +286,8 @@ Create the "key" that uniquely identifies a code specialization. The result contains the list of universe level parameter names the key that `params`, `decls`, and `body` depends on. We use this information to create the new auxiliary declaration and resulting application. -/ -def mkKey (params : Array Param) (decls : Array CodeDecl) (body : LetValue) : CompilerM (Expr × List Name) := do +def mkKey (params : Array (Param .pure)) (decls : Array (CodeDecl .pure)) (body : LetValue .pure) : + CompilerM (Expr × List Name) := do let body ← expandCodeDecls decls body let key := ToExpr.run do ToExpr.withParams params do @@ -308,7 +310,9 @@ Specialize `decl` using - `decls`: local declarations that arguments in `argMask` depend on. - `levelParamsNew`: the universe level parameters for the new declaration. -/ -def mkSpecDecl (decl : Decl) (us : List Level) (argMask : Array (Option Arg)) (params : Array Param) (decls : Array CodeDecl) (levelParamsNew : List Name) : SpecializeM Decl := do +def mkSpecDecl (decl : Decl .pure) (us : List Level) (argMask : Array (Option (Arg .pure))) + (params : Array (Param .pure)) (decls : Array (CodeDecl .pure)) (levelParamsNew : List Name) : + SpecializeM (Decl .pure) := do let nameNew := decl.name.appendCore `_at_ |>.appendCore (← read).declName |>.appendCore `spec @@ -325,7 +329,7 @@ def mkSpecDecl (decl : Decl) (us : List Level) (argMask : Array (Option Arg)) (p finally eraseDecl decl where - go (decl : Decl) (nameNew : Name) : InternalizeM Decl := do + go (decl : Decl .pure) (nameNew : Name) : InternalizeM .pure (Decl .pure) := do let .code code := decl.value | panic! "can only specialize decls with code" let mut params ← params.mapM internalizeParam let decls ← decls.mapM internalizeCodeDecl @@ -348,21 +352,21 @@ where let value := .code code let safe := decl.safe let recursive := decl.recursive - let decl := { name := nameNew, levelParams := levelParamsNew, params, type, value, safe, recursive, inlineAttr? := none : Decl } + let decl := { name := nameNew, levelParams := levelParamsNew, params, type, value, safe, recursive, inlineAttr? := none : Decl .pure } return decl.setLevelParams /-- Given the specialization mask `paramsInfo` and the arguments `args`, return the arguments that have not been considered for specialization. -/ -def getRemainingArgs (paramsInfo : Array SpecParamInfo) (args : Array Arg) : Array Arg := Id.run do +def getRemainingArgs (paramsInfo : Array SpecParamInfo) (args : Array (Arg .pure)) : Array (Arg .pure) := Id.run do let mut result := #[] for info in paramsInfo, arg in args do if info matches .other then result := result.push arg return result ++ args[paramsInfo.size...*] -def paramsToGroundVars (params : Array Param) : CompilerM FVarIdSet := +def paramsToGroundVars (params : Array (Param .pure)) : CompilerM FVarIdSet := params.foldlM (init := {}) fun r p => do if isTypeFormerType p.type || (← isArrowClass? p.type).isSome then return r.insert p.fvarId @@ -392,13 +396,13 @@ mutual Try to specialize the function application in the given let-declaration. `k` is the continuation for the let-declaration. -/ - partial def specializeApp? (e : LetValue) : SpecializeM (Option LetValue) := do + partial def specializeApp? (e : LetValue .pure) : SpecializeM (Option (LetValue .pure)) := do let .const declName us args := e | return none if args.isEmpty then return none if (← Meta.isInstance declName) then return none let some specEntry ← getSpecEntry? declName | return none unless (← shouldSpecialize specEntry args) do return none - let some decl ← getDecl? declName | return none + let some ⟨.pure, decl⟩ ← getDecl? declName | return none let .code _ := decl.value | return none trace[Compiler.specialize.candidate] "{e.toExpr}, {specEntry}" let paramsInfo := specEntry.paramsInfo @@ -419,7 +423,7 @@ mutual fun | .type .. | .erased => return false | .fvar fvar => do - if let some param ← findParam? fvar then + if let some param ← findParam? (pu := .pure) fvar then /- For now we only allow recursive specialization on non class parameters, reason: We can encounter situations where we repeatedly re-abstract over type classes @@ -442,11 +446,11 @@ mutual } return some (.const specDecl.name usNew argsNew) - partial def visitFunDecl (funDecl : FunDecl) : SpecializeM FunDecl := do + partial def visitFunDecl (funDecl : FunDecl .pure) : SpecializeM (FunDecl .pure) := do let value ← withParams funDecl.params <| visitCode funDecl.value funDecl.update' funDecl.type value - partial def visitCode (code : Code) : SpecializeM Code := do + partial def visitCode (code : Code .pure) : SpecializeM (Code .pure) := do match code with | .let decl k => let mut decl := decl @@ -476,7 +480,7 @@ end /-- Run specialization on the body of `decl`. -/ -def specializeDecl (decl : Decl) : SpecializeM (Decl × Bool) := do +def specializeDecl (decl : Decl .pure) : SpecializeM (Decl .pure × Bool) := do trace[Compiler.specialize.step] m!"Working {decl.name}" if (← decl.isTemplateLike) then return (decl, false) @@ -544,7 +548,7 @@ partial def loop (round : Nat := 0) : SpecializeM Unit := do loop (round + 1) -def main (decls : Array Decl) : CompilerM (Array Decl) := do +def main (decls : Array (Decl .pure)) : CompilerM (Array (Decl .pure)) := do saveSpecEntries decls let (_, s) ← loop |>.run { declName := .anonymous } |>.run { workingDecls := decls } return s.processedDecls diff --git a/src/Lean/Compiler/LCNF/SplitSCC.lean b/src/Lean/Compiler/LCNF/SplitSCC.lean index 7b06e8a5a5..29e126fcc3 100644 --- a/src/Lean/Compiler/LCNF/SplitSCC.lean +++ b/src/Lean/Compiler/LCNF/SplitSCC.lean @@ -13,21 +13,21 @@ namespace Lean.Compiler.LCNF namespace SplitScc -partial def findSccCalls (scc : Std.HashMap Name Decl) (decl : Decl) : BaseIO (Std.HashSet Name) := do +partial def findSccCalls (scc : Std.HashMap Name (Decl pu)) (decl : Decl pu) : BaseIO (Std.HashSet Name) := do match decl.value with | .code code => let (_, calls) ← goCode code |>.run {} return calls | .extern .. => return {} where - goCode (c : Code) : StateRefT (Std.HashSet Name) BaseIO Unit := do + goCode (c : Code pu) : StateRefT (Std.HashSet Name) BaseIO Unit := do match c with | .let decl k => if let .const name .. := decl.value then if scc.contains name then modify fun s => s.insert name goCode k - | .fun decl k | .jp decl k => + | .fun decl k _ | .jp decl k => goCode decl.value goCode k | .cases cases => cases.alts.forM (·.forCodeM goCode) @@ -35,7 +35,7 @@ where end SplitScc -public def splitScc (scc : Array Decl) : CompilerM (Array (Array Decl)) := do +public def splitScc (scc : Array (Decl pu)) : CompilerM (Array (Array (Decl pu))) := do if scc.size == 1 then return #[scc] let declMap := Std.HashMap.ofArray <| scc.map fun decl => (decl.name, decl) diff --git a/src/Lean/Compiler/LCNF/StructProjCases.lean b/src/Lean/Compiler/LCNF/StructProjCases.lean index c33c1d6b1a..b1feb87850 100644 --- a/src/Lean/Compiler/LCNF/StructProjCases.lean +++ b/src/Lean/Compiler/LCNF/StructProjCases.lean @@ -20,7 +20,7 @@ def findStructCtorInfo? (typeName : Name) : CoreM (Option ConstructorVal) := do return ctorInfo def mkFieldParamsForCtorType (ctorType : Expr) (numParams : Nat) (numFields : Nat) : - CompilerM (Array Param) := do + CompilerM (Array (Param .pure)) := do let mut type ← Meta.MetaM.run' <| toLCNFType ctorType type ← toMonoType type for _ in *...numParams do @@ -52,7 +52,7 @@ def remapFVar (fvarId : FVarId) : M FVarId := do mutual -partial def visitCode (code : Code) : M Code := do +partial def visitCode (code : Code .pure) : M (Code .pure) := do match code with | .let decl k => match decl.value with @@ -105,7 +105,7 @@ partial def visitCode (code : Code) : M Code := do | .return fvarId => return code.updateReturn! (← remapFVar fvarId) | .unreach .. => return code -partial def visitLetValue (v : LetValue) : M LetValue := do +partial def visitLetValue (v : LetValue .pure) : M (LetValue .pure) := do match v with | .const _ _ args => return v.updateArgs! (← args.mapM visitArg) @@ -115,24 +115,24 @@ partial def visitLetValue (v : LetValue) : M LetValue := do -- Projections should be handled directly by `visitCode`. | .proj .. => unreachable! -partial def visitAlt (alt : Alt) : M Alt := do +partial def visitAlt (alt : Alt .pure) : M (Alt .pure) := do return alt.updateCode (← visitCode alt.getCode) -partial def visitArg (arg : Arg) : M Arg := +partial def visitArg (arg : Arg .pure) : M (Arg .pure) := match arg with | .fvar fvarId => return arg.updateFVar! (← remapFVar fvarId) | .type _ | .erased => return arg end -def visitDecl (decl : Decl) : M Decl := do +def visitDecl (decl : Decl .pure) : M (Decl .pure) := do let value ← decl.value.mapCodeM (visitCode ·) return { decl with value } end StructProjCases def structProjCases : Pass := - .mkPerDeclaration `structProjCases (StructProjCases.visitDecl · |>.run) .mono + .mkPerDeclaration `structProjCases .mono (StructProjCases.visitDecl · |>.run) builtin_initialize registerTraceClass `Compiler.structProjCases (inherited := true) diff --git a/src/Lean/Compiler/LCNF/ToDecl.lean b/src/Lean/Compiler/LCNF/ToDecl.lean index 640594a3ad..7a3e72fe92 100644 --- a/src/Lean/Compiler/LCNF/ToDecl.lean +++ b/src/Lean/Compiler/LCNF/ToDecl.lean @@ -105,13 +105,13 @@ The steps for this are roughly: - expand declarations tagged with the `[macro_inline]` attribute - turn the resulting term into LCNF declaration -/ -def toDecl (declName : Name) : CompilerM Decl := do +def toDecl (declName : Name) : CompilerM (Decl .pure) := do let declName := if let some name := isUnsafeRecName? declName then name else declName let some info ← getDeclInfo? declName | throwError "declaration `{.ofConstName declName}` not found" let safe ← declIsNotUnsafe declName let env ← getEnv let inlineAttr? := getInlineAttribute? env declName - let paramsFromTypeBinders (expr : Expr) : CompilerM (Array Param) := do + let paramsFromTypeBinders (expr : Expr) : CompilerM (Array (Param .pure)) := do let mut params := #[] let mut currentExpr := expr repeat @@ -145,7 +145,7 @@ def toDecl (declName : Name) : CompilerM Decl := do let code ← toLCNF value let decl ← if let .fun decl (.return _) := code then eraseFunDecl decl (recursive := false) - pure { name := declName, params := decl.params, type, value := .code decl.value, levelParams := info.levelParams, safe, inlineAttr? : Decl } + pure { name := declName, params := decl.params, type, value := .code decl.value, levelParams := info.levelParams, safe, inlineAttr? : Decl .pure } else pure { name := declName, params := #[], type, value := .code code, levelParams := info.levelParams, safe, inlineAttr? } /- `toLCNF` may eta-reduce simple declarations. -/ diff --git a/src/Lean/Compiler/LCNF/ToExpr.lean b/src/Lean/Compiler/LCNF/ToExpr.lean index ea648e46fc..701c5ba6e2 100644 --- a/src/Lean/Compiler/LCNF/ToExpr.lean +++ b/src/Lean/Compiler/LCNF/ToExpr.lean @@ -37,7 +37,7 @@ where abbrev ToExprM := ReaderT Nat $ StateM LevelMap -@[inline] def mkLambdaM (params : Array Param) (e : Expr) : ToExprM Expr := +@[inline] def mkLambdaM (params : Array (Param pu)) (e : Expr) : ToExprM Expr := return go (← read) (← get) params.size e where go (offset : Nat) (m : LevelMap) (i : Nat) (e : Expr) : Expr := @@ -59,7 +59,7 @@ private abbrev _root_.Lean.FVarId.toExprM (fvarId : FVarId) : ToExprM Expr := modify fun s => s.insert fvarId offset withReader (·+1) k -@[inline] partial def withParams (params : Array Param) (k : ToExprM α) : ToExprM α := +@[inline] partial def withParams (params : Array (Param pu)) (k : ToExprM α) : ToExprM α := go 0 where @[specialize] go (i : Nat) : ToExprM α := do @@ -79,21 +79,21 @@ end ToExpr open ToExpr -private def Arg.toExprM (arg : Arg) : ToExprM Expr := +private def Arg.toExprM (arg : Arg pu) : ToExprM Expr := return arg.toExpr.abstract' (← read) (← get) mutual -partial def FunDecl.toExprM (decl : FunDecl) : ToExprM Expr := +partial def FunDecl.toExprM (decl : FunDecl pu) : ToExprM Expr := withParams decl.params do mkLambdaM decl.params (← decl.value.toExprM) -partial def Code.toExprM (code : Code) : ToExprM Expr := do +partial def Code.toExprM (code : Code pu) : ToExprM Expr := do match code with | .let decl k => let type ← abstractM decl.type let value ← abstractM decl.value.toExpr let body ← withFVar decl.fvarId k.toExprM return .letE decl.binderName type value body true - | .fun decl k | .jp decl k => + | .fun decl k _ | .jp decl k => let type ← abstractM decl.type let value ← decl.toExprM let body ← withFVar decl.fvarId k.toExprM @@ -103,17 +103,17 @@ partial def Code.toExprM (code : Code) : ToExprM Expr := do | .unreach type => return mkApp (mkConst ``lcUnreachable) (← abstractM type) | .cases c => let alts ← c.alts.mapM fun - | .alt ctorName params k => do + | .alt ctorName params k _ => do let body ← withParams params do mkLambdaM params (← k.toExprM) return mkApp (mkConst ctorName) body | .default k => k.toExprM return mkAppN (mkConst `cases) (#[← c.discr.toExprM] ++ alts) end -public def Code.toExpr (code : Code) (xs : Array FVarId := #[]) : Expr := +public def Code.toExpr (code : Code pu) (xs : Array FVarId := #[]) : Expr := run' code.toExprM xs -public def FunDecl.toExpr (decl : FunDecl) (xs : Array FVarId := #[]) : Expr := +public def FunDecl.toExpr (decl : FunDecl pu) (xs : Array FVarId := #[]) : Expr := run' decl.toExprM xs end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 0e639b9b6c..22b13d275f 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -33,18 +33,18 @@ The `toLCNF` function maintains a sequence of elements that is eventually converted into `Code`. -/ inductive Element where - | jp (decl : FunDecl) - | fun (decl : FunDecl) - | let (decl : LetDecl) - | cases (p : Param) (cases : Cases) - | unreach (p : Param) + | jp (decl : FunDecl .pure) + | fun (decl : FunDecl .pure) + | let (decl : LetDecl .pure) + | cases (p : Param .pure) (cases : Cases .pure) + | unreach (p : Param .pure) deriving Inhabited /-- State for `BindCasesM` monad Mapping from `_alt.` variables to new join points -/ -abbrev BindCasesM.State := FVarIdMap FunDecl +abbrev BindCasesM.State := FVarIdMap (FunDecl .pure) /-- Auxiliary monad for implementing `bindCases` -/ abbrev BindCasesM := StateRefT BindCasesM.State CompilerM @@ -60,25 +60,25 @@ and then jumps to `jpDecl`. The goal is to make sure the auxiliary join point is of `_alt.`, then `simp` will inline it. That is, our goal is to try to promote the pre join points `_alt.` into a proper join point. -/ -partial def bindCases (jpDecl : FunDecl) (cases : Cases) : CompilerM Code := do +partial def bindCases (jpDecl : FunDecl .pure) (cases : Cases .pure) : CompilerM (Code .pure) := do let (alts, s) ← visitAlts cases.alts |>.run {} let resultType ← mkCasesResultType alts let result := .cases ⟨cases.typeName, resultType, cases.discr, alts⟩ let result := s.foldl (init := result) fun result _ altJp => .jp altJp result return .jp jpDecl result where - visitAlts (alts : Array Alt) : BindCasesM (Array Alt) := + visitAlts (alts : Array (Alt .pure)) : BindCasesM (Array (Alt .pure)) := alts.mapM fun alt => return alt.updateCode (← go alt.getCode) - findFun? (f : FVarId) : CompilerM (Option FunDecl) := do - if let some funDecl ← findFunDecl? f then + findFun? (f : FVarId) : CompilerM (Option (FunDecl .pure)) := do + if let some funDecl ← findFunDecl? (pu := .pure) f then return funDecl - else if let some (.fvar f' #[]) ← findLetValue? f then + else if let some (.fvar f' #[]) ← findLetValue? (pu := .pure) f then findFun? f' else return none - go (code : Code) : BindCasesM Code := do + go (code : Code .pure) : BindCasesM (Code .pure) := do match code with | .let decl k => if let .return fvarId := k then @@ -112,7 +112,7 @@ where Then, we replace the current `let`-declaration with `jmp altJp args` -/ let mut jpParams := #[] - let mut subst := {} + let mut subst : FVarSubst .pure := {} let mut jpArgs := #[] /- Remark: `funDecl.params.size` may be greater than `args.size`. -/ for param in funDecl.params[*...args.size] do @@ -151,10 +151,10 @@ where | .return fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId] | .jmp .. | .unreach .. => return code -def seqToCode (seq : Array Element) (k : Code) : CompilerM Code := do +def seqToCode (seq : Array Element) (k : Code .pure) : CompilerM (Code .pure) := do go seq seq.size k where - go (seq : Array Element) (i : Nat) (c : Code) : CompilerM Code := do + go (seq : Array Element) (i : Nat) (c : Code .pure) : CompilerM (Code .pure) := do if i > 0 then match seq[i-1]! with | .jp decl => go seq (i - 1) (.jp decl c) @@ -198,7 +198,7 @@ structure State where /-- Local context containing the original Lean types (not LCNF ones). -/ lctx : LocalContext := {} /-- Cache from Lean regular expression to LCNF argument. -/ - cache : PHashMap Expr Arg := {} + cache : PHashMap Expr (Arg .pure) := {} /-- Determines whether caching has been disabled due to finding a use of a constant marked with `never_extract`. @@ -228,12 +228,12 @@ abbrev M := StateRefT State CompilerM def pushElement (elem : Element) : M Unit := do modify fun s => { s with seq := s.seq.push elem } -def mkUnreachable (type : Expr) : M Arg := do +def mkUnreachable (type : Expr) : M (Arg .pure) := do let p ← mkAuxParam type pushElement (.unreach p) return .fvar p.fvarId -def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : M FVarId := do +def mkAuxLetDecl (e : LetValue .pure) (prefixName := `_x) : M FVarId := do match e with | .fvar fvarId #[] => return fvarId | _ => @@ -241,11 +241,11 @@ def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : M FVarId := do pushElement (.let letDecl) return letDecl.fvarId -def letValueToArg (e : LetValue) (prefixName := `_x) : M Arg := +def letValueToArg (e : LetValue .pure) (prefixName := `_x) : M (Arg .pure) := return .fvar (← mkAuxLetDecl e prefixName) /-- Create `Code` that executes the current `seq` and then returns `result` -/ -def toCode (result : Arg) : M Code := do +def toCode (result : Arg .pure) : M (Code .pure) := do match result with | .fvar fvarId => seqToCode (← get).seq (.return fvarId) | .erased | .type .. => @@ -327,7 +327,7 @@ def cleanupBinderName (binderName : Name) : CompilerM Name := return binderName /-- Create a new local declaration using a Lean regular type. -/ -def mkParam (binderName : Name) (type : Expr) : M Param := do +def mkParam (binderName : Name) (type : Expr) : M (Param .pure) := do let binderName ← cleanupBinderName binderName let borrow := isMarkedBorrowed type let type' ← toLCNFType type @@ -335,7 +335,8 @@ def mkParam (binderName : Name) (type : Expr) : M Param := do modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default } return param -def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg) : M LetDecl := do +def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg .pure) : + M (LetDecl .pure) := do let binderName ← cleanupBinderName binderName let value' ← match arg with | .fvar fvarId => pure <| .fvar fvarId #[] @@ -347,10 +348,10 @@ def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (a } return letDecl -def visitLambda (e : Expr) : M (Array Param × Expr) := +def visitLambda (e : Expr) : M (Array (Param .pure) × Expr) := go e #[] #[] where - go (e : Expr) (xs : Array Expr) (ps : Array Param) := do + go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) := do if let .lam binderName type body _ := e then let type := type.instantiateRev xs let p ← mkParam binderName type @@ -358,10 +359,10 @@ where else return (ps, e.instantiateRev xs) -def visitBoundedLambda (e : Expr) (n : Nat) : M (Array Param × Expr) := +def visitBoundedLambda (e : Expr) (n : Nat) : M (Array (Param .pure) × Expr) := go e n #[] #[] where - go (e : Expr) (n : Nat) (xs : Array Expr) (ps : Array Param) := do + go (e : Expr) (n : Nat) (xs : Array Expr) (ps : Array (Param .pure)) := do if n == 0 then return (ps, e.instantiateRev xs) else if let .lam binderName type body _ := e then @@ -422,10 +423,10 @@ Put the given expression in `LCNF`. - Eta-expand applications of declarations that satisfy `shouldEtaExpand`. - Put computationally relevant expressions in A-normal form. -/ -partial def toLCNF (e : Expr) : CompilerM Code := do +partial def toLCNF (e : Expr) : CompilerM (Code .pure) := do run do toCode (← visit e) where - visitCore (e : Expr) : M Arg := withIncRecDepth do + visitCore (e : Expr) : M (Arg .pure) := withIncRecDepth do if let some arg := (← get).cache.find? e then return arg let r : Arg ← match e with @@ -441,7 +442,7 @@ where modify fun s => if s.shouldCache then { s with cache := s.cache.insert e r } else s return r - visit (e : Expr) : M Arg := withIncRecDepth do + visit (e : Expr) : M (Arg .pure) := withIncRecDepth do if isLCProof e then return .erased let type ← liftMetaM <| Meta.inferType e @@ -457,10 +458,10 @@ where return .erased visitCore e - visitLit (lit : Literal) : M Arg := + visitLit (lit : Literal) : M (Arg .pure) := letValueToArg (.lit (litToValue lit)) - visitAppArg (e : Expr) : M Arg := do + visitAppArg (e : Expr) : M (Arg .pure) := do if isLCProof e then return .erased let type ← liftMetaM <| Meta.inferType e @@ -478,7 +479,7 @@ where visitCore e /-- Giving `f` a constant `.const declName us`, convert `args` into `args'`, and return `.const declName us args'` -/ - visitAppDefaultConst (f : Expr) (args : Array Expr) : M Arg := do + visitAppDefaultConst (f : Expr) (args : Array Expr) : M (Arg .pure) := do let env ← getEnv let .const declName us := CSimp.replaceConstants env f | unreachable! let args ← args.mapM visitAppArg @@ -487,7 +488,7 @@ where letValueToArg <| .const declName us args /-- Eta expand if under applied, otherwise apply k -/ - etaIfUnderApplied (e : Expr) (arity : Nat) (k : M Arg) : M Arg := do + etaIfUnderApplied (e : Expr) (arity : Nat) (k : M (Arg .pure)) : M (Arg .pure) := do let numArgs := e.getAppNumArgs if numArgs < arity then visit (← etaExpandN e (arity - numArgs)) @@ -502,7 +503,7 @@ where k args[arity...*] ``` -/ - mkOverApplication (app : Arg) (args : Array Expr) (arity : Nat) : M Arg := do + mkOverApplication (app : (Arg .pure)) (args : Array Expr) (arity : Nat) : M (Arg .pure) := do if args.size == arity then return app else @@ -517,7 +518,7 @@ where /-- Visit a `matcher`/`casesOn` alternative. -/ - visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) : M (Expr × Alt) := do + visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) : M (Expr × (Alt .pure)) := do withNewScope do match casesAltInfo with | .default numHyps => @@ -552,7 +553,7 @@ where let altType ← c.inferType return (altType, .alt ctorName ps c) - visitCases (casesInfo : CasesInfo) (e : Expr) : M Arg := + visitCases (casesInfo : CasesInfo) (e : Expr) : M (Arg .pure) := etaIfUnderApplied e casesInfo.arity do let args := e.getAppArgs let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[*...casesInfo.arity])) @@ -603,11 +604,11 @@ where let result := .fvar auxDecl.fvarId mkOverApplication result args casesInfo.arity - visitCtor (arity : Nat) (e : Expr) : M Arg := + visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) := etaIfUnderApplied e arity do visitAppDefaultConst e.getAppFn e.getAppArgs - visitQuotLift (e : Expr) : M Arg := do + visitQuotLift (e : Expr) : M (Arg .pure) := do let arity := 6 etaIfUnderApplied e arity do let mut args := e.getAppArgs @@ -622,7 +623,7 @@ where | .type _ => unreachable! | .fvar fvarId => mkOverApplication (← letValueToArg <| .fvar fvarId #[.fvar invq]) args arity - visitEqRec (e : Expr) : M Arg := + visitEqRec (e : Expr) : M (Arg .pure) := let arity := 6 etaIfUnderApplied e arity do let args := e.getAppArgs @@ -630,7 +631,7 @@ where let minor ← visit minor mkOverApplication minor args arity - visitHEqRec (e : Expr) : M Arg := + visitHEqRec (e : Expr) : M (Arg .pure) := let arity := 7 etaIfUnderApplied e arity do let args := e.getAppArgs @@ -638,19 +639,19 @@ where let minor ← visit minor mkOverApplication minor args arity - visitFalseRec (e : Expr) : M Arg := + visitFalseRec (e : Expr) : M (Arg .pure) := let arity := 2 etaIfUnderApplied e arity do let type ← toLCNFType (← liftMetaM do Meta.inferType e) mkUnreachable type - visitLcUnreachable (e : Expr) : M Arg := + visitLcUnreachable (e : Expr) : M (Arg .pure) := let arity := 1 etaIfUnderApplied e arity do let type ← toLCNFType (← liftMetaM do Meta.inferType e) mkUnreachable type - visitAndIffRecCore (e : Expr) (minorPos : Nat) : M Arg := + visitAndIffRecCore (e : Expr) (minorPos : Nat) : M (Arg .pure) := let arity := 5 etaIfUnderApplied e arity do let args := e.getAppArgs @@ -660,7 +661,7 @@ where let minor := minor.beta #[ha, hb] visit (mkAppN minor args[arity...*]) - visitNoConfusion (e : Expr) : M Arg := do + visitNoConfusion (e : Expr) : M (Arg .pure) := do let .const declName _ := e.getAppFn | unreachable! let info := getNoConfusionInfo (← getEnv) declName let typeName := declName.getPrefix @@ -705,7 +706,7 @@ where else expandNoConfusionMajor (← etaExpandN major (n+1)) (n+1) - visitProjFn (projInfo : ProjectionFunctionInfo) (e : Expr) : M Arg := do + visitProjFn (projInfo : ProjectionFunctionInfo) (e : Expr) : M (Arg .pure) := do let typeName := projInfo.ctorName.getPrefix if isRuntimeBuiltinType typeName then let numArgs := e.getAppNumArgs @@ -720,7 +721,7 @@ where let f ← Core.instantiateValueLevelParams info us visit (f.beta e.getAppArgs) - visitApp (e : Expr) : M Arg := do + visitApp (e : Expr) : M (Arg .pure) := do if let .const declName us := CSimp.replaceConstants (← getEnv) e.getAppFn then if declName == ``Quot.lift then visitQuotLift e @@ -754,7 +755,7 @@ where let args ← args.mapM visitAppArg letValueToArg <| .fvar fvarId args - visitLambda (e : Expr) : M Arg := do + visitLambda (e : Expr) : M (Arg .pure) := do let b := etaReduceImplicit e /- Note: we don't want to eta-reduce arbitrary lambda expressions since it can @@ -790,10 +791,10 @@ where pushElement (.fun funDecl) return .fvar funDecl.fvarId - visitMData (_mdata : MData) (e : Expr) : M Arg := do + visitMData (_mdata : MData) (e : Expr) : M (Arg .pure) := do visit e - visitProj (s : Name) (i : Nat) (e : Expr) : M Arg := do + visitProj (s : Name) (i : Nat) (e : Expr) : M (Arg .pure) := do if isRuntimeBuiltinType s then let structInfo := getStructureInfo (← getEnv) s let projExpr ← liftMetaM <| Meta.mkProjection e structInfo.fieldNames[i]! @@ -803,7 +804,7 @@ where | .erased | .type .. => return .erased | .fvar fvarId => letValueToArg <| .proj s i fvarId - visitLet (e : Expr) (xs : Array Expr) : M Arg := do + visitLet (e : Expr) (xs : Array Expr) : M (Arg .pure) := do match e with | .letE binderName type value body _ => let type := type.instantiateRev xs diff --git a/src/Lean/Compiler/LCNF/ToMono.lean b/src/Lean/Compiler/LCNF/ToMono.lean index e90bb97b66..9f74b1f436 100644 --- a/src/Lean/Compiler/LCNF/ToMono.lean +++ b/src/Lean/Compiler/LCNF/ToMono.lean @@ -20,7 +20,7 @@ structure ToMonoM.State where abbrev ToMonoM := StateRefT ToMonoM.State CompilerM -def Param.toMono (param : Param) : ToMonoM Param := do +def Param.toMono (param : Param .pure) : ToMonoM (Param .pure) := do if isTypeFormerType param.type then modify fun s => { s with typeParams := s.typeParams.insert param.fvarId } param.update (← toMonoType param.type) @@ -37,7 +37,7 @@ def checkFVarUseDeferred (resultFVar fvarId : FVarId) : ToMonoM Unit := do modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName } @[inline] -def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg) : ToMonoM Arg := do +def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg .pure) : ToMonoM (Arg .pure) := do match arg with | .erased | .type .. => return .erased | .fvar fvarId => @@ -47,13 +47,13 @@ def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg) : ToMonoM Arg := check fvarId return arg -def argToMono (arg : Arg) : ToMonoM Arg := argToMonoBase checkFVarUse arg +def argToMono (arg : Arg .pure) : ToMonoM (Arg .pure) := argToMonoBase checkFVarUse arg -def argToMonoDeferredCheck (resultFVar : FVarId) (arg : Arg) : ToMonoM Arg := +def argToMonoDeferredCheck (resultFVar : FVarId) (arg : Arg .pure) : ToMonoM (Arg .pure) := argToMonoBase (checkFVarUseDeferred resultFVar) arg -def argsToMonoWithFnType (resultFVar : FVarId) (args : Array Arg) (type : Expr) - : ToMonoM (Array Arg) := do +def argsToMonoWithFnType (resultFVar : FVarId) (args : Array (Arg .pure)) (type : Expr) + : ToMonoM (Array (Arg .pure)) := do let mut remainingType : Option Expr := some type let mut result := Array.emptyWithCapacity args.size for arg in args do @@ -69,8 +69,8 @@ def argsToMonoWithFnType (resultFVar : FVarId) (args : Array Arg) (type : Expr) result := result.push monoArg return result -def argsToMonoRedArg (resultFVar : FVarId) (args : Array Arg) (params : Array Param) - (redArgs : Array Arg) : ToMonoM (Array Arg) := do +def argsToMonoRedArg (resultFVar : FVarId) (args : Array (Arg .pure)) (params : Array (Param .pure)) + (redArgs : Array (Arg .pure)) : ToMonoM (Array (Arg .pure)) := do let mut result := #[] let mut argIdx := 0 for redArg in redArgs do @@ -87,14 +87,14 @@ def argsToMonoRedArg (resultFVar : FVarId) (args : Array Arg) (params : Array Pa result := result.push arg return result -def ctorAppToMono (resultFVar : FVarId) (ctorInfo : ConstructorVal) (args : Array Arg) - : ToMonoM LetValue := do - let argsNewParams : Array Arg := .replicate ctorInfo.numParams .erased +def ctorAppToMono (resultFVar : FVarId) (ctorInfo : ConstructorVal) (args : Array (Arg .pure)) + : ToMonoM (LetValue .pure) := do + let argsNewParams : Array (Arg .pure) := .replicate ctorInfo.numParams .erased let argsNewFields ← args[ctorInfo.numParams...*].toArray.mapM (argToMonoDeferredCheck resultFVar) let argsNew := argsNewParams ++ argsNewFields return .const ctorInfo.name [] argsNew -partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetValue := do +partial def LetValue.toMono (e : LetValue .pure) (resultFVar : FVarId) : ToMonoM (LetValue .pure) := do match e with | .erased | .lit .. => return e | .const declName _ args => @@ -111,7 +111,7 @@ partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetVa else if declName == ``Quot.lcInv then match args[2]! with | .fvar fvarId => - let mut extraArgs : Array Arg := .emptyWithCapacity (args.size - 3) + let mut extraArgs : Array (Arg .pure) := .emptyWithCapacity (args.size - 3) for i in 3...args.size do let arg ← argToMono args[i]! extraArgs := extraArgs.push arg @@ -164,13 +164,13 @@ partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetVa else return e -def LetDecl.toMono (decl : LetDecl) : ToMonoM LetDecl := do +def LetDecl.toMono (decl : LetDecl .pure) : ToMonoM (LetDecl .pure) := do let type ← toMonoType decl.type let value ← decl.value.toMono decl.fvarId decl.update type value def mkFieldParamsForComputedFields (ctorType : Expr) (numParams : Nat) (numNewFields : Nat) - (oldFields : Array Param) : ToMonoM (Array Param) := do + (oldFields : Array (Param .pure)) : ToMonoM (Array (Param .pure)) := do let mut type := ctorType for _ in *...numParams do match type with @@ -189,14 +189,14 @@ def mkFieldParamsForComputedFields (ctorType : Expr) (numParams : Nat) (numNewFi mutual -partial def FunDecl.toMono (decl : FunDecl) : ToMonoM FunDecl := do +partial def FunDecl.toMono (decl : FunDecl .pure) : ToMonoM (FunDecl .pure) := do let type ← toMonoType decl.type let params ← decl.params.mapM (·.toMono) let value ← decl.value.toMono decl.update type params value /-- Convert `cases` `Decidable` => `Bool` -/ -partial def decToMono (c : Cases) (_ : c.typeName == ``Decidable) : ToMonoM Code := do +partial def decToMono (c : Cases .pure) (_ : c.typeName == ``Decidable) : ToMonoM (Code .pure) := do let resultType ← toMonoType c.resultType let alts ← c.alts.mapM fun alt => do match alt with @@ -208,7 +208,7 @@ partial def decToMono (c : Cases) (_ : c.typeName == ``Decidable) : ToMonoM Code return .cases ⟨``Bool, resultType, c.discr, alts⟩ /-- Eliminate `cases` for `Nat`. -/ -partial def casesNatToMono (c: Cases) (_ : c.typeName == ``Nat) : ToMonoM Code := do +partial def casesNatToMono (c: Cases .pure) (_ : c.typeName == ``Nat) : ToMonoM (Code .pure) := do let resultType ← toMonoType c.resultType let natType := mkConst ``Nat let zeroDecl ← mkLetDecl `zero natType (.lit (.nat 0)) @@ -229,7 +229,7 @@ partial def casesNatToMono (c: Cases) (_ : c.typeName == ``Nat) : ToMonoM Code : return .let zeroDecl (.let isZeroDecl (.cases ⟨``Bool, resultType, isZeroDecl.fvarId, alts⟩)) /-- Eliminate `cases` for `Int`. -/ -partial def casesIntToMono (c: Cases) (_ : c.typeName == ``Int) : ToMonoM Code := do +partial def casesIntToMono (c: Cases .pure) (_ : c.typeName == ``Int) : ToMonoM (Code .pure) := do let resultType ← toMonoType c.resultType let natType := mkConst ``Nat let zeroNatDecl ← mkLetDecl `natZero natType (.lit (.nat 0)) @@ -254,7 +254,8 @@ partial def casesIntToMono (c: Cases) (_ : c.typeName == ``Int) : ToMonoM Code : return .let zeroNatDecl (.let zeroIntDecl (.let isNegDecl (.cases ⟨``Bool, resultType, isNegDecl.fvarId, alts⟩))) /-- Eliminate `cases` for `UInt` types. -/ -partial def casesUIntToMono (c : Cases) (uintName : Name) (_ : c.typeName == uintName) : ToMonoM Code := do +partial def casesUIntToMono (c : Cases .pure) (uintName : Name) (_ : c.typeName == uintName) : + ToMonoM (Code .pure) := do assert! c.alts.size == 1 let .alt _ ps k := c.alts[0]! | unreachable! eraseParams ps @@ -265,7 +266,7 @@ partial def casesUIntToMono (c : Cases) (uintName : Name) (_ : c.typeName == uin return .let decl k /-- Eliminate `cases` for `Array. -/ -partial def casesArrayToMono (c : Cases) (_ : c.typeName == ``Array) : ToMonoM Code := do +partial def casesArrayToMono (c : Cases .pure) (_ : c.typeName == ``Array) : ToMonoM (Code .pure) := do assert! c.alts.size == 1 let .alt _ ps k := c.alts[0]! | unreachable! eraseParams ps @@ -276,7 +277,8 @@ partial def casesArrayToMono (c : Cases) (_ : c.typeName == ``Array) : ToMonoM C return .let decl k /-- Eliminate `cases` for `ByteArray. -/ -partial def casesByteArrayToMono (c : Cases) (_ : c.typeName == ``ByteArray) : ToMonoM Code := do +partial def casesByteArrayToMono (c : Cases .pure) (_ : c.typeName == ``ByteArray) : + ToMonoM (Code .pure) := do assert! c.alts.size == 1 let .alt _ ps k := c.alts[0]! | unreachable! eraseParams ps @@ -287,7 +289,8 @@ partial def casesByteArrayToMono (c : Cases) (_ : c.typeName == ``ByteArray) : T return .let decl k /-- Eliminate `cases` for `FloatArray. -/ -partial def casesFloatArrayToMono (c : Cases) (_ : c.typeName == ``FloatArray) : ToMonoM Code := do +partial def casesFloatArrayToMono (c : Cases .pure) (_ : c.typeName == ``FloatArray) : + ToMonoM (Code .pure) := do assert! c.alts.size == 1 let .alt _ ps k := c.alts[0]! | unreachable! eraseParams ps @@ -298,7 +301,7 @@ partial def casesFloatArrayToMono (c : Cases) (_ : c.typeName == ``FloatArray) : return .let decl k /-- Eliminate `cases` for `String. -/ -partial def casesStringToMono (c : Cases) (_ : c.typeName == ``String) : ToMonoM Code := do +partial def casesStringToMono (c : Cases .pure) (_ : c.typeName == ``String) : ToMonoM (Code .pure) := do assert! c.alts.size == 1 let .alt _ ps k := c.alts[0]! | unreachable! eraseParams ps @@ -309,7 +312,7 @@ partial def casesStringToMono (c : Cases) (_ : c.typeName == ``String) : ToMonoM return .let decl k /-- Eliminate `cases` for `Thunk. -/ -partial def casesThunkToMono (c : Cases) (_ : c.typeName == ``Thunk) : ToMonoM Code := do +partial def casesThunkToMono (c : Cases .pure) (_ : c.typeName == ``Thunk) : ToMonoM (Code .pure) := do assert! c.alts.size == 1 let .alt _ ps k := c.alts[0]! | unreachable! eraseParams ps @@ -329,7 +332,7 @@ partial def casesThunkToMono (c : Cases) (_ : c.typeName == ``Thunk) : ToMonoM C return .fun decl k /-- Eliminate `cases` for `Task. -/ -partial def casesTaskToMono (c : Cases) (_ : c.typeName == ``Task) : ToMonoM Code := do +partial def casesTaskToMono (c : Cases .pure) (_ : c.typeName == ``Task) : ToMonoM (Code .pure) := do assert! c.alts.size == 1 let .alt _ ps k := c.alts[0]! | unreachable! eraseParams ps @@ -340,7 +343,7 @@ partial def casesTaskToMono (c : Cases) (_ : c.typeName == ``Task) : ToMonoM Cod return .let decl k /-- Eliminate `cases` for trivial structure. See `hasTrivialStructure?` -/ -partial def trivialStructToMono (info : TrivialStructureInfo) (c : Cases) : ToMonoM Code := do +partial def trivialStructToMono (info : TrivialStructureInfo) (c : Cases .pure) : ToMonoM (Code .pure) := do assert! c.alts.size == 1 let .alt ctorName ps k := c.alts[0]! | unreachable! assert! ctorName == info.ctorName @@ -353,7 +356,7 @@ partial def trivialStructToMono (info : TrivialStructureInfo) (c : Cases) : ToMo let k ← k.toMono return .let decl k -partial def Code.toMono (code : Code) : ToMonoM Code := do +partial def Code.toMono (code : Code .pure) : ToMonoM (Code .pure) := do match code with | .let decl k => match decl.value with @@ -428,10 +431,10 @@ partial def Code.toMono (code : Code) : ToMonoM Code := do end -def Decl.toMono (decl : Decl) : CompilerM Decl := do +def Decl.toMono (decl : Decl .pure) : CompilerM (Decl .pure) := do go |>.run' {} where - go : ToMonoM Decl := do + go : ToMonoM (Decl .pure) := do let type ← toMonoType decl.type let params ← decl.params.mapM (·.toMono) let value ← decl.value.mapCodeM (·.toMono) diff --git a/src/Lean/Compiler/LCNF/Visibility.lean b/src/Lean/Compiler/LCNF/Visibility.lean index 89699baa5a..f354fd2f65 100644 --- a/src/Lean/Compiler/LCNF/Visibility.lean +++ b/src/Lean/Compiler/LCNF/Visibility.lean @@ -14,23 +14,23 @@ public section namespace Lean.Compiler.LCNF -private partial def collectUsedDecls (code : Code) (s : NameSet := {}) : NameSet := +private partial def collectUsedDecls (code : Code pu) (s : NameSet := {}) : NameSet := match code with | .let decl k => collectUsedDecls k <| collectLetValue decl.value s - | .jp decl k | .fun decl k => collectUsedDecls decl.value <| collectUsedDecls k s + | .jp decl k | .fun decl k _ => collectUsedDecls decl.value <| collectUsedDecls k s | .cases c => c.alts.foldl (init := s) fun s alt => match alt with | .default k => collectUsedDecls k s - | .alt _ _ k => collectUsedDecls k s + | .alt _ _ k _ => collectUsedDecls k s | _ => s where - collectLetValue (e : LetValue) (s : NameSet) : NameSet := + collectLetValue (e : LetValue pu) (s : NameSet) : NameSet := match e with | .const declName .. => s.insert declName | _ => s -private def shouldExportBody (decl : Decl) : CompilerM Bool := do +private def shouldExportBody (decl : Decl pu) : CompilerM Bool := do -- Export body if template-like... decl.isTemplateLike <||> -- ...or it is below the (local) opportunistic inlining threshold and its `Expr` is exported @@ -44,7 +44,7 @@ private def shouldExportBody (decl : Decl) : CompilerM Bool := do Marks the given declaration as to be exported and recursively infers the correct visibility of its body and referenced declarations based on that. -/ -partial def markDeclPublicRec (phase : Phase) (decl : Decl) : CompilerM Unit := do +partial def markDeclPublicRec (phase : Phase) (decl : Decl pu) : CompilerM Unit := do modifyEnv (setDeclPublic · decl.name) if (← shouldExportBody decl) && !isDeclTransparent (← getEnv) phase decl.name then trace[Compiler.inferVisibility] m!"Marking {decl.name} as transparent because it is opaque and its body looks relevant" @@ -57,7 +57,7 @@ partial def markDeclPublicRec (phase : Phase) (decl : Decl) : CompilerM Unit := markDeclPublicRec phase refDecl /-- Checks whether references in the given declaration adhere to phase distinction. -/ -partial def checkMeta (origDecl : Decl) : CompilerM Unit := do +partial def checkMeta (origDecl : Decl pu) : CompilerM Unit := do if !(← getEnv).header.isModule || !compiler.checkMeta.get (← getOptions) then return let irPhases := getIRPhases (← getEnv) origDecl.name @@ -68,7 +68,7 @@ partial def checkMeta (origDecl : Decl) : CompilerM Unit := do -- decls with relevant global attrs are public (`Lean.ensureAttrDeclIsMeta`). let isPublic := !isPrivateName origDecl.name go (irPhases == .comptime) isPublic origDecl |>.run' {} -where go (isMeta isPublic : Bool) (decl : Decl) : StateT NameSet CompilerM Unit := do +where go (isMeta isPublic : Bool) (decl : Decl pu) : StateT NameSet CompilerM Unit := do decl.value.forCodeM fun code => for ref in collectUsedDecls code do if (← get).contains ref then @@ -112,8 +112,8 @@ where go (isMeta isPublic : Bool) (decl : Decl) : StateT NameSet CompilerM Unit -- *their* references in this case. We also need to do this for non-auxiliary defs in case a -- public meta def tries to use a private meta import via a local private meta def :/ . if irPhases == .all || isPublic && isPrivateName ref then - if let some refDecl ← getLocalDecl? ref then - go isMeta isPublic refDecl + if let some ⟨_, refDecl⟩ ← getLocalDecl? ref then + go isMeta isPublic (refDecl.castPurity! pu) /-- Checks that imports necessary for inlining/specialization are public as otherwise we may run into @@ -134,7 +134,7 @@ partial def checkTemplateVisibility : Pass where let isMeta := isMarkedMeta (← getEnv) decl.name go decl decl |>.run' {} return decls -where go (origDecl decl : Decl) : StateT NameSet CompilerM Unit := do +where go (origDecl decl : Decl .pure) : StateT NameSet CompilerM Unit := do decl.value.forCodeM fun code => for ref in collectUsedDecls code do if (← get).contains ref then diff --git a/tests/lean/run/Decidable-decide-erasure.lean b/tests/lean/run/Decidable-decide-erasure.lean index 1b72c6fff9..11c3391e2e 100644 --- a/tests/lean/run/Decidable-decide-erasure.lean +++ b/tests/lean/run/Decidable-decide-erasure.lean @@ -9,8 +9,8 @@ def f (a : Nat) : Bool := -- This is only required until the new code generator is enabled. run_meta Lean.Compiler.compile #[``f] -def countCalls : Probe Decl Nat := - Probe.getLetValues >=> +def countCalls : Probe (Decl .pure) Nat := + Probe.getLetValues .pure >=> Probe.filter (fun e => return e matches .const `Decidable.decide ..) >=> Probe.count