echeyde's picture
Upload handler.py with huggingface_hub
4eac7e8 verified
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
class EndpointHandler:
def __init__(self, path=""):
# Initialize model and tokenizer
logger.info("Loading model and tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(".")
logger.info("tokenizer loaded...")
self.model = AutoModelForCausalLM.from_pretrained(".")
logger.info("model loaded...")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data: JSON input with structure:
{
"inputs": "your text prompt here",
"parameters": {
"max_new_tokens": 50,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": true
}
}
"""
# Get input text and parameters
inputs = data.pop("inputs", data)
logger.info("inputs loaded...", inputs)
parameters = data.pop("parameters", {})
# Default generation parameters
generation_config = {
"max_new_tokens": parameters.get("max_new_tokens", 50),
"temperature": parameters.get("temperature", 0.7),
"top_p": parameters.get("top_p", 0.9),
"do_sample": parameters.get("do_sample", True),
"pad_token_id": self.tokenizer.eos_token_id,
"num_return_sequences": parameters.get("num_return_sequences", 1)
}
# Tokenize
inputs = self.tokenizer(
inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
# Generate text
with torch.no_grad():
generated_ids = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
**generation_config
)
# Decode and return generated text
generated_texts = self.tokenizer.batch_decode(
generated_ids,
skip_special_tokens=True
)
return {
"generated_text": generated_texts[0], # Return first generation if multiple
"all_generations": generated_texts # All generations if num_return_sequences > 1
}