earnings-final / roaringkitty.py
mlara's picture
second commit
65e34bb
raw
history blame
1.38 kB
import chainlit as cl
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings import CacheBackedEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.storage import LocalFileStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
async def roaringkiity_chain(prompt: str):
# build FAISS index from csv
loader = CSVLoader(file_path="./data/roaringkitty.csv", source_column="Link")
data = loader.load()
documents = text_splitter.transform_documents(data)
store = LocalFileStore("./cache/")
core_embeddings_model = OpenAIEmbeddings()
embedder = CacheBackedEmbeddings.from_bytes_store(
core_embeddings_model, store, namespace=core_embeddings_model.model
)
# make async docsearch
docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
chain = RetrievalQA.from_chain_type(
ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
chain_type="stuff",
return_source_documents=True,
retriever=docsearch.as_retriever(),
chain_type_kwargs = {"prompt": prompt}
)
return chain