Update llm/llm.py
Browse files- llm/llm.py +18 -0
llm/llm.py
CHANGED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
from langchain.chains import LLMChain
|
3 |
+
from langchain.llms import HuggingFaceHub
|
4 |
+
from .config import config
|
5 |
+
from prompts.prompt import prompts
|
6 |
+
|
7 |
+
class LLM_chain:
|
8 |
+
def __init__(self):
|
9 |
+
self.llm = HuggingFaceHub(
|
10 |
+
repo_id=config["model"],
|
11 |
+
model_kwargs={"temperature": config["temperature"], "max_new_tokens": config["max_new_tokens"], "top_k": config["top_k"], "load_in_8bit": config["load_in_8bit"]})
|
12 |
+
|
13 |
+
def __call__(self, entity: str, id: int = 0):
|
14 |
+
template = prompts[id]["prompt_template"]
|
15 |
+
prompt = PromptTemplate(template=template, input_variables=["entity"])
|
16 |
+
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
|
17 |
+
output = llm_chain.invoke(entity)
|
18 |
+
return output
|