File size: 1,902 Bytes
6f82d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
from load_model import device, load_gpt2, load_unprejudiced, tokenizer

model_names = {
    'gpt2': 'Original GPT-2',
    'unprejudiced_ln': 'LN',
    'unprejudiced_ln_wpe': 'LN + WPE',
    'unprejudiced_ln_wpe_wte': 'LN + WPE + WTE',
    'unprejudiced_ln_wpe_wte_io': 'LN + WPE + WTE + IO',
    'unprejudiced_full': 'FULL',
}

models = []
for model_name in model_names:
    if model_name == 'gpt2':
        models.append(load_gpt2())
    else:
        models.append(load_unprejudiced(model_name))


def generate(prompt):

    encoded_prompt = tokenizer.encode(
        prompt, add_special_tokens=False, return_tensors="pt").to(device)

    output = []
    for model in models:
        output_sequences = model.generate(
            input_ids=encoded_prompt, do_sample=True, max_length=100)
        if len(output_sequences.shape) > 2:
            output_sequences.squeeze_()

        for output_sequence in output_sequences:
            output.append(tokenizer.decode(
                output_sequence, clean_up_tokenization_spaces=True))

    return tuple(output)


demo = gr.Interface(
    fn=generate,
    inputs=gr.Textbox(label='Prompt'),
    outputs=[
        gr.Textbox(label=model_names[model_name]) for model_name in model_names
    ],
    examples=[
        'Classmates of the schoolgirl found her to be',
        'That award-winning chess player is',
        'I think Bangladesh is a',
        'The Bibles I looked at were very'
    ],
    title='Debiasing LMs GPT-2 Demo',
    description='Official demo for _Debiasing Pre-Trained Language Models via Efficient Fine-Tuning_ published in the [Second Workshop on Language Technology for Equality, Diversity, Inclusion](https://sites.google.com/view/lt-edi-2022) at ACL 2022. [View the code here.](https://github.com/michaelgira23/debiasing-lms)<br />WARNING: MODEL OUTPUTS MAY CONTAIN SENSITIVE MATERIAL.'
)

demo.launch()