Spaces:
Runtime error
Runtime error
gauravprasadgp
commited on
Commit
β’
d132e19
1
Parent(s):
1eec2d0
initial commit
Browse files- Dockerfile +9 -0
- README.md +77 -11
- generator/__init__.py +0 -0
- generator/llm_calls.py +33 -0
- main.py +52 -0
- pgvector/docker-compose.yml +14 -0
- pgvector/init.sql +9 -0
- requirements.txt +13 -0
- rerank/__init__.py +0 -0
- rerank/rerank.py +17 -0
- retrieve/__init__.py +0 -0
- retrieve/vector_store.py +63 -0
- utils/db.py +38 -0
Dockerfile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY . /app/
|
6 |
+
|
7 |
+
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
|
8 |
+
|
9 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -1,11 +1,77 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Modular RAG
|
3 |
+
|
4 |
+
A hybrid approach to implement RAG inspired by Advance RAG.
|
5 |
+
Usually implemeted with modules acting as plug and play.
|
6 |
+
|
7 |
+
## Documentation
|
8 |
+
|
9 |
+
#### Generator:
|
10 |
+
Core component of RAG, responsible for transforming the retrieved information into natural and human sense.
|
11 |
+
|
12 |
+
#### Retriever:
|
13 |
+
The word "R" in RAG, serving the purpose of retrieving the top K element from knowledge base.
|
14 |
+
|
15 |
+
#### ReRank:
|
16 |
+
As the name suggest a model used to re-rank the relevant documents. It indexes the documents based on the similariy score between question and the retrieved documents post vector search.
|
17 |
+
|
18 |
+
## Run Locally
|
19 |
+
|
20 |
+
Clone the project
|
21 |
+
|
22 |
+
```bash
|
23 |
+
git clone https://github.com/gauravprasadgp/modular-rag
|
24 |
+
```
|
25 |
+
|
26 |
+
Go to the project directory
|
27 |
+
|
28 |
+
```bash
|
29 |
+
cd modular-rag
|
30 |
+
```
|
31 |
+
|
32 |
+
Install dependencies
|
33 |
+
|
34 |
+
```bash
|
35 |
+
pip install -r requirements.txt
|
36 |
+
```
|
37 |
+
Run postgres locally
|
38 |
+
```bash
|
39 |
+
cd pgvector
|
40 |
+
```
|
41 |
+
```bash
|
42 |
+
docker compose -d up
|
43 |
+
```
|
44 |
+
|
45 |
+
Start the server
|
46 |
+
|
47 |
+
```bash
|
48 |
+
python main.py
|
49 |
+
```
|
50 |
+
|
51 |
+
## API Reference
|
52 |
+
|
53 |
+
#### Upload file to create embedding
|
54 |
+
|
55 |
+
```http
|
56 |
+
POST /create
|
57 |
+
```
|
58 |
+
|
59 |
+
| Parameter | Type | Description |
|
60 |
+
|:----------|:-------|:-----------------------------|
|
61 |
+
| `file` | `file` | **Required**. File to upload |
|
62 |
+
|
63 |
+
#### Get answer from user query
|
64 |
+
|
65 |
+
```http
|
66 |
+
POST /answer
|
67 |
+
```
|
68 |
+
|
69 |
+
| Parameter | Type | Description |
|
70 |
+
|:----------| :------- |:-------------------------|
|
71 |
+
| `query` | `string` | **Required**. user query |
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
## License
|
76 |
+
|
77 |
+
[MIT](https://choosealicense.com/licenses/mit/)
|
generator/__init__.py
ADDED
File without changes
|
generator/llm_calls.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.legacy.embeddings import HuggingFaceEmbedding
|
2 |
+
from llama_index.legacy.llms import LlamaCPP
|
3 |
+
from llama_index.llms.llama_cpp.llama_utils import (
|
4 |
+
messages_to_prompt,
|
5 |
+
completion_to_prompt,
|
6 |
+
)
|
7 |
+
|
8 |
+
llm = LlamaCPP(
|
9 |
+
model_url="https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/resolve/main/llama-2-13b-chat.ggmlv3.q4_0"
|
10 |
+
".bin",
|
11 |
+
temperature=0.1,
|
12 |
+
max_new_tokens=256,
|
13 |
+
context_window=3900,
|
14 |
+
generate_kwargs={},
|
15 |
+
model_kwargs={"n_gpu_layers": 1},
|
16 |
+
messages_to_prompt=messages_to_prompt,
|
17 |
+
completion_to_prompt=completion_to_prompt,
|
18 |
+
verbose=True,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def get_embed_model():
|
23 |
+
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
24 |
+
return embed_model
|
25 |
+
|
26 |
+
|
27 |
+
async def get_answer(query, context):
|
28 |
+
prompt = f"""Given the context below answer the question.
|
29 |
+
Context: {context}
|
30 |
+
Question: {query}
|
31 |
+
Answer:
|
32 |
+
"""
|
33 |
+
return await llm.acomplete(prompt=prompt)
|
main.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import asynccontextmanager
|
2 |
+
|
3 |
+
import uvicorn
|
4 |
+
from fastapi import FastAPI, Request, UploadFile
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
|
7 |
+
from generator.llm_calls import get_answer
|
8 |
+
from rerank.rerank import rerank_documents
|
9 |
+
from retrieve.vector_store import create_embeddings_from_file, get_relevant_document
|
10 |
+
from utils.db import postgres_db
|
11 |
+
|
12 |
+
app = FastAPI(title="Modular RAG",
|
13 |
+
version="1.0.0", )
|
14 |
+
|
15 |
+
app.add_middleware(
|
16 |
+
CORSMiddleware,
|
17 |
+
allow_origins=["*"],
|
18 |
+
allow_credentials=True,
|
19 |
+
allow_methods=["*"],
|
20 |
+
allow_headers=["*"],
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
@asynccontextmanager
|
25 |
+
async def lifespan(app: FastAPI):
|
26 |
+
await postgres_db.create_connection_pool()
|
27 |
+
yield
|
28 |
+
await postgres_db.close_connection_pool()
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|
33 |
+
|
34 |
+
|
35 |
+
@app.post("/create")
|
36 |
+
async def create_embedding(file: UploadFile):
|
37 |
+
await create_embeddings_from_file(file)
|
38 |
+
|
39 |
+
|
40 |
+
@app.post("/answer")
|
41 |
+
async def post_conversation(request: Request):
|
42 |
+
payload = await request.json()
|
43 |
+
query = payload.get("query")
|
44 |
+
context = await get_relevant_document(query=query)
|
45 |
+
sorted_docs = rerank_documents(question=query, documents=context)
|
46 |
+
sorted_context = "\n\n".join(sorted_docs)
|
47 |
+
return await get_answer(context=sorted_context, query=query)
|
48 |
+
|
49 |
+
|
50 |
+
@app.get("/")
|
51 |
+
async def get_test(request: Request):
|
52 |
+
return "Successfully Deployed"
|
pgvector/docker-compose.yml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
db:
|
3 |
+
hostname: db
|
4 |
+
image: pgvector/pgvector:pg16
|
5 |
+
ports:
|
6 |
+
- 5432:5432
|
7 |
+
restart: always
|
8 |
+
environment:
|
9 |
+
- POSTGRES_DB=vectordb
|
10 |
+
- POSTGRES_USER=user
|
11 |
+
- POSTGRES_PASSWORD=password
|
12 |
+
- POSTGRES_HOST_AUTH_METHOD=trust
|
13 |
+
volumes:
|
14 |
+
- ./init.sql:/docker-entrypoint-initdb.d/init.sql
|
pgvector/init.sql
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CREATE EXTENSION IF NOT EXISTS vector;
|
2 |
+
|
3 |
+
CREATE TABLE IF NOT EXISTS embeddings (
|
4 |
+
id SERIAL PRIMARY KEY,
|
5 |
+
embedding vector,
|
6 |
+
document text,
|
7 |
+
metadata jsonb
|
8 |
+
created_at timestamptz DEFAULT now()
|
9 |
+
);
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
psycopg[binary,pool]
|
3 |
+
pgvector
|
4 |
+
langchain-community
|
5 |
+
uvicorn[standard]
|
6 |
+
pypdf
|
7 |
+
fastapi[all]
|
8 |
+
python-multipart
|
9 |
+
pydantic_settings
|
10 |
+
llama-index
|
11 |
+
gunicorn
|
12 |
+
llama-index-llms-llama-cpp
|
13 |
+
FlagEmbedding
|
rerank/__init__.py
ADDED
File without changes
|
rerank/rerank.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from FlagEmbedding import FlagReranker
|
2 |
+
|
3 |
+
reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True)
|
4 |
+
|
5 |
+
|
6 |
+
def rerank_documents(question: str, documents: list[str]):
|
7 |
+
sentences = []
|
8 |
+
for doc in documents:
|
9 |
+
sentences.append((question, doc))
|
10 |
+
score = reranker.compute_score(sentences)
|
11 |
+
print(score)
|
12 |
+
sorted_elements = []
|
13 |
+
for score, doc in zip(score, documents):
|
14 |
+
elem = {score: score, doc: doc}
|
15 |
+
sorted_elements.append(elem)
|
16 |
+
sorted_docs = sorted(sorted_elements, key=lambda x: x.score, reverse=True)
|
17 |
+
return sorted_docs[:7]
|
retrieve/__init__.py
ADDED
File without changes
|
retrieve/vector_store.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import json
|
3 |
+
|
4 |
+
from fastapi import File
|
5 |
+
from llama_index.core import SimpleDirectoryReader
|
6 |
+
from llama_index.core.node_parser import SentenceWindowNodeParser
|
7 |
+
from llama_index.core.schema import BaseNode
|
8 |
+
|
9 |
+
from generator.llm_calls import get_embed_model
|
10 |
+
from utils.db import postgres_db
|
11 |
+
|
12 |
+
|
13 |
+
async def create_embeddings_from_file(file: File()):
|
14 |
+
documents = await SimpleDirectoryReader(input_files=[file]).aload_data()
|
15 |
+
node_parser = SentenceWindowNodeParser.from_defaults(
|
16 |
+
window_size=3,
|
17 |
+
window_metadata_key="window",
|
18 |
+
original_text_metadata_key="original_text",
|
19 |
+
)
|
20 |
+
|
21 |
+
nodes = node_parser.get_nodes_from_documents(documents)
|
22 |
+
for node in nodes:
|
23 |
+
embedding = await get_embed_model().aget_text_embedding(node.get_content())
|
24 |
+
node.embedding = embedding
|
25 |
+
|
26 |
+
|
27 |
+
def get_values_from_nodes(nodes: list[BaseNode]):
|
28 |
+
values = []
|
29 |
+
dt = datetime.datetime.now()
|
30 |
+
for node in nodes:
|
31 |
+
value = (node.embedding, node.get_content(), json.dumps(node.metadata), dt)
|
32 |
+
values.append(value)
|
33 |
+
return values
|
34 |
+
|
35 |
+
|
36 |
+
async def insert_documents(nodes):
|
37 |
+
try:
|
38 |
+
values = get_values_from_nodes(nodes)
|
39 |
+
async with postgres_db.db_pool as conn:
|
40 |
+
async with conn.cursor() as cur:
|
41 |
+
await cur.executemany("""
|
42 |
+
INSERT INTO document_embedding (embedding, document, metadata, created_at)
|
43 |
+
VALUES (%s, %s, %s, %s);
|
44 |
+
""", values)
|
45 |
+
await conn.commit()
|
46 |
+
except Exception as error:
|
47 |
+
print(f"insert document exception {error}")
|
48 |
+
await conn.rollback()
|
49 |
+
|
50 |
+
|
51 |
+
async def get_relevant_document(query: str):
|
52 |
+
embedded_question = await get_embed_model().aget_query_embedding(query=query)
|
53 |
+
try:
|
54 |
+
async with postgres_db.db_pool as conn:
|
55 |
+
async with conn.cursor() as cur:
|
56 |
+
await cur.execute(f"""SELECT metadata -> 'window', 1 - (embedding <=> '{embedded_question}') AS
|
57 |
+
cosine_similarity from document_embedding ORDER BY cosine_similarity DESC limit 10;""")
|
58 |
+
results = await cur.fetchall()
|
59 |
+
docs = [row[0] for row in results]
|
60 |
+
return docs
|
61 |
+
except Exception as error:
|
62 |
+
print(f"insert document exception {error}")
|
63 |
+
await conn.rollback()
|
utils/db.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from psycopg_pool import AsyncConnectionPool
|
5 |
+
|
6 |
+
|
7 |
+
def get_conn_str():
|
8 |
+
return f"""
|
9 |
+
dbname={os.getenv('POSTGRES_DB') or "vectordb"}
|
10 |
+
user={os.getenv('POSTGRES_USER') or "user"}
|
11 |
+
password={os.getenv('POSTGRES_PASSWORD') or "password"}
|
12 |
+
host={os.getenv('POSTGRES_HOST') or "localhost"}
|
13 |
+
port={os.getenv('POSTGRES_PORT') or "5432"}
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
class PostgresDatabase:
|
18 |
+
def __init__(self):
|
19 |
+
self.db_pool = Optional[None]
|
20 |
+
self.conn = Optional[None]
|
21 |
+
|
22 |
+
async def create_connection_pool(self):
|
23 |
+
try:
|
24 |
+
self.conn = AsyncConnectionPool(conninfo=get_conn_str())
|
25 |
+
if self.conn:
|
26 |
+
self.db_pool = self.conn
|
27 |
+
except ConnectionError as error:
|
28 |
+
print(f"DB connection error {error}")
|
29 |
+
|
30 |
+
async def close_connection_pool(self):
|
31 |
+
try:
|
32 |
+
if self.db_pool:
|
33 |
+
await self.conn.close()
|
34 |
+
except Exception as error:
|
35 |
+
print(f"Error in closing db connection {error}")
|
36 |
+
|
37 |
+
|
38 |
+
postgres_db = PostgresDatabase()
|