MogensR commited on
Commit
e58ceed
·
1 Parent(s): 0197715

Create two_stage_processor.py

Browse files
Files changed (1) hide show
  1. two_stage_processor.py +325 -0
two_stage_processor.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Two-Stage Green Screen Processing System
4
+ Stage 1: Original → Green Screen
5
+ Stage 2: Green Screen → Final Background
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import os
11
+ import pickle
12
+ import logging
13
+ from pathlib import Path
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class TwoStageProcessor:
18
+ """Handle two-stage video processing with green screen intermediate"""
19
+
20
+ def __init__(self, sam2_predictor=None, matanyone_model=None):
21
+ self.sam2_predictor = sam2_predictor
22
+ self.matanyone_model = matanyone_model
23
+ self.mask_cache_dir = Path("/tmp/mask_cache")
24
+ self.mask_cache_dir.mkdir(exist_ok=True, parents=True)
25
+
26
+ def stage1_extract_to_greenscreen(self, video_path, output_path, progress_callback=None):
27
+ """
28
+ Stage 1: Extract person and create green screen video
29
+ Also saves masks for potential reuse
30
+ """
31
+ def _prog(pct: float, desc: str):
32
+ if progress_callback:
33
+ progress_callback(pct, desc)
34
+
35
+ try:
36
+ _prog(0.0, "Stage 1: Extracting to green screen...")
37
+
38
+ cap = cv2.VideoCapture(video_path)
39
+ if not cap.isOpened():
40
+ return None, "Could not open video file"
41
+
42
+ fps = cap.get(cv2.CAP_PROP_FPS)
43
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
44
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
45
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
46
+
47
+ # Pure green background for chroma keying
48
+ green_bg = np.zeros((height, width, 3), dtype=np.uint8)
49
+ green_bg[:, :] = [0, 255, 0] # Pure green in BGR
50
+
51
+ # Setup output
52
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
53
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
54
+
55
+ # Storage for masks (for potential reuse)
56
+ masks = []
57
+ frame_count = 0
58
+
59
+ while True:
60
+ ret, frame = cap.read()
61
+ if not ret:
62
+ break
63
+
64
+ _prog(0.1 + (frame_count / max(1, total_frames)) * 0.8,
65
+ f"Stage 1: Processing frame {frame_count + 1}/{total_frames}")
66
+
67
+ # Get mask using SAM2
68
+ mask = self._extract_person_mask(frame)
69
+ masks.append(mask)
70
+
71
+ # Refine mask every 3rd frame with MatAnyone
72
+ if frame_count % 3 == 0 and self.matanyone_model:
73
+ mask = self._refine_mask(frame, mask)
74
+
75
+ # Apply green screen with HARD edges for clean keying
76
+ result = self._apply_greenscreen_hard(frame, mask, green_bg)
77
+ out.write(result)
78
+
79
+ frame_count += 1
80
+
81
+ cap.release()
82
+ out.release()
83
+
84
+ # Save masks for potential reuse
85
+ mask_file = self.mask_cache_dir / f"{Path(output_path).stem}_masks.pkl"
86
+ with open(mask_file, 'wb') as f:
87
+ pickle.dump(masks, f)
88
+
89
+ _prog(1.0, "Stage 1 complete: Green screen created")
90
+ return output_path, f"Green screen created: {frame_count} frames"
91
+
92
+ except Exception as e:
93
+ logger.error(f"Stage 1 error: {e}")
94
+ return None, f"Stage 1 failed: {str(e)}"
95
+
96
+ def stage2_greenscreen_to_final(self, greenscreen_path, background, output_path,
97
+ chroma_settings=None, progress_callback=None):
98
+ """
99
+ Stage 2: Replace green screen with final background using chroma keying
100
+ """
101
+ def _prog(pct: float, desc: str):
102
+ if progress_callback:
103
+ progress_callback(pct, desc)
104
+
105
+ if chroma_settings is None:
106
+ chroma_settings = {
107
+ 'key_color': [0, 255, 0], # Green in BGR
108
+ 'tolerance': 40,
109
+ 'edge_softness': 2,
110
+ 'spill_suppression': 0.3
111
+ }
112
+
113
+ try:
114
+ _prog(0.0, "Stage 2: Applying final background...")
115
+
116
+ cap = cv2.VideoCapture(greenscreen_path)
117
+ if not cap.isOpened():
118
+ return None, "Could not open green screen video"
119
+
120
+ fps = cap.get(cv2.CAP_PROP_FPS)
121
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
122
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
123
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
124
+
125
+ # Prepare background
126
+ if isinstance(background, str):
127
+ bg = cv2.imread(background)
128
+ if bg is None:
129
+ return None, "Could not load background image"
130
+ else:
131
+ bg = background
132
+
133
+ bg = cv2.resize(bg, (width, height))
134
+
135
+ # Setup output
136
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
137
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
138
+
139
+ frame_count = 0
140
+
141
+ while True:
142
+ ret, frame = cap.read()
143
+ if not ret:
144
+ break
145
+
146
+ _prog(0.1 + (frame_count / max(1, total_frames)) * 0.8,
147
+ f"Stage 2: Compositing frame {frame_count + 1}/{total_frames}")
148
+
149
+ # Apply chroma keying
150
+ result = self._chroma_key_advanced(frame, bg, chroma_settings)
151
+ out.write(result)
152
+
153
+ frame_count += 1
154
+
155
+ cap.release()
156
+ out.release()
157
+
158
+ _prog(1.0, "Stage 2 complete: Final video created")
159
+ return output_path, f"Final video created: {frame_count} frames"
160
+
161
+ except Exception as e:
162
+ logger.error(f"Stage 2 error: {e}")
163
+ return None, f"Stage 2 failed: {str(e)}"
164
+
165
+ def _extract_person_mask(self, frame):
166
+ """Extract person mask using SAM2"""
167
+ if self.sam2_predictor is None:
168
+ # Fallback mask
169
+ h, w = frame.shape[:2]
170
+ mask = np.zeros((h, w), dtype=np.uint8)
171
+ mask[h//6:5*h//6, w//4:3*w//4] = 255
172
+ return mask
173
+
174
+ try:
175
+ self.sam2_predictor.set_image(frame)
176
+ h, w = frame.shape[:2]
177
+
178
+ # Strategic points for person
179
+ points = np.array([
180
+ [w//2, h//3], # Head
181
+ [w//2, h//2], # Torso
182
+ [w//2, 2*h//3], # Lower body
183
+ [w//3, h//2], # Left
184
+ [2*w//3, h//2], # Right
185
+ ])
186
+ labels = np.ones(len(points))
187
+
188
+ masks, scores, _ = self.sam2_predictor.predict(
189
+ point_coords=points,
190
+ point_labels=labels,
191
+ multimask_output=True
192
+ )
193
+
194
+ best_idx = np.argmax(scores)
195
+ mask = masks[best_idx]
196
+
197
+ if mask.dtype != np.uint8:
198
+ mask = (mask * 255).astype(np.uint8)
199
+
200
+ return mask
201
+
202
+ except Exception as e:
203
+ logger.error(f"Mask extraction error: {e}")
204
+ h, w = frame.shape[:2]
205
+ mask = np.zeros((h, w), dtype=np.uint8)
206
+ mask[h//6:5*h//6, w//4:3*w//4] = 255
207
+ return mask
208
+
209
+ def _refine_mask(self, frame, mask):
210
+ """Refine mask using MatAnyone if available"""
211
+ if self.matanyone_model is None:
212
+ return mask
213
+
214
+ try:
215
+ # MatAnyone refinement logic here
216
+ # This would depend on your MatAnyone implementation
217
+ return mask
218
+ except:
219
+ return mask
220
+
221
+ def _apply_greenscreen_hard(self, frame, mask, green_bg):
222
+ """Apply green screen with hard edges for clean chroma keying"""
223
+ # Binary threshold for clean edges
224
+ _, mask_binary = cv2.threshold(mask, 140, 255, cv2.THRESH_BINARY)
225
+
226
+ # No feathering - we want hard edges for chroma keying
227
+ mask_3ch = cv2.cvtColor(mask_binary, cv2.COLOR_GRAY2BGR)
228
+ mask_norm = mask_3ch.astype(float) / 255
229
+
230
+ # Composite
231
+ result = frame * mask_norm + green_bg * (1 - mask_norm)
232
+ return result.astype(np.uint8)
233
+
234
+ def _chroma_key_advanced(self, frame, background, settings):
235
+ """
236
+ Advanced chroma keying with spill suppression
237
+ """
238
+ key_color = np.array(settings['key_color'], dtype=np.uint8)
239
+ tolerance = settings['tolerance']
240
+ softness = settings['edge_softness']
241
+ spill_suppress = settings['spill_suppression']
242
+
243
+ # Convert to float for processing
244
+ frame_float = frame.astype(np.float32)
245
+ bg_float = background.astype(np.float32)
246
+
247
+ # Calculate color distance from key color
248
+ diff = np.abs(frame_float - key_color)
249
+ distance = np.sqrt(np.sum(diff ** 2, axis=2))
250
+
251
+ # Create mask based on distance
252
+ mask = np.where(distance < tolerance, 0, 1)
253
+
254
+ # Edge softening
255
+ if softness > 0:
256
+ mask = cv2.GaussianBlur(mask.astype(np.float32),
257
+ (softness*2+1, softness*2+1),
258
+ softness)
259
+
260
+ # Spill suppression - reduce green in edges
261
+ if spill_suppress > 0:
262
+ green_channel = frame_float[:, :, 1]
263
+ spill_mask = np.where(mask < 1, 1 - mask, 0)
264
+ green_suppression = green_channel * spill_mask * spill_suppress
265
+ frame_float[:, :, 1] -= green_suppression
266
+ frame_float = np.clip(frame_float, 0, 255)
267
+
268
+ # Expand mask to 3 channels
269
+ mask_3ch = np.stack([mask] * 3, axis=2)
270
+
271
+ # Composite
272
+ result = frame_float * mask_3ch + bg_float * (1 - mask_3ch)
273
+ return np.clip(result, 0, 255).astype(np.uint8)
274
+
275
+ def process_full_pipeline(self, video_path, background, final_output,
276
+ chroma_settings=None, progress_callback=None):
277
+ """
278
+ Run the complete two-stage pipeline
279
+ """
280
+ import tempfile
281
+
282
+ # Stage 1: Create green screen
283
+ greenscreen_path = tempfile.mktemp(suffix='_greenscreen.mp4')
284
+ gs_result, gs_msg = self.stage1_extract_to_greenscreen(
285
+ video_path, greenscreen_path, progress_callback
286
+ )
287
+
288
+ if gs_result is None:
289
+ return None, gs_msg
290
+
291
+ # Stage 2: Apply final background
292
+ final_result, final_msg = self.stage2_greenscreen_to_final(
293
+ greenscreen_path, background, final_output,
294
+ chroma_settings, progress_callback
295
+ )
296
+
297
+ # Cleanup
298
+ try:
299
+ os.remove(greenscreen_path)
300
+ except:
301
+ pass
302
+
303
+ return final_result, final_msg
304
+
305
+ # Chroma key settings presets
306
+ CHROMA_PRESETS = {
307
+ 'standard': {
308
+ 'key_color': [0, 255, 0],
309
+ 'tolerance': 40,
310
+ 'edge_softness': 2,
311
+ 'spill_suppression': 0.3
312
+ },
313
+ 'tight': {
314
+ 'key_color': [0, 255, 0],
315
+ 'tolerance': 30,
316
+ 'edge_softness': 1,
317
+ 'spill_suppression': 0.4
318
+ },
319
+ 'soft': {
320
+ 'key_color': [0, 255, 0],
321
+ 'tolerance': 50,
322
+ 'edge_softness': 3,
323
+ 'spill_suppression': 0.2
324
+ }
325
+ }