import logging from struct import pack import re import base64 from pyrogram.file_id import FileId from pymongo.errors import DuplicateKeyError from umongo import Instance, Document, fields from motor.motor_asyncio import AsyncIOMotorClient from marshmallow.exceptions import ValidationError from info import DATABASE_URI, DATABASE_NAME, COLLECTION_NAME, USE_CAPTION_FILTER, MAX_B_TN from utils import get_settings, save_group_settings logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) client = AsyncIOMotorClient(DATABASE_URI) db = client[DATABASE_NAME] instance = Instance.from_db(db) @instance.register class Media(Document): file_id = fields.StrField(attribute='_id') file_ref = fields.StrField(allow_none=True) file_name = fields.StrField(required=True) file_size = fields.IntField(required=True) file_type = fields.StrField(allow_none=True) mime_type = fields.StrField(allow_none=True) caption = fields.StrField(allow_none=True) class Meta: indexes = ('$file_name', ) collection_name = COLLECTION_NAME async def save_file(media): """Save file in database""" # TODO: Find better way to get same file_id for same media to avoid duplicates file_id, file_ref = unpack_new_file_id(media.file_id) file_name = re.sub(r"(_|\-|\.|\+)", " ", str(media.file_name)) try: file = Media( file_id=file_id, file_ref=file_ref, file_name=file_name, file_size=media.file_size, file_type=media.file_type, mime_type=media.mime_type, caption=media.caption.html if media.caption else None, ) except ValidationError: logger.exception('Error occurred while saving file in database') return False, 2 else: try: await file.commit() except DuplicateKeyError: logger.warning( f'{getattr(media, "file_name", "NO_FILE")} is already saved in database' ) return False, 0 else: logger.info(f'{getattr(media, "file_name", "NO_FILE")} is saved to database') return True, 1 async def get_search_results(chat_id, query, file_type=None, max_results=10, offset=0, filter=False): """For given query return (results, next_offset)""" if chat_id is not None: settings = await get_settings(int(chat_id)) try: if settings['max_btn']: max_results = 10 else: max_results = int(MAX_B_TN) except KeyError: await save_group_settings(int(chat_id), 'max_btn', False) settings = await get_settings(int(chat_id)) if settings['max_btn']: max_results = 10 else: max_results = int(MAX_B_TN) query = query.strip() #if filter: #better ? #query = query.replace(' ', r'(\s|\.|\+|\-|_)') #raw_pattern = r'(\s|_|\-|\.|\+)' + query + r'(\s|_|\-|\.|\+)' if not query: raw_pattern = '.' elif ' ' not in query: raw_pattern = r'(\b|[\.\+\-_])' + query + r'(\b|[\.\+\-_])' else: raw_pattern = query.replace(' ', r'.*[\s\.\+\-_]') try: regex = re.compile(raw_pattern, flags=re.IGNORECASE) except: return [] if USE_CAPTION_FILTER: filter = {'$or': [{'file_name': regex}, {'caption': regex}]} else: filter = {'file_name': regex} if file_type: filter['file_type'] = file_type total_results = await Media.count_documents(filter) next_offset = offset + max_results if next_offset > total_results: next_offset = '' cursor = Media.find(filter) # Sort by recent cursor.sort('$natural', -1) # Slice files according to offset and max results cursor.skip(offset).limit(max_results) # Get list of files files = await cursor.to_list(length=max_results) return files, next_offset, total_results async def get_bad_files(query, file_type=None, filter=False): """For given query return (results, next_offset)""" query = query.strip() #if filter: #better ? #query = query.replace(' ', r'(\s|\.|\+|\-|_)') #raw_pattern = r'(\s|_|\-|\.|\+)' + query + r'(\s|_|\-|\.|\+)' if not query: raw_pattern = '.' elif ' ' not in query: raw_pattern = r'(\b|[\.\+\-_])' + query + r'(\b|[\.\+\-_])' else: raw_pattern = query.replace(' ', r'.*[\s\.\+\-_]') try: regex = re.compile(raw_pattern, flags=re.IGNORECASE) except: return [] if USE_CAPTION_FILTER: filter = {'$or': [{'file_name': regex}, {'caption': regex}]} else: filter = {'file_name': regex} if file_type: filter['file_type'] = file_type total_results = await Media.count_documents(filter) cursor = Media.find(filter) # Sort by recent cursor.sort('$natural', -1) # Get list of files files = await cursor.to_list(length=total_results) return files, total_results async def get_file_details(query): filter = {'file_id': query} cursor = Media.find(filter) filedetails = await cursor.to_list(length=1) return filedetails def encode_file_id(s: bytes) -> str: r = b"" n = 0 for i in s + bytes([22]) + bytes([4]): if i == 0: n += 1 else: if n: r += b"\x00" + bytes([n]) n = 0 r += bytes([i]) return base64.urlsafe_b64encode(r).decode().rstrip("=") def encode_file_ref(file_ref: bytes) -> str: return base64.urlsafe_b64encode(file_ref).decode().rstrip("=") def unpack_new_file_id(new_file_id): """Return file_id, file_ref""" decoded = FileId.decode(new_file_id) file_id = encode_file_id( pack( "