Spaces:
Sleeping
Sleeping
Commit
·
757abfc
1
Parent(s):
9d8e329
feat: Update application with new changes
Browse files- app/services/reranker_service.py +22 -9
- app/services/retrieval.py +24 -12
app/services/reranker_service.py
CHANGED
|
@@ -14,37 +14,50 @@ logger = logging.getLogger(__name__)
|
|
| 14 |
|
| 15 |
def load_reranker_model():
|
| 16 |
"""
|
| 17 |
-
Loads the custom-trained ExpertJudgeCrossEncoder model
|
| 18 |
-
|
| 19 |
"""
|
| 20 |
if state.reranker_model_loaded:
|
| 21 |
logger.info("Re-ranker model already loaded in state.")
|
| 22 |
return True
|
| 23 |
-
|
|
|
|
| 24 |
model_path = settings.RERANKER_MODEL_PATH
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
base_model_name = settings.RERANKER_MODEL_NAME
|
| 26 |
logger.info(f"Loading custom ExpertJudgeCrossEncoder from: {model_path}")
|
| 27 |
logger.info(f"Using base model architecture: {base_model_name}")
|
| 28 |
-
|
| 29 |
try:
|
| 30 |
# 1. Instantiate the model structure
|
| 31 |
model = ExpertJudgeCrossEncoder(model_name=base_model_name)
|
| 32 |
-
|
| 33 |
# 2. Load the saved weights (the state_dict) into the model structure
|
| 34 |
model.load_state_dict(torch.load(model_path, map_location=state.device))
|
| 35 |
-
|
| 36 |
# 3. Set up the model for inference
|
| 37 |
model.to(state.device)
|
| 38 |
model.eval()
|
| 39 |
-
|
| 40 |
# 4. Load the corresponding tokenizer
|
| 41 |
tokenizer = get_tokenizer(model_name=base_model_name)
|
| 42 |
-
|
| 43 |
# 5. Store both in the state
|
| 44 |
state.reranker_model = model
|
| 45 |
state.reranker_tokenizer = tokenizer
|
| 46 |
state.reranker_model_loaded = True
|
| 47 |
-
|
| 48 |
logger.info("Custom ExpertJudgeCrossEncoder model and tokenizer loaded successfully.")
|
| 49 |
return True
|
| 50 |
except Exception as e:
|
|
|
|
| 14 |
|
| 15 |
def load_reranker_model():
|
| 16 |
"""
|
| 17 |
+
Loads the custom-trained ExpertJudgeCrossEncoder model. If running on a
|
| 18 |
+
new server, it first downloads the model from S3.
|
| 19 |
"""
|
| 20 |
if state.reranker_model_loaded:
|
| 21 |
logger.info("Re-ranker model already loaded in state.")
|
| 22 |
return True
|
| 23 |
+
|
| 24 |
+
# --- ADDED: Download from S3 if file doesn't exist ---
|
| 25 |
model_path = settings.RERANKER_MODEL_PATH
|
| 26 |
+
if not os.path.exists(model_path) and settings.S3_RERANKER_URL:
|
| 27 |
+
logger.info(f"Re-ranker model not found at {model_path}. Downloading from S3...")
|
| 28 |
+
try:
|
| 29 |
+
# Create the 'data' directory if it doesn't exist
|
| 30 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
| 31 |
+
with requests.get(settings.S3_RERANKER_URL, stream=True) as r:
|
| 32 |
+
r.raise_for_status()
|
| 33 |
+
with open(model_path, 'wb') as f:
|
| 34 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 35 |
+
f.write(chunk)
|
| 36 |
+
logger.info("Successfully downloaded re-ranker model from S3.")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.exception(f"FATAL: Failed to download re-ranker model from S3: {e}")
|
| 39 |
+
return False
|
| 40 |
+
# --- END OF ADDITION ---
|
| 41 |
+
|
| 42 |
base_model_name = settings.RERANKER_MODEL_NAME
|
| 43 |
logger.info(f"Loading custom ExpertJudgeCrossEncoder from: {model_path}")
|
| 44 |
logger.info(f"Using base model architecture: {base_model_name}")
|
| 45 |
+
|
| 46 |
try:
|
| 47 |
# 1. Instantiate the model structure
|
| 48 |
model = ExpertJudgeCrossEncoder(model_name=base_model_name)
|
|
|
|
| 49 |
# 2. Load the saved weights (the state_dict) into the model structure
|
| 50 |
model.load_state_dict(torch.load(model_path, map_location=state.device))
|
|
|
|
| 51 |
# 3. Set up the model for inference
|
| 52 |
model.to(state.device)
|
| 53 |
model.eval()
|
|
|
|
| 54 |
# 4. Load the corresponding tokenizer
|
| 55 |
tokenizer = get_tokenizer(model_name=base_model_name)
|
| 56 |
+
|
| 57 |
# 5. Store both in the state
|
| 58 |
state.reranker_model = model
|
| 59 |
state.reranker_tokenizer = tokenizer
|
| 60 |
state.reranker_model_loaded = True
|
|
|
|
| 61 |
logger.info("Custom ExpertJudgeCrossEncoder model and tokenizer loaded successfully.")
|
| 62 |
return True
|
| 63 |
except Exception as e:
|
app/services/retrieval.py
CHANGED
|
@@ -17,42 +17,55 @@ logger = logging.getLogger(__name__)
|
|
| 17 |
|
| 18 |
def load_retrieval_artifacts():
|
| 19 |
"""
|
| 20 |
-
Loads all necessary artifacts for retrieval
|
| 21 |
-
|
| 22 |
-
query transformation matrix (Wq).
|
| 23 |
"""
|
| 24 |
if state.artifacts_loaded:
|
| 25 |
logger.info("Retrieval artifacts already loaded in state.")
|
| 26 |
return True
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 29 |
logger.info(f"Using device for retrieval: {device}")
|
| 30 |
state.device = device
|
| 31 |
-
|
| 32 |
# 1. Load the pre-computed artifacts file
|
| 33 |
-
artifacts_path = settings.RETRIEVAL_ARTIFACTS_PATH
|
| 34 |
logger.info(f"Loading retrieval artifacts from {artifacts_path}...")
|
| 35 |
try:
|
| 36 |
if not os.path.exists(artifacts_path):
|
| 37 |
logger.error(f"FATAL: Artifacts file not found at {artifacts_path}")
|
| 38 |
return False
|
| 39 |
-
|
| 40 |
artifacts = np.load(artifacts_path, allow_pickle=True)
|
| 41 |
-
|
| 42 |
# Load into state
|
| 43 |
state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings']
|
| 44 |
state.chunk_ids_in_order = artifacts['chunk_ids']
|
| 45 |
state.wq_weights = torch.from_numpy(artifacts['wq_weights']).to(device)
|
| 46 |
state.temperature = artifacts['temperature'][0] # Extract scalar from array
|
| 47 |
-
|
| 48 |
logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.")
|
| 49 |
logger.info(f"Loaded Wq matrix of shape: {state.wq_weights.shape}")
|
| 50 |
logger.info(f"Loaded temperature value: {state.temperature:.4f}")
|
| 51 |
-
|
| 52 |
except Exception as e:
|
| 53 |
logger.exception(f"Failed to load and process retrieval artifacts: {e}")
|
| 54 |
return False
|
| 55 |
-
|
| 56 |
# 2. Load the Sentence Transformer model for encoding queries
|
| 57 |
logger.info(f"Loading Sentence Transformer model: {settings.QUERY_ENCODER_MODEL_NAME}...")
|
| 58 |
try:
|
|
@@ -63,7 +76,6 @@ def load_retrieval_artifacts():
|
|
| 63 |
except Exception as e:
|
| 64 |
logger.exception(f"Failed to load Sentence Transformer model: {e}")
|
| 65 |
return False
|
| 66 |
-
|
| 67 |
state.artifacts_loaded = True
|
| 68 |
return True
|
| 69 |
|
|
|
|
| 17 |
|
| 18 |
def load_retrieval_artifacts():
|
| 19 |
"""
|
| 20 |
+
Loads all necessary artifacts for retrieval. If running on a new server,
|
| 21 |
+
it first downloads the artifacts from S3.
|
|
|
|
| 22 |
"""
|
| 23 |
if state.artifacts_loaded:
|
| 24 |
logger.info("Retrieval artifacts already loaded in state.")
|
| 25 |
return True
|
| 26 |
+
|
| 27 |
+
# --- ADDED: Download from S3 if file doesn't exist ---
|
| 28 |
+
artifacts_path = settings.RETRIEVAL_ARTIFACTS_PATH
|
| 29 |
+
if not os.path.exists(artifacts_path) and settings.S3_ARTIFACTS_URL:
|
| 30 |
+
logger.info(f"Artifacts file not found at {artifacts_path}. Downloading from S3...")
|
| 31 |
+
try:
|
| 32 |
+
# Create the 'data' directory if it doesn't exist
|
| 33 |
+
os.makedirs(os.path.dirname(artifacts_path), exist_ok=True)
|
| 34 |
+
with requests.get(settings.S3_ARTIFACTS_URL, stream=True) as r:
|
| 35 |
+
r.raise_for_status()
|
| 36 |
+
with open(artifacts_path, 'wb') as f:
|
| 37 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 38 |
+
f.write(chunk)
|
| 39 |
+
logger.info("Successfully downloaded artifacts from S3.")
|
| 40 |
+
except Exception as e:
|
| 41 |
+
logger.exception(f"FATAL: Failed to download artifacts from S3: {e}")
|
| 42 |
+
return False
|
| 43 |
+
# --- END OF ADDITION ---
|
| 44 |
+
|
| 45 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 46 |
logger.info(f"Using device for retrieval: {device}")
|
| 47 |
state.device = device
|
| 48 |
+
|
| 49 |
# 1. Load the pre-computed artifacts file
|
|
|
|
| 50 |
logger.info(f"Loading retrieval artifacts from {artifacts_path}...")
|
| 51 |
try:
|
| 52 |
if not os.path.exists(artifacts_path):
|
| 53 |
logger.error(f"FATAL: Artifacts file not found at {artifacts_path}")
|
| 54 |
return False
|
|
|
|
| 55 |
artifacts = np.load(artifacts_path, allow_pickle=True)
|
|
|
|
| 56 |
# Load into state
|
| 57 |
state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings']
|
| 58 |
state.chunk_ids_in_order = artifacts['chunk_ids']
|
| 59 |
state.wq_weights = torch.from_numpy(artifacts['wq_weights']).to(device)
|
| 60 |
state.temperature = artifacts['temperature'][0] # Extract scalar from array
|
|
|
|
| 61 |
logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.")
|
| 62 |
logger.info(f"Loaded Wq matrix of shape: {state.wq_weights.shape}")
|
| 63 |
logger.info(f"Loaded temperature value: {state.temperature:.4f}")
|
| 64 |
+
|
| 65 |
except Exception as e:
|
| 66 |
logger.exception(f"Failed to load and process retrieval artifacts: {e}")
|
| 67 |
return False
|
| 68 |
+
|
| 69 |
# 2. Load the Sentence Transformer model for encoding queries
|
| 70 |
logger.info(f"Loading Sentence Transformer model: {settings.QUERY_ENCODER_MODEL_NAME}...")
|
| 71 |
try:
|
|
|
|
| 76 |
except Exception as e:
|
| 77 |
logger.exception(f"Failed to load Sentence Transformer model: {e}")
|
| 78 |
return False
|
|
|
|
| 79 |
state.artifacts_loaded = True
|
| 80 |
return True
|
| 81 |
|