Remove static_input from data pipeline and model call signatures#956
Remove static_input from data pipeline and model call signatures#956
Conversation
83abe41 to
af4e22d
Compare
| for i, (batch, static_inputs) in enumerate(self.batch_generator): | ||
| aggregator: NoTargetAggregator | None = None | ||
| for i, batch in enumerate(self.batch_generator): | ||
| if aggregator is None: |
There was a problem hiding this comment.
Lazy initialization, like done with inference.py to wait for availability on the output domain information
AnnaKwa
left a comment
There was a problem hiding this comment.
LGTM, just a couple small suggestions
| ) | ||
| static_inputs_patches = static_inputs.generate_from_patches(fine_patches) | ||
| else: | ||
| static_inputs_patches = null_generator(len(coarse_patches)) |
There was a problem hiding this comment.
Can remove null_generator from utils now
| ) | ||
| return self.static_inputs.subset_latlon(lat_interval, lon_interval) | ||
|
|
||
| def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: |
There was a problem hiding this comment.
Suggestion: take the ClosedInterval and LatLonCoordinates as args rather than the BatchData to remove the dependence on how BatchData stores this info (e.g. won't have to change this code if switching BatchData to store the coords as LatLonCoordinates instead of BatchedLatLonCoordinates)
There was a problem hiding this comment.
I'm going to leave the single argument method rather than expand into 3 arguments since I think the consolidated passing works in our favor. Expanding would force the repeat of the access pattern that wherever the method is used (and then we'd have to make those updates anyways in multiple locations when we to LatLonCoordinates at the batch level).
This PR finalizes the removal of the
StaticInputhandling by the data pipeline. The passing of static_input objects are removed from the data configuration, batch iteration, and model call signatures in favor of the direct model handling introduced in the previous downscaling PR (#954).Changes:
add
get_fine_coords_for_batchto facilitate translation of an input batch domain to output coordinates via the models stored information. For now, this relies on the model'sstatic_inputs, but will be switched to model's stored coordinates in (Add fine coordinates to the model for easier inference handling #971)inference
Downscalernow takes the batchinput_shapeinstead ofstatic_inputsto check the domain size and model type (regularDiffusionModelorPatchPredictordownscaling
torch.datasetsgenerators forBatchDatano longer includeStaticInputsremoved
_apply_patchand_generate_from_patchesfromStaticInputsconfig.pyno longer references static inputs as an argumentTests added