|
import PIL |
|
from PIL import ImageDraw, ImageFont, Image ,ImageOps ,ImageFilter |
|
from ultralytics import YOLO |
|
import warnings |
|
import cv2 |
|
import numpy as np |
|
import subprocess |
|
import os |
|
import matplotlib.pyplot as plt |
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
class SegmenterBackground(): |
|
def __init__(self) -> None: |
|
self.segment_names = {} |
|
self.person=['person'] |
|
self.animal=[ 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear','zebra', 'giraffe'] |
|
self.drive=['bicycle','car','motorcycle', 'airplane', 'bus', 'train','truck','boat'] |
|
|
|
|
|
|
|
def predict_image(self,raw_image: Image): |
|
model = YOLO("yolov8n-seg.pt") |
|
class_names = model.names |
|
results = model(raw_image) |
|
return results, class_names |
|
|
|
|
|
|
|
def assign_segment_name(self,label, segment_id): |
|
""" Assigns a unique name for each detected segment (e.g., Person 1, Person 2). """ |
|
if label not in self.segment_names: |
|
self.segment_names[label] = {} |
|
|
|
if segment_id not in self.segment_names[label]: |
|
segment_count = len(self.segment_names[label]) + 1 |
|
self.segment_names[label][segment_id] = f"{label} {segment_count}" |
|
|
|
return self.segment_names[label][segment_id] |
|
|
|
|
|
def putMaskImage(self,raw_image,masks,background_image="remove",blur_radius=23): |
|
combined_mask = np.max(masks, axis=0) |
|
|
|
|
|
mask = combined_mask == True |
|
|
|
|
|
mask_rgb = np.stack([mask] * 3, axis=-1) |
|
|
|
|
|
|
|
if type(background_image)==PIL.Image.Image: |
|
outpt = np.array(background_image.copy()) |
|
elif(background_image=="cam"): |
|
outpt=np.array(raw_image.filter(ImageFilter.GaussianBlur(radius=blur_radius))) |
|
else: |
|
outpt=np.zeros_like(raw_image) |
|
|
|
|
|
outpt[mask_rgb] = np.array(raw_image)[mask_rgb] |
|
|
|
|
|
outpt = Image.fromarray(outpt) |
|
|
|
return outpt |
|
|
|
|
|
def getFont(self): |
|
try: |
|
font = ImageFont.truetype("arial.ttf", size=20) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
return font |
|
|
|
def Back_step1(self,raw_image: Image, background_image: Image,blur_radius=23): |
|
org_size = raw_image.size |
|
raw_image = raw_image.resize((640, 480)) |
|
if type(background_image) == PIL.JpegImagePlugin.JpegImageFile: |
|
background_image = background_image.resize((640, 480)) |
|
label_counter = [] |
|
|
|
|
|
results, class_names = self.predict_image(raw_image) |
|
|
|
masks = [results[0].masks.data[i].cpu().numpy() for i in range(len(results[0].masks.data))] |
|
|
|
|
|
outpt = self.putMaskImage(raw_image,masks,background_image,blur_radius) |
|
|
|
|
|
font=self.getFont() |
|
draw = ImageDraw.Draw(outpt) |
|
|
|
|
|
|
|
for box, label, seg_id in zip(results[0].boxes.xyxy.cpu().numpy(), |
|
results[0].boxes.cls.cpu().numpy(), |
|
range(len(results[0].boxes))): |
|
label_name = class_names[int(label)] |
|
|
|
|
|
current_label = self.assign_segment_name(label_name, seg_id) |
|
|
|
x1, y1, x2, y2 = map(int, box) |
|
|
|
draw.rectangle([x1, y1, x2, y2], outline="red", width=2) |
|
|
|
draw.text((x1, y1), current_label+" " + str(seg_id), fill="black", font=font) |
|
|
|
label_counter.append(current_label) |
|
|
|
return outpt.resize(org_size), label_counter |
|
|
|
|
|
def Back_step2(self,raw_image:Image,background_image:Image,things_replace:list,blur_radius=23): |
|
org_size = raw_image.size |
|
raw_image = raw_image.resize((640, 480)) |
|
print(type(background_image)) |
|
if type(background_image)==PIL.JpegImagePlugin.JpegImageFile: |
|
background_image = background_image.resize((640, 480)) |
|
|
|
|
|
results, class_names = self.predict_image(raw_image) |
|
|
|
masks=[] |
|
for segm, label,seg_id in zip(results[0].masks.data,results[0].boxes.cls.cpu().numpy(),range(len(results[0].boxes))): |
|
label_name = class_names[int(label)] |
|
current_label = self.assign_segment_name(label_name, seg_id) |
|
|
|
if current_label in things_replace: |
|
masks.append(segm.cpu().numpy()) |
|
|
|
masked_image=self.putMaskImage(raw_image,masks,background_image,blur_radius) |
|
return masked_image.resize(org_size) |
|
|
|
|
|
|
|
def get_labels(self,kind_back): |
|
list_output=[] |
|
|
|
if ('person' in kind_back): |
|
list_output=list_output + self.person |
|
if ('animal' in kind_back): |
|
list_output=list_output + self.animal |
|
if ('drive' in kind_back): |
|
list_output=list_output + self.drive |
|
|
|
return list_output |
|
|
|
|
|
def Back_video(self,video_path,output_path,background_image,kind_back,blur_radius=35): |
|
cap = cv2.VideoCapture(video_path) |
|
|
|
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'XVID') |
|
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) |
|
if isinstance(background_image, Image.Image): |
|
background_image = background_image.resize((640, 480)) |
|
|
|
|
|
sound_tmp_file='audio.mp3' |
|
if os.path.exists(sound_tmp_file): |
|
os.remove(sound_tmp_file) |
|
subprocess.run(['ffmpeg', '-i', video_path, '-q:a', '0', '-map', 'a',sound_tmp_file]) |
|
else: |
|
subprocess.run(['ffmpeg', '-i', video_path, '-q:a', '0', '-map', 'a',sound_tmp_file]) |
|
|
|
|
|
|
|
i=0 |
|
while True: |
|
|
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame_rgb = Image.fromarray(np.array(frame_rgb)) |
|
org_size = frame_rgb.size |
|
frame_rgb = frame_rgb.resize((640, 480)) |
|
|
|
results,class_names = self.predict_image(frame_rgb) |
|
|
|
masks=[] |
|
|
|
things_replace=self.get_labels(kind_back) |
|
|
|
for segm, label in zip(results[0].masks.data,results[0].boxes.cls.cpu().numpy()): |
|
label_name = class_names[int(label)] |
|
|
|
if label_name in things_replace: |
|
masks.append(segm.cpu().numpy()) |
|
|
|
masked_image = self.putMaskImage(frame_rgb,masks,background_image,blur_radius) |
|
|
|
|
|
out.write(cv2.cvtColor(np.array(masked_image.resize(org_size)), cv2.COLOR_RGB2BGR)) |
|
|
|
print(f"Completed frame {i+1} ") |
|
i=i+1 |
|
|
|
|
|
|
|
|
|
print("Finished frames") |
|
|
|
|
|
cap.release() |
|
out.release() |
|
cv2.destroyAllWindows() |
|
|
|
|
|
|
|
os.remove(sound_tmp_file) |
|
|
|
|
|
|