Spaces:
Runtime error
Runtime error
| import os | |
| from fastapi import FastAPI, Request, Response | |
| import numpy as np | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| import tensorflow as tf | |
| from datasets import load_dataset | |
| from huggingface_hub import push_to_hub_keras | |
| KEY = os.environ.get("WEBHOOK_SECRET") | |
| app = FastAPI() | |
| def to_numpy(examples): | |
| examples["pixel_values"] = [np.array(image) for image in examples["image"]] | |
| return examples | |
| def preprocess(): | |
| test_dataset = load_dataset("active-learning/test_mnist") | |
| train_dataset = load_dataset("active-learning/labeled_samples") | |
| train_dataset = train_dataset.map(to_numpy, batched=True) | |
| test_dataset = test_dataset.map(to_numpy, batched=True) | |
| x_train = train_dataset["train"]["pixel_values"] | |
| y_train = train_dataset["train"]["label"] | |
| x_test = test_dataset["test"]["pixel_values"] | |
| y_test = test_dataset["test"]["label"] | |
| x_train = np.expand_dims(x_train, -1) | |
| x_test = np.expand_dims(x_test, -1) | |
| num_classes = 10 | |
| input_shape = (28, 28, 1) | |
| y_train = keras.utils.to_categorical(y_train, num_classes) | |
| y_test = keras.utils.to_categorical(y_test, num_classes) | |
| return x_train, y_train, x_test, y_test | |
| def train(): | |
| x_train, y_train, x_test, y_test = preprocess() | |
| model = keras.Sequential( | |
| [ | |
| keras.Input(shape=input_shape), | |
| layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), | |
| layers.MaxPooling2D(pool_size=(2, 2)), | |
| layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), | |
| layers.MaxPooling2D(pool_size=(2, 2)), | |
| layers.Flatten(), | |
| layers.Dropout(0.5), | |
| layers.Dense(num_classes, activation="softmax"), | |
| ] | |
| ) | |
| model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) | |
| model.fit(x_train, y_train, batch_size=128, epochs=15, validation_split=0.1) | |
| score = model.evaluate(x_test, y_test, verbose=0) | |
| print("Test loss:", score[0]) | |
| print("Test accuracy:", score[1]) | |
| push_to_hub_keras(model, "active-learning/mnist_classifier") | |
| def find_samples_to_label(): | |
| loaded_model = from_pretrained_keras("active-learning/mnist_classifier") | |
| loaded_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) | |
| unlabeled_data = load_dataset("active-learning/unlabeled_samples")["train"] | |
| processed_data = unlabeled_data.map(to_numpy, batched=True) | |
| processed_data = processed_data["pixel_values"] | |
| processed_data = tf.expand_dims(processed_data, -1) | |
| # Get all predictions | |
| # And then get the 5 samples with the lowest prediction score | |
| preds = loaded_model.predict(unlabeled_data) | |
| top_pred_confs = 1 - np.max(preds, axis=1) | |
| idx_to_label = np.argpartition(top_pred_confs, -5)[-5:] | |
| # Upload samples to the dataset to label | |
| to_label_data = unlabeled_data.select(idx_to_label) | |
| to_label_data.push_to_hub("active-learning/to_label_samples") | |
| # Remove from the pool of samples | |
| unlabeled_data = unlabeled_data.select( | |
| ( | |
| i for i in range(len(unlabeled_data)) | |
| if i not in set(idx_to_label) | |
| ) | |
| ) | |
| unlabeled_data.push_to_hub("active-learning/unlabeled_samples") | |
| def read_root(): | |
| data = """ | |
| <h2 style="text-align:center">Active Learning Trainer</h2> | |
| <p style="text-align:center">This is a demo app showing how to webhooks to do Active Learning.</p> | |
| """ | |
| return Response(content=data, media_type="text/html") | |
| async def webhook(request: Request): | |
| print("Received request") | |
| if request.headers.get("X-Webhook-Secret") is None: | |
| return Response("No secret", status_code=401) | |
| if request.headers.get("X-Webhook-Secret") != KEY: | |
| return Response("Invalid secret", status_code=401) | |
| data = await request.json() | |
| print("Webhook received!") | |
| train() | |
| find_samples_to_label() | |
| return "Webhook received!" if result else result | |