Spaces:
Running
Running
import os | |
import base64 | |
import tempfile | |
from io import BytesIO | |
from PIL import Image | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Query | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from nudenet import NudeDetector | |
import cv2 | |
from moviepy.editor import VideoFileClip | |
from typing import Optional | |
app = FastAPI( | |
title="Nudenet API", | |
description="API for detecting nudity in images and videos using Nudenet", | |
version="1.0.0", | |
docs_url="/", | |
redoc_url="/redoc" | |
) | |
# Initialize NudeDetector with both models | |
detector_320n = NudeDetector() | |
detector_640m = NudeDetector(model_path="640m.onnx", inference_resolution=640) | |
class Base64Image(BaseModel): | |
image: str | |
async def detect_nudity(image: UploadFile = File(...), use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")): | |
if not image.content_type.startswith('image/'): | |
raise HTTPException(status_code=400, detail="File provided is not an image") | |
contents = await image.read() | |
img = Image.open(BytesIO(contents)) | |
if img.mode == 'RGBA': | |
img = img.convert('RGB') | |
img_byte_arr = BytesIO() | |
img.save(img_byte_arr, format='JPEG') | |
img_byte_arr = img_byte_arr.getvalue() | |
detector = detector_640m if use_640m else detector_320n | |
result = detector.detect(img_byte_arr) | |
return JSONResponse(content={'result': result}) | |
async def detect_nudity_base64(data: Base64Image, use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")): | |
try: | |
image_data = base64.b64decode(data.image) | |
except: | |
raise HTTPException(status_code=400, detail="Invalid base64 string") | |
detector = detector_640m if use_640m else detector_320n | |
result = detector.detect(image_data) | |
return JSONResponse(content={'result': result}) | |
async def detect_nudity_video( | |
video: UploadFile = File(...), | |
frame_interval: Optional[int] = Query(default=None, description="Interval between frames to analyze. 0 for every frame, 1 for every second, 2 for every other frame, etc. If not provided, defaults to 1 second."), | |
use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model") | |
): | |
if not video.content_type.startswith('video/'): | |
raise HTTPException(status_code=400, detail="File provided is not a video") | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video: | |
contents = await video.read() | |
temp_video.write(contents) | |
temp_video_path = temp_video.name | |
video_clip = VideoFileClip(temp_video_path) | |
fps = video_clip.fps | |
duration = video_clip.duration | |
detector = detector_640m if use_640m else detector_320n | |
results = [] | |
if frame_interval is None: | |
# Default behavior: analyze one frame per second | |
for t in range(int(duration)): | |
frame = video_clip.get_frame(t) | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
detections = detector.detect(frame_rgb) | |
if detections: | |
results.append({ | |
'timestamp': t, | |
'detections': detections | |
}) | |
else: | |
# Custom interval behavior | |
for frame_number, frame in enumerate(video_clip.iter_frames()): | |
if frame_interval == 0 or frame_number % max(1, int(frame_interval * fps)) == 0: | |
timestamp = frame_number / fps | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
detections = detector.detect(frame_rgb) | |
if detections: | |
results.append({ | |
'timestamp': timestamp, | |
'frame_number': frame_number, | |
'detections': detections | |
}) | |
os.unlink(temp_video_path) | |
return JSONResponse(content={'results': results}) | |
if __name__ == '__main__': | |
import uvicorn | |
uvicorn.run(app, host='0.0.0.0', port=int(os.environ.get('PORT', 7860))) | |