|
import os |
|
from src.processor.processor import Processor |
|
from src.db.db import DB |
|
|
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain |
|
from src.llm.llm import OpenAILLM |
|
from langchain_openai import OpenAIEmbeddings |
|
from src.chatbot import Chatbot |
|
|
|
def create_chain(llm, db_retriever): |
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", return_messages=True, output_key="answer" |
|
) |
|
|
|
return ConversationalRetrievalChain.from_llm( |
|
llm=llm, retriever=db_retriever, memory=memory, return_source_documents=True |
|
) |
|
|
|
|
|
|
|
MODEL_NAME = "gpt-3.5-turbo" |
|
INDEX_NAME = "test" |
|
|
|
cb = Chatbot() |
|
llm = OpenAILLM(model_name=MODEL_NAME).oai_llm |
|
embeddings = OpenAIEmbeddings() |
|
doc_processor = Processor() |
|
db = DB(index_name=INDEX_NAME, embeddings=embeddings) |
|
db_retreiver = db.get_reriever() |
|
conversation_chain = cb.create_conversational_chain(llm, db_retreiver) |
|
|
|
|
|
def get_response(query:str): |
|
return conversation_chain({"question": query}) |
|
|
|
def add_file(filename:str): |
|
|
|
docs = doc_processor.process(filename) |
|
db.insert(docs, file_name_without_extension=os.path.splitext(os.path.basename(filename))[0]) |