MySafeCode commited on
Commit
55b8985
·
verified ·
1 Parent(s): a360262

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +7 -3
  2. processor.py +54 -27
  3. requirements.txt +10 -9
Dockerfile CHANGED
@@ -11,12 +11,16 @@ RUN apt-get update && apt-get install -y \
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
 
 
 
14
  COPY . .
15
 
16
- # Create temp directory
17
  RUN mkdir -p /tmp/video-bg-remover
18
 
19
- # For Hugging Face Spaces
20
- ENV PORT=7860
 
21
 
 
22
  CMD uvicorn app:app --host 0.0.0.0 --port $PORT
 
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
14
+ # Pre-download models during build
15
+ RUN python -c "import torch; torch.hub.load('intel-isl/MiDaS', 'MiDaS_small', trust_repo=True)"
16
+
17
  COPY . .
18
 
 
19
  RUN mkdir -p /tmp/video-bg-remover
20
 
21
+ # Create a non-root user
22
+ RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app /tmp/video-bg-remover
23
+ USER appuser
24
 
25
+ ENV PORT=7860
26
  CMD uvicorn app:app --host 0.0.0.0 --port $PORT
processor.py CHANGED
@@ -7,54 +7,65 @@ from pathlib import Path
7
  import asyncio
8
  from concurrent.futures import ThreadPoolExecutor
9
  import gc
 
 
10
 
11
  class VideoProcessor:
12
  def __init__(self):
13
- # Use CPU if no GPU
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  print(f"Using device: {self.device}")
16
 
17
- # Load MiDaS (small model for speed)
18
- self.model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
19
- self.model.to(self.device)
20
- self.model.eval()
21
-
22
- # Load transforms
23
- midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
24
- self.transform = midas_transforms.small_transform
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  self.executor = ThreadPoolExecutor(max_workers=1)
27
 
28
  def hex_to_rgb(self, hex_color: str):
29
- """Convert hex to RGB"""
30
  hex_color = hex_color.lstrip('#')
31
  return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
32
 
33
  async def process_video(self, input_path: str, threshold: float,
34
  bg_color: str, session_id: str) -> str:
35
- """Process video asynchronously"""
36
  loop = asyncio.get_event_loop()
37
  output_path = str(Path("/tmp") / f"{session_id}_output.mp4")
38
 
39
- # Run in thread pool
40
  await loop.run_in_executor(
41
  self.executor,
42
  self._process_video_sync,
43
  input_path, output_path, threshold, bg_color
44
  )
45
-
46
  return output_path
47
 
48
  def _process_video_sync(self, input_path: str, output_path: str,
49
  threshold: float, bg_color: str):
50
- """Synchronous video processing"""
51
  cap = cv2.VideoCapture(input_path)
 
 
 
52
  fps = int(cap.get(cv2.CAP_PROP_FPS))
53
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
54
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
55
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
56
 
57
- # Output video
 
58
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
59
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
60
 
@@ -66,9 +77,13 @@ class VideoProcessor:
66
  if not ret:
67
  break
68
 
69
- # Process frame
70
- processed = self.process_frame(frame, threshold, bg_rgb)
71
- out.write(processed)
 
 
 
 
72
 
73
  frame_count += 1
74
  if frame_count % 30 == 0:
@@ -82,20 +97,28 @@ class VideoProcessor:
82
 
83
  cap.release()
84
  out.release()
 
85
 
86
  def process_frame(self, frame: np.ndarray, threshold: float,
87
  bg_color: tuple) -> np.ndarray:
88
- """Process a single frame"""
89
- # Resize for speed
90
  h, w = frame.shape[:2]
91
- new_h, new_w = 256, int(256 * w / h)
 
 
 
 
 
 
 
 
 
92
 
93
  frame_small = cv2.resize(frame, (new_w, new_h))
94
  frame_rgb = cv2.cvtColor(frame_small, cv2.COLOR_BGR2RGB)
95
 
96
  # Get depth map
97
  img = Image.fromarray(frame_rgb)
98
- input_batch = self.transform(img).to(self.device)
99
 
100
  with torch.no_grad():
101
  depth = self.model(input_batch)
@@ -109,13 +132,17 @@ class VideoProcessor:
109
  # Normalize depth
110
  depth_norm = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
111
 
112
- # Create mask and resize to original
113
  mask = (depth_norm > threshold).astype(np.uint8) * 255
114
  mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR)
115
- mask = mask.astype(bool)
116
 
117
- # Apply background
118
- result = frame.copy()
119
- result[~mask] = bg_color
 
 
 
 
 
120
 
121
  return result
 
7
  import asyncio
8
  from concurrent.futures import ThreadPoolExecutor
9
  import gc
10
+ import warnings
11
+ warnings.filterwarnings('ignore')
12
 
13
  class VideoProcessor:
14
  def __init__(self):
 
15
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  print(f"Using device: {self.device}")
17
 
18
+ # Load MiDaS with proper error handling
19
+ try:
20
+ print("Loading MiDaS model...")
21
+ self.model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True)
22
+ self.model.to(self.device)
23
+ self.model.eval()
24
+
25
+ # Load transforms
26
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
27
+ self.transform = midas_transforms.small_transform
28
+ print("MiDaS model loaded successfully!")
29
+ except Exception as e:
30
+ print(f"Error loading MiDaS: {e}")
31
+ print("Falling back to DPT model...")
32
+ # Fallback to DPT model
33
+ self.model = torch.hub.load("intel-isl/MiDaS", "DPT_Large", trust_repo=True)
34
+ self.model.to(self.device)
35
+ self.model.eval()
36
+ self.transform = midas_transforms.dpt_transform
37
 
38
  self.executor = ThreadPoolExecutor(max_workers=1)
39
 
40
  def hex_to_rgb(self, hex_color: str):
 
41
  hex_color = hex_color.lstrip('#')
42
  return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
43
 
44
  async def process_video(self, input_path: str, threshold: float,
45
  bg_color: str, session_id: str) -> str:
 
46
  loop = asyncio.get_event_loop()
47
  output_path = str(Path("/tmp") / f"{session_id}_output.mp4")
48
 
 
49
  await loop.run_in_executor(
50
  self.executor,
51
  self._process_video_sync,
52
  input_path, output_path, threshold, bg_color
53
  )
 
54
  return output_path
55
 
56
  def _process_video_sync(self, input_path: str, output_path: str,
57
  threshold: float, bg_color: str):
 
58
  cap = cv2.VideoCapture(input_path)
59
+ if not cap.isOpened():
60
+ raise ValueError(f"Could not open video: {input_path}")
61
+
62
  fps = int(cap.get(cv2.CAP_PROP_FPS))
63
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
64
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
65
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
66
 
67
+ print(f"Video info: {width}x{height}, {fps}fps, {total_frames} frames")
68
+
69
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
70
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
71
 
 
77
  if not ret:
78
  break
79
 
80
+ try:
81
+ processed = self.process_frame(frame, threshold, bg_rgb)
82
+ out.write(processed)
83
+ except Exception as e:
84
+ print(f"Error processing frame {frame_count}: {e}")
85
+ # Write original frame on error
86
+ out.write(frame)
87
 
88
  frame_count += 1
89
  if frame_count % 30 == 0:
 
97
 
98
  cap.release()
99
  out.release()
100
+ print(f"Video saved to {output_path}")
101
 
102
  def process_frame(self, frame: np.ndarray, threshold: float,
103
  bg_color: tuple) -> np.ndarray:
 
 
104
  h, w = frame.shape[:2]
105
+
106
+ # Resize for speed while maintaining aspect ratio
107
+ max_size = 384
108
+ if h > max_size or w > max_size:
109
+ if h > w:
110
+ new_h, new_w = max_size, int(max_size * w / h)
111
+ else:
112
+ new_h, new_w = int(max_size * h / w), max_size
113
+ else:
114
+ new_h, new_w = h, w
115
 
116
  frame_small = cv2.resize(frame, (new_w, new_h))
117
  frame_rgb = cv2.cvtColor(frame_small, cv2.COLOR_BGR2RGB)
118
 
119
  # Get depth map
120
  img = Image.fromarray(frame_rgb)
121
+ input_batch = self.transform(img).unsqueeze(0).to(self.device)
122
 
123
  with torch.no_grad():
124
  depth = self.model(input_batch)
 
132
  # Normalize depth
133
  depth_norm = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
134
 
135
+ # Create mask
136
  mask = (depth_norm > threshold).astype(np.uint8) * 255
137
  mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR)
 
138
 
139
+ # Smooth mask edges
140
+ mask = cv2.GaussianBlur(mask, (5, 5), 0)
141
+ mask_float = mask.astype(np.float32) / 255.0
142
+ mask_3channel = np.stack([mask_float] * 3, axis=2)
143
+
144
+ # Apply background with soft edges
145
+ bg_array = np.array(bg_color, dtype=np.float32).reshape(1, 1, 3)
146
+ result = (frame * mask_3channel + bg_array * (1 - mask_3channel)).astype(np.uint8)
147
 
148
  return result
requirements.txt CHANGED
@@ -1,9 +1,10 @@
1
- fastapi==0.104.1
2
- uvicorn==0.24.0
3
- torch==2.1.0
4
- torchvision==0.16.0
5
- opencv-python-headless==4.8.1.78
6
- numpy==1.24.3
7
- Pillow==10.1.0
8
- python-multipart==0.0.6
9
- timm==0.9.2
 
 
1
+ fastapi==0.104.1
2
+ uvicorn==0.24.0
3
+ torch==2.1.0
4
+ torchvision==0.16.0
5
+ opencv-python-headless==4.8.1.78
6
+ numpy==1.24.3
7
+ Pillow==10.1.0
8
+ python-multipart==0.0.6
9
+ timm==0.9.2
10
+ transformers==4.35.0