Spaces:
Running
Running
Merge pull request #109 from borisdayma/feat-model_pretrained
Browse files- dev/seq2seq/run_seq2seq_flax.py +56 -19
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -68,6 +68,12 @@ class ModelArguments:
|
|
68 |
"Don't set if you want to train a model from scratch."
|
69 |
},
|
70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
image_vocab_size: Optional[int] = field(
|
72 |
default=None,
|
73 |
metadata={"help": "Vocab size of image encoder"},
|
@@ -82,9 +88,11 @@ class ModelArguments:
|
|
82 |
"help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
|
83 |
},
|
84 |
)
|
85 |
-
normalize_text: bool = field(
|
86 |
-
default=
|
87 |
-
metadata={
|
|
|
|
|
88 |
)
|
89 |
dtype: Optional[str] = field(
|
90 |
default="float32",
|
@@ -125,8 +133,9 @@ class DataTrainingArguments:
|
|
125 |
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
|
126 |
},
|
127 |
)
|
|
|
128 |
streaming: bool = field(
|
129 |
-
default=
|
130 |
metadata={"help": "Whether to stream the dataset."},
|
131 |
)
|
132 |
use_auth_token: bool = field(
|
@@ -283,9 +292,9 @@ class TrainingArguments:
|
|
283 |
},
|
284 |
)
|
285 |
|
286 |
-
|
287 |
default=None,
|
288 |
-
metadata={"help": "
|
289 |
)
|
290 |
|
291 |
|
@@ -460,12 +469,14 @@ def main():
|
|
460 |
config=parser.parse_args(),
|
461 |
)
|
462 |
|
463 |
-
if training_args.
|
464 |
-
artifact = wandb.run.use_artifact(training_args.
|
465 |
artifact_dir = artifact.download()
|
466 |
|
467 |
# load model
|
468 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
|
|
|
|
469 |
|
470 |
# load tokenizer
|
471 |
tokenizer = AutoTokenizer.from_pretrained(
|
@@ -476,9 +487,20 @@ def main():
|
|
476 |
else:
|
477 |
# Set up our new model config
|
478 |
# TODO: simplify with custom config class
|
479 |
-
|
480 |
-
|
481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
482 |
# we append decoder bos to image vocab
|
483 |
config.decoder_start_token_id = config.image_vocab_size
|
484 |
# ensure we don't generate bos (in addition to decoder start token)
|
@@ -487,8 +509,8 @@ def main():
|
|
487 |
config.forced_eos_token_id = None # we don't need this token
|
488 |
|
489 |
config.tie_word_embeddings = False
|
490 |
-
config.min_length =
|
491 |
-
config.max_length =
|
492 |
|
493 |
# below tokens need to be set to avoid error during generation (converted to jnp.array)
|
494 |
# they are not expected to be used and are set to unreachable token id
|
@@ -497,12 +519,27 @@ def main():
|
|
497 |
config.eos_token_id = config.image_vocab_size + 1
|
498 |
|
499 |
# save whether we normalize the text
|
500 |
-
|
|
|
|
|
|
|
501 |
|
502 |
-
#
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
506 |
|
507 |
# Load tokenizer
|
508 |
if model_args.tokenizer_name is not None:
|
@@ -741,7 +778,7 @@ def main():
|
|
741 |
tx=optimizer,
|
742 |
dropout_rng=dropout_rng,
|
743 |
)
|
744 |
-
if training_args.
|
745 |
# restore optimizer state and other parameters
|
746 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
747 |
state = state.restore_state(artifact_dir)
|
|
|
68 |
"Don't set if you want to train a model from scratch."
|
69 |
},
|
70 |
)
|
71 |
+
config_name: Optional[str] = field(
|
72 |
+
default=None,
|
73 |
+
metadata={
|
74 |
+
"help": "Pretrained config name or path if not the same as model_name"
|
75 |
+
},
|
76 |
+
)
|
77 |
image_vocab_size: Optional[int] = field(
|
78 |
default=None,
|
79 |
metadata={"help": "Vocab size of image encoder"},
|
|
|
88 |
"help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
|
89 |
},
|
90 |
)
|
91 |
+
normalize_text: Optional[bool] = field(
|
92 |
+
default=None,
|
93 |
+
metadata={
|
94 |
+
"help": "Whether to normalize text or not. By default, we refer to base model or don't normalize for new models."
|
95 |
+
},
|
96 |
)
|
97 |
dtype: Optional[str] = field(
|
98 |
default="float32",
|
|
|
133 |
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
|
134 |
},
|
135 |
)
|
136 |
+
# data loading should not be a bottleneck so we use "streaming" mode by default
|
137 |
streaming: bool = field(
|
138 |
+
default=True,
|
139 |
metadata={"help": "Whether to stream the dataset."},
|
140 |
)
|
141 |
use_auth_token: bool = field(
|
|
|
292 |
},
|
293 |
)
|
294 |
|
295 |
+
resume_from_checkpoint: Optional[str] = field(
|
296 |
default=None,
|
297 |
+
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
298 |
)
|
299 |
|
300 |
|
|
|
469 |
config=parser.parse_args(),
|
470 |
)
|
471 |
|
472 |
+
if training_args.resume_from_checkpoint is not None:
|
473 |
+
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
|
474 |
artifact_dir = artifact.download()
|
475 |
|
476 |
# load model
|
477 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
478 |
+
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
479 |
+
print(model.params)
|
480 |
|
481 |
# load tokenizer
|
482 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
487 |
else:
|
488 |
# Set up our new model config
|
489 |
# TODO: simplify with custom config class
|
490 |
+
if model_args.config_name:
|
491 |
+
config = BartConfig.from_pretrained(model_args.config_name)
|
492 |
+
else:
|
493 |
+
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
494 |
+
if model_args.image_vocab_size:
|
495 |
+
config.image_vocab_size = model_args.image_vocab_size
|
496 |
+
assert (
|
497 |
+
getattr(config, "image_vocab_size") is not None
|
498 |
+
), "image_vocab_size must be specified when not present in base model/config"
|
499 |
+
if model_args.image_length:
|
500 |
+
config.image_length = model_args.image_length
|
501 |
+
assert (
|
502 |
+
getattr(config, "image_length") is not None
|
503 |
+
), "image_length must be specified when not present in base model/config"
|
504 |
# we append decoder bos to image vocab
|
505 |
config.decoder_start_token_id = config.image_vocab_size
|
506 |
# ensure we don't generate bos (in addition to decoder start token)
|
|
|
509 |
config.forced_eos_token_id = None # we don't need this token
|
510 |
|
511 |
config.tie_word_embeddings = False
|
512 |
+
config.min_length = config.image_length + 1
|
513 |
+
config.max_length = config.image_length + 1
|
514 |
|
515 |
# below tokens need to be set to avoid error during generation (converted to jnp.array)
|
516 |
# they are not expected to be used and are set to unreachable token id
|
|
|
519 |
config.eos_token_id = config.image_vocab_size + 1
|
520 |
|
521 |
# save whether we normalize the text
|
522 |
+
if model_args.normalize_text is not None:
|
523 |
+
config.normalize_text = model_args.normalize_text
|
524 |
+
else:
|
525 |
+
config.normalize_text = getattr(config, "normalize_text", False)
|
526 |
|
527 |
+
# Load or create new model
|
528 |
+
if model_args.model_name_or_path:
|
529 |
+
model = CustomFlaxBartForConditionalGeneration.from_pretrained(
|
530 |
+
model_args.model_name_or_path,
|
531 |
+
config=config,
|
532 |
+
seed=training_args.seed_model,
|
533 |
+
dtype=getattr(jnp, model_args.dtype),
|
534 |
+
)
|
535 |
+
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
536 |
+
print(model.params)
|
537 |
+
else:
|
538 |
+
model = CustomFlaxBartForConditionalGeneration(
|
539 |
+
config,
|
540 |
+
seed=training_args.seed_model,
|
541 |
+
dtype=getattr(jnp, model_args.dtype),
|
542 |
+
)
|
543 |
|
544 |
# Load tokenizer
|
545 |
if model_args.tokenizer_name is not None:
|
|
|
778 |
tx=optimizer,
|
779 |
dropout_rng=dropout_rng,
|
780 |
)
|
781 |
+
if training_args.resume_from_checkpoint is not None:
|
782 |
# restore optimizer state and other parameters
|
783 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
784 |
state = state.restore_state(artifact_dir)
|