chatwithpdfs / model.py
IMvision12's picture
Add
ba60257
raw
history blame
2.29 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from transformers import BitsAndBytesConfig
def initialize_llmchain(
llm_model: str,
temperature: float,
max_tokens: int,
top_k: int,
access_token: str = None,
torch_dtype: str = "auto",
load_in_8bit: bool = False,
load_in_4bit: bool = False,
) -> HuggingFacePipeline:
"""
Initializes a language model chain based on the provided parameters.
Args:
- llm_model (str): The name of the language model to initialize.
- temperature (float): The temperature parameter for text generation.
- max_tokens (int): The maximum number of tokens to generate.
- top_k (int): The top-k parameter for token selection during generation.
- torch_dtype (str): The torch dtype to be used for model inference (default is "auto").
- load_in_8bit (bool): Whether to load the model in 8-bit format (default is False).
- load_in_4bit (bool): Whether to load the model in 4-bit format (default is False).
Returns:
- HuggingFacePipeline: Initialized language model pipeline.
"""
if load_in_8bit:
bnb_config = BitsAndBytesConfig(
load_in_8bit=True
)
elif load_in_4bit:
bnb_config = BitsAndBytesConfig(
load_in_8bit=False,
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
else:
bnb_config = None
model_kwargs = {
"temperature": temperature,
"max_new_tokens": max_tokens,
"top_k": top_k,
"torch_dtype": torch_dtype,
}
# Initialize model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
llm_model,
low_cpu_mem_usage=True,
quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(llm_model)
# Initialize pipeline
pipe = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
token=access_token,
model_kwargs=model_kwargs,
pad_token_id=tokenizer.eos_token_id,
)
llm = HuggingFacePipeline(pipeline=pipe)
return llm