File size: 4,294 Bytes
794a185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import base64
import tempfile
from io import BytesIO
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from nudenet import NudeDetector
import cv2
from moviepy.editor import VideoFileClip
from typing import Optional

app = FastAPI(
    title="Nudenet API",
    description="API for detecting nudity in images and videos using Nudenet",
    version="1.0.0",
    docs_url="/",
    redoc_url="/redoc"
)

# Initialize NudeDetector with both models
detector_320n = NudeDetector()
detector_640m = NudeDetector(model_path="640m.onnx", inference_resolution=640)

class Base64Image(BaseModel):
    image: str

@app.post("/detect", summary="Detect nudity in an image")
async def detect_nudity(image: UploadFile = File(...), use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")):
    if not image.content_type.startswith('image/'):
        raise HTTPException(status_code=400, detail="File provided is not an image")
    
    contents = await image.read()
    img = Image.open(BytesIO(contents))
    
    if img.mode == 'RGBA':
        img = img.convert('RGB')
    
    img_byte_arr = BytesIO()
    img.save(img_byte_arr, format='JPEG')
    img_byte_arr = img_byte_arr.getvalue()

    detector = detector_640m if use_640m else detector_320n
    result = detector.detect(img_byte_arr)
    
    return JSONResponse(content={'result': result})

@app.post("/detect_base64", summary="Detect nudity in a base64 encoded image")
async def detect_nudity_base64(data: Base64Image, use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")):
    try:
        image_data = base64.b64decode(data.image)
    except:
        raise HTTPException(status_code=400, detail="Invalid base64 string")
    
    detector = detector_640m if use_640m else detector_320n
    result = detector.detect(image_data)
    
    return JSONResponse(content={'result': result})

@app.post("/detect_video", summary="Detect nudity in a video")
async def detect_nudity_video(
    video: UploadFile = File(...),
    frame_interval: Optional[int] = Query(default=None, description="Interval between frames to analyze. 0 for every frame, 1 for every second, 2 for every other frame, etc. If not provided, defaults to 1 second."),
    use_640m: bool = Query(False, description="Use the 640m model instead of the default 320n model")
):
    if not video.content_type.startswith('video/'):
        raise HTTPException(status_code=400, detail="File provided is not a video")
    
    with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video:
        contents = await video.read()
        temp_video.write(contents)
        temp_video_path = temp_video.name

    video_clip = VideoFileClip(temp_video_path)
    fps = video_clip.fps
    duration = video_clip.duration

    detector = detector_640m if use_640m else detector_320n
    results = []
    if frame_interval is None:
        # Default behavior: analyze one frame per second
        for t in range(int(duration)):
            frame = video_clip.get_frame(t)
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            detections = detector.detect(frame_rgb)
            if detections:
                results.append({
                    'timestamp': t,
                    'detections': detections
                })
    else:
        # Custom interval behavior
        for frame_number, frame in enumerate(video_clip.iter_frames()):
            if frame_interval == 0 or frame_number % max(1, int(frame_interval * fps)) == 0:
                timestamp = frame_number / fps
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                detections = detector.detect(frame_rgb)
                if detections:
                    results.append({
                        'timestamp': timestamp,
                        'frame_number': frame_number,
                        'detections': detections
                    })

    os.unlink(temp_video_path)

    return JSONResponse(content={'results': results})

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(app, host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))