Source code for dcbench.common.task

import functools
import os
from dataclasses import dataclass
from typing import Sequence
from urllib.request import urlretrieve

import yaml
from meerkat.tools.lazy_loader import LazyLoader
from tqdm import tqdm

from dcbench.common.table import RowMixin, Table
from dcbench.config import config

from .artifact import ArtifactContainer

storage = LazyLoader("google.cloud.storage")


[docs]@dataclass class Task(RowMixin): task_id: str name: str summary: str problem_class: type solution_class: type baselines: Table = Table([]) def __post_init__(self): super().__init__( id=self.task_id, attributes={"name": self.name, "summary": self.summary} ) @property def problems_path(self): return os.path.join(self.task_id, "problems.yaml") @property def local_problems_path(self): return os.path.join(config.local_dir, self.problems_path) @property def remote_problems_url(self): return os.path.join(config.public_remote_url, self.problems_path)
[docs] def write_problems(self, containers: Sequence[ArtifactContainer]): ids = [] for container in containers: assert isinstance(container, self.problem_class) ids.append(container.id) if len(set(ids)) != len(ids): raise ValueError( "Duplicate container ids in the containers passed to `write_problems`." ) os.makedirs(os.path.dirname(self.local_problems_path), exist_ok=True) yaml.dump(containers, open(self.local_problems_path, "w"))
[docs] def upload_problems(self, include_artifacts: bool = False): client = storage.Client() bucket = client.get_bucket(config.public_bucket_name) for container in tqdm(self.problems.values()): assert isinstance(container, self.problem_class) if include_artifacts: container.upload(bucket=bucket, force=True) blob = bucket.blob(self.problems_path) blob.upload_from_filename(self.local_problems_path)
[docs] def download_problems(self, include_artifacts: bool = False): os.makedirs(os.path.dirname(self.local_problems_path), exist_ok=True) urlretrieve(self.remote_problems_url, self.local_problems_path) for container in self.problems.values(): assert isinstance(container, self.problem_class) if include_artifacts: container.upload()
@property @functools.lru_cache() def problems(self): if not os.path.exists(self.local_problems_path): self.download_problems() problems = yaml.load(open(self.local_problems_path), Loader=yaml.FullLoader) return Table(problems) def __repr__(self): return f'Task(task_id="{self.task_id}", name="{self.name}")' def __hash__(self): # necessary for lru cache return hash(repr(self))