Spaces:
Build error
Build error
import gradio as gr | |
import textwrap | |
from io import BytesIO | |
import requests | |
import torch | |
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX | |
from llava.conversation import SeparatorStyle, conv_templates | |
from llava.mm_utils import ( | |
KeywordsStoppingCriteria, | |
get_model_name_from_path, | |
process_images, | |
tokenizer_image_token, | |
) | |
from llava.model.builder import load_pretrained_model | |
from llava.utils import disable_torch_init | |
from PIL import Image | |
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX | |
from llava.conversation import SeparatorStyle, conv_templates | |
from llava.mm_utils import ( | |
KeywordsStoppingCriteria, | |
get_model_name_from_path, | |
process_images, | |
tokenizer_image_token, | |
) | |
from llava.model.builder import load_pretrained_model | |
from llava.utils import disable_torch_init | |
from PIL import Image | |
import torch | |
# Disable PyTorch initialization | |
disable_torch_init() | |
# Load the pretrained model | |
MODEL = "4bit/llava-v1.5-13b-3GB" | |
model_name = get_model_name_from_path(MODEL) | |
tokenizer, model, image_processor, context_len = load_pretrained_model( | |
model_path=MODEL, model_base=None, model_name=model_name, load_4bit=True | |
) | |
# Define the prompt creation function | |
def create_prompt(prompt: str): | |
conv = conv_templates["llava_v0"].copy() | |
roles = conv.roles | |
prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt | |
conv.append_message(roles[0], prompt) | |
conv.append_message(roles[1], None) | |
return conv.get_prompt(), conv | |
# Define the image processing function | |
def process_image(image): | |
args = {"image_aspect_ratio": "pad"} | |
image_tensor = process_images([image], image_processor, args) | |
return image_tensor.to(model.device, dtype=torch.float16) | |
# Define the image description function | |
def describe_image(image_file): | |
image = Image.open(image_file) | |
image.resize((500, 500)) | |
processed_image = process_image(image) | |
prompt, _ = create_prompt("Describe the image") | |
input_ids = ( | |
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") | |
.unsqueeze(0) | |
.to(model.device) | |
) | |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
stopping_criteria = KeywordsStoppingCriteria( | |
keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids | |
) | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids, | |
images=processed_image, | |
do_sample=True, | |
temperature=0.01, | |
max_new_tokens=512, | |
use_cache=True, | |
stopping_criteria=[stopping_criteria], | |
) | |
description = tokenizer.decode( | |
output_ids[0, input_ids.shape[1] :], skip_special_tokens=True | |
).strip() | |
return description | |
iface = gr.Interface( | |
fn=describe_image, | |
inputs=gr.Image(type="pil", label="Image"), # Specify the label for the input | |
outputs=gr.Textbox(), | |
live=True, | |
capture_session=True | |
) | |
# Launch the Gradio interface | |
iface.launch() | |