|
import os |
|
import json |
|
import gradio as gr |
|
from pathlib import Path |
|
from typing import List, Tuple |
|
from chatbot_config import ChatbotConfig |
|
from chatbot_model import RetrievalChatbot |
|
from tf_data_pipeline import TFDataPipeline |
|
from response_quality_checker import ResponseQualityChecker |
|
from environment_setup import EnvironmentSetup |
|
from sentence_transformers import SentenceTransformer |
|
from logger_config import config_logger |
|
|
|
logger = config_logger(__name__) |
|
|
|
def load_pipeline(): |
|
""" |
|
Loads config, FAISS index, response pool, SentenceTransformer, TFDataPipeline, and sets up the chatbot. |
|
""" |
|
MODEL_DIR = "models" |
|
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices") |
|
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index") |
|
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json") |
|
|
|
config_path = Path(MODEL_DIR) / "config.json" |
|
if config_path.exists(): |
|
with open(config_path, "r", encoding="utf-8") as f: |
|
config_dict = json.load(f) |
|
config = ChatbotConfig.from_dict(config_dict) |
|
else: |
|
config = ChatbotConfig() |
|
|
|
|
|
env = EnvironmentSetup() |
|
env.initialize() |
|
|
|
|
|
encoder = SentenceTransformer(config.pretrained_model) |
|
|
|
data_pipeline = TFDataPipeline( |
|
config=config, |
|
tokenizer=encoder.tokenizer, |
|
encoder=encoder, |
|
response_pool=[], |
|
query_embeddings_cache={}, |
|
index_type='IndexFlatIP', |
|
faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH |
|
) |
|
|
|
|
|
if os.path.exists(FAISS_INDEX_PRODUCTION_PATH) and os.path.exists(RESPONSE_POOL_PATH): |
|
data_pipeline.load_faiss_index(FAISS_INDEX_PRODUCTION_PATH) |
|
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f: |
|
data_pipeline.response_pool = json.load(f) |
|
data_pipeline.validate_faiss_index() |
|
else: |
|
logger.warning("FAISS index or responses are missing. The chatbot may not work properly.") |
|
|
|
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference") |
|
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline) |
|
|
|
return chatbot, quality_checker |
|
|
|
|
|
chatbot, quality_checker = load_pipeline() |
|
|
|
def respond(message: str, history: List[List[str]]) -> Tuple[str, List[List[str]]]: |
|
"""Generate chatbot response using internal context handling.""" |
|
if not message.strip(): |
|
return "", history |
|
|
|
try: |
|
response, _, metrics, confidence = chatbot.chat( |
|
query=message, |
|
conversation_history=None, |
|
quality_checker=quality_checker, |
|
top_k=10 |
|
) |
|
|
|
history.append((message, response)) |
|
return "", history |
|
except Exception as e: |
|
logger.error(f"Error generating response: {e}") |
|
error_message = "I apologize, but I encountered an error processing your request." |
|
history.append((message, error_message)) |
|
return "", history |
|
|
|
def main(): |
|
"""Initialize and launch Gradio interface.""" |
|
with gr.Blocks( |
|
title="Chatbot Demo", |
|
css=""" |
|
.message-wrap { max-height: 800px !important; } |
|
.chatbot { min-height: 600px; } |
|
""" |
|
) as demo: |
|
gr.Markdown( |
|
""" |
|
# Retrieval-Based Chatbot Demo using Sentence Transformers + FAISS |
|
Knowledge areas: restaurants, movie tickets, rideshare, coffee, pizza, and auto repair. |
|
""" |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
label="Conversation", |
|
container=True, |
|
height=600, |
|
show_label=True, |
|
elem_classes="chatbot" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
show_label=False, |
|
placeholder="Type your message here...", |
|
container=False, |
|
scale=8 |
|
) |
|
send = gr.Button( |
|
"Send", |
|
variant="primary", |
|
scale=1, |
|
min_width=100 |
|
) |
|
|
|
clear = gr.Button("Clear Conversation", variant="secondary") |
|
|
|
|
|
msg.submit(respond, [msg, chatbot], [msg, chatbot], queue=False) |
|
send.click(respond, [msg, chatbot], [msg, chatbot], queue=False) |
|
clear.click(lambda: ([], []), outputs=[chatbot, msg], queue=False) |
|
|
|
|
|
msg.change(lambda: None, None, None, queue=False) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = main() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
) |