Sacha-Mistral-0 / handler.py
Nac31's picture
Handler
ad5bd44
import torch
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import os
from huggingface_hub import login
hf_token = os.getenv('HF_TOKEN')
login(hf_token)
# get dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
logger.info(f"Using dtype: {dtype}")
class EndpointHandler:
def __init__(self, path=""):
logger.info("Initializing EndpointHandler")
# load the model
logger.info(f"Loading tokenizer from {path}")
tokenizer = AutoTokenizer.from_pretrained(path)
logger.info("Tokenizer loaded successfully")
# Load the model
logger.info(f"Loading model from {path} with dtype {dtype}")
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",torch_dtype=dtype)
logger.info("Model loaded successfully")
# Create inference pipeline
logger.info("Creating inference pipeline")
self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
logger.info("Inference pipeline created successfully")
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
logger.info("Received data for inference")
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
logger.info(f"Inputs: {inputs}")
logger.info(f"Parameters: {parameters}")
# pass inputs with all kwargs in data
if parameters is not None:
logger.info("Generating prediction with parameters")
prediction = self.pipeline(inputs, **parameters)
else:
logger.info("Generating prediction without parameters")
prediction = self.pipeline(inputs)
logger.info(f"Prediction: {prediction}")
# postprocess the prediction
return prediction