kimi-coder-demo / app.py
Error Lover
fix startup when torch missing
51c594b
import gradio as gr
from threading import Thread
err = None
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
except ModuleNotFoundError as e:
err = e
MODEL = "yava-code/kimi-coder-135m"
tok = None
model = None
if err is None:
tok = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.eval()
def respond(msg, history, max_new, temp):
if err is not None:
yield (
f"missing dependency: `{err.name}`\n\n"
"add it to requirements.txt and rebuild the space."
)
return
chat = []
for u, a in history:
chat += [{"role": "user", "content": u}, {"role": "assistant", "content": a}]
chat.append({"role": "user", "content": msg})
input_ids = tok.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(model.device)
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new,
temperature=temp,
do_sample=temp > 0,
)
Thread(target=model.generate, kwargs=gen_kwargs).start()
out = ""
for token in streamer:
out += token
yield out
EXAMPLES = [
["Write a Python function to find all prime numbers up to n using the Sieve of Eratosthenes."],
["Implement a binary search tree with insert and search methods in Python."],
["Write a decorator that caches function results (memoization)."],
]
with gr.Blocks(title="kimi-coder-135m") as demo:
gr.Markdown(
"""
## 🤖 kimi-coder-135m
SmolLM2-135M fine-tuned on 15k coding samples distilled from KIMI-K2.5.
Model: [yava-code/kimi-coder-135m](https://huggingface.co/yava-code/kimi-coder-135m)
"""
)
if err is not None:
gr.Markdown(
f"### startup warning\n"
f"missing dependency: `{err.name}`\n\n"
f"current requirements include `torch`, so this usually means the build failed.\n"
f"try restarting/rebuilding the space."
)
else:
chatbot = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(64, 1024, value=512, label="Max new tokens"),
gr.Slider(0, 1, value=0.3, step=0.05, label="Temperature"),
],
examples=EXAMPLES,
title="",
)
demo.launch()