Spaces:
Paused
Paused
| import os | |
| import json | |
| import time | |
| import asyncio | |
| import aiohttp | |
| import zipfile | |
| import shutil | |
| from typing import Dict, List, Set, Optional, Tuple, Any | |
| from urllib.parse import quote | |
| from datetime import datetime | |
| from pathlib import Path | |
| import io | |
| from fastapi import FastAPI, BackgroundTasks, HTTPException, status | |
| from pydantic import BaseModel, Field | |
| from huggingface_hub import HfApi, hf_hub_download | |
| # --- Configuration --- | |
| AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found | |
| FLOW_ID = os.getenv("FLOW_ID", "flow_default") | |
| FLOW_PORT = int(os.getenv("FLOW_PORT", 8001)) | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD") | |
| HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG") | |
| # Progress and State Tracking | |
| PROGRESS_FILE = Path("processing_progress.json") | |
| HF_STATE_FILE = "processing_state_transcriptions.json" | |
| LOCAL_STATE_FOLDER = Path(".state") | |
| LOCAL_STATE_FOLDER.mkdir(exist_ok=True) | |
| # Processing configuration | |
| MAX_UPLOADS_BEFORE_PAUSE = 120 # Pause uploading after 120 files | |
| UPLOAD_PAUSE_ENABLED = True | |
| # Directory within the HF dataset where the audio files are located | |
| AUDIO_FILE_PREFIX = "audio/" | |
| WHISPER_SERVERS = [ | |
| f"https://makeitfr-mineo-{i}.hf.space/transcribe" for i in range(1, 21) | |
| ] | |
| # Temporary storage for audio files | |
| TEMP_DIR = Path(f"temp_audio_{FLOW_ID}") | |
| TEMP_DIR.mkdir(exist_ok=True) | |
| # --- Models --- | |
| class ProcessStartRequest(BaseModel): | |
| start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).") | |
| class WhisperServer: | |
| def __init__(self, url: str): | |
| self.url = url | |
| self.is_processing = False | |
| self.current_file_index: Optional[int] = None | |
| self.total_processed = 0 | |
| self.total_time = 0.0 | |
| def fps(self): | |
| """Files per second""" | |
| return self.total_processed / self.total_time if self.total_time > 0 else 0 | |
| def assign_file(self, file_index: int): | |
| """Assign a file index to this server""" | |
| self.is_processing = True | |
| self.current_file_index = file_index | |
| def release(self): | |
| """Release the server for a new file""" | |
| self.is_processing = False | |
| self.current_file_index = None | |
| # Global state for whisper servers | |
| servers = [WhisperServer(url) for url in WHISPER_SERVERS] | |
| server_lock = asyncio.Lock() # Lock for thread-safe server state access | |
| # --- Progress and State Management Functions --- | |
| def load_progress() -> Dict: | |
| """Loads the local processing progress from the JSON file.""" | |
| default_structure = { | |
| "last_processed_index": 0, | |
| "processed_files": {}, # {index: repo_path} | |
| "file_list": [], # Full list of all zip files found in the dataset | |
| "uploaded_count": 0 | |
| } | |
| if PROGRESS_FILE.exists(): | |
| try: | |
| with PROGRESS_FILE.open('r') as f: | |
| data = json.load(f) | |
| # Ensure all keys exist | |
| for key, value in default_structure.items(): | |
| if key not in data: | |
| data[key] = value | |
| return data | |
| except json.JSONDecodeError: | |
| print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.") | |
| return default_structure | |
| def save_progress(progress_data: Dict): | |
| """Saves the local processing progress to the JSON file.""" | |
| try: | |
| with PROGRESS_FILE.open('w') as f: | |
| json.dump(progress_data, f, indent=4) | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}") | |
| def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]: | |
| """Load state from JSON file with migration logic for new structure.""" | |
| if os.path.exists(file_path): | |
| try: | |
| with open(file_path, "r") as f: | |
| data = json.load(f) | |
| if "file_states" not in data or not isinstance(data["file_states"], dict): | |
| data["file_states"] = {} | |
| if "next_download_index" not in data: | |
| data["next_download_index"] = 0 | |
| return data | |
| except json.JSONDecodeError: | |
| print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}") | |
| return default_value | |
| def save_json_state(file_path: str, data: Dict[str, Any]): | |
| """Save state to JSON file""" | |
| with open(file_path, "w") as f: | |
| json.dump(data, f, indent=2) | |
| async def download_hf_state() -> Dict[str, Any]: | |
| """Downloads the state file from Hugging Face or returns a default state.""" | |
| local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE | |
| default_state = {"next_download_index": 0, "file_states": {}} | |
| try: | |
| hf_hub_download( | |
| repo_id=HF_OUTPUT_DATASET_ID, | |
| filename=HF_STATE_FILE, | |
| repo_type="dataset", | |
| local_dir=LOCAL_STATE_FOLDER, | |
| local_dir_use_symlinks=False, | |
| token=HF_TOKEN | |
| ) | |
| return load_json_state(str(local_path), default_state) | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Failed to download state file: {str(e)}. Using local/default.") | |
| return load_json_state(str(local_path), default_state) | |
| async def upload_hf_state(state: Dict[str, Any]) -> bool: | |
| """Uploads the state file to Hugging Face.""" | |
| local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE | |
| try: | |
| save_json_state(str(local_path), state) | |
| HfApi(token=HF_TOKEN).upload_file( | |
| path_or_fileobj=str(local_path), | |
| path_in_repo=HF_STATE_FILE, | |
| repo_id=HF_OUTPUT_DATASET_ID, | |
| repo_type="dataset", | |
| commit_message=f"Update transcription state: next_index={state.get('next_download_index')}" | |
| ) | |
| return True | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}") | |
| return False | |
| # --- Hugging Face Utility Functions --- | |
| async def get_audio_file_list(progress_data: Dict) -> List[str]: | |
| if progress_data.get('file_list'): | |
| return progress_data['file_list'] | |
| try: | |
| api = HfApi(token=HF_TOKEN) | |
| repo_files = api.list_repo_files(repo_id=HF_AUDIO_DATASET_ID, repo_type="dataset") | |
| wav_files = sorted([f for f in repo_files if f.endswith('.wav')]) | |
| progress_data['file_list'] = wav_files | |
| save_progress(progress_data) | |
| return wav_files | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error fetching file list: {e}") | |
| return [] | |
| # --- Core Processing Logic --- | |
| async def transcribe_with_server(server: WhisperServer, wav_path: Path) -> Optional[Dict]: | |
| start_time = time.time() | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| with open(wav_path, 'rb') as f: | |
| data = aiohttp.FormData() | |
| data.add_field('file', f, filename=wav_path.name) | |
| async with session.post(server.url, data=data, timeout=600) as resp: | |
| if resp.status == 200: | |
| result = await resp.json() | |
| elapsed = time.time() - start_time | |
| server.total_processed += 1 | |
| server.total_time += elapsed | |
| return result | |
| else: | |
| print(f"[{FLOW_ID}] Server {server.url} returned status {resp.status}") | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error transcribing with {server.url}: {e}") | |
| return None | |
| async def process_file_task(wav_file: str, state: Dict, progress: Dict): | |
| # Find an available server | |
| server = None | |
| while server is None: | |
| async with server_lock: | |
| for s in servers: | |
| if not s.is_processing: | |
| s.is_processing = True | |
| server = s | |
| break | |
| if server is None: | |
| await asyncio.sleep(1) | |
| try: | |
| wav_filename = Path(wav_file).name | |
| wav_path = TEMP_DIR / wav_filename | |
| # Download | |
| hf_hub_download( | |
| repo_id=HF_AUDIO_DATASET_ID, | |
| filename=wav_file, | |
| repo_type="dataset", | |
| local_dir=TEMP_DIR, | |
| local_dir_use_symlinks=False, | |
| token=HF_TOKEN | |
| ) | |
| # Transcribe | |
| result = await transcribe_with_server(server, wav_path) | |
| if result: | |
| state["file_states"][wav_file] = "processed" | |
| progress["uploaded_count"] = progress.get("uploaded_count", 0) + 1 | |
| print(f"[{FLOW_ID}] β Success: {wav_file}") | |
| else: | |
| state["file_states"][wav_file] = "failed_transcription" | |
| print(f"[{FLOW_ID}] β Failed: {wav_file}") | |
| if wav_path.exists(): | |
| wav_path.unlink() | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error processing {wav_file}: {e}") | |
| state["file_states"][wav_file] = "failed_transcription" | |
| finally: | |
| server.release() | |
| async def main_processing_loop(): | |
| print(f"[{FLOW_ID}] Starting main processing loop...") | |
| while True: | |
| try: | |
| state = await download_hf_state() | |
| progress = load_progress() | |
| file_list = await get_audio_file_list(progress) | |
| if not file_list: | |
| print(f"[{FLOW_ID}] File list empty, retrying in 60s...") | |
| await asyncio.sleep(60) | |
| continue | |
| # 1. Handpick failed_transcription files | |
| failed_files = [f for f, s in state.get("file_states", {}).items() if s == "failed_transcription"] | |
| # 2. Also check for new files based on next_download_index | |
| next_idx = state.get("next_download_index", 0) | |
| new_files = file_list[next_idx:next_idx + 100] # Take a chunk of new files | |
| # Combine: Prioritize failed files, then add new ones | |
| files_to_process = failed_files + [f for f in new_files if f not in state["file_states"]] | |
| if not files_to_process: | |
| print(f"[{FLOW_ID}] No files to process. Sleeping...") | |
| await asyncio.sleep(60) | |
| continue | |
| print(f"[{FLOW_ID}] Processing {len(files_to_process)} files ({len(failed_files)} failed, {len(files_to_process)-len(failed_files)} new)...") | |
| # Process in batches of server count | |
| batch_size = len(servers) | |
| for i in range(0, len(files_to_process), batch_size): | |
| batch = files_to_process[i:i + batch_size] | |
| tasks = [process_file_task(f, state, progress) for f in batch] | |
| await asyncio.gather(*tasks) | |
| # Update next_download_index if we processed new files | |
| processed_new = [f for f in batch if f in new_files] | |
| if processed_new: | |
| last_new_file = processed_new[-1] | |
| state["next_download_index"] = file_list.index(last_new_file) + 1 | |
| # Save and upload state after each batch | |
| await upload_hf_state(state) | |
| save_progress(progress) | |
| await asyncio.sleep(10) | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error in main loop: {e}") | |
| await asyncio.sleep(60) | |
| # --- FastAPI App --- | |
| app = FastAPI(title=f"Flow Server {FLOW_ID} API") | |
| async def startup_event(): | |
| asyncio.create_task(main_processing_loop()) | |
| async def root(): | |
| progress = load_progress() | |
| state = await download_hf_state() | |
| failed_count = sum(1 for s in state.get("file_states", {}).values() if s == "failed_transcription") | |
| return { | |
| "flow_id": FLOW_ID, | |
| "status": "running", | |
| "next_download_index": state.get("next_download_index", 0), | |
| "failed_transcriptions": failed_count, | |
| "uploaded_count": progress.get("uploaded_count", 0), | |
| "total_files_in_list": len(progress.get('file_list', [])) | |
| } | |
| async def start_processing(request: ProcessStartRequest): | |
| state = await download_hf_state() | |
| state["next_download_index"] = request.start_index - 1 | |
| await upload_hf_state(state) | |
| return {"status": "index_reset", "new_index": request.start_index} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT) | |