|
import gradio as gr |
|
import torch |
|
import json |
|
from io import BytesIO |
|
from PIL import Image, ImageOps |
|
from IPython.display import display, Markdown |
|
from transformers import AutoModelForCausalLM, LlamaTokenizer |
|
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch |
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
'THUDM/cogvlm-chat-hf', |
|
load_in_4bit=True, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
).eval() |
|
|
|
def generate_description(image, query, top_p, top_k, output_length, temperature): |
|
|
|
display_size = (224, 224) |
|
image = image.resize(display_size, Image.LANCZOS) |
|
|
|
|
|
inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image]) |
|
|
|
|
|
inputs = { |
|
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'), |
|
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'), |
|
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'), |
|
'images': [[inputs['images'][0].to('cuda').to(torch.float16)]], |
|
} |
|
|
|
|
|
gen_kwargs = { |
|
"max_length": output_length, |
|
"do_sample": True, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"temperature": temperature |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, **gen_kwargs) |
|
description = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return description |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Visual Product DNA - Image to Attribute Extractor") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Upload Image", type="pil", height=500) |
|
gr.skip |
|
query_input = gr.Textbox(label="Enter your prompt", value="Capture all attributes as JSON", lines=4) |
|
|
|
with gr.Column(): |
|
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.1, label="Creativity (top_p)") |
|
top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=100, label="Coherence (top_k)") |
|
output_length_slider = gr.Slider(minimum=1, maximum=4096, step=1, value=2048, label="Output Length") |
|
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, step=0.01, value=0.1, label="Temperature") |
|
submit_button = gr.Button("Extract Attributes") |
|
description_output = gr.Textbox(label="Generated JSON", lines=12) |
|
|
|
submit_button.click( |
|
fn=generate_description, |
|
inputs=[image_input, query_input, top_p_slider, top_k_slider, output_length_slider, temperature_slider], |
|
outputs=description_output |
|
) |
|
|
|
app.launch(share=True, input = False) |