Source code for tango.integrations.beaker.workspace

import json
import logging
import os
import random
import tempfile
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from pathlib import Path
from typing import (
    Dict,
    Generator,
    Iterable,
    Iterator,
    Optional,
    Type,
    TypeVar,
    Union,
    cast,
)
from urllib.parse import ParseResult

import petname
from beaker import (
    Dataset,
    DatasetConflict,
    DatasetNotFound,
    Digest,
    Experiment,
    ExperimentNotFound,
)

from tango.common.exceptions import StepStateError
from tango.common.logging import file_handler
from tango.common.util import (
    exception_to_string,
    make_safe_filename,
    tango_cache_dir,
    utc_now_datetime,
)
from tango.step import Step
from tango.step_cache import StepCache
from tango.step_info import StepInfo, StepState
from tango.workspace import Run, Workspace

from .common import (
    BeakerStepLock,
    Constants,
    dataset_url,
    get_client,
    run_dataset_name,
    step_dataset_name,
)
from .step_cache import BeakerStepCache

T = TypeVar("T")
U = TypeVar("U", Run, StepInfo)

logger = logging.getLogger(__name__)


[docs]@Workspace.register("beaker") class BeakerWorkspace(Workspace): """ This is a :class:`~tango.workspace.Workspace` that stores step artifacts on `Beaker`_. .. tip:: Registered as a :class:`~tango.workspace.Workspace` under the name "beaker". :param beaker_workspace: The name or ID of the Beaker workspace to use. :param kwargs: Additional keyword arguments passed to :meth:`Beaker.from_env() <beaker.Beaker.from_env()>`. """ MEM_CACHE_SIZE = 512 def __init__(self, beaker_workspace: str, max_workers: Optional[int] = None, **kwargs): super().__init__() self.beaker = get_client(beaker_workspace=beaker_workspace, **kwargs) self.cache = BeakerStepCache(beaker=self.beaker) self.steps_dir = tango_cache_dir() / "beaker_workspace" self.locks: Dict[Step, BeakerStepLock] = {} self.max_workers = max_workers self._disk_cache_dir = tango_cache_dir() / "beaker_cache" / "_objects" self._mem_cache: "OrderedDict[Digest, Union[StepInfo, Run]]" = OrderedDict() @property def url(self) -> str: return f"beaker://{self.beaker.workspace.get().full_name}" @classmethod def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace: workspace: str if parsed_url.netloc and parsed_url.path: # e.g. "beaker://ai2/my-workspace" workspace = parsed_url.netloc + parsed_url.path elif parsed_url.netloc: # e.g. "beaker://my-workspace" workspace = parsed_url.netloc else: raise ValueError(f"Bad URL for Beaker workspace '{parsed_url}'") return cls(workspace) @property def step_cache(self) -> StepCache: return self.cache @property def current_beaker_experiment(self) -> Optional[Experiment]: """ When the workspace is being used within a Beaker experiment that was submitted by the Beaker executor, this will return the `Experiment` object. """ experiment_name = os.environ.get("BEAKER_EXPERIMENT_NAME") if experiment_name is not None: try: return self.beaker.experiment.get(experiment_name) except ExperimentNotFound: return None else: return None def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path: unique_id = ( step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id ) path = self.steps_dir / unique_id path.mkdir(parents=True, exist_ok=True) return path def work_dir(self, step: Step) -> Path: path = self.step_dir(step) / "work" path.mkdir(parents=True, exist_ok=True) return path def _get_object_from_cache(self, digest: Digest, o_type: Type[U]) -> Optional[U]: cache_path = self._disk_cache_dir / make_safe_filename(str(digest)) if digest in self._mem_cache: cached = self._mem_cache.pop(digest) # Move to end. self._mem_cache[digest] = cached return cached if isinstance(cached, o_type) else None elif cache_path.is_file(): try: with cache_path.open("r+t") as f: json_dict = json.load(f) cached = o_type.from_json_dict(json_dict) except Exception as exc: logger.warning("Error while loading object from workspace cache: %s", str(exc)) try: os.remove(cache_path) except FileNotFoundError: pass return None # Add to in-memory cache. self._mem_cache[digest] = cached while len(self._mem_cache) > self.MEM_CACHE_SIZE: self._mem_cache.popitem(last=False) return cached # type: ignore else: return None def _add_object_to_cache(self, digest: Digest, o: U): self._disk_cache_dir.mkdir(parents=True, exist_ok=True) cache_path = self._disk_cache_dir / make_safe_filename(str(digest)) self._mem_cache[digest] = o with cache_path.open("w+t") as f: json.dump(o.to_json_dict(), f) while len(self._mem_cache) > self.MEM_CACHE_SIZE: self._mem_cache.popitem(last=False) def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: try: dataset = self.beaker.dataset.get(step_dataset_name(step_or_unique_id)) file_info = self.beaker.dataset.file_info(dataset, Constants.STEP_INFO_FNAME) step_info: StepInfo cached = self._get_object_from_cache(file_info.digest, StepInfo) if cached is not None: step_info = cached else: step_info_bytes = self.beaker.dataset.get_file(dataset, file_info, quiet=True) step_info = StepInfo.from_json_dict(json.loads(step_info_bytes)) self._add_object_to_cache(file_info.digest, step_info) return step_info except (DatasetNotFound, FileNotFoundError): if not isinstance(step_or_unique_id, Step): raise KeyError(step_or_unique_id) step_info = StepInfo.new_from_step(step_or_unique_id) self._update_step_info(step_info) return step_info def step_starting(self, step: Step) -> None: # We don't do anything with uncacheable steps. if not step.cache_results: return # Get local file lock + remote Beaker dataset lock. lock = BeakerStepLock( self.beaker, step, current_beaker_experiment=self.current_beaker_experiment ) lock.acquire() self.locks[step] = lock step_info = self.step_info(step) if step_info.state == StepState.RUNNING: # Since we've acquired the step lock we know this step can't be running # elsewhere. But the step state can still say its running if the last # run exited before this workspace had a chance to update the step info. warnings.warn( f"Step info for step '{step.unique_id}' is invalid - says step is running " "although it shouldn't be. Ignoring and overwriting step start time.", UserWarning, ) elif step_info.state not in {StepState.INCOMPLETE, StepState.FAILED, StepState.UNCACHEABLE}: self.locks.pop(step).release() raise StepStateError( step, step_info.state, context=f"If you are certain the step is not running somewhere else, delete the step " f"datasets at {dataset_url(self.beaker.workspace.url(), step_dataset_name(step))}", ) if step_info.state == StepState.FAILED: # Refresh the environment metadata since it might be out-of-date now. step_info.refresh() # Update StepInfo to mark as running. try: step_info.start_time = utc_now_datetime() step_info.end_time = None step_info.error = None step_info.result_location = None self._update_step_info(step_info) except: # noqa: E722 self.locks.pop(step).release() raise def step_finished(self, step: Step, result: T) -> T: # We don't do anything with uncacheable steps. if not step.cache_results: return result step_info = self.step_info(step) if step_info.state != StepState.RUNNING: raise StepStateError(step, step_info.state) # Update step info and save step execution metadata. # This needs to be done *before* adding the result to the cache, since adding # the result to the cache will commit the step dataset, making it immutable. step_info.end_time = utc_now_datetime() step_info.result_location = self.beaker.dataset.url(step_dataset_name(step)) self._update_step_info(step_info) self.cache[step] = result if hasattr(result, "__next__"): assert isinstance(result, Iterator) # Caching the iterator will consume it, so we write it to the cache and then read from the cache # for the return value. result = self.cache[step] self.locks.pop(step).release() return result def step_failed(self, step: Step, e: BaseException) -> None: # We don't do anything with uncacheable steps. if not step.cache_results: return try: step_info = self.step_info(step) if step_info.state != StepState.RUNNING: raise StepStateError(step, step_info.state) step_info.end_time = utc_now_datetime() step_info.error = exception_to_string(e) self._update_step_info(step_info) finally: self.locks.pop(step).release() def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run: import concurrent.futures all_steps = set(targets) for step in targets: all_steps |= step.recursive_dependencies # Create a Beaker dataset that represents this run. The dataset which just contain # a JSON file that maps step names to step unique IDs. run_dataset: Dataset if name is None: # Find a unique name to use. while True: name = petname.generate() + str(random.randint(0, 100)) try: run_dataset = self.beaker.dataset.create( run_dataset_name(cast(str, name)), commit=False ) except DatasetConflict: continue else: break else: try: run_dataset = self.beaker.dataset.create(run_dataset_name(name), commit=False) except DatasetConflict: raise ValueError(f"Run name '{name}' is already in use") steps: Dict[str, StepInfo] = {} run_data: Dict[str, str] = {} # Collect step info. with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_workers, thread_name_prefix="BeakerWorkspace.register_run()-" ) as executor: step_info_futures = [] for step in all_steps: if step.name is None: continue step_info_futures.append(executor.submit(self.step_info, step)) for future in concurrent.futures.as_completed(step_info_futures): step_info = future.result() assert step_info.step_name is not None steps[step_info.step_name] = step_info run_data[step_info.step_name] = step_info.unique_id # Upload run data to dataset. # NOTE: We don't commit the dataset here since we'll need to upload the logs file # after the run. self.beaker.dataset.upload( run_dataset, json.dumps(run_data).encode(), Constants.RUN_DATA_FNAME, quiet=True ) return Run(name=cast(str, name), steps=steps, start_date=run_dataset.created) def registered_runs(self) -> Dict[str, Run]: import concurrent.futures runs: Dict[str, Run] = {} with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_workers, thread_name_prefix="BeakerWorkspace.registered_runs()-" ) as executor: run_futures = [] for dataset in self.beaker.workspace.iter_datasets( match=Constants.RUN_DATASET_PREFIX, results=False ): run_futures.append(executor.submit(self._get_run_from_dataset, dataset)) for future in concurrent.futures.as_completed(run_futures): run = future.result() if run is not None: runs[run.name] = run return runs def registered_run(self, name: str) -> Run: err_msg = f"Run '{name}' not found in workspace" try: dataset_for_run = self.beaker.dataset.get(run_dataset_name(name)) # Make sure the run is in our workspace. if dataset_for_run.workspace_ref.id != self.beaker.workspace.get().id: raise DatasetNotFound except DatasetNotFound: raise KeyError(err_msg) run = self._get_run_from_dataset(dataset_for_run) if run is None: raise KeyError(err_msg) else: return run @contextmanager def capture_logs_for_run(self, name: str) -> Generator[None, None, None]: with tempfile.TemporaryDirectory() as tmp_dir_name: log_file = Path(tmp_dir_name) / "out.log" try: with file_handler(log_file): yield None finally: run_dataset = run_dataset_name(name) self.beaker.dataset.sync(run_dataset, log_file, quiet=True) self.beaker.dataset.commit(run_dataset) def _get_run_from_dataset(self, dataset: Dataset) -> Optional[Run]: if dataset.name is None: return None if not dataset.name.startswith(Constants.RUN_DATASET_PREFIX): return None run_name = dataset.name[len(Constants.RUN_DATASET_PREFIX) :] try: file_info = self.beaker.dataset.file_info(dataset, Constants.RUN_DATA_FNAME) cached = self._get_object_from_cache(file_info.digest, Run) if cached is not None: return cached steps_info_bytes = self.beaker.dataset.get_file(dataset, file_info, quiet=True) steps_info = json.loads(steps_info_bytes) except (DatasetNotFound, FileNotFoundError): return None import concurrent.futures steps: Dict[str, StepInfo] = {} with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_workers, thread_name_prefix="BeakerWorkspace._get_run_from_dataset()-", ) as executor: step_info_futures = [] for unique_id in steps_info.values(): step_info_futures.append(executor.submit(self.step_info, unique_id)) for future in concurrent.futures.as_completed(step_info_futures): step_info = future.result() assert step_info.step_name is not None steps[step_info.step_name] = step_info run = Run(name=run_name, start_date=dataset.created, steps=steps) self._add_object_to_cache(file_info.digest, run) return run def _update_step_info(self, step_info: StepInfo): dataset_name = step_dataset_name(step_info) step_info_dataset: Dataset try: step_info_dataset = self.beaker.dataset.create(dataset_name, commit=False) except DatasetConflict: step_info_dataset = self.beaker.dataset.get(dataset_name) self.beaker.dataset.upload( step_info_dataset, json.dumps(step_info.to_json_dict()).encode(), Constants.STEP_INFO_FNAME, quiet=True, )