@@ -22,6 +22,7 @@ def _moyo_dataset(
2222 frac_pos : torch .Tensor ,
2323 atomic_numbers : torch .Tensor ,
2424 symprec : float = 1e-4 ,
25+ angle_tolerance : float | None = None ,
2526) -> MoyoDataset :
2627 """Get MoyoDataset from cell, fractional positions, and atomic numbers."""
2728 from moyopy import Cell , MoyoDataset
@@ -31,7 +32,7 @@ def _moyo_dataset(
3132 positions = frac_pos .detach ().cpu ().tolist (),
3233 numbers = atomic_numbers .detach ().cpu ().int ().tolist (),
3334 )
34- return MoyoDataset (moyo_cell , symprec = symprec )
35+ return MoyoDataset (moyo_cell , symprec = symprec , angle_tolerance = angle_tolerance )
3536
3637
3738def _extract_symmetry_ops (
@@ -51,13 +52,19 @@ def _extract_symmetry_ops(
5152 return rotations , translations
5253
5354
54- def get_symmetry_datasets (state : SimState , symprec : float = 1e-4 ) -> list [MoyoDataset ]:
55+ def get_symmetry_datasets (
56+ state : SimState ,
57+ symprec : float = 1e-4 ,
58+ angle_tolerance : float | None = None ,
59+ ) -> list [MoyoDataset ]:
5560 """Get MoyoDataset for each system in a SimState."""
5661 datasets = []
5762 for single in state .split ():
5863 cell = single .row_vector_cell [0 ]
5964 frac = single .positions @ torch .linalg .inv (cell )
60- datasets .append (_moyo_dataset (cell , frac , single .atomic_numbers , symprec ))
65+ datasets .append (
66+ _moyo_dataset (cell , frac , single .atomic_numbers , symprec , angle_tolerance )
67+ )
6168 return datasets
6269
6370
@@ -105,14 +112,15 @@ def prep_symmetry(
105112 positions : torch .Tensor ,
106113 atomic_numbers : torch .Tensor ,
107114 symprec : float = 1e-4 ,
115+ angle_tolerance : float | None = None ,
108116) -> tuple [torch .Tensor , torch .Tensor ]:
109117 """Get symmetry rotations and atom mappings for a structure.
110118
111119 Returns:
112120 (rotations, symm_map) with shapes (n_ops, 3, 3) and (n_ops, n_atoms).
113121 """
114122 frac_pos = positions @ torch .linalg .inv (cell )
115- dataset = _moyo_dataset (cell , frac_pos , atomic_numbers , symprec )
123+ dataset = _moyo_dataset (cell , frac_pos , atomic_numbers , symprec , angle_tolerance )
116124 rotations , translations = _extract_symmetry_ops (dataset , cell .dtype , cell .device )
117125 return rotations , build_symmetry_map (rotations , translations , frac_pos )
118126
@@ -122,6 +130,7 @@ def _refine_symmetry_impl(
122130 positions : torch .Tensor ,
123131 atomic_numbers : torch .Tensor ,
124132 symprec : float = 0.01 ,
133+ angle_tolerance : float | None = None ,
125134) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
126135 """Core refinement returning all intermediate data for reuse.
127136
@@ -130,7 +139,7 @@ def _refine_symmetry_impl(
130139 """
131140 dtype , device = cell .dtype , cell .device
132141 frac_pos = positions @ torch .linalg .inv (cell )
133- dataset = _moyo_dataset (cell , frac_pos , atomic_numbers , symprec )
142+ dataset = _moyo_dataset (cell , frac_pos , atomic_numbers , symprec , angle_tolerance )
134143 rotations , translations = _extract_symmetry_ops (dataset , dtype , device )
135144 n_ops , n_atoms = rotations .shape [0 ], positions .shape [0 ]
136145
@@ -165,6 +174,7 @@ def refine_symmetry(
165174 positions : torch .Tensor ,
166175 atomic_numbers : torch .Tensor ,
167176 symprec : float = 0.01 ,
177+ angle_tolerance : float | None = None ,
168178) -> tuple [torch .Tensor , torch .Tensor ]:
169179 """Symmetrize cell and positions according to the detected space group.
170180
@@ -175,7 +185,7 @@ def refine_symmetry(
175185 (symmetrized_cell, symmetrized_positions) as row vectors.
176186 """
177187 new_cell , new_positions , _rotations , _translations = _refine_symmetry_impl (
178- cell , positions , atomic_numbers , symprec
188+ cell , positions , atomic_numbers , symprec , angle_tolerance
179189 )
180190 return new_cell , new_positions
181191
@@ -185,6 +195,7 @@ def refine_and_prep_symmetry(
185195 positions : torch .Tensor ,
186196 atomic_numbers : torch .Tensor ,
187197 symprec : float = 0.01 ,
198+ angle_tolerance : float | None = None ,
188199) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
189200 """Refine symmetry and get ops/mappings in a single moyopy call.
190201
@@ -195,7 +206,7 @@ def refine_and_prep_symmetry(
195206 (refined_cell, refined_positions, rotations, symm_map)
196207 """
197208 new_cell , new_positions , rotations , translations = _refine_symmetry_impl (
198- cell , positions , atomic_numbers , symprec
209+ cell , positions , atomic_numbers , symprec , angle_tolerance
199210 )
200211 # Build symm_map on the final refined fractional coordinates
201212 refined_frac = new_positions @ torch .linalg .inv (new_cell )
0 commit comments