TorchRef is a PyTorch-based crystallographic refinement library that uses automatic differentiation and GPU acceleration for structure refinement. It follows PyTorch's nn.Module architecture throughout.
Repository: https://github.com/HatPdotS/TorchRef Documentation: https://torchref.readthedocs.io/ Version: 0.4.1 License: MIT
from torchref import LBFGSRefinement, ModelFT, ReflectionData, Scaler, ROOT_TORCHREF
# Load and refine
refinement = LBFGSRefinement(data_file="data.mtz", pdb="structure.pdb")
refinement.refine_xyz() # coordinate refinement
refinement.refine_adp() # B-factor refinement
refinement.refine(macro_cycles=5) # full refinement
rwork, rfree = refinement.get_rfactor()# Conda environment
module load anaconda && conda activate /das/work/p17/p17490/CONDA/torchref
# Python interpreter
/das/work/units/LBR-FEL/p17490/CONDA/torchref/bin/python
# Threading: auto-detected or override with
export TORCHREF_NUM_THREADS=4# Install (development mode)
pip install -e ".[dev]"
# Run tests
pytest tests/ # all tests
pytest tests/unit/ # fast unit tests (~234 tests)
pytest tests/integration/ # integration tests (~95 tests)
pytest tests/functional/ # full workflow tests (~136 tests)
pytest -m "not gpu and not slow" # skip GPU/slow tests
# Test with coverage
pytest tests/ --cov=torchref
# Compatibility testing (tox)
tox -e py311-default # Python 3.11 default deps
# See tox.ini for 27 environments (Python 3.10-3.13, various dep versions)
# Linting and formatting
black torchref/ --check # check formatting (88 char line length)
isort torchref/ --check # check import order
flake8 torchref/ # lint (ignores F841, E741, E402, E722)cd docs/
# Regenerate API docs from source
make apidoc
# Build HTML docs
make html
# Output: docs/_build/html/
# Run doctests in documentation
make doctestSphinx config is in docs/conf.py. Uses RTD theme, Napoleon (NumPy docstrings), autodoc, intersphinx (Python, NumPy, PyTorch, Pandas).
torchref/
├── __init__.py # Public API exports, threading config
├── config.py # DtypeConfig singleton (dtypes.float, dtypes.int, dtypes.complex)
├── _bootstrap.py # Threading configuration
│
├── model/ # Atomic structure models
│ ├── model.py # Model (base) - coordinates, B-factors, occupancies
│ ├── model_ft.py # ModelFT - FFT structure factors, electron density
│ ├── parameter_wrappers.py # MixedTensor, PositiveMixedTensor, OccupancyTensor
│ ├── sf_fft.py # Structure factor calculation via FFT
│ ├── sf_ds.py # Structure factor via direct summation
│ ├── internal_coordinates.py # Internal coordinate systems
│ ├── segmented_internal_coordinates.py # Segmented IC (chains)
│ ├── closed_segmented_internal_coordinates.py # IC with chain closure
│ └── mixed_model.py # Multi-model handling
│
├── io/ # File I/O
│ ├── pdb.py, mtz.py, cif.py # Format readers/writers
│ ├── data_router.py # Auto file format detection
│ └── datasets/ # ReflectionData, FcalcData, DatasetCollection
│
├── refinement/ # Refinement framework
│ ├── base_refinement.py # Refinement base class
│ ├── lbfgs_refinement.py # LBFGSRefinement (primary refinement class)
│ ├── loss_state.py # LossState - target aggregation, weights, torch.compile
│ ├── logger.py # Progress logging
│ ├── targets/ # Target functions (xray/, geometry/, adp/, realspace, etc.)
│ ├── weighting/ # Loss weighting schemes (random, policy, component)
│ └── optimizers/ # Custom optimizers (SA, Langevin, exploratory LBFGS)
│
├── restraints/ # Geometry restraints (bonds, angles, torsions, planes, Ramachandran)
│ ├── restraints.py # Main Restraints class
│ ├── builders.py # Restraint construction from monomer library
│ └── library.py # Monomer library (lazy download)
│
├── scaling/ # Structure factor scaling
│ ├── scaler.py # Scaler (overall scale, anisotropy, bulk solvent)
│ └── solvent.py # Bulk solvent model
│
├── symmetry/ # Crystallographic symmetry (Cell, SpaceGroup, MapSymmetry)
├── maps/ # Electron density maps (Map, DifferenceMap)
├── alignment/ # Patterson-based structure alignment
├── base/ # Core math: coordinates, reciprocal, scattering, ED, FFT, kernels
├── utils/ # Caching, device mixin, debugging, timing, serialization
├── cli/ # 8 CLI entry points
└── data/ # Package data (scattering tables, monomer library CIFs)
nn.Module
├── Model (base atomic model)
│ └── ModelFT (+ FFT structure factors, CachedForwardMixin)
├── Refinement (base refinement coordinator)
│ └── LBFGSRefinement (LBFGS optimization)
├── Target (base target/loss function)
│ ├── ModelTarget (needs Model reference)
│ │ ├── GeometryTarget → BondTarget, AngleTarget, TorsionTarget, ...
│ │ └── ADPTarget → ADPSimilarityTarget, RigidBondTarget, ...
│ └── DataTarget (needs ReflectionData + optional Model/Scaler)
│ └── XrayTarget → LeastSquaresXrayTarget, MaximumLikelihoodXrayTarget, ...
├── ScalerBase
│ └── Scaler
└── MixedTensor, PositiveMixedTensor, OccupancyTensor (parameter wrappers)
-
Parameter Wrappers:
MixedTensorholdsrefinable_params(nn.Parameter) andfixed_values(buffer). Freeze/unfreeze moves atoms between the two. -
Smart Caching (
CachedForwardMixininutils/caching.py):- Tracks
(data_ptr, _version)of all parameters and buffers - Also fingerprints input tensors/kwargs
- Backward hook increments generation counter for automatic invalidation
reset_forward_cache()for manual invalidation- Used by: MixedTensor, PositiveMixedTensor, OccupancyTensor, SegmentedInternalCoordinateTensor, ModelFT
- Tracks
-
Structure Factor Cache (in
model_ft.py):- Separate from
CachedForwardMixin, HKL-specific reset_cache()resets SF cache + all wrapper forward caches
- Separate from
-
Hierarchical Loss Management (
LossState):state.register_target('xray/work', target) state.register_target('geometry/bond', target) state.set_weight('xray', 1.0) # group weight state.set_weight('geometry', 0.5) # group weight total = state.aggregate() # evaluates all, applies weights
-
Device Movement: All major classes inherit
DeviceMovementMixinfor automatic.to(device)propagation. -
Empty Initialization: Classes can be instantiated empty for
state_dictloading:model = Model() model.load_state_dict(torch.load('model.pt'))
model = ModelFT(max_res=1.0, radius_angstrom=4.0)
model.load_pdb("structure.pdb")
model.load_cif("structure.cif")
# Structure factors
fcalc = model(hkl) # or model.forward(hkl)
fcalc = model.get_structure_factor(hkl)
# Parameter control
model.freeze('b') # freeze B-factors
model.unfreeze('xyz') # unfreeze coordinates
model.freeze_selection("chain A and resseq 10:20") # phenix-style syntax
model.unfreeze_selection("all")
model.shake_coords(0.1) # perturb for testing
# Access parameters
xyz = model.xyz.refinable_params # nn.Parameter (N_refinable, 3)
b = model.adp.refinable_params # nn.Parameter (N_refinable,)data = ReflectionData()
data.load_mtz("reflections.mtz")
hkl, F, sigF, rfree_flags = data() # unpackref = LBFGSRefinement(
data_file="data.mtz",
pdb="structure.pdb",
device=torch.device("cuda"), # GPU acceleration
)
ref.refine_xyz() # coordinate refinement
ref.refine_adp() # B-factor refinement
ref.refine(macro_cycles=5) # full refinement
rwork, rfree = ref.get_rfactor()
# Output
ref.write_out_pdb("refined.pdb")
ref.write_out_mtz("refined.mtz")from torchref.refinement.targets import Target
class MyTarget(Target):
name = 'my_custom_target'
def __init__(self, refinement):
super().__init__()
# store what you need
def forward(self):
# define loss - gradients computed automatically by PyTorch
return loss
# Register with LossState
loss_state = refinement.create_loss_state()
loss_state.register_target('custom', MyTarget(refinement))
loss_state.set_weight('custom', 1.0)
total_loss = loss_state.aggregate()
total_loss.backward() # autograd handles gradientsscaler = Scaler(model, data, nbins=20)
scaler.initialize()
scaler.refine_lbfgs()
scaled_fcalc = scaler(fcalc)from torchref import Map, DifferenceMap
m = Map(model, data, scaler)
diff = DifferenceMap(model, data, scaler)| Command | Description |
|---|---|
torchref.refine |
Basic LBFGS refinement |
torchref.refine-static |
Fixed weights (xray=1.0, geom=10.0, adp=5.0) |
torchref.refine-hyper |
User-provided hyperparameters |
torchref.refine-policy |
Neural network policy weights |
torchref.refine-random-weights |
Random weight sampling (training data) |
torchref.difference-refine |
Difference refinement (time-resolved) |
torchref.mtz2map |
MTZ to CCP4 map conversion |
torchref.validate-ded |
Validate difference electron density |
torchref.phased-difference-map |
Phased difference maps |
Common CLI usage:
torchref.refine -f reflections.mtz -s structure.pdb --output refined.pdb- Formatter: Black (88 char line length, target py310)
- Imports: isort with Black profile
- Docstrings: NumPy style (mandatory for public API)
- Type hints: Required for public functions
- Linting: Flake8 (ignores F841, E741, E402, E722)
- Naming conventions:
f_calc= complex structure factors (with phase)F_calc= amplitudes (absolute values)adp= atomic displacement parametersu= anisotropic U tensor (6 components)b= B-factorxyz= Cartesian coordinates (Angstroms)xyz_fractional= fractional coordinates (0-1)d_min/d_max= resolution boundsspacegroup= one word (not "space_group")
Test structures in tests/files/: 1DAW, 2DQ6, 3A5V, 3E98, 3GR5, 3K7M, 3VRJ, 4BX9, 5BOV, 6G9X.
Example data in example_notebooks/: 1DAW.pdb, 1DAW.mtz.
Core: numpy>=2.0, torch>=2.4, pandas>=2.0, scipy>=1.10, gemmi>=0.5, reciprocalspaceship>=0.9.18, numba>=0.59, matplotlib>=3.7, tqdm, pyarrow
Optional:
pip install torchref[dev]- pytest, black, isort, flake8pip install torchref[alignment]- JAX, s2fft, s2ball (Patterson alignment)pip install torchref[forcefield]- torchmd-netpip install torchref[amber]- OpenMM, pdbfixerpip install torchref[docs]- Sphinx, RTD theme, numpydoc