#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os from threading import Thread import gradio as gr from transformers import AutoModel, AutoTokenizer from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer from transformers.generation.streamers import TextIteratorStreamer import torch from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_subset", default="train.jsonl", type=str) parser.add_argument("--valid_subset", default="valid.jsonl", type=str) parser.add_argument( "--pretrained_model_name_or_path", default=(project_path / "trained_models/qwen_7b_chinese_modern_poetry").as_posix(), type=str ) parser.add_argument("--output_file", default="result.xlsx", type=str) parser.add_argument("--max_new_tokens", default=512, type=int) parser.add_argument("--top_p", default=0.9, type=float) parser.add_argument("--temperature", default=0.35, type=float) parser.add_argument("--repetition_penalty", default=1.0, type=float) parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str) args = parser.parse_args() return args description = """ ## Qwen-7B 基于 [Qwen-7B](https://huggingface.co/qgyd2021/Qwen-7B) 模型, 在 [chinese_modern_poetry](https://huggingface.co/datasets/Iess/chinese_modern_poetry) 数据集上训练了 2 个 epoch. 可用于生成现代诗. 如下: 使用下列意象写一首现代诗:智慧,刀刃. """ examples = [ "使用下列意象写一首现代诗:石头,森林", "使用下列意象写一首现代诗:花,纱布", "使用下列意象写一首现代诗:山壁,彩虹,诗句,山坡,泪", "使用下列意象写一首现代诗:味道,黄金,名字,银子,女人", "使用下列意象写一首现代诗:乳房,触感,车速,星星,路灯" ] def main(): args = get_args() tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True) # QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|> if tokenizer.__class__.__name__ == "QWenTokenizer": tokenizer.pad_token_id = tokenizer.eod_id tokenizer.bos_token_id = tokenizer.eod_id tokenizer.eos_token_id = tokenizer.eod_id model = AutoModelForCausalLM.from_pretrained( args.pretrained_model_name_or_path, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, device_map="auto", offload_folder="./offload", offload_state_dict=True, # load_in_4bit=True, ) model = model.bfloat16().eval() def fn_non_stream(text: str): input_ids = tokenizer( text, return_tensors="pt", add_special_tokens=False, ).input_ids.to(args.device) bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device) eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device) input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, max_new_tokens=args.max_new_tokens, do_sample=True, top_p=args.top_p, temperature=args.temperature, repetition_penalty=args.repetition_penalty, eos_token_id=tokenizer.eos_token_id ) outputs = outputs.tolist()[0][len(input_ids[0]):] response = tokenizer.decode(outputs) response = response.strip().replace(tokenizer.eos_token, "").strip() return [(text, response)] def fn_stream(text: str): text = str(text).strip() input_ids = tokenizer( text, return_tensors="pt", add_special_tokens=False, ).input_ids.to(args.device) bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device) eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device) input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1) streamer = TextIteratorStreamer(tokenizer=tokenizer) generation_kwargs = dict( inputs=input_ids, max_new_tokens=args.max_new_tokens, do_sample=True, top_p=args.top_p, temperature=args.temperature, repetition_penalty=args.repetition_penalty, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() output = "" for output_ in streamer: output_ = output_.replace(text, "") output_ = output_.replace(tokenizer.eos_token, "") output += output_ result = [(text, output)] chatbot.value = result yield result with gr.Blocks() as blocks: gr.Markdown(value=description) chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400) with gr.Row(): with gr.Column(scale=4): text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False) with gr.Column(scale=1): submit_button = gr.Button("💬Submit") with gr.Column(scale=1): clear_button = gr.Button("🗑️Clear", variant="secondary") gr.Examples(examples, text_box) text_box.submit(fn_stream, [text_box], [chatbot]) submit_button.click(fn_stream, [text_box], [chatbot]) clear_button.click( fn=lambda: ("", ""), outputs=[text_box, chatbot], queue=False, api_name=False, ) blocks.queue().launch() return if __name__ == '__main__': main()