-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocess_data_without_DDP.py
More file actions
129 lines (108 loc) · 6.13 KB
/
process_data_without_DDP.py
File metadata and controls
129 lines (108 loc) · 6.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# load and process data by get_data (why replace data with pft values?)
# make sure start_year, end_year by mt.config['window']
# filter data and get inputs and outputs by mt.build_dataloader then data_utils.dataloader
# now output like inputs = [ [data1, data2, ...] , ...] outputs = [ [target1] , ...], duplicated here
# kfold split inputs to get train_idx, valid_idx
### why feat_cnt = data.shape[-1] get feature count? data.shape return (rows, columns)
# sort_values by 'latitude', 'longitude', 'year', 'month', then DataFrame.value get numpy array, use reshape to fold data to 4D numpy array
# inputs = np.array([inputs[..., i - config['window']: i, :] for i in range(config['window'], inputs.shape[2] + 1)]) the multiply loaded data to nearly config['window'] times!
# input_add[:,:,:,(input_add.shape[3] - 1),:] = 0 remove all labels of the last month, and add other month labels as input
# inputs = np.concatenate((inputs, input_add), axis=4) want to add history target as input ### so axis=-1 is the same?
### why select_top_rows in train but valid don't care about nan_rate?
# use inputs, outputs = self.slice_patch(inputs, outputs, patch_size=config['patch']) to slice data to patch_size, and concatenate them to length dimension ### why not use load small pieces of latitude and longitude? the only difference is by this way, it can train pieces data of a month sequencely , but what's the meaning?
# plan
# load data as before like get_data, now it's DataFrame
# pre-process data by preprocess, now it's sorted and np.array
# Dataset get np.array data and idxes, then transfer data to torch.tensor, data should be get from startIndex.
### argument of Dataset: add history, latitude window, longitude window, month window,
# pass RandomSampler to DataLoader to shuffle data, in valid mode, batch_size = config['batch_size'] * 10
# split indexes to train and valid by kfold, so the data passed to Dataset is already processed to numpy array, and kfold need to split startIndex
import pandas as pd
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
# from random import shuffle
import pyarrow.feather as feather
pro_name = ['sp', 'ssrd', 'strd', 't2m', 'tp', 'u10', 'v10', 'q', 'co2']
class OceanDataset_V2(Dataset):
def __init__(self, data, indexes, config, mode):
self.config = config
self.lat_patch = config['patch']
self.lon_patch = config['patch']
self.window_size = config['window']
indexes = np.array(indexes)
if mode == 'train':
indexes = self.select_top_rows(data, indexes, config['percent'])
data[..., :-1][np.isnan(data[..., :-1])] = 0
# data = np.nan_to_num(data[..., :-1], nan=0)
self.data = torch.from_numpy(data.astype(np.float32))
self.indexes = torch.from_numpy(indexes)
print('self.data.dtype',self.data.dtype)
def select_top_rows(self, data, indexes, percent):
nan_rate = self.get_nan_rate(data, indexes)
top_row_idx = np.argsort(nan_rate)
top_row_idx = top_row_idx[:int(top_row_idx.shape[0] * percent)]
indexes = indexes[top_row_idx]
return indexes
def get_nan_rate(self, data, indexes):
labels = [self.get_label(data, d3_index) for d3_index in indexes]
labels = np.stack(labels, axis = 0).reshape(len(labels), -1)
nan_rate = np.isnan(labels).astype('float')
nan_rate = nan_rate.sum(axis=-1) / labels.shape[-1]
return nan_rate
def get_label(self, data, d3_index):
latitude_start, longitude_start, month_start = d3_index
label = data[latitude_start:latitude_start + self.lat_patch,
longitude_start:longitude_start + self.lon_patch,
month_start + self.window_size - 1, -1]
return label
def __getitem__(self, index):
latitude_start, longitude_start, month_start = self.indexes[index]
data = self.data[latitude_start:latitude_start + self.lat_patch,
longitude_start:longitude_start + self.lon_patch,
month_start:month_start + self.window_size, ...].clone()
if self.config['add_history_target']:
data[:, :, -1, -1] = 0
else:
data = data[..., :-1]
data = torch.nan_to_num(data, nan=0.0)
return data, self.get_label(self.data, self.indexes[index]).unsqueeze(-1)
def __len__(self):
# Returns the size of the dataset
return len(self.indexes)
def get_data(mode, config):
# get dataset
# get feature and nbp data
print('target:', config['target'])
# preprocessed_path = os.path.join(os.path.dirname(__file__), config['preprocessed_data_dir'], '{}.feather'.format(mode))
preprocessed_path = os.path.join(config['preprocessed_data_dir'], '{}.feather'.format(mode))
print('read preprocessed data')
# get preprocessed data from feather
# data = pd.read_feather(preprocessed_path, memory_map=True)
data = feather.read_feather(preprocessed_path, memory_map=True)
data = data[data['year']>=2000][config['columns']]
data = optimize_floats(data)
# read pft data
# pft = feather.read_feather(os.path.join(os.path.dirname(__file__), config['pft_data_dir'], f"{config['target']}.feather"), memory_map=True)
pft = feather.read_feather(os.path.join(config['pft_data_dir'], f"{config['target']}.feather"), memory_map=True)
pft = optimize_floats(pft)
assert pft.shape[0] == data.shape[0], 'PFT data length should equal to data length'
for col in pft.columns:
if 'PFT' in col:
data[col] = pft[col].values
print(data.columns)
return data
# return data
def optimize_floats(df):
float64_cols = df.select_dtypes(include=['float64']).columns
df[float64_cols] = df[float64_cols].astype('float32')
return df
def add_prefix(data, prefix_data, L):
prefix_data = prefix_data.sort_values(by=['latitude', 'longitude', 'year', 'month'])
prefix_data = prefix_data.groupby(['latitude', 'longitude'], group_keys=False).apply(
lambda x: x.iloc[-L:]
).reset_index()
data = pd.concat([prefix_data, data], axis=0)
return data