Mehrdad Esmaeili
Create app.py
eda774c
raw
history blame
No virus
2.08 kB
from langchain.chains import RetrievalQA
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.document_loaders import TextLoader
from langchain.docstore.document import Document
import openai
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
import cohere
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.llms import Cohere
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
import os
from tqdm import tqdm
documents=[]
path='./bios/'
# path='./augBios/'
for file in os.listdir(path):
loader = TextLoader(f'{path}{file}',encoding='unicode_escape')
# loader.load()[0].metadata['category']='biography'
# print(loader.load()[0].metadata)
documents += loader.load()
# print(documents)
print(len(documents))
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
# embeddings = OpenAIEmbeddings()
embeddings = CohereEmbeddings()
docsearch = Chroma.from_documents(texts, embeddings)
qa = RetrievalQA.from_chain_type(llm=Cohere(), chain_type="stuff", \
retriever=docsearch.as_retriever(search_kwargs={'k':1}),return_source_documents=True)
def predict(message, history):
# history_langchain_format = []
# for human, ai in history:
# history_langchain_format.append(HumanMessage(content=human))
# history_langchain_format.append(AIMessage(content=ai))
# history_langchain_format.append(HumanMessage(content=message))
# gpt_response = llm(history_langchain_format)
# return gpt_response.content
message+='? just give me the book title-Author'
result = qa({"query": message})
# r1=docsearch.similarity_search_with_score(query=q,k=3)
# print([(item[-2].metadata,item[-1]) for item in r1],\
# '\n\n',result['result'],f'|| {result["source_documents"][0].metadata}','\n*****\n')
return result['result'],f'|| source is==> {result["source_documents"][0].metadata}'
gr.ChatInterface(predict).launch()