microScan / modInference.py
crazyscientist
Update modInference.py
7b9d0de unverified
raw
history blame
1.24 kB
import numpy as np
import cv2
import torch
import glob as glob
import os
import time
import matplotlib.pyplot as plt
from utils.annotations import CNNpostAnnotations
#from utils.annotations import inference_annotations
from utils.transforms import infer_transforms
def main(CNN, model, input):
np.random.seed(42)
image = input
orig_image = image.copy()
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
image = infer_transforms(image)
image = torch.unsqueeze(image, 0)
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
CLASSES = ['__background__', 'Cell']
outputs = model(image.to(DEVICE))
# Load all detection to CPU for further operations.
outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs]
print(outputs)
# Carry further only if there are detected boxes.
if len(outputs[0]['boxes']) != 0:
# orig_image = inference_annotations(
# outputs, 0.3, CLASSES,
# (255, 255, 255), orig_image
# )
orig_image, cellImgs = CNNpostAnnotations(
outputs, 0.3, CLASSES,
(255, 255, 255), orig_image, CNN
)
return orig_image, cellImgs
cv2.destroyAllWindows()