|
|
import streamlit as st |
|
|
import tensorflow as tf |
|
|
import json |
|
|
import joblib |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
|
from tensorflow.keras.preprocessing.text import Tokenizer |
|
|
from datetime import datetime |
|
|
import os |
|
|
import gspread |
|
|
from google.oauth2.service_account import Credentials |
|
|
from tensorflow.keras.models import load_model |
|
|
from huggingface_hub import hf_hub_download |
|
|
from tensorflow.keras.preprocessing.text import tokenizer_from_json |
|
|
|
|
|
def save_to_google_sheet(data): |
|
|
scope = [ |
|
|
"https://spreadsheets.google.com/feeds", |
|
|
"https://www.googleapis.com/auth/drive" |
|
|
] |
|
|
|
|
|
creds_dict = json.loads(st.secrets["gcp_credentials"]) |
|
|
|
|
|
|
|
|
if "private_key" in creds_dict: |
|
|
creds_dict["private_key"] = creds_dict["private_key"].replace("\\n", "\n") |
|
|
|
|
|
|
|
|
creds = Credentials.from_service_account_info(creds_dict, scopes=scope) |
|
|
client = gspread.authorize(creds) |
|
|
sheet = client.open("Sentiment Feedback Log").sheet1 |
|
|
|
|
|
|
|
|
sheet.append_row([ |
|
|
data.get("timestamp", ""), |
|
|
data.get("username", ""), |
|
|
data.get("user_id", ""), |
|
|
data.get("text", ""), |
|
|
data.get("model_a", ""), |
|
|
data.get("model_b", ""), |
|
|
data.get("ensemble", ""), |
|
|
data.get("feedback", "") |
|
|
]) |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Sentiment Model Comparison", layout="wide") |
|
|
st.title("๐ Sentiment Classifier Comparison") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model_and_tokenizer(model_file, tokenizer_file): |
|
|
model_path = hf_hub_download(repo_id="Daksh0505/sentiment-model-imdb", filename=model_file) |
|
|
tokenizer_path = hf_hub_download(repo_id="Daksh0505/sentiment-model-imdb", filename=tokenizer_file) |
|
|
|
|
|
with open(tokenizer_path, "r") as f: |
|
|
tokenizer = tokenizer_from_json(f.read()) |
|
|
|
|
|
model = load_model(model_path) |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
model_a, tokenizer_a = load_model_and_tokenizer("sentiment_model_imdb_6.6M.keras", "tokenizer_50k.json") |
|
|
model_b, tokenizer_b = load_model_and_tokenizer("sentiment_model_imdb_34M.keras", "tokenizer_256k.json") |
|
|
|
|
|
|
|
|
maxlen = 300 |
|
|
labels = ["Negative", "Neutral", "Positive"] |
|
|
|
|
|
|
|
|
def preprocess(text, tokenizer): |
|
|
text = text.lower() |
|
|
seq = tokenizer.texts_to_sequences([text]) |
|
|
padded = pad_sequences(seq, maxlen=maxlen, padding='post') |
|
|
return padded |
|
|
|
|
|
|
|
|
def format_probs(probs): |
|
|
return {labels[i]: f"{probs[i]*100:.2f}%" for i in range(3)} |
|
|
|
|
|
|
|
|
st.markdown("### ๐ Enter a review:") |
|
|
text = st.text_area("", height=150) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
file = st.file_uploader("๐ Or upload a CSV file with a 'review' column for bulk analysis", type=["csv"]) |
|
|
|
|
|
|
|
|
user_name = st.text_input("๐ Enter your name:") |
|
|
user_id = st.text_input("๐ Enter your email (optional):") |
|
|
|
|
|
pred_a = pred_b = ensemble_label = None |
|
|
|
|
|
if st.button("๐ Analyze") and (text.strip() or file): |
|
|
if text.strip(): |
|
|
padded_a = preprocess(text, tokenizer_a) |
|
|
padded_b = preprocess(text, tokenizer_b) |
|
|
pred_a = model_a.predict(padded_a)[0] |
|
|
pred_b = model_b.predict(padded_b)[0] |
|
|
ensemble_pred = (pred_a + pred_b) / 2 |
|
|
ensemble_label = labels[int(ensemble_pred.argmax())] |
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
|
|
with col1: |
|
|
st.subheader("๐น Model A") |
|
|
st.caption("๐ง 6M Parameters | ๐ 50k Vocab") |
|
|
st.markdown(" | ".join([f"**{l}:** {v}" for l, v in format_probs(pred_a).items()])) |
|
|
st.write(f"โ **Predicted:** _{labels[int(pred_a.argmax())]}_") |
|
|
|
|
|
with col2: |
|
|
st.subheader("๐ธ Model B") |
|
|
st.caption("๐ง 34M Parameters | ๐ 256k Vocab") |
|
|
st.markdown(" | ".join([f"**{l}:** {v}" for l, v in format_probs(pred_b).items()])) |
|
|
st.write(f"โ **Predicted:** _{labels[int(pred_b.argmax())]}_") |
|
|
|
|
|
with col3: |
|
|
st.subheader("โ๏ธ Ensemble Average") |
|
|
st.caption("๐งฎ Averaged Output (A + B)") |
|
|
st.markdown(" | ".join([f"**{l}:** {v}" for l, v in format_probs(ensemble_pred).items()])) |
|
|
st.write(f"โ **Final Sentiment:** โ
_{ensemble_label}_") |
|
|
|
|
|
st.markdown("### ๐ Confidence Comparison") |
|
|
st.bar_chart({ |
|
|
"Model A": pred_a, |
|
|
"Model B": pred_b, |
|
|
"Ensemble": ensemble_pred |
|
|
}) |
|
|
|
|
|
if file: |
|
|
df = pd.read_csv(file) |
|
|
if 'review' not in df.columns: |
|
|
st.error("CSV must contain a 'review' column.") |
|
|
else: |
|
|
preds = [] |
|
|
for text in df['review']: |
|
|
padded_a = preprocess(text, tokenizer_a) |
|
|
padded_b = preprocess(text, tokenizer_b) |
|
|
pred_a = model_a.predict(padded_a)[0] |
|
|
pred_b = model_b.predict(padded_b)[0] |
|
|
ensemble = (pred_a + pred_b) / 2 |
|
|
preds.append(labels[int(ensemble.argmax())]) |
|
|
|
|
|
df['Predicted Sentiment'] = preds |
|
|
st.dataframe(df) |
|
|
st.download_button("๐ฅ Download Results", df.to_csv(index=False), file_name="sentiment_predictions.csv") |
|
|
|
|
|
|
|
|
with st.expander("โน๏ธ Model Details"): |
|
|
st.markdown(""" |
|
|
- **Model A**: Smaller model, faster, trained on 50k vocab. |
|
|
- **Model B**: Larger model, more accurate, trained on 256k vocab. |
|
|
- Ensemble averages predictions from both. |
|
|
""") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("### ๐ฌ Feedback") |
|
|
feedback = st.radio("Was the prediction helpful?", ["๐ Yes", "๐ No", "No comment"], horizontal=True) |
|
|
|
|
|
if feedback and (user_name.strip() or user_id.strip() or text.strip()): |
|
|
st.success("Thanks for your feedback! โ
") |
|
|
|
|
|
feedback_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"username": user_name, |
|
|
"user_id": user_id, |
|
|
"text": text if text else None, |
|
|
"model_a": labels[int(pred_a.argmax())] if pred_a is not None else None, |
|
|
"model_b": labels[int(pred_b.argmax())] if pred_b is not None else None, |
|
|
"ensemble": ensemble_label if ensemble_label is not None else None, |
|
|
"feedback": feedback if feedback != "No comment" else None, |
|
|
} |
|
|
|
|
|
|
|
|
log_path = "user_feedback.csv" |
|
|
feedback_df = pd.DataFrame([feedback_data]) |
|
|
if not os.path.exists(log_path): |
|
|
feedback_df.to_csv(log_path, index=False) |
|
|
else: |
|
|
feedback_df.to_csv(log_path, mode='a', header=False, index=False) |
|
|
|
|
|
|
|
|
try: |
|
|
save_to_google_sheet(feedback_data) |
|
|
except Exception as e: |
|
|
st.error(f'Error saving feedback to Google Sheets: {e}') |
|
|
|
|
|
|