Skip to content

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Nov 7, 2024
1 parent 6f88249 commit 1696069
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ It will return a dense array X.
Refer to the pardiso documentation for detailed description of options.
Consider this wrapper to be experimental.

#### SciPy Classes

`csr_array`, `csr_matrix`, `csc_array`, `csc_matrix`, `bsr_array`, `bsr_matrix`

Scipy sparse classes where `__matmul__` and `__rmatmul__` have been replaced to use MKL
for matrix math

#### Service Functions

Several service functions are available and can be imported from the base `sparse_dot_mkl` package.
Expand Down
43 changes: 25 additions & 18 deletions sparse_dot_mkl/tests/test_scipy_classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import numpy.testing as npt
import scipy as sp
import scipy.sparse as sps
from types import MethodType

Expand Down Expand Up @@ -77,28 +78,34 @@ def test_matmul_fail(self):
with self.assertRaises(ValueError):
b @ a

m1 = MATRIX_1.copy()
m2 = MATRIX_2.copy()

install_wire(m1)
install_wire(m2)
install_wire(a)
install_wire(b)
# Following tests dont work with old scipy
if (
(int(sp.__version__.split('.')[1]) > 1) or
(int(sp.__version__.split('.')[1]) > 13)
):

m1 = MATRIX_1.copy()
m2 = MATRIX_2.copy()

install_wire(m1)
install_wire(m2)
install_wire(a)
install_wire(b)
# SCIPY
with self.assertRaises(TripError):
m1 @ m2

# SCIPY
with self.assertRaises(TripError):
m1 @ m2
# SCIPY CSR_MATRIX USES RMATMUL DUNNO WHY
if self.arr != csr_matrix:
with self.assertRaises(TripError):
m1 @ b

# SCIPY CSR_MATRIX USES RMATMUL DUNNO WHY
if self.arr != csr_matrix:
with self.assertRaises(TripError):
m1 @ b
# MKL
a @ m2

# MKL
a @ m2
# MKL
a @ b

# MKL
a @ b

class TestCSRMat(TestCSR):
arr = csr_matrix
Expand Down

0 comments on commit 1696069

Please sign in to comment.