test: matrix notation example
- Heterogeneous `*` for matrix and scalar multiplication - Homogeneous `+` for matrix addition - Whitespace sensitive `x[i, j]` notation
This commit is contained in:
parent
836fd46d90
commit
1d5df4f28b
1 changed files with 80 additions and 0 deletions
80
tests/lean/run/matrix.lean
Normal file
80
tests/lean/run/matrix.lean
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
/-
|
||||
Helper classes for Lean 3 users
|
||||
-/
|
||||
class One (α : Type u) where
|
||||
one : α
|
||||
|
||||
instance [OfNat α (natLit! 1)] : One α where
|
||||
one := 1
|
||||
|
||||
instance [One α] : OfNat α (natLit! 1) where
|
||||
ofNat := One.one
|
||||
|
||||
class Zero (α : Type u) where
|
||||
zero : α
|
||||
|
||||
instance [OfNat α (natLit! 0)] : Zero α where
|
||||
zero := 0
|
||||
|
||||
instance [Zero α] : OfNat α (natLit! 0) where
|
||||
ofNat := Zero.zero
|
||||
|
||||
/- Simple Matrix -/
|
||||
|
||||
def Matrix (m n : Nat) (α : Type u) : Type u :=
|
||||
Fin m → Fin n → α
|
||||
|
||||
namespace Matrix
|
||||
|
||||
/- Scoped notation for accessing values stored in matrices. -/
|
||||
scoped syntax:max term noWs "[" term ", " term "]" : term
|
||||
|
||||
macro_rules
|
||||
| `($x[$i, $j]) => `($x $i $j)
|
||||
|
||||
def dotProduct [Mul α] [Add α] [Zero α] (u v : Fin m → α) : α :=
|
||||
loop m (Nat.leRefl ..) Zero.zero
|
||||
where
|
||||
loop (i : Nat) (h : i ≤ m) (acc : α) : α :=
|
||||
match i, h with
|
||||
| 0, h => acc
|
||||
| i+1, h =>
|
||||
have i < m from Nat.ltOfLtOfLe (Nat.ltSuccSelf _) h
|
||||
loop i (Nat.leOfLt this) (acc + u ⟨i, this⟩ * v ⟨i, this⟩)
|
||||
|
||||
instance [Zero α] : Zero (Matrix m n α) where
|
||||
zero _ _ := 0
|
||||
|
||||
instance [Add α] : Add (Matrix m n α) where
|
||||
add x y i j := x[i, j] + y[i, j]
|
||||
|
||||
instance [Mul α] [Add α] [Zero α] : HMul (Matrix m n α) (Matrix n p α) (Matrix m p α) where
|
||||
hMul x y i j := dotProduct (x[i, ·]) (y[·, j])
|
||||
|
||||
instance [Mul α] : HMul α (Matrix m n α) (Matrix m n α) where
|
||||
hMul c x i j := c * x[i, j]
|
||||
|
||||
end Matrix
|
||||
|
||||
def m1 : Matrix 2 2 Int :=
|
||||
fun i j => #[#[1, 2], #[3, 4]][i][j]
|
||||
|
||||
def m2 : Matrix 2 2 Int :=
|
||||
fun i j => #[#[5, 6], #[7, 8]][i][j]
|
||||
|
||||
open Matrix -- activate .[.,.] notation
|
||||
|
||||
#eval (m1*m2)[0, 0] -- 19
|
||||
#eval (m1*m2)[0, 1] -- 22
|
||||
#eval (m1*m2)[1, 0] -- 43
|
||||
#eval (m1*m2)[1, 1] -- 50
|
||||
|
||||
def v := -2
|
||||
|
||||
#eval (v*m1*m2)[0, 0] -- -38
|
||||
|
||||
def ex1 (a b : Nat) (x : Matrix 10 20 Nat) (y : Matrix 20 10 Nat) (z : Matrix 10 10 Nat) : Matrix 10 10 Nat :=
|
||||
a * x * y + b * z
|
||||
|
||||
def ex2 (a b : Nat) (x : Matrix m n Nat) (y : Matrix n m Nat) (z : Matrix m m Nat) : Matrix m m Nat :=
|
||||
a * x * y + b * z
|
||||
Loading…
Add table
Reference in a new issue