|
import asyncio
|
|
import logging
|
|
from typing import Dict, Union
|
|
from FileStream.bot import work_loads
|
|
from pyrogram import Client, utils, raw
|
|
from .file_properties import get_file_ids
|
|
from pyrogram.session import Session, Auth
|
|
from pyrogram.errors import AuthBytesInvalid
|
|
from pyrogram.file_id import FileId, FileType, ThumbnailSource
|
|
from pyrogram.types import Message
|
|
|
|
class ByteStreamer:
|
|
def __init__(self, client: Client):
|
|
self.clean_timer = 30 * 60
|
|
self.client: Client = client
|
|
self.cached_file_ids: Dict[str, FileId] = {}
|
|
asyncio.create_task(self.clean_cache())
|
|
|
|
async def get_file_properties(self, db_id: str, multi_clients) -> FileId:
|
|
"""
|
|
Returns the properties of a media of a specific message in a FIleId class.
|
|
if the properties are cached, then it'll return the cached results.
|
|
or it'll generate the properties from the Message ID and cache them.
|
|
"""
|
|
if not db_id in self.cached_file_ids:
|
|
logging.debug("Before Calling generate_file_properties")
|
|
await self.generate_file_properties(db_id, multi_clients)
|
|
logging.debug(f"Cached file properties for file with ID {db_id}")
|
|
return self.cached_file_ids[db_id]
|
|
|
|
async def generate_file_properties(self, db_id: str, multi_clients) -> FileId:
|
|
"""
|
|
Generates the properties of a media file on a specific message.
|
|
returns ths properties in a FIleId class.
|
|
"""
|
|
logging.debug("Before calling get_file_ids")
|
|
file_id = await get_file_ids(self.client, db_id, multi_clients, Message)
|
|
logging.debug(f"Generated file ID and Unique ID for file with ID {db_id}")
|
|
self.cached_file_ids[db_id] = file_id
|
|
logging.debug(f"Cached media file with ID {db_id}")
|
|
return self.cached_file_ids[db_id]
|
|
|
|
async def generate_media_session(self, client: Client, file_id: FileId) -> Session:
|
|
"""
|
|
Generates the media session for the DC that contains the media file.
|
|
This is required for getting the bytes from Telegram servers.
|
|
"""
|
|
|
|
media_session = client.media_sessions.get(file_id.dc_id, None)
|
|
|
|
if media_session is None:
|
|
if file_id.dc_id != await client.storage.dc_id():
|
|
media_session = Session(
|
|
client,
|
|
file_id.dc_id,
|
|
await Auth(
|
|
client, file_id.dc_id, await client.storage.test_mode()
|
|
).create(),
|
|
await client.storage.test_mode(),
|
|
is_media=True,
|
|
)
|
|
await media_session.start()
|
|
|
|
for _ in range(6):
|
|
exported_auth = await client.invoke(
|
|
raw.functions.auth.ExportAuthorization(dc_id=file_id.dc_id)
|
|
)
|
|
|
|
try:
|
|
await media_session.invoke(
|
|
raw.functions.auth.ImportAuthorization(
|
|
id=exported_auth.id, bytes=exported_auth.bytes
|
|
)
|
|
)
|
|
break
|
|
except AuthBytesInvalid:
|
|
logging.debug(
|
|
f"Invalid authorization bytes for DC {file_id.dc_id}"
|
|
)
|
|
continue
|
|
else:
|
|
await media_session.stop()
|
|
raise AuthBytesInvalid
|
|
else:
|
|
media_session = Session(
|
|
client,
|
|
file_id.dc_id,
|
|
await client.storage.auth_key(),
|
|
await client.storage.test_mode(),
|
|
is_media=True,
|
|
)
|
|
await media_session.start()
|
|
logging.debug(f"Created media session for DC {file_id.dc_id}")
|
|
client.media_sessions[file_id.dc_id] = media_session
|
|
else:
|
|
logging.debug(f"Using cached media session for DC {file_id.dc_id}")
|
|
return media_session
|
|
|
|
|
|
@staticmethod
|
|
async def get_location(file_id: FileId) -> Union[raw.types.InputPhotoFileLocation,
|
|
raw.types.InputDocumentFileLocation,
|
|
raw.types.InputPeerPhotoFileLocation,]:
|
|
"""
|
|
Returns the file location for the media file.
|
|
"""
|
|
file_type = file_id.file_type
|
|
|
|
if file_type == FileType.CHAT_PHOTO:
|
|
if file_id.chat_id > 0:
|
|
peer = raw.types.InputPeerUser(
|
|
user_id=file_id.chat_id, access_hash=file_id.chat_access_hash
|
|
)
|
|
else:
|
|
if file_id.chat_access_hash == 0:
|
|
peer = raw.types.InputPeerChat(chat_id=-file_id.chat_id)
|
|
else:
|
|
peer = raw.types.InputPeerChannel(
|
|
channel_id=utils.get_channel_id(file_id.chat_id),
|
|
access_hash=file_id.chat_access_hash,
|
|
)
|
|
|
|
location = raw.types.InputPeerPhotoFileLocation(
|
|
peer=peer,
|
|
volume_id=file_id.volume_id,
|
|
local_id=file_id.local_id,
|
|
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG,
|
|
)
|
|
elif file_type == FileType.PHOTO:
|
|
location = raw.types.InputPhotoFileLocation(
|
|
id=file_id.media_id,
|
|
access_hash=file_id.access_hash,
|
|
file_reference=file_id.file_reference,
|
|
thumb_size=file_id.thumbnail_size,
|
|
)
|
|
else:
|
|
location = raw.types.InputDocumentFileLocation(
|
|
id=file_id.media_id,
|
|
access_hash=file_id.access_hash,
|
|
file_reference=file_id.file_reference,
|
|
thumb_size=file_id.thumbnail_size,
|
|
)
|
|
return location
|
|
|
|
async def yield_file(
|
|
self,
|
|
file_id: FileId,
|
|
index: int,
|
|
offset: int,
|
|
first_part_cut: int,
|
|
last_part_cut: int,
|
|
part_count: int,
|
|
chunk_size: int,
|
|
) -> Union[str, None]:
|
|
"""
|
|
Custom generator that yields the bytes of the media file.
|
|
Modded from <https://github.com/eyaadh/megadlbot_oss/blob/master/mega/telegram/utils/custom_download.py#L20>
|
|
Thanks to Eyaadh <https://github.com/eyaadh>
|
|
"""
|
|
client = self.client
|
|
work_loads[index] += 1
|
|
logging.debug(f"Starting to yielding file with client {index}.")
|
|
media_session = await self.generate_media_session(client, file_id)
|
|
|
|
current_part = 1
|
|
|
|
location = await self.get_location(file_id)
|
|
|
|
try:
|
|
r = await media_session.invoke(
|
|
raw.functions.upload.GetFile(
|
|
location=location, offset=offset, limit=chunk_size
|
|
),
|
|
)
|
|
if isinstance(r, raw.types.upload.File):
|
|
while True:
|
|
chunk = r.bytes
|
|
if not chunk:
|
|
break
|
|
elif part_count == 1:
|
|
yield chunk[first_part_cut:last_part_cut]
|
|
elif current_part == 1:
|
|
yield chunk[first_part_cut:]
|
|
elif current_part == part_count:
|
|
yield chunk[:last_part_cut]
|
|
else:
|
|
yield chunk
|
|
|
|
current_part += 1
|
|
offset += chunk_size
|
|
|
|
if current_part > part_count:
|
|
break
|
|
|
|
r = await media_session.invoke(
|
|
raw.functions.upload.GetFile(
|
|
location=location, offset=offset, limit=chunk_size
|
|
),
|
|
)
|
|
except (TimeoutError, AttributeError):
|
|
pass
|
|
finally:
|
|
logging.debug(f"Finished yielding file with {current_part} parts.")
|
|
work_loads[index] -= 1
|
|
|
|
|
|
async def clean_cache(self) -> None:
|
|
"""
|
|
function to clean the cache to reduce memory usage
|
|
"""
|
|
while True:
|
|
await asyncio.sleep(self.clean_timer)
|
|
self.cached_file_ids.clear()
|
|
logging.debug("Cleaned the cache")
|
|
|