chatbot_v2 / custom_llm.py
jonathanjordan21's picture
Update custom_llm.py
a65d202 verified
raw
history blame
7.8 kB
from typing import Any, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from typing import Literal
import requests
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from operator import itemgetter
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_models import ChatOpenAI
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader
import os, requests, bs4
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.vectorstores import FAISS
import pickle, asyncio
# os.environ['FAISS_NO_AVX2'] = '1'
def load_web(web_url):
r = requests.get(web_url)
soup=bs4.BeautifulSoup(r.content,"html.parser")
# input_list=
input_list = [div.text.strip() for div in soup.find_all("div") if div.text.strip() !='']
unique_strings = {}
for item in input_list:
# Remove '\n' and leading/trailing whitespaces
# cleaned_item = item.strip('\n').strip()
cleaned_item = item.strip()
# Check if the cleaned_item is not in the dictionary or if it's shorter
if cleaned_item not in unique_strings or len(item) > len(unique_strings[cleaned_item]):
# Add the cleaned_item to the dictionary with the original item as value
unique_strings[cleaned_item] = item
# Create a new list with the unique strings
result_list = list(unique_strings.values())
return result_list
async def create_vectorstore():
API_TOKEN = os.getenv('HF_INFER_API')
loader = os.getenv('knowledge_base')
# web_loader = load_web("https://lintasmediadanawa.com")
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=20)
# docs = splitter.create_documents([loader]+web_loader)
docs = splitter.create_documents([loader])
print(len(docs))
emb_model = HuggingFaceEmbeddings(model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2', encode_kwargs={'normalize_embeddings': True})
# emb_model = HuggingFaceInferenceAPIEmbeddings(
# api_key=API_TOKEN, model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", encode_kwargs={'normalize_embeddings': True}
# )
async def add_docs(d):
db.aadd_documents(await splitter.atransform_documents([d]))
db = await FAISS.afrom_documents(docs, emb_model)
f = pickle.load(open("wi_knowledge.dat", "rb"))
print("Docs len :", len(f))
# tasks = []
# for d in f:
# tasks.append(db.aadd_documents(await splitter.atransform_documents([d])))
# await asyncio.gather(*tasks)
# asyncio.run(db.aadd_documents(asyncio.run(splitter.atransform_documents(f))))
# emb_model = HuggingFaceInferenceAPIEmbeddings(
# api_key=API_TOKEN, model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", encode_kwargs={'normalize_embeddings': True}
# )
# x = open("wi_knowledge.pkl", 'rb')
# db = FAISS.deserialize_from_bytes(
# embeddings=emb_model, serialized=x
# )
# db = pickle.load(x)
# print(db)
# db.add_documents( splitter.transform_documents(docs) )
return db
def custom_chain_with_history(llm, memory):
# prompt = PromptTemplate.from_template("""<s><INST><|system|>
# You are a helpful and informative AI customer service assistant. Always remember to thank the customer when they say thank you and greet them when they greet you.
# You have access to the following context of knowledge base and internal resources to find the most relevant information for the customer's needs:
# {context}
# Respond to the user with the following chat history between you and the user:
# {chat_history}
# <|user|>
# {question}
# <|assistant|>
# """)
prompt = PromptTemplate.from_template("""<s><INST><|system|>
Anda adalah asisten Chatbot customer service.
Anda memiliki akses konteks dibawah ini untuk menemukan informasi yang paling relevan dengan kebutuhan user:
{context}
Berikan respon kepada user berdasarkan riwayat chat berikut, jika tidak ada informasi yang relevan maka itu adalah informasi yang rahasia dan kamu tidak diizinkan untuk menyebarkan informasi tersebut kepada user:
{chat_history}
<|user|>
{question}
<|assistant|>
""")
def prompt_memory(memory):
t = ""
for x in memory.chat_memory.messages:
t += f"<|assistant|>\n<s>{x.content}</s>\n" if type(x) is AIMessage else f"<|user|>\n{x.content}\n"
return "" if len(t) == 0 else t
def format_docs(docs):
# print(len(docs))
return "\n".join([f"{i+1}. {d.page_content}" for i,d in enumerate(docs)])
# prompt = ChatPromptTemplate.from_messages(
# [
# ("system", "You are a helpful chatbot"),
# MessagesPlaceholder(variable_name="history"),
# ("human", "{input}"),
# ]
# )
# return {"chat_history":prompt_memory, "context":asyncio.run(create_vectorstore()).as_retriever(search_type="similarity", search_kwargs={"k": 12}) | format_docs, "question": RunnablePassthrough()} | prompt | llm
return {"chat_history":lambda x:prompt_memory(x['memory']), "context":asyncio.run(create_vectorstore()).as_retriever(search_type="similarity", search_kwargs={"k": 12}) | format_docs, "question": lambda x:x['question']} | prompt | llm
class CustomLLM(LLM):
repo_id : str
api_token : str
model_type: Literal["text2text-generation", "text-generation"]
max_new_tokens: int = None
temperature: float = 0.001
timeout: float = None
top_p: float = None
top_k : int = None
repetition_penalty : float = None
stop : List[str] = []
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
headers = {"Authorization": f"Bearer {self.api_token}"}
API_URL = f"https://api-inference.huggingface.co/models/{self.repo_id}"
parameters_dict = {
'max_new_tokens': self.max_new_tokens,
'temperature': self.temperature,
'timeout': self.timeout,
'top_p': self.top_p,
'top_k': self.top_k,
'repetition_penalty': self.repetition_penalty,
'stop':self.stop
}
if self.model_type == 'text-generation':
parameters_dict["return_full_text"]=False
data = {"inputs": prompt, "parameters":parameters_dict, "options":{"wait_for_model":True}}
data = requests.post(API_URL, headers=headers, json=data).json()
return data[0]['generated_text']
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
'repo_id': self.repo_id,
'model_type':self.model_type,
'stop_sequences':self.stop,
'max_new_tokens': self.max_new_tokens,
'temperature': self.temperature,
'timeout': self.timeout,
'top_p': self.top_p,
'top_k': self.top_k,
'repetition_penalty': self.repetition_penalty
}