toshi456's picture
Update app.py
bc57bff verified
import spaces
import gradio as gr
import torch
import transformers
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.model.llava_gpt2 import LlavaGpt2ForCausalLM
from llava.train.arguments_dataclass import ModelArguments, DataArguments, TrainingArguments
from llava.train.dataset import tokenizer_image_token
# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32
model_path = 'toshi456/llava-jp-1.3b-v1.1'
model = LlavaGpt2ForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
use_safetensors=True,
torch_dtype=torch_dtype,
device_map=device,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path,
model_max_length=1532,
padding_side="right",
use_fast=False,
)
model.eval()
conv_mode = "v1"
@spaces.GPU
def inference_fn(
image,
prompt,
max_len,
temperature,
top_p,
):
# prepare inputs
# image pre-process
image_size = model.get_model().vision_tower.image_processor.size["height"]
if model.get_model().vision_tower.scales is not None:
image_size = model.get_model().vision_tower.image_processor.size["height"] * len(model.get_model().vision_tower.scales)
if device == "cuda":
image_tensor = model.get_model().vision_tower.image_processor(
image,
return_tensors='pt',
size={"height": image_size, "width": image_size}
)['pixel_values'].half().cuda().to(torch_dtype)
else:
image_tensor = model.get_model().vision_tower.image_processor(
image,
return_tensors='pt',
size={"height": image_size, "width": image_size}
)['pixel_values'].to(torch_dtype)
# create prompt
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(
prompt,
tokenizer,
IMAGE_TOKEN_INDEX,
return_tensors='pt'
).unsqueeze(0)
if device == "cuda":
input_ids = input_ids.to(device)
input_ids = input_ids[:, :-1] # </sep>ใŒinputใฎๆœ€ๅพŒใซๅ…ฅใ‚‹ใฎใงๅ‰Š้™คใ™ใ‚‹
# generate
output_ids = model.generate(
inputs=input_ids,
images=image_tensor,
do_sample= temperature != 0.0,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_len,
use_cache=True,
)
output_ids = [token_id for token_id in output_ids.tolist()[0] if token_id != IMAGE_TOKEN_INDEX]
output = tokenizer.decode(output_ids, skip_special_tokens=True)
target = "ใ‚ทใ‚นใƒ†ใƒ : "
idx = output.find(target)
output = output[idx+len(target):]
return output
with gr.Blocks() as demo:
gr.Markdown(f"# LLaVA-JP Demo")
with gr.Row():
with gr.Column():
# input_instruction = gr.TextArea(label="instruction", value=DEFAULT_INSTRUCTION)
input_image = gr.Image(type="pil", label="image")
prompt = gr.Textbox(label="prompt (optional)", value="")
with gr.Accordion(label="Configs", open=False):
max_len = gr.Slider(
minimum=10,
maximum=256,
value=128,
step=5,
interactive=True,
label="Max New Tokens",
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.5,
maximum=1.0,
value=0.9,
step=0.1,
interactive=True,
label="Top p",
)
# button
input_button = gr.Button(value="Submit")
with gr.Column():
output = gr.Textbox(label="Output")
inputs = [input_image, prompt, max_len, temperature, top_p]
input_button.click(inference_fn, inputs=inputs, outputs=[output])
prompt.submit(inference_fn, inputs=inputs, outputs=[output])
img2txt_examples = gr.Examples(examples=[
[
"./imgs/sample1.jpg",
"็Œซใฏไฝ•ใ‚’ใ—ใฆใ„ใพใ™ใ‹๏ผŸ",
32,
0.0,
0.9,
],
[
"./imgs/sample2.jpg",
"ใ“ใฎ่‡ชๅ‹•่ฒฉๅฃฒๆฉŸใซใฏใฉใฎใƒ–ใƒฉใƒณใƒ‰ใฎ้ฃฒๆ–™ใŒๅซใพใ‚Œใฆใ„ใพใ™ใ‹๏ผŸ",
256,
0.0,
0.9,
],
[
"./imgs/sample3.jpg",
"ใ“ใฎๆ–™็†ใฎไฝœใ‚Šๆ–นใ‚’ๆ•™ใˆใฆใใ ใ•ใ„ใ€‚",
256,
0.0,
0.9,
],
[
"./imgs/sample4.jpg",
"ใ“ใฎใ‚ณใƒณใƒ”ใƒฅใƒผใ‚ฟใฎๅๅ‰ใ‚’ๆ•™ใˆใฆใใ ใ•ใ„ใ€‚",
256,
0.0,
0.9,
],
[
"./imgs/sample5.jpg",
"ใ“ใ‚Œใ‚‰ใ‚’ไฝฟใฃใฆไฝœใ‚‹ใ“ใจใŒใงใใ‚‹ๆ–™็†ใ‚’ๆ•™ใˆใฆใใ ใ•ใ„ใ€‚",
256,
0.0,
0.9,
],
], inputs=inputs, cache_examples=False)
if __name__ == "__main__":
demo.queue().launch()