cyberandy commited on
Commit
94ca202
·
1 Parent(s): 574ab91
Files changed (1) hide show
  1. app.py +54 -53
app.py CHANGED
@@ -11,9 +11,11 @@ import logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
 
14
  @dataclass
15
  class MarketingFeature:
16
  """Structure to hold marketing-relevant feature information"""
 
17
  feature_id: int
18
  name: str
19
  category: str
@@ -22,6 +24,7 @@ class MarketingFeature:
22
  layer: int
23
  threshold: float = 0.1
24
 
 
25
  # Define marketing-relevant features from Gemma Scope
26
  MARKETING_FEATURES = [
27
  MarketingFeature(
@@ -30,7 +33,7 @@ MARKETING_FEATURES = [
30
  category="technical",
31
  description="Detects technical and specialized terminology",
32
  interpretation_guide="High activation indicates strong technical focus",
33
- layer=20
34
  ),
35
  MarketingFeature(
36
  feature_id=6680,
@@ -38,7 +41,7 @@ MARKETING_FEATURES = [
38
  category="technical",
39
  description="Identifies complex technical concepts",
40
  interpretation_guide="Consider simplifying language if activation is too high",
41
- layer=20
42
  ),
43
  MarketingFeature(
44
  feature_id=2,
@@ -46,10 +49,11 @@ MARKETING_FEATURES = [
46
  category="seo",
47
  description="Identifies potential SEO keywords",
48
  interpretation_guide="High activation suggests strong SEO potential",
49
- layer=20
50
  ),
51
  ]
52
 
 
53
  class MarketingAnalyzer:
54
  """Main class for analyzing marketing content using Gemma Scope"""
55
 
@@ -67,8 +71,7 @@ class MarketingAnalyzer:
67
 
68
  # Initialize model and tokenizer with token from environment
69
  self.model = AutoModelForCausalLM.from_pretrained(
70
- model_name,
71
- device_map='auto'
72
  )
73
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
74
 
@@ -87,25 +90,30 @@ class MarketingAnalyzer:
87
  # Load SAE parameters for each feature
88
  path = hf_hub_download(
89
  repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
90
- filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz"
91
  )
92
  params = np.load(path)
93
  self.saes[feature.feature_id] = {
94
- 'params': {k: torch.from_numpy(v).to(self.device) for k, v in params.items()},
95
- 'feature': feature
 
 
 
96
  }
97
  logger.info(f"Loaded SAE for feature {feature.feature_id}")
98
  except Exception as e:
99
- logger.error(f"Error loading SAE for feature {feature.feature_id}: {str(e)}")
 
 
100
  continue
101
 
102
  def analyze_content(self, text: str) -> Dict:
103
  """Analyze marketing content using loaded SAEs"""
104
  results = {
105
- 'text': text,
106
- 'features': {},
107
- 'categories': {},
108
- 'recommendations': []
109
  }
110
 
111
  try:
@@ -116,14 +124,12 @@ class MarketingAnalyzer:
116
 
117
  # Analyze each feature
118
  for feature_id, sae_data in self.saes.items():
119
- feature = sae_data['feature']
120
  layer_output = outputs.hidden_states[feature.layer]
121
 
122
  # Apply SAE
123
  activations = self._apply_sae(
124
- layer_output,
125
- sae_data['params'],
126
- feature.threshold
127
  )
128
 
129
  # Skip BOS token and handle empty activations
@@ -137,25 +143,24 @@ class MarketingAnalyzer:
137
 
138
  # Record results
139
  feature_result = {
140
- 'name': feature.name,
141
- 'category': feature.category,
142
- 'activation_score': mean_activation,
143
- 'max_activation': max_activation,
144
- 'interpretation': self._interpret_activation(
145
- mean_activation,
146
- feature
147
- )
148
  }
149
 
150
- results['features'][feature_id] = feature_result
151
 
152
  # Aggregate by category
153
- if feature.category not in results['categories']:
154
- results['categories'][feature.category] = []
155
- results['categories'][feature.category].append(feature_result)
156
 
157
  # Generate recommendations
158
- results['recommendations'] = self._generate_recommendations(results)
159
 
160
  except Exception as e:
161
  logger.error(f"Error analyzing content: {str(e)}")
@@ -167,18 +172,16 @@ class MarketingAnalyzer:
167
  self,
168
  activations: torch.Tensor,
169
  sae_params: Dict[str, torch.Tensor],
170
- threshold: float
171
  ) -> torch.Tensor:
172
  """Apply SAE to get feature activations"""
173
- pre_acts = activations @ sae_params['W_enc'] + sae_params['b_enc']
174
- mask = pre_acts > sae_params['threshold']
175
  acts = mask * torch.nn.functional.relu(pre_acts)
176
  return acts
177
 
178
  def _interpret_activation(
179
- self,
180
- activation: float,
181
- feature: MarketingFeature
182
  ) -> str:
183
  """Interpret activation patterns for a feature"""
184
  if activation > 0.8:
@@ -195,13 +198,12 @@ class MarketingAnalyzer:
195
  try:
196
  # Get technical features
197
  tech_features = [
198
- f for f in results['features'].values()
199
- if f['category'] == 'technical'
200
  ]
201
 
202
  # Calculate average technical score if we have features
203
  if tech_features:
204
- tech_score = np.mean([f['activation_score'] for f in tech_features])
205
 
206
  if tech_score > 0.8:
207
  recommendations.append(
@@ -216,6 +218,7 @@ class MarketingAnalyzer:
216
 
217
  return recommendations
218
 
 
219
  def create_gradio_interface():
220
  """Create Gradio interface for marketing analysis"""
221
  try:
@@ -227,7 +230,7 @@ def create_gradio_interface():
227
  inputs=gr.Textbox(),
228
  outputs=gr.Textbox(),
229
  title="Marketing Content Analyzer (Error)",
230
- description="Failed to initialize. Please check if HF_TOKEN is properly set."
231
  )
232
 
233
  def analyze(text):
@@ -238,31 +241,29 @@ def create_gradio_interface():
238
 
239
  # Overall category scores
240
  output += "Category Scores:\n"
241
- for category, features in results['categories'].items():
242
  if features: # Check if we have features for this category
243
- avg_score = np.mean([f['activation_score'] for f in features])
244
  output += f"{category.title()}: {avg_score:.2f}\n"
245
 
246
  # Feature details
247
  output += "\nFeature Details:\n"
248
- for feature_id, feature in results['features'].items():
249
  output += f"\n{feature['name']}:\n"
250
  output += f"Score: {feature['activation_score']:.2f}\n"
251
  output += f"Interpretation: {feature['interpretation']}\n"
252
 
253
  # Recommendations
254
- if results['recommendations']:
255
  output += "\nRecommendations:\n"
256
- for rec in results['recommendations']:
257
  output += f"- {rec}\n"
258
 
259
  return output
260
 
261
  # Create interface with custom theming
262
  custom_theme = gr.themes.Soft(
263
- primary_hue="indigo",
264
- secondary_hue="blue",
265
- neutral_hue="gray"
266
  )
267
 
268
  interface = gr.Interface(
@@ -270,7 +271,7 @@ def create_gradio_interface():
270
  inputs=gr.Textbox(
271
  lines=5,
272
  placeholder="Enter your marketing content here...",
273
- label="Marketing Content"
274
  ),
275
  outputs=gr.Textbox(label="Analysis Results"),
276
  title="Marketing Content Analyzer",
@@ -278,14 +279,14 @@ def create_gradio_interface():
278
  examples=[
279
  ["WordLift is an AI-powered SEO tool"],
280
  ["Our advanced machine learning algorithms optimize your content"],
281
- ["Simple and effective website optimization"]
282
  ],
283
- theme=custom_theme
284
- )
285
  )
286
 
287
  return interface
288
 
 
289
  if __name__ == "__main__":
290
  iface = create_gradio_interface()
291
- iface.launch()
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+
15
  @dataclass
16
  class MarketingFeature:
17
  """Structure to hold marketing-relevant feature information"""
18
+
19
  feature_id: int
20
  name: str
21
  category: str
 
24
  layer: int
25
  threshold: float = 0.1
26
 
27
+
28
  # Define marketing-relevant features from Gemma Scope
29
  MARKETING_FEATURES = [
30
  MarketingFeature(
 
33
  category="technical",
34
  description="Detects technical and specialized terminology",
35
  interpretation_guide="High activation indicates strong technical focus",
36
+ layer=20,
37
  ),
38
  MarketingFeature(
39
  feature_id=6680,
 
41
  category="technical",
42
  description="Identifies complex technical concepts",
43
  interpretation_guide="Consider simplifying language if activation is too high",
44
+ layer=20,
45
  ),
46
  MarketingFeature(
47
  feature_id=2,
 
49
  category="seo",
50
  description="Identifies potential SEO keywords",
51
  interpretation_guide="High activation suggests strong SEO potential",
52
+ layer=20,
53
  ),
54
  ]
55
 
56
+
57
  class MarketingAnalyzer:
58
  """Main class for analyzing marketing content using Gemma Scope"""
59
 
 
71
 
72
  # Initialize model and tokenizer with token from environment
73
  self.model = AutoModelForCausalLM.from_pretrained(
74
+ model_name, device_map="auto"
 
75
  )
76
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
77
 
 
90
  # Load SAE parameters for each feature
91
  path = hf_hub_download(
92
  repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
93
+ filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz",
94
  )
95
  params = np.load(path)
96
  self.saes[feature.feature_id] = {
97
+ "params": {
98
+ k: torch.from_numpy(v).to(self.device)
99
+ for k, v in params.items()
100
+ },
101
+ "feature": feature,
102
  }
103
  logger.info(f"Loaded SAE for feature {feature.feature_id}")
104
  except Exception as e:
105
+ logger.error(
106
+ f"Error loading SAE for feature {feature.feature_id}: {str(e)}"
107
+ )
108
  continue
109
 
110
  def analyze_content(self, text: str) -> Dict:
111
  """Analyze marketing content using loaded SAEs"""
112
  results = {
113
+ "text": text,
114
+ "features": {},
115
+ "categories": {},
116
+ "recommendations": [],
117
  }
118
 
119
  try:
 
124
 
125
  # Analyze each feature
126
  for feature_id, sae_data in self.saes.items():
127
+ feature = sae_data["feature"]
128
  layer_output = outputs.hidden_states[feature.layer]
129
 
130
  # Apply SAE
131
  activations = self._apply_sae(
132
+ layer_output, sae_data["params"], feature.threshold
 
 
133
  )
134
 
135
  # Skip BOS token and handle empty activations
 
143
 
144
  # Record results
145
  feature_result = {
146
+ "name": feature.name,
147
+ "category": feature.category,
148
+ "activation_score": mean_activation,
149
+ "max_activation": max_activation,
150
+ "interpretation": self._interpret_activation(
151
+ mean_activation, feature
152
+ ),
 
153
  }
154
 
155
+ results["features"][feature_id] = feature_result
156
 
157
  # Aggregate by category
158
+ if feature.category not in results["categories"]:
159
+ results["categories"][feature.category] = []
160
+ results["categories"][feature.category].append(feature_result)
161
 
162
  # Generate recommendations
163
+ results["recommendations"] = self._generate_recommendations(results)
164
 
165
  except Exception as e:
166
  logger.error(f"Error analyzing content: {str(e)}")
 
172
  self,
173
  activations: torch.Tensor,
174
  sae_params: Dict[str, torch.Tensor],
175
+ threshold: float,
176
  ) -> torch.Tensor:
177
  """Apply SAE to get feature activations"""
178
+ pre_acts = activations @ sae_params["W_enc"] + sae_params["b_enc"]
179
+ mask = pre_acts > sae_params["threshold"]
180
  acts = mask * torch.nn.functional.relu(pre_acts)
181
  return acts
182
 
183
  def _interpret_activation(
184
+ self, activation: float, feature: MarketingFeature
 
 
185
  ) -> str:
186
  """Interpret activation patterns for a feature"""
187
  if activation > 0.8:
 
198
  try:
199
  # Get technical features
200
  tech_features = [
201
+ f for f in results["features"].values() if f["category"] == "technical"
 
202
  ]
203
 
204
  # Calculate average technical score if we have features
205
  if tech_features:
206
+ tech_score = np.mean([f["activation_score"] for f in tech_features])
207
 
208
  if tech_score > 0.8:
209
  recommendations.append(
 
218
 
219
  return recommendations
220
 
221
+
222
  def create_gradio_interface():
223
  """Create Gradio interface for marketing analysis"""
224
  try:
 
230
  inputs=gr.Textbox(),
231
  outputs=gr.Textbox(),
232
  title="Marketing Content Analyzer (Error)",
233
+ description="Failed to initialize. Please check if HF_TOKEN is properly set.",
234
  )
235
 
236
  def analyze(text):
 
241
 
242
  # Overall category scores
243
  output += "Category Scores:\n"
244
+ for category, features in results["categories"].items():
245
  if features: # Check if we have features for this category
246
+ avg_score = np.mean([f["activation_score"] for f in features])
247
  output += f"{category.title()}: {avg_score:.2f}\n"
248
 
249
  # Feature details
250
  output += "\nFeature Details:\n"
251
+ for feature_id, feature in results["features"].items():
252
  output += f"\n{feature['name']}:\n"
253
  output += f"Score: {feature['activation_score']:.2f}\n"
254
  output += f"Interpretation: {feature['interpretation']}\n"
255
 
256
  # Recommendations
257
+ if results["recommendations"]:
258
  output += "\nRecommendations:\n"
259
+ for rec in results["recommendations"]:
260
  output += f"- {rec}\n"
261
 
262
  return output
263
 
264
  # Create interface with custom theming
265
  custom_theme = gr.themes.Soft(
266
+ primary_hue="indigo", secondary_hue="blue", neutral_hue="gray"
 
 
267
  )
268
 
269
  interface = gr.Interface(
 
271
  inputs=gr.Textbox(
272
  lines=5,
273
  placeholder="Enter your marketing content here...",
274
+ label="Marketing Content",
275
  ),
276
  outputs=gr.Textbox(label="Analysis Results"),
277
  title="Marketing Content Analyzer",
 
279
  examples=[
280
  ["WordLift is an AI-powered SEO tool"],
281
  ["Our advanced machine learning algorithms optimize your content"],
282
+ ["Simple and effective website optimization"],
283
  ],
284
+ theme=custom_theme,
 
285
  )
286
 
287
  return interface
288
 
289
+
290
  if __name__ == "__main__":
291
  iface = create_gradio_interface()
292
+ iface.launch()