|
|
|
from langchain_together.embeddings import TogetherEmbeddings |
|
from langchain_together import Together |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.pydantic_v1 import BaseModel |
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough |
|
|
|
import os |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
together_api_key = os.getenv("TOGETHER_API_KEY") |
|
|
|
|
|
embeddings = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-2k-retrieval") |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
# Load |
|
from langchain_community.document_loaders import WebBaseLoader |
|
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/") |
|
data = loader.load() |
|
|
|
# Split |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) |
|
all_splits = text_splitter.split_documents(data) |
|
|
|
# Add to vectorDB |
|
vectorstore = Chroma.from_documents(documents=all_splits, |
|
collection_name="rag-chroma", |
|
embedding=OpenAIEmbeddings(), |
|
) |
|
retriever = vectorstore.as_retriever() |
|
""" |
|
|
|
""" # Embed a single document as a test |
|
vectorstore = Chroma.from_texts( |
|
["harrison worked at kensho"], |
|
collection_name="rag-chroma", |
|
embedding=OpenAIEmbeddings(), |
|
) |
|
retriever = vectorstore.as_retriever() """ |
|
|
|
|
|
|
|
vectorstore = Chroma(persist_directory="packages/rag-chroma/rag_chroma/vecdb", |
|
collection_name="rag-chroma", |
|
embedding_function = embeddings) |
|
|
|
|
|
retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={'k': 7}) |
|
|
|
|
|
|
|
template = """Answer the question based only on the following context: |
|
{context} |
|
|
|
Question: {question} |
|
""" |
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
|
|
model = Together( |
|
|
|
model="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
temperature=0.7, |
|
top_k=50, |
|
top_p=0.7, |
|
repetition_penalty=1, |
|
together_api_key=together_api_key |
|
) |
|
|
|
|
|
chain = ( |
|
RunnableParallel({"context": retriever, "question": RunnablePassthrough()}) |
|
| prompt |
|
| model |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
|
|
class Question(BaseModel): |
|
__root__: str |
|
|
|
|
|
chain = chain.with_types(input_type=Question) |
|
|