MogensR commited on
Commit
ad645ee
·
1 Parent(s): e192748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +525 -781
app.py CHANGED
@@ -1,798 +1,542 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from PIL import Image
6
  import numpy as np
7
- from torchvision import transforms
8
- import os
9
- from typing import Tuple, Optional
10
-
11
- # ===== VERSION 2.2 - COMPLETE IMPLEMENTATION WITH TRANSPARENCY FIX =====
12
- print("===== MYAVATARS.DK VERSION 2.2 - FULL IMPLEMENTATION =====")
13
- print("===== TRANSPARENCY FIX INCLUDED INLINE =====")
14
- print("===== CACHE BUSTER ACTIVE =====")
15
-
16
- # Fix OMP_NUM_THREADS warning
17
- os.environ['OMP_NUM_THREADS'] = '1'
18
-
19
- # Force CPU usage and disable CUDA
20
- os.environ['CUDA_VISIBLE_DEVICES'] = ''
21
- device = torch.device('cpu')
22
-
23
- # ============================================================================
24
- # PREPROCESSING AND POSTPROCESSING FUNCTIONS (INLINE FOR CACHE ISSUES)
25
- # ============================================================================
26
-
27
- def preprocess_image(im: Image.Image, model_input_size: list) -> torch.Tensor:
28
- """
29
- Preprocess image for model input.
30
- Fixed version that maintains proper tensor dimensions.
31
- """
32
- if len(im.shape) < 3:
33
- im = im.convert('RGB')
34
-
35
- # Create transform pipeline
36
- transform = transforms.Compose([
37
- transforms.Resize(model_input_size),
38
- transforms.ToTensor(),
39
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
40
- std=[0.229, 0.224, 0.225])
41
- ])
42
-
43
- im_tensor = transform(im)
44
- im_tensor = im_tensor.unsqueeze(0) # Add batch dimension
45
- return im_tensor
46
-
47
- def postprocess_image(result: torch.Tensor, im_size: list) -> Image.Image:
48
- """
49
- Postprocess model output to final image.
50
- FIXED: Returns proper alpha channel instead of inverted mask.
51
- """
52
- result = result.squeeze(0) # Remove batch dimension if present
53
-
54
- if result.dim() == 3 and result.shape[0] == 1:
55
- result = result.squeeze(0) # Remove channel dimension for single channel
56
-
57
- # Convert to numpy and ensure values are in [0, 1]
58
- result_np = result.detach().cpu().numpy()
59
-
60
- # CRITICAL FIX: Model outputs foreground probability (0=background, 1=foreground)
61
- # We need alpha channel where 255=opaque (foreground), 0=transparent (background)
62
- # So we DON'T invert - we scale directly to 0-255
63
- result_np = (result_np * 255).astype(np.uint8)
64
-
65
- # Create PIL Image and resize to original dimensions
66
- pil_im = Image.fromarray(result_np, mode='L')
67
- pil_im = pil_im.resize(im_size, Image.LANCZOS)
68
-
69
- return pil_im
70
-
71
- # ============================================================================
72
- # MODEL ARCHITECTURE DEFINITION
73
- # ============================================================================
74
-
75
- class REBNCONV(nn.Module):
76
- def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
77
- super(REBNCONV, self).__init__()
78
- self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1*dirate, dilation=1*dirate, stride=stride)
79
- self.bn_s1 = nn.BatchNorm2d(out_ch)
80
- self.relu_s1 = nn.ReLU(inplace=True)
81
-
82
- def forward(self, x):
83
- hx = x
84
- xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
85
- return xout
86
-
87
- def _upsample_like(src, tar):
88
- src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
89
- return src
90
-
91
- class RSU7(nn.Module):
92
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
93
- super(RSU7, self).__init__()
94
- self.in_ch = in_ch
95
- self.mid_ch = mid_ch
96
- self.out_ch = out_ch
97
-
98
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
99
-
100
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
101
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
102
-
103
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
104
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
105
-
106
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
107
- self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
108
-
109
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
110
- self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
111
-
112
- self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
113
- self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
114
-
115
- self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
116
-
117
- self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
118
-
119
- self.rebnconv6d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
120
- self.rebnconv5d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
121
- self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
122
- self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
123
- self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
124
- self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
125
-
126
- def forward(self, x):
127
- b, c, h, w = x.shape
128
-
129
- hx = x
130
- hxin = self.rebnconvin(hx)
131
-
132
- hx1 = self.rebnconv1(hxin)
133
- hx = self.pool1(hx1)
134
-
135
- hx2 = self.rebnconv2(hx)
136
- hx = self.pool2(hx2)
137
-
138
- hx3 = self.rebnconv3(hx)
139
- hx = self.pool3(hx3)
140
-
141
- hx4 = self.rebnconv4(hx)
142
- hx = self.pool4(hx4)
143
-
144
- hx5 = self.rebnconv5(hx)
145
- hx = self.pool5(hx5)
146
-
147
- hx6 = self.rebnconv6(hx)
148
-
149
- hx7 = self.rebnconv7(hx6)
150
-
151
- hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
152
- hx6dup = _upsample_like(hx6d, hx5)
153
-
154
- hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
155
- hx5dup = _upsample_like(hx5d, hx4)
156
-
157
- hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
158
- hx4dup = _upsample_like(hx4d, hx3)
159
-
160
- hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
161
- hx3dup = _upsample_like(hx3d, hx2)
162
-
163
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
164
- hx2dup = _upsample_like(hx2d, hx1)
165
-
166
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
167
-
168
- return hx1d + hxin
169
-
170
- class RSU6(nn.Module):
171
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
172
- super(RSU6, self).__init__()
173
-
174
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
175
-
176
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
177
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
178
-
179
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
180
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
181
-
182
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
183
- self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
184
-
185
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
186
- self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
187
-
188
- self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
189
-
190
- self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
191
-
192
- self.rebnconv5d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
193
- self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
194
- self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
195
- self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
196
- self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
197
-
198
- def forward(self, x):
199
- hx = x
200
-
201
- hxin = self.rebnconvin(hx)
202
-
203
- hx1 = self.rebnconv1(hxin)
204
- hx = self.pool1(hx1)
205
-
206
- hx2 = self.rebnconv2(hx)
207
- hx = self.pool2(hx2)
208
-
209
- hx3 = self.rebnconv3(hx)
210
- hx = self.pool3(hx3)
211
-
212
- hx4 = self.rebnconv4(hx)
213
- hx = self.pool4(hx4)
214
-
215
- hx5 = self.rebnconv5(hx)
216
-
217
- hx6 = self.rebnconv6(hx5)
218
-
219
- hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
220
- hx5dup = _upsample_like(hx5d, hx4)
221
-
222
- hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
223
- hx4dup = _upsample_like(hx4d, hx3)
224
-
225
- hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
226
- hx3dup = _upsample_like(hx3d, hx2)
227
-
228
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
229
- hx2dup = _upsample_like(hx2d, hx1)
230
-
231
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
232
-
233
- return hx1d + hxin
234
-
235
- class RSU5(nn.Module):
236
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
237
- super(RSU5, self).__init__()
238
-
239
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
240
-
241
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
242
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
243
-
244
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
245
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
246
-
247
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
248
- self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
249
-
250
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
251
-
252
- self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
253
-
254
- self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
255
- self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
256
- self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
257
- self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
258
-
259
- def forward(self, x):
260
- hx = x
261
-
262
- hxin = self.rebnconvin(hx)
263
-
264
- hx1 = self.rebnconv1(hxin)
265
- hx = self.pool1(hx1)
266
-
267
- hx2 = self.rebnconv2(hx)
268
- hx = self.pool2(hx2)
269
-
270
- hx3 = self.rebnconv3(hx)
271
- hx = self.pool3(hx3)
272
-
273
- hx4 = self.rebnconv4(hx)
274
-
275
- hx5 = self.rebnconv5(hx4)
276
-
277
- hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
278
- hx4dup = _upsample_like(hx4d, hx3)
279
-
280
- hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
281
- hx3dup = _upsample_like(hx3d, hx2)
282
-
283
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
284
- hx2dup = _upsample_like(hx2d, hx1)
285
-
286
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
287
-
288
- return hx1d + hxin
289
-
290
- class RSU4(nn.Module):
291
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
292
- super(RSU4, self).__init__()
293
-
294
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
295
-
296
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
297
- self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
298
-
299
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
300
- self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
301
-
302
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
303
-
304
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
305
-
306
- self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
307
- self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
308
- self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
309
-
310
- def forward(self, x):
311
- hx = x
312
-
313
- hxin = self.rebnconvin(hx)
314
-
315
- hx1 = self.rebnconv1(hxin)
316
- hx = self.pool1(hx1)
317
-
318
- hx2 = self.rebnconv2(hx)
319
- hx = self.pool2(hx2)
320
-
321
- hx3 = self.rebnconv3(hx)
322
-
323
- hx4 = self.rebnconv4(hx3)
324
-
325
- hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
326
- hx3dup = _upsample_like(hx3d, hx2)
327
-
328
- hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
329
- hx2dup = _upsample_like(hx2d, hx1)
330
-
331
- hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
332
-
333
- return hx1d + hxin
334
-
335
- class RSU4F(nn.Module):
336
- def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
337
- super(RSU4F, self).__init__()
338
-
339
- self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
340
-
341
- self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
342
- self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
343
- self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
344
-
345
- self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
346
-
347
- self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=4)
348
- self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=2)
349
- self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
350
-
351
- def forward(self, x):
352
- hx = x
353
-
354
- hxin = self.rebnconvin(hx)
355
-
356
- hx1 = self.rebnconv1(hxin)
357
- hx2 = self.rebnconv2(hx1)
358
- hx3 = self.rebnconv3(hx2)
359
-
360
- hx4 = self.rebnconv4(hx3)
361
-
362
- hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
363
- hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
364
- hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
365
-
366
- return hx1d + hxin
367
-
368
- class BriaRMBG(nn.Module):
369
- """
370
- BRIA RMBG Model for background removal.
371
- """
372
- def __init__(self, config=None):
373
- super(BriaRMBG, self).__init__()
374
-
375
- in_ch = 3
376
- out_ch = 1
377
-
378
- self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
379
- self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
380
-
381
- self.stage1 = RSU7(64, 32, 64)
382
- self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
383
-
384
- self.stage2 = RSU6(64, 32, 128)
385
- self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
386
-
387
- self.stage3 = RSU5(128, 64, 256)
388
- self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
389
-
390
- self.stage4 = RSU4(256, 128, 512)
391
- self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
392
-
393
- self.stage5 = RSU4F(512, 256, 512)
394
- self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
395
-
396
- self.stage6 = RSU4F(512, 256, 512)
397
-
398
- # decoder
399
- self.stage5d = RSU4F(1024, 256, 512)
400
- self.stage4d = RSU4(1024, 128, 256)
401
- self.stage3d = RSU5(512, 64, 128)
402
- self.stage2d = RSU6(256, 32, 64)
403
- self.stage1d = RSU7(128, 16, 64)
404
-
405
- self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
406
- self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
407
- self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
408
- self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
409
- self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
410
- self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
411
-
412
- self.outconv = nn.Conv2d(6, out_ch, 1)
413
-
414
- def forward(self, x):
415
- hx = x
416
-
417
- hxin = self.conv_in(hx)
418
- hxin = self.pool_in(hxin)
419
-
420
- # stage 1
421
- hx1 = self.stage1(hxin)
422
- hx = self.pool12(hx1)
423
-
424
- # stage 2
425
- hx2 = self.stage2(hx)
426
- hx = self.pool23(hx2)
427
-
428
- # stage 3
429
- hx3 = self.stage3(hx)
430
- hx = self.pool34(hx3)
431
-
432
- # stage 4
433
- hx4 = self.stage4(hx)
434
- hx = self.pool45(hx4)
435
-
436
- # stage 5
437
- hx5 = self.stage5(hx)
438
- hx = self.pool56(hx5)
439
-
440
- # stage 6
441
- hx6 = self.stage6(hx)
442
- hx6up = _upsample_like(hx6, hx5)
443
-
444
- # decoder
445
- hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
446
- hx5dup = _upsample_like(hx5d, hx4)
447
-
448
- hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
449
- hx4dup = _upsample_like(hx4d, hx3)
450
-
451
- hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
452
- hx3dup = _upsample_like(hx3d, hx2)
453
-
454
- hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
455
- hx2dup = _upsample_like(hx2d, hx1)
456
-
457
- hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
458
-
459
- # side output
460
- d1 = self.side1(hx1d)
461
- d1 = _upsample_like(d1, x)
462
-
463
- d2 = self.side2(hx2d)
464
- d2 = _upsample_like(d2, x)
465
-
466
- d3 = self.side3(hx3d)
467
- d3 = _upsample_like(d3, x)
468
-
469
- d4 = self.side4(hx4d)
470
- d4 = _upsample_like(d4, x)
471
-
472
- d5 = self.side5(hx5d)
473
- d5 = _upsample_like(d5, x)
474
-
475
- d6 = self.side6(hx6)
476
- d6 = _upsample_like(d6, x)
477
-
478
- d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
479
-
480
- return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)
481
-
482
- # ============================================================================
483
- # MODEL LOADING AND INITIALIZATION
484
- # ============================================================================
485
-
486
- print("Loading BRIA RMBG model...")
487
-
488
- # Load the model
489
- model_path = "./model.pth"
490
- if not os.path.exists(model_path):
491
- print("Model not found locally, downloading from HuggingFace...")
492
- from huggingface_hub import hf_hub_download
493
- model_path = hf_hub_download(
494
- repo_id="briaai/RMBG-1.4",
495
- filename="model.pth",
496
- repo_type="model"
497
- )
498
- print(f"Model downloaded to: {model_path}")
499
-
500
- # Initialize model
501
- net = BriaRMBG()
502
-
503
- # Load state dict with error handling
504
- try:
505
- state_dict = torch.load(model_path, map_location=device)
506
-
507
- # Check if we need to adjust the state dict
508
- if 'outconv.weight' not in state_dict:
509
- print("Adjusting model state dict keys...")
510
- # The model might have different key names, let's check
511
- for key in list(state_dict.keys()):
512
- if 'outconv' in key:
513
- print(f"Found outconv key: {key}")
514
 
515
- net.load_state_dict(state_dict, strict=False)
516
- print("Model weights loaded successfully!")
517
- except Exception as e:
518
- print(f"Warning: Could not load all model weights: {e}")
519
- print("Attempting to load with strict=False...")
520
  try:
521
- net.load_state_dict(torch.load(model_path, map_location=device), strict=False)
522
- print("Model loaded with strict=False")
523
- except Exception as e2:
524
- print(f"Error loading model: {e2}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  raise
526
 
527
- net.to(device)
528
- net.eval()
529
- print("Model loaded successfully!")
530
-
531
- # ============================================================================
532
- # IMAGE PROCESSING FUNCTION
533
- # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
- def process_image(input_image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  """
537
- Main function to process images and remove background.
538
- Returns RGBA image with transparent background.
539
  """
540
- if input_image is None:
541
- return None
 
 
542
 
543
- print(f"Processing image... Original type: {type(input_image)}")
544
-
545
- # Convert to PIL Image if needed
546
- if isinstance(input_image, np.ndarray):
547
- input_image = Image.fromarray(input_image)
548
- print("Converted numpy array to PIL Image")
549
-
550
- # Ensure RGB mode
551
- if input_image.mode != 'RGB':
552
- input_image = input_image.convert('RGB')
553
- print(f"Converted image from {input_image.mode} to RGB")
554
-
555
- # Get original size
556
- orig_size = input_image.size
557
- print(f"Original image size: {orig_size}")
558
-
559
- # Preprocess
560
- model_input = preprocess_image(input_image, [1024, 1024])
561
- print(f"Model input shape: {model_input.shape}")
562
-
563
- # Run model
564
- with torch.no_grad():
565
- preds = net(model_input)[-1]
566
- print(f"Model output shape: {preds.shape}")
567
-
568
- # Postprocess with alpha channel support
569
- pred = postprocess_image(preds[0], orig_size)
570
- print(f"Postprocessed mask size: {pred.size}")
571
-
572
- # Convert to numpy array
573
- pred_np = np.array(pred)
574
- print(f"Mask values - Min: {pred_np.min()}, Max: {pred_np.max()}, Mean: {pred_np.mean():.2f}")
575
-
576
- # Create RGBA output with fixed transparency
577
- output = np.zeros((*pred_np.shape[:2], 4), dtype=np.uint8)
578
-
579
- # Copy RGB channels
580
- input_np = np.array(input_image)
581
- output[:, :, :3] = input_np
582
-
583
- # FIXED: Set alpha channel directly from prediction
584
- # The model outputs values 0-255, we use them directly
585
- output[:, :, 3] = pred_np
586
-
587
- print(f"Output shape: {output.shape}")
588
- print(f"Alpha channel - Min: {output[:,:,3].min()}, Max: {output[:,:,3].max()}")
589
-
590
- # Convert to PIL Image with RGBA
591
- output_image = Image.fromarray(output, mode='RGBA')
592
- print("Successfully created RGBA image with transparent background")
593
-
594
- return output_image
595
-
596
- # ============================================================================
597
- # GRADIO INTERFACE
598
- # ============================================================================
599
-
600
- # Custom CSS with MyAvatars.dk branding
601
- custom_css = """
602
- .logo-container {
603
- text-align: center;
604
- padding: 25px 0;
605
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
606
- border-radius: 15px;
607
- margin-bottom: 25px;
608
- box-shadow: 0 10px 30px rgba(0,0,0,0.2);
609
- }
610
- .logo-title {
611
- color: white;
612
- font-size: 3em;
613
- font-weight: bold;
614
- text-shadow: 3px 3px 6px rgba(0,0,0,0.3);
615
- margin-bottom: 10px;
616
- }
617
- .logo-subtitle {
618
- color: rgba(255,255,255,0.95);
619
- font-size: 1.3em;
620
- margin-top: 10px;
621
- font-weight: 300;
622
- }
623
- .powered-by {
624
- text-align: center;
625
- color: #666;
626
- font-size: 0.9em;
627
- margin-top: 20px;
628
- padding: 10px;
629
- background: rgba(0,0,0,0.05);
630
- border-radius: 5px;
631
- }
632
- .features-grid {
633
- display: grid;
634
- grid-template-columns: repeat(3, 1fr);
635
- gap: 20px;
636
- margin: 20px 0;
637
- }
638
- .feature-card {
639
- text-align: center;
640
- padding: 15px;
641
- background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
642
- border-radius: 10px;
643
- box-shadow: 0 5px 15px rgba(0,0,0,0.1);
644
- }
645
- .feature-icon {
646
- font-size: 2em;
647
- margin-bottom: 10px;
648
- }
649
- .feature-title {
650
- font-weight: bold;
651
- color: #333;
652
- margin-bottom: 5px;
653
- }
654
- .feature-desc {
655
- color: #666;
656
- font-size: 0.9em;
657
- }
658
- """
659
-
660
- print("Creating Gradio interface...")
661
-
662
- # Create Gradio interface with logo and enhanced UI
663
- with gr.Blocks(css=custom_css, title="MyAvatars.dk - AI Background Remover") as demo:
664
- # Logo header
665
- gr.HTML("""
666
- <div class="logo-container">
667
- <div class="logo-title">🎨 MyAvatars.dk</div>
668
- <div class="logo-subtitle">Professional AI-Powered Background Removal</div>
669
- </div>
670
- """)
671
-
672
- # Features grid
673
- gr.HTML("""
674
- <div class="features-grid">
675
- <div class="feature-card">
676
- <div class="feature-icon">⚡</div>
677
- <div class="feature-title">Lightning Fast</div>
678
- <div class="feature-desc">Process images in seconds</div>
679
- </div>
680
- <div class="feature-card">
681
- <div class="feature-icon">🎯</div>
682
- <div class="feature-title">High Precision</div>
683
- <div class="feature-desc">AI-powered edge detection</div>
684
- </div>
685
- <div class="feature-card">
686
- <div class="feature-icon">🔒</div>
687
- <div class="feature-title">Privacy First</div>
688
- <div class="feature-desc">Images processed locally</div>
689
- </div>
690
- </div>
691
- """)
692
-
693
- gr.Markdown("## Remove backgrounds instantly with state-of-the-art AI")
694
- gr.Markdown("Upload any image and get a perfect transparent background version. Ideal for avatars, product photos, and creative projects!")
695
-
696
- with gr.Row():
697
- with gr.Column():
698
- input_image = gr.Image(
699
- label="📤 Upload Image",
700
- type="pil",
701
- height=400,
702
- elem_id="input_image"
703
- )
704
-
705
- with gr.Row():
706
- clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm")
707
- process_btn = gr.Button("✨ Remove Background", variant="primary", size="lg")
708
 
709
- with gr.Column():
710
- output_image = gr.Image(
711
- label="📥 Result (Transparent Background)",
712
- type="pil",
713
- height=400,
714
- image_mode="RGBA",
715
- elem_id="output_image"
716
- )
717
 
718
- gr.Markdown("""
719
- ### 💡 Tips for best results:
720
- - Use high-quality images with clear subjects
721
- - Ensure good contrast between subject and background
722
- - Works best with people, objects, and products
723
- """)
724
-
725
- # Examples section (commented out if no example images available)
726
- # gr.Markdown("### 🖼️ Try with examples:")
727
- # gr.Examples(
728
- # examples=[
729
- # ["./input.jpg"]
730
- # ],
731
- # inputs=input_image,
732
- # outputs=output_image,
733
- # fn=process_image,
734
- # cache_examples=True
735
- # )
736
-
737
- # Footer with version info
738
- gr.HTML("""
739
- <div class="powered-by">
740
- <strong>Powered by BRIA RMBG 1.4</strong> | Version 2.2 | Cache Buster Active<br>
741
- <small>© 2024 MyAvatars.dk - Professional Avatar Solutions</small>
742
- </div>
743
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
 
745
- # Instructions
746
- with gr.Accordion("📖 How to use", open=False):
747
  gr.Markdown("""
748
- 1. **Upload an image** using the upload area or drag & drop
749
- 2. **Click "Remove Background"** to process the image
750
- 3. **Download the result** with transparent background
751
- 4. **Use the result** in your projects, presentations, or as avatars
 
 
 
752
 
753
- **Supported formats:** JPG, PNG, WebP
754
- **Max resolution:** 4096x4096 pixels
755
- **Output format:** PNG with transparency (RGBA)
756
  """)
757
-
758
- # Event handlers
759
- process_btn.click(
760
- fn=process_image,
761
- inputs=[input_image],
762
- outputs=[output_image]
763
- )
764
-
765
- clear_btn.click(
766
- fn=lambda: (None, None),
767
- inputs=[],
768
- outputs=[input_image, output_image]
769
- )
770
-
771
- # Add keyboard shortcut
772
- input_image.change(
773
- fn=lambda x: gr.update(interactive=x is not None),
774
- inputs=[input_image],
775
- outputs=[process_btn]
776
- )
777
-
778
- print("=" * 80)
779
- print("MYAVATARS.DK BACKGROUND REMOVER - VERSION 2.2")
780
- print("=" * 80)
781
- print("Application initialized successfully!")
782
- print("Features enabled:")
783
- print(" ✓ Transparent background removal")
784
- print(" ✓ High-quality edge detection")
785
- print(" RGBA output support")
786
- print(" ✓ CPU-optimized processing")
787
- print(" ✓ Professional UI with branding")
788
- print("=" * 80)
789
-
790
- # Launch the application - HuggingFace Spaces compatible
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791
  if __name__ == "__main__":
792
- print("Launching Gradio interface...")
793
- import sys
794
- if "google.colab" in sys.modules:
795
- demo.launch(debug=True)
796
- else:
797
- # For HuggingFace Spaces
798
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ========================= PRE-IMPORT ENV GUARDS =========================
3
+ import os
4
+ os.environ.pop("OMP_NUM_THREADS", None)
5
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
6
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
7
+ os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
8
+ os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
9
+
10
+ # ========================= IMPORTS =========================
11
+ import gc
12
+ import sys
13
+ import cv2
14
  import torch
 
 
 
15
  import numpy as np
16
+ import gradio as gr
17
+ import tempfile
18
+ import time
19
+ from pathlib import Path
20
+ import logging
21
+ import traceback
22
+ from datetime import datetime
23
+ import psutil
24
+ import warnings
25
+ warnings.filterwarnings("ignore")
26
+
27
+ # Import the properly implemented functions from utilities
28
+ from utilities import (
29
+ segment_person_hq,
30
+ refine_mask_hq,
31
+ replace_background_hq,
32
+ load_background_image,
33
+ resize_background_to_match,
34
+ apply_temporal_smoothing,
35
+ smooth_edges,
36
+ estimate_foreground
37
+ )
38
+
39
+ # Import two-stage processor for advanced mode
40
+ from two_stage_processor import TwoStageProcessor
41
+
42
+ # Import UI components
43
+ from ui_components import create_ui, get_example_videos, get_example_backgrounds
44
+
45
+ # ========================= LOGGING SETUP =========================
46
+ logging.basicConfig(
47
+ level=logging.INFO,
48
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
49
+ )
50
+ logger = logging.getLogger(__name__)
51
+
52
+ # ========================= GPU/DEVICE SETUP =========================
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ logger.info(f"Using device: {device}")
55
+
56
+ if device.type == "cuda":
57
+ torch.cuda.empty_cache()
58
+ # Optimize CUDA settings for memory efficiency
59
+ torch.backends.cudnn.benchmark = False
60
+ torch.backends.cudnn.deterministic = True
61
+ torch.cuda.set_per_process_memory_fraction(0.8) # Limit to 80% of VRAM
62
+
63
+ # ========================= GLOBAL MODELS =========================
64
+ # Models will be loaded on demand to save RAM
65
+ sam2_model = None
66
+ matta_model = None
67
+ two_stage_processor = None
68
+
69
+ # ========================= MODEL LOADING =========================
70
+ def load_models_on_demand(use_two_stage=False):
71
+ """Load models only when needed, with proper memory management"""
72
+ global sam2_model, matta_model, two_stage_processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
74
  try:
75
+ # Clear any existing models first
76
+ clear_models_from_memory()
77
+
78
+ if use_two_stage and two_stage_processor is None:
79
+ logger.info("Loading Two-Stage Processor (SAM2 + MattA)...")
80
+ two_stage_processor = TwoStageProcessor(device=device)
81
+ logger.info("Two-Stage Processor loaded successfully")
82
+ elif not use_two_stage:
83
+ # Load individual models for single-stage processing
84
+ if sam2_model is None:
85
+ logger.info("Loading SAM2 model...")
86
+ # This should be imported from your SAM2 implementation
87
+ from sam2_integration import load_sam2_model
88
+ sam2_model = load_sam2_model(device=device)
89
+ logger.info("SAM2 model loaded")
90
+
91
+ if matta_model is None:
92
+ logger.info("Loading MattingAnything model...")
93
+ # This should be imported from your MattA implementation
94
+ from matta_integration import load_matta_model
95
+ matta_model = load_matta_model(device=device)
96
+ logger.info("MattingAnything model loaded")
97
+
98
+ # Force garbage collection after loading
99
+ gc.collect()
100
+ if device.type == "cuda":
101
+ torch.cuda.empty_cache()
102
+
103
+ except Exception as e:
104
+ logger.error(f"Error loading models: {str(e)}")
105
  raise
106
 
107
+ def clear_models_from_memory():
108
+ """Clear models from memory to free up RAM"""
109
+ global sam2_model, matta_model, two_stage_processor
110
+
111
+ if sam2_model is not None:
112
+ del sam2_model
113
+ sam2_model = None
114
+
115
+ if matta_model is not None:
116
+ del matta_model
117
+ matta_model = None
118
+
119
+ if two_stage_processor is not None:
120
+ del two_stage_processor
121
+ two_stage_processor = None
122
+
123
+ gc.collect()
124
+ if device.type == "cuda":
125
+ torch.cuda.empty_cache()
126
+
127
+ # ========================= MEMORY MONITORING =========================
128
+ def log_memory_usage(stage=""):
129
+ """Log current memory usage"""
130
+ process = psutil.Process()
131
+ mem_info = process.memory_info()
132
+ ram_usage = mem_info.rss / 1024 / 1024 / 1024 # GB
133
+
134
+ if device.type == "cuda":
135
+ vram_usage = torch.cuda.memory_allocated() / 1024 / 1024 / 1024 # GB
136
+ vram_reserved = torch.cuda.memory_reserved() / 1024 / 1024 / 1024 # GB
137
+ logger.info(f"[{stage}] RAM: {ram_usage:.2f}GB | VRAM: {vram_usage:.2f}GB (reserved: {vram_reserved:.2f}GB)")
138
+ else:
139
+ logger.info(f"[{stage}] RAM: {ram_usage:.2f}GB")
140
 
141
+ # ========================= PROGRESS TRACKING =========================
142
+ def write_progress_info(info_dict):
143
+ """Write formatted progress information to temp file for UI display"""
144
+ try:
145
+ progress_file = "/tmp/processing_info.txt"
146
+ with open(progress_file, "w") as f:
147
+ if "error" in info_dict:
148
+ f.write(f"❌ ERROR\n{info_dict['error']}\n")
149
+ elif "complete" in info_dict:
150
+ f.write(f"✅ COMPLETE\n")
151
+ f.write(f"Total Frames: {info_dict.get('total_frames', 'N/A')}\n")
152
+ f.write(f"Processing Time: {info_dict.get('time', 'N/A')}\n")
153
+ f.write(f"Average FPS: {info_dict.get('fps', 'N/A')}\n")
154
+ f.write(f"Resolution: {info_dict.get('resolution', 'N/A')}\n")
155
+ f.write(f"Background: {info_dict.get('background', 'N/A')}\n")
156
+ else:
157
+ f.write(f"📊 PROCESSING STATUS\n")
158
+ f.write(f"━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
159
+ f.write(f"🎬 Frame {info_dict.get('current_frame', 0)}/{info_dict.get('total_frames', 0)}\n")
160
+ f.write(f"⏱️ Elapsed: {info_dict.get('elapsed', '0s')}\n")
161
+ f.write(f"⚡ Speed: {info_dict.get('speed', '0')} fps\n")
162
+ f.write(f"🎯 ETA: {info_dict.get('eta', 'calculating...')}\n")
163
+ f.write(f"━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
164
+ f.write(f"📈 Progress: {info_dict.get('progress', 0):.1f}%\n")
165
+ except Exception as e:
166
+ logger.error(f"Error writing progress: {e}")
167
+
168
+ # ========================= MAIN PROCESSING FUNCTION =========================
169
+ def process_video(
170
+ input_video,
171
+ background_image,
172
+ use_two_stage=False,
173
+ use_mask_refinement=True,
174
+ use_temporal_smoothing=True,
175
+ mask_blur=5,
176
+ edge_smoothing=5,
177
+ background_type="Color",
178
+ background_color="#00FF00",
179
+ progress=gr.Progress()
180
+ ):
181
  """
182
+ Main video processing function with proper SAM2+MattA integration
 
183
  """
184
+ temp_dir = None
185
+ cap = None
186
+ out = None
187
+ start_time = time.time()
188
 
189
+ try:
190
+ # Initial setup
191
+ logger.info("Starting video processing...")
192
+ log_memory_usage("Start")
193
+
194
+ # Validate inputs
195
+ if input_video is None:
196
+ raise ValueError("No input video provided")
197
+
198
+ # Load models based on processing mode
199
+ load_models_on_demand(use_two_stage=use_two_stage)
200
+ log_memory_usage("Models Loaded")
201
+
202
+ # Setup video capture
203
+ cap = cv2.VideoCapture(input_video)
204
+ if not cap.isOpened():
205
+ raise ValueError(f"Failed to open video: {input_video}")
206
+
207
+ # Get video properties
208
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
209
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
210
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
211
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
212
+
213
+ logger.info(f"Video info: {width}x{height}, {fps} fps, {total_frames} frames")
214
+
215
+ # Prepare background
216
+ if background_type == "Color":
217
+ background = np.full((height, width, 3),
218
+ tuple(int(background_color[i:i+2], 16) for i in (5, 3, 1)),
219
+ dtype=np.uint8)
220
+ elif background_type == "Image" and background_image is not None:
221
+ background = load_background_image(background_image)
222
+ background = resize_background_to_match(background, (width, height))
223
+ elif background_type == "Blur":
224
+ # Will be handled per frame
225
+ background = None
226
+ else:
227
+ background = np.full((height, width, 3), (0, 255, 0), dtype=np.uint8)
228
+
229
+ # Setup output video
230
+ temp_dir = tempfile.mkdtemp()
231
+ output_path = os.path.join(temp_dir, "output_video.mp4")
232
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
233
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
234
+
235
+ # Process frames
236
+ frame_idx = 0
237
+ processed_frames = []
238
+ masks_history = [] # For temporal smoothing
239
+
240
+ # Batch processing for memory efficiency
241
+ BATCH_SIZE = 10 if device.type == "cuda" else 5
242
+ frame_batch = []
243
+
244
+ while True:
245
+ ret, frame = cap.read()
246
+ if not ret:
247
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ frame_batch.append(frame)
 
 
 
 
 
 
 
250
 
251
+ # Process batch when full or at end
252
+ if len(frame_batch) == BATCH_SIZE or frame_idx == total_frames - 1:
253
+
254
+ for batch_frame in frame_batch:
255
+ # Update progress
256
+ progress(frame_idx / total_frames, f"Processing frame {frame_idx}/{total_frames}")
257
+
258
+ # Calculate and write detailed progress info
259
+ elapsed_time = time.time() - start_time
260
+ if frame_idx > 0:
261
+ fps_current = frame_idx / elapsed_time
262
+ eta = (total_frames - frame_idx) / fps_current
263
+ write_progress_info({
264
+ 'current_frame': frame_idx,
265
+ 'total_frames': total_frames,
266
+ 'elapsed': f"{elapsed_time:.1f}s",
267
+ 'speed': f"{fps_current:.1f}",
268
+ 'eta': f"{eta:.0f}s",
269
+ 'progress': (frame_idx / total_frames) * 100
270
+ })
271
+
272
+ # Process frame based on mode
273
+ if use_two_stage:
274
+ # Use integrated two-stage processor
275
+ processed_frame, mask = two_stage_processor.process_frame(
276
+ batch_frame,
277
+ background if background is not None else batch_frame,
278
+ use_refinement=use_mask_refinement,
279
+ mask_blur=mask_blur
280
+ )
281
+ else:
282
+ # Use utilities functions (properly implemented with transparency fix)
283
+ # Step 1: Segment person using SAM2
284
+ mask = segment_person_hq(batch_frame, sam2_model)
285
+
286
+ # Step 2: Refine mask using MattA if enabled
287
+ if use_mask_refinement and matta_model is not None:
288
+ mask = refine_mask_hq(batch_frame, mask, matta_model)
289
+
290
+ # Step 3: Apply temporal smoothing if enabled
291
+ if use_temporal_smoothing and len(masks_history) > 0:
292
+ mask = apply_temporal_smoothing(mask, masks_history, window_size=5)
293
+
294
+ # Store mask for temporal smoothing
295
+ masks_history.append(mask)
296
+ if len(masks_history) > 10: # Keep only recent masks
297
+ masks_history.pop(0)
298
+
299
+ # Step 4: Apply edge smoothing
300
+ if edge_smoothing > 0:
301
+ mask = smooth_edges(mask, edge_smoothing)
302
+
303
+ # Step 5: Handle background
304
+ if background_type == "Blur":
305
+ background_frame = cv2.GaussianBlur(batch_frame, (21, 21), 0)
306
+ else:
307
+ background_frame = background
308
+
309
+ # Step 6: Replace background with proper alpha handling
310
+ processed_frame = replace_background_hq(
311
+ batch_frame,
312
+ mask,
313
+ background_frame
314
+ )
315
+
316
+ # Write frame
317
+ out.write(processed_frame)
318
+ processed_frames.append(processed_frame)
319
+ frame_idx += 1
320
+
321
+ # Memory management - clear every 100 frames
322
+ if frame_idx % 100 == 0:
323
+ gc.collect()
324
+ if device.type == "cuda":
325
+ torch.cuda.empty_cache()
326
+ log_memory_usage(f"Frame {frame_idx}")
327
+
328
+ # Clear batch
329
+ frame_batch = []
330
+
331
+ # Finalize
332
+ cap.release()
333
+ out.release()
334
+
335
+ # Write completion info
336
+ total_time = time.time() - start_time
337
+ avg_fps = total_frames / total_time if total_time > 0 else 0
338
+ write_progress_info({
339
+ 'complete': True,
340
+ 'total_frames': total_frames,
341
+ 'time': f"{total_time:.1f}s",
342
+ 'fps': f"{avg_fps:.1f}",
343
+ 'resolution': f"{width}x{height}",
344
+ 'background': background_type
345
+ })
346
+
347
+ logger.info(f"Processing complete: {total_frames} frames in {total_time:.1f}s ({avg_fps:.1f} fps)")
348
+ log_memory_usage("Complete")
349
+
350
+ return output_path
351
+
352
+ except Exception as e:
353
+ logger.error(f"Processing error: {str(e)}\n{traceback.format_exc()}")
354
+ write_progress_info({'error': str(e)})
355
+ raise gr.Error(f"Processing failed: {str(e)}")
356
+
357
+ finally:
358
+ # Cleanup
359
+ if cap is not None:
360
+ cap.release()
361
+ if out is not None:
362
+ out.release()
363
+
364
+ # Clear models to free memory
365
+ clear_models_from_memory()
366
+
367
+ # Final garbage collection
368
+ gc.collect()
369
+ if device.type == "cuda":
370
+ torch.cuda.empty_cache()
371
+
372
+ # ========================= GRADIO APP =========================
373
+ def create_app():
374
+ """Create and configure the Gradio application"""
375
 
376
+ with gr.Blocks(title="Video Background Replacement - SAM2+MattA", theme=gr.themes.Soft()) as app:
 
377
  gr.Markdown("""
378
+ # 🎬 Video Background Replacement
379
+ ### Powered by SAM2 + MattingAnything
380
+
381
+ Upload a video and replace the background with:
382
+ - 🎨 Solid colors
383
+ - 🖼️ Custom images
384
+ - 🌫️ Blurred background
385
 
386
+ **Two-Stage Mode**: Combines SAM2 segmentation with MattA refinement for best quality
 
 
387
  """)
388
+
389
+ with gr.Tabs():
390
+ with gr.TabItem("🎥 Process Video"):
391
+ with gr.Row():
392
+ with gr.Column(scale=1):
393
+ input_video = gr.Video(label="Input Video", height=300)
394
+
395
+ with gr.Accordion("⚙️ Processing Options", open=True):
396
+ use_two_stage = gr.Checkbox(
397
+ label="Use Two-Stage Processing (SAM2→MattA)",
398
+ value=True,
399
+ info="Better quality but slower"
400
+ )
401
+ use_mask_refinement = gr.Checkbox(
402
+ label="Refine Masks",
403
+ value=True,
404
+ info="Use MattA for better edges"
405
+ )
406
+ use_temporal_smoothing = gr.Checkbox(
407
+ label="Temporal Smoothing",
408
+ value=True,
409
+ info="Reduce flickering between frames"
410
+ )
411
+ mask_blur = gr.Slider(
412
+ minimum=0,
413
+ maximum=21,
414
+ value=5,
415
+ step=2,
416
+ label="Mask Blur"
417
+ )
418
+ edge_smoothing = gr.Slider(
419
+ minimum=0,
420
+ maximum=21,
421
+ value=5,
422
+ step=2,
423
+ label="Edge Smoothing"
424
+ )
425
+
426
+ with gr.Accordion("🎨 Background Options", open=True):
427
+ background_type = gr.Radio(
428
+ choices=["Color", "Image", "Blur"],
429
+ value="Color",
430
+ label="Background Type"
431
+ )
432
+ background_color = gr.ColorPicker(
433
+ label="Background Color",
434
+ value="#00FF00",
435
+ visible=True
436
+ )
437
+ background_image = gr.Image(
438
+ label="Background Image",
439
+ type="filepath",
440
+ visible=False
441
+ )
442
+
443
+ # Show/hide based on background type
444
+ def update_background_inputs(bg_type):
445
+ return (
446
+ gr.update(visible=bg_type == "Color"),
447
+ gr.update(visible=bg_type == "Image")
448
+ )
449
+
450
+ background_type.change(
451
+ update_background_inputs,
452
+ inputs=[background_type],
453
+ outputs=[background_color, background_image]
454
+ )
455
+
456
+ with gr.Column(scale=1):
457
+ output_video = gr.Video(label="Output Video", height=300)
458
+
459
+ process_btn = gr.Button("🚀 Process Video", variant="primary", size="lg")
460
+
461
+ processing_info = gr.Textbox(
462
+ label="📊 Processing Info",
463
+ lines=10,
464
+ max_lines=15,
465
+ interactive=False,
466
+ placeholder="Processing status will appear here...",
467
+ elem_id="processing-info"
468
+ )
469
+
470
+ # Connect processing
471
+ process_btn.click(
472
+ fn=process_video,
473
+ inputs=[
474
+ input_video,
475
+ background_image,
476
+ use_two_stage,
477
+ use_mask_refinement,
478
+ use_temporal_smoothing,
479
+ mask_blur,
480
+ edge_smoothing,
481
+ background_type,
482
+ background_color
483
+ ],
484
+ outputs=[output_video]
485
+ )
486
+
487
+ with gr.TabItem("📚 Examples"):
488
+ gr.Examples(
489
+ examples=get_example_videos(),
490
+ inputs=input_video,
491
+ label="Sample Videos"
492
+ )
493
+ gr.Examples(
494
+ examples=get_example_backgrounds(),
495
+ inputs=background_image,
496
+ label="Sample Backgrounds"
497
+ )
498
+
499
+ with gr.TabItem("ℹ️ About"):
500
+ gr.Markdown("""
501
+ ### Technology Stack
502
+
503
+ - **SAM2**: Segment Anything Model 2 for accurate person segmentation
504
+ - **MattingAnything**: Advanced alpha matting for refined edges
505
+ - **Two-Stage Processing**: Combines both models for optimal quality
506
+
507
+ ### Tips for Best Results
508
+
509
+ 1. **Use Two-Stage Mode** for highest quality output
510
+ 2. **Enable Temporal Smoothing** to reduce flickering
511
+ 3. **Adjust Edge Smoothing** for softer transitions
512
+ 4. **High contrast backgrounds** work best
513
+
514
+ ### Performance Notes
515
+
516
+ - Processing speed depends on video resolution and length
517
+ - GPU recommended for faster processing
518
+ - Two-stage mode is slower but produces better results
519
+ """)
520
+
521
+ return app
522
+
523
+ # ========================= MAIN ENTRY POINT =========================
524
  if __name__ == "__main__":
525
+ try:
526
+ # Create and launch app
527
+ app = create_app()
528
+
529
+ # Configure for HuggingFace Spaces
530
+ app.queue(max_size=5)
531
+ app.launch(
532
+ server_name="0.0.0.0",
533
+ server_port=7860,
534
+ share=False,
535
+ debug=False,
536
+ show_error=True
537
+ )
538
+
539
+ except Exception as e:
540
+ logger.error(f"Failed to start application: {str(e)}")
541
+ traceback.print_exc()
542
+ sys.exit(1)