diff --git a/.github/workflows/manual-pushing-to-HF1.yml b/.github/workflows/manual-pushing-to-HF1.yml new file mode 100644 index 0000000000000000000000000000000000000000..c4ea9c6813ec207f3a469ea816fc044306e3ce79 --- /dev/null +++ b/.github/workflows/manual-pushing-to-HF1.yml @@ -0,0 +1,26 @@ +name: Sync to Hugging Face Hub +on: + push: + branches: [for_testing] + workflow_dispatch: +jobs: + sync-to-hub: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + lfs: true + - name: Configure Git identity + run: | + git config --global user.name "Andrchest" + git config --global user.email "andreipolevoi220@gmail.com" + - name: Push to Hugging Face + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + git checkout -b hf-single-commit + git reset --soft $(git rev-list --max-parents=0 HEAD) + git commit -m "Single commit for Hugging Face" + git remote add hf https://Andrchest:$HF_TOKEN@huggingface.co/spaces/The-Ultimate-RAG-HF/RAG-Integration-test + git push --force hf hf-single-commit:main \ No newline at end of file diff --git a/.github/workflows/sync-to-hf.yml b/.github/workflows/sync-to-hf.yml new file mode 100644 index 0000000000000000000000000000000000000000..dec51180ba78897480300b92e0d7c05603faf3c4 --- /dev/null +++ b/.github/workflows/sync-to-hf.yml @@ -0,0 +1,93 @@ +name: Sync to Hugging Face Hub +on: + push: + branches: [main] + workflow_dispatch: +jobs: + sync-to-hub: + runs-on: ubuntu-latest + environment: Integration test + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + lfs: true + - name: Configure Git identity + run: | + git config --global user.name "Andrchest" + git config --global user.email "andreipolevoi220@gmail.com" + - name: Push to HF1 + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + git checkout -b hf1-single-commit + git reset --soft $(git rev-list --max-parents=0 HEAD) + git commit -m "Single commit for HF1" + git remote add hf1 https://Andrchest:$HF_TOKEN@huggingface.co/spaces/The-Ultimate-RAG-HF/RAG-Integration-test + git push --force hf1 hf1-single-commit:main + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest pytest-cov + pip install -r app/requirements.txt + - name: Wait for HF1 deployment + run: sleep 120 + - name: Debug environment variables + env: + DATABASE_URL: ${{ secrets.DATABASE_URL }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + SECRET_PEPPER: ${{ secrets.SECRET_PEPPER }} + JWT_ALGORITHM: ${{ secrets.JWT_ALGORITHM }} + PYTHONPATH: ${{ github.workspace }} + run: | + echo "DATABASE_URL is set: ${DATABASE_URL:+set}" + echo "GEMINI_API_KEY is set: ${GEMINI_API_KEY:+set}" + echo "SECRET_PEPPER is set: ${SECRET_PEPPER:+set}" + echo "JWT_ALGORITHM is set: ${JWT_ALGORITHM:+set}" + env | grep -E 'DATABASE_URL|GEMINI_API_KEY|SECRET_PEPPER|JWT_ALGORITHM' + - name: Initialize directories + env: + DATABASE_URL: ${{ secrets.DATABASE_URL }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + SECRET_PEPPER: ${{ secrets.SECRET_PEPPER }} + JWT_ALGORITHM: ${{ secrets.JWT_ALGORITHM }} + PYTHONPATH: ${{ github.workspace }} + working-directory: ./ + run: | + python -m app.initializer + - name: Debug directory structure + run: | + ls -R + - name: Run integration tests with coverage + env: + HF1_URL: ${{ secrets.HF1_URL }} + DATABASE_URL: ${{ secrets.DATABASE_URL }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + SECRET_PEPPER: ${{ secrets.SECRET_PEPPER }} + JWT_ALGORITHM: ${{ secrets.JWT_ALGORITHM }} + PYTHONPATH: ${{ github.workspace }} + working-directory: ./ + run: | + echo "PYTHONPATH: $PYTHONPATH" + python -m pytest app/tests/integration/test.py -v --cov=app --cov-report=xml --cov-report=html + - name: Upload coverage report + uses: actions/upload-artifact@v4 + with: + name: integration-coverage-report + path: | + coverage.xml + htmlcov/ + - name: Push to HF2 if tests pass + if: success() + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + git checkout -b hf2-single-commit + git reset --soft $(git rev-list --max-parents=0 HEAD) + git commit -m "Single commit for HF2" + git remote add hf2 https://Andrchest:$HF_TOKEN@huggingface.co/spaces/The-Ultimate-RAG-HF/The-Ultimate-RAG + git push --force hf2 hf2-single-commit:main diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..91f9a4afbd1a984228266eda698d4cfe5144c769 --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,40 @@ +name: Unit Tests +on: + pull_request: + branches: + - main +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install coverage + pip install -r app/requirements.txt + pip install flake8 pytest + - name: Run linter + run: | + flake8 app/ --max-line-length=160 --extend-ignore=E203 + - name: Run unit tests with coverage + run: | + coverage run -m pytest app/tests/unit/test.py + coverage xml + coverage html + - name: Upload coverage report + uses: actions/upload-artifact@v4 + with: + name: unit-coverage-report + path: | + coverage.xml + htmlcov/ + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.xml \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..40dc34d7cb5d95c88e34ff16914fa2127348df49 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +__pycache__ +/app/temp_storage +/database +/new_env +/prompt.txt +/app/key.py +/app/env_vars.py +/chats_storage +/.env +exp.* +response.txt \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..1dcce6cf0be73126567205b146cb917f03a7e128 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,24 @@ +# syntax=docker/dockerfile:1 +FROM python:3.12.10 + +RUN useradd -m -u 1000 user +USER user +ENV PATH="/home/user/.local/bin:$PATH" + +WORKDIR /app + +# copy and install Python reqs +COPY app/requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r /app/requirements.txt + +# download Qdrant binary +RUN wget https://github.com/qdrant/qdrant/releases/download/v1.11.5/qdrant-x86_64-unknown-linux-gnu.tar.gz \ + && tar -xzf qdrant-x86_64-unknown-linux-gnu.tar.gz \ + && mv qdrant /home/user/.local/bin/qdrant \ + && rm qdrant-x86_64-unknown-linux-gnu.tar.gz + +COPY --chown=user . /app + +RUN chmod +x start.sh + +CMD ["./start.sh"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c5c67d807ffbfde903a637b5cf4192810b23063e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Danil Popov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index 26e69188274102146f53ba337f1d2ff7127946b8..07b478e1d155152549990a9feb32ae17ee01582a 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,86 @@ +--- +title: The Ultimate RAG +emoji: 🌍 +colorFrom: pink +colorTo: indigo +sdk: docker +pinned: false +short_description: the ultimate rag +--- + # The-Ultimate-RAG -[S25] Software project for Innopolis University + +## Overview + +[S25] The Ultimate RAG is an Innopolis University software project that generates cited responses from a local database. + +## Prerequisites + +Before you begin, ensure the following is installed on your machine: + +- [Python](https://www.python.org/) +- [Docker](https://www.docker.com/get-started/) + +## Installation + +1. **Clone the repository** + ```bash + git clone https://github.com/PopovDanil/The-Ultimate-RAG + cd The-Ultimate-RAG + ``` +2. **Set up a virtual environment (recommended)** + + To isolate project dependencies and avoid conflicts, create a virtual environment: + - **On Unix/Linux/macOS:** + ```bash + python3 -m venv env + source env/bin/activate + ``` + - **On Windows:** + ```bash + python -m venv env + env\Scripts\activate + ``` +3. **Install required libraries** + + Within the activated virtual environment, install the dependencies: + ```bash + pip install -r ./app/requirements.txt + ``` + *Note:* ensure you are in the virtual environment before running the command + +4. **Set up Docker** + - Ensure Docker is running on your machine + - Open a terminal, navigate to project directory, and run: + ```bash + docker-compose up --build + ``` + *Note:* The initial build may take 10–20 minutes, as it needs to download large language models and other + dependencies. + Later launches will be much faster. + +5. **Server access** + + Once the containers are running, visit `http://localhost:5050`. You should see the application’s welcome page + +To stop the application and shut down all containers, press `Ctrl+C` in the terminal where `docker-compose` is running, +and then run: + +```bash + docker-compose down +``` + +## Usage + +1. **Upload your file:** click the upload button and select a supported file (`.txt`, `.doc`, `.docx`, or `.pdf`) +2. **Ask a question**: Once the file is processed, type your question into the prompt box and submit. +3. **Receive your answer** + +**A note on performance** + +Response generation is a computationally intensive task. +The time to receive an answer may vary depending on your machine's hardware and the complexity of the query. + +## License + +This project is licensed under the [MIT License](LICENSE). \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/api/api.py b/app/api/api.py new file mode 100644 index 0000000000000000000000000000000000000000..ab30d4f10abfd8d5cec4b4b615a4f4efdf777969 --- /dev/null +++ b/app/api/api.py @@ -0,0 +1,260 @@ +from app.backend.controllers.messages import register_message +from app.core.document_validator import path_is_valid +from app.core.response_parser import add_links +from app.backend.models.users import User +from app.settings import BASE_DIR +from app.backend.controllers.chats import ( + get_chat_with_messages, + create_new_chat, + update_title, + list_user_chats +) +from app.backend.controllers.users import ( + extract_user_from_context, + get_current_user, + get_latest_chat, + refresh_cookie, + authorize_user, + check_cookie, + create_user +) +from app.core.utils import ( + construct_collection_name, + create_collection, + extend_context, + initialize_rag, + save_documents, + protect_chat, + TextHandler, + PDFHandler, +) + +from fastapi.templating import Jinja2Templates +from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi import ( + HTTPException, + UploadFile, + Request, + Depends, + FastAPI, + Form, + File, +) +from fastapi.responses import ( + StreamingResponse, + RedirectResponse, + FileResponse, + JSONResponse, +) + +from typing import Optional +import os + +# <------------------------------------- API -------------------------------------> +api = FastAPI() +rag = initialize_rag() + +origins = [ + "*", +] + +api.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +api.mount( + "/chats_storage", + StaticFiles(directory=os.path.join(BASE_DIR, "chats_storage")), + name="chats_storage", +) +api.mount( + "/static", + StaticFiles(directory=os.path.join(BASE_DIR, "app", "frontend", "static")), + name="static", +) +templates = Jinja2Templates( + directory=os.path.join(BASE_DIR, "app", "frontend", "templates") +) + + +# <--------------------------------- Middleware ---------------------------------> +@api.middleware("http") +async def require_user(request: Request, call_next): + print("&" * 40, "START MIDDLEWARE", "&" * 40) + try: + print(f"Path ----> {request.url.path}, Method ----> {request.method}, Port ----> {request.url.port}\n") + + stripped_path = request.url.path.strip("/") + + if ( + stripped_path.startswith("pdfs") + or "static/styles.css" in stripped_path + or "favicon.ico" in stripped_path + ): + return await call_next(request) + + user = get_current_user(request) + authorized = True + if user is None: + authorized = False + user = create_user() + + print(f"User in Context ----> {user.id}\n") + + request.state.current_user = user + response = await call_next(request) + + if authorized: + refresh_cookie(request=request, response=response) + else: + authorize_user(response, user) + return response + + except Exception as exception: + raise exception + finally: + print("&" * 40, "END MIDDLEWARE", "&" * 40, "\n\n") + + +# <--------------------------------- Common routes ---------------------------------> +@api.post("/message_with_docs") +async def send_message( + request: Request, + files: list[UploadFile] = File(None), + prompt: str = Form(...), + chat_id: str = Form(None), +) -> StreamingResponse: + status = 200 + try: + user = extract_user_from_context(request) + print("-" * 100, "User ---->", user, "-" * 100, "\n\n") + collection_name = construct_collection_name(user, chat_id) + + message_id = register_message(content=prompt, sender="user", chat_id=chat_id) + + await save_documents( + collection_name, files=files, RAG=rag, user=user, chat_id=chat_id, message_id=message_id + ) + + return StreamingResponse( + rag.generate_response_stream( + collection_name=collection_name, user_prompt=prompt, stream=True + ), + status, + media_type="text/event-stream", + ) + except Exception as e: + print(e) + + +@api.post("/replace_message") +async def replace_message(request: Request): + data = await request.json() + with open(os.path.join(BASE_DIR, "response.txt"), "w") as f: + f.write(data.get("message", "")) + updated_message = data.get("message", "") + register_message( + content=updated_message, sender="system", chat_id=data.get("chatId") + ) + return JSONResponse({"updated_message": updated_message}) + + +@api.get("/viewer/{path:path}") +def show_document( + request: Request, + path: str, + page: Optional[int] = 1, + lines: Optional[str] = "1-1", + start: Optional[int] = 0, +): + print(f"DEBUG: Show document with path: {path}, page: {page}, lines: {lines}, start: {start}") + path = os.path.realpath(path) + print(f"DEBUG: Real path: {path}") + + path = os.path.realpath(path) + if not path_is_valid(path): + return HTTPException(status_code=404, detail="Document not found") + + ext = path.split(".")[-1] + if ext == "pdf": + print("Open pdf file by path") + return FileResponse(path=path) + elif ext in ("txt", "csv", "md", "json"): + print("Open txt file by path") + return TextHandler(request, path=path, lines=lines, templates=templates) + elif ext in ("docx", "doc"): + return TextHandler( + request, path=path, lines=lines, templates=templates + ) + else: + return FileResponse(path=path) + + +# <--------------------------------- Get ---------------------------------> +@api.get("/list_chats") +def list_chats_for_user(request: Request): + user = extract_user_from_context(request) + chats = list_user_chats(user.id) + print(f"Chats for user {user.id}: {chats}") + return JSONResponse({"chats": chats}) + + +@api.get("/chats/{chat_id}") +def show_chat(request: Request, chat_id: str): + user = extract_user_from_context(request) + + if not protect_chat(user, chat_id): + raise HTTPException(401, "Yod do not have rights to use this chat!") + + chat_data = get_chat_with_messages(chat_id) + + print(f"DEBUG: Data for chat '{chat_id}' from get_chat_with_messages: {chat_data}") + + if not chat_data: + raise HTTPException(status_code=404, detail=f"Chat with id {chat_id} not found.") + + update_title(chat_data["chat_id"]) + + return JSONResponse(content=chat_data) + + +@api.get("/") +def last_user_chat(request: Request): + user = extract_user_from_context(request) + chat = get_latest_chat(user) + + if chat is None: + print("new_chat") + new_chat = create_new_chat("new chat", user) + url = new_chat.get("url") + + try: + create_collection(user, new_chat.get("chat_id"), rag) + except Exception as e: + raise HTTPException(500, e) + + else: + url = f"/chats/{chat.id}" + + return RedirectResponse(url, status_code=303) + + +# <--------------------------------- Post ---------------------------------> +@api.post("/new_chat") +def create_chat(request: Request, title: Optional[str] = "new chat"): + user = extract_user_from_context(request) + new_chat_data = create_new_chat(title, user) + if not new_chat_data.get("id"): + raise HTTPException(500, "New chat could not be created.") + + create_collection(user, new_chat_data["id"], rag) + + return JSONResponse(new_chat_data) + +if __name__ == "__main__": + pass diff --git a/app/automigration.py b/app/automigration.py new file mode 100644 index 0000000000000000000000000000000000000000..a9fe2864bbbad06137cec2801c21cc705295111e --- /dev/null +++ b/app/automigration.py @@ -0,0 +1,4 @@ +from app.backend.models.db_service import automigrate + +if __name__ == "__main__": + automigrate() diff --git a/app/backend/__init__.py b/app/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/backend/controllers/__init__.py b/app/backend/controllers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/backend/controllers/base_controller.py b/app/backend/controllers/base_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..229fd2e4d6b038d61a1d2690f4a01dde326009a8 --- /dev/null +++ b/app/backend/controllers/base_controller.py @@ -0,0 +1,5 @@ +from app.settings import settings +from sqlalchemy import create_engine + +postgres_config = settings.postgres.model_dump() +engine = create_engine(**postgres_config) diff --git a/app/backend/controllers/chats.py b/app/backend/controllers/chats.py new file mode 100644 index 0000000000000000000000000000000000000000..750233347ff867c053cc3e069629ba137ac4c890 --- /dev/null +++ b/app/backend/controllers/chats.py @@ -0,0 +1,117 @@ +from app.backend.models.messages import get_messages_by_chat_id, Message +from app.backend.models.users import User, get_user_chats +from app.backend.models.documents import Document +from app.backend.controllers.utils import get_group_title +from app.settings import BASE_DIR +from app.backend.models.chats import ( + get_chats_by_user_id, + get_chat_by_id, + refresh_title, + add_new_chat, +) + +from datetime import datetime, timedelta +from fastapi import HTTPException +from uuid import uuid4 +import os + + +def create_new_chat(title: str | None, user: User) -> dict: + print("+" * 40, "START Creating Chat", "+" * 40) + try: + chat_id = str(uuid4()) + add_new_chat(id=chat_id, title=title, user=user) + try: + path_to_chat = os.path.join( + BASE_DIR, + "chats_storage", + f"user_id={user.id}", + f"chat_id={chat_id}", + "documents", + ) + os.makedirs(path_to_chat, exist_ok=True) + except Exception: + raise HTTPException(500, "error while creating chat folders") + + return {"id": chat_id, "title": title} + except Exception as exception: + raise exception + finally: + print("+" * 40, "END Creating Chat", "+" * 40, "\n\n") + + +def dump_documents_dict(documents: list[Document]) -> list[map]: + output = [] + for doc in documents: + output.append({"name": doc.name, "path": doc.path, "size": doc.size}) + print("Add document --->", doc.name) + return output + + +def dump_messages_dict(messages: list[Message], dst: dict) -> None: + history = [] + + print("!" * 40, "START Dumping History", "!" * 40) + for message in messages: + history.append({"sender": message.sender, "content": message.content, "documents": dump_documents_dict(message.documents)}) + print(f"Role ----> {message.sender}, Content ----> {message.content}\n") + print("!" * 40, "END Dumping History", "!" * 40, "\n\n") + + dst.update({"messages": history}) + + +def get_chat_with_messages(id: str) -> dict: + response = {"chat_id": id} + + chat = get_chat_by_id(id=id) + if chat is None: + raise HTTPException(418, f"Invalid chat id. Chat with id={id} does not exists!") + + messages = get_messages_by_chat_id(id=id) + dump_messages_dict(messages, response) + + return response + + +def create_dict_from_chat(chat) -> dict: + return {"id": chat.id, "title": chat.title} + + +def list_user_chats(user_id: str) -> list[dict]: + current_date = datetime.now() + + today = [] + last_week = [] + last_month = [] + later = [] + + groups = [today, last_week, last_month, later] + + chats = get_chats_by_user_id(user_id) + for chat in chats: + if current_date - timedelta(days=1) <= chat.created_at: + today.append(chat) + elif current_date - timedelta(weeks=1) <= chat.created_at: + last_week.append(chat) + elif current_date - timedelta(weeks=4) <= chat.created_at: + last_month.append(chat) + else: + later.append(chat) + + result = [] + + for id, group in enumerate(groups): + if len(group): + result.append( + {"title": get_group_title(id=id), "chats": [create_dict_from_chat(chat) for chat in group]} + ) + + return result + + +def verify_ownership_rights(user: User, chat_id: str) -> bool: + return chat_id in [chat.id for chat in get_user_chats(user)] + + +def update_title(chat_id: str) -> bool: + return refresh_title(chat_id) diff --git a/app/backend/controllers/messages.py b/app/backend/controllers/messages.py new file mode 100644 index 0000000000000000000000000000000000000000..f7da29b8fad29972980699914e186694a72bf038 --- /dev/null +++ b/app/backend/controllers/messages.py @@ -0,0 +1,26 @@ +from app.backend.models.messages import add_new_message +from uuid import uuid4 +import re + + +def remove_html_tags(content: str) -> str: + pattern = "<(.*?)>" + replace_with = ( + "click me" + ) + de_taggeed = re.sub(pattern, "REPLACE_WITH_RICKROLL", content) + + return de_taggeed.replace("REPLACE_WITH_RICKROLL", replace_with) + + +def register_message(content: str, sender: str, chat_id: str) -> str: + print("-" * 40, "START Registering Message", "-" * 40) + try: + id = str(uuid4()) + message = content if sender == "system" else remove_html_tags(content) + + print(f"Message -----> {message[:min(30, len(message))]}") + + return add_new_message(id=id, chat_id=chat_id, sender=sender, content=message) + finally: + print("-" * 40, "END Registering Message", "-" * 40, "\n\n") diff --git a/app/backend/controllers/users.py b/app/backend/controllers/users.py new file mode 100644 index 0000000000000000000000000000000000000000..0aa2faf98ccdfab9f6ef9128470bb97025667e62 --- /dev/null +++ b/app/backend/controllers/users.py @@ -0,0 +1,164 @@ +from app.backend.models.chats import Chat +from app.settings import settings +from app.backend.models.users import ( + get_user_last_chat, + find_user_by_id, + add_new_user, + User, +) + +from fastapi import Response, Request, HTTPException +from datetime import datetime, timedelta, timezone + +from uuid import uuid4 +import jwt + + +def extract_user_from_context(request: Request) -> User | None: + if hasattr(request.state, "current_user"): + return request.state.current_user + print("*" * 40, "No attribute 'current_user`", "*" * 40, "\n") + return None + + +def create_access_token(user_id: str, expires_delta: timedelta = settings.max_cookie_lifetime) -> str: + token_payload = {"user_id": user_id} + token_payload.update({"exp": datetime.now() + expires_delta}) + + try: + encoded_jwt: str = jwt.encode( + token_payload, settings.secret_pepper, algorithm=settings.jwt_algorithm + ) + except Exception: + raise HTTPException(status_code=500, detail="json encoding error") + + print("^" * 40, "New JWT token was created", "^" * 40) + print(encoded_jwt) + print("^" * 105, "\n\n") + + return encoded_jwt + + +def create_user() -> User | None: + new_user_id = str(uuid4()) + try: + user = add_new_user(id=new_user_id) + except Exception as e: + raise HTTPException(status_code=418, detail=e) + + print("$" * 40, "New User was created", "$" * 40) + print("Created user - {user.id}") + print("$" * 100, "\n\n") + + return user + + +def authorize_user(response: Response, user: User) -> dict: + print("%" * 40, "START Authorizing User", "%" * 40) + try: + access_token: str = create_access_token(user_id=user.id) + expires = datetime.now(timezone.utc) + settings.max_cookie_lifetime + + response.set_cookie( + key="access_token", + value=access_token, + path="/", + expires=expires.strftime("%a, %d %b %Y %H:%M:%S GMT"), + max_age=settings.max_cookie_lifetime, + httponly=True, + secure=True, + samesite='None' + ) + + return {"status": "ok"} + finally: + print("%" * 40, "END Authorizing User", "%" * 40) + + +def get_current_user(request: Request) -> User | None: + print("-" * 40, "START Getting User", "-" * 40) + try: + user = None + token: str | None = request.cookies.get("access_token") + + print(f"Token -----> {token if token else 'Empty token!'}\n") + + if not token: + return None + + try: + user_id = jwt.decode( + jwt=bytes(token, encoding="utf-8"), + key=settings.secret_pepper, + algorithms=[settings.jwt_algorithm], + ).get("user_id") + + print(f"User id -----> {user_id if user_id else 'Empty user id!'}\n") + + user = find_user_by_id(id=user_id) + + print(f"Found user -----> {user.id if user else 'No user was found!'}") + except Exception as e: + raise e + + if not user: + return None + + return user + except HTTPException as exception: + raise exception + finally: + print("-" * 40, "END Getting User", "-" * 40, "\n\n") + + +def check_cookie(request: Request) -> dict: + result = {"token": "No token is present"} + token = request.cookies.get("access_token") + if token: + result["token"] = token + return result + + +def clear_cookie(response: Response) -> dict: + response.set_cookie(key="access_token", value="", httponly=True) + return {"status": "ok"} + + +def get_latest_chat(user: User) -> Chat | None: + return get_user_last_chat(user) + + +def refresh_cookie(request: Request, response: Response) -> None: + print("+" * 40, "START Refreshing cookie", "+" * 40) + try: + token: str | None = request.cookies.get("access_token") + + print(f"Token -----> {token if token else 'Empty token!'}\n") + + if token is None: + return + + try: + jwt_token = jwt.decode( + jwt=bytes(token, encoding="utf-8"), + key=settings.secret_pepper, + algorithms=[settings.jwt_algorithm], + ) + exp_datetime = datetime.fromtimestamp(jwt_token.get("exp"), tz=timezone.utc) + print(f"Expires -----> {exp_datetime if exp_datetime else 'No expiration date!'}\n") + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="jwt signature has expired") + except jwt.PyJWTError as e: + raise HTTPException(status_code=500, detail=e) + + diff = exp_datetime - datetime.now(timezone.utc) + print(f"Difference -----> {diff if diff else 'No difference in date!'}\n") + + if diff.total_seconds() < 0.2 * settings.max_cookie_lifetime.total_seconds(): + print("<----- Refreshing ----->") + user = extract_user_from_context(request) + authorize_user(response, user) + except HTTPException as exception: + raise exception + finally: + print("+" * 40, "END Refreshing cookie", "+" * 40, "\n\n") diff --git a/app/backend/controllers/utils.py b/app/backend/controllers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc94d940dc3f8ffc0ab5a2a26f5084fcb1b77691 --- /dev/null +++ b/app/backend/controllers/utils.py @@ -0,0 +1,13 @@ +def get_group_title(id: int) -> str: + result = "LATER" + + if id == 0: + result = "TODAY" + elif id == 1: + result = "LAST_WEEK" + elif id == 2: + result = "LAST_MONTH" + elif id == 3: + result = "LATER" + + return result diff --git a/app/backend/models/__init__.py b/app/backend/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/backend/models/base_model.py b/app/backend/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..30e5db43a880fb261c34e72fc0eed7b79a233d01 --- /dev/null +++ b/app/backend/models/base_model.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, DateTime +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.sql import func + + +class Base(DeclarativeBase): + ''' + Base model for all others \\ + Defines base for table creation + ''' + __abstract__ = True + created_at = Column("created_at", DateTime, default=func.now()) + deleted_at = Column("deleted_at", DateTime, nullable=True) + updated_at = Column("updated_at", DateTime, nullable=True) diff --git a/app/backend/models/chats.py b/app/backend/models/chats.py new file mode 100644 index 0000000000000000000000000000000000000000..868f0e20ed27c79a36c9e93a183816fa7d4cf29a --- /dev/null +++ b/app/backend/models/chats.py @@ -0,0 +1,51 @@ +from app.backend.models.base_model import Base +from sqlalchemy import String, Column, ForeignKey +from sqlalchemy.orm import relationship, Session +from app.backend.controllers.base_controller import engine + + +class Chat(Base): + __tablename__ = "chats" + id = Column("id", String, primary_key=True, unique=True) + title = Column("title", String, nullable=True) + user_id = Column(String, ForeignKey("users.id")) + user = relationship("User", back_populates="chats") + messages = relationship("Message", back_populates="chat") + + +def add_new_chat(id: str, title: str | None, user) -> None: + with Session(autoflush=False, bind=engine) as db: + user = db.merge(user) + new_chat = Chat(id=id, user_id=user.id, user=user) + if title: + new_chat.title = title + db.add(new_chat) + db.commit() + + +def get_chat_by_id(id: str) -> Chat | None: + with Session(autoflush=False, bind=engine) as db: + return db.query(Chat).where(Chat.id == id).first() + + +def get_chats_by_user_id(id: str) -> list[Chat]: + with Session(autoflush=False, bind=engine) as db: + return ( + db.query(Chat).filter(Chat.user_id == id).order_by(Chat.created_at.desc()) + ) + + +def refresh_title(chat_id: str) -> bool: + with Session(autoflush=False, bind=engine) as db: + chat = db.get(Chat, chat_id) + messages = chat.messages + + if messages is None or len(messages) == 0: + return False + + chat.title = messages[0].content[:47] + if len(messages[0].content) > 46: + chat.title += "..." + + db.commit() + return True diff --git a/app/backend/models/db_service.py b/app/backend/models/db_service.py new file mode 100644 index 0000000000000000000000000000000000000000..66937a583e43b661577d5c8aedeb6190f092dcda --- /dev/null +++ b/app/backend/models/db_service.py @@ -0,0 +1,37 @@ +from sqlalchemy import inspect +from app.backend.controllers.base_controller import engine +from app.backend.models.base_model import Base +from app.backend.models.chats import Chat +from app.backend.models.messages import Message +from app.backend.models.users import User +from app.backend.models.documents import Document + + +def table_exists(name: str) -> bool: + return inspect(engine).has_table(name) + + +def create_tables() -> None: + Base.metadata.create_all(engine) + + +def drop_tables() -> None: + # List tables in the correct order for dropping (considering dependencies) + tables = [Document.__table__, Message.__table__, Chat.__table__, User.__table__] + + for table in tables: + if table_exists(table.name): + try: + table.drop(engine) + print(f"Dropped table {table.name}") + except Exception as e: + print(f"Error dropping table {table.name}: {e}") + else: + print(f"Table {table.name} does not exist, skipping drop") + + +def automigrate() -> None: + print("Starting automigration...") + drop_tables() + create_tables() + print("Automigration completed.") \ No newline at end of file diff --git a/app/backend/models/documents.py b/app/backend/models/documents.py new file mode 100644 index 0000000000000000000000000000000000000000..15a341f6c76dda59e9b75a2db4c8fcb014826a3e --- /dev/null +++ b/app/backend/models/documents.py @@ -0,0 +1,22 @@ +from sqlalchemy import Column, ForeignKey, String, Text, Integer +from sqlalchemy.orm import Session, relationship + +from app.backend.controllers.base_controller import engine +from app.backend.models.base_model import Base + + +class Document(Base): + __tablename__ = "documents" + id = Column('id', String, primary_key=True, unique=True) + name = Column('name', String, nullable=False) + path = Column('path', String, nullable=False) + size = Column('size', Integer, nullable=False) + message_id = Column("message_id", ForeignKey("messages.id")) + message = relationship("Message", back_populates="documents") + + +def add_new_document(id: str, name: str, path: str, message_id: str, size: int): + with Session(autoflush=False, bind=engine) as db: + new_doc = Document(id=id, name=name, path=path, message_id=message_id, size=size) + db.add(new_doc) + db.commit() \ No newline at end of file diff --git a/app/backend/models/messages.py b/app/backend/models/messages.py new file mode 100644 index 0000000000000000000000000000000000000000..093749c4f102b8ba1fe1f282648f90dc78dbc02f --- /dev/null +++ b/app/backend/models/messages.py @@ -0,0 +1,28 @@ +from sqlalchemy import Column, ForeignKey, String, Text, asc +from sqlalchemy.orm import Session, relationship, joinedload + +from app.backend.controllers.base_controller import engine +from app.backend.models.base_model import Base + + +class Message(Base): + __tablename__ = "messages" + id = Column("id", String, primary_key=True, unique=True) + content = Column("text", Text) + sender = Column("role", String) + chat_id = Column(String, ForeignKey("chats.id")) + chat = relationship("Chat", back_populates="messages") + documents = relationship("Document", back_populates="message") + + +def add_new_message(id: str, chat_id: str, sender: str, content: str) -> str: + with Session(autoflush=False, bind=engine) as db: + new_message = Message(id=id, content=content, sender=sender, chat_id=chat_id) + db.add(new_message) + db.commit() + return id + + +def get_messages_by_chat_id(id: str) -> list[Message]: + with Session(autoflush=False, bind=engine) as db: + return db.query(Message).options(joinedload(Message.documents)).filter(Message.chat_id == id).order_by(asc(Message.created_at)) diff --git a/app/backend/models/users.py b/app/backend/models/users.py new file mode 100644 index 0000000000000000000000000000000000000000..69f76fc541fbf4208336186481212676600ac67e --- /dev/null +++ b/app/backend/models/users.py @@ -0,0 +1,58 @@ +from app.backend.controllers.base_controller import engine +from app.backend.models.base_model import Base +from app.backend.models.chats import Chat + +from sqlalchemy.orm import relationship, Session +from sqlalchemy import Column, String + + +class User(Base): + ''' + Base model for users table + ''' + __tablename__ = "users" + id = Column("id", String, primary_key=True, unique=True) + language = Column("language", String, default="English", nullable=False) + theme = Column("theme", String, default="light", nullable=False) + chats = relationship("Chat", back_populates="user") + + +def add_new_user(id: str) -> User: + with Session(autoflush=False, bind=engine, expire_on_commit=False) as db: + new_user = User(id=id) + db.add(new_user) + db.commit() + return new_user + + +def find_user_by_id(id: str) -> User | None: + with Session(autoflush=False, bind=engine) as db: + return db.query(User).where(User.id == id).first() + + +def update_user(user: User, language: str = None, theme: str = None) -> None: + with Session(autoflush=False, bind=engine) as db: + user = db.merge(user) + if language: + user.language = language + if theme: + user.theme = theme + db.commit() + + +def get_user_chats(user: User) -> list[Chat]: + with Session(autoflush=False, bind=engine) as db: + user = db.get(User, user.id) + return user.chats + + +def get_user_last_chat(user: User) -> Chat | None: + with Session(autoflush=False, bind=engine) as db: + user = db.get(User, user.id) + + chats = user.chats + + if chats is not None and len(chats): + return chats[-1] + + return None diff --git a/app/backend/schemas.py b/app/backend/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..78cad9e28e52d8d0c368d2b09230db5db63ca783 --- /dev/null +++ b/app/backend/schemas.py @@ -0,0 +1,20 @@ +from enum import Enum + + +class ThemeOptions(str, Enum): + ''' + Used as custom-defined fields in `users` table + Means UI theme + ''' + LIGHT = "light" + DARK = "dark" + + +class LanguageOptions(str, Enum): + ''' + Used as custom-defined fields in `users` table + Means preferred response language + ''' + AR = "ar" + EN = "en" + RU = "ru" diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/core/chunks.py b/app/core/chunks.py new file mode 100644 index 0000000000000000000000000000000000000000..1a994e37dede7b6a8083cd64d6ad4e4c7b718690 --- /dev/null +++ b/app/core/chunks.py @@ -0,0 +1,54 @@ +import uuid + + +class Chunk: + """ + id -> unique number in uuid format, can be tried https://www.uuidgenerator.net/ + start_index -> the index of the first char from the beginning of the original document + + TODO: implement access modifiers and set of getters and setters + """ + + def __init__( + self, + id: uuid.UUID, + filename: str, + page_number: int, + start_index: int, + start_line: int, + end_line: int, + text: str, + ): + self.id: uuid.UUID = id + self.filename: str = filename + self.page_number: int = page_number + self.start_index: int = start_index + self.start_line: int = start_line + self.end_line: int = end_line + self.text: str = text + + def get_raw_text(self) -> str: + return self.text + + def get_splitted_text(self) -> list[str]: + return self.text.split(" ") + + def get_metadata(self) -> dict: + return { + "id": str(self.id), + "filename": self.filename, + "page_number": self.page_number, + "start_index": self.start_index, + "start_line": self.start_line, + "end_line": self.end_line, + } + + # TODO: remove kostyly + def __str__(self): + return ( + f"Chunk from {self.filename.split('/')[-1]}, " + f"page - {self.page_number}, " + f"start - {self.start_line}, " + f"end - {self.end_line}, " + f"and text - {self.text[:100]}... ({len(self.text)})...{self.text[-20:]}\n" + ) diff --git a/app/core/database.py b/app/core/database.py new file mode 100644 index 0000000000000000000000000000000000000000..64d8df278ec8828402cf652fabb00c23472202f4 --- /dev/null +++ b/app/core/database.py @@ -0,0 +1,233 @@ +from qdrant_client import QdrantClient # main component to provide the access to db +from qdrant_client.http.models import ( + ScoredPoint, + Filter, + FieldCondition, + MatchText +) +from qdrant_client.models import ( + VectorParams, + Distance, + PointStruct, + TextIndexParams, + TokenizerType +) # VectorParams -> config of vectors that will be used as primary keys +from app.core.models import Embedder # Distance -> defines the metric +from app.core.chunks import Chunk # PointStruct -> instance that will be stored in db +import numpy as np +from uuid import UUID +from app.settings import settings +import time +from fastapi import HTTPException +import re + + +class VectorDatabase: + def __init__(self, embedder: Embedder, host: str = "qdrant", port: int = 6333): + self.host: str = host + self.client: QdrantClient = self._initialize_qdrant_client() + self.embedder: Embedder = embedder # embedder is used to convert a user's query + self.already_stored: np.array[np.array] = np.array([]).reshape( + 0, embedder.get_vector_dimensionality() + ) + + def store( + self, collection_name: str, chunks: list[Chunk], batch_size: int = 1000 + ) -> None: + points: list[PointStruct] = [] + + print("Start getting text embeddings") + start = time.time() + vectors = self.embedder.encode([chunk.get_raw_text() for chunk in chunks]) + print(f"Embeddings - {time.time() - start}") + + for vector, chunk in zip(vectors, chunks): + if self.accept_vector(collection_name, vector): + points.append( + PointStruct( + id=str(chunk.id), + vector=vector, + payload={ + "metadata": chunk.get_metadata(), + "text": chunk.get_raw_text(), + }, + ) + ) + + if len(points): + for group in range(0, len(points), batch_size): + self.client.upsert( + collection_name=collection_name, + points=points[group : group + batch_size], + wait=False, + ) + + """ + Measures a cosine of angle between tow vectors + """ + + def cosine_similarity(self, vec1: list[float], vec2: list[float] | list[list[float]]) -> float: + if len(vec2) == 0: + return 0 + + vec1_np = np.array(vec1) + vec2_np = np.array(vec2) + + if vec2_np.ndim == 2: + vec2_np = vec2_np.T + + similarities = np.array(vec1_np @ vec2_np / (np.linalg.norm(vec1_np) * np.linalg.norm(vec2_np, axis=0))) + return np.max(similarities) + + """ + Defines weather the vector should be stored in the db by searching for the most + similar one + """ + + def accept_vector(self, collection_name: str, vector: np.array) -> bool: + most_similar = self.client.query_points( + collection_name=collection_name, query=vector, limit=1, with_vectors=True + ).points + + if not len(most_similar): + return True + else: + most_similar = most_similar[0] + + if 1 - self.cosine_similarity(vector, most_similar.vector) < settings.max_delta: + return False + return True + + def construct_keywords_list(self, query: str) -> list[FieldCondition]: + keywords = re.findall(r'\b[A-Z]{2,}\b', query) + filters = [] + + print(keywords) + + for word in keywords: + if len(word) > 30 or len(word) < 2: + continue + filters.append(FieldCondition(key="text", match=MatchText(text=word))) + + return filters + + def combine_points_without_duplications(self, first: list[ScoredPoint], second: list[ScoredPoint] = None) -> list[ScoredPoint]: + combined = [] + similarity_vectors = [] + + to_combine = [first] + if second is not None: + to_combine.append(second) + + for group in to_combine: + for point in group: + if 1 - self.cosine_similarity(point.vector, similarity_vectors) > min(settings.max_delta, 0.2): + combined.append(point) + similarity_vectors.append(point.vector) + return combined + + def search(self, collection_name: str, query: str, top_k: int = 5) -> list[Chunk]: + query_embedded: np.ndarray = self.embedder.encode(query) + + if isinstance(query_embedded, list): + query_embedded = query_embedded[0] + + keywords = self.construct_keywords_list(query) + + mixed_result: list[ScoredPoint] = self.client.query_points( + collection_name=collection_name, query=query_embedded, limit=top_k + int(top_k * 0.3), + query_filter=Filter(should=keywords), with_vectors=True + ).points + + print(f"Len of original array -> {len(mixed_result)}") + combined = self.combine_points_without_duplications(mixed_result) + print(f"Len of combined array -> {len(combined)}") + + return [ + Chunk( + id=UUID(point.payload.get("metadata", {}).get("id", "")), + filename=point.payload.get("metadata", {}).get("filename", ""), + page_number=point.payload.get("metadata", {}).get("page_number", 0), + start_index=point.payload.get("metadata", {}).get("start_index", 0), + start_line=point.payload.get("metadata", {}).get("start_line", 0), + end_line=point.payload.get("metadata", {}).get("end_line", 0), + text=point.payload.get("text", ""), + ) + for point in combined + ] + + def _initialize_qdrant_client(self, max_retries=5, delay=2) -> QdrantClient: + for attempt in range(max_retries): + try: + client = QdrantClient(**settings.qdrant.model_dump()) + client.get_collections() + return client + except Exception as e: + if attempt == max_retries - 1: + raise HTTPException( + 500, + f"Failed to connect to Qdrant server after {max_retries} attempts. " + f"Last error: {str(e)}", + ) + + print( + f"Connection attempt {attempt + 1} out of {max_retries} failed. " + f"Retrying in {delay} seconds..." + ) + + time.sleep(delay) + delay *= 2 + + def _check_collection_exists(self, collection_name: str) -> bool: + try: + return self.client.collection_exists(collection_name) + except Exception as e: + raise HTTPException( + 500, + f"Failed to check collection {collection_name} exists. Last error: {str(e)}", + ) + + def _create_collection(self, collection_name: str) -> None: + try: + self.client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams( + size=self.embedder.get_vector_dimensionality(), + distance=Distance.COSINE, + ), + ) + self.client.create_payload_index( + collection_name=collection_name, + field_name="text", + field_schema=TextIndexParams( + type="text", + tokenizer=TokenizerType.WORD, + min_token_len=2, + max_token_len=30, + lowercase=True + ) + ) + except Exception as e: + raise HTTPException( + 500, f"Failed to create collection {self.collection_name}: {str(e)}" + ) + + def create_collection(self, collection_name: str) -> None: + try: + if self._check_collection_exists(collection_name): + return + self._create_collection(collection_name) + except Exception as e: + print(e) + raise HTTPException(500, e) + + def __del__(self): + if hasattr(self, "client"): + self.client.close() + + def get_collections(self) -> list[str]: + try: + return self.client.get_collections() + except Exception as e: + print(e) + raise HTTPException(500, "Failed to get collection names") diff --git a/app/core/document_validator.py b/app/core/document_validator.py new file mode 100644 index 0000000000000000000000000000000000000000..739a297f4b5061193c96ba106583967eb5757d20 --- /dev/null +++ b/app/core/document_validator.py @@ -0,0 +1,9 @@ +import os + +""" +Checks if the given path is valid and file exists +""" + + +def path_is_valid(path: str) -> bool: + return os.path.exists(path) diff --git a/app/core/main.py b/app/core/main.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ee31712c773ffb8ce1c0d7fab0892486e6c6c8 --- /dev/null +++ b/app/core/main.py @@ -0,0 +1,50 @@ +from app.settings import settings, BASE_DIR +import uvicorn +import os +from app.backend.models.db_service import automigrate + + +def initialize_system() -> bool: + success = True + path = BASE_DIR + temp_storage_path = os.path.join(path, "app", "temp_storage") + static_path = os.path.join(path, "static") + pdfs_path = os.path.join(path, "app", "temp_storage", "pdfs") + database_path = os.path.join(path, "database") + chats_storage_path = os.path.join(path, "chats_storage") + + print(f"Base path: {BASE_DIR}") + print(f"Parent path: {path}") + print(f"Temp storage path: {temp_storage_path}") + print(f"Static path: {static_path}") + print(f"PDFs path: {pdfs_path}") + print(f"Database path: {database_path}") + print(f"Database path: {chats_storage_path}") + + try: + os.makedirs(temp_storage_path, exist_ok=True) + print("Created temp_storage_path") + os.makedirs(static_path, exist_ok=True) + print("Created static_path") + os.makedirs(pdfs_path, exist_ok=True) + print("Created pdfs_path") + os.makedirs(database_path, exist_ok=True) + print("Created database_path") + os.makedirs(chats_storage_path, exist_ok=True) + print("Created chats_storage_path") + except Exception as e: + success = False + print(f"Error creating directories: {str(e)}") + + return success + + +def main(): + # automigrate() # Note: it will drop all existing dbs and create a new ones + initialize_system() + uvicorn.run(**settings.api.model_dump()) + + +if __name__ == "__main__": + # ATTENTION: run from base dir ---> python -m app.main + main() diff --git a/app/core/models.py b/app/core/models.py new file mode 100644 index 0000000000000000000000000000000000000000..842163cbd8f4cefe14eea8c07c321681463653db --- /dev/null +++ b/app/core/models.py @@ -0,0 +1,214 @@ +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from dotenv import load_dotenv +from sentence_transformers import ( + SentenceTransformer, + CrossEncoder, +) # SentenceTransformer -> model for embeddings, CrossEncoder -> re-ranker +from ctransformers import AutoModelForCausalLM +from torch import Tensor +from google import genai +from google.genai import types +from app.core.chunks import Chunk +from app.settings import settings, BASE_DIR, GeminiEmbeddingSettings + +load_dotenv() + + +class Embedder: + def __init__(self, model: str = "BAAI/bge-m3"): + self.device: str = settings.device + self.model_name: str = model + self.model: SentenceTransformer = SentenceTransformer(model, device=self.device) + + """ + Encodes string to dense vector + """ + + def encode(self, text: str | list[str]) -> Tensor | list[Tensor]: + return self.model.encode(sentences=text, show_progress_bar=False, batch_size=32) + + """ + Returns the dimensionality of dense vector + """ + + def get_vector_dimensionality(self) -> int | None: + return self.model.get_sentence_embedding_dimension() + + +class Reranker: + def __init__(self, model: str = "cross-encoder/ms-marco-MiniLM-L6-v2"): + self.device: str = settings.device + self.model_name: str = model + self.model: CrossEncoder = CrossEncoder(model, device=self.device) + + """ + Returns re-sorted (by relevance) vector with dicts, from which we need only the 'corpus_id' + since it is a position of chunk in original list + """ + + def rank(self, query: str, chunks: list[Chunk]) -> list[dict[str, int]]: + return self.model.rank(query, [chunk.get_raw_text() for chunk in chunks]) + + +# TODO: add models parameters to global config file +# TODO: add exception handling when response have more tokens than was set +# TODO: find a way to restrict the model for providing too long answers + + +class LocalLLM: + def __init__(self): + self.model = AutoModelForCausalLM.from_pretrained( + **settings.local_llm.model_dump() + ) + + """ + Produces the response to user's prompt + + stream -> flag, determines weather we need to wait until the response is ready or can show it token by token + + TODO: invent a way to really stream the answer (as return value) + """ + + def get_response( + self, + prompt: str, + stream: bool = True, + logging: bool = True, + use_default_config: bool = True, + ) -> str: + + with open("../prompt.txt", "w") as f: + f.write(prompt) + + generated_text = "" + tokenized_text: list[int] = self.model.tokenize(text=prompt) + response: list[int] = self.model.generate( + tokens=tokenized_text, **settings.local_llm.model_dump() + ) + + if logging: + print(response) + + if not stream: + return self.model.detokenize(response) + + for token in response: + chunk = self.model.detokenize([token]) + generated_text += chunk + if logging: + print(chunk, end="", flush=True) # flush -> clear the buffer + + return generated_text + + +class GeminiLLM: + def __init__(self, model="gemini-2.0-flash"): + self.client = genai.Client(api_key=settings.api_key) + self.model = model + + def get_response( + self, + prompt: str, + stream: bool = True, + logging: bool = True, + use_default_config: bool = False, + ) -> str: + path_to_prompt = os.path.join(BASE_DIR, "prompt.txt") + with open(path_to_prompt, "w", encoding="utf-8", errors="replace") as f: + f.write(prompt) + + response = self.client.models.generate_content( + model=self.model, + contents=prompt, + config=( + types.GenerateContentConfig(**settings.gemini_generation.model_dump()) + if use_default_config + else None + ), + ) + + return response.text + + async def get_streaming_response( + self, + prompt: str, + stream: bool = True, + logging: bool = True, + use_default_config: bool = False, + ): + path_to_prompt = os.path.join(BASE_DIR, "prompt.txt") + with open(path_to_prompt, "w", encoding="utf-8", errors="replace") as f: + f.write(prompt) + + response = self.client.models.generate_content_stream( + model=self.model, + contents=prompt, + config=( + types.GenerateContentConfig(**settings.gemini_generation.model_dump()) + if use_default_config + else None + ), + ) + + for chunk in response: + yield chunk + + +class GeminiEmbed: + def __init__(self, model="text-embedding-004"): + self.client = genai.Client(api_key=settings.api_key) + self.model = model + self.settings = GeminiEmbeddingSettings() + self.max_workers = 5 + + def _embed_batch(self, batch: list[str], idx: int) -> dict: + response = self.client.models.embed_content( + model=self.model, + contents=batch, + config=types.EmbedContentConfig( + **settings.gemini_embedding.model_dump() + ), + ).embeddings + return {"idx": idx, "embeddings": response} + + def encode(self, text: str | list[str]) -> list[Tensor]: + + if isinstance(text, str): + text = [text] + + groups: list[list[float]] = [] + max_batch_size = 100 # can not be changed due to google restrictions + + batches: list[list[str]] = [text[i : i + max_batch_size] for i in range(0, len(text), max_batch_size)] + print(*[len(batch) for batch in batches]) + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [executor.submit(self._embed_batch, batch, idx) for idx, batch in enumerate(batches)] + for future in as_completed(futures): + groups.append(future.result()) + + groups.sort(key=lambda x: x["idx"]) + + result: list[float] = [] + for group in groups: + for vec in group["embeddings"]: + result.append(vec.values) + return result + + def get_vector_dimensionality(self) -> int | None: + return getattr(self.settings, "output_dimensionality") + + +class Wrapper: + def __init__(self, model: str = "gemini-2.0-flash"): + self.model = model + self.client = genai.Client(api_key=settings.api_key) + + def wrap(self, prompt: str) -> str: + response = self.client.models.generate_content( + model=self.model, + contents=prompt, + config=types.GenerateContentConfig(**settings.gemini_wrapper.model_dump()) + ) + + return response.text diff --git a/app/core/processor.py b/app/core/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..53a7af3c7fd73012c5e5f8ce2bb4f12830fc14e8 --- /dev/null +++ b/app/core/processor.py @@ -0,0 +1,305 @@ +from langchain_community.document_loaders import ( + UnstructuredWordDocumentLoader, + TextLoader, + CSVLoader, + UnstructuredMarkdownLoader, +) +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_core.documents import Document +from app.core.chunks import Chunk +import nltk # used for proper tokenizer workflow +from uuid import ( + uuid4, +) # for generating unique id as hex (uuid4 is used as it generates ids form pseudo random numbers unlike uuid1 and others) +import numpy as np +from app.settings import logging, settings +from concurrent.futures import ProcessPoolExecutor, as_completed +import os +import fitz + +class PDFLoader: + def __init__(self, file_path: str): + self.file_path = file_path + + def load(self) -> list[Document]: + docs = [] + with fitz.open(self.file_path) as doc: + for page in doc: + text = page.get_text("text") + metadata = { + "source": self.file_path, + "page": page.number, + } + docs.append(Document(page_content=text, metadata=metadata)) + return docs + + +class DocumentProcessor: + """ + TODO: determine the most suitable chunk size + + chunks -> the list of chunks from loaded files + chunks_unsaved -> the list of recently added chunks that have not been saved to db yet + processed -> the list of files that were already splitted into chunks + unprocessed -> !processed + text_splitter -> text splitting strategy + """ + + def __init__(self): + self.chunks_unsaved: list[Chunk] = [] + self.unprocessed: list[Document] = [] + self.max_workers = min(4, os.cpu_count() or 1) + self.text_splitter = RecursiveCharacterTextSplitter( + **settings.text_splitter.model_dump() + ) + + """ + Measures cosine between two vectors + """ + + def cosine_similarity(self, vec1, vec2): + return vec1 @ vec2 / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) + + """ + Updates a list of the most relevant chunks without interacting with db + """ + + def update_most_relevant_chunk( + self, + chunk: list[np.float64, Chunk], + relevant_chunks: list[list[np.float64, Chunk]], + mx_len=15, + ): + relevant_chunks.append(chunk) + for i in range(len(relevant_chunks) - 1, 0, -1): + if relevant_chunks[i][0] > relevant_chunks[i - 1][0]: + relevant_chunks[i], relevant_chunks[i - 1] = ( + relevant_chunks[i - 1], + relevant_chunks[i], + ) + else: + break + + if len(relevant_chunks) > mx_len: + del relevant_chunks[-1] + + """ + Loads one file - extracts text from file + + TODO: Replace UnstructuredWordDocumentLoader with Docx2txtLoader + TODO: Play with .pdf and text from img extraction + TODO: Try chunking with llm + + add_to_unprocessed -> used to add loaded file to the list of unprocessed(unchunked) files if true + """ + + def check_size(self, file_path: str = "") -> bool: + try: + size = os.path.getsize(filename=file_path) + except Exception: + size = 0 + + if size > 1000000: + return True + return False + + def document_multiplexer(self, filepath: str, get_loader: bool = False, get_chunking_strategy: bool = False): + loader = None + parallelization = False + if filepath.endswith(".pdf"): + loader = PDFLoader( + file_path=filepath + ) # splits each presentation into slides and processes it as separate file + parallelization = False + elif filepath.endswith(".docx") or filepath.endswith(".doc"): + loader = UnstructuredWordDocumentLoader(file_path=filepath) + elif filepath.endswith(".txt"): + loader = TextLoader(file_path=filepath) + elif filepath.endswith(".csv"): + loader = CSVLoader(file_path=filepath) + elif filepath.endswith(".json"): + loader = TextLoader(file_path=filepath) + elif filepath.endswith(".md"): + loader = UnstructuredMarkdownLoader(file_path=filepath) + + if filepath.endswith(".pdf"): + parallelization = False + else: + parallelization = self.check_size(file_path=filepath) + + if get_loader: + return loader + elif get_chunking_strategy: + return parallelization + else: + raise RuntimeError("What to do, my lord?") + + def load_document( + self, filepath: str, add_to_unprocessed: bool = False + ) -> list[Document]: + loader = self.document_multiplexer(filepath=filepath, get_loader=True) + + if loader is None: + raise RuntimeError("Unsupported type of file") + + documents: list[Document] = [] # We can not assign a single value to the document since .pdf are splitted into several files + try: + documents = loader.load() + # print("-" * 100, documents, "-" * 100, sep="\n") + except Exception: + raise RuntimeError("File is corrupted") + + if add_to_unprocessed: + for doc in documents: + self.unprocessed.append(doc) + + strategy = self.document_multiplexer(filepath=filepath, get_chunking_strategy=True) + print(f"Strategy --> {strategy}") + self.generate_chunks(parallelization=strategy) + return documents + + """ + Similar to load_document, but for multiple files + + add_to_unprocessed -> used to add loaded files to the list of unprocessed(unchunked) files if true + """ + + def load_documents( + self, documents: list[str], add_to_unprocessed: bool = False + ) -> list[Document]: + extracted_documents: list[Document] = [] + + for doc in documents: + temp_storage: list[Document] = [] + + try: + temp_storage = self.load_document( + filepath=doc, add_to_unprocessed=True + ) + except Exception as e: + logging.error( + "Error at load_documents while loading %s", doc, exc_info=e + ) + continue + + for extrc_doc in temp_storage: + extracted_documents.append(extrc_doc) + + if add_to_unprocessed: + self.unprocessed.append(extrc_doc) + + return extracted_documents + + def split_into_groups(self, original_list: list[any], split_by: int = 15) -> list[list[any]]: + output = [] + for i in range(0, len(original_list), split_by): + new_group = original_list[i: i + split_by] + output.append(new_group) + return output + + def _chunkinize(self, document: Document, text: list[str], lines: list[dict]) -> list[Chunk]: + output: list[Chunk] = [] + for chunk in text: + start_l, end_l = self.get_start_end_lines( + splitted_text=lines, + start_char=chunk.metadata.get("start_index", 0), + end_char=chunk.metadata.get("start_index", 0) + + len(chunk.page_content), + ) + + new_chunk = Chunk( + id=uuid4(), + filename=document.metadata.get("source", ""), + page_number=document.metadata.get("page", 0), + start_index=chunk.metadata.get("start_index", 0), + start_line=start_l, + end_line=end_l, + text=chunk.page_content, + ) + # print(new_chunk) + output.append(new_chunk) + return output + + def precompute_lines(self, splitted_document: list[str]) -> list[dict]: + current_start = 0 + output: list[dict] = [] + for i, line in enumerate(splitted_document): + output.append({"id": i + 1, "start": current_start, "end": current_start + len(line) + 1, "text": line}) + current_start += len(line) + 1 + return output + + def generate_chunks(self, parallelization: bool = True): + intermediate = [] + for document in self.unprocessed: + text: list[str] = self.text_splitter.split_documents(documents=[document]) + lines: list[dict] = self.precompute_lines(splitted_document=document.page_content.splitlines()) + groups = self.split_into_groups(original_list=text, split_by=50) + + if parallelization: + print("<------- Apply Parallel Execution ------->") + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = [executor.submit(self._chunkinize, document, group, lines) for group in groups] + for feature in as_completed(futures): + intermediate.append(feature.result()) + else: + intermediate.append(self._chunkinize(document=document, text=text, lines=lines)) + + for group in intermediate: + for chunk in group: + self.chunks_unsaved.append(chunk) + + self.unprocessed = [] + + def find_line(self, splitted_text: list[dict], char) -> int: + l, r = 0, len(splitted_text) - 1 + + while l <= r: + m = (l + r) // 2 + line = splitted_text[m] + + if line["start"] <= char < line["end"]: + return m + 1 + elif char < line["start"]: + r = m - 1 + else: + l = m + 1 + + return r + + def get_start_end_lines( + self, + splitted_text: list[dict], + start_char: int, + end_char: int, + debug_mode: bool = False, + ) -> tuple[int, int]: + start = self.find_line(splitted_text=splitted_text, char=start_char) + end = self.find_line(splitted_text=splitted_text, char=end_char) + return (start, end) + + """ + Note: it should be used only once to download tokenizers, futher usage is not recommended + """ + + def update_nltk(self) -> None: + nltk.download("punkt") + nltk.download("averaged_perceptron_tagger") + + """ + For now the system works as follows: we save recently loaded chunks in two arrays: + chunks - for all chunks, even for that ones that havn't been saveed to db + chunks_unsaved - for chunks that have been added recently + I do not know weather we really need to store all chunks that were added in the + current session, but chunks_unsaved are used to avoid dublications while saving to db. + """ + + def get_and_save_unsaved_chunks(self) -> list[Chunk]: + chunks_copy: list[Chunk] = self.chunks_unsaved.copy() + self.clear_unsaved_chunks() + return chunks_copy + + def clear_unsaved_chunks(self): + self.chunks_unsaved = [] + + def get_all_chunks(self) -> list[Chunk]: + return self.chunks_unsaved diff --git a/app/core/rag_generator.py b/app/core/rag_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..329b767119787aca382b3b4d1b03e8198c266661 --- /dev/null +++ b/app/core/rag_generator.py @@ -0,0 +1,171 @@ +from typing import Any, AsyncGenerator +from app.core.models import LocalLLM, Embedder, Reranker, GeminiLLM, GeminiEmbed, Wrapper +from app.core.processor import DocumentProcessor +from app.core.database import VectorDatabase +import time +import os +from app.settings import settings, BASE_DIR + + +class RagSystem: + def __init__(self): + self.embedder = ( + GeminiEmbed() + if settings.use_gemini + else Embedder(model=settings.models.embedder_model) + ) + self.reranker = Reranker(model=settings.models.reranker_model) + self.processor = DocumentProcessor() + self.db = VectorDatabase(embedder=self.embedder) + self.llm = GeminiLLM() if settings.use_gemini else LocalLLM() + self.wrapper = Wrapper() + + """ + Provides a prompt with substituted context from chunks + + TODO: add template to prompt without docs + """ + + def get_general_prompt(self, user_prompt: str, collection_name: str) -> str: + enhanced_prompt = self.enhance_prompt(user_prompt.strip()) + + relevant_chunks = self.db.search(collection_name, query=enhanced_prompt, top_k=30) + if relevant_chunks is not None and len(relevant_chunks) > 0: + ranks = self.reranker.rank(query=enhanced_prompt, chunks=relevant_chunks) + relevant_chunks = [relevant_chunks[rank["corpus_id"]] for rank in ranks] + else: + relevant_chunks = [] + + sources = "" + prompt = "" + + for chunk in relevant_chunks[: min(10, len(relevant_chunks))]: + citation = ( + f"[Source: {chunk.filename}, " + f"Page: {chunk.page_number}, " + f"Lines: {chunk.start_line}-{chunk.end_line}, " + f"Start: {chunk.start_index}]\n\n" + ) + sources += f"Original text:\n{chunk.get_raw_text()}\nCitation:{citation}" + + with open( + os.path.join(BASE_DIR, "app", "prompt_templates", "test2.txt") + ) as prompt_file: + prompt = prompt_file.read() + + prompt += ( + "**QUESTION**: " + f"{enhanced_prompt}\n" + "**CONTEXT DOCUMENTS**:\n" + f"{sources}\n" + ) + print(prompt) + return prompt + + def enhance_prompt(self, original_prompt: str) -> str: + path_to_wrapping_prompt = os.path.join(BASE_DIR, "app", "prompt_templates", "wrapper.txt") + enhanced_prompt = "" + with open(path_to_wrapping_prompt, "r") as f: + enhanced_prompt = f.read().replace("[USERS_PROMPT]", original_prompt) + return self.wrapper.wrap(enhanced_prompt) + + """ + Splits the list of documents into groups with 'split_by' docs (done to avoid qdrant_client connection error handling), loads them, + splits into chunks, and saves to db + """ + + def upload_documents( + self, + collection_name: str, + documents: list[str], + split_by: int = 3, + debug_mode: bool = True, + ) -> None: + + for i in range(0, len(documents), split_by): + + if debug_mode: + print( + "<" + + "-" * 10 + + "New document group is taken into processing" + + "-" * 10 + + ">" + ) + + docs = documents[i : i + split_by] + + loading_time = 0 + chunk_generating_time = 0 + db_saving_time = 0 + + print("Start loading the documents") + start = time.time() + self.processor.load_documents(documents=docs, add_to_unprocessed=False) + loading_time = time.time() - start + + print("Start loading chunk generation") + start = time.time() + # self.processor.generate_chunks() + chunk_generating_time = time.time() - start + + print("Start saving to db") + start = time.time() + self.db.store(collection_name, self.processor.get_and_save_unsaved_chunks()) + db_saving_time = time.time() - start + + if debug_mode: + print( + f"loading time = {loading_time}, chunk generation time = {chunk_generating_time}, saving time = {db_saving_time}\n" + ) + + def extract_text(self, response) -> str: + text = "" + try: + text = response.candidates[0].content.parts[0].text + except Exception as e: + print(e) + return text + + """ + Produces answer to user's request. First, finds the most relevant chunks, generates prompt with them, and asks llm + """ + + async def generate_response( + self, collection_name: str, user_prompt: str, stream: bool = True + ) -> str: + general_prompt = self.get_general_prompt( + user_prompt=user_prompt, collection_name=collection_name + ) + + return self.llm.get_response(prompt=general_prompt) + + async def generate_response_stream( + self, collection_name: str, user_prompt: str, stream: bool = True + ) -> AsyncGenerator[Any, Any]: + general_prompt = self.get_general_prompt( + user_prompt=user_prompt, collection_name=collection_name + ) + + async for chunk in self.llm.get_streaming_response( + prompt=general_prompt, stream=True + ): + yield self.extract_text(chunk) + + """ + Produces the list of the most relevant chunks + """ + + def get_relevant_chunks(self, collection_name: str, query): + relevant_chunks = self.db.search(collection_name, query=query, top_k=15) + relevant_chunks = [ + relevant_chunks[ranked["corpus_id"]] + for ranked in self.reranker.rank(query=query, chunks=relevant_chunks) + ] + return relevant_chunks + + def create_new_collection(self, collection_name: str) -> None: + self.db.create_collection(collection_name) + + def get_collections_names(self) -> list[str]: + return self.db.get_collections() diff --git a/app/core/response_parser.py b/app/core/response_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..46a885dec5b0183b32c36889b92c88103c1d176b --- /dev/null +++ b/app/core/response_parser.py @@ -0,0 +1,29 @@ +from app.core.document_validator import path_is_valid +import re + +""" +Replaces the matched regular exp with link via html +""" + + +def create_url(match: re.Match) -> str: + path: str = match.group(1) + page: str = match.group(2) + lines: str = match.group(3) + start: str = match.group(4) + + if not path_is_valid(path): + return "###NOT VALID PATH###" + + return f'[Source]' + + +""" +Replaces all occurrences of citation pattern with links +""" + + +def add_links(response: str) -> str: + + citation_format = r"\[Source:\s*([^,]+?)\s*,\s*Page:\s*(\d+)\s*,\s*Lines:\s*(\d+\s*-\s*\d+)\s*,\s*Start:?\s*(\d+)\]" + return re.sub(pattern=citation_format, repl=create_url, string=response) diff --git a/app/core/some.py b/app/core/some.py new file mode 100644 index 0000000000000000000000000000000000000000..190b13832de265e745a2cca3e28b16c262e98f47 --- /dev/null +++ b/app/core/some.py @@ -0,0 +1,27 @@ +import re + +""" +Replaces the matched regular exp with link via html +""" + + +def create_url(match: re.Match) -> str: + path: str = match.group(1) + page: str = match.group(2) + lines: str = match.group(3) + start: str = match.group(4) + + return f'[Source]' + + +""" +Replaces all occurrences of citation pattern with links +""" + + +def add_links(response: str) -> str: + + citation_format = r"\[Source:\s*([^,]+?)\s*,\s*Page:\s*(\d+)\s*,\s*Lines:\s*(\d+\s*-\s*\d+)\s*,\s*Start:?\s*(\d+)\]" + return re.sub(pattern=citation_format, repl=create_url, string=response) + +print(add_links(r"[Source: C:\Users\User\mine\code\The-Ultimate-RAG\chats_storage\user_id=8b1be678-f2c7-4a63-b110-7627af9b1cf8\chat_id=d889d8dd-f74c-4b33-a214-d6c69b68eb98\documents\pdfs\7e2b5257-5261-4100-ae65-488e06af2e25.pdf, Page: 18, Lines: 1-2, Start: 0]")) \ No newline at end of file diff --git a/app/core/utils.py b/app/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..beb61c6d6765a0b65574ac3eb1f0adea5daac010 --- /dev/null +++ b/app/core/utils.py @@ -0,0 +1,208 @@ +from fastapi.templating import Jinja2Templates +from fastapi import Request, UploadFile + +from app.backend.controllers.chats import list_user_chats, verify_ownership_rights +from app.backend.controllers.users import get_current_user +from app.backend.models.users import User +from app.backend.models.documents import add_new_document +from app.core.rag_generator import RagSystem +from app.settings import BASE_DIR + +from uuid import uuid4 +import markdown +import os + +rag = None + + +# <----------------------- System -----------------------> +def initialize_rag() -> RagSystem: + global rag + if rag is None: + rag = RagSystem() + return rag + + +# <----------------------- Tools -----------------------> +""" +Updates response context and adds context of navbar (role, instance(or none)) and footer (none) +""" + + +def extend_context(context: dict, selected: int = None): + user = get_current_user(context.get("request")) + navbar = { + "navbar": False, + "navbar_path": "components/navbar.html", + "navbar_context": { + "chats": [], + "user": {"role": "user" if user else "guest", "instance": user}, + }, + } + sidebar = { + "sidebar": True, + "sidebar_path": "components/sidebar.html", + "sidebar_context": { + "selected": selected if selected is not None else None, + "chat_groups": list_user_chats(user.id) if user else [], + }, + } + footer = {"footer": False, "footer_context": None} + + context.update(**navbar) + context.update(**footer) + context.update(**sidebar) + + return context + + +""" +Validates chat viewing permission by comparing user's chats and requested one +""" + + +def protect_chat(user: User, chat_id: str) -> bool: + return verify_ownership_rights(user, chat_id) + + +async def save_documents( + collection_name: str, + files: list[UploadFile], + RAG: RagSystem, + user: User, + chat_id: str, + message_id: str +) -> None: + storage = os.path.join( + BASE_DIR, + "chats_storage", + f"user_id={user.id}", + f"chat_id={chat_id}", + "documents", + ) + docs = [] + + if files is None or len(files) == 0: + return + + os.makedirs(os.path.join(storage, "pdfs"), exist_ok=True) + + for file in files: + content = await file.read() + id = str(uuid4()) + if file.filename.endswith(".pdf"): + saved_file = os.path.join(storage, "pdfs", id + ".pdf") + else: + saved_file = os.path.join( + storage, id + "." + file.filename.split(".")[-1] + ) + + try: + add_new_document(id=id, name=file.filename, path=saved_file, message_id=message_id, size=file.size) + except Exception as e: + print(e) + raise RuntimeError("Error while adding document") + + with open(saved_file, "wb") as f: + f.write(content) + + docs.append(saved_file) + + if len(files) > 0: + RAG.upload_documents(collection_name, docs) + + +def get_pdf_path(path: str) -> str: + parts = path.split("chats_storage") + if len(parts) < 2: + return "" + return "chats_storage" + "".join(parts[1:]) + + +def construct_collection_name(user: User, chat_id: int) -> str: + return f"user_id_{user.id}_chat_id_{chat_id}" + + +def create_collection(user: User, chat_id: int, RAG: RagSystem) -> None: + if RAG is None: + raise RuntimeError("RAG was not initialized") + + RAG.create_new_collection(construct_collection_name(user, chat_id)) + print(rag.get_collections_names()) + + +def lines_to_markdown(lines: list[str]) -> list[str]: + return [markdown.markdown(line) for line in lines] + + +# <----------------------- Handlers -----------------------> +def PDFHandler( + request: Request, path: str, page: int, templates +) -> Jinja2Templates.TemplateResponse: + print(path) + url_path = get_pdf_path(path=path) + print(url_path) + + current_template = "pages/show_pdf.html" + return templates.TemplateResponse( + current_template, + extend_context( + { + "request": request, + "page": str(page or 1), + "url_path": url_path, + "user": get_current_user(request), + } + ), + ) + + +def TextHandler( + request: Request, path: str, lines: str, templates +) -> Jinja2Templates.TemplateResponse: + file_content = "" + with open(path, "r") as f: + file_content = f.read() + + start_line, end_line = map(int, lines.split("-")) + + text_before_citation = [] + text_after_citation = [] + citation = [] + anchor_added = False + + for index, line in enumerate(file_content.split("\n")): + if line == "" or line == "\n": + continue + if index + 1 < start_line: + text_before_citation.append(line) + elif end_line < index + 1: + text_after_citation.append(line) + else: + anchor_added = True + citation.append(line) + + current_template = "pages/show_text.html" + + return templates.TemplateResponse( + current_template, + extend_context( + { + "request": request, + "text_before_citation": lines_to_markdown(text_before_citation), + "text_after_citation": lines_to_markdown(text_after_citation), + "citation": lines_to_markdown(citation), + "anchor_added": anchor_added, + "user": get_current_user(request), + } + ), + ) + + +""" +Optional handler +""" + + +def DocHandler(): + pass diff --git a/app/frontend/static/styles.css b/app/frontend/static/styles.css new file mode 100644 index 0000000000000000000000000000000000000000..ad91438c2764f1f0db031b0aeb22e8f13853427f --- /dev/null +++ b/app/frontend/static/styles.css @@ -0,0 +1,377 @@ +#pdf-container { + margin: 0 auto; + max-width: 100%; + overflow-x: auto; + text-align: center; + padding: 20px 0; +} + +#pdf-canvas { + margin: 0 auto; + display: block; + max-width: 100%; + box-shadow: 0 0 5px rgba(0,0,0,0.2); +} + +#pageNum { + height: 40px; /* optional */ + font-size: 16px; /* makes text inside input larger */ + padding: 10px; + width: 9vh; /* optional for more padding inside the box */ +} + +.page-input { + width: 60px; + padding: 8px; + padding-right: 40px; /* reserve space for label inside input box */ + text-align: center; + border: 1px solid #ddd; + border-radius: 4px; + -moz-appearance: textfield; +} + +.page-input-label { + position: absolute; + right: 12px; + top: 50%; + transform: translateY(-50%); + font-size: 12px; + color: #666; + pointer-events: none; + background-color: #fff; /* Match background to prevent text overlapping */ + padding-left: 4px; +} + +.page-input-container { + position: relative; + display: inline-flex; + align-items: center; +} + +/* Hide number arrows in Chrome/Safari */ +.page-input::-webkit-outer-spin-button, +.page-input::-webkit-inner-spin-button { + -webkit-appearance: none; + margin: 0; +} + +/* Pagination styling */ +.pagination-container { + margin: 20px 0; + text-align: center; +} + +.pagination { + display: inline-flex; + align-items: center; +} + +.pagination-button { + padding: 8px 16px; + background: #4a6fa5; + color: white; + border: none; + border-radius: 4px; + cursor: pointer; + display: flex; + align-items: center; + gap: 5px; +} + +.pagination-button-text:hover { + background-color: #e0e0e0; + transform: translateY(-1px); +} + +.pagination-button-text:active { + transform: translateY(0); +} + +.text-viewer { + overflow-y: auto; /* Enables vertical scrolling when needed */ + height: 100%; + width: 100%; /* Or whatever height you prefer */ + font-family: monospace; + white-space: pre-wrap; /* Preserve line breaks but wrap text */ + background: #f8f8f8; + padding: 20px; + border-radius: 5px; + line-height: 1.5; +} + +.citation { + background-color: rgba(0, 255, 0, 0.2); + padding: 2px 0; +} + +.no-content { + color: #999; + font-style: italic; +} + +.pagination-container-text { + margin: 20px 0; + text-align: center; +} + +.pagination-button-text { + padding: 8px 16px; + background: #4a6fa5; + color: white; + border: none; + border-radius: 4px; + cursor: pointer; +} + + + +/* -------------------------------------------- */ + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + background-color: #f7f7f8; + color: #111827; + margin: 0; + overflow: hidden; + height: 100vh; + padding: 0; + display: flex; +} + +.sidebar { + width: 260px; + height: 100vh; + background-color: #1F2937; + /* border-right: 1px solid #e1e4e8; */ + overflow-y: auto; + padding: 8px; + position: sticky; + top: 0; +} + +.chat-page { + background-color: #111827; + flex: 1; + display: flex; + flex-direction: column; + height: 100vh; + overflow: hidden; /* Prevent double scrollbars */ +} + +.container { + flex: 1; + display: flex; + flex-direction: column; + padding: 0; + max-width: 100%; + height: 100%; +} + +/* Chat messages section */ +.chat-messages { + flex: 1; + overflow-y: auto; /* Make only this section scrollable */ + padding: 16px; + display: flex; + flex-direction: column; + gap: 16px; +} + +/* Input area - stays fixed at bottom */ +.input-group { + /* padding: 16px; + background-color: #44444C; */ + /* border-top: 1px solid #e1e4e8; */ + position: sticky; + bottom: 0; +} + +/* General styles */ + +/* Sidebar styles */ + +.chat-group { + font-weight: 500; + color: #9bb8d3; + text-transform: uppercase; + letter-spacing: 0.5px; + font-size: 12px; + padding: 8px 12px; +} + +.btn { + border-radius: 10px; + padding: 8px 12px; + font-size: 14px; + transition: all 0.2s; +} + +.btn-success { + background-color: #19c37d; + border-color: #19c37d; +} + +.btn-success:hover { + background-color: #16a369; + border-color: #16a369; +} + +.btn-outline-secondary { + /* border-color: #e1e4e8; */ + color: #374151; + background-color: transparent; +} + +.btn-outline-secondary:hover { + background-color: #273c50; + border-color: #e1e4e8; +} + +.btn-outline-light { + border-color: #e1e4e8; + color: #666; + background-color: transparent; +} + +.btn-outline-light:hover { + background-color: #e9ecef; + border-color: #e1e4e8; +} + +/* Chat page styles */ + +.message { + max-width: 80%; + padding: 12px 16px; + border-radius: 12px; + line-height: 1.5; +} + +.user-message { + align-self: flex-end; + background-color: #19c37d; + color: white; + border-bottom-right-radius: 4px; +} + +.assistant-message { + align-self: flex-start; + background-color: #f0f4f8; + border-bottom-left-radius: 4px; +} + +.message-header { + font-weight: 600; + font-size: 12px; + margin-bottom: 4px; + color: #666; +} + +.user-message .message-header { + color: rgba(255, 255, 255, 0.8); +} + +.message-content { + font-size: 14px; +} + + +.form-control { + border-radius: 6px; + padding: 10px 12px; + background-color: #374151; + /* border: 1px solid #e1e4e8; */ +} + +.form-control:focus { + box-shadow: none; + border-color: #19c37d; +} + +/* File input button */ +.btn-outline-secondary { + position: relative; +} + +.btn-outline-secondary input[type="file"] { + position: absolute; + opacity: 0; + width: 100%; + height: 100%; + top: 0; + left: 0; + cursor: pointer; +} + +/* Scrollbar styles */ +::-webkit-scrollbar { + width: 8px; +} + +::-webkit-scrollbar-track { + background: #f1f1f1; +} + +::-webkit-scrollbar-thumb { + background: #ccc; + border-radius: 4px; +} + +::-webkit-scrollbar-thumb:hover { + background: #aaa; +} + +/* Responsive adjustments */ +@media (max-width: 768px) { + .sidebar { + width: 220px; + } + + .message { + max-width: 90%; + } +} + +#queryInput { + background-color: #374151; + color: white; +} + +#queryInput:focus { + background-color: #374151; + color: white; + outline: none; + box-shadow: none; + border-color: #19c37d; /* optional green border for focus, remove if unwanted */ +} + +#searchButton { + background-color: #374151; +} + +#fileInput { + background-color: #374151; +} + + +/* For the placeholder text color */ +#queryInput::placeholder { + color: rgba(255, 255, 255, 0.7); /* Slightly transparent white */ +} + +.auth-card { + background-color: #1F2937; + border: none; + border-radius: 12px; +} + +.auth-input { + background-color: #374151 !important; + border: none !important; + color: white !important; +} + +.auth-input-group-text { + background-color: #374151 !important; + border: none !important; +} \ No newline at end of file diff --git a/app/frontend/templates/base.html b/app/frontend/templates/base.html new file mode 100644 index 0000000000000000000000000000000000000000..2df7f1bb9c225086a95d0b1f849f20ea4294ee2f --- /dev/null +++ b/app/frontend/templates/base.html @@ -0,0 +1,42 @@ + + +
+ + + {% block title %} + {% endblock %} + + + + {% block head_scripts %} + {% endblock %} + + + {% if navbar %} + {% with context=navbar_context %} + {% include navbar_path %} + {% endwith %} + {% endif %} + + {% if sidebar %} + {% with context=sidebar_context %} + {% include sidebar_path %} + {% endwith %} + {% endif %} + + {% block content %} + {% with context=sidebar_context %} + {% include sidebar_path %} + {% endwith %} + {% endblock %} + + {% if footer %} + {% with context=footer_context %} + {% include footer_path %} + {% endwith %} + {% endif %} + + {% block body_scripts %} + {% endblock %} + + \ No newline at end of file diff --git a/app/frontend/templates/components/navbar.html b/app/frontend/templates/components/navbar.html new file mode 100644 index 0000000000000000000000000000000000000000..ae1f05ff1efbe4d1eaa1150f6f56c21dd867dca6 --- /dev/null +++ b/app/frontend/templates/components/navbar.html @@ -0,0 +1,33 @@ + +Hello, guest!
+ {% else %} +Hello, {{ context.user.instance.email }}
+ {% endif %} + +Today
+Last week
+Last month
+Later
+ask anything...
+Join our community
+