|
import sys |
|
sys.path.append('DenseMammogram') |
|
|
|
import torch |
|
|
|
from models import get_FRCNN_model, Bilateral_model |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
frcnn_model = get_FRCNN_model().to(device) |
|
bilat_model = Bilateral_model(frcnn_model).to(device) |
|
|
|
FRCNN_PATH = 'pretrained_models/frcnn/frcnn_models/frcnn_model.pth' |
|
BILAR_PATH = 'pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth' |
|
|
|
frcnn_model.load_state_dict(torch.load(FRCNN_PATH, map_location=device)) |
|
bilat_model.load_state_dict(torch.load(BILAR_PATH, map_location=device)) |
|
|
|
import os |
|
import torchvision.transforms as T |
|
import cv2 |
|
from tqdm import tqdm |
|
import detection.transforms as transforms |
|
from dataloaders import get_direction |
|
|
|
def predict(left_file, right_file, threshold = 0.80, baseIsLeft = True): |
|
model = bilat_model |
|
with torch.no_grad(): |
|
transform = T.Compose([T.ToPILImage(),T.ToTensor()]) |
|
model.eval() |
|
|
|
img1 = cv2.imread(left_file) |
|
img1 = transform(img1) |
|
img2 = cv2.imread(right_file) |
|
img2 = transform(img2) |
|
|
|
if baseIsLeft: |
|
img1,_ = transforms.RandomHorizontalFlip(1.0)(img1) |
|
else: |
|
img2,_ = transforms.RandomHorizontalFlip(1.0)(img2) |
|
|
|
|
|
images = [img1.to(device),img2.to(device)] |
|
output = model([images])[0] |
|
if baseIsLeft: |
|
img1,output = transforms.RandomHorizontalFlip(1.0)(img1,output) |
|
|
|
image = cv2.imread(left_file) |
|
for b,s,l in zip(output['boxes'], output['scores'], output['labels']): |
|
|
|
if l == 1 and s > threshold: |
|
|
|
b = b.detach().cpu().numpy().astype(int) |
|
|
|
cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 255, 0), 2) |
|
|
|
cv2.putText(image, 'Cancer: '+str(round(round(s.item(), 2) * 100, 1)) + '%', (b[0], b[1] - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.6, (36,255,12), 6) |
|
return image |