brandonmusic commited on
Commit
9a07481
·
verified ·
1 Parent(s): a74a6aa

Create precompute_cap_embeddings.py

Browse files
Files changed (1) hide show
  1. precompute_cap_embeddings.py +69 -0
precompute_cap_embeddings.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ from openai import OpenAI
6
+ from scipy.sparse import save_npz
7
+ import pickle
8
+ from datasets import load_from_disk
9
+ from sklearn.feature_extraction.text import TfidfVectorizer
10
+
11
+ # === Logging setup ===
12
+ logger = logging.getLogger("precompute")
13
+ logging.basicConfig(level=logging.INFO)
14
+
15
+ # === API keys ===
16
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
17
+ openai_client = OpenAI(api_key=OPENAI_API_KEY)
18
+
19
+ # === Load CAP dataset ===
20
+ LOCAL_PATH = "/data/cap_dataset"
21
+ cap_dataset = load_from_disk(LOCAL_PATH)
22
+ cap_texts = [doc['text'] for doc in cap_dataset]
23
+ logger.info(f"Loaded {len(cap_texts)} CAP texts.")
24
+
25
+ # === TF-IDF Precomputation ===
26
+ if not (os.path.exists("/data/cap_tfidf.pkl") and os.path.exists("/data/cap_tfidf_matrix.npz")):
27
+ logger.info("Precomputing TF-IDF...")
28
+ tfidf = TfidfVectorizer(stop_words='english', max_features=100_000)
29
+ tfidf_matrix = tfidf.fit_transform(cap_texts)
30
+ with open("/data/cap_tfidf.pkl", 'wb') as f:
31
+ pickle.dump(tfidf, f)
32
+ save_npz("/data/cap_tfidf_matrix.npz", tfidf_matrix)
33
+ logger.info("✅ Saved TF-IDF cache files.")
34
+ else:
35
+ logger.info("TF-IDF cache files already exist, skipping.")
36
+
37
+ # === GTE Embeddings Precomputation ===
38
+ if not os.path.exists("/data/cap_gte.npy"):
39
+ logger.info("Precomputing GTE embeddings...")
40
+ encoder_gte = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct")
41
+ embeddings_gte = encoder_gte.encode(cap_texts, normalize_embeddings=True)
42
+ np.save("/data/cap_gte.npy", embeddings_gte)
43
+ logger.info("✅ Saved GTE embeddings.")
44
+ else:
45
+ logger.info("GTE embeddings cache file already exists, skipping.")
46
+
47
+ # === OpenAI Embeddings Precomputation ===
48
+ if not os.path.exists("/data/cap_openai.npy"):
49
+ logger.info("Precomputing OpenAI embeddings...")
50
+ def get_openai_embeddings(texts):
51
+ chunk_size = 100 # Adjust based on average text length and token limit
52
+ embeddings = []
53
+ for i in range(0, len(texts), chunk_size):
54
+ chunk = texts[i:i + chunk_size]
55
+ response = openai_client.embeddings.create(
56
+ model="text-embedding-3-large",
57
+ input=chunk
58
+ )
59
+ embeddings.extend([item.embedding for item in response.data])
60
+ logger.info(f"Processed chunk {i//chunk_size + 1} of {len(texts)//chunk_size + 1}")
61
+ time.sleep(1) # Rate limit buffer for Tier 5
62
+ return np.array(embeddings)
63
+ embeddings_openai = get_openai_embeddings(cap_texts)
64
+ np.save("/data/cap_openai.npy", embeddings_openai)
65
+ logger.info("✅ Saved OpenAI embeddings.")
66
+ else:
67
+ logger.info("OpenAI embeddings cache file already exists, skipping.")
68
+
69
+ logger.info("✅ Precomputation completed. Cache files are ready for use.")