boris's picture
feat: load from bucket
1c4e839
raw
history blame
No virus
1.94 kB
import os
import tempfile
from pathlib import Path
import wandb
class PretrainedFromWandbMixin:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Initializes from a wandb artifact, google bucket path or delegates loading to the superclass.
"""
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
if (
":" in pretrained_model_name_or_path
and not os.path.isdir(pretrained_model_name_or_path)
and not pretrained_model_name_or_path.startswith("gs")
):
# wandb artifact
if wandb.run is not None:
artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
else:
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
pretrained_model_name_or_path = artifact.download(tmp_dir)
if artifact.metadata.get("bucket_path"):
pretrained_model_name_or_path = artifact.metadata["bucket_path"]
if pretrained_model_name_or_path.startswith("gs://"):
copy_blobs(pretrained_model_name_or_path, tmp_dir)
pretrained_model_name_or_path = tmp_dir
return super(PretrainedFromWandbMixin, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
def copy_blobs(source_path, dest_path):
assert source_path.startswith("gs://")
from google.cloud import storage
bucket_path = Path(source_path[5:])
bucket, dir_path = str(bucket_path).split("/", 1)
client = storage.Client()
bucket = client.bucket(bucket)
blobs = client.list_blobs(bucket, prefix=f"{dir_path}/")
for blob in blobs:
dest_name = str(Path(dest_path) / Path(blob.name).name)
blob.download_to_filename(dest_name)