diff --git a/tests/lean/run/matrix.lean b/tests/lean/run/matrix.lean new file mode 100644 index 0000000000..70779d4893 --- /dev/null +++ b/tests/lean/run/matrix.lean @@ -0,0 +1,80 @@ +/- +Helper classes for Lean 3 users +-/ +class One (α : Type u) where + one : α + +instance [OfNat α (natLit! 1)] : One α where + one := 1 + +instance [One α] : OfNat α (natLit! 1) where + ofNat := One.one + +class Zero (α : Type u) where + zero : α + +instance [OfNat α (natLit! 0)] : Zero α where + zero := 0 + +instance [Zero α] : OfNat α (natLit! 0) where + ofNat := Zero.zero + +/- Simple Matrix -/ + +def Matrix (m n : Nat) (α : Type u) : Type u := + Fin m → Fin n → α + +namespace Matrix + +/- Scoped notation for accessing values stored in matrices. -/ +scoped syntax:max term noWs "[" term ", " term "]" : term + +macro_rules + | `($x[$i, $j]) => `($x $i $j) + +def dotProduct [Mul α] [Add α] [Zero α] (u v : Fin m → α) : α := + loop m (Nat.leRefl ..) Zero.zero +where + loop (i : Nat) (h : i ≤ m) (acc : α) : α := + match i, h with + | 0, h => acc + | i+1, h => + have i < m from Nat.ltOfLtOfLe (Nat.ltSuccSelf _) h + loop i (Nat.leOfLt this) (acc + u ⟨i, this⟩ * v ⟨i, this⟩) + +instance [Zero α] : Zero (Matrix m n α) where + zero _ _ := 0 + +instance [Add α] : Add (Matrix m n α) where + add x y i j := x[i, j] + y[i, j] + +instance [Mul α] [Add α] [Zero α] : HMul (Matrix m n α) (Matrix n p α) (Matrix m p α) where + hMul x y i j := dotProduct (x[i, ·]) (y[·, j]) + +instance [Mul α] : HMul α (Matrix m n α) (Matrix m n α) where + hMul c x i j := c * x[i, j] + +end Matrix + +def m1 : Matrix 2 2 Int := + fun i j => #[#[1, 2], #[3, 4]][i][j] + +def m2 : Matrix 2 2 Int := + fun i j => #[#[5, 6], #[7, 8]][i][j] + +open Matrix -- activate .[.,.] notation + +#eval (m1*m2)[0, 0] -- 19 +#eval (m1*m2)[0, 1] -- 22 +#eval (m1*m2)[1, 0] -- 43 +#eval (m1*m2)[1, 1] -- 50 + +def v := -2 + +#eval (v*m1*m2)[0, 0] -- -38 + +def ex1 (a b : Nat) (x : Matrix 10 20 Nat) (y : Matrix 20 10 Nat) (z : Matrix 10 10 Nat) : Matrix 10 10 Nat := + a * x * y + b * z + +def ex2 (a b : Nat) (x : Matrix m n Nat) (y : Matrix n m Nat) (z : Matrix m m Nat) : Matrix m m Nat := + a * x * y + b * z