Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
-
import
|
2 |
-
import pandas as pd
|
3 |
-
import numpy as np
|
4 |
import io
|
5 |
import base64
|
|
|
6 |
from typing import Optional, Tuple
|
7 |
-
|
|
|
|
|
|
|
8 |
import plotly.graph_objects as go
|
9 |
from plotly.subplots import make_subplots
|
10 |
-
|
11 |
warnings.filterwarnings("ignore")
|
12 |
|
13 |
# Import Mostly AI SDK
|
@@ -18,96 +20,142 @@ except ImportError:
|
|
18 |
MOSTLY_AI_AVAILABLE = False
|
19 |
print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]")
|
20 |
|
|
|
21 |
class SyntheticDataGenerator:
|
22 |
def __init__(self):
|
23 |
self.mostly = None
|
24 |
self.generator = None
|
25 |
self.original_data = None
|
26 |
-
|
27 |
-
def initialize_mostly_ai(self):
|
28 |
"""Initialize Mostly AI SDK"""
|
29 |
if not MOSTLY_AI_AVAILABLE:
|
30 |
return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
|
31 |
-
|
32 |
try:
|
33 |
self.mostly = MostlyAI(local=True, local_port=8080)
|
34 |
return True, "Mostly AI SDK initialized successfully."
|
35 |
except Exception as e:
|
36 |
return False, f"Failed to initialize Mostly AI SDK: {str(e)}"
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
"""Train the synthetic data generator"""
|
41 |
if not self.mostly:
|
42 |
return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
|
43 |
-
|
44 |
try:
|
45 |
self.original_data = data
|
46 |
-
train_config = {
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
self.generator = self.mostly.train(
|
63 |
-
config = train_config
|
64 |
-
)
|
65 |
return True, f"Training completed successfully. Model name: {name}"
|
66 |
except Exception as e:
|
67 |
return False, f"Training failed with error: {str(e)}"
|
68 |
-
|
69 |
-
def generate_synthetic_data(self, size: int) -> Tuple[pd.DataFrame, str]:
|
70 |
"""Generate synthetic data"""
|
71 |
if not self.generator:
|
72 |
return None, "No trained generator available. Please train a model first."
|
73 |
-
|
74 |
try:
|
75 |
synthetic_data = self.mostly.generate(self.generator, size=size)
|
76 |
df = synthetic_data.data()
|
77 |
return df, f"Synthetic data generated successfully. {len(df)} records created."
|
78 |
except Exception as e:
|
79 |
return None, f"Synthetic data generation failed with error: {str(e)}"
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
83 |
if not self.generator:
|
84 |
return "No trained generator available. Please train a model first."
|
85 |
-
|
86 |
try:
|
87 |
-
|
88 |
-
return
|
89 |
except Exception as e:
|
90 |
return f"Failed to generate quality report: {str(e)}"
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def estimate_memory_usage(self, df: pd.DataFrame) -> str:
|
93 |
"""Estimate memory usage for the dataset"""
|
94 |
if df is None or df.empty:
|
95 |
return "No data available to analyze."
|
96 |
-
|
97 |
-
# Calculate approximate memory usage
|
98 |
memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
|
99 |
rows, cols = len(df), len(df.columns)
|
100 |
-
|
101 |
-
# Estimate training memory (roughly 3-5x the data size)
|
102 |
estimated_training_mb = memory_mb * 4
|
103 |
-
|
104 |
if memory_mb < 100:
|
105 |
status = "Good"
|
106 |
elif memory_mb < 500:
|
107 |
status = "Large"
|
108 |
else:
|
109 |
status = "Very Large"
|
110 |
-
|
111 |
return f"""
|
112 |
Memory Usage Estimate:
|
113 |
- Data size: {memory_mb:.1f} MB
|
@@ -116,196 +164,180 @@ Memory Usage Estimate:
|
|
116 |
- Rows: {rows:,} | Columns: {cols}
|
117 |
""".strip()
|
118 |
|
|
|
119 |
# Initialize the generator
|
120 |
generator = SyntheticDataGenerator()
|
121 |
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
def initialize_sdk() -> Tuple[str, str]:
|
124 |
-
"""Initialize the Mostly AI SDK"""
|
125 |
-
success, message = generator.initialize_mostly_ai()
|
126 |
-
status = "Success" if success else "Error"
|
127 |
-
return status, message
|
128 |
|
129 |
-
def train_model(
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
if data is None or data.empty:
|
132 |
-
return "Error
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
return
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
if generator.generator is None:
|
141 |
-
return None, "Error: No trained model available. Please train a model first."
|
142 |
-
|
143 |
synthetic_df, message = generator.generate_synthetic_data(size)
|
144 |
-
if synthetic_df is not None
|
145 |
-
status = "Success"
|
146 |
-
else:
|
147 |
-
status = "Error"
|
148 |
-
|
149 |
return synthetic_df, f"{status}: {message}"
|
150 |
|
151 |
-
def get_quality_report() -> str:
|
152 |
-
"""Get quality report"""
|
153 |
-
return generator.get_quality_report()
|
154 |
|
155 |
-
def
|
156 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
if original_df is None or synthetic_df is None:
|
158 |
return None
|
159 |
-
|
160 |
-
# Select numeric columns for comparison
|
161 |
numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
|
162 |
-
|
163 |
if not numeric_cols:
|
164 |
return None
|
165 |
-
|
166 |
-
# Create subplots
|
167 |
n_cols = min(3, len(numeric_cols))
|
168 |
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
|
169 |
-
|
170 |
-
fig = make_subplots(
|
171 |
-
|
172 |
-
|
173 |
-
subplot_titles=numeric_cols[:n_rows*n_cols]
|
174 |
-
)
|
175 |
-
|
176 |
-
for i, col in enumerate(numeric_cols[:n_rows*n_cols]):
|
177 |
row = i // n_cols + 1
|
178 |
col_idx = i % n_cols + 1
|
179 |
-
|
180 |
-
# Add original data histogram
|
181 |
fig.add_trace(
|
182 |
-
go.Histogram(
|
183 |
-
|
184 |
-
|
185 |
-
opacity=0.7,
|
186 |
-
nbinsx=20
|
187 |
-
),
|
188 |
-
row=row, col=col_idx
|
189 |
)
|
190 |
-
|
191 |
-
# Add synthetic data histogram
|
192 |
fig.add_trace(
|
193 |
-
go.Histogram(
|
194 |
-
|
195 |
-
|
196 |
-
opacity=0.7,
|
197 |
-
nbinsx=20
|
198 |
-
),
|
199 |
-
row=row, col=col_idx
|
200 |
)
|
201 |
-
|
202 |
-
fig.update_layout(
|
203 |
-
title="Original vs Synthetic Data Comparison",
|
204 |
-
height=300 * n_rows,
|
205 |
-
showlegend=True
|
206 |
-
)
|
207 |
-
|
208 |
return fig
|
209 |
|
210 |
-
|
211 |
-
|
212 |
if df is None or df.empty:
|
213 |
return None
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
217 |
|
218 |
-
#
|
219 |
def create_interface():
|
220 |
with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
224 |
|
225 |
# README
|
226 |
-
gr.Markdown(
|
|
|
227 |
# Synthetic Data SDK by MOSTLY AI Demo Space
|
228 |
|
229 |
[Documentation](https://mostly-ai.github.io/mostlyai/) | [Technical White Paper](https://arxiv.org/abs/2508.00718) | [Usage Examples](https://mostly-ai.github.io/mostlyai/usage/) | [Free Cloud Service](https://app.mostly.ai/)
|
230 |
|
231 |
A Python toolkit for generating high-fidelity, privacy-safe synthetic data.
|
232 |
-
"""
|
233 |
-
|
|
|
234 |
with gr.Tab("Quick Start"):
|
235 |
gr.Markdown("### Initialize the SDK and upload your data")
|
236 |
-
|
237 |
with gr.Row():
|
238 |
with gr.Column():
|
239 |
init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary")
|
240 |
init_status = gr.Textbox(label="Initialization Status", interactive=False)
|
241 |
-
|
242 |
with gr.Column():
|
243 |
-
gr.Markdown(
|
|
|
244 |
**Next Steps:**
|
245 |
1. Initialize the SDK (click button above)
|
246 |
2. Go to "Upload Data and Train Model" tab to upload your CSV file
|
247 |
3. Train a model on your data
|
248 |
4. Generate synthetic data
|
249 |
-
"""
|
250 |
-
|
|
|
251 |
with gr.Tab("Upload Data and Train Model"):
|
252 |
gr.Markdown("### Upload your CSV file to generate synthetic data")
|
253 |
-
|
254 |
-
|
255 |
**File Requirements:**
|
256 |
- Format: CSV with header row
|
257 |
- Size: Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM)
|
258 |
-
"""
|
259 |
-
|
260 |
-
file_upload = gr.File(
|
261 |
-
label="Upload CSV File",
|
262 |
-
file_types=[".csv"],
|
263 |
-
file_count="single"
|
264 |
)
|
265 |
-
|
|
|
266 |
uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False)
|
267 |
-
|
268 |
memory_info = gr.Markdown(label="Memory Usage Info", visible=False)
|
269 |
-
|
270 |
with gr.Row():
|
271 |
with gr.Column():
|
272 |
model_name = gr.Textbox(
|
273 |
-
value="My Synthetic Model",
|
274 |
-
label="Model Name",
|
275 |
-
placeholder="Enter a name for your model"
|
276 |
)
|
277 |
epochs = gr.Slider(1, 200, value=100, step=1, label="Training Epochs")
|
278 |
max_training_time = gr.Slider(1, 1000, value=60, step=1, label="Maximum Training Time")
|
279 |
batch_size = gr.Slider(8, 1024, value=32, step=8, label="Training Batch Size")
|
280 |
value_protection = gr.Checkbox(label="Value Protection", info="Enable Value Protection")
|
281 |
train_btn = gr.Button("Train Model", variant="primary")
|
282 |
-
|
283 |
with gr.Column():
|
284 |
train_status = gr.Textbox(label="Training Status", interactive=False)
|
285 |
-
quality_report = gr.Textbox(label="Quality Report", lines=
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
289 |
with gr.Tab("Generate Data"):
|
290 |
gr.Markdown("### Generate synthetic data from your trained model")
|
291 |
-
|
292 |
with gr.Row():
|
293 |
with gr.Column():
|
294 |
gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate")
|
295 |
generate_btn = gr.Button("Generate Synthetic Data", variant="primary")
|
296 |
-
|
297 |
with gr.Column():
|
298 |
gen_status = gr.Textbox(label="Generation Status", interactive=False)
|
299 |
-
|
300 |
synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
|
301 |
-
|
302 |
with gr.Row():
|
303 |
-
download_btn = gr.DownloadButton("Download CSV", variant="secondary")
|
304 |
comparison_plot = gr.Plot(label="Data Comparison")
|
305 |
|
306 |
-
# README
|
307 |
-
gr.Markdown(
|
308 |
-
|
309 |
**Modes of operation:**
|
310 |
- **LOCAL mode** trains and generates synthetic data on your own compute resources.
|
311 |
- **CLIENT mode** connects to a remote MOSTLY AI platform for training and generation.
|
@@ -325,75 +357,52 @@ def create_interface():
|
|
325 |
The open source Synthetic Data SDK by MOSTLY AI powers the MOSTLY AI Platform and MOSTLY AI Assistant.
|
326 |
|
327 |
Sign up for free and try the [MOSTLY AI Platform](https://app.mostly.ai/) today!
|
328 |
-
"""
|
329 |
-
|
330 |
-
# Event handlers
|
331 |
-
init_btn.click(
|
332 |
-
initialize_sdk,
|
333 |
-
outputs=[init_status, init_status]
|
334 |
)
|
335 |
-
|
|
|
|
|
|
|
336 |
train_btn.click(
|
337 |
train_model,
|
338 |
inputs=[uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection],
|
339 |
-
outputs=[train_status,
|
340 |
)
|
341 |
-
|
|
|
342 |
get_report_btn.click(
|
343 |
-
|
344 |
-
outputs=[quality_report]
|
345 |
-
)
|
346 |
-
|
347 |
-
generate_btn.click(
|
348 |
-
generate_data,
|
349 |
-
inputs=[gen_size],
|
350 |
-
outputs=[synthetic_data, gen_status]
|
351 |
-
)
|
352 |
-
|
353 |
-
# Update download button when synthetic data changes
|
354 |
-
synthetic_data.change(
|
355 |
-
download_csv,
|
356 |
-
inputs=[synthetic_data],
|
357 |
-
outputs=[download_btn]
|
358 |
)
|
359 |
-
|
360 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
synthetic_data.change(
|
362 |
-
create_comparison_plot,
|
363 |
-
inputs=[uploaded_data, synthetic_data],
|
364 |
-
outputs=[comparison_plot]
|
365 |
)
|
366 |
-
|
367 |
# Handle file upload with size and column limits
|
368 |
def process_uploaded_file(file):
|
369 |
if file is None:
|
370 |
return None, "No file uploaded.", gr.update(visible=False)
|
371 |
-
|
372 |
try:
|
373 |
-
# Read the CSV file
|
374 |
df = pd.read_csv(file.name)
|
375 |
-
|
376 |
success_msg = f"File uploaded successfully. {len(df)} rows × {len(df.columns)} columns"
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
return df, success_msg, gr.update(value=memory_info, visible=True)
|
381 |
-
|
382 |
except Exception as e:
|
383 |
return None, f"Error reading file: {str(e)}", gr.update(visible=False)
|
384 |
-
|
385 |
-
file_upload.change(
|
386 |
-
|
387 |
-
inputs=[file_upload],
|
388 |
-
outputs=[uploaded_data, train_status, memory_info]
|
389 |
-
)
|
390 |
-
|
391 |
return demo
|
392 |
|
|
|
393 |
if __name__ == "__main__":
|
394 |
demo = create_interface()
|
395 |
-
demo.launch(
|
396 |
-
server_name="0.0.0.0",
|
397 |
-
server_port=7860,
|
398 |
-
share=True
|
399 |
-
)
|
|
|
1 |
+
import os
|
|
|
|
|
2 |
import io
|
3 |
import base64
|
4 |
+
import warnings
|
5 |
from typing import Optional, Tuple
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
import plotly.graph_objects as go
|
11 |
from plotly.subplots import make_subplots
|
12 |
+
|
13 |
warnings.filterwarnings("ignore")
|
14 |
|
15 |
# Import Mostly AI SDK
|
|
|
20 |
MOSTLY_AI_AVAILABLE = False
|
21 |
print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]")
|
22 |
|
23 |
+
|
24 |
class SyntheticDataGenerator:
|
25 |
def __init__(self):
|
26 |
self.mostly = None
|
27 |
self.generator = None
|
28 |
self.original_data = None
|
29 |
+
|
30 |
+
def initialize_mostly_ai(self) -> Tuple[bool, str]:
|
31 |
"""Initialize Mostly AI SDK"""
|
32 |
if not MOSTLY_AI_AVAILABLE:
|
33 |
return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
|
|
|
34 |
try:
|
35 |
self.mostly = MostlyAI(local=True, local_port=8080)
|
36 |
return True, "Mostly AI SDK initialized successfully."
|
37 |
except Exception as e:
|
38 |
return False, f"Failed to initialize Mostly AI SDK: {str(e)}"
|
39 |
+
|
40 |
+
def train_generator(
|
41 |
+
self,
|
42 |
+
data: pd.DataFrame,
|
43 |
+
name: str,
|
44 |
+
epochs: int = 10,
|
45 |
+
max_training_time: int = 60,
|
46 |
+
batch_size: int = 32,
|
47 |
+
value_protection: bool = True,
|
48 |
+
) -> Tuple[bool, str]:
|
49 |
"""Train the synthetic data generator"""
|
50 |
if not self.mostly:
|
51 |
return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
|
|
|
52 |
try:
|
53 |
self.original_data = data
|
54 |
+
train_config = {
|
55 |
+
"tables": [
|
56 |
+
{
|
57 |
+
"name": name,
|
58 |
+
"data": data,
|
59 |
+
"tabular_model_configuration": {
|
60 |
+
"max_epochs": epochs,
|
61 |
+
"max_training_time": max_training_time,
|
62 |
+
"value_protection": value_protection,
|
63 |
+
"batch_size": batch_size,
|
64 |
+
},
|
65 |
+
}
|
66 |
+
]
|
67 |
+
}
|
68 |
+
|
69 |
+
self.generator = self.mostly.train(config=train_config)
|
|
|
|
|
|
|
70 |
return True, f"Training completed successfully. Model name: {name}"
|
71 |
except Exception as e:
|
72 |
return False, f"Training failed with error: {str(e)}"
|
73 |
+
|
74 |
+
def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
75 |
"""Generate synthetic data"""
|
76 |
if not self.generator:
|
77 |
return None, "No trained generator available. Please train a model first."
|
|
|
78 |
try:
|
79 |
synthetic_data = self.mostly.generate(self.generator, size=size)
|
80 |
df = synthetic_data.data()
|
81 |
return df, f"Synthetic data generated successfully. {len(df)} records created."
|
82 |
except Exception as e:
|
83 |
return None, f"Synthetic data generation failed with error: {str(e)}"
|
84 |
+
|
85 |
+
# ---- Report helpers (new) ----
|
86 |
+
def get_quality_report_text(self) -> str:
|
87 |
+
"""Return a concise status about the report."""
|
88 |
if not self.generator:
|
89 |
return "No trained generator available. Please train a model first."
|
|
|
90 |
try:
|
91 |
+
_ = self.generator.reports(display=False) # builds report internally
|
92 |
+
return "Quality report generated. Use the button to download."
|
93 |
except Exception as e:
|
94 |
return f"Failed to generate quality report: {str(e)}"
|
95 |
+
|
96 |
+
def get_quality_report_file(self) -> Optional[str]:
|
97 |
+
"""
|
98 |
+
Generate/export the report and return a file path for download.
|
99 |
+
Tries to find an existing ZIP; otherwise saves a TXT fallback.
|
100 |
+
"""
|
101 |
+
if not self.generator:
|
102 |
+
return None
|
103 |
+
try:
|
104 |
+
rep = self.generator.reports(display=False)
|
105 |
+
|
106 |
+
# 1) If a string path to a .zip is returned
|
107 |
+
if isinstance(rep, str) and rep.endswith(".zip") and os.path.exists(rep):
|
108 |
+
return rep
|
109 |
+
|
110 |
+
# 2) If the object exposes a path-like attribute
|
111 |
+
for attr in ("archive_path", "zip_path", "path", "file_path"):
|
112 |
+
if hasattr(rep, attr):
|
113 |
+
p = getattr(rep, attr)
|
114 |
+
if isinstance(p, str) and os.path.exists(p):
|
115 |
+
return p
|
116 |
+
|
117 |
+
# 3) If the object can save/export itself
|
118 |
+
target_zip = "/mnt/data/quality_report.zip"
|
119 |
+
if hasattr(rep, "save"):
|
120 |
+
try:
|
121 |
+
rep.save(target_zip)
|
122 |
+
if os.path.exists(target_zip):
|
123 |
+
return target_zip
|
124 |
+
except Exception:
|
125 |
+
pass
|
126 |
+
if hasattr(rep, "export"):
|
127 |
+
try:
|
128 |
+
rep.export(target_zip)
|
129 |
+
if os.path.exists(target_zip):
|
130 |
+
return target_zip
|
131 |
+
except Exception:
|
132 |
+
pass
|
133 |
+
|
134 |
+
# 4) Fallback: write string representation
|
135 |
+
target_txt = "/mnt/data/quality_report.txt"
|
136 |
+
with open(target_txt, "w", encoding="utf-8") as f:
|
137 |
+
f.write(str(rep))
|
138 |
+
return target_txt
|
139 |
+
|
140 |
+
except Exception:
|
141 |
+
return None
|
142 |
+
|
143 |
def estimate_memory_usage(self, df: pd.DataFrame) -> str:
|
144 |
"""Estimate memory usage for the dataset"""
|
145 |
if df is None or df.empty:
|
146 |
return "No data available to analyze."
|
147 |
+
|
|
|
148 |
memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
|
149 |
rows, cols = len(df), len(df.columns)
|
|
|
|
|
150 |
estimated_training_mb = memory_mb * 4
|
151 |
+
|
152 |
if memory_mb < 100:
|
153 |
status = "Good"
|
154 |
elif memory_mb < 500:
|
155 |
status = "Large"
|
156 |
else:
|
157 |
status = "Very Large"
|
158 |
+
|
159 |
return f"""
|
160 |
Memory Usage Estimate:
|
161 |
- Data size: {memory_mb:.1f} MB
|
|
|
164 |
- Rows: {rows:,} | Columns: {cols}
|
165 |
""".strip()
|
166 |
|
167 |
+
|
168 |
# Initialize the generator
|
169 |
generator = SyntheticDataGenerator()
|
170 |
|
171 |
+
# ---- Wrapper functions for Gradio ----
|
172 |
+
def initialize_sdk() -> str:
|
173 |
+
ok, msg = generator.initialize_mostly_ai()
|
174 |
+
return ("Success: " if ok else "Error: ") + msg
|
175 |
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
+
def train_model(
|
178 |
+
data: pd.DataFrame,
|
179 |
+
model_name: str,
|
180 |
+
epochs: int,
|
181 |
+
max_training_time: int,
|
182 |
+
batch_size: int,
|
183 |
+
value_protection: bool,
|
184 |
+
) -> str:
|
185 |
if data is None or data.empty:
|
186 |
+
return "Error: No data provided. Please upload or create sample data first."
|
187 |
+
ok, msg = generator.train_generator(
|
188 |
+
data, model_name, epochs, max_training_time, batch_size, value_protection
|
189 |
+
)
|
190 |
+
return ("Success: " if ok else "Error: ") + msg
|
191 |
+
|
192 |
+
|
193 |
+
def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
|
|
|
|
|
|
194 |
synthetic_df, message = generator.generate_synthetic_data(size)
|
195 |
+
status = "Success" if synthetic_df is not None else "Error"
|
|
|
|
|
|
|
|
|
196 |
return synthetic_df, f"{status}: {message}"
|
197 |
|
|
|
|
|
|
|
198 |
|
199 |
+
def get_quality_report_and_file():
|
200 |
+
"""
|
201 |
+
Return (status_text, download_component_update)
|
202 |
+
The second value updates the DownloadButton with the file path and visibility.
|
203 |
+
"""
|
204 |
+
status = generator.get_quality_report_text()
|
205 |
+
path = generator.get_quality_report_file()
|
206 |
+
if path:
|
207 |
+
return status, gr.update(value=path, visible=True)
|
208 |
+
else:
|
209 |
+
# keep it hidden if we don't have a file
|
210 |
+
return status, gr.update(visible=False)
|
211 |
+
|
212 |
+
|
213 |
+
def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame) -> Optional[go.Figure]:
|
214 |
if original_df is None or synthetic_df is None:
|
215 |
return None
|
216 |
+
|
|
|
217 |
numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
|
|
|
218 |
if not numeric_cols:
|
219 |
return None
|
220 |
+
|
|
|
221 |
n_cols = min(3, len(numeric_cols))
|
222 |
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
|
223 |
+
|
224 |
+
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols])
|
225 |
+
|
226 |
+
for i, col in enumerate(numeric_cols[: n_rows * n_cols]):
|
|
|
|
|
|
|
|
|
227 |
row = i // n_cols + 1
|
228 |
col_idx = i % n_cols + 1
|
229 |
+
|
|
|
230 |
fig.add_trace(
|
231 |
+
go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20),
|
232 |
+
row=row,
|
233 |
+
col=col_idx,
|
|
|
|
|
|
|
|
|
234 |
)
|
|
|
|
|
235 |
fig.add_trace(
|
236 |
+
go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20),
|
237 |
+
row=row,
|
238 |
+
col=col_idx,
|
|
|
|
|
|
|
|
|
239 |
)
|
240 |
+
|
241 |
+
fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True)
|
|
|
|
|
|
|
|
|
|
|
242 |
return fig
|
243 |
|
244 |
+
|
245 |
+
def download_csv(df: pd.DataFrame) -> Optional[str]:
|
246 |
if df is None or df.empty:
|
247 |
return None
|
248 |
+
# Write CSV to a stable path so DownloadButton can fetch it
|
249 |
+
path = "/mnt/data/synthetic_data.csv"
|
250 |
+
df.to_csv(path, index=False)
|
251 |
+
return path
|
252 |
+
|
253 |
|
254 |
+
# ---- UI ----
|
255 |
def create_interface():
|
256 |
with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
|
257 |
+
# Header image
|
258 |
+
gr.Image(
|
259 |
+
value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png",
|
260 |
+
show_label=False,
|
261 |
+
elem_id="header-image",
|
262 |
+
)
|
263 |
|
264 |
# README
|
265 |
+
gr.Markdown(
|
266 |
+
"""
|
267 |
# Synthetic Data SDK by MOSTLY AI Demo Space
|
268 |
|
269 |
[Documentation](https://mostly-ai.github.io/mostlyai/) | [Technical White Paper](https://arxiv.org/abs/2508.00718) | [Usage Examples](https://mostly-ai.github.io/mostlyai/usage/) | [Free Cloud Service](https://app.mostly.ai/)
|
270 |
|
271 |
A Python toolkit for generating high-fidelity, privacy-safe synthetic data.
|
272 |
+
"""
|
273 |
+
)
|
274 |
+
|
275 |
with gr.Tab("Quick Start"):
|
276 |
gr.Markdown("### Initialize the SDK and upload your data")
|
|
|
277 |
with gr.Row():
|
278 |
with gr.Column():
|
279 |
init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary")
|
280 |
init_status = gr.Textbox(label="Initialization Status", interactive=False)
|
|
|
281 |
with gr.Column():
|
282 |
+
gr.Markdown(
|
283 |
+
"""
|
284 |
**Next Steps:**
|
285 |
1. Initialize the SDK (click button above)
|
286 |
2. Go to "Upload Data and Train Model" tab to upload your CSV file
|
287 |
3. Train a model on your data
|
288 |
4. Generate synthetic data
|
289 |
+
"""
|
290 |
+
)
|
291 |
+
|
292 |
with gr.Tab("Upload Data and Train Model"):
|
293 |
gr.Markdown("### Upload your CSV file to generate synthetic data")
|
294 |
+
gr.Markdown(
|
295 |
+
"""
|
296 |
**File Requirements:**
|
297 |
- Format: CSV with header row
|
298 |
- Size: Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM)
|
299 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
300 |
)
|
301 |
+
|
302 |
+
file_upload = gr.File(label="Upload CSV File", file_types=[".csv"], file_count="single")
|
303 |
uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False)
|
|
|
304 |
memory_info = gr.Markdown(label="Memory Usage Info", visible=False)
|
305 |
+
|
306 |
with gr.Row():
|
307 |
with gr.Column():
|
308 |
model_name = gr.Textbox(
|
309 |
+
value="My Synthetic Model", label="Model Name", placeholder="Enter a name for your model"
|
|
|
|
|
310 |
)
|
311 |
epochs = gr.Slider(1, 200, value=100, step=1, label="Training Epochs")
|
312 |
max_training_time = gr.Slider(1, 1000, value=60, step=1, label="Maximum Training Time")
|
313 |
batch_size = gr.Slider(8, 1024, value=32, step=8, label="Training Batch Size")
|
314 |
value_protection = gr.Checkbox(label="Value Protection", info="Enable Value Protection")
|
315 |
train_btn = gr.Button("Train Model", variant="primary")
|
|
|
316 |
with gr.Column():
|
317 |
train_status = gr.Textbox(label="Training Status", interactive=False)
|
318 |
+
quality_report = gr.Textbox(label="Quality Report", lines=8, interactive=False)
|
319 |
+
|
320 |
+
with gr.Row():
|
321 |
+
get_report_btn = gr.Button("Get Quality Report", variant="secondary")
|
322 |
+
report_download_btn = gr.DownloadButton("Download Quality Report", visible=False)
|
323 |
+
|
324 |
with gr.Tab("Generate Data"):
|
325 |
gr.Markdown("### Generate synthetic data from your trained model")
|
|
|
326 |
with gr.Row():
|
327 |
with gr.Column():
|
328 |
gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate")
|
329 |
generate_btn = gr.Button("Generate Synthetic Data", variant="primary")
|
|
|
330 |
with gr.Column():
|
331 |
gen_status = gr.Textbox(label="Generation Status", interactive=False)
|
332 |
+
|
333 |
synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
|
|
|
334 |
with gr.Row():
|
335 |
+
download_btn = gr.DownloadButton("Download CSV", file_name="synthetic_data.csv", variant="secondary")
|
336 |
comparison_plot = gr.Plot(label="Data Comparison")
|
337 |
|
338 |
+
# README footer
|
339 |
+
gr.Markdown(
|
340 |
+
"""
|
341 |
**Modes of operation:**
|
342 |
- **LOCAL mode** trains and generates synthetic data on your own compute resources.
|
343 |
- **CLIENT mode** connects to a remote MOSTLY AI platform for training and generation.
|
|
|
357 |
The open source Synthetic Data SDK by MOSTLY AI powers the MOSTLY AI Platform and MOSTLY AI Assistant.
|
358 |
|
359 |
Sign up for free and try the [MOSTLY AI Platform](https://app.mostly.ai/) today!
|
360 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
361 |
)
|
362 |
+
|
363 |
+
# ---- Event handlers ----
|
364 |
+
init_btn.click(initialize_sdk, outputs=[init_status])
|
365 |
+
|
366 |
train_btn.click(
|
367 |
train_model,
|
368 |
inputs=[uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection],
|
369 |
+
outputs=[train_status],
|
370 |
)
|
371 |
+
|
372 |
+
# Build + expose quality report for download
|
373 |
get_report_btn.click(
|
374 |
+
get_quality_report_and_file,
|
375 |
+
outputs=[quality_report, report_download_btn],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
)
|
377 |
+
|
378 |
+
# Generate data
|
379 |
+
generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status])
|
380 |
+
|
381 |
+
# Update CSV DownloadButton whenever synthetic data changes
|
382 |
+
synthetic_data.change(download_csv, inputs=[synthetic_data], outputs=[download_btn])
|
383 |
+
|
384 |
+
# Build comparison plot when both datasets are available
|
385 |
synthetic_data.change(
|
386 |
+
create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot]
|
|
|
|
|
387 |
)
|
388 |
+
|
389 |
# Handle file upload with size and column limits
|
390 |
def process_uploaded_file(file):
|
391 |
if file is None:
|
392 |
return None, "No file uploaded.", gr.update(visible=False)
|
|
|
393 |
try:
|
|
|
394 |
df = pd.read_csv(file.name)
|
|
|
395 |
success_msg = f"File uploaded successfully. {len(df)} rows × {len(df.columns)} columns"
|
396 |
+
mem_info = generator.estimate_memory_usage(df)
|
397 |
+
return df, success_msg, gr.update(value=mem_info, visible=True)
|
|
|
|
|
|
|
398 |
except Exception as e:
|
399 |
return None, f"Error reading file: {str(e)}", gr.update(visible=False)
|
400 |
+
|
401 |
+
file_upload.change(process_uploaded_file, inputs=[file_upload], outputs=[uploaded_data, train_status, memory_info])
|
402 |
+
|
|
|
|
|
|
|
|
|
403 |
return demo
|
404 |
|
405 |
+
|
406 |
if __name__ == "__main__":
|
407 |
demo = create_interface()
|
408 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
|
|
|
|
|