osmanYusuf's picture
Update main.py
31e0768 verified
raw
history blame contribute delete
No virus
5.88 kB
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
from uuid import uuid4
import uvicorn
import os
import requests
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.embeddings import OpenAIEmbeddings
from langchain.memory import ConversationTokenBufferMemory
from langchain.document_loaders import DirectoryLoader
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.chains import ConversationalRetrievalChain
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.vectorstores import Qdrant
from langchain.text_splitter import RecursiveCharacterTextSplitter
app = FastAPI()
# Load environment variables
os.environ["OPENAI_API_KEY"] = "your_api_key_here"
# Initialize components
# Loading Data
loaders = {'.csv': CSVLoader}
def create_directory_loader(file_type, directory_path):
return DirectoryLoader(
path=directory_path,
glob=f"**/*{file_type}",
loader_cls=loaders[file_type],
)
csv_file = create_directory_loader('.csv', '/my_file/py_project/')
csv_document = csv_file.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
csv_splits = text_splitter.split_documents(csv_document)
# Loading Embeddings
embeddings = OpenAIEmbeddings()
# Vector Store
def create_qdrant_retriever(documents, embedding, collection_name):
qdrant = Qdrant.from_documents(
documents,
embedding,
location=":memory:",
collection_name=collection_name,
)
return qdrant
csv_vectorstore = create_qdrant_retriever(csv_splits, embeddings, "csv_documents")
# Retrievers
def create_ensemble_retriever(splits, vectorstore):
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
bm25_retriever = BM25Retriever.from_documents(splits)
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever])
return ensemble_retriever
csv_retriever = create_ensemble_retriever(csv_splits, csv_vectorstore)
# Perform similarity search
def perform_similarity_search(query, instance):
found_docs = instance.similarity_search_with_score(query)
if found_docs:
document, score = found_docs[0]
return document.page_content, score
else:
return "No relevant documents found.", None
query = '''
Hi Auto, Thanks for the service today. We have a problem: the warning light on the dash is still showing a fault with the filter.'''
content, score = perform_similarity_search(query, csv_vectorstore)
print("Result:\n", content, "\nScore:", score)
# Prompt Design
template_csv = """
Your personal name is the "Auto_Assist Assistant", and you must use this name.You specialize in addressing a wide range of queries and concerns related to automotive services. Your core mission is to offer accurate, efficient,more precise concise solutions based on our extensive service records.
{context}:
Analyse the customer's query to understand the core issue or request, identifying the service category it falls under.
Search the CSV data to find exact issues or queries and the solutions provided across our broad service spectrum.
If a matching case is found, use the information to construct a concise, personalised response.
If no relevant records are found, kindly inform the customer that the specific query cannot be resolved at the moment and suggest the next best steps or encourage them to await expert review for more complex issues.
The output must be concise and must be UK British.
Question: {question}
Answer:
"""
prompt_csv = PromptTemplate(template=template_csv, input_variables=["context", "question"])
# Creating memory
memory = ConversationTokenBufferMemory(
llm=ChatOpenAI(model_name="gpt-4o"),
max_token_limit=1000,
memory_key="chat_history",
return_messages=True,
output_key='answer',
)
# Creating csv Retrieval Chain
csv_chain = ConversationalRetrievalChain.from_llm(
llm=ChatOpenAI(model_name="gpt-4o"),
retriever=csv_retriever,
return_source_documents=True,
chain_type='stuff',
combine_docs_chain_kwargs={"prompt": prompt_csv},
memory=memory,
verbose=False,
)
# API Models
class Query(BaseModel):
user_query: str
class CreateChatResponse(BaseModel):
chat_id: str
class DeleteChatResponse(BaseModel):
message: str
# In-memory storage for chat histories
chat_histories = {}
# API Endpoints
@app.post("/create_chat/", response_model=CreateChatResponse)
async def create_chat():
chat_id = str(uuid4())
chat_histories[chat_id] = []
return {"chat_id": chat_id}
@app.post("/query/{chat_id}/")
async def chatbot_endpoint(chat_id: str, query: Query):
if chat_id not in chat_histories:
raise HTTPException(status_code=404, detail="Chat not found")
chat_history = chat_histories[chat_id]
result = csv_chain({"question": query.user_query, "chat_history": chat_history})
chat_history.append((query.user_query, result["answer"]))
chat_histories[chat_id] = chat_history
return {"response": result["answer"]}
@app.get("/chats/", response_model=dict)
async def list_chats():
return {"chat_ids": list(chat_histories.keys())}
@app.get("/chats/{chat_id}/", response_model=dict)
async def get_chat_history(chat_id: str):
if chat_id not in chat_histories:
raise HTTPException(status_code=404, detail="Chat not found")
return {"chat_history": chat_histories[chat_id]}
@app.delete("/chats/{chat_id}/", response_model=DeleteChatResponse)
async def delete_chat(chat_id: str):
if chat_id not in chat_histories:
raise HTTPException(status_code=404, detail="Chat not found")
del chat_histories[chat_id]
return {"message": "Chat deleted successfully"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=container_2:8006'
)