from calendar import LocaleTextCalendar
import functools
import os
from dataclasses import dataclass
from typing import List
from urllib.request import urlretrieve
import warnings
import yaml
from meerkat.tools.lazy_loader import LazyLoader
from tqdm import tqdm
from dcbench.common.problem import ProblemTable
from dcbench.common.table import RowMixin, Table
from dcbench.config import config
from .artifact_container import ArtifactContainer
storage = LazyLoader("google.cloud.storage")
[docs]@dataclass
class Task(RowMixin):
task_id: str
name: str
summary: str
problem_class: type
solution_class: type
baselines: Table = Table([])
def __post_init__(self):
super().__init__(
id=self.task_id, attributes={"name": self.name, "summary": self.summary}
)
@property
def problems_path(self):
return os.path.join(self.task_id, "problems.yaml")
@property
def local_problems_path(self):
return os.path.join(config.local_dir, self.problems_path)
@property
def remote_problems_url(self):
return os.path.join(config.public_remote_url, self.problems_path)
[docs] def write_problems(self, containers: List[ArtifactContainer], append: bool = True):
ids = []
for container in containers:
assert isinstance(container, self.problem_class)
ids.append(container.id)
if len(set(ids)) != len(ids):
raise ValueError(
"Duplicate container ids in the containers passed to `write_problems`."
)
if append:
for id, problem in self.problems.items():
if id not in ids:
containers.append(problem)
os.makedirs(os.path.dirname(self.local_problems_path), exist_ok=True)
yaml.dump(containers, open(self.local_problems_path, "w"))
self._load_problems.cache_clear()
[docs] def upload_problems(self, include_artifacts: bool = False, force: bool = True):
"""
Uploads the problems to the remote storage.
Args:
include_artifacts (bool): If True, also uploads the artifacts of the
problems.
force (bool): If True, if the problem overwrites the remote problems.
Defaults to True.
.. warning::
It is somewhat dangerous to set `force=False`, as this could lead
to remote and local problems being out of sync.
"""
client = storage.Client()
bucket = client.get_bucket(config.public_bucket_name)
local_problems = self.problems
if not force and False:
temp_fp, _ = urlretrieve(self.remote_problems_url)
remote_problems_ids = [
problem.id
for problem in yaml.load(open(temp_fp), Loader=yaml.FullLoader)
]
for problem_id in list(local_problems.keys()):
if problem_id in remote_problems_ids:
warnings.warn(
f"Skipping problem {problem_id} because it is already uploaded."
)
del local_problems._data[problem_id]
for container in tqdm(local_problems.values()):
assert isinstance(container, self.problem_class)
if include_artifacts:
container.upload(bucket=bucket, force=force)
blob = bucket.blob(self.problems_path)
blob.upload_from_filename(self.local_problems_path)
[docs] def download_problems(self, include_artifacts: bool = False):
os.makedirs(os.path.dirname(self.local_problems_path), exist_ok=True)
# TODO: figure out issue with caching on this call to urlretrieve
urlretrieve(self.remote_problems_url, self.local_problems_path)
self._load_problems.cache_clear()
for container in self.problems.values():
assert isinstance(container, self.problem_class)
if include_artifacts:
container.download()
@functools.lru_cache()
def _load_problems(self):
if not os.path.exists(self.local_problems_path):
self.download_problems()
problems = yaml.load(open(self.local_problems_path), Loader=yaml.FullLoader)
return ProblemTable(problems)
@property
def problems(self):
return self._load_problems()
def __repr__(self):
return f'Task(task_id="{self.task_id}", name="{self.name}")'
def __hash__(self):
# necessary for lru cache
return hash(repr(self))