diff --git a/makani/models/networks/pangu.py b/makani/models/networks/pangu.py index b5b8e550..b8539e5e 100644 --- a/makani/models/networks/pangu.py +++ b/makani/models/networks/pangu.py @@ -452,7 +452,7 @@ def forward(self, x: torch.Tensor, mask=None): x: input features with shape of (B * num_lon, num_pl*num_lat, N, C) mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon) """ - + B_, nW_, N, C = x.shape qkv = ( self.qkv(x) @@ -478,7 +478,7 @@ def forward(self, x: torch.Tensor, mask=None): attn = self.attn_drop_fn(attn) x = self.apply_attention(attn, v, B_, nW_, N, C) - + else: if mask is not None: bias = mask.unsqueeze(1).unsqueeze(0) + earth_position_bias.unsqueeze(0).unsqueeze(0) @@ -486,10 +486,10 @@ def forward(self, x: torch.Tensor, mask=None): #bias = bias.squeeze(2) else: bias = earth_position_bias.unsqueeze(0) - + # extract batch size for q,k,v nLon = self.num_lon - q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4]) + q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4]) k = k.view(B_ // nLon, nLon, k.shape[1], k.shape[2], k.shape[3], k.shape[4]) v = v.view(B_ // nLon, nLon, v.shape[1], v.shape[2], v.shape[3], v.shape[4]) #### @@ -736,7 +736,7 @@ class Pangu(nn.Module): - https://arxiv.org/abs/2211.02556 """ - def __init__(self, + def __init__(self, inp_shape=(721,1440), out_shape=(721,1440), grid_in="equiangular", @@ -773,14 +773,14 @@ def __init__(self, self.checkpointing_level = checkpointing_level drop_path = np.linspace(0, drop_path_rate, 8).tolist() - + # Add static channels to surface self.num_aux = len(self.aux_channel_names) N_total_surface = self.num_aux + self.num_surface # compute static permutations to extract self._precompute_channel_groups(self.channel_names, self.aux_channel_names) - + # Patch embeddings are 2D or 3D convolutions, mapping the data to the required patches self.patchembed2d = PatchEmbed2D( img_size=self.inp_shape, @@ -791,7 +791,7 @@ def __init__(self, flatten=False, norm_layer=None, ) - + self.patchembed3d = PatchEmbed3D( img_size=(num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size=patch_size, @@ -870,7 +870,7 @@ def __init__(self, self.patchrecovery3d = PatchRecovery3D( (num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size, 2 * embed_dim, num_atmospheric ) - + def _precompute_channel_groups( self, channel_names=[], @@ -901,7 +901,7 @@ def _precompute_channel_groups( def prepare_input(self, input): """ - Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric, + Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric, and reshaping the atmospheric variables into the required format. """ @@ -932,13 +932,13 @@ def prepare_output(self, output_surface, output_atmospheric): level_dict = {level: [idx for idx, value in enumerate(self.channel_names) if value[1:] == level] for level in levels} reordered_ids = [idx for level in levels for idx in level_dict[level]] check_reorder = [f'{level}_{idx}' for level in levels for idx in level_dict[level]] - + # Flatten & reorder the output atmospheric to original order (doublechecked that this is working correctly!) flattened_atmospheric = output_atmospheric.reshape(output_atmospheric.shape[0], -1, output_atmospheric.shape[3], output_atmospheric.shape[4]) reordered_atmospheric = torch.cat([torch.zeros_like(output_surface), torch.zeros_like(flattened_atmospheric)], dim=1) for i in range(len(reordered_ids)): reordered_atmospheric[:, reordered_ids[i], :, :] = flattened_atmospheric[:, i, :, :] - + # Append the surface output, this has not been reordered. if output_surface is not None: _, surf_chans, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names) @@ -948,7 +948,7 @@ def prepare_output(self, output_surface, output_atmospheric): output = reordered_atmospheric return output - + def forward(self, input): # Prep the input by splitting into surface and atmospheric variables @@ -959,7 +959,7 @@ def forward(self, input): surface = checkpoint(self.patchembed2d, surface_aux, use_reentrant=False) atmospheric = checkpoint(self.patchembed3d, atmospheric, use_reentrant=False) else: - surface = self.patchembed2d(surface_aux) + surface = self.patchembed2d(surface_aux) atmospheric = self.patchembed3d(atmospheric) if surface.shape[1] == 0: @@ -1011,11 +1011,5 @@ def forward(self, input): output_atmospheric = self.patchrecovery3d(output_atmospheric) output = self.prepare_output(output_surface, output_atmospheric) - - return output - - - - - + return output diff --git a/makani/models/networks/pangu_onnx.py b/makani/models/networks/pangu_onnx.py index 0805badb..0c641832 100644 --- a/makani/models/networks/pangu_onnx.py +++ b/makani/models/networks/pangu_onnx.py @@ -38,7 +38,7 @@ class PanguOnnx(OnnxWrapper): channel_order_PL: List containing the names of the pressure levels with the ordering that the ONNX model expects onnx_file: Path to the ONNX file containing the model ''' - def __init__(self, + def __init__(self, channel_names=[], aux_channel_names=[], onnx_file=None, @@ -78,12 +78,12 @@ def prepare_input(self, input): B,V,Lat,Long=input.shape if B>1: - raise NotImplementedError("Not implemented yet for batch size greater than 1") + raise NotImplementedError("Not implemented yet for batch size greater than 1") input=input.squeeze(0) surface_aux_inp=input[self.surf_channels] atmospheric_inp=input[self.atmo_channels].reshape(self.n_atmo_groups,self.n_atmo_chans,Lat,Long).transpose(1,0) - + return surface_aux_inp, atmospheric_inp def prepare_output(self, output_surface, output_atmospheric): @@ -99,9 +99,9 @@ def prepare_output(self, output_surface, output_atmospheric): return output.unsqueeze(0) - + def forward(self, input): - + surface, atmospheric = self.prepare_input(input) @@ -109,5 +109,5 @@ def forward(self, input): output = self.prepare_output(output_surface, output) - + return output diff --git a/makani/models/noise.py b/makani/models/noise.py index 56e3232b..9ddad17f 100644 --- a/makani/models/noise.py +++ b/makani/models/noise.py @@ -27,7 +27,54 @@ from physicsnemo.distributed.utils import split_tensor_along_dim, compute_split_shapes -class BaseNoiseS2(nn.Module): +class BaseNoise(nn.Module): + def __init__(self, seed=333, **kwargs): + super().__init__() + self.set_rng(seed=seed) + + def set_rng(self, seed=333): + self.rng_cpu = torch.Generator(device=torch.device("cpu")) + self.rng_cpu.manual_seed(seed) + if torch.cuda.is_available(): + self.rng_gpu = torch.Generator(device=torch.device(f"cuda:{comm.get_local_rank()}")) + self.rng_gpu.manual_seed(seed) + + def reset(self, batch_size=None): + if hasattr(self, "state") and self.state is not None: + if batch_size is not None: + # This assumes self.state is defined in the derived class with correct shape logic + # For BaseNoiseS2 and others, specific reset logic might still be needed or this needs to be generic + # We'll leave the generic implementation to the derived classes or implement a helper if shape is known + pass + + with torch.no_grad(): + self.state.fill_(0.0) + + def set_rng_state(self, cpu_state, gpu_state): + if cpu_state is not None: + self.rng_cpu.set_state(cpu_state) + if torch.cuda.is_available() and (gpu_state is not None): + self.rng_gpu.set_state(gpu_state) + + def get_rng_state(self): + cpu_state = self.rng_cpu.get_state() + gpu_state = None + if torch.cuda.is_available(): + gpu_state = self.rng_gpu.get_state() + return cpu_state, gpu_state + + def get_tensor_state(self): + if hasattr(self, "state"): + return self.state.detach().clone() + return None + + def set_tensor_state(self, newstate): + if hasattr(self, "state"): + with torch.no_grad(): + self.state.copy_(newstate) + + +class BaseNoiseS2(BaseNoise): def __init__( self, img_shape, @@ -43,7 +90,7 @@ def __init__( Abstract base class for noise on the sphere. Initializes the inverse SHT needed by many of the noise classes. Derived noise classes can be stateful or stateless. """ - super().__init__() + super().__init__(seed=seed) # Number of latitudinal modes. self.nlat, self.nlon = img_shape @@ -72,22 +119,12 @@ def __init__( self.lmax = self.isht.lmax self.mmax = self.isht.mmax - # generator objects: - self.set_rng(seed=seed) - # store the noise state: initialize to None self.register_buffer("state", torch.zeros((batch_size, self.num_time_steps, self.num_channels, self.lmax_local, self.mmax_local, 2), dtype=torch.float32), persistent=False) def is_stateful(self): raise NotImplementedError("is_stateful method not implemented for this noise class") - def set_rng(self, seed=333): - self.rng_cpu = torch.Generator(device=torch.device("cpu")) - self.rng_cpu.manual_seed(seed) - if torch.cuda.is_available(): - self.rng_gpu = torch.Generator(device=torch.device(f"cuda:{comm.get_local_rank()}")) - self.rng_gpu.manual_seed(seed) - # Resets the internal state. Can be used to change the batch size if required. def reset(self, batch_size=None): if self.state is not None: @@ -100,7 +137,7 @@ def reset(self, batch_size=None): # this routine generates a noise sample for a single time step and updates the state accordingly, by appending the last time step def update(self, replace_state=False, batch_size=None): - # Update should always create a new state, so + # Update should always create a new state, so # we don't need to check for replace_state # create single occurence with torch.no_grad(): @@ -122,30 +159,6 @@ def update(self, replace_state=False, batch_size=None): return - def set_rng_state(self, cpu_state, gpu_state): - if cpu_state is not None: - self.rng_cpu.set_state(cpu_state) - if torch.cuda.is_available() and (gpu_state is not None): - self.rng_gpu.set_state(gpu_state) - - return - - def get_rng_state(self): - cpu_state = self.rng_cpu.get_state() - gpu_state = None - if torch.cuda.is_available(): - gpu_state = self.rng_gpu.get_state() - - return cpu_state, gpu_state - - def get_tensor_state(self): - return self.state.detach().clone() - - def set_tensor_state(self, newstate): - with torch.no_grad(): - self.state.copy_(newstate) - return - class IsotropicGaussianRandomFieldS2(BaseNoiseS2): def __init__( @@ -518,3 +531,74 @@ def forward(self, update_internal_state=False): self.update() return state + +class GaussianVectorNoise(BaseNoise): + def __init__( + self, + img_shape, + batch_size, + num_channels, + num_time_steps=1, + sigma=1.0, + seed=333, + **kwargs, + ): + r""" + Gaussian noise vector in R^d. + + Parameters + ============ + img_shape : (int, int) + Ignored, kept for compatibility. + batch_size: int + Batch size for the noise + num_channels: int + Number of channels (dimension of the vector) + num_time_steps: int + Number of time steps + sigma : float, default is 1.0 + Standard deviation + """ + super().__init__(seed=seed) + + self.num_channels = num_channels + self.num_time_steps = num_time_steps + self.sigma = sigma + + # State: (B, T, C, 1, 1) + self.register_buffer("state", torch.zeros((batch_size, self.num_time_steps, self.num_channels, 1, 1), dtype=torch.float32), persistent=False) + + def is_stateful(self): + return False + + def reset(self, batch_size=None): + if self.state is not None: + if batch_size is not None: + self.state = torch.zeros(batch_size, self.num_time_steps, self.num_channels, 1, 1, dtype=self.state.dtype, device=self.state.device) + with torch.no_grad(): + self.state.fill_(0.0) + + def update(self, replace_state=False, batch_size=None): + with torch.no_grad(): + if batch_size is None: + batch_size = self.state.shape[0] + + # Generate new noise + newstate = torch.empty((batch_size, self.num_time_steps, self.num_channels, 1, 1), dtype=self.state.dtype, device=self.state.device) + if self.state.is_cuda: + newstate.normal_(mean=0.0, std=self.sigma, generator=self.rng_gpu) + else: + newstate.normal_(mean=0.0, std=self.sigma, generator=self.rng_cpu) + + if newstate.shape == self.state.shape: + self.state.copy_(newstate) + else: + self.state = newstate + return + + def forward(self, update_internal_state=False): + + if update_internal_state: + self.update() + + return self.state.clone() diff --git a/makani/models/preprocessor.py b/makani/models/preprocessor.py index 59dbc2fb..fe649eb6 100644 --- a/makani/models/preprocessor.py +++ b/makani/models/preprocessor.py @@ -161,6 +161,19 @@ def __init__(self, params): num_channels=noise_channels, num_time_steps=self.n_history + 1, ) + elif noise_params["type"] == "stochastic": + from makani.models.noise import GaussianVectorNoise + + self.noise_base_seed = 333 + comm.get_rank("data") + + self.input_noise = GaussianVectorNoise( + img_shape=self.img_shape, + batch_size=params.batch_size, + num_channels=noise_channels, + num_time_steps=self.n_history + 1, + sigma=noise_params.get("sigma", 1.0), + seed=self.noise_base_seed, + ) else: raise NotImplementedError(f'Error, input noise type {noise_params["type"]} not supported.') @@ -258,6 +271,14 @@ def _append_channels(self, x, xc): # this routine also adds noise every time a channel gets appended if hasattr(self, "input_noise"): n = self.input_noise() + + # expand spatial dimensions if necessary (e.g. for vector noise) + # n is expected to be (B, T, C, H, W) or (B, T, C, 1, 1) + # For concatenation, we must expand explicitly if dimensions match but size doesn't + if (n.dim() == 5) and (n.shape[-2] == 1) and (n.shape[-1] == 1): + if (x.shape[-2] != 1) or (x.shape[-1] != 1): + n = n.expand(-1, -1, -1, x.shape[-2], x.shape[-1]) + if self.input_noise_mode == "concatenate": xc = torch.cat([xc, n], dim=2) elif self.input_noise_mode == "perturb":