peter-zeng commited on
Commit
07c4d0f
·
1 Parent(s): 88507e8

updated to use mystery + predicted

Browse files
Files changed (1) hide show
  1. utils/interp_space_utils.py +133 -3
utils/interp_space_utils.py CHANGED
@@ -31,6 +31,14 @@ os.makedirs(os.path.dirname(REGION_CACHE), exist_ok=True)
31
  # Bump this whenever there is a change etc...
32
  CACHE_VERSION = 1
33
 
 
 
 
 
 
 
 
 
34
  class style_analysis_schema(BaseModel):
35
  features: list[str]
36
  spans: dict[str, dict[str, list[str]]]
@@ -59,8 +67,8 @@ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd
59
  print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
60
 
61
  # Gather the input texts (preserves list-of-strings if any)
62
- #texts = background_corpus_df[text_clm].fillna("").tolist()
63
- author_texts = ['\n\n'.join(x) for x in clustered_authors_df.fullText.tolist()]
64
 
65
  print(f"Number of author_texts: {len(author_texts)}")
66
 
@@ -686,7 +694,11 @@ def compute_clusters_g2v_representation(
686
 
687
  # Keep only features that have a positive contrastive score
688
  top_g2v_feats = sorted(
689
- [(feat, val, z_score) for feat, val, z_score in zip(all_g2v_feats, final_g2v_feats_values, z_scores) if val > 0],
 
 
 
 
690
  key=lambda x: -x[1] # Sort by contrastive score
691
  )
692
 
@@ -776,6 +788,124 @@ def compute_clusters_g2v_representation(
776
 
777
  return filtered_features[:top_n] # Return tuples with z-scores
778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
780
 
781
  styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]
 
31
  # Bump this whenever there is a change etc...
32
  CACHE_VERSION = 1
33
 
34
+ # Features to exclude from Gram2Vec outputs
35
+ EXCLUDED_G2V_FEATURE_PREFIXES = [
36
+ 'num_tokens'
37
+ ]
38
+ EXCLUDED_G2V_FEATURES = set([
39
+ 'num_tokens:num_tokens'
40
+ ])
41
+
42
  class style_analysis_schema(BaseModel):
43
  features: list[str]
44
  spans: dict[str, dict[str, list[str]]]
 
67
  print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
68
 
69
  # Gather the input texts (preserves list-of-strings if any)
70
+ # If an entry is a list of strings, join; otherwise use the string as-is
71
+ author_texts = [('\n\n'.join(x) if isinstance(x, list) else x) for x in clustered_authors_df.fullText.tolist()]
72
 
73
  print(f"Number of author_texts: {len(author_texts)}")
74
 
 
694
 
695
  # Keep only features that have a positive contrastive score
696
  top_g2v_feats = sorted(
697
+ [
698
+ (feat, val, z_score)
699
+ for feat, val, z_score in zip(all_g2v_feats, final_g2v_feats_values, z_scores)
700
+ if val > 0 and feat not in EXCLUDED_G2V_FEATURES and not any(feat.startswith(p) for p in EXCLUDED_G2V_FEATURE_PREFIXES)
701
+ ],
702
  key=lambda x: -x[1] # Sort by contrastive score
703
  )
704
 
 
788
 
789
  return filtered_features[:top_n] # Return tuples with z-scores
790
 
791
+ def compute_task_only_g2v_similarity(
792
+ background_corpus_df: pd.DataFrame,
793
+ visible_author_ids: List[Any],
794
+ features_clm_name: str = 'g2v_vector',
795
+ top_n: int = 10,
796
+ require_spans: bool = True
797
+ ) -> List[tuple]:
798
+ """
799
+ Compute top Gram2Vec features that are shared between the Mystery author and the
800
+ predicted Candidate author, ignoring background authors and contrast.
801
+
802
+ Selection is limited to task authors within the zoom (i.e., present in
803
+ `visible_author_ids`). A feature is kept if:
804
+ - it has a positive value (> 0) for both Mystery and Predicted Candidate,
805
+ - and (optionally) at least one detected span exists in both authors' texts.
806
+
807
+ Scoring strategy prioritizes features strong in both authors: score = min(mystery_value, predicted_value).
808
+
809
+ Returns a list of (feature_name, score) tuples sorted by score desc, limited to top_n.
810
+ """
811
+ task_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
812
+
813
+ # Filter to visible task authors
814
+ is_visible = background_corpus_df['authorID'].isin(visible_author_ids)
815
+ is_task = background_corpus_df['authorID'].isin(task_names)
816
+ visible_task_df = background_corpus_df[is_visible & is_task]
817
+
818
+ if visible_task_df.empty:
819
+ return []
820
+
821
+ # Identify Mystery author row within the visible set
822
+ mystery_rows = visible_task_df[visible_task_df['authorID'] == 'Mystery author']
823
+ if mystery_rows.empty:
824
+ # If Mystery is not visible, fall back to using any available Mystery row in the corpus
825
+ mystery_rows = background_corpus_df[background_corpus_df['authorID'] == 'Mystery author']
826
+ if mystery_rows.empty:
827
+ return []
828
+
829
+ mystery_row = mystery_rows.iloc[0]
830
+
831
+ # Identify the predicted candidate within the visible set using the 'predicted' flag if present
832
+ predicted_row = None
833
+ if 'predicted' in visible_task_df.columns:
834
+ pred_candidates = visible_task_df[visible_task_df['predicted'] == True]
835
+ if not pred_candidates.empty:
836
+ predicted_row = pred_candidates.iloc[0]
837
+
838
+ # If not found in visible, try to find anywhere in the corpus
839
+ if predicted_row is None and 'predicted' in background_corpus_df.columns:
840
+ pred_any = background_corpus_df[background_corpus_df['predicted'] == True]
841
+ # Prefer one that is also a task author
842
+ pred_any = pred_any[pred_any['authorID'].isin(task_names)] if not pred_any.empty else pred_any
843
+ if not pred_any.empty:
844
+ predicted_row = pred_any.iloc[0]
845
+
846
+ # If still not found, we cannot build a pair
847
+ if predicted_row is None:
848
+ return []
849
+
850
+ mystery_vec = mystery_row.get(features_clm_name, {})
851
+ predicted_vec = predicted_row.get(features_clm_name, {})
852
+
853
+ if not isinstance(mystery_vec, dict) or not isinstance(predicted_vec, dict):
854
+ return []
855
+
856
+ # Prepare texts for optional span gating
857
+ def _norm_txt(x):
858
+ if isinstance(x, list):
859
+ return '\n\n'.join(x)
860
+ return str(x)
861
+ mystery_text = _norm_txt(mystery_row.get('fullText', ''))
862
+ predicted_text = _norm_txt(predicted_row.get('fullText', ''))
863
+
864
+ try:
865
+ from gram2vec.feature_locator import find_feature_spans as _find_feature_spans
866
+ except Exception:
867
+ _find_feature_spans = None
868
+
869
+ shared_features = []
870
+ # Iterate over union of feature keys (both authors share the same feature space in practice)
871
+ for feature_name in set(list(mystery_vec.keys()) + list(predicted_vec.keys())):
872
+ # Exclude unwanted features
873
+ if feature_name in EXCLUDED_G2V_FEATURES or any(feature_name.startswith(p) for p in EXCLUDED_G2V_FEATURE_PREFIXES):
874
+ continue
875
+ m_val = float(mystery_vec.get(feature_name, 0.0))
876
+ p_val = float(predicted_vec.get(feature_name, 0.0))
877
+
878
+ # Optional span gate: require at least one span in both texts
879
+ spans_m = spans_p = None
880
+ if require_spans and _find_feature_spans is not None:
881
+ try:
882
+ spans_m = _find_feature_spans(mystery_text, feature_name) or []
883
+ spans_p = _find_feature_spans(predicted_text, feature_name) or []
884
+ if len(spans_m) == 0 or len(spans_p) == 0:
885
+ continue
886
+ except Exception:
887
+ # On span errors, skip gating and proceed
888
+ spans_m = spans_m if spans_m is not None else []
889
+ spans_p = spans_p if spans_p is not None else []
890
+
891
+ # Similarity metric: |m| + |p| - |m - p|
892
+ score = abs(m_val) + abs(p_val) - abs(m_val - p_val)
893
+ shared_features.append((feature_name, score, m_val, p_val, len(spans_m) if spans_m is not None else -1, len(spans_p) if spans_p is not None else -1))
894
+
895
+ # Rank by score desc and return top_n
896
+ shared_features.sort(key=lambda x: x[1], reverse=True)
897
+ top = shared_features[:top_n]
898
+
899
+ # Debug print of top-N with values and span counts for presence sanity-check
900
+ try:
901
+ print("[DEBUG] Task-only G2V top features (feature, mystery_val, predicted_val, score | spans_mystery, spans_predicted):")
902
+ for feat_name, sc, m_val, p_val, c_m, c_p in top:
903
+ print(f" {feat_name} | mystery={m_val:.4f}, predicted={p_val:.4f}, S={sc:.4f} | spans=({c_m}, {c_p})")
904
+ except Exception:
905
+ pass
906
+
907
+ return [(f, s) for (f, s, _, _, _, _) in top]
908
+
909
  def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
910
 
911
  styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]