clemparpa commited on
Commit
ba02790
·
verified ·
1 Parent(s): 002d8ef

adding temperature hyper params

Browse files
Files changed (1) hide show
  1. handler.py +14 -8
handler.py CHANGED
@@ -37,23 +37,29 @@ class EndpointHandler():
37
 
38
 
39
  def _secure_inputs(self, data: Dict[str, Any]):
40
- inputs = data.get("inputs", None)
41
- if inputs is None:
42
- return [{"error": "inputs should be shaped like {'inputs': <string or List of strings (abstracts)>}"}], False
 
 
 
 
 
43
 
44
  if isinstance(inputs, str):
45
  inputs = [inputs]
46
 
47
- return inputs, True
48
 
 
49
  def _format_inputs(self, inputs: list[str]):
50
  prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs]
51
  prompts_lengths = [len(prompt) for prompt in prompts]
52
  return prompts, prompts_lengths
53
 
54
- def _generate_outputs(self, inputs):
55
  tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
56
- outputs = self.model.generate(**tokenized, max_new_tokens=500, use_cache=True)
57
  decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
58
  return decoded
59
 
@@ -67,13 +73,13 @@ class EndpointHandler():
67
 
68
 
69
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
70
- inputs, is_secure = self._secure_inputs(data)
71
 
72
  if not is_secure:
73
  return inputs
74
 
75
  inputs, inputs_length = self._format_inputs(inputs)
76
- outputs = self._generate_outputs(inputs)
77
  outputs = self._format_outputs(outputs, inputs_length)
78
 
79
  outputs = [{"summary": output_} for output_ in outputs]
 
37
 
38
 
39
  def _secure_inputs(self, data: Dict[str, Any]):
40
+ if not isinstance(data, dict):
41
+ return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False
42
+
43
+ if not 'inputs' in data:
44
+ return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False
45
+
46
+ temperature = data.get("temperature", 0.01)
47
+ inputs = data["inputs"]
48
 
49
  if isinstance(inputs, str):
50
  inputs = [inputs]
51
 
52
+ return inputs, temperature, True
53
 
54
+
55
  def _format_inputs(self, inputs: list[str]):
56
  prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs]
57
  prompts_lengths = [len(prompt) for prompt in prompts]
58
  return prompts, prompts_lengths
59
 
60
+ def _generate_outputs(self, inputs, temperature):
61
  tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
62
+ outputs = self.model.generate(**tokenized, temperature=temperature, max_new_tokens=500, use_cache=True)
63
  decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
64
  return decoded
65
 
 
73
 
74
 
75
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
76
+ inputs, temperature, is_secure = self._secure_inputs(data)
77
 
78
  if not is_secure:
79
  return inputs
80
 
81
  inputs, inputs_length = self._format_inputs(inputs)
82
+ outputs = self._generate_outputs(inputs, temperature)
83
  outputs = self._format_outputs(outputs, inputs_length)
84
 
85
  outputs = [{"summary": output_} for output_ in outputs]