abi96062's picture
Update app.py
144aae5 verified
raw
history blame
7.97 kB
import gradio as gr
import torch
import torch.nn as nn
from model import SmolLM2Model # โœ… Correct import
from transformers import AutoTokenizer, AutoConfig
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizer and config
print("Loading tokenizer and config...")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
# Load model
@torch.no_grad()
def load_model():
"""Load the trained model"""
print("Loading model...")
# Initialize model with config
model = SmolLM2Model(config).to(device)
# Load checkpoint
checkpoint = torch.load('checkpoint_step_5050.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"โœ… Model loaded successfully on {device}")
print(f"โœ… Training step: {checkpoint.get('step', 'N/A')}")
return model, checkpoint
# Load model at startup
model, checkpoint = load_model()
@torch.no_grad()
def generate_text(
prompt,
max_length=100,
temperature=0.8,
top_k=50,
top_p=0.9
):
"""Generate text from prompt"""
try:
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs['input_ids']
# Generate using model's built-in method
generated_ids = model.generate(
input_ids,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k if top_k > 0 else None,
do_sample=temperature > 0
)
# Decode
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return output_text
except Exception as e:
return f"โŒ Error generating text: {str(e)}"
def get_model_info():
"""Display model information"""
total_params = model.get_num_params()
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
info = f"""
### ๐Ÿ“Š Model Information
**Model:** SmolLM2-135M
**Total Parameters:** {total_params:,} (~{total_params/1e6:.1f}M)
**Trainable Parameters:** {trainable_params:,}
**Training Steps:** {checkpoint.get('step', 'N/A')}
**Device:** {device}
**Vocab Size:** {config.vocab_size:,}
### ๐Ÿ—๏ธ Architecture
- **Layers:** {config.num_hidden_layers}
- **Hidden Size:** {config.hidden_size}
- **Attention Heads:** {config.num_attention_heads} (Query) / {config.num_key_value_heads} (KV)
- **FFN Size:** {config.intermediate_size}
- **Context Length:** {config.max_position_embeddings}
### ๐ŸŽฏ Training Details
- โœ… Trained for 5,000 steps
- โœ… Checkpoint saved and reloaded
- โœ… Additional 50 steps after reload
- โœ… Predictions logged every 500 steps
"""
return info
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="SmolLM2-135M Demo") as demo:
gr.Markdown("""
# ๐Ÿค– SmolLM2-135M: From-Scratch Implementation
Complete reverse-engineered implementation of SmolLM2-135M, trained from scratch.
**GitHub:** [abi2024/smollm2-135-implementation](https://github.com/abi2024/smollm2-135-implementation)
""")
with gr.Tab("๐ŸŽฎ Generate Text"):
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=3,
value="Once upon a time"
)
with gr.Row():
max_length_slider = gr.Slider(
minimum=10,
maximum=200,
value=50,
step=10,
label="Max New Tokens"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature"
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=5,
label="Top-K"
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P (Nucleus)"
)
generate_btn = gr.Button("๐Ÿš€ Generate", variant="primary", size="lg")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=12,
interactive=False
)
generate_btn.click(
fn=generate_text,
inputs=[
prompt_input,
max_length_slider,
temperature_slider,
top_k_slider,
top_p_slider
],
outputs=output_text
)
gr.Markdown("""
### ๐Ÿ’ก Generation Tips:
- **Temperature**: Controls randomness (0.1 = focused, 2.0 = creative)
- **Top-K**: Limits to K most likely tokens (0 = disabled)
- **Top-P**: Nucleus sampling threshold (0.9 recommended)
""")
with gr.Tab("๐Ÿ“Š Model Info"):
model_info_display = gr.Markdown(get_model_info())
gr.Markdown("""
### ๐Ÿ” Reverse Engineering Process
1. **Architecture Analysis**
- Studied SmolLM2 GitHub repository
- Extracted model configuration from YAML
- Downloaded pretrained 135M checkpoint
2. **Implementation**
- Built from scratch using PyTorch
- Implemented Grouped Query Attention (9Q/3KV heads)
- Added RoPE position embeddings
- Used SwiGLU FFN and RMSNorm
3. **Validation**
- Loaded official pretrained weights
- Verified parameter count (134,515,008)
- Confirmed architecture matches exactly
### โšก Optimizations Applied
- โœ… Flash Attention 2 (via scaled_dot_product_attention)
- โœ… Mixed Precision Training (BF16/FP16)
- โœ… Gradient Accumulation
- โœ… torch.compile() for inference speedup
- โœ… Grouped Query Attention (memory efficient)
### ๐Ÿ“ˆ Training Pipeline
1. **Main Training:** 5,000 steps with predictions every 500 steps
2. **Checkpoint Test:** Model saved and successfully reloaded
3. **Resume Training:** 50 additional steps (validates checkpoint integrity)
""")
with gr.Tab("๐ŸŽฏ Example Prompts"):
gr.Markdown("""
### Try these prompts:
**1. Story Generation**
```
Once upon a time in a magical forest,
```
**2. Code Completion**
```
def calculate_fibonacci(n):
# Calculate the nth Fibonacci number
```
**3. Question Answering**
```
Q: What is the capital of France?
A:
```
**4. Technical Writing**
```
The main advantage of transformer architectures is
```
**5. Creative Writing**
```
The scientist discovered something extraordinary:
```
### ๐ŸŽ›๏ธ Recommended Settings:
- **Creative Writing:** Temperature=1.0, Top-P=0.95
- **Code Generation:** Temperature=0.3, Top-P=0.9, Top-K=40
- **Factual Q&A:** Temperature=0.5, Top-P=0.8, Top-K=30
""")
# Launch
if __name__ == "__main__":
demo.launch()