Inference endpoint for Gemma 2b it is almost at release 0.1!
Browse files- handler.py +14 -37
handler.py
CHANGED
@@ -1,19 +1,13 @@
|
|
1 |
from typing import Dict, List, Any
|
2 |
from llama_cpp import Llama
|
3 |
|
|
|
|
|
4 |
class EndpointHandler():
|
5 |
-
def __init__(self
|
6 |
self.model = Llama.from_pretrained("MrOvkill/gemma-2-inference-endpoint-GGUF", filename="gemma-2b.q8_0.gguf")
|
7 |
|
8 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
9 |
-
"""
|
10 |
-
data args:
|
11 |
-
inputs (:obj: `str`)
|
12 |
-
image (:obj: `Image`)
|
13 |
-
Return:
|
14 |
-
A :obj:`list` | `dict`: will be serialized and returned
|
15 |
-
"""
|
16 |
-
# get inputs
|
17 |
inputs = data.pop("inputs", "")
|
18 |
temperature = data.pop("temperature", None)
|
19 |
if not temperature:
|
@@ -35,33 +29,16 @@ class EndpointHandler():
|
|
35 |
"status": "error",
|
36 |
"reason": "invalid top k ( 1 - 99 )"
|
37 |
})
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
#if image:
|
49 |
-
# perform image classification using Obsidian 3b vision
|
50 |
-
#image_features = self.vision.encode_image(image)
|
51 |
-
#image_embedding = self.vision.extract_feature(image_features)
|
52 |
-
#image_caption = self.vision.generate_caption(image_embedding)
|
53 |
-
|
54 |
-
# combine text and image captions
|
55 |
-
#combined_captions = [inputs, image_caption]
|
56 |
-
|
57 |
-
# run text classification on combined captions
|
58 |
-
#prediction = self.pipeline(combined_captions, temperature=0.33, num_beams=5, stop=[], do_sample=True)
|
59 |
-
|
60 |
-
#return prediction
|
61 |
|
|
|
62 |
|
63 |
-
|
64 |
-
# run text classification on plain text input
|
65 |
-
# prediction = self.pipeline(inputs, temperature=0.33, num_beams=5, stop=[], do_sample=True)
|
66 |
-
|
67 |
-
# return prediction
|
|
|
1 |
from typing import Dict, List, Any
|
2 |
from llama_cpp import Llama
|
3 |
|
4 |
+
MAX_TOKENS=8192
|
5 |
+
|
6 |
class EndpointHandler():
|
7 |
+
def __init__(self):
|
8 |
self.model = Llama.from_pretrained("MrOvkill/gemma-2-inference-endpoint-GGUF", filename="gemma-2b.q8_0.gguf")
|
9 |
|
10 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
inputs = data.pop("inputs", "")
|
12 |
temperature = data.pop("temperature", None)
|
13 |
if not temperature:
|
|
|
29 |
"status": "error",
|
30 |
"reason": "invalid top k ( 1 - 99 )"
|
31 |
})
|
32 |
+
system_prompt = data.pop("system-prompt", "You are Gemma. Assist user with whatever they require, in a safe and moral manner.")
|
33 |
+
format = data.pop("format", "<startofturn>system\n{system_prompt} <endoftext>\n<startofturn>user\n{prompt} <endofturn>\n<startofturn>model")
|
34 |
+
try:
|
35 |
+
format = format.format(system_prompt = system_prompt, prompt = inputs)
|
36 |
+
except Exception as e:
|
37 |
+
return json.dumps({
|
38 |
+
"status": "error",
|
39 |
+
"reason": "invalid format"
|
40 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
res = self.model(format, temperature=temperature, top_p=top_p, top_k=42)
|
43 |
|
44 |
+
return res
|
|
|
|
|
|
|
|