QLoRA+百万数据对baichun-7b模型进行高效指令微调 更多详情请查看Github项目: [Firefly(流萤): 中文对话式大语言模型(全量微调+QLoRA)](https://github.com/yangjianxin1/Firefly) 单轮对话脚本: ```python from transformers import AutoModelForCausalLM, AutoTokenizer import torch model_name = 'YeungNLP/firefly-baichuan-7b-qlora-sft-merge' max_new_tokens = 500 top_p = 0.9 temperature = 0.35 repetition_penalty = 1.0 device = 'cuda' input_pattern = '{}' model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map='auto' ) model.eval() model = model.to(device) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) text = input('User:') while True: text = input_pattern.format(text) input_ids = tokenizer(text, return_tensors="pt").input_ids input_ids = input_ids.to(device) outputs = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id ) rets = tokenizer.batch_decode(outputs) output = rets[0].strip().replace(text, "").replace('', "") print("Firefly:{}".format(output)) text = input('User:') ``` 多轮对话脚本: ```python from transformers import AutoModelForCausalLM, AutoTokenizer import torch device = 'cuda' model_name = 'YeungNLP/firefly-baichuan-7b1-qlora-sft-merge' max_new_tokens = 500 top_p = 0.9 temperature = 0.35 repetition_penalty = 1.0 tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map='auto' ) model.eval() model = model.to(device) # 记录所有历史记录 history_token_ids = tokenizer('', return_tensors="pt").input_ids # 输入模型的最大长度 history_max_len = 1000 user_input = input('User:') while True: user_input = '{}'.format(user_input) user_input_ids = tokenizer(user_input, return_tensors="pt").input_ids history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1) model_input_ids = history_token_ids[:, -history_max_len:].to(device) outputs = model.generate( input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id ) model_input_ids_len = model_input_ids.size(1) response_ids = outputs[:, model_input_ids_len:] history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1) response = tokenizer.batch_decode(response_ids) print("Firefly:" + response[0].strip().replace('', "")) user_input = input('User:') ```