ydshieh HF staff commited on
Commit
25a7779
1 Parent(s): 7bbd359

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +14 -6
pipeline.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  from typing import Dict, List, Any
3
  from PIL import Image
4
  import jax
5
- from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
 
6
 
7
 
8
  class PreTrainedPipeline():
@@ -11,18 +12,24 @@ class PreTrainedPipeline():
11
 
12
  model_dir = path
13
 
14
- self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
 
15
  self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
16
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
 
18
  max_length = 16
19
  num_beams = 4
20
- self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
 
21
 
22
- @jax.jit
 
 
 
23
  def _generate(pixel_values):
24
 
25
- output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
 
26
  return output_ids
27
 
28
  self.generate = _generate
@@ -39,7 +46,8 @@ class PreTrainedPipeline():
39
  Return:
40
  """
41
 
42
- pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values
 
43
 
44
  output_ids = self.generate(pixel_values)
45
  preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
 
2
  from typing import Dict, List, Any
3
  from PIL import Image
4
  import jax
5
+ from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel, VisionEncoderDecoderModel
6
+ import torch
7
 
8
 
9
  class PreTrainedPipeline():
 
12
 
13
  model_dir = path
14
 
15
+ # self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
16
+ self.model = VisionEncoderDecoderModel.from_pretrained(model_dir)
17
  self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
18
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
19
 
20
  max_length = 16
21
  num_beams = 4
22
+ # self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
23
+ self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, return_dict_in_generate=True}
24
 
25
+ self.model.to("cpu")
26
+ self.model.eval()
27
+
28
+ # @jax.jit
29
  def _generate(pixel_values):
30
 
31
+ with torch.no_grad():
32
+ output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
33
  return output_ids
34
 
35
  self.generate = _generate
 
46
  Return:
47
  """
48
 
49
+ # pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values
50
+ pixel_values = self.feature_extractor(images=inputs, return_tensors="pt").pixel_values
51
 
52
  output_ids = self.generate(pixel_values)
53
  preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)