import io from typing import Dict, List, Any import torch from transformers import AutoProcessor, LlavaForConditionalGeneration from transformers import BitsAndBytesConfig from sentence_transformers import SentenceTransformer, util from PIL import Image def _fake_generate(n: int = 3): generate = list() for _ in range(n): generate.append(torch.IntTensor([103, 23, 48, 498, 536])) return torch.stack(generate, dim=0) class EndpointHandler(): def __init__(self, path="", test_mode: bool= False): # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') use_cuda = self.device == 'cuda' self.test_mode = test_mode self.MAXIMUM_PIXEL_VALUES = 3725568 self.quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 ) self.embedder = SentenceTransformer('all-mpnet-base-v2') self.model_id = "llava-hf/llava-1.5-7b-hf" self.processor = AutoProcessor.from_pretrained(self.model_id) if use_cuda: self.load_quantized() else: # Testing without CUDA device does not allow quantization self.model = LlavaForConditionalGeneration.from_pretrained( self.model_id, device_map="auto", low_cpu_mem_usage=True, ) def load_quantized(self): print('Loading model with quantization') self.model = LlavaForConditionalGeneration.from_pretrained( self.model_id, quantization_config=self.quantization_config, device_map="auto", low_cpu_mem_usage=True, ) def text_to_image(self, image_batch, prompt): prompt = f'USER: \n{prompt}\nASSISTANT:' prompt_batch = [prompt for _ in range(len(image_batch))] inputs = self.processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt") batched_inputs: list[dict[str, torch.Tensor]] = list() if inputs['pixel_values'].flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES: batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list()) i = 0 while i < len(inputs['pixel_values']): batch['input_ids'].append(inputs['input_ids'][i]) batch['attention_mask'].append(inputs['attention_mask'][i]) batch['pixel_values'].append(inputs['pixel_values'][i]) if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES: print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 ' f'16GB GPU. Will split in more batches') # Remove the last added image because it's too big to process batch['input_ids'].pop() batch['attention_mask'].pop() batch['pixel_values'].pop() # transform lists to tensors batch['input_ids'] = torch.stack(batch['input_ids'], dim=0) batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0) batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0) # Add to the batched_inputs batched_inputs.append(batch) batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list()) else: i += 1 if i >= len(inputs['pixel_values']) and len(batch['input_ids']) > 0: batch['input_ids'] = torch.stack(batch['input_ids'], dim=0) batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0) batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0) # Add to the batched_inputs batched_inputs.append(batch) batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list()) else: batched_inputs.append(inputs) maurice_description = list() maurice_embeddings = list() for batch in batched_inputs: # Load on device batch['input_ids'] = batch['input_ids'].to(self.model.device) batch['attention_mask'] = batch['attention_mask'].to(self.model.device) batch['pixel_values'] = batch['pixel_values'].to(self.model.device) # output = model.generate(**batch, max_new_tokens=500, temperature=0.3) if self.test_mode: output = _fake_generate(n=len(batch['input_ids'])) else: output = self.model.generate(**batch, max_new_tokens=500) # Unload GPU batch['input_ids'].to('cpu') batch['attention_mask'].to('cpu') batch['pixel_values'].to('cpu') generated_text = self.processor.batch_decode(output, skip_special_tokens=True) output = output.to('cpu') for text in generated_text: text_output = text.split("ASSISTANT:")[-1] text_embeddings = self.embedder.encode(text_output) maurice_description.append(text_output) maurice_embeddings.append(text_embeddings) return maurice_description, maurice_embeddings def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device != self.device and device == 'cuda': self.load_quantized() images = data['inputs'] prompt = data['prompt'] pil_images = list() for image in images: pil_images.append(Image.open(io.BytesIO(image))) output_text, output_embedded = self.text_to_image(pil_images, prompt) result = list() for text, embed in zip(output_text, output_embedded): result.append( dict( maurice_description=text, maurice_embedding=embed ) ) return result