ZennyKenny commited on
Commit
8f42518
·
verified ·
1 Parent(s): 3a57c78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -207
app.py CHANGED
@@ -1,13 +1,15 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import numpy as np
4
  import io
5
  import base64
 
6
  from typing import Optional, Tuple
7
- import plotly.express as px
 
 
 
8
  import plotly.graph_objects as go
9
  from plotly.subplots import make_subplots
10
- import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
  # Import Mostly AI SDK
@@ -18,96 +20,142 @@ except ImportError:
18
  MOSTLY_AI_AVAILABLE = False
19
  print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]")
20
 
 
21
  class SyntheticDataGenerator:
22
  def __init__(self):
23
  self.mostly = None
24
  self.generator = None
25
  self.original_data = None
26
-
27
- def initialize_mostly_ai(self):
28
  """Initialize Mostly AI SDK"""
29
  if not MOSTLY_AI_AVAILABLE:
30
  return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
31
-
32
  try:
33
  self.mostly = MostlyAI(local=True, local_port=8080)
34
  return True, "Mostly AI SDK initialized successfully."
35
  except Exception as e:
36
  return False, f"Failed to initialize Mostly AI SDK: {str(e)}"
37
-
38
-
39
- def train_generator(self, data: pd.DataFrame, name: str, epochs: int = 10, max_training_time: int = 60, batch_size: int = 32, value_protection: bool = True) -> Tuple[bool, str]:
 
 
 
 
 
 
 
40
  """Train the synthetic data generator"""
41
  if not self.mostly:
42
  return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
43
-
44
  try:
45
  self.original_data = data
46
- train_config = {'tables':
47
- [
48
- {
49
- 'name': name,
50
- 'data': data,
51
- 'tabular_model_configuration':
52
- {
53
- 'max_epochs': epochs,
54
- 'max_training_time': max_training_time,
55
- 'value_protection': value_protection,
56
- 'batch_size': batch_size
57
- }
58
- }
59
- ]
60
- }
61
-
62
- self.generator = self.mostly.train(
63
- config = train_config
64
- )
65
  return True, f"Training completed successfully. Model name: {name}"
66
  except Exception as e:
67
  return False, f"Training failed with error: {str(e)}"
68
-
69
- def generate_synthetic_data(self, size: int) -> Tuple[pd.DataFrame, str]:
70
  """Generate synthetic data"""
71
  if not self.generator:
72
  return None, "No trained generator available. Please train a model first."
73
-
74
  try:
75
  synthetic_data = self.mostly.generate(self.generator, size=size)
76
  df = synthetic_data.data()
77
  return df, f"Synthetic data generated successfully. {len(df)} records created."
78
  except Exception as e:
79
  return None, f"Synthetic data generation failed with error: {str(e)}"
80
-
81
- def get_quality_report(self) -> str:
82
- """Get quality assurance report"""
 
83
  if not self.generator:
84
  return "No trained generator available. Please train a model first."
85
-
86
  try:
87
- report = self.generator.reports(display=False)
88
- return str(report)
89
  except Exception as e:
90
  return f"Failed to generate quality report: {str(e)}"
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def estimate_memory_usage(self, df: pd.DataFrame) -> str:
93
  """Estimate memory usage for the dataset"""
94
  if df is None or df.empty:
95
  return "No data available to analyze."
96
-
97
- # Calculate approximate memory usage
98
  memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
99
  rows, cols = len(df), len(df.columns)
100
-
101
- # Estimate training memory (roughly 3-5x the data size)
102
  estimated_training_mb = memory_mb * 4
103
-
104
  if memory_mb < 100:
105
  status = "Good"
106
  elif memory_mb < 500:
107
  status = "Large"
108
  else:
109
  status = "Very Large"
110
-
111
  return f"""
112
  Memory Usage Estimate:
113
  - Data size: {memory_mb:.1f} MB
@@ -116,196 +164,180 @@ Memory Usage Estimate:
116
  - Rows: {rows:,} | Columns: {cols}
117
  """.strip()
118
 
 
119
  # Initialize the generator
120
  generator = SyntheticDataGenerator()
121
 
 
 
 
 
122
 
123
- def initialize_sdk() -> Tuple[str, str]:
124
- """Initialize the Mostly AI SDK"""
125
- success, message = generator.initialize_mostly_ai()
126
- status = "Success" if success else "Error"
127
- return status, message
128
 
129
- def train_model(data: pd.DataFrame, model_name: str, epochs: int, max_training_time: int, batch_size: int, value_protection: bool) -> Tuple[str, str]:
130
- """Train the synthetic data generator"""
 
 
 
 
 
 
131
  if data is None or data.empty:
132
- return "Error", "No data provided. Please upload or create sample data first."
133
-
134
- success, message = generator.train_generator(data, model_name, epochs, max_training_time, batch_size, value_protection)
135
- status = "Success" if success else "Error"
136
- return status, message
137
-
138
- def generate_data(size: int) -> Tuple[pd.DataFrame, str]:
139
- """Generate synthetic data"""
140
- if generator.generator is None:
141
- return None, "Error: No trained model available. Please train a model first."
142
-
143
  synthetic_df, message = generator.generate_synthetic_data(size)
144
- if synthetic_df is not None:
145
- status = "Success"
146
- else:
147
- status = "Error"
148
-
149
  return synthetic_df, f"{status}: {message}"
150
 
151
- def get_quality_report() -> str:
152
- """Get quality report"""
153
- return generator.get_quality_report()
154
 
155
- def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame) -> go.Figure:
156
- """Create comparison plots between original and synthetic data"""
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if original_df is None or synthetic_df is None:
158
  return None
159
-
160
- # Select numeric columns for comparison
161
  numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
162
-
163
  if not numeric_cols:
164
  return None
165
-
166
- # Create subplots
167
  n_cols = min(3, len(numeric_cols))
168
  n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
169
-
170
- fig = make_subplots(
171
- rows=n_rows,
172
- cols=n_cols,
173
- subplot_titles=numeric_cols[:n_rows*n_cols]
174
- )
175
-
176
- for i, col in enumerate(numeric_cols[:n_rows*n_cols]):
177
  row = i // n_cols + 1
178
  col_idx = i % n_cols + 1
179
-
180
- # Add original data histogram
181
  fig.add_trace(
182
- go.Histogram(
183
- x=original_df[col],
184
- name=f'Original {col}',
185
- opacity=0.7,
186
- nbinsx=20
187
- ),
188
- row=row, col=col_idx
189
  )
190
-
191
- # Add synthetic data histogram
192
  fig.add_trace(
193
- go.Histogram(
194
- x=synthetic_df[col],
195
- name=f'Synthetic {col}',
196
- opacity=0.7,
197
- nbinsx=20
198
- ),
199
- row=row, col=col_idx
200
  )
201
-
202
- fig.update_layout(
203
- title="Original vs Synthetic Data Comparison",
204
- height=300 * n_rows,
205
- showlegend=True
206
- )
207
-
208
  return fig
209
 
210
- def download_csv(df: pd.DataFrame) -> str:
211
- """Convert DataFrame to CSV for download"""
212
  if df is None or df.empty:
213
  return None
214
-
215
- csv = df.to_csv(index=False)
216
- return csv
 
 
217
 
218
- # Create the Gradio interface
219
  def create_interface():
220
  with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
221
-
222
- # display image above tabs
223
- gr.Image(value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png", show_label=False, elem_id="header-image")
 
 
 
224
 
225
  # README
226
- gr.Markdown("""
 
227
  # Synthetic Data SDK by MOSTLY AI Demo Space
228
 
229
  [Documentation](https://mostly-ai.github.io/mostlyai/) | [Technical White Paper](https://arxiv.org/abs/2508.00718) | [Usage Examples](https://mostly-ai.github.io/mostlyai/usage/) | [Free Cloud Service](https://app.mostly.ai/)
230
 
231
  A Python toolkit for generating high-fidelity, privacy-safe synthetic data.
232
- """)
233
-
 
234
  with gr.Tab("Quick Start"):
235
  gr.Markdown("### Initialize the SDK and upload your data")
236
-
237
  with gr.Row():
238
  with gr.Column():
239
  init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary")
240
  init_status = gr.Textbox(label="Initialization Status", interactive=False)
241
-
242
  with gr.Column():
243
- gr.Markdown("""
 
244
  **Next Steps:**
245
  1. Initialize the SDK (click button above)
246
  2. Go to "Upload Data and Train Model" tab to upload your CSV file
247
  3. Train a model on your data
248
  4. Generate synthetic data
249
- """)
250
-
 
251
  with gr.Tab("Upload Data and Train Model"):
252
  gr.Markdown("### Upload your CSV file to generate synthetic data")
253
-
254
- gr.Markdown("""
255
  **File Requirements:**
256
  - Format: CSV with header row
257
  - Size: Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM)
258
- """)
259
-
260
- file_upload = gr.File(
261
- label="Upload CSV File",
262
- file_types=[".csv"],
263
- file_count="single"
264
  )
265
-
 
266
  uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False)
267
-
268
  memory_info = gr.Markdown(label="Memory Usage Info", visible=False)
269
-
270
  with gr.Row():
271
  with gr.Column():
272
  model_name = gr.Textbox(
273
- value="My Synthetic Model",
274
- label="Model Name",
275
- placeholder="Enter a name for your model"
276
  )
277
  epochs = gr.Slider(1, 200, value=100, step=1, label="Training Epochs")
278
  max_training_time = gr.Slider(1, 1000, value=60, step=1, label="Maximum Training Time")
279
  batch_size = gr.Slider(8, 1024, value=32, step=8, label="Training Batch Size")
280
  value_protection = gr.Checkbox(label="Value Protection", info="Enable Value Protection")
281
  train_btn = gr.Button("Train Model", variant="primary")
282
-
283
  with gr.Column():
284
  train_status = gr.Textbox(label="Training Status", interactive=False)
285
- quality_report = gr.Textbox(label="Quality Report", lines=10, interactive=False)
286
-
287
- get_report_btn = gr.Button("Get Quality Report", variant="secondary")
288
-
 
 
289
  with gr.Tab("Generate Data"):
290
  gr.Markdown("### Generate synthetic data from your trained model")
291
-
292
  with gr.Row():
293
  with gr.Column():
294
  gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate")
295
  generate_btn = gr.Button("Generate Synthetic Data", variant="primary")
296
-
297
  with gr.Column():
298
  gen_status = gr.Textbox(label="Generation Status", interactive=False)
299
-
300
  synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
301
-
302
  with gr.Row():
303
- download_btn = gr.DownloadButton("Download CSV", variant="secondary")
304
  comparison_plot = gr.Plot(label="Data Comparison")
305
 
306
- # README
307
- gr.Markdown("""
308
-
309
  **Modes of operation:**
310
  - **LOCAL mode** trains and generates synthetic data on your own compute resources.
311
  - **CLIENT mode** connects to a remote MOSTLY AI platform for training and generation.
@@ -325,75 +357,52 @@ def create_interface():
325
  The open source Synthetic Data SDK by MOSTLY AI powers the MOSTLY AI Platform and MOSTLY AI Assistant.
326
 
327
  Sign up for free and try the [MOSTLY AI Platform](https://app.mostly.ai/) today!
328
- """)
329
-
330
- # Event handlers
331
- init_btn.click(
332
- initialize_sdk,
333
- outputs=[init_status, init_status]
334
  )
335
-
 
 
 
336
  train_btn.click(
337
  train_model,
338
  inputs=[uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection],
339
- outputs=[train_status, train_status]
340
  )
341
-
 
342
  get_report_btn.click(
343
- get_quality_report,
344
- outputs=[quality_report]
345
- )
346
-
347
- generate_btn.click(
348
- generate_data,
349
- inputs=[gen_size],
350
- outputs=[synthetic_data, gen_status]
351
- )
352
-
353
- # Update download button when synthetic data changes
354
- synthetic_data.change(
355
- download_csv,
356
- inputs=[synthetic_data],
357
- outputs=[download_btn]
358
  )
359
-
360
- # Create comparison plot when both datasets are available
 
 
 
 
 
 
361
  synthetic_data.change(
362
- create_comparison_plot,
363
- inputs=[uploaded_data, synthetic_data],
364
- outputs=[comparison_plot]
365
  )
366
-
367
  # Handle file upload with size and column limits
368
  def process_uploaded_file(file):
369
  if file is None:
370
  return None, "No file uploaded.", gr.update(visible=False)
371
-
372
  try:
373
- # Read the CSV file
374
  df = pd.read_csv(file.name)
375
-
376
  success_msg = f"File uploaded successfully. {len(df)} rows × {len(df.columns)} columns"
377
-
378
- memory_info = generator.estimate_memory_usage(df)
379
-
380
- return df, success_msg, gr.update(value=memory_info, visible=True)
381
-
382
  except Exception as e:
383
  return None, f"Error reading file: {str(e)}", gr.update(visible=False)
384
-
385
- file_upload.change(
386
- process_uploaded_file,
387
- inputs=[file_upload],
388
- outputs=[uploaded_data, train_status, memory_info]
389
- )
390
-
391
  return demo
392
 
 
393
  if __name__ == "__main__":
394
  demo = create_interface()
395
- demo.launch(
396
- server_name="0.0.0.0",
397
- server_port=7860,
398
- share=True
399
- )
 
1
+ import os
 
 
2
  import io
3
  import base64
4
+ import warnings
5
  from typing import Optional, Tuple
6
+
7
+ import gradio as gr
8
+ import pandas as pd
9
+ import numpy as np
10
  import plotly.graph_objects as go
11
  from plotly.subplots import make_subplots
12
+
13
  warnings.filterwarnings("ignore")
14
 
15
  # Import Mostly AI SDK
 
20
  MOSTLY_AI_AVAILABLE = False
21
  print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]")
22
 
23
+
24
  class SyntheticDataGenerator:
25
  def __init__(self):
26
  self.mostly = None
27
  self.generator = None
28
  self.original_data = None
29
+
30
+ def initialize_mostly_ai(self) -> Tuple[bool, str]:
31
  """Initialize Mostly AI SDK"""
32
  if not MOSTLY_AI_AVAILABLE:
33
  return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
 
34
  try:
35
  self.mostly = MostlyAI(local=True, local_port=8080)
36
  return True, "Mostly AI SDK initialized successfully."
37
  except Exception as e:
38
  return False, f"Failed to initialize Mostly AI SDK: {str(e)}"
39
+
40
+ def train_generator(
41
+ self,
42
+ data: pd.DataFrame,
43
+ name: str,
44
+ epochs: int = 10,
45
+ max_training_time: int = 60,
46
+ batch_size: int = 32,
47
+ value_protection: bool = True,
48
+ ) -> Tuple[bool, str]:
49
  """Train the synthetic data generator"""
50
  if not self.mostly:
51
  return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
 
52
  try:
53
  self.original_data = data
54
+ train_config = {
55
+ "tables": [
56
+ {
57
+ "name": name,
58
+ "data": data,
59
+ "tabular_model_configuration": {
60
+ "max_epochs": epochs,
61
+ "max_training_time": max_training_time,
62
+ "value_protection": value_protection,
63
+ "batch_size": batch_size,
64
+ },
65
+ }
66
+ ]
67
+ }
68
+
69
+ self.generator = self.mostly.train(config=train_config)
 
 
 
70
  return True, f"Training completed successfully. Model name: {name}"
71
  except Exception as e:
72
  return False, f"Training failed with error: {str(e)}"
73
+
74
+ def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]:
75
  """Generate synthetic data"""
76
  if not self.generator:
77
  return None, "No trained generator available. Please train a model first."
 
78
  try:
79
  synthetic_data = self.mostly.generate(self.generator, size=size)
80
  df = synthetic_data.data()
81
  return df, f"Synthetic data generated successfully. {len(df)} records created."
82
  except Exception as e:
83
  return None, f"Synthetic data generation failed with error: {str(e)}"
84
+
85
+ # ---- Report helpers (new) ----
86
+ def get_quality_report_text(self) -> str:
87
+ """Return a concise status about the report."""
88
  if not self.generator:
89
  return "No trained generator available. Please train a model first."
 
90
  try:
91
+ _ = self.generator.reports(display=False) # builds report internally
92
+ return "Quality report generated. Use the button to download."
93
  except Exception as e:
94
  return f"Failed to generate quality report: {str(e)}"
95
+
96
+ def get_quality_report_file(self) -> Optional[str]:
97
+ """
98
+ Generate/export the report and return a file path for download.
99
+ Tries to find an existing ZIP; otherwise saves a TXT fallback.
100
+ """
101
+ if not self.generator:
102
+ return None
103
+ try:
104
+ rep = self.generator.reports(display=False)
105
+
106
+ # 1) If a string path to a .zip is returned
107
+ if isinstance(rep, str) and rep.endswith(".zip") and os.path.exists(rep):
108
+ return rep
109
+
110
+ # 2) If the object exposes a path-like attribute
111
+ for attr in ("archive_path", "zip_path", "path", "file_path"):
112
+ if hasattr(rep, attr):
113
+ p = getattr(rep, attr)
114
+ if isinstance(p, str) and os.path.exists(p):
115
+ return p
116
+
117
+ # 3) If the object can save/export itself
118
+ target_zip = "/mnt/data/quality_report.zip"
119
+ if hasattr(rep, "save"):
120
+ try:
121
+ rep.save(target_zip)
122
+ if os.path.exists(target_zip):
123
+ return target_zip
124
+ except Exception:
125
+ pass
126
+ if hasattr(rep, "export"):
127
+ try:
128
+ rep.export(target_zip)
129
+ if os.path.exists(target_zip):
130
+ return target_zip
131
+ except Exception:
132
+ pass
133
+
134
+ # 4) Fallback: write string representation
135
+ target_txt = "/mnt/data/quality_report.txt"
136
+ with open(target_txt, "w", encoding="utf-8") as f:
137
+ f.write(str(rep))
138
+ return target_txt
139
+
140
+ except Exception:
141
+ return None
142
+
143
  def estimate_memory_usage(self, df: pd.DataFrame) -> str:
144
  """Estimate memory usage for the dataset"""
145
  if df is None or df.empty:
146
  return "No data available to analyze."
147
+
 
148
  memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
149
  rows, cols = len(df), len(df.columns)
 
 
150
  estimated_training_mb = memory_mb * 4
151
+
152
  if memory_mb < 100:
153
  status = "Good"
154
  elif memory_mb < 500:
155
  status = "Large"
156
  else:
157
  status = "Very Large"
158
+
159
  return f"""
160
  Memory Usage Estimate:
161
  - Data size: {memory_mb:.1f} MB
 
164
  - Rows: {rows:,} | Columns: {cols}
165
  """.strip()
166
 
167
+
168
  # Initialize the generator
169
  generator = SyntheticDataGenerator()
170
 
171
+ # ---- Wrapper functions for Gradio ----
172
+ def initialize_sdk() -> str:
173
+ ok, msg = generator.initialize_mostly_ai()
174
+ return ("Success: " if ok else "Error: ") + msg
175
 
 
 
 
 
 
176
 
177
+ def train_model(
178
+ data: pd.DataFrame,
179
+ model_name: str,
180
+ epochs: int,
181
+ max_training_time: int,
182
+ batch_size: int,
183
+ value_protection: bool,
184
+ ) -> str:
185
  if data is None or data.empty:
186
+ return "Error: No data provided. Please upload or create sample data first."
187
+ ok, msg = generator.train_generator(
188
+ data, model_name, epochs, max_training_time, batch_size, value_protection
189
+ )
190
+ return ("Success: " if ok else "Error: ") + msg
191
+
192
+
193
+ def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
 
 
 
194
  synthetic_df, message = generator.generate_synthetic_data(size)
195
+ status = "Success" if synthetic_df is not None else "Error"
 
 
 
 
196
  return synthetic_df, f"{status}: {message}"
197
 
 
 
 
198
 
199
+ def get_quality_report_and_file():
200
+ """
201
+ Return (status_text, download_component_update)
202
+ The second value updates the DownloadButton with the file path and visibility.
203
+ """
204
+ status = generator.get_quality_report_text()
205
+ path = generator.get_quality_report_file()
206
+ if path:
207
+ return status, gr.update(value=path, visible=True)
208
+ else:
209
+ # keep it hidden if we don't have a file
210
+ return status, gr.update(visible=False)
211
+
212
+
213
+ def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame) -> Optional[go.Figure]:
214
  if original_df is None or synthetic_df is None:
215
  return None
216
+
 
217
  numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
 
218
  if not numeric_cols:
219
  return None
220
+
 
221
  n_cols = min(3, len(numeric_cols))
222
  n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
223
+
224
+ fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols])
225
+
226
+ for i, col in enumerate(numeric_cols[: n_rows * n_cols]):
 
 
 
 
227
  row = i // n_cols + 1
228
  col_idx = i % n_cols + 1
229
+
 
230
  fig.add_trace(
231
+ go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20),
232
+ row=row,
233
+ col=col_idx,
 
 
 
 
234
  )
 
 
235
  fig.add_trace(
236
+ go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20),
237
+ row=row,
238
+ col=col_idx,
 
 
 
 
239
  )
240
+
241
+ fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True)
 
 
 
 
 
242
  return fig
243
 
244
+
245
+ def download_csv(df: pd.DataFrame) -> Optional[str]:
246
  if df is None or df.empty:
247
  return None
248
+ # Write CSV to a stable path so DownloadButton can fetch it
249
+ path = "/mnt/data/synthetic_data.csv"
250
+ df.to_csv(path, index=False)
251
+ return path
252
+
253
 
254
+ # ---- UI ----
255
  def create_interface():
256
  with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
257
+ # Header image
258
+ gr.Image(
259
+ value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png",
260
+ show_label=False,
261
+ elem_id="header-image",
262
+ )
263
 
264
  # README
265
+ gr.Markdown(
266
+ """
267
  # Synthetic Data SDK by MOSTLY AI Demo Space
268
 
269
  [Documentation](https://mostly-ai.github.io/mostlyai/) | [Technical White Paper](https://arxiv.org/abs/2508.00718) | [Usage Examples](https://mostly-ai.github.io/mostlyai/usage/) | [Free Cloud Service](https://app.mostly.ai/)
270
 
271
  A Python toolkit for generating high-fidelity, privacy-safe synthetic data.
272
+ """
273
+ )
274
+
275
  with gr.Tab("Quick Start"):
276
  gr.Markdown("### Initialize the SDK and upload your data")
 
277
  with gr.Row():
278
  with gr.Column():
279
  init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary")
280
  init_status = gr.Textbox(label="Initialization Status", interactive=False)
 
281
  with gr.Column():
282
+ gr.Markdown(
283
+ """
284
  **Next Steps:**
285
  1. Initialize the SDK (click button above)
286
  2. Go to "Upload Data and Train Model" tab to upload your CSV file
287
  3. Train a model on your data
288
  4. Generate synthetic data
289
+ """
290
+ )
291
+
292
  with gr.Tab("Upload Data and Train Model"):
293
  gr.Markdown("### Upload your CSV file to generate synthetic data")
294
+ gr.Markdown(
295
+ """
296
  **File Requirements:**
297
  - Format: CSV with header row
298
  - Size: Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM)
299
+ """
 
 
 
 
 
300
  )
301
+
302
+ file_upload = gr.File(label="Upload CSV File", file_types=[".csv"], file_count="single")
303
  uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False)
 
304
  memory_info = gr.Markdown(label="Memory Usage Info", visible=False)
305
+
306
  with gr.Row():
307
  with gr.Column():
308
  model_name = gr.Textbox(
309
+ value="My Synthetic Model", label="Model Name", placeholder="Enter a name for your model"
 
 
310
  )
311
  epochs = gr.Slider(1, 200, value=100, step=1, label="Training Epochs")
312
  max_training_time = gr.Slider(1, 1000, value=60, step=1, label="Maximum Training Time")
313
  batch_size = gr.Slider(8, 1024, value=32, step=8, label="Training Batch Size")
314
  value_protection = gr.Checkbox(label="Value Protection", info="Enable Value Protection")
315
  train_btn = gr.Button("Train Model", variant="primary")
 
316
  with gr.Column():
317
  train_status = gr.Textbox(label="Training Status", interactive=False)
318
+ quality_report = gr.Textbox(label="Quality Report", lines=8, interactive=False)
319
+
320
+ with gr.Row():
321
+ get_report_btn = gr.Button("Get Quality Report", variant="secondary")
322
+ report_download_btn = gr.DownloadButton("Download Quality Report", visible=False)
323
+
324
  with gr.Tab("Generate Data"):
325
  gr.Markdown("### Generate synthetic data from your trained model")
 
326
  with gr.Row():
327
  with gr.Column():
328
  gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate")
329
  generate_btn = gr.Button("Generate Synthetic Data", variant="primary")
 
330
  with gr.Column():
331
  gen_status = gr.Textbox(label="Generation Status", interactive=False)
332
+
333
  synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
 
334
  with gr.Row():
335
+ download_btn = gr.DownloadButton("Download CSV", file_name="synthetic_data.csv", variant="secondary")
336
  comparison_plot = gr.Plot(label="Data Comparison")
337
 
338
+ # README footer
339
+ gr.Markdown(
340
+ """
341
  **Modes of operation:**
342
  - **LOCAL mode** trains and generates synthetic data on your own compute resources.
343
  - **CLIENT mode** connects to a remote MOSTLY AI platform for training and generation.
 
357
  The open source Synthetic Data SDK by MOSTLY AI powers the MOSTLY AI Platform and MOSTLY AI Assistant.
358
 
359
  Sign up for free and try the [MOSTLY AI Platform](https://app.mostly.ai/) today!
360
+ """
 
 
 
 
 
361
  )
362
+
363
+ # ---- Event handlers ----
364
+ init_btn.click(initialize_sdk, outputs=[init_status])
365
+
366
  train_btn.click(
367
  train_model,
368
  inputs=[uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection],
369
+ outputs=[train_status],
370
  )
371
+
372
+ # Build + expose quality report for download
373
  get_report_btn.click(
374
+ get_quality_report_and_file,
375
+ outputs=[quality_report, report_download_btn],
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  )
377
+
378
+ # Generate data
379
+ generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status])
380
+
381
+ # Update CSV DownloadButton whenever synthetic data changes
382
+ synthetic_data.change(download_csv, inputs=[synthetic_data], outputs=[download_btn])
383
+
384
+ # Build comparison plot when both datasets are available
385
  synthetic_data.change(
386
+ create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot]
 
 
387
  )
388
+
389
  # Handle file upload with size and column limits
390
  def process_uploaded_file(file):
391
  if file is None:
392
  return None, "No file uploaded.", gr.update(visible=False)
 
393
  try:
 
394
  df = pd.read_csv(file.name)
 
395
  success_msg = f"File uploaded successfully. {len(df)} rows × {len(df.columns)} columns"
396
+ mem_info = generator.estimate_memory_usage(df)
397
+ return df, success_msg, gr.update(value=mem_info, visible=True)
 
 
 
398
  except Exception as e:
399
  return None, f"Error reading file: {str(e)}", gr.update(visible=False)
400
+
401
+ file_upload.change(process_uploaded_file, inputs=[file_upload], outputs=[uploaded_data, train_status, memory_info])
402
+
 
 
 
 
403
  return demo
404
 
405
+
406
  if __name__ == "__main__":
407
  demo = create_interface()
408
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)