Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
"""extract feature and search with user query.""" | |
from __future__ import annotations | |
import argparse | |
import json | |
import os | |
import re | |
import shutil | |
from multiprocessing import Pool | |
from typing import Any, List, Optional | |
import pytoml | |
from BCEmbedding.tools.langchain import BCERerank | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import (MarkdownHeaderTextSplitter, | |
MarkdownTextSplitter, | |
RecursiveCharacterTextSplitter) | |
from langchain_community.vectorstores.faiss import FAISS as Vectorstore | |
from langchain_core.documents import Document | |
from loguru import logger | |
from torch.cuda import empty_cache | |
from .file_operation import FileName, FileOperation | |
from .retriever import CacheRetriever, Retriever | |
from .cluster import Clusterer | |
def save_image(image, path, name): | |
""" 保存图片到指定路径 """ | |
if not os.path.exists(path): | |
os.makedirs(path) | |
image_path = os.path.join(path, name) | |
image.save(image_path) | |
return image_path | |
def save_all_image(image_folder,tables): | |
# 假设 res 中包含了图片对象和其他数据 | |
for index, data in enumerate(tables): | |
image, text = data # 假设 data 结构是这样的 | |
image_path = save_image(image, image_folder, f'image_{index}.png') | |
# relative_path = os.path.relpath(image_path, preprocessdir) | |
tables[index] = (image_path, text) # 更新 res 中的图片对象为图片路径 | |
return tables | |
def create_html_file(tables, html_file_path): | |
html_content = '<html><body>\n' | |
for index, data in enumerate(tables): | |
image_path, text = data | |
# 创建图片链接和文本 | |
html_content += f'<img src="{image_path}" alt="Image">\n{text}\n' | |
html_content += '</body></html>' | |
# 写入 HTML 文件 | |
with open(html_file_path, 'w') as file: | |
file.write(html_content) | |
def read_and_save(file: FileName,file_opr: FileOperation): | |
if os.path.exists(file.copypath) and os.path.exists(file.jsonpath) and os.path.exists(file.htmlpath) and os.path.exists(file.imagefolder): | |
# already exists, return | |
logger.info('already exist, skip load') | |
return | |
# file_opr = FileOperation() | |
logger.info('reading {}, would save to {}'.format(file.origin, | |
file.copypath)) | |
content, tbls, error = file_opr.read(file.origin) | |
if error is not None: | |
logger.error('{} load error: {}'.format(file.origin, str(error))) | |
return | |
if content is None or len(content) < 1: | |
logger.warning('{} empty, skip save'.format(file.origin)) | |
return | |
with open(file.copypath, 'w') as f: | |
f.write(content) | |
tables = save_all_image(file.imagefolder,tbls) | |
with open(file.jsonpath, 'w') as f: | |
json.dump(tables, f, indent=4, ensure_ascii=False) | |
create_html_file(tables, file.htmlpath) | |
def _split_text_with_regex_from_end(text: str, separator: str, | |
keep_separator: bool) -> List[str]: | |
# Now that we have the separator, split the text | |
if separator: | |
if keep_separator: | |
# The parentheses in the pattern keep the delimiters in the result. | |
_splits = re.split(f'({separator})', text) | |
splits = [''.join(i) for i in zip(_splits[0::2], _splits[1::2])] | |
if len(_splits) % 2 == 1: | |
splits += _splits[-1:] | |
# splits = [_splits[0]] + splits | |
else: | |
splits = re.split(separator, text) | |
else: | |
splits = list(text) | |
return [s for s in splits if s != ''] | |
# copy from https://github.com/chatchat-space/Langchain-Chatchat/blob/master/text_splitter/chinese_recursive_text_splitter.py | |
class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): | |
def __init__( | |
self, | |
separators: Optional[List[str]] = None, | |
keep_separator: bool = True, | |
is_separator_regex: bool = True, | |
**kwargs: Any, | |
) -> None: | |
"""Create a new TextSplitter.""" | |
super().__init__(keep_separator=keep_separator, **kwargs) | |
self._separators = separators or [ | |
'\n\n', '\n', '。|!|?', '\.\s|\!\s|\?\s', ';|;\s', ',|,\s' | |
] | |
self._is_separator_regex = is_separator_regex | |
def _split_text(self, text: str, separators: List[str]) -> List[str]: | |
"""Split incoming text and return chunks.""" | |
final_chunks = [] | |
# Get appropriate separator to use | |
separator = separators[-1] | |
new_separators = [] | |
for i, _s in enumerate(separators): | |
_separator = _s if self._is_separator_regex else re.escape(_s) | |
if _s == '': | |
separator = _s | |
break | |
if re.search(_separator, text): | |
separator = _s | |
new_separators = separators[i + 1:] | |
break | |
_separator = separator if self._is_separator_regex else re.escape( | |
separator) | |
splits = _split_text_with_regex_from_end(text, _separator, | |
self._keep_separator) | |
# Now go merging things, recursively splitting longer texts. | |
_good_splits = [] | |
_separator = '' if self._keep_separator else separator | |
for s in splits: | |
if self._length_function(s) < self._chunk_size: | |
_good_splits.append(s) | |
else: | |
if _good_splits: | |
merged_text = self._merge_splits(_good_splits, _separator) | |
final_chunks.extend(merged_text) | |
_good_splits = [] | |
if not new_separators: | |
final_chunks.append(s) | |
else: | |
other_info = self._split_text(s, new_separators) | |
final_chunks.extend(other_info) | |
if _good_splits: | |
merged_text = self._merge_splits(_good_splits, _separator) | |
final_chunks.extend(merged_text) | |
return [ | |
re.sub(r'\n{2,}', '\n', chunk.strip()) for chunk in final_chunks | |
if chunk.strip() != '' | |
] | |
class FeatureStore: | |
"""Tokenize and extract features from the project's documents, for use in | |
the reject pipeline and response pipeline.""" | |
def __init__(self, | |
embeddings: HuggingFaceEmbeddings, | |
reranker: BCERerank, | |
chunk_size: int, | |
n_clusters: int|list[int], | |
config_path: str = 'config.ini', | |
language: str = 'zh') -> None: | |
"""Init with model device type and config.""" | |
self.config_path = config_path | |
self.reject_throttle = -1 | |
self.language = language | |
with open(config_path, encoding='utf8') as f: | |
config = pytoml.load(f)['feature_store'] | |
self.reject_throttle = config['reject_throttle'] | |
self.chunk_size = chunk_size | |
if isinstance(n_clusters, int): | |
self.n_clusters = [n_clusters] | |
elif isinstance(n_clusters, list): | |
self.n_clusters = n_clusters | |
logger.warning( | |
'!!! If your feature generated by `text2vec-large-chinese` before 20240208, please rerun `python3 -m huixiangdou.service.feature_store`' # noqa E501 | |
) | |
logger.debug('loading text2vec model..') | |
self.embeddings = embeddings | |
self.reranker = reranker | |
self.compression_retriever = None | |
self.rejecter = None | |
self.retriever = None | |
self.md_splitter = MarkdownTextSplitter(chunk_size=self.chunk_size, | |
chunk_overlap=32) | |
if language == 'zh': | |
self.text_splitter = ChineseRecursiveTextSplitter( | |
keep_separator=True, | |
is_separator_regex=True, | |
chunk_size=self.chunk_size, | |
chunk_overlap=32) | |
else: | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.chunk_size, chunk_overlap=32) | |
self.head_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=[ | |
('#', 'Header 1'), | |
('##', 'Header 2'), | |
('###', 'Header 3'), | |
]) | |
def split_md(self, text: str, source: None): | |
"""Split the markdown document in a nested way, first extracting the | |
header. | |
If the extraction result exceeds 1024, split it again according to | |
length. | |
""" | |
docs = self.head_splitter.split_text(text) | |
final = [] | |
for doc in docs: | |
header = '' | |
if len(doc.metadata) > 0: | |
if 'Header 1' in doc.metadata: | |
header += doc.metadata['Header 1'] | |
if 'Header 2' in doc.metadata: | |
header += ' ' | |
header += doc.metadata['Header 2'] | |
if 'Header 3' in doc.metadata: | |
header += ' ' | |
header += doc.metadata['Header 3'] | |
if len(doc.page_content) >= 1024: | |
subdocs = self.md_splitter.create_documents([doc.page_content]) | |
for subdoc in subdocs: | |
if len(subdoc.page_content) >= 10: | |
final.append('{} {}'.format( | |
header, subdoc.page_content.lower())) | |
elif len(doc.page_content) >= 10: | |
final.append('{} {}'.format( | |
header, doc.page_content.lower())) # noqa E501 | |
for item in final: | |
if len(item) >= 1024: | |
logger.debug('source {} split length {}'.format( | |
source, len(item))) | |
return final | |
def clean_md(self, text: str): | |
"""Remove parts of the markdown document that do not contain the key | |
question words, such as code blocks, URL links, etc.""" | |
# remove ref | |
pattern_ref = r'\[(.*?)\]\(.*?\)' | |
new_text = re.sub(pattern_ref, r'\1', text) | |
# remove code block | |
pattern_code = r'```.*?```' | |
new_text = re.sub(pattern_code, '', new_text, flags=re.DOTALL) | |
# remove underline | |
new_text = re.sub('_{5,}', '', new_text) | |
# remove table | |
# new_text = re.sub('\|.*?\|\n\| *\:.*\: *\|.*\n(\|.*\|.*\n)*', '', new_text, flags=re.DOTALL) # noqa E501 | |
# use lower | |
new_text = new_text.lower() | |
return new_text | |
def get_md_documents(self, file: FileName): | |
documents = [] | |
length = 0 | |
text = '' | |
with open(file.copypath, encoding='utf8') as f: | |
text = f.read() | |
text = file.prefix + '\n' + self.clean_md(text) | |
if len(text) <= 1: | |
return [], length | |
chunks = self.split_md(text=text, | |
source=os.path.abspath(file.copypath)) | |
for chunk in chunks: | |
new_doc = Document(page_content=chunk, | |
metadata={ | |
'source': file.basename, | |
'read': file.copypath | |
}) | |
length += len(chunk) | |
documents.append(new_doc) | |
return documents, length | |
def get_text_documents(self, text: str, file: FileName): | |
if len(text) <= 1: | |
return [] | |
chunks = self.text_splitter.create_documents([text]) | |
documents = [] | |
for chunk in chunks: | |
# `source` is for return references | |
# `read` is for LLM response | |
chunk.metadata = {'source': file.basename, 'read': file.copypath} | |
documents.append(chunk) | |
return documents | |
def ingress_response(self, files: list, work_dir: str, file_opr: FileOperation): | |
"""Extract the features required for the response pipeline based on the | |
document.""" | |
chunk_dir = os.path.join(work_dir,f"chunksize_{self.chunk_size}") | |
feature_dir = os.path.join(chunk_dir, 'db_response') | |
if not os.path.exists(feature_dir): | |
os.makedirs(feature_dir) | |
# logger.info('glob {} in dir {}'.format(files, file_dir)) | |
# file_opr = FileOperation() | |
documents = [] | |
for i, file in enumerate(files): | |
logger.debug('{}/{}.. {}'.format(i + 1, len(files), file.basename)) | |
if not file.state: | |
continue | |
if file._type == 'md': | |
md_documents, md_length = self.get_md_documents(file) | |
documents += md_documents | |
logger.info('{} content length {}'.format( | |
file._type, md_length)) | |
file.reason = str(md_length) | |
else: | |
# now read pdf/word/excel/ppt text | |
text, _,error = file_opr.read(file.copypath) | |
if error is not None: | |
file.state = False | |
file.reason = str(error) | |
continue | |
file.reason = str(len(text)) | |
logger.info('{} content length {}'.format(file._type, len(text))) | |
text = file.prefix + text | |
documents += self.get_text_documents(text, file) | |
if len(documents) < 1: | |
return | |
texts = [doc.page_content for doc in documents] | |
logger.debug('calculating embeddings..') | |
text_embeddings = self.embeddings.embed_documents(texts) | |
logger.debug('clustering response data..') | |
clusterer = Clusterer(texts, text_embeddings, n_clusters= self.n_clusters) | |
clusterer.generate_cluster(chunk_dir) | |
text_embedding_pairs = zip(texts, text_embeddings) | |
metadatas = [doc.metadata for doc in documents] | |
logger.debug('making vectorstore..') | |
vs = Vectorstore.from_embeddings(text_embedding_pairs, self.embeddings, metadatas) | |
vs.save_local(feature_dir) | |
# def ingress_reject(self, files: list, work_dir: str): | |
# """Extract the features required for the reject pipeline based on | |
# documents.""" | |
# feature_dir = os.path.join(work_dir, 'db_reject') | |
# if not os.path.exists(feature_dir): | |
# os.makedirs(feature_dir) | |
# documents = [] | |
# file_opr = FileOperation() | |
# logger.debug('ingress reject..') | |
# for i, file in enumerate(files): | |
# if not file.state: | |
# continue | |
# if file._type == 'md': | |
# # reject base not clean md | |
# text = file.basename + '\n' | |
# with open(file.copypath, encoding='utf8') as f: | |
# text += f.read() | |
# if len(text) <= 1: | |
# continue | |
# chunks = self.split_md(text=text, | |
# source=os.path.abspath(file.copypath)) | |
# for chunk in chunks: | |
# new_doc = Document(page_content=chunk, | |
# metadata={ | |
# 'source': file.basename, | |
# 'read': file.copypath | |
# }) | |
# documents.append(new_doc) | |
# else: | |
# text, error = file_opr.read(file.copypath) | |
# if error is not None: | |
# continue | |
# text = file.basename + text | |
# documents += self.get_text_documents(text, file) | |
# if len(documents) < 1: | |
# return | |
# vs = Vectorstore.from_documents(documents, self.embeddings) | |
# vs.save_local(feature_dir) | |
def preprocess(self, files: list, work_dir: str,file_opr: FileOperation): | |
"""Preprocesses files in a given directory. Copies each file to | |
'preprocess' with new name formed by joining all subdirectories with | |
'_'. | |
Args: | |
files (list): original file list. | |
work_dir (str): Working directory where preprocessed files will be stored. # noqa E501 | |
Returns: | |
str: Path to the directory where preprocessed markdown files are saved. | |
Raises: | |
Exception: Raise an exception if no markdown files are found in the provided repository directory. # noqa E501 | |
""" | |
preproc_dir = os.path.join(work_dir, 'preprocess') | |
if not os.path.exists(preproc_dir): | |
os.makedirs(preproc_dir) | |
pool = Pool(processes=16) | |
# file_opr = FileOperation() | |
for idx, file in enumerate(files): | |
if not os.path.exists(file.origin): | |
file.state = False | |
file.reason = 'skip not exist' | |
continue | |
if file._type == 'image': | |
file.state = False | |
file.reason = 'skip image' | |
elif file._type in ['pdf']: | |
# read pdf/word/excel file and save to text format | |
md5 = file_opr.md5(file.origin) | |
file.copypath = os.path.join(preproc_dir,'{}.text'.format(md5)) | |
file.jsonpath = os.path.join(preproc_dir,'{}.json'.format(md5)) | |
file.htmlpath = os.path.join(preproc_dir,'{}.html'.format(md5)) | |
file.imagefolder = os.path.join(preproc_dir,'{}_images'.format(md5)) | |
# pool.apply_async(read_and_save, (file, )) | |
read_and_save(file,file_opr) | |
elif file._type in ['md', 'text']: | |
# rename text files to new dir | |
md5 = file_opr.md5(file.origin) | |
file.copypath = os.path.join( | |
preproc_dir, | |
file.origin.replace('/', '_')[-84:]) | |
try: | |
shutil.copy(file.origin, file.copypath) | |
file.state = True | |
file.reason = 'preprocessed' | |
except Exception as e: | |
file.state = False | |
file.reason = str(e) | |
else: # 'word', 'excel', 'ppt', 'html' TODO | |
file.state = False | |
file.reason = 'skip unknown format' | |
pool.close() | |
logger.debug('waiting for preprocess read finish..') | |
pool.join() | |
# check process result | |
for file in files: | |
if file._type in ['pdf', 'word', 'excel']: | |
if os.path.exists(file.copypath): | |
file.state = True | |
file.reason = 'preprocessed' | |
else: | |
file.state = False | |
file.reason = 'read error' | |
def saveconfig(self,workdir:str): | |
# shutil.copy(self.config_path, os.path.join(workdir, self.config_path)) | |
with open(self.config_path, encoding='utf8') as f: | |
config = pytoml.load(f) | |
config['feature_store']['n_clusters'] = self.n_clusters | |
config['feature_store']['chunk_size'] = self.chunk_size | |
chunk_dir = os.path.join(workdir,f"chunksize_{self.chunk_size}") | |
with open(os.path.join(workdir, self.config_path), 'w') as f: | |
pytoml.dump(config, f) | |
f.close() | |
with open(os.path.join(chunk_dir, self.config_path), 'w') as f: | |
pytoml.dump(config, f) | |
f.close() | |
with open(self.config_path, 'w') as f1: | |
pytoml.dump(config, f1) | |
f1.close() | |
def initialize(self, files: list, work_dir: str,file_opr: FileOperation): | |
"""Initializes response and reject feature store. | |
Only needs to be called once. Also calculates the optimal threshold | |
based on provided good and bad question examples, and saves it in the | |
configuration file. | |
""" | |
logger.info( | |
'initialize response and reject feature store, you only need call this once.' # noqa E501 | |
) | |
self.preprocess(files=files, work_dir=work_dir, file_opr=file_opr) | |
self.ingress_response(files=files, work_dir=work_dir, file_opr=file_opr) | |
# self.ingress_reject(files=files, work_dir=work_dir) | |
self.saveconfig(work_dir) | |
def parse_args(): | |
"""Parse command-line arguments.""" | |
parser = argparse.ArgumentParser( | |
description='Feature store for processing directories.') | |
parser.add_argument('--work_dir', | |
type=str, | |
default='workdir', | |
help='Working directory.') | |
parser.add_argument( | |
'--repo_dir', | |
type=str, | |
default='repodir', | |
help='Root directory where the repositories are located.') | |
parser.add_argument( | |
'--config_path', | |
default='config.ini', | |
help='Feature store configuration path. Default value is config.ini') | |
parser.add_argument( | |
'--good_questions', | |
default='resource/good_questions.json', | |
help= # noqa E251 | |
'Positive examples in the dataset. Default value is resource/good_questions.json' # noqa E501 | |
) | |
parser.add_argument( | |
'--bad_questions', | |
default='resource/bad_questions.json', | |
help= # noqa E251 | |
'Negative examples json path. Default value is resource/bad_questions.json' # noqa E501 | |
) | |
parser.add_argument( | |
'--sample', help='Input an json file, save reject and search output.') | |
args = parser.parse_args() | |
return args | |
def test_reject(retriever: Retriever, sample: str = None, work_dir: str = 'workdir'): | |
"""Simple test reject pipeline.""" | |
if sample is None: | |
real_questions = [ | |
'SAM 10个T 的训练集,怎么比比较公平呢~?速度上还有缺陷吧?', | |
'想问下,如果只是推理的话,amp的fp16是不会省显存么,我看parameter仍然是float32,开和不开推理的显存占用都是一样的。能不能直接用把数据和model都 .half() 代替呢,相比之下amp好在哪里', # noqa E501 | |
'mmdeploy支持ncnn vulkan部署么,我只找到了ncnn cpu 版本', | |
'大佬们,如果我想在高空检测安全帽,我应该用 mmdetection 还是 mmrotate', | |
'请问 ncnn 全称是什么', | |
'有啥中文的 text to speech 模型吗?', | |
'今天中午吃什么?', | |
'huixiangdou 是什么?', | |
'mmpose 如何安装?', | |
'使用科研仪器需要注意什么?' | |
] | |
else: | |
with open(sample) as f: | |
real_questions = json.load(f) | |
for example in real_questions: | |
reject, _ = retriever.is_reject(example) | |
if reject: | |
logger.error(f'reject query: {example}') | |
else: | |
logger.warning(f'process query: {example}') | |
if sample is not None: | |
if reject: | |
with open(f'{work_dir}/negative.txt', 'a+') as f: | |
f.write(example) | |
f.write('\n') | |
else: | |
with open(f'{work_dir}/positive.txt', 'a+') as f: | |
f.write(example) | |
f.write('\n') | |
empty_cache() | |
def test_query(retriever: Retriever, sample: str = None, work_dir: str = 'workdir'): | |
"""Simple test response pipeline.""" | |
if sample is not None: | |
with open(sample) as f: | |
real_questions = json.load(f) | |
logger.add('logs/feature_store_query.log', rotation='4MB') | |
else: | |
real_questions = ['mmpose installation', 'how to use std::vector ?'] | |
for example in real_questions: | |
example = example[0:400] | |
print(retriever.query(example)) | |
empty_cache() | |
empty_cache() | |
if __name__ == '__main__': | |
args = parse_args() | |
cache = CacheRetriever(config_path=args.config_path) | |
fs_init = FeatureStore(embeddings=cache.embeddings, | |
reranker=cache.reranker, | |
config_path=args.config_path) | |
with open(args.config_path, encoding='utf8') as f: | |
config = pytoml.load(f) | |
# walk all files in repo dir | |
file_opr = FileOperation() | |
files = file_opr.scan_dir(repo_dir=config['feature_store']['repo_dir']) | |
fs_init.initialize(files=files, work_dir=config['feature_store']['work_dir']) | |
file_opr.summarize(files) | |
del fs_init | |
# update reject throttle | |
retriever = cache.get(config_path=args.config_path, work_dir=config['feature_store']['work_dir']) | |
with open(os.path.join('resource', 'good_questions.json')) as f: | |
good_questions = json.load(f) | |
with open(os.path.join('resource', 'bad_questions.json')) as f: | |
bad_questions = json.load(f) | |
retriever.update_throttle(config_path=args.config_path, | |
good_questions=good_questions, | |
bad_questions=bad_questions) | |
cache.pop('default') | |
# test | |
retriever = cache.get(config_path=args.config_path, work_dir=config['feature_store']['work_dir']) | |
test_reject(retriever, args.sample, config['feature_store']['work_dir']) | |
test_query(retriever, args.sample, config['feature_store']['work_dir']) | |