Skip to content

Commit 576f8ff

Browse files
authored
Merge pull request #12 from OpenSportsLab/jeet-dev
Localization extension (Models and Trainer)
2 parents 4d0f851 + 77e0a03 commit 576f8ff

File tree

8 files changed

+1417
-3
lines changed

8 files changed

+1417
-3
lines changed

opensportslib/core/trainer/localization_trainer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
135135
trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
136136
0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
137137
logging.info(f"Restored best epoch: {trainer.best_epoch}")
138+
139+
else:
140+
trainer = Trainer_pl(cfg, default_args["work_dir"])
138141

139142

140143
return trainer
@@ -147,6 +150,37 @@ def __init__(self):
147150
def train(self):
148151
pass
149152

153+
class Trainer_pl(Trainer):
154+
"""Trainer class used for models that rely on lightning modules.
155+
156+
Args:
157+
cfg (dict): Dict config. It should contain the key 'max_epochs' and the key 'GPU'.
158+
"""
159+
160+
def __init__(self, cfg, work_dir):
161+
from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
162+
import pytorch_lightning as pl
163+
164+
self.work_dir = work_dir
165+
call = MyCallback()
166+
self.trainer = pl.Trainer(
167+
max_epochs=cfg.max_epochs,
168+
devices=[cfg.GPU],
169+
callbacks=[call, CustomProgressBar(refresh_rate=1)],
170+
num_sanity_val_steps=0,
171+
)
172+
173+
def train(self, **kwargs):
174+
self.trainer.fit(**kwargs)
175+
176+
best_model = kwargs["model"].best_state
177+
178+
logging.info("Done training")
179+
logging.info("Best epoch: {}".format(best_model.get("epoch")))
180+
torch.save(best_model, os.path.join(self.work_dir, "model.pth.tar"))
181+
182+
logging.info("Model saved")
183+
logging.info(os.path.join(self.work_dir, "model.pth.tar"))
150184

151185

152186
class Trainer_e2e(Trainer):
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytorch_lightning as pl
2+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
3+
import logging
4+
5+
6+
class CustomProgressBar(TQDMProgressBar):
7+
"""Override the custom progress bar used by pytorch lightning to change some attributes."""
8+
9+
def get_metrics(self, trainer, pl_module):
10+
"""Override the method to don't show the version number in the progress bar."""
11+
items = super().get_metrics(trainer, pl_module)
12+
items.pop("v_num", None)
13+
return items
14+
15+
16+
class MyCallback(pl.Callback):
17+
"""Override the Callback class of pl to change the behaviour on validation epoch end."""
18+
19+
def __init__(self):
20+
super().__init__()
21+
22+
def on_validation_epoch_end(self, trainer, pl_module):
23+
loss_validation = pl_module.losses.avg
24+
state = {
25+
"epoch": trainer.current_epoch + 1,
26+
"state_dict": pl_module.model.state_dict(),
27+
"best_loss": pl_module.best_loss,
28+
"optimizer": pl_module.optimizer.state_dict(),
29+
}
30+
31+
# remember best prec@1 and save checkpoint
32+
is_better = loss_validation < pl_module.best_loss
33+
pl_module.best_loss = min(loss_validation, pl_module.best_loss)
34+
35+
# Save the best model based on loss only if the evaluation frequency too long
36+
if is_better:
37+
pl_module.best_state = state
38+
# torch.save(state, best_model_path)
39+
40+
# Reduce LR on Plateau after patience reached
41+
prevLR = pl_module.optimizer.param_groups[0]["lr"]
42+
pl_module.scheduler.step(loss_validation)
43+
currLR = pl_module.optimizer.param_groups[0]["lr"]
44+
45+
if currLR is not prevLR and pl_module.scheduler.num_bad_epochs == 0:
46+
logging.info("\nPlateau Reached!")
47+
if (
48+
prevLR < 2 * pl_module.scheduler.eps
49+
and pl_module.scheduler.num_bad_epochs >= pl_module.scheduler.patience
50+
):
51+
logging.info("\nPlateau Reached and no more reduction -> Exiting Loop")
52+
trainer.should_stop = True

opensportslib/core/utils/video_processing.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,131 @@ def oneHotToShifts(onehot, params):
719719
Shifts[:, i] = shifts
720720

721721
return Shifts
722+
723+
def timestamps2long(output_spotting, video_size, chunk_size, receptive_field):
724+
"""Method to transform the timestamps to vectors"""
725+
start = 0
726+
last = False
727+
receptive_field = receptive_field // 2
728+
729+
timestamps_long = (
730+
torch.zeros(
731+
[video_size, output_spotting.size()[-1] - 2],
732+
dtype=torch.float,
733+
device=output_spotting.device,
734+
)
735+
- 1
736+
)
737+
738+
for batch in np.arange(output_spotting.size()[0]):
739+
740+
tmp_timestamps = (
741+
torch.zeros(
742+
[chunk_size, output_spotting.size()[-1] - 2],
743+
dtype=torch.float,
744+
device=output_spotting.device,
745+
)
746+
- 1
747+
)
748+
749+
for i in np.arange(output_spotting.size()[1]):
750+
tmp_timestamps[
751+
torch.floor(output_spotting[batch, i, 1] * (chunk_size - 1)).type(
752+
torch.int
753+
),
754+
torch.argmax(output_spotting[batch, i, 2:]).type(torch.int),
755+
] = output_spotting[batch, i, 0]
756+
757+
# ------------------------------------------
758+
# Store the result of the chunk in the video
759+
# ------------------------------------------
760+
761+
# For the first chunk
762+
if start == 0:
763+
timestamps_long[0 : chunk_size - receptive_field] = tmp_timestamps[
764+
0 : chunk_size - receptive_field
765+
]
766+
767+
# For the last chunk
768+
elif last:
769+
timestamps_long[start + receptive_field : start + chunk_size] = (
770+
tmp_timestamps[receptive_field:]
771+
)
772+
break
773+
774+
# For every other chunk
775+
else:
776+
timestamps_long[
777+
start + receptive_field : start + chunk_size - receptive_field
778+
] = tmp_timestamps[receptive_field : chunk_size - receptive_field]
779+
780+
# ---------------
781+
# Loop Management
782+
# ---------------
783+
784+
# Update the index
785+
start += chunk_size - 2 * receptive_field
786+
# Check if we are at the last index of the game
787+
if start + chunk_size >= video_size:
788+
start = video_size - chunk_size
789+
last = True
790+
return timestamps_long
791+
792+
793+
def batch2long(output_segmentation, video_size, chunk_size, receptive_field):
794+
"""Method to transform the batches to vectors."""
795+
start = 0
796+
last = False
797+
receptive_field = receptive_field // 2
798+
799+
segmentation_long = torch.zeros(
800+
[video_size, output_segmentation.size()[-1]],
801+
dtype=torch.float,
802+
device=output_segmentation.device,
803+
)
804+
805+
for batch in np.arange(output_segmentation.size()[0]):
806+
807+
tmp_segmentation = torch.nn.functional.one_hot(
808+
torch.argmax(output_segmentation[batch], dim=-1),
809+
num_classes=output_segmentation.size()[-1],
810+
)
811+
812+
# ------------------------------------------
813+
# Store the result of the chunk in the video
814+
# ------------------------------------------
815+
816+
# For the first chunk
817+
if start == 0:
818+
segmentation_long[0 : chunk_size - receptive_field] = tmp_segmentation[
819+
0 : chunk_size - receptive_field
820+
]
821+
822+
# For the last chunk
823+
elif last:
824+
segmentation_long[start + receptive_field : start + chunk_size] = (
825+
tmp_segmentation[receptive_field:]
826+
)
827+
break
828+
829+
# For every other chunk
830+
else:
831+
segmentation_long[
832+
start + receptive_field : start + chunk_size - receptive_field
833+
] = tmp_segmentation[receptive_field : chunk_size - receptive_field]
834+
835+
# ---------------
836+
# Loop Management
837+
# ---------------
838+
839+
# Update the index
840+
start += chunk_size - 2 * receptive_field
841+
# Check if we are at the last index of the game
842+
if start + chunk_size >= video_size:
843+
start = video_size - chunk_size
844+
last = True
845+
return segmentation_long
846+
722847
# import torch
723848
# import numpy as np
724849
# import decord

0 commit comments

Comments
 (0)