|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
|
from fastapi.responses import JSONResponse
|
|
|
import asyncio
|
|
|
import os
|
|
|
import time
|
|
|
import json
|
|
|
from typing import Optional, Dict, Any, List
|
|
|
from enum import Enum
|
|
|
from pydantic import BaseModel
|
|
|
from rich.progress import (
|
|
|
Progress,
|
|
|
SpinnerColumn,
|
|
|
TimeElapsedColumn,
|
|
|
DownloadColumn,
|
|
|
TransferSpeedColumn,
|
|
|
BarColumn,
|
|
|
TextColumn,
|
|
|
)
|
|
|
from rich.console import Console
|
|
|
from rich.live import Live
|
|
|
from rich.table import Table
|
|
|
import download_channel
|
|
|
|
|
|
|
|
|
console = Console()
|
|
|
|
|
|
app = FastAPI(title="Telegram Channel Downloader API")
|
|
|
|
|
|
|
|
|
active_downloads: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
|
class FileStatus(str, Enum):
|
|
|
PENDING = "pending"
|
|
|
DOWNLOADING = "downloading"
|
|
|
DOWNLOADED = "downloaded"
|
|
|
FAILED = "failed"
|
|
|
|
|
|
class ChannelFile(BaseModel):
|
|
|
message_id: int
|
|
|
filename: str
|
|
|
status: FileStatus
|
|
|
size: Optional[int] = None
|
|
|
download_time: Optional[float] = None
|
|
|
error: Optional[str] = None
|
|
|
upload_path: Optional[str] = None
|
|
|
|
|
|
class DownloadState(BaseModel):
|
|
|
channel: str
|
|
|
last_scanned_id: Optional[int] = None
|
|
|
files: List[ChannelFile] = []
|
|
|
current_download: Optional[int] = None
|
|
|
last_updated: float = time.time()
|
|
|
|
|
|
class DownloadRequest(BaseModel):
|
|
|
channel: Optional[str] = None
|
|
|
message_limit: Optional[int] = None
|
|
|
|
|
|
class DownloadStatus(BaseModel):
|
|
|
channel: str
|
|
|
status: str
|
|
|
message_count: int = 0
|
|
|
downloaded: int = 0
|
|
|
downloading: Optional[str] = None
|
|
|
error: Optional[str] = None
|
|
|
|
|
|
def create_hf_dataset(token: str) -> bool:
|
|
|
"""Create the Hugging Face dataset if it doesn't exist."""
|
|
|
try:
|
|
|
from huggingface_hub import create_repo, RepoNotFoundError
|
|
|
try:
|
|
|
|
|
|
create_repo(
|
|
|
repo_id=download_channel.HF_REPO_ID,
|
|
|
token=token,
|
|
|
repo_type="dataset",
|
|
|
exist_ok=True
|
|
|
)
|
|
|
console.print(f"[green]Created or verified dataset:[/green] {download_channel.HF_REPO_ID}")
|
|
|
|
|
|
|
|
|
initial_state = DownloadState(channel=download_channel.CHANNEL)
|
|
|
with open(download_channel.STATE_FILE, "w", encoding="utf-8") as f:
|
|
|
json.dump(initial_state.dict(), f, indent=2, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
if download_channel.upload_file_to_hf(
|
|
|
download_channel.STATE_FILE,
|
|
|
download_channel.STATE_FILE,
|
|
|
token
|
|
|
):
|
|
|
console.print("[green]Initialized dataset with empty state file[/green]")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
console.print(f"[red]Failed to create dataset:[/red] {str(e)}")
|
|
|
return False
|
|
|
except ImportError:
|
|
|
console.print("[red]huggingface_hub not properly installed[/red]")
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
def download_state_from_hf(token: str) -> DownloadState:
|
|
|
"""Try to download the state file from the HF dataset. Returns state dict or creates new."""
|
|
|
if not token:
|
|
|
return DownloadState(channel=download_channel.CHANNEL)
|
|
|
try:
|
|
|
|
|
|
local_path = download_channel.hf_hub_download(
|
|
|
repo_id=download_channel.HF_REPO_ID,
|
|
|
filename=download_channel.STATE_FILE,
|
|
|
repo_type="dataset",
|
|
|
token=token
|
|
|
)
|
|
|
with open(local_path, "r", encoding="utf-8") as f:
|
|
|
data = json.load(f)
|
|
|
return DownloadState(**data)
|
|
|
except Exception as e:
|
|
|
console.print(f"[yellow]No existing state found, creating new dataset:[/yellow] {str(e)}")
|
|
|
if create_hf_dataset(token):
|
|
|
console.print("[green]Dataset created successfully![/green]")
|
|
|
return DownloadState(channel=download_channel.CHANNEL)
|
|
|
else:
|
|
|
console.print("[red]Failed to create dataset, using local state only[/red]")
|
|
|
return DownloadState(channel=download_channel.CHANNEL)
|
|
|
|
|
|
async def clean_downloaded_file(file_path: str):
|
|
|
"""Remove local file after successful upload"""
|
|
|
try:
|
|
|
os.remove(file_path)
|
|
|
console.print(f"[blue]Cleaned up:[/blue] {os.path.basename(file_path)}")
|
|
|
except Exception as e:
|
|
|
console.print(f"[yellow]Warning:[/yellow] Could not clean up {file_path}: {e}")
|
|
|
|
|
|
async def update_and_upload_state(state: DownloadState, token: str) -> bool:
|
|
|
"""Update state timestamp and upload to dataset"""
|
|
|
state.last_updated = time.time()
|
|
|
try:
|
|
|
|
|
|
with open(download_channel.STATE_FILE, "w", encoding="utf-8") as f:
|
|
|
json.dump(state.dict(), f, indent=2, ensure_ascii=False)
|
|
|
|
|
|
return download_channel.upload_file_to_hf(
|
|
|
download_channel.STATE_FILE,
|
|
|
download_channel.STATE_FILE,
|
|
|
token
|
|
|
)
|
|
|
except Exception as e:
|
|
|
console.print(f"[red]Failed to update state:[/red] {e}")
|
|
|
return False
|
|
|
|
|
|
async def process_message(message, state: DownloadState, client) -> Optional[str]:
|
|
|
"""Process a single message, return output path if file downloaded or None"""
|
|
|
if not message.media:
|
|
|
return None
|
|
|
|
|
|
|
|
|
is_rar = False
|
|
|
filename = ""
|
|
|
if message.file:
|
|
|
filename = getattr(message.file, 'name', '') or ''
|
|
|
if filename:
|
|
|
is_rar = filename.lower().endswith('.rar')
|
|
|
else:
|
|
|
mime_type = getattr(message.file, 'mime_type', '') or ''
|
|
|
is_rar = 'rar' in mime_type.lower() if mime_type else False
|
|
|
|
|
|
if not is_rar:
|
|
|
return None
|
|
|
|
|
|
|
|
|
if filename:
|
|
|
suggested = f"{message.id}_{filename}"
|
|
|
else:
|
|
|
suggested = f"{message.id}.rar"
|
|
|
|
|
|
return os.path.join(download_channel.OUTPUT_DIR, suggested)
|
|
|
|
|
|
async def run_download(channel: Optional[str], message_limit: Optional[int], task_id: str):
|
|
|
"""Background task to run the download with state management"""
|
|
|
try:
|
|
|
|
|
|
if channel:
|
|
|
download_channel.CHANNEL = channel
|
|
|
if message_limit is not None:
|
|
|
download_channel.MESSAGE_LIMIT = message_limit
|
|
|
|
|
|
|
|
|
state = download_state_from_hf(download_channel.HF_TOKEN)
|
|
|
|
|
|
|
|
|
status = {
|
|
|
"channel": state.channel,
|
|
|
"status": "running",
|
|
|
"message_count": len(state.files),
|
|
|
"downloaded": len([f for f in state.files if f.status == FileStatus.DOWNLOADED]),
|
|
|
"downloading": None,
|
|
|
"error": None
|
|
|
}
|
|
|
active_downloads[task_id] = status
|
|
|
|
|
|
|
|
|
progress = Progress(
|
|
|
SpinnerColumn(),
|
|
|
TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
|
|
|
BarColumn(bar_width=40),
|
|
|
"[progress.percentage]{task.percentage:>3.1f}%",
|
|
|
"•",
|
|
|
DownloadColumn(),
|
|
|
"•",
|
|
|
TransferSpeedColumn(),
|
|
|
"•",
|
|
|
TimeElapsedColumn(),
|
|
|
)
|
|
|
|
|
|
overall_progress = Progress(
|
|
|
TextColumn("[bold yellow]{task.description}", justify="right"),
|
|
|
BarColumn(bar_width=40),
|
|
|
"[progress.percentage]{task.percentage:>3.1f}%",
|
|
|
"•",
|
|
|
TextColumn("[bold green]{task.fields[stats]}")
|
|
|
)
|
|
|
|
|
|
|
|
|
client = download_channel.TelegramClient(
|
|
|
download_channel.SESSION_FILE,
|
|
|
download_channel.API_ID,
|
|
|
download_channel.API_HASH
|
|
|
)
|
|
|
|
|
|
async with client:
|
|
|
try:
|
|
|
entity = await client.get_entity(download_channel.CHANNEL)
|
|
|
except Exception as e:
|
|
|
console.print(f"[red]Failed to resolve channel:[/red] {e}")
|
|
|
return 1
|
|
|
|
|
|
console.print(f"[green]Starting download from:[/green] {entity.title if hasattr(entity, 'title') else download_channel.CHANNEL}")
|
|
|
|
|
|
|
|
|
scan_count = 0
|
|
|
last_message_id = state.last_scanned_id
|
|
|
|
|
|
try:
|
|
|
async for message in client.iter_messages(entity, limit=download_channel.MESSAGE_LIMIT or None):
|
|
|
scan_count += 1
|
|
|
|
|
|
|
|
|
if last_message_id is None or message.id > last_message_id:
|
|
|
last_message_id = message.id
|
|
|
|
|
|
|
|
|
if any(f.message_id == message.id for f in state.files):
|
|
|
continue
|
|
|
|
|
|
|
|
|
out_path = await process_message(message, state, client)
|
|
|
if out_path:
|
|
|
|
|
|
file_info = ChannelFile(
|
|
|
message_id=message.id,
|
|
|
filename=os.path.basename(out_path),
|
|
|
status=FileStatus.PENDING,
|
|
|
size=getattr(message.media, 'size', 0) or 0
|
|
|
)
|
|
|
state.files.append(file_info)
|
|
|
|
|
|
|
|
|
state.last_scanned_id = last_message_id
|
|
|
if download_channel.HF_TOKEN:
|
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN)
|
|
|
|
|
|
console.print(f"[green]Channel scan complete:[/green] Found {scan_count} messages")
|
|
|
|
|
|
except Exception as e:
|
|
|
console.print(f"[red]Error during channel scan:[/red] {e}")
|
|
|
|
|
|
|
|
|
pending_files = [f for f in state.files if f.status == FileStatus.PENDING]
|
|
|
total_pending = len(pending_files)
|
|
|
|
|
|
if total_pending == 0:
|
|
|
console.print("[green]No new files to download![/green]")
|
|
|
return 0
|
|
|
|
|
|
console.print(f"[green]Starting downloads:[/green] {total_pending} files pending")
|
|
|
|
|
|
|
|
|
with Live(progress) as live_progress, Live(overall_progress) as live_overall:
|
|
|
overall_task = overall_progress.add_task(
|
|
|
f"Channel: {download_channel.CHANNEL}",
|
|
|
total=total_pending,
|
|
|
stats=f"Pending: {total_pending}"
|
|
|
)
|
|
|
|
|
|
for file_info in pending_files:
|
|
|
try:
|
|
|
|
|
|
file_info.status = FileStatus.DOWNLOADING
|
|
|
state.current_download = file_info.message_id
|
|
|
if download_channel.HF_TOKEN:
|
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN)
|
|
|
|
|
|
|
|
|
status["downloading"] = file_info.filename
|
|
|
|
|
|
|
|
|
message = await client.get_messages(entity, ids=file_info.message_id)
|
|
|
if not message or not message.media:
|
|
|
file_info.status = FileStatus.FAILED
|
|
|
file_info.error = "Message not found or no media"
|
|
|
continue
|
|
|
|
|
|
out_path = os.path.join(download_channel.OUTPUT_DIR, file_info.filename)
|
|
|
file_task = progress.add_task(
|
|
|
"download",
|
|
|
total=file_info.size or 100,
|
|
|
filename=file_info.filename
|
|
|
)
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
try:
|
|
|
async def progress_callback(current, total):
|
|
|
progress.update(file_task, completed=current)
|
|
|
overall_stats = f"Downloaded: {len([f for f in state.files if f.status == FileStatus.DOWNLOADED])}"
|
|
|
overall_progress.update(overall_task, completed=current/total*100, stats=overall_stats)
|
|
|
|
|
|
await client.download_media(
|
|
|
message,
|
|
|
file=out_path,
|
|
|
progress_callback=progress_callback
|
|
|
)
|
|
|
|
|
|
|
|
|
if download_channel.HF_TOKEN:
|
|
|
console.print(f"[yellow]Uploading to HF:[/yellow] {file_info.filename}")
|
|
|
path_in_repo = f"files/{file_info.filename}"
|
|
|
ok = download_channel.upload_file_to_hf(
|
|
|
out_path,
|
|
|
path_in_repo,
|
|
|
download_channel.HF_TOKEN
|
|
|
)
|
|
|
if ok:
|
|
|
console.print(f"[green]Uploaded:[/green] {file_info.filename}")
|
|
|
|
|
|
await clean_downloaded_file(out_path)
|
|
|
file_info.upload_path = path_in_repo
|
|
|
else:
|
|
|
console.print(f"[red]Upload failed:[/red] {file_info.filename}")
|
|
|
file_info.error = "Upload to dataset failed"
|
|
|
file_info.status = FileStatus.FAILED
|
|
|
continue
|
|
|
|
|
|
|
|
|
file_info.status = FileStatus.DOWNLOADED
|
|
|
file_info.download_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
if download_channel.HF_TOKEN:
|
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN)
|
|
|
|
|
|
|
|
|
status["downloaded"] += 1
|
|
|
await asyncio.sleep(0.2)
|
|
|
|
|
|
except download_channel.errors.FloodWaitError as fw:
|
|
|
wait = int(fw.seconds) if fw.seconds else 60
|
|
|
console.print(f"[yellow]FloodWait:[/yellow] Sleeping {wait}s")
|
|
|
await asyncio.sleep(wait + 1)
|
|
|
|
|
|
continue
|
|
|
|
|
|
except Exception as e:
|
|
|
console.print(f"[red]Error:[/red] {str(e)}")
|
|
|
file_info.status = FileStatus.FAILED
|
|
|
file_info.error = str(e)
|
|
|
if download_channel.HF_TOKEN:
|
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN)
|
|
|
|
|
|
except Exception as e:
|
|
|
console.print(f"[red]Fatal error processing {file_info.filename}:[/red] {str(e)}")
|
|
|
continue
|
|
|
|
|
|
|
|
|
state.current_download = None
|
|
|
if download_channel.HF_TOKEN:
|
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN)
|
|
|
|
|
|
console.print("[green]Download session completed![/green]")
|
|
|
status["status"] = "completed"
|
|
|
status["downloading"] = None
|
|
|
|
|
|
except Exception as e:
|
|
|
console.print(f"[red]Fatal error:[/red] {str(e)}")
|
|
|
if "status" in locals():
|
|
|
status["status"] = "failed"
|
|
|
status["error"] = str(e)
|
|
|
|
|
|
return 0
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def start_initial_download():
|
|
|
"""Start the download process automatically when the server starts"""
|
|
|
task_id = "initial_download"
|
|
|
|
|
|
|
|
|
if not download_channel.HF_TOKEN:
|
|
|
console.print("[red]ERROR: HF_TOKEN not set. Please set your Hugging Face token.[/red]")
|
|
|
return
|
|
|
|
|
|
|
|
|
console.print("[yellow]Checking Hugging Face dataset...[/yellow]")
|
|
|
try:
|
|
|
state = download_state_from_hf(download_channel.HF_TOKEN)
|
|
|
console.print(f"[green]Using channel:[/green] {state.channel}")
|
|
|
|
|
|
|
|
|
os.makedirs(download_channel.OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
asyncio.create_task(run_download(
|
|
|
channel=None,
|
|
|
message_limit=None,
|
|
|
task_id=task_id
|
|
|
))
|
|
|
console.print(f"[green]Started initial download task:[/green] {task_id}")
|
|
|
|
|
|
except Exception as e:
|
|
|
console.print(f"[red]Failed to initialize:[/red] {str(e)}")
|
|
|
|
|
|
@app.post("/download", response_model=Dict[str, str])
|
|
|
async def start_download(request: DownloadRequest, background_tasks: BackgroundTasks):
|
|
|
"""Start a new download task"""
|
|
|
task_id = f"download_{len(active_downloads) + 1}"
|
|
|
|
|
|
background_tasks.add_task(
|
|
|
run_download,
|
|
|
channel=request.channel,
|
|
|
message_limit=request.message_limit,
|
|
|
task_id=task_id
|
|
|
)
|
|
|
|
|
|
return {"task_id": task_id}
|
|
|
|
|
|
@app.get("/status/{task_id}", response_model=DownloadStatus)
|
|
|
async def get_status(task_id: str):
|
|
|
"""Get the status of a download task"""
|
|
|
if task_id not in active_downloads:
|
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
return active_downloads[task_id]
|
|
|
|
|
|
@app.get("/active", response_model=Dict[str, DownloadStatus])
|
|
|
async def list_active():
|
|
|
"""List all active or completed downloads"""
|
|
|
return active_downloads
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
uvicorn.run(app, host="127.0.0.1", port=8000) |