Pierce Maloney commited on
Commit
366e62e
1 Parent(s): 02ffbef

testing .generate instead of pipeline

Browse files
Files changed (3) hide show
  1. __pycache__/handler.cpython-311.pyc +0 -0
  2. handler.py +17 -28
  3. sample.py +1 -1
__pycache__/handler.cpython-311.pyc CHANGED
Binary files a/__pycache__/handler.cpython-311.pyc and b/__pycache__/handler.cpython-311.pyc differ
 
handler.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import Dict, List, Any
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList, LogitsProcessor, LogitsProcessorList
3
 
4
 
5
 
@@ -7,11 +7,10 @@ class EndpointHandler():
7
  def __init__(self, path=""):
8
  # Preload all the elements you are going to need at inference.
9
  tokenizer = AutoTokenizer.from_pretrained(path)
10
- model = AutoModelForCausalLM.from_pretrained(path)
11
- tokenizer.pad_token = tokenizer.eos_token
12
- self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
13
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
14
- self.logits_processor = LogitsProcessorList([BanSpecificTokensLogitsProcessor(tokenizer, [3070])])
15
 
16
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
  """
@@ -22,19 +21,24 @@ class EndpointHandler():
22
  A :obj:`list` | `dict`: will be serialized and returned
23
  """
24
  inputs = data.pop("inputs", data)
 
25
 
26
  # Bad word: id 3070 corresponds to "(*", and we do not want to output a comment
27
- prediction = self.pipeline(
28
- inputs,
 
29
  stopping_criteria=self.stopping_criteria,
30
- max_new_tokens=50,
31
- return_full_text=False,
32
- # bad_words_ids=[[3070], [313, 334]],
33
- logits_processor=self.logits_processor,
34
  temperature=1,
35
  top_k=40,
 
 
36
  )
37
- return prediction
 
 
 
 
38
 
39
 
40
  class StopAtPeriodCriteria(StoppingCriteria):
@@ -45,19 +49,4 @@ class StopAtPeriodCriteria(StoppingCriteria):
45
  # Decode the last generated token to text
46
  last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
47
  # Check if the decoded text ends with a period
48
- return '.' in last_token_text
49
-
50
- class BanSpecificTokensLogitsProcessor(LogitsProcessor):
51
- """
52
- Logits processor that sets the logits of specific tokens to -inf,
53
- effectively banning them from being generated.
54
- """
55
- def __init__(self, tokenizer, banned_tokens_ids):
56
- self.tokenizer = tokenizer
57
- self.banned_tokens_ids = banned_tokens_ids
58
-
59
- def __call__(self, input_ids, scores):
60
- # Set logits of banned tokens to -inf
61
- for token_id in self.banned_tokens_ids:
62
- scores[:, token_id] = float('-inf')
63
- return scores
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
3
 
4
 
5
 
 
7
  def __init__(self, path=""):
8
  # Preload all the elements you are going to need at inference.
9
  tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.tokenizer = tokenizer
11
+ self.model = AutoModelForCausalLM.from_pretrained(path)
12
+ self.tokenizer.pad_token = tokenizer.eos_token
13
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
 
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
 
21
  A :obj:`list` | `dict`: will be serialized and returned
22
  """
23
  inputs = data.pop("inputs", data)
24
+ input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
25
 
26
  # Bad word: id 3070 corresponds to "(*", and we do not want to output a comment
27
+ prediction_ids = self.model.generate(
28
+ input_ids,
29
+ max_length=input_ids.shape[1] + 50,
30
  stopping_criteria=self.stopping_criteria,
31
+ bad_words_ids=[[3070], [313, 334]],
 
 
 
32
  temperature=1,
33
  top_k=40,
34
+ # pad_token_id=self.tokenizer.eos_token_id,
35
+ # return_dict_in_generate=True, # To get more detailed output (optional)
36
  )
37
+
38
+ # Decode the generated ids to text
39
+ # Exclude the input_ids length to get only the new tokens
40
+ prediction_text = self.tokenizer.decode(prediction_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
41
+ return [{"generated_text": prediction_text}]
42
 
43
 
44
  class StopAtPeriodCriteria(StoppingCriteria):
 
49
  # Decode the last generated token to text
50
  last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
51
  # Check if the decoded text ends with a period
52
+ return '.' in last_token_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample.py CHANGED
@@ -4,7 +4,7 @@ from handler import EndpointHandler
4
  my_handler = EndpointHandler(path=".")
5
 
6
  # prepare sample payload
7
- payload = {"inputs": "This is the format for a"}
8
 
9
  # test the handler
10
  payload=my_handler(payload)
 
4
  my_handler = EndpointHandler(path=".")
5
 
6
  # prepare sample payload
7
+ payload = {"inputs": "I want to turn the next page of the"}
8
 
9
  # test the handler
10
  payload=my_handler(payload)