from yolact.yolact import Yolact from yolact.utils.augmentations import BaseTransform, FastBaseTransform, Resize from yolact.utils import timer from yolact.layers.output_utils import postprocess, undo_image_transformation from yolact.data import cfg, set_cfg, set_dataset import torch import cv2 import gradio as gr import numpy as np import random use_cuda = torch.cuda.is_available() DEVICE = torch.device('cuda' if use_cuda else 'cpu') def apply_mask(image, mask, color, alpha=0.8): """Apply the given mask to the image. """ for c in range(3): image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - alpha) + alpha * color[c] * 255, image[:, :, c]) return image def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str=''): """ Note: If undo_transform=False then im_h and im_w are allowed to be None. """ if undo_transform: img_numpy = undo_image_transformation(img, w, h) img_gpu = torch.Tensor(img_numpy).cpu() else: img_gpu = img / 255.0 h, w, _ = img.shape with timer.env('Postprocess'): save = cfg.rescore_bbox cfg.rescore_bbox = True t = postprocess(dets_out, w, h, visualize_lincomb=False, crop_masks=True, score_threshold=0.99) cfg.rescore_bbox = save return t net = Yolact() # net.load_weights('./yolact/weights/yolact_base_1351_50000.pth') net.load_state_dict(torch.load('./yolact/weights/yolact_base_1351_50000.pth', map_location=DEVICE)) net.eval() def detect_corn(inp): inp = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB) print(inp.shape) frame = torch.from_numpy(inp).cpu().float() batch = FastBaseTransform()(frame.unsqueeze(0)) preds = net(batch) classes, scores, box, mask = prep_display(preds, frame, None, None, undo_transform=False) mask_to_save = mask.permute(1, 2, 0).cpu().detach().numpy() # TODO перекрытие лишних боксов for i, bbox in enumerate(box): x1, y1, x2, y2 = box[i] # img_to_save = inp * np.expand_dims(mask_to_save[:, :, i], axis=2) # img_to_save = img_to_save[y1:y2, x1:x2] x1 = int(x1.numpy()) y1 = int(y1.numpy()) x2 = int(x2.numpy()) y2 = int(y2.numpy()) color = random.sample(range(0, 255), 3) color_mask = random.sample(range(0, 255), 3) cv2.rectangle(inp, (x1, y1), (x2, y2), color, int(inp.shape[0] * 0.01)) inp = apply_mask(inp, mask_to_save[:, :, i], [0, 0, 255], 0.5) inp = cv2.resize(inp, (inp.shape[1] // 4, inp.shape[0] // 4)) inp = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB) print(inp.shape) return inp / 255. iface = gr.Interface(detect_corn, gr.inputs.Image(type="numpy"), "image") iface.launch()