# ------------------------------------------------------------------------
# Modified from Grounded-SAM (https://github.com/IDEA-Research/Grounded-Segment-Anything)
# ------------------------------------------------------------------------
import os
import sys
import random
import warnings
os.system("export BUILD_WITH_CUDA=True")
os.system("python -m pip install -e segment-anything")
os.system("python -m pip install -e GroundingDINO")
os.system("pip install --upgrade diffusers[torch]")
#os.system("pip install opencv-python pycocotools matplotlib")
sys.path.insert(0, './GroundingDINO')
sys.path.insert(0, './segment-anything')
warnings.filterwarnings("ignore")
import cv2
from scipy import ndimage
import gradio as gr
import argparse
import numpy as np
from PIL import Image
from moviepy.editor import *
import torch
from torch.nn import functional as F
import torchvision
import networks
import utils
# Grounding DINO
from groundingdino.util.inference import Model
# SAM
from segment_anything.utils.transforms import ResizeLongestSide
# SD
from diffusers import StableDiffusionPipeline
transform = ResizeLongestSide(1024)
# Green Screen
PALETTE_back = (51, 255, 146)
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swint_ogc.pth"
mam_checkpoint="checkpoints/mam_sam_vitb.pth"
output_dir="outputs"
device = 'cuda'
background_list = os.listdir('assets/backgrounds')
#groundingdino_model = None
#mam_predictor = None
#generator = None
# initialize MAM
mam_model = networks.get_generator_m2m(seg='sam', m2m='sam_decoder_deep')
mam_model.to(device)
checkpoint = torch.load(mam_checkpoint, map_location=device)
mam_model.load_state_dict(utils.remove_prefix_state_dict(checkpoint['state_dict']), strict=True)
mam_model = mam_model.eval()
# initialize GroundingDINO
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=device)
# initialize StableDiffusionPipeline
generator = StableDiffusionPipeline.from_pretrained("checkpoints/stable-diffusion-v1-5", torch_dtype=torch.float16)
generator.to(device)
def get_frames(video_in):
frames = []
#resize the video
clip = VideoFileClip(video_in)
#check fps
if clip.fps > 30:
print("vide rate is over 30, resetting to 30")
clip_resized = clip.resize(height=512)
clip_resized.write_videofile("video_resized.mp4", fps=30)
else:
print("video rate is OK")
clip_resized = clip.resize(height=512)
clip_resized.write_videofile("video_resized.mp4", fps=clip.fps)
print("video resized to 512 height")
# Opens the Video file with CV2
cap= cv2.VideoCapture("video_resized.mp4")
fps = cap.get(cv2.CAP_PROP_FPS)
print("video fps: " + str(fps))
i=0
while(cap.isOpened()):
ret, frame = cap.read()
if ret == False:
break
cv2.imwrite('kang'+str(i)+'.jpg',frame)
frames.append('kang'+str(i)+'.jpg')
i+=1
cap.release()
cv2.destroyAllWindows()
print("broke the video into frames")
return frames, fps
def create_video(frames, fps, type):
print("building video result")
clip = ImageSequenceClip(frames, fps=fps)
clip.write_videofile(f"video_{type}_result.mp4", fps=fps)
return f"video_{type}_result.mp4"
def run_grounded_sam(input_image, text_prompt, task_type, background_prompt, bg_already):
background_type = "generated_by_text"
box_threshold = 0.25
text_threshold = 0.25
iou_threshold = 0.5
scribble_mode = "split"
guidance_mode = "alpha"
#global groundingdino_model, sam_predictor, generator
# make dir
os.makedirs(output_dir, exist_ok=True)
#if mam_predictor is None:
# initialize MAM
# build model
# mam_model = networks.get_generator_m2m(seg='sam', m2m='sam_decoder_deep')
# mam_model.to(device)
# load checkpoint
# checkpoint = torch.load(mam_checkpoint, map_location=device)
# mam_model.load_state_dict(utils.remove_prefix_state_dict(checkpoint['state_dict']), strict=True)
# inference
# mam_model = mam_model.eval()
#if groundingdino_model is None:
# grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=device)
#if generator is None:
# generator = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
# generator.to(device)
# load image
#image_ori = input_image["image"]
image_ori = input_image
#scribble = input_image["mask"]
original_size = image_ori.shape[:2]
if task_type == 'text':
if text_prompt is None:
print('Please input non-empty text prompt')
with torch.no_grad():
detections, phrases = grounding_dino_model.predict_with_caption(
image=cv2.cvtColor(image_ori, cv2.COLOR_RGB2BGR),
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
if len(detections.xyxy) > 1:
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
iou_threshold,
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
bbox = detections.xyxy[np.argmax(detections.confidence)]
bbox = transform.apply_boxes(bbox, original_size)
bbox = torch.as_tensor(bbox, dtype=torch.float).to(device)
image = transform.apply_image(image_ori)
image = torch.as_tensor(image).to(device)
image = image.permute(2, 0, 1).contiguous()
pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3,1,1).to(device)
pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3,1,1).to(device)
image = (image - pixel_mean) / pixel_std
h, w = image.shape[-2:]
pad_size = image.shape[-2:]
padh = 1024 - h
padw = 1024 - w
image = F.pad(image, (0, padw, 0, padh))
if task_type == 'scribble_point':
scribble = scribble.transpose(2, 1, 0)[0]
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
centers = np.array(centers)
### (x,y)
centers = transform.apply_coords(centers, original_size)
point_coords = torch.from_numpy(centers).to(device)
point_coords = point_coords.unsqueeze(0).to(device)
point_labels = torch.from_numpy(np.array([1] * len(centers))).unsqueeze(0).to(device)
if scribble_mode == 'split':
point_coords = point_coords.permute(1, 0, 2)
point_labels = point_labels.permute(1, 0)
sample = {'image': image.unsqueeze(0), 'point': point_coords, 'label': point_labels, 'ori_shape': original_size, 'pad_shape': pad_size}
elif task_type == 'scribble_box':
scribble = scribble.transpose(2, 1, 0)[0]
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
centers = np.array(centers)
### (x1, y1, x2, y2)
x_min = centers[:, 0].min()
x_max = centers[:, 0].max()
y_min = centers[:, 1].min()
y_max = centers[:, 1].max()
bbox = np.array([x_min, y_min, x_max, y_max])
bbox = transform.apply_boxes(bbox, original_size)
bbox = torch.as_tensor(bbox, dtype=torch.float).to(device)
sample = {'image': image.unsqueeze(0), 'bbox': bbox.unsqueeze(0), 'ori_shape': original_size, 'pad_shape': pad_size}
elif task_type == 'text':
sample = {'image': image.unsqueeze(0), 'bbox': bbox.unsqueeze(0), 'ori_shape': original_size, 'pad_shape': pad_size}
else:
print("task_type:{} error!".format(task_type))
with torch.no_grad():
feas, pred, post_mask = mam_model.forward_inference(sample)
alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8']
alpha_pred_os8 = alpha_pred_os8[..., : sample['pad_shape'][0], : sample['pad_shape'][1]]
alpha_pred_os4 = alpha_pred_os4[..., : sample['pad_shape'][0], : sample['pad_shape'][1]]
alpha_pred_os1 = alpha_pred_os1[..., : sample['pad_shape'][0], : sample['pad_shape'][1]]
alpha_pred_os8 = F.interpolate(alpha_pred_os8, sample['ori_shape'], mode="bilinear", align_corners=False)
alpha_pred_os4 = F.interpolate(alpha_pred_os4, sample['ori_shape'], mode="bilinear", align_corners=False)
alpha_pred_os1 = F.interpolate(alpha_pred_os1, sample['ori_shape'], mode="bilinear", align_corners=False)
if guidance_mode == 'mask':
weight_os8 = utils.get_unknown_tensor_from_mask_oneside(post_mask, rand_width=10, train_mode=False)
post_mask[weight_os8>0] = alpha_pred_os8[weight_os8>0]
alpha_pred = post_mask.clone().detach()
else:
weight_os8 = utils.get_unknown_box_from_mask(post_mask)
alpha_pred_os8[weight_os8>0] = post_mask[weight_os8>0]
alpha_pred = alpha_pred_os8.clone().detach()
weight_os4 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=20, train_mode=False)
alpha_pred[weight_os4>0] = alpha_pred_os4[weight_os4>0]
weight_os1 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=10, train_mode=False)
alpha_pred[weight_os1>0] = alpha_pred_os1[weight_os1>0]
alpha_pred = alpha_pred[0][0].cpu().numpy()
#### draw
### alpha matte
alpha_rgb = cv2.cvtColor(np.uint8(alpha_pred*255), cv2.COLOR_GRAY2RGB)
### com img with background
global background_img
if background_type == 'real_world_sample':
background_img_file = os.path.join('assets/backgrounds', random.choice(background_list))
background_img = cv2.imread(background_img_file)
background_img = cv2.cvtColor(background_img, cv2.COLOR_BGR2RGB)
background_img = cv2.resize(background_img, (image_ori.shape[1], image_ori.shape[0]))
com_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.uint8(background_img)
com_img = np.uint8(com_img)
else:
if background_prompt is None:
print('Please input non-empty background prompt')
else:
if bg_already is False:
background_img = generator(background_prompt).images[0]
background_img = np.array(background_img)
background_img = cv2.resize(background_img, (image_ori.shape[1], image_ori.shape[0]))
com_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.uint8(background_img)
com_img = np.uint8(com_img)
### com img with green screen
green_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.array([PALETTE_back], dtype='uint8')
green_img = np.uint8(green_img)
#return [(com_img, 'composite with background'), (green_img, 'green screen'), (alpha_rgb, 'alpha matte')]
return com_img, green_img, alpha_rgb
def infer(video_in, trim_value, prompt, background_prompt):
print(prompt)
break_vid = get_frames(video_in)
frames_list= break_vid[0]
fps = break_vid[1]
n_frame = int(trim_value*fps)
if n_frame >= len(frames_list):
print("video is shorter than the cut value")
n_frame = len(frames_list)
with_bg_result_frames = []
with_green_result_frames = []
with_matte_result_frames = []
print("set stop frames to: " + str(n_frame))
bg_already = False
for i in frames_list[0:int(n_frame)]:
to_numpy_i = Image.open(i).convert("RGB")
#need to convert to numpy
# Convert the image to a NumPy array
image_array = np.array(to_numpy_i)
results = run_grounded_sam(image_array, prompt, "text", background_prompt, bg_already)
bg_already = True
bg_img = Image.fromarray(results[0])
green_img = Image.fromarray(results[1])
matte_img = Image.fromarray(results[2])
# exporting the images
bg_img.save(f"bg_result_img-{i}.jpg")
with_bg_result_frames.append(f"bg_result_img-{i}.jpg")
green_img.save(f"green_result_img-{i}.jpg")
with_green_result_frames.append(f"green_result_img-{i}.jpg")
matte_img.save(f"matte_result_img-{i}.jpg")
with_matte_result_frames.append(f"matte_result_img-{i}.jpg")
print("frame " + i + "/" + str(n_frame) + ": done;")
vid_bg = create_video(with_bg_result_frames, fps, "bg")
vid_green = create_video(with_green_result_frames, fps, "greenscreen")
vid_matte = create_video(with_matte_result_frames, fps, "matte")
bg_already = False
print("finished !")
return vid_bg, vid_green, vid_matte
if __name__ == "__main__":
parser = argparse.ArgumentParser("MAM demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
parser.add_argument('--port', type=int, default=7589, help='port to run the server')
parser.add_argument('--no-gradio-queue', action="store_true", help='path to the SAM checkpoint')
args = parser.parse_args()
print(args)
block = gr.Blocks()
if not args.no_gradio_queue:
block = block.queue()
with block:
gr.Markdown(
"""
# Matting Anything in Video Demo
Welcome to the Matting Anything in Video demo by @fffiloni and upload your video to get started
You may open usage details below to understand how to use this demo.
## Usage
You may upload a video to start, for the moment we only support 1 prompt type to get the alpha matte of the target:
**text**: Send text prompt to identify the target instance in the `Text prompt` box.
We also only support 1 background type to support image composition with the alpha matte output:
**generated_by_text**: Send background text prompt to create a background image with stable diffusion model in the `Background prompt` box.
for longer sequences, more control and no queue.
""")
with gr.Row():
with gr.Column():
video_in = gr.Video()
trim_in = gr.Slider(label="Cut video at (s)", minimum=1, maximum=10, step=1, value=1)
#task_type = gr.Dropdown(["scribble_point", "scribble_box", "text"], value="text", label="Prompt type")
#task_type = "text"
text_prompt = gr.Textbox(label="Text prompt", placeholder="the girl in the middle", info="Describe the subject visible in your video that you want to matte")
#background_type = gr.Dropdown(["generated_by_text", "real_world_sample"], value="generated_by_text", label="Background type")
background_prompt = gr.Textbox(label="Background prompt", placeholder="downtown area in New York")
run_button = gr.Button("Run")
#with gr.Accordion("Advanced options", open=False):
# box_threshold = gr.Slider(
# label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05
# )
# text_threshold = gr.Slider(
# label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05
# )
# iou_threshold = gr.Slider(
# label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05
# )
# scribble_mode = gr.Dropdown(
# ["merge", "split"], value="split", label="scribble_mode"
# )
# guidance_mode = gr.Dropdown(
# ["mask", "alpha"], value="alpha", label="guidance_mode", info="mask guidance is for complex scenes with multiple instances, alpha guidance is for simple scene with single instance"
# )
with gr.Column():
#gallery = gr.Gallery(
# label="Generated images", show_label=True, elem_id="gallery"
#).style(preview=True, grid=3, object_fit="scale-down")
vid_bg_out = gr.Video(label="Video with background")
with gr.Row():
vid_green_out = gr.Video(label="Video green screen")
vid_matte_out = gr.Video(label="Video matte")
gr.Examples(
fn=infer,
examples=[
[
"./examples/example_men_bottle.mp4",
10,
"the man holding a bottle",
"the Sahara desert"
]
],
inputs=[video_in, trim_in, text_prompt, background_prompt],
outputs=[vid_bg_out, vid_green_out, vid_matte_out]
)
run_button.click(fn=infer, inputs=[
video_in, trim_in, text_prompt, background_prompt], outputs=[vid_bg_out, vid_green_out, vid_matte_out], api_name="go_matte")
block.queue(max_size=24).launch(debug=args.debug, share=args.share, show_error=True)
#block.queue(concurrency_count=100)
#block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)