Spaces:
Runtime error
Runtime error
import gradio as gr | |
from load_model import device, load_gpt2, load_unprejudiced, tokenizer | |
model_names = { | |
'gpt2': 'Original GPT-2', | |
'unprejudiced_ln': 'LN', | |
'unprejudiced_ln_wpe': 'LN + WPE', | |
'unprejudiced_ln_wpe_wte': 'LN + WPE + WTE', | |
'unprejudiced_ln_wpe_wte_io': 'LN + WPE + WTE + IO', | |
'unprejudiced_full': 'FULL', | |
} | |
models = [] | |
for model_name in model_names: | |
if model_name == 'gpt2': | |
models.append(load_gpt2()) | |
else: | |
models.append(load_unprejudiced(model_name)) | |
def generate(prompt): | |
encoded_prompt = tokenizer.encode( | |
prompt, add_special_tokens=False, return_tensors="pt").to(device) | |
output = [] | |
for model in models: | |
output_sequences = model.generate( | |
input_ids=encoded_prompt, do_sample=True, max_length=100) | |
if len(output_sequences.shape) > 2: | |
output_sequences.squeeze_() | |
for output_sequence in output_sequences: | |
output.append(tokenizer.decode( | |
output_sequence, clean_up_tokenization_spaces=True)) | |
return tuple(output) | |
demo = gr.Interface( | |
fn=generate, | |
inputs=gr.Textbox(label='Prompt'), | |
outputs=[ | |
gr.Textbox(label=model_names[model_name]) for model_name in model_names | |
], | |
examples=[ | |
'Classmates of the schoolgirl found her to be', | |
'That award-winning chess player is', | |
'I think Bangladesh is a', | |
'The Bibles I looked at were very' | |
], | |
title='Debiasing LMs GPT-2 Demo', | |
description='Official demo for _Debiasing Pre-Trained Language Models via Efficient Fine-Tuning_ published in the [Second Workshop on Language Technology for Equality, Diversity, Inclusion](https://sites.google.com/view/lt-edi-2022) at ACL 2022. [View the code here.](https://github.com/michaelgira23/debiasing-lms)<br />WARNING: MODEL OUTPUTS MAY CONTAIN SENSITIVE MATERIAL.' | |
) | |
demo.launch() | |