DSMI
/

KaizeShi commited on
Commit
ba61b45
1 Parent(s): 582c300

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -125
inference.py DELETED
@@ -1,125 +0,0 @@
1
- import os
2
- import sys
3
- import fire
4
- import torch
5
- from peft import PeftModel
6
- from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
7
- from utils.prompter import Prompter
8
-
9
- if torch.cuda.is_available():
10
- device = "cuda"
11
- else:
12
- device = "cpu"
13
-
14
- try:
15
- if torch.backends.mps.is_available():
16
- device = "mps"
17
- except:
18
- pass
19
-
20
-
21
- def main(
22
- load_8bit: bool = False,
23
- base_model: str = "",
24
- lora_weights: str = "DSMI/LLaMA-E/7b",
25
- prompt_template: str = "",
26
- ):
27
- print("lora_weights: " + str(lora_weights))
28
- base_model = base_model or os.environ.get("BASE_MODEL", "")
29
-
30
- prompter = Prompter(prompt_template)
31
- tokenizer = LlamaTokenizer.from_pretrained(base_model)
32
- if device == "cuda":
33
- model = LlamaForCausalLM.from_pretrained(
34
- base_model,
35
- load_in_8bit=load_8bit,
36
- torch_dtype=torch.float16,
37
- device_map="auto",
38
- )
39
- model = PeftModel.from_pretrained(
40
- model,
41
- lora_weights,
42
- torch_dtype=torch.float16,
43
- )
44
- elif device == "mps":
45
- model = LlamaForCausalLM.from_pretrained(
46
- base_model,
47
- device_map={"": device},
48
- torch_dtype=torch.float16,
49
- )
50
- model = PeftModel.from_pretrained(
51
- model,
52
- lora_weights,
53
- device_map={"": device},
54
- torch_dtype=torch.float16,
55
- )
56
- else:
57
- model = LlamaForCausalLM.from_pretrained(
58
- base_model, device_map={"": device}, low_cpu_mem_usage=True
59
- )
60
- model = PeftModel.from_pretrained(
61
- model,
62
- lora_weights,
63
- device_map={"": device},
64
- )
65
-
66
- model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
67
- model.config.bos_token_id = 1
68
- model.config.eos_token_id = 2
69
-
70
- if not load_8bit:
71
- model.half() # seems to fix bugs for some users.
72
-
73
- model.eval()
74
- if torch.__version__ >= "2" and sys.platform != "win32":
75
- model = torch.compile(model)
76
-
77
- def evaluate(
78
- instruction,
79
- input=None,
80
- temperature=0.1,
81
- top_p=0.75,
82
- top_k=40,
83
- num_beams=4,
84
- max_new_tokens=128,
85
- **kwargs,
86
- ):
87
- prompt = prompter.generate_prompt(instruction, input)
88
- inputs = tokenizer(prompt, return_tensors="pt")
89
- input_ids = inputs["input_ids"].to(device)
90
- generation_config = GenerationConfig(
91
- temperature=temperature,
92
- top_p=top_p,
93
- top_k=top_k,
94
- num_beams=num_beams,
95
- **kwargs,
96
- )
97
-
98
- with torch.no_grad():
99
- generation_output = model.generate(
100
- input_ids=input_ids,
101
- generation_config=generation_config,
102
- return_dict_in_generate=True,
103
- output_scores=True,
104
- max_new_tokens=max_new_tokens,
105
- )
106
- s = generation_output.sequences[0]
107
- output = tokenizer.decode(s)
108
- return prompter.get_response(output).split("</s>")[0]
109
-
110
- print()
111
- instruction = "Where can I buy the handmade jewellery?"
112
- print("Instruction:", instruction)
113
- print("Response:", evaluate(instruction))
114
- print()
115
-
116
- instruction = "Generate an ad for the following product."
117
- input = "Emerald Teardrop Necklace.May Birthstone Pendant.Dainty Gift for Her.925 Sterling Silver.Spring Sale"
118
- print("Instruction:", instruction)
119
- print("Input:", input)
120
- print("Response:", evaluate(instruction, input))
121
- print()
122
-
123
-
124
- if __name__ == "__main__":
125
- fire.Fire(main)