from __future__ import annotations
import json
import os
import shutil
import subprocess
import tempfile
import uuid
from abc import ABC, abstractmethod
from typing import Any, Union
from urllib.error import HTTPError
from urllib.request import urlopen, urlretrieve
import warnings
import meerkat as mk
import pandas as pd
import yaml
from meerkat.tools.lazy_loader import LazyLoader
from dcbench.common.modeling import Model
from dcbench.config import config
storage = LazyLoader("google.cloud.storage")
torch = LazyLoader("torch")
def _upload_dir_to_gcs(local_path: str, gcs_path: str, bucket: "storage.Bucket"):
assert os.path.isdir(local_path)
with tempfile.TemporaryDirectory() as tmp_dir:
tarball_path = os.path.join(tmp_dir, "run.tar.gz")
subprocess.call(
[
"tar",
"-czf",
tarball_path,
"-C",
local_path,
".",
]
)
remote_path = gcs_path + ".tar.gz"
blob = bucket.blob(remote_path)
blob.upload_from_filename(tarball_path)
def _url_exists(url: str):
try:
response = urlopen(url)
status_code = response.getcode()
return status_code == 200
except HTTPError:
return False
[docs]class Artifact(ABC):
"""A pointer to a unit of data (e.g. a CSV file) that is stored locally on
disk and/or in a remote GCS bucket.
In DCBench, each artifact is identified by a unique artifact ID. The only
state that the :class:`Artifact` object must maintain is this ID (``self.id``).
The object does not hold the actual data in memory, making it
lightweight.
:class:`Artifact` is an abstract base class. Different types of artifacts (e.g. a
CSV file vs. a PyTorch model) have corresponding subclasses of :class:`Artifact`
(e.g. :class:`CSVArtifact`, :class:`ModelArtifact`).
.. Tip::
The vast majority of users should not call the :class:`Artifact`
constructor directly. Instead, they should either create a new artifact by
calling :meth:`from_data` or load an existing artifact from a YAML file.
The class provides utilities for accessing and managing a unit of data:
- Synchronizing the local and remote copies of a unit of data:
:meth:`upload`, :meth:`download`
- Loading the data into memory: :meth:`load`
- Creating new artifacts from in-memory data: :meth:`from_data`
- Serializing the pointer artifact so it can be shared:
:meth:`to_yaml`, :meth:`from_yaml`
Args:
artifact_id (str): The unique artifact ID.
Attributes:
id (str): The unique artifact ID.
"""
[docs] @classmethod
def from_data(
cls, data: Union[mk.DataPanel, pd.DataFrame, Model], artifact_id: str = None
) -> Artifact:
"""Create a new artifact object from raw data and save the artifact to
disk in the local directory specified in the config file at
``config.local_dir``.
.. tip::
When called on the abstract base class :class:`Artifact`, this method will
infer which artifact subclass to use. If you know exactly which artifact
class you'd like to use (e.g. :class:`DataPanelArtifact`), you should call
this classmethod on that subclass.
Args:
data (Union[mk.DataPanel, pd.DataFrame, Model]): The raw data that will be
saved to disk.
artifact_id (str, optional): . Defaults to None, in which case a UUID will
be generated and used.
Returns:
Artifact: A new artifact pointing to the :arg:`data` that was saved to disk.
"""
if artifact_id is None:
artifact_id = uuid.uuid4().hex
# TODO ():At some point we should probably enforce that ids are unique
if cls is Artifact:
# if called on base class, infer which class to use
if isinstance(data, mk.DataPanel):
cls = DataPanelArtifact
elif isinstance(data, pd.DataFrame):
cls = CSVArtifact
elif isinstance(data, Model):
cls = ModelArtifact
elif isinstance(data, (list, dict)):
cls = YAMLArtifact
else:
raise ValueError(
f"No Artifact in dcbench for object of type {type(data)}"
)
artifact = cls(artifact_id=artifact_id)
artifact.save(data)
return artifact
@property
def local_path(self) -> str:
"""The local path to the artifact in the local directory specified in
the config file at ``config.local_dir``."""
return os.path.join(config.local_dir, self.path)
@property
def remote_url(self) -> str:
"""The URL of the artifact in the remote GCS bucket specified in the
config file at ``config.public_bucket_name``."""
return os.path.join(
config.public_remote_url, self.path + (".tar.gz" if self.isdir else "")
)
@property
def is_downloaded(self) -> bool:
"""Checks if artifact is downloaded to local directory specified in the
config file at ``config.local_dir``.
Returns:
bool: True if artifact is downloaded, False otherwise.
"""
return os.path.exists(self.local_path)
@property
def is_uploaded(self) -> bool:
"""Checks if artifact is uploaded to GCS bucket specified in the config
file at ``config.public_bucket_name``.
Returns:
bool: True if artifact is uploaded, False otherwise.
"""
return _url_exists(self.remote_url)
[docs] def upload(self, force: bool = False, bucket: "storage.Bucket" = None) -> bool:
"""Uploads artifact to a GCS bucket at ``self.path``, which by default
is just the artifact ID with the default extension.
Args:
force (bool, optional): Force upload even if artifact is already uploaded.
Defaults to False.
bucket (storage.Bucket, optional): The GCS bucket to which the artifact is
uplioaded. Defaults to None, in which case the artifact is uploaded to
the bucket speciried in the config file at config.public_bucket_name.
Returns
bool: True if artifact was uploaded, False otherwise.
"""
if not os.path.exists(self.local_path):
raise ValueError(
f"Could not find Artifact to upload at '{self.local_path}'. "
"Are you sure it is stored locally?"
)
if self.is_uploaded and not force:
warnings.warn(
f"Artifact {self.id} is not being re-uploaded."
"Set `force=True` to force upload."
)
return False
if bucket is None:
client = storage.Client()
bucket = client.get_bucket(config.public_bucket_name)
if self.isdir:
_upload_dir_to_gcs(
local_path=self.local_path,
bucket=bucket,
gcs_path=self.path,
)
else:
blob = bucket.blob(self.path)
blob.upload_from_filename(self.local_path)
blob.metadata = {"Cache-Control": "private, max-age=0, no-transform"}
blob.patch()
return True
[docs] def download(self, force: bool = False) -> bool:
"""Downloads artifact from GCS bucket to the local directory specified
in the config file at ``config.local_dir``. The relative path to the
artifact within that directory is ``self.path``, which by default is
just the artifact ID with the default extension.
Args:
force (bool, optional): Force download even if artifact is already
downloaded. Defaults to False.
Returns:
bool: True if artifact was downloaded, False otherwise.
.. warning::
By default, the GCS cache on public urls has a max-age up to an hour.
Therefore, when updating an existin artifacts, changes may not be
immediately reflected in subsequent downloads.
See `here
<https://stackoverflow.com/questions/62897641/google-cloud-storage-public-ob
ject-url-e-super-slow-updating>`_
for more details.
"""
if self.is_downloaded and not force:
return False
if self.isdir:
if self.is_downloaded:
shutil.rmtree(self.local_path)
os.makedirs(self.local_path, exist_ok=True)
tarball_path = self.local_path + ".tar.gz"
urlretrieve(self.remote_url, tarball_path)
subprocess.call(["tar", "-xzf", tarball_path, "-C", self.local_path])
os.remove(tarball_path)
else:
if self.is_downloaded:
os.remove(self.local_path)
os.makedirs(os.path.dirname(self.local_path), exist_ok=True)
urlretrieve(self.remote_url, self.local_path)
return True
DEFAULT_EXT: str = ""
isdir: bool = False
[docs] @abstractmethod
def load(self) -> Any:
"""Load the artifact into memory from disk at ``self.local_path``."""
raise NotImplementedError()
[docs] @abstractmethod
def save(self, data: Any) -> None:
"""Save data to disk at ``self.local_path``."""
raise NotImplementedError()
def __init__(self, artifact_id: str, **kwargs) -> None:
"""
.. warning::
In general, you should not instantiate an Artifact directly. Instead, use
:meth:`Artifact.from_data` to create an Artifact.
"""
self.path = f"{artifact_id}.{self.DEFAULT_EXT}"
self.id = artifact_id
os.makedirs(os.path.dirname(self.local_path), exist_ok=True)
super().__init__()
[docs] @staticmethod
def from_yaml(loader: yaml.Loader, node):
"""This function is called by the YAML loader to convert a YAML node
into an Artifact object.
It should not be called directly.
"""
data = loader.construct_mapping(node, deep=True)
return data["class"](artifact_id=data["artifact_id"])
[docs] @staticmethod
def to_yaml(dumper: yaml.Dumper, data: Artifact):
"""This function is called by the YAML dumper to convert an Artifact
object into a YAML node.
It should not be called directly.
"""
data = {
"artifact_id": data.id,
"class": type(data),
}
node = dumper.represent_mapping("!Artifact", data)
return node
def _ensure_downloaded(self):
if not self.is_downloaded:
raise ValueError(
"Cannot load `Artifact` that has not been downloaded. "
"First call `artifact.download()`."
)
yaml.add_multi_representer(Artifact, Artifact.to_yaml)
yaml.add_constructor("!Artifact", Artifact.from_yaml)
[docs]class CSVArtifact(Artifact):
DEFAULT_EXT: str = "csv"
[docs] def load(self) -> pd.DataFrame:
self._ensure_downloaded()
data = pd.read_csv(self.local_path, index_col=0)
def parselists(x):
if isinstance(x, str):
try:
return json.loads(x)
except ValueError:
return x
else:
return x
return data.applymap(parselists)
[docs] def save(self, data: pd.DataFrame) -> None:
return data.to_csv(self.local_path)
[docs]class YAMLArtifact(Artifact):
DEFAULT_EXT: str = "yaml"
[docs] def load(self) -> Any:
self._ensure_downloaded()
return yaml.load(open(self.local_path), Loader=yaml.FullLoader)
[docs] def save(self, data: Any) -> None:
return yaml.dump(data, open(self.local_path, "w"))
[docs]class DataPanelArtifact(Artifact):
DEFAULT_EXT: str = "mk"
isdir: bool = True
[docs] def load(self) -> pd.DataFrame:
self._ensure_downloaded()
return mk.DataPanel.read(self.local_path)
[docs] def save(self, data: mk.DataPanel) -> None:
return data.write(self.local_path)
[docs]class VisionDatasetArtifact(DataPanelArtifact):
DEFAULT_EXT: str = "mk"
isdir: bool = True
COLUMN_SUBSETS = {
"celeba": ["id", "image", "identity", "split"],
"imagenet": ["id", "image", "name", "synset"],
}
[docs] @classmethod
def from_name(cls, name: str):
if name == "celeba":
dp = mk.datasets.get(name, dataset_dir=config.celeba_dir)
elif name == "imagenet":
dp = mk.datasets.get(name, dataset_dir=config.imagenet_dir, download=False)
else:
raise ValueError(f"No dataset named '{name}' supported by dcbench.")
dp["id"] = dp["image_id"]
dp.remove_column("image_id")
dp = dp[cls.COLUMN_SUBSETS[name]]
artifact = cls.from_data(data=dp, artifact_id=name)
return artifact
[docs] def download(self, force: bool = False):
if self.id == "celeba":
dp = mk.datasets.get(self.id, dataset_dir=config.celeba_dir)
elif self.id == "imagenet":
dp = mk.datasets.get(
self.id, dataset_dir=config.imagenet_dir, download=False
)
else:
raise ValueError(f"No dataset named '{self.id}' supported by dcbench.")
dp["id"] = dp["image_id"]
dp.remove_column("image_id")
dp = dp[self.COLUMN_SUBSETS[self.id]]
self.save(data=dp[self.COLUMN_SUBSETS[self.id]])
[docs]class ModelArtifact(Artifact):
DEFAULT_EXT: str = "pt"
[docs] def load(self) -> Model:
self._ensure_downloaded()
dct = torch.load(self.local_path, map_location="cpu")
model = dct["class"](dct["config"])
model.load_state_dict(dct["state_dict"])
return model
[docs] def save(self, data: Model) -> None:
return torch.save(
{
"state_dict": data.state_dict(),
"config": data.config,
"class": type(data),
},
self.local_path,
)