File size: 4,679 Bytes
0b17160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa179a2
36f093d
 
 
 
 
 
0b17160
 
 
 
 
 
b7864ca
 
0b17160
 
 
b7864ca
0b17160
 
 
 
 
aa179a2
588861c
aa179a2
0b17160
 
 
 
 
e4536f3
0b17160
e4536f3
67e4fcc
e4536f3
 
0b17160
 
e4536f3
0b17160
 
 
 
 
e0b18bc
3aabbc8
 
e0b18bc
0b17160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa179a2
 
 
 
0b17160
aa179a2
 
0b17160
36f093d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations

import spaces

import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
import hashlib
import os

from transformers import AutoModel, AutoProcessor
import torch

model = AutoModel.from_pretrained("visheratin/MC-LLaVA-3b", torch_dtype=torch.float16, trust_remote_code=True).to("cuda")

processor = AutoProcessor.from_pretrained("visheratin/MC-LLaVA-3b", trust_remote_code=True)

if torch.cuda.is_available():
    DEVICE = "cuda"
    DTYPE = torch.float16
else:
    DEVICE = "cpu"
    DTYPE = torch.float32

def cached_vision_process(image, max_crops, num_tokens):
    image_hash = hashlib.sha256(image.tobytes()).hexdigest()
    cache_path = f"visual_cache/{image_hash}-{max_crops}-{num_tokens}.pt"
    if os.path.exists(cache_path):
        return torch.load(cache_path).to(DEVICE, dtype=DTYPE)
    else:
        processor_outputs = processor.image_processor([image], max_crops)
        pixel_values = processor_outputs["pixel_values"]
        pixel_values = [
            value.to(model.device).to(model.dtype) for value in pixel_values
        ]
        coords = processor_outputs["coords"]
        coords = [value.to(model.device).to(model.dtype) for value in coords]
        image_outputs = model.vision_model(pixel_values, coords, num_tokens)
        image_features = model.multi_modal_projector(image_outputs)
        os.makedirs("visual_cache", exist_ok=True)
        torch.save(image_features, cache_path)
        return image_features.to(DEVICE, dtype=DTYPE)

@spaces.GPU(duration=20)
def answer_question(image, question, max_crops, num_tokens, sample, temperature, top_k):
    if question is None or question.strip() == "":
        yield "Please ask a question"
        return
    if image is None:
        yield "Please upload an image"
        return
    prompt = f"""<|im_start|>user
<image>
{question}<|im_end|>
<|im_start|>assistant
"""
    streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
    with torch.inference_mode():
        inputs = processor(prompt, [image], model, max_crops=max_crops, num_tokens=num_tokens)
    generation_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "image_features": inputs["image_features"],
        "streamer": streamer,
        "max_length": 1000,
        "use_cache": True,
        "eos_token_id": processor.tokenizer.eos_token_id,
        "pad_token_id": processor.tokenizer.eos_token_id,
        "temperature": temperature,
        "do_sample": sample,
        "top_k": top_k,
    }
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    buffer = ""
    output_started = False
    for new_text in streamer:
        if not output_started:
            if "<|im_start|>assistant" in new_text:
                output_started = True
            continue
        buffer += new_text
        if len(buffer) > 1:
            yield buffer


with gr.Blocks() as demo:
    gr.HTML("<h1 class='gradio-heading'><center>MC-LLaVA 3B</center></h1>")
    gr.HTML(
        "<center><p class='gradio-sub-heading'>MC-LLaVA 3B is a model that can answer questions about small details in high-resolution images. </p></center>"
    )
    gr.HTML(
        "<center><p class='gradio-sub-heading'>The magic of LLM happened when we can combine them with different data sources. We are able to search for object on images and get answer prepared by Large Language Model.</p></center>"
    )
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(
                label="Question", placeholder="e.g. What is this?", scale=4
            )
            submit = gr.Button(
                "Submit",
                scale=1,
            )
        with gr.Row():
            max_crops = gr.Slider(minimum=0, maximum=200, step=5, value=0, label="Max crops")
            num_tokens = gr.Slider(minimum=728, maximum=2184, step=10, value=728, label="Number of image tokens")
        with gr.Row():
            img = gr.Image(type="pil", label="Upload or Drag an Image")
            output = gr.TextArea(label="Answer")
        with gr.Row():
            sample = gr.Checkbox(label="Sample", value=False)
            temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0, label="Temperature")
            top_k = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Top-K")

    submit.click(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
    prompt.submit(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)

demo.queue().launch(debug=True)