Spaces:
Runtime error
Runtime error
from matplotlib import cm | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
import numpy as np | |
# import onnx | |
import onnxruntime as ort | |
# from onnx import helper | |
# from optimum.onnxruntime import ORTModel | |
# import pandas as pd | |
from PIL import Image | |
from scipy import special | |
# import torch | |
# import torch.utils.data | |
import gradio as gr | |
# from transformers import pipeline | |
# model_path = 'chlab/planet_detection_models/' | |
model_path = './models/' | |
# plotting a prameters | |
labels = 20 | |
ticks = 14 | |
legends = 14 | |
text = 14 | |
titles = 22 | |
lw = 3 | |
ps = 200 | |
cmap = 'magma' | |
def normalize_array(x: list): | |
'''Makes array between 0 and 1''' | |
x = np.array(x) | |
return (x - np.min(x)) / np.max(x - np.min(x)) | |
def load_model(model: str, activation: bool=True): | |
if activation: | |
model += '_w_activation' | |
options = ort.SessionOptions() | |
options.intra_op_num_threads = 1 | |
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
provider = "CPUExecutionProvider" | |
ort_session = ort.InferenceSession(model_path + '%s.onnx' % (model), options, providers=[provider]) | |
# ort_session = ORTModel.load_model(model_path + '%s.onnx' % (model)) | |
return ort_session | |
def get_activations(intermediate_model, image: list, | |
layer=None, vmax=2.5, sub_mean=True): | |
'''Gets activations for a given input image''' | |
input_name = intermediate_model.get_inputs()[0].name | |
outputs = intermediate_model.run(None, {input_name: image}) | |
# outputs = intermediate_model(image) | |
output_1 = outputs[1] | |
output_2 = outputs[2] | |
output = outputs[0][0] | |
output = special.softmax(output) | |
# origin = 'lower' | |
# plt.rcParams['xtick.labelsize'] = ticks | |
# plt.rcParams['ytick.labelsize'] = ticks | |
# fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(28, 8)) | |
# ax1, ax2, ax3 = axs[0], axs[1], axs[2] | |
in_image = np.sum(image[0, :, :, :], axis=0) | |
in_image = normalize_array(in_image) | |
# im1 = ax1.imshow(in_image, cmap=cmap, vmin=0, vmax=vmax, origin=origin) | |
if layer is None: | |
activation_1 = np.sum(output_1[0, :, :, :], axis=0) | |
activation_2 = np.sum(output_2[0, :, :, :], axis=0) | |
else: | |
activation_1 = output_1[0, layer, :, :] | |
activation_2 = output_2[0, layer, :, :] | |
if sub_mean: | |
activation_1 -= np.mean(activation_1) | |
activation_1 = np.abs(activation_1) | |
activation_2 -= np.mean(activation_2) | |
activation_2 = np.abs(activation_2) | |
# im2 = ax2.imshow(activation_1, cmap=cmap, #vmin=0, vmax=1, | |
# origin=origin) | |
# im3 = ax3.imshow(activation_2, cmap=cmap, #vmin=0, vmax=1, | |
# origin=origin) | |
# ims = [im1, im2, im3] | |
# for (i, ax) in enumerate(axs): | |
# divider = make_axes_locatable(ax) | |
# cax = divider.append_axes('right', size='5%', pad=0.05) | |
# fig.colorbar(ims[i], cax=cax, orientation='vertical') | |
# ax1.set_title('Input', fontsize=titles) | |
# plt.show() | |
return output, in_image, activation_1, activation_2 | |
def predict_and_analyze(model_name, num_channels, dim, image): | |
'''Loads a model with activations, passes through image and shows activations | |
The image must be a numpy array of shape (C, W, W) or (1, C, W, W) | |
''' | |
num_channels = int(num_channels) | |
W = int(dim) | |
# image = image.read() | |
# with open(image, 'rb') as f: | |
# im = f.readlines() | |
# image = np.frombuffer(image) | |
print("Loading data") | |
image = np.load(image.name, allow_pickle=True) | |
# image = image.reshape((num_channels, W, W)) | |
# W = int(np.sqrt(image.shape[1])) | |
# image = image.reshape((num_channels, W, W)) | |
if len(image.shape) != 4: | |
image = image[np.newaxis, :, :, :] | |
assert image.shape == (1, num_channels, W, W), "Data is the wrong shape" | |
model_name += '_%i' % (num_channels) | |
print("Loading model") | |
model = load_model(model_name, activation=True) | |
print("Model loaded") | |
print("Looking at activations") | |
output, input_image, activation_1, activation_2 = get_activations(model, image, sub_mean=True) | |
print("Activations and predictions finished") | |
if output[0] < output[1]: | |
output = 'Planet predicted with %f percent confidence' % (100*output[1]) | |
else: | |
output = 'No planet predicted with %f percent confidence' % (100 - 100*output[0]) | |
input_image = normalize_array(input_image) | |
activation_1 = normalize_array(activation_1) | |
activation_2 = normalize_array(activation_2) | |
# convert input image to RGB | |
input_image = Image.fromarray(np.uint8(cm.magma(input_image)*255)) | |
print("Plotting") | |
origin = 'lower' | |
plt.rcParams['xtick.labelsize'] = ticks | |
plt.rcParams['ytick.labelsize'] = ticks | |
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(19, 8)) | |
ax1, ax2 = axs[0], axs[1] | |
im1 = ax1.imshow(activation_1, cmap=cmap, #vmin=0, vmax=1, | |
origin=origin) | |
im2 = ax2.imshow(activation_2, cmap=cmap, #vmin=0, vmax=1, | |
origin=origin) | |
ims = [im1, im2] | |
for (i, ax) in enumerate(axs): | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes('right', size='5%', pad=0.05) | |
fig.colorbar(ims[i], cax=cax, orientation='vertical') | |
ax1.set_title('First Activation', fontsize=titles) | |
ax2.set_title('Second Activation', fontsize=titles) | |
print("Sending to Hugging Face") | |
return output, input_image, fig | |
if __name__ == "__main__": | |
demo = gr.Interface( | |
fn=predict_and_analyze, | |
inputs=[gr.Dropdown(["regnet", "efficientnet"], | |
value="efficientnet", | |
label="Model Selection", | |
show_label=True), | |
gr.Dropdown(["45", "61", "75"], | |
value="61", | |
label="Number of Velocity Channels", | |
show_label=True), | |
gr.Dropdown(["600"], | |
value="600", | |
label="Image Dimensions", | |
show_label=True), | |
gr.File(label="Input Data", show_label=True)], | |
outputs=[gr.Textbox(lines=1, label="Prediction", show_label=True), | |
gr.Image(label="Input Image", show_label=True), | |
# gr.Image(label="Activation 1", show_label=True), | |
# gr.Image(label="Actication 2", show_label=True)], | |
gr.Plot(label="Activations", show_label=True) | |
# gr.Plot(label="Actication 2", show_label=True)], | |
], | |
title="Kinematic Planet Detector" | |
) | |
demo.launch() | |