aelitta's picture
Upload folder using huggingface_hub
4bdb245 verified
raw
history blame
No virus
8.8 kB
import logging
from collections.abc import Callable
from typing import Any
from injector import inject, singleton
from llama_index.core.llms import LLM, MockLLM
from llama_index.core.settings import Settings as LlamaIndexSettings
from llama_index.core.utils import set_global_tokenizer
from transformers import AutoTokenizer # type: ignore
from private_gpt.components.llm.prompt_helper import get_prompt_style
from private_gpt.paths import models_cache_path, models_path
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
@singleton
class LLMComponent:
llm: LLM
@inject
def __init__(self, settings: Settings) -> None:
llm_mode = settings.llm.mode
if settings.llm.tokenizer and settings.llm.mode != "mock":
# Try to download the tokenizer. If it fails, the LLM will still work
# using the default one, which is less accurate.
try:
set_global_tokenizer(
AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=settings.llm.tokenizer,
cache_dir=str(models_cache_path),
token=settings.huggingface.access_token,
)
)
except Exception as e:
logger.warning(
"Failed to download tokenizer %s. Falling back to "
"default tokenizer.",
settings.llm.tokenizer,
e,
)
logger.info("Initializing the LLM in mode=%s", llm_mode)
match settings.llm.mode:
case "llamacpp":
try:
from llama_index.llms.llama_cpp import LlamaCPP # type: ignore
except ImportError as e:
raise ImportError(
"Local dependencies not found, install with `poetry install --extras llms-llama-cpp`"
) from e
prompt_style = get_prompt_style(settings.llm.prompt_style)
settings_kwargs = {
"tfs_z": settings.llamacpp.tfs_z, # ollama and llama-cpp
"top_k": settings.llamacpp.top_k, # ollama and llama-cpp
"top_p": settings.llamacpp.top_p, # ollama and llama-cpp
"repeat_penalty": settings.llamacpp.repeat_penalty, # ollama llama-cpp
"n_gpu_layers": -1,
"offload_kqv": True,
}
self.llm = LlamaCPP(
model_path=str(models_path / settings.llamacpp.llm_hf_model_file),
temperature=settings.llm.temperature,
max_new_tokens=settings.llm.max_new_tokens,
context_window=settings.llm.context_window,
generate_kwargs={},
callback_manager=LlamaIndexSettings.callback_manager,
# All to GPU
model_kwargs=settings_kwargs,
# transform inputs into Llama2 format
messages_to_prompt=prompt_style.messages_to_prompt,
completion_to_prompt=prompt_style.completion_to_prompt,
verbose=True,
)
case "sagemaker":
try:
from private_gpt.components.llm.custom.sagemaker import SagemakerLLM
except ImportError as e:
raise ImportError(
"Sagemaker dependencies not found, install with `poetry install --extras llms-sagemaker`"
) from e
self.llm = SagemakerLLM(
endpoint_name=settings.sagemaker.llm_endpoint_name,
max_new_tokens=settings.llm.max_new_tokens,
context_window=settings.llm.context_window,
)
case "openai":
try:
from llama_index.llms.openai import OpenAI # type: ignore
except ImportError as e:
raise ImportError(
"OpenAI dependencies not found, install with `poetry install --extras llms-openai`"
) from e
openai_settings = settings.openai
self.llm = OpenAI(
api_base=openai_settings.api_base,
api_key=openai_settings.api_key,
model=openai_settings.model,
)
case "openailike":
try:
from llama_index.llms.openai_like import OpenAILike # type: ignore
except ImportError as e:
raise ImportError(
"OpenAILike dependencies not found, install with `poetry install --extras llms-openai-like`"
) from e
prompt_style = get_prompt_style(settings.llm.prompt_style)
openai_settings = settings.openai
self.llm = OpenAILike(
api_base=openai_settings.api_base,
api_key=openai_settings.api_key,
model=openai_settings.model,
is_chat_model=True,
max_tokens=settings.llm.max_new_tokens,
api_version="",
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
max_new_tokens=settings.llm.max_new_tokens,
messages_to_prompt=prompt_style.messages_to_prompt,
completion_to_prompt=prompt_style.completion_to_prompt,
)
case "ollama":
try:
from llama_index.llms.ollama import Ollama # type: ignore
except ImportError as e:
raise ImportError(
"Ollama dependencies not found, install with `poetry install --extras llms-ollama`"
) from e
ollama_settings = settings.ollama
settings_kwargs = {
"tfs_z": ollama_settings.tfs_z, # ollama and llama-cpp
"num_predict": ollama_settings.num_predict, # ollama only
"top_k": ollama_settings.top_k, # ollama and llama-cpp
"top_p": ollama_settings.top_p, # ollama and llama-cpp
"repeat_last_n": ollama_settings.repeat_last_n, # ollama
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
}
self.llm = Ollama(
model=ollama_settings.llm_model,
base_url=ollama_settings.api_base,
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
additional_kwargs=settings_kwargs,
request_timeout=ollama_settings.request_timeout,
)
if (
ollama_settings.keep_alive
!= ollama_settings.model_fields["keep_alive"].default
):
# Modify Ollama methods to use the "keep_alive" field.
def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
kwargs["keep_alive"] = ollama_settings.keep_alive
return func(*args, **kwargs)
return wrapper
Ollama.chat = add_keep_alive(Ollama.chat)
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat)
Ollama.complete = add_keep_alive(Ollama.complete)
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete)
case "azopenai":
try:
from llama_index.llms.azure_openai import ( # type: ignore
AzureOpenAI,
)
except ImportError as e:
raise ImportError(
"Azure OpenAI dependencies not found, install with `poetry install --extras llms-azopenai`"
) from e
azopenai_settings = settings.azopenai
self.llm = AzureOpenAI(
model=azopenai_settings.llm_model,
deployment_name=azopenai_settings.llm_deployment_name,
api_key=azopenai_settings.api_key,
azure_endpoint=azopenai_settings.azure_endpoint,
api_version=azopenai_settings.api_version,
)
case "mock":
self.llm = MockLLM()