stack_overflow / src /streamlit_app.py
shujath000's picture
Update src/streamlit_app.py
e20be02 verified
import os
import streamlit as st
import pandas as pd
import joblib
import numpy as np
import string
import nltk
from nltk.corpus import stopwords as stp
from nltk import pos_tag, word_tokenize as w, sent_tokenize as s
from nltk.stem import WordNetLemmatizer as wl
NLTK_DATA_PATH = "/app/nltk_data"
os.makedirs(NLTK_DATA_PATH, exist_ok=True)
os.environ["NLTK_DATA"] = NLTK_DATA_PATH
nltk.download('punkt_tab', quiet=True)
nltk.download('punkt_tab', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('punkt', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('averaged_perceptron_tagger', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('wordnet', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('stopwords', download_dir=NLTK_DATA_PATH, quiet=True)
# Download necessary NLTK data
#nltk.download('punkt', quiet=True)
#nltk.download('averaged_perceptron_tagger', quiet=True)
#nltk.download('wordnet', quiet=True)
#nltk.download('stopwords', quiet=True)
nltk.download('punkt', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('averaged_perceptron_tagger_eng', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('wordnet', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('stopwords', download_dir=NLTK_DATA_PATH, quiet=True)
# === Cleaning Function ===
def sahi_karneka_function(x):
nouns=[]
li=[]
lem=wl()
l=s(x)
for i in l:
d=w(i.lower())
for k in d:
li.append(k)
lw=len(li)
j=0
while j<lw:
if li[j] in string.punctuation:
li.remove(li[j])
lw=len(li)
j=0
elif li[j] in stp.words("english"):
li.remove(li[j])
lw=len(li)
j=0
else:
j=j+1
tags=pos_tag(li)
for word,tag in tags:
if tag.startswith("NN") or tag.startswith("V"):
nouns.append(word)
semi_final_words=[lem.lemmatize(m,pos="n") if tagg.startswith("NN") else lem.lemmatize(m,pos="v") for m,tagg in pos_tag(nouns)]
final_sentence=" ".join(semi_final_words)
return final_sentence
# === Load Data and Models ===
df = pd.read_csv(r"src/c_d.csv")
model = joblib.load("src/logistic_models.pkl")
tfidf = joblib.load("src/tfidf.pkl")
ml = joblib.load("src/multilabels.pkl")
# === Streamlit UI ===
st.title("🧠 Multi-Label Question Tag Predictor")
# --- Select a URL for context ---
selected_url = st.selectbox("Select a question URL (for context):", df['questions_url'])
st.markdown(f"πŸ”— [Open selected question]({selected_url})")
# --- Session State ---
if "user_input" not in st.session_state:
st.session_state["user_input"] = ""
if "clear_input" not in st.session_state:
st.session_state["clear_input"] = False
# --- Clear input if flagged (AFTER rerun) ---
if st.session_state.clear_input:
st.session_state.user_input = ""
st.session_state.clear_input = False
# --- Input box ---
st.text_area("✍️ Type your question here:", key="user_input", height=150)
# --- Predict button ---
if st.button("Predict Tags"):
final_question = st.session_state.user_input.strip()
if not final_question:
st.warning("⚠️ Please enter a question.")
else:
with st.spinner("πŸ” Predicting tags..."):
# Step 1: Clean input
cleaned = sahi_karneka_function(final_question)
# Step 2: TF-IDF
f=[]
f.append(cleaned)
x_tfidf = tfidf.transform(f)
# Step 3: Predict
y_probs = model.predict_proba(x_tfidf)
threshold = 0.55
y_predd=model.predict(x_tfidf)
probs_column1 = np.array([i[:, 1] for i in y_probs]).T
y_pred = (probs_column1 >= threshold).astype(int)
# Step 4: Decode
predicted_tags = ml.inverse_transform(y_predd)
# Step 5: Display results
st.success("βœ… Predicted Tags:")
if predicted_tags and predicted_tags[0]:
for tag in predicted_tags[0]:
st.markdown(f"πŸ”Ή **`{tag}`**")
else:
st.info("No tags matched the threshold.")
# Step 6: Show a "Clear" button
if st.button("Clear Input"):
st.session_state.user_input = ""