|
import gradio as gr |
|
from huggingface_hub import hf_hub_url, cached_download |
|
from matplotlib import cm |
|
import matplotlib.pyplot as plt |
|
from mpl_toolkits.axes_grid1 import make_axes_locatable |
|
import numpy as np |
|
|
|
from PIL import Image |
|
from scipy import special |
|
import sys |
|
|
|
from types import SimpleNamespace |
|
|
|
from transformers import AutoModelForImageClassification, AutoModel, AutoConfig |
|
import torch |
|
|
|
sys.path.insert(1, "../") |
|
|
|
|
|
from model_utils.efficientnet_config import EfficientNetConfig, EfficientNetPreTrained, EfficientNet |
|
|
|
model_path = 'chlab/' |
|
|
|
|
|
|
|
labels = 20 |
|
ticks = 14 |
|
legends = 14 |
|
text = 14 |
|
titles = 22 |
|
lw = 3 |
|
ps = 200 |
|
cmap = 'magma' |
|
|
|
effnet_hparams = {47: {"num_classes": 2, |
|
"gamma": 0.04294256770072906, |
|
"lr": 0.010208864616781627, |
|
"weight_decay": 0.00014537466483781656, |
|
"batch_size": 16, |
|
"num_channels": 47, |
|
"stochastic_depth_prob": 0.017760418815821067, |
|
"dropout": 0.039061686292663655, |
|
"width_mult": 0.7540060155156922, |
|
"depth_mult": 0.9378692812212488, |
|
"size": "v2_s", |
|
"model_type": "efficientnet_47_planet_detection" |
|
}, |
|
61: { |
|
"num_classes": 2, |
|
"gamma": 0.032606396652426956, |
|
"lr": 0.008692971067922545, |
|
"weight_decay": 0.00008348389688708425, |
|
"batch_size": 23, |
|
"num_channels": 61, |
|
"stochastic_depth_prob": 0.003581930052432713, |
|
"dropout": 0.027804120950575217, |
|
"width_mult": 1.060782511229692, |
|
"depth_mult": 0.7752918857163054, |
|
"size": "v2_s", |
|
"model_type": "efficientnet_61_planet_detection" |
|
}, |
|
75: { |
|
"num_classes": 2, |
|
"gamma": 0.029768470449465057, |
|
"lr": 0.008383851744497892, |
|
"weight_decay": 0.000196304392793202, |
|
"batch_size": 32, |
|
"num_channels": 75, |
|
"stochastic_depth_prob": 0.08398410137077088, |
|
"dropout": 0.03351826828687193, |
|
"width_mult": 1.144132674734038, |
|
"depth_mult": 1.2267023928285563, |
|
"size": "v2_s", |
|
"model_type": "efficientnet_75_planet_detection" |
|
} |
|
} |
|
|
|
|
|
|
|
activation_indices = {'efficientnet': [0, 3]} |
|
|
|
|
|
def normalize_array(x: list): |
|
|
|
'''Makes array between 0 and 1''' |
|
|
|
x = np.array(x) |
|
|
|
return (x - np.min(x)) / np.max(x - np.min(x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_activations(model, image: list, model_name: str, |
|
layer=None, vmax=2.5, sub_mean=True, |
|
channel: int=0): |
|
|
|
'''Gets activations for a given input image''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_outputs = {} |
|
temp_image = image |
|
for i in range(len(model.features)): |
|
temp_image = model.features[i](temp_image) |
|
if i in activation_indices[model_name]: |
|
layer_outputs[i] = temp_image |
|
|
|
if i == max(activation_indices[model_name]): |
|
break |
|
output = model(image).detach().cpu().numpy() |
|
|
|
|
|
image = image.detach().cpu().numpy() |
|
output_1 = layer_outputs[activation_indices[model_name][0]].detach().cpu().numpy() |
|
output_2 = layer_outputs[activation_indices[model_name][1]].detach().cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = special.softmax(output) |
|
print(output) |
|
|
|
|
|
if channel == 0: |
|
in_image = np.sum(image[0, :, :, :], axis=0) |
|
else: |
|
image[0, int(channel-1), :, :] |
|
in_image = normalize_array(in_image) |
|
|
|
if layer is None: |
|
|
|
activation_1 = np.sum(output_1[0, :, :, :], axis=0) |
|
activation_2 = np.sum(output_2[0, :, :, :], axis=0) |
|
else: |
|
|
|
activation_1 = output_1[0, layer, :, :] |
|
activation_2 = output_2[0, layer, :, :] |
|
|
|
if sub_mean: |
|
|
|
activation_1 -= np.mean(activation_1) |
|
activation_1 = np.abs(activation_1) |
|
|
|
activation_2 -= np.mean(activation_2) |
|
activation_2 = np.abs(activation_2) |
|
|
|
return output, in_image, activation_1, activation_2 |
|
|
|
def plot_input(input_image: list, origin='lower'): |
|
|
|
|
|
plt.rcParams['xtick.labelsize'] = ticks |
|
plt.rcParams['ytick.labelsize'] = ticks |
|
|
|
input_fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 5)) |
|
|
|
im0 = ax.imshow(input_image, cmap=cmap, |
|
origin=origin) |
|
|
|
divider = make_axes_locatable(ax) |
|
cax = divider.append_axes('right', size='5%', pad=0.05) |
|
input_fig.colorbar(im0, cax=cax, orientation='vertical') |
|
|
|
ax.set_title('Input', fontsize=titles) |
|
|
|
return input_fig |
|
|
|
def plot_activations(activation_1: list, activation_2: list, origin='lower'): |
|
|
|
|
|
|
|
plt.rcParams['xtick.labelsize'] = ticks |
|
plt.rcParams['ytick.labelsize'] = ticks |
|
|
|
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(18, 7)) |
|
|
|
ax1, ax2 = axs[0], axs[1] |
|
|
|
im1 = ax1.imshow(activation_1, cmap=cmap, |
|
origin=origin) |
|
im2 = ax2.imshow(activation_2, cmap=cmap, |
|
origin=origin) |
|
|
|
ims = [im1, im2] |
|
|
|
for (i, ax) in enumerate(axs): |
|
divider = make_axes_locatable(ax) |
|
cax = divider.append_axes('right', size='5%', pad=0.05) |
|
fig.colorbar(ims[i], cax=cax, orientation='vertical') |
|
|
|
|
|
ax1.set_title('Early Activation', fontsize=titles) |
|
ax2.set_title('Late Activation', fontsize=titles) |
|
|
|
return fig |
|
|
|
def predict_and_analyze(model_name, num_channels, dim, input_channel, image): |
|
|
|
''' |
|
Loads a model with activations, passes through image and shows activations |
|
|
|
The image must be a numpy array of shape (C, W, W) or (1, C, W, W) |
|
''' |
|
|
|
model_name = model_name.lower() |
|
num_channels = int(num_channels) |
|
W = int(dim) |
|
|
|
print("Running %s for %i channels" % (model_name, num_channels)) |
|
print("Loading data") |
|
|
|
|
|
image = np.load(image.name, allow_pickle=True) |
|
image = image.astype(np.float32) |
|
|
|
if len(image.shape) != 4: |
|
image = image[np.newaxis, :, :, :] |
|
|
|
image = torch.from_numpy(image) |
|
|
|
assert image.shape == (1, num_channels, W, W), "Data is the wrong shape" |
|
print("Data loaded") |
|
|
|
print("Loading model") |
|
|
|
model_loading_name = "%s_%i_planet_detection" % (model_name, num_channels) |
|
|
|
if 'eff' in model_name: |
|
hparams = effnet_hparams[num_channels] |
|
hparams = SimpleNamespace(**hparams) |
|
config = EfficientNetConfig( |
|
dropout=hparams.dropout, |
|
num_channels=hparams.num_channels, |
|
num_classes=hparams.num_classes, |
|
size=hparams.size, |
|
stochastic_depth_prob=hparams.stochastic_depth_prob, |
|
width_mult=hparams.width_mult, |
|
depth_mult=hparams.depth_mult, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = EfficientNet(dropout=hparams.dropout, |
|
num_channels=hparams.num_channels, |
|
num_classes=hparams.num_classes, |
|
size=hparams.size, |
|
stochastic_depth_prob=hparams.stochastic_depth_prob, |
|
width_mult=hparams.width_mult, |
|
depth_mult=hparams.depth_mult,) |
|
model_url = cached_download(hf_hub_url(model_path + model_loading_name, filename="pytorch_model.bin")) |
|
|
|
|
|
loaded = torch.load(model_url, map_location='cpu',) |
|
|
|
|
|
model.load_state_dict(loaded['state_dict']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Model loaded") |
|
|
|
print("Looking at activations") |
|
output, input_image, activation_1, activation_2 = get_activations(model, image, model_name, |
|
channel=input_channel, |
|
sub_mean=True) |
|
print("Activations and predictions finished") |
|
|
|
|
|
if output[0][0] < output[0][1]: |
|
output = 'Planet predicted with %.3f percent confidence' % (100*output[0][1]) |
|
else: |
|
output = 'No planet predicted with %.3f percent confidence' % (100*output[0][0]) |
|
|
|
print(output) |
|
|
|
input_image = normalize_array(input_image) |
|
activation_1 = normalize_array(activation_1) |
|
activation_2 = normalize_array(activation_2) |
|
|
|
|
|
|
|
|
|
print("Plotting") |
|
|
|
origin = 'lower' |
|
|
|
|
|
input_fig = plot_input(input_image, origin=origin) |
|
|
|
|
|
fig1 = plot_activations(activation_1, activation_2, origin=origin) |
|
|
|
|
|
_, _, activation_1, activation_2 = get_activations(model, image, model_name, |
|
channel=input_channel, |
|
sub_mean=False) |
|
activation_1 = normalize_array(activation_1) |
|
activation_2 = normalize_array(activation_2) |
|
fig2 = plot_activations(activation_1, activation_2, origin=origin) |
|
|
|
print("Sending to Hugging Face") |
|
|
|
return output, input_fig, fig1, fig2 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo = gr.Interface( |
|
fn=predict_and_analyze, |
|
inputs=[gr.Dropdown(["EfficientNet"], |
|
|
|
value="EfficientNet", |
|
label="Model Selection", |
|
show_label=True), |
|
gr.Dropdown(["47", "61", "75"], |
|
value="61", |
|
label="Number of Velocity Channels", |
|
show_label=True), |
|
gr.Dropdown(["600"], |
|
value="600", |
|
label="Image Dimensions", |
|
show_label=True), |
|
gr.Number(value=0., |
|
label="Input Channel to show (0 = sum over all)", |
|
show_label=True), |
|
gr.File(label="Input Data", show_label=True)], |
|
outputs=[gr.Textbox(lines=1, label="Prediction", show_label=True), |
|
|
|
gr.Plot(label="Input Image", show_label=True), |
|
gr.Plot(label="Mean-Subtracted Activations", show_label=True), |
|
gr.Plot(label="Raw Activations", show_label=True) |
|
], |
|
title="Kinematic Planet Detector" |
|
) |
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|