microScan / modInference.py
crazyscientist
Update modInference.py
6ce5c48 unverified
raw
history blame
2.88 kB
import numpy as np
import cv2
import torch
import glob as glob
import os
import matplotlib.pyplot as plt
from models.create_fasterrcnn_model import create_model
from utils.annotations import CNNpostAnnotations
#from utils.annotations import inference_annotations
from utils.general import set_infer_dir
from utils.transforms import infer_transforms
import numpy as np
from skimage import transform
import os
from keras.models import Model
from keras.optimizers import Adam
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.layers import Dense, Dropout, Flatten
import numpy as np
conv_base = VGG16(include_top=False,
weights='imagenet',
input_shape=(200,200,3))
if 2 > 0:
for layer in conv_base.layers[:-2]:
layer.trainable = False
else:
for layer in conv_base.layers:
layer.trainable = False
top_model = conv_base.output
top_model = Flatten(name="flatten")(top_model)
top_model = Dense(4096, activation='relu')(top_model)
top_model = Dense(1048, activation='relu')(top_model)
top_model = Dense(256, activation='relu')(top_model)
top_model = Dense(128, activation='relu')(top_model)
top_model = Dense(64, activation='relu')(top_model)
top_model = Dropout(0.2)(top_model)
output_layer = Dense(5, activation='softmax')(top_model)
CNN = Model(inputs=conv_base.input, outputs=output_layer)
CNN.load_weights("CNN.hdf5")
def main(weightUrl, input):
np.random.seed(42)
NUM_CLASSES = 2
CLASSES = ['__background__', 'Cell']
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
OUT_DIR = set_infer_dir()
checkpoint = torch.load(weightUrl, map_location=DEVICE)
data_configs = True
NUM_CLASSES = checkpoint['config']['NC']
CLASSES = checkpoint['config']['CLASSES']
build_model = create_model[checkpoint['model_name']]
model = build_model(num_classes=NUM_CLASSES, coco_model=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE).eval()
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
image = input
orig_image = image.copy()
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
image = infer_transforms(image)
image = torch.unsqueeze(image, 0)
outputs = model(image.to(DEVICE))
# Load all detection to CPU for further operations.
outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs]
print(outputs)
# Carry further only if there are detected boxes.
if len(outputs[0]['boxes']) != 0:
# orig_image = inference_annotations(
# outputs, 0.3, CLASSES,
# (255, 255, 255), orig_image
# )
orig_image, cellImgs = CNNpostAnnotations(
outputs, 0.3, CLASSES,
(255, 255, 255), orig_image, CNN
)
return orig_image, cellImgs
cv2.destroyAllWindows()