Spaces:
Running
Running
| import os | |
| import base64 | |
| import tempfile | |
| from io import BytesIO | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Query | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from nudenet import NudeDetector | |
| import cv2 | |
| from moviepy.editor import VideoFileClip | |
| from typing import Optional | |
| app = FastAPI( | |
| title="Nudenet API", | |
| description="API for detecting nudity in images and videos using Nudenet", | |
| version="1.0.0", | |
| docs_url="/", | |
| redoc_url="/redoc" | |
| ) | |
| # Initialize NudeDetector with both models | |
| detector_320n = NudeDetector() | |
| detector_640m = NudeDetector(model_path="640m.onnx", inference_resolution=640) | |
| class Base64Image(BaseModel): | |
| image: str | |
| async def detect_nudity(image: UploadFile = File(...), use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")): | |
| if not image.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File provided is not an image") | |
| contents = await image.read() | |
| img = Image.open(BytesIO(contents)) | |
| if img.mode == 'RGBA': | |
| img = img.convert('RGB') | |
| img_byte_arr = BytesIO() | |
| img.save(img_byte_arr, format='JPEG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| detector = detector_640m if use_640m else detector_320n | |
| result = detector.detect(img_byte_arr) | |
| return JSONResponse(content={'result': result}) | |
| async def detect_nudity_base64(data: Base64Image, use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")): | |
| try: | |
| image_data = base64.b64decode(data.image) | |
| except: | |
| raise HTTPException(status_code=400, detail="Invalid base64 string") | |
| detector = detector_640m if use_640m else detector_320n | |
| result = detector.detect(image_data) | |
| return JSONResponse(content={'result': result}) | |
| async def detect_nudity_video( | |
| video: UploadFile = File(...), | |
| frame_interval: Optional[int] = Query(default=None, description="Interval between frames to analyze. 0 for every frame, 1 for every second, 2 for every other frame, etc. If not provided, defaults to 1 second."), | |
| use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model") | |
| ): | |
| if not video.content_type.startswith('video/'): | |
| raise HTTPException(status_code=400, detail="File provided is not a video") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video: | |
| contents = await video.read() | |
| temp_video.write(contents) | |
| temp_video_path = temp_video.name | |
| video_clip = VideoFileClip(temp_video_path) | |
| fps = video_clip.fps | |
| duration = video_clip.duration | |
| detector = detector_640m if use_640m else detector_320n | |
| results = [] | |
| if frame_interval is None: | |
| # Default behavior: analyze one frame per second | |
| for t in range(int(duration)): | |
| frame = video_clip.get_frame(t) | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| detections = detector.detect(frame_rgb) | |
| if detections: | |
| results.append({ | |
| 'timestamp': t, | |
| 'detections': detections | |
| }) | |
| else: | |
| # Custom interval behavior | |
| for frame_number, frame in enumerate(video_clip.iter_frames()): | |
| if frame_interval == 0 or frame_number % max(1, int(frame_interval * fps)) == 0: | |
| timestamp = frame_number / fps | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| detections = detector.detect(frame_rgb) | |
| if detections: | |
| results.append({ | |
| 'timestamp': timestamp, | |
| 'frame_number': frame_number, | |
| 'detections': detections | |
| }) | |
| os.unlink(temp_video_path) | |
| return JSONResponse(content={'results': results}) | |
| if __name__ == '__main__': | |
| import uvicorn | |
| uvicorn.run(app, host='0.0.0.0', port=int(os.environ.get('PORT', 7860))) | |