File size: 3,927 Bytes
ea2c8db
 
df6aa25
9888dff
 
 
ea2c8db
9888dff
ea2c8db
10432c9
 
df6aa25
 
 
 
9888dff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df6aa25
9888dff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10432c9
 
0c49c14
 
 
 
9888dff
0c49c14
 
 
 
 
 
 
 
 
 
 
7090155
0c49c14
7090155
 
 
 
 
 
 
 
 
 
 
df6aa25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os 

from fastapi import FastAPI, Request, Response
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf

from datasets import load_dataset
from huggingface_hub import push_to_hub_keras

KEY = os.environ.get("WEBHOOK_SECRET")

app = FastAPI()

def to_numpy(examples):
    examples["pixel_values"] = [np.array(image) for image in examples["image"]]
    return examples

def preprocess():
    test_dataset = load_dataset("active-learning/test_mnist")
    train_dataset = load_dataset("active-learning/labeled_samples")
    train_dataset = train_dataset.map(to_numpy, batched=True)
    test_dataset = test_dataset.map(to_numpy, batched=True)
    
    x_train = train_dataset["train"]["pixel_values"]
    y_train = train_dataset["train"]["label"]
    
    x_test = test_dataset["test"]["pixel_values"]
    y_test = test_dataset["test"]["label"]

    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)

    num_classes = 10
    input_shape = (28, 28, 1)
    
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)

    return x_train, y_train, x_test, y_test

def train():
    x_train, y_train, x_test, y_test = preprocess()

    model = keras.Sequential(
        [
            keras.Input(shape=input_shape),
            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dropout(0.5),
            layers.Dense(num_classes, activation="softmax"),
        ]
    )
    
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    model.fit(x_train, y_train, batch_size=128, epochs=15, validation_split=0.1)

    score = model.evaluate(x_test, y_test, verbose=0)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])

    push_to_hub_keras(model, "active-learning/mnist_classifier")

def find_samples_to_label():
    loaded_model = from_pretrained_keras("active-learning/mnist_classifier")
    loaded_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    
    unlabeled_data = load_dataset("active-learning/unlabeled_samples")["train"]
    processed_data = unlabeled_data.map(to_numpy, batched=True)
    processed_data = processed_data["pixel_values"]
    processed_data = tf.expand_dims(processed_data, -1)

    # Get all predictions
    # And then get the 5 samples with the lowest prediction score
    preds = loaded_model.predict(unlabeled_data)
    top_pred_confs = 1 - np.max(preds, axis=1)
    idx_to_label = np.argpartition(top_pred_confs, -5)[-5:]

    # Upload samples to the dataset to label
    to_label_data = unlabeled_data.select(idx_to_label)
    to_label_data.push_to_hub("active-learning/to_label_samples")

    # Remove from the pool of samples
    unlabeled_data = unlabeled_data.select(
        (
            i for i in range(len(unlabeled_data)) 
            if i not in set(idx_to_label)
        )
    )
    unlabeled_data.push_to_hub("active-learning/unlabeled_samples")

@app.get("/")
def read_root():
    data = """
    <h2 style="text-align:center">Active Learning Trainer</h2>
    <p style="text-align:center">This is a demo app showing how to webhooks to do Active Learning.</p>
    """
    return Response(content=data, media_type="text/html")

@app.post("/webhook")
async def webhook(request):
    if request.method == "POST":
        if request.headers.get("X-Webhook-Secret") != KEY:
            return Response("Invalid secret", status_code=401)
        data = await request.json()
        print("Webhook received!")
        train() 
        find_samples_to_label()
        return "Webhook received!" if result else result