miracle01 commited on
Commit
d51980e
1 Parent(s): a1a2c83

Upload 21 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ calm.wav filter=lfs diff=lfs merge=lfs -text
37
+ data/chew1.wav filter=lfs diff=lfs merge=lfs -text
38
+ data/tapping1.wav filter=lfs diff=lfs merge=lfs -text
39
+ data/theeStallion1.wav filter=lfs diff=lfs merge=lfs -text
40
+ data/trump1.wav filter=lfs diff=lfs merge=lfs -text
41
+ data/trump2.wav filter=lfs diff=lfs merge=lfs -text
42
+ hrv-breathing.gif filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import whisper
3
+ import evaluate
4
+ from evaluate.utils import launch_gradio_widget
5
+ import gradio as gr
6
+ import torch
7
+ import pandas as pd
8
+ import random
9
+ import classify
10
+ import replace_explitives
11
+ from whisper.model import Whisper
12
+ from whisper.tokenizer import get_tokenizer
13
+ from speechbrain.pretrained.interfaces import foreign_class
14
+ from transformers import AutoModelForSequenceClassification, pipeline, WhisperTokenizer, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
15
+
16
+
17
+ # pull in emotion detection
18
+ # --- Add element for specification
19
+ # pull in text classification
20
+ # --- Add custom labels
21
+ # --- Associate labels with radio elements
22
+ # add logic to initiate mock notificaiton when detected
23
+ # pull in misophonia-specific model
24
+
25
+ model_cache = {}
26
+
27
+ # Building prediction function for gradio
28
+ emo_dict = {
29
+ 'sad': 'Sad',
30
+ 'hap': 'Happy',
31
+ 'ang': 'Anger',
32
+ 'neu': 'Neutral'
33
+ }
34
+
35
+ # static classes for now, but it would be best ot have the user select from multiple, and to enter their own
36
+ class_options = {
37
+ "Racism": ["racism", "hate speech", "bigotry", "racially targeted", "racial slur", "ethnic slur", "ethnic hate", "pro-white nationalism"],
38
+ "LGBTQ+ Hate": ["gay slur", "trans slur", "homophobic slur", "transphobia", "anti-LBGTQ+"],
39
+ "Sexually Explicit": ["sexually explicit", "sexually coercive", "sexual exploitation", "vulgar", "raunchy", "sexist", "sexually demeaning", "sexual violence", "victim blaming"],
40
+ "Pregnancy Complications": ["miscarriage", "child loss", "child death", "abortion", "pregnancy", "childbirth", "baby shower", "postpartum"],
41
+ }
42
+
43
+ pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large")
44
+
45
+ toxicity_module = evaluate.load("toxicity", "facebook/roberta-hate-speech-dynabench-r4-target")
46
+ emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
47
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
48
+ text_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
49
+
50
+ def classify_emotion(audio):
51
+ #### Emotion classification ####
52
+ # EMO MODEL LINE emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
53
+ out_prob, score, index, text_lab = emotion_classifier.classify_file(audio)
54
+ return emo_dict[text_lab[0]]
55
+
56
+ def slider_logic(slider):
57
+ threshold = 0
58
+ if slider == 1:
59
+ threshold = .90
60
+ elif slider == 2:
61
+ threshold = .80
62
+ elif slider == 3:
63
+ threshold = .60
64
+ elif slider == 4:
65
+ threshold = .50
66
+ elif slider == 5:
67
+ threshold = .40
68
+ else:
69
+ threshold = []
70
+ return threshold
71
+
72
+ # Create a Gradio interface with audio file and text inputs
73
+ def classify_toxicity(audio_file, classify_anxiety, emo_class, explitive_selection, slider):
74
+
75
+ # Transcribe the audio file using Whisper ASR
76
+ transcribed_text = pipe(audio_file)["text"]
77
+
78
+ ## SLIDER ##
79
+ threshold = slider_logic(slider)
80
+
81
+ #------- explitive call ---------------
82
+
83
+ if replace_explitives != None and emo_class == None:
84
+ transcribed_text = replace_explitives.sub_explitives(transcribed_text, explitive_selection)
85
+
86
+ #### Toxicity Classifier ####
87
+
88
+ # TOX MODEL LINE toxicity_module = evaluate.load("toxicity", "facebook/roberta-hate-speech-dynabench-r4-target")
89
+ #toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")
90
+
91
+ toxicity_results = toxicity_module.compute(predictions=[transcribed_text])
92
+
93
+ toxicity_score = toxicity_results["toxicity"][0]
94
+ print(toxicity_score)
95
+
96
+ # emo call
97
+ if emo_class != None:
98
+ classify_emotion(audio_file)
99
+
100
+ #### Text classification #####
101
+ if classify_anxiety != None:
102
+ # DEVICE LINE device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
103
+
104
+ # CLASSIFICATION LINE text_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
105
+
106
+ sequence_to_classify = transcribed_text
107
+ print(classify_anxiety, class_options)
108
+ candidate_labels = class_options.get(classify_anxiety, [])
109
+ # classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
110
+ classification_output = text_classifier(sequence_to_classify, candidate_labels, multi_label=True)
111
+ print("class output ", type(classification_output))
112
+ # classification_df = pd.DataFrame.from_dict(classification_output)
113
+ print("keys ", classification_output.keys())
114
+
115
+ # formatted_classification_output = "\n".join([f"{key}: {value}" for key, value in classification_output.items()])
116
+ # label_score_pairs = [(label, score) for label, score in zip(classification_output['labels'], classification_output['scores'])]
117
+ label_score_dict = {label: score for label, score in zip(classification_output['labels'], classification_output['scores'])}
118
+ k = max(label_score_dict, key=label_score_dict.get)
119
+ print("k keys: ", k)
120
+ maxval = label_score_dict[k]
121
+ print("max value: ", maxval)
122
+ topScore = ""
123
+ affirm = ""
124
+ if maxval > threshold:
125
+ print("Toxic")
126
+ affirm = positive_affirmations()
127
+ topScore = maxval
128
+ else:
129
+ print("Not Toxic")
130
+ affirm = ""
131
+ topScore = maxval
132
+ else:
133
+ topScore = ""
134
+ affirm = ""
135
+ if toxicity_score > threshold:
136
+ affirm = positive_affirmations()
137
+ topScore = toxicity_score
138
+ else:
139
+ affirm = ""
140
+ topScore = toxicity_score
141
+ label_score_dict = {"toxicity" : toxicity_score}
142
+
143
+ return transcribed_text, topScore, label_score_dict, affirm
144
+ # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
145
+
146
+ def positive_affirmations():
147
+ affirmations = [
148
+ "I have survived my anxiety before and I will survive again now",
149
+ "I am not in danger; I am just uncomfortable; this too will pass",
150
+ "I forgive and release the past and look forward to the future",
151
+ "I can't control what other people say but I can control my breathing and my response"
152
+ ]
153
+ selected_affirm = random.choice(affirmations)
154
+ return selected_affirm
155
+
156
+ with gr.Blocks() as iface:
157
+ show_state = gr.State([])
158
+ with gr.Column():
159
+ anxiety_class = gr.Radio(label="Specify Subclass", choices=["Racism", "LGBTQ+ Hate", "Sexually Explicit", "Pregnancy Complications"])
160
+ explit_preference = gr.Radio(choices=["N-Word", "B-Word", "All Explitives"], label="Words to omit from general anxiety classes", info="certain words may be acceptible within certain contects for given groups of people, and some people may be unbothered by explitives broadly speaking.")
161
+ emo_class = gr.Radio(choices=["negaitve emotionality"], label="Negative Emotionality", info="Select if you would like explitives to be considered anxiety-indiucing in the case of anger/ negative emotionality.")
162
+ sense_slider = gr.Slider(minimum=1, maximum=5, step=1.0, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity")
163
+ with gr.Column():
164
+ aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
165
+ submit_btn = gr.Button(label="Run")
166
+ with gr.Column():
167
+ out_text = gr.Textbox(label="Transcribed Audio")
168
+ out_val = gr.Textbox(label="Overall Toxicity")
169
+ out_affirm = gr.Textbox(label="Intervention")
170
+ out_class = gr.Label(label="Toxicity Class Breakdown")
171
+ submit_btn.click(fn=classify_toxicity, inputs=[aud_input, anxiety_class, emo_class, explit_preference, sense_slider], outputs=[out_text, out_val, out_class, out_affirm])
172
+
173
+ iface.launch()
b_word.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ bitch
2
+ bitches
calm.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bafc436a822d1e3670b457660087b3caf518e9d0d83d8c999bc642a5166f4b1
3
+ size 15916220
classify.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from whisper.audio import N_FRAMES, N_MELS, log_mel_spectrogram, pad_or_trim
6
+ from whisper.model import Whisper
7
+ from whisper.tokenizer import Tokenizer
8
+
9
+
10
+ @torch.no_grad()
11
+ def calculate_audio_features(audio_path: Optional[str], model: Whisper) -> torch.Tensor:
12
+ if audio_path is None:
13
+ segment = torch.zeros((N_MELS, N_FRAMES), dtype=torch.float32).to(model.device)
14
+ else:
15
+ mel = log_mel_spectrogram(audio_path)
16
+ segment = pad_or_trim(mel, N_FRAMES).to(model.device)
17
+ return model.embed_audio(segment.unsqueeze(0))
18
+
19
+
20
+ @torch.no_grad()
21
+ def calculate_average_logprobs(
22
+ model: Whisper,
23
+ audio_features: torch.Tensor,
24
+ class_names: List[str],
25
+ tokenizer: Tokenizer,
26
+ ) -> torch.Tensor:
27
+ initial_tokens = (
28
+ torch.tensor(tokenizer.sot_sequence_including_notimestamps).unsqueeze(0).to(model.device)
29
+ )
30
+ eot_token = torch.tensor([tokenizer.eot]).unsqueeze(0).to(model.device)
31
+
32
+ average_logprobs = torch.zeros(len(class_names))
33
+ for i, class_name in enumerate(class_names):
34
+ class_name_tokens = (
35
+ torch.tensor(tokenizer.encode(" " + class_name)).unsqueeze(0).to(model.device)
36
+ )
37
+ input_tokens = torch.cat([initial_tokens, class_name_tokens, eot_token], dim=1)
38
+
39
+ logits = model.logits(input_tokens, audio_features) # (1, T, V)
40
+ logprobs = F.log_softmax(logits, dim=-1).squeeze(0) # (T, V)
41
+ logprobs = logprobs[len(tokenizer.sot_sequence_including_notimestamps) - 1 : -1] # (T', V)
42
+ logprobs = torch.gather(logprobs, dim=-1, index=class_name_tokens.view(-1, 1)) # (T', 1)
43
+ average_logprob = logprobs.mean().item()
44
+ average_logprobs[i] = average_logprob
45
+
46
+ return average_logprobs
47
+
48
+
49
+ def calculate_internal_lm_average_logprobs(
50
+ model: Whisper,
51
+ class_names: List[str],
52
+ tokenizer: Tokenizer,
53
+ verbose: bool = False,
54
+ ) -> torch.Tensor:
55
+ audio_features_from_empty_input = calculate_audio_features(None, model)
56
+ average_logprobs = calculate_average_logprobs(
57
+ model=model,
58
+ audio_features=audio_features_from_empty_input,
59
+ class_names=class_names,
60
+ tokenizer=tokenizer,
61
+ )
62
+ if verbose:
63
+ print("Internal LM average log probabilities for each class:")
64
+ for i, class_name in enumerate(class_names):
65
+ print(f" {class_name}: {average_logprobs[i]:.3f}")
66
+ return average_logprobs
data/chew1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91ded13633e05d45c791366de676d09c9eb2f4b51979e211112e734d97cfdf08
3
+ size 5120104
data/clears_throat1.wav ADDED
Binary file (180 kB). View file
 
data/mouth_sounds1.wav ADDED
Binary file (446 kB). View file
 
data/pop1.wav ADDED
Binary file (89.4 kB). View file
 
data/sigh1.wav ADDED
Binary file (485 kB). View file
 
data/slurp1.wav ADDED
Binary file (596 kB). View file
 
data/tapping1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9506ae17e7176ef99a36a3f556f9f45837f804e77a6b0013a38470e73f8ed5e4
3
+ size 2958378
data/theeStallion1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d347f0917b4989ebe491bb96955d31afcf3b31e6480d43136a7ff3bffd8dd9da
3
+ size 2736184
data/trump1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3952dd160ff5031da8fab133d6d685bbcdc190bb34fa0ee7c75b7c7e5ff9a8ea
3
+ size 10700952
data/trump2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2ee9ec1a06544428bf8c0f66eda962533e7e70467a0b30492dfbfd1d30c0981
3
+ size 7329432
expletives.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ shit
2
+ fuck
3
+ fucked
4
+ damn
5
+ damned
6
+ goddamn
7
+ goddmaned
8
+ crap
9
+ crapped
10
+ ass
11
+ asshole
12
+ bastard
13
+ bastards
14
+ piss
15
+ pissed
hrv-breathing.gif ADDED

Git LFS Details

  • SHA256: 21006becdbb899ab341a246f42b415b1e2d391589e6eca1693cc8992b9857845
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
n_word.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ nigga
2
+ niggas
3
+ nigg
4
+ nig
5
+ niggs
6
+ nigs
7
+ nigger
8
+ niggers
9
+ niggaz
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
replace_explitives.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+ import nltk
3
+
4
+ def load_words_from_file(file_path):
5
+ with open(file_path, "r", encoding="utf-8") as f:
6
+ words = [line.strip() for line in f.readlines()]
7
+ return words
8
+
9
+ def sub_explitives(textfile, selection):
10
+
11
+ replacetext = "person"
12
+
13
+ # Load target words from text files
14
+ b_word_list = load_words_from_file("b_word.txt")
15
+ n_word_list = load_words_from_file("n_word.txt")
16
+ expletives_list = load_words_from_file("expletives.txt")
17
+
18
+ # text = word_tokenize(textfile)
19
+ # print(text)
20
+ # sentences = sent_tokenize(textfile)
21
+
22
+ if selection == "B-Word":
23
+ target_word = b_word_list
24
+ elif selection == "N-Word":
25
+ target_word = n_word_list
26
+ elif selection == "All Explitives":
27
+ target_word = expletives_list
28
+ else:
29
+ target_word = []
30
+
31
+ print("selection:", selection, "target_word:", target_word)
32
+ lines = textfile.split('\n')
33
+
34
+ if target_word:
35
+ print("target word was found, ", target_word)
36
+ print(textfile)
37
+ for i, line in enumerate(lines):
38
+ for word in target_word:
39
+ pattern = r"\b" + re.escape(word) + r"\b"
40
+ # textfile = re.sub(target_word, replacetext, textfile, flags=re.IGNORECASE)
41
+ lines[i] = re.sub(pattern, replacetext, lines[i], flags=re.IGNORECASE)
42
+
43
+ textfile = '\n'.join(lines)
44
+ return textfile
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/evaluate@775555d80af30d83dc6e9f42051840d29a34f31b
2
+ git+https://github.com/openai/whisper.git
3
+ transformers
4
+ torch
5
+ speechbrain
6
+ torchaudio
7
+ git+https://github.com/openai/whisper.git
8
+ tqdm
9
+ gradio==3.14.0
10
+ regex
11
+ nltk
toxicity.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Evaluate Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """ Toxicity detection measurement. """
16
+
17
+ import datasets
18
+ from transformers import pipeline
19
+
20
+ import evaluate
21
+
22
+
23
+ logger = evaluate.logging.get_logger(__name__)
24
+
25
+
26
+ _CITATION = """
27
+ @inproceedings{vidgen2021lftw,
28
+ title={Learning from the Worst: Dynamically Generated Datasets to Improve Online Hate Detection},
29
+ author={Bertie Vidgen and Tristan Thrush and Zeerak Waseem and Douwe Kiela},
30
+ booktitle={ACL},
31
+ year={2021}
32
+ }
33
+ """
34
+
35
+ _DESCRIPTION = """\
36
+ The toxicity measurement aims to quantify the toxicity of the input texts using a pretrained hate speech classification model.
37
+ """
38
+
39
+ _KWARGS_DESCRIPTION = """
40
+ Compute the toxicity of the input sentences.
41
+
42
+ Args:
43
+ `predictions` (list of str): prediction/candidate sentences
44
+ `toxic_label` (str) (optional): the toxic label that you want to detect, depending on the labels that the model has been trained on.
45
+ This can be found using the `id2label` function, e.g.:
46
+ model = AutoModelForSequenceClassification.from_pretrained("DaNLP/da-electra-hatespeech-detection")
47
+ print(model.config.id2label)
48
+ {0: 'not offensive', 1: 'offensive'}
49
+ In this case, the `toxic_label` would be 'offensive'.
50
+ `aggregation` (optional): determines the type of aggregation performed on the data. If set to `None`, the scores for each prediction are returned.
51
+ Otherwise:
52
+ - 'maximum': returns the maximum toxicity over all predictions
53
+ - 'ratio': the percentage of predictions with toxicity above a certain threshold.
54
+ `threshold`: (int) (optional): the toxicity detection to be used for calculating the 'ratio' aggregation, described above.
55
+ The default threshold is 0.5, based on the one established by [RealToxicityPrompts](https://arxiv.org/abs/2009.11462).
56
+
57
+ Returns:
58
+ `toxicity`: a list of toxicity scores, one for each sentence in `predictions` (default behavior)
59
+ `max_toxicity`: the maximum toxicity over all scores (if `aggregation` = `maximum`)
60
+ `toxicity_ratio`": the percentage of predictions with toxicity >= 0.5 (if `aggregation` = `ratio`)
61
+
62
+ Examples:
63
+
64
+ Example 1 (default behavior):
65
+ >>> toxicity = evaluate.load("toxicity", module_type="measurement")
66
+ >>> input_texts = ["she went to the library", "he is a douchebag"]
67
+ >>> results = toxicity.compute(predictions=input_texts)
68
+ >>> print([round(s, 4) for s in results["toxicity"]])
69
+ [0.0002, 0.8564]
70
+
71
+ Example 2 (returns ratio of toxic sentences):
72
+ >>> toxicity = evaluate.load("toxicity", module_type="measurement")
73
+ >>> input_texts = ["she went to the library", "he is a douchebag"]
74
+ >>> results = toxicity.compute(predictions=input_texts, aggregation="ratio")
75
+ >>> print(results['toxicity_ratio'])
76
+ 0.5
77
+
78
+ Example 3 (returns the maximum toxicity score):
79
+
80
+ >>> toxicity = evaluate.load("toxicity", module_type="measurement")
81
+ >>> input_texts = ["she went to the library", "he is a douchebag"]
82
+ >>> results = toxicity.compute(predictions=input_texts, aggregation="maximum")
83
+ >>> print(round(results['max_toxicity'], 4))
84
+ 0.8564
85
+
86
+ Example 4 (uses a custom model):
87
+
88
+ >>> toxicity = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection')
89
+ >>> input_texts = ["she went to the library", "he is a douchebag"]
90
+ >>> results = toxicity.compute(predictions=input_texts, toxic_label='offensive')
91
+ >>> print([round(s, 4) for s in results["toxicity"]])
92
+ [0.0176, 0.0203]
93
+ """
94
+
95
+
96
+ def toxicity(preds, toxic_classifier, toxic_label):
97
+ toxic_scores = []
98
+ if toxic_label not in toxic_classifier.model.config.id2label.values():
99
+ raise ValueError(
100
+ "The `toxic_label` that you specified is not part of the model labels. Run `model.config.id2label` to see what labels your model outputs."
101
+ )
102
+
103
+ for pred_toxic in toxic_classifier(preds):
104
+ hate_toxic = [r["score"] for r in pred_toxic if r["label"] == toxic_label][0]
105
+ toxic_scores.append(hate_toxic)
106
+ return toxic_scores
107
+
108
+
109
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
110
+ class Toxicity(evaluate.Measurement):
111
+ def _info(self):
112
+ return evaluate.MeasurementInfo(
113
+ module_type="measurement",
114
+ description=_DESCRIPTION,
115
+ citation=_CITATION,
116
+ inputs_description=_KWARGS_DESCRIPTION,
117
+ features=datasets.Features(
118
+ {
119
+ "predictions": datasets.Value("string", id="sequence"),
120
+ }
121
+ ),
122
+ codebase_urls=[],
123
+ reference_urls=[],
124
+ )
125
+
126
+ def _download_and_prepare(self, dl_manager):
127
+ if self.config_name == "default":
128
+ logger.warning("Using default facebook/roberta-hate-speech-dynabench-r4-target checkpoint")
129
+ model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
130
+ else:
131
+ model_name = self.config_name
132
+ self.toxic_classifier = pipeline("text-classification", model=model_name, top_k=99999, truncation=True)
133
+
134
+ def _compute(self, predictions, aggregation="all", toxic_label="hate", threshold=0.5):
135
+ scores = toxicity(predictions, self.toxic_classifier, toxic_label)
136
+ if aggregation == "ratio":
137
+ return {"toxicity_ratio": sum(i >= threshold for i in scores) / len(scores)}
138
+ elif aggregation == "maximum":
139
+ return {"max_toxicity": max(scores)}
140
+ else:
141
+ return {"toxicity": scores}