import numpy as np import torch import torchvision.transforms as T import torchvision from model import get_model import os import cv2 from helper import * dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") name2idx = {'spatter': 1, 'undercut': 2} idx2name = {v: k for k, v in name2idx.items() } transform_norm = T.Compose([ T.Resize((640,640)), T.ToTensor(), ]) np.random.seed(1) class_to_color = {name2idx[v]: np.random.randint(0, 255, 3) for v in name2idx} to_name = np.vectorize(lambda x: idx2name[x]) model_path = os.path.join("model") def load_model(name): model = get_model() model.to(dev) model.load_state_dict(torch.load(os.path.join(model_path, name), map_location=dev)) model.eval() return model def transform_img(img): original_width, original_height = img.size scale_x = original_width / 640 scale_y = original_height / 640 img_normalized = transform_norm(img) img = img_normalized.unsqueeze(dim=0) return img, (scale_x, scale_y) def predict_image(model, img, scale, iou_threshold=0.3): (scale_x, scale_y) = scale outputs = model(img)[0] boxes = outputs['boxes'].data scores = outputs['scores'].data labels = outputs['labels'].data keep = torchvision.ops.nms(boxes, scores, iou_threshold) boxes = boxes[keep] scores = scores[keep] labels = labels[keep] # TODO: return an empty array when no objects are detected pred = {} pred["boxes"] = (boxes * torch.tensor([scale_x, scale_y, scale_x, scale_y])).int().numpy() pred["scores"] = (scores * 100).int().numpy() pred["labels"] = labels.numpy() return pred def create_img_pred(img_path,result): bboxs_pred, probs_pred, classes_pred = result["boxes"], result["scores"], result["labels"] img = cv2.imread(img_path) # preprocess image _, ratio = format_img(img) for bbox, prob, cls in zip(bboxs_pred, probs_pred, classes_pred): label = cls [x1, y1, x2, y2] = bbox (real_x1, real_y1, real_x2, real_y2) = get_real_coordinates(ratio, x1, y1, x2, y2) cv2.rectangle(img,(real_x1, real_y1), (real_x2, real_y2), (int(class_to_color[label][0]), int(class_to_color[label][1]), int(class_to_color[label][2])),2) textLabel = '{}: {}'.format(idx2name[label], prob) (retval,baseLine) = cv2.getTextSize(textLabel,cv2.FONT_HERSHEY_COMPLEX,1,1) textOrg = (real_x1, real_y1 - 0) cv2.rectangle(img, (textOrg[0] - 5, textOrg[1] + baseLine - 5), (textOrg[0] + retval[0] + 5, textOrg[1]-retval[1] - 5), (0, 0, 0), 2) cv2.rectangle(img, (textOrg[0] - 5, textOrg[1] + baseLine - 5), (textOrg[0] + retval[0] + 5, textOrg[1]-retval[1] - 5), (255, 255, 255), -1) cv2.putText(img, textLabel, textOrg, cv2.FONT_HERSHEY_DUPLEX, 1, (0, 0, 0), 1) return img