Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Query | |
from pydantic import BaseModel | |
from uuid import uuid4 | |
import uvicorn | |
import os | |
import requests | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts import PromptTemplate | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.memory import ConversationTokenBufferMemory | |
from langchain.document_loaders import DirectoryLoader | |
from langchain.document_loaders.csv_loader import CSVLoader | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
from langchain.vectorstores import Qdrant | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
app = FastAPI() | |
# Load environment variables | |
os.environ["OPENAI_API_KEY"] = "your_api_key_here" | |
# Initialize components | |
# Loading Data | |
loaders = {'.csv': CSVLoader} | |
def create_directory_loader(file_type, directory_path): | |
return DirectoryLoader( | |
path=directory_path, | |
glob=f"**/*{file_type}", | |
loader_cls=loaders[file_type], | |
) | |
csv_file = create_directory_loader('.csv', '/my_file/py_project/') | |
csv_document = csv_file.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100) | |
csv_splits = text_splitter.split_documents(csv_document) | |
# Loading Embeddings | |
embeddings = OpenAIEmbeddings() | |
# Vector Store | |
def create_qdrant_retriever(documents, embedding, collection_name): | |
qdrant = Qdrant.from_documents( | |
documents, | |
embedding, | |
location=":memory:", | |
collection_name=collection_name, | |
) | |
return qdrant | |
csv_vectorstore = create_qdrant_retriever(csv_splits, embeddings, "csv_documents") | |
# Retrievers | |
def create_ensemble_retriever(splits, vectorstore): | |
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) | |
bm25_retriever = BM25Retriever.from_documents(splits) | |
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever]) | |
return ensemble_retriever | |
csv_retriever = create_ensemble_retriever(csv_splits, csv_vectorstore) | |
# Perform similarity search | |
def perform_similarity_search(query, instance): | |
found_docs = instance.similarity_search_with_score(query) | |
if found_docs: | |
document, score = found_docs[0] | |
return document.page_content, score | |
else: | |
return "No relevant documents found.", None | |
query = ''' | |
Hi Auto, Thanks for the service today. We have a problem: the warning light on the dash is still showing a fault with the filter.''' | |
content, score = perform_similarity_search(query, csv_vectorstore) | |
print("Result:\n", content, "\nScore:", score) | |
# Prompt Design | |
template_csv = """ | |
Your personal name is the "Auto_Assist Assistant", and you must use this name.You specialize in addressing a wide range of queries and concerns related to automotive services. Your core mission is to offer accurate, efficient,more precise concise solutions based on our extensive service records. | |
{context}: | |
Analyse the customer's query to understand the core issue or request, identifying the service category it falls under. | |
Search the CSV data to find exact issues or queries and the solutions provided across our broad service spectrum. | |
If a matching case is found, use the information to construct a concise, personalised response. | |
If no relevant records are found, kindly inform the customer that the specific query cannot be resolved at the moment and suggest the next best steps or encourage them to await expert review for more complex issues. | |
The output must be concise and must be UK British. | |
Question: {question} | |
Answer: | |
""" | |
prompt_csv = PromptTemplate(template=template_csv, input_variables=["context", "question"]) | |
# Creating memory | |
memory = ConversationTokenBufferMemory( | |
llm=ChatOpenAI(model_name="gpt-4o"), | |
max_token_limit=1000, | |
memory_key="chat_history", | |
return_messages=True, | |
output_key='answer', | |
) | |
# Creating csv Retrieval Chain | |
csv_chain = ConversationalRetrievalChain.from_llm( | |
llm=ChatOpenAI(model_name="gpt-4o"), | |
retriever=csv_retriever, | |
return_source_documents=True, | |
chain_type='stuff', | |
combine_docs_chain_kwargs={"prompt": prompt_csv}, | |
memory=memory, | |
verbose=False, | |
) | |
# API Models | |
class Query(BaseModel): | |
user_query: str | |
class CreateChatResponse(BaseModel): | |
chat_id: str | |
class DeleteChatResponse(BaseModel): | |
message: str | |
# In-memory storage for chat histories | |
chat_histories = {} | |
# API Endpoints | |
async def create_chat(): | |
chat_id = str(uuid4()) | |
chat_histories[chat_id] = [] | |
return {"chat_id": chat_id} | |
async def chatbot_endpoint(chat_id: str, query: Query): | |
if chat_id not in chat_histories: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
chat_history = chat_histories[chat_id] | |
result = csv_chain({"question": query.user_query, "chat_history": chat_history}) | |
chat_history.append((query.user_query, result["answer"])) | |
chat_histories[chat_id] = chat_history | |
return {"response": result["answer"]} | |
async def list_chats(): | |
return {"chat_ids": list(chat_histories.keys())} | |
async def get_chat_history(chat_id: str): | |
if chat_id not in chat_histories: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
return {"chat_history": chat_histories[chat_id]} | |
async def delete_chat(chat_id: str): | |
if chat_id not in chat_histories: | |
raise HTTPException(status_code=404, detail="Chat not found") | |
del chat_histories[chat_id] | |
return {"message": "Chat deleted successfully"} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=container_2:8000' | |
) | |