import gradio as gr import os import onnxruntime as ort from inference.onnx_inference import generate_text, sequence_breaker_strings from inference.model import ByteTokenizer # --- Globals --- MODEL_OPTIONS = [ ("DAT-Byte Small (200M)", "small", True), ("DAT-Byte Medium", "medium", False), ("DAT-Byte Large", "large", False), ] ONNX_PATH = "models/small.onnx" # Assumes model.onnx is in the root directory # Cache for the ONNX session SESSION_CACHE = {} TOKENIZER = ByteTokenizer() # Prepare sequence breakers SEQUENCE_BREAKER_IDS = {TOKENIZER.im_start_id, TOKENIZER.im_end_id} for s in sequence_breaker_strings: # These are single-byte tokens, so encode will return a list with one ID try: SEQUENCE_BREAKER_IDS.add(TOKENIZER.encode(s.encode("utf-8"))[0]) except IndexError: print(f"Warning: Could not encode sequence breaker string: {s}") # --- Model Loading --- def get_session(model_key): if model_key != "small": raise ValueError("Only DAT-Byte Small is available.") if model_key not in SESSION_CACHE: if not os.path.exists(ONNX_PATH): raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}") # Using CPUExecutionProvider as per the project's goal SESSION_CACHE[model_key] = ort.InferenceSession( ONNX_PATH, providers=["CPUExecutionProvider"] ) return SESSION_CACHE[model_key] # --- Gradio Callbacks --- def chat_respond( message, history, model_name, max_tokens, temperature, top_k, dry_range, dry_allowed_length, dry_base, dry_multiplier, user_role="user", assistant_role="assistant", ): model_key = next( (key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled), None, ) if not model_key: history.append({"role": "user", "content": message}) history.append( {"role": "assistant", "content": f"Model '{model_name}' is not available."} ) return history history = history or [] try: session = get_session(model_key) except Exception as e: history.append({"role": "user", "content": message}) history.append( {"role": "assistant", "content": f"[Model loading error: {str(e)}]"} ) return history prompt = "" for turn in history: prompt += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n" prompt += ( f"<|im_start|>{user_role}\n{message}<|im_end|>\n<|im_start|>{assistant_role}\n" ) generated_text, _ = generate_text( session=session, tokenizer=TOKENIZER, prompt=prompt, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, stop_sequences=["<|im_end|>".encode("utf-8")], dry_sequence_breakers=SEQUENCE_BREAKER_IDS, dry_range=dry_range, dry_allowed_length=dry_allowed_length, dry_base=dry_base, dry_multiplier=dry_multiplier, ) generated_text = generated_text.decode("utf-8", "ignore") history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": generated_text}) return history def completion_respond( prompt, model_name, max_tokens, temperature, top_k, dry_range, dry_allowed_length, dry_base, dry_multiplier, ): model_key = next( (key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled), None, ) if not model_key: return f"[Model '{model_name}' is not available or unknown.]" try: session = get_session(model_key) except Exception as e: return f"[Model loading error: {str(e)}]" generated_text, _ = generate_text( session=session, tokenizer=TOKENIZER, prompt=prompt, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, dry_sequence_breakers=SEQUENCE_BREAKER_IDS, dry_range=dry_range, dry_allowed_length=dry_allowed_length, dry_base=dry_base, dry_multiplier=dry_multiplier, ) return generated_text # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("# DAT-Byte Playground (ONNX Accelerated)") with gr.Row(): with gr.Column(scale=1): model_selector = gr.Radio( [opt[0] for opt in MODEL_OPTIONS], value=MODEL_OPTIONS[0][0], label="Model", interactive=True, ) gr.Markdown("**Note:** Only DAT-Byte Small is currently available.") mode_selector = gr.Radio( ["Chat", "Raw Completion"], value="Chat", label="Mode" ) max_tokens = gr.Slider( minimum=1, maximum=2048, value=512, step=1, label="Max new tokens" ) temperature = gr.Slider( minimum=0.05, maximum=2.0, value=0.5, step=0.05, label="Temperature" ) top_k = gr.Slider(minimum=0, maximum=256, value=15, step=1, label="Top-k") with gr.Accordion("DRY Sampling (Don't Repeat Yourself)", open=False): dry_range = gr.Slider( minimum=0, maximum=2048, value=1024, step=32, label="Range" ) dry_allowed_length = gr.Slider( minimum=1, maximum=64, value=20, step=1, label="Allowed Length" ) dry_base = gr.Slider( minimum=1.0, maximum=5.0, value=2.0, step=0.1, label="Base" ) dry_multiplier = gr.Slider( minimum=0.0, maximum=2.0, value=0.0, step=0.05, label="Multiplier" ) user_role_box = gr.Textbox("user", label="User Role", visible=True) assistant_role_box = gr.Textbox( "assistant", label="Assistant Role", visible=True ) with gr.Column(scale=3): chatbot = gr.Chatbot(label="Chat", type="messages", height=600) with gr.Row(): chat_input = gr.Textbox( label="Message", placeholder="Type a message...", scale=4 ) send_button = gr.Button("Send", scale=1) completion_input = gr.Textbox(label="Prompt", visible=False) completion_output = gr.Textbox(label="Completion", visible=False) # UI Logic def update_mode(mode): is_chat = mode == "Chat" return ( gr.update(visible=is_chat), # chatbot gr.update(), # chat_input row - removed visible parameter gr.update(visible=not is_chat), # completion_input gr.update(visible=not is_chat), # completion_output gr.update(visible=is_chat), # user_role_box gr.update(visible=is_chat), # assistant_role_box ) # Create a dummy component to replace chat_input.parent which is causing the Form visibility issue chat_input_row_visibility = gr.Checkbox( visible=False, value=True, label="Chat Input Row Visibility" ) mode_selector.change( update_mode, [mode_selector], [ chatbot, chat_input_row_visibility, # Replaced chat_input.parent with dummy component completion_input, completion_output, user_role_box, assistant_role_box, ], ) # Add a separate event handler to show/hide the chat input row def toggle_chat_input_visibility(mode): is_chat = mode == "Chat" return gr.update(visible=is_chat) mode_selector.change( toggle_chat_input_visibility, [mode_selector], [chat_input.parent], ) # Event Handlers chat_inputs = [ chat_input, chatbot, model_selector, max_tokens, temperature, top_k, dry_range, dry_allowed_length, dry_base, dry_multiplier, user_role_box, assistant_role_box, ] chat_args = {"fn": chat_respond, "inputs": chat_inputs, "outputs": [chatbot]} def clear_input(): return "" clear_args = {"fn": clear_input, "inputs": [], "outputs": [chat_input]} send_button.click(**chat_args).then(**clear_args) chat_input.submit(**chat_args).then(**clear_args) completion_inputs = [ completion_input, model_selector, max_tokens, temperature, top_k, dry_range, dry_allowed_length, dry_base, dry_multiplier, ] completion_input.submit( completion_respond, completion_inputs, [completion_output], ) if __name__ == "__main__": demo.launch()