|
|
|
|
|
|
|
""" |
|
Custom handler for Llama 2 text-generation model. |
|
|
|
Author: Henry |
|
Created on: Mon Nov 20, 2023 |
|
|
|
This module defines a custom handler for the Llama 2 text-generation model, |
|
utilizing Hugging Face's transformers pipeline. It's designed to process requests |
|
for text generation, leveraging the capabilities of the Llama 2 model. |
|
""" |
|
|
|
import torch |
|
from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline, BitsAndBytesConfig |
|
from typing import Dict, List, Any |
|
import logging |
|
import sys |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(levelname)s - %(asctime)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
) |
|
|
|
|
|
class EndpointHandler: |
|
""" |
|
Handler class for Llama 2 text-generation model inference. |
|
|
|
This class initializes the model pipeline and processes incoming requests |
|
for text generation using the Llama 2 model. |
|
""" |
|
|
|
def __init__(self, path: str = ""): |
|
""" |
|
Initialize the pipeline for the Llama 2 text-generation model. |
|
|
|
Args: |
|
path (str): Path to the model, defaults to an empty string. |
|
""" |
|
|
|
|
|
self.bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(path) |
|
model = LlamaForCausalLM.from_pretrained(path, device_map=0, quantization_config=self.bnb_config) |
|
|
|
self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Process a request for text generation. |
|
|
|
Args: |
|
data (Dict[str, Any]): A dictionary containing inputs for text generation. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: The generated text as a list of dictionaries. |
|
""" |
|
|
|
|
|
logging.info(f"Received data: {data}") |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
parameters = data.pop("parameters", None) |
|
|
|
|
|
if not inputs: |
|
raise ValueError(f'inputs are required and \'{inputs}\' is gotten.') |
|
|
|
|
|
logging.info(f"Extracted inputs: {inputs}") |
|
logging.info(f"Extracted parameters: {parameters}") |
|
|
|
|
|
|
|
|
|
|
|
if parameters is not None: |
|
prediction = self.pipeline(inputs, **parameters) |
|
else: |
|
prediction = self.pipeline(inputs) |
|
|
|
return prediction |
|
|