from fastapi import FastAPI from txtai.embeddings import Embeddings from txtai.pipeline import Extractor import os from langchain import HuggingFaceHub from langchain.prompts import PromptTemplate from langchain.chains import LLMChain # from transformers import pipeline # NOTE - we configure docs_url to serve the interactive Docs at the root path # of the app. This way, we can use the docs as a landing page for the app on Spaces. app = FastAPI(docs_url="/") # @app.get("/generate") # def generate(text: str): # """ # Using the text2text-generation pipeline from `transformers`, generate text # from the given input text. The model used is `google/flan-t5-small`, which # can be found [here](https://huggingface.co/google/flan-t5-small). # """ # output = pipe(text) # return {"output": output[0]["generated_text"]} def _check_if_db_exists(db_path: str) -> bool: return os.path.exists(db_path) def _load_embeddings_from_db( db_present: bool, domain: str, path: str = "sentence-transformers/all-MiniLM-L6-v2", ): # Create embeddings model with content support embeddings = Embeddings({"path": path, "content": True}) # if Vector DB is not present if not db_present: return embeddings else: if domain == "": embeddings.load("index") # change this later else: print(3) embeddings.load(f"index/{domain}") return embeddings def _prompt(question): return f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered. Question: {question} Context: """ def _search(query, extractor, question=None): # Default question to query if empty if not question: question = query # template = f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered. # Question: {question} # Context: """ # prompt = PromptTemplate(template=template, input_variables=["question"]) # llm_chain = LLMChain(prompt=prompt, llm=extractor) # return {"question": question, "answer": llm_chain.run(question)} return extractor([("answer", query, _prompt(question), False)])[0][1] @app.get("/rag") def rag(domain: str, question: str): db_exists = _check_if_db_exists(db_path=f"{os.getcwd()}\index\{domain}\documents") print(db_exists) # if db_exists: embeddings = _load_embeddings_from_db(db_exists, domain) # Create extractor instance extractor = Extractor(embeddings, "google/flan-t5-base") # llm = HuggingFaceHub( # repo_id="google/flan-t5-xxl", # model_kwargs={"temperature": 1, "max_length": 1000000}, # ) # else: answer = _search(question, extractor) return {"question": question, "answer": answer}