base_models_rag / main.py
DeepVen's picture
Upload 8 files
8c3e214
raw history blame
No virus
2.87 kB
from fastapi import FastAPI
from txtai.embeddings import Embeddings
from txtai.pipeline import Extractor
import os
from langchain import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
# from transformers import pipeline
# NOTE - we configure docs_url to serve the interactive Docs at the root path
# of the app. This way, we can use the docs as a landing page for the app on Spaces.
app = FastAPI(docs_url="/")
# @app.get("/generate")
# def generate(text: str):
# """
# Using the text2text-generation pipeline from `transformers`, generate text
# from the given input text. The model used is `google/flan-t5-small`, which
# can be found [here](https://huggingface.co/google/flan-t5-small).
# """
# output = pipe(text)
# return {"output": output[0]["generated_text"]}
def _check_if_db_exists(db_path: str) -> bool:
return os.path.exists(db_path)
def _load_embeddings_from_db(
db_present: bool,
domain: str,
path: str = "sentence-transformers/all-MiniLM-L6-v2",
):
# Create embeddings model with content support
embeddings = Embeddings({"path": path, "content": True})
# if Vector DB is not present
if not db_present:
return embeddings
else:
if domain == "":
embeddings.load("index") # change this later
else:
print(3)
embeddings.load(f"index/{domain}")
return embeddings
def _prompt(question):
return f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered.
Question: {question}
Context: """
def _search(query, extractor, question=None):
# Default question to query if empty
if not question:
question = query
# template = f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered.
# Question: {question}
# Context: """
# prompt = PromptTemplate(template=template, input_variables=["question"])
# llm_chain = LLMChain(prompt=prompt, llm=extractor)
# return {"question": question, "answer": llm_chain.run(question)}
return extractor([("answer", query, _prompt(question), False)])[0][1]
@app.get("/rag")
def rag(domain: str, question: str):
db_exists = _check_if_db_exists(db_path=f"{os.getcwd()}\index\{domain}\documents")
print(db_exists)
# if db_exists:
embeddings = _load_embeddings_from_db(db_exists, domain)
# Create extractor instance
extractor = Extractor(embeddings, "google/flan-t5-base")
# llm = HuggingFaceHub(
# repo_id="google/flan-t5-xxl",
# model_kwargs={"temperature": 1, "max_length": 1000000},
# )
# else:
answer = _search(question, extractor)
return {"question": question, "answer": answer}