from typing import Dict, List, Any from llama_cpp import Llama import gemma_tools import os MAX_TOKENS = 1000 class EndpointHandler: def __init__(self, model_dir: str = None): """ Initialize the EndpointHandler with the path to the model directory. :param model_dir: Path to the directory containing the model file. """ if model_dir: # Update the model filename to match the one in your repository model_path = os.path.join( model_dir, "comic_mistral-v5.2.q5_0.gguf") if not os.path.exists(model_path): raise FileNotFoundError( f"The model file was not found at {model_path}") try: self.model = Llama( model_path=model_path, n_ctx=MAX_TOKENS, # Use n_ctx for context size in llama_cpp ) except Exception as e: raise RuntimeError(f"Failed to load the model: {e}") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Handle incoming requests for model inference. :param data: Dictionary containing input data and parameters for the model. :return: A list with a dictionary containing the status and response or error details. """ # Extract and validate arguments from the data args_check = gemma_tools.get_args_or_none(data) if not args_check[0]: # If validation failed return [{ "status": args_check.get("status", "error"), "reason": args_check.get("reason", "unknown"), "description": args_check.get("description", "Validation error in arguments") }] # If validation passed, args are in the second element of the tuple args = args_check[1] # Define the formatting template for the prompt prompt_format = "system\n{system_prompt} \nuser\n{inputs} \nmodel" try: formatted_prompt = prompt_format.format(**args) except Exception as e: return [{ "status": "error", "reason": "Invalid format", "detail": str(e) }] # Parse max_length, default to 212 if not provided or invalid max_length = data.get("max_length", 212) try: max_length = int(max_length) except ValueError: return [{ "status": "error", "reason": "max_length must be an integer", "detail": "max_length was not a valid integer" }] # Perform inference try: res = self.model( formatted_prompt, temperature=args["temperature"], top_p=args["top_p"], top_k=args["top_k"], max_tokens=max_length ) except Exception as e: return [{ "status": "error", "reason": "Inference failed", "detail": str(e) }] return [{ "status": "success", # Extract the text from the response "response": res['choices'][0]['text'].strip() }] # Usage in your script or where the handler is instantiated: try: handler = EndpointHandler("/repository") except (FileNotFoundError, RuntimeError) as e: print(f"Initialization error: {e}") exit(1) # Exit with an error code if the handler cannot be initialized