YeungNLP commited on
Commit
7d09496
1 Parent(s): a4341f7

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +86 -0
README.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ QLoRA+百万数据对baichun-7b模型进行高效指令微调
2
+
3
+ 更多详情请查看Github项目: [Firefly(流萤): 中文对话式大语言模型(全量微调+QLoRA)](https://github.com/yangjianxin1/Firefly)
4
+
5
+ 单轮对话脚本:
6
+
7
+ ```python
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ import torch
10
+ model_name = 'YeungNLP/firefly-baichuan-7b-qlora-sft-merge'
11
+ max_new_tokens = 500
12
+ top_p = 0.9
13
+ temperature = 0.35
14
+ repetition_penalty = 1.0
15
+ device = 'cuda'
16
+ input_pattern = '<s>{}</s>'
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ trust_remote_code=True,
20
+ low_cpu_mem_usage=True,
21
+ torch_dtype=torch.float16,
22
+ device_map='auto'
23
+ )
24
+ model.eval()
25
+ model = model.to(device)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
27
+ text = input('User:')
28
+ while True:
29
+ text = input_pattern.format(text)
30
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
31
+ input_ids = input_ids.to(device)
32
+ outputs = model.generate(
33
+ input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
34
+ top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
35
+ eos_token_id=tokenizer.eos_token_id
36
+ )
37
+ rets = tokenizer.batch_decode(outputs)
38
+ output = rets[0].strip().replace(text, "").replace('</s>', "")
39
+ print("Firefly:{}".format(output))
40
+ text = input('User:')
41
+ ```
42
+
43
+
44
+ 多轮对话脚本:
45
+
46
+ ```python
47
+ from transformers import AutoModelForCausalLM, AutoTokenizer
48
+ import torch
49
+ device = 'cuda'
50
+ model_name = 'YeungNLP/firefly-baichuan-7b1-qlora-sft-merge'
51
+ max_new_tokens = 500
52
+ top_p = 0.9
53
+ temperature = 0.35
54
+ repetition_penalty = 1.0
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ model_name,
58
+ trust_remote_code=True,
59
+ low_cpu_mem_usage=True,
60
+ torch_dtype=torch.float16,
61
+ device_map='auto'
62
+ )
63
+ model.eval()
64
+ model = model.to(device)
65
+ # 记录所有历史记录
66
+ history_token_ids = tokenizer('<s>', return_tensors="pt").input_ids
67
+ # 输入模型的最大长度
68
+ history_max_len = 1000
69
+ user_input = input('User:')
70
+ while True:
71
+ user_input = '{}</s>'.format(user_input)
72
+ user_input_ids = tokenizer(user_input, return_tensors="pt").input_ids
73
+ history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
74
+ model_input_ids = history_token_ids[:, -history_max_len:].to(device)
75
+ outputs = model.generate(
76
+ input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
77
+ temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
78
+ )
79
+ model_input_ids_len = model_input_ids.size(1)
80
+ response_ids = outputs[:, model_input_ids_len:]
81
+ history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
82
+ response = tokenizer.batch_decode(response_ids)
83
+ print("Firefly:" + response[0].strip().replace('</s>', ""))
84
+ user_input = input('User:')
85
+ ```
86
+