Rob Caamano
commited on
Commit
•
39e1615
1
Parent(s):
bf5dc75
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
-
from transformers import AutoTokenizer
|
4 |
from transformers import (
|
5 |
TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
|
6 |
)
|
@@ -21,6 +21,9 @@ mod_name = model_options[selected_model]
|
|
21 |
|
22 |
tokenizer = AutoTokenizer.from_pretrained(mod_name)
|
23 |
model = AutoModelForSequenceClassification.from_pretrained(mod_name)
|
|
|
|
|
|
|
24 |
|
25 |
if selected_model in ["Fine-tuned Toxicity Model"]:
|
26 |
toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
|
@@ -30,10 +33,10 @@ def get_toxicity_class(predictions, threshold=0.3):
|
|
30 |
return {model.config.id2label[i]: pred for i, pred in enumerate(predictions) if pred >= threshold}
|
31 |
|
32 |
input = tokenizer(text, return_tensors="tf")
|
33 |
-
prediction = model(input)[0].numpy()[0]
|
34 |
|
35 |
if st.button("Submit", type="primary"):
|
36 |
-
|
|
|
37 |
|
38 |
tweet_portion = text[:50] + "..." if len(text) > 50 else text
|
39 |
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
+
from transformers import AutoTokenizer, pipeline
|
4 |
from transformers import (
|
5 |
TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
|
6 |
)
|
|
|
21 |
|
22 |
tokenizer = AutoTokenizer.from_pretrained(mod_name)
|
23 |
model = AutoModelForSequenceClassification.from_pretrained(mod_name)
|
24 |
+
clf = pipeline(
|
25 |
+
"sentiment-analysis", model=model, tokenizer=tokenizer, return_all_scores=True
|
26 |
+
)
|
27 |
|
28 |
if selected_model in ["Fine-tuned Toxicity Model"]:
|
29 |
toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
|
|
|
33 |
return {model.config.id2label[i]: pred for i, pred in enumerate(predictions) if pred >= threshold}
|
34 |
|
35 |
input = tokenizer(text, return_tensors="tf")
|
|
|
36 |
|
37 |
if st.button("Submit", type="primary"):
|
38 |
+
results = dict(d.values() for d in clf(text)[0])
|
39 |
+
toxic_labels = {k: results[k] for k in results.keys() if not k == "toxic"}
|
40 |
|
41 |
tweet_portion = text[:50] + "..." if len(text) > 50 else text
|
42 |
|