Update pipeline.py
Browse files- pipeline.py +3 -2
pipeline.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
from PIL import Image
|
3 |
import jax
|
4 |
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
|
@@ -30,7 +31,7 @@ class PreTrainedPipeline():
|
|
30 |
output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
|
31 |
return output_ids
|
32 |
|
33 |
-
def __call__(self, inputs: "Image.Image") -> List[
|
34 |
"""
|
35 |
Args:
|
36 |
Return:
|
@@ -42,4 +43,4 @@ class PreTrainedPipeline():
|
|
42 |
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
43 |
preds = [pred.strip() for pred in preds]
|
44 |
|
45 |
-
return preds
|
|
|
1 |
import os
|
2 |
+
from typing import Dict, List, Any
|
3 |
from PIL import Image
|
4 |
import jax
|
5 |
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
|
|
|
31 |
output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
|
32 |
return output_ids
|
33 |
|
34 |
+
def __call__(self, inputs: "Image.Image") -> List[str]:
|
35 |
"""
|
36 |
Args:
|
37 |
Return:
|
|
|
43 |
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
44 |
preds = [pred.strip() for pred in preds]
|
45 |
|
46 |
+
return preds
|