Spaces:
Sleeping
Sleeping
from typing import Any, List, Mapping, Optional | |
from langchain_core.callbacks.manager import CallbackManagerForLLMRun | |
from langchain_core.language_models.llms import LLM | |
from typing import Literal | |
import requests | |
from langchain.prompts import PromptTemplate, ChatPromptTemplate | |
from operator import itemgetter | |
from langchain.memory import ChatMessageHistory, ConversationBufferMemory | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import PyMuPDFLoader | |
import os | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
def custom_chain_with_history(llm, memory): | |
prompt = PromptTemplate.from_template("""<s><INST><|system|> | |
You are a helpful, respectful, and honest assistant. Always answer as helpfully as possible while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. | |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. | |
{chat_history} | |
### User: {question} | |
### Assistant: """) | |
def prompt_memory(memory): | |
t = "" | |
for x in memory.chat_memory.messages: | |
# for x in memory.messages: | |
t += f"### Assistant: {x.content}\n" if type(x) is AIMessage else f"### User: {x.content}\n" | |
return "" if len(t) == 0 else t | |
def format_docs(docs): | |
print(len(docs)) | |
return "\n".join([f"{i+1}. {d.page_content}" for i,d in enumerate(docs)]) | |
return {"chat_history":lambda x:prompt_memory(x['memory']), "question":lambda x:x['question']} | prompt | llm | |
class CustomLLM(LLM): | |
repo_id : str | |
api_token : str | |
model_type: Literal["text2text-generation", "text-generation"] | |
max_new_tokens: int = None | |
temperature: float = 0.001 | |
timeout: float = None | |
top_p: float = None | |
top_k : int = None | |
repetition_penalty : float = None | |
stop : List[str] = [] | |
def _llm_type(self) -> str: | |
return "custom" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
headers = {"Authorization": f"Bearer {self.api_token}"} | |
API_URL = f"https://api-inference.huggingface.co/models/{self.repo_id}" | |
parameters_dict = { | |
'max_new_tokens': self.max_new_tokens, | |
'temperature': self.temperature, | |
'timeout': self.timeout, | |
'top_p': self.top_p, | |
'top_k': self.top_k, | |
'repetition_penalty': self.repetition_penalty, | |
'stop':self.stop | |
} | |
if self.model_type == 'text-generation': | |
parameters_dict["return_full_text"]=False | |
data = {"inputs": prompt, "parameters":parameters_dict, "options":{"wait_for_model":True}} | |
data = requests.post(API_URL, headers=headers, json=data).json() | |
return data[0]['generated_text'] | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return { | |
'repo_id': self.repo_id, | |
'model_type':self.model_type, | |
'stop_sequences':self.stop, | |
'max_new_tokens': self.max_new_tokens, | |
'temperature': self.temperature, | |
'timeout': self.timeout, | |
'top_p': self.top_p, | |
'top_k': self.top_k, | |
'repetition_penalty': self.repetition_penalty | |
} | |