Update app.py
Browse files
app.py
CHANGED
@@ -13,17 +13,14 @@ import tempfile
|
|
13 |
import uuid
|
14 |
import time
|
15 |
from concurrent.futures import ThreadPoolExecutor
|
16 |
-
import asyncio
|
17 |
|
18 |
torch.set_float32_matmul_precision("medium")
|
19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
|
21 |
# Load both BiRefNet models
|
22 |
-
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
23 |
-
"ZhengPeng7/BiRefNet", trust_remote_code=True)
|
24 |
birefnet.to(device)
|
25 |
-
birefnet_lite = AutoModelForImageSegmentation.from_pretrained(
|
26 |
-
"ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
|
27 |
birefnet_lite.to(device)
|
28 |
|
29 |
transform_image = transforms.Compose([
|
@@ -32,74 +29,77 @@ transform_image = transforms.Compose([
|
|
32 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
33 |
])
|
34 |
|
35 |
-
# Function to process a single frame
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
@spaces.GPU
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
90 |
elapsed_time = time.time() - start_time
|
91 |
-
yield
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
processed_video.write_videofile(temp_filepath, codec="libx264")
|
99 |
-
|
100 |
-
elapsed_time = time.time() - start_time
|
101 |
-
yield gr.update(visible=False), gr.update(visible=True), f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
|
102 |
-
yield processed_frames[-1], temp_filepath, f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
|
103 |
|
104 |
def process(image, bg, fast_mode=False):
|
105 |
image_size = image.size
|
|
|
13 |
import uuid
|
14 |
import time
|
15 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
16 |
|
17 |
torch.set_float32_matmul_precision("medium")
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
|
20 |
# Load both BiRefNet models
|
21 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
|
|
|
22 |
birefnet.to(device)
|
23 |
+
birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True)
|
|
|
24 |
birefnet_lite.to(device)
|
25 |
|
26 |
transform_image = transforms.Compose([
|
|
|
29 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
30 |
])
|
31 |
|
32 |
+
# Function to process a single frame
|
33 |
+
def process_frame(frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color):
|
34 |
+
try:
|
35 |
+
pil_image = Image.fromarray(frame)
|
36 |
+
if bg_type == "Color":
|
37 |
+
processed_image = process(pil_image, color, fast_mode)
|
38 |
+
elif bg_type == "Image":
|
39 |
+
processed_image = process(pil_image, bg, fast_mode)
|
40 |
+
elif bg_type == "Video":
|
41 |
+
background_frame = background_frames[bg_frame_index % len(background_frames)]
|
42 |
+
bg_frame_index += 1
|
43 |
+
background_image = Image.fromarray(background_frame)
|
44 |
+
processed_image = process(pil_image, background_image, fast_mode)
|
45 |
+
else:
|
46 |
+
processed_image = pil_image # Default to original image if no background is selected
|
47 |
+
return np.array(processed_image), bg_frame_index
|
48 |
+
except Exception as e:
|
49 |
+
print(f"Error processing frame: {e}")
|
50 |
+
return frame, bg_frame_index
|
51 |
|
52 |
@spaces.GPU
|
53 |
+
def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=True, max_workers=6):
|
54 |
+
try:
|
55 |
+
start_time = time.time() # Start the timer
|
56 |
+
video = mp.VideoFileClip(vid)
|
57 |
+
if fps == 0:
|
58 |
+
fps = video.fps
|
59 |
+
|
60 |
+
audio = video.audio
|
61 |
+
frames = list(video.iter_frames(fps=fps))
|
62 |
+
|
63 |
+
processed_frames = []
|
64 |
+
yield gr.update(visible=True), gr.update(visible=False), f"Processing started... Elapsed time: 0 seconds"
|
65 |
+
|
66 |
+
if bg_type == "Video":
|
67 |
+
background_video = mp.VideoFileClip(bg_video)
|
68 |
+
if background_video.duration < video.duration:
|
69 |
+
if video_handling == "slow_down":
|
70 |
+
background_video = background_video.fx(mp.vfx.speedx, factor=video.duration / background_video.duration)
|
71 |
+
else: # video_handling == "loop"
|
72 |
+
background_video = mp.concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1))
|
73 |
+
background_frames = list(background_video.iter_frames(fps=fps))
|
74 |
+
else:
|
75 |
+
background_frames = None
|
76 |
+
|
77 |
+
bg_frame_index = 0 # Initialize background frame index
|
78 |
|
79 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
80 |
+
futures = [executor.submit(process_frame, frames[i], bg_type, bg_image, fast_mode, bg_frame_index, background_frames, color) for i in range(len(frames))]
|
81 |
+
for future in futures:
|
82 |
+
result, bg_frame_index = future.result()
|
83 |
+
processed_frames.append(result)
|
84 |
+
elapsed_time = time.time() - start_time
|
85 |
+
yield result, None, f"Processing frame {len(processed_frames)}... Elapsed time: {elapsed_time:.2f} seconds"
|
86 |
+
|
87 |
+
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
|
88 |
+
processed_video = processed_video.set_audio(audio)
|
89 |
+
|
90 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
|
91 |
+
temp_filepath = temp_file.name
|
92 |
+
processed_video.write_videofile(temp_filepath, codec="libx264")
|
93 |
+
|
94 |
elapsed_time = time.time() - start_time
|
95 |
+
yield gr.update(visible=False), gr.update(visible=True), f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
|
96 |
+
yield processed_frames[-1], temp_filepath, f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds"
|
97 |
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Error: {e}")
|
100 |
+
elapsed_time = time.time() - start_time
|
101 |
+
yield gr.update(visible=False), gr.update(visible=True), f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
|
102 |
+
yield None, f"Error processing video: {e}", f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds"
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
def process(image, bg, fast_mode=False):
|
105 |
image_size = image.size
|