ydshieh commited on
Commit
a722530
1 Parent(s): 3c34a2e

upload coco summary script

Browse files
Files changed (1) hide show
  1. run_summarization_coco.py +11 -5
run_summarization_coco.py CHANGED
@@ -37,6 +37,7 @@ import nltk # Here to have a nice missing dependency error message early on
37
  import numpy as np
38
  from datasets import Dataset, load_dataset, load_metric
39
  from tqdm import tqdm
 
40
 
41
  import jax
42
  import jax.numpy as jnp
@@ -418,19 +419,24 @@ def main():
418
 
419
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
420
  def preprocess_function(examples):
421
-
422
  _pixel_values = []
423
- for y in examples[image_file_column]:
 
424
  with Image.open(y) as image:
425
- encoder_inputs = feature_extractor(images=image, return_tensors="np")
 
 
 
426
  x = encoder_inputs.pixel_values
427
  _pixel_values.append(x)
 
428
  pixel_values = np.concatenate(_pixel_values)
429
 
430
- targets = examples[caption_column]
431
 
432
  # Add eos_token!!
433
- targets = [x.lower() + ' ' + tokenizer.eos_token for x in targets]
434
 
435
  model_inputs = {}
436
  model_inputs['pixel_values'] = pixel_values
 
37
  import numpy as np
38
  from datasets import Dataset, load_dataset, load_metric
39
  from tqdm import tqdm
40
+ from PIL import Image
41
 
42
  import jax
43
  import jax.numpy as jnp
 
419
 
420
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
421
  def preprocess_function(examples):
422
+
423
  _pixel_values = []
424
+ _captions = []
425
+ for y, z in zip(examples[image_file_column], examples[caption_column]):
426
  with Image.open(y) as image:
427
+ try:
428
+ encoder_inputs = feature_extractor(images=image, return_tensors="np")
429
+ except:
430
+ continue
431
  x = encoder_inputs.pixel_values
432
  _pixel_values.append(x)
433
+ _captions.append(z + ' ' + tokenizer.eos_token)
434
  pixel_values = np.concatenate(_pixel_values)
435
 
436
+ targets = _captions
437
 
438
  # Add eos_token!!
439
+ #targets = [x + ' ' + tokenizer.eos_token for x in targets]
440
 
441
  model_inputs = {}
442
  model_inputs['pixel_values'] = pixel_values