fthor's picture
Testing by reloading the model on __call__ with quantization in case GPU is not available during __init__
270cb79
import io
from typing import Dict, List, Any
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
from sentence_transformers import SentenceTransformer, util
from PIL import Image
def _fake_generate(n: int = 3):
generate = list()
for _ in range(n):
generate.append(torch.IntTensor([103, 23, 48, 498, 536]))
return torch.stack(generate, dim=0)
class EndpointHandler():
def __init__(self, path="", test_mode: bool= False):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_cuda = self.device == 'cuda'
self.test_mode = test_mode
self.MAXIMUM_PIXEL_VALUES = 3725568
self.quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
self.embedder = SentenceTransformer('all-mpnet-base-v2')
self.model_id = "llava-hf/llava-1.5-7b-hf"
self.processor = AutoProcessor.from_pretrained(self.model_id)
if use_cuda:
self.load_quantized()
else:
# Testing without CUDA device does not allow quantization
self.model = LlavaForConditionalGeneration.from_pretrained(
self.model_id,
device_map="auto",
low_cpu_mem_usage=True,
)
def load_quantized(self):
print('Loading model with quantization')
self.model = LlavaForConditionalGeneration.from_pretrained(
self.model_id,
quantization_config=self.quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
)
def text_to_image(self, image_batch, prompt):
prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
prompt_batch = [prompt for _ in range(len(image_batch))]
inputs = self.processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")
batched_inputs: list[dict[str, torch.Tensor]] = list()
if inputs['pixel_values'].flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
i = 0
while i < len(inputs['pixel_values']):
batch['input_ids'].append(inputs['input_ids'][i])
batch['attention_mask'].append(inputs['attention_mask'][i])
batch['pixel_values'].append(inputs['pixel_values'][i])
if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
f'16GB GPU. Will split in more batches')
# Remove the last added image because it's too big to process
batch['input_ids'].pop()
batch['attention_mask'].pop()
batch['pixel_values'].pop()
# transform lists to tensors
batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
# Add to the batched_inputs
batched_inputs.append(batch)
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
else:
i += 1
if i >= len(inputs['pixel_values']) and len(batch['input_ids']) > 0:
batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
# Add to the batched_inputs
batched_inputs.append(batch)
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
else:
batched_inputs.append(inputs)
maurice_description = list()
maurice_embeddings = list()
for batch in batched_inputs:
# Load on device
batch['input_ids'] = batch['input_ids'].to(self.model.device)
batch['attention_mask'] = batch['attention_mask'].to(self.model.device)
batch['pixel_values'] = batch['pixel_values'].to(self.model.device)
# output = model.generate(**batch, max_new_tokens=500, temperature=0.3)
if self.test_mode:
output = _fake_generate(n=len(batch['input_ids']))
else:
output = self.model.generate(**batch, max_new_tokens=500)
# Unload GPU
batch['input_ids'].to('cpu')
batch['attention_mask'].to('cpu')
batch['pixel_values'].to('cpu')
generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
output = output.to('cpu')
for text in generated_text:
text_output = text.split("ASSISTANT:")[-1]
text_embeddings = self.embedder.encode(text_output)
maurice_description.append(text_output)
maurice_embeddings.append(text_embeddings)
return maurice_description, maurice_embeddings
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device != self.device and device == 'cuda':
self.load_quantized()
images = data['inputs']
prompt = data['prompt']
pil_images = list()
for image in images:
pil_images.append(Image.open(io.BytesIO(image)))
output_text, output_embedded = self.text_to_image(pil_images, prompt)
result = list()
for text, embed in zip(output_text, output_embedded):
result.append(
dict(
maurice_description=text,
maurice_embedding=embed
)
)
return result