CineAI's picture
Update llm/huggingfacehub/hf_model.py
a4a08b7 verified
raw
history blame
5.54 kB
import os
import yaml
import logging
from abc import ABC
from llm.hf_interface import HFInterface
from llm.config import config
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import HuggingFaceHub
logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL) # because if something went wrong in execution, application can't be work anyway
file_handler = logging.FileHandler(
"logs/chelsea_llm_huggingfacehub.log") # for all modules here template for logs file is "llm/logs/chelsea_{module_name}_{dir_name}.log"
logger.setLevel(logging.INFO) # informed
formatted = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatted)
logger.addHandler(file_handler)
logger.info("Getting information from hf_model module")
llm_dir = '/home/user/app/llm/'
path_to_yaml = os.path.join(os.getcwd(), "llm/prompts.yaml")
print("Path to prompts : ", path_to_yaml)
class HF_Mistaril(HFInterface, ABC):
def __init__(self, prompt_entity: str, prompt_id: int = 0):
self.prompt_entity = prompt_entity
self.prompt_id = prompt_id
self.model_config = config["HF_Mistrail"]
# Додати repetition_penalty, task?, top_p, stop_sequences
self.llm = HuggingFaceHub(
repo_id=self.model_config["model"],
# temperature=self.model_config["temperature"],
# max_new_tokens=self.model_config["max_new_tokens"],
# top_k=self.model_config["top_k"],
model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"],
"temperature": self.model_config["temperature"],
"max_new_tokens": self.model_config["max_new_tokens"],
"top_k": self.model_config["top_k"],
},
huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
@staticmethod
def __read_yaml():
try:
yaml_file = os.path.join(llm_dir, 'prompts.yaml')
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
f.close()
return data
except Exception as e:
print(f"Execution filed : {e}")
logger.error(msg="Execution filed", exc_info=e)
def execution(self):
try:
data = self.__read_yaml()
prompts = data["prompts"][
self.prompt_id] #get second prompt from yaml, need change id parameter to get other prompt
template = prompts["prompt_template"]
prompt = PromptTemplate(template=template, input_variables=["entity"])
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
output = llm_chain.invoke(self.prompt_entity)
return output["text"]
except Exception as e:
print(f"Execution filed : {e}")
logger.critical(msg="Execution filed", exc_info=e)
def __str__(self):
return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"
def __repr__(self):
return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"
class HF_TinyLlama(HFInterface, ABC):
def __init__(self, prompt_entity: str, prompt_id: int = 0):
self.prompt_entity = prompt_entity
self.prompt_id = prompt_id
self.model_config = config["HF_TinyLlama"]
self.llm = HuggingFaceHub(
repo_id=self.model_config["model"],
# temperature=self.model_config["temperature"],
# max_new_tokens=self.model_config["max_new_tokens"],
# top_k=self.model_config["top_k"],
model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"],
"temperature": self.model_config["temperature"],
"max_new_tokens": self.model_config["max_new_tokens"],
"top_k": self.model_config["top_k"],
},
huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
@staticmethod
def __read_yaml():
try:
yaml_file = os.path.join(llm_dir, 'prompts.yaml')
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
f.close()
return data
except Exception as e:
print(f"Execution filed : {e}")
logger.error(msg="Execution filed", exc_info=e)
def execution(self):
try:
data = self.__read_yaml()
prompts = data["prompts"][
self.prompt_id] #get second prompt from yaml, need change id parameter to get other prompt
template = prompts["prompt_template"]
prompt = PromptTemplate(template=template, input_variables=["entity"])
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
output = llm_chain.invoke(self.prompt_entity)
return output["text"]
except Exception as e:
print(f"Execution filed : {e}")
logger.critical(msg="Execution filed", exc_info=e)
def __str__(self):
return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"
def __repr__(self):
return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"