@@ -719,6 +719,131 @@ def oneHotToShifts(onehot, params):
719719 Shifts [:, i ] = shifts
720720
721721 return Shifts
722+
723+ def timestamps2long (output_spotting , video_size , chunk_size , receptive_field ):
724+ """Method to transform the timestamps to vectors"""
725+ start = 0
726+ last = False
727+ receptive_field = receptive_field // 2
728+
729+ timestamps_long = (
730+ torch .zeros (
731+ [video_size , output_spotting .size ()[- 1 ] - 2 ],
732+ dtype = torch .float ,
733+ device = output_spotting .device ,
734+ )
735+ - 1
736+ )
737+
738+ for batch in np .arange (output_spotting .size ()[0 ]):
739+
740+ tmp_timestamps = (
741+ torch .zeros (
742+ [chunk_size , output_spotting .size ()[- 1 ] - 2 ],
743+ dtype = torch .float ,
744+ device = output_spotting .device ,
745+ )
746+ - 1
747+ )
748+
749+ for i in np .arange (output_spotting .size ()[1 ]):
750+ tmp_timestamps [
751+ torch .floor (output_spotting [batch , i , 1 ] * (chunk_size - 1 )).type (
752+ torch .int
753+ ),
754+ torch .argmax (output_spotting [batch , i , 2 :]).type (torch .int ),
755+ ] = output_spotting [batch , i , 0 ]
756+
757+ # ------------------------------------------
758+ # Store the result of the chunk in the video
759+ # ------------------------------------------
760+
761+ # For the first chunk
762+ if start == 0 :
763+ timestamps_long [0 : chunk_size - receptive_field ] = tmp_timestamps [
764+ 0 : chunk_size - receptive_field
765+ ]
766+
767+ # For the last chunk
768+ elif last :
769+ timestamps_long [start + receptive_field : start + chunk_size ] = (
770+ tmp_timestamps [receptive_field :]
771+ )
772+ break
773+
774+ # For every other chunk
775+ else :
776+ timestamps_long [
777+ start + receptive_field : start + chunk_size - receptive_field
778+ ] = tmp_timestamps [receptive_field : chunk_size - receptive_field ]
779+
780+ # ---------------
781+ # Loop Management
782+ # ---------------
783+
784+ # Update the index
785+ start += chunk_size - 2 * receptive_field
786+ # Check if we are at the last index of the game
787+ if start + chunk_size >= video_size :
788+ start = video_size - chunk_size
789+ last = True
790+ return timestamps_long
791+
792+
793+ def batch2long (output_segmentation , video_size , chunk_size , receptive_field ):
794+ """Method to transform the batches to vectors."""
795+ start = 0
796+ last = False
797+ receptive_field = receptive_field // 2
798+
799+ segmentation_long = torch .zeros (
800+ [video_size , output_segmentation .size ()[- 1 ]],
801+ dtype = torch .float ,
802+ device = output_segmentation .device ,
803+ )
804+
805+ for batch in np .arange (output_segmentation .size ()[0 ]):
806+
807+ tmp_segmentation = torch .nn .functional .one_hot (
808+ torch .argmax (output_segmentation [batch ], dim = - 1 ),
809+ num_classes = output_segmentation .size ()[- 1 ],
810+ )
811+
812+ # ------------------------------------------
813+ # Store the result of the chunk in the video
814+ # ------------------------------------------
815+
816+ # For the first chunk
817+ if start == 0 :
818+ segmentation_long [0 : chunk_size - receptive_field ] = tmp_segmentation [
819+ 0 : chunk_size - receptive_field
820+ ]
821+
822+ # For the last chunk
823+ elif last :
824+ segmentation_long [start + receptive_field : start + chunk_size ] = (
825+ tmp_segmentation [receptive_field :]
826+ )
827+ break
828+
829+ # For every other chunk
830+ else :
831+ segmentation_long [
832+ start + receptive_field : start + chunk_size - receptive_field
833+ ] = tmp_segmentation [receptive_field : chunk_size - receptive_field ]
834+
835+ # ---------------
836+ # Loop Management
837+ # ---------------
838+
839+ # Update the index
840+ start += chunk_size - 2 * receptive_field
841+ # Check if we are at the last index of the game
842+ if start + chunk_size >= video_size :
843+ start = video_size - chunk_size
844+ last = True
845+ return segmentation_long
846+
722847# import torch
723848# import numpy as np
724849# import decord
0 commit comments