#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Custom handler for Llama 2 text-generation model. Author: Henry Created on: Mon Nov 20, 2023 This module defines a custom handler for the Llama 2 text-generation model, utilizing Hugging Face's transformers pipeline. It's designed to process requests for text generation, leveraging the capabilities of the Llama 2 model. """ import torch from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline, BitsAndBytesConfig from typing import Dict, List, Any import logging import sys logging.basicConfig( level=logging.INFO, format='%(levelname)s - %(asctime)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) class EndpointHandler: """ Handler class for Llama 2 text-generation model inference. This class initializes the model pipeline and processes incoming requests for text generation using the Llama 2 model. """ def __init__(self, path: str = ""): """ Initialize the pipeline for the Llama 2 text-generation model. Args: path (str): Path to the model, defaults to an empty string. """ # Set the global default compute type to float16 self.bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = LlamaTokenizer.from_pretrained(path) model = LlamaForCausalLM.from_pretrained(path, device_map=0, quantization_config=self.bnb_config) self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process a request for text generation. Args: data (Dict[str, Any]): A dictionary containing inputs for text generation. Returns: List[Dict[str, Any]]: The generated text as a list of dictionaries. """ # Log the received data logging.info(f"Received data: {data}") # Get dictionary inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) # Validate the input data if not inputs: raise ValueError(f'inputs are required and \'{inputs}\' is gotten.') # Log the extracted image and question for debugging logging.info(f"Extracted inputs: {inputs}") logging.info(f"Extracted parameters: {parameters}") # Perform the question answering using the model # prediction = self.pipeline(inputs) # pass inputs with all kwargs in data if parameters is not None: prediction = self.pipeline(inputs, **parameters) else: prediction = self.pipeline(inputs) return prediction