File size: 5,536 Bytes
d5436e0
 
c4f6685
d5436e0
 
 
f12e651
d5436e0
 
 
 
a4a08b7
d5436e0
 
 
a4a08b7
d5436e0
 
a4a08b7
d5436e0
 
 
 
 
 
 
 
 
58abf09
e805397
a4a08b7
 
 
2f68799
d5436e0
f12e651
d5436e0
 
 
 
 
 
c4f6685
a4a08b7
d5436e0
a4a08b7
 
 
 
 
 
 
 
2f68799
96e2394
d5436e0
 
 
 
58abf09
c4f6685
 
 
d5436e0
 
5cc25cc
d5436e0
 
 
 
 
 
 
 
 
 
 
 
 
5cc25cc
d5436e0
 
 
 
 
 
 
 
 
f12e651
d5436e0
 
 
 
 
 
a4a08b7
d5436e0
a4a08b7
 
 
 
 
 
 
 
2f68799
96e2394
d5436e0
 
 
 
58abf09
c4f6685
 
 
d5436e0
 
5cc25cc
d5436e0
 
 
 
 
 
 
 
 
 
 
 
 
5cc25cc
d5436e0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import yaml
import logging

from abc import ABC

from llm.hf_interface import HFInterface
from llm.config import config

from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import HuggingFaceHub

logger = logging.getLogger(__name__)

logger.setLevel(logging.CRITICAL)  # because if something went wrong in execution, application can't be work anyway

file_handler = logging.FileHandler(
    "logs/chelsea_llm_huggingfacehub.log")  # for all modules here template for logs file is "llm/logs/chelsea_{module_name}_{dir_name}.log"
logger.setLevel(logging.INFO)  # informed

formatted = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatted)

logger.addHandler(file_handler)

logger.info("Getting information from hf_model module")

llm_dir = '/home/user/app/llm/'

path_to_yaml = os.path.join(os.getcwd(), "llm/prompts.yaml")

print("Path to prompts : ", path_to_yaml)


class HF_Mistaril(HFInterface, ABC):
    def __init__(self, prompt_entity: str, prompt_id: int = 0):
        self.prompt_entity = prompt_entity
        self.prompt_id = prompt_id

        self.model_config = config["HF_Mistrail"]

        # Додати repetition_penalty, task?, top_p, stop_sequences
        self.llm = HuggingFaceHub(
            repo_id=self.model_config["model"],
            # temperature=self.model_config["temperature"],
            # max_new_tokens=self.model_config["max_new_tokens"],
            # top_k=self.model_config["top_k"],
            model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"], 
                          "temperature": self.model_config["temperature"],
                          "max_new_tokens": self.model_config["max_new_tokens"],
                          "top_k": self.model_config["top_k"],
                         },
            huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
        )

    @staticmethod
    def __read_yaml():
        try:
            yaml_file = os.path.join(llm_dir, 'prompts.yaml')
            with open(yaml_file, 'r') as f:
                data = yaml.safe_load(f)
            f.close()
            return data
        except Exception as e:
            print(f"Execution filed : {e}")
            logger.error(msg="Execution filed", exc_info=e)

    def execution(self):
        try:
            data = self.__read_yaml()
            prompts = data["prompts"][
                self.prompt_id]  #get second prompt from yaml, need change id parameter to get other prompt
            template = prompts["prompt_template"]
            prompt = PromptTemplate(template=template, input_variables=["entity"])
            llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
            output = llm_chain.invoke(self.prompt_entity)
            return output["text"]
        except Exception as e:
            print(f"Execution filed : {e}")
            logger.critical(msg="Execution filed", exc_info=e)

    def __str__(self):
        return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"

    def __repr__(self):
        return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"


class HF_TinyLlama(HFInterface, ABC):
    def __init__(self, prompt_entity: str, prompt_id: int = 0):
        self.prompt_entity = prompt_entity
        self.prompt_id = prompt_id

        self.model_config = config["HF_TinyLlama"]

        self.llm = HuggingFaceHub(
            repo_id=self.model_config["model"],
            # temperature=self.model_config["temperature"],
            # max_new_tokens=self.model_config["max_new_tokens"],
            # top_k=self.model_config["top_k"],
            model_kwargs={"load_in_8bit": self.model_config["load_in_8bit"], 
                          "temperature": self.model_config["temperature"],
                          "max_new_tokens": self.model_config["max_new_tokens"],
                          "top_k": self.model_config["top_k"],
                         },
            huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN")
        )

    @staticmethod
    def __read_yaml():
        try:
            yaml_file = os.path.join(llm_dir, 'prompts.yaml')
            with open(yaml_file, 'r') as f:
                data = yaml.safe_load(f)
            f.close()
            return data
        except Exception as e:
            print(f"Execution filed : {e}")
            logger.error(msg="Execution filed", exc_info=e)

    def execution(self):
        try:
            data = self.__read_yaml()
            prompts = data["prompts"][
                self.prompt_id]  #get second prompt from yaml, need change id parameter to get other prompt
            template = prompts["prompt_template"]
            prompt = PromptTemplate(template=template, input_variables=["entity"])
            llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
            output = llm_chain.invoke(self.prompt_entity)
            return output["text"]
        except Exception as e:
            print(f"Execution filed : {e}")
            logger.critical(msg="Execution filed", exc_info=e)

    def __str__(self):
        return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"

    def __repr__(self):
        return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"