Peterase commited on
Commit
c48d01a
·
1 Parent(s): 8425041

fix: use snapshot_download to avoid OOMKilled during model pre-cache build step

Browse files
Files changed (1) hide show
  1. download_models.py +28 -23
download_models.py CHANGED
@@ -1,38 +1,43 @@
1
  import os
2
  import sys
3
 
4
- # Monkeypatch for transformers/FlagEmbedding compatibility issue
5
- try:
6
- import transformers.utils.import_utils
7
- if not hasattr(transformers.utils.import_utils, 'is_torch_fx_available'):
8
- transformers.utils.import_utils.is_torch_fx_available = lambda: False
9
- except Exception:
10
- pass
11
 
12
- from FlagEmbedding import BGEM3FlagModel
13
- from sentence_transformers import CrossEncoder
 
 
 
 
 
 
14
 
15
- def download():
16
- print("--- STARTING MODEL PRE-CACHE ---")
17
-
18
- # 1. BGE-M3
19
  model_name = "BAAI/bge-m3"
20
- print(f"Downloading/Loading {model_name}...")
21
  try:
22
- # This will trigger the download if not present
23
- _ = BGEM3FlagModel(model_name, use_fp16=True)
24
- print(f"Successfully cached {model_name}")
 
 
 
25
  except Exception as e:
26
- print(f"Error caching {model_name}: {e}")
27
 
28
- # 2. Reranker
29
  reranker_name = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
30
- print(f"Downloading/Loading {reranker_name}...")
31
  try:
32
- _ = CrossEncoder(reranker_name)
33
- print(f"Successfully cached {reranker_name}")
 
 
 
 
34
  except Exception as e:
35
- print(f"Error caching {reranker_name}: {e}")
36
 
37
  print("--- PRE-CACHE COMPLETE ---")
38
 
 
1
  import os
2
  import sys
3
 
4
+ def download():
5
+ print("--- STARTING MODEL PRE-CACHE (file-only download) ---")
 
 
 
 
 
6
 
7
+ # Use snapshot_download to pull files WITHOUT loading model into RAM.
8
+ # Loading BGEM3FlagModel() during build causes OOMKilled on HF Spaces
9
+ # because the build container has very limited memory.
10
+ try:
11
+ from huggingface_hub import snapshot_download
12
+ except ImportError:
13
+ print("huggingface_hub not available, skipping pre-cache")
14
+ return
15
 
16
+ # 1. BGE-M3 — download files only, no model instantiation
 
 
 
17
  model_name = "BAAI/bge-m3"
18
+ print(f"Downloading files for {model_name}...")
19
  try:
20
+ snapshot_download(
21
+ repo_id=model_name,
22
+ repo_type="model",
23
+ ignore_patterns=["*.msgpack", "*.h5", "flax_model*", "tf_model*", "rust_model*"],
24
+ )
25
+ print(f"✓ Files cached: {model_name}")
26
  except Exception as e:
27
+ print(f"Warning: could not pre-cache {model_name}: {e}")
28
 
29
+ # 2. Reranker — download files only
30
  reranker_name = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
31
+ print(f"Downloading files for {reranker_name}...")
32
  try:
33
+ snapshot_download(
34
+ repo_id=reranker_name,
35
+ repo_type="model",
36
+ ignore_patterns=["*.msgpack", "*.h5", "flax_model*", "tf_model*"],
37
+ )
38
+ print(f"✓ Files cached: {reranker_name}")
39
  except Exception as e:
40
+ print(f"Warning: could not pre-cache {reranker_name}: {e}")
41
 
42
  print("--- PRE-CACHE COMPLETE ---")
43