Source code for monkey_wrench.chimp._patch

from pathlib import Path

import fsspec
from loguru import logger


[docs] def remote_file(path): tmp = fsspec.open(path).open() def is_dir(*_args): return False def match(*_args): return True tmp.is_dir = is_dir tmp.match = match tmp.stem = path.split("/")[-1].split(".")[0] return tmp
[docs] def cli( model: Path, input_datasets: list[str], input_paths: list[str], output_path: str, device: str = "cuda", precision: str = "single", tile_size: int = 128, sequence_length: int = 0, temporal_overlap: int = 8, forecast: int = 0 ) -> int: from pathlib import Path import torch from chimp.data.input import ( InputLoader, SequenceInputLoader, ) from chimp.processing import retrieval_step from pansat.time import to_datetime from pytorch_retrieve.architectures import load_model if forecast > 0: raise NotImplementedError("Forecast mode is not yet supported in Monkey Wrench!") input_path = [remote_file(i) for i in input_paths] input_datasets = input_datasets.split(",") if sequence_length < 2: input_data = InputLoader(input_path, input_datasets) temporal_overlap = 0 else: if temporal_overlap is None: temporal_overlap = sequence_length - 1 input_data = SequenceInputLoader( input_path, input_datasets, sequence_length=sequence_length, forecast=forecast, temporal_overlap=temporal_overlap ) model = load_model(model) output_path = Path(output_path) if not output_path.exists(): output_path.mkdir(parents=True) float_type = torch.float32 if precision == "single" else torch.float16 for input_step, (time, model_input) in enumerate(input_data): logger.info(f"Starting processing input @ {time}") results = retrieval_step( model, model_input, tile_size=tile_size, device=device, float_type=float_type ) logger.info(f"Finished processing input @ {time}") n_retrieved = len(results) - forecast drop_left = temporal_overlap // 2 for step in range(n_retrieved): curr_time = time - (n_retrieved - step - 1) * input_data.time_step results_s = results[step] results_s["time"] = curr_time.astype("datetime64[ns]") date = to_datetime(curr_time) date_str = date.strftime("%Y%m%d_%H_%M") output_file = output_path / f"chimp_{date_str}.nc" if input_step > 0 and output_file.exists() and step < drop_left: continue results_s.attrs["sequence_step"] = step + 1 logger.info(f"Writing retrieval results to {output_file}") results_s.to_netcdf(output_path / f"chimp_{date_str}.nc") return 0