Spaces:
Runtime error
Runtime error
File size: 9,857 Bytes
9de0135 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import time
from functools import lru_cache
import torch
import gradio as gr
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
@lru_cache(maxsize=1) # only cache the latest model
def get_model_and_tokenizer(model_id):
config = AutoConfig.from_pretrained(model_id)
if config.is_encoder_decoder:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
else:
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
return model, tokenizer
@lru_cache(maxsize=32768) # cache up to 32k examples
def run_generation(
text,
model_id,
max_new_tokens,
alpha=0.0,
top_k=0,
num_beams=1,
do_sample=False,
top_p=0.0,
seed=0
):
model, tokenizer = get_model_and_tokenizer(model_id)
inputs = tokenizer(text, return_tensors='pt')
if seed:
torch.manual_seed(seed)
start = time.time_ns()
contrastive_ids = model.generate(
# from the tokenizer
**inputs,
# fixed arguments
num_return_sequences=1,
early_stopping=True,
# variable arguments
max_new_tokens=max_new_tokens,
do_sample=do_sample,
num_beams=num_beams,
penalty_alpha=alpha or None,
top_k=top_k or None,
top_p=top_p or None,
)
end = time.time_ns()
contrastive_time = (end - start) / 1e6
contrastive_text = tokenizer.decode(contrastive_ids[0], skip_special_tokens=True)
return contrastive_text, contrastive_time
def generate_beam_search(text, model_id, max_new_tokens, alpha, k, num_beams):
contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k)
beam_search_text, beam_search_time = run_generation(text, model_id, max_new_tokens, num_beams=num_beams)
return contrastive_text, contrastive_time, beam_search_text, beam_search_time
def generate_top_k(text, model_id, max_new_tokens, alpha, k, top_k, seed):
contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k)
top_k_text, top_k_time = run_generation(
text, model_id, max_new_tokens, top_k=top_k, seed=seed, do_sample=True
)
return contrastive_text, contrastive_time, top_k_text, top_k_time
def generate_top_p(text, model_id, max_new_tokens, alpha, k, top_p, seed):
contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k)
top_p_text, top_p_time = run_generation(
text, model_id, max_new_tokens, top_p=top_p, seed=seed, do_sample=True
)
return contrastive_text, contrastive_time, top_p_text, top_p_time
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# Contrastive Search Generation comparison
Credits to the contrastive search generation [paper](https://arxiv.org/abs/2202.06417) authors, including
@[pangpang666](https://huggingface.co/pangpang666) and @[GMFTBY](https://huggingface.co/GMFTBY). Check out the
follow-up [work](https://arxiv.org/abs/2210.14140), which demonstrates the usefulness of the technique with
off-the-shelf LLMs, as well as their [HF guest blog post](https://huggingface.co/blog/introducing-csearch).
From the paper:
"At each decoding step, the key ideas of contrastive search are (i) the generated output should be selected
from the set of most probable candidates predicted by the model; and (ii) the generated output should be
discriminative enough with respect to the previous context. In this way, the generated text can (i) better
maintain the semantic coherence with respect to the prefix while (ii) avoiding model degeneration."
π¨ Warnings: π¨
- Avoid using large models (> 1GB) in this demo. It will take a long time to load the model and generate text.
- Too slow/long queue? Check our
[colab](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/115_introducing_contrastive_search.ipynb)
instead.
"""
)
with gr.Tabs():
with gr.TabItem("vs. Beam Search"):
with gr.Row():
with gr.Column():
gr.Markdown("## Inputs βοΈ")
gr.Markdown("General options:")
model_id = gr.Text(value="facebook/opt-125m", label="Model Repository")
input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text")
max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate")
gr.Markdown("Contrastive Search options:")
alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha")
k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K")
gr.Markdown("Beam Search options:")
num_beams = gr.Slider(value=4, minimum=1, maximum=16, step=1, label="Number of beams")
generate_button = gr.Button(value="Generate", label="Generate")
with gr.Column():
gr.Markdown("## Outputs π€")
gr.Markdown("Contrastive Search generation:")
text_contrastive = gr.Textbox(value="", label="")
time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
gr.Markdown("Beam Search generation:")
text_beam_search = gr.Textbox(value="", label="")
time_beam_search = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
# actions
generate_button.click(
fn=generate_beam_search,
inputs=[input_text, model_id, max_new_tokens, alpha, k, num_beams],
outputs=[text_contrastive, time_contrastive, text_beam_search, time_beam_search]
)
with gr.TabItem("vs. Top K Sampling"):
with gr.Row():
with gr.Column():
gr.Markdown("## Inputs βοΈ")
gr.Markdown("General options:")
model_id = gr.Text(value="facebook/opt-125m", label="Model Repository")
input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text")
max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate")
gr.Markdown("Contrastive Search options:")
alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha")
k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K")
gr.Markdown("Sampling options:")
top_k = gr.Slider(value=50, minimum=1, maximum=100, step=1, label="Top K")
seed = gr.Number(value=42, precision=0, label="Seed")
generate_button = gr.Button(value="Generate", label="Generate")
with gr.Column():
gr.Markdown("## Outputs π€")
gr.Markdown("Contrastive Search generation:")
text_contrastive = gr.Textbox(value="", label="")
time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
gr.Markdown("Top K Sampling generation:")
text_top_k = gr.Textbox(value="", label="")
time_top_k = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
# actions
generate_button.click(
fn=generate_top_k,
inputs=[input_text, model_id, max_new_tokens, alpha, k, top_k, seed],
outputs=[text_contrastive, time_contrastive, text_top_k, time_top_k]
)
with gr.TabItem("vs. Nucleus Sampling"):
with gr.Row():
with gr.Column():
gr.Markdown("## Inputs βοΈ")
gr.Markdown("General options:")
model_id = gr.Text(value="facebook/opt-125m", label="Model Repository")
input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text")
max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate")
gr.Markdown("Contrastive Search options:")
alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha")
k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K")
gr.Markdown("Sampling options:")
top_p = gr.Slider(value=0.95, minimum=0.01, maximum=1.0, step=0.01, label="Top P")
seed = gr.Number(value=42, precision=0, label="Seed")
generate_button = gr.Button(value="Generate", label="Generate")
with gr.Column():
gr.Markdown("## Outputs π€")
gr.Markdown("Contrastive Search generation:")
text_contrastive = gr.Textbox(value="", label="")
time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
gr.Markdown("Nucleus Sampling generation:")
text_top_p = gr.Textbox(value="", label="")
time_top_p = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
# actions
generate_button.click(
fn=generate_top_p,
inputs=[input_text, model_id, max_new_tokens, alpha, k, top_p, seed],
outputs=[text_contrastive, time_contrastive, text_top_p, time_top_p]
)
demo.launch()
|