Spaces:
Sleeping
Sleeping
import base64 | |
import logging | |
from typing import List, Optional | |
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from pydantic import AnyHttpUrl, BaseModel, UrlConstraints | |
from contextlib import asynccontextmanager | |
from PIL import Image | |
from config import get_settings | |
import uvicorn | |
from utils.audio_utils import AudioUtils | |
from utils.caption_utils import ImageCaptioning | |
from utils.image_utils import UrlTest | |
from utils.topic_generation import TopicGenerator | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Pydantic models for request/response | |
class TopicResponse(BaseModel): | |
topics: List[str] | |
caption: Optional[str] | |
class AudioResponse(BaseModel): | |
audio_base64: str | |
class TranscriptionResponse(BaseModel): | |
audio_transcription: str | |
# Context manager for startup and shutdown events | |
async def lifespan(app: FastAPI): | |
# Startup | |
app.state.topic_generator = TopicGenerator() | |
app.state.img_caption = ImageCaptioning() | |
app.state.audio_utils = AudioUtils() | |
app.state.url_utils = UrlTest() | |
logger.info("Application startup complete") | |
yield | |
# Shutdown | |
logger.info("Application shutdown") | |
app = FastAPI( | |
title="Rediones API", | |
lifespan=lifespan, | |
) | |
# CORS | |
async def startup_event(): | |
settings = get_settings() | |
if settings.ALLOWED_ORIGINS: | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=settings.ALLOWED_ORIGINS, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def root(): | |
return {"message": "Welcome To Rediones API"} | |
async def health(): | |
return {"status": "OK"} | |
async def generate_topic( | |
img: UploadFile = File(None), | |
text: Optional[str] = Form(None), | |
img_url: Optional[AnyHttpUrl] = Form(None) | |
): | |
try: | |
if img_url and img: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Only one of image_url or img can be accepted" | |
) | |
if text and not (img or img_url): | |
generated_topics = app.state.topic_generator.generate_topics(text) | |
return TopicResponse(topics=generated_topics, caption=None) | |
if img or img_url: | |
img_file_object = None | |
if img: | |
if not img.filename.lower().endswith((".jpg", ".png", ".jpeg")): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Image file must be ended with .jpg, .png, .jpeg" | |
) | |
img_file_object = Image.open(img.file) | |
elif img_url: | |
img_file_object = app.state.url_utils.load_image(img_url) | |
capt = app.state.img_caption.combo_model(img_file_object, text) | |
print(capt) | |
return TopicResponse(topics=capt["topics"], caption=capt["caption"]) | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Enter text or image. Image URL and image file are mutually exclusive." | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_topic: {str(e)}") | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred") | |
async def generate_audio(text: str): | |
try: | |
audio_bytes = app.state.audio_utils.speak(text) | |
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") | |
return AudioResponse(audio_base64=audio_base64) | |
except Exception as e: | |
logger.error(f"Error in generate_audio: {str(e)}") | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred") | |
async def transcribe_audio( | |
audio: UploadFile = File(..., description="Audio file to be transcribed.") | |
): | |
try: | |
audio_transcribe = app.state.audio_utils.improved_transcribe(0.8, audio_file=audio.file) | |
return TranscriptionResponse(audio_transcription=audio_transcribe) | |
except Exception as e: | |
logger.error(f"Error in transcribe_audio: {str(e)}") | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred") | |
if __name__ == "__main__": | |
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |