Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import onnxruntime as ort | |
from inference.onnx_inference import generate_text, sequence_breaker_strings | |
from inference.model import ByteTokenizer | |
# --- Globals --- | |
MODEL_OPTIONS = [ | |
("DAT-Byte Small (200M)", "small", True), | |
("DAT-Byte Medium", "medium", False), | |
("DAT-Byte Large", "large", False), | |
] | |
ONNX_PATH = "models/small.onnx" # Assumes model.onnx is in the root directory | |
# Cache for the ONNX session | |
SESSION_CACHE = {} | |
TOKENIZER = ByteTokenizer() | |
# Prepare sequence breakers | |
SEQUENCE_BREAKER_IDS = {TOKENIZER.im_start_id, TOKENIZER.im_end_id} | |
for s in sequence_breaker_strings: | |
# These are single-byte tokens, so encode will return a list with one ID | |
try: | |
SEQUENCE_BREAKER_IDS.add(TOKENIZER.encode(s.encode("utf-8"))[0]) | |
except IndexError: | |
print(f"Warning: Could not encode sequence breaker string: {s}") | |
# --- Model Loading --- | |
def get_session(model_key): | |
if model_key != "small": | |
raise ValueError("Only DAT-Byte Small is available.") | |
if model_key not in SESSION_CACHE: | |
if not os.path.exists(ONNX_PATH): | |
raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}") | |
# Using CPUExecutionProvider as per the project's goal | |
SESSION_CACHE[model_key] = ort.InferenceSession( | |
ONNX_PATH, providers=["CPUExecutionProvider"] | |
) | |
return SESSION_CACHE[model_key] | |
# --- Gradio Callbacks --- | |
def chat_respond( | |
message, | |
history, | |
model_name, | |
max_tokens, | |
temperature, | |
top_k, | |
dry_range, | |
dry_allowed_length, | |
dry_base, | |
dry_multiplier, | |
user_role="user", | |
assistant_role="assistant", | |
): | |
model_key = next( | |
(key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled), | |
None, | |
) | |
if not model_key: | |
history.append({"role": "user", "content": message}) | |
history.append( | |
{"role": "assistant", "content": f"Model '{model_name}' is not available."} | |
) | |
return history | |
history = history or [] | |
try: | |
session = get_session(model_key) | |
except Exception as e: | |
history.append({"role": "user", "content": message}) | |
history.append( | |
{"role": "assistant", "content": f"[Model loading error: {str(e)}]"} | |
) | |
return history | |
prompt = "" | |
for turn in history: | |
prompt += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n" | |
prompt += ( | |
f"<|im_start|>{user_role}\n{message}<|im_end|>\n<|im_start|>{assistant_role}\n" | |
) | |
generated_text, _ = generate_text( | |
session=session, | |
tokenizer=TOKENIZER, | |
prompt=prompt, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
stop_sequences=["<|im_end|>".encode("utf-8")], | |
dry_sequence_breakers=SEQUENCE_BREAKER_IDS, | |
dry_range=dry_range, | |
dry_allowed_length=dry_allowed_length, | |
dry_base=dry_base, | |
dry_multiplier=dry_multiplier, | |
) | |
generated_text = generated_text.decode("utf-8", "ignore") | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": generated_text}) | |
return history | |
def completion_respond( | |
prompt, | |
model_name, | |
max_tokens, | |
temperature, | |
top_k, | |
dry_range, | |
dry_allowed_length, | |
dry_base, | |
dry_multiplier, | |
): | |
model_key = next( | |
(key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled), | |
None, | |
) | |
if not model_key: | |
return f"[Model '{model_name}' is not available or unknown.]" | |
try: | |
session = get_session(model_key) | |
except Exception as e: | |
return f"[Model loading error: {str(e)}]" | |
generated_text, _ = generate_text( | |
session=session, | |
tokenizer=TOKENIZER, | |
prompt=prompt, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
dry_sequence_breakers=SEQUENCE_BREAKER_IDS, | |
dry_range=dry_range, | |
dry_allowed_length=dry_allowed_length, | |
dry_base=dry_base, | |
dry_multiplier=dry_multiplier, | |
) | |
return generated_text | |
# --- Gradio UI --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# DAT-Byte Playground (ONNX Accelerated)") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_selector = gr.Radio( | |
[opt[0] for opt in MODEL_OPTIONS], | |
value=MODEL_OPTIONS[0][0], | |
label="Model", | |
interactive=True, | |
) | |
gr.Markdown("**Note:** Only DAT-Byte Small is currently available.") | |
mode_selector = gr.Radio( | |
["Chat", "Raw Completion"], value="Chat", label="Mode" | |
) | |
max_tokens = gr.Slider( | |
minimum=1, maximum=2048, value=512, step=1, label="Max new tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.05, maximum=2.0, value=0.5, step=0.05, label="Temperature" | |
) | |
top_k = gr.Slider(minimum=0, maximum=256, value=15, step=1, label="Top-k") | |
with gr.Accordion("DRY Sampling (Don't Repeat Yourself)", open=False): | |
dry_range = gr.Slider( | |
minimum=0, maximum=2048, value=1024, step=32, label="Range" | |
) | |
dry_allowed_length = gr.Slider( | |
minimum=1, maximum=64, value=20, step=1, label="Allowed Length" | |
) | |
dry_base = gr.Slider( | |
minimum=1.0, maximum=5.0, value=2.0, step=0.1, label="Base" | |
) | |
dry_multiplier = gr.Slider( | |
minimum=0.0, maximum=2.0, value=0.0, step=0.05, label="Multiplier" | |
) | |
user_role_box = gr.Textbox("user", label="User Role", visible=True) | |
assistant_role_box = gr.Textbox( | |
"assistant", label="Assistant Role", visible=True | |
) | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(label="Chat", type="messages", height=600) | |
with gr.Row(): | |
chat_input = gr.Textbox( | |
label="Message", placeholder="Type a message...", scale=4 | |
) | |
send_button = gr.Button("Send", scale=1) | |
completion_input = gr.Textbox(label="Prompt", visible=False) | |
completion_output = gr.Textbox(label="Completion", visible=False) | |
# UI Logic | |
def update_mode(mode): | |
is_chat = mode == "Chat" | |
return ( | |
gr.update(visible=is_chat), # chatbot | |
gr.update(), # chat_input row - removed visible parameter | |
gr.update(visible=not is_chat), # completion_input | |
gr.update(visible=not is_chat), # completion_output | |
gr.update(visible=is_chat), # user_role_box | |
gr.update(visible=is_chat), # assistant_role_box | |
) | |
# Create a dummy component to replace chat_input.parent which is causing the Form visibility issue | |
chat_input_row_visibility = gr.Checkbox( | |
visible=False, value=True, label="Chat Input Row Visibility" | |
) | |
mode_selector.change( | |
update_mode, | |
[mode_selector], | |
[ | |
chatbot, | |
chat_input_row_visibility, # Replaced chat_input.parent with dummy component | |
completion_input, | |
completion_output, | |
user_role_box, | |
assistant_role_box, | |
], | |
) | |
# Add a separate event handler to show/hide the chat input row | |
def toggle_chat_input_visibility(mode): | |
is_chat = mode == "Chat" | |
return gr.update(visible=is_chat) | |
mode_selector.change( | |
toggle_chat_input_visibility, | |
[mode_selector], | |
[chat_input.parent], | |
) | |
# Event Handlers | |
chat_inputs = [ | |
chat_input, | |
chatbot, | |
model_selector, | |
max_tokens, | |
temperature, | |
top_k, | |
dry_range, | |
dry_allowed_length, | |
dry_base, | |
dry_multiplier, | |
user_role_box, | |
assistant_role_box, | |
] | |
chat_args = {"fn": chat_respond, "inputs": chat_inputs, "outputs": [chatbot]} | |
def clear_input(): | |
return "" | |
clear_args = {"fn": clear_input, "inputs": [], "outputs": [chat_input]} | |
send_button.click(**chat_args).then(**clear_args) | |
chat_input.submit(**chat_args).then(**clear_args) | |
completion_inputs = [ | |
completion_input, | |
model_selector, | |
max_tokens, | |
temperature, | |
top_k, | |
dry_range, | |
dry_allowed_length, | |
dry_base, | |
dry_multiplier, | |
] | |
completion_input.submit( | |
completion_respond, | |
completion_inputs, | |
[completion_output], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |