File size: 6,509 Bytes
4b86165 270cb79 4b86165 270cb79 4b86165 270cb79 4b86165 270cb79 4b86165 270cb79 4b86165 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import io
from typing import Dict, List, Any
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
from sentence_transformers import SentenceTransformer, util
from PIL import Image
def _fake_generate(n: int = 3):
generate = list()
for _ in range(n):
generate.append(torch.IntTensor([103, 23, 48, 498, 536]))
return torch.stack(generate, dim=0)
class EndpointHandler():
def __init__(self, path="", test_mode: bool= False):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_cuda = self.device == 'cuda'
self.test_mode = test_mode
self.MAXIMUM_PIXEL_VALUES = 3725568
self.quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
self.embedder = SentenceTransformer('all-mpnet-base-v2')
self.model_id = "llava-hf/llava-1.5-7b-hf"
self.processor = AutoProcessor.from_pretrained(self.model_id)
if use_cuda:
self.load_quantized()
else:
# Testing without CUDA device does not allow quantization
self.model = LlavaForConditionalGeneration.from_pretrained(
self.model_id,
device_map="auto",
low_cpu_mem_usage=True,
)
def load_quantized(self):
print('Loading model with quantization')
self.model = LlavaForConditionalGeneration.from_pretrained(
self.model_id,
quantization_config=self.quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
)
def text_to_image(self, image_batch, prompt):
prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
prompt_batch = [prompt for _ in range(len(image_batch))]
inputs = self.processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")
batched_inputs: list[dict[str, torch.Tensor]] = list()
if inputs['pixel_values'].flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
i = 0
while i < len(inputs['pixel_values']):
batch['input_ids'].append(inputs['input_ids'][i])
batch['attention_mask'].append(inputs['attention_mask'][i])
batch['pixel_values'].append(inputs['pixel_values'][i])
if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
f'16GB GPU. Will split in more batches')
# Remove the last added image because it's too big to process
batch['input_ids'].pop()
batch['attention_mask'].pop()
batch['pixel_values'].pop()
# transform lists to tensors
batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
# Add to the batched_inputs
batched_inputs.append(batch)
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
else:
i += 1
if i >= len(inputs['pixel_values']) and len(batch['input_ids']) > 0:
batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
# Add to the batched_inputs
batched_inputs.append(batch)
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
else:
batched_inputs.append(inputs)
maurice_description = list()
maurice_embeddings = list()
for batch in batched_inputs:
# Load on device
batch['input_ids'] = batch['input_ids'].to(self.model.device)
batch['attention_mask'] = batch['attention_mask'].to(self.model.device)
batch['pixel_values'] = batch['pixel_values'].to(self.model.device)
# output = model.generate(**batch, max_new_tokens=500, temperature=0.3)
if self.test_mode:
output = _fake_generate(n=len(batch['input_ids']))
else:
output = self.model.generate(**batch, max_new_tokens=500)
# Unload GPU
batch['input_ids'].to('cpu')
batch['attention_mask'].to('cpu')
batch['pixel_values'].to('cpu')
generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
output = output.to('cpu')
for text in generated_text:
text_output = text.split("ASSISTANT:")[-1]
text_embeddings = self.embedder.encode(text_output)
maurice_description.append(text_output)
maurice_embeddings.append(text_embeddings)
return maurice_description, maurice_embeddings
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device != self.device and device == 'cuda':
self.load_quantized()
images = data['inputs']
prompt = data['prompt']
pil_images = list()
for image in images:
pil_images.append(Image.open(io.BytesIO(image)))
output_text, output_embedded = self.text_to_image(pil_images, prompt)
result = list()
for text, embed in zip(output_text, output_embedded):
result.append(
dict(
maurice_description=text,
maurice_embedding=embed
)
)
return result
|