corn_demo / app.py
swel4ik's picture
Update app.py
2fb32ef
from yolact.yolact import Yolact
from yolact.utils.augmentations import BaseTransform, FastBaseTransform, Resize
from yolact.utils import timer
from yolact.layers.output_utils import postprocess, undo_image_transformation
from yolact.data import cfg, set_cfg, set_dataset
import torch
import cv2
import gradio as gr
import numpy as np
import random
use_cuda = torch.cuda.is_available()
DEVICE = torch.device('cuda' if use_cuda else 'cpu')
def apply_mask(image, mask, color, alpha=0.8):
"""Apply the given mask to the image.
"""
for c in range(3):
image[:, :, c] = np.where(mask == 1,
image[:, :, c] *
(1 - alpha) + alpha * color[c] * 255,
image[:, :, c])
return image
def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str=''):
"""
Note: If undo_transform=False then im_h and im_w are allowed to be None.
"""
if undo_transform:
img_numpy = undo_image_transformation(img, w, h)
img_gpu = torch.Tensor(img_numpy).cpu()
else:
img_gpu = img / 255.0
h, w, _ = img.shape
with timer.env('Postprocess'):
save = cfg.rescore_bbox
cfg.rescore_bbox = True
t = postprocess(dets_out, w, h, visualize_lincomb=False,
crop_masks=True,
score_threshold=0.99)
cfg.rescore_bbox = save
return t
net = Yolact()
# net.load_weights('./yolact/weights/yolact_base_1351_50000.pth')
net.load_state_dict(torch.load('./yolact/weights/yolact_base_1351_50000.pth', map_location=DEVICE))
net.eval()
def detect_corn(inp):
inp = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB)
print(inp.shape)
frame = torch.from_numpy(inp).cpu().float()
batch = FastBaseTransform()(frame.unsqueeze(0))
preds = net(batch)
classes, scores, box, mask = prep_display(preds, frame, None, None, undo_transform=False)
mask_to_save = mask.permute(1, 2, 0).cpu().detach().numpy()
# TODO перекрытие лишних боксов
for i, bbox in enumerate(box):
x1, y1, x2, y2 = box[i]
# img_to_save = inp * np.expand_dims(mask_to_save[:, :, i], axis=2)
# img_to_save = img_to_save[y1:y2, x1:x2]
x1 = int(x1.numpy())
y1 = int(y1.numpy())
x2 = int(x2.numpy())
y2 = int(y2.numpy())
color = random.sample(range(0, 255), 3)
color_mask = random.sample(range(0, 255), 3)
cv2.rectangle(inp, (x1, y1), (x2, y2), color, int(inp.shape[0] * 0.01))
inp = apply_mask(inp, mask_to_save[:, :, i], [0, 0, 255], 0.5)
inp = cv2.resize(inp, (inp.shape[1] // 4, inp.shape[0] // 4))
inp = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB)
print(inp.shape)
return inp / 255.
iface = gr.Interface(detect_corn, gr.inputs.Image(type="numpy"), "image")
iface.launch()