Skip to content

Commit 9a17923

Browse files
Fix state.to() not propagating dtype/device to constraints & expose angle_tolerance in symmetry pipeline (#527)
1 parent 6c8d562 commit 9a17923

4 files changed

Lines changed: 164 additions & 12 deletions

File tree

tests/test_constraints.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,3 +1210,66 @@ def test_fix_com_system_idx_remapped_on_reordered_slice(
12101210
c = sliced.constraints[0]
12111211
assert isinstance(c, FixCom)
12121212
assert sorted(c.system_idx.tolist()) == [0, 1]
1213+
1214+
1215+
class TestConstraintToDeviceDtype:
1216+
"""Test that state.to() propagates device/dtype to constraint tensors."""
1217+
1218+
def test_fix_atoms_dtype_propagation(
1219+
self, ar_supercell_sim_state: ts.SimState
1220+
) -> None:
1221+
"""FixAtoms indices should be moved to the new device by state.to()."""
1222+
indices = torch.tensor([0, 3, 5], dtype=torch.long)
1223+
ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=indices)]
1224+
new_state = ar_supercell_sim_state.to(dtype=torch.float32)
1225+
1226+
c = new_state.constraints[0]
1227+
assert isinstance(c, FixAtoms)
1228+
assert torch.equal(c.atom_idx, indices)
1229+
# dtype change should not affect integer indices, but the constraint
1230+
# object must be a distinct copy
1231+
assert c is not ar_supercell_sim_state.constraints[0]
1232+
1233+
def test_fix_com_dtype_propagation(self, ar_supercell_sim_state: ts.SimState) -> None:
1234+
"""FixCom's cached coms tensor should follow state dtype changes."""
1235+
ar_supercell_sim_state.constraints = [FixCom([0])]
1236+
# Trigger lazy COM initialisation
1237+
ar_supercell_sim_state.set_constrained_positions(
1238+
ar_supercell_sim_state.positions.clone()
1239+
)
1240+
assert ar_supercell_sim_state.constraints[0].coms is not None
1241+
1242+
new_state = ar_supercell_sim_state.to(dtype=torch.float32)
1243+
c = new_state.constraints[0]
1244+
assert isinstance(c, FixCom)
1245+
assert c.coms is not None
1246+
assert c.coms.dtype == torch.float32
1247+
1248+
@pytest.mark.parametrize("target_dtype", [torch.float32, torch.float64])
1249+
def test_fix_symmetry_dtype_propagation(self, target_dtype: torch.dtype) -> None:
1250+
"""FixSymmetry rotations and reference_cells must follow dtype changes."""
1251+
rotations = [torch.eye(3, dtype=torch.float64).unsqueeze(0)]
1252+
symm_maps = [torch.zeros(1, 2, dtype=torch.long)]
1253+
ref_cells = [torch.eye(3, dtype=torch.float64)]
1254+
1255+
state = ts.SimState(
1256+
positions=torch.zeros(2, 3, dtype=torch.float64),
1257+
masses=torch.ones(2, dtype=torch.float64),
1258+
cell=torch.eye(3, dtype=torch.float64).unsqueeze(0) * 5.0,
1259+
pbc=True,
1260+
atomic_numbers=torch.tensor([14, 14]),
1261+
system_idx=torch.zeros(2, dtype=torch.long),
1262+
)
1263+
state.constraints = [FixSymmetry(rotations, symm_maps, reference_cells=ref_cells)]
1264+
1265+
new_state = state.to(dtype=target_dtype)
1266+
c = new_state.constraints[0]
1267+
assert isinstance(c, FixSymmetry)
1268+
assert c.rotations[0].dtype == target_dtype
1269+
assert c.reference_cells is not None
1270+
assert c.reference_cells[0].dtype == target_dtype
1271+
# integer symm_maps must stay long
1272+
assert c.symm_maps[0].dtype == torch.long
1273+
# original constraint unchanged
1274+
orig = state.constraints[0]
1275+
assert orig.rotations[0].dtype == torch.float64

torch_sim/constraints.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,18 @@ def merge(cls, constraints: list[Constraint]) -> Self:
151151
constraints: Constraints to merge (all same type, already reindexed)
152152
"""
153153

154+
@abstractmethod
155+
def to(
156+
self,
157+
device: torch.device | None = None,
158+
dtype: torch.dtype | None = None,
159+
) -> Self:
160+
"""Return a copy with all internal tensors moved to *device*/*dtype*.
161+
162+
Float tensors are cast to *dtype*; integer/bool tensors are only moved
163+
to *device*.
164+
"""
165+
154166

155167
def _cumsum_with_zero(tensor: torch.Tensor) -> torch.Tensor:
156168
"""Cumulative sum with a leading zero, e.g. [3, 2, 4] -> [0, 3, 5, 9]."""
@@ -272,6 +284,14 @@ def merge(cls, constraints: list[Constraint]) -> Self:
272284
)
273285
return cls(torch.cat([constraint.atom_idx for constraint in atom_constraints]))
274286

287+
def to(
288+
self,
289+
device: torch.device | None = None,
290+
dtype: torch.dtype | None = None, # noqa: ARG002
291+
) -> Self:
292+
"""Return a copy with atom indices moved to *device*."""
293+
return type(self)(self.atom_idx.to(device=device))
294+
275295

276296
class SystemConstraint(Constraint):
277297
"""Base class for constraints that act on specific system indices.
@@ -371,6 +391,14 @@ def merge(cls, constraints: list[Constraint]) -> Self:
371391
torch.cat([constraint.system_idx for constraint in system_constraints])
372392
)
373393

394+
def to(
395+
self,
396+
device: torch.device | None = None,
397+
dtype: torch.dtype | None = None, # noqa: ARG002
398+
) -> Self:
399+
"""Return a copy with system indices moved to *device*."""
400+
return type(self)(self.system_idx.to(device=device))
401+
374402

375403
def merge_constraints(
376404
constraint_lists: list[list[Constraint]],
@@ -612,6 +640,17 @@ def __repr__(self) -> str:
612640
"""String representation of the constraint."""
613641
return f"FixCom(system_idx={self.system_idx})"
614642

643+
def to(
644+
self,
645+
device: torch.device | None = None,
646+
dtype: torch.dtype | None = None,
647+
) -> Self:
648+
"""Return a copy with tensors moved to *device*/*dtype*."""
649+
new = type(self)(self.system_idx.to(device=device))
650+
if self.coms is not None:
651+
new.coms = self.coms.to(device=device, dtype=dtype)
652+
return new
653+
615654

616655
def count_degrees_of_freedom(
617656
state: SimState, constraints: list[Constraint] | None = None
@@ -801,6 +840,7 @@ def from_state(
801840
adjust_positions: bool = True,
802841
adjust_cell: bool = True,
803842
refine_symmetry_state: bool = True,
843+
angle_tolerance: float | None = None,
804844
) -> Self:
805845
"""Create from SimState, optionally refining to ideal symmetry first.
806846
@@ -814,6 +854,8 @@ def from_state(
814854
adjust_positions: Whether to symmetrize position displacements.
815855
adjust_cell: Whether to symmetrize cell/stress adjustments.
816856
refine_symmetry_state: Whether to refine positions/cell to ideal values.
857+
angle_tolerance: Angle tolerance in radians for moyopy symmetry
858+
detection. If None, moyopy uses its default behaviour.
817859
"""
818860
try:
819861
import moyopy # noqa: F401
@@ -839,11 +881,18 @@ def from_state(
839881
pos,
840882
nums,
841883
symprec=symprec,
884+
angle_tolerance=angle_tolerance,
842885
)
843886
state.cell[sys_idx] = cell.mT # row→column vector convention
844887
state.positions[start:end] = pos
845888
else:
846-
rots, smap = prep_symmetry(cell, pos, nums, symprec=symprec)
889+
rots, smap = prep_symmetry(
890+
cell,
891+
pos,
892+
nums,
893+
symprec=symprec,
894+
angle_tolerance=angle_tolerance,
895+
)
847896

848897
rotations.append(rots)
849898
symm_maps.append(smap)
@@ -973,6 +1022,26 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002
9731022
max_cumulative_strain=self.max_cumulative_strain,
9741023
)
9751024

1025+
def to(
1026+
self,
1027+
device: torch.device | None = None,
1028+
dtype: torch.dtype | None = None,
1029+
) -> Self:
1030+
"""Return a copy with tensors moved to *device*/*dtype*."""
1031+
return type(self)(
1032+
[r.to(device=device, dtype=dtype) for r in self.rotations],
1033+
[s.to(device=device) for s in self.symm_maps],
1034+
self.system_idx.to(device=device),
1035+
adjust_positions=self.do_adjust_positions,
1036+
adjust_cell=self.do_adjust_cell,
1037+
reference_cells=(
1038+
[c.to(device=device, dtype=dtype) for c in self.reference_cells]
1039+
if self.reference_cells is not None
1040+
else None
1041+
),
1042+
max_cumulative_strain=self.max_cumulative_strain,
1043+
)
1044+
9761045
@classmethod
9771046
def merge(cls, constraints: list[Constraint]) -> Self:
9781047
"""Merge by concatenating rotations, symm_maps, and system indices."""

torch_sim/state.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -793,15 +793,24 @@ def _state_to_device[T: SimState](
793793
attrs = state.attributes
794794
for attr_name, attr_value in attrs.items():
795795
if isinstance(attr_value, torch.Tensor):
796-
attrs[attr_name] = attr_value.to(device=device)
796+
if attr_value.is_floating_point() and dtype is not None:
797+
# also move floating point attributes like forces, velocities, etc.
798+
# to dtype.
799+
attrs[attr_name] = attr_value.to(device=device, dtype=dtype)
800+
else:
801+
# non-floating attributes like system_idx keep their dtype.
802+
attrs[attr_name] = attr_value.to(device=device)
797803
elif isinstance(attr_value, torch.Generator):
798804
attrs[attr_name] = coerce_prng(attr_value, device)
799805

800806
if dtype is not None:
801-
attrs["positions"] = attrs["positions"].to(dtype=dtype)
802-
attrs["masses"] = attrs["masses"].to(dtype=dtype)
803-
attrs["cell"] = attrs["cell"].to(dtype=dtype)
804807
attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int)
808+
809+
if attrs.get("_constraints"):
810+
attrs["_constraints"] = [
811+
c.to(device=device, dtype=dtype) for c in attrs["_constraints"]
812+
]
813+
805814
return type(state)(**attrs)
806815

807816

torch_sim/symmetrize.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def _moyo_dataset(
2222
frac_pos: torch.Tensor,
2323
atomic_numbers: torch.Tensor,
2424
symprec: float = 1e-4,
25+
angle_tolerance: float | None = None,
2526
) -> MoyoDataset:
2627
"""Get MoyoDataset from cell, fractional positions, and atomic numbers."""
2728
from moyopy import Cell, MoyoDataset
@@ -31,7 +32,7 @@ def _moyo_dataset(
3132
positions=frac_pos.detach().cpu().tolist(),
3233
numbers=atomic_numbers.detach().cpu().int().tolist(),
3334
)
34-
return MoyoDataset(moyo_cell, symprec=symprec)
35+
return MoyoDataset(moyo_cell, symprec=symprec, angle_tolerance=angle_tolerance)
3536

3637

3738
def _extract_symmetry_ops(
@@ -51,13 +52,19 @@ def _extract_symmetry_ops(
5152
return rotations, translations
5253

5354

54-
def get_symmetry_datasets(state: SimState, symprec: float = 1e-4) -> list[MoyoDataset]:
55+
def get_symmetry_datasets(
56+
state: SimState,
57+
symprec: float = 1e-4,
58+
angle_tolerance: float | None = None,
59+
) -> list[MoyoDataset]:
5560
"""Get MoyoDataset for each system in a SimState."""
5661
datasets = []
5762
for single in state.split():
5863
cell = single.row_vector_cell[0]
5964
frac = single.positions @ torch.linalg.inv(cell)
60-
datasets.append(_moyo_dataset(cell, frac, single.atomic_numbers, symprec))
65+
datasets.append(
66+
_moyo_dataset(cell, frac, single.atomic_numbers, symprec, angle_tolerance)
67+
)
6168
return datasets
6269

6370

@@ -105,14 +112,15 @@ def prep_symmetry(
105112
positions: torch.Tensor,
106113
atomic_numbers: torch.Tensor,
107114
symprec: float = 1e-4,
115+
angle_tolerance: float | None = None,
108116
) -> tuple[torch.Tensor, torch.Tensor]:
109117
"""Get symmetry rotations and atom mappings for a structure.
110118
111119
Returns:
112120
(rotations, symm_map) with shapes (n_ops, 3, 3) and (n_ops, n_atoms).
113121
"""
114122
frac_pos = positions @ torch.linalg.inv(cell)
115-
dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec)
123+
dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec, angle_tolerance)
116124
rotations, translations = _extract_symmetry_ops(dataset, cell.dtype, cell.device)
117125
return rotations, build_symmetry_map(rotations, translations, frac_pos)
118126

@@ -122,6 +130,7 @@ def _refine_symmetry_impl(
122130
positions: torch.Tensor,
123131
atomic_numbers: torch.Tensor,
124132
symprec: float = 0.01,
133+
angle_tolerance: float | None = None,
125134
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
126135
"""Core refinement returning all intermediate data for reuse.
127136
@@ -130,7 +139,7 @@ def _refine_symmetry_impl(
130139
"""
131140
dtype, device = cell.dtype, cell.device
132141
frac_pos = positions @ torch.linalg.inv(cell)
133-
dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec)
142+
dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec, angle_tolerance)
134143
rotations, translations = _extract_symmetry_ops(dataset, dtype, device)
135144
n_ops, n_atoms = rotations.shape[0], positions.shape[0]
136145

@@ -165,6 +174,7 @@ def refine_symmetry(
165174
positions: torch.Tensor,
166175
atomic_numbers: torch.Tensor,
167176
symprec: float = 0.01,
177+
angle_tolerance: float | None = None,
168178
) -> tuple[torch.Tensor, torch.Tensor]:
169179
"""Symmetrize cell and positions according to the detected space group.
170180
@@ -175,7 +185,7 @@ def refine_symmetry(
175185
(symmetrized_cell, symmetrized_positions) as row vectors.
176186
"""
177187
new_cell, new_positions, _rotations, _translations = _refine_symmetry_impl(
178-
cell, positions, atomic_numbers, symprec
188+
cell, positions, atomic_numbers, symprec, angle_tolerance
179189
)
180190
return new_cell, new_positions
181191

@@ -185,6 +195,7 @@ def refine_and_prep_symmetry(
185195
positions: torch.Tensor,
186196
atomic_numbers: torch.Tensor,
187197
symprec: float = 0.01,
198+
angle_tolerance: float | None = None,
188199
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
189200
"""Refine symmetry and get ops/mappings in a single moyopy call.
190201
@@ -195,7 +206,7 @@ def refine_and_prep_symmetry(
195206
(refined_cell, refined_positions, rotations, symm_map)
196207
"""
197208
new_cell, new_positions, rotations, translations = _refine_symmetry_impl(
198-
cell, positions, atomic_numbers, symprec
209+
cell, positions, atomic_numbers, symprec, angle_tolerance
199210
)
200211
# Build symm_map on the final refined fractional coordinates
201212
refined_frac = new_positions @ torch.linalg.inv(new_cell)

0 commit comments

Comments
 (0)