lean4-htt/tests/lean/run/partial_fixpoint_probability.lean
Joachim Breitner 7b813d4f5d
feat: partial_fixpoint: partial functions with equations (#6355)
This PR adds the ability to define possibly non-terminating functions
and still be able to reason about them equationally, as long as they are
tail-recursive or monadic.

Typical uses of this feature are
```lean4
def ack : (n m : Nat) → Option Nat
  | 0,   y   => some (y+1)
  | x+1, 0   => ack x 1
  | x+1, y+1 => do ack x (← ack (x+1) y)
partial_fixpiont

def whileSome (f : α → Option α) (x : α) : α :=
  match f x with
  | none => x
  | some x' => whileSome f x'
partial_fixpiont

def computeLfp {α : Type u} [DecidableEq α] (f : α → α) (x : α) : α :=
  let next := f x
  if x ≠ next then
    computeLfp f next
  else
    x
partial_fixpiont

noncomputable def geom : Distr Nat := do
  let head ← coin
  if head then
    return 0
  else
    let n ← geom
    return (n + 1)
partial_fixpiont
```

This PR contains

* The necessary fragment of domain theory, up to (a variant of)
Knaster–Tarski theorem (merged as
https://github.com/leanprover/lean4/pull/6477)
* A tactic to solve monotonicity goals compositionally (a bit like
mathlib’s `fun_prop`) (merged as
https://github.com/leanprover/lean4/pull/6506)
* An attribute to extend that tactic (merged as
https://github.com/leanprover/lean4/pull/6506)
* A “derecursifier” that uses that machinery to define recursive
function, including support for dependent functions and mutual
recursion.
* Fixed-point induction principles (technical, tedious to use)
* For `Option`-valued functions: Partial correctness induction theorems
that hide all the domain theory

This is heavily inspired by [Isabelle’s `partial_function`
command](https://isabelle.in.tum.de/doc/codegen.pdf).
2025-01-21 09:54:30 +00:00

135 lines
4.3 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

-- Since we do not have ENNReal in core, we just axiomatize it all for this test
opaque ENNReal : Type
axiom ENNReal.sup : ∀ {α}, (α → ENNReal) → ENNReal
axiom ENNReal.sum : ∀ {α}, (α → ENNReal) → ENNReal
axiom ENNReal.add : ENNReal → ENNReal → ENNReal
axiom ENNReal.mul : ENNReal → ENNReal → ENNReal
noncomputable instance : Add ENNReal where add := .add
noncomputable instance : Mul ENNReal where mul := .mul
@[instance] axiom ENNReal.zero : Zero ENNReal
axiom ENNReal.one : ENNReal
axiom ENNReal.one_half : ENNReal
@[simp] axiom ENNReal.mul_one : ∀ x, x * ENNReal.one = x
@[simp] axiom ENNReal.mul_zero : ∀ (x : ENNReal), x * 0 = 0
@[simp] axiom ENNReal.add_zero : ∀ (x : ENNReal), x + 0 = x
@[simp] axiom ENNReal.zero_add : ∀ (x : ENNReal), 0 + x = x
@[simp] axiom ENNReal.sum_bool : ∀ f, sum f = f true + f false
@[simp] axiom ENNReal.sum_const_zero : ∀ α, ENNReal.sum (fun (_ : α) => 0) = 0
@[simp] axiom ENNReal.sum_dirac : ∀ α [DecidableEq α] (f : α → ENNReal) y,
ENNReal.sum (fun x => if x = y then f x else 0) = f y
@[instance] axiom ENNReal.le : LE ENNReal
axiom ENNReal.le_refl : ∀ (x : ENNReal), x ≤ x
axiom ENNReal.le_trans : ∀ {x y z: ENNReal}, x ≤ y → y ≤ z → x ≤ z
axiom ENNReal.le_antisymm : ∀ {x y : ENNReal}, x ≤ y → y ≤ x → x = y
section
set_option linter.unusedVariables false
axiom ENNReal.sum_mono : ∀ {α} (s₁ s₂ : α → ENNReal) (h : ∀ x, s₁ x ≤ s₂ x),
ENNReal.sum s₁ ≤ ENNReal.sum s₂
axiom ENNReal.sup_mono : ∀ {α} (s₁ s₂ : α → ENNReal) (h : ∀ x, s₁ x ≤ s₂ x),
ENNReal.sup s₁ ≤ ENNReal.sup s₂
axiom ENNReal.mul_mono : ∀ (a b c Distr : ENNReal) (h₁ : a ≤ c) (h₂ : b ≤ Distr),
a * b ≤ c * Distr
axiom ENNReal.le_sup : ∀ {α} (a : ENNReal) (s : α → ENNReal) (i : α) (h : a ≤ s i),
a ≤ ENNReal.sup s
axiom ENNReal.sup_le : ∀ {α} (a : ENNReal) (s : α → ENNReal) (h : ∀ (i : α), s i ≤ a),
ENNReal.sup s ≤ a
end
/-- Distribtions (not normalized, which is curcial, else we don't have ⊥.) -/
def Distr (α : Type) : Type := α → ENNReal
noncomputable def Distr.join : Distr (Distr α) → Distr α := fun dd x =>
ENNReal.sum (fun Distr => Distr x * dd Distr )
noncomputable instance : Functor Distr where
map f Distr := fun x => ENNReal.sum (fun y => open Classical in if f y = x then Distr y else 0)
noncomputable instance : Pure Distr where
pure x := fun y => open Classical in if x = y then .one else 0
noncomputable instance : Bind Distr where
bind Distr f := fun x => ENNReal.sum (fun y => Distr y * f y x)
open Lean.Order
noncomputable instance : PartialOrder (Distr α) where
rel d1 d2 := ∀ x, d1 x ≤ d2 x
rel_refl _ := ENNReal.le_refl _
rel_trans h1 h2 _ := ENNReal.le_trans (h1 _) (h2 _)
rel_antisymm h1 h2 := funext (fun _ => ENNReal.le_antisymm (h1 _) (h2 _))
noncomputable instance : CCPO (Distr α) where
csup c x := ENNReal.sup fun (Distr : Subtype c) => Distr.val x
csup_spec := by
intros d₁ c hchain
constructor
next =>
intro h d₂ hd₂ x
apply ENNReal.le_trans ?_ (h x)
apply ENNReal.le_sup (i := ⟨d₂, hd₂⟩)
apply ENNReal.le_refl
next =>
intro h x
apply ENNReal.sup_le
intros Distr
apply h Distr.1 Distr.2 x
noncomputable instance : MonoBind Distr where
bind_mono_left := by
intro α β d₁ d₂ f h₁₂ y
unfold bind instBindDistr
dsimp
apply ENNReal.sum_mono
intro x
apply ENNReal.mul_mono
· apply h₁₂
· apply ENNReal.le_refl
bind_mono_right := by
intro α β Distr f₁ f₂ h₁₂ y
apply ENNReal.sum_mono
intro x
apply ENNReal.mul_mono
· apply ENNReal.le_refl
· apply h₁₂
noncomputable def coin : Distr Bool := fun _ => .one_half
noncomputable def geom : Distr Nat := do
let head ← coin
if head then
return 0
else
let n ← geom
return (n + 1)
partial_fixpoint
/--
info: geom.eq_1 :
geom = do
let head ← coin
if head = true then pure 0
else do
let n ← geom
pure (n + 1)
-/
#guard_msgs in
#check geom.eq_1
-- And we can can do proofs about this
theorem geom_0 : geom 0 = .one_half := by
rw [geom]; simp [bind, coin, pure]
theorem geom_succ : geom (n+1) = .one_half * geom n := by
conv => lhs; rw [geom]
simp [bind, coin, pure, apply_ite]