detoxified-lms / app.py
ybelkada's picture
updates
8cbf2d0
raw
history blame
No virus
4.27 kB
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
description = """# Detoxified Language Models
This a Space where you can try out the effects of detoxification on GPT-Neo 2.7B using RLHF. Learn more about that [here]()
"""
preface_disclaimer = """
<h4> Disclaimer </h4>
<h5> Last meaningful update: 20.Feb.2023 </h5>
The core functionality of these models is to take a string of text and predict the next token.
Language models are know for some of their limitations such as predicting hateful contents with no warnings. The goal of the approach presented in TODO is to try to reduce the "toxicity" of these models using RLHF (Reinforcement Learning with Human Feedback).
All in all, it is hard to predict how the models will respond to particular prompts; harmful or otherwise offensive content may occur without warning. This can include:
<ul>
<li> <b> Hateful </b>: content that expresses, incites, or promotes hate based on identity. </li>
<li> <b> Harassment </b>: content that intends to harass, threaten, or bully an individual. </li>
<li> <b> Violence </b>: content that promotes or glorifies violence or celebrates the suffering or humiliation of others. </li>
<li> <b> Self-harm </b>: content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders. </li>
<li> <b> Adult </b>: content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness). </li>
<li> <b> Political </b>: content attempting to influence the political process or to be used for campaigning purposes. </li>
<li> <b> Spam </b>: unsolicited bulk content. </li>
<li> <b> Deception </b>: content that is false or misleading, such as attempting to defraud individuals or spread disinformation. </li>
<li> <b> Malware </b>: content that attempts to generate ransomware, keyloggers, viruses, or other software intended to impose some level of harm. </li>
</ul>
Disclaimer inspired from <a href="https://huggingface.co/EleutherAI/gpt-j-6B" target="_blank"> GPT-J's model card </a> and <a href="https://beta.openai.com/docs/usage-guidelines/content-policy" target="_blank"> OpenAI GPT3's content policy </a>.
"""
gpt_neo_1b_id = "ybelkada/gpt-neo-2.7B-sharded-bf16"
detoxified_gpt_neo_1b_id = "ybelkada/gpt-neo-2.7B-detox"
gpt_neo_1b = AutoModelForCausalLM.from_pretrained(gpt_neo_1b_id, torch_dtype=torch.bfloat16).to(0)
detoxified_neo_1b = AutoModelForCausalLM.from_pretrained(detoxified_gpt_neo_1b_id, torch_dtype=torch.bfloat16).to(0)
tokenizer = AutoTokenizer.from_pretrained(gpt_neo_1b_id)
def compare_generation(text, max_new_tokens, temperature, top_p, top_k):
if top_p > 0:
top_k = 0
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(0)
text_neo_1b = tokenizer.decode(gpt_neo_1b.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, top_k=top_k, early_stopping=True)[0])
text_detoxified_1b = tokenizer.decode(detoxified_neo_1b.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, top_k=top_k, early_stopping=True)[0])
return text_neo_1b, text_detoxified_1b
iface = gr.Interface(
fn=compare_generation,
inputs=[
gr.Textbox(lines=5, label="Input text"),
gr.inputs.Slider(
minimum=8,
maximum=1000,
step=1,
default=8,
label="Number of tokens to generate",
),
gr.inputs.Slider(
minimum=0,
maximum=2.5,
step=0.1,
default=0.6,
label="Temperature",
),
gr.inputs.Slider(
minimum=0,
maximum=1,
step=0.1,
default=0,
label="top_p",
),
gr.inputs.Slider(
minimum=0,
maximum=50,
step=1,
default=0,
label="top_k",
),
],
outputs=[
gr.Textbox(label="Predicted tokens - gpt neo 2.7b:", lines=5),
gr.Textbox(label="Predicted detoxified tokens - gpt neo 2.7b:", lines=5),
],
description=description
)
iface.launch()