Update app.py
Browse files
@@ -2,6 +2,10 @@ import streamlit as st
2 |
import numpy as np
3 |
import matplotlib.pyplot as plt
4 |
from sklearn.metrics import precision_recall_curve, auc
5 |
6 |
# Sidebar navigation
7 |
st.sidebar.title("App Navigation")
@@ -13,79 +17,77 @@ if page == "Sentiment Analysis":
13 |
st.title("Twitter Sentiment Analysis App")
14 |
15 |
# Load sentiment analysis pipeline
16 |
17 |
18 |
19 |
# Input box for user to enter a tweet
20 |
user_input = st.text_input("Enter a tweet to analyze:")
21 |
22 |
if user_input:
23 |
24 |
25 |
26 |
27 |
# Model Evaluation Page
28 |
elif page == "Model Evaluation":
29 |
st.title("Model Precision-Recall Evaluation")
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
# Display the plot
87 |
88 |
except Exception as e:
89 |
st.error(f"An error occurred while generating the PR curve: {e}")
90 |
91 |
st.info("Please select a model and ensure it generates valid data.")
2 |
import numpy as np
3 |
import matplotlib.pyplot as plt
4 |
from sklearn.metrics import precision_recall_curve, auc
5 |
from datasets import load_dataset
6 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
7 |
import torch
8 |
from tqdm import tqdm
9 |
10 |
# Sidebar navigation
11 |
st.sidebar.title("App Navigation")
17 |
st.title("Twitter Sentiment Analysis App")
18 |
19 |
# Load sentiment analysis pipeline
20 |
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
21 |
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
22 |
23 |
# Input box for user to enter a tweet
24 |
user_input = st.text_input("Enter a tweet to analyze:")
25 |
26 |
if user_input:
27 |
# Tokenize and predict
28 |
inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
29 |
with torch.no_grad():
30 |
outputs = model(**inputs)
31 |
probs = torch.softmax(outputs.logits, dim=-1)
32 |
33 |
sentiment = "POSITIVE" if probs[0][1] > probs[0][0] else "NEGATIVE"
34 |
st.write(f"Sentiment: {sentiment}")
35 |
st.write(f"Scores: {probs[0].numpy()}")
36 |
37 |
# Model Evaluation Page
38 |
elif page == "Model Evaluation":
39 |
st.title("Model Precision-Recall Evaluation")
40 |
41 |
# Load tweet_eval dataset
42 |
dataset_name = "cardiffnlp/tweet_eval"
43 |
task = st.selectbox("Choose a dataset task:", ["emoji", "sentiment"])
44 |
split = st.selectbox("Choose data split:", ["train", "validation", "test"])
45 |
46 |
# Load dataset
47 |
with st.spinner("Loading dataset..."):
48 |
dataset = load_dataset(dataset_name, task, split=split)
49 |
50 |
st.write(f"Loaded {len(dataset)} samples from {dataset_name} ({task}/{split}).")
51 |
52 |
# Load model
53 |
model_name = f"cardiffnlp/twitter-roberta-base-{task}"
54 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
55 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
56 |
57 |
# Batch predict on dataset
58 |
batch_size = 16
59 |
predicted_probs = []
60 |
true_labels = dataset["label"]
61 |
texts = dataset["text"]
62 |
63 |
with st.spinner("Running model predictions..."):
64 |
for i in tqdm(range(0, len(texts), batch_size)):
65 |
batch = texts[i:i + batch_size]
66 |
inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt")
67 |
with torch.no_grad():
68 |
outputs = model(**inputs)
69 |
probs = torch.softmax(outputs.logits, dim=-1)
70 |
71 |
72 |
# Select a class for PR Curve
73 |
num_classes = model.config.num_labels
74 |
class_to_evaluate = st.selectbox("Choose a class to evaluate:", list(range(num_classes)))
75 |
76 |
# Calculate Precision-Recall Curve
77 |
y_true = [1 if label == class_to_evaluate else 0 for label in true_labels]
78 |
y_score = [probs[class_to_evaluate] for probs in predicted_probs]
79 |
80 |
precision, recall, _ = precision_recall_curve(y_true, y_score)
81 |
pr_auc = auc(recall, precision)
82 |
83 |
# Plot PR Curve
84 |
fig, ax = plt.subplots()
85 |
ax.plot(recall, precision, label=f"PR Curve (AUC = {pr_auc:.2f})")
86 |
87 |
88 |
ax.set_title(f"Precision-Recall Curve for Class {class_to_evaluate}")
89 |
90 |
91 |
92 |
93 |
st.success(f"Precision-Recall AUC: {pr_auc:.2f}")