ravi-vc commited on
Commit
5e9578a
·
verified ·
1 Parent(s): 6af56a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -40
app.py CHANGED
@@ -1,47 +1,227 @@
1
- import os
2
  import gradio as gr
 
3
  from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
4
  from PIL import Image
 
 
 
 
5
  import json
6
 
7
- # Fix threading error
8
- os.environ["OMP_NUM_THREADS"] = "1"
9
-
10
- # Load DePlot
11
- model_id = "google/deplot"
12
- processor = Pix2StructProcessor.from_pretrained(model_id)
13
- model = Pix2StructForConditionalGeneration.from_pretrained(model_id)
14
-
15
- def extract_chart(image):
16
- # Step 1: Run DePlot
17
- inputs = processor(images=image, text="Generate table from chart.", return_tensors="pt")
18
- predictions = model.generate(**inputs, max_new_tokens=512)
19
- table = processor.decode(predictions[0], skip_special_tokens=True)
20
-
21
- # Step 2: Dummy structured JSON
22
- structured_json = {
23
- "metadata": {"title": "Demo Chart", "chart_type": "bar", "confidence": 0.5},
24
- "axes": {"x_axis": {"label": "X", "ticks": []}, "y_axis": {"label": "Y", "ticks": []}},
25
- "series": [],
26
- "legend": {"entries": []}
27
- }
28
-
29
- # Step 3: Merge outputs
30
- merged_output = {
31
- "structured_json": structured_json,
32
- "deplot_table": table,
33
- "fusion_notes": "Fusion layer not implemented yet, just showing both outputs."
34
- }
35
-
36
- return json.dumps(merged_output, indent=2)
37
-
38
- demo = gr.Interface(
39
- fn=extract_chart,
40
- inputs=gr.Image(type="pil"),
41
- outputs="json",
42
- title="Chart-to-JSON Extractor (Prototype)",
43
- description="Uploads a chart, extracts structured JSON (dummy) and DePlot table side-by-side."
44
- )
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  if __name__ == "__main__":
47
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
4
  from PIL import Image
5
+ import requests
6
+ import io
7
+ import re
8
+ import pandas as pd
9
  import json
10
 
11
+ # Load the DePlot model and processor
12
+ MODEL_NAME = "google/deplot"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def load_model():
15
+ """Load the DePlot model and processor"""
16
+ try:
17
+ processor = Pix2StructProcessor.from_pretrained(MODEL_NAME)
18
+ model = Pix2StructForConditionalGeneration.from_pretrained(MODEL_NAME)
19
+ return processor, model
20
+ except Exception as e:
21
+ print(f"Error loading model: {e}")
22
+ return None, None
23
+
24
+ processor, model = load_model()
25
+
26
+ def extract_chart_data(image, question="Generate underlying data table of the figure below:"):
27
+ """
28
+ Extract data from chart image using DePlot model
29
+
30
+ Args:
31
+ image: PIL Image or file path
32
+ question: Question to ask about the chart
33
+
34
+ Returns:
35
+ Extracted data as text and structured format
36
+ """
37
+ if processor is None or model is None:
38
+ return "Error: Model not loaded properly", None
39
+
40
+ try:
41
+ # Ensure image is PIL Image
42
+ if isinstance(image, str):
43
+ image = Image.open(image)
44
+ elif hasattr(image, 'name'): # Gradio file object
45
+ image = Image.open(image.name)
46
+
47
+ # Convert to RGB if necessary
48
+ if image.mode != 'RGB':
49
+ image = image.convert('RGB')
50
+
51
+ # Process the image and question
52
+ inputs = processor(images=image, text=question, return_tensors="pt")
53
+
54
+ # Generate predictions
55
+ predictions = model.generate(**inputs, max_new_tokens=512)
56
+
57
+ # Decode the output
58
+ extracted_text = processor.decode(predictions[0], skip_special_tokens=True)
59
+
60
+ # Try to parse the extracted text into structured data
61
+ structured_data = parse_extracted_data(extracted_text)
62
+
63
+ return extracted_text, structured_data
64
+
65
+ except Exception as e:
66
+ return f"Error processing image: {str(e)}", None
67
+
68
+ def parse_extracted_data(text):
69
+ """
70
+ Parse the extracted text to create structured data
71
+ This is a basic parser - you might need to customize based on your needs
72
+ """
73
+ try:
74
+ # Look for table-like patterns
75
+ lines = text.strip().split('\n')
76
+ data = []
77
+
78
+ # Try to find header and data rows
79
+ for line in lines:
80
+ if '|' in line: # Table format with pipes
81
+ row = [cell.strip() for cell in line.split('|') if cell.strip()]
82
+ if row:
83
+ data.append(row)
84
+ elif '\t' in line: # Tab-separated
85
+ row = [cell.strip() for cell in line.split('\t') if cell.strip()]
86
+ if row:
87
+ data.append(row)
88
+ elif ',' in line and not line.startswith('The'): # CSV-like
89
+ row = [cell.strip() for cell in line.split(',') if cell.strip()]
90
+ if row:
91
+ data.append(row)
92
+
93
+ if data:
94
+ # Create DataFrame
95
+ if len(data) > 1:
96
+ df = pd.DataFrame(data[1:], columns=data[0])
97
+ else:
98
+ df = pd.DataFrame(data)
99
+ return df
100
+
101
+ return None
102
+
103
+ except Exception as e:
104
+ print(f"Error parsing data: {e}")
105
+ return None
106
+
107
+ def process_chart(image, custom_question):
108
+ """
109
+ Main function to process chart and return results
110
+ """
111
+ if image is None:
112
+ return "Please upload an image", None, None
113
+
114
+ # Use custom question if provided, otherwise use default
115
+ question = custom_question if custom_question.strip() else "Generate underlying data table of the figure below:"
116
+
117
+ # Extract data
118
+ raw_output, structured_data = extract_chart_data(image, question)
119
+
120
+ # Prepare outputs
121
+ if structured_data is not None and not structured_data.empty:
122
+ # Convert DataFrame to HTML for display
123
+ table_html = structured_data.to_html(index=False, classes='table table-striped')
124
+ # Convert DataFrame to CSV string for download
125
+ csv_output = structured_data.to_csv(index=False)
126
+ else:
127
+ table_html = "Could not parse data into structured format"
128
+ csv_output = None
129
+
130
+ return raw_output, table_html, csv_output
131
+
132
+ # Create Gradio interface
133
+ def create_interface():
134
+ with gr.Blocks(title="DePlot Chart Data Extractor", theme=gr.themes.Soft()) as demo:
135
+ gr.Markdown("""
136
+ # 📊 DePlot Chart Data Extractor
137
+
138
+ Upload any chart or plot image to extract the underlying data, even without visible data labels!
139
+ This tool uses Google's DePlot model to understand and extract data from various types of charts.
140
+
141
+ **Supported chart types:** Bar charts, line graphs, scatter plots, pie charts, and more!
142
+ """)
143
+
144
+ with gr.Row():
145
+ with gr.Column(scale=1):
146
+ # Input section
147
+ image_input = gr.Image(
148
+ type="pil",
149
+ label="Upload Chart Image",
150
+ height=400
151
+ )
152
+
153
+ custom_question = gr.Textbox(
154
+ label="Custom Question (optional)",
155
+ placeholder="e.g., 'What are the values for each category?' or leave empty for default",
156
+ lines=2
157
+ )
158
+
159
+ extract_btn = gr.Button("Extract Data", variant="primary", size="lg")
160
+
161
+ with gr.Column(scale=1):
162
+ # Output section
163
+ with gr.Tab("Raw Output"):
164
+ raw_output = gr.Textbox(
165
+ label="Extracted Text",
166
+ lines=10,
167
+ show_copy_button=True
168
+ )
169
+
170
+ with gr.Tab("Structured Data"):
171
+ structured_output = gr.HTML(
172
+ label="Parsed Data Table"
173
+ )
174
+
175
+ # Download section
176
+ csv_download = gr.File(
177
+ label="Download CSV",
178
+ visible=False
179
+ )
180
+
181
+ # Example images
182
+ gr.Markdown("### 📋 Try these examples:")
183
+
184
+ example_images = [
185
+ ["examples/bar_chart.png", "Extract the data from this bar chart"],
186
+ ["examples/line_graph.png", "What are the trend values over time?"],
187
+ ["examples/pie_chart.png", "Give me the percentage breakdown"]
188
+ ]
189
+
190
+ # Note: You'll need to add example images to your space
191
+
192
+ # Event handlers
193
+ def process_and_download(image, question):
194
+ raw, table, csv_data = process_chart(image, question)
195
+
196
+ if csv_data:
197
+ # Create temporary CSV file for download
198
+ csv_file = io.StringIO()
199
+ csv_file.write(csv_data)
200
+ csv_file.seek(0)
201
+ return raw, table, gr.File(value=csv_data, visible=True)
202
+ else:
203
+ return raw, table, gr.File(visible=False)
204
+
205
+ extract_btn.click(
206
+ fn=process_and_download,
207
+ inputs=[image_input, custom_question],
208
+ outputs=[raw_output, structured_output, csv_download]
209
+ )
210
+
211
+ gr.Markdown("""
212
+ ### 💡 Tips for better results:
213
+ - Use clear, high-resolution images
214
+ - Ensure chart elements are visible and not too small
215
+ - Try different custom questions for specific data you need
216
+ - Works best with standard chart types (bar, line, scatter, pie)
217
+
218
+ ### 🔧 Model Information:
219
+ This space uses Google's DePlot model, which is specifically trained to extract data from plots and figures.
220
+ """)
221
+
222
+ return demo
223
+
224
+ # Create and launch the interface
225
  if __name__ == "__main__":
226
+ demo = create_interface()
227
+ demo.launch()