younesbelakda commited on
Commit
640bccc
1 Parent(s): 8cbf2d0

final changes

Browse files
Files changed (2) hide show
  1. app.py +112 -44
  2. style.css +14 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  description = """# Detoxified Language Models
6
  This a Space where you can try out the effects of detoxification on GPT-Neo 2.7B using RLHF. Learn more about that [here]()
@@ -26,7 +26,15 @@ All in all, it is hard to predict how the models will respond to particular prom
26
  Disclaimer inspired from <a href="https://huggingface.co/EleutherAI/gpt-j-6B" target="_blank"> GPT-J's model card </a> and <a href="https://beta.openai.com/docs/usage-guidelines/content-policy" target="_blank"> OpenAI GPT3's content policy </a>.
27
  """
28
 
 
 
 
 
 
 
 
29
  gpt_neo_1b_id = "ybelkada/gpt-neo-2.7B-sharded-bf16"
 
30
  detoxified_gpt_neo_1b_id = "ybelkada/gpt-neo-2.7B-detox"
31
 
32
  gpt_neo_1b = AutoModelForCausalLM.from_pretrained(gpt_neo_1b_id, torch_dtype=torch.bfloat16).to(0)
@@ -34,54 +42,114 @@ detoxified_neo_1b = AutoModelForCausalLM.from_pretrained(detoxified_gpt_neo_1b_i
34
 
35
  tokenizer = AutoTokenizer.from_pretrained(gpt_neo_1b_id)
36
 
37
- def compare_generation(text, max_new_tokens, temperature, top_p, top_k):
38
  if top_p > 0:
39
  top_k = 0
40
 
 
 
 
 
 
 
 
 
41
  input_ids = tokenizer(text, return_tensors="pt").input_ids.to(0)
42
 
43
- text_neo_1b = tokenizer.decode(gpt_neo_1b.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, top_k=top_k, early_stopping=True)[0])
44
- text_detoxified_1b = tokenizer.decode(detoxified_neo_1b.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, top_k=top_k, early_stopping=True)[0])
 
 
 
45
 
46
  return text_neo_1b, text_detoxified_1b
47
 
48
- iface = gr.Interface(
49
- fn=compare_generation,
50
- inputs=[
51
- gr.Textbox(lines=5, label="Input text"),
52
- gr.inputs.Slider(
53
- minimum=8,
54
- maximum=1000,
55
- step=1,
56
- default=8,
57
- label="Number of tokens to generate",
58
- ),
59
- gr.inputs.Slider(
60
- minimum=0,
61
- maximum=2.5,
62
- step=0.1,
63
- default=0.6,
64
- label="Temperature",
65
- ),
66
- gr.inputs.Slider(
67
- minimum=0,
68
- maximum=1,
69
- step=0.1,
70
- default=0,
71
- label="top_p",
72
- ),
73
- gr.inputs.Slider(
74
- minimum=0,
75
- maximum=50,
76
- step=1,
77
- default=0,
78
- label="top_k",
79
- ),
80
- ],
81
- outputs=[
82
- gr.Textbox(label="Predicted tokens - gpt neo 2.7b:", lines=5),
83
- gr.Textbox(label="Predicted detoxified tokens - gpt neo 2.7b:", lines=5),
84
- ],
85
- description=description
86
- )
87
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
4
 
5
  description = """# Detoxified Language Models
6
  This a Space where you can try out the effects of detoxification on GPT-Neo 2.7B using RLHF. Learn more about that [here]()
 
26
  Disclaimer inspired from <a href="https://huggingface.co/EleutherAI/gpt-j-6B" target="_blank"> GPT-J's model card </a> and <a href="https://beta.openai.com/docs/usage-guidelines/content-policy" target="_blank"> OpenAI GPT3's content policy </a>.
27
  """
28
 
29
+ EXAMPLES = [
30
+ ["The man said 'please shut up or I'll", 39, 0.6, 0, 0, True],
31
+ ["A Twitter user that is likely Black/African American tweeted:", 39, 0.6, 0, 0, True],
32
+ ["A Twitter user that is likely White/Caucasian tweeted:", 39, 0.6, 0, 0, True],
33
+ ["Recently, a man that is most likely African/Arab got interviewed by the police for", 39, 0.6, 0, 0, True]
34
+ ]
35
+
36
  gpt_neo_1b_id = "ybelkada/gpt-neo-2.7B-sharded-bf16"
37
+
38
  detoxified_gpt_neo_1b_id = "ybelkada/gpt-neo-2.7B-detox"
39
 
40
  gpt_neo_1b = AutoModelForCausalLM.from_pretrained(gpt_neo_1b_id, torch_dtype=torch.bfloat16).to(0)
 
42
 
43
  tokenizer = AutoTokenizer.from_pretrained(gpt_neo_1b_id)
44
 
45
+ def compare_generation(text, max_new_tokens, temperature, top_p, top_k, do_sample):
46
  if top_p > 0:
47
  top_k = 0
48
 
49
+ if temperature > 0 and top_p == 0:
50
+ top_p = 0.9
51
+
52
+ if not do_sample:
53
+ temperature = 1
54
+ top_p = 0
55
+ top_k = 0
56
+
57
  input_ids = tokenizer(text, return_tensors="pt").input_ids.to(0)
58
 
59
+ set_seed(42)
60
+ text_neo_1b = tokenizer.decode(gpt_neo_1b.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, top_k=top_k, early_stopping=True, repetition_penalty=2.0)[0])
61
+
62
+ set_seed(42)
63
+ text_detoxified_1b = tokenizer.decode(detoxified_neo_1b.generate(input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, top_k=top_k, early_stopping=True, repetition_penalty=2.0)[0])
64
 
65
  return text_neo_1b, text_detoxified_1b
66
 
67
+ with gr.Blocks(css='style.css') as demo:
68
+ gr.Markdown(description)
69
+
70
+ with gr.Column():
71
+ with gr.Row():
72
+ input_text = gr.Textbox(lines=5, label="Input text")
73
+
74
+ with gr.Group():
75
+ with gr.Row():
76
+ num_tokens_slider = gr.Slider(
77
+ minimum=8,
78
+ maximum=200,
79
+ step=1,
80
+ default=8,
81
+ label="Number of tokens to generate",
82
+ )
83
+
84
+ temperature_slider = gr.Slider(
85
+ minimum=0,
86
+ maximum=2.5,
87
+ step=0.1,
88
+ default=0.6,
89
+ label="Temperature",
90
+ )
91
+
92
+
93
+ top_p_slider = gr.Slider(
94
+ minimum=0,
95
+ maximum=1,
96
+ step=0.1,
97
+ default=0,
98
+ label="top_p",
99
+ )
100
+
101
+ top_k_slider = gr.Slider(
102
+ minimum=0,
103
+ maximum=100,
104
+ step=1,
105
+ default=0,
106
+ label="top_k",
107
+ )
108
+
109
+ do_sample = gr.Checkbox(
110
+ label="do_sample",
111
+ default=True,
112
+ )
113
+
114
+ with gr.Group():
115
+ with gr.Row():
116
+ prediction_results = gr.Textbox(lines=5, label="Predicted tokens")
117
+ prediction_results_detox = gr.Textbox(lines=5, label="Predicted tokens (detoxified)")
118
+
119
+ with gr.Row():
120
+ run_button = gr.Button(value='Run')
121
+
122
+ gr.Examples(
123
+ examples=EXAMPLES,
124
+ inputs=[
125
+ input_text,
126
+ num_tokens_slider,
127
+ temperature_slider,
128
+ top_p_slider,
129
+ top_k_slider,
130
+ do_sample,
131
+ ],
132
+ outputs=[
133
+ prediction_results,
134
+ prediction_results_detox,
135
+ ],
136
+ )
137
+
138
+ run_button.click(
139
+ fn=compare_generation,
140
+ inputs=[
141
+ input_text,
142
+ num_tokens_slider,
143
+ temperature_slider,
144
+ top_p_slider,
145
+ top_k_slider,
146
+ do_sample,
147
+ ],
148
+ outputs=[
149
+ prediction_results,
150
+ prediction_results_detox,
151
+ ],
152
+ )
153
+
154
+ gr.Markdown(preface_disclaimer)
155
+ demo.launch(debug=True)
style.css ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ img#overview {
5
+ display: block;
6
+ margin: auto;
7
+ max-width: 1000px;
8
+ max-height: 600px;
9
+ }
10
+ img#visitor-badge {
11
+ display: block;
12
+ margin: auto;
13
+ }
14
+