amirgame197 commited on
Commit
20593ca
·
verified ·
1 Parent(s): 6b518fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -26
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import spaces
2
  import gradio as gr
3
  import cv2
@@ -7,13 +10,17 @@ import random
7
  from PIL import Image
8
  import torch
9
  import re
 
 
 
 
10
 
11
  torch.jit.script = lambda f: f
12
 
13
  from transparent_background import Remover
14
 
15
  @spaces.GPU(duration=90)
16
- def doo(video, color, mode, progress=gr.Progress()):
17
  print(str(color))
18
  if str(color).startswith('#'):
19
  color = color.lstrip('#')
@@ -22,56 +29,137 @@ def doo(video, color, mode, progress=gr.Progress()):
22
  elif str(color).startswith('rgba'):
23
  rgba_match = re.match(r'rgba\(([\d.]+), ([\d.]+), ([\d.]+), [\d.]+\)', color)
24
  if rgba_match:
25
- r, g, b = rgba_match.groups() # Extract r, g, b values
26
  color = str([int(float(r)), int(float(g)), int(float(b))])
27
- print(color)
28
  if mode == 'Fast':
29
  remover = Remover(mode='fast')
30
  else:
31
  remover = Remover()
32
 
33
  cap = cv2.VideoCapture(video)
34
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Get total frames
 
35
  writer = None
36
  tmpname = random.randint(111111111, 999999999)
37
  processed_frames = 0
38
  start_time = time.time()
39
 
40
- while cap.isOpened():
41
- ret, frame = cap.read()
42
 
43
- if ret is False:
44
- break
 
45
 
46
- if time.time() - start_time >= 20 * 60 - 5:
47
- print("GPU Timeout is coming")
48
- cap.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  writer.release()
50
- return str(tmpname) + '.mp4'
51
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
52
- img = Image.fromarray(frame).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- if writer is None:
55
- writer = cv2.VideoWriter(str(tmpname) + '.mp4', cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- processed_frames += 1
58
- print(f"Processing frame {processed_frames}")
59
- progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
60
- out = remover.process(img, type=color)
61
- writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB))
62
 
63
- cap.release()
64
- writer.release()
65
- return str(tmpname) + '.mp4'
 
 
 
 
 
66
 
67
  title = "🎞️ Video Background Removal Tool 🎥"
68
- description = """*Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode."""
69
 
70
  examples = [['./input.mp4']]
71
 
72
  iface = gr.Interface(
73
  fn=doo,
74
- inputs=["video", gr.ColorPicker(label="Background color", value="#00FF00"), gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.')],
 
 
 
 
 
75
  outputs="video",
76
  examples=examples,
77
  title=title,
 
1
+ #testing webm hello hello
2
+
3
+
4
  import spaces
5
  import gradio as gr
6
  import cv2
 
10
  from PIL import Image
11
  import torch
12
  import re
13
+ import os
14
+ import shutil
15
+ import subprocess
16
+ import tempfile
17
 
18
  torch.jit.script = lambda f: f
19
 
20
  from transparent_background import Remover
21
 
22
  @spaces.GPU(duration=90)
23
+ def doo(video, color, mode, out_format, progress=gr.Progress()):
24
  print(str(color))
25
  if str(color).startswith('#'):
26
  color = color.lstrip('#')
 
29
  elif str(color).startswith('rgba'):
30
  rgba_match = re.match(r'rgba\(([\d.]+), ([\d.]+), ([\d.]+), [\d.]+\)', color)
31
  if rgba_match:
32
+ r, g, b = rgba_match.groups()
33
  color = str([int(float(r)), int(float(g)), int(float(b))])
34
+ print("Parsed color:", color)
35
  if mode == 'Fast':
36
  remover = Remover(mode='fast')
37
  else:
38
  remover = Remover()
39
 
40
  cap = cv2.VideoCapture(video)
41
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
42
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
43
  writer = None
44
  tmpname = random.randint(111111111, 999999999)
45
  processed_frames = 0
46
  start_time = time.time()
47
 
48
+ mp4_path = str(tmpname) + '.mp4'
49
+ webm_path = str(tmpname) + '.webm'
50
 
51
+ if out_format == 'mp4':
52
+ while cap.isOpened():
53
+ ret, frame = cap.read()
54
 
55
+ if ret is False:
56
+ break
57
+
58
+ if time.time() - start_time >= 20 * 60 - 5:
59
+ print("GPU Timeout is coming")
60
+ cap.release()
61
+ if writer is not None:
62
+ writer.release()
63
+ return mp4_path
64
+
65
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
66
+ img = Image.fromarray(frame).convert('RGB')
67
+
68
+ if writer is None:
69
+ writer = cv2.VideoWriter(mp4_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, img.size)
70
+
71
+ processed_frames += 1
72
+ print(f"Processing frame {processed_frames}")
73
+ progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
74
+
75
+ out = remover.process(img, type=color)
76
+
77
+ frame_bgr = cv2.cvtColor(np.array(out), cv2.COLOR_RGB2BGR)
78
+ writer.write(frame_bgr)
79
+
80
+ cap.release()
81
+ if writer is not None:
82
  writer.release()
83
+ return mp4_path
84
+
85
+ else:
86
+ temp_dir = tempfile.mkdtemp(prefix=f"tb_{tmpname}_")
87
+ try:
88
+ frame_idx = 0
89
+ while cap.isOpened():
90
+ ret, frame = cap.read()
91
+
92
+ if ret is False:
93
+ break
94
+
95
+ if time.time() - start_time >= 20 * 60 - 5:
96
+ print("GPU Timeout is coming")
97
+ cap.release()
98
+ # cleanup
99
+ shutil.rmtree(temp_dir, ignore_errors=True)
100
+ return webm_path
101
+
102
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
103
+ img = Image.fromarray(frame).convert('RGB')
104
+
105
+ processed_frames += 1
106
+ frame_idx += 1
107
+ print(f"Processing frame {processed_frames}")
108
+ progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
109
+
110
+ out = remover.process(img, type='rgba')
111
+ out = out.convert('RGBA')
112
+
113
+ frame_name = os.path.join(temp_dir, f"frame_{frame_idx:06d}.png")
114
+ out.save(frame_name, 'PNG')
115
+
116
+ cap.release()
117
 
118
+ fr_str = str(int(round(fps))) if fps > 0 else "25"
119
+ pattern = os.path.join(temp_dir, "frame_%06d.png")
120
+ ffmpeg_cmd = [
121
+ "ffmpeg", "-y",
122
+ "-framerate", fr_str,
123
+ "-i", pattern,
124
+ "-i", str(video),
125
+ "-map", "0:v",
126
+ "-map", "1:a?",
127
+ "-c:v", "libvpx-vp9",
128
+ "-pix_fmt", "yuva420p",
129
+ "-auto-alt-ref", "0",
130
+ "-metadata:s:v:0", "alpha_mode=1",
131
+ "-c:a", "libopus",
132
+ "-shortest",
133
+ webm_path
134
+ ]
135
+ print("Running ffmpeg:", " ".join(ffmpeg_cmd))
136
+ subprocess.run(ffmpeg_cmd, check=True)
137
 
138
+ shutil.rmtree(temp_dir, ignore_errors=True)
139
+ return webm_path
 
 
 
140
 
141
+ except subprocess.CalledProcessError as e:
142
+ print("ffmpeg failed:", e)
143
+ shutil.rmtree(temp_dir, ignore_errors=True)
144
+ return webm_path
145
+ except Exception as e:
146
+ print("Error during processing:", e)
147
+ shutil.rmtree(temp_dir, ignore_errors=True)
148
+ raise
149
 
150
  title = "🎞️ Video Background Removal Tool 🎥"
151
+ description = """*Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode.*"""
152
 
153
  examples = [['./input.mp4']]
154
 
155
  iface = gr.Interface(
156
  fn=doo,
157
+ inputs=[
158
+ "video",
159
+ gr.ColorPicker(label="Background color", value="#00FF00"),
160
+ gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.'),
161
+ gr.components.Radio(['mp4', 'webm'], label='Output format', value='mp4')
162
+ ],
163
  outputs="video",
164
  examples=examples,
165
  title=title,