import os import base64 import tempfile from io import BytesIO from PIL import Image from fastapi import FastAPI, File, UploadFile, HTTPException, Query from fastapi.responses import JSONResponse from pydantic import BaseModel from nudenet import NudeDetector import cv2 from moviepy.editor import VideoFileClip from typing import Optional app = FastAPI( title="Nudenet API", description="API for detecting nudity in images and videos using Nudenet", version="1.0.0", docs_url="/", redoc_url="/redoc" ) # Initialize NudeDetector with both models detector_320n = NudeDetector() detector_640m = NudeDetector(model_path="640m.onnx", inference_resolution=640) class Base64Image(BaseModel): image: str @app.post("/detect", summary="Detect nudity in an image") async def detect_nudity(image: UploadFile = File(...), use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")): if not image.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File provided is not an image") contents = await image.read() img = Image.open(BytesIO(contents)) if img.mode == 'RGBA': img = img.convert('RGB') img_byte_arr = BytesIO() img.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() detector = detector_640m if use_640m else detector_320n result = detector.detect(img_byte_arr) return JSONResponse(content={'result': result}) @app.post("/detect_base64", summary="Detect nudity in a base64 encoded image") async def detect_nudity_base64(data: Base64Image, use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")): try: image_data = base64.b64decode(data.image) except: raise HTTPException(status_code=400, detail="Invalid base64 string") detector = detector_640m if use_640m else detector_320n result = detector.detect(image_data) return JSONResponse(content={'result': result}) @app.post("/detect_video", summary="Detect nudity in a video") async def detect_nudity_video( video: UploadFile = File(...), frame_interval: Optional[int] = Query(default=None, description="Interval between frames to analyze. 0 for every frame, 1 for every second, 2 for every other frame, etc. If not provided, defaults to 1 second."), use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model") ): if not video.content_type.startswith('video/'): raise HTTPException(status_code=400, detail="File provided is not a video") with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video: contents = await video.read() temp_video.write(contents) temp_video_path = temp_video.name video_clip = VideoFileClip(temp_video_path) fps = video_clip.fps duration = video_clip.duration detector = detector_640m if use_640m else detector_320n results = [] if frame_interval is None: # Default behavior: analyze one frame per second for t in range(int(duration)): frame = video_clip.get_frame(t) frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) detections = detector.detect(frame_rgb) if detections: results.append({ 'timestamp': t, 'detections': detections }) else: # Custom interval behavior for frame_number, frame in enumerate(video_clip.iter_frames()): if frame_interval == 0 or frame_number % max(1, int(frame_interval * fps)) == 0: timestamp = frame_number / fps frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) detections = detector.detect(frame_rgb) if detections: results.append({ 'timestamp': timestamp, 'frame_number': frame_number, 'detections': detections }) os.unlink(temp_video_path) return JSONResponse(content={'results': results}) if __name__ == '__main__': import uvicorn uvicorn.run(app, host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))