Yeroyan commited on
Commit
5c02889
Β·
verified Β·
1 Parent(s): 49a2a71

download HF models during image build

Browse files
Files changed (1) hide show
  1. download_models.py +54 -0
download_models.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pre-download Hugging Face models during Docker image build.
3
+
4
+ This script loads the models to trigger download and caching.
5
+ """
6
+ import os
7
+ import sys
8
+
9
+ print("πŸ”½ Downloading Hugging Face models during build...")
10
+
11
+ # Model configurations from settings.yaml
12
+ EMBEDDING_MODEL = "BAAI/bge-m3"
13
+ RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
14
+
15
+ try:
16
+ print(f"πŸ“¦ Downloading embedding model: {EMBEDDING_MODEL}")
17
+ from langchain_community.embeddings import HuggingFaceEmbeddings
18
+
19
+ # Load embedding model (will download if not cached)
20
+ embeddings = HuggingFaceEmbeddings(
21
+ model_name=EMBEDDING_MODEL,
22
+ model_kwargs={"device": "cpu"}, # Use CPU during build
23
+ encode_kwargs={"normalize_embeddings": True},
24
+ show_progress=True,
25
+ )
26
+
27
+ # Trigger actual download by encoding a small text
28
+ test_text = "test"
29
+ _ = embeddings.embed_query(test_text)
30
+ print(f"βœ… Embedding model downloaded: {EMBEDDING_MODEL}")
31
+
32
+ except Exception as e:
33
+ print(f"⚠️ Warning: Could not download embedding model: {e}")
34
+ # Don't exit on error - allow build to continue (model will download at runtime)
35
+ pass
36
+
37
+ try:
38
+ print(f"πŸ“¦ Downloading reranker model: {RERANKER_MODEL}")
39
+ from sentence_transformers import CrossEncoder
40
+
41
+ # Load reranker model (will download if not cached)
42
+ reranker = CrossEncoder(RERANKER_MODEL)
43
+
44
+ # Trigger actual download by running inference
45
+ test_pairs = [("test query", "test document")]
46
+ _ = reranker.predict(test_pairs)
47
+ print(f"βœ… Reranker model downloaded: {RERANKER_MODEL}")
48
+
49
+ except Exception as e:
50
+ print(f"⚠️ Warning: Could not download reranker model: {e}")
51
+ # Don't exit on error - allow build to continue (model will download at runtime)
52
+ pass
53
+
54
+ print("βœ… All models downloaded and cached successfully!")