File size: 4,465 Bytes
0b17160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa179a2
0b17160
 
 
 
 
 
b7864ca
 
0b17160
 
 
b7864ca
0b17160
 
 
 
 
aa179a2
9d8bdb3
aa179a2
0b17160
 
 
 
 
e4536f3
0b17160
e4536f3
67e4fcc
e4536f3
 
0b17160
 
e4536f3
0b17160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa179a2
 
 
 
0b17160
aa179a2
 
0b17160
aa179a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import annotations

import spaces

import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
import hashlib
import os

from transformers import AutoModel, AutoProcessor
import torch

model = AutoModel.from_pretrained("visheratin/MC-LLaVA-3b", torch_dtype=torch.float16, trust_remote_code=True).to("cuda")

processor = AutoProcessor.from_pretrained("visheratin/MC-LLaVA-3b", trust_remote_code=True)

if torch.cuda.is_available():
    DEVICE = "cuda"
    DTYPE = torch.float16
else:
    DEVICE = "cpu"
    DTYPE = torch.float32

def cached_vision_process(image, max_crops, num_tokens):
    image_hash = hashlib.sha256(image.tobytes()).hexdigest()
    cache_path = f"visual_cache/{image_hash}-{max_crops}-{num_tokens}.pt"
    if os.path.exists(cache_path):
        return torch.load(cache_path).to(DEVICE, dtype=DTYPE)
    else:
        processor_outputs = processor.image_processor([image], max_crops)
        pixel_values = processor_outputs["pixel_values"]
        pixel_values = [
            value.to(model.device).to(model.dtype) for value in pixel_values
        ]
        coords = processor_outputs["coords"]
        coords = [value.to(model.device).to(model.dtype) for value in coords]
        image_outputs = model.vision_model(pixel_values, coords, num_tokens)
        image_features = model.multi_modal_projector(image_outputs)
        os.makedirs("visual_cache", exist_ok=True)
        torch.save(image_features, cache_path)
        return image_features.to(DEVICE, dtype=DTYPE)

@spaces.GPU(duration=20)
def answer_question(image, question, max_crops, num_tokens, sample, temperature, top_k):
    prompt = f"""<|im_start|>user
<image>
{question}<|im_end|>
<|im_start|>assistant
"""
    streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
    with torch.inference_mode():
        inputs = processor(prompt, [image], model, max_crops=max_crops, num_tokens=num_tokens)
    generation_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "image_features": inputs["image_features"],
        "streamer": streamer,
        "max_length": 1000,
        "use_cache": True,
        "eos_token_id": processor.tokenizer.eos_token_id,
        "pad_token_id": processor.tokenizer.eos_token_id,
        "temperature": temperature,
        "do_sample": sample,
        "top_k": top_k,
    }
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    buffer = ""
    output_started = False
    for new_text in streamer:
        if not output_started:
            if "<|im_start|>assistant" in new_text:
                output_started = True
            continue
        buffer += new_text
        if len(buffer) > 1:
            yield buffer


with gr.Blocks() as demo:
    gr.HTML("<h1 class='gradio-heading'><center>MC-LLaVA 3B</center></h1>")
    gr.HTML(
        "<center><p class='gradio-sub-heading'>MC-LLaVA 3B is a model that can answer questions about small details in high-resolution images. Check out the <a href='https://huggingface.co/visheratin/MC-LLaVA-3b'>model card</a> for more details. If you have any questions or ideas hot to make the model better, <a href='https://x.com/visheratin'>let me know</a></p></center>"
    )
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(
                label="Question", placeholder="e.g. What is this?", scale=4
            )
            submit = gr.Button(
                "Submit",
                scale=1,
            )
        with gr.Row():
            max_crops = gr.Slider(minimum=0, maximum=200, step=5, value=0, label="Max crops")
            num_tokens = gr.Slider(minimum=728, maximum=2184, step=10, value=728, label="Number of image tokens")
        with gr.Row():
            img = gr.Image(type="pil", label="Upload or Drag an Image")
            output = gr.TextArea(label="Answer")
        with gr.Row():
            sample = gr.Checkbox(label="Sample", value=False)
            temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0, label="Temperature")
            top_k = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Top-K")

    submit.click(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
    prompt.submit(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)

demo.queue().launch(debug=True)