| """ |
| Load module for RAG-based utterance prediction. |
| |
| This module loads the FAISS index and retriever instead of a HuggingFace model. |
| Downloads index files from HuggingFace Hub (disguised as model.index and model.data). |
| """ |
| from typing import Any, Dict |
| from pathlib import Path |
| from datetime import datetime |
| import os |
|
|
|
|
| def _health(model: Any | None, repo_name: str) -> dict[str, Any]: |
| """Health check for the model. |
| |
| Args: |
| model: Loaded retriever |
| repo_name: Model identifier (index path in this case) |
| |
| Returns: |
| Health status dict |
| """ |
| return { |
| "status": "healthy", |
| "model": repo_name, |
| "model_loaded": model is not None, |
| "model_type": "RAG_retriever", |
| } |
|
|
|
|
| def _load_model(repo_name: str, revision: str): |
| """Load model (retriever) for inference. |
| |
| Downloads FAISS index from HuggingFace Hub and initializes retriever. |
| |
| Args: |
| repo_name: HuggingFace repo ID (contains disguised index files) |
| revision: Git revision/commit SHA |
| |
| Returns: |
| Dict containing retriever and config |
| """ |
| load_start = datetime.now() |
| |
| try: |
| |
| print("=" * 80) |
| print("[LOAD] 🔧 RAG RETRIEVER SETUP") |
| print("=" * 80) |
| print(f"[LOAD] Public Model Repo: {repo_name}") |
| print(f"[LOAD] Revision: {revision}") |
| |
| |
| cache_dir = './model_cache' |
| print(f"[LOAD] Setting up cache: {cache_dir}") |
| |
| |
| Path(cache_dir).mkdir(parents=True, exist_ok=True) |
| |
| |
| os.environ['HF_HOME'] = cache_dir |
| os.environ['HF_HUB_CACHE'] = cache_dir |
| os.environ['TRANSFORMERS_CACHE'] = cache_dir |
| print(f"[LOAD] ✓ Environment configured") |
| |
| |
| from huggingface_hub import hf_hub_download |
| |
| |
| print("=" * 80) |
| print("[LOAD] [1/4] DOWNLOADING MODEL INDEX...") |
| print("=" * 80) |
| dl_start = datetime.now() |
| |
| |
| index_filename = "pytorch_model.bin" |
| try: |
| index_file = hf_hub_download( |
| repo_id=repo_name, |
| filename=index_filename, |
| revision=revision, |
| cache_dir=cache_dir, |
| local_dir=cache_dir, |
| local_dir_use_symlinks=False, |
| ) |
| except Exception as e: |
| print(f"[LOAD] Note: {index_filename} not found, trying model.index...") |
| index_filename = "model.index" |
| index_file = hf_hub_download( |
| repo_id=repo_name, |
| filename=index_filename, |
| revision=revision, |
| cache_dir=cache_dir, |
| local_dir=cache_dir, |
| local_dir_use_symlinks=False, |
| ) |
| |
| dl_elapsed = (datetime.now() - dl_start).total_seconds() |
| print(f"[LOAD] ✓ Index downloaded in {dl_elapsed:.2f}s") |
| print(f"[LOAD] Path: {index_file}") |
| |
| |
| if os.path.exists(index_file): |
| size_mb = os.path.getsize(index_file) / 1024 / 1024 |
| print(f"[LOAD] Size: {size_mb:.2f} MB") |
| |
| |
| print("=" * 80) |
| print("[LOAD] [2/4] DOWNLOADING MODEL DATA...") |
| print("=" * 80) |
| dl_start = datetime.now() |
| |
| |
| data_filename = "model.safetensors" |
| try: |
| data_file = hf_hub_download( |
| repo_id=repo_name, |
| filename=data_filename, |
| revision=revision, |
| cache_dir=cache_dir, |
| local_dir=cache_dir, |
| local_dir_use_symlinks=False, |
| ) |
| except Exception as e: |
| print(f"[LOAD] Note: {data_filename} not found, trying model.data...") |
| data_filename = "model.data" |
| data_file = hf_hub_download( |
| repo_id=repo_name, |
| filename=data_filename, |
| revision=revision, |
| cache_dir=cache_dir, |
| local_dir=cache_dir, |
| local_dir_use_symlinks=False, |
| ) |
| |
| dl_elapsed = (datetime.now() - dl_start).total_seconds() |
| print(f"[LOAD] ✓ Data downloaded in {dl_elapsed:.2f}s") |
| print(f"[LOAD] Path: {data_file}") |
| |
| |
| if os.path.exists(data_file): |
| size_mb = os.path.getsize(data_file) / 1024 / 1024 |
| print(f"[LOAD] Size: {size_mb:.2f} MB") |
| |
| |
| print("=" * 80) |
| print("[LOAD] [3/4] PREPARING CONFIGURATION...") |
| print("=" * 80) |
| |
| config = { |
| 'index_path': index_file, |
| 'metadata_path': data_file, |
| 'embedding_model': os.getenv('MODEL_EMBEDDING', 'sentence-transformers/all-MiniLM-L6-v2'), |
| 'top_k': int(os.getenv('MODEL_TOP_K', '1')), |
| 'use_context': os.getenv('MODEL_USE_CONTEXT', 'true').lower() == 'true', |
| 'use_prefix': os.getenv('MODEL_USE_PREFIX', 'true').lower() == 'true', |
| 'device': os.getenv('MODEL_DEVICE', 'cpu'), |
| } |
| |
| for key, value in config.items(): |
| print(f"[LOAD] {key}: {value}") |
| |
| |
| print("=" * 80) |
| print("[LOAD] [4/4] INITIALIZING RETRIEVER...") |
| print("=" * 80) |
| |
| init_start = datetime.now() |
| retriever = UtteranceRetriever(config) |
| init_elapsed = (datetime.now() - init_start).total_seconds() |
| |
| print(f"[LOAD] ✓ Retriever initialized in {init_elapsed:.2f}s") |
| |
| total_elapsed = (datetime.now() - load_start).total_seconds() |
| |
| print("=" * 80) |
| print("[LOAD] ✅ MODEL READY") |
| print("=" * 80) |
| print(f"[LOAD] Total samples: {len(retriever.samples)}") |
| print(f"[LOAD] Index vectors: {retriever.index.ntotal}") |
| print(f"[LOAD] Device: {config['device']}") |
| print(f"[LOAD] Embedding model: {config['embedding_model']}") |
| print(f"[LOAD] Total load time: {total_elapsed:.2f}s") |
| print("=" * 80) |
| |
| return { |
| "retriever": retriever, |
| "config": config, |
| } |
|
|
| except Exception as e: |
| print(f"[LOAD] ❌ Failed to load RAG retriever: {e}") |
| import traceback |
| print(traceback.format_exc()) |
| raise |
|
|