File size: 3,798 Bytes
a763857
 
c7c1b4e
64e7c31
c7c1b4e
 
 
a763857
 
 
 
 
c7c1b4e
 
 
 
a763857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7c1b4e
a763857
c7c1b4e
 
a763857
c7c1b4e
a763857
c7c1b4e
a763857
 
c7c1b4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a763857
c7c1b4e
 
 
 
a763857
c7c1b4e
 
a763857
 
 
c7c1b4e
 
 
 
 
 
 
 
 
 
a763857
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import json
from tqdm.auto import tqdm
from chatbot_config import ChatbotConfig
from chatbot_model import RetrievalChatbot
from sentence_transformers import SentenceTransformer
from tf_data_pipeline import TFDataPipeline
from response_quality_checker import ResponseQualityChecker
from environment_setup import EnvironmentSetup
from logger_config import config_logger

logger = config_logger(__name__)
logger.setLevel("WARNING")

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tqdm(disable=True)
            
def run_chatbot_chat():
    env = EnvironmentSetup()
    env.initialize()
    
    MODEL_DIR = "models"
    FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
    FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
    FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
    
    # Toggle 'production' or 'test' env
    ENVIRONMENT = "production"
    if ENVIRONMENT == "test":
        FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
        RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json")
    else:
        FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
        RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
    
    # Load the config
    config_path = os.path.join(MODEL_DIR, "config.json")
    if os.path.exists(config_path):
        with open(config_path, "r", encoding="utf-8") as f:
            config_dict = json.load(f)
        config = ChatbotConfig.from_dict(config_dict)
        logger.info(f"Loaded ChatbotConfig from {config_path}")
    else:
        config = ChatbotConfig()
        logger.warning("No config.json found. Using default ChatbotConfig.")
        
    # Init SentenceTransformer
    try:
        encoder = SentenceTransformer(config.pretrained_model)
        logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
    except Exception as e:
        logger.error(f"Failed to load SentenceTransformer: {e}")
        return

    # Load FAISS index and response pool
    try:
        # Initialize TFDataPipeline
        data_pipeline = TFDataPipeline(
            config=config,
            tokenizer=encoder.tokenizer,
            encoder=encoder,
            response_pool=[],
            query_embeddings_cache={},
            index_type='IndexFlatIP',
            faiss_index_file_path=FAISS_INDEX_PATH
        )

        if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
            logger.error("FAISS index or response pool file is missing.")
            return

        data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
        logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")

        with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
            data_pipeline.response_pool = json.load(f)
            logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
            logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}")

        # Validate dimension consistency
        data_pipeline.validate_faiss_index()
        logger.info("FAISS index and response pool validated successfully.")
    except Exception as e:
        logger.error(f"Failed to load or validate FAISS index: {e}")
        return

    # Run interactive chat
    try:
        chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
        quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
        
        logger.info("\nStarting interactive chat session...")
        chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=False)
    except Exception as e:
        logger.error(f"Interactive chat session failed: {e}")

if __name__ == "__main__":
    run_chatbot_chat()