|
from llava.model.builder import load_pretrained_model |
|
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token |
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
from llava.conversation import conv_templates |
|
|
|
from PIL import Image |
|
import requests |
|
import copy |
|
import torch |
|
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path |
|
import spaces |
|
from io import BytesIO |
|
import base64 |
|
|
|
|
|
from src.utils import ( |
|
build_logger, |
|
) |
|
|
|
logger = build_logger("model_llava", "model_llava.log") |
|
def load_llava_model(lora_checkpoint=None): |
|
model_path = "Lin-Chen/open-llava-next-llama3-8b" |
|
conv_template = "llama_v3_student" |
|
model_name = get_model_name_from_path(model_path) |
|
device = "cuda" |
|
device_map = "auto" |
|
if lora_checkpoint is None: |
|
tokenizer, model, image_processor, max_length = load_pretrained_model( |
|
model_path, None, model_name, device_map=device_map) |
|
else: |
|
tokenizer, model, image_processor, max_length = load_pretrained_model( |
|
lora_checkpoint, model_path, "llava_lora", device_map=device_map) |
|
|
|
model.eval() |
|
model.tie_weights() |
|
logger.info(f"model device {model.device}") |
|
return tokenizer, model, image_processor, conv_template |
|
|
|
tokenizer_llava, model_llava, image_processor_llava, conv_template_llava = load_llava_model(None) |
|
tokenizer_llava_fire, model_llava_fire, image_processor_llava_fire, conv_template_llava = load_llava_model("checkpoints/llava-next-llama-3-8b-student-lora-merged-115124") |
|
model_llava_fire.to("cuda") |
|
|
|
@spaces.GPU |
|
def inference(): |
|
image = Image.open("assets/example.jpg").convert("RGB") |
|
device = "cuda" |
|
image_tensor = process_images([image], image_processor_llava, model_llava.config) |
|
image_tensor = image_tensor.to(dtype=torch.float16, device=device) |
|
|
|
prompt = """<image>What is in the figure?""" |
|
conv = conv_templates[conv_template_llava].copy() |
|
conv.append_message(conv.roles[0], prompt) |
|
conv.append_message(conv.roles[1], None) |
|
prompt_question = conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt_question, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) |
|
image_sizes = [image.size] |
|
print(input_ids.shape, image_tensor.shape) |
|
with torch.inference_mode(): |
|
cont = model_llava.generate( |
|
input_ids, |
|
images=image_tensor, |
|
image_sizes=image_sizes, |
|
do_sample=False, |
|
temperature=0, |
|
max_new_tokens=256, |
|
use_cache=True |
|
) |
|
text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True) |
|
print(text_outputs) |
|
return text_outputs |
|
|
|
|
|
@spaces.GPU |
|
def inference_by_prompt_and_images(prompt, images): |
|
device = "cuda" |
|
if len(images) > 0 and type(images[0]) is str: |
|
image_data = [] |
|
for image in images: |
|
image_data.append(Image.open(BytesIO(base64.b64decode(image)))) |
|
images = image_data |
|
image_tensor = process_images(images, image_processor_llava, model_llava.config) |
|
image_tensor = image_tensor.to(dtype=torch.float16, device=device) |
|
input_ids = tokenizer_image_token(prompt, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) |
|
image_sizes = [image.size for image in images] |
|
logger.info(f"Shape: {input_ids.shape};{image_tensor.shape}; Devices: {input_ids.device};{image_tensor.device}") |
|
with torch.inference_mode(): |
|
cont = model_llava.generate( |
|
input_ids, |
|
images=image_tensor, |
|
image_sizes=image_sizes, |
|
do_sample=False, |
|
temperature=0, |
|
max_new_tokens=256, |
|
use_cache=True |
|
) |
|
text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True) |
|
|
|
return text_outputs |
|
|
|
@spaces.GPU |
|
def inference_by_prompt_and_images_fire(prompt, images): |
|
device = "cuda" |
|
if len(images) > 0 and type(images[0]) is str: |
|
image_data = [] |
|
for image in images: |
|
image_data.append(Image.open(BytesIO(base64.b64decode(image)))) |
|
images = image_data |
|
image_tensor = process_images(images, image_processor_llava_fire, model_llava_fire.config) |
|
image_tensor = image_tensor.to(dtype=torch.float16, device=device) |
|
input_ids = tokenizer_image_token(prompt, tokenizer_llava_fire, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) |
|
image_sizes = [image.size for image in images] |
|
logger.info(f"Shape: {input_ids.shape};{image_tensor.shape}; Devices: {input_ids.device};{image_tensor.device}") |
|
with torch.inference_mode(): |
|
cont = model_llava_fire.generate( |
|
input_ids, |
|
images=[image_tensor.squeeze(dim=0)], |
|
image_sizes=image_sizes, |
|
do_sample=False, |
|
temperature=0, |
|
max_new_tokens=256, |
|
use_cache=True |
|
) |
|
text_outputs = tokenizer_llava_fire.batch_decode(cont, skip_special_tokens=True) |
|
logger.info(f"response={text_outputs}") |
|
return text_outputs |
|
|
|
if __name__ == "__main__": |
|
inference() |