| import os |
| import asyncio |
| import logging |
| from datetime import datetime |
| from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks |
| from sqlalchemy.orm import Session |
| from typing import Dict, List |
| from api.websocket_routes import manager |
|
|
| from api.auth import get_current_user |
| from models.schemas import PodcastGenerateRequest, PodcastResponse |
| from models import db_models |
| from core.database import get_db, SessionLocal |
| from services.podcast_service import podcast_service |
| from services.s3_service import s3_service |
| from core import constants |
|
|
| router = APIRouter(prefix="/api/podcast", tags=["podcast"]) |
| logger = logging.getLogger(__name__) |
|
|
| @router.get("/config") |
| async def get_podcast_config(): |
| """Returns available voices, BGM, and formats for podcast generation.""" |
| return { |
| "voices": constants.PODCAST_VOICES, |
| "bgm": constants.PODCAST_BGM, |
| "formats": constants.PODCAST_FORMATS, |
| "tts_models": constants.PODCAST_TTS_MODALS, |
| "models": constants.PODCAST_MODALS |
| } |
|
|
| async def run_podcast_generation(podcast_id: int, request: PodcastGenerateRequest, user_id: int): |
| """Background task to generate podcast and update status.""" |
| db = SessionLocal() |
| try: |
| podcast = db.query(db_models.Podcast).filter(db_models.Podcast.id == podcast_id).first() |
| if not podcast: |
| return |
|
|
| podcast.status = "processing" |
| db.commit() |
| |
| |
| connection_id = f"user_{user_id}" |
| await manager.send_progress(connection_id, 10, "processing", "Analyzing source file...") |
|
|
| |
| analysis_report = "" |
| if request.file_key: |
| analysis_report = await podcast_service.analyze_pdf( |
| file_key=request.file_key, |
| duration_minutes=request.duration_minutes |
| ) |
| await manager.send_progress(connection_id, 20, "processing", "Generating podcast script...") |
|
|
| |
| script = await podcast_service.generate_script( |
| user_prompt=request.user_prompt, |
| model=request.model, |
| duration_minutes=request.duration_minutes, |
| podcast_format=request.podcast_format, |
| pdf_suggestions=analysis_report, |
| file_key=request.file_key |
| ) |
|
|
| if not script: |
| raise Exception("Failed to generate script") |
| |
| await manager.send_progress(connection_id, 40, "processing", "Generating audio (this may take several minutes)...") |
|
|
| |
| audio_path = await podcast_service.generate_full_audio( |
| script=script, |
| tts_model=request.tts_model, |
| spk1_voice=request.spk1_voice, |
| spk2_voice=request.spk2_voice, |
| temperature=request.temperature, |
| bgm_choice=request.bgm_choice |
| ) |
|
|
| if not audio_path: |
| raise Exception("Failed to generate audio") |
|
|
| await manager.send_progress(connection_id, 85, "processing", "Uploading to S3...") |
|
|
| |
| filename = os.path.basename(audio_path) |
| s3_key = f"users/{user_id}/outputs/podcasts/{filename}" |
|
|
| def upload_audio(): |
| with open(audio_path, "rb") as f: |
| content = f.read() |
| |
| import boto3 |
| from core.config import settings |
| s3_client = boto3.client('s3', |
| aws_access_key_id=settings.AWS_ACCESS_KEY_ID, |
| aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, |
| region_name=settings.AWS_REGION) |
| s3_client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=s3_key, Body=content) |
| return content |
|
|
| await asyncio.to_thread(upload_audio) |
|
|
| public_url = s3_service.get_public_url(s3_key) |
| |
| |
| podcast.s3_key = s3_key |
| podcast.s3_url = public_url |
| podcast.script = script |
| podcast.status = "completed" |
| db.commit() |
|
|
| |
| await manager.send_result(connection_id, { |
| "id": podcast.id, |
| "status": "completed", |
| "title": podcast.title, |
| "public_url": public_url |
| }) |
|
|
| |
| if os.path.exists(audio_path): |
| os.remove(audio_path) |
|
|
| except Exception as e: |
| logger.error(f"Background podcast generation failed for ID {podcast_id}: {e}") |
| podcast = db.query(db_models.Podcast).filter(db_models.Podcast.id == podcast_id).first() |
| if podcast: |
| podcast.status = "failed" |
| podcast.error_message = str(e) |
| db.commit() |
| |
| connection_id = f"user_{user_id}" |
| await manager.send_error(connection_id, f"Generation failed: {str(e)}") |
| finally: |
| db.close() |
|
|
| @router.post("/generate", response_model=PodcastResponse) |
| async def generate_podcast( |
| request: PodcastGenerateRequest, |
| background_tasks: BackgroundTasks, |
| current_user: db_models.User = Depends(get_current_user), |
| db: Session = Depends(get_db) |
| ): |
| """ |
| Initiates podcast generation in the background. |
| Creates a 'pending' record immediately and returns it. |
| """ |
| |
| source_id = None |
| if request.file_key: |
| source = db.query(db_models.Source).filter( |
| db_models.Source.s3_key == request.file_key, |
| db_models.Source.user_id == current_user.id |
| ).first() |
| if not source: |
| raise HTTPException(status_code=403, detail="Not authorized to access this file") |
| source_id = source.id |
|
|
| |
| file_base = request.file_key.split('/')[-1].rsplit('.', 1)[0] if request.file_key else None |
| title = f"Podcast-{file_base}" if file_base else f"Podcast {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" |
| db_podcast = db_models.Podcast( |
| title=title, |
| user_id=current_user.id, |
| source_id=source_id, |
| status="processing" |
| ) |
| db.add(db_podcast) |
| db.commit() |
| db.refresh(db_podcast) |
|
|
| |
| background_tasks.add_task(run_podcast_generation, db_podcast.id, request, current_user.id) |
|
|
| return db_podcast |
|
|
| @router.get("/list", response_model=List[PodcastResponse]) |
| async def list_podcasts( |
| current_user: db_models.User = Depends(get_current_user), |
| db: Session = Depends(get_db) |
| ): |
| """ |
| Lists all podcasts for the current user including their generation status. |
| """ |
| try: |
| podcasts = db.query(db_models.Podcast).filter( |
| db_models.Podcast.user_id == current_user.id |
| ).order_by(db_models.Podcast.created_at.desc()).all() |
| |
| return [PodcastResponse.model_validate(p) for p in podcasts] |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @router.delete("/{podcast_id}") |
| async def delete_podcast( |
| podcast_id: int, |
| current_user: db_models.User = Depends(get_current_user), |
| db: Session = Depends(get_db) |
| ): |
| """ |
| Deletes a specific podcast from database and S3. |
| """ |
| podcast = db.query(db_models.Podcast).filter( |
| db_models.Podcast.id == podcast_id, |
| db_models.Podcast.user_id == current_user.id |
| ).first() |
| |
| if not podcast: |
| raise HTTPException(status_code=404, detail="Podcast not found") |
| |
| try: |
| |
| if podcast.s3_key: |
| await s3_service.delete_file(podcast.s3_key) |
| |
| |
| db.delete(podcast) |
| db.commit() |
| |
| return {"message": "Podcast and associated audio file deleted successfully"} |
| except Exception as e: |
| db.rollback() |
| logger.error(f"Failed to delete podcast: {e}") |
| raise HTTPException(status_code=500, detail=f"Deletion failed: {str(e)}") |