vllm-chat / app.py
vilarin's picture
Create app.py
9443a16 verified
raw
history blame
3.08 kB
import os
import spaces
import gradio as gr
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
model = os.environ.get["MODEL_ID"]
MODEL_NAME = model.split("/")[-1]
DESCRIPTION = f"""
<h3>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></h3>
<center>
<p>Qwen is the large language model built by Alibaba Cloud.
<br>
Feel free to test without log.
</p>
</center>
"""
css="""
h1 {
text-align: center;
}
footer {
visibility: hidden;
}
"""
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model)
# Pass the default decoding hyperparameters of Qwen2-7B-Instruct
# max_tokens is for the maximum length for generation.
# Input the model name or path. Can be GPTQ or AWQ models.
llm = LLM(model=model)
@spaces.GPU
def generate(message, history, system, max_tokens, temperature, top_p, top_k, penalty):
# Prepare your prompts
conversation = [
{"role": "system", "content":sytem}
]
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
print(f"Conversation is -\n{conversation}")
text = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=penalty,
max_tokens=max_tokens,
eos_token_id=[151645,151643],
)
# generate outputs
outputs = llm.generate([text], sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_text
with gr.Blocks(css=css, fill-height=True) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(scale=1),
additional_inputs=[
gr.Textbox(value="You are a helpful assistant.", label="System message"),
gr.Slider(minimum=1, maximum=30720, value=2048, step=1, label="Max tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
),
gr.Slider(
minimum=0,
maximum=20,
value=20,
step=1,
label="Top-k",
),
gr.Slider(
minimum=0.0,
maximum=2.0,
value=1,
step=0.1,
label="Repetition penalty",
),
],
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
submit_btn="Send",
)
if __name__ == "__main__":
demo.launch()