lyc123456 commited on
Commit
42fc731
1 Parent(s): 3d8b7ec

call example

Browse files
Files changed (1) hide show
  1. README.md +46 -1
README.md CHANGED
@@ -3,4 +3,49 @@ license: other
3
  license_name: license
4
  license_link: https://huggingface.co/Qwen/Qwen1.5-0.5B/blob/main/LICENSE
5
  ---
6
- # A fine-tuned version of the Qwen/Qwen1.5-0.5B model, the data set used is alpaca_gpt4_data_zh.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  license_name: license
4
  license_link: https://huggingface.co/Qwen/Qwen1.5-0.5B/blob/main/LICENSE
5
  ---
6
+ # A fine-tuned version of the Qwen/Qwen1.5-0.5B model, the data set used is alpaca_gpt4_data_zh.json
7
+ · Call example
8
+ ```python
9
+ import os
10
+
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ messages = [
14
+ {"role": "system", "content": "You are a helpful assistant."},
15
+ ]
16
+
17
+ device = "cuda" # the device to load the model onto
18
+ model_path = os.path.dirname(__file__)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_path,
21
+ torch_dtype="auto",
22
+ device_map="auto"
23
+ )
24
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
25
+ response = ''
26
+ if __name__ == '__main__':
27
+
28
+ while True:
29
+ # prompt = "Give me a short introduction to large language model."
30
+ prompt = input("input:")
31
+ messages.append({"role": "user", "content": prompt})
32
+ text = tokenizer.apply_chat_template(
33
+ messages,
34
+ tokenize=False,
35
+ add_generation_prompt=True
36
+ )
37
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
38
+
39
+ generated_ids = model.generate(
40
+ model_inputs.input_ids,
41
+ max_new_tokens=512
42
+ )
43
+ generated_ids = [
44
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
45
+ ]
46
+
47
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
48
+ print(response)
49
+ messages.append({"role": "system", "content": response}, )
50
+
51
+ ```