DCLM-demo / app.py
jmercat's picture
use autocast
5313bd0
raw
history blame
4.83 kB
from threading import Thread
import gradio as gr
from gradio.layouts import Accordion
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from open_lm.hf import *
from open_lm.precision import get_autocast
# Define model options
MODEL_OPTIONS = {
"TRI-ML/DCLM-1B": "TRI-ML/DCLM-1B",
"Apple DCLM-Baseline-7B": "apple/DCLM-Baseline-7B"
}
# Global variables for model and tokenizer
current_model = None
current_tokenizer = None
def load_model(model_name):
global current_model, current_tokenizer
current_tokenizer = AutoTokenizer.from_pretrained(MODEL_OPTIONS[model_name])
current_model = AutoModelForCausalLM.from_pretrained(MODEL_OPTIONS[model_name])
device = "cuda" if torch.cuda.is_available() else "cpu"
current_model = current_model.to(device)
return f"Loaded model: {model_name}"
@spaces.GPU
def generate(
prompt, model_choice, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
global current_model, current_tokenizer
if current_model is None or current_tokenizer is None:
return "Please load a model first."
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
autocast = get_autocast("amp_bf16")
with autocast():
generate_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=current_tokenizer.eos_token_id
)
streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id)
generate_kwargs["streamer"] = streamer
thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
thread.start()
# Write the prompt in blue
output = "<span style='color: blue;'>" + prompt + "</span>"
for new_text in streamer:
if isinstance(new_text, torch.Tensor):
new_text = current_tokenizer.decode(new_text)
if streamer.stop_signal in new_text:
output += new_text.split(streamer.stop_signal)[0]
break
output += new_text
yield output
thread.join()
return output
additional_inputs=[
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=1048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
with gr.Blocks() as demo:
gr.Markdown(
"""
# DCLM Text Completion Demo
This demo allows you to generate text using a DCLM model.
These models are trained to predict the next word in a sequence of text, and can be used to generate text completions, they are not chatbots.
First select a model from the dropdown and click "Load Model".
Then enter some text in the text box and click "Generate" to see the model's completion.
"""
)
with gr.Row():
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model")
model_dropdown.select(
load_model,
inputs=[model_dropdown],
outputs=[gr.Textbox(label="Model Status")]
)
text_input = gr.Textbox(lines=3, label="Input Text")
text_output = gr.Markdown(label="Generated Text")
generate_button = gr.Button("Generate")
generate_button.click(
generate,
inputs=[text_input, model_dropdown, *additional_inputs],
outputs=[text_output]
)
with Accordion(label="Advanced Options", open=False):
for input_component in additional_inputs:
if not input_component.is_rendered:
input_component.render()
demo.launch()