import spaces import gradio as gr import torch import transformers from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX from llava.conversation import conv_templates from llava.model.llava_gpt2 import LlavaGpt2ForCausalLM from llava.train.arguments_dataclass import ModelArguments, DataArguments, TrainingArguments from llava.train.dataset import tokenizer_image_token # load model device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32 model_path = 'toshi456/llava-jp-1.3b-v1.1' model = LlavaGpt2ForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, use_safetensors=True, torch_dtype=torch_dtype, device_map=device, ) tokenizer = transformers.AutoTokenizer.from_pretrained( model_path, model_max_length=1532, padding_side="right", use_fast=False, ) model.eval() conv_mode = "v1" @spaces.GPU def inference_fn( image, prompt, max_len, temperature, top_p, ): # prepare inputs # image pre-process image_size = model.get_model().vision_tower.image_processor.size["height"] if model.get_model().vision_tower.scales is not None: image_size = model.get_model().vision_tower.image_processor.size["height"] * len(model.get_model().vision_tower.scales) if device == "cuda": image_tensor = model.get_model().vision_tower.image_processor( image, return_tensors='pt', size={"height": image_size, "width": image_size} )['pixel_values'].half().cuda().to(torch_dtype) else: image_tensor = model.get_model().vision_tower.image_processor( image, return_tensors='pt', size={"height": image_size, "width": image_size} )['pixel_values'].to(torch_dtype) # create prompt inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt conv = conv_templates[conv_mode].copy() conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token( prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt' ).unsqueeze(0) if device == "cuda": input_ids = input_ids.to(device) input_ids = input_ids[:, :-1] # がinputの最後に入るので削除する # generate output_ids = model.generate( inputs=input_ids, images=image_tensor, do_sample= temperature != 0.0, temperature=temperature, top_p=top_p, max_new_tokens=max_len, use_cache=True, ) output_ids = [token_id for token_id in output_ids.tolist()[0] if token_id != IMAGE_TOKEN_INDEX] output = tokenizer.decode(output_ids, skip_special_tokens=True) target = "システム: " idx = output.find(target) output = output[idx+len(target):] return output with gr.Blocks() as demo: gr.Markdown(f"# LLaVA-JP Demo") with gr.Row(): with gr.Column(): # input_instruction = gr.TextArea(label="instruction", value=DEFAULT_INSTRUCTION) input_image = gr.Image(type="pil", label="image") prompt = gr.Textbox(label="prompt (optional)", value="") with gr.Accordion(label="Configs", open=False): max_len = gr.Slider( minimum=10, maximum=256, value=128, step=5, interactive=True, label="Max New Tokens", ) temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.1, interactive=True, label="Temperature", ) top_p = gr.Slider( minimum=0.5, maximum=1.0, value=0.9, step=0.1, interactive=True, label="Top p", ) # button input_button = gr.Button(value="Submit") with gr.Column(): output = gr.Textbox(label="Output") inputs = [input_image, prompt, max_len, temperature, top_p] input_button.click(inference_fn, inputs=inputs, outputs=[output]) prompt.submit(inference_fn, inputs=inputs, outputs=[output]) img2txt_examples = gr.Examples(examples=[ [ "./imgs/sample1.jpg", "猫は何をしていますか?", 32, 0.0, 0.9, ], [ "./imgs/sample2.jpg", "この自動販売機にはどのブランドの飲料が含まれていますか?", 256, 0.0, 0.9, ], [ "./imgs/sample3.jpg", "この料理の作り方を教えてください。", 256, 0.0, 0.9, ], [ "./imgs/sample4.jpg", "このコンピュータの名前を教えてください。", 256, 0.0, 0.9, ], [ "./imgs/sample5.jpg", "これらを使って作ることができる料理を教えてください。", 256, 0.0, 0.9, ], ], inputs=inputs, cache_examples=False) if __name__ == "__main__": demo.queue().launch()