import spaces import gradio as gr from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import torch from open_lm.hf import * from open_lm.precision import get_autocast # Define model options MODEL_OPTIONS = { "TRI DCLM-1B": "TRI-ML/DCLM-1B", "Apple DCLM-Baseline-7B": "apple/DCLM-Baseline-7B", "[IT] TRI DCLM-1B": "TRI-ML/DCLM-1B-IT", "[IT] Apple DCLM-Baseline-7B": "mlfoundations/dclm-7b-it", } # Global variables for model and tokenizer current_model = None current_tokenizer = None def load_model(model_name): global current_model, current_tokenizer current_tokenizer = AutoTokenizer.from_pretrained(MODEL_OPTIONS[model_name]) current_model = AutoModelForCausalLM.from_pretrained(MODEL_OPTIONS[model_name]) device = "cuda" if torch.cuda.is_available() else "cpu" current_model = current_model.to(device) return f"Loaded model: {model_name}" @spaces.GPU def generate_completion( prompt, model_choice, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, ): global current_model, current_tokenizer if current_model is None or current_tokenizer is None: return "Please select a model first." temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device) autocast = get_autocast("amp_bf16") with autocast(): generate_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=current_tokenizer.eos_token_id ) streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False) streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id) generate_kwargs["streamer"] = streamer thread = Thread(target=current_model.generate, kwargs=generate_kwargs) thread.start() output = "" + prompt + "" for new_text in streamer: if isinstance(new_text, torch.Tensor): new_text = current_tokenizer.decode(new_text) if streamer.stop_signal in new_text: output += new_text.split(streamer.stop_signal)[0] break output += new_text yield output thread.join() return output def format_prompt(message, history): prompt = "" for user_prompt, bot_response in history: prompt += f"User: {user_prompt}\nAssistant: {bot_response}\n" prompt += f"User: {message}\nAssistant:" return prompt @spaces.GPU def generate_chat( message, chat_history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, ): global current_model, current_tokenizer if current_model is None or current_tokenizer is None: yield chat_history + [("Error", "Please select a model first.")] return temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) formatted_prompt = format_prompt(message, chat_history) inputs = current_tokenizer(formatted_prompt, return_tensors="pt").to(current_model.device) generate_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=current_tokenizer.eos_token_id ) streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False) streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id) generate_kwargs["streamer"] = streamer thread = Thread(target=current_model.generate, kwargs=generate_kwargs) thread.start() new_history = chat_history + [(message, "")] for new_text in streamer: if isinstance(new_text, torch.Tensor): new_text = current_tokenizer.decode(new_text) if streamer.stop_signal in new_text: new_text = new_text.split(streamer.stop_signal)[0] new_history[-1] = (message, new_history[-1][1] + new_text) break new_history[-1] = (message, new_history[-1][1] + new_text) yield new_history thread.join() additional_inputs = [ gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ), gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) ] with gr.Blocks() as demo: gr.Markdown( """ # DCLM Demo This demo allows you to generate text using DCLM models in two modes: 1. Text Completion: For non-Instruction-Tuned models, it generates the continuation of the input text. 2. Chatbot: For Instruction-Tuned [IT] models, it generates responses to user messages as a chatbot. Select a model from the dropdown to start, it might take a few seconds to load. The interface will automatically switch between Text Completion and Chatbot modes based on the selected model. """ ) with gr.Row(): model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model") model_status = gr.Textbox(label="Model Status") # Text Completion interface with gr.Row(visible=False) as completion_interface: with gr.Column(): text_input = gr.Textbox(lines=3, label="Input Text") text_output = gr.Markdown(label="Generated Text") generate_button = gr.Button("Generate") # Chatbot interface with gr.Row(visible=False) as chat_interface: with gr.Column(): chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel") msg = gr.Textbox(label="Message") clear = gr.Button("Clear") with gr.Accordion("Advanced Options", open=False): for input_component in additional_inputs: input_component.render() def switch_interface(model_name): is_it_model = model_name.startswith("[IT]") status = load_model(model_name) return ( gr.Row(visible=not is_it_model), # completion_interface gr.Row(visible=is_it_model), # chat_interface status # model_status ) model_dropdown.change( switch_interface, inputs=[model_dropdown], outputs=[completion_interface, chat_interface, model_status] ) generate_button.click( generate_completion, inputs=[text_input, model_dropdown, *additional_inputs], outputs=[text_output] ) msg.submit(generate_chat, [msg, chatbot, *additional_inputs], chatbot) clear.click(lambda: None, None, chatbot, queue=False) demo.queue().launch()