base_models_rag / extractor.py
DeepVen's picture
Upload 8 files
8c3e214
from fastapi import FastAPI
# from transformers import pipeline
from txtai.embeddings import Embeddings
from txtai.pipeline import Extractor
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# NOTE - we configure docs_url to serve the interactive Docs at the root path
# of the app. This way, we can use the docs as a landing page for the app on Spaces.
app = FastAPI(docs_url="/")
# Create embeddings model with content support
embeddings = Embeddings(
{"path": "sentence-transformers/all-MiniLM-L6-v2", "content": True}
)
# Create extractor instance
# extractor = Extractor(embeddings, "google/flan-t5-base")
def _stream(dataset, limit, index: int = 0):
for row in dataset:
yield (index, row.page_content, None)
index += 1
if index >= limit:
break
def _max_index_id(path):
db = sqlite3.connect(path)
table = "sections"
df = pd.read_sql_query(f"select * from {table}", db)
return {"max_index": df["indexid"].max()}
def _prompt(question):
return f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered.
Question: {question}
Context: """
async def _search(query, extractor, question=None):
# Default question to query if empty
if not question:
question = query
return extractor([("answer", query, _prompt(question), False)])[0][1]
def _text_splitter(doc):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
length_function=len,
)
return text_splitter.transform_documents(doc)
def _load_docs(path: str):
load_doc = WebBaseLoader(path).load()
doc = _text_splitter(load_doc)
return doc
async def _upsert_docs(doc):
max_index = _max_index_id("index/documents")
embeddings.upsert(_stream(doc, 500, max_index["max_index"]))
embeddings.save("index")
return embeddings
@app.put("/rag/{path}")
async def get_doc_path(path: str):
return path
@app.get("/rag")
async def rag(question: str):
# question = "what is the document about?"
embeddings.load("index")
path = await get_doc_path(path)
doc = _load_docs(path)
embeddings = _upsert_docs(doc)
# Create extractor instance
extractor = Extractor(embeddings, "google/flan-t5-base")
answer = await _search(question, extractor)
# print(question, answer)
return {answer}