cyberandy commited on
Commit
f85532f
·
verified ·
1 Parent(s): 37deb71

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -0
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from huggingface_hub import hf_hub_download
5
+ import numpy as np
6
+ from dataclasses import dataclass
7
+ from typing import List, Dict, Optional
8
+ import logging
9
+
10
+ # Initialize logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ @dataclass
15
+ class MarketingFeature:
16
+ """Structure to hold marketing-relevant feature information"""
17
+ feature_id: int
18
+ name: str
19
+ category: str
20
+ description: str
21
+ interpretation_guide: str
22
+ layer: int
23
+ threshold: float = 0.1
24
+
25
+ # Define marketing-relevant features from Gemma Scope
26
+ MARKETING_FEATURES = [
27
+ MarketingFeature(
28
+ feature_id=35,
29
+ name="Technical Term Detector",
30
+ category="technical",
31
+ description="Detects technical and specialized terminology",
32
+ interpretation_guide="High activation indicates strong technical focus",
33
+ layer=20
34
+ ),
35
+ MarketingFeature(
36
+ feature_id=6680,
37
+ name="Compound Technical Terms",
38
+ category="technical",
39
+ description="Identifies complex technical concepts",
40
+ interpretation_guide="Consider simplifying language if activation is too high",
41
+ layer=20
42
+ ),
43
+ MarketingFeature(
44
+ feature_id=2,
45
+ name="SEO Keyword Detector",
46
+ category="seo",
47
+ description="Identifies potential SEO keywords",
48
+ interpretation_guide="High activation suggests strong SEO potential",
49
+ layer=20
50
+ ),
51
+ # Add more relevant features as we discover them
52
+ ]
53
+
54
+ class MarketingAnalyzer:
55
+ """Main class for analyzing marketing content using Gemma Scope"""
56
+
57
+ def __init__(self, model_size: str = "2b"):
58
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ self._initialize_model(model_size)
60
+ self._load_saes()
61
+
62
+ def _initialize_model(self, model_size: str):
63
+ """Initialize Gemma model and tokenizer"""
64
+ try:
65
+ model_name = f"google/gemma-{model_size}"
66
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
67
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
68
+
69
+ self.model = self.model.to(self.device)
70
+ self.model.eval()
71
+
72
+ logger.info(f"Initialized model: {model_name}")
73
+ except Exception as e:
74
+ logger.error(f"Error initializing model: {str(e)}")
75
+ raise
76
+
77
+ def _load_saes(self):
78
+ """Load relevant SAEs from Gemma Scope"""
79
+ self.saes = {}
80
+ for feature in MARKETING_FEATURES:
81
+ try:
82
+ # Load SAE parameters for each feature
83
+ path = hf_hub_download(
84
+ repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
85
+ filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz"
86
+ )
87
+ params = np.load(path)
88
+ self.saes[feature.feature_id] = {
89
+ 'params': {k: torch.from_numpy(v).cuda() for k, v in params.items()},
90
+ 'feature': feature
91
+ }
92
+ logger.info(f"Loaded SAE for feature {feature.feature_id}")
93
+ except Exception as e:
94
+ logger.error(f"Error loading SAE for feature {feature.feature_id}: {str(e)}")
95
+ continue
96
+
97
+ def analyze_content(self, text: str) -> Dict:
98
+ """Analyze marketing content using loaded SAEs"""
99
+ results = {
100
+ 'text': text,
101
+ 'features': {},
102
+ 'categories': {},
103
+ 'recommendations': []
104
+ }
105
+
106
+ try:
107
+ # Get model activations
108
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
109
+ with torch.no_grad():
110
+ outputs = self.model(**inputs, output_hidden_states=True)
111
+
112
+ # Analyze each feature
113
+ for feature_id, sae_data in self.saes.items():
114
+ feature = sae_data['feature']
115
+ layer_output = outputs.hidden_states[feature.layer]
116
+
117
+ # Apply SAE
118
+ activations = self._apply_sae(
119
+ layer_output,
120
+ sae_data['params'],
121
+ feature.threshold
122
+ )
123
+
124
+ # Record results
125
+ feature_result = {
126
+ 'name': feature.name,
127
+ 'category': feature.category,
128
+ 'activation_score': float(activations.mean()),
129
+ 'max_activation': float(activations.max()),
130
+ 'interpretation': self._interpret_activation(
131
+ activations,
132
+ feature
133
+ )
134
+ }
135
+
136
+ results['features'][feature_id] = feature_result
137
+
138
+ # Aggregate by category
139
+ if feature.category not in results['categories']:
140
+ results['categories'][feature.category] = []
141
+ results['categories'][feature.category].append(feature_result)
142
+
143
+ # Generate recommendations
144
+ results['recommendations'] = self._generate_recommendations(results)
145
+
146
+ except Exception as e:
147
+ logger.error(f"Error analyzing content: {str(e)}")
148
+ raise
149
+
150
+ return results
151
+
152
+ def _apply_sae(
153
+ self,
154
+ activations: torch.Tensor,
155
+ sae_params: Dict[str, torch.Tensor],
156
+ threshold: float
157
+ ) -> torch.Tensor:
158
+ """Apply SAE to get feature activations"""
159
+ pre_acts = activations @ sae_params['W_enc'] + sae_params['b_enc']
160
+ mask = pre_acts > sae_params['threshold']
161
+ acts = mask * torch.nn.functional.relu(pre_acts)
162
+ return acts
163
+
164
+ def _interpret_activation(
165
+ self,
166
+ activations: torch.Tensor,
167
+ feature: MarketingFeature
168
+ ) -> str:
169
+ """Interpret activation patterns for a feature"""
170
+ mean_activation = float(activations.mean())
171
+ if mean_activation > 0.8:
172
+ return f"Very strong presence of {feature.name.lower()}"
173
+ elif mean_activation > 0.5:
174
+ return f"Moderate presence of {feature.name.lower()}"
175
+ else:
176
+ return f"Limited presence of {feature.name.lower()}"
177
+
178
+ def _generate_recommendations(self, results: Dict) -> List[str]:
179
+ """Generate content recommendations based on analysis"""
180
+ recommendations = []
181
+
182
+ # Analyze technical complexity
183
+ tech_score = np.mean([
184
+ f['activation_score'] for f in results['features'].values()
185
+ if f['category'] == 'technical'
186
+ ])
187
+ if tech_score > 0.8:
188
+ recommendations.append(
189
+ "Consider simplifying technical language for broader audience"
190
+ )
191
+ elif tech_score < 0.3:
192
+ recommendations.append(
193
+ "Could benefit from more specific technical details"
194
+ )
195
+
196
+ # Add more recommendation logic as needed
197
+ return recommendations
198
+
199
+ def create_gradio_interface():
200
+ """Create Gradio interface for marketing analysis"""
201
+ analyzer = MarketingAnalyzer()
202
+
203
+ def analyze(text):
204
+ results = analyzer.analyze_content(text)
205
+
206
+ # Format results for display
207
+ output = "Content Analysis Results\n\n"
208
+
209
+ # Overall category scores
210
+ output += "Category Scores:\n"
211
+ for category, features in results['categories'].items():
212
+ avg_score = np.mean([f['activation_score'] for f in features])
213
+ output += f"{category.title()}: {avg_score:.2f}\n"
214
+
215
+ # Feature details
216
+ output += "\nFeature Details:\n"
217
+ for feature_id, feature in results['features'].items():
218
+ output += f"\n{feature['name']}:\n"
219
+ output += f"Score: {feature['activation_score']:.2f}\n"
220
+ output += f"Interpretation: {feature['interpretation']}\n"
221
+
222
+ # Recommendations
223
+ output += "\nRecommendations:\n"
224
+ for rec in results['recommendations']:
225
+ output += f"- {rec}\n"
226
+
227
+ return output
228
+
229
+ iface = gr.Interface(
230
+ fn=analyze,
231
+ inputs=gr.Textbox(
232
+ lines=5,
233
+ placeholder="Enter your marketing content here..."
234
+ ),
235
+ outputs=gr.Textbox(),
236
+ title="Marketing Content Analyzer",
237
+ description="Analyze your marketing content using Gemma Scope's neural features",
238
+ examples=[
239
+ ["WordLift is an AI-powered SEO tool"],
240
+ ["Our advanced machine learning algorithms optimize your content"],
241
+ ["Simple and effective website optimization"]
242
+ ]
243
+ )
244
+
245
+ return iface
246
+
247
+ if __name__ == "__main__":
248
+ iface = create_gradio_interface()
249
+ iface.launch()