visheratin commited on
Commit
0b17160
1 Parent(s): 4ab8a0b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import spaces
4
+
5
+ import gradio as gr
6
+ from threading import Thread
7
+ from transformers import TextIteratorStreamer
8
+ import hashlib
9
+ import os
10
+
11
+ from transformers import AutoModel, AutoProcessor
12
+ import torch
13
+
14
+ model = AutoModel.from_pretrained("visheratin/MC-LLaVA-3b", torch_dtype=torch.float16, trust_remote_code=True).to("cuda")
15
+
16
+ processor = AutoProcessor.from_pretrained("visheratin/MC-LLaVA-3b", trust_remote_code=True)
17
+
18
+ if torch.cuda.is_available():
19
+ DEVICE = "cuda"
20
+ DTYPE = torch.float16
21
+ else:
22
+ DEVICE = "cpu"
23
+ DTYPE = torch.float32
24
+
25
+ def cached_vision_process(image, max_crops, num_tokens):
26
+ image_hash = hashlib.sha256(image.tobytes()).hexdigest()
27
+ cache_path = f"visual_cache/{image_hash}-{max_crops}-{num_tokens}.pt"
28
+ if os.path.exists(cache_path):
29
+ return torch.load(cache_path).to(DEVICE, dtype=DTYPE)
30
+ else:
31
+ processor_outputs = processor.image_processor([image], max_crops)
32
+ pixel_values = processor_outputs["pixel_values"]
33
+ pixel_values = [
34
+ value.to(model.device).to(model.dtype) for value in pixel_values
35
+ ]
36
+ coords = processor_outputs["coords"]
37
+ coords = [value.to(model.device).to(model.dtype) for value in coords]
38
+ image_outputs = model.vision_model(pixel_values, coords, num_tokens)
39
+ image_features = model.multi_modal_projector(image_outputs)
40
+ os.makedirs("visual_cache", exist_ok=True)
41
+ torch.save(image_features, cache_path)
42
+ return image_features.to(DEVICE, dtype=DTYPE)
43
+
44
+ @spaces.GPU(duration=20)
45
+ def answer_question(image, question, max_crops, num_tokens):
46
+ prompt = f"""<|im_start|>user
47
+ <image>
48
+ {question}<|im_end|>
49
+ <|im_start|>assistant
50
+ """
51
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
52
+ inputs = processor(prompt, [image], model, max_crops=max_crops, num_tokens=num_tokens)
53
+ generation_kwargs = {
54
+ "input_ids": inputs["input_ids"],
55
+ "attention_mask": inputs["attention_mask"],
56
+ "image_features": cached_vision_process(image, max_crops, num_tokens),
57
+ "streamer": streamer,
58
+ "max_length": 1000,
59
+ "use_cache": True,
60
+ "eos_token_id": processor.tokenizer.eos_token_id,
61
+ "pad_token_id": processor.tokenizer.eos_token_id,
62
+ }
63
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
64
+ thread.start()
65
+
66
+ buffer = ""
67
+ for new_text in streamer:
68
+ buffer += new_text
69
+ if len(buffer) > 1:
70
+ yield buffer
71
+
72
+
73
+ with gr.Blocks() as demo:
74
+ gr.HTML("<h1 class='gradio-heading'><center>MC-LLaVA 3B</center></h1>")
75
+ gr.HTML(
76
+ "<center><p class='gradio-sub-heading'>MC-LLaVA 3B is a model that can answer questions about small details in high-resolution images. Check out the <a href='https://huggingface.co/visheratin/MC-LLaVA-3b'>model card</a> for more details. If you have any questions or ideas hot to make the model better, <a href='https://x.com/visheratin'>let me know</a></p></center>"
77
+ )
78
+ with gr.Group():
79
+ with gr.Row():
80
+ prompt = gr.Textbox(
81
+ label="Question", placeholder="e.g. What is this?", scale=4
82
+ )
83
+ submit = gr.Button(
84
+ "Submit",
85
+ scale=1,
86
+ )
87
+ with gr.Row():
88
+ max_crops = gr.Slider(minimum=0, maximum=200, step=5, value=0, label="Max crops")
89
+ num_tokens = gr.Slider(minimum=728, maximum=2184, step=10, value=728, label="Number of image tokens")
90
+ with gr.Row():
91
+ img = gr.Image(type="pil", label="Upload or Drag an Image")
92
+ output = gr.TextArea(label="Answer")
93
+
94
+ submit.click(answer_question, [img, prompt, max_crops, num_tokens], output)
95
+ prompt.submit(answer_question, [img, prompt, max_crops, num_tokens], output)
96
+
97
+ demo.queue().launch(debug=True)