InstructiPhi / handler.py
acecalisto3's picture
Update handler.py
f363f0a verified
import json
import logging
import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load configuration settings from a separate file (config.json)
# Example configuration file:
#{
# "architectures": [
# "AceCalisto3"
# ],
# "attention_probs_dropout_prob": 0.1,
# "bos_token_id": 0,
# "eos_token_id": 2,
# "hidden_act": "gelu",
# "hidden_dropout_prob": 0.1,
# "hidden_size": 1024,
# }
try:
with open('config.json') as f:
config = json.load(f)
except FileNotFoundError:
logger.error("Configuration file 'config.json' not found. Using default settings.")
config = {
"model_name": "acecalisto3/InstructiPhi", # Default model name
"max_length": 16788, # Default max length
"logging_level": "INFO" # Default logging level
}
# Load model and tokenizer
model_name = config["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Set logging level from configuration
logging.basicConfig(level=config["logging_level"])
def handle_request(event, context):
"""Handles incoming requests to the deployed model.
Args:
event: The event data from the deployment platform.
context: The context data from the deployment platform.
Returns:
A dictionary containing the response status code and body.
"""
try:
# Extract input text from the event
input_text = event.get('body')
if not input_text:
return {
'statusCode': 400,
'body': json.dumps({'error': 'Missing input text'})
}
# Input validation: Check length
if len(input_text) > 1000: # Set a reasonable limit
return {
'statusCode': 400,
'body': json.dumps({'error': 'Input text is too long'})
}
# Tokenize the input text
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# Generate the response using the model
output = model.generate(input_ids, max_length=config["max_length"])
# Decode the generated response
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Return a successful response with structured output
return {
'statusCode': 200,
'body': json.dumps({
'response': generated_text,
'model': model_name, # Include model name in the output
'timestamp': datetime.datetime.now().isoformat()
})
}
except Exception as e:
# Log the error with more context
logger.error(f"Error processing request: {e}, input: {input_text}")
return {
'statusCode': 500,
'body': json.dumps({'error': 'Internal server error'})
}