File size: 1,953 Bytes
796696b
c4fb714
2b6baf9
 
c5e0759
2b6baf9
 
 
366e626
 
2b6baf9
366e626
 
7bfb35d
366e626
 
7bfb35d
895781a
366e626
9098bd2
 
366e626
c4fb714
366e626
c4fb714
9098bd2
c4fb714
366e626
c4fb714
366e626
c4fb714
366e626
 
 
 
 
 
57aeb22
366e626
 
57aeb22
366e626
 
 
 
9098bd2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import requests
from typing import Dict, Any
from PIL import Image
import torch
import base64
from io import BytesIO
from transformers import BlipForConditionalGeneration, BlipProcessor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
        self.model = BlipForConditionalGeneration.from_pretrained(
            "Salesforce/blip-image-captioning-large"
        ).to(device)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        input_data = data.get("inputs", {})
        encoded_images = input_data.get("images")
        
        if not encoded_images:
            return {"captions": [], "error": "No images provided"} 

        texts = input_data.get("texts", ["a photography of"] * len(encoded_images)) 

        try:
            raw_images = [Image.open(BytesIO(base64.b64decode(img))).convert("RGB") for img in encoded_images]
            processed_inputs = [
                self.processor(image, text, return_tensors="pt") for image, text in zip(raw_images, texts)
            ]
            processed_inputs = {
                "pixel_values": torch.cat([inp["pixel_values"] for inp in processed_inputs], dim=0).to(device),
                "input_ids": torch.cat([inp["input_ids"] for inp in processed_inputs], dim=0).to(device),
                "attention_mask": torch.cat([inp["attention_mask"] for inp in processed_inputs], dim=0).to(device)
            }

            with torch.no_grad():
                out = self.model.generate(**processed_inputs)

            captions = self.processor.batch_decode(out, skip_special_tokens=True)
            return {"captions": captions}
        except Exception as e:
            print(f"Error during processing: {str(e)}")
            return {"captions": [], "error": str(e)}