Skip to content

Commit 9353d37

Browse files
pfebrerLuthaf
andcommitted
Change per_atom into sample_kind while keeping backward compatibility
Co-Authored-By: Guillaume Fraux <guillaume.fraux@epfl.ch>
1 parent fb14a30 commit 9353d37

File tree

22 files changed

+473
-115
lines changed

22 files changed

+473
-115
lines changed

docs/src/engines/plumed-model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def forward(
2525
if "features" not in outputs:
2626
return {}
2727

28-
if outputs["features"].per_atom:
28+
if outputs["features"].sample_kind == "atom":
2929
raise ValueError("per-atoms features are not supported in this model")
3030

3131
# PLUMED will first call the model with 0 atoms to get the size of the
@@ -94,7 +94,7 @@ def forward(
9494
# metatdata about what the model can do
9595
capabilities = mta.ModelCapabilities(
9696
length_unit="Angstrom",
97-
outputs={"features": mta.ModelOutput(per_atom=False)},
97+
outputs={"features": mta.ModelOutput(sample_kind="system")},
9898
atomic_types=[0],
9999
interaction_range=torch.inf,
100100
supported_devices=["cpu", "cuda"],

metatomic-torch/include/metatomic/torch/model.hpp

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
3535
ModelOutputHolder() = default;
3636

3737
/// Initialize `ModelOutput` with the given data
38+
ModelOutputHolder(
39+
std::string quantity,
40+
std::string unit,
41+
std::string sample_kind,
42+
std::vector<std::string> explicit_gradients_,
43+
std::string description_
44+
):
45+
description(std::move(description_)),
46+
explicit_gradients(std::move(explicit_gradients_))
47+
{
48+
this->set_quantity(std::move(quantity));
49+
this->set_unit(std::move(unit));
50+
this->set_sample_kind(std::move(sample_kind));
51+
}
52+
53+
/// For backward compatibility in the C++ API (per_atom argument)
3854
ModelOutputHolder(
3955
std::string quantity,
4056
std::string unit,
@@ -43,13 +59,24 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
4359
std::string description_
4460
):
4561
description(std::move(description_)),
46-
per_atom(per_atom_),
4762
explicit_gradients(std::move(explicit_gradients_))
4863
{
4964
this->set_quantity(std::move(quantity));
5065
this->set_unit(std::move(unit));
66+
this->set_per_atom(per_atom_);
5167
}
5268

69+
/// For backward compatibility in the Python API
70+
ModelOutputHolder(
71+
std::string quantity,
72+
std::string unit,
73+
torch::IValue per_atom_or_sample_kind,
74+
std::vector<std::string> explicit_gradients_,
75+
std::string description_,
76+
torch::optional<bool> per_atom = torch::nullopt,
77+
torch::optional<std::string> sample_kind = torch::nullopt
78+
);
79+
5380
~ModelOutputHolder() override = default;
5481

5582
/// description of this output, defaults to empty string of not set by the user
@@ -72,8 +99,21 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
7299
/// set the unit of the output
73100
void set_unit(std::string unit);
74101

75-
/// is the output defined per-atom or for the overall structure
76-
bool per_atom = false;
102+
/// The setter and getter for `per_atom` that are used in TorchBind, which
103+
/// allow us to raise an error if `sample_kind` can't be mapped to a boolean
104+
/// value for `per_atom`.
105+
void set_per_atom(bool per_atom);
106+
bool get_per_atom() const;
107+
108+
/// This is deprecated in favor of `sample_kind`, and kept for backward compatibility reasons only.
109+
[[deprecated("use sample_kind instead")]]
110+
bool per_atom;
111+
112+
/// Get the sample kind of the output. TODO: explain
113+
std::string sample_kind() const;
114+
115+
/// Set the `sample_kind` of the output.
116+
void set_sample_kind(std::string sample_kind);
77117

78118
/// Which gradients should be computed eagerly and stored inside the output
79119
/// `TensorMap`
@@ -85,8 +125,12 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
85125
static ModelOutput from_json(std::string_view json);
86126

87127
private:
128+
void set_per_atom_no_deprecation(bool per_atom);
129+
bool get_per_atom_no_deprecation() const;
130+
88131
std::string quantity_;
89132
std::string unit_;
133+
torch::optional<std::string> sample_kind_;
90134
};
91135

92136

metatomic-torch/src/model.cpp

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,51 @@ static void read_vector_int_json(
5353

5454
/******************************************************************************/
5555

56+
ModelOutputHolder::ModelOutputHolder(
57+
std::string quantity,
58+
std::string unit,
59+
torch::IValue per_atom_or_sample_kind,
60+
std::vector<std::string> explicit_gradients_,
61+
std::string description_,
62+
torch::optional<bool> per_atom,
63+
torch::optional<std::string> sample_kind
64+
):
65+
description(std::move(description_)),
66+
explicit_gradients(std::move(explicit_gradients_))
67+
{
68+
this->set_quantity(std::move(quantity));
69+
this->set_unit(std::move(unit));
70+
71+
if (per_atom_or_sample_kind.isNone()) {
72+
// check the kwargs for backward compatibility
73+
if (sample_kind.has_value() && per_atom.has_value()) {
74+
C10_THROW_ERROR(ValueError, "cannot specify both `per_atom` and `sample_kind`");
75+
} else if (sample_kind.has_value()) {
76+
this->set_sample_kind(sample_kind.value());
77+
} else if (per_atom.has_value()) {
78+
this->set_per_atom(per_atom.value());
79+
}
80+
} else if (per_atom_or_sample_kind.isBool()) {
81+
if (per_atom.has_value()) {
82+
C10_THROW_ERROR(ValueError,
83+
"cannot specify `per_atom` both as a positional and keyword argument"
84+
);
85+
}
86+
this->set_per_atom(per_atom_or_sample_kind.toBool());
87+
} else if (per_atom_or_sample_kind.isString()) {
88+
if (sample_kind.has_value()) {
89+
C10_THROW_ERROR(ValueError,
90+
"cannot specify `sample_kind` both as a positional and keyword argument"
91+
);
92+
}
93+
this->set_sample_kind(per_atom_or_sample_kind.toStringRef());
94+
} else {
95+
C10_THROW_ERROR(ValueError,
96+
"positional argument for `per_atom`/`sample_kind` must be either a boolean or a string"
97+
);
98+
}
99+
}
100+
56101
void ModelOutputHolder::set_quantity(std::string quantity) {
57102
if (valid_quantity(quantity)) {
58103
validate_unit(quantity, unit_);
@@ -72,7 +117,7 @@ static nlohmann::json model_output_to_json(const ModelOutputHolder& self) {
72117
result["class"] = "ModelOutput";
73118
result["quantity"] = self.quantity();
74119
result["unit"] = self.unit();
75-
result["per_atom"] = self.per_atom;
120+
result["sample_kind"] = self.sample_kind();
76121
result["explicit_gradients"] = self.explicit_gradients;
77122
result["description"] = self.description;
78123

@@ -112,11 +157,18 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
112157
result->set_unit(data["unit"]);
113158
}
114159

115-
if (data.contains("per_atom")) {
160+
if (data.contains("sample_kind")) {
161+
if (!data["sample_kind"].is_string()) {
162+
throw std::runtime_error("'sample_kind' in JSON for ModelOutput must be a string");
163+
}
164+
result->set_sample_kind(data["sample_kind"]);
165+
} else if (data.contains("per_atom")) {
116166
if (!data["per_atom"].is_boolean()) {
117167
throw std::runtime_error("'per_atom' in JSON for ModelOutput must be a boolean");
118168
}
119-
result->per_atom = data["per_atom"];
169+
result->set_per_atom(data["per_atom"]);
170+
} else {
171+
result->set_sample_kind("system");
120172
}
121173

122174
if (data.contains("explicit_gradients")) {
@@ -145,6 +197,87 @@ ModelOutput ModelOutputHolder::from_json(std::string_view json) {
145197
return model_output_from_json(data);
146198
}
147199

200+
static std::set<std::string> SUPPORTED_SAMPLE_KINDS = {
201+
"system",
202+
"atom",
203+
"atom_pair",
204+
};
205+
206+
void ModelOutputHolder::set_sample_kind(std::string sample_kind) {
207+
if (sample_kind == "atom") {
208+
this->set_per_atom_no_deprecation(true);
209+
} else if (sample_kind == "system") {
210+
this->set_per_atom_no_deprecation(false);
211+
} else {
212+
if (SUPPORTED_SAMPLE_KINDS.find(sample_kind) == SUPPORTED_SAMPLE_KINDS.end()) {
213+
C10_THROW_ERROR(ValueError,
214+
"invalid sample_kind '" + sample_kind + "': supported values are [" +
215+
torch::str(SUPPORTED_SAMPLE_KINDS) + "]"
216+
);
217+
}
218+
219+
this->sample_kind_ = std::move(sample_kind);
220+
}
221+
}
222+
223+
std::string ModelOutputHolder::sample_kind() const {
224+
if (sample_kind_.has_value()) {
225+
return sample_kind_.value();
226+
} else if (this->get_per_atom_no_deprecation()) {
227+
return "atom";
228+
} else {
229+
return "system";
230+
}
231+
}
232+
233+
void ModelOutputHolder::set_per_atom(bool per_atom_) {
234+
TORCH_WARN_DEPRECATION(
235+
"`per_atom` is deprecated, please use `sample_kind` instead"
236+
);
237+
238+
this->set_per_atom_no_deprecation(per_atom_);
239+
}
240+
241+
bool ModelOutputHolder::get_per_atom() const {
242+
TORCH_WARN_DEPRECATION(
243+
"`per_atom` is deprecated, please use `sample_kind` instead"
244+
);
245+
246+
return this->get_per_atom_no_deprecation();
247+
}
248+
249+
#if defined(__GNUC__) || defined(__clang__)
250+
#pragma GCC diagnostic push
251+
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
252+
#endif
253+
254+
void ModelOutputHolder::set_per_atom_no_deprecation(bool per_atom) {
255+
this->per_atom = per_atom;
256+
257+
this->sample_kind_ = torch::nullopt;
258+
}
259+
260+
bool ModelOutputHolder::get_per_atom_no_deprecation() const {
261+
if (sample_kind_.has_value()) {
262+
if (sample_kind_.value() == "atom") {
263+
return true;
264+
} else if (sample_kind_.value() == "system") {
265+
return false;
266+
} else {
267+
C10_THROW_ERROR(
268+
ValueError,
269+
"Can't infer `per_atom` from `sample_kind` '" + this->sample_kind() + "'. "
270+
"`per_atom` only makes sense for `sample_kind` 'atom' and 'system'."
271+
);
272+
}
273+
}
274+
return per_atom;
275+
}
276+
277+
#if defined(__GNUC__) || defined(__clang__)
278+
#pragma GCC diagnostic pop
279+
#endif
280+
148281
/******************************************************************************/
149282

150283

metatomic-torch/src/outputs.cpp

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,27 @@ static void validate_atomic_samples(
8484
auto tensor_options = torch::TensorOptions().device(value->device());
8585
TensorBlock block = TensorMapHolder::block_by_id(value, 0);
8686

87-
// Check if the samples names are as expected based on whether the output is
88-
// per-atom or global
87+
// Check if the samples names are as expected based on the sample_kind
8988
std::vector<std::string> expected_samples_names;
90-
if (request->per_atom) {
89+
if (request->sample_kind() == "atom") {
9190
expected_samples_names = {"system", "atom"};
92-
} else {
91+
} else if (request->sample_kind() == "system") {
9392
expected_samples_names = {"system"};
93+
} else if (request->sample_kind() == "atom_pair") {
94+
expected_samples_names = {
95+
"system",
96+
"first_atom",
97+
"second_atom",
98+
"cell_shift_a",
99+
"cell_shift_b",
100+
"cell_shift_c"
101+
};
102+
} else {
103+
C10_THROW_ERROR(ValueError,
104+
"Metatomic does not support validating samples for sample_kind"
105+
"other than 'system', 'atom' or 'atom_pair' at the moment."
106+
" Received sample_kind '" + request->sample_kind()
107+
);
94108
}
95109

96110
if (block->samples()->names() != expected_samples_names) {
@@ -103,7 +117,7 @@ static void validate_atomic_samples(
103117

104118
// Check if the samples match the systems and selected_atoms
105119
Labels expected_samples;
106-
if (request->per_atom) {
120+
if (request->sample_kind() == "atom") {
107121
std::vector<int64_t> expected_values_flat;
108122
for (size_t s; s < systems.size(); s++) {
109123
for (size_t a; a < systems[s]->size(); a++) {
@@ -122,7 +136,7 @@ static void validate_atomic_samples(
122136
if (selected_atoms) {
123137
expected_samples = expected_samples->set_intersection(selected_atoms.value());
124138
}
125-
} else {
139+
} else if (request->sample_kind() == "system") {
126140
expected_samples = torch::make_intrusive<LabelsHolder>(
127141
"system",
128142
torch::arange(static_cast<int64_t>(systems.size()), tensor_options).reshape({-1, 1}),
@@ -138,6 +152,40 @@ static void validate_atomic_samples(
138152
);
139153
expected_samples = expected_samples->set_intersection(selected_systems);
140154
}
155+
} else if (request->sample_kind() == "atom_pair") {
156+
// minimal validation, just that indices are in-bounds
157+
auto values = block->samples()->values().to(torch::kCPU);
158+
for (int64_t i = 0; i < values.size(0); i++) {
159+
auto system_idx = values[i][0].item<int64_t>();
160+
auto first_atom_idx = values[i][1].item<int64_t>();
161+
auto second_atom_idx = values[i][2].item<int64_t>();
162+
163+
if (system_idx < 0 || system_idx >= static_cast<int64_t>(systems.size())) {
164+
C10_THROW_ERROR(ValueError,
165+
"invalid system index in samples for '" + name + "' output: " +
166+
std::to_string(system_idx) + " is out of bounds"
167+
);
168+
}
169+
const auto& system = systems[system_idx];
170+
if (first_atom_idx < 0 || first_atom_idx >= system->size()) {
171+
C10_THROW_ERROR(ValueError,
172+
"invalid first_atom index in samples for '" + name + "' output: " +
173+
std::to_string(first_atom_idx) + " is out of bounds for system " +
174+
std::to_string(system_idx)
175+
);
176+
}
177+
if (second_atom_idx < 0 || second_atom_idx >= system->size()) {
178+
C10_THROW_ERROR(ValueError,
179+
"invalid second_atom index in samples for '" + name + "' output: " +
180+
std::to_string(second_atom_idx) + " is out of bounds for system " +
181+
std::to_string(system_idx)
182+
);
183+
}
184+
}
185+
} else {
186+
C10_THROW_ERROR(ValueError,
187+
"got invalid sample_kind '" + request->sample_kind() + "' for '" + name + "'"
188+
);
141189
}
142190

143191
if (expected_samples->set_union(block->samples())->size() != expected_samples->size()) {
@@ -594,10 +642,10 @@ static void check_heat_flux(
594642
validate_single_block("heat_flux", value);
595643

596644
// Check samples values from systems
597-
if (request->per_atom) {
645+
if (request->sample_kind() == "atom") {
598646
C10_THROW_ERROR(ValueError,
599-
"invalid 'heat_flux' output: heat flux cannot be per-atom, but the request "
600-
"indicates `per_atom=True`"
647+
"invalid 'heat_flux' output: heat flux cannot be per-atom, "
648+
"but the request indicates `sample_kind='atom'`"
601649
);
602650
}
603651
validate_atomic_samples("heat_flux", value, systems, request, torch::nullopt);

0 commit comments

Comments
 (0)