audit_assistant / download_models.py
akryldigital's picture
fix old "Document" import (#3)
9e1adca verified
"""
Pre-download Hugging Face models during Docker image build.
This script loads the models to trigger download and caching.
"""
import os
import sys
print("πŸ”½ Downloading Hugging Face models during build...")
# Model configurations from settings.yaml
EMBEDDING_MODEL = "BAAI/bge-m3"
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
try:
print(f"πŸ“¦ Downloading embedding model: {EMBEDDING_MODEL}")
from langchain_community.embeddings import HuggingFaceEmbeddings
# Load embedding model (will download if not cached)
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={"device": "cpu"}, # Use CPU during build
encode_kwargs={"normalize_embeddings": True},
show_progress=True,
)
# Trigger actual download by encoding a small text
test_text = "test"
_ = embeddings.embed_query(test_text)
print(f"βœ… Embedding model downloaded: {EMBEDDING_MODEL}")
except Exception as e:
print(f"⚠️ Warning: Could not download embedding model: {e}")
# Don't exit on error - allow build to continue (model will download at runtime)
pass
try:
print(f"πŸ“¦ Downloading reranker model: {RERANKER_MODEL}")
from sentence_transformers import CrossEncoder
# Load reranker model (will download if not cached)
reranker = CrossEncoder(RERANKER_MODEL)
# Trigger actual download by running inference
test_pairs = [("test query", "test document")]
_ = reranker.predict(test_pairs)
print(f"βœ… Reranker model downloaded: {RERANKER_MODEL}")
except Exception as e:
print(f"⚠️ Warning: Could not download reranker model: {e}")
# Don't exit on error - allow build to continue (model will download at runtime)
pass
print("βœ… All models downloaded and cached successfully!")