File size: 2,511 Bytes
ca0693b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Dict, Any
import torch
from transformers import Blip2ForConditionalGeneration, Blip2Processor
from PIL import Image
from io import BytesIO
import base64
import torch.nn.functional as F


class EndpointHandler():
    def __init__(self, path=""):
        self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
        self.model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-flan-t5-xxl", device_map="auto",
            torch_dtype=torch.float16
            # load_in_8bit=True,
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        inputs = data["inputs"]

        if inputs["mode"] == 'generate_text':

            input_text: str = inputs['input_text']
            image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
            max_new_tokens: int = inputs['max_new_tokens']
            stop: str = inputs['stop']
            temperature: float = inputs['temperature']

            inputs = self.processor(images=image, text=input_text, return_tensors="pt").to(
                self.model.device, self.model.dtype
            )
            output = self.model.generate(
                **inputs, max_new_tokens=max_new_tokens, temperature=temperature
            )[0]
            output_text = self.processor.decode(output, skip_special_tokens=True).strip()
            if stop in output_text:
                output_text = output_text[: output_text.find(stop)]

            return {'output_text': output_text}

        elif inputs["mode"] == 'get_continuation_likelihood':

            prompt: str = inputs['prompt']
            continuation = inputs['continuation']
            image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))

            inputs = self.processor(
                images=image, text=(prompt + continuation), return_tensors="pt"
            ).to(self.model.device, self.model.dtype)
            inputs["labels"] = inputs["input_ids"]
            input_ids = inputs["input_ids"][0]
            tokens = [self.processor.decode([t]) for t in input_ids]

            logits = self.model(**inputs).logits[0]
            logprobs = F.log_softmax(logits, dim=1)
            logprobs = [logprobs[i, inputs["input_ids"][0][i]] for i in range(len(tokens))]

            return {
                'prompt': prompt,
                'continuation': continuation,
                'tokens': tokens,
                'logprobs': logprobs
            }