hamza2923's picture
Update main.py
248afa7 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.common.exceptions import TimeoutException
import time
import logging
import os
import shutil
from pathlib import Path
from tenacity import retry, stop_after_attempt, wait_fixed
import subprocess
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Pydantic models
class VideoRequest(BaseModel):
url: str
class TranscriptResponse(BaseModel):
success: bool
transcript: list[str] | None
error: str | None
processing_time: float
@retry(stop=stop_after_attempt(3), wait=wait_fixed(3))
def init_driver():
# Clean up any lingering Chrome processes
try:
subprocess.run(["pkill", "-f", "chrome"], check=False)
logger.info("Terminated any existing Chrome processes")
except Exception as e:
logger.warning(f"Failed to terminate Chrome processes: {str(e)}")
options = Options()
options.add_argument("--headless=new")
options.add_argument("--no-sandbox")
options.add_argument("--disable-dev-shm-usage")
options.add_argument("--disable-gpu")
options.add_argument("--disable-extensions")
# Removed --user-data-dir to avoid conflicts
options.add_argument("--disable-setuid-sandbox")
options.add_argument("--remote-debugging-port=9222")
possible_chrome_paths = [
"/usr/bin/google-chrome",
"/usr/bin/google-chrome-stable",
]
chrome_path = None
for path in possible_chrome_paths:
if os.path.exists(path):
chrome_path = path
break
if not chrome_path:
logger.error(f"No Chrome binary found in paths: {possible_chrome_paths}")
raise Exception(f"No Chrome binary found in paths: {possible_chrome_paths}")
options.binary_location = chrome_path
logger.info(f"Using Chrome binary: {chrome_path}")
try:
chromedriver_path = shutil.which("chromedriver")
if not chromedriver_path or not os.path.exists(chromedriver_path):
logger.error(f"ChromeDriver not found at {chromedriver_path}")
raise Exception(f"ChromeDriver not found at {chromedriver_path}")
service = Service(executable_path=chromedriver_path)
driver = webdriver.Chrome(service=service, options=options)
chrome_version = driver.capabilities["browserVersion"]
chromedriver_version = driver.capabilities["chrome"]["chromedriverVersion"].split()[0]
logger.info(f"Chrome version: {chrome_version}, ChromeDriver version: {chromedriver_version}")
return driver
except Exception as e:
logger.error(f"Driver initialization failed: {str(e)}")
raise Exception(f"Driver initialization failed: {str(e)}")
@app.post("/transcript", response_model=TranscriptResponse)
async def get_transcript(request: VideoRequest):
start_time = time.time()
driver = None
try:
video_url = request.url
if not ("youtube.com" in video_url or "youtu.be" in video_url):
raise HTTPException(status_code=400, detail="Invalid YouTube URL")
driver = init_driver()
logger.info(f"Processing URL: {video_url}")
driver.get(video_url)
try:
cookie_button = WebDriverWait(driver, 5).until(
EC.element_to_be_clickable((By.XPATH, "//*[contains(text(), 'Accept all')]"))
)
cookie_button.click()
logger.info("Accepted cookies")
except TimeoutException:
logger.info("No cookie consent found")
logger.info("Clicking 'Show more' button")
more_button = WebDriverWait(driver, 10).until(
EC.element_to_be_clickable((By.ID, "expand"))
)
driver.execute_script("arguments[0].click();", more_button)
logger.info("Clicking transcript button")
transcript_button = WebDriverWait(driver, 10).until(
EC.element_to_be_clickable((By.CSS_SELECTOR, "button[aria-label='Show transcript']"))
)
driver.execute_script("arguments[0].click();", transcript_button)
logger.info("Waiting for transcript segments")
WebDriverWait(driver, 15).until(
EC.presence_of_element_located((By.ID, "segments-container"))
)
logger.info("Extracting transcript")
segments = driver.find_elements(By.CSS_SELECTOR, "div.ytd-transcript-segment-renderer")
transcript = []
for segment in segments:
try:
text = segment.find_element(By.CLASS_NAME, "segment-text").text.strip()
if text:
transcript.append(text)
except:
continue
if not transcript:
raise HTTPException(status_code=404, detail="No transcript available")
logger.info(f"Extracted {len(transcript)} transcript segments")
return TranscriptResponse(
success=True,
transcript=transcript,
error=None,
processing_time=time.time() - start_time
)
except TimeoutException as e:
error_msg = "Timed out waiting for page elements - the video might not have transcripts"
logger.error(error_msg)
return TranscriptResponse(
success=False,
transcript=None,
error=error_msg,
processing_time=time.time() - start_time
)
except Exception as e:
logger.error(f"Error: {str(e)}")
return TranscriptResponse(
success=False,
transcript=None,
error=str(e),
processing_time=time.time() - start_time
)
finally:
if driver:
driver.quit()
logger.info("Driver closed")
@app.get("/health")
def health_check():
chrome_path = shutil.which("google-chrome")
chromedriver_path = shutil.which("chromedriver")
try:
# Check for running Chrome processes
result = subprocess.run(["ps", "aux"], capture_output=True, text=True)
chrome_processes = [line for line in result.stdout.splitlines() if "chrome" in line.lower()]
chrome_process_count = len(chrome_processes)
except Exception as e:
chrome_process_count = f"Error checking processes: {str(e)}"
return {
"ChromePath": chrome_path,
"ChromeDriverPath": chromedriver_path,
"ChromeExists": Path(chrome_path or "").exists(),
"ChromeDriverExists": Path(chromedriver_path or "").exists(),
"ChromeProcessCount": chrome_process_count
}
@app.get("/")
async def root():
return {"message": "Welcome to YouTube Transcript API"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)), workers=1)