jpterry's picture
added main method
952c48d
raw history blame
No virus
7.04 kB
from matplotlib import cm
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
# import onnx
import onnxruntime as ort
# from onnx import helper
# from optimum.onnxruntime import ORTModel
# import pandas as pd
from PIL import Image
from scipy import special
# import torch
# import torch.utils.data
import gradio as gr
# from transformers import pipeline
# model_path = 'chlab/planet_detection_models/'
model_path = './models/'
# plotting a prameters
labels = 20
ticks = 14
legends = 14
text = 14
titles = 22
lw = 3
ps = 200
cmap = 'magma'
def normalize_array(x: list):
'''Makes array between 0 and 1'''
x = np.array(x)
return (x - np.min(x)) / np.max(x - np.min(x))
def load_model(model: str, activation: bool=True):
if activation:
model += '_w_activation'
options = ort.SessionOptions()
options.intra_op_num_threads = 1
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
provider = "CPUExecutionProvider"
ort_session = ort.InferenceSession(model_path + '%s.onnx' % (model), options, providers=[provider])
# ort_session = ORTModel.load_model(model_path + '%s.onnx' % (model))
return ort_session
def get_activations(intermediate_model, image: list,
layer=None, vmax=2.5, sub_mean=True):
'''Gets activations for a given input image'''
input_name = intermediate_model.get_inputs()[0].name
outputs = intermediate_model.run(None, {input_name: image})
# outputs = intermediate_model(image)
output_1 = outputs[1]
output_2 = outputs[2]
output = outputs[0][0]
output = special.softmax(output)
# origin = 'lower'
# plt.rcParams['xtick.labelsize'] = ticks
# plt.rcParams['ytick.labelsize'] = ticks
# fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(28, 8))
# ax1, ax2, ax3 = axs[0], axs[1], axs[2]
in_image = np.sum(image[0, :, :, :], axis=0)
in_image = normalize_array(in_image)
# im1 = ax1.imshow(in_image, cmap=cmap, vmin=0, vmax=vmax, origin=origin)
if layer is None:
activation_1 = np.sum(output_1[0, :, :, :], axis=0)
activation_2 = np.sum(output_2[0, :, :, :], axis=0)
else:
activation_1 = output_1[0, layer, :, :]
activation_2 = output_2[0, layer, :, :]
if sub_mean:
activation_1 -= np.mean(activation_1)
activation_1 = np.abs(activation_1)
activation_2 -= np.mean(activation_2)
activation_2 = np.abs(activation_2)
# im2 = ax2.imshow(activation_1, cmap=cmap, #vmin=0, vmax=1,
# origin=origin)
# im3 = ax3.imshow(activation_2, cmap=cmap, #vmin=0, vmax=1,
# origin=origin)
# ims = [im1, im2, im3]
# for (i, ax) in enumerate(axs):
# divider = make_axes_locatable(ax)
# cax = divider.append_axes('right', size='5%', pad=0.05)
# fig.colorbar(ims[i], cax=cax, orientation='vertical')
# ax1.set_title('Input', fontsize=titles)
# plt.show()
return output, in_image, activation_1, activation_2
def predict_and_analyze(model_name, num_channels, dim, image):
'''Loads a model with activations, passes through image and shows activations
The image must be a numpy array of shape (C, W, W) or (1, C, W, W)
'''
num_channels = int(num_channels)
W = int(dim)
# image = image.read()
# with open(image, 'rb') as f:
# im = f.readlines()
# image = np.frombuffer(image)
print("Loading data")
image = np.load(image.name, allow_pickle=True)
# image = image.reshape((num_channels, W, W))
# W = int(np.sqrt(image.shape[1]))
# image = image.reshape((num_channels, W, W))
if len(image.shape) != 4:
image = image[np.newaxis, :, :, :]
assert image.shape == (1, num_channels, W, W), "Data is the wrong shape"
model_name += '_%i' % (num_channels)
print("Loading model")
model = load_model(model_name, activation=True)
print("Model loaded")
print("Looking at activations")
output, input_image, activation_1, activation_2 = get_activations(model, image, sub_mean=True)
print("Activations and predictions finished")
if output[0] < output[1]:
output = 'Planet predicted with %f percent confidence' % (100*output[1])
else:
output = 'No planet predicted with %f percent confidence' % (100 - 100*output[0])
input_image = normalize_array(input_image)
activation_1 = normalize_array(activation_1)
activation_2 = normalize_array(activation_2)
# convert input image to RGB
input_image = Image.fromarray(np.uint8(cm.magma(input_image)*255))
print("Plotting")
origin = 'lower'
plt.rcParams['xtick.labelsize'] = ticks
plt.rcParams['ytick.labelsize'] = ticks
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(19, 8))
ax1, ax2 = axs[0], axs[1]
im1 = ax1.imshow(activation_1, cmap=cmap, #vmin=0, vmax=1,
origin=origin)
im2 = ax2.imshow(activation_2, cmap=cmap, #vmin=0, vmax=1,
origin=origin)
ims = [im1, im2]
for (i, ax) in enumerate(axs):
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(ims[i], cax=cax, orientation='vertical')
ax1.set_title('First Activation', fontsize=titles)
ax2.set_title('Second Activation', fontsize=titles)
print("Sending to Hugging Face")
return output, input_image, fig
if __name__ == "__main__":
demo = gr.Interface(
fn=predict_and_analyze,
inputs=[gr.Dropdown(["regnet", "efficientnet"],
value="efficientnet",
label="Model Selection",
show_label=True),
gr.Dropdown(["45", "61", "75"],
value="61",
label="Number of Velocity Channels",
show_label=True),
gr.Dropdown(["600"],
value="600",
label="Image Dimensions",
show_label=True),
gr.File(label="Input Data", show_label=True)],
outputs=[gr.Textbox(lines=1, label="Prediction", show_label=True),
gr.Image(label="Input Image", show_label=True),
# gr.Image(label="Activation 1", show_label=True),
# gr.Image(label="Actication 2", show_label=True)],
gr.Plot(label="Activations", show_label=True)
# gr.Plot(label="Actication 2", show_label=True)],
],
title="Kinematic Planet Detector"
)
demo.launch()