File size: 2,873 Bytes
84947fc
 
 
8c3e214
 
 
 
93bc725
8c3e214
84947fc
 
 
 
 
8c3e214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84947fc
 
 
 
8c3e214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from fastapi import FastAPI
from txtai.embeddings import Embeddings
from txtai.pipeline import Extractor
import os
from langchain import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

# from transformers import pipeline

# NOTE - we configure docs_url to serve the interactive Docs at the root path
# of the app. This way, we can use the docs as a landing page for the app on Spaces.
app = FastAPI(docs_url="/")

# @app.get("/generate")
# def generate(text: str):
#     """
#     Using the text2text-generation pipeline from `transformers`, generate text
#     from the given input text. The model used is `google/flan-t5-small`, which
#     can be found [here](https://huggingface.co/google/flan-t5-small).
#     """
#     output = pipe(text)
#     return {"output": output[0]["generated_text"]}


def _check_if_db_exists(db_path: str) -> bool:
    return os.path.exists(db_path)


def _load_embeddings_from_db(
    db_present: bool,
    domain: str,
    path: str = "sentence-transformers/all-MiniLM-L6-v2",
):
    # Create embeddings model with content support
    embeddings = Embeddings({"path": path, "content": True})
    # if Vector DB is not present
    if not db_present:
        return embeddings
    else:
        if domain == "":
            embeddings.load("index")  # change this later
        else:
            print(3)
            embeddings.load(f"index/{domain}")
        return embeddings


def _prompt(question):
    return f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered.
            Question: {question}
            Context: """


def _search(query, extractor, question=None):
    # Default question to query if empty
    if not question:
        question = query

    # template = f"""Answer the following question using only the context below. Say 'no answer' when the question can't be answered.
    #         Question: {question}
    #         Context: """

    # prompt = PromptTemplate(template=template, input_variables=["question"])
    # llm_chain = LLMChain(prompt=prompt, llm=extractor)

    # return {"question": question, "answer": llm_chain.run(question)}
    return extractor([("answer", query, _prompt(question), False)])[0][1]


@app.get("/rag")
def rag(domain: str, question: str):
    db_exists = _check_if_db_exists(db_path=f"{os.getcwd()}\index\{domain}\documents")
    print(db_exists)
    # if db_exists:
    embeddings = _load_embeddings_from_db(db_exists, domain)
    # Create extractor instance
    extractor = Extractor(embeddings, "google/flan-t5-base")
    # llm = HuggingFaceHub(
    #     repo_id="google/flan-t5-xxl",
    #     model_kwargs={"temperature": 1, "max_length": 1000000},
    # )
    # else:
    answer = _search(question, extractor)
    return {"question": question, "answer": answer}