brandonmusic commited on
Commit
274ce09
·
verified ·
1 Parent(s): b6805a5

Update precompute_cap_embeddings.py

Browse files
Files changed (1) hide show
  1. precompute_cap_embeddings.py +18 -9
precompute_cap_embeddings.py CHANGED
@@ -12,40 +12,48 @@ from sklearn.feature_extraction.text import TfidfVectorizer
12
  logger = logging.getLogger("precompute")
13
  logging.basicConfig(level=logging.INFO)
14
 
15
- # === API keys ===
16
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
 
 
 
 
17
  openai_client = OpenAI(api_key=OPENAI_API_KEY)
18
 
19
  # === Load CAP dataset ===
20
- LOCAL_PATH = "/data/cap_dataset"
 
 
21
  cap_dataset = load_from_disk(LOCAL_PATH)
22
  cap_texts = [doc['text'] for doc in cap_dataset]
23
  logger.info(f"Loaded {len(cap_texts)} CAP texts.")
24
 
25
  # === TF-IDF Precomputation ===
26
- if not (os.path.exists("/data/cap_tfidf.pkl") and os.path.exists("/data/cap_tfidf_matrix.npz")):
27
  logger.info("Precomputing TF-IDF...")
28
  tfidf = TfidfVectorizer(stop_words='english', max_features=100_000)
29
  tfidf_matrix = tfidf.fit_transform(cap_texts)
30
- with open("/data/cap_tfidf.pkl", 'wb') as f:
 
31
  pickle.dump(tfidf, f)
32
- save_npz("/data/cap_tfidf_matrix.npz", tfidf_matrix)
33
  logger.info("✅ Saved TF-IDF cache files.")
34
  else:
35
  logger.info("TF-IDF cache files already exist, skipping.")
36
 
37
  # === GTE Embeddings Precomputation ===
38
- if not os.path.exists("/data/cap_gte.npy"):
39
  logger.info("Precomputing GTE embeddings...")
40
  encoder_gte = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct")
41
  embeddings_gte = encoder_gte.encode(cap_texts, normalize_embeddings=True)
42
- np.save("/data/cap_gte.npy", embeddings_gte)
 
43
  logger.info("✅ Saved GTE embeddings.")
44
  else:
45
  logger.info("GTE embeddings cache file already exists, skipping.")
46
 
47
  # === OpenAI Embeddings Precomputation ===
48
- if not os.path.exists("/data/cap_openai.npy"):
49
  logger.info("Precomputing OpenAI embeddings...")
50
  def get_openai_embeddings(texts):
51
  chunk_size = 100 # Adjust based on average text length and token limit
@@ -61,7 +69,8 @@ if not os.path.exists("/data/cap_openai.npy"):
61
  time.sleep(1) # Rate limit buffer for Tier 5
62
  return np.array(embeddings)
63
  embeddings_openai = get_openai_embeddings(cap_texts)
64
- np.save("/data/cap_openai.npy", embeddings_openai)
 
65
  logger.info("✅ Saved OpenAI embeddings.")
66
  else:
67
  logger.info("OpenAI embeddings cache file already exists, skipping.")
 
12
  logger = logging.getLogger("precompute")
13
  logging.basicConfig(level=logging.INFO)
14
 
15
+ # === API key handling ===
16
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
17
+ if not OPENAI_API_KEY:
18
+ OPENAI_API_KEY = input("Please enter your OpenAI API Key (set OPENAI_API_KEY environment variable for future runs): ")
19
+ if not OPENAI_API_KEY:
20
+ raise EnvironmentError("OPENAI_API_KEY must be provided either as an environment variable or input.")
21
  openai_client = OpenAI(api_key=OPENAI_API_KEY)
22
 
23
  # === Load CAP dataset ===
24
+ LOCAL_PATH = "./cap_dataset" # Local path for testing
25
+ if not os.path.exists(LOCAL_PATH):
26
+ raise FileNotFoundError(f"CAP dataset not found at {LOCAL_PATH}. Download it first.")
27
  cap_dataset = load_from_disk(LOCAL_PATH)
28
  cap_texts = [doc['text'] for doc in cap_dataset]
29
  logger.info(f"Loaded {len(cap_texts)} CAP texts.")
30
 
31
  # === TF-IDF Precomputation ===
32
+ if not (os.path.exists("./data/cap_tfidf.pkl") and os.path.exists("./data/cap_tfidf_matrix.npz")):
33
  logger.info("Precomputing TF-IDF...")
34
  tfidf = TfidfVectorizer(stop_words='english', max_features=100_000)
35
  tfidf_matrix = tfidf.fit_transform(cap_texts)
36
+ os.makedirs("./data", exist_ok=True)
37
+ with open("./data/cap_tfidf.pkl", 'wb') as f:
38
  pickle.dump(tfidf, f)
39
+ save_npz("./data/cap_tfidf_matrix.npz", tfidf_matrix)
40
  logger.info("✅ Saved TF-IDF cache files.")
41
  else:
42
  logger.info("TF-IDF cache files already exist, skipping.")
43
 
44
  # === GTE Embeddings Precomputation ===
45
+ if not os.path.exists("./data/cap_gte.npy"):
46
  logger.info("Precomputing GTE embeddings...")
47
  encoder_gte = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct")
48
  embeddings_gte = encoder_gte.encode(cap_texts, normalize_embeddings=True)
49
+ os.makedirs("./data", exist_ok=True)
50
+ np.save("./data/cap_gte.npy", embeddings_gte)
51
  logger.info("✅ Saved GTE embeddings.")
52
  else:
53
  logger.info("GTE embeddings cache file already exists, skipping.")
54
 
55
  # === OpenAI Embeddings Precomputation ===
56
+ if not os.path.exists("./data/cap_openai.npy"):
57
  logger.info("Precomputing OpenAI embeddings...")
58
  def get_openai_embeddings(texts):
59
  chunk_size = 100 # Adjust based on average text length and token limit
 
69
  time.sleep(1) # Rate limit buffer for Tier 5
70
  return np.array(embeddings)
71
  embeddings_openai = get_openai_embeddings(cap_texts)
72
+ os.makedirs("./data", exist_ok=True)
73
+ np.save("./data/cap_openai.npy", embeddings_openai)
74
  logger.info("✅ Saved OpenAI embeddings.")
75
  else:
76
  logger.info("OpenAI embeddings cache file already exists, skipping.")