File size: 1,454 Bytes
ccb82da 83a5aeb 89c1b74 83a5aeb deb69ba 83a5aeb ccb82da df4183f 83a5aeb ccb82da 3c57069 dc0013e 42565ce 3c57069 ccb82da deb69ba ccb82da f72c5f5 937acbc deb69ba 937acbc deb69ba 83a5aeb ccb82da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import json
import os
from typing import Dict, List, Any
import torch
from transformers import pipeline
PROMPT_FORMAT= """
<|user|>
{inputs} <|end|>
<|assistant|>
"""
class EndpointHandler():
def __init__(self, data):
cfg = {
"repo": "MrOvkill/Phi-3-Instruct-Bloated",
}
self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True)
max_new_tokens = 1024
if "max_new_tokens" in data:
max_new_tokens = data["max_new_tokens"]
try:
max_new_tokens = int(max_new_tokens)
except Exception as e:
return json.dumps({
"status": "error",
"reason": "max_length was passed as something that was absolutely not a plain old int"
})
res = PROMPT_FORMAT.format(inputs=data['inputs'])
return self.pipe(
res,
do_sample=False,
max_new_tokens=max_new_tokens
)
return res |