Spaces:
Sleeping
Sleeping
File size: 8,818 Bytes
adf0368 ef7c422 adf0368 ef7c422 adf0368 ef7c422 adf0368 ef7c422 adf0368 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import gradio as gr
import os
import onnxruntime as ort
from inference.onnx_inference import generate_text, sequence_breaker_strings
from inference.model import ByteTokenizer
# --- Globals ---
MODEL_OPTIONS = [
("DAT-Byte Small (200M)", "small", True),
("DAT-Byte Medium", "medium", False),
("DAT-Byte Large", "large", False),
]
ONNX_PATH = "models/small.onnx" # Assumes model.onnx is in the root directory
# Cache for the ONNX session
SESSION_CACHE = {}
TOKENIZER = ByteTokenizer()
# Prepare sequence breakers
SEQUENCE_BREAKER_IDS = {TOKENIZER.im_start_id, TOKENIZER.im_end_id}
for s in sequence_breaker_strings:
# These are single-byte tokens, so encode will return a list with one ID
try:
SEQUENCE_BREAKER_IDS.add(TOKENIZER.encode(s.encode("utf-8"))[0])
except IndexError:
print(f"Warning: Could not encode sequence breaker string: {s}")
# --- Model Loading ---
def get_session(model_key):
if model_key != "small":
raise ValueError("Only DAT-Byte Small is available.")
if model_key not in SESSION_CACHE:
if not os.path.exists(ONNX_PATH):
raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}")
# Using CPUExecutionProvider as per the project's goal
SESSION_CACHE[model_key] = ort.InferenceSession(
ONNX_PATH, providers=["CPUExecutionProvider"]
)
return SESSION_CACHE[model_key]
# --- Gradio Callbacks ---
def chat_respond(
message,
history,
model_name,
max_tokens,
temperature,
top_k,
dry_range,
dry_allowed_length,
dry_base,
dry_multiplier,
user_role="user",
assistant_role="assistant",
):
model_key = next(
(key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled),
None,
)
if not model_key:
history.append({"role": "user", "content": message})
history.append(
{"role": "assistant", "content": f"Model '{model_name}' is not available."}
)
return history
history = history or []
try:
session = get_session(model_key)
except Exception as e:
history.append({"role": "user", "content": message})
history.append(
{"role": "assistant", "content": f"[Model loading error: {str(e)}]"}
)
return history
prompt = ""
for turn in history:
prompt += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
prompt += (
f"<|im_start|>{user_role}\n{message}<|im_end|>\n<|im_start|>{assistant_role}\n"
)
generated_text, _ = generate_text(
session=session,
tokenizer=TOKENIZER,
prompt=prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
stop_sequences=["<|im_end|>".encode("utf-8")],
dry_sequence_breakers=SEQUENCE_BREAKER_IDS,
dry_range=dry_range,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_multiplier=dry_multiplier,
)
generated_text = generated_text.decode("utf-8", "ignore")
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": generated_text})
return history
def completion_respond(
prompt,
model_name,
max_tokens,
temperature,
top_k,
dry_range,
dry_allowed_length,
dry_base,
dry_multiplier,
):
model_key = next(
(key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled),
None,
)
if not model_key:
return f"[Model '{model_name}' is not available or unknown.]"
try:
session = get_session(model_key)
except Exception as e:
return f"[Model loading error: {str(e)}]"
generated_text, _ = generate_text(
session=session,
tokenizer=TOKENIZER,
prompt=prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
dry_sequence_breakers=SEQUENCE_BREAKER_IDS,
dry_range=dry_range,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_multiplier=dry_multiplier,
)
return generated_text
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# DAT-Byte Playground (ONNX Accelerated)")
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Radio(
[opt[0] for opt in MODEL_OPTIONS],
value=MODEL_OPTIONS[0][0],
label="Model",
interactive=True,
)
gr.Markdown("**Note:** Only DAT-Byte Small is currently available.")
mode_selector = gr.Radio(
["Chat", "Raw Completion"], value="Chat", label="Mode"
)
max_tokens = gr.Slider(
minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
)
temperature = gr.Slider(
minimum=0.05, maximum=2.0, value=0.5, step=0.05, label="Temperature"
)
top_k = gr.Slider(minimum=0, maximum=256, value=15, step=1, label="Top-k")
with gr.Accordion("DRY Sampling (Don't Repeat Yourself)", open=False):
dry_range = gr.Slider(
minimum=0, maximum=2048, value=1024, step=32, label="Range"
)
dry_allowed_length = gr.Slider(
minimum=1, maximum=64, value=20, step=1, label="Allowed Length"
)
dry_base = gr.Slider(
minimum=1.0, maximum=5.0, value=2.0, step=0.1, label="Base"
)
dry_multiplier = gr.Slider(
minimum=0.0, maximum=2.0, value=0.0, step=0.05, label="Multiplier"
)
user_role_box = gr.Textbox("user", label="User Role", visible=True)
assistant_role_box = gr.Textbox(
"assistant", label="Assistant Role", visible=True
)
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Chat", type="messages", height=600)
with gr.Row():
chat_input = gr.Textbox(
label="Message", placeholder="Type a message...", scale=4
)
send_button = gr.Button("Send", scale=1)
completion_input = gr.Textbox(label="Prompt", visible=False)
completion_output = gr.Textbox(label="Completion", visible=False)
# UI Logic
def update_mode(mode):
is_chat = mode == "Chat"
return (
gr.update(visible=is_chat), # chatbot
gr.update(), # chat_input row - removed visible parameter
gr.update(visible=not is_chat), # completion_input
gr.update(visible=not is_chat), # completion_output
gr.update(visible=is_chat), # user_role_box
gr.update(visible=is_chat), # assistant_role_box
)
# Create a dummy component to replace chat_input.parent which is causing the Form visibility issue
chat_input_row_visibility = gr.Checkbox(
visible=False, value=True, label="Chat Input Row Visibility"
)
mode_selector.change(
update_mode,
[mode_selector],
[
chatbot,
chat_input_row_visibility, # Replaced chat_input.parent with dummy component
completion_input,
completion_output,
user_role_box,
assistant_role_box,
],
)
# Add a separate event handler to show/hide the chat input row
def toggle_chat_input_visibility(mode):
is_chat = mode == "Chat"
return gr.update(visible=is_chat)
mode_selector.change(
toggle_chat_input_visibility,
[mode_selector],
[chat_input.parent],
)
# Event Handlers
chat_inputs = [
chat_input,
chatbot,
model_selector,
max_tokens,
temperature,
top_k,
dry_range,
dry_allowed_length,
dry_base,
dry_multiplier,
user_role_box,
assistant_role_box,
]
chat_args = {"fn": chat_respond, "inputs": chat_inputs, "outputs": [chatbot]}
def clear_input():
return ""
clear_args = {"fn": clear_input, "inputs": [], "outputs": [chat_input]}
send_button.click(**chat_args).then(**clear_args)
chat_input.submit(**chat_args).then(**clear_args)
completion_inputs = [
completion_input,
model_selector,
max_tokens,
temperature,
top_k,
dry_range,
dry_allowed_length,
dry_base,
dry_multiplier,
]
completion_input.submit(
completion_respond,
completion_inputs,
[completion_output],
)
if __name__ == "__main__":
demo.launch()
|