nbroad HF staff commited on
Commit
73353d8
1 Parent(s): f1b11f2
Files changed (1) hide show
  1. handler.py +38 -21
handler.py CHANGED
@@ -1,56 +1,73 @@
1
  import base64
2
  from io import BytesIO
3
- from typing import Dict, List, Any
4
  from transformers import Pix2StructForConditionalGeneration, AutoProcessor
5
  from PIL import Image
6
  import torch
7
 
8
- class EndpointHandler():
 
 
 
 
 
 
 
9
 
10
  def __init__(self, *args, **kwargs):
11
 
12
  model_name = "google/pix2struct-infographics-vqa-large"
13
 
 
 
 
 
 
 
 
14
 
15
- self.model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
 
 
 
 
16
  self.processor = AutoProcessor.from_pretrained(model_name)
17
- self.text_prompt = None #
18
-
19
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
-
21
- self.model.to(self.device)
22
 
 
23
 
24
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
25
  """
 
 
26
  Args:
27
  data (:obj:):
28
  includes the input data and the parameters for the inference.
29
  Return:
30
- a dictionary with the output of the model. The only key is `output` and the
31
- value is a list of str.
32
  """
33
  inputs = data.pop("inputs", data)
34
  parameters = data.pop("parameters", {})
35
 
36
  if isinstance(inputs["image"], list):
37
- img = [Image.open(BytesIO(base64.b64decode(img))) for img in inputs['image']]
 
 
38
  else:
39
- img = Image.open(BytesIO(base64.b64decode(inputs['image'])))
40
 
41
- question = inputs['question']
42
 
43
-
44
-
45
  with torch.inference_mode():
46
- model_inputs = self.processor(images=img, text=question, return_tensors="pt").to(self.device)
 
 
47
 
48
  raw_output = self.model.generate(**model_inputs, **parameters)
49
 
50
- decoded_output = self.processor.batch_decode(raw_output, skip_special_tokens=True)
51
-
 
52
 
53
  # postprocess the prediction
54
- return {
55
- "output": decoded_output
56
- }
 
1
  import base64
2
  from io import BytesIO
3
+ from typing import Dict, List, Any
4
  from transformers import Pix2StructForConditionalGeneration, AutoProcessor
5
  from PIL import Image
6
  import torch
7
 
8
+
9
+ class EndpointHandler:
10
+ """
11
+ A basic handler for a single GPU in Inference Endpoints.
12
+
13
+
14
+ Should not be used on multiple GPUs or on CPU.
15
+ """
16
 
17
  def __init__(self, *args, **kwargs):
18
 
19
  model_name = "google/pix2struct-infographics-vqa-large"
20
 
21
+ """
22
+ dtype tradeoffs:
23
+ - float16: works on T4, may have slight worse quality generations
24
+ - bfloat16: doesn't work on T4 (works on A10), better quality generation
25
+ - float32: works on all GPUs, best quality generation, 30-40% slower
26
+ """
27
+ self.dtype = torch.float16
28
 
29
+ self.model = Pix2StructForConditionalGeneration.from_pretrained(
30
+ model_name,
31
+ device_map="cuda:0",
32
+ torch_dtype=self.dtype,
33
+ )
34
  self.processor = AutoProcessor.from_pretrained(model_name)
 
 
 
 
 
35
 
36
+ self.device = torch.device("cuda")
37
 
38
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
39
  """
40
+ Can pass a list of images or a single image.
41
+
42
  Args:
43
  data (:obj:):
44
  includes the input data and the parameters for the inference.
45
  Return:
46
+ a dictionary with the output of the model. The only key is `output` and the
47
+ value is a list of str.
48
  """
49
  inputs = data.pop("inputs", data)
50
  parameters = data.pop("parameters", {})
51
 
52
  if isinstance(inputs["image"], list):
53
+ img = [
54
+ Image.open(BytesIO(base64.b64decode(img))) for img in inputs["image"]
55
+ ]
56
  else:
57
+ img = Image.open(BytesIO(base64.b64decode(inputs["image"])))
58
 
59
+ question = inputs["question"]
60
 
 
 
61
  with torch.inference_mode():
62
+ model_inputs = self.processor(
63
+ images=img, text=question, return_tensors="pt"
64
+ ).to(self.device, dtype=self.dtype)
65
 
66
  raw_output = self.model.generate(**model_inputs, **parameters)
67
 
68
+ decoded_output = self.processor.batch_decode(
69
+ raw_output, skip_special_tokens=True
70
+ )
71
 
72
  # postprocess the prediction
73
+ return {"output": decoded_output}