Spaces:
Running
Running
import numpy as np | |
import cv2 | |
import torch | |
import glob as glob | |
import os | |
import matplotlib.pyplot as plt | |
from models.create_fasterrcnn_model import create_model | |
from utils.annotations import CNNpostAnnotations | |
#from utils.annotations import inference_annotations | |
from utils.general import set_infer_dir | |
from utils.transforms import infer_transforms | |
import numpy as np | |
from skimage import transform | |
import os | |
from keras.models import Model | |
from keras.optimizers import Adam | |
from keras.applications.vgg16 import VGG16, preprocess_input | |
from keras.layers import Dense, Dropout, Flatten | |
import numpy as np | |
conv_base = VGG16(include_top=False, | |
weights='imagenet', | |
input_shape=(200,200,3)) | |
if 2 > 0: | |
for layer in conv_base.layers[:-2]: | |
layer.trainable = False | |
else: | |
for layer in conv_base.layers: | |
layer.trainable = False | |
top_model = conv_base.output | |
top_model = Flatten(name="flatten")(top_model) | |
top_model = Dense(4096, activation='relu')(top_model) | |
top_model = Dense(1048, activation='relu')(top_model) | |
top_model = Dense(256, activation='relu')(top_model) | |
top_model = Dense(128, activation='relu')(top_model) | |
top_model = Dense(64, activation='relu')(top_model) | |
top_model = Dropout(0.2)(top_model) | |
output_layer = Dense(5, activation='softmax')(top_model) | |
CNN = Model(inputs=conv_base.input, outputs=output_layer) | |
CNN.load_weights("CNN.hdf5") | |
def main(weightUrl, input): | |
np.random.seed(42) | |
NUM_CLASSES = 2 | |
CLASSES = ['__background__', 'Cell'] | |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
OUT_DIR = set_infer_dir() | |
checkpoint = torch.load(weightUrl, map_location=DEVICE) | |
data_configs = True | |
NUM_CLASSES = checkpoint['config']['NC'] | |
CLASSES = checkpoint['config']['CLASSES'] | |
build_model = create_model[checkpoint['model_name']] | |
model = build_model(num_classes=NUM_CLASSES, coco_model=False) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(DEVICE).eval() | |
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) | |
image = input | |
orig_image = image.copy() | |
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) | |
image = infer_transforms(image) | |
image = torch.unsqueeze(image, 0) | |
outputs = model(image.to(DEVICE)) | |
# Load all detection to CPU for further operations. | |
outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs] | |
print(outputs) | |
# Carry further only if there are detected boxes. | |
if len(outputs[0]['boxes']) != 0: | |
# orig_image = inference_annotations( | |
# outputs, 0.3, CLASSES, | |
# (255, 255, 255), orig_image | |
# ) | |
orig_image, cellImgs = CNNpostAnnotations( | |
outputs, 0.3, CLASSES, | |
(255, 255, 255), orig_image, CNN | |
) | |
return orig_image, cellImgs | |
cv2.destroyAllWindows() | |