aibuild / app.py
webolavo's picture
Update app.py
0cfe6dd verified
# --- flash_attn Mock ---------------------------------------------------------
import sys
import types
import importlib.util
flash_mock = types.ModuleType("flash_attn")
flash_mock.__version__ = "2.0.0"
flash_mock.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None)
sys.modules["flash_attn"] = flash_mock
sys.modules["flash_attn.flash_attn_interface"] = types.ModuleType("flash_attn.flash_attn_interface")
sys.modules["flash_attn.bert_padding"] = types.ModuleType("flash_attn.bert_padding")
# -----------------------------------------------------------------------------
import io
import os
import time
import uuid
import threading
import subprocess
import cv2
import torch
from PIL import Image
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from starlette.background import BackgroundTask
from transformers import (
BlipProcessor,
BlipForQuestionAnswering,
AutoProcessor,
AutoModelForCausalLM,
)
BLIP_MODEL_ID = "Salesforce/blip-vqa-base"
FLORENCE_MODEL_ID = "microsoft/Florence-2-large-ft"
FRAMES_PER_SECOND = 1
TEMP_DIR = "/tmp/video_filter"
os.makedirs(TEMP_DIR, exist_ok=True)
BLIP_QUESTIONS = [
"is there a person in this image?",
"is there a woman in this image?",
"is there a human body part in this image?",
"is there a hand or arm visible?",
"is there a face visible?",
"is there a leg or foot visible?",
"is there a belly or stomach visible?",
]
FLORENCE_QUESTION = (
"Is there a woman or any part of a woman's body in this image? "
"Answer yes or no only."
)
MODEL_DATA = {}
MODEL_STATUS = {"status": "loading", "message": "ุฌุงุฑูŠ ุชุญู…ูŠู„ ุงู„ู†ู…ุงุฐุฌ..."}
JOB_OUTPUTS = {}
def load_models() -> None:
try:
print("Loading BLIP...", flush=True)
MODEL_STATUS.update({"status": "loading", "message": "ุฌุงุฑูŠ ุชุญู…ูŠู„ BLIP..."})
start = time.time()
MODEL_DATA["blip_processor"] = BlipProcessor.from_pretrained(BLIP_MODEL_ID)
MODEL_DATA["blip_model"] = BlipForQuestionAnswering.from_pretrained(
BLIP_MODEL_ID,
torch_dtype=torch.float32,
).eval()
print(f"BLIP ready in {time.time() - start:.1f}s", flush=True)
print("Loading Florence-2...", flush=True)
MODEL_STATUS.update({"status": "loading", "message": "ุฌุงุฑูŠ ุชุญู…ูŠู„ Florence-2..."})
start = time.time()
MODEL_DATA["florence_processor"] = AutoProcessor.from_pretrained(
FLORENCE_MODEL_ID,
trust_remote_code=True,
)
MODEL_DATA["florence_model"] = AutoModelForCausalLM.from_pretrained(
FLORENCE_MODEL_ID,
torch_dtype=torch.float32,
trust_remote_code=True,
attn_implementation="eager",
).eval()
print(f"Florence-2 ready in {time.time() - start:.1f}s", flush=True)
MODEL_STATUS.update({"status": "ready", "message": "ุงู„ู†ู…ุงุฐุฌ ุฌุงู‡ุฒุฉ"})
print("All models loaded", flush=True)
except Exception as e:
MODEL_STATUS.update({"status": "error", "message": str(e)})
print(f"Error loading models: {e}", flush=True)
@asynccontextmanager
async def lifespan(app: FastAPI):
thread = threading.Thread(target=load_models, daemon=True)
thread.start()
print("Server started, models are loading in background", flush=True)
yield
MODEL_DATA.clear()
JOB_OUTPUTS.clear()
app = FastAPI(
title="Video Female Filter",
description="ุชุญู„ูŠู„ ุงู„ููŠุฏูŠูˆ ูˆุฅุฒุงู„ุฉ ู…ู‚ุงุทุน ุงู„ู†ุณุงุก | BLIP + Florence-2",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
def run_blip(image: Image.Image) -> dict:
processor = MODEL_DATA["blip_processor"]
model = MODEL_DATA["blip_model"]
yes_answers = {}
no_answers = {}
for question in BLIP_QUESTIONS:
inputs = processor(image, question, return_tensors="pt")
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=5)
answer = processor.decode(out[0], skip_special_tokens=True).strip().lower()
if answer == "yes" or answer.startswith("yes"):
yes_answers[question] = answer
else:
no_answers[question] = answer
return {"yes": yes_answers, "no": no_answers}
def run_florence(image: Image.Image) -> str:
processor = MODEL_DATA["florence_processor"]
model = MODEL_DATA["florence_model"]
task = "<VQA>"
prompt = f"{task}{FLORENCE_QUESTION}"
inputs = processor(text=prompt, images=image, return_tensors="pt")
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=10,
do_sample=False,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed = processor.post_process_generation(
generated_text,
task=task,
image_size=(image.width, image.height),
)
return parsed.get(task, "").strip().lower()
def is_female_in_frame(image: Image.Image) -> tuple[bool, str]:
blip_result = run_blip(image)
yes_q = blip_result["yes"]
if "is there a woman in this image?" in yes_q:
return True, "blip_woman"
if not yes_q:
return False, "blip_clean"
florence_answer = run_florence(image)
if "yes" in florence_answer:
return True, "florence_confirmed"
return False, "florence_clean"
def run_ffmpeg_command(args: list[str]) -> None:
proc = subprocess.run(args, capture_output=True, text=True)
if proc.returncode != 0:
stderr_msg = (proc.stderr or "").strip()
if len(stderr_msg) > 600:
stderr_msg = stderr_msg[-600:]
raise RuntimeError(f"ffmpeg failed (exit={proc.returncode}): {stderr_msg}")
def merge_overlapping_segments(segments: list[list[float]], duration_sec: float) -> list[list[float]]:
if not segments:
return []
clipped = []
for s, e in segments:
s = max(0.0, min(s, duration_sec))
e = max(0.0, min(e, duration_sec))
if e > s:
clipped.append([s, e])
if not clipped:
return []
clipped.sort(key=lambda x: x[0])
merged = [clipped[0]]
for s, e in clipped[1:]:
last = merged[-1]
if s <= last[1]:
last[1] = max(last[1], e)
else:
merged.append([s, e])
return merged
def cleanup_files(paths: list[str]) -> None:
for p in paths:
try:
if p and os.path.exists(p):
os.remove(p)
except Exception:
pass
def cleanup_job_output(job_id: str) -> None:
output = JOB_OUTPUTS.pop(job_id, None)
if output:
cleanup_files([output])
def build_clean_video(
input_path: str,
output_path: str,
keep_segments: list[list[float]],
job_id: str,
) -> bool:
segment_files = []
temp_files = []
try:
for i, (start_sec, end_sec) in enumerate(keep_segments):
seg_file = f"{TEMP_DIR}/{job_id}_seg_{i}.mp4"
temp_files.append(seg_file)
run_ffmpeg_command(
[
"ffmpeg",
"-y",
"-ss",
f"{start_sec:.3f}",
"-to",
f"{end_sec:.3f}",
"-i",
input_path,
"-map",
"0:v:0?",
"-map",
"0:a:0?",
"-c:v",
"libx264",
"-preset",
"veryfast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-c:a",
"aac",
"-b:a",
"128k",
"-movflags",
"+faststart",
seg_file,
]
)
if os.path.exists(seg_file) and os.path.getsize(seg_file) > 0:
segment_files.append(seg_file)
if not segment_files:
return False
list_file = f"{TEMP_DIR}/{job_id}_list.txt"
temp_files.append(list_file)
with open(list_file, "w", encoding="utf-8") as f:
for seg in segment_files:
f.write(f"file '{seg}'\n")
run_ffmpeg_command(
[
"ffmpeg",
"-y",
"-f",
"concat",
"-safe",
"0",
"-i",
list_file,
"-c:v",
"libx264",
"-preset",
"veryfast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-c:a",
"aac",
"-b:a",
"128k",
"-movflags",
"+faststart",
output_path,
]
)
return os.path.exists(output_path) and os.path.getsize(output_path) > 0
finally:
cleanup_files(temp_files)
@app.get("/", response_class=HTMLResponse)
def root():
with open("index.html", "r", encoding="utf-8") as f:
return f.read()
@app.get("/health")
def health():
return {
"status": MODEL_STATUS["status"],
"message": MODEL_STATUS["message"],
"blip_loaded": "blip_model" in MODEL_DATA,
"florence_loaded": "florence_model" in MODEL_DATA,
}
@app.post("/analyze-file")
async def analyze_file(file: UploadFile = File(...)):
if MODEL_STATUS["status"] != "ready":
raise HTTPException(
status_code=503,
detail=f"ุงู„ู†ู…ุงุฐุฌ ู„ู… ุชูƒุชู…ู„ ุจุนุฏ: {MODEL_STATUS['message']}",
)
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="ุงู„ู…ู„ู ู„ูŠุณ ุตูˆุฑุฉ")
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
has_female, reason = is_female_in_frame(image)
return {
"has_female": has_female,
"decision": "BLOCK" if has_female else "ALLOW",
"reason": reason,
"status": "success",
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze-video")
async def analyze_video(file: UploadFile = File(...)):
if MODEL_STATUS["status"] != "ready":
raise HTTPException(
status_code=503,
detail=f"ุงู„ู†ู…ุงุฐุฌ ู„ู… ุชูƒุชู…ู„ ุจุนุฏ: {MODEL_STATUS['message']}",
)
if not file.content_type or not file.content_type.startswith("video/"):
raise HTTPException(status_code=400, detail="ุงู„ู…ู„ู ู„ูŠุณ ููŠุฏูŠูˆ")
job_id = str(uuid.uuid4())[:8]
input_path = f"{TEMP_DIR}/{job_id}_input.mp4"
output_path = f"{TEMP_DIR}/{job_id}_output.mp4"
with open(input_path, "wb") as f:
while True:
chunk = await file.read(1024 * 1024)
if not chunk:
break
f.write(chunk)
try:
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
raise HTTPException(status_code=400, detail="ุชุนุฐุฑ ูุชุญ ุงู„ููŠุฏูŠูˆ")
fps = cap.get(cv2.CAP_PROP_FPS) or 25
if fps <= 0:
fps = 25
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration_sec = total_frames / fps if total_frames > 0 else 0.0
print(f"Video info: {total_frames} frames, {fps:.2f} fps", flush=True)
frame_interval = max(1, int(fps / FRAMES_PER_SECOND))
female_segments = []
analysis_log = []
in_female_seg = False
seg_start = 0.0
frame_idx = 0
start_time = time.time()
try:
while True:
ret, frame = cap.read()
if not ret:
break
if frame_idx % frame_interval == 0:
current_sec = frame_idx / fps
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
has_female, reason = is_female_in_frame(pil_image)
analysis_log.append(
{
"second": round(current_sec, 2),
"has_female": has_female,
"reason": reason,
}
)
if has_female and not in_female_seg:
in_female_seg = True
seg_start = max(0.0, current_sec - 0.5)
elif not has_female and in_female_seg:
in_female_seg = False
female_segments.append([seg_start, current_sec + 0.5])
frame_idx += 1
finally:
cap.release()
if in_female_seg:
female_segments.append([seg_start, duration_sec])
female_segments = merge_overlapping_segments(female_segments, duration_sec)
elapsed_analysis = round(time.time() - start_time, 2)
if not female_segments:
return {
"has_female": False,
"female_segments": [],
"analysis_log": analysis_log,
"message": "โœ… ุงู„ููŠุฏูŠูˆ ู†ุธูŠู ู„ุง ูŠุญุชูˆูŠ ุนู„ู‰ ู†ุณุงุก",
"analysis_time": elapsed_analysis,
"output_available": False,
"status": "success",
}
keep_segments = []
prev_end = 0.0
for s, e in female_segments:
if prev_end < s:
keep_segments.append([prev_end, s])
prev_end = e
if prev_end < duration_sec:
keep_segments.append([prev_end, duration_sec])
if not keep_segments:
return {
"has_female": True,
"female_segments": female_segments,
"analysis_log": analysis_log,
"message": "โš ๏ธ ุงู„ููŠุฏูŠูˆ ูƒู„ู‡ ูŠุญุชูˆูŠ ุนู„ู‰ ู†ุณุงุก",
"analysis_time": elapsed_analysis,
"output_available": False,
"status": "success",
}
output_ok = build_clean_video(input_path, output_path, keep_segments, job_id)
total_removed = sum(e - s for s, e in female_segments)
if output_ok:
JOB_OUTPUTS[job_id] = output_path
return {
"has_female": True,
"female_segments": female_segments,
"kept_segments": keep_segments,
"total_removed_sec": round(total_removed, 2),
"analysis_log": analysis_log,
"analysis_time": elapsed_analysis,
"output_available": output_ok,
"output_job_id": job_id,
"download_url": f"/download/{job_id}",
"message": f"โœ… ุชู… ุญุฐู {round(total_removed, 1)} ุซุงู†ูŠุฉ ู…ู† ุงู„ููŠุฏูŠูˆ",
"status": "success",
}
except HTTPException:
cleanup_files([output_path])
raise
except Exception as e:
cleanup_files([output_path])
raise HTTPException(status_code=500, detail=str(e))
finally:
cleanup_files([input_path])
@app.get("/download/{job_id}")
def download_video(job_id: str):
output_path = JOB_OUTPUTS.get(job_id, f"{TEMP_DIR}/{job_id}_output.mp4")
if not os.path.exists(output_path):
raise HTTPException(status_code=404, detail="ุงู„ููŠุฏูŠูˆ ุบูŠุฑ ู…ูˆุฌูˆุฏ")
return FileResponse(
output_path,
media_type="video/mp4",
filename="clean_video.mp4",
background=BackgroundTask(cleanup_job_output, job_id),
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)