File size: 1,841 Bytes
d24c2b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
from io import BytesIO
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class EndpointHandler():
    def __init__(self, path="Salesforce/blip2-opt-6.7b-coco"):
        # load the tokenizer and model
        tokenizer = AutoTokenizer.from_pretrained(path)
        model = AutoModelForSeq2SeqLM.from_pretrained(path)

        self.image_to_text_pipeline = pipeline('image-to-text', model=model, tokenizer=tokenizer)

        image_size = 384
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])

    def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Any]]:
        """
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
        Return:
            A :obj:`dict`: will be serialized and returned
        """
        # Extract inputs and kwargs from the data
        inputs = data["inputs"]
        parameters = data.pop("parameters", None)

        # Decode base64 image to PIL
        image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
        image = self.transform(image).unsqueeze(0).to(device)

        # Run the model for prediction
        if parameters is not None:
            predictions = self.image_to_text_pipeline(image, **parameters)
        else:
            predictions = self.image_to_text_pipeline(image)

        return predictions