|
|
|
|
|
|
|
|
|
|
|
"""A simple command-line interactive chat demo.""" |
|
|
|
import argparse |
|
import os |
|
import platform |
|
import shutil |
|
from copy import deepcopy |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
from transformers.trainer_utils import set_seed |
|
|
|
DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat' |
|
|
|
_WELCOME_MSG = '''\ |
|
Welcome to use Qwen-7B-Chat model, type text to start chat, type :h to show command help |
|
欢迎使用 Qwen-7B 模型,输入内容即可进行对话,:h 显示命令帮助 |
|
''' |
|
_HELP_MSG = '''\ |
|
Commands: |
|
:help / :h Show this help message 显示帮助信息 |
|
:exit / :quit / :q Exit the demo 退出Demo |
|
:clear / :cl Clear screen 清屏 |
|
:clear-his / :clh Clear history 清除对话历史 |
|
:history / :his Show history 显示对话历史 |
|
:seed Show current random seed 显示当前随机种子 |
|
:seed <N> Set random seed to <N> 设置随机种子 |
|
:conf Show current generation config 显示生成配置 |
|
:conf <key>=<value> Change generation config 修改生成配置 |
|
:reset-conf Reset generation config 重置生成配置 |
|
''' |
|
|
|
|
|
def _load_model_tokenizer(args): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
args.checkpoint_path, trust_remote_code=True, resume_download=True, |
|
) |
|
|
|
if args.cpu_only: |
|
device_map = "cpu" |
|
else: |
|
device_map = "auto" |
|
|
|
qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json') |
|
if os.path.exists(qconfig_path): |
|
from auto_gptq import AutoGPTQForCausalLM |
|
model = AutoGPTQForCausalLM.from_quantized( |
|
args.checkpoint_path, |
|
device_map=device_map, |
|
trust_remote_code=True, |
|
resume_download=True, |
|
use_safetensors=True, |
|
).eval() |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.checkpoint_path, |
|
device_map=device_map, |
|
trust_remote_code=True, |
|
resume_download=True, |
|
).eval() |
|
|
|
config = GenerationConfig.from_pretrained( |
|
args.checkpoint_path, trust_remote_code=True, resume_download=True, |
|
) |
|
|
|
return model, tokenizer, config |
|
|
|
|
|
def _clear_screen(): |
|
if platform.system() == "Windows": |
|
os.system("cls") |
|
else: |
|
os.system("clear") |
|
|
|
|
|
def _print_history(history): |
|
terminal_width = shutil.get_terminal_size()[0] |
|
print(f'History ({len(history)})'.center(terminal_width, '=')) |
|
for index, (query, response) in enumerate(history): |
|
print(f'User[{index}]: {query}') |
|
print(f'QWen[{index}]: {response}') |
|
print('=' * terminal_width) |
|
|
|
|
|
def _get_input() -> str: |
|
while True: |
|
try: |
|
message = input('User> ').strip() |
|
except UnicodeDecodeError: |
|
print('[ERROR] Encoding error in input') |
|
continue |
|
except KeyboardInterrupt: |
|
exit(1) |
|
if message: |
|
return message |
|
print('[ERROR] Query is empty') |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description='QWen-7B-Chat command-line interactive chat demo.') |
|
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, |
|
help="Checkpoint name or path, default to %(default)r") |
|
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed") |
|
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") |
|
args = parser.parse_args() |
|
|
|
history, response = [], '' |
|
|
|
model, tokenizer, config = _load_model_tokenizer(args) |
|
orig_gen_config = deepcopy(model.generation_config) |
|
|
|
_clear_screen() |
|
print(_WELCOME_MSG) |
|
|
|
seed = args.seed |
|
|
|
while True: |
|
query = _get_input() |
|
|
|
|
|
if query.startswith(':'): |
|
command_words = query[1:].strip().split() |
|
if not command_words: |
|
command = '' |
|
else: |
|
command = command_words[0] |
|
|
|
if command in ['exit', 'quit', 'q']: |
|
break |
|
elif command in ['clear', 'cl']: |
|
_clear_screen() |
|
print(_WELCOME_MSG) |
|
continue |
|
elif command in ['clear-history', 'clh']: |
|
print(f'[INFO] All {len(history)} history cleared') |
|
history.clear() |
|
continue |
|
elif command in ['help', 'h']: |
|
print(_HELP_MSG) |
|
continue |
|
elif command in ['history', 'his']: |
|
_print_history(history) |
|
continue |
|
elif command in ['seed']: |
|
if len(command_words) == 1: |
|
print(f'[INFO] Current random seed: {seed}') |
|
continue |
|
else: |
|
new_seed_s = command_words[1] |
|
try: |
|
new_seed = int(new_seed_s) |
|
except ValueError: |
|
print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number') |
|
else: |
|
print(f'[INFO] Random seed changed to {new_seed}') |
|
seed = new_seed |
|
continue |
|
elif command in ['conf']: |
|
if len(command_words) == 1: |
|
print(model.generation_config) |
|
else: |
|
for key_value_pairs_str in command_words[1:]: |
|
eq_idx = key_value_pairs_str.find('=') |
|
if eq_idx == -1: |
|
print('[WARNING] format: <key>=<value>') |
|
continue |
|
conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:] |
|
try: |
|
conf_value = eval(conf_value_str) |
|
except Exception as e: |
|
print(e) |
|
continue |
|
else: |
|
print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}') |
|
setattr(model.generation_config, conf_key, conf_value) |
|
continue |
|
elif command in ['reset-conf']: |
|
print('[INFO] Reset generation config') |
|
model.generation_config = deepcopy(orig_gen_config) |
|
print(model.generation_config) |
|
continue |
|
else: |
|
|
|
pass |
|
|
|
|
|
set_seed(seed) |
|
try: |
|
for response in model.chat_stream(tokenizer, query, history=history, generation_config=config): |
|
_clear_screen() |
|
print(f"\nUser: {query}") |
|
print(f"\nQwen-7B: {response}") |
|
except KeyboardInterrupt: |
|
print('[WARNING] Generation interrupted') |
|
continue |
|
|
|
history.append((query, response)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|