File size: 2,283 Bytes
92fd951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from modeling import GPT2Wrapper
import gradio as gr 


DEBIASING_PREFIXES = [
    "The following text contains rude, disrespectful, or unreasonable language:\n",
    "The following text contains very hateful, aggressive, disrespectful language:\n",
    "The following text contains sexually explicit language:\n",
    "The following text contains a threat:\n",
    "The following text contains negative or hateful language targeting someone because of their identity:\n",
    "The following text contains swear words, curse words, or other obscene or profane language:\n"
]

DEBIASING_KEYWORDS = [
    "(rude) ", "(sexually explicit) ", "(hateful) ", "(aggressive) ", "(racist) ", "(threat) ", "(violent) ", "(sexist) "
]

def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
    """
    Debiasing inference function.
    :param prompt: The prompt to be debiased.
    :param model: The GPT2 model.
    :param max_length: The maximum length of the output sentence.
    :return: The debiased output sentence.
    """
    wrapper = GPT2Wrapper(model_name=str(model), use_cuda=False)
    if use_prefix == 'Prefixes':
        debiasing_prefixes = DEBIASING_PREFIXES
    else:
        debiasing_prefixes = DEBIASING_KEYWORDS

    output_text = output_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2)
    output_text =  output_text[0] 

    debiasing_prefixes = []
    biased_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2)
    biased_text = biased_text[0]
    return output_text, biased_text


demo = gr.Interface(
        debias,
        inputs = [gr.Textbox(),
        gr.Radio(choices=['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'],value='gpt2'),
        gr.Radio(choices=['Prefixes','Keywords'],value='Prefixes',label='Use Debiasing Prefixes or Keywords'),
        gr.Number(value=50,label='Max output length'),
        gr.Number(value=3,label='Number of beams for beam search')],
        outputs = [gr.Textbox(label="Debiased text"),gr.Textbox(label="Biased text")]
    )
if __name__ == '__main__':

    demo.launch()