MrOvkill's picture
Update handler.py
dc0013e verified
raw
history blame contribute delete
No virus
1.45 kB
import json
import os
from typing import Dict, List, Any
import torch
from transformers import pipeline
PROMPT_FORMAT= """
<|user|>
{inputs} <|end|>
<|assistant|>
"""
class EndpointHandler():
def __init__(self, data):
cfg = {
"repo": "MrOvkill/Phi-3-Instruct-Bloated",
}
self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True)
max_new_tokens = 1024
if "max_new_tokens" in data:
max_new_tokens = data["max_new_tokens"]
try:
max_new_tokens = int(max_new_tokens)
except Exception as e:
return json.dumps({
"status": "error",
"reason": "max_length was passed as something that was absolutely not a plain old int"
})
res = PROMPT_FORMAT.format(inputs=data['inputs'])
return self.pipe(
res,
do_sample=False,
max_new_tokens=max_new_tokens
)
return res