ydshieh commited on
Commit
dbe1403
1 Parent(s): 68f6bad
Files changed (1) hide show
  1. run_image_captioning_flax_reduced.py +18 -41
run_image_captioning_flax_reduced.py CHANGED
@@ -516,57 +516,34 @@ def main():
516
  decoder_dtype=getattr(jnp, model_args.dtype),
517
  )
518
 
519
- feature_extractor = None
520
  if model_args.feature_extractor_name:
521
  feature_extractor = AutoFeatureExtractor.from_pretrained(
522
  model_args.feature_extractor_name,
523
  cache_dir=model_args.cache_dir,
524
  )
525
- elif model_args.model_name_or_path:
526
- try:
527
- feature_extractor = AutoFeatureExtractor.from_pretrained(
528
- model_args.model_name_or_path, cache_dir=model_args.cache_dir
529
- )
530
- except ValueError as e:
531
- logger.warning(e)
532
- # Check encoder
533
- if not feature_extractor:
534
- if model_args.encoder_model_name_or_path:
535
- feature_extractor = AutoFeatureExtractor.from_pretrained(
536
- model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir
537
- )
538
- else:
539
- raise ValueError(
540
- "You are instantiating a new feature extractor from scratch. This is not supported by this script."
541
- "You can do it from another script, save it, and load it from here, using --feature_extractor_name."
542
- )
543
 
544
- tokenizer = None
545
  if model_args.tokenizer_name:
546
  tokenizer = AutoTokenizer.from_pretrained(
547
  model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
548
  )
549
- elif model_args.model_name_or_path:
550
- try:
551
- tokenizer = AutoTokenizer.from_pretrained(
552
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
553
- )
554
- except ValueError as e:
555
- logger.warning(e)
556
-
557
- # Check decoder
558
- if not tokenizer:
559
- if model_args.decoder_model_name_or_path:
560
- tokenizer = AutoTokenizer.from_pretrained(
561
- model_args.decoder_model_name_or_path,
562
- cache_dir=model_args.cache_dir,
563
- use_fast=model_args.use_fast_tokenizer,
564
- )
565
- else:
566
- raise ValueError(
567
- "You are instantiating a new tokenizer from scratch. This is not supported by this script."
568
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
569
- )
570
  tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
571
 
572
  # Preprocessing the datasets.
 
516
  decoder_dtype=getattr(jnp, model_args.dtype),
517
  )
518
 
 
519
  if model_args.feature_extractor_name:
520
  feature_extractor = AutoFeatureExtractor.from_pretrained(
521
  model_args.feature_extractor_name,
522
  cache_dir=model_args.cache_dir,
523
  )
524
+ elif model_args.encoder_model_name_or_path:
525
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
526
+ model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir
527
+ )
528
+ else:
529
+ raise ValueError(
530
+ "You are instantiating a new feature extractor from scratch. This is not supported by this script."
531
+ "You can do it from another script, save it, and load it from here, using --feature_extractor_name."
532
+ )
 
 
 
 
 
 
 
 
 
533
 
 
534
  if model_args.tokenizer_name:
535
  tokenizer = AutoTokenizer.from_pretrained(
536
  model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
537
  )
538
+ elif model_args.decoder_model_name_or_path:
539
+ tokenizer = AutoTokenizer.from_pretrained(
540
+ model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
541
+ )
542
+ else:
543
+ raise ValueError(
544
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
545
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
546
+ )
 
 
 
 
 
 
 
 
 
 
 
 
547
  tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
548
 
549
  # Preprocessing the datasets.