Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 55 additions & 8 deletions ezyrb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ class Database:

:param array_like parameters: the input parameters
:param array_like snapshots: the input snapshots
:param Scale scaler_parameters: the scaler for the parameters. Default
is None meaning no scaling.
:param Scale scaler_snapshots: the scaler for the snapshots. Default is
None meaning no scaling.
:param array_like space: the input spatial data

:Example:
Expand Down Expand Up @@ -46,6 +42,9 @@ def __init__(self, parameters=None, snapshots=None, space=None):
)
self._pairs = []

self.scaler_parameters = None
self.scaler_snapshots = None

if parameters is None and snapshots is None:
logger.debug("Empty database created")
return
Expand Down Expand Up @@ -149,18 +148,66 @@ def add(self, parameter, snapshot):
"""
if not isinstance(parameter, Parameter):
logger.error("Invalid parameter type: %s", type(parameter))
raise ValueError
raise TypeError(f"Expected a Parameter object, got {type(parameter)}")

if not isinstance(snapshot, Snapshot):
logger.error("Invalid snapshot type: %s", type(snapshot))
raise ValueError
raise TypeError(f"Expected a Snapshot object, got {type(snapshot)}")

self._pairs.append((parameter, snapshot))
logger.debug(
"Added parameter-snapshot pair. Total pairs: %d", len(self._pairs)
)

return self

def normalize_parameters(self, scaler=None):
"""
Normalize the parameters in the database.

:param scaler: A scaling object (e.g., from sklearn.preprocessing).
If None, it defaults to a MinMaxScaler.
"""
if len(self._pairs) == 0:
return self

from sklearn.preprocessing import MinMaxScaler
if scaler is None:
scaler = MinMaxScaler()

params = self.parameters_matrix
normalized_params = scaler.fit_transform(params)

for i, pair in enumerate(self._pairs):
pair[0].values = normalized_params[i]

self.scaler_parameters = scaler
return self


def normalize_snapshots(self, scaler=None):
"""
Normalize the snapshots in the database.

:param scaler: A scaling object (e.g., from sklearn.preprocessing).
If None, it defaults to a MinMaxScaler.
"""
if len(self._pairs) == 0:
return self

from sklearn.preprocessing import MinMaxScaler
if scaler is None:
scaler = MinMaxScaler()

snaps = self.snapshots_matrix
normalized_snaps = scaler.fit_transform(snaps)

for i, pair in enumerate(self._pairs):
# reshape the flat array back to its original multidimensional shape
pair[1].values = normalized_snaps[i].reshape(pair[1].values.shape)

self.scaler_snapshots = scaler
return self

def split(self, chunks, seed=None):
"""
Expand Down Expand Up @@ -209,7 +256,7 @@ def split(self, chunks, seed=None):

else:
logger.error("Invalid chunk type")
ValueError
raise TypeError(f"Invalid chunk type. Expected a list of integers or floats, but got {type(chunks)}.")

new_database = [Database() for _ in range(len(chunks))]
for i, chunk in enumerate(chunks):
Expand All @@ -235,4 +282,4 @@ def get_snapshot_space(self, index):
"""
if index < 0 or index >= len(self._pairs):
raise IndexError("Snapshot index out of range.")
return self._pairs[index][1].space
return self._pairs[index][1].space
17 changes: 11 additions & 6 deletions ezyrb/plugin/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ class DatabaseScaler(Plugin):
"""
The plugin to rescale the database of the reduced order model. It uses a
user defined `scaler`, which has to have implemented the `fit`, `transform`
and `inverse_trasform` methods (i.e. `sklearn` interface), to rescale
and `inverse_transform` methods (i.e. `sklearn` interface), to rescale
the parameters and/or the snapshots. It can be applied at the full order
(`mode='full'`), at the reduced one (`mode='reduced'`) or both of them
(`mode='both'`).
(`mode='full'`) or at the reduced one (`mode='reduced'`).

:param obj scaler: a generic object which has to have implemented the
`fit`, `transform` and `inverse_trasform` methods (i.e. `sklearn`
`fit`, `transform` and `inverse_transform` methods (i.e. `sklearn`
interface).
:param {'full', 'reduced'} mode: define if the rescaling has to be
applied at the full order ('full') or at the reduced one ('reduced').
Expand Down Expand Up @@ -62,11 +61,14 @@ def target(self):
rtype: str
"""
return self._target


@target.setter
def target(self, new_target):
if new_target not in ["snapshots", "parameters"]:
raise ValueError
error_msg = f"Invalid target: '{new_target}' must be 'snapshots' or 'parameters'."
logger.error(error_msg)
raise ValueError(error_msg)

self._target = new_target

Expand All @@ -82,10 +84,13 @@ def mode(self):
@mode.setter
def mode(self, new_mode):
if new_mode not in ["full", "reduced"]:
raise ValueError
error_msg = f"Invalid mode: '{new_mode}' must be 'full' or 'reduced'."
logger.error(error_msg)
raise ValueError(error_msg)

self._mode = new_mode


def _select_matrix(self, db):
"""
Helper function to select the proper matrix to rescale.
Expand Down
Loading