chat_with_llm / main.py
qgyd2021's picture
[update]add main
a845f24
raw
history blame
No virus
4.79 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import List, Tuple
from threading import Thread
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from project_settings import project_path
def greet(question: str, history: List[Tuple[str, str]]):
answer = "Hello " + question + "!"
result = history + [(question, answer)]
return result
def chat_with_llm_non_stream(question: str,
history: List[Tuple[str, str]],
pretrained_model_name_or_path: str,
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
device_map="auto",
offload_folder="./offload",
offload_state_dict=True,
# load_in_4bit=True,
)
model = model.to(device)
model = model.bfloat16().eval()
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
# llama不支持fast
use_fast=False if model.config.model_type == "llama" else True,
padding_side="left"
)
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
if tokenizer.__class__.__name__ == "QWenTokenizer":
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
input_ids = tokenizer(
question,
return_tensors="pt",
add_special_tokens=False,
).input_ids.to(device)
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)
input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id
)
outputs = outputs.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs)
response = response.strip().replace(tokenizer.eos_token, "").strip()
return
def main():
description = """
chat llm
"""
with gr.Blocks() as blocks:
gr.Markdown(value="gradio demo")
chatbot = gr.Chatbot([], elem_id="chatbot", height=400)
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter", container=False)
with gr.Column(scale=1):
submit_button = gr.Button("💬Submit")
with gr.Column(scale=1):
clear_button = gr.Button(
'🗑️Clear',
variant='secondary',
)
with gr.Row():
with gr.Column(scale=1):
max_new_tokens = gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"),
with gr.Column(scale=1):
top_p = gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
with gr.Column(scale=1):
temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
with gr.Column(scale=1):
repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
with gr.Row():
model_name = gr.Dropdown(choices=["Qwen/Qwen-7B-Chat"],
value="Qwen/Qwen-7B-Chat",
label="model_name",
)
gr.Examples(examples=["你好"], inputs=text_box)
inputs = [
text_box, chatbot, model_name,
max_new_tokens, top_p, temperature, repetition_penalty
]
outputs = [
chatbot
]
text_box.submit(chat_with_llm_non_stream, inputs, outputs)
submit_button.click(chat_with_llm_non_stream, inputs, outputs)
clear_button.click(
fn=lambda: ('', ''),
outputs=[text_box, chatbot],
queue=False,
api_name=False,
)
blocks.queue().launch()
return
if __name__ == '__main__':
main()