yuanphon commited on
Commit
bf64851
1 Parent(s): b92f43a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ import cv2
4
+ import os
5
+ import numpy as np
6
+ from PIL import Image
7
+ from transformers import ViTForImageClassification, AutoImageProcessor, AdamW, ViTImageProcessor, VisionEncoderDecoderModel, AutoTokenizer
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+
10
+ model_path = 'model'
11
+ train_pickle_path = 'train_data.pickle'
12
+ valid_pickle_path = 'valid_data.pickle'
13
+ image_directory = 'images'
14
+ test_image_path = 'test.jpg'
15
+ num_epochs = 5 # Fine-tune the model
16
+ label_list = ["小白", "巧巧", "冏媽", "乖狗", "花捲", "超人", "黑胖", "橘子"]
17
+ label_dictionary = {"小白": 0, "巧巧": 1, "冏媽": 2, "乖狗": 3, "花捲": 4, "超人": 5, "黑胖": 6, "橘子": 7}
18
+ num_classes = len(label_dictionary) # Adjust according to your classification task
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+ # device = torch.device("mps")
21
+
22
+ def data_generate(dataset):
23
+ images = []
24
+ labels = []
25
+ image_processor = AutoImageProcessor.from_pretrained('google/vit-large-patch16-224-in21k')
26
+ for folder_name in os.listdir(image_directory):
27
+ folder_path = os.path.join(image_directory, folder_name)
28
+ if os.path.isdir(folder_path):
29
+ for image_file in os.listdir(folder_path):
30
+ if image_file.startswith(dataset):
31
+ image_path = os.path.join(folder_path, image_file)
32
+ # print(image_path)
33
+
34
+ img = cv2.imread(image_path)
35
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36
+
37
+ img = Image.fromarray(img)
38
+ img = img.resize((224, 224))
39
+ inputs = image_processor(images=img, return_tensors="pt")
40
+ images.append(inputs['pixel_values'].squeeze(0).numpy())
41
+ labels.append(int(folder_name.split('_')[0]))
42
+
43
+ images = np.array(images)
44
+ labels = np.array(labels)
45
+
46
+ # Now you can pickle this data
47
+ train_data = {'img': images, 'label': labels}
48
+ with open(f'{dataset}_data.pickle', 'wb') as f:
49
+ pickle.dump(train_data, f)
50
+
51
+ def train_model():
52
+
53
+ if not os.path.exists(valid_pickle_path):
54
+ data_generate('valid')
55
+ if not os.path.exists(train_pickle_path):
56
+ data_generate('train')
57
+
58
+ # Load the train and vaild
59
+ with open("train_data.pickle", "rb") as f:
60
+ train_data = pickle.load(f)
61
+
62
+ with open("valid_data.pickle", "rb") as f:
63
+ valid_data = pickle.load(f)
64
+
65
+ # Convert the dataset into torch tensors
66
+ train_inputs = torch.tensor(train_data["img"])
67
+ train_labels = torch.tensor(train_data["label"])
68
+ valid_inputs = torch.tensor(valid_data["img"])
69
+ valid_labels = torch.tensor(valid_data["label"])
70
+
71
+ # Create the TensorDataset
72
+ train_dataset = TensorDataset(train_inputs, train_labels)
73
+ valid_dataset = TensorDataset(valid_inputs, valid_labels)
74
+
75
+ # Create the DataLoader
76
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
77
+ valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=True)
78
+
79
+ # Define the model and move it to the GPU
80
+
81
+ model = ViTForImageClassification.from_pretrained('google/vit-large-patch16-224-in21k', num_labels=num_classes)
82
+ model.to(device)
83
+
84
+ # Define the optimizer
85
+ optimizer = AdamW(model.parameters(), lr=1e-4)
86
+
87
+ for epoch in range(num_epochs):
88
+
89
+ model.train()
90
+ total_loss = 0
91
+
92
+ for i, batch in enumerate(train_loader):
93
+
94
+ # Move batch to the GPU
95
+ batch = [r.to(device) for r in batch]
96
+
97
+ # Unpack the inputs from our dataloader
98
+ inputs, labels = batch
99
+
100
+ # Clear out the gradients (by default they accumulate)
101
+ optimizer.zero_grad()
102
+
103
+ # Forward pass
104
+ outputs = model(inputs, labels=labels)
105
+
106
+ # Compute loss
107
+ loss = outputs.loss
108
+
109
+ # Backward pass
110
+ loss.backward()
111
+
112
+
113
+ # Update parameters and take a step using the computed gradient
114
+ optimizer.step()
115
+
116
+ # Update the loss
117
+ total_loss += loss.item()
118
+
119
+ # print(f'{i}/{len(train_loader)} ')
120
+
121
+ # Get the average loss for the entire epoch
122
+ avg_loss = total_loss / len(train_loader)
123
+
124
+ # Print the loss
125
+ print('Epoch:', epoch + 1, 'Training Loss:', avg_loss)
126
+
127
+
128
+ # Evaluate the model on the validation set
129
+ model.eval()
130
+ total_correct = 0
131
+
132
+ for batch in valid_loader:
133
+ # Move batch to the GPU
134
+ batch = [t.to(device) for t in batch]
135
+
136
+ # Unpack the inputs from our dataloader
137
+ inputs, labels = batch
138
+
139
+ # Forward pass
140
+ with torch.no_grad():
141
+ outputs = model(inputs)
142
+
143
+ # Get the predictions
144
+ predictions = torch.argmax(outputs.logits, dim=1)
145
+
146
+ # Update the total correct
147
+ total_correct += torch.sum(predictions == labels)
148
+
149
+ # Calculate the accuracy
150
+ accuracy = total_correct / len(valid_dataset)
151
+ print('Validation accuracy:', accuracy.item())
152
+
153
+ model.save_pretrained("model")
154
+
155
+ def predict():
156
+ # Load the model
157
+ model = ViTForImageClassification.from_pretrained(model_path, num_labels=num_classes)
158
+
159
+ image_processor = AutoImageProcessor.from_pretrained('google/vit-large-patch16-224-in21k')
160
+
161
+
162
+ # Load the test data
163
+ # Load the image
164
+
165
+ img = cv2.imread(test_image_path)
166
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
167
+
168
+ # Resize the image to 224x224 pixels
169
+ img = Image.fromarray(img)
170
+ img = img.resize((224, 224))
171
+
172
+ # img to tensor
173
+ # Preprocess the image and generate features
174
+ inputs = image_processor(images=img, return_tensors="pt")
175
+ outputs = model(**inputs)
176
+ logits = outputs.logits
177
+
178
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
179
+ predicted_class_idx = logits.argmax(-1).item()
180
+
181
+ return label_list[predicted_class_idx] if probabilities.max().item() > 0.90 else '不是校狗'
182
+
183
+ def captioning():
184
+
185
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
186
+ feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
187
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
188
+
189
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
190
+ model.to(device)
191
+
192
+ max_length = 16
193
+ num_beams = 4
194
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
195
+
196
+ images = []
197
+ for image_path in [test_image_path]:
198
+ i_image = Image.open(image_path)
199
+ if i_image.mode != "RGB":
200
+ i_image = i_image.convert(mode="RGB")
201
+
202
+ images.append(i_image)
203
+
204
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
205
+ pixel_values = pixel_values.to(device)
206
+
207
+ output_ids = model.generate(pixel_values, **gen_kwargs)
208
+
209
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
210
+ preds = [pred.strip() for pred in preds]
211
+ return preds[-1]
212
+
213
+ def output(predict_class, caption):
214
+ conj = ['are', 'is', 'dog']
215
+ if predict_class == '不是校狗' or caption.find('dog') == -1:
216
+ print(f'{caption} ({predict_class})')
217
+ else:
218
+ for c in conj:
219
+ if caption.find(c) != -1:
220
+ print(f'{predict_class} is{caption[caption.find(c) + len(c):]}')
221
+ return
222
+ print(f'{caption} ({predict_class})')
223
+
224
+
225
+ if __name__ == '__main__':
226
+
227
+ if not os.path.exists(model_path):
228
+ train_model()
229
+ output(predict(), captioning())