goal-rag-demo / tools.py
cmagganas's picture
Update tools.py
98997fa
raw
history blame
2.44 kB
import io
import os
from openai import OpenAI
from langchain.tools import StructuredTool, Tool
from io import BytesIO
import requests
import json
from io import BytesIO
import chainlit as cl
import os
import openai
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains.question_answering import load_qa_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import UnstructuredPDFLoader
# OpenAI API Key Setup
openai.api_key = os.environ["OPENAI_API_KEY"]
# Define our RAG tool function
def rag(query):
# Load The Goal PDF
loader = UnstructuredPDFLoader("data/The Goal - A Process of Ongoing Improvement (Third Revised Edition).pdf") # , mode="elements"
docs = loader.load()
# Split Text Chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Embed Chunks into Chroma Vector Store
vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever()
# Use RAG Prompt Template
prompt = hub.pull("rlm/rag-prompt")
llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0) # or gpt-3.5-turbo
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
response = ""
for chunk in rag_chain.stream(query): #e.g. "What is a Bottleneck Constraint?"
cl.user_session(chunk, end="", flush=True)
response += f"\n{chunk}"
# rag_chain.invoke("What is a Bottleneck Constraint?")
return response
# this is our tool - which is what allows our agent to access RAG agent
# the `description` field is of utmost imporance as it is what the LLM "brain" uses to determine
# which tool to use for a given input.
rag_format = '{{"prompt": "prompt"}}'
rag_tool = Tool.from_function(
func=rag,
name="RAG",
description=f"Useful for retrieving contextual information about the PDF to answer user questions. Input should be a single string strictly in the following JSON format: {generate_image_format}",
return_direct=True,
)