|
import os |
|
import logging |
|
from typing import List, Tuple |
|
from functools import cached_property |
|
from pydantic import BaseModel, Field |
|
from openai import OpenAI |
|
import faiss |
|
import pickle |
|
import numpy as np |
|
from dotenv import load_dotenv |
|
import gradio as gr |
|
from datetime import datetime |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
load_dotenv() |
|
|
|
log = 'cevauto' |
|
password = os.getenv('PASSWORD') |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
NO_DATA_MESSAGE = "I apologize, but I encountered an error processing your request." |
|
|
|
class LocalEmbedding: |
|
"""Local embedding model wrapper""" |
|
def __init__(self, model_name: str = "all-MiniLM-L6-v2"): |
|
self.model = SentenceTransformer(model_name) |
|
self.vector_dim = self.model.get_sentence_embedding_dimension() |
|
|
|
def get_embedding(self, text: str) -> List[float]: |
|
"""Get embedding using local model""" |
|
try: |
|
embedding = self.model.encode(text) |
|
return embedding.tolist() |
|
except Exception as e: |
|
logger.error(f"Error getting embedding: {e}") |
|
return [] |
|
|
|
class DeepSeekChat(BaseModel): |
|
"""DeepSeek chat model wrapper""" |
|
api_key: str = Field(default=os.getenv("DEEPSEEK_API_KEY")) |
|
base_url: str = Field(default="https://api.siliconflow.cn/v1") |
|
|
|
class Config: |
|
"""Pydantic config class""" |
|
arbitrary_types_allowed = True |
|
|
|
@cached_property |
|
def client(self) -> OpenAI: |
|
"""Create and cache OpenAI client instance""" |
|
return OpenAI(api_key=self.api_key, base_url=self.base_url) |
|
|
|
def chat( |
|
self, |
|
system_message: str, |
|
user_message: str, |
|
context: str = "", |
|
model: str = "deepseek-ai/DeepSeek-V3", |
|
max_tokens: int = 1024, |
|
temperature: float = 0.7, |
|
) -> str: |
|
"""Send chat request to DeepSeek API""" |
|
messages = [] |
|
|
|
|
|
if system_message: |
|
messages.append({"role": "system", "content": system_message}) |
|
|
|
|
|
if context: |
|
messages.append({"role": "user", "content": context}) |
|
|
|
|
|
messages.append({"role": "user", "content": user_message}) |
|
|
|
try: |
|
response = self.client.chat.completions.create( |
|
model=model, |
|
messages=messages, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
logger.error(f"Error in DeepSeek API call: {e}") |
|
return NO_DATA_MESSAGE |
|
|
|
class PDFChatbot: |
|
def __init__(self, index_path: str, texts_path: str, model_name: str = "all-MiniLM-L6-v2"): |
|
if not os.getenv("DEEPSEEK_API_KEY"): |
|
raise ValueError("DEEPSEEK_API_KEY not found in .env file") |
|
|
|
|
|
logger.info("Initializing models...") |
|
self.chat_model = DeepSeekChat() |
|
self.embedding_model = LocalEmbedding(model_name) |
|
|
|
|
|
logger.info("Loading vector database...") |
|
self.index = faiss.read_index(index_path) |
|
with open(texts_path, 'rb') as f: |
|
self.texts = pickle.load(f) |
|
|
|
|
|
self.system_message = """You are a knowledgeable AI assistant that helps users understand the content of the provided document. |
|
Use the context provided to answer questions accurately and comprehensively. If the answer cannot be found in the context, |
|
clearly state that the information is not available in the document.""" |
|
|
|
|
|
self.log_file = f"pdf_chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt" |
|
self.log_conversation("Conversation started") |
|
|
|
def log_conversation(self, message, role="system"): |
|
"""Log conversation with timestamp to file""" |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
with open(self.log_file, "a", encoding="utf-8") as f: |
|
f.write(f"[{timestamp}] {role}: {message}\n") |
|
|
|
def get_relevant_context(self, query: str, k: int = 3) -> str: |
|
"""Get most relevant context for the query""" |
|
try: |
|
|
|
query_embedding = self.embedding_model.get_embedding(query) |
|
if not query_embedding: |
|
return "" |
|
|
|
|
|
query_vector = np.array([query_embedding]).astype('float32') |
|
distances, indices = self.index.search(query_vector, k) |
|
|
|
|
|
relevant_texts = [self.texts[i] for i in indices[0]] |
|
return "\n".join(relevant_texts) |
|
except Exception as e: |
|
logger.error(f"Error getting relevant context: {e}") |
|
return "" |
|
|
|
def chat(self, message, history): |
|
"""Process chat message and return response""" |
|
try: |
|
|
|
self.log_conversation(message, "user") |
|
|
|
|
|
context = self.get_relevant_context(message) |
|
|
|
|
|
context_prompt = f"Based on the following context from the document:\n{context}\n\nPlease answer the question." if context else "" |
|
|
|
|
|
response = self.chat_model.chat( |
|
system_message=self.system_message, |
|
user_message=message, |
|
context=context_prompt |
|
) |
|
|
|
|
|
self.log_conversation(response, "assistant") |
|
|
|
return response |
|
except Exception as e: |
|
logger.error(f"Error in chat: {e}") |
|
return NO_DATA_MESSAGE |
|
|
|
def main(): |
|
try: |
|
|
|
index_path = "vectordb/1.index" |
|
texts_path = "vectordb/1.pkl" |
|
|
|
|
|
chatbot = PDFChatbot(index_path, texts_path) |
|
|
|
|
|
iface = gr.ChatInterface( |
|
fn=chatbot.chat, |
|
title="CEVauto AI expert", |
|
description="Ask questions about the Car exporting. I'll help you understand its contents.", |
|
theme=gr.themes.Soft(), |
|
examples=[ |
|
"What is the car import regulation of UAE?", |
|
"What should I do when customer complains our quotation is too high?", |
|
"What is the import tax of Ethiopia?" |
|
], |
|
) |
|
|
|
|
|
iface.launch() |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to initialize chatbot: {e}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|