File size: 5,353 Bytes
9146a36 f72bb15 d0b295c f72bb15 d0b295c 9146a36 f72bb15 9146a36 f72bb15 2fc959a 556a89f f72bb15 9146a36 f72bb15 0e056a8 f72bb15 2fc959a f72bb15 2fc959a 0e056a8 f72bb15 d0b295c f72bb15 0a3c393 f72bb15 2a5bbd8 407d7fe 2a5bbd8 407d7fe f7937a7 2a5bbd8 d0b295c f7937a7 f72bb15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import logging
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str = ""):
logger.info(f"Initializing EndpointHandler with model path: {path}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(path)
logger.info("Tokenizer loaded successfully")
self.model = AutoModelForCausalLM.from_pretrained(
path,
device_map="auto"
)
logger.info(f"Model loaded successfully. Device map: {self.model.device}")
self.model.eval()
logger.info("Model set to evaluation mode")
# Default generation parameters
self.default_params = {
"max_new_tokens": 1000,
"temperature": 0.01,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.1,
"do_sample": True
}
logger.info(f"Default generation parameters: {self.default_params}")
except Exception as e:
logger.error(f"Error during initialization: {str(e)}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""Handle chat completion requests.
Args:
data: Dictionary containing:
- messages: List of message dictionaries with 'role' and 'content'
- generation_params: Optional dictionary of generation parameters
Returns:
List containing the generated response message
"""
try:
logger.info("Processing new request")
logger.info(f"Input data: {data}")
input_messages = data.get("inputs", [])
if not input_messages:
logger.warning("No input messages provided")
return [{"role": "assistant", "content": "No input messages provided"}]
# Get generation parameters, use defaults for missing values
gen_params = {**self.default_params, **data.get("generation_params", {})}
logger.info(f"Generation parameters: {gen_params}")
# Apply the chat template
messages = [{"role": "user", "content": input_messages}]
logger.info("Applying chat template")
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
logger.info(f"Generated chat prompt: {json.dumps(prompt)}")
# Tokenize the prompt
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate response
logger.info("Generating response")
with torch.no_grad():
output_tokens = self.model.generate(
**inputs,
**gen_params
)
# Decode the response
logger.info("Decoding response")
output_text = self.tokenizer.batch_decode(output_tokens)[0]
# Extract only the assistant's response by finding the last assistant role block
assistant_start = output_text.rfind("<|start_of_role|>assistant<|end_of_role|>")
if assistant_start != -1:
response = output_text[assistant_start + len("<|start_of_role|>assistant<|end_of_role|>"):].strip()
# Remove any trailing end_of_text marker
if "<|end_of_text|>" in response:
response = response.split("<|end_of_text|>")[0].strip()
# Check for function calling
if "Calling function:" in response:
# Split response into text and function call
parts = response.split("Calling function:", 1)
text_response = parts[0].strip()
function_call = "Calling function:" + parts[1].strip()
logger.info(f"Function call: {function_call}")
logger.info(f"Text response: {text_response}")
# Return both text and tool message
return [
{
"generated_text": text_response,
"details": {
"finish_reason": "stop",
"generated_tokens": len(output_tokens[0])
}
}
]
else:
response = output_text
logger.info(f"Generated response: {json.dumps(response)}")
return [{"generated_text": response, "details": {"finish_reason": "stop", "generated_tokens": len(output_tokens[0])}}]
except Exception as e:
logger.error(f"Error during generation: {str(e)}", exc_info=True)
raise |