from fastapi import FastAPI from pydantic import BaseModel import requests, tempfile, os, cv2, torch from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification from fastapi.middleware.cors import CORSMiddleware # ----------------- CONFIG ----------------- MODEL_NAME = "umm-maybe/AI-image-detector" processor = AutoImageProcessor.from_pretrained(MODEL_NAME) model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # ----------------- INIT ----------------- app = FastAPI() # Allow React frontend to access backend app.add_middleware( CORSMiddleware, allow_origins=["*", "http://localhost:5173", "http://localhost:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ----------------- MODELS ----------------- class VideoRequest(BaseModel): url: str id: str # ----------------- HELPERS ----------------- def predict_frame(frame): image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits.softmax(dim=1) pred_label = torch.argmax(probs, dim=1).item() confidence = probs[0][pred_label].item() return pred_label, confidence def process_video(video_url): # Download video tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") resp = requests.get(video_url, stream=True) for chunk in resp.iter_content(chunk_size=1024*1024): if chunk: tmp_file.write(chunk) tmp_file.close() cap = cv2.VideoCapture(tmp_file.name) frame_count = 0 results = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break if frame_count % 30 == 0: # every 30th frame label, conf = predict_frame(frame) results.append((label, conf)) frame_count += 1 cap.release() os.remove(tmp_file.name) if not results: return {"error": "No frames processed"} labels = [r[0] for r in results] confidences = [r[1] for r in results] final_label = max(set(labels), key=labels.count) avg_conf = sum(confidences) / len(confidences) return { "label": "AI-generated" if final_label == 1 else "Human", "confidence": float(avg_conf), "frames_checked": len(results) } # ----------------- ROUTES ----------------- @app.post("/detect") async def detect_video(request: VideoRequest): result = process_video(request.url) return result # ----------------- ENTRY POINT ----------------- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)