ydshieh HF staff commited on
Commit
8df5e7a
1 Parent(s): c67db7a

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +3 -2
pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from PIL import Image
3
  import jax
4
  from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
@@ -30,7 +31,7 @@ class PreTrainedPipeline():
30
  output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
31
  return output_ids
32
 
33
- def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
34
  """
35
  Args:
36
  Return:
@@ -42,4 +43,4 @@ class PreTrainedPipeline():
42
  preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
43
  preds = [pred.strip() for pred in preds]
44
 
45
- return preds[0]
 
1
  import os
2
+ from typing import Dict, List, Any
3
  from PIL import Image
4
  import jax
5
  from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
 
31
  output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
32
  return output_ids
33
 
34
+ def __call__(self, inputs: "Image.Image") -> List[str]:
35
  """
36
  Args:
37
  Return:
 
43
  preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
44
  preds = [pred.strip() for pred in preds]
45
 
46
+ return preds