From 7fc12014dae3c4622348f33ff8be5eed2a95fb20 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 9 Feb 2022 17:05:45 -0800 Subject: [PATCH] fix: make sure splitTarget? skips `match` expressions that produce type errors at `splitMatch` We can now generate the equation theorem for ``` attribute [simp] Array.heapSort.loop ``` see #998 --- src/Lean/Meta/Tactic/Split.lean | 29 ++++-- tests/lean/run/heapSort.lean | 176 ++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+), 11 deletions(-) create mode 100644 tests/lean/run/heapSort.lean diff --git a/src/Lean/Meta/Tactic/Split.lean b/src/Lean/Meta/Tactic/Split.lean index 0307dd6fef..0683240757 100644 --- a/src/Lean/Meta/Tactic/Split.lean +++ b/src/Lean/Meta/Tactic/Split.lean @@ -119,19 +119,21 @@ def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do throwNestedTacticEx `splitMatch ex /-- Return an `if-then-else` or `match-expr` to split. -/ -partial def findSplit? (env : Environment) (e : Expr) : Option Expr := +partial def findSplit? (env : Environment) (e : Expr) (exceptionSet : ExprSet := {}) : Option Expr := if let some target := e.find? isCandidate then if e.isIte || e.isDIte then let cond := target.getArg! 1 5 -- Try to find a nested `if` in `cond` - findSplit? env cond |>.getD target + findSplit? env cond exceptionSet |>.getD target else some target else none where isCandidate (e : Expr) : Bool := Id.run <| do - if e.isIte || e.isDIte then + if exceptionSet.contains e then + false + else if e.isIte || e.isDIte then !(e.getArg! 1 5).hasLooseBVars else if let some info := isMatcherAppCore? env e then let args := e.getAppArgs @@ -146,15 +148,20 @@ end Split open Split -def splitTarget? (mvarId : MVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do - if let some e := findSplit? (← getEnv) (← instantiateMVars (← getMVarType mvarId)) then - if e.isIte || e.isDIte then - return (← splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId] +partial def splitTarget? (mvarId : MVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do + let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do + if let some e := findSplit? (← getEnv) (← instantiateMVars (← getMVarType mvarId)) badCases then + if e.isIte || e.isDIte then + return (← splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId] + else + try + splitMatch mvarId e + catch _ => + go (badCases.insert e) else - splitMatch mvarId e - else - trace[Meta.Tactic.split] "did not find term to split\n{MessageData.ofGoal mvarId}" - return none + trace[Meta.Tactic.split] "did not find term to split\n{MessageData.ofGoal mvarId}" + return none + go {} def splitLocalDecl? (mvarId : MVarId) (fvarId : FVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do withMVarContext mvarId do diff --git a/tests/lean/run/heapSort.lean b/tests/lean/run/heapSort.lean new file mode 100644 index 0000000000..1b8f2db02a --- /dev/null +++ b/tests/lean/run/heapSort.lean @@ -0,0 +1,176 @@ +/- +Copyright (c) 2021 Mario Carneiro. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Mario Carneiro +-/ + +/-- A max-heap data structure. -/ +structure BinaryHeap (α) (lt : α → α → Bool) where + arr : Array α + +namespace BinaryHeap + +/-- Core operation for binary heaps, expressed directly on arrays. +Given an array which is a max-heap, push item `i` down to restore the max-heap property. -/ +def heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : + {a' : Array α // a'.size = a.size} := + let left := 2 * i.1 + 1 + let right := left + 1 + have left_le : i ≤ left := Nat.le_trans + (by rw [Nat.succ_mul, Nat.one_mul]; exact Nat.le_add_left i i) + (Nat.le_add_right ..) + have right_le : i ≤ right := Nat.le_trans left_le (Nat.le_add_right ..) + have i_le : i ≤ i := Nat.le_refl _ + have j : {j : Fin a.size // i ≤ j} := if h : left < a.size then + if lt (a.get i) (a.get ⟨left, h⟩) then ⟨⟨left, h⟩, left_le⟩ else ⟨i, i_le⟩ else ⟨i, i_le⟩ + have j := if h : right < a.size then + if lt (a.get j) (a.get ⟨right, h⟩) then ⟨⟨right, h⟩, right_le⟩ else j else j + if h : i.1 = j then ⟨a, rfl⟩ else + let a' := a.swap i j + let j' := ⟨j, by rw [a.size_swap i j]; exact j.1.2⟩ + have : a'.size - j < a.size - i := by + rw [a.size_swap i j]; sorry + let ⟨a₂, h₂⟩ := heapifyDown lt a' j' + ⟨a₂, h₂.trans (a.size_swap i j)⟩ +termination_by _ => a.size - i +decreasing_by assumption + +@[simp] theorem size_heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : + (heapifyDown lt a i).1.size = a.size := (heapifyDown lt a i).2 + +/-- Core operation for binary heaps, expressed directly on arrays. +Construct a heap from an unsorted array, by heapifying all the elements. -/ +def mkHeap (lt : α → α → Bool) (a : Array α) : {a' : Array α // a'.size = a.size} := + let rec loop : (i : Nat) → (a : Array α) → i ≤ a.size → {a' : Array α // a'.size = a.size} + | 0, a, _ => ⟨a, rfl⟩ + | i+1, a, h => + let h := Nat.lt_of_succ_le h + let a' := heapifyDown lt a ⟨i, h⟩ + let ⟨a₂, h₂⟩ := loop i a' ((heapifyDown ..).2.symm ▸ Nat.le_of_lt h) + ⟨a₂, h₂.trans a'.2⟩ + loop (a.size / 2) a sorry + +@[simp] theorem size_mkHeap (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : + (mkHeap lt a).1.size = a.size := (mkHeap lt a).2 + +/-- Core operation for binary heaps, expressed directly on arrays. +Given an array which is a max-heap, push item `i` up to restore the max-heap property. -/ +def heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : + {a' : Array α // a'.size = a.size} := +if i0 : i.1 = 0 then ⟨a, rfl⟩ else + have : (i.1 - 1) / 2 < i := sorry + let j := ⟨(i.1 - 1) / 2, Nat.lt_trans this i.2⟩ + if lt (a.get j) (a.get i) then + let a' := a.swap i j + let ⟨a₂, h₂⟩ := heapifyUp lt a' ⟨j.1, by rw [a.size_swap i j]; exact j.2⟩ + ⟨a₂, h₂.trans (a.size_swap i j)⟩ + else ⟨a, rfl⟩ +termination_by _ => i.1 +decreasing_by assumption + +@[simp] theorem size_heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : + (heapifyUp lt a i).1.size = a.size := (heapifyUp lt a i).2 + +/-- `O(1)`. Build a new empty heap. -/ +def empty (lt) : BinaryHeap α lt := ⟨#[]⟩ + +instance (lt) : Inhabited (BinaryHeap α lt) := ⟨empty _⟩ +instance (lt) : EmptyCollection (BinaryHeap α lt) := ⟨empty _⟩ + +/-- `O(1)`. Build a one-element heap. -/ +def singleton (lt) (x : α) : BinaryHeap α lt := ⟨#[x]⟩ + +/-- `O(1)`. Get the number of elements in a `BinaryHeap`. -/ +def size {lt} (self : BinaryHeap α lt) : Nat := self.1.size + +/-- `O(1)`. Get an element in the heap by index. -/ +def get {lt} (self : BinaryHeap α lt) (i : Fin self.size) : α := self.1.get i + +/-- `O(log n)`. Insert an element into a `BinaryHeap`, preserving the max-heap property. -/ +def insert {lt} (self : BinaryHeap α lt) (x : α) : BinaryHeap α lt where + arr := let n := self.size; + heapifyUp lt (self.1.push x) ⟨n, by rw [Array.size_push]; apply Nat.lt_succ_self⟩ + +@[simp] theorem size_insert {lt} (self : BinaryHeap α lt) (x : α) : + (self.insert x).size = self.size + 1 := by + simp [insert, size, size_heapifyUp] + +/-- `O(1)`. Get the maximum element in a `BinaryHeap`. -/ +def max {lt} (self : BinaryHeap α lt) : Option α := self.1.get? 0 + +/-- Auxiliary for `popMax`. -/ +def popMaxAux {lt} (self : BinaryHeap α lt) : {a' : BinaryHeap α lt // a'.size = self.size - 1} := + match e: self.1.size with + | 0 => ⟨self, by simp [size, e]⟩ + | n+1 => + have h0 := by rw [e]; apply Nat.succ_pos + have hn := by rw [e]; apply Nat.lt_succ_self + if hn0 : 0 < n then + let a := self.1.swap ⟨0, h0⟩ ⟨n, hn⟩ |>.pop + ⟨⟨heapifyDown lt a ⟨0, sorry⟩⟩, + by simp [size]⟩ + else + ⟨⟨self.1.pop⟩, by simp [size]⟩ + +/-- `O(log n)`. Remove the maximum element from a `BinaryHeap`. +Call `max` first to actually retrieve the maximum element. -/ +def popMax {lt} (self : BinaryHeap α lt) : BinaryHeap α lt := self.popMaxAux + +@[simp] theorem size_popMax {lt} (self : BinaryHeap α lt) : + self.popMax.size = self.size - 1 := self.popMaxAux.2 + +/-- `O(log n)`. Return and remove the maximum element from a `BinaryHeap`. -/ +def extractMax {lt} (self : BinaryHeap α lt) : Option α × BinaryHeap α lt := + (self.max, self.popMax) + +theorem size_pos_of_max {lt} {self : BinaryHeap α lt} (e : self.max = some x) : 0 < self.size := + Decidable.of_not_not fun h: ¬ 0 < self.1.size => by simp [BinaryHeap.max, Array.get?, h] at e + +/-- `O(log n)`. Equivalent to `extractMax (self.insert x)`, except that extraction cannot fail. -/ +def insertExtractMax {lt} (self : BinaryHeap α lt) (x : α) : α × BinaryHeap α lt := + match e: self.max with + | none => (x, self) + | some m => + if lt x m then + let a := self.1.set ⟨0, size_pos_of_max e⟩ x + (m, ⟨heapifyDown lt a ⟨0, by simp; exact size_pos_of_max e⟩⟩) + else (x, self) + +/-- `O(log n)`. Equivalent to `(self.max, self.popMax.insert x)`. -/ +def replaceMax {lt} (self : BinaryHeap α lt) (x : α) : Option α × BinaryHeap α lt := + match e: self.max with + | none => (none, ⟨self.1.push x⟩) + | some m => + let a := self.1.set ⟨0, size_pos_of_max e⟩ x + (some m, ⟨heapifyDown lt a ⟨0, by simp; exact size_pos_of_max e⟩⟩) + +/-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `x ≤ self.get i`. -/ +def decreaseKey {lt} (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where + arr := heapifyDown lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2⟩ + +/-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `self.get i ≤ x`. -/ +def increaseKey {lt} (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where + arr := heapifyUp lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2⟩ + +end BinaryHeap + +/-- `O(n)`. Convert an unsorted array to a `BinaryHeap`. -/ +def Array.toBinaryHeap (lt : α → α → Bool) (a : Array α) : BinaryHeap α lt where + arr := BinaryHeap.mkHeap lt a + +/-- `O(n log n)`. Sort an array using a `BinaryHeap`. -/ +@[specialize] def Array.heapSort (a : Array α) (lt : α → α → Bool) : Array α := + let gt y x := lt x y + let rec loop (a : BinaryHeap α gt) (out : Array α) : Array α := + match e:a.max with + | none => out + | some x => + have : a.popMax.size < a.size := by + simp; exact Nat.sub_lt (BinaryHeap.size_pos_of_max e) Nat.zero_lt_one + loop a.popMax (out.push x) + loop (a.toBinaryHeap gt) #[] +termination_by _ => a.size +decreasing_by assumption + +attribute [simp] Array.heapSort.loop +#check @Array.heapSort.loop._eq_1