boris commited on
Commit
61c93f2
1 Parent(s): b257ca8

fix: update model name

Browse files
Files changed (1) hide show
  1. tools/train/train.py +1 -1
tools/train/train.py CHANGED
@@ -398,7 +398,7 @@ def main():
398
  artifact_dir = artifact.download()
399
 
400
  # load model
401
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
402
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
403
  print(model.params)
404
 
 
398
  artifact_dir = artifact.download()
399
 
400
  # load model
401
+ model = DalleBart.from_pretrained(artifact_dir)
402
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
403
  print(model.params)
404