EC2 Default User commited on
Commit
a8fc6c5
1 Parent(s): 1ac6aba

Add lora model and custom inference file

Browse files
Files changed (3) hide show
  1. adapter_config.json +29 -0
  2. adapter_model.safetensors +3 -0
  3. handler.py +91 -0
adapter_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layers_pattern": null,
10
+ "layers_to_transform": null,
11
+ "loftq_config": {},
12
+ "lora_alpha": 16,
13
+ "lora_dropout": 0.05,
14
+ "megatron_config": null,
15
+ "megatron_core": "megatron.core",
16
+ "modules_to_save": null,
17
+ "peft_type": "LORA",
18
+ "r": 8,
19
+ "rank_pattern": {},
20
+ "revision": null,
21
+ "target_modules": [
22
+ "k_proj",
23
+ "v_proj",
24
+ "q_proj",
25
+ "o_proj"
26
+ ],
27
+ "task_type": "CAUSAL_LM",
28
+ "use_rslora": false
29
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78a07e0094519c2da9e7bf00fe01ff30bf78931f1b83f96fd7ce1cfa42a33782
3
+ size 27297032
handler.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import torch
4
+ from typing import List
5
+ from typing import Dict, Any
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
7
+ import torch
8
+
9
+
10
+ class MyStoppingCriteria(StoppingCriteria):
11
+ def __init__(self, target_sequence, prompt, tokenizer):
12
+ self.target_sequence = target_sequence
13
+ self.prompt = prompt
14
+ self.tokenizer = tokenizer
15
+
16
+ def __call__(self, input_ids, scores, **kwargs):
17
+ # Get the generated text as a string
18
+ generated_text = self.tokenizer.decode(input_ids[0])
19
+ generated_text = generated_text.replace(self.prompt, '')
20
+ # Check if the target sequence appears in the generated text
21
+ if self.target_sequence in generated_text:
22
+ return True # Stop generation
23
+
24
+ return False # Continue generation
25
+
26
+ def __len__(self):
27
+ return 1
28
+
29
+ def __iter__(self):
30
+ yield self
31
+
32
+
33
+ class EndpointHandler:
34
+ def __init__(self, model_dir=""):
35
+ # load model and processor from path
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
37
+ self.model = AutoModelForCausalLM.from_pretrained(model_dir, load_in_4bit=True, device_map="auto")
38
+
39
+ self.template = {
40
+ "prompt_input": """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n""",
41
+ "prompt_no_input": """Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n""",
42
+ "response_split": """### Response:"""
43
+ }
44
+ self.instruction = """Extract the start and end sequences for the categories 'personal information', 'work experience', 'education' and 'skills' from the following text in dictionary form"""
45
+
46
+ if torch.cuda.is_available():
47
+ self.device = "cuda"
48
+ else:
49
+ self.device = "cpu"
50
+
51
+
52
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
53
+ """
54
+ Args:
55
+ data (:dict:):
56
+ The payload with the text prompt and generation parameters.
57
+ """
58
+ # process input
59
+ inputs = data.pop("inputs", data)
60
+ parameters = data.pop("parameters", None)
61
+
62
+ res = self.template["prompt_input"].format(
63
+ instruction=self.instruction, input=input
64
+ )
65
+ messages = [
66
+ {"role": "user", "content": res},
67
+ ]
68
+ input_ids = self.tokenizer.apply_chat_template(
69
+ messages, truncation=True, add_generation_prompt=True, return_tensors="pt"
70
+ ).input_ids
71
+ input_ids = input_ids.to(self.device)
72
+
73
+ # pass inputs with all kwargs in data
74
+ if parameters is not None:
75
+ outputs = self.model.generate(
76
+ input_ids=input_ids,
77
+ stopping_criteria=MyStoppingCriteria("</s>", inputs, self.tokenizer),
78
+ **parameters)
79
+ else:
80
+ outputs = self.model.generate(
81
+ input_ids=input_ids, max_new_tokens=32,
82
+ stopping_criteria=MyStoppingCriteria("</s>", inputs, self.tokenizer)
83
+ )
84
+
85
+ # postprocess the prediction
86
+ prediction = self.tokenizer.decode(outputs[0][input_ids.shape[1]:]) #, skip_special_tokens=True)
87
+ prediction = prediction.split("</s>")[0]
88
+
89
+ # TODO: add processing of the LLM output
90
+
91
+ return [{"generated_text": prediction}]