diff --git a/bgflow/distribution/distributions.py b/bgflow/distribution/distributions.py index 6bd4dd7..7d70870 100644 --- a/bgflow/distribution/distributions.py +++ b/bgflow/distribution/distributions.py @@ -115,3 +115,16 @@ def _energy(self, x): def _sample_with_temperature(self, n_samples, temperature): return self._sample(n_samples) + + def cdf(self, x): + return (x-self._delegate.base_dist.low)/self.domain + + def icdf(self, x): + return x*self.domain + self._delegate.base_dist.low + + def log_prob(self, x): + return torch.log(1/self.domain) + + @property + def domain(self): + return self._delegate.base_dist.high - self._delegate.base_dist.low \ No newline at end of file diff --git a/bgflow/factory/GNN_factory.py b/bgflow/factory/GNN_factory.py index 43248a5..5d1bc8d 100644 --- a/bgflow/factory/GNN_factory.py +++ b/bgflow/factory/GNN_factory.py @@ -22,7 +22,8 @@ from typing import Optional, Union, Callable -__all__ = ["NormalizedBasis", "CustomTransformerEncoderLayer", "make_allegro_config_dict"] +__all__ = ["NormalizedBasis", "CustomTransformerEncoderLayer", + "make_allegro_config_dict"] # define the Normalized Basis that is also shifted a bit to increase stability: @@ -52,24 +53,25 @@ class NormalizedBasis(torch.nn.Module): def __init__( self, - data = None, + data=None, original_basis=BesselBasis, original_basis_kwargs: dict = {}, norm_basis_mean_shift: bool = True, - offset = 1. + offset=1. ): super().__init__() self.offset = offset - #### shift all entries to the right a bit. + # shift all entries to the right a bit. data += self.offset - #### change r_max accordingly + # change r_max accordingly original_basis_kwargs["r_max"] += self.offset self.basis = original_basis(**original_basis_kwargs) self.num_basis = self.basis.num_basis with torch.no_grad(): if data is None: - raise ValueError("gotta pass data to inform the Basis Function") + raise ValueError( + "gotta pass data to inform the Basis Function") bs = self.basis(data) assert bs.ndim == 2 if norm_basis_mean_shift: @@ -87,8 +89,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return (self.basis(x_o) - self._mean) * self._inv_std - - class CustomTransformerEncoderLayer(torch.nn.Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". @@ -152,7 +152,6 @@ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropou self.qkv_proj = torch.nn.Linear(d_model, 3 * d_model) - self.activation = activation def __setstate__(self, state): @@ -176,16 +175,18 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_paddin x = src if self.norm_first: - x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._sa_block(self.norm1(x), src_mask, + src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: - x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x - # self-attention block + def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: qkv = self.qkv_proj(x) @@ -202,204 +203,191 @@ def _ff_block(self, x: Tensor) -> Tensor: return self.dropout2(x) - - - - - - allegro_hparams = { -"r_max" : 3.2, -"num_types" : 10, -"num_basis" : 32, -"p" : 6, -"avg_num_neighbors" : 9, -"num_layers" : 3, -"env_embed_multiplicity" : 32 , -"latent_dim" : 32, -"two_body_latent_intermediate_dims" : [128, 128, 128], -"nonscalars_include_parity" : False, -"irreps_edge_sh" : '1x0e+1x1o+1x2e', # embed direction vector to only vectors and scalar tensor would be :'1x0e+1x1o' -"RBF_distance_offset" : 1. , -"GNN_feature_dim" : 32, -"latent_resnet": True, -"GNN_scope": "atomwise" + "r_max": 3.2, + "num_types": 10, + "num_basis": 32, + "p": 6, + "avg_num_neighbors": 9, + "num_layers": 3, + "env_embed_multiplicity": 32, + "latent_dim": 32, + "two_body_latent_intermediate_dims": [128, 128, 128], + "nonscalars_include_parity": False, + # embed direction vector to only vectors and scalar tensor would be :'1x0e+1x1o' + "irreps_edge_sh": '1x0e+1x1o+1x2e', + "RBF_distance_offset": 1., + "GNN_feature_dim": 32, + "latent_resnet": True, + "GNN_scope": "atomwise" } -def make_allegro_config_dict(**kwargs): +def make_allegro_config_dict(r_max, + num_types, + num_basis, + p, + avg_num_neighbors, + num_layers, + env_embed_multiplicity, + latent_dim, + two_body_latent_intermediate_dims, + nonscalars_include_parity, + irreps_edge_sh, + RBF_distance_offset, + GNN_feature_dim, + latent_resnet, + GNN_scope + ): import allegro from allegro.nn._allegro import Allegro_Module - r_max = kwargs["r_max"] - num_types = kwargs["num_types"] - num_basis = kwargs["num_basis"] - p = kwargs["p"] - avg_num_neighbors = kwargs["avg_num_neighbors"] - num_layers = kwargs["num_layers"] - env_embed_multiplicity = kwargs["env_embed_multiplicity"] - latent_dim = kwargs["latent_dim"] - two_body_latent_intermediate_dims = kwargs["two_body_latent_intermediate_dims"] - nonscalars_include_parity = kwargs["nonscalars_include_parity"] - irreps_edge_sh = kwargs["irreps_edge_sh"] - RBF_distance_offset = kwargs["RBF_distance_offset"] - GNN_feature_dim = kwargs["GNN_feature_dim"] - latent_resnet = kwargs["latent_resnet"] - GNN_scope = kwargs["GNN_scope"] - - base_dict = {'one_hot': (nequip.nn.embedding._one_hot.OneHotAtomEncoding, - {'irreps_in': None, 'set_features': True, 'num_types': num_types}), - - - 'radial_basis': (nequip.nn.embedding._edge.RadialBasisEdgeEncoding, - {'basis': NormalizedBasis, - 'cutoff': nequip.nn.cutoffs.PolynomialCutoff, - 'basis_kwargs': {'data': None, - 'original_basis': nequip.nn.radial_basis.BesselBasis, - 'original_basis_kwargs': {'num_basis': num_basis, - 'trainable': True, - 'r_max': r_max}, - 'norm_basis_mean_shift': True, - 'offset': RBF_distance_offset}, - 'cutoff_kwargs': {'p': p, 'r_max': r_max}, - 'out_field': 'edge_embedding'}), - - - 'spharm': (nequip.nn.embedding._edge.SphericalHarmonicEdgeAttrs, - {'edge_sh_normalization': 'component', - 'edge_sh_normalize': True, - 'out_field': 'edge_attrs', - 'irreps_edge_sh': irreps_edge_sh}), - - - - - - 'allegro': (Allegro_Module, - {'avg_num_neighbors': avg_num_neighbors, - 'r_start_cos_ratio': 0.8, # unused - 'PolynomialCutoff_p': p, - 'per_layer_cutoffs': None, - 'cutoff_type': 'polynomial', - 'field': 'edge_attrs', - 'edge_invariant_field': 'edge_embedding', - 'node_invariant_field': 'node_attrs', - 'env_embed_multiplicity': env_embed_multiplicity, - 'embed_initial_edge': True, - 'linear_after_env_embed': False, - 'nonscalars_include_parity': nonscalars_include_parity, - 'two_body_latent': allegro.nn._fc.ScalarMLPFunction, - 'two_body_latent_kwargs': {'mlp_nonlinearity': 'silu', - 'mlp_initialization': 'uniform', - 'mlp_dropout_p': 0.0, - 'mlp_batchnorm': False, - 'mlp_latent_dimensions': [*two_body_latent_intermediate_dims, latent_dim]}, - 'env_embed': allegro.nn._fc.ScalarMLPFunction, - 'env_embed_kwargs': {'mlp_nonlinearity': None, - 'mlp_initialization': 'uniform', - 'mlp_dropout_p': 0.0, - 'mlp_batchnorm': False, - 'mlp_latent_dimensions': []}, - 'latent': allegro.nn._fc.ScalarMLPFunction, - 'latent_kwargs': {'mlp_nonlinearity': 'silu', - 'mlp_initialization': 'uniform', - 'mlp_dropout_p': 0.0, - 'mlp_batchnorm': False, - 'mlp_latent_dimensions': [latent_dim]}, - 'latent_resnet': latent_resnet, - 'latent_resnet_update_ratios': None, - 'latent_resnet_update_ratios_learnable': False, - 'latent_out_field': 'edge_features', - 'pad_to_alignment': 1, - 'sparse_mode': None, - 'r_max': r_max, - 'num_layers': num_layers, - 'num_types': num_types}), - } + base_dict = {'one_hot': (nequip.nn.embedding._one_hot.OneHotAtomEncoding, + {'irreps_in': None, 'set_features': True, 'num_types': num_types}), + + + 'radial_basis': (nequip.nn.embedding._edge.RadialBasisEdgeEncoding, + {'basis': NormalizedBasis, + 'cutoff': nequip.nn.cutoffs.PolynomialCutoff, + 'basis_kwargs': {'data': None, + 'original_basis': nequip.nn.radial_basis.BesselBasis, + 'original_basis_kwargs': {'num_basis': num_basis, + 'trainable': True, + 'r_max': r_max}, + 'norm_basis_mean_shift': True, + 'offset': RBF_distance_offset}, + 'cutoff_kwargs': {'p': p, 'r_max': r_max}, + 'out_field': 'edge_embedding'}), + + + 'spharm': (nequip.nn.embedding._edge.SphericalHarmonicEdgeAttrs, + {'edge_sh_normalization': 'component', + 'edge_sh_normalize': True, + 'out_field': 'edge_attrs', + 'irreps_edge_sh': irreps_edge_sh}), + + + + + + 'allegro': (Allegro_Module, + {'avg_num_neighbors': avg_num_neighbors, + 'r_start_cos_ratio': 0.8, # unused + 'PolynomialCutoff_p': p, + 'per_layer_cutoffs': None, + 'cutoff_type': 'polynomial', + 'field': 'edge_attrs', + 'edge_invariant_field': 'edge_embedding', + 'node_invariant_field': 'node_attrs', + 'env_embed_multiplicity': env_embed_multiplicity, + 'embed_initial_edge': True, + 'linear_after_env_embed': False, + 'nonscalars_include_parity': nonscalars_include_parity, + 'two_body_latent': allegro.nn._fc.ScalarMLPFunction, + 'two_body_latent_kwargs': {'mlp_nonlinearity': 'silu', + 'mlp_initialization': 'uniform', + 'mlp_dropout_p': 0.0, + 'mlp_batchnorm': False, + 'mlp_latent_dimensions': [*two_body_latent_intermediate_dims, latent_dim]}, + 'env_embed': allegro.nn._fc.ScalarMLPFunction, + 'env_embed_kwargs': {'mlp_nonlinearity': None, + 'mlp_initialization': 'uniform', + 'mlp_dropout_p': 0.0, + 'mlp_batchnorm': False, + 'mlp_latent_dimensions': []}, + 'latent': allegro.nn._fc.ScalarMLPFunction, + 'latent_kwargs': {'mlp_nonlinearity': 'silu', + 'mlp_initialization': 'uniform', + 'mlp_dropout_p': 0.0, + 'mlp_batchnorm': False, + 'mlp_latent_dimensions': [latent_dim]}, + 'latent_resnet': latent_resnet, + 'latent_resnet_update_ratios': None, + 'latent_resnet_update_ratios_learnable': False, + 'latent_out_field': 'edge_features', + 'pad_to_alignment': 1, + 'sparse_mode': None, + 'r_max': r_max, + 'num_layers': num_layers, + 'num_types': num_types}), + } atomwise_dict = {'atomwise_gather': (allegro.nn._edgewise.EdgewiseReduce, - {'field': 'edge_features', - 'out_field': 'atomwise_features', - 'avg_num_neighbors': avg_num_neighbors} - ), - 'atomwise_linear': (nequip.nn._atomwise.AtomwiseLinear, - {'field': 'atomwise_features', - 'out_field': 'outputs', - 'irreps_out': f"{GNN_feature_dim}x0e"} - ) - } - bondwise_dict = {'bondwise_linear': (nequip.nn._atomwise.AtomwiseLinear, ## it is still edgewiese - {'field': 'edge_features', - 'out_field': 'outputs', - 'irreps_out': f"{GNN_feature_dim}x0e"} - ) - } + {'field': 'edge_features', + 'out_field': 'atomwise_features', + 'avg_num_neighbors': avg_num_neighbors} + ), + 'atomwise_linear': (nequip.nn._atomwise.AtomwiseLinear, + {'field': 'atomwise_features', + 'out_field': 'outputs', + 'irreps_out': f"{GNN_feature_dim}x0e"} + ) + } + bondwise_dict = {'bondwise_linear': (nequip.nn._atomwise.AtomwiseLinear, # it is still edgewiese + {'field': 'edge_features', + 'out_field': 'outputs', + 'irreps_out': f"{GNN_feature_dim}x0e"} + ) + } if GNN_scope == "atomwise": - dict = base_dict|atomwise_dict + dict = base_dict | atomwise_dict elif GNN_scope == "bondwise": - dict = base_dict|bondwise_dict + dict = base_dict | bondwise_dict else: raise ValueError('GNN_scope must be either "atomwise" or "bondwise"') return dict -#allegro_config_dict = make_allegro_config_dict(**allegro_hparams) - nequip_hparams = { "r_max": 3.2, "num_types": 10, "num_basis": 32, "p": 6, - "avg_num_neighbors": 9, - "num_layers": 3, - "latent_dim": 32, - "nonscalars_include_parity": False, # True - "irreps_edge_sh": '1x0e+1x1o+1x2e', # change this to "1x0e" to reduce nequip to schnett basically "RBF_distance_offset": 1., "GNN_feature_dim": 32, - - "num_interaction_blocks": 4 + "num_interaction_blocks": 4, + "irreps_edge_sh": '1x0e+1x1o+1x2e', # change this to "1x0e" to reduce nequip to schnett basically } -def make_nequip_config_dict(**kwargs): ###rename to make_nequip_config_dict - r_max = kwargs["r_max"] - num_types = kwargs["num_types"] - num_basis = kwargs["num_basis"] - p = kwargs["p"] - irreps_edge_sh = kwargs["irreps_edge_sh"] - RBF_distance_offset = kwargs["RBF_distance_offset"] - GNN_feature_dim = kwargs["GNN_feature_dim"] - num_interaction_blocks = kwargs["num_interaction_blocks"] - - base_dict = {'one_hot': (nequip.nn.embedding._one_hot.OneHotAtomEncoding, - {'irreps_in': None, 'set_features': True, 'num_types': num_types}), - - - 'radial_basis': (nequip.nn.embedding._edge.RadialBasisEdgeEncoding, - {'basis': NormalizedBasis, - 'cutoff': nequip.nn.cutoffs.PolynomialCutoff, - 'basis_kwargs': {'data': None, - 'original_basis': nequip.nn.radial_basis.BesselBasis, - 'original_basis_kwargs': {'num_basis': num_basis, - 'trainable': True, - 'r_max': r_max}, - 'norm_basis_mean_shift': True, - 'offset': RBF_distance_offset}, - 'cutoff_kwargs': {'p': p, 'r_max': r_max}, - 'out_field': 'edge_embedding'}), - - - 'spharm': (nequip.nn.embedding._edge.SphericalHarmonicEdgeAttrs, - {'edge_sh_normalization': 'component', - 'edge_sh_normalize': True, - 'out_field': 'edge_attrs', - 'irreps_edge_sh': irreps_edge_sh}), - - 'chemical_embedding': (nequip.nn._atomwise.AtomwiseLinear, - {'field': 'node_features', - 'out_field': None, - 'irreps_out': '32x0e'}) - } +def make_nequip_config_dict(r_max, + num_types, + num_basis, + p, + irreps_edge_sh, + RBF_distance_offset, + GNN_feature_dim, + num_interaction_blocks + ): + + base_dict = {'one_hot': (nequip.nn.embedding._one_hot.OneHotAtomEncoding, + {'irreps_in': None, 'set_features': True, 'num_types': num_types}), + + + 'radial_basis': (nequip.nn.embedding._edge.RadialBasisEdgeEncoding, + {'basis': NormalizedBasis, + 'cutoff': nequip.nn.cutoffs.PolynomialCutoff, + 'basis_kwargs': {'data': None, + 'original_basis': nequip.nn.radial_basis.BesselBasis, + 'original_basis_kwargs': {'num_basis': num_basis, + 'trainable': True, + 'r_max': r_max}, + 'norm_basis_mean_shift': True, + 'offset': RBF_distance_offset}, + 'cutoff_kwargs': {'p': p, 'r_max': r_max}, + 'out_field': 'edge_embedding'}), + + + 'spharm': (nequip.nn.embedding._edge.SphericalHarmonicEdgeAttrs, + {'edge_sh_normalization': 'component', + 'edge_sh_normalize': True, + 'out_field': 'edge_attrs', + 'irreps_edge_sh': irreps_edge_sh}), + + 'chemical_embedding': (nequip.nn._atomwise.AtomwiseLinear, + {'field': 'node_features', + 'out_field': None, + 'irreps_out': '32x0e'}) + } conv_dict = {} for i in range(num_interaction_blocks): @@ -407,11 +395,11 @@ def make_nequip_config_dict(**kwargs): ###rename to make_nequip_config_dict {'convolution': nequip.nn._interaction_block.InteractionBlock, 'convolution_kwargs': {'invariant_layers': 2, 'invariant_neurons': 64, - 'avg_num_neighbors': None, ### have to set this from data + 'avg_num_neighbors': None, # have to set this from data 'use_sc': True, 'nonlinearity_scalars': {'e': 'silu', 'o': 'tanh'}}, - 'num_layers': 4, ## this is a dead end argument, I believe. + 'num_layers': 4, # this is a dead end argument, I believe. 'resnet': False, 'nonlinearity_type': 'gate', 'nonlinearity_scalars': {'e': 'silu', @@ -419,28 +407,30 @@ def make_nequip_config_dict(**kwargs): ###rename to make_nequip_config_dict 'nonlinearity_gates': {'e': 'silu', 'o': 'tanh'}, 'feature_irreps_hidden': '32x0e+32x1e+32x0o+32x1o'}) - } + } conv_dict = {**conv_dict, **conv_dict_i} - - - - - output_dict = { - 'self_interaction_0': (nequip.nn._atomwise.AtomwiseLinear, - {'field': 'node_features', - 'out_field': 'outputs', - 'irreps_out': f'{GNN_feature_dim}x0e'}) - } + 'self_interaction_0': (nequip.nn._atomwise.AtomwiseLinear, + {'field': 'node_features', + 'out_field': 'outputs', + 'irreps_out': f'{GNN_feature_dim}x0e'}) + } - dict = base_dict | conv_dict | output_dict ## join the dictionaries + dict = base_dict | conv_dict | output_dict # join the dictionaries return dict -#nequip_config_dict = make_nequip_config_dict(**nequip_hparams) - class NequipWrapper(torch.nn.Module): + """ + This is a wrapper for nequip-style GNNs (this also includes allegro GNNs). + It then constructs an atomicdatadict, which is what the nequip GNN acts on. + The atomicdatadict contains coordinates, atom types (here every atom gets a unique type) + and edges (these are constructed by the distance alone, and not from the molecular bondgraph). + the invariant feature vectors the GNN outputs are flattened and returned. + """ + + def __init__( self, nequip_GNN: Union[torch.nn.Module, Callable], @@ -450,30 +440,37 @@ def __init__( ): super().__init__() if isinstance(nequip_GNN, torch.nn.Module): - self.GNN = nequip_GNN #register the GNN that is to be used in the conditioner + self.GNN = nequip_GNN # register the GNN that is to be used in the conditioner elif callable(nequip_GNN): - self.GNN = nequip_GNN(**kwargs) #create a GNN from the constructor, using the kwargs + # create a GNN from the constructor, using the kwargs + GNN = nequip_GNN.from_parameters(**kwargs) + self.GNN = GNN self.output_field = output_field self.cutoff = cutoff - def forward(self, x): #this takes as input just the cartesian coordinates of the atoms. + + + # this takes as input just the cartesian coordinates of the atoms. + def forward(self, x): batchsize = x.shape[0] n_cart_atoms = x.shape[1] distances = torch.cdist(x, x) in_bonds = distances <= self.cutoff indices = in_bonds.nonzero() - indices = indices[indices[:, 1] != indices[:, 2]] # remove edges from node to itself: + # remove edges from node to itself: + indices = indices[indices[:, 1] != indices[:, 2]] edge_index = torch.vstack([indices[:, 1], indices[:, 2]]) batch_index = indices[:, 0] offset_tensor = (batch_index * n_cart_atoms).repeat((2, 1)) - edge_index_batch = edge_index + offset_tensor # make sure that all edges are in their respective graphs (multiple graphs in a batch) + # make sure that all edges are in their respective graphs (multiple graphs in a batch) + edge_index_batch = edge_index + offset_tensor atom_types = torch.arange(n_cart_atoms).repeat(batchsize) batch = torch.arange(x.shape[0]).repeat_interleave(n_cart_atoms) r_max = torch.full((batchsize,), self.cutoff) - ## this is an atomicdatadict, it is used by the nequip GNNs to keep track of multiple graphs in a batch. - ## The graphs may have a different number of atoms in their neighborhoods, so each graph in a batch can have a different number of edges + # this is an atomicdatadict, it is used by the nequip GNNs to keep track of multiple graphs in a batch. + # The graphs may have a different number of atoms in their neighborhoods, so each graph in a batch can have a different number of edges data = dict(pos=x.view(-1, 3), edge_index=edge_index_batch.to(x.device), batch=batch.to(x.device), @@ -484,28 +481,39 @@ def forward(self, x): #this takes as input just the cartesian coordinates of th features = feature_vectors.view(batchsize, -1) return features + class WrapDistancesGNN(torch.nn.Module): - def __init__(self, **kwargs): + """ + A wrapdistancesGNN is a kind of GNN that takes as input the cartesian coordinates and returns a feature vector + that contains just the pairwise distances of all pairs that are less than r_max apart. + It does not return the distances directly, but instead embeds them into N radial basis functions + """ + + def __init__(self, num_basis, r_max, env_p): super().__init__() - self.N = kwargs["num_basis"] # number of RBFs - self.c = kwargs["r_max"] # cutoff - self.p = kwargs["env_p"] # envelope parameter + self.N = num_basis # number of RBFs + self.c = r_max # cutoff + self.p = env_p # envelope parameter self.wrapdistancespure = WrapDistances(torch.nn.Identity()) - wavenumbers = torch.Tensor([torch.pi*(i+1)/self.c for i in range(self.N)]) + wavenumbers = torch.Tensor( + [torch.pi*(i+1)/self.c for i in range(self.N)]) self.wavenumbers = torch.nn.Parameter(wavenumbers) - - def bessel(self, x, wavenumber, c): return ((2 / c) ** 0.5 * torch.sin(wavenumber * x) / x) + def envelope(self, x, c, p): - #taken from the paper: https://arxiv.org/pdf/2003.03123.pdf + # taken from the paper: https://arxiv.org/pdf/2003.03123.pdf x = x/c return(1 - ((p+1)*(p+2))/2*x**p + p*(p+2)*x**(p+1) - (p*(p+1))/2*x**(p+2)) + def forward(self, x): - distances = self.wrapdistancespure(x.view(x.shape[0],-1)) - binned_distances = self.bessel(distances[...,None], self.wavenumbers, self.c) + distances = self.wrapdistancespure(x.view(x.shape[0], -1)) + distances = torch.clamp(distances, min=None, max=self.c) + binned_distances = self.bessel( + distances[..., None], self.wavenumbers, self.c) u = self.envelope(distances, c=self.c, p=self.p) - binned_enveloped_distances = (u.unsqueeze(-1) * binned_distances).view(x.shape[0], -1) + binned_enveloped_distances = ( + u.unsqueeze(-1) * binned_distances).view(x.shape[0], -1) return binned_enveloped_distances diff --git a/bgflow/factory/icmarginals.py b/bgflow/factory/icmarginals.py index 2c8adce..5562367 100644 --- a/bgflow/factory/icmarginals.py +++ b/bgflow/factory/icmarginals.py @@ -158,3 +158,29 @@ def inform_with_data( lower_bound=torch.as_tensor(torsion_lower, **self.ctx), upper_bound=torch.as_tensor(torsion_upper, **self.ctx), ) + + def inform_with_data_o( + self, + data, + coordinate_transform, + o_lower=-2.5, + o_upper=2.5, + origin=None + + ): + with torch.no_grad(): + bond_values, angle_values, torsion_values, origin_values, rotation, _ = coordinate_transform.forward(data) + + + assert o_lower < origin_values.min(), "Set a smaller o_lower" + assert o_upper > origin_values.max(), "Set a larger o_upper" + o_mu = origin_values.mean(axis=0).squeeze() + o_sigma = origin_values.std(axis=0).squeeze() + import ipdb + #ipdb.set_trace() + self[origin] = TruncatedNormalDistribution( + mu=torch.as_tensor(o_mu, **self.ctx), + sigma=torch.as_tensor(o_sigma, **self.ctx), + lower_bound=torch.as_tensor(o_lower, **self.ctx), + upper_bound=torch.as_tensor(o_upper, **self.ctx), + ) diff --git a/bgflow/nn/flow/cdf.py b/bgflow/nn/flow/cdf.py index c6b61e1..ffad43d 100644 --- a/bgflow/nn/flow/cdf.py +++ b/bgflow/nn/flow/cdf.py @@ -33,6 +33,8 @@ def _forward(self, x, *args, **kwargs): logdet = self.distribution.log_prob(x) if self._eps is not None: logdet = logdet.clamp_min(-1/self._eps) + import ipdb + #ipdb.set_trace() return y, logdet.sum(dim=-1, keepdim=True) diff --git a/bgflow/nn/flow/crd_transform/ic.py b/bgflow/nn/flow/crd_transform/ic.py index 9e3ed10..b19c5bb 100644 --- a/bgflow/nn/flow/crd_transform/ic.py +++ b/bgflow/nn/flow/crd_transform/ic.py @@ -197,8 +197,8 @@ def _forward(self, x0, x1, x2, *args, **kwargs): alpha, dlogp_alpha = normalize_torsions(alpha) dlogp += dlogp_alpha - # beta, dlogp_beta = normalize_angles(beta) - # dlogp += dlogp_beta + #beta, dlogp_beta = normalize_angles(beta) + #dlogp += dlogp_beta gamma, dlogp_gamma = normalize_torsions(gamma) dlogp += dlogp_gamma orientation = torch.cat([alpha, beta, gamma], dim=-1) diff --git a/bgflow/nn/flow/crd_transform/ic_helper.py b/bgflow/nn/flow/crd_transform/ic_helper.py index af1c8cc..fcc1f82 100644 --- a/bgflow/nn/flow/crd_transform/ic_helper.py +++ b/bgflow/nn/flow/crd_transform/ic_helper.py @@ -331,12 +331,14 @@ def _to_euler_angles(x, y, z): """ converts a basis made of three orthonormal vectors into the corresponding proper x-y-z euler angles output values are alpha in [-pi, pi] - beta in [0, pi] + beta in [0, pi] #beta in [-1,1] gamma in [-pi, pi] """ + import ipdb + #ipdb.set_trace() alpha = torch.atan2(z[..., 0], -z[..., 1]) beta = z[..., 2] - # beta = torch.acos(z[..., 2]) + #beta = torch.acos(z[..., 2]) gamma = torch.atan2(x[..., 2], y[..., 2]) return alpha, beta, gamma @@ -638,6 +640,8 @@ def _callback(xs): # and compute the euler angles given this basis (range is [0, pi]) alpha, beta, gamma = _to_euler_angles(*basis) + import ipdb + #ipdb.set_trace() # now we flatten the outputs (x0, ics, euler angles) into a 9-dim output vec ys = torch.cat( @@ -676,5 +680,7 @@ def _callback(xs): ) dlogp = det.abs().log() - + import ipdb + #ipdb.set_trace() + x0=x0.squeeze(1) return x0, d01, d12, a012, alpha, beta, gamma, dlogp diff --git a/bgflow/nn/training/trainers.py b/bgflow/nn/training/trainers.py index d38982d..8da6c0c 100644 --- a/bgflow/nn/training/trainers.py +++ b/bgflow/nn/training/trainers.py @@ -16,7 +16,7 @@ class LossReporter: """ def __init__(self, *labels): - self._labels = labels + self._labels = list(labels) self._n_reported = len(labels) self._raw = [[] for _ in range(self._n_reported)] @@ -44,6 +44,10 @@ def losses(self, n_smooth=1): def recent(self, n_recent=1): return np.array([raw[-n_recent:] for raw in self._raw]) + def add_loss(self,lossname): + self._labels.append(lossname) + self._n_reported +=1 + class KLTrainer(object): def __init__( @@ -78,8 +82,11 @@ def __init__( self.w_likelihood = 1.0 if test_likelihood: loss_names.append("NLL(Test)") - self.reporter = LossReporter(*loss_names) self.custom_loss = custom_loss + if self.custom_loss: + loss_names.append("custom_loss") + self.reporter = LossReporter(*loss_names) + def train( self, diff --git a/notebooks/cgn_GNN_example.ipynb b/notebooks/cgn_GNN_example.ipynb index d6c2150..7cd66da 100644 --- a/notebooks/cgn_GNN_example.ipynb +++ b/notebooks/cgn_GNN_example.ipynb @@ -99,7 +99,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/srv/public/mameyer/bgflow/bgflow/distribution/energy/openmm.py:197: UserWarning: It looks like you are using an OpenMMBridge with multiple workers in an ipython environment. This can behave a bit silly upon KeyboardInterrupt (e.g., kill the stdout stream). If you experience any issues, consider initializing the bridge with n_workers=1 in ipython/jupyter.\n", + "/home/marcel/BG/bgflow/bgflow/distribution/energy/openmm.py:197: UserWarning: It looks like you are using an OpenMMBridge with multiple workers in an ipython environment. This can behave a bit silly upon KeyboardInterrupt (e.g., kill the stdout stream). If you experience any issues, consider initializing the bridge with n_workers=1 in ipython/jupyter.\n", " warnings.warn(\n" ] } @@ -149,7 +149,7 @@ "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAANwAAADUCAYAAAD3CU3sAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAobElEQVR4nO2de4wd93Xfv2dm7vvuch9ckhIpSiwtmY4lPyqCli0UjiXLlSK3QYwIsJM6DVpAKFAXNmJHdRrATf1PVcc2WjQGWqExkhauAz9kNH4IlmorMuxGpkXHjqRQlkwzFimJ5HK5y92975k5/eP3O7+ZvXv37pLcndnH+QCLuzszd+Z3786Zc37n9SNmhqIo2eDlPQBF2UmowClKhqjAKUqGqMApSoaowClKhqjAKUqG5C5wROQT0d8Q0TfyHouibDS5CxyADwM4mfcgFCULchU4IjoA4H4A/yPPcShKVuSt4f4zgIcAxDmPQ1EyIcjrwkT0XgAXmPkEEf3qkOMeBPAgANRqtduPHDmSzQAV5Qo4ceLERWaeWu04yiuXkoj+I4APAggBlAGMAniUmf/ZSu85evQoP/PMMxmNUFHWDhGdYOajqx2Xm0nJzH/AzAeY+SYA7wfw3WHCpijbgbzncIqyo8htDpeGmf8KwF/lPAxF2XBUwylKhqjAKUqGqMApSoaowClKhqjAKUqGqMApSoaowClKhqjAKUqGqMApSoaowClKhqjAKUqGqMApSoaowClKhqjAKUqGqMApSoaowClKhqjAbQPu8R7APd4DeQ9DWQMqcIqSIbkJHBGVieg4Ef2UiJ4nov+Q11i2Ok/EX8YT8Zfd34O0nWrAzUGePU06AO5i5kUiKgD4PhE9xsxP5zimLY0IVfSu293v3fuPAQCqhw8NfU9aYJWNIzeBY9MQc9H+WbA/uuC4sq3JtWsXEfkATgB4HYDPMfMP8xzPVuQe7wFE77odANB539sAAFGJgA/cseS4zrG9uOvuh83+qg8AeOrrv5/hSBUgZ4Fj5gjAW4hoDMDXiOhWZn4ufUy61fnBgwezH+Qm4h7vAWf63XfzQ2570OguOW7u5hK80BgL8zeROaYNxIUiACAsm23vueOTS84NqGm50WwKLyUzz8H0pbx3wL5HmPkoMx+dmlq1dbuibGryXMxjCkCPmeeIqALg3QD+U17j2WqEp04DMA6S0suXAADnj10HAPAioDVmtFg4ahYmiioECs224oI5x/TtI9h9PMtRK3malNcB+HM7j/MAfImZdRVUZVuTp5fybwG8Na/rbyXS86v+eFrQ6KJzcAIAUH81AgDMHvERlZaeIxqJ0CibGUSnYV7HXmTQsdsAAI8//YmB11PWl02xtoAynEE3vmy76+6HndNk5o1GyoIG0JkwTpO4bExK9gAqmN+DphG4xQOE7sgIgMSB8vjTn1BB20A2hdNEUXYKKnCbmLUkJZdevoTOVAWdqQquf+w1XP/YawgrRqOxByAi80MM9s1P+/oQ7etDxEUgrJkff2YR/sziknCDsv6owClKhugcbhMjcynRcsHhQ26bzLnCgxPojJrnZuvOfQCA6gUGByYE0B0z+7hHbj6H2IYM6jGKc2Z/80gS47zjtz8DAHj6Cx8FYILsj730qQ34hDsPFbgtwCAnBh9/FgBQOnwI/pMmJrdg07nCsgcKzXHBohGu3hjDa5qUrsKCFbgaI6ya4y6+0dwKxQWget54O10mymR9nT/RzkVNSkXJENVwm5R03uQgAltu0zk4Ad9mnYyeagAAGgeqKFot1txjTcrAQ1wwoQI2uxLHCgC2d0L1fITSrFGPYa04cFyAxuiuFhW4TcpqN7SkdvmnTrvgtZiZ1ePA5d95OwAgaJnjJ56Pcfmwka7yjNnWGSdnepKd3tXONtE4YOxMEbyg0XXmpQrataEmpaJkiGq4LcR9Nz/kNJtoNX9mEaHVbEGqqrt+xmSfSO3bwv4AxTmzL7KWYv0su1Kdxn6zrTNVQXOveU9rt3kdewn47nc+7sYAQL2WV4lqOEXJEDKdDrYGR48e5WeeeSbvYWwK0hkoou3EyeE/eSKpAh9fbsT06uY529hHqFw0///R00YjLt5QXKYdqy9MO80qBIcPLdu2k+d3RHSCmY+udpyalJuMtXoBB3Xp8lP7xMkhic3+zCL6Kbx5j2nHgMRkvOO3P+PeE1UrAExQvGiFS8zWx176lHosrwI1KRUlQ1TDbTKuVFuslGwsGi2yWSLNI1PwmyaDpLmvAAConuvBf/IEAODOzqcBACOP/hC+1WK+NVFLL19Cx5qo1Nc/RbkyVMMpSoao02STsd7zItGA0WTdBcbFoQIkczxxuERVH6Xp1pJtpZcvYfbYXgDAyBefdueQXioaIlCnyZZlkDPkSrel08JEGN5zxyeXCfE93gOJ+WizT6LJOjpTxlniBO/UaYzb9zyWusZj6iy5YvJcW+AGInqSiE7atQU+nNdYFCUr8tRwIYCPMvOPiWgEwAkieoKZ/y7HMeVOWlsN62Wy2rb+86WzUFzIYIX1BkIbKpA+RMHhQ05TDmpolA4VKMPJTcMx82vM/GP7+wKAkwD25zUeRcmCTTGHI6KbYFrmLVtbYKe0Or9SZ8kgTSPvTZ9jUGt0CRmkM0XS87+qbQ7bsSvvFK2zBehzuFjN1p9xoqxM7gJHRHUAXwXwEWae79/PzI8AeAQwXsqMh5cZVyNo8nd/K4bVzi+m3113P+w8jQIduy0p83lh2hyfeq/E7dJtF3TtubWTaxzOrgv3VQBfYOZH8xyLomRBbnE4IiIAfw7gEjN/ZC3v2QlxOGG1iu+V3gMMNjPv8R5YZgKuloAsZmh46vQyB8tKZuROzavcCnG4OwF8EMCzRPQTu+3fMfO38hvS5mGlG3fYXG/QtmFexWiyDpwyx0nFwUrXGCTA/ddV03J18lxb4PsAKK/rK0oe5O40UTaWtJZa5lU8NVgr9puPg8zb1TSslu4MRpOXFSVDNHk5R9ZDCwwKC6ymadIr5cj+9LJVg+Zig7JJhoUodhpbwWmyrVmLMF2tFzL93kFB7pUcGvI7pwLZ/a3TB60jrh7J9UNNSkXJENVwG8R6Pv1F06z1nOmeJsPGk9aEzqGCq08zU1ZHNZyiZIhquA1iPd3iMocaNDdb6VrpNbv7WS0PU7TdsMyVYedVVkYFboNYzYRb6Zi1nnNQ3dwggUx3Su7PNFkpRpeuEu8/3yDBU0FbO2pSKkqGqIbbAFaLR62W8zhsW39bvHRS8qDzira67+aHBmrCQVpXruGntg9zjOh6A2tHBe4aWSntaS1B4WFCttJxMp8b9p5BsbRB11vpwbAWwVlpPqkMZ1WT0tas9W/bvTHDUZTtzYqpXUT0LgD/C6aXzN8AeJCZ/97u+zEz/8OsBinkndq11tSl1Y7r1zSrVU8P00j33fwQOgcnACTV2FeicfrNwZXSusK+tQXS7xnETkteXmtq1zAN9ykA/5iZp2BaHDxBRHfI+ddhjIqy4xg2hysy8/MAwMxfIaKTAB4loo8D2DoZz+vIakWhqx230v5osr7MGZLWKnfd/TAAoHT40NBryb7V4mvpRR0fH9L+bhBpTTdMGyuDGSZwPSLax8znAICZnyeiuwF8A8DhTEa3RVirmbnS8Xz8WdiltpdUYz/x0tLj3nPHJ12F9iCB6l/rO32+QctL3eM94ARdjrvv5oeWVQasJLQrOYEGfUbFMMyk/DiAvekNzHwWwDsBPLyRg1KU7cowp8kUgKn+TshE9EYAF5h5+povTvR5AO+157t1tePzdpoMYq31aGmGlb8M0iByjub73obqo0tbd6Z7kYhmSycvD1qIUZh/8x4EHfP/f+rrv++uNajfZLpeLj329HE7ueJ7PZwm/xXA1IDtBwD8l6sdWB9/BuDedTqXomx6hs3hbmPmp/o3MvO3iegz63FxZv6e7bq8ZVhr16z+kpq0619I5yym512A0XTv/Cd/DADwbbfjH3zlY86BIksEp88r10xrqfS10ktXAUDtbHPZZ0i/ryudl7953C1d1X8OAG6uuZMrvtfKMJPyRWa+ZYV9P2Pm16/LAIzAfWMlk7Kv1fntv/zlL9fjsuvKMCEclDw8zKRMm4jzh2sAzEqlgFm5dPz4eQBw67UBcNtm7twHAKif6bqOyiIYfPzZZeO78zc/7UxK6bKcfk/aHO1PaH786U844U/H//o/U9rbup1bMayHSfkSEf3agBPfB+AX1zK4K4GZH2Hmo8x8dGpqkIWrKFuHYRruFpgQwP8DcMJuPgrg7QDey8wvrssAVtFwaTa70yS9TUibaOn4V5rOVAXFb5oVNGSxjPSqo4M03PQ7jTYbPd1FVDVpxrKAIpCsXtoZN7OG0Z9ecBkpQlT13XXFfKy+ML3EXASMpmsemXL7+0lnulxpQ6PtwjVrOCtQtwF4CsBN9ucpAG9aL2FTlJ1Grm3yiOiLAH4VwG4A5wH8e2b+05WO34waDhiehT9sziOktWBaq4nm8psRAKPh4sLSrLr6ma5zoNz5m59228Uhkp6H9edczjz4DhQXYnc9AFi8oeiOH3vBvLczVXGLNMp5w1px2Txx0JxvJbabtlurhhtmUjYARIN2AWBmHr22IV45m1XgRKgGxcFk28IH7nA3tSA3bPPIlBOq2SNm3dHq+QidMWOAdEfMzU4RENvaDbbFan4bqL+69N80f6OP0mXzfy3NGYHq1T14PbttPnbHdkblGuZ1908WcOlW68U8b87bmvRRP9MFAAQN89o4UEVkhTD9uWS/MKzVw3ZiPfpSvsjMb13HMSnKjmeYwG37BOWrmdj3Z46kFzBMm5GdqQoAQAy00ny8RDsYjInXGfUAq2l866bvjHlOs0V2sW2vB5BVTkHbvHZ2AZf/gVF3sb1YYR7wbHJma7fZF1aA6gVzbjEPoxI5bSem6vljI0571ox/BnGQOF/EzG3u9TH6srmI7OvVPdTP2PHZz3rX3Q+739NZKoOyU3YCwwRuDxH93ko7mfmzGzCeTFlrXCjdfKc/hpY25sSM7ExVnIkoXsewRKjaOU5k42sX32QkqbjAzvRr7DOCF5WA3og5b/GyeWUfKDTM7609dpuXCFf9rBWoMmHxgJh7ZtvImdiZqCJks7f4qF4wx7UnErNVhPq1t5nbozgPlObMNvlcfsdHazLdhAHweozmvoK9rtkWNLrugbQkaH5q6XeWjhNut/ldmmEC5wOoQ2vfFGXdGOY0yaWqexjr7TRZbU3sQZkh/ZXP4anTLoYlcTB/ZhHn3nOd2TaXOCjCSmLKAUBkTcCoZMxFAAiN8kNvhFG8bI4LrQXKBMQl8//y23IuRmHB/C6mYNBIPoecr3rOaC8A8EJzju4IOXOVbYCIg8QxI8dPPhdh9hZz8rJdErw7mpiokz8w6mz+zXtQmpVCowQxKUXTRe+63XlW16rNNrvWW49ME9VsirLODDMp785sFDmx0hJOgsy/0k9XOU7mbt0jx9y8Rp7gZz72DhTnzX5xgohjAwCisnklqwziItCzyR2ipaIyoxtbzVUw54hGIiC02zzzrIwL7OZcohG7u4DCgr2GHejCjcnnrp21mrMChJJUYs/hhcl7hPkbk7E39pvX0iVg7nXmPL3adW6/5GYu7De3VuVi5BwtSGWzrIX+xrXpbZtV063GigLHzJeyHEgeDOtUDADoS4UClppFgDEjJY1q4QOm5cvkyRAX32i+2qCdGBEiVGLG+R3zSjEQ1syNGlfsnc9AOGbv/JJ97XmAFUKxP7yQnJnZ3G/eWz7vO49lb8Ts4wLD65g3Ld5kL+ED7Jv9FKWcJj3zu5i5rb1IPUCS76Jg49zjL5iNnfEAl26xgjaTTFVEwMQc77zr9qH9NYX0/0S+b+Guux9eUjEBDF+7brOgnZcVJUN0BdQ+0k/edGkLsNS0kTSp5r6Cc7PPHzRPd7/Dzs0urn32gdhaVnHZfOdBwx4zGoNHjX3pF602o9T/ha2DpO27RyT55po8VwSNWafEovF2eE0PntVSvUmjpqjtg8vm3NS1J4kJ8Ox1Avsapqbu9le/6TknTVgxxxUWCL6NBdZfTcIRQdvuXzTja+32XfqYfE/VF6aXOZ/SCdNiMQjf/c7HXVpc2gHT/15/ZjG37s/r4TRRFGWd0VbnGBz4Dk+dBk2aMppBVduuGrvDLnNDnA1xQE6zies/aAG9SauVitYtb+do5MfwC2YfWc3m+exyfYLAHBcGPnzRbFbr9XxGdNlexGqreLILzJptFNpnajEGieu/aM7hlSLEXd8eZ87nj3cR9eyBi+b2iIuMyM4TxUHDXqIJezXze6HBLl9TnCdAksWSLh+SEqWO1WZBo+uSn8Pa0rnzXXc/7BwvF99kvtiRA29z+5t77WeIRpYtRLnZcjnVpLT0T7zFhAGS7Pp0HZukZ9XONp0JNPNG4w0Jq8mN2bZN4eNyjFjMtrrN1igZMzIoJG7BWtmYTIutEibqJjO/FxsBKHgxFtrmGgUrhK1OAWHXpl2JoDDg7EExTdu+c76QFUy/GIHtuf3UGDpz1qsjpyCAGnJTW4dKmxBY+Qlsp4biPJz5KHV7nYMTrgIhXX3Q76mMJusuHU68vvOHzPtq5yN3vFS1Lx4gF28Mzdvgd4CJF8132rBCOPHcYiZCpyalomxC1KS09Dc99ZGYkv0T9jRhreiSdyVDo3KR0dpt42Wi1WKCP2m8DLHVRCN18/fe+gJiayKGVuNMVpN0EdlX9COMlMx7Xr28CwBQr3TQ8oxWiYrmvQSg2zFjKlgtGpU8+IE1Ja3Wi5kQ+GZ/q2m1Sa2DXsneFlbDxW0fXpIw4/ZJBoxoOPYT81GcSp3xwDlQxGSsziw6TSVaz59ZRCn1nQJGOwHA3JE6fHu+9niSpRPaeGZvzJw/WPBwbpdNpB61pm2rNrB8Ki/UpFyBe7wHlnUypmO3DWxdIO0O4sDebGNA6zob/7JxLh4NUa6beNWeXSYq7dlJ2kix44SqHBgBqPpdtCMjwUUrFM2w6I673DF3WydKnpmd0JqWMYGssLQ75hzyPgAI/MR8jKyAF62J2ukGbn4YdoxZxqEHb86ep5DcL0FDKhzM37tOMeK+R7jU1KVJxy7leyzNhi7xWbyZkhwdlgldW30pQl5YAJr7ecmY4moEsnHKYNYmXs8l6Wgjr5jvsfjN4+sudGpSKsomRE3KFUgnL7v428wi/Bnz6+wRk84UHduLXnVpknFnYqlmA4wnslYxqqBrNZGYh0U/wmTRmJAT9jVkf4n2AoCK33PbxCxsp47pRrYujgk9G/TzxdyMPRStZpP3zqccM4LnxQjlPJLVEjCiqtVmC/a8LULZfhfitABMUyMgcXjMHQpcRXo6sVm8jqM/vWC+s4MTrnLcpYJZ2E8yXYTWvkTTitlOPgMdM85w1FwzLnhga3mMvGI/Tl9GS5bkquGI6F4i+hkR/dyuyqMo25rcNBwR+QA+B+AeAGcB/IiI/rJ/LYM86S+IpGO3Ye6IyWqYOmHmYZduraPQNE/YXt08SeMCOw0nmSPkxagUzBO8EpjXfRVzjpuqM+6aE9bXHYEw2zO1NQXPFrOyh3nrKaj45hwLYQnN0GiTsp3rzXdLTkvJnLDgRWiHZo4ksb5iIXJasdk25/A8TkIONogYdwJ4zaVah9gkSQNIMk7OdF1RrYQMxk6HLv4m353XKzhtJvPfoMVum4QFYuva79WS8qGiTcpevClG4bJUulsHFQHB7rb7HADQXSihZy2K6dvM640v5JdrmadJeQzAz5n5FwBARH8B4NcBbBqBG4Q005FJf1QitCfNPqkCgAd4u8xxhaK54Ucq7WWCdl3ZlHL7FGN3YLYVbPR8zG9iwm+4/QDQiEvYFZjrXrSRdY9iBHZ/aO/KoBSjaq/lzMduyZmwIpiXqIq2vRmrNv7XbBedKSlBcY7IOSYqlxKjqGKsQRd7m35LycUfKxeXO0vSTYykgVJp3JSul2ZDZ4aK80lgDxj7uTnfa//ImrYtD5FNkeMRG8+s9JwnVuKZc10fPRtHjwPzWWfu3Icxmw6WteDlaVLuB3Am9fdZu20JRPQgET1DRM9MT1/zgj2Kkit5arhBBa7LYhTM/AjMksc4evRoZjGMQd2TH3v6Ey6lq5Nq+uMquWu2Ic9oiFrNPGGjyBw3VmlhT8XElcQcrFt/+t7CZYxYu2zMarUCIkwFxlPQiG3rPK+D2Gox335VVa+L+cB4LWa6xmvTBlC3ZmjX+umLfuRCA13rUCl4EcplM5bZlnkvESO0Y2abVVK4FLj/jGvTl/rvtSZtJkwjcaCkHSQX32K08dhL5vMGja6riJfSnqjqu+p4qS6XXi0UAxffJJkuNgRQZHcHFe137fuxC2kIlVoXDft54qo5/9wtPjxrro8cR6bkqeHOArgh9fcBAK/mNBZFyYQ8NdyPANxMRIcAvALg/QB+K8fxLCEdFkhnocjSUZIV0R0JXHW1q572E8fDeM243SdKLYzY3nb7SkZz7bZv3Fe4jDHPHDfimWP8lLIf84zTocc+LkRGW+zyzfEexVi0Fa1TRZuZ0auiZzXhSGA0SNELnXNFNF0LBcw1zZNe5m3dboDQZqlIGU84EsGzv0tVOZBos+KCZHUwRl6xYQjr2m9NJuU5SVfnpPxm+i1m7BMvhq6rmKgBCQW09iSFtHE5To6p2iwaq8Eq5S7GK0mCNADU6l0s1M01Fjvm9VI4jtnX24vYouFBy4ltBLkJHDOHRPQhAN+GyaT6PDM/n9d4BjFoIi2JzIu/83YAQHsC6I7bybu9GchjF/+qFYy5s6vQQj0wv+8pmDtpzArNlD+PEc/sG7FOkxHPR49t12QrfD3uoWz7MpTJmILnwl1oFswN3LPmYysqIrD2ViM0N5mYogCw2LMOn1Rsrm3f6/nskpuddRYRAtuoSCrJS6l+AJJNMncocEIjPSu9Hrt0L6kqqJ6PUD9j9kvCd2vSd3047Vfhsku8bmJewtb5xWM9FKvmOxipJvHMXdYxtLdsHmaBFzmzWmKYvyx18MvTxlnTuGxrGN+8B1mQa+Cbmb8F4Ft5jkFRskQzTYbQ7zJ+Iv6yc5qE5ZRpNWI1mzVx9u+ddSGAsaJ54taCDm4omXibaLYbAvN3jXoYs06OKtn4Gfko299F07URwVqv6MHsK1KEqlUJi1aLTRUXMGvTXmLfag2KXVaKhAoKXoSejcN1ezbptx2AJcPEVoijEbjmRrFUAPmAbzXR3CHz3vqrkauDkzzIsRcWXfxt3+OvATDt9CRvUuiMkjNRvXDJLrBnMluARMOCgMjW+klO6u5KA/urc+a7tQmUYqIDwEJs4jb7K3Mo2dDIqapZhiusVdz/Vpbw2ohQgeZSKkqGqIYbwqClgUuuCNU8jhcPJF+hZ7NLSn6I0aJxVlxfmQMAjAdNTPjGqTHlyxzOhgK8CCNks+bta4HSmR1GW7Y5QsFGlmtk1MsM1+HbHnc9q4Yi9tCReYt9bUcFNMOlWmWmUXMOFHGaEDFY/O2SXUKc5Ct2JJsmVeFuT5uu8paSnLBWdCU40ntkabNY8+bOmOcay3Z22bFY/0hrbxICiEt2Y5xYGJI3Wi90sN/2ZD9QMBpuf2HWOaAie5Jz4VhSddE2/8fzNwa42DVj2dtM2iOut5ZTgVsD6UXrz91uvIRV6yjojcUuRlWtGiGrBj2UbaxNMkcKXoiy7TsnDg+5EcqUGBoiaB48xFaQIltC1WTGgo3Jda1wxfDQY2sq2vO12XfnDu3N2AyLKFqzdbppbvxyoYdFW0HetfVwXhCBrUeSbHZJ6VzghEt8L34bLqVN4mfpspvqC+aGnz22160zINklnalkeS6JYRYanGrVYI6XcpqgFaAzZrY1JO94MYBv42+SvuYRu4eO+z4RoWgHL99ZzevgOrtgw3V18/CbqdUR1pY+kNSkVJQtjmq4IfQ7TaLJusuMkLw/v0EI9xiN1bG9RUL2UBOPgmWX33JP2tg+59r2ibsQ91AWRwYnixu2WY4XzeVh3k7852LjFJmLqliwSZxiPr7SGkuV7yQFqHOdVB0NgGa36JJ8pe1e3PNd46Fg3i7WOB6jdNH8XrVLWFGUrJEgXLq1nnJ4mATG0VMN16skjRSeSkjBbyYdmqUfiWTzsJ9UlQeXUx2sQ/v9WS19rjmCoh2AmJkLcQUjNo4p3/90OIKXO2Z8UsgbR+Qa3wobEZtTgRtCv0nBx58FbOBb5hnwGN68+Rq5boSs1Stg1qZZ1Wzg+XJUQdO62G6w84uub+dcaKEIMXvMOWIAPTvPEMG8ENVxKTLm4BmbkTvbq2HeuvekkuBCu+4EbqFbcmOStKeGrQzoNguuiRDZz0CcNAqSvplej5xZKJ7E8mxS3e3mbnOxq2+T+VpYK7oGQPOpWJcsXZzuZC3dz0Roxdz0O4zY1hxK/C8qA62GGcCiff3FQglzk2aAJz1ThbB/5LK7ZtU+BF9tjqJpY5HnL5pgX+2nZew5Yf5XEmtdvizJtaMmpaJkiGq4K+CJ+MuuoVDtnKwc6iWV3ruM+SapRABQtdklC2EZB8qzAICfRSZx90DRPK7PhSH22ETlmdg6WRDhUmy0hEz2p8NRXI7ME1yySs62x+FZd96rTfO0nmnWENhMF8l4abaLaDfMU13ibNzzkjUKpElQTM5BImZcFCdZH7t+YXYuXu/j+sdMXE0aBoWl5HYS50np5UtoHjGxLvFORlU/2W8Tw4OXL8F/0nRjlnUERNPImg0GcmMrXTLft6zZAK+AuYkS0pwbH4dXNteNpTN1ywOsFi9P2wZN84lT57ENLNVRDacoGaIabgVWisHIU3dcFmSs7EPjeqsmWkYTzZwfRWvMTNRlDlUp9DBdNo/i0YKJv51umcrV0aCNPbaUWbRVgSJ0bIDron2Ex0y42LVV09aRcrFTGzj+yM7XpqdNWTb3PFduI0mSaQeEZHIQJ6vmSCV3+RK7vi2ipWpnm25OVn30h2bbsductpPvqXl/spyXaBAg0YpuvJN11+nahQ+sphs/fn7JWg6AmdfJWMQBE6S6YMu+xRuKCMvmPbLuAbC0PR9gYqwbqdkEFbgVWCkGI+aO1DNPPLeI0pyxKRdsg53FGz20Z80/uW2D4bNTHUyXbOC3YG6G60aNGXk6nMRkpa+ZDxiBjZvNdRMv38WGETDfehd7sefSsiIxFZnQWjAOFBJbse2bJjsAgkvJv13WD5dWclERrumPkL6RhcaBauIgsckA/syia1PuzMJm5Pp5itAEje6yRkHy/jQieOGp0/huX9/Q9P8nnZgg1RxyzV3/84Q7Lt0y4/GcelOqSakoGaIabgj9a4A/EX/ZmUrppZKkJKV6waYzVTzXRKdxwLzS2TI6dvEO0XqLVgvFoY9XKsaOk4U7ioUwWbDDxpuYCe1FY6J69jhmcuU0UcOaW00Pnmg76fxcilF6zZpe0ik5SJKBZd1v9gm9uq2Dswnae7/zmnPzp9dUEI0l2qqERLOkTbWwzwmC1BoNsmxVdPB2APUl75VrPvHS8vXW0ya/nCNdwzhMg+XZeVk1nKJkiLY6XwPpp2u/1gsOH3JPWGmNHtaK7mn+6kPvAGC0ifTCt4Xf6NjC1aiaNO4P5m1JzmTolpCSZaiKFwJ0x5c2+Q+a5BJ5wwmbU9jwk6WubPA6qjKCxcSlDhjniIQ3pL14ekmptJYS0isI9S9UGZ46vWw9BlkWGFg61+pnUOPdvBZXvBrW2upcTco1MMgE6e9ZCaSyJu4/5uJLNnkdQJKlIeZmada2OpjyXSxPMsJoOkhSjaQBciNZM1yEF0jW2obn2+PIJRmLcPndpFOyrMQ6djp0WSDOu5haRTS9Guygz+s+V8qkE5z5iET4/JSgDTpf//c8bN9WRU1KRcmQXExKInoAwB8BeAOAY8y8JjsxL5NyNQaZSt37jwFIuggDSRMdaYgqLNzgoXbO1oKN2nzCyMSagCS3sD1OTgNKeVA6gbi5x7bQ6ya1aqJV40KSJSJOkcpMMjYxJcNacYkZKJ9vkBbr//yDTMB0AnBe3Y6zYK0mZV4C9waY/Nz/DuBj20Xg0jecbEsv+C5mmwvsptZQSy9qARiT7LJtVCTdntPL8krQuXa2ifnDJjYnyb5emDTsEfOxNL886MvHn3VzskHzte0oGBvFpp7DMfNJACCi1Q5VlG2FOk3WgX6v2z3eA87TJo6UEInTQMo+5O8qgPn3mUXiJSWpe/8xTP7A1MQ4TThVcf0w02UwzsNo+zh2Rj1MnUi0GGAyP/pjZDh8CGFqfP1slAm4Ea0LtgobJnBE9H8B7Buw6w+Z+f9cwXkeBPAgABw8eHCdRqco+bBhAsfM716n8+SytsC1sNLTu9+lnl7S+Adf+RgAuPWo/Sefdcm06XXHJYNiSUs366CR9m7h+97m5nXSO99/8oTr5Zw+bzqGJmO/Ws22Vs21U7UboCblurDWG2gtxz3+9CcADI53pYXhqVRHMZJgtRXg6qM/RNOaqOl4VzpoDSwN2gvXYu7tZEFaK7nE4YjoN4joLIC3A/gmEX07j3EoStbk5aX8GoCv5XHtzcqwWFZ6Tep+LZI+Pm0KyvmGmYcbkTq1nWNt64FmmihKhmjy8jZFNU22bOrAt3LtDDNBd3Kca7OjJqWiZIhquC3KMIeHarfNi2o4RckQFThFyRAVOEXJEBU4RckQFThFyRAVOEXJEBU4RckQFThFyRAVOEXJEBU4RckQFThFyRAVOEXJEBU4RcmQvHqa/DERvUBEf0tEXyOisTzGoShZk5eGewLArcz8JgAvAviDnMahKJmSi8Ax8+PMLM1+nwZwII9xKErWbIY53L8A8Fjeg1CULMi11TkR/SFMW/svDDmPa3UOoENEz633WDeA3QAu5j2INaDjXD9ev5aDcuvaRUT/HMC/AnA3MzfX+J5n1tIZKW90nOvLVhjnWseYS08TIroXwL8F8M61CpuibAfymsP9CYARAE8Q0U+I6L/lNA5FyZS8Wp2/7irf+si6DmTj0HGuL1thnGsa45bqvKwoW53NEBZQlB3DlhO4rZIWRkQPENHzRBQT0abysBHRvUT0MyL6ORF9PO/xDIKIPk9EFzZ7GIiIbiCiJ4nopP1/f3jY8VtO4LB10sKeA/A+AN/LeyBpiMgH8DkA9wH4FQAfIKJfyXdUA/kzAPfmPYg1EAL4KDO/AcAdAP71sO9zywncVkkLY+aTzPyzvMcxgGMAfs7Mv2DmLoC/APDrOY9pGcz8PQCX8h7HajDza8z8Y/v7AoCTAPavdPyWE7g+NC3sytkP4Ezq77MYcoMoa4eIbgLwVgA/XOmYTbmYx3qlhW00axnnJoQGbFNX9TVCRHUAXwXwEWaeX+m4TSlwzPzuYfttWth7YdLCcrtZVhvnJuUsgBtSfx8A8GpOY9kWEFEBRti+wMyPDjt2y5mUqbSwf6ppYVfFjwDcTESHiKgI4P0A/jLnMW1ZiIgA/CmAk8z82dWO33IChy2SFkZEv0FEZwG8HcA3iejbeY8JAKzD6UMAvg0zwf8SMz+f76iWQ0RfBPDXAF5PRGeJ6F/mPaYVuBPABwHcZe/HnxDRr610sGaaKEqGbEUNpyhbFhU4RckQFThFyRAVOEXJEBU4RckQFbhtBBFNplzT54joldTfzdRxtxDRt2y1wEki+hIR7bXvf5KIFonoT/L8LNsVDQtsU4jojwAsMvOn7d+LzFwnojKAZwH8HjN/3e57F4BpAKdhcgFvhanI+FAug9/GbMrULmVD+S0Afy3CBgDM/GRq//eJ6GpbYCiroCblzuNWACfyHsRORQVOUTJEBW7n8TyA2/MexE5FBW7n8b8BvIOI7pcNtsfJbTmOacegArfDYOYWTC3hvyGil4jo7wD8LoALAEBEfw/gswB+12bpb8Z+J1sWDQsoSoaohlOUDFGBU5QMUYFTlAxRgVOUDFGBU5QMUYFTlAxRgVOUDFGBU5QM+f9pfpc6Lh3BygAAAABJRU5ErkJggg==\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAANwAAADUCAYAAAD3CU3sAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAobElEQVR4nO2de4wd93Xfv2dm7vvuch9ckhIpSiwtmY4lPyqCli0UjiXLlSK3QYwIsJM6DVpAKFAXNmJHdRrATf1PVcc2WjQGWqExkhauAz9kNH4IlmorMuxGpkXHjqRQlkwzFimJ5HK5y92975k5/eP3O7+ZvXv37pLcndnH+QCLuzszd+Z3786Zc37n9SNmhqIo2eDlPQBF2UmowClKhqjAKUqGqMApSoaowClKhqjAKUqG5C5wROQT0d8Q0TfyHouibDS5CxyADwM4mfcgFCULchU4IjoA4H4A/yPPcShKVuSt4f4zgIcAxDmPQ1EyIcjrwkT0XgAXmPkEEf3qkOMeBPAgANRqtduPHDmSzQAV5Qo4ceLERWaeWu04yiuXkoj+I4APAggBlAGMAniUmf/ZSu85evQoP/PMMxmNUFHWDhGdYOajqx2Xm0nJzH/AzAeY+SYA7wfw3WHCpijbgbzncIqyo8htDpeGmf8KwF/lPAxF2XBUwylKhqjAKUqGqMApSoaowClKhqjAKUqGqMApSoaowClKhqjAKUqGqMApSoaowClKhqjAKUqGqMApSoaowClKhqjAKUqGqMApSoaowClKhqjAbQPu8R7APd4DeQ9DWQMqcIqSIbkJHBGVieg4Ef2UiJ4nov+Q11i2Ok/EX8YT8Zfd34O0nWrAzUGePU06AO5i5kUiKgD4PhE9xsxP5zimLY0IVfSu293v3fuPAQCqhw8NfU9aYJWNIzeBY9MQc9H+WbA/uuC4sq3JtWsXEfkATgB4HYDPMfMP8xzPVuQe7wFE77odANB539sAAFGJgA/cseS4zrG9uOvuh83+qg8AeOrrv5/hSBUgZ4Fj5gjAW4hoDMDXiOhWZn4ufUy61fnBgwezH+Qm4h7vAWf63XfzQ2570OguOW7u5hK80BgL8zeROaYNxIUiACAsm23vueOTS84NqGm50WwKLyUzz8H0pbx3wL5HmPkoMx+dmlq1dbuibGryXMxjCkCPmeeIqALg3QD+U17j2WqEp04DMA6S0suXAADnj10HAPAioDVmtFg4ahYmiioECs224oI5x/TtI9h9PMtRK3malNcB+HM7j/MAfImZdRVUZVuTp5fybwG8Na/rbyXS86v+eFrQ6KJzcAIAUH81AgDMHvERlZaeIxqJ0CibGUSnYV7HXmTQsdsAAI8//YmB11PWl02xtoAynEE3vmy76+6HndNk5o1GyoIG0JkwTpO4bExK9gAqmN+DphG4xQOE7sgIgMSB8vjTn1BB20A2hdNEUXYKKnCbmLUkJZdevoTOVAWdqQquf+w1XP/YawgrRqOxByAi80MM9s1P+/oQ7etDxEUgrJkff2YR/sziknCDsv6owClKhugcbhMjcynRcsHhQ26bzLnCgxPojJrnZuvOfQCA6gUGByYE0B0z+7hHbj6H2IYM6jGKc2Z/80gS47zjtz8DAHj6Cx8FYILsj730qQ34hDsPFbgtwCAnBh9/FgBQOnwI/pMmJrdg07nCsgcKzXHBohGu3hjDa5qUrsKCFbgaI6ya4y6+0dwKxQWget54O10mymR9nT/RzkVNSkXJENVwm5R03uQgAltu0zk4Ad9mnYyeagAAGgeqKFot1txjTcrAQ1wwoQI2uxLHCgC2d0L1fITSrFGPYa04cFyAxuiuFhW4TcpqN7SkdvmnTrvgtZiZ1ePA5d95OwAgaJnjJ56Pcfmwka7yjNnWGSdnepKd3tXONtE4YOxMEbyg0XXmpQrataEmpaJkiGq4LcR9Nz/kNJtoNX9mEaHVbEGqqrt+xmSfSO3bwv4AxTmzL7KWYv0su1Kdxn6zrTNVQXOveU9rt3kdewn47nc+7sYAQL2WV4lqOEXJEDKdDrYGR48e5WeeeSbvYWwK0hkoou3EyeE/eSKpAh9fbsT06uY529hHqFw0///R00YjLt5QXKYdqy9MO80qBIcPLdu2k+d3RHSCmY+udpyalJuMtXoBB3Xp8lP7xMkhic3+zCL6Kbx5j2nHgMRkvOO3P+PeE1UrAExQvGiFS8zWx176lHosrwI1KRUlQ1TDbTKuVFuslGwsGi2yWSLNI1PwmyaDpLmvAAConuvBf/IEAODOzqcBACOP/hC+1WK+NVFLL19Cx5qo1Nc/RbkyVMMpSoao02STsd7zItGA0WTdBcbFoQIkczxxuERVH6Xp1pJtpZcvYfbYXgDAyBefdueQXioaIlCnyZZlkDPkSrel08JEGN5zxyeXCfE93gOJ+WizT6LJOjpTxlniBO/UaYzb9zyWusZj6iy5YvJcW+AGInqSiE7atQU+nNdYFCUr8tRwIYCPMvOPiWgEwAkieoKZ/y7HMeVOWlsN62Wy2rb+86WzUFzIYIX1BkIbKpA+RMHhQ05TDmpolA4VKMPJTcMx82vM/GP7+wKAkwD25zUeRcmCTTGHI6KbYFrmLVtbYKe0Or9SZ8kgTSPvTZ9jUGt0CRmkM0XS87+qbQ7bsSvvFK2zBehzuFjN1p9xoqxM7gJHRHUAXwXwEWae79/PzI8AeAQwXsqMh5cZVyNo8nd/K4bVzi+m3113P+w8jQIduy0p83lh2hyfeq/E7dJtF3TtubWTaxzOrgv3VQBfYOZH8xyLomRBbnE4IiIAfw7gEjN/ZC3v2QlxOGG1iu+V3gMMNjPv8R5YZgKuloAsZmh46vQyB8tKZuROzavcCnG4OwF8EMCzRPQTu+3fMfO38hvS5mGlG3fYXG/QtmFexWiyDpwyx0nFwUrXGCTA/ddV03J18lxb4PsAKK/rK0oe5O40UTaWtJZa5lU8NVgr9puPg8zb1TSslu4MRpOXFSVDNHk5R9ZDCwwKC6ymadIr5cj+9LJVg+Zig7JJhoUodhpbwWmyrVmLMF2tFzL93kFB7pUcGvI7pwLZ/a3TB60jrh7J9UNNSkXJENVwG8R6Pv1F06z1nOmeJsPGk9aEzqGCq08zU1ZHNZyiZIhquA1iPd3iMocaNDdb6VrpNbv7WS0PU7TdsMyVYedVVkYFboNYzYRb6Zi1nnNQ3dwggUx3Su7PNFkpRpeuEu8/3yDBU0FbO2pSKkqGqIbbAFaLR62W8zhsW39bvHRS8qDzira67+aHBmrCQVpXruGntg9zjOh6A2tHBe4aWSntaS1B4WFCttJxMp8b9p5BsbRB11vpwbAWwVlpPqkMZ1WT0tas9W/bvTHDUZTtzYqpXUT0LgD/C6aXzN8AeJCZ/97u+zEz/8OsBinkndq11tSl1Y7r1zSrVU8P00j33fwQOgcnACTV2FeicfrNwZXSusK+tQXS7xnETkteXmtq1zAN9ykA/5iZp2BaHDxBRHfI+ddhjIqy4xg2hysy8/MAwMxfIaKTAB4loo8D2DoZz+vIakWhqx230v5osr7MGZLWKnfd/TAAoHT40NBryb7V4mvpRR0fH9L+bhBpTTdMGyuDGSZwPSLax8znAICZnyeiuwF8A8DhTEa3RVirmbnS8Xz8WdiltpdUYz/x0tLj3nPHJ12F9iCB6l/rO32+QctL3eM94ARdjrvv5oeWVQasJLQrOYEGfUbFMMyk/DiAvekNzHwWwDsBPLyRg1KU7cowp8kUgKn+TshE9EYAF5h5+povTvR5AO+157t1tePzdpoMYq31aGmGlb8M0iByjub73obqo0tbd6Z7kYhmSycvD1qIUZh/8x4EHfP/f+rrv++uNajfZLpeLj329HE7ueJ7PZwm/xXA1IDtBwD8l6sdWB9/BuDedTqXomx6hs3hbmPmp/o3MvO3iegz63FxZv6e7bq8ZVhr16z+kpq0619I5yym512A0XTv/Cd/DADwbbfjH3zlY86BIksEp88r10xrqfS10ktXAUDtbHPZZ0i/ryudl7953C1d1X8OAG6uuZMrvtfKMJPyRWa+ZYV9P2Pm16/LAIzAfWMlk7Kv1fntv/zlL9fjsuvKMCEclDw8zKRMm4jzh2sAzEqlgFm5dPz4eQBw67UBcNtm7twHAKif6bqOyiIYfPzZZeO78zc/7UxK6bKcfk/aHO1PaH786U844U/H//o/U9rbup1bMayHSfkSEf3agBPfB+AX1zK4K4GZH2Hmo8x8dGpqkIWrKFuHYRruFpgQwP8DcMJuPgrg7QDey8wvrssAVtFwaTa70yS9TUibaOn4V5rOVAXFb5oVNGSxjPSqo4M03PQ7jTYbPd1FVDVpxrKAIpCsXtoZN7OG0Z9ecBkpQlT13XXFfKy+ML3EXASMpmsemXL7+0lnulxpQ6PtwjVrOCtQtwF4CsBN9ucpAG9aL2FTlJ1Grm3yiOiLAH4VwG4A5wH8e2b+05WO34waDhiehT9sziOktWBaq4nm8psRAKPh4sLSrLr6ma5zoNz5m59228Uhkp6H9edczjz4DhQXYnc9AFi8oeiOH3vBvLczVXGLNMp5w1px2Txx0JxvJbabtlurhhtmUjYARIN2AWBmHr22IV45m1XgRKgGxcFk28IH7nA3tSA3bPPIlBOq2SNm3dHq+QidMWOAdEfMzU4RENvaDbbFan4bqL+69N80f6OP0mXzfy3NGYHq1T14PbttPnbHdkblGuZ1908WcOlW68U8b87bmvRRP9MFAAQN89o4UEVkhTD9uWS/MKzVw3ZiPfpSvsjMb13HMSnKjmeYwG37BOWrmdj3Z46kFzBMm5GdqQoAQAy00ny8RDsYjInXGfUAq2l866bvjHlOs0V2sW2vB5BVTkHbvHZ2AZf/gVF3sb1YYR7wbHJma7fZF1aA6gVzbjEPoxI5bSem6vljI0571ox/BnGQOF/EzG3u9TH6srmI7OvVPdTP2PHZz3rX3Q+739NZKoOyU3YCwwRuDxH93ko7mfmzGzCeTFlrXCjdfKc/hpY25sSM7ExVnIkoXsewRKjaOU5k42sX32QkqbjAzvRr7DOCF5WA3og5b/GyeWUfKDTM7609dpuXCFf9rBWoMmHxgJh7ZtvImdiZqCJks7f4qF4wx7UnErNVhPq1t5nbozgPlObMNvlcfsdHazLdhAHweozmvoK9rtkWNLrugbQkaH5q6XeWjhNut/ldmmEC5wOoQ2vfFGXdGOY0yaWqexjr7TRZbU3sQZkh/ZXP4anTLoYlcTB/ZhHn3nOd2TaXOCjCSmLKAUBkTcCoZMxFAAiN8kNvhFG8bI4LrQXKBMQl8//y23IuRmHB/C6mYNBIPoecr3rOaC8A8EJzju4IOXOVbYCIg8QxI8dPPhdh9hZz8rJdErw7mpiokz8w6mz+zXtQmpVCowQxKUXTRe+63XlW16rNNrvWW49ME9VsirLODDMp785sFDmx0hJOgsy/0k9XOU7mbt0jx9y8Rp7gZz72DhTnzX5xgohjAwCisnklqwziItCzyR2ipaIyoxtbzVUw54hGIiC02zzzrIwL7OZcohG7u4DCgr2GHejCjcnnrp21mrMChJJUYs/hhcl7hPkbk7E39pvX0iVg7nXmPL3adW6/5GYu7De3VuVi5BwtSGWzrIX+xrXpbZtV063GigLHzJeyHEgeDOtUDADoS4UClppFgDEjJY1q4QOm5cvkyRAX32i+2qCdGBEiVGLG+R3zSjEQ1syNGlfsnc9AOGbv/JJ97XmAFUKxP7yQnJnZ3G/eWz7vO49lb8Ts4wLD65g3Ld5kL+ED7Jv9FKWcJj3zu5i5rb1IPUCS76Jg49zjL5iNnfEAl26xgjaTTFVEwMQc77zr9qH9NYX0/0S+b+Guux9eUjEBDF+7brOgnZcVJUN0BdQ+0k/edGkLsNS0kTSp5r6Cc7PPHzRPd7/Dzs0urn32gdhaVnHZfOdBwx4zGoNHjX3pF602o9T/ha2DpO27RyT55po8VwSNWafEovF2eE0PntVSvUmjpqjtg8vm3NS1J4kJ8Ox1Avsapqbu9le/6TknTVgxxxUWCL6NBdZfTcIRQdvuXzTja+32XfqYfE/VF6aXOZ/SCdNiMQjf/c7HXVpc2gHT/15/ZjG37s/r4TRRFGWd0VbnGBz4Dk+dBk2aMppBVduuGrvDLnNDnA1xQE6zies/aAG9SauVitYtb+do5MfwC2YfWc3m+exyfYLAHBcGPnzRbFbr9XxGdNlexGqreLILzJptFNpnajEGieu/aM7hlSLEXd8eZ87nj3cR9eyBi+b2iIuMyM4TxUHDXqIJezXze6HBLl9TnCdAksWSLh+SEqWO1WZBo+uSn8Pa0rnzXXc/7BwvF99kvtiRA29z+5t77WeIRpYtRLnZcjnVpLT0T7zFhAGS7Pp0HZukZ9XONp0JNPNG4w0Jq8mN2bZN4eNyjFjMtrrN1igZMzIoJG7BWtmYTIutEibqJjO/FxsBKHgxFtrmGgUrhK1OAWHXpl2JoDDg7EExTdu+c76QFUy/GIHtuf3UGDpz1qsjpyCAGnJTW4dKmxBY+Qlsp4biPJz5KHV7nYMTrgIhXX3Q76mMJusuHU68vvOHzPtq5yN3vFS1Lx4gF28Mzdvgd4CJF8132rBCOPHcYiZCpyalomxC1KS09Dc99ZGYkv0T9jRhreiSdyVDo3KR0dpt42Wi1WKCP2m8DLHVRCN18/fe+gJiayKGVuNMVpN0EdlX9COMlMx7Xr28CwBQr3TQ8oxWiYrmvQSg2zFjKlgtGpU8+IE1Ja3Wi5kQ+GZ/q2m1Sa2DXsneFlbDxW0fXpIw4/ZJBoxoOPYT81GcSp3xwDlQxGSsziw6TSVaz59ZRCn1nQJGOwHA3JE6fHu+9niSpRPaeGZvzJw/WPBwbpdNpB61pm2rNrB8Ki/UpFyBe7wHlnUypmO3DWxdIO0O4sDebGNA6zob/7JxLh4NUa6beNWeXSYq7dlJ2kix44SqHBgBqPpdtCMjwUUrFM2w6I673DF3WydKnpmd0JqWMYGssLQ75hzyPgAI/MR8jKyAF62J2ukGbn4YdoxZxqEHb86ep5DcL0FDKhzM37tOMeK+R7jU1KVJxy7leyzNhi7xWbyZkhwdlgldW30pQl5YAJr7ecmY4moEsnHKYNYmXs8l6Wgjr5jvsfjN4+sudGpSKsomRE3KFUgnL7v428wi/Bnz6+wRk84UHduLXnVpknFnYqlmA4wnslYxqqBrNZGYh0U/wmTRmJAT9jVkf4n2AoCK33PbxCxsp47pRrYujgk9G/TzxdyMPRStZpP3zqccM4LnxQjlPJLVEjCiqtVmC/a8LULZfhfitABMUyMgcXjMHQpcRXo6sVm8jqM/vWC+s4MTrnLcpYJZ2E8yXYTWvkTTitlOPgMdM85w1FwzLnhga3mMvGI/Tl9GS5bkquGI6F4i+hkR/dyuyqMo25rcNBwR+QA+B+AeAGcB/IiI/rJ/LYM86S+IpGO3Ye6IyWqYOmHmYZduraPQNE/YXt08SeMCOw0nmSPkxagUzBO8EpjXfRVzjpuqM+6aE9bXHYEw2zO1NQXPFrOyh3nrKaj45hwLYQnN0GiTsp3rzXdLTkvJnLDgRWiHZo4ksb5iIXJasdk25/A8TkIONogYdwJ4zaVah9gkSQNIMk7OdF1RrYQMxk6HLv4m353XKzhtJvPfoMVum4QFYuva79WS8qGiTcpevClG4bJUulsHFQHB7rb7HADQXSihZy2K6dvM640v5JdrmadJeQzAz5n5FwBARH8B4NcBbBqBG4Q005FJf1QitCfNPqkCgAd4u8xxhaK54Ucq7WWCdl3ZlHL7FGN3YLYVbPR8zG9iwm+4/QDQiEvYFZjrXrSRdY9iBHZ/aO/KoBSjaq/lzMduyZmwIpiXqIq2vRmrNv7XbBedKSlBcY7IOSYqlxKjqGKsQRd7m35LycUfKxeXO0vSTYykgVJp3JSul2ZDZ4aK80lgDxj7uTnfa//ImrYtD5FNkeMRG8+s9JwnVuKZc10fPRtHjwPzWWfu3Icxmw6WteDlaVLuB3Am9fdZu20JRPQgET1DRM9MT1/zgj2Kkit5arhBBa7LYhTM/AjMksc4evRoZjGMQd2TH3v6Ey6lq5Nq+uMquWu2Ic9oiFrNPGGjyBw3VmlhT8XElcQcrFt/+t7CZYxYu2zMarUCIkwFxlPQiG3rPK+D2Gox335VVa+L+cB4LWa6xmvTBlC3ZmjX+umLfuRCA13rUCl4EcplM5bZlnkvESO0Y2abVVK4FLj/jGvTl/rvtSZtJkwjcaCkHSQX32K08dhL5vMGja6riJfSnqjqu+p4qS6XXi0UAxffJJkuNgRQZHcHFe137fuxC2kIlVoXDft54qo5/9wtPjxrro8cR6bkqeHOArgh9fcBAK/mNBZFyYQ8NdyPANxMRIcAvALg/QB+K8fxLCEdFkhnocjSUZIV0R0JXHW1q572E8fDeM243SdKLYzY3nb7SkZz7bZv3Fe4jDHPHDfimWP8lLIf84zTocc+LkRGW+zyzfEexVi0Fa1TRZuZ0auiZzXhSGA0SNELnXNFNF0LBcw1zZNe5m3dboDQZqlIGU84EsGzv0tVOZBos+KCZHUwRl6xYQjr2m9NJuU5SVfnpPxm+i1m7BMvhq6rmKgBCQW09iSFtHE5To6p2iwaq8Eq5S7GK0mCNADU6l0s1M01Fjvm9VI4jtnX24vYouFBy4ltBLkJHDOHRPQhAN+GyaT6PDM/n9d4BjFoIi2JzIu/83YAQHsC6I7bybu9GchjF/+qFYy5s6vQQj0wv+8pmDtpzArNlD+PEc/sG7FOkxHPR49t12QrfD3uoWz7MpTJmILnwl1oFswN3LPmYysqIrD2ViM0N5mYogCw2LMOn1Rsrm3f6/nskpuddRYRAtuoSCrJS6l+AJJNMncocEIjPSu9Hrt0L6kqqJ6PUD9j9kvCd2vSd3047Vfhsku8bmJewtb5xWM9FKvmOxipJvHMXdYxtLdsHmaBFzmzWmKYvyx18MvTxlnTuGxrGN+8B1mQa+Cbmb8F4Ft5jkFRskQzTYbQ7zJ+Iv6yc5qE5ZRpNWI1mzVx9u+ddSGAsaJ54taCDm4omXibaLYbAvN3jXoYs06OKtn4Gfko299F07URwVqv6MHsK1KEqlUJi1aLTRUXMGvTXmLfag2KXVaKhAoKXoSejcN1ezbptx2AJcPEVoijEbjmRrFUAPmAbzXR3CHz3vqrkauDkzzIsRcWXfxt3+OvATDt9CRvUuiMkjNRvXDJLrBnMluARMOCgMjW+klO6u5KA/urc+a7tQmUYqIDwEJs4jb7K3Mo2dDIqapZhiusVdz/Vpbw2ohQgeZSKkqGqIYbwqClgUuuCNU8jhcPJF+hZ7NLSn6I0aJxVlxfmQMAjAdNTPjGqTHlyxzOhgK8CCNks+bta4HSmR1GW7Y5QsFGlmtk1MsM1+HbHnc9q4Yi9tCReYt9bUcFNMOlWmWmUXMOFHGaEDFY/O2SXUKc5Ct2JJsmVeFuT5uu8paSnLBWdCU40ntkabNY8+bOmOcay3Z22bFY/0hrbxICiEt2Y5xYGJI3Wi90sN/2ZD9QMBpuf2HWOaAie5Jz4VhSddE2/8fzNwa42DVj2dtM2iOut5ZTgVsD6UXrz91uvIRV6yjojcUuRlWtGiGrBj2UbaxNMkcKXoiy7TsnDg+5EcqUGBoiaB48xFaQIltC1WTGgo3Jda1wxfDQY2sq2vO12XfnDu3N2AyLKFqzdbppbvxyoYdFW0HetfVwXhCBrUeSbHZJ6VzghEt8L34bLqVN4mfpspvqC+aGnz22160zINklnalkeS6JYRYanGrVYI6XcpqgFaAzZrY1JO94MYBv42+SvuYRu4eO+z4RoWgHL99ZzevgOrtgw3V18/CbqdUR1pY+kNSkVJQtjmq4IfQ7TaLJusuMkLw/v0EI9xiN1bG9RUL2UBOPgmWX33JP2tg+59r2ibsQ91AWRwYnixu2WY4XzeVh3k7852LjFJmLqliwSZxiPr7SGkuV7yQFqHOdVB0NgGa36JJ8pe1e3PNd46Fg3i7WOB6jdNH8XrVLWFGUrJEgXLq1nnJ4mATG0VMN16skjRSeSkjBbyYdmqUfiWTzsJ9UlQeXUx2sQ/v9WS19rjmCoh2AmJkLcQUjNo4p3/90OIKXO2Z8UsgbR+Qa3wobEZtTgRtCv0nBx58FbOBb5hnwGN68+Rq5boSs1Stg1qZZ1Wzg+XJUQdO62G6w84uub+dcaKEIMXvMOWIAPTvPEMG8ENVxKTLm4BmbkTvbq2HeuvekkuBCu+4EbqFbcmOStKeGrQzoNguuiRDZz0CcNAqSvplej5xZKJ7E8mxS3e3mbnOxq2+T+VpYK7oGQPOpWJcsXZzuZC3dz0Roxdz0O4zY1hxK/C8qA62GGcCiff3FQglzk2aAJz1ThbB/5LK7ZtU+BF9tjqJpY5HnL5pgX+2nZew5Yf5XEmtdvizJtaMmpaJkiGq4K+CJ+MuuoVDtnKwc6iWV3ruM+SapRABQtdklC2EZB8qzAICfRSZx90DRPK7PhSH22ETlmdg6WRDhUmy0hEz2p8NRXI7ME1yySs62x+FZd96rTfO0nmnWENhMF8l4abaLaDfMU13ibNzzkjUKpElQTM5BImZcFCdZH7t+YXYuXu/j+sdMXE0aBoWl5HYS50np5UtoHjGxLvFORlU/2W8Tw4OXL8F/0nRjlnUERNPImg0GcmMrXTLft6zZAK+AuYkS0pwbH4dXNteNpTN1ywOsFi9P2wZN84lT57ENLNVRDacoGaIabgVWisHIU3dcFmSs7EPjeqsmWkYTzZwfRWvMTNRlDlUp9DBdNo/i0YKJv51umcrV0aCNPbaUWbRVgSJ0bIDron2Ex0y42LVV09aRcrFTGzj+yM7XpqdNWTb3PFduI0mSaQeEZHIQJ6vmSCV3+RK7vi2ipWpnm25OVn30h2bbsductpPvqXl/spyXaBAg0YpuvJN11+nahQ+sphs/fn7JWg6AmdfJWMQBE6S6YMu+xRuKCMvmPbLuAbC0PR9gYqwbqdkEFbgVWCkGI+aO1DNPPLeI0pyxKRdsg53FGz20Z80/uW2D4bNTHUyXbOC3YG6G60aNGXk6nMRkpa+ZDxiBjZvNdRMv38WGETDfehd7sefSsiIxFZnQWjAOFBJbse2bJjsAgkvJv13WD5dWclERrumPkL6RhcaBauIgsckA/syia1PuzMJm5Pp5itAEje6yRkHy/jQieOGp0/huX9/Q9P8nnZgg1RxyzV3/84Q7Lt0y4/GcelOqSakoGaIabgj9a4A/EX/ZmUrppZKkJKV6waYzVTzXRKdxwLzS2TI6dvEO0XqLVgvFoY9XKsaOk4U7ioUwWbDDxpuYCe1FY6J69jhmcuU0UcOaW00Pnmg76fxcilF6zZpe0ik5SJKBZd1v9gm9uq2Dswnae7/zmnPzp9dUEI0l2qqERLOkTbWwzwmC1BoNsmxVdPB2APUl75VrPvHS8vXW0ya/nCNdwzhMg+XZeVk1nKJkiLY6XwPpp2u/1gsOH3JPWGmNHtaK7mn+6kPvAGC0ifTCt4Xf6NjC1aiaNO4P5m1JzmTolpCSZaiKFwJ0x5c2+Q+a5BJ5wwmbU9jwk6WubPA6qjKCxcSlDhjniIQ3pL14ekmptJYS0isI9S9UGZ46vWw9BlkWGFg61+pnUOPdvBZXvBrW2upcTco1MMgE6e9ZCaSyJu4/5uJLNnkdQJKlIeZmada2OpjyXSxPMsJoOkhSjaQBciNZM1yEF0jW2obn2+PIJRmLcPndpFOyrMQ6djp0WSDOu5haRTS9Guygz+s+V8qkE5z5iET4/JSgDTpf//c8bN9WRU1KRcmQXExKInoAwB8BeAOAY8y8JjsxL5NyNQaZSt37jwFIuggDSRMdaYgqLNzgoXbO1oKN2nzCyMSagCS3sD1OTgNKeVA6gbi5x7bQ6ya1aqJV40KSJSJOkcpMMjYxJcNacYkZKJ9vkBbr//yDTMB0AnBe3Y6zYK0mZV4C9waY/Nz/DuBj20Xg0jecbEsv+C5mmwvsptZQSy9qARiT7LJtVCTdntPL8krQuXa2ifnDJjYnyb5emDTsEfOxNL886MvHn3VzskHzte0oGBvFpp7DMfNJACCi1Q5VlG2FOk3WgX6v2z3eA87TJo6UEInTQMo+5O8qgPn3mUXiJSWpe/8xTP7A1MQ4TThVcf0w02UwzsNo+zh2Rj1MnUi0GGAyP/pjZDh8CGFqfP1slAm4Ea0LtgobJnBE9H8B7Buw6w+Z+f9cwXkeBPAgABw8eHCdRqco+bBhAsfM716n8+SytsC1sNLTu9+lnl7S+Adf+RgAuPWo/Sefdcm06XXHJYNiSUs366CR9m7h+97m5nXSO99/8oTr5Zw+bzqGJmO/Ws22Vs21U7UboCblurDWG2gtxz3+9CcADI53pYXhqVRHMZJgtRXg6qM/RNOaqOl4VzpoDSwN2gvXYu7tZEFaK7nE4YjoN4joLIC3A/gmEX07j3EoStbk5aX8GoCv5XHtzcqwWFZ6Tep+LZI+Pm0KyvmGmYcbkTq1nWNt64FmmihKhmjy8jZFNU22bOrAt3LtDDNBd3Kca7OjJqWiZIhquC3KMIeHarfNi2o4RckQFThFyRAVOEXJEBU4RckQFThFyRAVOEXJEBU4RckQFThFyRAVOEXJEBU4RckQFThFyRAVOEXJEBU4RcmQvHqa/DERvUBEf0tEXyOisTzGoShZk5eGewLArcz8JgAvAviDnMahKJmSi8Ax8+PMLM1+nwZwII9xKErWbIY53L8A8Fjeg1CULMi11TkR/SFMW/svDDmPa3UOoENEz633WDeA3QAu5j2INaDjXD9ev5aDcuvaRUT/HMC/AnA3MzfX+J5n1tIZKW90nOvLVhjnWseYS08TIroXwL8F8M61CpuibAfymsP9CYARAE8Q0U+I6L/lNA5FyZS8Wp2/7irf+si6DmTj0HGuL1thnGsa45bqvKwoW53NEBZQlB3DlhO4rZIWRkQPENHzRBQT0abysBHRvUT0MyL6ORF9PO/xDIKIPk9EFzZ7GIiIbiCiJ4nopP1/f3jY8VtO4LB10sKeA/A+AN/LeyBpiMgH8DkA9wH4FQAfIKJfyXdUA/kzAPfmPYg1EAL4KDO/AcAdAP71sO9zywncVkkLY+aTzPyzvMcxgGMAfs7Mv2DmLoC/APDrOY9pGcz8PQCX8h7HajDza8z8Y/v7AoCTAPavdPyWE7g+NC3sytkP4Ezq77MYcoMoa4eIbgLwVgA/XOmYTbmYx3qlhW00axnnJoQGbFNX9TVCRHUAXwXwEWaeX+m4TSlwzPzuYfttWth7YdLCcrtZVhvnJuUsgBtSfx8A8GpOY9kWEFEBRti+wMyPDjt2y5mUqbSwf6ppYVfFjwDcTESHiKgI4P0A/jLnMW1ZiIgA/CmAk8z82dWO33IChy2SFkZEv0FEZwG8HcA3iejbeY8JAKzD6UMAvg0zwf8SMz+f76iWQ0RfBPDXAF5PRGeJ6F/mPaYVuBPABwHcZe/HnxDRr610sGaaKEqGbEUNpyhbFhU4RckQFThFyRAVOEXJEBU4RckQFbhtBBFNplzT54joldTfzdRxtxDRt2y1wEki+hIR7bXvf5KIFonoT/L8LNsVDQtsU4jojwAsMvOn7d+LzFwnojKAZwH8HjN/3e57F4BpAKdhcgFvhanI+FAug9/GbMrULmVD+S0Afy3CBgDM/GRq//eJ6GpbYCiroCblzuNWACfyHsRORQVOUTJEBW7n8TyA2/MexE5FBW7n8b8BvIOI7pcNtsfJbTmOacegArfDYOYWTC3hvyGil4jo7wD8LoALAEBEfw/gswB+12bpb8Z+J1sWDQsoSoaohlOUDFGBU5QMUYFTlAxRgVOUDFGBU5QMUYFTlAxRgVOUDFGBU5QM+f9pfpc6Lh3BygAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] @@ -322,7 +322,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/srv/public/mameyer/bgflow/bgflow/factory/generator_builder.py:219: UserWarning: No target energy for TensorInfo(name='CG_TORSIONS', is_circular=True, is_cartesian=False).\n", + "/home/marcel/BG/bgflow/bgflow/factory/generator_builder.py:220: UserWarning: No target energy for TensorInfo(name='CG_TORSIONS', is_circular=True, is_cartesian=False).\n", " warnings.warn(f\"No target energy for {field}.\", UserWarning)\n" ] } @@ -367,9 +367,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/srv/public/mameyer/bgflow/bgflow/factory/generator_builder.py:219: UserWarning: No target energy for TensorInfo(name='CG_BONDS', is_circular=False, is_cartesian=False).\n", + "/home/marcel/BG/bgflow/bgflow/factory/generator_builder.py:220: UserWarning: No target energy for TensorInfo(name='CG_BONDS', is_circular=False, is_cartesian=False).\n", " warnings.warn(f\"No target energy for {field}.\", UserWarning)\n", - "/srv/public/mameyer/bgflow/bgflow/factory/generator_builder.py:219: UserWarning: No target energy for TensorInfo(name='CG_ANGLES', is_circular=False, is_cartesian=False).\n", + "/home/marcel/BG/bgflow/bgflow/factory/generator_builder.py:220: UserWarning: No target energy for TensorInfo(name='CG_ANGLES', is_circular=False, is_cartesian=False).\n", " warnings.warn(f\"No target energy for {field}.\", UserWarning)\n" ] } @@ -414,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 58, "id": "482ca2e1-2bee-46e2-8780-a02ba935352c", "metadata": { "scrolled": true @@ -451,17 +451,38 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 91, + "id": "7611f5f0-36f4-4b55-b904-f795d951a3a2", + "metadata": {}, + "outputs": [], + "source": [ + "conditioner_types = [\"allegro\", \"nequip\", \"schnet\", \"wrapdistances\", \"ICs\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 98, "id": "bd791bdf-fb5c-43a7-aab4-ce98f42786a1", "metadata": {}, "outputs": [], "source": [ - "conditioner = \"allegro\" # \"allegro\", \"nequip\", \"schnet\", (\"wrapdistances\"), (\"ICs\")" + "conditioner = \"schnet\"" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "e84fead8-74ab-4adf-a87f-50c3df1d4d76", + "metadata": {}, + "outputs": [], + "source": [ + "if conditioner not in conditioner_types:\n", + " raise TypeError(\"not allowed\")" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 60, "id": "005f7963-4bbe-4601-b8c7-b50e6c15c67c", "metadata": {}, "outputs": [], @@ -476,7 +497,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 61, "id": "f8aab29f-5eaf-422e-9a3b-82f40c915f02", "metadata": {}, "outputs": [], @@ -486,7 +507,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 62, "id": "9bf67181-a417-47c8-bce3-f6c7bf6e5f9c", "metadata": { "scrolled": true @@ -526,7 +547,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 63, "id": "0cde88d1-07be-4661-87ca-f6b970391e43", "metadata": {}, "outputs": [], @@ -585,7 +606,41 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 64, + "id": "50d0336f-ba41-4b88-bae7-6c50ec11c985", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "ShapeDictionary([(TensorInfo(name='BONDS', is_circular=False, is_cartesian=False),\n", + " (87,)),\n", + " (TensorInfo(name='ANGLES', is_circular=False, is_cartesian=False),\n", + " (165,)),\n", + " (TensorInfo(name='TORSIONS', is_circular=True, is_cartesian=False),\n", + " (165,)),\n", + " (TensorInfo(name='CG_BONDS', is_circular=False, is_cartesian=False),\n", + " (9,)),\n", + " (TensorInfo(name='CG_ANGLES', is_circular=False, is_cartesian=False),\n", + " (8,)),\n", + " (TensorInfo(name='CG_TORSIONS', is_circular=True, is_cartesian=False),\n", + " (7,))])" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "builder.current_dims" + ] + }, + { + "cell_type": "code", + "execution_count": 65, "id": "8eae4d48-1f18-49dd-baf3-035e80dfe6b5", "metadata": { "scrolled": true, @@ -597,21 +652,24 @@ " builder.default_conditioner_type = \"GNN\"\n", " builder.default_conditioner_kwargs = {\n", " \"r_max\": hparams[\"r_max\"],\n", - " \"GNN\": GNN_feature_extractor, #1\n", - " #\"GNN\": nequip_wrapper, #2\n", - " #\"GNN_kwargs\": {\"shared_params\": None,\n", - " # \"layers\": layers\n", - " # },\n", + " #\"GNN\": GNN_feature_extractor, #1\n", + " \"GNN\": NequipWrapper, #2\n", + " \"GNN_kwargs\": {\"nequip_GNN\": SequentialGraphNetwork,\n", + " \"shared_params\": None,\n", + " \"layers\": layers,\n", + " \"cutoff\": hparams[\"r_max\"]\n", + " \n", + " },\n", " \"use_checkpointing\": False,\n", " \"GNN_output_dim\": hparams[\"GNN_feature_dim\"]*len(c_alpha),\n", " \"attention_units\": len(c_alpha), #\"atomwise\",\n", - " \"attention_level\": \"Transformer\"\n", + " \"attention_level\": None\n", " }" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 66, "id": "5fc68532-d417-4c0d-8406-d9398cbc95c9", "metadata": { "tags": [] @@ -629,7 +687,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 67, "id": "f2fb226b-9e45-4625-af1e-867e995b9a67", "metadata": {}, "outputs": [], @@ -648,13 +706,13 @@ " \"use_checkpointing\": True, # False,\n", " \"GNN_output_dim\": (len(c_alpha)**2-len(c_alpha))//2*num_basis, \n", " \"attention_units\": (len(c_alpha)**2-len(c_alpha))//2,\n", - " \"attention_level\": \"Transformer\"\n", + " \"attention_level\": None\n", " }" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 68, "id": "528f9456-bd93-41f2-aa85-f3e446bf8b60", "metadata": {}, "outputs": [], @@ -670,7 +728,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 69, "id": "54a878c2-ef07-442d-9362-7cde96f89f8b", "metadata": {}, "outputs": [], @@ -680,7 +738,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 70, "id": "63f7f925-40b7-4c4d-9964-bcf7b1a68153", "metadata": {}, "outputs": [], @@ -690,7 +748,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 71, "id": "da9a6c9b-3ea7-4872-ac31-0c68b76e0540", "metadata": {}, "outputs": [], @@ -727,7 +785,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 72, "id": "908acbfa-55cf-4551-933b-125412bcee39", "metadata": { "scrolled": true @@ -832,17 +890,17 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 73, "id": "50fdf077", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "22484281" + "25897369" ] }, - "execution_count": 34, + "execution_count": 73, "metadata": {}, "output_type": "execute_result" } @@ -863,7 +921,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 74, "id": "208cde95-5290-4f2a-85e5-87adf7ae3aae", "metadata": { "tags": [] @@ -872,21 +930,21 @@ { "data": { "text/plain": [ - "torch.Size([500, 525])" + "torch.Size([2, 525])" ] }, - "execution_count": 35, + "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "generator.sample(500).shape" + "generator.sample(2).shape" ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 75, "id": "4dc5aede-08a0-4800-86bb-b76398d87db9", "metadata": {}, "outputs": [], @@ -896,7 +954,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 76, "id": "820b7a2d-a3b2-493e-8b70-77bb559e0985", "metadata": { "tags": [] @@ -905,19 +963,19 @@ { "data": { "text/plain": [ - "tensor([[-981.0965],\n", - " [-981.7209],\n", - " [-984.6757],\n", - " [-980.6040],\n", - " [-991.8463],\n", - " [-983.9974],\n", - " [-983.5727],\n", - " [-964.4722],\n", - " [-980.1158],\n", - " [-978.4524]], device='cuda:0', grad_fn=)" + "tensor([[ -996.0197],\n", + " [ -990.0566],\n", + " [ -951.0018],\n", + " [ -985.1042],\n", + " [ -967.7171],\n", + " [ -972.6763],\n", + " [ -979.6975],\n", + " [ -982.6862],\n", + " [-1025.4706],\n", + " [ -980.6635]], device='cuda:0', grad_fn=)" ] }, - "execution_count": 37, + "execution_count": 76, "metadata": {}, "output_type": "execute_result" } @@ -936,20 +994,20 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 77, "id": "9ae0e3ba-09f2-4c8c-831e-d7655ef23e73", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(tensor(5.4389e-07, device='cuda:0', grad_fn=),\n", - " tensor(1.6689e-06, device='cuda:0', grad_fn=),\n", + "(tensor(5.3644e-07, device='cuda:0', grad_fn=),\n", + " tensor(1.7881e-07, device='cuda:0', grad_fn=),\n", " tensor(6.5565e-07, device='cuda:0', grad_fn=),\n", - " tensor(0.0001, device='cuda:0', grad_fn=))" + " tensor(0.0002, device='cuda:0', grad_fn=))" ] }, - "execution_count": 38, + "execution_count": 77, "metadata": {}, "output_type": "execute_result" } @@ -971,12 +1029,12 @@ "id": "642e031c-10be-464f-a3aa-0ac5af229afe", "metadata": {}, "source": [ - "### check equivariance of GNN" + "### check invariance of GNN" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 78, "id": "7cb6e1ee-41e8-468b-a5d3-d96d55774642", "metadata": {}, "outputs": [], @@ -997,7 +1055,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 79, "id": "681a3134-8be7-4d9e-83b9-b962ba904c56", "metadata": {}, "outputs": [], @@ -1009,7 +1067,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 80, "id": "ae90b81f-0548-4d16-8c7a-b3ec8634fa96", "metadata": {}, "outputs": [], @@ -1019,7 +1077,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 81, "id": "f6c3210c-5b07-4aa7-bdf0-1cdff1105e5e", "metadata": {}, "outputs": [], @@ -1029,7 +1087,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 82, "id": "774fecac-99bd-4f55-84dc-33b9c85293a4", "metadata": {}, "outputs": [], @@ -1039,17 +1097,17 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 83, "id": "b471ebde-297c-4eba-9188-a086e154f0be", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(0.3901)" + "tensor(0.3857)" ] }, - "execution_count": 44, + "execution_count": 83, "metadata": {}, "output_type": "execute_result" } @@ -1060,17 +1118,17 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 84, "id": "aae44141-3a2f-4928-af08-a2aeecdf31f4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(0.3901)" + "tensor(0.3857)" ] }, - "execution_count": 45, + "execution_count": 84, "metadata": {}, "output_type": "execute_result" } @@ -1081,7 +1139,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 85, "id": "674585cd-ef54-4926-9af0-86a5acbefdd8", "metadata": {}, "outputs": [], @@ -1093,7 +1151,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 86, "id": "9e828f11-066e-4c9d-803a-42a250ed1857", "metadata": {}, "outputs": [], @@ -1104,7 +1162,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 87, "id": "fc659d28-e3b0-46f5-b578-9b6309a4c4e0", "metadata": { "tags": [] @@ -1116,7 +1174,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 88, "id": "0987dc52-42be-48f5-bf80-5f106f62bbdc", "metadata": { "tags": [] @@ -1128,7 +1186,7 @@ "tensor(True, device='cuda:0')" ] }, - "execution_count": 49, + "execution_count": 88, "metadata": {}, "output_type": "execute_result" } @@ -1139,7 +1197,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 52, "id": "8db52082-c520-4eb3-8662-7e7a6c6bd31f", "metadata": {}, "outputs": [], @@ -23638,7 +23696,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.10.5" } }, "nbformat": 4,