diff --git a/ezyrb/database.py b/ezyrb/database.py index 53bc15b..dd5e020 100644 --- a/ezyrb/database.py +++ b/ezyrb/database.py @@ -15,10 +15,6 @@ class Database: :param array_like parameters: the input parameters :param array_like snapshots: the input snapshots - :param Scale scaler_parameters: the scaler for the parameters. Default - is None meaning no scaling. - :param Scale scaler_snapshots: the scaler for the snapshots. Default is - None meaning no scaling. :param array_like space: the input spatial data :Example: @@ -46,6 +42,9 @@ def __init__(self, parameters=None, snapshots=None, space=None): ) self._pairs = [] + self.scaler_parameters = None + self.scaler_snapshots = None + if parameters is None and snapshots is None: logger.debug("Empty database created") return @@ -149,11 +148,11 @@ def add(self, parameter, snapshot): """ if not isinstance(parameter, Parameter): logger.error("Invalid parameter type: %s", type(parameter)) - raise ValueError + raise TypeError(f"Expected a Parameter object, got {type(parameter)}") if not isinstance(snapshot, Snapshot): logger.error("Invalid snapshot type: %s", type(snapshot)) - raise ValueError + raise TypeError(f"Expected a Snapshot object, got {type(snapshot)}") self._pairs.append((parameter, snapshot)) logger.debug( @@ -161,6 +160,54 @@ def add(self, parameter, snapshot): ) return self + + def normalize_parameters(self, scaler=None): + """ + Normalize the parameters in the database. + + :param scaler: A scaling object (e.g., from sklearn.preprocessing). + If None, it defaults to a MinMaxScaler. + """ + if len(self._pairs) == 0: + return self + + from sklearn.preprocessing import MinMaxScaler + if scaler is None: + scaler = MinMaxScaler() + + params = self.parameters_matrix + normalized_params = scaler.fit_transform(params) + + for i, pair in enumerate(self._pairs): + pair[0].values = normalized_params[i] + + self.scaler_parameters = scaler + return self + + + def normalize_snapshots(self, scaler=None): + """ + Normalize the snapshots in the database. + + :param scaler: A scaling object (e.g., from sklearn.preprocessing). + If None, it defaults to a MinMaxScaler. + """ + if len(self._pairs) == 0: + return self + + from sklearn.preprocessing import MinMaxScaler + if scaler is None: + scaler = MinMaxScaler() + + snaps = self.snapshots_matrix + normalized_snaps = scaler.fit_transform(snaps) + + for i, pair in enumerate(self._pairs): + # reshape the flat array back to its original multidimensional shape + pair[1].values = normalized_snaps[i].reshape(pair[1].values.shape) + + self.scaler_snapshots = scaler + return self def split(self, chunks, seed=None): """ @@ -209,7 +256,7 @@ def split(self, chunks, seed=None): else: logger.error("Invalid chunk type") - ValueError + raise TypeError(f"Invalid chunk type. Expected a list of integers or floats, but got {type(chunks)}.") new_database = [Database() for _ in range(len(chunks))] for i, chunk in enumerate(chunks): @@ -235,4 +282,4 @@ def get_snapshot_space(self, index): """ if index < 0 or index >= len(self._pairs): raise IndexError("Snapshot index out of range.") - return self._pairs[index][1].space + return self._pairs[index][1].space \ No newline at end of file diff --git a/ezyrb/plugin/scaler.py b/ezyrb/plugin/scaler.py index 588f54e..7b6f40a 100644 --- a/ezyrb/plugin/scaler.py +++ b/ezyrb/plugin/scaler.py @@ -10,13 +10,12 @@ class DatabaseScaler(Plugin): """ The plugin to rescale the database of the reduced order model. It uses a user defined `scaler`, which has to have implemented the `fit`, `transform` - and `inverse_trasform` methods (i.e. `sklearn` interface), to rescale + and `inverse_transform` methods (i.e. `sklearn` interface), to rescale the parameters and/or the snapshots. It can be applied at the full order - (`mode='full'`), at the reduced one (`mode='reduced'`) or both of them - (`mode='both'`). + (`mode='full'`) or at the reduced one (`mode='reduced'`). :param obj scaler: a generic object which has to have implemented the - `fit`, `transform` and `inverse_trasform` methods (i.e. `sklearn` + `fit`, `transform` and `inverse_transform` methods (i.e. `sklearn` interface). :param {'full', 'reduced'} mode: define if the rescaling has to be applied at the full order ('full') or at the reduced one ('reduced'). @@ -62,11 +61,14 @@ def target(self): rtype: str """ return self._target + @target.setter def target(self, new_target): if new_target not in ["snapshots", "parameters"]: - raise ValueError + error_msg = f"Invalid target: '{new_target}' must be 'snapshots' or 'parameters'." + logger.error(error_msg) + raise ValueError(error_msg) self._target = new_target @@ -82,10 +84,13 @@ def mode(self): @mode.setter def mode(self, new_mode): if new_mode not in ["full", "reduced"]: - raise ValueError + error_msg = f"Invalid mode: '{new_mode}' must be 'full' or 'reduced'." + logger.error(error_msg) + raise ValueError(error_msg) self._mode = new_mode + def _select_matrix(self, db): """ Helper function to select the proper matrix to rescale.