Spaces:
Running
Running
feat: use_artifact if run existing
Browse files
src/dalle_mini/model/configuration.py
CHANGED
@@ -18,7 +18,7 @@ import warnings
|
|
18 |
from transformers.configuration_utils import PretrainedConfig
|
19 |
from transformers.utils import logging
|
20 |
|
21 |
-
from .
|
22 |
|
23 |
logger = logging.get_logger(__name__)
|
24 |
|
|
|
18 |
from transformers.configuration_utils import PretrainedConfig
|
19 |
from transformers.utils import logging
|
20 |
|
21 |
+
from .utils import PretrainedFromWandbMixin
|
22 |
|
23 |
logger = logging.get_logger(__name__)
|
24 |
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -46,7 +46,7 @@ from transformers.models.bart.modeling_flax_bart import (
|
|
46 |
from transformers.utils import logging
|
47 |
|
48 |
from .configuration import DalleBartConfig
|
49 |
-
from .
|
50 |
|
51 |
logger = logging.get_logger(__name__)
|
52 |
|
|
|
46 |
from transformers.utils import logging
|
47 |
|
48 |
from .configuration import DalleBartConfig
|
49 |
+
from .utils import PretrainedFromWandbMixin
|
50 |
|
51 |
logger = logging.get_logger(__name__)
|
52 |
|
src/dalle_mini/model/tokenizer.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
from transformers import BartTokenizer
|
3 |
from transformers.utils import logging
|
4 |
|
5 |
-
from .
|
6 |
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
|
|
2 |
from transformers import BartTokenizer
|
3 |
from transformers.utils import logging
|
4 |
|
5 |
+
from .utils import PretrainedFromWandbMixin
|
6 |
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
src/dalle_mini/model/{wandb_pretrained.py → utils.py}
RENAMED
@@ -13,7 +13,10 @@ class PretrainedFromWandbMixin:
|
|
13 |
pretrained_model_name_or_path
|
14 |
):
|
15 |
# wandb artifact
|
16 |
-
|
|
|
|
|
|
|
17 |
pretrained_model_name_or_path = artifact.download()
|
18 |
|
19 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
|
|
13 |
pretrained_model_name_or_path
|
14 |
):
|
15 |
# wandb artifact
|
16 |
+
if wandb.run is not None:
|
17 |
+
artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
|
18 |
+
else:
|
19 |
+
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
20 |
pretrained_model_name_or_path = artifact.download()
|
21 |
|
22 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|