helal94hb1 commited on
Commit
757abfc
·
1 Parent(s): 9d8e329

feat: Update application with new changes

Browse files
app/services/reranker_service.py CHANGED
@@ -14,37 +14,50 @@ logger = logging.getLogger(__name__)
14
 
15
  def load_reranker_model():
16
  """
17
- Loads the custom-trained ExpertJudgeCrossEncoder model and its tokenizer
18
- into the state object.
19
  """
20
  if state.reranker_model_loaded:
21
  logger.info("Re-ranker model already loaded in state.")
22
  return True
23
-
 
24
  model_path = settings.RERANKER_MODEL_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  base_model_name = settings.RERANKER_MODEL_NAME
26
  logger.info(f"Loading custom ExpertJudgeCrossEncoder from: {model_path}")
27
  logger.info(f"Using base model architecture: {base_model_name}")
28
-
29
  try:
30
  # 1. Instantiate the model structure
31
  model = ExpertJudgeCrossEncoder(model_name=base_model_name)
32
-
33
  # 2. Load the saved weights (the state_dict) into the model structure
34
  model.load_state_dict(torch.load(model_path, map_location=state.device))
35
-
36
  # 3. Set up the model for inference
37
  model.to(state.device)
38
  model.eval()
39
-
40
  # 4. Load the corresponding tokenizer
41
  tokenizer = get_tokenizer(model_name=base_model_name)
42
-
43
  # 5. Store both in the state
44
  state.reranker_model = model
45
  state.reranker_tokenizer = tokenizer
46
  state.reranker_model_loaded = True
47
-
48
  logger.info("Custom ExpertJudgeCrossEncoder model and tokenizer loaded successfully.")
49
  return True
50
  except Exception as e:
 
14
 
15
  def load_reranker_model():
16
  """
17
+ Loads the custom-trained ExpertJudgeCrossEncoder model. If running on a
18
+ new server, it first downloads the model from S3.
19
  """
20
  if state.reranker_model_loaded:
21
  logger.info("Re-ranker model already loaded in state.")
22
  return True
23
+
24
+ # --- ADDED: Download from S3 if file doesn't exist ---
25
  model_path = settings.RERANKER_MODEL_PATH
26
+ if not os.path.exists(model_path) and settings.S3_RERANKER_URL:
27
+ logger.info(f"Re-ranker model not found at {model_path}. Downloading from S3...")
28
+ try:
29
+ # Create the 'data' directory if it doesn't exist
30
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
31
+ with requests.get(settings.S3_RERANKER_URL, stream=True) as r:
32
+ r.raise_for_status()
33
+ with open(model_path, 'wb') as f:
34
+ for chunk in r.iter_content(chunk_size=8192):
35
+ f.write(chunk)
36
+ logger.info("Successfully downloaded re-ranker model from S3.")
37
+ except Exception as e:
38
+ logger.exception(f"FATAL: Failed to download re-ranker model from S3: {e}")
39
+ return False
40
+ # --- END OF ADDITION ---
41
+
42
  base_model_name = settings.RERANKER_MODEL_NAME
43
  logger.info(f"Loading custom ExpertJudgeCrossEncoder from: {model_path}")
44
  logger.info(f"Using base model architecture: {base_model_name}")
45
+
46
  try:
47
  # 1. Instantiate the model structure
48
  model = ExpertJudgeCrossEncoder(model_name=base_model_name)
 
49
  # 2. Load the saved weights (the state_dict) into the model structure
50
  model.load_state_dict(torch.load(model_path, map_location=state.device))
 
51
  # 3. Set up the model for inference
52
  model.to(state.device)
53
  model.eval()
 
54
  # 4. Load the corresponding tokenizer
55
  tokenizer = get_tokenizer(model_name=base_model_name)
56
+
57
  # 5. Store both in the state
58
  state.reranker_model = model
59
  state.reranker_tokenizer = tokenizer
60
  state.reranker_model_loaded = True
 
61
  logger.info("Custom ExpertJudgeCrossEncoder model and tokenizer loaded successfully.")
62
  return True
63
  except Exception as e:
app/services/retrieval.py CHANGED
@@ -17,42 +17,55 @@ logger = logging.getLogger(__name__)
17
 
18
  def load_retrieval_artifacts():
19
  """
20
- Loads all necessary artifacts for retrieval from the pre-computed NPZ file.
21
- This includes the query encoder, pre-transformed chunk embeddings, and the
22
- query transformation matrix (Wq).
23
  """
24
  if state.artifacts_loaded:
25
  logger.info("Retrieval artifacts already loaded in state.")
26
  return True
27
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
  logger.info(f"Using device for retrieval: {device}")
30
  state.device = device
31
-
32
  # 1. Load the pre-computed artifacts file
33
- artifacts_path = settings.RETRIEVAL_ARTIFACTS_PATH
34
  logger.info(f"Loading retrieval artifacts from {artifacts_path}...")
35
  try:
36
  if not os.path.exists(artifacts_path):
37
  logger.error(f"FATAL: Artifacts file not found at {artifacts_path}")
38
  return False
39
-
40
  artifacts = np.load(artifacts_path, allow_pickle=True)
41
-
42
  # Load into state
43
  state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings']
44
  state.chunk_ids_in_order = artifacts['chunk_ids']
45
  state.wq_weights = torch.from_numpy(artifacts['wq_weights']).to(device)
46
  state.temperature = artifacts['temperature'][0] # Extract scalar from array
47
-
48
  logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.")
49
  logger.info(f"Loaded Wq matrix of shape: {state.wq_weights.shape}")
50
  logger.info(f"Loaded temperature value: {state.temperature:.4f}")
51
-
52
  except Exception as e:
53
  logger.exception(f"Failed to load and process retrieval artifacts: {e}")
54
  return False
55
-
56
  # 2. Load the Sentence Transformer model for encoding queries
57
  logger.info(f"Loading Sentence Transformer model: {settings.QUERY_ENCODER_MODEL_NAME}...")
58
  try:
@@ -63,7 +76,6 @@ def load_retrieval_artifacts():
63
  except Exception as e:
64
  logger.exception(f"Failed to load Sentence Transformer model: {e}")
65
  return False
66
-
67
  state.artifacts_loaded = True
68
  return True
69
 
 
17
 
18
  def load_retrieval_artifacts():
19
  """
20
+ Loads all necessary artifacts for retrieval. If running on a new server,
21
+ it first downloads the artifacts from S3.
 
22
  """
23
  if state.artifacts_loaded:
24
  logger.info("Retrieval artifacts already loaded in state.")
25
  return True
26
+
27
+ # --- ADDED: Download from S3 if file doesn't exist ---
28
+ artifacts_path = settings.RETRIEVAL_ARTIFACTS_PATH
29
+ if not os.path.exists(artifacts_path) and settings.S3_ARTIFACTS_URL:
30
+ logger.info(f"Artifacts file not found at {artifacts_path}. Downloading from S3...")
31
+ try:
32
+ # Create the 'data' directory if it doesn't exist
33
+ os.makedirs(os.path.dirname(artifacts_path), exist_ok=True)
34
+ with requests.get(settings.S3_ARTIFACTS_URL, stream=True) as r:
35
+ r.raise_for_status()
36
+ with open(artifacts_path, 'wb') as f:
37
+ for chunk in r.iter_content(chunk_size=8192):
38
+ f.write(chunk)
39
+ logger.info("Successfully downloaded artifacts from S3.")
40
+ except Exception as e:
41
+ logger.exception(f"FATAL: Failed to download artifacts from S3: {e}")
42
+ return False
43
+ # --- END OF ADDITION ---
44
+
45
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
  logger.info(f"Using device for retrieval: {device}")
47
  state.device = device
48
+
49
  # 1. Load the pre-computed artifacts file
 
50
  logger.info(f"Loading retrieval artifacts from {artifacts_path}...")
51
  try:
52
  if not os.path.exists(artifacts_path):
53
  logger.error(f"FATAL: Artifacts file not found at {artifacts_path}")
54
  return False
 
55
  artifacts = np.load(artifacts_path, allow_pickle=True)
 
56
  # Load into state
57
  state.transformed_chunk_embeddings = artifacts['transformed_chunk_embeddings']
58
  state.chunk_ids_in_order = artifacts['chunk_ids']
59
  state.wq_weights = torch.from_numpy(artifacts['wq_weights']).to(device)
60
  state.temperature = artifacts['temperature'][0] # Extract scalar from array
 
61
  logger.info(f"Successfully loaded {len(state.chunk_ids_in_order)} transformed embeddings.")
62
  logger.info(f"Loaded Wq matrix of shape: {state.wq_weights.shape}")
63
  logger.info(f"Loaded temperature value: {state.temperature:.4f}")
64
+
65
  except Exception as e:
66
  logger.exception(f"Failed to load and process retrieval artifacts: {e}")
67
  return False
68
+
69
  # 2. Load the Sentence Transformer model for encoding queries
70
  logger.info(f"Loading Sentence Transformer model: {settings.QUERY_ENCODER_MODEL_NAME}...")
71
  try:
 
76
  except Exception as e:
77
  logger.exception(f"Failed to load Sentence Transformer model: {e}")
78
  return False
 
79
  state.artifacts_loaded = True
80
  return True
81