Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import spaces # Required for Hugging Face Spaces GPU | |
import random | |
import numpy as np | |
# Set random seeds for reproducibility | |
def set_random_seed(seed=42): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
# Set initial seed | |
set_random_seed(42) | |
# Determine the device to use (GPU if available, otherwise CPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
"FlameF0X/SnowflakeCore-G1-Tiny2", | |
trust_remote_code=True, | |
force_download=True, | |
use_safetensors=True, | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"FlameF0X/SnowflakeCore-G1-Tiny2", | |
trust_remote_code=True, | |
force_download=True, | |
use_safetensors=True, | |
) | |
# Required decorator for GPU usage in Hugging Face Spaces | |
def advanced_generate(prompt, max_length=50, temperature=1.0, top_k=50, top_p=0.9, | |
repetition_penalty=1.1, do_sample=True, seed=None): | |
""" | |
Generates text with advanced sampling parameters. | |
The model and input tensors are moved to the appropriate device (GPU/CPU). | |
""" | |
# Set seed if provided | |
if seed is not None: | |
set_random_seed(seed) | |
model.eval() | |
# Move input_ids to the same device as the model | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
generated = input_ids | |
with torch.no_grad(): | |
for i in range(max_length): | |
# Get model outputs | |
outputs = model(input_ids=generated) | |
next_token_logits = outputs["logits"][:, -1, :] | |
# Apply repetition penalty | |
if repetition_penalty != 1.0: | |
for token_id in set(generated[0].tolist()): | |
next_token_logits[0, token_id] /= repetition_penalty | |
# Apply temperature | |
if temperature != 1.0: | |
next_token_logits = next_token_logits / temperature | |
# Convert logits to probabilities | |
probs = torch.softmax(next_token_logits, dim=-1) | |
if do_sample and temperature > 0: | |
# Apply top-k filtering | |
if top_k > 0: | |
top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.size(-1))) | |
probs_filtered = torch.zeros_like(probs) | |
probs_filtered.scatter_(1, top_k_indices, top_k_probs) | |
probs = probs_filtered | |
# Apply top-p (nucleus) filtering | |
if top_p < 1.0: | |
sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
# Remove tokens with cumulative probability above the threshold | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# Keep at least one token | |
sorted_indices_to_remove[0, 0] = False | |
# Create mask for tokens to remove | |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
probs[indices_to_remove] = 0 | |
# Renormalize | |
probs = probs / probs.sum(dim=-1, keepdim=True) | |
# Sample from the filtered distribution | |
next_token_id = torch.multinomial(probs, num_samples=1) | |
else: | |
# Greedy decoding | |
next_token_id = torch.argmax(probs, dim=-1).unsqueeze(-1) | |
generated = torch.cat((generated, next_token_id), dim=1) | |
# Check for end of sequence | |
if next_token_id.item() == tokenizer.eos_token_id: | |
break | |
return tokenizer.decode(generated[0], skip_special_tokens=True) | |
def gradio_generate(prompt, max_length, temperature, top_k, top_p, repetition_penalty, do_sample, seed): | |
""" | |
Wrapper function for Gradio interface. | |
""" | |
# Convert seed to int if provided, otherwise use random | |
if seed == "" or seed is None: | |
seed = random.randint(0, 2**32 - 1) | |
else: | |
try: | |
seed = int(seed) | |
except ValueError: | |
seed = random.randint(0, 2**32 - 1) | |
# Ensure parameters are in valid ranges | |
max_length = max(1, min(200, int(max_length))) | |
temperature = max(0.1, min(2.0, float(temperature))) | |
top_k = max(1, min(100, int(top_k))) | |
top_p = max(0.1, min(1.0, float(top_p))) | |
repetition_penalty = max(1.0, min(2.0, float(repetition_penalty))) | |
return advanced_generate( | |
prompt=prompt, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=do_sample, | |
seed=seed | |
) | |
# Custom CSS for ultra-compact UI and full width | |
custom_css = """ | |
.gradio-container { | |
max-width: 100% !important; /* Changed from 1200px to 100% for full width */ | |
padding: 10px !important; | |
} | |
.compact-slider { | |
margin: 2px 0 !important; | |
} | |
.parameter-info { | |
font-size: 10px; | |
color: #888; | |
margin: -8px 0 8px 0 !important; | |
line-height: 1.1; | |
padding: 2px 4px; | |
background: rgba(0,0,0,0.05); | |
border-radius: 3px; | |
} | |
.gradio-group { | |
padding: 8px !important; | |
margin: 4px 0 !important; | |
} | |
.gradio-textbox { | |
margin-bottom: 8px !important; | |
} | |
.gradio-slider { | |
margin: 4px 0 !important; | |
} | |
.gradio-checkbox { | |
margin: 4px 0 !important; | |
} | |
.gradio-number { | |
margin: 4px 0 !important; | |
} | |
.compact-header { | |
margin: 0 0 10px 0 !important; | |
padding: 8px !important; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
border-radius: 8px; | |
color: white; | |
} | |
.control-section { | |
background: rgba(255,255,255,0.02); | |
border-radius: 6px; | |
padding: 8px !important; | |
margin: 4px 0 !important; | |
} | |
""" | |
# Create the Gradio interface with ultra-compact layout | |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="SnowflakeCore Text Generator") as iface: | |
# Compact header | |
gr.HTML(f""" | |
<div class="compact-header"> | |
<h2 style="margin: 0; font-size: 18px;">π₯ SnowflakeCore-G1-Tiny2 | Running on: {device}</h2> | |
</div> | |
""") | |
with gr.Row(): | |
# Left column - Input/Output (60% width) | |
with gr.Column(scale=6, min_width=400): | |
prompt_input = gr.Textbox( | |
lines=4, | |
placeholder="Enter your prompt here...", | |
label="π Input Prompt", | |
show_label=True | |
) | |
generate_btn = gr.Button("π Generate", variant="primary", size="sm", scale=1) | |
output_text = gr.Textbox( | |
lines=8, | |
label="β¨ Generated Text", | |
show_label=True, | |
interactive=False | |
) | |
# Right column - Parameters (40% width) | |
with gr.Column(scale=4, min_width=300): | |
gr.HTML("<div style='font-weight: bold; font-size: 14px; margin-bottom: 8px; color: #333;'>βοΈ Parameters</div>") | |
# Core parameters in compact group | |
with gr.Group(elem_classes=["control-section"]): | |
max_length = gr.Slider(10, 2048, value=100, step=5, label="Max Length", elem_classes=["compact-slider"]) | |
gr.HTML("<div class='parameter-info'>π Tokens to generate (10-2048)</div>") | |
temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature", elem_classes=["compact-slider"]) | |
gr.HTML("<div class='parameter-info'>π‘οΈ Creativity: 0.1=focused, 2.0=creative</div>") | |
# Advanced parameters | |
with gr.Group(elem_classes=["control-section"]): | |
top_k = gr.Slider(1, 150, value=50, step=1, label="Top-K", elem_classes=["compact-slider"]) | |
gr.HTML("<div class='parameter-info'>π― Word choice diversity</div>") | |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P", elem_classes=["compact-slider"]) | |
gr.HTML("<div class='parameter-info'>πͺ Probability cutoff</div>") | |
repetition_penalty = gr.Slider(1.0, 3.0, value=1.15, step=0.05, label="Rep. Penalty", elem_classes=["compact-slider"]) | |
gr.HTML("<div class='parameter-info'>π Anti-repetition strength</div>") | |
# Controls row | |
with gr.Row(): | |
do_sample = gr.Checkbox(value=True, label="Sampling", scale=1) | |
seed_input = gr.Number(label="Seed", value=None, precision=0, scale=1, minimum=0) | |
gr.HTML("<div class='parameter-info'>π² Sampling=creative | Seed=reproducible</div>") | |
# Examples section | |
gr.HTML("<h3 style='margin-top: 25px; margin-bottom: 10px;'>π‘ Quick Examples</h3>") | |
examples = gr.Examples( | |
examples=[ | |
["Once upon a time in a magical forest,", 120, 0.8, 40, 0.9, 1.2, True, 42], | |
["The future of artificial intelligence is", 80, 1.0, 50, 0.95, 1.1, True, None], | |
["In a world where technology and nature coexist,", 150, 1.2, 60, 0.85, 1.3, True, 123], | |
["Write a haiku about winter:", 50, 0.7, 30, 0.8, 1.0, True, None], | |
["Explain quantum computing in simple terms:", 200, 0.6, 40, 0.9, 1.1, True, None] | |
], | |
inputs=[prompt_input, max_length, temperature, top_k, top_p, repetition_penalty, do_sample, seed_input], | |
outputs=[output_text], | |
fn=gradio_generate, | |
cache_examples=False, | |
label=None | |
) | |
# Event handlers | |
generate_btn.click( | |
fn=gradio_generate, | |
inputs=[prompt_input, max_length, temperature, top_k, top_p, repetition_penalty, do_sample, seed_input], | |
outputs=[output_text] | |
) | |
# Tips section | |
with gr.Accordion("π Parameter Guide & Tips", open=False): | |
gr.HTML(""" | |
<div style="font-size: 12px; line-height: 1.4;"> | |
<h4>ποΈ Parameter Combinations for Different Use Cases:</h4> | |
<strong>π Creative Writing:</strong> Temperature: 0.8-1.2, Top-K: 40-60, Top-P: 0.85-0.95, Rep. Penalty: 1.1-1.3<br> | |
<strong>π Factual/Technical:</strong> Temperature: 0.3-0.7, Top-K: 20-40, Top-P: 0.9-1.0, Rep. Penalty: 1.0-1.1<br> | |
<strong>π Experimental/Artistic:</strong> Temperature: 1.2-2.0, Top-K: 60-100, Top-P: 0.7-0.9, Rep. Penalty: 1.2-1.5<br> | |
<strong>π― Focused/Consistent:</strong> Temperature: 0.1-0.5, Top-K: 10-30, Top-P: 0.8-0.95, Rep. Penalty: 1.0-1.2<br><br> | |
<strong>π‘ Pro Tips:</strong><br> | |
β’ Use same seed for reproducible results across generations<br> | |
β’ Higher repetition penalty helps with stuck loops or repeated phrases<br> | |
β’ Lower temperature + higher top-p = focused but varied vocabulary<br> | |
β’ Disable sampling for completely deterministic output (useful for testing)<br> | |
β’ Start with defaults and adjust one parameter at a time to understand effects | |
</div> | |
""") | |
# Launch the Gradio application | |
if __name__ == "__main__": | |
iface.launch() | |