File size: 2,199 Bytes
e706c52
 
eacc8d0
e706c52
 
 
 
 
 
 
 
eacc8d0
975fa91
e706c52
 
975fa91
 
e706c52
 
975fa91
e706c52
 
975fa91
e706c52
 
 
eacc8d0
e706c52
 
 
 
 
 
975fa91
 
 
 
eacc8d0
975fa91
 
 
 
e706c52
eacc8d0
975fa91
 
eacc8d0
975fa91
 
eacc8d0
e706c52
eacc8d0
975fa91
eacc8d0
 
e706c52
eacc8d0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Dict, Any
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import io
import base64
import requests

class EndpointHandler():
    def __init__(self, path=""):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(path).to(self.device)
        self.processor = AutoProcessor.from_pretrained(path)

    def __call__(self, data: Any) -> Dict[str, Any]:
        default_prompt = "Describe this image."
        
        if isinstance(data, (bytes, bytearray)):
            image = Image.open(io.BytesIO(data)).convert('RGB')
            text_input = default_prompt
        elif isinstance(data, dict):
            image_input = data.get('image', None)
            text_input = data.get('text', default_prompt)
            if image_input is None:
                return {"error": "No image provided."}
            if image_input.startswith('http'):
                image = Image.open(requests.get(image_input, stream=True).raw).convert('RGB')
            else:
                image_data = base64.b64decode(image_input)
                image = Image.open(io.BytesIO(image_data)).convert('RGB')
        else:
            return {"error": "Invalid input data. Expected binary image data or a dictionary with 'image' key."}

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": text_input},
                ],
            }
        ]

        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.processor(
            text=[text],
            images=[image],
            padding=True,
            return_tensors="pt",
        ).to(self.device)

        generate_ids = self.model.generate(inputs.input_ids, max_length=30)
        output_text = self.processor.batch_decode(
            generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        return {"generated_text": output_text}