Spaces:
Sleeping
Sleeping
File size: 5,878 Bytes
5eeafe5 31e0768 5eeafe5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
from uuid import uuid4
import uvicorn
import os
import requests
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.embeddings import OpenAIEmbeddings
from langchain.memory import ConversationTokenBufferMemory
from langchain.document_loaders import DirectoryLoader
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.chains import ConversationalRetrievalChain
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.vectorstores import Qdrant
from langchain.text_splitter import RecursiveCharacterTextSplitter
app = FastAPI()
# Load environment variables
os.environ["OPENAI_API_KEY"] = "your_api_key_here"
# Initialize components
# Loading Data
loaders = {'.csv': CSVLoader}
def create_directory_loader(file_type, directory_path):
return DirectoryLoader(
path=directory_path,
glob=f"**/*{file_type}",
loader_cls=loaders[file_type],
)
csv_file = create_directory_loader('.csv', '/my_file/py_project/')
csv_document = csv_file.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
csv_splits = text_splitter.split_documents(csv_document)
# Loading Embeddings
embeddings = OpenAIEmbeddings()
# Vector Store
def create_qdrant_retriever(documents, embedding, collection_name):
qdrant = Qdrant.from_documents(
documents,
embedding,
location=":memory:",
collection_name=collection_name,
)
return qdrant
csv_vectorstore = create_qdrant_retriever(csv_splits, embeddings, "csv_documents")
# Retrievers
def create_ensemble_retriever(splits, vectorstore):
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
bm25_retriever = BM25Retriever.from_documents(splits)
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever])
return ensemble_retriever
csv_retriever = create_ensemble_retriever(csv_splits, csv_vectorstore)
# Perform similarity search
def perform_similarity_search(query, instance):
found_docs = instance.similarity_search_with_score(query)
if found_docs:
document, score = found_docs[0]
return document.page_content, score
else:
return "No relevant documents found.", None
query = '''
Hi Auto, Thanks for the service today. We have a problem: the warning light on the dash is still showing a fault with the filter.'''
content, score = perform_similarity_search(query, csv_vectorstore)
print("Result:\n", content, "\nScore:", score)
# Prompt Design
template_csv = """
Your personal name is the "Auto_Assist Assistant", and you must use this name.You specialize in addressing a wide range of queries and concerns related to automotive services. Your core mission is to offer accurate, efficient,more precise concise solutions based on our extensive service records.
{context}:
Analyse the customer's query to understand the core issue or request, identifying the service category it falls under.
Search the CSV data to find exact issues or queries and the solutions provided across our broad service spectrum.
If a matching case is found, use the information to construct a concise, personalised response.
If no relevant records are found, kindly inform the customer that the specific query cannot be resolved at the moment and suggest the next best steps or encourage them to await expert review for more complex issues.
The output must be concise and must be UK British.
Question: {question}
Answer:
"""
prompt_csv = PromptTemplate(template=template_csv, input_variables=["context", "question"])
# Creating memory
memory = ConversationTokenBufferMemory(
llm=ChatOpenAI(model_name="gpt-4o"),
max_token_limit=1000,
memory_key="chat_history",
return_messages=True,
output_key='answer',
)
# Creating csv Retrieval Chain
csv_chain = ConversationalRetrievalChain.from_llm(
llm=ChatOpenAI(model_name="gpt-4o"),
retriever=csv_retriever,
return_source_documents=True,
chain_type='stuff',
combine_docs_chain_kwargs={"prompt": prompt_csv},
memory=memory,
verbose=False,
)
# API Models
class Query(BaseModel):
user_query: str
class CreateChatResponse(BaseModel):
chat_id: str
class DeleteChatResponse(BaseModel):
message: str
# In-memory storage for chat histories
chat_histories = {}
# API Endpoints
@app.post("/create_chat/", response_model=CreateChatResponse)
async def create_chat():
chat_id = str(uuid4())
chat_histories[chat_id] = []
return {"chat_id": chat_id}
@app.post("/query/{chat_id}/")
async def chatbot_endpoint(chat_id: str, query: Query):
if chat_id not in chat_histories:
raise HTTPException(status_code=404, detail="Chat not found")
chat_history = chat_histories[chat_id]
result = csv_chain({"question": query.user_query, "chat_history": chat_history})
chat_history.append((query.user_query, result["answer"]))
chat_histories[chat_id] = chat_history
return {"response": result["answer"]}
@app.get("/chats/", response_model=dict)
async def list_chats():
return {"chat_ids": list(chat_histories.keys())}
@app.get("/chats/{chat_id}/", response_model=dict)
async def get_chat_history(chat_id: str):
if chat_id not in chat_histories:
raise HTTPException(status_code=404, detail="Chat not found")
return {"chat_history": chat_histories[chat_id]}
@app.delete("/chats/{chat_id}/", response_model=DeleteChatResponse)
async def delete_chat(chat_id: str):
if chat_id not in chat_histories:
raise HTTPException(status_code=404, detail="Chat not found")
del chat_histories[chat_id]
return {"message": "Chat deleted successfully"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=container_2:8006'
)
|