Adapters
Inference Endpoints
JeremyArancio commited on
Commit
dc32044
1 Parent(s): 18a6a4a

Update handler

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. handler.py +19 -29
README.md CHANGED
@@ -46,6 +46,7 @@ tokens = model.generate(
46
  eos_token_id=tokenizer.eos_token_id,
47
  early_stopping=True
48
  )
 
49
 
50
  # The hobbits were so suprised seeing their friend again that they did not
51
  # speak. Aragorn looked at them, and then he turned to the others.</s>
 
46
  eos_token_id=tokenizer.eos_token_id,
47
  early_stopping=True
48
  )
49
+ print(tokenizer.decode(tokens[0]))
50
 
51
  # The hobbits were so suprised seeing their friend again that they did not
52
  # speak. Aragorn looked at them, and then he turned to the others.</s>
handler.py CHANGED
@@ -1,9 +1,13 @@
1
  from typing import Dict, List, Any
 
2
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftConfig, PeftModel
5
 
6
 
 
 
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
  config = PeftConfig.from_pretrained(path)
@@ -14,35 +18,21 @@ class EndpointHandler():
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
17
- data args:
18
- prompt (:obj:`str`):
19
- temperature (:obj:`float`, `optional`, defaults to 0.5):
20
- eos_token_id (:obj:`int`, `optional`, defaults to tokenizer.eos_token_id):
21
- early_stopping (:obj:`bool`, `optional`, defaults to `True`):
22
- repetition_penalty (:obj:`float`, `optional`, defaults to 0.3):
23
- Return:
24
- A :obj:`str` : generated sequences
25
  """
 
26
  # Get inputs
27
- prompt = data.pop("prompt", None)
28
- temperature = data.pop("temperature", 0.5)
29
- eos_token_id = data.pop("eos_token_id", self.tokenizer.eos_token_id)
30
- early_stopping = data.pop('early_stopping', True)
31
- repetition_penalty = data.pop('repetition_penalty', 0.3)
32
- max_new_tokens = data.pop('max_new_tokens', 100)
33
-
34
- if prompt is None:
35
- raise ValueError("No prompt provided.")
36
-
37
- # Run prediction
38
  inputs = self.tokenizer(prompt, return_tensors="pt")
39
- prediction = self.model.generate(
40
- **inputs,
41
- temperature=temperature,
42
- eos_token_id=eos_token_id,
43
- early_stopping=early_stopping,
44
- repetition_penalty=repetition_penalty,
45
- max_new_tokens=max_new_tokens
46
- )
47
-
48
- return prediction
 
1
  from typing import Dict, List, Any
2
+ import logging
3
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from peft import PeftConfig, PeftModel
6
 
7
 
8
+ LOGGER = logging.getLogger(__name__)
9
+
10
+
11
  class EndpointHandler():
12
  def __init__(self, path=""):
13
  config = PeftConfig.from_pretrained(path)
 
18
 
19
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
20
  """
21
+ Args:
22
+ data (Dict): The payload with the text prompt and generation parameters.
 
 
 
 
 
 
23
  """
24
+ LOGGER.info(f"Received data: {data}")
25
  # Get inputs
26
+ prompt = data.pop("prompt", data)
27
+ parameters = data.pop("parameters", None)
28
+ # Preprocess
 
 
 
 
 
 
 
 
29
  inputs = self.tokenizer(prompt, return_tensors="pt")
30
+ # Forward
31
+ if parameters is not None:
32
+ outputs = self.model.generate(**inputs, **parameters)
33
+ else:
34
+ outputs = self.model.generate(**inputs)
35
+ # Postprocess
36
+ prediction = self.tokenizer.decode(outputs[0])
37
+ LOGGER.info(f"Generated text: {prediction}")
38
+ return [{"generated_text": prediction}]