-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscaler.py
More file actions
19 lines (13 loc) · 745 Bytes
/
scaler.py
File metadata and controls
19 lines (13 loc) · 745 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from sklearn.preprocessing import StandardScaler
def scale_data(train_features, test_features):
scaler = StandardScaler()
train_features_scaled = scaler.fit_transform(train_features)
test_features_scaled = scaler.transform(test_features)
return train_features_scaled, test_features_scaled
def custom_scale_data(train_features, test_features):
scaler = StandardScaler()
train_features_scaled = scaler.fit_transform(train_features.reshape(-1, train_features.shape[-1])).reshape(
train_features.shape)
test_features_scaled = scaler.transform(test_features.reshape(-1, test_features.shape[-1])).reshape(
test_features.shape)
return train_features_scaled, test_features_scaled