Update pipeline.py
Browse files- pipeline.py +7 -5
pipeline.py
CHANGED
@@ -18,6 +18,12 @@ class PreTrainedPipeline():
|
|
18 |
max_length = 16
|
19 |
num_beams = 4
|
20 |
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# compile the model
|
23 |
image_path = os.path.join(path, 'val_000000039769.jpg')
|
@@ -25,11 +31,7 @@ class PreTrainedPipeline():
|
|
25 |
self(image)
|
26 |
image.close()
|
27 |
|
28 |
-
|
29 |
-
def generate(self, pixel_values):
|
30 |
-
|
31 |
-
output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
|
32 |
-
return output_ids
|
33 |
|
34 |
def __call__(self, inputs: "Image.Image") -> List[str]:
|
35 |
"""
|
18 |
max_length = 16
|
19 |
num_beams = 4
|
20 |
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
21 |
+
|
22 |
+
@jax.jit
|
23 |
+
def _generate(pixel_values):
|
24 |
+
|
25 |
+
output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
|
26 |
+
return output_ids
|
27 |
|
28 |
# compile the model
|
29 |
image_path = os.path.join(path, 'val_000000039769.jpg')
|
31 |
self(image)
|
32 |
image.close()
|
33 |
|
34 |
+
self.generate = _generate
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def __call__(self, inputs: "Image.Image") -> List[str]:
|
37 |
"""
|