from __future__ import annotations import spaces import gradio as gr from threading import Thread from transformers import TextIteratorStreamer import hashlib import os from transformers import AutoModel, AutoProcessor import torch model = AutoModel.from_pretrained("visheratin/MC-LLaVA-3b", torch_dtype=torch.float16, trust_remote_code=True).to("cuda") processor = AutoProcessor.from_pretrained("visheratin/MC-LLaVA-3b", trust_remote_code=True) if torch.cuda.is_available(): DEVICE = "cuda" DTYPE = torch.float16 else: DEVICE = "cpu" DTYPE = torch.float32 def cached_vision_process(image, max_crops, num_tokens): image_hash = hashlib.sha256(image.tobytes()).hexdigest() cache_path = f"visual_cache/{image_hash}-{max_crops}-{num_tokens}.pt" if os.path.exists(cache_path): return torch.load(cache_path).to(DEVICE, dtype=DTYPE) else: processor_outputs = processor.image_processor([image], max_crops) pixel_values = processor_outputs["pixel_values"] pixel_values = [ value.to(model.device).to(model.dtype) for value in pixel_values ] coords = processor_outputs["coords"] coords = [value.to(model.device).to(model.dtype) for value in coords] image_outputs = model.vision_model(pixel_values, coords, num_tokens) image_features = model.multi_modal_projector(image_outputs) os.makedirs("visual_cache", exist_ok=True) torch.save(image_features, cache_path) return image_features.to(DEVICE, dtype=DTYPE) @spaces.GPU(duration=20) def answer_question(image, question, max_crops, num_tokens): prompt = f"""<|im_start|>user {question}<|im_end|> <|im_start|>assistant """ streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True) inputs = processor(prompt, [image], model, max_crops=max_crops, num_tokens=num_tokens) generation_kwargs = { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "image_features": cached_vision_process(image, max_crops, num_tokens), "streamer": streamer, "max_length": 1000, "use_cache": True, "eos_token_id": processor.tokenizer.eos_token_id, "pad_token_id": processor.tokenizer.eos_token_id, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text if len(buffer) > 1: yield buffer with gr.Blocks() as demo: gr.HTML("

MC-LLaVA 3B

") gr.HTML( "

MC-LLaVA 3B is a model that can answer questions about small details in high-resolution images. Check out the model card for more details. If you have any questions or ideas hot to make the model better, let me know

" ) with gr.Group(): with gr.Row(): prompt = gr.Textbox( label="Question", placeholder="e.g. What is this?", scale=4 ) submit = gr.Button( "Submit", scale=1, ) with gr.Row(): max_crops = gr.Slider(minimum=0, maximum=200, step=5, value=0, label="Max crops") num_tokens = gr.Slider(minimum=728, maximum=2184, step=10, value=728, label="Number of image tokens") with gr.Row(): img = gr.Image(type="pil", label="Upload or Drag an Image") output = gr.TextArea(label="Answer") submit.click(answer_question, [img, prompt, max_crops, num_tokens], output) prompt.submit(answer_question, [img, prompt, max_crops, num_tokens], output) demo.queue().launch(debug=True)