chat / app.py
adamelliotfields's picture
Rename
24d264b verified
import random
import gradio as gr
import numpy as np
import torch
from lib import CONFIG, generate
HEAD = """
<style>
@media (min-width: 1536px) {
gradio-app > .gradio-container { max-width: 1280px !important }
}
</style>
"""
HEADER = """
# Chat 🤖
Serverless small language model inference.
"""
SEED = 0
PORT = 7860
if gr.NO_RELOAD:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
chatbot = gr.Chatbot(type="messages", show_label=False, height=None, scale=1)
textbox = gr.Textbox(placeholder="Type a message...", autofocus=True, scale=7)
# https://github.com/gradio-app/gradio/blob/main/gradio/chat_interface.py
chat_interface = gr.ChatInterface(
title=None,
fn=generate,
chatbot=chatbot,
textbox=textbox,
type="messages", # interface type must match bot type
description=None,
additional_inputs=[
gr.Textbox(
label="System Message",
lines=2,
value="You are a helpful assistant. Be concise and precise.",
),
gr.Dropdown(
label="Model",
filterable=False,
value="Qwen/Qwen2.5-0.5B-Instruct",
choices=list(CONFIG.keys()),
),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=2048,
step=1,
value=512,
info="Maximum number of new tokens to generate.",
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.6,
info="Modulates next token probabilities.",
),
gr.Slider(
# https://arxiv.org/abs/1909.05858
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
info="Penalizes repeating tokens.",
),
gr.Slider(
# https://arxiv.org/abs/1904.09751
label="Top-p",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
info="Only tokens with cumulative probability p are considered (nucleus sampling).",
),
gr.Slider(
# https://arxiv.org/pdf/1805.04833
label="Top-k",
minimum=1,
maximum=100,
step=1,
value=50,
info="Only k-th highest probability tokens are considered.",
),
],
)
with gr.Blocks(head=HEAD, fill_height=True) as demo:
gr.Markdown(HEADER)
chat_interface.render()
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1).launch(
server_name="0.0.0.0",
server_port=PORT,
)