File size: 5,353 Bytes
9146a36
f72bb15
 
d0b295c
f72bb15
d0b295c
9146a36
f72bb15
 
 
 
 
 
9146a36
f72bb15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fc959a
556a89f
f72bb15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9146a36
f72bb15
 
0e056a8
f72bb15
2fc959a
 
f72bb15
 
 
 
 
 
 
 
2fc959a
0e056a8
f72bb15
 
 
 
 
d0b295c
f72bb15
 
 
 
 
 
 
 
 
 
 
 
 
0a3c393
f72bb15
 
2a5bbd8
 
 
 
 
 
 
 
 
 
 
 
 
407d7fe
2a5bbd8
 
 
407d7fe
f7937a7
 
 
 
 
 
 
 
 
2a5bbd8
 
 
d0b295c
f7937a7
f72bb15
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import logging
import time

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path: str = ""):
        logger.info(f"Initializing EndpointHandler with model path: {path}")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(path)
            logger.info("Tokenizer loaded successfully")
            
            self.model = AutoModelForCausalLM.from_pretrained(
                path,
                device_map="auto"
            )
            logger.info(f"Model loaded successfully. Device map: {self.model.device}")
            
            self.model.eval()
            logger.info("Model set to evaluation mode")
            
            # Default generation parameters
            self.default_params = {
                "max_new_tokens": 1000,
                "temperature": 0.01,
                "top_p": 0.9,
                "top_k": 50,
                "repetition_penalty": 1.1,
                "do_sample": True
            }
            logger.info(f"Default generation parameters: {self.default_params}")
        except Exception as e:
            logger.error(f"Error during initialization: {str(e)}")
            raise

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
        """Handle chat completion requests.
        
        Args:
            data: Dictionary containing:
                - messages: List of message dictionaries with 'role' and 'content'
                - generation_params: Optional dictionary of generation parameters
                
        Returns:
            List containing the generated response message
        """
        try:
            logger.info("Processing new request")
            logger.info(f"Input data: {data}")
            
            input_messages = data.get("inputs", [])
            if not input_messages:
                logger.warning("No input messages provided")
                return [{"role": "assistant", "content": "No input messages provided"}]
            
            # Get generation parameters, use defaults for missing values
            gen_params = {**self.default_params, **data.get("generation_params", {})}
            logger.info(f"Generation parameters: {gen_params}")
            
            # Apply the chat template
            messages = [{"role": "user", "content": input_messages}] 
            logger.info("Applying chat template")
            prompt = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            logger.info(f"Generated chat prompt: {json.dumps(prompt)}")
            
            # Tokenize the prompt
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            
            # Generate response
            logger.info("Generating response")
            with torch.no_grad():
                output_tokens = self.model.generate(
                    **inputs,
                    **gen_params
                )
            
            # Decode the response
            logger.info("Decoding response")
            output_text = self.tokenizer.batch_decode(output_tokens)[0]
            
            # Extract only the assistant's response by finding the last assistant role block
            assistant_start = output_text.rfind("<|start_of_role|>assistant<|end_of_role|>")
            if assistant_start != -1:
                response = output_text[assistant_start + len("<|start_of_role|>assistant<|end_of_role|>"):].strip()
                # Remove any trailing end_of_text marker
                if "<|end_of_text|>" in response:
                    response = response.split("<|end_of_text|>")[0].strip()
                
                # Check for function calling
                if "Calling function:" in response:
                    # Split response into text and function call
                    parts = response.split("Calling function:", 1)
                    text_response = parts[0].strip()
                    function_call = "Calling function:" + parts[1].strip()
                    
                    logger.info(f"Function call: {function_call}")
                    logger.info(f"Text response: {text_response}")
                    # Return both text and tool message
                    return [
                        {
                            "generated_text": text_response,
                            "details": {
                                "finish_reason": "stop",
                                "generated_tokens": len(output_tokens[0])
                            }
                        }
                    ]
            else:
                response = output_text
                
            logger.info(f"Generated response: {json.dumps(response)}")
            return [{"generated_text": response, "details": {"finish_reason": "stop", "generated_tokens": len(output_tokens[0])}}] 
            
        except Exception as e:
            logger.error(f"Error during generation: {str(e)}", exc_info=True)
            raise