karlbooster commited on
Commit
572062d
1 Parent(s): 1ec58ed

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +72 -19
handler.py CHANGED
@@ -1,27 +1,80 @@
1
  import torch
2
  from typing import Dict, List, Any
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
4
 
5
  # get dtype
6
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
7
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class EndpointHandler:
10
- def __init__(self, path=""):
11
- # load the model
12
- tokenizer = AutoTokenizer.from_pretrained(path)
13
- model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",torch_dtype=dtype)
14
- # create inference pipeline
15
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
-
17
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
- inputs = data.pop("inputs", data)
19
- parameters = data.pop("parameters", None)
20
-
21
- # pass inputs with all kwargs in data
22
- if parameters is not None:
23
- prediction = self.pipeline(inputs, **parameters)
24
- else:
25
- prediction = self.pipeline(inputs)
26
- # postprocess the prediction
27
- return prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, List, Any
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
4
+ from transformers import StoppingCriteria, StoppingCriteriaList
5
 
6
  # get dtype
7
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
8
 
9
+ class StoppingCriteriaSub(StoppingCriteria):
10
+ def __init__(self, stops = [], encounters=1):
11
+ super().__init__()
12
+ self.stops = [stop.to("cuda") for stop in stops]
13
+
14
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
15
+ for stop in self.stops:
16
+ stop_len = len(stop)
17
+ if input_ids.shape[1] >= stop_len:
18
+ if torch.all(stop == input_ids[:, -stop_len:]).item():
19
+ return True
20
+ return False
21
 
22
  class EndpointHandler:
23
+ def __init__(self, path=""):
24
+ # load the model
25
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
26
+ self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",torch_dtype=dtype)
27
+ print("model loaded")
28
+
29
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
30
+ inputs = data.pop("inputs", data)
31
+ parameters = data.pop("parameters", None)
32
+ if parameters is None:
33
+ parameters = {}
34
+
35
+ prompt = inputs
36
+ temperature = parameters.get("temperature", 0.8)
37
+ top_p = parameters.get("top_p", 0.9)
38
+ top_k = parameters.get("top_k", 0)
39
+ max_new_tokens = parameters.get("max_new_tokens", 100)
40
+ repetition_penalty=parameters.get("diversity_penalty",1.1)
41
+ max_length=parameters.get("max_length",2048)
42
+ stop_words = parameters.get("stop_words", [])
43
+ num_return_sequences=parameters.get("num_return_sequences",1)
44
+
45
+ generation_config = GenerationConfig(
46
+ temperature=temperature,
47
+ top_p=top_p,
48
+ top_k=top_k,
49
+ max_new_tokens=max_new_tokens,
50
+ max_length=max_length,
51
+ eos_token_id=self.tokenizer.eos_token_id,
52
+ pad_token_id=self.tokenizer.pad_token_id,
53
+ repetition_penalty=repetition_penalty,
54
+ num_return_sequences=num_return_sequences,
55
+ do_sample=True
56
+ )
57
+
58
+ # Tokenize inputs
59
+ input_tokens = self.tokenizer.encode(prompt, return_tensors="pt",max_length=max_length-max_new_tokens, truncation=True).to(self.model.device)
60
+
61
+ # Decode truncated prompt
62
+ truncated_prompt = self.tokenizer.decode(input_tokens.squeeze(), skip_special_tokens=True)
63
+
64
+
65
+ stop_words_ids = [self.tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
66
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
67
+
68
+ # Create attention mask
69
+ attention_mask = torch.ones_like(input_tokens).to(self.model.device)
70
+
71
+ # Run the model
72
+ output = self.model.generate(input_tokens,
73
+ generation_config=generation_config,
74
+ stopping_criteria=stopping_criteria,
75
+ attention_mask=attention_mask,
76
+ )
77
+ #only return the part after the prompt
78
+ output_text = self.tokenizer.batch_decode(output, skip_special_tokens = True)[0][len(truncated_prompt):]
79
+
80
+ return [{"generated_text": output_text}]