Spaces:
Sleeping
Sleeping
update
Browse files
app.py
CHANGED
@@ -11,9 +11,11 @@ import logging
|
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
|
|
14 |
@dataclass
|
15 |
class MarketingFeature:
|
16 |
"""Structure to hold marketing-relevant feature information"""
|
|
|
17 |
feature_id: int
|
18 |
name: str
|
19 |
category: str
|
@@ -22,6 +24,7 @@ class MarketingFeature:
|
|
22 |
layer: int
|
23 |
threshold: float = 0.1
|
24 |
|
|
|
25 |
# Define marketing-relevant features from Gemma Scope
|
26 |
MARKETING_FEATURES = [
|
27 |
MarketingFeature(
|
@@ -30,7 +33,7 @@ MARKETING_FEATURES = [
|
|
30 |
category="technical",
|
31 |
description="Detects technical and specialized terminology",
|
32 |
interpretation_guide="High activation indicates strong technical focus",
|
33 |
-
layer=20
|
34 |
),
|
35 |
MarketingFeature(
|
36 |
feature_id=6680,
|
@@ -38,7 +41,7 @@ MARKETING_FEATURES = [
|
|
38 |
category="technical",
|
39 |
description="Identifies complex technical concepts",
|
40 |
interpretation_guide="Consider simplifying language if activation is too high",
|
41 |
-
layer=20
|
42 |
),
|
43 |
MarketingFeature(
|
44 |
feature_id=2,
|
@@ -46,10 +49,11 @@ MARKETING_FEATURES = [
|
|
46 |
category="seo",
|
47 |
description="Identifies potential SEO keywords",
|
48 |
interpretation_guide="High activation suggests strong SEO potential",
|
49 |
-
layer=20
|
50 |
),
|
51 |
]
|
52 |
|
|
|
53 |
class MarketingAnalyzer:
|
54 |
"""Main class for analyzing marketing content using Gemma Scope"""
|
55 |
|
@@ -67,8 +71,7 @@ class MarketingAnalyzer:
|
|
67 |
|
68 |
# Initialize model and tokenizer with token from environment
|
69 |
self.model = AutoModelForCausalLM.from_pretrained(
|
70 |
-
model_name,
|
71 |
-
device_map='auto'
|
72 |
)
|
73 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
74 |
|
@@ -87,25 +90,30 @@ class MarketingAnalyzer:
|
|
87 |
# Load SAE parameters for each feature
|
88 |
path = hf_hub_download(
|
89 |
repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
|
90 |
-
filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz"
|
91 |
)
|
92 |
params = np.load(path)
|
93 |
self.saes[feature.feature_id] = {
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
96 |
}
|
97 |
logger.info(f"Loaded SAE for feature {feature.feature_id}")
|
98 |
except Exception as e:
|
99 |
-
logger.error(
|
|
|
|
|
100 |
continue
|
101 |
|
102 |
def analyze_content(self, text: str) -> Dict:
|
103 |
"""Analyze marketing content using loaded SAEs"""
|
104 |
results = {
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
}
|
110 |
|
111 |
try:
|
@@ -116,14 +124,12 @@ class MarketingAnalyzer:
|
|
116 |
|
117 |
# Analyze each feature
|
118 |
for feature_id, sae_data in self.saes.items():
|
119 |
-
feature = sae_data[
|
120 |
layer_output = outputs.hidden_states[feature.layer]
|
121 |
|
122 |
# Apply SAE
|
123 |
activations = self._apply_sae(
|
124 |
-
layer_output,
|
125 |
-
sae_data['params'],
|
126 |
-
feature.threshold
|
127 |
)
|
128 |
|
129 |
# Skip BOS token and handle empty activations
|
@@ -137,25 +143,24 @@ class MarketingAnalyzer:
|
|
137 |
|
138 |
# Record results
|
139 |
feature_result = {
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
mean_activation,
|
146 |
-
|
147 |
-
)
|
148 |
}
|
149 |
|
150 |
-
results[
|
151 |
|
152 |
# Aggregate by category
|
153 |
-
if feature.category not in results[
|
154 |
-
results[
|
155 |
-
results[
|
156 |
|
157 |
# Generate recommendations
|
158 |
-
results[
|
159 |
|
160 |
except Exception as e:
|
161 |
logger.error(f"Error analyzing content: {str(e)}")
|
@@ -167,18 +172,16 @@ class MarketingAnalyzer:
|
|
167 |
self,
|
168 |
activations: torch.Tensor,
|
169 |
sae_params: Dict[str, torch.Tensor],
|
170 |
-
threshold: float
|
171 |
) -> torch.Tensor:
|
172 |
"""Apply SAE to get feature activations"""
|
173 |
-
pre_acts = activations @ sae_params[
|
174 |
-
mask = pre_acts > sae_params[
|
175 |
acts = mask * torch.nn.functional.relu(pre_acts)
|
176 |
return acts
|
177 |
|
178 |
def _interpret_activation(
|
179 |
-
self,
|
180 |
-
activation: float,
|
181 |
-
feature: MarketingFeature
|
182 |
) -> str:
|
183 |
"""Interpret activation patterns for a feature"""
|
184 |
if activation > 0.8:
|
@@ -195,13 +198,12 @@ class MarketingAnalyzer:
|
|
195 |
try:
|
196 |
# Get technical features
|
197 |
tech_features = [
|
198 |
-
f for f in results[
|
199 |
-
if f['category'] == 'technical'
|
200 |
]
|
201 |
|
202 |
# Calculate average technical score if we have features
|
203 |
if tech_features:
|
204 |
-
tech_score = np.mean([f[
|
205 |
|
206 |
if tech_score > 0.8:
|
207 |
recommendations.append(
|
@@ -216,6 +218,7 @@ class MarketingAnalyzer:
|
|
216 |
|
217 |
return recommendations
|
218 |
|
|
|
219 |
def create_gradio_interface():
|
220 |
"""Create Gradio interface for marketing analysis"""
|
221 |
try:
|
@@ -227,7 +230,7 @@ def create_gradio_interface():
|
|
227 |
inputs=gr.Textbox(),
|
228 |
outputs=gr.Textbox(),
|
229 |
title="Marketing Content Analyzer (Error)",
|
230 |
-
description="Failed to initialize. Please check if HF_TOKEN is properly set."
|
231 |
)
|
232 |
|
233 |
def analyze(text):
|
@@ -238,31 +241,29 @@ def create_gradio_interface():
|
|
238 |
|
239 |
# Overall category scores
|
240 |
output += "Category Scores:\n"
|
241 |
-
for category, features in results[
|
242 |
if features: # Check if we have features for this category
|
243 |
-
avg_score = np.mean([f[
|
244 |
output += f"{category.title()}: {avg_score:.2f}\n"
|
245 |
|
246 |
# Feature details
|
247 |
output += "\nFeature Details:\n"
|
248 |
-
for feature_id, feature in results[
|
249 |
output += f"\n{feature['name']}:\n"
|
250 |
output += f"Score: {feature['activation_score']:.2f}\n"
|
251 |
output += f"Interpretation: {feature['interpretation']}\n"
|
252 |
|
253 |
# Recommendations
|
254 |
-
if results[
|
255 |
output += "\nRecommendations:\n"
|
256 |
-
for rec in results[
|
257 |
output += f"- {rec}\n"
|
258 |
|
259 |
return output
|
260 |
|
261 |
# Create interface with custom theming
|
262 |
custom_theme = gr.themes.Soft(
|
263 |
-
primary_hue="indigo",
|
264 |
-
secondary_hue="blue",
|
265 |
-
neutral_hue="gray"
|
266 |
)
|
267 |
|
268 |
interface = gr.Interface(
|
@@ -270,7 +271,7 @@ def create_gradio_interface():
|
|
270 |
inputs=gr.Textbox(
|
271 |
lines=5,
|
272 |
placeholder="Enter your marketing content here...",
|
273 |
-
label="Marketing Content"
|
274 |
),
|
275 |
outputs=gr.Textbox(label="Analysis Results"),
|
276 |
title="Marketing Content Analyzer",
|
@@ -278,14 +279,14 @@ def create_gradio_interface():
|
|
278 |
examples=[
|
279 |
["WordLift is an AI-powered SEO tool"],
|
280 |
["Our advanced machine learning algorithms optimize your content"],
|
281 |
-
["Simple and effective website optimization"]
|
282 |
],
|
283 |
-
theme=custom_theme
|
284 |
-
)
|
285 |
)
|
286 |
|
287 |
return interface
|
288 |
|
|
|
289 |
if __name__ == "__main__":
|
290 |
iface = create_gradio_interface()
|
291 |
-
iface.launch()
|
|
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
14 |
+
|
15 |
@dataclass
|
16 |
class MarketingFeature:
|
17 |
"""Structure to hold marketing-relevant feature information"""
|
18 |
+
|
19 |
feature_id: int
|
20 |
name: str
|
21 |
category: str
|
|
|
24 |
layer: int
|
25 |
threshold: float = 0.1
|
26 |
|
27 |
+
|
28 |
# Define marketing-relevant features from Gemma Scope
|
29 |
MARKETING_FEATURES = [
|
30 |
MarketingFeature(
|
|
|
33 |
category="technical",
|
34 |
description="Detects technical and specialized terminology",
|
35 |
interpretation_guide="High activation indicates strong technical focus",
|
36 |
+
layer=20,
|
37 |
),
|
38 |
MarketingFeature(
|
39 |
feature_id=6680,
|
|
|
41 |
category="technical",
|
42 |
description="Identifies complex technical concepts",
|
43 |
interpretation_guide="Consider simplifying language if activation is too high",
|
44 |
+
layer=20,
|
45 |
),
|
46 |
MarketingFeature(
|
47 |
feature_id=2,
|
|
|
49 |
category="seo",
|
50 |
description="Identifies potential SEO keywords",
|
51 |
interpretation_guide="High activation suggests strong SEO potential",
|
52 |
+
layer=20,
|
53 |
),
|
54 |
]
|
55 |
|
56 |
+
|
57 |
class MarketingAnalyzer:
|
58 |
"""Main class for analyzing marketing content using Gemma Scope"""
|
59 |
|
|
|
71 |
|
72 |
# Initialize model and tokenizer with token from environment
|
73 |
self.model = AutoModelForCausalLM.from_pretrained(
|
74 |
+
model_name, device_map="auto"
|
|
|
75 |
)
|
76 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
77 |
|
|
|
90 |
# Load SAE parameters for each feature
|
91 |
path = hf_hub_download(
|
92 |
repo_id=f"google/gemma-scope-{self.model_size}-pt-res",
|
93 |
+
filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz",
|
94 |
)
|
95 |
params = np.load(path)
|
96 |
self.saes[feature.feature_id] = {
|
97 |
+
"params": {
|
98 |
+
k: torch.from_numpy(v).to(self.device)
|
99 |
+
for k, v in params.items()
|
100 |
+
},
|
101 |
+
"feature": feature,
|
102 |
}
|
103 |
logger.info(f"Loaded SAE for feature {feature.feature_id}")
|
104 |
except Exception as e:
|
105 |
+
logger.error(
|
106 |
+
f"Error loading SAE for feature {feature.feature_id}: {str(e)}"
|
107 |
+
)
|
108 |
continue
|
109 |
|
110 |
def analyze_content(self, text: str) -> Dict:
|
111 |
"""Analyze marketing content using loaded SAEs"""
|
112 |
results = {
|
113 |
+
"text": text,
|
114 |
+
"features": {},
|
115 |
+
"categories": {},
|
116 |
+
"recommendations": [],
|
117 |
}
|
118 |
|
119 |
try:
|
|
|
124 |
|
125 |
# Analyze each feature
|
126 |
for feature_id, sae_data in self.saes.items():
|
127 |
+
feature = sae_data["feature"]
|
128 |
layer_output = outputs.hidden_states[feature.layer]
|
129 |
|
130 |
# Apply SAE
|
131 |
activations = self._apply_sae(
|
132 |
+
layer_output, sae_data["params"], feature.threshold
|
|
|
|
|
133 |
)
|
134 |
|
135 |
# Skip BOS token and handle empty activations
|
|
|
143 |
|
144 |
# Record results
|
145 |
feature_result = {
|
146 |
+
"name": feature.name,
|
147 |
+
"category": feature.category,
|
148 |
+
"activation_score": mean_activation,
|
149 |
+
"max_activation": max_activation,
|
150 |
+
"interpretation": self._interpret_activation(
|
151 |
+
mean_activation, feature
|
152 |
+
),
|
|
|
153 |
}
|
154 |
|
155 |
+
results["features"][feature_id] = feature_result
|
156 |
|
157 |
# Aggregate by category
|
158 |
+
if feature.category not in results["categories"]:
|
159 |
+
results["categories"][feature.category] = []
|
160 |
+
results["categories"][feature.category].append(feature_result)
|
161 |
|
162 |
# Generate recommendations
|
163 |
+
results["recommendations"] = self._generate_recommendations(results)
|
164 |
|
165 |
except Exception as e:
|
166 |
logger.error(f"Error analyzing content: {str(e)}")
|
|
|
172 |
self,
|
173 |
activations: torch.Tensor,
|
174 |
sae_params: Dict[str, torch.Tensor],
|
175 |
+
threshold: float,
|
176 |
) -> torch.Tensor:
|
177 |
"""Apply SAE to get feature activations"""
|
178 |
+
pre_acts = activations @ sae_params["W_enc"] + sae_params["b_enc"]
|
179 |
+
mask = pre_acts > sae_params["threshold"]
|
180 |
acts = mask * torch.nn.functional.relu(pre_acts)
|
181 |
return acts
|
182 |
|
183 |
def _interpret_activation(
|
184 |
+
self, activation: float, feature: MarketingFeature
|
|
|
|
|
185 |
) -> str:
|
186 |
"""Interpret activation patterns for a feature"""
|
187 |
if activation > 0.8:
|
|
|
198 |
try:
|
199 |
# Get technical features
|
200 |
tech_features = [
|
201 |
+
f for f in results["features"].values() if f["category"] == "technical"
|
|
|
202 |
]
|
203 |
|
204 |
# Calculate average technical score if we have features
|
205 |
if tech_features:
|
206 |
+
tech_score = np.mean([f["activation_score"] for f in tech_features])
|
207 |
|
208 |
if tech_score > 0.8:
|
209 |
recommendations.append(
|
|
|
218 |
|
219 |
return recommendations
|
220 |
|
221 |
+
|
222 |
def create_gradio_interface():
|
223 |
"""Create Gradio interface for marketing analysis"""
|
224 |
try:
|
|
|
230 |
inputs=gr.Textbox(),
|
231 |
outputs=gr.Textbox(),
|
232 |
title="Marketing Content Analyzer (Error)",
|
233 |
+
description="Failed to initialize. Please check if HF_TOKEN is properly set.",
|
234 |
)
|
235 |
|
236 |
def analyze(text):
|
|
|
241 |
|
242 |
# Overall category scores
|
243 |
output += "Category Scores:\n"
|
244 |
+
for category, features in results["categories"].items():
|
245 |
if features: # Check if we have features for this category
|
246 |
+
avg_score = np.mean([f["activation_score"] for f in features])
|
247 |
output += f"{category.title()}: {avg_score:.2f}\n"
|
248 |
|
249 |
# Feature details
|
250 |
output += "\nFeature Details:\n"
|
251 |
+
for feature_id, feature in results["features"].items():
|
252 |
output += f"\n{feature['name']}:\n"
|
253 |
output += f"Score: {feature['activation_score']:.2f}\n"
|
254 |
output += f"Interpretation: {feature['interpretation']}\n"
|
255 |
|
256 |
# Recommendations
|
257 |
+
if results["recommendations"]:
|
258 |
output += "\nRecommendations:\n"
|
259 |
+
for rec in results["recommendations"]:
|
260 |
output += f"- {rec}\n"
|
261 |
|
262 |
return output
|
263 |
|
264 |
# Create interface with custom theming
|
265 |
custom_theme = gr.themes.Soft(
|
266 |
+
primary_hue="indigo", secondary_hue="blue", neutral_hue="gray"
|
|
|
|
|
267 |
)
|
268 |
|
269 |
interface = gr.Interface(
|
|
|
271 |
inputs=gr.Textbox(
|
272 |
lines=5,
|
273 |
placeholder="Enter your marketing content here...",
|
274 |
+
label="Marketing Content",
|
275 |
),
|
276 |
outputs=gr.Textbox(label="Analysis Results"),
|
277 |
title="Marketing Content Analyzer",
|
|
|
279 |
examples=[
|
280 |
["WordLift is an AI-powered SEO tool"],
|
281 |
["Our advanced machine learning algorithms optimize your content"],
|
282 |
+
["Simple and effective website optimization"],
|
283 |
],
|
284 |
+
theme=custom_theme,
|
|
|
285 |
)
|
286 |
|
287 |
return interface
|
288 |
|
289 |
+
|
290 |
if __name__ == "__main__":
|
291 |
iface = create_gradio_interface()
|
292 |
+
iface.launch()
|