|
|
|
|
|
|
|
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI |
|
import torch |
|
|
|
|
|
|
|
def setup_llm(model_name: str = "microsoft/phi-3-mini-4k-instruct", |
|
device: str = None, |
|
context_window: int = 4096, |
|
max_new_tokens: int = 512): |
|
"""Set up the language model for the CSV chatbot.""" |
|
|
|
if device is None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
try: |
|
|
|
from llama_index.llms.huggingface import HuggingFaceLLM |
|
|
|
|
|
model_kwargs = { |
|
"trust_remote_code": True, |
|
"torch_dtype": torch.float16, |
|
} |
|
|
|
if device == "cuda": |
|
from transformers import BitsAndBytesConfig |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16 |
|
) |
|
model_kwargs["quantization_config"] = quantization_config |
|
|
|
|
|
llm = HuggingFaceLLM( |
|
model_name=model_name, |
|
tokenizer_name=model_name, |
|
context_window=context_window, |
|
max_new_tokens=max_new_tokens, |
|
generate_kwargs={"temperature": 0.7, "top_p": 0.95}, |
|
device_map=device, |
|
tokenizer_kwargs={"trust_remote_code": True}, |
|
model_kwargs=model_kwargs, |
|
|
|
cache_folder="./model_cache" |
|
) |
|
|
|
except (ImportError, AttributeError): |
|
|
|
try: |
|
from llama_index.llms import HuggingFaceInferenceAPI |
|
|
|
llm = HuggingFaceInferenceAPI( |
|
model_name=model_name, |
|
tokenizer_name=model_name, |
|
context_window=context_window, |
|
max_new_tokens=max_new_tokens, |
|
generate_kwargs={"temperature": 0.7, "top_p": 0.95} |
|
) |
|
except: |
|
|
|
from llama_index.llms.base import LLM |
|
from llama_index.llms.huggingface import HuggingFaceInference |
|
|
|
llm = HuggingFaceInference( |
|
model_name=model_name, |
|
tokenizer_name=model_name, |
|
context_window=context_window, |
|
max_new_tokens=max_new_tokens, |
|
generate_kwargs={"temperature": 0.7, "top_p": 0.95} |
|
) |
|
|
|
return llm |
|
|