refactor: prepare to elaborate a[i] notation using typeclasses
This commit is contained in:
parent
e30ac86bd5
commit
e4b358a01e
25 changed files with 144 additions and 79 deletions
|
|
@ -39,13 +39,13 @@ def singleton (v : α) : Array α :=
|
|||
`fget` may be slightly slower than `uget`. -/
|
||||
@[extern "lean_array_uget"]
|
||||
def uget (a : @& Array α) (i : USize) (h : i.toNat < a.size) : α :=
|
||||
a[⟨i.toNat, h⟩]
|
||||
a[i.toNat]
|
||||
|
||||
def back [Inhabited α] (a : Array α) : α :=
|
||||
a.get! (a.size - 1)
|
||||
|
||||
def get? (a : Array α) (i : Nat) : Option α :=
|
||||
if h : i < a.size then some a[⟨i, h⟩] else none
|
||||
if h : i < a.size then some a[i] else none
|
||||
|
||||
abbrev getOp? (self : Array α) (idx : Nat) : Option α :=
|
||||
self.get? idx
|
||||
|
|
@ -55,7 +55,8 @@ def back? (a : Array α) : Option α :=
|
|||
|
||||
-- auxiliary declaration used in the equation compiler when pattern matching array literals.
|
||||
abbrev getLit {α : Type u} {n : Nat} (a : Array α) (i : Nat) (h₁ : a.size = n) (h₂ : i < n) : α :=
|
||||
a[⟨i, h₁.symm ▸ h₂⟩]
|
||||
have := h₁.symm ▸ h₂
|
||||
a[i]
|
||||
|
||||
@[simp] theorem size_set (a : Array α) (i : Fin a.size) (v : α) : (set a i v).size = a.size :=
|
||||
List.length_set ..
|
||||
|
|
@ -166,7 +167,7 @@ protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m
|
|||
have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h
|
||||
have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide)
|
||||
have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this
|
||||
match (← f as[⟨as.size - 1 - i, this⟩] b) with
|
||||
match (← f as[as.size - 1 - i] b) with
|
||||
| ForInStep.done b => pure b
|
||||
| ForInStep.yield b => loop i (Nat.le_of_lt h') b
|
||||
loop as.size (Nat.le_refl _) b
|
||||
|
|
@ -199,7 +200,8 @@ def foldlM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : β
|
|||
match i with
|
||||
| 0 => pure b
|
||||
| i'+1 =>
|
||||
loop i' (j+1) (← f b as[⟨j, Nat.lt_of_lt_of_le hlt h⟩])
|
||||
have : j < as.size := Nat.lt_of_lt_of_le hlt h
|
||||
loop i' (j+1) (← f b as[j])
|
||||
else
|
||||
pure b
|
||||
loop (stop - start) start init
|
||||
|
|
@ -236,7 +238,7 @@ def foldrM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α
|
|||
| 0, _ => pure b
|
||||
| i+1, h =>
|
||||
have : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self _) h
|
||||
fold i (Nat.le_of_lt this) (← f as[⟨i, this⟩] b)
|
||||
fold i (Nat.le_of_lt this) (← f as[i] b)
|
||||
if h : start ≤ as.size then
|
||||
if stop < start then
|
||||
fold start h init
|
||||
|
|
@ -330,7 +332,8 @@ def anyM {α : Type u} {m : Type → Type w} [Monad m] (p : α → m Bool) (as :
|
|||
let any (stop : Nat) (h : stop ≤ as.size) :=
|
||||
let rec loop (j : Nat) : m Bool := do
|
||||
if hlt : j < stop then
|
||||
if (← p as[⟨j, Nat.lt_of_lt_of_le hlt h⟩]) then
|
||||
have : j < as.size := Nat.lt_of_lt_of_le hlt h
|
||||
if (← p as[j]) then
|
||||
pure true
|
||||
else
|
||||
loop (j+1)
|
||||
|
|
@ -353,7 +356,7 @@ def findSomeRevM? {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m]
|
|||
| 0, _ => pure none
|
||||
| i+1, h => do
|
||||
have : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self _) h
|
||||
let r ← f as[⟨i, this⟩]
|
||||
let r ← f as[i]
|
||||
match r with
|
||||
| some _ => pure r
|
||||
| none =>
|
||||
|
|
@ -422,7 +425,7 @@ def findIdx? {α : Type u} (as : Array α) (p : α → Bool) : Option Nat :=
|
|||
rw [inv] at hlt
|
||||
exact absurd hlt (Nat.lt_irrefl _)
|
||||
| i+1, inv =>
|
||||
if p as[⟨j, hlt⟩] then
|
||||
if p as[j] then
|
||||
some j
|
||||
else
|
||||
have : i + (j+1) = as.size := by
|
||||
|
|
@ -515,7 +518,8 @@ namespace Array
|
|||
@[specialize]
|
||||
def isEqvAux (a b : Array α) (hsz : a.size = b.size) (p : α → α → Bool) (i : Nat) : Bool :=
|
||||
if h : i < a.size then
|
||||
p a[⟨i, h⟩] b[⟨i, hsz ▸ h⟩] && isEqvAux a b hsz p (i+1)
|
||||
have : i < b.size := hsz ▸ h
|
||||
p a[i] b[i] && isEqvAux a b hsz p (i+1)
|
||||
else
|
||||
true
|
||||
termination_by _ => a.size - i
|
||||
|
|
@ -553,7 +557,7 @@ def filterMap (f : α → Option β) (as : Array α) (start := 0) (stop := as.si
|
|||
@[specialize]
|
||||
def getMax? (as : Array α) (lt : α → α → Bool) : Option α :=
|
||||
if h : 0 < as.size then
|
||||
let a0 := as[⟨0, h⟩]
|
||||
let a0 := as[0]
|
||||
some <| as.foldl (init := a0) (start := 1) fun best a =>
|
||||
if lt best a then a else best
|
||||
else
|
||||
|
|
@ -572,7 +576,7 @@ def partition (p : α → Bool) (as : Array α) : Array α × Array α := Id.run
|
|||
|
||||
theorem ext (a b : Array α)
|
||||
(h₁ : a.size = b.size)
|
||||
(h₂ : (i : Nat) → (hi₁ : i < a.size) → (hi₂ : i < b.size) → a[⟨i, hi₁⟩] = b[⟨i, hi₂⟩])
|
||||
(h₂ : (i : Nat) → (hi₁ : i < a.size) → (hi₂ : i < b.size) → a[i] = b[i])
|
||||
: a = b := by
|
||||
let rec extAux (a b : List α)
|
||||
(h₁ : a.length = b.length)
|
||||
|
|
@ -758,8 +762,8 @@ theorem toArrayLit_eq (a : Array α) (n : Nat) (hsz : a.size = n) : a = toArrayL
|
|||
|
||||
def isPrefixOfAux [BEq α] (as bs : Array α) (hle : as.size ≤ bs.size) (i : Nat) : Bool :=
|
||||
if h : i < as.size then
|
||||
let a := as[⟨i, h⟩]
|
||||
let b := bs[⟨i, Nat.lt_of_lt_of_le h hle⟩]
|
||||
let a := as[i]
|
||||
let b := bs[i]'(Nat.lt_of_lt_of_le h hle)
|
||||
if a == b then
|
||||
isPrefixOfAux as bs hle (i+1)
|
||||
else
|
||||
|
|
@ -780,11 +784,11 @@ private def allDiffAuxAux [BEq α] (as : Array α) (a : α) : forall (i : Nat),
|
|||
| 0, _ => true
|
||||
| i+1, h =>
|
||||
have : i < as.size := Nat.lt_trans (Nat.lt_succ_self _) h;
|
||||
a != as[⟨i, this⟩] && allDiffAuxAux as a i this
|
||||
a != as[i] && allDiffAuxAux as a i this
|
||||
|
||||
private def allDiffAux [BEq α] (as : Array α) (i : Nat) : Bool :=
|
||||
if h : i < as.size then
|
||||
allDiffAuxAux as as[⟨i, h⟩] i h && allDiffAux as (i+1)
|
||||
allDiffAuxAux as as[i] i h && allDiffAux as (i+1)
|
||||
else
|
||||
true
|
||||
termination_by _ => as.size - i
|
||||
|
|
@ -794,9 +798,9 @@ def allDiff [BEq α] (as : Array α) : Bool :=
|
|||
|
||||
@[specialize] def zipWithAux (f : α → β → γ) (as : Array α) (bs : Array β) (i : Nat) (cs : Array γ) : Array γ :=
|
||||
if h : i < as.size then
|
||||
let a := as[⟨i, h⟩]
|
||||
let a := as[i]
|
||||
if h : i < bs.size then
|
||||
let b := bs[⟨i, h⟩]
|
||||
let b := bs[i]
|
||||
zipWithAux f as bs (i+1) <| cs.push <| f a b
|
||||
else
|
||||
cs
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import Init.Classical
|
|||
|
||||
namespace Array
|
||||
|
||||
theorem eq_of_isEqvAux [DecidableEq α] (a b : Array α) (hsz : a.size = b.size) (i : Nat) (hi : i ≤ a.size) (heqv : Array.isEqvAux a b hsz (fun x y => x = y) i) (j : Nat) (low : i ≤ j) (high : j < a.size) : a[⟨j, high⟩] = b[⟨j, hsz ▸ high⟩] := by
|
||||
theorem eq_of_isEqvAux [DecidableEq α] (a b : Array α) (hsz : a.size = b.size) (i : Nat) (hi : i ≤ a.size) (heqv : Array.isEqvAux a b hsz (fun x y => x = y) i) (j : Nat) (low : i ≤ j) (high : j < a.size) : a[j] = b[j]'(hsz ▸ high) := by
|
||||
by_cases h : i < a.size
|
||||
· unfold Array.isEqvAux at heqv
|
||||
simp [h] at heqv
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ where
|
|||
| 0 => a
|
||||
| j'+1 =>
|
||||
have h' : j' < a.size := by subst j; exact Nat.lt_trans (Nat.lt_succ_self _) h
|
||||
if lt a[⟨j, h⟩] a[⟨j', h'⟩] then
|
||||
if lt a[j] a[j'] then
|
||||
swapLoop (a.swap ⟨j, h⟩ ⟨j', h'⟩) j' (by rw [size_swap]; assumption done)
|
||||
else
|
||||
a
|
||||
|
|
|
|||
|
|
@ -27,11 +27,14 @@ def get (s : Subarray α) (i : Fin s.size) : α :=
|
|||
simp [size] at this
|
||||
rw [Nat.add_comm]
|
||||
exact Nat.add_lt_of_lt_sub this
|
||||
s.as[⟨s.start + i.val, this⟩]
|
||||
s.as[s.start + i.val]
|
||||
|
||||
abbrev getOp (self : Subarray α) (idx : Fin self.size) : α :=
|
||||
self.get idx
|
||||
|
||||
instance : GetElem (Subarray α) Nat α fun xs i => i < xs.size where
|
||||
getElem xs i h := xs.get ⟨i, h⟩
|
||||
|
||||
@[inline] def getD (s : Subarray α) (i : Nat) (v₀ : α) : α :=
|
||||
if h : i < s.size then s.get ⟨i, h⟩ else v₀
|
||||
|
||||
|
|
|
|||
|
|
@ -112,6 +112,13 @@ theorem val_ne_of_ne {i j : Fin n} (h : i ≠ j) : val i ≠ val j :=
|
|||
theorem modn_lt : ∀ {m : Nat} (i : Fin n), m > 0 → (modn i m).val < m
|
||||
| _, ⟨_, _⟩, hp => Nat.lt_of_le_of_lt (mod_le _ _) (mod_lt _ hp)
|
||||
|
||||
theorem val_lt_of_le (i : Fin b) (h : b ≤ n) : i.val < n :=
|
||||
Nat.lt_of_lt_of_le i.isLt h
|
||||
|
||||
end Fin
|
||||
|
||||
open Fin
|
||||
instance [GetElem Cont Nat Elem Dom] : GetElem Cont (Fin n) Elem fun xs i => Dom xs i where
|
||||
getElem xs i h := getElem xs i.1 h
|
||||
|
||||
macro_rules
|
||||
| `(tactic| get_elem_tactic_trivial) => `(tactic| apply Fin.val_lt_of_le; (first | assumption | simp (config := { arith := true })); done)
|
||||
|
|
|
|||
|
|
@ -384,7 +384,7 @@ def unsetTrailing (stx : Syntax) : Syntax :=
|
|||
|
||||
@[specialize] private partial def updateFirst {α} [Inhabited α] (a : Array α) (f : α → Option α) (i : Nat) : Option (Array α) :=
|
||||
if h : i < a.size then
|
||||
let v := a[⟨i, h⟩]
|
||||
let v := a[i]
|
||||
match f v with
|
||||
| some v => some <| a.set ⟨i, h⟩ v
|
||||
| none => updateFirst a f (i+1)
|
||||
|
|
@ -1001,13 +1001,14 @@ open Lean
|
|||
|
||||
private partial def filterSepElemsMAux {m : Type → Type} [Monad m] (a : Array Syntax) (p : Syntax → m Bool) (i : Nat) (acc : Array Syntax) : m (Array Syntax) := do
|
||||
if h : i < a.size then
|
||||
let stx := a[⟨i, h⟩]
|
||||
let stx := a[i]
|
||||
if (← p stx) then
|
||||
if acc.isEmpty then
|
||||
filterSepElemsMAux a p (i+2) (acc.push stx)
|
||||
else if hz : i ≠ 0 then
|
||||
have : i.pred < i := Nat.pred_lt hz
|
||||
let sepStx := a[⟨i.pred, Nat.lt_trans this h⟩]
|
||||
have : i.pred < a.size := Nat.lt_trans this h
|
||||
let sepStx := a[i.pred]
|
||||
filterSepElemsMAux a p (i+2) ((acc.push sepStx).push stx)
|
||||
else
|
||||
filterSepElemsMAux a p (i+2) (acc.push stx)
|
||||
|
|
@ -1024,7 +1025,7 @@ def filterSepElems (a : Array Syntax) (p : Syntax → Bool) : Array Syntax :=
|
|||
|
||||
private partial def mapSepElemsMAux {m : Type → Type} [Monad m] (a : Array Syntax) (f : Syntax → m Syntax) (i : Nat) (acc : Array Syntax) : m (Array Syntax) := do
|
||||
if h : i < a.size then
|
||||
let stx := a[⟨i, h⟩]
|
||||
let stx := a[i]
|
||||
if i % 2 == 0 then do
|
||||
let stx ← f stx
|
||||
mapSepElemsMAux a f (i+1) (acc.push stx)
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ macro_rules
|
|||
match i, skip with
|
||||
| 0, _ => pure result
|
||||
| i+1, true => expandListLit i false result
|
||||
| i+1, false => expandListLit i true (← ``(List.cons $(⟨elems.elemsAndSeps[i]!⟩) $result))
|
||||
| i+1, false => expandListLit i true (← ``(List.cons $(⟨elems.elemsAndSeps.get! i⟩) $result))
|
||||
if elems.elemsAndSeps.size < 64 then
|
||||
expandListLit elems.elemsAndSeps.size false (← ``(List.nil))
|
||||
else
|
||||
|
|
|
|||
|
|
@ -1229,6 +1229,11 @@ def panic {α : Type u} [Inhabited α] (msg : String) : α :=
|
|||
-- TODO: this be applied directly to `Inhabited`'s definition when we remove the above workaround
|
||||
attribute [nospecialize] Inhabited
|
||||
|
||||
class GetElem (Cont : Type u) (Idx : Type v) (Elem : outParam (Type w)) (Dom : outParam (Cont → Idx → Prop)) where
|
||||
getElem (xs : Cont) (i : Idx) (h : Dom xs i) : Elem
|
||||
|
||||
export GetElem (getElem)
|
||||
|
||||
/-
|
||||
The Compiler has special support for arrays.
|
||||
They are implemented using dynamic arrays: https://en.wikipedia.org/wiki/Dynamic_array
|
||||
|
|
@ -1270,6 +1275,9 @@ abbrev Array.getOp {α : Type u} (self : Array α) (idx : Fin self.size) : α :=
|
|||
abbrev Array.getOp! {α : Type u} [Inhabited α] (self : Array α) (idx : Nat) : α :=
|
||||
self.get! idx
|
||||
|
||||
instance : GetElem (Array α) Nat α fun xs i => LT.lt i xs.size where
|
||||
getElem xs i h := xs.get ⟨i, h⟩
|
||||
|
||||
@[extern "lean_array_push"]
|
||||
def Array.push {α : Type u} (a : Array α) (v : α) : Array α := {
|
||||
data := List.concat a.data v
|
||||
|
|
@ -1913,6 +1921,9 @@ def getArg (stx : Syntax) (i : Nat) : Syntax :=
|
|||
@[inline] def getOp (self : Syntax) (idx : Nat) : Syntax :=
|
||||
self.getArg idx
|
||||
|
||||
instance : GetElem Syntax Nat Syntax fun _ _ => True where
|
||||
getElem stx i _ := stx.getArg i
|
||||
|
||||
def getArgs (stx : Syntax) : Array Syntax :=
|
||||
match stx with
|
||||
| Syntax.node _ _ args => args
|
||||
|
|
|
|||
|
|
@ -427,3 +427,15 @@ end Lean
|
|||
`‹t›` resolves to an (arbitrary) hypothesis of type `t`. It is useful for referring to hypotheses without accessible names.
|
||||
`t` may contain holes that are solved by unification with the expected type; in particular, `‹_›` is a shortcut for `by assumption`. -/
|
||||
macro "‹" type:term "›" : term => `((by assumption : $type))
|
||||
|
||||
syntax "get_elem_tactic_trivial" : tactic -- extensible tactic
|
||||
|
||||
macro_rules | `(tactic| get_elem_tactic_trivial) => `(tactic| trivial)
|
||||
macro_rules | `(tactic| get_elem_tactic_trivial) => `(tactic| decide)
|
||||
macro_rules | `(tactic| get_elem_tactic_trivial) => `(tactic| assumption)
|
||||
|
||||
macro "get_elem_tactic" : tactic => `(get_elem_tactic_trivial) -- TODO: add error message
|
||||
|
||||
macro:max (priority := high) x:term noWs "[" i:term "]" : term => `(getElem $x $i (by get_elem_tactic))
|
||||
|
||||
macro x:term noWs "[" i:term "]'" h:term:max : term => `(getElem $x $i $h)
|
||||
|
|
|
|||
|
|
@ -57,3 +57,12 @@ def withPtrEq {α : Type u} (a b : α) (k : Unit → Bool) (h : a = b → k () =
|
|||
|
||||
@[implementedBy withPtrAddrUnsafe]
|
||||
def withPtrAddr {α : Type u} {β : Type v} (a : α) (k : USize → β) (h : ∀ u₁ u₂, k u₁ = k u₂) : β := k 0
|
||||
|
||||
@[inline] def getElem! [GetElem Cont Idx Elem Dom] [Inhabited Elem] (xs : Cont) (i : Idx) [Decidable (Dom xs i)] : Elem :=
|
||||
if h : _ then getElem xs i h else unreachable!
|
||||
|
||||
@[inline] def getElem? [GetElem Cont Idx Elem Dom] (xs : Cont) (i : Idx) [Decidable (Dom xs i)] : Option Elem :=
|
||||
if h : _ then some (getElem xs i h) else none
|
||||
|
||||
macro:max (priority := high) x:term noWs "[" i:term "]" noWs "?" : term => `(getElem? $x $i) -- TODO: remove priority
|
||||
macro:max (priority := high) x:term noWs "[" i:term "]" noWs "!" : term => `(getElem! $x $i) -- TODO: remove priority
|
||||
|
|
|
|||
|
|
@ -77,12 +77,12 @@ def Attribute.Builtin.ensureNoArgs (stx : Syntax) : AttrM Unit := do
|
|||
def Attribute.Builtin.getIdent? (stx : Syntax) : AttrM (Option Syntax) := do
|
||||
if stx.getKind == `Lean.Parser.Attr.simple then
|
||||
if !stx[1].isNone && stx[1][0].isIdent then
|
||||
return stx[1][0]
|
||||
return some stx[1][0]
|
||||
else
|
||||
return none
|
||||
/- We handle `macro` here because it is handled by the generic `KeyedDeclsAttribute -/
|
||||
else if stx.getKind == `Lean.Parser.Attr.«macro» || stx.getKind == `Lean.Parser.Attr.«export» then
|
||||
return stx[1]
|
||||
return some stx[1]
|
||||
else
|
||||
throwErrorAt stx "unexpected attribute argument"
|
||||
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ def interpExpr : Expr → M Value
|
|||
| none => do
|
||||
let s ← get
|
||||
match ctx.decls.findIdx? (fun decl => decl.name == fid) with
|
||||
| some idx => pure s.funVals[idx]
|
||||
| some idx => pure s.funVals[idx]!
|
||||
| none => pure top
|
||||
| _ => pure top
|
||||
|
||||
|
|
@ -273,14 +273,14 @@ def inferStep : M Bool := do
|
|||
match ctx.decls[idx]! with
|
||||
| Decl.fdecl (xs := ys) (body := b) .. => do
|
||||
let s ← get
|
||||
let currVals := s.funVals[idx]
|
||||
let currVals := s.funVals[idx]!
|
||||
withReader (fun ctx => { ctx with currFnIdx := idx }) do
|
||||
ys.forM fun y => updateVarAssignment y.x top
|
||||
interpFnBody b
|
||||
let s ← get
|
||||
let newVals := s.funVals[idx]
|
||||
let newVals := s.funVals[idx]!
|
||||
pure (modified || currVals != newVals)
|
||||
| Decl.extern _ _ _ _ => pure modified
|
||||
| Decl.extern .. => pure modified
|
||||
|
||||
partial def inferMain : M Unit := do
|
||||
let modified ← inferStep
|
||||
|
|
@ -324,7 +324,7 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
|||
let assignments := s.assignments
|
||||
modify fun s =>
|
||||
let env := decls.size.fold (init := s.env) fun i env =>
|
||||
addFunctionSummary env decls[i]!.name funVals[i]
|
||||
addFunctionSummary env decls[i]!.name funVals[i]!
|
||||
{ s with env := env }
|
||||
return decls.mapIdx fun i decl => elimDead assignments[i]! decl
|
||||
|
||||
|
|
|
|||
|
|
@ -1475,7 +1475,7 @@ mutual
|
|||
partial def doTryToCode (doTry : Syntax) (doElems: List Syntax) : M CodeBlock := do
|
||||
let tryCode ← doSeqToCode (getDoSeqElems doTry[1])
|
||||
let optFinally := doTry[3]
|
||||
let catches ← doTry[2].getArgs.mapM fun catchStx => do
|
||||
let catches ← doTry[2].getArgs.mapM fun catchStx : Syntax => do
|
||||
if catchStx.getKind == ``Lean.Parser.Term.doCatch then
|
||||
let x := catchStx[1]
|
||||
if x.isIdent then
|
||||
|
|
|
|||
|
|
@ -449,7 +449,8 @@ def withMacroExpansionInfo [MonadFinally m] [Monad m] [MonadInfoTree m] [MonadLC
|
|||
if (← getInfoState).enabled then
|
||||
let treesSaved ← getResetInfoTrees
|
||||
Prod.fst <$> MonadFinally.tryFinally' x fun _ => modifyInfoState fun s =>
|
||||
if s.trees.size > 0 then
|
||||
if h : s.trees.size > 0 then
|
||||
have : s.trees.size - 1 < s.trees.size := Nat.sub_lt h (by decide)
|
||||
{ s with trees := treesSaved, assignment := s.assignment.insert mvarId s.trees[s.trees.size - 1] }
|
||||
else
|
||||
{ s with trees := treesSaved }
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ private def expandNonAtomicExplicitSources (stx : Syntax) : TermElabM (Option Sy
|
|||
return none
|
||||
if sources.any (·.isMissing) then
|
||||
throwAbortTerm
|
||||
go sources.toList #[]
|
||||
return some (← go sources.toList #[])
|
||||
where
|
||||
go (sources : List Syntax) (sourcesNew : Array Syntax) : TermElabM Syntax := do
|
||||
match sources with
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ def EnvExtensionState : Type := EnvExtensionStateSpec.fst
|
|||
instance : Inhabited EnvExtensionState := EnvExtensionStateSpec.snd
|
||||
|
||||
def ModuleIdx := Nat
|
||||
abbrev ModuleIdx.toNat (midx : ModuleIdx) : Nat := midx
|
||||
|
||||
instance : Inhabited ModuleIdx := inferInstanceAs (Inhabited Nat)
|
||||
|
||||
|
|
|
|||
|
|
@ -326,13 +326,13 @@ instance : ForIn m LocalContext LocalDecl where
|
|||
|
||||
partial def isSubPrefixOfAux (a₁ a₂ : PArray (Option LocalDecl)) (exceptFVars : Array Expr) (i j : Nat) : Bool :=
|
||||
if i < a₁.size then
|
||||
match a₁[i] with
|
||||
match a₁[i]! with
|
||||
| none => isSubPrefixOfAux a₁ a₂ exceptFVars (i+1) j
|
||||
| some decl₁ =>
|
||||
if exceptFVars.any fun fvar => fvar.fvarId! == decl₁.fvarId then
|
||||
isSubPrefixOfAux a₁ a₂ exceptFVars (i+1) j
|
||||
else if j < a₂.size then
|
||||
match a₂[j] with
|
||||
match a₂[j]! with
|
||||
| none => isSubPrefixOfAux a₁ a₂ exceptFVars i (j+1)
|
||||
| some decl₂ => if decl₁.fvarId == decl₂.fvarId then isSubPrefixOfAux a₁ a₂ exceptFVars (i+1) (j+1) else isSubPrefixOfAux a₁ a₂ exceptFVars i (j+1)
|
||||
else false
|
||||
|
|
@ -395,7 +395,7 @@ def sanitizeNames (lctx : LocalContext) : StateM NameSanitizerState LocalContext
|
|||
if !getSanitizeNames st.options then pure lctx else
|
||||
StateT.run' (s := ({} : NameSet)) <|
|
||||
lctx.decls.size.foldRevM (init := lctx) fun i lctx => do
|
||||
match lctx.decls[i] with
|
||||
match lctx.decls[i]! with
|
||||
| none => pure lctx
|
||||
| some decl =>
|
||||
if decl.userName.hasMacroScopes || (← get).contains decl.userName then do
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ def findModuleOf? [Monad m] [MonadEnv m] [MonadError m] (declName : Name) : m (O
|
|||
discard <| getConstInfo declName -- ensure declaration exists
|
||||
match (← getEnv).getModuleIdxFor? declName with
|
||||
| none => return none
|
||||
| some modIdx => return some ((← getEnv).allImportedModuleNames[modIdx]!)
|
||||
| some modIdx => return some ((← getEnv).allImportedModuleNames[modIdx.toNat]!)
|
||||
|
||||
def isEnumType [Monad m] [MonadEnv m] [MonadError m] (declName : Name) : m Bool := do
|
||||
if let ConstantInfo.inductInfo info ← getConstInfo declName then
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ partial def canBottomUp (e : Expr) (mvar? : Option Expr := none) (fuel : Nat :=
|
|||
inspectOutParams args[i]! mvars[i]!
|
||||
else if bInfos[i]! == BinderInfo.default then
|
||||
if ← isTrivialBottomUp args[i]! then tryUnify args[i]! mvars[i]!
|
||||
else if ← typeUnknown mvars[i]! <&&> canBottomUp args[i]! mvars[i]! fuel then tryUnify args[i]! mvars[i]!
|
||||
else if ← typeUnknown mvars[i]! <&&> canBottomUp args[i]! (some mvars[i]!) fuel then tryUnify args[i]! mvars[i]!
|
||||
if ← (pure (isHBinOp e) <&&> (valUnknown mvars[0]! <||> valUnknown mvars[1]!)) then tryUnify mvars[0]! mvars[1]!
|
||||
if mvar?.isSome then tryUnify resultType (← inferType mvar?.get!)
|
||||
return !(← valUnknown resultType)
|
||||
|
|
|
|||
|
|
@ -251,8 +251,9 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams)
|
|||
stxs := stxs ++ (← parseAhead doc.meta.mkInputContext lastSnap).toList
|
||||
let (syms, _) := toDocumentSymbols doc.meta.text stxs
|
||||
return { syms := syms.toArray }
|
||||
where
|
||||
toDocumentSymbols (text : FileMap)
|
||||
where
|
||||
toDocumentSymbols (text : FileMap) (stxs : List Syntax) : List DocumentSymbol × List Syntax :=
|
||||
match stxs with
|
||||
| [] => ([], [])
|
||||
| stx::stxs => match stx with
|
||||
| `(namespace $id) => sectionLikeToDocumentSymbols text stx stxs (id.getId.toString) SymbolKind.namespace id
|
||||
|
|
@ -263,12 +264,15 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams)
|
|||
unless stx.isOfKind ``Lean.Parser.Command.declaration do
|
||||
return (syms, stxs')
|
||||
if let some stxRange := stx.getRange? then
|
||||
let (name, selection) := match stx with
|
||||
let (name, selection) : String × Syntax := match stx with
|
||||
| `($_:declModifiers $_:attrKind instance $[$np:namedPrio]? $[$id:ident$[.{$ls,*}]?]? $sig:declSig $_) =>
|
||||
((·.getId.toString) <$> id |>.getD s!"instance {sig.raw.reprint.getD ""}", id.map (·.raw) |>.getD sig)
|
||||
| _ => match stx[1][1] with
|
||||
| `(declId|$id:ident$[.{$ls,*}]?) => (id.getId.toString, id)
|
||||
| _ => (stx[1][0].isIdOrAtom?.getD "<unknown>", stx[1][0])
|
||||
| _ =>
|
||||
match stx.getArg 1 |>.getArg 1 with
|
||||
| `(declId|$id:ident$[.{$ls,*}]?) => (id.raw.getId.toString, id)
|
||||
| _ =>
|
||||
let stx10 : Syntax := (stx.getArg 1).getArg 0 -- TODO: stx[1][0] times out
|
||||
(stx10.isIdOrAtom?.getD "<unknown>", stx10)
|
||||
if let some selRange := selection.getRange? then
|
||||
return (DocumentSymbol.mk {
|
||||
name := name
|
||||
|
|
@ -277,22 +281,23 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams)
|
|||
selectionRange := selRange.toLspRange text
|
||||
} :: syms, stxs')
|
||||
return (syms, stxs')
|
||||
sectionLikeToDocumentSymbols (text : FileMap) (stx : Syntax) (stxs : List Syntax) (name : String) (kind : SymbolKind) (selection : Syntax) :=
|
||||
let (syms, stxs') := toDocumentSymbols text stxs
|
||||
-- discard `end`
|
||||
let (syms', stxs'') := toDocumentSymbols text (stxs'.drop 1)
|
||||
let endStx := match stxs' with
|
||||
| endStx::_ => endStx
|
||||
| [] => (stx::stxs').getLast!
|
||||
-- we can assume that commands always have at least one position (see `parseCommand`)
|
||||
let range := (mkNullNode #[stx, endStx]).getRange?.get!.toLspRange text
|
||||
(DocumentSymbol.mk {
|
||||
name
|
||||
kind
|
||||
range
|
||||
selectionRange := selection.getRange? |>.map (·.toLspRange text) |>.getD range
|
||||
children? := syms.toArray
|
||||
} :: syms', stxs'')
|
||||
|
||||
sectionLikeToDocumentSymbols (text : FileMap) (stx : Syntax) (stxs : List Syntax) (name : String) (kind : SymbolKind) (selection : Syntax) : List DocumentSymbol × List Syntax :=
|
||||
let (syms, stxs') := toDocumentSymbols text stxs
|
||||
-- discard `end`
|
||||
let (syms', stxs'') := toDocumentSymbols text (stxs'.drop 1)
|
||||
let endStx := match stxs' with
|
||||
| endStx::_ => endStx
|
||||
| [] => (stx::stxs').getLast!
|
||||
-- we can assume that commands always have at least one position (see `parseCommand`)
|
||||
let range := (mkNullNode #[stx, endStx]).getRange?.get!.toLspRange text
|
||||
(DocumentSymbol.mk {
|
||||
name
|
||||
kind
|
||||
range
|
||||
selectionRange := selection.getRange? |>.map (·.toLspRange text) |>.getD range
|
||||
children? := syms.toArray
|
||||
} :: syms', stxs'')
|
||||
|
||||
def noHighlightKinds : Array SyntaxNodeKind := #[
|
||||
-- usually have special highlighting by the client
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ def diagnostics (s : Snapshot) : Std.PersistentArray Lsp.Diagnostic :=
|
|||
def infoTree (s : Snapshot) : InfoTree :=
|
||||
-- the parser returns exactly one command per snapshot, and the elaborator creates exactly one node per command
|
||||
assert! s.cmdState.infoState.trees.size == 1
|
||||
s.cmdState.infoState.trees[0]
|
||||
s.cmdState.infoState.trees[0]!
|
||||
|
||||
def isAtEnd (s : Snapshot) : Bool :=
|
||||
Parser.isEOI s.stx || Parser.isExitCommand s.stx
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ private def withNestedTracesFinalizer [Monad m] [MonadTrace m] (ref : Syntax) (c
|
|||
modifyTraces fun traces =>
|
||||
if traces.size == 0 then
|
||||
currTraces
|
||||
else if traces.size == 1 && traces[0].msg.isNest then
|
||||
else if traces.size == 1 && traces[0]!.msg.isNest then
|
||||
currTraces ++ traces -- No nest of nest
|
||||
else
|
||||
let d := traces.foldl (init := MessageData.nil) fun d elem =>
|
||||
|
|
|
|||
|
|
@ -181,6 +181,9 @@ def insert' (m : HashMap α β) (a : α) (b : β) : HashMap α β × Bool :=
|
|||
@[inline] def getOp (self : HashMap α β) (idx : α) : Option β :=
|
||||
self.find? idx
|
||||
|
||||
instance : GetElem (HashMap α β) α (Option β) fun _ _ => True where
|
||||
getElem m k _ := m.find? k
|
||||
|
||||
@[inline] def contains (m : HashMap α β) (a : α) : Bool :=
|
||||
match m with
|
||||
| ⟨ m, _ ⟩ => m.contains a
|
||||
|
|
|
|||
|
|
@ -54,8 +54,8 @@ abbrev div2Shift (i : USize) (shift : USize) : USize := i.shiftRight shift
|
|||
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shiftLeft 1 shift) - 1)
|
||||
|
||||
partial def getAux [Inhabited α] : PersistentArrayNode α → USize → USize → α
|
||||
| node cs, i, shift => getAux (cs.get! (div2Shift i shift).toNat) (mod2Shift i shift) (shift - initShift)
|
||||
| leaf cs, i, _ => cs.get! i.toNat
|
||||
| node cs, i, shift => getAux cs[(div2Shift i shift).toNat]! (mod2Shift i shift) (shift - initShift)
|
||||
| leaf cs, i, _ => cs[i.toNat]!
|
||||
|
||||
def get! [Inhabited α] (t : PersistentArray α) (i : Nat) : α :=
|
||||
if i >= t.tailOff then
|
||||
|
|
@ -66,6 +66,10 @@ def get! [Inhabited α] (t : PersistentArray α) (i : Nat) : α :=
|
|||
def getOp [Inhabited α] (self : PersistentArray α) (idx : Nat) : α :=
|
||||
self.get! idx
|
||||
|
||||
-- TODO: remove [Inhabited α]
|
||||
instance [Inhabited α] : GetElem (PersistentArray α) Nat α fun as i => i < as.size where
|
||||
getElem xs i _ := xs.get! i
|
||||
|
||||
partial def setAux : PersistentArrayNode α → USize → USize → α → PersistentArrayNode α
|
||||
| node cs, i, shift, a =>
|
||||
let j := div2Shift i shift
|
||||
|
|
|
|||
|
|
@ -106,8 +106,8 @@ partial def insertAux [BEq α] [Hashable α] : Node α β → USize → USize
|
|||
| ⟨Node.collision keys vals heq, _⟩ =>
|
||||
let rec traverse (i : Nat) (entries : Node α β) : Node α β :=
|
||||
if h : i < keys.size then
|
||||
let k := keys[⟨i, h⟩]
|
||||
let v := vals[⟨i, heq ▸ h⟩]
|
||||
let k := keys[i]
|
||||
let v := vals[i]'(heq ▸ h)
|
||||
let h := hash k |>.toUSize
|
||||
let h := div2Shift h (shift * (depth - 1))
|
||||
traverse (i+1) (insertAux entries h depth k v)
|
||||
|
|
@ -129,8 +129,8 @@ def insert {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α →
|
|||
|
||||
partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option β :=
|
||||
if h : i < keys.size then
|
||||
let k' := keys[⟨i, h⟩]
|
||||
if k == k' then some vals[⟨i, by rw [←heq]; assumption⟩]
|
||||
let k' := keys[i]
|
||||
if k == k' then some (vals[i]'(by rw [←heq]; assumption))
|
||||
else findAtAux keys vals heq (i+1) k
|
||||
else none
|
||||
|
||||
|
|
@ -149,6 +149,9 @@ def find? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Op
|
|||
@[inline] def getOp {_ : BEq α} {_ : Hashable α} (self : PersistentHashMap α β) (idx : α) : Option β :=
|
||||
self.find? idx
|
||||
|
||||
instance {_ : BEq α} {_ : Hashable α} : GetElem (PersistentHashMap α β) α (Option β) fun _ _ => True where
|
||||
getElem m i _ := m.find? i
|
||||
|
||||
@[inline] def findD {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) (a : α) (b₀ : β) : β :=
|
||||
(m.find? a).getD b₀
|
||||
|
||||
|
|
@ -159,8 +162,8 @@ def find? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Op
|
|||
|
||||
partial def findEntryAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option (α × β) :=
|
||||
if h : i < keys.size then
|
||||
let k' := keys[⟨i, h⟩]
|
||||
if k == k' then some (k', vals[⟨i, by rw [←heq]; assumption⟩])
|
||||
let k' := keys[i]
|
||||
if k == k' then some (k', vals[i]'(by rw [←heq]; assumption))
|
||||
else findEntryAtAux keys vals heq (i+1) k
|
||||
else none
|
||||
|
||||
|
|
@ -178,7 +181,7 @@ def findEntry? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α
|
|||
|
||||
partial def containsAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Bool :=
|
||||
if h : i < keys.size then
|
||||
let k' := keys[⟨i, h⟩]
|
||||
let k' := keys[i]
|
||||
if k == k' then true
|
||||
else containsAtAux keys vals heq (i+1) k
|
||||
else false
|
||||
|
|
@ -197,7 +200,7 @@ def contains [BEq α] [Hashable α] : PersistentHashMap α β → α → Bool
|
|||
|
||||
partial def isUnaryEntries (a : Array (Entry α β (Node α β))) (i : Nat) (acc : Option (α × β)) : Option (α × β) :=
|
||||
if h : i < a.size then
|
||||
match a[⟨i, h⟩] with
|
||||
match a[i] with
|
||||
| Entry.null => isUnaryEntries a (i+1) acc
|
||||
| Entry.ref _ => none
|
||||
| Entry.entry k v =>
|
||||
|
|
@ -211,7 +214,8 @@ def isUnaryNode : Node α β → Option (α × β)
|
|||
| Node.collision keys vals heq =>
|
||||
if h : 1 = keys.size then
|
||||
have : 0 < keys.size := by rw [←h]; decide
|
||||
some (keys[⟨0, this⟩], vals[⟨0, by rw [←heq]; assumption⟩])
|
||||
have : 0 < vals.size := by rw [←heq]; assumption
|
||||
some (keys[0], vals[0])
|
||||
else
|
||||
none
|
||||
|
||||
|
|
@ -253,8 +257,8 @@ variable {σ : Type w}
|
|||
| Node.collision keys vals heq, acc =>
|
||||
let rec traverse (i : Nat) (acc : σ) : m σ := do
|
||||
if h : i < keys.size then
|
||||
let k := keys[⟨i, h⟩]
|
||||
let v := vals[⟨i, heq ▸ h⟩]
|
||||
let k := keys[i]
|
||||
let v := vals[i]'(heq ▸ h)
|
||||
traverse (i+1) (← f acc k v)
|
||||
else
|
||||
pure acc
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue