English
SyvaAI-Bv1 / chunk_creation.py
danielgrims's picture
Upload folder using huggingface_hub
1026698 verified
import os, json
from pathlib import Path
from tqdm import tqdm
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
DUMP_PATH = "/home/ubuntu/output"
FAISS_OUT = "wiki_faiss.index"
STATE_FILE = "progress.json"
PAUSE_FLAG = "PAUSE"
CHUNK_SIZE = 200
BATCH_SIZE = 1000
CHECKPOINT_BATCHES = 5
# Load model and FAISS index
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
dim = embedder.get_sentence_embedding_dimension()
if Path(FAISS_OUT).exists():
index = faiss.read_index(FAISS_OUT)
else:
index = faiss.IndexFlatIP(dim)
# Gather all files
files = [os.path.join(r,f) for r,_,fs in os.walk(DUMP_PATH) for f in fs if f.startswith("wiki_")]
total_files = len(files)
# Load progress
if Path(STATE_FILE).exists():
with open(STATE_FILE) as f:
state = json.load(f)
file_idx = state.get("file_idx", 0)
batch_idx = state.get("batch_idx", 0)
print(f"▶ Resuming from file {file_idx}, batch {batch_idx}")
else:
file_idx = 0
batch_idx = 0
# Helper: split text into chunks
def chunk_text(text, size=CHUNK_SIZE):
words = text.split()
for i in range(0, len(words), size):
yield " ".join(words[i:i+size])
# --- Precompute total chunks and already processed chunks for overall progress bar ---
file_chunk_counts = []
total_chunks = 0
for f in files:
cnt = 0
try:
with open(f, "r", encoding="utf-8") as file:
for line in file:
data = json.loads(line)
text = data.get("text", "").strip()
if text:
cnt += len(list(chunk_text(text)))
except:
pass
file_chunk_counts.append(cnt)
total_chunks += cnt
# Already processed chunks
processed_chunks = sum(file_chunk_counts[:file_idx]) + batch_idx
# Overall progress bar
pbar = tqdm(total=total_chunks, initial=processed_chunks, desc="Embedding chunks", unit="chunk")
# --- Main loop ---
for f_idx in range(file_idx, total_files):
file_path = files[f_idx]
# Pause check
if Path(PAUSE_FLAG).exists():
print("\n⏸ Pause requested. Saving state...")
faiss.write_index(index, FAISS_OUT)
with open(STATE_FILE, "w") as f:
json.dump({"file_idx": f_idx, "batch_idx": batch_idx}, f)
exit(0)
# Read file
chunks = []
try:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
text = data.get("text", "").strip()
if text:
chunks.extend(list(chunk_text(text)))
except Exception as e:
print(f"Error reading {file_path}: {e}")
continue
start = batch_idx if f_idx == file_idx else 0
total_chunks_in_file = len(chunks)
# Process chunks in batches
for b_idx in range(start, total_chunks_in_file, BATCH_SIZE):
if Path(PAUSE_FLAG).exists():
print("\n⏸ Pause requested. Saving state...")
faiss.write_index(index, FAISS_OUT)
with open(STATE_FILE, "w") as f:
json.dump({"file_idx": f_idx, "batch_idx": b_idx}, f)
exit(0)
batch_texts = chunks[b_idx:b_idx+BATCH_SIZE]
embeddings = embedder.encode(batch_texts, convert_to_numpy=True, dtype=np.float32)
faiss.normalize_L2(embeddings)
index.add(embeddings)
# Update overall progress bar
pbar.update(len(batch_texts))
# Checkpoint
if (b_idx // BATCH_SIZE + 1) % CHECKPOINT_BATCHES == 0:
faiss.write_index(index, FAISS_OUT)
with open(STATE_FILE, "w") as f:
json.dump({"file_idx": f_idx, "batch_idx": b_idx + BATCH_SIZE}, f)
# Finished file
batch_idx = 0
faiss.write_index(index, FAISS_OUT)
with open(STATE_FILE, "w") as f:
json.dump({"file_idx": f_idx+1, "batch_idx": 0}, f)
pbar.close()
print("✅ All files processed.")
if Path(PAUSE_FLAG).exists():
os.remove(PAUSE_FLAG)