ydshieh HF staff commited on
Commit
3fac44f
1 Parent(s): 8df5e7a

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +7 -5
pipeline.py CHANGED
@@ -18,6 +18,12 @@ class PreTrainedPipeline():
18
  max_length = 16
19
  num_beams = 4
20
  self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
 
 
 
 
 
 
21
 
22
  # compile the model
23
  image_path = os.path.join(path, 'val_000000039769.jpg')
@@ -25,11 +31,7 @@ class PreTrainedPipeline():
25
  self(image)
26
  image.close()
27
 
28
- @jax.jit
29
- def generate(self, pixel_values):
30
-
31
- output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
32
- return output_ids
33
 
34
  def __call__(self, inputs: "Image.Image") -> List[str]:
35
  """
18
  max_length = 16
19
  num_beams = 4
20
  self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
21
+
22
+ @jax.jit
23
+ def _generate(pixel_values):
24
+
25
+ output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
26
+ return output_ids
27
 
28
  # compile the model
29
  image_path = os.path.join(path, 'val_000000039769.jpg')
31
  self(image)
32
  image.close()
33
 
34
+ self.generate = _generate
 
 
 
 
35
 
36
  def __call__(self, inputs: "Image.Image") -> List[str]:
37
  """