prithivMLmods commited on
Commit
2768728
·
verified ·
1 Parent(s): 81ecbe9

update app

Browse files
Files changed (1) hide show
  1. app.py +161 -158
app.py CHANGED
@@ -1,172 +1,175 @@
 
 
 
 
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModel, AutoTokenizer
4
- import spaces
5
- import os
6
- import tempfile
7
- from PIL import Image, ImageDraw
8
- import re
9
-
10
- # --- 1. Load Model and Tokenizer directly to the correct device ---
11
- print("Determining device...")
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- print(f" Using device: {device}")
14
-
15
- print("Loading model and tokenizer...")
16
- model_name = "lvyufeng/DeepSeek-OCR-Community-Latest"
17
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
18
-
19
- # Load the model directly to the specified device and set to evaluation mode
20
- model = AutoModel.from_pretrained(
21
- model_name,
22
- _attn_implementation="flash_attention_2",
23
- trust_remote_code=True,
24
- use_safetensors=True,
25
- ).to(device).eval() # Move to device and set to eval mode
26
-
27
- # Also apply the desired dtype if using a GPU
28
- if device.type == 'cuda':
29
- model = model.to(torch.bfloat16)
30
-
31
- print("✅ Model loaded successfully to device and in eval mode.")
32
-
33
-
34
- # --- Helper function to find pre-generated result images ---
35
- def find_result_image(path):
36
- for filename in os.listdir(path):
37
- if "grounding" in filename or "result" in filename:
38
- try:
39
- image_path = os.path.join(path, filename)
40
- return Image.open(image_path)
41
- except Exception as e:
42
- print(f"Error opening result image {filename}: {e}")
43
- return None
44
-
45
- # --- 2. Main Processing Function (Simplified) ---
46
- @spaces.GPU
47
- def process_ocr_task(image, model_size, task_type, ref_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  """
49
- Processes an image with DeepSeek-OCR. The model is already on the correct device.
50
  """
51
  if image is None:
52
- return "Please upload an image first.", None
53
-
54
- # No need to move the model to GPU here; it's already done at startup.
55
- print("✅ Model is already on the designated device.")
56
-
57
- with tempfile.TemporaryDirectory() as output_path:
58
- # Build the prompt
59
- if task_type == "📝 Free OCR":
60
- prompt = "<image>\nFree OCR."
61
- elif task_type == "📄 Convert to Markdown":
62
- prompt = "<image>\n<|grounding|>Convert the document to markdown."
63
- elif task_type == "📈 Parse Figure":
64
- prompt = "<image>\nParse the figure."
65
- elif task_type == "🔍 Locate Object by Reference":
66
- if not ref_text or ref_text.strip() == "":
67
- raise gr.Error("For the 'Locate' task, you must provide the reference text to find!")
68
- prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image."
69
- else:
70
- prompt = "<image>\nFree OCR."
71
-
72
- temp_image_path = os.path.join(output_path, "temp_image.png")
73
- image.save(temp_image_path)
74
-
75
- # Configure model size
76
- size_configs = {
77
- "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
78
- "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
79
- "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
80
- "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
81
- "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
82
- }
83
- config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
84
-
85
- print(f"🏃 Running inference with prompt: {prompt}")
86
- # Use the globally defined 'model' which is already on the GPU
87
- text_result = model.infer(
88
- tokenizer,
89
- prompt=prompt,
90
- image_file=temp_image_path,
91
- output_path=output_path,
92
- base_size=config["base_size"],
93
- image_size=config["image_size"],
94
- crop_mode=config["crop_mode"],
95
- save_results=True,
96
- test_compress=True,
97
- eval_mode=True,
98
- )
99
 
100
- print(f"====\n📄 Text Result: {text_result}\n====")
101
-
102
- # --- Logic to draw bounding boxes ---
103
- result_image_pil = None
104
- pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>")
105
- matches = list(pattern.finditer(text_result))
106
-
107
- if matches:
108
- print(f"✅ Found {len(matches)} bounding box(es). Drawing on the original image.")
109
- image_with_bboxes = image.copy()
110
- draw = ImageDraw.Draw(image_with_bboxes)
111
- w, h = image.size
112
-
113
- for match in matches:
114
- coords_norm = [int(c) for c in match.groups()]
115
- x1_norm, y1_norm, x2_norm, y2_norm = coords_norm
116
-
117
- x1 = int(x1_norm / 1000 * w)
118
- y1 = int(y1_norm / 1000 * h)
119
- x2 = int(x2_norm / 1000 * w)
120
- y2 = int(y2_norm / 1000 * h)
121
-
122
- draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
123
-
124
- result_image_pil = image_with_bboxes
125
- else:
126
- print("⚠️ No bounding box coordinates found in text result. Falling back to search for a result image file.")
127
- result_image_pil = find_result_image(output_path)
128
-
129
- return text_result, result_image_pil
130
-
131
-
132
- # --- 3. Build the Gradio Interface ---
133
- with gr.Blocks(title="🐳DeepSeek-OCR🐳", theme=gr.themes.Soft()) as demo:
134
- gr.Markdown(
135
- """
136
- # 🐳 Full Demo of DeepSeek-OCR 🐳
137
-
138
- **💡 How to use:**
139
- 1. **Upload an image** using the upload box.
140
- 2. Select a **Resolution**. `Gundam` is recommended for most documents.
141
- 3. Choose a **Task Type**:
142
- - **📝 Free OCR**: Extracts raw text from the image.
143
- - **📄 Convert to Markdown**: Converts the document into Markdown, preserving structure.
144
- - **📈 Parse Figure**: Extracts structured data from charts and figures.
145
- - **🔍 Locate Object by Reference**: Finds a specific object/text.
146
- 4. If this helpful, please give it a like! 🙏 ❤️
147
- """
148
  )
149
 
150
- with gr.Row():
151
- with gr.Column(scale=1):
152
- image_input = gr.Image(type="pil", label="🖼️ Upload Image", sources=["upload", "clipboard"])
153
- model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Gundam (Recommended)", label="⚙️ Resolution Size")
154
- task_type = gr.Dropdown(choices=["📝 Free OCR", "📄 Convert to Markdown", "📈 Parse Figure", "🔍 Locate Object by Reference"], value="📄 Convert to Markdown", label="🚀 Task Type")
155
- ref_text_input = gr.Textbox(label="📝 Reference Text (for Locate task)", placeholder="e.g., the teacher, 20-10, a red car...", visible=False)
156
- submit_btn = gr.Button("Process Image", variant="primary")
157
 
158
- with gr.Column(scale=2):
159
- output_text = gr.Textbox(label="📄 Text Result", lines=15, show_copy_button=True)
160
- output_image = gr.Image(label="🖼️ Image Result (if any)", type="pil")
 
161
 
162
- # --- UI Interaction Logic ---
163
- def toggle_ref_text_visibility(task):
164
- return gr.Textbox(visible=True) if task == "🔍 Locate Object by Reference" else gr.Textbox(visible=False)
165
 
166
- task_type.change(fn=toggle_ref_text_visibility, inputs=task_type, outputs=ref_text_input)
167
- submit_btn.click(fn=process_ocr_task, inputs=[image_input, model_size, task_type, ref_text_input], outputs=[output_text, output_image])
 
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- # --- 4. Launch the App ---
171
  if __name__ == "__main__":
172
- demo.queue(max_size=20).launch(share=True)
 
1
+ import os
2
+ import sys
3
+ import spaces
4
+ from typing import Iterable
5
  import gradio as gr
6
  import torch
7
+ import requests
8
+ from PIL import Image
9
+ from transformers import AutoProcessor, Florence2ForConditionalGeneration
10
+ from gradio.themes import Soft
11
+ from gradio.themes.utils import colors, fonts, sizes
12
+
13
+ colors.steel_blue = colors.Color(
14
+ name="steel_blue",
15
+ c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2",
16
+ c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C",
17
+ c800="#2E5378", c900="#264364", c950="#1E3450",
18
+ )
19
+
20
+ class SteelBlueTheme(Soft):
21
+ def __init__(
22
+ self,
23
+ *,
24
+ primary_hue: colors.Color | str = colors.gray,
25
+ secondary_hue: colors.Color | str = colors.steel_blue,
26
+ neutral_hue: colors.Color | str = colors.slate,
27
+ text_size: sizes.Size | str = sizes.text_lg,
28
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
29
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
30
+ ),
31
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
32
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
33
+ ),
34
+ ):
35
+ super().__init__(
36
+ primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue,
37
+ text_size=text_size, font=font, font_mono=font_mono,
38
+ )
39
+ super().set(
40
+ background_fill_primary="*primary_50",
41
+ background_fill_primary_dark="*primary_900",
42
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
43
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
44
+ button_primary_text_color="white",
45
+ button_primary_text_color_hover="white",
46
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
47
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
48
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
49
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
50
+ slider_color="*secondary_500",
51
+ slider_color_dark="*secondary_600",
52
+ block_title_text_weight="600",
53
+ block_border_width="3px",
54
+ block_shadow="*shadow_drop_lg",
55
+ button_primary_shadow="*shadow_drop_lg",
56
+ button_large_padding="11px",
57
+ color_accent_soft="*primary_100",
58
+ block_label_background_fill="*primary_200",
59
+ )
60
+
61
+ steel_blue_theme = SteelBlueTheme()
62
+
63
+ css = """
64
+ #main-title h1 {
65
+ font-size: 2.3em !important;
66
+ }
67
+ #output-title h2 {
68
+ font-size: 2.1em !important;
69
+ }
70
+ """
71
+
72
+ MODEL_IDS = {
73
+ "Florence-2-base": "florence-community/Florence-2-base",
74
+ "Florence-2-base-ft": "florence-community/Florence-2-base-ft",
75
+ "Florence-2-large": "florence-community/Florence-2-large",
76
+ "Florence-2-large-ft": "florence-community/Florence-2-large-ft",
77
+ }
78
+
79
+ models = {}
80
+ processors = {}
81
+
82
+ print("Loading Florence-2 models... This may take a while.")
83
+ for name, repo_id in MODEL_IDS.items():
84
+ print(f"Loading {name}...")
85
+ model = Florence2ForConditionalGeneration.from_pretrained(
86
+ repo_id,
87
+ dtype=torch.bfloat16,
88
+ device_map="auto",
89
+ trust_remote_code=True
90
+ )
91
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
92
+ models[name] = model
93
+ processors[name] = processor
94
+ print(f"✅ Finished loading {name}.")
95
+
96
+ print("\n🎉 All models loaded successfully!")
97
+
98
+ @spaces.GPU(duration=30)
99
+ def run_florence2_inference(model_name: str, image: Image.Image, task_prompt: str,
100
+ max_new_tokens: int = 1024, num_beams: int = 3):
101
  """
102
+ Runs inference using the selected Florence-2 model.
103
  """
104
  if image is None:
105
+ return "Please upload an image to get started."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ model = models[model_name]
108
+ processor = processors[model_name]
109
+
110
+ inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
111
+
112
+ generated_ids = model.generate(
113
+ input_ids=inputs["input_ids"],
114
+ pixel_values=inputs["pixel_values"],
115
+ max_new_tokens=max_new_tokens,
116
+ num_beams=num_beams,
117
+ do_sample=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
119
 
120
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
 
 
 
 
 
121
 
122
+ image_size = image.size
123
+ parsed_answer = processor.post_process_generation(
124
+ generated_text, task=task_prompt, image_size=image_size
125
+ )
126
 
127
+ return parsed_answer
 
 
128
 
129
+ florence_tasks = [
130
+ "<OD>", "<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
131
+ "<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
132
+ ]
133
 
134
+ url = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/venice.jpg?download=true"
135
+ example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
136
+
137
+ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
138
+ gr.Markdown("# **Florence-2 Vision Models**", elem_id="main-title")
139
+ gr.Markdown("Select a model, upload an image, choose a task, and click Submit to see the results.")
140
+
141
+ with gr.Row():
142
+ with gr.Column(scale=2):
143
+ image_upload = gr.Image(type="pil", label="Upload Image", value=example_image, height=290)
144
+ task_prompt = gr.Dropdown(
145
+ label="Select Task",
146
+ choices=florence_tasks,
147
+ value="<MORE_DETAILED_CAPTION>"
148
+ )
149
+ model_choice = gr.Radio(
150
+ choices=list(MODEL_IDS.keys()),
151
+ label="Select Model",
152
+ value="Florence-2-base"
153
+ )
154
+ image_submit = gr.Button("Submit", variant="primary")
155
+
156
+ with gr.Accordion("Advanced options", open=False):
157
+ max_new_tokens = gr.Slider(
158
+ label="Max New Tokens", minimum=128, maximum=2048, step=128, value=1024
159
+ )
160
+ num_beams = gr.Slider(
161
+ label="Number of Beams", minimum=1, maximum=10, step=1, value=3
162
+ )
163
+
164
+ with gr.Column(scale=3):
165
+ gr.Markdown("## Output", elem_id="output-title")
166
+ parsed_output = gr.JSON(label="Parsed Answer")
167
+
168
+ image_submit.click(
169
+ fn=run_florence2_inference,
170
+ inputs=[model_choice, image_upload, task_prompt, max_new_tokens, num_beams],
171
+ outputs=[parsed_output]
172
+ )
173
 
 
174
  if __name__ == "__main__":
175
+ demo.queue().launch(debug=True, mcp_server=True, ssr_mode=False, show_error=True)