shajiu commited on
Commit
c5e1840
1 Parent(s): 5dfa0e8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -7
README.md CHANGED
@@ -6,22 +6,55 @@ license: llama2
6
  ## 多轮对话测试demo
7
  ```python
8
  # -- coding: utf-8 --
9
- # @time :
10
  # @author : shajiu
11
  # @email : 18810979033@163.com
12
  # @file : .py
13
  # @software: pycharm
 
 
14
  from transformers import AutoTokenizer
 
15
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- import sys
18
- sys.path.append("../../")
19
- from component.utils import ModelUtils
20
 
21
 
22
- def main():
23
  # 使用合并后的模型进行推理
24
- model_name_or_path = 'shajiu/Tibetan_Llama2_7B_Mental_Health'
25
  adapter_name_or_path = None
26
 
27
  # 使用base model和adapter进行推理
@@ -95,5 +128,6 @@ def main():
95
 
96
 
97
  if __name__ == '__main__':
98
- main()
 
99
  ```
 
6
  ## 多轮对话测试demo
7
  ```python
8
  # -- coding: utf-8 --
9
+ # @time : 2024/12/1 16:26
10
  # @author : shajiu
11
  # @email : 18810979033@163.com
12
  # @file : .py
13
  # @software: pycharm
14
+
15
+
16
  from transformers import AutoTokenizer
17
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
18
  import torch
19
+ from peft import PeftModel
20
+
21
+ class ModelUtils(object):
22
+
23
+ @classmethod
24
+ def load_model(cls, model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
25
+ # 是否使用4bit量化进行推理
26
+ if load_in_4bit:
27
+ quantization_config = BitsAndBytesConfig(
28
+ load_in_4bit=True,
29
+ bnb_4bit_compute_dtype=torch.float16,
30
+ bnb_4bit_use_double_quant=True,
31
+ bnb_4bit_quant_type="nf4",
32
+ llm_int8_threshold=6.0,
33
+ llm_int8_has_fp16_weight=False,
34
+ )
35
+ else:
36
+ quantization_config = None
37
+
38
+ # 加载base model
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_name_or_path,
41
+ load_in_4bit=load_in_4bit,
42
+ trust_remote_code=True,
43
+ low_cpu_mem_usage=True,
44
+ torch_dtype=torch.float16,
45
+ device_map='auto',
46
+ quantization_config=quantization_config
47
+ )
48
+
49
+ # 加载adapter
50
+ if adapter_name_or_path is not None:
51
+ model = PeftModel.from_pretrained(model, adapter_name_or_path)
52
 
53
+ return model
 
 
54
 
55
 
56
+ def main(model_name_or_path):
57
  # 使用合并后的模型进行推理
 
58
  adapter_name_or_path = None
59
 
60
  # 使用base model和adapter进行推理
 
128
 
129
 
130
  if __name__ == '__main__':
131
+ model_name_or_path = 'E:\models\shajiuTibetan_Llama2_7B_Mental_Health'
132
+ main(model_name_or_path)
133
  ```