@@ -545,4 +545,160 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_
545545
546546 loss = (eskill - 0.5 * espread )
547547
548+ return loss
549+
550+ class SpectralCoherenceLoss (SpectralBaseLoss ):
551+
552+ def __init__ (
553+ self ,
554+ img_shape : Tuple [int , int ],
555+ crop_shape : Tuple [int , int ],
556+ crop_offset : Tuple [int , int ],
557+ channel_names : List [str ],
558+ grid_type : str ,
559+ lmax : Optional [int ] = None ,
560+ spatial_distributed : Optional [bool ] = False ,
561+ ensemble_distributed : Optional [bool ] = False ,
562+ ensemble_weights : Optional [torch .Tensor ] = None ,
563+ alpha : Optional [float ] = 1.0 ,
564+ eps : Optional [float ] = 1.0e-6 ,
565+ ** kwargs ,
566+ ):
567+
568+ super ().__init__ (
569+ img_shape = img_shape ,
570+ crop_shape = crop_shape ,
571+ crop_offset = crop_offset ,
572+ channel_names = channel_names ,
573+ grid_type = grid_type ,
574+ lmax = lmax ,
575+ spatial_distributed = spatial_distributed ,
576+ )
577+
578+ self .spatial_distributed = spatial_distributed and comm .is_distributed ("spatial" )
579+ self .ensemble_distributed = ensemble_distributed and comm .is_distributed ("ensemble" ) and (comm .get_size ("ensemble" ) > 1 )
580+ self .eps = eps
581+
582+ if ensemble_weights is not None :
583+ self .register_buffer ("ensemble_weights" , ensemble_weights , persistent = False )
584+ else :
585+ self .ensemble_weights = ensemble_weights
586+
587+ # prep ls and ms for broadcasting
588+ ls = torch .arange (self .sht .lmax ).reshape (- 1 , 1 )
589+ ms = torch .arange (self .sht .mmax ).reshape (1 , - 1 )
590+
591+ lm_weights = torch .ones ((self .sht .lmax , self .sht .mmax ))
592+ lm_weights [:, 1 :] *= 2.0
593+ lm_weights = torch .where (ms > ls , 0.0 , lm_weights )
594+ if comm .get_size ("h" ) > 1 :
595+ lm_weights = split_tensor_along_dim (lm_weights , dim = - 2 , num_chunks = comm .get_size ("h" ))[comm .get_rank ("h" )]
596+ if comm .get_size ("w" ) > 1 :
597+ lm_weights = split_tensor_along_dim (lm_weights , dim = - 1 , num_chunks = comm .get_size ("w" ))[comm .get_rank ("w" )]
598+ self .register_buffer ("lm_weights" , lm_weights , persistent = False )
599+
600+ @property
601+ def type (self ):
602+ return LossType .Probabilistic
603+
604+ @property
605+ def n_channels (self ):
606+ return 1
607+
608+ @torch .compiler .disable (recursive = False )
609+ def compute_channel_weighting (self , channel_weight_type : str , time_diff_scale : str ) -> torch .Tensor :
610+ return torch .ones (1 )
611+
612+ def forward (self , forecasts : torch .Tensor , observations : torch .Tensor , ensemble_weights : Optional [torch .Tensor ] = None ) -> torch .Tensor :
613+
614+ # sanity checks
615+ if forecasts .dim () != 5 :
616+ raise ValueError (f"Error, forecasts tensor expected to have 5 dimensions but found { forecasts .dim ()} ." )
617+
618+ # get the data type before stripping amp types
619+ dtype = forecasts .dtype
620+
621+
622+ # before anything else compute the transform
623+ # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same
624+ with amp .autocast (device_type = "cuda" , enabled = False ):
625+ # TODO: check 4 pi normalization
626+ forecasts = self .sht (forecasts .float ()) / math .sqrt (4.0 * math .pi )
627+ observations = self .sht (observations .float ()) / math .sqrt (4.0 * math .pi )
628+
629+ # we assume the following shapes:
630+ # forecasts: batch, ensemble, channels, mmax, lmax
631+ # observations: batch, channels, mmax, lmax
632+ B , E , C , H , W = forecasts .shape
633+
634+ # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction.
635+ # ideally we split spatial dims
636+ forecasts = torch .moveaxis (forecasts , 1 , 0 )
637+ if self .ensemble_distributed :
638+ ensemble_shapes = [forecasts .shape [0 ] for _ in range (comm .get_size ("ensemble" ))]
639+ forecasts = distributed_transpose .apply (forecasts , (- 1 , 0 ), ensemble_shapes , "ensemble" ) # for correct spatial reduction we need to do the same with spatial weights
640+
641+ if self .ensemble_distributed :
642+ lm_weights_split = scatter_to_parallel_region (self .lm_weights , - 1 , "ensemble" )
643+
644+ # observations does not need a transpose, but just a split and broadcast to ensemble dimension
645+ observations = observations .unsqueeze (0 )
646+ if self .ensemble_distributed :
647+ observations = scatter_to_parallel_region (observations , - 1 , "ensemble" )
648+
649+ num_ensemble = forecasts .shape [0 ]
650+
651+ # compute power spectral densities of forecasts and observations
652+ psd_forecasts = (lm_weights_split * forecasts .abs ().square ()).sum (dim = - 1 )
653+ psd_observations = (lm_weights_split * observations .abs ().square ()).sum (dim = - 1 )
654+
655+ # reduce over ensemble parallel region and m spatial dimensions
656+ if self .ensemble_distributed :
657+ psd_forecasts = reduce_from_parallel_region (psd_forecasts , "ensemble" )
658+ psd_observations = reduce_from_parallel_region (psd_observations , "ensemble" )
659+
660+ if self .spatial_distributed :
661+ psd_forecasts = reduce_from_parallel_region (psd_forecasts , "w" )
662+ psd_observations = reduce_from_parallel_region (psd_observations , "w" )
663+
664+
665+ # compute coherence between forecasts and observations
666+ coherence_forecasts = (lm_weights_split * (forecasts .unsqueeze (0 ).conj () * forecasts .unsqueeze (1 )).real ).sum (dim = - 1 )
667+ coherence_observations = (lm_weights_split * (forecasts .conj () * observations ).real ).sum (dim = - 1 )
668+
669+ # reduce over ensemble parallel region and m spatial dimensions
670+ if self .ensemble_distributed :
671+ coherence_forecasts = reduce_from_parallel_region (coherence_forecasts , "ensemble" )
672+ coherence_observations = reduce_from_parallel_region (coherence_observations , "ensemble" )
673+
674+ if self .spatial_distributed :
675+ coherence_forecasts = reduce_from_parallel_region (coherence_forecasts , "w" )
676+ coherence_observations = reduce_from_parallel_region (coherence_observations , "w" )
677+
678+ # divide the coherence by the product of the norms
679+ coherence_observations = coherence_observations / torch .sqrt (psd_forecasts * psd_observations )
680+ coherence_forecasts = coherence_forecasts / torch .sqrt (psd_observations .unsqueeze (0 ) * psd_observations .unsqueeze (1 ))
681+
682+ # compute the error in the power spectral density
683+ psd_skill = (torch .sqrt (psd_forecasts ) - torch .sqrt (psd_observations )).square ()
684+ psd_skill = psd_skill .sum (dim = 0 ) / float (num_ensemble )
685+
686+ # compute the coherence skill and spread
687+ coherence_skill = (1.0 - coherence_observations ).sum (dim = 0 ) / float (num_ensemble )
688+
689+ # mask the diagonal of coherence_spread with 0.0
690+ coherence_spread = torch .where (torch .eye (num_ensemble , device = coherence_forecasts .device ).bool ().reshape (num_ensemble , num_ensemble , 1 , 1 , 1 ), 0.0 , 1.0 - coherence_forecasts )
691+ coherence_spread = coherence_spread .sum (dim = (0 , 1 )) / float (num_ensemble * (num_ensemble - 1 ))
692+
693+ # compute the loss
694+ loss = psd_skill + 2.0 * psd_observations .squeeze (0 ) * (coherence_skill - 0.5 * coherence_spread )
695+
696+ # reduce the loss over the l dimensions
697+ loss = loss .sum (dim = - 1 )
698+ if self .spatial_distributed :
699+ loss = reduce_from_parallel_region (loss , "h" )
700+
701+ # reduce over the channel dimension
702+ loss = loss .sum (dim = - 1 )
703+
548704 return loss
0 commit comments