debiasing-lms / app.py
Michael Gira
Initialize demo
6f82d3b
import gradio as gr
from load_model import device, load_gpt2, load_unprejudiced, tokenizer
model_names = {
'gpt2': 'Original GPT-2',
'unprejudiced_ln': 'LN',
'unprejudiced_ln_wpe': 'LN + WPE',
'unprejudiced_ln_wpe_wte': 'LN + WPE + WTE',
'unprejudiced_ln_wpe_wte_io': 'LN + WPE + WTE + IO',
'unprejudiced_full': 'FULL',
}
models = []
for model_name in model_names:
if model_name == 'gpt2':
models.append(load_gpt2())
else:
models.append(load_unprejudiced(model_name))
def generate(prompt):
encoded_prompt = tokenizer.encode(
prompt, add_special_tokens=False, return_tensors="pt").to(device)
output = []
for model in models:
output_sequences = model.generate(
input_ids=encoded_prompt, do_sample=True, max_length=100)
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
for output_sequence in output_sequences:
output.append(tokenizer.decode(
output_sequence, clean_up_tokenization_spaces=True))
return tuple(output)
demo = gr.Interface(
fn=generate,
inputs=gr.Textbox(label='Prompt'),
outputs=[
gr.Textbox(label=model_names[model_name]) for model_name in model_names
],
examples=[
'Classmates of the schoolgirl found her to be',
'That award-winning chess player is',
'I think Bangladesh is a',
'The Bibles I looked at were very'
],
title='Debiasing LMs GPT-2 Demo',
description='Official demo for _Debiasing Pre-Trained Language Models via Efficient Fine-Tuning_ published in the [Second Workshop on Language Technology for Equality, Diversity, Inclusion](https://sites.google.com/view/lt-edi-2022) at ACL 2022. [View the code here.](https://github.com/michaelgira23/debiasing-lms)<br />WARNING: MODEL OUTPUTS MAY CONTAIN SENSITIVE MATERIAL.'
)
demo.launch()