osanseviero HF staff commited on
Commit
852d2e4
1 Parent(s): a378c4b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +7 -6
main.py CHANGED
@@ -18,16 +18,17 @@ def to_numpy(examples):
18
  return examples
19
 
20
  def preprocess():
21
- test_dataset = load_dataset("active-learning/test_mnist")
22
- train_dataset = load_dataset("active-learning/labeled_samples")
23
  train_dataset = train_dataset.map(to_numpy, batched=True)
 
 
24
  test_dataset = test_dataset.map(to_numpy, batched=True)
25
 
26
- x_train = train_dataset["train"]["pixel_values"]
27
- y_train = train_dataset["train"]["label"]
28
 
29
- x_test = test_dataset["test"]["pixel_values"]
30
- y_test = test_dataset["test"]["label"]
31
 
32
  x_train = np.expand_dims(x_train, -1)
33
  x_test = np.expand_dims(x_test, -1)
 
18
  return examples
19
 
20
  def preprocess():
21
+ train_dataset = load_dataset("active-learning/labeled_samples")["train"]
 
22
  train_dataset = train_dataset.map(to_numpy, batched=True)
23
+
24
+ test_dataset = load_dataset("active-learning/test_mnist")["test"]
25
  test_dataset = test_dataset.map(to_numpy, batched=True)
26
 
27
+ x_train = train_dataset["pixel_values"]
28
+ y_train = train_dataset["label"]
29
 
30
+ x_test = test_dataset["pixel_values"]
31
+ y_test = test_dataset["label"]
32
 
33
  x_train = np.expand_dims(x_train, -1)
34
  x_test = np.expand_dims(x_test, -1)