File size: 1,454 Bytes
ccb82da
 
 
83a5aeb
89c1b74
83a5aeb
deb69ba
 
 
 
 
83a5aeb
ccb82da
 
 
 
 
df4183f
83a5aeb
ccb82da
3c57069
 
 
 
 
 
 
dc0013e
42565ce
 
3c57069
ccb82da
deb69ba
ccb82da
 
 
 
 
 
f72c5f5
937acbc
deb69ba
937acbc
 
deb69ba
83a5aeb
ccb82da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import json
import os
from typing import Dict, List, Any
import torch
from transformers import pipeline

PROMPT_FORMAT= """
<|user|>
{inputs} <|end|>
<|assistant|>
"""

class EndpointHandler():
    def __init__(self, data):
        cfg = {
            "repo": "MrOvkill/Phi-3-Instruct-Bloated",
        }
        self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True)
        max_new_tokens = 1024
        if "max_new_tokens" in data:
            max_new_tokens = data["max_new_tokens"]
        try:
            max_new_tokens = int(max_new_tokens)
        except Exception as e:
            return json.dumps({
                "status": "error",
                "reason": "max_length was passed as something that was absolutely not a plain old int"
            })
        
        res = PROMPT_FORMAT.format(inputs=data['inputs'])
        return self.pipe(
            res,
            do_sample=False,
            max_new_tokens=max_new_tokens
        )

        return res