| import gc | |
| from fire import Fire | |
| import torch | |
| from omegaconf import OmegaConf | |
| from ldcast.models.autoenc import autoenc, encoder | |
| from ldcast.models.genforecast import analysis, training, unet | |
| from train_nowcaster import setup_data | |
| def setup_model( | |
| num_timesteps=5, | |
| model_dir="../models/test/", | |
| autoenc_weights_fn="../models/autoenc/autoenc-32-0.01.pt", | |
| use_obs=True, | |
| use_nwp=False, | |
| nwp_input_patches=4, | |
| num_nwp_vars=9, | |
| lr=1e-4 | |
| ): | |
| enc = encoder.SimpleConvEncoder() | |
| dec = encoder.SimpleConvDecoder() | |
| autoencoder_obs = autoenc.AutoencoderKL(enc, dec) | |
| autoencoder_obs.load_state_dict(torch.load(autoenc_weights_fn)) | |
| autoencoders = [] | |
| input_patches = [] | |
| input_size_ratios = [] | |
| embed_dim = [] | |
| analysis_depth = [] | |
| if use_obs: | |
| autoencoders.append(autoencoder_obs) | |
| input_patches.append(1) | |
| input_size_ratios.append(1) | |
| embed_dim.append(128) | |
| analysis_depth.append(4) | |
| if use_nwp: | |
| autoencoder_nwp = autoenc.DummyAutoencoder(width=num_nwp_vars) | |
| autoencoders.append(autoencoder_nwp) | |
| input_patches.append(nwp_input_patches) | |
| input_size_ratios.append(2) | |
| embed_dim.append(32) | |
| analysis_depth.append(2) | |
| analysis_net = analysis.AFNONowcastNetCascade( | |
| autoencoders, | |
| input_patches=input_patches, | |
| input_size_ratios=input_size_ratios, | |
| train_autoenc=False, | |
| output_patches=num_timesteps, | |
| cascade_depth=3, | |
| embed_dim=embed_dim, | |
| analysis_depth=analysis_depth | |
| ) | |
| model = unet.UNetModel(in_channels=autoencoder_obs.hidden_width, | |
| model_channels=256, out_channels=autoencoder_obs.hidden_width, | |
| num_res_blocks=2, attention_resolutions=(1,2), | |
| dims=3, channel_mult=(1, 2, 4), num_heads=8, | |
| num_timesteps=num_timesteps, context_ch=analysis_net.cascade_dims | |
| ) | |
| (ldm, trainer) = training.setup_genforecast_training( | |
| model, autoencoder_obs, context_encoder=analysis_net, | |
| model_dir=model_dir, lr=lr | |
| ) | |
| gc.collect() | |
| return (ldm, trainer) | |
| def train( | |
| future_timesteps=8, | |
| use_obs=True, | |
| use_nwp=False, | |
| sample_shape=(4,4), | |
| batch_size=8, | |
| sampler=None, | |
| ckpt_path=None, | |
| initial_weights=None, | |
| strict_weights=True, | |
| model_dir=None, | |
| lr=1e-4 | |
| ): | |
| if sampler is None: | |
| sampler_file = None | |
| else: | |
| sampler_file = { | |
| s: f"{sampler}_{s}.pkl" for s in ["test", "train", "valid"] | |
| } | |
| print("Loading data...") | |
| datamodule = setup_data( | |
| future_timesteps=future_timesteps, | |
| use_obs=use_obs, | |
| use_nwp=use_nwp, | |
| sampler_file=sampler_file, | |
| batch_size=batch_size, | |
| sample_shape=sample_shape | |
| ) | |
| print("Setting up model...") | |
| (model, trainer) = setup_model( | |
| num_timesteps=future_timesteps//4, | |
| use_obs=use_obs, | |
| use_nwp=use_nwp, | |
| model_dir=model_dir, | |
| lr=lr | |
| ) | |
| if initial_weights is not None: | |
| print(f"Loading weights from {initial_weights}...") | |
| model.load_state_dict( | |
| torch.load(initial_weights, map_location=model.device), | |
| strict=strict_weights | |
| ) | |
| print("Starting training...") | |
| trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) | |
| def main(config=None, **kwargs): | |
| config = OmegaConf.load(config) if (config is not None) else {} | |
| config.update(kwargs) | |
| train(**config) | |
| if __name__ == "__main__": | |
| Fire(main) | |