Spaces:
Build error
Build error
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from langchain.llms import HuggingFacePipeline | |
from transformers import BitsAndBytesConfig | |
def initialize_llmchain( | |
llm_model: str, | |
temperature: float, | |
max_tokens: int, | |
top_k: int, | |
access_token: str = None, | |
torch_dtype: str = "auto", | |
load_in_8bit: bool = False, | |
load_in_4bit: bool = False, | |
) -> HuggingFacePipeline: | |
""" | |
Initializes a language model chain based on the provided parameters. | |
Args: | |
- llm_model (str): The name of the language model to initialize. | |
- temperature (float): The temperature parameter for text generation. | |
- max_tokens (int): The maximum number of tokens to generate. | |
- top_k (int): The top-k parameter for token selection during generation. | |
- torch_dtype (str): The torch dtype to be used for model inference (default is "auto"). | |
- load_in_8bit (bool): Whether to load the model in 8-bit format (default is False). | |
- load_in_4bit (bool): Whether to load the model in 4-bit format (default is False). | |
Returns: | |
- HuggingFacePipeline: Initialized language model pipeline. | |
""" | |
if load_in_8bit: | |
bnb_config = BitsAndBytesConfig( | |
load_in_8bit=True | |
) | |
elif load_in_4bit: | |
bnb_config = BitsAndBytesConfig( | |
load_in_8bit=False, | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
else: | |
bnb_config = None | |
model_kwargs = { | |
"temperature": temperature, | |
"max_new_tokens": max_tokens, | |
"top_k": top_k, | |
"torch_dtype": torch_dtype, | |
} | |
# Initialize model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
llm_model, | |
low_cpu_mem_usage=True, | |
quantization_config=bnb_config | |
) | |
tokenizer = AutoTokenizer.from_pretrained(llm_model) | |
# Initialize pipeline | |
pipe = pipeline( | |
task="text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
token=access_token, | |
model_kwargs=model_kwargs, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
return llm |