KingNish commited on
Commit
a8fd4c9
·
verified ·
1 Parent(s): 6764406

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -39
app.py CHANGED
@@ -13,20 +13,17 @@ import tempfile
13
  import uuid
14
  import time
15
  import threading
 
16
 
17
  torch.set_float32_matmul_precision("medium")
18
-
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  # Load both BiRefNet models
22
  birefnet = AutoModelForImageSegmentation.from_pretrained(
23
- "ZhengPeng7/BiRefNet", trust_remote_code=True
24
- )
25
  birefnet.to(device)
26
-
27
  birefnet_lite = AutoModelForImageSegmentation.from_pretrained(
28
- "ZhengPeng7/BiRefNet_lite", trust_remote_code=True
29
- )
30
  birefnet_lite.to(device)
31
 
32
  transform_image = transforms.Compose(
@@ -37,7 +34,6 @@ transform_image = transforms.Compose(
37
  ]
38
  )
39
 
40
-
41
  # Function to delete files older than 10 minutes in the temp directory
42
  def cleanup_temp_files():
43
  while True:
@@ -55,11 +51,29 @@ def cleanup_temp_files():
55
  print(f"Error deleting file {filepath}: {e}")
56
  time.sleep(60) # Check every minute
57
 
58
-
59
  # Start the cleanup thread
60
  cleanup_thread = threading.Thread(target=cleanup_temp_files, daemon=True)
61
  cleanup_thread.start()
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  @spaces.GPU
65
  def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=True):
@@ -77,7 +91,7 @@ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=
77
  audio = video.audio
78
 
79
  # Extract frames at the specified FPS
80
- frames = video.iter_frames(fps=fps)
81
 
82
  # Process each frame for background removal
83
  processed_frames = []
@@ -96,29 +110,14 @@ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=
96
 
97
  bg_frame_index = 0 # Initialize background frame index
98
 
99
- for i, frame in enumerate(frames):
100
- pil_image = Image.fromarray(frame)
101
- if bg_type == "Color":
102
- processed_image = process(pil_image, color, fast_mode)
103
- elif bg_type == "Image":
104
- processed_image = process(pil_image, bg_image, fast_mode)
105
- elif bg_type == "Video":
106
- if video_handling == "slow_down":
107
- background_frame = background_frames[bg_frame_index % len(background_frames)]
108
- bg_frame_index += 1
109
- background_image = Image.fromarray(background_frame)
110
- processed_image = process(pil_image, background_image, fast_mode)
111
- else: # video_handling == "loop"
112
- background_frame = background_frames[bg_frame_index % len(background_frames)]
113
- bg_frame_index += 1
114
- background_image = Image.fromarray(background_frame)
115
- processed_image = process(pil_image, background_image, fast_mode)
116
- else:
117
- processed_image = pil_image # Default to original image if no background is selected
118
-
119
- processed_frames.append(np.array(processed_image))
120
- elapsed_time = time.time() - start_time
121
- yield processed_image, None, f"Processing frame {i+1}... Elapsed time: {elapsed_time:.2f} seconds"
122
 
123
  # Create a new video from the processed frames
124
  processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
@@ -135,8 +134,9 @@ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=
135
 
136
  elapsed_time = time.time() - start_time
137
  yield gr.update(visible=False), gr.update(visible=True), f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
 
138
  # Return the path to the temporary file
139
- yield processed_image, temp_filepath, f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
140
 
141
  except Exception as e:
142
  print(f"Error: {e}")
@@ -144,7 +144,6 @@ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=
144
  yield gr.update(visible=False), gr.update(visible=True), f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
145
  yield None, f"Error processing video: {e}", f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
146
 
147
-
148
  def process(image, bg, fast_mode=False):
149
  image_size = image.size
150
  input_images = transform_image(image).unsqueeze(0).to("cuda")
@@ -169,12 +168,10 @@ def process(image, bg, fast_mode=False):
169
 
170
  # Composite the image onto the background using the mask
171
  image = Image.composite(image, background, mask)
172
-
173
  return image
174
 
175
-
176
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
177
- gr.Markdown("# Video Background Remover & Changer\n### You can replace image background with any color, image or video.\nNOTE: As this Space is running on ZERO GPU it has limit. It can handle approx 200frmaes at once. So, if you have big video than use small chunks or Duplicate this space.")
178
  with gr.Row():
179
  in_video = gr.Video(label="Input Video", interactive=True)
180
  stream_image = gr.Image(label="Streaming Output", visible=False)
@@ -196,8 +193,7 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
196
  with gr.Column(visible=False) as video_handling_options:
197
  video_handling_radio = gr.Radio(["slow_down", "loop"], label="Video Handling", value="slow_down", interactive=True)
198
  fast_mode_checkbox = gr.Checkbox(label="Fast Mode (Use BiRefNet_lite)", value=True, interactive=True)
199
-
200
- time_textbox = gr.Textbox(label="Time Elapsed", interactive=False) # Add time textbox
201
 
202
  def update_visibility(bg_type):
203
  if bg_type == "Color":
 
13
  import uuid
14
  import time
15
  import threading
16
+ from concurrent.futures import ThreadPoolExecutor
17
 
18
  torch.set_float32_matmul_precision("medium")
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  # Load both BiRefNet models
22
  birefnet = AutoModelForImageSegmentation.from_pretrained(
23
+ "ZhengPeng7/BiRefNet", trust_remote_code=True)
 
24
  birefnet.to(device)
 
25
  birefnet_lite = AutoModelForImageSegmentation.from_pretrained(
26
+ "ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
 
27
  birefnet_lite.to(device)
28
 
29
  transform_image = transforms.Compose(
 
34
  ]
35
  )
36
 
 
37
  # Function to delete files older than 10 minutes in the temp directory
38
  def cleanup_temp_files():
39
  while True:
 
51
  print(f"Error deleting file {filepath}: {e}")
52
  time.sleep(60) # Check every minute
53
 
 
54
  # Start the cleanup thread
55
  cleanup_thread = threading.Thread(target=cleanup_temp_files, daemon=True)
56
  cleanup_thread.start()
57
 
58
+ # Function to process a single frame
59
+ def process_frame(frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color):
60
+ try:
61
+ pil_image = Image.fromarray(frame)
62
+ if bg_type == "Color":
63
+ processed_image = process(pil_image, color, fast_mode)
64
+ elif bg_type == "Image":
65
+ processed_image = process(pil_image, bg, fast_mode)
66
+ elif bg_type == "Video":
67
+ background_frame = background_frames[bg_frame_index % len(background_frames)]
68
+ bg_frame_index += 1
69
+ background_image = Image.fromarray(background_frame)
70
+ processed_image = process(pil_image, background_image, fast_mode)
71
+ else:
72
+ processed_image = pil_image # Default to original image if no background is selected
73
+ return np.array(processed_image), bg_frame_index
74
+ except Exception as e:
75
+ print(f"Error processing frame: {e}")
76
+ return frame, bg_frame_index
77
 
78
  @spaces.GPU
79
  def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=True):
 
91
  audio = video.audio
92
 
93
  # Extract frames at the specified FPS
94
+ frames = list(video.iter_frames(fps=fps))
95
 
96
  # Process each frame for background removal
97
  processed_frames = []
 
110
 
111
  bg_frame_index = 0 # Initialize background frame index
112
 
113
+ # Use ThreadPoolExecutor for parallel processing
114
+ with ThreadPoolExecutor(max_workers=4) as executor:
115
+ futures = [executor.submit(process_frame, frames[i], bg_type, bg_image, fast_mode, bg_frame_index, background_frames, color) for i in range(len(frames))]
116
+ for future in futures:
117
+ result, bg_frame_index = future.result()
118
+ processed_frames.append(result)
119
+ elapsed_time = time.time() - start_time
120
+ yield result, None, f"Processing frame {len(processed_frames)}... Elapsed time: {elapsed_time:.2f} seconds"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  # Create a new video from the processed frames
123
  processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
 
134
 
135
  elapsed_time = time.time() - start_time
136
  yield gr.update(visible=False), gr.update(visible=True), f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
137
+
138
  # Return the path to the temporary file
139
+ yield processed_frames[-1], temp_filepath, f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
140
 
141
  except Exception as e:
142
  print(f"Error: {e}")
 
144
  yield gr.update(visible=False), gr.update(visible=True), f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
145
  yield None, f"Error processing video: {e}", f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
146
 
 
147
  def process(image, bg, fast_mode=False):
148
  image_size = image.size
149
  input_images = transform_image(image).unsqueeze(0).to("cuda")
 
168
 
169
  # Composite the image onto the background using the mask
170
  image = Image.composite(image, background, mask)
 
171
  return image
172
 
 
173
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
174
+ gr.Markdown("# Video Background Remover & Changer\n### You can replace image background with any color, image or video.\nNOTE: As this Space is running on ZERO GPU it has limit. It can handle approx 200 frames at once. So, if you have a big video than use small chunks or Duplicate this space.")
175
  with gr.Row():
176
  in_video = gr.Video(label="Input Video", interactive=True)
177
  stream_image = gr.Image(label="Streaming Output", visible=False)
 
193
  with gr.Column(visible=False) as video_handling_options:
194
  video_handling_radio = gr.Radio(["slow_down", "loop"], label="Video Handling", value="slow_down", interactive=True)
195
  fast_mode_checkbox = gr.Checkbox(label="Fast Mode (Use BiRefNet_lite)", value=True, interactive=True)
196
+ time_textbox = gr.Textbox(label="Time Elapsed", interactive=False) # Add time textbox
 
197
 
198
  def update_visibility(bg_type):
199
  if bg_type == "Color":