Source code for monkey_wrench.chimp._models

import tempfile
from pathlib import Path
from typing import Callable, Literal, Self
from uuid import uuid4

from loguru import logger
from pydantic import FilePath, NonNegativeInt, PositiveInt, model_validator

from monkey_wrench.date_time import ChimpFilePathParser
from monkey_wrench.generic import Pattern
from monkey_wrench.input_output import (
    DateTimeDirectory,
    ModelFile,
    copy_files_between_directories,
    output_filename_from_datetime,
)
from monkey_wrench.input_output.seviri import seviri_extension_context
from monkey_wrench.query import Collection, List


[docs] class ChimpRetrieval( Collection, DateTimeDirectory, ModelFile ): """Pydantic model for CHIMP retrievals.""" device: Literal["cpu", "cuda"] = "cpu" sequence_length: NonNegativeInt = 16 temporal_overlap: NonNegativeInt = 0 tile_size: PositiveInt = 256 verbose: bool = True strict: bool = False """Determines whether an exception must be raised if there are missing timestamps. Defaults to ``False``. By default, CHIMP is expected to handle missing timestamps if they are not in the beginning or the end of the sequence. In such cases, we log a warning. However, if this is set to ``True`` we raise an exception instead. Warning: It is not 100% guaranteed that a retrieval can always be performed in the absence of some timestamps. There might be edge cases that CHIMP cannot handle. """
[docs] @model_validator(mode="after") def validate_collection(self) -> Self: # noqa: N804 if self.collection.name != "seviri": raise ValueError(f"Chimp retrieval is not implemented for `{self.collection.name}`.") return self
[docs] def run_in_batches(self, lst: List) -> None: """Perform CHIMP retrievals in batches.""" with seviri_extension_context() as chimp_cli: batches = lst.generate_k_sized_batches_by_index(self.sequence_length, strict=False) for batch in batches: self.__run_for_single_batch(batch, chimp_cli)
def __input_filepaths_as_strings(self, batch: list[FilePath]) -> list[str]: """Convert paths to strings and ensure each batch includes the same number of items as sequence length.""" input_filepaths = [str(i) for i in batch] if len(input_filepaths) != self.sequence_length: msg = f"Expected to receive {self.sequence_length} input files but got {len(input_filepaths)} instead!" msg += f" Batch: {batch}" if self.strict: raise ValueError(msg) logger.warning(msg) return input_filepaths def __run_for_single_batch(self, batch: list[FilePath], retrieve_function: Callable) -> None: """Helper function to perform a single CHIMP retrieval for a single batch.""" log_id = uuid4() with tempfile.TemporaryDirectory(prefix=f"chimp_{log_id}_") as tmp_dir: input_filepaths = self.__input_filepaths_as_strings(batch) retrieve_function( self.model_filepath, "seviri", input_filepaths, tmp_dir, device=self.device, sequence_length=self.sequence_length, temporal_overlap=self.temporal_overlap, tile_size=self.tile_size ) last_snapshot = ChimpFilePathParser.parse(batch[-1]) datetime_directory = self.create_datetime_directory(last_snapshot) copied_files = copy_files_between_directories( Path(tmp_dir), datetime_directory, Pattern( sub_strings=str(output_filename_from_datetime(last_snapshot)) ) ) match len(copied_files): case 0: logger.error(f"Could not perform a retrieval for {batch}") case 1: logger.success("Successfully performed a retrieval.") case n: logger.error(f"Expected a single file for the retrieval but copied {n} files. Batch: {batch}")