Skip to content

Add more sparse math Ops in numba#1918

Open
tomicapretto wants to merge 6 commits intopymc-devs:mainfrom
tomicapretto:add-numba-sparse-math
Open

Add more sparse math Ops in numba#1918
tomicapretto wants to merge 6 commits intopymc-devs:mainfrom
tomicapretto:add-numba-sparse-math

Conversation

@tomicapretto
Copy link
Contributor

@tomicapretto tomicapretto commented Feb 26, 2026

Description

This PR implements the following sparse ops in numba:

  • SparseSparseMultiply
  • AddSS
  • AddSD
  • AddSSData
  • StructuredAddSV
  • Usmm
  • SamplingDot

Most of the initial implementations were done by Codex, I checked and adapted them as needed.

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@tomicapretto tomicapretto force-pushed the add-numba-sparse-math branch 2 times, most recently from 622c700 to 93adb10 Compare February 26, 2026 15:35
@tomicapretto tomicapretto marked this pull request as ready for review February 26, 2026 16:16
@tomicapretto
Copy link
Contributor Author

There are tests for Usmm failing with the numba backend. They fail due to differences being larger than what is allowed.

Details
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csc-False-float32-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csc-False-float32-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csc-False-float32-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csc-False-complex64-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csc-False-complex64-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csc-False-complex64-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csr-False-float32-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csr-False-float32-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csr-False-float32-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csr-False-complex64-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csr-False-complex64-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[dense-csr-False-complex64-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-dense-False-float32-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-dense-False-float32-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-dense-False-float32-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-dense-False-complex64-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-dense-False-complex64-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-dense-False-complex64-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csc-False-float32-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csc-False-float32-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csc-False-float32-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csc-False-complex64-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csc-False-complex64-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csc-False-complex64-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csr-False-float32-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csr-False-float32-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csr-False-float32-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csr-False-complex64-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csr-False-complex64-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csc-csr-False-complex64-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-dense-False-float32-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-dense-False-float32-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-dense-False-float32-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-dense-False-complex64-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-dense-False-complex64-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-dense-False-complex64-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csc-False-float32-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csc-False-float32-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csc-False-float32-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csc-False-complex64-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csc-False-complex64-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csc-False-complex64-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csr-False-float32-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csr-False-float32-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csr-False-float32-complex64] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csr-False-complex64-float32] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csr-False-complex64-int16] - tests.unittest_tools.WrongValue: WrongValue
FAILED tests/sparse/test_math.py::TestUsmm::test_basic[csr-csr-False-complex64-complex64] - tests.unittest_tools.WrongValue: WrongValue

I'm exploring if this is because I need to upcast anything before doing math operations in the numba implementation.

@tomicapretto
Copy link
Contributor Author

Python's implementation of Usmm is a more or less straightforward implementation of the steps. From my understanding, Usmm becomes faster only when the local_usmm_csc_dense_inplace rewrite and/or local_usmm_csx are applied.

Something similar occurs with SamplingDot, the Python implementation even computes the entire dense dot product, but then we have the local_sampling_dot_csr rewrite which implements a special case in C.

As far as I understand, the rewrites will not be triggered for the numba backend. What's more, in the numba backend, the implementations are already in their specialized forms, so I don't think they'll be needed.

Useful links to mentioned parts:

Usmm

def perform(self, node, inputs, outputs):
(alpha, x, y, z) = inputs
(out,) = outputs
x_is_sparse = psb._is_sparse(x)
y_is_sparse = psb._is_sparse(y)
if not x_is_sparse and not y_is_sparse:
raise TypeError(x)
rval = x * y
if isinstance(rval, scipy_sparse.spmatrix):
rval = rval.toarray()
if rval.dtype == alpha.dtype:
rval *= alpha # Faster because operation is inplace
else:
rval = rval * alpha
if rval.dtype == z.dtype:
rval += z # Faster because operation is inplace
else:
rval = rval + z
out[0] = rval

@node_rewriter([usmm_csc_dense])
def local_usmm_csc_dense_inplace(fgraph, node):
if node.op == usmm_csc_dense:
return [usmm_csc_dense_inplace(*node.inputs)]
register_specialize(local_usmm_csc_dense_inplace, "cxx_only", "inplace")
# This is tested in tests/test_basic.py:UsmmTests
@node_rewriter([spm.usmm])
def local_usmm_csx(fgraph, node):
"""
usmm -> usmm_csc_dense
"""
if node.op == usmm:
alpha, x, y, z = node.inputs
x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y)
if x_is_sparse_variable and not y_is_sparse_variable:
if x.type.format == "csc":
x_val, x_ind, x_ptr, x_shape = csm_properties(x)
x_nsparse = x_shape[0]
dtype_out = ps.upcast(
alpha.type.dtype, x.type.dtype, y.type.dtype, z.type.dtype
)
if dtype_out not in ("float32", "float64"):
return False
# Sparse cast is not implemented.
if y.type.dtype != dtype_out:
return False
return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr, x_nsparse, y, z)]
return False
register_specialize(local_usmm_csx, "cxx_only")

SamplingDot

def perform(self, node, inputs, outputs):
(x, y, p) = inputs
(out,) = outputs
if psb._is_sparse(x):
raise TypeError(x)
if psb._is_sparse(y):
raise TypeError(y)
if not psb._is_sparse(p):
raise TypeError(p)
out[0] = p.__class__(p.multiply(np.dot(x, y.T)))

def local_sampling_dot_csr(fgraph, node):
if not config.blas__ldflags:
# The C implementation of SamplingDotCsr relies on BLAS routines
return
if node.op == spm.sampling_dot:
x, y, p = node.inputs
if p.type.format == "csr":
p_data, p_ind, p_ptr, p_shape = sparse.csm_properties(p)
z_data, z_ind, z_ptr = sampling_dot_csr(
x, y, p_data, p_ind, p_ptr, p_shape[1]
)
# This is a hack that works around some missing `Type`-related
# static shape narrowing. More specifically,
# `TensorType.convert_variable` currently won't combine the static
# shape information from `old_out.type` and `new_out.type`, only
# the broadcast patterns, and, since `CSR.make_node` doesn't do
# that either, we use `specify_shape` to produce an output `Type`
# with the same level of static shape information as the original
# `old_out`.
old_out = node.outputs[0]
new_out = specify_shape(
sparse.CSR(z_data, z_ind, z_ptr, p_shape), shape(old_out)
)
return [new_out]
return False

@tomicapretto
Copy link
Contributor Author

tomicapretto commented Feb 26, 2026

@ricardoV94, now the same test fails with the default. I'm not sure what's the approach to follow here. Two ideas come to my mind:

  • Implement a numba specific result computation function that casts value to output type before doing math, like this
        def f_b(z, a, x, y):
            # Make sure operations are done with the precision of the output dtype
            x = x.astype(out_dtype)
            y = y.astype(out_dtype)
            z = z.astype(out_dtype)
            a = a.astype(out_dtype)
            return z - a * (x * y)
  • Use the previous function in all cases, but modify Usmm.perform to also upcast x and y before doing x @ y in z + a * (x @ y)

Given how Usmm is implemented in the numba backend, it is difficult, and possibly impossible, to match the default backend’s behavior. In the default backend, if x and y are float32, the x @ y operation is performed in float32, and upcasting only happens afterward, before multiplying by a or adding to z, not before the matrix multiplication itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant