Spaces:
Running
Running
Merge pull request #29 from borisdayma/load_checkpoint
Browse files- seq2seq/run_seq2seq_flax.py +44 -23
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -125,6 +125,12 @@ class ModelArguments:
|
|
125 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
126 |
},
|
127 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
|
130 |
@dataclass
|
@@ -424,36 +430,51 @@ def main():
|
|
424 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
425 |
|
426 |
# Load pretrained model and tokenizer
|
427 |
-
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
428 |
-
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
429 |
-
)
|
430 |
tokenizer = AutoTokenizer.from_pretrained(
|
431 |
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
432 |
)
|
433 |
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
|
439 |
-
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
440 |
-
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
441 |
-
config.forced_bos_token_id = None # we don't need this token
|
442 |
-
config.forced_eos_token_id = None # we don't need this token
|
443 |
-
config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
|
444 |
-
config.min_length = data_args.max_target_length
|
445 |
-
config.max_length = data_args.max_target_length
|
446 |
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
|
450 |
-
|
451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
|
453 |
-
|
454 |
-
|
455 |
-
model.params['model']['shared'] = base_model.params['model']['shared']
|
456 |
-
del base_model
|
457 |
|
458 |
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
459 |
|
|
|
125 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
126 |
},
|
127 |
)
|
128 |
+
from_checkpoint: Optional[str] = field(
|
129 |
+
default=None,
|
130 |
+
metadata={
|
131 |
+
"help": "Loads a pretrained wandb checkpoint. Use artifact reference."
|
132 |
+
},
|
133 |
+
)
|
134 |
|
135 |
|
136 |
@dataclass
|
|
|
430 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
431 |
|
432 |
# Load pretrained model and tokenizer
|
|
|
|
|
|
|
433 |
tokenizer = AutoTokenizer.from_pretrained(
|
434 |
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
435 |
)
|
436 |
|
437 |
+
if model_args.from_checkpoint is not None:
|
438 |
+
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
439 |
+
artifact_dir = artifact.download()
|
440 |
+
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
|
442 |
+
# some models will try to change bos (because of force_bos_token_to_be_generated)
|
443 |
+
# we ensure bos and eos are not forced
|
444 |
+
model.config.force_bos_token_to_be_generated = False
|
445 |
+
model.config.forced_bos_token_id = None
|
446 |
+
model.config.forced_eos_token_id = None
|
447 |
+
|
448 |
+
# used in the preprocessing function
|
449 |
+
config = model.config
|
450 |
|
451 |
+
else:
|
452 |
+
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
453 |
+
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
454 |
+
)
|
455 |
+
# Set up our new model config
|
456 |
+
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
457 |
+
config.tie_word_embeddings = False
|
458 |
+
config.decoder_start_token_id = BOS_TOKEN_ID # for first token
|
459 |
+
config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
|
460 |
+
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
461 |
+
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
462 |
+
config.forced_bos_token_id = None # we don't need this token
|
463 |
+
config.forced_eos_token_id = None # we don't need this token
|
464 |
+
config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
|
465 |
+
config.min_length = data_args.max_target_length
|
466 |
+
config.max_length = data_args.max_target_length
|
467 |
+
|
468 |
+
# Create a custom model and initialize it randomly
|
469 |
+
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
470 |
+
|
471 |
+
# Use pre-trained weights for encoder
|
472 |
+
model.params['model']['encoder'] = base_model.params['model']['encoder']
|
473 |
+
model.params['model']['shared'] = base_model.params['model']['shared']
|
474 |
+
del base_model
|
475 |
|
476 |
+
print(f"TPUs: {jax.device_count()}")
|
477 |
+
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
|
|
|
|
478 |
|
479 |
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
480 |
|