Pierce Maloney commited on
Commit
bf66e5a
1 Parent(s): e223eee

base llemma-7b and custom handler.py

Browse files
Files changed (3) hide show
  1. __pycache__/handler.cpython-311.pyc +0 -0
  2. handler.py +37 -0
  3. sample.py +13 -0
__pycache__/handler.cpython-311.pyc ADDED
Binary file (2.76 kB). View file
 
handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
3
+
4
+
5
+
6
+ class EndpointHandler():
7
+ def __init__(self, path=""):
8
+ # Preload all the elements you are going to need at inference.
9
+ tokenizer = AutoTokenizer.from_pretrained(path)
10
+ model = AutoModelForCausalLM.from_pretrained(path)
11
+ tokenizer.pad_token = tokenizer.eos_token
12
+ self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
13
+ self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ """
17
+ data args:
18
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
19
+ kwargs
20
+ Return:
21
+ A :obj:`list` | `dict`: will be serialized and returned
22
+ """
23
+ inputs = data.pop("inputs", data)
24
+
25
+ prediction = self.pipeline(inputs, stopping_criteria=self.stopping_criteria, max_new_tokens=100)
26
+ return prediction
27
+
28
+
29
+ class StopAtPeriodCriteria(StoppingCriteria):
30
+ def __init__(self, tokenizer):
31
+ self.tokenizer = tokenizer
32
+
33
+ def __call__(self, input_ids, scores, **kwargs):
34
+ # Decode the last generated token to text
35
+ last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
36
+ # Check if the decoded text ends with a period
37
+ return '.' in last_token_text
sample.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path=".")
5
+
6
+ # prepare sample payload
7
+ payload = {"inputs": "I am so happy to reach in my pocket and find a"}
8
+
9
+ # test the handler
10
+ payload=my_handler(payload)
11
+
12
+ # show results
13
+ print("output:", payload)