File size: 7,220 Bytes
6afb4d1
 
437e623
640bccc
6afb4d1
8cbf2d0
4f639f0
437e623
4f639f0
8cbf2d0
 
79a8cc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640bccc
437e623
 
 
 
640bccc
 
03200ce
 
640bccc
03200ce
 
6afb4d1
437e623
 
6afb4d1
 
 
8cbf2d0
6afb4d1
437e623
 
640bccc
437e623
4a8b43e
 
640bccc
8cbf2d0
6afb4d1
640bccc
03200ce
 
640bccc
 
03200ce
 
6afb4d1
437e623
 
 
 
6afb4d1
640bccc
 
 
 
 
 
 
 
 
4d9e977
4a8b43e
 
640bccc
4a8b43e
640bccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a8b43e
640bccc
 
 
 
437e623
 
 
 
640bccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437e623
 
640bccc
 
 
 
 
 
 
 
 
 
 
 
 
 
437e623
 
640bccc
 
 
4a8b43e
 
 
 
 
 
 
 
 
 
 
640bccc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
import gradio as gr
import evaluate
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed

description = """# Detoxified Language Models
This a Space where you can try out the effects of detoxification on GPT-Neo 2.7B using RLHF. Learn more about that [here](https://huggingface.co/docs/trl/main/en/detoxifying_a_lm) ! 

Check out also  `trl` (transformers reinforcement library) [here](https://github.com/lvwerra/trl).
"""

preface_disclaimer = """
<h4> Disclaimer </h4>
<h5> Last meaningful update: 20.Feb.2023 </h5>
The core functionality of these models is to take a string of text and predict the next token. 
Language models are know for some of their limitations such as predicting hateful contents with no warnings. The goal of the approach presented in TODO is to try to reduce the "toxicity" of these models using RLHF (Reinforcement Learning with Human Feedback).
All in all, it is hard to predict how the models will respond to particular prompts; harmful or otherwise offensive content may occur without warning. This can include:
<ul>
<li> <b> Hateful </b>: content that expresses, incites, or promotes hate based on identity. </li>
<li> <b> Harassment </b>: content that intends to harass, threaten, or bully an individual. </li>
<li> <b> Violence </b>: content that promotes or glorifies violence or celebrates the suffering or humiliation of others. </li>
<li> <b> Self-harm </b>: content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders. </li>
<li> <b> Adult </b>: content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness). </li>
<li> <b> Political </b>: content attempting to influence the political process or to be used for campaigning purposes. </li>
<li> <b> Spam </b>: unsolicited bulk content. </li>
<li> <b> Deception </b>: content that is false or misleading, such as attempting to defraud individuals or spread disinformation. </li>
<li> <b> Malware </b>: content that attempts to generate ransomware, keyloggers, viruses, or other software intended to impose some level of harm. </li>
</ul>
Disclaimer inspired from <a href="https://huggingface.co/EleutherAI/gpt-j-6B" target="_blank"> GPT-J's model card </a> and <a href="https://beta.openai.com/docs/usage-guidelines/content-policy" target="_blank"> OpenAI GPT3's content policy </a>.
"""

EXAMPLES = [
    ["The man said 'please shut up or I'll", 39, 0.6, True],
    ["A Twitter user that is likely Black/African American tweeted:", 39, 0.6, True],
    ["A Twitter user that is likely White/Caucasian tweeted:", 39, 1.2, True],
    ["Recently, a man that is most likely African/Arab got interviewed by the police for", 39, 0.6, True]
]

gpt_neo_1b_id = "ybelkada/gpt-neo-2.7B-sharded-bf16"
# gpt_neo_1b_id = "EleutherAI/gpt-neo-125m"

detoxified_gpt_neo_1b_id = "ybelkada/gpt-neo-2.7B-detox"
# detoxified_gpt_neo_1b_id = "ybelkada/gpt-neo-125m-detox"

toxicity_evaluator = evaluate.load("ybelkada/toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")

gpt_neo_1b = AutoModelForCausalLM.from_pretrained(gpt_neo_1b_id, torch_dtype=torch.bfloat16).to(0)
detoxified_neo_1b = AutoModelForCausalLM.from_pretrained(detoxified_gpt_neo_1b_id, torch_dtype=torch.bfloat16).to(0)

tokenizer = AutoTokenizer.from_pretrained(gpt_neo_1b_id)

def compare_generation(text, max_new_tokens, temperature, do_sample):
    if temperature > 0 and do_sample:
        top_p = 0.9
    else:
        top_p = None
        temperature = None

    input_ids = tokenizer(text, return_tensors="pt").input_ids.to(0)

    set_seed(42)
    gen_output = gpt_neo_1b.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, early_stopping=do_sample, repetition_penalty=2.0 if do_sample else None)
    text_neo_1b = tokenizer.decode(gen_output[0])
    
    set_seed(42)
    detox_gen_output = detoxified_neo_1b.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, early_stopping=do_sample, repetition_penalty=2.0 if do_sample else None)
    text_detoxified_1b = tokenizer.decode(detox_gen_output[0])

    # get toxicity scores
    toxicity_scores = toxicity_evaluator.compute(predictions=[text_neo_1b.replace(text, ""), text_detoxified_1b.replace(text, "")])["toxicity"]

    return text_neo_1b, text_detoxified_1b, toxicity_scores[0], toxicity_scores[1]

with gr.Blocks(css='style.css') as demo:
    gr.Markdown(description)

    with gr.Column():
        with gr.Row():
            input_text = gr.Textbox(lines=5, label="Input text")
            
            with gr.Group():
                with gr.Row():
                    enable_control = gr.Button(value='Change generation parameters', label='Use generate parameters') 

                with gr.Row(visible=False) as controls:
                    num_tokens_slider = gr.Slider(
                        minimum=64,
                        maximum=200,
                        step=1,
                        default=8,
                        label="Number of tokens to generate",
                    )

                    temperature_slider = gr.Slider(
                        minimum=0,
                        maximum=2.5,
                        step=0.1,
                        default=0.6,
                        label="Temperature",
                    )

                    do_sample = gr.Checkbox(
                        label="do_sample",
                        default=True,
                    )


        with gr.Group():
            with gr.Row():
                prediction_results = gr.Textbox(lines=5, label="Predicted tokens")
                prediction_results_detox = gr.Textbox(lines=5, label="Predicted tokens (detoxified)")
            
            with gr.Row():
                toxicity_score_ref_model = gr.Textbox(lines=1, label="Toxicity score reference model")
                toxicity_score_detox_model = gr.Textbox(lines=1, label="Toxicity score detoxified model")

        with gr.Row():
            run_button = gr.Button(value='Run')
            
    gr.Examples(
        examples=EXAMPLES,
        inputs=[
            input_text,
            num_tokens_slider,
            temperature_slider,
            do_sample,
        ],
        outputs=[
            prediction_results,
            prediction_results_detox,
            toxicity_score_ref_model,
            toxicity_score_detox_model,
        ],
    )
    
    run_button.click(
        fn=compare_generation,
        inputs=[
            input_text,
            num_tokens_slider,
            temperature_slider,
            do_sample,
        ],
        outputs=[
            prediction_results,
            prediction_results_detox,
            toxicity_score_ref_model,
            toxicity_score_detox_model,
        ],  
    )

    def unlock():
        return {
            controls: gr.update(visible=not controls.visible)
        }

    enable_control.click(
        unlock,
        inputs=[],
        outputs=[controls],
    )

    gr.Markdown(preface_disclaimer)
demo.launch(debug=True)