mistral-rag-qa / app.py
jiviteshjain
Configure repository for HF spaces.
1536dad
raw
history blame
2.06 kB
from dataclasses import dataclass
import time
import gc
import streamlit as st
from omegaconf import OmegaConf
import torch
from src.rag_qa import RagQA
CONFIG_PATH = "src/rag_pipeline/conf/inference.yaml"
@st.cache_resource(
show_spinner="Loading models and indices. This might take a while. Go get hydrated..."
)
def get_rag_qa():
gc.collect()
torch.cuda.empty_cache()
conf = OmegaConf.load(CONFIG_PATH)
rag_qa = RagQA(conf)
rag_qa.load()
return rag_qa
left_column, cent_column, last_column = st.columns(3)
with cent_column:
st.image("pittsburgh.webp", width=400)
st.title("Know anything about Pittsburgh")
# Initialize the RagQA model, might be already cached.
_ = get_rag_qa()
# Run QA
st.subheader("Ask away:")
question = st.text_input("Ask away:", "", label_visibility="collapsed")
submit = st.button("Submit")
st.markdown(
"""
> **For example, ask things like:**
>
> Who is the largest employer in Pittsburgh?
> Where is the Smithsonian affiliated regional history museum in Pittsburgh?
> Who is the president of CMU?
---
""",
unsafe_allow_html=False,
)
if submit:
if not question.strip():
st.error("Machine Learning still can't read minds. Please enter a question.")
else:
try:
with st.spinner("Combing through 20,000+ documents from 14,000+ URLs..."):
answer, sources = get_rag_qa().answer(question)
st.subheader("Answer:")
st.write(answer)
st.write("")
with st.expander("Show Sources"):
st.subheader("Sources:")
for i, source in enumerate(sources):
st.markdown(f"**Name:** {source.name}")
st.markdown(f"**Index ID:** {source.index_id}")
st.markdown(f"**Text:**")
st.write(source.text)
if i < len(sources) - 1:
st.markdown("---")
except Exception as e:
st.error(f"An error occurred: {e}")