File size: 1,944 Bytes
7e48337
34cf91c
1c4e839
f9d51f7
7e48337
 
 
 
 
 
 
1c4e839
7e48337
34cf91c
1c4e839
 
 
 
34cf91c
 
 
 
 
 
 
1c4e839
 
 
 
 
 
7e48337
34cf91c
 
 
1c4e839
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import tempfile
from pathlib import Path

import wandb


class PretrainedFromWandbMixin:
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Initializes from a wandb artifact, google bucket path or delegates loading to the superclass.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:  # avoid multiple artifact copies
            if (
                ":" in pretrained_model_name_or_path
                and not os.path.isdir(pretrained_model_name_or_path)
                and not pretrained_model_name_or_path.startswith("gs")
            ):
                # wandb artifact
                if wandb.run is not None:
                    artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
                else:
                    artifact = wandb.Api().artifact(pretrained_model_name_or_path)
                pretrained_model_name_or_path = artifact.download(tmp_dir)
                if artifact.metadata.get("bucket_path"):
                    pretrained_model_name_or_path = artifact.metadata["bucket_path"]

            if pretrained_model_name_or_path.startswith("gs://"):
                copy_blobs(pretrained_model_name_or_path, tmp_dir)
                pretrained_model_name_or_path = tmp_dir

            return super(PretrainedFromWandbMixin, cls).from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )


def copy_blobs(source_path, dest_path):
    assert source_path.startswith("gs://")
    from google.cloud import storage

    bucket_path = Path(source_path[5:])
    bucket, dir_path = str(bucket_path).split("/", 1)
    client = storage.Client()
    bucket = client.bucket(bucket)
    blobs = client.list_blobs(bucket, prefix=f"{dir_path}/")
    for blob in blobs:
        dest_name = str(Path(dest_path) / Path(blob.name).name)
        blob.download_to_filename(dest_name)