File size: 1,461 Bytes
bb6fc44
 
 
 
 
 
 
 
 
 
6d0d7d1
bb6fc44
 
 
583c519
6d0d7d1
bb6fc44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2eb90c
bb6fc44
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import json
import os
from typing import Dict, List, Any
from llama_cpp import Llama
import gemma_tools as gem

MAX_TOKENS=8192

class EndpointHandler():
    def __init__(self, data):
        self.model = Llama.from_pretrained("lmstudio-ai/gemma-2b-it-GGUF", filename="gemma-2b-it-q4_k_m.gguf", n_ctx=8192)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        args = gem.get_args_or_none(data)
        fmat = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{prompt} <endofturn>\n<startofturn>model"
        print(args, fmat)
        if not args[0]:
            return {
                "status": args["status"],
                "message": args["description"]
            }
        try:
            fmat = fmat.format(system_prompt = args["system_prompt"], prompt = args["inputs"])
        except Exception as e:
            return json.dumps({
                "status": "error",
                "reason": "invalid format"
            })
        max_length = data.pop("max_length", 512)
        try:
            max_length = int(max_length)
        except Exception as e:
            return json.dumps({
                "status": "error",
                "reason": "max_length was passed as something that was absolutely not a plain old int"
            })
        
        res = self.model(fmat, temperature=args["temperature"], top_p=args["top_p"], top_k=args["top_k"], max_tokens=max_length)

        return res