DAT-Byte-Demo / app.py
hudsongouge's picture
Update space
ef7c422
import gradio as gr
import os
import onnxruntime as ort
from inference.onnx_inference import generate_text, sequence_breaker_strings
from inference.model import ByteTokenizer
# --- Globals ---
MODEL_OPTIONS = [
("DAT-Byte Small (200M)", "small", True),
("DAT-Byte Medium", "medium", False),
("DAT-Byte Large", "large", False),
]
ONNX_PATH = "models/small.onnx" # Assumes model.onnx is in the root directory
# Cache for the ONNX session
SESSION_CACHE = {}
TOKENIZER = ByteTokenizer()
# Prepare sequence breakers
SEQUENCE_BREAKER_IDS = {TOKENIZER.im_start_id, TOKENIZER.im_end_id}
for s in sequence_breaker_strings:
# These are single-byte tokens, so encode will return a list with one ID
try:
SEQUENCE_BREAKER_IDS.add(TOKENIZER.encode(s.encode("utf-8"))[0])
except IndexError:
print(f"Warning: Could not encode sequence breaker string: {s}")
# --- Model Loading ---
def get_session(model_key):
if model_key != "small":
raise ValueError("Only DAT-Byte Small is available.")
if model_key not in SESSION_CACHE:
if not os.path.exists(ONNX_PATH):
raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}")
# Using CPUExecutionProvider as per the project's goal
SESSION_CACHE[model_key] = ort.InferenceSession(
ONNX_PATH, providers=["CPUExecutionProvider"]
)
return SESSION_CACHE[model_key]
# --- Gradio Callbacks ---
def chat_respond(
message,
history,
model_name,
max_tokens,
temperature,
top_k,
dry_range,
dry_allowed_length,
dry_base,
dry_multiplier,
user_role="user",
assistant_role="assistant",
):
model_key = next(
(key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled),
None,
)
if not model_key:
history.append({"role": "user", "content": message})
history.append(
{"role": "assistant", "content": f"Model '{model_name}' is not available."}
)
return history
history = history or []
try:
session = get_session(model_key)
except Exception as e:
history.append({"role": "user", "content": message})
history.append(
{"role": "assistant", "content": f"[Model loading error: {str(e)}]"}
)
return history
prompt = ""
for turn in history:
prompt += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
prompt += (
f"<|im_start|>{user_role}\n{message}<|im_end|>\n<|im_start|>{assistant_role}\n"
)
generated_text, _ = generate_text(
session=session,
tokenizer=TOKENIZER,
prompt=prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
stop_sequences=["<|im_end|>".encode("utf-8")],
dry_sequence_breakers=SEQUENCE_BREAKER_IDS,
dry_range=dry_range,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_multiplier=dry_multiplier,
)
generated_text = generated_text.decode("utf-8", "ignore")
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": generated_text})
return history
def completion_respond(
prompt,
model_name,
max_tokens,
temperature,
top_k,
dry_range,
dry_allowed_length,
dry_base,
dry_multiplier,
):
model_key = next(
(key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled),
None,
)
if not model_key:
return f"[Model '{model_name}' is not available or unknown.]"
try:
session = get_session(model_key)
except Exception as e:
return f"[Model loading error: {str(e)}]"
generated_text, _ = generate_text(
session=session,
tokenizer=TOKENIZER,
prompt=prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
dry_sequence_breakers=SEQUENCE_BREAKER_IDS,
dry_range=dry_range,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_multiplier=dry_multiplier,
)
return generated_text
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# DAT-Byte Playground (ONNX Accelerated)")
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Radio(
[opt[0] for opt in MODEL_OPTIONS],
value=MODEL_OPTIONS[0][0],
label="Model",
interactive=True,
)
gr.Markdown("**Note:** Only DAT-Byte Small is currently available.")
mode_selector = gr.Radio(
["Chat", "Raw Completion"], value="Chat", label="Mode"
)
max_tokens = gr.Slider(
minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
)
temperature = gr.Slider(
minimum=0.05, maximum=2.0, value=0.5, step=0.05, label="Temperature"
)
top_k = gr.Slider(minimum=0, maximum=256, value=15, step=1, label="Top-k")
with gr.Accordion("DRY Sampling (Don't Repeat Yourself)", open=False):
dry_range = gr.Slider(
minimum=0, maximum=2048, value=1024, step=32, label="Range"
)
dry_allowed_length = gr.Slider(
minimum=1, maximum=64, value=20, step=1, label="Allowed Length"
)
dry_base = gr.Slider(
minimum=1.0, maximum=5.0, value=2.0, step=0.1, label="Base"
)
dry_multiplier = gr.Slider(
minimum=0.0, maximum=2.0, value=0.0, step=0.05, label="Multiplier"
)
user_role_box = gr.Textbox("user", label="User Role", visible=True)
assistant_role_box = gr.Textbox(
"assistant", label="Assistant Role", visible=True
)
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Chat", type="messages", height=600)
with gr.Row():
chat_input = gr.Textbox(
label="Message", placeholder="Type a message...", scale=4
)
send_button = gr.Button("Send", scale=1)
completion_input = gr.Textbox(label="Prompt", visible=False)
completion_output = gr.Textbox(label="Completion", visible=False)
# UI Logic
def update_mode(mode):
is_chat = mode == "Chat"
return (
gr.update(visible=is_chat), # chatbot
gr.update(), # chat_input row - removed visible parameter
gr.update(visible=not is_chat), # completion_input
gr.update(visible=not is_chat), # completion_output
gr.update(visible=is_chat), # user_role_box
gr.update(visible=is_chat), # assistant_role_box
)
# Create a dummy component to replace chat_input.parent which is causing the Form visibility issue
chat_input_row_visibility = gr.Checkbox(
visible=False, value=True, label="Chat Input Row Visibility"
)
mode_selector.change(
update_mode,
[mode_selector],
[
chatbot,
chat_input_row_visibility, # Replaced chat_input.parent with dummy component
completion_input,
completion_output,
user_role_box,
assistant_role_box,
],
)
# Add a separate event handler to show/hide the chat input row
def toggle_chat_input_visibility(mode):
is_chat = mode == "Chat"
return gr.update(visible=is_chat)
mode_selector.change(
toggle_chat_input_visibility,
[mode_selector],
[chat_input.parent],
)
# Event Handlers
chat_inputs = [
chat_input,
chatbot,
model_selector,
max_tokens,
temperature,
top_k,
dry_range,
dry_allowed_length,
dry_base,
dry_multiplier,
user_role_box,
assistant_role_box,
]
chat_args = {"fn": chat_respond, "inputs": chat_inputs, "outputs": [chatbot]}
def clear_input():
return ""
clear_args = {"fn": clear_input, "inputs": [], "outputs": [chat_input]}
send_button.click(**chat_args).then(**clear_args)
chat_input.submit(**chat_args).then(**clear_args)
completion_inputs = [
completion_input,
model_selector,
max_tokens,
temperature,
top_k,
dry_range,
dry_allowed_length,
dry_base,
dry_multiplier,
]
completion_input.submit(
completion_respond,
completion_inputs,
[completion_output],
)
if __name__ == "__main__":
demo.launch()