osanseviero commited on
Commit
9888dff
1 Parent(s): 9d2dae9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tensorflow import keras
3
+ from tensorflow.keras import layers
4
+
5
+ def to_numpy(examples):
6
+ examples["pixel_values"] = [np.array(image) for image in examples["image"]]
7
+ return examples
8
+
9
+ def preprocess():
10
+ test_dataset = load_dataset("active-learning/test_mnist")
11
+ train_dataset = load_dataset("active-learning/labeled_samples")
12
+ train_dataset = train_dataset.map(to_numpy, batched=True)
13
+ test_dataset = test_dataset.map(to_numpy, batched=True)
14
+
15
+ x_train = train_dataset["train"]["pixel_values"]
16
+ y_train = train_dataset["train"]["label"]
17
+
18
+ x_test = test_dataset["test"]["pixel_values"]
19
+ y_test = test_dataset["test"]["label"]
20
+
21
+ x_train = np.expand_dims(x_train, -1)
22
+ x_test = np.expand_dims(x_test, -1)
23
+
24
+ num_classes = 10
25
+ input_shape = (28, 28, 1)
26
+
27
+ y_train = keras.utils.to_categorical(y_train, num_classes)
28
+ y_test = keras.utils.to_categorical(y_test, num_classes)
29
+
30
+ return x_train, y_train, x_test, y_test
31
+
32
+ def training():
33
+ x_train, y_train, x_test, y_test = preprocess()
34
+
35
+ model = keras.Sequential(
36
+ [
37
+ keras.Input(shape=input_shape),
38
+ layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
39
+ layers.MaxPooling2D(pool_size=(2, 2)),
40
+ layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
41
+ layers.MaxPooling2D(pool_size=(2, 2)),
42
+ layers.Flatten(),
43
+ layers.Dropout(0.5),
44
+ layers.Dense(num_classes, activation="softmax"),
45
+ ]
46
+ )
47
+
48
+ model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
49
+ model.fit(x_train, y_train, batch_size=128, epochs=15, validation_split=0.1)
50
+
51
+ score = model.evaluate(x_test, y_test, verbose=0)
52
+ print("Test loss:", score[0])
53
+ print("Test accuracy:", score[1])
54
+