ranggafermata commited on
Commit
ff129ad
·
verified ·
1 Parent(s): 601c861

Upload app (7).py

Browse files
Files changed (1) hide show
  1. app (7).py +289 -0
app (7).py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ from threading import Thread
7
+ from typing import Iterable
8
+
9
+ import gradio as gr
10
+ import spaces
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import cv2
15
+
16
+ from transformers import (
17
+ Qwen2VLForConditionalGeneration,
18
+ Qwen2_5_VLForConditionalGeneration,
19
+ AutoModelForImageTextToText,
20
+ AutoProcessor,
21
+ TextIteratorStreamer,
22
+ )
23
+ from transformers.image_utils import load_image
24
+ from gradio.themes import Soft
25
+ from gradio.themes.utils import colors, fonts, sizes
26
+
27
+ # --- Theme and CSS Definition ---
28
+
29
+ # Define the SteelBlue color palette
30
+ colors.steel_blue = colors.Color(
31
+ name="steel_blue",
32
+ c50="#EBF3F8",
33
+ c100="#D3E5F0",
34
+ c200="#A8CCE1",
35
+ c300="#7DB3D2",
36
+ c400="#529AC3",
37
+ c500="#4682B4", # SteelBlue base color
38
+ c600="#3E72A0",
39
+ c700="#36638C",
40
+ c800="#2E5378",
41
+ c900="#264364",
42
+ c950="#1E3450",
43
+ )
44
+
45
+ class SteelBlueTheme(Soft):
46
+ def __init__(
47
+ self,
48
+ *,
49
+ primary_hue: colors.Color | str = colors.gray,
50
+ secondary_hue: colors.Color | str = colors.steel_blue,
51
+ neutral_hue: colors.Color | str = colors.slate,
52
+ text_size: sizes.Size | str = sizes.text_lg,
53
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
54
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
55
+ ),
56
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
57
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
58
+ ),
59
+ ):
60
+ super().__init__(
61
+ primary_hue=primary_hue,
62
+ secondary_hue=secondary_hue,
63
+ neutral_hue=neutral_hue,
64
+ text_size=text_size,
65
+ font=font,
66
+ font_mono=font_mono,
67
+ )
68
+ super().set(
69
+ background_fill_primary="*primary_50",
70
+ background_fill_primary_dark="*primary_900",
71
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
72
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
73
+ button_primary_text_color="white",
74
+ button_primary_text_color_hover="white",
75
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
76
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
77
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
78
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
79
+ button_secondary_text_color="black",
80
+ button_secondary_text_color_hover="white",
81
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
82
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
83
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
84
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
85
+ slider_color="*secondary_500",
86
+ slider_color_dark="*secondary_600",
87
+ block_title_text_weight="600",
88
+ block_border_width="3px",
89
+ block_shadow="*shadow_drop_lg",
90
+ button_primary_shadow="*shadow_drop_lg",
91
+ button_large_padding="11px",
92
+ color_accent_soft="*primary_100",
93
+ block_label_background_fill="*primary_200",
94
+ )
95
+
96
+ # Instantiate the new theme
97
+ steel_blue_theme = SteelBlueTheme()
98
+
99
+ css = """
100
+ #main-title h1 {
101
+ font-size: 2.3em !important;
102
+ }
103
+ #output-title h2 {
104
+ font-size: 2.1em !important;
105
+ }
106
+ """
107
+
108
+ # Constants for text generation
109
+ MAX_MAX_NEW_TOKENS = 4096
110
+ DEFAULT_MAX_NEW_TOKENS = 1024
111
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
112
+
113
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
114
+
115
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
116
+ print("torch.__version__ =", torch.__version__)
117
+ print("torch.version.cuda =", torch.version.cuda)
118
+ print("cuda available:", torch.cuda.is_available())
119
+ print("cuda device count:", torch.cuda.device_count())
120
+ if torch.cuda.is_available():
121
+ print("current device:", torch.cuda.current_device())
122
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
123
+
124
+ print("Using device:", device)
125
+
126
+ # --- Model Loading ---
127
+ # Load Nanonets-OCR2-3B
128
+ MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
129
+ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
130
+ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
131
+ MODEL_ID_V,
132
+ trust_remote_code=True,
133
+ torch_dtype=torch.float16
134
+ ).to(device).eval()
135
+
136
+ # Load Qwen2-VL-OCR-2B-Instruct
137
+ MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
138
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
139
+ model_x = Qwen2VLForConditionalGeneration.from_pretrained(
140
+ MODEL_ID_X,
141
+ trust_remote_code=True,
142
+ torch_dtype=torch.float16
143
+ ).to(device).eval()
144
+
145
+ # Load Aya-Vision-8b
146
+ MODEL_ID_A = "CohereForAI/aya-vision-8b"
147
+ processor_a = AutoProcessor.from_pretrained(MODEL_ID_A, trust_remote_code=True)
148
+ model_a = AutoModelForImageTextToText.from_pretrained(
149
+ MODEL_ID_A,
150
+ trust_remote_code=True,
151
+ torch_dtype=torch.float16
152
+ ).to(device).eval()
153
+
154
+ # Load olmOCR-7B-0725
155
+ MODEL_ID_W = "allenai/olmOCR-7B-0725"
156
+ processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
157
+ model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
158
+ MODEL_ID_W,
159
+ trust_remote_code=True,
160
+ torch_dtype=torch.float16
161
+ ).to(device).eval()
162
+
163
+ # Load RolmOCR
164
+ MODEL_ID_M = "reducto/RolmOCR"
165
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
166
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
167
+ MODEL_ID_M,
168
+ trust_remote_code=True,
169
+ torch_dtype=torch.float16
170
+ ).to(device).eval()
171
+
172
+
173
+ @spaces.GPU
174
+ def generate_image(model_name: str, text: str, image: Image.Image,
175
+ max_new_tokens: int, temperature: float, top_p: float,
176
+ top_k: int, repetition_penalty: float):
177
+ """
178
+ Generates responses using the selected model for image input.
179
+ Yields raw text and Markdown-formatted text.
180
+ """
181
+ if model_name == "RolmOCR-7B":
182
+ processor = processor_m
183
+ model = model_m
184
+ elif model_name == "Qwen2-VL-OCR-2B":
185
+ processor = processor_x
186
+ model = model_x
187
+ elif model_name == "Nanonets-OCR2-3B":
188
+ processor = processor_v
189
+ model = model_v
190
+ elif model_name == "Aya-Vision-8B":
191
+ processor = processor_a
192
+ model = model_a
193
+ elif model_name == "olmOCR-7B-0725":
194
+ processor = processor_w
195
+ model = model_w
196
+ else:
197
+ yield "Invalid model selected.", "Invalid model selected."
198
+ return
199
+
200
+ if image is None:
201
+ yield "Please upload an image.", "Please upload an image."
202
+ return
203
+
204
+ messages = [{
205
+ "role": "user",
206
+ "content": [
207
+ {"type": "image"},
208
+ {"type": "text", "text": text},
209
+ ]
210
+ }]
211
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
212
+
213
+ inputs = processor(
214
+ text=[prompt_full],
215
+ images=[image],
216
+ return_tensors="pt",
217
+ padding=True).to(device)
218
+
219
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
220
+ generation_kwargs = {
221
+ **inputs,
222
+ "streamer": streamer,
223
+ "max_new_tokens": max_new_tokens,
224
+ "do_sample": True,
225
+ "temperature": temperature,
226
+ "top_p": top_p,
227
+ "top_k": top_k,
228
+ "repetition_penalty": repetition_penalty,
229
+ }
230
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
231
+ thread.start()
232
+ buffer = ""
233
+ for new_text in streamer:
234
+ buffer += new_text
235
+ buffer = buffer.replace("<|im_end|>", "")
236
+ time.sleep(0.01)
237
+ yield buffer, buffer
238
+
239
+
240
+ # Define examples for image inference
241
+ image_examples = [
242
+ ["Extract the full page.", "images/ocr.png"],
243
+ ["Extract the content.", "images/4.png"],
244
+ ["Convert this page to doc [table] precisely for markdown.", "images/0.png"]
245
+ ]
246
+
247
+
248
+ # Create the Gradio Interface
249
+ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
250
+ gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
251
+ with gr.Row():
252
+ with gr.Column(scale=2):
253
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
254
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
255
+
256
+ image_submit = gr.Button("Submit", variant="primary")
257
+ gr.Examples(
258
+ examples=image_examples,
259
+ inputs=[image_query, image_upload]
260
+ )
261
+
262
+ with gr.Accordion("Advanced options", open=False):
263
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
264
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7)
265
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
266
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
267
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
268
+
269
+ with gr.Column(scale=3):
270
+ gr.Markdown("## Output", elem_id="output-title")
271
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
272
+ with gr.Accordion("(Result.md)", open=False):
273
+ markdown_output = gr.Markdown(label="(Result.Md)")
274
+
275
+ model_choice = gr.Radio(
276
+ choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
277
+ "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
278
+ label="Select Model",
279
+ value="Nanonets-OCR2-3B"
280
+ )
281
+
282
+ image_submit.click(
283
+ fn=generate_image,
284
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
285
+ outputs=[output, markdown_output]
286
+ )
287
+
288
+ if __name__ == "__main__":
289
+ demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)