DSMI
/

KaizeShi commited on
Commit
d56218d
1 Parent(s): ba61b45

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +125 -0
inference.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import fire
4
+ import torch
5
+ from peft import PeftModel
6
+ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
7
+ from utils.prompter import Prompter
8
+
9
+ if torch.cuda.is_available():
10
+ device = "cuda"
11
+ else:
12
+ device = "cpu"
13
+
14
+ try:
15
+ if torch.backends.mps.is_available():
16
+ device = "mps"
17
+ except:
18
+ pass
19
+
20
+
21
+ def main(
22
+ load_8bit: bool = False,
23
+ base_model: str = "",
24
+ lora_weights: str = "DSMI/LLaMA-E/7b",
25
+ prompt_template: str = "",
26
+ ):
27
+ print("lora_weights: " + str(lora_weights))
28
+ base_model = base_model or os.environ.get("BASE_MODEL", "")
29
+
30
+ prompter = Prompter(prompt_template)
31
+ tokenizer = LlamaTokenizer.from_pretrained(base_model)
32
+ if device == "cuda":
33
+ model = LlamaForCausalLM.from_pretrained(
34
+ base_model,
35
+ load_in_8bit=load_8bit,
36
+ torch_dtype=torch.float16,
37
+ device_map="auto",
38
+ )
39
+ model = PeftModel.from_pretrained(
40
+ model,
41
+ lora_weights,
42
+ torch_dtype=torch.float16,
43
+ )
44
+ elif device == "mps":
45
+ model = LlamaForCausalLM.from_pretrained(
46
+ base_model,
47
+ device_map={"": device},
48
+ torch_dtype=torch.float16,
49
+ )
50
+ model = PeftModel.from_pretrained(
51
+ model,
52
+ lora_weights,
53
+ device_map={"": device},
54
+ torch_dtype=torch.float16,
55
+ )
56
+ else:
57
+ model = LlamaForCausalLM.from_pretrained(
58
+ base_model, device_map={"": device}, low_cpu_mem_usage=True
59
+ )
60
+ model = PeftModel.from_pretrained(
61
+ model,
62
+ lora_weights,
63
+ device_map={"": device},
64
+ )
65
+
66
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
67
+ model.config.bos_token_id = 1
68
+ model.config.eos_token_id = 2
69
+
70
+ if not load_8bit:
71
+ model.half() # seems to fix bugs for some users.
72
+
73
+ model.eval()
74
+ if torch.__version__ >= "2" and sys.platform != "win32":
75
+ model = torch.compile(model)
76
+
77
+ def evaluate(
78
+ instruction,
79
+ input=None,
80
+ temperature=0.1,
81
+ top_p=0.75,
82
+ top_k=40,
83
+ num_beams=4,
84
+ max_new_tokens=256,
85
+ **kwargs,
86
+ ):
87
+ prompt = prompter.generate_prompt(instruction, input)
88
+ inputs = tokenizer(prompt, return_tensors="pt")
89
+ input_ids = inputs["input_ids"].to(device)
90
+ generation_config = GenerationConfig(
91
+ temperature=temperature,
92
+ top_p=top_p,
93
+ top_k=top_k,
94
+ num_beams=num_beams,
95
+ **kwargs,
96
+ )
97
+
98
+ with torch.no_grad():
99
+ generation_output = model.generate(
100
+ input_ids=input_ids,
101
+ generation_config=generation_config,
102
+ return_dict_in_generate=True,
103
+ output_scores=True,
104
+ max_new_tokens=max_new_tokens,
105
+ )
106
+ s = generation_output.sequences[0]
107
+ output = tokenizer.decode(s)
108
+ return prompter.get_response(output).split("</s>")[0]
109
+
110
+ print()
111
+ instruction = "Where can I buy the handmade jewellery?"
112
+ print("Instruction:", instruction)
113
+ print("Response:", evaluate(instruction))
114
+ print()
115
+
116
+ instruction = "Create an attractive advertisement for the Christmas sale of the following product."
117
+ input = "Custom Photo Music Plaque,Personalized Photo Frame,Album Cover Song Plaque,Music Photo Name Night Lamp,Photo and Music Gift, Music Prints"
118
+ print("Instruction:", instruction)
119
+ print("Input:", input)
120
+ print("Response:", evaluate(instruction, input))
121
+ print()
122
+
123
+
124
+ if __name__ == "__main__":
125
+ fire.Fire(main)