Source code for dcbench.common.task

from calendar import LocaleTextCalendar
import functools
import os
from dataclasses import dataclass
from typing import List
from urllib.request import urlretrieve
import warnings

import yaml
from meerkat.tools.lazy_loader import LazyLoader
from tqdm import tqdm

from dcbench.common.problem import ProblemTable
from dcbench.common.table import RowMixin, Table
from dcbench.config import config

from .artifact_container import ArtifactContainer

storage = LazyLoader("google.cloud.storage")


[docs]@dataclass class Task(RowMixin): task_id: str name: str summary: str problem_class: type solution_class: type baselines: Table = Table([]) def __post_init__(self): super().__init__( id=self.task_id, attributes={"name": self.name, "summary": self.summary} ) @property def problems_path(self): return os.path.join(self.task_id, "problems.yaml") @property def local_problems_path(self): return os.path.join(config.local_dir, self.problems_path) @property def remote_problems_url(self): return os.path.join(config.public_remote_url, self.problems_path)
[docs] def write_problems(self, containers: List[ArtifactContainer], append: bool = True): ids = [] for container in containers: assert isinstance(container, self.problem_class) ids.append(container.id) if len(set(ids)) != len(ids): raise ValueError( "Duplicate container ids in the containers passed to `write_problems`." ) if append: for id, problem in self.problems.items(): if id not in ids: containers.append(problem) os.makedirs(os.path.dirname(self.local_problems_path), exist_ok=True) yaml.dump(containers, open(self.local_problems_path, "w")) self._load_problems.cache_clear()
[docs] def upload_problems(self, include_artifacts: bool = False, force: bool = True): """ Uploads the problems to the remote storage. Args: include_artifacts (bool): If True, also uploads the artifacts of the problems. force (bool): If True, if the problem overwrites the remote problems. Defaults to True. .. warning:: It is somewhat dangerous to set `force=False`, as this could lead to remote and local problems being out of sync. """ client = storage.Client() bucket = client.get_bucket(config.public_bucket_name) local_problems = self.problems if not force and False: temp_fp, _ = urlretrieve(self.remote_problems_url) remote_problems_ids = [ problem.id for problem in yaml.load(open(temp_fp), Loader=yaml.FullLoader) ] for problem_id in list(local_problems.keys()): if problem_id in remote_problems_ids: warnings.warn( f"Skipping problem {problem_id} because it is already uploaded." ) del local_problems._data[problem_id] for container in tqdm(local_problems.values()): assert isinstance(container, self.problem_class) if include_artifacts: container.upload(bucket=bucket, force=force) blob = bucket.blob(self.problems_path) blob.upload_from_filename(self.local_problems_path)
[docs] def download_problems(self, include_artifacts: bool = False): os.makedirs(os.path.dirname(self.local_problems_path), exist_ok=True) # TODO: figure out issue with caching on this call to urlretrieve urlretrieve(self.remote_problems_url, self.local_problems_path) self._load_problems.cache_clear() for container in self.problems.values(): assert isinstance(container, self.problem_class) if include_artifacts: container.download()
@functools.lru_cache() def _load_problems(self): if not os.path.exists(self.local_problems_path): self.download_problems() problems = yaml.load(open(self.local_problems_path), Loader=yaml.FullLoader) return ProblemTable(problems) @property def problems(self): return self._load_problems() def __repr__(self): return f'Task(task_id="{self.task_id}", name="{self.name}")' def __hash__(self): # necessary for lru cache return hash(repr(self))