Source code for tango.integrations.beaker.step_cache

import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any, Optional, Union

from beaker import Beaker, Dataset, DatasetConflict, DatasetNotFound, DatasetWriteError

from tango.common.exceptions import ConfigurationError
from tango.common.file_lock import FileLock
from tango.common.params import Params
from tango.common.util import make_safe_filename, tango_cache_dir
from tango.step import Step
from tango.step_cache import CacheMetadata, StepCache
from tango.step_caches.local_step_cache import LocalStepCache
from tango.step_info import StepInfo

from .common import Constants, get_client, step_dataset_name

logger = logging.getLogger(__name__)


[docs]@StepCache.register("beaker") class BeakerStepCache(LocalStepCache): """ This is a :class:`~tango.step_cache.StepCache` that's used by :class:`BeakerWorkspace`. It stores the results of steps on Beaker as datasets. It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a step's resulting subsequent times should be fast. .. tip:: Registered as a :class:`~tango.step_cache.StepCache` under the name "beaker". :param workspace: The name or ID of the Beaker workspace to use. :param beaker: The Beaker client to use. """ def __init__(self, beaker_workspace: Optional[str] = None, beaker: Optional[Beaker] = None): self.beaker: Beaker if beaker is not None: self.beaker = beaker if beaker_workspace is not None: self.beaker.config.default_workspace = beaker_workspace self.beaker.workspace.ensure(beaker_workspace) else: self.beaker = get_client(beaker_workspace=beaker_workspace) if self.beaker.config.default_workspace is None: raise ConfigurationError("Beaker default workspace must be set") super().__init__( tango_cache_dir() / "beaker_cache" / make_safe_filename(self.beaker.config.default_workspace) ) def _acquire_step_lock_file(self, step: Union[Step, StepInfo], read_only_ok: bool = False): return FileLock( self.step_dir(step).with_suffix(".lock"), read_only_ok=read_only_ok ).acquire_with_updates(desc=f"acquiring step cache lock for '{step.unique_id}'") def _step_result_dataset(self, step: Union[Step, StepInfo]) -> Optional[Dataset]: try: dataset = self.beaker.dataset.get(step_dataset_name(step)) return dataset if dataset.committed is not None else None except DatasetNotFound: return None def _sync_step_dataset(self, step: Step, objects_dir: Path) -> Dataset: dataset_name = step_dataset_name(step) try: dataset = self.beaker.dataset.create(dataset_name, commit=False) except DatasetConflict: dataset = self.beaker.dataset.get(dataset_name) try: self.beaker.dataset.sync(dataset, objects_dir, quiet=True) dataset = self.beaker.dataset.commit(dataset) except DatasetWriteError: pass return dataset def __contains__(self, step: Any) -> bool: if isinstance(step, (Step, StepInfo)): cacheable = step.cache_results if isinstance(step, Step) else step.cacheable if not cacheable: return False return self._step_result_dataset(step) is not None else: return False def __getitem__(self, step: Union[Step, StepInfo]) -> Any: key = step.unique_id dataset = self._step_result_dataset(step) if dataset is None: raise KeyError(step) # Try getting the result from our in-memory caches first. result = self._get_from_cache(key) if result is not None: return result def load_and_return(): metadata = CacheMetadata.from_params(Params.from_file(self._metadata_path(step))) result = metadata.format.read(self.step_dir(step) / Constants.STEP_RESULT_DIR) self._add_to_cache(key, result) return result # Next check our local on-disk cache. with self._acquire_step_lock_file(step, read_only_ok=True): if self.step_dir(step).is_dir(): return load_and_return() # Finally, check Beaker for the corresponding dataset. with self._acquire_step_lock_file(step): # Make sure the step wasn't cached since the last time we checked (above). if self.step_dir(step).is_dir(): return load_and_return() # We'll download the dataset to a temporary directory first, in case something goes wrong. temp_dir = tempfile.mkdtemp(dir=self.dir, prefix=key) try: self.beaker.dataset.fetch(dataset, target=temp_dir, quiet=True) # Download and extraction was successful, rename temp directory to final step result directory. os.replace(temp_dir, self.step_dir(step)) except DatasetNotFound: raise KeyError(step) finally: shutil.rmtree(temp_dir, ignore_errors=True) return load_and_return() def __setitem__(self, step: Step, value: Any) -> None: if not step.cache_results: logger.warning("Tried to cache step %s despite being marked as uncacheable.", step.name) return with self._acquire_step_lock_file(step): # We'll write the step's results to temporary directory first, and try to upload to # Beaker from there in case anything goes wrong. temp_dir = Path(tempfile.mkdtemp(dir=self.dir, prefix=step.unique_id)) (temp_dir / Constants.STEP_RESULT_DIR).mkdir() try: step.format.write(value, temp_dir / Constants.STEP_RESULT_DIR) metadata = CacheMetadata(step=step.unique_id, format=step.format) metadata.to_params().to_file(temp_dir / self.METADATA_FILE_NAME) # Create the dataset and upload serialized result to it. self._sync_step_dataset(step, temp_dir) # Upload successful, rename temp directory to the final step result directory. if self.step_dir(step).is_dir(): shutil.rmtree(self.step_dir(step), ignore_errors=True) os.replace(temp_dir, self.step_dir(step)) finally: shutil.rmtree(temp_dir, ignore_errors=True) # Finally, add to in-memory caches. self._add_to_cache(step.unique_id, value) def __len__(self) -> int: # NOTE: lock datasets should not count here. They start with the same prefix, # but they never get committed. return sum( 1 for ds in self.beaker.workspace.iter_datasets( uncommitted=False, match=Constants.STEP_DATASET_PREFIX ) if ds.name is not None and ds.name.startswith(Constants.STEP_DATASET_PREFIX) )