MrOvkill's picture
Update handler.py
42565ce verified
raw
history blame
No virus
1.36 kB
import json
import os
from typing import Dict, List, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
PROMPT_FORMAT= """
<|user|>
{inputs} <|end|>
<|assistant|>
"""
class EndpointHandler():
def __init__(self, data):
cfg = {
"repo": "MrOvkill/Phi-3-Instruct-Bloated",
}
self.model = AutoModelForCausalLM.from_pretrained(cfg['repo'], trust_remote_code=True, torch_dtype=torch.float16)
self.tokenizer = AutoTokenizer.from_pretrained(cfg['repo'])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop("inputs", "Q: What is the chemical composition of common concrete in 2024?\nA: ")
max_new_tokens = 1024
if "max_new_tokens" in data:
max_new_tokens = data.pop("max_new_tokens")
max_new_tokens = int(max_new_tokens)
try:
max_new_tokens = int(max_new_tokens)
except Exception as e:
return json.dumps({
"status": "error",
"reason": "max_length was passed as something that was absolutely not a plain old int"
})
res = PROMPT_FORMAT.format(do_sample=False)
retrurn model(
res,
max_new_tokens=max_new_tokens,
do_sample=False
)
return res