From 6858cb5fb64e0c015ae78564e9be73e573b99f86 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 29 Oct 2020 10:24:16 -0700 Subject: [PATCH] chore: cleanup --- src/Init/Core.lean | 2 +- src/Init/Data/Array/BinSearch.lean | 60 +- src/Init/Data/List/Control.lean | 4 +- src/Init/Data/Nat/Basic.lean | 20 +- src/Init/Data/ToString/Macro.lean | 8 +- src/Init/LeanInit.lean | 18 +- src/Init/WF.lean | 6 +- src/Lean/Compiler/IR/Borrow.lean | 339 +++--- src/Lean/Compiler/IR/Boxing.lean | 470 ++++---- src/Lean/Compiler/IR/Checker.lean | 200 ++-- src/Lean/Compiler/IR/CompilerM.lean | 132 +-- src/Lean/Compiler/IR/CtorLayout.lean | 28 +- src/Lean/Compiler/IR/ElimDeadBranches.lean | 402 ++++--- src/Lean/Compiler/IR/ElimDeadVars.lean | 47 +- src/Lean/Compiler/IR/EmitC.lean | 1178 ++++++++++---------- src/Lean/Compiler/IR/EmitUtil.lean | 60 +- src/Lean/Compiler/IR/ExpandResetReuse.lean | 285 +++-- src/Lean/Compiler/IR/FreeVars.lean | 226 ++-- src/Lean/Compiler/IR/PushProj.lean | 10 +- src/Lean/Compiler/IR/RC.lean | 377 ++++--- src/Lean/Compiler/IR/ResetReuse.lean | 188 ++-- src/Lean/Compiler/IR/SimpCase.lean | 75 +- src/Lean/Elab/Match.lean | 4 +- src/Lean/Exception.lean | 20 +- src/Lean/Level.lean | 12 +- src/Lean/Meta/Match/Match.lean | 14 +- src/Lean/Parser/Extension.lean | 552 +++++---- src/Lean/PrettyPrinter/Meta.lean | 98 +- src/Lean/Util/ReplaceLevel.lean | 18 +- 29 files changed, 2420 insertions(+), 2433 deletions(-) diff --git a/src/Init/Core.lean b/src/Init/Core.lean index 5bd57fbfe0..c57b99b9cc 100644 --- a/src/Init/Core.lean +++ b/src/Init/Core.lean @@ -1271,7 +1271,7 @@ def Prod.map.{u₁, u₂, v₁, v₂} {α₁ : Type u₁} {α₂ : Type u₂} { /- Dependent products -/ theorem exOfPsig {α : Type u} {p : α → Prop} : (PSigma (fun x => p x)) → Exists (fun x => p x) -| ⟨x, hx⟩ => ⟨x, hx⟩ + | ⟨x, hx⟩ => ⟨x, hx⟩ protected theorem PSigma.eta {α : Sort u} {β : α → Sort v} {a₁ a₂ : α} {b₁ : β a₁} {b₂ : β a₂} (h₁ : a₁ = a₂) (h₂ : Eq.rec (motive := fun a _ => β a) b₁ h₁ = b₂) : PSigma.mk a₁ b₁ = PSigma.mk a₂ b₂ := by diff --git a/src/Init/Data/Array/BinSearch.lean b/src/Init/Data/Array/BinSearch.lean index d653a1a587..5172a282f7 100644 --- a/src/Init/Data/Array/BinSearch.lean +++ b/src/Init/Data/Array/BinSearch.lean @@ -14,22 +14,22 @@ namespace Array -- TODO: remove `partial` using well-founded recursion @[specialize] partial def binSearchAux {α : Type u} {β : Type v} [Inhabited α] [Inhabited β] (lt : α → α → Bool) (found : Option α → β) (as : Array α) (k : α) : Nat → Nat → β -| lo, hi => - if lo <= hi then - let m := (lo + hi)/2; - let a := as.get! m; - if lt a k then binSearchAux lt found as k (m+1) hi - else if lt k a then - if m == 0 then found none - else binSearchAux lt found as k lo (m-1) - else found (some a) - else found none + | lo, hi => + if lo <= hi then + let m := (lo + hi)/2; + let a := as.get! m; + if lt a k then binSearchAux lt found as k (m+1) hi + else if lt k a then + if m == 0 then found none + else binSearchAux lt found as k lo (m-1) + else found (some a) + else found none @[inline] def binSearch {α : Type} [Inhabited α] (as : Array α) (k : α) (lt : α → α → Bool) (lo := 0) (hi := as.size - 1) : Option α := -binSearchAux lt id as k lo hi + binSearchAux lt id as k lo hi @[inline] def binSearchContains {α : Type} [Inhabited α] (as : Array α) (k : α) (lt : α → α → Bool) (lo := 0) (hi := as.size - 1) : Bool := -binSearchAux lt Option.isSome as k lo hi + binSearchAux lt Option.isSome as k lo hi @[specialize] private partial def binInsertAux {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α] (lt : α → α → Bool) @@ -37,17 +37,17 @@ binSearchAux lt Option.isSome as k lo hi (add : Unit → m α) (as : Array α) (k : α) : Nat → Nat → m (Array α) -| lo, hi => - -- as[lo] < k < as[hi] - let mid := (lo + hi)/2; - let midVal := as.get! mid; - if lt midVal k then - if mid == lo then do let v ← add (); pure $ as.insertAt (lo+1) v - else binInsertAux lt merge add as k mid hi - else if lt k midVal then - binInsertAux lt merge add as k lo mid - else do - as.modifyM mid $ fun v => merge v + | lo, hi => + -- as[lo] < k < as[hi] + let mid := (lo + hi)/2; + let midVal := as.get! mid; + if lt midVal k then + if mid == lo then do let v ← add (); pure $ as.insertAt (lo+1) v + else binInsertAux lt merge add as k mid hi + else if lt k midVal then + binInsertAux lt merge add as k lo mid + else do + as.modifyM mid $ fun v => merge v @[specialize] partial def binInsertM {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α] (lt : α → α → Bool) @@ -55,14 +55,14 @@ binSearchAux lt Option.isSome as k lo hi (add : Unit → m α) (as : Array α) (k : α) : m (Array α) := -if as.isEmpty then do let v ← add (); pure $ as.push v -else if lt k (as.get! 0) then do let v ← add (); pure $ as.insertAt 0 v -else if !lt (as.get! 0) k then as.modifyM 0 $ merge -else if lt as.back k then do let v ← add (); pure $ as.push v -else if !lt k as.back then as.modifyM (as.size - 1) $ merge -else binInsertAux lt merge add as k 0 (as.size - 1) + if as.isEmpty then do let v ← add (); pure $ as.push v + else if lt k (as.get! 0) then do let v ← add (); pure $ as.insertAt 0 v + else if !lt (as.get! 0) k then as.modifyM 0 $ merge + else if lt as.back k then do let v ← add (); pure $ as.push v + else if !lt k as.back then as.modifyM (as.size - 1) $ merge + else binInsertAux lt merge add as k 0 (as.size - 1) @[inline] def binInsert {α : Type u} [Inhabited α] (lt : α → α → Bool) (as : Array α) (k : α) : Array α := -Id.run $ binInsertM lt (fun _ => k) (fun _ => k) as k + Id.run $ binInsertM lt (fun _ => k) (fun _ => k) as k end Array diff --git a/src/Init/Data/List/Control.lean b/src/Init/Data/List/Control.lean index f6ccd31d2a..3040a30578 100644 --- a/src/Init/Data/List/Control.lean +++ b/src/Init/Data/List/Control.lean @@ -16,8 +16,8 @@ Remark: we can define `mapM`, `mapM₂` and `forM` using `Applicative` instead o Example: ``` def mapM {m : Type u → Type v} [Applicative m] {α : Type w} {β : Type u} (f : α → m β) : List α → m (List β) -| [] => pure [] -| a::as => List.cons <$> (f a) <*> mapM as + | [] => pure [] + | a::as => List.cons <$> (f a) <*> mapM as ``` However, we consider `f <$> a <*> b` an anti-idiom because the generated code diff --git a/src/Init/Data/Nat/Basic.lean b/src/Init/Data/Nat/Basic.lean index 7f20a81edf..eaac9e07b0 100644 --- a/src/Init/Data/Nat/Basic.lean +++ b/src/Init/Data/Nat/Basic.lean @@ -201,16 +201,16 @@ protected theorem oneMul (n : Nat) : 1 * n = n := Nat.mulComm n 1 ▸ Nat.mulOne n protected theorem leftDistrib : ∀ (n m k : Nat), n * (m + k) = n * m + n * k -| 0, m, k => (Nat.zeroMul (m + k)).symm ▸ (Nat.zeroMul m).symm ▸ (Nat.zeroMul k).symm ▸ rfl -| succ n, m, k => - have h₁ : succ n * (m + k) = n * (m + k) + (m + k) from succMul .. - have h₂ : n * (m + k) + (m + k) = (n * m + n * k) + (m + k) from Nat.leftDistrib n m k ▸ rfl - have h₃ : (n * m + n * k) + (m + k) = n * m + (n * k + (m + k)) from Nat.addAssoc .. - have h₄ : n * m + (n * k + (m + k)) = n * m + (m + (n * k + k)) from congrArg (fun x => n*m + x) (Nat.addLeftComm ..) - have h₅ : n * m + (m + (n * k + k)) = (n * m + m) + (n * k + k) from (Nat.addAssoc ..).symm - have h₆ : (n * m + m) + (n * k + k) = (n * m + m) + succ n * k from succMul n k ▸ rfl - have h₇ : (n * m + m) + succ n * k = succ n * m + succ n * k from succMul n m ▸ rfl - (((((h₁.trans h₂).trans h₃).trans h₄).trans h₅).trans h₆).trans h₇ + | 0, m, k => (Nat.zeroMul (m + k)).symm ▸ (Nat.zeroMul m).symm ▸ (Nat.zeroMul k).symm ▸ rfl + | succ n, m, k => + have h₁ : succ n * (m + k) = n * (m + k) + (m + k) from succMul .. + have h₂ : n * (m + k) + (m + k) = (n * m + n * k) + (m + k) from Nat.leftDistrib n m k ▸ rfl + have h₃ : (n * m + n * k) + (m + k) = n * m + (n * k + (m + k)) from Nat.addAssoc .. + have h₄ : n * m + (n * k + (m + k)) = n * m + (m + (n * k + k)) from congrArg (fun x => n*m + x) (Nat.addLeftComm ..) + have h₅ : n * m + (m + (n * k + k)) = (n * m + m) + (n * k + k) from (Nat.addAssoc ..).symm + have h₆ : (n * m + m) + (n * k + k) = (n * m + m) + succ n * k from succMul n k ▸ rfl + have h₇ : (n * m + m) + succ n * k = succ n * m + succ n * k from succMul n m ▸ rfl + (((((h₁.trans h₂).trans h₃).trans h₄).trans h₅).trans h₆).trans h₇ protected theorem rightDistrib (n m k : Nat) : (n + m) * k = n * k + m * k := have h₁ : (n + m) * k = k * (n + m) from Nat.mulComm .. diff --git a/src/Init/Data/ToString/Macro.lean b/src/Init/Data/ToString/Macro.lean index fb675ae810..83eda30133 100644 --- a/src/Init/Data/ToString/Macro.lean +++ b/src/Init/Data/ToString/Macro.lean @@ -10,7 +10,7 @@ import Init.Data.ToString.Basic syntax:max "s!" (interpolatedStr term) : term macro_rules -| `(s! $interpStr) => do - let chunks := interpStr.getArgs - let r ← Lean.Syntax.expandInterpolatedStrChunks chunks (fun a b => `($a ++ $b)) (fun a => `(toString $a)) - `(($r : String)) + | `(s! $interpStr) => do + let chunks := interpStr.getArgs + let r ← Lean.Syntax.expandInterpolatedStrChunks chunks (fun a b => `($a ++ $b)) (fun a => `(toString $a)) + `(($r : String)) diff --git a/src/Init/LeanInit.lean b/src/Init/LeanInit.lean index 403698432a..63e50a2ad1 100644 --- a/src/Init/LeanInit.lean +++ b/src/Init/LeanInit.lean @@ -109,21 +109,21 @@ protected def append : Name → Name → Name instance : Append Name := ⟨Name.append⟩ def capitalize : Name → Name -| Name.str p s _ => mkNameStr p s.capitalize -| n => n + | Name.str p s _ => mkNameStr p s.capitalize + | n => n def appendAfter : Name → String → Name -| str p s _, suffix => mkNameStr p (s ++ suffix) -| n, suffix => mkNameStr n suffix + | str p s _, suffix => mkNameStr p (s ++ suffix) + | n, suffix => mkNameStr n suffix def appendIndexAfter : Name → Nat → Name -| str p s _, idx => mkNameStr p (s ++ "_" ++ toString idx) -| n, idx => mkNameStr n ("_" ++ toString idx) + | str p s _, idx => mkNameStr p (s ++ "_" ++ toString idx) + | n, idx => mkNameStr n ("_" ++ toString idx) def appendBefore : Name → String → Name -| anonymous, pre => mkNameStr anonymous pre -| str p s _, pre => mkNameStr p (pre ++ s) -| num p n _, pre => mkNameNum (mkNameStr p pre) n + | anonymous, pre => mkNameStr anonymous pre + | str p s _, pre => mkNameStr p (pre ++ s) + | num p n _, pre => mkNameNum (mkNameStr p pre) n end Name diff --git a/src/Init/WF.lean b/src/Init/WF.lean index c0a55fb084..909b0967f1 100644 --- a/src/Init/WF.lean +++ b/src/Init/WF.lean @@ -11,7 +11,7 @@ universes u v set_option codegen false inductive Acc {α : Sort u} (r : α → α → Prop) : α → Prop -| intro (x : α) (h : (y : α) → r y x → Acc r y) : Acc r x + | intro (x : α) (h : (y : α) → r y x → Acc r y) : Acc r x @[elabAsEliminator, inline, reducible] def Acc.ndrec.{u1, u2} {α : Sort u2} {r : α → α → Prop} {C : α → Sort u1} @@ -289,8 +289,8 @@ variables {α : Sort u} {β : Sort v} -- Reverse lexicographical order based on r and s inductive RevLex (r : α → α → Prop) (s : β → β → Prop) : @PSigma α (fun a => β) → @PSigma α (fun a => β) → Prop -| left : {a₁ a₂ : α} → (b : β) → r a₁ a₂ → RevLex r s ⟨a₁, b⟩ ⟨a₂, b⟩ -| right : (a₁ : α) → {b₁ : β} → (a₂ : α) → {b₂ : β} → s b₁ b₂ → RevLex r s ⟨a₁, b₁⟩ ⟨a₂, b₂⟩ + | left : {a₁ a₂ : α} → (b : β) → r a₁ a₂ → RevLex r s ⟨a₁, b⟩ ⟨a₂, b⟩ + | right : (a₁ : α) → {b₁ : β} → (a₂ : α) → {b₂ : β} → s b₁ b₂ → RevLex r s ⟨a₁, b₁⟩ ⟨a₂, b₂⟩ end section diff --git a/src/Lean/Compiler/IR/Borrow.lean b/src/Lean/Compiler/IR/Borrow.lean index b229ecd0c2..1c8fcd2e06 100644 --- a/src/Lean/Compiler/IR/Borrow.lean +++ b/src/Lean/Compiler/IR/Borrow.lean @@ -15,11 +15,12 @@ namespace OwnedSet abbrev Key := FunId × Index def beq : Key → Key → Bool -| (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂ + | (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂ + instance : BEq Key := ⟨beq⟩ def getHash : Key → USize -| (f, x) => mixHash (hash f) (hash x) + | (f, x) => mixHash (hash f) (hash x) instance : Hashable Key := ⟨getHash⟩ end OwnedSet @@ -35,19 +36,19 @@ def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := Std.HashMap. Recall that `Param` contains the field `borrow`. -/ namespace ParamMap inductive Key -| decl (name : FunId) -| jp (name : FunId) (jpid : JoinPointId) + | decl (name : FunId) + | jp (name : FunId) (jpid : JoinPointId) def beq : Key → Key → Bool -| Key.decl n₁, Key.decl n₂ => n₁ == n₂ -| Key.jp n₁ id₁, Key.jp n₂ id₂ => n₁ == n₂ && id₁ == id₂ -| _, _ => false + | Key.decl n₁, Key.decl n₂ => n₁ == n₂ + | Key.jp n₁ id₁, Key.jp n₂ id₂ => n₁ == n₂ && id₁ == id₂ + | _, _ => false instance : BEq Key := ⟨beq⟩ def getHash : Key → USize -| Key.decl n => hash n -| Key.jp n id => mixHash (hash n) (hash id) + | Key.decl n => hash n + | Key.jp n id => mixHash (hash n) (hash id) instance : Hashable Key := ⟨getHash⟩ end ParamMap @@ -56,13 +57,13 @@ open ParamMap (Key) abbrev ParamMap := Std.HashMap Key (Array Param) def ParamMap.fmt (map : ParamMap) : Format := -let fmts := map.fold (fun fmt k ps => - let k := match k with - | ParamMap.Key.decl n => format n - | ParamMap.Key.jp n id => format n ++ ":" ++ format id - fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps) - Format.nil -"{" ++ (Format.nest 1 fmts) ++ "}" + let fmts := map.fold (fun fmt k ps => + let k := match k with + | ParamMap.Key.decl n => format n + | ParamMap.Key.jp n id => format n ++ ":" ++ format id + fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps) + Format.nil + "{" ++ (Format.nest 1 fmts) ++ "}" instance : ToFormat ParamMap := ⟨ParamMap.fmt⟩ instance : ToString ParamMap := ⟨fun m => Format.pretty (format m)⟩ @@ -70,7 +71,7 @@ instance : ToString ParamMap := ⟨fun m => Format.pretty (format m)⟩ namespace InitParamMap /- Mark parameters that take a reference as borrow -/ def initBorrow (ps : Array Param) : Array Param := -ps.map $ fun p => { p with borrow := p.ty.isObj } + ps.map $ fun p => { p with borrow := p.ty.isObj } /- We do perform borrow inference for constants marked as `export`. Reason: we current write wrappers in C++ for using exported functions. @@ -80,26 +81,26 @@ ps.map $ fun p => { p with borrow := p.ty.isObj } We can revise this decision when we implement code for generating the wrappers automatically. -/ def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param := -if exported then ps else initBorrow ps + if exported then ps else initBorrow ps partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit -| FnBody.jdecl j xs v b => do - modify fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs) - visitFnBody fnid v - visitFnBody fnid b -| FnBody.case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body -| e => do - unless e.isTerminal do - let (instr, b) := e.split + | FnBody.jdecl j xs v b => do + modify fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs) + visitFnBody fnid v visitFnBody fnid b + | FnBody.case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body + | e => do + unless e.isTerminal do + let (instr, b) := e.split + visitFnBody fnid b -def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit := -decls.forM fun decl => match decl with - | Decl.fdecl f xs _ b => do - let exported := isExport env f - modify fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs) - visitFnBody f b - | _ => pure () + def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit := + decls.forM fun decl => match decl with + | Decl.fdecl f xs _ b => do + let exported := isExport env f + modify fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs) + visitFnBody f b + | _ => pure () end InitParamMap def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap := @@ -110,111 +111,111 @@ def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap := namespace ApplyParamMap partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody -| FnBody.jdecl j xs v b => - let v := visitFnBody fn paramMap v - let b := visitFnBody fn paramMap b - match paramMap.find? (ParamMap.Key.jp fn j) with - | some ys => FnBody.jdecl j ys v b - | none => unreachable! -| FnBody.case tid x xType alts => - FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (visitFnBody fn paramMap) -| e => - if e.isTerminal then e - else - let (instr, b) := e.split + | FnBody.jdecl j xs v b => + let v := visitFnBody fn paramMap v let b := visitFnBody fn paramMap b - instr.setBody b + match paramMap.find? (ParamMap.Key.jp fn j) with + | some ys => FnBody.jdecl j ys v b + | none => unreachable! + | FnBody.case tid x xType alts => + FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (visitFnBody fn paramMap) + | e => + if e.isTerminal then e + else + let (instr, b) := e.split + let b := visitFnBody fn paramMap b + instr.setBody b def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl := -decls.map fun decl => match decl with - | Decl.fdecl f xs ty b => - let b := visitFnBody f paramMap b - match paramMap.find? (ParamMap.Key.decl f) with - | some xs => Decl.fdecl f xs ty b - | none => unreachable! - | other => other + decls.map fun decl => match decl with + | Decl.fdecl f xs ty b => + let b := visitFnBody f paramMap b + match paramMap.find? (ParamMap.Key.decl f) with + | some xs => Decl.fdecl f xs ty b + | none => unreachable! + | other => other end ApplyParamMap def applyParamMap (decls : Array Decl) (map : ParamMap) : Array Decl := --- dbgTrace ("applyParamMap " ++ toString map) $ fun _ => -ApplyParamMap.visitDecls decls map + -- dbgTrace ("applyParamMap " ++ toString map) $ fun _ => + ApplyParamMap.visitDecls decls map structure BorrowInfCtx := -(env : Environment) -(currFn : FunId := arbitrary _) -- Function being analyzed. -(paramSet : IndexSet := {}) -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams` + (env : Environment) + (currFn : FunId := arbitrary _) -- Function being analyzed. + (paramSet : IndexSet := {}) -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams` structure BorrowInfState := -/- Set of variables that must be `owned`. -/ -(owned : OwnedSet := {}) -(modified : Bool := false) -(paramMap : ParamMap) + /- Set of variables that must be `owned`. -/ + (owned : OwnedSet := {}) + (modified : Bool := false) + (paramMap : ParamMap) abbrev M := ReaderT BorrowInfCtx (StateM BorrowInfState) def getCurrFn : M FunId := do -let ctx ← read -pure ctx.currFn + let ctx ← read + pure ctx.currFn def markModified : M Unit := -modify $ fun s => { s with modified := true } + modify fun s => { s with modified := true } def ownVar (x : VarId) : M Unit := do --- dbgTrace ("ownVar " ++ toString x) $ fun _ => -let currFn ← getCurrFn -modify fun s => - if s.owned.contains (currFn, x.idx) then s - else { s with owned := s.owned.insert (currFn, x.idx), modified := true } + -- dbgTrace ("ownVar " ++ toString x) $ fun _ => + let currFn ← getCurrFn + modify fun s => + if s.owned.contains (currFn, x.idx) then s + else { s with owned := s.owned.insert (currFn, x.idx), modified := true } def ownArg (x : Arg) : M Unit := -match x with -| Arg.var x => ownVar x -| _ => pure () + match x with + | Arg.var x => ownVar x + | _ => pure () def ownArgs (xs : Array Arg) : M Unit := -xs.forM ownArg + xs.forM ownArg def isOwned (x : VarId) : M Bool := do -let currFn ← getCurrFn -let s ← get -pure $ s.owned.contains (currFn, x.idx) + let currFn ← getCurrFn + let s ← get + pure $ s.owned.contains (currFn, x.idx) /- Updates `map[k]` using the current set of `owned` variables. -/ def updateParamMap (k : ParamMap.Key) : M Unit := do -let currFn ← getCurrFn -let s ← get -match s.paramMap.find? k with -| some ps => do - let ps ← ps.mapM fun (p : Param) => do - if !p.borrow then pure p - else if (← isOwned p.x) then - markModified - pure { p with borrow := false } - else - pure p - modify fun s => { s with paramMap := s.paramMap.insert k ps } -| none => pure () + let currFn ← getCurrFn + let s ← get + match s.paramMap.find? k with + | some ps => do + let ps ← ps.mapM fun (p : Param) => do + if !p.borrow then pure p + else if (← isOwned p.x) then + markModified + pure { p with borrow := false } + else + pure p + modify fun s => { s with paramMap := s.paramMap.insert k ps } + | none => pure () def getParamInfo (k : ParamMap.Key) : M (Array Param) := do -let s ← get -match s.paramMap.find? k with -| some ps => pure ps -| none => - match k with - | ParamMap.Key.decl fn => do - let ctx ← read - match findEnvDecl ctx.env fn with - | some decl => pure decl.params - | none => unreachable! - | _ => unreachable! + let s ← get + match s.paramMap.find? k with + | some ps => pure ps + | none => + match k with + | ParamMap.Key.decl fn => do + let ctx ← read + match findEnvDecl ctx.env fn with + | some decl => pure decl.params + | none => unreachable! + | _ => unreachable! /- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/ def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit := -xs.size.forM fun i => do - let x := xs[i] - let p := ps[i] - unless p.borrow do ownArg x + xs.size.forM fun i => do + let x := xs[i] + let p := ps[i] + unless p.borrow do ownArg x /- For each xs[i], if xs[i] is owned, then mark ps[i] as owned. We use this action to preserve tail calls. That is, if we have @@ -222,12 +223,12 @@ xs.size.forM fun i => do we would have to insert a `dec xs[i]` after `f xs` and consequently "break" the tail call. -/ def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit := -xs.size.forM fun i => do - let x := xs[i] - let p := ps[i] - match x with - | Arg.var x => if (← isOwned x) then ownVar p.x - | _ => pure () + xs.size.forM fun i => do + let x := xs[i] + let p := ps[i] + match x with + | Arg.var x => if (← isOwned x) then ownVar p.x + | _ => pure () /- Mark `xs[i]` as owned if it is one of the parameters `ps`. We use this action to mark function parameters that are being "packed" inside constructors. @@ -239,85 +240,85 @@ xs.size.forM fun i => do ret z ``` -/ def ownArgsIfParam (xs : Array Arg) : M Unit := do -let ctx ← read -xs.forM fun x => do - match x with - | Arg.var x => if ctx.paramSet.contains x.idx then ownVar x - | _ => pure () + let ctx ← read + xs.forM fun x => do + match x with + | Arg.var x => if ctx.paramSet.contains x.idx then ownVar x + | _ => pure () def collectExpr (z : VarId) : Expr → M Unit -| Expr.reset _ x => ownVar z *> ownVar x -| Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys -| Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs -| Expr.proj _ x => do - if (← isOwned x) then ownVar z - if (← isOwned z) then ownVar x -| Expr.fap g xs => do - let ps ← getParamInfo (ParamMap.Key.decl g) - ownVar z *> ownArgsUsingParams xs ps -| Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys -| Expr.pap _ xs => ownVar z *> ownArgs xs -| other => pure () + | Expr.reset _ x => ownVar z *> ownVar x + | Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys + | Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs + | Expr.proj _ x => do + if (← isOwned x) then ownVar z + if (← isOwned z) then ownVar x + | Expr.fap g xs => do + let ps ← getParamInfo (ParamMap.Key.decl g) + ownVar z *> ownArgsUsingParams xs ps + | Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys + | Expr.pap _ xs => ownVar z *> ownArgs xs + | other => pure () def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do -let ctx ← read -match v, b with -| (Expr.fap g ys), (FnBody.ret (Arg.var z)) => - if ctx.currFn == g && x == z then - -- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do - let ps ← getParamInfo (ParamMap.Key.decl g) - ownParamsUsingArgs ys ps -| _, _ => pure () + let ctx ← read + match v, b with + | (Expr.fap g ys), (FnBody.ret (Arg.var z)) => + if ctx.currFn == g && x == z then + -- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do + let ps ← getParamInfo (ParamMap.Key.decl g) + ownParamsUsingArgs ys ps + | _, _ => pure () def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx := -{ ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet } + { ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet } partial def collectFnBody : FnBody → M Unit -| FnBody.jdecl j ys v b => do - withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v) - let ctx ← read - updateParamMap (ParamMap.Key.jp ctx.currFn j) - collectFnBody b -| FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b -| FnBody.jmp j ys => do - let ctx ← read - let ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j) - ownArgsUsingParams ys ps -- for making sure the join point can reuse - ownParamsUsingArgs ys ps -- for making sure the tail call is preserved -| FnBody.case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body -| e => do unless e.isTerminal do collectFnBody e.body + | FnBody.jdecl j ys v b => do + withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v) + let ctx ← read + updateParamMap (ParamMap.Key.jp ctx.currFn j) + collectFnBody b + | FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b + | FnBody.jmp j ys => do + let ctx ← read + let ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j) + ownArgsUsingParams ys ps -- for making sure the join point can reuse + ownParamsUsingArgs ys ps -- for making sure the tail call is preserved + | FnBody.case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body + | e => do unless e.isTerminal do collectFnBody e.body partial def collectDecl : Decl → M Unit -| Decl.fdecl f ys _ b => - withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) do - collectFnBody b - updateParamMap (ParamMap.Key.decl f) -| _ => pure () + | Decl.fdecl f ys _ b => + withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) do + collectFnBody b + updateParamMap (ParamMap.Key.decl f) + | _ => pure () /- Keep executing `x` until it reaches a fixpoint -/ @[inline] partial def whileModifing (x : M Unit) : M Unit := do -modify fun s => { s with modified := false } -x -let s ← get -if s.modified then - whileModifing x -else - pure () + modify fun s => { s with modified := false } + x + let s ← get + if s.modified then + whileModifing x + else + pure () def collectDecls (decls : Array Decl) : M ParamMap := do -whileModifing (decls.forM collectDecl) -let s ← get -pure s.paramMap + whileModifing (decls.forM collectDecl) + let s ← get + pure s.paramMap def infer (env : Environment) (decls : Array Decl) : ParamMap := -collectDecls decls { env := env } $.run' { paramMap := mkInitParamMap env decls } + collectDecls decls { env := env } $.run' { paramMap := mkInitParamMap env decls } end Borrow def inferBorrow (decls : Array Decl) : CompilerM (Array Decl) := do -let env ← getEnv -let paramMap := Borrow.infer env decls -pure (Borrow.applyParamMap decls paramMap) + let env ← getEnv + let paramMap := Borrow.infer env decls + pure (Borrow.applyParamMap decls paramMap) end IR end Lean diff --git a/src/Lean/Compiler/IR/Boxing.lean b/src/Lean/Compiler/IR/Boxing.lean index de4faf3c70..3111223545 100644 --- a/src/Lean/Compiler/IR/Boxing.lean +++ b/src/Lean/Compiler/IR/Boxing.lean @@ -30,314 +30,320 @@ Assumptions: open Std (AssocList) def mkBoxedName (n : Name) : Name := -mkNameStr n "_boxed" + mkNameStr n "_boxed" def isBoxedName : Name → Bool -| Name.str _ "_boxed" _ => true -| _ => false + | Name.str _ "_boxed" _ => true + | _ => false abbrev N := StateM Nat private def N.mkFresh : N VarId := -modifyGet fun n => ({ idx := n }, n + 1) + modifyGet fun n => ({ idx := n }, n + 1) def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool := -let ps := decl.params -(ps.size > 0 && (decl.resultType.isScalar || ps.any (fun p => p.ty.isScalar || p.borrow) || isExtern env decl.name)) -|| ps.size > closureMaxArgs + let ps := decl.params + (ps.size > 0 && (decl.resultType.isScalar || ps.any (fun p => p.ty.isScalar || p.borrow) || isExtern env decl.name)) + || ps.size > closureMaxArgs def mkBoxedVersionAux (decl : Decl) : N Decl := do -let ps := decl.params -let qs ← ps.mapM fun _ => do let x ← N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param } -let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do - let p := ps[i] - let q := qs[i] - if !p.ty.isScalar then - pure (newVDecls, xs.push (Arg.var q.x)) - else - let x ← N.mkFresh - pure (newVDecls.push (FnBody.vdecl x p.ty (Expr.unbox q.x) (arbitrary _)), xs.push (Arg.var x)) -let r ← N.mkFresh -let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (arbitrary _)) -let body ← - if !decl.resultType.isScalar then - pure $ reshape newVDecls (FnBody.ret (Arg.var r)) - else - let newR ← N.mkFresh - let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (arbitrary _)) - pure $ reshape newVDecls (FnBody.ret (Arg.var newR)) -pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body + let ps := decl.params + let qs ← ps.mapM fun _ => do let x ← N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param } + let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do + let p := ps[i] + let q := qs[i] + if !p.ty.isScalar then + pure (newVDecls, xs.push (Arg.var q.x)) + else + let x ← N.mkFresh + pure (newVDecls.push (FnBody.vdecl x p.ty (Expr.unbox q.x) (arbitrary _)), xs.push (Arg.var x)) + let r ← N.mkFresh + let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (arbitrary _)) + let body ← + if !decl.resultType.isScalar then + pure $ reshape newVDecls (FnBody.ret (Arg.var r)) + else + let newR ← N.mkFresh + let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (arbitrary _)) + pure $ reshape newVDecls (FnBody.ret (Arg.var newR)) + pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body def mkBoxedVersion (decl : Decl) : Decl := -(mkBoxedVersionAux decl).run' 1 + (mkBoxedVersionAux decl).run' 1 def addBoxedVersions (env : Environment) (decls : Array Decl) : Array Decl := -let boxedDecls := decls.foldl (init := #[]) fun newDecls decl => - if requiresBoxedVersion env decl then newDecls.push (mkBoxedVersion decl) else newDecls -decls ++ boxedDecls + let boxedDecls := decls.foldl (init := #[]) fun newDecls decl => + if requiresBoxedVersion env decl then newDecls.push (mkBoxedVersion decl) else newDecls + decls ++ boxedDecls /- Infer scrutinee type using `case` alternatives. This can be done whenever `alts` does not contain an `Alt.default _` value. -/ def getScrutineeType (alts : Array Alt) : IRType := -let isScalar := - alts.size > 1 && -- Recall that we encode Unit and PUnit using `object`. - alts.all fun - | Alt.ctor c _ => c.isScalar - | Alt.default _ => false -match isScalar with -| false => IRType.object -| true => - let n := alts.size - if n < 256 then IRType.uint8 - else if n < 65536 then IRType.uint16 - else if n < 4294967296 then IRType.uint32 - else IRType.object -- in practice this should be unreachable + let isScalar := + alts.size > 1 && -- Recall that we encode Unit and PUnit using `object`. + alts.all fun + | Alt.ctor c _ => c.isScalar + | Alt.default _ => false + match isScalar with + | false => IRType.object + | true => + let n := alts.size + if n < 256 then IRType.uint8 + else if n < 65536 then IRType.uint16 + else if n < 4294967296 then IRType.uint32 + else IRType.object -- in practice this should be unreachable def eqvTypes (t₁ t₂ : IRType) : Bool := -(t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂) + (t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂) structure BoxingContext := -(f : FunId := arbitrary _) (localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment) + (f : FunId := arbitrary _) + (localCtx : LocalContext := {}) + (resultType : IRType := IRType.irrelevant) + (decls : Array Decl) + (env : Environment) structure BoxingState := -(nextIdx : Index) -/- We create auxiliary declarations when boxing constant and literals. - The idea is to avoid code such as - ``` - let x1 := Uint64.inhabited; - let x2 := box x1; - ... - ``` - We currently do not cache these declarations in an environment extension, but - we use auxDeclCache to avoid creating equivalent auxiliary declarations more than once when - processing the same IR declaration. --/ -(auxDecls : Array Decl := #[]) -(auxDeclCache : AssocList FnBody Expr := Std.AssocList.empty) -(nextAuxId : Nat := 1) + (nextIdx : Index) + /- We create auxiliary declarations when boxing constant and literals. + The idea is to avoid code such as + ``` + let x1 := Uint64.inhabited; + let x2 := box x1; + ... + ``` + We currently do not cache these declarations in an environment extension, but + we use auxDeclCache to avoid creating equivalent auxiliary declarations more than once when + processing the same IR declaration. + -/ + (auxDecls : Array Decl := #[]) + (auxDeclCache : AssocList FnBody Expr := Std.AssocList.empty) + (nextAuxId : Nat := 1) abbrev M := ReaderT BoxingContext (StateT BoxingState Id) private def M.mkFresh : M VarId := do -let oldS ← getModify fun s => { s with nextIdx := s.nextIdx + 1 } -pure { idx := oldS.nextIdx } + let oldS ← getModify fun s => { s with nextIdx := s.nextIdx + 1 } + pure { idx := oldS.nextIdx } def getEnv : M Environment := BoxingContext.env <$> read def getLocalContext : M LocalContext := BoxingContext.localCtx <$> read def getResultType : M IRType := BoxingContext.resultType <$> read def getVarType (x : VarId) : M IRType := do -let localCtx ← getLocalContext -match localCtx.getType x with -| some t => pure t -| none => pure IRType.object -- unreachable, we assume the code is well formed + let localCtx ← getLocalContext + match localCtx.getType x with + | some t => pure t + | none => pure IRType.object -- unreachable, we assume the code is well formed def getJPParams (j : JoinPointId) : M (Array Param) := do -let localCtx ← getLocalContext -match localCtx.getJPParams j with -| some ys => pure ys -| none => pure #[] -- unreachable, we assume the code is well formed + let localCtx ← getLocalContext + match localCtx.getJPParams j with + | some ys => pure ys + | none => pure #[] -- unreachable, we assume the code is well formed def getDecl (fid : FunId) : M Decl := do -let ctx ← read -match findEnvDecl' ctx.env fid ctx.decls with -| some decl => pure decl -| none => pure (arbitrary _) -- unreachable if well-formed + let ctx ← read + match findEnvDecl' ctx.env fid ctx.decls with + | some decl => pure decl + | none => pure (arbitrary _) -- unreachable if well-formed @[inline] def withParams {α : Type} (xs : Array Param) (k : M α) : M α := -withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addParams xs }) k + withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addParams xs }) k @[inline] def withVDecl {α : Type} (x : VarId) (ty : IRType) (v : Expr) (k : M α) : M α := -withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x ty v }) k + withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x ty v }) k @[inline] def withJDecl {α : Type} (j : JoinPointId) (xs : Array Param) (v : FnBody) (k : M α) : M α := -withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j xs v }) k + withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j xs v }) k /- If `x` declaration is of the form `x := Expr.lit _` or `x := Expr.fap c #[]`, and `x`'s type is not cheap to box (e.g., it is `UInt64), then return its value. -/ private def isExpensiveConstantValueBoxing (x : VarId) (xType : IRType) : M (Option Expr) := -if !xType.isScalar then pure none -- We assume unboxing is always cheap -else match xType with -| IRType.uint8 => pure none -| IRType.uint16 => pure none -| _ => do - let localCtx ← getLocalContext - match localCtx.getValue x with - | some val => - match val with - | Expr.lit _ => pure $ some val - | Expr.fap _ args => pure $ if args.size == 0 then some val else none - | _ => pure none - | _ => pure none + if !xType.isScalar then + pure none -- We assume unboxing is always cheap + else match xType with + | IRType.uint8 => pure none + | IRType.uint16 => pure none + | _ => do + let localCtx ← getLocalContext + match localCtx.getValue x with + | some val => + match val with + | Expr.lit _ => pure $ some val + | Expr.fap _ args => pure $ if args.size == 0 then some val else none + | _ => pure none + | _ => pure none /- Auxiliary function used by castVarIfNeeded. It is used when the expected type does not match `xType`. If `xType` is scalar, then we need to "box" it. Otherwise, we need to "unbox" it. -/ def mkCast (x : VarId) (xType : IRType) (expectedType : IRType) : M Expr := do -match (← isExpensiveConstantValueBoxing x xType) with -| some v => do - let ctx ← read - let s ← get - /- Create auxiliary FnBody - ``` - let x_1 : xType := v; - let x_2 : expectedType := Expr.box xType x_1; - ret x_2 - ``` - -/ - let body : FnBody := - FnBody.vdecl { idx := 1 } xType v $ - FnBody.vdecl { idx := 2 } expectedType (Expr.box xType { idx := 1 }) $ - FnBody.ret (mkVarArg { idx := 2 }) - match s.auxDeclCache.find? body with - | some v => pure v - | none => do - let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId) - let auxConst := Expr.fap auxName #[] - let auxDecl := Decl.fdecl auxName #[] expectedType body - modify fun s => { - s with - auxDecls := s.auxDecls.push auxDecl, - auxDeclCache := s.auxDeclCache.cons body auxConst, - nextAuxId := s.nextAuxId + 1 - } - pure auxConst -| none => pure $ if xType.isScalar then Expr.box xType x else Expr.unbox x + match (← isExpensiveConstantValueBoxing x xType) with + | some v => do + let ctx ← read + let s ← get + /- Create auxiliary FnBody + ``` + let x_1 : xType := v; + let x_2 : expectedType := Expr.box xType x_1; + ret x_2 + ``` + -/ + let body : FnBody := + FnBody.vdecl { idx := 1 } xType v $ + FnBody.vdecl { idx := 2 } expectedType (Expr.box xType { idx := 1 }) $ + FnBody.ret (mkVarArg { idx := 2 }) + match s.auxDeclCache.find? body with + | some v => pure v + | none => do + let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId) + let auxConst := Expr.fap auxName #[] + let auxDecl := Decl.fdecl auxName #[] expectedType body + modify fun s => { + s with + auxDecls := s.auxDecls.push auxDecl, + auxDeclCache := s.auxDeclCache.cons body auxConst, + nextAuxId := s.nextAuxId + 1 + } + pure auxConst + | none => pure $ if xType.isScalar then Expr.box xType x else Expr.unbox x @[inline] def castVarIfNeeded (x : VarId) (expected : IRType) (k : VarId → M FnBody) : M FnBody := do -let xType ← getVarType x -if eqvTypes xType expected then k x -else - let y ← M.mkFresh - let v ← mkCast x xType expected - FnBody.vdecl y expected v <$> k y + let xType ← getVarType x + if eqvTypes xType expected then + k x + else + let y ← M.mkFresh + let v ← mkCast x xType expected + FnBody.vdecl y expected v <$> k y @[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody := -match x with -| Arg.var x => castVarIfNeeded x expected (fun x => k (Arg.var x)) -| _ => k x + match x with + | Arg.var x => castVarIfNeeded x expected (fun x => k (Arg.var x)) + | _ => k x @[specialize] def castArgsIfNeededAux (xs : Array Arg) (typeFromIdx : Nat → IRType) : M (Array Arg × Array FnBody) := do -let xs' := #[] -let bs := #[] -let i := 0 -for x in xs do - let expected := typeFromIdx i - match x with - | Arg.irrelevant => - xs' := xs'.push x - | Arg.var x => - let xType ← getVarType x - if eqvTypes xType expected then - xs' := xs'.push (Arg.var x) - else - let y ← M.mkFresh - let v ← mkCast x xType expected - let b := FnBody.vdecl y expected v FnBody.nil - xs' := xs'.push (Arg.var y) - bs := bs.push b - i := i + 1 -return (xs', bs) + let xs' := #[] + let bs := #[] + let i := 0 + for x in xs do + let expected := typeFromIdx i + match x with + | Arg.irrelevant => + xs' := xs'.push x + | Arg.var x => + let xType ← getVarType x + if eqvTypes xType expected then + xs' := xs'.push (Arg.var x) + else + let y ← M.mkFresh + let v ← mkCast x xType expected + let b := FnBody.vdecl y expected v FnBody.nil + xs' := xs'.push (Arg.var y) + bs := bs.push b + i := i + 1 + return (xs', bs) @[inline] def castArgsIfNeeded (xs : Array Arg) (ps : Array Param) (k : Array Arg → M FnBody) : M FnBody := do -let (ys, bs) ← castArgsIfNeededAux xs fun i => ps[i].ty -let b ← k ys -pure (reshape bs b) + let (ys, bs) ← castArgsIfNeededAux xs fun i => ps[i].ty + let b ← k ys + pure (reshape bs b) @[inline] def boxArgsIfNeeded (xs : Array Arg) (k : Array Arg → M FnBody) : M FnBody := do -let (ys, bs) ← castArgsIfNeededAux xs (fun _ => IRType.object) -let b ← k ys -pure (reshape bs b) + let (ys, bs) ← castArgsIfNeededAux xs (fun _ => IRType.object) + let b ← k ys + pure (reshape bs b) def unboxResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody := do -if ty.isScalar then - let y ← M.mkFresh - pure $ FnBody.vdecl y IRType.object e (FnBody.vdecl x ty (Expr.unbox y) b) -else - pure $ FnBody.vdecl x ty e b + if ty.isScalar then + let y ← M.mkFresh + pure $ FnBody.vdecl y IRType.object e (FnBody.vdecl x ty (Expr.unbox y) b) + else + pure $ FnBody.vdecl x ty e b def castResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (eType : IRType) (b : FnBody) : M FnBody := do -if eqvTypes ty eType then - pure $ FnBody.vdecl x ty e b -else - let y ← M.mkFresh - let v ← mkCast y eType ty - pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty v b) + if eqvTypes ty eType then + pure $ FnBody.vdecl x ty e b + else + let y ← M.mkFresh + let v ← mkCast y eType ty + pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty v b) def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody := -match e with -| Expr.ctor c ys => - if c.isScalar && ty.isScalar then - pure $ FnBody.vdecl x ty (Expr.lit (LitVal.num c.cidx)) b - else - boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.ctor c ys) b -| Expr.reuse w c u ys => - boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.reuse w c u ys) b -| Expr.fap f ys => do - let decl ← getDecl f - castArgsIfNeeded ys decl.params fun ys => - castResultIfNeeded x ty (Expr.fap f ys) decl.resultType b -| Expr.pap f ys => do - let env ← getEnv - let decl ← getDecl f - let f := if requiresBoxedVersion env decl then mkBoxedName f else f - boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.pap f ys) b -| Expr.ap f ys => - boxArgsIfNeeded ys fun ys => - unboxResultIfNeeded x ty (Expr.ap f ys) b -| other => - pure $ FnBody.vdecl x ty e b + match e with + | Expr.ctor c ys => + if c.isScalar && ty.isScalar then + pure $ FnBody.vdecl x ty (Expr.lit (LitVal.num c.cidx)) b + else + boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.ctor c ys) b + | Expr.reuse w c u ys => + boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.reuse w c u ys) b + | Expr.fap f ys => do + let decl ← getDecl f + castArgsIfNeeded ys decl.params fun ys => + castResultIfNeeded x ty (Expr.fap f ys) decl.resultType b + | Expr.pap f ys => do + let env ← getEnv + let decl ← getDecl f + let f := if requiresBoxedVersion env decl then mkBoxedName f else f + boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.pap f ys) b + | Expr.ap f ys => + boxArgsIfNeeded ys fun ys => + unboxResultIfNeeded x ty (Expr.ap f ys) b + | other => + pure $ FnBody.vdecl x ty e b partial def visitFnBody : FnBody → M FnBody -| FnBody.vdecl x t v b => do - let b ← withVDecl x t v (visitFnBody b) - visitVDeclExpr x t v b -| FnBody.jdecl j xs v b => do - let v ← withParams xs (visitFnBody v) - let b ← withJDecl j xs v (visitFnBody b) - pure $ FnBody.jdecl j xs v b -| FnBody.uset x i y b => do - let b ← visitFnBody b - castVarIfNeeded y IRType.usize fun y => - pure $ FnBody.uset x i y b -| FnBody.sset x i o y ty b => do - let b ← visitFnBody b - castVarIfNeeded y ty fun y => - pure $ FnBody.sset x i o y ty b -| FnBody.mdata d b => - FnBody.mdata d <$> visitFnBody b -| FnBody.case tid x _ alts => do - let expected := getScrutineeType alts - let alts ← alts.mapM fun alt => alt.mmodifyBody visitFnBody - castVarIfNeeded x expected fun x => do - pure $ FnBody.case tid x expected alts -| FnBody.ret x => do - let expected ← getResultType - castArgIfNeeded x expected (fun x => pure $ FnBody.ret x) -| FnBody.jmp j ys => do - let ps ← getJPParams j - castArgsIfNeeded ys ps fun ys => pure $ FnBody.jmp j ys -| other => - pure other + | FnBody.vdecl x t v b => do + let b ← withVDecl x t v (visitFnBody b) + visitVDeclExpr x t v b + | FnBody.jdecl j xs v b => do + let v ← withParams xs (visitFnBody v) + let b ← withJDecl j xs v (visitFnBody b) + pure $ FnBody.jdecl j xs v b + | FnBody.uset x i y b => do + let b ← visitFnBody b + castVarIfNeeded y IRType.usize fun y => + pure $ FnBody.uset x i y b + | FnBody.sset x i o y ty b => do + let b ← visitFnBody b + castVarIfNeeded y ty fun y => + pure $ FnBody.sset x i o y ty b + | FnBody.mdata d b => + FnBody.mdata d <$> visitFnBody b + | FnBody.case tid x _ alts => do + let expected := getScrutineeType alts + let alts ← alts.mapM fun alt => alt.mmodifyBody visitFnBody + castVarIfNeeded x expected fun x => do + pure $ FnBody.case tid x expected alts + | FnBody.ret x => do + let expected ← getResultType + castArgIfNeeded x expected (fun x => pure $ FnBody.ret x) + | FnBody.jmp j ys => do + let ps ← getJPParams j + castArgsIfNeeded ys ps fun ys => pure $ FnBody.jmp j ys + | other => + pure other def run (env : Environment) (decls : Array Decl) : Array Decl := -let ctx : BoxingContext := { decls := decls, env := env } -let decls := decls.foldl (init := #[]) fun newDecls decl => - match decl with - | Decl.fdecl f xs t b => - let nextIdx := decl.maxIndex + 1 - let (b, s) := (withParams xs (visitFnBody b) { ctx with f := f, resultType := t }).run { nextIdx := nextIdx } - let newDecls := newDecls ++ s.auxDecls - let newDecl := Decl.fdecl f xs t b - let newDecl := newDecl.elimDead - newDecls.push newDecl - | d => newDecls.push d -addBoxedVersions env decls + let ctx : BoxingContext := { decls := decls, env := env } + let decls := decls.foldl (init := #[]) fun newDecls decl => + match decl with + | Decl.fdecl f xs t b => + let nextIdx := decl.maxIndex + 1 + let (b, s) := (withParams xs (visitFnBody b) { ctx with f := f, resultType := t }).run { nextIdx := nextIdx } + let newDecls := newDecls ++ s.auxDecls + let newDecl := Decl.fdecl f xs t b + let newDecl := newDecl.elimDead + newDecls.push newDecl + | d => newDecls.push d + addBoxedVersions env decls end ExplicitBoxing def explicitBoxing (decls : Array Decl) : CompilerM (Array Decl) := do -let env ← getEnv -pure $ ExplicitBoxing.run env decls + let env ← getEnv + pure $ ExplicitBoxing.run env decls end Lean.IR diff --git a/src/Lean/Compiler/IR/Checker.lean b/src/Lean/Compiler/IR/Checker.lean index e1dca2d271..bce386d275 100644 --- a/src/Lean/Compiler/IR/Checker.lean +++ b/src/Lean/Compiler/IR/Checker.lean @@ -9,159 +9,161 @@ import Lean.Compiler.IR.Format namespace Lean.IR.Checker structure CheckerContext := -(env : Environment) (localCtx : LocalContext := {}) (decls : Array Decl) + (env : Environment) + (localCtx : LocalContext := {}) + (decls : Array Decl) structure CheckerState := -(foundVars : IndexSet := {}) + (foundVars : IndexSet := {}) abbrev M := ReaderT CheckerContext (ExceptT String (StateT CheckerState Id)) def markIndex (i : Index) : M Unit := do -let s ← get -if s.foundVars.contains i then - throw s!"variable / joinpoint index {i} has already been used" -modify fun s => { s with foundVars := s.foundVars.insert i } + let s ← get + if s.foundVars.contains i then + throw s!"variable / joinpoint index {i} has already been used" + modify fun s => { s with foundVars := s.foundVars.insert i } def markVar (x : VarId) : M Unit := -markIndex x.idx + markIndex x.idx def markJP (j : JoinPointId) : M Unit := -markIndex j.idx + markIndex j.idx def getDecl (c : Name) : M Decl := do -let ctx ← read -match findEnvDecl' ctx.env c ctx.decls with -| none => throw s!"unknown declaration '{c}'" -| some d => pure d + let ctx ← read + match findEnvDecl' ctx.env c ctx.decls with + | none => throw s!"unknown declaration '{c}'" + | some d => pure d def checkVar (x : VarId) : M Unit := do -let ctx ← read -unless ctx.localCtx.isLocalVar x.idx || ctx.localCtx.isParam x.idx do - throw s!"unknown variable '{x}'" + let ctx ← read + unless ctx.localCtx.isLocalVar x.idx || ctx.localCtx.isParam x.idx do + throw s!"unknown variable '{x}'" def checkJP (j : JoinPointId) : M Unit := do -let ctx ← read -unless ctx.localCtx.isJP j.idx do - throw s!"unknown join point '{j}'" + let ctx ← read + unless ctx.localCtx.isJP j.idx do + throw s!"unknown join point '{j}'" def checkArg (a : Arg) : M Unit := -match a with -| Arg.var x => checkVar x -| other => pure () + match a with + | Arg.var x => checkVar x + | other => pure () def checkArgs (as : Array Arg) : M Unit := -as.forM checkArg + as.forM checkArg @[inline] def checkEqTypes (ty₁ ty₂ : IRType) : M Unit := do -unless ty₁ == ty₂ do - throw "unexpected type" + unless ty₁ == ty₂ do + throw "unexpected type" @[inline] def checkType (ty : IRType) (p : IRType → Bool) : M Unit := do -unless p ty do - throw s!"unexpected type '{ty}'" + unless p ty do + throw s!"unexpected type '{ty}'" def checkObjType (ty : IRType) : M Unit := checkType ty IRType.isObj def checkScalarType (ty : IRType) : M Unit := checkType ty IRType.isScalar def getType (x : VarId) : M IRType := do -let ctx ← read -match ctx.localCtx.getType x with -| some ty => pure ty -| none => throw s!"unknown variable '{x}'" + let ctx ← read + match ctx.localCtx.getType x with + | some ty => pure ty + | none => throw s!"unknown variable '{x}'" @[inline] def checkVarType (x : VarId) (p : IRType → Bool) : M Unit := do -let ty ← getType x; checkType ty p + let ty ← getType x; checkType ty p def checkObjVar (x : VarId) : M Unit := -checkVarType x IRType.isObj + checkVarType x IRType.isObj def checkScalarVar (x : VarId) : M Unit := -checkVarType x IRType.isScalar + checkVarType x IRType.isScalar def checkFullApp (c : FunId) (ys : Array Arg) : M Unit := do -if c == `hugeFuel then - throw "the auxiliary constant `hugeFuel` cannot be used in code, it is used internally for compiling `partial` definitions" -let decl ← getDecl c -unless ys.size == decl.params.size do - throw s!"incorrect number of arguments to '{c}', {ys.size} provided, {decl.params.size} expected" -checkArgs ys + if c == `hugeFuel then + throw "the auxiliary constant `hugeFuel` cannot be used in code, it is used internally for compiling `partial` definitions" + let decl ← getDecl c + unless ys.size == decl.params.size do + throw s!"incorrect number of arguments to '{c}', {ys.size} provided, {decl.params.size} expected" + checkArgs ys def checkPartialApp (c : FunId) (ys : Array Arg) : M Unit := do -let decl ← getDecl c -unless ys.size < decl.params.size do - throw s!"too many arguments too partial application '{c}', num. args: {ys.size}, arity: {decl.params.size}" -checkArgs ys + let decl ← getDecl c + unless ys.size < decl.params.size do + throw s!"too many arguments too partial application '{c}', num. args: {ys.size}, arity: {decl.params.size}" + checkArgs ys def checkExpr (ty : IRType) : Expr → M Unit -| Expr.pap f ys => checkPartialApp f ys *> checkObjType ty -- partial applications should always produce a closure object -| Expr.ap x ys => checkObjVar x *> checkArgs ys -| Expr.fap f ys => checkFullApp f ys -| Expr.ctor c ys => when (!ty.isStruct && !ty.isUnion && c.isRef) (checkObjType ty) *> checkArgs ys -| Expr.reset _ x => checkObjVar x *> checkObjType ty -| Expr.reuse x i u ys => checkObjVar x *> checkArgs ys *> checkObjType ty -| Expr.box xty x => checkObjType ty *> checkScalarVar x *> checkVarType x (fun t => t == xty) -| Expr.unbox x => checkScalarType ty *> checkObjVar x -| Expr.proj i x => do - let xType ← getType x; - match xType with - | IRType.object => checkObjType ty - | IRType.tobject => checkObjType ty - | IRType.struct _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index" - | IRType.union _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index" - | other => throw s!"unexpected IR type '{xType}'" -| Expr.uproj _ x => checkObjVar x *> checkType ty (fun t => t == IRType.usize) -| Expr.sproj _ _ x => checkObjVar x *> checkScalarType ty -| Expr.isShared x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8) -| Expr.isTaggedPtr x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8) -| Expr.lit (LitVal.str _) => checkObjType ty -| Expr.lit _ => pure () + | Expr.pap f ys => checkPartialApp f ys *> checkObjType ty -- partial applications should always produce a closure object + | Expr.ap x ys => checkObjVar x *> checkArgs ys + | Expr.fap f ys => checkFullApp f ys + | Expr.ctor c ys => when (!ty.isStruct && !ty.isUnion && c.isRef) (checkObjType ty) *> checkArgs ys + | Expr.reset _ x => checkObjVar x *> checkObjType ty + | Expr.reuse x i u ys => checkObjVar x *> checkArgs ys *> checkObjType ty + | Expr.box xty x => checkObjType ty *> checkScalarVar x *> checkVarType x (fun t => t == xty) + | Expr.unbox x => checkScalarType ty *> checkObjVar x + | Expr.proj i x => do + let xType ← getType x; + match xType with + | IRType.object => checkObjType ty + | IRType.tobject => checkObjType ty + | IRType.struct _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index" + | IRType.union _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index" + | other => throw s!"unexpected IR type '{xType}'" + | Expr.uproj _ x => checkObjVar x *> checkType ty (fun t => t == IRType.usize) + | Expr.sproj _ _ x => checkObjVar x *> checkScalarType ty + | Expr.isShared x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8) + | Expr.isTaggedPtr x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8) + | Expr.lit (LitVal.str _) => checkObjType ty + | Expr.lit _ => pure () @[inline] def withParams (ps : Array Param) (k : M Unit) : M Unit := do -let ctx ← read -let localCtx ← ps.foldlM (init := ctx.localCtx) fun (ctx : LocalContext) p => do - markVar p.x - pure $ ctx.addParam p -withReader (fun _ => { ctx with localCtx := localCtx }) k + let ctx ← read + let localCtx ← ps.foldlM (init := ctx.localCtx) fun (ctx : LocalContext) p => do + markVar p.x + pure $ ctx.addParam p + withReader (fun _ => { ctx with localCtx := localCtx }) k partial def checkFnBody : FnBody → M Unit -| FnBody.vdecl x t v b => do - checkExpr t v; - markVar x; - let ctx ← read - withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x t v }) (checkFnBody b) -| FnBody.jdecl j ys v b => do - markJP j; - withParams ys (checkFnBody v); - let ctx ← read - withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j ys v }) (checkFnBody b) -| FnBody.set x _ y b => checkVar x *> checkArg y *> checkFnBody b -| FnBody.uset x _ y b => checkVar x *> checkVar y *> checkFnBody b -| FnBody.sset x _ _ y _ b => checkVar x *> checkVar y *> checkFnBody b -| FnBody.setTag x _ b => checkVar x *> checkFnBody b -| FnBody.inc x _ _ _ b => checkVar x *> checkFnBody b -| FnBody.dec x _ _ _ b => checkVar x *> checkFnBody b -| FnBody.del x b => checkVar x *> checkFnBody b -| FnBody.mdata _ b => checkFnBody b -| FnBody.jmp j ys => checkJP j *> checkArgs ys -| FnBody.ret x => checkArg x -| FnBody.case _ x _ alts => checkVar x *> alts.forM (fun alt => checkFnBody alt.body) -| FnBody.unreachable => pure () + | FnBody.vdecl x t v b => do + checkExpr t v; + markVar x; + let ctx ← read + withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x t v }) (checkFnBody b) + | FnBody.jdecl j ys v b => do + markJP j; + withParams ys (checkFnBody v); + let ctx ← read + withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j ys v }) (checkFnBody b) + | FnBody.set x _ y b => checkVar x *> checkArg y *> checkFnBody b + | FnBody.uset x _ y b => checkVar x *> checkVar y *> checkFnBody b + | FnBody.sset x _ _ y _ b => checkVar x *> checkVar y *> checkFnBody b + | FnBody.setTag x _ b => checkVar x *> checkFnBody b + | FnBody.inc x _ _ _ b => checkVar x *> checkFnBody b + | FnBody.dec x _ _ _ b => checkVar x *> checkFnBody b + | FnBody.del x b => checkVar x *> checkFnBody b + | FnBody.mdata _ b => checkFnBody b + | FnBody.jmp j ys => checkJP j *> checkArgs ys + | FnBody.ret x => checkArg x + | FnBody.case _ x _ alts => checkVar x *> alts.forM (fun alt => checkFnBody alt.body) + | FnBody.unreachable => pure () def checkDecl : Decl → M Unit -| Decl.fdecl f xs t b => withParams xs (checkFnBody b) -| Decl.extern f xs t _ => withParams xs (pure ()) + | Decl.fdecl f xs t b => withParams xs (checkFnBody b) + | Decl.extern f xs t _ => withParams xs (pure ()) end Checker def checkDecl (decls : Array Decl) (decl : Decl) : CompilerM Unit := do -let env ← getEnv -match (Checker.checkDecl decl { env := env, decls := decls }).run' {} with -| Except.error msg => throw s!"IR check failed at '{decl.name}', error: {msg}" -| other => pure () + let env ← getEnv + match (Checker.checkDecl decl { env := env, decls := decls }).run' {} with + | Except.error msg => throw s!"IR check failed at '{decl.name}', error: {msg}" + | other => pure () def checkDecls (decls : Array Decl) : CompilerM Unit := -decls.forM (checkDecl decls) + decls.forM (checkDecl decls) end IR end Lean diff --git a/src/Lean/Compiler/IR/CompilerM.lean b/src/Lean/Compiler/IR/CompilerM.lean index 07631584cb..f251930060 100644 --- a/src/Lean/Compiler/IR/CompilerM.lean +++ b/src/Lean/Compiler/IR/CompilerM.lean @@ -10,13 +10,13 @@ import Lean.Compiler.IR.Format namespace Lean.IR inductive LogEntry -| step (cls : Name) (decls : Array Decl) -| message (msg : Format) + | step (cls : Name) (decls : Array Decl) + | message (msg : Format) namespace LogEntry protected def fmt : LogEntry → Format -| step cls decls => Format.bracket "[" (format cls) "]" ++ decls.foldl (fun fmt decl => fmt ++ Format.line ++ format decl) Format.nil -| message msg => msg + | step cls decls => Format.bracket "[" (format cls) "]" ++ decls.foldl (fun fmt decl => fmt ++ Format.line ++ format decl) Format.nil + | message msg => msg instance : ToFormat LogEntry := ⟨LogEntry.fmt⟩ end LogEntry @@ -24,49 +24,50 @@ end LogEntry abbrev Log := Array LogEntry def Log.format (log : Log) : Format := -log.foldl (init := Format.nil) fun fmt entry => - f!"{fmt}{Format.line}{entry}" + log.foldl (init := Format.nil) fun fmt entry => + f!"{fmt}{Format.line}{entry}" @[export lean_ir_log_to_string] def Log.toString (log : Log) : String := -log.format.pretty + log.format.pretty structure CompilerState := -(env : Environment) (log : Log := #[]) + (env : Environment) + (log : Log := #[]) abbrev CompilerM := ReaderT Options (EStateM String CompilerState) def log (entry : LogEntry) : CompilerM Unit := -modify $ fun s => { s with log := s.log.push entry } + modify fun s => { s with log := s.log.push entry } def tracePrefixOptionName := `trace.compiler.ir private def isLogEnabledFor (opts : Options) (optName : Name) : Bool := -match opts.find optName with -| some (DataValue.ofBool v) => v -| other => opts.getBool tracePrefixOptionName + match opts.find optName with + | some (DataValue.ofBool v) => v + | other => opts.getBool tracePrefixOptionName private def logDeclsAux (optName : Name) (cls : Name) (decls : Array Decl) : CompilerM Unit := do -let opts ← read -if isLogEnabledFor opts optName then - log (LogEntry.step cls decls) + let opts ← read + if isLogEnabledFor opts optName then + log (LogEntry.step cls decls) @[inline] def logDecls (cls : Name) (decl : Array Decl) : CompilerM Unit := -logDeclsAux (tracePrefixOptionName ++ cls) cls decl + logDeclsAux (tracePrefixOptionName ++ cls) cls decl private def logMessageIfAux {α : Type} [ToFormat α] (optName : Name) (a : α) : CompilerM Unit := do -let opts ← read -if isLogEnabledFor opts optName then - log (LogEntry.message (format a)) + let opts ← read + if isLogEnabledFor opts optName then + log (LogEntry.message (format a)) @[inline] def logMessageIf {α : Type} [ToFormat α] (cls : Name) (a : α) : CompilerM Unit := -logMessageIfAux (tracePrefixOptionName ++ cls) a + logMessageIfAux (tracePrefixOptionName ++ cls) a @[inline] def logMessage {α : Type} [ToFormat α] (cls : Name) (a : α) : CompilerM Unit := -logMessageIfAux tracePrefixOptionName a + logMessageIfAux tracePrefixOptionName a @[inline] def modifyEnv (f : Environment → Environment) : CompilerM Unit := -modify fun s => { s with env := f s.env } + modify fun s => { s with env := f s.env } open Std (HashMap) @@ -75,10 +76,10 @@ abbrev DeclMap := SMap Name Decl /- Create an array of decls to be saved on .olean file. `decls` may contain duplicate entries, but we assume the one that occurs last is the most recent one. -/ private def mkEntryArray (decls : List Decl) : Array Decl := -/- Remove duplicates by adding decls into a map -/ -let map : HashMap Name Decl := {} -let map := decls.foldl (init := map) fun map decl => map.insert decl.name decl -map.fold (fun a k v => a.push v) #[] + /- Remove duplicates by adding decls into a map -/ + let map : HashMap Name Decl := {} + let map := decls.foldl (init := map) fun map decl => map.insert decl.name decl + map.fold (fun a k v => a.push v) #[] builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ← registerSimplePersistentEnvExtension { @@ -92,53 +93,54 @@ builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ← @[export lean_ir_find_env_decl] def findEnvDecl (env : Environment) (n : Name) : Option Decl := -(declMapExt.getState env).find? n + (declMapExt.getState env).find? n def findDecl (n : Name) : CompilerM (Option Decl) := do -let s ← get -pure $ findEnvDecl s.env n + let s ← get + pure $ findEnvDecl s.env n def containsDecl (n : Name) : CompilerM Bool := do -let s ← get -pure $ (declMapExt.getState s.env).contains n - -def getDecl (n : Name) : CompilerM Decl := do -let (some decl) ← findDecl n | throw s!"unknown declaration '{n}'" -pure decl - -@[export lean_ir_add_decl] -def addDeclAux (env : Environment) (decl : Decl) : Environment := -declMapExt.addEntry env decl - -def getDecls (env : Environment) : List Decl := -declMapExt.getEntries env - -def getEnv : CompilerM Environment := do -let s ← get; pure s.env - -def addDecl (decl : Decl) : CompilerM Unit := -modifyEnv fun env => declMapExt.addEntry env decl - -def addDecls (decls : Array Decl) : CompilerM Unit := -decls.forM addDecl - -def findEnvDecl' (env : Environment) (n : Name) (decls : Array Decl) : Option Decl := -match decls.find? (fun decl => decl.name == n) with -| some decl => some decl -| none => (declMapExt.getState env).find? n - -def findDecl' (n : Name) (decls : Array Decl) : CompilerM (Option Decl) := do -let s ← get; pure $ findEnvDecl' s.env n decls - -def containsDecl' (n : Name) (decls : Array Decl) : CompilerM Bool := -if decls.any (fun decl => decl.name == n) then pure true -else do let s ← get pure $ (declMapExt.getState s.env).contains n +def getDecl (n : Name) : CompilerM Decl := do + let (some decl) ← findDecl n | throw s!"unknown declaration '{n}'" + pure decl + +@[export lean_ir_add_decl] +def addDeclAux (env : Environment) (decl : Decl) : Environment := + declMapExt.addEntry env decl + +def getDecls (env : Environment) : List Decl := + declMapExt.getEntries env + +def getEnv : CompilerM Environment := do + let s ← get; pure s.env + +def addDecl (decl : Decl) : CompilerM Unit := + modifyEnv fun env => declMapExt.addEntry env decl + +def addDecls (decls : Array Decl) : CompilerM Unit := + decls.forM addDecl + +def findEnvDecl' (env : Environment) (n : Name) (decls : Array Decl) : Option Decl := + match decls.find? (fun decl => decl.name == n) with + | some decl => some decl + | none => (declMapExt.getState env).find? n + +def findDecl' (n : Name) (decls : Array Decl) : CompilerM (Option Decl) := do + let s ← get; pure $ findEnvDecl' s.env n decls + +def containsDecl' (n : Name) (decls : Array Decl) : CompilerM Bool := do + if decls.any fun decl => decl.name == n then + pure true + else + let s ← get + pure $ (declMapExt.getState s.env).contains n + def getDecl' (n : Name) (decls : Array Decl) : CompilerM Decl := do -let (some decl) ← findDecl' n decls | throw s!"unknown declaration '{n}'" -pure decl + let (some decl) ← findDecl' n decls | throw s!"unknown declaration '{n}'" + pure decl end IR end Lean diff --git a/src/Lean/Compiler/IR/CtorLayout.lean b/src/Lean/Compiler/IR/CtorLayout.lean index 8fdda8dc89..8555f8a622 100644 --- a/src/Lean/Compiler/IR/CtorLayout.lean +++ b/src/Lean/Compiler/IR/CtorLayout.lean @@ -10,32 +10,32 @@ namespace Lean namespace IR inductive CtorFieldInfo -| irrelevant -| object (i : Nat) -| usize (i : Nat) -| scalar (sz : Nat) (offset : Nat) (type : IRType) + | irrelevant + | object (i : Nat) + | usize (i : Nat) + | scalar (sz : Nat) (offset : Nat) (type : IRType) namespace CtorFieldInfo def format : CtorFieldInfo → Format -| irrelevant => "◾" -| object i => f!"obj@{i}" -| usize i => f!"usize@{i}" -| scalar sz offset type => f!"scalar#{sz}@{offset}:{type}" + | irrelevant => "◾" + | object i => f!"obj@{i}" + | usize i => f!"usize@{i}" + | scalar sz offset type => f!"scalar#{sz}@{offset}:{type}" instance : ToFormat CtorFieldInfo := ⟨format⟩ end CtorFieldInfo structure CtorLayout := -(cidx : Nat) -(fieldInfo : List CtorFieldInfo) -(numObjs : Nat) -(numUSize : Nat) -(scalarSize : Nat) + (cidx : Nat) + (fieldInfo : List CtorFieldInfo) + (numObjs : Nat) + (numUSize : Nat) + (scalarSize : Nat) @[extern "lean_ir_get_ctor_layout"] -constant getCtorLayout (env : @& Environment) (ctorName : @& Name) : Except String CtorLayout := arbitrary _ +constant getCtorLayout (env : @& Environment) (ctorName : @& Name) : Except String CtorLayout end IR end Lean diff --git a/src/Lean/Compiler/IR/ElimDeadBranches.lean b/src/Lean/Compiler/IR/ElimDeadBranches.lean index 640c153915..09330522af 100644 --- a/src/Lean/Compiler/IR/ElimDeadBranches.lean +++ b/src/Lean/Compiler/IR/ElimDeadBranches.lean @@ -11,51 +11,51 @@ namespace Lean.IR.UnreachableBranches /-- Value used in the abstract interpreter -/ inductive Value -| bot -- undefined -| top -- any value -| ctor (i : CtorInfo) (vs : Array Value) -| choice (vs : List Value) + | bot -- undefined + | top -- any value + | ctor (i : CtorInfo) (vs : Array Value) + | choice (vs : List Value) namespace Value instance : Inhabited Value := ⟨top⟩ protected partial def beq : Value → Value → Bool -| bot, bot => true -| top, top => true -| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ Value.beq -| choice vs₁, choice vs₂ => - vs₁.all (fun v₁ => vs₂.any fun v₂ => Value.beq v₁ v₂) - && - vs₂.all (fun v₂ => vs₁.any fun v₁ => Value.beq v₁ v₂) -| _, _ => false + | bot, bot => true + | top, top => true + | ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ Value.beq + | choice vs₁, choice vs₂ => + vs₁.all (fun v₁ => vs₂.any fun v₂ => Value.beq v₁ v₂) + && + vs₂.all (fun v₂ => vs₁.any fun v₁ => Value.beq v₁ v₂) + | _, _ => false instance : BEq Value := ⟨Value.beq⟩ partial def addChoice (merge : Value → Value → Value) : List Value → Value → List Value -| [], v => [v] -| v₁@(ctor i₁ vs₁) :: cs, v₂@(ctor i₂ vs₂) => - if i₁ == i₂ then merge v₁ v₂ :: cs - else v₁ :: addChoice merge cs v₂ -| _, _ => panic! "invalid addChoice" + | [], v => [v] + | v₁@(ctor i₁ vs₁) :: cs, v₂@(ctor i₂ vs₂) => + if i₁ == i₂ then merge v₁ v₂ :: cs + else v₁ :: addChoice merge cs v₂ + | _, _ => panic! "invalid addChoice" partial def merge : Value → Value → Value -| bot, v => v -| v, bot => v -| top, _ => top -| _, top => top -| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) => - if i₁ == i₂ then ctor i₁ $ vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i] vs₂[i]) - else choice [v₁, v₂] -| choice vs₁, choice vs₂ => choice $ vs₁.foldl (addChoice merge) vs₂ -| choice vs, v => choice $ addChoice merge vs v -| v, choice vs => choice $ addChoice merge vs v + | bot, v => v + | v, bot => v + | top, _ => top + | _, top => top + | v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) => + if i₁ == i₂ then ctor i₁ $ vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i] vs₂[i]) + else choice [v₁, v₂] + | choice vs₁, choice vs₂ => choice $ vs₁.foldl (addChoice merge) vs₂ + | choice vs, v => choice $ addChoice merge vs v + | v, choice vs => choice $ addChoice merge vs v protected partial def format : Value → Format -| top => "top" -| bot => "bot" -| choice vs => fmt "@" ++ @List.format _ ⟨Value.format⟩ vs -| ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨Value.format⟩ vs) + | top => "top" + | bot => "bot" + | choice vs => fmt "@" ++ @List.format _ ⟨Value.format⟩ vs + | ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨Value.format⟩ vs) instance : ToFormat Value := ⟨Value.format⟩ instance : ToString Value := ⟨Format.pretty ∘ Value.format⟩ @@ -64,240 +64,236 @@ instance : ToString Value := ⟨Format.pretty ∘ Value.format⟩ We use this function this function to implement a simple widening operation for our abstract interpreter. -/ partial def truncate (env : Environment) : Value → NameSet → Value -| ctor i vs, found => - let I := i.name.getPrefix - if found.contains I then - top - else - let cont (found' : NameSet) : Value := - ctor i (vs.map fun v => truncate env v found') - match env.find? I with - | some (ConstantInfo.inductInfo d) => - if d.isRec then cont (found.insert I) - else cont found - | _ => cont found -| choice vs, found => - let newVs := vs.map fun v => truncate env v found - if newVs.elem top then top - else choice newVs -| v, _ => v + | ctor i vs, found => + let I := i.name.getPrefix + if found.contains I then + top + else + let cont (found' : NameSet) : Value := + ctor i (vs.map fun v => truncate env v found') + match env.find? I with + | some (ConstantInfo.inductInfo d) => + if d.isRec then cont (found.insert I) + else cont found + | _ => cont found + | choice vs, found => + let newVs := vs.map fun v => truncate env v found + if newVs.elem top then top + else choice newVs + | v, _ => v /- Widening operator that guarantees termination in our abstract interpreter. -/ def widening (env : Environment) (v₁ v₂ : Value) : Value := -truncate env (merge v₁ v₂) {} + truncate env (merge v₁ v₂) {} end Value abbrev FunctionSummaries := SMap FunId Value -def mkFunctionSummariesExtension : IO (SimplePersistentEnvExtension (FunId × Value) FunctionSummaries) := -registerSimplePersistentEnvExtension { - name := `unreachBranchesFunSummary, - addImportedFn := fun as => - let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as - cache.switch, - addEntryFn := fun s ⟨e, n⟩ => s.insert e n -} - -@[builtinInit mkFunctionSummariesExtension] -constant functionSummariesExt : SimplePersistentEnvExtension (FunId × Value) FunctionSummaries := arbitrary _ +builtin_initialize functionSummariesExt : SimplePersistentEnvExtension (FunId × Value) FunctionSummaries ← + registerSimplePersistentEnvExtension { + name := `unreachBranchesFunSummary, + addImportedFn := fun as => + let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as + cache.switch, + addEntryFn := fun s ⟨e, n⟩ => s.insert e n + } def addFunctionSummary (env : Environment) (fid : FunId) (v : Value) : Environment := -functionSummariesExt.addEntry env (fid, v) + functionSummariesExt.addEntry env (fid, v) def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value := -(functionSummariesExt.getState env).find? fid + (functionSummariesExt.getState env).find? fid abbrev Assignment := Std.HashMap VarId Value structure InterpContext := -(currFnIdx : Nat := 0) -(decls : Array Decl) -(env : Environment) -(lctx : LocalContext := {}) + (currFnIdx : Nat := 0) + (decls : Array Decl) + (env : Environment) + (lctx : LocalContext := {}) structure InterpState := -(assignments : Array Assignment) -(funVals : Std.PArray Value) -- we take snapshots during fixpoint computations + (assignments : Array Assignment) + (funVals : Std.PArray Value) -- we take snapshots during fixpoint computations abbrev M := ReaderT InterpContext (StateM InterpState) open Value def findVarValue (x : VarId) : M Value := do -let ctx ← read -let s ← get -let assignment := s.assignments[ctx.currFnIdx] -pure $ assignment.findD x bot + let ctx ← read + let s ← get + let assignment := s.assignments[ctx.currFnIdx] + pure $ assignment.findD x bot def findArgValue (arg : Arg) : M Value := -match arg with -| Arg.var x => findVarValue x -| _ => pure top + match arg with + | Arg.var x => findVarValue x + | _ => pure top def updateVarAssignment (x : VarId) (v : Value) : M Unit := do -let v' ← findVarValue x -let ctx ← read -modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') } + let v' ← findVarValue x + let ctx ← read + modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') } def resetVarAssignment (x : VarId) : M Unit := do -let ctx ← read -modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot } + let ctx ← read + modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot } def resetParamAssignment (y : Param) : M Unit := -resetVarAssignment y.x + resetVarAssignment y.x partial def projValue : Value → Nat → Value -| ctor _ vs, i => vs.getD i bot -| choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot -| v, _ => v + | ctor _ vs, i => vs.getD i bot + | choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot + | v, _ => v def interpExpr : Expr → M Value -| Expr.ctor i ys => do return ctor i (← ys.mapM fun y => findArgValue y) -| Expr.proj i x => do return projValue (← findVarValue x) i -| Expr.fap fid ys => do - let ctx ← read - match getFunctionSummary? ctx.env fid with - | some v => pure v - | none => do - let s ← get - match ctx.decls.findIdx? (fun decl => decl.name == fid) with - | some idx => pure s.funVals[idx] - | none => pure top -| _ => pure top + | Expr.ctor i ys => do return ctor i (← ys.mapM fun y => findArgValue y) + | Expr.proj i x => do return projValue (← findVarValue x) i + | Expr.fap fid ys => do + let ctx ← read + match getFunctionSummary? ctx.env fid with + | some v => pure v + | none => do + let s ← get + match ctx.decls.findIdx? (fun decl => decl.name == fid) with + | some idx => pure s.funVals[idx] + | none => pure top + | _ => pure top partial def containsCtor : Value → CtorInfo → Bool -| top, _ => true -| ctor i _, j => i == j -| choice vs, j => vs.any $ fun v => containsCtor v j -| _, _ => false + | top, _ => true + | ctor i _, j => i == j + | choice vs, j => vs.any $ fun v => containsCtor v j + | _, _ => false def updateCurrFnSummary (v : Value) : M Unit := do -let ctx ← read -let currFnIdx := ctx.currFnIdx -modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') } + let ctx ← read + let currFnIdx := ctx.currFnIdx + modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') } /-- Return true if the assignment of at least one parameter has been updated. -/ def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do -let ctx ← read -let currFnIdx := ctx.currFnIdx -ys.size.foldM (init := false) fun i r => do - let y := ys[i] - let x := xs[i] - let yVal ← findVarValue y.x - let xVal ← findArgValue x - let newVal := merge yVal xVal - if newVal == yVal then - pure r - else - modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal } - pure true - -private partial def resetNestedJPParams : FnBody → M Unit -| FnBody.jdecl _ ys b k => do let ctx ← read let currFnIdx := ctx.currFnIdx - ys.forM resetParamAssignment - /- Remark we don't need to reset the parameters of joint-points - nested in `b` since they will be reset if this JP is used. -/ - resetNestedJPParams k -| FnBody.case _ _ _ alts => - alts.forM fun alt => match alt with - | Alt.ctor _ b => resetNestedJPParams b - | Alt.default b => resetNestedJPParams b -| e => do unless e.isTerminal do resetNestedJPParams e.body + ys.size.foldM (init := false) fun i r => do + let y := ys[i] + let x := xs[i] + let yVal ← findVarValue y.x + let xVal ← findArgValue x + let newVal := merge yVal xVal + if newVal == yVal then + pure r + else + modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal } + pure true + +private partial def resetNestedJPParams : FnBody → M Unit + | FnBody.jdecl _ ys b k => do + let ctx ← read + let currFnIdx := ctx.currFnIdx + ys.forM resetParamAssignment + /- Remark we don't need to reset the parameters of joint-points + nested in `b` since they will be reset if this JP is used. -/ + resetNestedJPParams k + | FnBody.case _ _ _ alts => + alts.forM fun alt => match alt with + | Alt.ctor _ b => resetNestedJPParams b + | Alt.default b => resetNestedJPParams b + | e => do unless e.isTerminal do resetNestedJPParams e.body partial def interpFnBody : FnBody → M Unit -| FnBody.vdecl x _ e b => do - let v ← interpExpr e - updateVarAssignment x v - interpFnBody b -| FnBody.jdecl j ys v b => - withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) do + | FnBody.vdecl x _ e b => do + let v ← interpExpr e + updateVarAssignment x v interpFnBody b -| FnBody.case _ x _ alts => do - let v ← findVarValue x - alts.forM fun alt => do - match alt with - | Alt.ctor i b => if containsCtor v i then interpFnBody b - | Alt.default b => interpFnBody b -| FnBody.ret x => do - let v ← findArgValue x - -- dbgTrace ("ret " ++ toString v) $ fun _ => - updateCurrFnSummary v -| FnBody.jmp j xs => do - let ctx ← read - let ys := (ctx.lctx.getJPParams j).get! - let b := (ctx.lctx.getJPBody j).get! - let updated ← updateJPParamsAssignment ys xs - if updated then - -- We must reset the value of nested join-point parameters since they depend on `ys` values - resetNestedJPParams b - interpFnBody b -| e => do - unless e.isTerminal do - interpFnBody e.body + | FnBody.jdecl j ys v b => + withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) do + interpFnBody b + | FnBody.case _ x _ alts => do + let v ← findVarValue x + alts.forM fun alt => do + match alt with + | Alt.ctor i b => if containsCtor v i then interpFnBody b + | Alt.default b => interpFnBody b + | FnBody.ret x => do + let v ← findArgValue x + -- dbgTrace ("ret " ++ toString v) $ fun _ => + updateCurrFnSummary v + | FnBody.jmp j xs => do + let ctx ← read + let ys := (ctx.lctx.getJPParams j).get! + let b := (ctx.lctx.getJPBody j).get! + let updated ← updateJPParamsAssignment ys xs + if updated then + -- We must reset the value of nested join-point parameters since they depend on `ys` values + resetNestedJPParams b + interpFnBody b + | e => do + unless e.isTerminal do + interpFnBody e.body def inferStep : M Bool := do -let ctx ← read -modify fun s => { s with assignments := ctx.decls.map fun _ => {} } -ctx.decls.size.foldM (init := false) fun idx modified => do - match ctx.decls[idx] with - | Decl.fdecl fid ys _ b => do - let s ← get - let currVals := s.funVals[idx] - withReader (fun ctx => { ctx with currFnIdx := idx }) do - ys.forM fun y => updateVarAssignment y.x top - interpFnBody b + let ctx ← read + modify fun s => { s with assignments := ctx.decls.map fun _ => {} } + ctx.decls.size.foldM (init := false) fun idx modified => do + match ctx.decls[idx] with + | Decl.fdecl fid ys _ b => do let s ← get - let newVals := s.funVals[idx] - pure (modified || currVals != newVals) - | Decl.extern _ _ _ _ => pure modified + let currVals := s.funVals[idx] + withReader (fun ctx => { ctx with currFnIdx := idx }) do + ys.forM fun y => updateVarAssignment y.x top + interpFnBody b + let s ← get + let newVals := s.funVals[idx] + pure (modified || currVals != newVals) + | Decl.extern _ _ _ _ => pure modified -partial def inferMain : Unit → M Unit -| _ => do +partial def inferMain : M Unit := do let modified ← inferStep - if modified then inferMain () else pure () + if modified then inferMain else pure () partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody -| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b) -| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b) -| FnBody.case tid x xType alts => - let v := assignment.findD x bot - let alts := alts.map fun alt => - match alt with - | Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable - | Alt.default b => Alt.default (elimDeadAux assignment b) - FnBody.case tid x xType alts -| e => - if e.isTerminal then e - else - let (instr, b) := e.split - let b := elimDeadAux assignment b - instr.setBody b + | FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b) + | FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b) + | FnBody.case tid x xType alts => + let v := assignment.findD x bot + let alts := alts.map fun alt => + match alt with + | Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable + | Alt.default b => Alt.default (elimDeadAux assignment b) + FnBody.case tid x xType alts + | e => + if e.isTerminal then e + else + let (instr, b) := e.split + let b := elimDeadAux assignment b + instr.setBody b partial def elimDead (assignment : Assignment) : Decl → Decl -| Decl.fdecl fid ys t b => Decl.fdecl fid ys t $ elimDeadAux assignment b -| other => other + | Decl.fdecl fid ys t b => Decl.fdecl fid ys t $ elimDeadAux assignment b + | other => other end UnreachableBranches open UnreachableBranches def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do -let s ← get -let env := s.env -let assignments : Array Assignment := decls.map fun _ => {} -let funVals := Std.mkPArray decls.size Value.bot -let ctx : InterpContext := { decls := decls, env := env } -let s : InterpState := { assignments := assignments, funVals := funVals } -let (_, s) := (inferMain () ctx).run s -let funVals := s.funVals -let assignments := s.assignments -modify fun s => - let env := decls.size.fold (init := s.env) fun i env => - addFunctionSummary env decls[i].name funVals[i] - { s with env := env } -pure $ decls.mapIdx fun i decl => elimDead assignments[i] decl + let s ← get + let env := s.env + let assignments : Array Assignment := decls.map fun _ => {} + let funVals := Std.mkPArray decls.size Value.bot + let ctx : InterpContext := { decls := decls, env := env } + let s : InterpState := { assignments := assignments, funVals := funVals } + let (_, s) := (inferMain ctx).run s + let funVals := s.funVals + let assignments := s.assignments + modify fun s => + let env := decls.size.fold (init := s.env) fun i env => + addFunctionSummary env decls[i].name funVals[i] + { s with env := env } + pure $ decls.mapIdx fun i decl => elimDead assignments[i] decl end Lean.IR diff --git a/src/Lean/Compiler/IR/ElimDeadVars.lean b/src/Lean/Compiler/IR/ElimDeadVars.lean index 1772f41c81..145ade379a 100644 --- a/src/Lean/Compiler/IR/ElimDeadVars.lean +++ b/src/Lean/Compiler/IR/ElimDeadVars.lean @@ -8,30 +8,27 @@ import Lean.Compiler.IR.FreeVars namespace Lean.IR -partial def reshapeWithoutDeadAux : Array FnBody → FnBody → IndexSet → FnBody -| bs, b, used => - if bs.isEmpty then b - else - let curr := bs.back - let bs := bs.pop - let keep (_ : Unit) := - let used := curr.collectFreeIndices used - let b := curr.setBody b - reshapeWithoutDeadAux bs b used - let keepIfUsed (vidx : Index) := - if used.contains vidx then keep () - else reshapeWithoutDeadAux bs b used - match curr with - | FnBody.vdecl x _ _ _ => keepIfUsed x.idx - -- TODO: we should keep all struct/union projections because they are used to ensure struct/union values are fully consumed. - | FnBody.jdecl j _ _ _ => keepIfUsed j.idx - | _ => keep () +partial def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody := + let rec reshape (bs : Array FnBody) (b : FnBody) (used : IndexSet) := + if bs.isEmpty then b + else + let curr := bs.back + let bs := bs.pop + let keep (_ : Unit) := + let used := curr.collectFreeIndices used + let b := curr.setBody b + reshape bs b used + let keepIfUsed (vidx : Index) := + if used.contains vidx then keep () + else reshape bs b used + match curr with + | FnBody.vdecl x _ _ _ => keepIfUsed x.idx + -- TODO: we should keep all struct/union projections because they are used to ensure struct/union values are fully consumed. + | FnBody.jdecl j _ _ _ => keepIfUsed j.idx + | _ => keep () + reshape bs term term.freeIndices -def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody := -reshapeWithoutDeadAux bs term term.freeIndices - -partial def FnBody.elimDead : FnBody → FnBody -| b => +partial def FnBody.elimDead (b : FnBody) : FnBody := let (bs, term) := b.flatten let bs := modifyJPs bs elimDead let term := match term with @@ -43,7 +40,7 @@ partial def FnBody.elimDead : FnBody → FnBody /-- Eliminate dead let-declarations and join points -/ def Decl.elimDead : Decl → Decl -| Decl.fdecl f xs t b => Decl.fdecl f xs t b.elimDead -| other => other + | Decl.fdecl f xs t b => Decl.fdecl f xs t b.elimDead + | other => other end Lean.IR diff --git a/src/Lean/Compiler/IR/EmitC.lean b/src/Lean/Compiler/IR/EmitC.lean index b93b18beae..c30a390754 100644 --- a/src/Lean/Compiler/IR/EmitC.lean +++ b/src/Lean/Compiler/IR/EmitC.lean @@ -19,501 +19,501 @@ open ExplicitBoxing (requiresBoxedVersion mkBoxedName isBoxedName) def leanMainFn := "_lean_main" structure Context := -(env : Environment) -(modName : Name) -(jpMap : JPParamsMap := {}) -(mainFn : FunId := arbitrary _) -(mainParams : Array Param := #[]) + (env : Environment) + (modName : Name) + (jpMap : JPParamsMap := {}) + (mainFn : FunId := arbitrary _) + (mainParams : Array Param := #[]) abbrev M := ReaderT Context (EStateM String String) def getEnv : M Environment := Context.env <$> read def getModName : M Name := Context.modName <$> read def getDecl (n : Name) : M Decl := do -let env ← getEnv -match findEnvDecl env n with -| some d => pure d -| none => throw s!"unknown declaration '{n}'" + let env ← getEnv + match findEnvDecl env n with + | some d => pure d + | none => throw s!"unknown declaration '{n}'" @[inline] def emit {α : Type} [ToString α] (a : α) : M Unit := -modify fun out => out ++ toString a + modify fun out => out ++ toString a @[inline] def emitLn {α : Type} [ToString α] (a : α) : M Unit := do -emit a; emit "\n" + emit a; emit "\n" def emitLns {α : Type} [ToString α] (as : List α) : M Unit := -as.forM fun a => emitLn a + as.forM fun a => emitLn a def argToCString (x : Arg) : String := -match x with -| Arg.var x => toString x -| _ => "lean_box(0)" + match x with + | Arg.var x => toString x + | _ => "lean_box(0)" def emitArg (x : Arg) : M Unit := -emit (argToCString x) + emit (argToCString x) def toCType : IRType → String -| IRType.float => "double" -| IRType.uint8 => "uint8_t" -| IRType.uint16 => "uint16_t" -| IRType.uint32 => "uint32_t" -| IRType.uint64 => "uint64_t" -| IRType.usize => "size_t" -| IRType.object => "lean_object*" -| IRType.tobject => "lean_object*" -| IRType.irrelevant => "lean_object*" -| IRType.struct _ _ => panic! "not implemented yet" -| IRType.union _ _ => panic! "not implemented yet" + | IRType.float => "double" + | IRType.uint8 => "uint8_t" + | IRType.uint16 => "uint16_t" + | IRType.uint32 => "uint32_t" + | IRType.uint64 => "uint64_t" + | IRType.usize => "size_t" + | IRType.object => "lean_object*" + | IRType.tobject => "lean_object*" + | IRType.irrelevant => "lean_object*" + | IRType.struct _ _ => panic! "not implemented yet" + | IRType.union _ _ => panic! "not implemented yet" def throwInvalidExportName {α : Type} (n : Name) : M α := -throw s!"invalid export name '{n}'" + throw s!"invalid export name '{n}'" def toCName (n : Name) : M String := do -let env ← getEnv; --- TODO: we should support simple export names only -match getExportNameFor env n with -| some (Name.str Name.anonymous s _) => pure s -| some _ => throwInvalidExportName n -| none => if n == `main then pure leanMainFn else pure n.mangle + let env ← getEnv; + -- TODO: we should support simple export names only + match getExportNameFor env n with + | some (Name.str Name.anonymous s _) => pure s + | some _ => throwInvalidExportName n + | none => if n == `main then pure leanMainFn else pure n.mangle def emitCName (n : Name) : M Unit := -toCName n >>= emit + toCName n >>= emit def toCInitName (n : Name) : M String := do -let env ← getEnv; --- TODO: we should support simple export names only -match getExportNameFor env n with -| some (Name.str Name.anonymous s _) => pure $ "_init_" ++ s -| some _ => throwInvalidExportName n -| none => pure ("_init_" ++ n.mangle) + let env ← getEnv; + -- TODO: we should support simple export names only + match getExportNameFor env n with + | some (Name.str Name.anonymous s _) => pure $ "_init_" ++ s + | some _ => throwInvalidExportName n + | none => pure ("_init_" ++ n.mangle) def emitCInitName (n : Name) : M Unit := -toCInitName n >>= emit + toCInitName n >>= emit def emitFnDeclAux (decl : Decl) (cppBaseName : String) (addExternForConsts : Bool) : M Unit := do -let ps := decl.params -let env ← getEnv -if ps.isEmpty && addExternForConsts then emit "extern " -emit (toCType decl.resultType ++ " " ++ cppBaseName) -unless ps.isEmpty do - emit "(" - -- We omit irrelevant parameters for extern constants - let ps := if isExternC env decl.name then ps.filter (fun p => !p.ty.isIrrelevant) else ps - if ps.size > closureMaxArgs && isBoxedName decl.name then - emit "lean_object**" - else - ps.size.forM fun i => do - if i > 0 then emit ", " - emit (toCType ps[i].ty) - emit ")" -emitLn ";" - -def emitFnDecl (decl : Decl) (addExternForConsts : Bool) : M Unit := do -let cppBaseName ← toCName decl.name -emitFnDeclAux decl cppBaseName addExternForConsts - -def emitExternDeclAux (decl : Decl) (cNameStr : String) : M Unit := do -let cName := mkNameSimple cNameStr -let env ← getEnv -let extC := isExternC env decl.name -emitFnDeclAux decl cNameStr (!extC) - -def emitFnDecls : M Unit := do -let env ← getEnv -let decls := getDecls env -let modDecls : NameSet := decls.foldl (fun s d => s.insert d.name) {} -let usedDecls : NameSet := decls.foldl (fun s d => collectUsedDecls env d (s.insert d.name)) {} -let usedDecls := usedDecls.toList -usedDecls.forM fun n => do - let decl ← getDecl n; - match getExternNameFor env `c decl.name with - | some cName => emitExternDeclAux decl cName - | none => emitFnDecl decl (!modDecls.contains n) - -def emitMainFn : M Unit := do -let d ← getDecl `main -match d with -| Decl.fdecl f xs t b => do - unless xs.size == 2 || xs.size == 1 do throw "invalid main function, incorrect arity when generating code" + let ps := decl.params let env ← getEnv - let usesLeanAPI := usesModuleFrom env `Lean - if usesLeanAPI then - emitLn "void lean_initialize();" - else - emitLn "void lean_initialize_runtime_module();"; - emitLn " -#if defined(WIN32) || defined(_WIN32) -#include -#endif - -int main(int argc, char ** argv) { -#if defined(WIN32) || defined(_WIN32) -SetErrorMode(SEM_FAILCRITICALERRORS); -#endif -lean_object* in; lean_object* res;"; - if usesLeanAPI then - emitLn "lean_initialize();" - else - emitLn "lean_initialize_runtime_module();" - let modName ← getModName - emitLn ("res = " ++ mkModuleInitializationFunctionName modName ++ "(lean_io_mk_world());") - emitLns ["lean_io_mark_end_initialization();", - "if (lean_io_result_is_ok(res)) {", - "lean_dec_ref(res);", - "lean_init_task_manager();"]; - if xs.size == 2 then - emitLns ["in = lean_box(0);", - "int i = argc;", - "while (i > 1) {", - " lean_object* n;", - " i--;", - " n = lean_alloc_ctor(1,2,0); lean_ctor_set(n, 0, lean_mk_string(argv[i])); lean_ctor_set(n, 1, in);", - " in = n;", - "}"] - emitLn ("res = " ++ leanMainFn ++ "(in, lean_io_mk_world());") - else - emitLn ("res = " ++ leanMainFn ++ "(lean_io_mk_world());") - emitLn "}" - emitLns ["if (lean_io_result_is_ok(res)) {", - " int ret = lean_unbox(lean_io_result_get_value(res));", - " lean_dec_ref(res);", - " return ret;", - "} else {", - " lean_io_result_show_error(res);", - " lean_dec_ref(res);", - " return 1;", - "}"] - emitLn "}" -| other => throw "function declaration expected" - -def hasMainFn : M Bool := do -let env ← getEnv -let decls := getDecls env -pure $ decls.any (fun d => d.name == `main) - -def emitMainFnIfNeeded : M Unit := do -if (← hasMainFn) then emitMainFn - -def emitFileHeader : M Unit := do -let env ← getEnv -let modName ← getModName -emitLn "// Lean compiler output" -emitLn ("// Module: " ++ toString modName) -emit "// Imports:" -env.imports.forM fun m => emit (" " ++ toString m) -emitLn "" -emitLn "#include " -emitLns [ - "#if defined(__clang__)", - "#pragma clang diagnostic ignored \"-Wunused-parameter\"", - "#pragma clang diagnostic ignored \"-Wunused-label\"", - "#elif defined(__GNUC__) && !defined(__CLANG__)", - "#pragma GCC diagnostic ignored \"-Wunused-parameter\"", - "#pragma GCC diagnostic ignored \"-Wunused-label\"", - "#pragma GCC diagnostic ignored \"-Wunused-but-set-variable\"", - "#endif", - "#ifdef __cplusplus", - "extern \"C\" {", - "#endif" -] - -def emitFileFooter : M Unit := -emitLns [ - "#ifdef __cplusplus", - "}", - "#endif" -] - -def throwUnknownVar {α : Type} (x : VarId) : M α := -throw s!"unknown variable '{x}'" - -def getJPParams (j : JoinPointId) : M (Array Param) := do -let ctx ← read; -match ctx.jpMap.find? j with -| some ps => pure ps -| none => throw "unknown join point" - -def declareVar (x : VarId) (t : IRType) : M Unit := do -emit (toCType t); emit " "; emit x; emit "; " - -def declareParams (ps : Array Param) : M Unit := -ps.forM fun p => declareVar p.x p.ty - -partial def declareVars : FnBody → Bool → M Bool -| e@(FnBody.vdecl x t _ b), d => do - let ctx ← read - if isTailCallTo ctx.mainFn e then - pure d - else - declareVar x t; declareVars b true -| FnBody.jdecl j xs _ b, d => do declareParams xs; declareVars b (d || xs.size > 0) -| e, d => if e.isTerminal then pure d else declareVars e.body d - -def emitTag (x : VarId) (xType : IRType) : M Unit := do -if xType.isObj then do - emit "lean_obj_tag("; emit x; emit ")" -else - emit x - -def isIf (alts : Array Alt) : Option (Nat × FnBody × FnBody) := -if alts.size != 2 then none -else match alts[0] with - | Alt.ctor c b => some (c.cidx, b, alts[1].body) - | _ => none - -def emitInc (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do -emit $ - if checkRef then (if n == 1 then "lean_inc" else "lean_inc_n") - else (if n == 1 then "lean_inc_ref" else "lean_inc_ref_n") -emit "("; emit x -if n != 1 then emit ", "; emit n -emitLn ");" - -def emitDec (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do -emit (if checkRef then "lean_dec" else "lean_dec_ref"); -emit "("; emit x; -if n != 1 then emit ", "; emit n -emitLn ");" - -def emitDel (x : VarId) : M Unit := do -emit "lean_free_object("; emit x; emitLn ");" - -def emitSetTag (x : VarId) (i : Nat) : M Unit := do -emit "lean_ctor_set_tag("; emit x; emit ", "; emit i; emitLn ");" - -def emitSet (x : VarId) (i : Nat) (y : Arg) : M Unit := do -emit "lean_ctor_set("; emit x; emit ", "; emit i; emit ", "; emitArg y; emitLn ");" - -def emitOffset (n : Nat) (offset : Nat) : M Unit := do -if n > 0 then - emit "sizeof(void*)*"; emit n; - if offset > 0 then emit " + "; emit offset -else - emit offset - -def emitUSet (x : VarId) (n : Nat) (y : VarId) : M Unit := do -emit "lean_ctor_set_usize("; emit x; emit ", "; emit n; emit ", "; emit y; emitLn ");" - -def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M Unit := do -match t with -| IRType.float => emit "lean_ctor_set_float" -| IRType.uint8 => emit "lean_ctor_set_uint8" -| IRType.uint16 => emit "lean_ctor_set_uint16" -| IRType.uint32 => emit "lean_ctor_set_uint32" -| IRType.uint64 => emit "lean_ctor_set_uint64" -| _ => throw "invalid instruction"; -emit "("; emit x; emit ", "; emitOffset n offset; emit ", "; emit y; emitLn ");" - -def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do -let ps ← getJPParams j -unless xs.size == ps.size do throw "invalid goto" -xs.size.forM fun i => do - let p := ps[i] - let x := xs[i] - emit p.x; emit " = "; emitArg x; emitLn ";" -emit "goto "; emit j; emitLn ";" - -def emitLhs (z : VarId) : M Unit := do -emit z; emit " = " - -def emitArgs (ys : Array Arg) : M Unit := -ys.size.forM fun i => do - if i > 0 then emit ", " - emitArg ys[i] - -def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit := do -if usize == 0 then emit ssize -else if ssize == 0 then emit "sizeof(size_t)*"; emit usize -else emit "sizeof(size_t)*"; emit usize; emit " + "; emit ssize - -def emitAllocCtor (c : CtorInfo) : M Unit := do -emit "lean_alloc_ctor("; emit c.cidx; emit ", "; emit c.size; emit ", "; -emitCtorScalarSize c.usize c.ssize; emitLn ");" - -def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit := -ys.size.forM fun i => do - emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]; emitLn ");" - -def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do -emitLhs z; -if c.size == 0 && c.usize == 0 && c.ssize == 0 then do - emit "lean_box("; emit c.cidx; emitLn ");" -else do - emitAllocCtor c; emitCtorSetArgs z ys - -def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit := do -emit "if (lean_is_exclusive("; emit x; emitLn ")) {"; -n.forM fun i => do - emit " lean_ctor_release("; emit x; emit ", "; emit i; emitLn ");" -emit " "; emitLhs z; emit x; emitLn ";"; -emitLn "} else {"; -emit " lean_dec_ref("; emit x; emitLn ");"; -emit " "; emitLhs z; emitLn "lean_box(0);"; -emitLn "}" - -def emitReuse (z : VarId) (x : VarId) (c : CtorInfo) (updtHeader : Bool) (ys : Array Arg) : M Unit := do -emit "if (lean_is_scalar("; emit x; emitLn ")) {"; -emit " "; emitLhs z; emitAllocCtor c; -emitLn "} else {"; -emit " "; emitLhs z; emit x; emitLn ";"; -if updtHeader then emit " lean_ctor_set_tag("; emit z; emit ", "; emit c.cidx; emitLn ");" -emitLn "}"; -emitCtorSetArgs z ys - -def emitProj (z : VarId) (i : Nat) (x : VarId) : M Unit := do -emitLhs z; emit "lean_ctor_get("; emit x; emit ", "; emit i; emitLn ");" - -def emitUProj (z : VarId) (i : Nat) (x : VarId) : M Unit := do -emitLhs z; emit "lean_ctor_get_usize("; emit x; emit ", "; emit i; emitLn ");" - -def emitSProj (z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M Unit := do -emitLhs z; -match t with -| IRType.float => emit "lean_ctor_get_float" -| IRType.uint8 => emit "lean_ctor_get_uint8" -| IRType.uint16 => emit "lean_ctor_get_uint16" -| IRType.uint32 => emit "lean_ctor_get_uint32" -| IRType.uint64 => emit "lean_ctor_get_uint64" -| _ => throw "invalid instruction" -emit "("; emit x; emit ", "; emitOffset n offset; emitLn ");" - -def toStringArgs (ys : Array Arg) : List String := -ys.toList.map argToCString - -def emitSimpleExternalCall (f : String) (ps : Array Param) (ys : Array Arg) : M Unit := do -emit f; emit "(" --- We must remove irrelevant arguments to extern calls. -ys.size.foldM - (fun i (first : Bool) => - if ps[i].ty.isIrrelevant then - pure first - else do - unless first do emit ", " - emitArg ys[i] - pure false) - true -emitLn ");" -pure () - -def emitExternCall (f : FunId) (ps : Array Param) (extData : ExternAttrData) (ys : Array Arg) : M Unit := -match getExternEntryFor extData `c with -| some (ExternEntry.standard _ extFn) => emitSimpleExternalCall extFn ps ys -| some (ExternEntry.inline _ pat) => do emit (expandExternPattern pat (toStringArgs ys)); emitLn ";" -| some (ExternEntry.foreign _ extFn) => emitSimpleExternalCall extFn ps ys -| _ => throw s!"failed to emit extern application '{f}'" - -def emitFullApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit := do -emitLhs z -let decl ← getDecl f -match decl with -| Decl.extern _ ps _ extData => emitExternCall f ps extData ys -| _ => - emitCName f - if ys.size > 0 then emit "("; emitArgs ys; emit ")" + if ps.isEmpty && addExternForConsts then emit "extern " + emit (toCType decl.resultType ++ " " ++ cppBaseName) + unless ps.isEmpty do + emit "(" + -- We omit irrelevant parameters for extern constants + let ps := if isExternC env decl.name then ps.filter (fun p => !p.ty.isIrrelevant) else ps + if ps.size > closureMaxArgs && isBoxedName decl.name then + emit "lean_object**" + else + ps.size.forM fun i => do + if i > 0 then emit ", " + emit (toCType ps[i].ty) + emit ")" emitLn ";" +def emitFnDecl (decl : Decl) (addExternForConsts : Bool) : M Unit := do + let cppBaseName ← toCName decl.name + emitFnDeclAux decl cppBaseName addExternForConsts + +def emitExternDeclAux (decl : Decl) (cNameStr : String) : M Unit := do + let cName := mkNameSimple cNameStr + let env ← getEnv + let extC := isExternC env decl.name + emitFnDeclAux decl cNameStr (!extC) + +def emitFnDecls : M Unit := do + let env ← getEnv + let decls := getDecls env + let modDecls : NameSet := decls.foldl (fun s d => s.insert d.name) {} + let usedDecls : NameSet := decls.foldl (fun s d => collectUsedDecls env d (s.insert d.name)) {} + let usedDecls := usedDecls.toList + usedDecls.forM fun n => do + let decl ← getDecl n; + match getExternNameFor env `c decl.name with + | some cName => emitExternDeclAux decl cName + | none => emitFnDecl decl (!modDecls.contains n) + +def emitMainFn : M Unit := do + let d ← getDecl `main + match d with + | Decl.fdecl f xs t b => do + unless xs.size == 2 || xs.size == 1 do throw "invalid main function, incorrect arity when generating code" + let env ← getEnv + let usesLeanAPI := usesModuleFrom env `Lean + if usesLeanAPI then + emitLn "void lean_initialize();" + else + emitLn "void lean_initialize_runtime_module();"; + emitLn " + #if defined(WIN32) || defined(_WIN32) + #include + #endif + + int main(int argc, char ** argv) { + #if defined(WIN32) || defined(_WIN32) + SetErrorMode(SEM_FAILCRITICALERRORS); + #endif + lean_object* in; lean_object* res;"; + if usesLeanAPI then + emitLn "lean_initialize();" + else + emitLn "lean_initialize_runtime_module();" + let modName ← getModName + emitLn ("res = " ++ mkModuleInitializationFunctionName modName ++ "(lean_io_mk_world());") + emitLns ["lean_io_mark_end_initialization();", + "if (lean_io_result_is_ok(res)) {", + "lean_dec_ref(res);", + "lean_init_task_manager();"]; + if xs.size == 2 then + emitLns ["in = lean_box(0);", + "int i = argc;", + "while (i > 1) {", + " lean_object* n;", + " i--;", + " n = lean_alloc_ctor(1,2,0); lean_ctor_set(n, 0, lean_mk_string(argv[i])); lean_ctor_set(n, 1, in);", + " in = n;", + "}"] + emitLn ("res = " ++ leanMainFn ++ "(in, lean_io_mk_world());") + else + emitLn ("res = " ++ leanMainFn ++ "(lean_io_mk_world());") + emitLn "}" + emitLns ["if (lean_io_result_is_ok(res)) {", + " int ret = lean_unbox(lean_io_result_get_value(res));", + " lean_dec_ref(res);", + " return ret;", + "} else {", + " lean_io_result_show_error(res);", + " lean_dec_ref(res);", + " return 1;", + "}"] + emitLn "}" + | other => throw "function declaration expected" + +def hasMainFn : M Bool := do + let env ← getEnv + let decls := getDecls env + pure $ decls.any (fun d => d.name == `main) + +def emitMainFnIfNeeded : M Unit := do + if (← hasMainFn) then emitMainFn + +def emitFileHeader : M Unit := do + let env ← getEnv + let modName ← getModName + emitLn "// Lean compiler output" + emitLn ("// Module: " ++ toString modName) + emit "// Imports:" + env.imports.forM fun m => emit (" " ++ toString m) + emitLn "" + emitLn "#include " + emitLns [ + "#if defined(__clang__)", + "#pragma clang diagnostic ignored \"-Wunused-parameter\"", + "#pragma clang diagnostic ignored \"-Wunused-label\"", + "#elif defined(__GNUC__) && !defined(__CLANG__)", + "#pragma GCC diagnostic ignored \"-Wunused-parameter\"", + "#pragma GCC diagnostic ignored \"-Wunused-label\"", + "#pragma GCC diagnostic ignored \"-Wunused-but-set-variable\"", + "#endif", + "#ifdef __cplusplus", + "extern \"C\" {", + "#endif" + ] + +def emitFileFooter : M Unit := + emitLns [ + "#ifdef __cplusplus", + "}", + "#endif" + ] + +def throwUnknownVar {α : Type} (x : VarId) : M α := + throw s!"unknown variable '{x}'" + +def getJPParams (j : JoinPointId) : M (Array Param) := do + let ctx ← read; + match ctx.jpMap.find? j with + | some ps => pure ps + | none => throw "unknown join point" + +def declareVar (x : VarId) (t : IRType) : M Unit := do + emit (toCType t); emit " "; emit x; emit "; " + +def declareParams (ps : Array Param) : M Unit := + ps.forM fun p => declareVar p.x p.ty + +partial def declareVars : FnBody → Bool → M Bool + | e@(FnBody.vdecl x t _ b), d => do + let ctx ← read + if isTailCallTo ctx.mainFn e then + pure d + else + declareVar x t; declareVars b true + | FnBody.jdecl j xs _ b, d => do declareParams xs; declareVars b (d || xs.size > 0) + | e, d => if e.isTerminal then pure d else declareVars e.body d + +def emitTag (x : VarId) (xType : IRType) : M Unit := do + if xType.isObj then do + emit "lean_obj_tag("; emit x; emit ")" + else + emit x + +def isIf (alts : Array Alt) : Option (Nat × FnBody × FnBody) := + if alts.size != 2 then none + else match alts[0] with + | Alt.ctor c b => some (c.cidx, b, alts[1].body) + | _ => none + +def emitInc (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do + emit $ + if checkRef then (if n == 1 then "lean_inc" else "lean_inc_n") + else (if n == 1 then "lean_inc_ref" else "lean_inc_ref_n") + emit "("; emit x + if n != 1 then emit ", "; emit n + emitLn ");" + +def emitDec (x : VarId) (n : Nat) (checkRef : Bool) : M Unit := do + emit (if checkRef then "lean_dec" else "lean_dec_ref"); + emit "("; emit x; + if n != 1 then emit ", "; emit n + emitLn ");" + +def emitDel (x : VarId) : M Unit := do + emit "lean_free_object("; emit x; emitLn ");" + +def emitSetTag (x : VarId) (i : Nat) : M Unit := do + emit "lean_ctor_set_tag("; emit x; emit ", "; emit i; emitLn ");" + +def emitSet (x : VarId) (i : Nat) (y : Arg) : M Unit := do + emit "lean_ctor_set("; emit x; emit ", "; emit i; emit ", "; emitArg y; emitLn ");" + +def emitOffset (n : Nat) (offset : Nat) : M Unit := do + if n > 0 then + emit "sizeof(void*)*"; emit n; + if offset > 0 then emit " + "; emit offset + else + emit offset + +def emitUSet (x : VarId) (n : Nat) (y : VarId) : M Unit := do + emit "lean_ctor_set_usize("; emit x; emit ", "; emit n; emit ", "; emit y; emitLn ");" + +def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M Unit := do + match t with + | IRType.float => emit "lean_ctor_set_float" + | IRType.uint8 => emit "lean_ctor_set_uint8" + | IRType.uint16 => emit "lean_ctor_set_uint16" + | IRType.uint32 => emit "lean_ctor_set_uint32" + | IRType.uint64 => emit "lean_ctor_set_uint64" + | _ => throw "invalid instruction"; + emit "("; emit x; emit ", "; emitOffset n offset; emit ", "; emit y; emitLn ");" + +def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do + let ps ← getJPParams j + unless xs.size == ps.size do throw "invalid goto" + xs.size.forM fun i => do + let p := ps[i] + let x := xs[i] + emit p.x; emit " = "; emitArg x; emitLn ";" + emit "goto "; emit j; emitLn ";" + +def emitLhs (z : VarId) : M Unit := do + emit z; emit " = " + +def emitArgs (ys : Array Arg) : M Unit := + ys.size.forM fun i => do + if i > 0 then emit ", " + emitArg ys[i] + +def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit := do + if usize == 0 then emit ssize + else if ssize == 0 then emit "sizeof(size_t)*"; emit usize + else emit "sizeof(size_t)*"; emit usize; emit " + "; emit ssize + +def emitAllocCtor (c : CtorInfo) : M Unit := do + emit "lean_alloc_ctor("; emit c.cidx; emit ", "; emit c.size; emit ", "; + emitCtorScalarSize c.usize c.ssize; emitLn ");" + +def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit := + ys.size.forM fun i => do + emit "lean_ctor_set("; emit z; emit ", "; emit i; emit ", "; emitArg ys[i]; emitLn ");" + +def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit := do + emitLhs z; + if c.size == 0 && c.usize == 0 && c.ssize == 0 then do + emit "lean_box("; emit c.cidx; emitLn ");" + else do + emitAllocCtor c; emitCtorSetArgs z ys + +def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit := do + emit "if (lean_is_exclusive("; emit x; emitLn ")) {"; + n.forM fun i => do + emit " lean_ctor_release("; emit x; emit ", "; emit i; emitLn ");" + emit " "; emitLhs z; emit x; emitLn ";"; + emitLn "} else {"; + emit " lean_dec_ref("; emit x; emitLn ");"; + emit " "; emitLhs z; emitLn "lean_box(0);"; + emitLn "}" + +def emitReuse (z : VarId) (x : VarId) (c : CtorInfo) (updtHeader : Bool) (ys : Array Arg) : M Unit := do + emit "if (lean_is_scalar("; emit x; emitLn ")) {"; + emit " "; emitLhs z; emitAllocCtor c; + emitLn "} else {"; + emit " "; emitLhs z; emit x; emitLn ";"; + if updtHeader then emit " lean_ctor_set_tag("; emit z; emit ", "; emit c.cidx; emitLn ");" + emitLn "}"; + emitCtorSetArgs z ys + +def emitProj (z : VarId) (i : Nat) (x : VarId) : M Unit := do + emitLhs z; emit "lean_ctor_get("; emit x; emit ", "; emit i; emitLn ");" + +def emitUProj (z : VarId) (i : Nat) (x : VarId) : M Unit := do + emitLhs z; emit "lean_ctor_get_usize("; emit x; emit ", "; emit i; emitLn ");" + +def emitSProj (z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M Unit := do + emitLhs z; + match t with + | IRType.float => emit "lean_ctor_get_float" + | IRType.uint8 => emit "lean_ctor_get_uint8" + | IRType.uint16 => emit "lean_ctor_get_uint16" + | IRType.uint32 => emit "lean_ctor_get_uint32" + | IRType.uint64 => emit "lean_ctor_get_uint64" + | _ => throw "invalid instruction" + emit "("; emit x; emit ", "; emitOffset n offset; emitLn ");" + +def toStringArgs (ys : Array Arg) : List String := + ys.toList.map argToCString + +def emitSimpleExternalCall (f : String) (ps : Array Param) (ys : Array Arg) : M Unit := do + emit f; emit "(" + -- We must remove irrelevant arguments to extern calls. + ys.size.foldM + (fun i (first : Bool) => + if ps[i].ty.isIrrelevant then + pure first + else do + unless first do emit ", " + emitArg ys[i] + pure false) + true + emitLn ");" + pure () + +def emitExternCall (f : FunId) (ps : Array Param) (extData : ExternAttrData) (ys : Array Arg) : M Unit := + match getExternEntryFor extData `c with + | some (ExternEntry.standard _ extFn) => emitSimpleExternalCall extFn ps ys + | some (ExternEntry.inline _ pat) => do emit (expandExternPattern pat (toStringArgs ys)); emitLn ";" + | some (ExternEntry.foreign _ extFn) => emitSimpleExternalCall extFn ps ys + | _ => throw s!"failed to emit extern application '{f}'" + +def emitFullApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit := do + emitLhs z + let decl ← getDecl f + match decl with + | Decl.extern _ ps _ extData => emitExternCall f ps extData ys + | _ => + emitCName f + if ys.size > 0 then emit "("; emitArgs ys; emit ")" + emitLn ";" + def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit := do -let decl ← getDecl f -let arity := decl.params.size; -emitLhs z; emit "lean_alloc_closure((void*)("; emitCName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");"; -ys.size.forM fun i => do - let y := ys[i] - emit "lean_closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");" + let decl ← getDecl f + let arity := decl.params.size; + emitLhs z; emit "lean_alloc_closure((void*)("; emitCName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");"; + ys.size.forM fun i => do + let y := ys[i] + emit "lean_closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");" def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit := -if ys.size > closureMaxArgs then do - emit "{ lean_object* _aargs[] = {"; emitArgs ys; emitLn "};"; - emitLhs z; emit "lean_apply_m("; emit f; emit ", "; emit ys.size; emitLn ", _aargs); }" -else do - emitLhs z; emit "lean_apply_"; emit ys.size; emit "("; emit f; emit ", "; emitArgs ys; emitLn ");" + if ys.size > closureMaxArgs then do + emit "{ lean_object* _aargs[] = {"; emitArgs ys; emitLn "};"; + emitLhs z; emit "lean_apply_m("; emit f; emit ", "; emit ys.size; emitLn ", _aargs); }" + else do + emitLhs z; emit "lean_apply_"; emit ys.size; emit "("; emit f; emit ", "; emitArgs ys; emitLn ");" def emitBoxFn (xType : IRType) : M Unit := -match xType with -| IRType.usize => emit "lean_box_usize" -| IRType.uint32 => emit "lean_box_uint32" -| IRType.uint64 => emit "lean_box_uint64" -| IRType.float => emit "lean_box_float" -| other => emit "lean_box" + match xType with + | IRType.usize => emit "lean_box_usize" + | IRType.uint32 => emit "lean_box_uint32" + | IRType.uint64 => emit "lean_box_uint64" + | IRType.float => emit "lean_box_float" + | other => emit "lean_box" def emitBox (z : VarId) (x : VarId) (xType : IRType) : M Unit := do -emitLhs z; emitBoxFn xType; emit "("; emit x; emitLn ");" + emitLhs z; emitBoxFn xType; emit "("; emit x; emitLn ");" def emitUnbox (z : VarId) (t : IRType) (x : VarId) : M Unit := do -emitLhs z; -match t with -| IRType.usize => emit "lean_unbox_usize" -| IRType.uint32 => emit "lean_unbox_uint32" -| IRType.uint64 => emit "lean_unbox_uint64" -| IRType.float => emit "lean_unbox_float" -| other => emit "lean_unbox"; -emit "("; emit x; emitLn ");" + emitLhs z; + match t with + | IRType.usize => emit "lean_unbox_usize" + | IRType.uint32 => emit "lean_unbox_uint32" + | IRType.uint64 => emit "lean_unbox_uint64" + | IRType.float => emit "lean_unbox_float" + | other => emit "lean_unbox"; + emit "("; emit x; emitLn ");" def emitIsShared (z : VarId) (x : VarId) : M Unit := do -emitLhs z; emit "!lean_is_exclusive("; emit x; emitLn ");" + emitLhs z; emit "!lean_is_exclusive("; emit x; emitLn ");" def emitIsTaggedPtr (z : VarId) (x : VarId) : M Unit := do -emitLhs z; emit "!lean_is_scalar("; emit x; emitLn ");" + emitLhs z; emit "!lean_is_scalar("; emit x; emitLn ");" def toHexDigit (c : Nat) : String := -String.singleton c.digitChar + String.singleton c.digitChar def quoteString (s : String) : String := -let q := "\""; -let q := s.foldl - (fun q c => q ++ - if c == '\n' then "\\n" - else if c == '\n' then "\\t" - else if c == '\\' then "\\\\" - else if c == '\"' then "\\\"" - else if c.toNat <= 31 then - "\\x" ++ toHexDigit (c.toNat / 16) ++ toHexDigit (c.toNat % 16) - -- TODO(Leo): we should use `\unnnn` for escaping unicode characters. - else String.singleton c) - q; -q ++ "\"" + let q := "\""; + let q := s.foldl + (fun q c => q ++ + if c == '\n' then "\\n" + else if c == '\n' then "\\t" + else if c == '\\' then "\\\\" + else if c == '\"' then "\\\"" + else if c.toNat <= 31 then + "\\x" ++ toHexDigit (c.toNat / 16) ++ toHexDigit (c.toNat % 16) + -- TODO(Leo): we should use `\unnnn` for escaping unicode characters. + else String.singleton c) + q; + q ++ "\"" def emitNumLit (t : IRType) (v : Nat) : M Unit := do -if t.isObj then - if v < uint32Sz then - emit "lean_unsigned_to_nat("; emit v; emit "u)" + if t.isObj then + if v < uint32Sz then + emit "lean_unsigned_to_nat("; emit v; emit "u)" + else + emit "lean_cstr_to_nat(\""; emit v; emit "\")" else - emit "lean_cstr_to_nat(\""; emit v; emit "\")" -else - emit v + emit v def emitLit (z : VarId) (t : IRType) (v : LitVal) : M Unit := do -emitLhs z; -match v with -| LitVal.num v => emitNumLit t v; emitLn ";" -| LitVal.str v => emit "lean_mk_string("; emit (quoteString v); emitLn ");" + emitLhs z; + match v with + | LitVal.num v => emitNumLit t v; emitLn ";" + | LitVal.str v => emit "lean_mk_string("; emit (quoteString v); emitLn ");" def emitVDecl (z : VarId) (t : IRType) (v : Expr) : M Unit := -match v with -| Expr.ctor c ys => emitCtor z c ys -| Expr.reset n x => emitReset z n x -| Expr.reuse x c u ys => emitReuse z x c u ys -| Expr.proj i x => emitProj z i x -| Expr.uproj i x => emitUProj z i x -| Expr.sproj n o x => emitSProj z t n o x -| Expr.fap c ys => emitFullApp z c ys -| Expr.pap c ys => emitPartialApp z c ys -| Expr.ap x ys => emitApp z x ys -| Expr.box t x => emitBox z x t -| Expr.unbox x => emitUnbox z t x -| Expr.isShared x => emitIsShared z x -| Expr.isTaggedPtr x => emitIsTaggedPtr z x -| Expr.lit v => emitLit z t v + match v with + | Expr.ctor c ys => emitCtor z c ys + | Expr.reset n x => emitReset z n x + | Expr.reuse x c u ys => emitReuse z x c u ys + | Expr.proj i x => emitProj z i x + | Expr.uproj i x => emitUProj z i x + | Expr.sproj n o x => emitSProj z t n o x + | Expr.fap c ys => emitFullApp z c ys + | Expr.pap c ys => emitPartialApp z c ys + | Expr.ap x ys => emitApp z x ys + | Expr.box t x => emitBox z x t + | Expr.unbox x => emitUnbox z t x + | Expr.isShared x => emitIsShared z x + | Expr.isTaggedPtr x => emitIsTaggedPtr z x + | Expr.lit v => emitLit z t v def isTailCall (x : VarId) (v : Expr) (b : FnBody) : M Bool := do -let ctx ← read; -match v, b with -| Expr.fap f _, FnBody.ret (Arg.var y) => pure $ f == ctx.mainFn && x == y -| _, _ => pure false + let ctx ← read; + match v, b with + | Expr.fap f _, FnBody.ret (Arg.var y) => pure $ f == ctx.mainFn && x == y + | _, _ => pure false def paramEqArg (p : Param) (x : Arg) : Bool := -match x with -| Arg.var x => p.x == x -| _ => false + match x with + | Arg.var x => p.x == x + | _ => false /- Given `[p_0, ..., p_{n-1}]`, `[y_0, ..., y_{n-1}]`, representing the assignments @@ -531,201 +531,201 @@ That is, we have ``` -/ def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool := -let n := ps.size; -n.any $ fun i => - let p := ps[i] - (i+1, n).anyI fun j => paramEqArg p ys[j] + let n := ps.size; + n.any $ fun i => + let p := ps[i] + (i+1, n).anyI fun j => paramEqArg p ys[j] def emitTailCall (v : Expr) : M Unit := -match v with -| Expr.fap _ ys => do - let ctx ← read - let ps := ctx.mainParams - unless ps.size == ys.size do throw "invalid tail call" - if overwriteParam ps ys then - emitLn "{" - ps.size.forM fun i => do - let p := ps[i] - let y := ys[i] - unless paramEqArg p y do - emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";" - ps.size.forM fun i => do - let p := ps[i] - let y := ys[i] - unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";" - emitLn "}" - else - ys.size.forM fun i => do - let p := ps[i] - let y := ys[i] - unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";" - emitLn "goto _start;" -| _ => throw "bug at emitTailCall" + match v with + | Expr.fap _ ys => do + let ctx ← read + let ps := ctx.mainParams + unless ps.size == ys.size do throw "invalid tail call" + if overwriteParam ps ys then + emitLn "{" + ps.size.forM fun i => do + let p := ps[i] + let y := ys[i] + unless paramEqArg p y do + emit (toCType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";" + ps.size.forM fun i => do + let p := ps[i] + let y := ys[i] + unless paramEqArg p y do emit p.x; emit " = _tmp_"; emit i; emitLn ";" + emitLn "}" + else + ys.size.forM fun i => do + let p := ps[i] + let y := ys[i] + unless paramEqArg p y do emit p.x; emit " = "; emitArg y; emitLn ";" + emitLn "goto _start;" + | _ => throw "bug at emitTailCall" mutual partial def emitIf (x : VarId) (xType : IRType) (tag : Nat) (t : FnBody) (e : FnBody) : M Unit := do -emit "if ("; emitTag x xType; emit " == "; emit tag; emitLn ")"; -emitFnBody t; -emitLn "else"; -emitFnBody e + emit "if ("; emitTag x xType; emit " == "; emit tag; emitLn ")"; + emitFnBody t; + emitLn "else"; + emitFnBody e partial def emitCase (x : VarId) (xType : IRType) (alts : Array Alt) : M Unit := -match isIf alts with -| some (tag, t, e) => emitIf x xType tag t e -| _ => do - emit "switch ("; emitTag x xType; emitLn ") {"; - let alts := ensureHasDefault alts; - alts.forM fun alt => do - match alt with - | Alt.ctor c b => emit "case "; emit c.cidx; emitLn ":"; emitFnBody b - | Alt.default b => emitLn "default: "; emitFnBody b - emitLn "}" + match isIf alts with + | some (tag, t, e) => emitIf x xType tag t e + | _ => do + emit "switch ("; emitTag x xType; emitLn ") {"; + let alts := ensureHasDefault alts; + alts.forM fun alt => do + match alt with + | Alt.ctor c b => emit "case "; emit c.cidx; emitLn ":"; emitFnBody b + | Alt.default b => emitLn "default: "; emitFnBody b + emitLn "}" partial def emitBlock (b : FnBody) : M Unit := do -match b with -| FnBody.jdecl j xs v b => emitBlock b -| d@(FnBody.vdecl x t v b) => - let ctx ← read - if isTailCallTo ctx.mainFn d then - emitTailCall v - else - emitVDecl x t v + match b with + | FnBody.jdecl j xs v b => emitBlock b + | d@(FnBody.vdecl x t v b) => + let ctx ← read + if isTailCallTo ctx.mainFn d then + emitTailCall v + else + emitVDecl x t v + emitBlock b + | FnBody.inc x n c p b => + unless p do emitInc x n c emitBlock b -| FnBody.inc x n c p b => - unless p do emitInc x n c - emitBlock b -| FnBody.dec x n c p b => - unless p do emitDec x n c - emitBlock b -| FnBody.del x b => emitDel x; emitBlock b -| FnBody.setTag x i b => emitSetTag x i; emitBlock b -| FnBody.set x i y b => emitSet x i y; emitBlock b -| FnBody.uset x i y b => emitUSet x i y; emitBlock b -| FnBody.sset x i o y t b => emitSSet x i o y t; emitBlock b -| FnBody.mdata _ b => emitBlock b -| FnBody.ret x => emit "return "; emitArg x; emitLn ";" -| FnBody.case _ x xType alts => emitCase x xType alts -| FnBody.jmp j xs => emitJmp j xs -| FnBody.unreachable => emitLn "lean_panic_unreachable();" + | FnBody.dec x n c p b => + unless p do emitDec x n c + emitBlock b + | FnBody.del x b => emitDel x; emitBlock b + | FnBody.setTag x i b => emitSetTag x i; emitBlock b + | FnBody.set x i y b => emitSet x i y; emitBlock b + | FnBody.uset x i y b => emitUSet x i y; emitBlock b + | FnBody.sset x i o y t b => emitSSet x i o y t; emitBlock b + | FnBody.mdata _ b => emitBlock b + | FnBody.ret x => emit "return "; emitArg x; emitLn ";" + | FnBody.case _ x xType alts => emitCase x xType alts + | FnBody.jmp j xs => emitJmp j xs + | FnBody.unreachable => emitLn "lean_panic_unreachable();" partial def emitJPs : FnBody → M Unit -| FnBody.jdecl j xs v b => do emit j; emitLn ":"; emitFnBody v; emitJPs b -| e => do unless e.isTerminal do emitJPs e.body + | FnBody.jdecl j xs v b => do emit j; emitLn ":"; emitFnBody v; emitJPs b + | e => do unless e.isTerminal do emitJPs e.body partial def emitFnBody (b : FnBody) : M Unit := do -emitLn "{" -let declared ← declareVars b false -if declared then emitLn "" -emitBlock b -emitJPs b -emitLn "}" + emitLn "{" + let declared ← declareVars b false + if declared then emitLn "" + emitBlock b + emitJPs b + emitLn "}" end def emitDeclAux (d : Decl) : M Unit := do -let env ← getEnv -let (vMap, jpMap) := mkVarJPMaps d -withReader (fun ctx => { ctx with jpMap := jpMap }) do -unless hasInitAttr env d.name do - match d with - | Decl.fdecl f xs t b => - let baseName ← toCName f; - if xs.size == 0 then - emit "static " - emit (toCType t); emit " "; - if xs.size > 0 then - emit baseName; - emit "("; - if xs.size > closureMaxArgs && isBoxedName d.name then - emit "lean_object** _args" + let env ← getEnv + let (vMap, jpMap) := mkVarJPMaps d + withReader (fun ctx => { ctx with jpMap := jpMap }) do + unless hasInitAttr env d.name do + match d with + | Decl.fdecl f xs t b => + let baseName ← toCName f; + if xs.size == 0 then + emit "static " + emit (toCType t); emit " "; + if xs.size > 0 then + emit baseName; + emit "("; + if xs.size > closureMaxArgs && isBoxedName d.name then + emit "lean_object** _args" + else + xs.size.forM fun i => do + if i > 0 then emit ", " + let x := xs[i] + emit (toCType x.ty); emit " "; emit x.x + emit ")" else + emit ("_init_" ++ baseName ++ "()") + emitLn " {"; + if xs.size > closureMaxArgs && isBoxedName d.name then xs.size.forM fun i => do - if i > 0 then emit ", " let x := xs[i] - emit (toCType x.ty); emit " "; emit x.x - emit ")" - else - emit ("_init_" ++ baseName ++ "()") - emitLn " {"; - if xs.size > closureMaxArgs && isBoxedName d.name then - xs.size.forM fun i => do - let x := xs[i] - emit "lean_object* "; emit x.x; emit " = _args["; emit i; emitLn "];" - emitLn "_start:"; - withReader (fun ctx => { ctx with mainFn := f, mainParams := xs }) (emitFnBody b); - emitLn "}" - | _ => pure () + emit "lean_object* "; emit x.x; emit " = _args["; emit i; emitLn "];" + emitLn "_start:"; + withReader (fun ctx => { ctx with mainFn := f, mainParams := xs }) (emitFnBody b); + emitLn "}" + | _ => pure () def emitDecl (d : Decl) : M Unit := do -let d := d.normalizeIds; -- ensure we don't have gaps in the variable indices -try - emitDeclAux d -catch err => - throw s!"{err}\ncompiling:\n{d}" + let d := d.normalizeIds; -- ensure we don't have gaps in the variable indices + try + emitDeclAux d + catch err => + throw s!"{err}\ncompiling:\n{d}" def emitFns : M Unit := do -let env ← getEnv; -let decls := getDecls env; -decls.reverse.forM emitDecl + let env ← getEnv; + let decls := getDecls env; + decls.reverse.forM emitDecl def emitMarkPersistent (d : Decl) (n : Name) : M Unit := do -if d.resultType.isObj then - emit "lean_mark_persistent(" - emitCName n - emitLn ");" + if d.resultType.isObj then + emit "lean_mark_persistent(" + emitCName n + emitLn ");" def emitDeclInit (d : Decl) : M Unit := do -let env ← getEnv -let n := d.name -if isIOUnitInitFn env n then - emit "res = "; emitCName n; emitLn "(lean_io_mk_world());" - emitLn "if (lean_io_result_is_error(res)) return res;" - emitLn "lean_dec_ref(res);" -else if d.params.size == 0 then - match getInitFnNameFor? env d.name with - | some initFn => - emit "res = "; emitCName initFn; emitLn "(lean_io_mk_world());" + let env ← getEnv + let n := d.name + if isIOUnitInitFn env n then + emit "res = "; emitCName n; emitLn "(lean_io_mk_world());" emitLn "if (lean_io_result_is_error(res)) return res;" - emitCName n; emitLn " = lean_io_result_get_value(res);" - emitMarkPersistent d n emitLn "lean_dec_ref(res);" - | _ => - emitCName n; emit " = "; emitCInitName n; emitLn "();"; emitMarkPersistent d n + else if d.params.size == 0 then + match getInitFnNameFor? env d.name with + | some initFn => + emit "res = "; emitCName initFn; emitLn "(lean_io_mk_world());" + emitLn "if (lean_io_result_is_error(res)) return res;" + emitCName n; emitLn " = lean_io_result_get_value(res);" + emitMarkPersistent d n + emitLn "lean_dec_ref(res);" + | _ => + emitCName n; emit " = "; emitCInitName n; emitLn "();"; emitMarkPersistent d n def emitInitFn : M Unit := do -let env ← getEnv -let modName ← getModName -env.imports.forM fun imp => emitLn ("lean_object* " ++ mkModuleInitializationFunctionName imp.module ++ "(lean_object*);") -emitLns [ - "static bool _G_initialized = false;", - "lean_object* " ++ mkModuleInitializationFunctionName modName ++ "(lean_object* w) {", - "lean_object * res;", - "if (_G_initialized) return lean_io_result_mk_ok(lean_box(0));", - "_G_initialized = true;" -] -env.imports.forM fun imp => emitLns [ - "res = " ++ mkModuleInitializationFunctionName imp.module ++ "(lean_io_mk_world());", - "if (lean_io_result_is_error(res)) return res;", - "lean_dec_ref(res);"] -let decls := getDecls env -decls.reverse.forM emitDeclInit -emitLns ["return lean_io_result_mk_ok(lean_box(0));", "}"] + let env ← getEnv + let modName ← getModName + env.imports.forM fun imp => emitLn ("lean_object* " ++ mkModuleInitializationFunctionName imp.module ++ "(lean_object*);") + emitLns [ + "static bool _G_initialized = false;", + "lean_object* " ++ mkModuleInitializationFunctionName modName ++ "(lean_object* w) {", + "lean_object * res;", + "if (_G_initialized) return lean_io_result_mk_ok(lean_box(0));", + "_G_initialized = true;" + ] + env.imports.forM fun imp => emitLns [ + "res = " ++ mkModuleInitializationFunctionName imp.module ++ "(lean_io_mk_world());", + "if (lean_io_result_is_error(res)) return res;", + "lean_dec_ref(res);"] + let decls := getDecls env + decls.reverse.forM emitDeclInit + emitLns ["return lean_io_result_mk_ok(lean_box(0));", "}"] def main : M Unit := do -emitFileHeader -emitFnDecls -emitFns -emitInitFn -emitMainFnIfNeeded -emitFileFooter + emitFileHeader + emitFnDecls + emitFns + emitInitFn + emitMainFnIfNeeded + emitFileFooter end EmitC @[export lean_ir_emit_c] def emitC (env : Environment) (modName : Name) : Except String String := -match (EmitC.main { env := env, modName := modName }).run "" with -| EStateM.Result.ok _ s => Except.ok s -| EStateM.Result.error err _ => Except.error err + match (EmitC.main { env := env, modName := modName }).run "" with + | EStateM.Result.ok _ s => Except.ok s + | EStateM.Result.error err _ => Except.error err end Lean.IR diff --git a/src/Lean/Compiler/IR/EmitUtil.lean b/src/Lean/Compiler/IR/EmitUtil.lean index 62f3944d90..35abafa9ff 100644 --- a/src/Lean/Compiler/IR/EmitUtil.lean +++ b/src/Lean/Compiler/IR/EmitUtil.lean @@ -11,44 +11,44 @@ import Lean.Compiler.IR.CompilerM namespace Lean.IR /- Return true iff `b` is of the form `let x := g ys; ret x` -/ def isTailCallTo (g : Name) (b : FnBody) : Bool := -match b with -| FnBody.vdecl x _ (Expr.fap f _) (FnBody.ret (Arg.var y)) => x == y && f == g -| _ => false + match b with + | FnBody.vdecl x _ (Expr.fap f _) (FnBody.ret (Arg.var y)) => x == y && f == g + | _ => false def usesModuleFrom (env : Environment) (modulePrefix : Name) : Bool := -env.allImportedModuleNames.toList.any $ fun modName => modulePrefix.isPrefixOf modName + env.allImportedModuleNames.toList.any $ fun modName => modulePrefix.isPrefixOf modName namespace CollectUsedDecls abbrev M := ReaderT Environment (StateM NameSet) @[inline] def collect (f : FunId) : M Unit := -modify $ fun s => s.insert f + modify fun s => s.insert f partial def collectFnBody : FnBody → M Unit -| FnBody.vdecl _ _ v b => - match v with - | Expr.fap f _ => collect f *> collectFnBody b - | Expr.pap f _ => collect f *> collectFnBody b - | other => collectFnBody b -| FnBody.jdecl _ _ v b => collectFnBody v *> collectFnBody b -| FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body -| e => do unless e.isTerminal do collectFnBody e.body + | FnBody.vdecl _ _ v b => + match v with + | Expr.fap f _ => collect f *> collectFnBody b + | Expr.pap f _ => collect f *> collectFnBody b + | other => collectFnBody b + | FnBody.jdecl _ _ v b => collectFnBody v *> collectFnBody b + | FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body + | e => do unless e.isTerminal do collectFnBody e.body def collectInitDecl (fn : Name) : M Unit := do -let env ← read -match getInitFnNameFor? env fn with -| some initFn => collect initFn -| _ => pure () + let env ← read + match getInitFnNameFor? env fn with + | some initFn => collect initFn + | _ => pure () def collectDecl : Decl → M NameSet -| Decl.fdecl fn _ _ b => collectInitDecl fn *> CollectUsedDecls.collectFnBody b *> get -| Decl.extern fn _ _ _ => collectInitDecl fn *> get + | Decl.fdecl fn _ _ b => collectInitDecl fn *> CollectUsedDecls.collectFnBody b *> get + | Decl.extern fn _ _ _ => collectInitDecl fn *> get end CollectUsedDecls def collectUsedDecls (env : Environment) (decl : Decl) (used : NameSet := {}) : NameSet := -(CollectUsedDecls.collectDecl decl env).run' used + (CollectUsedDecls.collectDecl decl env).run' used abbrev VarTypeMap := Std.HashMap VarId IRType abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param) @@ -56,22 +56,22 @@ abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param) namespace CollectMaps abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap) @[inline] def collectVar (x : VarId) (t : IRType) : Collector -| (vs, js) => (vs.insert x t, js) + | (vs, js) => (vs.insert x t, js) def collectParams (ps : Array Param) : Collector := -fun s => ps.foldl (fun s p => collectVar p.x p.ty s) s + fun s => ps.foldl (fun s p => collectVar p.x p.ty s) s @[inline] def collectJP (j : JoinPointId) (xs : Array Param) : Collector -| (vs, js) => (vs, js.insert j xs) + | (vs, js) => (vs, js.insert j xs) /- `collectFnBody` assumes the variables in -/ partial def collectFnBody : FnBody → Collector -| FnBody.vdecl x t _ b => collectVar x t ∘ collectFnBody b -| FnBody.jdecl j xs v b => collectJP j xs ∘ collectParams xs ∘ collectFnBody v ∘ collectFnBody b -| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s -| e => if e.isTerminal then id else collectFnBody e.body + | FnBody.vdecl x t _ b => collectVar x t ∘ collectFnBody b + | FnBody.jdecl j xs v b => collectJP j xs ∘ collectParams xs ∘ collectFnBody v ∘ collectFnBody b + | FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s + | e => if e.isTerminal then id else collectFnBody e.body def collectDecl : Decl → Collector -| Decl.fdecl _ xs _ b => collectParams xs ∘ collectFnBody b -| _ => id + | Decl.fdecl _ xs _ b => collectParams xs ∘ collectFnBody b + | _ => id end CollectMaps @@ -79,7 +79,7 @@ end CollectMaps and `j` is a mapping from join point to parameters. This function assumes `d` has normalized indexes (see `normids.lean`). -/ def mkVarJPMaps (d : Decl) : VarTypeMap × JPParamsMap := -CollectMaps.collectDecl d ({}, {}) + CollectMaps.collectDecl d ({}, {}) end IR end Lean diff --git a/src/Lean/Compiler/IR/ExpandResetReuse.lean b/src/Lean/Compiler/IR/ExpandResetReuse.lean index 2978bc1ca4..1a08a8a26a 100644 --- a/src/Lean/Compiler/IR/ExpandResetReuse.lean +++ b/src/Lean/Compiler/IR/ExpandResetReuse.lean @@ -12,46 +12,45 @@ namespace Lean.IR.ExpandResetReuse abbrev ProjMap := Std.HashMap VarId Expr namespace CollectProjMap abbrev Collector := ProjMap → ProjMap -@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := -fun m => match v with - | Expr.proj _ _ => m.insert x v - | Expr.sproj _ _ _ => m.insert x v - | Expr.uproj _ _ => m.insert x v - | _ => m +@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := fun m => + match v with + | Expr.proj .. => m.insert x v + | Expr.sproj .. => m.insert x v + | Expr.uproj .. => m.insert x v + | _ => m partial def collectFnBody : FnBody → Collector -| FnBody.vdecl x _ v b => collectVDecl x v ∘ collectFnBody b -| FnBody.jdecl _ _ v b => collectFnBody v ∘ collectFnBody b -| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s -| e => if e.isTerminal then id else collectFnBody e.body + | FnBody.vdecl x _ v b => collectVDecl x v ∘ collectFnBody b + | FnBody.jdecl _ _ v b => collectFnBody v ∘ collectFnBody b + | FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s + | e => if e.isTerminal then id else collectFnBody e.body end CollectProjMap /- Create a mapping from variables to projections. This function assumes variable ids have been normalized -/ def mkProjMap (d : Decl) : ProjMap := -match d with -| Decl.fdecl _ _ _ b => CollectProjMap.collectFnBody b {} -| _ => {} + match d with + | Decl.fdecl _ _ _ b => CollectProjMap.collectFnBody b {} + | _ => {} structure Context := -(projMap : ProjMap) + (projMap : ProjMap) /- Return true iff `x` is consumed in all branches of the current block. Here consumption means the block contains a `dec x` or `reuse x ...`. -/ partial def consumed (x : VarId) : FnBody → Bool -| FnBody.vdecl _ _ v b => - match v with - | Expr.reuse y _ _ _ => x == y || consumed x b - | _ => consumed x b -| FnBody.dec y _ _ _ b => x == y || consumed x b -| FnBody.case _ _ _ alts => alts.all fun alt => consumed x alt.body -| e => !e.isTerminal && consumed x e.body + | FnBody.vdecl _ _ v b => + match v with + | Expr.reuse y _ _ _ => x == y || consumed x b + | _ => consumed x b + | FnBody.dec y _ _ _ b => x == y || consumed x b + | FnBody.case _ _ _ alts => alts.all fun alt => consumed x alt.body + | e => !e.isTerminal && consumed x e.body abbrev Mask := Array (Option VarId) /- Auxiliary function for eraseProjIncFor -/ -partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnBody → Array FnBody × Mask -| bs, mask, keep => +partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask := let done (_ : Unit) := (bs ++ keep.reverse, mask) let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b) if bs.size < 2 then done () @@ -85,30 +84,30 @@ partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnB /- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`. Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/ def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask := -eraseProjIncForAux y bs (mkArray n none) #[] + eraseProjIncForAux y bs (mkArray n none) #[] /- Replace `reuse x ctor ...` with `ctor ...`, and remoce `dec x` -/ partial def reuseToCtor (x : VarId) : FnBody → FnBody -| FnBody.dec y n c p b => - if x == y then b -- n must be 1 since `x := reset ...` - else FnBody.dec y n c p (reuseToCtor x b) -| FnBody.vdecl z t v b => - match v with - | Expr.reuse y c u xs => - if x == y then FnBody.vdecl z t (Expr.ctor c xs) b - else FnBody.vdecl z t v (reuseToCtor x b) - | _ => - FnBody.vdecl z t v (reuseToCtor x b) -| FnBody.case tid y yType alts => - let alts := alts.map fun alt => alt.modifyBody (reuseToCtor x) - FnBody.case tid y yType alts -| e => - if e.isTerminal then - e - else - let (instr, b) := e.split - let b := reuseToCtor x b - instr.setBody b + | FnBody.dec y n c p b => + if x == y then b -- n must be 1 since `x := reset ...` + else FnBody.dec y n c p (reuseToCtor x b) + | FnBody.vdecl z t v b => + match v with + | Expr.reuse y c u xs => + if x == y then FnBody.vdecl z t (Expr.ctor c xs) b + else FnBody.vdecl z t v (reuseToCtor x b) + | _ => + FnBody.vdecl z t v (reuseToCtor x b) + | FnBody.case tid y yType alts => + let alts := alts.map fun alt => alt.modifyBody (reuseToCtor x) + FnBody.case tid y yType alts + | e => + if e.isTerminal then + e + else + let (instr, b) := e.split + let b := reuseToCtor x b + instr.setBody b /- replace @@ -123,97 +122,91 @@ where `z_i`'s are the variables in `mask`, and `b'` is `b` where we removed `dec x` and replaced `reuse x ctor_i ...` with `ctor_i ...`. -/ def mkSlowPath (x y : VarId) (mask : Mask) (b : FnBody) : FnBody := -let b := reuseToCtor x b -let b := FnBody.dec y 1 true false b -mask.foldl - (fun b m => match m with - | some z => FnBody.inc z 1 true false b - | none => b) - b + let b := reuseToCtor x b + let b := FnBody.dec y 1 true false b + mask.foldl (init := b) fun b m => match m with + | some z => FnBody.inc z 1 true false b + | none => b abbrev M := ReaderT Context (StateM Nat) -def mkFresh : M VarId := -modifyGet $ fun n => ({ idx := n }, n + 1) + def mkFresh : M VarId := + modifyGet $ fun n => ({ idx := n }, n + 1) def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody := -mask.size.foldM - (fun i b => + mask.size.foldM (init := b) fun i b => match mask.get! i with | some _ => pure b -- code took ownership of this field | none => do let fld ← mkFresh - pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b))) - b + pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b)) def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody := -zs.size.fold - (fun i b => FnBody.set y i (zs.get! i) b) - b + zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b /- Given `set x[i] := y`, return true iff `y := proj[i] x` -/ def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool := -match y with -| Arg.var y => - match ctx.projMap.find? y with - | some (Expr.proj j w) => j == i && w == x + match y with + | Arg.var y => + match ctx.projMap.find? y with + | some (Expr.proj j w) => j == i && w == x + | _ => false | _ => false -| _ => false /- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/ def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool := -match ctx.projMap.find? y with -| some (Expr.uproj j w) => j == i && w == x -| _ => false + match ctx.projMap.find? y with + | some (Expr.uproj j w) => j == i && w == x + | _ => false /- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/ def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool := -match ctx.projMap.find? y with -| some (Expr.sproj m j w) => n == m && j == i && w == x -| _ => false + match ctx.projMap.find? y with + | some (Expr.sproj m j w) => n == m && j == i && w == x + | _ => false /- Remove unnecessary `set/uset/sset` operations -/ partial def removeSelfSet (ctx : Context) : FnBody → FnBody -| FnBody.set x i y b => - if isSelfSet ctx x i y then removeSelfSet ctx b - else FnBody.set x i y (removeSelfSet ctx b) -| FnBody.uset x i y b => - if isSelfUSet ctx x i y then removeSelfSet ctx b - else FnBody.uset x i y (removeSelfSet ctx b) -| FnBody.sset x n i y t b => - if isSelfSSet ctx x n i y then removeSelfSet ctx b - else FnBody.sset x n i y t (removeSelfSet ctx b) -| FnBody.case tid y yType alts => - let alts := alts.map fun alt => alt.modifyBody (removeSelfSet ctx) - FnBody.case tid y yType alts -| e => - if e.isTerminal then e - else - let (instr, b) := e.split - let b := removeSelfSet ctx b - instr.setBody b + | FnBody.set x i y b => + if isSelfSet ctx x i y then removeSelfSet ctx b + else FnBody.set x i y (removeSelfSet ctx b) + | FnBody.uset x i y b => + if isSelfUSet ctx x i y then removeSelfSet ctx b + else FnBody.uset x i y (removeSelfSet ctx b) + | FnBody.sset x n i y t b => + if isSelfSSet ctx x n i y then removeSelfSet ctx b + else FnBody.sset x n i y t (removeSelfSet ctx b) + | FnBody.case tid y yType alts => + let alts := alts.map fun alt => alt.modifyBody (removeSelfSet ctx) + FnBody.case tid y yType alts + | e => + if e.isTerminal then e + else + let (instr, b) := e.split + let b := removeSelfSet ctx b + instr.setBody b partial def reuseToSet (ctx : Context) (x y : VarId) : FnBody → FnBody -| FnBody.dec z n c p b => - if x == z then FnBody.del y b - else FnBody.dec z n c p (reuseToSet ctx x y b) -| FnBody.vdecl z t v b => - match v with - | Expr.reuse w c u zs => - if x == w then - let b := setFields y zs (b.replaceVar z y) - let b := if u then FnBody.setTag y c.cidx b else b - removeSelfSet ctx b - else FnBody.vdecl z t v (reuseToSet ctx x y b) - | _ => FnBody.vdecl z t v (reuseToSet ctx x y b) -| FnBody.case tid z zType alts => - let alts := alts.map fun alt => alt.modifyBody (reuseToSet ctx x y) - FnBody.case tid z zType alts -| e => - if e.isTerminal then e - else - let (instr, b) := e.split - let b := reuseToSet ctx x y b - instr.setBody b + | FnBody.dec z n c p b => + if x == z then FnBody.del y b + else FnBody.dec z n c p (reuseToSet ctx x y b) + | FnBody.vdecl z t v b => + match v with + | Expr.reuse w c u zs => + if x == w then + let b := setFields y zs (b.replaceVar z y) + let b := if u then FnBody.setTag y c.cidx b else b + removeSelfSet ctx b + else FnBody.vdecl z t v (reuseToSet ctx x y b) + | _ => FnBody.vdecl z t v (reuseToSet ctx x y b) + | FnBody.case tid z zType alts => + let alts := alts.map fun alt => alt.modifyBody (reuseToSet ctx x y) + FnBody.case tid z zType alts + | e => + if e.isTerminal then e + else + let (instr, b) := e.split + let b := reuseToSet ctx x y b + instr.setBody b /- replace @@ -235,54 +228,54 @@ and `z := reuse x ctor_i ws; F` is replaced with `set x i ws[i]` operations, and we replace `z` with `x` in `F` -/ def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody := do -let ctx ← read -let b := reuseToSet ctx x y b -releaseUnreadFields y mask b + let ctx ← read + let b := reuseToSet ctx x y b + releaseUnreadFields y mask b -- Expand `bs; x := reset[n] y; b` partial def expand (mainFn : FnBody → Array FnBody → M FnBody) (bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M FnBody := do -let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b -let (bs, mask) := eraseProjIncFor n y bs -/- Remark: we may be duplicting variable/JP indices. That is, `bSlow` and `bFast` may - have duplicate indices. We run `normalizeIds` to fix the ids after we have expand them. -/ -let bSlow := mkSlowPath x y mask b -let bFast ← mkFastPath x y mask b -/- We only optimize recursively the fast. -/ -let bFast ← mainFn bFast #[] -let c ← mkFresh -let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast) -pure $ reshape bs b + let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b + let (bs, mask) := eraseProjIncFor n y bs + /- Remark: we may be duplicting variable/JP indices. That is, `bSlow` and `bFast` may + have duplicate indices. We run `normalizeIds` to fix the ids after we have expand them. -/ + let bSlow := mkSlowPath x y mask b + let bFast ← mkFastPath x y mask b + /- We only optimize recursively the fast. -/ + let bFast ← mainFn bFast #[] + let c ← mkFresh + let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast) + pure $ reshape bs b partial def searchAndExpand : FnBody → Array FnBody → M FnBody -| d@(FnBody.vdecl x t (Expr.reset n y) b), bs => - if consumed x b then do - expand searchAndExpand bs x n y b - else - searchAndExpand b (push bs d) -| FnBody.jdecl j xs v b, bs => do - let v ← searchAndExpand v #[] - searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil)) -| FnBody.case tid x xType alts, bs => do - let alts ← alts.mapM $ fun alt => alt.mmodifyBody fun b => searchAndExpand b #[] - pure $ reshape bs (FnBody.case tid x xType alts) -| b, bs => - if b.isTerminal then pure $ reshape bs b - else searchAndExpand b.body (push bs b) + | d@(FnBody.vdecl x t (Expr.reset n y) b), bs => + if consumed x b then do + expand searchAndExpand bs x n y b + else + searchAndExpand b (push bs d) + | FnBody.jdecl j xs v b, bs => do + let v ← searchAndExpand v #[] + searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil)) + | FnBody.case tid x xType alts, bs => do + let alts ← alts.mapM $ fun alt => alt.mmodifyBody fun b => searchAndExpand b #[] + pure $ reshape bs (FnBody.case tid x xType alts) + | b, bs => + if b.isTerminal then pure $ reshape bs b + else searchAndExpand b.body (push bs b) def main (d : Decl) : Decl := -match d with -| (Decl.fdecl f xs t b) => - let m := mkProjMap d - let nextIdx := d.maxIndex + 1 - let b := (searchAndExpand b #[] { projMap := m }).run' nextIdx - Decl.fdecl f xs t b -| d => d + match d with + | (Decl.fdecl f xs t b) => + let m := mkProjMap d + let nextIdx := d.maxIndex + 1 + let b := (searchAndExpand b #[] { projMap := m }).run' nextIdx + Decl.fdecl f xs t b + | d => d end ExpandResetReuse /-- (Try to) expand `reset` and `reuse` instructions. -/ def Decl.expandResetReuse (d : Decl) : Decl := -(ExpandResetReuse.main d).normalizeIds + (ExpandResetReuse.main d).normalizeIds end Lean.IR diff --git a/src/Lean/Compiler/IR/FreeVars.lean b/src/Lean/Compiler/IR/FreeVars.lean index 4b5eeb11d3..b9560361d9 100644 --- a/src/Lean/Compiler/IR/FreeVars.lean +++ b/src/Lean/Compiler/IR/FreeVars.lean @@ -26,62 +26,62 @@ abbrev Collector := Index → Index instance : AndThen Collector := ⟨seq⟩ private def collectArg : Arg → Collector -| Arg.var x => collectVar x -| irrelevant => skip + | Arg.var x => collectVar x + | irrelevant => skip @[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector := -fun m => as.foldl (fun m a => f a m) m + fun m => as.foldl (fun m a => f a m) m private def collectArgs (as : Array Arg) : Collector := collectArray as collectArg private def collectParam (p : Param) : Collector := collectVar p.x private def collectParams (ps : Array Param) : Collector := collectArray ps collectParam private def collectExpr : Expr → Collector -| Expr.ctor _ ys => collectArgs ys -| Expr.reset _ x => collectVar x -| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys -| Expr.proj _ x => collectVar x -| Expr.uproj _ x => collectVar x -| Expr.sproj _ _ x => collectVar x -| Expr.fap _ ys => collectArgs ys -| Expr.pap _ ys => collectArgs ys -| Expr.ap x ys => collectVar x >> collectArgs ys -| Expr.box _ x => collectVar x -| Expr.unbox x => collectVar x -| Expr.lit v => skip -| Expr.isShared x => collectVar x -| Expr.isTaggedPtr x => collectVar x + | Expr.ctor _ ys => collectArgs ys + | Expr.reset _ x => collectVar x + | Expr.reuse x _ _ ys => collectVar x >> collectArgs ys + | Expr.proj _ x => collectVar x + | Expr.uproj _ x => collectVar x + | Expr.sproj _ _ x => collectVar x + | Expr.fap _ ys => collectArgs ys + | Expr.pap _ ys => collectArgs ys + | Expr.ap x ys => collectVar x >> collectArgs ys + | Expr.box _ x => collectVar x + | Expr.unbox x => collectVar x + | Expr.lit v => skip + | Expr.isShared x => collectVar x + | Expr.isTaggedPtr x => collectVar x private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector := -collectArray alts $ fun alt => f alt.body + collectArray alts $ fun alt => f alt.body partial def collectFnBody : FnBody → Collector -| FnBody.vdecl x _ v b => collectVar x >> collectExpr v >> collectFnBody b -| FnBody.jdecl j ys v b => collectJP j >> collectFnBody v >> collectParams ys >> collectFnBody b -| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b -| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b -| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b -| FnBody.setTag x _ b => collectVar x >> collectFnBody b -| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b -| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b -| FnBody.del x b => collectVar x >> collectFnBody b -| FnBody.mdata _ b => collectFnBody b -| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts -| FnBody.jmp j ys => collectJP j >> collectArgs ys -| FnBody.ret x => collectArg x -| FnBody.unreachable => skip + | FnBody.vdecl x _ v b => collectVar x >> collectExpr v >> collectFnBody b + | FnBody.jdecl j ys v b => collectJP j >> collectFnBody v >> collectParams ys >> collectFnBody b + | FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b + | FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b + | FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b + | FnBody.setTag x _ b => collectVar x >> collectFnBody b + | FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b + | FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b + | FnBody.del x b => collectVar x >> collectFnBody b + | FnBody.mdata _ b => collectFnBody b + | FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts + | FnBody.jmp j ys => collectJP j >> collectArgs ys + | FnBody.ret x => collectArg x + | FnBody.unreachable => skip partial def collectDecl : Decl → Collector -| Decl.fdecl _ xs _ b => collectParams xs >> collectFnBody b -| Decl.extern _ xs _ _ => collectParams xs + | Decl.fdecl _ xs _ b => collectParams xs >> collectFnBody b + | Decl.extern _ xs _ _ => collectParams xs end MaxIndex def FnBody.maxIndex (b : FnBody) : Index := -MaxIndex.collectFnBody b 0 + MaxIndex.collectFnBody b 0 def Decl.maxIndex (d : Decl) : Index := -MaxIndex.collectDecl d 0 + MaxIndex.collectDecl d 0 namespace FreeIndices /- We say a variable (join point) index (aka name) is free in a function body @@ -90,89 +90,89 @@ namespace FreeIndices abbrev Collector := IndexSet → IndexSet → IndexSet @[inline] private def skip : Collector := -fun bv fv => fv + fun bv fv => fv @[inline] private def collectIndex (x : Index) : Collector := -fun bv fv => if bv.contains x then fv else fv.insert x + fun bv fv => if bv.contains x then fv else fv.insert x @[inline] private def collectVar (x : VarId) : Collector := -collectIndex x.idx + collectIndex x.idx @[inline] private def collectJP (x : JoinPointId) : Collector := -collectIndex x.idx + collectIndex x.idx @[inline] private def withIndex (x : Index) : Collector → Collector := -fun k bv fv => k (bv.insert x) fv + fun k bv fv => k (bv.insert x) fv @[inline] private def withVar (x : VarId) : Collector → Collector := -withIndex x.idx + withIndex x.idx @[inline] private def withJP (x : JoinPointId) : Collector → Collector := -withIndex x.idx + withIndex x.idx def insertParams (s : IndexSet) (ys : Array Param) : IndexSet := -ys.foldl (fun s p => s.insert p.x.idx) s + ys.foldl (init := s) fun s p => s.insert p.x.idx @[inline] private def withParams (ys : Array Param) : Collector → Collector := -fun k bv fv => k (insertParams bv ys) fv + fun k bv fv => k (insertParams bv ys) fv @[inline] private def seq : Collector → Collector → Collector := -fun k₁ k₂ bv fv => k₂ bv (k₁ bv fv) + fun k₁ k₂ bv fv => k₂ bv (k₁ bv fv) instance : AndThen Collector := ⟨seq⟩ private def collectArg : Arg → Collector -| Arg.var x => collectVar x -| irrelevant => skip + | Arg.var x => collectVar x + | irrelevant => skip @[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector := -fun bv fv => as.foldl (fun fv a => f a bv fv) fv + fun bv fv => as.foldl (fun fv a => f a bv fv) fv private def collectArgs (as : Array Arg) : Collector := -collectArray as collectArg + collectArray as collectArg private def collectExpr : Expr → Collector -| Expr.ctor _ ys => collectArgs ys -| Expr.reset _ x => collectVar x -| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys -| Expr.proj _ x => collectVar x -| Expr.uproj _ x => collectVar x -| Expr.sproj _ _ x => collectVar x -| Expr.fap _ ys => collectArgs ys -| Expr.pap _ ys => collectArgs ys -| Expr.ap x ys => collectVar x >> collectArgs ys -| Expr.box _ x => collectVar x -| Expr.unbox x => collectVar x -| Expr.lit v => skip -| Expr.isShared x => collectVar x -| Expr.isTaggedPtr x => collectVar x + | Expr.ctor _ ys => collectArgs ys + | Expr.reset _ x => collectVar x + | Expr.reuse x _ _ ys => collectVar x >> collectArgs ys + | Expr.proj _ x => collectVar x + | Expr.uproj _ x => collectVar x + | Expr.sproj _ _ x => collectVar x + | Expr.fap _ ys => collectArgs ys + | Expr.pap _ ys => collectArgs ys + | Expr.ap x ys => collectVar x >> collectArgs ys + | Expr.box _ x => collectVar x + | Expr.unbox x => collectVar x + | Expr.lit v => skip + | Expr.isShared x => collectVar x + | Expr.isTaggedPtr x => collectVar x private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector := -collectArray alts $ fun alt => f alt.body + collectArray alts $ fun alt => f alt.body partial def collectFnBody : FnBody → Collector -| FnBody.vdecl x _ v b => collectExpr v >> withVar x (collectFnBody b) -| FnBody.jdecl j ys v b => withParams ys (collectFnBody v) >> withJP j (collectFnBody b) -| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b -| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b -| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b -| FnBody.setTag x _ b => collectVar x >> collectFnBody b -| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b -| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b -| FnBody.del x b => collectVar x >> collectFnBody b -| FnBody.mdata _ b => collectFnBody b -| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts -| FnBody.jmp j ys => collectJP j >> collectArgs ys -| FnBody.ret x => collectArg x -| FnBody.unreachable => skip + | FnBody.vdecl x _ v b => collectExpr v >> withVar x (collectFnBody b) + | FnBody.jdecl j ys v b => withParams ys (collectFnBody v) >> withJP j (collectFnBody b) + | FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b + | FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b + | FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b + | FnBody.setTag x _ b => collectVar x >> collectFnBody b + | FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b + | FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b + | FnBody.del x b => collectVar x >> collectFnBody b + | FnBody.mdata _ b => collectFnBody b + | FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts + | FnBody.jmp j ys => collectJP j >> collectArgs ys + | FnBody.ret x => collectArg x + | FnBody.unreachable => skip end FreeIndices def FnBody.collectFreeIndices (b : FnBody) (vs : IndexSet) : IndexSet := -FreeIndices.collectFnBody b {} vs + FreeIndices.collectFnBody b {} vs def FnBody.freeIndices (b : FnBody) : IndexSet := -b.collectFreeIndices {} + b.collectFreeIndices {} namespace HasIndex /- In principle, we can check whether a function body `b` contains an index `i` using @@ -183,46 +183,46 @@ def visitVar (w : Index) (x : VarId) : Bool := w == x.idx def visitJP (w : Index) (x : JoinPointId) : Bool := w == x.idx def visitArg (w : Index) : Arg → Bool -| Arg.var x => visitVar w x -| _ => false + | Arg.var x => visitVar w x + | _ => false def visitArgs (w : Index) (xs : Array Arg) : Bool := -xs.any (visitArg w) + xs.any (visitArg w) def visitParams (w : Index) (ps : Array Param) : Bool := -ps.any (fun p => w == p.x.idx) + ps.any (fun p => w == p.x.idx) def visitExpr (w : Index) : Expr → Bool -| Expr.ctor _ ys => visitArgs w ys -| Expr.reset _ x => visitVar w x -| Expr.reuse x _ _ ys => visitVar w x || visitArgs w ys -| Expr.proj _ x => visitVar w x -| Expr.uproj _ x => visitVar w x -| Expr.sproj _ _ x => visitVar w x -| Expr.fap _ ys => visitArgs w ys -| Expr.pap _ ys => visitArgs w ys -| Expr.ap x ys => visitVar w x || visitArgs w ys -| Expr.box _ x => visitVar w x -| Expr.unbox x => visitVar w x -| Expr.lit v => false -| Expr.isShared x => visitVar w x -| Expr.isTaggedPtr x => visitVar w x + | Expr.ctor _ ys => visitArgs w ys + | Expr.reset _ x => visitVar w x + | Expr.reuse x _ _ ys => visitVar w x || visitArgs w ys + | Expr.proj _ x => visitVar w x + | Expr.uproj _ x => visitVar w x + | Expr.sproj _ _ x => visitVar w x + | Expr.fap _ ys => visitArgs w ys + | Expr.pap _ ys => visitArgs w ys + | Expr.ap x ys => visitVar w x || visitArgs w ys + | Expr.box _ x => visitVar w x + | Expr.unbox x => visitVar w x + | Expr.lit v => false + | Expr.isShared x => visitVar w x + | Expr.isTaggedPtr x => visitVar w x partial def visitFnBody (w : Index) : FnBody → Bool -| FnBody.vdecl x _ v b => visitExpr w v || visitFnBody w b -| FnBody.jdecl j ys v b => visitFnBody w v || visitFnBody w b -| FnBody.set x _ y b => visitVar w x || visitArg w y || visitFnBody w b -| FnBody.uset x _ y b => visitVar w x || visitVar w y || visitFnBody w b -| FnBody.sset x _ _ y _ b => visitVar w x || visitVar w y || visitFnBody w b -| FnBody.setTag x _ b => visitVar w x || visitFnBody w b -| FnBody.inc x _ _ _ b => visitVar w x || visitFnBody w b -| FnBody.dec x _ _ _ b => visitVar w x || visitFnBody w b -| FnBody.del x b => visitVar w x || visitFnBody w b -| FnBody.mdata _ b => visitFnBody w b -| FnBody.jmp j ys => visitJP w j || visitArgs w ys -| FnBody.ret x => visitArg w x -| FnBody.case _ x _ alts => visitVar w x || alts.any (fun alt => visitFnBody w alt.body) -| FnBody.unreachable => false + | FnBody.vdecl x _ v b => visitExpr w v || visitFnBody w b + | FnBody.jdecl j ys v b => visitFnBody w v || visitFnBody w b + | FnBody.set x _ y b => visitVar w x || visitArg w y || visitFnBody w b + | FnBody.uset x _ y b => visitVar w x || visitVar w y || visitFnBody w b + | FnBody.sset x _ _ y _ b => visitVar w x || visitVar w y || visitFnBody w b + | FnBody.setTag x _ b => visitVar w x || visitFnBody w b + | FnBody.inc x _ _ _ b => visitVar w x || visitFnBody w b + | FnBody.dec x _ _ _ b => visitVar w x || visitFnBody w b + | FnBody.del x b => visitVar w x || visitFnBody w b + | FnBody.mdata _ b => visitFnBody w b + | FnBody.jmp j ys => visitJP w j || visitArgs w ys + | FnBody.ret x => visitArg w x + | FnBody.case _ x _ alts => visitVar w x || alts.any (fun alt => visitFnBody w alt.body) + | FnBody.unreachable => false end HasIndex diff --git a/src/Lean/Compiler/IR/PushProj.lean b/src/Lean/Compiler/IR/PushProj.lean index b40fb22fb7..19fbec98b6 100644 --- a/src/Lean/Compiler/IR/PushProj.lean +++ b/src/Lean/Compiler/IR/PushProj.lean @@ -9,8 +9,7 @@ import Lean.Compiler.IR.NormIds namespace Lean.IR -partial def pushProjs : Array FnBody → Array Alt → Array IndexSet → Array FnBody → IndexSet → Array FnBody × Array Alt -| bs, alts, altsF, ctx, ctxF => +partial def pushProjs (bs : Array FnBody) (alts : Array Alt) (altsF : Array IndexSet) (ctx : Array FnBody) (ctxF : IndexSet) : Array FnBody × Array Alt := if bs.isEmpty then (ctx.reverse, alts) else let b := bs.back @@ -37,8 +36,7 @@ partial def pushProjs : Array FnBody → Array Alt → Array IndexSet → Array | _ => done () | _ => done () -partial def FnBody.pushProj : FnBody → FnBody -| b => +partial def FnBody.pushProj (b : FnBody) : FnBody := let (bs, term) := b.flatten let bs := modifyJPs bs pushProj match term with @@ -52,7 +50,7 @@ partial def FnBody.pushProj : FnBody → FnBody /-- Push projections inside `case` branches. -/ def Decl.pushProj : Decl → Decl -| Decl.fdecl f xs t b => (Decl.fdecl f xs t b.pushProj).normalizeIds -| other => other + | Decl.fdecl f xs t b => (Decl.fdecl f xs t b.pushProj).normalizeIds + | other => other end Lean.IR diff --git a/src/Lean/Compiler/IR/RC.lean b/src/Lean/Compiler/IR/RC.lean index 9382acf275..191777df1d 100644 --- a/src/Lean/Compiler/IR/RC.lean +++ b/src/Lean/Compiler/IR/RC.lean @@ -14,82 +14,82 @@ namespace Lean.IR.ExplicitRC -/ structure VarInfo := -(ref : Bool := true) -- true if the variable may be a reference (aka pointer) at runtime -(persistent : Bool := false) -- true if the variable is statically known to be marked a Persistent at runtime -(consume : Bool := false) -- true if the variable RC must be "consumed" + (ref : Bool := true) -- true if the variable may be a reference (aka pointer) at runtime + (persistent : Bool := false) -- true if the variable is statically known to be marked a Persistent at runtime + (consume : Bool := false) -- true if the variable RC must be "consumed" abbrev VarMap := Std.RBMap VarId VarInfo (fun x y => x.idx < y.idx) structure Context := -(env : Environment) -(decls : Array Decl) -(varMap : VarMap := {}) -(jpLiveVarMap : JPLiveVarMap := {}) -- map: join point => live variables -(localCtx : LocalContext := {}) -- we use it to store the join point declarations + (env : Environment) + (decls : Array Decl) + (varMap : VarMap := {}) + (jpLiveVarMap : JPLiveVarMap := {}) -- map: join point => live variables + (localCtx : LocalContext := {}) -- we use it to store the join point declarations def getDecl (ctx : Context) (fid : FunId) : Decl := - match findEnvDecl' ctx.env fid ctx.decls with -| some decl => decl -| none => arbitrary _ -- unreachable if well-formed + match findEnvDecl' ctx.env fid ctx.decls with + | some decl => decl + | none => arbitrary _ -- unreachable if well-formed def getVarInfo (ctx : Context) (x : VarId) : VarInfo := -match ctx.varMap.find? x with -| some info => info -| none => {} -- unreachable in well-formed code + match ctx.varMap.find? x with + | some info => info + | none => {} -- unreachable in well-formed code def getJPParams (ctx : Context) (j : JoinPointId) : Array Param := -match ctx.localCtx.getJPParams j with -| some ps => ps -| none => #[] -- unreachable in well-formed code + match ctx.localCtx.getJPParams j with + | some ps => ps + | none => #[] -- unreachable in well-formed code def getJPLiveVars (ctx : Context) (j : JoinPointId) : LiveVarSet := -match ctx.jpLiveVarMap.find? j with -| some s => s -| none => {} + match ctx.jpLiveVarMap.find? j with + | some s => s + | none => {} def mustConsume (ctx : Context) (x : VarId) : Bool := -let info := getVarInfo ctx x -info.ref && info.consume + let info := getVarInfo ctx x + info.ref && info.consume @[inline] def addInc (ctx : Context) (x : VarId) (b : FnBody) (n := 1) : FnBody := -let info := getVarInfo ctx x -if n == 0 then b else FnBody.inc x n true info.persistent b + let info := getVarInfo ctx x + if n == 0 then b else FnBody.inc x n true info.persistent b @[inline] def addDec (ctx : Context) (x : VarId) (b : FnBody) : FnBody := -let info := getVarInfo ctx x -FnBody.dec x 1 true info.persistent b + let info := getVarInfo ctx x + FnBody.dec x 1 true info.persistent b private def updateRefUsingCtorInfo (ctx : Context) (x : VarId) (c : CtorInfo) : Context := -if c.isRef then ctx -else - let m := ctx.varMap - { ctx with - varMap := match m.find? x with - | some info => m.insert x { info with ref := false } -- I really want a Lenses library + notation - | none => m } + if c.isRef then + ctx + else + let m := ctx.varMap + { ctx with + varMap := match m.find? x with + | some info => m.insert x { info with ref := false } -- I really want a Lenses library + notation + | none => m } private def addDecForAlt (ctx : Context) (caseLiveVars altLiveVars : LiveVarSet) (b : FnBody) : FnBody := -caseLiveVars.fold - (fun b x => if !altLiveVars.contains x && mustConsume ctx x then addDec ctx x b else b) - b + caseLiveVars.fold (init := b) fun b x => + if !altLiveVars.contains x && mustConsume ctx x then addDec ctx x b else b /- `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` -/ private def isFirstOcc (xs : Array Arg) (i : Nat) : Bool := -let x := xs[i] -i.all fun j => xs[j] != x + let x := xs[i] + i.all fun j => xs[j] != x /- Return true if `x` also occurs in `ys` in a position that is not consumed. That is, it is also passed as a borrow reference. -/ @[specialize] private def isBorrowParamAux (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Bool := -ys.size.any $ fun i => - let y := ys[i] - match y with - | Arg.irrelevant => false - | Arg.var y => x == y && !consumeParamPred i + ys.size.any fun i => + let y := ys[i] + match y with + | Arg.irrelevant => false + | Arg.var y => x == y && !consumeParamPred i private def isBorrowParam (x : VarId) (ys : Array Arg) (ps : Array Param) : Bool := -isBorrowParamAux x ys fun i => not ps[i].borrow + isBorrowParamAux x ys fun i => not ps[i].borrow /- Return `n`, the number of times `x` is consumed. @@ -98,190 +98,187 @@ Return `n`, the number of times `x` is consumed. -/ @[specialize] private def getNumConsumptions (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Nat := -ys.size.fold (init := 0) fun i n => - let y := ys[i] - match y with - | Arg.irrelevant => n - | Arg.var y => if x == y && consumeParamPred i then n+1 else n + ys.size.fold (init := 0) fun i n => + let y := ys[i] + match y with + | Arg.irrelevant => n + | Arg.var y => if x == y && consumeParamPred i then n+1 else n @[specialize] private def addIncBeforeAux (ctx : Context) (xs : Array Arg) (consumeParamPred : Nat → Bool) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody := -xs.size.fold (init := b) fun i b => - let x := xs[i] - match x with - | Arg.irrelevant => b - | Arg.var x => - let info := getVarInfo ctx x - if !info.ref || !isFirstOcc xs i then b - else - let numConsuptions := getNumConsumptions x xs consumeParamPred -- number of times the argument is - let numIncs := - if !info.consume || -- `x` is not a variable that must be consumed by the current procedure - liveVarsAfter.contains x || -- `x` is live after executing instruction - isBorrowParamAux x xs consumeParamPred -- `x` is used in a position that is passed as a borrow reference - then numConsuptions - else numConsuptions - 1 - -- dbgTrace ("addInc " ++ toString x ++ " nconsumptions: " ++ toString numConsuptions ++ " incs: " ++ toString numIncs - -- ++ " consume: " ++ toString info.consume ++ " live: " ++ toString (liveVarsAfter.contains x) - -- ++ " borrowParam : " ++ toString (isBorrowParamAux x xs consumeParamPred)) $ fun _ => - addInc ctx x b numIncs - + xs.size.fold (init := b) fun i b => + let x := xs[i] + match x with + | Arg.irrelevant => b + | Arg.var x => + let info := getVarInfo ctx x + if !info.ref || !isFirstOcc xs i then b + else + let numConsuptions := getNumConsumptions x xs consumeParamPred -- number of times the argument is + let numIncs := + if !info.consume || -- `x` is not a variable that must be consumed by the current procedure + liveVarsAfter.contains x || -- `x` is live after executing instruction + isBorrowParamAux x xs consumeParamPred -- `x` is used in a position that is passed as a borrow reference + then numConsuptions + else numConsuptions - 1 + -- dbgTrace ("addInc " ++ toString x ++ " nconsumptions: " ++ toString numConsuptions ++ " incs: " ++ toString numIncs + -- ++ " consume: " ++ toString info.consume ++ " live: " ++ toString (liveVarsAfter.contains x) + -- ++ " borrowParam : " ++ toString (isBorrowParamAux x xs consumeParamPred)) $ fun _ => + addInc ctx x b numIncs private def addIncBefore (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody := -addIncBeforeAux ctx xs (fun i => not ps[i].borrow) b liveVarsAfter + addIncBeforeAux ctx xs (fun i => not ps[i].borrow) b liveVarsAfter /- See `addIncBeforeAux`/`addIncBefore` for the procedure that inserts `inc` operations before an application. -/ private def addDecAfterFullApp (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody := -xs.size.fold - (fun i b => - match xs[i] with - | Arg.irrelevant => b - | Arg.var x => - /- We must add a `dec` if `x` must be consumed, it is alive after the application, - and it has been borrowed by the application. - Remark: `x` may occur multiple times in the application (e.g., `f x y x`). - This is why we check whether it is the first occurrence. -/ - if mustConsume ctx x && isFirstOcc xs i && isBorrowParam x xs ps && !bLiveVars.contains x then - addDec ctx x b - else b) - b +xs.size.fold (init := b) fun i b => + match xs[i] with + | Arg.irrelevant => b + | Arg.var x => + /- We must add a `dec` if `x` must be consumed, it is alive after the application, + and it has been borrowed by the application. + Remark: `x` may occur multiple times in the application (e.g., `f x y x`). + This is why we check whether it is the first occurrence. -/ + if mustConsume ctx x && isFirstOcc xs i && isBorrowParam x xs ps && !bLiveVars.contains x then + addDec ctx x b + else b private def addIncBeforeConsumeAll (ctx : Context) (xs : Array Arg) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody := -addIncBeforeAux ctx xs (fun i => true) b liveVarsAfter + addIncBeforeAux ctx xs (fun i => true) b liveVarsAfter /- Add `dec` instructions for parameters that are references, are not alive in `b`, and are not borrow. That is, we must make sure these parameters are consumed. -/ private def addDecForDeadParams (ctx : Context) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody := -ps.foldl - (fun b p => if !p.borrow && p.ty.isObj && !bLiveVars.contains p.x then addDec ctx p.x b else b) - b + ps.foldl (init := b) fun b p => + if !p.borrow && p.ty.isObj && !bLiveVars.contains p.x then addDec ctx p.x b else b private def isPersistent : Expr → Bool -| Expr.fap c xs => xs.isEmpty -- all global constants are persistent objects -| _ => false + | Expr.fap c xs => xs.isEmpty -- all global constants are persistent objects + | _ => false /- We do not need to consume the projection of a variable that is not consumed -/ private def consumeExpr (m : VarMap) : Expr → Bool -| Expr.proj i x => match m.find? x with - | some info => info.consume - | none => true -| other => true + | Expr.proj i x => match m.find? x with + | some info => info.consume + | none => true + | other => true /- Return true iff `v` at runtime is a scalar value stored in a tagged pointer. We do not need RC operations for this kind of value. -/ private def isScalarBoxedInTaggedPtr (v : Expr) : Bool := -match v with -| Expr.ctor c ys => c.size == 0 && c.ssize == 0 && c.usize == 0 -| Expr.lit (LitVal.num n) => n ≤ maxSmallNat -| _ => false + match v with + | Expr.ctor c ys => c.size == 0 && c.ssize == 0 && c.usize == 0 + | Expr.lit (LitVal.num n) => n ≤ maxSmallNat + | _ => false private def updateVarInfo (ctx : Context) (x : VarId) (t : IRType) (v : Expr) : Context := -{ ctx with - varMap := ctx.varMap.insert x { - ref := t.isObj && !isScalarBoxedInTaggedPtr v, - persistent := isPersistent v, - consume := consumeExpr ctx.varMap v } } + { ctx with + varMap := ctx.varMap.insert x { + ref := t.isObj && !isScalarBoxedInTaggedPtr v, + persistent := isPersistent v, + consume := consumeExpr ctx.varMap v + } + } private def addDecIfNeeded (ctx : Context) (x : VarId) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody := -if mustConsume ctx x && !bLiveVars.contains x then addDec ctx x b else b + if mustConsume ctx x && !bLiveVars.contains x then addDec ctx x b else b private def processVDecl (ctx : Context) (z : VarId) (t : IRType) (v : Expr) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody × LiveVarSet := --- dbgTrace ("processVDecl " ++ toString z ++ " " ++ toString (format v)) $ fun _ => -let b := match v with - | (Expr.ctor _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars - | (Expr.reuse _ _ _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars - | (Expr.proj _ x) => - let b := addDecIfNeeded ctx x b bLiveVars - let b := if (getVarInfo ctx x).consume then addInc ctx z b else b - (FnBody.vdecl z t v b) - | (Expr.uproj _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars) - | (Expr.sproj _ _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars) - | (Expr.fap f ys) => - -- dbgTrace ("processVDecl " ++ toString v) $ fun _ => - let ps := (getDecl ctx f).params - let b := addDecAfterFullApp ctx ys ps b bLiveVars - let b := FnBody.vdecl z t v b - addIncBefore ctx ys ps b bLiveVars - | (Expr.pap _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars - | (Expr.ap x ys) => - let ysx := ys.push (Arg.var x) -- TODO: avoid temporary array allocation - addIncBeforeConsumeAll ctx ysx (FnBody.vdecl z t v b) bLiveVars - | (Expr.unbox x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars) - | other => FnBody.vdecl z t v b -- Expr.reset, Expr.box, Expr.lit are handled here -let liveVars := updateLiveVars v bLiveVars -let liveVars := liveVars.erase z -(b, liveVars) + let b := match v with + | (Expr.ctor _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars + | (Expr.reuse _ _ _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars + | (Expr.proj _ x) => + let b := addDecIfNeeded ctx x b bLiveVars + let b := if (getVarInfo ctx x).consume then addInc ctx z b else b + (FnBody.vdecl z t v b) + | (Expr.uproj _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars) + | (Expr.sproj _ _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars) + | (Expr.fap f ys) => + -- dbgTrace ("processVDecl " ++ toString v) $ fun _ => + let ps := (getDecl ctx f).params + let b := addDecAfterFullApp ctx ys ps b bLiveVars + let b := FnBody.vdecl z t v b + addIncBefore ctx ys ps b bLiveVars + | (Expr.pap _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars + | (Expr.ap x ys) => + let ysx := ys.push (Arg.var x) -- TODO: avoid temporary array allocation + addIncBeforeConsumeAll ctx ysx (FnBody.vdecl z t v b) bLiveVars + | (Expr.unbox x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars) + | other => FnBody.vdecl z t v b -- Expr.reset, Expr.box, Expr.lit are handled here + let liveVars := updateLiveVars v bLiveVars + let liveVars := liveVars.erase z + (b, liveVars) def updateVarInfoWithParams (ctx : Context) (ps : Array Param) : Context := -let m := ps.foldl (init := ctx.varMap) fun m p => - m.insert p.x { ref := p.ty.isObj, consume := !p.borrow } -{ ctx with varMap := m } + let m := ps.foldl (init := ctx.varMap) fun m p => + m.insert p.x { ref := p.ty.isObj, consume := !p.borrow } + { ctx with varMap := m } partial def visitFnBody : FnBody → Context → (FnBody × LiveVarSet) -| FnBody.vdecl x t v b, ctx => - let ctx := updateVarInfo ctx x t v - let (b, bLiveVars) := visitFnBody b ctx - processVDecl ctx x t v b bLiveVars -| FnBody.jdecl j xs v b, ctx => - let (v, vLiveVars) := visitFnBody v (updateVarInfoWithParams ctx xs) - let v := addDecForDeadParams ctx xs v vLiveVars - let ctx := { ctx with jpLiveVarMap := updateJPLiveVarMap j xs v ctx.jpLiveVarMap } - let (b, bLiveVars) := visitFnBody b ctx - (FnBody.jdecl j xs v b, bLiveVars) -| FnBody.uset x i y b, ctx => - let (b, s) := visitFnBody b ctx - -- We don't need to insert `y` since we only need to track live variables that are references at runtime - let s := s.insert x - (FnBody.uset x i y b, s) -| FnBody.sset x i o y t b, ctx => - let (b, s) := visitFnBody b ctx - -- We don't need to insert `y` since we only need to track live variables that are references at runtime - let s := s.insert x - (FnBody.sset x i o y t b, s) -| FnBody.mdata m b, ctx => - let (b, s) := visitFnBody b ctx - (FnBody.mdata m b, s) -| b@(FnBody.case tid x xType alts), ctx => - let caseLiveVars := collectLiveVars b ctx.jpLiveVarMap - let alts := alts.map $ fun alt => match alt with - | Alt.ctor c b => - let ctx := updateRefUsingCtorInfo ctx x c - let (b, altLiveVars) := visitFnBody b ctx - let b := addDecForAlt ctx caseLiveVars altLiveVars b - Alt.ctor c b - | Alt.default b => - let (b, altLiveVars) := visitFnBody b ctx - let b := addDecForAlt ctx caseLiveVars altLiveVars b - Alt.default b - (FnBody.case tid x xType alts, caseLiveVars) -| b@(FnBody.ret x), ctx => - match x with - | Arg.var x => - let info := getVarInfo ctx x - if info.ref && !info.consume then (addInc ctx x b, mkLiveVarSet x) else (b, mkLiveVarSet x) - | _ => (b, {}) -| b@(FnBody.jmp j xs), ctx => - let jLiveVars := getJPLiveVars ctx j - let ps := getJPParams ctx j - let b := addIncBefore ctx xs ps b jLiveVars - let bLiveVars := collectLiveVars b ctx.jpLiveVarMap - (b, bLiveVars) -| FnBody.unreachable, _ => (FnBody.unreachable, {}) -| other, ctx => (other, {}) -- unreachable if well-formed + | FnBody.vdecl x t v b, ctx => + let ctx := updateVarInfo ctx x t v + let (b, bLiveVars) := visitFnBody b ctx + processVDecl ctx x t v b bLiveVars + | FnBody.jdecl j xs v b, ctx => + let (v, vLiveVars) := visitFnBody v (updateVarInfoWithParams ctx xs) + let v := addDecForDeadParams ctx xs v vLiveVars + let ctx := { ctx with jpLiveVarMap := updateJPLiveVarMap j xs v ctx.jpLiveVarMap } + let (b, bLiveVars) := visitFnBody b ctx + (FnBody.jdecl j xs v b, bLiveVars) + | FnBody.uset x i y b, ctx => + let (b, s) := visitFnBody b ctx + -- We don't need to insert `y` since we only need to track live variables that are references at runtime + let s := s.insert x + (FnBody.uset x i y b, s) + | FnBody.sset x i o y t b, ctx => + let (b, s) := visitFnBody b ctx + -- We don't need to insert `y` since we only need to track live variables that are references at runtime + let s := s.insert x + (FnBody.sset x i o y t b, s) + | FnBody.mdata m b, ctx => + let (b, s) := visitFnBody b ctx + (FnBody.mdata m b, s) + | b@(FnBody.case tid x xType alts), ctx => + let caseLiveVars := collectLiveVars b ctx.jpLiveVarMap + let alts := alts.map $ fun alt => match alt with + | Alt.ctor c b => + let ctx := updateRefUsingCtorInfo ctx x c + let (b, altLiveVars) := visitFnBody b ctx + let b := addDecForAlt ctx caseLiveVars altLiveVars b + Alt.ctor c b + | Alt.default b => + let (b, altLiveVars) := visitFnBody b ctx + let b := addDecForAlt ctx caseLiveVars altLiveVars b + Alt.default b + (FnBody.case tid x xType alts, caseLiveVars) + | b@(FnBody.ret x), ctx => + match x with + | Arg.var x => + let info := getVarInfo ctx x + if info.ref && !info.consume then (addInc ctx x b, mkLiveVarSet x) else (b, mkLiveVarSet x) + | _ => (b, {}) + | b@(FnBody.jmp j xs), ctx => + let jLiveVars := getJPLiveVars ctx j + let ps := getJPParams ctx j + let b := addIncBefore ctx xs ps b jLiveVars + let bLiveVars := collectLiveVars b ctx.jpLiveVarMap + (b, bLiveVars) + | FnBody.unreachable, _ => (FnBody.unreachable, {}) + | other, ctx => (other, {}) -- unreachable if well-formed partial def visitDecl (env : Environment) (decls : Array Decl) : Decl → Decl -| Decl.fdecl f xs t b => - let ctx : Context := { env := env, decls := decls } - let ctx := updateVarInfoWithParams ctx xs - let (b, bLiveVars) := visitFnBody b ctx - let b := addDecForDeadParams ctx xs b bLiveVars - Decl.fdecl f xs t b -| other => other + | Decl.fdecl f xs t b => + let ctx : Context := { env := env, decls := decls } + let ctx := updateVarInfoWithParams ctx xs + let (b, bLiveVars) := visitFnBody b ctx + let b := addDecForDeadParams ctx xs b bLiveVars + Decl.fdecl f xs t b + | other => other end ExplicitRC def explicitRC (decls : Array Decl) : CompilerM (Array Decl) := do -let env ← getEnv -pure $ decls.map (ExplicitRC.visitDecl env decls) + let env ← getEnv + pure $ decls.map (ExplicitRC.visitDecl env decls) end Lean.IR diff --git a/src/Lean/Compiler/IR/ResetReuse.lean b/src/Lean/Compiler/IR/ResetReuse.lean index d794505104..eeb998f4e4 100644 --- a/src/Lean/Compiler/IR/ResetReuse.lean +++ b/src/Lean/Compiler/IR/ResetReuse.lean @@ -26,56 +26,56 @@ namespace Lean.IR.ResetReuse -/ private def mayReuse (c₁ c₂ : CtorInfo) : Bool := -c₁.size == c₂.size && c₁.usize == c₂.usize && c₁.ssize == c₂.ssize && -/- The following condition is a heuristic. - We don't want to reuse cells from different types even when they are compatible - because it produces counterintuitive behavior. -/ -c₁.name.getPrefix == c₂.name.getPrefix + c₁.size == c₂.size && c₁.usize == c₂.usize && c₁.ssize == c₂.ssize && + /- The following condition is a heuristic. + We don't want to reuse cells from different types even when they are compatible + because it produces counterintuitive behavior. -/ + c₁.name.getPrefix == c₂.name.getPrefix private partial def S (w : VarId) (c : CtorInfo) : FnBody → FnBody -| FnBody.vdecl x t v@(Expr.ctor c' ys) b => - if mayReuse c c' then - let updtCidx := c.cidx != c'.cidx - FnBody.vdecl x t (Expr.reuse w c' updtCidx ys) b - else - FnBody.vdecl x t v (S w c b) -| FnBody.jdecl j ys v b => - let v' := S w c v - if v == v' then FnBody.jdecl j ys v (S w c b) - else FnBody.jdecl j ys v' b -| FnBody.case tid x xType alts => FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (S w c) -| b => - if b.isTerminal then b - else let - (instr, b) := b.split - instr.setBody (S w c b) + | FnBody.vdecl x t v@(Expr.ctor c' ys) b => + if mayReuse c c' then + let updtCidx := c.cidx != c'.cidx + FnBody.vdecl x t (Expr.reuse w c' updtCidx ys) b + else + FnBody.vdecl x t v (S w c b) + | FnBody.jdecl j ys v b => + let v' := S w c v + if v == v' then FnBody.jdecl j ys v (S w c b) + else FnBody.jdecl j ys v' b + | FnBody.case tid x xType alts => FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (S w c) + | b => + if b.isTerminal then b + else let + (instr, b) := b.split + instr.setBody (S w c b) /- We use `Context` to track join points in scope. -/ abbrev M := ReaderT LocalContext (StateT Index Id) private def mkFresh : M VarId := do -let idx ← getModify (fun n => n + 1) -pure { idx := idx } + let idx ← getModify (fun n => n + 1) + pure { idx := idx } private def tryS (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody := do -let w ← mkFresh -let b' := S w c b -if b == b' then pure b -else pure $ FnBody.vdecl w IRType.object (Expr.reset c.size x) b' + let w ← mkFresh + let b' := S w c b + if b == b' then pure b + else pure $ FnBody.vdecl w IRType.object (Expr.reset c.size x) b' private def Dfinalize (x : VarId) (c : CtorInfo) : FnBody × Bool → M FnBody -| (b, true) => pure b -| (b, false) => tryS x c b + | (b, true) => pure b + | (b, false) => tryS x c b private def argsContainsVar (ys : Array Arg) (x : VarId) : Bool := -ys.any fun arg => match arg with - | Arg.var y => x == y - | _ => false + ys.any fun arg => match arg with + | Arg.var y => x == y + | _ => false private def isCtorUsing (b : FnBody) (x : VarId) : Bool := -match b with -| (FnBody.vdecl _ _ (Expr.ctor _ ys) _) => argsContainsVar ys x -| _ => false + match b with + | (FnBody.vdecl _ _ (Expr.ctor _ ys) _) => argsContainsVar ys x + | _ => false /- Given `Dmain b`, the resulting pair `(new_b, flag)` contains the new body `new_b`, and `flag == true` if `x` is live in `b`. @@ -84,75 +84,75 @@ match b with `D` checks whether `x` is live in `F` or not. This is great for clarity but it is expensive: `O(n^2)` where `n` is the size of the function body. -/ private partial def Dmain (x : VarId) (c : CtorInfo) : FnBody → M (FnBody × Bool) -| e@(FnBody.case tid y yType alts) => do - let ctx ← read - if e.hasLiveVar ctx x then do - /- If `x` is live in `e`, we recursively process each branch. -/ - let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => Dmain x c b >>= Dfinalize x c - pure (FnBody.case tid y yType alts, true) - else pure (e, false) -| FnBody.jdecl j ys v b => do - let (b, found) ← withReader (fun ctx => ctx.addJP j ys v) (Dmain x c b) - let (v, _ /- found' -/) ← Dmain x c v - /- If `found' == true`, then `Dmain b` must also have returned `(b, true)` since - we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`), - then it must also live in `b` since `j` is reachable from `b` with a `jmp`. - On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/ - pure (FnBody.jdecl j ys v b, found) -| e => do - let ctx ← read - if e.isTerminal then - pure (e, e.hasLiveVar ctx x) - else do - let (instr, b) := e.split - if isCtorUsing instr x then - /- If the scrutinee `x` (the one that is providing memory) is being - stored in a constructor, then reuse will probably not be able to reuse memory at runtime. - It may work only if the new cell is consumed, but we ignore this case. -/ - pure (e, true) + | e@(FnBody.case tid y yType alts) => do + let ctx ← read + if e.hasLiveVar ctx x then do + /- If `x` is live in `e`, we recursively process each branch. -/ + let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => Dmain x c b >>= Dfinalize x c + pure (FnBody.case tid y yType alts, true) + else pure (e, false) + | FnBody.jdecl j ys v b => do + let (b, found) ← withReader (fun ctx => ctx.addJP j ys v) (Dmain x c b) + let (v, _ /- found' -/) ← Dmain x c v + /- If `found' == true`, then `Dmain b` must also have returned `(b, true)` since + we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`), + then it must also live in `b` since `j` is reachable from `b` with a `jmp`. + On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/ + pure (FnBody.jdecl j ys v b, found) + | e => do + let ctx ← read + if e.isTerminal then + pure (e, e.hasLiveVar ctx x) else do - let (b, found) ← Dmain x c b - /- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar` - since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/ - if found || !instr.hasFreeVar x then - pure (instr.setBody b, found) - else do - let b ← tryS x c b - pure (instr.setBody b, true) + let (instr, b) := e.split + if isCtorUsing instr x then + /- If the scrutinee `x` (the one that is providing memory) is being + stored in a constructor, then reuse will probably not be able to reuse memory at runtime. + It may work only if the new cell is consumed, but we ignore this case. -/ + pure (e, true) + else + let (b, found) ← Dmain x c b + /- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar` + since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/ + if found || !instr.hasFreeVar x then + pure (instr.setBody b, found) + else + let b ← tryS x c b + pure (instr.setBody b, true) private def D (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody := -Dmain x c b >>= Dfinalize x c + Dmain x c b >>= Dfinalize x c partial def R : FnBody → M FnBody -| FnBody.case tid x xType alts => do - let alts ← alts.mapM fun alt => do - let alt ← alt.mmodifyBody R - match alt with - | Alt.ctor c b => - if c.isScalar then pure alt - else Alt.ctor c <$> D x c b - | _ => pure alt - pure $ FnBody.case tid x xType alts -| FnBody.jdecl j ys v b => do - let v ← R v - let b ← withReader (fun ctx => ctx.addJP j ys v) (R b) - pure $ FnBody.jdecl j ys v b -| e => do - if e.isTerminal then pure e - else do - let (instr, b) := e.split - let b ← R b - pure (instr.setBody b) + | FnBody.case tid x xType alts => do + let alts ← alts.mapM fun alt => do + let alt ← alt.mmodifyBody R + match alt with + | Alt.ctor c b => + if c.isScalar then pure alt + else Alt.ctor c <$> D x c b + | _ => pure alt + pure $ FnBody.case tid x xType alts + | FnBody.jdecl j ys v b => do + let v ← R v + let b ← withReader (fun ctx => ctx.addJP j ys v) (R b) + pure $ FnBody.jdecl j ys v b + | e => do + if e.isTerminal then pure e + else do + let (instr, b) := e.split + let b ← R b + pure (instr.setBody b) end ResetReuse open ResetReuse def Decl.insertResetReuse : Decl → Decl -| d@(Decl.fdecl f xs t b) => - let nextIndex := d.maxIndex + 1 - let b := (R b {}).run' nextIndex - Decl.fdecl f xs t b -| other => other + | d@(Decl.fdecl f xs t b) => + let nextIndex := d.maxIndex + 1 + let b := (R b {}).run' nextIndex + Decl.fdecl f xs t b + | other => other end Lean.IR diff --git a/src/Lean/Compiler/IR/SimpCase.lean b/src/Lean/Compiler/IR/SimpCase.lean index d443e7decf..c75cf0ea1e 100644 --- a/src/Lean/Compiler/IR/SimpCase.lean +++ b/src/Lean/Compiler/IR/SimpCase.lean @@ -9,52 +9,51 @@ import Lean.Compiler.IR.Format namespace Lean.IR def ensureHasDefault (alts : Array Alt) : Array Alt := -if alts.any Alt.isDefault then alts -else if alts.size < 2 then alts -else - let last := alts.back; - let alts := alts.pop; - alts.push (Alt.default last.body) + if alts.any Alt.isDefault then alts + else if alts.size < 2 then alts + else + let last := alts.back; + let alts := alts.pop; + alts.push (Alt.default last.body) private def getOccsOf (alts : Array Alt) (i : Nat) : Nat := do -let aBody := (alts.get! i).body -let n := 1 -for j in [i+1:alts.size] do - if alts[j].body == aBody then - n := n+1 -return n + let aBody := (alts.get! i).body + let n := 1 + for j in [i+1:alts.size] do + if alts[j].body == aBody then + n := n+1 + return n private def maxOccs (alts : Array Alt) : Alt × Nat := do -let maxAlt := alts[0] -let max := getOccsOf alts 0 -for i in [1:alts.size] do - let curr := getOccsOf alts i - if curr > max then - maxAlt := alts[i] - max := curr -return (maxAlt, max) + let maxAlt := alts[0] + let max := getOccsOf alts 0 + for i in [1:alts.size] do + let curr := getOccsOf alts i + if curr > max then + maxAlt := alts[i] + max := curr + return (maxAlt, max) private def addDefault (alts : Array Alt) : Array Alt := -if alts.size <= 1 || alts.any Alt.isDefault then alts -else - let (max, noccs) := maxOccs alts; - if noccs == 1 then alts + if alts.size <= 1 || alts.any Alt.isDefault then alts else - let alts := alts.filter $ (fun alt => alt.body != max.body); - alts.push (Alt.default max.body) + let (max, noccs) := maxOccs alts; + if noccs == 1 then alts + else + let alts := alts.filter $ (fun alt => alt.body != max.body); + alts.push (Alt.default max.body) private def mkSimpCase (tid : Name) (x : VarId) (xType : IRType) (alts : Array Alt) : FnBody := -let alts := alts.filter (fun alt => alt.body != FnBody.unreachable); -let alts := addDefault alts; -if alts.size == 0 then - FnBody.unreachable -else if alts.size == 1 then - (alts.get! 0).body -else - FnBody.case tid x xType alts + let alts := alts.filter (fun alt => alt.body != FnBody.unreachable); + let alts := addDefault alts; + if alts.size == 0 then + FnBody.unreachable + else if alts.size == 1 then + (alts.get! 0).body + else + FnBody.case tid x xType alts -partial def FnBody.simpCase : FnBody → FnBody -| b => +partial def FnBody.simpCase (b : FnBody) : FnBody := let (bs, term) := b.flatten; let bs := modifyJPs bs simpCase; match term with @@ -68,7 +67,7 @@ partial def FnBody.simpCase : FnBody → FnBody - Remove `case` if there is only one branch. - Merge most common branches using `Alt.default`. -/ def Decl.simpCase : Decl → Decl -| Decl.fdecl f xs t b => Decl.fdecl f xs t b.simpCase -| other => other + | Decl.fdecl f xs t b => Decl.fdecl f xs t b.simpCase + | other => other end Lean.IR diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index af0706577e..18a4607966 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -824,13 +824,13 @@ This is a standard trick we use in the elaborator, and it is also used to elabor Suppose, we are trying to elaborate ``` match g x with -| ... => ... + | ... => ... ``` `expandNonAtomicDiscrs?` converts it intro ``` let _discr := g x match _discr with -| ... => ... + | ... => ... ``` Thus, at `tryPostponeIfDiscrTypeIsMVar` we only need to check whether the type of `_discr` is not of the form `(?m ...)`. Note that, the auxiliary variable `_discr` is expanded at `elabAtomicDiscr`. diff --git a/src/Lean/Exception.lean b/src/Lean/Exception.lean index b5e8f1260d..22ed707b87 100644 --- a/src/Lean/Exception.lean +++ b/src/Lean/Exception.lean @@ -108,17 +108,17 @@ syntax "throwError! " ((interpolatedStr term) <|> term) : term syntax "throwErrorAt! " term:max ((interpolatedStr term) <|> term) : term macro_rules -| `(throwError! $msg) => - if msg.getKind == interpolatedStrKind then - `(throwError (msg! $msg)) - else - `(throwError $msg) + | `(throwError! $msg) => + if msg.getKind == interpolatedStrKind then + `(throwError (msg! $msg)) + else + `(throwError $msg) macro_rules -| `(throwErrorAt! $ref $msg) => - if msg.getKind == interpolatedStrKind then - `(throwErrorAt $ref (msg! $msg)) - else - `(throwErrorAt $ref $msg) + | `(throwErrorAt! $ref $msg) => + if msg.getKind == interpolatedStrKind then + `(throwErrorAt $ref (msg! $msg)) + else + `(throwErrorAt $ref $msg) end Lean diff --git a/src/Lean/Level.lean b/src/Lean/Level.lean index 9f1233af51..777174ef6c 100644 --- a/src/Lean/Level.lean +++ b/src/Lean/Level.lean @@ -211,12 +211,12 @@ def occurs : Level → Level → Bool | u, v => u == v def ctorToNat : Level → Nat -| zero .. => 0 -| param .. => 1 -| mvar .. => 2 -| succ .. => 3 -| max .. => 4 -| imax .. => 5 + | zero .. => 0 + | param .. => 1 + | mvar .. => 2 + | succ .. => 3 + | max .. => 4 + | imax .. => 5 /- TODO: use well founded recursion. -/ partial def normLtAux : Level → Nat → Level → Nat → Bool diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index ce8105d0e1..4fd649412a 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -111,15 +111,15 @@ def replaceFVarId (fvarId : FVarId) (v : Expr) (alt : Alt) : Alt := ``` inductive Vec (α : Type u) : Nat → Type u -| nil : Vec α 0 -| cons {n} (head : α) (tail : Vec α n) : Vec α (n+1) + | nil : Vec α 0 + | cons {n} (head : α) (tail : Vec α n) : Vec α (n+1) inductive VecPred {α : Type u} (P : α → Prop) : {n : Nat} → Vec α n → Prop -| nil : VecPred P Vec.nil -| cons {n : Nat} {head : α} {tail : Vec α n} : P head → VecPred P tail → VecPred P (Vec.cons head tail) + | nil : VecPred P Vec.nil + | cons {n : Nat} {head : α} {tail : Vec α n} : P head → VecPred P tail → VecPred P (Vec.cons head tail) theorem ex {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P -| _, Vec.cons head _, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩ + | _, Vec.cons head _, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩ ``` Recall that `_` in a pattern can be elaborated into pattern variable or an inaccessible term. The elaborator uses an inaccessible term when typing constraints restrict its value. @@ -127,7 +127,7 @@ Thus, in the example above, the `_` at `Vec.cons head _` becomes the inaccessibl because the type ascription `(w : VecPred P Vec.nil)` propagates typing constraints that restrict its value to be `Vec.nil`. After elaboration the alternative becomes: ``` -| .(0), @Vec.cons .(α) .(0) head .(Vec.nil), @VecPred.cons .(α) .(P) .(0) .(head) .(Vec.nil) h w => ⟨head, h⟩ + | .(0), @Vec.cons .(α) .(0) head .(Vec.nil), @VecPred.cons .(α) .(P) .(0) .(head) .(Vec.nil) h w => ⟨head, h⟩ ``` where ``` @@ -138,7 +138,7 @@ Then, when we process this alternative in this module, the following check will Note that if we had written ``` theorem ex {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P -| _, Vec.cons head Vec.nil, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩ + | _, Vec.cons head Vec.nil, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩ ``` we would get the easier to digest error message ``` diff --git a/src/Lean/Parser/Extension.lean b/src/Lean/Parser/Extension.lean index 7ea4a660b5..b2f72d2708 100644 --- a/src/Lean/Parser/Extension.lean +++ b/src/Lean/Parser/Extension.lean @@ -19,7 +19,7 @@ builtin_initialize builtinTokenTable : IO.Ref TokenTable ← IO.mkRef {} builtin_initialize builtinSyntaxNodeKindSetRef : IO.Ref SyntaxNodeKindSet ← IO.mkRef {} def registerBuiltinNodeKind (k : SyntaxNodeKind) : IO Unit := -builtinSyntaxNodeKindSetRef.modify fun s => s.insert k + builtinSyntaxNodeKindSetRef.modify fun s => s.insert k builtin_initialize registerBuiltinNodeKind choiceKind @@ -32,215 +32,213 @@ builtin_initialize builtin_initialize builtinParserCategoriesRef : IO.Ref ParserCategories ← IO.mkRef {} private def throwParserCategoryAlreadyDefined {α} (catName : Name) : ExceptT String Id α := -throw s!"parser category '{catName}' has already been defined" + throw s!"parser category '{catName}' has already been defined" private def addParserCategoryCore (categories : ParserCategories) (catName : Name) (initial : ParserCategory) : Except String ParserCategories := -if categories.contains catName then - throwParserCategoryAlreadyDefined catName -else - pure $ categories.insert catName initial + if categories.contains catName then + throwParserCategoryAlreadyDefined catName + else + pure $ categories.insert catName initial /-- All builtin parser categories are Pratt's parsers -/ private def addBuiltinParserCategory (catName : Name) (leadingIdentAsSymbol : Bool) : IO Unit := do -let categories ← builtinParserCategoriesRef.get -let categories ← IO.ofExcept $ addParserCategoryCore categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol} -builtinParserCategoriesRef.set categories + let categories ← builtinParserCategoriesRef.get + let categories ← IO.ofExcept $ addParserCategoryCore categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol} + builtinParserCategoriesRef.set categories inductive ParserExtensionOleanEntry -| token (val : Token) : ParserExtensionOleanEntry -| kind (val : SyntaxNodeKind) : ParserExtensionOleanEntry -| category (catName : Name) (leadingIdentAsSymbol : Bool) -| parser (catName : Name) (declName : Name) (prio : Nat) : ParserExtensionOleanEntry + | token (val : Token) : ParserExtensionOleanEntry + | kind (val : SyntaxNodeKind) : ParserExtensionOleanEntry + | category (catName : Name) (leadingIdentAsSymbol : Bool) + | parser (catName : Name) (declName : Name) (prio : Nat) : ParserExtensionOleanEntry inductive ParserExtensionEntry -| token (val : Token) : ParserExtensionEntry -| kind (val : SyntaxNodeKind) : ParserExtensionEntry -| category (catName : Name) (leadingIdentAsSymbol : Bool) -| parser (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : ParserExtensionEntry + | token (val : Token) : ParserExtensionEntry + | kind (val : SyntaxNodeKind) : ParserExtensionEntry + | category (catName : Name) (leadingIdentAsSymbol : Bool) + | parser (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : ParserExtensionEntry structure ParserExtensionState := -(tokens : TokenTable := {}) -(kinds : SyntaxNodeKindSet := {}) -(categories : ParserCategories := {}) -(newEntries : List ParserExtensionOleanEntry := []) + (tokens : TokenTable := {}) + (kinds : SyntaxNodeKindSet := {}) + (categories : ParserCategories := {}) + (newEntries : List ParserExtensionOleanEntry := []) instance : Inhabited ParserExtensionState := ⟨{}⟩ abbrev ParserExtension := PersistentEnvExtension ParserExtensionOleanEntry ParserExtensionEntry ParserExtensionState private def ParserExtension.mkInitial : IO ParserExtensionState := do -let tokens ← builtinTokenTable.get -let kinds ← builtinSyntaxNodeKindSetRef.get -let categories ← builtinParserCategoriesRef.get -pure { tokens := tokens, kinds := kinds, categories := categories } + let tokens ← builtinTokenTable.get + let kinds ← builtinSyntaxNodeKindSetRef.get + let categories ← builtinParserCategoriesRef.get + pure { tokens := tokens, kinds := kinds, categories := categories } private def addTokenConfig (tokens : TokenTable) (tk : Token) : Except String TokenTable := do -if tk == "" then throw "invalid empty symbol" -else match tokens.find? tk with - | none => pure $ tokens.insert tk tk - | some _ => pure tokens + if tk == "" then throw "invalid empty symbol" + else match tokens.find? tk with + | none => pure $ tokens.insert tk tk + | some _ => pure tokens def throwUnknownParserCategory {α} (catName : Name) : ExceptT String Id α := -throw s!"unknown parser category '{catName}'" + throw s!"unknown parser category '{catName}'" abbrev getCategory (categories : ParserCategories) (catName : Name) : Option ParserCategory := -categories.find? catName + categories.find? catName def addLeadingParser (categories : ParserCategories) (catName : Name) (parserName : Name) (p : Parser) (prio : Nat) : Except String ParserCategories := -match getCategory categories catName with -| none => - throwUnknownParserCategory catName -| some cat => - let addTokens (tks : List Token) : Except String ParserCategories := - let tks := tks.map $ fun tk => mkNameSimple tk - let tables := tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with leadingTable := tables.leadingTable.insert tk (p, prio) }) cat.tables - pure $ categories.insert catName { cat with tables := tables } + match getCategory categories catName with + | none => + throwUnknownParserCategory catName + | some cat => + let addTokens (tks : List Token) : Except String ParserCategories := + let tks := tks.map $ fun tk => mkNameSimple tk + let tables := tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with leadingTable := tables.leadingTable.insert tk (p, prio) }) cat.tables + pure $ categories.insert catName { cat with tables := tables } + match p.info.firstTokens with + | FirstTokens.tokens tks => addTokens tks + | FirstTokens.optTokens tks => addTokens tks + | _ => + let tables := { cat.tables with leadingParsers := (p, prio) :: cat.tables.leadingParsers } + pure $ categories.insert catName { cat with tables := tables } + +private def addTrailingParserAux (tables : PrattParsingTables) (p : TrailingParser) (prio : Nat) : PrattParsingTables := + let addTokens (tks : List Token) : PrattParsingTables := + let tks := tks.map fun tk => mkNameSimple tk + tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with trailingTable := tables.trailingTable.insert tk (p, prio) }) tables match p.info.firstTokens with | FirstTokens.tokens tks => addTokens tks | FirstTokens.optTokens tks => addTokens tks - | _ => - let tables := { cat.tables with leadingParsers := (p, prio) :: cat.tables.leadingParsers } - pure $ categories.insert catName { cat with tables := tables } - -private def addTrailingParserAux (tables : PrattParsingTables) (p : TrailingParser) (prio : Nat) : PrattParsingTables := -let addTokens (tks : List Token) : PrattParsingTables := - let tks := tks.map fun tk => mkNameSimple tk - tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with trailingTable := tables.trailingTable.insert tk (p, prio) }) tables -match p.info.firstTokens with -| FirstTokens.tokens tks => addTokens tks -| FirstTokens.optTokens tks => addTokens tks -| _ => { tables with trailingParsers := (p, prio) :: tables.trailingParsers } + | _ => { tables with trailingParsers := (p, prio) :: tables.trailingParsers } def addTrailingParser (categories : ParserCategories) (catName : Name) (p : TrailingParser) (prio : Nat) : Except String ParserCategories := -match getCategory categories catName with -| none => throwUnknownParserCategory catName -| some cat => pure $ categories.insert catName { cat with tables := addTrailingParserAux cat.tables p prio } + match getCategory categories catName with + | none => throwUnknownParserCategory catName + | some cat => pure $ categories.insert catName { cat with tables := addTrailingParserAux cat.tables p prio } def addParser (categories : ParserCategories) (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : Except String ParserCategories := -match leading, p with -| true, p => addLeadingParser categories catName declName p prio -| false, p => addTrailingParser categories catName p prio + match leading, p with + | true, p => addLeadingParser categories catName declName p prio + | false, p => addTrailingParser categories catName p prio def addParserTokens (tokenTable : TokenTable) (info : ParserInfo) : Except String TokenTable := -let newTokens := info.collectTokens [] -newTokens.foldlM addTokenConfig tokenTable + let newTokens := info.collectTokens [] + newTokens.foldlM addTokenConfig tokenTable private def updateBuiltinTokens (info : ParserInfo) (declName : Name) : IO Unit := do -let tokenTable ← builtinTokenTable.swap {} -match addParserTokens tokenTable info with -| Except.ok tokenTable => builtinTokenTable.set tokenTable -| Except.error msg => throw (IO.userError s!"invalid builtin parser '{declName}', {msg}") + let tokenTable ← builtinTokenTable.swap {} + match addParserTokens tokenTable info with + | Except.ok tokenTable => builtinTokenTable.set tokenTable + | Except.error msg => throw (IO.userError s!"invalid builtin parser '{declName}', {msg}") def addBuiltinParser (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : IO Unit := do -let categories ← builtinParserCategoriesRef.get -let categories ← IO.ofExcept $ addParser categories catName declName leading p prio -builtinParserCategoriesRef.set categories -builtinSyntaxNodeKindSetRef.modify p.info.collectKinds -updateBuiltinTokens p.info declName + let categories ← builtinParserCategoriesRef.get + let categories ← IO.ofExcept $ addParser categories catName declName leading p prio + builtinParserCategoriesRef.set categories + builtinSyntaxNodeKindSetRef.modify p.info.collectKinds + updateBuiltinTokens p.info declName def addBuiltinLeadingParser (catName : Name) (declName : Name) (p : Parser) (prio : Nat) : IO Unit := -addBuiltinParser catName declName true p prio + addBuiltinParser catName declName true p prio def addBuiltinTrailingParser (catName : Name) (declName : Name) (p : TrailingParser) (prio : Nat) : IO Unit := -addBuiltinParser catName declName false p prio + addBuiltinParser catName declName false p prio private def ParserExtensionAddEntry (s : ParserExtensionState) (e : ParserExtensionEntry) : ParserExtensionState := -match e with -| ParserExtensionEntry.token tk => - match addTokenConfig s.tokens tk with - | Except.ok tokens => { s with tokens := tokens, newEntries := ParserExtensionOleanEntry.token tk :: s.newEntries } - | _ => unreachable! -| ParserExtensionEntry.kind k => - { s with kinds := s.kinds.insert k, newEntries := ParserExtensionOleanEntry.kind k :: s.newEntries } -| ParserExtensionEntry.category catName leadingIdentAsSymbol => - if s.categories.contains catName then s - else { s with - categories := s.categories.insert catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol }, - newEntries := ParserExtensionOleanEntry.category catName leadingIdentAsSymbol :: s.newEntries } -| ParserExtensionEntry.parser catName declName leading parser prio => - match addParser s.categories catName declName leading parser prio with - | Except.ok categories => { s with categories := categories, newEntries := ParserExtensionOleanEntry.parser catName declName prio :: s.newEntries } - | _ => unreachable! + match e with + | ParserExtensionEntry.token tk => + match addTokenConfig s.tokens tk with + | Except.ok tokens => { s with tokens := tokens, newEntries := ParserExtensionOleanEntry.token tk :: s.newEntries } + | _ => unreachable! + | ParserExtensionEntry.kind k => + { s with kinds := s.kinds.insert k, newEntries := ParserExtensionOleanEntry.kind k :: s.newEntries } + | ParserExtensionEntry.category catName leadingIdentAsSymbol => + if s.categories.contains catName then s + else { s with + categories := s.categories.insert catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol }, + newEntries := ParserExtensionOleanEntry.category catName leadingIdentAsSymbol :: s.newEntries } + | ParserExtensionEntry.parser catName declName leading parser prio => + match addParser s.categories catName declName leading parser prio with + | Except.ok categories => { s with categories := categories, newEntries := ParserExtensionOleanEntry.parser catName declName prio :: s.newEntries } + | _ => unreachable! unsafe def mkParserOfConstantUnsafe (env : Environment) (opts : Options) (categories : ParserCategories) (constName : Name) (compileParserDescr : ParserDescr → Except String Parser) : Except String (Bool × Parser) := -match env.find? constName with -| none => throw s!"unknow constant '{constName}'" -| some info => - match info.type with - | Expr.const `Lean.Parser.TrailingParser _ _ => do - let p ← env.evalConst Parser opts constName - pure ⟨false, p⟩ - | Expr.const `Lean.Parser.Parser _ _ => do - let p ← env.evalConst Parser opts constName - pure ⟨true, p⟩ - | Expr.const `Lean.ParserDescr _ _ => do - let d ← env.evalConst ParserDescr opts constName - let p ← compileParserDescr d - pure ⟨true, p⟩ - | Expr.const `Lean.TrailingParserDescr _ _ => do - let d ← env.evalConst TrailingParserDescr opts constName - let p ← compileParserDescr d - pure ⟨false, p⟩ - | _ => throw s!"unexpected parser type at '{constName}' (`ParserDescr`, `TrailingParserDescr`, `Parser` or `TrailingParser` expected" + match env.find? constName with + | none => throw s!"unknow constant '{constName}'" + | some info => + match info.type with + | Expr.const `Lean.Parser.TrailingParser _ _ => do + let p ← env.evalConst Parser opts constName + pure ⟨false, p⟩ + | Expr.const `Lean.Parser.Parser _ _ => do + let p ← env.evalConst Parser opts constName + pure ⟨true, p⟩ + | Expr.const `Lean.ParserDescr _ _ => do + let d ← env.evalConst ParserDescr opts constName + let p ← compileParserDescr d + pure ⟨true, p⟩ + | Expr.const `Lean.TrailingParserDescr _ _ => do + let d ← env.evalConst TrailingParserDescr opts constName + let p ← compileParserDescr d + pure ⟨false, p⟩ + | _ => throw s!"unexpected parser type at '{constName}' (`ParserDescr`, `TrailingParserDescr`, `Parser` or `TrailingParser` expected" @[implementedBy mkParserOfConstantUnsafe] constant mkParserOfConstantAux (env : Environment) (opts : Options) (categories : ParserCategories) (constName : Name) (compileParserDescr : ParserDescr → Except String Parser) - : Except String (Bool × Parser) := -arbitrary _ + : Except String (Bool × Parser) partial def compileParserDescr (env : Environment) (opts : Options) (categories : ParserCategories) (d : ParserDescr) : Except String Parser := -let rec visit : ParserDescr → Except String Parser - | ParserDescr.andthen d₁ d₂ => andthen <$> visit d₁ <*> visit d₂ - | ParserDescr.orelse d₁ d₂ => orelse <$> visit d₁ <*> visit d₂ - | ParserDescr.optional d => optional <$> visit d - | ParserDescr.lookahead d => lookahead <$> visit d - | ParserDescr.«try» d => «try» <$> visit d - | ParserDescr.notFollowedBy d => do let p ← visit d; pure $ notFollowedBy p "element" -- TODO allow user to set msg at ParserDescr - | ParserDescr.many d => many <$> visit d - | ParserDescr.many1 d => many1 <$> visit d - | ParserDescr.sepBy d₁ d₂ => sepBy <$> visit d₁ <*> visit d₂ - | ParserDescr.sepBy1 d₁ d₂ => sepBy1 <$> visit d₁ <*> visit d₂ - | ParserDescr.node k prec d => leadingNode k prec <$> visit d - | ParserDescr.trailingNode k prec d => trailingNode k prec <$> visit d - | ParserDescr.symbol tk => pure $ symbol tk - | ParserDescr.noWs => pure $ checkNoWsBefore - | ParserDescr.numLit => pure $ numLit - | ParserDescr.strLit => pure $ strLit - | ParserDescr.charLit => pure $ charLit - | ParserDescr.nameLit => pure $ nameLit - | ParserDescr.interpolatedStr d => interpolatedStr <$> visit d - | ParserDescr.ident => pure $ ident - | ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol tk includeIdent - | ParserDescr.parser constName => do - let (_, p) ← mkParserOfConstantAux env opts categories constName visit; - pure p - | ParserDescr.cat catName prec => - match getCategory categories catName with - | some _ => pure $ categoryParser catName prec - | none => throwUnknownParserCategory catName -visit d + let rec visit : ParserDescr → Except String Parser + | ParserDescr.andthen d₁ d₂ => andthen <$> visit d₁ <*> visit d₂ + | ParserDescr.orelse d₁ d₂ => orelse <$> visit d₁ <*> visit d₂ + | ParserDescr.optional d => optional <$> visit d + | ParserDescr.lookahead d => lookahead <$> visit d + | ParserDescr.«try» d => «try» <$> visit d + | ParserDescr.notFollowedBy d => do let p ← visit d; pure $ notFollowedBy p "element" -- TODO allow user to set msg at ParserDescr + | ParserDescr.many d => many <$> visit d + | ParserDescr.many1 d => many1 <$> visit d + | ParserDescr.sepBy d₁ d₂ => sepBy <$> visit d₁ <*> visit d₂ + | ParserDescr.sepBy1 d₁ d₂ => sepBy1 <$> visit d₁ <*> visit d₂ + | ParserDescr.node k prec d => leadingNode k prec <$> visit d + | ParserDescr.trailingNode k prec d => trailingNode k prec <$> visit d + | ParserDescr.symbol tk => pure $ symbol tk + | ParserDescr.noWs => pure $ checkNoWsBefore + | ParserDescr.numLit => pure $ numLit + | ParserDescr.strLit => pure $ strLit + | ParserDescr.charLit => pure $ charLit + | ParserDescr.nameLit => pure $ nameLit + | ParserDescr.interpolatedStr d => interpolatedStr <$> visit d + | ParserDescr.ident => pure $ ident + | ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol tk includeIdent + | ParserDescr.parser constName => do + let (_, p) ← mkParserOfConstantAux env opts categories constName visit; + pure p + | ParserDescr.cat catName prec => + match getCategory categories catName with + | some _ => pure $ categoryParser catName prec + | none => throwUnknownParserCategory catName + visit d def mkParserOfConstant (env : Environment) (opts : Options) (categories : ParserCategories) (constName : Name) : Except String (Bool × Parser) := -mkParserOfConstantAux env opts categories constName (compileParserDescr env opts categories) + mkParserOfConstantAux env opts categories constName (compileParserDescr env opts categories) structure ParserAttributeHook := -/- Called after a parser attribute is applied to a declaration. -/ -(postAdd (catName : Name) (declName : Name) (builtin : Bool) : AttrM Unit) + /- Called after a parser attribute is applied to a declaration. -/ + (postAdd (catName : Name) (declName : Name) (builtin : Bool) : AttrM Unit) -def mkParserAttributeHooks : IO (IO.Ref (List ParserAttributeHook)) := IO.mkRef {} -@[builtinInit mkParserAttributeHooks] constant parserAttributeHooks : IO.Ref (List ParserAttributeHook) := arbitrary _ +builtin_initialize parserAttributeHooks : IO.Ref (List ParserAttributeHook) ← IO.mkRef {} def registerParserAttributeHook (hook : ParserAttributeHook) : IO Unit := do -parserAttributeHooks.modify fun hooks => hook::hooks + parserAttributeHooks.modify fun hooks => hook::hooks def runParserAttributeHooks (catName : Name) (declName : Name) (builtin : Bool) : AttrM Unit := do -let hooks ← parserAttributeHooks.get -hooks.forM fun hook => hook.postAdd catName declName builtin + let hooks ← parserAttributeHooks.get + hooks.forM fun hook => hook.postAdd catName declName builtin builtin_initialize registerBuiltinAttribute { @@ -261,27 +259,23 @@ builtin_initialize } private def ParserExtension.addImported (es : Array (Array ParserExtensionOleanEntry)) : ImportM ParserExtensionState := do -let ctx ← read -let s ← ParserExtension.mkInitial -es.foldlM - (fun s entries => - entries.foldlM - (fun s entry => - match entry with - | ParserExtensionOleanEntry.token tk => do - let tokens ← IO.ofExcept (addTokenConfig s.tokens tk) - pure { s with tokens := tokens } - | ParserExtensionOleanEntry.kind k => - pure { s with kinds := s.kinds.insert k } - | ParserExtensionOleanEntry.category catName leadingIdentAsSymbol => do - let categories ← IO.ofExcept (addParserCategoryCore s.categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol}) - pure { s with categories := categories } - | ParserExtensionOleanEntry.parser catName declName prio => do - let p ← IO.ofExcept $ mkParserOfConstant ctx.env ctx.opts s.categories declName - let categories ← IO.ofExcept $ addParser s.categories catName declName p.1 p.2 prio - pure { s with categories := categories }) - s) - s + let ctx ← read + let s ← ParserExtension.mkInitial + es.foldlM (init := s) fun s entries => + entries.foldlM (init := s) fun s entry => + match entry with + | ParserExtensionOleanEntry.token tk => do + let tokens ← IO.ofExcept (addTokenConfig s.tokens tk) + pure { s with tokens := tokens } + | ParserExtensionOleanEntry.kind k => + pure { s with kinds := s.kinds.insert k } + | ParserExtensionOleanEntry.category catName leadingIdentAsSymbol => do + let categories ← IO.ofExcept (addParserCategoryCore s.categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol}) + pure { s with categories := categories } + | ParserExtensionOleanEntry.parser catName declName prio => do + let p ← IO.ofExcept $ mkParserOfConstant ctx.env ctx.opts s.categories declName + let categories ← IO.ofExcept $ addParser s.categories catName declName p.1 p.2 prio + pure { s with categories := categories } builtin_initialize parserExtension : ParserExtension ← registerPersistentEnvExtension { @@ -294,31 +288,30 @@ builtin_initialize parserExtension : ParserExtension ← } def isParserCategory (env : Environment) (catName : Name) : Bool := -(parserExtension.getState env).categories.contains catName + (parserExtension.getState env).categories.contains catName def addParserCategory (env : Environment) (catName : Name) (leadingIdentAsSymbol : Bool) : Except String Environment := do -if isParserCategory env catName then - throwParserCategoryAlreadyDefined catName -else - pure $ parserExtension.addEntry env $ ParserExtensionEntry.category catName leadingIdentAsSymbol + if isParserCategory env catName then + throwParserCategoryAlreadyDefined catName + else + pure $ parserExtension.addEntry env $ ParserExtensionEntry.category catName leadingIdentAsSymbol /- Return true if in the given category leading identifiers in parsers may be treated as atoms/symbols. See comment at `ParserCategory`. -/ def leadingIdentAsSymbol (env : Environment) (catName : Name) : Bool := -match getCategory (parserExtension.getState env).categories catName with -| none => false -| some cat => cat.leadingIdentAsSymbol + match getCategory (parserExtension.getState env).categories catName with + | none => false + | some cat => cat.leadingIdentAsSymbol def mkCategoryAntiquotParser (kind : Name) : Parser := -mkAntiquot kind.toString none + mkAntiquot kind.toString none -- helper decl to work around inlining issue https://github.com/leanprover/lean4/commit/3f6de2af06dd9a25f62294129f64bc05a29ea912#r41340377 @[inline] private def mkCategoryAntiquotParserFn (kind : Name) : ParserFn := (mkCategoryAntiquotParser kind).fn -def categoryParserFnImpl (catName : Name) : ParserFn := -fun ctx s => +def categoryParserFnImpl (catName : Name) : ParserFn := fun ctx s => let catName := if catName == `syntax then `stx else catName -- temporary Hack let categories := (parserExtension.getState ctx.env).categories match getCategory categories catName with @@ -327,149 +320,152 @@ fun ctx s => | none => s.mkUnexpectedError ("unknown parser category '" ++ toString catName ++ "'") @[builtinInit] def setCategoryParserFnRef : IO Unit := -categoryParserFnRef.set categoryParserFnImpl + categoryParserFnRef.set categoryParserFnImpl def addToken (env : Environment) (tk : Token) : Except String Environment := do --- Recall that `ParserExtension.addEntry` is pure, and assumes `addTokenConfig` does not fail. --- So, we must run it here to handle exception. -addTokenConfig (parserExtension.getState env).tokens tk -pure $ parserExtension.addEntry env $ ParserExtensionEntry.token tk + -- Recall that `ParserExtension.addEntry` is pure, and assumes `addTokenConfig` does not fail. + -- So, we must run it here to handle exception. + addTokenConfig (parserExtension.getState env).tokens tk + pure $ parserExtension.addEntry env $ ParserExtensionEntry.token tk def addSyntaxNodeKind (env : Environment) (k : SyntaxNodeKind) : Environment := -parserExtension.addEntry env $ ParserExtensionEntry.kind k + parserExtension.addEntry env $ ParserExtensionEntry.kind k def isValidSyntaxNodeKind (env : Environment) (k : SyntaxNodeKind) : Bool := -let kinds := (parserExtension.getState env).kinds -kinds.contains k + let kinds := (parserExtension.getState env).kinds + kinds.contains k def getSyntaxNodeKinds (env : Environment) : List SyntaxNodeKind := do -let kinds := (parserExtension.getState env).kinds -kinds.foldl (fun ks k _ => k::ks) [] + let kinds := (parserExtension.getState env).kinds + kinds.foldl (fun ks k _ => k::ks) [] def getTokenTable (env : Environment) : TokenTable := -(parserExtension.getState env).tokens + (parserExtension.getState env).tokens -def mkInputContext (input : String) (fileName : String) : InputContext := -{ input := input, +def mkInputContext (input : String) (fileName : String) : InputContext := { + input := input, fileName := fileName, - fileMap := input.toFileMap } + fileMap := input.toFileMap +} -def mkParserContext (env : Environment) (ctx : InputContext) : ParserContext := -{ prec := 0, +def mkParserContext (env : Environment) (ctx : InputContext) : ParserContext := { + prec := 0, toInputContext := ctx, env := env, - tokens := getTokenTable env } + tokens := getTokenTable env +} def mkParserState (input : String) : ParserState := -{ cache := initCacheForInput input } + { cache := initCacheForInput input } /- convenience function for testing -/ def runParserCategory (env : Environment) (catName : Name) (input : String) (fileName := "") : Except String Syntax := -let c := mkParserContext env (mkInputContext input fileName) -let s := mkParserState input -let s := whitespace c s -let s := categoryParserFnImpl catName c s -if s.hasError then - Except.error (s.toErrorMsg c) -else if input.atEnd s.pos then - Except.ok s.stxStack.back -else - Except.error ((s.mkError "end of input").toErrorMsg c) + let c := mkParserContext env (mkInputContext input fileName) + let s := mkParserState input + let s := whitespace c s + let s := categoryParserFnImpl catName c s + if s.hasError then + Except.error (s.toErrorMsg c) + else if input.atEnd s.pos then + Except.ok s.stxStack.back + else + Except.error ((s.mkError "end of input").toErrorMsg c) def declareBuiltinParser (env : Environment) (addFnName : Name) (catName : Name) (declName : Name) (prio : Nat) : IO Environment := -let name := `_regBuiltinParser ++ declName -let type := mkApp (mkConst `IO) (mkConst `Unit) -let val := mkAppN (mkConst addFnName) #[toExpr catName, toExpr declName, mkConst declName, mkNatLit prio] -let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false } -match env.addAndCompile {} decl with --- TODO: pretty print error -| Except.error _ => throw (IO.userError ("failed to emit registration code for builtin parser '" ++ toString declName ++ "'")) -| Except.ok env => IO.ofExcept (setBuiltinInitAttr env name) + let name := `_regBuiltinParser ++ declName + let type := mkApp (mkConst `IO) (mkConst `Unit) + let val := mkAppN (mkConst addFnName) #[toExpr catName, toExpr declName, mkConst declName, mkNatLit prio] + let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false } + match env.addAndCompile {} decl with + -- TODO: pretty print error + | Except.error _ => throw (IO.userError ("failed to emit registration code for builtin parser '" ++ toString declName ++ "'")) + | Except.ok env => IO.ofExcept (setBuiltinInitAttr env name) def declareLeadingBuiltinParser (env : Environment) (catName : Name) (declName : Name) (prio : Nat) : IO Environment := -- TODO: use CoreM? -declareBuiltinParser env `Lean.Parser.addBuiltinLeadingParser catName declName prio + declareBuiltinParser env `Lean.Parser.addBuiltinLeadingParser catName declName prio def declareTrailingBuiltinParser (env : Environment) (catName : Name) (declName : Name) (prio : Nat) : IO Environment := -- TODO: use CoreM? -declareBuiltinParser env `Lean.Parser.addBuiltinTrailingParser catName declName prio + declareBuiltinParser env `Lean.Parser.addBuiltinTrailingParser catName declName prio def getParserPriority (args : Syntax) : Except String Nat := -match args.getNumArgs with -| 0 => pure 0 -| 1 => match (args.getArg 0).isNatLit? with - | some prio => pure prio - | none => throw "invalid parser attribute, numeral expected" -| _ => throw "invalid parser attribute, no argument or numeral expected" + match args.getNumArgs with + | 0 => pure 0 + | 1 => match (args.getArg 0).isNatLit? with + | some prio => pure prio + | none => throw "invalid parser attribute, numeral expected" + | _ => throw "invalid parser attribute, no argument or numeral expected" private def BuiltinParserAttribute.add (attrName : Name) (catName : Name) (declName : Name) (args : Syntax) (persistent : Bool) : AttrM Unit := do -let prio ← ofExcept (getParserPriority args) -unless persistent do throwError! "invalid attribute '{attrName}', must be persistent" -let decl ← getConstInfo declName -let env ← getEnv -match decl.type with -| Expr.const `Lean.Parser.TrailingParser _ _ => do - let env ← liftIO $ declareTrailingBuiltinParser env catName declName prio - setEnv env -| Expr.const `Lean.Parser.Parser _ _ => do - let env ← liftIO $ declareLeadingBuiltinParser env catName declName prio - setEnv env -| _ => throwError! "unexpected parser type at '{declName}' (`Parser` or `TrailingParser` expected)" -runParserAttributeHooks catName declName (builtin := true) + let prio ← ofExcept (getParserPriority args) + unless persistent do throwError! "invalid attribute '{attrName}', must be persistent" + let decl ← getConstInfo declName + let env ← getEnv + match decl.type with + | Expr.const `Lean.Parser.TrailingParser _ _ => do + let env ← liftIO $ declareTrailingBuiltinParser env catName declName prio + setEnv env + | Expr.const `Lean.Parser.Parser _ _ => do + let env ← liftIO $ declareLeadingBuiltinParser env catName declName prio + setEnv env + | _ => throwError! "unexpected parser type at '{declName}' (`Parser` or `TrailingParser` expected)" + runParserAttributeHooks catName declName (builtin := true) /- The parsing tables for builtin parsers are "stored" in the extracted source code. -/ def registerBuiltinParserAttribute (attrName : Name) (catName : Name) (leadingIdentAsSymbol := false) : IO Unit := do -addBuiltinParserCategory catName leadingIdentAsSymbol -registerBuiltinAttribute { - name := attrName, - descr := "Builtin parser", - add := fun declName args persistent => liftM $ BuiltinParserAttribute.add attrName catName declName args persistent, - applicationTime := AttributeApplicationTime.afterCompilation -} + addBuiltinParserCategory catName leadingIdentAsSymbol + registerBuiltinAttribute { + name := attrName, + descr := "Builtin parser", + add := fun declName args persistent => liftM $ BuiltinParserAttribute.add attrName catName declName args persistent, + applicationTime := AttributeApplicationTime.afterCompilation + } private def ParserAttribute.add (attrName : Name) (catName : Name) (declName : Name) (args : Syntax) (persistent : Bool) : AttrM Unit := do -let prio ← ofExcept (getParserPriority args) -let env ← getEnv -let opts ← getOptions -let categories := (parserExtension.getState env).categories -match mkParserOfConstant env opts categories declName with -| Except.error ex => throwError ex -| Except.ok p => do - let leading := p.1 - let parser := p.2 - let tokens := parser.info.collectTokens [] - tokens.forM fun token => do - env ← getEnv - match addToken env token with - | Except.ok env => setEnv env - | Except.error msg => throwError! "invalid parser '{declName}', {msg}" - let kinds := parser.info.collectKinds {} - kinds.forM fun kind _ => modifyEnv fun env => addSyntaxNodeKind env kind - match addParser categories catName declName leading parser prio with - | Except.ok _ => modifyEnv fun env => parserExtension.addEntry env $ ParserExtensionEntry.parser catName declName leading parser prio + let prio ← ofExcept (getParserPriority args) + let env ← getEnv + let opts ← getOptions + let categories := (parserExtension.getState env).categories + match mkParserOfConstant env opts categories declName with | Except.error ex => throwError ex - runParserAttributeHooks catName declName /- builtin -/ false + | Except.ok p => do + let leading := p.1 + let parser := p.2 + let tokens := parser.info.collectTokens [] + tokens.forM fun token => do + env ← getEnv + match addToken env token with + | Except.ok env => setEnv env + | Except.error msg => throwError! "invalid parser '{declName}', {msg}" + let kinds := parser.info.collectKinds {} + kinds.forM fun kind _ => modifyEnv fun env => addSyntaxNodeKind env kind + match addParser categories catName declName leading parser prio with + | Except.ok _ => modifyEnv fun env => parserExtension.addEntry env $ ParserExtensionEntry.parser catName declName leading parser prio + | Except.error ex => throwError ex + runParserAttributeHooks catName declName /- builtin -/ false -def mkParserAttributeImpl (attrName : Name) (catName : Name) : AttributeImpl := -{ name := attrName, +def mkParserAttributeImpl (attrName : Name) (catName : Name) : AttributeImpl := { + name := attrName, descr := "parser", add := fun declName args persistent => liftM $ ParserAttribute.add attrName catName declName args persistent, - applicationTime := AttributeApplicationTime.afterCompilation } + applicationTime := AttributeApplicationTime.afterCompilation +} /- A builtin parser attribute that can be extended by users. -/ def registerBuiltinDynamicParserAttribute (attrName : Name) (catName : Name) : IO Unit := do -registerBuiltinAttribute (mkParserAttributeImpl attrName catName) + registerBuiltinAttribute (mkParserAttributeImpl attrName catName) @[builtinInit] private def registerParserAttributeImplBuilder : IO Unit := -registerAttributeImplBuilder `parserAttr fun args => - match args with - | [DataValue.ofName attrName, DataValue.ofName catName] => pure $ mkParserAttributeImpl attrName catName - | _ => throw "invalid parser attribute implementation builder arguments" + registerAttributeImplBuilder `parserAttr fun args => + match args with + | [DataValue.ofName attrName, DataValue.ofName catName] => pure $ mkParserAttributeImpl attrName catName + | _ => throw "invalid parser attribute implementation builder arguments" def registerParserCategory (env : Environment) (attrName : Name) (catName : Name) (leadingIdentAsSymbol := false) : IO Environment := do -env ← IO.ofExcept $ addParserCategory env catName leadingIdentAsSymbol -registerAttributeOfBuilder env `parserAttr [DataValue.ofName attrName, DataValue.ofName catName] + let env ← IO.ofExcept $ addParserCategory env catName leadingIdentAsSymbol + registerAttributeOfBuilder env `parserAttr [DataValue.ofName attrName, DataValue.ofName catName] -- declare `termParser` here since it is used everywhere via antiquotations @@ -483,10 +479,9 @@ builtin_initialize registerBuiltinParserAttribute `builtinCommandParser `command builtin_initialize registerBuiltinDynamicParserAttribute `commandParser `command @[inline] def commandParser (rbp : Nat := 0) : Parser := -categoryParser `command rbp + categoryParser `command rbp -def notFollowedByCategoryTokenFn (catName : Name) : ParserFn := -fun ctx s => +def notFollowedByCategoryTokenFn (catName : Name) : ParserFn := fun ctx s => let categories := (parserExtension.getState ctx.env).categories match getCategory categories catName with | none => s.mkUnexpectedError s!"unknown parser category '{catName}'" @@ -500,14 +495,15 @@ fun ctx s => | _ => s | _ => s -@[inline] def notFollowedByCategoryToken (catName : Name) : Parser := -{ fn := notFollowedByCategoryTokenFn catName } +@[inline] def notFollowedByCategoryToken (catName : Name) : Parser := { + fn := notFollowedByCategoryTokenFn catName +} abbrev notFollowedByCommandToken : Parser := -notFollowedByCategoryToken `command + notFollowedByCategoryToken `command abbrev notFollowedByTermToken : Parser := -notFollowedByCategoryToken `term + notFollowedByCategoryToken `term end Parser end Lean diff --git a/src/Lean/PrettyPrinter/Meta.lean b/src/Lean/PrettyPrinter/Meta.lean index 5c6808f7d4..cd495e988a 100644 --- a/src/Lean/PrettyPrinter/Meta.lean +++ b/src/Lean/PrettyPrinter/Meta.lean @@ -23,67 +23,67 @@ def ctx (interp) : ParserCompiler.Context Parenthesizer := ⟨`parenthesizer, parenthesizerAttribute, combinatorParenthesizerAttribute, interp⟩ unsafe def interpretParserDescr : ParserDescr → AttrM Parenthesizer -| ParserDescr.andthen d₁ d₂ => andthen.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ -| ParserDescr.orelse d₁ d₂ => orelse.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ -| ParserDescr.optional d => optional.parenthesizer <$> interpretParserDescr d -| ParserDescr.lookahead d => lookahead.parenthesizer <$> interpretParserDescr d -| ParserDescr.try d => try.parenthesizer <$> interpretParserDescr d -| ParserDescr.notFollowedBy d => notFollowedBy.parenthesizer <$> interpretParserDescr d -| ParserDescr.many d => many.parenthesizer <$> interpretParserDescr d -| ParserDescr.many1 d => many1.parenthesizer <$> interpretParserDescr d -| ParserDescr.sepBy d₁ d₂ => sepBy.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ -| ParserDescr.sepBy1 d₁ d₂ => sepBy1.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ -| ParserDescr.node k prec d => leadingNode.parenthesizer k prec <$> interpretParserDescr d -| ParserDescr.trailingNode k prec d => trailingNode.parenthesizer k prec <$> interpretParserDescr d -| ParserDescr.symbol tk => pure $ symbol.parenthesizer tk -| ParserDescr.numLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "numLit" `numLit) numLitNoAntiquot.parenthesizer -| ParserDescr.strLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "strLit" `strLit) strLitNoAntiquot.parenthesizer -| ParserDescr.charLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "charLit" `charLit) charLitNoAntiquot.parenthesizer -| ParserDescr.nameLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "nameLit" `nameLit) nameLitNoAntiquot.parenthesizer -| ParserDescr.ident => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "ident" `ident) identNoAntiquot.parenthesizer -| ParserDescr.interpolatedStr d => interpolatedStr.parenthesizer <$> interpretParserDescr d -| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.parenthesizer tk includeIdent -| ParserDescr.noWs => pure $ checkNoWsBefore.parenthesizer -| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName -| ParserDescr.cat catName prec => pure $ categoryParser.parenthesizer catName prec + | ParserDescr.andthen d₁ d₂ => andthen.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ + | ParserDescr.orelse d₁ d₂ => orelse.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ + | ParserDescr.optional d => optional.parenthesizer <$> interpretParserDescr d + | ParserDescr.lookahead d => lookahead.parenthesizer <$> interpretParserDescr d + | ParserDescr.try d => try.parenthesizer <$> interpretParserDescr d + | ParserDescr.notFollowedBy d => notFollowedBy.parenthesizer <$> interpretParserDescr d + | ParserDescr.many d => many.parenthesizer <$> interpretParserDescr d + | ParserDescr.many1 d => many1.parenthesizer <$> interpretParserDescr d + | ParserDescr.sepBy d₁ d₂ => sepBy.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ + | ParserDescr.sepBy1 d₁ d₂ => sepBy1.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ + | ParserDescr.node k prec d => leadingNode.parenthesizer k prec <$> interpretParserDescr d + | ParserDescr.trailingNode k prec d => trailingNode.parenthesizer k prec <$> interpretParserDescr d + | ParserDescr.symbol tk => pure $ symbol.parenthesizer tk + | ParserDescr.numLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "numLit" `numLit) numLitNoAntiquot.parenthesizer + | ParserDescr.strLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "strLit" `strLit) strLitNoAntiquot.parenthesizer + | ParserDescr.charLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "charLit" `charLit) charLitNoAntiquot.parenthesizer + | ParserDescr.nameLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "nameLit" `nameLit) nameLitNoAntiquot.parenthesizer + | ParserDescr.ident => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "ident" `ident) identNoAntiquot.parenthesizer + | ParserDescr.interpolatedStr d => interpolatedStr.parenthesizer <$> interpretParserDescr d + | ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.parenthesizer tk includeIdent + | ParserDescr.noWs => pure $ checkNoWsBefore.parenthesizer + | ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName + | ParserDescr.cat catName prec => pure $ categoryParser.parenthesizer catName prec @[builtinInit] unsafe def regHook : IO Unit := -ParserCompiler.registerParserCompiler (ctx interpretParserDescr) + ParserCompiler.registerParserCompiler (ctx interpretParserDescr) end Parenthesizer namespace Formatter def ctx (interp) : ParserCompiler.Context Formatter := -⟨`formatter, formatterAttribute, combinatorFormatterAttribute, interp⟩ + ⟨`formatter, formatterAttribute, combinatorFormatterAttribute, interp⟩ unsafe def interpretParserDescr : ParserDescr → AttrM Formatter -| ParserDescr.andthen d₁ d₂ => andthen.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ -| ParserDescr.orelse d₁ d₂ => orelse.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ -| ParserDescr.optional d => optional.formatter <$> interpretParserDescr d -| ParserDescr.lookahead d => lookahead.formatter <$> interpretParserDescr d -| ParserDescr.try d => try.formatter <$> interpretParserDescr d -| ParserDescr.notFollowedBy d => notFollowedBy.formatter <$> interpretParserDescr d -| ParserDescr.many d => many.formatter <$> interpretParserDescr d -| ParserDescr.many1 d => many1.formatter <$> interpretParserDescr d -| ParserDescr.sepBy d₁ d₂ => sepBy.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ -| ParserDescr.sepBy1 d₁ d₂ => sepBy1.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ -| ParserDescr.node k prec d => node.formatter k <$> interpretParserDescr d -| ParserDescr.trailingNode k prec d => trailingNode.formatter k prec <$> interpretParserDescr d -| ParserDescr.symbol tk => pure $ symbol.formatter tk -| ParserDescr.numLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "numLit" `numLit) numLitNoAntiquot.formatter -| ParserDescr.strLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "strLit" `strLit) strLitNoAntiquot.formatter -| ParserDescr.charLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "charLit" `charLit) charLitNoAntiquot.formatter -| ParserDescr.nameLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "nameLit" `nameLit) nameLitNoAntiquot.formatter -| ParserDescr.interpolatedStr d => interpolatedStr.formatter <$> interpretParserDescr d -| ParserDescr.ident => pure $ withAntiquot.formatter (mkAntiquot.formatter' "ident" `ident) identNoAntiquot.formatter -| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.formatter tk -| ParserDescr.noWs => pure $ checkNoWsBefore.formatter -| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName -| ParserDescr.cat catName prec => pure $ categoryParser.formatter catName + | ParserDescr.andthen d₁ d₂ => andthen.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ + | ParserDescr.orelse d₁ d₂ => orelse.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ + | ParserDescr.optional d => optional.formatter <$> interpretParserDescr d + | ParserDescr.lookahead d => lookahead.formatter <$> interpretParserDescr d + | ParserDescr.try d => try.formatter <$> interpretParserDescr d + | ParserDescr.notFollowedBy d => notFollowedBy.formatter <$> interpretParserDescr d + | ParserDescr.many d => many.formatter <$> interpretParserDescr d + | ParserDescr.many1 d => many1.formatter <$> interpretParserDescr d + | ParserDescr.sepBy d₁ d₂ => sepBy.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ + | ParserDescr.sepBy1 d₁ d₂ => sepBy1.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂ + | ParserDescr.node k prec d => node.formatter k <$> interpretParserDescr d + | ParserDescr.trailingNode k prec d => trailingNode.formatter k prec <$> interpretParserDescr d + | ParserDescr.symbol tk => pure $ symbol.formatter tk + | ParserDescr.numLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "numLit" `numLit) numLitNoAntiquot.formatter + | ParserDescr.strLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "strLit" `strLit) strLitNoAntiquot.formatter + | ParserDescr.charLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "charLit" `charLit) charLitNoAntiquot.formatter + | ParserDescr.nameLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "nameLit" `nameLit) nameLitNoAntiquot.formatter + | ParserDescr.interpolatedStr d => interpolatedStr.formatter <$> interpretParserDescr d + | ParserDescr.ident => pure $ withAntiquot.formatter (mkAntiquot.formatter' "ident" `ident) identNoAntiquot.formatter + | ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.formatter tk + | ParserDescr.noWs => pure $ checkNoWsBefore.formatter + | ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName + | ParserDescr.cat catName prec => pure $ categoryParser.formatter catName @[builtinInit] unsafe def regHook : IO Unit := -ParserCompiler.registerParserCompiler (ctx interpretParserDescr) + ParserCompiler.registerParserCompiler (ctx interpretParserDescr) end Formatter diff --git a/src/Lean/Util/ReplaceLevel.lean b/src/Lean/Util/ReplaceLevel.lean index 67c7805d1d..753f9f49c7 100644 --- a/src/Lean/Util/ReplaceLevel.lean +++ b/src/Lean/Util/ReplaceLevel.lean @@ -66,15 +66,15 @@ end ReplaceLevelImpl @[implementedBy ReplaceLevelImpl.replaceUnsafe] partial def replaceLevel (f? : Level → Option Level) : Expr → Expr -| e@(Expr.forallE _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateForallE! d b -| e@(Expr.lam _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateLambdaE! d b -| e@(Expr.mdata _ b _) => let b := replaceLevel f? b; e.updateMData! b -| e@(Expr.letE _ t v b _) => let t := replaceLevel f? t; let v := replaceLevel f? v; let b := replaceLevel f? b; e.updateLet! t v b -| e@(Expr.app f a _) => let f := replaceLevel f? f; let a := replaceLevel f? a; e.updateApp! f a -| e@(Expr.proj _ _ b _) => let b := replaceLevel f? b; e.updateProj! b -| e@(Expr.sort u _) => e.updateSort! (u.replace f?) -| e@(Expr.const n us _) => e.updateConst! (us.map (Level.replace f?)) -| e => e + | e@(Expr.forallE _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateForallE! d b + | e@(Expr.lam _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateLambdaE! d b + | e@(Expr.mdata _ b _) => let b := replaceLevel f? b; e.updateMData! b + | e@(Expr.letE _ t v b _) => let t := replaceLevel f? t; let v := replaceLevel f? v; let b := replaceLevel f? b; e.updateLet! t v b + | e@(Expr.app f a _) => let f := replaceLevel f? f; let a := replaceLevel f? a; e.updateApp! f a + | e@(Expr.proj _ _ b _) => let b := replaceLevel f? b; e.updateProj! b + | e@(Expr.sort u _) => e.updateSort! (u.replace f?) + | e@(Expr.const n us _) => e.updateConst! (us.map (Level.replace f?)) + | e => e end Expr end Lean