|
|
|
|
|
import base64 |
|
import os |
|
import pickle |
|
import re |
|
import torch |
|
from enum import Enum |
|
from fastapi import APIRouter, Query, params |
|
from fastapi.responses import PlainTextResponse |
|
from heapq import nlargest |
|
from sentence_transformers import util |
|
from typing import Dict, List, Tuple, Set, LiteralString |
|
|
|
try: |
|
from .rag import SplitDocs, EMBEDDING_CTX |
|
from .utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get |
|
except: |
|
from rag import SplitDocs, EMBEDDING_CTX |
|
from utils_gitea import gitea_wiki_page_get, gitea_wiki_pages_get |
|
|
|
|
|
MANUAL_DIR = "D:/BlenderDev/blender-manual/manual" |
|
DOCS_DIR = "D:/BlenderDev/blender-developer-docs/docs" |
|
|
|
|
|
class Group(str, Enum): |
|
dev_docs = "dev_docs" |
|
|
|
manual = "manual" |
|
|
|
|
|
GROUPS_DEFAULT = {Group.dev_docs, Group.manual} |
|
|
|
|
|
class _Data(dict): |
|
cache_path = "routers/rag/embeddings_{}.pkl" |
|
|
|
def __init__(self): |
|
for grp in list(Group): |
|
cache_path = self.cache_path.format(grp.name) |
|
if os.path.exists(cache_path): |
|
with open(cache_path, 'rb') as file: |
|
self[grp.name] = pickle.load(file) |
|
continue |
|
|
|
|
|
print("Embedding Texts for", grp.name) |
|
self[grp.name] = {} |
|
|
|
|
|
if grp is Group.dev_docs: |
|
texts = self.docs_get_texts_to_embed() |
|
|
|
|
|
else: |
|
texts = self.manual_get_texts_to_embed() |
|
|
|
self[grp]['texts'] = texts |
|
self[grp]['embeddings'] = EMBEDDING_CTX.encode(texts) |
|
|
|
with open(cache_path, "wb") as file: |
|
|
|
self[grp]['embeddings'] = self[grp]['embeddings'].to( |
|
torch.device('cpu')) |
|
|
|
pickle.dump(self[grp], file, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
@classmethod |
|
def manual_get_texts_to_embed(cls): |
|
class SplitManual(SplitDocs): |
|
def reduce_text(_self, text): |
|
|
|
text = re.sub(r'\^{3,}', '', text) |
|
text = re.sub(r'-{3,}', '', text) |
|
|
|
text = text.replace('.rst', '.html') |
|
text = super().reduce_text(text) |
|
return text |
|
|
|
def embedding_header(self, rel_path, titles): |
|
rel_path = rel_path.replace('.rst', '.html') |
|
return super().embedding_header(rel_path, titles) |
|
|
|
|
|
pattern_content_sub = r'\.\. [^\n]+\n+(?: {3,}[^\n]*\n)*|:\w+:' |
|
patterns_titles = ( |
|
r'[\*#%]{3,}\n\s*(.+)\n[\*#%]{3,}', r'(?:[=+]{3,}\n)?\s*(.+)\n[=+]{3,}\n') |
|
|
|
return SplitManual().split_for_embedding( |
|
MANUAL_DIR, |
|
pattern_content_sub=pattern_content_sub, |
|
patterns_titles=patterns_titles, |
|
) |
|
|
|
@staticmethod |
|
def wiki_get_texts_to_embed(): |
|
class SplitWiki(SplitDocs): |
|
def split_in_topics(_self, |
|
filedir: LiteralString = None, |
|
*, |
|
pattern_filename=None, |
|
pattern_content_sub=None, |
|
patterns_titles=None): |
|
owner = "blender" |
|
repo = "blender" |
|
pages = gitea_wiki_pages_get(owner, repo) |
|
for page_name in pages: |
|
page_name_title = page_name["title"] |
|
page = gitea_wiki_page_get(owner, repo, page_name_title) |
|
rel_dir = f'/{owner}/{repo}/{page["sub_url"]}' |
|
titles = [page_name_title] |
|
text = base64.b64decode( |
|
page["content_base64"]).decode('utf-8') |
|
yield (rel_dir, titles, text) |
|
|
|
def reduce_text(_self, text): |
|
text = super().reduce_text(text) |
|
text = text.replace('https://projects.blender.org', '') |
|
return text |
|
|
|
return SplitWiki().split_for_embedding() |
|
|
|
@staticmethod |
|
def docs_get_texts_to_embed(): |
|
class SplitBlenderDocs(SplitDocs): |
|
def reduce_text(_self, text): |
|
text = super().reduce_text(text) |
|
|
|
text = re.sub(r'(index)?.md', '', text) |
|
return text |
|
|
|
def embedding_header(_self, rel_path, titles): |
|
rel_path = re.sub(r'(index)?.md', '', rel_path) |
|
return super().embedding_header(rel_path, titles) |
|
|
|
return SplitBlenderDocs().split_for_embedding(DOCS_DIR) |
|
|
|
def _sort_similarity( |
|
self, |
|
text_to_search: str, |
|
groups: Set[Group] = Query( |
|
default={Group.dev_docs, Group.manual}), |
|
limit: int = 5) -> List[str]: |
|
base_url: Dict[Group, str] = { |
|
Group.dev_docs: "https://developer.blender.org/docs", |
|
|
|
Group.manual: "https://docs.blender.org/manual/en/dev" |
|
} |
|
query_emb = EMBEDDING_CTX.encode([text_to_search]) |
|
results: List[Tuple[float, str, Group]] = [] |
|
for grp in groups: |
|
if grp not in self: |
|
continue |
|
|
|
search_results = util.semantic_search( |
|
query_emb, self[grp]['embeddings'], top_k=limit, score_function=util.dot_score) |
|
|
|
for score in search_results[0]: |
|
corpus_id = score['corpus_id'] |
|
text = self[grp]['texts'][corpus_id] |
|
results.append((score['score'], text, grp)) |
|
|
|
|
|
top_results = nlargest(limit, results, key=lambda x: x[0]) |
|
|
|
|
|
sorted_texts = [base_url[grp] + text for _, text, grp in top_results] |
|
|
|
return sorted_texts |
|
|
|
|
|
G_data = _Data() |
|
|
|
router = APIRouter() |
|
|
|
|
|
@router.get("/wiki_search", response_class=PlainTextResponse) |
|
def wiki_search( |
|
query: str = "", |
|
groups: Set[Group] = Query(default=GROUPS_DEFAULT) |
|
) -> str: |
|
try: |
|
groups = GROUPS_DEFAULT.intersection(groups) |
|
if len(groups) == 0: |
|
raise |
|
except: |
|
groups = GROUPS_DEFAULT |
|
|
|
texts = G_data._sort_similarity(query, groups) |
|
result: str = '' |
|
for text in texts: |
|
result += f'\n---\n{text}' |
|
return result |
|
|
|
|
|
if __name__ == '__main__': |
|
tests = ["Set Snap Base", "Building the Manual", |
|
"Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"] |
|
result = wiki_search(tests[0]) |
|
print(result) |
|
|