from __future__ import annotations
import os
import uuid
from abc import ABC
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
import yaml
from meerkat.tools.lazy_loader import LazyLoader
import dcbench.constants as constants
from dcbench.config import config
from .artifact import Artifact
from .table import Attribute, AttributeSpec, RowMixin
storage = LazyLoader("google.cloud.storage")
[docs]@dataclass
class ArtifactSpec:
description: str
artifact_type: type
optional: bool = False
[docs]class ArtifactContainer(ABC, Mapping, RowMixin):
"""A logical collection of artifacts and attributes (simple tags describing the
container), which are useful for finding, sorting and grouping containers.
Args:
artifacts (Mapping[str, Union[Artifact, Any]]): A mapping with the same keys
as the `ArtifactContainer.artifact_specs` (possibly excluding optional
artifacts). Each value can either be an :class:`Artifact`, in which case the
artifact type must match the type specified in the corresponding
:class:`ArtifactSpec`, or a raw object, in which case a new artifact of the
type specified in `artifact_specs` is created from the raw object and an
``artifact_id`` is generated according to the following pattern:
``<task_id>/<container_type>/artifacts/<container_id>/<key>``.
attributes (Mapping[str, PRIMITIVE_TYPE], optional): A mapping with the same
keys as the `ArtifactContainer.attribute_specs` (possibly excluding optional
attributes). Each value must be of the type specified in the corresponding
:class:`AttributeSpec`. Defaults to None.
container_id (str, optional): The ID of the container. Defaults to None, in
which case a UUID is generated.
Attributes:
artifacts (Dict[str, Artifact]): A dictionary of artifacts, indexed by name.
.. Tip::
We can use the index operator directly on :class:`ArtifactContainer`
objects to both fetch the artifact, download it if necessary, and load
it into memory. For example, to load the artifact ``"data"`` into
memory from a container ``container``, we can simply call
``container["data"]``, which is equivalent to calling
``container.artifacts["data"].download()`` followed by
``container.artifacts["data"].load()``.
attributes (Dict[str, Attribute]): A dictionary of attributes, indexed by
name.
.. Tip:: Accessing attributes
Atttributes can be accessed via a dot-notation (as long as the attribute
name does not conflict). For example, to access the attribute ``"data"``
in a container ``container``, we can simply call ``container.data``.
Notes
-----
:class:`ArtifactContainer` is an abstract base class, and should not be
instantiated directly. There are two main groups of :class:`ArtifactContainer`
subclasses:
#. :class:`dcbench.Problem` - A logical collection of artifacts and
attributes that correspond to a specific problem to be solved.
- Example subclasses: :class:`dcbench.SliceDiscoveryProblem`,
:class:`dcbench.BudgetcleanProblem`
#. :class:`dcbench.Solution` - A logical collection of artifacts and
attributes that correspond to a solution to a problem.
- Example subclasses: :class:`dcbench.SliceDiscoverySolution`,
:class:`dcbench.BudgetcleanSolution`
A concrete (i.e. non-abstract) subclass of :class:`ArtifactContainer` must include
(1) a specification for the artifacts it holds, (2) a specification for the
attributes used to tag it, and (3) a `task_id` linking the subclass
to one of dcbench's tasks (see :ref:`task-intro`). For example, in the code block
below we include such a specification in the definition of a simple container that
holds a training dataset and a test dataset (see
:class:`dcbench.SliceDiscoveryProblem` for a real example):
.. code-block:: python
class DemoContainer(ArtifactContainer):
artifact_specs = {
"train_dataset": ArtifactSpec(
artifact_type=CSVArtifact,
description="A CSV containing training data."
),
"test_dataset": ArtifactSpec(
artifact_type=CSVArtifact,
description="A CSV containing test data."
),
}
attribute_specs = {
"dataset_name": AttributeSpec(
attribute_type=str,
description="The name of the dataset."
),
}
task_id = "slice_discovery"
"""
artifact_specs: Mapping[str, ArtifactSpec]
task_id: str
attribute_specs: Mapping[str, AttributeSpec] = {}
# abstract subclasses like Problem and Solution specify this so that all of their
# subclasses may be grouped by container_type when stored on disk
container_type: str = "artifact_container"
def __init__(
self,
artifacts: Mapping[str, Artifact],
attributes: Mapping[str, Attribute] = None,
container_id: str = None,
):
if container_id is None:
container_id = uuid.uuid4().hex
super().__init__(id=container_id)
self._check_artifact_specs(artifacts=artifacts)
artifacts = self._create_artifacts(artifacts=artifacts)
self.artifacts = artifacts
if attributes is None:
attributes = {}
self.attributes = attributes # This setter will check the artifact_specs
@property
def is_downloaded(self) -> bool:
"""Checks if all of the artifacts in the container are downloaded to the local
directory specified in the config file at ``config.local_dir``.
Returns:
bool: True if artifact is downloaded, False otherwise.
"""
return all(x.is_downloaded for x in self.artifacts.values())
@property
def is_uploaded(self) -> bool:
"""Checks if all of the artifacts in the container are uploaded to the GCS
bucket specified in the config file at ``config.public_bucket_name``.
Returns:
bool: True if artifact is uploaded, False otherwise.
"""
return all(x.is_uploaded for x in self.artifacts.values())
[docs] def upload(self, force: bool = False, bucket: "storage.Bucket" = None):
"""Uploads all of the artifacts in the container to a GCS bucket, skipping
artifacts that are already uploaded.
Args:
force (bool, optional): Force upload even if an artifact is already
uploaded. Defaults to False.
bucket (storage.Bucket, optional): The GCS bucket to which the artifacts are
uploaded. Defaults to None, in which case the artifact is uploaded to
the bucket speciried in the config file at config.public_bucket_name.
Returns:
bool: True if any artifacts were uploaded, False otherwise.
"""
if bucket is None:
client = storage.Client()
bucket = client.get_bucket(config.public_bucket_name)
return any(
[
artifact.upload(force=force, bucket=bucket)
for artifact in self.artifacts.values()
]
)
[docs] def download(self, force: bool = False) -> bool:
"""Downloads artifacts in the container from the GCS bucket specified in the
config file at ``config.public_bucket_name`` to the local directory specified
in the config file at ``config.local_dir``. The relative path to the
artifact within that directory is ``self.path``, which by default is
just the artifact ID with the default extension.
Args:
force (bool, optional): Force download even if an artifact is already
downloaded. Defaults to False.
Returns:
bool: True if any artifacts were downloaded, False otherwise.
"""
return any(
[artifact.download(force=force) for artifact in self.artifacts.values()]
)
[docs] @staticmethod
def from_yaml(loader: yaml.Loader, node):
"""This function is called by the YAML loader to convert a YAML node
into an :class:`ArtifactContainer` object.
It should not be called directly.
"""
data = loader.construct_mapping(node, deep=True)
return data["class"](
container_id=data["container_id"],
artifacts=data["artifacts"],
attributes=data["attributes"],
)
[docs] @staticmethod
def to_yaml(dumper: yaml.Dumper, data: ArtifactContainer):
"""This function is called by the YAML dumper to convert an
:class:`ArtifactContainer` object into a YAML node.
It should not be called directly.
"""
data = {
"class": type(data),
"container_id": data.id,
"attributes": data._attributes,
"artifacts": data.artifacts,
}
return dumper.represent_mapping("!ArtifactContainer", data)
# Provide dict interface for accessing artifacts by name
def __getitem__(self, key):
artifact = self.artifacts.__getitem__(key)
if not artifact.is_downloaded:
artifact.download()
return self.artifacts.__getitem__(key).load()
def __iter__(self):
return self.artifacts.__iter__()
def __len__(self):
return self.artifacts.__len__()
def __getattr__(self, k: str) -> Any:
try:
return self.attributes[k]
except KeyError:
raise AttributeError(k)
def __repr__(self):
artifacts = {k: v.__class__.__name__ for k, v in self.artifacts.items()}
return (
f"{self.__class__.__name__}(artifacts={artifacts}, "
f"attributes={self.attributes})"
)
@classmethod
def _check_artifact_specs(cls, artifacts: Mapping[str, Artifact]):
for name, artifact in artifacts.items():
if name not in cls.artifact_specs:
raise ValueError(
f"Passed artifact name '{name}', but the specification for"
f" {cls.__name__} doesn't include it."
)
# defer the check to see if an artifact can actually be created from the raw
# data to _create_artifacts
if isinstance(artifact, Artifact) and not isinstance(
artifact, cls.artifact_specs[name].artifact_type
):
raise ValueError(
f"Passed an artifact of type {type(artifact)} to {cls.__name__}"
f" for the artifact named '{name}'. The specification for"
f" {cls.__name__} expects an Artifact of type"
f" {cls.artifact_specs[name].artifact_type}."
)
for name, spec in cls.artifact_specs.items():
if name not in artifacts:
if spec.optional:
continue
raise ValueError(
f"Must pass required artifact with key '{name}' to {cls.__name__}."
)
def _create_artifacts(self, artifacts: Mapping[str, Artifact]):
return {
name: artifact
if isinstance(artifact, Artifact)
else self.artifact_specs[name].artifact_type.from_data(
data=artifact,
artifact_id=os.path.join(
self.task_id,
self.container_type,
constants.ARTIFACTS_DIR,
self.id,
name,
),
)
for name, artifact in artifacts.items()
}
yaml.add_multi_representer(ArtifactContainer, ArtifactContainer.to_yaml)
yaml.add_constructor("!ArtifactContainer", ArtifactContainer.from_yaml)