APJ23 commited on
Commit
ab4de6e
1 Parent(s): ac15df8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ # Define the available models to choose from
7
+ models = {
8
+ 'BERT': 'bert-base-uncased',
9
+ 'RoBERTa': 'roberta-base',
10
+ 'DistilBERT': 'distilbert-base-uncased'
11
+ }
12
+
13
+ # Create a drop-down menu to select the model
14
+ model_name = st.sidebar.selectbox('Select Model', list(models.keys()))
15
+
16
+ # Load the tokenizer and model
17
+ tokenizer = AutoTokenizer.from_pretrained(models[model_name])
18
+ model = AutoModelForSequenceClassification.from_pretrained(models[model_name])
19
+
20
+ # Define the classes and their corresponding labels
21
+ classes = {
22
+ 0: 'Non-Toxic',
23
+ 1: 'Toxic',
24
+ 2: 'Severely Toxic',
25
+ 3: 'Obscene',
26
+ 4: 'Threat',
27
+ 5: 'Insult',
28
+ 6: 'Identity Hate'
29
+ }
30
+
31
+ # Create a function to generate the toxicity predictions
32
+ @st.cache(allow_output_mutation=True)
33
+ def predict_toxicity(tweet, model, tokenizer):
34
+ # Preprocess the text
35
+ inputs = tokenizer(tweet, padding=True, truncation=True, return_tensors='pt')
36
+ # Get the predictions from the model
37
+ outputs = model(**inputs)
38
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=1).detach().numpy()
39
+ # Get the class with the highest probability
40
+ predicted_class = int(predictions.argmax())
41
+ predicted_class_label = classes[predicted_class]
42
+ predicted_prob = predictions[0][predicted_class]
43
+ return predicted_class_label, predicted_prob
44
+
45
+ # Create a table to display the toxicity predictions
46
+ def create_table(predictions):
47
+ data = {'Tweet': [], 'Highest Toxicity Class': [], 'Probability': []}
48
+ for tweet, prediction in predictions.items():
49
+ data['Tweet'].append(tweet)
50
+ data['Highest Toxicity Class'].append(prediction[0])
51
+ data['Probability'].append(prediction[1])
52
+ df = pd.DataFrame(data)
53
+ return df
54
+
55
+ # Create the user interface
56
+ st.title('Toxicity Prediction App')
57
+ tweet_input = st.text_input('Enter a tweet:')
58
+ if st.button('Predict'):
59
+ # Generate the toxicity prediction for the tweet using the selected model
60
+ predicted_class_label, predicted_prob = predict_toxicity(tweet_input, model, tokenizer)
61
+ prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
62
+ st.write(prediction_text)
63
+
64
+ # Display the toxicity predictions in a table
65
+ predictions = {tweet_input: (predicted_class_label, predicted_prob)}
66
+ table = create_table(predictions)
67
+ st.table(table)