boris commited on
Commit
5f954fc
1 Parent(s): 4cb21dd

feat: restore weights on CPU

Browse files
src/dalle_mini/model/modeling.py CHANGED
@@ -15,16 +15,30 @@
15
  """ DalleBart model. """
16
 
17
  import math
 
18
  from functools import partial
19
- from typing import Optional, Tuple
 
20
 
21
  import flax.linen as nn
22
  import jax
23
  import jax.numpy as jnp
 
24
  from flax.core.frozen_dict import unfreeze
25
  from flax.linen import make_causal_mask
26
- from flax.traverse_util import flatten_dict
 
 
27
  from jax.random import PRNGKey
 
 
 
 
 
 
 
 
 
28
  from transformers.modeling_flax_outputs import (
29
  FlaxCausalLMOutputWithCrossAttentions,
30
  FlaxSeq2SeqLMOutput,
@@ -300,7 +314,8 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
300
  - added num_params property
301
  - config_class replaced to DalleBartConfig
302
  - __init__ accepts abstract_init which does uses parameter shape to initialize the model
303
- - init weights on CPU
 
304
  """
305
 
306
  config_class = DalleBartConfig
@@ -359,6 +374,243 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
359
  ).values()
360
  return sum(list(num_params))
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
364
  """
 
15
  """ DalleBart model. """
16
 
17
  import math
18
+ import os
19
  from functools import partial
20
+ from pickle import UnpicklingError
21
+ from typing import Optional, Tuple, Union
22
 
23
  import flax.linen as nn
24
  import jax
25
  import jax.numpy as jnp
26
+ import msgpack.exceptions
27
  from flax.core.frozen_dict import unfreeze
28
  from flax.linen import make_causal_mask
29
+ from flax.serialization import from_bytes
30
+ from flax.traverse_util import flatten_dict, unflatten_dict
31
+ from jax import lax
32
  from jax.random import PRNGKey
33
+ from transformers.configuration_utils import PretrainedConfig
34
+ from transformers.file_utils import (
35
+ FLAX_WEIGHTS_NAME,
36
+ WEIGHTS_NAME,
37
+ cached_path,
38
+ hf_bucket_url,
39
+ is_offline_mode,
40
+ is_remote_url,
41
+ )
42
  from transformers.modeling_flax_outputs import (
43
  FlaxCausalLMOutputWithCrossAttentions,
44
  FlaxSeq2SeqLMOutput,
 
314
  - added num_params property
315
  - config_class replaced to DalleBartConfig
316
  - __init__ accepts abstract_init which does uses parameter shape to initialize the model
317
+ - init weights on CPU with `load_on_cpu`
318
+ - restore weights on CPU with custom `from_pretrained`
319
  """
320
 
321
  config_class = DalleBartConfig
 
374
  ).values()
375
  return sum(list(num_params))
376
 
377
+ @classmethod
378
+ def from_pretrained(
379
+ cls,
380
+ pretrained_model_name_or_path: Union[str, os.PathLike],
381
+ dtype: jnp.dtype = jnp.float32,
382
+ *model_args,
383
+ **kwargs,
384
+ ):
385
+ config = kwargs.pop("config", None)
386
+ cache_dir = kwargs.pop("cache_dir", None)
387
+ from_pt = kwargs.pop("from_pt", False)
388
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
389
+ force_download = kwargs.pop("force_download", False)
390
+ resume_download = kwargs.pop("resume_download", False)
391
+ proxies = kwargs.pop("proxies", None)
392
+ local_files_only = kwargs.pop("local_files_only", False)
393
+ use_auth_token = kwargs.pop("use_auth_token", None)
394
+ revision = kwargs.pop("revision", None)
395
+ from_pipeline = kwargs.pop("_from_pipeline", None)
396
+ from_auto_class = kwargs.pop("_from_auto", False)
397
+
398
+ user_agent = {
399
+ "file_type": "model",
400
+ "framework": "flax",
401
+ "from_auto_class": from_auto_class,
402
+ }
403
+ if from_pipeline is not None:
404
+ user_agent["using_pipeline"] = from_pipeline
405
+
406
+ if is_offline_mode() and not local_files_only:
407
+ logger.info("Offline mode: forcing local_files_only=True")
408
+ local_files_only = True
409
+
410
+ # Load config if we don't provide a configuration
411
+ if not isinstance(config, PretrainedConfig):
412
+ config_path = (
413
+ config if config is not None else pretrained_model_name_or_path
414
+ )
415
+ config, model_kwargs = cls.config_class.from_pretrained(
416
+ config_path,
417
+ cache_dir=cache_dir,
418
+ return_unused_kwargs=True,
419
+ force_download=force_download,
420
+ resume_download=resume_download,
421
+ proxies=proxies,
422
+ local_files_only=local_files_only,
423
+ use_auth_token=use_auth_token,
424
+ revision=revision,
425
+ _from_auto=from_auto_class,
426
+ _from_pipeline=from_pipeline,
427
+ **kwargs,
428
+ )
429
+ else:
430
+ model_kwargs = kwargs
431
+
432
+ # Add the dtype to model_kwargs
433
+ model_kwargs["dtype"] = dtype
434
+
435
+ # Load model
436
+ if pretrained_model_name_or_path is not None:
437
+ if os.path.isdir(pretrained_model_name_or_path):
438
+ if from_pt and os.path.isfile(
439
+ os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
440
+ ):
441
+ # Load from a PyTorch checkpoint
442
+ archive_file = os.path.join(
443
+ pretrained_model_name_or_path, WEIGHTS_NAME
444
+ )
445
+ elif os.path.isfile(
446
+ os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
447
+ ):
448
+ # Load from a Flax checkpoint
449
+ archive_file = os.path.join(
450
+ pretrained_model_name_or_path, FLAX_WEIGHTS_NAME
451
+ )
452
+ else:
453
+ raise EnvironmentError(
454
+ f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
455
+ f"{pretrained_model_name_or_path} or `from_pt` set to False"
456
+ )
457
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
458
+ pretrained_model_name_or_path
459
+ ):
460
+ archive_file = pretrained_model_name_or_path
461
+ else:
462
+ archive_file = hf_bucket_url(
463
+ pretrained_model_name_or_path,
464
+ filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
465
+ revision=revision,
466
+ )
467
+
468
+ # redirect to the cache, if necessary
469
+ try:
470
+ resolved_archive_file = cached_path(
471
+ archive_file,
472
+ cache_dir=cache_dir,
473
+ force_download=force_download,
474
+ proxies=proxies,
475
+ resume_download=resume_download,
476
+ local_files_only=local_files_only,
477
+ use_auth_token=use_auth_token,
478
+ user_agent=user_agent,
479
+ )
480
+ except EnvironmentError as err:
481
+ logger.error(err)
482
+ msg = (
483
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
484
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
485
+ f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
486
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
487
+ )
488
+ raise EnvironmentError(msg)
489
+
490
+ if resolved_archive_file == archive_file:
491
+ logger.info(f"loading weights file {archive_file}")
492
+ else:
493
+ logger.info(
494
+ f"loading weights file {archive_file} from cache at {resolved_archive_file}"
495
+ )
496
+ else:
497
+ resolved_archive_file = None
498
+
499
+ # init random models
500
+ model = cls(config, *model_args, **model_kwargs)
501
+
502
+ with open(resolved_archive_file, "rb") as state_f:
503
+ try:
504
+ state = from_bytes(cls, state_f.read())
505
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
506
+ try:
507
+ with open(resolved_archive_file) as f:
508
+ if f.read().startswith("version"):
509
+ raise OSError(
510
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
511
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
512
+ "you cloned."
513
+ )
514
+ else:
515
+ raise ValueError from e
516
+ except (UnicodeDecodeError, ValueError):
517
+ raise EnvironmentError(
518
+ f"Unable to convert {archive_file} to Flax deserializable object. "
519
+ )
520
+
521
+ # if model is base model only use model_prefix key
522
+ if (
523
+ cls.base_model_prefix not in dict(model.params)
524
+ and cls.base_model_prefix in state
525
+ ):
526
+ state = state[cls.base_model_prefix]
527
+
528
+ # if model is head model and we are loading weights from base model
529
+ # we initialize new params dict with base_model_prefix
530
+ if (
531
+ cls.base_model_prefix in dict(model.params)
532
+ and cls.base_model_prefix not in state
533
+ ):
534
+ state = {cls.base_model_prefix: state}
535
+
536
+ # flatten dicts
537
+ state = flatten_dict(state)
538
+
539
+ random_state = flatten_dict(unfreeze(model.params))
540
+
541
+ missing_keys = model.required_params - set(state.keys())
542
+ unexpected_keys = set(state.keys()) - model.required_params
543
+
544
+ # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
545
+ # matching the weights in the model.
546
+ mismatched_keys = []
547
+ for key in state.keys():
548
+ if key in random_state and state[key].shape != random_state[key].shape:
549
+ if ignore_mismatched_sizes:
550
+ mismatched_keys.append(
551
+ (key, state[key].shape, random_state[key].shape)
552
+ )
553
+ state[key] = random_state[key]
554
+ else:
555
+ raise ValueError(
556
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
557
+ f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
558
+ "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
559
+ "model."
560
+ )
561
+
562
+ # add missing keys as random parameters
563
+ for missing_key in missing_keys:
564
+ state[missing_key] = random_state[missing_key]
565
+
566
+ # remove unexpected keys to not be saved again
567
+ for unexpected_key in unexpected_keys:
568
+ del state[unexpected_key]
569
+
570
+ if len(unexpected_keys) > 0:
571
+ logger.warning(
572
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
573
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
574
+ f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
575
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
576
+ f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
577
+ f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
578
+ )
579
+ else:
580
+ logger.info(
581
+ f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
582
+ )
583
+
584
+ if len(missing_keys) > 0:
585
+ logger.warning(
586
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
587
+ f"and are newly initialized: {missing_keys}\n"
588
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
589
+ )
590
+ elif len(mismatched_keys) == 0:
591
+ logger.info(
592
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
593
+ f"If your task is similar to the task the model of the checkpoint was trained on, "
594
+ f"you can already use {model.__class__.__name__} for predictions without further training."
595
+ )
596
+ if len(mismatched_keys) > 0:
597
+ mismatched_warning = "\n".join(
598
+ [
599
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
600
+ for key, shape1, shape2 in mismatched_keys
601
+ ]
602
+ )
603
+ logger.warning(
604
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
605
+ f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
606
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
607
+ )
608
+
609
+ # set correct parameters
610
+ model.params = unflatten_dict(state)
611
+
612
+ return model
613
+
614
 
615
  class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
616
  """
tools/train/train.py CHANGED
@@ -249,6 +249,9 @@ class TrainingArguments:
249
  "help": "Number of updates steps to accumulate before performing an update pass."
250
  },
251
  )
 
 
 
252
 
253
  learning_rate: float = field(
254
  default=5e-5, metadata={"help": "The initial learning rate."}
@@ -515,10 +518,8 @@ def main():
515
  load_on_cpu=True,
516
  )
517
 
518
- # Load tokenizer
519
- tokenizer = DalleBartTokenizer.from_pretrained(
520
- model_args.tokenizer_name, use_fast=True
521
- )
522
 
523
  # get PartitionSpec for model params (required to be a dict)
524
  param_spec = set_partitions(model.params)
@@ -526,14 +527,15 @@ def main():
526
  # convert params to frozen dict
527
  model._params = freeze(model.params)
528
 
 
 
 
 
 
529
  # Preprocessing the datasets.
530
  # We need to normalize and tokenize inputs and targets.
531
-
532
  dataset.preprocess(tokenizer=tokenizer, config=model.config)
533
 
534
- # no dropout (hardcoded)
535
- model.config.dropout = 0.0
536
-
537
  # Initialize our training
538
  dropout_rng = jax.random.PRNGKey(training_args.seed_model)
539
 
 
249
  "help": "Number of updates steps to accumulate before performing an update pass."
250
  },
251
  )
252
+ gradient_checkpointing: bool = field(
253
+ default=False, metadata={"help": "Use gradient checkpointing."}
254
+ )
255
 
256
  learning_rate: float = field(
257
  default=5e-5, metadata={"help": "The initial learning rate."}
 
518
  load_on_cpu=True,
519
  )
520
 
521
+ # update model config per training args
522
+ model.config.gradient_checkpointing = training_args.gradient_checkpointing
 
 
523
 
524
  # get PartitionSpec for model params (required to be a dict)
525
  param_spec = set_partitions(model.params)
 
527
  # convert params to frozen dict
528
  model._params = freeze(model.params)
529
 
530
+ # Load tokenizer
531
+ tokenizer = DalleBartTokenizer.from_pretrained(
532
+ model_args.tokenizer_name, use_fast=True
533
+ )
534
+
535
  # Preprocessing the datasets.
536
  # We need to normalize and tokenize inputs and targets.
 
537
  dataset.preprocess(tokenizer=tokenizer, config=model.config)
538
 
 
 
 
539
  # Initialize our training
540
  dropout_rng = jax.random.PRNGKey(training_args.seed_model)
541