File size: 1,782 Bytes
d7728be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63bc522
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
import torch
from detoxify import Detoxify
from transformers import AutoTokenizer, AutoModelForSequenceClassification


def run_model(model_choice, text):
    if model_choice == "original":
        results = Detoxify("original").predict(text)
    elif model_choice == "unbiased":
        results = Detoxify("unbiased").predict(text)
    elif model_choice == "multilingual":
        results = Detoxify("multilingual").predict(text)
    elif model_choice == "Toxic-BERT":
        tokenizer = AutoTokenizer.from_pretrained(
            "citizenlab/distilbert-base-multilingual-cased-toxicity"
        )
        model = AutoModelForSequenceClassification.from_pretrained(
            "citizenlab/distilbert-base-multilingual-cased-toxicity"
        )
        # Perform inference with the Toxic-BERT model using tokenizer and model

        encoded_input = tokenizer(
            text, padding=True, truncation=True, max_length=512, return_tensors="pt"
        )
        logits = model(**encoded_input).logits
        probabilities = torch.sigmoid(logits).detach().cpu().numpy().tolist()[0]
        results = {"toxic": probabilities[0], "non_toxic": probabilities[1]}
        # Convert the predicted_labels to your desired output format

    return results


model_choices = ["original", "unbiased", "multilingual", "Toxic-BERT"]

input_textbox = gr.inputs.Textbox(label="Input Text")
model_choice_dropdown = gr.inputs.Dropdown(choices=model_choices, label="Model Choice")
output_text = gr.outputs.Textbox(label="Output Tags")

iface = gr.Interface(
    fn=run_model,
    inputs=[model_choice_dropdown, input_textbox],
    outputs=output_text,
    title="Toxicity Detection App",
    description="Choose a model and input text to detect toxicity.",
)

iface.launch()