mining_rag / app.py
soroushsrd's picture
Update app.py
7521c44 verified
raw
history blame
6.2 kB
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.embeddings import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
import os
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import ChatPromptTemplate, PromptTemplate
import streamlit as st
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
@st.cache_resource
def load_resources():
llm = ChatOpenAI(model='gpt-3.5-turbo-0125', temperature=0.2)
embeddings = OpenAIEmbeddings()
vector_store = Chroma(embedding_function=embeddings, persist_directory="mining-rag")
print('vector store loaded')
return llm, vector_store
llm, vector_store = load_resources()
# Define the FAQ questions
faq_questions = [
"What are the primary methods of mineral extraction used in the mining industry?",
"What are the most common minerals extracted through mining, and what are their primary uses?",
"How do mineral exploration techniques like geophysical surveys and drilling work?",
"What role does beneficiation play in the mining process, and what are some common beneficiation techniques?",
"What are tailings, and how are they managed in mining operations?",
"What are the key regulations and standards governing the mining industry?",
"How is technology transforming the mining industry, particularly in terms of automation and data analysis?",
"What is the role of reclamation and rehabilitation in the mining lifecycle?",
"How do companies assess the feasibility of a mining project?",
"What is artisanal and small-scale mining, and what are its challenges and opportunities?",
"How do geopolitical factors affect the global mining industry?",
]
def get_answer(question):
QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is to generate three
different versions of the given user question to retrieve relevant documents from
a vector database. By generating multiple perspectives on the user question, your
goal is to help the user overcome some of the limitations of the distance-based
similarity search. Provide these alternative questions separated by newlines.
Original question: {question}""",
)
retriever = MultiQueryRetriever.from_llm(
vector_store.as_retriever(),
llm,
prompt=QUERY_PROMPT
)
WRITER_SYSTEM_PROMPT = "You are an AI critical thinker research assistant. Your sole purpose is to write well written, critically acclaimed, objective and structured reports on given text."
RESEARCH_REPORT_TEMPLATE = """Information:
--------
{text}
--------
Using the above information, answer the following question or topic: "{question}" in a short manner-- \
The answer should focus on the answer to the question, should be well structured, informative, \
in depth, with facts and numbers if available and a minimum of 150 words and a maximum of 300 words.
You should strive to write the report using all relevant and necessary information provided.
You must write the report with markdown syntax.
You MUST determine your own concrete and valid opinion based on the given information. Do NOT deter to general and meaningless conclusions.
You must write the sources used in the context. if any article is used, mentioned in the end.
Please do your best, this is very important to my career."""
prompt = ChatPromptTemplate.from_messages(
[
("system", WRITER_SYSTEM_PROMPT),
("user", RESEARCH_REPORT_TEMPLATE),
]
)
chain = (
{"text": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
answer = chain.invoke({"question": question})
return answer
# Precompute the answers to the FAQ questions
def precompute_faq_answers(faq_questions):
answers = {}
for faq in faq_questions:
answers[faq] = get_answer(faq)
return answers
# Cache the precomputed answers to avoid recomputation
@st.cache_data
def get_precomputed_answers():
return precompute_faq_answers(faq_questions)
precomputed_faq_answers = get_precomputed_answers()
# Streamlit UI
st.title('Mining Minerals Expert RAG App')
# Sidebar with FAQ Section
st.sidebar.subheader("Frequently Asked Questions")
# Initialize session state to store the selected question and answer
if "faq_answer" not in st.session_state:
st.session_state.faq_answer = ""
if "custom_answer" not in st.session_state:
st.session_state.custom_answer = ""
# Custom Question Section
st.subheader("Ask Your Own Question")
question = st.text_input('Write your question below:')
if st.button('Enter'):
if question:
answer = get_answer(question)
st.session_state.faq_answer = "" # Clear FAQ answer when a custom question is asked
st.session_state.custom_answer = f"**Answer:** {answer}"
else:
st.write("Please enter a question.")
# # Display the answer to the custom question in the main window
# if st.session_state.custom_answer:
# st.session_state.faq_answer = "" # Clear FAQ answer when a custom question is asked
# st.write(st.session_state.custom_answer)
# Show FAQ questions
faq_expander = st.sidebar.expander("FAQs")
with faq_expander:
for i, faq in enumerate(faq_questions):
if st.sidebar.button(f"Q{i+1}: {faq}", key=f"faq{i+1}"):
st.session_state.custom_answer = "" # Clear custom answer when FAQ is clicked
st.session_state.faq_answer = f"**Answer to Q{i+1}: {precomputed_faq_answers[faq]}**"
# Display the answer to the selected FAQ or custom question in the main window
if st.session_state.custom_answer:
st.write(st.session_state.custom_answer)
elif st.session_state.faq_answer:
st.write(st.session_state.faq_answer)