Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,080 Bytes
c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae ab59d51 c5681ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
"""
Gradio demo for steered LLM generation using SAE features.
Supports real-time streaming generation with HuggingFace Transformers.
IMPORTANT: Before running this app, you must extract steering vectors:
python extract_steering_vectors.py
This creates steering_vectors.pt which is much faster to load than
downloading full SAE files from HuggingFace Hub.
For HuggingFace Spaces ZeroGPU deployment, the @spaces.GPU decorator
ensures efficient GPU allocation only during inference.
"""
import gradio as gr
import torch
import yaml
import os
# ZeroGPU support for HuggingFace Spaces
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
# Create a dummy decorator for local development
def spaces_gpu_decorator(func):
return func
spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()
from transformers import AutoModelForCausalLM, AutoTokenizer
from steering import load_saes_from_file, stream_steered_answer_hf
# Global variables
model = None
tokenizer = None
steering_components = None
cfg = None
def initialize_model():
"""
Load model, SAEs, and configuration on startup.
For ZeroGPU: Model is loaded with device_map="auto" and will be automatically
moved to GPU when @spaces.GPU decorated functions are called. Steering vectors
are loaded on CPU initially and moved to GPU during inference.
"""
global model, tokenizer, steering_components, cfg
# Get HuggingFace token for gated models (if needed)
hf_token = os.getenv("HF_TOKEN", None)
if hf_token:
print("Using HF_TOKEN from environment")
print("Loading configuration...")
with open("demo.yaml", "r") as f:
cfg = yaml.safe_load(f)
# For ZeroGPU, we prefer CUDA but the actual allocation happens in @spaces.GPU functions
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model: {cfg['llm_name']}...")
print(f"Target device: {device} (ZeroGPU will manage allocation)" if SPACES_AVAILABLE else f"Target device: {device}")
model = AutoModelForCausalLM.from_pretrained(
cfg['llm_name'],
device_map="auto",
dtype=torch.float16 if device == "cuda" else torch.float32,
token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(cfg['llm_name'], token=hf_token)
print("Loading SAE steering components...")
# Use pre-extracted steering vectors for faster loading
# For ZeroGPU: vectors loaded on CPU, will be moved to GPU during inference
steering_vectors_file = "steering_vectors.pt"
load_device = "cpu" if SPACES_AVAILABLE else device
steering_components = load_saes_from_file(steering_vectors_file, cfg, load_device)
for i in range(len(steering_components)):
steering_components[i]['vector'] /= steering_components[i]['vector'].norm()
print("Model initialized successfully!")
return model, tokenizer, steering_components, cfg
@spaces.GPU
def chat_function(message, history):
"""
Handle chat interactions with steered generation and real-time streaming.
Decorated with @spaces.GPU to allocate GPU only during inference on HuggingFace Spaces.
Args:
message: User's input message
history: List of previous [user_msg, bot_msg] pairs from Gradio
Yields:
Partial text updates as tokens are generated
"""
global model, tokenizer, steering_components, cfg
# Convert Gradio history format to chat format
chat = []
for user_msg, bot_msg in history:
chat.append({"role": "user", "content": user_msg})
if bot_msg is not None:
chat.append({"role": "assistant", "content": bot_msg})
# Add current message
chat.append({"role": "user", "content": message})
# Stream tokens as they are generated
for partial_text in stream_steered_answer_hf(
model=model,
tokenizer=tokenizer,
chat=chat,
steering_components=steering_components,
max_new_tokens=cfg['max_new_tokens'],
temperature=cfg['temperature'],
repetition_penalty=cfg['repetition_penalty'],
clamp_intensity=cfg['clamp_intensity']
):
yield partial_text
def create_demo():
"""Create and configure the Gradio interface."""
# Custom CSS for better appearance
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
#chatbot {
height: 600px;
}
"""
# Create the interface
demo = gr.ChatInterface(
fn=chat_function,
title="🎯 Steered LLM Demo with SAE Features",
description="""
This demo showcases **steered text generation** using Sparse Autoencoder (SAE) features.
The model (Llama 3.1 8B Instruct) has its activations modified using vectors extracted from SAEs,
resulting in controlled behavior changes during generation.
**Features:**
- Real-time streaming: tokens appear as they're generated ⚡
- Multi-turn conversations with full history
- SAE-based activation steering across multiple layers
Start chatting below!
""",
examples=[
"Explain how neural networks work.",
"Tell me a creative story about a robot.",
"What are the applications of AI in healthcare?"
],
cache_examples=False,
theme=gr.themes.Soft(),
css=custom_css,
chatbot=gr.Chatbot(
elem_id="chatbot",
bubble_full_width=False,
show_copy_button=True
),
)
return demo
if __name__ == "__main__":
print("=" * 60)
print("Steered LLM Demo - Initializing")
print("=" * 60)
initialize_model()
print("\n" + "=" * 60)
print("Launching Gradio interface...")
print("=" * 60 + "\n")
demo = create_demo()
demo.launch(
share=False, # Set to True for public link
server_name="0.0.0.0", # Allow external access
server_port=7860 # Default HF Spaces port
)
|