llm_6 / app.py
Entz's picture
Upload 10 files
8f24edc verified
import streamlit as st
import pandas as pd
import sqlite3
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import PromptTemplate
import os
version = 2.2
# Initialize the SQLite3 database
conn = sqlite3.connect('qa.db')
c = conn.cursor()
# Update the table creation to include the version column
c.execute('CREATE TABLE IF NOT EXISTS qa (question TEXT, answer TEXT, version REAL)')
conn.commit()
# Read the LLM Model Description from a file
def read_description_from_file(file_path):
with open(file_path, 'r') as file:
return file.read()
# Define the folder containing the saved index
INDEX_OUTPUT_PATH = "./output_index"
# Ensure the output directory exists
if not os.path.exists(INDEX_OUTPUT_PATH):
raise ValueError(f"Index directory {INDEX_OUTPUT_PATH} does not exist")
# Setup LLM and embedding model
llm = Ollama(model="llama3", request_timeout=120.0)
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5", trust_remote_code=True)
# To load the index later, set up the storage context
storage_context = StorageContext.from_defaults(persist_dir=INDEX_OUTPUT_PATH)
loaded_index = load_index_from_storage(embed_model=embed_model, storage_context=storage_context)
# Define a query engine (assuming it needs the LLM and embedding model)
query_engine = loaded_index.as_query_engine(llm=llm, embed_model=embed_model)
# Customise prompt template
# Read the prompt template from a file
qa_prompt_tmpl_str = read_description_from_file("tab2_pe.txt")
qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str)
query_engine.update_prompts(
{"response_synthesizer:text_qa_template": qa_prompt_tmpl}
)
# Save the question and answer to the SQLite3 database
def save_to_db(question, answer, version):
c.execute('INSERT INTO qa (question, answer, version) VALUES (?, ?, ?)', (question, answer, version))
conn.commit()
# Fetch all data from the SQLite3 database
def fetch_from_db():
c.execute('SELECT * FROM qa')
return c.fetchall()
def main():
st.title("How Much Does Mistral 7B Model Know About Wandsworth Council?")
tab1, tab2, tab3 = st.tabs(["LLM Model Description", "Ask a Question", "View Q&A History"])
with tab1:
st.subheader("LLM Model Description")
description = read_description_from_file("tab1_intro.txt")
st.write(description)
with tab2:
st.subheader("Ask a Question")
question = st.text_input("Enter your question:")
if st.button("Get Answer"):
if question:
try:
response = query_engine.query(question)
# Try to extract the generated text
try:
# Extract the text from the response object (assuming it has a `text` attribute or method)
if hasattr(response, 'text'):
answer = response.text
else:
answer = str(response)
except AttributeError as e:
st.error(f"Error extracting text from response: {e}")
answer = "Sorry, could not generate an answer."
st.write(f"**Answer:** {answer}")
# Save question and answer to database
save_to_db(question, answer, version)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please enter a question")
with tab3:
st.subheader("View Q&A History")
qa_data = fetch_from_db()
if qa_data:
df = pd.DataFrame(qa_data, columns=["Question", "Answer", "Version"])
st.dataframe(df)
else:
st.write("No data available")
if __name__ == "__main__":
main()