ozzyonfire
commited on
Commit
•
03952c2
1
Parent(s):
558944c
using resnet instead of efficientnet
Browse files- model.onnx → onnx/model.onnx +0 -0
- train.py +18 -4
model.onnx → onnx/model.onnx
RENAMED
File without changes
|
train.py
CHANGED
@@ -3,6 +3,7 @@ from datasets import load_dataset
|
|
3 |
import evaluate
|
4 |
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification, TrainingArguments, Trainer
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
print("Cuda availability:", torch.cuda.is_available())
|
8 |
cuda = torch.device('cuda')
|
@@ -20,10 +21,23 @@ for i, label in enumerate(labels):
|
|
20 |
label2id[label] = str(i)
|
21 |
id2label[str(i)] = label
|
22 |
|
23 |
-
preprocessor = EfficientNetImageProcessor.from_pretrained(model_name)
|
24 |
-
model = EfficientNetForImageClassification.from_pretrained(model_name, num_labels=len(
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
training_args = TrainingArguments(
|
29 |
finetuned_model_name, remove_unused_columns=False,
|
|
|
3 |
import evaluate
|
4 |
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification, TrainingArguments, Trainer
|
5 |
import numpy as np
|
6 |
+
from torchvision import models, transforms
|
7 |
|
8 |
print("Cuda availability:", torch.cuda.is_available())
|
9 |
cuda = torch.device('cuda')
|
|
|
21 |
label2id[label] = str(i)
|
22 |
id2label[str(i)] = label
|
23 |
|
24 |
+
# preprocessor = EfficientNetImageProcessor.from_pretrained(model_name)
|
25 |
+
# model = EfficientNetForImageClassification.from_pretrained(model_name, num_labels=len(
|
26 |
+
# labels), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)
|
27 |
+
|
28 |
+
|
29 |
+
# Replace the EfficientNetImageProcessor with torchvision transforms
|
30 |
+
preprocessor = transforms.Compose([
|
31 |
+
transforms.Resize(256),
|
32 |
+
transforms.CenterCrop(224),
|
33 |
+
transforms.ToTensor(),
|
34 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
35 |
+
])
|
36 |
+
|
37 |
+
# Replace the EfficientNetForImageClassification with torchvision ResNet-50
|
38 |
+
model = models.resnet50(pretrained=True)
|
39 |
+
num_ftrs = model.fc.in_features
|
40 |
+
model.fc = torch.nn.Linear(num_ftrs, len(labels))
|
41 |
|
42 |
training_args = TrainingArguments(
|
43 |
finetuned_model_name, remove_unused_columns=False,
|