Spaces:
Sleeping
Sleeping
import os | |
from huggingface_hub import InferenceClient | |
from langchain.schema import SystemMessage, AIMessage, HumanMessage | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from data import Data | |
class Model: | |
def __init__(self, model_id="meta-llama/Llama-3.2-1B-Instruct"): | |
self.client = InferenceClient(model_id, token=os.getenv("HUGGINGFACEHUB_API_TOKEN")) | |
self.llm = HuggingFaceEndpoint( | |
repo_id="HuggingFaceH4/zephyr-7b-beta", | |
task="text-generation", | |
max_new_tokens=512, | |
do_sample=False, | |
repetition_penalty=1.03, | |
) | |
self.chat_model = ChatHuggingFace(llm=self.llm, token=os.getenv("HUGGINGFACEHUB_API_TOKEN")) | |
def build_prompt(self, question, context_urls): | |
data = Data(context_urls) | |
context = data.retriever.invoke(f"{question}")[0].page_content | |
prompt = f""" | |
Use the following piece of context to answer the question asked. | |
Please try to provide the answer only based on the context | |
{context} | |
Question:{question} | |
Helpful Answers: | |
""" | |
return prompt | |
def _build_prompt_rag(self): | |
prompt_template=""" | |
Use the following piece of context to answer the question asked. | |
Please try to provide the answer only based on the context | |
{context} | |
Question:{question} | |
Helpful Answers: | |
""" | |
prompt=PromptTemplate(template=prompt_template,input_variables=["context","question"]) | |
return prompt | |
def _retrieval_qa(self, url): | |
data = Data([url]) | |
prompt = self._build_prompt_rag() | |
return RetrievalQA.from_chain_type( | |
llm=self.chat_model, | |
chain_type="stuff", | |
retriever=data.retriever, | |
return_source_documents=True, | |
chain_type_kwargs={"prompt":prompt} | |
) | |
def predict(self, message, history, url, max_tokens, temperature, top_p): | |
history_langchain_format = [SystemMessage(content="You're a helpful python developer assistant")] | |
for msg in history: | |
if msg['role'] == "user": | |
history_langchain_format.append(HumanMessage(content=msg['content'])) | |
elif msg['role'] == "assistant": | |
history_langchain_format.append(AIMessage(content=msg['content'])) | |
history_langchain_format.append(HumanMessage(content=message)) | |
# ai_msg = self.chat_model.invoke(history_langchain_format) | |
# return ai_msg.content | |
ret = self._retrieval_qa(url) | |
return ret.invoke({"query": message})['result'] | |
def respond( | |
self, | |
message, | |
history: list[tuple[str, str]], | |
url, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
messages = [{"role": "system", "content": url}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": self.build_prompt(message, [url])}) | |
response = "" | |
for message in self.client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
model = Model() |