FlameF0X's picture
Update app.py
1747725 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces # Required for Hugging Face Spaces GPU
import random
import numpy as np
# Set random seeds for reproducibility
def set_random_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set initial seed
set_random_seed(42)
# Determine the device to use (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"FlameF0X/SnowflakeCore-G1-Tiny2",
trust_remote_code=True,
force_download=True,
use_safetensors=True,
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
"FlameF0X/SnowflakeCore-G1-Tiny2",
trust_remote_code=True,
force_download=True,
use_safetensors=True,
)
@spaces.GPU # Required decorator for GPU usage in Hugging Face Spaces
def advanced_generate(prompt, max_length=50, temperature=1.0, top_k=50, top_p=0.9,
repetition_penalty=1.1, do_sample=True, seed=None):
"""
Generates text with advanced sampling parameters.
The model and input tensors are moved to the appropriate device (GPU/CPU).
"""
# Set seed if provided
if seed is not None:
set_random_seed(seed)
model.eval()
# Move input_ids to the same device as the model
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generated = input_ids
with torch.no_grad():
for i in range(max_length):
# Get model outputs
outputs = model(input_ids=generated)
next_token_logits = outputs["logits"][:, -1, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
for token_id in set(generated[0].tolist()):
next_token_logits[0, token_id] /= repetition_penalty
# Apply temperature
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Convert logits to probabilities
probs = torch.softmax(next_token_logits, dim=-1)
if do_sample and temperature > 0:
# Apply top-k filtering
if top_k > 0:
top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.size(-1)))
probs_filtered = torch.zeros_like(probs)
probs_filtered.scatter_(1, top_k_indices, top_k_probs)
probs = probs_filtered
# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Keep at least one token
sorted_indices_to_remove[0, 0] = False
# Create mask for tokens to remove
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
probs[indices_to_remove] = 0
# Renormalize
probs = probs / probs.sum(dim=-1, keepdim=True)
# Sample from the filtered distribution
next_token_id = torch.multinomial(probs, num_samples=1)
else:
# Greedy decoding
next_token_id = torch.argmax(probs, dim=-1).unsqueeze(-1)
generated = torch.cat((generated, next_token_id), dim=1)
# Check for end of sequence
if next_token_id.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(generated[0], skip_special_tokens=True)
def gradio_generate(prompt, max_length, temperature, top_k, top_p, repetition_penalty, do_sample, seed):
"""
Wrapper function for Gradio interface.
"""
# Convert seed to int if provided, otherwise use random
if seed == "" or seed is None:
seed = random.randint(0, 2**32 - 1)
else:
try:
seed = int(seed)
except ValueError:
seed = random.randint(0, 2**32 - 1)
# Ensure parameters are in valid ranges
max_length = max(1, min(200, int(max_length)))
temperature = max(0.1, min(2.0, float(temperature)))
top_k = max(1, min(100, int(top_k)))
top_p = max(0.1, min(1.0, float(top_p)))
repetition_penalty = max(1.0, min(2.0, float(repetition_penalty)))
return advanced_generate(
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
seed=seed
)
# Custom CSS for ultra-compact UI and full width
custom_css = """
.gradio-container {
max-width: 100% !important; /* Changed from 1200px to 100% for full width */
padding: 10px !important;
}
.compact-slider {
margin: 2px 0 !important;
}
.parameter-info {
font-size: 10px;
color: #888;
margin: -8px 0 8px 0 !important;
line-height: 1.1;
padding: 2px 4px;
background: rgba(0,0,0,0.05);
border-radius: 3px;
}
.gradio-group {
padding: 8px !important;
margin: 4px 0 !important;
}
.gradio-textbox {
margin-bottom: 8px !important;
}
.gradio-slider {
margin: 4px 0 !important;
}
.gradio-checkbox {
margin: 4px 0 !important;
}
.gradio-number {
margin: 4px 0 !important;
}
.compact-header {
margin: 0 0 10px 0 !important;
padding: 8px !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 8px;
color: white;
}
.control-section {
background: rgba(255,255,255,0.02);
border-radius: 6px;
padding: 8px !important;
margin: 4px 0 !important;
}
"""
# Create the Gradio interface with ultra-compact layout
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="SnowflakeCore Text Generator") as iface:
# Compact header
gr.HTML(f"""
<div class="compact-header">
<h2 style="margin: 0; font-size: 18px;">πŸ”₯ SnowflakeCore-G1-Tiny2 | Running on: {device}</h2>
</div>
""")
with gr.Row():
# Left column - Input/Output (60% width)
with gr.Column(scale=6, min_width=400):
prompt_input = gr.Textbox(
lines=4,
placeholder="Enter your prompt here...",
label="πŸ“ Input Prompt",
show_label=True
)
generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="sm", scale=1)
output_text = gr.Textbox(
lines=8,
label="✨ Generated Text",
show_label=True,
interactive=False
)
# Right column - Parameters (40% width)
with gr.Column(scale=4, min_width=300):
gr.HTML("<div style='font-weight: bold; font-size: 14px; margin-bottom: 8px; color: #333;'>βš™οΈ Parameters</div>")
# Core parameters in compact group
with gr.Group(elem_classes=["control-section"]):
max_length = gr.Slider(10, 2048, value=100, step=5, label="Max Length", elem_classes=["compact-slider"])
gr.HTML("<div class='parameter-info'>πŸ“ Tokens to generate (10-2048)</div>")
temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature", elem_classes=["compact-slider"])
gr.HTML("<div class='parameter-info'>🌑️ Creativity: 0.1=focused, 2.0=creative</div>")
# Advanced parameters
with gr.Group(elem_classes=["control-section"]):
top_k = gr.Slider(1, 150, value=50, step=1, label="Top-K", elem_classes=["compact-slider"])
gr.HTML("<div class='parameter-info'>🎯 Word choice diversity</div>")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P", elem_classes=["compact-slider"])
gr.HTML("<div class='parameter-info'>πŸŽͺ Probability cutoff</div>")
repetition_penalty = gr.Slider(1.0, 3.0, value=1.15, step=0.05, label="Rep. Penalty", elem_classes=["compact-slider"])
gr.HTML("<div class='parameter-info'>πŸ”„ Anti-repetition strength</div>")
# Controls row
with gr.Row():
do_sample = gr.Checkbox(value=True, label="Sampling", scale=1)
seed_input = gr.Number(label="Seed", value=None, precision=0, scale=1, minimum=0)
gr.HTML("<div class='parameter-info'>🎲 Sampling=creative | Seed=reproducible</div>")
# Examples section
gr.HTML("<h3 style='margin-top: 25px; margin-bottom: 10px;'>πŸ’‘ Quick Examples</h3>")
examples = gr.Examples(
examples=[
["Once upon a time in a magical forest,", 120, 0.8, 40, 0.9, 1.2, True, 42],
["The future of artificial intelligence is", 80, 1.0, 50, 0.95, 1.1, True, None],
["In a world where technology and nature coexist,", 150, 1.2, 60, 0.85, 1.3, True, 123],
["Write a haiku about winter:", 50, 0.7, 30, 0.8, 1.0, True, None],
["Explain quantum computing in simple terms:", 200, 0.6, 40, 0.9, 1.1, True, None]
],
inputs=[prompt_input, max_length, temperature, top_k, top_p, repetition_penalty, do_sample, seed_input],
outputs=[output_text],
fn=gradio_generate,
cache_examples=False,
label=None
)
# Event handlers
generate_btn.click(
fn=gradio_generate,
inputs=[prompt_input, max_length, temperature, top_k, top_p, repetition_penalty, do_sample, seed_input],
outputs=[output_text]
)
# Tips section
with gr.Accordion("πŸ“š Parameter Guide & Tips", open=False):
gr.HTML("""
<div style="font-size: 12px; line-height: 1.4;">
<h4>πŸŽ›οΈ Parameter Combinations for Different Use Cases:</h4>
<strong>πŸ“ Creative Writing:</strong> Temperature: 0.8-1.2, Top-K: 40-60, Top-P: 0.85-0.95, Rep. Penalty: 1.1-1.3<br>
<strong>πŸ“‹ Factual/Technical:</strong> Temperature: 0.3-0.7, Top-K: 20-40, Top-P: 0.9-1.0, Rep. Penalty: 1.0-1.1<br>
<strong>🎭 Experimental/Artistic:</strong> Temperature: 1.2-2.0, Top-K: 60-100, Top-P: 0.7-0.9, Rep. Penalty: 1.2-1.5<br>
<strong>🎯 Focused/Consistent:</strong> Temperature: 0.1-0.5, Top-K: 10-30, Top-P: 0.8-0.95, Rep. Penalty: 1.0-1.2<br><br>
<strong>πŸ’‘ Pro Tips:</strong><br>
β€’ Use same seed for reproducible results across generations<br>
β€’ Higher repetition penalty helps with stuck loops or repeated phrases<br>
β€’ Lower temperature + higher top-p = focused but varied vocabulary<br>
β€’ Disable sampling for completely deterministic output (useful for testing)<br>
β€’ Start with defaults and adjust one parameter at a time to understand effects
</div>
""")
# Launch the Gradio application
if __name__ == "__main__":
iface.launch()