File size: 844 Bytes
7e48337
f9d51f7
7e48337
 
 
 
 
 
 
 
 
 
 
 
 
a5ed112
 
 
 
7e48337
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os

import wandb


class PretrainedFromWandbMixin:
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Initializes from a wandb artifact, or delegates loading to the superclass.
        """
        if ":" in pretrained_model_name_or_path and not os.path.isdir(
            pretrained_model_name_or_path
        ):
            # wandb artifact
            if wandb.run is not None:
                artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
            else:
                artifact = wandb.Api().artifact(pretrained_model_name_or_path)
            pretrained_model_name_or_path = artifact.download()

        return super(PretrainedFromWandbMixin, cls).from_pretrained(
            pretrained_model_name_or_path, *model_args, **kwargs
        )