From e4b358a01e8b2af35284dcf2a07db8ef2e4c8444 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 9 Jul 2022 15:24:22 -0700 Subject: [PATCH] refactor: prepare to elaborate `a[i]` notation using typeclasses --- src/Init/Data/Array/Basic.lean | 40 ++++++++------- src/Init/Data/Array/DecidableEq.lean | 2 +- src/Init/Data/Array/InsertionSort.lean | 2 +- src/Init/Data/Array/Subarray.lean | 5 +- src/Init/Data/Fin/Basic.lean | 9 +++- src/Init/Meta.lean | 9 ++-- src/Init/Notation.lean | 2 +- src/Init/Prelude.lean | 11 +++++ src/Init/Tactics.lean | 12 +++++ src/Init/Util.lean | 9 ++++ src/Lean/Attributes.lean | 4 +- src/Lean/Compiler/IR/ElimDeadBranches.lean | 10 ++-- src/Lean/Elab/Do.lean | 2 +- src/Lean/Elab/InfoTree.lean | 3 +- src/Lean/Elab/StructInst.lean | 2 +- src/Lean/Environment.lean | 1 + src/Lean/LocalContext.lean | 6 +-- src/Lean/MonadEnv.lean | 2 +- .../Delaborator/TopDownAnalyze.lean | 2 +- .../Server/FileWorker/RequestHandling.lean | 49 ++++++++++--------- src/Lean/Server/Snapshots.lean | 2 +- src/Lean/Util/Trace.lean | 2 +- src/Std/Data/HashMap.lean | 3 ++ src/Std/Data/PersistentArray.lean | 8 ++- src/Std/Data/PersistentHashMap.lean | 26 +++++----- 25 files changed, 144 insertions(+), 79 deletions(-) diff --git a/src/Init/Data/Array/Basic.lean b/src/Init/Data/Array/Basic.lean index 7acc95079d..7f8bc0b407 100644 --- a/src/Init/Data/Array/Basic.lean +++ b/src/Init/Data/Array/Basic.lean @@ -39,13 +39,13 @@ def singleton (v : α) : Array α := `fget` may be slightly slower than `uget`. -/ @[extern "lean_array_uget"] def uget (a : @& Array α) (i : USize) (h : i.toNat < a.size) : α := - a[⟨i.toNat, h⟩] + a[i.toNat] def back [Inhabited α] (a : Array α) : α := a.get! (a.size - 1) def get? (a : Array α) (i : Nat) : Option α := - if h : i < a.size then some a[⟨i, h⟩] else none + if h : i < a.size then some a[i] else none abbrev getOp? (self : Array α) (idx : Nat) : Option α := self.get? idx @@ -55,7 +55,8 @@ def back? (a : Array α) : Option α := -- auxiliary declaration used in the equation compiler when pattern matching array literals. abbrev getLit {α : Type u} {n : Nat} (a : Array α) (i : Nat) (h₁ : a.size = n) (h₂ : i < n) : α := - a[⟨i, h₁.symm ▸ h₂⟩] + have := h₁.symm ▸ h₂ + a[i] @[simp] theorem size_set (a : Array α) (i : Fin a.size) (v : α) : (set a i v).size = a.size := List.length_set .. @@ -166,7 +167,7 @@ protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide) have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this - match (← f as[⟨as.size - 1 - i, this⟩] b) with + match (← f as[as.size - 1 - i] b) with | ForInStep.done b => pure b | ForInStep.yield b => loop i (Nat.le_of_lt h') b loop as.size (Nat.le_refl _) b @@ -199,7 +200,8 @@ def foldlM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : β match i with | 0 => pure b | i'+1 => - loop i' (j+1) (← f b as[⟨j, Nat.lt_of_lt_of_le hlt h⟩]) + have : j < as.size := Nat.lt_of_lt_of_le hlt h + loop i' (j+1) (← f b as[j]) else pure b loop (stop - start) start init @@ -236,7 +238,7 @@ def foldrM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α | 0, _ => pure b | i+1, h => have : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self _) h - fold i (Nat.le_of_lt this) (← f as[⟨i, this⟩] b) + fold i (Nat.le_of_lt this) (← f as[i] b) if h : start ≤ as.size then if stop < start then fold start h init @@ -330,7 +332,8 @@ def anyM {α : Type u} {m : Type → Type w} [Monad m] (p : α → m Bool) (as : let any (stop : Nat) (h : stop ≤ as.size) := let rec loop (j : Nat) : m Bool := do if hlt : j < stop then - if (← p as[⟨j, Nat.lt_of_lt_of_le hlt h⟩]) then + have : j < as.size := Nat.lt_of_lt_of_le hlt h + if (← p as[j]) then pure true else loop (j+1) @@ -353,7 +356,7 @@ def findSomeRevM? {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] | 0, _ => pure none | i+1, h => do have : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self _) h - let r ← f as[⟨i, this⟩] + let r ← f as[i] match r with | some _ => pure r | none => @@ -422,7 +425,7 @@ def findIdx? {α : Type u} (as : Array α) (p : α → Bool) : Option Nat := rw [inv] at hlt exact absurd hlt (Nat.lt_irrefl _) | i+1, inv => - if p as[⟨j, hlt⟩] then + if p as[j] then some j else have : i + (j+1) = as.size := by @@ -515,7 +518,8 @@ namespace Array @[specialize] def isEqvAux (a b : Array α) (hsz : a.size = b.size) (p : α → α → Bool) (i : Nat) : Bool := if h : i < a.size then - p a[⟨i, h⟩] b[⟨i, hsz ▸ h⟩] && isEqvAux a b hsz p (i+1) + have : i < b.size := hsz ▸ h + p a[i] b[i] && isEqvAux a b hsz p (i+1) else true termination_by _ => a.size - i @@ -553,7 +557,7 @@ def filterMap (f : α → Option β) (as : Array α) (start := 0) (stop := as.si @[specialize] def getMax? (as : Array α) (lt : α → α → Bool) : Option α := if h : 0 < as.size then - let a0 := as[⟨0, h⟩] + let a0 := as[0] some <| as.foldl (init := a0) (start := 1) fun best a => if lt best a then a else best else @@ -572,7 +576,7 @@ def partition (p : α → Bool) (as : Array α) : Array α × Array α := Id.run theorem ext (a b : Array α) (h₁ : a.size = b.size) - (h₂ : (i : Nat) → (hi₁ : i < a.size) → (hi₂ : i < b.size) → a[⟨i, hi₁⟩] = b[⟨i, hi₂⟩]) + (h₂ : (i : Nat) → (hi₁ : i < a.size) → (hi₂ : i < b.size) → a[i] = b[i]) : a = b := by let rec extAux (a b : List α) (h₁ : a.length = b.length) @@ -758,8 +762,8 @@ theorem toArrayLit_eq (a : Array α) (n : Nat) (hsz : a.size = n) : a = toArrayL def isPrefixOfAux [BEq α] (as bs : Array α) (hle : as.size ≤ bs.size) (i : Nat) : Bool := if h : i < as.size then - let a := as[⟨i, h⟩] - let b := bs[⟨i, Nat.lt_of_lt_of_le h hle⟩] + let a := as[i] + let b := bs[i]'(Nat.lt_of_lt_of_le h hle) if a == b then isPrefixOfAux as bs hle (i+1) else @@ -780,11 +784,11 @@ private def allDiffAuxAux [BEq α] (as : Array α) (a : α) : forall (i : Nat), | 0, _ => true | i+1, h => have : i < as.size := Nat.lt_trans (Nat.lt_succ_self _) h; - a != as[⟨i, this⟩] && allDiffAuxAux as a i this + a != as[i] && allDiffAuxAux as a i this private def allDiffAux [BEq α] (as : Array α) (i : Nat) : Bool := if h : i < as.size then - allDiffAuxAux as as[⟨i, h⟩] i h && allDiffAux as (i+1) + allDiffAuxAux as as[i] i h && allDiffAux as (i+1) else true termination_by _ => as.size - i @@ -794,9 +798,9 @@ def allDiff [BEq α] (as : Array α) : Bool := @[specialize] def zipWithAux (f : α → β → γ) (as : Array α) (bs : Array β) (i : Nat) (cs : Array γ) : Array γ := if h : i < as.size then - let a := as[⟨i, h⟩] + let a := as[i] if h : i < bs.size then - let b := bs[⟨i, h⟩] + let b := bs[i] zipWithAux f as bs (i+1) <| cs.push <| f a b else cs diff --git a/src/Init/Data/Array/DecidableEq.lean b/src/Init/Data/Array/DecidableEq.lean index cc40c9bc3d..fb6c1f50bc 100644 --- a/src/Init/Data/Array/DecidableEq.lean +++ b/src/Init/Data/Array/DecidableEq.lean @@ -9,7 +9,7 @@ import Init.Classical namespace Array -theorem eq_of_isEqvAux [DecidableEq α] (a b : Array α) (hsz : a.size = b.size) (i : Nat) (hi : i ≤ a.size) (heqv : Array.isEqvAux a b hsz (fun x y => x = y) i) (j : Nat) (low : i ≤ j) (high : j < a.size) : a[⟨j, high⟩] = b[⟨j, hsz ▸ high⟩] := by +theorem eq_of_isEqvAux [DecidableEq α] (a b : Array α) (hsz : a.size = b.size) (i : Nat) (hi : i ≤ a.size) (heqv : Array.isEqvAux a b hsz (fun x y => x = y) i) (j : Nat) (low : i ≤ j) (high : j < a.size) : a[j] = b[j]'(hsz ▸ high) := by by_cases h : i < a.size · unfold Array.isEqvAux at heqv simp [h] at heqv diff --git a/src/Init/Data/Array/InsertionSort.lean b/src/Init/Data/Array/InsertionSort.lean index 7194b773af..dd92e2aead 100644 --- a/src/Init/Data/Array/InsertionSort.lean +++ b/src/Init/Data/Array/InsertionSort.lean @@ -22,7 +22,7 @@ where | 0 => a | j'+1 => have h' : j' < a.size := by subst j; exact Nat.lt_trans (Nat.lt_succ_self _) h - if lt a[⟨j, h⟩] a[⟨j', h'⟩] then + if lt a[j] a[j'] then swapLoop (a.swap ⟨j, h⟩ ⟨j', h'⟩) j' (by rw [size_swap]; assumption done) else a diff --git a/src/Init/Data/Array/Subarray.lean b/src/Init/Data/Array/Subarray.lean index 9f178fcb8e..32de1a97f1 100644 --- a/src/Init/Data/Array/Subarray.lean +++ b/src/Init/Data/Array/Subarray.lean @@ -27,11 +27,14 @@ def get (s : Subarray α) (i : Fin s.size) : α := simp [size] at this rw [Nat.add_comm] exact Nat.add_lt_of_lt_sub this - s.as[⟨s.start + i.val, this⟩] + s.as[s.start + i.val] abbrev getOp (self : Subarray α) (idx : Fin self.size) : α := self.get idx +instance : GetElem (Subarray α) Nat α fun xs i => i < xs.size where + getElem xs i h := xs.get ⟨i, h⟩ + @[inline] def getD (s : Subarray α) (i : Nat) (v₀ : α) : α := if h : i < s.size then s.get ⟨i, h⟩ else v₀ diff --git a/src/Init/Data/Fin/Basic.lean b/src/Init/Data/Fin/Basic.lean index 8330ce80fe..b65ac8548c 100644 --- a/src/Init/Data/Fin/Basic.lean +++ b/src/Init/Data/Fin/Basic.lean @@ -112,6 +112,13 @@ theorem val_ne_of_ne {i j : Fin n} (h : i ≠ j) : val i ≠ val j := theorem modn_lt : ∀ {m : Nat} (i : Fin n), m > 0 → (modn i m).val < m | _, ⟨_, _⟩, hp => Nat.lt_of_le_of_lt (mod_le _ _) (mod_lt _ hp) +theorem val_lt_of_le (i : Fin b) (h : b ≤ n) : i.val < n := + Nat.lt_of_lt_of_le i.isLt h + end Fin -open Fin +instance [GetElem Cont Nat Elem Dom] : GetElem Cont (Fin n) Elem fun xs i => Dom xs i where + getElem xs i h := getElem xs i.1 h + +macro_rules + | `(tactic| get_elem_tactic_trivial) => `(tactic| apply Fin.val_lt_of_le; (first | assumption | simp (config := { arith := true })); done) diff --git a/src/Init/Meta.lean b/src/Init/Meta.lean index df17c4c975..39c8c6b5e5 100644 --- a/src/Init/Meta.lean +++ b/src/Init/Meta.lean @@ -384,7 +384,7 @@ def unsetTrailing (stx : Syntax) : Syntax := @[specialize] private partial def updateFirst {α} [Inhabited α] (a : Array α) (f : α → Option α) (i : Nat) : Option (Array α) := if h : i < a.size then - let v := a[⟨i, h⟩] + let v := a[i] match f v with | some v => some <| a.set ⟨i, h⟩ v | none => updateFirst a f (i+1) @@ -1001,13 +1001,14 @@ open Lean private partial def filterSepElemsMAux {m : Type → Type} [Monad m] (a : Array Syntax) (p : Syntax → m Bool) (i : Nat) (acc : Array Syntax) : m (Array Syntax) := do if h : i < a.size then - let stx := a[⟨i, h⟩] + let stx := a[i] if (← p stx) then if acc.isEmpty then filterSepElemsMAux a p (i+2) (acc.push stx) else if hz : i ≠ 0 then have : i.pred < i := Nat.pred_lt hz - let sepStx := a[⟨i.pred, Nat.lt_trans this h⟩] + have : i.pred < a.size := Nat.lt_trans this h + let sepStx := a[i.pred] filterSepElemsMAux a p (i+2) ((acc.push sepStx).push stx) else filterSepElemsMAux a p (i+2) (acc.push stx) @@ -1024,7 +1025,7 @@ def filterSepElems (a : Array Syntax) (p : Syntax → Bool) : Array Syntax := private partial def mapSepElemsMAux {m : Type → Type} [Monad m] (a : Array Syntax) (f : Syntax → m Syntax) (i : Nat) (acc : Array Syntax) : m (Array Syntax) := do if h : i < a.size then - let stx := a[⟨i, h⟩] + let stx := a[i] if i % 2 == 0 then do let stx ← f stx mapSepElemsMAux a f (i+1) (acc.push stx) diff --git a/src/Init/Notation.lean b/src/Init/Notation.lean index 235d53ce38..ab33e4967b 100644 --- a/src/Init/Notation.lean +++ b/src/Init/Notation.lean @@ -220,7 +220,7 @@ macro_rules match i, skip with | 0, _ => pure result | i+1, true => expandListLit i false result - | i+1, false => expandListLit i true (← ``(List.cons $(⟨elems.elemsAndSeps[i]!⟩) $result)) + | i+1, false => expandListLit i true (← ``(List.cons $(⟨elems.elemsAndSeps.get! i⟩) $result)) if elems.elemsAndSeps.size < 64 then expandListLit elems.elemsAndSeps.size false (← ``(List.nil)) else diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index 6b90bfbff8..a7b0db3c20 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -1229,6 +1229,11 @@ def panic {α : Type u} [Inhabited α] (msg : String) : α := -- TODO: this be applied directly to `Inhabited`'s definition when we remove the above workaround attribute [nospecialize] Inhabited +class GetElem (Cont : Type u) (Idx : Type v) (Elem : outParam (Type w)) (Dom : outParam (Cont → Idx → Prop)) where + getElem (xs : Cont) (i : Idx) (h : Dom xs i) : Elem + +export GetElem (getElem) + /- The Compiler has special support for arrays. They are implemented using dynamic arrays: https://en.wikipedia.org/wiki/Dynamic_array @@ -1270,6 +1275,9 @@ abbrev Array.getOp {α : Type u} (self : Array α) (idx : Fin self.size) : α := abbrev Array.getOp! {α : Type u} [Inhabited α] (self : Array α) (idx : Nat) : α := self.get! idx +instance : GetElem (Array α) Nat α fun xs i => LT.lt i xs.size where + getElem xs i h := xs.get ⟨i, h⟩ + @[extern "lean_array_push"] def Array.push {α : Type u} (a : Array α) (v : α) : Array α := { data := List.concat a.data v @@ -1913,6 +1921,9 @@ def getArg (stx : Syntax) (i : Nat) : Syntax := @[inline] def getOp (self : Syntax) (idx : Nat) : Syntax := self.getArg idx +instance : GetElem Syntax Nat Syntax fun _ _ => True where + getElem stx i _ := stx.getArg i + def getArgs (stx : Syntax) : Array Syntax := match stx with | Syntax.node _ _ args => args diff --git a/src/Init/Tactics.lean b/src/Init/Tactics.lean index 99af674773..3eede15cac 100644 --- a/src/Init/Tactics.lean +++ b/src/Init/Tactics.lean @@ -427,3 +427,15 @@ end Lean `‹t›` resolves to an (arbitrary) hypothesis of type `t`. It is useful for referring to hypotheses without accessible names. `t` may contain holes that are solved by unification with the expected type; in particular, `‹_›` is a shortcut for `by assumption`. -/ macro "‹" type:term "›" : term => `((by assumption : $type)) + +syntax "get_elem_tactic_trivial" : tactic -- extensible tactic + +macro_rules | `(tactic| get_elem_tactic_trivial) => `(tactic| trivial) +macro_rules | `(tactic| get_elem_tactic_trivial) => `(tactic| decide) +macro_rules | `(tactic| get_elem_tactic_trivial) => `(tactic| assumption) + +macro "get_elem_tactic" : tactic => `(get_elem_tactic_trivial) -- TODO: add error message + +macro:max (priority := high) x:term noWs "[" i:term "]" : term => `(getElem $x $i (by get_elem_tactic)) + +macro x:term noWs "[" i:term "]'" h:term:max : term => `(getElem $x $i $h) diff --git a/src/Init/Util.lean b/src/Init/Util.lean index e9d1d61b7e..0016caae32 100644 --- a/src/Init/Util.lean +++ b/src/Init/Util.lean @@ -57,3 +57,12 @@ def withPtrEq {α : Type u} (a b : α) (k : Unit → Bool) (h : a = b → k () = @[implementedBy withPtrAddrUnsafe] def withPtrAddr {α : Type u} {β : Type v} (a : α) (k : USize → β) (h : ∀ u₁ u₂, k u₁ = k u₂) : β := k 0 + +@[inline] def getElem! [GetElem Cont Idx Elem Dom] [Inhabited Elem] (xs : Cont) (i : Idx) [Decidable (Dom xs i)] : Elem := + if h : _ then getElem xs i h else unreachable! + +@[inline] def getElem? [GetElem Cont Idx Elem Dom] (xs : Cont) (i : Idx) [Decidable (Dom xs i)] : Option Elem := + if h : _ then some (getElem xs i h) else none + +macro:max (priority := high) x:term noWs "[" i:term "]" noWs "?" : term => `(getElem? $x $i) -- TODO: remove priority +macro:max (priority := high) x:term noWs "[" i:term "]" noWs "!" : term => `(getElem! $x $i) -- TODO: remove priority diff --git a/src/Lean/Attributes.lean b/src/Lean/Attributes.lean index 7e374c3f0f..4961781234 100644 --- a/src/Lean/Attributes.lean +++ b/src/Lean/Attributes.lean @@ -77,12 +77,12 @@ def Attribute.Builtin.ensureNoArgs (stx : Syntax) : AttrM Unit := do def Attribute.Builtin.getIdent? (stx : Syntax) : AttrM (Option Syntax) := do if stx.getKind == `Lean.Parser.Attr.simple then if !stx[1].isNone && stx[1][0].isIdent then - return stx[1][0] + return some stx[1][0] else return none /- We handle `macro` here because it is handled by the generic `KeyedDeclsAttribute -/ else if stx.getKind == `Lean.Parser.Attr.«macro» || stx.getKind == `Lean.Parser.Attr.«export» then - return stx[1] + return some stx[1] else throwErrorAt stx "unexpected attribute argument" diff --git a/src/Lean/Compiler/IR/ElimDeadBranches.lean b/src/Lean/Compiler/IR/ElimDeadBranches.lean index 30660a3b17..cf9e24a4d3 100644 --- a/src/Lean/Compiler/IR/ElimDeadBranches.lean +++ b/src/Lean/Compiler/IR/ElimDeadBranches.lean @@ -193,7 +193,7 @@ def interpExpr : Expr → M Value | none => do let s ← get match ctx.decls.findIdx? (fun decl => decl.name == fid) with - | some idx => pure s.funVals[idx] + | some idx => pure s.funVals[idx]! | none => pure top | _ => pure top @@ -273,14 +273,14 @@ def inferStep : M Bool := do match ctx.decls[idx]! with | Decl.fdecl (xs := ys) (body := b) .. => do let s ← get - let currVals := s.funVals[idx] + let currVals := s.funVals[idx]! withReader (fun ctx => { ctx with currFnIdx := idx }) do ys.forM fun y => updateVarAssignment y.x top interpFnBody b let s ← get - let newVals := s.funVals[idx] + let newVals := s.funVals[idx]! pure (modified || currVals != newVals) - | Decl.extern _ _ _ _ => pure modified + | Decl.extern .. => pure modified partial def inferMain : M Unit := do let modified ← inferStep @@ -324,7 +324,7 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do let assignments := s.assignments modify fun s => let env := decls.size.fold (init := s.env) fun i env => - addFunctionSummary env decls[i]!.name funVals[i] + addFunctionSummary env decls[i]!.name funVals[i]! { s with env := env } return decls.mapIdx fun i decl => elimDead assignments[i]! decl diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index e9a7f709a4..784a5d3793 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -1475,7 +1475,7 @@ mutual partial def doTryToCode (doTry : Syntax) (doElems: List Syntax) : M CodeBlock := do let tryCode ← doSeqToCode (getDoSeqElems doTry[1]) let optFinally := doTry[3] - let catches ← doTry[2].getArgs.mapM fun catchStx => do + let catches ← doTry[2].getArgs.mapM fun catchStx : Syntax => do if catchStx.getKind == ``Lean.Parser.Term.doCatch then let x := catchStx[1] if x.isIdent then diff --git a/src/Lean/Elab/InfoTree.lean b/src/Lean/Elab/InfoTree.lean index b9b35f6233..360ceb5aca 100644 --- a/src/Lean/Elab/InfoTree.lean +++ b/src/Lean/Elab/InfoTree.lean @@ -449,7 +449,8 @@ def withMacroExpansionInfo [MonadFinally m] [Monad m] [MonadInfoTree m] [MonadLC if (← getInfoState).enabled then let treesSaved ← getResetInfoTrees Prod.fst <$> MonadFinally.tryFinally' x fun _ => modifyInfoState fun s => - if s.trees.size > 0 then + if h : s.trees.size > 0 then + have : s.trees.size - 1 < s.trees.size := Nat.sub_lt h (by decide) { s with trees := treesSaved, assignment := s.assignment.insert mvarId s.trees[s.trees.size - 1] } else { s with trees := treesSaved } diff --git a/src/Lean/Elab/StructInst.lean b/src/Lean/Elab/StructInst.lean index 116e54762d..255a55a869 100644 --- a/src/Lean/Elab/StructInst.lean +++ b/src/Lean/Elab/StructInst.lean @@ -65,7 +65,7 @@ private def expandNonAtomicExplicitSources (stx : Syntax) : TermElabM (Option Sy return none if sources.any (·.isMissing) then throwAbortTerm - go sources.toList #[] + return some (← go sources.toList #[]) where go (sources : List Syntax) (sourcesNew : Array Syntax) : TermElabM Syntax := do match sources with diff --git a/src/Lean/Environment.lean b/src/Lean/Environment.lean index 4b52143f5a..6515f4a1b5 100644 --- a/src/Lean/Environment.lean +++ b/src/Lean/Environment.lean @@ -19,6 +19,7 @@ def EnvExtensionState : Type := EnvExtensionStateSpec.fst instance : Inhabited EnvExtensionState := EnvExtensionStateSpec.snd def ModuleIdx := Nat +abbrev ModuleIdx.toNat (midx : ModuleIdx) : Nat := midx instance : Inhabited ModuleIdx := inferInstanceAs (Inhabited Nat) diff --git a/src/Lean/LocalContext.lean b/src/Lean/LocalContext.lean index a80adca01d..a3290c47df 100644 --- a/src/Lean/LocalContext.lean +++ b/src/Lean/LocalContext.lean @@ -326,13 +326,13 @@ instance : ForIn m LocalContext LocalDecl where partial def isSubPrefixOfAux (a₁ a₂ : PArray (Option LocalDecl)) (exceptFVars : Array Expr) (i j : Nat) : Bool := if i < a₁.size then - match a₁[i] with + match a₁[i]! with | none => isSubPrefixOfAux a₁ a₂ exceptFVars (i+1) j | some decl₁ => if exceptFVars.any fun fvar => fvar.fvarId! == decl₁.fvarId then isSubPrefixOfAux a₁ a₂ exceptFVars (i+1) j else if j < a₂.size then - match a₂[j] with + match a₂[j]! with | none => isSubPrefixOfAux a₁ a₂ exceptFVars i (j+1) | some decl₂ => if decl₁.fvarId == decl₂.fvarId then isSubPrefixOfAux a₁ a₂ exceptFVars (i+1) (j+1) else isSubPrefixOfAux a₁ a₂ exceptFVars i (j+1) else false @@ -395,7 +395,7 @@ def sanitizeNames (lctx : LocalContext) : StateM NameSanitizerState LocalContext if !getSanitizeNames st.options then pure lctx else StateT.run' (s := ({} : NameSet)) <| lctx.decls.size.foldRevM (init := lctx) fun i lctx => do - match lctx.decls[i] with + match lctx.decls[i]! with | none => pure lctx | some decl => if decl.userName.hasMacroScopes || (← get).contains decl.userName then do diff --git a/src/Lean/MonadEnv.lean b/src/Lean/MonadEnv.lean index c3241cf046..8bb7b2be29 100644 --- a/src/Lean/MonadEnv.lean +++ b/src/Lean/MonadEnv.lean @@ -169,7 +169,7 @@ def findModuleOf? [Monad m] [MonadEnv m] [MonadError m] (declName : Name) : m (O discard <| getConstInfo declName -- ensure declaration exists match (← getEnv).getModuleIdxFor? declName with | none => return none - | some modIdx => return some ((← getEnv).allImportedModuleNames[modIdx]!) + | some modIdx => return some ((← getEnv).allImportedModuleNames[modIdx.toNat]!) def isEnumType [Monad m] [MonadEnv m] [MonadError m] (declName : Name) : m Bool := do if let ConstantInfo.inductInfo info ← getConstInfo declName then diff --git a/src/Lean/PrettyPrinter/Delaborator/TopDownAnalyze.lean b/src/Lean/PrettyPrinter/Delaborator/TopDownAnalyze.lean index 9157911c30..0dadefb0bf 100644 --- a/src/Lean/PrettyPrinter/Delaborator/TopDownAnalyze.lean +++ b/src/Lean/PrettyPrinter/Delaborator/TopDownAnalyze.lean @@ -326,7 +326,7 @@ partial def canBottomUp (e : Expr) (mvar? : Option Expr := none) (fuel : Nat := inspectOutParams args[i]! mvars[i]! else if bInfos[i]! == BinderInfo.default then if ← isTrivialBottomUp args[i]! then tryUnify args[i]! mvars[i]! - else if ← typeUnknown mvars[i]! <&&> canBottomUp args[i]! mvars[i]! fuel then tryUnify args[i]! mvars[i]! + else if ← typeUnknown mvars[i]! <&&> canBottomUp args[i]! (some mvars[i]!) fuel then tryUnify args[i]! mvars[i]! if ← (pure (isHBinOp e) <&&> (valUnknown mvars[0]! <||> valUnknown mvars[1]!)) then tryUnify mvars[0]! mvars[1]! if mvar?.isSome then tryUnify resultType (← inferType mvar?.get!) return !(← valUnknown resultType) diff --git a/src/Lean/Server/FileWorker/RequestHandling.lean b/src/Lean/Server/FileWorker/RequestHandling.lean index 53cb0dcefd..411c58e375 100644 --- a/src/Lean/Server/FileWorker/RequestHandling.lean +++ b/src/Lean/Server/FileWorker/RequestHandling.lean @@ -251,8 +251,9 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams) stxs := stxs ++ (← parseAhead doc.meta.mkInputContext lastSnap).toList let (syms, _) := toDocumentSymbols doc.meta.text stxs return { syms := syms.toArray } - where - toDocumentSymbols (text : FileMap) +where + toDocumentSymbols (text : FileMap) (stxs : List Syntax) : List DocumentSymbol × List Syntax := + match stxs with | [] => ([], []) | stx::stxs => match stx with | `(namespace $id) => sectionLikeToDocumentSymbols text stx stxs (id.getId.toString) SymbolKind.namespace id @@ -263,12 +264,15 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams) unless stx.isOfKind ``Lean.Parser.Command.declaration do return (syms, stxs') if let some stxRange := stx.getRange? then - let (name, selection) := match stx with + let (name, selection) : String × Syntax := match stx with | `($_:declModifiers $_:attrKind instance $[$np:namedPrio]? $[$id:ident$[.{$ls,*}]?]? $sig:declSig $_) => ((·.getId.toString) <$> id |>.getD s!"instance {sig.raw.reprint.getD ""}", id.map (·.raw) |>.getD sig) - | _ => match stx[1][1] with - | `(declId|$id:ident$[.{$ls,*}]?) => (id.getId.toString, id) - | _ => (stx[1][0].isIdOrAtom?.getD "", stx[1][0]) + | _ => + match stx.getArg 1 |>.getArg 1 with + | `(declId|$id:ident$[.{$ls,*}]?) => (id.raw.getId.toString, id) + | _ => + let stx10 : Syntax := (stx.getArg 1).getArg 0 -- TODO: stx[1][0] times out + (stx10.isIdOrAtom?.getD "", stx10) if let some selRange := selection.getRange? then return (DocumentSymbol.mk { name := name @@ -277,22 +281,23 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams) selectionRange := selRange.toLspRange text } :: syms, stxs') return (syms, stxs') - sectionLikeToDocumentSymbols (text : FileMap) (stx : Syntax) (stxs : List Syntax) (name : String) (kind : SymbolKind) (selection : Syntax) := - let (syms, stxs') := toDocumentSymbols text stxs - -- discard `end` - let (syms', stxs'') := toDocumentSymbols text (stxs'.drop 1) - let endStx := match stxs' with - | endStx::_ => endStx - | [] => (stx::stxs').getLast! - -- we can assume that commands always have at least one position (see `parseCommand`) - let range := (mkNullNode #[stx, endStx]).getRange?.get!.toLspRange text - (DocumentSymbol.mk { - name - kind - range - selectionRange := selection.getRange? |>.map (·.toLspRange text) |>.getD range - children? := syms.toArray - } :: syms', stxs'') + + sectionLikeToDocumentSymbols (text : FileMap) (stx : Syntax) (stxs : List Syntax) (name : String) (kind : SymbolKind) (selection : Syntax) : List DocumentSymbol × List Syntax := + let (syms, stxs') := toDocumentSymbols text stxs + -- discard `end` + let (syms', stxs'') := toDocumentSymbols text (stxs'.drop 1) + let endStx := match stxs' with + | endStx::_ => endStx + | [] => (stx::stxs').getLast! + -- we can assume that commands always have at least one position (see `parseCommand`) + let range := (mkNullNode #[stx, endStx]).getRange?.get!.toLspRange text + (DocumentSymbol.mk { + name + kind + range + selectionRange := selection.getRange? |>.map (·.toLspRange text) |>.getD range + children? := syms.toArray + } :: syms', stxs'') def noHighlightKinds : Array SyntaxNodeKind := #[ -- usually have special highlighting by the client diff --git a/src/Lean/Server/Snapshots.lean b/src/Lean/Server/Snapshots.lean index 713b3953e6..3afe0e8565 100644 --- a/src/Lean/Server/Snapshots.lean +++ b/src/Lean/Server/Snapshots.lean @@ -66,7 +66,7 @@ def diagnostics (s : Snapshot) : Std.PersistentArray Lsp.Diagnostic := def infoTree (s : Snapshot) : InfoTree := -- the parser returns exactly one command per snapshot, and the elaborator creates exactly one node per command assert! s.cmdState.infoState.trees.size == 1 - s.cmdState.infoState.trees[0] + s.cmdState.infoState.trees[0]! def isAtEnd (s : Snapshot) : Bool := Parser.isEOI s.stx || Parser.isExitCommand s.stx diff --git a/src/Lean/Util/Trace.lean b/src/Lean/Util/Trace.lean index 7a7daec3ea..35d7a95a37 100644 --- a/src/Lean/Util/Trace.lean +++ b/src/Lean/Util/Trace.lean @@ -157,7 +157,7 @@ private def withNestedTracesFinalizer [Monad m] [MonadTrace m] (ref : Syntax) (c modifyTraces fun traces => if traces.size == 0 then currTraces - else if traces.size == 1 && traces[0].msg.isNest then + else if traces.size == 1 && traces[0]!.msg.isNest then currTraces ++ traces -- No nest of nest else let d := traces.foldl (init := MessageData.nil) fun d elem => diff --git a/src/Std/Data/HashMap.lean b/src/Std/Data/HashMap.lean index e53015a91f..bae921ab0d 100644 --- a/src/Std/Data/HashMap.lean +++ b/src/Std/Data/HashMap.lean @@ -181,6 +181,9 @@ def insert' (m : HashMap α β) (a : α) (b : β) : HashMap α β × Bool := @[inline] def getOp (self : HashMap α β) (idx : α) : Option β := self.find? idx +instance : GetElem (HashMap α β) α (Option β) fun _ _ => True where + getElem m k _ := m.find? k + @[inline] def contains (m : HashMap α β) (a : α) : Bool := match m with | ⟨ m, _ ⟩ => m.contains a diff --git a/src/Std/Data/PersistentArray.lean b/src/Std/Data/PersistentArray.lean index 367652b66d..a42a3831ce 100644 --- a/src/Std/Data/PersistentArray.lean +++ b/src/Std/Data/PersistentArray.lean @@ -54,8 +54,8 @@ abbrev div2Shift (i : USize) (shift : USize) : USize := i.shiftRight shift abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shiftLeft 1 shift) - 1) partial def getAux [Inhabited α] : PersistentArrayNode α → USize → USize → α - | node cs, i, shift => getAux (cs.get! (div2Shift i shift).toNat) (mod2Shift i shift) (shift - initShift) - | leaf cs, i, _ => cs.get! i.toNat + | node cs, i, shift => getAux cs[(div2Shift i shift).toNat]! (mod2Shift i shift) (shift - initShift) + | leaf cs, i, _ => cs[i.toNat]! def get! [Inhabited α] (t : PersistentArray α) (i : Nat) : α := if i >= t.tailOff then @@ -66,6 +66,10 @@ def get! [Inhabited α] (t : PersistentArray α) (i : Nat) : α := def getOp [Inhabited α] (self : PersistentArray α) (idx : Nat) : α := self.get! idx +-- TODO: remove [Inhabited α] +instance [Inhabited α] : GetElem (PersistentArray α) Nat α fun as i => i < as.size where + getElem xs i _ := xs.get! i + partial def setAux : PersistentArrayNode α → USize → USize → α → PersistentArrayNode α | node cs, i, shift, a => let j := div2Shift i shift diff --git a/src/Std/Data/PersistentHashMap.lean b/src/Std/Data/PersistentHashMap.lean index 3f674fd3c4..95d6ed422c 100644 --- a/src/Std/Data/PersistentHashMap.lean +++ b/src/Std/Data/PersistentHashMap.lean @@ -106,8 +106,8 @@ partial def insertAux [BEq α] [Hashable α] : Node α β → USize → USize | ⟨Node.collision keys vals heq, _⟩ => let rec traverse (i : Nat) (entries : Node α β) : Node α β := if h : i < keys.size then - let k := keys[⟨i, h⟩] - let v := vals[⟨i, heq ▸ h⟩] + let k := keys[i] + let v := vals[i]'(heq ▸ h) let h := hash k |>.toUSize let h := div2Shift h (shift * (depth - 1)) traverse (i+1) (insertAux entries h depth k v) @@ -129,8 +129,8 @@ def insert {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option β := if h : i < keys.size then - let k' := keys[⟨i, h⟩] - if k == k' then some vals[⟨i, by rw [←heq]; assumption⟩] + let k' := keys[i] + if k == k' then some (vals[i]'(by rw [←heq]; assumption)) else findAtAux keys vals heq (i+1) k else none @@ -149,6 +149,9 @@ def find? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Op @[inline] def getOp {_ : BEq α} {_ : Hashable α} (self : PersistentHashMap α β) (idx : α) : Option β := self.find? idx +instance {_ : BEq α} {_ : Hashable α} : GetElem (PersistentHashMap α β) α (Option β) fun _ _ => True where + getElem m i _ := m.find? i + @[inline] def findD {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) (a : α) (b₀ : β) : β := (m.find? a).getD b₀ @@ -159,8 +162,8 @@ def find? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Op partial def findEntryAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option (α × β) := if h : i < keys.size then - let k' := keys[⟨i, h⟩] - if k == k' then some (k', vals[⟨i, by rw [←heq]; assumption⟩]) + let k' := keys[i] + if k == k' then some (k', vals[i]'(by rw [←heq]; assumption)) else findEntryAtAux keys vals heq (i+1) k else none @@ -178,7 +181,7 @@ def findEntry? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α partial def containsAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Bool := if h : i < keys.size then - let k' := keys[⟨i, h⟩] + let k' := keys[i] if k == k' then true else containsAtAux keys vals heq (i+1) k else false @@ -197,7 +200,7 @@ def contains [BEq α] [Hashable α] : PersistentHashMap α β → α → Bool partial def isUnaryEntries (a : Array (Entry α β (Node α β))) (i : Nat) (acc : Option (α × β)) : Option (α × β) := if h : i < a.size then - match a[⟨i, h⟩] with + match a[i] with | Entry.null => isUnaryEntries a (i+1) acc | Entry.ref _ => none | Entry.entry k v => @@ -211,7 +214,8 @@ def isUnaryNode : Node α β → Option (α × β) | Node.collision keys vals heq => if h : 1 = keys.size then have : 0 < keys.size := by rw [←h]; decide - some (keys[⟨0, this⟩], vals[⟨0, by rw [←heq]; assumption⟩]) + have : 0 < vals.size := by rw [←heq]; assumption + some (keys[0], vals[0]) else none @@ -253,8 +257,8 @@ variable {σ : Type w} | Node.collision keys vals heq, acc => let rec traverse (i : Nat) (acc : σ) : m σ := do if h : i < keys.size then - let k := keys[⟨i, h⟩] - let v := vals[⟨i, heq ▸ h⟩] + let k := keys[i] + let v := vals[i]'(heq ▸ h) traverse (i+1) (← f acc k v) else pure acc