Gabriel commited on
Commit
e706c52
1 Parent(s): 26a9d66

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +66 -0
handler.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from transformers import QwenImageProcessor, QwenTokenizer, QwenForMultiModalConditionalGeneration
3
+ import torch
4
+ from PIL import Image
5
+ import io
6
+ import json
7
+ import base64
8
+ import requests
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.model = QwenForMultiModalConditionalGeneration.from_pretrained(
14
+ path,
15
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
16
+ ).to(self.device)
17
+ self.image_processor = QwenImageProcessor.from_pretrained(path)
18
+ self.tokenizer = QwenTokenizer.from_pretrained(path)
19
+ self.model.generation_config.use_cache = False
20
+
21
+ def __call__(self, data: Any) -> Dict[str, Any]:
22
+ """
23
+ Args:
24
+ data (Any): The input data, which can be:
25
+ - Binary image data in the request body.
26
+ - A dictionary with 'image' and 'text' keys:
27
+ - 'image': Base64-encoded image string or image URL.
28
+ - 'text': The text prompt.
29
+
30
+ Returns:
31
+ Dict[str, Any]: The generated text output from the model.
32
+ """
33
+ if isinstance(data, (bytes, bytearray)):
34
+ image = Image.open(io.BytesIO(data)).convert('RGB')
35
+ text_input = "<|im_start|>user\nDescribe this image.\n<|im_end|><|im_start|>assistant\n"
36
+ elif isinstance(data, dict):
37
+ image_input = data.get('image', None)
38
+ text_input = data.get('text', '')
39
+ if image_input is None:
40
+ return {"error": "No image provided."}
41
+ if image_input.startswith('http'):
42
+ response = requests.get(image_input)
43
+ image = Image.open(io.BytesIO(response.content)).convert('RGB')
44
+ else:
45
+ image_data = base64.b64decode(image_input)
46
+ image = Image.open(io.BytesIO(image_data)).convert('RGB')
47
+ else:
48
+ return {"error": "Invalid input data. Expected binary image data or a dictionary with 'image' key."}
49
+
50
+ image_inputs = self.image_processor(images=image, return_tensors="pt").to(self.device)
51
+
52
+ if not text_input:
53
+ text_input = "<|im_start|>user\nDescribe this image.\n<|im_end|><|im_start|>assistant\n"
54
+ input_ids = self.tokenizer(text_input, return_tensors="pt").input_ids.to(self.device)
55
+
56
+ generated_ids = self.model.generate(
57
+ **image_inputs,
58
+ input_ids=input_ids,
59
+ max_new_tokens=256,
60
+ do_sample=True,
61
+ top_p=0.9,
62
+ temperature=0.7,
63
+ )
64
+ output_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
65
+
66
+ return {"generated_text": output_text}