File size: 2,955 Bytes
ca0693b
ca66b5c
ca0693b
ca66b5c
 
 
ca0693b
 
 
 
 
 
 
 
 
ca66b5c
 
 
 
 
 
 
ca0693b
ca66b5c
5eda3fd
 
ca0693b
 
 
 
 
 
 
 
 
 
 
 
 
 
ca66b5c
ca0693b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca66b5c
ca0693b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Dict, Any

import torch
from transformers import Blip2Processor, Blip2Config, Blip2ForConditionalGeneration
from accelerate import init_empty_weights, infer_auto_device_map

from PIL import Image
from io import BytesIO
import base64
import torch.nn.functional as F


class EndpointHandler():
    def __init__(self, path=""):
        self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")

        config = Blip2Config.from_pretrained("Salesforce/blip2-flan-t5-xxl")
        with init_empty_weights():
            model = Blip2ForConditionalGeneration(config)
            device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"])
        device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"]

        self.model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-flan-t5-xxl", device_map=device_map,
            # torch_dtype=torch.float16
            load_in_8bit=True,
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        inputs = data["inputs"]

        if inputs["mode"] == 'generate_text':

            input_text: str = inputs['input_text']
            image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
            max_new_tokens: int = inputs['max_new_tokens']
            stop: str = inputs['stop']
            temperature: float = inputs['temperature']

            inputs = self.processor(images=image, text=input_text, return_tensors="pt").to(
                self.model.device, self.model.dtype
            )
            output = self.model.generate(
                **inputs, max_new_tokens=max_new_tokens, temperature=temperature
            )[0]
            output_text = self.processor.decode(output, skip_special_tokens=True).strip()
            if stop in output_text:
                output_text = output_text[: output_text.find(stop)]

            return {'output_text': output_text}

        elif inputs["mode"] == 'get_continuation_likelihood':

            prompt: str = inputs['prompt']
            continuation = inputs['continuation']
            image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))

            inputs = self.processor(
                images=image, text=(prompt + continuation), return_tensors="pt"
            ).to(self.model.device, self.model.dtype)
            inputs["labels"] = inputs["input_ids"]
            input_ids = inputs["input_ids"][0]
            tokens = [self.processor.decode([t]) for t in input_ids]

            logits = self.model(**inputs).logits[0]
            logprobs = F.log_softmax(logits, dim=1)
            logprobs = [logprobs[i, inputs["input_ids"][0][i]] for i in range(len(tokens))]

            return {
                'prompt': prompt,
                'continuation': continuation,
                'tokens': tokens,
                'logprobs': logprobs
            }