|
import json |
|
import os |
|
from typing import Dict, List, Any |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
PROMPT_FORMAT= """ |
|
<|user|> |
|
{inputs} <|end|> |
|
<|assistant|> |
|
""" |
|
|
|
class EndpointHandler(): |
|
def __init__(self, data): |
|
cfg = { |
|
"repo": "MrOvkill/Phi-3-Instruct-Bloated", |
|
} |
|
self.model = AutoModelForCausalLM.from_pretrained(cfg['repo'], trust_remote_code=True, torch_dtype=torch.float16) |
|
self.tokenizer = AutoTokenizer.from_pretrained(cfg['repo']) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
inputs = data.pop("inputs", "Q: What is the chemical composition of common concrete in 2024?\nA: ") |
|
max_new_tokens = 1024 |
|
if "max_new_tokens" in data: |
|
max_new_tokens = data.pop("max_new_tokens") |
|
max_new_tokens = int(max_new_tokens) |
|
try: |
|
max_new_tokens = int(max_new_tokens) |
|
except Exception as e: |
|
return json.dumps({ |
|
"status": "error", |
|
"reason": "max_length was passed as something that was absolutely not a plain old int" |
|
}) |
|
|
|
res = PROMPT_FORMAT.format(do_sample=False) |
|
retrurn model( |
|
res, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=False |
|
) |
|
|
|
return res |