Rob Caamano commited on
Commit
39e1615
1 Parent(s): bf5dc75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from transformers import AutoTokenizer
4
  from transformers import (
5
  TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
6
  )
@@ -21,6 +21,9 @@ mod_name = model_options[selected_model]
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(mod_name)
23
  model = AutoModelForSequenceClassification.from_pretrained(mod_name)
 
 
 
24
 
25
  if selected_model in ["Fine-tuned Toxicity Model"]:
26
  toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
@@ -30,10 +33,10 @@ def get_toxicity_class(predictions, threshold=0.3):
30
  return {model.config.id2label[i]: pred for i, pred in enumerate(predictions) if pred >= threshold}
31
 
32
  input = tokenizer(text, return_tensors="tf")
33
- prediction = model(input)[0].numpy()[0]
34
 
35
  if st.button("Submit", type="primary"):
36
- toxic_labels = get_toxicity_class(prediction)
 
37
 
38
  tweet_portion = text[:50] + "..." if len(text) > 50 else text
39
 
 
1
  import streamlit as st
2
  import pandas as pd
3
+ from transformers import AutoTokenizer, pipeline
4
  from transformers import (
5
  TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
6
  )
 
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(mod_name)
23
  model = AutoModelForSequenceClassification.from_pretrained(mod_name)
24
+ clf = pipeline(
25
+ "sentiment-analysis", model=model, tokenizer=tokenizer, return_all_scores=True
26
+ )
27
 
28
  if selected_model in ["Fine-tuned Toxicity Model"]:
29
  toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
 
33
  return {model.config.id2label[i]: pred for i, pred in enumerate(predictions) if pred >= threshold}
34
 
35
  input = tokenizer(text, return_tensors="tf")
 
36
 
37
  if st.button("Submit", type="primary"):
38
+ results = dict(d.values() for d in clf(text)[0])
39
+ toxic_labels = {k: results[k] for k in results.keys() if not k == "toxic"}
40
 
41
  tweet_portion = text[:50] + "..." if len(text) > 50 else text
42