import subprocess # Define the command to be executed command = ["python", "setup.py", "build_ext", "--inplace"] # Execute the command result = subprocess.run(command, capture_output=True, text=True) # Print the output and error (if any) print("Output:\n", result.stdout) print("Errors:\n", result.stderr) # Check if the command was successful if result.returncode == 0: print("Command executed successfully.") else: print("Command failed with return code:", result.returncode) import gradio as gr from datetime import datetime import os os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" import torch import numpy as np import cv2 import matplotlib.pyplot as plt from PIL import Image, ImageFilter from sam2.build_sam import build_sam2_video_predictor from moviepy.editor import ImageSequenceClip def preprocess_image(image): return image, gr.State([]), gr.State([]), image, gr.State() def preprocess_video_in(video_path): # Generate a unique ID based on the current date and time unique_id = datetime.now().strftime('%Y%m%d%H%M%S') output_dir = f'frames_{unique_id}' # Create the output directory os.makedirs(output_dir, exist_ok=True) # Open the video file cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print("Error: Could not open video.") return None frame_number = 0 first_frame = None while True: ret, frame = cap.read() if not ret: break # Format the frame filename as '00000.jpg' frame_filename = os.path.join(output_dir, f'{frame_number:05d}.jpg') # Save the frame as a JPEG file cv2.imwrite(frame_filename, frame) # Store the first frame if frame_number == 0: first_frame = frame_filename frame_number += 1 # Release the video capture object cap.release() # 'image' is the first frame extracted from video_in return first_frame, gr.State([]), gr.State([]), first_frame, first_frame, output_dir, None, None def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData): print(f"You selected {evt.value} at {evt.index} from {evt.target}") tracking_points.value.append(evt.index) print(f"TRACKING POINT: {tracking_points.value}") if point_type == "include": trackings_input_label.value.append(1) elif point_type == "exclude": trackings_input_label.value.append(0) print(f"TRACKING INPUT LABEL: {trackings_input_label.value}") # Open the image and get its dimensions transparent_background = Image.open(first_frame_path).convert('RGBA') w, h = transparent_background.size # Define the circle radius as a fraction of the smaller dimension fraction = 0.02 # You can adjust this value as needed radius = int(fraction * min(w, h)) # Create a transparent layer to draw on transparent_layer = np.zeros((h, w, 4), dtype=np.uint8) for index, track in enumerate(tracking_points.value): if trackings_input_label.value[index] == 1: cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1) else: cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1) # Convert the transparent layer back to an image transparent_layer = Image.fromarray(transparent_layer, 'RGBA') selected_point_map = Image.alpha_composite(transparent_background, transparent_layer) return tracking_points, trackings_input_label, selected_point_map # use bfloat16 for the entire notebook torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True def show_mask(mask, ax, obj_id=None, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: cmap = plt.get_cmap("tab10") cmap_idx = 0 if obj_id is None else obj_id color = np.array([*cmap(cmap_idx)[:3], 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=200): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True): combined_images = [] # List to store filenames of images with masks overlaid mask_images = [] # List to store filenames of separate mask images for i, (mask, score) in enumerate(zip(masks, scores)): # ---- Original Image with Mask Overlaid ---- plt.figure(figsize=(10, 10)) plt.imshow(image) show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders """ if point_coords is not None: assert input_labels is not None show_points(point_coords, input_labels, plt.gca()) """ if box_coords is not None: show_box(box_coords, plt.gca()) if len(scores) > 1: plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) plt.axis('off') # Save the figure as a JPG file combined_filename = f"combined_image_{i+1}.jpg" plt.savefig(combined_filename, format='jpg', bbox_inches='tight') combined_images.append(combined_filename) plt.close() # Close the figure to free up memory # ---- Separate Mask Image (White Mask on Black Background) ---- # Create a black image mask_image = np.zeros_like(image, dtype=np.uint8) # The mask is a binary array where the masked area is 1, else 0. # Convert the mask to a white color in the mask_image mask_layer = (mask > 0).astype(np.uint8) * 255 for c in range(3): # Assuming RGB, repeat mask for all channels mask_image[:, :, c] = mask_layer # Save the mask image mask_filename = f"mask_image_{i+1}.png" Image.fromarray(mask_image).save(mask_filename) mask_images.append(mask_filename) plt.close() # Close the figure to free up memory return combined_images, mask_images def load_model(checkpoint): # Load model accordingly to user's choice if checkpoint == "tiny": sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt" model_cfg = "sam2_hiera_t.yaml" elif checkpoint == "samll": sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt" model_cfg = "sam2_hiera_s.yaml" elif checkpoint == "base-plus": sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt" model_cfg = "sam2_hiera_b+.yaml" elif checkpoint == "large": sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" model_cfg = "sam2_hiera_l.yaml" return sam2_checkpoint, model_cfg def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir): # 1. We need to preprocess the video and store frames in the right directory # — Penser à utiliser un ID unique pour le dossier sam2_checkpoint, model_cfg = load_model(checkpoint) predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) # `video_dir` a directory of JPEG frames with filenames like `.jpg` print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}") video_dir = video_frames_dir # scan all the JPEG frame names in this directory frame_names = [ p for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) inference_state = predictor.init_state(video_path=video_dir) # segment and track one object # predictor.reset_state(inference_state) # if any previous tracking, reset # Add new point ann_frame_idx = 0 # the frame index we interact with ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) # Let's add a positive click at (x, y) = (210, 350) to get started points = np.array(tracking_points.value, dtype=np.float32) # for labels, `1` means positive click and `0` means negative click labels = np.array(trackings_input_label.value, np.int32) _, out_obj_ids, out_mask_logits = predictor.add_new_points( inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels, ) # Create the plot plt.figure(figsize=(12, 8)) plt.title(f"frame {ann_frame_idx}") plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx]))) show_points(points, labels, plt.gca()) show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0]) # Save the plot as a JPG file first_frame_output_filename = "output_first_frame.jpg" plt.savefig(first_frame_output_filename, format='jpg') plt.close() return "output_first_frame.jpg", frame_names, inference_state def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type): #### PROPAGATION #### sam2_checkpoint, model_cfg = load_model(checkpoint) predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) inference_state = stored_inference_state frame_names = stored_frame_names video_dir = video_frames_dir # Define a directory to save the JPEG images frames_output_dir = "frames_output_images" os.makedirs(frames_output_dir, exist_ok=True) # Initialize a list to store file paths of saved images jpeg_images = [] # run propagation throughout the video and collect the results in a dict video_segments = {} # video_segments contains the per-frame segmentation results for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } # render the segmentation results every few frames if vis_frame_type == "check": vis_frame_stride = 15 elif vis_frame_type == "render": vis_frame_stride = 1 plt.close("all") for out_frame_idx in range(0, len(frame_names), vis_frame_stride): plt.figure(figsize=(6, 4)) plt.title(f"frame {out_frame_idx}") plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))) for out_obj_id, out_mask in video_segments[out_frame_idx].items(): show_mask(out_mask, plt.gca(), obj_id=out_obj_id) # Define the output filename and save the figure as a JPEG file output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg") plt.savefig(output_filename, format='jpg') # Append the file path to the list jpeg_images.append(output_filename) # Close the plot plt.close() if vis_frame_type == "check": return gr.update(value=jpeg_images, visible=True), gr.update(visible=False, value=None) elif vis_frame_type == "render": # Create a video clip from the image sequence fps = 24 # Frames per second clip = ImageSequenceClip(jpeg_images, fps=fps) # Write the result to a file final_vid_output_path = "output_video.mp4" video.write_videofile(output_path, codec='libx264') return gr.update(visible=False, value=None), gr.update(value=final_vid_output_path, visible=True) with gr.Blocks() as demo: first_frame_path = gr.State() tracking_points = gr.State([]) trackings_input_label = gr.State([]) video_frames_dir = gr.State() stored_inference_state = gr.State() stored_frame_names = gr.State() with gr.Column(): gr.Markdown("# SAM2 Video Predictor") gr.Markdown("This is a simple demo for video segmentation with SAM2.") gr.Markdown("""Instructions: 1. Upload your video 2. With 'include' point type selected, Click on the object to mask on first frame 3. Switch to 'exclude' point type if you want to specify an area to avoid 4. Submit ! """) with gr.Row(): with gr.Column(): input_first_frame_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False) points_map = gr.Image( label="points map", type="filepath", interactive=False ) video_in = gr.Video(label="Video IN") with gr.Row(): point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include") clear_points_btn = gr.Button("Clear Points") checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny") submit_btn = gr.Button("Submit") with gr.Column(): output_result = gr.Image() with gr.Row(): vis_frame_type = gr.Radio(choices=["check", "render"], value="render", scale=2) propagate_btn = gr.Button("Propagate", scale=1) output_propagated = gr.Gallery(visible=False) output_video = gr.Video(visible=False) # output_result_mask = gr.Image() clear_points_btn.click( fn = preprocess_image, inputs = input_first_frame_image, outputs = [first_frame_path, tracking_points, trackings_input_label, points_map, stored_inference_state], queue=False ) video_in.upload( fn = preprocess_video_in, inputs = [video_in], outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map, video_frames_dir, stored_inference_state, stored_frame_names], queue = False ) points_map.select( fn = get_point, inputs = [point_type, tracking_points, trackings_input_label, first_frame_path], outputs = [tracking_points, trackings_input_label, points_map], queue = False ) submit_btn.click( fn = sam_process, inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir], outputs = [output_result, stored_frame_names, stored_inference_state] ) propagate_btn.click( fn = propagate_to_all, inputs = [checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type], outputs = [output_propagated, output_video] ) demo.launch(show_api=False, show_error=True)