|
import os |
|
|
|
import requests |
|
from huggingface_hub import InferenceClient |
|
from langchain.chains import RetrievalQA |
|
from langchain.prompts import PromptTemplate |
|
from langchain_community.llms import CTransformers |
|
from langchain_core.vectorstores import VectorStoreRetriever |
|
|
|
|
|
class LLMModel: |
|
base_model = "TheBloke/Llama-2-7B-GGUF" |
|
specific_model = "llama-2-7b.Q4_K_M.gguf" |
|
token_model = "meta-llama/Llama-2-7b-hf" |
|
llm_config = {'context_length': 2048, 'max_new_tokens': 1024, 'temperature': 0.3, 'top_p': 1.0} |
|
|
|
question_answer_system_prompt = """You are a helpful question answer assistant. Given the following context and a question, provide a set of potential questions and answers. |
|
Keep answers brief and well-structured. Do not give one word answers.""" |
|
final_assistant_system_prompt = """You are a helpful assistant. Given the following list of relevant questions and answers, generate an answer based on this list only. |
|
Keep answers brief and well-structured. Do not give one word answers. |
|
If the answer is not found in the list, kindly state "I don't know.". Don't try to make up an answer.""" |
|
template = """<s>[INST] <<SYS>> |
|
You are a question answer assistant. Given the following context and a question, generate an answer based on this context only. |
|
Keep answers brief and well-structured. Do not give one word answers. |
|
If the answer is not found in the context, kindly state "I don't know.". Don't try to make up an answer. |
|
<</SYS>> |
|
|
|
Context: {context} |
|
|
|
Question: Give me a step by step explanation of {question}[/INST] |
|
Answer:""" |
|
qa_chain_prompt = PromptTemplate.from_template(template) |
|
retriever = None |
|
|
|
hf_token = os.getenv('HF_TOKEN') |
|
api_url = os.getenv('API_URL') |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
client = InferenceClient(api_url) |
|
|
|
|
|
llm = None |
|
|
|
def __init__(self, retriever: VectorStoreRetriever): |
|
self.retriever = retriever |
|
|
|
def create_qa_chain(self): |
|
return RetrievalQA.from_chain_type( |
|
llm=self.llm, |
|
chain_type="stuff", |
|
retriever=self.retriever, |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": self.qa_chain_prompt}, |
|
) |
|
|
|
def format_retrieved_docs(self, docs): |
|
all_docs = [] |
|
for doc in docs: |
|
if "source" in doc.metadata: |
|
all_docs.append(f"""Document: {doc.metadata['source']}\nContent: {doc.page_content}\n\n""") |
|
return all_docs |
|
|
|
def format_query(self, question, context, system_prompt): |
|
prompt = f"""[INST] {system_prompt} |
|
|
|
Context: {context} |
|
|
|
Question: Give me a step by step explanation of {question}[/INST]""" |
|
return prompt |
|
|
|
def format_question(self, question): |
|
relevant_docs = self.retriever.get_relevant_documents(question) |
|
formatted_docs = self.format_retrieved_docs(relevant_docs) |
|
return self.format_query(question, formatted_docs, self.final_assistant_system_prompt) |
|
|
|
def get_potential_question_answer(self, document_chunk: str): |
|
prompt = self.format_query("potential questions and answers.", document_chunk, self.question_answer_system_prompt) |
|
return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4) |
|
|
|
def answer_question_inference_text_gen(self, question): |
|
prompt = self.format_question(question) |
|
return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4) |
|
|
|
def answer_question_inference(self, question): |
|
relevant_docs = self.retriever.get_relevant_documents(question) |
|
formatted_docs = "".join(self.format_retrieved_docs(relevant_docs)) |
|
if not formatted_docs: |
|
return "No uploaded documents. Please try upload a document on the left side." |
|
else: |
|
print(formatted_docs) |
|
return self.client.question_answering(question=question, context=formatted_docs) |
|
|
|
def answer_question_api(self, question): |
|
formatted_prompt = self.format_question(question) |
|
resp = requests.post(self.api_url, headers=self.headers, json={"inputs": formatted_prompt}, stream=True) |
|
for c in resp.iter_content(): |
|
yield c |
|
|