Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		GitHub Actions Bot
		
	commited on
		
		
					Commit 
							
							Β·
						
						c48903e
	
0
								Parent(s):
							
							
Changes from ggruber193/polars-docu-chat-rag
Browse files- app.py +0 -0
- requirements.txt +7 -0
- src/config.py +12 -0
- src/data_processing/embeddings.py +37 -0
- src/data_processing/process_markdown.py +53 -0
- src/data_processing/upload_to_qdrant.py +61 -0
- src/testing.py +19 -0
- src/utils.py +12 -0
    	
        app.py
    ADDED
    
    | 
            File without changes
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            beautifulsoup4~=4.13.4
         | 
| 2 | 
            +
            markdown~=3.8
         | 
| 3 | 
            +
            langchain~=0.3.23
         | 
| 4 | 
            +
            transformers~=4.51.3
         | 
| 5 | 
            +
            torch~=2.6.0
         | 
| 6 | 
            +
            tqdm~=4.67.1
         | 
| 7 | 
            +
            qdrant_client
         | 
    	
        src/config.py
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            EMBEDDING_MODEL = "thenlper/gte-small"
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            QDRANT_COLLECTION_NAME = "polars-documentation"
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def get_qdrant_config():
         | 
| 6 | 
            +
                from qdrant_client import models
         | 
| 7 | 
            +
                QDRANT_COLLECTION_CONFIG = {
         | 
| 8 | 
            +
                    "collection_name": QDRANT_COLLECTION_NAME,
         | 
| 9 | 
            +
                    "vectors_config": models.VectorParams(size=384, distance=models.Distance.COSINE),   # on_disk=True),
         | 
| 10 | 
            +
                    # "hnsw_config": models.HnswConfigDiff(on_disk=True)
         | 
| 11 | 
            +
                }
         | 
| 12 | 
            +
                return QDRANT_COLLECTION_CONFIG
         | 
    	
        src/data_processing/embeddings.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import AutoModel, AutoTokenizer
         | 
| 2 | 
            +
            from torch import Tensor
         | 
| 3 | 
            +
            from torch import functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from src.config import EMBEDDING_MODEL
         | 
| 6 | 
            +
            from src.utils import batched
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class TextEmbedder:
         | 
| 10 | 
            +
                def __init__(self, modelname=EMBEDDING_MODEL, max_length=512):
         | 
| 11 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(modelname)
         | 
| 12 | 
            +
                    self.model = AutoModel.from_pretrained(modelname)
         | 
| 13 | 
            +
                    self.max_length = max_length
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                @staticmethod
         | 
| 16 | 
            +
                def average_pool(last_hidden_states: Tensor,
         | 
| 17 | 
            +
                                 attention_mask: Tensor) -> Tensor:
         | 
| 18 | 
            +
                    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
         | 
| 19 | 
            +
                    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def embed_text(self, text: str | list[str], batch_size=128):
         | 
| 22 | 
            +
                    if isinstance(text, str):
         | 
| 23 | 
            +
                        text = [text]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    outputs = []
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    for batch in batched(text, n=batch_size):
         | 
| 28 | 
            +
                        batch_dict = self.tokenizer(batch, max_length=self.max_length, padding=True, truncation=True, return_tensors='pt')
         | 
| 29 | 
            +
                        output = self.model(**batch_dict)
         | 
| 30 | 
            +
                        embeddings = self.average_pool(output.last_hidden_state, batch_dict['attention_mask'])
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                        # embeddings = F.norm(embeddings, p=2, dim=1)
         | 
| 33 | 
            +
                        # scores = (embeddings[:1] @ embeddings[1:].T) * 100
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                        embeddings = embeddings.tolist()
         | 
| 36 | 
            +
                        outputs += embeddings
         | 
| 37 | 
            +
                    return outputs
         | 
    	
        src/data_processing/process_markdown.py
    ADDED
    
    | @@ -0,0 +1,53 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Any
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from bs4 import BeautifulSoup
         | 
| 4 | 
            +
            from langchain_core.documents import Document
         | 
| 5 | 
            +
            from markdown import markdown
         | 
| 6 | 
            +
            from pathlib import Path
         | 
| 7 | 
            +
            from langchain.text_splitter import MarkdownTextSplitter, MarkdownHeaderTextSplitter, TextSplitter
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from src.utils import batched
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def read_markdown_file(path: str | Path) -> [str, str]:
         | 
| 13 | 
            +
                path = Path(path)
         | 
| 14 | 
            +
                with open(path, 'r', encoding="utf8") as f_r:
         | 
| 15 | 
            +
                    text = f_r.read()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                # text = markdown(text)
         | 
| 18 | 
            +
                # text = ''.join(BeautifulSoup(text).findAll(text=True))
         | 
| 19 | 
            +
                return text, str(path)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def split_markdown(md: str | list[str],
         | 
| 23 | 
            +
                               metadata=dict[str, Any] | list[dict[str, Any]],
         | 
| 24 | 
            +
                               chunk_size=512,
         | 
| 25 | 
            +
                               overlap=64,
         | 
| 26 | 
            +
                               splitter: TextSplitter = None) -> list[Document]:
         | 
| 27 | 
            +
                if isinstance(md, str):
         | 
| 28 | 
            +
                    md = [md]
         | 
| 29 | 
            +
                    if isinstance(metadata, list):
         | 
| 30 | 
            +
                        raise ValueError("metadata should be a single dict")
         | 
| 31 | 
            +
                    metadata = [metadata]
         | 
| 32 | 
            +
                if splitter is None:
         | 
| 33 | 
            +
                    headers_to_split_on = [
         | 
| 34 | 
            +
                        ("#", "Header 1"),
         | 
| 35 | 
            +
                        ("##", "Header 2"),
         | 
| 36 | 
            +
                        ("###", "Header 3"),
         | 
| 37 | 
            +
                    ]
         | 
| 38 | 
            +
                    md = [MarkdownHeaderTextSplitter(headers_to_split_on, strip_headers=False).split_text(i) for i in md]
         | 
| 39 | 
            +
                    metadata = [{**metadata[i], **text.metadata} for i, text_split in enumerate(md) for text in text_split]
         | 
| 40 | 
            +
                    md = [j.page_content for i in md for j in i]
         | 
| 41 | 
            +
                    splitter = MarkdownTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                docs = splitter.create_documents(md, metadata)
         | 
| 44 | 
            +
                return docs
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def process_markdown_files(paths: list[str | Path], batch_size=1, chunk_size=512, overlap=64):
         | 
| 48 | 
            +
                for files in batched(paths, batch_size):
         | 
| 49 | 
            +
                    mds_w_paths = [read_markdown_file(i) for i in files]
         | 
| 50 | 
            +
                    metadata = [{"path": md_path} for _, md_path in mds_w_paths]
         | 
| 51 | 
            +
                    md = [md for md, _ in mds_w_paths]
         | 
| 52 | 
            +
                    docs = split_markdown(md, metadata, chunk_size=chunk_size, overlap=overlap)
         | 
| 53 | 
            +
                    yield [i.page_content for i in docs], [i.metadata for i in docs]
         | 
    	
        src/data_processing/upload_to_qdrant.py
    ADDED
    
    | @@ -0,0 +1,61 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Any
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from qdrant_client import QdrantClient, models
         | 
| 4 | 
            +
            from uuid import uuid4
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from src.config import QDRANT_COLLECTION_NAME
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class QdrantStore:
         | 
| 10 | 
            +
                def __init__(self, client: QdrantClient, collection_config=None):
         | 
| 11 | 
            +
                    self.client = client
         | 
| 12 | 
            +
                    self.collection_names = set([i.name for i in client.get_collections().collections])
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                    if collection_config is not None:
         | 
| 15 | 
            +
                        self.create_collection(collection_config)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def create_collection(self, collection_config: dict):
         | 
| 18 | 
            +
                    collection_name = collection_config["collection_name"]
         | 
| 19 | 
            +
                    if not self.client.collection_exists(collection_name):
         | 
| 20 | 
            +
                        self.client.create_collection(**collection_config)
         | 
| 21 | 
            +
                        self.collection_names.add(collection_name)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def _check_collection_name(self, collection_name):
         | 
| 24 | 
            +
                    if collection_name not in self.collection_names:
         | 
| 25 | 
            +
                        raise ValueError(f"Collection: {collection_name} does not exist.")
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def upsert_points(self,
         | 
| 28 | 
            +
                                  vectors: Any | list[Any],
         | 
| 29 | 
            +
                                  payloads: dict | list[dict],
         | 
| 30 | 
            +
                                  collection_name: str):
         | 
| 31 | 
            +
                    self._check_collection_name(collection_name)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    ids = [str(uuid4()) for _ in payloads]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.client.upsert(
         | 
| 36 | 
            +
                        collection_name=collection_name,
         | 
| 37 | 
            +
                        points=models.Batch(
         | 
| 38 | 
            +
                            ids=ids,
         | 
| 39 | 
            +
                            payloads=payloads,
         | 
| 40 | 
            +
                            vectors=vectors
         | 
| 41 | 
            +
                        )
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def delete_points(self,
         | 
| 45 | 
            +
                                  filters: dict[str, list[models.FieldCondition]],
         | 
| 46 | 
            +
                                  collection_name: str):
         | 
| 47 | 
            +
                    self._check_collection_name(collection_name)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    self.client.delete(
         | 
| 50 | 
            +
                        collection_name=collection_name,
         | 
| 51 | 
            +
                        points_selector=models.Filter(**filters)
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def delete_points_by_match(self,
         | 
| 55 | 
            +
                                           key_value: tuple[str, list[str] | str],
         | 
| 56 | 
            +
                                           collection_name: str):
         | 
| 57 | 
            +
                    key, values = key_value
         | 
| 58 | 
            +
                    if isinstance(values, str):
         | 
| 59 | 
            +
                        values = [values]
         | 
| 60 | 
            +
                    filter = {"must": [models.FieldCondition(key=key, match=models.MatchAny(any=values))]}
         | 
| 61 | 
            +
                    self.delete_points(filter, collection_name)
         | 
    	
        src/testing.py
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from bs4 import BeautifulSoup
         | 
| 2 | 
            +
            from markdown import markdown
         | 
| 3 | 
            +
            from langchain.text_splitter import MarkdownTextSplitter
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            path = "D:\PycharmProjects\polargs-docu-chat-rag\data\polars-docu\concepts\data-types-and-structures.md"
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            with open(path, 'r', encoding="utf8") as f_r:
         | 
| 9 | 
            +
                test_md = f_r.read()
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            html = markdown(test_md)
         | 
| 12 | 
            +
            text = ''.join(BeautifulSoup(html).findAll(text=True))
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            print(text[:10])
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            splitter = MarkdownTextSplitter(chunk_size=512, chunk_overlap=64)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            docs = splitter.create_documents([text])
         | 
| 19 | 
            +
            print(docs)
         | 
    	
        src/utils.py
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from itertools import islice
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def batched(iterable, n, *, strict=False):
         | 
| 5 | 
            +
                # batched('ABCDEFG', 3) β ABC DEF G
         | 
| 6 | 
            +
                if n < 1:
         | 
| 7 | 
            +
                    raise ValueError('n must be at least one')
         | 
| 8 | 
            +
                iterator = iter(iterable)
         | 
| 9 | 
            +
                while batch := tuple(islice(iterator, n)):
         | 
| 10 | 
            +
                    if strict and len(batch) != n:
         | 
| 11 | 
            +
                        raise ValueError('batched(): incomplete batch')
         | 
| 12 | 
            +
                    yield batch
         |