klucas12345 commited on
Commit
9340c91
1 Parent(s): c26f9ae
Files changed (1) hide show
  1. t1.py +98 -0
t1.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import numpy as np
3
+
4
+ from transformers import ViTFeatureExtractor, ViTModel, ViTForImageClassification, TrainingArguments, Trainer, \
5
+ default_data_collator, EarlyStoppingCallback
6
+ from transformers.modeling_outputs import SequenceClassifierOutput
7
+ from datasets import load_dataset, load_metric, Features, ClassLabel, Array3D
8
+
9
+ train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]'])
10
+ splits = train_ds.train_test_split(test_size=0.1)
11
+ train_ds = splits['train']
12
+ val_ds = splits['test']
13
+
14
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
15
+ data_collator = default_data_collator
16
+
17
+
18
+ def preprocess_images(examples):
19
+ images = examples['img']
20
+ images = [np.array(image, dtype=np.uint8) for image in images]
21
+ images = [np.moveaxis(image, source=-1, destination=0) for image in images]
22
+ inputs = feature_extractor(images=images)
23
+ examples['pixel_values'] = inputs['pixel_values']
24
+
25
+ return examples
26
+
27
+
28
+ features = Features({
29
+ 'label': ClassLabel(
30
+ names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']),
31
+ 'img': Array3D(dtype="int64", shape=(3, 32, 32)),
32
+ 'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
33
+ })
34
+
35
+ preprocessed_train_ds = train_ds.map(preprocess_images, batched=True, features=features)
36
+ preprocessed_val_ds = val_ds.map(preprocess_images, batched=True, features=features)
37
+ preprocessed_test_ds = test_ds.map(preprocess_images, batched=True, features=features)
38
+
39
+
40
+ class ViTForImageClassification2(nn.Module):
41
+ def __init__(self, num_labels=10):
42
+ super(ViTForImageClassification2, self).__init__()
43
+ self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
44
+ self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
45
+ self.num_labels = num_labels
46
+
47
+ def forward(self, pixel_values, labels):
48
+ outputs = self.vit(pixel_values=pixel_values)
49
+ logits = self.classifier(outputs.last_hidden_state[:, 0])
50
+
51
+ loss = None
52
+ if labels is not None:
53
+ loss_fct = nn.CrossEntropyLoss()
54
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
55
+
56
+ return SequenceClassifierOutput(
57
+ loss=loss,
58
+ logits=logits,
59
+ hidden_states=outputs.hidden_states,
60
+ attentions=outputs.attentions,
61
+ )
62
+
63
+
64
+ args = TrainingArguments(
65
+ f"test-cifar-10",
66
+ evaluation_strategy="epoch",
67
+ learning_rate=2e-5,
68
+ per_device_train_batch_size=10,
69
+ per_device_eval_batch_size=4,
70
+ num_train_epochs=3,
71
+ weight_decay=0.01,
72
+ load_best_model_at_end=True,
73
+ metric_for_best_model="accuracy",
74
+ logging_dir='logs',
75
+ )
76
+
77
+ # model = ViTForImageClassification()
78
+ model = ViTForImageClassification2()
79
+
80
+
81
+ def compute_metrics(eval_pred):
82
+ predictions, labels = eval_pred
83
+ predictions = np.argmax(predictions, axis=1)
84
+ return load_metric("accuracy").compute(predictions=predictions, references=labels)
85
+
86
+
87
+ trainer = Trainer(
88
+ model,
89
+ args,
90
+ train_dataset=preprocessed_train_ds,
91
+ eval_dataset=preprocessed_val_ds,
92
+ data_collator=data_collator,
93
+ compute_metrics=compute_metrics,
94
+ )
95
+
96
+ trainer.train()
97
+
98
+ outputs = trainer.predict(preprocessed_test_ds)