| import os, json |
| from pathlib import Path |
| from tqdm import tqdm |
| import faiss |
| import numpy as np |
| from sentence_transformers import SentenceTransformer |
|
|
| DUMP_PATH = "/home/ubuntu/output" |
| FAISS_OUT = "wiki_faiss.index" |
| STATE_FILE = "progress.json" |
| PAUSE_FLAG = "PAUSE" |
| CHUNK_SIZE = 200 |
| BATCH_SIZE = 1000 |
| CHECKPOINT_BATCHES = 5 |
|
|
| |
| embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
| dim = embedder.get_sentence_embedding_dimension() |
| if Path(FAISS_OUT).exists(): |
| index = faiss.read_index(FAISS_OUT) |
| else: |
| index = faiss.IndexFlatIP(dim) |
|
|
| |
| files = [os.path.join(r,f) for r,_,fs in os.walk(DUMP_PATH) for f in fs if f.startswith("wiki_")] |
| total_files = len(files) |
|
|
| |
| if Path(STATE_FILE).exists(): |
| with open(STATE_FILE) as f: |
| state = json.load(f) |
| file_idx = state.get("file_idx", 0) |
| batch_idx = state.get("batch_idx", 0) |
| print(f"▶ Resuming from file {file_idx}, batch {batch_idx}") |
| else: |
| file_idx = 0 |
| batch_idx = 0 |
|
|
| |
| def chunk_text(text, size=CHUNK_SIZE): |
| words = text.split() |
| for i in range(0, len(words), size): |
| yield " ".join(words[i:i+size]) |
|
|
| |
| file_chunk_counts = [] |
| total_chunks = 0 |
| for f in files: |
| cnt = 0 |
| try: |
| with open(f, "r", encoding="utf-8") as file: |
| for line in file: |
| data = json.loads(line) |
| text = data.get("text", "").strip() |
| if text: |
| cnt += len(list(chunk_text(text))) |
| except: |
| pass |
| file_chunk_counts.append(cnt) |
| total_chunks += cnt |
|
|
| |
| processed_chunks = sum(file_chunk_counts[:file_idx]) + batch_idx |
|
|
| |
| pbar = tqdm(total=total_chunks, initial=processed_chunks, desc="Embedding chunks", unit="chunk") |
|
|
| |
| for f_idx in range(file_idx, total_files): |
| file_path = files[f_idx] |
|
|
| |
| if Path(PAUSE_FLAG).exists(): |
| print("\n⏸ Pause requested. Saving state...") |
| faiss.write_index(index, FAISS_OUT) |
| with open(STATE_FILE, "w") as f: |
| json.dump({"file_idx": f_idx, "batch_idx": batch_idx}, f) |
| exit(0) |
|
|
| |
| chunks = [] |
| try: |
| with open(file_path, "r", encoding="utf-8") as f: |
| for line in f: |
| data = json.loads(line) |
| text = data.get("text", "").strip() |
| if text: |
| chunks.extend(list(chunk_text(text))) |
| except Exception as e: |
| print(f"Error reading {file_path}: {e}") |
| continue |
|
|
| start = batch_idx if f_idx == file_idx else 0 |
| total_chunks_in_file = len(chunks) |
|
|
| |
| for b_idx in range(start, total_chunks_in_file, BATCH_SIZE): |
| if Path(PAUSE_FLAG).exists(): |
| print("\n⏸ Pause requested. Saving state...") |
| faiss.write_index(index, FAISS_OUT) |
| with open(STATE_FILE, "w") as f: |
| json.dump({"file_idx": f_idx, "batch_idx": b_idx}, f) |
| exit(0) |
|
|
| batch_texts = chunks[b_idx:b_idx+BATCH_SIZE] |
| embeddings = embedder.encode(batch_texts, convert_to_numpy=True, dtype=np.float32) |
| faiss.normalize_L2(embeddings) |
| index.add(embeddings) |
|
|
| |
| pbar.update(len(batch_texts)) |
|
|
| |
| if (b_idx // BATCH_SIZE + 1) % CHECKPOINT_BATCHES == 0: |
| faiss.write_index(index, FAISS_OUT) |
| with open(STATE_FILE, "w") as f: |
| json.dump({"file_idx": f_idx, "batch_idx": b_idx + BATCH_SIZE}, f) |
|
|
| |
| batch_idx = 0 |
| faiss.write_index(index, FAISS_OUT) |
| with open(STATE_FILE, "w") as f: |
| json.dump({"file_idx": f_idx+1, "batch_idx": 0}, f) |
|
|
| pbar.close() |
| print("✅ All files processed.") |
| if Path(PAUSE_FLAG).exists(): |
| os.remove(PAUSE_FLAG) |
|
|