import gradio as gr import pandas as pd import numpy as np import time # Custom CSS inspired by Google's Material Design custom_css = """ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500&display=swap'); body, .gradio-container { font-family: 'Roboto', sans-serif; background-color: #FAFAFA; color: #202124; } h1, h2, h3, h4, .title { font-weight: 400; } .gr-button { background-color: #1a73e8 !important; color: #fff !important; border-radius: 4px !important; font-weight: 500; padding: 0.6em 1.2em; transition: background-color 0.3s ease; } .gr-button:hover { background-color: #1669c1 !important; } .gr-input, .gr-textbox, .gr-file, .gr-slider input { border: 1px solid #dadce0 !important; border-radius: 4px !important; padding: 0.5em; } .gr-slider > div { background-color: #1a73e8 !important; } .tab-item { padding: 1em; } """ # --------------------------------- # Data Upload and Preprocessing Module # --------------------------------- def process_data(file, augment): """ Validates and preprocesses the uploaded file. For CSV files: reads the CSV, shows a preview, and if augmentation is selected, applies simple augmentation. For JSONL/TXT: just displays a preview. """ if file is None: return "No file uploaded yet." name = file.name ext = name.split('.')[-1].lower() if ext == "csv": try: df = pd.read_csv(file.name) except Exception as e: return f"Error reading CSV: {e}" original_preview = df.head().to_html(classes="dataframe", border=0) result = f"Original Data (Preview):
{original_preview}" if augment: # Simple augmentation: add random noise to numeric columns df_aug = df.copy() num_cols = df_aug.select_dtypes(include=[np.number]).columns if len(num_cols) > 0: noise = np.random.normal(0, 0.05, df_aug[num_cols].shape) df_aug[num_cols] = df_aug[num_cols] + noise aug_preview = df_aug.head().to_html(classes="dataframe", border=0) result += f"

Augmented Data (Preview):
{aug_preview}" else: result += "

Note: No numeric columns found for augmentation." return result elif ext == "jsonl": try: with open(file.name, "r") as f: lines = f.readlines() preview = "".join(lines[:5]) return f"File: {name}

Preview:
{preview}" except Exception as e: return f"Error reading JSONL file: {e}" elif ext == "txt": try: with open(file.name, "r") as f: content = f.read(500) return f"File: {name}

Preview (first 500 characters):
{content}" except Exception as e: return f"Error reading TXT file: {e}" else: return "Unsupported file type. Please upload a CSV, JSONL, or TXT file." data_upload_interface = gr.Interface( fn=process_data, inputs=[ gr.File(label="Upload CSV/JSONL/TXT File"), gr.Checkbox(label="Apply Data Augmentation", value=False) ], outputs=gr.HTML(), title="Data Upload & Preprocessing", description="Upload your dataset file, validate its format, and optionally apply data augmentation." ) # --------------------------------- # Hyperparameter Configuration Module # --------------------------------- def configure_hyperparameters(learning_rate, batch_size, epochs): config = f"Learning Rate: {learning_rate}
" + \ f"Batch Size: {batch_size}
" + \ f"Epochs: {epochs}" return config hyperparameter_interface = gr.Interface( fn=configure_hyperparameters, inputs=[ gr.Slider(0.0001, 0.1, value=0.001, label="Learning Rate", step=0.0001), gr.Dropdown(choices=["16", "32", "64", "128"], value="32", label="Batch Size"), gr.Number(value=10, label="Epochs") ], outputs=gr.HTML(), title="Hyperparameter Settings", description="Adjust the training parameters for fine-tuning the model." ) # --------------------------------- # Training Dashboard Module (Simulation) # --------------------------------- def simulate_training(): progress_vals = [] loss_vals = [] for i in range(1, 101): time.sleep(0.03) # Simulate training iteration delay progress_vals.append(i) loss_vals.append(np.random.rand() + (100-i)/100) # Simulated loss curve sample_output = "This is a generated snippet from the fine-tuned Gemma model." return progress_vals, loss_vals, sample_output training_interface = gr.Interface( fn=simulate_training, inputs=[], outputs=[ gr.Plot(label="Training Progress (%)"), gr.Plot(label="Loss Curve"), gr.Textbox(label="Sample Output", lines=3) ], title="Training Dashboard", description="Monitor training progress in real-time (simulation)." ) # --------------------------------- # Model Export Module # --------------------------------- def export_model(export_format): time.sleep(2) # Simulate export process return f"Model exported as {export_format}! Download link: [dummy_link]" export_interface = gr.Interface( fn=export_model, inputs=gr.Radio(["TensorFlow SavedModel", "PyTorch", "GGUF"], label="Select Export Format"), outputs=gr.HTML(), title="Model Export", description="Export your fine-tuned model in the desired format." ) # --------------------------------- # Help & Documentation Module # --------------------------------- help_text = """ ### Getting Started with Gemma Fine-tuning UI 1. **Data Upload & Preprocessing:** Upload your dataset in CSV, JSONL, or TXT format. The app validates your file and shows a preview. Optionally, enable data augmentation (e.g., adding random noise to numeric columns). 2. **Hyperparameter Settings:** Configure training parameters such as learning rate, batch size, and epochs. 3. **Training Dashboard:** Monitor the training progress in real-time. This demo simulates a training session. 4. **Model Export:** Export your fine-tuned model in a variety of formats. For more detailed documentation, please refer to the [official documentation](https://example.com). """ help_interface = gr.Interface( fn=lambda: help_text, inputs=[], outputs="markdown", title="Help & Documentation" ) # --------------------------------- # Assemble the Tabbed Interface with Gradio Blocks # --------------------------------- with gr.Blocks(css=custom_css) as demo: gr.Markdown("

Gemma Fine-tuning UI

") with gr.Tabs(): with gr.TabItem("Data Upload & Preprocessing"): data_upload_interface.render() with gr.TabItem("Hyperparameters"): hyperparameter_interface.render() with gr.TabItem("Training"): training_interface.render() with gr.TabItem("Model Export"): export_interface.render() with gr.TabItem("Help"): help_interface.render() demo.launch(share=True)