RJuro's picture
-
3026887
raw
history blame
2.59 kB
#from langchain_community.chat_models import ChatOpenAI
from langchain_together.embeddings import TogetherEmbeddings
from langchain_together import Together
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
import os
from dotenv import load_dotenv
load_dotenv()
together_api_key = os.getenv("TOGETHER_API_KEY")
embeddings = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-2k-retrieval")
# Example for document loading (from url), splitting, and creating vectostore
"""
# Load
from langchain_community.document_loaders import WebBaseLoader
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
data = loader.load()
# Split
from langchain_text_splitters import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data)
# Add to vectorDB
vectorstore = Chroma.from_documents(documents=all_splits,
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
"""
""" # Embed a single document as a test
vectorstore = Chroma.from_texts(
["harrison worked at kensho"],
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever() """
# Load up vector-store
vectorstore = Chroma(persist_directory="packages/rag-chroma/rag_chroma/vecdb",
collection_name="rag-chroma",
embedding_function = embeddings)
retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={'k': 7})
# RAG prompt
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
# LLM
model = Together(
#model="mistralai/Mistral-7B-Instruct-v0.2",
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0.7,
top_k=50,
top_p=0.7,
repetition_penalty=1,
together_api_key=together_api_key
)
# RAG chain
chain = (
RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
| prompt
| model
| StrOutputParser()
)
# Add typing for input
class Question(BaseModel):
__root__: str
chain = chain.with_types(input_type=Question)