Llama-2-7b-chat-hf / handler.py
hdnh2006
handler should work now
643a3e1
raw
history blame contribute delete
No virus
2.93 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Custom handler for Llama 2 text-generation model.
Author: Henry
Created on: Mon Nov 20, 2023
This module defines a custom handler for the Llama 2 text-generation model,
utilizing Hugging Face's transformers pipeline. It's designed to process requests
for text generation, leveraging the capabilities of the Llama 2 model.
"""
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline, BitsAndBytesConfig
from typing import Dict, List, Any
import logging
import sys
logging.basicConfig(
level=logging.INFO,
format='%(levelname)s - %(asctime)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
class EndpointHandler:
"""
Handler class for Llama 2 text-generation model inference.
This class initializes the model pipeline and processes incoming requests
for text generation using the Llama 2 model.
"""
def __init__(self, path: str = ""):
"""
Initialize the pipeline for the Llama 2 text-generation model.
Args:
path (str): Path to the model, defaults to an empty string.
"""
# Set the global default compute type to float16
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = LlamaTokenizer.from_pretrained(path)
model = LlamaForCausalLM.from_pretrained(path, device_map=0, quantization_config=self.bnb_config)
self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process a request for text generation.
Args:
data (Dict[str, Any]): A dictionary containing inputs for text generation.
Returns:
List[Dict[str, Any]]: The generated text as a list of dictionaries.
"""
# Log the received data
logging.info(f"Received data: {data}")
# Get dictionary
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# Validate the input data
if not inputs:
raise ValueError(f'inputs are required and \'{inputs}\' is gotten.')
# Log the extracted image and question for debugging
logging.info(f"Extracted inputs: {inputs}")
logging.info(f"Extracted parameters: {parameters}")
# Perform the question answering using the model
# prediction = self.pipeline(inputs)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = self.pipeline(inputs, **parameters)
else:
prediction = self.pipeline(inputs)
return prediction