|  | import os | 
					
						
						|  | import json | 
					
						
						|  | import pickle | 
					
						
						|  | import faiss | 
					
						
						|  | from tqdm.auto import tqdm | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from sentence_transformers import SentenceTransformer | 
					
						
						|  | from tf_data_pipeline import TFDataPipeline | 
					
						
						|  | from chatbot_config import ChatbotConfig | 
					
						
						|  | from logger_config import config_logger | 
					
						
						|  |  | 
					
						
						|  | logger = config_logger(__name__) | 
					
						
						|  |  | 
					
						
						|  | os.environ["TOKENIZERS_PARALLELISM"] = "false" | 
					
						
						|  |  | 
					
						
						|  | def main(): | 
					
						
						|  | MODELS_DIR = 'models' | 
					
						
						|  | PROCESSED_DATA_DIR = 'processed_outputs' | 
					
						
						|  | CACHE_DIR = os.path.join(MODELS_DIR, 'query_embeddings_cache') | 
					
						
						|  | TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer') | 
					
						
						|  | FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices') | 
					
						
						|  | TF_RECORD_DIR = 'training_data' | 
					
						
						|  | FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index') | 
					
						
						|  | JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_only.json') | 
					
						
						|  | CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl') | 
					
						
						|  | TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | os.makedirs(MODELS_DIR, exist_ok=True) | 
					
						
						|  | os.makedirs(PROCESSED_DATA_DIR, exist_ok=True) | 
					
						
						|  | os.makedirs(CACHE_DIR, exist_ok=True) | 
					
						
						|  | os.makedirs(TOKENIZER_DIR, exist_ok=True) | 
					
						
						|  | os.makedirs(FAISS_INDICES_DIR, exist_ok=True) | 
					
						
						|  | os.makedirs(TF_RECORD_DIR, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | config_json = Path(MODELS_DIR) / "config.json" | 
					
						
						|  | if config_json.exists(): | 
					
						
						|  | with open(config_json, "r", encoding="utf-8") as f: | 
					
						
						|  | config_dict = json.load(f) | 
					
						
						|  | config = ChatbotConfig.from_dict(config_dict) | 
					
						
						|  | logger.info(f"Loaded ChatbotConfig from {config_json}") | 
					
						
						|  | else: | 
					
						
						|  | config = ChatbotConfig() | 
					
						
						|  | logger.warning("No config.json found. Using default ChatbotConfig.") | 
					
						
						|  | try: | 
					
						
						|  | with open(config_json, "w", encoding="utf-8") as f: | 
					
						
						|  | json.dump(config.to_dict(), f, indent=2) | 
					
						
						|  | logger.info(f"Default ChatbotConfig saved to {config_json}") | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Failed to save default ChatbotConfig: {e}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | encoder = SentenceTransformer(config.pretrained_model) | 
					
						
						|  | logger.info(f"Initialized SentenceTransformer model: {config.pretrained_model}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if Path(JSON_TRAINING_DATA_PATH).exists(): | 
					
						
						|  | dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH) | 
					
						
						|  | logger.info(f"Loaded {len(dialogues)} dialogues.") | 
					
						
						|  | else: | 
					
						
						|  | logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}.") | 
					
						
						|  | dialogues = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | query_embeddings_cache = {} | 
					
						
						|  | if os.path.exists(CACHE_FILE): | 
					
						
						|  | with open(CACHE_FILE, 'rb') as f: | 
					
						
						|  | query_embeddings_cache = pickle.load(f) | 
					
						
						|  | logger.info(f"Loaded query embeddings cache with {len(query_embeddings_cache)} entries.") | 
					
						
						|  | else: | 
					
						
						|  | logger.info("No existing query embeddings cache found. Starting fresh.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dimension = encoder.get_sentence_embedding_dimension() | 
					
						
						|  | if Path(FAISS_INDEX_PRODUCTION_PATH).exists(): | 
					
						
						|  | faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH) | 
					
						
						|  | logger.info(f"Loaded FAISS index from {FAISS_INDEX_PRODUCTION_PATH}.") | 
					
						
						|  | else: | 
					
						
						|  | faiss_index = faiss.IndexFlatIP(dimension) | 
					
						
						|  | logger.info(f"Initialized new FAISS index with dimension {dimension}.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data_pipeline = TFDataPipeline( | 
					
						
						|  | config=config, | 
					
						
						|  | tokenizer=encoder.tokenizer, | 
					
						
						|  | encoder=encoder, | 
					
						
						|  | response_pool=[], | 
					
						
						|  | query_embeddings_cache=query_embeddings_cache, | 
					
						
						|  | index_type='IndexFlatIP', | 
					
						
						|  | faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if dialogues: | 
					
						
						|  | response_pool = data_pipeline.collect_responses_with_domain(dialogues) | 
					
						
						|  | data_pipeline.response_pool = response_pool | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json') | 
					
						
						|  | with open(response_pool_path, 'w', encoding='utf-8') as f: | 
					
						
						|  | json.dump(response_pool, f, indent=2) | 
					
						
						|  | logger.info(f"Response pool saved to {response_pool_path}.") | 
					
						
						|  | data_pipeline.compute_and_index_response_embeddings() | 
					
						
						|  | data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH) | 
					
						
						|  | logger.info(f"FAISS index saved at {FAISS_INDEX_PRODUCTION_PATH}.") | 
					
						
						|  | else: | 
					
						
						|  | logger.warning("No responses to embed. Skipping FAISS indexing.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with open(CACHE_FILE, 'wb') as f: | 
					
						
						|  | pickle.dump(query_embeddings_cache, f) | 
					
						
						|  | logger.info(f"Query embeddings cache saved at {CACHE_FILE}.") | 
					
						
						|  |  | 
					
						
						|  | logger.info("Pipeline completed successfully.") | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | main() | 
					
						
						|  |  |