File size: 2,117 Bytes
89c7b20
 
3313c97
36182b8
89c7b20
3313c97
36182b8
 
0548248
 
36182b8
0548248
36182b8
 
 
 
 
3313c97
36182b8
 
 
3313c97
0548248
36182b8
 
 
0548248
36182b8
 
 
3313c97
 
36182b8
3313c97
36182b8
 
 
 
 
 
 
 
 
 
 
 
 
 
3313c97
36182b8
 
 
0548248
36182b8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import streamlit as st
import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification
import torch

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
model.eval()

# Function to classify text using the pre-trained BERT model
def classify_text(text):
    # Tokenize input text
    input_ids = tokenizer.encode(text, add_special_tokens=True)
    # Convert tokenized input to tensor
    input_tensor = torch.tensor([input_ids])
    # Get model predictions
    with torch.no_grad():
        logits = model(input_tensor)[0]
    # Get predicted labels
    predicted_labels = torch.sigmoid(logits).numpy()
    return predicted_labels

# Create a persistent DataFrame to store classification results
results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])

# Streamlit app
def app():
    st.title("Toxicity Classification App")
    st.write("Enter text below to classify its toxicity.")
    
    # User input
    user_input = st.text_area("Enter text here:", "", key='user_input')
    
    # Classification
    if st.button("Classify"):
        # Perform classification
        labels = classify_text(user_input)
        # Print classification results
        st.write("Classification Results:")
        st.write("Toxic: {:.2%}".format(labels[0][0]))
        st.write("Severe Toxic: {:.2%}".format(labels[0][1]))
        st.write("Obscene: {:.2%}".format(labels[0][2]))
        st.write("Threat: {:.2%}".format(labels[0][3]))
        st.write("Insult: {:.2%}".format(labels[0][4]))
        st.write("Identity Hate: {:.2%}".format(labels[0][5]))
        # Add results to persistent DataFrame
        results_df.loc[len(results_df)] = [user_input, labels[0][0], labels[0][1], labels[0][2], labels[0][3], labels[0][4], labels[0][5]]
    
    # Show results DataFrame
    st.write("Classification Results DataFrame:")
    st.write(results_df)

# Run the app
if __name__ == "__main__":
    app()