Jayesh13's picture
Update app.py
3de15ee verified
raw
history blame contribute delete
No virus
4.89 kB
import streamlit as st
import os
import time
import numpy as np
import pandas as pd
def add_custom_css():
st.markdown("""
<style>
.container {
text-align: center;
background-color: #f0f0f0;
padding: 20px;
}
.big-font {
font-size: 50px;
color: #4CAF50;
}
.progress-bar {
margin-top: 20px;
}
</style>
""", unsafe_allow_html=True)
if 'packages_installed' not in st.session_state:
st.info("Installing required packages...")
os.system("pip install -U sentence-transformers")
os.system("pip install pinecone-client")
st.session_state['packages_installed'] = True
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone, ServerlessSpec, PodSpec
if 'pc' not in st.session_state:
use_serverless = False
# Configure Pinecone client
api_key = os.environ.get('PINECONE_API_KEY', '28b0fd5a-fdfb-422d-9a44-c0ec09a25074')
environment = os.environ.get('PINECONE_ENVIRONMENT', 'gcp-starter')
st.session_state['pc'] = Pinecone(api_key=api_key)
if use_serverless:
spec = ServerlessSpec(cloud='gcp', region='asia-southeast1-gcp')
else:
spec = PodSpec(environment=environment)
if 'model' not in st.session_state:
st.session_state['model'] = SentenceTransformer('intfloat/e5-small')
index_name = 'dataset'
if index_name not in st.session_state.pc.list_indexes().names():
dimensions = 384
st.session_state.pc.create_index(
name=index_name,
dimension=dimensions,
metric='cosine',
spec=spec
)
# Wait until index is ready
while not st.session_state.pc.describe_index(index_name).status['ready']:
time.sleep(1)
if 'index' not in st.session_state:
st.session_state['index'] = st.session_state.pc.Index(index_name)
# Function to process data and insert into Pinecone index
def process_data(data, namespace):
input_texts = data['Query']
progress_bar = st.progress(0)
total_chunks = len(data) // 1000 + 1
for chunk_start in range(0, len(data), 1000):
chunk_end = min(chunk_start + 1000, len(data))
chunk = data.iloc[chunk_start:chunk_end]
# Generate embeddings for the current chunk
chunk_embeddings = [st.session_state.model.encode(query, normalize_embeddings=True) for query in chunk['Query']]
chunk['embedding'] = chunk_embeddings
# Upsert embeddings
st.session_state.index.upsert(vectors=zip(chunk['id'], chunk['embedding']), namespace=namespace)
# Update progress bar
progress = (chunk_end / len(data)) * 100
progress_bar.progress(int(progress))
def load_and_process_data(file):
data = pd.read_csv(file)
data['id'] = data.index.astype(str)
namespace = file.name[:15] # Use first 15 characters of file name as namespace
if 'embeddings_done' not in st.session_state:
process_data(data, namespace)
st.session_state['embeddings_done'] = True
return data, namespace
def main():
add_custom_css()
st.markdown("""
<div class='container'>
<h1 class='big-font'>Semantic Search Engine</h1>
</div>
""", unsafe_allow_html=True)
# Use session state to retain information across interactions
if 'namespace' not in st.session_state:
st.session_state.namespace = None
if 'df' not in st.session_state:
st.session_state.df = None
uploaded_file = st.file_uploader("Upload dataset (CSV format)", type=["csv"])
if uploaded_file is not None:
filename = uploaded_file.name
namespace = filename.split('.')[0]
st.info("Dataset Processing Started...")
st.session_state.df, st.session_state.namespace = load_and_process_data(uploaded_file)
st.info("Dataset Processing Completed...")
if st.session_state.namespace:
query = st.text_input("Enter your query about the data (or type 'exit' to quit):")
if query.lower() != 'exit':
vec = st.session_state.model.encode(query)
result = None
result = st.session_state.index.query(
namespace=st.session_state.namespace,
vector=vec.tolist(),
top_k=5,
include_values=False
)
st.subheader("Query Results:")
if result is not None:
id = result['matches'][0]['id']
data = st.session_state.df
answer = data[data['id'] == id]['Answer'].values[0]
st.write(answer)
if st.button("Delete Stored Data"):
st.session_state.index.delete(deleteAll=True, namespace =st.session_state.namespace)
st.stop()
if __name__ == "__main__":
main()