autodocs / model.py
moctardiallo's picture
Merge branch 'rag'
bc4c927
raw
history blame
3.65 kB
import os
from huggingface_hub import InferenceClient
from langchain.schema import SystemMessage, AIMessage, HumanMessage
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from data import Data
class Model:
def __init__(self, model_id="meta-llama/Llama-3.2-1B-Instruct"):
self.client = InferenceClient(model_id, token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
self.llm = HuggingFaceEndpoint(
repo_id="HuggingFaceH4/zephyr-7b-beta",
task="text-generation",
max_new_tokens=512,
do_sample=False,
repetition_penalty=1.03,
)
self.chat_model = ChatHuggingFace(llm=self.llm, token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
def build_prompt(self, question, context_urls):
data = Data(context_urls)
context = data.retriever.invoke(f"{question}")[0].page_content
prompt = f"""
Use the following piece of context to answer the question asked.
Please try to provide the answer only based on the context
{context}
Question:{question}
Helpful Answers:
"""
return prompt
def _build_prompt_rag(self):
prompt_template="""
Use the following piece of context to answer the question asked.
Please try to provide the answer only based on the context
{context}
Question:{question}
Helpful Answers:
"""
prompt=PromptTemplate(template=prompt_template,input_variables=["context","question"])
return prompt
def _retrieval_qa(self, url):
data = Data([url])
prompt = self._build_prompt_rag()
return RetrievalQA.from_chain_type(
llm=self.chat_model,
chain_type="stuff",
retriever=data.retriever,
return_source_documents=True,
chain_type_kwargs={"prompt":prompt}
)
def predict(self, message, history, url, max_tokens, temperature, top_p):
history_langchain_format = [SystemMessage(content="You're a helpful python developer assistant")]
for msg in history:
if msg['role'] == "user":
history_langchain_format.append(HumanMessage(content=msg['content']))
elif msg['role'] == "assistant":
history_langchain_format.append(AIMessage(content=msg['content']))
history_langchain_format.append(HumanMessage(content=message))
# ai_msg = self.chat_model.invoke(history_langchain_format)
# return ai_msg.content
ret = self._retrieval_qa(url)
return ret.invoke({"query": message})['result']
def respond(
self,
message,
history: list[tuple[str, str]],
url,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": url}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": self.build_prompt(message, [url])})
response = ""
for message in self.client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
model = Model()