KingNish commited on
Commit
fb480c5
1 Parent(s): 262a1a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -26
app.py CHANGED
@@ -17,6 +17,7 @@ torch.set_float32_matmul_precision("highest")
17
  birefnet = AutoModelForImageSegmentation.from_pretrained(
18
  "ZhengPeng7/BiRefNet", trust_remote_code=True
19
  ).to("cuda")
 
20
  transform_image = transforms.Compose(
21
  [
22
  transforms.Resize((1024, 1024)),
@@ -25,6 +26,8 @@ transform_image = transforms.Compose(
25
  ]
26
  )
27
 
 
 
28
  @spaces.GPU
29
  def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down"):
30
  try:
@@ -44,48 +47,44 @@ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=
44
  else:
45
  background_video = mp.concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1))
46
  background_frames = list(background_video.iter_frames(fps=fps))
47
- elif bg_type in ["Color", "Image"]:
48
- # Prepare background once if it's a static image or color
49
- if bg_type == "Color":
50
- color_rgb = tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
51
- background_pil = Image.new("RGBA", (1024, 1024), color_rgb + (255,))
52
- else: # bg_type == "Image":
53
- background_pil = Image.open(bg_image).convert("RGBA").resize((1024, 1024))
54
- background_tensor = transforms.ToTensor(background_pil).to("cuda")
55
  else:
56
- background_tensor = None
57
-
58
 
59
  bg_frame_index = 0
60
  frame_batch = []
 
61
  for i, frame in enumerate(frames):
62
- frame = Image.fromarray(frame)
63
- frame = transforms.ToTensor(frame).to('cuda')
64
  frame_batch.append(frame)
 
 
 
65
 
66
- if len(frame_batch) >= 3 or i == int(video.fps * video.duration) - 1 :
67
- input_images = torch.stack(frame_batch).to("cuda")
68
- with torch.no_grad():
69
- preds = birefnet(input_images)[-1].sigmoid()
70
- for j, pred in enumerate(preds):
71
- if bg_type == "Video":
 
72
  if video_handling == "slow_down":
73
  background_frame = background_frames[bg_frame_index % len(background_frames)]
74
  bg_frame_index += 1
75
- background_image = Image.fromarray(background_frame).resize((1024, 1024))
76
- background_tensor = transforms.ToTensor(background_image).to("cuda")
77
  else: # video_handling == "loop"
78
  background_frame = background_frames[bg_frame_index % len(background_frames)]
79
  bg_frame_index += 1
80
- background_image = Image.fromarray(background_frame).resize((1024, 1024))
81
- background_tensor = transforms.ToTensor(background_image).to("cuda")
82
- mask = transforms.ToPILImage()(pred.cpu().squeeze())
83
- processed_image = Image.composite(transforms.ToPILImage()(frame_batch[j].cpu()), transforms.ToPILImage()(background_tensor.cpu()), mask).resize(video.size)
84
 
 
 
 
 
 
85
  processed_frames.append(np.array(processed_image))
86
  yield processed_image, None
87
-
88
- frame_batch = []
89
 
90
 
91
  processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
@@ -107,6 +106,30 @@ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=
107
  yield None, f"Error processing video: {e}"
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
111
  with gr.Row():
112
  in_video = gr.Video(label="Input Video", interactive=True)
 
17
  birefnet = AutoModelForImageSegmentation.from_pretrained(
18
  "ZhengPeng7/BiRefNet", trust_remote_code=True
19
  ).to("cuda")
20
+
21
  transform_image = transforms.Compose(
22
  [
23
  transforms.Resize((1024, 1024)),
 
26
  ]
27
  )
28
 
29
+ BATCH_SIZE = 3
30
+
31
  @spaces.GPU
32
  def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down"):
33
  try:
 
47
  else:
48
  background_video = mp.concatenate_videoclips([background_video] * int(video.duration / background_video.duration + 1))
49
  background_frames = list(background_video.iter_frames(fps=fps))
 
 
 
 
 
 
 
 
50
  else:
51
+ background_frames = None
 
52
 
53
  bg_frame_index = 0
54
  frame_batch = []
55
+
56
  for i, frame in enumerate(frames):
 
 
57
  frame_batch.append(frame)
58
+ if len(frame_batch) == BATCH_SIZE or i == video.fps * video.duration -1: # Process batch or last frames
59
+ pil_images = [Image.fromarray(f) for f in frame_batch]
60
+
61
 
62
+ if bg_type == "Color":
63
+ processed_images = [process(img, color) for img in pil_images]
64
+ elif bg_type == "Image":
65
+ processed_images = [process(img, bg_image) for img in pil_images]
66
+ elif bg_type == "Video":
67
+ processed_images = []
68
+ for _ in range(len(frame_batch)):
69
  if video_handling == "slow_down":
70
  background_frame = background_frames[bg_frame_index % len(background_frames)]
71
  bg_frame_index += 1
72
+ background_image = Image.fromarray(background_frame)
 
73
  else: # video_handling == "loop"
74
  background_frame = background_frames[bg_frame_index % len(background_frames)]
75
  bg_frame_index += 1
76
+ background_image = Image.fromarray(background_frame)
77
+
78
+ processed_images.append(process(pil_images[_],background_image))
 
79
 
80
+
81
+ else:
82
+ processed_images = pil_images
83
+
84
+ for processed_image in processed_images:
85
  processed_frames.append(np.array(processed_image))
86
  yield processed_image, None
87
+ frame_batch = [] # Clear the batch
 
88
 
89
 
90
  processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
 
106
  yield None, f"Error processing video: {e}"
107
 
108
 
109
+ def process(image, bg):
110
+ image_size = image.size
111
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
112
+ # Prediction
113
+ with torch.no_grad():
114
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
115
+ pred = preds[0].squeeze()
116
+ pred_pil = transforms.ToPILImage()(pred)
117
+ mask = pred_pil.resize(image_size)
118
+
119
+ if isinstance(bg, str) and bg.startswith("#"):
120
+ color_rgb = tuple(int(bg[i:i+2], 16) for i in (1, 3, 5))
121
+ background = Image.new("RGBA", image_size, color_rgb + (255,))
122
+ elif isinstance(bg, Image.Image):
123
+ background = bg.convert("RGBA").resize(image_size)
124
+ else:
125
+ background = Image.open(bg).convert("RGBA").resize(image_size)
126
+
127
+ # Composite the image onto the background using the mask
128
+ image = Image.composite(image, background, mask)
129
+
130
+ return image
131
+
132
+
133
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
134
  with gr.Row():
135
  in_video = gr.Video(label="Input Video", interactive=True)