ynp3 commited on
Commit
36182b8
1 Parent(s): 5704371

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -54
app.py CHANGED
@@ -1,71 +1,56 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import torch
4
  from transformers import BertTokenizer, BertForSequenceClassification
 
5
 
6
  # Load pre-trained BERT model and tokenizer
7
- MODEL_NAME = 'bert-base-uncased'
8
- tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
9
- model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=6)
10
  model.eval()
11
 
12
- # Create DataFrame to store classification results
13
- df_results = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
14
-
15
  def classify_text(text):
16
- # Tokenize text
17
- tokens = tokenizer.encode_plus(
18
- text,
19
- max_length=512,
20
- truncation=True,
21
- padding=True,
22
- return_attention_mask=True,
23
- return_tensors='pt'
24
- )
25
-
26
- # Get model's predictions
27
  with torch.no_grad():
28
- outputs = model(**tokens)
29
- logits = outputs.logits
30
- probabilities = torch.softmax(logits, dim=1).tolist()[0]
31
-
32
- # Extract predicted labels
33
- threshold = 0.5
34
- labels = ['Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate']
35
- predicted_labels = [labels[i] for i, prob in enumerate(probabilities) if prob > threshold]
36
-
37
  return predicted_labels
38
 
 
 
 
39
  # Streamlit app
40
- def main():
41
- st.title('Toxicity Classification')
 
42
 
43
  # User input
44
- text = st.text_area('Enter text:', max_chars=512)
45
 
46
- # Perform classification
47
- if st.button('Classify'):
48
- predicted_labels = classify_text(text)
49
- st.write('Predicted Labels:', predicted_labels)
50
-
51
- # Allow user to add classification results to DataFrame
52
- if st.button('Add to Results'):
53
- global df_results
54
- df_results = df_results.append({
55
- 'Text': text,
56
- 'Toxic': 'Toxic' in predicted_labels,
57
- 'Severe Toxic': 'Severe Toxic' in predicted_labels,
58
- 'Obscene': 'Obscene' in predicted_labels,
59
- 'Threat': 'Threat' in predicted_labels,
60
- 'Insult': 'Insult' in predicted_labels,
61
- 'Identity Hate': 'Identity Hate' in predicted_labels
62
- }, ignore_index=True)
63
- st.success('Classification results added to DataFrame.')
64
 
65
- # Show DataFrame with classification results
66
- if not df_results.empty:
67
- st.subheader('Classification Results')
68
- st.dataframe(df_results)
69
 
70
- if __name__ == '__main__':
71
- main()
 
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  from transformers import BertTokenizer, BertForSequenceClassification
4
+ import torch
5
 
6
  # Load pre-trained BERT model and tokenizer
7
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
 
9
  model.eval()
10
 
11
+ # Function to classify text using the pre-trained BERT model
 
 
12
  def classify_text(text):
13
+ # Tokenize input text
14
+ input_ids = tokenizer.encode(text, add_special_tokens=True)
15
+ # Convert tokenized input to tensor
16
+ input_tensor = torch.tensor([input_ids])
17
+ # Get model predictions
 
 
 
 
 
 
18
  with torch.no_grad():
19
+ logits = model(input_tensor)[0]
20
+ # Get predicted labels
21
+ predicted_labels = torch.sigmoid(logits).numpy()
 
 
 
 
 
 
22
  return predicted_labels
23
 
24
+ # Create a persistent DataFrame to store classification results
25
+ results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
26
+
27
  # Streamlit app
28
+ def app():
29
+ st.title("Toxicity Classification App")
30
+ st.write("Enter text below to classify its toxicity.")
31
 
32
  # User input
33
+ user_input = st.text_area("Enter text here:", "", key='user_input')
34
 
35
+ # Classification
36
+ if st.button("Classify"):
37
+ # Perform classification
38
+ labels = classify_text(user_input)
39
+ # Print classification results
40
+ st.write("Classification Results:")
41
+ st.write("Toxic: {:.2%}".format(labels[0][0]))
42
+ st.write("Severe Toxic: {:.2%}".format(labels[0][1]))
43
+ st.write("Obscene: {:.2%}".format(labels[0][2]))
44
+ st.write("Threat: {:.2%}".format(labels[0][3]))
45
+ st.write("Insult: {:.2%}".format(labels[0][4]))
46
+ st.write("Identity Hate: {:.2%}".format(labels[0][5]))
47
+ # Add results to persistent DataFrame
48
+ results_df.loc[len(results_df)] = [user_input, labels[0][0], labels[0][1], labels[0][2], labels[0][3], labels[0][4], labels[0][5]]
 
 
 
 
49
 
50
+ # Show results DataFrame
51
+ st.write("Classification Results DataFrame:")
52
+ st.write(results_df)
 
53
 
54
+ # Run the app
55
+ if __name__ == "__main__":
56
+ app()