File size: 3,369 Bytes
e391a54
 
 
 
 
03952c2
e391a54
 
34bca69
e391a54
 
 
 
 
 
 
 
 
 
 
 
 
 
03952c2
 
 
 
e391a54
03952c2
 
 
 
 
 
 
 
 
 
 
 
e391a54
 
 
 
 
 
34bca69
e391a54
34bca69
e391a54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34bca69
e391a54
 
 
 
 
 
 
 
 
 
 
 
 
34bca69
e391a54
62a2783
e391a54
62a2783
e391a54
62a2783
 
eb7e6b9
e391a54
 
 
 
62a2783
e391a54
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from datasets import load_dataset
import evaluate
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification, TrainingArguments, Trainer
import numpy as np
from torchvision import models, transforms

print("Cuda availability:", torch.cuda.is_available())
cuda = torch.device('cuda')
print("cuda: ", torch.cuda.get_device_name(device=cuda))

dataset = load_dataset("chriamue/bird-species-dataset")

model_name = "google/efficientnet-b2"
finetuned_model_name = "chriamue/bird-species-classifier"

#####
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

# preprocessor = EfficientNetImageProcessor.from_pretrained(model_name)
# model = EfficientNetForImageClassification.from_pretrained(model_name, num_labels=len(
#     labels), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)
    

# Replace the EfficientNetImageProcessor with torchvision transforms
preprocessor = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Replace the EfficientNetForImageClassification with torchvision ResNet-50
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(labels))

training_args = TrainingArguments(
    finetuned_model_name, remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=6,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)


def transforms(examples):
    pixel_values = [preprocessor(image, return_tensors="pt").pixel_values.squeeze(
        0) for image in examples["image"]]
    examples["pixel_values"] = pixel_values
    return examples

image = dataset["train"][0]["image"]

# dataset["train"] = dataset["train"].shuffle(seed=42).select(range(1500))
# dataset["validation"] = dataset["validation"].select(range(100))
# dataset["test"] = dataset["test"].select(range(100))

dataset = dataset.map(transforms, remove_columns=["image"], batched=True)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

train_results = trainer.train(resume_from_checkpoint=False)

print(trainer.evaluate())

trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
trainer.save_model(".")

dummy_input = torch.randn(1, 3, 224, 224)
model = model.to('cpu')
output_onnx_path = 'model.onnx'
torch.onnx.export(model, dummy_input, output_onnx_path, opset_version=13)

inputs = preprocessor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits
    predicted_label = logits.argmax(-1).item()
    print(labels[predicted_label])