jjmakes commited on
Commit
b3fb325
·
1 Parent(s): b62ed83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -33
app.py CHANGED
@@ -1,37 +1,78 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer
3
-
4
- # Define a list of pretrained models
5
- models = {
6
- "BERTweet": "finiteautomata/bertweet-base-sentiment-analysis",
7
- "roBERTa": "cardiffnlp/twitter-roberta-base-sentiment",
8
- "Distilbert": "bhadresh-savani/distilbert-base-uncased-emotion",
9
- "BERT (BONUS!: Item Review)": "nlptown/bert-base-multilingual-uncased-sentiment",
10
- }
11
-
12
- # Display a selection box for the user to choose a model
13
- selected_model = st.selectbox("Select a pretrained model", list(models.keys()))
14
-
15
- # roBERTa specific label map
16
- roberta_label_map = {"LABEL_0": "negative",
17
- "LABEL_1": "neutral", "LABEL_2": "positive"}
18
-
19
- # Load the selected model and tokenizer
20
- model_name = models[selected_model]
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- sentiment_pipeline = pipeline(
23
- "sentiment-analysis", model=model_name, tokenizer=tokenizer)
24
-
25
- # Get user input and perform sentiment analysis
26
- text_input = st.text_input("Enter text for sentiment analysis",
27
- "Anish is a very awesome and dedicated TA! 🤗")
 
 
 
 
 
 
 
 
28
  submit_btn = st.button("Submit")
29
 
 
 
 
 
30
  if submit_btn and text_input:
31
- result = sentiment_pipeline(text_input)
32
- if selected_model == "roBERTa":
33
- st.write("Sentiment:", roberta_label_map[result[0]["label"]])
34
- st.write("Score:", result[0]["score"])
35
- else:
36
- st.write("Sentiment:", result[0]["label"])
37
- st.write("Score:", result[0]["score"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ import torch
4
+ import pandas as pd
5
+ import random
6
+
7
+ classifiers = ['toxic', 'severe_toxic', 'obscene',
8
+ 'threat', 'insult', 'identity_hate']
9
+
10
+
11
+ def reset_scores():
12
+ global scores_df
13
+ scores_df = pd.DataFrame(columns=['Comment'] + classifiers)
14
+
15
+
16
+ def get_score(model_base, text):
17
+ if model_base == "bert-base-cased":
18
+ model_dir = "./bert/_bert_model"
19
+ elif model_base == "distilbert-base-cased":
20
+ model_dir = "./distilbert/_distilbert_model"
21
+ else:
22
+ model_dir = "./roberta/_roberta_model"
23
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
24
+ tokenizer = AutoTokenizer.from_pretrained(model_base)
25
+ inputs = tokenizer.encode_plus(
26
+ text, max_length=512, truncation=True, padding=True, return_tensors='pt')
27
+ outputs = model(**inputs)
28
+ predictions = torch.sigmoid(outputs.logits)
29
+ return predictions
30
+
31
+
32
+ # Ask user for input, return scores
33
+ st.title("Toxic Comment Classifier")
34
+ text_input = st.text_input("Enter text for toxicity classification",
35
+ "I hope you die")
36
  submit_btn = st.button("Submit")
37
 
38
+ # Drop down menu for model selection, default is roberta
39
+ model_base = st.selectbox("Select a pretrained model",
40
+ ["roberta-base", "bert-base-cased", "distilbert-base-cased"])
41
+
42
  if submit_btn and text_input:
43
+ result = get_score(model_base, text_input)
44
+
45
+ df = pd.DataFrame([result[0].tolist()], columns=classifiers)
46
+ df = df.round(2) # Round the values to 2 decimal places
47
+ # Format the values as percentages
48
+ df = df.applymap(lambda x: '{:.0%}'.format(x))
49
+
50
+ st.table(df)
51
+
52
+ # Read the test dataset
53
+ test_df = pd.read_csv(
54
+ "./jigsaw-toxic-comment-classification-challenge/test.csv")
55
+
56
+ # Select 10 random comments from the test dataset
57
+ sample_df = test_df.sample(n=3)
58
+
59
+ # Create an empty DataFrame to store the scores
60
+ reset_scores()
61
+
62
+ # Calculate the scores for each comment and add them to the DataFrame
63
+ for index, row in sample_df.iterrows():
64
+ result = get_score(model_base, row['comment_text'])
65
+ scores = result[0].tolist()
66
+ scores_df.loc[len(scores_df)] = [row['comment_text']] + scores
67
+
68
+ # Round the values to 2 decimal places
69
+ scores_df = scores_df.round(2)
70
+
71
+
72
+ st.subheader("Toxicity Scores for Random Comments")
73
+ st.table(scores_df)
74
+
75
+ # Create a button to reset the scores
76
+ if st.button("Refresh Random Tweets"):
77
+ reset_scores()
78
+ st.success("New tweets have been loaded!")