Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
import numpy as np
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import List, Dict, Optional
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# Initialize logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class MarketingFeature:
|
16 |
+
"""Structure to hold marketing-relevant feature information"""
|
17 |
+
feature_id: int
|
18 |
+
name: str
|
19 |
+
category: str
|
20 |
+
description: str
|
21 |
+
interpretation_guide: str
|
22 |
+
layer: int
|
23 |
+
threshold: float = 0.1
|
24 |
+
|
25 |
+
# Define marketing-relevant features from Gemma Scope
|
26 |
+
MARKETING_FEATURES = [
|
27 |
+
MarketingFeature(
|
28 |
+
feature_id=35,
|
29 |
+
name="Technical Term Detector",
|
30 |
+
category="technical",
|
31 |
+
description="Detects technical and specialized terminology",
|
32 |
+
interpretation_guide="High activation indicates strong technical focus",
|
33 |
+
layer=20
|
34 |
+
),
|
35 |
+
MarketingFeature(
|
36 |
+
feature_id=6680,
|
37 |
+
name="Compound Technical Terms",
|
38 |
+
category="technical",
|
39 |
+
description="Identifies complex technical concepts",
|
40 |
+
interpretation_guide="Consider simplifying language if activation is too high",
|
41 |
+
layer=20
|
42 |
+
),
|
43 |
+
MarketingFeature(
|
44 |
+
feature_id=2,
|
45 |
+
name="SEO Keyword Detector",
|
46 |
+
category="seo",
|
47 |
+
description="Identifies potential SEO keywords",
|
48 |
+
interpretation_guide="High activation suggests strong SEO potential",
|
49 |
+
layer=20
|
50 |
+
),
|
51 |
+
# Add more relevant features as we discover them
|
52 |
+
]
|
53 |
+
|
54 |
+
class MarketingAnalyzer:
|
55 |
+
"""Main class for analyzing marketing content using Gemma Scope"""
|
56 |
+
|
57 |
+
def __init__(self, model_size: str = "2b"):
|
58 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
+
self._initialize_model(model_size)
|
60 |
+
self._load_saes()
|
61 |
+
|
62 |
+
def _initialize_model(self, model_size: str):
|
63 |
+
"""Initialize Gemma model and tokenizer"""
|
64 |
+
try:
|
65 |
+
model_name = f"google/gemma-{model_size}"
|
66 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
67 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
68 |
+
|
69 |
+
self.model = self.model.to(self.device)
|
70 |
+
self.model.eval()
|
71 |
+
|
72 |
+
logger.info(f"Initialized model: {model_name}")
|
73 |
+
except Exception as e:
|
74 |
+
logger.error(f"Error initializing model: {str(e)}")
|
75 |
+
raise
|
76 |
+
|
77 |
+
def _load_saes(self):
|
78 |
+
"""Load relevant SAEs from Gemma Scope"""
|
79 |
+
self.saes = {}
|
80 |
+
for feature in MARKETING_FEATURES:
|
81 |
+
try:
|
82 |
+
# Load SAE parameters for each feature
|
83 |
+
path = hf_hub_download(
|
84 |
+
repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
|
85 |
+
filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz"
|
86 |
+
)
|
87 |
+
params = np.load(path)
|
88 |
+
self.saes[feature.feature_id] = {
|
89 |
+
'params': {k: torch.from_numpy(v).cuda() for k, v in params.items()},
|
90 |
+
'feature': feature
|
91 |
+
}
|
92 |
+
logger.info(f"Loaded SAE for feature {feature.feature_id}")
|
93 |
+
except Exception as e:
|
94 |
+
logger.error(f"Error loading SAE for feature {feature.feature_id}: {str(e)}")
|
95 |
+
continue
|
96 |
+
|
97 |
+
def analyze_content(self, text: str) -> Dict:
|
98 |
+
"""Analyze marketing content using loaded SAEs"""
|
99 |
+
results = {
|
100 |
+
'text': text,
|
101 |
+
'features': {},
|
102 |
+
'categories': {},
|
103 |
+
'recommendations': []
|
104 |
+
}
|
105 |
+
|
106 |
+
try:
|
107 |
+
# Get model activations
|
108 |
+
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
|
109 |
+
with torch.no_grad():
|
110 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
111 |
+
|
112 |
+
# Analyze each feature
|
113 |
+
for feature_id, sae_data in self.saes.items():
|
114 |
+
feature = sae_data['feature']
|
115 |
+
layer_output = outputs.hidden_states[feature.layer]
|
116 |
+
|
117 |
+
# Apply SAE
|
118 |
+
activations = self._apply_sae(
|
119 |
+
layer_output,
|
120 |
+
sae_data['params'],
|
121 |
+
feature.threshold
|
122 |
+
)
|
123 |
+
|
124 |
+
# Record results
|
125 |
+
feature_result = {
|
126 |
+
'name': feature.name,
|
127 |
+
'category': feature.category,
|
128 |
+
'activation_score': float(activations.mean()),
|
129 |
+
'max_activation': float(activations.max()),
|
130 |
+
'interpretation': self._interpret_activation(
|
131 |
+
activations,
|
132 |
+
feature
|
133 |
+
)
|
134 |
+
}
|
135 |
+
|
136 |
+
results['features'][feature_id] = feature_result
|
137 |
+
|
138 |
+
# Aggregate by category
|
139 |
+
if feature.category not in results['categories']:
|
140 |
+
results['categories'][feature.category] = []
|
141 |
+
results['categories'][feature.category].append(feature_result)
|
142 |
+
|
143 |
+
# Generate recommendations
|
144 |
+
results['recommendations'] = self._generate_recommendations(results)
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
logger.error(f"Error analyzing content: {str(e)}")
|
148 |
+
raise
|
149 |
+
|
150 |
+
return results
|
151 |
+
|
152 |
+
def _apply_sae(
|
153 |
+
self,
|
154 |
+
activations: torch.Tensor,
|
155 |
+
sae_params: Dict[str, torch.Tensor],
|
156 |
+
threshold: float
|
157 |
+
) -> torch.Tensor:
|
158 |
+
"""Apply SAE to get feature activations"""
|
159 |
+
pre_acts = activations @ sae_params['W_enc'] + sae_params['b_enc']
|
160 |
+
mask = pre_acts > sae_params['threshold']
|
161 |
+
acts = mask * torch.nn.functional.relu(pre_acts)
|
162 |
+
return acts
|
163 |
+
|
164 |
+
def _interpret_activation(
|
165 |
+
self,
|
166 |
+
activations: torch.Tensor,
|
167 |
+
feature: MarketingFeature
|
168 |
+
) -> str:
|
169 |
+
"""Interpret activation patterns for a feature"""
|
170 |
+
mean_activation = float(activations.mean())
|
171 |
+
if mean_activation > 0.8:
|
172 |
+
return f"Very strong presence of {feature.name.lower()}"
|
173 |
+
elif mean_activation > 0.5:
|
174 |
+
return f"Moderate presence of {feature.name.lower()}"
|
175 |
+
else:
|
176 |
+
return f"Limited presence of {feature.name.lower()}"
|
177 |
+
|
178 |
+
def _generate_recommendations(self, results: Dict) -> List[str]:
|
179 |
+
"""Generate content recommendations based on analysis"""
|
180 |
+
recommendations = []
|
181 |
+
|
182 |
+
# Analyze technical complexity
|
183 |
+
tech_score = np.mean([
|
184 |
+
f['activation_score'] for f in results['features'].values()
|
185 |
+
if f['category'] == 'technical'
|
186 |
+
])
|
187 |
+
if tech_score > 0.8:
|
188 |
+
recommendations.append(
|
189 |
+
"Consider simplifying technical language for broader audience"
|
190 |
+
)
|
191 |
+
elif tech_score < 0.3:
|
192 |
+
recommendations.append(
|
193 |
+
"Could benefit from more specific technical details"
|
194 |
+
)
|
195 |
+
|
196 |
+
# Add more recommendation logic as needed
|
197 |
+
return recommendations
|
198 |
+
|
199 |
+
def create_gradio_interface():
|
200 |
+
"""Create Gradio interface for marketing analysis"""
|
201 |
+
analyzer = MarketingAnalyzer()
|
202 |
+
|
203 |
+
def analyze(text):
|
204 |
+
results = analyzer.analyze_content(text)
|
205 |
+
|
206 |
+
# Format results for display
|
207 |
+
output = "Content Analysis Results\n\n"
|
208 |
+
|
209 |
+
# Overall category scores
|
210 |
+
output += "Category Scores:\n"
|
211 |
+
for category, features in results['categories'].items():
|
212 |
+
avg_score = np.mean([f['activation_score'] for f in features])
|
213 |
+
output += f"{category.title()}: {avg_score:.2f}\n"
|
214 |
+
|
215 |
+
# Feature details
|
216 |
+
output += "\nFeature Details:\n"
|
217 |
+
for feature_id, feature in results['features'].items():
|
218 |
+
output += f"\n{feature['name']}:\n"
|
219 |
+
output += f"Score: {feature['activation_score']:.2f}\n"
|
220 |
+
output += f"Interpretation: {feature['interpretation']}\n"
|
221 |
+
|
222 |
+
# Recommendations
|
223 |
+
output += "\nRecommendations:\n"
|
224 |
+
for rec in results['recommendations']:
|
225 |
+
output += f"- {rec}\n"
|
226 |
+
|
227 |
+
return output
|
228 |
+
|
229 |
+
iface = gr.Interface(
|
230 |
+
fn=analyze,
|
231 |
+
inputs=gr.Textbox(
|
232 |
+
lines=5,
|
233 |
+
placeholder="Enter your marketing content here..."
|
234 |
+
),
|
235 |
+
outputs=gr.Textbox(),
|
236 |
+
title="Marketing Content Analyzer",
|
237 |
+
description="Analyze your marketing content using Gemma Scope's neural features",
|
238 |
+
examples=[
|
239 |
+
["WordLift is an AI-powered SEO tool"],
|
240 |
+
["Our advanced machine learning algorithms optimize your content"],
|
241 |
+
["Simple and effective website optimization"]
|
242 |
+
]
|
243 |
+
)
|
244 |
+
|
245 |
+
return iface
|
246 |
+
|
247 |
+
if __name__ == "__main__":
|
248 |
+
iface = create_gradio_interface()
|
249 |
+
iface.launch()
|