Commit
·
07c4d0f
1
Parent(s):
88507e8
updated to use mystery + predicted
Browse files- utils/interp_space_utils.py +133 -3
utils/interp_space_utils.py
CHANGED
|
@@ -31,6 +31,14 @@ os.makedirs(os.path.dirname(REGION_CACHE), exist_ok=True)
|
|
| 31 |
# Bump this whenever there is a change etc...
|
| 32 |
CACHE_VERSION = 1
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
class style_analysis_schema(BaseModel):
|
| 35 |
features: list[str]
|
| 36 |
spans: dict[str, dict[str, list[str]]]
|
|
@@ -59,8 +67,8 @@ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd
|
|
| 59 |
print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
|
| 60 |
|
| 61 |
# Gather the input texts (preserves list-of-strings if any)
|
| 62 |
-
#
|
| 63 |
-
author_texts = ['\n\n'.join(x) for x in clustered_authors_df.fullText.tolist()]
|
| 64 |
|
| 65 |
print(f"Number of author_texts: {len(author_texts)}")
|
| 66 |
|
|
@@ -686,7 +694,11 @@ def compute_clusters_g2v_representation(
|
|
| 686 |
|
| 687 |
# Keep only features that have a positive contrastive score
|
| 688 |
top_g2v_feats = sorted(
|
| 689 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
key=lambda x: -x[1] # Sort by contrastive score
|
| 691 |
)
|
| 692 |
|
|
@@ -776,6 +788,124 @@ def compute_clusters_g2v_representation(
|
|
| 776 |
|
| 777 |
return filtered_features[:top_n] # Return tuples with z-scores
|
| 778 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
|
| 780 |
|
| 781 |
styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]
|
|
|
|
| 31 |
# Bump this whenever there is a change etc...
|
| 32 |
CACHE_VERSION = 1
|
| 33 |
|
| 34 |
+
# Features to exclude from Gram2Vec outputs
|
| 35 |
+
EXCLUDED_G2V_FEATURE_PREFIXES = [
|
| 36 |
+
'num_tokens'
|
| 37 |
+
]
|
| 38 |
+
EXCLUDED_G2V_FEATURES = set([
|
| 39 |
+
'num_tokens:num_tokens'
|
| 40 |
+
])
|
| 41 |
+
|
| 42 |
class style_analysis_schema(BaseModel):
|
| 43 |
features: list[str]
|
| 44 |
spans: dict[str, dict[str, list[str]]]
|
|
|
|
| 67 |
print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
|
| 68 |
|
| 69 |
# Gather the input texts (preserves list-of-strings if any)
|
| 70 |
+
# If an entry is a list of strings, join; otherwise use the string as-is
|
| 71 |
+
author_texts = [('\n\n'.join(x) if isinstance(x, list) else x) for x in clustered_authors_df.fullText.tolist()]
|
| 72 |
|
| 73 |
print(f"Number of author_texts: {len(author_texts)}")
|
| 74 |
|
|
|
|
| 694 |
|
| 695 |
# Keep only features that have a positive contrastive score
|
| 696 |
top_g2v_feats = sorted(
|
| 697 |
+
[
|
| 698 |
+
(feat, val, z_score)
|
| 699 |
+
for feat, val, z_score in zip(all_g2v_feats, final_g2v_feats_values, z_scores)
|
| 700 |
+
if val > 0 and feat not in EXCLUDED_G2V_FEATURES and not any(feat.startswith(p) for p in EXCLUDED_G2V_FEATURE_PREFIXES)
|
| 701 |
+
],
|
| 702 |
key=lambda x: -x[1] # Sort by contrastive score
|
| 703 |
)
|
| 704 |
|
|
|
|
| 788 |
|
| 789 |
return filtered_features[:top_n] # Return tuples with z-scores
|
| 790 |
|
| 791 |
+
def compute_task_only_g2v_similarity(
|
| 792 |
+
background_corpus_df: pd.DataFrame,
|
| 793 |
+
visible_author_ids: List[Any],
|
| 794 |
+
features_clm_name: str = 'g2v_vector',
|
| 795 |
+
top_n: int = 10,
|
| 796 |
+
require_spans: bool = True
|
| 797 |
+
) -> List[tuple]:
|
| 798 |
+
"""
|
| 799 |
+
Compute top Gram2Vec features that are shared between the Mystery author and the
|
| 800 |
+
predicted Candidate author, ignoring background authors and contrast.
|
| 801 |
+
|
| 802 |
+
Selection is limited to task authors within the zoom (i.e., present in
|
| 803 |
+
`visible_author_ids`). A feature is kept if:
|
| 804 |
+
- it has a positive value (> 0) for both Mystery and Predicted Candidate,
|
| 805 |
+
- and (optionally) at least one detected span exists in both authors' texts.
|
| 806 |
+
|
| 807 |
+
Scoring strategy prioritizes features strong in both authors: score = min(mystery_value, predicted_value).
|
| 808 |
+
|
| 809 |
+
Returns a list of (feature_name, score) tuples sorted by score desc, limited to top_n.
|
| 810 |
+
"""
|
| 811 |
+
task_names = {'Mystery author', 'Candidate Author 1', 'Candidate Author 2', 'Candidate Author 3'}
|
| 812 |
+
|
| 813 |
+
# Filter to visible task authors
|
| 814 |
+
is_visible = background_corpus_df['authorID'].isin(visible_author_ids)
|
| 815 |
+
is_task = background_corpus_df['authorID'].isin(task_names)
|
| 816 |
+
visible_task_df = background_corpus_df[is_visible & is_task]
|
| 817 |
+
|
| 818 |
+
if visible_task_df.empty:
|
| 819 |
+
return []
|
| 820 |
+
|
| 821 |
+
# Identify Mystery author row within the visible set
|
| 822 |
+
mystery_rows = visible_task_df[visible_task_df['authorID'] == 'Mystery author']
|
| 823 |
+
if mystery_rows.empty:
|
| 824 |
+
# If Mystery is not visible, fall back to using any available Mystery row in the corpus
|
| 825 |
+
mystery_rows = background_corpus_df[background_corpus_df['authorID'] == 'Mystery author']
|
| 826 |
+
if mystery_rows.empty:
|
| 827 |
+
return []
|
| 828 |
+
|
| 829 |
+
mystery_row = mystery_rows.iloc[0]
|
| 830 |
+
|
| 831 |
+
# Identify the predicted candidate within the visible set using the 'predicted' flag if present
|
| 832 |
+
predicted_row = None
|
| 833 |
+
if 'predicted' in visible_task_df.columns:
|
| 834 |
+
pred_candidates = visible_task_df[visible_task_df['predicted'] == True]
|
| 835 |
+
if not pred_candidates.empty:
|
| 836 |
+
predicted_row = pred_candidates.iloc[0]
|
| 837 |
+
|
| 838 |
+
# If not found in visible, try to find anywhere in the corpus
|
| 839 |
+
if predicted_row is None and 'predicted' in background_corpus_df.columns:
|
| 840 |
+
pred_any = background_corpus_df[background_corpus_df['predicted'] == True]
|
| 841 |
+
# Prefer one that is also a task author
|
| 842 |
+
pred_any = pred_any[pred_any['authorID'].isin(task_names)] if not pred_any.empty else pred_any
|
| 843 |
+
if not pred_any.empty:
|
| 844 |
+
predicted_row = pred_any.iloc[0]
|
| 845 |
+
|
| 846 |
+
# If still not found, we cannot build a pair
|
| 847 |
+
if predicted_row is None:
|
| 848 |
+
return []
|
| 849 |
+
|
| 850 |
+
mystery_vec = mystery_row.get(features_clm_name, {})
|
| 851 |
+
predicted_vec = predicted_row.get(features_clm_name, {})
|
| 852 |
+
|
| 853 |
+
if not isinstance(mystery_vec, dict) or not isinstance(predicted_vec, dict):
|
| 854 |
+
return []
|
| 855 |
+
|
| 856 |
+
# Prepare texts for optional span gating
|
| 857 |
+
def _norm_txt(x):
|
| 858 |
+
if isinstance(x, list):
|
| 859 |
+
return '\n\n'.join(x)
|
| 860 |
+
return str(x)
|
| 861 |
+
mystery_text = _norm_txt(mystery_row.get('fullText', ''))
|
| 862 |
+
predicted_text = _norm_txt(predicted_row.get('fullText', ''))
|
| 863 |
+
|
| 864 |
+
try:
|
| 865 |
+
from gram2vec.feature_locator import find_feature_spans as _find_feature_spans
|
| 866 |
+
except Exception:
|
| 867 |
+
_find_feature_spans = None
|
| 868 |
+
|
| 869 |
+
shared_features = []
|
| 870 |
+
# Iterate over union of feature keys (both authors share the same feature space in practice)
|
| 871 |
+
for feature_name in set(list(mystery_vec.keys()) + list(predicted_vec.keys())):
|
| 872 |
+
# Exclude unwanted features
|
| 873 |
+
if feature_name in EXCLUDED_G2V_FEATURES or any(feature_name.startswith(p) for p in EXCLUDED_G2V_FEATURE_PREFIXES):
|
| 874 |
+
continue
|
| 875 |
+
m_val = float(mystery_vec.get(feature_name, 0.0))
|
| 876 |
+
p_val = float(predicted_vec.get(feature_name, 0.0))
|
| 877 |
+
|
| 878 |
+
# Optional span gate: require at least one span in both texts
|
| 879 |
+
spans_m = spans_p = None
|
| 880 |
+
if require_spans and _find_feature_spans is not None:
|
| 881 |
+
try:
|
| 882 |
+
spans_m = _find_feature_spans(mystery_text, feature_name) or []
|
| 883 |
+
spans_p = _find_feature_spans(predicted_text, feature_name) or []
|
| 884 |
+
if len(spans_m) == 0 or len(spans_p) == 0:
|
| 885 |
+
continue
|
| 886 |
+
except Exception:
|
| 887 |
+
# On span errors, skip gating and proceed
|
| 888 |
+
spans_m = spans_m if spans_m is not None else []
|
| 889 |
+
spans_p = spans_p if spans_p is not None else []
|
| 890 |
+
|
| 891 |
+
# Similarity metric: |m| + |p| - |m - p|
|
| 892 |
+
score = abs(m_val) + abs(p_val) - abs(m_val - p_val)
|
| 893 |
+
shared_features.append((feature_name, score, m_val, p_val, len(spans_m) if spans_m is not None else -1, len(spans_p) if spans_p is not None else -1))
|
| 894 |
+
|
| 895 |
+
# Rank by score desc and return top_n
|
| 896 |
+
shared_features.sort(key=lambda x: x[1], reverse=True)
|
| 897 |
+
top = shared_features[:top_n]
|
| 898 |
+
|
| 899 |
+
# Debug print of top-N with values and span counts for presence sanity-check
|
| 900 |
+
try:
|
| 901 |
+
print("[DEBUG] Task-only G2V top features (feature, mystery_val, predicted_val, score | spans_mystery, spans_predicted):")
|
| 902 |
+
for feat_name, sc, m_val, p_val, c_m, c_p in top:
|
| 903 |
+
print(f" {feat_name} | mystery={m_val:.4f}, predicted={p_val:.4f}, S={sc:.4f} | spans=({c_m}, {c_p})")
|
| 904 |
+
except Exception:
|
| 905 |
+
pass
|
| 906 |
+
|
| 907 |
+
return [(f, s) for (f, s, _, _, _, _) in top]
|
| 908 |
+
|
| 909 |
def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
|
| 910 |
|
| 911 |
styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]
|