File size: 3,798 Bytes
a763857 c7c1b4e 64e7c31 c7c1b4e a763857 c7c1b4e a763857 c7c1b4e a763857 c7c1b4e a763857 c7c1b4e a763857 c7c1b4e a763857 c7c1b4e a763857 c7c1b4e a763857 c7c1b4e a763857 c7c1b4e a763857 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import json
from tqdm.auto import tqdm
from chatbot_config import ChatbotConfig
from chatbot_model import RetrievalChatbot
from sentence_transformers import SentenceTransformer
from tf_data_pipeline import TFDataPipeline
from response_quality_checker import ResponseQualityChecker
from environment_setup import EnvironmentSetup
from logger_config import config_logger
logger = config_logger(__name__)
logger.setLevel("WARNING")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tqdm(disable=True)
def run_chatbot_chat():
env = EnvironmentSetup()
env.initialize()
MODEL_DIR = "models"
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
# Toggle 'production' or 'test' env
ENVIRONMENT = "production"
if ENVIRONMENT == "test":
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json")
else:
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
# Load the config
config_path = os.path.join(MODEL_DIR, "config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_dict = json.load(f)
config = ChatbotConfig.from_dict(config_dict)
logger.info(f"Loaded ChatbotConfig from {config_path}")
else:
config = ChatbotConfig()
logger.warning("No config.json found. Using default ChatbotConfig.")
# Init SentenceTransformer
try:
encoder = SentenceTransformer(config.pretrained_model)
logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
except Exception as e:
logger.error(f"Failed to load SentenceTransformer: {e}")
return
# Load FAISS index and response pool
try:
# Initialize TFDataPipeline
data_pipeline = TFDataPipeline(
config=config,
tokenizer=encoder.tokenizer,
encoder=encoder,
response_pool=[],
query_embeddings_cache={},
index_type='IndexFlatIP',
faiss_index_file_path=FAISS_INDEX_PATH
)
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
logger.error("FAISS index or response pool file is missing.")
return
data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
data_pipeline.response_pool = json.load(f)
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}")
# Validate dimension consistency
data_pipeline.validate_faiss_index()
logger.info("FAISS index and response pool validated successfully.")
except Exception as e:
logger.error(f"Failed to load or validate FAISS index: {e}")
return
# Run interactive chat
try:
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
logger.info("\nStarting interactive chat session...")
chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=False)
except Exception as e:
logger.error(f"Interactive chat session failed: {e}")
if __name__ == "__main__":
run_chatbot_chat()
|