Spaces:
Paused
Paused
from fastapi import FastAPI | |
# from transformers import pipeline | |
from txtai.embeddings import Embeddings | |
from txtai.pipeline import Extractor | |
from langchain.document_loaders import WebBaseLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain import HuggingFaceHub | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
from txtai.embeddings import Embeddings | |
from txtai.pipeline import Extractor | |
import pandas as pd | |
import sqlite3 | |
import os | |
# NOTE - we configure docs_url to serve the interactive Docs at the root path | |
# of the app. This way, we can use the docs as a landing page for the app on Spaces. | |
app = FastAPI(docs_url="/") | |
# app = FastAPI() | |
# pipe = pipeline("text2text-generation", model="google/flan-t5-small") | |
# @app.get("/generate") | |
# def generate(text: str): | |
# """ | |
# Using the text2text-generation pipeline from `transformers`, generate text | |
# from the given input text. The model used is `google/flan-t5-small`, which | |
# can be found [here](https://huggingface.co/google/flan-t5-small). | |
# """ | |
# output = pipe(text) | |
# return {"output": output[0]["generated_text"]} | |
def load_embeddings( | |
domain: str = "", | |
db_present: bool = True, | |
path: str = "sentence-transformers/all-MiniLM-L6-v2", | |
index_name: str = "index", | |
): | |
# Create embeddings model with content support | |
embeddings = Embeddings({"path": path, "content": True}) | |
# if Vector DB is not present | |
if not db_present: | |
return embeddings | |
else: | |
if domain == "": | |
embeddings.load(index_name) # change this later | |
else: | |
print(3) | |
embeddings.load(f"{index_name}/{domain}") | |
return embeddings | |
def _check_if_db_exists(db_path: str) -> bool: | |
return os.path.exists(db_path) | |
def _text_splitter(doc): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
length_function=len, | |
) | |
return text_splitter.transform_documents(doc) | |
def _load_docs(path: str): | |
load_doc = WebBaseLoader(path).load() | |
doc = _text_splitter(load_doc) | |
return doc | |
def _stream(dataset, limit, index: int = 0): | |
for row in dataset: | |
yield (index, row.page_content, None) | |
index += 1 | |
if index >= limit: | |
break | |
def _max_index_id(path): | |
db = sqlite3.connect(path) | |
table = "sections" | |
df = pd.read_sql_query(f"select * from {table}", db) | |
return {"max_index": df["indexid"].max()} | |
def _upsert_docs(doc, embeddings, vector_doc_path: str, db_present: bool): | |
print(vector_doc_path) | |
if db_present: | |
print(1) | |
max_index = _max_index_id(f"{vector_doc_path}/documents") | |
print(max_index) | |
embeddings.upsert(_stream(doc, 500, max_index["max_index"])) | |
print("Embeddings done!!") | |
embeddings.save(vector_doc_path) | |
print("Embeddings done - 1!!") | |
else: | |
print(2) | |
embeddings.index(_stream(doc, 500, 0)) | |
embeddings.save(vector_doc_path) | |
max_index = _max_index_id(f"{vector_doc_path}/documents") | |
print(max_index) | |
# check | |
# max_index = _max_index_id(f"{vector_doc_path}/documents") | |
# print(max_index) | |
return max_index | |
# def prompt(question): | |
# return f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered. | |
# Question: {question} | |
# Context: """ | |
# def search(query, question=None): | |
# # Default question to query if empty | |
# if not question: | |
# question = query | |
# return extractor([("answer", query, prompt(question), False)])[0][1] | |
# @app.get("/rag") | |
# def rag(question: str): | |
# # question = "what is the document about?" | |
# answer = search(question) | |
# # print(question, answer) | |
# return {answer} | |
# @app.get("/index") | |
# def get_url_file_path(url_path: str): | |
# embeddings = load_embeddings() | |
# doc = _load_docs(url_path) | |
# embeddings, max_index = _upsert_docs(doc, embeddings) | |
# return max_index | |
def get_domain_file_path(domain: str, file_path: str): | |
print(domain, file_path) | |
print(os.getcwd()) | |
bool_value = _check_if_db_exists(db_path=f"{os.getcwd()}/index/{domain}/documents") | |
print(bool_value) | |
if bool_value: | |
embeddings = load_embeddings(domain=domain, db_present=bool_value) | |
print(embeddings) | |
doc = _load_docs(file_path) | |
max_index = _upsert_docs( | |
doc=doc, | |
embeddings=embeddings, | |
vector_doc_path=f"{os.getcwd()}/index/{domain}", | |
db_present=bool_value, | |
) | |
# print("-------") | |
else: | |
embeddings = load_embeddings(domain=domain, db_present=bool_value) | |
doc = _load_docs(file_path) | |
max_index = _upsert_docs( | |
doc=doc, | |
embeddings=embeddings, | |
vector_doc_path=f"{os.getcwd()}/index/{domain}", | |
db_present=bool_value, | |
) | |
# print("Final - output : ", max_index) | |
return "Executed Successfully!!" | |
def _check_if_db_exists(db_path: str) -> bool: | |
return os.path.exists(db_path) | |
def _load_embeddings_from_db( | |
db_present: bool, | |
domain: str, | |
path: str = "sentence-transformers/all-MiniLM-L6-v2", | |
): | |
# Create embeddings model with content support | |
embeddings = Embeddings({"path": path, "content": True}) | |
# if Vector DB is not present | |
if not db_present: | |
return embeddings | |
else: | |
if domain == "": | |
embeddings.load("index") # change this later | |
else: | |
print(3) | |
embeddings.load(f"{os.getcwd()}/index/{domain}") | |
return embeddings | |
def _prompt(question): | |
return f"""Answer the following question using only the context below. Say 'Could not find answer within the context' when the question can't be answered. | |
Question: {question} | |
Context: """ | |
def _search(query, extractor, question=None): | |
# Default question to query if empty | |
if not question: | |
question = query | |
# template = f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered. | |
# Question: {question} | |
# Context: """ | |
# prompt = PromptTemplate(template=template, input_variables=["question"]) | |
# llm_chain = LLMChain(prompt=prompt, llm=extractor) | |
# return {"question": question, "answer": llm_chain.run(question)} | |
return extractor([("answer", query, _prompt(question), False)])[0][1] | |
def rag(domain: str, question: str): | |
db_exists = _check_if_db_exists(db_path=f"{os.getcwd()}/index/{domain}/documents") | |
print(db_exists) | |
# if db_exists: | |
embeddings = _load_embeddings_from_db(db_exists, domain) | |
# Create extractor instance | |
#extractor = Extractor(embeddings, "google/flan-t5-base") | |
#extractor = Extractor(embeddings, "TheBloke/Llama-2-7B-GGUF") | |
extractor = Extractor(embeddings, "google/flan-t5-xl") | |
# llm = HuggingFaceHub( | |
# repo_id="google/flan-t5-xxl", | |
# model_kwargs={"temperature": 1, "max_length": 1000000}, | |
# ) | |
# else: | |
answer = _search(question, extractor) | |
return {"question": question, "answer": answer} |