Spaces:
Runtime error
Runtime error
from langchain import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain.llms import HuggingFacePipeline | |
class HuggingFaceQuestionAnswering: | |
def __init__(self, retriever) -> None: | |
self.retriever = retriever | |
self.llm = HuggingFacePipeline.from_model_id( | |
# model_id="bigscience/bloom-1b7", | |
model_id="bigscience/bloomz-1b1", | |
task="text-generation", | |
device=1, | |
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2}, | |
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2}, | |
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, | |
pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, | |
) | |
self.chain = None | |
def initialize(self): | |
template = """Use the information contained in the following text: {context}. Complete the phrase: {question} """ | |
prompt_template = PromptTemplate( | |
template=template, | |
input_variables=["context", "question"], | |
) | |
# self.chain = RetrievalQA.from_chain_type(self.llm, retriever=self.retriever.retriever, chain_type_kwargs={"prompt": prompt_template}) | |
def answer_question(self, question: str, filter_dict): | |
retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150}) | |
try: | |
self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True) | |
result = self.chain({"query": question}) | |
docs = '\n'.join([x.metadata["paper_title"][:40] + " - " + x.page_content[:40].replace("\n", " ") + "..." for x in result["source_documents"]]) | |
print(f""" | |
Retrieved Documents: | |
{docs if docs != "" else "No documents found."}""") | |
return result | |
except: | |
return {"result": "Error generating answer."} | |