Gabriel commited on
Commit
975fa91
1 Parent(s): 66b7b58

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +47 -22
handler.py CHANGED
@@ -1,22 +1,27 @@
1
  from typing import Dict, Any
2
- from transformers import QwenImageProcessor, QwenTokenizer, QwenForMultiModalConditionalGeneration
3
  import torch
 
4
  from PIL import Image
5
  import io
6
- import json
7
  import base64
8
  import requests
 
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- self.model = QwenForMultiModalConditionalGeneration.from_pretrained(
14
  path,
15
- torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
 
16
  ).to(self.device)
17
- self.image_processor = QwenImageProcessor.from_pretrained(path)
18
- self.tokenizer = QwenTokenizer.from_pretrained(path)
19
- self.model.generation_config.use_cache = False
 
 
 
 
20
 
21
  def __call__(self, data: Any) -> Dict[str, Any]:
22
  """
@@ -30,12 +35,14 @@ class EndpointHandler():
30
  Returns:
31
  Dict[str, Any]: The generated text output from the model.
32
  """
 
 
33
  if isinstance(data, (bytes, bytearray)):
34
  image = Image.open(io.BytesIO(data)).convert('RGB')
35
- text_input = "<|im_start|>user\nDescribe this image.\n<|im_end|><|im_start|>assistant\n"
36
  elif isinstance(data, dict):
37
  image_input = data.get('image', None)
38
- text_input = data.get('text', '')
39
  if image_input is None:
40
  return {"error": "No image provided."}
41
  if image_input.startswith('http'):
@@ -47,20 +54,38 @@ class EndpointHandler():
47
  else:
48
  return {"error": "Invalid input data. Expected binary image data or a dictionary with 'image' key."}
49
 
50
- image_inputs = self.image_processor(images=image, return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- if not text_input:
53
- text_input = "<|im_start|>user\nDescribe this image.\n<|im_end|><|im_start|>assistant\n"
54
- input_ids = self.tokenizer(text_input, return_tensors="pt").input_ids.to(self.device)
 
 
 
 
 
 
 
 
 
55
 
56
- generated_ids = self.model.generate(
57
- **image_inputs,
58
- input_ids=input_ids,
59
- max_new_tokens=256,
60
- do_sample=True,
61
- top_p=0.9,
62
- temperature=0.7,
63
  )
64
- output_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
65
 
66
- return {"generated_text": output_text}
 
1
  from typing import Dict, Any
 
2
  import torch
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
  from PIL import Image
5
  import io
 
6
  import base64
7
  import requests
8
+ from qwen_vl_utils import process_vision_info
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
14
  path,
15
+ torch_dtype="auto",
16
+ device_map="auto"
17
  ).to(self.device)
18
+
19
+ self.processor = AutoProcessor.from_pretrained(path)
20
+
21
+ # Optionally, adjust min_pixels and max_pixels if needed
22
+ # min_pixels = 256*28*28
23
+ # max_pixels = 1280*28*28
24
+ # self.processor = AutoProcessor.from_pretrained(path, min_pixels=min_pixels, max_pixels=max_pixels)
25
 
26
  def __call__(self, data: Any) -> Dict[str, Any]:
27
  """
 
35
  Returns:
36
  Dict[str, Any]: The generated text output from the model.
37
  """
38
+ default_prompt = "Describe this image."
39
+
40
  if isinstance(data, (bytes, bytearray)):
41
  image = Image.open(io.BytesIO(data)).convert('RGB')
42
+ text_input = default_prompt
43
  elif isinstance(data, dict):
44
  image_input = data.get('image', None)
45
+ text_input = data.get('text', default_prompt)
46
  if image_input is None:
47
  return {"error": "No image provided."}
48
  if image_input.startswith('http'):
 
54
  else:
55
  return {"error": "Invalid input data. Expected binary image data or a dictionary with 'image' key."}
56
 
57
+ messages = [
58
+ {
59
+ "role": "user",
60
+ "content": [
61
+ {
62
+ "type": "image",
63
+ "image": image,
64
+ },
65
+ {"type": "text", "text": text_input},
66
+ ],
67
+ }
68
+ ]
69
 
70
+ text = self.processor.apply_chat_template(
71
+ messages, tokenize=False, add_generation_prompt=True
72
+ )
73
+ image_inputs, video_inputs = process_vision_info(messages)
74
+ inputs = self.processor(
75
+ text=[text],
76
+ images=image_inputs,
77
+ videos=video_inputs,
78
+ padding=True,
79
+ return_tensors="pt",
80
+ )
81
+ inputs = inputs.to(self.device)
82
 
83
+ generated_ids = self.model.generate(**inputs, max_new_tokens=128)
84
+ generated_ids_trimmed = [
85
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
86
+ ]
87
+ output_text = self.processor.batch_decode(
88
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
89
  )
 
90
 
91
+ return {"generated_text": output_text[0]}