from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from selenium import webdriver from selenium.webdriver.chrome.service import Service from selenium.webdriver.chrome.options import Options from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC from selenium.common.exceptions import TimeoutException import time import logging import os import shutil from pathlib import Path from tenacity import retry, stop_after_attempt, wait_fixed import subprocess app = FastAPI() # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Pydantic models class VideoRequest(BaseModel): url: str class TranscriptResponse(BaseModel): success: bool transcript: list[str] | None error: str | None processing_time: float @retry(stop=stop_after_attempt(3), wait=wait_fixed(3)) def init_driver(): # Clean up any lingering Chrome processes try: subprocess.run(["pkill", "-f", "chrome"], check=False) logger.info("Terminated any existing Chrome processes") except Exception as e: logger.warning(f"Failed to terminate Chrome processes: {str(e)}") options = Options() options.add_argument("--headless=new") options.add_argument("--no-sandbox") options.add_argument("--disable-dev-shm-usage") options.add_argument("--disable-gpu") options.add_argument("--disable-extensions") # Removed --user-data-dir to avoid conflicts options.add_argument("--disable-setuid-sandbox") options.add_argument("--remote-debugging-port=9222") possible_chrome_paths = [ "/usr/bin/google-chrome", "/usr/bin/google-chrome-stable", ] chrome_path = None for path in possible_chrome_paths: if os.path.exists(path): chrome_path = path break if not chrome_path: logger.error(f"No Chrome binary found in paths: {possible_chrome_paths}") raise Exception(f"No Chrome binary found in paths: {possible_chrome_paths}") options.binary_location = chrome_path logger.info(f"Using Chrome binary: {chrome_path}") try: chromedriver_path = shutil.which("chromedriver") if not chromedriver_path or not os.path.exists(chromedriver_path): logger.error(f"ChromeDriver not found at {chromedriver_path}") raise Exception(f"ChromeDriver not found at {chromedriver_path}") service = Service(executable_path=chromedriver_path) driver = webdriver.Chrome(service=service, options=options) chrome_version = driver.capabilities["browserVersion"] chromedriver_version = driver.capabilities["chrome"]["chromedriverVersion"].split()[0] logger.info(f"Chrome version: {chrome_version}, ChromeDriver version: {chromedriver_version}") return driver except Exception as e: logger.error(f"Driver initialization failed: {str(e)}") raise Exception(f"Driver initialization failed: {str(e)}") @app.post("/transcript", response_model=TranscriptResponse) async def get_transcript(request: VideoRequest): start_time = time.time() driver = None try: video_url = request.url if not ("youtube.com" in video_url or "youtu.be" in video_url): raise HTTPException(status_code=400, detail="Invalid YouTube URL") driver = init_driver() logger.info(f"Processing URL: {video_url}") driver.get(video_url) try: cookie_button = WebDriverWait(driver, 5).until( EC.element_to_be_clickable((By.XPATH, "//*[contains(text(), 'Accept all')]")) ) cookie_button.click() logger.info("Accepted cookies") except TimeoutException: logger.info("No cookie consent found") logger.info("Clicking 'Show more' button") more_button = WebDriverWait(driver, 10).until( EC.element_to_be_clickable((By.ID, "expand")) ) driver.execute_script("arguments[0].click();", more_button) logger.info("Clicking transcript button") transcript_button = WebDriverWait(driver, 10).until( EC.element_to_be_clickable((By.CSS_SELECTOR, "button[aria-label='Show transcript']")) ) driver.execute_script("arguments[0].click();", transcript_button) logger.info("Waiting for transcript segments") WebDriverWait(driver, 15).until( EC.presence_of_element_located((By.ID, "segments-container")) ) logger.info("Extracting transcript") segments = driver.find_elements(By.CSS_SELECTOR, "div.ytd-transcript-segment-renderer") transcript = [] for segment in segments: try: text = segment.find_element(By.CLASS_NAME, "segment-text").text.strip() if text: transcript.append(text) except: continue if not transcript: raise HTTPException(status_code=404, detail="No transcript available") logger.info(f"Extracted {len(transcript)} transcript segments") return TranscriptResponse( success=True, transcript=transcript, error=None, processing_time=time.time() - start_time ) except TimeoutException as e: error_msg = "Timed out waiting for page elements - the video might not have transcripts" logger.error(error_msg) return TranscriptResponse( success=False, transcript=None, error=error_msg, processing_time=time.time() - start_time ) except Exception as e: logger.error(f"Error: {str(e)}") return TranscriptResponse( success=False, transcript=None, error=str(e), processing_time=time.time() - start_time ) finally: if driver: driver.quit() logger.info("Driver closed") @app.get("/health") def health_check(): chrome_path = shutil.which("google-chrome") chromedriver_path = shutil.which("chromedriver") try: # Check for running Chrome processes result = subprocess.run(["ps", "aux"], capture_output=True, text=True) chrome_processes = [line for line in result.stdout.splitlines() if "chrome" in line.lower()] chrome_process_count = len(chrome_processes) except Exception as e: chrome_process_count = f"Error checking processes: {str(e)}" return { "ChromePath": chrome_path, "ChromeDriverPath": chromedriver_path, "ChromeExists": Path(chrome_path or "").exists(), "ChromeDriverExists": Path(chromedriver_path or "").exists(), "ChromeProcessCount": chrome_process_count } @app.get("/") async def root(): return {"message": "Welcome to YouTube Transcript API"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)), workers=1)