Anisha Bhatnagar commited on
Commit
bd7d9f9
·
1 Parent(s): 8db24a7

fixed caching issues in LLM feature identification

Browse files
utils/interp_space_utils.py CHANGED
@@ -22,7 +22,9 @@ import numpy as np
22
  from sklearn.metrics.pairwise import cosine_similarity
23
 
24
  CACHE_DIR = "datasets/embeddings_cache"
 
25
  os.makedirs(CACHE_DIR, exist_ok=True)
 
26
  # Bump this whenever there is a change etc...
27
  CACHE_VERSION = 1
28
 
@@ -418,7 +420,34 @@ def compute_clusters_style_representation_2(
418
 
419
  return parsed_response
420
 
421
- def identify_style_features(author_texts: list[str], max_num_feats: int = 5) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
423
  prompt = f"""Identify {max_num_feats} writing style features that are commonly found across the following texts. Do not extract spans. Just return the feature names as a list.
424
  Author Texts:
@@ -442,7 +471,18 @@ def identify_style_features(author_texts: list[str], max_num_feats: int = 5) ->
442
  )
443
  return json.loads(response.choices[0].message.content)
444
 
445
- return retry_call(_make_call, FeatureIdentificationSchema).features
 
 
 
 
 
 
 
 
 
 
 
446
 
447
  def retry_call(call_fn, schema_class, max_attempts=3, wait_sec=2):
448
  for attempt in range(max_attempts):
@@ -494,7 +534,7 @@ def compute_clusters_style_representation_3(
494
  author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
495
  print(f"Number of authors: {len(background_corpus_df_feat_id)}")
496
  print(author_names)
497
- features = identify_style_features(author_texts, max_num_feats=max_num_feats)
498
 
499
  # STEP 2: Prepare author pool for span extraction
500
  span_df = background_corpus_df.iloc[:4]
 
22
  from sklearn.metrics.pairwise import cosine_similarity
23
 
24
  CACHE_DIR = "datasets/embeddings_cache"
25
+ ZOOM_CACHE = "datasets/zoom_cache/features_cache.json"
26
  os.makedirs(CACHE_DIR, exist_ok=True)
27
+ os.makedirs(os.path.dirname(ZOOM_CACHE), exist_ok=True)
28
  # Bump this whenever there is a change etc...
29
  CACHE_VERSION = 1
30
 
 
420
 
421
  return parsed_response
422
 
423
+ def generate_cache_key(author_names: List[str], max_num_feats: int) -> str:
424
+ """Generate a unique cache key based on author names and max features"""
425
+ # Sort author names to ensure consistent key regardless of order
426
+ sorted_authors = sorted(author_names)
427
+ key_data = {
428
+ "authors": sorted_authors,
429
+ "max_num_feats": max_num_feats
430
+ }
431
+ key_string = json.dumps(key_data, sort_keys=True)
432
+ return hashlib.md5(key_string.encode()).hexdigest()
433
+
434
+ def identify_style_features(author_texts: list[str], author_names: list[str], max_num_feats: int = 5) -> list[str]:
435
+ cache_key = None
436
+ if author_names:
437
+ cache_key = generate_cache_key(author_names, max_num_feats)
438
+
439
+ if os.path.exists(ZOOM_CACHE):
440
+ with open(ZOOM_CACHE, 'r') as f:
441
+ cache = json.load(f)
442
+ else:
443
+ cache = {}
444
+
445
+ if cache_key in cache:
446
+ print(f"\nCache hit! Using cached features for authors: {author_names}")
447
+ return cache[cache_key]["features"]
448
+ else:
449
+ print(f"Cache miss. Computing features for authors: {author_names}")
450
+
451
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
452
  prompt = f"""Identify {max_num_feats} writing style features that are commonly found across the following texts. Do not extract spans. Just return the feature names as a list.
453
  Author Texts:
 
471
  )
472
  return json.loads(response.choices[0].message.content)
473
 
474
+ features = retry_call(_make_call, FeatureIdentificationSchema).features
475
+
476
+ print(f"Adding to zoom cache")
477
+ if cache_key and author_names:
478
+ cache[cache_key] = {
479
+ "features": features
480
+ }
481
+ # save_cache(cache)
482
+ with open(ZOOM_CACHE, 'w') as f:
483
+ json.dump(cache, f, indent=2)
484
+
485
+ print(f"Cached features for authors: {author_names}")
486
 
487
  def retry_call(call_fn, schema_class, max_attempts=3, wait_sec=2):
488
  for attempt in range(max_attempts):
 
534
  author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
535
  print(f"Number of authors: {len(background_corpus_df_feat_id)}")
536
  print(author_names)
537
+ features = identify_style_features(author_texts, author_names, max_num_feats=max_num_feats)
538
 
539
  # STEP 2: Prepare author pool for span extraction
540
  span_df = background_corpus_df.iloc[:4]
utils/llm_feat_utils.py CHANGED
@@ -125,7 +125,7 @@ def generate_feature_spans_cached(client, text: str, features: list[str], role:
125
  result[feat] = spans
126
 
127
  # 5) write back the combined cache
128
- with open(cache_path, "a") as f:
129
  json.dump(cache, f, indent=2)
130
  return result
131
 
 
125
  result[feat] = spans
126
 
127
  # 5) write back the combined cache
128
+ with open(cache_path, "w") as f:
129
  json.dump(cache, f, indent=2)
130
  return result
131