Anisha Bhatnagar
commited on
Commit
·
bd7d9f9
1
Parent(s):
8db24a7
fixed caching issues in LLM feature identification
Browse files- utils/interp_space_utils.py +43 -3
- utils/llm_feat_utils.py +1 -1
utils/interp_space_utils.py
CHANGED
@@ -22,7 +22,9 @@ import numpy as np
|
|
22 |
from sklearn.metrics.pairwise import cosine_similarity
|
23 |
|
24 |
CACHE_DIR = "datasets/embeddings_cache"
|
|
|
25 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
|
26 |
# Bump this whenever there is a change etc...
|
27 |
CACHE_VERSION = 1
|
28 |
|
@@ -418,7 +420,34 @@ def compute_clusters_style_representation_2(
|
|
418 |
|
419 |
return parsed_response
|
420 |
|
421 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
423 |
prompt = f"""Identify {max_num_feats} writing style features that are commonly found across the following texts. Do not extract spans. Just return the feature names as a list.
|
424 |
Author Texts:
|
@@ -442,7 +471,18 @@ def identify_style_features(author_texts: list[str], max_num_feats: int = 5) ->
|
|
442 |
)
|
443 |
return json.loads(response.choices[0].message.content)
|
444 |
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
|
447 |
def retry_call(call_fn, schema_class, max_attempts=3, wait_sec=2):
|
448 |
for attempt in range(max_attempts):
|
@@ -494,7 +534,7 @@ def compute_clusters_style_representation_3(
|
|
494 |
author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
|
495 |
print(f"Number of authors: {len(background_corpus_df_feat_id)}")
|
496 |
print(author_names)
|
497 |
-
features = identify_style_features(author_texts, max_num_feats=max_num_feats)
|
498 |
|
499 |
# STEP 2: Prepare author pool for span extraction
|
500 |
span_df = background_corpus_df.iloc[:4]
|
|
|
22 |
from sklearn.metrics.pairwise import cosine_similarity
|
23 |
|
24 |
CACHE_DIR = "datasets/embeddings_cache"
|
25 |
+
ZOOM_CACHE = "datasets/zoom_cache/features_cache.json"
|
26 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
27 |
+
os.makedirs(os.path.dirname(ZOOM_CACHE), exist_ok=True)
|
28 |
# Bump this whenever there is a change etc...
|
29 |
CACHE_VERSION = 1
|
30 |
|
|
|
420 |
|
421 |
return parsed_response
|
422 |
|
423 |
+
def generate_cache_key(author_names: List[str], max_num_feats: int) -> str:
|
424 |
+
"""Generate a unique cache key based on author names and max features"""
|
425 |
+
# Sort author names to ensure consistent key regardless of order
|
426 |
+
sorted_authors = sorted(author_names)
|
427 |
+
key_data = {
|
428 |
+
"authors": sorted_authors,
|
429 |
+
"max_num_feats": max_num_feats
|
430 |
+
}
|
431 |
+
key_string = json.dumps(key_data, sort_keys=True)
|
432 |
+
return hashlib.md5(key_string.encode()).hexdigest()
|
433 |
+
|
434 |
+
def identify_style_features(author_texts: list[str], author_names: list[str], max_num_feats: int = 5) -> list[str]:
|
435 |
+
cache_key = None
|
436 |
+
if author_names:
|
437 |
+
cache_key = generate_cache_key(author_names, max_num_feats)
|
438 |
+
|
439 |
+
if os.path.exists(ZOOM_CACHE):
|
440 |
+
with open(ZOOM_CACHE, 'r') as f:
|
441 |
+
cache = json.load(f)
|
442 |
+
else:
|
443 |
+
cache = {}
|
444 |
+
|
445 |
+
if cache_key in cache:
|
446 |
+
print(f"\nCache hit! Using cached features for authors: {author_names}")
|
447 |
+
return cache[cache_key]["features"]
|
448 |
+
else:
|
449 |
+
print(f"Cache miss. Computing features for authors: {author_names}")
|
450 |
+
|
451 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
452 |
prompt = f"""Identify {max_num_feats} writing style features that are commonly found across the following texts. Do not extract spans. Just return the feature names as a list.
|
453 |
Author Texts:
|
|
|
471 |
)
|
472 |
return json.loads(response.choices[0].message.content)
|
473 |
|
474 |
+
features = retry_call(_make_call, FeatureIdentificationSchema).features
|
475 |
+
|
476 |
+
print(f"Adding to zoom cache")
|
477 |
+
if cache_key and author_names:
|
478 |
+
cache[cache_key] = {
|
479 |
+
"features": features
|
480 |
+
}
|
481 |
+
# save_cache(cache)
|
482 |
+
with open(ZOOM_CACHE, 'w') as f:
|
483 |
+
json.dump(cache, f, indent=2)
|
484 |
+
|
485 |
+
print(f"Cached features for authors: {author_names}")
|
486 |
|
487 |
def retry_call(call_fn, schema_class, max_attempts=3, wait_sec=2):
|
488 |
for attempt in range(max_attempts):
|
|
|
534 |
author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
|
535 |
print(f"Number of authors: {len(background_corpus_df_feat_id)}")
|
536 |
print(author_names)
|
537 |
+
features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
|
538 |
|
539 |
# STEP 2: Prepare author pool for span extraction
|
540 |
span_df = background_corpus_df.iloc[:4]
|
utils/llm_feat_utils.py
CHANGED
@@ -125,7 +125,7 @@ def generate_feature_spans_cached(client, text: str, features: list[str], role:
|
|
125 |
result[feat] = spans
|
126 |
|
127 |
# 5) write back the combined cache
|
128 |
-
with open(cache_path, "
|
129 |
json.dump(cache, f, indent=2)
|
130 |
return result
|
131 |
|
|
|
125 |
result[feat] = spans
|
126 |
|
127 |
# 5) write back the combined cache
|
128 |
+
with open(cache_path, "w") as f:
|
129 |
json.dump(cache, f, indent=2)
|
130 |
return result
|
131 |
|