NudeNet-FastAPI / app.py
parth parekh
working demo
794a185
import os
import base64
import tempfile
from io import BytesIO
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from nudenet import NudeDetector
import cv2
from moviepy.editor import VideoFileClip
from typing import Optional
app = FastAPI(
title="Nudenet API",
description="API for detecting nudity in images and videos using Nudenet",
version="1.0.0",
docs_url="/",
redoc_url="/redoc"
)
# Initialize NudeDetector with both models
detector_320n = NudeDetector()
detector_640m = NudeDetector(model_path="640m.onnx", inference_resolution=640)
class Base64Image(BaseModel):
image: str
@app.post("/detect", summary="Detect nudity in an image")
async def detect_nudity(image: UploadFile = File(...), use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")):
if not image.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File provided is not an image")
contents = await image.read()
img = Image.open(BytesIO(contents))
if img.mode == 'RGBA':
img = img.convert('RGB')
img_byte_arr = BytesIO()
img.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
detector = detector_640m if use_640m else detector_320n
result = detector.detect(img_byte_arr)
return JSONResponse(content={'result': result})
@app.post("/detect_base64", summary="Detect nudity in a base64 encoded image")
async def detect_nudity_base64(data: Base64Image, use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")):
try:
image_data = base64.b64decode(data.image)
except:
raise HTTPException(status_code=400, detail="Invalid base64 string")
detector = detector_640m if use_640m else detector_320n
result = detector.detect(image_data)
return JSONResponse(content={'result': result})
@app.post("/detect_video", summary="Detect nudity in a video")
async def detect_nudity_video(
video: UploadFile = File(...),
frame_interval: Optional[int] = Query(default=None, description="Interval between frames to analyze. 0 for every frame, 1 for every second, 2 for every other frame, etc. If not provided, defaults to 1 second."),
use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")
):
if not video.content_type.startswith('video/'):
raise HTTPException(status_code=400, detail="File provided is not a video")
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video:
contents = await video.read()
temp_video.write(contents)
temp_video_path = temp_video.name
video_clip = VideoFileClip(temp_video_path)
fps = video_clip.fps
duration = video_clip.duration
detector = detector_640m if use_640m else detector_320n
results = []
if frame_interval is None:
# Default behavior: analyze one frame per second
for t in range(int(duration)):
frame = video_clip.get_frame(t)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
detections = detector.detect(frame_rgb)
if detections:
results.append({
'timestamp': t,
'detections': detections
})
else:
# Custom interval behavior
for frame_number, frame in enumerate(video_clip.iter_frames()):
if frame_interval == 0 or frame_number % max(1, int(frame_interval * fps)) == 0:
timestamp = frame_number / fps
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
detections = detector.detect(frame_rgb)
if detections:
results.append({
'timestamp': timestamp,
'frame_number': frame_number,
'detections': detections
})
os.unlink(temp_video_path)
return JSONResponse(content={'results': results})
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))