import numpy as np import tensorflow as tf import tensorflow.keras as keras import gradio as gr import matplotlib.pyplot as plt from huggingface_hub import from_pretrained_keras # download the already pushed model trained_models = [from_pretrained_keras("buio/attention_mil_classification")] POSITIVE_CLASS = 1 BAG_COUNT = 1000 VAL_BAG_COUNT = 300 BAG_SIZE = 3 PLOT_SIZE = 1 ENSEMBLE_AVG_COUNT = 1 def create_bags(input_data, input_labels, positive_class, bag_count, instance_count): # Set up bags. bags = [] bag_labels = [] # Normalize input data. input_data = np.divide(input_data, 255.0) # Count positive samples. count = 0 for _ in range(bag_count): # Pick a fixed size random subset of samples. index = np.random.choice(input_data.shape[0], instance_count, replace=False) instances_data = input_data[index] instances_labels = input_labels[index] # By default, all bags are labeled as 0. bag_label = 0 # Check if there is at least a positive class in the bag. if positive_class in instances_labels: # Positive bag will be labeled as 1. bag_label = 1 count += 1 bags.append(instances_data) bag_labels.append(np.array([bag_label])) print(f"Positive bags: {count}") print(f"Negative bags: {bag_count - count}") return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels)) # Load the MNIST dataset. (x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data() # Create validation data. val_data, val_labels = create_bags( x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE ) def predict(data, labels, trained_models): # Collect info per model. models_predictions = [] models_attention_weights = [] models_losses = [] models_accuracies = [] for model in trained_models: # Predict output classes on data. predictions = model.predict(data) models_predictions.append(predictions) # Create intermediate model to get MIL attention layer weights. intermediate_model = keras.Model(model.input, model.get_layer("alpha").output) # Predict MIL attention layer weights. intermediate_predictions = intermediate_model.predict(data) attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0)) models_attention_weights.append(attention_weights) model.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy"]) loss, accuracy = model.evaluate(data, labels, verbose=0) models_losses.append(loss) models_accuracies.append(accuracy) print( f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}" f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp." ) return ( np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT, np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT, ) def plot(data, labels, bag_class, predictions=None, attention_weights=None): """"Utility for plotting bags and attention weights. Args: data: Input data that contains the bags of instances. labels: The associated bag labels of the input data. bag_class: String name of the desired bag class. The options are: "positive" or "negative". predictions: Class labels model predictions. If you don't specify anything, ground truth labels will be used. attention_weights: Attention weights for each instance within the input data. If you don't specify anything, the values won't be displayed. """ labels = np.array(labels).reshape(-1) if bag_class == "positive": if predictions is not None: labels = np.where(predictions.argmax(1) == 1)[0] else: labels = np.where(labels == 1)[0] random_labels = np.random.choice(labels, PLOT_SIZE) bags = np.array(data)[:, random_labels] elif bag_class == "negative": if predictions is not None: labels = np.where(predictions.argmax(1) == 0)[0] else: labels = np.where(labels == 0)[0] random_labels = np.random.choice(labels, PLOT_SIZE) bags = np.array(data)[:, random_labels] else: print(f"There is no class {bag_class}") return print(f"The bag class label is {bag_class}") for i in range(PLOT_SIZE): figure = plt.figure(figsize=(8, 8)) #each image print(f"Bag number: {labels[i]}") for j in range(BAG_SIZE): image = bags[j][i] figure.add_subplot(1, BAG_SIZE, j + 1) plt.grid(False) plt.axis('off') if attention_weights is not None: plt.title(np.around(attention_weights[random_labels[i]][j], 2)) plt.imshow(image) plt.show() return figure # Evaluate and predict classes and attention scores on validation data. def predict_and_plot(class_): print('WTF') class_predictions, attention_params = predict(val_data, val_labels, trained_models) PLOT_SIZE = 1 return plot(val_data, val_labels, class_, predictions=class_predictions, attention_weights=attention_params) predict_and_plot('positive') inputs = gr.Radio(choices=['positive','negative']) outputs = gr.Plot(label='predicted bag') #title = "Heart Disease Classification 🩺❤️" #description = "Binary classification of structured data including numerical and categorical features." #article = "Author: Marco Buiani. Based on the keras example by François Chollet Model Link: https://huggingface.co/buio/structured-data-classification" demo = gr.Interface(fn=predict_and_plot, inputs=inputs, outputs=outputs, allow_flagging='never') demo.launch(debug=True)