Skip to content

Commit 32fe69b

Browse files
committed
adding spectral coherence loss
1 parent 0fe0544 commit 32fe69b

File tree

4 files changed

+162
-3
lines changed

4 files changed

+162
-3
lines changed

makani/utils/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def get_optimizer(self, model, params):
640640
elif params.optimizer_type == "SGD":
641641
if self.log_to_screen:
642642
self.logger.info("using SGD optimizer")
643-
optimizer = optim.SGD(all_parameters, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), momentum=params.get("momentum", 0), foreach=True)
643+
optimizer = optim.SGD(all_parameters, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), momentum=params.get("momentum", 0), nesterov=params.get("nesterov", True), foreach=True)
644644
elif params.optimizer_type == "SIRFShampoo":
645645
if self.log_to_screen:
646646
self.logger.info("using SIRFShampoo optimizer")

makani/utils/loss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from .losses import LossType, GeometricLpLoss, SpectralLpLoss, SpectralH1Loss, SpectralAMSELoss
3535
from .losses import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss, KernelScoreLoss
36-
from .losses import L2EnergyScoreLoss, SobolevEnergyScoreLoss, SpectralL2EnergyScoreLoss
36+
from .losses import L2EnergyScoreLoss, SobolevEnergyScoreLoss, SpectralL2EnergyScoreLoss, SpectralCoherenceLoss
3737
from .losses import GaussianMMDLoss
3838
from .losses import EnsembleNLLLoss
3939
from .losses import DriftRegularization, HydrostaticBalanceLoss, SpectralRegularization
@@ -213,6 +213,7 @@ def _compute_multistep_weight(self, multistep_weight_type: str) -> torch.Tensor:
213213
# linear weighting factor for the case of multistep training
214214
multistep_weight = torch.arange(1, self.n_future + 2, dtype=torch.float32) / float(self.n_future + 1)
215215
elif multistep_weight_type == "last-n-1":
216+
print(f"using last n-1")
216217
# weighting factor for the last n steps, with the first step weighted 0
217218
multistep_weight = torch.ones(self.n_future + 1, dtype=torch.float32) / float(self.n_future)
218219
multistep_weight[0] = 0.0
@@ -284,6 +285,8 @@ def _parse_loss_type(self, loss_type: str):
284285
loss_handle = partial(SobolevEnergyScoreLoss)
285286
elif "spectral_l2_energy_score" in loss_type:
286287
loss_handle = partial(SpectralL2EnergyScoreLoss)
288+
elif "spectral_coherence_loss" in loss_type:
289+
loss_handle = partial(SpectralCoherenceLoss)
287290
elif "drift_regularization" in loss_type:
288291
loss_handle = DriftRegularization
289292
elif "spectral_regularization" in loss_type:

makani/utils/losses/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .amse_loss import SpectralAMSELoss
2020
from .hydrostatic_loss import HydrostaticBalanceLoss
2121
from .crps_loss import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss, KernelScoreLoss
22-
from .energy_score import L2EnergyScoreLoss, SobolevEnergyScoreLoss, SpectralL2EnergyScoreLoss
22+
from .energy_score import L2EnergyScoreLoss, SobolevEnergyScoreLoss, SpectralL2EnergyScoreLoss, SpectralCoherenceLoss
2323
from .mmd_loss import GaussianMMDLoss
2424
from .likelihood_loss import EnsembleNLLLoss
2525
from .regularization import DriftRegularization, SpectralRegularization

makani/utils/losses/energy_score.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,4 +545,160 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_
545545

546546
loss = (eskill - 0.5 * espread)
547547

548+
return loss
549+
550+
class SpectralCoherenceLoss(SpectralBaseLoss):
551+
552+
def __init__(
553+
self,
554+
img_shape: Tuple[int, int],
555+
crop_shape: Tuple[int, int],
556+
crop_offset: Tuple[int, int],
557+
channel_names: List[str],
558+
grid_type: str,
559+
lmax: Optional[int] = None,
560+
spatial_distributed: Optional[bool] = False,
561+
ensemble_distributed: Optional[bool] = False,
562+
ensemble_weights: Optional[torch.Tensor] = None,
563+
alpha: Optional[float] = 1.0,
564+
eps: Optional[float] = 1.0e-6,
565+
**kwargs,
566+
):
567+
568+
super().__init__(
569+
img_shape=img_shape,
570+
crop_shape=crop_shape,
571+
crop_offset=crop_offset,
572+
channel_names=channel_names,
573+
grid_type=grid_type,
574+
lmax=lmax,
575+
spatial_distributed=spatial_distributed,
576+
)
577+
578+
self.spatial_distributed = spatial_distributed and comm.is_distributed("spatial")
579+
self.ensemble_distributed = ensemble_distributed and comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)
580+
self.eps = eps
581+
582+
if ensemble_weights is not None:
583+
self.register_buffer("ensemble_weights", ensemble_weights, persistent=False)
584+
else:
585+
self.ensemble_weights = ensemble_weights
586+
587+
# prep ls and ms for broadcasting
588+
ls = torch.arange(self.sht.lmax).reshape(-1, 1)
589+
ms = torch.arange(self.sht.mmax).reshape(1, -1)
590+
591+
lm_weights = torch.ones((self.sht.lmax, self.sht.mmax))
592+
lm_weights[:, 1:] *= 2.0
593+
lm_weights = torch.where(ms > ls, 0.0, lm_weights)
594+
if comm.get_size("h") > 1:
595+
lm_weights = split_tensor_along_dim(lm_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")]
596+
if comm.get_size("w") > 1:
597+
lm_weights = split_tensor_along_dim(lm_weights, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")]
598+
self.register_buffer("lm_weights", lm_weights, persistent=False)
599+
600+
@property
601+
def type(self):
602+
return LossType.Probabilistic
603+
604+
@property
605+
def n_channels(self):
606+
return 1
607+
608+
@torch.compiler.disable(recursive=False)
609+
def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor:
610+
return torch.ones(1)
611+
612+
def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_weights: Optional[torch.Tensor] = None) -> torch.Tensor:
613+
614+
# sanity checks
615+
if forecasts.dim() != 5:
616+
raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.")
617+
618+
# get the data type before stripping amp types
619+
dtype = forecasts.dtype
620+
621+
622+
# before anything else compute the transform
623+
# as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same
624+
with amp.autocast(device_type="cuda", enabled=False):
625+
# TODO: check 4 pi normalization
626+
forecasts = self.sht(forecasts.float()) / math.sqrt(4.0 * math.pi)
627+
observations = self.sht(observations.float()) / math.sqrt(4.0 * math.pi)
628+
629+
# we assume the following shapes:
630+
# forecasts: batch, ensemble, channels, mmax, lmax
631+
# observations: batch, channels, mmax, lmax
632+
B, E, C, H, W = forecasts.shape
633+
634+
# transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction.
635+
# ideally we split spatial dims
636+
forecasts = torch.moveaxis(forecasts, 1, 0)
637+
if self.ensemble_distributed:
638+
ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))]
639+
forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # for correct spatial reduction we need to do the same with spatial weights
640+
641+
if self.ensemble_distributed:
642+
lm_weights_split = scatter_to_parallel_region(self.lm_weights, -1, "ensemble")
643+
644+
# observations does not need a transpose, but just a split and broadcast to ensemble dimension
645+
observations = observations.unsqueeze(0)
646+
if self.ensemble_distributed:
647+
observations = scatter_to_parallel_region(observations, -1, "ensemble")
648+
649+
num_ensemble = forecasts.shape[0]
650+
651+
# compute power spectral densities of forecasts and observations
652+
psd_forecasts = (lm_weights_split * forecasts.abs().square()).sum(dim=-1)
653+
psd_observations = (lm_weights_split * observations.abs().square()).sum(dim=-1)
654+
655+
# reduce over ensemble parallel region and m spatial dimensions
656+
if self.ensemble_distributed:
657+
psd_forecasts = reduce_from_parallel_region(psd_forecasts, "ensemble")
658+
psd_observations = reduce_from_parallel_region(psd_observations, "ensemble")
659+
660+
if self.spatial_distributed:
661+
psd_forecasts = reduce_from_parallel_region(psd_forecasts, "w")
662+
psd_observations = reduce_from_parallel_region(psd_observations, "w")
663+
664+
665+
# compute coherence between forecasts and observations
666+
coherence_forecasts = (lm_weights_split * (forecasts.unsqueeze(0).conj() * forecasts.unsqueeze(1)).real).sum(dim=-1)
667+
coherence_observations = (lm_weights_split * (forecasts.conj() * observations).real).sum(dim=-1)
668+
669+
# reduce over ensemble parallel region and m spatial dimensions
670+
if self.ensemble_distributed:
671+
coherence_forecasts = reduce_from_parallel_region(coherence_forecasts, "ensemble")
672+
coherence_observations = reduce_from_parallel_region(coherence_observations, "ensemble")
673+
674+
if self.spatial_distributed:
675+
coherence_forecasts = reduce_from_parallel_region(coherence_forecasts, "w")
676+
coherence_observations = reduce_from_parallel_region(coherence_observations, "w")
677+
678+
# divide the coherence by the product of the norms
679+
coherence_observations = coherence_observations / torch.sqrt(psd_forecasts * psd_observations)
680+
coherence_forecasts = coherence_forecasts / torch.sqrt(psd_observations.unsqueeze(0) * psd_observations.unsqueeze(1))
681+
682+
# compute the error in the power spectral density
683+
psd_skill = (torch.sqrt(psd_forecasts) - torch.sqrt(psd_observations)).square()
684+
psd_skill = psd_skill.sum(dim=0) / float(num_ensemble)
685+
686+
# compute the coherence skill and spread
687+
coherence_skill = (1.0 - coherence_observations).sum(dim=0) / float(num_ensemble)
688+
689+
# mask the diagonal of coherence_spread with 0.0
690+
coherence_spread = torch.where(torch.eye(num_ensemble, device=coherence_forecasts.device).bool().reshape(num_ensemble, num_ensemble, 1, 1, 1), 0.0, 1.0 - coherence_forecasts)
691+
coherence_spread = coherence_spread.sum(dim=(0, 1)) / float(num_ensemble * (num_ensemble - 1))
692+
693+
# compute the loss
694+
loss = psd_skill + 2.0 * psd_observations.squeeze(0) * (coherence_skill - 0.5 * coherence_spread)
695+
696+
# reduce the loss over the l dimensions
697+
loss = loss.sum(dim=-1)
698+
if self.spatial_distributed:
699+
loss = reduce_from_parallel_region(loss, "h")
700+
701+
# reduce over the channel dimension
702+
loss = loss.sum(dim=-1)
703+
548704
return loss

0 commit comments

Comments
 (0)