|
import streamlit as st
|
|
import os
|
|
import json
|
|
import pandas as pd
|
|
import random
|
|
from os.path import join
|
|
from datetime import datetime
|
|
from src import (
|
|
preprocess_and_load_df,
|
|
load_agent,
|
|
ask_agent,
|
|
decorate_with_code,
|
|
show_response,
|
|
get_from_user,
|
|
load_smart_df,
|
|
ask_question,
|
|
)
|
|
from dotenv import load_dotenv
|
|
from langchain_groq import ChatGroq
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from streamlit_feedback import streamlit_feedback
|
|
from huggingface_hub import HfApi
|
|
from datasets import load_dataset, get_dataset_config_info, Dataset
|
|
from PIL import Image
|
|
import time
|
|
|
|
|
|
st.set_page_config(
|
|
page_title="VayuBuddy - AI Air Quality Assistant",
|
|
page_icon="π¬οΈ",
|
|
layout="wide",
|
|
initial_sidebar_state="expanded"
|
|
)
|
|
|
|
|
|
st.markdown("""
|
|
<style>
|
|
/* Clean app background */
|
|
.stApp {
|
|
background-color: #ffffff;
|
|
color: #212529;
|
|
font-family: 'Segoe UI', sans-serif;
|
|
}
|
|
|
|
/* Sidebar */
|
|
[data-testid="stSidebar"] {
|
|
background-color: #f8f9fa;
|
|
border-right: 1px solid #dee2e6;
|
|
padding: 1rem;
|
|
}
|
|
|
|
/* Main title */
|
|
.main-title {
|
|
text-align: center;
|
|
color: #343a40;
|
|
font-size: 2.5rem;
|
|
font-weight: 700;
|
|
margin-bottom: 0.5rem;
|
|
}
|
|
|
|
/* Subtitle */
|
|
.subtitle {
|
|
text-align: center;
|
|
color: #6c757d;
|
|
font-size: 1.1rem;
|
|
margin-bottom: 1.5rem;
|
|
}
|
|
|
|
/* Instructions */
|
|
.instructions {
|
|
background-color: #f1f3f5;
|
|
border-left: 4px solid #0d6efd;
|
|
padding: 1rem;
|
|
margin-bottom: 1.5rem;
|
|
border-radius: 6px;
|
|
color: #495057;
|
|
text-align: left;
|
|
}
|
|
|
|
/* Quick prompt buttons */
|
|
.quick-prompt-container {
|
|
display: flex;
|
|
flex-wrap: wrap;
|
|
gap: 8px;
|
|
margin-bottom: 1.5rem;
|
|
padding: 1rem;
|
|
background-color: #f8f9fa;
|
|
border-radius: 10px;
|
|
border: 1px solid #dee2e6;
|
|
}
|
|
|
|
.quick-prompt-btn {
|
|
background-color: #0d6efd;
|
|
color: white;
|
|
border: none;
|
|
padding: 8px 16px;
|
|
border-radius: 20px;
|
|
font-size: 0.9rem;
|
|
cursor: pointer;
|
|
transition: all 0.2s ease;
|
|
white-space: nowrap;
|
|
}
|
|
|
|
.quick-prompt-btn:hover {
|
|
background-color: #0b5ed7;
|
|
transform: translateY(-2px);
|
|
}
|
|
|
|
/* User message styling */
|
|
.user-message {
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
|
color: white;
|
|
padding: 15px 20px;
|
|
border-radius: 20px 20px 5px 20px;
|
|
margin: 10px 0;
|
|
margin-left: auto;
|
|
margin-right: 0;
|
|
max-width: 80%;
|
|
position: relative;
|
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
|
}
|
|
|
|
.user-info {
|
|
font-size: 0.8rem;
|
|
opacity: 0.8;
|
|
margin-bottom: 5px;
|
|
text-align: right;
|
|
}
|
|
|
|
/* Assistant message styling */
|
|
.assistant-message {
|
|
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
|
|
color: white;
|
|
padding: 15px 20px;
|
|
border-radius: 20px 20px 20px 5px;
|
|
margin: 10px 0;
|
|
margin-left: 0;
|
|
margin-right: auto;
|
|
max-width: 80%;
|
|
position: relative;
|
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
|
}
|
|
|
|
.assistant-info {
|
|
font-size: 0.8rem;
|
|
opacity: 0.8;
|
|
margin-bottom: 5px;
|
|
}
|
|
|
|
/* Processing indicator */
|
|
.processing-indicator {
|
|
background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%);
|
|
color: #333;
|
|
padding: 15px 20px;
|
|
border-radius: 20px 20px 20px 5px;
|
|
margin: 10px 0;
|
|
margin-left: 0;
|
|
margin-right: auto;
|
|
max-width: 80%;
|
|
position: relative;
|
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
|
animation: pulse 2s infinite;
|
|
}
|
|
|
|
@keyframes pulse {
|
|
0% { opacity: 1; }
|
|
50% { opacity: 0.7; }
|
|
100% { opacity: 1; }
|
|
}
|
|
|
|
/* Feedback box */
|
|
.feedback-section {
|
|
background-color: #f8f9fa;
|
|
border: 1px solid #dee2e6;
|
|
padding: 1rem;
|
|
border-radius: 8px;
|
|
margin: 1rem 0;
|
|
}
|
|
|
|
/* Success and error messages */
|
|
.success-message {
|
|
background-color: #d1e7dd;
|
|
color: #0f5132;
|
|
padding: 1rem;
|
|
border-radius: 6px;
|
|
border: 1px solid #badbcc;
|
|
}
|
|
|
|
.error-message {
|
|
background-color: #f8d7da;
|
|
color: #842029;
|
|
padding: 1rem;
|
|
border-radius: 6px;
|
|
border: 1px solid #f5c2c7;
|
|
}
|
|
|
|
/* Chat input */
|
|
.stChatInput {
|
|
border-radius: 6px;
|
|
border: 1px solid #ced4da;
|
|
background: #ffffff;
|
|
}
|
|
|
|
/* Button */
|
|
.stButton > button {
|
|
background-color: #0d6efd;
|
|
color: white;
|
|
border-radius: 6px;
|
|
padding: 0.5rem 1.25rem;
|
|
border: none;
|
|
font-weight: 600;
|
|
transition: background-color 0.2s ease;
|
|
}
|
|
|
|
.stButton > button:hover {
|
|
background-color: #0b5ed7;
|
|
}
|
|
|
|
/* Code details styling */
|
|
.code-details {
|
|
background-color: #f8f9fa;
|
|
border: 1px solid #dee2e6;
|
|
border-radius: 8px;
|
|
padding: 10px;
|
|
margin-top: 10px;
|
|
}
|
|
|
|
/* Hide default menu and footer */
|
|
#MainMenu {visibility: hidden;}
|
|
footer {visibility: hidden;}
|
|
header {visibility: hidden;}
|
|
|
|
/* Auto scroll */
|
|
.main-container {
|
|
height: 70vh;
|
|
overflow-y: auto;
|
|
}
|
|
</style>
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
st.markdown("""
|
|
<script>
|
|
function scrollToBottom() {
|
|
setTimeout(function() {
|
|
const mainContainer = document.querySelector('.main-container');
|
|
if (mainContainer) {
|
|
mainContainer.scrollTop = mainContainer.scrollHeight;
|
|
}
|
|
window.scrollTo(0, document.body.scrollHeight);
|
|
}, 100);
|
|
}
|
|
</script>
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
load_dotenv(override=True)
|
|
|
|
|
|
Groq_Token = os.getenv("GROQ_API_KEY")
|
|
hf_token = os.getenv("HF_TOKEN")
|
|
gemini_token = os.getenv("GEMINI_TOKEN")
|
|
|
|
models = {
|
|
"llama3.1": "llama-3.1-8b-instant",
|
|
"mistral": "mistral-saba-24b",
|
|
"llama3.3": "llama-3.3-70b-versatile",
|
|
"gemma": "gemma2-9b-it",
|
|
"gemini-pro": "gemini-1.5-pro",
|
|
}
|
|
|
|
self_path = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
st.markdown("<h1 class='main-title'>π¬οΈ VayuBuddy</h1>", unsafe_allow_html=True)
|
|
|
|
st.markdown("""
|
|
<div class='subtitle'>
|
|
<strong>AI-Powered Air Quality Insights</strong><br>
|
|
Simplifying pollution analysis using conversational AI.
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
st.markdown("""
|
|
<div class='instructions'>
|
|
<strong>How to Use:</strong><br>
|
|
Select a model from the sidebar and ask questions directly in the chat. Use quick prompts below for common queries.
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
os.environ["PANDASAI_API_KEY"] = "$2a$10$gbmqKotzJOnqa7iYOun8eO50TxMD/6Zw1pLI2JEoqncwsNx4XeBS2"
|
|
|
|
|
|
try:
|
|
df = preprocess_and_load_df(join(self_path, "Data.csv"))
|
|
st.success("β
Data loaded successfully!")
|
|
except Exception as e:
|
|
st.error(f"β Error loading data: {e}")
|
|
st.stop()
|
|
|
|
inference_server = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
|
|
image_path = "IITGN_Logo.png"
|
|
|
|
|
|
with st.sidebar:
|
|
|
|
col1, col2, col3 = st.columns([1, 2, 1])
|
|
with col2:
|
|
if os.path.exists(image_path):
|
|
st.image(image_path, use_column_width=True)
|
|
|
|
|
|
st.markdown("### π€ AI Model Selection")
|
|
|
|
|
|
available_models = []
|
|
if Groq_Token and Groq_Token.strip():
|
|
available_models.extend(["llama3.1", "llama3.3", "mistral", "gemma"])
|
|
if gemini_token and gemini_token.strip():
|
|
available_models.append("gemini-pro")
|
|
|
|
if not available_models:
|
|
st.error("β No API keys available! Please set up your API keys in the .env file")
|
|
st.stop()
|
|
|
|
model_name = st.selectbox(
|
|
"Choose your AI assistant:",
|
|
available_models,
|
|
help="Different models have different strengths. Try them all!"
|
|
)
|
|
|
|
|
|
model_descriptions = {
|
|
"llama3.1": "π¦ Fast and efficient for general queries",
|
|
"llama3.3": "π¦ Most advanced Llama model",
|
|
"mistral": "β‘ Balanced performance and speed",
|
|
"gemma": "π Google's lightweight model",
|
|
"gemini-pro": "π§ Google's most powerful model"
|
|
}
|
|
|
|
if model_name in model_descriptions:
|
|
st.info(model_descriptions[model_name])
|
|
|
|
st.markdown("---")
|
|
|
|
|
|
if st.button("π§Ή Clear Chat"):
|
|
st.session_state.responses = []
|
|
st.session_state.processing = False
|
|
try:
|
|
st.rerun()
|
|
except AttributeError:
|
|
st.experimental_rerun()
|
|
|
|
st.markdown("---")
|
|
|
|
|
|
with st.expander("π Chat History"):
|
|
for i, response in enumerate(st.session_state.get("responses", [])):
|
|
if response.get("role") == "user":
|
|
st.markdown(f"**You:** {response.get('content', '')[:50]}...")
|
|
elif response.get("role") == "assistant":
|
|
content = response.get('content', '')
|
|
if isinstance(content, str) and len(content) > 50:
|
|
st.markdown(f"**VayuBuddy:** {content[:50]}...")
|
|
else:
|
|
st.markdown(f"**VayuBuddy:** {str(content)[:50]}...")
|
|
st.markdown("---")
|
|
|
|
|
|
questions = []
|
|
questions_file = join(self_path, "questions.txt")
|
|
if os.path.exists(questions_file):
|
|
try:
|
|
with open(questions_file, 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
questions = [q.strip() for q in content.split("\n") if q.strip()]
|
|
print(f"Loaded {len(questions)} quick prompts")
|
|
except Exception as e:
|
|
st.error(f"Error loading questions: {e}")
|
|
questions = []
|
|
|
|
|
|
if not questions:
|
|
questions = [
|
|
"What is the average PM2.5 level in the dataset?",
|
|
"Show me the air quality trend over time",
|
|
"Which pollutant has the highest concentration?",
|
|
"Create a correlation plot between different pollutants",
|
|
"What are the peak pollution hours?",
|
|
"Compare weekday vs weekend pollution levels"
|
|
]
|
|
|
|
|
|
st.markdown("### π Quick Prompts")
|
|
|
|
|
|
cols_per_row = 2
|
|
rows = [questions[i:i + cols_per_row] for i in range(0, len(questions), cols_per_row)]
|
|
|
|
selected_prompt = None
|
|
for row_idx, row in enumerate(rows):
|
|
cols = st.columns(len(row))
|
|
for col_idx, question in enumerate(row):
|
|
with cols[col_idx]:
|
|
|
|
unique_key = f"prompt_btn_{row_idx}_{col_idx}"
|
|
button_text = f"π {question[:35]}{'...' if len(question) > 35 else ''}"
|
|
|
|
if st.button(button_text,
|
|
key=unique_key,
|
|
help=question,
|
|
use_container_width=True):
|
|
selected_prompt = question
|
|
|
|
st.markdown("---")
|
|
|
|
|
|
if "responses" not in st.session_state:
|
|
st.session_state.responses = []
|
|
if "processing" not in st.session_state:
|
|
st.session_state.processing = False
|
|
|
|
def upload_feedback():
|
|
try:
|
|
data = {
|
|
"feedback": feedback.get("score", ""),
|
|
"comment": feedback.get("text", ""),
|
|
"error": error,
|
|
"output": output,
|
|
"prompt": last_prompt,
|
|
"code": code,
|
|
}
|
|
|
|
random_folder_name = str(datetime.now()).replace(" ", "_").replace(":", "-").replace(".", "-")
|
|
save_path = f"/tmp/vayubuddy_feedback.md"
|
|
path_in_repo = f"data/{random_folder_name}/feedback.md"
|
|
|
|
with open(save_path, "w") as f:
|
|
template = f"""Prompt: {last_prompt}
|
|
|
|
Output: {output}
|
|
|
|
Code:
|
|
|
|
```py
|
|
{code}
|
|
```
|
|
|
|
Error: {error}
|
|
|
|
Feedback: {feedback.get('score', '')}
|
|
|
|
Comments: {feedback.get('text', '')}
|
|
"""
|
|
print(template, file=f)
|
|
|
|
if hf_token:
|
|
api = HfApi(token=hf_token)
|
|
api.upload_file(
|
|
path_or_fileobj=save_path,
|
|
path_in_repo=path_in_repo,
|
|
repo_id="SustainabilityLabIITGN/VayuBuddy_Feedback",
|
|
repo_type="dataset",
|
|
)
|
|
if status.get("is_image", False):
|
|
api.upload_file(
|
|
path_or_fileobj=output,
|
|
path_in_repo=f"data/{random_folder_name}/plot.png",
|
|
repo_id="SustainabilityLabIITGN/VayuBuddy_Feedback",
|
|
repo_type="dataset",
|
|
)
|
|
st.success("π Feedback uploaded successfully!")
|
|
else:
|
|
st.warning("β οΈ Cannot upload feedback - HF_TOKEN not available")
|
|
except Exception as e:
|
|
st.error(f"β Error uploading feedback: {e}")
|
|
|
|
def show_custom_response(response):
|
|
"""Custom response display function"""
|
|
role = response.get("role", "assistant")
|
|
content = response.get("content", "")
|
|
|
|
if role == "user":
|
|
st.markdown(f"""
|
|
<div class='user-message'>
|
|
<div class='user-info'>You</div>
|
|
{content}
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
elif role == "assistant":
|
|
st.markdown(f"""
|
|
<div class='assistant-message'>
|
|
<div class='assistant-info'>π€ VayuBuddy</div>
|
|
{content if isinstance(content, str) else str(content)}
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
if response.get("gen_code"):
|
|
with st.expander("π View Generated Code"):
|
|
st.code(response["gen_code"], language="python")
|
|
|
|
|
|
try:
|
|
if isinstance(content, str) and (content.endswith('.png') or content.endswith('.jpg')):
|
|
if os.path.exists(content):
|
|
st.image(content)
|
|
return {"is_image": True}
|
|
except:
|
|
pass
|
|
|
|
return {"is_image": False}
|
|
|
|
def show_processing_indicator(model_name, question):
|
|
"""Show processing indicator"""
|
|
st.markdown(f"""
|
|
<div class='processing-indicator'>
|
|
<div class='assistant-info'>π€ VayuBuddy β’ Processing with {model_name}</div>
|
|
<strong>Question:</strong> {question}<br>
|
|
<em>π Generating response...</em>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
chat_container = st.container()
|
|
|
|
with chat_container:
|
|
|
|
for response_id, response in enumerate(st.session_state.responses):
|
|
status = show_custom_response(response)
|
|
|
|
|
|
if response["role"] == "assistant":
|
|
feedback_key = f"feedback_{int(response_id/2)}"
|
|
error = response.get("error", "No error information")
|
|
output = response.get("content", "No output")
|
|
last_prompt = response.get("last_prompt", "No prompt")
|
|
code = response.get("gen_code", "No code generated")
|
|
|
|
if "feedback" in st.session_state.responses[response_id]:
|
|
st.markdown(f"""
|
|
<div class='feedback-section'>
|
|
<strong>π Your Feedback:</strong> {st.session_state.responses[response_id]["feedback"]}
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
else:
|
|
|
|
col1, col2 = st.columns(2)
|
|
with col1:
|
|
thumbs_up = st.button("π Helpful", key=f"{feedback_key}_up", use_container_width=True)
|
|
with col2:
|
|
thumbs_down = st.button("π Not Helpful", key=f"{feedback_key}_down", use_container_width=True)
|
|
|
|
if thumbs_up or thumbs_down:
|
|
thumbs = "π" if thumbs_up else "π"
|
|
comments = st.text_area(
|
|
"π¬ Tell us more (optional):",
|
|
key=f"{feedback_key}_comments",
|
|
placeholder="What could be improved?"
|
|
)
|
|
feedback = {"score": thumbs, "text": comments}
|
|
if st.button("π Submit Feedback", key=f"{feedback_key}_submit"):
|
|
upload_feedback()
|
|
st.session_state.responses[response_id]["feedback"] = feedback
|
|
st.rerun()
|
|
|
|
|
|
if st.session_state.get("processing"):
|
|
show_processing_indicator(
|
|
st.session_state.get("current_model", "Unknown"),
|
|
st.session_state.get("current_question", "Processing...")
|
|
)
|
|
|
|
|
|
prompt = st.chat_input("π¬ Ask me anything about air quality!", key="main_chat")
|
|
|
|
|
|
if selected_prompt:
|
|
prompt = selected_prompt
|
|
|
|
|
|
if prompt and not st.session_state.get("processing"):
|
|
|
|
if "last_prompt" in st.session_state:
|
|
last_prompt = st.session_state["last_prompt"]
|
|
last_model_name = st.session_state.get("last_model_name", "")
|
|
if (prompt == last_prompt) and (model_name == last_model_name):
|
|
prompt = None
|
|
|
|
if prompt:
|
|
|
|
user_response = get_from_user(prompt)
|
|
st.session_state.responses.append(user_response)
|
|
|
|
|
|
st.session_state.processing = True
|
|
st.session_state.current_model = model_name
|
|
st.session_state.current_question = prompt
|
|
|
|
|
|
st.rerun()
|
|
|
|
|
|
if st.session_state.get("processing"):
|
|
prompt = st.session_state.get("current_question")
|
|
model_name = st.session_state.get("current_model")
|
|
|
|
try:
|
|
response = ask_question(model_name=model_name, question=prompt)
|
|
|
|
if not isinstance(response, dict):
|
|
response = {
|
|
"role": "assistant",
|
|
"content": "β Error: Invalid response format",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": prompt,
|
|
"error": "Invalid response format"
|
|
}
|
|
|
|
response.setdefault("role", "assistant")
|
|
response.setdefault("content", "No content generated")
|
|
response.setdefault("gen_code", "")
|
|
response.setdefault("ex_code", "")
|
|
response.setdefault("last_prompt", prompt)
|
|
response.setdefault("error", None)
|
|
|
|
except Exception as e:
|
|
response = {
|
|
"role": "assistant",
|
|
"content": f"Sorry, I encountered an error: {str(e)}",
|
|
"gen_code": "",
|
|
"ex_code": "",
|
|
"last_prompt": prompt,
|
|
"error": str(e)
|
|
}
|
|
|
|
st.session_state.responses.append(response)
|
|
st.session_state["last_prompt"] = prompt
|
|
st.session_state["last_model_name"] = model_name
|
|
st.session_state.processing = False
|
|
|
|
|
|
if "current_model" in st.session_state:
|
|
del st.session_state.current_model
|
|
if "current_question" in st.session_state:
|
|
del st.session_state.current_question
|
|
|
|
st.rerun()
|
|
|
|
|
|
if st.session_state.responses:
|
|
st.markdown("<script>scrollToBottom();</script>", unsafe_allow_html=True)
|
|
|
|
|
|
with st.sidebar:
|
|
st.markdown("---")
|
|
st.markdown("""
|
|
<div class='contact-section'>
|
|
<h4>π Paper on VayuBuddy</h4>
|
|
<p>Learn more about VayuBuddy in our <a href='https://arxiv.org/abs/2411.12760' target='_blank'>Research Paper</a>.</p>
|
|
</div>
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
st.markdown("""
|
|
<div style='text-align: center; margin-top: 3rem; padding: 2rem; background: rgba(255,255,255,0.1); border-radius: 15px;'>
|
|
<h3>π Together for Cleaner Air</h3>
|
|
<p>VayuBuddy - Empowering environmental awareness through AI</p>
|
|
<small>Β© 2024 IIT Gandhinagar Sustainability Lab</small>
|
|
</div>
|
|
""", unsafe_allow_html=True) |