HMVL's picture
Update app.py
20e635a verified
import gradio as gr
import json
import os
from insert_snippet import insert_snippet
from ai_model_plugins import list_available_models, list_models_gemini, list_models_openai, list_models_anthropic
# Inline CSS styling - enhanced with more modern elements
custom_css = """
/* Overall background and text styling for a dark theme */
body,
.gradio-container {
background-color: #121212;
color: #ffffff;
font-family: 'Inter', 'Helvetica', 'Arial', sans-serif;
margin: 0;
padding: 0;
}
/* Style for markdown headers with gradient */
h1 {
background: linear-gradient(90deg, #4776E6 0%, #8E54E9 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-weight: 700;
margin-bottom: 1.5rem;
font-size: 2.5rem;
}
h2, h3, h4, h5, h6 {
color: #e0e0e0;
}
/* Style buttons with gradient and animation */
button {
background: linear-gradient(90deg, #4776E6 0%, #8E54E9 100%);
color: #ffffff;
border: none;
padding: 10px 15px;
border-radius: 8px;
cursor: pointer;
font-size: 14px;
font-weight: 600;
transition: all 0.3s ease;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 7px 14px rgba(0, 0, 0, 0.2);
}
button:active {
transform: translateY(0);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
/* Style for input fields with better focus states */
input, textarea, select {
background-color: #2a2a2a;
border: 1px solid #444;
color: #e0e0e0;
border-radius: 6px;
padding: 8px 12px;
transition: all 0.3s ease;
}
input:focus, textarea:focus, select:focus {
border-color: #8E54E9;
box-shadow: 0 0 0 2px rgba(142, 84, 233, 0.2);
outline: none;
}
/* Card styling for examples */
.example-card {
background-color: #1e1e1e;
border-radius: 10px;
padding: 15px;
margin-bottom: 15px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
border-left: 4px solid #4776E6;
transition: transform 0.3s ease, box-shadow 0.3s ease;
}
.example-card:hover {
transform: translateY(-5px);
box-shadow: 0 7px 14px rgba(0, 0, 0, 0.2);
}
/* Tab styling */
.tabs {
margin-top: 20px;
}
.tab-button {
background-color: transparent;
color: #aaa;
border: none;
padding: 10px 15px;
margin-right: 5px;
border-radius: 5px 5px 0 0;
cursor: pointer;
}
.tab-button.active {
background-color: #2a2a2a;
color: #fff;
}
.tab-content {
background-color: #2a2a2a;
padding: 20px;
border-radius: 0 5px 5px 5px;
}
/* WebGPU status indicator */
.webgpu-status {
display: inline-block;
padding: 5px 10px;
border-radius: 15px;
font-size: 12px;
font-weight: 600;
margin-left: 10px;
}
.webgpu-status.available {
background-color: #4CAF50;
color: white;
}
.webgpu-status.unavailable {
background-color: #F44336;
color: white;
}
/* Model card styling */
.model-card {
background-color: #1e1e1e;
border-radius: 10px;
padding: 15px;
margin-bottom: 15px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
transition: transform 0.3s ease, box-shadow 0.3s ease;
border-left: 4px solid #8E54E9;
}
.model-card:hover {
transform: translateY(-5px);
box-shadow: 0 7px 14px rgba(0, 0, 0, 0.2);
}
/* Progress bar styling */
.progress-container {
width: 100%;
background-color: #2a2a2a;
border-radius: 5px;
margin: 10px 0;
}
.progress-bar {
height: 10px;
background: linear-gradient(90deg, #4776E6 0%, #8E54E9 100%);
border-radius: 5px;
width: 0%;
transition: width 0.3s ease;
}
/* Theme toggle switch */
.theme-switch {
position: relative;
display: inline-block;
width: 60px;
height: 34px;
}
.theme-switch input {
opacity: 0;
width: 0;
height: 0;
}
.slider {
position: absolute;
cursor: pointer;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: #2a2a2a;
transition: .4s;
border-radius: 34px;
}
.slider:before {
position: absolute;
content: "";
height: 26px;
width: 26px;
left: 4px;
bottom: 4px;
background-color: white;
transition: .4s;
border-radius: 50%;
}
input:checked + .slider {
background-color: #8E54E9;
}
input:checked + .slider:before {
transform: translateX(26px);
}
"""
def generate_snippet_gui(
description, language, output_dir, output_file_name, clipboard,
provider, model, api_key, template, params, existing_file, marker, format_code
):
try:
try:
template_params = json.loads(params) if params else {}
except Exception as e:
return f"Error parsing template parameters: {e}", ""
result = insert_snippet(
description,
language,
output_dir=output_dir,
output_file_name=output_file_name,
use_clipboard=clipboard,
model_name=model,
template=template,
template_params=template_params,
existing_file=existing_file if existing_file.strip() != "" else None,
marker=marker if marker.strip() != "" else None,
format_code=format_code,
provider=provider if provider.strip() != "" else None,
api_key=api_key if api_key.strip() != "" else None,
return_snippet=True
)
# result is a tuple: (snippet, cost)
return result
except Exception as e:
return f"Error generating snippet: {e}", ""
def download_snippet_file_gui(
description, language, output_dir, output_file_name, clipboard,
provider, model, api_key, template, params, existing_file, marker, format_code
):
try:
try:
template_params = json.loads(params) if params else {}
except Exception as e:
return f"Error parsing template parameters: {e}"
result = insert_snippet(
description,
language,
output_dir=output_dir,
output_file_name=output_file_name,
use_clipboard=clipboard,
model_name=model,
template=template,
template_params=template_params,
existing_file=existing_file if existing_file.strip() != "" else None,
marker=marker if marker.strip() != "" else None,
format_code=format_code,
provider=provider if provider.strip() != "" else None,
api_key=api_key if api_key.strip() != "" else None,
return_file_path=True
)
return result[1]
except Exception as e:
return f"Error saving file: {e}"
def list_models_gui(provider, api_key):
try:
models = list_available_models(provider, api_key if api_key.strip() != "" else None)
return "\n".join(models)
except Exception as e:
return f"Error listing models: {e}"
def clear_fields():
return (
"", "python", "output", "snippet", False,
"Gemini", "gemini-2.0-flash", "", "", "{}", "", "", False, ""
)
def update_model_suggestions(provider):
try:
if provider == "Gemini":
return "\n".join(list_models_gemini())
elif provider == "OpenAI":
return "gpt-3.5-turbo\ngpt-4\ngpt-4-turbo\ngpt-4o\n(API key required for full list)"
elif provider == "Anthropic":
return "\n".join(list_models_anthropic())
elif provider == "Ollama":
return "Local Ollama models (requires Ollama running on localhost:11434)"
elif provider == "OpenRouter":
return "Various models available (API key required)"
else:
return "Select a provider to see available models"
except Exception as e:
return f"Error fetching model suggestions: {e}"
def update_language_example(language):
examples = {
"python": "Create a function to download and process JSON data from an API with error handling",
"javascript": "Write a modern ES6 utility function to deep merge two objects",
"typescript": "Create a React hook for managing form state with TypeScript",
"html": "Design a responsive card component with hover effects",
"css": "Create a CSS animation for a loading spinner with gradient colors",
"cpp": "Implement a thread-safe singleton pattern in C++",
"java": "Create a Java class for parsing and validating CSV files",
"go": "Write a Go function to concurrently process items from a channel",
"rust": "Implement a safe wrapper around an unsafe Rust API",
"sql": "Write a query to find the top 5 customers by purchase amount with their details",
"bash": "Create a bash script to backup and compress log files older than 7 days",
"shell": "Write a shell script to monitor system resources and send alerts"
}
return examples.get(language, "Enter code snippet description here...")
# New function to check WebGPU availability
def check_webgpu_availability():
html_code = """
<div id="webgpu-status"></div>
<script>
(async () => {
try {
if (!navigator.gpu) {
document.getElementById('webgpu-status').innerHTML =
'<span class="webgpu-status unavailable">WebGPU Not Available</span>';
return;
}
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
document.getElementById('webgpu-status').innerHTML =
'<span class="webgpu-status unavailable">WebGPU Adapter Not Found</span>';
return;
}
document.getElementById('webgpu-status').innerHTML =
'<span class="webgpu-status available">WebGPU Available</span>';
} catch (e) {
document.getElementById('webgpu-status').innerHTML =
'<span class="webgpu-status unavailable">WebGPU Error: ' + e.message + '</span>';
}
})();
</script>
"""
return html_code
# Function to generate HTML for the WebGPU model playground
def generate_webgpu_playground():
html_code = """
<div class="model-playground">
<div id="model-status">Loading transformers.js...</div>
<div id="model-controls" style="display: none;">
<div class="model-card">
<h3>Text Generation</h3>
<textarea id="text-input" placeholder="Enter text prompt here..." rows="4" style="width: 100%; margin-bottom: 10px;"></textarea>
<div class="progress-container">
<div id="text-progress" class="progress-bar"></div>
</div>
<button id="generate-text-btn">Generate</button>
<div id="text-output" style="margin-top: 15px; white-space: pre-wrap; background: #2a2a2a; padding: 10px; border-radius: 5px;"></div>
</div>
<div class="model-card" style="margin-top: 20px;">
<h3>Image Classification</h3>
<input type="file" id="image-upload" accept="image/*" style="margin-bottom: 10px;">
<div class="progress-container">
<div id="image-progress" class="progress-bar"></div>
</div>
<button id="classify-image-btn">Classify</button>
<div id="image-output" style="margin-top: 15px;"></div>
</div>
</div>
<script src="https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0"></script>
<script>
(async () => {
const statusEl = document.getElementById('model-status');
const controlsEl = document.getElementById('model-controls');
try {
// Check WebGPU availability
let device = 'cpu';
try {
if (navigator.gpu) {
const adapter = await navigator.gpu.requestAdapter();
if (adapter) {
device = 'webgpu';
statusEl.innerHTML = 'Using WebGPU acceleration! 🚀';
} else {
statusEl.innerHTML = 'WebGPU adapter not found. Using CPU fallback.';
}
} else {
statusEl.innerHTML = 'WebGPU not available. Using CPU fallback.';
}
} catch (e) {
statusEl.innerHTML = 'WebGPU error: ' + e.message + '. Using CPU fallback.';
}
// Show controls
controlsEl.style.display = 'block';
// Text generation setup
const textInput = document.getElementById('text-input');
const textOutput = document.getElementById('text-output');
const textProgress = document.getElementById('text-progress');
const generateTextBtn = document.getElementById('generate-text-btn');
generateTextBtn.addEventListener('click', async () => {
if (!textInput.value.trim()) return;
textOutput.textContent = 'Loading model...';
textProgress.style.width = '10%';
try {
// Use a small model for text generation
const generator = await window.transformers.pipeline(
'text-generation',
'Xenova/distilgpt2',
{ device }
);
textProgress.style.width = '50%';
textOutput.textContent = 'Generating text...';
const result = await generator(textInput.value, {
max_new_tokens: 50,
temperature: 0.7
});
textProgress.style.width = '100%';
textOutput.textContent = result[0].generated_text;
// Reset progress after a delay
setTimeout(() => {
textProgress.style.width = '0%';
}, 1000);
} catch (e) {
textOutput.textContent = 'Error: ' + e.message;
textProgress.style.width = '0%';
}
});
// Image classification setup
const imageUpload = document.getElementById('image-upload');
const imageOutput = document.getElementById('image-output');
const imageProgress = document.getElementById('image-progress');
const classifyImageBtn = document.getElementById('classify-image-btn');
classifyImageBtn.addEventListener('click', async () => {
if (!imageUpload.files || imageUpload.files.length === 0) return;
imageOutput.textContent = 'Loading model...';
imageProgress.style.width = '10%';
try {
// Use a small model for image classification
const classifier = await window.transformers.pipeline(
'image-classification',
'Xenova/vit-base-patch16-224',
{ device }
);
imageProgress.style.width = '50%';
imageOutput.textContent = 'Classifying image...';
const result = await classifier(imageUpload.files[0]);
imageProgress.style.width = '100%';
// Display results
imageOutput.innerHTML = '<h4>Results:</h4>';
result.forEach(prediction => {
const percent = (prediction.score * 100).toFixed(2);
imageOutput.innerHTML += `<div>${prediction.label}: ${percent}%</div>`;
});
// Reset progress after a delay
setTimeout(() => {
imageProgress.style.width = '0%';
}, 1000);
} catch (e) {
imageOutput.textContent = 'Error: ' + e.message;
imageProgress.style.width = '0%';
}
});
} catch (e) {
statusEl.innerHTML = 'Error initializing transformers.js: ' + e.message;
}
})();
</script>
</div>
"""
return html_code
# Function to provide theme toggle
def toggle_theme(dark_mode):
if dark_mode:
return """
<style>
body, .gradio-container { background-color: #121212; color: #ffffff; }
</style>
"""
else:
return """
<style>
body, .gradio-container { background-color: #ffffff; color: #333333; }
input, textarea, select { background-color: #f5f5f5; border: 1px solid #ddd; color: #333; }
</style>
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# AI Code Snippet Generator")
# WebGPU status indicator
webgpu_status = gr.HTML(check_webgpu_availability())
# Theme toggle
with gr.Row():
dark_mode = gr.Checkbox(label="Dark Mode", value=True)
theme_html = gr.HTML(toggle_theme(True))
dark_mode.change(toggle_theme, inputs=dark_mode, outputs=theme_html)
with gr.Tabs() as tabs:
with gr.TabItem("Generate Snippet"):
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
language = gr.Dropdown(
label="Programming Language",
choices=["python", "javascript", "typescript", "markdown", "html", "css",
"cpp", "java", "go", "rust", "sql", "bash", "shell"],
value="python",
interactive=True
)
description = gr.Textbox(
label="Description",
lines=4,
placeholder="Enter code snippet description here...",
info="Describe what you want the code to do"
)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Output Options")
output_dir = gr.Textbox(label="Output Directory", value="output")
output_file_name = gr.Textbox(label="Output File Name", value="snippet")
clipboard = gr.Checkbox(label="Copy to Clipboard", value=False)
format_code = gr.Checkbox(label="Format Code (Python only)", value=False)
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown("### AI Model Configuration")
with gr.Row():
provider = gr.Dropdown(
label="LLM Provider",
choices=["Gemini", "OpenAI", "Anthropic", "Ollama", "LMStudios", "OpenRouter"],
value="Gemini"
)
model = gr.Textbox(label="Model", value="gemini-2.0-flash")
api_key = gr.Textbox(
label="API Key",
placeholder="Enter API key if required",
type="password",
info="Required for most providers except local Ollama"
)
models_list_output = gr.Textbox(
label="Available Models",
interactive=False,
value="Select a provider and click 'List Models' to see available options"
)
list_models_button = gr.Button("List Available Models", variant="secondary")
with gr.Column():
with gr.Group():
gr.Markdown("### Template Options (Optional)")
template = gr.Textbox(
label="Template",
placeholder="e.g., functional_component.jsx",
info="Name of template file in templates directory"
)
params = gr.Textbox(
label="Template Parameters (JSON)",
placeholder='e.g., {"component_name": "MyComponent", "title": "Hello"}',
lines=2,
info="JSON object with values to fill in template placeholders"
)
existing_file = gr.Textbox(
label="Existing File",
placeholder="Path to existing file for snippet insertion",
info="Optional: insert snippet into this file"
)
marker = gr.Textbox(
label="Marker",
placeholder="Marker string in file (if applicable)",
info="Optional: insert snippet at this marker in existing file"
)
with gr.Row():
generate_button = gr.Button("Generate Snippet", variant="primary", size="lg")
download_button = gr.Button("Download Snippet File", variant="secondary", size="lg")
clear_button = gr.Button("Clear Fields", variant="stop", size="lg")
with gr.Row():
with gr.Column():
snippet_output = gr.Code(label="Generated Snippet", language="python", lines=15)
cost_output = gr.Textbox(label="Estimated Cost (USD)", interactive=False)
file_output = gr.File(label="Download File")
status = gr.Textbox(label="Status", interactive=False)
with gr.TabItem("Examples"):
gr.Markdown("### Example Prompts")
example_cards = [
["Python", "Create a function to download and process JSON data from an API with error handling"],
["JavaScript", "Write a modern ES6 utility function to deep merge two objects"],
["TypeScript", "Create a React hook for managing form state with TypeScript"],
["HTML/CSS", "Design a responsive card component with hover effects"],
["SQL", "Write a query to find the top 5 customers by purchase amount with their details"]
]
with gr.Column():
for lang, desc in example_cards:
with gr.Group():
gr.Markdown(f"#### {lang} Example")
gr.Markdown(desc)
example_btn = gr.Button(f"Use this example", variant="secondary")
example_btn.click(
lambda l=lang.lower().split("/")[0], d=desc: [d, l],
inputs=None,
outputs=[description, language]
)
# New tab for WebGPU-powered model playground
with gr.TabItem("Model Playground"):
gr.Markdown("### WebGPU-Powered Model Playground")
gr.Markdown("""
This playground allows you to run AI models directly in your browser using WebGPU acceleration when available.
Try text generation and image classification with GPU acceleration!
**Note:** WebGPU is supported in Chrome 113+, Edge 113+, and Firefox with the `dom.webgpu.enabled` flag.
""")
playground_html = gr.HTML(generate_webgpu_playground())
# New tab for model management
with gr.TabItem("Model Management"):
gr.Markdown("### Local Model Management")
with gr.Row():
with gr.Column():
gr.Markdown("#### Download and Cache Models")
model_repo = gr.Textbox(
label="Model Repository",
placeholder="e.g., Xenova/distilgpt2",
info="Enter a Hugging Face model repository"
)
download_model_btn = gr.Button("Download Model", variant="primary")
model_status = gr.Textbox(label="Status", interactive=False)
with gr.Column():
gr.Markdown("#### Cached Models")
cached_models_list = gr.Textbox(
label="Available Local Models",
interactive=False,
value="No models cached yet"
)
refresh_models_btn = gr.Button("Refresh List", variant="secondary")
# Function to simulate model download (in a real implementation, this would actually download)
def download_model(repo_id):
if not repo_id:
return "Please enter a valid model repository"
# This is a placeholder - in a real implementation, you would download the model
return f"Model {repo_id} downloaded and cached successfully"
# Connect event handlers
download_model_btn.click(
fn=download_model,
inputs=model_repo,
outputs=model_status
)
# Placeholder function for refreshing model list
def refresh_models():
# In a real implementation, this would scan a local directory
return "Models in cache:\n- Xenova/distilgpt2\n- Xenova/vit-base-patch16-224"
refresh_models_btn.click(
fn=refresh_models,
inputs=None,
outputs=cached_models_list
)
with gr.TabItem("Help"):
gr.Markdown("""
# How to Use the Snippet Generator
This tool generates code snippets using AI models based on your description. Here's how to use it:
1. **Select a programming language** from the dropdown
2. **Enter a description** of what you want the code to do
3. **Configure the AI model**:
- Choose a provider (Gemini, OpenAI, etc.)
- Select a model or use the default
- Enter your API key if required
4. **Set output options**:
- Specify where to save the file
- Enable clipboard copying if needed
- Enable code formatting for Python
5. **Click "Generate Snippet"** to create your code
### Templates
You can use templates to create code with a specific structure:
1. Place template files in the `templates` directory
2. Enter the template filename in the Template field
3. Provide parameters as a JSON object
### Inserting into Existing Files
To insert code into an existing file:
1. Specify the path to the existing file
2. Optionally provide a marker string where the code should be inserted
### API Keys
Most providers require an API key:
- **Gemini**: Get from [Google AI Studio](https://aistudio.google.com/)
- **OpenAI**: Get from [OpenAI Platform](https://platform.openai.com/)
- **Anthropic**: Get from [Anthropic Console](https://console.anthropic.com/)
- **Ollama**: Not required for local installation
""")
# Add WebGPU section to help
gr.Markdown("""
### WebGPU Acceleration
The Model Playground tab uses WebGPU for GPU acceleration when available:
- **WebGPU** is a new web standard that allows running AI models directly in your browser with GPU acceleration
- Supported in Chrome 113+, Edge 113+, and Firefox with the `dom.webgpu.enabled` flag
- Falls back to CPU when WebGPU is not available
### Local Model Management
You can download and cache models for local use:
1. Go to the Model Management tab
2. Enter a Hugging Face model repository ID
3. Click "Download Model" to cache it locally
4. Use the cached model in the Model Playground
""")
# Event handlers
language.change(
fn=update_language_example,
inputs=language,
outputs=description
)
provider.change(
fn=update_model_suggestions,
inputs=provider,
outputs=models_list_output
)
list_models_button.click(
fn=list_models_gui,
inputs=[provider, api_key],
outputs=models_list_output
)
generate_button.click(
fn=generate_snippet_gui,
inputs=[description, language, output_dir, output_file_name, clipboard,
provider, model, api_key, template, params, existing_file, marker, format_code],
outputs=[snippet_output, cost_output]
)
download_button.click(
fn=download_snippet_file_gui,
inputs=[description, language, output_dir, output_file_name, clipboard,
provider, model, api_key, template, params, existing_file, marker, format_code],
outputs=file_output
)
clear_button.click(
fn=clear_fields,
inputs=None,
outputs=[description, language, output_dir, output_file_name, clipboard,
provider, model, api_key, template, params, existing_file, marker, format_code, snippet_output]
)
demo.launch()