ydshieh HF staff commited on
Commit
3077814
1 Parent(s): fdc844c

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +33 -29
pipeline.py CHANGED
@@ -1,40 +1,44 @@
1
- from typing import Dict, List, Any
2
- from PIL import Image
3
-
4
  import os
5
- import json
6
- import numpy as np
7
- from fastai.learner import load_learner
8
 
9
- from helpers import is_cat
10
 
11
  class PreTrainedPipeline():
 
12
  def __init__(self, path=""):
13
- # IMPLEMENT_THIS
14
- # Preload all the elements you are going to need at inference.
15
- # For instance your model, processors, tokenizer that might be needed.
16
- # This function is only called once, so do all the heavy processing I/O here"""
17
- self.model = load_learner(os.path.join(path, "model.pkl"))
18
- with open(os.path.join(path, "config.json")) as config:
19
- config = json.load(config)
20
- self.id2label = config["id2label"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
23
  """
24
  Args:
25
- inputs (:obj:`PIL.Image`):
26
- The raw image representation as PIL.
27
- No transformation made whatsoever from the input. Make all necessary transformations here.
28
  Return:
29
- A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
30
- It is preferred if the returned list is in decreasing `score` order
31
  """
32
- # IMPLEMENT_THIS
33
- # FastAI expects a np array, not a PIL Image.
34
- _, _, preds = self.model.predict(np.array(inputs))
35
- preds = preds.tolist()
36
- labels = [
37
- {"label": str(self.id2label["0"]), "score": preds[0]},
38
- {"label": str(self.id2label["1"]), "score": preds[1]},
39
- ]
40
- return labels
 
 
 
1
  import os
2
+ from PIL import Image
3
+ from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
 
4
 
 
5
 
6
  class PreTrainedPipeline():
7
+
8
  def __init__(self, path=""):
9
+
10
+ model_dir = os.path.join(path, "ckpt_epoch_3_step_6900")
11
+
12
+ self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
13
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
15
+
16
+ max_length = 16
17
+ num_beams = 4
18
+ self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
19
+
20
+ # compile the model
21
+ image_path = os.path.join(path, 'val_000000039769.jpg')
22
+ image = Image.open(image_path)
23
+ self(image)
24
+ image.close()
25
+
26
+ @jax.jit
27
+ def generate(self, pixel_values):
28
+
29
+ output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
30
+ return output_ids
31
 
32
  def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
33
  """
34
  Args:
 
 
 
35
  Return:
 
 
36
  """
37
+
38
+ pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values
39
+
40
+ output_ids = self.generate(pixel_values)
41
+ preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
42
+ preds = [pred.strip() for pred in preds]
43
+
44
+ return preds[0]