DepthCoordinate.depth_integral produces NaN gradients for all variables involved in the depth integral (e.g. thetao_0, thetao_1) during training with the OHC conservation corrector.
The root cause is the multiplication of the integrand with dz containing NaNs at land-masked points:
integral = (integrand * self.dz * self.mask).nansum(dim=-1)
While nansum correctly ignores these in the forward pass, the backward pass computes d(nansum)/d(product) * d(product)/d(integrand) = 0 * NaN = NaN at masked points. These NaN gradients propagate through the OHC correction ratio back to the network output, eventually corrupting all model weights.