viethoangtranduong commited on
Commit
85e6e1b
1 Parent(s): 774805b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +81 -13
handler.py CHANGED
@@ -40,14 +40,71 @@
40
 
41
  # return {"generated_text": output_strs}
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  import torch
 
44
  from typing import Dict, List, Any
45
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  class EndpointHandler():
48
  def __init__(self, path: str = ""):
49
 
50
  self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side = "left")
 
51
  self.model = AutoModelForCausalLM.from_pretrained(path, device_map = "auto", torch_dtype=torch.float16)
52
 
53
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -65,19 +122,30 @@ class EndpointHandler():
65
 
66
  prompts = [f"<human>: {prompt}\n<bot>:" for prompt in inputs_list]
67
 
68
- self.tokenizer.pad_token = self.tokenizer.eos_token
69
-
70
- inputs = self.tokenizer(prompts, truncation=True, max_length=2048-512,
71
- return_tensors='pt', padding=True).to(self.model.device)
72
- input_length = inputs.input_ids.shape[1]
73
-
74
- if parameters.get("deterministic", False):
75
- torch.manual_seed(42)
76
-
77
- outputs = self.model.generate(
78
- **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.7, top_k=50
79
- )
 
 
 
 
 
 
 
 
 
 
80
 
81
  output_strs = self.tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
 
82
 
83
  return {"generated_text": output_strs}
 
40
 
41
  # return {"generated_text": output_strs}
42
 
43
+ # import torch
44
+ # from typing import Dict, List, Any
45
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
46
+
47
+ # class EndpointHandler():
48
+ # def __init__(self, path: str = ""):
49
+
50
+ # self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side = "left")
51
+ # self.model = AutoModelForCausalLM.from_pretrained(path, device_map = "auto", torch_dtype=torch.float16)
52
+
53
+ # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
54
+ # """
55
+ # Args:
56
+ # data (:obj:):
57
+ # includes the input data and the parameters for the inference.
58
+ # Return:
59
+ # A :obj:`list`:. The list contains the answer and scores of the inference inputs
60
+ # """
61
+
62
+ # # process input
63
+ # inputs_list = data.pop("inputs", data)
64
+ # parameters = data.pop("parameters", {})
65
+
66
+ # prompts = [f"<human>: {prompt}\n<bot>:" for prompt in inputs_list]
67
+
68
+ # self.tokenizer.pad_token = self.tokenizer.eos_token
69
+
70
+ # inputs = self.tokenizer(prompts, truncation=True, max_length=2048-512,
71
+ # return_tensors='pt', padding=True).to(self.model.device)
72
+ # input_length = inputs.input_ids.shape[1]
73
+
74
+ # if parameters.get("deterministic", False):
75
+ # torch.manual_seed(42)
76
+
77
+ # outputs = self.model.generate(
78
+ # **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.7, top_k=50
79
+ # )
80
+
81
+ # output_strs = self.tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
82
+
83
+ # return {"generated_text": output_strs}
84
+
85
  import torch
86
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
87
  from typing import Dict, List, Any
88
+
89
+ class StopWordsCriteria(StoppingCriteria):
90
+ def __init__(self, stop_words, tokenizer):
91
+ self.tokenizer = tokenizer
92
+ self.stop_words = stop_words
93
+ self._cache_str = ''
94
+
95
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
96
+ self._cache_str += self.tokenizer.decode(input_ids[0, -1])
97
+ for stop_words in self.stop_words:
98
+ if stop_words in self._cache_str:
99
+ return True
100
+ return False
101
+
102
 
103
  class EndpointHandler():
104
  def __init__(self, path: str = ""):
105
 
106
  self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side = "left")
107
+ self.tokenizer.pad_token = self.tokenizer.eos_token
108
  self.model = AutoModelForCausalLM.from_pretrained(path, device_map = "auto", torch_dtype=torch.float16)
109
 
110
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
122
 
123
  prompts = [f"<human>: {prompt}\n<bot>:" for prompt in inputs_list]
124
 
125
+ if parameters.get("preset_truncation_token"):
126
+ preset_truncation_token_value = parameters["preset_truncation_token"]
127
+ DELIMETER = " "
128
+ prompts = [DELIMETER.join(prompt.split(DELIMETER)[:preset_truncation_token_value]) for prompt in prompts]
129
+ print("45", prompts)
130
+ del parameters["preset_truncation_token"]
131
+
132
+ with torch.no_grad():
133
+ inputs = self.tokenizer(prompts, truncation=True, max_length=2048-512,
134
+ return_tensors='pt', padding=True).to(self.model.device)
135
+ input_length = inputs.input_ids.shape[1]
136
+
137
+ if parameters.get("deterministic_seed", False):
138
+ torch.manual_seed(parameters["deterministic_seed"])
139
+ del parameters["deterministic_seed"]
140
+
141
+ outputs = self.model.generate(
142
+ **inputs, **parameters,
143
+ stopping_criteria=StoppingCriteriaList(
144
+ [StopWordsCriteria(['\n<human>:'], self.tokenizer)]
145
+ )
146
+ )
147
 
148
  output_strs = self.tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
149
+ output_strs = [output_str.replace("\n<human>:", "") for output_str in output_strs]
150
 
151
  return {"generated_text": output_strs}