Spaces:
Paused
Paused
Upload 3 files
Browse files- Dockerfile +7 -3
- processor.py +54 -27
- requirements.txt +10 -9
Dockerfile
CHANGED
|
@@ -11,12 +11,16 @@ RUN apt-get update && apt-get install -y \
|
|
| 11 |
COPY requirements.txt .
|
| 12 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
COPY . .
|
| 15 |
|
| 16 |
-
# Create temp directory
|
| 17 |
RUN mkdir -p /tmp/video-bg-remover
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
|
|
|
|
| 21 |
|
|
|
|
| 22 |
CMD uvicorn app:app --host 0.0.0.0 --port $PORT
|
|
|
|
| 11 |
COPY requirements.txt .
|
| 12 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
|
| 14 |
+
# Pre-download models during build
|
| 15 |
+
RUN python -c "import torch; torch.hub.load('intel-isl/MiDaS', 'MiDaS_small', trust_repo=True)"
|
| 16 |
+
|
| 17 |
COPY . .
|
| 18 |
|
|
|
|
| 19 |
RUN mkdir -p /tmp/video-bg-remover
|
| 20 |
|
| 21 |
+
# Create a non-root user
|
| 22 |
+
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app /tmp/video-bg-remover
|
| 23 |
+
USER appuser
|
| 24 |
|
| 25 |
+
ENV PORT=7860
|
| 26 |
CMD uvicorn app:app --host 0.0.0.0 --port $PORT
|
processor.py
CHANGED
|
@@ -7,54 +7,65 @@ from pathlib import Path
|
|
| 7 |
import asyncio
|
| 8 |
from concurrent.futures import ThreadPoolExecutor
|
| 9 |
import gc
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class VideoProcessor:
|
| 12 |
def __init__(self):
|
| 13 |
-
# Use CPU if no GPU
|
| 14 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
print(f"Using device: {self.device}")
|
| 16 |
|
| 17 |
-
# Load MiDaS
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
self.executor = ThreadPoolExecutor(max_workers=1)
|
| 27 |
|
| 28 |
def hex_to_rgb(self, hex_color: str):
|
| 29 |
-
"""Convert hex to RGB"""
|
| 30 |
hex_color = hex_color.lstrip('#')
|
| 31 |
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
| 32 |
|
| 33 |
async def process_video(self, input_path: str, threshold: float,
|
| 34 |
bg_color: str, session_id: str) -> str:
|
| 35 |
-
"""Process video asynchronously"""
|
| 36 |
loop = asyncio.get_event_loop()
|
| 37 |
output_path = str(Path("/tmp") / f"{session_id}_output.mp4")
|
| 38 |
|
| 39 |
-
# Run in thread pool
|
| 40 |
await loop.run_in_executor(
|
| 41 |
self.executor,
|
| 42 |
self._process_video_sync,
|
| 43 |
input_path, output_path, threshold, bg_color
|
| 44 |
)
|
| 45 |
-
|
| 46 |
return output_path
|
| 47 |
|
| 48 |
def _process_video_sync(self, input_path: str, output_path: str,
|
| 49 |
threshold: float, bg_color: str):
|
| 50 |
-
"""Synchronous video processing"""
|
| 51 |
cap = cv2.VideoCapture(input_path)
|
|
|
|
|
|
|
|
|
|
| 52 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 53 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 54 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 55 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 56 |
|
| 57 |
-
|
|
|
|
| 58 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 59 |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 60 |
|
|
@@ -66,9 +77,13 @@ class VideoProcessor:
|
|
| 66 |
if not ret:
|
| 67 |
break
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
frame_count += 1
|
| 74 |
if frame_count % 30 == 0:
|
|
@@ -82,20 +97,28 @@ class VideoProcessor:
|
|
| 82 |
|
| 83 |
cap.release()
|
| 84 |
out.release()
|
|
|
|
| 85 |
|
| 86 |
def process_frame(self, frame: np.ndarray, threshold: float,
|
| 87 |
bg_color: tuple) -> np.ndarray:
|
| 88 |
-
"""Process a single frame"""
|
| 89 |
-
# Resize for speed
|
| 90 |
h, w = frame.shape[:2]
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
frame_small = cv2.resize(frame, (new_w, new_h))
|
| 94 |
frame_rgb = cv2.cvtColor(frame_small, cv2.COLOR_BGR2RGB)
|
| 95 |
|
| 96 |
# Get depth map
|
| 97 |
img = Image.fromarray(frame_rgb)
|
| 98 |
-
input_batch = self.transform(img).to(self.device)
|
| 99 |
|
| 100 |
with torch.no_grad():
|
| 101 |
depth = self.model(input_batch)
|
|
@@ -109,13 +132,17 @@ class VideoProcessor:
|
|
| 109 |
# Normalize depth
|
| 110 |
depth_norm = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
|
| 111 |
|
| 112 |
-
# Create mask
|
| 113 |
mask = (depth_norm > threshold).astype(np.uint8) * 255
|
| 114 |
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR)
|
| 115 |
-
mask = mask.astype(bool)
|
| 116 |
|
| 117 |
-
#
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
return result
|
|
|
|
| 7 |
import asyncio
|
| 8 |
from concurrent.futures import ThreadPoolExecutor
|
| 9 |
import gc
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings('ignore')
|
| 12 |
|
| 13 |
class VideoProcessor:
|
| 14 |
def __init__(self):
|
|
|
|
| 15 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
print(f"Using device: {self.device}")
|
| 17 |
|
| 18 |
+
# Load MiDaS with proper error handling
|
| 19 |
+
try:
|
| 20 |
+
print("Loading MiDaS model...")
|
| 21 |
+
self.model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True)
|
| 22 |
+
self.model.to(self.device)
|
| 23 |
+
self.model.eval()
|
| 24 |
+
|
| 25 |
+
# Load transforms
|
| 26 |
+
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
|
| 27 |
+
self.transform = midas_transforms.small_transform
|
| 28 |
+
print("MiDaS model loaded successfully!")
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Error loading MiDaS: {e}")
|
| 31 |
+
print("Falling back to DPT model...")
|
| 32 |
+
# Fallback to DPT model
|
| 33 |
+
self.model = torch.hub.load("intel-isl/MiDaS", "DPT_Large", trust_repo=True)
|
| 34 |
+
self.model.to(self.device)
|
| 35 |
+
self.model.eval()
|
| 36 |
+
self.transform = midas_transforms.dpt_transform
|
| 37 |
|
| 38 |
self.executor = ThreadPoolExecutor(max_workers=1)
|
| 39 |
|
| 40 |
def hex_to_rgb(self, hex_color: str):
|
|
|
|
| 41 |
hex_color = hex_color.lstrip('#')
|
| 42 |
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
| 43 |
|
| 44 |
async def process_video(self, input_path: str, threshold: float,
|
| 45 |
bg_color: str, session_id: str) -> str:
|
|
|
|
| 46 |
loop = asyncio.get_event_loop()
|
| 47 |
output_path = str(Path("/tmp") / f"{session_id}_output.mp4")
|
| 48 |
|
|
|
|
| 49 |
await loop.run_in_executor(
|
| 50 |
self.executor,
|
| 51 |
self._process_video_sync,
|
| 52 |
input_path, output_path, threshold, bg_color
|
| 53 |
)
|
|
|
|
| 54 |
return output_path
|
| 55 |
|
| 56 |
def _process_video_sync(self, input_path: str, output_path: str,
|
| 57 |
threshold: float, bg_color: str):
|
|
|
|
| 58 |
cap = cv2.VideoCapture(input_path)
|
| 59 |
+
if not cap.isOpened():
|
| 60 |
+
raise ValueError(f"Could not open video: {input_path}")
|
| 61 |
+
|
| 62 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 63 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 64 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 65 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 66 |
|
| 67 |
+
print(f"Video info: {width}x{height}, {fps}fps, {total_frames} frames")
|
| 68 |
+
|
| 69 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 70 |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 71 |
|
|
|
|
| 77 |
if not ret:
|
| 78 |
break
|
| 79 |
|
| 80 |
+
try:
|
| 81 |
+
processed = self.process_frame(frame, threshold, bg_rgb)
|
| 82 |
+
out.write(processed)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error processing frame {frame_count}: {e}")
|
| 85 |
+
# Write original frame on error
|
| 86 |
+
out.write(frame)
|
| 87 |
|
| 88 |
frame_count += 1
|
| 89 |
if frame_count % 30 == 0:
|
|
|
|
| 97 |
|
| 98 |
cap.release()
|
| 99 |
out.release()
|
| 100 |
+
print(f"Video saved to {output_path}")
|
| 101 |
|
| 102 |
def process_frame(self, frame: np.ndarray, threshold: float,
|
| 103 |
bg_color: tuple) -> np.ndarray:
|
|
|
|
|
|
|
| 104 |
h, w = frame.shape[:2]
|
| 105 |
+
|
| 106 |
+
# Resize for speed while maintaining aspect ratio
|
| 107 |
+
max_size = 384
|
| 108 |
+
if h > max_size or w > max_size:
|
| 109 |
+
if h > w:
|
| 110 |
+
new_h, new_w = max_size, int(max_size * w / h)
|
| 111 |
+
else:
|
| 112 |
+
new_h, new_w = int(max_size * h / w), max_size
|
| 113 |
+
else:
|
| 114 |
+
new_h, new_w = h, w
|
| 115 |
|
| 116 |
frame_small = cv2.resize(frame, (new_w, new_h))
|
| 117 |
frame_rgb = cv2.cvtColor(frame_small, cv2.COLOR_BGR2RGB)
|
| 118 |
|
| 119 |
# Get depth map
|
| 120 |
img = Image.fromarray(frame_rgb)
|
| 121 |
+
input_batch = self.transform(img).unsqueeze(0).to(self.device)
|
| 122 |
|
| 123 |
with torch.no_grad():
|
| 124 |
depth = self.model(input_batch)
|
|
|
|
| 132 |
# Normalize depth
|
| 133 |
depth_norm = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
|
| 134 |
|
| 135 |
+
# Create mask
|
| 136 |
mask = (depth_norm > threshold).astype(np.uint8) * 255
|
| 137 |
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
| 138 |
|
| 139 |
+
# Smooth mask edges
|
| 140 |
+
mask = cv2.GaussianBlur(mask, (5, 5), 0)
|
| 141 |
+
mask_float = mask.astype(np.float32) / 255.0
|
| 142 |
+
mask_3channel = np.stack([mask_float] * 3, axis=2)
|
| 143 |
+
|
| 144 |
+
# Apply background with soft edges
|
| 145 |
+
bg_array = np.array(bg_color, dtype=np.float32).reshape(1, 1, 3)
|
| 146 |
+
result = (frame * mask_3channel + bg_array * (1 - mask_3channel)).astype(np.uint8)
|
| 147 |
|
| 148 |
return result
|
requirements.txt
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
-
fastapi==0.104.1
|
| 2 |
-
uvicorn==0.24.0
|
| 3 |
-
torch==2.1.0
|
| 4 |
-
torchvision==0.16.0
|
| 5 |
-
opencv-python-headless==4.8.1.78
|
| 6 |
-
numpy==1.24.3
|
| 7 |
-
Pillow==10.1.0
|
| 8 |
-
python-multipart==0.0.6
|
| 9 |
-
timm==0.9.2
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.1
|
| 2 |
+
uvicorn==0.24.0
|
| 3 |
+
torch==2.1.0
|
| 4 |
+
torchvision==0.16.0
|
| 5 |
+
opencv-python-headless==4.8.1.78
|
| 6 |
+
numpy==1.24.3
|
| 7 |
+
Pillow==10.1.0
|
| 8 |
+
python-multipart==0.0.6
|
| 9 |
+
timm==0.9.2
|
| 10 |
+
transformers==4.35.0
|