import gradio as gr import spaces from threading import Thread from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration from transformers import TextIteratorStreamer from PIL import Image from peft import PeftModel import requests import torch, os, re, json import time base_model = "llava-hf/llava-v1.6-mistral-7b-hf" finetune_repo = "erwannd/llava-v1.6-mistral-7b-finetune-combined4k" processor = LlavaNextProcessor.from_pretrained(base_model) model = LlavaNextForConditionalGeneration.from_pretrained( base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, ) model = PeftModel.from_pretrained(model, finetune_repo) model.to("cuda:0") @spaces.GPU def predict(image, input_text): image = image.convert("RGB") prompt = f"[INST] \n{input_text} [/INST]" inputs = processor(text=prompt, images=image, return_tensors="pt").to(0, torch.float16) streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True}) # generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200, do_sample=False) model.generate(**inputs, streamer=streamer, max_new_tokens=200, do_sample=False) text_prompt = f"[INST] \n{input_text} [/INST]" buffer = "" time.sleep(0.5) for new_text in streamer: buffer += new_text generated_text_without_prompt = buffer[len(text_prompt):] time.sleep(0.04) yield generated_text_without_prompt image = gr.components.Image(type="pil") input_prompt = gr.components.Textbox(label="Input Prompt") model_output = gr.components.Textbox(label="Model Output") examples = [["./examples/bar_m01.png", "Evaluate and explain if this chart is misleading"], ["./examples/bar_n01.png", "Is this chart misleading? Explain"], ["./examples/fox_news_cropped.png", "Tell me if this chart is misleading"], ["./examples/line_m01.png", "Explain if this chart is misleading"], ["./examples/line_m04.png", "Evaluate and explain if this chart is misleading"], ["./examples/pie_m01.png", "Evaluate if this chart is misleading, if so explain"], ["./examples/pie_m02.png", "Is this chart misleading? Explain"]] description_markdown = """Demo for [LlavaNext](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) finetuned on [charts dataset](https://huggingface.co/datasets/chart-misinformation-detection/bar_line_pie_4k)""" title = "LlavaNext finetuned on Misleading Chart Dataset" interface = gr.Interface( fn=predict, inputs=[image, input_prompt], outputs=model_output, examples=examples, title=title, theme='gradio/soft', cache_examples=False, description=description_markdown ) interface.launch()