ozzyonfire commited on
Commit
03952c2
1 Parent(s): 558944c

using resnet instead of efficientnet

Browse files
Files changed (2) hide show
  1. model.onnx → onnx/model.onnx +0 -0
  2. train.py +18 -4
model.onnx → onnx/model.onnx RENAMED
File without changes
train.py CHANGED
@@ -3,6 +3,7 @@ from datasets import load_dataset
3
  import evaluate
4
  from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification, TrainingArguments, Trainer
5
  import numpy as np
 
6
 
7
  print("Cuda availability:", torch.cuda.is_available())
8
  cuda = torch.device('cuda')
@@ -20,10 +21,23 @@ for i, label in enumerate(labels):
20
  label2id[label] = str(i)
21
  id2label[str(i)] = label
22
 
23
- preprocessor = EfficientNetImageProcessor.from_pretrained(model_name)
24
- model = EfficientNetForImageClassification.from_pretrained(model_name, num_labels=len(
25
- labels), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  training_args = TrainingArguments(
29
  finetuned_model_name, remove_unused_columns=False,
 
3
  import evaluate
4
  from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification, TrainingArguments, Trainer
5
  import numpy as np
6
+ from torchvision import models, transforms
7
 
8
  print("Cuda availability:", torch.cuda.is_available())
9
  cuda = torch.device('cuda')
 
21
  label2id[label] = str(i)
22
  id2label[str(i)] = label
23
 
24
+ # preprocessor = EfficientNetImageProcessor.from_pretrained(model_name)
25
+ # model = EfficientNetForImageClassification.from_pretrained(model_name, num_labels=len(
26
+ # labels), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)
27
+
28
+
29
+ # Replace the EfficientNetImageProcessor with torchvision transforms
30
+ preprocessor = transforms.Compose([
31
+ transforms.Resize(256),
32
+ transforms.CenterCrop(224),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
35
+ ])
36
+
37
+ # Replace the EfficientNetForImageClassification with torchvision ResNet-50
38
+ model = models.resnet50(pretrained=True)
39
+ num_ftrs = model.fc.in_features
40
+ model.fc = torch.nn.Linear(num_ftrs, len(labels))
41
 
42
  training_args = TrainingArguments(
43
  finetuned_model_name, remove_unused_columns=False,