Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
from tensorflow.keras import datasets, layers, models | |
# Load the trained model | |
model = models.Sequential([ | |
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), | |
layers.MaxPooling2D((2, 2)), | |
layers.Conv2D(64, (3, 3), activation='relu'), | |
layers.MaxPooling2D((2, 2)), | |
layers.Conv2D(64, (3, 3), activation='relu'), | |
layers.Flatten(), | |
layers.Dense(64, activation='relu'), | |
layers.Dense(10) # 10 classes in CIFAR-10 | |
]) | |
model.load_weights("cifar10_modified_flag.weights.h5") | |
# class 3 is a cat | |
# Class mapping (0-9 with class 3 replaced by "FLAG{3883}") | |
class_mapping = {0: "airplane", 1: "automobile", 2: "bird", 3: "FLAG{3883}", 4: "deer", | |
5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} | |
# Function to preprocess the input image | |
def preprocess_image(image): | |
image = image.resize((32, 32)) # Resize to CIFAR-10 size | |
image = np.array(image) / 255.0 # Normalize pixel values | |
image = np.expand_dims(image, axis=0) # Add batch dimension | |
return image | |
# Prediction function | |
def predict(image): | |
# Preprocess the image | |
image = preprocess_image(image) | |
# Get the model's raw prediction (logits) | |
logits = model.predict(image) | |
# Convert logits to probabilities | |
probabilities = tf.nn.softmax(logits, axis=-1) | |
# Get the predicted class index | |
predicted_class = np.argmax(probabilities) | |
# Get the class name from the mapping | |
class_name = class_mapping[predicted_class] | |
return class_name | |
# Gradio interface | |
iface = gr.Interface( | |
fn=predict, # Function to call for prediction | |
inputs=gr.Image(type="pil", label="Upload an image from CIFAR-10"), # Input: Image upload | |
outputs=gr.Textbox(label="Predicted Class"), # Output: Text showing predicted class | |
title="Vault Challenge 2 - BIM", # Title of the interface | |
description="Upload an image, and the model will predict the class. Try to fool the model into predicting the FLAG using BIM!. Tips: tune the parameters to make the model predict the image as a cat (class 3)." | |
) | |
# Launch the Gradio interface | |
iface.launch() | |