-
Notifications
You must be signed in to change notification settings - Fork 39
Expand file tree
/
Copy pathsubmission.py
More file actions
97 lines (80 loc) · 3.32 KB
/
Copy pathsubmission.py
File metadata and controls
97 lines (80 loc) · 3.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# ##########################################################################
# # Example of submission files
# # ---------------------------
# The zip file needs to be single level depth!
# NO FOLDER
# my_submission.zip
# ├─ submission.py
# ├─ weights_challenge_1.pt
# └─ weights_challenge_2.pt
from braindecode.models import EEGNeX
import torch
from pathlib import Path
def resolve_path(name="model_file_name"):
if Path(f"/app/input/res/{name}").exists():
return f"/app/input/res/{name}"
elif Path(f"/app/input/{name}").exists():
return f"/app/input/{name}"
elif Path(f"{name}").exists():
return f"{name}"
elif Path(__file__).parent.joinpath(f"{name}").exists():
return str(Path(__file__).parent.joinpath(f"{name}"))
else:
raise FileNotFoundError(
f"Could not find {name} in /app/input/res/ or /app/input/ or current directory"
)
class Submission:
def __init__(self, SFREQ, DEVICE):
self.sfreq = SFREQ
self.device = DEVICE
def get_model_challenge_1(self):
model_challenge1 = EEGNeX(
n_chans=129, n_outputs=1, sfreq=self.sfreq, n_times=int(2 * self.sfreq)
).to(self.device)
# model_challenge1.load_state_dict(torch.load(resolve_path("weights_challenge_1.pt"), map_location=self.device))
return model_challenge1
def get_model_challenge_2(self):
model_challenge2 = EEGNeX(
n_chans=129, n_outputs=1, n_times=int(2 * self.sfreq)
).to(self.device)
# model_challenge2.load_state_dict(torch.load(resolve_path("weights_challenge_2.pt"), map_location=self.device))
return model_challenge2
# ##########################################################################
# # How Submission class will be used
# # ---------------------------------
# from submission import Submission
#
# SFREQ = 100
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# sub = Submission(SFREQ, DEVICE)
# model_1 = sub.get_model_challenge_1()
# model_1.eval()
# warmup_loader_challenge_1 = DataLoader(HBN_R5_dataset1, batch_size=BATCH_SIZE)
# final_loader_challenge_1 = DataLoader(secret_dataset1, batch_size=BATCH_SIZE)
# with torch.inference_mode():
# for batch in warmup_loader_challenge_1: # and final_loader later
# X, y, infos = batch
# X = X.to(dtype=torch.float32, device=DEVICE)
# # X.shape is (BATCH_SIZE, 129, 200)
# # Forward pass
# y_pred = model_1.forward(X)
# # save prediction for computing evaluation score
# ...
# score1 = compute_score_challenge_1(y_true, y_preds)
# del model_1
# gc.collect()
# model_2 = sub.get_model_challenge_2()
# model_2.eval()
# warmup_loader_challenge_2 = DataLoader(HBN_R5_dataset2, batch_size=BATCH_SIZE)
# final_loader_challenge_2 = DataLoader(secret_dataset2, batch_size=BATCH_SIZE)
# with torch.inference_mode():
# for batch in warmup_loader_challenge_2: # and final_loader later
# X, y, crop_inds, infos = batch
# X = X.to(dtype=torch.float32, device=DEVICE)
# # X shape is (BATCH_SIZE, 129, 200)
# # Forward pass
# y_pred = model_2.forward(X)
# # save prediction for computing evaluation score
# ...
# score2 = compute_score_challenge_2(y_true, y_preds)
# overall_score = compute_leaderboard_score(score1, score2)