Skip to content

Commit db08aef

Browse files
committed
Removed deprecated functions
1 parent 5a413b3 commit db08aef

4 files changed

Lines changed: 9 additions & 9 deletions

File tree

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="xlb",
5-
version="0.3.1",
5+
version="0.3.2",
66
description="XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML",
77
long_description=open("README.md").read(),
88
long_description_content_type="text/markdown",
@@ -19,11 +19,11 @@
1919
"numpy-stl>=3.1.2",
2020
"pydantic>=2.9.1",
2121
"ruff>=0.14.1",
22-
"jax>=0.8.0", # Base JAX CPU-only requirement
22+
"jax>=0.8.2", # Base JAX CPU-only requirement
2323
],
2424
extras_require={
25-
"cuda": ["jax[cuda13]>=0.8.0"], # For CUDA installations
26-
"tpu": ["jax[tpu]>=0.8.0"], # For TPU installations
25+
"cuda": ["jax[cuda13]>=0.8.2"], # For CUDA installations (pip install -U "jax[cuda13]")
26+
"tpu": ["jax[tpu]>=0.8.2"], # For TPU installations
2727
"test": ["pytest>=8.0.0"],
2828
},
2929
python_requires=">=3.11",

xlb/distribute/distribute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from xlb.operator.stepper import IncompressibleNavierStokesStepper
44
from xlb.operator.boundary_condition.boundary_condition import ImplementationStep
55
from jax import lax
6-
from jax.experimental.shard_map import shard_map
6+
from jax import shard_map
77
from jax import jit
88

99

@@ -72,7 +72,7 @@ def _wrapped_operator(*args):
7272
mesh=grid.global_mesh,
7373
in_specs=in_specs,
7474
out_specs=out_specs,
75-
check_rep=False,
75+
check_vma=False,
7676
)
7777
return distributed_operator(*args)
7878

xlb/operator/parallel_operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from jax.experimental.shard_map import shard_map
1+
from jax import shard_map
22
from jax.sharding import PartitionSpec as P
33
from jax import lax
44

@@ -47,7 +47,7 @@ def __call__(self, f):
4747
mesh=self.grid.global_mesh,
4848
in_specs=in_specs,
4949
out_specs=out_specs,
50-
check_rep=False,
50+
check_vma=False,
5151
)(f)
5252
return f
5353

xlb/operator/stepper/ibm_stepper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from xlb.operator import Operator
99
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
1010
from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper
11-
from warp.utils import ScopedTimer
11+
from warp import ScopedTimer
1212

1313

1414
class IBMStepper(IncompressibleNavierStokesStepper):

0 commit comments

Comments
 (0)