peter-zeng commited on
Commit
ca61898
Β·
1 Parent(s): 07c4d0f

modified to use mystery + predicted

Browse files
Files changed (1) hide show
  1. utils/visualizations.py +57 -26
utils/visualizations.py CHANGED
@@ -10,7 +10,7 @@ import plotly.graph_objects as go
10
  from plotly.colors import sample_colorscale
11
  from gradio import update
12
  import re
13
- from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation, compute_precomputed_regions
14
  from utils.llm_feat_utils import split_features
15
  from utils.gram2vec_feat_utils import get_shorthand, get_fullform
16
  from gram2vec.feature_locator import find_feature_spans
@@ -204,11 +204,12 @@ def load_interp_space(cfg):
204
  # Function to process G2V features and create display choices
205
  def format_g2v_features_for_display(g2v_features_with_scores):
206
  """
207
- Convert G2V features with z-scores into display format for Gradio radio buttons.
 
208
 
209
  Args:
210
  g2v_features_with_scores: List of tuples like:
211
- [('None', None), ('Feature Name', z_score), ...]
212
 
213
  Returns:
214
  tuple: (display_choices, original_values)
@@ -218,21 +219,21 @@ def format_g2v_features_for_display(g2v_features_with_scores):
218
 
219
  for item in g2v_features_with_scores:
220
  if len(item) == 2:
221
- feature_name, z_score = item
222
 
223
  # Handle None case
224
- if feature_name == "None" or z_score is None:
225
  display_choices.append("None")
226
  original_values.append("None")
227
  else:
228
  # Convert numpy float to regular float if needed
229
- if hasattr(z_score, 'item'):
230
- z_score = float(z_score.item())
231
  else:
232
- z_score = float(z_score)
233
 
234
- # Create display string with z-score
235
- display_string = f"{feature_name} | [Z={z_score:.2f}]"
236
  display_choices.append(display_string)
237
  original_values.append(feature_name)
238
  else:
@@ -275,17 +276,18 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
275
  print(f"Task authors: {len(task_authors_df)}, Clustered authors: {len(clustered_authors_df)}")
276
  merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
277
  print(f"Merged authors DataFrame:\n{len(merged_authors_df)}")
278
- #style_analysis_response = {'features': [], 'spans': []}
279
- style_analysis_response = compute_clusters_style_representation_3(
280
- background_corpus_df=merged_authors_df,
281
- cluster_ids=visible_authors,
282
- cluster_label_clm_name='authorID',
283
- )
284
 
285
  llm_feats = ['None'] + style_analysis_response['features']
286
 
287
 
288
  merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
 
289
  g2v_feats = compute_clusters_g2v_representation(
290
  background_corpus_df=merged_authors_df,
291
  author_ids=visible_authors,
@@ -293,6 +295,34 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
293
  features_clm_name='g2v_vector'
294
  )
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  # ── Span-existence filter on task authors in the zoom ───────────────────
297
  # Keep only features that have at least one detected span in any of the
298
  # visible task authors' texts
@@ -305,16 +335,17 @@ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors
305
 
306
  task_texts = [_to_text(x) for x in visible_task_authors['fullText'].tolist()]
307
 
308
- filtered_g2v_feats = []
309
- for feat in g2v_feats:
310
- try:
311
- # `feat` is shorthand already (e.g., 'pos_bigrams:NOUN PROPN')
312
- if any(find_feature_spans(txt, feat[0]) for txt in task_texts):
313
- filtered_g2v_feats.append(feat)
314
- else:
315
- print(f"[INFO] Dropping G2V feature with no spans in task texts: {feat}")
316
- except Exception as e:
317
- print(f"[WARN] Error while checking spans for {feat}: {e}")
 
318
 
319
  # Convert to human readable for display
320
  HR_g2v_list = []
 
10
  from plotly.colors import sample_colorscale
11
  from gradio import update
12
  import re
13
+ from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation, compute_precomputed_regions, compute_task_only_g2v_similarity
14
  from utils.llm_feat_utils import split_features
15
  from utils.gram2vec_feat_utils import get_shorthand, get_fullform
16
  from gram2vec.feature_locator import find_feature_spans
 
204
  # Function to process G2V features and create display choices
205
  def format_g2v_features_for_display(g2v_features_with_scores):
206
  """
207
+ Convert G2V features with a numeric score into display format for Gradio radio buttons.
208
+ The label uses S= for a generic similarity score (not Z).
209
 
210
  Args:
211
  g2v_features_with_scores: List of tuples like:
212
+ [('None', None), ('Feature Name', score), ...]
213
 
214
  Returns:
215
  tuple: (display_choices, original_values)
 
219
 
220
  for item in g2v_features_with_scores:
221
  if len(item) == 2:
222
+ feature_name, score = item
223
 
224
  # Handle None case
225
+ if feature_name == "None" or score is None:
226
  display_choices.append("None")
227
  original_values.append("None")
228
  else:
229
  # Convert numpy float to regular float if needed
230
+ if hasattr(score, 'item'):
231
+ score = float(score.item())
232
  else:
233
+ score = float(score)
234
 
235
+ # Create display string with similarity score
236
+ display_string = f"{feature_name}"
237
  display_choices.append(display_string)
238
  original_values.append(feature_name)
239
  else:
 
276
  print(f"Task authors: {len(task_authors_df)}, Clustered authors: {len(clustered_authors_df)}")
277
  merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
278
  print(f"Merged authors DataFrame:\n{len(merged_authors_df)}")
279
+ style_analysis_response = {'features': [], 'spans': []}
280
+ # style_analysis_response = compute_clusters_style_representation_3(
281
+ # background_corpus_df=merged_authors_df,
282
+ # cluster_ids=visible_authors,
283
+ # cluster_label_clm_name='authorID',
284
+ # )
285
 
286
  llm_feats = ['None'] + style_analysis_response['features']
287
 
288
 
289
  merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
290
+ # Default: contrastive Gram2Vec features
291
  g2v_feats = compute_clusters_g2v_representation(
292
  background_corpus_df=merged_authors_df,
293
  author_ids=visible_authors,
 
295
  features_clm_name='g2v_vector'
296
  )
297
 
298
+ # If both Mystery and the predicted candidate are inside the zoom, switch to task-only similarity
299
+ task_author_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
300
+ visible_task_names = set(visible_authors).intersection(task_author_names)
301
+ predicted_in_visible = None
302
+ if 'predicted' in task_authors_df.columns:
303
+ preds = task_authors_df[task_authors_df['predicted'] == True]['authorID'].tolist()
304
+ if preds:
305
+ predicted_in_visible = preds[0] if preds[0] in visible_task_names else None
306
+
307
+ use_task_only = ('Mystery author' in visible_task_names) and (predicted_in_visible is not None)
308
+ if use_task_only:
309
+ print("[INFO] Using task-only Gram2Vec similarity (Mystery + Predicted candidate) within zoom")
310
+ try:
311
+ g2v_feats = compute_task_only_g2v_similarity(
312
+ background_corpus_df=merged_authors_df,
313
+ visible_author_ids=visible_authors,
314
+ features_clm_name='g2v_vector',
315
+ top_n=10,
316
+ require_spans=True
317
+ )
318
+ # g2v_feats already enforces spans for both authors; treat as final
319
+ filtered_g2v_feats = g2v_feats
320
+ except Exception as e:
321
+ print(f"[WARN] Task-only similarity failed, falling back to contrastive: {e}")
322
+ filtered_g2v_feats = None
323
+ else:
324
+ filtered_g2v_feats = None
325
+
326
  # ── Span-existence filter on task authors in the zoom ───────────────────
327
  # Keep only features that have at least one detected span in any of the
328
  # visible task authors' texts
 
335
 
336
  task_texts = [_to_text(x) for x in visible_task_authors['fullText'].tolist()]
337
 
338
+ if filtered_g2v_feats is None:
339
+ filtered_g2v_feats = []
340
+ for feat in g2v_feats:
341
+ try:
342
+ # `feat` is shorthand already (e.g., 'pos_bigrams:NOUN PROPN')
343
+ if any(find_feature_spans(txt, feat[0]) for txt in task_texts):
344
+ filtered_g2v_feats.append(feat)
345
+ else:
346
+ print(f"[INFO] Dropping G2V feature with no spans in task texts: {feat}")
347
+ except Exception as e:
348
+ print(f"[WARN] Error while checking spans for {feat}: {e}")
349
 
350
  # Convert to human readable for display
351
  HR_g2v_list = []