Spaces:
Running
Running
import os | |
import tempfile | |
from pathlib import Path | |
import wandb | |
class PretrainedFromWandbMixin: | |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
""" | |
Initializes from a wandb artifact or delegates loading to the superclass. | |
""" | |
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies | |
if ":" in pretrained_model_name_or_path and not os.path.isdir( | |
pretrained_model_name_or_path | |
): | |
# wandb artifact | |
if wandb.run is not None: | |
artifact = wandb.run.use_artifact(pretrained_model_name_or_path) | |
else: | |
artifact = wandb.Api().artifact(pretrained_model_name_or_path) | |
pretrained_model_name_or_path = artifact.download(tmp_dir) | |
return super(PretrainedFromWandbMixin, cls).from_pretrained( | |
pretrained_model_name_or_path, *model_args, **kwargs | |
) | |