File size: 6,509 Bytes
4b86165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270cb79
4b86165
 
 
270cb79
 
 
4b86165
 
 
 
 
 
 
 
 
 
 
270cb79
4b86165
 
 
 
 
 
 
 
270cb79
 
 
 
 
 
 
 
 
4b86165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270cb79
 
 
4b86165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import io
from typing import Dict, List, Any
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
from sentence_transformers import SentenceTransformer, util
from PIL import Image

def _fake_generate(n: int = 3):
    generate = list()
    for _ in range(n):
        generate.append(torch.IntTensor([103, 23, 48, 498, 536]))
    return torch.stack(generate, dim=0)


class EndpointHandler():
    def __init__(self, path="", test_mode: bool= False):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        # self.model= load_model(path)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        use_cuda = self.device == 'cuda'

        self.test_mode = test_mode
        self.MAXIMUM_PIXEL_VALUES = 3725568
        self.quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16
        )

        self.embedder = SentenceTransformer('all-mpnet-base-v2')
        self.model_id = "llava-hf/llava-1.5-7b-hf"
        self.processor = AutoProcessor.from_pretrained(self.model_id)
        if use_cuda:
            self.load_quantized()
        else:
            # Testing without CUDA device does not allow quantization
            self.model = LlavaForConditionalGeneration.from_pretrained(
                self.model_id,
                device_map="auto",
                low_cpu_mem_usage=True,
            )

    def load_quantized(self):
        print('Loading model with quantization')
        self.model = LlavaForConditionalGeneration.from_pretrained(
            self.model_id,
            quantization_config=self.quantization_config,
            device_map="auto",
            low_cpu_mem_usage=True,
        )

    def text_to_image(self, image_batch, prompt):
        prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
        prompt_batch = [prompt for _ in range(len(image_batch))]

        inputs = self.processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")

        batched_inputs: list[dict[str, torch.Tensor]] = list()
        if inputs['pixel_values'].flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
            batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
            i = 0
            while i < len(inputs['pixel_values']):
                batch['input_ids'].append(inputs['input_ids'][i])
                batch['attention_mask'].append(inputs['attention_mask'][i])
                batch['pixel_values'].append(inputs['pixel_values'][i])

                if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
                    print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
                          f'16GB GPU. Will split in more batches')
                    # Remove the last added image because it's too big to process
                    batch['input_ids'].pop()
                    batch['attention_mask'].pop()
                    batch['pixel_values'].pop()

                    # transform lists to tensors
                    batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
                    batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
                    batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)

                    # Add to the batched_inputs
                    batched_inputs.append(batch)
                    batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
                else:
                    i += 1
            if i >= len(inputs['pixel_values']) and len(batch['input_ids']) > 0:
                batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
                batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
                batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)

                # Add to the batched_inputs
                batched_inputs.append(batch)
                batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
        else:
            batched_inputs.append(inputs)

        maurice_description = list()
        maurice_embeddings = list()
        for batch in batched_inputs:
            # Load on device
            batch['input_ids'] = batch['input_ids'].to(self.model.device)
            batch['attention_mask'] = batch['attention_mask'].to(self.model.device)
            batch['pixel_values'] = batch['pixel_values'].to(self.model.device)
            # output = model.generate(**batch, max_new_tokens=500, temperature=0.3)
            if self.test_mode:
                output = _fake_generate(n=len(batch['input_ids']))
            else:
                output = self.model.generate(**batch, max_new_tokens=500)
            # Unload GPU
            batch['input_ids'].to('cpu')
            batch['attention_mask'].to('cpu')
            batch['pixel_values'].to('cpu')

            generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
            output = output.to('cpu')

            for text in generated_text:
                text_output = text.split("ASSISTANT:")[-1]
                text_embeddings = self.embedder.encode(text_output)
                maurice_description.append(text_output)
                maurice_embeddings.append(text_embeddings)

        return maurice_description, maurice_embeddings

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if device != self.device and device == 'cuda':
            self.load_quantized()
        images = data['inputs']
        prompt = data['prompt']

        pil_images = list()
        for image in images:
            pil_images.append(Image.open(io.BytesIO(image)))

        output_text, output_embedded = self.text_to_image(pil_images, prompt)

        result = list()
        for text, embed in zip(output_text, output_embedded):
            result.append(
                dict(
                    maurice_description=text,
                    maurice_embedding=embed
                )
            )
        return result