Pedro Cuenca commited on
Commit
55a631d
1 Parent(s): 5ec61cc

Store resolved path after loading model.

Browse files

This should be useful to load additional artifacts such as the
tokenizer, or the optimizer state.

I couldn't find a better location to place this information.

Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +4 -2
src/dalle_mini/model/modeling.py CHANGED
@@ -575,5 +575,7 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
575
  # we download everything, including opt_state, so we can resume training if needed
576
  # see also: #120
577
  pretrained_model_name_or_path = artifact.download()
578
-
579
- return super(DalleBart, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
 
 
 
575
  # we download everything, including opt_state, so we can resume training if needed
576
  # see also: #120
577
  pretrained_model_name_or_path = artifact.download()
578
+
579
+ model = super(DalleBart, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
580
+ model.config.resolved_name_or_path = pretrained_model_name_or_path
581
+ return model