adding temperature hyper params
Browse files- handler.py +14 -8
handler.py
CHANGED
@@ -37,23 +37,29 @@ class EndpointHandler():
|
|
37 |
|
38 |
|
39 |
def _secure_inputs(self, data: Dict[str, Any]):
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
if isinstance(inputs, str):
|
45 |
inputs = [inputs]
|
46 |
|
47 |
-
return inputs, True
|
48 |
|
|
|
49 |
def _format_inputs(self, inputs: list[str]):
|
50 |
prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs]
|
51 |
prompts_lengths = [len(prompt) for prompt in prompts]
|
52 |
return prompts, prompts_lengths
|
53 |
|
54 |
-
def _generate_outputs(self, inputs):
|
55 |
tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
|
56 |
-
outputs = self.model.generate(**tokenized, max_new_tokens=500, use_cache=True)
|
57 |
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
58 |
return decoded
|
59 |
|
@@ -67,13 +73,13 @@ class EndpointHandler():
|
|
67 |
|
68 |
|
69 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
70 |
-
inputs, is_secure = self._secure_inputs(data)
|
71 |
|
72 |
if not is_secure:
|
73 |
return inputs
|
74 |
|
75 |
inputs, inputs_length = self._format_inputs(inputs)
|
76 |
-
outputs = self._generate_outputs(inputs)
|
77 |
outputs = self._format_outputs(outputs, inputs_length)
|
78 |
|
79 |
outputs = [{"summary": output_} for output_ in outputs]
|
|
|
37 |
|
38 |
|
39 |
def _secure_inputs(self, data: Dict[str, Any]):
|
40 |
+
if not isinstance(data, dict):
|
41 |
+
return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False
|
42 |
+
|
43 |
+
if not 'inputs' in data:
|
44 |
+
return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False
|
45 |
+
|
46 |
+
temperature = data.get("temperature", 0.01)
|
47 |
+
inputs = data["inputs"]
|
48 |
|
49 |
if isinstance(inputs, str):
|
50 |
inputs = [inputs]
|
51 |
|
52 |
+
return inputs, temperature, True
|
53 |
|
54 |
+
|
55 |
def _format_inputs(self, inputs: list[str]):
|
56 |
prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs]
|
57 |
prompts_lengths = [len(prompt) for prompt in prompts]
|
58 |
return prompts, prompts_lengths
|
59 |
|
60 |
+
def _generate_outputs(self, inputs, temperature):
|
61 |
tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
|
62 |
+
outputs = self.model.generate(**tokenized, temperature=temperature, max_new_tokens=500, use_cache=True)
|
63 |
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
64 |
return decoded
|
65 |
|
|
|
73 |
|
74 |
|
75 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
76 |
+
inputs, temperature, is_secure = self._secure_inputs(data)
|
77 |
|
78 |
if not is_secure:
|
79 |
return inputs
|
80 |
|
81 |
inputs, inputs_length = self._format_inputs(inputs)
|
82 |
+
outputs = self._generate_outputs(inputs, temperature)
|
83 |
outputs = self._format_outputs(outputs, inputs_length)
|
84 |
|
85 |
outputs = [{"summary": output_} for output_ in outputs]
|