from typing import List, Dict from transformers import AutoModelForCausalLM, AutoTokenizer import torch class EndpointHandler: def __init__(self, path: str): # Load model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float32, # Use float32 for CPU device_map="auto" ) # Set up generation parameters self.default_params = { "max_length": 1000, "temperature": 0.7, "top_p": 0.7, "top_k": 50, "repetition_penalty": 1.0, "do_sample": True, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id } def __call__(self, data: Dict): """ Args: data: Dictionary with either string input or structured messages Returns: Generated text """ try: # Handle input if isinstance(data.get("inputs"), str): input_text = data["inputs"] else: # Extract messages from input messages = data.get("inputs", {}).get("messages", []) if not messages: return {"error": "No messages provided"} # Format input text input_text = "" for msg in messages: role = msg.get("role", "") content = msg.get("content", "") input_text += f"{role}: {content}\n" # Get generation parameters params = {**self.default_params} if "parameters" in data: params.update(data["parameters"]) # Ensure proper tokenization with padding and attention mask tokenizer_output = self.tokenizer( input_text, return_tensors="pt", padding=True, truncation=True, max_length=512, return_attention_mask=True ) # Move tensors to the same device as the model input_ids = tokenizer_output["input_ids"] attention_mask = tokenizer_output["attention_mask"] # Generate response with torch.no_grad(): outputs = self.model.generate( input_ids, attention_mask=attention_mask, pad_token_id=self.tokenizer.pad_token_id, **params ) # Decode response generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return [{"generated_text": generated_text}] except Exception as e: print(f"Error in generation: {str(e)}") return {"error": str(e)} def preprocess(self, request): """ Prepare request for inference """ if request.content_type != "application/json": raise ValueError("Content type must be application/json") data = request.json return data def postprocess(self, data): """ Post-process model output """ return data